├── .flake8 ├── .gitignore ├── .pre-commit-config.yaml ├── .style.yapf ├── LICENSE ├── README.md ├── assets ├── Position_Embedding.png ├── fig1.png ├── fig2.png └── fig3.png ├── pyproject.toml ├── requirements-dev.txt ├── requirements.txt ├── scripts ├── convert_vit_weight.py ├── evaluate.py ├── infernce.py ├── train.py ├── visualise_dataset.py └── visualise_pos_embed.py ├── setup.py └── tubevit ├── __init__.py ├── dataset.py ├── model.py └── positional_encoding.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E402,E501,W503,W504,C901,W291,E266,W293,E126 3 | max-line-length = 120 4 | exclude = 5 | .git, 6 | __pycache__, 7 | build 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # ignore video meta 132 | *.pickle 133 | 134 | # ignore model weight 135 | *.pt 136 | 137 | # pytorch lightning log 138 | lightning_logs 139 | 140 | # pycharm 141 | .idea 142 | 143 | # ignore dataset 144 | data/ 145 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pycqa/isort 3 | rev: 5.12.0 4 | hooks: 5 | - id: isort 6 | name: isort (python) 7 | - repo: https://github.com/psf/black 8 | rev: 23.7.0 9 | hooks: 10 | - id: black 11 | - repo: https://github.com/astral-sh/ruff-pre-commit 12 | rev: v0.0.284 13 | hooks: 14 | - id: ruff 15 | - repo: https://github.com/pre-commit/pre-commit-hooks 16 | rev: v4.4.0 17 | hooks: 18 | - id: detect-private-key 19 | - id: check-added-large-files 20 | - id: check-merge-conflict 21 | - id: check-json 22 | - id: check-yaml 23 | - id: check-toml 24 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = pep8 3 | COLUMN_LIMIT = 120 4 | ALLOW_SPLIT_BEFORE_DICT_VALUE = false 5 | SPACES_AROUND_POWER_OPERATOR = true 6 | BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = false 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Su YR 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TubeViT 2 | 3 | An unofficial implementation of TubeViT 4 | in "[Rethinking Video ViTs: Sparse Video Tubes for Joint Image and Video Learning](https://arxiv.org/abs/2212.03229)" 5 | 6 | # Spec. 7 | 8 | - [x] Fixed Positional embedding 9 | - [ ] Sparse Tube Construction 10 | - [x] Multi-Tube 11 | - [x] Interpolated Kernels 12 | - [ ] Space To Depth 13 | - [ ] config of tubes 14 | - [ ] pipeline 15 | - [x] training 16 | - [x] evaluating 17 | - [ ] inference 18 | 19 | # Usage 20 | 21 | This project is based on `torch==1.13.1` and [pytorch-lightning](https://github.com/Lightning-AI/lightning) 22 | 23 | ## Setup 24 | 25 | 1. Install requirements 26 | 27 | ```commandline 28 | pip install -r requirements.txt 29 | ``` 30 | 31 | 2. Download UFC101 dataset 32 | 33 | ## Convert ViT pre-trained weight 34 | 35 | Use `convert_vit_weight.py` to convert torch ViT pre-trained weight to TubeVit. 36 | 37 | ```commandline 38 | python scripts/convert_vit_weight.py --help  ✔ 39 | Usage: convert_vit_weight.py [OPTIONS] 40 | 41 | Options: 42 | -nc, --num-classes INTEGER num of classes of dataset. 43 | -f, --frames-per-clip INTEGER frame per clip. 44 | -v, --video-size ... 45 | frame per clip. 46 | -o, --output-path PATH output model weight name. 47 | --help Show this message and exit. 48 | ``` 49 | 50 | ### Example 51 | 52 | Convert ImageNet pre-trained weight to UCF101. `--num-classes` is 101 by default. 53 | 54 | ```commandline 55 | python scripts/convert_vit_weight.py 56 | ``` 57 | 58 | ## Train 59 | 60 | Current `train.py` only train on pytorch UCF101 dataset. 61 | Change the dataset if needed. 62 | 63 | `--dataset-root` and `--annotation-path` is based 64 | on [torchvision.datasets.UCF101](https://pytorch.org/vision/main/generated/torchvision.datasets.UCF101.html) 65 | 66 | ```commandline 67 | python scripts/train.py --help 68 | 69 | Usage: train.py [OPTIONS] 70 | 71 | Options: 72 | -r, --dataset-root PATH path to dataset. [required] 73 | -a, --annotation-path PATH path to dataset. [required] 74 | -nc, --num-classes INTEGER num of classes of dataset. 75 | -b, --batch-size INTEGER batch size. 76 | -f, --frames-per-clip INTEGER frame per clip. 77 | -v, --video-size ... 78 | frame per clip. 79 | --max-epochs INTEGER max epochs. 80 | --num-workers INTEGER 81 | --fast-dev-run 82 | --seed INTEGER random seed. 83 | --preview-video Show input video 84 | --help Show this message and exit. 85 | ``` 86 | 87 | ### Examples 88 | 89 | ```commandline 90 | python scripts/train.py -r path/to/dataset -a path/to/annotation 91 | ``` 92 | 93 | ## Evaluation 94 | 95 | ```commandline 96 | python scripts/evaluate.py --help 97 | 98 | Usage: evaluate.py [OPTIONS] 99 | 100 | Options: 101 | -r, --dataset-root PATH path to dataset. [required] 102 | -m, --model-path PATH path to model weight. [required] 103 | -a, --annotation-path PATH path to dataset. [required] 104 | --label-path PATH path to classInd.txt. [required] 105 | -nc, --num-classes INTEGER num of classes of dataset. 106 | -b, --batch-size INTEGER batch size. 107 | -f, --frames-per-clip INTEGER frame per clip. 108 | -v, --video-size ... 109 | frame per clip. 110 | --num-workers INTEGER 111 | --seed INTEGER random seed. 112 | --verbose Show input video 113 | --help Show this message and exit. 114 | ``` 115 | 116 | ### Examples 117 | 118 | ```commandline 119 | python scripts/evaluate.py -r path/to/dataset -a path/to/annotation 120 | ``` 121 | 122 | # Model Architecture 123 | 124 | ![fig1.png](assets/fig1.png) 125 | ![fig2.png](assets/fig2.png) 126 | ![fig3.png](assets/fig3.png) 127 | 128 | # Positional embedding 129 | 130 | ![Position_Embedding.png](assets/Position_Embedding.png) 131 | 132 | -------------------------------------------------------------------------------- /assets/Position_Embedding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daniel-code/TubeViT/8df0d99aec9674f3ec2c5182534fb0cd4ab9aa49/assets/Position_Embedding.png -------------------------------------------------------------------------------- /assets/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daniel-code/TubeViT/8df0d99aec9674f3ec2c5182534fb0cd4ab9aa49/assets/fig1.png -------------------------------------------------------------------------------- /assets/fig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daniel-code/TubeViT/8df0d99aec9674f3ec2c5182534fb0cd4ab9aa49/assets/fig2.png -------------------------------------------------------------------------------- /assets/fig3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daniel-code/TubeViT/8df0d99aec9674f3ec2c5182534fb0cd4ab9aa49/assets/fig3.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | # Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default. 3 | select = [ 4 | "E", # pycodestyle error 5 | "F", # Pyflakes 6 | "I", # isort 7 | "PLE", # Pylint error 8 | "PLW", # Pylint warning 9 | ] 10 | ignore = [] 11 | 12 | # Allow autofix for all enabled rules (when `--fix`) is provided. 13 | fixable = ["F841", "F541"] 14 | unfixable = [] 15 | 16 | # Exclude a variety of commonly ignored directories. 17 | exclude = [ 18 | ".bzr", 19 | ".direnv", 20 | ".eggs", 21 | ".git", 22 | ".git-rewrite", 23 | ".hg", 24 | ".mypy_cache", 25 | ".nox", 26 | ".pants.d", 27 | ".pytype", 28 | ".ruff_cache", 29 | ".svn", 30 | ".tox", 31 | ".venv", 32 | "__pypackages__", 33 | "_build", 34 | "buck-out", 35 | "build", 36 | "dist", 37 | "node_modules", 38 | "venv", 39 | ] 40 | 41 | # Same as Black. 42 | line-length = 120 43 | 44 | # Allow unused variables when underscore-prefixed. 45 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 46 | 47 | # Assume Python 3.8 48 | target-version = "py38" 49 | 50 | [tool.ruff.mccabe] 51 | # Unlike Flake8, default to a complexity level of 10. 52 | max-complexity = 10 53 | 54 | [tool.black] 55 | line-length = 120 56 | target-version = ['py38'] 57 | 58 | [tool.isort] 59 | py_version = 38 60 | profile = "black" 61 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | black 2 | ruff 3 | isort 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | click~=8.1.7 2 | torch~=2.0.1 3 | torchvision~=0.15.2 4 | lightning~=2.0.7 5 | torchmetrics~=1.0.3 6 | pytorchvideo==0.1.5 7 | matplotlib~=3.7.2 8 | seaborn~=0.12.2 9 | -------------------------------------------------------------------------------- /scripts/convert_vit_weight.py: -------------------------------------------------------------------------------- 1 | import click 2 | import numpy as np 3 | import torch 4 | from torch import Tensor 5 | from torch.nn import functional as F 6 | from torchvision.models import ViT_B_16_Weights 7 | 8 | from tubevit.model import TubeViT 9 | 10 | 11 | @click.command() 12 | @click.option("-nc", "--num-classes", type=int, default=101, help="num of classes of dataset.") 13 | @click.option("-f", "--frames-per-clip", type=int, default=32, help="frame per clip.") 14 | @click.option("-v", "--video-size", type=click.Tuple([int, int]), default=(224, 224), help="frame per clip.") 15 | @click.option( 16 | "-o", 17 | "--output-path", 18 | type=click.Path(), 19 | default="tubevit_b_(a+iv)+(d+v)+(e+iv)+(f+v).pt", 20 | help="output model weight name.", 21 | ) 22 | def main(num_classes, frames_per_clip, video_size, output_path): 23 | x = np.random.random((1, 3, frames_per_clip, video_size[0], video_size[1])) 24 | x = Tensor(x) 25 | print("x: ", x.shape) 26 | 27 | y = np.random.randint(0, 1, size=(1, num_classes)) 28 | y = Tensor(y) 29 | print("y: ", y.shape) 30 | 31 | model = TubeViT( 32 | num_classes=num_classes, 33 | video_shape=x.shape[1:], 34 | num_layers=12, 35 | num_heads=12, 36 | hidden_dim=768, 37 | mlp_dim=3072, 38 | ) 39 | 40 | weights = ViT_B_16_Weights.DEFAULT.get_state_dict(progress=True) 41 | 42 | # inflated vit path convolution layer weight 43 | conv_proj_weight = weights["conv_proj.weight"] 44 | conv_proj_weight = F.interpolate(conv_proj_weight, (8, 8), mode="bilinear") 45 | conv_proj_weight = torch.unsqueeze(conv_proj_weight, dim=2) 46 | conv_proj_weight = conv_proj_weight.repeat(1, 1, 8, 1, 1) 47 | conv_proj_weight = conv_proj_weight / 8.0 48 | 49 | # remove missmatch parameters 50 | weights.pop("encoder.pos_embedding") 51 | weights.pop("heads.head.weight") 52 | weights.pop("heads.head.bias") 53 | 54 | model.load_state_dict(weights, strict=False) 55 | model.sparse_tubes_tokenizer.conv_proj_weight = torch.nn.Parameter(conv_proj_weight, requires_grad=True) 56 | 57 | torch.save(model.state_dict(), output_path) 58 | 59 | 60 | if __name__ == "__main__": 61 | main() 62 | -------------------------------------------------------------------------------- /scripts/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import click 5 | import matplotlib.pyplot as plt 6 | import pytorch_lightning as pl 7 | import seaborn as sns 8 | import torch 9 | from pytorchvideo.transforms import Normalize 10 | from torch.utils.data import DataLoader, RandomSampler 11 | from torchmetrics.functional import accuracy, auroc, confusion_matrix, f1_score 12 | from torchvision.transforms import transforms as T 13 | from torchvision.transforms._transforms_video import ToTensorVideo 14 | 15 | from tubevit.dataset import MyUCF101 16 | from tubevit.model import TubeViTLightningModule 17 | 18 | 19 | @click.command() 20 | @click.option("-r", "--dataset-root", type=click.Path(exists=True), required=True, help="path to dataset.") 21 | @click.option("-m", "--model-path", type=click.Path(exists=True), required=True, help="path to model weight.") 22 | @click.option("-a", "--annotation-path", type=click.Path(exists=True), required=True, help="path to dataset.") 23 | @click.option("--label-path", type=click.Path(exists=True), required=True, help="path to classInd.txt.") 24 | @click.option("-nc", "--num-classes", type=int, default=101, help="num of classes of dataset.") 25 | @click.option("-b", "--batch-size", type=int, default=32, help="batch size.") 26 | @click.option("-f", "--frames-per-clip", type=int, default=32, help="frame per clip.") 27 | @click.option("-v", "--video-size", type=click.Tuple([int, int]), default=(224, 224), help="frame per clip.") 28 | @click.option("--num-workers", type=int, default=0) 29 | @click.option("--seed", type=int, default=42, help="random seed.") 30 | @click.option("--verbose", type=bool, is_flag=True, show_default=True, default=False, help="Show input video") 31 | def main( 32 | dataset_root, 33 | model_path, 34 | annotation_path, 35 | label_path, 36 | num_classes, 37 | batch_size, 38 | frames_per_clip, 39 | video_size, 40 | num_workers, 41 | seed, 42 | verbose, 43 | ): 44 | pl.seed_everything(seed) 45 | 46 | with open(label_path, "r") as f: 47 | labels = f.read().splitlines() 48 | labels = list(map(lambda x: x.split(" ")[-1], labels)) 49 | 50 | test_transform = T.Compose( 51 | [ 52 | ToTensorVideo(), 53 | T.Resize(size=video_size), 54 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 55 | ] 56 | ) 57 | 58 | val_metadata_file = "ucf101-val-meta.pickle" 59 | val_precomputed_metadata = None 60 | if os.path.exists(val_metadata_file): 61 | with open(val_metadata_file, "rb") as f: 62 | val_precomputed_metadata = pickle.load(f) 63 | 64 | val_set = MyUCF101( 65 | root=dataset_root, 66 | annotation_path=annotation_path, 67 | _precomputed_metadata=val_precomputed_metadata, 68 | frames_per_clip=frames_per_clip, 69 | train=False, 70 | output_format="THWC", 71 | transform=test_transform, 72 | ) 73 | 74 | if not os.path.exists(val_metadata_file): 75 | with open(val_metadata_file, "wb") as f: 76 | pickle.dump(val_set.metadata, f, protocol=pickle.HIGHEST_PROTOCOL) 77 | 78 | val_sampler = RandomSampler(val_set, num_samples=len(val_set) // 5000) 79 | val_dataloader = DataLoader( 80 | val_set, 81 | batch_size=batch_size, 82 | num_workers=num_workers, 83 | shuffle=False, 84 | drop_last=True, 85 | sampler=val_sampler, 86 | ) 87 | 88 | x, y = next(iter(val_dataloader)) 89 | print(x.shape) 90 | 91 | model = TubeViTLightningModule.load_from_checkpoint(model_path) 92 | 93 | trainer = pl.Trainer(accelerator="auto", default_root_dir="lightning_predict_logs") 94 | predictions = trainer.predict(model, dataloaders=val_dataloader) 95 | 96 | y = torch.cat([item["y"] for item in predictions]) 97 | y_pred = torch.cat([item["y_pred"] for item in predictions]) 98 | y_prob = torch.cat([item["y_prob"] for item in predictions]) 99 | 100 | print("accuracy:", accuracy(y_prob, y, task="multiclass", num_classes=num_classes)) 101 | print("accuracy_top5:", accuracy(y_prob, y, task="multiclass", num_classes=num_classes, top_k=5)) 102 | print("auroc:", auroc(y_prob, y, task="multiclass", num_classes=num_classes)) 103 | print("f1_score:", f1_score(y_prob, y, task="multiclass", num_classes=num_classes)) 104 | 105 | cm = confusion_matrix(y_pred, y, task="multiclass", num_classes=num_classes) 106 | 107 | plt.figure(figsize=(20, 20), dpi=100) 108 | ax = sns.heatmap(cm, annot=False, fmt="d", xticklabels=labels, yticklabels=labels) 109 | ax.set_xlabel("Prediction") 110 | ax.set_ylabel("Ground Truth") 111 | ax.set_title("Confusion Matrix") 112 | plt.tight_layout() 113 | plt.savefig("output.png", dpi=300) 114 | if verbose: 115 | plt.show() 116 | 117 | 118 | if __name__ == "__main__": 119 | main() 120 | -------------------------------------------------------------------------------- /scripts/infernce.py: -------------------------------------------------------------------------------- 1 | import click 2 | import torch 3 | from pytorchvideo.data.encoded_video import EncodedVideo 4 | from pytorchvideo.transforms import ApplyTransformToKey, UniformTemporalSubsample, ShortSideScale 5 | from torchvision.transforms import Compose, Lambda 6 | from torchvision.transforms._transforms_video import NormalizeVideo, CenterCropVideo 7 | 8 | from tubevit.model import TubeViTLightningModule 9 | 10 | 11 | @click.command() 12 | @click.argument("video-path") 13 | @click.option("-m", "--model-path", type=click.Path(exists=True), required=True, help="path to model weight.") 14 | @click.option("--label-path", type=click.Path(exists=True), required=True, help="path to classInd.txt.") 15 | @click.option("-f", "--frames-per-clip", type=int, default=32, help="frame per clip.") 16 | @click.option("-v", "--video-size", type=click.Tuple([int, int]), default=(224, 224), help="frame per clip.") 17 | def main( 18 | video_path, 19 | model_path, 20 | label_path, 21 | frames_per_clip, 22 | video_size, 23 | ): 24 | with open(label_path, "r") as f: 25 | labels = f.read().splitlines() 26 | labels = list(map(lambda x: x.split(" ")[-1], labels)) 27 | 28 | # Compose video data transforms 29 | transform = ApplyTransformToKey( 30 | key="video", 31 | transform=Compose( 32 | [ 33 | UniformTemporalSubsample(frames_per_clip), 34 | Lambda(lambda x: x / 255.0), 35 | NormalizeVideo(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 36 | ShortSideScale( 37 | size=video_size[0] 38 | ), 39 | CenterCropVideo(crop_size=video_size) 40 | ] 41 | ), 42 | ) 43 | 44 | # Load video 45 | video = EncodedVideo.from_path(video_path) 46 | # Get clip 47 | clip_start_sec = 0.0 # secs 48 | clip_duration = 2.0 # secs 49 | duration = video.duration 50 | video_data = [] 51 | for i in range(10): 52 | if clip_start_sec + clip_duration * (i + 1) <= duration: 53 | data = video.get_clip(start_sec=clip_start_sec + clip_duration * i, 54 | end_sec=clip_start_sec + clip_duration * (i + 1)) 55 | data = transform(data) 56 | video_data.append(data['video']) 57 | 58 | video_data = torch.stack(video_data) 59 | model = TubeViTLightningModule.load_from_checkpoint(model_path) 60 | prediction = model.predict_step(batch=(video_data, None), batch_idx=0) 61 | print(video_data.shape) 62 | print('Predict:', labels[torch.argmax(torch.sum(prediction['y_prob'], dim=0)).to('cpu').item()]) 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import click 5 | import lightning.pytorch as pl 6 | import matplotlib.pyplot as plt 7 | from lightning.pytorch.loggers import TensorBoardLogger 8 | from pytorchvideo.transforms import Normalize, Permute, RandAugment 9 | from torch.utils.data import DataLoader 10 | from torchvision.transforms import transforms as T 11 | from torchvision.transforms._transforms_video import ToTensorVideo 12 | 13 | from tubevit.dataset import MyUCF101 14 | from tubevit.model import TubeViTLightningModule 15 | 16 | 17 | @click.command() 18 | @click.option("-r", "--dataset-root", type=click.Path(exists=True), required=True, help="path to dataset.") 19 | @click.option("-a", "--annotation-path", type=click.Path(exists=True), required=True, help="path to dataset.") 20 | @click.option("-nc", "--num-classes", type=int, default=101, help="num of classes of dataset.") 21 | @click.option("-b", "--batch-size", type=int, default=32, help="batch size.") 22 | @click.option("-f", "--frames-per-clip", type=int, default=32, help="frame per clip.") 23 | @click.option("-v", "--video-size", type=click.Tuple([int, int]), default=(224, 224), help="frame per clip.") 24 | @click.option("--max-epochs", type=int, default=10, help="max epochs.") 25 | @click.option("--num-workers", type=int, default=0) 26 | @click.option("--fast-dev-run", type=bool, is_flag=True, show_default=True, default=False) 27 | @click.option("--seed", type=int, default=42, help="random seed.") 28 | @click.option("--preview-video", type=bool, is_flag=True, show_default=True, default=False, help="Show input video") 29 | def main( 30 | dataset_root, 31 | annotation_path, 32 | num_classes, 33 | batch_size, 34 | frames_per_clip, 35 | video_size, 36 | max_epochs, 37 | num_workers, 38 | fast_dev_run, 39 | seed, 40 | preview_video, 41 | ): 42 | pl.seed_everything(seed) 43 | 44 | imagenet_mean = [0.485, 0.456, 0.406] 45 | imagenet_std = [0.229, 0.224, 0.225] 46 | 47 | train_transform = T.Compose( 48 | [ 49 | ToTensorVideo(), # C, T, H, W 50 | Permute(dims=[1, 0, 2, 3]), # T, C, H, W 51 | RandAugment(magnitude=10, num_layers=2), 52 | Permute(dims=[1, 0, 2, 3]), # C, T, H, W 53 | T.Resize(size=video_size), 54 | Normalize(mean=imagenet_mean, std=imagenet_std), 55 | ] 56 | ) 57 | 58 | test_transform = T.Compose( 59 | [ 60 | ToTensorVideo(), 61 | T.Resize(size=video_size), 62 | Normalize(mean=imagenet_mean, std=imagenet_std), 63 | ] 64 | ) 65 | 66 | train_metadata_file = "ucf101-train-meta.pickle" 67 | train_precomputed_metadata = None 68 | if os.path.exists(train_metadata_file): 69 | with open(train_metadata_file, "rb") as f: 70 | train_precomputed_metadata = pickle.load(f) 71 | 72 | train_set = MyUCF101( 73 | root=dataset_root, 74 | annotation_path=annotation_path, 75 | _precomputed_metadata=train_precomputed_metadata, 76 | frames_per_clip=frames_per_clip, 77 | train=True, 78 | output_format="THWC", 79 | transform=train_transform, 80 | ) 81 | 82 | if not os.path.exists(train_metadata_file): 83 | with open(train_metadata_file, "wb") as f: 84 | pickle.dump(train_set.metadata, f, protocol=pickle.HIGHEST_PROTOCOL) 85 | 86 | val_metadata_file = "ucf101-val-meta.pickle" 87 | val_precomputed_metadata = None 88 | if os.path.exists(val_metadata_file): 89 | with open(val_metadata_file, "rb") as f: 90 | val_precomputed_metadata = pickle.load(f) 91 | 92 | val_set = MyUCF101( 93 | root=dataset_root, 94 | annotation_path=annotation_path, 95 | _precomputed_metadata=val_precomputed_metadata, 96 | frames_per_clip=frames_per_clip, 97 | train=False, 98 | output_format="THWC", 99 | transform=test_transform, 100 | ) 101 | 102 | if not os.path.exists(val_metadata_file): 103 | with open(val_metadata_file, "wb") as f: 104 | pickle.dump(val_set.metadata, f, protocol=pickle.HIGHEST_PROTOCOL) 105 | 106 | train_dataloader = DataLoader( 107 | train_set, 108 | batch_size=batch_size, 109 | num_workers=num_workers, 110 | shuffle=True, 111 | drop_last=True, 112 | pin_memory=True, 113 | ) 114 | 115 | val_dataloader = DataLoader( 116 | val_set, 117 | batch_size=batch_size, 118 | num_workers=num_workers, 119 | shuffle=False, 120 | drop_last=True, 121 | pin_memory=True, 122 | ) 123 | 124 | x, y = next(iter(train_dataloader)) 125 | print(x.shape) 126 | 127 | if preview_video: 128 | x = x.permute(0, 2, 3, 4, 1) 129 | fig, axs = plt.subplots(4, 8) 130 | for i in range(4): 131 | for j in range(8): 132 | axs[i][j].imshow(x[0][i * 8 + j]) 133 | axs[i][j].set_xticks([]) 134 | axs[i][j].set_yticks([]) 135 | plt.tight_layout() 136 | plt.show() 137 | 138 | model = TubeViTLightningModule( 139 | num_classes=num_classes, 140 | video_shape=x.shape[1:], 141 | num_layers=12, 142 | num_heads=12, 143 | hidden_dim=768, 144 | mlp_dim=3072, 145 | lr=1e-4, 146 | weight_decay=0.001, 147 | weight_path="tubevit_b_(a+iv)+(d+v)+(e+iv)+(f+v).pt", 148 | max_epochs=max_epochs, 149 | ) 150 | 151 | callbacks = [pl.callbacks.LearningRateMonitor(logging_interval="epoch")] 152 | logger = TensorBoardLogger("logs", name="TubeViT") 153 | 154 | trainer = pl.Trainer( 155 | max_epochs=max_epochs, 156 | accelerator="auto", 157 | fast_dev_run=fast_dev_run, 158 | logger=logger, 159 | callbacks=callbacks, 160 | ) 161 | trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader) 162 | trainer.save_checkpoint("./models/tubevit_ucf101.ckpt") 163 | 164 | 165 | if __name__ == "__main__": 166 | main() 167 | -------------------------------------------------------------------------------- /scripts/visualise_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import click 5 | import matplotlib.pyplot as plt 6 | import pytorch_lightning as pl 7 | from pytorchvideo.transforms import Normalize, Permute, RandAugment 8 | from torch.utils.data import DataLoader, RandomSampler 9 | from torchvision.transforms import transforms as T 10 | from torchvision.transforms._transforms_video import ToTensorVideo 11 | 12 | from tubevit.dataset import MyUCF101 13 | 14 | 15 | @click.command() 16 | @click.option("-r", "--dataset-root", type=click.Path(exists=True), required=True, help="path to dataset.") 17 | @click.option("-a", "--annotation-path", type=click.Path(exists=True), required=True, help="path to dataset.") 18 | @click.option("--label-path", type=click.Path(exists=True), required=True, help="path to classInd.txt.") 19 | @click.option("-b", "--batch-size", type=int, default=32, help="batch size.") 20 | @click.option("-f", "--frames-per-clip", type=int, default=32, help="frame per clip.") 21 | @click.option("-v", "--video-size", type=click.Tuple([int, int]), default=(224, 224), help="frame per clip.") 22 | @click.option("--num-workers", type=int, default=0) 23 | @click.option("--seed", type=int, default=42, help="random seed.") 24 | def main(dataset_root, video_size, annotation_path, label_path, frames_per_clip, batch_size, num_workers, seed): 25 | pl.seed_everything(seed) 26 | with open(label_path, "r") as f: 27 | labels = f.read().splitlines() 28 | labels = list(map(lambda x: x.split(" ")[-1], labels)) 29 | 30 | imagenet_mean = [0.485, 0.456, 0.406] 31 | imagenet_std = [0.229, 0.224, 0.225] 32 | 33 | train_transform = T.Compose( 34 | [ 35 | ToTensorVideo(), # C, T, H, W 36 | Permute(dims=[1, 0, 2, 3]), # T, C, H, W 37 | RandAugment(magnitude=10, num_layers=2), 38 | Permute(dims=[1, 0, 2, 3]), # C, T, H, W 39 | T.Resize(size=video_size), 40 | Normalize(mean=imagenet_mean, std=imagenet_std), 41 | ] 42 | ) 43 | 44 | train_metadata_file = "ucf101-train-meta.pickle" 45 | train_precomputed_metadata = None 46 | if os.path.exists(train_metadata_file): 47 | with open(train_metadata_file, "rb") as f: 48 | train_precomputed_metadata = pickle.load(f) 49 | 50 | train_set = MyUCF101( 51 | root=dataset_root, 52 | annotation_path=annotation_path, 53 | _precomputed_metadata=train_precomputed_metadata, 54 | frames_per_clip=frames_per_clip, 55 | train=True, 56 | output_format="THWC", 57 | transform=train_transform, 58 | ) 59 | 60 | if not os.path.exists(train_metadata_file): 61 | with open(train_metadata_file, "wb") as f: 62 | pickle.dump(train_set.metadata, f, protocol=pickle.HIGHEST_PROTOCOL) 63 | 64 | train_sampler = RandomSampler(train_set, num_samples=len(train_set) // 10) 65 | train_dataloader = DataLoader( 66 | train_set, 67 | batch_size=batch_size, 68 | num_workers=num_workers, 69 | shuffle=False, 70 | drop_last=True, 71 | sampler=train_sampler, 72 | ) 73 | 74 | x, y = next(iter(train_dataloader)) 75 | 76 | x = x.permute(0, 2, 3, 4, 1) # CTHW->THWC 77 | 78 | fig, axs = plt.subplots(batch_size // 4, 8, figsize=(batch_size // 4, 8)) 79 | for i in range(batch_size // 4): 80 | axs[i][0].set_title(labels[y[i]]) 81 | for j in range(8): 82 | axs[i][j].imshow(x[i][j]) 83 | axs[i][j].set_xticks([]) 84 | axs[i][j].set_yticks([]) 85 | plt.tight_layout() 86 | plt.show() 87 | 88 | 89 | if __name__ == "__main__": 90 | main() 91 | -------------------------------------------------------------------------------- /scripts/visualise_pos_embed.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | 4 | from tubevit.positional_encoding import get_3d_sincos_pos_embed 5 | 6 | if __name__ == "__main__": 7 | kernel_sizes = ( 8 | (8, 8, 8), 9 | (16, 4, 4), 10 | (4, 12, 12), 11 | (1, 16, 16), 12 | ) 13 | 14 | strides = ( 15 | (16, 32, 32), 16 | (6, 32, 32), 17 | (16, 32, 32), 18 | (32, 16, 16), 19 | ) 20 | 21 | offsets = ( 22 | (0, 0, 0), 23 | (4, 8, 8), 24 | (0, 16, 16), 25 | (0, 0, 0), 26 | ) 27 | 28 | tube_shape = ( 29 | (2, 7, 7), 30 | (3, 7, 7), 31 | (2, 7, 7), 32 | (1, 14, 14), 33 | ) 34 | 35 | pos_encode = [torch.zeros(1, 768)] 36 | 37 | for i in range(len(kernel_sizes)): 38 | pos_embed = get_3d_sincos_pos_embed( 39 | embed_dim=768, tube_shape=tube_shape[i], stride=strides[i], offset=offsets[i], kernel_size=kernel_sizes[i] 40 | ) 41 | # pos_embed = torch.Tensor(pos_embed) 42 | pos_encode.append(pos_embed) 43 | 44 | pos_encode = torch.cat(pos_encode) 45 | plt.imshow(pos_encode) 46 | plt.title("Position Embedding") 47 | plt.xlabel("embed_dim") 48 | plt.ylabel("index of tokens") 49 | plt.tight_layout() 50 | plt.savefig("Position_Embedding.png") 51 | plt.show() 52 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | # yapf: disable 4 | setup( 5 | name='tubevit', 6 | packages=find_packages(), 7 | version='0.1.0', 8 | description='An unofficial implementation of TubeViT in "Rethinking Video ViTs: Sparse Video Tubes for Joint Image and Video Learning".', 9 | author='Su YR', 10 | license="MIT", 11 | ) 12 | -------------------------------------------------------------------------------- /tubevit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daniel-code/TubeViT/8df0d99aec9674f3ec2c5182534fb0cd4ab9aa49/tubevit/__init__.py -------------------------------------------------------------------------------- /tubevit/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, Tuple 2 | 3 | from torch import Tensor 4 | from torchvision.datasets import UCF101 5 | 6 | 7 | class MyUCF101(UCF101): 8 | def __init__(self, transform: Optional[Callable] = None, *args, **kwargs) -> None: 9 | super().__init__(*args, **kwargs) 10 | self.transform = transform 11 | 12 | def __getitem__(self, idx: int) -> Tuple[Tensor, int]: 13 | video, audio, info, video_idx = self.video_clips.get_clip(idx) 14 | label = self.samples[self.indices[video_idx]][1] 15 | 16 | if self.transform is not None: 17 | video = self.transform(video) 18 | 19 | return video, label 20 | -------------------------------------------------------------------------------- /tubevit/model.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any, Callable, List, Union 3 | 4 | import lightning.pytorch as pl 5 | import numpy as np 6 | import torch 7 | from torch import Tensor, nn, optim 8 | from torch.nn import functional as F 9 | from torchmetrics.functional import accuracy, f1_score 10 | from torchvision.models.vision_transformer import EncoderBlock 11 | from typing_extensions import OrderedDict 12 | 13 | from tubevit.positional_encoding import get_3d_sincos_pos_embed 14 | 15 | 16 | class Encoder(nn.Module): 17 | """ 18 | Transformer Model Encoder for sequence to sequence translation. 19 | Code from torch. 20 | Move pos_embedding to TubeViT 21 | """ 22 | 23 | def __init__( 24 | self, 25 | num_layers: int, 26 | num_heads: int, 27 | hidden_dim: int, 28 | mlp_dim: int, 29 | dropout: float, 30 | attention_dropout: float, 31 | norm_layer: Callable[..., nn.Module] = partial(nn.LayerNorm, eps=1e-6), 32 | ): 33 | super().__init__() 34 | self.dropout = nn.Dropout(dropout) 35 | layers: OrderedDict[str, nn.Module] = OrderedDict() 36 | for i in range(num_layers): 37 | layers[f"encoder_layer_{i}"] = EncoderBlock( 38 | num_heads, 39 | hidden_dim, 40 | mlp_dim, 41 | dropout, 42 | attention_dropout, 43 | norm_layer, 44 | ) 45 | self.layers = nn.Sequential(layers) 46 | self.ln = norm_layer(hidden_dim) 47 | 48 | def forward(self, x: Tensor): 49 | torch._assert(x.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {x.shape}") 50 | return self.ln(self.layers(self.dropout(x))) 51 | 52 | 53 | class SparseTubesTokenizer(nn.Module): 54 | def __init__(self, hidden_dim, kernel_sizes, strides, offsets): 55 | super().__init__() 56 | self.hidden_dim = hidden_dim 57 | self.kernel_sizes = kernel_sizes 58 | self.strides = strides 59 | self.offsets = offsets 60 | 61 | self.conv_proj_weight = nn.Parameter( 62 | torch.empty((self.hidden_dim, 3, *self.kernel_sizes[0])).normal_(), requires_grad=True 63 | ) 64 | 65 | self.register_parameter("conv_proj_weight", self.conv_proj_weight) 66 | 67 | self.conv_proj_bias = nn.Parameter(torch.zeros(len(self.kernel_sizes), self.hidden_dim), requires_grad=True) 68 | self.register_parameter("conv_proj_bias", self.conv_proj_bias) 69 | 70 | def forward(self, x: Tensor) -> Tensor: 71 | n, c, t, h, w = x.shape # CTHW 72 | tubes = [] 73 | for i in range(len(self.kernel_sizes)): 74 | if i == 0: 75 | weight = self.conv_proj_weight 76 | else: 77 | weight = F.interpolate(self.conv_proj_weight, self.kernel_sizes[i], mode="trilinear") 78 | 79 | tube = F.conv3d( 80 | x[:, :, self.offsets[i][0] :, self.offsets[i][1] :, self.offsets[i][2] :], 81 | weight, 82 | bias=self.conv_proj_bias[i], 83 | stride=self.strides[i], 84 | ) 85 | 86 | tube = tube.reshape((n, self.hidden_dim, -1)) 87 | 88 | tubes.append(tube) 89 | 90 | x = torch.cat(tubes, dim=-1) 91 | x = x.permute(0, 2, 1).contiguous() 92 | return x 93 | 94 | 95 | class SelfAttentionPooling(nn.Module): 96 | """ 97 | Implementation of SelfAttentionPooling 98 | Original Paper: Self-Attention Encoding and Pooling for Speaker Recognition 99 | https://arxiv.org/pdf/2008.01077v1.pdf 100 | 101 | code from https://gist.github.com/pohanchi/c77f6dbfbcbc21c5215acde4f62e4362 102 | """ 103 | 104 | def __init__(self, input_dim): 105 | super(SelfAttentionPooling, self).__init__() 106 | self.W = nn.Linear(input_dim, 1) 107 | 108 | def forward(self, x): 109 | """ 110 | input: 111 | batch_rep : size (N, T, H), N: batch size, T: sequence length, H: Hidden dimension 112 | 113 | attention_weight: 114 | att_w : size (N, T, 1) 115 | 116 | return: 117 | utter_rep: size (N, H) 118 | """ 119 | 120 | # (N, T, H) -> (N, T) -> (N, T, 1) 121 | att_w = nn.functional.softmax(self.W(x).squeeze(dim=-1), dim=-1).unsqueeze(dim=-1) 122 | x = torch.sum(x * att_w, dim=1) 123 | return x 124 | 125 | 126 | class TubeViT(nn.Module): 127 | def __init__( 128 | self, 129 | num_classes: int, 130 | video_shape: Union[List[int], np.ndarray], # CTHW 131 | num_layers: int, 132 | num_heads: int, 133 | hidden_dim: int, 134 | mlp_dim: int, 135 | dropout: float = 0.0, 136 | attention_dropout: float = 0.0, 137 | representation_size=None, 138 | ): 139 | super(TubeViT, self).__init__() 140 | self.video_shape = np.array(video_shape) # CTHW 141 | self.num_classes = num_classes 142 | self.hidden_dim = hidden_dim 143 | self.kernel_sizes = ( 144 | (8, 8, 8), 145 | (16, 4, 4), 146 | (4, 12, 12), 147 | (1, 16, 16), 148 | ) 149 | 150 | self.strides = ( 151 | (16, 32, 32), 152 | (6, 32, 32), 153 | (16, 32, 32), 154 | (32, 16, 16), 155 | ) 156 | 157 | self.offsets = ( 158 | (0, 0, 0), 159 | (4, 8, 8), 160 | (0, 16, 16), 161 | (0, 0, 0), 162 | ) 163 | self.sparse_tubes_tokenizer = SparseTubesTokenizer( 164 | self.hidden_dim, self.kernel_sizes, self.strides, self.offsets 165 | ) 166 | 167 | self.pos_embedding = self._generate_position_embedding() 168 | self.pos_embedding = torch.nn.Parameter(self.pos_embedding, requires_grad=False) 169 | self.register_parameter("pos_embedding", self.pos_embedding) 170 | 171 | # Add a class token 172 | self.class_token = nn.Parameter(torch.zeros(1, 1, self.hidden_dim), requires_grad=True) 173 | self.register_parameter("class_token", self.class_token) 174 | 175 | self.encoder = Encoder( 176 | num_layers=num_layers, 177 | num_heads=num_heads, 178 | hidden_dim=self.hidden_dim, 179 | mlp_dim=mlp_dim, 180 | dropout=dropout, 181 | attention_dropout=attention_dropout, 182 | ) 183 | 184 | self.attention_pooling = SelfAttentionPooling(self.hidden_dim) 185 | 186 | heads_layers: OrderedDict[str, nn.Module] = OrderedDict() 187 | if representation_size is None: 188 | heads_layers["head"] = nn.Linear(self.hidden_dim, self.num_classes) 189 | else: 190 | heads_layers["pre_logits"] = nn.Linear(self.hidden_dim, representation_size) 191 | heads_layers["act"] = nn.Tanh() 192 | heads_layers["head"] = nn.Linear(representation_size, self.num_classes) 193 | 194 | self.heads = nn.Sequential(heads_layers) 195 | 196 | def forward(self, x): 197 | x = self.sparse_tubes_tokenizer(x) 198 | n = x.shape[0] 199 | 200 | # Expand the class token to the full batch 201 | batch_class_token = self.class_token.expand(n, -1, -1) 202 | x = torch.cat([batch_class_token, x], dim=1) 203 | 204 | x = x + self.pos_embedding 205 | 206 | x = self.encoder(x) 207 | 208 | # Attention pooling 209 | x = self.attention_pooling(x) 210 | 211 | x = self.heads(x) 212 | 213 | return x 214 | 215 | def _calc_conv_shape(self, kernel_size, stride, offset) -> np.ndarray: 216 | kernel_size = np.array(kernel_size) 217 | stride = np.array(stride) 218 | offset = np.array(offset) 219 | output = np.floor(((self.video_shape[[1, 2, 3]] - offset - kernel_size) / stride) + 1).astype(int) 220 | return output 221 | 222 | def _generate_position_embedding(self) -> torch.nn.Parameter: 223 | position_embedding = [torch.zeros(1, self.hidden_dim)] 224 | 225 | for i in range(len(self.kernel_sizes)): 226 | tube_shape = self._calc_conv_shape(self.kernel_sizes[i], self.strides[i], self.offsets[i]) 227 | pos_embed = get_3d_sincos_pos_embed( 228 | embed_dim=self.hidden_dim, 229 | tube_shape=tube_shape, 230 | kernel_size=self.kernel_sizes[i], 231 | stride=self.strides[i], 232 | offset=self.offsets[i], 233 | ) 234 | position_embedding.append(pos_embed) 235 | 236 | position_embedding = torch.cat(position_embedding, dim=0).contiguous() 237 | return position_embedding 238 | 239 | 240 | class TubeViTLightningModule(pl.LightningModule): 241 | def __init__( 242 | self, 243 | num_classes, 244 | video_shape, 245 | num_layers, 246 | num_heads, 247 | hidden_dim, 248 | mlp_dim, 249 | lr: float = 3e-4, 250 | weight_decay: float = 0, 251 | weight_path: str = None, 252 | max_epochs: int = None, 253 | label_smoothing: float = 0.0, 254 | dropout: float = 0.0, 255 | attention_dropout: float = 0.0, 256 | **kwargs, 257 | ): 258 | self.save_hyperparameters() 259 | super().__init__() 260 | self.num_classes = num_classes 261 | self.model = TubeViT( 262 | num_classes=num_classes, 263 | video_shape=video_shape, 264 | num_layers=num_layers, 265 | num_heads=num_heads, 266 | hidden_dim=hidden_dim, 267 | mlp_dim=mlp_dim, 268 | dropout=dropout, 269 | attention_dropout=attention_dropout, 270 | ) 271 | 272 | self.lr = lr 273 | self.loss_func = nn.CrossEntropyLoss(label_smoothing=label_smoothing) 274 | self.example_input_array = Tensor(1, *video_shape) 275 | 276 | if weight_path is not None: 277 | self.model.load_state_dict(torch.load(weight_path), strict=False) 278 | self.max_epochs = max_epochs 279 | self.weight_decay = weight_decay 280 | 281 | def forward(self, x): 282 | return self.model(x) 283 | 284 | def training_step(self, batch, batch_idx): 285 | x, y = batch 286 | y_hat = self(x) 287 | 288 | loss = self.loss_func(y_hat, y) 289 | 290 | y_pred = torch.softmax(y_hat, dim=-1) 291 | 292 | # Logging to TensorBoard by default 293 | self.log("train_loss", loss, prog_bar=True) 294 | self.log("train_acc", accuracy(y_pred, y, task="multiclass", num_classes=self.num_classes), prog_bar=True) 295 | self.log("train_f1", f1_score(y_pred, y, task="multiclass", num_classes=self.num_classes), prog_bar=True) 296 | 297 | return loss 298 | 299 | def validation_step(self, batch, batch_idx): 300 | x, y = batch 301 | y_hat = self(x) 302 | 303 | loss = self.loss_func(y_hat, y) 304 | 305 | y_pred = torch.softmax(y_hat, dim=-1) 306 | 307 | # Logging to TensorBoard by default 308 | self.log("val_loss", loss, prog_bar=True) 309 | self.log("val_acc", accuracy(y_pred, y, task="multiclass", num_classes=self.num_classes), prog_bar=True) 310 | self.log("val_f1", f1_score(y_pred, y, task="multiclass", num_classes=self.num_classes), prog_bar=True) 311 | 312 | return loss 313 | 314 | def on_train_epoch_end(self) -> None: 315 | self.log("lr", self.optimizers().optimizer.param_groups[0]["lr"], on_step=False, on_epoch=True) 316 | 317 | def configure_optimizers(self): 318 | optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay) 319 | if self.max_epochs is not None: 320 | lr_scheduler = optim.lr_scheduler.OneCycleLR( 321 | optimizer=optimizer, max_lr=self.lr, total_steps=self.max_epochs 322 | ) 323 | return [optimizer], [lr_scheduler] 324 | else: 325 | return optimizer 326 | 327 | def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: 328 | x, y = batch 329 | y_hat = self(x) 330 | y_pred = torch.softmax(y_hat, dim=-1) 331 | 332 | return {"y": y, "y_pred": torch.argmax(y_pred, dim=-1), "y_prob": y_pred} 333 | -------------------------------------------------------------------------------- /tubevit/positional_encoding.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inspired by positional_encoding in [pytorchvideo](https://github.com/facebookresearch/pytorchvideo/blob/f7e7a88a9a04b70cb65a564acfc38538fe71ff7b/pytorchvideo/layers/positional_encoding.py). 3 | Convert to pytorch version. 4 | """ 5 | 6 | from typing import Tuple 7 | 8 | import torch 9 | 10 | 11 | def get_3d_sincos_pos_embed( 12 | embed_dim: int, tube_shape: Tuple[int, int, int], stride, offset, kernel_size, cls_token: bool = False 13 | ) -> torch.Tensor: 14 | """ 15 | Get 3D sine-cosine positional embedding. 16 | Args: 17 | tube_shape: (t_size, grid_h_size, grid_w_size) 18 | kernel_size: 19 | offset: 20 | stride: 21 | embed_dim: 22 | cls_token: bool, whether to contain CLS token 23 | Returns: 24 | (torch.Tensor): [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] 25 | (w/ or w/o cls_token) 26 | """ 27 | assert embed_dim % 4 == 0 28 | embed_dim_spatial = embed_dim // 3 * 2 29 | embed_dim_temporal = embed_dim // 3 30 | 31 | # spatial 32 | grid_h_size = tube_shape[1] 33 | grid_h = torch.arange(grid_h_size, dtype=torch.float) 34 | grid_h = grid_h * stride[1] + offset[1] + kernel_size[1] // 2 35 | 36 | grid_w_size = tube_shape[2] 37 | grid_w = torch.arange(tube_shape[2], dtype=torch.float) 38 | grid_w = grid_w * stride[2] + offset[2] + kernel_size[2] // 2 39 | grid = torch.meshgrid(grid_w, grid_h, indexing="ij") 40 | grid = torch.stack(grid, dim=0) 41 | 42 | grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) 43 | pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid) 44 | 45 | # temporal 46 | t_size = tube_shape[0] 47 | grid_t = torch.arange(t_size, dtype=torch.float) 48 | grid_t = grid_t * stride[0] + offset[0] + kernel_size[0] // 2 49 | pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t) 50 | 51 | pos_embed_temporal = pos_embed_temporal[:, None, :] 52 | pos_embed_temporal = torch.repeat_interleave(pos_embed_temporal, grid_h_size * grid_w_size, dim=1) 53 | pos_embed_spatial = pos_embed_spatial[None, :, :] 54 | pos_embed_spatial = torch.repeat_interleave(pos_embed_spatial, t_size, dim=0) 55 | 56 | pos_embed = torch.cat([pos_embed_temporal, pos_embed_spatial], dim=-1) 57 | pos_embed = pos_embed.reshape([-1, embed_dim]) 58 | 59 | if cls_token: 60 | pos_embed = torch.cat([torch.zeros([1, embed_dim]), pos_embed], dim=0) 61 | return pos_embed 62 | 63 | 64 | def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int, cls_token: bool = False) -> torch.Tensor: 65 | """ 66 | Get 2D sine-cosine positional embedding. 67 | Args: 68 | grid_size: int of the grid height and width 69 | cls_token: bool, whether to contain CLS token 70 | Returns: 71 | (torch.Tensor): [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 72 | """ 73 | grid_h = torch.arange(grid_size, dtype=torch.float) 74 | grid_w = torch.arange(grid_size, dtype=torch.float) 75 | grid = torch.meshgrid(grid_w, grid_h, indexing="ij") 76 | grid = torch.stack(grid, dim=0) 77 | 78 | grid = grid.reshape([2, 1, grid_size, grid_size]) 79 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 80 | if cls_token: 81 | pos_embed = torch.cat([torch.zeros([1, embed_dim]), pos_embed], dim=0) 82 | return pos_embed 83 | 84 | 85 | def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor: 86 | """ 87 | Get 2D sine-cosine positional embedding from grid. 88 | Args: 89 | embed_dim: embedding dimension. 90 | grid: positions 91 | Returns: 92 | (torch.Tensor): [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 93 | """ 94 | assert embed_dim % 2 == 0 95 | 96 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) 97 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) 98 | 99 | emb = torch.cat([emb_h, emb_w], dim=1) 100 | return emb 101 | 102 | 103 | def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor: 104 | """ 105 | Get 1D sine-cosine positional embedding. 106 | Args: 107 | embed_dim: output dimension for each position 108 | pos: a list of positions to be encoded: size (M,) 109 | Returns: 110 | (torch.Tensor): tensor of shape (M, D) 111 | """ 112 | assert embed_dim % 2 == 0 113 | omega = torch.arange(embed_dim // 2, dtype=torch.float) 114 | omega /= embed_dim / 2.0 115 | omega = 1.0 / 10000**omega 116 | 117 | pos = pos.reshape(-1) 118 | out = torch.einsum("m,d->md", pos, omega) 119 | 120 | emb_sin = torch.sin(out) 121 | emb_cos = torch.cos(out) 122 | 123 | emb = torch.cat([emb_sin, emb_cos], dim=1) 124 | return emb 125 | --------------------------------------------------------------------------------