├── .gitignore ├── LICENSE ├── README.md ├── configs ├── convert │ ├── vitdet_b.txt │ └── vivit_b.txt ├── detectron │ ├── vitdet_b_coco.py │ └── vitdet_b_vid.py ├── evaluate │ ├── vitdet_vid │ │ ├── _ablate_av.yml │ │ ├── _base.yml │ │ ├── _half.yml │ │ ├── _size_1024.yml │ │ ├── _size_672.yml │ │ ├── _spatial.yml │ │ ├── _stgt.yml │ │ ├── _temporal.yml │ │ ├── _tokenwise.yml │ │ ├── ablate_av_1024.yml │ │ ├── ablate_av_672.yml │ │ ├── base_1024.yml │ │ ├── base_672.yml │ │ ├── base_half_1024.yml │ │ ├── compare_ln_1024.yml │ │ ├── spatial_1024.yml │ │ ├── spatial_672.yml │ │ ├── spatial_half_1024.yml │ │ ├── spatial_half_672.yml │ │ ├── spatiotemporal_1024.yml │ │ ├── spatiotemporal_672.yml │ │ ├── spatiotemporal_full_1024.yml │ │ ├── spatiotemporal_full_672.yml │ │ ├── stgt_1024.yml │ │ ├── stgt_672.yml │ │ ├── temporal_1024.yml │ │ ├── temporal_672.yml │ │ ├── temporal_full_1024.yml │ │ ├── temporal_full_672.yml │ │ ├── threshold_1024.yml │ │ ├── tokenwise_1024.yml │ │ └── tokenwise_672.yml │ ├── vivit_epic_kitchens │ │ ├── _ats.yml │ │ ├── _base.yml │ │ ├── _temporal.yml │ │ ├── ats.yml │ │ ├── base.yml │ │ ├── temporal_100.yml │ │ ├── temporal_200.yml │ │ ├── temporal_50.yml │ │ ├── temporal_ats_200.yml │ │ └── temporal_naive_100.yml │ └── vivit_kinetics400 │ │ ├── _base.yml │ │ ├── _temporal.yml │ │ ├── base.yml │ │ ├── temporal_24.yml │ │ ├── temporal_48.yml │ │ └── temporal_96.yml ├── models │ ├── vitdet_b_coco.yml │ ├── vitdet_b_vid.yml │ ├── vivit_b_epic_kitchens.yml │ └── vivit_b_kinetics400.yml ├── spatial │ ├── vivit_epic_kitchens │ │ ├── 100.yml │ │ ├── 200.yml │ │ ├── 50.yml │ │ └── _base.yml │ └── vivit_kinetics400 │ │ ├── 24.yml │ │ ├── 48.yml │ │ ├── 96.yml │ │ └── _base.yml ├── time │ ├── vitdet_vid │ │ ├── _base.yml │ │ ├── _cpu.yml │ │ ├── _cuda.yml │ │ ├── _size_1024.yml │ │ ├── _size_672.yml │ │ ├── _spatial.yml │ │ ├── _temporal.yml │ │ ├── base_1024_cpu.yml │ │ ├── base_1024_cuda.yml │ │ ├── base_672_cpu.yml │ │ ├── base_672_cuda.yml │ │ ├── spatial_1024_cpu.yml │ │ ├── spatial_1024_cuda.yml │ │ ├── spatial_672_cpu.yml │ │ ├── spatial_672_cuda.yml │ │ ├── spatiotemporal_1024_cpu.yml │ │ ├── spatiotemporal_1024_cuda.yml │ │ ├── spatiotemporal_672_cpu.yml │ │ ├── spatiotemporal_672_cuda.yml │ │ ├── temporal_1024_cpu.yml │ │ ├── temporal_1024_cuda.yml │ │ ├── temporal_672_cpu.yml │ │ └── temporal_672_cuda.yml │ └── vivit_epic_kitchens │ │ ├── _base.yml │ │ ├── _cpu.yml │ │ ├── _cuda.yml │ │ ├── _temporal.yml │ │ ├── base_cpu.yml │ │ ├── base_cuda.yml │ │ ├── temporal_cpu.yml │ │ └── temporal_cuda.yml └── train │ ├── vivit_epic_kitchens │ ├── _base.yml │ ├── final_100.yml │ ├── final_200.yml │ └── final_50.yml │ └── vivit_kinetics400 │ ├── _base.yml │ ├── final_24.yml │ ├── final_48.yml │ └── final_96.yml ├── datasets ├── epic_kitchens.py ├── kinetics400.py ├── vid.py └── vivit_spatial.py ├── environment.yml ├── eventful_transformer ├── backbones.py ├── base.py ├── blocks.py ├── counting.py ├── modules.py ├── policies.py └── utils.py ├── models ├── vitdet.py └── vivit.py ├── scripts ├── convert │ ├── vitdet.py │ └── vivit.py ├── evaluate │ ├── vitdet_vid.py │ ├── vitdet_vid.sh │ ├── vivit_epic_kitchens.py │ └── vivit_kinetics400.py ├── misc │ └── measure_vitdet_padding.py ├── spatial │ ├── vivit_epic_kitchens.py │ └── vivit_kinetics400.py ├── time │ ├── vitdet_vid.py │ └── vivit_epic_kitchens.py └── train │ ├── vivit_epic_kitchens.py │ └── vivit_kinetics400.py └── utils ├── config.py ├── evaluate.py ├── image.py ├── misc.py ├── spatial.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Sorted in ascending order 2 | .DS_Store 3 | .idea 4 | __pycache__ 5 | inputs 6 | outputs 7 | weights 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Matthew Dutson 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Overview 2 | 3 | This is the PyTorch code for our ICCV 2023 paper "Eventful Transformers: Leveraging Temporal Redundancy in Vision Transformers." Please see our [paper webpage](https://wisionlab.com/project/eventful-transformers/) and the [arXiv paper](https://arxiv.org/abs/2308.13494). 4 | 5 | ## Disclaimer 6 | 7 | This is research-grade code, so it's possible you will encounter some hiccups. [Contact me](https://github.com/mattdutson/) if you encounter problems or if the documentation is unclear, and I will do my best to help. 8 | 9 | ## TLDR 10 | 11 | Most of the interesting code (implementation of our core contributions) is in `eventful_transformer/blocks.py`, `eventful_transformer/modules.py`, and `eventful_transformer/policies.py`. 12 | 13 | ## Dependencies 14 | 15 | Dependencies are managed using Conda. The environment is defined in `environment.yml`. 16 | 17 | To create the environment, run: 18 | ``` 19 | conda env create -f environment.yml 20 | ``` 21 | Then activate the environment with: 22 | ``` 23 | conda activate eventful-transformer 24 | ``` 25 | 26 | ## Running Scripts 27 | 28 | Scripts should be run from the repo's base directory. 29 | 30 | Many scripts expect a `.yml` configuration file as a command-line argument. These configuration files are in `configs`. The structure of the `configs` folder is set to mirror the structure of the `scripts` folder. For example, to run the `base_672` evaluation for the ViTDet VID model: 31 | ``` 32 | ./scripts/evaluate/vitdet_vid.py ./configs/evaluate/vitdet_vid/base_672.yml 33 | ``` 34 | 35 | ## Weights 36 | 37 | Weights for the ViViT action recognition model (on Kinetics-400 and EPIC-Kitchens) are available [here](https://github.com/alibaba-mmai-research/TAdaConv/blob/main/MODEL_ZOO.md). We use the "ViViT Fact. Enc." weights. 38 | 39 | Weights for the ViTDet object detection model (on COCO) are available [here](https://github.com/facebookresearch/detectron2/tree/main/projects/ViTDet). We use the "Cascade Mask R-CNN, ViTDet, ViT-B" weights. Weights on ImageNet VID are available [here](https://drive.google.com/drive/folders/1tNtIOYlCIlzb2d_fCsIbmjgIETd-xzW-) (`frcnn_vitdet_final.pth`). 40 | 41 | The weight names need to be remapped to work with this codebase. To remap the ViViT weights, run: 42 | ``` 43 | ./scripts/convert/vivit.py ./configs/convert/vivit_b.txt 44 | ``` 45 | with `` and `` replaced by the path of the downloaded weights and the path where the converted weights should be saved, respectively. 46 | 47 | To remap the ViTDet weights, run: 48 | ``` 49 | ./scripts/convert/vitdet.py ./configs/convert/vitdet_b.txt 50 | ``` 51 | 52 | Some ViViT evaluation scripts assume a fine-tuned temporal sub-model. Fine-tuned weights can be downloaded [here](https://drive.google.com/drive/folders/1AP-NRhO4l_spEJ6ZXvfVO3PLlLOsXCmM?usp=sharing). 53 | 54 | Alternatively, you can run the fine-tuning yourself. To do this, run a `spatial` configuration (to cache the forward pass of the spatial sub-model), followed by a `train` configuration. For example: 55 | ``` 56 | ./scripts/spatial/vivit_epic_kitchens.py ./configs/spatial/vivit_epic_kitchens/50.yml 57 | ``` 58 | then 59 | ``` 60 | ./scripts/train/vivit_epic_kitchens.py ./configs/train/vivit_epic_kitchens/final_50.yml 61 | ``` 62 | This will produce `weights/vivit_b_epic_kitchens_final_50.pth`. 63 | 64 | ## Data 65 | 66 | The `datasets` folder defines PyTorch `Dataset` classes for Kinetics-400, VID, and EPIC-Kitchens. 67 | 68 | The Kinetics-400 class will automatically download and prepare the dataset on first use. 69 | 70 | VID requires a manual download. Download `vid_data.tar` from [here](https://drive.google.com/drive/folders/1tNtIOYlCIlzb2d_fCsIbmjgIETd-xzW-) and place it at `./data/vid/data.tar`. The VID class will take care of unpacking and preparing the data on first use. 71 | 72 | EPIC-Kitchens also requires a manual download. Download the videos from [here](https://drive.google.com/drive/folders/1OKJpgSKR1QnWa2tMMafknLF-CpEaxDbY) and place them in `./data/epic_kitchens/videos`. Download the labels `EPIC_100_train.csv` and `EPIC_100_validation.csv` from [here](https://github.com/epic-kitchens/epic-kitchens-100-annotations) and place them in `./data/epic_kitchens`. The EPICKitchens class will prepare the data on first use. 73 | 74 | ## Other Setup 75 | 76 | Scripts assume that the current working directory is on the Python path. In the Bash shell, run 77 | ``` 78 | export PYTHONPATH="$PYTHONPATH:." 79 | ``` 80 | Or in the Fish shell: 81 | ``` 82 | set -ax PYTHONPATH . 83 | ``` 84 | 85 | ## Code Style 86 | 87 | Format all code using [Black](https://black.readthedocs.io/en/stable/). Use a line limit of 88 characters (the default). To format a file, use the command: 88 | ``` 89 | black 90 | ``` 91 | -------------------------------------------------------------------------------- /configs/convert/vitdet_b.txt: -------------------------------------------------------------------------------- 1 | ^backbone\.net\.patch_embed\.proj\. 2 | embedding.conv. 3 | 4 | ^backbone\.net\.pos_embed$ 5 | backbone.position_encoding.encoding 6 | 7 | ^backbone\.net\.blocks\.(\d*)\.norm1\. 8 | backbone.blocks.\1.input_layer_norm. 9 | 10 | ^backbone\.net\.blocks\.(\d*)\.attn\.qkv\. 11 | backbone.blocks.\1.qkv. 12 | 13 | ^backbone\.net\.blocks\.(\d*)\.attn\.rel_pos_h$ 14 | backbone.blocks.\1.relative_position.y_embedding 15 | 16 | ^backbone\.net\.blocks\.(\d*)\.attn\.rel_pos_w$ 17 | backbone.blocks.\1.relative_position.x_embedding 18 | 19 | ^backbone\.net\.blocks\.(\d*)\.attn\.proj\. 20 | backbone.blocks.\1.projection. 21 | 22 | ^backbone\.net\.blocks\.(\d*)\.norm2\. 23 | backbone.blocks.\1.mlp_layer_norm. 24 | 25 | ^backbone\.net\.blocks\.(\d*)\.mlp\.fc1\. 26 | backbone.blocks.\1.mlp_1. 27 | 28 | ^backbone\.net\.blocks\.(\d*)\.mlp\.fc2\. 29 | backbone.blocks.\1.mlp_2. 30 | 31 | ^backbone\.simfp_2\.([013])\. 32 | pyramid.stages.0.\1. 33 | 34 | ^backbone\.simfp_2\.4\.weight$ 35 | pyramid.stages.0.4.weight 36 | 37 | ^backbone\.simfp_2\.4\.norm\. 38 | pyramid.stages.0.5. 39 | 40 | ^backbone\.simfp_2\.5\.weight$ 41 | pyramid.stages.0.6.weight 42 | 43 | ^backbone\.simfp_2\.5\.norm\. 44 | pyramid.stages.0.7. 45 | 46 | ^backbone\.simfp_3\.0\. 47 | pyramid.stages.1.0. 48 | 49 | ^backbone\.simfp_3\.1\.weight$ 50 | pyramid.stages.1.1.weight 51 | 52 | ^backbone\.simfp_3\.1\.norm\. 53 | pyramid.stages.1.2. 54 | 55 | ^backbone\.simfp_3\.2\.weight$ 56 | pyramid.stages.1.3.weight 57 | 58 | ^backbone\.simfp_3\.2\.norm\. 59 | pyramid.stages.1.4. 60 | 61 | ^backbone\.simfp_4\.0\.weight$ 62 | pyramid.stages.2.0.weight 63 | 64 | ^backbone\.simfp_4\.0\.norm\. 65 | pyramid.stages.2.1. 66 | 67 | ^backbone\.simfp_4\.1\.weight$ 68 | pyramid.stages.2.2.weight 69 | 70 | ^backbone\.simfp_4\.1\.norm\. 71 | pyramid.stages.2.3. 72 | 73 | ^backbone\.simfp_5\.1\.weight$ 74 | pyramid.stages.3.1.weight 75 | 76 | ^backbone\.simfp_5\.1\.norm\. 77 | pyramid.stages.3.2. 78 | 79 | ^backbone\.simfp_5\.2\.weight$ 80 | pyramid.stages.3.3.weight 81 | 82 | ^backbone\.simfp_5\.2\.norm\. 83 | pyramid.stages.3.4. 84 | -------------------------------------------------------------------------------- /configs/convert/vivit_b.txt: -------------------------------------------------------------------------------- 1 | ^backbone\.stem\.conv1\.weight$ 2 | embedding.conv.weight 3 | 4 | ^backbone\.stem\.conv1\.bias$ 5 | embedding.conv.bias 6 | 7 | ^backbone\.cls_token$ 8 | spatial_model.class_token 9 | 10 | ^backbone\.pos_embd$ 11 | spatial_model.backbone.position_encoding.encoding 12 | 13 | ^backbone\.layers\.(\d*)\.norm\. 14 | spatial_model.backbone.blocks.\1.input_layer_norm. 15 | 16 | ^backbone\.layers\.(\d*)\.attn\.to_qkv\. 17 | spatial_model.backbone.blocks.\1.qkv. 18 | 19 | ^backbone\.layers\.(\d*)\.attn\.proj\. 20 | spatial_model.backbone.blocks.\1.projection. 21 | 22 | ^backbone\.layers\.(\d*)\.norm_ffn\. 23 | spatial_model.backbone.blocks.\1.mlp_layer_norm. 24 | 25 | ^backbone\.layers\.(\d*)\.ffn\.net\.0\. 26 | spatial_model.backbone.blocks.\1.mlp_1. 27 | 28 | ^backbone\.layers\.(\d*)\.ffn\.net\.3\. 29 | spatial_model.backbone.blocks.\1.mlp_2. 30 | 31 | ^backbone\.norm\. 32 | spatial_model.layer_norm. 33 | 34 | ^backbone\.cls_token_out$ 35 | temporal_model.class_token 36 | 37 | ^backbone\.temp_embd$ 38 | temporal_model.backbone.position_encoding.encoding 39 | 40 | ^backbone\.layers_temporal\.(\d*)\.norm\. 41 | temporal_model.backbone.blocks.\1.input_layer_norm. 42 | 43 | ^backbone\.layers_temporal\.(\d*)\.attn\.to_qkv\. 44 | temporal_model.backbone.blocks.\1.qkv. 45 | 46 | ^backbone\.layers_temporal\.(\d*)\.attn\.proj\. 47 | temporal_model.backbone.blocks.\1.projection. 48 | 49 | ^backbone\.layers_temporal\.(\d*)\.norm_ffn\. 50 | temporal_model.backbone.blocks.\1.mlp_layer_norm. 51 | 52 | ^backbone\.layers_temporal\.(\d*)\.ffn\.net\.0\. 53 | temporal_model.backbone.blocks.\1.mlp_1. 54 | 55 | ^backbone\.layers_temporal\.(\d*)\.ffn\.net\.3\. 56 | temporal_model.backbone.blocks.\1.mlp_2. 57 | 58 | ^backbone\.norm_out\. 59 | temporal_model.layer_norm. 60 | 61 | ^head\.linear\. 62 | classifier. 63 | 64 | ^head\.linear1\. 65 | classifier. 66 | 67 | ^head\.linear2\. 68 | DISCARD 69 | -------------------------------------------------------------------------------- /configs/detectron/vitdet_b_coco.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from detectron2 import model_zoo 4 | from detectron2.config import LazyCall as L 5 | from detectron2.layers import ShapeSpec 6 | from detectron2.modeling.box_regression import Box2BoxTransform 7 | from detectron2.modeling.matcher import Matcher 8 | from detectron2.modeling.roi_heads import ( 9 | FastRCNNOutputLayers, 10 | FastRCNNConvFCHead, 11 | CascadeROIHeads, 12 | ) 13 | 14 | # This file is derived from: 15 | # https://github.com/facebookresearch/detectron2/blob/main/projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_b_100ep.py 16 | 17 | model = model_zoo.get_config(str(Path("common", "models", "mask_rcnn_vitdet.py"))).model 18 | 19 | # arguments that don't exist for Cascade R-CNN 20 | [model.roi_heads.pop(k) for k in ["box_head", "box_predictor", "proposal_matcher"]] 21 | 22 | model.roi_heads.update( 23 | _target_=CascadeROIHeads, 24 | box_heads=[ 25 | L(FastRCNNConvFCHead)( 26 | input_shape=ShapeSpec(channels=256, height=7, width=7), 27 | conv_dims=[256, 256, 256, 256], 28 | fc_dims=[1024], 29 | conv_norm="LN", 30 | ) 31 | for _ in range(3) 32 | ], 33 | box_predictors=[ 34 | L(FastRCNNOutputLayers)( 35 | input_shape=ShapeSpec(channels=1024), 36 | test_score_thresh=0.05, 37 | box2box_transform=L(Box2BoxTransform)(weights=(w1, w1, w2, w2)), 38 | cls_agnostic_bbox_reg=True, 39 | num_classes="${...num_classes}", 40 | ) 41 | for (w1, w2) in [(10, 5), (20, 10), (30, 15)] 42 | ], 43 | proposal_matchers=[ 44 | L(Matcher)(thresholds=[th], labels=[0, 1], allow_low_quality_matches=False) 45 | for th in [0.5, 0.6, 0.7] 46 | ], 47 | ) 48 | 49 | # We only need configuration for the proposal generator and ROI heads. 50 | model = dict(proposal_generator=model.proposal_generator, roi_heads=model.roi_heads) 51 | -------------------------------------------------------------------------------- /configs/detectron/vitdet_b_vid.py: -------------------------------------------------------------------------------- 1 | from detectron2.config import LazyCall as L 2 | from detectron2.layers import ShapeSpec 3 | from detectron2.modeling.anchor_generator import DefaultAnchorGenerator 4 | from detectron2.modeling.box_regression import Box2BoxTransform 5 | from detectron2.modeling.matcher import Matcher 6 | from detectron2.modeling.poolers import ROIPooler 7 | from detectron2.modeling.proposal_generator import RPN, StandardRPNHead 8 | from detectron2.modeling.roi_heads import FastRCNNOutputLayers, FastRCNNConvFCHead 9 | from detectron2.modeling.roi_heads import StandardROIHeads 10 | 11 | # This file is derived from: 12 | # https://github.com/happyharrycn/detectron2_vitdet_vid/blob/main/projects/ViTDet-VID/configs/frcnn_vitdet.py 13 | 14 | model = dict( 15 | proposal_generator=L(RPN)( 16 | in_features=["p2", "p3", "p4", "p5", "p6"], 17 | head=L(StandardRPNHead)(in_channels=256, num_anchors=3, conv_dims=[-1, -1]), 18 | anchor_generator=L(DefaultAnchorGenerator)( 19 | sizes=[[32], [64], [128], [256], [512]], 20 | aspect_ratios=[0.5, 1.0, 2.0], 21 | strides=[4, 8, 16, 32, 64], 22 | offset=0.0, 23 | ), 24 | anchor_matcher=L(Matcher)( 25 | thresholds=[0.3, 0.7], labels=[0, -1, 1], allow_low_quality_matches=True 26 | ), 27 | box2box_transform=L(Box2BoxTransform)(weights=[1.0, 1.0, 1.0, 1.0]), 28 | batch_size_per_image=256, 29 | positive_fraction=0.5, 30 | pre_nms_topk=(2000, 1000), 31 | post_nms_topk=(1000, 300), 32 | nms_thresh=0.7, 33 | ), 34 | roi_heads=L(StandardROIHeads)( 35 | num_classes=30, 36 | batch_size_per_image=128, 37 | positive_fraction=0.25, 38 | proposal_matcher=L(Matcher)( 39 | thresholds=[0.5], labels=[0, 1], allow_low_quality_matches=False 40 | ), 41 | box_in_features=["p2", "p3", "p4", "p5"], 42 | box_pooler=L(ROIPooler)( 43 | output_size=7, 44 | scales=(1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0 / 32), 45 | sampling_ratio=0, 46 | pooler_type="ROIAlignV2", 47 | ), 48 | box_head=L(FastRCNNConvFCHead)( 49 | input_shape=ShapeSpec(channels=256, height=7, width=7), 50 | conv_dims=[256, 256, 256, 256], 51 | conv_norm="LN", 52 | fc_dims=[1024], 53 | ), 54 | box_predictor=L(FastRCNNOutputLayers)( 55 | input_shape=ShapeSpec(channels=1024), 56 | test_score_thresh=0.05, 57 | box2box_transform=L(Box2BoxTransform)(weights=(10, 10, 5, 5)), 58 | num_classes="${..num_classes}", 59 | box_reg_loss_type="giou", 60 | loss_weight={"loss_box_reg": 2.0}, 61 | ), 62 | ), 63 | ) 64 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/_ablate_av.yml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone_config: 3 | block_class: "EventfulMatmul1Block" 4 | windowed_class: "EventfulTokenwiseBlock" 5 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/_base.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "configs/models/vitdet_b_vid.yml" 3 | _output: "results/evaluate/vitdet_vid/${_name}/" 4 | split: "vid_val" 5 | vanilla: false 6 | weights: "weights/vitdet_b_vid.pth" 7 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/_half.yml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone_config: 3 | block_config: 4 | matmul_2_cast: "float16" 5 | windowed_overrides: 6 | matmul_2_cast: null 7 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/_size_1024.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | model: 4 | input_shape: [3, 1024, 1024] 5 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/_size_672.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | model: 4 | input_shape: [3, 672, 672] 5 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/_spatial.yml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone_config: 3 | block_config: 4 | pool_size: 2 5 | windowed_overrides: 6 | pool_size: null 7 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/_stgt.yml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone_config: 3 | block_class: "EventfulTokenwiseBlock" 4 | block_config: 5 | stgt: true 6 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/_temporal.yml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone_config: 3 | block_class: "EventfulBlock" 4 | windowed_class: "EventfulTokenwiseBlock" 5 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/_tokenwise.yml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone_config: 3 | block_class: "EventfulTokenwiseBlock" 4 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/ablate_av_1024.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_1024.yml" 3 | - "_ablate_av.yml" 4 | token_top_k: [256, 512, 768, 1024, 1536, 2048] 5 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/ablate_av_672.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_672.yml" 3 | - "_ablate_av.yml" 4 | token_top_k: [128, 256, 384, 512, 768, 1024] 5 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/base_1024.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_1024.yml" 3 | vanilla: true 4 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/base_672.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_672.yml" 3 | vanilla: true 4 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/base_half_1024.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_1024.yml" 3 | - "_half.yml" 4 | vanilla: true 5 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/compare_ln_1024.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_1024.yml" 3 | model: 4 | backbone_config: 5 | block_class: "EventfulTokenwiseBlock" 6 | block_config: 7 | gate_before_ln: true 8 | token_top_k: [512, 1024, 2048] 9 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/spatial_1024.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_1024.yml" 3 | - "_spatial.yml" 4 | vanilla: true 5 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/spatial_672.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_672.yml" 3 | - "_spatial.yml" 4 | vanilla: true 5 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/spatial_half_1024.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_1024.yml" 3 | - "_half.yml" 4 | - "_spatial.yml" 5 | vanilla: true 6 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/spatial_half_672.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_672.yml" 3 | - "_half.yml" 4 | - "_spatial.yml" 5 | vanilla: true 6 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/spatiotemporal_1024.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_1024.yml" 3 | - "_half.yml" 4 | - "_spatial.yml" 5 | - "_temporal.yml" 6 | token_top_k: [256, 512, 768, 1024, 1536, 2048] 7 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/spatiotemporal_672.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_672.yml" 3 | - "_half.yml" 4 | - "_spatial.yml" 5 | - "_temporal.yml" 6 | token_top_k: [128, 256, 384, 512, 768, 1024] 7 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/spatiotemporal_full_1024.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_1024.yml" 3 | - "_spatial.yml" 4 | - "_temporal.yml" 5 | token_top_k: [512] 6 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/spatiotemporal_full_672.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_672.yml" 3 | - "_spatial.yml" 4 | - "_temporal.yml" 5 | token_top_k: [256] 6 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/stgt_1024.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_1024.yml" 3 | - "_stgt.yml" 4 | token_top_k: [256, 512, 768, 1024, 1536, 2048] 5 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/stgt_672.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_672.yml" 3 | - "_stgt.yml" 4 | token_top_k: [128, 256, 384, 512, 768, 1024] 5 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/temporal_1024.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_1024.yml" 3 | - "_half.yml" 4 | - "_temporal.yml" 5 | token_top_k: [256, 512, 768, 1024, 1536, 2048] 6 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/temporal_672.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_672.yml" 3 | - "_half.yml" 4 | - "_temporal.yml" 5 | token_top_k: [128, 256, 384, 512, 768, 1024] 6 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/temporal_full_1024.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_1024.yml" 3 | - "_temporal.yml" 4 | token_top_k: [512] 5 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/temporal_full_672.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_672.yml" 3 | - "_temporal.yml" 4 | token_top_k: [256] 5 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/threshold_1024.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_1024.yml" 3 | - "_half.yml" 4 | - "_temporal.yml" 5 | token_thresholds: [0.2, 1.0, 5.0] 6 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/tokenwise_1024.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_1024.yml" 3 | - "_tokenwise.yml" 4 | token_top_k: [256, 512, 768, 1024, 1536, 2048] 5 | -------------------------------------------------------------------------------- /configs/evaluate/vitdet_vid/tokenwise_672.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_672.yml" 3 | - "_tokenwise.yml" 4 | token_top_k: [128, 256, 384, 512, 768, 1024] 5 | -------------------------------------------------------------------------------- /configs/evaluate/vivit_epic_kitchens/_ats.yml: -------------------------------------------------------------------------------- 1 | model: 2 | spatial_config: 3 | block_config: 4 | ats_fraction: 0.9 5 | -------------------------------------------------------------------------------- /configs/evaluate/vivit_epic_kitchens/_base.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "configs/models/vivit_b_epic_kitchens.yml" 3 | _output: "results/evaluate/vivit_epic_kitchens/${_name}/" 4 | split: "validation" 5 | vanilla: false 6 | -------------------------------------------------------------------------------- /configs/evaluate/vivit_epic_kitchens/_temporal.yml: -------------------------------------------------------------------------------- 1 | model: 2 | spatial_config: 3 | block_class: "EventfulBlock" 4 | block_config: 5 | matmul_2_cast: "float16" 6 | -------------------------------------------------------------------------------- /configs/evaluate/vivit_epic_kitchens/ats.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | - "_ats.yml" 4 | vanilla: true 5 | weights: "weights/vivit_b_epic_kitchens.pth" 6 | -------------------------------------------------------------------------------- /configs/evaluate/vivit_epic_kitchens/base.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | vanilla: true 4 | weights: "weights/vivit_b_epic_kitchens.pth" 5 | -------------------------------------------------------------------------------- /configs/evaluate/vivit_epic_kitchens/temporal_100.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | - "_temporal.yml" 4 | token_top_k: [60, 80, 100, 120, 140] 5 | weights: "weights/vivit_b_epic_kitchens_final_100.pth" 6 | -------------------------------------------------------------------------------- /configs/evaluate/vivit_epic_kitchens/temporal_200.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | - "_temporal.yml" 4 | token_top_k: [120, 160, 200, 240, 280] 5 | weights: "weights/vivit_b_epic_kitchens_final_200.pth" 6 | -------------------------------------------------------------------------------- /configs/evaluate/vivit_epic_kitchens/temporal_50.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | - "_temporal.yml" 4 | token_top_k: [30, 40, 50, 60, 70] 5 | weights: "weights/vivit_b_epic_kitchens_final_50.pth" 6 | -------------------------------------------------------------------------------- /configs/evaluate/vivit_epic_kitchens/temporal_ats_200.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | - "_ats.yml" 4 | - "_temporal.yml" 5 | token_top_fraction: [0.5] 6 | weights: "weights/vivit_b_epic_kitchens_final_200.pth" 7 | -------------------------------------------------------------------------------- /configs/evaluate/vivit_epic_kitchens/temporal_naive_100.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | - "_temporal.yml" 4 | token_top_k: [100] 5 | weights: "weights/vivit_b_epic_kitchens.pth" 6 | -------------------------------------------------------------------------------- /configs/evaluate/vivit_kinetics400/_base.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "configs/models/vivit_b_kinetics400.yml" 3 | _output: "results/evaluate/vivit_kinetics400/${_name}/" 4 | vanilla: false 5 | -------------------------------------------------------------------------------- /configs/evaluate/vivit_kinetics400/_temporal.yml: -------------------------------------------------------------------------------- 1 | model: 2 | spatial_config: 3 | block_class: "EventfulBlock" 4 | block_config: 5 | matmul_2_cast: "float16" 6 | -------------------------------------------------------------------------------- /configs/evaluate/vivit_kinetics400/base.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | vanilla: true 4 | weights: "weights/vivit_b_kinetics400.pth" 5 | -------------------------------------------------------------------------------- /configs/evaluate/vivit_kinetics400/temporal_24.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | - "_temporal.yml" 4 | token_top_k: [24] 5 | weights: "weights/vivit_b_kinetics400_final_24.pth" 6 | -------------------------------------------------------------------------------- /configs/evaluate/vivit_kinetics400/temporal_48.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | - "_temporal.yml" 4 | token_top_k: [48] 5 | weights: "weights/vivit_b_kinetics400_final_48.pth" 6 | -------------------------------------------------------------------------------- /configs/evaluate/vivit_kinetics400/temporal_96.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | - "_temporal.yml" 4 | token_top_k: [96] 5 | weights: "weights/vivit_b_kinetics400_final_96.pth" 6 | -------------------------------------------------------------------------------- /configs/models/vitdet_b_coco.yml: -------------------------------------------------------------------------------- 1 | model: 2 | classes: 80 3 | detectron2_config: "configs/detectron/vitdet_b_coco.py" 4 | input_shape: [3, 1024, 1024] 5 | normalize_mean: [123.675, 116.28, 103.53] 6 | normalize_std: [58.395, 57.12, 57.375] 7 | output_channels: 256 8 | patch_size: [16, 16] 9 | scale_factors: [4.0, 2.0, 1.0, 0.5] 10 | backbone_config: 11 | depth: 12 12 | position_encoding_size: [14, 14] 13 | window_indices: [0, 1, 3, 4, 6, 7, 9, 10] 14 | block_config: 15 | dim: 768 16 | relative_embedding_size: [64, 64] 17 | heads: 12 18 | mlp_ratio: 4 19 | window_size: [14, 14] 20 | -------------------------------------------------------------------------------- /configs/models/vitdet_b_vid.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "vitdet_b_coco.yml" 3 | model: 4 | classes: 30 5 | detectron2_config: "configs/detectron/vitdet_b_vid.py" 6 | -------------------------------------------------------------------------------- /configs/models/vivit_b_epic_kitchens.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "vivit_b_kinetics400.yml" 3 | model: 4 | classes: 97 # Verb classes only 5 | input_shape: [32, 3, 320, 320] 6 | temporal_stride: 1 # The original model uses stride 2 at 60 fps, but our data is 30fps. 7 | spatial_config: 8 | position_encoding_size: [20, 20] 9 | -------------------------------------------------------------------------------- /configs/models/vivit_b_kinetics400.yml: -------------------------------------------------------------------------------- 1 | model: 2 | classes: 400 3 | input_shape: [32, 3, 224, 224] 4 | normalize_mean: 0.45 5 | normalize_std: 0.225 6 | spatial_views: 3 7 | temporal_stride: 2 8 | temporal_views: 4 9 | tubelet_shape: [2, 16, 16] 10 | spatial_config: 11 | depth: 12 12 | position_encoding_size: [14, 14] 13 | block_config: 14 | dim: 768 15 | heads: 12 16 | mlp_ratio: 4 17 | temporal_config: 18 | depth: 4 19 | position_encoding_size: [16] 20 | block_config: 21 | dim: 768 22 | heads: 12 23 | mlp_ratio: 4 24 | -------------------------------------------------------------------------------- /configs/spatial/vivit_epic_kitchens/100.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | k: 100 4 | -------------------------------------------------------------------------------- /configs/spatial/vivit_epic_kitchens/200.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | k: 200 4 | -------------------------------------------------------------------------------- /configs/spatial/vivit_epic_kitchens/50.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | k: 50 4 | -------------------------------------------------------------------------------- /configs/spatial/vivit_epic_kitchens/_base.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "configs/models/vivit_b_epic_kitchens.yml" 3 | model: 4 | spatial_config: 5 | block_class: "EventfulBlock" 6 | block_config: 7 | matmul_2_cast: "float16" 8 | spatial_only: true 9 | weights: "weights/vivit_b_epic_kitchens.pth" 10 | -------------------------------------------------------------------------------- /configs/spatial/vivit_kinetics400/24.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | k: 24 4 | -------------------------------------------------------------------------------- /configs/spatial/vivit_kinetics400/48.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | k: 48 4 | -------------------------------------------------------------------------------- /configs/spatial/vivit_kinetics400/96.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | k: 96 4 | -------------------------------------------------------------------------------- /configs/spatial/vivit_kinetics400/_base.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "configs/models/vivit_b_kinetics400.yml" 3 | 4 | # The full training dataset is huge (>200k videos) - using max_tars=40 limits 5 | # the size to about 40k videos. 6 | max_tars: 40 7 | model: 8 | spatial_config: 9 | block_class: "EventfulBlock" 10 | block_config: 11 | matmul_2_cast: "float16" 12 | spatial_only: true 13 | weights: "weights/vivit_b_kinetics400.pth" 14 | -------------------------------------------------------------------------------- /configs/time/vitdet_vid/_base.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "configs/models/vitdet_b_vid.yml" 3 | _output: "results/time/vitdet_vid/${_name}/" 4 | split: "vid_val" 5 | vanilla: false 6 | weights: "weights/vitdet_b_vid.pth" 7 | -------------------------------------------------------------------------------- /configs/time/vitdet_vid/_cpu.yml: -------------------------------------------------------------------------------- 1 | device: "cpu" 2 | model: 3 | backbone_config: 4 | block_config: 5 | matmul_2_cast: "bfloat16" 6 | windowed_overrides: 7 | matmul_2_cast: null 8 | n_items: 1 # 242 frames 9 | threads: 8 10 | -------------------------------------------------------------------------------- /configs/time/vitdet_vid/_cuda.yml: -------------------------------------------------------------------------------- 1 | device: "cuda" 2 | model: 3 | backbone_config: 4 | block_config: 5 | matmul_2_cast: "float16" 6 | windowed_overrides: 7 | matmul_2_cast: null 8 | n_items: 5 9 | -------------------------------------------------------------------------------- /configs/time/vitdet_vid/_size_1024.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | input_size: 1024 4 | model: 5 | input_shape: [3, 1024, 1024] 6 | -------------------------------------------------------------------------------- /configs/time/vitdet_vid/_size_672.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | input_size: 672 4 | model: 5 | input_shape: [3, 672, 672] 6 | -------------------------------------------------------------------------------- /configs/time/vitdet_vid/_spatial.yml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone_config: 3 | block_config: 4 | pool_size: 2 5 | windowed_overrides: 6 | pool_size: null 7 | -------------------------------------------------------------------------------- /configs/time/vitdet_vid/_temporal.yml: -------------------------------------------------------------------------------- 1 | model: 2 | backbone_config: 3 | block_class: "EventfulBlock" 4 | windowed_class: "EventfulTokenwiseBlock" 5 | -------------------------------------------------------------------------------- /configs/time/vitdet_vid/base_1024_cpu.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_1024.yml" 3 | - "_cpu.yml" 4 | vanilla: true 5 | -------------------------------------------------------------------------------- /configs/time/vitdet_vid/base_1024_cuda.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_1024.yml" 3 | - "_cuda.yml" 4 | vanilla: true 5 | -------------------------------------------------------------------------------- /configs/time/vitdet_vid/base_672_cpu.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_672.yml" 3 | - "_cpu.yml" 4 | vanilla: true 5 | -------------------------------------------------------------------------------- /configs/time/vitdet_vid/base_672_cuda.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_672.yml" 3 | - "_cuda.yml" 4 | vanilla: true 5 | -------------------------------------------------------------------------------- /configs/time/vitdet_vid/spatial_1024_cpu.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_1024.yml" 3 | - "_cpu.yml" 4 | - "_spatial.yml" 5 | vanilla: true 6 | -------------------------------------------------------------------------------- /configs/time/vitdet_vid/spatial_1024_cuda.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_1024.yml" 3 | - "_cuda.yml" 4 | - "_spatial.yml" 5 | vanilla: true 6 | -------------------------------------------------------------------------------- /configs/time/vitdet_vid/spatial_672_cpu.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_672.yml" 3 | - "_cpu.yml" 4 | - "_spatial.yml" 5 | vanilla: true 6 | -------------------------------------------------------------------------------- /configs/time/vitdet_vid/spatial_672_cuda.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_672.yml" 3 | - "_cuda.yml" 4 | - "_spatial.yml" 5 | vanilla: true 6 | -------------------------------------------------------------------------------- /configs/time/vitdet_vid/spatiotemporal_1024_cpu.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_1024.yml" 3 | - "_cpu.yml" 4 | - "_spatial.yml" 5 | - "_temporal.yml" 6 | token_top_k: [512] 7 | -------------------------------------------------------------------------------- /configs/time/vitdet_vid/spatiotemporal_1024_cuda.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_1024.yml" 3 | - "_cuda.yml" 4 | - "_spatial.yml" 5 | - "_temporal.yml" 6 | token_top_k: [512] 7 | -------------------------------------------------------------------------------- /configs/time/vitdet_vid/spatiotemporal_672_cpu.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_672.yml" 3 | - "_cpu.yml" 4 | - "_spatial.yml" 5 | - "_temporal.yml" 6 | token_top_k: [256] 7 | -------------------------------------------------------------------------------- /configs/time/vitdet_vid/spatiotemporal_672_cuda.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_672.yml" 3 | - "_cuda.yml" 4 | - "_spatial.yml" 5 | - "_temporal.yml" 6 | token_top_k: [256] 7 | -------------------------------------------------------------------------------- /configs/time/vitdet_vid/temporal_1024_cpu.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_1024.yml" 3 | - "_cpu.yml" 4 | - "_temporal.yml" 5 | token_top_k: [512] 6 | -------------------------------------------------------------------------------- /configs/time/vitdet_vid/temporal_1024_cuda.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_1024.yml" 3 | - "_cuda.yml" 4 | - "_temporal.yml" 5 | token_top_k: [512] 6 | -------------------------------------------------------------------------------- /configs/time/vitdet_vid/temporal_672_cpu.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_672.yml" 3 | - "_cpu.yml" 4 | - "_temporal.yml" 5 | token_top_k: [256] 6 | -------------------------------------------------------------------------------- /configs/time/vitdet_vid/temporal_672_cuda.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_size_672.yml" 3 | - "_cuda.yml" 4 | - "_temporal.yml" 5 | token_top_k: [256] 6 | -------------------------------------------------------------------------------- /configs/time/vivit_epic_kitchens/_base.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "configs/models/vivit_b_epic_kitchens.yml" 3 | _output: "results/time/vivit_epic_kitchens/${_name}/" 4 | vanilla: false 5 | weights: "weights/vivit_b_epic_kitchens.pth" 6 | -------------------------------------------------------------------------------- /configs/time/vivit_epic_kitchens/_cpu.yml: -------------------------------------------------------------------------------- 1 | device: "cpu" 2 | model: 3 | spatial_config: 4 | block_config: 5 | matmul_2_cast: "bfloat16" 6 | n_items: 5 7 | threads: 8 8 | -------------------------------------------------------------------------------- /configs/time/vivit_epic_kitchens/_cuda.yml: -------------------------------------------------------------------------------- 1 | device: "cuda" 2 | model: 3 | spatial_config: 4 | block_config: 5 | matmul_2_cast: "float16" 6 | n_items: 100 7 | -------------------------------------------------------------------------------- /configs/time/vivit_epic_kitchens/_temporal.yml: -------------------------------------------------------------------------------- 1 | model: 2 | spatial_config: 3 | block_class: "EventfulBlock" 4 | -------------------------------------------------------------------------------- /configs/time/vivit_epic_kitchens/base_cpu.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | - "_cpu.yml" 4 | vanilla: true -------------------------------------------------------------------------------- /configs/time/vivit_epic_kitchens/base_cuda.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | - "_cuda.yml" 4 | vanilla: true 5 | -------------------------------------------------------------------------------- /configs/time/vivit_epic_kitchens/temporal_cpu.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | - "_cpu.yml" 4 | - "_temporal.yml" 5 | token_top_k: [50] 6 | -------------------------------------------------------------------------------- /configs/time/vivit_epic_kitchens/temporal_cuda.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | - "_cuda.yml" 4 | - "_temporal.yml" 5 | token_top_k: [50] 6 | -------------------------------------------------------------------------------- /configs/train/vivit_epic_kitchens/_base.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "configs/models/vivit_b_epic_kitchens.yml" 3 | _output: "results/train/vivit_epic_kitchens/${_name}/" 4 | epochs: 5 5 | model: 6 | dropout_rate: 0.5 7 | temporal_only: true 8 | optimizer: "AdamW" 9 | optimizer_kwargs: 10 | lr: 1.0e-5 11 | weight_decay: 0.05 12 | output_weights: "weights/vivit_b_epic_kitchens_${_name}.pth" 13 | starting_weights: "weights/vivit_b_epic_kitchens.pth" 14 | tensorboard: "tensorboard/${_name}" 15 | train_batch_size: 8 16 | val_batch_size: 8 17 | -------------------------------------------------------------------------------- /configs/train/vivit_epic_kitchens/final_100.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | k: 100 4 | -------------------------------------------------------------------------------- /configs/train/vivit_epic_kitchens/final_200.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | k: 200 4 | -------------------------------------------------------------------------------- /configs/train/vivit_epic_kitchens/final_50.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | k: 50 4 | -------------------------------------------------------------------------------- /configs/train/vivit_kinetics400/_base.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "configs/models/vivit_b_kinetics400.yml" 3 | _output: "results/train/vivit_kinetics400/${_name}/" 4 | epochs: 10 5 | model: 6 | dropout_rate: 0.5 7 | temporal_only: true 8 | optimizer: "AdamW" 9 | optimizer_kwargs: 10 | lr: 2.0e-6 11 | weight_decay: 0.05 12 | output_weights: "weights/vivit_b_kinetics400_${_name}.pth" 13 | starting_weights: "weights/vivit_b_kinetics400.pth" 14 | tensorboard: "tensorboard/${_name}" 15 | train_batch_size: 16 16 | val_batch_size: 16 17 | -------------------------------------------------------------------------------- /configs/train/vivit_kinetics400/final_24.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | k: 24 4 | -------------------------------------------------------------------------------- /configs/train/vivit_kinetics400/final_48.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | k: 48 4 | -------------------------------------------------------------------------------- /configs/train/vivit_kinetics400/final_96.yml: -------------------------------------------------------------------------------- 1 | _defaults: 2 | - "_base.yml" 3 | k: 96 4 | -------------------------------------------------------------------------------- /datasets/epic_kitchens.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import shutil 3 | from pathlib import Path 4 | from sys import stderr 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | from torchvision.io import read_image 9 | from tqdm import tqdm 10 | 11 | from utils.misc import decode_video, seeded_shuffle 12 | 13 | SPLITS = ["train", "validation"] 14 | 15 | 16 | class EPICKitchens(Dataset): 17 | """ 18 | A loader for the EPIC-Kitchens 100 dataset. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | location, 24 | split="validation", 25 | shuffle=True, 26 | shuffle_seed=42, 27 | video_transform=None, 28 | ): 29 | """ 30 | Initializes the loader. On the first call, this constructor will 31 | do some one-time setup. 32 | 33 | :param location: Directory containing the dataset (e.g., 34 | data/epic_kitchens). See the project README. 35 | :param split: Either "train" or "validation" 36 | :param shuffle: Whether to shuffle videos 37 | :param shuffle_seed: The seed to use if shuffling 38 | :param video_transform: A callable to be applied to each video 39 | as it is loaded 40 | """ 41 | assert split in SPLITS 42 | self.video_transform = video_transform 43 | 44 | # Make sure the dataset has been set up. 45 | Path(location, split).mkdir(parents=True, exist_ok=True) 46 | if not self.is_decoded(location, split): 47 | self.clean_decoded(location, split) 48 | self.decode(location, split) 49 | 50 | # Load information about each clip in the dataset. 51 | self.frames_path = Path(location, split, "frames") 52 | self.clips_info = self._get_clips_info(location, split) 53 | 54 | # Optionally shuffle the videos. 55 | if shuffle: 56 | seeded_shuffle(self.clips_info, shuffle_seed) 57 | 58 | def __getitem__(self, index): 59 | """ 60 | Loads and returns an item from the dataset. 61 | 62 | :param index: The index of the item to load 63 | :return: A (video, info) tuple, where "video" is a tensor and 64 | "info" is a dict containing the label and other metadata 65 | """ 66 | clip_info = self.clips_info[index] 67 | clip_path = self.frames_path / f"{clip_info['clip_id']:05d}" 68 | frame_paths = sorted(clip_path.glob("*.jpg")) 69 | video = torch.stack([read_image(str(frame_path)) for frame_path in frame_paths]) 70 | if self.video_transform is not None: 71 | video = self.video_transform(video) 72 | return video, clip_info["class_id"] 73 | 74 | def __len__(self): 75 | """ 76 | Returns the number of items in the dataset. 77 | """ 78 | return len(self.clips_info) 79 | 80 | @staticmethod 81 | def clean_decoded(location, split): 82 | """ 83 | Deletes one-time setup data. 84 | 85 | :param location: The location of the dataset (see __init__) 86 | :param split: The split ("train" or "validation") 87 | """ 88 | base_path = Path(location, split) 89 | (base_path / "decoded").unlink(missing_ok=True) 90 | folder_path = base_path / "frames" 91 | if folder_path.is_dir(): 92 | shutil.rmtree(folder_path) 93 | 94 | @staticmethod 95 | def decode(location, split): 96 | """ 97 | Performs one-time setup. Decode/extract videos based on 98 | information in the CSV files. 99 | 100 | :param location: The location of the dataset (see __init__) 101 | :param split: The split ("train" or "validation") 102 | """ 103 | base_path = Path(location, split) 104 | frames_path = base_path / "frames" 105 | frames_path.mkdir(exist_ok=True) 106 | 107 | # Decode videos into images. 108 | print("Decoding clips...", file=stderr, flush=True) 109 | clips_info = EPICKitchens._get_clips_info(location, split) 110 | for clip_info in tqdm(clips_info, total=len(clips_info), ncols=0): 111 | video_path = Path(location, "videos", f"{clip_info['video_id']}.mp4") 112 | decode_path = frames_path / f"{clip_info['clip_id']:05d}" 113 | ffmpeg_input_args = [ 114 | "-ss", 115 | clip_info["start_time"], 116 | "-to", 117 | clip_info["end_time"], 118 | ] 119 | ffmpeg_output_args = ["-qscale:v", "2"] 120 | return_code = decode_video( 121 | video_path, 122 | decode_path, 123 | name_format="%4d", 124 | image_format="jpg", 125 | ffmpeg_input_args=ffmpeg_input_args, 126 | ffmpeg_output_args=ffmpeg_output_args, 127 | ) 128 | if return_code != 0: 129 | print( 130 | f"Decoding failed for clip {clip_info['clip_id']}", 131 | file=stderr, 132 | flush=True, 133 | ) 134 | shutil.rmtree(decode_path) 135 | 136 | # Create an empty indicator file. 137 | print("Decoding complete.", file=stderr, flush=True) 138 | (base_path / f"decoded").touch() 139 | 140 | @staticmethod 141 | def is_decoded(location, split): 142 | """ 143 | Returns true if one-time setup has been completed. 144 | 145 | :param location: The location of the dataset (see __init__) 146 | :param split: The split ("train" or "validation") 147 | """ 148 | return Path(location, split, "decoded").is_file() 149 | 150 | @staticmethod 151 | def _get_clips_info(location, split): 152 | clips_info = [] 153 | with open(Path(location, f"EPIC_100_{split}.csv"), "r") as csv_file: 154 | csv_reader = csv.reader(csv_file) 155 | next(csv_reader) # Skip header line 156 | for i, line in enumerate(csv_reader): 157 | clips_info.append( 158 | { 159 | "clip_id": i, 160 | "video_id": line[2], 161 | "start_time": line[4], 162 | "end_time": line[5], 163 | "label": line[9], 164 | "class_id": int(line[10]), 165 | } 166 | ) 167 | return clips_info 168 | -------------------------------------------------------------------------------- /datasets/kinetics400.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import shutil 3 | from pathlib import Path 4 | from sys import stderr 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | from torchvision.io import read_image 9 | from tqdm import tqdm 10 | 11 | from utils.misc import decode_video, download_file, seeded_shuffle 12 | 13 | CLASSES = [ 14 | "abseiling", 15 | "air drumming", 16 | "answering questions", 17 | "applauding", 18 | "applying cream", 19 | "archery", 20 | "arm wrestling", 21 | "arranging flowers", 22 | "assembling computer", 23 | "auctioning", 24 | "baby waking up", 25 | "baking cookies", 26 | "balloon blowing", 27 | "bandaging", 28 | "barbequing", 29 | "bartending", 30 | "beatboxing", 31 | "bee keeping", 32 | "belly dancing", 33 | "bench pressing", 34 | "bending back", 35 | "bending metal", 36 | "biking through snow", 37 | "blasting sand", 38 | "blowing glass", 39 | "blowing leaves", 40 | "blowing nose", 41 | "blowing out candles", 42 | "bobsledding", 43 | "bookbinding", 44 | "bouncing on trampoline", 45 | "bowling", 46 | "braiding hair", 47 | "breading or breadcrumbing", 48 | "breakdancing", 49 | "brush painting", 50 | "brushing hair", 51 | "brushing teeth", 52 | "building cabinet", 53 | "building shed", 54 | "bungee jumping", 55 | "busking", 56 | "canoeing or kayaking", 57 | "capoeira", 58 | "carrying baby", 59 | "cartwheeling", 60 | "carving pumpkin", 61 | "catching fish", 62 | "catching or throwing baseball", 63 | "catching or throwing frisbee", 64 | "catching or throwing softball", 65 | "celebrating", 66 | "changing oil", 67 | "changing wheel", 68 | "checking tires", 69 | "cheerleading", 70 | "chopping wood", 71 | "clapping", 72 | "clay pottery making", 73 | "clean and jerk", 74 | "cleaning floor", 75 | "cleaning gutters", 76 | "cleaning pool", 77 | "cleaning shoes", 78 | "cleaning toilet", 79 | "cleaning windows", 80 | "climbing a rope", 81 | "climbing ladder", 82 | "climbing tree", 83 | "contact juggling", 84 | "cooking chicken", 85 | "cooking egg", 86 | "cooking on campfire", 87 | "cooking sausages", 88 | "counting money", 89 | "country line dancing", 90 | "cracking neck", 91 | "crawling baby", 92 | "crossing river", 93 | "crying", 94 | "curling hair", 95 | "cutting nails", 96 | "cutting pineapple", 97 | "cutting watermelon", 98 | "dancing ballet", 99 | "dancing charleston", 100 | "dancing gangnam style", 101 | "dancing macarena", 102 | "deadlifting", 103 | "decorating the christmas tree", 104 | "digging", 105 | "dining", 106 | "disc golfing", 107 | "diving cliff", 108 | "dodgeball", 109 | "doing aerobics", 110 | "doing laundry", 111 | "doing nails", 112 | "drawing", 113 | "dribbling basketball", 114 | "drinking", 115 | "drinking beer", 116 | "drinking shots", 117 | "driving car", 118 | "driving tractor", 119 | "drop kicking", 120 | "drumming fingers", 121 | "dunking basketball", 122 | "dying hair", 123 | "eating burger", 124 | "eating cake", 125 | "eating carrots", 126 | "eating chips", 127 | "eating doughnuts", 128 | "eating hotdog", 129 | "eating ice cream", 130 | "eating spaghetti", 131 | "eating watermelon", 132 | "egg hunting", 133 | "exercising arm", 134 | "exercising with an exercise ball", 135 | "extinguishing fire", 136 | "faceplanting", 137 | "feeding birds", 138 | "feeding fish", 139 | "feeding goats", 140 | "filling eyebrows", 141 | "finger snapping", 142 | "fixing hair", 143 | "flipping pancake", 144 | "flying kite", 145 | "folding clothes", 146 | "folding napkins", 147 | "folding paper", 148 | "front raises", 149 | "frying vegetables", 150 | "garbage collecting", 151 | "gargling", 152 | "getting a haircut", 153 | "getting a tattoo", 154 | "giving or receiving award", 155 | "golf chipping", 156 | "golf driving", 157 | "golf putting", 158 | "grinding meat", 159 | "grooming dog", 160 | "grooming horse", 161 | "gymnastics tumbling", 162 | "hammer throw", 163 | "headbanging", 164 | "headbutting", 165 | "high jump", 166 | "high kick", 167 | "hitting baseball", 168 | "hockey stop", 169 | "holding snake", 170 | "hopscotch", 171 | "hoverboarding", 172 | "hugging", 173 | "hula hooping", 174 | "hurdling", 175 | "hurling (sport)", 176 | "ice climbing", 177 | "ice fishing", 178 | "ice skating", 179 | "ironing", 180 | "javelin throw", 181 | "jetskiing", 182 | "jogging", 183 | "juggling balls", 184 | "juggling fire", 185 | "juggling soccer ball", 186 | "jumping into pool", 187 | "jumpstyle dancing", 188 | "kicking field goal", 189 | "kicking soccer ball", 190 | "kissing", 191 | "kitesurfing", 192 | "knitting", 193 | "krumping", 194 | "laughing", 195 | "laying bricks", 196 | "long jump", 197 | "lunge", 198 | "making a cake", 199 | "making a sandwich", 200 | "making bed", 201 | "making jewelry", 202 | "making pizza", 203 | "making snowman", 204 | "making sushi", 205 | "making tea", 206 | "marching", 207 | "massaging back", 208 | "massaging feet", 209 | "massaging legs", 210 | "massaging person's head", 211 | "milking cow", 212 | "mopping floor", 213 | "motorcycling", 214 | "moving furniture", 215 | "mowing lawn", 216 | "news anchoring", 217 | "opening bottle", 218 | "opening present", 219 | "paragliding", 220 | "parasailing", 221 | "parkour", 222 | "passing American football (in game)", 223 | "passing American football (not in game)", 224 | "peeling apples", 225 | "peeling potatoes", 226 | "petting animal (not cat)", 227 | "petting cat", 228 | "picking fruit", 229 | "planting trees", 230 | "plastering", 231 | "playing accordion", 232 | "playing badminton", 233 | "playing bagpipes", 234 | "playing basketball", 235 | "playing bass guitar", 236 | "playing cards", 237 | "playing cello", 238 | "playing chess", 239 | "playing clarinet", 240 | "playing controller", 241 | "playing cricket", 242 | "playing cymbals", 243 | "playing didgeridoo", 244 | "playing drums", 245 | "playing flute", 246 | "playing guitar", 247 | "playing harmonica", 248 | "playing harp", 249 | "playing ice hockey", 250 | "playing keyboard", 251 | "playing kickball", 252 | "playing monopoly", 253 | "playing organ", 254 | "playing paintball", 255 | "playing piano", 256 | "playing poker", 257 | "playing recorder", 258 | "playing saxophone", 259 | "playing squash or racquetball", 260 | "playing tennis", 261 | "playing trombone", 262 | "playing trumpet", 263 | "playing ukulele", 264 | "playing violin", 265 | "playing volleyball", 266 | "playing xylophone", 267 | "pole vault", 268 | "presenting weather forecast", 269 | "pull ups", 270 | "pumping fist", 271 | "pumping gas", 272 | "punching bag", 273 | "punching person (boxing)", 274 | "push up", 275 | "pushing car", 276 | "pushing cart", 277 | "pushing wheelchair", 278 | "reading book", 279 | "reading newspaper", 280 | "recording music", 281 | "riding a bike", 282 | "riding camel", 283 | "riding elephant", 284 | "riding mechanical bull", 285 | "riding mountain bike", 286 | "riding mule", 287 | "riding or walking with horse", 288 | "riding scooter", 289 | "riding unicycle", 290 | "ripping paper", 291 | "robot dancing", 292 | "rock climbing", 293 | "rock scissors paper", 294 | "roller skating", 295 | "running on treadmill", 296 | "sailing", 297 | "salsa dancing", 298 | "sanding floor", 299 | "scrambling eggs", 300 | "scuba diving", 301 | "setting table", 302 | "shaking hands", 303 | "shaking head", 304 | "sharpening knives", 305 | "sharpening pencil", 306 | "shaving head", 307 | "shaving legs", 308 | "shearing sheep", 309 | "shining shoes", 310 | "shooting basketball", 311 | "shooting goal (soccer)", 312 | "shot put", 313 | "shoveling snow", 314 | "shredding paper", 315 | "shuffling cards", 316 | "side kick", 317 | "sign language interpreting", 318 | "singing", 319 | "situp", 320 | "skateboarding", 321 | "ski jumping", 322 | "skiing (not slalom or crosscountry)", 323 | "skiing crosscountry", 324 | "skiing slalom", 325 | "skipping rope", 326 | "skydiving", 327 | "slacklining", 328 | "slapping", 329 | "sled dog racing", 330 | "smoking", 331 | "smoking hookah", 332 | "snatch weight lifting", 333 | "sneezing", 334 | "sniffing", 335 | "snorkeling", 336 | "snowboarding", 337 | "snowkiting", 338 | "snowmobiling", 339 | "somersaulting", 340 | "spinning poi", 341 | "spray painting", 342 | "spraying", 343 | "springboard diving", 344 | "squat", 345 | "sticking tongue out", 346 | "stomping grapes", 347 | "stretching arm", 348 | "stretching leg", 349 | "strumming guitar", 350 | "surfing crowd", 351 | "surfing water", 352 | "sweeping floor", 353 | "swimming backstroke", 354 | "swimming breast stroke", 355 | "swimming butterfly stroke", 356 | "swing dancing", 357 | "swinging legs", 358 | "swinging on something", 359 | "sword fighting", 360 | "tai chi", 361 | "taking a shower", 362 | "tango dancing", 363 | "tap dancing", 364 | "tapping guitar", 365 | "tapping pen", 366 | "tasting beer", 367 | "tasting food", 368 | "testifying", 369 | "texting", 370 | "throwing axe", 371 | "throwing ball", 372 | "throwing discus", 373 | "tickling", 374 | "tobogganing", 375 | "tossing coin", 376 | "tossing salad", 377 | "training dog", 378 | "trapezing", 379 | "trimming or shaving beard", 380 | "trimming trees", 381 | "triple jump", 382 | "tying bow tie", 383 | "tying knot (not on a tie)", 384 | "tying tie", 385 | "unboxing", 386 | "unloading truck", 387 | "using computer", 388 | "using remote controller (not gaming)", 389 | "using segway", 390 | "vault", 391 | "waiting in line", 392 | "walking the dog", 393 | "washing dishes", 394 | "washing feet", 395 | "washing hair", 396 | "washing hands", 397 | "water skiing", 398 | "water sliding", 399 | "watering plants", 400 | "waxing back", 401 | "waxing chest", 402 | "waxing eyebrows", 403 | "waxing legs", 404 | "weaving basket", 405 | "welding", 406 | "whistling", 407 | "windsurfing", 408 | "wrapping present", 409 | "wrestling", 410 | "writing", 411 | "yawning", 412 | "yoga", 413 | "zumba", 414 | ] 415 | 416 | CLASS_IDS = {name: i for i, name in enumerate(CLASSES)} 417 | 418 | SPLITS = ["train", "test", "val"] 419 | 420 | # https://github.com/cvdfoundation/kinetics-dataset/blob/main/k400_downloader.sh 421 | LABEL_DOWNLOADS = { 422 | split: f"https://s3.amazonaws.com/kinetics/400/annotations/{split}.csv" 423 | for split in SPLITS 424 | } 425 | VIDEO_DOWNLOADS = { 426 | split: f"https://s3.amazonaws.com/kinetics/400/{split}/k400_{split}_path.txt" 427 | for split in SPLITS 428 | } 429 | 430 | 431 | class Kinetics400(Dataset): 432 | """ 433 | A loader for the Kinetics-400 dataset. 434 | """ 435 | 436 | def __init__( 437 | self, 438 | location, 439 | split="val", 440 | decode_size=None, 441 | decode_fps=None, 442 | max_tars=None, 443 | shuffle=True, 444 | shuffle_seed=42, 445 | video_transform=None, 446 | ): 447 | """ 448 | Initializes the loader. On the first call, this constructor will 449 | do some one-time setup (including downloading data). 450 | 451 | :param location: Directory where the dataset should be stored 452 | :param split: Either "train", "test", or "val" 453 | :param decode_size: The short-edge length for decoded frames 454 | :param decode_fps: The fps for decoded frames 455 | :param max_tars: Set a cap on the number of tar files to 456 | download for this split. Each tar contains about 1k videos. 457 | :param shuffle: Whether to shuffle videos 458 | :param shuffle_seed: The seed to use if shuffling 459 | :param video_transform: A callable to be applied to each video 460 | as it is loaded 461 | """ 462 | assert split in SPLITS 463 | self.video_transform = video_transform 464 | 465 | base_split = split 466 | if max_tars is not None: 467 | split = f"{split}_{max_tars}" 468 | 469 | # Make sure the dataset has been set up. 470 | Path(location, split).mkdir(parents=True, exist_ok=True) 471 | if not self.is_downloaded(location, split): 472 | self.clean_downloaded(location, split) 473 | self.download(location, base_split, split, max_tars) 474 | if not self.is_unpacked(location, split): 475 | self.clean_unpacked(location, split) 476 | self.unpack(location, split) 477 | if not self.is_decoded(location, split, decode_size, decode_fps): 478 | self.clean_decoded(location, split, decode_size, decode_fps) 479 | self.decode(location, split, decode_size, decode_fps) 480 | 481 | # Load information about each video in the dataset. 482 | self.frames_path = Path(location, split, f"frames_{decode_size}_{decode_fps}") 483 | self.videos_info = self._get_videos_info( 484 | location, split, decode_size, decode_fps 485 | ) 486 | 487 | # Optionally shuffle the videos (by default they are sorted). 488 | if shuffle: 489 | seeded_shuffle(self.videos_info, shuffle_seed) 490 | 491 | def __getitem__(self, index): 492 | """ 493 | Loads and returns an item from the dataset. 494 | 495 | :param index: The index of the item to load 496 | :return: A (video, label) tuple, where "video" is a tensor and 497 | "label" is the class label. 498 | """ 499 | video_info = self.videos_info[index] 500 | video_path = self.frames_path / video_info["video_id"] 501 | video = torch.stack( 502 | [read_image(str(video_path / frame)) for frame in video_info["frames"]] 503 | ) 504 | if self.video_transform is not None: 505 | video = self.video_transform(video) 506 | return video, video_info["label"] 507 | 508 | def __len__(self): 509 | """ 510 | Returns the number of items in the dataset. 511 | """ 512 | return len(self.videos_info) 513 | 514 | @staticmethod 515 | def clean_decoded(location, split, decode_size, decode_fps): 516 | """ 517 | Deletes one-time setup data (decoded frames). 518 | 519 | :param location: The location of the dataset (see __init__) 520 | :param split: The split name 521 | :param decode_size: The short-edge length for decoded frames 522 | :param decode_fps: The fps for decoded frames 523 | """ 524 | base_path = Path(location, split) 525 | (base_path / f"decoded_{decode_size}_{decode_fps}").unlink(missing_ok=True) 526 | folder_path = base_path / f"frames_{decode_size}_{decode_fps}" 527 | if folder_path.is_dir(): 528 | shutil.rmtree(folder_path) 529 | 530 | @staticmethod 531 | def clean_downloaded(location, split): 532 | """ 533 | Deletes downloaded data (e.g., tar files). 534 | 535 | :param location: The location of the dataset (see __init__) 536 | :param split: The split name 537 | """ 538 | base_path = Path(location, split) 539 | (base_path / "downloaded").unlink(missing_ok=True) 540 | (base_path / "labels.csv").unlink(missing_ok=True) 541 | folder_path = base_path / "downloads" 542 | if folder_path.is_dir(): 543 | shutil.rmtree(folder_path) 544 | 545 | @staticmethod 546 | def clean_unpacked(location, split): 547 | """ 548 | Deletes one-time setup data (unpacked tars). 549 | 550 | :param location: The location of the dataset (see __init__) 551 | :param split: The split name 552 | """ 553 | base_path = Path(location, split) 554 | (base_path / "unpacked").unlink(missing_ok=True) 555 | folder_path = base_path / "videos" 556 | if folder_path.is_dir(): 557 | shutil.rmtree(folder_path) 558 | 559 | @staticmethod 560 | def decode(location, split, decode_size, decode_fps): 561 | """ 562 | Performs one-time setup (frame decoding). 563 | 564 | :param location: The location of the dataset (see __init__) 565 | :param split: The split name 566 | :param decode_size: The short-edge length for decoded frames 567 | :param decode_fps: The fps for decoded frames 568 | """ 569 | base_path = Path(location, split) 570 | frames_path = base_path / f"frames_{decode_size}_{decode_fps}" 571 | frames_path.mkdir(exist_ok=True) 572 | 573 | # Decode videos into images. 574 | print("Decoding videos...", file=stderr, flush=True) 575 | video_list = list((base_path / "videos").glob("*.mp4")) 576 | for video_path in tqdm(video_list, total=len(video_list), ncols=0): 577 | ffmpeg_output_args = ["-qscale:v", "2"] 578 | if decode_size is not None: 579 | # These options resize the short side to decode_size. 580 | ffmpeg_output_args += [ 581 | "-filter:v", 582 | f"scale={decode_size}:{decode_size}:force_original_aspect_ratio=increase", 583 | ] 584 | if decode_fps is not None: 585 | # Use the "framerate" or "minterpolate" filters for 586 | # higher-quality FPS adjustments. 587 | # https://ffmpeg.org/ffmpeg-filters.html 588 | ffmpeg_output_args += ["-r", f"{decode_fps}"] 589 | decode_path = frames_path / video_path.stem 590 | return_code = decode_video( 591 | video_path, 592 | decode_path, 593 | name_format="%3d", 594 | image_format="jpg", 595 | ffmpeg_output_args=ffmpeg_output_args, 596 | ) 597 | if return_code != 0: 598 | print( 599 | f"Decoding failed for video {video_path.stem}.", 600 | file=stderr, 601 | flush=True, 602 | ) 603 | shutil.rmtree(decode_path) 604 | 605 | # Create an empty indicator file. 606 | print("Decoding complete.", file=stderr, flush=True) 607 | (base_path / f"decoded_{decode_size}_{decode_fps}").touch() 608 | 609 | @staticmethod 610 | def download(location, base_split, split, max_tars): 611 | """ 612 | Performs one-time setup (downloading data). 613 | 614 | :param location: The location of the dataset (see __init__) 615 | :param base_split: The main split ("train", "test", or "val") 616 | :param split: The qualified split name (e.g., train_40 for 617 | max_tars=40) 618 | :param max_tars: Set a cap on the number of tar files to 619 | download for this split. Each tar contains about 1k videos. 620 | """ 621 | base_path = Path(location, split) 622 | downloads_path = base_path / "downloads" 623 | downloads_path.mkdir(exist_ok=True) 624 | 625 | # Download the class labels. 626 | download_file(LABEL_DOWNLOADS[base_split], base_path / "labels.csv") 627 | 628 | # Download the video archive files. 629 | download_file(VIDEO_DOWNLOADS[base_split], downloads_path / "download_list.txt") 630 | n = 0 631 | with open(downloads_path / "download_list.txt", "r") as download_list: 632 | for url in download_list: 633 | if (max_tars is not None) and (n >= max_tars): 634 | break 635 | url = url.strip() 636 | filename = url.split("/")[-1] 637 | download_file(url, downloads_path / filename) 638 | n += 1 639 | 640 | # Create an empty indicator file. 641 | print("Downloads complete.", file=stderr, flush=True) 642 | (base_path / "downloaded").touch() 643 | 644 | @staticmethod 645 | def is_decoded(location, split, decode_size, decode_fps): 646 | """ 647 | Returns true if one-time setup (frame decoding) has been 648 | completed. 649 | 650 | :param location: The location of the dataset (see __init__) 651 | :param split: The split name 652 | :param decode_size: The short-edge length for decoded frames 653 | :param decode_fps: The fps for decoded frames 654 | """ 655 | return Path(location, split, f"decoded_{decode_size}_{decode_fps}").is_file() 656 | 657 | @staticmethod 658 | def is_downloaded(location, split): 659 | """ 660 | Returns true if one-time setup (data download) has been 661 | completed. 662 | 663 | :param location: The location of the dataset (see __init__) 664 | :param split: The split name 665 | """ 666 | return Path(location, split, "downloaded").is_file() 667 | 668 | @staticmethod 669 | def is_unpacked(location, split): 670 | """ 671 | Returns true if one-time setup (tar unpacking) has been 672 | completed. 673 | 674 | :param location: The location of the dataset (see __init__) 675 | :param split: The split name 676 | """ 677 | return Path(location, split, "unpacked").is_file() 678 | 679 | @staticmethod 680 | def unpack(location, split): 681 | """ 682 | Performs one-time setup (unpacking tars). 683 | 684 | :param location: The location of the dataset (see __init__) 685 | :param split: The split name 686 | """ 687 | base_path = Path(location, split) 688 | downloads_path = base_path / "downloads" 689 | videos_path = base_path / "videos" 690 | videos_path.mkdir(exist_ok=True) 691 | 692 | # Unpack the video archive files. 693 | with open(downloads_path / "download_list.txt", "r") as download_list: 694 | for url in download_list: 695 | url = url.strip() 696 | filename = url.split("/")[-1] 697 | filepath = downloads_path / url.split("/")[-1] 698 | if filepath.exists(): 699 | print(f"Unpacking {filename}...", file=stderr, flush=True) 700 | shutil.unpack_archive(filepath, videos_path) 701 | 702 | # Create an empty indicator file. 703 | print("Unpacking complete.", file=stderr, flush=True) 704 | (base_path / "unpacked").touch() 705 | 706 | @staticmethod 707 | def _get_videos_info(location, split, decode_size, decode_fps): 708 | videos_info = [] 709 | frames_path = Path(location, split, f"frames_{decode_size}_{decode_fps}") 710 | with open(Path(location, split, "labels.csv"), "r") as csv_file: 711 | csv_reader = csv.reader(csv_file) 712 | next(csv_reader) # Skip header line 713 | for line in csv_reader: 714 | video_id = f"{line[1]}_{int(line[2]):06d}_{int(line[3]):06d}" 715 | label = CLASS_IDS[line[0]] 716 | video_path = frames_path / video_id 717 | if not video_path.is_dir(): 718 | continue 719 | frames = [path.name for path in video_path.glob("*.jpg")] 720 | frames.sort() 721 | videos_info.append( 722 | {"video_id": video_id, "label": label, "frames": frames} 723 | ) 724 | videos_info.sort(key=lambda x: x["video_id"]) 725 | return videos_info 726 | -------------------------------------------------------------------------------- /datasets/vid.py: -------------------------------------------------------------------------------- 1 | import json 2 | import shutil 3 | from collections import defaultdict 4 | from copy import deepcopy 5 | from pathlib import Path 6 | from sys import stderr 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data import Dataset 11 | from torchvision.io import read_image 12 | 13 | from utils.image import rescale 14 | from utils.misc import seeded_shuffle 15 | 16 | CLASSES = [ 17 | "airplane", 18 | "antelope", 19 | "bear", 20 | "bicycle", 21 | "bird", 22 | "bus", 23 | "car", 24 | "cattle", 25 | "dog", 26 | "domestic_cat", 27 | "elephant", 28 | "fox", 29 | "giant_panda", 30 | "hamster", 31 | "horse", 32 | "lion", 33 | "lizard", 34 | "monkey", 35 | "motorcycle", 36 | "rabbit", 37 | "red_panda", 38 | "sheep", 39 | "snake", 40 | "squirrel", 41 | "tiger", 42 | "train", 43 | "turtle", 44 | "watercraft", 45 | "whale", 46 | "zebra", 47 | ] 48 | 49 | SPLITS = ["det_train", "vid_train", "vid_val", "vid_minival"] 50 | 51 | 52 | class VID(Dataset): 53 | """ 54 | A loader for the ImageNet VID dataset. 55 | """ 56 | 57 | def __init__( 58 | self, 59 | location, 60 | split="vid_val", 61 | tar_path=None, 62 | shuffle=True, 63 | shuffle_seed=42, 64 | frame_transform=None, 65 | annotation_transform=None, 66 | combined_transform=None, 67 | ): 68 | """ 69 | Initializes the loader. One the first call, this constructor 70 | will do some one-time setup. 71 | 72 | :param location: Directory containing the dataset (e.g., 73 | data/vid). See the project README. 74 | :param split: Either "det_train", "vid_train", "vid_val", or 75 | "vid_minival" 76 | :param tar_path: Location of the data.tar file (e.g., 77 | data/vid/data.tar). See the project README. 78 | :param shuffle: Whether to shuffle videos. 79 | :param shuffle_seed: The seed to use if shuffling. 80 | :param frame_transform: A callable to be applied to each frame 81 | as it is loaded. Passed to VIDItem constructor. 82 | :param annotation_transform: A callable to be applied to each 83 | bounding-box annotation as it is loaded. Passed to VIDItem 84 | constructor. 85 | :param combined_transform: A callable to be applied to each 86 | (frame, annotation) tuple as it is loaded. Passed to VIDItem 87 | constructor. 88 | """ 89 | assert split in SPLITS 90 | self.frame_transform = frame_transform 91 | self.annotation_transform = annotation_transform 92 | self.combined_transform = combined_transform 93 | 94 | # Make sure the dataset has been set up. 95 | if not self.is_unpacked(location): 96 | assert tar_path is not None 97 | self.clean_unpacked(location) 98 | self.unpack(location, tar_path) 99 | 100 | # Load information about each video in the dataset. 101 | self.frames_path = Path(location, split, "frames") 102 | self.video_info = self._get_videos_info(location, split) 103 | 104 | # Optionally shuffle the videos (by default they are sorted). 105 | if shuffle: 106 | seeded_shuffle(self.video_info, shuffle_seed) 107 | 108 | def __getitem__(self, index): 109 | """ 110 | Loads and returns an item from the dataset. 111 | 112 | :param index: The index of the item to load 113 | :return: A VIDItem object. 114 | """ 115 | video_info = self.video_info[index] 116 | video_path = self.frames_path / video_info["video_id"] 117 | frame_paths = [ 118 | str(video_path / frame["filename"]) for frame in video_info["frames"] 119 | ] 120 | annotations = [frame["annotations"] for frame in video_info["frames"]] 121 | vid_item = VIDItem( 122 | frame_paths, 123 | annotations, 124 | self.frame_transform, 125 | self.annotation_transform, 126 | self.combined_transform, 127 | ) 128 | return vid_item 129 | 130 | def __len__(self): 131 | """ 132 | Returns the number of items in the dataset. 133 | """ 134 | return len(self.video_info) 135 | 136 | @staticmethod 137 | def clean_unpacked(location): 138 | """ 139 | Deletes one-time setup data. 140 | 141 | :param location: The location of the dataset (see __init__) 142 | """ 143 | base_path = Path(location) 144 | (base_path / "unpacked").unlink(missing_ok=True) 145 | for split in SPLITS: 146 | split_path = base_path / split 147 | if split_path.is_dir(): 148 | shutil.rmtree(split_path) 149 | 150 | @staticmethod 151 | def is_unpacked(location): 152 | """ 153 | Returns true if one-time setup has been completed. 154 | 155 | :param location: The location of the dataset (see __init__) 156 | """ 157 | return Path(location, "unpacked").is_file() 158 | 159 | @staticmethod 160 | def unpack(location, tar_path): 161 | """ 162 | Performs one-time setup. Extract data from the data.tar file. 163 | 164 | :param location: The location of the dataset (see __init__) 165 | :param tar_path: The location of the data.tar file (see 166 | __init__) 167 | """ 168 | base_path = Path(location) 169 | base_path.mkdir(exist_ok=True) 170 | 171 | # Unpack the tar archive. 172 | print(f"Unpacking {tar_path.name}...", file=stderr, flush=True) 173 | shutil.unpack_archive(tar_path, base_path) 174 | unpacked_path = base_path / "vid_data" 175 | print("Unpacking complete.", file=stderr, flush=True) 176 | 177 | # Move the annotations. 178 | print(f"Reorganizing data...", file=stderr, flush=True) 179 | for split in SPLITS: 180 | split_path = base_path / split 181 | split_path.mkdir(exist_ok=True) 182 | annotations_path = unpacked_path / "annotations" / f"{split}.json" 183 | annotations_path.rename(split_path / "labels.json") 184 | 185 | # Reorganize the images. 186 | for split in SPLITS[:-1]: 187 | split_path = base_path / split 188 | frames_path = split_path / "frames" 189 | frames_path.mkdir(exist_ok=True) 190 | for filename in (unpacked_path / split).glob("*.JPEG"): 191 | video_id, frame_number = filename.stem.split("_")[-2:] 192 | video_path = frames_path / video_id 193 | video_path.mkdir(exist_ok=True) 194 | filename.rename(video_path / f"{frame_number}.jpg") 195 | 196 | # Symlink vid_minival/frames to vid_val/frames. 197 | link_from = base_path / SPLITS[-1] / "frames" 198 | link_to = base_path / SPLITS[-2] / "frames" 199 | link_from.symlink_to(link_to.resolve(), target_is_directory=True) 200 | print(f"Reorganization complete.", file=stderr, flush=True) 201 | 202 | # Clean up and create an empty indicator file. 203 | shutil.rmtree(unpacked_path) 204 | (base_path / "unpacked").touch() 205 | 206 | @staticmethod 207 | def _get_videos_info(location, split): 208 | with Path(location, split, "labels.json").open("r") as json_file: 209 | json_data = json.load(json_file) 210 | 211 | # Place frames in a dictionary with their ID as the key. 212 | frame_dict = {} 213 | for item in json_data["images"]: 214 | filename = Path(item["file_name"]) 215 | video_id, frame_number = filename.stem.split("_")[-2:] 216 | frame_dict[item["id"]] = { 217 | "video_id": video_id, 218 | "filename": f"{frame_number}.jpg", 219 | "annotations": {"boxes": [], "labels": []}, 220 | } 221 | 222 | # Assign each bounding box annotation to the correct frame. 223 | for item in json_data["annotations"]: 224 | annotations = frame_dict[item["image_id"]]["annotations"] 225 | # Convert from xywh to xyxy (what ViTDet outputs). 226 | x, y, w, h = item["bbox"] 227 | annotations["boxes"].append([x, y, x + w, y + h]) 228 | 229 | # Convert to zero-based category IDs (what ViTDet outputs). 230 | annotations["labels"].append(item["category_id"] - 1) 231 | 232 | # Convert annotations to tensors and organize frames by video. 233 | video_dict = defaultdict(list) 234 | for frame in frame_dict.values(): 235 | for key in "boxes", "labels": 236 | frame["annotations"][key] = torch.tensor(frame["annotations"][key]) 237 | video_dict[frame.pop("video_id")].append(frame) 238 | videos_info = [] 239 | for video_id, video in video_dict.items(): 240 | video.sort(key=lambda v: v["filename"]) 241 | # Some videos contain several non-contiguous segments. We 242 | # need to split these into distinct sequences. 243 | last = None 244 | segment = [] 245 | for frame in video: 246 | i = int(Path(frame["filename"]).stem) 247 | if (last is not None) and (i > last + 1): 248 | videos_info.append({"video_id": video_id, "frames": segment}) 249 | segment = [] 250 | segment.append(frame) 251 | last = i 252 | if len(segment) > 0: 253 | videos_info.append({"video_id": video_id, "frames": segment}) 254 | 255 | videos_info.sort(key=lambda v: v["video_id"] + v["frames"][0]["filename"]) 256 | return videos_info 257 | 258 | 259 | class VIDItem(Dataset): 260 | """ 261 | A Dataset subclass for iterating over a single VID item. Necessary 262 | due to the very long length of some videos (loading into a single 263 | tensor would exhaust memory). 264 | """ 265 | 266 | def __init__( 267 | self, 268 | frame_paths, 269 | annotations, 270 | frame_transform, 271 | annotation_transform, 272 | combined_transform, 273 | ): 274 | """ 275 | Initializes the item. 276 | 277 | :param frame_paths: A list of frame paths for this item 278 | :param annotations: A list of annotations for this item 279 | :param frame_transform: A callable to be applied to each frame 280 | as it is loaded. 281 | :param annotation_transform: A callable to be applied to each 282 | bounding-box annotation as it is loaded. 283 | :param combined_transform: A callable to be applied to each 284 | (frame, annotation) tuple as it is loaded. 285 | """ 286 | self.frame_paths = frame_paths 287 | self.annotations = annotations 288 | self.frame_transform = frame_transform 289 | self.annotation_transform = annotation_transform 290 | self.combined_transform = combined_transform 291 | 292 | def __getitem__(self, index): 293 | """ 294 | Loads and returns a frame and the corresponding labels. 295 | 296 | :param index: The frame index 297 | :return: A (frame, annotations) tuple 298 | """ 299 | frame = read_image(self.frame_paths[index]) 300 | if self.frame_transform is not None: 301 | frame = self.frame_transform(frame) 302 | annotations = self.annotations[index] 303 | if self.annotation_transform is not None: 304 | annotations = self.annotation_transform(annotations) 305 | if self.combined_transform is not None: 306 | return self.combined_transform((frame, annotations)) 307 | else: 308 | return frame, annotations 309 | 310 | def __len__(self): 311 | """ 312 | Returns the number of frame in the item. 313 | """ 314 | return len(self.frame_paths) 315 | 316 | 317 | # short_edge_length=640 and max_size=1024: 318 | # https://github.com/happyharrycn/detectron2_vitdet_vid/blob/main/projects/ViTDet-VID/configs/frcnn_vitdet.py#L103 319 | class VIDResize(nn.Module): 320 | """ 321 | A PyTorch module for simultaneously resizing frames and annotations. 322 | Should be passed to VID.__init__ as a combined_transform. 323 | """ 324 | 325 | def __init__(self, short_edge_length, max_size): 326 | """ 327 | 328 | :param short_edge_length: The size to which the short edge 329 | should be resized 330 | :param max_size: The maximum size of the long edge (this 331 | overrides short_edge_length if there is a conflict) 332 | """ 333 | super().__init__() 334 | self.short_edge_length = short_edge_length 335 | self.max_size = max_size 336 | 337 | def forward(self, x): 338 | frame, annotations = x 339 | short_edge = min(frame.shape[-2:]) 340 | long_edge = max(frame.shape[-2:]) 341 | scale = min(self.short_edge_length / short_edge, self.max_size / long_edge) 342 | frame = rescale(frame, scale) 343 | annotations = deepcopy(annotations) 344 | annotations["boxes"] *= scale 345 | return frame, annotations 346 | -------------------------------------------------------------------------------- /datasets/vivit_spatial.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | 7 | from utils.misc import seeded_shuffle 8 | 9 | 10 | class ViViTSpatial(Dataset): 11 | """ 12 | A loader for intermediate outputs of the ViViT spatial model. See 13 | scripts/train/vivit_epic_kitchens.py. 14 | """ 15 | 16 | def __init__( 17 | self, 18 | location, 19 | split="train", 20 | base_name="spatial", 21 | k=None, 22 | shuffle=True, 23 | shuffle_seed=42, 24 | ): 25 | """ 26 | Initializes the loader. 27 | 28 | :param location: Location where the base dataset is stored 29 | (e.g., data/epic_kitchens) 30 | :param split: The split for the base dataset (e.g., "train") 31 | :param base_name: The name of the intermediate output folder 32 | containing .npz files (e.g., "spatial_50") 33 | :param k: If not None, this is appended as f"{base_name}_{k}" 34 | :param shuffle: Whether to shuffle items 35 | :param shuffle_seed: The seed to use if shuffling 36 | """ 37 | # Load the path of each item in the dataset. 38 | name = base_name if (k is None) else f"{base_name}_{k}" 39 | paths = sorted(Path(location, split, name).glob("*.npz")) 40 | self.item_paths = [str(path) for path in paths] 41 | 42 | # Optionally shuffle the items. 43 | if shuffle: 44 | seeded_shuffle(self.item_paths, shuffle_seed) 45 | 46 | def __getitem__(self, index): 47 | """ 48 | Loads and returns an item from the dataset. 49 | 50 | :param index: The index of the item to load 51 | :return: A (spatial, label) tuple, where "spatial" is the 52 | intermediate output of the ViViT spatial model, and "label" is 53 | the class label. 54 | """ 55 | item = np.load(self.item_paths[index]) 56 | return torch.tensor(item["spatial"]), torch.tensor(item["label"]) 57 | 58 | def __len__(self): 59 | """ 60 | Returns the number of items in the dataset. 61 | """ 62 | return len(self.item_paths) 63 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: eventful-transformer 2 | channels: 3 | - defaults 4 | dependencies: 5 | - ffmpeg=4.2 6 | - matplotlib=3.7 7 | - numpy=1.23 8 | - nvidia::cuda=11.8 9 | - nvidia::cuda-nvcc=11.8 # For building Detectron2 10 | - pandas=1.5 11 | - pip=23.0 12 | - python=3.10 13 | - pytorch::pytorch-cuda=11.8 14 | - pytorch::pytorch=2.0 15 | - pytorch::torchvision=0.15 16 | - requests=2.28 17 | - scipy=1.10 18 | - tensorboard=2.11 19 | - torchmetrics=0.11 20 | - tqdm=4.65.* 21 | - pip: 22 | - av==10.0.* 23 | - omegaconf==2.3.* 24 | - opencv-python==4.7.* # Required for Detectron visualizer 25 | 26 | # This builds Detectron2 from source (requires gcc >= 5.4). 27 | # The specified commit hash corresponds to Detectron2 0.6. 28 | - git+https://github.com/facebookresearch/detectron2.git@88217cad6d741ea1510d13e54089739f5a0f4d7d 29 | -------------------------------------------------------------------------------- /eventful_transformer/backbones.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from eventful_transformer import blocks 4 | from eventful_transformer.base import ExtendedModule 5 | from eventful_transformer.utils import PositionEncoding 6 | 7 | 8 | class ViTBackbone(ExtendedModule): 9 | """ 10 | Common backbone for vision Transformers. 11 | """ 12 | 13 | def __init__( 14 | self, 15 | block_config, 16 | depth, 17 | position_encoding_size, 18 | input_size, 19 | block_class="Block", 20 | has_class_token=False, 21 | window_indices=(), 22 | windowed_class=None, 23 | windowed_overrides=None, 24 | ): 25 | """ 26 | :param block_config: A dict containing kwargs for the 27 | block_class constructor 28 | :param depth: The number of blocks to use 29 | :param position_encoding_size: The size (in tokens) assumed for 30 | position encodings 31 | :param input_size: The expected size of the inputs in tokens 32 | :param block_class: The specific Block class to use (see 33 | blocks.py for options) 34 | :param has_class_token: Whether to add an extra class token 35 | :param window_indices: Block indices that should use windowed 36 | attention 37 | :param windowed_class: The specific Block class to use with 38 | windowed attention (if None, fall back to block_class) 39 | :param windowed_overrides: A dict containing kwargs overrides 40 | for windowed_class 41 | """ 42 | super().__init__() 43 | self.position_encoding = PositionEncoding( 44 | block_config["dim"], position_encoding_size, input_size, has_class_token 45 | ) 46 | self.blocks = nn.Sequential() 47 | for i in range(depth): 48 | block_class_i = block_class 49 | block_config_i = block_config.copy() 50 | if i in window_indices: 51 | if windowed_class is not None: 52 | block_class_i = windowed_class 53 | if windowed_overrides is not None: 54 | block_config_i |= windowed_overrides 55 | else: 56 | block_config_i["window_size"] = None 57 | self.blocks.append( 58 | getattr(blocks, block_class_i)(input_size=input_size, **block_config_i) 59 | ) 60 | 61 | def forward(self, x): 62 | x = self.position_encoding(x) 63 | x = self.blocks(x) 64 | return x 65 | -------------------------------------------------------------------------------- /eventful_transformer/base.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from sys import stdout 3 | 4 | from torch import nn as nn 5 | 6 | 7 | class Counts(defaultdict): 8 | """ 9 | A utility class for counting operations. Essentially, a dict with 10 | arithmetic operations on values. 11 | """ 12 | 13 | def __init__(self, *args, **kwargs): 14 | if len(args) > 0 or len(kwargs) > 0: 15 | super().__init__(*args, **kwargs) 16 | else: 17 | super().__init__(int) 18 | 19 | def __add__(self, other): 20 | result = self.copy() 21 | if isinstance(other, Counts): 22 | for key, value in other.items(): 23 | result[key] += value 24 | else: 25 | for key, value in result.items(): 26 | result[key] += other 27 | return result 28 | 29 | def __mul__(self, other): 30 | result = self.copy() 31 | for key in result: 32 | result[key] *= other 33 | return result 34 | 35 | def __neg__(self): 36 | result = self.copy() 37 | for key, value in result.items(): 38 | result[key] = -value 39 | return result 40 | 41 | def __radd__(self, other): 42 | return self.__add__(other) 43 | 44 | def __rmul__(self, other): 45 | return self.__mul__(other) 46 | 47 | def __rsub__(self, other): 48 | return self.__neg__().__add__(other) 49 | 50 | def __sub__(self, other): 51 | return self.__add__(-other) 52 | 53 | def __truediv__(self, other): 54 | return self.__mul__(1.0 / other) 55 | 56 | def csv_header(self): 57 | """ 58 | Generates a CSV header line from the keys. 59 | """ 60 | return dict_csv_header(self) 61 | 62 | def csv_line(self): 63 | """ 64 | Generates a CSV data line from the values. 65 | """ 66 | return dict_csv_line(self) 67 | 68 | def pretty_print(self, indent=4, value_format=".3e", file=stdout, flush=False): 69 | """ 70 | Prints the count data in a human-readable format. 71 | 72 | :param indent: Number of spaces to use for indents 73 | :param value_format: Number format for count values 74 | :param file: File where output should be printed (default is 75 | stdout) 76 | :param flush: Whether to flush the output buffer 77 | """ 78 | print(dict_string(self, indent, value_format), file=file, flush=flush) 79 | 80 | 81 | class ExtendedModule(nn.Module): 82 | """ 83 | An extended nn.Module that adds tooling for useful features: 84 | operation counting, state resets, and sub-module filtering. 85 | """ 86 | 87 | def __init__(self): 88 | super().__init__() 89 | self.count_mode = False 90 | self.counts = Counts() 91 | 92 | def clear_counts(self): 93 | """ 94 | Resets the operation counts for this module and all 95 | ExtendedModule submodules. 96 | """ 97 | for module in self.extended_modules(): 98 | module.counts.clear() 99 | 100 | def counting(self, mode=True): 101 | """ 102 | Sets the operation counting mode (enables counting by default). 103 | 104 | :param mode: A True/False counting mode 105 | """ 106 | for module in self.extended_modules(): 107 | module.count_mode = mode 108 | 109 | def extended_modules(self): 110 | """ 111 | Enumerates all ExtendedModule submodules. 112 | """ 113 | return self.modules_of_type(ExtendedModule) 114 | 115 | def modules_of_type(self, module_type): 116 | """ 117 | Enumerates all submodules of the specified type. 118 | 119 | :param module_type: Enumerate children of this class 120 | :return: An iterator of children of the specified class 121 | """ 122 | return filter(lambda x: isinstance(x, module_type), self.modules()) 123 | 124 | def no_counting(self): 125 | """ 126 | Disables operation counting. 127 | """ 128 | self.counting(mode=False) 129 | 130 | def reset(self): 131 | """ 132 | Resets extra state for this module and all submodules. 133 | """ 134 | for module in self.extended_modules(): 135 | module.reset_self() 136 | 137 | def reset_self(self): 138 | """ 139 | Resets extra state in this module (but not in submodules). Child 140 | classes can define reset logic by overriding this method. 141 | """ 142 | pass 143 | 144 | def total_counts(self): 145 | """ 146 | Returns a sum of operation counts over this module and all 147 | ExtendedModule submodules. 148 | """ 149 | return sum(x.counts for x in self.extended_modules()) 150 | 151 | 152 | def numeric_tuple(x, length): 153 | """ 154 | Expands a single numeric value (int, float, complex, or bool) into a 155 | tuple of a specified length. If the value is not of the specified 156 | types, does nothing. 157 | 158 | :param x: The input value 159 | :param length: The length of tuple to return if x is of the 160 | specified types 161 | """ 162 | return (x,) * length if isinstance(x, (int, float, complex, bool)) else tuple(x) 163 | 164 | 165 | def dict_csv_header(x): 166 | """ 167 | Returns a CSV-header string containing the keys of a dict. 168 | 169 | :param x: A dict 170 | """ 171 | return ",".join(k for k in sorted(x.keys())) 172 | 173 | 174 | def dict_csv_line(x): 175 | """ 176 | Returns a CSV-content string containing the values of a dict. 177 | 178 | :param x: A dict 179 | """ 180 | return ",".join(f"{x[k]:g}" for k in sorted(x.keys())) 181 | 182 | 183 | def dict_string(x, indent=4, value_format=".4g"): 184 | """ 185 | Returns a human-readable string for a dict. 186 | 187 | :param indent: Number of spaces to use for indents 188 | :param value_format: Number format for count values 189 | """ 190 | lines = [] 191 | key_length = max(len(str(key)) for key in x.keys()) 192 | format_str = " " * indent + f"{{:<{key_length + 1}}} {{:{value_format}}}" 193 | for key in sorted(x.keys()): 194 | lines.append(format_str.format(f"{key}:", x[key])) 195 | return "\n".join(lines) 196 | -------------------------------------------------------------------------------- /eventful_transformer/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | from math import sqrt, prod 5 | 6 | from eventful_transformer.base import ExtendedModule, numeric_tuple 7 | from eventful_transformer.counting import CountedAdd, CountedLinear, CountedMatmul 8 | from eventful_transformer.modules import ( 9 | SimpleSTGTGate, 10 | TokenBuffer, 11 | TokenDeltaGate, 12 | TokenGate, 13 | MatmulDeltaAccumulator, 14 | MatmulBuffer, 15 | ) 16 | from eventful_transformer.utils import ( 17 | DropPath, 18 | RelativePositionEmbedding, 19 | expand_row_index, 20 | ) 21 | from utils.image import pad_to_size 22 | 23 | LN_EPS = 1e-6 24 | 25 | 26 | class Block(ExtendedModule): 27 | """ 28 | Defines a base (non-eventful) Transformer block. Includes a couple 29 | of extra features: a simple implementation of Adaptive Token 30 | Sampling (ATS - Fayyaz et al. 2022) and self-attention pooling. 31 | These features are controlled via the ats_fraction and pool_size 32 | parameters. 33 | """ 34 | 35 | def __init__( 36 | self, 37 | dim, 38 | heads, 39 | input_size, 40 | mlp_ratio, 41 | ats_fraction=None, 42 | drop_path_rate=0.0, 43 | relative_embedding_size=None, 44 | matmul_2_cast=None, 45 | pool_size=None, 46 | window_size=None, 47 | ): 48 | """ 49 | :param dim: The number of dimensions in a token 50 | :param heads: The number of attention heads (None for no 51 | multi-headed attention) 52 | :param input_size: The expected size of the inputs in tokens 53 | :param mlp_ratio: The ratio of the MLP dimensionality to the 54 | token dimensionality 55 | :param ats_fraction: The fraction of tokens to retain if 56 | using Adaptive Token Sampling (ATS) 57 | :param drop_path_rate: Drop path ratio (for use when training) 58 | :param relative_embedding_size: The size (in tokens) assumed for 59 | relative position embeddings 60 | :param matmul_2_cast: Typecast for the attention-value product 61 | (None, "float16", or "bfloat16"). Helps save some memory when 62 | using an A-gate, without a noticeable impact on accuracy. 63 | :param pool_size: Pooling ratio to use with self-attention 64 | pooling. 65 | :param window_size: Self-attention window size (None to use 66 | global, non-windowed attention). 67 | """ 68 | super().__init__() 69 | self.heads = heads 70 | self.input_size = tuple(input_size) 71 | if ats_fraction is not None: 72 | assert pool_size is None 73 | assert window_size is None 74 | assert not (ats_fraction < 0.0 or ats_fraction > 1.0) 75 | assert not (drop_path_rate < 0.0 or drop_path_rate > 1.0) 76 | assert matmul_2_cast in [None, "float16", "bfloat16"] 77 | self.ats_fraction = ats_fraction 78 | self.last_ats_indices = None 79 | self.matmul_2_cast = matmul_2_cast 80 | if pool_size is None: 81 | self.pool_size = None 82 | else: 83 | self.pool_size = numeric_tuple(pool_size, length=2) 84 | if window_size is None: 85 | self.window_size = None 86 | attention_size = input_size 87 | else: 88 | self.window_size = numeric_tuple(window_size, length=2) 89 | attention_size = self.window_size 90 | if relative_embedding_size is not None: 91 | relative_embedding_size = self.window_size 92 | self.scale = sqrt(dim // heads) 93 | 94 | # Set up submodules. 95 | self.input_layer_norm = nn.LayerNorm(dim, eps=LN_EPS) 96 | self.qkv = CountedLinear(in_features=dim, out_features=dim * 3) 97 | self.drop_path = ( 98 | DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() 99 | ) 100 | if relative_embedding_size is not None: 101 | self.relative_position = RelativePositionEmbedding( 102 | attention_size, 103 | relative_embedding_size, 104 | dim // heads, 105 | pool_size=self.pool_size, 106 | ) 107 | else: 108 | self.relative_position = None 109 | self.matmul = CountedMatmul() 110 | self.projection = CountedLinear(in_features=dim, out_features=dim) 111 | self.add = CountedAdd() 112 | self.mlp_layer_norm = nn.LayerNorm(dim, eps=LN_EPS) 113 | self.mlp_1 = CountedLinear(in_features=dim, out_features=dim * mlp_ratio) 114 | self.gelu = nn.GELU() 115 | self.mlp_2 = CountedLinear(in_features=dim * mlp_ratio, out_features=dim) 116 | 117 | def forward(self, x): 118 | skip_1 = x 119 | x = self.input_layer_norm(x) 120 | 121 | # Linearly project x into qkv space. 122 | x = self.qkv(x) 123 | 124 | # Compute attention on the qkv representation. 125 | x, ats_indices = self._forward_attention(x) 126 | skip_1 = self._gather_ats_skip(skip_1, ats_indices) 127 | 128 | # Apply the post-attention linear transform and add the skip. 129 | x = self.projection(x) 130 | x = self.add(self.drop_path(x), skip_1) 131 | 132 | # Apply the token-wise MLP. 133 | skip_2 = x 134 | x = self.mlp_layer_norm(x) 135 | x = self._forward_mlp(x) 136 | x = self.add(self.drop_path(x), skip_2) 137 | return x 138 | 139 | def reset_self(self): 140 | self.last_ats_indices = None 141 | 142 | # A simple version of the method from 143 | # "Adaptive Token Sampling for Efficient Vision Transformers" 144 | # (Fayyaz et al., ECCV 2022) 145 | # For now we just use the top-k version of ATS (select the tokens 146 | # with the k highest scores). Using CDF-based token sampling should 147 | # also be possible, but it would be more complex to implement (we 148 | # would need a mechanism for masking the K' < K active tokens in 149 | # gates and buffers). 150 | def _adaptive_token_sampling(self, a, v): 151 | if self.ats_fraction is None: 152 | return a, None 153 | 154 | class_scores = a[..., 0] 155 | raw_scores = class_scores * torch.linalg.vector_norm(v[...], dim=-1) 156 | scores = raw_scores / raw_scores[..., 1:].sum(dim=-1, keepdim=True) 157 | 158 | # Always select the class token. 159 | scores[..., 0] = float("inf") 160 | 161 | # Sum scores over heads. 162 | scores = scores.sum(dim=-3) 163 | 164 | # Add +1 for the class token 165 | n_select = int(self.ats_fraction * (scores.shape[-1] - 1)) + 1 166 | 167 | # Select the k tokens with the highest scores. 168 | ats_indices = scores.topk(n_select, sorted=False)[1] 169 | 170 | # Sort the token indices (for stabilization). This seems to 171 | # work pretty well, although we could probably come up with 172 | # better/more sophisticated. E.g., we could try to find the 173 | # permutation of indices that minimized some norm between the 174 | # previous and current ats_indices. 175 | ats_indices = self._stabilize_ats_indices(ats_indices) 176 | self.last_ats_indices = ats_indices 177 | 178 | return ( 179 | a.gather(dim=-2, index=expand_row_index(ats_indices, a.shape)), 180 | ats_indices, 181 | ) 182 | 183 | def _cast_matmul_2(self, x, v): 184 | old_dtype = x.dtype 185 | if self.matmul_2_cast is not None: 186 | dtype = getattr(torch, self.matmul_2_cast) 187 | x = x.to(dtype) 188 | v = v.to(dtype) 189 | return x, v, old_dtype 190 | 191 | def _compute_window_padding(self): 192 | pad_h = -self.input_size[0] % self.window_size[0] 193 | pad_w = -self.input_size[1] % self.window_size[1] 194 | return pad_h, pad_w 195 | 196 | @staticmethod 197 | def _gather_ats_skip(skip_1, ats_indices): 198 | if ats_indices is None: 199 | return skip_1 200 | else: 201 | return skip_1.gather( 202 | dim=-2, index=expand_row_index(ats_indices, skip_1.shape) 203 | ) 204 | 205 | def _forward_attention(self, x): 206 | # (batch, token, dim) 207 | 208 | # Partition the windows and attention heads. _window_partition 209 | # is a noop if self.window_size is None. Windows are arranged 210 | # along the batch dimension. 211 | x = self._partition_windows(x, in_qkv_domain=True) 212 | q, k, v = self._partition_heads(x) 213 | # (batch, heads, token, dim / heads) 214 | 215 | # Token pooling is a noop if self.pool_size is None. 216 | k = self._pool_tokens(k) 217 | v = self._pool_tokens(v) 218 | 219 | # Perform the actual attention computation. 220 | # The output of this first matmul is huge - hence it's much 221 | # faster to scale one of the inputs than it is to scale the 222 | # output. 223 | x = self.matmul(q / self.scale, k.transpose(-2, -1)) 224 | if self.relative_position is not None: 225 | x = self.relative_position(x, q) 226 | x = x.softmax(dim=-1) 227 | 228 | # Adaptive token sampling is a noop if self.ats_fraction is None. 229 | x, ats_indices = self._adaptive_token_sampling(x, v) 230 | 231 | x, v, old_dtype = self._cast_matmul_2(x, v) 232 | x = self.matmul(x, v) 233 | # (batch, heads, token, dim / heads) 234 | 235 | x = self._recombine_heads(x) 236 | x = self._recombine_windows(x) 237 | x = self._uncast_matmul_2(x, old_dtype) 238 | # (batch, token, dim) 239 | 240 | return x, ats_indices 241 | 242 | def _forward_mlp(self, x): 243 | x = self.mlp_1(x) 244 | x = self.gelu(x) 245 | x = self.mlp_2(x) 246 | return x 247 | 248 | def _partition_heads(self, x): 249 | # (batch, token, dim) 250 | 251 | x = x.view(x.shape[:-1] + (3, self.heads, x.shape[-1] // (3 * self.heads))) 252 | q, k, v = x.permute(2, 0, 3, 1, 4) 253 | # (batch, heads, token, dim / heads) 254 | 255 | return q, k, v 256 | 257 | def _partition_windows(self, x, in_qkv_domain): 258 | if self.window_size is None: 259 | return x 260 | 261 | p = self._compute_window_padding() 262 | d = self.window_size 263 | # (batch, token, dim) 264 | 265 | # Unflatten the spatial dimensions. 266 | x = x.view(x.shape[:1] + self.input_size + x.shape[2:]) 267 | # (batch, height, width, dim) 268 | 269 | if any(p): 270 | s = x.shape 271 | pad_tensor = torch.zeros( 272 | (1,) * (x.ndim - 1) + s[-1:], dtype=x.dtype, device=x.device 273 | ) 274 | 275 | # The attention computation expects padded tokens to equal 276 | # _forward_qkv(zero). If x has already been mapped to the 277 | # QKV domain, then we need to transform the padded zero 278 | # values to the QKV domain. Only the bias portion of the 279 | # linear transform has an effect on the zero padding vector. 280 | if in_qkv_domain: 281 | pad_tensor = self.qkv.forward_bias(pad_tensor) 282 | 283 | # Pad to a multiple of the window size. 284 | # func.pad seems broken (see the comments in pad_to_size). 285 | # In the meantime we'll use pad_to_size. 286 | # x = func.pad(x, (0, 0, 0, p[1], 0, p[0])) 287 | x = pad_to_size(x, (s[-3] + p[0], s[-2] + p[1], s[-1]), pad_tensor) 288 | # (batch, height, width, dim) 289 | 290 | # Partition into windows. 291 | s = x.shape 292 | x = x.view(-1, s[-3] // d[0], d[0], s[-2] // d[1], d[1], s[-1]) 293 | x = x.transpose(-3, -4) 294 | # (batch, window_y, window_x, token_y, token_x, dim) 295 | 296 | # Re-flatten the spatial dimensions. Can't use x.view here 297 | # because of the transpose. 298 | x = x.reshape(-1, prod(d), s[-1]) 299 | # (batch * window, token, dim) 300 | 301 | return x 302 | 303 | def _pool_tokens(self, x): 304 | # (batch, heads, token, dim) 305 | 306 | if self.pool_size is None: 307 | return x 308 | w = self.input_size if (self.window_size is None) else self.window_size 309 | s = x.shape 310 | 311 | # Can't use x.view here because of the permutation in 312 | # _partition_heads. 313 | x = x.reshape((-1,) + w + x.shape[-1:]) 314 | # (batch * heads, token_y, token_x, dim) 315 | 316 | x = x.permute(0, 3, 1, 2) 317 | x = func.avg_pool2d(x, self.pool_size) 318 | # (batch * heads, dim, token_y, token_x) 319 | 320 | x = x.permute(0, 2, 3, 1) 321 | # (batch * heads, token_y, token_x, dim) 322 | 323 | x = x.view(s[:-2] + (-1,) + s[-1:]) 324 | # (batch, heads, token, dim) 325 | 326 | return x 327 | 328 | @staticmethod 329 | def _recombine_heads(x): 330 | # (batch, heads, token, dim / heads) 331 | 332 | # Can't use x.view here because of the permutation. 333 | x = x.permute(0, 2, 1, 3) 334 | x_reshaped = x.reshape(x.shape[:-2] + (-1,)) 335 | # (batch, token, dim) 336 | 337 | # We assume that x.reshape actually copies the data. We can run 338 | # into problems if this is not the case, i.e., we may end up 339 | # with a gate being passed a raw reference to an accumulator 340 | # state. For an example, see EventfulMatmul1Block. 341 | assert x.data_ptr() != x_reshaped.data_ptr() 342 | x = x_reshaped 343 | 344 | return x 345 | 346 | def _recombine_windows(self, x): 347 | if self.window_size is None: 348 | return x 349 | 350 | p = self._compute_window_padding() 351 | d = self.window_size 352 | s = self.input_size 353 | total_h = p[0] + s[0] 354 | total_w = p[1] + s[1] 355 | # (batch * window, token, dim) 356 | 357 | # Unflatten the spatial dimensions. 358 | x = x.view(-1, total_h // d[0], total_w // d[1], d[0], d[1], x.shape[-1]) 359 | # (batch, window_y, window_x, token_y, token_x, dim) 360 | 361 | # Recombine the window partitions. Can't use x.view here because 362 | # of the transpose. 363 | x = x.transpose(-3, -4) 364 | x = x.reshape(-1, total_h, total_w, x.shape[-1]) 365 | # (batch, height, width, dim) 366 | 367 | # Remove padding. 368 | if any(p): 369 | x = x[:, : s[0], : s[1]] 370 | # (batch, height, width, dim) 371 | 372 | # Re-flatten the spatial dimensions. 373 | x = x.flatten(start_dim=1, end_dim=2) 374 | # (batch, token, dim) 375 | 376 | return x 377 | 378 | def _stabilize_ats_indices(self, ats_indices): 379 | ats_indices = ats_indices.sort(dim=-1)[0] 380 | if self.last_ats_indices is None: 381 | return ats_indices 382 | 383 | # Faster on the CPU 384 | new_indices = ats_indices.flatten(end_dim=-2).cpu() 385 | old_indices = self.last_ats_indices.flatten(end_dim=-2).cpu() 386 | stabilized = old_indices.clone() 387 | for i in range(new_indices.shape[0]): 388 | old_not_in_new = torch.isin(old_indices[i], new_indices[i], invert=True) 389 | new_not_in_old = torch.isin(new_indices[i], old_indices[i], invert=True) 390 | stabilized[i, old_not_in_new] = new_indices[i, new_not_in_old] 391 | return stabilized.to(ats_indices.device).view(ats_indices.shape) 392 | 393 | def _uncast_matmul_2(self, x, old_dtype): 394 | if self.matmul_2_cast is not None: 395 | x = x.to(old_dtype) 396 | return x 397 | 398 | 399 | class EventfulTokenwiseBlock(Block): 400 | """ 401 | A Transformer block that adds eventfulness to token-wise operations. 402 | """ 403 | 404 | def __init__(self, gate_before_ln=False, stgt=False, **super_kwargs): 405 | """ 406 | :param gate_before_ln: Determines whether token gates are placed 407 | before or after layer norm operations 408 | :param stgt: Whether to use the SimpleSTGTGate (instead of our 409 | TokenGate) for benchmarking 410 | :param super_kwargs: Kwargs for the super class (Block) 411 | """ 412 | super().__init__(**super_kwargs) 413 | self.gate_before_ln = gate_before_ln 414 | token_gate_class = SimpleSTGTGate if stgt else TokenGate 415 | self.qkv_gate = token_gate_class() 416 | self.qkv_accumulator = TokenBuffer() 417 | self.projection_gate = token_gate_class() 418 | self.projection_accumulator = TokenBuffer() 419 | self.mlp_gate = token_gate_class() 420 | self.mlp_accumulator = TokenBuffer() 421 | 422 | def forward(self, x): 423 | skip_1, x, index = self._forward_pre_attention(x) 424 | x = self.qkv_accumulator(x, index) 425 | x, ats_indices = self._forward_attention(x) 426 | skip_1 = self._gather_ats_skip(skip_1, ats_indices) 427 | x = self._forward_post_attention(x, skip_1) 428 | return x 429 | 430 | def _forward_post_attention(self, x, skip_1): 431 | # Gate-accumulator block 2 432 | x, index = self.projection_gate(x) 433 | x = self.projection(x) 434 | x = self.projection_accumulator(x, index) 435 | 436 | x = self.add(self.drop_path(x), skip_1) 437 | skip_2 = x 438 | 439 | # Gate-accumulator block 3 440 | if self.gate_before_ln: 441 | x, index = self.mlp_gate(x) 442 | x = self.mlp_layer_norm(x) 443 | else: 444 | x = self.mlp_layer_norm(x) 445 | x, index = self.mlp_gate(x) 446 | x = self._forward_mlp(x) 447 | x = self.mlp_accumulator(x, index) 448 | x = self.add(self.drop_path(x), skip_2) 449 | 450 | return x 451 | 452 | def _forward_pre_attention(self, x): 453 | skip_1 = x 454 | 455 | # Gate-accumulator block 1 456 | if self.gate_before_ln: 457 | x, index = self.qkv_gate(x) 458 | x = self.input_layer_norm(x) 459 | else: 460 | x = self.input_layer_norm(x) 461 | x, index = self.qkv_gate(x) 462 | x = self.qkv(x) 463 | return skip_1, x, index 464 | 465 | 466 | class EventfulMatmul1Block(EventfulTokenwiseBlock): 467 | """ 468 | An EventfulTokenWiseBlock that adds eventfulness to the query-key 469 | product (in addition to token-wise operations). 470 | """ 471 | 472 | def __init__(self, **super_kwargs): 473 | """ 474 | :param super_kwargs: Kwargs for the super class ( 475 | EventfulTokenwiseBlock) 476 | """ 477 | super().__init__(**super_kwargs) 478 | 479 | # self._pool_index assumes that the input size is divisible by 480 | # the pooling size. 481 | if self.pool_size is not None: 482 | assert all(s % p == 0 for s, p in zip(self.input_size, self.pool_size)) 483 | 484 | # This class only supports non-windowed attention for now. 485 | assert self.window_size is None 486 | 487 | self.matmul_accumulator_1 = MatmulBuffer() 488 | 489 | def forward(self, x): 490 | skip_1, x, index = self._forward_pre_attention(x) 491 | x = self.qkv_accumulator(x, index) 492 | x, ats_indices = self._forward_attention((x, index)) 493 | skip_1 = self._gather_ats_skip(skip_1, ats_indices) 494 | x = self._forward_post_attention(x, skip_1) 495 | return x 496 | 497 | def _forward_attention(self, x): 498 | x, v, _ = self._forward_matmul_1(x) 499 | x, ats_indices = self._adaptive_token_sampling(x, v) 500 | x, v, old_dtype = self._cast_matmul_2(x, v) 501 | x = self.matmul(x, v) 502 | x = self._recombine_heads(x) 503 | x = self._uncast_matmul_2(x, old_dtype) 504 | return x, ats_indices 505 | 506 | def _forward_matmul_1(self, x): 507 | x, index = x 508 | q, k, v = self._partition_heads(x) 509 | k = self._pool_tokens(k) 510 | v = self._pool_tokens(v) 511 | index_k = self._pool_index(index) 512 | 513 | # See comment in Block._forward_attention. 514 | x = self.matmul_accumulator_1( 515 | q / self.scale, k.transpose(-2, -1), index, index_k 516 | ) 517 | 518 | if self.relative_position is not None: 519 | # We need inplace=False because x is a direct reference to 520 | # an accumulator state. 521 | x = self.relative_position(x, q, inplace=False) 522 | x = x.softmax(dim=-1) 523 | return x, v, index_k 524 | 525 | def _pool_index(self, index): 526 | if (self.pool_size is None) or (index is None): 527 | return index 528 | width = self.input_size[1] 529 | index_y = index.div(width, rounding_mode="floor") 530 | index_x = index.remainder(width) 531 | index_y = index_y.div(self.pool_size[0], rounding_mode="floor") 532 | index_x = index_x.div(self.pool_size[1], rounding_mode="floor") 533 | index = index_y * (width // self.pool_size[1]) + index_x 534 | 535 | # Calling .unique() still works if there are multiple items in 536 | # the batch. However, the output size along dim=-1 will be the 537 | # largest of the individual output sizes. This could result in 538 | # some redundant downstream computation. 539 | index = index.unique(dim=-1) 540 | return index 541 | 542 | 543 | class EventfulBlock(EventfulMatmul1Block): 544 | """ 545 | An EventfulMatmul1Block that also adds eventfulness to the 546 | attention-value product. 547 | """ 548 | def __init__(self, **super_kwargs): 549 | """ 550 | :param super_kwargs: Kwargs for the super class ( 551 | EventfulTokenwiseBlock) 552 | """ 553 | super().__init__(**super_kwargs) 554 | self.v_gate = TokenDeltaGate() 555 | self.matmul_gate = TokenDeltaGate(structure="col") 556 | self.matmul_accumulator_2 = MatmulDeltaAccumulator() 557 | 558 | def _forward_attention(self, a): 559 | a, v, index_k = self._forward_matmul_1(a) 560 | 561 | a, v, old_dtype = self._cast_matmul_2(a, v) 562 | a, ats_indices = self._adaptive_token_sampling(a, v) 563 | if not self.matmul_2_cast: 564 | # We clone v here because it may be a direct reference to 565 | # self.qkv_accumulator.a. 566 | v = v.clone() 567 | v_n_tilde, v_delta_tilde, index_v = self.v_gate(v, forced_index=index_k) 568 | a_n_tilde, a_delta_tilde, _ = self.matmul_gate(a, forced_index=index_v) 569 | a = self.matmul_accumulator_2( 570 | a_n_tilde, v_n_tilde, a_delta_tilde, v_delta_tilde 571 | ) 572 | 573 | a = self._recombine_heads(a) 574 | a = self._uncast_matmul_2(a, old_dtype) 575 | return a, ats_indices 576 | -------------------------------------------------------------------------------- /eventful_transformer/counting.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | from math import prod 5 | 6 | from eventful_transformer.base import ExtendedModule, numeric_tuple 7 | 8 | 9 | class CountedAdd(ExtendedModule): 10 | """ 11 | An addition operator that counts flops. 12 | """ 13 | 14 | def forward(self, a, b, inplace=False): 15 | if inplace: 16 | a += b 17 | result = a 18 | else: 19 | result = a + b 20 | if self.count_mode: 21 | self.counts["add_flops"] += result.numel() 22 | return result 23 | 24 | 25 | class CountedBias(ExtendedModule): 26 | """ 27 | A bias-addition module that counts flops. 28 | """ 29 | 30 | def __init__(self, features, spatial_dims=0, device=None, dtype=None): 31 | """ 32 | :param features: Dimensionality of the bias (size of feature 33 | dimension) 34 | :param spatial_dims: The number of trailing spatial dimensions 35 | of the input 36 | :param device: Bias device 37 | :param dtype: Bias data type 38 | """ 39 | super().__init__() 40 | self.features = features 41 | self.spatial_dims = spatial_dims 42 | self.bias = nn.Parameter(torch.zeros(features, device=device, dtype=dtype)) 43 | 44 | def forward(self, x): 45 | result = x + self.bias.view((self.features,) + (1,) * self.spatial_dims) 46 | if self.count_mode: 47 | self.counts["bias_flops"] += result.numel() 48 | return result 49 | 50 | 51 | class CountedConv(ExtendedModule): 52 | """ 53 | A convolution module that counts flops. 54 | """ 55 | 56 | def __init__( 57 | self, 58 | spatial_dims, 59 | in_channels, 60 | out_channels, 61 | kernel_size, 62 | stride=1, 63 | padding=0, 64 | dilation=1, 65 | groups=1, 66 | device=None, 67 | dtype=None, 68 | ): 69 | """ 70 | :param spatial_dims: The number of spatial dims (e.g., 2 for 2D 71 | convolution) 72 | :param in_channels: The number of input channels 73 | :param out_channels: The number of output channels 74 | :param kernel_size: The kernel size (int or tuple) 75 | :param stride: The convolution stride (int or tuple) 76 | :param padding: The amount of padding 77 | :param dilation: Dilation ratio 78 | :param groups: Number of channel groups 79 | :param device: Convolution kernel device 80 | :param dtype: Convolution kernel data type 81 | """ 82 | super().__init__() 83 | self.spatial_dims = spatial_dims 84 | self.in_channels = in_channels 85 | self.out_channels = out_channels 86 | self.kernel_size = numeric_tuple(kernel_size, length=spatial_dims) 87 | self.stride = numeric_tuple(stride, length=spatial_dims) 88 | if isinstance(padding, int): 89 | self.padding = numeric_tuple(padding, length=spatial_dims) 90 | else: 91 | self.padding = padding 92 | self.dilation = numeric_tuple(dilation, length=spatial_dims) 93 | self.groups = groups 94 | self.conv_function = getattr(func, f"conv{self.spatial_dims}d") 95 | shape = (out_channels, in_channels // groups) + self.kernel_size 96 | self.weight = nn.Parameter(torch.zeros(shape, device=device, dtype=dtype)) 97 | 98 | def forward(self, x): 99 | result = self.conv_function( 100 | x, 101 | self.weight, 102 | stride=self.stride, 103 | padding=self.padding, 104 | dilation=self.dilation, 105 | groups=self.groups, 106 | ) 107 | if self.count_mode: 108 | fan_in = (self.in_channels // self.groups) * prod(self.kernel_size) 109 | self.counts[f"conv{self.spatial_dims}d_flops"] += result.numel() * fan_in 110 | return result 111 | 112 | 113 | class CountedEinsum(ExtendedModule): 114 | """ 115 | Einsum (Einstein summation) operation that counts flops. 116 | """ 117 | 118 | def forward(self, equation, *operands): 119 | if self.count_mode: 120 | # There might be some cases here I haven't considered. But 121 | # this works fine for inner products. 122 | op_map = torch.einsum(equation, *[torch.ones_like(x) for x in operands]) 123 | self.counts["einsum_flops"] += int(op_map.sum()) 124 | return torch.einsum(equation, *operands) 125 | 126 | 127 | class CountedLinear(ExtendedModule): 128 | """ 129 | Linear transform operation that counts flops. 130 | """ 131 | 132 | def __init__(self, in_features, out_features, device=None, dtype=None): 133 | """ 134 | :param in_features: Dimensionality of input vectors 135 | :param out_features: Dimensionality of output vectors 136 | :param device: Transform matrix device 137 | :param dtype: Transform matrix data type 138 | """ 139 | super().__init__() 140 | self.in_features = in_features 141 | self.out_features = out_features 142 | shape = (out_features, in_features) 143 | self.weight = nn.Parameter(torch.zeros(shape, device=device, dtype=dtype)) 144 | self.bias = nn.Parameter(torch.zeros(out_features, device=device, dtype=dtype)) 145 | 146 | def forward_bias(self, x): 147 | result = x + self.bias 148 | if self.count_mode: 149 | self.counts["bias_flops"] += result.numel() 150 | return result 151 | 152 | def forward_linear(self, x): 153 | if self.count_mode: 154 | self.counts["linear_flops"] += x.numel() * self.out_features 155 | return func.linear(x, self.weight) 156 | 157 | def forward(self, x): 158 | result = func.linear(x, self.weight, self.bias) 159 | if self.count_mode: 160 | self.counts["bias_flops"] += result.numel() 161 | self.counts["linear_flops"] += x.numel() * self.out_features 162 | return result 163 | 164 | 165 | class CountedMatmul(ExtendedModule): 166 | """ 167 | Matrix multiplication operation that counts flops. We assume a 168 | batched 2D matrix multiplication. 169 | """ 170 | 171 | def forward(self, a, b): 172 | result = a @ b 173 | if self.count_mode: 174 | self.counts["matmul_flops"] += result.numel() * a.shape[-1] 175 | return result 176 | -------------------------------------------------------------------------------- /eventful_transformer/modules.py: -------------------------------------------------------------------------------- 1 | from eventful_transformer.base import ExtendedModule 2 | from eventful_transformer.counting import CountedMatmul 3 | from eventful_transformer.utils import expand_col_index, expand_row_index 4 | 5 | 6 | class SimpleSTGTGate(ExtendedModule): 7 | """ 8 | This class implements a simple version of the gating logic described 9 | in "Spatio-Temporal Gated Transformers for Efficient Video 10 | Processing". This is intended to be used as an experimental 11 | baseline. 12 | """ 13 | 14 | def __init__(self, structure="row"): 15 | """ 16 | :param structure: Options other than structure="row" have not 17 | yet been implemented 18 | """ 19 | super().__init__() 20 | 21 | # Currently, 22 | assert structure == "row" 23 | 24 | self.first = True 25 | self.policy = None 26 | self.p = None 27 | 28 | def forward(self, c): 29 | if self.first: 30 | return self.forward_first(c) 31 | else: 32 | return self.forward_incremental(c) 33 | 34 | def forward_first(self, c): 35 | self.first = False 36 | self.p = c 37 | return c, None 38 | 39 | def forward_incremental(self, c): 40 | if self.count_mode: 41 | self.counts["gate_flops"] += c.numel() 42 | index = self.policy(c - self.p, dim=-1) 43 | c_tilde = c.gather(dim=-2, index=expand_row_index(index, c.shape)) 44 | self.p = c 45 | return c_tilde, index 46 | 47 | def reset_self(self): 48 | self.first = True 49 | self.p = None 50 | 51 | 52 | class TokenBuffer(ExtendedModule): 53 | """ 54 | Defines a token buffer. 55 | """ 56 | 57 | def __init__(self, structure="row"): 58 | """ 59 | :param structure: Whether tokens should be indexed along the 60 | last ("col") or second-to-last ("row") dimension 61 | """ 62 | super().__init__() 63 | assert structure in ["row", "col"] 64 | self.structure = structure 65 | self.first = True 66 | self.b = None 67 | 68 | def forward(self, x, index): 69 | """ 70 | Warning - the output is a direct reference to self.b (a state 71 | tensor). 72 | """ 73 | if self.first: 74 | return self.forward_first(x) 75 | else: 76 | return self.forward_incremental(x, index) 77 | 78 | def forward_first(self, x): 79 | """ 80 | Forward pass on the first time step (flush). 81 | """ 82 | self.first = False 83 | self.b = x.clone() 84 | return self.b 85 | 86 | def forward_incremental(self, x, index): 87 | """ 88 | Forward pass after the first time step (incremental update). 89 | """ 90 | if self.structure == "row": 91 | index = expand_row_index(index, self.b.shape) 92 | dim = -2 93 | else: 94 | index = expand_col_index(index, self.b.shape) 95 | dim = -1 96 | self.b.scatter_(dim=dim, index=index, src=x) 97 | return self.b 98 | 99 | def reset_self(self): 100 | self.first = True 101 | self.b = None 102 | 103 | 104 | class TokenGate(ExtendedModule): 105 | """ 106 | Defines a token gate. 107 | 108 | TokenGate.policy defines the token selection policy. 109 | """ 110 | 111 | def __init__(self, structure="row"): 112 | """ 113 | :param structure: Whether tokens should be indexed along the 114 | last ("col") or second-to-last ("row") dimension 115 | """ 116 | super().__init__() 117 | assert structure in ["row", "col"] 118 | self.structure = structure 119 | self.first = True 120 | self.policy = None 121 | self.p = None 122 | 123 | def forward(self, c, forced_index=None): 124 | """ 125 | :param c: Warning - self.p (a state tensor) retains a direct 126 | reference to the last value of this input 127 | :param forced_index: A set of indices to force-update (instead 128 | of letting the policy decide) 129 | """ 130 | if self.first: 131 | return self.forward_first(c) 132 | else: 133 | return self.forward_incremental(c, forced_index=forced_index) 134 | 135 | def forward_first(self, c): 136 | """ 137 | Forward pass on the first time step (flush). 138 | """ 139 | self.first = False 140 | self.p = c 141 | return c, None 142 | 143 | def forward_incremental(self, c, forced_index=None): 144 | """ 145 | Forward pass after the first time step (incremental update). 146 | """ 147 | if self.count_mode: 148 | self.counts["gate_flops"] += self.p.numel() 149 | dim, expanded, index = self._apply_policy(c - self.p, forced_index) 150 | c_tilde = c.gather(dim=dim, index=expanded) 151 | self.p.scatter_(dim=dim, index=expanded, src=c_tilde) 152 | return c_tilde, index 153 | 154 | def _apply_policy(self, x, forced_index): 155 | dim = -2 if (self.structure == "row") else -1 156 | if forced_index is None: 157 | index = self.policy(x, dim=(-1 if (self.structure == "row") else -2)) 158 | else: 159 | index = forced_index 160 | if self.structure == "row": 161 | expanded = expand_row_index(index, x.shape) 162 | else: 163 | expanded = expand_col_index(index, x.shape) 164 | return dim, expanded, index 165 | 166 | def reset_self(self): 167 | self.first = True 168 | self.p = None 169 | 170 | 171 | class TokenDeltaGate(TokenGate): 172 | """ 173 | Defines a token delta gate. 174 | """ 175 | 176 | def __init__(self, structure="row"): 177 | """ 178 | :param structure: Whether tokens should be indexed along the 179 | last ("col") or second-to-last ("row") dimension 180 | """ 181 | super().__init__(structure=structure) 182 | 183 | def forward_first(self, c): 184 | c = super().forward_first(c)[0] 185 | return c, None, None 186 | 187 | def forward_incremental(self, c, forced_index=None): 188 | """ 189 | :param c: Warning - self.p (a state tensor) retains a direct 190 | reference to the last value of this input 191 | :param forced_index: A set of indices to force-update (instead 192 | of letting the policy decide) 193 | """ 194 | if self.count_mode: 195 | self.counts["gate_flops"] += self.p.numel() 196 | e = c - self.p 197 | dim, expanded, index = self._apply_policy(e, forced_index) 198 | c_tilde = c.gather(dim=dim, index=expanded) 199 | e_tilde = e.gather(dim=dim, index=expanded) 200 | self.p.scatter_(dim=dim, index=expanded, src=c_tilde) 201 | return c_tilde, e_tilde, index 202 | 203 | 204 | class MatmulBuffer(ExtendedModule): 205 | """ 206 | Defines a buffer for updating the query-key product. 207 | """ 208 | def __init__(self): 209 | super().__init__() 210 | self.first = True 211 | self.product = None 212 | self.matmul = CountedMatmul() 213 | 214 | def forward(self, q, k, index_q, index_k): 215 | """ 216 | Warning - the output is a direct reference to self.product (a 217 | state tensor). 218 | """ 219 | if self.first: 220 | return self.forward_first(q, k) 221 | else: 222 | return self.forward_incremental(q, k, index_q, index_k) 223 | 224 | def forward_first(self, q, k): 225 | """ 226 | Forward pass on the first time step (flush). 227 | """ 228 | self.first = False 229 | self.product = self.matmul(q, k) 230 | return self.product 231 | 232 | def forward_incremental(self, q, k, index_q, index_k): 233 | """ 234 | Forward pass after the first time step (incremental update). 235 | """ 236 | q_tilde = q.gather(dim=-2, index=expand_row_index(index_q, q.shape)) 237 | k_tilde = k.gather(dim=-1, index=expand_col_index(index_k, k.shape)) 238 | self.product.scatter_( 239 | dim=-2, 240 | index=expand_row_index(index_q, self.product.shape), 241 | src=self.matmul(q_tilde, k), 242 | ) 243 | self.product.scatter_( 244 | dim=-1, 245 | index=expand_col_index(index_k, self.product.shape), 246 | src=self.matmul(q, k_tilde), 247 | ) 248 | return self.product 249 | 250 | def reset_self(self): 251 | self.first = True 252 | self.product = None 253 | 254 | 255 | class MatmulDeltaAccumulator(ExtendedModule): 256 | """ 257 | Defines a buffer for updating the attention-value product. 258 | """ 259 | def __init__(self): 260 | super().__init__() 261 | self.first = True 262 | self.product = None 263 | self.matmul = CountedMatmul() 264 | 265 | def forward(self, a_n_tilde, v_n_tilde, a_delta_tilde, v_delta_tilde): 266 | """ 267 | Warning - the output is a direct reference to self.product (a 268 | state tensor). 269 | """ 270 | if self.first: 271 | return self.forward_first(a_n_tilde, v_n_tilde) 272 | else: 273 | return self.forward_incremental( 274 | a_n_tilde, v_n_tilde, a_delta_tilde, v_delta_tilde 275 | ) 276 | 277 | def forward_first(self, a, v): 278 | """ 279 | Forward pass on the first time step (flush). 280 | """ 281 | self.first = False 282 | self.product = self.matmul(a, v) 283 | return self.product 284 | 285 | def forward_incremental(self, a_n_tilde, v_n_tilde, a_delta_tilde, v_delta_tilde): 286 | """ 287 | Forward pass after the first time step (incremental update). 288 | """ 289 | if self.count_mode: 290 | self.counts["accumulator_flops"] += ( 291 | v_n_tilde.numel() + 2 * self.product.numel() 292 | ) 293 | self.product += self.matmul(a_n_tilde, v_delta_tilde) 294 | self.product += self.matmul(a_delta_tilde, v_n_tilde - v_delta_tilde) 295 | return self.product 296 | 297 | def reset_self(self): 298 | self.first = True 299 | self.product = None 300 | -------------------------------------------------------------------------------- /eventful_transformer/policies.py: -------------------------------------------------------------------------------- 1 | from torch.linalg import vector_norm 2 | 3 | from eventful_transformer.base import ExtendedModule 4 | 5 | 6 | class TokenNormThreshold(ExtendedModule): 7 | """ 8 | Defines a policy that selects tokens whose error norm exceeds a 9 | threshold. 10 | """ 11 | def __init__(self, threshold=0.0, order=2): 12 | """ 13 | :param threshold: The token norm threshold 14 | :param order: The type of norm (e.g., 2 for L2 norm) 15 | """ 16 | super().__init__() 17 | self.threshold = threshold 18 | self.order = order 19 | 20 | def forward(self, x, dim=-1): 21 | """ 22 | :param x: A tensor of token errors 23 | :param dim: The dimension along which we should reduce the norm 24 | """ 25 | assert all(size == 1 for size in x.shape[:-2]) 26 | 27 | # Note: The call to nonzero is very slow. 28 | index = vector_norm(x, ord=self.order, dim=dim).gt(self.threshold).nonzero() 29 | 30 | # Note: We assume here that all the leading dimensions (i.e., 31 | # the batch dimension) have size 1. See the assertion above. 32 | return index[..., -1].view((1,) * (x.ndim - 2) + (-1,)) 33 | 34 | # Alternative: 35 | # norm = vector_norm(x, ord=self.order, dim=-1) 36 | # return norm.topk(norm.gt(self.threshold).sum(), sorted=False)[1] 37 | 38 | 39 | class TokenNormTopK(ExtendedModule): 40 | """ 41 | Defines a policy that selects the k tokens with the largest error 42 | norm. 43 | """ 44 | def __init__(self, k, order=2, save_status=False): 45 | """ 46 | :param k: Select k tokens 47 | :param order: The type of norm (e.g., 2 for L2 norm) 48 | :param save_status: Cache inputs and outputs for debugging and 49 | visualization 50 | """ 51 | super().__init__() 52 | self.k = k 53 | self.order = order 54 | self.save_status = save_status 55 | self.last_input = None 56 | self.last_output = None 57 | 58 | def forward(self, x, dim=-1): 59 | """ 60 | :param x: A tensor of token errors 61 | :param dim: The dimension along which we should reduce the norm 62 | """ 63 | output = vector_norm(x, ord=self.order, dim=dim).topk(self.k, sorted=False)[1] 64 | if self.save_status: 65 | # Clone to protect against external modification. 66 | self.last_input = x.clone() 67 | self.last_output = output.clone() 68 | return output 69 | 70 | 71 | class TokenNormTopFraction(ExtendedModule): 72 | """ 73 | Defines a policy that selects some fraction of tokens with the 74 | largest error norm. 75 | """ 76 | def __init__(self, fraction, order=2): 77 | """ 78 | :param fraction: Select this fraction of tokens (e.g., 0.5 for 79 | half of tokens) 80 | :param order: The type of norm (e.g., 2 for L2 norm) 81 | """ 82 | super().__init__() 83 | assert not (fraction < 0.0 or fraction > 1.0) 84 | self.fraction = fraction 85 | self.order = order 86 | 87 | 88 | def forward(self, x, dim=-1): 89 | """ 90 | :param x: A tensor of token errors 91 | :param dim: The dimension along which we should reduce the norm 92 | """ 93 | x_norm = vector_norm(x, ord=self.order, dim=dim) 94 | k = int(self.fraction * x_norm.shape[-1]) 95 | return x_norm.topk(k, sorted=False)[1] 96 | -------------------------------------------------------------------------------- /eventful_transformer/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from math import prod 3 | from torch import nn as nn 4 | from torch.nn import functional as func 5 | 6 | from eventful_transformer.base import ExtendedModule 7 | from eventful_transformer.counting import CountedAdd, CountedEinsum 8 | 9 | 10 | class DropPath(ExtendedModule): 11 | """ 12 | Defines a drop-path module. 13 | 14 | Reference: https://github.com/alibaba-mmai-research/TAdaConv/blob/main/models/base/base_blocks.py 15 | """ 16 | def __init__(self, drop_rate): 17 | """ 18 | :param drop_rate: Fraction that should be dropped 19 | """ 20 | super().__init__() 21 | self.drop_rate = drop_rate 22 | 23 | def forward(self, x): 24 | if not self.training: 25 | return x 26 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) 27 | keep_mask = torch.rand(shape, device=x.device) > self.drop_rate 28 | output = x.div(1.0 - self.drop_rate) * keep_mask.to(x.dtype) 29 | return output 30 | 31 | 32 | class PositionEncoding(ExtendedModule): 33 | """ 34 | Defines a position encoding module. 35 | """ 36 | def __init__(self, dim, encoding_size, input_size, has_class_token): 37 | """ 38 | :param dim: The dimensionality of token vectors 39 | :param encoding_size: The size (in tokens) assumed for position 40 | encodings 41 | :param input_size: The expected size of the inputs in tokens 42 | :param has_class_token: Whether the input has a class token 43 | """ 44 | super().__init__() 45 | self.encoding_size = tuple(encoding_size) 46 | self.input_size = tuple(input_size) 47 | self.has_class_token = has_class_token 48 | tokens = prod(self.encoding_size) + int(has_class_token) 49 | self.encoding = nn.Parameter(torch.zeros(1, tokens, dim)) 50 | self.add = CountedAdd() 51 | self.cached_encoding = None 52 | 53 | def forward(self, x): 54 | if self.training: 55 | self.cached_encoding = None 56 | encoding = self._compute_sized_encoding() 57 | else: 58 | # Cache the resized encoding during inference (assuming the 59 | # weights don't change, its value doesn't change between 60 | # model invocations). 61 | if self.cached_encoding is None: 62 | self.cached_encoding = self._compute_sized_encoding() 63 | encoding = self.cached_encoding 64 | 65 | # Add the position encoding. 66 | x = self.add(x, encoding) 67 | return x 68 | 69 | def _compute_sized_encoding(self): 70 | encoding = self.encoding 71 | 72 | # Interpolate the position encoding if needed. 73 | if self.input_size != self.encoding_size: 74 | # (batch, patch, dim) 75 | 76 | if self.has_class_token: 77 | # The class token comes *first* (see ViViTSubModel). 78 | class_token = encoding[:, :1] 79 | encoding = encoding[:, 1:] 80 | else: 81 | class_token = None 82 | encoding = encoding.transpose(1, 2) 83 | encoding = encoding.view(encoding.shape[:-1] + self.encoding_size) 84 | # (batch, dim) + encoding_size 85 | 86 | # Note: We do not count operations from this interpolation, 87 | # even though it is in the backbone. This is because the 88 | # cost of interpolating is amortized over many invocations. 89 | encoding = func.interpolate( 90 | encoding, self.input_size, mode="bicubic", align_corners=False 91 | ) 92 | # (batch, dim) + embedding_size 93 | 94 | encoding = encoding.flatten(start_dim=2) 95 | encoding = encoding.transpose(1, 2) 96 | if self.has_class_token: 97 | encoding = torch.concat([class_token, encoding], dim=1) 98 | # (batch, patch, dim) 99 | 100 | return torch.Tensor(encoding) 101 | 102 | def reset_self(self): 103 | # Clear the cached value of sized_encoding whenever the model is 104 | # reset (just in case new weights get loaded). 105 | self.cached_encoding = None 106 | 107 | 108 | class RelativePositionEmbedding(ExtendedModule): 109 | """ 110 | Defines relative position embeddings. 111 | """ 112 | def __init__(self, attention_size, embedding_size, head_dim, pool_size=None): 113 | """ 114 | :param attention_size: The expected size of the attention window 115 | :param embedding_size: The size (in tokens) assumed for position 116 | embeddings 117 | :param head_dim: The dimensionality of each attention head 118 | :param pool_size: The pooling size (if self-attention pooling is 119 | being used - see the pool_size parameter to Block. 120 | """ 121 | super().__init__() 122 | self.attention_size = attention_size 123 | self.embedding_size = embedding_size 124 | self.pool_size = pool_size 125 | self.y_embedding = nn.Parameter( 126 | torch.zeros(2 * embedding_size[0] - 1, head_dim) 127 | ) 128 | self.x_embedding = nn.Parameter( 129 | torch.zeros(2 * embedding_size[1] - 1, head_dim) 130 | ) 131 | self.add = CountedAdd() 132 | self.einsum = CountedEinsum() 133 | self.y_relative = None 134 | self.x_relative = None 135 | 136 | # This is based on the add_decomposed_rel_pos function here: 137 | # https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/utils.py 138 | # noinspection PyTypeChecker 139 | def forward(self, x, q, inplace=True): 140 | a = self.attention_size 141 | 142 | # Unflatten the spatial dimensions. 143 | if self.pool_size is None: 144 | p = a 145 | else: 146 | p = (a[0] // self.pool_size[0], a[1] // self.pool_size[1]) 147 | x = x.view(x.shape[:2] + a + p) 148 | q = q.view(q.shape[:2] + a + q.shape[-1:]) 149 | 150 | # Apply the relative position embedding. 151 | if self.y_relative is None: 152 | # Cache y_relative and x_relative (assuming the weights 153 | # don't change, their values don't change between model 154 | # invocations). 155 | self.y_relative = self._get_relative(self.y_embedding, dim=0) 156 | self.x_relative = self._get_relative(self.x_embedding, dim=1) 157 | x = self.add( 158 | x, 159 | self.einsum("abhwc,hkc->abhwk", q, self.y_relative).unsqueeze(dim=-1), 160 | inplace=inplace, 161 | ) 162 | x = self.add( 163 | x, 164 | self.einsum("abhwc,wkc->abhwk", q, self.x_relative).unsqueeze(dim=-2), 165 | inplace=True, 166 | ) 167 | 168 | # Re-flatten the spatial dimensions. 169 | x = x.view(x.shape[:2] + (prod(a), prod(p))) 170 | 171 | return x 172 | 173 | # This is a simplification of the get_rel_pos function here: 174 | # https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/utils.py 175 | def _get_relative(self, embedding, dim): 176 | range_0 = torch.arange(self.embedding_size[dim]).unsqueeze(dim=1) 177 | range_1 = torch.arange(self.embedding_size[dim]).unsqueeze(dim=0) 178 | relative = embedding[range_0 - range_1 + self.embedding_size[dim] - 1] 179 | if self.embedding_size != self.attention_size: 180 | relative = relative.transpose(0, 2).unsqueeze(dim=0) 181 | relative = func.interpolate( 182 | relative, self.attention_size, mode="bicubic", align_corners=False 183 | ) 184 | relative = relative.squeeze(dim=0).transpose(0, 2) 185 | if self.pool_size is not None: 186 | relative = relative.transpose(1, 2) 187 | relative = func.avg_pool1d(relative, self.pool_size[dim]) 188 | relative = relative.transpose(1, 2) 189 | return relative 190 | 191 | def reset_self(self): 192 | # Clear the cached values of x_relative and y_relative whenever 193 | # the model is reset (just in case new weights get loaded). 194 | self.y_relative = None 195 | self.x_relative = None 196 | 197 | 198 | def expand_col_index(index, target_shape): 199 | old_shape = index.shape 200 | new_dims = len(target_shape) - index.ndim 201 | index = index.view(old_shape[:-1] + (1,) * new_dims + old_shape[-1:]) 202 | index = index.expand(target_shape[:-1] + (-1,)) 203 | return index 204 | 205 | 206 | def expand_row_index(index, target_shape): 207 | old_shape = index.shape 208 | new_dims = len(target_shape) - index.ndim 209 | index = index.view(old_shape[:-1] + (1,) * (new_dims - 1) + (old_shape[-1], 1)) 210 | index = index.expand(target_shape[:-2] + (-1, target_shape[-1])) 211 | return index 212 | -------------------------------------------------------------------------------- /models/vitdet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from detectron2.config import LazyConfig, instantiate 3 | from detectron2.structures import ImageList 4 | from torchvision.transforms import Normalize 5 | 6 | from eventful_transformer.backbones import ViTBackbone 7 | from eventful_transformer.base import ExtendedModule, numeric_tuple 8 | from eventful_transformer.blocks import LN_EPS 9 | from utils.image import as_float32, pad_to_size 10 | 11 | 12 | # Resources consulted: 13 | # https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/utils.py 14 | # https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py 15 | 16 | 17 | class LinearEmbedding(nn.Module): 18 | """ 19 | The initial linear patch-embedding layer for ViTDet. Linearly 20 | transforms each input patch into a token vector. 21 | """ 22 | 23 | def __init__(self, input_channels, dim, patch_size): 24 | """ 25 | :param input_channels: The number of image channels (e.g., 3 for 26 | RGB images) 27 | :param dim: The dimensionality of token vectors 28 | :param patch_size: The patch size for each token (a 2-element 29 | tuple/list) 30 | """ 31 | super().__init__() 32 | self.conv = nn.Conv2d( 33 | in_channels=input_channels, 34 | out_channels=dim, 35 | kernel_size=patch_size, 36 | stride=patch_size, 37 | ) 38 | 39 | def forward(self, x): 40 | # (batch, dim, height, width) 41 | 42 | x = self.conv(x) 43 | # (batch, dim, height, width) 44 | 45 | # Flatten the spatial axes. 46 | x = x.flatten(start_dim=-2) 47 | # (batch, dim, patch) 48 | 49 | x = x.transpose(1, 2) 50 | # (batch, patch, dim) 51 | 52 | return x 53 | 54 | 55 | class PointwiseLayerNorm2d(nn.LayerNorm): 56 | """ 57 | A LayerNorm operation which performs x.permute(0, 2, 3, 1) before 58 | applying the normalization. The permutation is inverted after 59 | normalization. 60 | """ 61 | 62 | def forward(self, x): 63 | # (batch, dim, height, width) 64 | 65 | x = x.permute(0, 2, 3, 1) 66 | # (batch, height, width, dim) 67 | 68 | x = super().forward(x) 69 | x = x.permute(0, 3, 1, 2) 70 | # (batch, dim, height, width) 71 | 72 | return x 73 | 74 | 75 | class SimplePyramid(nn.Module): 76 | """ 77 | The ViTDet feature pyramid (precedes the object detection head). 78 | """ 79 | 80 | def __init__(self, scale_factors, dim, out_channels): 81 | """ 82 | :param scale_factors: A list of spatial scale factors 83 | :param dim: The dimensionality of token vectors in the 84 | Transformer backbone 85 | :param out_channels: The number of output channels (the number 86 | of channels expected by the object detection head) 87 | """ 88 | super().__init__() 89 | self.stages = nn.ModuleList( 90 | self._build_scale(scale, dim, out_channels) for scale in scale_factors 91 | ) 92 | self.max_pool = nn.MaxPool2d(kernel_size=1, stride=2, padding=0) 93 | 94 | def forward(self, x): 95 | x = [stage(x) for stage in self.stages] 96 | x.append(self.max_pool(x[-1])) 97 | return x 98 | 99 | @staticmethod 100 | def _build_scale(scale, dim, out_channels): 101 | assert scale in [4.0, 2.0, 1.0, 0.5] 102 | if scale == 0.5: 103 | mid_dim = dim 104 | start_layers = [nn.MaxPool2d(kernel_size=2, stride=2)] 105 | elif scale == 1.0: 106 | mid_dim = dim 107 | start_layers = [] 108 | elif scale == 2.0: 109 | mid_dim = dim // 2 110 | start_layers = [nn.ConvTranspose2d(dim, mid_dim, kernel_size=2, stride=2)] 111 | else: # scale == 4.0 112 | mid_dim = dim // 4 113 | start_layers = [ 114 | nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2), 115 | PointwiseLayerNorm2d(dim // 2, eps=LN_EPS), 116 | nn.GELU(), 117 | nn.ConvTranspose2d(dim // 2, mid_dim, kernel_size=2, stride=2), 118 | ] 119 | common_layers = [ 120 | nn.Conv2d(mid_dim, out_channels, kernel_size=1, bias=False), 121 | PointwiseLayerNorm2d(out_channels, eps=LN_EPS), 122 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), 123 | PointwiseLayerNorm2d(out_channels, eps=LN_EPS), 124 | ] 125 | return nn.Sequential(*start_layers, *common_layers) 126 | 127 | 128 | class ViTDet(ExtendedModule): 129 | """ 130 | The ViTDet object detection Transformer model. See 131 | configs/models/vitdet_b_coco.yml for an example configuration. 132 | """ 133 | 134 | def __init__( 135 | self, 136 | backbone_config, 137 | classes, 138 | detectron2_config, 139 | input_shape, 140 | normalize_mean, 141 | normalize_std, 142 | output_channels, 143 | patch_size, 144 | scale_factors, 145 | ): 146 | """ 147 | :param backbone_config: A dict containing kwargs for the 148 | backbone constructor 149 | :param classes: The number of object classes 150 | :param detectron2_config: Path of a Python file containing a 151 | Detectron2 config for the detection head 152 | :param input_shape: The (c, h, w) shape for inputs (the 153 | preprocessing spatially pads inputs to this shape) 154 | :param normalize_mean: The mean to use with 155 | torchvision.transforms.Normalize 156 | :param normalize_std: The standard deviation to use with 157 | torchvision.transforms.Normalize 158 | :param output_channels: The number of channels expected by the 159 | object detection head 160 | :param patch_size: The patch size for each token (a 2-element 161 | tuple/list) 162 | :param scale_factors: Scale factors for the SimplePyramid module 163 | """ 164 | super().__init__() 165 | input_c, input_h, input_w = input_shape 166 | patch_size = numeric_tuple(patch_size, length=2) 167 | self.backbone_input_size = (input_h // patch_size[0], input_w // patch_size[1]) 168 | 169 | # Set up submodules. 170 | self.preprocessing = ViTDetPreprocessing( 171 | input_shape, normalize_mean, normalize_std 172 | ) 173 | dim = backbone_config["block_config"]["dim"] 174 | self.embedding = LinearEmbedding(input_c, dim, patch_size) 175 | self.backbone = ViTBackbone( 176 | input_size=self.backbone_input_size, 177 | **backbone_config, 178 | ) 179 | self.pyramid = SimplePyramid(scale_factors, dim, output_channels) 180 | detectron2_config = LazyConfig.load(detectron2_config)["model"] 181 | self.proposal_generator = instantiate(detectron2_config["proposal_generator"]) 182 | roi_heads_config = detectron2_config["roi_heads"] 183 | roi_heads_config["num_classes"] = classes 184 | self.roi_heads = instantiate(roi_heads_config) 185 | 186 | def forward(self, x): 187 | images, x = self.pre_backbone(x) 188 | x = self.backbone(x) 189 | results = self.post_backbone(images, x) 190 | return results 191 | 192 | def post_backbone(self, images, x): 193 | """ 194 | Computes the portion of the model after the Transformer 195 | backbone. 196 | """ 197 | x = x.transpose(-1, -2) 198 | x = x.view(x.shape[:-1] + self.backbone_input_size) 199 | x = self.pyramid(x) 200 | 201 | # Compute region proposals and bounding boxes. 202 | x = dict(zip(self.proposal_generator.in_features, x)) 203 | proposals = self.proposal_generator(images, x, None)[0] 204 | result = self.roi_heads(images, x, proposals, None)[0] 205 | result = [ 206 | {"boxes": y.pred_boxes.tensor, "scores": y.scores, "labels": y.pred_classes} 207 | for y in result 208 | ] 209 | return result 210 | 211 | def pre_backbone(self, x): 212 | """ 213 | Computes the portion of the model before the Transformer 214 | backbone. 215 | """ 216 | x = as_float32(x) # Range [0, 1] 217 | x = self.preprocessing(x) 218 | images = ImageList.from_tensors([x]) 219 | x = self.embedding(x) 220 | return images, x 221 | 222 | 223 | class ViTDetPreprocessing(nn.Module): 224 | """ 225 | Preprocessing for ViTDet. Applies value normalization and square 226 | padding. Expects inputs scaled to the range [0, 1]. 227 | """ 228 | 229 | def __init__(self, input_shape, normalize_mean, normalize_std): 230 | """ 231 | :param input_shape: The (c, h, w) shape to which inputs should 232 | be padded 233 | :param normalize_mean: The mean to use with 234 | torchvision.transforms.Normalize 235 | :param normalize_std: The standard deviation to use with 236 | torchvision.transforms.Normalize 237 | """ 238 | super().__init__() 239 | self.input_shape = tuple(input_shape) 240 | self.normalization = Normalize(normalize_mean, normalize_std) 241 | 242 | def forward(self, x): 243 | # This normalization assumes x in the range [0, 255], but the 244 | # parent model (ViTDet) scales the input image to [0, 1]. 245 | x = self.normalization(x * 255.0) 246 | 247 | # This is bottom-right padding, so it won't affect the bounding 248 | # box coordinates. 249 | x = pad_to_size(x, self.input_shape[-2:]) 250 | 251 | return x 252 | -------------------------------------------------------------------------------- /models/vivit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.transforms import Normalize 4 | 5 | from eventful_transformer.backbones import ViTBackbone 6 | from eventful_transformer.base import ExtendedModule 7 | from eventful_transformer.blocks import LN_EPS 8 | from eventful_transformer.counting import CountedLinear 9 | from utils.image import as_float32, resize_to_fit 10 | 11 | 12 | # Resources consulted: 13 | # https://github.com/google-research/scenic/tree/main/scenic/projects/vivit 14 | # https://github.com/alibaba-mmai-research/TAdaConv 15 | # https://github.com/alibaba-mmai-research/TAdaConv/blob/main/models/base/transformer.py 16 | 17 | 18 | class FactorizedViViT(ExtendedModule): 19 | """ 20 | The spatio-temporal factorized ViViT action recognition Transformer 21 | model. See configs/models/vivit_b_kinetics400.yml for an example 22 | configuration. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | classes, 28 | input_shape, 29 | normalize_mean, 30 | normalize_std, 31 | spatial_config, 32 | spatial_views, 33 | temporal_config, 34 | temporal_stride, 35 | temporal_views, 36 | tubelet_shape, 37 | batch_views=True, 38 | dropout_rate=0.0, 39 | spatial_only=False, 40 | temporal_only=False, 41 | ): 42 | """ 43 | :param classes: The number of output classes 44 | :param input_shape: The (t, c, h, w) shape for each input view 45 | (the preprocessing crops inputs to this shape) 46 | :param normalize_mean: The mean to use with 47 | torchvision.transforms.Normalize 48 | :param normalize_std: The standard deviation to use with 49 | torchvision.transforms.Normalize 50 | :param spatial_config: A dict containing kwargs for the 51 | spatial sub-model constructor 52 | :param spatial_views: The number of spatial views of the video 53 | :param temporal_config: A dict containing kwargs for the 54 | temporal sub-model constructor 55 | :param temporal_stride: The temporal stride at which input 56 | frames should be sampled 57 | :param temporal_views: The number of temporal views of the video 58 | :param tubelet_shape: The (t, h, w) shape that should be mapped 59 | to a token vector (the 3D analog of patch size) 60 | :param batch_views: Whether the spatial and temporal views 61 | should be arranged along the batch axis and processed in 62 | parallel (this is more efficient, but doesn't always work) 63 | :param dropout_rate: The dropout rate to use before the final 64 | classification layer 65 | :param spatial_only: Only apply the spatial sub-model 66 | :param temporal_only: Only apply the temporal sub-model 67 | """ 68 | super().__init__() 69 | assert not (spatial_only and temporal_only) 70 | assert not (dropout_rate < 0.0 or dropout_rate > 1.0) 71 | input_shape = tuple(input_shape) 72 | tubelet_shape = tuple(tubelet_shape) 73 | input_t, input_c, input_h, input_w = input_shape 74 | backbone_input_size = (input_h // tubelet_shape[1], input_w // tubelet_shape[2]) 75 | self.batch_views = batch_views 76 | self.spatial_only = spatial_only 77 | self.temporal_only = temporal_only 78 | 79 | # Set up submodules. 80 | self.preprocessing = ViViTPreprocessing( 81 | input_shape, 82 | normalize_mean, 83 | normalize_std, 84 | spatial_views, 85 | temporal_stride, 86 | temporal_views, 87 | ) 88 | dim = spatial_config["block_config"]["dim"] 89 | self.embedding = TubeletEmbedding(input_c, dim, tubelet_shape) 90 | self.spatial_model = ViViTSubModel(backbone_input_size, spatial_config) 91 | backbone_input_size = (input_t // tubelet_shape[0],) 92 | self.temporal_model = ViViTSubModel(backbone_input_size, temporal_config) 93 | self.dropout = ( 94 | nn.Dropout(dropout_rate) if (dropout_rate > 0.0) else nn.Identity() 95 | ) 96 | self.classifier = CountedLinear(in_features=dim, out_features=classes) 97 | 98 | def forward(self, x): 99 | batch_size = x.shape[0] 100 | if not self.temporal_only: 101 | x = self._forward_spatial(x) 102 | if not self.spatial_only: 103 | x = self._forward_temporal(x, batch_size) 104 | return x 105 | 106 | def _forward_spatial(self, x): 107 | # Performance note: We can improve performance somewhat by 108 | # batching multiple views together and calling _forward_view 109 | # once. However, this complicates things when we select 110 | # different numbers of tokens from different views (e.g., with 111 | # a threshold policy) and want to perform some operation (e.g., 112 | # a matrix multiply) where the batch dimension needs to be 113 | # maintained. We may want to revisit this. 114 | # Idea: We could restrict the number of active tokens to a small 115 | # number of preset values (e.g., 0, 32, 64, 128) and group items 116 | # in the batch by the number of active tokens. This would reduce 117 | # the number of kernel invocations - from the number of items in 118 | # the batch to the number of preset values (4 in the previous 119 | # example). 120 | x = self.preprocessing(x) 121 | if self.batch_views: 122 | x = torch.stack(x, dim=1).flatten(end_dim=1) 123 | x = self._forward_view(x) 124 | else: 125 | x = [self._forward_view(view) for view in x] 126 | x = torch.stack(x, dim=1).flatten(end_dim=1) 127 | return x 128 | 129 | def _forward_temporal(self, x, batch_size): 130 | x = x.view((-1,) + x.shape[-2:]) 131 | x = self.temporal_model(x) 132 | x = self.dropout(x) 133 | x = self.classifier(x) 134 | x = x.view(batch_size, -1, x.shape[-1]) 135 | x = x.mean(dim=-2) 136 | x = x.softmax(dim=-1) 137 | return x 138 | 139 | def _forward_view(self, x): 140 | # (batch, time, channel, height, width) 141 | 142 | x = self.embedding(x) 143 | # (batch, time, patch, dim) 144 | 145 | # Apply the spatial model to each time step. 146 | self.spatial_model.reset() 147 | x = torch.stack([self.spatial_model(x[:, t]) for t in range(x.shape[1])], dim=1) 148 | # (batch, time, dim) 149 | 150 | return x 151 | 152 | 153 | class TubeletEmbedding(nn.Module): 154 | """ 155 | The initial linear tubelet-embedding layer for ViViT. Linearly 156 | transforms each input tubelet (t, h, w) into a token vector. 157 | """ 158 | 159 | def __init__(self, input_channels, dim, tubelet_shape): 160 | """ 161 | 162 | :param input_channels: The number of image channels (e.g., 3 for 163 | RGB images) 164 | :param dim: The dimensionality of token vectors 165 | :param tubelet_shape: The tubelet size for each token (a 166 | 3-element t, h, w tuple/list) 167 | """ 168 | super().__init__() 169 | self.conv = nn.Conv3d( 170 | in_channels=input_channels, 171 | out_channels=dim, 172 | kernel_size=tubelet_shape, 173 | stride=tubelet_shape, 174 | ) 175 | 176 | def forward(self, x): 177 | # (batch, time, dim, height, width) 178 | 179 | # Permute so all 3 dimensions for Conv3d are adjacent. 180 | x = x.permute(0, 2, 1, 3, 4) 181 | 182 | x = self.conv(x) 183 | # (batch, dim, time, height, width) 184 | 185 | # Flatten the spatial axes. 186 | x = x.flatten(start_dim=-2) 187 | # (batch, dim, time, patch) 188 | 189 | x = x.permute(0, 2, 3, 1) 190 | # (batch, time, patch, dim) 191 | 192 | return x 193 | 194 | 195 | class ViViTPreprocessing(nn.Module): 196 | """ 197 | Preprocessing for ViViT. Applies value normalization and chops the 198 | video into multiple spatial and temporal views. 199 | """ 200 | 201 | def __init__( 202 | self, 203 | input_shape, 204 | normalize_mean, 205 | normalize_std, 206 | spatial_views, 207 | temporal_stride, 208 | temporal_views, 209 | ): 210 | """ 211 | :param input_shape: The (t, c, h, w) shape for each input view 212 | :param normalize_mean: The mean to use with 213 | torchvision.transforms.Normalize 214 | :param normalize_std: The standard deviation to use with 215 | torchvision.transforms.Normalize 216 | :param spatial_views: The number of spatial views of the video 217 | :param temporal_stride: The temporal stride at which input 218 | frames should be sampled 219 | :param temporal_views: The number of spatial views of the video 220 | """ 221 | super().__init__() 222 | self.input_shape = input_shape 223 | self.temporal_stride = temporal_stride 224 | self.temporal_views = temporal_views 225 | self.spatial_views = spatial_views 226 | self.normalization = Normalize(normalize_mean, normalize_std) 227 | 228 | def forward(self, x): 229 | t, _, h, w = self.input_shape 230 | 231 | # Repeat the last frame if the video is too short. 232 | view_size = self.temporal_stride * t 233 | if x.shape[1] < view_size: 234 | n_pad = view_size - x.shape[1] 235 | frame = x[:, -1:] 236 | pad_frames = frame.expand(frame.shape[:1] + (n_pad,) + frame.shape[2:]) 237 | x = torch.concat([x, pad_frames], dim=1) 238 | 239 | # Chop the video into multiple temporal views. 240 | if self.temporal_views == 1: 241 | start_positions = [(x.shape[1] - view_size) // 2] 242 | else: 243 | spacing = (x.shape[1] - view_size) / (self.temporal_views - 1) 244 | start_positions = [int(k * spacing) for k in range(self.temporal_views)] 245 | x = [x[:, i : i + view_size : self.temporal_stride] for i in start_positions] 246 | 247 | # Normalize and resize the video. 248 | x = [as_float32(x_i) for x_i in x] # Range [0, 1] 249 | x = [self.normalization(x_i) for x_i in x] 250 | x = [ 251 | torch.stack( 252 | [resize_to_fit(x_i[:, t], (h, w)) for t in range(x_i.shape[1])], dim=1 253 | ) 254 | for x_i in x 255 | ] 256 | 257 | # Chop each temporal view into multiple spatial views. 258 | if self.spatial_views == 1: 259 | start_positions = [((x[0].shape[-2] - h) // 2, (x[0].shape[-1] - w) // 2)] 260 | else: 261 | h_spacing = (x[0].shape[-2] - h) / (self.spatial_views - 1) 262 | w_spacing = (x[0].shape[-1] - w) / (self.spatial_views - 1) 263 | start_positions = [ 264 | (int(k * h_spacing), int(k * w_spacing)) 265 | for k in range(self.spatial_views) 266 | ] 267 | x = [view[..., i : i + h, j : j + w] for i, j in start_positions for view in x] 268 | 269 | return x 270 | 271 | 272 | class ViViTSubModel(ExtendedModule): 273 | """ 274 | A factorized ViViT sub-model (spatial or temporal). 275 | """ 276 | 277 | def __init__(self, input_size, backbone_config): 278 | """ 279 | 280 | :param input_size: The input size in tokens - (h, w) for the 281 | spatial sub-model or (t,) for the temporal sub-model 282 | :param backbone_config: A dict containing kwargs for the 283 | backbone constructor 284 | """ 285 | super().__init__() 286 | dim = backbone_config["block_config"]["dim"] 287 | self.class_token = nn.Parameter(torch.zeros(1, 1, dim)) 288 | self.backbone = ViTBackbone( 289 | input_size=input_size, has_class_token=True, **backbone_config 290 | ) 291 | self.layer_norm = nn.LayerNorm(dim, eps=LN_EPS) 292 | 293 | def forward(self, x): 294 | # Append the class token. 295 | expand_shape = (x.shape[0],) + self.class_token.shape[1:] 296 | x = torch.concat([self.class_token.expand(expand_shape), x], dim=1) 297 | 298 | x = self.backbone(x) 299 | x = self.layer_norm(x) 300 | 301 | # Extract the class embedding token. 302 | x = x[:, 0] 303 | return x 304 | -------------------------------------------------------------------------------- /scripts/convert/vitdet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import pickle 4 | from argparse import ArgumentParser 5 | 6 | import torch 7 | 8 | from utils.misc import parse_patterns, remap_weights 9 | 10 | 11 | # Weight sources: 12 | # https://github.com/alibaba-mmai-research/TAdaConv/blob/main/MODEL_ZOO.md 13 | # https://github.com/happyharrycn/detectron2_vitdet_vid/tree/main/projects/ViTDet-VID 14 | 15 | 16 | def main(args): 17 | if args.in_file.endswith(".pkl"): 18 | with open(args.in_file, "rb") as pickle_file: 19 | in_weights = pickle.load(pickle_file) 20 | else: 21 | in_weights = torch.load(args.in_file) 22 | in_weights = in_weights["model"] 23 | 24 | # Throw out the class embedding token. 25 | in_weights["backbone.net.pos_embed"] = in_weights["backbone.net.pos_embed"][:, 1:] 26 | 27 | patterns = parse_patterns(args.pattern_file) 28 | out_weights, n_remapped = remap_weights(in_weights, patterns, args.verbose) 29 | for key, weight in out_weights.items(): 30 | # Modifying in place while iterating is okay because the keys 31 | # aren't changing. 32 | if not isinstance(out_weights[key], torch.Tensor): 33 | out_weights[key] = torch.tensor(weight) 34 | torch.save(out_weights, args.out_file) 35 | print(f"Remapped {n_remapped}/{len(in_weights)} weights.") 36 | 37 | 38 | def parse_args(): 39 | parser = ArgumentParser() 40 | parser.add_argument("in_file", help="the input .pkl or .pth file") 41 | parser.add_argument("out_file", help=".pth file where the output should be saved") 42 | parser.add_argument("pattern_file", help=".txt file containing regex patterns") 43 | parser.add_argument( 44 | "-v", "--verbose", action="store_true", help="print detailed output" 45 | ) 46 | return parser.parse_args() 47 | 48 | 49 | if __name__ == "__main__": 50 | main(parse_args()) 51 | -------------------------------------------------------------------------------- /scripts/convert/vivit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from argparse import ArgumentParser 4 | 5 | import torch 6 | 7 | from utils.misc import parse_patterns, remap_weights 8 | 9 | 10 | # Weight source: 11 | # https://github.com/facebookresearch/detectron2/tree/main/projects/ViTDet 12 | 13 | 14 | def main(args): 15 | in_weights = torch.load(args.in_file)["model_state"] 16 | patterns = parse_patterns(args.pattern_file) 17 | out_weights, n_remapped = remap_weights(in_weights, patterns, args.verbose) 18 | torch.save(out_weights, args.out_file) 19 | print(f"Remapped {n_remapped}/{len(in_weights)} weights.") 20 | 21 | 22 | def parse_args(): 23 | parser = ArgumentParser() 24 | parser.add_argument("in_file", help="the input .pyth file") 25 | parser.add_argument("out_file", help=".pth file where the output should be saved") 26 | parser.add_argument("pattern_file", help=".txt file containing regex patterns") 27 | parser.add_argument( 28 | "-v", "--verbose", action="store_true", help="print detailed output" 29 | ) 30 | return parser.parse_args() 31 | 32 | 33 | if __name__ == "__main__": 34 | main(parse_args()) 35 | -------------------------------------------------------------------------------- /scripts/evaluate/vitdet_vid.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from pathlib import Path 4 | 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from torchmetrics.detection.mean_ap import MeanAveragePrecision 8 | from tqdm import tqdm 9 | 10 | from datasets.vid import VIDResize, VID 11 | from models.vitdet import ViTDet 12 | from utils.config import initialize_run 13 | from utils.evaluate import run_evaluations 14 | from utils.misc import dict_to_device, squeeze_dict 15 | 16 | 17 | def evaluate_vitdet_metrics(device, model, data, config): 18 | model.counting() 19 | model.clear_counts() 20 | n_frames = 0 21 | outputs = [] 22 | labels = [] 23 | n_items = config.get("n_items", len(data)) 24 | for _, vid_item in tqdm(zip(range(n_items), data), total=n_items, ncols=0): 25 | vid_item = DataLoader(vid_item, batch_size=1) 26 | n_frames += len(vid_item) 27 | model.reset() 28 | for frame, annotations in vid_item: 29 | with torch.inference_mode(): 30 | outputs.extend(model(frame.to(device))) 31 | labels.append(squeeze_dict(dict_to_device(annotations, device), dim=0)) 32 | 33 | # MeanAveragePrecision is extremely slow. It seems fastest to call 34 | # update() and compute() just once, after all predictions are done. 35 | mean_ap = MeanAveragePrecision() 36 | mean_ap.update(outputs, labels) 37 | metrics = mean_ap.compute() 38 | 39 | counts = model.total_counts() / n_frames 40 | model.clear_counts() 41 | return {"metrics": metrics, "counts": counts} 42 | 43 | 44 | def main(): 45 | config = initialize_run(config_location=Path("configs", "evaluate", "vitdet_vid")) 46 | long_edge = max(config["model"]["input_shape"][-2:]) 47 | data = VID( 48 | Path("data", "vid"), 49 | split=config["split"], 50 | tar_path=Path("data", "vid", "data.tar"), 51 | combined_transform=VIDResize( 52 | short_edge_length=640 * long_edge // 1024, max_size=long_edge 53 | ), 54 | ) 55 | run_evaluations(config, ViTDet, data, evaluate_vitdet_metrics) 56 | 57 | 58 | if __name__ == "__main__": 59 | main() 60 | -------------------------------------------------------------------------------- /scripts/evaluate/vitdet_vid.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Slurm wrapper for scripts/evaluate/vitdet_vid.py. Usage: 4 | # sbatch -J ./scripts/evaluate/vitdet_vid.sh 5 | # where is the name of the config in configs/evaluate/vitdet_vid. 6 | 7 | # To override the time limit, use the -t/--time command-line argument. 8 | 9 | # To request a specific GPU, use the argument --gres=gpu::1 with 10 | # replaced by the type of GPU (e.g, a100). 11 | 12 | #SBATCH --cpus-per-task=16 13 | #SBATCH --output=slurm/%x.txt 14 | #SBATCH --gres=gpu:1 15 | #SBATCH --mem=48GB 16 | #SBATCH --partition=research 17 | #SBATCH --time=4-00:00:00 18 | 19 | ./scripts/evaluate/vitdet_vid.py "$SLURM_JOB_NAME" 20 | -------------------------------------------------------------------------------- /scripts/evaluate/vivit_epic_kitchens.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from pathlib import Path 4 | 5 | from datasets.epic_kitchens import EPICKitchens 6 | from models.vivit import FactorizedViViT 7 | from utils.config import initialize_run 8 | from utils.evaluate import evaluate_vivit_metrics, run_evaluations 9 | 10 | 11 | def main(): 12 | config = initialize_run( 13 | config_location=Path("configs", "evaluate", "vivit_epic_kitchens") 14 | ) 15 | data = EPICKitchens(Path("data", "epic_kitchens"), split="validation") 16 | run_evaluations(config, FactorizedViViT, data, evaluate_vivit_metrics) 17 | 18 | 19 | if __name__ == "__main__": 20 | main() 21 | -------------------------------------------------------------------------------- /scripts/evaluate/vivit_kinetics400.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from pathlib import Path 4 | 5 | from datasets.kinetics400 import Kinetics400 6 | from models.vivit import FactorizedViViT 7 | from utils.config import initialize_run 8 | from utils.evaluate import run_evaluations, evaluate_vivit_metrics 9 | 10 | 11 | def main(): 12 | config = initialize_run( 13 | config_location=Path("configs", "evaluate", "vivit_kinetics400") 14 | ) 15 | data = Kinetics400( 16 | Path("data", "kinetics400"), split="val", decode_size=224, decode_fps=25 17 | ) 18 | run_evaluations(config, FactorizedViViT, data, evaluate_vivit_metrics) 19 | 20 | 21 | if __name__ == "__main__": 22 | main() 23 | -------------------------------------------------------------------------------- /scripts/misc/measure_vitdet_padding.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from pathlib import Path 4 | 5 | from tqdm import tqdm 6 | 7 | from datasets.vid import VIDResize, VID 8 | 9 | 10 | def main(): 11 | for size in 672, 1024: 12 | data = VID( 13 | Path("data", "vid"), 14 | split="vid_val", 15 | tar_path=Path("data", "vid", "data.tar"), 16 | combined_transform=VIDResize( 17 | short_edge_length=640 * size // 1024, max_size=size 18 | ), 19 | ) 20 | weighted_sum = 0.0 21 | total_frames = 0 22 | # noinspection PyTypeChecker 23 | for vid_item in tqdm(data, ncols=0): 24 | frame = vid_item[0][0] 25 | padding_ratio = frame.shape[-1] * frame.shape[-2] / (size**2) 26 | weighted_sum += len(vid_item) * padding_ratio 27 | total_frames += len(vid_item) 28 | print(f"Size {size}: {weighted_sum / total_frames:.5g}") 29 | 30 | 31 | if __name__ == "__main__": 32 | main() 33 | -------------------------------------------------------------------------------- /scripts/spatial/vivit_epic_kitchens.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from pathlib import Path 4 | 5 | from datasets.epic_kitchens import EPICKitchens 6 | from utils.config import get_cli_config 7 | from utils.spatial import compute_vivit_spatial 8 | 9 | 10 | def main(): 11 | config = get_cli_config( 12 | config_location=Path("configs", "spatial", "vivit_epic_kitchens") 13 | ) 14 | k = config["k"] 15 | location = Path("data", "epic_kitchens") 16 | for split in "train", "validation": 17 | print(f"{split.capitalize()}, k={k}", flush=True) 18 | data = EPICKitchens(location, split=split, shuffle=False) 19 | compute_vivit_spatial(config, location / split / f"spatial_{k}", data) 20 | 21 | 22 | if __name__ == "__main__": 23 | main() 24 | -------------------------------------------------------------------------------- /scripts/spatial/vivit_kinetics400.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from pathlib import Path 4 | 5 | from datasets.kinetics400 import Kinetics400 6 | from utils.config import get_cli_config 7 | from utils.spatial import compute_vivit_spatial 8 | 9 | 10 | def main(): 11 | config = get_cli_config( 12 | config_location=Path("configs", "spatial", "vivit_kinetics400") 13 | ) 14 | k = config["k"] 15 | location = Path("data", "kinetics400") 16 | for split in "train", "val": 17 | print(f"{split.capitalize()}, k={k}", flush=True) 18 | max_tars = config.get("max_tars", None) if (split == "train") else None 19 | data = Kinetics400( 20 | location, 21 | split=split, 22 | decode_size=224, 23 | decode_fps=25, 24 | max_tars=max_tars, 25 | shuffle=False, 26 | ) 27 | if max_tars is not None: 28 | split = f"{split}_{max_tars}" 29 | compute_vivit_spatial(config, location / split / f"spatial_224_25_{k}", data) 30 | 31 | 32 | if __name__ == "__main__": 33 | main() 34 | -------------------------------------------------------------------------------- /scripts/time/vitdet_vid.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import time 4 | from pathlib import Path 5 | 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from tqdm import tqdm 9 | 10 | from datasets.vid import VIDResize, VID 11 | from models.vitdet import ViTDet 12 | from utils.config import initialize_run 13 | from utils.evaluate import run_evaluations 14 | from utils.misc import MeanValue 15 | 16 | 17 | def evaluate_vitdet_runtime(device, model, data, config): 18 | model.no_counting() 19 | backbone = MeanValue() 20 | backbone_non_first = MeanValue() 21 | other = MeanValue() 22 | other_non_first = MeanValue() 23 | n_items = config.get("n_items", len(data)) 24 | for _, vid_item in tqdm(zip(range(n_items), data), total=n_items, ncols=0): 25 | vid_item = DataLoader(vid_item, batch_size=1) 26 | model.reset() 27 | for t, (frame, annotations) in enumerate(vid_item): 28 | with torch.inference_mode(): 29 | frame = frame.to(device) 30 | torch.cuda.synchronize() 31 | t_0 = time.time() 32 | images, x = model.pre_backbone(frame) 33 | torch.cuda.synchronize() 34 | t_1 = time.time() 35 | x = model.backbone(x) 36 | torch.cuda.synchronize() 37 | t_2 = time.time() 38 | model.post_backbone(images, x) 39 | torch.cuda.synchronize() 40 | t_3 = time.time() 41 | t_backbone = t_2 - t_1 42 | t_other = (t_3 - t_2) + (t_1 - t_0) 43 | backbone.update(t_backbone) 44 | other.update(t_other) 45 | if t > 0: 46 | backbone_non_first.update(t_backbone) 47 | other_non_first.update(t_other) 48 | times = { 49 | "backbone": backbone.compute(), 50 | "backbone_non_first": backbone_non_first.compute(), 51 | "other": other.compute(), 52 | "other_non_first": other_non_first.compute(), 53 | "total": backbone.compute() + other.compute(), 54 | "total_non_first": backbone_non_first.compute() + other_non_first.compute(), 55 | } 56 | return {"times": times} 57 | 58 | 59 | def main(): 60 | config = initialize_run(config_location=Path("configs", "time", "vitdet_vid")) 61 | input_size = config.get("input_size", 1024) 62 | data = VID( 63 | Path("data", "vid"), 64 | split=config["split"], 65 | tar_path=Path("data", "vid", "data.tar"), 66 | combined_transform=VIDResize( 67 | short_edge_length=640 * input_size // 1024, max_size=input_size 68 | ), 69 | ) 70 | run_evaluations(config, ViTDet, data, evaluate_vitdet_runtime) 71 | 72 | 73 | if __name__ == "__main__": 74 | main() 75 | -------------------------------------------------------------------------------- /scripts/time/vivit_epic_kitchens.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import time 4 | from pathlib import Path 5 | 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from tqdm import tqdm 9 | 10 | from datasets.epic_kitchens import EPICKitchens 11 | from models.vivit import FactorizedViViT 12 | from utils.config import initialize_run 13 | from utils.evaluate import run_evaluations 14 | from utils.misc import MeanValue 15 | 16 | 17 | def evaluate_vivit_runtime(device, model, data, config): 18 | model.no_counting() 19 | spatial = MeanValue() 20 | temporal = MeanValue() 21 | data = DataLoader(data, batch_size=1) 22 | n_items = config.get("n_items", len(data)) 23 | for _, (video, label) in tqdm(zip(range(n_items), data), total=n_items, ncols=0): 24 | model.reset() 25 | with torch.inference_mode(): 26 | video = video.to(device) 27 | torch.cuda.synchronize() 28 | t_0 = time.time() 29 | model.spatial_only = True 30 | model.temporal_only = False 31 | x = model(video) 32 | torch.cuda.synchronize() 33 | t_1 = time.time() 34 | model.spatial_only = False 35 | model.temporal_only = True 36 | model(x) 37 | t_2 = time.time() 38 | torch.cuda.synchronize() 39 | spatial.update(t_1 - t_0) 40 | temporal.update(t_2 - t_1) 41 | times = { 42 | "spatial": spatial.compute(), 43 | "temporal": temporal.compute(), 44 | "total": spatial.compute() + temporal.compute(), 45 | } 46 | return {"times": times} 47 | 48 | 49 | def main(): 50 | config = initialize_run( 51 | config_location=Path("configs", "time", "vivit_epic_kitchens") 52 | ) 53 | data = EPICKitchens(Path("data", "epic_kitchens"), split="validation") 54 | run_evaluations(config, FactorizedViViT, data, evaluate_vivit_runtime) 55 | 56 | 57 | if __name__ == "__main__": 58 | main() 59 | -------------------------------------------------------------------------------- /scripts/train/vivit_epic_kitchens.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from pathlib import Path 4 | 5 | from datasets.vivit_spatial import ViViTSpatial 6 | from utils.config import get_cli_config 7 | from utils.train import train_vivit_temporal 8 | 9 | 10 | def main(): 11 | config = get_cli_config( 12 | config_location=Path("configs", "train", "vivit_epic_kitchens") 13 | ) 14 | train_data = ViViTSpatial( 15 | Path("data", "epic_kitchens"), split="train", k=config["k"] 16 | ) 17 | val_data = ViViTSpatial( 18 | Path("data", "epic_kitchens"), split="validation", k=config["k"] 19 | ) 20 | train_vivit_temporal(config, train_data, val_data) 21 | 22 | 23 | if __name__ == "__main__": 24 | main() 25 | -------------------------------------------------------------------------------- /scripts/train/vivit_kinetics400.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from pathlib import Path 4 | 5 | from datasets.vivit_spatial import ViViTSpatial 6 | from utils.config import get_cli_config 7 | from utils.train import train_vivit_temporal 8 | 9 | 10 | def main(): 11 | config = get_cli_config( 12 | config_location=Path("configs", "train", "vivit_kinetics400") 13 | ) 14 | train_data = ViViTSpatial( 15 | Path("data", "kinetics400"), 16 | split="train_40", 17 | base_name="spatial_224_25", 18 | k=config["k"], 19 | ) 20 | val_data = ViViTSpatial( 21 | Path("data", "kinetics400"), 22 | split="val", 23 | base_name="spatial_224_25", 24 | k=config["k"], 25 | ) 26 | train_vivit_temporal(config, train_data, val_data) 27 | 28 | 29 | if __name__ == "__main__": 30 | main() 31 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from pathlib import Path 3 | 4 | from omegaconf import OmegaConf 5 | 6 | 7 | def get_cli_config(config_location="."): 8 | # Parse command-line arguments. 9 | parser = ArgumentParser() 10 | parser.add_argument( 11 | "name", 12 | help=f'the configuration name (the file is "{config_location}/.yml")', 13 | ) 14 | parser.add_argument( 15 | "overrides", nargs="*", help="configuration overrides (like a.b.c=value)" 16 | ) 17 | args = parser.parse_args() 18 | 19 | # Merge the configuration file and command-line overrides. 20 | config_path = Path(config_location, f"{args.name}.yml") 21 | config = load_config(config_path, to_container=False) 22 | config = OmegaConf.merge(config, OmegaConf.from_dotlist(args.overrides)) 23 | 24 | # Generate a unique name for this configuration. 25 | if "_name" not in config: 26 | if len(args.overrides) == 0: 27 | name = config_path.stem 28 | else: 29 | name = f"{config_path.stem}-{'-'.join(args.overrides)}" 30 | config["_name"] = name 31 | 32 | return OmegaConf.to_container(config, resolve=True) 33 | 34 | 35 | def initialize_run(config_location="."): 36 | config = get_cli_config(config_location=config_location) 37 | 38 | if "_output" in config: 39 | # Create an output directory and save the merged configuration. 40 | output_dir = Path(config["_output"]) 41 | output_dir.mkdir(parents=True, exist_ok=True) 42 | OmegaConf.save(config, output_dir / "config.yml", resolve=True) 43 | 44 | return config 45 | 46 | 47 | def load_config(config_path, to_container=True): 48 | # Load the specified config, composing with defaults if necessary. 49 | config = OmegaConf.load(config_path) 50 | defaults = [] 51 | for defaults_path in config.pop("_defaults", []): 52 | relative_path = Path(config_path).parent / defaults_path 53 | chosen_path = relative_path if relative_path.is_file() else defaults_path 54 | defaults.append(load_config(chosen_path, to_container=False)) 55 | config = OmegaConf.merge(*defaults, config) 56 | return OmegaConf.to_container(config, resolve=True) if to_container else config 57 | -------------------------------------------------------------------------------- /utils/evaluate.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from tqdm import tqdm 6 | 7 | from eventful_transformer.base import dict_csv_header, dict_csv_line, dict_string 8 | from eventful_transformer.policies import ( 9 | TokenNormThreshold, 10 | TokenNormTopK, 11 | TokenNormTopFraction, 12 | ) 13 | from utils.misc import ( 14 | TopKAccuracy, 15 | get_device_description, 16 | get_pytorch_device, 17 | set_policies, 18 | tee_print, 19 | ) 20 | 21 | 22 | def evaluate_vivit_metrics(device, model, data, config): 23 | model.counting() 24 | model.clear_counts() 25 | top_1 = TopKAccuracy(k=1) 26 | top_5 = TopKAccuracy(k=5) 27 | data = DataLoader(data, batch_size=1) 28 | n_items = config.get("n_items", len(data)) 29 | for _, (video, label) in tqdm(zip(range(n_items), data), total=n_items, ncols=0): 30 | model.reset() 31 | with torch.inference_mode(): 32 | output = model(video.to(device)) 33 | label = label.to(device) 34 | top_1.update(output, label) 35 | top_5.update(output, label) 36 | metrics = {"top_1": top_1.compute(), "top_5": top_5.compute()} 37 | counts = model.total_counts() / n_items 38 | model.clear_counts() 39 | return {"metrics": metrics, "counts": counts} 40 | 41 | 42 | def run_evaluations(config, model_class, data, evaluate_function): 43 | device = config.get("device", get_pytorch_device()) 44 | if "threads" in config: 45 | torch.set_num_threads(config["threads"]) 46 | 47 | # Load and set up the model. 48 | model = model_class(**(config["model"])) 49 | model.load_state_dict(torch.load(config["weights"])) 50 | model = model.to(device) 51 | 52 | completed = [] 53 | output_dir = Path(config["_output"]) 54 | 55 | def do_evaluation(title): 56 | with open(output_dir / "output.txt", "a") as tee_file: 57 | # Run the evaluation. 58 | model.eval() 59 | results = evaluate_function(device, model, data, config) 60 | 61 | # Print and save results. 62 | tee_print(title, tee_file) 63 | tee_print(get_device_description(device), tee_file) 64 | if isinstance(results, dict): 65 | save_csv_results(results, output_dir, first_run=(len(completed) == 0)) 66 | for key, val in results.items(): 67 | tee_print(key.capitalize(), tee_file) 68 | tee_print(dict_string(val), tee_file) 69 | else: 70 | tee_print(results, tee_file) 71 | tee_print("", tee_file) 72 | completed.append(title) 73 | 74 | # Evaluate the model. 75 | if config.get("vanilla", False): 76 | do_evaluation("Vanilla") 77 | for k in config.get("token_top_k", []): 78 | set_policies(model, TokenNormTopK, k=k) 79 | do_evaluation(f"Token top k={k}") 80 | for fraction in config.get("token_top_fraction", []): 81 | set_policies(model, TokenNormTopFraction, fraction=fraction) 82 | do_evaluation(f"Token top {fraction * 100:.1f}%") 83 | for threshold in config.get("token_thresholds", []): 84 | set_policies(model, TokenNormThreshold, threshold=threshold) 85 | do_evaluation(f"Token threshold {threshold}") 86 | 87 | 88 | def save_csv_results(results, output_dir, first_run=False): 89 | for key, val in results.items(): 90 | with open(output_dir / f"{key}.csv", "a") as csv_file: 91 | if first_run: 92 | print(dict_csv_header(val), file=csv_file) 93 | print(dict_csv_line(val), file=csv_file) 94 | -------------------------------------------------------------------------------- /utils/image.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | import torchvision 5 | import torchvision.transforms.functional as func 6 | from torchvision import transforms 7 | 8 | 9 | def as_float32(x): 10 | if isinstance(x, torch.Tensor) and x.dtype == torch.uint8: 11 | return x.float() / 255.0 12 | elif isinstance(x, np.ndarray) and x.dtype == np.uint8: 13 | return x.astype(np.float32) / 255.0 14 | elif type(x) in (tuple, list) and isinstance(x[0], int): 15 | return type(x)(x_i / 255.0 for x_i in x) 16 | else: 17 | return x 18 | 19 | 20 | def as_uint8(x): 21 | if isinstance(x, torch.Tensor) and x.dtype != torch.uint8: 22 | return (x * 255.0).byte() 23 | elif isinstance(x, np.ndarray) and x.dtype != np.uint8: 24 | return (x * 255.0).astype(np.uint8) 25 | elif type(x) in (tuple, list) and isinstance(x[0], float): 26 | return type(x)(int(x_i * 255.0) for x_i in x) 27 | else: 28 | return x 29 | 30 | 31 | def pad_to_size(x, size, pad_tensor=None): 32 | # padding = [0, size[1] - x.shape[-1], 0, size[0] - x.shape[-2]] 33 | # x = func.pad(x, padding, fill=0, padding_mode="constant") 34 | # The two lines above are not working as expected - maybe there's a 35 | # bug in func.pad? In the meantime we'll use the concat-based 36 | # padding code below. 37 | if pad_tensor is None: 38 | pad_tensor = torch.zeros((1,) * x.ndim, dtype=x.dtype, device=x.device) 39 | for dim in list(range(-1, -len(size) - 1, -1)): 40 | expand_shape = list(x.shape) 41 | expand_shape[dim] = size[dim] - x.shape[dim] 42 | if expand_shape[dim] == 0: 43 | continue 44 | 45 | # torch.concat allocates a new tensor. So, we're safe to use 46 | # torch.expand here (instead of torch.repeat) without worrying 47 | # about different elements of x referencing the same data. 48 | x = torch.concat([x, pad_tensor.expand(expand_shape)], dim) 49 | return x 50 | 51 | 52 | def rescale( 53 | x, scale, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True 54 | ): 55 | if scale != 1.0: 56 | x = func.resize( 57 | x, 58 | [round(scale * x.shape[-2]), round(scale * x.shape[-1])], 59 | interpolation=interpolation, 60 | antialias=antialias, 61 | ) 62 | return x 63 | 64 | 65 | def resize_to_fit( 66 | x, size, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True 67 | ): 68 | scale = max(size[0] / x.shape[-2], size[1] / x.shape[-1]) 69 | x = rescale(x, scale, interpolation=interpolation, antialias=antialias) 70 | return x 71 | 72 | 73 | def save_image_mpl(image, pathname, **imshow_kwargs): 74 | fig, ax = plt.subplots() 75 | ax.imshow(image, **imshow_kwargs) 76 | ax.axis("off") 77 | fig.savefig(pathname, bbox_inches="tight", pad_inches=0.0) 78 | plt.close() 79 | 80 | 81 | def write_image(filename, image, **kwargs): 82 | filename = str(filename) 83 | lower = filename.lower() 84 | image = torch.as_tensor(image) 85 | assert any(lower.endswith(ext) for ext in [".png", ".jpg", ".jpeg"]) 86 | if lower.endswith(".png"): 87 | torchvision.io.write_png(image, filename, **kwargs) 88 | else: 89 | torchvision.io.write_jpeg(image, filename, **kwargs) 90 | 91 | 92 | def write_video(filename, video, fps=30, is_chw=True): 93 | filename = str(filename) 94 | video = torch.as_tensor(video) 95 | if is_chw: 96 | video = video.permute(0, 2, 3, 1) 97 | torchvision.io.write_video(filename, video, fps=fps) 98 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import re 2 | import subprocess 3 | from pathlib import Path 4 | from random import Random 5 | 6 | import requests 7 | import torch 8 | 9 | from eventful_transformer.modules import SimpleSTGTGate, TokenDeltaGate, TokenGate 10 | 11 | 12 | class MeanValue: 13 | def __init__(self): 14 | self.sum = 0.0 15 | self.count = 0 16 | 17 | def compute(self): 18 | return 0.0 if (self.count == 0) else self.sum / self.count 19 | 20 | def reset(self): 21 | self.sum = 0.0 22 | self.count = 0 23 | 24 | def update(self, value): 25 | self.sum += value 26 | self.count += 1 27 | 28 | 29 | class TopKAccuracy: 30 | def __init__(self, k): 31 | self.k = k 32 | self.correct = 0 33 | self.total = 0 34 | 35 | def compute(self): 36 | return self.correct / self.total 37 | 38 | def reset(self): 39 | self.correct = 0 40 | self.total = 0 41 | 42 | def update(self, pred, true): 43 | _, top_k = pred.topk(self.k, dim=-1) 44 | self.correct += true.eq(top_k).sum().item() 45 | self.total += true.numel() 46 | 47 | 48 | def decode_video( 49 | input_path, 50 | output_path, 51 | name_format="%d", 52 | image_format="png", 53 | ffmpeg_input_args=None, 54 | ffmpeg_output_args=None, 55 | ): 56 | output_path = Path(output_path) 57 | output_path.mkdir(exist_ok=True) 58 | if ffmpeg_input_args is None: 59 | ffmpeg_input_args = [] 60 | if ffmpeg_output_args is None: 61 | ffmpeg_output_args = [] 62 | return subprocess.call( 63 | ["ffmpeg", "-loglevel", "error"] 64 | + ffmpeg_input_args 65 | + ["-i", input_path] 66 | + ffmpeg_output_args 67 | + [output_path / f"{name_format}.{image_format}"] 68 | ) 69 | 70 | 71 | def dict_to_device(x, device): 72 | return {key: value.to(device) for key, value in x.items()} 73 | 74 | 75 | # https://gist.github.com/wasi0013/ab73f314f8070951b92f6670f68b2d80 76 | def download_file(url, output_path, chunk_size=4096, verbose=True): 77 | if verbose: 78 | print(f"Downloading {url}...", flush=True) 79 | with requests.get(url, stream=True) as source: 80 | with open(output_path, "wb") as output_file: 81 | for chunk in source.iter_content(chunk_size=chunk_size): 82 | if chunk: 83 | output_file.write(chunk) 84 | 85 | 86 | def get_device_description(device): 87 | if device == "cuda": 88 | return torch.cuda.get_device_name() 89 | else: 90 | return f"CPU with {torch.get_num_threads()} threads" 91 | 92 | 93 | def get_pytorch_device(): 94 | return "cuda" if torch.cuda.is_available() else "cpu" 95 | 96 | 97 | def parse_patterns(pattern_file): 98 | patterns = [] 99 | last_regex = None 100 | with open(pattern_file, "r") as text: 101 | for line in text: 102 | line = line.strip() 103 | if line == "": 104 | continue 105 | elif last_regex is None: 106 | last_regex = re.compile(line) 107 | else: 108 | patterns.append((last_regex, line)) 109 | last_regex = None 110 | return patterns 111 | 112 | 113 | def remap_weights(in_weights, patterns, verbose=False): 114 | n_remapped = 0 115 | out_weights = {} 116 | for in_key, weight in in_weights.items(): 117 | out_key = in_key 118 | discard = False 119 | for regex, replacement in patterns: 120 | out_key, n_matches = regex.subn(replacement, out_key) 121 | if n_matches > 0: 122 | if replacement == "DISCARD": 123 | discard = True 124 | out_key = "DISCARD" 125 | n_remapped += 1 126 | if verbose: 127 | print(f"{in_key} ==> {out_key}") 128 | break 129 | if not discard: 130 | out_weights[out_key] = weight 131 | return out_weights, n_remapped 132 | 133 | 134 | def seeded_shuffle(sequence, seed): 135 | rng = Random() 136 | rng.seed(seed) 137 | rng.shuffle(sequence) 138 | 139 | 140 | def set_policies(model, policy_class, **policy_kwargs): 141 | for gate_class in [SimpleSTGTGate, TokenDeltaGate, TokenGate]: 142 | for gate in model.modules_of_type(gate_class): 143 | gate.policy = policy_class(**policy_kwargs) 144 | 145 | 146 | def squeeze_dict(x, dim=None): 147 | return {key: value.squeeze(dim=dim) for key, value in x.items()} 148 | 149 | 150 | def tee_print(s, file, flush=True): 151 | print(s, flush=flush) 152 | print(s, file=file, flush=flush) 153 | -------------------------------------------------------------------------------- /utils/spatial.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from tqdm import tqdm 7 | 8 | from eventful_transformer.policies import TokenNormTopK 9 | from models.vivit import FactorizedViViT 10 | from utils.misc import get_pytorch_device, set_policies 11 | 12 | 13 | def compute_vivit_spatial(config, output_dir, data): 14 | device = get_pytorch_device() 15 | 16 | # Load and set up the model. 17 | model = FactorizedViViT(**(config["model"])) 18 | model.load_state_dict(torch.load(config["weights"])) 19 | model = model.to(device) 20 | 21 | set_policies(model, TokenNormTopK, k=config["k"]) 22 | output_dir = Path(output_dir) 23 | output_dir.mkdir(exist_ok=True) 24 | data = DataLoader(data, batch_size=1, drop_last=False) 25 | for i, (video, label) in tqdm(enumerate(data), total=len(data), ncols=0): 26 | model.reset() 27 | with torch.inference_mode(): 28 | spatial = model(video.to(device)) 29 | np.savez( 30 | output_dir / f"{i:05d}.npz", 31 | spatial=spatial.cpu().numpy(), 32 | label=label[0].cpu().numpy(), 33 | ) 34 | -------------------------------------------------------------------------------- /utils/train.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | import torch 4 | from torch import optim 5 | from torch.utils.data import DataLoader 6 | from torch.utils.tensorboard import SummaryWriter 7 | from tqdm import tqdm 8 | 9 | from models.vivit import FactorizedViViT 10 | from utils.misc import MeanValue, TopKAccuracy, get_pytorch_device 11 | 12 | 13 | def train_vivit_temporal(config, train_data, val_data): 14 | device = get_pytorch_device() 15 | torch.random.manual_seed(42) 16 | 17 | # Set up the dataset. 18 | train_data = DataLoader( 19 | train_data, batch_size=config["train_batch_size"], shuffle=True 20 | ) 21 | val_data = DataLoader(val_data, batch_size=config["val_batch_size"]) 22 | 23 | # Load and set up the model. 24 | model = FactorizedViViT(**(config["model"])) 25 | model.load_state_dict(torch.load(config["starting_weights"])) 26 | model = model.to(device) 27 | 28 | # Set up the optimizer. 29 | optimizer_class = getattr(optim, config["optimizer"]) 30 | optimizer = optimizer_class( 31 | list(model.temporal_model.parameters()) + list(model.classifier.parameters()), 32 | **config["optimizer_kwargs"], 33 | ) 34 | 35 | # Set up the loss and metrics. 36 | loss_function = torch.nn.CrossEntropyLoss() 37 | mean_loss = MeanValue() 38 | top_1 = TopKAccuracy(k=1) 39 | top_5 = TopKAccuracy(k=5) 40 | 41 | # Set up TensorBoard logging. 42 | if "tensorboard" in config: 43 | base_name = config["tensorboard"] 44 | now_str = datetime.now().strftime("%Y-%m-%d_%H:%M:%S") 45 | tensorboard = SummaryWriter(f"{base_name}_{now_str}") 46 | else: 47 | tensorboard = None 48 | 49 | def log_epoch(tb_key, step): 50 | value_loss = mean_loss.compute() 51 | value_1 = top_1.compute() 52 | value_5 = top_5.compute() 53 | if tensorboard is not None: 54 | tensorboard.add_scalars("loss", {tb_key: value_loss}, step) 55 | tensorboard.add_scalars("top_1", {tb_key: value_1}, step) 56 | tensorboard.add_scalars("top_5", {tb_key: value_5}, step) 57 | print(f"Loss: {value_loss:.4f}; Top-1: {value_1:.4f}; Top-5: {value_5:.4f}") 58 | 59 | def train_pass(step): 60 | model.train() 61 | mean_loss.reset() 62 | top_1.reset() 63 | top_5.reset() 64 | print("Training pass", flush=True) 65 | for spatial, label in tqdm(train_data, total=len(train_data), ncols=0): 66 | label = label.to(device) 67 | output = model(spatial.to(device)) 68 | loss = loss_function(output, label) 69 | optimizer.zero_grad() 70 | loss.backward() 71 | optimizer.step() 72 | mean_loss.update(loss.item()) 73 | top_1.update(output, label.unsqueeze(dim=-1)) 74 | top_5.update(output, label.unsqueeze(dim=-1)) 75 | log_epoch("train", step) 76 | 77 | def val_pass(step): 78 | model.eval() 79 | mean_loss.reset() 80 | top_1.reset() 81 | top_5.reset() 82 | print("Validation pass", flush=True) 83 | for spatial, label in tqdm(val_data, total=len(val_data), ncols=0): 84 | label = label.to(device) 85 | output = model(spatial.to(device)) 86 | loss = loss_function(output, label) 87 | mean_loss.update(loss.item()) 88 | top_1.update(output, label.unsqueeze(dim=-1)) 89 | top_5.update(output, label.unsqueeze(dim=-1)) 90 | log_epoch("val", step) 91 | 92 | val_pass(0) 93 | n_epochs = config["epochs"] 94 | for epoch in range(n_epochs): 95 | print(f"\nEpoch {epoch + 1}/{n_epochs}", flush=True) 96 | train_pass(epoch + 1) 97 | val_pass(epoch + 1) 98 | 99 | if tensorboard is not None: 100 | tensorboard.close() 101 | 102 | # Save the final weights. 103 | weight_path = config["output_weights"] 104 | torch.save(model.state_dict(), weight_path) 105 | print(f"Saved weights to {weight_path}", flush=True) 106 | --------------------------------------------------------------------------------