├── assets ├── method_fig.png ├── results_fig.png ├── teaser_fig.png ├── MTL_t5_xxl_text_embeddings.npy ├── Cat101_t5_xxl_text_embeddings.npy └── FineDiving_t5_xxl_text_embeddings.npy ├── libs ├── core │ ├── __init__.py │ └── config.py ├── datasets │ ├── __init__.py │ ├── datasets.py │ ├── data_utils.py │ ├── mtl_aqa.py │ └── finediving.py ├── utils │ ├── __init__.py │ ├── postprocessing.py │ ├── preprocessing.py │ ├── metrics.py │ ├── lr_schedulers.py │ └── train_utils.py └── modeling │ ├── __init__.py │ ├── models.py │ ├── weight_init.py │ ├── necks.py │ ├── backbones.py │ ├── losses.py │ ├── meta_archs.py │ ├── i3d.py │ └── blocks.py ├── requirements.txt ├── INSTALL.md ├── LICENSE ├── README.md ├── .gitignore ├── GETTING_STARTED.md ├── configs ├── mtl_aqa │ ├── deter_mtl_aqa_text_data_query.yaml │ └── stoch_mtl_aqa_text_data_query.yaml └── fine │ ├── deter_fine_diving_text_data_query.yaml │ └── stoch_fine_diving_text_data_query.yaml ├── tools ├── mtl_aqa │ └── mtl_t5xxl_text_embed_extraction.py └── finediving │ └── finediving_t5xxl_text_embed_extraction.py ├── eval.py └── train.py /assets/method_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abrarmajeedi/rica2_aqa/HEAD/assets/method_fig.png -------------------------------------------------------------------------------- /assets/results_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abrarmajeedi/rica2_aqa/HEAD/assets/results_fig.png -------------------------------------------------------------------------------- /assets/teaser_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abrarmajeedi/rica2_aqa/HEAD/assets/teaser_fig.png -------------------------------------------------------------------------------- /assets/MTL_t5_xxl_text_embeddings.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abrarmajeedi/rica2_aqa/HEAD/assets/MTL_t5_xxl_text_embeddings.npy -------------------------------------------------------------------------------- /assets/Cat101_t5_xxl_text_embeddings.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abrarmajeedi/rica2_aqa/HEAD/assets/Cat101_t5_xxl_text_embeddings.npy -------------------------------------------------------------------------------- /libs/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import load_default_config, load_config 2 | 3 | __all__ = ['load_default_config', 'load_config'] 4 | -------------------------------------------------------------------------------- /assets/FineDiving_t5_xxl_text_embeddings.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abrarmajeedi/rica2_aqa/HEAD/assets/FineDiving_t5_xxl_text_embeddings.npy -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ## Requirements 2 | Tensorboard >= 2.14.0 3 | PyYaml >= 6.0.2 4 | einops >=0.8.0 5 | scipy >=1.10.1 6 | numpy>=1.24.1 7 | Pillow>=10.2.0 8 | transformers>=4.45.2 9 | -------------------------------------------------------------------------------- /libs/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_utils import worker_init_reset_seed 2 | from .datasets import make_dataset, make_data_loader 3 | from . import finediving, mtl_aqa # other datasets go here 4 | 5 | __all__ = ['worker_init_reset_seed', 6 | 'make_dataset', 'make_data_loader'] 7 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | ## Conda environment 4 | Python >= 3.8 5 | 6 | conda create --name aqa_env python=3.8 7 | conda activate aqa_env 8 | 9 | ## PyTorch 10 | 11 | PyTorch >= 2.4.1 12 | [(Get PyTorch Here)](https://pytorch.org/get-started/locally/) 13 | 14 | ## Requirements 15 | Install the required libraries with : 16 | 17 | pip3 install -r requirements.txt -------------------------------------------------------------------------------- /libs/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .train_utils import (Logger, make_optimizer, make_scheduler, save_checkpoint, 2 | AverageMeter, train_one_epoch, valid_one_epoch, 3 | fix_random_seed, ModelEma) 4 | 5 | # from .postprocessing import postprocess_results 6 | 7 | __all__ = ['Logger','make_optimizer', 'make_scheduler', 'save_checkpoint', 8 | 'AverageMeter', 'train_one_epoch', 'valid_one_epoch', 9 | 'postprocess_results', 'fix_random_seed', 'ModelEma', 'remove_duplicate_annotations'] 10 | -------------------------------------------------------------------------------- /libs/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from .blocks import (MaskedConv1D,MaskedMHCA, MaskedMHA, LayerNorm, 2 | TransformerBlock, ConvBlock, Scale, AffineDropPath) 3 | from .models import make_backbone, make_neck, make_meta_arch 4 | from . import backbones # backbones 5 | from . import necks # necks 6 | from . import meta_archs # full models 7 | from .i3d import I3D 8 | __all__ = ['MaskedConv1D','MaskedMHCA', 'MaskedMHA', 'LayerNorm', 9 | 'TransformerBlock', 'ConvBlock', 'Scale', 'AffineDropPath', 10 | 'make_backbone', 'make_neck', 'make_meta_arch'] -------------------------------------------------------------------------------- /libs/modeling/models.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # backbone (e.g., conv / transformer) 4 | backbones = {} 5 | def register_backbone(name): 6 | def decorator(cls): 7 | backbones[name] = cls 8 | return cls 9 | return decorator 10 | 11 | # neck (e.g., FPN) 12 | necks = {} 13 | def register_neck(name): 14 | def decorator(cls): 15 | necks[name] = cls 16 | return cls 17 | return decorator 18 | 19 | # meta arch (the actual implementation of each model) 20 | meta_archs = {} 21 | def register_meta_arch(name): 22 | def decorator(cls): 23 | meta_archs[name] = cls 24 | return cls 25 | return decorator 26 | 27 | # builder functions 28 | def make_backbone(name, **kwargs): 29 | backbone = backbones[name](**kwargs) 30 | return backbone 31 | 32 | def make_neck(name, **kwargs): 33 | neck = necks[name](**kwargs) 34 | return neck 35 | 36 | def make_meta_arch(name, **kwargs): 37 | meta_arch = meta_archs[name](**kwargs) 38 | return meta_arch 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 University of Wisconsin-Madison 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /libs/datasets/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from .data_utils import trivial_batch_collator, worker_init_reset_seed 4 | 5 | datasets = {} 6 | def register_dataset(name): 7 | def decorator(cls): 8 | datasets[name] = cls 9 | return cls 10 | return decorator 11 | 12 | def make_dataset(name, is_training, split, **kwargs): 13 | """ 14 | A simple dataset builder 15 | """ 16 | dataset = datasets[name](is_training, split, **kwargs) 17 | return dataset 18 | 19 | def make_data_loader(dataset, is_training, generator, train_batch_size, test_batch_size, num_workers): 20 | """ 21 | A simple dataloder builder 22 | """ 23 | loader = torch.utils.data.DataLoader( 24 | dataset, 25 | batch_size=train_batch_size if is_training else test_batch_size, 26 | num_workers=num_workers, 27 | collate_fn=trivial_batch_collator, 28 | worker_init_fn=(worker_init_reset_seed if is_training else None), 29 | shuffle=is_training, 30 | drop_last=is_training, 31 | generator=generator, 32 | persistent_workers=True, 33 | pin_memory=True 34 | ) 35 | return loader 36 | -------------------------------------------------------------------------------- /libs/utils/postprocessing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import time 4 | import json 5 | import pickle 6 | from typing import Dict 7 | 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | @torch.no_grad() 14 | def update_metric_dict_with_model_output(metric_dict, model_output, gt_scores, difficulties, is_val, cfg=None): 15 | # B, samples, 1 16 | all_sample_outputs = model_output["all_sample_outputs"].detach() 17 | 18 | with_dd = cfg["dataset"]["with_dd"] 19 | 20 | batched_pred_quality_scores = get_pred_scores(all_sample_outputs, cfg).detach().cpu().numpy() 21 | batched_gt_scores = gt_scores.detach().cpu().numpy() 22 | 23 | if with_dd: 24 | batched_pred_quality_scores = batched_pred_quality_scores * difficulties.detach().cpu().numpy() 25 | 26 | metric_dict.update("pred_scores", batched_pred_quality_scores) 27 | metric_dict.update("gt_scores", batched_gt_scores) 28 | 29 | if is_val: 30 | if "global_sqrt_var_emb" in model_output.keys(): 31 | metric_dict.update("global_sqrt_var_emb", model_output["global_sqrt_var_emb"].detach().cpu().numpy()) 32 | return 33 | 34 | 35 | def get_pred_scores(all_sample_outputs, cfg): 36 | #B, samples, 1: all_sample_outputs.size() 37 | 38 | if cfg["dataset_name"] == "jigsaws": 39 | batched_pred_quality_scores = all_sample_outputs.mean(1).squeeze() 40 | if cfg["dataset"]["six_item_score_scaling"]: 41 | batched_pred_quality_scores = batched_pred_quality_scores * 6.0 42 | else: 43 | batched_pred_quality_scores = all_sample_outputs.mean(1).squeeze() 44 | if cfg["dataset"]["three_judge_score_scaling"]: 45 | batched_pred_quality_scores = batched_pred_quality_scores * 3.0 46 | 47 | return batched_pred_quality_scores 48 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RICA2: Rubric-Informed, Calibrated Assessment of Actions (ECCV 2024) 2 | 3 | [![Project Page](https://img.shields.io/badge/Project-Page-blue.svg)](https://abrarmajeedi.github.io/rica2_aqa/) 4 | [![arXiv](https://img.shields.io/badge/arXiv-2312.04364-b31b1b.svg)](https://arxiv.org/abs/2408.02138) 5 | 6 | 7 | 8 | Check out the new Medium post on RICA2! [![Medium](https://img.shields.io/badge/Medium-%23000000.svg?logo=medium&logoColor=white)](https://namburisrinath.medium.com/rica%C2%B2-rubric-informed-calibrated-assessment-of-actions-92b5715a9163) 9 | 10 | 11 | 12 | ## Abstract 13 | 14 | ![Teaser figure](assets/teaser_fig.png) 15 | 16 | The ability to quantify how well an action is carried out, also known as action quality assessment (AQA), has attracted recent interest in the vision community. Unfortunately, prior methods often ignore the score rubric used by human experts and fall short of quantifying the uncertainty of the model prediction. To bridge the gap, we present RICA^2 - a deep probabilistic model that integrates score rubric and accounts for prediction uncertainty for AQA. Central to our method lies in stochastic embeddings of action steps, defined on a graph structure that encodes the score rubric. The embeddings spread probabilistic density in the latent space and allow our method to represent model uncertainty. The graph encodes the scoring criteria, based on which the quality scores can be decoded. We demonstrate that our method establishes new state of the art on public benchmarks, including FineDiving, MTL-AQA, and JIGSAWS, with superior performance in score prediction and uncertainty calibration 17 | 18 | 19 | ## Method 20 | 21 | ![Main method figure](assets/method_fig.png) 22 | 23 | 24 | 25 | ## Results 26 | 27 | 28 | ![Result figure](assets/results_fig.png) 29 | 30 | 31 | ## Code 32 | Please find the installation instructions in [INSTALL.md](./INSTALL.md) 33 | 34 | Instructions to run the code can be found in [ GETTING_STARTED.md](./GETTING_STARTED.md) 35 | 36 | 37 | ## Citation 38 | If you find our work useful, please consider citing: 39 | 40 | ```bibtex 41 | @article{majeedi24rica2, 42 | title={RICA^2: Rubric-Informed, Calibrated Assessment of Actions}, 43 | author={Majeedi, Abrar and Gajjala, Viswanatha Reddy and GNVV, Satya Sai Srinath Namburi and Li, Yin}, 44 | year={2024}, 45 | booktitle={European conference on computer vision}, 46 | primaryClass={cs.CV} 47 | } 48 | ``` 49 | 50 | 51 | ## Contact 52 | Email: majeedi+at+wisc+dot+edu 53 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | ckpt/ 3 | feats/ 4 | exp/ 5 | images/reliablity_diagrams/ 6 | libs/utils/dist 7 | libs/utils/nms_1d_cpu.egg-info 8 | **/build 9 | *.out 10 | *pth 11 | 12 | slurm* 13 | # Byte-compiled / optimized / DLL files 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | pip-wheel-metadata/ 36 | share/python-wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | MANIFEST 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .nox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | *.py,cover 63 | .hypothesis/ 64 | .pytest_cache/ 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | local_settings.py 72 | db.sqlite3 73 | db.sqlite3-journal 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 106 | __pypackages__/ 107 | 108 | # Celery stuff 109 | celerybeat-schedule 110 | celerybeat.pid 111 | 112 | # SageMath parsed files 113 | *.sage.py 114 | 115 | # Environments 116 | .env 117 | .venv 118 | env/ 119 | venv/ 120 | ENV/ 121 | env.bak/ 122 | venv.bak/ 123 | 124 | # Spyder project settings 125 | .spyderproject 126 | .spyproject 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # mkdocs documentation 132 | /site 133 | 134 | # mypy 135 | .mypy_cache/ 136 | .dmypy.json 137 | dmypy.json 138 | 139 | # Pyre type checker 140 | .pyre/ 141 | -------------------------------------------------------------------------------- /libs/modeling/weight_init.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py 2 | import torch 3 | import math 4 | import warnings 5 | 6 | 7 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 8 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 9 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 10 | def norm_cdf(x): 11 | # Computes standard normal cumulative distribution function 12 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 13 | 14 | if (mean < a - 2 * std) or (mean > b + 2 * std): 15 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 16 | "The distribution of values may be incorrect.", 17 | stacklevel=2) 18 | 19 | with torch.no_grad(): 20 | # Values are generated by using a truncated uniform distribution and 21 | # then using the inverse CDF for the normal distribution. 22 | # Get upper and lower cdf values 23 | l = norm_cdf((a - mean) / std) 24 | u = norm_cdf((b - mean) / std) 25 | 26 | # Uniformly fill tensor with values from [l, u], then translate to 27 | # [2l-1, 2u-1]. 28 | tensor.uniform_(2 * l - 1, 2 * u - 1) 29 | 30 | # Use inverse cdf transform for normal distribution to get truncated 31 | # standard normal 32 | tensor.erfinv_() 33 | 34 | # Transform to proper mean, std 35 | tensor.mul_(std * math.sqrt(2.)) 36 | tensor.add_(mean) 37 | 38 | # Clamp to ensure it's in the proper range 39 | tensor.clamp_(min=a, max=b) 40 | return tensor 41 | 42 | 43 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 44 | # type: (Tensor, float, float, float, float) -> Tensor 45 | r"""Fills the input Tensor with values drawn from a truncated 46 | normal distribution. The values are effectively drawn from the 47 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 48 | with values outside :math:`[a, b]` redrawn until they are within 49 | the bounds. The method used for generating the random values works 50 | best when :math:`a \leq \text{mean} \leq b`. 51 | Args: 52 | tensor: an n-dimensional `torch.Tensor` 53 | mean: the mean of the normal distribution 54 | std: the standard deviation of the normal distribution 55 | a: the minimum cutoff value 56 | b: the maximum cutoff value 57 | Examples: 58 | >>> w = torch.empty(3, 5) 59 | >>> nn.init.trunc_normal_(w) 60 | """ 61 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 62 | -------------------------------------------------------------------------------- /libs/core/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | 4 | DEFAULTS = { 5 | # random seed for reproducibility, a large number is preferred 6 | "init_rand_seed": 1234567891, 7 | # dataset loader, specify the dataset here 8 | "dataset_name": "", 9 | "devices": ['cuda:0'], # default: single gpu 10 | "train_split": ('training', ), 11 | "val_split": ('validation', ), 12 | "model_name": "XXX", 13 | "dataset": { 14 | 15 | }, 16 | "loader": { 17 | "train_batch_size": 8, 18 | "test_batch_size": 8, 19 | "num_workers": 1, 20 | }, 21 | # network architecture 22 | "model": { 23 | "finetune_feat_extractor": False, 24 | "feat_extractor_type": None, 25 | "feat_extractor_weights_path": None, 26 | "backbone_type": 'xxx', 27 | # disable abs position encoding (added to input embedding) 28 | "neck_type": 'xxx', 29 | "decoder_params": { 30 | "decoder_ffn_dim": 2048, 31 | "decoder_activation": 'gelu', 32 | }, 33 | }, 34 | "train_cfg": { 35 | # gradient cliping, not needed for pre-LN transformer 36 | "clip_grad_l2norm": -1, 37 | }, 38 | "test_cfg": { 39 | }, 40 | # optimizer (for training) 41 | "opt": { 42 | # solver 43 | "type": "AdamW", # SGD or AdamW 44 | # solver params 45 | "momentum": 0.9, 46 | "weight_decay": 0.0, 47 | "learning_rate": 1e-3, 48 | # excluding the warmup epochs 49 | "epochs": 30, 50 | # lr scheduler: cosine / multistep 51 | "warmup": True, 52 | "warmup_epochs": 5, 53 | "schedule_type": "cosine", 54 | # in #epochs excluding warmup 55 | "schedule_steps": [], 56 | "schedule_gamma": 0.1, 57 | } 58 | } 59 | 60 | def _merge(src, dst): 61 | for k, v in src.items(): 62 | if k in dst: 63 | if isinstance(v, dict): 64 | _merge(src[k], dst[k]) 65 | else: 66 | dst[k] = v 67 | 68 | def load_default_config(): 69 | config = DEFAULTS 70 | return config 71 | 72 | def _update_config(config): 73 | # fill in derived fields 74 | config["model"]["train_cfg"] = config["train_cfg"] 75 | config["model"]["test_cfg"] = config["test_cfg"] 76 | config["model"]["frames_per_clip"] = config["dataset"]["frames_per_clip"] 77 | config["model"]["max_seq_len"] = config["dataset"]["max_seq_len"] 78 | return config 79 | 80 | def load_config(config_file, defaults=DEFAULTS): 81 | with open(config_file, "r") as fd: 82 | config = yaml.load(fd, Loader=yaml.FullLoader) 83 | _merge(defaults, config) 84 | config = _update_config(config) 85 | return config -------------------------------------------------------------------------------- /GETTING_STARTED.md: -------------------------------------------------------------------------------- 1 | 2 | # Getting Started 3 | 4 | 5 | 6 | ## Requirements 7 | 8 | Please find the installation instructions in [INSTALL.md](./INSTALL.md) 9 | 10 | ## Code 11 | Pull the code using 12 | 13 | git clone https://github.com/abrarmajeedi/rica2_aqa 14 | 15 | and navigate into the code directory 16 | 17 | 18 | cd rica2_aqa 19 | 20 | 21 | ## Data 22 | 23 | You can download the zipped data from the Google drive [link](https://drive.google.com/file/d/1CjYtxnjHZzDkWDYrLMFbph9b-EZ8fdFT/view?usp=sharing). 24 | 25 | 26 | Once downloaded, unzip the archive into ./data into the code directory 27 | 28 | 29 | Make sure the data follows this structure 30 | ```markdown 31 | ├── data 32 | │ 33 | │ ├── finediving 34 | │ │ ├── Annotations 35 | │ │ │ ├── Annotation files (**.pkl) 36 | │ 37 | │ │ ├── FINADiving_MTL_256 38 | │ │ │ ├── Video Frame directories 39 | │ 40 | │ ├── mtl_aqa 41 | │ │ ├── frames_long 42 | │ │ │ ├── Video frame directories 43 | │ 44 | │ │ ├── info 45 | │ │ │ ├── Annotation files (**.pkl) 46 | ``` 47 | 48 | ## Pretrained I3D weights 49 | 50 | You can download the pretrained I3D weights from the Google drive [link](https://drive.google.com/file/d/1vi-C3V_i4Sy_4Y3yJLLeiGRzpz8Evvid/view?usp=sharing). 51 | 52 | 53 | Once downloaded, place the file in `./pre_trained/model_rgb.pth` 54 | 55 | 56 | 57 | ## Running the code 58 | 59 | Use the following commands to run the code 60 | 61 | ### FineDiving 62 | 63 | python -u train.py configs/fine/stoch_fine_diving_text_data_query.yaml 64 | 65 | To run the deterministic RICA2† 66 | 67 | python -u train.py configs/fine/deter_fine_diving_text_data_query.yaml 68 | 69 | ### MTL-AQA 70 | 71 | python -u train.py configs/mtl_aqa/stoch_mtl_diving_text_data_query.yaml 72 | 73 | To run the deterministic RICA2† 74 | 75 | python -u train.py configs/mtl_aqa/deter_mtl_diving_text_data_query.yaml 76 | 77 | 78 | These commands will train the specified models and automatically run the evaluation, generating the evaluation results at the end. 79 | 80 | 81 | ## [BONUS] Tuning Experiment Parameters 82 | 83 | Our code allows easy change of model and experiment parameters: 84 | 85 | ### Modifying hyperparameters 86 | 87 | You can modify different hyperparameters of the models and training by changing values within in the config files in `./configs` 88 | 89 | 90 | ### Generating text embeddings 91 | 92 | RICA2 incorporates the step information in the an action via LLM embedddings extracted from the textual step descriptions. These can be found `./tools`. 93 | 94 | For FineDiving 95 | 96 | python ./tools/finediving/finediving_t5xxl_text_embed_extraction.py 97 | 98 | For MTL-AQA 99 | 100 | python ./tools/mtl_aqa/mtl_t5xxl_text_embed_extraction.py 101 | 102 | 103 | You can easily change the models used for extracting embeddings from the step descriptions by following the user-friendly [HuggingFace documentation](https://huggingface.co/docs). 104 | 105 | 106 | ## Contact 107 | 108 | Email: majeedi+at+wisc+dot+edu -------------------------------------------------------------------------------- /configs/mtl_aqa/deter_mtl_aqa_text_data_query.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: mtl_aqa 2 | model_name : 'aqa-model' 3 | train_split: 'train' 4 | val_split: 'test' 5 | devices: [0] 6 | dataset: { 7 | data_root : ./data/mtl_aqa, 8 | video_dir: frames_long, 9 | label_file : augmented_final_annotations_dict.pkl, 10 | train_datafile : train_split_0.pkl , 11 | test_datafile : test_split_0.pkl , 12 | use_feats : False, 13 | feat_dir : "", 14 | frames_per_clip : 103, 15 | window_size : 8, 16 | stride : 5, 17 | max_seq_len: 20, 18 | input_frame_size : [256,455], 19 | crop_frame_size : 224, 20 | with_dd : True, 21 | three_judge_score_scaling : False, 22 | } 23 | model: { 24 | backbone_type : 'convEncoder', 25 | input_feat_dim: 1024, #feats or from the feat extractor 26 | embed_dim: 512, 27 | conv_dropout: 0.1, #drop out for initial conv layers 28 | conv_kernel_size: 3, 29 | neck_type: 'decoder-neck', 30 | num_layers : { 31 | n_conv_layers: 2, 32 | n_encoder_layers: 2, 33 | n_decoder_layers: 2, 34 | n_mlp_head_layers : 2, #for now this is hardcoded to 3 layers 35 | }, 36 | encoder_params: { 37 | n_encoder_heads: 8, 38 | attn_pdrop: 0.1, 39 | proj_pdrop: 0.1, 40 | path_pdrop: 0.1, 41 | use_abs_pe: False, 42 | }, 43 | decoder_params: { 44 | n_decoder_heads: 8, 45 | stride: 1, 46 | attn_pdrop: 0.1, 47 | proj_pdrop: 0.1, 48 | path_pdrop: 0.1, 49 | xattn_mode: 'affine', 50 | with_ln: True, 51 | query_config: { 52 | text_embeddings_path: assets/MTL_t5_xxl_text_embeddings.npy, #relative to the project root 53 | freeze_text_embeddings: True, 54 | text_queries_emb_dim: 4096, 55 | }, 56 | }, 57 | use_stochastic_embd: False, 58 | num_random_samples: 1, 59 | num_phases: 24, # =num of total subaction, check dataloader for more details 60 | score_bins: 1, 61 | } 62 | opt: { 63 | learning_rate: 0.0005, #this will get scaled by batchsize 64 | warmup_epochs: 3, 65 | schedule_type: "no_decay", 66 | epochs: 350, 67 | weight_decay: 0.0, 68 | feature_extractor_factor: 0.3, #lr for the feature extractor will be scaled by this factor 69 | neck_lr_factor: 1.0, #lr for the neck will be scaled by this factor 70 | } 71 | loader: { 72 | #for train and test batch_size provide how many samples can fit on one 8GB GPU e.g. for MTL this is 1 and 2, for finediving, this is 2 and 4 73 | train_batch_size : 8, #this can get overwritten by arg 74 | test_batch_size: 16, #this can get overwritten by arg 75 | num_workers: 4, #this will also be changed dynamically based on cpu count, this is min number, max is 20 76 | } 77 | train_cfg: { 78 | dataset_name: mtl_aqa, 79 | clip_grad_l2norm: 1.0, 80 | accumulation_steps: 1, #how many steps/batches to accumulate gradients before taking an optimizer step 81 | loss_weights: { 82 | loss: mse, 83 | quality_score: 1.0, 84 | phase_vib : 0.0, # 1e-3 85 | scale_vib: False, 86 | ranking: 0.05, # use > 0.0 to enable 87 | sparsity: 0.05 # use > 0.0 to enable 88 | }, 89 | } 90 | test_cfg: { 91 | } 92 | output_folder: ./ckpt/ 93 | -------------------------------------------------------------------------------- /configs/mtl_aqa/stoch_mtl_aqa_text_data_query.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: mtl_aqa 2 | model_name : 'aqa-model' 3 | train_split: 'train' 4 | val_split: 'test' 5 | devices: [0] 6 | dataset: { 7 | data_root : ./data/mtl_aqa, 8 | video_dir: frames_long, 9 | label_file : augmented_final_annotations_dict.pkl, 10 | train_datafile : train_split_0.pkl , 11 | test_datafile : test_split_0.pkl , 12 | use_feats : False, 13 | feat_dir : "", 14 | frames_per_clip : 103, 15 | window_size : 8, 16 | stride : 5, 17 | max_seq_len: 20, 18 | input_frame_size : [256,455], 19 | crop_frame_size : 224, 20 | with_dd : True, 21 | three_judge_score_scaling : False, 22 | } 23 | model: { 24 | backbone_type : 'convEncoder', 25 | input_feat_dim: 1024, #feats or from the feat extractor 26 | embed_dim: 512, 27 | conv_dropout: 0.1, #drop out for initial conv layers 28 | conv_kernel_size: 3, 29 | neck_type: 'decoder-neck', 30 | num_layers : { 31 | n_conv_layers: 2, 32 | n_encoder_layers: 2, 33 | n_decoder_layers: 2, 34 | n_mlp_head_layers : 2, #for now this is hardcoded to 3 layers 35 | }, 36 | encoder_params: { 37 | n_encoder_heads: 8, 38 | attn_pdrop: 0.1, 39 | proj_pdrop: 0.1, 40 | path_pdrop: 0.1, 41 | use_abs_pe: False, 42 | }, 43 | decoder_params: { 44 | n_decoder_heads: 8, 45 | stride: 1, 46 | attn_pdrop: 0.1, 47 | proj_pdrop: 0.1, 48 | path_pdrop: 0.1, 49 | xattn_mode: 'affine', 50 | with_ln: True, 51 | query_config: { 52 | text_embeddings_path: assets/MTL_t5_xxl_text_embeddings.npy, #relative to the project root 53 | freeze_text_embeddings: True, 54 | text_queries_emb_dim: 4096, 55 | }, 56 | }, 57 | use_stochastic_embd: True, 58 | num_random_samples: 20, 59 | num_phases: 24, # =num of total subaction, check dataloader for more details 60 | score_bins: 1, 61 | } 62 | opt: { 63 | learning_rate: 0.0005, #this will get scaled by batchsize 64 | warmup_epochs: 3, 65 | schedule_type: "no_decay", 66 | epochs: 350, 67 | weight_decay: 0.0, 68 | feature_extractor_factor: 0.3, #lr for the feature extractor will be scaled by this factor 69 | neck_lr_factor: 1.0, #lr for the neck will be scaled by this factor 70 | } 71 | loader: { 72 | #for train and test batch_size provide how many samples can fit on one 8GB GPU e.g. for MTL this is 1 and 2, for finediving, this is 2 and 4 73 | train_batch_size : 8, #this can get overwritten by arg 74 | test_batch_size: 16, #this can get overwritten by arg 75 | num_workers: 4, #this will also be changed dynamically based on cpu count, this is min number, max is 20 76 | } 77 | train_cfg: { 78 | dataset_name: mtl_aqa, 79 | clip_grad_l2norm: 1.0, 80 | accumulation_steps: 1, #how many steps/batches to accumulate gradients before taking an optimizer step 81 | loss_weights: { 82 | loss: mse, 83 | quality_score: 1.0, 84 | phase_vib : 0.00001, # 1e-3 85 | scale_vib: True, 86 | ranking: 0.05, # use > 0.0 to enable 87 | sparsity: 0.05 # use > 0.0 to enable 88 | }, 89 | } 90 | test_cfg: { 91 | } 92 | output_folder: ./ckpt/ 93 | -------------------------------------------------------------------------------- /configs/fine/deter_fine_diving_text_data_query.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: finediving 2 | model_name : 'aqa-model' 3 | train_split: 'train' 4 | val_split: 'test' 5 | devices: [0] 6 | dataset: { 7 | data_root : ./data/finediving, 8 | video_dir: FINADiving_MTL_256s, 9 | label_file : FineDiving_fine-grained_annotation.pkl, 10 | coarse_label_file : FineDiving_coarse_annotation.pkl, 11 | train_datafile : train_split.pkl , 12 | test_datafile : test_split.pkl , 13 | use_feats : False, 14 | feat_dir : "", 15 | frames_per_clip : 96, 16 | window_size : 16, 17 | stride : 10, 18 | max_seq_len: 9, 19 | input_frame_size : [112,200], 20 | crop_frame_size : 112, 21 | with_dd : True, 22 | three_judge_score_scaling : False, 23 | } 24 | model: { 25 | backbone_type : 'convEncoder', 26 | input_feat_dim: 1024, #feats or from the feat extractor 27 | embed_dim: 512, 28 | conv_dropout: 0.1, #drop out for initial conv layers 29 | conv_kernel_size: 3, 30 | neck_type: 'decoder-neck', 31 | num_layers : { 32 | n_conv_layers: 2, 33 | n_encoder_layers: 2, 34 | n_decoder_layers: 2, 35 | n_mlp_head_layers : 2, #for now this is hardcoded to 3 layers 36 | }, 37 | encoder_params: { 38 | n_encoder_heads: 8, 39 | attn_pdrop: 0.1, 40 | proj_pdrop: 0.1, 41 | path_pdrop: 0.1, 42 | use_abs_pe: False, 43 | }, 44 | decoder_params: { 45 | n_decoder_heads: 8, 46 | stride: 1, 47 | attn_pdrop: 0.1, 48 | proj_pdrop: 0.1, 49 | path_pdrop: 0.1, 50 | xattn_mode: 'affine', 51 | with_ln: True, 52 | query_config: { 53 | text_embeddings_path: assets/FineDiving_t5_xxl_text_embeddings.npy, #relative to the project root 54 | freeze_text_embeddings: True, 55 | text_queries_emb_dim: 4096, 56 | }, 57 | }, 58 | use_stochastic_embd: False, 59 | num_random_samples: 1, 60 | num_phases: 29, # =num of total subaction, check dataloader for more details 61 | score_bins: 1, 62 | } 63 | opt: { 64 | learning_rate: 0.0005, #this will get scaled by batchsize 65 | warmup_epochs: 3, 66 | schedule_type: "no_decay", 67 | epochs: 350, 68 | weight_decay: 0.0, 69 | feature_extractor_factor: 0.3, #lr for the feature extractor will be scaled by this factor 70 | neck_lr_factor: 1.0, #lr for the neck will be scaled by this factor 71 | } 72 | loader: { 73 | #for train and test batch_size provide how many samples can fit on one 8GB GPU e.g. for MTL this is 1 and 2, for finediving, this is 2 and 4 74 | train_batch_size : 8, #this can get overwritten by arg 75 | test_batch_size: 16, #this can get overwritten by arg 76 | num_workers: 4, #this will also be changed dynamically based on cpu count, this is min number, max is 20 77 | } 78 | train_cfg: { 79 | dataset_name: finediving, 80 | clip_grad_l2norm: 1.0, 81 | accumulation_steps: 1, #how many steps/batches to accumulate gradients before taking an optimizer step 82 | loss_weights: { 83 | loss: mse, 84 | quality_score: 1.0, 85 | phase_vib : 0.0, # 1e-3 86 | scale_vib: False, 87 | ranking: 0.05, # use > 0.0 to enable 88 | sparsity: 0.05 # use > 0.0 to enable 89 | }, 90 | } 91 | test_cfg: { 92 | } 93 | output_folder: ./ckpt/ 94 | 95 | 96 | -------------------------------------------------------------------------------- /configs/fine/stoch_fine_diving_text_data_query.yaml: -------------------------------------------------------------------------------- 1 | dataset_name: finediving 2 | model_name : 'aqa-model' 3 | train_split: 'train' 4 | val_split: 'test' 5 | devices: [0] 6 | dataset: { 7 | data_root : ./data/finediving, 8 | video_dir: FINADiving_MTL_256s, 9 | label_file : FineDiving_fine-grained_annotation.pkl, 10 | coarse_label_file : FineDiving_coarse_annotation.pkl, 11 | train_datafile : train_split.pkl , 12 | test_datafile : test_split.pkl , 13 | use_feats : False, 14 | feat_dir : "", 15 | frames_per_clip : 96, 16 | window_size : 16, 17 | stride : 10, 18 | max_seq_len: 9, 19 | input_frame_size : [112,200], 20 | crop_frame_size : 112, 21 | with_dd : True, 22 | three_judge_score_scaling : False, 23 | } 24 | model: { 25 | backbone_type : 'convEncoder', 26 | input_feat_dim: 1024, #feats or from the feat extractor 27 | embed_dim: 512, 28 | conv_dropout: 0.1, #drop out for initial conv layers 29 | conv_kernel_size: 3, 30 | neck_type: 'decoder-neck', 31 | num_layers : { 32 | n_conv_layers: 2, 33 | n_encoder_layers: 2, 34 | n_decoder_layers: 2, 35 | n_mlp_head_layers : 2, #for now this is hardcoded to 3 layers 36 | }, 37 | encoder_params: { 38 | n_encoder_heads: 8, 39 | attn_pdrop: 0.1, 40 | proj_pdrop: 0.1, 41 | path_pdrop: 0.1, 42 | use_abs_pe: False, 43 | }, 44 | decoder_params: { 45 | n_decoder_heads: 8, 46 | stride: 1, 47 | attn_pdrop: 0.1, 48 | proj_pdrop: 0.1, 49 | path_pdrop: 0.1, 50 | xattn_mode: 'affine', 51 | with_ln: True, 52 | query_config: { 53 | text_embeddings_path: assets/FineDiving_t5_xxl_text_embeddings.npy, #relative to the project root 54 | freeze_text_embeddings: True, 55 | text_queries_emb_dim: 4096, 56 | }, 57 | }, 58 | use_stochastic_embd: True, 59 | num_random_samples: 20, 60 | num_phases: 29, # =num of total subaction, check dataloader for more details 61 | score_bins: 1, 62 | } 63 | opt: { 64 | learning_rate: 0.0005, #this will get scaled by batchsize 65 | warmup_epochs: 3, 66 | schedule_type: "no_decay", 67 | epochs: 350, 68 | weight_decay: 0.0, 69 | feature_extractor_factor: 0.3, #lr for the feature extractor will be scaled by this factor 70 | neck_lr_factor: 1.0, #lr for the neck will be scaled by this factor 71 | } 72 | loader: { 73 | #for train and test batch_size provide how many samples can fit on one 8GB GPU e.g. for MTL this is 1 and 2, for finediving, this is 2 and 4 74 | train_batch_size : 8, #this can get overwritten by arg 75 | test_batch_size: 16, #this can get overwritten by arg 76 | num_workers: 4, #this will also be changed dynamically based on cpu count, this is min number, max is 20 77 | } 78 | train_cfg: { 79 | dataset_name: finediving, 80 | clip_grad_l2norm: 1.0, 81 | accumulation_steps: 1, #how many steps/batches to accumulate gradients before taking an optimizer step 82 | loss_weights: { 83 | loss: mse, 84 | quality_score: 1.0, 85 | phase_vib : 0.00001, # 1e-3 86 | scale_vib: True, 87 | ranking: 0.05, # use > 0.0 to enable 88 | sparsity: 0.05 # use > 0.0 to enable 89 | }, 90 | } 91 | test_cfg: { 92 | } 93 | output_folder: ./ckpt/ 94 | 95 | 96 | -------------------------------------------------------------------------------- /libs/utils/preprocessing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | @torch.no_grad() 6 | def multi_label_multi_class_one_hot_encode(labels, num_phases): 7 | """ 8 | Convert a list of labels to a binary one hot encoding 9 | This function is for one hot encoding multi-label multi-class labels 10 | """ 11 | labels = torch.LongTensor(labels) 12 | 13 | y_onehot = nn.functional.one_hot(labels, num_classes=num_phases) 14 | 15 | y_onehot = y_onehot.sum(dim=0).float() 16 | 17 | return y_onehot 18 | 19 | @torch.no_grad() 20 | def preprocessing(video_list, 21 | feat_extractor_type, 22 | max_seq_len, num_phases, 23 | padding_val=0.0 24 | ): 25 | """ 26 | Generate batched features and masks from a list of dict items 27 | """ 28 | if feat_extractor_type is not None: 29 | #each item in batch has frames as a list of tensors of shape Nwin x T x C x H x W 30 | frames = [x['window_frames'] for x in video_list] 31 | 32 | #B, Nwin x T x C x H x W -> B x Nwin x T x C x H x W 33 | batched_inputs = torch.stack(frames).cuda() 34 | 35 | batched_masks = torch.ones((len(video_list), 1, frames[0].shape[0])).cuda() 36 | 37 | else: 38 | feats = [x['feats'] for x in video_list] 39 | feats_lens = torch.as_tensor([feat.shape[-1] for feat in feats]) 40 | max_len = feats_lens.max(0).values.item() 41 | 42 | 43 | assert max_len <= max_seq_len, f"Input length must be smaller than max_seq_len during training, max len = {max_len}, max_seq_len = {max_seq_len}" 44 | # set max_len to self.max_seq_len 45 | max_len = max_seq_len 46 | # batch input shape B, C, T 47 | batch_shape = [len(feats), feats[0].shape[0], max_len] 48 | batched_inputs = feats[0].new_full(batch_shape, padding_val) 49 | for feat, pad_feat in zip(feats, batched_inputs): 50 | pad_feat[..., :feat.shape[-1]].copy_(feat) 51 | 52 | # generate the mask 53 | batched_masks = torch.arange(max_len)[None, :] < feats_lens[:, None] 54 | 55 | # push to device 56 | batched_inputs = batched_inputs.cuda() 57 | batched_masks = batched_masks.unsqueeze(1).cuda() 58 | 59 | actions = [multi_label_multi_class_one_hot_encode(x['actions_present'], num_phases) for x in video_list] 60 | 61 | gt_actions = torch.stack(actions, dim=0).cuda() 62 | 63 | gt_scores = torch.tensor([x['gt_score'] for x in video_list]).cuda() 64 | 65 | video_ids = [x['video_id'] for x in video_list] 66 | 67 | 68 | if 'target' in video_list[0]: 69 | target = torch.tensor([x['target'] for x in video_list]).cuda() 70 | else: 71 | target = [] 72 | 73 | if "difficulty" in video_list[0]: 74 | difficulties = torch.tensor([x['difficulty'] for x in video_list]) 75 | else: 76 | difficulties = [] 77 | 78 | ret_dict = { 79 | 'video_ids': video_ids, 80 | 'batched_inputs': batched_inputs, 81 | 'batched_masks': batched_masks, 82 | 'gt_actions': gt_actions, 83 | 'gt_scores': gt_scores, 84 | 'target': target, 85 | 'difficulties': difficulties 86 | } 87 | 88 | return ret_dict 89 | 90 | 91 | -------------------------------------------------------------------------------- /libs/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import stats 3 | 4 | from scipy.stats import hmean, kendalltau 5 | 6 | 7 | np.seterr(divide='ignore', invalid='ignore') 8 | np.set_printoptions(precision=2) 9 | 10 | 11 | 12 | class MetricsDict: 13 | """A class to store all the data for computing metrics. 14 | """ 15 | def __init__(self): 16 | self.metric_dict = { 17 | "pred_scores": [], 18 | "gt_scores": [], 19 | "gt_actions": [], 20 | "all_judge_scores": [], 21 | "difficulty": [], 22 | } 23 | 24 | def update(self, key, value): 25 | if key not in self.metric_dict.keys(): 26 | self.metric_dict[key] = [] 27 | 28 | self.metric_dict[key].append(value) 29 | 30 | def get_metric_dict(self): 31 | return self.metric_dict 32 | 33 | 34 | def evaluate(metric_dict, 35 | is_train, 36 | dataset_name, 37 | curr_epoch =-1 38 | ): 39 | metric_dict = metric_dict.get_metric_dict() 40 | 41 | if not is_train: 42 | #if var is available use that for uncertainty else use logits 43 | plot_rejection_curve(metric_dict) 44 | 45 | pred_scores = metric_dict["pred_scores"] 46 | true_scores = metric_dict["gt_scores"] 47 | 48 | pred_scores = np.concatenate([np.atleast_1d(x) for x in pred_scores]) 49 | true_scores = np.concatenate([np.atleast_1d(x) for x in true_scores]) 50 | 51 | min_true_score = true_scores.min() 52 | max_true_score = true_scores.max() 53 | 54 | rho, p = stats.spearmanr(pred_scores, true_scores) 55 | L2 = np.power(pred_scores - true_scores, 2).sum() / true_scores.shape[0] 56 | RL2 = np.power((pred_scores - true_scores) / (max_true_score - min_true_score), 2).sum() / true_scores.shape[0] 57 | 58 | if dataset_name == "cat101": 59 | accuracy = (np.round(pred_scores) == true_scores).mean() 60 | print(f"Epoch {curr_epoch} Accuracy: {accuracy}") 61 | 62 | return rho*100, L2, RL2*100 63 | 64 | def plot_rejection_curve(metric_dict): 65 | if "global_sqrt_var_emb" in metric_dict.keys(): 66 | global_sqrt_var_emb = metric_dict["global_sqrt_var_emb"] 67 | global_sqrt_var_emb = np.concatenate(global_sqrt_var_emb) 68 | mean_sqrt_var = hmean(global_sqrt_var_emb, axis = 1) 69 | uncertainties = mean_sqrt_var 70 | else: 71 | return 72 | 73 | #matplotlib.use('module://drawilleplot') 74 | pred_scores = metric_dict["pred_scores"] 75 | true_scores = metric_dict["gt_scores"] 76 | 77 | pred_scores = np.concatenate([np.atleast_1d(x) for x in pred_scores]) 78 | true_scores = np.concatenate([np.atleast_1d(x) for x in true_scores]) 79 | 80 | y_mae, bins = rejection_plot(pred_scores, true_scores, uncertainties) 81 | 82 | print("Calibration Kendall Tau (MAE): {:0.4f}".format(kendalltau(y_mae, bins)[0])) 83 | 84 | 85 | 86 | def rejection_plot(preds, gts, uncertainty, num_bins=11): 87 | all_mae = [] 88 | bins = np.linspace(0, 100, num_bins) 89 | #[ 0. 11.11 22.22 33.33 44.44 55.56 66.67 77.78 88.89 100. ] 90 | 91 | conf_sort_idx = np.argsort(uncertainty) 92 | uncertainty = uncertainty[conf_sort_idx] 93 | 94 | preds = preds[conf_sort_idx] 95 | gts = gts[conf_sort_idx] 96 | 97 | for bin_num in range(bins.shape[0]-1): 98 | bin_low = np.percentile(uncertainty, bins[bin_num]) 99 | 100 | if bin_num != bins.shape[0]-2: 101 | bin_high = np.percentile(uncertainty, bins[bin_num+1]) 102 | else: 103 | bin_high = 10000000000000000 104 | bin_idxs = np.where((uncertainty>=bin_low) & (uncertainty= t1 and t < t2 77 | 78 | def exceeds_range(frame): 79 | return frame.pts * tb >= t2 80 | 81 | for frame in buffer: 82 | if is_in_range(frame): 83 | ret.append(frame) 84 | 85 | prev_pts = None 86 | 87 | # This try except block is to avoid the EOF error that arrives because t2 exceeds the frame range 88 | try: 89 | for frame in container.decode(video=0): 90 | if frame.pts is None: 91 | raise AssertionError("frame is None") 92 | if prev_pts is not None and frame.pts < prev_pts: 93 | raise AssertionError("failed assumption pts in order: ") 94 | if not isinstance(frame, av.VideoFrame): 95 | raise AssertionError("other packets not supported") 96 | prev_pts = frame.pts 97 | 98 | buffer.append(frame) 99 | if len(buffer) > max_buffer_size: 100 | del buffer[0] 101 | 102 | if is_in_range(frame): 103 | ret.append(frame) 104 | elif exceeds_range(frame): 105 | break 106 | except: 107 | pass 108 | 109 | pts_in_ret = [frame.pts for frame in ret] 110 | if not (np.diff(pts_in_ret) > 0).all(): 111 | raise AssertionError("not increasing sequence of frames") 112 | return ret 113 | 114 | 115 | @property 116 | def duration(self) -> float: 117 | vstream = self.vid._container.streams.video[0] 118 | return vstream.duration * vstream.time_base 119 | -------------------------------------------------------------------------------- /libs/modeling/necks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | 6 | from .models import register_neck 7 | from .blocks import TransformerDecoder, LayerNorm 8 | 9 | @register_neck("decoder-neck") 10 | class Decoder(nn.Module): 11 | def __init__(self, d_model=512, 12 | n_heads=8, 13 | stride=1, 14 | num_decoder_layers=6, 15 | attn_pdrop=0.0, 16 | proj_pdrop=0.0, 17 | path_pdrop=0.0, 18 | num_phases=-1, 19 | xattn_mode='affine', 20 | with_ln=True, 21 | use_rel_pe=False, 22 | query_config=None): 23 | super().__init__() 24 | 25 | self.layers = nn.ModuleList() 26 | for _ in range(num_decoder_layers): 27 | self.layers.append( 28 | TransformerDecoder( 29 | embd_dim = d_model, 30 | kv_dim = d_model, 31 | stride=stride, 32 | n_heads=n_heads, 33 | attn_pdrop=attn_pdrop, 34 | proj_pdrop=proj_pdrop, 35 | path_pdrop=path_pdrop, 36 | xattn_mode=xattn_mode, 37 | use_rel_pe=use_rel_pe 38 | ) 39 | ) 40 | 41 | self.ln_out = LayerNorm(d_model) if with_ln else None 42 | 43 | assert num_phases > 0, "Number of phases must be > 0" 44 | 45 | self.num_phases = num_phases 46 | 47 | self._reset_parameters() 48 | text_queries_emb_dim = query_config['text_queries_emb_dim'] 49 | 50 | self.d_model = d_model 51 | self.n_heads = n_heads 52 | self.query_config = query_config 53 | 54 | text_embd = torch.from_numpy(np.load(query_config['text_embeddings_path'])) 55 | self.text_queries = nn.Embedding.from_pretrained(text_embd.squeeze(), 56 | freeze=query_config['freeze_text_embeddings']) 57 | self.text_queries_projection = nn.Linear( 58 | in_features=text_queries_emb_dim, out_features=d_model, bias=True) 59 | 60 | def _get_queries(self): 61 | # num phases x C 62 | text_queries = self.text_queries.weight 63 | 64 | # num phases x C -> num phases x d_model 65 | queries = self.text_queries_projection(text_queries) 66 | 67 | return queries 68 | 69 | def _reset_parameters(self): 70 | for p in self.parameters(): 71 | if p.dim() > 1: 72 | nn.init.xavier_uniform_(p) 73 | 74 | def _forward(self, q, q_mask, kv, kv_mask, kv_size=None, video_ids=None, curr_epoch=None): 75 | layerwise_cross_attn = [] 76 | for layer in self.layers: 77 | q, q_mask, cross_attn = layer(q, q_mask, kv, kv_mask, kv_size, 78 | video_ids=video_ids, curr_epoch=curr_epoch) 79 | layerwise_cross_attn.append(cross_attn) 80 | 81 | q = self.ln_out(q) 82 | 83 | return q, q_mask, layerwise_cross_attn 84 | 85 | 86 | def forward(self, vid_embd, vid_masks, batched_gt_actions, video_ids=None, curr_epoch=None): 87 | 88 | """ 89 | vid_embd: B, C, T 90 | vid_masks: B, T 91 | batched_gt_actions: B, phases 92 | """ 93 | B, C, T = vid_embd.shape 94 | 95 | #phases x d_model 96 | queries = self._get_queries() 97 | 98 | #phases x d_model -> B x phases x d_model 99 | queries = queries.repeat(B, 1, 1) 100 | 101 | #B x phases x d_model -> B x d_model x phases 102 | queries = queries.permute(0, 2, 1) 103 | 104 | #B x 1 x Phases 105 | query_masks = batched_gt_actions.unsqueeze(1).bool().detach() 106 | 107 | out, out_masks, cross_attns = self._forward(queries, query_masks, vid_embd, vid_masks, kv_size=T, 108 | video_ids=video_ids, curr_epoch=curr_epoch) 109 | 110 | #B x d_model x phases -> B x phases x d_model 111 | out = out.permute(0, 2, 1) 112 | 113 | return out, cross_attns 114 | -------------------------------------------------------------------------------- /tools/mtl_aqa/mtl_t5xxl_text_embed_extraction.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import List 3 | 4 | import torch 5 | import numpy as np 6 | from transformers import AutoTokenizer, AutoModel 7 | 8 | ##########################MTL######## 9 | @dataclass 10 | class DataClass: 11 | sub_action_descriptions: List[str] = field( 12 | default_factory=lambda: [ 13 | "Armstand: In this position, athletes start by standing on their hands on the edge of the diving board and perform their dive while maintaining this handstand position.", 14 | "Inwards: In this rotation type, athletes perform a forward-facing takeoff and rotate inward toward the diving board as they execute their dive.", 15 | "Reverse: In this rotation type, athletes perform a backward-facing takeoff and rotate backward away from the diving board as they execute their dive.", 16 | "Backward: In this rotation type, athletes perform a backward-facing takeoff and rotate backward toward the diving board as they execute their dive.", 17 | "Forward: In this rotation type, athletes perform a forward-facing takeoff and rotate forward away from the diving board as they execute their dive.", 18 | "Free: In this position, athletes have the freedom to perform any combination of dives from various categories without any restrictions or limitations.", 19 | "Tuck: In this position, athletes bring their knees to their chest and hold onto their shins while maintaining a compact shape throughout their dive.", 20 | "Pike: In this position, athletes maintain a straight body with their legs extended and their toes pointed out while bending at the waist to bring their hands toward their toes.", 21 | "0.5 Somersault: Athletes perform a half rotation in the air during their dive.", 22 | "1 Somersault: Athletes perform a full forward or backward rotation in the air during their dive.", 23 | "1.5 Somersault: Athletes perform a full rotation and an additional half rotation in the air during their dive.", 24 | "2 Somersault: Athletes perform two full forward or backward rotations in the air during their dive.", 25 | "2.5 Somersault: Athletes perform two full rotations and an additional half rotation in the air during their dive.", 26 | "3 Somersault: Athletes perform three full forward or backward rotations in the air during their dive.", 27 | "3.5 Somersault: Athletes perform three full rotations and an additional half rotation in the air during their dive.", 28 | "4.5 Somersault: Athletes perform four full rotations and an additional half rotation in the air during their dive.", 29 | "0.5 Twist: Athletes perform a half twist in the air during their dive.", 30 | "1 Twist: Athletes perform one full twist in the air during their dive.", 31 | "1.5 Twist: Athletes perform one and a half twists in the air during their dive.", 32 | "2 Twist: Athletes perform two full twists in the air during their dive.", 33 | "2.5 Twist: Athletes perform two and a half twists in the air during their dive.", 34 | "3 Twist: Athletes perform three full twists in the air during their dive.", 35 | "3.5 Twist: Athletes perform three and a half twists in the air during their dive.", 36 | "Entry: A diving technique involving a entry into the water, typically performed at the end of a dive.", 37 | ] 38 | ) 39 | 40 | class GetTextEmbeddings: 41 | def __init__(self, output_path) -> None: 42 | """ 43 | Args: 44 | output_path (_type_): output_path to save the embeddibngs 45 | """ 46 | self.data = DataClass() 47 | self.output_path = output_path 48 | 49 | def get_huggingface_embeddings(self): 50 | def average_pool( 51 | last_hidden_states: torch.Tensor, attention_mask: torch.Tensor 52 | ) -> torch.Tensor: 53 | last_hidden = last_hidden_states.masked_fill( 54 | ~attention_mask[..., None].bool(), 0.0 55 | ) 56 | return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] 57 | 58 | input_texts = [ 59 | "This is a " + desc.replace(": ", " action: ") 60 | for desc in self.data.sub_action_descriptions 61 | ] 62 | 63 | model_id="google/flan-t5-xxl" # "intfloat/e5-large" 64 | tokenizer = AutoTokenizer.from_pretrained(model_id) 65 | model = AutoModel.from_pretrained(model_id) 66 | 67 | # Tokenize the input texts 68 | batch_dict = tokenizer( 69 | input_texts, 70 | max_length=512, 71 | padding=True, 72 | truncation=True, 73 | return_tensors="pt", 74 | ) 75 | 76 | with torch.no_grad(): 77 | outputs= model(input_ids= batch_dict['input_ids'], decoder_input_ids=batch_dict['input_ids']) 78 | embeddings = average_pool( 79 | outputs.last_hidden_state, batch_dict["attention_mask"] 80 | ) 81 | np.save(self.output_path, embeddings.detach().cpu().numpy()) 82 | return embeddings.detach().cpu().numpy() 83 | 84 | 85 | get_text_embeddings = GetTextEmbeddings("MTL_t5_xxl_text_embeddings.npy") 86 | text_embeddings = get_text_embeddings.get_huggingface_embeddings() 87 | print("Text Embed shape: ", text_embeddings.shape) 88 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # python imports 2 | import argparse 3 | import os 4 | from pprint import pprint 5 | import sys, pickle 6 | 7 | # torch imports 8 | import torch 9 | import torch.nn as nn 10 | import torch.utils.data 11 | # for visualization 12 | from torch.utils.tensorboard import SummaryWriter 13 | 14 | # our code 15 | from libs.core import load_config 16 | from libs.datasets import make_dataset, make_data_loader 17 | from libs.modeling import make_meta_arch 18 | from libs.utils import (Logger, valid_one_epoch, 19 | fix_random_seed, ModelEma) 20 | 21 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 22 | def override_cfg_params(cfg, args): 23 | if args.gpu is not None: 24 | cfg["devices"] = args.gpu 25 | else: 26 | cfg['devices'] = [i for i in range(torch.cuda.device_count())] 27 | 28 | if args.data_root is not None: 29 | cfg["dataset"]["data_root"] = args.data_root 30 | 31 | if args.test_batch_size > 0: 32 | cfg["loader"]["test_batch_size"] = args.test_batch_size 33 | 34 | if cfg["train_cfg"]["loss_weights"]["loss"] == "corn": 35 | cfg["model"]["score_bins"] -= 1 36 | 37 | if args.cv_fold > -1: 38 | cfg["dataset"]["cross_val_id"] = args.cv_fold 39 | 40 | if args.cv_split_file != "": 41 | cfg["dataset"]["cross_val_split_file"] = args.cv_split_file 42 | 43 | if cfg["dataset"]["use_feats"] == False: 44 | cfg["model"]["finetune_feat_extractor"]= True 45 | cfg["model"]["feat_extractor_type"]= 'i3d' 46 | cfg["model"]["feat_extractor_weights_path"]= './pre_trained/model_rgb.pth' 47 | 48 | if cfg["model"]["use_stochastic_embd"] == False: 49 | cfg["train_cfg"]["loss_weights"]["phase_vib"] = 0.0 50 | 51 | return cfg 52 | 53 | 54 | 55 | def create_train_val_dataloaders(cfg, rng_generator): 56 | train_dataset = make_dataset( 57 | cfg['dataset_name'], 58 | True, 59 | cfg['train_split'], 60 | **cfg['dataset'] 61 | ) 62 | 63 | 64 | val_dataset = make_dataset( 65 | cfg['dataset_name'], 66 | False, 67 | cfg['val_split'], 68 | **cfg['dataset'] 69 | ) 70 | 71 | # data loaders 72 | train_loader = make_data_loader( 73 | train_dataset, True, rng_generator, **cfg['loader']) 74 | 75 | val_loader = make_data_loader( 76 | val_dataset, False, None, **cfg['loader']) 77 | 78 | return (train_loader, val_loader) 79 | 80 | ################################################################################ 81 | def main(args): 82 | """main function that handles training / inference""" 83 | 84 | """1. setup parameters / folders""" 85 | 86 | # parse args 87 | args.start_epoch = 0 88 | 89 | if os.path.isfile(args.config): 90 | cfg = load_config(args.config) 91 | else: 92 | raise ValueError("Config file does not exist.") 93 | 94 | torch.set_warn_always(False) 95 | 96 | # fix the random seeds (this will fix everything) 97 | rng_generator = fix_random_seed(cfg['init_rand_seed'], include_cuda=True) 98 | 99 | cfg = override_cfg_params(cfg, args) 100 | 101 | print("Args") 102 | pprint(vars(args), indent=4, stream=sys.__stdout__,sort_dicts=False) 103 | pprint(cfg, stream=sys.__stdout__,sort_dicts=False) 104 | 105 | """2. create dataset / dataloader""" 106 | train_loader, val_loader = create_train_val_dataloaders(cfg, rng_generator) 107 | 108 | """3. create model, optimizer, and scheduler""" 109 | # model 110 | model = make_meta_arch(cfg['model_name'], **cfg['model']) 111 | 112 | # not ideal for multi GPU training, ok for now 113 | # gpu_ids = ','.join(str(device_id) for device_id in cfg['devices']) 114 | # os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids 115 | # model = nn.DataParallel(model, device_ids=cfg['devices']) 116 | model = nn.DataParallel(model).cuda() 117 | 118 | ckpt_file = args.ckpt 119 | 120 | if not os.path.isfile(ckpt_file): 121 | raise ValueError("CKPT file does not exist!") 122 | 123 | """4. load ckpt""" 124 | print("=> loading checkpoint '{}'".format(ckpt_file)) 125 | # load ckpt, reset epoch / best rmse 126 | checkpoint = torch.load( 127 | ckpt_file, 128 | map_location = lambda storage, loc: storage.cuda(cfg['devices'][0]) 129 | ) 130 | model.load_state_dict(checkpoint['state_dict'], strict=True) 131 | del checkpoint 132 | 133 | 134 | """5. validation loop""" 135 | print("\nStart testing model {:s} ...".format(cfg['model_name'])) 136 | 137 | with torch.no_grad(): 138 | curr_srcc, curr_rl2, metric_dict = valid_one_epoch( 139 | val_loader, 140 | model, 141 | -1, 142 | cfg = cfg, 143 | tb_writer=None, 144 | print_freq=args.print_freq, 145 | save_predictions=True 146 | ) 147 | 148 | print("SRCC: {:.4f}, RL2: {:.4f}".format(curr_srcc, curr_rl2)) 149 | 150 | with open(os.path.join(os.path.dirname(ckpt_file), "epoch_{:03d}_srcc_{:.3f}_rl2_{:.3f}_outputs.pkl".format(-1, curr_srcc, curr_rl2)), "wb") as f: 151 | pickle.dump(metric_dict, f) 152 | 153 | 154 | print("All done!") 155 | return 156 | 157 | ################################################################################ 158 | if __name__ == '__main__': 159 | """Entry Point""" 160 | # the arg parser 161 | parser = argparse.ArgumentParser( 162 | description='Train') 163 | parser.add_argument('config', metavar='DIR', 164 | help='path to a config file') 165 | parser.add_argument('-p', '--print-freq', default=10, type=int, 166 | help='print frequency (default: 10 iterations)') 167 | parser.add_argument('--ckpt', default='', type=str, 168 | help='name of exp folder (default: none)') 169 | parser.add_argument('--data_root', type=str, metavar='PATH',) 170 | parser.add_argument('--test_batch_size', default=-1, type=int) 171 | parser.add_argument('--cv_fold', default=-1, type=int) 172 | 173 | parser.add_argument('--cv_split_file', default='', type=str) 174 | parser.add_argument('--gpu', nargs='*') 175 | 176 | args = parser.parse_args() 177 | 178 | main(args) 179 | -------------------------------------------------------------------------------- /libs/modeling/backbones.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | from torch.nn import functional as F 5 | 6 | from .models import register_backbone 7 | from .blocks import (TransformerBlock, MaskedConv1D, 8 | LayerNorm) 9 | from .i3d import I3D 10 | 11 | 12 | @register_backbone("convEncoder") 13 | class ConvEncoderBackbone(nn.Module): 14 | """ 15 | A backbone that with convs and encoder blocks 16 | """ 17 | def __init__( 18 | self, 19 | n_in, # input feature dimension 20 | n_embd, # embedding dimension (after convolution) 21 | n_embd_ks, # conv kernel size of the embedding network 22 | n_conv_layers, 23 | conv_dropout, 24 | conv_ln, # if to use layernorm 25 | n_encoder_layers, 26 | n_enc_head, 27 | attn_pdrop, 28 | proj_pdrop, 29 | path_pdrop, 30 | pos_embd 31 | ): 32 | super().__init__() 33 | self.n_in = n_in 34 | 35 | self.conv_dropout = conv_dropout 36 | self.n_encoder_layers = n_encoder_layers 37 | self.relu = nn.ReLU(inplace=True) 38 | 39 | self.pos_embd = pos_embd 40 | 41 | # embedding network using convs 42 | self.embd = nn.ModuleList() 43 | self.embd_norm = nn.ModuleList() 44 | for idx in range(n_conv_layers): 45 | n_in = n_embd if idx > 0 else n_in 46 | self.embd.append( 47 | MaskedConv1D( 48 | n_in, n_embd, n_embd_ks, 49 | stride=1, padding=n_embd_ks//2, bias=(not conv_ln) 50 | ) 51 | ) 52 | if conv_ln: 53 | self.embd_norm.append(LayerNorm(n_embd)) 54 | else: 55 | self.embd_norm.append(nn.Identity()) 56 | 57 | self.encoder_blocks = nn.ModuleList([TransformerBlock(n_embd= n_embd, 58 | n_head = n_enc_head, 59 | n_hidden= n_embd, 60 | n_out = n_embd, 61 | attn_pdrop=attn_pdrop, 62 | proj_pdrop=proj_pdrop, 63 | path_pdrop=path_pdrop, 64 | use_rel_pe = False, #only for local attn, not applicable here 65 | ) for _ in range(n_encoder_layers)]) 66 | 67 | # init weights 68 | self.apply(self.__init_weights__) 69 | 70 | def __init_weights__(self, module): 71 | # set nn.Linear bias term to 0 72 | if isinstance(module, (nn.Linear, nn.Conv1d)): 73 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 74 | if module.bias is not None: 75 | torch.nn.init.constant_(module.bias, 0.) 76 | 77 | def forward(self, x, mask, video_ids = None, curr_epoch = None): 78 | # x: batch size, feature channel, sequence length, 79 | # mask: batch size, 1, sequence length (bool) 80 | B, C, T = x.size() 81 | 82 | # embedding network 83 | for idx in range(len(self.embd)): 84 | x, mask = self.embd[idx](x, mask) 85 | x = self.relu(self.embd_norm[idx](x)) 86 | 87 | if idx != len(self.embd) - 1: 88 | x = nn.Dropout(self.conv_dropout)(x) 89 | 90 | B, C, T = x.size() 91 | 92 | if self.pos_embd is not None: 93 | x = x + self.pos_embd[:, :, :T].cuda() * mask.to(x.dtype) 94 | 95 | 96 | for idx in range(self.n_encoder_layers): 97 | x, mask = self.encoder_blocks[idx](x, mask) 98 | 99 | return x, mask 100 | 101 | 102 | 103 | @register_backbone("i3d") 104 | class I3D_feat_extractor(nn.Module): 105 | def __init__(self, I3D_ckpt_path, finetune): 106 | super(I3D_feat_extractor, self).__init__() 107 | self.i3d = I3D(num_classes=400, modality='rgb', dropout_prob=0.5) 108 | 109 | if I3D_ckpt_path is not None: 110 | print("loading I3D weights from: ", I3D_ckpt_path) 111 | self.i3d.load_state_dict(torch.load(I3D_ckpt_path)) 112 | 113 | self.finetune = finetune 114 | self.se = False 115 | 116 | self.avg_pool = torch.nn.AdaptiveAvgPool3d((1, 1, 1)) 117 | 118 | 119 | def get_feature_dim(self): 120 | return self.i3d.get_logits_dim() 121 | 122 | def forward(self, videos, video_ids = None): 123 | 124 | if videos.dim() == 6: 125 | #B x Nwin x T x C x H x W 126 | B, N_Win, T, C, H, W = videos.size() 127 | 128 | #B x N_Win x T x C x H x W -> B*N_Win x T x C x H x W 129 | videos_reshaped = videos.reshape(B*N_Win, T, C, H, W) 130 | 131 | #B*N_Win, T, C, H, W -> B*N_Win, C, T, H, W 132 | videos_tensor = videos_reshaped.permute(0, 2, 1, 3, 4) 133 | 134 | if videos.dim() == 7: 135 | #B x Nwin x T x C x H x W 136 | B, N_Win, crops, T, C, H, W = videos.size() 137 | 138 | #B x N_Win x T x C x H x W -> B*N_Win*crops x T x C x H x W 139 | videos_reshaped = videos.reshape(B*N_Win*crops, T, C, H, W) 140 | 141 | #B*N_Win*crops, T, C, H, W -> B*N_Win*crops, C, T, H, W 142 | videos_tensor = videos_reshaped.permute(0, 2, 1, 3, 4) 143 | 144 | if not self.finetune: 145 | with torch.no_grad(): 146 | self.i3d.eval() 147 | video_feature = self.i3d(videos_tensor) 148 | else: 149 | video_feature = self.i3d(videos_tensor) 150 | 151 | 152 | #Video -> B*N_Win, C 153 | video_feature = self.avg_pool(video_feature).squeeze() 154 | 155 | 156 | if videos.dim() == 6: 157 | #Split into batches (B x T x C), because N_Win is the final T 158 | batch_feats = video_feature.reshape(B, N_Win, -1) 159 | 160 | if videos.dim() == 7: 161 | batch_feats = video_feature.reshape(B, N_Win, crops, -1) 162 | 163 | batch_feats = batch_feats.mean(axis=2) 164 | 165 | #B x T x C -> B x C x T 166 | batch_feats = batch_feats.permute(0, 2, 1) 167 | 168 | return batch_feats 169 | 170 | 171 | -------------------------------------------------------------------------------- /tools/finediving/finediving_t5xxl_text_embed_extraction.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from dataclasses import dataclass, field 3 | from typing import List 4 | 5 | # import open_clip 6 | # from open_clip import tokenizer 7 | import torch 8 | import numpy as np 9 | import torch.nn.functional as F 10 | from transformers import AutoTokenizer, AutoModel 11 | 12 | 13 | """ 14 | Format in the annotation pkl file 15 | {'Forward': 1, 16 | 'Back': 2, 17 | 'Reverse': 3, 18 | 'Inward': 4, 19 | 'Arm.Forward': 5, 20 | 'Arm.Back': 6, 21 | 'Arm.Reverse': 7, 22 | '1 Som.Pike': 12, 23 | '1.5 Soms.Pike': 13, 24 | '2 Soms.Pike': 14, 25 | '2.5 Soms.Pike': 15, 26 | '3 Soms.Pike': 16, 27 | '3.5 Soms.Pike': 17, 28 | '4.5 Soms.Pike': 19, 29 | '1.5 Soms.Tuck': 21, 30 | '2 Soms.Tuck': 22, 31 | '2.5 Soms.Tuck': 23, 32 | '3 Soms.Tuck': 24, 33 | '3.5 Soms.Tuck': 25, 34 | '4.5 Soms.Tuck': 27, 35 | '0.5 Twist': 29, 36 | '1 Twist': 30, 37 | '1.5 Twists': 31, 38 | '2 Twists': 32, 39 | '2.5 Twists': 33, 40 | '3 Twists': 34, 41 | '3.5 Twists': 35, 42 | 'Entry': 36, 43 | '0.5 Som.Pike': 37} 44 | """ 45 | 46 | 47 | @dataclass 48 | class DataClass: 49 | sub_action_descriptions: List[str] = field( 50 | default_factory=lambda: [ 51 | "Forward: A diving technique involving a front-facing takeoff and entry.", 52 | "Back: A diving technique involving a back-facing takeoff and entry.", 53 | "Reverse: A diving technique involving a back-facing takeoff and entry while rotating forward.", 54 | "Inward: A diving technique involving a front-facing takeoff and entry while rotating backwards.", 55 | "Arm Forward: A diving technique involving a front-facing takeoff and entry with arms extended and hands meeting above the head.", 56 | "Arm Back: A diving technique involving a back-facing takeoff and entry with arms extended and hands meeting above the head.", 57 | "Arm Reverse: A diving technique involving a back-facing takeoff and entry with arms extended and hands meeting above the head while rotating forward.", 58 | "0.5 Somersault Pike: A diving technique involving a take-off with half a somersault in the pike position before entering the water.", 59 | "1 Somersault Pike: A diving technique involving a takeoff and rotating forward to form a pike position with one somersault.", 60 | "1.5 Somersaults Pike: A diving technique involving a takeoff and rotating forward to form a pike position with one and a half somersaults.", 61 | "2 Somersaults Pike: A diving technique involving a takeoff and rotating forward to form a pike position with two somersaults.", 62 | "2.5 Somersaults Pike: A diving technique involving a takeoff and rotating forward to form a pike position with two and a half somersaults.", 63 | "3 Somersaults Pike: A diving technique involving a takeoff and rotating forward to form a pike position with three somersaults.", 64 | "3.5 Somersaults Pike: A diving technique involving a takeoff and rotating forward to form a pike position with three and a half somersaults.", 65 | "4.5 Somersaults Pike: A diving technique involving a takeoff and rotating forward to form a pike position with four and a half somersaults.", 66 | "1.5 Somersaults Tuck: A diving technique involving a takeoff and rotating forward to bend at the waist with one and a half somersaults.", 67 | "2 Somersaults Tuck: A diving technique involving a takeoff and rotating forward to bend at the waist with two somersaults.", 68 | "2.5 Somersaults Tuck: A diving technique involving a takeoff and rotating forward to bend at the waist with two and a half somersaults.", 69 | "3 Somersaults Tuck: A diving technique involving a takeoff and rotating forward to bend at the waist with three somersaults.", 70 | "3.5 Somersaults Tuck: A diving technique involving a takeoff and rotating forward to bend at the waist with three and a half somersaults.", 71 | "4.5 Somersaults Tuck: A diving technique involving a takeoff and rotating forward to bend at the waist with four and a half somersaults.", 72 | "0.5 Twist: A diving technique involving a takeoff and half a twist before entering the water.", 73 | "1 Twist: A diving technique involving a takeoff and one full twist before entering the water.", 74 | "1.5 Twists: A diving technique involving a takeoff and one and a half twists before entering the water.", 75 | "2 Twists: A diving technique involving a takeoff and two full twists before entering the water.", 76 | "2.5 Twists: A diving technique involving a takeoff and two and a half twists before entering the water.", 77 | "3 Twists: A diving technique involving a takeoff with three twists before entering the water.", 78 | "3.5 Twists: A diving technique involving a takeoff with three and a half twists before entering the water.", 79 | "Entry: A diving technique involving a entry into the water, typically performed at the end of a dive." 80 | ] 81 | ) 82 | 83 | class GetTextEmbeddings: 84 | def __init__(self, output_path) -> None: 85 | """ 86 | Args: 87 | output_path (_type_): output_path to save the embeddibngs 88 | """ 89 | self.data = DataClass() 90 | self.output_path = output_path 91 | 92 | def get_huggingface_embeddings(self): 93 | def average_pool( 94 | last_hidden_states: torch.Tensor, attention_mask: torch.Tensor 95 | ) -> torch.Tensor: 96 | last_hidden = last_hidden_states.masked_fill( 97 | ~attention_mask[..., None].bool(), 0.0 98 | ) 99 | return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] 100 | 101 | input_texts = [ 102 | "This is a " + desc.replace(": ", " action: ") 103 | for desc in self.data.sub_action_descriptions 104 | ] 105 | 106 | model_id="google/flan-t5-xxl" # "intfloat/e5-large" 107 | tokenizer = AutoTokenizer.from_pretrained(model_id) 108 | model = AutoModel.from_pretrained(model_id) 109 | 110 | 111 | # Tokenize the input texts 112 | batch_dict = tokenizer( 113 | input_texts, 114 | max_length=512, 115 | padding=True, 116 | truncation=True, 117 | return_tensors="pt", 118 | ) 119 | 120 | with torch.no_grad(): 121 | outputs= model(input_ids= batch_dict['input_ids'], decoder_input_ids=batch_dict['input_ids']) 122 | embeddings = average_pool( 123 | outputs.last_hidden_state, batch_dict["attention_mask"] 124 | ) 125 | np.save(self.output_path, embeddings.detach().cpu().numpy()) 126 | return embeddings.detach().cpu().numpy() 127 | 128 | 129 | if __name__ == "__main__": 130 | parser = argparse.ArgumentParser() 131 | parser.add_argument( 132 | "--output_path", 133 | help="path to save the embeddings", 134 | nargs="?" 135 | ) 136 | args = parser.parse_args() 137 | 138 | get_text_embeddings = GetTextEmbeddings("FineDiving_t5_xxl_text_embeddings.npy") 139 | text_embeddings = get_text_embeddings.get_huggingface_embeddings() 140 | print("Text Embed shape: ", text_embeddings.shape) 141 | 142 | -------------------------------------------------------------------------------- /libs/datasets/mtl_aqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | from PIL import Image 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | from torchvision import transforms 9 | from torch.nn import functional as F 10 | 11 | from .datasets import register_dataset 12 | 13 | import pickle 14 | 15 | @register_dataset("mtl_aqa") 16 | class MTL_AQA(Dataset): 17 | def __init__(self, 18 | is_training, 19 | split, 20 | data_root, 21 | video_dir, 22 | label_file, 23 | train_datafile, 24 | test_datafile, 25 | stride, 26 | window_size, 27 | frames_per_clip, 28 | max_seq_len, 29 | input_frame_size = None, 30 | crop_frame_size = None, 31 | with_dd = True, 32 | three_judge_score_scaling = False, 33 | use_feats = None, 34 | feat_dir = None, 35 | ): 36 | 37 | self.subset = split 38 | self.is_training = is_training 39 | 40 | self.use_feats = use_feats 41 | 42 | self.crop_frame_size = crop_frame_size 43 | 44 | self.with_dd = with_dd 45 | self.three_judge_score_scaling = three_judge_score_scaling 46 | 47 | 48 | 49 | if self.subset == 'test': 50 | self.transforms = transforms.Compose( 51 | [ 52 | transforms.Resize(input_frame_size), 53 | transforms.CenterCrop(crop_frame_size), 54 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) 55 | ] 56 | ) 57 | 58 | 59 | elif self.subset == 'train': 60 | self.transforms = transforms.Compose( 61 | [ 62 | transforms.RandomHorizontalFlip(p=0.5), 63 | transforms.Resize(input_frame_size), 64 | transforms.RandomCrop(crop_frame_size), 65 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]), 66 | ] 67 | ) 68 | 69 | else: 70 | raise ValueError("subset should be train or test") 71 | 72 | print(f"subset: {self.subset}, is_training: {self.is_training}") 73 | 74 | self.pil_2_tensor = transforms.ToTensor() 75 | self.stride = stride 76 | self.window_size = window_size 77 | self.frames_per_clip = frames_per_clip 78 | self.max_seq_len = max_seq_len 79 | 80 | # file paths 81 | self.data_root = data_root 82 | self.video_dir = os.path.join(self.data_root, video_dir) 83 | 84 | if self.use_feats: 85 | self.feat_dir = os.path.join(self.data_root, feat_dir) 86 | 87 | train_datafile_path = os.path.join(self.data_root, "info", train_datafile) 88 | test_datafile_path = os.path.join(self.data_root,"info", test_datafile) 89 | 90 | self.data_anno = self.read_pickle(os.path.join(self.data_root, "info", label_file)) 91 | 92 | if self.subset == 'test': 93 | self.datalist = self._load_annotation(test_datafile_path) 94 | elif self.subset == 'train': 95 | self.datalist = self._load_annotation(train_datafile_path) 96 | else: 97 | raise ValueError("subset should be train or test") 98 | 99 | 100 | def _load_annotation(self, pkl_file): 101 | 102 | data_list = self.read_pickle(pkl_file) 103 | processed_data_list = [] 104 | 105 | for video_id in data_list: 106 | data = {} 107 | data['video_id'] = video_id 108 | 109 | data['actions_present'] = self.get_actions_present(self.data_anno[video_id]) 110 | 111 | data['final_score'] = self.data_anno[video_id]["final_score"] 112 | 113 | data['difficulty'] = self.data_anno[video_id]["difficulty"] 114 | 115 | data['gt_score'] = data['final_score'] 116 | 117 | processed_data_list.append(data) 118 | 119 | 120 | return processed_data_list 121 | 122 | def get_actions_present(self, anno): 123 | 124 | """ 125 | armstand: No, Yes 126 | rotation_type: Inward, reverse, backward, forward 127 | 128 | positions: Free, tuck, Pike 129 | #SS: 9 unique , including 0 for no ss 130 | #tw: 8 unique, including 0 for no tw 131 | 132 | indexing: 0, 1,2,3,4, 5,6,7, 8,9,10,11,12,13,14,15, 16,17,18,19,20,21,22 133 | """ 134 | if anno["armstand"] != 0: 135 | armstand_idx = 0 136 | else: 137 | armstand_idx = "MISSING" 138 | 139 | rotation_type_idx = 1 + anno["rotation_type"] 140 | 141 | pos_idx = 5 + anno["position"] 142 | 143 | if anno["ss_no"] != 0: 144 | ss_idx = 7 + anno["ss_no"] 145 | else: 146 | ss_idx = "MISSING" 147 | 148 | if anno["tw_no"] != 0: 149 | tw_idx = 15 + anno["tw_no"] 150 | else: 151 | tw_idx = "MISSING" 152 | 153 | actions_present = [] 154 | for idx in [pos_idx, armstand_idx, rotation_type_idx, ss_idx, tw_idx]: 155 | if idx != "MISSING": 156 | actions_present.append(idx) 157 | 158 | #action for entry 159 | actions_present.append(23) 160 | 161 | return actions_present 162 | 163 | def read_pickle(self, pickle_path): 164 | with open(pickle_path,'rb') as f: 165 | pickle_data = pickle.load(f) 166 | return pickle_data 167 | 168 | 169 | def load_video(self, video_file_name): 170 | first, second = video_file_name[0], video_file_name[1] 171 | 172 | if first < 10: 173 | first = "0"+str(first) 174 | 175 | if second < 10: 176 | second = "0"+str(second) 177 | image_list = sorted((glob.glob(os.path.join(self.video_dir, str(first)+"_"+str(second), '*.jpg')))) 178 | 179 | if self.is_training: 180 | #start from 0-1-2-3 181 | start_idx = torch.randint(0,4,[1]).item() 182 | image_list = image_list[start_idx:start_idx+self.frames_per_clip] 183 | 184 | video = [Image.open(image) for image in image_list] 185 | video = [transforms.ToTensor()(frame) for frame in video] 186 | video = torch.stack(video) 187 | video = self.transforms(video) 188 | 189 | #extract windows 190 | start_idx = [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95] 191 | 192 | video_pack = torch.stack([video[i:i+self.window_size,:,:,:] for i in start_idx]) 193 | 194 | return video_pack 195 | 196 | def load_feats(self, video_file_name): 197 | feats = np.load(os.path.join(self.feat_dir, str(video_file_name[0])+"-abr-"+str(video_file_name[1])+'.npz'))["arr_0"] 198 | feats = torch.from_numpy(feats).float() 199 | 200 | return feats 201 | 202 | 203 | def __getitem__(self, index): 204 | video_data = self.datalist[index] 205 | 206 | video_id = video_data["video_id"] 207 | 208 | data = {"video_id": video_id} 209 | if self.use_feats: 210 | data['feats'] = self.load_feats(video_id) 211 | else: 212 | data['window_frames'] = self.load_video(video_id) 213 | 214 | data["video_name"] = video_id 215 | data["difficulty"] = video_data["difficulty"] 216 | data["actions_present"] = video_data["actions_present"] 217 | 218 | target = video_data["gt_score"] 219 | data["gt_score"] = video_data["gt_score"] 220 | 221 | if self.with_dd: 222 | target = target / data['difficulty'] 223 | 224 | if self.three_judge_score_scaling: 225 | target = target / 3.0 226 | 227 | data["target"] = target 228 | 229 | return data 230 | 231 | 232 | def __len__(self): 233 | return len(self.datalist) 234 | -------------------------------------------------------------------------------- /libs/modeling/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from einops import repeat 6 | 7 | 8 | """ 9 | ranking and sparsity loss, modified from TPT 10 | Uses broadcasted operations to calculate the the loss in a fully vectorized manner. 11 | https://github.com/baiyang4/aqa_tpt/blob/cf6d1631ec53c676f108926fc5480afe7c47ff56/train_pairencode1_decoder_1selfatt_self8head_ffn_sp_new.py#L289 12 | """ 13 | def get_att_loss(logits_all, gt_actions,ranking_loss_wt, sparsity_loss_wt, hinge_loss): 14 | 15 | # logits_all: B, Q, T 16 | B, Q, T = logits_all.size() 17 | #gt_actions One hot encoded: B, Q 18 | 19 | logits_all = logits_all.permute(0, 2, 1) #.transpose(-1,-2) 20 | softmax_dim = logits_all.shape[1] 21 | temp_idx = repeat(torch.arange(1, softmax_dim + 1), 't -> b t k', b=logits_all.shape[0], k=logits_all.shape[-1]).float().to(logits_all.device) 22 | cluster_mean = (logits_all * temp_idx).sum(1) 23 | var = (torch.abs(temp_idx - repeat(cluster_mean, 'b k -> b t k', t=softmax_dim)) * logits_all).sum(1) 24 | 25 | if ranking_loss_wt == 0.0: 26 | return 0.0, var.mean() 27 | 28 | # Extract active clusters based on gt_actions for all samples in batch 29 | active_indices = gt_actions.nonzero(as_tuple=True) # Indices where gt_actions is 1 30 | active_clusters = cluster_mean[active_indices] # Extract active clusters 31 | 32 | # Calculate hinge loss for ranking the active clusters in a fully vectorized manner 33 | cluster_counts = gt_actions.sum(dim=1).long() # Number of active clusters per batch 34 | max_clusters = cluster_counts.max().item() 35 | 36 | # Pad active clusters to have the same length per batch 37 | padded_active_clusters = torch.zeros((B, max_clusters), device=cluster_mean.device) 38 | mask = torch.arange(max_clusters, device=cluster_mean.device).unsqueeze(0) < cluster_counts.unsqueeze(1) 39 | padded_active_clusters[mask] = active_clusters 40 | 41 | # Create shifted versions of the active clusters for comparison 42 | current_clusters = padded_active_clusters[:, :-1] # B, max_clusters - 1 43 | next_clusters = padded_active_clusters[:, 1:] # B, max_clusters - 1 44 | valid_pairs_mask = mask[:, :-1] & mask[:, 1:] # Mask to identify valid pairs for loss calculation 45 | 46 | # Apply the valid pairs mask to current and next clusters 47 | valid_current_clusters = current_clusters[valid_pairs_mask] 48 | valid_next_clusters = next_clusters[valid_pairs_mask] 49 | 50 | # Calculate hinge loss for valid pairs 51 | ones = torch.ones_like(valid_current_clusters) 52 | loss = hinge_loss(valid_next_clusters, valid_current_clusters, ones).mean() 53 | 54 | # Add boundary conditions (first and last clusters only once per batch) 55 | first_clusters = padded_active_clusters[:, 0] # B 56 | last_clusters = padded_active_clusters[torch.arange(B), cluster_counts - 1] # B 57 | boundary_mask = cluster_counts > 0 # Mask to identify batches with at least one active cluster 58 | 59 | loss += hinge_loss(first_clusters[boundary_mask], torch.ones_like(first_clusters[boundary_mask]), torch.ones_like(first_clusters[boundary_mask])).sum() 60 | loss += hinge_loss(torch.ones_like(last_clusters[boundary_mask]) * softmax_dim, last_clusters[boundary_mask], torch.ones_like(last_clusters[boundary_mask])).sum() 61 | 62 | return loss, var.mean() 63 | 64 | 65 | 66 | def criterion(model_output, target, difficulties, inp_gt_actions, loss_weights, with_dd=True, three_judge_score_scaling=False): 67 | 68 | batched_gt_judge_scores = target 69 | 70 | #B, samples, 1 71 | batched_pred_quality_scores = model_output["all_sample_outputs"] 72 | 73 | difficulties = difficulties.cuda() 74 | 75 | # comment these out when training jigsaws?? TODO 76 | if with_dd: 77 | batched_gt_judge_scores = batched_gt_judge_scores * difficulties 78 | batched_pred_quality_scores = batched_pred_quality_scores * difficulties.unsqueeze(-1).unsqueeze(-1) 79 | 80 | if three_judge_score_scaling: 81 | batched_gt_judge_scores = batched_gt_judge_scores * 3.0 82 | batched_pred_quality_scores = batched_pred_quality_scores * 3.0 83 | 84 | if loss_weights["phase_vib"] != 0: 85 | 86 | gt_actions = inp_gt_actions.detach() 87 | gt_actions = gt_actions.reshape(-1) 88 | 89 | phase_mean_emb, phase_var_emb = model_output["phase_mean_emb"], model_output["phase_var_emb"] 90 | 91 | batch_size, num_phases, channels = phase_mean_emb.shape 92 | phase_mean_emb = phase_mean_emb.reshape(-1, channels) 93 | phase_var_emb = phase_var_emb.reshape(-1, channels) 94 | 95 | #mask out the non-action phases 96 | if num_phases > 1: 97 | phase_mean_emb = phase_mean_emb[gt_actions.nonzero().squeeze().long(),:] 98 | phase_var_emb = phase_var_emb[gt_actions.nonzero().squeeze().long(),:] 99 | 100 | #B*phases x channels -> B*phases 101 | phase_vib_loss = torch.sum(torch.pow(phase_mean_emb, 2) + phase_var_emb - torch.log(phase_var_emb) - 1.0, dim=1) * 0.5 102 | 103 | phase_vib_loss = phase_vib_loss.sum() 104 | phase_vib_loss = phase_vib_loss / (batch_size * channels) 105 | 106 | 107 | # batched_pred_quality_score_logits shape -> batch_size x num_samples x 1 108 | # batched_gt_judge_score_class shape -> B x 1 109 | B, num_samples, bins = batched_pred_quality_scores.shape 110 | 111 | if loss_weights["loss"] == "mse": 112 | ground_truth_expanded = batched_gt_judge_scores.unsqueeze(1).expand(-1, num_samples) 113 | 114 | batched_pred_quality_scores_flat = batched_pred_quality_scores.reshape(-1) 115 | ground_truth_flat = ground_truth_expanded.reshape(-1) 116 | 117 | quality_score_loss = nn.MSELoss(reduction='mean')(batched_pred_quality_scores_flat, ground_truth_flat) 118 | 119 | elif loss_weights["loss"] =="xentropy": 120 | ground_truth_expanded = batched_gt_judge_scores.unsqueeze(1).expand(-1, num_samples) 121 | 122 | batched_pred_quality_scores_flat = batched_pred_quality_scores.reshape(B, bins) 123 | ground_truth_flat = ground_truth_expanded.reshape(B).long() 124 | quality_score_loss = nn.CrossEntropyLoss(reduction="mean")(batched_pred_quality_scores_flat, ground_truth_flat) 125 | else: 126 | raise ValueError("Invalid loss function") 127 | 128 | if loss_weights["ranking"] > 0.0 or loss_weights["sparsity"] > 0.0: 129 | # bz x heads x num_queries x T 130 | # Flatten the logits and gt_actions across batch, heads, and decoder layers dimensions 131 | batch_size, num_heads, num_queries, num_clips = model_output["cross_attn"][0].shape 132 | num_decoder_layers = len(model_output["cross_attn"]) 133 | 134 | # Concatenate across decoder layers and heads to flatten 135 | logits_all_flat = torch.cat([model_output["cross_attn"][decoder_layer_idx][:, head_idx, :, :] 136 | for decoder_layer_idx in range(num_decoder_layers) 137 | for head_idx in range(num_heads)], dim=0) # Shape: (B * num_heads * num_decoder_layers), Q, T 138 | 139 | # Repeat the ground truth actions for each head and decoder layer 140 | gt_actions_flat = inp_gt_actions.detach().repeat(num_heads * num_decoder_layers, 1) # Shape: (B * num_heads * num_decoder_layers), Q 141 | 142 | ranking_loss, sparsity_loss = get_att_loss(logits_all_flat, 143 | gt_actions_flat, 144 | ranking_loss_wt=loss_weights["ranking"], 145 | sparsity_loss_wt=loss_weights["sparsity"], 146 | hinge_loss=nn.MarginRankingLoss(1.0)) 147 | 148 | 149 | loss_dict, final_loss = {}, 0.0 150 | valid_loss_keys = ["quality_score", "phase_vib", "ranking", "sparsity"] 151 | 152 | for key, value in loss_weights.items(): 153 | if key in valid_loss_keys and isinstance(value, (int, float)) and value != 0: 154 | loss_dict[f"{key}_loss"] = value * locals()[f"{key}_loss"] 155 | final_loss += loss_dict[f"{key}_loss"] 156 | 157 | loss_dict["final_loss"] = final_loss 158 | 159 | return loss_dict 160 | -------------------------------------------------------------------------------- /libs/datasets/finediving.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import glob 4 | import numpy as np 5 | from PIL import Image 6 | 7 | import torch 8 | from torch.utils.data import Dataset 9 | from torchvision import transforms 10 | from torch.nn import functional as F 11 | 12 | from .datasets import register_dataset 13 | 14 | import pickle 15 | 16 | 17 | @register_dataset("finediving") 18 | class FineDiving(Dataset): 19 | def __init__(self, 20 | is_training, 21 | split, 22 | data_root, 23 | video_dir, 24 | label_file, 25 | coarse_label_file, 26 | train_datafile, 27 | test_datafile, 28 | stride, 29 | window_size, 30 | frames_per_clip, 31 | max_seq_len, 32 | input_frame_size = None, 33 | crop_frame_size = 224, 34 | with_dd = True, 35 | three_judge_score_scaling = False, 36 | use_feats = None, 37 | feat_dir = None 38 | ): 39 | 40 | self.subset = split 41 | self.is_training = is_training 42 | 43 | self.use_feats = use_feats 44 | self.crop_frame_size = crop_frame_size 45 | 46 | self.with_dd = with_dd 47 | self.three_judge_score_scaling = three_judge_score_scaling 48 | 49 | if self.subset == 'test': 50 | self.transforms = transforms.Compose( 51 | [ 52 | transforms.Resize(input_frame_size), 53 | transforms.CenterCrop(crop_frame_size), 54 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) 55 | ] 56 | ) 57 | 58 | elif self.subset == 'train': 59 | self.transforms = transforms.Compose( 60 | [ 61 | transforms.RandomHorizontalFlip(p=0.5), 62 | transforms.Resize(input_frame_size), 63 | transforms.RandomCrop(crop_frame_size), 64 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]), 65 | ] 66 | ) 67 | 68 | else: 69 | raise ValueError("subset should be train or test") 70 | 71 | print(f"subset: {self.subset}, is_training: {self.is_training}") 72 | 73 | 74 | self.pil_2_tensor = transforms.ToTensor() 75 | 76 | self.stride = stride 77 | self.window_size = window_size 78 | self.frames_per_clip = frames_per_clip 79 | self.max_seq_len = max_seq_len 80 | 81 | # file paths 82 | self.data_root = data_root 83 | self.video_dir = os.path.join(self.data_root, video_dir) 84 | 85 | if self.use_feats: 86 | self.feat_dir = os.path.join(self.data_root, feat_dir) 87 | 88 | self.label_path = os.path.join(self.data_root,"Annotations", label_file) 89 | self.coarse_label_path = os.path.join(self.data_root,"Annotations", coarse_label_file) 90 | 91 | train_datafile_path = os.path.join(self.data_root, "Annotations",train_datafile) 92 | test_datafile_path = os.path.join(self.data_root,"Annotations", test_datafile) 93 | 94 | #mapping from subaction to index, refer to gen_text_embeddings.py for more details 95 | self.label_to_idx_dict = {1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 37: 7, 12: 8, 13: 9, 14: 10, 15: 11, 96 | 16: 12, 17: 13, 19: 14, 21: 15, 22: 16, 23: 17, 24: 18, 25: 19, 27: 20, 29: 21, 97 | 30: 22, 31: 23, 32: 24, 33: 25, 34: 26, 35: 27, 36: 28} 98 | 99 | 100 | self.data_anno = self.read_pickle(self.label_path) 101 | self.coarse_data_anno = self.read_pickle(self.coarse_label_path) 102 | 103 | if self.subset == 'test': 104 | self.datalist = self._load_annotations(test_datafile_path) 105 | elif self.subset == 'train': 106 | self.datalist = self._load_annotations(train_datafile_path) 107 | else: 108 | raise ValueError("subset should be train or test") 109 | 110 | 111 | def _load_annotations(self, datafile_path): 112 | data_list = self.read_pickle(datafile_path) 113 | 114 | processed_data_list = [] 115 | 116 | for video_id in data_list: 117 | data = {} 118 | data['video_id'] = video_id 119 | 120 | data['final_score'] = self.data_anno[video_id]["dive_score"] 121 | 122 | data['difficulty'] = self.data_anno[video_id]["difficulty"] 123 | 124 | data['gt_score'] = data['final_score'] 125 | 126 | processed_data_list.append(data) 127 | 128 | return processed_data_list 129 | 130 | def load_video(self, video_file_name): 131 | image_list = sorted((glob.glob(os.path.join(self.video_dir, video_file_name[0], str(video_file_name[1]), '*.jpg')))) 132 | 133 | if self.is_training: 134 | #start from 0-1-2-3 135 | start_idx = torch.randint(0,4,[1]).item() 136 | 137 | # randomly drop the end frames 138 | end_idx = torch.randint(0,4,[1]).item() 139 | image_list = image_list[start_idx:start_idx + len(image_list) - end_idx] 140 | 141 | start_frame = int(image_list[0].split("/")[-1][:-4]) 142 | end_frame = int(image_list[-1].split("/")[-1][:-4]) 143 | 144 | frame_list = np.linspace(start_frame, end_frame, self.frames_per_clip).astype(np.int32) 145 | image_frame_idx = [frame_list[i] - start_frame for i in range(self.frames_per_clip)] 146 | video = [Image.open(image_list[image_frame_idx[i]]) for i in range(self.frames_per_clip)] 147 | video = [transforms.ToTensor()(frame) for frame in video] 148 | video = torch.stack(video) 149 | video = self.transforms(video) 150 | 151 | #extract windows 152 | start_idx = list(range(0, 90, 10)) 153 | 154 | video_pack = torch.stack([video[i:i+self.window_size,:,:,:] for i in start_idx]) 155 | 156 | frames_labels = self.data_anno[video_file_name]["frames_labels"] 157 | #mapping labels to label indices 158 | frames_labels = [self.label_to_idx_dict[l] for l in frames_labels] 159 | 160 | return video_pack, frames_labels 161 | 162 | 163 | 164 | def load_feats(self, video_file_name): 165 | feats = np.load(os.path.join(self.feat_dir, video_file_name[0]+"-abr-"+str(video_file_name[1])+'.npz'))["feats"] 166 | 167 | feats = feats.transpose(1,0) 168 | 169 | feats = torch.from_numpy(feats).float() 170 | 171 | frames_labels = self.data_anno[video_file_name]["frames_labels"] 172 | 173 | #mapping labels to label indices 174 | frames_labels = [self.label_to_idx_dict[l] for l in frames_labels] 175 | 176 | return feats, frames_labels 177 | 178 | 179 | def read_pickle(self, pickle_path): 180 | with open(pickle_path,'rb') as f: 181 | pickle_data = pickle.load(f) 182 | return pickle_data 183 | 184 | 185 | def __getitem__(self, index): 186 | video_data = self.datalist[index] 187 | 188 | video_id = video_data["video_id"] 189 | 190 | data = {"video_id": video_id} 191 | if self.use_feats: 192 | data['feats'], frame_labels = self.load_feats(video_id) 193 | else: 194 | data['window_frames'], frame_labels = self.load_video(video_id) 195 | 196 | 197 | data["video_name"] = video_id 198 | 199 | data['difficulty'] = video_data["difficulty"] 200 | 201 | target = video_data["gt_score"] 202 | data["gt_score"] = video_data["gt_score"] 203 | 204 | if self.with_dd: 205 | target = target / data['difficulty'] 206 | 207 | if self.three_judge_score_scaling: 208 | target = target / 3.0 209 | 210 | data["target"] = target 211 | 212 | #if entry is missing, add 28 213 | if 28 not in frame_labels: 214 | frame_labels.append(28) 215 | 216 | #take only the presence of actions 217 | data['actions_present'] = sorted(list(set(frame_labels))) 218 | 219 | return data 220 | 221 | def __len__(self): 222 | return len(self.datalist) 223 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # python imports 2 | import argparse, os 3 | import time 4 | import datetime 5 | from pprint import pprint 6 | import sys, pickle 7 | 8 | # torch imports 9 | import torch 10 | import torch.nn as nn 11 | import torch.utils.data 12 | # for visualization 13 | from torch.utils.tensorboard import SummaryWriter 14 | 15 | # our code 16 | from libs.core import load_config 17 | from libs.datasets import make_dataset, make_data_loader 18 | from libs.modeling import make_meta_arch 19 | from libs.utils import (Logger, train_one_epoch, valid_one_epoch, 20 | save_checkpoint, make_optimizer, make_scheduler, 21 | fix_random_seed) 22 | 23 | def override_cfg_params(cfg, args): 24 | if args.gpu is not None: 25 | cfg["devices"] = args.gpu 26 | else: 27 | cfg['devices'] = [i for i in range(torch.cuda.device_count())] 28 | 29 | if args.data_root is not None: 30 | cfg["dataset"]["data_root"] = args.data_root 31 | 32 | if args.train_batch_size > 0: 33 | cfg["loader"]["train_batch_size"] = args.train_batch_size 34 | 35 | if args.test_batch_size > 0: 36 | cfg["loader"]["test_batch_size"] = args.test_batch_size 37 | 38 | if args.cv_fold > -1: 39 | cfg["dataset"]["cross_val_id"] = args.cv_fold 40 | 41 | if args.cv_split_file != "": 42 | cfg["dataset"]["cross_val_split_file"] = args.cv_split_file 43 | 44 | if cfg["dataset"]["use_feats"] == False: 45 | cfg["model"]["finetune_feat_extractor"]= True 46 | cfg["model"]["feat_extractor_type"]= 'i3d' 47 | cfg["model"]["feat_extractor_weights_path"]= './pre_trained/model_rgb.pth' 48 | 49 | if cfg["model"]["use_stochastic_embd"] == False: 50 | cfg["train_cfg"]["loss_weights"]["phase_vib"] = 0.0 51 | cfg["train_cfg"]["loss_weights"]["scale_vib"] = False 52 | 53 | return cfg 54 | 55 | 56 | def create_checkpoint_folder(cfg, args): 57 | if not os.path.exists(cfg['output_folder']): 58 | os.mkdir(cfg['output_folder']) 59 | 60 | cfg_filename = os.path.basename(args.config).replace('.yaml', '') 61 | 62 | if len(args.output) == 0: 63 | ts = datetime.datetime.fromtimestamp(int(time.time())) 64 | ts = str(ts).replace(' ', '_').replace(':', '-') 65 | ckpt_folder = os.path.join(cfg['output_folder'], cfg_filename + '_' + str(ts)) 66 | else: 67 | ckpt_folder = os.path.join(cfg['output_folder'], cfg_filename + '_' + str(args.output)) 68 | 69 | if not os.path.exists(ckpt_folder): 70 | os.mkdir(ckpt_folder) 71 | 72 | return ckpt_folder 73 | 74 | 75 | def create_train_val_dataloaders(cfg, rng_generator): 76 | train_dataset = make_dataset( 77 | cfg['dataset_name'], 78 | True, 79 | cfg['train_split'], 80 | **cfg['dataset'] 81 | ) 82 | 83 | 84 | val_dataset = make_dataset( 85 | cfg['dataset_name'], 86 | False, 87 | cfg['val_split'], 88 | **cfg['dataset'] 89 | ) 90 | 91 | print("Train dataset size: {:d}".format(len(train_dataset))) 92 | print("Val dataset size: {:d}".format(len(val_dataset))) 93 | 94 | # data loaders 95 | train_loader = make_data_loader( 96 | train_dataset, True, rng_generator, **cfg['loader']) 97 | 98 | val_loader = make_data_loader( 99 | val_dataset, False, None, **cfg['loader']) 100 | 101 | return (train_loader, val_loader) 102 | 103 | ################################################################################ 104 | def main(args): 105 | """main function that handles training / inference""" 106 | 107 | """1. setup parameters / folders""" 108 | 109 | # parse args 110 | args.start_epoch = 0 111 | 112 | if os.path.isfile(args.config): 113 | cfg = load_config(args.config) 114 | else: 115 | raise ValueError("Config file does not exist.") 116 | 117 | ckpt_folder = create_checkpoint_folder(cfg, args) 118 | print("Checkpoint folder: {:s}".format(ckpt_folder)) 119 | # tensorboard writer 120 | tb_writer = SummaryWriter(os.path.join(ckpt_folder, 'logs')) 121 | 122 | logger = Logger(os.path.join(ckpt_folder, '0_log.txt')) 123 | print("If you plan to debug using ipdb then comment the following line") 124 | #sys.stdout = logger 125 | 126 | torch.set_warn_always(False) 127 | 128 | # fix the random seeds (this will fix everything) 129 | rng_generator = fix_random_seed(cfg['init_rand_seed'], include_cuda=True) 130 | 131 | cfg = override_cfg_params(cfg, args) 132 | 133 | print("Args") 134 | pprint(vars(args), indent=4, stream=sys.__stdout__,sort_dicts=False) 135 | pprint(cfg, stream=sys.__stdout__,sort_dicts=False) 136 | 137 | """2. create dataset / dataloader""" 138 | train_loader, val_loader = create_train_val_dataloaders(cfg, rng_generator) 139 | 140 | """3. create model, optimizer, and scheduler""" 141 | # model 142 | model = make_meta_arch(cfg['model_name'], **cfg['model']) 143 | 144 | model = nn.DataParallel(model).cuda() 145 | 146 | optimizer = make_optimizer(model, cfg['opt']) 147 | 148 | # schedule 149 | num_iters_per_epoch = len(train_loader) / cfg["train_cfg"]["accumulation_steps"] 150 | scheduler = make_scheduler(optimizer, cfg['opt'], num_iters_per_epoch) 151 | 152 | """4. Resume from model / Misc""" 153 | # resume from a checkpoint? 154 | if args.resume: 155 | if os.path.isfile(args.resume): 156 | # load ckpt, reset epoch / best rmse 157 | checkpoint = torch.load(args.resume, 158 | map_location = lambda storage, loc: storage.cuda( 159 | cfg['devices'][0])) 160 | args.start_epoch = checkpoint['epoch'] 161 | model.load_state_dict(checkpoint['state_dict']) 162 | # also load the optimizer / scheduler if necessary 163 | optimizer.load_state_dict(checkpoint['optimizer']) 164 | scheduler.load_state_dict(checkpoint['scheduler']) 165 | print("=> loaded checkpoint '{:s}' (epoch {:d}".format( 166 | args.resume, checkpoint['epoch'] 167 | )) 168 | del checkpoint 169 | else: 170 | print("=> no checkpoint found at '{}'".format(args.resume)) 171 | return 172 | 173 | # save the current config 174 | with open(os.path.join(ckpt_folder, 'config.txt'), 'w') as fid: 175 | pprint(cfg, stream=fid) 176 | fid.flush() 177 | 178 | """4. training / validation loop""" 179 | print("\nStart training model {:s} ...".format(cfg['model_name'])) 180 | 181 | # start training 182 | max_epochs = cfg['opt'].get( 183 | 'early_stop_epochs', 184 | cfg['opt']['epochs'] + cfg['opt']['warmup_epochs'] 185 | ) 186 | best_rl2 = 1000 187 | srcc_at_best_rl2 = -100 188 | best_rl2_epoch = 0 189 | 190 | best_srcc = -100 191 | rl2_at_best_srcc = 1000 192 | best_srcc_epoch = 0 193 | 194 | for epoch in range(args.start_epoch, max_epochs): 195 | 196 | # train for one epoch 197 | #start_time = time.time() 198 | train_one_epoch( 199 | train_loader, 200 | model, 201 | optimizer, 202 | scheduler, 203 | epoch, 204 | cfg = cfg, 205 | tb_writer=tb_writer, 206 | print_freq=args.print_freq 207 | ) 208 | #print("Time taken for epoch: {:.2f} s".format(time.time() - start_time)) 209 | 210 | if cfg["train_cfg"]["loss_weights"]["phase_vib"] > 0.0 and cfg["train_cfg"]["loss_weights"]["scale_vib"]: 211 | if epoch > 30 and epoch % 10 == 0: 212 | cfg["train_cfg"]["loss_weights"]["phase_vib"] *= 3 213 | cfg["train_cfg"]["loss_weights"]["phase_vib"] = min(cfg["train_cfg"]["loss_weights"]["phase_vib"], 0.005) 214 | 215 | 216 | if (epoch == 0) or ((epoch+1) <= 30 and (epoch+1) % 5 == 0 ) or (30 < (epoch+1) <= 120 and (epoch+1) % 3 == 0) or (epoch+1) > 120 or (epoch+1) == max_epochs: 217 | curr_srcc, curr_rl2, metric_dict = valid_one_epoch( 218 | val_loader, 219 | model, 220 | epoch, 221 | cfg = cfg, 222 | tb_writer=tb_writer, 223 | print_freq=args.print_freq, 224 | save_predictions=True 225 | ) 226 | 227 | if curr_srcc > best_srcc: 228 | best_srcc = curr_srcc 229 | rl2_at_best_srcc = curr_rl2 230 | best_srcc_epoch = epoch 231 | srcc_improved = True 232 | else: 233 | srcc_improved = False 234 | 235 | if curr_rl2 < best_rl2: 236 | best_rl2 = curr_rl2 237 | srcc_at_best_rl2 = curr_srcc 238 | best_rl2_epoch = epoch 239 | rl2_improved = True 240 | else: 241 | rl2_improved = False 242 | 243 | print("Best SRCC: {:.4f}, corres. RL2: {:.4f} at epoch {:d}".format(best_srcc, rl2_at_best_srcc, best_srcc_epoch)) 244 | print("Best RL2: {:.4f}, corres. SRCC: {:.4f} at epoch {:d}".format(best_rl2, srcc_at_best_rl2, best_rl2_epoch)) 245 | 246 | save_states = { 247 | 'epoch': epoch + 1, 248 | 'state_dict': model.state_dict(), 249 | 'scheduler': scheduler.state_dict(), 250 | 'optimizer': optimizer.state_dict(), 251 | } 252 | 253 | if srcc_improved: 254 | save_checkpoint( 255 | save_states, 256 | False, 257 | file_folder=ckpt_folder, 258 | file_name='srcc_best.pth.tar' 259 | ) 260 | with open(os.path.join(ckpt_folder, "srcc_best_outputs.pkl"), "wb") as f: 261 | pickle.dump(metric_dict, f) 262 | 263 | if rl2_improved: 264 | save_checkpoint( 265 | save_states, 266 | False, 267 | file_folder=ckpt_folder, 268 | file_name='rl2_best.pth.tar' 269 | ) 270 | with open(os.path.join(ckpt_folder, "rl2_best_outputs.pkl"), "wb") as f: 271 | pickle.dump(metric_dict, f) 272 | 273 | del save_states 274 | 275 | print("Best SRCC: {:.4f}".format(best_srcc)) 276 | print("Best RL2: {:.4f}".format(best_rl2)) 277 | 278 | # wrap up 279 | tb_writer.close() 280 | print("All done!") 281 | return 282 | 283 | ################################################################################ 284 | if __name__ == '__main__': 285 | """Entry Point""" 286 | # the arg parser 287 | parser = argparse.ArgumentParser( 288 | description='Train') 289 | parser.add_argument('config', metavar='DIR', 290 | help='path to a config file') 291 | parser.add_argument('-p', '--print-freq', default=10, type=int, 292 | help='print frequency (default: 10 iterations)') 293 | parser.add_argument('-c', '--ckpt-freq', default=5, type=int, 294 | help='checkpoint frequency (default: every 5 epochs)') 295 | parser.add_argument('--output', default='', type=str, 296 | help='name of exp folder (default: none)') 297 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 298 | help='path to a checkpoint (default: none)') 299 | 300 | parser.add_argument('--data_root', type=str, metavar='PATH',) 301 | parser.add_argument('--train_batch_size', default=-1, type=int) 302 | parser.add_argument('--test_batch_size', default=-1, type=int) 303 | parser.add_argument('--cv_fold', default=-1, type=int) 304 | 305 | parser.add_argument('--cv_split_file', default='', type=str) 306 | parser.add_argument('--gpu', nargs='*') 307 | 308 | args = parser.parse_args() 309 | 310 | main(args) 311 | -------------------------------------------------------------------------------- /libs/modeling/meta_archs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from .models import register_meta_arch, make_backbone, make_neck 6 | from .blocks import get_sinusoid_encoding, PhaseDistComposer 7 | from .losses import criterion 8 | 9 | 10 | class Quality_Score_Head(nn.Module): 11 | """ 12 | Head for multiple judge quality score 13 | """ 14 | 15 | def __init__(self, input_dim, score_bins, num_random_samples=1, use_stochastic_embd=True, dataset=None): 16 | super().__init__() 17 | 18 | self.score_bins = score_bins 19 | self.input_dim = input_dim 20 | 21 | self.common_base = nn.Sequential( 22 | nn.Linear(input_dim, input_dim // 2), 23 | nn.ReLU() 24 | ) 25 | 26 | self.scoring_head = nn.Sequential( 27 | nn.Linear(input_dim // 2, input_dim // 2), 28 | nn.ReLU(), 29 | nn.Linear(input_dim // 2, score_bins) 30 | ) 31 | 32 | self.use_stochastic_embd = use_stochastic_embd 33 | self.phase_composer = PhaseDistComposer(dataset_name=dataset, dim=input_dim//2) 34 | 35 | if self.use_stochastic_embd: 36 | self.num_random_samples = num_random_samples 37 | self.mean_var_linear = nn.Sequential( 38 | nn.Linear(input_dim // 2, input_dim // 2), 39 | nn.ReLU(), 40 | nn.Linear(input_dim // 2, input_dim), 41 | ) 42 | 43 | self.temperature = nn.Parameter(torch.ones(1) * 2) 44 | 45 | self._reset_parameters() 46 | 47 | def _reset_parameters(self): 48 | for p in self.parameters(): 49 | if p.dim() > 1: 50 | nn.init.xavier_uniform_(p) 51 | 52 | def temperature_scale(self, logits): 53 | temperature = self.temperature.unsqueeze(0).unsqueeze(0).expand(*logits.shape) 54 | return logits / temperature 55 | 56 | def forward(self, x, gt_actions= None): 57 | common_input = self.common_base(x) 58 | 59 | #common_input_shape -> B x phases x dim//2 60 | if self.use_stochastic_embd: 61 | #B x phases x channels*2 62 | phase_mean_log_var = self.mean_var_linear(common_input) 63 | 64 | #first C/2 channels are mean, last C/2 channels are log_var 65 | phase_mean_emb = phase_mean_log_var[:, :, :self.input_dim // 2] 66 | phase_log_var = phase_mean_log_var[:,:,self.input_dim // 2: ] 67 | 68 | phase_var_emb = torch.exp(phase_log_var) 69 | 70 | #shapes = B 71 | common_input, global_sqrt_var = self.phase_composer.process_phases( 72 | phase_mean_emb = phase_mean_emb, 73 | phase_var_emb = phase_var_emb, 74 | num_samples = 1 if self.training else self.num_random_samples, 75 | gt_actions = gt_actions.detach() 76 | ) 77 | else: 78 | # #if phase var emb is none, then it become deterministic 79 | common_input, _ = self.phase_composer.process_phases( 80 | common_input, None, 1, gt_actions.detach() 81 | ) 82 | 83 | #num_samples x batch_size x 1 84 | all_sample_output = self.scoring_head(common_input) 85 | 86 | #num_samples x batch_size x 1 --> batch_size x num_samples x 1 87 | all_sample_output = all_sample_output.permute(1, 0, 2) 88 | 89 | if self.use_stochastic_embd: 90 | return ( 91 | phase_mean_emb, 92 | phase_var_emb, 93 | global_sqrt_var, 94 | all_sample_output 95 | ) 96 | else: 97 | return all_sample_output 98 | 99 | 100 | 101 | @register_meta_arch("aqa-model") 102 | class AQA_Model(nn.Module): 103 | """ 104 | Transformer based model for single stage action localization 105 | """ 106 | 107 | def __init__( 108 | self, 109 | feat_extractor_type, 110 | finetune_feat_extractor, 111 | feat_extractor_weights_path, 112 | backbone_type, 113 | input_feat_dim, 114 | embed_dim, 115 | conv_dropout, 116 | neck_type, 117 | num_layers, 118 | conv_kernel_size, 119 | encoder_params, 120 | decoder_params, 121 | num_phases, 122 | score_bins, 123 | train_cfg, # other cfg for training 124 | test_cfg, # other cfg for testing 125 | max_seq_len, # max sequence length in features (used for training) 126 | frames_per_clip, # original sequence length in frames 127 | use_stochastic_embd = False, 128 | num_random_samples = None 129 | ): 130 | super().__init__() 131 | 132 | self.feat_extractor_type = feat_extractor_type 133 | self.finetune_feat_extractor = finetune_feat_extractor 134 | 135 | self.embed_dim = embed_dim 136 | self.use_phases = False 137 | 138 | self.num_phases = num_phases 139 | 140 | self.input_feat_dim = input_feat_dim 141 | self.max_seq_len = max_seq_len 142 | self.frames_per_clip = frames_per_clip 143 | 144 | self.score_bins = score_bins 145 | 146 | self.train_cfg = train_cfg 147 | 148 | 149 | self.use_stochastic_embd = use_stochastic_embd 150 | 151 | self.num_random_samples = num_random_samples 152 | 153 | if encoder_params["use_abs_pe"]: 154 | pos_embd = get_sinusoid_encoding(self.max_seq_len, embed_dim) / ( 155 | embed_dim**0.5 156 | ) 157 | self.register_buffer("pos_embd", pos_embd, persistent=False) 158 | self.pos_embd = self.pos_embd 159 | 160 | assert feat_extractor_type in ["i3d", None] 161 | 162 | if feat_extractor_type == "i3d": 163 | self.feat_extractor = make_backbone( 164 | "i3d", 165 | **{ 166 | "I3D_ckpt_path": feat_extractor_weights_path, 167 | "finetune": self.finetune_feat_extractor, 168 | } 169 | ) 170 | else: 171 | self.feat_extractor = None 172 | 173 | # backbone network: conv + transformer 174 | assert backbone_type in ["conv", "convEncoder"] 175 | 176 | if backbone_type == "convEncoder": 177 | self.backbone = make_backbone( 178 | "convEncoder", 179 | **{ 180 | "n_in": input_feat_dim, # input feature dimension 181 | "n_embd": embed_dim, # embedding dimension (after convolution) 182 | "conv_dropout": conv_dropout, # dropout rate for conv layers in the initial projection network 183 | "n_conv_layers": num_layers["n_conv_layers"], # number of layers 184 | "n_embd_ks": conv_kernel_size, # conv kernel size of the embedding network 185 | "conv_ln": False, # whether to use layer norm 186 | "n_encoder_layers": num_layers[ 187 | "n_encoder_layers" 188 | ], # number of encoder layers 189 | "n_enc_head": encoder_params[ 190 | "n_encoder_heads" 191 | ], # number of heads in the encoder 192 | "attn_pdrop": encoder_params["attn_pdrop"], 193 | "proj_pdrop": encoder_params["proj_pdrop"], 194 | "path_pdrop": encoder_params["path_pdrop"], 195 | "pos_embd": self.pos_embd if encoder_params["use_abs_pe"] else None, 196 | } 197 | ) 198 | else: 199 | self.backbone = None 200 | 201 | if neck_type == "decoder-neck": 202 | self.neck = make_neck( 203 | "decoder-neck", 204 | **{ 205 | "d_model": embed_dim, 206 | "n_heads": decoder_params["n_decoder_heads"], 207 | "stride": decoder_params["stride"], 208 | "num_decoder_layers": num_layers["n_decoder_layers"], 209 | "attn_pdrop": decoder_params["attn_pdrop"], 210 | "proj_pdrop": decoder_params["proj_pdrop"], 211 | "path_pdrop": decoder_params["path_pdrop"], 212 | "xattn_mode": decoder_params["xattn_mode"], 213 | "with_ln": decoder_params["with_ln"], 214 | "num_phases": num_phases, 215 | "query_config": decoder_params["query_config"], 216 | } 217 | ) 218 | else: 219 | self.neck = None 220 | 221 | 222 | self.quality_score_head = Quality_Score_Head( 223 | input_dim=embed_dim, 224 | score_bins=self.score_bins, 225 | num_random_samples = self.num_random_samples if self.training else 1, 226 | use_stochastic_embd = self.use_stochastic_embd, 227 | dataset=train_cfg["dataset_name"] 228 | ) 229 | 230 | @property 231 | def device(self): 232 | # a hacky way to get the device type 233 | # will throw an error if parameters are on different devices 234 | return list(set(p.device for p in self.parameters()))[0] 235 | 236 | def forward( 237 | self, 238 | batched_inputs, 239 | batched_masks, 240 | batched_gt_actions, 241 | video_ids = None, 242 | curr_epoch = None, 243 | criterion_args = None 244 | ): 245 | if self.feat_extractor is not None: 246 | batched_inputs = self.feat_extractor(batched_inputs, video_ids) 247 | 248 | # x = B x C x T, mask = B x 1 x T 249 | x, masks = self.backbone(batched_inputs, batched_masks, video_ids=video_ids, curr_epoch=curr_epoch) 250 | 251 | 252 | if self.neck is not None: 253 | # output = B x #phases x C 254 | x, cross_attns = self.neck(x, masks, batched_gt_actions, video_ids=video_ids, curr_epoch=curr_epoch) 255 | else: 256 | # B x C x T -> B x T x C 257 | x = x.permute(0, 2, 1) 258 | 259 | if self.use_stochastic_embd: 260 | ( 261 | phase_mean_emb, 262 | phase_var_emb, 263 | global_sqrt_var_emb, 264 | all_sample_outputs 265 | ) = self.quality_score_head(x, batched_gt_actions) 266 | else: 267 | #batch_size x num_samples x bins 268 | all_sample_outputs = self.quality_score_head(x, batched_gt_actions) 269 | 270 | 271 | out_dict = { 272 | "all_sample_outputs": all_sample_outputs, 273 | "gt_actions": batched_gt_actions, 274 | "cross_attn": cross_attns if self.neck is not None else None, 275 | } 276 | 277 | if self.use_stochastic_embd: 278 | out_dict["phase_mean_emb"] = phase_mean_emb 279 | out_dict["phase_var_emb"] = phase_var_emb 280 | out_dict["global_sqrt_var_emb"] = global_sqrt_var_emb 281 | 282 | 283 | if criterion_args is not None: 284 | losses = criterion(out_dict, 285 | criterion_args['target'], 286 | criterion_args["difficulties"], 287 | criterion_args["gt_actions"], 288 | loss_weights = criterion_args['loss_weights'], 289 | with_dd = criterion_args['with_dd'], 290 | three_judge_score_scaling = criterion_args['three_judge_score_scaling'], 291 | ) 292 | else: 293 | losses = None 294 | 295 | #remove cross attns from out_dict 296 | if self.neck is not None: 297 | out_dict.pop("cross_attn") 298 | return out_dict, losses -------------------------------------------------------------------------------- /libs/utils/lr_schedulers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | from collections import Counter 4 | from bisect import bisect_right 5 | 6 | import torch 7 | from torch.optim.lr_scheduler import _LRScheduler 8 | 9 | 10 | class LinearWarmupCosineAnnealingLR(_LRScheduler): 11 | """ 12 | Sets the learning rate of each parameter group to follow a linear warmup schedule 13 | between warmup_start_lr and base_lr followed by a cosine annealing schedule between 14 | base_lr and eta_min. 15 | 16 | .. warning:: 17 | It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR` 18 | after each iteration as calling it after each epoch will keep the starting lr at 19 | warmup_start_lr for the first epoch which is 0 in most cases. 20 | 21 | .. warning:: 22 | passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING. 23 | It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of 24 | :func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing 25 | epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling 26 | train and validation methods. 27 | 28 | Example: 29 | >>> layer = nn.Linear(10, 1) 30 | >>> optimizer = Adam(layer.parameters(), lr=0.02) 31 | >>> scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=40) 32 | >>> # 33 | >>> # the default case 34 | >>> for epoch in range(40): 35 | ... # train(...) 36 | ... # validate(...) 37 | ... scheduler.step() 38 | >>> # 39 | >>> # passing epoch param case 40 | >>> for epoch in range(40): 41 | ... scheduler.step(epoch) 42 | ... # train(...) 43 | ... # validate(...) 44 | """ 45 | 46 | def __init__( 47 | self, 48 | optimizer, 49 | warmup_epochs, 50 | max_epochs, 51 | warmup_start_lr = 0.0, 52 | eta_min = 1e-8, 53 | last_epoch = -1, 54 | ): 55 | """ 56 | Args: 57 | optimizer (Optimizer): Wrapped optimizer. 58 | warmup_epochs (int): Maximum number of iterations for linear warmup 59 | max_epochs (int): Maximum number of iterations 60 | warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0. 61 | eta_min (float): Minimum learning rate. Default: 0. 62 | last_epoch (int): The index of last epoch. Default: -1. 63 | """ 64 | self.warmup_epochs = warmup_epochs 65 | self.max_epochs = max_epochs 66 | self.warmup_start_lr = warmup_start_lr 67 | self.eta_min = eta_min 68 | 69 | super(LinearWarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch) 70 | 71 | def get_lr(self): 72 | """ 73 | Compute learning rate using chainable form of the scheduler 74 | """ 75 | if not self._get_lr_called_within_step: 76 | warnings.warn( 77 | "To get the last learning rate computed by the scheduler, " 78 | "please use `get_last_lr()`.", 79 | UserWarning, 80 | ) 81 | 82 | if self.last_epoch == 0: 83 | return [self.warmup_start_lr] * len(self.base_lrs) 84 | elif self.last_epoch < self.warmup_epochs: 85 | return [ 86 | group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 87 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 88 | ] 89 | elif self.last_epoch == self.warmup_epochs: 90 | return self.base_lrs 91 | elif (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0: 92 | return [ 93 | group["lr"] + (base_lr - self.eta_min) * 94 | (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2 95 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 96 | ] 97 | 98 | return [ 99 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) / 100 | ( 101 | 1 + 102 | math.cos(math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs)) 103 | ) * (group["lr"] - self.eta_min) + self.eta_min for group in self.optimizer.param_groups 104 | ] 105 | 106 | def _get_closed_form_lr(self): 107 | """ 108 | Called when epoch is passed as a param to the `step` function of the scheduler. 109 | """ 110 | if self.last_epoch < self.warmup_epochs: 111 | return [ 112 | self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 113 | for base_lr in self.base_lrs 114 | ] 115 | 116 | return [ 117 | self.eta_min + 0.5 * (base_lr - self.eta_min) * 118 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) 119 | for base_lr in self.base_lrs 120 | ] 121 | 122 | 123 | class LinearWarmupMultiStepLR(_LRScheduler): 124 | """ 125 | Sets the learning rate of each parameter group to follow a linear warmup schedule 126 | between warmup_start_lr and base_lr followed byt co 127 | 128 | .. warning:: 129 | It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR` 130 | after each iteration as calling it after each epoch will keep the starting lr at 131 | warmup_start_lr for the first epoch which is 0 in most cases. 132 | 133 | .. warning:: 134 | passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING. 135 | It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of 136 | :func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing 137 | epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling 138 | train and validation methods. 139 | """ 140 | 141 | def __init__( 142 | self, 143 | optimizer, 144 | warmup_epochs, 145 | milestones, 146 | warmup_start_lr = 0.0, 147 | gamma = 0.1, 148 | last_epoch = -1, 149 | ): 150 | """ 151 | Args: 152 | optimizer (Optimizer): Wrapped optimizer. 153 | warmup_epochs (int): Maximum number of iterations for linear warmup 154 | max_epochs (int): Maximum number of iterations 155 | milestones (list): List of epoch indices. Must be increasing. 156 | warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0. 157 | gamma (float): Multiplicative factor of learning rate decay. 158 | Default: 0.1. 159 | last_epoch (int): The index of last epoch. Default: -1. 160 | """ 161 | self.warmup_epochs = warmup_epochs 162 | self.warmup_start_lr = warmup_start_lr 163 | self.milestones = Counter(milestones) 164 | self.gamma = gamma 165 | 166 | super(LinearWarmupMultiStepLR, self).__init__(optimizer, last_epoch) 167 | 168 | def get_lr(self): 169 | """ 170 | Compute learning rate using chainable form of the scheduler 171 | """ 172 | if not self._get_lr_called_within_step: 173 | warnings.warn("To get the last learning rate computed by the scheduler, " 174 | "please use `get_last_lr()`.", UserWarning) 175 | 176 | if self.last_epoch == 0: 177 | # starting warm up 178 | return [self.warmup_start_lr] * len(self.base_lrs) 179 | elif self.last_epoch < self.warmup_epochs: 180 | # linear warm up (0 ~ self.warmup_epochs -1) 181 | return [ 182 | group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 183 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 184 | ] 185 | elif self.last_epoch == self.warmup_epochs: 186 | # end of warm up (reset to base lrs) 187 | return self.base_lrs 188 | elif (self.last_epoch - self.warmup_epochs) not in self.milestones: 189 | # in between the steps 190 | return [group['lr'] for group in self.optimizer.param_groups] 191 | 192 | return [ 193 | group['lr'] * self.gamma ** self.milestones[self.last_epoch - self.warmup_epochs] 194 | for group in self.optimizer.param_groups 195 | ] 196 | 197 | def _get_closed_form_lr(self): 198 | """ 199 | Called when epoch is passed as a param to the `step` function of the scheduler. 200 | """ 201 | if self.last_epoch < self.warmup_epochs: 202 | return [ 203 | self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 204 | for base_lr in self.base_lrs 205 | ] 206 | 207 | milestones = list(sorted(self.milestones.elements())) 208 | return [base_lr * self.gamma ** bisect_right(milestones, self.last_epoch - self.warmup_epochs) 209 | for base_lr in self.base_lrs] 210 | 211 | 212 | 213 | 214 | class LinearWarmupNoDecayLR(_LRScheduler): 215 | """ 216 | Sets the learning rate of each parameter group to follow a linear warmup schedule 217 | between warmup_start_lr and base_lr followed by a constant lr. 218 | 219 | .. warning:: 220 | It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR` 221 | after each iteration as calling it after each epoch will keep the starting lr at 222 | warmup_start_lr for the first epoch which is 0 in most cases. 223 | 224 | .. warning:: 225 | passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING. 226 | It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of 227 | :func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing 228 | epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling 229 | train and validation methods. 230 | 231 | Example: 232 | >>> layer = nn.Linear(10, 1) 233 | >>> optimizer = Adam(layer.parameters(), lr=0.02) 234 | >>> scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=40) 235 | >>> # 236 | >>> # the default case 237 | >>> for epoch in range(40): 238 | ... # train(...) 239 | ... # validate(...) 240 | ... scheduler.step() 241 | >>> # 242 | >>> # passing epoch param case 243 | >>> for epoch in range(40): 244 | ... scheduler.step(epoch) 245 | ... # train(...) 246 | ... # validate(...) 247 | """ 248 | 249 | def __init__( 250 | self, 251 | optimizer, 252 | warmup_epochs, 253 | max_epochs, 254 | warmup_start_lr = 0.0, 255 | eta_min = 1e-8, 256 | last_epoch = -1, 257 | ): 258 | """ 259 | Args: 260 | optimizer (Optimizer): Wrapped optimizer. 261 | warmup_epochs (int): Maximum number of iterations for linear warmup 262 | max_epochs (int): Maximum number of iterations 263 | warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0. 264 | eta_min (float): Minimum learning rate. Default: 0. 265 | last_epoch (int): The index of last epoch. Default: -1. 266 | """ 267 | self.warmup_epochs = warmup_epochs 268 | self.max_epochs = max_epochs 269 | self.warmup_start_lr = warmup_start_lr 270 | self.eta_min = eta_min 271 | 272 | super(LinearWarmupNoDecayLR, self).__init__(optimizer, last_epoch) 273 | 274 | def get_lr(self): 275 | """ 276 | Compute learning rate using chainable form of the scheduler 277 | """ 278 | if not self._get_lr_called_within_step: 279 | warnings.warn( 280 | "To get the last learning rate computed by the scheduler, " 281 | "please use `get_last_lr()`.", 282 | UserWarning, 283 | ) 284 | 285 | if self.last_epoch == 0: 286 | return [self.warmup_start_lr] * len(self.base_lrs) 287 | elif self.last_epoch < self.warmup_epochs: 288 | return [ 289 | group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 290 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 291 | ] 292 | else: 293 | return self.base_lrs 294 | 295 | def _get_closed_form_lr(self): 296 | """ 297 | Called when epoch is passed as a param to the `step` function of the scheduler. 298 | """ 299 | if self.last_epoch < self.warmup_epochs: 300 | return [ 301 | self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 302 | for base_lr in self.base_lrs 303 | ] 304 | 305 | return self.base_lrs -------------------------------------------------------------------------------- /libs/modeling/i3d.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import numpy as np 4 | import torch 5 | 6 | def get_padding_shape(filter_shape, stride): 7 | def _pad_top_bottom(filter_dim, stride_val): 8 | pad_along = max(filter_dim - stride_val, 0) 9 | pad_top = pad_along // 2 10 | pad_bottom = pad_along - pad_top 11 | return pad_top, pad_bottom 12 | 13 | padding_shape = [] 14 | for filter_dim, stride_val in zip(filter_shape, stride): 15 | pad_top, pad_bottom = _pad_top_bottom(filter_dim, stride_val) 16 | padding_shape.append(pad_top) 17 | padding_shape.append(pad_bottom) 18 | depth_top = padding_shape.pop(0) 19 | depth_bottom = padding_shape.pop(0) 20 | padding_shape.append(depth_top) 21 | padding_shape.append(depth_bottom) 22 | 23 | return tuple(padding_shape) 24 | 25 | 26 | def simplify_padding(padding_shapes): 27 | all_same = True 28 | padding_init = padding_shapes[0] 29 | for pad in padding_shapes[1:]: 30 | if pad != padding_init: 31 | all_same = False 32 | return all_same, padding_init 33 | 34 | 35 | class Unit3Dpy(torch.nn.Module): 36 | def __init__(self, 37 | in_channels, 38 | out_channels, 39 | kernel_size=(1, 1, 1), 40 | stride=(1, 1, 1), 41 | activation='relu', 42 | padding='SAME', 43 | use_bias=False, 44 | use_bn=True): 45 | super(Unit3Dpy, self).__init__() 46 | 47 | self.padding = padding 48 | self.activation = activation 49 | self.use_bn = use_bn 50 | if padding == 'SAME': 51 | padding_shape = get_padding_shape(kernel_size, stride) 52 | simplify_pad, pad_size = simplify_padding(padding_shape) 53 | self.simplify_pad = simplify_pad 54 | elif padding == 'VALID': 55 | padding_shape = 0 56 | else: 57 | raise ValueError( 58 | 'padding should be in [VALID|SAME] but got {}'.format(padding)) 59 | 60 | if padding == 'SAME': 61 | if not simplify_pad: 62 | self.pad = torch.nn.ConstantPad3d(padding_shape, 0) 63 | self.conv3d = torch.nn.Conv3d( 64 | in_channels, 65 | out_channels, 66 | kernel_size, 67 | stride=stride, 68 | bias=use_bias) 69 | else: 70 | self.conv3d = torch.nn.Conv3d( 71 | in_channels, 72 | out_channels, 73 | kernel_size, 74 | stride=stride, 75 | padding=pad_size, 76 | bias=use_bias) 77 | elif padding == 'VALID': 78 | self.conv3d = torch.nn.Conv3d( 79 | in_channels, 80 | out_channels, 81 | kernel_size, 82 | padding=padding_shape, 83 | stride=stride, 84 | bias=use_bias) 85 | else: 86 | raise ValueError( 87 | 'padding should be in [VALID|SAME] but got {}'.format(padding)) 88 | 89 | if self.use_bn: 90 | self.batch3d = torch.nn.BatchNorm3d(out_channels) 91 | 92 | if activation == 'relu': 93 | self.activation = torch.nn.functional.relu 94 | 95 | def forward(self, inp): 96 | if self.padding == 'SAME' and self.simplify_pad is False: 97 | inp = self.pad(inp) 98 | out = self.conv3d(inp) 99 | if self.use_bn: 100 | out = self.batch3d(out) 101 | if self.activation is not None: 102 | out = torch.nn.functional.relu(out) 103 | return out 104 | 105 | 106 | class MaxPool3dTFPadding(torch.nn.Module): 107 | def __init__(self, kernel_size, stride=None, padding='SAME'): 108 | super(MaxPool3dTFPadding, self).__init__() 109 | if padding == 'SAME': 110 | padding_shape = get_padding_shape(kernel_size, stride) 111 | self.padding_shape = padding_shape 112 | self.pad = torch.nn.ConstantPad3d(padding_shape, 0) 113 | self.pool = torch.nn.MaxPool3d(kernel_size, stride, ceil_mode=True) 114 | 115 | def forward(self, inp): 116 | inp = self.pad(inp) 117 | out = self.pool(inp) 118 | return out 119 | 120 | 121 | class Mixed(torch.nn.Module): 122 | def __init__(self, in_channels, out_channels): 123 | super(Mixed, self).__init__() 124 | # Branch 0 125 | self.branch_0 = Unit3Dpy( 126 | in_channels, out_channels[0], kernel_size=(1, 1, 1)) 127 | 128 | # Branch 1 129 | branch_1_conv1 = Unit3Dpy( 130 | in_channels, out_channels[1], kernel_size=(1, 1, 1)) 131 | branch_1_conv2 = Unit3Dpy( 132 | out_channels[1], out_channels[2], kernel_size=(3, 3, 3)) 133 | self.branch_1 = torch.nn.Sequential(branch_1_conv1, branch_1_conv2) 134 | 135 | # Branch 2 136 | branch_2_conv1 = Unit3Dpy( 137 | in_channels, out_channels[3], kernel_size=(1, 1, 1)) 138 | branch_2_conv2 = Unit3Dpy( 139 | out_channels[3], out_channels[4], kernel_size=(3, 3, 3)) 140 | self.branch_2 = torch.nn.Sequential(branch_2_conv1, branch_2_conv2) 141 | 142 | # Branch3 143 | branch_3_pool = MaxPool3dTFPadding( 144 | kernel_size=(3, 3, 3), stride=(1, 1, 1), padding='SAME') 145 | branch_3_conv2 = Unit3Dpy( 146 | in_channels, out_channels[5], kernel_size=(1, 1, 1)) 147 | self.branch_3 = torch.nn.Sequential(branch_3_pool, branch_3_conv2) 148 | 149 | def forward(self, inp): 150 | out_0 = self.branch_0(inp) 151 | out_1 = self.branch_1(inp) 152 | out_2 = self.branch_2(inp) 153 | out_3 = self.branch_3(inp) 154 | out = torch.cat((out_0, out_1, out_2, out_3), 1) 155 | return out 156 | 157 | 158 | class I3D(torch.nn.Module): 159 | def __init__(self, 160 | num_classes, 161 | modality='rgb', 162 | dropout_prob=0, 163 | name='inception'): 164 | super(I3D, self).__init__() 165 | 166 | self.name = name 167 | self.num_classes = num_classes 168 | if modality == 'rgb': 169 | in_channels = 3 170 | elif modality == 'flow': 171 | in_channels = 2 172 | else: 173 | raise ValueError( 174 | '{} not among known modalities [rgb|flow]'.format(modality)) 175 | self.modality = modality 176 | 177 | conv3d_1a_7x7 = Unit3Dpy( 178 | out_channels=64, 179 | in_channels=in_channels, 180 | kernel_size=(7, 7, 7), 181 | stride=(2, 2, 2), 182 | padding='SAME') 183 | # 1st conv-pool 184 | self.conv3d_1a_7x7 = conv3d_1a_7x7 185 | self.maxPool3d_2a_3x3 = MaxPool3dTFPadding( 186 | kernel_size=(1, 3, 3), stride=(1, 2, 2), padding='SAME') 187 | # conv conv 188 | conv3d_2b_1x1 = Unit3Dpy( 189 | out_channels=64, 190 | in_channels=64, 191 | kernel_size=(1, 1, 1), 192 | padding='SAME') 193 | self.conv3d_2b_1x1 = conv3d_2b_1x1 194 | conv3d_2c_3x3 = Unit3Dpy( 195 | out_channels=192, 196 | in_channels=64, 197 | kernel_size=(3, 3, 3), 198 | padding='SAME') 199 | self.conv3d_2c_3x3 = conv3d_2c_3x3 200 | self.maxPool3d_3a_3x3 = MaxPool3dTFPadding( 201 | kernel_size=(1, 3, 3), stride=(1, 2, 2), padding='SAME') 202 | 203 | # Mixed_3b 204 | self.mixed_3b = Mixed(192, [64, 96, 128, 16, 32, 32]) 205 | self.mixed_3c = Mixed(256, [128, 128, 192, 32, 96, 64]) 206 | 207 | self.maxPool3d_4a_3x3 = MaxPool3dTFPadding( 208 | kernel_size=(3, 3, 3), stride=(2, 2, 2), padding='SAME') 209 | 210 | # Mixed 4 211 | self.mixed_4b = Mixed(480, [192, 96, 208, 16, 48, 64]) 212 | self.mixed_4c = Mixed(512, [160, 112, 224, 24, 64, 64]) 213 | self.mixed_4d = Mixed(512, [128, 128, 256, 24, 64, 64]) 214 | self.mixed_4e = Mixed(512, [112, 144, 288, 32, 64, 64]) 215 | self.mixed_4f = Mixed(528, [256, 160, 320, 32, 128, 128]) 216 | 217 | self.maxPool3d_5a_2x2 = MaxPool3dTFPadding( 218 | kernel_size=(2, 2, 2), stride=(2, 2, 2), padding='SAME') 219 | 220 | # Mixed 5 221 | self.mixed_5b = Mixed(832, [256, 160, 320, 32, 128, 128]) 222 | self.mixed_5c = Mixed(832, [384, 192, 384, 48, 128, 128]) 223 | 224 | # self.avg_pool = torch.nn.AvgPool3d((2, 7, 7), (1, 1, 1)) 225 | # keep temporal dim 226 | # self.avg_pool = torch.nn.AvgPool3d((1, 7, 7), (1, 1, 1)) 227 | # self.spacial_se = torch.nn.Sequential( 228 | # torch.nn.Conv2d(1024, 1, 3, 1, 1), 229 | # torch.nn.Sigmoid() 230 | # ) 231 | # self.avg_pool_se = torch.nn.AvgPool2d(7, 7) 232 | 233 | # self.dropout = torch.nn.Dropout(dropout_prob) 234 | self.conv3d_0c_1x1 = Unit3Dpy( 235 | in_channels=1024, 236 | out_channels=self.num_classes, 237 | kernel_size=(1, 1, 1), 238 | activation=None, 239 | use_bias=True, 240 | use_bn=False) 241 | # self.softmax = torch.nn.Softmax(1) 242 | # self.softmax = torch.nn.Sigmiod(1) 243 | 244 | def get_logits_dim(self): 245 | return 1024 246 | 247 | def forward(self, inp): 248 | # Preprocessing 249 | out = self.conv3d_1a_7x7(inp) 250 | out = self.maxPool3d_2a_3x3(out) 251 | out = self.conv3d_2b_1x1(out) 252 | out = self.conv3d_2c_3x3(out) 253 | out = self.maxPool3d_3a_3x3(out) 254 | out = self.mixed_3b(out) 255 | out = self.mixed_3c(out) 256 | out = self.maxPool3d_4a_3x3(out) 257 | out = self.mixed_4b(out) 258 | out = self.mixed_4c(out) 259 | out = self.mixed_4d(out) 260 | out = self.mixed_4e(out) 261 | out = self.mixed_4f(out) 262 | out = self.maxPool3d_5a_2x2(out) 263 | out = self.mixed_5b(out) 264 | out = self.mixed_5c(out) 265 | 266 | # out = out.mean(2) 267 | # feature = self.avg_pool_se(self.spacial_se(out) * out) 268 | # out = self.avg_pool(out) 269 | return out 270 | # out = self.dropout(feature) 271 | # out = self.conv3d_0c_1x1(out) 272 | # out = out.squeeze(3) 273 | # out = out.squeeze(3) 274 | # out,_ = out.max(2) 275 | # out_logits = out 276 | # out = self.softmax(out_logits) 277 | # # out = self.sigmoid(out_logits) # B * C 278 | # return feature, out, out_logits 279 | 280 | def load_tf_weights(self, sess): 281 | state_dict = {} 282 | if self.modality == 'rgb': 283 | prefix = 'RGB/inception_i3d' 284 | elif self.modality == 'flow': 285 | prefix = 'Flow/inception_i3d' 286 | load_conv3d(state_dict, 'conv3d_1a_7x7', sess, 287 | os.path.join(prefix, 'Conv3d_1a_7x7')) 288 | load_conv3d(state_dict, 'conv3d_2b_1x1', sess, 289 | os.path.join(prefix, 'Conv3d_2b_1x1')) 290 | load_conv3d(state_dict, 'conv3d_2c_3x3', sess, 291 | os.path.join(prefix, 'Conv3d_2c_3x3')) 292 | 293 | load_mixed(state_dict, 'mixed_3b', sess, 294 | os.path.join(prefix, 'Mixed_3b')) 295 | load_mixed(state_dict, 'mixed_3c', sess, 296 | os.path.join(prefix, 'Mixed_3c')) 297 | load_mixed(state_dict, 'mixed_4b', sess, 298 | os.path.join(prefix, 'Mixed_4b')) 299 | load_mixed(state_dict, 'mixed_4c', sess, 300 | os.path.join(prefix, 'Mixed_4c')) 301 | load_mixed(state_dict, 'mixed_4d', sess, 302 | os.path.join(prefix, 'Mixed_4d')) 303 | load_mixed(state_dict, 'mixed_4e', sess, 304 | os.path.join(prefix, 'Mixed_4e')) 305 | # Here goest to 0.1 max error with tf 306 | load_mixed(state_dict, 'mixed_4f', sess, 307 | os.path.join(prefix, 'Mixed_4f')) 308 | 309 | load_mixed( 310 | state_dict, 311 | 'mixed_5b', 312 | sess, 313 | os.path.join(prefix, 'Mixed_5b'), 314 | fix_typo=True) 315 | load_mixed(state_dict, 'mixed_5c', sess, 316 | os.path.join(prefix, 'Mixed_5c')) 317 | load_conv3d( 318 | state_dict, 319 | 'conv3d_0c_1x1', 320 | sess, 321 | os.path.join(prefix, 'Logits', 'Conv3d_0c_1x1'), 322 | bias=True, 323 | bn=False) 324 | self.load_state_dict(state_dict) 325 | 326 | 327 | def get_conv_params(sess, name, bias=False): 328 | # Get conv weights 329 | conv_weights_tensor = sess.graph.get_tensor_by_name( 330 | os.path.join(name, 'w:0')) 331 | if bias: 332 | conv_bias_tensor = sess.graph.get_tensor_by_name( 333 | os.path.join(name, 'b:0')) 334 | conv_bias = sess.run(conv_bias_tensor) 335 | conv_weights = sess.run(conv_weights_tensor) 336 | conv_shape = conv_weights.shape 337 | 338 | kernel_shape = conv_shape[0:3] 339 | in_channels = conv_shape[3] 340 | out_channels = conv_shape[4] 341 | 342 | conv_op = sess.graph.get_operation_by_name( 343 | os.path.join(name, 'convolution')) 344 | padding_name = conv_op.get_attr('padding') 345 | padding = _get_padding(padding_name, kernel_shape) 346 | all_strides = conv_op.get_attr('strides') 347 | strides = all_strides[1:4] 348 | conv_params = [ 349 | conv_weights, kernel_shape, in_channels, out_channels, strides, padding 350 | ] 351 | if bias: 352 | conv_params.append(conv_bias) 353 | return conv_params 354 | 355 | 356 | def get_bn_params(sess, name): 357 | moving_mean_tensor = sess.graph.get_tensor_by_name( 358 | os.path.join(name, 'moving_mean:0')) 359 | moving_var_tensor = sess.graph.get_tensor_by_name( 360 | os.path.join(name, 'moving_variance:0')) 361 | beta_tensor = sess.graph.get_tensor_by_name(os.path.join(name, 'beta:0')) 362 | moving_mean = sess.run(moving_mean_tensor) 363 | moving_var = sess.run(moving_var_tensor) 364 | beta = sess.run(beta_tensor) 365 | return moving_mean, moving_var, beta 366 | 367 | 368 | def _get_padding(padding_name, conv_shape): 369 | padding_name = padding_name.decode("utf-8") 370 | if padding_name == "VALID": 371 | return [0, 0] 372 | elif padding_name == "SAME": 373 | # return [math.ceil(int(conv_shape[0])/2), math.ceil(int(conv_shape[1])/2)] 374 | return [ 375 | math.floor(int(conv_shape[0]) / 2), 376 | math.floor(int(conv_shape[1]) / 2), 377 | math.floor(int(conv_shape[2]) / 2) 378 | ] 379 | else: 380 | raise ValueError('Invalid padding name ' + padding_name) 381 | 382 | 383 | def load_conv3d(state_dict, name_pt, sess, name_tf, bias=False, bn=True): 384 | # Transfer convolution params 385 | conv_name_tf = os.path.join(name_tf, 'conv_3d') 386 | conv_params = get_conv_params(sess, conv_name_tf, bias=bias) 387 | if bias: 388 | conv_weights, kernel_shape, in_channels, out_channels, strides, padding, conv_bias = conv_params 389 | else: 390 | conv_weights, kernel_shape, in_channels, out_channels, strides, padding = conv_params 391 | 392 | conv_weights_rs = np.transpose( 393 | conv_weights, (4, 3, 0, 1, 394 | 2)) # to pt format (out_c, in_c, depth, height, width) 395 | state_dict[name_pt + '.conv3d.weight'] = torch.from_numpy(conv_weights_rs) 396 | if bias: 397 | state_dict[name_pt + '.conv3d.bias'] = torch.from_numpy(conv_bias) 398 | 399 | # Transfer batch norm params 400 | if bn: 401 | conv_tf_name = os.path.join(name_tf, 'batch_norm') 402 | moving_mean, moving_var, beta = get_bn_params(sess, conv_tf_name) 403 | 404 | out_planes = conv_weights_rs.shape[0] 405 | state_dict[name_pt + '.batch3d.weight'] = torch.ones(out_planes) 406 | state_dict[name_pt + 407 | '.batch3d.bias'] = torch.from_numpy(beta.squeeze()) 408 | state_dict[name_pt 409 | + '.batch3d.running_mean'] = torch.from_numpy(moving_mean.squeeze()) 410 | state_dict[name_pt 411 | + '.batch3d.running_var'] = torch.from_numpy(moving_var.squeeze()) 412 | 413 | 414 | def load_mixed(state_dict, name_pt, sess, name_tf, fix_typo=False): 415 | # Branch 0 416 | load_conv3d(state_dict, name_pt + '.branch_0', sess, 417 | os.path.join(name_tf, 'Branch_0/Conv3d_0a_1x1')) 418 | 419 | # Branch .1 420 | load_conv3d(state_dict, name_pt + '.branch_1.0', sess, 421 | os.path.join(name_tf, 'Branch_1/Conv3d_0a_1x1')) 422 | load_conv3d(state_dict, name_pt + '.branch_1.1', sess, 423 | os.path.join(name_tf, 'Branch_1/Conv3d_0b_3x3')) 424 | 425 | # Branch 2 426 | load_conv3d(state_dict, name_pt + '.branch_2.0', sess, 427 | os.path.join(name_tf, 'Branch_2/Conv3d_0a_1x1')) 428 | if fix_typo: 429 | load_conv3d(state_dict, name_pt + '.branch_2.1', sess, 430 | os.path.join(name_tf, 'Branch_2/Conv3d_0a_3x3')) 431 | else: 432 | load_conv3d(state_dict, name_pt + '.branch_2.1', sess, 433 | os.path.join(name_tf, 'Branch_2/Conv3d_0b_3x3')) 434 | 435 | # Branch 3 436 | load_conv3d(state_dict, name_pt + '.branch_3.1', sess, 437 | os.path.join(name_tf, 'Branch_3/Conv3d_0b_1x1')) 438 | 439 | 440 | 441 | 442 | 443 | -------------------------------------------------------------------------------- /libs/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import sys 4 | import numpy as np 5 | import random 6 | from copy import deepcopy 7 | from pprint import pprint 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | import torch.backends.cudnn as cudnn 13 | import torch.nn.functional as F 14 | import gc 15 | 16 | from .lr_schedulers import LinearWarmupMultiStepLR, LinearWarmupCosineAnnealingLR, LinearWarmupNoDecayLR 17 | 18 | from ..modeling import MaskedConv1D, Scale, AffineDropPath, LayerNorm 19 | from .metrics import evaluate, MetricsDict 20 | from .preprocessing import preprocessing 21 | from .postprocessing import update_metric_dict_with_model_output 22 | from collections import defaultdict 23 | 24 | 25 | from torch.profiler import profile, record_function, ProfilerActivity 26 | 27 | class Logger(object): 28 | def __init__(self, log_file): 29 | self.terminal = sys.stdout 30 | self.log_file = log_file 31 | 32 | def write(self, message): 33 | #import ipdb; ipdb.set_trace() 34 | message = str(message) 35 | #pprint(message, stream = sys.__stdout__) 36 | self.terminal.write(message) 37 | #print(message + "\n", file=sys.__stdout__, end = "") 38 | with open(self.log_file, 'a') as f: 39 | f.write(message) 40 | def flush(self): 41 | self.terminal.flush() 42 | 43 | def set_bn_eval(m): 44 | classname = m.__class__.__name__ 45 | if classname.find('BatchNorm') != -1: 46 | m.eval() 47 | 48 | def freeze_top_3_blocks(m): 49 | classname = m.__class__.__name__ 50 | if ( 51 | classname.find('conv3d_1a_7x7') != -1 or 52 | classname.find('conv3d_2b_1x1') != -1 or 53 | classname.find('conv3d_2c_3x3') != -1 or 54 | classname.find('mixed_3b') != -1 or 55 | classname.find('mixed_3c') != -1 or 56 | classname.find('mixed_4b') != -1 57 | ): 58 | m.requires_grad = False 59 | m.eval() 60 | 61 | 62 | ################################################################################ 63 | def fix_random_seed(seed, include_cuda=True): 64 | rng_generator = torch.manual_seed(seed) 65 | np.random.seed(seed) 66 | random.seed(seed) 67 | os.environ["PYTHONHASHSEED"] = str(seed) 68 | if include_cuda: 69 | # training: disable cudnn benchmark to ensure the reproducibility 70 | cudnn.enabled = True 71 | cudnn.benchmark = False 72 | cudnn.deterministic = True 73 | torch.cuda.manual_seed(seed) 74 | torch.cuda.manual_seed_all(seed) 75 | # this is needed for CUDA >= 10.2 76 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" 77 | # print("Determininstic is off") 78 | torch.use_deterministic_algorithms(True, warn_only=True) 79 | else: 80 | cudnn.enabled = True 81 | cudnn.benchmark = True 82 | return rng_generator 83 | 84 | 85 | def save_checkpoint(state, is_best, file_folder, 86 | file_name='checkpoint.pth.tar'): 87 | """save checkpoint to file""" 88 | if not os.path.exists(file_folder): 89 | os.mkdir(file_folder) 90 | torch.save(state, os.path.join(file_folder, file_name)) 91 | if is_best: 92 | # skip the optimization / scheduler state 93 | state.pop('optimizer', None) 94 | state.pop('scheduler', None) 95 | torch.save(state, os.path.join(file_folder, 'model_best.pth.tar')) 96 | 97 | 98 | def print_model_params(model): 99 | for name, param in model.named_parameters(): 100 | print(name, param.min().item(), param.max().item(), param.mean().item()) 101 | return 102 | 103 | 104 | def make_optimizer(model, optimizer_config): 105 | """create optimizer 106 | return a supported optimizer 107 | """ 108 | # separate out all parameters that with / without weight decay 109 | # see https://github.com/karpathy/minGPT/blob/master/mingpt/model.py#L134 110 | 111 | low_lr_decay = set() 112 | low_lr_no_decay = set() 113 | 114 | med_lr_decay = set() 115 | med_lr_no_decay = set() 116 | 117 | high_lr_decay = set() 118 | high_lr_no_decay = set() 119 | 120 | whitelist_weight_modules = (torch.nn.Linear, nn.Parameter, torch.nn.Conv1d, torch.nn.Conv2d, MaskedConv1D, torch.nn.Conv3d, torch.nn.MultiheadAttention, torch.nn.ConvTranspose1d) 121 | blacklist_weight_modules = (LayerNorm, torch.nn.GroupNorm,torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.Embedding, torch.nn.LayerNorm) 122 | 123 | # three sets of lr: low_lr for feat_extractor/i3d, med_lr for neck/backbone, and high_lr for heads. 124 | # loop over all modules / params 125 | for mn, m in model.named_modules(): 126 | for pn, p in m.named_parameters(): 127 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name 128 | #low lr for feat extractor 129 | if "feat_extractor" in fpn or "feat_extractor" in pn: 130 | if pn.endswith('bias'): 131 | # all biases will not be decayed 132 | low_lr_no_decay.add(fpn) 133 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 134 | # weights of whitelist modules will be weight decayed 135 | low_lr_decay.add(fpn) 136 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 137 | # weights of blacklist modules will NOT be weight decayed 138 | low_lr_no_decay.add(fpn) 139 | elif pn.endswith('scale') and isinstance(m, (Scale, AffineDropPath)): 140 | # corner case of our scale layer 141 | low_lr_no_decay.add(fpn) 142 | elif pn.endswith('rel_pe'): 143 | # corner case for relative position encoding 144 | low_lr_no_decay.add(fpn) 145 | # med lr for neck 146 | elif ("neck" in fpn or "neck" in pn) or ("backbone" in fpn or "backbone" in pn): 147 | if pn.endswith('bias'): 148 | med_lr_no_decay.add(fpn) 149 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 150 | med_lr_decay.add(fpn) 151 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 152 | med_lr_no_decay.add(fpn) 153 | elif pn.endswith('scale') and isinstance(m, (Scale, AffineDropPath)): 154 | med_lr_no_decay.add(fpn) 155 | elif pn.endswith('rel_pe'): 156 | med_lr_no_decay.add(fpn) 157 | # high lr for quality_score_head 158 | elif ("quality_score_head" in fpn or "quality_score_head" in pn): 159 | if pn.endswith('bias'): 160 | # all biases will not be decayed 161 | high_lr_no_decay.add(fpn) 162 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 163 | # weights of whitelist modules will be weight decayed 164 | high_lr_decay.add(fpn) 165 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 166 | # weights of blacklist modules will NOT be weight decayed 167 | high_lr_no_decay.add(fpn) 168 | elif pn.endswith('scale') and isinstance(m, (Scale, AffineDropPath)): 169 | # corner case of our scale layer 170 | high_lr_no_decay.add(fpn) 171 | elif pn.endswith('rel_pe'): 172 | # corner case for relative position encoding 173 | high_lr_no_decay.add(fpn) 174 | elif pn.endswith('temperature'): 175 | high_lr_no_decay.add(fpn) 176 | else: 177 | raise ValueError(f"Unrecognized parameter: {fpn}") 178 | 179 | # validate that we considered every parameter 180 | param_dict = {pn: p for pn, p in model.named_parameters()} 181 | 182 | # @TODO add assert statements for med_lr_decay, med_lr_no_decay? 183 | #check that all intersections are empty 184 | assert len(high_lr_decay & high_lr_no_decay) == 0, \ 185 | "parameters %s were in both high_lr_decay and high_lr_no_decay set!" \ 186 | % (str(high_lr_decay & high_lr_no_decay), ) 187 | assert len(low_lr_decay & low_lr_no_decay) == 0, \ 188 | "parameters %s were in both low_lr_decay and low_lr_no_decay set!" \ 189 | % (str(low_lr_decay & low_lr_no_decay), ) 190 | assert len(high_lr_decay & low_lr_decay) == 0, \ 191 | "parameters %s were in both high_lr_decay and low_lr_decay set!" \ 192 | % (str(high_lr_decay & low_lr_decay), ) 193 | assert len(high_lr_no_decay & low_lr_no_decay) == 0, \ 194 | "parameters %s were in both high_lr_no_decay and low_lr_no_decay set!" \ 195 | % (str(high_lr_no_decay & low_lr_no_decay), ) 196 | 197 | 198 | union_params = high_lr_decay | high_lr_no_decay | low_lr_decay | low_lr_no_decay | med_lr_decay | med_lr_no_decay 199 | assert len(param_dict.keys() - union_params) == 0, \ 200 | "parameters %s were not separated into either decay/no_decay set!" \ 201 | % (str(param_dict.keys() - union_params), ) 202 | 203 | # create the pytorch optimizer object 204 | if (len(low_lr_decay) + len(low_lr_no_decay) + len(med_lr_decay) + len(med_lr_no_decay)) > 0: 205 | optim_groups = [ 206 | {"params": [param_dict[pn] for pn in sorted(list(high_lr_decay))], 207 | "weight_decay": optimizer_config['weight_decay'], 208 | 'name': 'high_lr_decay'}, 209 | {"params": [param_dict[pn] for pn in sorted(list(high_lr_no_decay))], 210 | "weight_decay": 0.0, 211 | 'name': 'high_lr_no_decay'}, 212 | {"params": [param_dict[pn] for pn in sorted(list(med_lr_decay))], 213 | "weight_decay": optimizer_config['weight_decay'], 214 | "lr": optimizer_config["learning_rate"] * optimizer_config["neck_lr_factor"], 215 | 'name': 'med_lr_decay'}, 216 | {"params": [param_dict[pn] for pn in sorted(list(med_lr_no_decay))], 217 | "weight_decay": 0.0, 218 | "lr": optimizer_config["learning_rate"] * optimizer_config["neck_lr_factor"], 219 | 'name': 'med_lr_no_decay'}, 220 | {"params": [param_dict[pn] for pn in sorted(list(low_lr_decay))], 221 | "weight_decay": optimizer_config['weight_decay'], 222 | "lr": optimizer_config["learning_rate"] * optimizer_config["feature_extractor_factor"], 223 | 'name': 'low_lr_decay'}, 224 | {"params": [param_dict[pn] for pn in sorted(list(low_lr_no_decay))], 225 | "weight_decay": 0.0, 226 | "lr": optimizer_config["learning_rate"] * optimizer_config["feature_extractor_factor"], 227 | 'name': 'low_lr_no_decay'} 228 | ] 229 | else: 230 | optim_groups = [ 231 | {"params": [param_dict[pn] for pn in sorted(list(high_lr_decay))], "weight_decay": optimizer_config['weight_decay']}, 232 | {"params": [param_dict[pn] for pn in sorted(list(high_lr_no_decay))], "weight_decay": 0.0} 233 | ] 234 | 235 | if optimizer_config["type"] == "SGD": 236 | optimizer = optim.SGD( 237 | optim_groups, 238 | lr=optimizer_config["learning_rate"], 239 | momentum=optimizer_config["momentum"] 240 | ) 241 | elif optimizer_config["type"] == "AdamW": 242 | optimizer = optim.AdamW( 243 | optim_groups, 244 | lr=optimizer_config["learning_rate"] 245 | ) 246 | else: 247 | raise TypeError("Unsupported optimizer!") 248 | 249 | return optimizer 250 | 251 | 252 | def make_scheduler( 253 | optimizer, 254 | optimizer_config, 255 | num_iters_per_epoch, 256 | last_epoch=-1 257 | ): 258 | """create scheduler 259 | return a supported scheduler 260 | All scheduler returned by this function should step every iteration 261 | """ 262 | if optimizer_config["warmup"]: 263 | max_epochs = optimizer_config["epochs"] + optimizer_config["warmup_epochs"] 264 | max_steps = max_epochs * num_iters_per_epoch 265 | 266 | # get warmup params 267 | warmup_epochs = optimizer_config["warmup_epochs"] 268 | warmup_steps = warmup_epochs * num_iters_per_epoch 269 | 270 | # with linear warmup: call our custom schedulers 271 | if optimizer_config["schedule_type"] == "cosine": 272 | # Cosine 273 | scheduler = LinearWarmupCosineAnnealingLR( 274 | optimizer, 275 | warmup_steps, 276 | max_steps, 277 | last_epoch=last_epoch 278 | ) 279 | 280 | elif optimizer_config["schedule_type"] == "multistep": 281 | # Multi step 282 | steps = [num_iters_per_epoch * step for step in optimizer_config["schedule_steps"]] 283 | scheduler = LinearWarmupMultiStepLR( 284 | optimizer, 285 | warmup_steps, 286 | steps, 287 | gamma=optimizer_config["schedule_gamma"], 288 | last_epoch=last_epoch 289 | ) 290 | elif optimizer_config["schedule_type"] == "no_decay": 291 | steps = [num_iters_per_epoch * step for step in optimizer_config["schedule_steps"]] 292 | scheduler = LinearWarmupNoDecayLR( 293 | optimizer, 294 | warmup_steps, 295 | steps, 296 | last_epoch=last_epoch 297 | ) 298 | else: 299 | raise TypeError("Unsupported scheduler!") 300 | 301 | else: 302 | max_epochs = optimizer_config["epochs"] 303 | max_steps = max_epochs * num_iters_per_epoch 304 | 305 | # without warmup: call default schedulers 306 | if optimizer_config["schedule_type"] == "cosine": 307 | # step per iteration 308 | scheduler = optim.lr_scheduler.CosineAnnealingLR( 309 | optimizer, 310 | max_steps, 311 | last_epoch=last_epoch 312 | ) 313 | 314 | elif optimizer_config["schedule_type"] == "multistep": 315 | # step every some epochs 316 | steps = [num_iters_per_epoch * step for step in optimizer_config["schedule_steps"]] 317 | scheduler = optim.lr_scheduler.MultiStepLR( 318 | optimizer, 319 | steps, 320 | gamma=schedule_config["gamma"], 321 | last_epoch=last_epoch 322 | ) 323 | else: 324 | raise TypeError("Unsupported scheduler!") 325 | 326 | return scheduler 327 | 328 | 329 | class AverageMeter(object): 330 | """Computes and stores the average and current value. 331 | Used to compute dataset stats from mini-batches 332 | """ 333 | def __init__(self): 334 | self.initialized = False 335 | self.val = None 336 | self.avg = None 337 | self.sum = None 338 | self.count = 0.0 339 | 340 | def initialize(self, val, n): 341 | self.val = val 342 | self.avg = val 343 | self.sum = val * n 344 | self.count = n 345 | self.initialized = True 346 | 347 | def update(self, val, n=1): 348 | if not self.initialized: 349 | self.initialize(val, n) 350 | else: 351 | self.add(val, n) 352 | 353 | def add(self, val, n): 354 | self.val = val 355 | self.sum += val * n 356 | self.count += n 357 | self.avg = self.sum / self.count 358 | 359 | 360 | class ModelEma(torch.nn.Module): 361 | def __init__(self, model, decay=0.999, device=None): 362 | super().__init__() 363 | # make a copy of the model for accumulating moving average of weights 364 | self.module = deepcopy(model) 365 | self.module.eval() 366 | self.decay = decay 367 | self.device = device # perform ema on different device from model if set 368 | if self.device is not None: 369 | self.module.to(device=device) 370 | 371 | def _update(self, model, update_fn): 372 | with torch.no_grad(): 373 | for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): 374 | if self.device is not None: 375 | model_v = model_v.to(device=self.device) 376 | ema_v.copy_(update_fn(ema_v, model_v)) 377 | 378 | def update(self, model): 379 | self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) 380 | 381 | def set(self, model): 382 | self._update(model, update_fn=lambda e, m: m) 383 | 384 | 385 | def my_corn_label_from_logits(logits): 386 | probs = torch.sigmoid(logits) 387 | cum_prob = torch.cumprod(probs, dim=-1) 388 | predict_levels = cum_prob > 0.5 389 | predict_labels = torch.sum(predict_levels, dim=-1) 390 | return predict_labels 391 | 392 | 393 | 394 | def print_training_progress(curr_epoch, curr_iter, total_iters, batch_time, losses_tracker, rho, L2, RL2): 395 | # Format the training progress information 396 | progress_str = [ 397 | f"Epoch: [{curr_epoch:03d}][{curr_iter:05d}/{total_iters:05d}]", 398 | f"Time: {batch_time.val:.2f}s ({batch_time.avg:.2f}s)\n", 399 | f"Total Loss: {losses_tracker['final_loss'].val:.2f} ({losses_tracker['final_loss'].avg:.2f})" 400 | ] 401 | 402 | loss_info = ", ".join([f"{key}: {value.val:.2f} ({value.avg:.2f})" for key, value in losses_tracker.items() if key != "final_loss"]) 403 | progress_str.append(loss_info) 404 | 405 | progress_str.append(f"Running -> Rho: {rho:.4f}, L2: {L2:.2f}, RL2: {RL2:.4f}") 406 | 407 | print('\t'.join(progress_str)) 408 | 409 | 410 | 411 | 412 | ################################################################################ 413 | def train_one_epoch( 414 | train_loader, 415 | model, 416 | optimizer, 417 | scheduler, 418 | curr_epoch, 419 | cfg = None, 420 | tb_writer = None, 421 | print_freq = 20 422 | ): 423 | """Training the model for one epoch""" 424 | # set up meters 425 | batch_time = AverageMeter() 426 | losses_tracker = defaultdict(AverageMeter) 427 | metric_dict = MetricsDict() 428 | # number of iterations per epoch 429 | num_iters = len(train_loader) 430 | # switch to train mode 431 | model.train() 432 | 433 | #fix bn 434 | model.apply(set_bn_eval) 435 | 436 | if "freeze_early_layers" in cfg["train_cfg"]: 437 | if cfg["train_cfg"]["freeze_early_layers"]: 438 | print("Freezing early layers") 439 | model.apply(freeze_top_3_blocks) 440 | 441 | # main training loop 442 | print("\n[Train]: Epoch {:d} started".format(curr_epoch)) 443 | start = time.time() 444 | 445 | optimizer.zero_grad() 446 | for iter_idx, video_list in enumerate(train_loader, 0): 447 | # zero out optimizer 448 | preprocessed_dict = preprocessing(video_list, 449 | feat_extractor_type = cfg["model"]["feat_extractor_type"], 450 | max_seq_len= cfg["model"]["max_seq_len"], 451 | num_phases= cfg["model"]["num_phases"] 452 | ) 453 | 454 | #loss is calculated in the model 455 | criterion_args = {"target" : preprocessed_dict['target'], 456 | "difficulties" : preprocessed_dict["difficulties"], 457 | "gt_actions" : preprocessed_dict["gt_actions"], 458 | "loss_weights" : cfg["train_cfg"]["loss_weights"], 459 | "with_dd" : cfg["dataset"]["with_dd"], 460 | "three_judge_score_scaling" : cfg["dataset"]["three_judge_score_scaling"] 461 | } 462 | # forward 463 | model_output, losses = model(**dict(batched_inputs = preprocessed_dict["batched_inputs"], 464 | batched_masks = preprocessed_dict["batched_masks"], 465 | batched_gt_actions = preprocessed_dict["gt_actions"], 466 | video_ids = preprocessed_dict["video_ids"], 467 | curr_epoch = curr_epoch, 468 | criterion_args = criterion_args 469 | )) 470 | 471 | 472 | for key, value in losses.items(): 473 | losses[key] = losses[key].mean() 474 | 475 | losses["final_loss"] /= cfg["train_cfg"]["accumulation_steps"] 476 | 477 | losses['final_loss'].backward() 478 | 479 | # gradient cliping (to stabilize training if necessary) 480 | if cfg["train_cfg"]["clip_grad_l2norm"] > 0.0: 481 | torch.nn.utils.clip_grad_norm_( 482 | model.parameters(), 483 | cfg["train_cfg"]["clip_grad_l2norm"] 484 | ) 485 | 486 | if ((iter_idx + 1) % cfg["train_cfg"]["accumulation_steps"]) == 0 or (iter_idx == num_iters - 1): 487 | optimizer.step() 488 | optimizer.zero_grad() 489 | scheduler.step() 490 | 491 | 492 | with torch.no_grad(): 493 | update_metric_dict_with_model_output(metric_dict, model_output, preprocessed_dict['gt_scores'],preprocessed_dict['difficulties'], is_val=False, cfg=cfg) 494 | 495 | # track all losses 496 | for key, value in losses.items(): 497 | losses_tracker[key].update(value.item()) 498 | 499 | del losses 500 | 501 | gc.collect() 502 | 503 | # printing (only check the stats when necessary to avoid extra cost) 504 | if ((iter_idx != 0) and (iter_idx % print_freq) == 0) or (iter_idx == num_iters - 1): 505 | # measure elapsed time (sync all kernels) 506 | torch.cuda.synchronize() 507 | batch_time.update((time.time() - start) / print_freq) 508 | start = time.time() 509 | 510 | rho, L2, RL2 = evaluate(metric_dict, 511 | is_train=True, 512 | dataset_name=cfg["dataset_name"], 513 | ) 514 | 515 | # log to tensor board 516 | lr = scheduler.get_last_lr()[0] 517 | global_step = curr_epoch * num_iters + iter_idx 518 | if tb_writer is not None: 519 | # learning rate (after stepping) 520 | tb_writer.add_scalar( 521 | 'train/learning_rate', 522 | lr, 523 | global_step 524 | ) 525 | # all losses 526 | losses_vals_dict = {key : value.val for key, value in losses_tracker.items()} 527 | 528 | tb_writer.add_scalars('train/all_losses', losses_vals_dict, global_step) 529 | tb_writer.add_scalars('train/L2, RL2', {'L2': L2, 'RL2': RL2}, global_step) 530 | tb_writer.add_scalar('train/rho', rho, global_step) 531 | 532 | print_training_progress(curr_epoch, iter_idx, num_iters, batch_time, losses_tracker, rho, L2, RL2) 533 | 534 | rho, L2, RL2 = evaluate(metric_dict, 535 | is_train = True, 536 | dataset_name = cfg["dataset_name"] 537 | ) 538 | 539 | 540 | print('Full epoch metrics -> Rho : {:.4f}, L2 : {:.2f}, RL2 : {:.4f}'.format(rho, L2, RL2)) 541 | 542 | lr = scheduler.get_last_lr()[0] 543 | print("[Train]: Epoch {:d} finished with lr={:.8f}\n".format(curr_epoch, lr)) 544 | 545 | return 546 | 547 | 548 | def valid_one_epoch( 549 | val_loader, 550 | model, 551 | curr_epoch, 552 | cfg, 553 | tb_writer, 554 | print_freq, 555 | save_predictions = False 556 | ): 557 | 558 | """Test the model on the validation set""" 559 | 560 | # switch to evaluate mode 561 | model.eval() 562 | 563 | num_iters = len(val_loader) 564 | metric_dict = MetricsDict() 565 | 566 | # loop over validation set 567 | print(["Validation"]) 568 | for iter_idx, video_list in enumerate(val_loader, 0): 569 | # forward the model (wo. grad) 570 | with torch.no_grad(): 571 | preprocessed_dict = preprocessing(video_list, 572 | feat_extractor_type = cfg["model"]["feat_extractor_type"], 573 | max_seq_len= cfg["model"]["max_seq_len"], 574 | num_phases= cfg["model"]["num_phases"] 575 | ) 576 | # forward / backward the model 577 | model_output, _ = model(**dict(batched_inputs = preprocessed_dict["batched_inputs"], 578 | batched_masks = preprocessed_dict["batched_masks"], 579 | batched_gt_actions = preprocessed_dict["gt_actions"], 580 | video_ids = preprocessed_dict["video_ids"], 581 | curr_epoch = curr_epoch 582 | )) 583 | # only used for evaluation 584 | metric_dict.update("video_ids", preprocessed_dict["video_ids"]) 585 | 586 | update_metric_dict_with_model_output(metric_dict, model_output, preprocessed_dict['gt_scores'],preprocessed_dict['difficulties'], is_val=True, cfg=cfg) 587 | 588 | if ((iter_idx != 0) and (iter_idx % print_freq) == 0) or (iter_idx == num_iters - 1): 589 | torch.cuda.synchronize() 590 | 591 | block0 = f'Epoch: [{curr_epoch:03d}][{iter_idx:05d}/{num_iters:05d}]' 592 | print(block0) 593 | 594 | 595 | rho, L2, RL2 = evaluate(metric_dict, 596 | is_train=False, 597 | dataset_name = cfg["dataset_name"], 598 | curr_epoch = curr_epoch 599 | ) 600 | 601 | print('Eval Metrics -> Rho : {:.4f}, L2 : {:.2f}, RL2 : {:.4f}'.format(rho, L2, RL2)) 602 | 603 | # log metrics to tb_writer 604 | if tb_writer is not None: 605 | tb_writer.add_scalar('validation/rho', rho, curr_epoch) 606 | tb_writer.add_scalars('validation/L2, RL2', {'L2': L2, 'RL2': RL2}, curr_epoch) 607 | 608 | if save_predictions: 609 | metric_dict = metric_dict.get_metric_dict() 610 | else: 611 | metric_dict = None 612 | 613 | return rho, RL2, metric_dict -------------------------------------------------------------------------------- /libs/modeling/blocks.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | from .weight_init import trunc_normal_ 8 | 9 | def reparameterize(B_mu, B_sigma, num_samples=1): 10 | expanded_mu = B_mu.expand(num_samples, *B_mu.shape) 11 | expanded_sigma = B_sigma.expand(num_samples, *B_sigma.shape) 12 | norm_v = torch.randn_like(expanded_mu).detach() 13 | # reparameterization trick 14 | samples = expanded_mu + expanded_sigma * norm_v 15 | return samples 16 | 17 | class PhaseDistComposer(nn.Module): 18 | def __init__(self, dataset_name='finediving', dim = 256): 19 | super().__init__() 20 | self.dataset_name = dataset_name 21 | if dataset_name == 'finediving': 22 | self.level_0_to_1_index = [torch.tensor(x).detach().long() for x in [[0,1,2,3,4,5,6],[7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27],[28]]] 23 | elif dataset_name =="mtl_aqa": 24 | self.level_0_to_1_index = [torch.tensor(x).detach().long() for x in [[0,1,2,3,4],[5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22],[23]]] 25 | elif dataset_name == "needle_passing": 26 | self.level_0_to_1_index = [torch.tensor(x).detach().long() for x in [[0,1,2,3,4,5,6,7],[0,1,2,3,4,5,6,7],[0,1,2,3,4,5,6,7]]] 27 | elif dataset_name == "knot_tying": 28 | self.level_0_to_1_index = [torch.tensor(x).detach().long() for x in [[0,1,2,3,4,5],[0,1,2,3,4,5],[0,1,2,3,4,5]]] 29 | elif dataset_name == "suturing": 30 | self.level_0_to_1_index = [torch.tensor(x).detach().long() for x in [[0,1,2,3,4,5,6,7,8,9],[0,1,2,3,4,5,6,7,8,9],[0,1,2,3,4,5,6,7,8,9]]] 31 | 32 | self.level_2_mlp = nn.Sequential( 33 | nn.Linear(dim,dim), 34 | nn.ReLU(), 35 | nn.Linear(dim,dim) 36 | ) 37 | 38 | self.level_3_mlp = nn.Sequential( 39 | nn.Linear(dim,dim), 40 | nn.ReLU(), 41 | nn.Linear(dim,dim) 42 | ) 43 | 44 | def get_one_sample(self, phase_mean_emb, phase_var_emb, gt_actions): 45 | #B, Phases, C 46 | if phase_var_emb is None: 47 | level_0_samples = phase_mean_emb 48 | else: 49 | phase_sqrt_var_emb = torch.sqrt(phase_var_emb) 50 | level_0_samples = reparameterize(phase_mean_emb, phase_sqrt_var_emb, num_samples=1).squeeze(0) 51 | 52 | B, num_phases, C = phase_mean_emb.shape 53 | 54 | level_1_inputs_separated = [level_0_samples[:,node_input_idx,:] for node_input_idx in self.level_0_to_1_index] 55 | 56 | level_1_presence_separated = [gt_actions[:,node_input_idx].unsqueeze(-1) for node_input_idx in self.level_0_to_1_index] 57 | 58 | #num nodes, B x phases going into the node x C 59 | level_1_inputs_filtered_by_presence = [level_1_inputs_separated[node_num] * level_1_presence_separated[node_num] for node_num in range(len(level_1_inputs_separated))] 60 | 61 | #num nodes, B x C 62 | level_1_inputs_summed_separate = [torch.sum(level_1_inputs_filtered_by_presence[node_num], dim=1) 63 | / torch.sum(level_1_presence_separated[node_num], dim=1) 64 | for node_num in range(len(level_1_inputs_filtered_by_presence))] 65 | 66 | #num nodes, B x C 67 | level_2_outputs = [self.level_2_mlp(x) for x in level_1_inputs_summed_separate] 68 | 69 | #num_nodes x B x C 70 | level_2_outputs = torch.stack(level_2_outputs, dim=0) 71 | 72 | #level_3_input is the same as output since we have more decoding heads later 73 | level_3_output = torch.sum(level_2_outputs, dim=0) / level_2_outputs.shape[0] 74 | 75 | #check if nan in level_3_output 76 | if torch.isnan(level_3_output).any(): 77 | import ipdb; ipdb.set_trace() 78 | 79 | return level_3_output 80 | 81 | 82 | def process_phases(self, phase_mean_emb, phase_var_emb, num_samples, gt_actions): 83 | all_samples = [] 84 | for i in range(num_samples): 85 | one_sample_per_item = self.get_one_sample(phase_mean_emb, phase_var_emb, gt_actions) 86 | all_samples.append(one_sample_per_item) 87 | 88 | #num_samples, B, C 89 | all_samples = torch.stack(all_samples, dim=0) 90 | 91 | if phase_var_emb is None: 92 | return all_samples, None 93 | 94 | masked_phase_var_emb = phase_var_emb * gt_actions.unsqueeze(-1) 95 | 96 | #global variance 97 | global_masked_var = torch.sum(masked_phase_var_emb, dim=1) / torch.sum(gt_actions, dim=1).unsqueeze(-1).detach() 98 | 99 | #global sigma 100 | global_sigma = torch.sqrt(global_masked_var) 101 | 102 | return all_samples, global_sigma 103 | 104 | 105 | class MaskedConv1D(nn.Module): 106 | """ 107 | Masked 1D convolution. Interface remains the same as Conv1d. 108 | Only support a sub set of 1d convs 109 | """ 110 | def __init__( 111 | self, 112 | in_channels, 113 | out_channels, 114 | kernel_size, 115 | stride=1, 116 | padding=0, 117 | dilation=1, 118 | groups=1, 119 | bias=True, 120 | padding_mode='zeros' 121 | ): 122 | super().__init__() 123 | # element must be aligned 124 | assert (kernel_size % 2 == 1) and (kernel_size // 2 == padding) 125 | # stride 126 | self.stride = stride 127 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, 128 | stride, padding, dilation, groups, bias, padding_mode) 129 | # zero out the bias term if it exists 130 | if bias: 131 | torch.nn.init.constant_(self.conv.bias, 0.) 132 | 133 | def forward(self, x, mask): 134 | # x: batch size, feature channel, sequence length, 135 | # mask: batch size, 1, sequence length (bool) 136 | B, C, T = x.size() 137 | 138 | # input length must be divisible by stride 139 | assert T % self.stride == 0 140 | 141 | # conv 142 | out_conv = self.conv(x) 143 | # compute the mask 144 | if self.stride > 1: 145 | # downsample the mask using nearest neighbor 146 | out_mask = F.interpolate( 147 | mask.to(x.dtype), size=out_conv.size(-1), mode='nearest' 148 | ) 149 | else: 150 | # masking out the features 151 | out_mask = mask.to(x.dtype) 152 | 153 | # masking the output, stop grad to mask 154 | out_conv = out_conv * out_mask.detach() 155 | out_mask = out_mask.bool() 156 | return out_conv, out_mask 157 | 158 | 159 | class LayerNorm(nn.Module): 160 | """ 161 | LayerNorm that supports inputs of size B, C, T 162 | """ 163 | def __init__( 164 | self, 165 | num_channels, 166 | eps = 1e-5, 167 | affine = True, 168 | device = None, 169 | dtype = None, 170 | ): 171 | super().__init__() 172 | factory_kwargs = {'device': device, 'dtype': dtype} 173 | self.num_channels = num_channels 174 | self.eps = eps 175 | self.affine = affine 176 | 177 | if self.affine: 178 | self.weight = nn.Parameter( 179 | torch.ones([1, num_channels, 1], **factory_kwargs)) 180 | self.bias = nn.Parameter( 181 | torch.zeros([1, num_channels, 1], **factory_kwargs)) 182 | else: 183 | self.register_parameter('weight', None) 184 | self.register_parameter('bias', None) 185 | 186 | def forward(self, x): 187 | assert x.dim() == 3 188 | assert x.shape[1] == self.num_channels 189 | 190 | # normalization along C channels 191 | mu = torch.mean(x, dim=1, keepdim=True) 192 | res_x = x - mu 193 | sigma = torch.mean(res_x**2, dim=1, keepdim=True) 194 | out = res_x / torch.sqrt(sigma + self.eps) 195 | 196 | # apply weight and bias 197 | if self.affine: 198 | out *= self.weight 199 | out += self.bias 200 | 201 | return out 202 | 203 | class MaskedAvgPool1D(nn.Module): 204 | """ 205 | Masked 1D average pooling 206 | """ 207 | def __init__(self): 208 | super(MaskedAvgPool1D, self).__init__() 209 | 210 | def forward(self, x, mask): 211 | x_sum = torch.sum(x * mask.float(), dim=-1, keepdim=True) 212 | n = torch.sum(mask, dim=-1, keepdim=True) 213 | x = x_sum / n 214 | 215 | return x 216 | 217 | 218 | # helper functions for Transformer blocks 219 | def get_sinusoid_encoding(n_position, d_hid): 220 | ''' Sinusoid position encoding table ''' 221 | 222 | def get_position_angle_vec(position): 223 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 224 | 225 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 226 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 227 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 228 | 229 | # return a tensor of size 1 C T 230 | return torch.FloatTensor(sinusoid_table).unsqueeze(0).transpose(1, 2) 231 | 232 | class MaskedMHA(nn.Module): 233 | """ 234 | Multi Head Attention with mask 235 | NOTE: This implementation supports 236 | - global and local self-attention 237 | - (global) cross-attention 238 | 239 | Modified from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py 240 | """ 241 | def __init__( 242 | self, 243 | embd_dim, # embedding dimension 244 | q_dim=None, # query dimension 245 | kv_dim=None, # key / value dimension 246 | out_dim=None, # output dimension 247 | n_heads=4, # number of attention heads 248 | window_size=0, # local attention window size (0 for global attention) 249 | attn_pdrop=0.0, # dropout rate for attention map 250 | proj_pdrop=0.0, # dropout rate for projection 251 | use_rel_pe=False, # whether to apply relative position encoding 252 | ): 253 | super(MaskedMHA, self).__init__() 254 | 255 | assert embd_dim % n_heads == 0 256 | self.embd_dim = embd_dim 257 | 258 | if q_dim is None: 259 | q_dim = embd_dim 260 | if kv_dim is None: 261 | kv_dim = embd_dim 262 | if out_dim is None: 263 | out_dim = q_dim 264 | 265 | self.n_heads = n_heads 266 | self.n_channels = embd_dim // n_heads 267 | self.scale = 1.0 / math.sqrt(math.sqrt(self.n_channels)) 268 | self.out_dim = out_dim 269 | 270 | self.query = nn.Conv1d(q_dim, embd_dim, 1) 271 | self.key = nn.Conv1d(kv_dim, embd_dim, 1) 272 | self.value = nn.Conv1d(kv_dim, embd_dim, 1) 273 | self.proj = nn.Conv1d(embd_dim, out_dim, 1) 274 | 275 | self.attn_drop = nn.Dropout(attn_pdrop) 276 | self.proj_drop = nn.Dropout(proj_pdrop) 277 | 278 | # local attention window size 279 | assert window_size == 0 or window_size % 2 == 1 280 | self.window_size = window_size 281 | self.stride = window_size // 2 282 | 283 | # masks for local attention (left / right paddings) 284 | l_mask = torch.ones(self.stride, self.stride + 1).tril().flip(dims=(0,)) 285 | r_mask = torch.ones(self.stride, self.stride + 1).tril().flip(dims=(1,)) 286 | self.register_buffer('l_mask', l_mask.bool(), persistent=False) 287 | self.register_buffer('r_mask', r_mask.bool(), persistent=False) 288 | 289 | # relative position encoding for local attention 290 | if window_size > 0 and use_rel_pe: 291 | self.rel_pe = nn.Parameter(torch.zeros(n_heads, 1, window_size)) 292 | trunc_normal_(self.rel_pe, std=(2.0 / embd_dim) ** 0.5) 293 | else: 294 | self.rel_pe = None 295 | 296 | def _chunk(self, x, size): 297 | """ 298 | Convert feature sequence into temporally overlapping chunks. 299 | 300 | Args: 301 | x (float tensor, (n, t, d)): feature sequence. 302 | size (int): chunk size. 303 | 304 | Returns: 305 | x (float tensor, (n, k, s, d)): chunked features. 306 | """ 307 | n, t, d = x.size() 308 | assert (t + self.stride - size) % self.stride == 0 309 | n_chunks = (t + self.stride - size) // self.stride 310 | 311 | chunk_size = (n, n_chunks, size, d) 312 | chunk_stride = (x.stride(0), self.stride * x.stride(1), *x.stride()[1:]) 313 | x = x.as_strided(size=chunk_size, stride=chunk_stride) 314 | 315 | return x 316 | 317 | def _query_key_matmul(self, q, k): 318 | """ 319 | Chunk-wise query-key product. 320 | 321 | Args: 322 | q (float tensor, (n, t, d)): query tensor. 323 | k (float tensor, (n, t, d)): key tensor. 324 | 325 | Returns: 326 | attn (float tensor, (n, t, w)): unnormalized attention scores. 327 | """ 328 | assert q.size() == k.size() 329 | n, t, _ = q.size() 330 | w, s = self.window_size, self.stride 331 | 332 | # chunk query and key tensors: (n, t, d) -> (n, t // s - 1, 2s, d) 333 | q_chunks = self._chunk(q.contiguous(), size=2 * s) 334 | k_chunks = self._chunk(k.contiguous(), size=2 * s) 335 | n_chunks = q_chunks.size(1) 336 | 337 | # chunk-wise attention scores: (n, t // s - 1, 2s, 2s) 338 | chunk_attn = torch.einsum('bcxd,bcyd->bcxy', (q_chunks, k_chunks)) 339 | 340 | # shift diagonals into columns: (n, t // s - 1, 2s, w) 341 | chunk_attn = F.pad(chunk_attn, (0, 0, 0, 1)) 342 | chunk_attn = chunk_attn.view(n, n_chunks, 2 * s, w) 343 | 344 | # fill in the overall attention matrix: (n, t // s, s, w) 345 | attn = chunk_attn.new_zeros(n, t // s, s, w) 346 | attn[:, :-1, :, s:] = chunk_attn[:, :, :s, :s + 1] 347 | attn[:, -1, :, s:] = chunk_attn[:, -1, s:, :s + 1] 348 | attn[:, 1:, :, :s] = chunk_attn[:, :, -(s + 1):-1, s + 1:] 349 | attn[:, 0, 1:s, 1:s] = chunk_attn[:, 0, :s - 1, w - (s - 1):] 350 | attn = attn.view(n, t, w) 351 | 352 | # mask invalid attention scores 353 | attn[:, :s, :s + 1].masked_fill_(self.l_mask, float('-inf')) 354 | attn[:, -s:, -(s + 1):].masked_fill_(self.r_mask, float('-inf')) 355 | 356 | return attn 357 | 358 | def _attn_normalize(self, attn, mask): 359 | """ 360 | Normalize attention scores over valid positions. 361 | 362 | Args: 363 | attn (float tensor, (bs, h, t, w)): unnormalized attention scores. 364 | mask (bool tensor, (bs, t, 1)): mask (1 for valid positions). 365 | 366 | Returns: 367 | attn (float tensor, (bs, h, t, w)): normalized attention map. 368 | """ 369 | bs, h, t, w = attn.size() 370 | 371 | # inverse mask (0 for valid positions, -inf for invalid ones) 372 | inv_mask = torch.logical_not(mask) 373 | inv_mask_float = inv_mask.float().masked_fill(inv_mask, -1e4) 374 | 375 | # additive attention mask: (bs, t, w) 376 | attn_mask = self._query_key_matmul( 377 | torch.ones_like(inv_mask_float), inv_mask_float 378 | ) 379 | attn += attn_mask.view(bs, 1, t, w) 380 | 381 | # normalize 382 | attn = F.softmax(attn, dim=-1) 383 | 384 | # if all key / value positions in a local window are invalid 385 | # (i.e., when the query position is invalid), softmax returns NaN. 386 | # Replace NaNs with 0 387 | attn = attn.masked_fill(inv_mask.unsqueeze(1), 0.0) 388 | 389 | return attn 390 | 391 | def _attn_value_matmul(self, attn, v): 392 | """ 393 | Chunk-wise attention-value product. 394 | 395 | Args: 396 | attn (float tensor, (n, t, w)): attention map. 397 | v (float tensor, (n, t, d)): value tensor. 398 | 399 | Returns: 400 | out (float tensor, (n, t, d)): attention-weighted sum of values. 401 | """ 402 | n, t, d = v.size() 403 | w, s = self.window_size, self.stride 404 | 405 | # chunk attention map: (n, t, w) -> (n, t // s, s, w) 406 | attn_chunks = attn.view(n, t // s, s, w) 407 | 408 | # shift columns into diagonals: (n, t // s, s, 3s) 409 | attn_chunks = F.pad(attn_chunks, (0, s)) 410 | attn_chunks = attn_chunks.view(n, t // s, -1)[..., :-s] 411 | attn_chunks = attn_chunks.view(n, t // s, s, 3 * s) 412 | 413 | # chunk value tensor: (n, t + 2s, d) -> (n, t // s, 3s, d) 414 | v = F.pad(v, (0, 0, s, s)) 415 | v_chunks = self._chunk(v.contiguous(), size=3 * s) 416 | 417 | # chunk-wise attention-weighted sum: (n, t // s, s, d) 418 | out = torch.einsum('bcwd,bcdh->bcwh', (attn_chunks, v_chunks)) 419 | out = out.view(n, t, d) 420 | 421 | return out 422 | 423 | def forward(self, q, k=None, v=None, kv_mask=None, kv_size=None, 424 | video_ids=None, layer_idx=None, curr_epoch=None, q_mask=None): 425 | """ 426 | Args: 427 | q (float tensor, (bs, c, t1)): query feature sequence. 428 | k (float tensor, (bs, c, t2)): key feature sequence. 429 | v (float tensor, (bs, c, t2)): value feature sequence. 430 | kv_mask (bool tensor, (bs, 1, t2)): key / value mask. 431 | kv_size (int tensor, (bs,)): number of times to repeat each sample. 432 | """ 433 | bs, c = q.size(0), self.embd_dim 434 | h, d, w = self.n_heads, self.n_channels, self.window_size 435 | 436 | if k is None: 437 | k = q 438 | if v is None: 439 | v = k 440 | 441 | # if mask is not given, assume all positions are valid 442 | if kv_mask is None: 443 | kv_mask = torch.ones_like(k[:, :1], dtype=torch.bool) 444 | 445 | q = self.query(q) 446 | k = self.key(k) 447 | v = self.value(v) 448 | 449 | # repeat query to match the size of key / value 450 | if kv_size is not None and k.size(0) != bs: 451 | q = q.repeat_interleave(kv_size, dim=0) 452 | bs = q.size(0) 453 | 454 | if self.window_size > 0: 455 | q = q.view(bs, h, d, -1).flatten(0, 1).transpose(1, 2) 456 | k = k.view(bs, h, d, -1).flatten(0, 1).transpose(1, 2) 457 | v = v.view(bs, h, d, -1).flatten(0, 1).transpose(1, 2) 458 | 459 | # attention scores: (bs * h, t, w) 460 | attn = self._query_key_matmul(q * self.scale, k * self.scale) 461 | attn = attn.view(bs, h, -1, w) 462 | if self.rel_pe is not None: 463 | attn += self.rel_pe 464 | 465 | # normalized attention map: (bs, h, t, w) 466 | attn = self._attn_normalize(attn, kv_mask.transpose(1, 2)) 467 | attn = self.attn_drop(attn) 468 | attn = attn.view(bs * h, -1, w) 469 | 470 | # attention-weighted sum of values: # (bs * h, t, d) 471 | q = self._attn_value_matmul(attn, v) 472 | q = q.view(bs, h, -1, d) 473 | else: 474 | q = q.view(bs, h, d, -1).transpose(2, 3) 475 | k = k.view(bs, h, d, -1) 476 | v = v.view(bs, h, d, -1).transpose(2, 3) 477 | 478 | attn = (q * self.scale) @ (k * self.scale) # (bs, h, t1, t2) 479 | attn = attn.masked_fill( 480 | mask=torch.logical_not(kv_mask[:, :, None, :]), 481 | value=float('-inf'), 482 | ) 483 | attn = F.softmax(attn, dim=-1) 484 | 485 | ret_attn = attn.clone() 486 | 487 | attn = self.attn_drop(attn) 488 | q = attn @ v # (bs, h, t1, d) 489 | 490 | q = q.transpose(2, 3).reshape(bs, c, -1) # (bs, c, t1) 491 | out = self.proj_drop(self.proj(q)) 492 | 493 | return out, ret_attn * q_mask.unsqueeze(-1).expand(attn.shape).requires_grad_(False) 494 | 495 | 496 | class AttNPool1D(nn.Module): 497 | def __init__(self, embd_dim, n_heads=4): 498 | super(AttNPool1D, self).__init__() 499 | 500 | self.pool = MaskedAvgPool1D() 501 | self.attn = MaskedMHA(embd_dim, n_heads=n_heads) 502 | 503 | def forward(self, x, mask): 504 | x_mean = self.pool(x, mask) 505 | h = torch.cat((x_mean, x), dim=-1) 506 | mask = torch.cat((mask[..., :1], mask), dim=-1) 507 | 508 | pool = self.attn(h, kv_mask=mask)[..., :1] 509 | x = torch.cat((pool, x), dim=-1) 510 | 511 | return x, mask 512 | 513 | 514 | 515 | class ConvXAttNLayer(nn.Module): 516 | """ 517 | Multi Head Conv Cross Attention with mask 518 | 519 | With current implementation, the downpsampled features will be aligned with 520 | every s+1 time steps, where s is the down-sampling stride. This allows us 521 | to easily interpolate the corresponding position encoding. 522 | """ 523 | def __init__( 524 | self, 525 | embd_dim, # embedding dimension 526 | kv_dim, # key / value dimension 527 | out_dim=None, # output dimension 528 | stride=1, # convolution stride 529 | n_heads=4, # number of attention heads 530 | attn_pdrop=0.0, # dropout rate for attention map 531 | proj_pdrop=0.0, # dropout rate for projection 532 | use_offset=False, # whether to add offsets to down-sampled points 533 | ): 534 | super(ConvXAttNLayer, self).__init__() 535 | 536 | self.use_conv = stride > 0 537 | if self.use_conv: 538 | assert stride == 1 or stride % 2 == 0 539 | kernel_size = stride + 1 + use_offset if stride > 1 else 3 540 | padding = (kernel_size - 1) // 2 541 | 542 | self.q_conv = MaskedConv1D( 543 | embd_dim, embd_dim, 544 | kernel_size=kernel_size, 545 | stride=stride, padding=padding, 546 | groups=embd_dim, bias=False, 547 | ) 548 | self.q_norm = LayerNorm(embd_dim) 549 | else: 550 | self.q_conv = self.q_norm = None 551 | 552 | if out_dim is None: 553 | out_dim = embd_dim 554 | 555 | # cross-attention 556 | self.xattn = MaskedMHA( 557 | embd_dim, 558 | kv_dim=kv_dim, out_dim=out_dim, 559 | n_heads=n_heads, 560 | attn_pdrop=attn_pdrop, proj_pdrop=proj_pdrop 561 | ) 562 | 563 | def forward(self, q, q_mask, kv, kv_mask, kv_size=None, 564 | video_ids=None, layer_idx=None, curr_epoch=None): 565 | if self.use_conv: 566 | q, q_mask = self.q_conv(q, q_mask) 567 | q = self.q_norm(q) 568 | 569 | out, cross_attn = self.xattn(q, kv, None, kv_mask, kv_size, 570 | video_ids=video_ids, layer_idx=layer_idx, 571 | curr_epoch=curr_epoch, q_mask = q_mask) 572 | if kv_size is not None and out.size(0) != q_mask.size(0): 573 | q_mask = q_mask.repeat_interleave(kv_size, dim=0) 574 | 575 | return out, q_mask, cross_attn 576 | 577 | 578 | class FFN(nn.Module): 579 | """ 580 | Feed Forward Network (MLP) in Transformer. 581 | """ 582 | def __init__(self, channels, expansion=4, pdrop=0.0): 583 | super(FFN, self).__init__() 584 | 585 | self.fc = nn.Conv1d(channels, channels * expansion, 1) 586 | self.actv = nn.GELU() 587 | self.proj = nn.Conv1d(channels * expansion, channels, 1) 588 | self.dropout = nn.Dropout(pdrop) 589 | 590 | def forward(self, x): 591 | x = self.dropout(self.actv(self.fc(x))) 592 | x = self.dropout(self.proj(x)) 593 | 594 | return x 595 | 596 | 597 | class TransformerDecoder(nn.Module): 598 | """ 599 | Transformer Decoder (w/o self-attention). 600 | (optional depth-wise conv -> xattn -> FFN) 601 | """ 602 | def __init__( 603 | self, 604 | embd_dim, # embedding dimension 605 | kv_dim, # key / value dimension 606 | stride=1, # convolution stride (0 if disable convs) 607 | n_heads=4, # number of attention heads 608 | window_size=0, # MHA window size (0 for global attention) 609 | expansion=4, # expansion factor for FFN 610 | attn_pdrop=0.0, # dropout rate for attention map 611 | proj_pdrop=0.0, # dropout rate for projection 612 | path_pdrop=0.0, # dropout rate for residual paths 613 | xattn_mode='adaln', # cross-attention mode (affine | adaln) 614 | use_offset=False, # whether to add offsets to down-sampled points 615 | use_rel_pe=False, # whether to apply relative position encoding 616 | ): 617 | super(TransformerDecoder, self).__init__() 618 | 619 | # cross-attention 620 | assert xattn_mode in ('affine', 'adaln') 621 | self.xattn = ConvXAttNLayer( 622 | embd_dim, kv_dim, 623 | out_dim=embd_dim * 2, 624 | stride=stride, n_heads=n_heads, 625 | attn_pdrop=attn_pdrop, proj_pdrop=path_pdrop, 626 | use_offset=use_offset 627 | ) 628 | self.ln_xattn_q = LayerNorm(embd_dim) 629 | self.ln_xattn_kv = LayerNorm(kv_dim) 630 | 631 | if xattn_mode == 'adaln': 632 | self.adaln = LayerNorm(embd_dim, affine=False) 633 | else: 634 | self.adaln = None 635 | 636 | # FFN 637 | self.ffn = FFN(embd_dim, expansion, proj_pdrop) 638 | self.ln_ffn = LayerNorm(embd_dim) 639 | 640 | # drop path 641 | if path_pdrop > 0.0: 642 | self.drop_path_ffn = AffineDropPath(embd_dim, drop_prob=path_pdrop) 643 | else: 644 | self.drop_path_ffn = nn.Identity() 645 | 646 | def forward(self, q, q_mask, kv, kv_mask, kv_size=None, 647 | video_ids=None, curr_epoch=None): 648 | if q_mask is None: 649 | q_mask = torch.ones_like(q[:, :1], dtype=torch.bool) 650 | 651 | q = q * q_mask.float() 652 | 653 | 654 | # cross-attention (optionally with depth-wise conv) 655 | out, q_mask, cross_attn = self.xattn( 656 | self.ln_xattn_q(q), q_mask, self.ln_xattn_kv(kv), kv_mask, kv_size, 657 | video_ids=video_ids, curr_epoch=curr_epoch 658 | ) 659 | if kv_size is not None and q.size(0) != out.size(0): 660 | q = q.repeat_interleave(kv_size, dim=0) 661 | 662 | q = q * q_mask.float() 663 | weight, bias = out.split(q.size(1), dim=1) 664 | 665 | if self.adaln is not None: 666 | q = self.adaln(q) 667 | 668 | q = q * weight + bias 669 | 670 | # FFN 671 | out = self.ffn(self.ln_ffn(q)) * q_mask.float() 672 | q = q + self.drop_path_ffn(out) 673 | 674 | return q, q_mask, cross_attn 675 | 676 | 677 | class Scale(nn.Module): 678 | """ 679 | Multiply the output regression range by a learnable constant value 680 | """ 681 | 682 | def __init__(self, init=1.0): 683 | """ 684 | init_value : initial value for the scalar 685 | """ 686 | super(Scale, self).__init__() 687 | 688 | self.scale = nn.Parameter(torch.as_tensor(init, dtype=torch.float)) 689 | 690 | def forward(self, x): 691 | return x * self.scale 692 | 693 | 694 | def drop_path(x, drop_prob=0.0, training=False): 695 | """ 696 | Stochastic Depth per sample. 697 | """ 698 | if drop_prob == 0.0 or not training: 699 | return x 700 | 701 | keep_prob = 1 - drop_prob 702 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) 703 | mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 704 | mask.floor_() 705 | x = x.div(keep_prob) * mask 706 | 707 | return x 708 | 709 | 710 | class DropPath(nn.Module): 711 | """ 712 | Drop paths per sample (when applied in main path of residual blocks). 713 | """ 714 | def __init__(self, drop_prob=None): 715 | super(DropPath, self).__init__() 716 | 717 | self.drop_prob = drop_prob 718 | 719 | def forward(self, x): 720 | return drop_path(x, self.drop_prob, self.training) 721 | 722 | 723 | class AffineDropPath(nn.Module): 724 | """ 725 | Drop paths per sample (when applied in main path of residual blocks) 726 | with a per channel scaling factor (and zero init). 727 | 728 | https://arxiv.org/pdf/2103.17239.pdf 729 | """ 730 | def __init__(self, dim, drop_prob=0.0, init_scale=1e-4): 731 | super(AffineDropPath, self).__init__() 732 | 733 | self.scale = nn.Parameter(init_scale * torch.ones((1, dim, 1))) 734 | self.drop_prob = drop_prob 735 | 736 | def forward(self, x): 737 | return drop_path(self.scale * x, self.drop_prob, self.training) 738 | 739 | 740 | class ConvBlock(nn.Module): 741 | """ 742 | A simple conv block similar to the basic block used in ResNet 743 | """ 744 | def __init__( 745 | self, 746 | n_embd, # dimension of the input features 747 | kernel_size=3, # conv kernel size 748 | n_ds_stride=1, # downsampling stride for the current layer 749 | expansion_factor=2, # expansion factor of feat dims 750 | n_out=None, # output dimension, if None, set to input dim 751 | act_layer=nn.ReLU, # nonlinear activation used after conv, default ReLU 752 | ): 753 | super().__init__() 754 | # must use odd sized kernel 755 | assert (kernel_size % 2 == 1) and (kernel_size > 1) 756 | padding = kernel_size // 2 757 | if n_out is None: 758 | n_out = n_embd 759 | 760 | # 1x3 (strided) -> 1x3 (basic block in resnet) 761 | width = n_embd * expansion_factor 762 | self.conv1 = MaskedConv1D( 763 | n_embd, width, kernel_size, n_ds_stride, padding=padding) 764 | self.conv2 = MaskedConv1D( 765 | width, n_out, kernel_size, 1, padding=padding) 766 | 767 | # attach downsampling conv op 768 | if n_ds_stride > 1: 769 | # 1x1 strided conv (same as resnet) 770 | self.downsample = MaskedConv1D(n_embd, n_out, 1, n_ds_stride) 771 | else: 772 | self.downsample = None 773 | 774 | self.act = act_layer() 775 | 776 | def forward(self, x, mask, pos_embd=None): 777 | identity = x 778 | out, out_mask = self.conv1(x, mask) 779 | out = self.act(out) 780 | out, out_mask = self.conv2(out, out_mask) 781 | 782 | # downsampling 783 | if self.downsample is not None: 784 | identity, _ = self.downsample(x, mask) 785 | 786 | # residual connection 787 | out += identity 788 | out = self.act(out) 789 | 790 | return out, out_mask 791 | 792 | 793 | class MaskedMHCA(nn.Module): 794 | """ 795 | Multi Head Conv Attention with mask 796 | 797 | Add a depthwise convolution within a standard MHA 798 | The extra conv op can be used to 799 | (1) encode relative position information (relacing position encoding); 800 | (2) downsample the features if needed; 801 | (3) match the feature channels 802 | 803 | Note: With current implementation, the downsampled feature will be aligned 804 | to every s+1 time step, where s is the downsampling stride. This allows us 805 | to easily interpolate the corresponding positional embeddings. 806 | 807 | Modified from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py 808 | """ 809 | 810 | def __init__( 811 | self, 812 | n_embd, # dimension of the output features 813 | n_head, # number of heads in multi-head self-attention 814 | n_qx_stride=1, # dowsampling stride for query and input 815 | n_kv_stride=1, # downsampling stride for key and value 816 | attn_pdrop=0.0, # dropout rate for the attention map 817 | proj_pdrop=0.0, # dropout rate for projection op 818 | ): 819 | super().__init__() 820 | assert n_embd % n_head == 0 821 | self.n_embd = n_embd 822 | self.n_head = n_head 823 | self.n_channels = n_embd // n_head 824 | self.scale = 1.0 / math.sqrt(self.n_channels) 825 | 826 | # conv/pooling operations 827 | assert (n_qx_stride == 1) or (n_qx_stride % 2 == 0) 828 | assert (n_kv_stride == 1) or (n_kv_stride % 2 == 0) 829 | self.n_qx_stride = n_qx_stride 830 | self.n_kv_stride = n_kv_stride 831 | 832 | # query conv (depthwise) 833 | kernel_size = self.n_qx_stride + 1 if self.n_qx_stride > 1 else 3 834 | stride, padding = self.n_kv_stride, kernel_size // 2 835 | self.query_conv = MaskedConv1D( 836 | self.n_embd, self.n_embd, kernel_size, 837 | stride=stride, padding=padding, groups=self.n_embd, bias=False 838 | ) 839 | self.query_norm = LayerNorm(self.n_embd) 840 | 841 | # key, value conv (depthwise) 842 | kernel_size = self.n_kv_stride + 1 if self.n_kv_stride > 1 else 3 843 | stride, padding = self.n_kv_stride, kernel_size // 2 844 | self.key_conv = MaskedConv1D( 845 | self.n_embd, self.n_embd, kernel_size, 846 | stride=stride, padding=padding, groups=self.n_embd, bias=False 847 | ) 848 | self.key_norm = LayerNorm(self.n_embd) 849 | self.value_conv = MaskedConv1D( 850 | self.n_embd, self.n_embd, kernel_size, 851 | stride=stride, padding=padding, groups=self.n_embd, bias=False 852 | ) 853 | self.value_norm = LayerNorm(self.n_embd) 854 | 855 | # key, query, value projections for all heads 856 | # it is OK to ignore masking, as the mask will be attached on the attention 857 | self.key = nn.Conv1d(self.n_embd, self.n_embd, 1) 858 | self.query = nn.Conv1d(self.n_embd, self.n_embd, 1) 859 | self.value = nn.Conv1d(self.n_embd, self.n_embd, 1) 860 | 861 | # regularization 862 | self.attn_drop = nn.Dropout(attn_pdrop) 863 | self.proj_drop = nn.Dropout(proj_pdrop) 864 | 865 | # output projection 866 | self.proj = nn.Conv1d(self.n_embd, self.n_embd, 1) 867 | 868 | def forward(self, x, mask): 869 | # x: batch size, feature channel, sequence length, 870 | # mask: batch size, 1, sequence length (bool) 871 | B, C, T = x.size() 872 | 873 | # query conv -> (B, nh * hs, T') 874 | q, qx_mask = self.query_conv(x, mask) 875 | q = self.query_norm(q) 876 | # key, value conv -> (B, nh * hs, T'') 877 | k, kv_mask = self.key_conv(x, mask) 878 | k = self.key_norm(k) 879 | v, _ = self.value_conv(x, mask) 880 | v = self.value_norm(v) 881 | 882 | # projections 883 | q = self.query(q) 884 | k = self.key(k) 885 | v = self.value(v) 886 | 887 | # move head forward to be the batch dim 888 | # (B, nh * hs, T'/T'') -> (B, nh, T'/T'', hs) 889 | k = k.view(B, self.n_head, self.n_channels, -1).transpose(2, 3) 890 | q = q.view(B, self.n_head, self.n_channels, -1).transpose(2, 3) 891 | v = v.view(B, self.n_head, self.n_channels, -1).transpose(2, 3) 892 | 893 | # self-attention: (B, nh, T', hs) x (B, nh, hs, T'') -> (B, nh, T', T'') 894 | att = (q * self.scale) @ k.transpose(-2, -1) 895 | # prevent q from attending to invalid tokens 896 | att = att.masked_fill(torch.logical_not(kv_mask[:, :, None, :]), float('-inf')) 897 | # softmax attn 898 | att = F.softmax(att, dim=-1) 899 | att = self.attn_drop(att) 900 | # (B, nh, T', T'') x (B, nh, T'', hs) -> (B, nh, T', hs) 901 | out = att @ (v * kv_mask[:, :, :, None].to(v.dtype)) 902 | # re-assemble all head outputs side by side 903 | out = out.transpose(2, 3).contiguous().view(B, C, -1) 904 | 905 | # output projection + skip connection 906 | out = self.proj_drop(self.proj(out)) * qx_mask.to(out.dtype) 907 | return out, qx_mask 908 | 909 | 910 | 911 | class TransformerBlock(nn.Module): 912 | """ 913 | A simple (post layer norm) Transformer block 914 | Modified from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py 915 | """ 916 | def __init__( 917 | self, 918 | n_embd, # dimension of the input features 919 | n_head, # number of attention heads 920 | n_ds_strides=(1, 1), # downsampling strides for q & x, k & v 921 | n_out=None, # output dimension, if None, set to input dim 922 | n_hidden=None, # dimension of the hidden layer in MLP 923 | act_layer=nn.GELU, # nonlinear activation used in MLP, default GELU 924 | attn_pdrop=0.0, # dropout rate for the attention map 925 | proj_pdrop=0.0, # dropout rate for the projection / MLP 926 | path_pdrop=0.0, # drop path rate 927 | mha_win_size=-1, # > 0 to use window mha 928 | use_rel_pe=False # if to add rel position encoding to attention 929 | ): 930 | super().__init__() 931 | assert len(n_ds_strides) == 2 932 | # layer norm for order (B C T) 933 | self.ln1 = LayerNorm(n_embd) 934 | self.ln2 = LayerNorm(n_embd) 935 | 936 | # specify the attention module 937 | if mha_win_size > 1: 938 | self.attn = LocalMaskedMHCA( 939 | n_embd, 940 | n_head, 941 | window_size=mha_win_size, 942 | n_qx_stride=n_ds_strides[0], 943 | n_kv_stride=n_ds_strides[1], 944 | attn_pdrop=attn_pdrop, 945 | proj_pdrop=proj_pdrop, 946 | use_rel_pe=use_rel_pe # only valid for local attention 947 | ) 948 | else: 949 | self.attn = MaskedMHCA( 950 | n_embd, 951 | n_head, 952 | n_qx_stride=n_ds_strides[0], 953 | n_kv_stride=n_ds_strides[1], 954 | attn_pdrop=attn_pdrop, 955 | proj_pdrop=proj_pdrop 956 | ) 957 | 958 | # input 959 | if n_ds_strides[0] > 1: 960 | kernel_size, stride, padding = \ 961 | n_ds_strides[0] + 1, n_ds_strides[0], (n_ds_strides[0] + 1)//2 962 | self.pool_skip = nn.MaxPool1d( 963 | kernel_size, stride=stride, padding=padding) 964 | else: 965 | self.pool_skip = nn.Identity() 966 | 967 | # two layer mlp 968 | if n_hidden is None: 969 | n_hidden = 4 * n_embd # default 970 | if n_out is None: 971 | n_out = n_embd 972 | # ok to use conv1d here with stride=1 973 | self.mlp = nn.Sequential( 974 | nn.Conv1d(n_embd, n_hidden, 1), 975 | act_layer(), 976 | nn.Dropout(proj_pdrop, inplace=True), 977 | nn.Conv1d(n_hidden, n_out, 1), 978 | nn.Dropout(proj_pdrop, inplace=True), 979 | ) 980 | 981 | # drop path 982 | if path_pdrop > 0.0: 983 | self.drop_path_attn = AffineDropPath(n_embd, drop_prob = path_pdrop) 984 | self.drop_path_mlp = AffineDropPath(n_out, drop_prob = path_pdrop) 985 | else: 986 | self.drop_path_attn = nn.Identity() 987 | self.drop_path_mlp = nn.Identity() 988 | 989 | def forward(self, x, mask, pos_embd=None): 990 | # pre-LN transformer: https://arxiv.org/pdf/2002.04745.pdf 991 | out, out_mask = self.attn(self.ln1(x), mask) 992 | out_mask_float = out_mask.to(out.dtype) 993 | out = self.pool_skip(x) * out_mask_float + self.drop_path_attn(out) 994 | # FFN 995 | out = out + self.drop_path_mlp(self.mlp(self.ln2(out)) * out_mask_float) 996 | # optionally add pos_embd to the output 997 | if pos_embd is not None: 998 | out += pos_embd * out_mask_float 999 | return out, out_mask 1000 | --------------------------------------------------------------------------------