├── assets
├── method_fig.png
├── results_fig.png
├── teaser_fig.png
├── MTL_t5_xxl_text_embeddings.npy
├── Cat101_t5_xxl_text_embeddings.npy
└── FineDiving_t5_xxl_text_embeddings.npy
├── libs
├── core
│ ├── __init__.py
│ └── config.py
├── datasets
│ ├── __init__.py
│ ├── datasets.py
│ ├── data_utils.py
│ ├── mtl_aqa.py
│ └── finediving.py
├── utils
│ ├── __init__.py
│ ├── postprocessing.py
│ ├── preprocessing.py
│ ├── metrics.py
│ ├── lr_schedulers.py
│ └── train_utils.py
└── modeling
│ ├── __init__.py
│ ├── models.py
│ ├── weight_init.py
│ ├── necks.py
│ ├── backbones.py
│ ├── losses.py
│ ├── meta_archs.py
│ ├── i3d.py
│ └── blocks.py
├── requirements.txt
├── INSTALL.md
├── LICENSE
├── README.md
├── .gitignore
├── GETTING_STARTED.md
├── configs
├── mtl_aqa
│ ├── deter_mtl_aqa_text_data_query.yaml
│ └── stoch_mtl_aqa_text_data_query.yaml
└── fine
│ ├── deter_fine_diving_text_data_query.yaml
│ └── stoch_fine_diving_text_data_query.yaml
├── tools
├── mtl_aqa
│ └── mtl_t5xxl_text_embed_extraction.py
└── finediving
│ └── finediving_t5xxl_text_embed_extraction.py
├── eval.py
└── train.py
/assets/method_fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abrarmajeedi/rica2_aqa/HEAD/assets/method_fig.png
--------------------------------------------------------------------------------
/assets/results_fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abrarmajeedi/rica2_aqa/HEAD/assets/results_fig.png
--------------------------------------------------------------------------------
/assets/teaser_fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abrarmajeedi/rica2_aqa/HEAD/assets/teaser_fig.png
--------------------------------------------------------------------------------
/assets/MTL_t5_xxl_text_embeddings.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abrarmajeedi/rica2_aqa/HEAD/assets/MTL_t5_xxl_text_embeddings.npy
--------------------------------------------------------------------------------
/assets/Cat101_t5_xxl_text_embeddings.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abrarmajeedi/rica2_aqa/HEAD/assets/Cat101_t5_xxl_text_embeddings.npy
--------------------------------------------------------------------------------
/libs/core/__init__.py:
--------------------------------------------------------------------------------
1 | from .config import load_default_config, load_config
2 |
3 | __all__ = ['load_default_config', 'load_config']
4 |
--------------------------------------------------------------------------------
/assets/FineDiving_t5_xxl_text_embeddings.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abrarmajeedi/rica2_aqa/HEAD/assets/FineDiving_t5_xxl_text_embeddings.npy
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ## Requirements
2 | Tensorboard >= 2.14.0
3 | PyYaml >= 6.0.2
4 | einops >=0.8.0
5 | scipy >=1.10.1
6 | numpy>=1.24.1
7 | Pillow>=10.2.0
8 | transformers>=4.45.2
9 |
--------------------------------------------------------------------------------
/libs/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .data_utils import worker_init_reset_seed
2 | from .datasets import make_dataset, make_data_loader
3 | from . import finediving, mtl_aqa # other datasets go here
4 |
5 | __all__ = ['worker_init_reset_seed',
6 | 'make_dataset', 'make_data_loader']
7 |
--------------------------------------------------------------------------------
/INSTALL.md:
--------------------------------------------------------------------------------
1 | # Installation
2 |
3 | ## Conda environment
4 | Python >= 3.8
5 |
6 | conda create --name aqa_env python=3.8
7 | conda activate aqa_env
8 |
9 | ## PyTorch
10 |
11 | PyTorch >= 2.4.1
12 | [(Get PyTorch Here)](https://pytorch.org/get-started/locally/)
13 |
14 | ## Requirements
15 | Install the required libraries with :
16 |
17 | pip3 install -r requirements.txt
--------------------------------------------------------------------------------
/libs/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .train_utils import (Logger, make_optimizer, make_scheduler, save_checkpoint,
2 | AverageMeter, train_one_epoch, valid_one_epoch,
3 | fix_random_seed, ModelEma)
4 |
5 | # from .postprocessing import postprocess_results
6 |
7 | __all__ = ['Logger','make_optimizer', 'make_scheduler', 'save_checkpoint',
8 | 'AverageMeter', 'train_one_epoch', 'valid_one_epoch',
9 | 'postprocess_results', 'fix_random_seed', 'ModelEma', 'remove_duplicate_annotations']
10 |
--------------------------------------------------------------------------------
/libs/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | from .blocks import (MaskedConv1D,MaskedMHCA, MaskedMHA, LayerNorm,
2 | TransformerBlock, ConvBlock, Scale, AffineDropPath)
3 | from .models import make_backbone, make_neck, make_meta_arch
4 | from . import backbones # backbones
5 | from . import necks # necks
6 | from . import meta_archs # full models
7 | from .i3d import I3D
8 | __all__ = ['MaskedConv1D','MaskedMHCA', 'MaskedMHA', 'LayerNorm',
9 | 'TransformerBlock', 'ConvBlock', 'Scale', 'AffineDropPath',
10 | 'make_backbone', 'make_neck', 'make_meta_arch']
--------------------------------------------------------------------------------
/libs/modeling/models.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | # backbone (e.g., conv / transformer)
4 | backbones = {}
5 | def register_backbone(name):
6 | def decorator(cls):
7 | backbones[name] = cls
8 | return cls
9 | return decorator
10 |
11 | # neck (e.g., FPN)
12 | necks = {}
13 | def register_neck(name):
14 | def decorator(cls):
15 | necks[name] = cls
16 | return cls
17 | return decorator
18 |
19 | # meta arch (the actual implementation of each model)
20 | meta_archs = {}
21 | def register_meta_arch(name):
22 | def decorator(cls):
23 | meta_archs[name] = cls
24 | return cls
25 | return decorator
26 |
27 | # builder functions
28 | def make_backbone(name, **kwargs):
29 | backbone = backbones[name](**kwargs)
30 | return backbone
31 |
32 | def make_neck(name, **kwargs):
33 | neck = necks[name](**kwargs)
34 | return neck
35 |
36 | def make_meta_arch(name, **kwargs):
37 | meta_arch = meta_archs[name](**kwargs)
38 | return meta_arch
39 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 University of Wisconsin-Madison
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/libs/datasets/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from .data_utils import trivial_batch_collator, worker_init_reset_seed
4 |
5 | datasets = {}
6 | def register_dataset(name):
7 | def decorator(cls):
8 | datasets[name] = cls
9 | return cls
10 | return decorator
11 |
12 | def make_dataset(name, is_training, split, **kwargs):
13 | """
14 | A simple dataset builder
15 | """
16 | dataset = datasets[name](is_training, split, **kwargs)
17 | return dataset
18 |
19 | def make_data_loader(dataset, is_training, generator, train_batch_size, test_batch_size, num_workers):
20 | """
21 | A simple dataloder builder
22 | """
23 | loader = torch.utils.data.DataLoader(
24 | dataset,
25 | batch_size=train_batch_size if is_training else test_batch_size,
26 | num_workers=num_workers,
27 | collate_fn=trivial_batch_collator,
28 | worker_init_fn=(worker_init_reset_seed if is_training else None),
29 | shuffle=is_training,
30 | drop_last=is_training,
31 | generator=generator,
32 | persistent_workers=True,
33 | pin_memory=True
34 | )
35 | return loader
36 |
--------------------------------------------------------------------------------
/libs/utils/postprocessing.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import time
4 | import json
5 | import pickle
6 | from typing import Dict
7 |
8 | import numpy as np
9 |
10 | import torch
11 | import torch.nn as nn
12 |
13 | @torch.no_grad()
14 | def update_metric_dict_with_model_output(metric_dict, model_output, gt_scores, difficulties, is_val, cfg=None):
15 | # B, samples, 1
16 | all_sample_outputs = model_output["all_sample_outputs"].detach()
17 |
18 | with_dd = cfg["dataset"]["with_dd"]
19 |
20 | batched_pred_quality_scores = get_pred_scores(all_sample_outputs, cfg).detach().cpu().numpy()
21 | batched_gt_scores = gt_scores.detach().cpu().numpy()
22 |
23 | if with_dd:
24 | batched_pred_quality_scores = batched_pred_quality_scores * difficulties.detach().cpu().numpy()
25 |
26 | metric_dict.update("pred_scores", batched_pred_quality_scores)
27 | metric_dict.update("gt_scores", batched_gt_scores)
28 |
29 | if is_val:
30 | if "global_sqrt_var_emb" in model_output.keys():
31 | metric_dict.update("global_sqrt_var_emb", model_output["global_sqrt_var_emb"].detach().cpu().numpy())
32 | return
33 |
34 |
35 | def get_pred_scores(all_sample_outputs, cfg):
36 | #B, samples, 1: all_sample_outputs.size()
37 |
38 | if cfg["dataset_name"] == "jigsaws":
39 | batched_pred_quality_scores = all_sample_outputs.mean(1).squeeze()
40 | if cfg["dataset"]["six_item_score_scaling"]:
41 | batched_pred_quality_scores = batched_pred_quality_scores * 6.0
42 | else:
43 | batched_pred_quality_scores = all_sample_outputs.mean(1).squeeze()
44 | if cfg["dataset"]["three_judge_score_scaling"]:
45 | batched_pred_quality_scores = batched_pred_quality_scores * 3.0
46 |
47 | return batched_pred_quality_scores
48 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RICA2: Rubric-Informed, Calibrated Assessment of Actions (ECCV 2024)
2 |
3 | [](https://abrarmajeedi.github.io/rica2_aqa/)
4 | [](https://arxiv.org/abs/2408.02138)
5 |
6 |
7 |
8 | Check out the new Medium post on RICA2! [](https://namburisrinath.medium.com/rica%C2%B2-rubric-informed-calibrated-assessment-of-actions-92b5715a9163)
9 |
10 |
11 |
12 | ## Abstract
13 |
14 | 
15 |
16 | The ability to quantify how well an action is carried out, also known as action quality assessment (AQA), has attracted recent interest in the vision community. Unfortunately, prior methods often ignore the score rubric used by human experts and fall short of quantifying the uncertainty of the model prediction. To bridge the gap, we present RICA^2 - a deep probabilistic model that integrates score rubric and accounts for prediction uncertainty for AQA. Central to our method lies in stochastic embeddings of action steps, defined on a graph structure that encodes the score rubric. The embeddings spread probabilistic density in the latent space and allow our method to represent model uncertainty. The graph encodes the scoring criteria, based on which the quality scores can be decoded. We demonstrate that our method establishes new state of the art on public benchmarks, including FineDiving, MTL-AQA, and JIGSAWS, with superior performance in score prediction and uncertainty calibration
17 |
18 |
19 | ## Method
20 |
21 | 
22 |
23 |
24 |
25 | ## Results
26 |
27 |
28 | 
29 |
30 |
31 | ## Code
32 | Please find the installation instructions in [INSTALL.md](./INSTALL.md)
33 |
34 | Instructions to run the code can be found in [ GETTING_STARTED.md](./GETTING_STARTED.md)
35 |
36 |
37 | ## Citation
38 | If you find our work useful, please consider citing:
39 |
40 | ```bibtex
41 | @article{majeedi24rica2,
42 | title={RICA^2: Rubric-Informed, Calibrated Assessment of Actions},
43 | author={Majeedi, Abrar and Gajjala, Viswanatha Reddy and GNVV, Satya Sai Srinath Namburi and Li, Yin},
44 | year={2024},
45 | booktitle={European conference on computer vision},
46 | primaryClass={cs.CV}
47 | }
48 | ```
49 |
50 |
51 | ## Contact
52 | Email: majeedi+at+wisc+dot+edu
53 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | data/
2 | ckpt/
3 | feats/
4 | exp/
5 | images/reliablity_diagrams/
6 | libs/utils/dist
7 | libs/utils/nms_1d_cpu.egg-info
8 | **/build
9 | *.out
10 | *pth
11 |
12 | slurm*
13 | # Byte-compiled / optimized / DLL files
14 | __pycache__/
15 | *.py[cod]
16 | *$py.class
17 |
18 | # C extensions
19 | *.so
20 |
21 | # Distribution / packaging
22 | .Python
23 | build/
24 | develop-eggs/
25 | dist/
26 | downloads/
27 | eggs/
28 | .eggs/
29 | lib/
30 | lib64/
31 | parts/
32 | sdist/
33 | var/
34 | wheels/
35 | pip-wheel-metadata/
36 | share/python-wheels/
37 | *.egg-info/
38 | .installed.cfg
39 | *.egg
40 | MANIFEST
41 |
42 | # PyInstaller
43 | # Usually these files are written by a python script from a template
44 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
45 | *.manifest
46 | *.spec
47 |
48 | # Installer logs
49 | pip-log.txt
50 | pip-delete-this-directory.txt
51 |
52 | # Unit test / coverage reports
53 | htmlcov/
54 | .tox/
55 | .nox/
56 | .coverage
57 | .coverage.*
58 | .cache
59 | nosetests.xml
60 | coverage.xml
61 | *.cover
62 | *.py,cover
63 | .hypothesis/
64 | .pytest_cache/
65 |
66 | # Translations
67 | *.mo
68 | *.pot
69 |
70 | # Django stuff:
71 | local_settings.py
72 | db.sqlite3
73 | db.sqlite3-journal
74 |
75 | # Flask stuff:
76 | instance/
77 | .webassets-cache
78 |
79 | # Scrapy stuff:
80 | .scrapy
81 |
82 | # Sphinx documentation
83 | docs/_build/
84 |
85 | # PyBuilder
86 | target/
87 |
88 | # Jupyter Notebook
89 | .ipynb_checkpoints
90 |
91 | # IPython
92 | profile_default/
93 | ipython_config.py
94 |
95 | # pyenv
96 | .python-version
97 |
98 | # pipenv
99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
102 | # install all needed dependencies.
103 | #Pipfile.lock
104 |
105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
106 | __pypackages__/
107 |
108 | # Celery stuff
109 | celerybeat-schedule
110 | celerybeat.pid
111 |
112 | # SageMath parsed files
113 | *.sage.py
114 |
115 | # Environments
116 | .env
117 | .venv
118 | env/
119 | venv/
120 | ENV/
121 | env.bak/
122 | venv.bak/
123 |
124 | # Spyder project settings
125 | .spyderproject
126 | .spyproject
127 |
128 | # Rope project settings
129 | .ropeproject
130 |
131 | # mkdocs documentation
132 | /site
133 |
134 | # mypy
135 | .mypy_cache/
136 | .dmypy.json
137 | dmypy.json
138 |
139 | # Pyre type checker
140 | .pyre/
141 |
--------------------------------------------------------------------------------
/libs/modeling/weight_init.py:
--------------------------------------------------------------------------------
1 | # from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
2 | import torch
3 | import math
4 | import warnings
5 |
6 |
7 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
8 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
9 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
10 | def norm_cdf(x):
11 | # Computes standard normal cumulative distribution function
12 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
13 |
14 | if (mean < a - 2 * std) or (mean > b + 2 * std):
15 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
16 | "The distribution of values may be incorrect.",
17 | stacklevel=2)
18 |
19 | with torch.no_grad():
20 | # Values are generated by using a truncated uniform distribution and
21 | # then using the inverse CDF for the normal distribution.
22 | # Get upper and lower cdf values
23 | l = norm_cdf((a - mean) / std)
24 | u = norm_cdf((b - mean) / std)
25 |
26 | # Uniformly fill tensor with values from [l, u], then translate to
27 | # [2l-1, 2u-1].
28 | tensor.uniform_(2 * l - 1, 2 * u - 1)
29 |
30 | # Use inverse cdf transform for normal distribution to get truncated
31 | # standard normal
32 | tensor.erfinv_()
33 |
34 | # Transform to proper mean, std
35 | tensor.mul_(std * math.sqrt(2.))
36 | tensor.add_(mean)
37 |
38 | # Clamp to ensure it's in the proper range
39 | tensor.clamp_(min=a, max=b)
40 | return tensor
41 |
42 |
43 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
44 | # type: (Tensor, float, float, float, float) -> Tensor
45 | r"""Fills the input Tensor with values drawn from a truncated
46 | normal distribution. The values are effectively drawn from the
47 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
48 | with values outside :math:`[a, b]` redrawn until they are within
49 | the bounds. The method used for generating the random values works
50 | best when :math:`a \leq \text{mean} \leq b`.
51 | Args:
52 | tensor: an n-dimensional `torch.Tensor`
53 | mean: the mean of the normal distribution
54 | std: the standard deviation of the normal distribution
55 | a: the minimum cutoff value
56 | b: the maximum cutoff value
57 | Examples:
58 | >>> w = torch.empty(3, 5)
59 | >>> nn.init.trunc_normal_(w)
60 | """
61 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
62 |
--------------------------------------------------------------------------------
/libs/core/config.py:
--------------------------------------------------------------------------------
1 | import yaml
2 |
3 |
4 | DEFAULTS = {
5 | # random seed for reproducibility, a large number is preferred
6 | "init_rand_seed": 1234567891,
7 | # dataset loader, specify the dataset here
8 | "dataset_name": "",
9 | "devices": ['cuda:0'], # default: single gpu
10 | "train_split": ('training', ),
11 | "val_split": ('validation', ),
12 | "model_name": "XXX",
13 | "dataset": {
14 |
15 | },
16 | "loader": {
17 | "train_batch_size": 8,
18 | "test_batch_size": 8,
19 | "num_workers": 1,
20 | },
21 | # network architecture
22 | "model": {
23 | "finetune_feat_extractor": False,
24 | "feat_extractor_type": None,
25 | "feat_extractor_weights_path": None,
26 | "backbone_type": 'xxx',
27 | # disable abs position encoding (added to input embedding)
28 | "neck_type": 'xxx',
29 | "decoder_params": {
30 | "decoder_ffn_dim": 2048,
31 | "decoder_activation": 'gelu',
32 | },
33 | },
34 | "train_cfg": {
35 | # gradient cliping, not needed for pre-LN transformer
36 | "clip_grad_l2norm": -1,
37 | },
38 | "test_cfg": {
39 | },
40 | # optimizer (for training)
41 | "opt": {
42 | # solver
43 | "type": "AdamW", # SGD or AdamW
44 | # solver params
45 | "momentum": 0.9,
46 | "weight_decay": 0.0,
47 | "learning_rate": 1e-3,
48 | # excluding the warmup epochs
49 | "epochs": 30,
50 | # lr scheduler: cosine / multistep
51 | "warmup": True,
52 | "warmup_epochs": 5,
53 | "schedule_type": "cosine",
54 | # in #epochs excluding warmup
55 | "schedule_steps": [],
56 | "schedule_gamma": 0.1,
57 | }
58 | }
59 |
60 | def _merge(src, dst):
61 | for k, v in src.items():
62 | if k in dst:
63 | if isinstance(v, dict):
64 | _merge(src[k], dst[k])
65 | else:
66 | dst[k] = v
67 |
68 | def load_default_config():
69 | config = DEFAULTS
70 | return config
71 |
72 | def _update_config(config):
73 | # fill in derived fields
74 | config["model"]["train_cfg"] = config["train_cfg"]
75 | config["model"]["test_cfg"] = config["test_cfg"]
76 | config["model"]["frames_per_clip"] = config["dataset"]["frames_per_clip"]
77 | config["model"]["max_seq_len"] = config["dataset"]["max_seq_len"]
78 | return config
79 |
80 | def load_config(config_file, defaults=DEFAULTS):
81 | with open(config_file, "r") as fd:
82 | config = yaml.load(fd, Loader=yaml.FullLoader)
83 | _merge(defaults, config)
84 | config = _update_config(config)
85 | return config
--------------------------------------------------------------------------------
/GETTING_STARTED.md:
--------------------------------------------------------------------------------
1 |
2 | # Getting Started
3 |
4 |
5 |
6 | ## Requirements
7 |
8 | Please find the installation instructions in [INSTALL.md](./INSTALL.md)
9 |
10 | ## Code
11 | Pull the code using
12 |
13 | git clone https://github.com/abrarmajeedi/rica2_aqa
14 |
15 | and navigate into the code directory
16 |
17 |
18 | cd rica2_aqa
19 |
20 |
21 | ## Data
22 |
23 | You can download the zipped data from the Google drive [link](https://drive.google.com/file/d/1CjYtxnjHZzDkWDYrLMFbph9b-EZ8fdFT/view?usp=sharing).
24 |
25 |
26 | Once downloaded, unzip the archive into ./data into the code directory
27 |
28 |
29 | Make sure the data follows this structure
30 | ```markdown
31 | ├── data
32 | │
33 | │ ├── finediving
34 | │ │ ├── Annotations
35 | │ │ │ ├── Annotation files (**.pkl)
36 | │
37 | │ │ ├── FINADiving_MTL_256
38 | │ │ │ ├── Video Frame directories
39 | │
40 | │ ├── mtl_aqa
41 | │ │ ├── frames_long
42 | │ │ │ ├── Video frame directories
43 | │
44 | │ │ ├── info
45 | │ │ │ ├── Annotation files (**.pkl)
46 | ```
47 |
48 | ## Pretrained I3D weights
49 |
50 | You can download the pretrained I3D weights from the Google drive [link](https://drive.google.com/file/d/1vi-C3V_i4Sy_4Y3yJLLeiGRzpz8Evvid/view?usp=sharing).
51 |
52 |
53 | Once downloaded, place the file in `./pre_trained/model_rgb.pth`
54 |
55 |
56 |
57 | ## Running the code
58 |
59 | Use the following commands to run the code
60 |
61 | ### FineDiving
62 |
63 | python -u train.py configs/fine/stoch_fine_diving_text_data_query.yaml
64 |
65 | To run the deterministic RICA2†
66 |
67 | python -u train.py configs/fine/deter_fine_diving_text_data_query.yaml
68 |
69 | ### MTL-AQA
70 |
71 | python -u train.py configs/mtl_aqa/stoch_mtl_diving_text_data_query.yaml
72 |
73 | To run the deterministic RICA2†
74 |
75 | python -u train.py configs/mtl_aqa/deter_mtl_diving_text_data_query.yaml
76 |
77 |
78 | These commands will train the specified models and automatically run the evaluation, generating the evaluation results at the end.
79 |
80 |
81 | ## [BONUS] Tuning Experiment Parameters
82 |
83 | Our code allows easy change of model and experiment parameters:
84 |
85 | ### Modifying hyperparameters
86 |
87 | You can modify different hyperparameters of the models and training by changing values within in the config files in `./configs`
88 |
89 |
90 | ### Generating text embeddings
91 |
92 | RICA2 incorporates the step information in the an action via LLM embedddings extracted from the textual step descriptions. These can be found `./tools`.
93 |
94 | For FineDiving
95 |
96 | python ./tools/finediving/finediving_t5xxl_text_embed_extraction.py
97 |
98 | For MTL-AQA
99 |
100 | python ./tools/mtl_aqa/mtl_t5xxl_text_embed_extraction.py
101 |
102 |
103 | You can easily change the models used for extracting embeddings from the step descriptions by following the user-friendly [HuggingFace documentation](https://huggingface.co/docs).
104 |
105 |
106 | ## Contact
107 |
108 | Email: majeedi+at+wisc+dot+edu
--------------------------------------------------------------------------------
/configs/mtl_aqa/deter_mtl_aqa_text_data_query.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: mtl_aqa
2 | model_name : 'aqa-model'
3 | train_split: 'train'
4 | val_split: 'test'
5 | devices: [0]
6 | dataset: {
7 | data_root : ./data/mtl_aqa,
8 | video_dir: frames_long,
9 | label_file : augmented_final_annotations_dict.pkl,
10 | train_datafile : train_split_0.pkl ,
11 | test_datafile : test_split_0.pkl ,
12 | use_feats : False,
13 | feat_dir : "",
14 | frames_per_clip : 103,
15 | window_size : 8,
16 | stride : 5,
17 | max_seq_len: 20,
18 | input_frame_size : [256,455],
19 | crop_frame_size : 224,
20 | with_dd : True,
21 | three_judge_score_scaling : False,
22 | }
23 | model: {
24 | backbone_type : 'convEncoder',
25 | input_feat_dim: 1024, #feats or from the feat extractor
26 | embed_dim: 512,
27 | conv_dropout: 0.1, #drop out for initial conv layers
28 | conv_kernel_size: 3,
29 | neck_type: 'decoder-neck',
30 | num_layers : {
31 | n_conv_layers: 2,
32 | n_encoder_layers: 2,
33 | n_decoder_layers: 2,
34 | n_mlp_head_layers : 2, #for now this is hardcoded to 3 layers
35 | },
36 | encoder_params: {
37 | n_encoder_heads: 8,
38 | attn_pdrop: 0.1,
39 | proj_pdrop: 0.1,
40 | path_pdrop: 0.1,
41 | use_abs_pe: False,
42 | },
43 | decoder_params: {
44 | n_decoder_heads: 8,
45 | stride: 1,
46 | attn_pdrop: 0.1,
47 | proj_pdrop: 0.1,
48 | path_pdrop: 0.1,
49 | xattn_mode: 'affine',
50 | with_ln: True,
51 | query_config: {
52 | text_embeddings_path: assets/MTL_t5_xxl_text_embeddings.npy, #relative to the project root
53 | freeze_text_embeddings: True,
54 | text_queries_emb_dim: 4096,
55 | },
56 | },
57 | use_stochastic_embd: False,
58 | num_random_samples: 1,
59 | num_phases: 24, # =num of total subaction, check dataloader for more details
60 | score_bins: 1,
61 | }
62 | opt: {
63 | learning_rate: 0.0005, #this will get scaled by batchsize
64 | warmup_epochs: 3,
65 | schedule_type: "no_decay",
66 | epochs: 350,
67 | weight_decay: 0.0,
68 | feature_extractor_factor: 0.3, #lr for the feature extractor will be scaled by this factor
69 | neck_lr_factor: 1.0, #lr for the neck will be scaled by this factor
70 | }
71 | loader: {
72 | #for train and test batch_size provide how many samples can fit on one 8GB GPU e.g. for MTL this is 1 and 2, for finediving, this is 2 and 4
73 | train_batch_size : 8, #this can get overwritten by arg
74 | test_batch_size: 16, #this can get overwritten by arg
75 | num_workers: 4, #this will also be changed dynamically based on cpu count, this is min number, max is 20
76 | }
77 | train_cfg: {
78 | dataset_name: mtl_aqa,
79 | clip_grad_l2norm: 1.0,
80 | accumulation_steps: 1, #how many steps/batches to accumulate gradients before taking an optimizer step
81 | loss_weights: {
82 | loss: mse,
83 | quality_score: 1.0,
84 | phase_vib : 0.0, # 1e-3
85 | scale_vib: False,
86 | ranking: 0.05, # use > 0.0 to enable
87 | sparsity: 0.05 # use > 0.0 to enable
88 | },
89 | }
90 | test_cfg: {
91 | }
92 | output_folder: ./ckpt/
93 |
--------------------------------------------------------------------------------
/configs/mtl_aqa/stoch_mtl_aqa_text_data_query.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: mtl_aqa
2 | model_name : 'aqa-model'
3 | train_split: 'train'
4 | val_split: 'test'
5 | devices: [0]
6 | dataset: {
7 | data_root : ./data/mtl_aqa,
8 | video_dir: frames_long,
9 | label_file : augmented_final_annotations_dict.pkl,
10 | train_datafile : train_split_0.pkl ,
11 | test_datafile : test_split_0.pkl ,
12 | use_feats : False,
13 | feat_dir : "",
14 | frames_per_clip : 103,
15 | window_size : 8,
16 | stride : 5,
17 | max_seq_len: 20,
18 | input_frame_size : [256,455],
19 | crop_frame_size : 224,
20 | with_dd : True,
21 | three_judge_score_scaling : False,
22 | }
23 | model: {
24 | backbone_type : 'convEncoder',
25 | input_feat_dim: 1024, #feats or from the feat extractor
26 | embed_dim: 512,
27 | conv_dropout: 0.1, #drop out for initial conv layers
28 | conv_kernel_size: 3,
29 | neck_type: 'decoder-neck',
30 | num_layers : {
31 | n_conv_layers: 2,
32 | n_encoder_layers: 2,
33 | n_decoder_layers: 2,
34 | n_mlp_head_layers : 2, #for now this is hardcoded to 3 layers
35 | },
36 | encoder_params: {
37 | n_encoder_heads: 8,
38 | attn_pdrop: 0.1,
39 | proj_pdrop: 0.1,
40 | path_pdrop: 0.1,
41 | use_abs_pe: False,
42 | },
43 | decoder_params: {
44 | n_decoder_heads: 8,
45 | stride: 1,
46 | attn_pdrop: 0.1,
47 | proj_pdrop: 0.1,
48 | path_pdrop: 0.1,
49 | xattn_mode: 'affine',
50 | with_ln: True,
51 | query_config: {
52 | text_embeddings_path: assets/MTL_t5_xxl_text_embeddings.npy, #relative to the project root
53 | freeze_text_embeddings: True,
54 | text_queries_emb_dim: 4096,
55 | },
56 | },
57 | use_stochastic_embd: True,
58 | num_random_samples: 20,
59 | num_phases: 24, # =num of total subaction, check dataloader for more details
60 | score_bins: 1,
61 | }
62 | opt: {
63 | learning_rate: 0.0005, #this will get scaled by batchsize
64 | warmup_epochs: 3,
65 | schedule_type: "no_decay",
66 | epochs: 350,
67 | weight_decay: 0.0,
68 | feature_extractor_factor: 0.3, #lr for the feature extractor will be scaled by this factor
69 | neck_lr_factor: 1.0, #lr for the neck will be scaled by this factor
70 | }
71 | loader: {
72 | #for train and test batch_size provide how many samples can fit on one 8GB GPU e.g. for MTL this is 1 and 2, for finediving, this is 2 and 4
73 | train_batch_size : 8, #this can get overwritten by arg
74 | test_batch_size: 16, #this can get overwritten by arg
75 | num_workers: 4, #this will also be changed dynamically based on cpu count, this is min number, max is 20
76 | }
77 | train_cfg: {
78 | dataset_name: mtl_aqa,
79 | clip_grad_l2norm: 1.0,
80 | accumulation_steps: 1, #how many steps/batches to accumulate gradients before taking an optimizer step
81 | loss_weights: {
82 | loss: mse,
83 | quality_score: 1.0,
84 | phase_vib : 0.00001, # 1e-3
85 | scale_vib: True,
86 | ranking: 0.05, # use > 0.0 to enable
87 | sparsity: 0.05 # use > 0.0 to enable
88 | },
89 | }
90 | test_cfg: {
91 | }
92 | output_folder: ./ckpt/
93 |
--------------------------------------------------------------------------------
/configs/fine/deter_fine_diving_text_data_query.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: finediving
2 | model_name : 'aqa-model'
3 | train_split: 'train'
4 | val_split: 'test'
5 | devices: [0]
6 | dataset: {
7 | data_root : ./data/finediving,
8 | video_dir: FINADiving_MTL_256s,
9 | label_file : FineDiving_fine-grained_annotation.pkl,
10 | coarse_label_file : FineDiving_coarse_annotation.pkl,
11 | train_datafile : train_split.pkl ,
12 | test_datafile : test_split.pkl ,
13 | use_feats : False,
14 | feat_dir : "",
15 | frames_per_clip : 96,
16 | window_size : 16,
17 | stride : 10,
18 | max_seq_len: 9,
19 | input_frame_size : [112,200],
20 | crop_frame_size : 112,
21 | with_dd : True,
22 | three_judge_score_scaling : False,
23 | }
24 | model: {
25 | backbone_type : 'convEncoder',
26 | input_feat_dim: 1024, #feats or from the feat extractor
27 | embed_dim: 512,
28 | conv_dropout: 0.1, #drop out for initial conv layers
29 | conv_kernel_size: 3,
30 | neck_type: 'decoder-neck',
31 | num_layers : {
32 | n_conv_layers: 2,
33 | n_encoder_layers: 2,
34 | n_decoder_layers: 2,
35 | n_mlp_head_layers : 2, #for now this is hardcoded to 3 layers
36 | },
37 | encoder_params: {
38 | n_encoder_heads: 8,
39 | attn_pdrop: 0.1,
40 | proj_pdrop: 0.1,
41 | path_pdrop: 0.1,
42 | use_abs_pe: False,
43 | },
44 | decoder_params: {
45 | n_decoder_heads: 8,
46 | stride: 1,
47 | attn_pdrop: 0.1,
48 | proj_pdrop: 0.1,
49 | path_pdrop: 0.1,
50 | xattn_mode: 'affine',
51 | with_ln: True,
52 | query_config: {
53 | text_embeddings_path: assets/FineDiving_t5_xxl_text_embeddings.npy, #relative to the project root
54 | freeze_text_embeddings: True,
55 | text_queries_emb_dim: 4096,
56 | },
57 | },
58 | use_stochastic_embd: False,
59 | num_random_samples: 1,
60 | num_phases: 29, # =num of total subaction, check dataloader for more details
61 | score_bins: 1,
62 | }
63 | opt: {
64 | learning_rate: 0.0005, #this will get scaled by batchsize
65 | warmup_epochs: 3,
66 | schedule_type: "no_decay",
67 | epochs: 350,
68 | weight_decay: 0.0,
69 | feature_extractor_factor: 0.3, #lr for the feature extractor will be scaled by this factor
70 | neck_lr_factor: 1.0, #lr for the neck will be scaled by this factor
71 | }
72 | loader: {
73 | #for train and test batch_size provide how many samples can fit on one 8GB GPU e.g. for MTL this is 1 and 2, for finediving, this is 2 and 4
74 | train_batch_size : 8, #this can get overwritten by arg
75 | test_batch_size: 16, #this can get overwritten by arg
76 | num_workers: 4, #this will also be changed dynamically based on cpu count, this is min number, max is 20
77 | }
78 | train_cfg: {
79 | dataset_name: finediving,
80 | clip_grad_l2norm: 1.0,
81 | accumulation_steps: 1, #how many steps/batches to accumulate gradients before taking an optimizer step
82 | loss_weights: {
83 | loss: mse,
84 | quality_score: 1.0,
85 | phase_vib : 0.0, # 1e-3
86 | scale_vib: False,
87 | ranking: 0.05, # use > 0.0 to enable
88 | sparsity: 0.05 # use > 0.0 to enable
89 | },
90 | }
91 | test_cfg: {
92 | }
93 | output_folder: ./ckpt/
94 |
95 |
96 |
--------------------------------------------------------------------------------
/configs/fine/stoch_fine_diving_text_data_query.yaml:
--------------------------------------------------------------------------------
1 | dataset_name: finediving
2 | model_name : 'aqa-model'
3 | train_split: 'train'
4 | val_split: 'test'
5 | devices: [0]
6 | dataset: {
7 | data_root : ./data/finediving,
8 | video_dir: FINADiving_MTL_256s,
9 | label_file : FineDiving_fine-grained_annotation.pkl,
10 | coarse_label_file : FineDiving_coarse_annotation.pkl,
11 | train_datafile : train_split.pkl ,
12 | test_datafile : test_split.pkl ,
13 | use_feats : False,
14 | feat_dir : "",
15 | frames_per_clip : 96,
16 | window_size : 16,
17 | stride : 10,
18 | max_seq_len: 9,
19 | input_frame_size : [112,200],
20 | crop_frame_size : 112,
21 | with_dd : True,
22 | three_judge_score_scaling : False,
23 | }
24 | model: {
25 | backbone_type : 'convEncoder',
26 | input_feat_dim: 1024, #feats or from the feat extractor
27 | embed_dim: 512,
28 | conv_dropout: 0.1, #drop out for initial conv layers
29 | conv_kernel_size: 3,
30 | neck_type: 'decoder-neck',
31 | num_layers : {
32 | n_conv_layers: 2,
33 | n_encoder_layers: 2,
34 | n_decoder_layers: 2,
35 | n_mlp_head_layers : 2, #for now this is hardcoded to 3 layers
36 | },
37 | encoder_params: {
38 | n_encoder_heads: 8,
39 | attn_pdrop: 0.1,
40 | proj_pdrop: 0.1,
41 | path_pdrop: 0.1,
42 | use_abs_pe: False,
43 | },
44 | decoder_params: {
45 | n_decoder_heads: 8,
46 | stride: 1,
47 | attn_pdrop: 0.1,
48 | proj_pdrop: 0.1,
49 | path_pdrop: 0.1,
50 | xattn_mode: 'affine',
51 | with_ln: True,
52 | query_config: {
53 | text_embeddings_path: assets/FineDiving_t5_xxl_text_embeddings.npy, #relative to the project root
54 | freeze_text_embeddings: True,
55 | text_queries_emb_dim: 4096,
56 | },
57 | },
58 | use_stochastic_embd: True,
59 | num_random_samples: 20,
60 | num_phases: 29, # =num of total subaction, check dataloader for more details
61 | score_bins: 1,
62 | }
63 | opt: {
64 | learning_rate: 0.0005, #this will get scaled by batchsize
65 | warmup_epochs: 3,
66 | schedule_type: "no_decay",
67 | epochs: 350,
68 | weight_decay: 0.0,
69 | feature_extractor_factor: 0.3, #lr for the feature extractor will be scaled by this factor
70 | neck_lr_factor: 1.0, #lr for the neck will be scaled by this factor
71 | }
72 | loader: {
73 | #for train and test batch_size provide how many samples can fit on one 8GB GPU e.g. for MTL this is 1 and 2, for finediving, this is 2 and 4
74 | train_batch_size : 8, #this can get overwritten by arg
75 | test_batch_size: 16, #this can get overwritten by arg
76 | num_workers: 4, #this will also be changed dynamically based on cpu count, this is min number, max is 20
77 | }
78 | train_cfg: {
79 | dataset_name: finediving,
80 | clip_grad_l2norm: 1.0,
81 | accumulation_steps: 1, #how many steps/batches to accumulate gradients before taking an optimizer step
82 | loss_weights: {
83 | loss: mse,
84 | quality_score: 1.0,
85 | phase_vib : 0.00001, # 1e-3
86 | scale_vib: True,
87 | ranking: 0.05, # use > 0.0 to enable
88 | sparsity: 0.05 # use > 0.0 to enable
89 | },
90 | }
91 | test_cfg: {
92 | }
93 | output_folder: ./ckpt/
94 |
95 |
96 |
--------------------------------------------------------------------------------
/libs/utils/preprocessing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | @torch.no_grad()
6 | def multi_label_multi_class_one_hot_encode(labels, num_phases):
7 | """
8 | Convert a list of labels to a binary one hot encoding
9 | This function is for one hot encoding multi-label multi-class labels
10 | """
11 | labels = torch.LongTensor(labels)
12 |
13 | y_onehot = nn.functional.one_hot(labels, num_classes=num_phases)
14 |
15 | y_onehot = y_onehot.sum(dim=0).float()
16 |
17 | return y_onehot
18 |
19 | @torch.no_grad()
20 | def preprocessing(video_list,
21 | feat_extractor_type,
22 | max_seq_len, num_phases,
23 | padding_val=0.0
24 | ):
25 | """
26 | Generate batched features and masks from a list of dict items
27 | """
28 | if feat_extractor_type is not None:
29 | #each item in batch has frames as a list of tensors of shape Nwin x T x C x H x W
30 | frames = [x['window_frames'] for x in video_list]
31 |
32 | #B, Nwin x T x C x H x W -> B x Nwin x T x C x H x W
33 | batched_inputs = torch.stack(frames).cuda()
34 |
35 | batched_masks = torch.ones((len(video_list), 1, frames[0].shape[0])).cuda()
36 |
37 | else:
38 | feats = [x['feats'] for x in video_list]
39 | feats_lens = torch.as_tensor([feat.shape[-1] for feat in feats])
40 | max_len = feats_lens.max(0).values.item()
41 |
42 |
43 | assert max_len <= max_seq_len, f"Input length must be smaller than max_seq_len during training, max len = {max_len}, max_seq_len = {max_seq_len}"
44 | # set max_len to self.max_seq_len
45 | max_len = max_seq_len
46 | # batch input shape B, C, T
47 | batch_shape = [len(feats), feats[0].shape[0], max_len]
48 | batched_inputs = feats[0].new_full(batch_shape, padding_val)
49 | for feat, pad_feat in zip(feats, batched_inputs):
50 | pad_feat[..., :feat.shape[-1]].copy_(feat)
51 |
52 | # generate the mask
53 | batched_masks = torch.arange(max_len)[None, :] < feats_lens[:, None]
54 |
55 | # push to device
56 | batched_inputs = batched_inputs.cuda()
57 | batched_masks = batched_masks.unsqueeze(1).cuda()
58 |
59 | actions = [multi_label_multi_class_one_hot_encode(x['actions_present'], num_phases) for x in video_list]
60 |
61 | gt_actions = torch.stack(actions, dim=0).cuda()
62 |
63 | gt_scores = torch.tensor([x['gt_score'] for x in video_list]).cuda()
64 |
65 | video_ids = [x['video_id'] for x in video_list]
66 |
67 |
68 | if 'target' in video_list[0]:
69 | target = torch.tensor([x['target'] for x in video_list]).cuda()
70 | else:
71 | target = []
72 |
73 | if "difficulty" in video_list[0]:
74 | difficulties = torch.tensor([x['difficulty'] for x in video_list])
75 | else:
76 | difficulties = []
77 |
78 | ret_dict = {
79 | 'video_ids': video_ids,
80 | 'batched_inputs': batched_inputs,
81 | 'batched_masks': batched_masks,
82 | 'gt_actions': gt_actions,
83 | 'gt_scores': gt_scores,
84 | 'target': target,
85 | 'difficulties': difficulties
86 | }
87 |
88 | return ret_dict
89 |
90 |
91 |
--------------------------------------------------------------------------------
/libs/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy import stats
3 |
4 | from scipy.stats import hmean, kendalltau
5 |
6 |
7 | np.seterr(divide='ignore', invalid='ignore')
8 | np.set_printoptions(precision=2)
9 |
10 |
11 |
12 | class MetricsDict:
13 | """A class to store all the data for computing metrics.
14 | """
15 | def __init__(self):
16 | self.metric_dict = {
17 | "pred_scores": [],
18 | "gt_scores": [],
19 | "gt_actions": [],
20 | "all_judge_scores": [],
21 | "difficulty": [],
22 | }
23 |
24 | def update(self, key, value):
25 | if key not in self.metric_dict.keys():
26 | self.metric_dict[key] = []
27 |
28 | self.metric_dict[key].append(value)
29 |
30 | def get_metric_dict(self):
31 | return self.metric_dict
32 |
33 |
34 | def evaluate(metric_dict,
35 | is_train,
36 | dataset_name,
37 | curr_epoch =-1
38 | ):
39 | metric_dict = metric_dict.get_metric_dict()
40 |
41 | if not is_train:
42 | #if var is available use that for uncertainty else use logits
43 | plot_rejection_curve(metric_dict)
44 |
45 | pred_scores = metric_dict["pred_scores"]
46 | true_scores = metric_dict["gt_scores"]
47 |
48 | pred_scores = np.concatenate([np.atleast_1d(x) for x in pred_scores])
49 | true_scores = np.concatenate([np.atleast_1d(x) for x in true_scores])
50 |
51 | min_true_score = true_scores.min()
52 | max_true_score = true_scores.max()
53 |
54 | rho, p = stats.spearmanr(pred_scores, true_scores)
55 | L2 = np.power(pred_scores - true_scores, 2).sum() / true_scores.shape[0]
56 | RL2 = np.power((pred_scores - true_scores) / (max_true_score - min_true_score), 2).sum() / true_scores.shape[0]
57 |
58 | if dataset_name == "cat101":
59 | accuracy = (np.round(pred_scores) == true_scores).mean()
60 | print(f"Epoch {curr_epoch} Accuracy: {accuracy}")
61 |
62 | return rho*100, L2, RL2*100
63 |
64 | def plot_rejection_curve(metric_dict):
65 | if "global_sqrt_var_emb" in metric_dict.keys():
66 | global_sqrt_var_emb = metric_dict["global_sqrt_var_emb"]
67 | global_sqrt_var_emb = np.concatenate(global_sqrt_var_emb)
68 | mean_sqrt_var = hmean(global_sqrt_var_emb, axis = 1)
69 | uncertainties = mean_sqrt_var
70 | else:
71 | return
72 |
73 | #matplotlib.use('module://drawilleplot')
74 | pred_scores = metric_dict["pred_scores"]
75 | true_scores = metric_dict["gt_scores"]
76 |
77 | pred_scores = np.concatenate([np.atleast_1d(x) for x in pred_scores])
78 | true_scores = np.concatenate([np.atleast_1d(x) for x in true_scores])
79 |
80 | y_mae, bins = rejection_plot(pred_scores, true_scores, uncertainties)
81 |
82 | print("Calibration Kendall Tau (MAE): {:0.4f}".format(kendalltau(y_mae, bins)[0]))
83 |
84 |
85 |
86 | def rejection_plot(preds, gts, uncertainty, num_bins=11):
87 | all_mae = []
88 | bins = np.linspace(0, 100, num_bins)
89 | #[ 0. 11.11 22.22 33.33 44.44 55.56 66.67 77.78 88.89 100. ]
90 |
91 | conf_sort_idx = np.argsort(uncertainty)
92 | uncertainty = uncertainty[conf_sort_idx]
93 |
94 | preds = preds[conf_sort_idx]
95 | gts = gts[conf_sort_idx]
96 |
97 | for bin_num in range(bins.shape[0]-1):
98 | bin_low = np.percentile(uncertainty, bins[bin_num])
99 |
100 | if bin_num != bins.shape[0]-2:
101 | bin_high = np.percentile(uncertainty, bins[bin_num+1])
102 | else:
103 | bin_high = 10000000000000000
104 | bin_idxs = np.where((uncertainty>=bin_low) & (uncertainty= t1 and t < t2
77 |
78 | def exceeds_range(frame):
79 | return frame.pts * tb >= t2
80 |
81 | for frame in buffer:
82 | if is_in_range(frame):
83 | ret.append(frame)
84 |
85 | prev_pts = None
86 |
87 | # This try except block is to avoid the EOF error that arrives because t2 exceeds the frame range
88 | try:
89 | for frame in container.decode(video=0):
90 | if frame.pts is None:
91 | raise AssertionError("frame is None")
92 | if prev_pts is not None and frame.pts < prev_pts:
93 | raise AssertionError("failed assumption pts in order: ")
94 | if not isinstance(frame, av.VideoFrame):
95 | raise AssertionError("other packets not supported")
96 | prev_pts = frame.pts
97 |
98 | buffer.append(frame)
99 | if len(buffer) > max_buffer_size:
100 | del buffer[0]
101 |
102 | if is_in_range(frame):
103 | ret.append(frame)
104 | elif exceeds_range(frame):
105 | break
106 | except:
107 | pass
108 |
109 | pts_in_ret = [frame.pts for frame in ret]
110 | if not (np.diff(pts_in_ret) > 0).all():
111 | raise AssertionError("not increasing sequence of frames")
112 | return ret
113 |
114 |
115 | @property
116 | def duration(self) -> float:
117 | vstream = self.vid._container.streams.video[0]
118 | return vstream.duration * vstream.time_base
119 |
--------------------------------------------------------------------------------
/libs/modeling/necks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | import numpy as np
5 |
6 | from .models import register_neck
7 | from .blocks import TransformerDecoder, LayerNorm
8 |
9 | @register_neck("decoder-neck")
10 | class Decoder(nn.Module):
11 | def __init__(self, d_model=512,
12 | n_heads=8,
13 | stride=1,
14 | num_decoder_layers=6,
15 | attn_pdrop=0.0,
16 | proj_pdrop=0.0,
17 | path_pdrop=0.0,
18 | num_phases=-1,
19 | xattn_mode='affine',
20 | with_ln=True,
21 | use_rel_pe=False,
22 | query_config=None):
23 | super().__init__()
24 |
25 | self.layers = nn.ModuleList()
26 | for _ in range(num_decoder_layers):
27 | self.layers.append(
28 | TransformerDecoder(
29 | embd_dim = d_model,
30 | kv_dim = d_model,
31 | stride=stride,
32 | n_heads=n_heads,
33 | attn_pdrop=attn_pdrop,
34 | proj_pdrop=proj_pdrop,
35 | path_pdrop=path_pdrop,
36 | xattn_mode=xattn_mode,
37 | use_rel_pe=use_rel_pe
38 | )
39 | )
40 |
41 | self.ln_out = LayerNorm(d_model) if with_ln else None
42 |
43 | assert num_phases > 0, "Number of phases must be > 0"
44 |
45 | self.num_phases = num_phases
46 |
47 | self._reset_parameters()
48 | text_queries_emb_dim = query_config['text_queries_emb_dim']
49 |
50 | self.d_model = d_model
51 | self.n_heads = n_heads
52 | self.query_config = query_config
53 |
54 | text_embd = torch.from_numpy(np.load(query_config['text_embeddings_path']))
55 | self.text_queries = nn.Embedding.from_pretrained(text_embd.squeeze(),
56 | freeze=query_config['freeze_text_embeddings'])
57 | self.text_queries_projection = nn.Linear(
58 | in_features=text_queries_emb_dim, out_features=d_model, bias=True)
59 |
60 | def _get_queries(self):
61 | # num phases x C
62 | text_queries = self.text_queries.weight
63 |
64 | # num phases x C -> num phases x d_model
65 | queries = self.text_queries_projection(text_queries)
66 |
67 | return queries
68 |
69 | def _reset_parameters(self):
70 | for p in self.parameters():
71 | if p.dim() > 1:
72 | nn.init.xavier_uniform_(p)
73 |
74 | def _forward(self, q, q_mask, kv, kv_mask, kv_size=None, video_ids=None, curr_epoch=None):
75 | layerwise_cross_attn = []
76 | for layer in self.layers:
77 | q, q_mask, cross_attn = layer(q, q_mask, kv, kv_mask, kv_size,
78 | video_ids=video_ids, curr_epoch=curr_epoch)
79 | layerwise_cross_attn.append(cross_attn)
80 |
81 | q = self.ln_out(q)
82 |
83 | return q, q_mask, layerwise_cross_attn
84 |
85 |
86 | def forward(self, vid_embd, vid_masks, batched_gt_actions, video_ids=None, curr_epoch=None):
87 |
88 | """
89 | vid_embd: B, C, T
90 | vid_masks: B, T
91 | batched_gt_actions: B, phases
92 | """
93 | B, C, T = vid_embd.shape
94 |
95 | #phases x d_model
96 | queries = self._get_queries()
97 |
98 | #phases x d_model -> B x phases x d_model
99 | queries = queries.repeat(B, 1, 1)
100 |
101 | #B x phases x d_model -> B x d_model x phases
102 | queries = queries.permute(0, 2, 1)
103 |
104 | #B x 1 x Phases
105 | query_masks = batched_gt_actions.unsqueeze(1).bool().detach()
106 |
107 | out, out_masks, cross_attns = self._forward(queries, query_masks, vid_embd, vid_masks, kv_size=T,
108 | video_ids=video_ids, curr_epoch=curr_epoch)
109 |
110 | #B x d_model x phases -> B x phases x d_model
111 | out = out.permute(0, 2, 1)
112 |
113 | return out, cross_attns
114 |
--------------------------------------------------------------------------------
/tools/mtl_aqa/mtl_t5xxl_text_embed_extraction.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import List
3 |
4 | import torch
5 | import numpy as np
6 | from transformers import AutoTokenizer, AutoModel
7 |
8 | ##########################MTL########
9 | @dataclass
10 | class DataClass:
11 | sub_action_descriptions: List[str] = field(
12 | default_factory=lambda: [
13 | "Armstand: In this position, athletes start by standing on their hands on the edge of the diving board and perform their dive while maintaining this handstand position.",
14 | "Inwards: In this rotation type, athletes perform a forward-facing takeoff and rotate inward toward the diving board as they execute their dive.",
15 | "Reverse: In this rotation type, athletes perform a backward-facing takeoff and rotate backward away from the diving board as they execute their dive.",
16 | "Backward: In this rotation type, athletes perform a backward-facing takeoff and rotate backward toward the diving board as they execute their dive.",
17 | "Forward: In this rotation type, athletes perform a forward-facing takeoff and rotate forward away from the diving board as they execute their dive.",
18 | "Free: In this position, athletes have the freedom to perform any combination of dives from various categories without any restrictions or limitations.",
19 | "Tuck: In this position, athletes bring their knees to their chest and hold onto their shins while maintaining a compact shape throughout their dive.",
20 | "Pike: In this position, athletes maintain a straight body with their legs extended and their toes pointed out while bending at the waist to bring their hands toward their toes.",
21 | "0.5 Somersault: Athletes perform a half rotation in the air during their dive.",
22 | "1 Somersault: Athletes perform a full forward or backward rotation in the air during their dive.",
23 | "1.5 Somersault: Athletes perform a full rotation and an additional half rotation in the air during their dive.",
24 | "2 Somersault: Athletes perform two full forward or backward rotations in the air during their dive.",
25 | "2.5 Somersault: Athletes perform two full rotations and an additional half rotation in the air during their dive.",
26 | "3 Somersault: Athletes perform three full forward or backward rotations in the air during their dive.",
27 | "3.5 Somersault: Athletes perform three full rotations and an additional half rotation in the air during their dive.",
28 | "4.5 Somersault: Athletes perform four full rotations and an additional half rotation in the air during their dive.",
29 | "0.5 Twist: Athletes perform a half twist in the air during their dive.",
30 | "1 Twist: Athletes perform one full twist in the air during their dive.",
31 | "1.5 Twist: Athletes perform one and a half twists in the air during their dive.",
32 | "2 Twist: Athletes perform two full twists in the air during their dive.",
33 | "2.5 Twist: Athletes perform two and a half twists in the air during their dive.",
34 | "3 Twist: Athletes perform three full twists in the air during their dive.",
35 | "3.5 Twist: Athletes perform three and a half twists in the air during their dive.",
36 | "Entry: A diving technique involving a entry into the water, typically performed at the end of a dive.",
37 | ]
38 | )
39 |
40 | class GetTextEmbeddings:
41 | def __init__(self, output_path) -> None:
42 | """
43 | Args:
44 | output_path (_type_): output_path to save the embeddibngs
45 | """
46 | self.data = DataClass()
47 | self.output_path = output_path
48 |
49 | def get_huggingface_embeddings(self):
50 | def average_pool(
51 | last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
52 | ) -> torch.Tensor:
53 | last_hidden = last_hidden_states.masked_fill(
54 | ~attention_mask[..., None].bool(), 0.0
55 | )
56 | return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
57 |
58 | input_texts = [
59 | "This is a " + desc.replace(": ", " action: ")
60 | for desc in self.data.sub_action_descriptions
61 | ]
62 |
63 | model_id="google/flan-t5-xxl" # "intfloat/e5-large"
64 | tokenizer = AutoTokenizer.from_pretrained(model_id)
65 | model = AutoModel.from_pretrained(model_id)
66 |
67 | # Tokenize the input texts
68 | batch_dict = tokenizer(
69 | input_texts,
70 | max_length=512,
71 | padding=True,
72 | truncation=True,
73 | return_tensors="pt",
74 | )
75 |
76 | with torch.no_grad():
77 | outputs= model(input_ids= batch_dict['input_ids'], decoder_input_ids=batch_dict['input_ids'])
78 | embeddings = average_pool(
79 | outputs.last_hidden_state, batch_dict["attention_mask"]
80 | )
81 | np.save(self.output_path, embeddings.detach().cpu().numpy())
82 | return embeddings.detach().cpu().numpy()
83 |
84 |
85 | get_text_embeddings = GetTextEmbeddings("MTL_t5_xxl_text_embeddings.npy")
86 | text_embeddings = get_text_embeddings.get_huggingface_embeddings()
87 | print("Text Embed shape: ", text_embeddings.shape)
88 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | # python imports
2 | import argparse
3 | import os
4 | from pprint import pprint
5 | import sys, pickle
6 |
7 | # torch imports
8 | import torch
9 | import torch.nn as nn
10 | import torch.utils.data
11 | # for visualization
12 | from torch.utils.tensorboard import SummaryWriter
13 |
14 | # our code
15 | from libs.core import load_config
16 | from libs.datasets import make_dataset, make_data_loader
17 | from libs.modeling import make_meta_arch
18 | from libs.utils import (Logger, valid_one_epoch,
19 | fix_random_seed, ModelEma)
20 |
21 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
22 | def override_cfg_params(cfg, args):
23 | if args.gpu is not None:
24 | cfg["devices"] = args.gpu
25 | else:
26 | cfg['devices'] = [i for i in range(torch.cuda.device_count())]
27 |
28 | if args.data_root is not None:
29 | cfg["dataset"]["data_root"] = args.data_root
30 |
31 | if args.test_batch_size > 0:
32 | cfg["loader"]["test_batch_size"] = args.test_batch_size
33 |
34 | if cfg["train_cfg"]["loss_weights"]["loss"] == "corn":
35 | cfg["model"]["score_bins"] -= 1
36 |
37 | if args.cv_fold > -1:
38 | cfg["dataset"]["cross_val_id"] = args.cv_fold
39 |
40 | if args.cv_split_file != "":
41 | cfg["dataset"]["cross_val_split_file"] = args.cv_split_file
42 |
43 | if cfg["dataset"]["use_feats"] == False:
44 | cfg["model"]["finetune_feat_extractor"]= True
45 | cfg["model"]["feat_extractor_type"]= 'i3d'
46 | cfg["model"]["feat_extractor_weights_path"]= './pre_trained/model_rgb.pth'
47 |
48 | if cfg["model"]["use_stochastic_embd"] == False:
49 | cfg["train_cfg"]["loss_weights"]["phase_vib"] = 0.0
50 |
51 | return cfg
52 |
53 |
54 |
55 | def create_train_val_dataloaders(cfg, rng_generator):
56 | train_dataset = make_dataset(
57 | cfg['dataset_name'],
58 | True,
59 | cfg['train_split'],
60 | **cfg['dataset']
61 | )
62 |
63 |
64 | val_dataset = make_dataset(
65 | cfg['dataset_name'],
66 | False,
67 | cfg['val_split'],
68 | **cfg['dataset']
69 | )
70 |
71 | # data loaders
72 | train_loader = make_data_loader(
73 | train_dataset, True, rng_generator, **cfg['loader'])
74 |
75 | val_loader = make_data_loader(
76 | val_dataset, False, None, **cfg['loader'])
77 |
78 | return (train_loader, val_loader)
79 |
80 | ################################################################################
81 | def main(args):
82 | """main function that handles training / inference"""
83 |
84 | """1. setup parameters / folders"""
85 |
86 | # parse args
87 | args.start_epoch = 0
88 |
89 | if os.path.isfile(args.config):
90 | cfg = load_config(args.config)
91 | else:
92 | raise ValueError("Config file does not exist.")
93 |
94 | torch.set_warn_always(False)
95 |
96 | # fix the random seeds (this will fix everything)
97 | rng_generator = fix_random_seed(cfg['init_rand_seed'], include_cuda=True)
98 |
99 | cfg = override_cfg_params(cfg, args)
100 |
101 | print("Args")
102 | pprint(vars(args), indent=4, stream=sys.__stdout__,sort_dicts=False)
103 | pprint(cfg, stream=sys.__stdout__,sort_dicts=False)
104 |
105 | """2. create dataset / dataloader"""
106 | train_loader, val_loader = create_train_val_dataloaders(cfg, rng_generator)
107 |
108 | """3. create model, optimizer, and scheduler"""
109 | # model
110 | model = make_meta_arch(cfg['model_name'], **cfg['model'])
111 |
112 | # not ideal for multi GPU training, ok for now
113 | # gpu_ids = ','.join(str(device_id) for device_id in cfg['devices'])
114 | # os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids
115 | # model = nn.DataParallel(model, device_ids=cfg['devices'])
116 | model = nn.DataParallel(model).cuda()
117 |
118 | ckpt_file = args.ckpt
119 |
120 | if not os.path.isfile(ckpt_file):
121 | raise ValueError("CKPT file does not exist!")
122 |
123 | """4. load ckpt"""
124 | print("=> loading checkpoint '{}'".format(ckpt_file))
125 | # load ckpt, reset epoch / best rmse
126 | checkpoint = torch.load(
127 | ckpt_file,
128 | map_location = lambda storage, loc: storage.cuda(cfg['devices'][0])
129 | )
130 | model.load_state_dict(checkpoint['state_dict'], strict=True)
131 | del checkpoint
132 |
133 |
134 | """5. validation loop"""
135 | print("\nStart testing model {:s} ...".format(cfg['model_name']))
136 |
137 | with torch.no_grad():
138 | curr_srcc, curr_rl2, metric_dict = valid_one_epoch(
139 | val_loader,
140 | model,
141 | -1,
142 | cfg = cfg,
143 | tb_writer=None,
144 | print_freq=args.print_freq,
145 | save_predictions=True
146 | )
147 |
148 | print("SRCC: {:.4f}, RL2: {:.4f}".format(curr_srcc, curr_rl2))
149 |
150 | with open(os.path.join(os.path.dirname(ckpt_file), "epoch_{:03d}_srcc_{:.3f}_rl2_{:.3f}_outputs.pkl".format(-1, curr_srcc, curr_rl2)), "wb") as f:
151 | pickle.dump(metric_dict, f)
152 |
153 |
154 | print("All done!")
155 | return
156 |
157 | ################################################################################
158 | if __name__ == '__main__':
159 | """Entry Point"""
160 | # the arg parser
161 | parser = argparse.ArgumentParser(
162 | description='Train')
163 | parser.add_argument('config', metavar='DIR',
164 | help='path to a config file')
165 | parser.add_argument('-p', '--print-freq', default=10, type=int,
166 | help='print frequency (default: 10 iterations)')
167 | parser.add_argument('--ckpt', default='', type=str,
168 | help='name of exp folder (default: none)')
169 | parser.add_argument('--data_root', type=str, metavar='PATH',)
170 | parser.add_argument('--test_batch_size', default=-1, type=int)
171 | parser.add_argument('--cv_fold', default=-1, type=int)
172 |
173 | parser.add_argument('--cv_split_file', default='', type=str)
174 | parser.add_argument('--gpu', nargs='*')
175 |
176 | args = parser.parse_args()
177 |
178 | main(args)
179 |
--------------------------------------------------------------------------------
/libs/modeling/backbones.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import numpy as np
4 | from torch.nn import functional as F
5 |
6 | from .models import register_backbone
7 | from .blocks import (TransformerBlock, MaskedConv1D,
8 | LayerNorm)
9 | from .i3d import I3D
10 |
11 |
12 | @register_backbone("convEncoder")
13 | class ConvEncoderBackbone(nn.Module):
14 | """
15 | A backbone that with convs and encoder blocks
16 | """
17 | def __init__(
18 | self,
19 | n_in, # input feature dimension
20 | n_embd, # embedding dimension (after convolution)
21 | n_embd_ks, # conv kernel size of the embedding network
22 | n_conv_layers,
23 | conv_dropout,
24 | conv_ln, # if to use layernorm
25 | n_encoder_layers,
26 | n_enc_head,
27 | attn_pdrop,
28 | proj_pdrop,
29 | path_pdrop,
30 | pos_embd
31 | ):
32 | super().__init__()
33 | self.n_in = n_in
34 |
35 | self.conv_dropout = conv_dropout
36 | self.n_encoder_layers = n_encoder_layers
37 | self.relu = nn.ReLU(inplace=True)
38 |
39 | self.pos_embd = pos_embd
40 |
41 | # embedding network using convs
42 | self.embd = nn.ModuleList()
43 | self.embd_norm = nn.ModuleList()
44 | for idx in range(n_conv_layers):
45 | n_in = n_embd if idx > 0 else n_in
46 | self.embd.append(
47 | MaskedConv1D(
48 | n_in, n_embd, n_embd_ks,
49 | stride=1, padding=n_embd_ks//2, bias=(not conv_ln)
50 | )
51 | )
52 | if conv_ln:
53 | self.embd_norm.append(LayerNorm(n_embd))
54 | else:
55 | self.embd_norm.append(nn.Identity())
56 |
57 | self.encoder_blocks = nn.ModuleList([TransformerBlock(n_embd= n_embd,
58 | n_head = n_enc_head,
59 | n_hidden= n_embd,
60 | n_out = n_embd,
61 | attn_pdrop=attn_pdrop,
62 | proj_pdrop=proj_pdrop,
63 | path_pdrop=path_pdrop,
64 | use_rel_pe = False, #only for local attn, not applicable here
65 | ) for _ in range(n_encoder_layers)])
66 |
67 | # init weights
68 | self.apply(self.__init_weights__)
69 |
70 | def __init_weights__(self, module):
71 | # set nn.Linear bias term to 0
72 | if isinstance(module, (nn.Linear, nn.Conv1d)):
73 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
74 | if module.bias is not None:
75 | torch.nn.init.constant_(module.bias, 0.)
76 |
77 | def forward(self, x, mask, video_ids = None, curr_epoch = None):
78 | # x: batch size, feature channel, sequence length,
79 | # mask: batch size, 1, sequence length (bool)
80 | B, C, T = x.size()
81 |
82 | # embedding network
83 | for idx in range(len(self.embd)):
84 | x, mask = self.embd[idx](x, mask)
85 | x = self.relu(self.embd_norm[idx](x))
86 |
87 | if idx != len(self.embd) - 1:
88 | x = nn.Dropout(self.conv_dropout)(x)
89 |
90 | B, C, T = x.size()
91 |
92 | if self.pos_embd is not None:
93 | x = x + self.pos_embd[:, :, :T].cuda() * mask.to(x.dtype)
94 |
95 |
96 | for idx in range(self.n_encoder_layers):
97 | x, mask = self.encoder_blocks[idx](x, mask)
98 |
99 | return x, mask
100 |
101 |
102 |
103 | @register_backbone("i3d")
104 | class I3D_feat_extractor(nn.Module):
105 | def __init__(self, I3D_ckpt_path, finetune):
106 | super(I3D_feat_extractor, self).__init__()
107 | self.i3d = I3D(num_classes=400, modality='rgb', dropout_prob=0.5)
108 |
109 | if I3D_ckpt_path is not None:
110 | print("loading I3D weights from: ", I3D_ckpt_path)
111 | self.i3d.load_state_dict(torch.load(I3D_ckpt_path))
112 |
113 | self.finetune = finetune
114 | self.se = False
115 |
116 | self.avg_pool = torch.nn.AdaptiveAvgPool3d((1, 1, 1))
117 |
118 |
119 | def get_feature_dim(self):
120 | return self.i3d.get_logits_dim()
121 |
122 | def forward(self, videos, video_ids = None):
123 |
124 | if videos.dim() == 6:
125 | #B x Nwin x T x C x H x W
126 | B, N_Win, T, C, H, W = videos.size()
127 |
128 | #B x N_Win x T x C x H x W -> B*N_Win x T x C x H x W
129 | videos_reshaped = videos.reshape(B*N_Win, T, C, H, W)
130 |
131 | #B*N_Win, T, C, H, W -> B*N_Win, C, T, H, W
132 | videos_tensor = videos_reshaped.permute(0, 2, 1, 3, 4)
133 |
134 | if videos.dim() == 7:
135 | #B x Nwin x T x C x H x W
136 | B, N_Win, crops, T, C, H, W = videos.size()
137 |
138 | #B x N_Win x T x C x H x W -> B*N_Win*crops x T x C x H x W
139 | videos_reshaped = videos.reshape(B*N_Win*crops, T, C, H, W)
140 |
141 | #B*N_Win*crops, T, C, H, W -> B*N_Win*crops, C, T, H, W
142 | videos_tensor = videos_reshaped.permute(0, 2, 1, 3, 4)
143 |
144 | if not self.finetune:
145 | with torch.no_grad():
146 | self.i3d.eval()
147 | video_feature = self.i3d(videos_tensor)
148 | else:
149 | video_feature = self.i3d(videos_tensor)
150 |
151 |
152 | #Video -> B*N_Win, C
153 | video_feature = self.avg_pool(video_feature).squeeze()
154 |
155 |
156 | if videos.dim() == 6:
157 | #Split into batches (B x T x C), because N_Win is the final T
158 | batch_feats = video_feature.reshape(B, N_Win, -1)
159 |
160 | if videos.dim() == 7:
161 | batch_feats = video_feature.reshape(B, N_Win, crops, -1)
162 |
163 | batch_feats = batch_feats.mean(axis=2)
164 |
165 | #B x T x C -> B x C x T
166 | batch_feats = batch_feats.permute(0, 2, 1)
167 |
168 | return batch_feats
169 |
170 |
171 |
--------------------------------------------------------------------------------
/tools/finediving/finediving_t5xxl_text_embed_extraction.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from dataclasses import dataclass, field
3 | from typing import List
4 |
5 | # import open_clip
6 | # from open_clip import tokenizer
7 | import torch
8 | import numpy as np
9 | import torch.nn.functional as F
10 | from transformers import AutoTokenizer, AutoModel
11 |
12 |
13 | """
14 | Format in the annotation pkl file
15 | {'Forward': 1,
16 | 'Back': 2,
17 | 'Reverse': 3,
18 | 'Inward': 4,
19 | 'Arm.Forward': 5,
20 | 'Arm.Back': 6,
21 | 'Arm.Reverse': 7,
22 | '1 Som.Pike': 12,
23 | '1.5 Soms.Pike': 13,
24 | '2 Soms.Pike': 14,
25 | '2.5 Soms.Pike': 15,
26 | '3 Soms.Pike': 16,
27 | '3.5 Soms.Pike': 17,
28 | '4.5 Soms.Pike': 19,
29 | '1.5 Soms.Tuck': 21,
30 | '2 Soms.Tuck': 22,
31 | '2.5 Soms.Tuck': 23,
32 | '3 Soms.Tuck': 24,
33 | '3.5 Soms.Tuck': 25,
34 | '4.5 Soms.Tuck': 27,
35 | '0.5 Twist': 29,
36 | '1 Twist': 30,
37 | '1.5 Twists': 31,
38 | '2 Twists': 32,
39 | '2.5 Twists': 33,
40 | '3 Twists': 34,
41 | '3.5 Twists': 35,
42 | 'Entry': 36,
43 | '0.5 Som.Pike': 37}
44 | """
45 |
46 |
47 | @dataclass
48 | class DataClass:
49 | sub_action_descriptions: List[str] = field(
50 | default_factory=lambda: [
51 | "Forward: A diving technique involving a front-facing takeoff and entry.",
52 | "Back: A diving technique involving a back-facing takeoff and entry.",
53 | "Reverse: A diving technique involving a back-facing takeoff and entry while rotating forward.",
54 | "Inward: A diving technique involving a front-facing takeoff and entry while rotating backwards.",
55 | "Arm Forward: A diving technique involving a front-facing takeoff and entry with arms extended and hands meeting above the head.",
56 | "Arm Back: A diving technique involving a back-facing takeoff and entry with arms extended and hands meeting above the head.",
57 | "Arm Reverse: A diving technique involving a back-facing takeoff and entry with arms extended and hands meeting above the head while rotating forward.",
58 | "0.5 Somersault Pike: A diving technique involving a take-off with half a somersault in the pike position before entering the water.",
59 | "1 Somersault Pike: A diving technique involving a takeoff and rotating forward to form a pike position with one somersault.",
60 | "1.5 Somersaults Pike: A diving technique involving a takeoff and rotating forward to form a pike position with one and a half somersaults.",
61 | "2 Somersaults Pike: A diving technique involving a takeoff and rotating forward to form a pike position with two somersaults.",
62 | "2.5 Somersaults Pike: A diving technique involving a takeoff and rotating forward to form a pike position with two and a half somersaults.",
63 | "3 Somersaults Pike: A diving technique involving a takeoff and rotating forward to form a pike position with three somersaults.",
64 | "3.5 Somersaults Pike: A diving technique involving a takeoff and rotating forward to form a pike position with three and a half somersaults.",
65 | "4.5 Somersaults Pike: A diving technique involving a takeoff and rotating forward to form a pike position with four and a half somersaults.",
66 | "1.5 Somersaults Tuck: A diving technique involving a takeoff and rotating forward to bend at the waist with one and a half somersaults.",
67 | "2 Somersaults Tuck: A diving technique involving a takeoff and rotating forward to bend at the waist with two somersaults.",
68 | "2.5 Somersaults Tuck: A diving technique involving a takeoff and rotating forward to bend at the waist with two and a half somersaults.",
69 | "3 Somersaults Tuck: A diving technique involving a takeoff and rotating forward to bend at the waist with three somersaults.",
70 | "3.5 Somersaults Tuck: A diving technique involving a takeoff and rotating forward to bend at the waist with three and a half somersaults.",
71 | "4.5 Somersaults Tuck: A diving technique involving a takeoff and rotating forward to bend at the waist with four and a half somersaults.",
72 | "0.5 Twist: A diving technique involving a takeoff and half a twist before entering the water.",
73 | "1 Twist: A diving technique involving a takeoff and one full twist before entering the water.",
74 | "1.5 Twists: A diving technique involving a takeoff and one and a half twists before entering the water.",
75 | "2 Twists: A diving technique involving a takeoff and two full twists before entering the water.",
76 | "2.5 Twists: A diving technique involving a takeoff and two and a half twists before entering the water.",
77 | "3 Twists: A diving technique involving a takeoff with three twists before entering the water.",
78 | "3.5 Twists: A diving technique involving a takeoff with three and a half twists before entering the water.",
79 | "Entry: A diving technique involving a entry into the water, typically performed at the end of a dive."
80 | ]
81 | )
82 |
83 | class GetTextEmbeddings:
84 | def __init__(self, output_path) -> None:
85 | """
86 | Args:
87 | output_path (_type_): output_path to save the embeddibngs
88 | """
89 | self.data = DataClass()
90 | self.output_path = output_path
91 |
92 | def get_huggingface_embeddings(self):
93 | def average_pool(
94 | last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
95 | ) -> torch.Tensor:
96 | last_hidden = last_hidden_states.masked_fill(
97 | ~attention_mask[..., None].bool(), 0.0
98 | )
99 | return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
100 |
101 | input_texts = [
102 | "This is a " + desc.replace(": ", " action: ")
103 | for desc in self.data.sub_action_descriptions
104 | ]
105 |
106 | model_id="google/flan-t5-xxl" # "intfloat/e5-large"
107 | tokenizer = AutoTokenizer.from_pretrained(model_id)
108 | model = AutoModel.from_pretrained(model_id)
109 |
110 |
111 | # Tokenize the input texts
112 | batch_dict = tokenizer(
113 | input_texts,
114 | max_length=512,
115 | padding=True,
116 | truncation=True,
117 | return_tensors="pt",
118 | )
119 |
120 | with torch.no_grad():
121 | outputs= model(input_ids= batch_dict['input_ids'], decoder_input_ids=batch_dict['input_ids'])
122 | embeddings = average_pool(
123 | outputs.last_hidden_state, batch_dict["attention_mask"]
124 | )
125 | np.save(self.output_path, embeddings.detach().cpu().numpy())
126 | return embeddings.detach().cpu().numpy()
127 |
128 |
129 | if __name__ == "__main__":
130 | parser = argparse.ArgumentParser()
131 | parser.add_argument(
132 | "--output_path",
133 | help="path to save the embeddings",
134 | nargs="?"
135 | )
136 | args = parser.parse_args()
137 |
138 | get_text_embeddings = GetTextEmbeddings("FineDiving_t5_xxl_text_embeddings.npy")
139 | text_embeddings = get_text_embeddings.get_huggingface_embeddings()
140 | print("Text Embed shape: ", text_embeddings.shape)
141 |
142 |
--------------------------------------------------------------------------------
/libs/datasets/mtl_aqa.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import numpy as np
4 | from PIL import Image
5 |
6 | import torch
7 | from torch.utils.data import Dataset
8 | from torchvision import transforms
9 | from torch.nn import functional as F
10 |
11 | from .datasets import register_dataset
12 |
13 | import pickle
14 |
15 | @register_dataset("mtl_aqa")
16 | class MTL_AQA(Dataset):
17 | def __init__(self,
18 | is_training,
19 | split,
20 | data_root,
21 | video_dir,
22 | label_file,
23 | train_datafile,
24 | test_datafile,
25 | stride,
26 | window_size,
27 | frames_per_clip,
28 | max_seq_len,
29 | input_frame_size = None,
30 | crop_frame_size = None,
31 | with_dd = True,
32 | three_judge_score_scaling = False,
33 | use_feats = None,
34 | feat_dir = None,
35 | ):
36 |
37 | self.subset = split
38 | self.is_training = is_training
39 |
40 | self.use_feats = use_feats
41 |
42 | self.crop_frame_size = crop_frame_size
43 |
44 | self.with_dd = with_dd
45 | self.three_judge_score_scaling = three_judge_score_scaling
46 |
47 |
48 |
49 | if self.subset == 'test':
50 | self.transforms = transforms.Compose(
51 | [
52 | transforms.Resize(input_frame_size),
53 | transforms.CenterCrop(crop_frame_size),
54 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
55 | ]
56 | )
57 |
58 |
59 | elif self.subset == 'train':
60 | self.transforms = transforms.Compose(
61 | [
62 | transforms.RandomHorizontalFlip(p=0.5),
63 | transforms.Resize(input_frame_size),
64 | transforms.RandomCrop(crop_frame_size),
65 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
66 | ]
67 | )
68 |
69 | else:
70 | raise ValueError("subset should be train or test")
71 |
72 | print(f"subset: {self.subset}, is_training: {self.is_training}")
73 |
74 | self.pil_2_tensor = transforms.ToTensor()
75 | self.stride = stride
76 | self.window_size = window_size
77 | self.frames_per_clip = frames_per_clip
78 | self.max_seq_len = max_seq_len
79 |
80 | # file paths
81 | self.data_root = data_root
82 | self.video_dir = os.path.join(self.data_root, video_dir)
83 |
84 | if self.use_feats:
85 | self.feat_dir = os.path.join(self.data_root, feat_dir)
86 |
87 | train_datafile_path = os.path.join(self.data_root, "info", train_datafile)
88 | test_datafile_path = os.path.join(self.data_root,"info", test_datafile)
89 |
90 | self.data_anno = self.read_pickle(os.path.join(self.data_root, "info", label_file))
91 |
92 | if self.subset == 'test':
93 | self.datalist = self._load_annotation(test_datafile_path)
94 | elif self.subset == 'train':
95 | self.datalist = self._load_annotation(train_datafile_path)
96 | else:
97 | raise ValueError("subset should be train or test")
98 |
99 |
100 | def _load_annotation(self, pkl_file):
101 |
102 | data_list = self.read_pickle(pkl_file)
103 | processed_data_list = []
104 |
105 | for video_id in data_list:
106 | data = {}
107 | data['video_id'] = video_id
108 |
109 | data['actions_present'] = self.get_actions_present(self.data_anno[video_id])
110 |
111 | data['final_score'] = self.data_anno[video_id]["final_score"]
112 |
113 | data['difficulty'] = self.data_anno[video_id]["difficulty"]
114 |
115 | data['gt_score'] = data['final_score']
116 |
117 | processed_data_list.append(data)
118 |
119 |
120 | return processed_data_list
121 |
122 | def get_actions_present(self, anno):
123 |
124 | """
125 | armstand: No, Yes
126 | rotation_type: Inward, reverse, backward, forward
127 |
128 | positions: Free, tuck, Pike
129 | #SS: 9 unique , including 0 for no ss
130 | #tw: 8 unique, including 0 for no tw
131 |
132 | indexing: 0, 1,2,3,4, 5,6,7, 8,9,10,11,12,13,14,15, 16,17,18,19,20,21,22
133 | """
134 | if anno["armstand"] != 0:
135 | armstand_idx = 0
136 | else:
137 | armstand_idx = "MISSING"
138 |
139 | rotation_type_idx = 1 + anno["rotation_type"]
140 |
141 | pos_idx = 5 + anno["position"]
142 |
143 | if anno["ss_no"] != 0:
144 | ss_idx = 7 + anno["ss_no"]
145 | else:
146 | ss_idx = "MISSING"
147 |
148 | if anno["tw_no"] != 0:
149 | tw_idx = 15 + anno["tw_no"]
150 | else:
151 | tw_idx = "MISSING"
152 |
153 | actions_present = []
154 | for idx in [pos_idx, armstand_idx, rotation_type_idx, ss_idx, tw_idx]:
155 | if idx != "MISSING":
156 | actions_present.append(idx)
157 |
158 | #action for entry
159 | actions_present.append(23)
160 |
161 | return actions_present
162 |
163 | def read_pickle(self, pickle_path):
164 | with open(pickle_path,'rb') as f:
165 | pickle_data = pickle.load(f)
166 | return pickle_data
167 |
168 |
169 | def load_video(self, video_file_name):
170 | first, second = video_file_name[0], video_file_name[1]
171 |
172 | if first < 10:
173 | first = "0"+str(first)
174 |
175 | if second < 10:
176 | second = "0"+str(second)
177 | image_list = sorted((glob.glob(os.path.join(self.video_dir, str(first)+"_"+str(second), '*.jpg'))))
178 |
179 | if self.is_training:
180 | #start from 0-1-2-3
181 | start_idx = torch.randint(0,4,[1]).item()
182 | image_list = image_list[start_idx:start_idx+self.frames_per_clip]
183 |
184 | video = [Image.open(image) for image in image_list]
185 | video = [transforms.ToTensor()(frame) for frame in video]
186 | video = torch.stack(video)
187 | video = self.transforms(video)
188 |
189 | #extract windows
190 | start_idx = [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95]
191 |
192 | video_pack = torch.stack([video[i:i+self.window_size,:,:,:] for i in start_idx])
193 |
194 | return video_pack
195 |
196 | def load_feats(self, video_file_name):
197 | feats = np.load(os.path.join(self.feat_dir, str(video_file_name[0])+"-abr-"+str(video_file_name[1])+'.npz'))["arr_0"]
198 | feats = torch.from_numpy(feats).float()
199 |
200 | return feats
201 |
202 |
203 | def __getitem__(self, index):
204 | video_data = self.datalist[index]
205 |
206 | video_id = video_data["video_id"]
207 |
208 | data = {"video_id": video_id}
209 | if self.use_feats:
210 | data['feats'] = self.load_feats(video_id)
211 | else:
212 | data['window_frames'] = self.load_video(video_id)
213 |
214 | data["video_name"] = video_id
215 | data["difficulty"] = video_data["difficulty"]
216 | data["actions_present"] = video_data["actions_present"]
217 |
218 | target = video_data["gt_score"]
219 | data["gt_score"] = video_data["gt_score"]
220 |
221 | if self.with_dd:
222 | target = target / data['difficulty']
223 |
224 | if self.three_judge_score_scaling:
225 | target = target / 3.0
226 |
227 | data["target"] = target
228 |
229 | return data
230 |
231 |
232 | def __len__(self):
233 | return len(self.datalist)
234 |
--------------------------------------------------------------------------------
/libs/modeling/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from einops import repeat
6 |
7 |
8 | """
9 | ranking and sparsity loss, modified from TPT
10 | Uses broadcasted operations to calculate the the loss in a fully vectorized manner.
11 | https://github.com/baiyang4/aqa_tpt/blob/cf6d1631ec53c676f108926fc5480afe7c47ff56/train_pairencode1_decoder_1selfatt_self8head_ffn_sp_new.py#L289
12 | """
13 | def get_att_loss(logits_all, gt_actions,ranking_loss_wt, sparsity_loss_wt, hinge_loss):
14 |
15 | # logits_all: B, Q, T
16 | B, Q, T = logits_all.size()
17 | #gt_actions One hot encoded: B, Q
18 |
19 | logits_all = logits_all.permute(0, 2, 1) #.transpose(-1,-2)
20 | softmax_dim = logits_all.shape[1]
21 | temp_idx = repeat(torch.arange(1, softmax_dim + 1), 't -> b t k', b=logits_all.shape[0], k=logits_all.shape[-1]).float().to(logits_all.device)
22 | cluster_mean = (logits_all * temp_idx).sum(1)
23 | var = (torch.abs(temp_idx - repeat(cluster_mean, 'b k -> b t k', t=softmax_dim)) * logits_all).sum(1)
24 |
25 | if ranking_loss_wt == 0.0:
26 | return 0.0, var.mean()
27 |
28 | # Extract active clusters based on gt_actions for all samples in batch
29 | active_indices = gt_actions.nonzero(as_tuple=True) # Indices where gt_actions is 1
30 | active_clusters = cluster_mean[active_indices] # Extract active clusters
31 |
32 | # Calculate hinge loss for ranking the active clusters in a fully vectorized manner
33 | cluster_counts = gt_actions.sum(dim=1).long() # Number of active clusters per batch
34 | max_clusters = cluster_counts.max().item()
35 |
36 | # Pad active clusters to have the same length per batch
37 | padded_active_clusters = torch.zeros((B, max_clusters), device=cluster_mean.device)
38 | mask = torch.arange(max_clusters, device=cluster_mean.device).unsqueeze(0) < cluster_counts.unsqueeze(1)
39 | padded_active_clusters[mask] = active_clusters
40 |
41 | # Create shifted versions of the active clusters for comparison
42 | current_clusters = padded_active_clusters[:, :-1] # B, max_clusters - 1
43 | next_clusters = padded_active_clusters[:, 1:] # B, max_clusters - 1
44 | valid_pairs_mask = mask[:, :-1] & mask[:, 1:] # Mask to identify valid pairs for loss calculation
45 |
46 | # Apply the valid pairs mask to current and next clusters
47 | valid_current_clusters = current_clusters[valid_pairs_mask]
48 | valid_next_clusters = next_clusters[valid_pairs_mask]
49 |
50 | # Calculate hinge loss for valid pairs
51 | ones = torch.ones_like(valid_current_clusters)
52 | loss = hinge_loss(valid_next_clusters, valid_current_clusters, ones).mean()
53 |
54 | # Add boundary conditions (first and last clusters only once per batch)
55 | first_clusters = padded_active_clusters[:, 0] # B
56 | last_clusters = padded_active_clusters[torch.arange(B), cluster_counts - 1] # B
57 | boundary_mask = cluster_counts > 0 # Mask to identify batches with at least one active cluster
58 |
59 | loss += hinge_loss(first_clusters[boundary_mask], torch.ones_like(first_clusters[boundary_mask]), torch.ones_like(first_clusters[boundary_mask])).sum()
60 | loss += hinge_loss(torch.ones_like(last_clusters[boundary_mask]) * softmax_dim, last_clusters[boundary_mask], torch.ones_like(last_clusters[boundary_mask])).sum()
61 |
62 | return loss, var.mean()
63 |
64 |
65 |
66 | def criterion(model_output, target, difficulties, inp_gt_actions, loss_weights, with_dd=True, three_judge_score_scaling=False):
67 |
68 | batched_gt_judge_scores = target
69 |
70 | #B, samples, 1
71 | batched_pred_quality_scores = model_output["all_sample_outputs"]
72 |
73 | difficulties = difficulties.cuda()
74 |
75 | # comment these out when training jigsaws?? TODO
76 | if with_dd:
77 | batched_gt_judge_scores = batched_gt_judge_scores * difficulties
78 | batched_pred_quality_scores = batched_pred_quality_scores * difficulties.unsqueeze(-1).unsqueeze(-1)
79 |
80 | if three_judge_score_scaling:
81 | batched_gt_judge_scores = batched_gt_judge_scores * 3.0
82 | batched_pred_quality_scores = batched_pred_quality_scores * 3.0
83 |
84 | if loss_weights["phase_vib"] != 0:
85 |
86 | gt_actions = inp_gt_actions.detach()
87 | gt_actions = gt_actions.reshape(-1)
88 |
89 | phase_mean_emb, phase_var_emb = model_output["phase_mean_emb"], model_output["phase_var_emb"]
90 |
91 | batch_size, num_phases, channels = phase_mean_emb.shape
92 | phase_mean_emb = phase_mean_emb.reshape(-1, channels)
93 | phase_var_emb = phase_var_emb.reshape(-1, channels)
94 |
95 | #mask out the non-action phases
96 | if num_phases > 1:
97 | phase_mean_emb = phase_mean_emb[gt_actions.nonzero().squeeze().long(),:]
98 | phase_var_emb = phase_var_emb[gt_actions.nonzero().squeeze().long(),:]
99 |
100 | #B*phases x channels -> B*phases
101 | phase_vib_loss = torch.sum(torch.pow(phase_mean_emb, 2) + phase_var_emb - torch.log(phase_var_emb) - 1.0, dim=1) * 0.5
102 |
103 | phase_vib_loss = phase_vib_loss.sum()
104 | phase_vib_loss = phase_vib_loss / (batch_size * channels)
105 |
106 |
107 | # batched_pred_quality_score_logits shape -> batch_size x num_samples x 1
108 | # batched_gt_judge_score_class shape -> B x 1
109 | B, num_samples, bins = batched_pred_quality_scores.shape
110 |
111 | if loss_weights["loss"] == "mse":
112 | ground_truth_expanded = batched_gt_judge_scores.unsqueeze(1).expand(-1, num_samples)
113 |
114 | batched_pred_quality_scores_flat = batched_pred_quality_scores.reshape(-1)
115 | ground_truth_flat = ground_truth_expanded.reshape(-1)
116 |
117 | quality_score_loss = nn.MSELoss(reduction='mean')(batched_pred_quality_scores_flat, ground_truth_flat)
118 |
119 | elif loss_weights["loss"] =="xentropy":
120 | ground_truth_expanded = batched_gt_judge_scores.unsqueeze(1).expand(-1, num_samples)
121 |
122 | batched_pred_quality_scores_flat = batched_pred_quality_scores.reshape(B, bins)
123 | ground_truth_flat = ground_truth_expanded.reshape(B).long()
124 | quality_score_loss = nn.CrossEntropyLoss(reduction="mean")(batched_pred_quality_scores_flat, ground_truth_flat)
125 | else:
126 | raise ValueError("Invalid loss function")
127 |
128 | if loss_weights["ranking"] > 0.0 or loss_weights["sparsity"] > 0.0:
129 | # bz x heads x num_queries x T
130 | # Flatten the logits and gt_actions across batch, heads, and decoder layers dimensions
131 | batch_size, num_heads, num_queries, num_clips = model_output["cross_attn"][0].shape
132 | num_decoder_layers = len(model_output["cross_attn"])
133 |
134 | # Concatenate across decoder layers and heads to flatten
135 | logits_all_flat = torch.cat([model_output["cross_attn"][decoder_layer_idx][:, head_idx, :, :]
136 | for decoder_layer_idx in range(num_decoder_layers)
137 | for head_idx in range(num_heads)], dim=0) # Shape: (B * num_heads * num_decoder_layers), Q, T
138 |
139 | # Repeat the ground truth actions for each head and decoder layer
140 | gt_actions_flat = inp_gt_actions.detach().repeat(num_heads * num_decoder_layers, 1) # Shape: (B * num_heads * num_decoder_layers), Q
141 |
142 | ranking_loss, sparsity_loss = get_att_loss(logits_all_flat,
143 | gt_actions_flat,
144 | ranking_loss_wt=loss_weights["ranking"],
145 | sparsity_loss_wt=loss_weights["sparsity"],
146 | hinge_loss=nn.MarginRankingLoss(1.0))
147 |
148 |
149 | loss_dict, final_loss = {}, 0.0
150 | valid_loss_keys = ["quality_score", "phase_vib", "ranking", "sparsity"]
151 |
152 | for key, value in loss_weights.items():
153 | if key in valid_loss_keys and isinstance(value, (int, float)) and value != 0:
154 | loss_dict[f"{key}_loss"] = value * locals()[f"{key}_loss"]
155 | final_loss += loss_dict[f"{key}_loss"]
156 |
157 | loss_dict["final_loss"] = final_loss
158 |
159 | return loss_dict
160 |
--------------------------------------------------------------------------------
/libs/datasets/finediving.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import glob
4 | import numpy as np
5 | from PIL import Image
6 |
7 | import torch
8 | from torch.utils.data import Dataset
9 | from torchvision import transforms
10 | from torch.nn import functional as F
11 |
12 | from .datasets import register_dataset
13 |
14 | import pickle
15 |
16 |
17 | @register_dataset("finediving")
18 | class FineDiving(Dataset):
19 | def __init__(self,
20 | is_training,
21 | split,
22 | data_root,
23 | video_dir,
24 | label_file,
25 | coarse_label_file,
26 | train_datafile,
27 | test_datafile,
28 | stride,
29 | window_size,
30 | frames_per_clip,
31 | max_seq_len,
32 | input_frame_size = None,
33 | crop_frame_size = 224,
34 | with_dd = True,
35 | three_judge_score_scaling = False,
36 | use_feats = None,
37 | feat_dir = None
38 | ):
39 |
40 | self.subset = split
41 | self.is_training = is_training
42 |
43 | self.use_feats = use_feats
44 | self.crop_frame_size = crop_frame_size
45 |
46 | self.with_dd = with_dd
47 | self.three_judge_score_scaling = three_judge_score_scaling
48 |
49 | if self.subset == 'test':
50 | self.transforms = transforms.Compose(
51 | [
52 | transforms.Resize(input_frame_size),
53 | transforms.CenterCrop(crop_frame_size),
54 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
55 | ]
56 | )
57 |
58 | elif self.subset == 'train':
59 | self.transforms = transforms.Compose(
60 | [
61 | transforms.RandomHorizontalFlip(p=0.5),
62 | transforms.Resize(input_frame_size),
63 | transforms.RandomCrop(crop_frame_size),
64 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
65 | ]
66 | )
67 |
68 | else:
69 | raise ValueError("subset should be train or test")
70 |
71 | print(f"subset: {self.subset}, is_training: {self.is_training}")
72 |
73 |
74 | self.pil_2_tensor = transforms.ToTensor()
75 |
76 | self.stride = stride
77 | self.window_size = window_size
78 | self.frames_per_clip = frames_per_clip
79 | self.max_seq_len = max_seq_len
80 |
81 | # file paths
82 | self.data_root = data_root
83 | self.video_dir = os.path.join(self.data_root, video_dir)
84 |
85 | if self.use_feats:
86 | self.feat_dir = os.path.join(self.data_root, feat_dir)
87 |
88 | self.label_path = os.path.join(self.data_root,"Annotations", label_file)
89 | self.coarse_label_path = os.path.join(self.data_root,"Annotations", coarse_label_file)
90 |
91 | train_datafile_path = os.path.join(self.data_root, "Annotations",train_datafile)
92 | test_datafile_path = os.path.join(self.data_root,"Annotations", test_datafile)
93 |
94 | #mapping from subaction to index, refer to gen_text_embeddings.py for more details
95 | self.label_to_idx_dict = {1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 37: 7, 12: 8, 13: 9, 14: 10, 15: 11,
96 | 16: 12, 17: 13, 19: 14, 21: 15, 22: 16, 23: 17, 24: 18, 25: 19, 27: 20, 29: 21,
97 | 30: 22, 31: 23, 32: 24, 33: 25, 34: 26, 35: 27, 36: 28}
98 |
99 |
100 | self.data_anno = self.read_pickle(self.label_path)
101 | self.coarse_data_anno = self.read_pickle(self.coarse_label_path)
102 |
103 | if self.subset == 'test':
104 | self.datalist = self._load_annotations(test_datafile_path)
105 | elif self.subset == 'train':
106 | self.datalist = self._load_annotations(train_datafile_path)
107 | else:
108 | raise ValueError("subset should be train or test")
109 |
110 |
111 | def _load_annotations(self, datafile_path):
112 | data_list = self.read_pickle(datafile_path)
113 |
114 | processed_data_list = []
115 |
116 | for video_id in data_list:
117 | data = {}
118 | data['video_id'] = video_id
119 |
120 | data['final_score'] = self.data_anno[video_id]["dive_score"]
121 |
122 | data['difficulty'] = self.data_anno[video_id]["difficulty"]
123 |
124 | data['gt_score'] = data['final_score']
125 |
126 | processed_data_list.append(data)
127 |
128 | return processed_data_list
129 |
130 | def load_video(self, video_file_name):
131 | image_list = sorted((glob.glob(os.path.join(self.video_dir, video_file_name[0], str(video_file_name[1]), '*.jpg'))))
132 |
133 | if self.is_training:
134 | #start from 0-1-2-3
135 | start_idx = torch.randint(0,4,[1]).item()
136 |
137 | # randomly drop the end frames
138 | end_idx = torch.randint(0,4,[1]).item()
139 | image_list = image_list[start_idx:start_idx + len(image_list) - end_idx]
140 |
141 | start_frame = int(image_list[0].split("/")[-1][:-4])
142 | end_frame = int(image_list[-1].split("/")[-1][:-4])
143 |
144 | frame_list = np.linspace(start_frame, end_frame, self.frames_per_clip).astype(np.int32)
145 | image_frame_idx = [frame_list[i] - start_frame for i in range(self.frames_per_clip)]
146 | video = [Image.open(image_list[image_frame_idx[i]]) for i in range(self.frames_per_clip)]
147 | video = [transforms.ToTensor()(frame) for frame in video]
148 | video = torch.stack(video)
149 | video = self.transforms(video)
150 |
151 | #extract windows
152 | start_idx = list(range(0, 90, 10))
153 |
154 | video_pack = torch.stack([video[i:i+self.window_size,:,:,:] for i in start_idx])
155 |
156 | frames_labels = self.data_anno[video_file_name]["frames_labels"]
157 | #mapping labels to label indices
158 | frames_labels = [self.label_to_idx_dict[l] for l in frames_labels]
159 |
160 | return video_pack, frames_labels
161 |
162 |
163 |
164 | def load_feats(self, video_file_name):
165 | feats = np.load(os.path.join(self.feat_dir, video_file_name[0]+"-abr-"+str(video_file_name[1])+'.npz'))["feats"]
166 |
167 | feats = feats.transpose(1,0)
168 |
169 | feats = torch.from_numpy(feats).float()
170 |
171 | frames_labels = self.data_anno[video_file_name]["frames_labels"]
172 |
173 | #mapping labels to label indices
174 | frames_labels = [self.label_to_idx_dict[l] for l in frames_labels]
175 |
176 | return feats, frames_labels
177 |
178 |
179 | def read_pickle(self, pickle_path):
180 | with open(pickle_path,'rb') as f:
181 | pickle_data = pickle.load(f)
182 | return pickle_data
183 |
184 |
185 | def __getitem__(self, index):
186 | video_data = self.datalist[index]
187 |
188 | video_id = video_data["video_id"]
189 |
190 | data = {"video_id": video_id}
191 | if self.use_feats:
192 | data['feats'], frame_labels = self.load_feats(video_id)
193 | else:
194 | data['window_frames'], frame_labels = self.load_video(video_id)
195 |
196 |
197 | data["video_name"] = video_id
198 |
199 | data['difficulty'] = video_data["difficulty"]
200 |
201 | target = video_data["gt_score"]
202 | data["gt_score"] = video_data["gt_score"]
203 |
204 | if self.with_dd:
205 | target = target / data['difficulty']
206 |
207 | if self.three_judge_score_scaling:
208 | target = target / 3.0
209 |
210 | data["target"] = target
211 |
212 | #if entry is missing, add 28
213 | if 28 not in frame_labels:
214 | frame_labels.append(28)
215 |
216 | #take only the presence of actions
217 | data['actions_present'] = sorted(list(set(frame_labels)))
218 |
219 | return data
220 |
221 | def __len__(self):
222 | return len(self.datalist)
223 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # python imports
2 | import argparse, os
3 | import time
4 | import datetime
5 | from pprint import pprint
6 | import sys, pickle
7 |
8 | # torch imports
9 | import torch
10 | import torch.nn as nn
11 | import torch.utils.data
12 | # for visualization
13 | from torch.utils.tensorboard import SummaryWriter
14 |
15 | # our code
16 | from libs.core import load_config
17 | from libs.datasets import make_dataset, make_data_loader
18 | from libs.modeling import make_meta_arch
19 | from libs.utils import (Logger, train_one_epoch, valid_one_epoch,
20 | save_checkpoint, make_optimizer, make_scheduler,
21 | fix_random_seed)
22 |
23 | def override_cfg_params(cfg, args):
24 | if args.gpu is not None:
25 | cfg["devices"] = args.gpu
26 | else:
27 | cfg['devices'] = [i for i in range(torch.cuda.device_count())]
28 |
29 | if args.data_root is not None:
30 | cfg["dataset"]["data_root"] = args.data_root
31 |
32 | if args.train_batch_size > 0:
33 | cfg["loader"]["train_batch_size"] = args.train_batch_size
34 |
35 | if args.test_batch_size > 0:
36 | cfg["loader"]["test_batch_size"] = args.test_batch_size
37 |
38 | if args.cv_fold > -1:
39 | cfg["dataset"]["cross_val_id"] = args.cv_fold
40 |
41 | if args.cv_split_file != "":
42 | cfg["dataset"]["cross_val_split_file"] = args.cv_split_file
43 |
44 | if cfg["dataset"]["use_feats"] == False:
45 | cfg["model"]["finetune_feat_extractor"]= True
46 | cfg["model"]["feat_extractor_type"]= 'i3d'
47 | cfg["model"]["feat_extractor_weights_path"]= './pre_trained/model_rgb.pth'
48 |
49 | if cfg["model"]["use_stochastic_embd"] == False:
50 | cfg["train_cfg"]["loss_weights"]["phase_vib"] = 0.0
51 | cfg["train_cfg"]["loss_weights"]["scale_vib"] = False
52 |
53 | return cfg
54 |
55 |
56 | def create_checkpoint_folder(cfg, args):
57 | if not os.path.exists(cfg['output_folder']):
58 | os.mkdir(cfg['output_folder'])
59 |
60 | cfg_filename = os.path.basename(args.config).replace('.yaml', '')
61 |
62 | if len(args.output) == 0:
63 | ts = datetime.datetime.fromtimestamp(int(time.time()))
64 | ts = str(ts).replace(' ', '_').replace(':', '-')
65 | ckpt_folder = os.path.join(cfg['output_folder'], cfg_filename + '_' + str(ts))
66 | else:
67 | ckpt_folder = os.path.join(cfg['output_folder'], cfg_filename + '_' + str(args.output))
68 |
69 | if not os.path.exists(ckpt_folder):
70 | os.mkdir(ckpt_folder)
71 |
72 | return ckpt_folder
73 |
74 |
75 | def create_train_val_dataloaders(cfg, rng_generator):
76 | train_dataset = make_dataset(
77 | cfg['dataset_name'],
78 | True,
79 | cfg['train_split'],
80 | **cfg['dataset']
81 | )
82 |
83 |
84 | val_dataset = make_dataset(
85 | cfg['dataset_name'],
86 | False,
87 | cfg['val_split'],
88 | **cfg['dataset']
89 | )
90 |
91 | print("Train dataset size: {:d}".format(len(train_dataset)))
92 | print("Val dataset size: {:d}".format(len(val_dataset)))
93 |
94 | # data loaders
95 | train_loader = make_data_loader(
96 | train_dataset, True, rng_generator, **cfg['loader'])
97 |
98 | val_loader = make_data_loader(
99 | val_dataset, False, None, **cfg['loader'])
100 |
101 | return (train_loader, val_loader)
102 |
103 | ################################################################################
104 | def main(args):
105 | """main function that handles training / inference"""
106 |
107 | """1. setup parameters / folders"""
108 |
109 | # parse args
110 | args.start_epoch = 0
111 |
112 | if os.path.isfile(args.config):
113 | cfg = load_config(args.config)
114 | else:
115 | raise ValueError("Config file does not exist.")
116 |
117 | ckpt_folder = create_checkpoint_folder(cfg, args)
118 | print("Checkpoint folder: {:s}".format(ckpt_folder))
119 | # tensorboard writer
120 | tb_writer = SummaryWriter(os.path.join(ckpt_folder, 'logs'))
121 |
122 | logger = Logger(os.path.join(ckpt_folder, '0_log.txt'))
123 | print("If you plan to debug using ipdb then comment the following line")
124 | #sys.stdout = logger
125 |
126 | torch.set_warn_always(False)
127 |
128 | # fix the random seeds (this will fix everything)
129 | rng_generator = fix_random_seed(cfg['init_rand_seed'], include_cuda=True)
130 |
131 | cfg = override_cfg_params(cfg, args)
132 |
133 | print("Args")
134 | pprint(vars(args), indent=4, stream=sys.__stdout__,sort_dicts=False)
135 | pprint(cfg, stream=sys.__stdout__,sort_dicts=False)
136 |
137 | """2. create dataset / dataloader"""
138 | train_loader, val_loader = create_train_val_dataloaders(cfg, rng_generator)
139 |
140 | """3. create model, optimizer, and scheduler"""
141 | # model
142 | model = make_meta_arch(cfg['model_name'], **cfg['model'])
143 |
144 | model = nn.DataParallel(model).cuda()
145 |
146 | optimizer = make_optimizer(model, cfg['opt'])
147 |
148 | # schedule
149 | num_iters_per_epoch = len(train_loader) / cfg["train_cfg"]["accumulation_steps"]
150 | scheduler = make_scheduler(optimizer, cfg['opt'], num_iters_per_epoch)
151 |
152 | """4. Resume from model / Misc"""
153 | # resume from a checkpoint?
154 | if args.resume:
155 | if os.path.isfile(args.resume):
156 | # load ckpt, reset epoch / best rmse
157 | checkpoint = torch.load(args.resume,
158 | map_location = lambda storage, loc: storage.cuda(
159 | cfg['devices'][0]))
160 | args.start_epoch = checkpoint['epoch']
161 | model.load_state_dict(checkpoint['state_dict'])
162 | # also load the optimizer / scheduler if necessary
163 | optimizer.load_state_dict(checkpoint['optimizer'])
164 | scheduler.load_state_dict(checkpoint['scheduler'])
165 | print("=> loaded checkpoint '{:s}' (epoch {:d}".format(
166 | args.resume, checkpoint['epoch']
167 | ))
168 | del checkpoint
169 | else:
170 | print("=> no checkpoint found at '{}'".format(args.resume))
171 | return
172 |
173 | # save the current config
174 | with open(os.path.join(ckpt_folder, 'config.txt'), 'w') as fid:
175 | pprint(cfg, stream=fid)
176 | fid.flush()
177 |
178 | """4. training / validation loop"""
179 | print("\nStart training model {:s} ...".format(cfg['model_name']))
180 |
181 | # start training
182 | max_epochs = cfg['opt'].get(
183 | 'early_stop_epochs',
184 | cfg['opt']['epochs'] + cfg['opt']['warmup_epochs']
185 | )
186 | best_rl2 = 1000
187 | srcc_at_best_rl2 = -100
188 | best_rl2_epoch = 0
189 |
190 | best_srcc = -100
191 | rl2_at_best_srcc = 1000
192 | best_srcc_epoch = 0
193 |
194 | for epoch in range(args.start_epoch, max_epochs):
195 |
196 | # train for one epoch
197 | #start_time = time.time()
198 | train_one_epoch(
199 | train_loader,
200 | model,
201 | optimizer,
202 | scheduler,
203 | epoch,
204 | cfg = cfg,
205 | tb_writer=tb_writer,
206 | print_freq=args.print_freq
207 | )
208 | #print("Time taken for epoch: {:.2f} s".format(time.time() - start_time))
209 |
210 | if cfg["train_cfg"]["loss_weights"]["phase_vib"] > 0.0 and cfg["train_cfg"]["loss_weights"]["scale_vib"]:
211 | if epoch > 30 and epoch % 10 == 0:
212 | cfg["train_cfg"]["loss_weights"]["phase_vib"] *= 3
213 | cfg["train_cfg"]["loss_weights"]["phase_vib"] = min(cfg["train_cfg"]["loss_weights"]["phase_vib"], 0.005)
214 |
215 |
216 | if (epoch == 0) or ((epoch+1) <= 30 and (epoch+1) % 5 == 0 ) or (30 < (epoch+1) <= 120 and (epoch+1) % 3 == 0) or (epoch+1) > 120 or (epoch+1) == max_epochs:
217 | curr_srcc, curr_rl2, metric_dict = valid_one_epoch(
218 | val_loader,
219 | model,
220 | epoch,
221 | cfg = cfg,
222 | tb_writer=tb_writer,
223 | print_freq=args.print_freq,
224 | save_predictions=True
225 | )
226 |
227 | if curr_srcc > best_srcc:
228 | best_srcc = curr_srcc
229 | rl2_at_best_srcc = curr_rl2
230 | best_srcc_epoch = epoch
231 | srcc_improved = True
232 | else:
233 | srcc_improved = False
234 |
235 | if curr_rl2 < best_rl2:
236 | best_rl2 = curr_rl2
237 | srcc_at_best_rl2 = curr_srcc
238 | best_rl2_epoch = epoch
239 | rl2_improved = True
240 | else:
241 | rl2_improved = False
242 |
243 | print("Best SRCC: {:.4f}, corres. RL2: {:.4f} at epoch {:d}".format(best_srcc, rl2_at_best_srcc, best_srcc_epoch))
244 | print("Best RL2: {:.4f}, corres. SRCC: {:.4f} at epoch {:d}".format(best_rl2, srcc_at_best_rl2, best_rl2_epoch))
245 |
246 | save_states = {
247 | 'epoch': epoch + 1,
248 | 'state_dict': model.state_dict(),
249 | 'scheduler': scheduler.state_dict(),
250 | 'optimizer': optimizer.state_dict(),
251 | }
252 |
253 | if srcc_improved:
254 | save_checkpoint(
255 | save_states,
256 | False,
257 | file_folder=ckpt_folder,
258 | file_name='srcc_best.pth.tar'
259 | )
260 | with open(os.path.join(ckpt_folder, "srcc_best_outputs.pkl"), "wb") as f:
261 | pickle.dump(metric_dict, f)
262 |
263 | if rl2_improved:
264 | save_checkpoint(
265 | save_states,
266 | False,
267 | file_folder=ckpt_folder,
268 | file_name='rl2_best.pth.tar'
269 | )
270 | with open(os.path.join(ckpt_folder, "rl2_best_outputs.pkl"), "wb") as f:
271 | pickle.dump(metric_dict, f)
272 |
273 | del save_states
274 |
275 | print("Best SRCC: {:.4f}".format(best_srcc))
276 | print("Best RL2: {:.4f}".format(best_rl2))
277 |
278 | # wrap up
279 | tb_writer.close()
280 | print("All done!")
281 | return
282 |
283 | ################################################################################
284 | if __name__ == '__main__':
285 | """Entry Point"""
286 | # the arg parser
287 | parser = argparse.ArgumentParser(
288 | description='Train')
289 | parser.add_argument('config', metavar='DIR',
290 | help='path to a config file')
291 | parser.add_argument('-p', '--print-freq', default=10, type=int,
292 | help='print frequency (default: 10 iterations)')
293 | parser.add_argument('-c', '--ckpt-freq', default=5, type=int,
294 | help='checkpoint frequency (default: every 5 epochs)')
295 | parser.add_argument('--output', default='', type=str,
296 | help='name of exp folder (default: none)')
297 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
298 | help='path to a checkpoint (default: none)')
299 |
300 | parser.add_argument('--data_root', type=str, metavar='PATH',)
301 | parser.add_argument('--train_batch_size', default=-1, type=int)
302 | parser.add_argument('--test_batch_size', default=-1, type=int)
303 | parser.add_argument('--cv_fold', default=-1, type=int)
304 |
305 | parser.add_argument('--cv_split_file', default='', type=str)
306 | parser.add_argument('--gpu', nargs='*')
307 |
308 | args = parser.parse_args()
309 |
310 | main(args)
311 |
--------------------------------------------------------------------------------
/libs/modeling/meta_archs.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | from .models import register_meta_arch, make_backbone, make_neck
6 | from .blocks import get_sinusoid_encoding, PhaseDistComposer
7 | from .losses import criterion
8 |
9 |
10 | class Quality_Score_Head(nn.Module):
11 | """
12 | Head for multiple judge quality score
13 | """
14 |
15 | def __init__(self, input_dim, score_bins, num_random_samples=1, use_stochastic_embd=True, dataset=None):
16 | super().__init__()
17 |
18 | self.score_bins = score_bins
19 | self.input_dim = input_dim
20 |
21 | self.common_base = nn.Sequential(
22 | nn.Linear(input_dim, input_dim // 2),
23 | nn.ReLU()
24 | )
25 |
26 | self.scoring_head = nn.Sequential(
27 | nn.Linear(input_dim // 2, input_dim // 2),
28 | nn.ReLU(),
29 | nn.Linear(input_dim // 2, score_bins)
30 | )
31 |
32 | self.use_stochastic_embd = use_stochastic_embd
33 | self.phase_composer = PhaseDistComposer(dataset_name=dataset, dim=input_dim//2)
34 |
35 | if self.use_stochastic_embd:
36 | self.num_random_samples = num_random_samples
37 | self.mean_var_linear = nn.Sequential(
38 | nn.Linear(input_dim // 2, input_dim // 2),
39 | nn.ReLU(),
40 | nn.Linear(input_dim // 2, input_dim),
41 | )
42 |
43 | self.temperature = nn.Parameter(torch.ones(1) * 2)
44 |
45 | self._reset_parameters()
46 |
47 | def _reset_parameters(self):
48 | for p in self.parameters():
49 | if p.dim() > 1:
50 | nn.init.xavier_uniform_(p)
51 |
52 | def temperature_scale(self, logits):
53 | temperature = self.temperature.unsqueeze(0).unsqueeze(0).expand(*logits.shape)
54 | return logits / temperature
55 |
56 | def forward(self, x, gt_actions= None):
57 | common_input = self.common_base(x)
58 |
59 | #common_input_shape -> B x phases x dim//2
60 | if self.use_stochastic_embd:
61 | #B x phases x channels*2
62 | phase_mean_log_var = self.mean_var_linear(common_input)
63 |
64 | #first C/2 channels are mean, last C/2 channels are log_var
65 | phase_mean_emb = phase_mean_log_var[:, :, :self.input_dim // 2]
66 | phase_log_var = phase_mean_log_var[:,:,self.input_dim // 2: ]
67 |
68 | phase_var_emb = torch.exp(phase_log_var)
69 |
70 | #shapes = B
71 | common_input, global_sqrt_var = self.phase_composer.process_phases(
72 | phase_mean_emb = phase_mean_emb,
73 | phase_var_emb = phase_var_emb,
74 | num_samples = 1 if self.training else self.num_random_samples,
75 | gt_actions = gt_actions.detach()
76 | )
77 | else:
78 | # #if phase var emb is none, then it become deterministic
79 | common_input, _ = self.phase_composer.process_phases(
80 | common_input, None, 1, gt_actions.detach()
81 | )
82 |
83 | #num_samples x batch_size x 1
84 | all_sample_output = self.scoring_head(common_input)
85 |
86 | #num_samples x batch_size x 1 --> batch_size x num_samples x 1
87 | all_sample_output = all_sample_output.permute(1, 0, 2)
88 |
89 | if self.use_stochastic_embd:
90 | return (
91 | phase_mean_emb,
92 | phase_var_emb,
93 | global_sqrt_var,
94 | all_sample_output
95 | )
96 | else:
97 | return all_sample_output
98 |
99 |
100 |
101 | @register_meta_arch("aqa-model")
102 | class AQA_Model(nn.Module):
103 | """
104 | Transformer based model for single stage action localization
105 | """
106 |
107 | def __init__(
108 | self,
109 | feat_extractor_type,
110 | finetune_feat_extractor,
111 | feat_extractor_weights_path,
112 | backbone_type,
113 | input_feat_dim,
114 | embed_dim,
115 | conv_dropout,
116 | neck_type,
117 | num_layers,
118 | conv_kernel_size,
119 | encoder_params,
120 | decoder_params,
121 | num_phases,
122 | score_bins,
123 | train_cfg, # other cfg for training
124 | test_cfg, # other cfg for testing
125 | max_seq_len, # max sequence length in features (used for training)
126 | frames_per_clip, # original sequence length in frames
127 | use_stochastic_embd = False,
128 | num_random_samples = None
129 | ):
130 | super().__init__()
131 |
132 | self.feat_extractor_type = feat_extractor_type
133 | self.finetune_feat_extractor = finetune_feat_extractor
134 |
135 | self.embed_dim = embed_dim
136 | self.use_phases = False
137 |
138 | self.num_phases = num_phases
139 |
140 | self.input_feat_dim = input_feat_dim
141 | self.max_seq_len = max_seq_len
142 | self.frames_per_clip = frames_per_clip
143 |
144 | self.score_bins = score_bins
145 |
146 | self.train_cfg = train_cfg
147 |
148 |
149 | self.use_stochastic_embd = use_stochastic_embd
150 |
151 | self.num_random_samples = num_random_samples
152 |
153 | if encoder_params["use_abs_pe"]:
154 | pos_embd = get_sinusoid_encoding(self.max_seq_len, embed_dim) / (
155 | embed_dim**0.5
156 | )
157 | self.register_buffer("pos_embd", pos_embd, persistent=False)
158 | self.pos_embd = self.pos_embd
159 |
160 | assert feat_extractor_type in ["i3d", None]
161 |
162 | if feat_extractor_type == "i3d":
163 | self.feat_extractor = make_backbone(
164 | "i3d",
165 | **{
166 | "I3D_ckpt_path": feat_extractor_weights_path,
167 | "finetune": self.finetune_feat_extractor,
168 | }
169 | )
170 | else:
171 | self.feat_extractor = None
172 |
173 | # backbone network: conv + transformer
174 | assert backbone_type in ["conv", "convEncoder"]
175 |
176 | if backbone_type == "convEncoder":
177 | self.backbone = make_backbone(
178 | "convEncoder",
179 | **{
180 | "n_in": input_feat_dim, # input feature dimension
181 | "n_embd": embed_dim, # embedding dimension (after convolution)
182 | "conv_dropout": conv_dropout, # dropout rate for conv layers in the initial projection network
183 | "n_conv_layers": num_layers["n_conv_layers"], # number of layers
184 | "n_embd_ks": conv_kernel_size, # conv kernel size of the embedding network
185 | "conv_ln": False, # whether to use layer norm
186 | "n_encoder_layers": num_layers[
187 | "n_encoder_layers"
188 | ], # number of encoder layers
189 | "n_enc_head": encoder_params[
190 | "n_encoder_heads"
191 | ], # number of heads in the encoder
192 | "attn_pdrop": encoder_params["attn_pdrop"],
193 | "proj_pdrop": encoder_params["proj_pdrop"],
194 | "path_pdrop": encoder_params["path_pdrop"],
195 | "pos_embd": self.pos_embd if encoder_params["use_abs_pe"] else None,
196 | }
197 | )
198 | else:
199 | self.backbone = None
200 |
201 | if neck_type == "decoder-neck":
202 | self.neck = make_neck(
203 | "decoder-neck",
204 | **{
205 | "d_model": embed_dim,
206 | "n_heads": decoder_params["n_decoder_heads"],
207 | "stride": decoder_params["stride"],
208 | "num_decoder_layers": num_layers["n_decoder_layers"],
209 | "attn_pdrop": decoder_params["attn_pdrop"],
210 | "proj_pdrop": decoder_params["proj_pdrop"],
211 | "path_pdrop": decoder_params["path_pdrop"],
212 | "xattn_mode": decoder_params["xattn_mode"],
213 | "with_ln": decoder_params["with_ln"],
214 | "num_phases": num_phases,
215 | "query_config": decoder_params["query_config"],
216 | }
217 | )
218 | else:
219 | self.neck = None
220 |
221 |
222 | self.quality_score_head = Quality_Score_Head(
223 | input_dim=embed_dim,
224 | score_bins=self.score_bins,
225 | num_random_samples = self.num_random_samples if self.training else 1,
226 | use_stochastic_embd = self.use_stochastic_embd,
227 | dataset=train_cfg["dataset_name"]
228 | )
229 |
230 | @property
231 | def device(self):
232 | # a hacky way to get the device type
233 | # will throw an error if parameters are on different devices
234 | return list(set(p.device for p in self.parameters()))[0]
235 |
236 | def forward(
237 | self,
238 | batched_inputs,
239 | batched_masks,
240 | batched_gt_actions,
241 | video_ids = None,
242 | curr_epoch = None,
243 | criterion_args = None
244 | ):
245 | if self.feat_extractor is not None:
246 | batched_inputs = self.feat_extractor(batched_inputs, video_ids)
247 |
248 | # x = B x C x T, mask = B x 1 x T
249 | x, masks = self.backbone(batched_inputs, batched_masks, video_ids=video_ids, curr_epoch=curr_epoch)
250 |
251 |
252 | if self.neck is not None:
253 | # output = B x #phases x C
254 | x, cross_attns = self.neck(x, masks, batched_gt_actions, video_ids=video_ids, curr_epoch=curr_epoch)
255 | else:
256 | # B x C x T -> B x T x C
257 | x = x.permute(0, 2, 1)
258 |
259 | if self.use_stochastic_embd:
260 | (
261 | phase_mean_emb,
262 | phase_var_emb,
263 | global_sqrt_var_emb,
264 | all_sample_outputs
265 | ) = self.quality_score_head(x, batched_gt_actions)
266 | else:
267 | #batch_size x num_samples x bins
268 | all_sample_outputs = self.quality_score_head(x, batched_gt_actions)
269 |
270 |
271 | out_dict = {
272 | "all_sample_outputs": all_sample_outputs,
273 | "gt_actions": batched_gt_actions,
274 | "cross_attn": cross_attns if self.neck is not None else None,
275 | }
276 |
277 | if self.use_stochastic_embd:
278 | out_dict["phase_mean_emb"] = phase_mean_emb
279 | out_dict["phase_var_emb"] = phase_var_emb
280 | out_dict["global_sqrt_var_emb"] = global_sqrt_var_emb
281 |
282 |
283 | if criterion_args is not None:
284 | losses = criterion(out_dict,
285 | criterion_args['target'],
286 | criterion_args["difficulties"],
287 | criterion_args["gt_actions"],
288 | loss_weights = criterion_args['loss_weights'],
289 | with_dd = criterion_args['with_dd'],
290 | three_judge_score_scaling = criterion_args['three_judge_score_scaling'],
291 | )
292 | else:
293 | losses = None
294 |
295 | #remove cross attns from out_dict
296 | if self.neck is not None:
297 | out_dict.pop("cross_attn")
298 | return out_dict, losses
--------------------------------------------------------------------------------
/libs/utils/lr_schedulers.py:
--------------------------------------------------------------------------------
1 | import math
2 | import warnings
3 | from collections import Counter
4 | from bisect import bisect_right
5 |
6 | import torch
7 | from torch.optim.lr_scheduler import _LRScheduler
8 |
9 |
10 | class LinearWarmupCosineAnnealingLR(_LRScheduler):
11 | """
12 | Sets the learning rate of each parameter group to follow a linear warmup schedule
13 | between warmup_start_lr and base_lr followed by a cosine annealing schedule between
14 | base_lr and eta_min.
15 |
16 | .. warning::
17 | It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR`
18 | after each iteration as calling it after each epoch will keep the starting lr at
19 | warmup_start_lr for the first epoch which is 0 in most cases.
20 |
21 | .. warning::
22 | passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING.
23 | It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of
24 | :func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing
25 | epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling
26 | train and validation methods.
27 |
28 | Example:
29 | >>> layer = nn.Linear(10, 1)
30 | >>> optimizer = Adam(layer.parameters(), lr=0.02)
31 | >>> scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=40)
32 | >>> #
33 | >>> # the default case
34 | >>> for epoch in range(40):
35 | ... # train(...)
36 | ... # validate(...)
37 | ... scheduler.step()
38 | >>> #
39 | >>> # passing epoch param case
40 | >>> for epoch in range(40):
41 | ... scheduler.step(epoch)
42 | ... # train(...)
43 | ... # validate(...)
44 | """
45 |
46 | def __init__(
47 | self,
48 | optimizer,
49 | warmup_epochs,
50 | max_epochs,
51 | warmup_start_lr = 0.0,
52 | eta_min = 1e-8,
53 | last_epoch = -1,
54 | ):
55 | """
56 | Args:
57 | optimizer (Optimizer): Wrapped optimizer.
58 | warmup_epochs (int): Maximum number of iterations for linear warmup
59 | max_epochs (int): Maximum number of iterations
60 | warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0.
61 | eta_min (float): Minimum learning rate. Default: 0.
62 | last_epoch (int): The index of last epoch. Default: -1.
63 | """
64 | self.warmup_epochs = warmup_epochs
65 | self.max_epochs = max_epochs
66 | self.warmup_start_lr = warmup_start_lr
67 | self.eta_min = eta_min
68 |
69 | super(LinearWarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch)
70 |
71 | def get_lr(self):
72 | """
73 | Compute learning rate using chainable form of the scheduler
74 | """
75 | if not self._get_lr_called_within_step:
76 | warnings.warn(
77 | "To get the last learning rate computed by the scheduler, "
78 | "please use `get_last_lr()`.",
79 | UserWarning,
80 | )
81 |
82 | if self.last_epoch == 0:
83 | return [self.warmup_start_lr] * len(self.base_lrs)
84 | elif self.last_epoch < self.warmup_epochs:
85 | return [
86 | group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
87 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
88 | ]
89 | elif self.last_epoch == self.warmup_epochs:
90 | return self.base_lrs
91 | elif (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0:
92 | return [
93 | group["lr"] + (base_lr - self.eta_min) *
94 | (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2
95 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
96 | ]
97 |
98 | return [
99 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) /
100 | (
101 | 1 +
102 | math.cos(math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs))
103 | ) * (group["lr"] - self.eta_min) + self.eta_min for group in self.optimizer.param_groups
104 | ]
105 |
106 | def _get_closed_form_lr(self):
107 | """
108 | Called when epoch is passed as a param to the `step` function of the scheduler.
109 | """
110 | if self.last_epoch < self.warmup_epochs:
111 | return [
112 | self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
113 | for base_lr in self.base_lrs
114 | ]
115 |
116 | return [
117 | self.eta_min + 0.5 * (base_lr - self.eta_min) *
118 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
119 | for base_lr in self.base_lrs
120 | ]
121 |
122 |
123 | class LinearWarmupMultiStepLR(_LRScheduler):
124 | """
125 | Sets the learning rate of each parameter group to follow a linear warmup schedule
126 | between warmup_start_lr and base_lr followed byt co
127 |
128 | .. warning::
129 | It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR`
130 | after each iteration as calling it after each epoch will keep the starting lr at
131 | warmup_start_lr for the first epoch which is 0 in most cases.
132 |
133 | .. warning::
134 | passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING.
135 | It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of
136 | :func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing
137 | epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling
138 | train and validation methods.
139 | """
140 |
141 | def __init__(
142 | self,
143 | optimizer,
144 | warmup_epochs,
145 | milestones,
146 | warmup_start_lr = 0.0,
147 | gamma = 0.1,
148 | last_epoch = -1,
149 | ):
150 | """
151 | Args:
152 | optimizer (Optimizer): Wrapped optimizer.
153 | warmup_epochs (int): Maximum number of iterations for linear warmup
154 | max_epochs (int): Maximum number of iterations
155 | milestones (list): List of epoch indices. Must be increasing.
156 | warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0.
157 | gamma (float): Multiplicative factor of learning rate decay.
158 | Default: 0.1.
159 | last_epoch (int): The index of last epoch. Default: -1.
160 | """
161 | self.warmup_epochs = warmup_epochs
162 | self.warmup_start_lr = warmup_start_lr
163 | self.milestones = Counter(milestones)
164 | self.gamma = gamma
165 |
166 | super(LinearWarmupMultiStepLR, self).__init__(optimizer, last_epoch)
167 |
168 | def get_lr(self):
169 | """
170 | Compute learning rate using chainable form of the scheduler
171 | """
172 | if not self._get_lr_called_within_step:
173 | warnings.warn("To get the last learning rate computed by the scheduler, "
174 | "please use `get_last_lr()`.", UserWarning)
175 |
176 | if self.last_epoch == 0:
177 | # starting warm up
178 | return [self.warmup_start_lr] * len(self.base_lrs)
179 | elif self.last_epoch < self.warmup_epochs:
180 | # linear warm up (0 ~ self.warmup_epochs -1)
181 | return [
182 | group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
183 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
184 | ]
185 | elif self.last_epoch == self.warmup_epochs:
186 | # end of warm up (reset to base lrs)
187 | return self.base_lrs
188 | elif (self.last_epoch - self.warmup_epochs) not in self.milestones:
189 | # in between the steps
190 | return [group['lr'] for group in self.optimizer.param_groups]
191 |
192 | return [
193 | group['lr'] * self.gamma ** self.milestones[self.last_epoch - self.warmup_epochs]
194 | for group in self.optimizer.param_groups
195 | ]
196 |
197 | def _get_closed_form_lr(self):
198 | """
199 | Called when epoch is passed as a param to the `step` function of the scheduler.
200 | """
201 | if self.last_epoch < self.warmup_epochs:
202 | return [
203 | self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
204 | for base_lr in self.base_lrs
205 | ]
206 |
207 | milestones = list(sorted(self.milestones.elements()))
208 | return [base_lr * self.gamma ** bisect_right(milestones, self.last_epoch - self.warmup_epochs)
209 | for base_lr in self.base_lrs]
210 |
211 |
212 |
213 |
214 | class LinearWarmupNoDecayLR(_LRScheduler):
215 | """
216 | Sets the learning rate of each parameter group to follow a linear warmup schedule
217 | between warmup_start_lr and base_lr followed by a constant lr.
218 |
219 | .. warning::
220 | It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR`
221 | after each iteration as calling it after each epoch will keep the starting lr at
222 | warmup_start_lr for the first epoch which is 0 in most cases.
223 |
224 | .. warning::
225 | passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING.
226 | It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of
227 | :func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing
228 | epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling
229 | train and validation methods.
230 |
231 | Example:
232 | >>> layer = nn.Linear(10, 1)
233 | >>> optimizer = Adam(layer.parameters(), lr=0.02)
234 | >>> scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=40)
235 | >>> #
236 | >>> # the default case
237 | >>> for epoch in range(40):
238 | ... # train(...)
239 | ... # validate(...)
240 | ... scheduler.step()
241 | >>> #
242 | >>> # passing epoch param case
243 | >>> for epoch in range(40):
244 | ... scheduler.step(epoch)
245 | ... # train(...)
246 | ... # validate(...)
247 | """
248 |
249 | def __init__(
250 | self,
251 | optimizer,
252 | warmup_epochs,
253 | max_epochs,
254 | warmup_start_lr = 0.0,
255 | eta_min = 1e-8,
256 | last_epoch = -1,
257 | ):
258 | """
259 | Args:
260 | optimizer (Optimizer): Wrapped optimizer.
261 | warmup_epochs (int): Maximum number of iterations for linear warmup
262 | max_epochs (int): Maximum number of iterations
263 | warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0.
264 | eta_min (float): Minimum learning rate. Default: 0.
265 | last_epoch (int): The index of last epoch. Default: -1.
266 | """
267 | self.warmup_epochs = warmup_epochs
268 | self.max_epochs = max_epochs
269 | self.warmup_start_lr = warmup_start_lr
270 | self.eta_min = eta_min
271 |
272 | super(LinearWarmupNoDecayLR, self).__init__(optimizer, last_epoch)
273 |
274 | def get_lr(self):
275 | """
276 | Compute learning rate using chainable form of the scheduler
277 | """
278 | if not self._get_lr_called_within_step:
279 | warnings.warn(
280 | "To get the last learning rate computed by the scheduler, "
281 | "please use `get_last_lr()`.",
282 | UserWarning,
283 | )
284 |
285 | if self.last_epoch == 0:
286 | return [self.warmup_start_lr] * len(self.base_lrs)
287 | elif self.last_epoch < self.warmup_epochs:
288 | return [
289 | group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
290 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
291 | ]
292 | else:
293 | return self.base_lrs
294 |
295 | def _get_closed_form_lr(self):
296 | """
297 | Called when epoch is passed as a param to the `step` function of the scheduler.
298 | """
299 | if self.last_epoch < self.warmup_epochs:
300 | return [
301 | self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
302 | for base_lr in self.base_lrs
303 | ]
304 |
305 | return self.base_lrs
--------------------------------------------------------------------------------
/libs/modeling/i3d.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import numpy as np
4 | import torch
5 |
6 | def get_padding_shape(filter_shape, stride):
7 | def _pad_top_bottom(filter_dim, stride_val):
8 | pad_along = max(filter_dim - stride_val, 0)
9 | pad_top = pad_along // 2
10 | pad_bottom = pad_along - pad_top
11 | return pad_top, pad_bottom
12 |
13 | padding_shape = []
14 | for filter_dim, stride_val in zip(filter_shape, stride):
15 | pad_top, pad_bottom = _pad_top_bottom(filter_dim, stride_val)
16 | padding_shape.append(pad_top)
17 | padding_shape.append(pad_bottom)
18 | depth_top = padding_shape.pop(0)
19 | depth_bottom = padding_shape.pop(0)
20 | padding_shape.append(depth_top)
21 | padding_shape.append(depth_bottom)
22 |
23 | return tuple(padding_shape)
24 |
25 |
26 | def simplify_padding(padding_shapes):
27 | all_same = True
28 | padding_init = padding_shapes[0]
29 | for pad in padding_shapes[1:]:
30 | if pad != padding_init:
31 | all_same = False
32 | return all_same, padding_init
33 |
34 |
35 | class Unit3Dpy(torch.nn.Module):
36 | def __init__(self,
37 | in_channels,
38 | out_channels,
39 | kernel_size=(1, 1, 1),
40 | stride=(1, 1, 1),
41 | activation='relu',
42 | padding='SAME',
43 | use_bias=False,
44 | use_bn=True):
45 | super(Unit3Dpy, self).__init__()
46 |
47 | self.padding = padding
48 | self.activation = activation
49 | self.use_bn = use_bn
50 | if padding == 'SAME':
51 | padding_shape = get_padding_shape(kernel_size, stride)
52 | simplify_pad, pad_size = simplify_padding(padding_shape)
53 | self.simplify_pad = simplify_pad
54 | elif padding == 'VALID':
55 | padding_shape = 0
56 | else:
57 | raise ValueError(
58 | 'padding should be in [VALID|SAME] but got {}'.format(padding))
59 |
60 | if padding == 'SAME':
61 | if not simplify_pad:
62 | self.pad = torch.nn.ConstantPad3d(padding_shape, 0)
63 | self.conv3d = torch.nn.Conv3d(
64 | in_channels,
65 | out_channels,
66 | kernel_size,
67 | stride=stride,
68 | bias=use_bias)
69 | else:
70 | self.conv3d = torch.nn.Conv3d(
71 | in_channels,
72 | out_channels,
73 | kernel_size,
74 | stride=stride,
75 | padding=pad_size,
76 | bias=use_bias)
77 | elif padding == 'VALID':
78 | self.conv3d = torch.nn.Conv3d(
79 | in_channels,
80 | out_channels,
81 | kernel_size,
82 | padding=padding_shape,
83 | stride=stride,
84 | bias=use_bias)
85 | else:
86 | raise ValueError(
87 | 'padding should be in [VALID|SAME] but got {}'.format(padding))
88 |
89 | if self.use_bn:
90 | self.batch3d = torch.nn.BatchNorm3d(out_channels)
91 |
92 | if activation == 'relu':
93 | self.activation = torch.nn.functional.relu
94 |
95 | def forward(self, inp):
96 | if self.padding == 'SAME' and self.simplify_pad is False:
97 | inp = self.pad(inp)
98 | out = self.conv3d(inp)
99 | if self.use_bn:
100 | out = self.batch3d(out)
101 | if self.activation is not None:
102 | out = torch.nn.functional.relu(out)
103 | return out
104 |
105 |
106 | class MaxPool3dTFPadding(torch.nn.Module):
107 | def __init__(self, kernel_size, stride=None, padding='SAME'):
108 | super(MaxPool3dTFPadding, self).__init__()
109 | if padding == 'SAME':
110 | padding_shape = get_padding_shape(kernel_size, stride)
111 | self.padding_shape = padding_shape
112 | self.pad = torch.nn.ConstantPad3d(padding_shape, 0)
113 | self.pool = torch.nn.MaxPool3d(kernel_size, stride, ceil_mode=True)
114 |
115 | def forward(self, inp):
116 | inp = self.pad(inp)
117 | out = self.pool(inp)
118 | return out
119 |
120 |
121 | class Mixed(torch.nn.Module):
122 | def __init__(self, in_channels, out_channels):
123 | super(Mixed, self).__init__()
124 | # Branch 0
125 | self.branch_0 = Unit3Dpy(
126 | in_channels, out_channels[0], kernel_size=(1, 1, 1))
127 |
128 | # Branch 1
129 | branch_1_conv1 = Unit3Dpy(
130 | in_channels, out_channels[1], kernel_size=(1, 1, 1))
131 | branch_1_conv2 = Unit3Dpy(
132 | out_channels[1], out_channels[2], kernel_size=(3, 3, 3))
133 | self.branch_1 = torch.nn.Sequential(branch_1_conv1, branch_1_conv2)
134 |
135 | # Branch 2
136 | branch_2_conv1 = Unit3Dpy(
137 | in_channels, out_channels[3], kernel_size=(1, 1, 1))
138 | branch_2_conv2 = Unit3Dpy(
139 | out_channels[3], out_channels[4], kernel_size=(3, 3, 3))
140 | self.branch_2 = torch.nn.Sequential(branch_2_conv1, branch_2_conv2)
141 |
142 | # Branch3
143 | branch_3_pool = MaxPool3dTFPadding(
144 | kernel_size=(3, 3, 3), stride=(1, 1, 1), padding='SAME')
145 | branch_3_conv2 = Unit3Dpy(
146 | in_channels, out_channels[5], kernel_size=(1, 1, 1))
147 | self.branch_3 = torch.nn.Sequential(branch_3_pool, branch_3_conv2)
148 |
149 | def forward(self, inp):
150 | out_0 = self.branch_0(inp)
151 | out_1 = self.branch_1(inp)
152 | out_2 = self.branch_2(inp)
153 | out_3 = self.branch_3(inp)
154 | out = torch.cat((out_0, out_1, out_2, out_3), 1)
155 | return out
156 |
157 |
158 | class I3D(torch.nn.Module):
159 | def __init__(self,
160 | num_classes,
161 | modality='rgb',
162 | dropout_prob=0,
163 | name='inception'):
164 | super(I3D, self).__init__()
165 |
166 | self.name = name
167 | self.num_classes = num_classes
168 | if modality == 'rgb':
169 | in_channels = 3
170 | elif modality == 'flow':
171 | in_channels = 2
172 | else:
173 | raise ValueError(
174 | '{} not among known modalities [rgb|flow]'.format(modality))
175 | self.modality = modality
176 |
177 | conv3d_1a_7x7 = Unit3Dpy(
178 | out_channels=64,
179 | in_channels=in_channels,
180 | kernel_size=(7, 7, 7),
181 | stride=(2, 2, 2),
182 | padding='SAME')
183 | # 1st conv-pool
184 | self.conv3d_1a_7x7 = conv3d_1a_7x7
185 | self.maxPool3d_2a_3x3 = MaxPool3dTFPadding(
186 | kernel_size=(1, 3, 3), stride=(1, 2, 2), padding='SAME')
187 | # conv conv
188 | conv3d_2b_1x1 = Unit3Dpy(
189 | out_channels=64,
190 | in_channels=64,
191 | kernel_size=(1, 1, 1),
192 | padding='SAME')
193 | self.conv3d_2b_1x1 = conv3d_2b_1x1
194 | conv3d_2c_3x3 = Unit3Dpy(
195 | out_channels=192,
196 | in_channels=64,
197 | kernel_size=(3, 3, 3),
198 | padding='SAME')
199 | self.conv3d_2c_3x3 = conv3d_2c_3x3
200 | self.maxPool3d_3a_3x3 = MaxPool3dTFPadding(
201 | kernel_size=(1, 3, 3), stride=(1, 2, 2), padding='SAME')
202 |
203 | # Mixed_3b
204 | self.mixed_3b = Mixed(192, [64, 96, 128, 16, 32, 32])
205 | self.mixed_3c = Mixed(256, [128, 128, 192, 32, 96, 64])
206 |
207 | self.maxPool3d_4a_3x3 = MaxPool3dTFPadding(
208 | kernel_size=(3, 3, 3), stride=(2, 2, 2), padding='SAME')
209 |
210 | # Mixed 4
211 | self.mixed_4b = Mixed(480, [192, 96, 208, 16, 48, 64])
212 | self.mixed_4c = Mixed(512, [160, 112, 224, 24, 64, 64])
213 | self.mixed_4d = Mixed(512, [128, 128, 256, 24, 64, 64])
214 | self.mixed_4e = Mixed(512, [112, 144, 288, 32, 64, 64])
215 | self.mixed_4f = Mixed(528, [256, 160, 320, 32, 128, 128])
216 |
217 | self.maxPool3d_5a_2x2 = MaxPool3dTFPadding(
218 | kernel_size=(2, 2, 2), stride=(2, 2, 2), padding='SAME')
219 |
220 | # Mixed 5
221 | self.mixed_5b = Mixed(832, [256, 160, 320, 32, 128, 128])
222 | self.mixed_5c = Mixed(832, [384, 192, 384, 48, 128, 128])
223 |
224 | # self.avg_pool = torch.nn.AvgPool3d((2, 7, 7), (1, 1, 1))
225 | # keep temporal dim
226 | # self.avg_pool = torch.nn.AvgPool3d((1, 7, 7), (1, 1, 1))
227 | # self.spacial_se = torch.nn.Sequential(
228 | # torch.nn.Conv2d(1024, 1, 3, 1, 1),
229 | # torch.nn.Sigmoid()
230 | # )
231 | # self.avg_pool_se = torch.nn.AvgPool2d(7, 7)
232 |
233 | # self.dropout = torch.nn.Dropout(dropout_prob)
234 | self.conv3d_0c_1x1 = Unit3Dpy(
235 | in_channels=1024,
236 | out_channels=self.num_classes,
237 | kernel_size=(1, 1, 1),
238 | activation=None,
239 | use_bias=True,
240 | use_bn=False)
241 | # self.softmax = torch.nn.Softmax(1)
242 | # self.softmax = torch.nn.Sigmiod(1)
243 |
244 | def get_logits_dim(self):
245 | return 1024
246 |
247 | def forward(self, inp):
248 | # Preprocessing
249 | out = self.conv3d_1a_7x7(inp)
250 | out = self.maxPool3d_2a_3x3(out)
251 | out = self.conv3d_2b_1x1(out)
252 | out = self.conv3d_2c_3x3(out)
253 | out = self.maxPool3d_3a_3x3(out)
254 | out = self.mixed_3b(out)
255 | out = self.mixed_3c(out)
256 | out = self.maxPool3d_4a_3x3(out)
257 | out = self.mixed_4b(out)
258 | out = self.mixed_4c(out)
259 | out = self.mixed_4d(out)
260 | out = self.mixed_4e(out)
261 | out = self.mixed_4f(out)
262 | out = self.maxPool3d_5a_2x2(out)
263 | out = self.mixed_5b(out)
264 | out = self.mixed_5c(out)
265 |
266 | # out = out.mean(2)
267 | # feature = self.avg_pool_se(self.spacial_se(out) * out)
268 | # out = self.avg_pool(out)
269 | return out
270 | # out = self.dropout(feature)
271 | # out = self.conv3d_0c_1x1(out)
272 | # out = out.squeeze(3)
273 | # out = out.squeeze(3)
274 | # out,_ = out.max(2)
275 | # out_logits = out
276 | # out = self.softmax(out_logits)
277 | # # out = self.sigmoid(out_logits) # B * C
278 | # return feature, out, out_logits
279 |
280 | def load_tf_weights(self, sess):
281 | state_dict = {}
282 | if self.modality == 'rgb':
283 | prefix = 'RGB/inception_i3d'
284 | elif self.modality == 'flow':
285 | prefix = 'Flow/inception_i3d'
286 | load_conv3d(state_dict, 'conv3d_1a_7x7', sess,
287 | os.path.join(prefix, 'Conv3d_1a_7x7'))
288 | load_conv3d(state_dict, 'conv3d_2b_1x1', sess,
289 | os.path.join(prefix, 'Conv3d_2b_1x1'))
290 | load_conv3d(state_dict, 'conv3d_2c_3x3', sess,
291 | os.path.join(prefix, 'Conv3d_2c_3x3'))
292 |
293 | load_mixed(state_dict, 'mixed_3b', sess,
294 | os.path.join(prefix, 'Mixed_3b'))
295 | load_mixed(state_dict, 'mixed_3c', sess,
296 | os.path.join(prefix, 'Mixed_3c'))
297 | load_mixed(state_dict, 'mixed_4b', sess,
298 | os.path.join(prefix, 'Mixed_4b'))
299 | load_mixed(state_dict, 'mixed_4c', sess,
300 | os.path.join(prefix, 'Mixed_4c'))
301 | load_mixed(state_dict, 'mixed_4d', sess,
302 | os.path.join(prefix, 'Mixed_4d'))
303 | load_mixed(state_dict, 'mixed_4e', sess,
304 | os.path.join(prefix, 'Mixed_4e'))
305 | # Here goest to 0.1 max error with tf
306 | load_mixed(state_dict, 'mixed_4f', sess,
307 | os.path.join(prefix, 'Mixed_4f'))
308 |
309 | load_mixed(
310 | state_dict,
311 | 'mixed_5b',
312 | sess,
313 | os.path.join(prefix, 'Mixed_5b'),
314 | fix_typo=True)
315 | load_mixed(state_dict, 'mixed_5c', sess,
316 | os.path.join(prefix, 'Mixed_5c'))
317 | load_conv3d(
318 | state_dict,
319 | 'conv3d_0c_1x1',
320 | sess,
321 | os.path.join(prefix, 'Logits', 'Conv3d_0c_1x1'),
322 | bias=True,
323 | bn=False)
324 | self.load_state_dict(state_dict)
325 |
326 |
327 | def get_conv_params(sess, name, bias=False):
328 | # Get conv weights
329 | conv_weights_tensor = sess.graph.get_tensor_by_name(
330 | os.path.join(name, 'w:0'))
331 | if bias:
332 | conv_bias_tensor = sess.graph.get_tensor_by_name(
333 | os.path.join(name, 'b:0'))
334 | conv_bias = sess.run(conv_bias_tensor)
335 | conv_weights = sess.run(conv_weights_tensor)
336 | conv_shape = conv_weights.shape
337 |
338 | kernel_shape = conv_shape[0:3]
339 | in_channels = conv_shape[3]
340 | out_channels = conv_shape[4]
341 |
342 | conv_op = sess.graph.get_operation_by_name(
343 | os.path.join(name, 'convolution'))
344 | padding_name = conv_op.get_attr('padding')
345 | padding = _get_padding(padding_name, kernel_shape)
346 | all_strides = conv_op.get_attr('strides')
347 | strides = all_strides[1:4]
348 | conv_params = [
349 | conv_weights, kernel_shape, in_channels, out_channels, strides, padding
350 | ]
351 | if bias:
352 | conv_params.append(conv_bias)
353 | return conv_params
354 |
355 |
356 | def get_bn_params(sess, name):
357 | moving_mean_tensor = sess.graph.get_tensor_by_name(
358 | os.path.join(name, 'moving_mean:0'))
359 | moving_var_tensor = sess.graph.get_tensor_by_name(
360 | os.path.join(name, 'moving_variance:0'))
361 | beta_tensor = sess.graph.get_tensor_by_name(os.path.join(name, 'beta:0'))
362 | moving_mean = sess.run(moving_mean_tensor)
363 | moving_var = sess.run(moving_var_tensor)
364 | beta = sess.run(beta_tensor)
365 | return moving_mean, moving_var, beta
366 |
367 |
368 | def _get_padding(padding_name, conv_shape):
369 | padding_name = padding_name.decode("utf-8")
370 | if padding_name == "VALID":
371 | return [0, 0]
372 | elif padding_name == "SAME":
373 | # return [math.ceil(int(conv_shape[0])/2), math.ceil(int(conv_shape[1])/2)]
374 | return [
375 | math.floor(int(conv_shape[0]) / 2),
376 | math.floor(int(conv_shape[1]) / 2),
377 | math.floor(int(conv_shape[2]) / 2)
378 | ]
379 | else:
380 | raise ValueError('Invalid padding name ' + padding_name)
381 |
382 |
383 | def load_conv3d(state_dict, name_pt, sess, name_tf, bias=False, bn=True):
384 | # Transfer convolution params
385 | conv_name_tf = os.path.join(name_tf, 'conv_3d')
386 | conv_params = get_conv_params(sess, conv_name_tf, bias=bias)
387 | if bias:
388 | conv_weights, kernel_shape, in_channels, out_channels, strides, padding, conv_bias = conv_params
389 | else:
390 | conv_weights, kernel_shape, in_channels, out_channels, strides, padding = conv_params
391 |
392 | conv_weights_rs = np.transpose(
393 | conv_weights, (4, 3, 0, 1,
394 | 2)) # to pt format (out_c, in_c, depth, height, width)
395 | state_dict[name_pt + '.conv3d.weight'] = torch.from_numpy(conv_weights_rs)
396 | if bias:
397 | state_dict[name_pt + '.conv3d.bias'] = torch.from_numpy(conv_bias)
398 |
399 | # Transfer batch norm params
400 | if bn:
401 | conv_tf_name = os.path.join(name_tf, 'batch_norm')
402 | moving_mean, moving_var, beta = get_bn_params(sess, conv_tf_name)
403 |
404 | out_planes = conv_weights_rs.shape[0]
405 | state_dict[name_pt + '.batch3d.weight'] = torch.ones(out_planes)
406 | state_dict[name_pt +
407 | '.batch3d.bias'] = torch.from_numpy(beta.squeeze())
408 | state_dict[name_pt
409 | + '.batch3d.running_mean'] = torch.from_numpy(moving_mean.squeeze())
410 | state_dict[name_pt
411 | + '.batch3d.running_var'] = torch.from_numpy(moving_var.squeeze())
412 |
413 |
414 | def load_mixed(state_dict, name_pt, sess, name_tf, fix_typo=False):
415 | # Branch 0
416 | load_conv3d(state_dict, name_pt + '.branch_0', sess,
417 | os.path.join(name_tf, 'Branch_0/Conv3d_0a_1x1'))
418 |
419 | # Branch .1
420 | load_conv3d(state_dict, name_pt + '.branch_1.0', sess,
421 | os.path.join(name_tf, 'Branch_1/Conv3d_0a_1x1'))
422 | load_conv3d(state_dict, name_pt + '.branch_1.1', sess,
423 | os.path.join(name_tf, 'Branch_1/Conv3d_0b_3x3'))
424 |
425 | # Branch 2
426 | load_conv3d(state_dict, name_pt + '.branch_2.0', sess,
427 | os.path.join(name_tf, 'Branch_2/Conv3d_0a_1x1'))
428 | if fix_typo:
429 | load_conv3d(state_dict, name_pt + '.branch_2.1', sess,
430 | os.path.join(name_tf, 'Branch_2/Conv3d_0a_3x3'))
431 | else:
432 | load_conv3d(state_dict, name_pt + '.branch_2.1', sess,
433 | os.path.join(name_tf, 'Branch_2/Conv3d_0b_3x3'))
434 |
435 | # Branch 3
436 | load_conv3d(state_dict, name_pt + '.branch_3.1', sess,
437 | os.path.join(name_tf, 'Branch_3/Conv3d_0b_1x1'))
438 |
439 |
440 |
441 |
442 |
443 |
--------------------------------------------------------------------------------
/libs/utils/train_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import sys
4 | import numpy as np
5 | import random
6 | from copy import deepcopy
7 | from pprint import pprint
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.optim as optim
12 | import torch.backends.cudnn as cudnn
13 | import torch.nn.functional as F
14 | import gc
15 |
16 | from .lr_schedulers import LinearWarmupMultiStepLR, LinearWarmupCosineAnnealingLR, LinearWarmupNoDecayLR
17 |
18 | from ..modeling import MaskedConv1D, Scale, AffineDropPath, LayerNorm
19 | from .metrics import evaluate, MetricsDict
20 | from .preprocessing import preprocessing
21 | from .postprocessing import update_metric_dict_with_model_output
22 | from collections import defaultdict
23 |
24 |
25 | from torch.profiler import profile, record_function, ProfilerActivity
26 |
27 | class Logger(object):
28 | def __init__(self, log_file):
29 | self.terminal = sys.stdout
30 | self.log_file = log_file
31 |
32 | def write(self, message):
33 | #import ipdb; ipdb.set_trace()
34 | message = str(message)
35 | #pprint(message, stream = sys.__stdout__)
36 | self.terminal.write(message)
37 | #print(message + "\n", file=sys.__stdout__, end = "")
38 | with open(self.log_file, 'a') as f:
39 | f.write(message)
40 | def flush(self):
41 | self.terminal.flush()
42 |
43 | def set_bn_eval(m):
44 | classname = m.__class__.__name__
45 | if classname.find('BatchNorm') != -1:
46 | m.eval()
47 |
48 | def freeze_top_3_blocks(m):
49 | classname = m.__class__.__name__
50 | if (
51 | classname.find('conv3d_1a_7x7') != -1 or
52 | classname.find('conv3d_2b_1x1') != -1 or
53 | classname.find('conv3d_2c_3x3') != -1 or
54 | classname.find('mixed_3b') != -1 or
55 | classname.find('mixed_3c') != -1 or
56 | classname.find('mixed_4b') != -1
57 | ):
58 | m.requires_grad = False
59 | m.eval()
60 |
61 |
62 | ################################################################################
63 | def fix_random_seed(seed, include_cuda=True):
64 | rng_generator = torch.manual_seed(seed)
65 | np.random.seed(seed)
66 | random.seed(seed)
67 | os.environ["PYTHONHASHSEED"] = str(seed)
68 | if include_cuda:
69 | # training: disable cudnn benchmark to ensure the reproducibility
70 | cudnn.enabled = True
71 | cudnn.benchmark = False
72 | cudnn.deterministic = True
73 | torch.cuda.manual_seed(seed)
74 | torch.cuda.manual_seed_all(seed)
75 | # this is needed for CUDA >= 10.2
76 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
77 | # print("Determininstic is off")
78 | torch.use_deterministic_algorithms(True, warn_only=True)
79 | else:
80 | cudnn.enabled = True
81 | cudnn.benchmark = True
82 | return rng_generator
83 |
84 |
85 | def save_checkpoint(state, is_best, file_folder,
86 | file_name='checkpoint.pth.tar'):
87 | """save checkpoint to file"""
88 | if not os.path.exists(file_folder):
89 | os.mkdir(file_folder)
90 | torch.save(state, os.path.join(file_folder, file_name))
91 | if is_best:
92 | # skip the optimization / scheduler state
93 | state.pop('optimizer', None)
94 | state.pop('scheduler', None)
95 | torch.save(state, os.path.join(file_folder, 'model_best.pth.tar'))
96 |
97 |
98 | def print_model_params(model):
99 | for name, param in model.named_parameters():
100 | print(name, param.min().item(), param.max().item(), param.mean().item())
101 | return
102 |
103 |
104 | def make_optimizer(model, optimizer_config):
105 | """create optimizer
106 | return a supported optimizer
107 | """
108 | # separate out all parameters that with / without weight decay
109 | # see https://github.com/karpathy/minGPT/blob/master/mingpt/model.py#L134
110 |
111 | low_lr_decay = set()
112 | low_lr_no_decay = set()
113 |
114 | med_lr_decay = set()
115 | med_lr_no_decay = set()
116 |
117 | high_lr_decay = set()
118 | high_lr_no_decay = set()
119 |
120 | whitelist_weight_modules = (torch.nn.Linear, nn.Parameter, torch.nn.Conv1d, torch.nn.Conv2d, MaskedConv1D, torch.nn.Conv3d, torch.nn.MultiheadAttention, torch.nn.ConvTranspose1d)
121 | blacklist_weight_modules = (LayerNorm, torch.nn.GroupNorm,torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.Embedding, torch.nn.LayerNorm)
122 |
123 | # three sets of lr: low_lr for feat_extractor/i3d, med_lr for neck/backbone, and high_lr for heads.
124 | # loop over all modules / params
125 | for mn, m in model.named_modules():
126 | for pn, p in m.named_parameters():
127 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
128 | #low lr for feat extractor
129 | if "feat_extractor" in fpn or "feat_extractor" in pn:
130 | if pn.endswith('bias'):
131 | # all biases will not be decayed
132 | low_lr_no_decay.add(fpn)
133 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
134 | # weights of whitelist modules will be weight decayed
135 | low_lr_decay.add(fpn)
136 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
137 | # weights of blacklist modules will NOT be weight decayed
138 | low_lr_no_decay.add(fpn)
139 | elif pn.endswith('scale') and isinstance(m, (Scale, AffineDropPath)):
140 | # corner case of our scale layer
141 | low_lr_no_decay.add(fpn)
142 | elif pn.endswith('rel_pe'):
143 | # corner case for relative position encoding
144 | low_lr_no_decay.add(fpn)
145 | # med lr for neck
146 | elif ("neck" in fpn or "neck" in pn) or ("backbone" in fpn or "backbone" in pn):
147 | if pn.endswith('bias'):
148 | med_lr_no_decay.add(fpn)
149 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
150 | med_lr_decay.add(fpn)
151 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
152 | med_lr_no_decay.add(fpn)
153 | elif pn.endswith('scale') and isinstance(m, (Scale, AffineDropPath)):
154 | med_lr_no_decay.add(fpn)
155 | elif pn.endswith('rel_pe'):
156 | med_lr_no_decay.add(fpn)
157 | # high lr for quality_score_head
158 | elif ("quality_score_head" in fpn or "quality_score_head" in pn):
159 | if pn.endswith('bias'):
160 | # all biases will not be decayed
161 | high_lr_no_decay.add(fpn)
162 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
163 | # weights of whitelist modules will be weight decayed
164 | high_lr_decay.add(fpn)
165 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
166 | # weights of blacklist modules will NOT be weight decayed
167 | high_lr_no_decay.add(fpn)
168 | elif pn.endswith('scale') and isinstance(m, (Scale, AffineDropPath)):
169 | # corner case of our scale layer
170 | high_lr_no_decay.add(fpn)
171 | elif pn.endswith('rel_pe'):
172 | # corner case for relative position encoding
173 | high_lr_no_decay.add(fpn)
174 | elif pn.endswith('temperature'):
175 | high_lr_no_decay.add(fpn)
176 | else:
177 | raise ValueError(f"Unrecognized parameter: {fpn}")
178 |
179 | # validate that we considered every parameter
180 | param_dict = {pn: p for pn, p in model.named_parameters()}
181 |
182 | # @TODO add assert statements for med_lr_decay, med_lr_no_decay?
183 | #check that all intersections are empty
184 | assert len(high_lr_decay & high_lr_no_decay) == 0, \
185 | "parameters %s were in both high_lr_decay and high_lr_no_decay set!" \
186 | % (str(high_lr_decay & high_lr_no_decay), )
187 | assert len(low_lr_decay & low_lr_no_decay) == 0, \
188 | "parameters %s were in both low_lr_decay and low_lr_no_decay set!" \
189 | % (str(low_lr_decay & low_lr_no_decay), )
190 | assert len(high_lr_decay & low_lr_decay) == 0, \
191 | "parameters %s were in both high_lr_decay and low_lr_decay set!" \
192 | % (str(high_lr_decay & low_lr_decay), )
193 | assert len(high_lr_no_decay & low_lr_no_decay) == 0, \
194 | "parameters %s were in both high_lr_no_decay and low_lr_no_decay set!" \
195 | % (str(high_lr_no_decay & low_lr_no_decay), )
196 |
197 |
198 | union_params = high_lr_decay | high_lr_no_decay | low_lr_decay | low_lr_no_decay | med_lr_decay | med_lr_no_decay
199 | assert len(param_dict.keys() - union_params) == 0, \
200 | "parameters %s were not separated into either decay/no_decay set!" \
201 | % (str(param_dict.keys() - union_params), )
202 |
203 | # create the pytorch optimizer object
204 | if (len(low_lr_decay) + len(low_lr_no_decay) + len(med_lr_decay) + len(med_lr_no_decay)) > 0:
205 | optim_groups = [
206 | {"params": [param_dict[pn] for pn in sorted(list(high_lr_decay))],
207 | "weight_decay": optimizer_config['weight_decay'],
208 | 'name': 'high_lr_decay'},
209 | {"params": [param_dict[pn] for pn in sorted(list(high_lr_no_decay))],
210 | "weight_decay": 0.0,
211 | 'name': 'high_lr_no_decay'},
212 | {"params": [param_dict[pn] for pn in sorted(list(med_lr_decay))],
213 | "weight_decay": optimizer_config['weight_decay'],
214 | "lr": optimizer_config["learning_rate"] * optimizer_config["neck_lr_factor"],
215 | 'name': 'med_lr_decay'},
216 | {"params": [param_dict[pn] for pn in sorted(list(med_lr_no_decay))],
217 | "weight_decay": 0.0,
218 | "lr": optimizer_config["learning_rate"] * optimizer_config["neck_lr_factor"],
219 | 'name': 'med_lr_no_decay'},
220 | {"params": [param_dict[pn] for pn in sorted(list(low_lr_decay))],
221 | "weight_decay": optimizer_config['weight_decay'],
222 | "lr": optimizer_config["learning_rate"] * optimizer_config["feature_extractor_factor"],
223 | 'name': 'low_lr_decay'},
224 | {"params": [param_dict[pn] for pn in sorted(list(low_lr_no_decay))],
225 | "weight_decay": 0.0,
226 | "lr": optimizer_config["learning_rate"] * optimizer_config["feature_extractor_factor"],
227 | 'name': 'low_lr_no_decay'}
228 | ]
229 | else:
230 | optim_groups = [
231 | {"params": [param_dict[pn] for pn in sorted(list(high_lr_decay))], "weight_decay": optimizer_config['weight_decay']},
232 | {"params": [param_dict[pn] for pn in sorted(list(high_lr_no_decay))], "weight_decay": 0.0}
233 | ]
234 |
235 | if optimizer_config["type"] == "SGD":
236 | optimizer = optim.SGD(
237 | optim_groups,
238 | lr=optimizer_config["learning_rate"],
239 | momentum=optimizer_config["momentum"]
240 | )
241 | elif optimizer_config["type"] == "AdamW":
242 | optimizer = optim.AdamW(
243 | optim_groups,
244 | lr=optimizer_config["learning_rate"]
245 | )
246 | else:
247 | raise TypeError("Unsupported optimizer!")
248 |
249 | return optimizer
250 |
251 |
252 | def make_scheduler(
253 | optimizer,
254 | optimizer_config,
255 | num_iters_per_epoch,
256 | last_epoch=-1
257 | ):
258 | """create scheduler
259 | return a supported scheduler
260 | All scheduler returned by this function should step every iteration
261 | """
262 | if optimizer_config["warmup"]:
263 | max_epochs = optimizer_config["epochs"] + optimizer_config["warmup_epochs"]
264 | max_steps = max_epochs * num_iters_per_epoch
265 |
266 | # get warmup params
267 | warmup_epochs = optimizer_config["warmup_epochs"]
268 | warmup_steps = warmup_epochs * num_iters_per_epoch
269 |
270 | # with linear warmup: call our custom schedulers
271 | if optimizer_config["schedule_type"] == "cosine":
272 | # Cosine
273 | scheduler = LinearWarmupCosineAnnealingLR(
274 | optimizer,
275 | warmup_steps,
276 | max_steps,
277 | last_epoch=last_epoch
278 | )
279 |
280 | elif optimizer_config["schedule_type"] == "multistep":
281 | # Multi step
282 | steps = [num_iters_per_epoch * step for step in optimizer_config["schedule_steps"]]
283 | scheduler = LinearWarmupMultiStepLR(
284 | optimizer,
285 | warmup_steps,
286 | steps,
287 | gamma=optimizer_config["schedule_gamma"],
288 | last_epoch=last_epoch
289 | )
290 | elif optimizer_config["schedule_type"] == "no_decay":
291 | steps = [num_iters_per_epoch * step for step in optimizer_config["schedule_steps"]]
292 | scheduler = LinearWarmupNoDecayLR(
293 | optimizer,
294 | warmup_steps,
295 | steps,
296 | last_epoch=last_epoch
297 | )
298 | else:
299 | raise TypeError("Unsupported scheduler!")
300 |
301 | else:
302 | max_epochs = optimizer_config["epochs"]
303 | max_steps = max_epochs * num_iters_per_epoch
304 |
305 | # without warmup: call default schedulers
306 | if optimizer_config["schedule_type"] == "cosine":
307 | # step per iteration
308 | scheduler = optim.lr_scheduler.CosineAnnealingLR(
309 | optimizer,
310 | max_steps,
311 | last_epoch=last_epoch
312 | )
313 |
314 | elif optimizer_config["schedule_type"] == "multistep":
315 | # step every some epochs
316 | steps = [num_iters_per_epoch * step for step in optimizer_config["schedule_steps"]]
317 | scheduler = optim.lr_scheduler.MultiStepLR(
318 | optimizer,
319 | steps,
320 | gamma=schedule_config["gamma"],
321 | last_epoch=last_epoch
322 | )
323 | else:
324 | raise TypeError("Unsupported scheduler!")
325 |
326 | return scheduler
327 |
328 |
329 | class AverageMeter(object):
330 | """Computes and stores the average and current value.
331 | Used to compute dataset stats from mini-batches
332 | """
333 | def __init__(self):
334 | self.initialized = False
335 | self.val = None
336 | self.avg = None
337 | self.sum = None
338 | self.count = 0.0
339 |
340 | def initialize(self, val, n):
341 | self.val = val
342 | self.avg = val
343 | self.sum = val * n
344 | self.count = n
345 | self.initialized = True
346 |
347 | def update(self, val, n=1):
348 | if not self.initialized:
349 | self.initialize(val, n)
350 | else:
351 | self.add(val, n)
352 |
353 | def add(self, val, n):
354 | self.val = val
355 | self.sum += val * n
356 | self.count += n
357 | self.avg = self.sum / self.count
358 |
359 |
360 | class ModelEma(torch.nn.Module):
361 | def __init__(self, model, decay=0.999, device=None):
362 | super().__init__()
363 | # make a copy of the model for accumulating moving average of weights
364 | self.module = deepcopy(model)
365 | self.module.eval()
366 | self.decay = decay
367 | self.device = device # perform ema on different device from model if set
368 | if self.device is not None:
369 | self.module.to(device=device)
370 |
371 | def _update(self, model, update_fn):
372 | with torch.no_grad():
373 | for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
374 | if self.device is not None:
375 | model_v = model_v.to(device=self.device)
376 | ema_v.copy_(update_fn(ema_v, model_v))
377 |
378 | def update(self, model):
379 | self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
380 |
381 | def set(self, model):
382 | self._update(model, update_fn=lambda e, m: m)
383 |
384 |
385 | def my_corn_label_from_logits(logits):
386 | probs = torch.sigmoid(logits)
387 | cum_prob = torch.cumprod(probs, dim=-1)
388 | predict_levels = cum_prob > 0.5
389 | predict_labels = torch.sum(predict_levels, dim=-1)
390 | return predict_labels
391 |
392 |
393 |
394 | def print_training_progress(curr_epoch, curr_iter, total_iters, batch_time, losses_tracker, rho, L2, RL2):
395 | # Format the training progress information
396 | progress_str = [
397 | f"Epoch: [{curr_epoch:03d}][{curr_iter:05d}/{total_iters:05d}]",
398 | f"Time: {batch_time.val:.2f}s ({batch_time.avg:.2f}s)\n",
399 | f"Total Loss: {losses_tracker['final_loss'].val:.2f} ({losses_tracker['final_loss'].avg:.2f})"
400 | ]
401 |
402 | loss_info = ", ".join([f"{key}: {value.val:.2f} ({value.avg:.2f})" for key, value in losses_tracker.items() if key != "final_loss"])
403 | progress_str.append(loss_info)
404 |
405 | progress_str.append(f"Running -> Rho: {rho:.4f}, L2: {L2:.2f}, RL2: {RL2:.4f}")
406 |
407 | print('\t'.join(progress_str))
408 |
409 |
410 |
411 |
412 | ################################################################################
413 | def train_one_epoch(
414 | train_loader,
415 | model,
416 | optimizer,
417 | scheduler,
418 | curr_epoch,
419 | cfg = None,
420 | tb_writer = None,
421 | print_freq = 20
422 | ):
423 | """Training the model for one epoch"""
424 | # set up meters
425 | batch_time = AverageMeter()
426 | losses_tracker = defaultdict(AverageMeter)
427 | metric_dict = MetricsDict()
428 | # number of iterations per epoch
429 | num_iters = len(train_loader)
430 | # switch to train mode
431 | model.train()
432 |
433 | #fix bn
434 | model.apply(set_bn_eval)
435 |
436 | if "freeze_early_layers" in cfg["train_cfg"]:
437 | if cfg["train_cfg"]["freeze_early_layers"]:
438 | print("Freezing early layers")
439 | model.apply(freeze_top_3_blocks)
440 |
441 | # main training loop
442 | print("\n[Train]: Epoch {:d} started".format(curr_epoch))
443 | start = time.time()
444 |
445 | optimizer.zero_grad()
446 | for iter_idx, video_list in enumerate(train_loader, 0):
447 | # zero out optimizer
448 | preprocessed_dict = preprocessing(video_list,
449 | feat_extractor_type = cfg["model"]["feat_extractor_type"],
450 | max_seq_len= cfg["model"]["max_seq_len"],
451 | num_phases= cfg["model"]["num_phases"]
452 | )
453 |
454 | #loss is calculated in the model
455 | criterion_args = {"target" : preprocessed_dict['target'],
456 | "difficulties" : preprocessed_dict["difficulties"],
457 | "gt_actions" : preprocessed_dict["gt_actions"],
458 | "loss_weights" : cfg["train_cfg"]["loss_weights"],
459 | "with_dd" : cfg["dataset"]["with_dd"],
460 | "three_judge_score_scaling" : cfg["dataset"]["three_judge_score_scaling"]
461 | }
462 | # forward
463 | model_output, losses = model(**dict(batched_inputs = preprocessed_dict["batched_inputs"],
464 | batched_masks = preprocessed_dict["batched_masks"],
465 | batched_gt_actions = preprocessed_dict["gt_actions"],
466 | video_ids = preprocessed_dict["video_ids"],
467 | curr_epoch = curr_epoch,
468 | criterion_args = criterion_args
469 | ))
470 |
471 |
472 | for key, value in losses.items():
473 | losses[key] = losses[key].mean()
474 |
475 | losses["final_loss"] /= cfg["train_cfg"]["accumulation_steps"]
476 |
477 | losses['final_loss'].backward()
478 |
479 | # gradient cliping (to stabilize training if necessary)
480 | if cfg["train_cfg"]["clip_grad_l2norm"] > 0.0:
481 | torch.nn.utils.clip_grad_norm_(
482 | model.parameters(),
483 | cfg["train_cfg"]["clip_grad_l2norm"]
484 | )
485 |
486 | if ((iter_idx + 1) % cfg["train_cfg"]["accumulation_steps"]) == 0 or (iter_idx == num_iters - 1):
487 | optimizer.step()
488 | optimizer.zero_grad()
489 | scheduler.step()
490 |
491 |
492 | with torch.no_grad():
493 | update_metric_dict_with_model_output(metric_dict, model_output, preprocessed_dict['gt_scores'],preprocessed_dict['difficulties'], is_val=False, cfg=cfg)
494 |
495 | # track all losses
496 | for key, value in losses.items():
497 | losses_tracker[key].update(value.item())
498 |
499 | del losses
500 |
501 | gc.collect()
502 |
503 | # printing (only check the stats when necessary to avoid extra cost)
504 | if ((iter_idx != 0) and (iter_idx % print_freq) == 0) or (iter_idx == num_iters - 1):
505 | # measure elapsed time (sync all kernels)
506 | torch.cuda.synchronize()
507 | batch_time.update((time.time() - start) / print_freq)
508 | start = time.time()
509 |
510 | rho, L2, RL2 = evaluate(metric_dict,
511 | is_train=True,
512 | dataset_name=cfg["dataset_name"],
513 | )
514 |
515 | # log to tensor board
516 | lr = scheduler.get_last_lr()[0]
517 | global_step = curr_epoch * num_iters + iter_idx
518 | if tb_writer is not None:
519 | # learning rate (after stepping)
520 | tb_writer.add_scalar(
521 | 'train/learning_rate',
522 | lr,
523 | global_step
524 | )
525 | # all losses
526 | losses_vals_dict = {key : value.val for key, value in losses_tracker.items()}
527 |
528 | tb_writer.add_scalars('train/all_losses', losses_vals_dict, global_step)
529 | tb_writer.add_scalars('train/L2, RL2', {'L2': L2, 'RL2': RL2}, global_step)
530 | tb_writer.add_scalar('train/rho', rho, global_step)
531 |
532 | print_training_progress(curr_epoch, iter_idx, num_iters, batch_time, losses_tracker, rho, L2, RL2)
533 |
534 | rho, L2, RL2 = evaluate(metric_dict,
535 | is_train = True,
536 | dataset_name = cfg["dataset_name"]
537 | )
538 |
539 |
540 | print('Full epoch metrics -> Rho : {:.4f}, L2 : {:.2f}, RL2 : {:.4f}'.format(rho, L2, RL2))
541 |
542 | lr = scheduler.get_last_lr()[0]
543 | print("[Train]: Epoch {:d} finished with lr={:.8f}\n".format(curr_epoch, lr))
544 |
545 | return
546 |
547 |
548 | def valid_one_epoch(
549 | val_loader,
550 | model,
551 | curr_epoch,
552 | cfg,
553 | tb_writer,
554 | print_freq,
555 | save_predictions = False
556 | ):
557 |
558 | """Test the model on the validation set"""
559 |
560 | # switch to evaluate mode
561 | model.eval()
562 |
563 | num_iters = len(val_loader)
564 | metric_dict = MetricsDict()
565 |
566 | # loop over validation set
567 | print(["Validation"])
568 | for iter_idx, video_list in enumerate(val_loader, 0):
569 | # forward the model (wo. grad)
570 | with torch.no_grad():
571 | preprocessed_dict = preprocessing(video_list,
572 | feat_extractor_type = cfg["model"]["feat_extractor_type"],
573 | max_seq_len= cfg["model"]["max_seq_len"],
574 | num_phases= cfg["model"]["num_phases"]
575 | )
576 | # forward / backward the model
577 | model_output, _ = model(**dict(batched_inputs = preprocessed_dict["batched_inputs"],
578 | batched_masks = preprocessed_dict["batched_masks"],
579 | batched_gt_actions = preprocessed_dict["gt_actions"],
580 | video_ids = preprocessed_dict["video_ids"],
581 | curr_epoch = curr_epoch
582 | ))
583 | # only used for evaluation
584 | metric_dict.update("video_ids", preprocessed_dict["video_ids"])
585 |
586 | update_metric_dict_with_model_output(metric_dict, model_output, preprocessed_dict['gt_scores'],preprocessed_dict['difficulties'], is_val=True, cfg=cfg)
587 |
588 | if ((iter_idx != 0) and (iter_idx % print_freq) == 0) or (iter_idx == num_iters - 1):
589 | torch.cuda.synchronize()
590 |
591 | block0 = f'Epoch: [{curr_epoch:03d}][{iter_idx:05d}/{num_iters:05d}]'
592 | print(block0)
593 |
594 |
595 | rho, L2, RL2 = evaluate(metric_dict,
596 | is_train=False,
597 | dataset_name = cfg["dataset_name"],
598 | curr_epoch = curr_epoch
599 | )
600 |
601 | print('Eval Metrics -> Rho : {:.4f}, L2 : {:.2f}, RL2 : {:.4f}'.format(rho, L2, RL2))
602 |
603 | # log metrics to tb_writer
604 | if tb_writer is not None:
605 | tb_writer.add_scalar('validation/rho', rho, curr_epoch)
606 | tb_writer.add_scalars('validation/L2, RL2', {'L2': L2, 'RL2': RL2}, curr_epoch)
607 |
608 | if save_predictions:
609 | metric_dict = metric_dict.get_metric_dict()
610 | else:
611 | metric_dict = None
612 |
613 | return rho, RL2, metric_dict
--------------------------------------------------------------------------------
/libs/modeling/blocks.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import nn
7 | from .weight_init import trunc_normal_
8 |
9 | def reparameterize(B_mu, B_sigma, num_samples=1):
10 | expanded_mu = B_mu.expand(num_samples, *B_mu.shape)
11 | expanded_sigma = B_sigma.expand(num_samples, *B_sigma.shape)
12 | norm_v = torch.randn_like(expanded_mu).detach()
13 | # reparameterization trick
14 | samples = expanded_mu + expanded_sigma * norm_v
15 | return samples
16 |
17 | class PhaseDistComposer(nn.Module):
18 | def __init__(self, dataset_name='finediving', dim = 256):
19 | super().__init__()
20 | self.dataset_name = dataset_name
21 | if dataset_name == 'finediving':
22 | self.level_0_to_1_index = [torch.tensor(x).detach().long() for x in [[0,1,2,3,4,5,6],[7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27],[28]]]
23 | elif dataset_name =="mtl_aqa":
24 | self.level_0_to_1_index = [torch.tensor(x).detach().long() for x in [[0,1,2,3,4],[5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22],[23]]]
25 | elif dataset_name == "needle_passing":
26 | self.level_0_to_1_index = [torch.tensor(x).detach().long() for x in [[0,1,2,3,4,5,6,7],[0,1,2,3,4,5,6,7],[0,1,2,3,4,5,6,7]]]
27 | elif dataset_name == "knot_tying":
28 | self.level_0_to_1_index = [torch.tensor(x).detach().long() for x in [[0,1,2,3,4,5],[0,1,2,3,4,5],[0,1,2,3,4,5]]]
29 | elif dataset_name == "suturing":
30 | self.level_0_to_1_index = [torch.tensor(x).detach().long() for x in [[0,1,2,3,4,5,6,7,8,9],[0,1,2,3,4,5,6,7,8,9],[0,1,2,3,4,5,6,7,8,9]]]
31 |
32 | self.level_2_mlp = nn.Sequential(
33 | nn.Linear(dim,dim),
34 | nn.ReLU(),
35 | nn.Linear(dim,dim)
36 | )
37 |
38 | self.level_3_mlp = nn.Sequential(
39 | nn.Linear(dim,dim),
40 | nn.ReLU(),
41 | nn.Linear(dim,dim)
42 | )
43 |
44 | def get_one_sample(self, phase_mean_emb, phase_var_emb, gt_actions):
45 | #B, Phases, C
46 | if phase_var_emb is None:
47 | level_0_samples = phase_mean_emb
48 | else:
49 | phase_sqrt_var_emb = torch.sqrt(phase_var_emb)
50 | level_0_samples = reparameterize(phase_mean_emb, phase_sqrt_var_emb, num_samples=1).squeeze(0)
51 |
52 | B, num_phases, C = phase_mean_emb.shape
53 |
54 | level_1_inputs_separated = [level_0_samples[:,node_input_idx,:] for node_input_idx in self.level_0_to_1_index]
55 |
56 | level_1_presence_separated = [gt_actions[:,node_input_idx].unsqueeze(-1) for node_input_idx in self.level_0_to_1_index]
57 |
58 | #num nodes, B x phases going into the node x C
59 | level_1_inputs_filtered_by_presence = [level_1_inputs_separated[node_num] * level_1_presence_separated[node_num] for node_num in range(len(level_1_inputs_separated))]
60 |
61 | #num nodes, B x C
62 | level_1_inputs_summed_separate = [torch.sum(level_1_inputs_filtered_by_presence[node_num], dim=1)
63 | / torch.sum(level_1_presence_separated[node_num], dim=1)
64 | for node_num in range(len(level_1_inputs_filtered_by_presence))]
65 |
66 | #num nodes, B x C
67 | level_2_outputs = [self.level_2_mlp(x) for x in level_1_inputs_summed_separate]
68 |
69 | #num_nodes x B x C
70 | level_2_outputs = torch.stack(level_2_outputs, dim=0)
71 |
72 | #level_3_input is the same as output since we have more decoding heads later
73 | level_3_output = torch.sum(level_2_outputs, dim=0) / level_2_outputs.shape[0]
74 |
75 | #check if nan in level_3_output
76 | if torch.isnan(level_3_output).any():
77 | import ipdb; ipdb.set_trace()
78 |
79 | return level_3_output
80 |
81 |
82 | def process_phases(self, phase_mean_emb, phase_var_emb, num_samples, gt_actions):
83 | all_samples = []
84 | for i in range(num_samples):
85 | one_sample_per_item = self.get_one_sample(phase_mean_emb, phase_var_emb, gt_actions)
86 | all_samples.append(one_sample_per_item)
87 |
88 | #num_samples, B, C
89 | all_samples = torch.stack(all_samples, dim=0)
90 |
91 | if phase_var_emb is None:
92 | return all_samples, None
93 |
94 | masked_phase_var_emb = phase_var_emb * gt_actions.unsqueeze(-1)
95 |
96 | #global variance
97 | global_masked_var = torch.sum(masked_phase_var_emb, dim=1) / torch.sum(gt_actions, dim=1).unsqueeze(-1).detach()
98 |
99 | #global sigma
100 | global_sigma = torch.sqrt(global_masked_var)
101 |
102 | return all_samples, global_sigma
103 |
104 |
105 | class MaskedConv1D(nn.Module):
106 | """
107 | Masked 1D convolution. Interface remains the same as Conv1d.
108 | Only support a sub set of 1d convs
109 | """
110 | def __init__(
111 | self,
112 | in_channels,
113 | out_channels,
114 | kernel_size,
115 | stride=1,
116 | padding=0,
117 | dilation=1,
118 | groups=1,
119 | bias=True,
120 | padding_mode='zeros'
121 | ):
122 | super().__init__()
123 | # element must be aligned
124 | assert (kernel_size % 2 == 1) and (kernel_size // 2 == padding)
125 | # stride
126 | self.stride = stride
127 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size,
128 | stride, padding, dilation, groups, bias, padding_mode)
129 | # zero out the bias term if it exists
130 | if bias:
131 | torch.nn.init.constant_(self.conv.bias, 0.)
132 |
133 | def forward(self, x, mask):
134 | # x: batch size, feature channel, sequence length,
135 | # mask: batch size, 1, sequence length (bool)
136 | B, C, T = x.size()
137 |
138 | # input length must be divisible by stride
139 | assert T % self.stride == 0
140 |
141 | # conv
142 | out_conv = self.conv(x)
143 | # compute the mask
144 | if self.stride > 1:
145 | # downsample the mask using nearest neighbor
146 | out_mask = F.interpolate(
147 | mask.to(x.dtype), size=out_conv.size(-1), mode='nearest'
148 | )
149 | else:
150 | # masking out the features
151 | out_mask = mask.to(x.dtype)
152 |
153 | # masking the output, stop grad to mask
154 | out_conv = out_conv * out_mask.detach()
155 | out_mask = out_mask.bool()
156 | return out_conv, out_mask
157 |
158 |
159 | class LayerNorm(nn.Module):
160 | """
161 | LayerNorm that supports inputs of size B, C, T
162 | """
163 | def __init__(
164 | self,
165 | num_channels,
166 | eps = 1e-5,
167 | affine = True,
168 | device = None,
169 | dtype = None,
170 | ):
171 | super().__init__()
172 | factory_kwargs = {'device': device, 'dtype': dtype}
173 | self.num_channels = num_channels
174 | self.eps = eps
175 | self.affine = affine
176 |
177 | if self.affine:
178 | self.weight = nn.Parameter(
179 | torch.ones([1, num_channels, 1], **factory_kwargs))
180 | self.bias = nn.Parameter(
181 | torch.zeros([1, num_channels, 1], **factory_kwargs))
182 | else:
183 | self.register_parameter('weight', None)
184 | self.register_parameter('bias', None)
185 |
186 | def forward(self, x):
187 | assert x.dim() == 3
188 | assert x.shape[1] == self.num_channels
189 |
190 | # normalization along C channels
191 | mu = torch.mean(x, dim=1, keepdim=True)
192 | res_x = x - mu
193 | sigma = torch.mean(res_x**2, dim=1, keepdim=True)
194 | out = res_x / torch.sqrt(sigma + self.eps)
195 |
196 | # apply weight and bias
197 | if self.affine:
198 | out *= self.weight
199 | out += self.bias
200 |
201 | return out
202 |
203 | class MaskedAvgPool1D(nn.Module):
204 | """
205 | Masked 1D average pooling
206 | """
207 | def __init__(self):
208 | super(MaskedAvgPool1D, self).__init__()
209 |
210 | def forward(self, x, mask):
211 | x_sum = torch.sum(x * mask.float(), dim=-1, keepdim=True)
212 | n = torch.sum(mask, dim=-1, keepdim=True)
213 | x = x_sum / n
214 |
215 | return x
216 |
217 |
218 | # helper functions for Transformer blocks
219 | def get_sinusoid_encoding(n_position, d_hid):
220 | ''' Sinusoid position encoding table '''
221 |
222 | def get_position_angle_vec(position):
223 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
224 |
225 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
226 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
227 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
228 |
229 | # return a tensor of size 1 C T
230 | return torch.FloatTensor(sinusoid_table).unsqueeze(0).transpose(1, 2)
231 |
232 | class MaskedMHA(nn.Module):
233 | """
234 | Multi Head Attention with mask
235 | NOTE: This implementation supports
236 | - global and local self-attention
237 | - (global) cross-attention
238 |
239 | Modified from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
240 | """
241 | def __init__(
242 | self,
243 | embd_dim, # embedding dimension
244 | q_dim=None, # query dimension
245 | kv_dim=None, # key / value dimension
246 | out_dim=None, # output dimension
247 | n_heads=4, # number of attention heads
248 | window_size=0, # local attention window size (0 for global attention)
249 | attn_pdrop=0.0, # dropout rate for attention map
250 | proj_pdrop=0.0, # dropout rate for projection
251 | use_rel_pe=False, # whether to apply relative position encoding
252 | ):
253 | super(MaskedMHA, self).__init__()
254 |
255 | assert embd_dim % n_heads == 0
256 | self.embd_dim = embd_dim
257 |
258 | if q_dim is None:
259 | q_dim = embd_dim
260 | if kv_dim is None:
261 | kv_dim = embd_dim
262 | if out_dim is None:
263 | out_dim = q_dim
264 |
265 | self.n_heads = n_heads
266 | self.n_channels = embd_dim // n_heads
267 | self.scale = 1.0 / math.sqrt(math.sqrt(self.n_channels))
268 | self.out_dim = out_dim
269 |
270 | self.query = nn.Conv1d(q_dim, embd_dim, 1)
271 | self.key = nn.Conv1d(kv_dim, embd_dim, 1)
272 | self.value = nn.Conv1d(kv_dim, embd_dim, 1)
273 | self.proj = nn.Conv1d(embd_dim, out_dim, 1)
274 |
275 | self.attn_drop = nn.Dropout(attn_pdrop)
276 | self.proj_drop = nn.Dropout(proj_pdrop)
277 |
278 | # local attention window size
279 | assert window_size == 0 or window_size % 2 == 1
280 | self.window_size = window_size
281 | self.stride = window_size // 2
282 |
283 | # masks for local attention (left / right paddings)
284 | l_mask = torch.ones(self.stride, self.stride + 1).tril().flip(dims=(0,))
285 | r_mask = torch.ones(self.stride, self.stride + 1).tril().flip(dims=(1,))
286 | self.register_buffer('l_mask', l_mask.bool(), persistent=False)
287 | self.register_buffer('r_mask', r_mask.bool(), persistent=False)
288 |
289 | # relative position encoding for local attention
290 | if window_size > 0 and use_rel_pe:
291 | self.rel_pe = nn.Parameter(torch.zeros(n_heads, 1, window_size))
292 | trunc_normal_(self.rel_pe, std=(2.0 / embd_dim) ** 0.5)
293 | else:
294 | self.rel_pe = None
295 |
296 | def _chunk(self, x, size):
297 | """
298 | Convert feature sequence into temporally overlapping chunks.
299 |
300 | Args:
301 | x (float tensor, (n, t, d)): feature sequence.
302 | size (int): chunk size.
303 |
304 | Returns:
305 | x (float tensor, (n, k, s, d)): chunked features.
306 | """
307 | n, t, d = x.size()
308 | assert (t + self.stride - size) % self.stride == 0
309 | n_chunks = (t + self.stride - size) // self.stride
310 |
311 | chunk_size = (n, n_chunks, size, d)
312 | chunk_stride = (x.stride(0), self.stride * x.stride(1), *x.stride()[1:])
313 | x = x.as_strided(size=chunk_size, stride=chunk_stride)
314 |
315 | return x
316 |
317 | def _query_key_matmul(self, q, k):
318 | """
319 | Chunk-wise query-key product.
320 |
321 | Args:
322 | q (float tensor, (n, t, d)): query tensor.
323 | k (float tensor, (n, t, d)): key tensor.
324 |
325 | Returns:
326 | attn (float tensor, (n, t, w)): unnormalized attention scores.
327 | """
328 | assert q.size() == k.size()
329 | n, t, _ = q.size()
330 | w, s = self.window_size, self.stride
331 |
332 | # chunk query and key tensors: (n, t, d) -> (n, t // s - 1, 2s, d)
333 | q_chunks = self._chunk(q.contiguous(), size=2 * s)
334 | k_chunks = self._chunk(k.contiguous(), size=2 * s)
335 | n_chunks = q_chunks.size(1)
336 |
337 | # chunk-wise attention scores: (n, t // s - 1, 2s, 2s)
338 | chunk_attn = torch.einsum('bcxd,bcyd->bcxy', (q_chunks, k_chunks))
339 |
340 | # shift diagonals into columns: (n, t // s - 1, 2s, w)
341 | chunk_attn = F.pad(chunk_attn, (0, 0, 0, 1))
342 | chunk_attn = chunk_attn.view(n, n_chunks, 2 * s, w)
343 |
344 | # fill in the overall attention matrix: (n, t // s, s, w)
345 | attn = chunk_attn.new_zeros(n, t // s, s, w)
346 | attn[:, :-1, :, s:] = chunk_attn[:, :, :s, :s + 1]
347 | attn[:, -1, :, s:] = chunk_attn[:, -1, s:, :s + 1]
348 | attn[:, 1:, :, :s] = chunk_attn[:, :, -(s + 1):-1, s + 1:]
349 | attn[:, 0, 1:s, 1:s] = chunk_attn[:, 0, :s - 1, w - (s - 1):]
350 | attn = attn.view(n, t, w)
351 |
352 | # mask invalid attention scores
353 | attn[:, :s, :s + 1].masked_fill_(self.l_mask, float('-inf'))
354 | attn[:, -s:, -(s + 1):].masked_fill_(self.r_mask, float('-inf'))
355 |
356 | return attn
357 |
358 | def _attn_normalize(self, attn, mask):
359 | """
360 | Normalize attention scores over valid positions.
361 |
362 | Args:
363 | attn (float tensor, (bs, h, t, w)): unnormalized attention scores.
364 | mask (bool tensor, (bs, t, 1)): mask (1 for valid positions).
365 |
366 | Returns:
367 | attn (float tensor, (bs, h, t, w)): normalized attention map.
368 | """
369 | bs, h, t, w = attn.size()
370 |
371 | # inverse mask (0 for valid positions, -inf for invalid ones)
372 | inv_mask = torch.logical_not(mask)
373 | inv_mask_float = inv_mask.float().masked_fill(inv_mask, -1e4)
374 |
375 | # additive attention mask: (bs, t, w)
376 | attn_mask = self._query_key_matmul(
377 | torch.ones_like(inv_mask_float), inv_mask_float
378 | )
379 | attn += attn_mask.view(bs, 1, t, w)
380 |
381 | # normalize
382 | attn = F.softmax(attn, dim=-1)
383 |
384 | # if all key / value positions in a local window are invalid
385 | # (i.e., when the query position is invalid), softmax returns NaN.
386 | # Replace NaNs with 0
387 | attn = attn.masked_fill(inv_mask.unsqueeze(1), 0.0)
388 |
389 | return attn
390 |
391 | def _attn_value_matmul(self, attn, v):
392 | """
393 | Chunk-wise attention-value product.
394 |
395 | Args:
396 | attn (float tensor, (n, t, w)): attention map.
397 | v (float tensor, (n, t, d)): value tensor.
398 |
399 | Returns:
400 | out (float tensor, (n, t, d)): attention-weighted sum of values.
401 | """
402 | n, t, d = v.size()
403 | w, s = self.window_size, self.stride
404 |
405 | # chunk attention map: (n, t, w) -> (n, t // s, s, w)
406 | attn_chunks = attn.view(n, t // s, s, w)
407 |
408 | # shift columns into diagonals: (n, t // s, s, 3s)
409 | attn_chunks = F.pad(attn_chunks, (0, s))
410 | attn_chunks = attn_chunks.view(n, t // s, -1)[..., :-s]
411 | attn_chunks = attn_chunks.view(n, t // s, s, 3 * s)
412 |
413 | # chunk value tensor: (n, t + 2s, d) -> (n, t // s, 3s, d)
414 | v = F.pad(v, (0, 0, s, s))
415 | v_chunks = self._chunk(v.contiguous(), size=3 * s)
416 |
417 | # chunk-wise attention-weighted sum: (n, t // s, s, d)
418 | out = torch.einsum('bcwd,bcdh->bcwh', (attn_chunks, v_chunks))
419 | out = out.view(n, t, d)
420 |
421 | return out
422 |
423 | def forward(self, q, k=None, v=None, kv_mask=None, kv_size=None,
424 | video_ids=None, layer_idx=None, curr_epoch=None, q_mask=None):
425 | """
426 | Args:
427 | q (float tensor, (bs, c, t1)): query feature sequence.
428 | k (float tensor, (bs, c, t2)): key feature sequence.
429 | v (float tensor, (bs, c, t2)): value feature sequence.
430 | kv_mask (bool tensor, (bs, 1, t2)): key / value mask.
431 | kv_size (int tensor, (bs,)): number of times to repeat each sample.
432 | """
433 | bs, c = q.size(0), self.embd_dim
434 | h, d, w = self.n_heads, self.n_channels, self.window_size
435 |
436 | if k is None:
437 | k = q
438 | if v is None:
439 | v = k
440 |
441 | # if mask is not given, assume all positions are valid
442 | if kv_mask is None:
443 | kv_mask = torch.ones_like(k[:, :1], dtype=torch.bool)
444 |
445 | q = self.query(q)
446 | k = self.key(k)
447 | v = self.value(v)
448 |
449 | # repeat query to match the size of key / value
450 | if kv_size is not None and k.size(0) != bs:
451 | q = q.repeat_interleave(kv_size, dim=0)
452 | bs = q.size(0)
453 |
454 | if self.window_size > 0:
455 | q = q.view(bs, h, d, -1).flatten(0, 1).transpose(1, 2)
456 | k = k.view(bs, h, d, -1).flatten(0, 1).transpose(1, 2)
457 | v = v.view(bs, h, d, -1).flatten(0, 1).transpose(1, 2)
458 |
459 | # attention scores: (bs * h, t, w)
460 | attn = self._query_key_matmul(q * self.scale, k * self.scale)
461 | attn = attn.view(bs, h, -1, w)
462 | if self.rel_pe is not None:
463 | attn += self.rel_pe
464 |
465 | # normalized attention map: (bs, h, t, w)
466 | attn = self._attn_normalize(attn, kv_mask.transpose(1, 2))
467 | attn = self.attn_drop(attn)
468 | attn = attn.view(bs * h, -1, w)
469 |
470 | # attention-weighted sum of values: # (bs * h, t, d)
471 | q = self._attn_value_matmul(attn, v)
472 | q = q.view(bs, h, -1, d)
473 | else:
474 | q = q.view(bs, h, d, -1).transpose(2, 3)
475 | k = k.view(bs, h, d, -1)
476 | v = v.view(bs, h, d, -1).transpose(2, 3)
477 |
478 | attn = (q * self.scale) @ (k * self.scale) # (bs, h, t1, t2)
479 | attn = attn.masked_fill(
480 | mask=torch.logical_not(kv_mask[:, :, None, :]),
481 | value=float('-inf'),
482 | )
483 | attn = F.softmax(attn, dim=-1)
484 |
485 | ret_attn = attn.clone()
486 |
487 | attn = self.attn_drop(attn)
488 | q = attn @ v # (bs, h, t1, d)
489 |
490 | q = q.transpose(2, 3).reshape(bs, c, -1) # (bs, c, t1)
491 | out = self.proj_drop(self.proj(q))
492 |
493 | return out, ret_attn * q_mask.unsqueeze(-1).expand(attn.shape).requires_grad_(False)
494 |
495 |
496 | class AttNPool1D(nn.Module):
497 | def __init__(self, embd_dim, n_heads=4):
498 | super(AttNPool1D, self).__init__()
499 |
500 | self.pool = MaskedAvgPool1D()
501 | self.attn = MaskedMHA(embd_dim, n_heads=n_heads)
502 |
503 | def forward(self, x, mask):
504 | x_mean = self.pool(x, mask)
505 | h = torch.cat((x_mean, x), dim=-1)
506 | mask = torch.cat((mask[..., :1], mask), dim=-1)
507 |
508 | pool = self.attn(h, kv_mask=mask)[..., :1]
509 | x = torch.cat((pool, x), dim=-1)
510 |
511 | return x, mask
512 |
513 |
514 |
515 | class ConvXAttNLayer(nn.Module):
516 | """
517 | Multi Head Conv Cross Attention with mask
518 |
519 | With current implementation, the downpsampled features will be aligned with
520 | every s+1 time steps, where s is the down-sampling stride. This allows us
521 | to easily interpolate the corresponding position encoding.
522 | """
523 | def __init__(
524 | self,
525 | embd_dim, # embedding dimension
526 | kv_dim, # key / value dimension
527 | out_dim=None, # output dimension
528 | stride=1, # convolution stride
529 | n_heads=4, # number of attention heads
530 | attn_pdrop=0.0, # dropout rate for attention map
531 | proj_pdrop=0.0, # dropout rate for projection
532 | use_offset=False, # whether to add offsets to down-sampled points
533 | ):
534 | super(ConvXAttNLayer, self).__init__()
535 |
536 | self.use_conv = stride > 0
537 | if self.use_conv:
538 | assert stride == 1 or stride % 2 == 0
539 | kernel_size = stride + 1 + use_offset if stride > 1 else 3
540 | padding = (kernel_size - 1) // 2
541 |
542 | self.q_conv = MaskedConv1D(
543 | embd_dim, embd_dim,
544 | kernel_size=kernel_size,
545 | stride=stride, padding=padding,
546 | groups=embd_dim, bias=False,
547 | )
548 | self.q_norm = LayerNorm(embd_dim)
549 | else:
550 | self.q_conv = self.q_norm = None
551 |
552 | if out_dim is None:
553 | out_dim = embd_dim
554 |
555 | # cross-attention
556 | self.xattn = MaskedMHA(
557 | embd_dim,
558 | kv_dim=kv_dim, out_dim=out_dim,
559 | n_heads=n_heads,
560 | attn_pdrop=attn_pdrop, proj_pdrop=proj_pdrop
561 | )
562 |
563 | def forward(self, q, q_mask, kv, kv_mask, kv_size=None,
564 | video_ids=None, layer_idx=None, curr_epoch=None):
565 | if self.use_conv:
566 | q, q_mask = self.q_conv(q, q_mask)
567 | q = self.q_norm(q)
568 |
569 | out, cross_attn = self.xattn(q, kv, None, kv_mask, kv_size,
570 | video_ids=video_ids, layer_idx=layer_idx,
571 | curr_epoch=curr_epoch, q_mask = q_mask)
572 | if kv_size is not None and out.size(0) != q_mask.size(0):
573 | q_mask = q_mask.repeat_interleave(kv_size, dim=0)
574 |
575 | return out, q_mask, cross_attn
576 |
577 |
578 | class FFN(nn.Module):
579 | """
580 | Feed Forward Network (MLP) in Transformer.
581 | """
582 | def __init__(self, channels, expansion=4, pdrop=0.0):
583 | super(FFN, self).__init__()
584 |
585 | self.fc = nn.Conv1d(channels, channels * expansion, 1)
586 | self.actv = nn.GELU()
587 | self.proj = nn.Conv1d(channels * expansion, channels, 1)
588 | self.dropout = nn.Dropout(pdrop)
589 |
590 | def forward(self, x):
591 | x = self.dropout(self.actv(self.fc(x)))
592 | x = self.dropout(self.proj(x))
593 |
594 | return x
595 |
596 |
597 | class TransformerDecoder(nn.Module):
598 | """
599 | Transformer Decoder (w/o self-attention).
600 | (optional depth-wise conv -> xattn -> FFN)
601 | """
602 | def __init__(
603 | self,
604 | embd_dim, # embedding dimension
605 | kv_dim, # key / value dimension
606 | stride=1, # convolution stride (0 if disable convs)
607 | n_heads=4, # number of attention heads
608 | window_size=0, # MHA window size (0 for global attention)
609 | expansion=4, # expansion factor for FFN
610 | attn_pdrop=0.0, # dropout rate for attention map
611 | proj_pdrop=0.0, # dropout rate for projection
612 | path_pdrop=0.0, # dropout rate for residual paths
613 | xattn_mode='adaln', # cross-attention mode (affine | adaln)
614 | use_offset=False, # whether to add offsets to down-sampled points
615 | use_rel_pe=False, # whether to apply relative position encoding
616 | ):
617 | super(TransformerDecoder, self).__init__()
618 |
619 | # cross-attention
620 | assert xattn_mode in ('affine', 'adaln')
621 | self.xattn = ConvXAttNLayer(
622 | embd_dim, kv_dim,
623 | out_dim=embd_dim * 2,
624 | stride=stride, n_heads=n_heads,
625 | attn_pdrop=attn_pdrop, proj_pdrop=path_pdrop,
626 | use_offset=use_offset
627 | )
628 | self.ln_xattn_q = LayerNorm(embd_dim)
629 | self.ln_xattn_kv = LayerNorm(kv_dim)
630 |
631 | if xattn_mode == 'adaln':
632 | self.adaln = LayerNorm(embd_dim, affine=False)
633 | else:
634 | self.adaln = None
635 |
636 | # FFN
637 | self.ffn = FFN(embd_dim, expansion, proj_pdrop)
638 | self.ln_ffn = LayerNorm(embd_dim)
639 |
640 | # drop path
641 | if path_pdrop > 0.0:
642 | self.drop_path_ffn = AffineDropPath(embd_dim, drop_prob=path_pdrop)
643 | else:
644 | self.drop_path_ffn = nn.Identity()
645 |
646 | def forward(self, q, q_mask, kv, kv_mask, kv_size=None,
647 | video_ids=None, curr_epoch=None):
648 | if q_mask is None:
649 | q_mask = torch.ones_like(q[:, :1], dtype=torch.bool)
650 |
651 | q = q * q_mask.float()
652 |
653 |
654 | # cross-attention (optionally with depth-wise conv)
655 | out, q_mask, cross_attn = self.xattn(
656 | self.ln_xattn_q(q), q_mask, self.ln_xattn_kv(kv), kv_mask, kv_size,
657 | video_ids=video_ids, curr_epoch=curr_epoch
658 | )
659 | if kv_size is not None and q.size(0) != out.size(0):
660 | q = q.repeat_interleave(kv_size, dim=0)
661 |
662 | q = q * q_mask.float()
663 | weight, bias = out.split(q.size(1), dim=1)
664 |
665 | if self.adaln is not None:
666 | q = self.adaln(q)
667 |
668 | q = q * weight + bias
669 |
670 | # FFN
671 | out = self.ffn(self.ln_ffn(q)) * q_mask.float()
672 | q = q + self.drop_path_ffn(out)
673 |
674 | return q, q_mask, cross_attn
675 |
676 |
677 | class Scale(nn.Module):
678 | """
679 | Multiply the output regression range by a learnable constant value
680 | """
681 |
682 | def __init__(self, init=1.0):
683 | """
684 | init_value : initial value for the scalar
685 | """
686 | super(Scale, self).__init__()
687 |
688 | self.scale = nn.Parameter(torch.as_tensor(init, dtype=torch.float))
689 |
690 | def forward(self, x):
691 | return x * self.scale
692 |
693 |
694 | def drop_path(x, drop_prob=0.0, training=False):
695 | """
696 | Stochastic Depth per sample.
697 | """
698 | if drop_prob == 0.0 or not training:
699 | return x
700 |
701 | keep_prob = 1 - drop_prob
702 | shape = (x.shape[0],) + (1,) * (x.ndim - 1)
703 | mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
704 | mask.floor_()
705 | x = x.div(keep_prob) * mask
706 |
707 | return x
708 |
709 |
710 | class DropPath(nn.Module):
711 | """
712 | Drop paths per sample (when applied in main path of residual blocks).
713 | """
714 | def __init__(self, drop_prob=None):
715 | super(DropPath, self).__init__()
716 |
717 | self.drop_prob = drop_prob
718 |
719 | def forward(self, x):
720 | return drop_path(x, self.drop_prob, self.training)
721 |
722 |
723 | class AffineDropPath(nn.Module):
724 | """
725 | Drop paths per sample (when applied in main path of residual blocks)
726 | with a per channel scaling factor (and zero init).
727 |
728 | https://arxiv.org/pdf/2103.17239.pdf
729 | """
730 | def __init__(self, dim, drop_prob=0.0, init_scale=1e-4):
731 | super(AffineDropPath, self).__init__()
732 |
733 | self.scale = nn.Parameter(init_scale * torch.ones((1, dim, 1)))
734 | self.drop_prob = drop_prob
735 |
736 | def forward(self, x):
737 | return drop_path(self.scale * x, self.drop_prob, self.training)
738 |
739 |
740 | class ConvBlock(nn.Module):
741 | """
742 | A simple conv block similar to the basic block used in ResNet
743 | """
744 | def __init__(
745 | self,
746 | n_embd, # dimension of the input features
747 | kernel_size=3, # conv kernel size
748 | n_ds_stride=1, # downsampling stride for the current layer
749 | expansion_factor=2, # expansion factor of feat dims
750 | n_out=None, # output dimension, if None, set to input dim
751 | act_layer=nn.ReLU, # nonlinear activation used after conv, default ReLU
752 | ):
753 | super().__init__()
754 | # must use odd sized kernel
755 | assert (kernel_size % 2 == 1) and (kernel_size > 1)
756 | padding = kernel_size // 2
757 | if n_out is None:
758 | n_out = n_embd
759 |
760 | # 1x3 (strided) -> 1x3 (basic block in resnet)
761 | width = n_embd * expansion_factor
762 | self.conv1 = MaskedConv1D(
763 | n_embd, width, kernel_size, n_ds_stride, padding=padding)
764 | self.conv2 = MaskedConv1D(
765 | width, n_out, kernel_size, 1, padding=padding)
766 |
767 | # attach downsampling conv op
768 | if n_ds_stride > 1:
769 | # 1x1 strided conv (same as resnet)
770 | self.downsample = MaskedConv1D(n_embd, n_out, 1, n_ds_stride)
771 | else:
772 | self.downsample = None
773 |
774 | self.act = act_layer()
775 |
776 | def forward(self, x, mask, pos_embd=None):
777 | identity = x
778 | out, out_mask = self.conv1(x, mask)
779 | out = self.act(out)
780 | out, out_mask = self.conv2(out, out_mask)
781 |
782 | # downsampling
783 | if self.downsample is not None:
784 | identity, _ = self.downsample(x, mask)
785 |
786 | # residual connection
787 | out += identity
788 | out = self.act(out)
789 |
790 | return out, out_mask
791 |
792 |
793 | class MaskedMHCA(nn.Module):
794 | """
795 | Multi Head Conv Attention with mask
796 |
797 | Add a depthwise convolution within a standard MHA
798 | The extra conv op can be used to
799 | (1) encode relative position information (relacing position encoding);
800 | (2) downsample the features if needed;
801 | (3) match the feature channels
802 |
803 | Note: With current implementation, the downsampled feature will be aligned
804 | to every s+1 time step, where s is the downsampling stride. This allows us
805 | to easily interpolate the corresponding positional embeddings.
806 |
807 | Modified from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
808 | """
809 |
810 | def __init__(
811 | self,
812 | n_embd, # dimension of the output features
813 | n_head, # number of heads in multi-head self-attention
814 | n_qx_stride=1, # dowsampling stride for query and input
815 | n_kv_stride=1, # downsampling stride for key and value
816 | attn_pdrop=0.0, # dropout rate for the attention map
817 | proj_pdrop=0.0, # dropout rate for projection op
818 | ):
819 | super().__init__()
820 | assert n_embd % n_head == 0
821 | self.n_embd = n_embd
822 | self.n_head = n_head
823 | self.n_channels = n_embd // n_head
824 | self.scale = 1.0 / math.sqrt(self.n_channels)
825 |
826 | # conv/pooling operations
827 | assert (n_qx_stride == 1) or (n_qx_stride % 2 == 0)
828 | assert (n_kv_stride == 1) or (n_kv_stride % 2 == 0)
829 | self.n_qx_stride = n_qx_stride
830 | self.n_kv_stride = n_kv_stride
831 |
832 | # query conv (depthwise)
833 | kernel_size = self.n_qx_stride + 1 if self.n_qx_stride > 1 else 3
834 | stride, padding = self.n_kv_stride, kernel_size // 2
835 | self.query_conv = MaskedConv1D(
836 | self.n_embd, self.n_embd, kernel_size,
837 | stride=stride, padding=padding, groups=self.n_embd, bias=False
838 | )
839 | self.query_norm = LayerNorm(self.n_embd)
840 |
841 | # key, value conv (depthwise)
842 | kernel_size = self.n_kv_stride + 1 if self.n_kv_stride > 1 else 3
843 | stride, padding = self.n_kv_stride, kernel_size // 2
844 | self.key_conv = MaskedConv1D(
845 | self.n_embd, self.n_embd, kernel_size,
846 | stride=stride, padding=padding, groups=self.n_embd, bias=False
847 | )
848 | self.key_norm = LayerNorm(self.n_embd)
849 | self.value_conv = MaskedConv1D(
850 | self.n_embd, self.n_embd, kernel_size,
851 | stride=stride, padding=padding, groups=self.n_embd, bias=False
852 | )
853 | self.value_norm = LayerNorm(self.n_embd)
854 |
855 | # key, query, value projections for all heads
856 | # it is OK to ignore masking, as the mask will be attached on the attention
857 | self.key = nn.Conv1d(self.n_embd, self.n_embd, 1)
858 | self.query = nn.Conv1d(self.n_embd, self.n_embd, 1)
859 | self.value = nn.Conv1d(self.n_embd, self.n_embd, 1)
860 |
861 | # regularization
862 | self.attn_drop = nn.Dropout(attn_pdrop)
863 | self.proj_drop = nn.Dropout(proj_pdrop)
864 |
865 | # output projection
866 | self.proj = nn.Conv1d(self.n_embd, self.n_embd, 1)
867 |
868 | def forward(self, x, mask):
869 | # x: batch size, feature channel, sequence length,
870 | # mask: batch size, 1, sequence length (bool)
871 | B, C, T = x.size()
872 |
873 | # query conv -> (B, nh * hs, T')
874 | q, qx_mask = self.query_conv(x, mask)
875 | q = self.query_norm(q)
876 | # key, value conv -> (B, nh * hs, T'')
877 | k, kv_mask = self.key_conv(x, mask)
878 | k = self.key_norm(k)
879 | v, _ = self.value_conv(x, mask)
880 | v = self.value_norm(v)
881 |
882 | # projections
883 | q = self.query(q)
884 | k = self.key(k)
885 | v = self.value(v)
886 |
887 | # move head forward to be the batch dim
888 | # (B, nh * hs, T'/T'') -> (B, nh, T'/T'', hs)
889 | k = k.view(B, self.n_head, self.n_channels, -1).transpose(2, 3)
890 | q = q.view(B, self.n_head, self.n_channels, -1).transpose(2, 3)
891 | v = v.view(B, self.n_head, self.n_channels, -1).transpose(2, 3)
892 |
893 | # self-attention: (B, nh, T', hs) x (B, nh, hs, T'') -> (B, nh, T', T'')
894 | att = (q * self.scale) @ k.transpose(-2, -1)
895 | # prevent q from attending to invalid tokens
896 | att = att.masked_fill(torch.logical_not(kv_mask[:, :, None, :]), float('-inf'))
897 | # softmax attn
898 | att = F.softmax(att, dim=-1)
899 | att = self.attn_drop(att)
900 | # (B, nh, T', T'') x (B, nh, T'', hs) -> (B, nh, T', hs)
901 | out = att @ (v * kv_mask[:, :, :, None].to(v.dtype))
902 | # re-assemble all head outputs side by side
903 | out = out.transpose(2, 3).contiguous().view(B, C, -1)
904 |
905 | # output projection + skip connection
906 | out = self.proj_drop(self.proj(out)) * qx_mask.to(out.dtype)
907 | return out, qx_mask
908 |
909 |
910 |
911 | class TransformerBlock(nn.Module):
912 | """
913 | A simple (post layer norm) Transformer block
914 | Modified from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
915 | """
916 | def __init__(
917 | self,
918 | n_embd, # dimension of the input features
919 | n_head, # number of attention heads
920 | n_ds_strides=(1, 1), # downsampling strides for q & x, k & v
921 | n_out=None, # output dimension, if None, set to input dim
922 | n_hidden=None, # dimension of the hidden layer in MLP
923 | act_layer=nn.GELU, # nonlinear activation used in MLP, default GELU
924 | attn_pdrop=0.0, # dropout rate for the attention map
925 | proj_pdrop=0.0, # dropout rate for the projection / MLP
926 | path_pdrop=0.0, # drop path rate
927 | mha_win_size=-1, # > 0 to use window mha
928 | use_rel_pe=False # if to add rel position encoding to attention
929 | ):
930 | super().__init__()
931 | assert len(n_ds_strides) == 2
932 | # layer norm for order (B C T)
933 | self.ln1 = LayerNorm(n_embd)
934 | self.ln2 = LayerNorm(n_embd)
935 |
936 | # specify the attention module
937 | if mha_win_size > 1:
938 | self.attn = LocalMaskedMHCA(
939 | n_embd,
940 | n_head,
941 | window_size=mha_win_size,
942 | n_qx_stride=n_ds_strides[0],
943 | n_kv_stride=n_ds_strides[1],
944 | attn_pdrop=attn_pdrop,
945 | proj_pdrop=proj_pdrop,
946 | use_rel_pe=use_rel_pe # only valid for local attention
947 | )
948 | else:
949 | self.attn = MaskedMHCA(
950 | n_embd,
951 | n_head,
952 | n_qx_stride=n_ds_strides[0],
953 | n_kv_stride=n_ds_strides[1],
954 | attn_pdrop=attn_pdrop,
955 | proj_pdrop=proj_pdrop
956 | )
957 |
958 | # input
959 | if n_ds_strides[0] > 1:
960 | kernel_size, stride, padding = \
961 | n_ds_strides[0] + 1, n_ds_strides[0], (n_ds_strides[0] + 1)//2
962 | self.pool_skip = nn.MaxPool1d(
963 | kernel_size, stride=stride, padding=padding)
964 | else:
965 | self.pool_skip = nn.Identity()
966 |
967 | # two layer mlp
968 | if n_hidden is None:
969 | n_hidden = 4 * n_embd # default
970 | if n_out is None:
971 | n_out = n_embd
972 | # ok to use conv1d here with stride=1
973 | self.mlp = nn.Sequential(
974 | nn.Conv1d(n_embd, n_hidden, 1),
975 | act_layer(),
976 | nn.Dropout(proj_pdrop, inplace=True),
977 | nn.Conv1d(n_hidden, n_out, 1),
978 | nn.Dropout(proj_pdrop, inplace=True),
979 | )
980 |
981 | # drop path
982 | if path_pdrop > 0.0:
983 | self.drop_path_attn = AffineDropPath(n_embd, drop_prob = path_pdrop)
984 | self.drop_path_mlp = AffineDropPath(n_out, drop_prob = path_pdrop)
985 | else:
986 | self.drop_path_attn = nn.Identity()
987 | self.drop_path_mlp = nn.Identity()
988 |
989 | def forward(self, x, mask, pos_embd=None):
990 | # pre-LN transformer: https://arxiv.org/pdf/2002.04745.pdf
991 | out, out_mask = self.attn(self.ln1(x), mask)
992 | out_mask_float = out_mask.to(out.dtype)
993 | out = self.pool_skip(x) * out_mask_float + self.drop_path_attn(out)
994 | # FFN
995 | out = out + self.drop_path_mlp(self.mlp(self.ln2(out)) * out_mask_float)
996 | # optionally add pos_embd to the output
997 | if pos_embd is not None:
998 | out += pos_embd * out_mask_float
999 | return out, out_mask
1000 |
--------------------------------------------------------------------------------