├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── args.py ├── data ├── augs.py ├── data.py ├── data_kubric.py ├── data_plugin.py ├── data_utils.py └── data_vis.py ├── datasets ├── rubric_all_videos.txt ├── rubric_cupgames_videos.txt ├── rubric_davytb_videos.txt └── rubric_office_videos.txt ├── demo ├── teaduck2.mp4 ├── teaduck2_135_occl.png ├── teaduck2_15_query.png ├── teaduck2_195_snitch.png └── teaduck2_75_occl.png ├── eval ├── inference.py ├── metrics.py ├── pick_represent.py └── test.py ├── gen_kubric ├── export_kub_cont.py ├── export_kub_rand.py ├── kubric_constants.py ├── kubric_sim.py └── scene_type_utils.py ├── loss.py ├── model ├── mask_tracker.py ├── resnet.py ├── seeker.py └── vision_tf.py ├── pipeline.py ├── rep_lists ├── kubric_containers.txt ├── kubric_random.txt ├── rubric_cupgames.txt ├── rubric_davytb.txt └── rubric_office.txt ├── requirements.txt ├── third_party └── TimeSformer │ ├── .gitignore │ ├── CODE_OF_CONDUCT.md │ ├── CONTRIBUTING.md │ ├── LICENSE │ ├── README.md │ ├── environment.yml │ ├── example.ipynb │ ├── setup.cfg │ ├── setup.py │ └── timesformer │ ├── __init__.py │ ├── config │ ├── __init__.py │ └── defaults.py │ ├── datasets │ ├── DATASET.md │ ├── __init__.py │ ├── build.py │ ├── cv2_transform.py │ ├── decoder.py │ ├── kinetics.py │ ├── loader.py │ ├── multigrid_helper.py │ ├── ssv2.py │ ├── transform.py │ ├── utils.py │ └── video_container.py │ ├── models │ ├── __init__.py │ ├── batchnorm_helper.py │ ├── build.py │ ├── conv2d_same.py │ ├── custom_video_model_builder.py │ ├── features.py │ ├── head_helper.py │ ├── helpers.py │ ├── linear.py │ ├── losses.py │ ├── nonlocal_helper.py │ ├── operators.py │ ├── optimizer.py │ ├── resnet_helper.py │ ├── stem_helper.py │ ├── video_model_builder.py │ ├── vit.py │ └── vit_utils.py │ ├── utils │ ├── __init__.py │ ├── ava_eval_helper.py │ ├── ava_evaluation │ │ ├── README.md │ │ ├── __init__.py │ │ ├── ava_action_list_v2.1_for_activitynet_2018.pbtxt.txt │ │ ├── label_map_util.py │ │ ├── metrics.py │ │ ├── np_box_list.py │ │ ├── np_box_list_ops.py │ │ ├── np_box_mask_list.py │ │ ├── np_box_mask_list_ops.py │ │ ├── np_box_ops.py │ │ ├── np_mask_ops.py │ │ ├── object_detection_evaluation.py │ │ ├── per_image_evaluation.py │ │ └── standard_fields.py │ ├── benchmark.py │ ├── bn_helper.py │ ├── c2_model_loading.py │ ├── checkpoint.py │ ├── distributed.py │ ├── env.py │ ├── logging.py │ ├── lr_policy.py │ ├── meters.py │ ├── metrics.py │ ├── misc.py │ ├── multigrid.py │ ├── multiprocessing.py │ ├── parser.py │ └── weight_init_helper.py │ └── visualization │ ├── __init__.py │ ├── tensorboard_vis.py │ └── utils.py ├── train.py └── utils ├── geometry.py ├── logvis.py ├── logvisgen.py ├── my_utils.py └── visualization.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom. 2 | checkpoints/ 3 | checkpoints_*/ 4 | logs/ 5 | logs_*/ 6 | wandb/ 7 | *.egg-info 8 | docker_python/ 9 | kubric_output*/ 10 | tk_output*/ 11 | experimental/unit_test/ 12 | unit_test/ 13 | *.lnk 14 | pretrained/*.pyth 15 | plugin-*-ready/ 16 | plugin-ready/ 17 | tmp*.mp4 18 | *.pth 19 | pruned_*/ 20 | .cache/ 21 | cache/ 22 | datasets/*/ 23 | datasets/kubric* 24 | datasets/rubric* 25 | !datasets/*.txt 26 | 27 | # Byte-compiled / optimized / DLL files 28 | __pycache__/ 29 | *.py[cod] 30 | *$py.class 31 | 32 | # C extensions 33 | *.so 34 | 35 | # Distribution / packaging 36 | .Python 37 | build/ 38 | develop-eggs/ 39 | dist/ 40 | downloads/ 41 | eggs/ 42 | .eggs/ 43 | lib/ 44 | lib64/ 45 | parts/ 46 | sdist/ 47 | var/ 48 | wheels/ 49 | pip-wheel-metadata/ 50 | share/python-wheels/ 51 | *.egg-info/ 52 | .installed.cfg 53 | *.egg 54 | MANIFEST 55 | 56 | # PyInstaller 57 | # Usually these files are written by a python script from a template 58 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 59 | *.manifest 60 | *.spec 61 | 62 | # Installer logs 63 | pip-log.txt 64 | pip-delete-this-directory.txt 65 | 66 | # Unit test / coverage reports 67 | htmlcov/ 68 | .tox/ 69 | .nox/ 70 | .coverage 71 | .coverage.* 72 | .cache 73 | nosetests.xml 74 | coverage.xml 75 | *.cover 76 | *.py,cover 77 | .hypothesis/ 78 | .pytest_cache/ 79 | 80 | # Translations 81 | *.mo 82 | *.pot 83 | 84 | # Django stuff: 85 | *.log 86 | local_settings.py 87 | db.sqlite3 88 | db.sqlite3-journal 89 | 90 | # Flask stuff: 91 | instance/ 92 | .webassets-cache 93 | 94 | # Scrapy stuff: 95 | .scrapy 96 | 97 | # Sphinx documentation 98 | docs/_build/ 99 | 100 | # PyBuilder 101 | target/ 102 | 103 | # Jupyter Notebook 104 | .ipynb_checkpoints 105 | 106 | # IPython 107 | profile_default/ 108 | ipython_config.py 109 | 110 | # pyenv 111 | .python-version 112 | 113 | # pipenv 114 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 115 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 116 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 117 | # install all needed dependencies. 118 | #Pipfile.lock 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Basile Van Hoorick 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | These imports are shared across all files. 3 | Created by Basile Van Hoorick for TCOW. 4 | ''' 5 | 6 | # Library imports. 7 | import argparse 8 | import collections 9 | import collections.abc 10 | import copy 11 | import cv2 12 | import imageio 13 | import itertools 14 | import joblib 15 | import json 16 | import lovely_numpy 17 | import lovely_tensors 18 | import matplotlib.colors 19 | import matplotlib.pyplot as plt 20 | import multiprocessing as mp 21 | import numpy as np 22 | import os 23 | import pandas as pd 24 | import pathlib 25 | import pickle 26 | import platform 27 | import random 28 | import rich 29 | import rich.console 30 | import rich.logging 31 | import rich.progress 32 | import scipy 33 | import seaborn as sns 34 | import shutil 35 | import sklearn 36 | import sklearn.decomposition 37 | import sys 38 | import time 39 | import torch 40 | import torch.nn 41 | import torch.nn.functional 42 | import torch.optim 43 | import torch.utils 44 | import torch.utils.data 45 | import torchvision 46 | import torchvision.datasets 47 | import torchvision.io 48 | import torchvision.models 49 | import torchvision.transforms 50 | import torchvision.utils 51 | import tqdm 52 | import tqdm.rich 53 | import warnings 54 | from collections import defaultdict 55 | from einops import rearrange, repeat 56 | from lovely_numpy import lo 57 | from rich import print 58 | 59 | PROJECT_NAME = 'tcow' 60 | 61 | sys.path.append(os.getcwd()) 62 | sys.path.append(os.path.join(os.getcwd(), 'data/')) 63 | sys.path.append(os.path.join(os.getcwd(), 'eval/')) 64 | sys.path.append(os.path.join(os.getcwd(), 'model/')) 65 | sys.path.append(os.path.join(os.getcwd(), 'third_party/')) 66 | sys.path.append(os.path.join(os.getcwd(), 'utils/')) 67 | 68 | lovely_tensors.monkey_patch() 69 | 70 | 71 | # Quick functions for usage during debugging: 72 | 73 | def mmm(x): 74 | return (x.min(), x.mean(), x.max()) 75 | 76 | 77 | def st(x): 78 | return (x.dtype, x.shape) 79 | 80 | 81 | def stmmm(x): 82 | return (*st(x), *mmm(x)) 83 | -------------------------------------------------------------------------------- /data/data_vis.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Dataset and annotation visualization methods, usually for temporary debugging. 3 | Created by Basile Van Hoorick for TCOW. 4 | ''' 5 | 6 | from __init__ import * 7 | 8 | # Internal imports. 9 | import my_utils 10 | import visualization 11 | 12 | 13 | def depth_to_rgb_vis(depth, max_depth=None): 14 | ''' 15 | :depth (*, 1) array of float32. 16 | :return rgb_vis (*, 3) array of uint8. 17 | ''' 18 | min_depth = 0.0 19 | if max_depth is None: 20 | max_depth = max(np.max(depth), 1e-6) 21 | 22 | depth = depth.copy().squeeze(-1) 23 | depth = np.clip(depth, 0.0, max_depth) 24 | depth = (depth - min_depth) / (max_depth - min_depth) 25 | 26 | rgb_vis = plt.cm.viridis(2.0 / (depth + 1.0) - 1.0)[..., :3] 27 | rgb_vis = (rgb_vis * 255.0).astype(np.uint8) 28 | 29 | return rgb_vis 30 | 31 | 32 | def segm_rgb_to_ids_kubric(segm_rgb): # , num_inst=None): 33 | ''' 34 | :param segm_rgb (*, 3) array of RGB values. 35 | :return segm_ids (*, 1) array of 1-based instance IDs (0 = background). 36 | ''' 37 | # We assume that hues are distributed across the range [0, 1] for instances in the image, ranked 38 | # by their integer ID. Check kubric plotting.hls_palette() for more details. 39 | hsv = matplotlib.colors.rgb_to_hsv(segm_rgb) 40 | to_rank = hsv[..., 0] # + hsv[..., 2] * 1e-5 41 | unique_hues = np.sort(np.unique(to_rank)) 42 | hue_start = 0.01 43 | assert np.isclose(unique_hues[0], 0.0, rtol=1e-3, atol=1e-3), str(unique_hues) 44 | 45 | # commented this because object ID 0 may not be always visible in every frame: 46 | # assert np.isclose(unique_hues[1], hue_start, rtol=1e-3, atol=1e-3), str(unique_hues) 47 | 48 | # The smallest jump inbetween subsequent hues determines the highest instance ID that is VALO, 49 | # which is <= the total number of instances. Skip the very first hue, which is always 0 and 50 | # corresponds to background. 51 | hue_steps = np.array([unique_hues[i] - unique_hues[i - 1] for i in range(2, len(unique_hues))]) 52 | 53 | # For this sanity check to work, we must never have more than ~95 instances per scene. 54 | assert np.all(hue_steps >= 1e-2), str(hue_steps) 55 | 56 | # IMPORTANT NOTE: The current VALO set may be a strict SUBSET of the original VALO set (recorded 57 | # in the metadata), because we already applied frame subsampling here! In practice, this 58 | # sometimes causes big (i.e. integer multiple) jumps in hue_steps. 59 | # NEW: Ignore outliers the smart way. 60 | adjacent_steps = hue_steps[hue_steps <= np.min(hue_steps) * 1.5] 61 | hue_step = np.mean(adjacent_steps) 62 | 63 | # The jump from background to first instance is a special case, so ensure even distribution. 64 | nice_rank = to_rank.copy() 65 | nice_rank[nice_rank >= hue_start] += hue_step - hue_start 66 | ids_approx = (nice_rank / hue_step) 67 | 68 | segm_ids = np.round(ids_approx)[..., None].astype(np.int32) # (T, H, W, 1). 69 | return segm_ids 70 | 71 | 72 | def segm_ids_to_rgb(segm_ids, num_inst=None): 73 | ''' 74 | NOTE: This is NOT consistent with segm_rgb_to_ids_kubric(), because background (0) gets mapped 75 | to red! 76 | :segm_ids (*, 1) array of uint32. 77 | :return segm_rgb (*, 3) array of uint8. 78 | ''' 79 | if num_inst is None: 80 | num_inst = np.max(segm_ids) + 1 81 | num_inst = max(num_inst, 1) 82 | 83 | segm_ids = segm_ids.copy().squeeze(-1) 84 | segm_ids = segm_ids / num_inst 85 | 86 | segm_rgb = plt.cm.hsv(segm_ids)[..., :3] 87 | segm_rgb = (segm_rgb * 255.0).astype(np.uint8) 88 | 89 | return segm_rgb 90 | -------------------------------------------------------------------------------- /datasets/rubric_all_videos.txt: -------------------------------------------------------------------------------- 1 | # Relative paths to test set examples for TCOW Rubric benchmark. 2 | # Entries here can either be a single video, or a directory containing images. 3 | # (See data.py to see how this is processed) 4 | 5 | rubric/2_teaduck/teaduck1.mp4 6 | rubric/2_teaduck/teaduck2.mp4 7 | rubric/2_teaduck/teaduck3_reveal.mp4 8 | rubric/2_teaduck/teaduck4_reveal.mp4 9 | rubric/2_teaduck/teaduck5_cammove.mp4 10 | rubric/2_teaduck/teaduck6_teamove.mp4 11 | 12 | rubric/3_mugduck/mugduck1_mugmove.mp4 13 | rubric/3_mugduck/mugduck2_reveal.mp4 14 | rubric/3_mugduck/mugduck3_reveal.mp4 15 | rubric/3_mugduck/mugduck4_mugmove.mp4 16 | rubric/3_mugduck/multicupduck1_game.mp4 17 | 18 | rubric/4_home/chips1_oof.mp4 19 | rubric/4_home/pump1_scan.mp4 20 | rubric/4_home/pumpcookie1_reveal.mp4 21 | 22 | rubric/5_bagmugduck/bagduck1_move.mp4 23 | rubric/5_bagmugduck/bagduck3_recurse.mp4 24 | rubric/5_bagmugduck/bagduck4_transfer.mp4 25 | rubric/5_bagmugduck/mugduck5_stay.mp4 26 | rubric/5_bagmugduck/mugduck6_shuffle.mp4 27 | rubric/5_bagmugduck/mugduck7_shuffle.mp4 28 | 29 | rubric/6_handball/handball1_wave.mp4 30 | rubric/6_handball/handball2_boxoccl.mp4 31 | rubric/6_handball/handball3_occlpick.mp4 32 | 33 | rubric/7_ballbounce/lightball4_occl1x.mp4 34 | rubric/7_ballbounce/lightball5_occl2x.mp4 35 | rubric/7_ballbounce/lightball6_occl4x.mp4 36 | 37 | rubric/8_plantcupball/plantcupball1.mp4 38 | rubric/8_plantcupball/plantcupball2.mp4 39 | rubric/8_plantcupball/plantcupball3.mp4 40 | 41 | rubric/9_bowlbox/manycont1_nested.mp4 42 | rubric/9_bowlbox/manycont2_transfer.mp4 43 | rubric/9_bowlbox/manycont3_nested.mp4 44 | rubric/9_bowlbox/manycont4_transfer.mp4 45 | rubric/9_bowlbox/manycont5_transfer.mp4 46 | 47 | rubric/10_confuse/confuse1.mp4 48 | rubric/10_confuse/confuse2.mp4 49 | rubric/10_confuse/confuse3.mp4 50 | 51 | rubric/dmpt_cgt/cgt_frames_0002.mp4 52 | rubric/dmpt_cgt/cgt_frames_0011.mp4 53 | rubric/dmpt_cgt/cgt_frames_0026.mp4 54 | rubric/dmpt_cgt/cgt_frames_0061.mp4 55 | rubric/dmpt_cgt/cgt_frames_0065.mp4 56 | rubric/dmpt_cgt/cgt_frames_0076.mp4 57 | rubric/dmpt_cgt/cgt_frames_0092.mp4 58 | rubric/dmpt_cgt/cgt_frames_0113.mp4 59 | rubric/dmpt_cgt/cgt_frames_0126.mp4 60 | rubric/dmpt_cgt/cgt_frames_0136.mp4 61 | rubric/dmpt_cgt/cgt_frames_0137.mp4 62 | rubric/dmpt_cgt/cgt_frames_0154.mp4 63 | rubric/dmpt_cgt/cgt_frames_0172.mp4 64 | rubric/dmpt_cgt/cgt_frames_0211.mp4 65 | 66 | rubric/davis_2017/test_giant-slalom 67 | rubric/davis_2017/test_people-sunset 68 | rubric/davis_2017/test_salsa1 69 | rubric/davis_2017/test_salsa2 70 | rubric/davis_2017/test_salsa3 71 | rubric/davis_2017/test_subway1 72 | rubric/davis_2017/test_subway2 73 | rubric/davis_2017/test_subway3 74 | rubric/davis_2017/train_bmx-bumps 75 | rubric/davis_2017/train_dancing 76 | rubric/davis_2017/train_lindy-hop1 77 | rubric/davis_2017/train_lindy-hop2 78 | rubric/davis_2017/train_lindy-hop3 79 | rubric/davis_2017/train_rhino 80 | rubric/davis_2017/train_scooter-board 81 | rubric/davis_2017/train_soccerball 82 | rubric/davis_2017/val_bmx-trees 83 | rubric/davis_2017/val_india1 84 | rubric/davis_2017/val_india2 85 | rubric/davis_2017/val_india3 86 | rubric/davis_2017/val_libby 87 | rubric/davis_2017/val_pigs 88 | 89 | rubric/ytvos_2019/val_0a49f5265b 90 | rubric/ytvos_2019/val_0b97736357 91 | rubric/ytvos_2019/val_0c04834d61 92 | rubric/ytvos_2019/val_0e4068b53f 93 | rubric/ytvos_2019/val_1b85035216 94 | rubric/ytvos_2019/val_1e0257109e 95 | rubric/ytvos_2019/val_1e6efb0b5f 96 | rubric/ytvos_2019/val_3b72dc1941_rev 97 | rubric/ytvos_2019/val_3f4bacb16a_1 98 | rubric/ytvos_2019/val_3f2012d518 99 | rubric/ytvos_2019/val_4bef684040 100 | rubric/ytvos_2019/val_5c3d2d3155 101 | rubric/ytvos_2019/val_5d2020eff8 102 | rubric/ytvos_2019/val_7e625db8c4 103 | rubric/ytvos_2019/val_24e2b52a4d_fish1 104 | rubric/ytvos_2019/val_24e2b52a4d_fish2 105 | rubric/ytvos_2019/val_33c8dcbe09 106 | rubric/ytvos_2019/val_42d810ba9d 107 | rubric/ytvos_2019/val_91f5ad52e9 108 | -------------------------------------------------------------------------------- /datasets/rubric_cupgames_videos.txt: -------------------------------------------------------------------------------- 1 | # Relative paths to test set examples for TCOW Rubric benchmark. 2 | # Entries here can either be a single video, or a directory containing images. 3 | # (See data.py to see how this is processed) 4 | 5 | rubric/dmpt_cgt/cgt_frames_0002.mp4 6 | rubric/dmpt_cgt/cgt_frames_0011.mp4 7 | rubric/dmpt_cgt/cgt_frames_0026.mp4 8 | rubric/dmpt_cgt/cgt_frames_0061.mp4 9 | rubric/dmpt_cgt/cgt_frames_0065.mp4 10 | rubric/dmpt_cgt/cgt_frames_0076.mp4 11 | rubric/dmpt_cgt/cgt_frames_0092.mp4 12 | rubric/dmpt_cgt/cgt_frames_0113.mp4 13 | rubric/dmpt_cgt/cgt_frames_0126.mp4 14 | rubric/dmpt_cgt/cgt_frames_0136.mp4 15 | rubric/dmpt_cgt/cgt_frames_0137.mp4 16 | rubric/dmpt_cgt/cgt_frames_0154.mp4 17 | rubric/dmpt_cgt/cgt_frames_0172.mp4 18 | rubric/dmpt_cgt/cgt_frames_0211.mp4 19 | -------------------------------------------------------------------------------- /datasets/rubric_davytb_videos.txt: -------------------------------------------------------------------------------- 1 | # Relative paths to test set examples for TCOW Rubric benchmark. 2 | # Entries here can either be a single video, or a directory containing images. 3 | # (See data.py to see how this is processed) 4 | 5 | rubric/davis_2017/test_giant-slalom 6 | rubric/davis_2017/test_people-sunset 7 | rubric/davis_2017/test_salsa1 8 | rubric/davis_2017/test_salsa2 9 | rubric/davis_2017/test_salsa3 10 | rubric/davis_2017/test_subway1 11 | rubric/davis_2017/test_subway2 12 | rubric/davis_2017/test_subway3 13 | rubric/davis_2017/train_bmx-bumps 14 | rubric/davis_2017/train_dancing 15 | rubric/davis_2017/train_lindy-hop1 16 | rubric/davis_2017/train_lindy-hop2 17 | rubric/davis_2017/train_lindy-hop3 18 | rubric/davis_2017/train_rhino 19 | rubric/davis_2017/train_scooter-board 20 | rubric/davis_2017/train_soccerball 21 | rubric/davis_2017/val_bmx-trees 22 | rubric/davis_2017/val_india1 23 | rubric/davis_2017/val_india2 24 | rubric/davis_2017/val_india3 25 | rubric/davis_2017/val_libby 26 | rubric/davis_2017/val_pigs 27 | 28 | rubric/ytvos_2019/val_0a49f5265b 29 | rubric/ytvos_2019/val_0b97736357 30 | rubric/ytvos_2019/val_0c04834d61 31 | rubric/ytvos_2019/val_0e4068b53f 32 | rubric/ytvos_2019/val_1b85035216 33 | rubric/ytvos_2019/val_1e0257109e 34 | rubric/ytvos_2019/val_1e6efb0b5f 35 | rubric/ytvos_2019/val_3b72dc1941_rev 36 | rubric/ytvos_2019/val_3f4bacb16a_1 37 | rubric/ytvos_2019/val_3f2012d518 38 | rubric/ytvos_2019/val_4bef684040 39 | rubric/ytvos_2019/val_5c3d2d3155 40 | rubric/ytvos_2019/val_5d2020eff8 41 | rubric/ytvos_2019/val_7e625db8c4 42 | rubric/ytvos_2019/val_24e2b52a4d_fish1 43 | rubric/ytvos_2019/val_24e2b52a4d_fish2 44 | rubric/ytvos_2019/val_33c8dcbe09 45 | rubric/ytvos_2019/val_42d810ba9d 46 | rubric/ytvos_2019/val_91f5ad52e9 47 | -------------------------------------------------------------------------------- /datasets/rubric_office_videos.txt: -------------------------------------------------------------------------------- 1 | # Relative paths to test set examples for TCOW Rubric benchmark. 2 | # Entries here can either be a single video, or a directory containing images. 3 | # (See data.py to see how this is processed) 4 | 5 | rubric/2_teaduck/teaduck1.mp4 6 | rubric/2_teaduck/teaduck2.mp4 7 | rubric/2_teaduck/teaduck3_reveal.mp4 8 | rubric/2_teaduck/teaduck4_reveal.mp4 9 | rubric/2_teaduck/teaduck5_cammove.mp4 10 | rubric/2_teaduck/teaduck6_teamove.mp4 11 | 12 | rubric/3_mugduck/mugduck1_mugmove.mp4 13 | rubric/3_mugduck/mugduck2_reveal.mp4 14 | rubric/3_mugduck/mugduck3_reveal.mp4 15 | rubric/3_mugduck/mugduck4_mugmove.mp4 16 | rubric/3_mugduck/multicupduck1_game.mp4 17 | 18 | rubric/4_home/chips1_oof.mp4 19 | rubric/4_home/pump1_scan.mp4 20 | rubric/4_home/pumpcookie1_reveal.mp4 21 | 22 | rubric/5_bagmugduck/bagduck1_move.mp4 23 | rubric/5_bagmugduck/bagduck3_recurse.mp4 24 | rubric/5_bagmugduck/bagduck4_transfer.mp4 25 | rubric/5_bagmugduck/mugduck5_stay.mp4 26 | rubric/5_bagmugduck/mugduck6_shuffle.mp4 27 | rubric/5_bagmugduck/mugduck7_shuffle.mp4 28 | 29 | rubric/6_handball/handball1_wave.mp4 30 | rubric/6_handball/handball2_boxoccl.mp4 31 | rubric/6_handball/handball3_occlpick.mp4 32 | 33 | rubric/7_ballbounce/lightball4_occl1x.mp4 34 | rubric/7_ballbounce/lightball5_occl2x.mp4 35 | rubric/7_ballbounce/lightball6_occl4x.mp4 36 | 37 | rubric/8_plantcupball/plantcupball1.mp4 38 | rubric/8_plantcupball/plantcupball2.mp4 39 | rubric/8_plantcupball/plantcupball3.mp4 40 | 41 | rubric/9_bowlbox/manycont1_nested.mp4 42 | rubric/9_bowlbox/manycont2_transfer.mp4 43 | rubric/9_bowlbox/manycont3_nested.mp4 44 | rubric/9_bowlbox/manycont4_transfer.mp4 45 | rubric/9_bowlbox/manycont5_transfer.mp4 46 | 47 | rubric/10_confuse/confuse1.mp4 48 | rubric/10_confuse/confuse2.mp4 49 | rubric/10_confuse/confuse3.mp4 50 | -------------------------------------------------------------------------------- /demo/teaduck2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basilevh/tcow/a72e3e13a45e4156137328e5290f9e848d360367/demo/teaduck2.mp4 -------------------------------------------------------------------------------- /demo/teaduck2_135_occl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basilevh/tcow/a72e3e13a45e4156137328e5290f9e848d360367/demo/teaduck2_135_occl.png -------------------------------------------------------------------------------- /demo/teaduck2_15_query.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basilevh/tcow/a72e3e13a45e4156137328e5290f9e848d360367/demo/teaduck2_15_query.png -------------------------------------------------------------------------------- /demo/teaduck2_195_snitch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basilevh/tcow/a72e3e13a45e4156137328e5290f9e848d360367/demo/teaduck2_195_snitch.png -------------------------------------------------------------------------------- /demo/teaduck2_75_occl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basilevh/tcow/a72e3e13a45e4156137328e5290f9e848d360367/demo/teaduck2_75_occl.png -------------------------------------------------------------------------------- /eval/inference.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Evaluation tools. 3 | Created by Basile Van Hoorick for TCOW. 4 | ''' 5 | 6 | import os 7 | import sys 8 | sys.path.insert(0, os.path.join(os.getcwd(), 'eval/')) 9 | sys.path.insert(0, os.getcwd()) 10 | 11 | from __init__ import * 12 | 13 | # Internal imports. 14 | import my_utils 15 | import seeker 16 | import pipeline 17 | 18 | 19 | def load_networks(checkpoint_path, device, logger, epoch=-1): 20 | ''' 21 | :param checkpoint_path (str): Path to model checkpoint folder or file. 22 | :param epoch (int): If >= 0, desired checkpoint epoch to load. 23 | :return (networks, train_args, dset_args, model_args, epoch). 24 | networks (dict). 25 | train_args (dict). 26 | train_dset_args (dict). 27 | model_args (dict). 28 | epoch (int). 29 | ''' 30 | print_fn = logger.info if logger is not None else print 31 | 32 | assert os.path.exists(checkpoint_path) 33 | if os.path.isdir(checkpoint_path): 34 | model_fn = f'model_{epoch}.pth' if epoch >= 0 else 'checkpoint.pth' 35 | checkpoint_path = os.path.join(checkpoint_path, model_fn) 36 | 37 | print_fn('Loading weights from: ' + checkpoint_path) 38 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 39 | 40 | # Load all arguments for later use. 41 | train_args = checkpoint['train_args'] 42 | train_dset_args = checkpoint['dset_args'] 43 | 44 | # Get network instance parameters. 45 | seeker_args = checkpoint['seeker_args'] 46 | 47 | model_args = {'seeker': seeker_args} 48 | 49 | # Instantiate networks. 50 | seeker_net = seeker.Seeker(logger, **seeker_args) 51 | seeker_net = seeker_net.to(device) 52 | seeker_net.load_state_dict(checkpoint['net_seeker']) 53 | networks = {'seeker': seeker_net} 54 | epoch = checkpoint['epoch'] 55 | print_fn('=> Loaded epoch (1-based): ' + str(epoch + 1)) 56 | 57 | return (networks, train_args, train_dset_args, model_args, epoch) 58 | 59 | 60 | def perform_inference(data_retval, networks, device, logger, all_args, cur_step): 61 | # **pipeline_args): 62 | ''' 63 | Generates test time predictions. 64 | :param data_retval (dict): Data loader element. 65 | :param all_args (dict): train, test, train_dset, test_dset, model. 66 | ''' 67 | # Following DRY, prepare pipeline instance, *BUT* take care of shared args by updating them. 68 | used_args = copy.deepcopy(all_args['train']) 69 | used_args.num_queries = all_args['test'].num_queries 70 | 71 | my_pipeline = pipeline.MyTrainPipeline(used_args, logger, networks, device) 72 | my_pipeline.set_phase('test') # This calls eval() on all submodules. 73 | 74 | include_loss = True 75 | metrics_only = (data_retval['source_name'][0] == 'plugin') 76 | 77 | temp_st = time.time() 78 | (model_retval, loss_retval) = my_pipeline( 79 | data_retval, cur_step, cur_step, 0, 1.0, include_loss, metrics_only) 80 | logger.debug(f'(perform_inference) my_pipeline: {time.time() - temp_st:.3f}s') 81 | 82 | # Calculate various evaluation metrics. 83 | loss_retval = my_pipeline.process_entire_batch( 84 | data_retval, model_retval, loss_retval, cur_step, cur_step, 0, 1.0) \ 85 | if loss_retval is not None else None 86 | 87 | # Organize and return relevant info, moving stuff to CPU and/or converting to numpy as needed. 88 | inference_retval = dict() 89 | inference_retval['model_retval'] = model_retval 90 | inference_retval['loss_retval'] = loss_retval 91 | inference_retval = my_utils.dict_to_cpu(inference_retval) 92 | 93 | return inference_retval 94 | -------------------------------------------------------------------------------- /eval/test.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Evaluation logic. 3 | Created by Basile Van Hoorick for TCOW. 4 | ''' 5 | 6 | import os 7 | import sys 8 | sys.path.insert(0, os.path.join(os.getcwd(), 'eval/')) 9 | sys.path.insert(0, os.getcwd()) 10 | 11 | from __init__ import * 12 | 13 | # Internal imports. 14 | import args 15 | import data 16 | import data_utils 17 | import inference 18 | import logvis 19 | import metrics 20 | import my_utils 21 | 22 | 23 | def _test_inner(all_args, networks, data_loader, device, logger, step_offset): 24 | 25 | num_steps = len(data_loader) 26 | start_time = time.time() 27 | inference_retvals = [] 28 | 29 | # for cur_step, data_retval in enumerate(tqdm.tqdm(data_loader)): 30 | for cur_step, data_retval in enumerate(data_loader): 31 | 32 | real_step = cur_step + step_offset 33 | 34 | if cur_step == 0: 35 | logger.info(f'Enter first data loader iteration took {time.time() - start_time:.3f}s') 36 | 37 | data_retval['within_batch_idx'] = torch.arange(all_args['test'].batch_size) # (B). 38 | 39 | # Perform inference (independently per example). 40 | inference_retval = inference.perform_inference( 41 | data_retval, networks, device, logger, all_args, real_step) 42 | 43 | # Print and visualize stuff. 44 | if not(all_args['test'].log_rarely): 45 | friendly_short_name = logger.handle_test_step( 46 | real_step, num_steps, data_retval, inference_retval, all_args) 47 | inference_retval['friendly_short_name'] = friendly_short_name 48 | 49 | # Save all data if desired. 50 | data_retval_pruned = data_utils.clean_remain_reproducible(data_retval) 51 | inference_retval['data_retval_pruned'] = data_retval_pruned 52 | if all_args['test'].store_results: 53 | logger.save_pickle( 54 | inference_retval, f'inference_retval_s{real_step}.p', step=real_step) 55 | 56 | # Save some information to be aggregated across the entire test set. 57 | inference_retval = my_utils.dict_to_cpu(inference_retval) 58 | inference_retvals.append(inference_retval) 59 | 60 | return inference_retvals 61 | 62 | 63 | def _test_outer(all_args, networks, device, logger): 64 | ''' 65 | :param all_args (dict): train, test, train_dset, test_dset, model. 66 | ''' 67 | outer_start_time = time.time() 68 | 69 | for net in networks.values(): 70 | net.eval() 71 | torch.set_grad_enabled(False) 72 | 73 | orig_test_args = copy.deepcopy(all_args['test']) 74 | 75 | assert isinstance(all_args['test'].data_path, list) 76 | actual_data_paths = data_utils.get_data_paths_from_args(all_args['test'].data_path) 77 | assert isinstance(actual_data_paths, list) 78 | 79 | inference_retvals = [] 80 | step_offset = 0 81 | 82 | logger.info('Starting outer test loop over individual data paths...') 83 | for outer_step, cur_data_path in enumerate(tqdm.tqdm(actual_data_paths)): 84 | 85 | # Temporarily overwrite value in args object because data.py uses this. We pretend to the 86 | # callee that only a list of size one was given. 87 | all_args['test'].data_path = [cur_data_path] 88 | 89 | logger.info('Initializing current data loader...') 90 | start_time = time.time() 91 | 92 | # Instantiate dataset. 93 | (cur_test_loader, test_dset_args) = data.create_test_data_loader( 94 | all_args['train'], all_args['test'], all_args['train_dset'], logger) 95 | if outer_step == 0: 96 | logger.info('Final (first) test dataset args: ' + str(test_dset_args)) 97 | all_args['test_dset'] = test_dset_args 98 | 99 | logger.info(f'Took {time.time() - start_time:.3f}s') 100 | 101 | cur_inference_retvals = _test_inner( 102 | all_args, networks, cur_test_loader, device, logger, step_offset) 103 | 104 | inference_retvals += cur_inference_retvals 105 | 106 | step_offset += len(cur_test_loader) 107 | 108 | del cur_test_loader 109 | 110 | all_args['test'] = orig_test_args # Restore test_args. 111 | 112 | _test_postprocess(inference_retvals, logger) 113 | 114 | logger.info() 115 | total_time = time.time() - outer_start_time 116 | logger.info(f'Total time: {total_time / 3600.0:.3f} hours') 117 | 118 | pass 119 | 120 | 121 | def _test_postprocess(inference_retvals, logger): 122 | 123 | # Report mean metrics over all examples (consider both per-frame and per-scene). 124 | if inference_retvals[0]['loss_retval'] is not None: 125 | 126 | metrics_retvals = [x['loss_retval']['metrics'] for x in inference_retvals] 127 | 128 | final_weighted_metrics = metrics.calculate_weighted_averages(metrics_retvals) 129 | final_unweighted_metrics = metrics.calculate_unweighted_averages(metrics_retvals) 130 | metrics.pretty_print_aggregated( 131 | logger, final_weighted_metrics, final_unweighted_metrics, len(metrics_retvals)) 132 | 133 | test_results = metrics.test_results_to_dataframe(inference_retvals) 134 | 135 | # Export itemized results to CSV file to allow for easy sorting. 136 | csv_fp = os.path.join(logger.log_dir, 'itemized_results.csv') 137 | test_results.to_csv(csv_fp) 138 | logger.info(f'Exported quantitative results to: {csv_fp}') 139 | 140 | # Sanity check: verify_* should be exactly equal (both keys and values) to final_*_metrics. 141 | verify_weighted = metrics.calculate_weighted_averages_dataframe(test_results) 142 | verify_unweighted = metrics.calculate_unweighted_averages_dataframe(test_results) 143 | for k in verify_weighted.keys(): 144 | if not(np.isnan(verify_weighted[k]) or np.isnan(final_weighted_metrics[k])): 145 | if not(np.isclose(verify_weighted[k], final_weighted_metrics[k])): 146 | logger.error(f'Weighted metric {k} does not match! ' 147 | f'{verify_weighted[k]} vs {final_weighted_metrics[k]}') 148 | for k in verify_unweighted.keys(): 149 | if not(np.isnan(verify_unweighted[k]) or np.isnan(final_unweighted_metrics[k])): 150 | if not(np.isclose(verify_unweighted[k], final_unweighted_metrics[k])): 151 | logger.error(f'Unweighted metric {k} does not match! ' 152 | f'{verify_unweighted[k]} vs {final_unweighted_metrics[k]}') 153 | 154 | if len(inference_retvals) >= 20: 155 | logger.warning() 156 | logger.warning('Note that the metrics reported here are based on ALL available clips ' 157 | 'sampled from the input video files. If you ran the test script on a Kubric ' 158 | 'or Rubric dataset, we recommend ignoring these numbers and instead running ' 159 | 'pick_represent for a more balanced evaluation (which are also the stats we ' 160 | 'use in our paper). See README for instructions.') 161 | logger.warning() 162 | 163 | pass 164 | 165 | 166 | def main(test_args, logger): 167 | 168 | logger.info() 169 | logger.info('torch version: ' + str(torch.__version__)) 170 | logger.info('torchvision version: ' + str(torchvision.__version__)) 171 | logger.save_args(test_args) 172 | 173 | # NOTE: This current test script is not even dependent on any randomness / seed at all! 174 | np.random.seed(test_args.seed) 175 | random.seed(test_args.seed) 176 | torch.manual_seed(test_args.seed) 177 | if test_args.device == 'cuda': 178 | torch.cuda.manual_seed_all(test_args.seed) 179 | 180 | logger.info('Initializing model...') 181 | start_time = time.time() 182 | 183 | # Instantiate networks and load weights. 184 | if test_args.device == 'cuda': 185 | device = torch.device('cuda:' + str(test_args.gpu_id)) 186 | else: 187 | device = torch.device(test_args.device) 188 | (networks, train_args, train_dset_args, model_args, epoch) = \ 189 | inference.load_networks(test_args.resume, device, logger, epoch=test_args.epoch) 190 | 191 | logger.info(f'Took {time.time() - start_time:.3f}s') 192 | 193 | if test_args.avoid_wandb < 2: 194 | logger.init_wandb(PROJECT_NAME, test_args, networks.values(), name=test_args.name, 195 | group=test_args.wandb_group) 196 | 197 | # Print test arguments. 198 | logger.info('Train command args: ' + str(train_args)) 199 | logger.info('Train dataset args: ' + str(train_dset_args)) 200 | logger.info('Final test command args: ' + str(test_args)) 201 | 202 | # Combine arguments for later use. 203 | all_args = dict() 204 | all_args['train'] = train_args 205 | all_args['test'] = test_args 206 | all_args['train_dset'] = train_dset_args 207 | all_args['model'] = model_args 208 | 209 | # Run actual test loop. 210 | _test_outer(all_args, networks, device, logger) 211 | 212 | 213 | if __name__ == '__main__': 214 | 215 | # WARNING: This is slow, but we can detect NaNs this way: 216 | # torch.autograd.set_detect_anomaly(True) 217 | 218 | np.set_printoptions(precision=3, suppress=True) 219 | torch.set_printoptions(precision=3, sci_mode=False) 220 | 221 | # https://github.com/pytorch/pytorch/issues/11201 222 | torch.multiprocessing.set_sharing_strategy('file_system') 223 | torch.cuda.empty_cache() 224 | 225 | test_args = args.test_args() 226 | 227 | logger = logvis.MyLogger(test_args, context='test_' + test_args.name, 228 | log_level=test_args.log_level.upper()) 229 | 230 | try: 231 | 232 | main(test_args, logger) 233 | 234 | except Exception as e: 235 | 236 | logger.exception(e) 237 | 238 | logger.warning('Shutting down due to exception...') 239 | -------------------------------------------------------------------------------- /gen_kubric/kubric_constants.py: -------------------------------------------------------------------------------- 1 | 2 | GSO_CONTAINER_IDS = [ 3 | 'Utana_5_Porcelain_Ramekin_Large', 4 | 'Threshold_Porcelain_Pitcher_White', 5 | 'Threshold_Porcelain_Coffee_Mug_All_Over_Bead_White', 6 | 'Threshold_Bead_Cereal_Bowl_White', 7 | 'Threshold_Basket_Natural_Finish_Fabric_Liner_Small', 8 | 'Target_Basket_Medium', 9 | 'Sterilite_Caddy_Blue_Sky_17_58_x_12_58_x_9_14', 10 | 'Spritz_Easter_Basket_Plastic_Teal', 11 | 'Smith_Hawken_Woven_BasketTray_Organizer_with_3_Compartments_95_x_9_x_13', 12 | 'Sapota_Threshold_4_Ceramic_Round_Planter_Red', 13 | 'Room_Essentials_Mug_White_Yellow', 14 | 'Room_Essentials_Fabric_Cube_Lavender', 15 | 'Room_Essentials_Bowl_Turquiose', 16 | 'RJ_Rabbit_Easter_Basket_Blue', 17 | 'Pennington_Electric_Pot_Cabana_4', 18 | 'Now_Designs_Bowl_Akita_Black', 19 | 'Nordic_Ware_Original_Bundt_Pan', 20 | 'Markings_Desk_Caddy', 21 | 'Hefty_Waste_Basket_Decorative_Bronze_85_liter', 22 | 'Full_Circle_Happy_Scraps_Out_Collector_Gray', 23 | 'Footed_Bowl_Sand', 24 | 'Ecoforms_Pot_Nova_6_Turquoise', 25 | 'Ecoforms_Planter_Pot_QP6Ebony', 26 | 'Ecoforms_Planter_Pot_GP12AAvocado', 27 | 'Ecoforms_Planter_Bowl_Cole_Hardware', 28 | 'Ecoforms_Plant_Pot_GP9_SAND', 29 | 'Ecoforms_Plant_Pot_GP9AAvocado', 30 | 'Ecoforms_Plant_Container_Urn_55_Mocha', 31 | 'Ecoforms_Plant_Container_Urn_55_Avocado', 32 | 'Ecoforms_Plant_Container_URN_SAN', 33 | 'Ecoforms_Plant_Container_URN_NAT', 34 | 'Ecoforms_Plant_Container_SB9Turquoise', 35 | 'Ecoforms_Plant_Container_Quadra_Turquoise_QP12', 36 | 'Ecoforms_Plant_Container_Quadra_Sand_QP6', 37 | 'Ecoforms_Plant_Container_GP16A_Coral', 38 | 'Ecoforms_Plant_Container_QP6CORAL', 39 | 'Ecoforms_Plant_Container_QP6HARVEST', 40 | 'Ecoforms_Plant_Container_QP_Harvest', 41 | 'Ecoforms_Plant_Container_QP_Turquoise', 42 | 'Ecoforms_Plant_Container_GP16AMOCHA', 43 | 'Ecoforms_Plant_Container_FB6_Tur', 44 | 'Ecoforms_Plant_Container_B4_Har', 45 | 'Ecoforms_Plant_Container_12_Pot_Nova', 46 | 'Ecoforms_Plant_Bowl_Turquoise_7', 47 | 'Down_To_Earth_Orchid_Pot_Ceramic_Lime', 48 | 'Down_To_Earth_Orchid_Pot_Ceramic_Red', 49 | 'Ecoforms_Cup_B4_SAN', 50 | 'Ecoforms_Garden_Pot_GP16ATurquois', 51 | 'Down_To_Earth_Ceramic_Orchid_Pot_Asst_Blue', 52 | 'Curver_Storage_Bin_Black_Small', 53 | 'Cole_Hardware_Orchid_Pot_85', 54 | 'Cole_Hardware_Mug_Classic_Blue', 55 | 'Cole_Hardware_Flower_Pot_1025', 56 | 'Cole_Hardware_Electric_Pot_Cabana_55', 57 | 'Cole_Hardware_Electric_Pot_Assortment_55', 58 | 'Cole_Hardware_Deep_Bowl_Good_Earth_1075', 59 | 'Cole_Hardware_Bowl_Scirocco_YellowBlue', 60 | 'Closetmaid_Premium_Fabric_Cube_Red', 61 | 'Central_Garden_Flower_Pot_Goo_425', 62 | 'Calphalon_Kitchen_Essentials_12_Cast_Iron_Fry_Pan_Black', 63 | 'Bradshaw_International_11642_7_Qt_MP_Plastic_Bowl', 64 | 'BIA_Porcelain_Ramekin_With_Glazed_Rim_35_45_oz_cup', 65 | 'BIA_Cordon_Bleu_White_Porcelain_Utensil_Holder_900028', 66 | '45oz_RAMEKIN_ASST_DEEP_COLORS', 67 | ] 68 | GSO_CARRIER_IDS = [ 69 | 'Top_Paw_Dog_Bowl_Blue_Paw_Bone_Ceramic_25_fl_oz_total', 70 | 'Top_Paw_Dog_Bow_Bone_Ceramic_13_fl_oz_total', 71 | 'Threshold_Tray_Rectangle_Porcelain', 72 | 'Threshold_Salad_Plate_Square_Rim_Porcelain', 73 | 'Threshold_Porcelain_Spoon_Rest_White', 74 | 'Threshold_Porcelain_Serving_Bowl_Coupe_White', 75 | 'Threshold_Dinner_Plate_Square_Rim_White_Porcelain', 76 | 'Threshold_Bistro_Ceramic_Dinner_Plate_Ruby_Ring', 77 | 'Threshold_Bamboo_Ceramic_Soap_Dish', 78 | 'Sea_to_Summit_Xl_Bowl', 79 | 'Room_Essentials_Salad_Plate_Turquoise', 80 | 'Room_Essentials_Dish_Drainer_Collapsible_White', 81 | 'Neat_Solutions_Character_Bib_2_pack', 82 | 'Markings_Letter_Holder', 83 | 'Kotobuki_Saucer_Dragon_Fly', 84 | 'Grreatv_Choice_Dog_Bowl_Gray_Bones_Plastic_20_fl_oz_total', 85 | 'Grreat_Choice_Dog_Double_Dish_Plastic_Blue', 86 | 'Ecoforms_Saucer_SQ3_Turquoise', 87 | 'Ecoforms_Quadra_Saucer_SQ1_Avocado', 88 | 'Ecoforms_Plate_S20Avocado', 89 | 'Ecoforms_Plant_Saucer_SQ8COR', 90 | 'Ecoforms_Plant_Saucer_SQ1HARVEST', 91 | 'Ecoforms_Plant_Saucer_S20MOCHA', 92 | 'Ecoforms_Plant_Saucer_S17MOCHA', 93 | 'Ecoforms_Plant_Saucer_S14NATURAL', 94 | 'Ecoforms_Plant_Saucer_S14MOCHA', 95 | 'Ecoforms_Plant_Plate_S11Turquoise', 96 | 'Ecoforms_Plant_Container_S24Turquoise', 97 | 'Ecoforms_Plant_Container_S24NATURAL', 98 | 'Ecoforms_Plant_Container_S14Turquoise', 99 | 'Ecoforms_Plant_Bowl_Atlas_Low', 100 | 'Dixie_10_ounce_Bowls_35_ct', 101 | 'Cole_Hardware_Saucer_Glazed_6', 102 | 'Design_Ideas_Drawer_Store_Organizer', 103 | 'Corningware_CW_by_Corningware_3qt_Oblong_Casserole_Dish_Blue', 104 | 'Cole_Hardware_Saucer_Electric', 105 | 'Cole_Hardware_Plant_Saucer_Glazed_9', 106 | 'Cole_Hardware_Plant_Saucer_Brown_125', 107 | 'Chefmate_8_Frypan', 108 | 'Chef_Style_Round_Cake_Pan_9_inch_pan', 109 | 'Avengers_Thor_PLlrpYniaeB', 110 | '3D_Dollhouse_Swing', 111 | '3D_Dollhouse_Sofa', 112 | 'Travel_Mate_P_series_Notebook', 113 | 'RedBlack_Nintendo_3DSXL', 114 | 'Lenovo_Yoga_2_11', 115 | 'BlueBlack_Nintendo_3DSXL', 116 | 'BlackBlack_Nintendo_3DSXL', 117 | 'Melissa_Doug_Pound_and_Roll', 118 | 'GEOMETRIC_PEG_BOARD', 119 | ] 120 | 121 | # These are also containers but must be spawned in upside down rotation. 122 | GSO_HAT_IDS = [ 123 | 'Retail_Leadership_Summit_tQFCizMt6g0', 124 | 'Retail_Leadership_Summit_eCT3zqHYIkX', 125 | 'Retail_Leadership_Summit', 126 | 'DPC_tropical_Trends_Hat', 127 | 'DPC_Handmade_Hat_Brown', 128 | ] 129 | 130 | GSO_SHOE_CONTAINS = [ 131 | 'adipure', 132 | 'adistar', 133 | 'adizero', 134 | 'adrenaline', 135 | 'amberlight', 136 | 'asics', 137 | 'boot', 138 | 'cascadia', 139 | 'chelsea', 140 | 'climacool', 141 | 'colton', 142 | 'court', 143 | 'crazy_', 144 | 'd_rose', 145 | 'dragon_w', 146 | 'fyw', 147 | 'ghost', 148 | 'glycerin', 149 | 'great_jones', 150 | 'grand_prix', 151 | 'mens', 152 | 'pheehan', 153 | 'predito', 154 | 'predator', 155 | 'purecadence', 156 | 'pureconnect', 157 | 'pureflow', 158 | 'ravenna', 159 | 'reebok', 160 | 'reef', 161 | 'samba', 162 | 'samoa', 163 | 'slack', 164 | 'suede', 165 | 'superstar', 166 | 'terrex', 167 | 'tieks', 168 | 'timberland', 169 | 'top_ten', 170 | 'topsider', 171 | 'tory_', 172 | 'trochilus', 173 | 'trx', 174 | 'tzx', 175 | 'ugg_', 176 | 'wings', 177 | 'zigkick', 178 | 'zx700', 179 | ] 180 | 181 | # These are perfect boxes / cuboids. 182 | # NOT included are: 183 | # - Non-perfect cuboids, extra things sticking out. 184 | # - Too flat boxes (such as cereal boxes). 185 | # - Too exotic or unrealistic covers (such as shampoo bottles). 186 | GSO_BOX_CONTAINS = [ 187 | 'apples', 188 | 'asus', 189 | 'bean', 190 | 'blender', 191 | 'brother_printing', 192 | 'capsules', 193 | 'chocolate_cube', 194 | 'crayola', # Sometimes has tags on top. 195 | 'crunch_girl', 196 | 'deskstar', 197 | 'international_paper', 198 | 'jarrosil', 199 | 'latte', 200 | 'lego_bricks', 201 | 'mist', 202 | 'momento', 203 | 'nestle_carnation', 204 | 'nestle_nips', 205 | 'nestle_skinny', 206 | 'netgear', 207 | 'office_depot', 208 | 'packet', # Covers Whey stuff. 209 | 'paint', 210 | 'perricoen', 211 | 'perricone', 212 | 'pepsi', 213 | 'philips', 214 | 'pure_life', 215 | 'same_200', 216 | 'snacks', 217 | 'super_berry', 218 | 'soda', 219 | 'yum', 220 | ] 221 | 222 | # These are other things used for pushing that are not boxes. 223 | # (abandoned) 224 | # GSO_PUSHER_CONTAINS = [ 225 | # ] 226 | -------------------------------------------------------------------------------- /model/mask_tracker.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Neural network architecture description. 3 | Created by Basile Van Hoorick for TCOW. 4 | ''' 5 | 6 | import os 7 | import sys 8 | sys.path.insert(0, os.path.join(os.getcwd(), 'seeker/')) 9 | sys.path.insert(0, os.path.join(os.getcwd(), 'third_party/aot-benchmark/')) 10 | sys.path.insert(0, os.getcwd()) 11 | 12 | from __init__ import * 13 | 14 | # Internal imports. 15 | import resnet 16 | import vision_tf 17 | 18 | 19 | class QueryMaskTracker(torch.nn.Module): 20 | ''' 21 | X 22 | ''' 23 | 24 | def __init__(self, logger, num_total_frames=24, num_visible_frames=16, frame_height=224, 25 | frame_width=288, tracker_pretrained=False, attention_type='divided_space_time', 26 | patch_size=16, causal_attention=False, norm_embeddings=False, drop_path_rate=0.1, 27 | network_depth=12, track_map_stride=4, track_map_resize='bilinear', 28 | query_channels=1, output_channels=3, flag_channels=3): 29 | super().__init__() 30 | self.logger = logger 31 | self.num_total_frames = num_total_frames 32 | self.num_visible_frames = num_visible_frames 33 | self.frame_height = frame_height 34 | self.frame_width = frame_width 35 | self.attention_type = attention_type 36 | self.patch_size = patch_size 37 | self.causal_attention = causal_attention 38 | self.norm_embeddings = norm_embeddings 39 | self.drop_path_rate = drop_path_rate 40 | self.network_depth = network_depth 41 | self.track_map_stride = track_map_stride 42 | self.track_map_resize = track_map_resize 43 | self.query_channels = query_channels 44 | self.output_channels = output_channels 45 | self.flag_channels = flag_channels 46 | 47 | # Determine precise input shapes. 48 | self.input_channels = 3 + self.query_channels 49 | 50 | # Determine precise output shapes. 51 | self.output_channels = output_channels 52 | 53 | # Translate given pretrained info. 54 | self.pretrained_path = '' 55 | if isinstance(tracker_pretrained, bool): 56 | self.tracker_pretrained = tracker_pretrained 57 | elif isinstance(tracker_pretrained, str): 58 | # Consistent with _str2bool(). 59 | if tracker_pretrained.lower() in ['1', 'y', 'yes', 't', 'true']: 60 | self.tracker_pretrained = True # Defaults to vit_base_patch16_224 (ImageNet). 61 | elif len(tracker_pretrained) <= 5: 62 | self.tracker_pretrained = False 63 | else: 64 | self.tracker_pretrained = True 65 | self.pretrained_path = tracker_pretrained # Custom file path on disk. 66 | else: 67 | raise ValueError(f'Invalid tracker_pretrained value: {tracker_pretrained}.') 68 | self.logger.info(f'(QueryMaskTracker) tracker_pretrained: {self.tracker_pretrained} ' 69 | f'pretrained_path: {self.pretrained_path}') 70 | 71 | # Instantiate actual network components. 72 | # Instantiate tracker backbone. 73 | self.tracker_backbone = vision_tf.MyDenseTimeSformerBackbone( 74 | self.logger, num_frames=self.num_total_frames, frame_height=self.frame_height, 75 | frame_width=self.frame_width, patch_dim=self.patch_size, 76 | in_channels=self.input_channels, pretrained=self.tracker_pretrained, 77 | pretrained_path=self.pretrained_path, attention_type=self.attention_type, 78 | causal_attention=self.causal_attention, norm_embeddings=self.norm_embeddings, 79 | drop_path_rate=self.drop_path_rate, network_depth=self.network_depth) 80 | self.use_feature_dim = self.tracker_backbone.output_feature_dim 81 | 82 | # This applies to every spatiotemporal patch (typically C x 1 x 16 x 16). 83 | self.tracker_post_linear = torch.nn.Linear( 84 | self.use_feature_dim, self.output_channels * self.patch_size * self.patch_size) 85 | if self.flag_channels > 0: 86 | self.flag_post_linear = torch.nn.Linear(self.use_feature_dim, self.flag_channels) 87 | # Flags are typically (occluded, contained, soft_fraction). 88 | 89 | assert self.frame_height % self.patch_size == 0 90 | assert self.frame_width % self.patch_size == 0 91 | 92 | def forward(self, input_frames, query_mask): 93 | ''' 94 | Assumes input frames are already blacked out as appropriate. 95 | :param input_frames (B, 3-7, T, Hf, Wf) tensor. 96 | :param query_mask (B, C, T, Hf, Wf) tensor. 97 | :return (output_mask, output_flags). 98 | output_mask (B, C, T, Hf, Wf) tensor. 99 | output_flags (B, T, F) tensor. 100 | ''' 101 | # Append query information in desired way. 102 | (B, _, T, Hf, Wf) = input_frames.shape 103 | input_frames = input_frames.type(torch.float32) 104 | query_mask = query_mask.type(torch.float32) 105 | assert query_mask.shape[1] == 1 106 | 107 | input_with_query = input_frames.clone() 108 | input_with_query = torch.cat([input_with_query, query_mask], dim=1) 109 | 110 | (output_features, _) = self.tracker_backbone(input_with_query, None) 111 | 112 | output_features = rearrange(output_features, 'B D T H W -> B T H W D') 113 | output_patches = self.tracker_post_linear(output_features) # (B, T, H, W, D). 114 | output_mask = rearrange(output_patches, 'B T H W (C h w) -> B C T (H h) (W w)', 115 | C=self.output_channels, h=self.patch_size, w=self.patch_size) 116 | 117 | # Make output coarser such that optimization (hopefully) focuses less on precise boundaries. 118 | if self.track_map_stride > 1: 119 | # Awkward way to do this, but it works. 120 | output_mask = rearrange(output_mask, 'B C T Hf Wf -> (B T) C Hf Wf') 121 | output_mask = torch.nn.functional.avg_pool2d( 122 | output_mask, self.track_map_stride, self.track_map_stride) 123 | 124 | if self.track_map_resize == 'nearest': 125 | output_mask = torch.nn.functional.interpolate( 126 | output_mask, scale_factor=self.track_map_stride, mode='nearest') 127 | elif self.track_map_resize == 'bilinear': 128 | output_mask = torch.nn.functional.interpolate( 129 | output_mask, scale_factor=self.track_map_stride, mode='bilinear', 130 | align_corners=True) 131 | 132 | output_mask = rearrange(output_mask, '(B T) C Hf Wf -> B C T Hf Wf', B=B, T=T) 133 | 134 | # Calculate extra info per frame. 135 | if self.flag_channels > 0: 136 | output_flags = self.flag_post_linear(output_features) # (B, T, H, W, F). 137 | output_flags = output_flags.mean(dim=[-2, -3]) # (B, T, F). 138 | 139 | else: 140 | output_flags = None 141 | 142 | return (output_mask, output_flags) # (B, C, T, Hf, Wf), (B, T, F). 143 | -------------------------------------------------------------------------------- /model/resnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Neural network architectures for dense prediction via CNN-based image and/or video models. 3 | Created by Basile Van Hoorick for TCOW. 4 | ''' 5 | 6 | import os 7 | import sys 8 | sys.path.append(os.getcwd()) 9 | sys.path.append(os.path.join(os.getcwd(), 'model/')) 10 | 11 | from __init__ import * 12 | 13 | # Library imports. 14 | import os 15 | import sys 16 | import timm 17 | 18 | 19 | # NOTE: Not used in my augs, BUT used in most pretrained models. 20 | # https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/constants.py 21 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 22 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 23 | 24 | 25 | class DenseResNet(torch.nn.Module): 26 | 27 | def __init__(self, logger, timm_name, pretrained, frame_height, frame_width, patch_dim, 28 | in_channels): 29 | super().__init__() 30 | self.logger = logger 31 | self.timm_name = timm_name 32 | self.pretrained = pretrained 33 | # Frame size. 34 | self.Hf = frame_height 35 | self.Wf = frame_width 36 | # Number of patches. 37 | self.Ho = frame_height // patch_dim 38 | self.Wo = frame_width // patch_dim 39 | # Patch size. 40 | self.ho = patch_dim 41 | self.wo = patch_dim 42 | # Number of channels. 43 | self.Ci = in_channels 44 | 45 | # Instantiate model. 46 | self.resnet = timm.create_model(timm_name, pretrained=pretrained) 47 | 48 | for bottleneck in self.resnet.layer3: 49 | bottleneck.act3 = torch.nn.Sequential() 50 | assert self.ho == 16 and self.wo == 16 # We have four downsampling operations. 51 | self.output_feature_dim = 1024 # layer3 output size = final embedding size. 52 | 53 | # Replace first convolutional layer to accommodate non-standard inputs. 54 | if in_channels != 3: 55 | assert not(pretrained) 56 | self.resnet.conv1 = torch.nn.Conv2d( 57 | in_channels=in_channels, out_channels=64, kernel_size=(7, 7), stride=(2, 2), 58 | padding=(3, 3), bias=False) 59 | 60 | def forward(self, input_pixels): 61 | ''' 62 | :param input_pixels (B, C, H, W) tensor. 63 | ''' 64 | 65 | # Normalize if pretrained. 66 | if self.pretrained: 67 | mean = torch.tensor(IMAGENET_DEFAULT_MEAN, dtype=input_pixels.dtype, 68 | device=input_pixels.device) 69 | mean = mean[:, None, None].expand_as(input_pixels[0]) 70 | std = torch.tensor(IMAGENET_DEFAULT_STD, dtype=input_pixels.dtype, 71 | device=input_pixels.device) 72 | std = std[:, None, None].expand_as(input_pixels[0]) 73 | input_pixels = input_pixels - mean 74 | input_pixels = input_pixels / std 75 | 76 | # Adapted from 77 | # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/resnet.py. 78 | # Skip layer3 final ReLU, layer4, global_pool, flatten, and fc. 79 | x = self.resnet.conv1(input_pixels) 80 | x = self.resnet.bn1(x) 81 | x = self.resnet.act1(x) 82 | x = self.resnet.maxpool(x) 83 | x = self.resnet.layer1(x) 84 | x = self.resnet.layer2(x) 85 | x = self.resnet.layer3(x) 86 | output_features = x # (B, D, H, W). 87 | 88 | assert output_features.shape[1] == self.output_feature_dim 89 | 90 | return output_features 91 | 92 | 93 | class MyDenseResNetBackbone(DenseResNet): 94 | ''' 95 | Trainable variant of the DenseResNet. 96 | ''' 97 | 98 | def __init__(self, logger, frame_height=224, frame_width=288, in_channels=3, pretrained=False): 99 | super().__init__(logger, 'resnet50', pretrained, frame_height, frame_width, 16, in_channels) 100 | -------------------------------------------------------------------------------- /model/seeker.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Neural network architecture description. 3 | Created by Basile Van Hoorick for TCOW. 4 | ''' 5 | 6 | import os 7 | import sys 8 | sys.path.append(os.getcwd()) 9 | sys.path.append(os.path.join(os.getcwd(), 'seeker/')) 10 | 11 | from __init__ import * 12 | 13 | # Internal imports. 14 | import mask_tracker 15 | 16 | 17 | class Seeker(torch.nn.Module): 18 | 19 | def __init__(self, logger, **kwargs): 20 | super().__init__() 21 | self.logger = logger 22 | self.seeker = mask_tracker.QueryMaskTracker(logger, **kwargs) 23 | 24 | def forward(self, *args): 25 | return self.seeker(*args) 26 | -------------------------------------------------------------------------------- /model/vision_tf.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Neural network architectures for dense prediction via transformer-based image and/or video models. 3 | Created by Basile Van Hoorick for TCOW. 4 | ''' 5 | 6 | import os 7 | import sys 8 | sys.path.append(os.getcwd()) 9 | sys.path.append(os.path.join(os.getcwd(), 'model/')) 10 | 11 | from __init__ import * 12 | 13 | # Library imports. 14 | import os 15 | import sys 16 | import timm 17 | 18 | # Internal imports. 19 | from timesformer.models.vit import TimeSformer 20 | 21 | 22 | # https://github.com/facebookresearch/TimeSformer/blob/main/timesformer/config/defaults.py 23 | TIMESFORMER_MEAN = (0.45, 0.45, 0.45) 24 | TIMESFORMER_STD = (0.225, 0.225, 0.225) 25 | 26 | 27 | class DenseTimeSformer(torch.nn.Module): 28 | ''' 29 | Based on https://github.com/facebookresearch/TimeSformer. 30 | ''' 31 | 32 | def __init__(self, logger, pretrained, pretrained_path, frame_height, frame_width, 33 | patch_dim, in_channels, num_frames, attention_type, causal_attention, 34 | norm_embeddings, drop_path_rate, network_depth): 35 | super().__init__() 36 | self.logger = logger 37 | self.pretrained = pretrained 38 | # Frame size. 39 | self.Hf = frame_height 40 | self.Wf = frame_width 41 | # Number of patches. 42 | self.Ho = frame_height // patch_dim 43 | self.Wo = frame_width // patch_dim 44 | # Patch size. 45 | self.ho = patch_dim 46 | self.wo = patch_dim 47 | # Number of channels. 48 | self.Ci = in_channels 49 | # Extra options. 50 | self.T = num_frames 51 | self.attention_type = attention_type 52 | self.causal_attention = causal_attention 53 | self.norm_embeddings = norm_embeddings 54 | self.drop_path_rate = drop_path_rate 55 | self.network_depth = network_depth 56 | 57 | self.timesformer = TimeSformer( 58 | img_size=(self.Hf, self.Wf), patch_size=patch_dim, num_classes=0, num_frames=self.T, 59 | attention_type=self.attention_type, causal_attention=self.causal_attention, 60 | drop_path_rate=self.drop_path_rate, network_depth=network_depth, 61 | pretrained=self.pretrained, pretrained_model=pretrained_path, in_chans=self.Ci) 62 | self.output_feature_dim = self.timesformer.model.embed_dim # Typically 768 or 1024. 63 | 64 | # Taken from their dataset code (Kinetics and SSv2): 65 | # self.data_mean = [0.45, 0.45, 0.45] 66 | # self.data_std = [0.225, 0.225, 0.225] 67 | 68 | def forward(self, input_pixels, extra_token_in): 69 | ''' 70 | :param input_pixels (B, C, T, H, W) tensor. 71 | :param extra_token_in (B, D, N) tensor. 72 | :return output_features or (output_features, extra_token_out). 73 | output_features (B, D, T, H, W) tensor. 74 | extra_token_out (B, D, N) tensor. 75 | ''' 76 | 77 | # Normalize if pretrained. 78 | # https://github.com/facebookresearch/TimeSformer/issues/10 79 | # NOTE: If there are more than 3 channels, we assume the first 3 are RGB and leave the rest 80 | # untouched. 81 | if self.pretrained: 82 | mean = torch.tensor(TIMESFORMER_MEAN, dtype=input_pixels.dtype, 83 | device=input_pixels.device) 84 | mean = mean[:, None, None, None].expand_as(input_pixels[0, 0:3]) 85 | std = torch.tensor(TIMESFORMER_STD, dtype=input_pixels.dtype, 86 | device=input_pixels.device) 87 | std = std[:, None, None, None].expand_as(input_pixels[0, 0:3]) 88 | input_pixels[:, 0:3] = input_pixels[:, 0:3] - mean[None] 89 | input_pixels[:, 0:3] = input_pixels[:, 0:3] / std[None] 90 | 91 | # Adapted from 92 | # https://github.com/facebookresearch/TimeSformer/blob/main/timesformer/models/vit.py 93 | # See fr-timesformer (or Ctrl+Click TimeSformer) for actual code. 94 | B = input_pixels.shape[0] 95 | x, T, W = self.timesformer.model.patch_embed(input_pixels) 96 | assert T == self.T 97 | assert W == self.Wo 98 | 99 | cls_tokens = self.timesformer.model.cls_token.expand(x.size(0), -1, -1) 100 | x = torch.cat((cls_tokens, x), dim=1) 101 | 102 | # resizing the positional embeddings in case they don't match the input at inference 103 | if x.size(1) != self.timesformer.model.pos_embed.size(1): 104 | pos_embed = self.timesformer.model.pos_embed 105 | cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1) 106 | other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2) 107 | P = int(other_pos_embed.size(2) ** 0.5) 108 | H = x.size(1) // W 109 | other_pos_embed = other_pos_embed.reshape(1, x.size(2), P, P) 110 | new_pos_embed = torch.nn.functional.interpolate( 111 | other_pos_embed, size=(H, W), mode='nearest') 112 | new_pos_embed = new_pos_embed.flatten(2) 113 | new_pos_embed = new_pos_embed.transpose(1, 2) 114 | new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1) 115 | x = x + new_pos_embed 116 | else: 117 | x = x + self.timesformer.model.pos_embed 118 | 119 | x = self.timesformer.model.pos_drop(x) 120 | 121 | # Time Embeddings 122 | cls_tokens = x[:B, 0, :].unsqueeze(1) 123 | x = x[:, 1:] 124 | x = rearrange(x, '(b t) n m -> (b n) t m', b=B, t=T) 125 | 126 | # Resizing time embeddings in case they don't match 127 | if T != self.timesformer.model.time_embed.size(1): 128 | time_embed = self.timesformer.model.time_embed.transpose(1, 2) 129 | new_time_embed = torch.nn.functional.interpolate( 130 | time_embed, size=(T), mode='nearest') 131 | new_time_embed = new_time_embed.transpose(1, 2) 132 | x = x + new_time_embed 133 | else: 134 | x = x + self.timesformer.model.time_embed 135 | 136 | x = self.timesformer.model.time_drop(x) 137 | x = rearrange(x, '(b n) t m -> b (n t) m', b=B, t=T) 138 | x = torch.cat((cls_tokens, x), dim=1) 139 | 140 | # BVH MOD: 141 | if extra_token_in is not None: 142 | # Simply overwrite cls_token to avoid having to modify attention mask stuff in vit.py. 143 | assert extra_token_in.shape[-1] == 1 144 | x[:, 0, :] = extra_token_in.squeeze(-1) # (B, D). 145 | 146 | # Attention blocks 147 | y = x 148 | for blk in self.timesformer.model.blocks: 149 | y = blk(y, B, T, W) 150 | 151 | # Layer normalization is traditionally applied for all tokens. 152 | if self.norm_embeddings: 153 | y = self.timesformer.model.norm(y) 154 | 155 | # BVH MOD: 156 | # Separate output corresponding to cls / extra input token position. 157 | extra_token_out = y[:, 0, :].unsqueeze(-1) # (B, D, 1). 158 | 159 | # Discard cls_token altogether. 160 | # Skip head (traditionally for cls only). 161 | y = y[:, 1:] 162 | 163 | y = rearrange(y, 'B (H W T) D -> B D T H W', 164 | B=B, T=T, H=self.Ho, W=self.Wo, D=self.output_feature_dim) 165 | output_features = y # (B, D, T, H, W). 166 | 167 | assert output_features.shape[1] == self.output_feature_dim 168 | 169 | return (output_features, extra_token_out) 170 | 171 | 172 | class MyDenseTimeSformerBackbone(DenseTimeSformer): 173 | ''' 174 | Trainable variant of the DenseTimeSformerBackbone. 175 | ''' 176 | 177 | def __init__(self, logger, num_frames=16, frame_height=224, frame_width=288, 178 | patch_dim=16, in_channels=3, pretrained=False, pretrained_path='', 179 | attention_type='divided_space_time', causal_attention=False, norm_embeddings=False, 180 | drop_path_rate=0.1, network_depth=12): 181 | super().__init__( 182 | logger, pretrained, pretrained_path, frame_height, frame_width, patch_dim, in_channels, num_frames, 183 | attention_type, causal_attention, norm_embeddings, drop_path_rate, network_depth) 184 | 185 | 186 | if __name__ == '__main__': 187 | 188 | (B, T, H, W, C) = (2, 18, 192, 160, 3) 189 | patch_size = 16 190 | 191 | # print('MyDenseVisionTransformerBackbone') 192 | # my_vit = MyDenseVisionTransformerBackbone(None, H, W, C) 193 | 194 | # x = torch.randn(B, C, H, W) 195 | # print('x:', x.shape, x.min().item(), x.mean().item(), x.max().item()) 196 | 197 | # y = my_vit(x) 198 | # print('y:', y.shape, y.min().item(), y.mean().item(), y.max().item()) 199 | # print() 200 | 201 | # assert y.shape == (B, my_vit.output_feature_dim, H // 16, W // 16) 202 | 203 | for attention_type in ['divided_space_time', 'joint_space_time']: 204 | 205 | print('MyDenseTimeSformerBackbone') 206 | print('attention_type:', attention_type) 207 | my_tsf = MyDenseTimeSformerBackbone(None, T, H, W, patch_size, C, attention_type) 208 | 209 | x = torch.randn(B, C, T, H, W) 210 | print('x:', x.shape, x.min().item(), x.mean().item(), x.max().item()) 211 | 212 | y = my_tsf(x, None) 213 | print('y:', y.shape, y.min().item(), y.mean().item(), y.max().item()) 214 | print() 215 | 216 | assert y.shape == (B, my_tsf.output_feature_dim, T, H // patch_size, W // patch_size) 217 | 218 | pass 219 | -------------------------------------------------------------------------------- /rep_lists/kubric_containers.txt: -------------------------------------------------------------------------------- 1 | # File name patterns for guiding TCOW evaluation. 2 | # (See pick_represent.py to see how this is processed) 3 | 4 | kubbench_v3,s0_ku_d0_ 5 | kubbench_v3,s1_ku_d1_ 6 | kubbench_v3,s2_ku_d2_ 7 | kubbench_v3,s3_ku_d3_ 8 | kubbench_v3,s4_ku_d4_ 9 | kubbench_v3,s5_ku_d5_ 10 | kubbench_v3,s6_ku_d6_ 11 | kubbench_v3,s7_ku_d7_ 12 | kubbench_v3,s8_ku_d8_ 13 | kubbench_v3,s9_ku_d9_ 14 | kubbench_v3,s10_ku_d10_ 15 | kubbench_v3,s12_ku_d12_ 16 | kubbench_v3,s13_ku_d13_ 17 | kubbench_v3,s14_ku_d14_ 18 | kubbench_v3,s15_ku_d15_ 19 | kubbench_v3,s16_ku_d16_ 20 | kubbench_v3,s17_ku_d17_ 21 | kubbench_v3,s18_ku_d18_ 22 | kubbench_v3,s19_ku_d19_ 23 | kubbench_v3,s20_ku_d20_ 24 | kubbench_v3,s21_ku_d21_ 25 | kubbench_v3,s22_ku_d22_ 26 | kubbench_v3,s24_ku_d24_ 27 | kubbench_v3,s25_ku_d25_ 28 | kubbench_v3,s26_ku_d26_ 29 | kubbench_v3,s28_ku_d28_ 30 | kubbench_v3,s29_ku_d29_ 31 | -------------------------------------------------------------------------------- /rep_lists/kubric_random.txt: -------------------------------------------------------------------------------- 1 | # File name patterns for guiding TCOW evaluation. 2 | # (See pick_represent.py to see how this is processed) 3 | 4 | kubcon_v10, # Match in every row via scene_dn. 5 | -------------------------------------------------------------------------------- /rep_lists/rubric_cupgames.txt: -------------------------------------------------------------------------------- 1 | # File name patterns for guiding TCOW evaluation. 2 | # (See pick_represent.py to see how this is processed) 3 | 4 | cgt_frames_0002_i1_f100 # no movement yet 5 | cgt_frames_0002_i3_f100 6 | cgt_frames_0002_i4_f100 # reveal empty only 7 | cgt_frames_0002_i6_f100 # reveal empty + labeled snitch 8 | cgt_frames_0011_i5_f90 9 | cgt_frames_0011_i6_f90 # reveal empty only 10 | cgt_frames_0011_i8_f90 # reveal empty + labeled snitch 11 | cgt_frames_0026_i3_f60 # no movement yet 12 | cgt_frames_0026_i5_f60 13 | cgt_frames_0026_i7_f60 14 | cgt_frames_0026_i9_f60 # reveal empty only 15 | cgt_frames_0061_i2_f40 # no movement yet, also NL 16 | cgt_frames_0061_i4_f40 # also NL 17 | cgt_frames_0061_i5_f40 # also NL 18 | cgt_frames_0061_i6_f40 # also NL 19 | cgt_frames_0061_i7_f40 # also NL 20 | cgt_frames_0065_i2_f80 # only unlabeled movement (still implies MC) 21 | cgt_frames_0065_i5_f80 22 | cgt_frames_0065_i8_f80 # reveal empty + labeled snitch 23 | cgt_frames_0076_i2_f105 # no movement yet 24 | cgt_frames_0076_i8_f105 # reveal labeled snitch only 25 | cgt_frames_0092_i2_f80 # no movement yet 26 | cgt_frames_0092_i4_f80 27 | cgt_frames_0092_i7_f80 28 | cgt_frames_0113_i2_f40 # no movement yet, also NL 29 | cgt_frames_0113_i5_f40 # also NL 30 | cgt_frames_0113_i7_f40 # also NL 31 | cgt_frames_0113_i10_f40 # reveal labeled snitch only, also NL 32 | cgt_frames_0126_i2_f80 # no movement yet 33 | cgt_frames_0126_i5_f80 34 | cgt_frames_0126_i7_f80 35 | cgt_frames_0126_i8_f80 # reveal 1/2 empty only 36 | cgt_frames_0126_i10_f80 # reveal 1/2 empty + labeled snitch 37 | cgt_frames_0136_i2_f195 # only empty movement (still implies MC due to IC), also NL 38 | cgt_frames_0136_i4_f195 # reveal labeled snitch only, also NL 39 | cgt_frames_0137_i1_f60 # no movement yet, also NL 40 | cgt_frames_0137_i3_f60 # also NL 41 | cgt_frames_0137_i5_f60 # also NL 42 | cgt_frames_0137_i7_f60 # reveal 2/2 empty only, also NL 43 | cgt_frames_0154_i2_f60 # only unlabeled movement (still implies MC), also NL 44 | cgt_frames_0154_i4_f60 # also NL 45 | cgt_frames_0154_i5_f60 # also NL 46 | cgt_frames_0154_i7_f60 # reveal empty only, also NL 47 | cgt_frames_0154_i9_f60 # reveal empty + labeled snitch, also NL 48 | cgt_frames_0172_i2_f20 # only unlabeled movement (still implies MC) 49 | cgt_frames_0172_i3_f20 50 | cgt_frames_0172_i4_f20 51 | cgt_frames_0172_i5_f20 # reveal empty + labeled snitch at once at end 52 | cgt_frames_0172_i7_f20 # reveal empty + labeled snitch in middle, then keep shuffling 53 | cgt_frames_0211_i1_f90 # no movement yet 54 | cgt_frames_0211_i4_f90 55 | cgt_frames_0211_i7_f90 # reveal empty only 56 | cgt_frames_0211_i9_f90 # reveal empty + labeled snitch 57 | -------------------------------------------------------------------------------- /rep_lists/rubric_davytb.txt: -------------------------------------------------------------------------------- 1 | # File name patterns for guiding TCOW evaluation. 2 | # (See pick_represent.py to see how this is processed) 3 | 4 | # DAVIS. 5 | 6 | # No (or only partial) occlusion. 7 | train_rhino_i3_f0 8 | train_soccerball_i1_f0 9 | val_bmx-trees_i3_f3 # also NL 10 | val_libby_i1_f7 11 | 12 | # Full occlusion. 13 | test_giant-slalom_i3_f5 14 | test_people-sunset_i1_f6 15 | test_salsa1_i1_f0 # also IOT 16 | test_salsa2_i2_f0 # also IOT 17 | test_salsa3_i3_f0 # also IOT 18 | test_subway1_i3_f0 19 | test_subway2_i1_f0 20 | test_subway3_i2_f0 21 | train_bmx-bumps_i3_f0 # also NL 22 | train_dancing_i2_f3 23 | train_lindy-hop1_i2_f0 # also IOT 24 | train_lindy-hop2_i2_f1 # also IOT 25 | train_lindy-hop3_i2_f0 # also IOT 26 | train_scooter-board_i3_f0 # also NL 27 | val_india1_i2_f3 # also IOT 28 | val_india2_i2_f3 # also IOT 29 | val_india3_i2_f3 # also IOT 30 | val_pigs_i2_f4 # also IOT 31 | 32 | # YouTube-VOS. 33 | 34 | # No (or only partial) occlusion. 35 | val_0b97736357_i5_f20 36 | val_3f2012d518_i3_f8 # also NL 37 | 38 | # Full occlusion. 39 | val_0a49f5265b_i3_f6 # also NL 40 | val_0c04834d61_i2_f10 41 | val_0e4068b53f_i2_f35 # also IOT 42 | val_1b85035216_i5_f6 # also NL 43 | val_1e0257109e_i2_f10 44 | val_1e6efb0b5f_i2_f14 45 | val_3b72dc1941_rev_i1_f8 # also IOT 46 | val_3f4bacb16a_1_i1_f6 # also IOT 47 | val_4bef684040_i4_f20 # also IOT 48 | val_5c3d2d3155_i2_f37 49 | val_5d2020eff8_i3_f0 # also IOT 50 | val_7e625db8c4_i3_f0 # also IOT, NL 51 | val_24e2b52a4d_fish1_i1_f100 # also IOT 52 | val_24e2b52a4d_fish2_i1_f80 # also IOT 53 | val_33c8dcbe09_i2_f61 # also IOT 54 | val_42d810ba9d_i2_f42 # also IOT 55 | val_91f5ad52e9_i3_f9 56 | -------------------------------------------------------------------------------- /rep_lists/rubric_office.txt: -------------------------------------------------------------------------------- 1 | # File name patterns for guiding TCOW evaluation. 2 | # (See pick_represent.py to see how this is processed) 3 | 4 | # No (or only partial) occlusion or containment. 5 | chips1_oof_i5_f40 # 4_home, 1x labeled, also NL 6 | chips1_oof_i9_f40 # 4_home, 2x labeled, also NL 7 | pump1_scan_i5_f30 # 4_home, 1x labeled 8 | pump1_scan_i9_f30 # 4_home, 2x labeled 9 | 10 | # Full occlusion or containment. 11 | teaduck1_i3_f50 # 2_teaduck, no reveal 12 | teaduck1_i5_f50 # 2_teaduck, labeled reveal end 13 | teaduck2_i3_f15 # 2_teaduck, no reveal 14 | teaduck2_i5_f15 # 2_teaduck, unlabeled reveal end 15 | teaduck3_reveal_i5_f15 # 2_teaduck, unlabeled reveal end 16 | teaduck3_reveal_i7_f15 # 2_teaduck, labeled reveal end 17 | teaduck4_reveal_i3_f15 # 2_teaduck, no reveal 18 | teaduck4_reveal_i5_f15 # 2_teaduck, labeled reveal end 19 | teaduck4_reveal_i9_f15 # 2_teaduck, 2x labeled reveal 20 | teaduck5_cammove_i5_f15 # 2_teaduck, no reveal, also NL 21 | teaduck5_cammove_i7_f15 # 2_teaduck, no reveal, also NL 22 | teaduck6_teamove_i5_f15 # 2_teaduck, no reveal, also MC 23 | mugduck1_mugmove_i3_f10 # 3_mugduck, no reveal, maybe HD 24 | mugduck1_mugmove_i5_f10 # 3_mugduck, no reveal, also MC, maybe HD 25 | mugduck2_reveal_i5_f20 # 3_mugduck, unlabeled reveal end 26 | mugduck2_reveal_i7_f20 # 3_mugduck, labeled reveal end 27 | mugduck3_reveal_i3_f15 # 3_mugduck, no reveal, v111 bad 28 | mugduck3_reveal_i7_f15 # 3_mugduck, labeled reveal end, v111 bad 29 | mugduck4_mugmove_i3_f40 # 3_mugduck, no reveal 30 | mugduck4_mugmove_i5_f40 # 3_mugduck, no reveal, also MC 31 | mugduck4_mugmove_i7_f40 # 3_mugduck, no reveal, also MC 32 | multicupduck1_game_i3_f30 # 3_mugduck, no reveal, also IC 33 | multicupduck1_game_i5_f30 # 3_mugduck, no reveal, also MC, IC 34 | multicupduck1_game_i7_f30 # 3_mugduck, no reveal, also MC, IC 35 | multicupduck1_game_i9_f30 # 3_mugduck, no reveal, also MC, IC 36 | pumpcookie1_reveal_i5_f20 # 4_home 37 | bagduck1_move_i4_f20 # 5_bagmugduck, no reveal, also MC, HD 38 | bagduck1_move_i5_f20 # 5_bagmugduck, no reveal, also MC, HD 39 | bagduck1_move_i3_f50 # 5_bagmugduck, no reveal, also MC, HD 40 | bagduck1_move_i4_f50 # 5_bagmugduck, no reveal, also MC, HD 41 | bagduck3_recurse_i5_f15 # 5_bagmugduck, no reveal, also MC, HD, RT 42 | bagduck4_transfer_i5_f22 # 5_bagmugduck, brief reveal, also MC, HD, RT 43 | mugduck5_stay_i3_f20 # 5_bagmugduck, no reveal 44 | mugduck5_stay_i5_f20 # 5_bagmugduck, labeled reveal end 45 | mugduck6_shuffle_i5_f30 # 5_bagmugduck, unlabeled reveal end, also MC 46 | mugduck7_shuffle_i3_f30 # 5_bagmugduck, no reveal 47 | mugduck7_shuffle_i6_f30 # 5_bagmugduck, labeled reveal end, also MC 48 | handball1_wave_i3_f5 # 6_handball, no reveal, also MC, HD 49 | handball2_boxoccl_i3_f30 # 6_handball, no reveal, also MC, HD 50 | handball2_boxoccl_i5_f30 # 6_handball, no reveal, also MC, HD 51 | handball2_boxoccl_i7_f30 # 6_handball, labeled reveal end, also MC, HD 52 | handball3_occlpick_i3_f20 # 6_handball, no reveal, also MC, HD 53 | handball3_occlpick_i5_f20 # 6_handball, no reveal, also MC, HD, SA 54 | lightball4_occl1x_i10_f40 # 7_ballbounce, 2x labeled reveal 55 | lightball5_occl2x_i7_f30 # 7_ballbounce, 2x labeled reveal 56 | lightball5_occl2x_i9_f30 # 7_ballbounce, 2x labeled reveal 57 | lightball6_occl4x_i8_f15 # 7_ballbounce, labeled reveal end 58 | lightball6_occl4x_i10_f15 # 7_ballbounce, 2x labeled reveal 59 | plantcupball1_i3_f35 # 8_plantcupball, no reveal, also RT 60 | plantcupball1_i5_f35 # 8_plantcupball, no reveal, also RT 61 | plantcupball1_i7_f35 # 8_plantcupball, unlabeled reveal end, also RT 62 | plantcupball2_i3_f25 # 8_plantcupball, no reveal, also RT 63 | plantcupball2_i5_f25 # 8_plantcupball, no reveal, also MC, RT 64 | plantcupball2_i7_f25 # 8_plantcupball, no reveal, also MC, RT 65 | plantcupball2_i8_f25 # 8_plantcupball, no reveal, also MC, RT 66 | plantcupball3_i3_f45 # 8_plantcupball, no reveal 67 | plantcupball3_i5_f45 # 8_plantcupball, no reveal, also RT 68 | manycont1_nested_i3_f10 # 9_manycont, no reveal 69 | manycont1_nested_i5_f10 # 9_manycont, no reveal, also MC, RT 70 | manycont2_transfer_i3_f10 # 9_manycont, no reveal 71 | manycont2_transfer_i5_f10 # 9_manycont, brief reveal, also MC, RT 72 | manycont3_nested_i2_f0 # 9_manycont, no reveal 73 | manycont3_nested_i4_f0 # 9_manycont, no reveal, also MC, RT 74 | manycont4_transfer_i3_f10 # 9_manycont, no reveal 75 | manycont4_transfer_i7_f10 # 9_manycont, brief reveal, also MC, RT 76 | manycont5_transfer_i3_f5 # 9_manycont, no reveal 77 | manycont5_transfer_i5_f5 # 9_manycont, brief reveal, also MC, RT 78 | confuse1_i2_f20 # 10_confuse, no reveal 79 | confuse1_i4_f20 # 10_confuse, labeled reveal end 80 | confuse2_i4_f40 # 10_confuse, no reveal 81 | confuse3_i4_f15 # 10_confuse, no reveal 82 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | av 2 | einops 3 | ffmpeg 4 | fvcore 5 | imageio 6 | imageio[ffmpeg] 7 | joblib 8 | kubric 9 | lovely_numpy 10 | lovely_tensors 11 | matplotlib 12 | numpy 13 | opencv_python 14 | pandas 15 | Pillow 16 | Pillow 17 | psutil 18 | pybullet 19 | rich 20 | scikit_learn 21 | scipy 22 | seaborn 23 | setuptools 24 | simplejson 25 | timm 26 | torch_optimizer 27 | tqdm 28 | wandb 29 | -------------------------------------------------------------------------------- /third_party/TimeSformer/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | # Docker file from Python is inspired from here : 6 | # https://github.com/github/gitignore/blob/master/Python.gitignore 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | tests/report/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | 138 | # pytype static type analyzer 139 | .pytype/ 140 | 141 | 142 | # Cython debug symbols 143 | cython_debug/ 144 | -------------------------------------------------------------------------------- /third_party/TimeSformer/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /third_party/TimeSformer/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to TimeSformer 2 | 3 | ## Pull Requests 4 | We actively welcome your pull requests. 5 | 6 | 1. Fork the repo and create your branch from `master`. 7 | 2. If you've added code that should be tested, add tests. 8 | 3. If you've changed APIs, update the documentation. 9 | 4. Ensure the test suite passes. 10 | 5. Make sure your code lints. 11 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 12 | 13 | ## Contributor License Agreement ("CLA") 14 | In order to accept your pull request, we need you to submit a CLA. You only need 15 | to do this once to work on any of Facebook's open source projects. 16 | 17 | Complete your CLA here: 18 | 19 | ## Issues 20 | We use GitHub issues to track public bugs. Please ensure your description is 21 | clear and has sufficient instructions to be able to reproduce the issue. 22 | 23 | ## License 24 | By contributing to TimeSformer, you agree that your contributions will be licensed 25 | under the [LICENSE.md](LICENSE.md) file in the root directory of this source tree. 26 | -------------------------------------------------------------------------------- /third_party/TimeSformer/README.md: -------------------------------------------------------------------------------- 1 | # TimeSformer 2 | 3 | Taken from: 4 | 5 | https://github.com/facebookresearch/TimeSformer 6 | 7 | Usage: 8 | 9 | Clone this repository, go inside, and run `pip install -e .` 10 | -------------------------------------------------------------------------------- /third_party/TimeSformer/environment.yml: -------------------------------------------------------------------------------- 1 | name: timesformer 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - python>3.7 8 | - jupyterlab 9 | - pandas>=1.2 10 | - numpy>1.19 11 | - pytorch>=1.6 12 | - torchvision>=0.7 13 | - scikit-learn>=0.22 14 | - opencv>=4.2 15 | - pyyaml>=5.1 16 | - yacs>=0.1.6 17 | - einops>=0.3 18 | - tensorboard 19 | - psutil 20 | - tqdm 21 | - matplotlib 22 | - simplejson 23 | - pip 24 | - pip: 25 | - fvcore 26 | - av -------------------------------------------------------------------------------- /third_party/TimeSformer/example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "08fe0c59", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from pathlib import Path\n", 11 | "\n", 12 | "import torch\n", 13 | "from timesformer.models.vit import TimeSformer" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "id": "10239d32", 20 | "metadata": {}, 21 | "outputs": [ 22 | { 23 | "data": { 24 | "text/plain": [ 25 | "True" 26 | ] 27 | }, 28 | "execution_count": 2, 29 | "metadata": {}, 30 | "output_type": "execute_result" 31 | } 32 | ], 33 | "source": [ 34 | "model_file = Path.home()/'TimeSformer/models/TimeSformer_divST_8x32_224_K600.pyth'\n", 35 | "model_file.exists()" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 3, 41 | "id": "652fb03e", 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "model = TimeSformer(img_size=224, num_classes=600, num_frames=8, attention_type='divided_space_time', pretrained_model=str(model_file))\n", 46 | "\n", 47 | "dummy_video = torch.randn(2, 3, 8, 224, 224) # (batch x channels x frames x height x width)\n", 48 | "\n", 49 | "pred = model(dummy_video,) # (2, 600)" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 6, 55 | "id": "83de13c5-791c-4db7-aba4-6d29ce88584e", 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "assert pred.shape == (2,600)" 60 | ] 61 | } 62 | ], 63 | "metadata": { 64 | "kernelspec": { 65 | "display_name": "Python 3", 66 | "language": "python", 67 | "name": "python3" 68 | }, 69 | "language_info": { 70 | "codemirror_mode": { 71 | "name": "ipython", 72 | "version": 3 73 | }, 74 | "file_extension": ".py", 75 | "mimetype": "text/x-python", 76 | "name": "python", 77 | "nbconvert_exporter": "python", 78 | "pygments_lexer": "ipython3", 79 | "version": "3.9.4" 80 | } 81 | }, 82 | "nbformat": 4, 83 | "nbformat_minor": 5 84 | } 85 | -------------------------------------------------------------------------------- /third_party/TimeSformer/setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | line_length=100 3 | multi_line_output=4 4 | known_standard_library=numpy,setuptools 5 | known_myself=timesformer 6 | known_third_party=fvcore,av,torch,pycocotools,yacs,termcolor,scipy,simplejson,matplotlib,torchvision,yaml,tqdm,psutil,opencv-python,pandas,tensorboard,moviepy,sklearn,cv2 7 | no_lines_before=STDLIB,THIRDPARTY 8 | sections=FUTURE,STDLIB,THIRDPARTY,myself,FIRSTPARTY,LOCALFOLDER 9 | default_section=FIRSTPARTY 10 | 11 | [mypy] 12 | python_version=3.6 13 | ignore_missing_imports = True 14 | warn_unused_configs = True 15 | disallow_untyped_defs = True 16 | check_untyped_defs = True 17 | warn_unused_ignores = True 18 | warn_redundant_casts = True 19 | show_column_numbers = True 20 | follow_imports = silent 21 | allow_redefinition = True 22 | ; Require all functions to be annotated 23 | disallow_incomplete_defs = True 24 | -------------------------------------------------------------------------------- /third_party/TimeSformer/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | from setuptools import find_packages, setup 4 | 5 | setup( 6 | name="timesformer", 7 | version="1.0", 8 | author="FBAI", 9 | url="unknown", 10 | description="TimeSformer", 11 | keywords = [ 12 | 'artificial intelligence', 13 | 'attention mechanism', 14 | 'transformers', 15 | 'video classification', 16 | ], 17 | install_requires=[ 18 | 'einops>=0.3', 19 | 'torch>=1.6' 20 | ], 21 | extras_require={"tensorboard_video_visualization": ["moviepy"]}, 22 | packages=find_packages(exclude=("configs", "tests")), 23 | ) 24 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | from timesformer.utils.env import setup_environment 4 | 5 | setup_environment() 6 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/config/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/datasets/DATASET.md: -------------------------------------------------------------------------------- 1 | # Dataset Preparation 2 | 3 | ## Kinetics 4 | 5 | The Kinetics Dataset could be downloaded from the following [link](https://github.com/cvdfoundation/kinetics-dataset): 6 | 7 | After all the videos were downloaded, resize the video to the short edge size of 256, then prepare the csv files for training, validation, and testing set as `train.csv`, `val.csv`, `test.csv`. The format of the csv file is: 8 | 9 | ``` 10 | path_to_video_1 label_1 11 | path_to_video_2 label_2 12 | path_to_video_3 label_3 13 | ... 14 | path_to_video_N label_N 15 | ``` 16 | 17 | ## Something-Something V2 18 | 1. Please download the dataset and annotations from [dataset provider](https://20bn.com/datasets/something-something). 19 | 20 | 2. Download the *frame list* from the following links: ([train](https://dl.fbaipublicfiles.com/pyslowfast/dataset/ssv2/frame_lists/train.csv), [val](https://dl.fbaipublicfiles.com/pyslowfast/dataset/ssv2/frame_lists/val.csv)). 21 | 22 | 3. Extract the frames at 30 FPS. (We used ffmpeg-4.1.3 with command 23 | `ffmpeg -i "${video}" -r 30 -q:v 1 "${out_name}"` 24 | in experiments.) Please put the frames in a structure consistent with the frame lists. 25 | 26 | Please put all annotation json files and the frame lists in the same folder, and set `DATA.PATH_TO_DATA_DIR` to the path. Set `DATA.PATH_PREFIX` to be the path to the folder containing extracted frames. 27 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | from .build import DATASET_REGISTRY, build_dataset # noqa 4 | from .kinetics import Kinetics # noqa 5 | from .ssv2 import Ssv2 # noqa 6 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/datasets/build.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | from fvcore.common.registry import Registry 4 | 5 | DATASET_REGISTRY = Registry("DATASET") 6 | DATASET_REGISTRY.__doc__ = """ 7 | Registry for dataset. 8 | 9 | The registered object will be called with `obj(cfg, split)`. 10 | The call should return a `torch.utils.data.Dataset` object. 11 | """ 12 | 13 | 14 | def build_dataset(dataset_name, cfg, split): 15 | """ 16 | Build a dataset, defined by `dataset_name`. 17 | Args: 18 | dataset_name (str): the name of the dataset to be constructed. 19 | cfg (CfgNode): configs. Details can be found in 20 | slowfast/config/defaults.py 21 | split (str): the split of the data loader. Options include `train`, 22 | `val`, and `test`. 23 | Returns: 24 | Dataset: a constructed dataset specified by dataset_name. 25 | """ 26 | # Capitalize the the first letter of the dataset_name since the dataset_name 27 | # in configs may be in lowercase but the name of dataset class should always 28 | # start with an uppercase letter. 29 | name = dataset_name.capitalize() 30 | return DATASET_REGISTRY.get(name)(cfg, split) 31 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/datasets/loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Data loader.""" 4 | 5 | import itertools 6 | import numpy as np 7 | import torch 8 | from torch.utils.data._utils.collate import default_collate 9 | from torch.utils.data.distributed import DistributedSampler 10 | from torch.utils.data.sampler import RandomSampler 11 | 12 | from timesformer.datasets.multigrid_helper import ShortCycleBatchSampler 13 | 14 | from . import utils as utils 15 | from .build import build_dataset 16 | 17 | 18 | def detection_collate(batch): 19 | """ 20 | Collate function for detection task. Concatanate bboxes, labels and 21 | metadata from different samples in the first dimension instead of 22 | stacking them to have a batch-size dimension. 23 | Args: 24 | batch (tuple or list): data batch to collate. 25 | Returns: 26 | (tuple): collated detection data batch. 27 | """ 28 | inputs, labels, video_idx, extra_data = zip(*batch) 29 | inputs, video_idx = default_collate(inputs), default_collate(video_idx) 30 | labels = torch.tensor(np.concatenate(labels, axis=0)).float() 31 | 32 | collated_extra_data = {} 33 | for key in extra_data[0].keys(): 34 | data = [d[key] for d in extra_data] 35 | if key == "boxes" or key == "ori_boxes": 36 | # Append idx info to the bboxes before concatenating them. 37 | bboxes = [ 38 | np.concatenate( 39 | [np.full((data[i].shape[0], 1), float(i)), data[i]], axis=1 40 | ) 41 | for i in range(len(data)) 42 | ] 43 | bboxes = np.concatenate(bboxes, axis=0) 44 | collated_extra_data[key] = torch.tensor(bboxes).float() 45 | elif key == "metadata": 46 | collated_extra_data[key] = torch.tensor( 47 | list(itertools.chain(*data)) 48 | ).view(-1, 2) 49 | else: 50 | collated_extra_data[key] = default_collate(data) 51 | 52 | return inputs, labels, video_idx, collated_extra_data 53 | 54 | 55 | def construct_loader(cfg, split, is_precise_bn=False): 56 | """ 57 | Constructs the data loader for the given dataset. 58 | Args: 59 | cfg (CfgNode): configs. Details can be found in 60 | slowfast/config/defaults.py 61 | split (str): the split of the data loader. Options include `train`, 62 | `val`, and `test`. 63 | """ 64 | assert split in ["train", "val", "test"] 65 | if split in ["train"]: 66 | dataset_name = cfg.TRAIN.DATASET 67 | batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS)) 68 | shuffle = True 69 | drop_last = True 70 | elif split in ["val"]: 71 | dataset_name = cfg.TRAIN.DATASET 72 | batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS)) 73 | shuffle = False 74 | drop_last = False 75 | elif split in ["test"]: 76 | dataset_name = cfg.TEST.DATASET 77 | batch_size = int(cfg.TEST.BATCH_SIZE / max(1, cfg.NUM_GPUS)) 78 | shuffle = False 79 | drop_last = False 80 | 81 | # Construct the dataset 82 | dataset = build_dataset(dataset_name, cfg, split) 83 | 84 | if cfg.MULTIGRID.SHORT_CYCLE and split in ["train"] and not is_precise_bn: 85 | # Create a sampler for multi-process training 86 | sampler = utils.create_sampler(dataset, shuffle, cfg) 87 | batch_sampler = ShortCycleBatchSampler( 88 | sampler, batch_size=batch_size, drop_last=drop_last, cfg=cfg 89 | ) 90 | # Create a loader 91 | loader = torch.utils.data.DataLoader( 92 | dataset, 93 | batch_sampler=batch_sampler, 94 | num_workers=cfg.DATA_LOADER.NUM_WORKERS, 95 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY, 96 | worker_init_fn=utils.loader_worker_init_fn(dataset), 97 | ) 98 | else: 99 | # Create a sampler for multi-process training 100 | sampler = utils.create_sampler(dataset, shuffle, cfg) 101 | # Create a loader 102 | loader = torch.utils.data.DataLoader( 103 | dataset, 104 | batch_size=batch_size, 105 | shuffle=(False if sampler else shuffle), 106 | sampler=sampler, 107 | num_workers=cfg.DATA_LOADER.NUM_WORKERS, 108 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY, 109 | drop_last=drop_last, 110 | collate_fn=detection_collate if cfg.DETECTION.ENABLE else None, 111 | worker_init_fn=utils.loader_worker_init_fn(dataset), 112 | ) 113 | return loader 114 | 115 | 116 | def shuffle_dataset(loader, cur_epoch): 117 | """ " 118 | Shuffles the data. 119 | Args: 120 | loader (loader): data loader to perform shuffle. 121 | cur_epoch (int): number of the current epoch. 122 | """ 123 | sampler = ( 124 | loader.batch_sampler.sampler 125 | if isinstance(loader.batch_sampler, ShortCycleBatchSampler) 126 | else loader.sampler 127 | ) 128 | assert isinstance( 129 | sampler, (RandomSampler, DistributedSampler) 130 | ), "Sampler type '{}' not supported".format(type(sampler)) 131 | # RandomSampler handles shuffling automatically 132 | if isinstance(sampler, DistributedSampler): 133 | # DistributedSampler shuffles data based on epoch 134 | sampler.set_epoch(cur_epoch) 135 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/datasets/multigrid_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Helper functions for multigrid training.""" 4 | 5 | import numpy as np 6 | from torch._six import int_classes as _int_classes 7 | from torch.utils.data.sampler import Sampler 8 | 9 | 10 | class ShortCycleBatchSampler(Sampler): 11 | """ 12 | Extend Sampler to support "short cycle" sampling. 13 | See paper "A Multigrid Method for Efficiently Training Video Models", 14 | Wu et al., 2019 (https://arxiv.org/abs/1912.00998) for details. 15 | """ 16 | 17 | def __init__(self, sampler, batch_size, drop_last, cfg): 18 | if not isinstance(sampler, Sampler): 19 | raise ValueError( 20 | "sampler should be an instance of " 21 | "torch.utils.data.Sampler, but got sampler={}".format(sampler) 22 | ) 23 | if ( 24 | not isinstance(batch_size, _int_classes) 25 | or isinstance(batch_size, bool) 26 | or batch_size <= 0 27 | ): 28 | raise ValueError( 29 | "batch_size should be a positive integer value, " 30 | "but got batch_size={}".format(batch_size) 31 | ) 32 | if not isinstance(drop_last, bool): 33 | raise ValueError( 34 | "drop_last should be a boolean value, but got " 35 | "drop_last={}".format(drop_last) 36 | ) 37 | self.sampler = sampler 38 | self.drop_last = drop_last 39 | 40 | bs_factor = [ 41 | int( 42 | round( 43 | ( 44 | float(cfg.DATA.TRAIN_CROP_SIZE) 45 | / (s * cfg.MULTIGRID.DEFAULT_S) 46 | ) 47 | ** 2 48 | ) 49 | ) 50 | for s in cfg.MULTIGRID.SHORT_CYCLE_FACTORS 51 | ] 52 | 53 | self.batch_sizes = [ 54 | batch_size * bs_factor[0], 55 | batch_size * bs_factor[1], 56 | batch_size, 57 | ] 58 | 59 | def __iter__(self): 60 | counter = 0 61 | batch_size = self.batch_sizes[0] 62 | batch = [] 63 | for idx in self.sampler: 64 | batch.append((idx, counter % 3)) 65 | if len(batch) == batch_size: 66 | yield batch 67 | counter += 1 68 | batch_size = self.batch_sizes[counter % 3] 69 | batch = [] 70 | if len(batch) > 0 and not self.drop_last: 71 | yield batch 72 | 73 | def __len__(self): 74 | avg_batch_size = sum(self.batch_sizes) / 3.0 75 | if self.drop_last: 76 | return int(np.floor(len(self.sampler) / avg_batch_size)) 77 | else: 78 | return int(np.ceil(len(self.sampler) / avg_batch_size)) 79 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/datasets/video_container.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | import av 4 | 5 | 6 | def get_video_container(path_to_vid, multi_thread_decode=False, backend="pyav"): 7 | """ 8 | Given the path to the video, return the pyav video container. 9 | Args: 10 | path_to_vid (str): path to the video. 11 | multi_thread_decode (bool): if True, perform multi-thread decoding. 12 | backend (str): decoder backend, options include `pyav` and 13 | `torchvision`, default is `pyav`. 14 | Returns: 15 | container (container): video container. 16 | """ 17 | if backend == "torchvision": 18 | with open(path_to_vid, "rb") as fp: 19 | container = fp.read() 20 | return container 21 | elif backend == "pyav": 22 | #try: 23 | container = av.open(path_to_vid) 24 | if multi_thread_decode: 25 | # Enable multiple threads for decoding. 26 | container.streams.video[0].thread_type = "AUTO" 27 | #except: 28 | # container = None 29 | return container 30 | else: 31 | raise NotImplementedError("Unknown backend {}".format(backend)) 32 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | from .build import MODEL_REGISTRY, build_model # noqa 4 | from .custom_video_model_builder import * # noqa 5 | from .video_model_builder import ResNet, SlowFast # noqa 6 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/models/batchnorm_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """BatchNorm (BN) utility functions and custom batch-size BN implementations""" 4 | 5 | from functools import partial 6 | import torch 7 | import torch.distributed as dist 8 | import torch.nn as nn 9 | from torch.autograd.function import Function 10 | 11 | import timesformer.utils.distributed as du 12 | 13 | 14 | def get_norm(cfg): 15 | """ 16 | Args: 17 | cfg (CfgNode): model building configs, details are in the comments of 18 | the config file. 19 | Returns: 20 | nn.Module: the normalization layer. 21 | """ 22 | if cfg.BN.NORM_TYPE == "batchnorm": 23 | return nn.BatchNorm3d 24 | elif cfg.BN.NORM_TYPE == "sub_batchnorm": 25 | return partial(SubBatchNorm3d, num_splits=cfg.BN.NUM_SPLITS) 26 | elif cfg.BN.NORM_TYPE == "sync_batchnorm": 27 | return partial( 28 | NaiveSyncBatchNorm3d, num_sync_devices=cfg.BN.NUM_SYNC_DEVICES 29 | ) 30 | else: 31 | raise NotImplementedError( 32 | "Norm type {} is not supported".format(cfg.BN.NORM_TYPE) 33 | ) 34 | 35 | 36 | class SubBatchNorm3d(nn.Module): 37 | """ 38 | The standard BN layer computes stats across all examples in a GPU. In some 39 | cases it is desirable to compute stats across only a subset of examples 40 | (e.g., in multigrid training https://arxiv.org/abs/1912.00998). 41 | SubBatchNorm3d splits the batch dimension into N splits, and run BN on 42 | each of them separately (so that the stats are computed on each subset of 43 | examples (1/N of batch) independently. During evaluation, it aggregates 44 | the stats from all splits into one BN. 45 | """ 46 | 47 | def __init__(self, num_splits, **args): 48 | """ 49 | Args: 50 | num_splits (int): number of splits. 51 | args (list): other arguments. 52 | """ 53 | super(SubBatchNorm3d, self).__init__() 54 | self.num_splits = num_splits 55 | num_features = args["num_features"] 56 | # Keep only one set of weight and bias. 57 | if args.get("affine", True): 58 | self.affine = True 59 | args["affine"] = False 60 | self.weight = torch.nn.Parameter(torch.ones(num_features)) 61 | self.bias = torch.nn.Parameter(torch.zeros(num_features)) 62 | else: 63 | self.affine = False 64 | self.bn = nn.BatchNorm3d(**args) 65 | args["num_features"] = num_features * num_splits 66 | self.split_bn = nn.BatchNorm3d(**args) 67 | 68 | def _get_aggregated_mean_std(self, means, stds, n): 69 | """ 70 | Calculate the aggregated mean and stds. 71 | Args: 72 | means (tensor): mean values. 73 | stds (tensor): standard deviations. 74 | n (int): number of sets of means and stds. 75 | """ 76 | mean = means.view(n, -1).sum(0) / n 77 | std = ( 78 | stds.view(n, -1).sum(0) / n 79 | + ((means.view(n, -1) - mean) ** 2).view(n, -1).sum(0) / n 80 | ) 81 | return mean.detach(), std.detach() 82 | 83 | def aggregate_stats(self): 84 | """ 85 | Synchronize running_mean, and running_var. Call this before eval. 86 | """ 87 | if self.split_bn.track_running_stats: 88 | ( 89 | self.bn.running_mean.data, 90 | self.bn.running_var.data, 91 | ) = self._get_aggregated_mean_std( 92 | self.split_bn.running_mean, 93 | self.split_bn.running_var, 94 | self.num_splits, 95 | ) 96 | 97 | def forward(self, x): 98 | if self.training: 99 | n, c, t, h, w = x.shape 100 | x = x.view(n // self.num_splits, c * self.num_splits, t, h, w) 101 | x = self.split_bn(x) 102 | x = x.view(n, c, t, h, w) 103 | else: 104 | x = self.bn(x) 105 | if self.affine: 106 | x = x * self.weight.view((-1, 1, 1, 1)) 107 | x = x + self.bias.view((-1, 1, 1, 1)) 108 | return x 109 | 110 | 111 | class GroupGather(Function): 112 | """ 113 | GroupGather performs all gather on each of the local process/ GPU groups. 114 | """ 115 | 116 | @staticmethod 117 | def forward(ctx, input, num_sync_devices, num_groups): 118 | """ 119 | Perform forwarding, gathering the stats across different process/ GPU 120 | group. 121 | """ 122 | ctx.num_sync_devices = num_sync_devices 123 | ctx.num_groups = num_groups 124 | 125 | input_list = [ 126 | torch.zeros_like(input) for k in range(du.get_local_size()) 127 | ] 128 | dist.all_gather( 129 | input_list, input, async_op=False, group=du._LOCAL_PROCESS_GROUP 130 | ) 131 | 132 | inputs = torch.stack(input_list, dim=0) 133 | if num_groups > 1: 134 | rank = du.get_local_rank() 135 | group_idx = rank // num_sync_devices 136 | inputs = inputs[ 137 | group_idx 138 | * num_sync_devices : (group_idx + 1) 139 | * num_sync_devices 140 | ] 141 | inputs = torch.sum(inputs, dim=0) 142 | return inputs 143 | 144 | @staticmethod 145 | def backward(ctx, grad_output): 146 | """ 147 | Perform backwarding, gathering the gradients across different process/ GPU 148 | group. 149 | """ 150 | grad_output_list = [ 151 | torch.zeros_like(grad_output) for k in range(du.get_local_size()) 152 | ] 153 | dist.all_gather( 154 | grad_output_list, 155 | grad_output, 156 | async_op=False, 157 | group=du._LOCAL_PROCESS_GROUP, 158 | ) 159 | 160 | grads = torch.stack(grad_output_list, dim=0) 161 | if ctx.num_groups > 1: 162 | rank = du.get_local_rank() 163 | group_idx = rank // ctx.num_sync_devices 164 | grads = grads[ 165 | group_idx 166 | * ctx.num_sync_devices : (group_idx + 1) 167 | * ctx.num_sync_devices 168 | ] 169 | grads = torch.sum(grads, dim=0) 170 | return grads, None, None 171 | 172 | 173 | class NaiveSyncBatchNorm3d(nn.BatchNorm3d): 174 | def __init__(self, num_sync_devices, **args): 175 | """ 176 | Naive version of Synchronized 3D BatchNorm. 177 | Args: 178 | num_sync_devices (int): number of device to sync. 179 | args (list): other arguments. 180 | """ 181 | self.num_sync_devices = num_sync_devices 182 | if self.num_sync_devices > 0: 183 | assert du.get_local_size() % self.num_sync_devices == 0, ( 184 | du.get_local_size(), 185 | self.num_sync_devices, 186 | ) 187 | self.num_groups = du.get_local_size() // self.num_sync_devices 188 | else: 189 | self.num_sync_devices = du.get_local_size() 190 | self.num_groups = 1 191 | super(NaiveSyncBatchNorm3d, self).__init__(**args) 192 | 193 | def forward(self, input): 194 | if du.get_local_size() == 1 or not self.training: 195 | return super().forward(input) 196 | 197 | assert input.shape[0] > 0, "SyncBatchNorm does not support empty inputs" 198 | C = input.shape[1] 199 | mean = torch.mean(input, dim=[0, 2, 3, 4]) 200 | meansqr = torch.mean(input * input, dim=[0, 2, 3, 4]) 201 | 202 | vec = torch.cat([mean, meansqr], dim=0) 203 | vec = GroupGather.apply(vec, self.num_sync_devices, self.num_groups) * ( 204 | 1.0 / self.num_sync_devices 205 | ) 206 | 207 | mean, meansqr = torch.split(vec, C) 208 | var = meansqr - mean * mean 209 | self.running_mean += self.momentum * (mean.detach() - self.running_mean) 210 | self.running_var += self.momentum * (var.detach() - self.running_var) 211 | 212 | invstd = torch.rsqrt(var + self.eps) 213 | scale = self.weight * invstd 214 | bias = self.bias - mean * scale 215 | scale = scale.reshape(1, -1, 1, 1, 1) 216 | bias = bias.reshape(1, -1, 1, 1, 1) 217 | return input * scale + bias 218 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/models/build.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Model construction functions.""" 4 | 5 | import torch 6 | from fvcore.common.registry import Registry 7 | 8 | MODEL_REGISTRY = Registry("MODEL") 9 | MODEL_REGISTRY.__doc__ = """ 10 | Registry for video model. 11 | 12 | The registered object will be called with `obj(cfg)`. 13 | The call should return a `torch.nn.Module` object. 14 | """ 15 | 16 | 17 | def build_model(cfg, gpu_id=None): 18 | """ 19 | Builds the video model. 20 | Args: 21 | cfg (configs): configs that contains the hyper-parameters to build the 22 | backbone. Details can be seen in slowfast/config/defaults.py. 23 | gpu_id (Optional[int]): specify the gpu index to build model. 24 | """ 25 | if torch.cuda.is_available(): 26 | assert ( 27 | cfg.NUM_GPUS <= torch.cuda.device_count() 28 | ), "Cannot use more GPU devices than available" 29 | else: 30 | assert ( 31 | cfg.NUM_GPUS == 0 32 | ), "Cuda is not available. Please set `NUM_GPUS: 0 for running on CPUs." 33 | 34 | # Construct the model 35 | name = cfg.MODEL.MODEL_NAME 36 | model = MODEL_REGISTRY.get(name)(cfg) 37 | 38 | if cfg.NUM_GPUS: 39 | if gpu_id is None: 40 | # Determine the GPU used by the current process 41 | cur_device = torch.cuda.current_device() 42 | else: 43 | cur_device = gpu_id 44 | # Transfer the model to the current GPU device 45 | model = model.cuda(device=cur_device) 46 | 47 | 48 | # Use multi-process data parallel model in the multi-gpu setting 49 | if cfg.NUM_GPUS > 1: 50 | # Make model replica operate on the current device 51 | model = torch.nn.parallel.DistributedDataParallel( 52 | module=model, device_ids=[cur_device], output_device=cur_device 53 | ) 54 | return model 55 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/models/conv2d_same.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Ross Wightman 2 | # Conv2d w/ Same Padding 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from typing import Tuple, Optional 8 | 9 | import math 10 | from typing import List, Tuple 11 | #from .padding import pad_same, get_padding_value 12 | 13 | # Dynamically pad input x with 'SAME' padding for conv with specified args 14 | def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): 15 | ih, iw = x.size()[-2:] 16 | pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) 17 | if pad_h > 0 or pad_w > 0: 18 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value) 19 | return x 20 | 21 | # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution 22 | def get_same_padding(x: int, k: int, s: int, d: int): 23 | return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) 24 | 25 | def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: 26 | dynamic = False 27 | if isinstance(padding, str): 28 | # for any string padding, the padding will be calculated for you, one of three ways 29 | padding = padding.lower() 30 | if padding == 'same': 31 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact 32 | if is_static_pad(kernel_size, **kwargs): 33 | # static case, no extra overhead 34 | padding = get_padding(kernel_size, **kwargs) 35 | else: 36 | # dynamic 'SAME' padding, has runtime/GPU memory overhead 37 | padding = 0 38 | dynamic = True 39 | elif padding == 'valid': 40 | # 'VALID' padding, same as padding=0 41 | padding = 0 42 | else: 43 | # Default to PyTorch style 'same'-ish symmetric padding 44 | padding = get_padding(kernel_size, **kwargs) 45 | return padding, dynamic 46 | 47 | def conv2d_same( 48 | x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), 49 | padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): 50 | x = pad_same(x, weight.shape[-2:], stride, dilation) 51 | return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) 52 | 53 | 54 | class Conv2dSame(nn.Conv2d): 55 | """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions 56 | """ 57 | 58 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 59 | padding=0, dilation=1, groups=1, bias=True): 60 | super(Conv2dSame, self).__init__( 61 | in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) 62 | 63 | def forward(self, x): 64 | return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 65 | 66 | 67 | def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): 68 | padding = kwargs.pop('padding', '') 69 | kwargs.setdefault('bias', False) 70 | padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) 71 | if is_dynamic: 72 | return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) 73 | else: 74 | return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) 75 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/models/custom_video_model_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | 4 | """A More Flexible Video models.""" 5 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/models/head_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """ResNe(X)t Head helper.""" 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | class ResNetBasicHead(nn.Module): 9 | """ 10 | ResNe(X)t 3D head. 11 | This layer performs a fully-connected projection during training, when the 12 | input size is 1x1x1. It performs a convolutional projection during testing 13 | when the input size is larger than 1x1x1. If the inputs are from multiple 14 | different pathways, the inputs will be concatenated after pooling. 15 | """ 16 | 17 | def __init__( 18 | self, 19 | dim_in, 20 | num_classes, 21 | pool_size, 22 | dropout_rate=0.0, 23 | act_func="softmax", 24 | ): 25 | """ 26 | The `__init__` method of any subclass should also contain these 27 | arguments. 28 | ResNetBasicHead takes p pathways as input where p in [1, infty]. 29 | 30 | Args: 31 | dim_in (list): the list of channel dimensions of the p inputs to the 32 | ResNetHead. 33 | num_classes (int): the channel dimensions of the p outputs to the 34 | ResNetHead. 35 | pool_size (list): the list of kernel sizes of p spatial temporal 36 | poolings, temporal pool kernel size, spatial pool kernel size, 37 | spatial pool kernel size in order. 38 | dropout_rate (float): dropout rate. If equal to 0.0, perform no 39 | dropout. 40 | act_func (string): activation function to use. 'softmax': applies 41 | softmax on the output. 'sigmoid': applies sigmoid on the output. 42 | """ 43 | super(ResNetBasicHead, self).__init__() 44 | assert ( 45 | len({len(pool_size), len(dim_in)}) == 1 46 | ), "pathway dimensions are not consistent." 47 | self.num_pathways = len(pool_size) 48 | 49 | for pathway in range(self.num_pathways): 50 | if pool_size[pathway] is None: 51 | avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) 52 | else: 53 | avg_pool = nn.AvgPool3d(pool_size[pathway], stride=1) 54 | self.add_module("pathway{}_avgpool".format(pathway), avg_pool) 55 | 56 | if dropout_rate > 0.0: 57 | self.dropout = nn.Dropout(dropout_rate) 58 | # Perform FC in a fully convolutional manner. The FC layer will be 59 | # initialized with a different std comparing to convolutional layers. 60 | self.projection = nn.Linear(sum(dim_in), num_classes, bias=True) 61 | 62 | # Softmax for evaluation and testing. 63 | if act_func == "softmax": 64 | self.act = nn.Softmax(dim=4) 65 | elif act_func == "sigmoid": 66 | self.act = nn.Sigmoid() 67 | else: 68 | raise NotImplementedError( 69 | "{} is not supported as an activation" 70 | "function.".format(act_func) 71 | ) 72 | 73 | def forward(self, inputs): 74 | assert ( 75 | len(inputs) == self.num_pathways 76 | ), "Input tensor does not contain {} pathway".format(self.num_pathways) 77 | pool_out = [] 78 | for pathway in range(self.num_pathways): 79 | m = getattr(self, "pathway{}_avgpool".format(pathway)) 80 | pool_out.append(m(inputs[pathway])) 81 | x = torch.cat(pool_out, 1) 82 | # (N, C, T, H, W) -> (N, T, H, W, C). 83 | x = x.permute((0, 2, 3, 4, 1)) 84 | # Perform dropout. 85 | if hasattr(self, "dropout"): 86 | x = self.dropout(x) 87 | x = self.projection(x) 88 | 89 | # Performs fully convlutional inference. 90 | if not self.training: 91 | x = self.act(x) 92 | x = x.mean([1, 2, 3]) 93 | 94 | x = x.view(x.shape[0], -1) 95 | return x 96 | 97 | 98 | class X3DHead(nn.Module): 99 | """ 100 | X3D head. 101 | This layer performs a fully-connected projection during training, when the 102 | input size is 1x1x1. It performs a convolutional projection during testing 103 | when the input size is larger than 1x1x1. If the inputs are from multiple 104 | different pathways, the inputs will be concatenated after pooling. 105 | """ 106 | 107 | def __init__( 108 | self, 109 | dim_in, 110 | dim_inner, 111 | dim_out, 112 | num_classes, 113 | pool_size, 114 | dropout_rate=0.0, 115 | act_func="softmax", 116 | inplace_relu=True, 117 | eps=1e-5, 118 | bn_mmt=0.1, 119 | norm_module=nn.BatchNorm3d, 120 | bn_lin5_on=False, 121 | ): 122 | """ 123 | The `__init__` method of any subclass should also contain these 124 | arguments. 125 | X3DHead takes a 5-dim feature tensor (BxCxTxHxW) as input. 126 | 127 | Args: 128 | dim_in (float): the channel dimension C of the input. 129 | num_classes (int): the channel dimensions of the output. 130 | pool_size (float): a single entry list of kernel size for 131 | spatiotemporal pooling for the TxHxW dimensions. 132 | dropout_rate (float): dropout rate. If equal to 0.0, perform no 133 | dropout. 134 | act_func (string): activation function to use. 'softmax': applies 135 | softmax on the output. 'sigmoid': applies sigmoid on the output. 136 | inplace_relu (bool): if True, calculate the relu on the original 137 | input without allocating new memory. 138 | eps (float): epsilon for batch norm. 139 | bn_mmt (float): momentum for batch norm. Noted that BN momentum in 140 | PyTorch = 1 - BN momentum in Caffe2. 141 | norm_module (nn.Module): nn.Module for the normalization layer. The 142 | default is nn.BatchNorm3d. 143 | bn_lin5_on (bool): if True, perform normalization on the features 144 | before the classifier. 145 | """ 146 | super(X3DHead, self).__init__() 147 | self.pool_size = pool_size 148 | self.dropout_rate = dropout_rate 149 | self.num_classes = num_classes 150 | self.act_func = act_func 151 | self.eps = eps 152 | self.bn_mmt = bn_mmt 153 | self.inplace_relu = inplace_relu 154 | self.bn_lin5_on = bn_lin5_on 155 | self._construct_head(dim_in, dim_inner, dim_out, norm_module) 156 | 157 | def _construct_head(self, dim_in, dim_inner, dim_out, norm_module): 158 | 159 | self.conv_5 = nn.Conv3d( 160 | dim_in, 161 | dim_inner, 162 | kernel_size=(1, 1, 1), 163 | stride=(1, 1, 1), 164 | padding=(0, 0, 0), 165 | bias=False, 166 | ) 167 | self.conv_5_bn = norm_module( 168 | num_features=dim_inner, eps=self.eps, momentum=self.bn_mmt 169 | ) 170 | self.conv_5_relu = nn.ReLU(self.inplace_relu) 171 | 172 | if self.pool_size is None: 173 | self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) 174 | else: 175 | self.avg_pool = nn.AvgPool3d(self.pool_size, stride=1) 176 | 177 | self.lin_5 = nn.Conv3d( 178 | dim_inner, 179 | dim_out, 180 | kernel_size=(1, 1, 1), 181 | stride=(1, 1, 1), 182 | padding=(0, 0, 0), 183 | bias=False, 184 | ) 185 | if self.bn_lin5_on: 186 | self.lin_5_bn = norm_module( 187 | num_features=dim_out, eps=self.eps, momentum=self.bn_mmt 188 | ) 189 | self.lin_5_relu = nn.ReLU(self.inplace_relu) 190 | 191 | if self.dropout_rate > 0.0: 192 | self.dropout = nn.Dropout(self.dropout_rate) 193 | # Perform FC in a fully convolutional manner. The FC layer will be 194 | # initialized with a different std comparing to convolutional layers. 195 | self.projection = nn.Linear(dim_out, self.num_classes, bias=True) 196 | 197 | # Softmax for evaluation and testing. 198 | if self.act_func == "softmax": 199 | self.act = nn.Softmax(dim=4) 200 | elif self.act_func == "sigmoid": 201 | self.act = nn.Sigmoid() 202 | else: 203 | raise NotImplementedError( 204 | "{} is not supported as an activation" 205 | "function.".format(self.act_func) 206 | ) 207 | 208 | def forward(self, inputs): 209 | # In its current design the X3D head is only useable for a single 210 | # pathway input. 211 | assert len(inputs) == 1, "Input tensor does not contain 1 pathway" 212 | x = self.conv_5(inputs[0]) 213 | x = self.conv_5_bn(x) 214 | x = self.conv_5_relu(x) 215 | x = self.avg_pool(x) 216 | 217 | x = self.lin_5(x) 218 | if self.bn_lin5_on: 219 | x = self.lin_5_bn(x) 220 | x = self.lin_5_relu(x) 221 | 222 | # (N, C, T, H, W) -> (N, T, H, W, C). 223 | x = x.permute((0, 2, 3, 4, 1)) 224 | # Perform dropout. 225 | if hasattr(self, "dropout"): 226 | x = self.dropout(x) 227 | x = self.projection(x) 228 | 229 | # Performs fully convlutional inference. 230 | if not self.training: 231 | x = self.act(x) 232 | x = x.mean([1, 2, 3]) 233 | 234 | x = x.view(x.shape[0], -1) 235 | return x 236 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/models/linear.py: -------------------------------------------------------------------------------- 1 | """ Linear layer (alternate definition) 2 | """ 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn as nn 6 | 7 | class Linear(nn.Linear): 8 | def forward(self, input: torch.Tensor) -> torch.Tensor: 9 | if torch.jit.is_scripting(): 10 | bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None 11 | return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias) 12 | else: 13 | return F.linear(input, self.weight, self.bias) 14 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/models/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Loss functions.""" 4 | 5 | import torch.nn as nn 6 | 7 | _LOSSES = { 8 | "cross_entropy": nn.CrossEntropyLoss, 9 | "bce": nn.BCELoss, 10 | "bce_logit": nn.BCEWithLogitsLoss, 11 | } 12 | 13 | 14 | def get_loss_func(loss_name): 15 | """ 16 | Retrieve the loss given the loss name. 17 | Args (int): 18 | loss_name: the name of the loss to use. 19 | """ 20 | if loss_name not in _LOSSES.keys(): 21 | raise NotImplementedError("Loss {} is not supported".format(loss_name)) 22 | return _LOSSES[loss_name] 23 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/models/nonlocal_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Non-local helper""" 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class Nonlocal(nn.Module): 10 | """ 11 | Builds Non-local Neural Networks as a generic family of building 12 | blocks for capturing long-range dependencies. Non-local Network 13 | computes the response at a position as a weighted sum of the 14 | features at all positions. This building block can be plugged into 15 | many computer vision architectures. 16 | More details in the paper: https://arxiv.org/pdf/1711.07971.pdf 17 | """ 18 | 19 | def __init__( 20 | self, 21 | dim, 22 | dim_inner, 23 | pool_size=None, 24 | instantiation="softmax", 25 | zero_init_final_conv=False, 26 | zero_init_final_norm=True, 27 | norm_eps=1e-5, 28 | norm_momentum=0.1, 29 | norm_module=nn.BatchNorm3d, 30 | ): 31 | """ 32 | Args: 33 | dim (int): number of dimension for the input. 34 | dim_inner (int): number of dimension inside of the Non-local block. 35 | pool_size (list): the kernel size of spatial temporal pooling, 36 | temporal pool kernel size, spatial pool kernel size, spatial 37 | pool kernel size in order. By default pool_size is None, 38 | then there would be no pooling used. 39 | instantiation (string): supports two different instantiation method: 40 | "dot_product": normalizing correlation matrix with L2. 41 | "softmax": normalizing correlation matrix with Softmax. 42 | zero_init_final_conv (bool): If true, zero initializing the final 43 | convolution of the Non-local block. 44 | zero_init_final_norm (bool): 45 | If true, zero initializing the final batch norm of the Non-local 46 | block. 47 | norm_module (nn.Module): nn.Module for the normalization layer. The 48 | default is nn.BatchNorm3d. 49 | """ 50 | super(Nonlocal, self).__init__() 51 | self.dim = dim 52 | self.dim_inner = dim_inner 53 | self.pool_size = pool_size 54 | self.instantiation = instantiation 55 | self.use_pool = ( 56 | False 57 | if pool_size is None 58 | else any((size > 1 for size in pool_size)) 59 | ) 60 | self.norm_eps = norm_eps 61 | self.norm_momentum = norm_momentum 62 | self._construct_nonlocal( 63 | zero_init_final_conv, zero_init_final_norm, norm_module 64 | ) 65 | 66 | def _construct_nonlocal( 67 | self, zero_init_final_conv, zero_init_final_norm, norm_module 68 | ): 69 | # Three convolution heads: theta, phi, and g. 70 | self.conv_theta = nn.Conv3d( 71 | self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0 72 | ) 73 | self.conv_phi = nn.Conv3d( 74 | self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0 75 | ) 76 | self.conv_g = nn.Conv3d( 77 | self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0 78 | ) 79 | 80 | # Final convolution output. 81 | self.conv_out = nn.Conv3d( 82 | self.dim_inner, self.dim, kernel_size=1, stride=1, padding=0 83 | ) 84 | # Zero initializing the final convolution output. 85 | self.conv_out.zero_init = zero_init_final_conv 86 | 87 | # TODO: change the name to `norm` 88 | self.bn = norm_module( 89 | num_features=self.dim, 90 | eps=self.norm_eps, 91 | momentum=self.norm_momentum, 92 | ) 93 | # Zero initializing the final bn. 94 | self.bn.transform_final_bn = zero_init_final_norm 95 | 96 | # Optional to add the spatial-temporal pooling. 97 | if self.use_pool: 98 | self.pool = nn.MaxPool3d( 99 | kernel_size=self.pool_size, 100 | stride=self.pool_size, 101 | padding=[0, 0, 0], 102 | ) 103 | 104 | def forward(self, x): 105 | x_identity = x 106 | N, C, T, H, W = x.size() 107 | 108 | theta = self.conv_theta(x) 109 | 110 | # Perform temporal-spatial pooling to reduce the computation. 111 | if self.use_pool: 112 | x = self.pool(x) 113 | 114 | phi = self.conv_phi(x) 115 | g = self.conv_g(x) 116 | 117 | theta = theta.view(N, self.dim_inner, -1) 118 | phi = phi.view(N, self.dim_inner, -1) 119 | g = g.view(N, self.dim_inner, -1) 120 | 121 | # (N, C, TxHxW) * (N, C, TxHxW) => (N, TxHxW, TxHxW). 122 | theta_phi = torch.einsum("nct,ncp->ntp", (theta, phi)) 123 | # For original Non-local paper, there are two main ways to normalize 124 | # the affinity tensor: 125 | # 1) Softmax normalization (norm on exp). 126 | # 2) dot_product normalization. 127 | if self.instantiation == "softmax": 128 | # Normalizing the affinity tensor theta_phi before softmax. 129 | theta_phi = theta_phi * (self.dim_inner ** -0.5) 130 | theta_phi = nn.functional.softmax(theta_phi, dim=2) 131 | elif self.instantiation == "dot_product": 132 | spatial_temporal_dim = theta_phi.shape[2] 133 | theta_phi = theta_phi / spatial_temporal_dim 134 | else: 135 | raise NotImplementedError( 136 | "Unknown norm type {}".format(self.instantiation) 137 | ) 138 | 139 | # (N, TxHxW, TxHxW) * (N, C, TxHxW) => (N, C, TxHxW). 140 | theta_phi_g = torch.einsum("ntg,ncg->nct", (theta_phi, g)) 141 | 142 | # (N, C, TxHxW) => (N, C, T, H, W). 143 | theta_phi_g = theta_phi_g.view(N, self.dim_inner, T, H, W) 144 | 145 | p = self.conv_out(theta_phi_g) 146 | p = self.bn(p) 147 | return x_identity + p 148 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/models/operators.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Custom operators.""" 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class Swish(nn.Module): 10 | """Swish activation function: x * sigmoid(x).""" 11 | 12 | def __init__(self): 13 | super(Swish, self).__init__() 14 | 15 | def forward(self, x): 16 | return SwishEfficient.apply(x) 17 | 18 | 19 | class SwishEfficient(torch.autograd.Function): 20 | """Swish activation function: x * sigmoid(x).""" 21 | 22 | @staticmethod 23 | def forward(ctx, x): 24 | result = x * torch.sigmoid(x) 25 | ctx.save_for_backward(x) 26 | return result 27 | 28 | @staticmethod 29 | def backward(ctx, grad_output): 30 | x = ctx.saved_variables[0] 31 | sigmoid_x = torch.sigmoid(x) 32 | return grad_output * (sigmoid_x * (1 + x * (1 - sigmoid_x))) 33 | 34 | 35 | class SE(nn.Module): 36 | """Squeeze-and-Excitation (SE) block w/ Swish: AvgPool, FC, Swish, FC, Sigmoid.""" 37 | 38 | def _round_width(self, width, multiplier, min_width=8, divisor=8): 39 | """ 40 | Round width of filters based on width multiplier 41 | Args: 42 | width (int): the channel dimensions of the input. 43 | multiplier (float): the multiplication factor. 44 | min_width (int): the minimum width after multiplication. 45 | divisor (int): the new width should be dividable by divisor. 46 | """ 47 | if not multiplier: 48 | return width 49 | 50 | width *= multiplier 51 | min_width = min_width or divisor 52 | width_out = max( 53 | min_width, int(width + divisor / 2) // divisor * divisor 54 | ) 55 | if width_out < 0.9 * width: 56 | width_out += divisor 57 | return int(width_out) 58 | 59 | def __init__(self, dim_in, ratio, relu_act=True): 60 | """ 61 | Args: 62 | dim_in (int): the channel dimensions of the input. 63 | ratio (float): the channel reduction ratio for squeeze. 64 | relu_act (bool): whether to use ReLU activation instead 65 | of Swish (default). 66 | divisor (int): the new width should be dividable by divisor. 67 | """ 68 | super(SE, self).__init__() 69 | self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) 70 | dim_fc = self._round_width(dim_in, ratio) 71 | self.fc1 = nn.Conv3d(dim_in, dim_fc, 1, bias=True) 72 | self.fc1_act = nn.ReLU() if relu_act else Swish() 73 | self.fc2 = nn.Conv3d(dim_fc, dim_in, 1, bias=True) 74 | 75 | self.fc2_sig = nn.Sigmoid() 76 | 77 | def forward(self, x): 78 | x_in = x 79 | for module in self.children(): 80 | x = module(x) 81 | return x_in * x 82 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/models/optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Optimizer.""" 4 | 5 | import torch 6 | 7 | import timesformer.utils.lr_policy as lr_policy 8 | 9 | 10 | def construct_optimizer(model, cfg): 11 | """ 12 | Construct a stochastic gradient descent or ADAM optimizer with momentum. 13 | Details can be found in: 14 | Herbert Robbins, and Sutton Monro. "A stochastic approximation method." 15 | and 16 | Diederik P.Kingma, and Jimmy Ba. 17 | "Adam: A Method for Stochastic Optimization." 18 | 19 | Args: 20 | model (model): model to perform stochastic gradient descent 21 | optimization or ADAM optimization. 22 | cfg (config): configs of hyper-parameters of SGD or ADAM, includes base 23 | learning rate, momentum, weight_decay, dampening, and etc. 24 | """ 25 | # Batchnorm parameters. 26 | bn_params = [] 27 | # Non-batchnorm parameters. 28 | non_bn_parameters = [] 29 | for name, p in model.named_parameters(): 30 | if "bn" in name: 31 | bn_params.append(p) 32 | else: 33 | non_bn_parameters.append(p) 34 | # Apply different weight decay to Batchnorm and non-batchnorm parameters. 35 | # In Caffe2 classification codebase the weight decay for batchnorm is 0.0. 36 | # Having a different weight decay on batchnorm might cause a performance 37 | # drop. 38 | optim_params = [ 39 | {"params": bn_params, "weight_decay": cfg.BN.WEIGHT_DECAY}, 40 | {"params": non_bn_parameters, "weight_decay": cfg.SOLVER.WEIGHT_DECAY}, 41 | ] 42 | # Check all parameters will be passed into optimizer. 43 | assert len(list(model.parameters())) == len(non_bn_parameters) + len( 44 | bn_params 45 | ), "parameter size does not match: {} + {} != {}".format( 46 | len(non_bn_parameters), len(bn_params), len(list(model.parameters())) 47 | ) 48 | 49 | if cfg.SOLVER.OPTIMIZING_METHOD == "sgd": 50 | return torch.optim.SGD( 51 | optim_params, 52 | lr=cfg.SOLVER.BASE_LR, 53 | momentum=cfg.SOLVER.MOMENTUM, 54 | weight_decay=cfg.SOLVER.WEIGHT_DECAY, 55 | dampening=cfg.SOLVER.DAMPENING, 56 | nesterov=cfg.SOLVER.NESTEROV, 57 | ) 58 | elif cfg.SOLVER.OPTIMIZING_METHOD == "adam": 59 | return torch.optim.Adam( 60 | optim_params, 61 | lr=cfg.SOLVER.BASE_LR, 62 | betas=(0.9, 0.999), 63 | eps=1e-08, 64 | weight_decay=cfg.SOLVER.WEIGHT_DECAY, 65 | ) 66 | elif cfg.SOLVER.OPTIMIZING_METHOD == "adamw": 67 | return torch.optim.AdamW( 68 | optim_params, 69 | lr=cfg.SOLVER.BASE_LR, 70 | betas=(0.9, 0.999), 71 | eps=1e-08, 72 | weight_decay=cfg.SOLVER.WEIGHT_DECAY, 73 | ) 74 | else: 75 | raise NotImplementedError( 76 | "Does not support {} optimizer".format(cfg.SOLVER.OPTIMIZING_METHOD) 77 | ) 78 | 79 | 80 | def get_epoch_lr(cur_epoch, cfg): 81 | """ 82 | Retrieves the lr for the given epoch (as specified by the lr policy). 83 | Args: 84 | cfg (config): configs of hyper-parameters of ADAM, includes base 85 | learning rate, betas, and weight decays. 86 | cur_epoch (float): the number of epoch of the current training stage. 87 | """ 88 | return lr_policy.get_lr_at_epoch(cfg, cur_epoch) 89 | 90 | 91 | def set_lr(optimizer, new_lr): 92 | """ 93 | Sets the optimizer lr to the specified value. 94 | Args: 95 | optimizer (optim): the optimizer using to optimize the current network. 96 | new_lr (float): the new learning rate to set. 97 | """ 98 | for param_group in optimizer.param_groups: 99 | param_group["lr"] = new_lr 100 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/models/vit_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Ross Wightman 2 | # Various utility functions 3 | 4 | import torch 5 | import torch.nn as nn 6 | from functools import partial 7 | import math 8 | import warnings 9 | import torch.nn.functional as F 10 | 11 | from timesformer.models.helpers import load_pretrained 12 | from .build import MODEL_REGISTRY 13 | from itertools import repeat 14 | # from torch._six import container_abcs 15 | import collections.abc as container_abcs 16 | 17 | DEFAULT_CROP_PCT = 0.875 18 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 19 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 20 | IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) 21 | IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) 22 | IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) 23 | IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3) 24 | 25 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 26 | def norm_cdf(x): 27 | # Computes standard normal cumulative distribution function 28 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 29 | 30 | if (mean < a - 2 * std) or (mean > b + 2 * std): 31 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 32 | "The distribution of values may be incorrect.", 33 | stacklevel=2) 34 | 35 | with torch.no_grad(): 36 | # Values are generated by using a truncated uniform distribution and 37 | # then using the inverse CDF for the normal distribution. 38 | # Get upper and lower cdf values 39 | l = norm_cdf((a - mean) / std) 40 | u = norm_cdf((b - mean) / std) 41 | 42 | # Uniformly fill tensor with values from [l, u], then translate to 43 | # [2l-1, 2u-1]. 44 | tensor.uniform_(2 * l - 1, 2 * u - 1) 45 | 46 | # Use inverse cdf transform for normal distribution to get truncated 47 | # standard normal 48 | tensor.erfinv_() 49 | 50 | # Transform to proper mean, std 51 | tensor.mul_(std * math.sqrt(2.)) 52 | tensor.add_(mean) 53 | 54 | # Clamp to ensure it's in the proper range 55 | tensor.clamp_(min=a, max=b) 56 | return tensor 57 | 58 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 59 | # type: (Tensor, float, float, float, float) -> Tensor 60 | r"""Fills the input Tensor with values drawn from a truncated 61 | normal distribution. The values are effectively drawn from the 62 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 63 | with values outside :math:`[a, b]` redrawn until they are within 64 | the bounds. The method used for generating the random values works 65 | best when :math:`a \leq \text{mean} \leq b`. 66 | Args: 67 | tensor: an n-dimensional `torch.Tensor` 68 | mean: the mean of the normal distribution 69 | std: the standard deviation of the normal distribution 70 | a: the minimum cutoff value 71 | b: the maximum cutoff value 72 | Examples: 73 | >>> w = torch.empty(3, 5) 74 | >>> nn.init.trunc_normal_(w) 75 | """ 76 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 77 | 78 | # From PyTorch internals 79 | def _ntuple(n): 80 | def parse(x): 81 | if isinstance(x, container_abcs.Iterable): 82 | return x 83 | return tuple(repeat(x, n)) 84 | return parse 85 | to_2tuple = _ntuple(2) 86 | 87 | # Calculate symmetric padding for a convolution 88 | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: 89 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 90 | return padding 91 | 92 | def get_padding_value(padding, kernel_size, **kwargs): 93 | dynamic = False 94 | if isinstance(padding, str): 95 | # for any string padding, the padding will be calculated for you, one of three ways 96 | padding = padding.lower() 97 | if padding == 'same': 98 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact 99 | if is_static_pad(kernel_size, **kwargs): 100 | # static case, no extra overhead 101 | padding = get_padding(kernel_size, **kwargs) 102 | else: 103 | # dynamic 'SAME' padding, has runtime/GPU memory overhead 104 | padding = 0 105 | dynamic = True 106 | elif padding == 'valid': 107 | # 'VALID' padding, same as padding=0 108 | padding = 0 109 | else: 110 | # Default to PyTorch style 'same'-ish symmetric padding 111 | padding = get_padding(kernel_size, **kwargs) 112 | return padding, dynamic 113 | 114 | # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution 115 | def get_same_padding(x: int, k: int, s: int, d: int): 116 | return max((int(math.ceil(x // s)) - 1) * s + (k - 1) * d + 1 - x, 0) 117 | 118 | 119 | # Can SAME padding for given args be done statically? 120 | def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): 121 | return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 122 | 123 | 124 | # Dynamically pad input x with 'SAME' padding for conv with specified args 125 | #def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): 126 | def pad_same(x, k, s, d=(1, 1), value= 0): 127 | ih, iw = x.size()[-2:] 128 | pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) 129 | if pad_h > 0 or pad_w > 0: 130 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value) 131 | return x 132 | 133 | def adaptive_pool_feat_mult(pool_type='avg'): 134 | if pool_type == 'catavgmax': 135 | return 2 136 | else: 137 | return 1 138 | 139 | def drop_path(x, drop_prob: float = 0., training: bool = False): 140 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 141 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 142 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 143 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 144 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 145 | 'survival rate' as the argument. 146 | """ 147 | if drop_prob == 0. or not training: 148 | return x 149 | keep_prob = 1 - drop_prob 150 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 151 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 152 | random_tensor.floor_() # binarize 153 | output = x.div(keep_prob) * random_tensor 154 | return output 155 | 156 | class DropPath(nn.Module): 157 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 158 | """ 159 | def __init__(self, drop_prob=None): 160 | super(DropPath, self).__init__() 161 | self.drop_prob = drop_prob 162 | 163 | def forward(self, x): 164 | return drop_path(x, self.drop_prob, self.training) 165 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/utils/ava_evaluation/README.md: -------------------------------------------------------------------------------- 1 | The code under this folder is from the official [ActivityNet repo](https://github.com/activitynet/ActivityNet). 2 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/utils/ava_evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basilevh/tcow/a72e3e13a45e4156137328e5290f9e848d360367/third_party/TimeSformer/timesformer/utils/ava_evaluation/__init__.py -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/utils/ava_evaluation/ava_action_list_v2.1_for_activitynet_2018.pbtxt.txt: -------------------------------------------------------------------------------- 1 | item { 2 | name: "bend/bow (at the waist)" 3 | id: 1 4 | } 5 | item { 6 | name: "crouch/kneel" 7 | id: 3 8 | } 9 | item { 10 | name: "dance" 11 | id: 4 12 | } 13 | item { 14 | name: "fall down" 15 | id: 5 16 | } 17 | item { 18 | name: "get up" 19 | id: 6 20 | } 21 | item { 22 | name: "jump/leap" 23 | id: 7 24 | } 25 | item { 26 | name: "lie/sleep" 27 | id: 8 28 | } 29 | item { 30 | name: "martial art" 31 | id: 9 32 | } 33 | item { 34 | name: "run/jog" 35 | id: 10 36 | } 37 | item { 38 | name: "sit" 39 | id: 11 40 | } 41 | item { 42 | name: "stand" 43 | id: 12 44 | } 45 | item { 46 | name: "swim" 47 | id: 13 48 | } 49 | item { 50 | name: "walk" 51 | id: 14 52 | } 53 | item { 54 | name: "answer phone" 55 | id: 15 56 | } 57 | item { 58 | name: "carry/hold (an object)" 59 | id: 17 60 | } 61 | item { 62 | name: "climb (e.g., a mountain)" 63 | id: 20 64 | } 65 | item { 66 | name: "close (e.g., a door, a box)" 67 | id: 22 68 | } 69 | item { 70 | name: "cut" 71 | id: 24 72 | } 73 | item { 74 | name: "dress/put on clothing" 75 | id: 26 76 | } 77 | item { 78 | name: "drink" 79 | id: 27 80 | } 81 | item { 82 | name: "drive (e.g., a car, a truck)" 83 | id: 28 84 | } 85 | item { 86 | name: "eat" 87 | id: 29 88 | } 89 | item { 90 | name: "enter" 91 | id: 30 92 | } 93 | item { 94 | name: "hit (an object)" 95 | id: 34 96 | } 97 | item { 98 | name: "lift/pick up" 99 | id: 36 100 | } 101 | item { 102 | name: "listen (e.g., to music)" 103 | id: 37 104 | } 105 | item { 106 | name: "open (e.g., a window, a car door)" 107 | id: 38 108 | } 109 | item { 110 | name: "play musical instrument" 111 | id: 41 112 | } 113 | item { 114 | name: "point to (an object)" 115 | id: 43 116 | } 117 | item { 118 | name: "pull (an object)" 119 | id: 45 120 | } 121 | item { 122 | name: "push (an object)" 123 | id: 46 124 | } 125 | item { 126 | name: "put down" 127 | id: 47 128 | } 129 | item { 130 | name: "read" 131 | id: 48 132 | } 133 | item { 134 | name: "ride (e.g., a bike, a car, a horse)" 135 | id: 49 136 | } 137 | item { 138 | name: "sail boat" 139 | id: 51 140 | } 141 | item { 142 | name: "shoot" 143 | id: 52 144 | } 145 | item { 146 | name: "smoke" 147 | id: 54 148 | } 149 | item { 150 | name: "take a photo" 151 | id: 56 152 | } 153 | item { 154 | name: "text on/look at a cellphone" 155 | id: 57 156 | } 157 | item { 158 | name: "throw" 159 | id: 58 160 | } 161 | item { 162 | name: "touch (an object)" 163 | id: 59 164 | } 165 | item { 166 | name: "turn (e.g., a screwdriver)" 167 | id: 60 168 | } 169 | item { 170 | name: "watch (e.g., TV)" 171 | id: 61 172 | } 173 | item { 174 | name: "work on a computer" 175 | id: 62 176 | } 177 | item { 178 | name: "write" 179 | id: 63 180 | } 181 | item { 182 | name: "fight/hit (a person)" 183 | id: 64 184 | } 185 | item { 186 | name: "give/serve (an object) to (a person)" 187 | id: 65 188 | } 189 | item { 190 | name: "grab (a person)" 191 | id: 66 192 | } 193 | item { 194 | name: "hand clap" 195 | id: 67 196 | } 197 | item { 198 | name: "hand shake" 199 | id: 68 200 | } 201 | item { 202 | name: "hand wave" 203 | id: 69 204 | } 205 | item { 206 | name: "hug (a person)" 207 | id: 70 208 | } 209 | item { 210 | name: "kiss (a person)" 211 | id: 72 212 | } 213 | item { 214 | name: "lift (a person)" 215 | id: 73 216 | } 217 | item { 218 | name: "listen to (a person)" 219 | id: 74 220 | } 221 | item { 222 | name: "push (another person)" 223 | id: 76 224 | } 225 | item { 226 | name: "sing to (e.g., self, a person, a group)" 227 | id: 77 228 | } 229 | item { 230 | name: "take (an object) from (a person)" 231 | id: 78 232 | } 233 | item { 234 | name: "talk to (e.g., self, a person, a group)" 235 | id: 79 236 | } 237 | item { 238 | name: "watch (a person)" 239 | id: 80 240 | } 241 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/utils/ava_evaluation/label_map_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Label map utility functions.""" 16 | 17 | from __future__ import ( 18 | absolute_import, 19 | division, 20 | print_function, 21 | unicode_literals, 22 | ) 23 | import logging 24 | 25 | # from google.protobuf import text_format 26 | # from google3.third_party.tensorflow_models.object_detection.protos import string_int_label_map_pb2 27 | 28 | 29 | def _validate_label_map(label_map): 30 | """Checks if a label map is valid. 31 | 32 | Args: 33 | label_map: StringIntLabelMap to validate. 34 | 35 | Raises: 36 | ValueError: if label map is invalid. 37 | """ 38 | for item in label_map.item: 39 | if item.id < 1: 40 | raise ValueError("Label map ids should be >= 1.") 41 | 42 | 43 | def create_category_index(categories): 44 | """Creates dictionary of COCO compatible categories keyed by category id. 45 | 46 | Args: 47 | categories: a list of dicts, each of which has the following keys: 48 | 'id': (required) an integer id uniquely identifying this category. 49 | 'name': (required) string representing category name 50 | e.g., 'cat', 'dog', 'pizza'. 51 | 52 | Returns: 53 | category_index: a dict containing the same entries as categories, but keyed 54 | by the 'id' field of each category. 55 | """ 56 | category_index = {} 57 | for cat in categories: 58 | category_index[cat["id"]] = cat 59 | return category_index 60 | 61 | 62 | def get_max_label_map_index(label_map): 63 | """Get maximum index in label map. 64 | 65 | Args: 66 | label_map: a StringIntLabelMapProto 67 | 68 | Returns: 69 | an integer 70 | """ 71 | return max([item.id for item in label_map.item]) 72 | 73 | 74 | def convert_label_map_to_categories( 75 | label_map, max_num_classes, use_display_name=True 76 | ): 77 | """Loads label map proto and returns categories list compatible with eval. 78 | 79 | This function loads a label map and returns a list of dicts, each of which 80 | has the following keys: 81 | 'id': (required) an integer id uniquely identifying this category. 82 | 'name': (required) string representing category name 83 | e.g., 'cat', 'dog', 'pizza'. 84 | We only allow class into the list if its id-label_id_offset is 85 | between 0 (inclusive) and max_num_classes (exclusive). 86 | If there are several items mapping to the same id in the label map, 87 | we will only keep the first one in the categories list. 88 | 89 | Args: 90 | label_map: a StringIntLabelMapProto or None. If None, a default categories 91 | list is created with max_num_classes categories. 92 | max_num_classes: maximum number of (consecutive) label indices to include. 93 | use_display_name: (boolean) choose whether to load 'display_name' field 94 | as category name. If False or if the display_name field does not exist, 95 | uses 'name' field as category names instead. 96 | Returns: 97 | categories: a list of dictionaries representing all possible categories. 98 | """ 99 | categories = [] 100 | list_of_ids_already_added = [] 101 | if not label_map: 102 | label_id_offset = 1 103 | for class_id in range(max_num_classes): 104 | categories.append( 105 | { 106 | "id": class_id + label_id_offset, 107 | "name": "category_{}".format(class_id + label_id_offset), 108 | } 109 | ) 110 | return categories 111 | for item in label_map.item: 112 | if not 0 < item.id <= max_num_classes: 113 | logging.info( 114 | "Ignore item %d since it falls outside of requested " 115 | "label range.", 116 | item.id, 117 | ) 118 | continue 119 | if use_display_name and item.HasField("display_name"): 120 | name = item.display_name 121 | else: 122 | name = item.name 123 | if item.id not in list_of_ids_already_added: 124 | list_of_ids_already_added.append(item.id) 125 | categories.append({"id": item.id, "name": name}) 126 | return categories 127 | 128 | 129 | def load_labelmap(path): 130 | """Loads label map proto. 131 | 132 | Args: 133 | path: path to StringIntLabelMap proto text file. 134 | Returns: 135 | a StringIntLabelMapProto 136 | """ 137 | with open(path, "r") as fid: 138 | label_map_string = fid.read() 139 | label_map = string_int_label_map_pb2.StringIntLabelMap() 140 | try: 141 | text_format.Merge(label_map_string, label_map) 142 | except text_format.ParseError: 143 | label_map.ParseFromString(label_map_string) 144 | _validate_label_map(label_map) 145 | return label_map 146 | 147 | 148 | def get_label_map_dict(label_map_path, use_display_name=False): 149 | """Reads a label map and returns a dictionary of label names to id. 150 | 151 | Args: 152 | label_map_path: path to label_map. 153 | use_display_name: whether to use the label map items' display names as keys. 154 | 155 | Returns: 156 | A dictionary mapping label names to id. 157 | """ 158 | label_map = load_labelmap(label_map_path) 159 | label_map_dict = {} 160 | for item in label_map.item: 161 | if use_display_name: 162 | label_map_dict[item.display_name] = item.id 163 | else: 164 | label_map_dict[item.name] = item.id 165 | return label_map_dict 166 | 167 | 168 | def create_category_index_from_labelmap(label_map_path): 169 | """Reads a label map and returns a category index. 170 | 171 | Args: 172 | label_map_path: Path to `StringIntLabelMap` proto text file. 173 | 174 | Returns: 175 | A category index, which is a dictionary that maps integer ids to dicts 176 | containing categories, e.g. 177 | {1: {'id': 1, 'name': 'dog'}, 2: {'id': 2, 'name': 'cat'}, ...} 178 | """ 179 | label_map = load_labelmap(label_map_path) 180 | max_num_classes = max(item.id for item in label_map.item) 181 | categories = convert_label_map_to_categories(label_map, max_num_classes) 182 | return create_category_index(categories) 183 | 184 | 185 | def create_class_agnostic_category_index(): 186 | """Creates a category index with a single `object` class.""" 187 | return {1: {"id": 1, "name": "object"}} 188 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/utils/ava_evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Functions for computing metrics like precision, recall, CorLoc and etc.""" 17 | from __future__ import division 18 | import numpy as np 19 | 20 | 21 | def compute_precision_recall(scores, labels, num_gt): 22 | """Compute precision and recall. 23 | 24 | Args: 25 | scores: A float numpy array representing detection score 26 | labels: A boolean numpy array representing true/false positive labels 27 | num_gt: Number of ground truth instances 28 | 29 | Raises: 30 | ValueError: if the input is not of the correct format 31 | 32 | Returns: 33 | precision: Fraction of positive instances over detected ones. This value is 34 | None if no ground truth labels are present. 35 | recall: Fraction of detected positive instance over all positive instances. 36 | This value is None if no ground truth labels are present. 37 | 38 | """ 39 | if ( 40 | not isinstance(labels, np.ndarray) 41 | or labels.dtype != np.bool 42 | or len(labels.shape) != 1 43 | ): 44 | raise ValueError("labels must be single dimension bool numpy array") 45 | 46 | if not isinstance(scores, np.ndarray) or len(scores.shape) != 1: 47 | raise ValueError("scores must be single dimension numpy array") 48 | 49 | if num_gt < np.sum(labels): 50 | raise ValueError( 51 | "Number of true positives must be smaller than num_gt." 52 | ) 53 | 54 | if len(scores) != len(labels): 55 | raise ValueError("scores and labels must be of the same size.") 56 | 57 | if num_gt == 0: 58 | return None, None 59 | 60 | sorted_indices = np.argsort(scores) 61 | sorted_indices = sorted_indices[::-1] 62 | labels = labels.astype(int) 63 | true_positive_labels = labels[sorted_indices] 64 | false_positive_labels = 1 - true_positive_labels 65 | cum_true_positives = np.cumsum(true_positive_labels) 66 | cum_false_positives = np.cumsum(false_positive_labels) 67 | precision = cum_true_positives.astype(float) / ( 68 | cum_true_positives + cum_false_positives 69 | ) 70 | recall = cum_true_positives.astype(float) / num_gt 71 | return precision, recall 72 | 73 | 74 | def compute_average_precision(precision, recall): 75 | """Compute Average Precision according to the definition in VOCdevkit. 76 | 77 | Precision is modified to ensure that it does not decrease as recall 78 | decrease. 79 | 80 | Args: 81 | precision: A float [N, 1] numpy array of precisions 82 | recall: A float [N, 1] numpy array of recalls 83 | 84 | Raises: 85 | ValueError: if the input is not of the correct format 86 | 87 | Returns: 88 | average_precison: The area under the precision recall curve. NaN if 89 | precision and recall are None. 90 | 91 | """ 92 | if precision is None: 93 | if recall is not None: 94 | raise ValueError("If precision is None, recall must also be None") 95 | return np.NAN 96 | 97 | if not isinstance(precision, np.ndarray) or not isinstance( 98 | recall, np.ndarray 99 | ): 100 | raise ValueError("precision and recall must be numpy array") 101 | if precision.dtype != np.float or recall.dtype != np.float: 102 | raise ValueError("input must be float numpy array.") 103 | if len(precision) != len(recall): 104 | raise ValueError("precision and recall must be of the same size.") 105 | if not precision.size: 106 | return 0.0 107 | if np.amin(precision) < 0 or np.amax(precision) > 1: 108 | raise ValueError("Precision must be in the range of [0, 1].") 109 | if np.amin(recall) < 0 or np.amax(recall) > 1: 110 | raise ValueError("recall must be in the range of [0, 1].") 111 | if not all(recall[i] <= recall[i + 1] for i in range(len(recall) - 1)): 112 | raise ValueError("recall must be a non-decreasing array") 113 | 114 | recall = np.concatenate([[0], recall, [1]]) 115 | precision = np.concatenate([[0], precision, [0]]) 116 | 117 | # Preprocess precision to be a non-decreasing array 118 | for i in range(len(precision) - 2, -1, -1): 119 | precision[i] = np.maximum(precision[i], precision[i + 1]) 120 | 121 | indices = np.where(recall[1:] != recall[:-1])[0] + 1 122 | average_precision = np.sum( 123 | (recall[indices] - recall[indices - 1]) * precision[indices] 124 | ) 125 | return average_precision 126 | 127 | 128 | def compute_cor_loc( 129 | num_gt_imgs_per_class, num_images_correctly_detected_per_class 130 | ): 131 | """Compute CorLoc according to the definition in the following paper. 132 | 133 | https://www.robots.ox.ac.uk/~vgg/rg/papers/deselaers-eccv10.pdf 134 | 135 | Returns nans if there are no ground truth images for a class. 136 | 137 | Args: 138 | num_gt_imgs_per_class: 1D array, representing number of images containing 139 | at least one object instance of a particular class 140 | num_images_correctly_detected_per_class: 1D array, representing number of 141 | images that are correctly detected at least one object instance of a 142 | particular class 143 | 144 | Returns: 145 | corloc_per_class: A float numpy array represents the corloc score of each 146 | class 147 | """ 148 | # Divide by zero expected for classes with no gt examples. 149 | with np.errstate(divide="ignore", invalid="ignore"): 150 | return np.where( 151 | num_gt_imgs_per_class == 0, 152 | np.nan, 153 | num_images_correctly_detected_per_class / num_gt_imgs_per_class, 154 | ) 155 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/utils/ava_evaluation/np_box_list.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Numpy BoxList classes and functions.""" 17 | 18 | from __future__ import ( 19 | absolute_import, 20 | division, 21 | print_function, 22 | unicode_literals, 23 | ) 24 | import numpy as np 25 | 26 | 27 | class BoxList(object): 28 | """Box collection. 29 | 30 | BoxList represents a list of bounding boxes as numpy array, where each 31 | bounding box is represented as a row of 4 numbers, 32 | [y_min, x_min, y_max, x_max]. It is assumed that all bounding boxes within a 33 | given list correspond to a single image. 34 | 35 | Optionally, users can add additional related fields (such as 36 | objectness/classification scores). 37 | """ 38 | 39 | def __init__(self, data): 40 | """Constructs box collection. 41 | 42 | Args: 43 | data: a numpy array of shape [N, 4] representing box coordinates 44 | 45 | Raises: 46 | ValueError: if bbox data is not a numpy array 47 | ValueError: if invalid dimensions for bbox data 48 | """ 49 | if not isinstance(data, np.ndarray): 50 | raise ValueError("data must be a numpy array.") 51 | if len(data.shape) != 2 or data.shape[1] != 4: 52 | raise ValueError("Invalid dimensions for box data.") 53 | if data.dtype != np.float32 and data.dtype != np.float64: 54 | raise ValueError( 55 | "Invalid data type for box data: float is required." 56 | ) 57 | if not self._is_valid_boxes(data): 58 | raise ValueError( 59 | "Invalid box data. data must be a numpy array of " 60 | "N*[y_min, x_min, y_max, x_max]" 61 | ) 62 | self.data = {"boxes": data} 63 | 64 | def num_boxes(self): 65 | """Return number of boxes held in collections.""" 66 | return self.data["boxes"].shape[0] 67 | 68 | def get_extra_fields(self): 69 | """Return all non-box fields.""" 70 | return [k for k in self.data.keys() if k != "boxes"] 71 | 72 | def has_field(self, field): 73 | return field in self.data 74 | 75 | def add_field(self, field, field_data): 76 | """Add data to a specified field. 77 | 78 | Args: 79 | field: a string parameter used to speficy a related field to be accessed. 80 | field_data: a numpy array of [N, ...] representing the data associated 81 | with the field. 82 | Raises: 83 | ValueError: if the field is already exist or the dimension of the field 84 | data does not matches the number of boxes. 85 | """ 86 | if self.has_field(field): 87 | raise ValueError("Field " + field + "already exists") 88 | if len(field_data.shape) < 1 or field_data.shape[0] != self.num_boxes(): 89 | raise ValueError("Invalid dimensions for field data") 90 | self.data[field] = field_data 91 | 92 | def get(self): 93 | """Convenience function for accesssing box coordinates. 94 | 95 | Returns: 96 | a numpy array of shape [N, 4] representing box corners 97 | """ 98 | return self.get_field("boxes") 99 | 100 | def get_field(self, field): 101 | """Accesses data associated with the specified field in the box collection. 102 | 103 | Args: 104 | field: a string parameter used to speficy a related field to be accessed. 105 | 106 | Returns: 107 | a numpy 1-d array representing data of an associated field 108 | 109 | Raises: 110 | ValueError: if invalid field 111 | """ 112 | if not self.has_field(field): 113 | raise ValueError("field {} does not exist".format(field)) 114 | return self.data[field] 115 | 116 | def get_coordinates(self): 117 | """Get corner coordinates of boxes. 118 | 119 | Returns: 120 | a list of 4 1-d numpy arrays [y_min, x_min, y_max, x_max] 121 | """ 122 | box_coordinates = self.get() 123 | y_min = box_coordinates[:, 0] 124 | x_min = box_coordinates[:, 1] 125 | y_max = box_coordinates[:, 2] 126 | x_max = box_coordinates[:, 3] 127 | return [y_min, x_min, y_max, x_max] 128 | 129 | def _is_valid_boxes(self, data): 130 | """Check whether data fullfills the format of N*[ymin, xmin, ymax, xmin]. 131 | 132 | Args: 133 | data: a numpy array of shape [N, 4] representing box coordinates 134 | 135 | Returns: 136 | a boolean indicating whether all ymax of boxes are equal or greater than 137 | ymin, and all xmax of boxes are equal or greater than xmin. 138 | """ 139 | if data.shape[0] > 0: 140 | for i in range(data.shape[0]): 141 | if data[i, 0] > data[i, 2] or data[i, 1] > data[i, 3]: 142 | return False 143 | return True 144 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/utils/ava_evaluation/np_box_mask_list.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Numpy BoxMaskList classes and functions.""" 17 | 18 | from __future__ import ( 19 | absolute_import, 20 | division, 21 | print_function, 22 | unicode_literals, 23 | ) 24 | import numpy as np 25 | 26 | from . import np_box_list 27 | 28 | 29 | class BoxMaskList(np_box_list.BoxList): 30 | """Convenience wrapper for BoxList with masks. 31 | 32 | BoxMaskList extends the np_box_list.BoxList to contain masks as well. 33 | In particular, its constructor receives both boxes and masks. Note that the 34 | masks correspond to the full image. 35 | """ 36 | 37 | def __init__(self, box_data, mask_data): 38 | """Constructs box collection. 39 | 40 | Args: 41 | box_data: a numpy array of shape [N, 4] representing box coordinates 42 | mask_data: a numpy array of shape [N, height, width] representing masks 43 | with values are in {0,1}. The masks correspond to the full 44 | image. The height and the width will be equal to image height and width. 45 | 46 | Raises: 47 | ValueError: if bbox data is not a numpy array 48 | ValueError: if invalid dimensions for bbox data 49 | ValueError: if mask data is not a numpy array 50 | ValueError: if invalid dimension for mask data 51 | """ 52 | super(BoxMaskList, self).__init__(box_data) 53 | if not isinstance(mask_data, np.ndarray): 54 | raise ValueError("Mask data must be a numpy array.") 55 | if len(mask_data.shape) != 3: 56 | raise ValueError("Invalid dimensions for mask data.") 57 | if mask_data.dtype != np.uint8: 58 | raise ValueError( 59 | "Invalid data type for mask data: uint8 is required." 60 | ) 61 | if mask_data.shape[0] != box_data.shape[0]: 62 | raise ValueError( 63 | "There should be the same number of boxes and masks." 64 | ) 65 | self.data["masks"] = mask_data 66 | 67 | def get_masks(self): 68 | """Convenience function for accessing masks. 69 | 70 | Returns: 71 | a numpy array of shape [N, height, width] representing masks 72 | """ 73 | return self.get_field("masks") 74 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/utils/ava_evaluation/np_box_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Operations for [N, 4] numpy arrays representing bounding boxes. 17 | 18 | Example box operations that are supported: 19 | * Areas: compute bounding box areas 20 | * IOU: pairwise intersection-over-union scores 21 | """ 22 | from __future__ import ( 23 | absolute_import, 24 | division, 25 | print_function, 26 | unicode_literals, 27 | ) 28 | import numpy as np 29 | 30 | 31 | def area(boxes): 32 | """Computes area of boxes. 33 | 34 | Args: 35 | boxes: Numpy array with shape [N, 4] holding N boxes 36 | 37 | Returns: 38 | a numpy array with shape [N*1] representing box areas 39 | """ 40 | return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 41 | 42 | 43 | def intersection(boxes1, boxes2): 44 | """Compute pairwise intersection areas between boxes. 45 | 46 | Args: 47 | boxes1: a numpy array with shape [N, 4] holding N boxes 48 | boxes2: a numpy array with shape [M, 4] holding M boxes 49 | 50 | Returns: 51 | a numpy array with shape [N*M] representing pairwise intersection area 52 | """ 53 | [y_min1, x_min1, y_max1, x_max1] = np.split(boxes1, 4, axis=1) 54 | [y_min2, x_min2, y_max2, x_max2] = np.split(boxes2, 4, axis=1) 55 | 56 | all_pairs_min_ymax = np.minimum(y_max1, np.transpose(y_max2)) 57 | all_pairs_max_ymin = np.maximum(y_min1, np.transpose(y_min2)) 58 | intersect_heights = np.maximum( 59 | np.zeros(all_pairs_max_ymin.shape), 60 | all_pairs_min_ymax - all_pairs_max_ymin, 61 | ) 62 | all_pairs_min_xmax = np.minimum(x_max1, np.transpose(x_max2)) 63 | all_pairs_max_xmin = np.maximum(x_min1, np.transpose(x_min2)) 64 | intersect_widths = np.maximum( 65 | np.zeros(all_pairs_max_xmin.shape), 66 | all_pairs_min_xmax - all_pairs_max_xmin, 67 | ) 68 | return intersect_heights * intersect_widths 69 | 70 | 71 | def iou(boxes1, boxes2): 72 | """Computes pairwise intersection-over-union between box collections. 73 | 74 | Args: 75 | boxes1: a numpy array with shape [N, 4] holding N boxes. 76 | boxes2: a numpy array with shape [M, 4] holding N boxes. 77 | 78 | Returns: 79 | a numpy array with shape [N, M] representing pairwise iou scores. 80 | """ 81 | intersect = intersection(boxes1, boxes2) 82 | area1 = area(boxes1) 83 | area2 = area(boxes2) 84 | union = ( 85 | np.expand_dims(area1, axis=1) 86 | + np.expand_dims(area2, axis=0) 87 | - intersect 88 | ) 89 | return intersect / union 90 | 91 | 92 | def ioa(boxes1, boxes2): 93 | """Computes pairwise intersection-over-area between box collections. 94 | 95 | Intersection-over-area (ioa) between two boxes box1 and box2 is defined as 96 | their intersection area over box2's area. Note that ioa is not symmetric, 97 | that is, IOA(box1, box2) != IOA(box2, box1). 98 | 99 | Args: 100 | boxes1: a numpy array with shape [N, 4] holding N boxes. 101 | boxes2: a numpy array with shape [M, 4] holding N boxes. 102 | 103 | Returns: 104 | a numpy array with shape [N, M] representing pairwise ioa scores. 105 | """ 106 | intersect = intersection(boxes1, boxes2) 107 | areas = np.expand_dims(area(boxes2), axis=0) 108 | return intersect / areas 109 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/utils/ava_evaluation/np_mask_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Operations for [N, height, width] numpy arrays representing masks. 17 | 18 | Example mask operations that are supported: 19 | * Areas: compute mask areas 20 | * IOU: pairwise intersection-over-union scores 21 | """ 22 | from __future__ import ( 23 | absolute_import, 24 | division, 25 | print_function, 26 | unicode_literals, 27 | ) 28 | import numpy as np 29 | 30 | EPSILON = 1e-7 31 | 32 | 33 | def area(masks): 34 | """Computes area of masks. 35 | 36 | Args: 37 | masks: Numpy array with shape [N, height, width] holding N masks. Masks 38 | values are of type np.uint8 and values are in {0,1}. 39 | 40 | Returns: 41 | a numpy array with shape [N*1] representing mask areas. 42 | 43 | Raises: 44 | ValueError: If masks.dtype is not np.uint8 45 | """ 46 | if masks.dtype != np.uint8: 47 | raise ValueError("Masks type should be np.uint8") 48 | return np.sum(masks, axis=(1, 2), dtype=np.float32) 49 | 50 | 51 | def intersection(masks1, masks2): 52 | """Compute pairwise intersection areas between masks. 53 | 54 | Args: 55 | masks1: a numpy array with shape [N, height, width] holding N masks. Masks 56 | values are of type np.uint8 and values are in {0,1}. 57 | masks2: a numpy array with shape [M, height, width] holding M masks. Masks 58 | values are of type np.uint8 and values are in {0,1}. 59 | 60 | Returns: 61 | a numpy array with shape [N*M] representing pairwise intersection area. 62 | 63 | Raises: 64 | ValueError: If masks1 and masks2 are not of type np.uint8. 65 | """ 66 | if masks1.dtype != np.uint8 or masks2.dtype != np.uint8: 67 | raise ValueError("masks1 and masks2 should be of type np.uint8") 68 | n = masks1.shape[0] 69 | m = masks2.shape[0] 70 | answer = np.zeros([n, m], dtype=np.float32) 71 | for i in np.arange(n): 72 | for j in np.arange(m): 73 | answer[i, j] = np.sum( 74 | np.minimum(masks1[i], masks2[j]), dtype=np.float32 75 | ) 76 | return answer 77 | 78 | 79 | def iou(masks1, masks2): 80 | """Computes pairwise intersection-over-union between mask collections. 81 | 82 | Args: 83 | masks1: a numpy array with shape [N, height, width] holding N masks. Masks 84 | values are of type np.uint8 and values are in {0,1}. 85 | masks2: a numpy array with shape [M, height, width] holding N masks. Masks 86 | values are of type np.uint8 and values are in {0,1}. 87 | 88 | Returns: 89 | a numpy array with shape [N, M] representing pairwise iou scores. 90 | 91 | Raises: 92 | ValueError: If masks1 and masks2 are not of type np.uint8. 93 | """ 94 | if masks1.dtype != np.uint8 or masks2.dtype != np.uint8: 95 | raise ValueError("masks1 and masks2 should be of type np.uint8") 96 | intersect = intersection(masks1, masks2) 97 | area1 = area(masks1) 98 | area2 = area(masks2) 99 | union = ( 100 | np.expand_dims(area1, axis=1) 101 | + np.expand_dims(area2, axis=0) 102 | - intersect 103 | ) 104 | return intersect / np.maximum(union, EPSILON) 105 | 106 | 107 | def ioa(masks1, masks2): 108 | """Computes pairwise intersection-over-area between box collections. 109 | 110 | Intersection-over-area (ioa) between two masks, mask1 and mask2 is defined as 111 | their intersection area over mask2's area. Note that ioa is not symmetric, 112 | that is, IOA(mask1, mask2) != IOA(mask2, mask1). 113 | 114 | Args: 115 | masks1: a numpy array with shape [N, height, width] holding N masks. Masks 116 | values are of type np.uint8 and values are in {0,1}. 117 | masks2: a numpy array with shape [M, height, width] holding N masks. Masks 118 | values are of type np.uint8 and values are in {0,1}. 119 | 120 | Returns: 121 | a numpy array with shape [N, M] representing pairwise ioa scores. 122 | 123 | Raises: 124 | ValueError: If masks1 and masks2 are not of type np.uint8. 125 | """ 126 | if masks1.dtype != np.uint8 or masks2.dtype != np.uint8: 127 | raise ValueError("masks1 and masks2 should be of type np.uint8") 128 | intersect = intersection(masks1, masks2) 129 | areas = np.expand_dims(area(masks2), axis=0) 130 | return intersect / (areas + EPSILON) 131 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/utils/benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Functions for benchmarks. 4 | """ 5 | 6 | import numpy as np 7 | import pprint 8 | import torch 9 | import tqdm 10 | from fvcore.common.timer import Timer 11 | 12 | import timesformer.utils.logging as logging 13 | import timesformer.utils.misc as misc 14 | from timesformer.datasets import loader 15 | from timesformer.utils.env import setup_environment 16 | 17 | logger = logging.get_logger(__name__) 18 | 19 | 20 | def benchmark_data_loading(cfg): 21 | """ 22 | Benchmark the speed of data loading in PySlowFast. 23 | Args: 24 | 25 | cfg (CfgNode): configs. Details can be found in 26 | lib/config/defaults.py 27 | """ 28 | # Set up environment. 29 | setup_environment() 30 | # Set random seed from configs. 31 | np.random.seed(cfg.RNG_SEED) 32 | torch.manual_seed(cfg.RNG_SEED) 33 | 34 | # Setup logging format. 35 | logging.setup_logging(cfg.OUTPUT_DIR) 36 | 37 | # Print config. 38 | logger.info("Benchmark data loading with config:") 39 | logger.info(pprint.pformat(cfg)) 40 | 41 | timer = Timer() 42 | dataloader = loader.construct_loader(cfg, "train") 43 | logger.info( 44 | "Initialize loader using {:.2f} seconds.".format(timer.seconds()) 45 | ) 46 | # Total batch size across different machines. 47 | batch_size = cfg.TRAIN.BATCH_SIZE * cfg.NUM_SHARDS 48 | log_period = cfg.BENCHMARK.LOG_PERIOD 49 | epoch_times = [] 50 | # Test for a few epochs. 51 | for cur_epoch in range(cfg.BENCHMARK.NUM_EPOCHS): 52 | timer = Timer() 53 | timer_epoch = Timer() 54 | iter_times = [] 55 | if cfg.BENCHMARK.SHUFFLE: 56 | loader.shuffle_dataset(dataloader, cur_epoch) 57 | for cur_iter, _ in enumerate(tqdm.tqdm(dataloader)): 58 | if cur_iter > 0 and cur_iter % log_period == 0: 59 | iter_times.append(timer.seconds()) 60 | ram_usage, ram_total = misc.cpu_mem_usage() 61 | logger.info( 62 | "Epoch {}: {} iters ({} videos) in {:.2f} seconds. " 63 | "RAM Usage: {:.2f}/{:.2f} GB.".format( 64 | cur_epoch, 65 | log_period, 66 | log_period * batch_size, 67 | iter_times[-1], 68 | ram_usage, 69 | ram_total, 70 | ) 71 | ) 72 | timer.reset() 73 | epoch_times.append(timer_epoch.seconds()) 74 | ram_usage, ram_total = misc.cpu_mem_usage() 75 | logger.info( 76 | "Epoch {}: in total {} iters ({} videos) in {:.2f} seconds. " 77 | "RAM Usage: {:.2f}/{:.2f} GB.".format( 78 | cur_epoch, 79 | len(dataloader), 80 | len(dataloader) * batch_size, 81 | epoch_times[-1], 82 | ram_usage, 83 | ram_total, 84 | ) 85 | ) 86 | logger.info( 87 | "Epoch {}: on average every {} iters ({} videos) take {:.2f}/{:.2f} " 88 | "(avg/std) seconds.".format( 89 | cur_epoch, 90 | log_period, 91 | log_period * batch_size, 92 | np.mean(iter_times), 93 | np.std(iter_times), 94 | ) 95 | ) 96 | logger.info( 97 | "On average every epoch ({} videos) takes {:.2f}/{:.2f} " 98 | "(avg/std) seconds.".format( 99 | len(dataloader) * batch_size, 100 | np.mean(epoch_times), 101 | np.std(epoch_times), 102 | ) 103 | ) 104 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/utils/bn_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """bn helper.""" 4 | 5 | import itertools 6 | import torch 7 | 8 | 9 | @torch.no_grad() 10 | def compute_and_update_bn_stats(model, data_loader, num_batches=200): 11 | """ 12 | Compute and update the batch norm stats to make it more precise. During 13 | training both bn stats and the weight are changing after every iteration, 14 | so the bn can not precisely reflect the latest stats of the current model. 15 | Here the bn stats is recomputed without change of weights, to make the 16 | running mean and running var more precise. 17 | Args: 18 | model (model): the model using to compute and update the bn stats. 19 | data_loader (dataloader): dataloader using to provide inputs. 20 | num_batches (int): running iterations using to compute the stats. 21 | """ 22 | 23 | # Prepares all the bn layers. 24 | bn_layers = [ 25 | m 26 | for m in model.modules() 27 | if any( 28 | ( 29 | isinstance(m, bn_type) 30 | for bn_type in ( 31 | torch.nn.BatchNorm1d, 32 | torch.nn.BatchNorm2d, 33 | torch.nn.BatchNorm3d, 34 | ) 35 | ) 36 | ) 37 | ] 38 | 39 | # In order to make the running stats only reflect the current batch, the 40 | # momentum is disabled. 41 | # bn.running_mean = (1 - momentum) * bn.running_mean + momentum * batch_mean 42 | # Setting the momentum to 1.0 to compute the stats without momentum. 43 | momentum_actual = [bn.momentum for bn in bn_layers] 44 | for bn in bn_layers: 45 | bn.momentum = 1.0 46 | 47 | # Calculates the running iterations for precise stats computation. 48 | running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers] 49 | running_square_mean = [torch.zeros_like(bn.running_var) for bn in bn_layers] 50 | 51 | for ind, (inputs, _, _) in enumerate( 52 | itertools.islice(data_loader, num_batches) 53 | ): 54 | # Forwards the model to update the bn stats. 55 | if isinstance(inputs, (list,)): 56 | for i in range(len(inputs)): 57 | inputs[i] = inputs[i].float().cuda(non_blocking=True) 58 | else: 59 | inputs = inputs.cuda(non_blocking=True) 60 | model(inputs) 61 | 62 | for i, bn in enumerate(bn_layers): 63 | # Accumulates the bn stats. 64 | running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1) 65 | # $E(x^2) = Var(x) + E(x)^2$. 66 | cur_square_mean = bn.running_var + bn.running_mean ** 2 67 | running_square_mean[i] += ( 68 | cur_square_mean - running_square_mean[i] 69 | ) / (ind + 1) 70 | 71 | for i, bn in enumerate(bn_layers): 72 | bn.running_mean = running_mean[i] 73 | # Var(x) = $E(x^2) - E(x)^2$. 74 | bn.running_var = running_square_mean[i] - bn.running_mean ** 2 75 | # Sets the precise bn stats. 76 | bn.momentum = momentum_actual[i] 77 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/utils/c2_model_loading.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Caffe2 to PyTorch checkpoint name converting utility.""" 4 | 5 | import re 6 | 7 | 8 | def get_name_convert_func(): 9 | """ 10 | Get the function to convert Caffe2 layer names to PyTorch layer names. 11 | Returns: 12 | (func): function to convert parameter name from Caffe2 format to PyTorch 13 | format. 14 | """ 15 | pairs = [ 16 | # ------------------------------------------------------------ 17 | # 'nonlocal_conv3_1_theta_w' -> 's3.pathway0_nonlocal3.conv_g.weight' 18 | [ 19 | r"^nonlocal_conv([0-9]+)_([0-9]+)_(.*)", 20 | r"s\1.pathway0_nonlocal\2_\3", 21 | ], 22 | # 'theta' -> 'conv_theta' 23 | [r"^(.*)_nonlocal([0-9]+)_(theta)(.*)", r"\1_nonlocal\2.conv_\3\4"], 24 | # 'g' -> 'conv_g' 25 | [r"^(.*)_nonlocal([0-9]+)_(g)(.*)", r"\1_nonlocal\2.conv_\3\4"], 26 | # 'phi' -> 'conv_phi' 27 | [r"^(.*)_nonlocal([0-9]+)_(phi)(.*)", r"\1_nonlocal\2.conv_\3\4"], 28 | # 'out' -> 'conv_out' 29 | [r"^(.*)_nonlocal([0-9]+)_(out)(.*)", r"\1_nonlocal\2.conv_\3\4"], 30 | # 'nonlocal_conv4_5_bn_s' -> 's4.pathway0_nonlocal3.bn.weight' 31 | [r"^(.*)_nonlocal([0-9]+)_(bn)_(.*)", r"\1_nonlocal\2.\3.\4"], 32 | # ------------------------------------------------------------ 33 | # 't_pool1_subsample_bn' -> 's1_fuse.conv_f2s.bn.running_mean' 34 | [r"^t_pool1_subsample_bn_(.*)", r"s1_fuse.bn.\1"], 35 | # 't_pool1_subsample' -> 's1_fuse.conv_f2s' 36 | [r"^t_pool1_subsample_(.*)", r"s1_fuse.conv_f2s.\1"], 37 | # 't_res4_5_branch2c_bn_subsample_bn_rm' -> 's4_fuse.conv_f2s.bias' 38 | [ 39 | r"^t_res([0-9]+)_([0-9]+)_branch2c_bn_subsample_bn_(.*)", 40 | r"s\1_fuse.bn.\3", 41 | ], 42 | # 't_pool1_subsample' -> 's1_fuse.conv_f2s' 43 | [ 44 | r"^t_res([0-9]+)_([0-9]+)_branch2c_bn_subsample_(.*)", 45 | r"s\1_fuse.conv_f2s.\3", 46 | ], 47 | # ------------------------------------------------------------ 48 | # 'res4_4_branch_2c_bn_b' -> 's4.pathway0_res4.branch2.c_bn_b' 49 | [ 50 | r"^res([0-9]+)_([0-9]+)_branch([0-9]+)([a-z])_(.*)", 51 | r"s\1.pathway0_res\2.branch\3.\4_\5", 52 | ], 53 | # 'res_conv1_bn_' -> 's1.pathway0_stem.bn.' 54 | [r"^res_conv1_bn_(.*)", r"s1.pathway0_stem.bn.\1"], 55 | # 'conv1_xy_w_momentum' -> 's1.pathway0_stem.conv_xy.' 56 | [r"^conv1_xy(.*)", r"s1.pathway0_stem.conv_xy\1"], 57 | # 'conv1_w_momentum' -> 's1.pathway0_stem.conv.' 58 | [r"^conv1_(.*)", r"s1.pathway0_stem.conv.\1"], 59 | # 'res4_0_branch1_w' -> 'S4.pathway0_res0.branch1.weight' 60 | [ 61 | r"^res([0-9]+)_([0-9]+)_branch([0-9]+)_(.*)", 62 | r"s\1.pathway0_res\2.branch\3_\4", 63 | ], 64 | # 'res_conv1_' -> 's1.pathway0_stem.conv.' 65 | [r"^res_conv1_(.*)", r"s1.pathway0_stem.conv.\1"], 66 | # ------------------------------------------------------------ 67 | # 'res4_4_branch_2c_bn_b' -> 's4.pathway0_res4.branch2.c_bn_b' 68 | [ 69 | r"^t_res([0-9]+)_([0-9]+)_branch([0-9]+)([a-z])_(.*)", 70 | r"s\1.pathway1_res\2.branch\3.\4_\5", 71 | ], 72 | # 'res_conv1_bn_' -> 's1.pathway0_stem.bn.' 73 | [r"^t_res_conv1_bn_(.*)", r"s1.pathway1_stem.bn.\1"], 74 | # 'conv1_w_momentum' -> 's1.pathway0_stem.conv.' 75 | [r"^t_conv1_(.*)", r"s1.pathway1_stem.conv.\1"], 76 | # 'res4_0_branch1_w' -> 'S4.pathway0_res0.branch1.weight' 77 | [ 78 | r"^t_res([0-9]+)_([0-9]+)_branch([0-9]+)_(.*)", 79 | r"s\1.pathway1_res\2.branch\3_\4", 80 | ], 81 | # 'res_conv1_' -> 's1.pathway0_stem.conv.' 82 | [r"^t_res_conv1_(.*)", r"s1.pathway1_stem.conv.\1"], 83 | # ------------------------------------------------------------ 84 | # pred_ -> head.projection. 85 | [r"pred_(.*)", r"head.projection.\1"], 86 | # '.b_bn_fc' -> '.se.fc' 87 | [r"(.*)b_bn_fc(.*)", r"\1se.fc\2"], 88 | # conv_5 -> head.conv_5. 89 | [r"conv_5(.*)", r"head.conv_5\1"], 90 | # conv_5 -> head.conv_5. 91 | [r"lin_5(.*)", r"head.lin_5\1"], 92 | # '.bn_b' -> '.weight' 93 | [r"(.*)bn.b\Z", r"\1bn.bias"], 94 | # '.bn_s' -> '.weight' 95 | [r"(.*)bn.s\Z", r"\1bn.weight"], 96 | # '_bn_rm' -> '.running_mean' 97 | [r"(.*)bn.rm\Z", r"\1bn.running_mean"], 98 | # '_bn_riv' -> '.running_var' 99 | [r"(.*)bn.riv\Z", r"\1bn.running_var"], 100 | # '_b' -> '.bias' 101 | [r"(.*)[\._]b\Z", r"\1.bias"], 102 | # '_w' -> '.weight' 103 | [r"(.*)[\._]w\Z", r"\1.weight"], 104 | ] 105 | 106 | def convert_caffe2_name_to_pytorch(caffe2_layer_name): 107 | """ 108 | Convert the caffe2_layer_name to pytorch format by apply the list of 109 | regular expressions. 110 | Args: 111 | caffe2_layer_name (str): caffe2 layer name. 112 | Returns: 113 | (str): pytorch layer name. 114 | """ 115 | for source, dest in pairs: 116 | caffe2_layer_name = re.sub(source, dest, caffe2_layer_name) 117 | return caffe2_layer_name 118 | 119 | return convert_caffe2_name_to_pytorch 120 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/utils/env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Set up Environment.""" 4 | 5 | import timesformer.utils.logging as logging 6 | 7 | _ENV_SETUP_DONE = False 8 | 9 | 10 | def setup_environment(): 11 | global _ENV_SETUP_DONE 12 | if _ENV_SETUP_DONE: 13 | return 14 | _ENV_SETUP_DONE = True 15 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/utils/logging.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Logging.""" 4 | 5 | import atexit 6 | import builtins 7 | import decimal 8 | import functools 9 | import logging 10 | import os 11 | import sys 12 | import simplejson 13 | from fvcore.common.file_io import PathManager 14 | 15 | import timesformer.utils.distributed as du 16 | 17 | 18 | def _suppress_print(): 19 | """ 20 | Suppresses printing from the current process. 21 | """ 22 | 23 | def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False): 24 | pass 25 | 26 | builtins.print = print_pass 27 | 28 | 29 | @functools.lru_cache(maxsize=None) 30 | def _cached_log_stream(filename): 31 | io = PathManager.open(filename, "a", buffering=1024) 32 | atexit.register(io.close) 33 | return io 34 | 35 | 36 | def setup_logging(output_dir=None): 37 | """ 38 | Sets up the logging for multiple processes. Only enable the logging for the 39 | master process, and suppress logging for the non-master processes. 40 | """ 41 | # Set up logging format. 42 | _FORMAT = "[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s" 43 | 44 | if du.is_master_proc(): 45 | # Enable logging for the master process. 46 | logging.root.handlers = [] 47 | else: 48 | # Suppress logging for non-master processes. 49 | _suppress_print() 50 | 51 | logger = logging.getLogger() 52 | logger.setLevel(logging.DEBUG) 53 | logger.propagate = False 54 | plain_formatter = logging.Formatter( 55 | "[%(asctime)s][%(levelname)s] %(filename)s: %(lineno)3d: %(message)s", 56 | datefmt="%m/%d %H:%M:%S", 57 | ) 58 | 59 | if du.is_master_proc(): 60 | ch = logging.StreamHandler(stream=sys.stdout) 61 | ch.setLevel(logging.DEBUG) 62 | ch.setFormatter(plain_formatter) 63 | logger.addHandler(ch) 64 | 65 | if output_dir is not None and du.is_master_proc(du.get_world_size()): 66 | filename = os.path.join(output_dir, "stdout.log") 67 | fh = logging.StreamHandler(_cached_log_stream(filename)) 68 | fh.setLevel(logging.DEBUG) 69 | fh.setFormatter(plain_formatter) 70 | logger.addHandler(fh) 71 | 72 | 73 | def get_logger(name): 74 | """ 75 | Retrieve the logger with the specified name or, if name is None, return a 76 | logger which is the root logger of the hierarchy. 77 | Args: 78 | name (string): name of the logger. 79 | """ 80 | return logging.getLogger(name) 81 | 82 | 83 | def log_json_stats(stats): 84 | """ 85 | Logs json stats. 86 | Args: 87 | stats (dict): a dictionary of statistical information to log. 88 | """ 89 | stats = { 90 | k: decimal.Decimal("{:.5f}".format(v)) if isinstance(v, float) else v 91 | for k, v in stats.items() 92 | } 93 | json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True) 94 | logger = get_logger(__name__) 95 | logger.info("json_stats: {:s}".format(json_stats)) 96 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/utils/lr_policy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Learning rate policy.""" 4 | 5 | import math 6 | 7 | 8 | def get_lr_at_epoch(cfg, cur_epoch): 9 | """ 10 | Retrieve the learning rate of the current epoch with the option to perform 11 | warm up in the beginning of the training stage. 12 | Args: 13 | cfg (CfgNode): configs. Details can be found in 14 | slowfast/config/defaults.py 15 | cur_epoch (float): the number of epoch of the current training stage. 16 | """ 17 | lr = get_lr_func(cfg.SOLVER.LR_POLICY)(cfg, cur_epoch) 18 | # Perform warm up. 19 | if cur_epoch < cfg.SOLVER.WARMUP_EPOCHS: 20 | lr_start = cfg.SOLVER.WARMUP_START_LR 21 | lr_end = get_lr_func(cfg.SOLVER.LR_POLICY)( 22 | cfg, cfg.SOLVER.WARMUP_EPOCHS 23 | ) 24 | alpha = (lr_end - lr_start) / cfg.SOLVER.WARMUP_EPOCHS 25 | lr = cur_epoch * alpha + lr_start 26 | return lr 27 | 28 | 29 | def lr_func_cosine(cfg, cur_epoch): 30 | """ 31 | Retrieve the learning rate to specified values at specified epoch with the 32 | cosine learning rate schedule. Details can be found in: 33 | Ilya Loshchilov, and Frank Hutter 34 | SGDR: Stochastic Gradient Descent With Warm Restarts. 35 | Args: 36 | cfg (CfgNode): configs. Details can be found in 37 | slowfast/config/defaults.py 38 | cur_epoch (float): the number of epoch of the current training stage. 39 | """ 40 | assert cfg.SOLVER.COSINE_END_LR < cfg.SOLVER.BASE_LR 41 | return ( 42 | cfg.SOLVER.COSINE_END_LR 43 | + (cfg.SOLVER.BASE_LR - cfg.SOLVER.COSINE_END_LR) 44 | * (math.cos(math.pi * cur_epoch / cfg.SOLVER.MAX_EPOCH) + 1.0) 45 | * 0.5 46 | ) 47 | 48 | 49 | def lr_func_steps_with_relative_lrs(cfg, cur_epoch): 50 | """ 51 | Retrieve the learning rate to specified values at specified epoch with the 52 | steps with relative learning rate schedule. 53 | Args: 54 | cfg (CfgNode): configs. Details can be found in 55 | slowfast/config/defaults.py 56 | cur_epoch (float): the number of epoch of the current training stage. 57 | """ 58 | ind = get_step_index(cfg, cur_epoch) 59 | return cfg.SOLVER.LRS[ind] * cfg.SOLVER.BASE_LR 60 | 61 | 62 | def get_step_index(cfg, cur_epoch): 63 | """ 64 | Retrieves the lr step index for the given epoch. 65 | Args: 66 | cfg (CfgNode): configs. Details can be found in 67 | slowfast/config/defaults.py 68 | cur_epoch (float): the number of epoch of the current training stage. 69 | """ 70 | steps = cfg.SOLVER.STEPS + [cfg.SOLVER.MAX_EPOCH] 71 | for ind, step in enumerate(steps): # NoQA 72 | if cur_epoch < step: 73 | break 74 | return ind - 1 75 | 76 | 77 | def get_lr_func(lr_policy): 78 | """ 79 | Given the configs, retrieve the specified lr policy function. 80 | Args: 81 | lr_policy (string): the learning rate policy to use for the job. 82 | """ 83 | policy = "lr_func_" + lr_policy 84 | if policy not in globals(): 85 | raise NotImplementedError("Unknown LR policy: {}".format(lr_policy)) 86 | else: 87 | return globals()[policy] 88 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/utils/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Functions for computing metrics.""" 4 | 5 | import torch 6 | import numpy as np 7 | 8 | def topks_correct(preds, labels, ks): 9 | """ 10 | Given the predictions, labels, and a list of top-k values, compute the 11 | number of correct predictions for each top-k value. 12 | 13 | Args: 14 | preds (array): array of predictions. Dimension is batchsize 15 | N x ClassNum. 16 | labels (array): array of labels. Dimension is batchsize N. 17 | ks (list): list of top-k values. For example, ks = [1, 5] correspods 18 | to top-1 and top-5. 19 | 20 | Returns: 21 | topks_correct (list): list of numbers, where the `i`-th entry 22 | corresponds to the number of top-`ks[i]` correct predictions. 23 | """ 24 | assert preds.size(0) == labels.size( 25 | 0 26 | ), "Batch dim of predictions and labels must match" 27 | # Find the top max_k predictions for each sample 28 | _top_max_k_vals, top_max_k_inds = torch.topk( 29 | preds, max(ks), dim=1, largest=True, sorted=True 30 | ) 31 | # (batch_size, max_k) -> (max_k, batch_size). 32 | top_max_k_inds = top_max_k_inds.t() 33 | # (batch_size, ) -> (max_k, batch_size). 34 | rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds) 35 | # (i, j) = 1 if top i-th prediction for the j-th sample is correct. 36 | top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels) 37 | # Compute the number of topk correct predictions for each k. 38 | topks_correct = [top_max_k_correct[:k, :].float().sum() for k in ks] 39 | return topks_correct 40 | 41 | 42 | def topk_errors(preds, labels, ks): 43 | """ 44 | Computes the top-k error for each k. 45 | Args: 46 | preds (array): array of predictions. Dimension is N. 47 | labels (array): array of labels. Dimension is N. 48 | ks (list): list of ks to calculate the top accuracies. 49 | """ 50 | num_topks_correct = topks_correct(preds, labels, ks) 51 | return [(1.0 - x / preds.size(0)) * 100.0 for x in num_topks_correct] 52 | 53 | 54 | def topk_accuracies(preds, labels, ks): 55 | """ 56 | Computes the top-k accuracy for each k. 57 | Args: 58 | preds (array): array of predictions. Dimension is N. 59 | labels (array): array of labels. Dimension is N. 60 | ks (list): list of ks to calculate the top accuracies. 61 | """ 62 | num_topks_correct = topks_correct(preds, labels, ks) 63 | return [(x / preds.size(0)) * 100.0 for x in num_topks_correct] 64 | 65 | def multitask_topks_correct(preds, labels, ks=(1,)): 66 | """ 67 | Args: 68 | preds: tuple(torch.FloatTensor), each tensor should be of shape 69 | [batch_size, class_count], class_count can vary on a per task basis, i.e. 70 | outputs[i].shape[1] can be different to outputs[j].shape[j]. 71 | labels: tuple(torch.LongTensor), each tensor should be of shape [batch_size] 72 | ks: tuple(int), compute accuracy at top-k for the values of k specified 73 | in this parameter. 74 | Returns: 75 | tuple(float), same length at topk with the corresponding accuracy@k in. 76 | """ 77 | max_k = int(np.max(ks)) 78 | task_count = len(preds) 79 | batch_size = labels[0].size(0) 80 | all_correct = torch.zeros(max_k, batch_size).type(torch.ByteTensor) 81 | if torch.cuda.is_available(): 82 | all_correct = all_correct.cuda() 83 | for output, label in zip(preds, labels): 84 | _, max_k_idx = output.topk(max_k, dim=1, largest=True, sorted=True) 85 | # Flip batch_size, class_count as .view doesn't work on non-contiguous 86 | max_k_idx = max_k_idx.t() 87 | correct_for_task = max_k_idx.eq(label.view(1, -1).expand_as(max_k_idx)) 88 | all_correct.add_(correct_for_task) 89 | 90 | multitask_topks_correct = [ 91 | torch.ge(all_correct[:k].float().sum(0), task_count).float().sum(0) for k in ks 92 | ] 93 | 94 | return multitask_topks_correct 95 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/utils/multigrid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Helper functions for multigrid training.""" 4 | 5 | import numpy as np 6 | 7 | import timesformer.utils.logging as logging 8 | 9 | logger = logging.get_logger(__name__) 10 | 11 | 12 | class MultigridSchedule(object): 13 | """ 14 | This class defines multigrid training schedule and update cfg accordingly. 15 | """ 16 | 17 | def init_multigrid(self, cfg): 18 | """ 19 | Update cfg based on multigrid settings. 20 | Args: 21 | cfg (configs): configs that contains training and multigrid specific 22 | hyperparameters. Details can be seen in 23 | slowfast/config/defaults.py. 24 | Returns: 25 | cfg (configs): the updated cfg. 26 | """ 27 | self.schedule = None 28 | # We may modify cfg.TRAIN.BATCH_SIZE, cfg.DATA.NUM_FRAMES, and 29 | # cfg.DATA.TRAIN_CROP_SIZE during training, so we store their original 30 | # value in cfg and use them as global variables. 31 | cfg.MULTIGRID.DEFAULT_B = cfg.TRAIN.BATCH_SIZE 32 | cfg.MULTIGRID.DEFAULT_T = cfg.DATA.NUM_FRAMES 33 | cfg.MULTIGRID.DEFAULT_S = cfg.DATA.TRAIN_CROP_SIZE 34 | 35 | if cfg.MULTIGRID.LONG_CYCLE: 36 | self.schedule = self.get_long_cycle_schedule(cfg) 37 | cfg.SOLVER.STEPS = [0] + [s[-1] for s in self.schedule] 38 | # Fine-tuning phase. 39 | cfg.SOLVER.STEPS[-1] = ( 40 | cfg.SOLVER.STEPS[-2] + cfg.SOLVER.STEPS[-1] 41 | ) // 2 42 | cfg.SOLVER.LRS = [ 43 | cfg.SOLVER.GAMMA ** s[0] * s[1][0] for s in self.schedule 44 | ] 45 | # Fine-tuning phase. 46 | cfg.SOLVER.LRS = cfg.SOLVER.LRS[:-1] + [ 47 | cfg.SOLVER.LRS[-2], 48 | cfg.SOLVER.LRS[-1], 49 | ] 50 | 51 | cfg.SOLVER.MAX_EPOCH = self.schedule[-1][-1] 52 | 53 | elif cfg.MULTIGRID.SHORT_CYCLE: 54 | cfg.SOLVER.STEPS = [ 55 | int(s * cfg.MULTIGRID.EPOCH_FACTOR) for s in cfg.SOLVER.STEPS 56 | ] 57 | cfg.SOLVER.MAX_EPOCH = int( 58 | cfg.SOLVER.MAX_EPOCH * cfg.MULTIGRID.EPOCH_FACTOR 59 | ) 60 | return cfg 61 | 62 | def update_long_cycle(self, cfg, cur_epoch): 63 | """ 64 | Before every epoch, check if long cycle shape should change. If it 65 | should, update cfg accordingly. 66 | Args: 67 | cfg (configs): configs that contains training and multigrid specific 68 | hyperparameters. Details can be seen in 69 | slowfast/config/defaults.py. 70 | cur_epoch (int): current epoch index. 71 | Returns: 72 | cfg (configs): the updated cfg. 73 | changed (bool): do we change long cycle shape at this epoch? 74 | """ 75 | base_b, base_t, base_s = get_current_long_cycle_shape( 76 | self.schedule, cur_epoch 77 | ) 78 | if base_s != cfg.DATA.TRAIN_CROP_SIZE or base_t != cfg.DATA.NUM_FRAMES: 79 | 80 | cfg.DATA.NUM_FRAMES = base_t 81 | cfg.DATA.TRAIN_CROP_SIZE = base_s 82 | cfg.TRAIN.BATCH_SIZE = base_b * cfg.MULTIGRID.DEFAULT_B 83 | 84 | bs_factor = ( 85 | float(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS) 86 | / cfg.MULTIGRID.BN_BASE_SIZE 87 | ) 88 | 89 | if bs_factor < 1: 90 | cfg.BN.NORM_TYPE = "sync_batchnorm" 91 | cfg.BN.NUM_SYNC_DEVICES = int(1.0 / bs_factor) 92 | elif bs_factor > 1: 93 | cfg.BN.NORM_TYPE = "sub_batchnorm" 94 | cfg.BN.NUM_SPLITS = int(bs_factor) 95 | else: 96 | cfg.BN.NORM_TYPE = "batchnorm" 97 | 98 | cfg.MULTIGRID.LONG_CYCLE_SAMPLING_RATE = cfg.DATA.SAMPLING_RATE * ( 99 | cfg.MULTIGRID.DEFAULT_T // cfg.DATA.NUM_FRAMES 100 | ) 101 | logger.info("Long cycle updates:") 102 | logger.info("\tBN.NORM_TYPE: {}".format(cfg.BN.NORM_TYPE)) 103 | if cfg.BN.NORM_TYPE == "sync_batchnorm": 104 | logger.info( 105 | "\tBN.NUM_SYNC_DEVICES: {}".format(cfg.BN.NUM_SYNC_DEVICES) 106 | ) 107 | elif cfg.BN.NORM_TYPE == "sub_batchnorm": 108 | logger.info("\tBN.NUM_SPLITS: {}".format(cfg.BN.NUM_SPLITS)) 109 | logger.info("\tTRAIN.BATCH_SIZE: {}".format(cfg.TRAIN.BATCH_SIZE)) 110 | logger.info( 111 | "\tDATA.NUM_FRAMES x LONG_CYCLE_SAMPLING_RATE: {}x{}".format( 112 | cfg.DATA.NUM_FRAMES, cfg.MULTIGRID.LONG_CYCLE_SAMPLING_RATE 113 | ) 114 | ) 115 | logger.info( 116 | "\tDATA.TRAIN_CROP_SIZE: {}".format(cfg.DATA.TRAIN_CROP_SIZE) 117 | ) 118 | return cfg, True 119 | else: 120 | return cfg, False 121 | 122 | def get_long_cycle_schedule(self, cfg): 123 | """ 124 | Based on multigrid hyperparameters, define the schedule of a long cycle. 125 | Args: 126 | cfg (configs): configs that contains training and multigrid specific 127 | hyperparameters. Details can be seen in 128 | slowfast/config/defaults.py. 129 | Returns: 130 | schedule (list): Specifies a list long cycle base shapes and their 131 | corresponding training epochs. 132 | """ 133 | 134 | steps = cfg.SOLVER.STEPS 135 | 136 | default_size = float( 137 | cfg.DATA.NUM_FRAMES * cfg.DATA.TRAIN_CROP_SIZE ** 2 138 | ) 139 | default_iters = steps[-1] 140 | 141 | # Get shapes and average batch size for each long cycle shape. 142 | avg_bs = [] 143 | all_shapes = [] 144 | for t_factor, s_factor in cfg.MULTIGRID.LONG_CYCLE_FACTORS: 145 | base_t = int(round(cfg.DATA.NUM_FRAMES * t_factor)) 146 | base_s = int(round(cfg.DATA.TRAIN_CROP_SIZE * s_factor)) 147 | if cfg.MULTIGRID.SHORT_CYCLE: 148 | shapes = [ 149 | [ 150 | base_t, 151 | cfg.MULTIGRID.DEFAULT_S 152 | * cfg.MULTIGRID.SHORT_CYCLE_FACTORS[0], 153 | ], 154 | [ 155 | base_t, 156 | cfg.MULTIGRID.DEFAULT_S 157 | * cfg.MULTIGRID.SHORT_CYCLE_FACTORS[1], 158 | ], 159 | [base_t, base_s], 160 | ] 161 | else: 162 | shapes = [[base_t, base_s]] 163 | 164 | # (T, S) -> (B, T, S) 165 | shapes = [ 166 | [int(round(default_size / (s[0] * s[1] * s[1]))), s[0], s[1]] 167 | for s in shapes 168 | ] 169 | avg_bs.append(np.mean([s[0] for s in shapes])) 170 | all_shapes.append(shapes) 171 | 172 | # Get schedule regardless of cfg.MULTIGRID.EPOCH_FACTOR. 173 | total_iters = 0 174 | schedule = [] 175 | for step_index in range(len(steps) - 1): 176 | step_epochs = steps[step_index + 1] - steps[step_index] 177 | 178 | for long_cycle_index, shapes in enumerate(all_shapes): 179 | cur_epochs = ( 180 | step_epochs * avg_bs[long_cycle_index] / sum(avg_bs) 181 | ) 182 | 183 | cur_iters = cur_epochs / avg_bs[long_cycle_index] 184 | total_iters += cur_iters 185 | schedule.append((step_index, shapes[-1], cur_epochs)) 186 | 187 | iter_saving = default_iters / total_iters 188 | 189 | final_step_epochs = cfg.SOLVER.MAX_EPOCH - steps[-1] 190 | 191 | # We define the fine-tuning phase to have the same amount of iteration 192 | # saving as the rest of the training. 193 | ft_epochs = final_step_epochs / iter_saving * avg_bs[-1] 194 | 195 | schedule.append((step_index + 1, all_shapes[-1][2], ft_epochs)) 196 | 197 | # Obtrain final schedule given desired cfg.MULTIGRID.EPOCH_FACTOR. 198 | x = ( 199 | cfg.SOLVER.MAX_EPOCH 200 | * cfg.MULTIGRID.EPOCH_FACTOR 201 | / sum(s[-1] for s in schedule) 202 | ) 203 | 204 | final_schedule = [] 205 | total_epochs = 0 206 | for s in schedule: 207 | epochs = s[2] * x 208 | total_epochs += epochs 209 | final_schedule.append((s[0], s[1], int(round(total_epochs)))) 210 | print_schedule(final_schedule) 211 | return final_schedule 212 | 213 | 214 | def print_schedule(schedule): 215 | """ 216 | Log schedule. 217 | """ 218 | logger.info("Long cycle index\tBase shape\tEpochs") 219 | for s in schedule: 220 | logger.info("{}\t{}\t{}".format(s[0], s[1], s[2])) 221 | 222 | 223 | def get_current_long_cycle_shape(schedule, epoch): 224 | """ 225 | Given a schedule and epoch index, return the long cycle base shape. 226 | Args: 227 | schedule (configs): configs that contains training and multigrid specific 228 | hyperparameters. Details can be seen in 229 | slowfast/config/defaults.py. 230 | cur_epoch (int): current epoch index. 231 | Returns: 232 | shapes (list): A list describing the base shape in a long cycle: 233 | [batch size relative to default, 234 | number of frames, spatial dimension]. 235 | """ 236 | for s in schedule: 237 | if epoch < s[-1]: 238 | return s[1] 239 | return schedule[-1][1] 240 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/utils/multiprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Multiprocessing helpers.""" 4 | 5 | import torch 6 | 7 | 8 | def run( 9 | local_rank, 10 | num_proc, 11 | func, 12 | init_method, 13 | shard_id, 14 | num_shards, 15 | backend, 16 | cfg, 17 | output_queue=None, 18 | ): 19 | """ 20 | Runs a function from a child process. 21 | Args: 22 | local_rank (int): rank of the current process on the current machine. 23 | num_proc (int): number of processes per machine. 24 | func (function): function to execute on each of the process. 25 | init_method (string): method to initialize the distributed training. 26 | TCP initialization: equiring a network address reachable from all 27 | processes followed by the port. 28 | Shared file-system initialization: makes use of a file system that 29 | is shared and visible from all machines. The URL should start with 30 | file:// and contain a path to a non-existent file on a shared file 31 | system. 32 | shard_id (int): the rank of the current machine. 33 | num_shards (int): number of overall machines for the distributed 34 | training job. 35 | backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are 36 | supports, each with different capabilities. Details can be found 37 | here: 38 | https://pytorch.org/docs/stable/distributed.html 39 | cfg (CfgNode): configs. Details can be found in 40 | slowfast/config/defaults.py 41 | output_queue (queue): can optionally be used to return values from the 42 | master process. 43 | """ 44 | # Initialize the process group. 45 | world_size = num_proc * num_shards 46 | rank = shard_id * num_proc + local_rank 47 | 48 | try: 49 | torch.distributed.init_process_group( 50 | backend=backend, 51 | init_method=init_method, 52 | world_size=world_size, 53 | rank=rank, 54 | ) 55 | except Exception as e: 56 | raise e 57 | 58 | torch.cuda.set_device(local_rank) 59 | ret = func(cfg) 60 | if output_queue is not None and local_rank == 0: 61 | output_queue.put(ret) 62 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/utils/parser.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Argument parser functions.""" 4 | 5 | import argparse 6 | import sys 7 | 8 | import timesformer.utils.checkpoint as cu 9 | from timesformer.config.defaults import get_cfg 10 | 11 | 12 | def parse_args(): 13 | """ 14 | Parse the following arguments for a default parser for PySlowFast users. 15 | Args: 16 | shard_id (int): shard id for the current machine. Starts from 0 to 17 | num_shards - 1. If single machine is used, then set shard id to 0. 18 | num_shards (int): number of shards using by the job. 19 | init_method (str): initialization method to launch the job with multiple 20 | devices. Options includes TCP or shared file-system for 21 | initialization. details can be find in 22 | https://pytorch.org/docs/stable/distributed.html#tcp-initialization 23 | cfg (str): path to the config file. 24 | opts (argument): provide addtional options from the command line, it 25 | overwrites the config loaded from file. 26 | """ 27 | parser = argparse.ArgumentParser( 28 | description="Provide SlowFast video training and testing pipeline." 29 | ) 30 | parser.add_argument( 31 | "--shard_id", 32 | help="The shard id of current node, Starts from 0 to num_shards - 1", 33 | default=0, 34 | type=int, 35 | ) 36 | parser.add_argument( 37 | "--num_shards", 38 | help="Number of shards using by the job", 39 | default=1, 40 | type=int, 41 | ) 42 | parser.add_argument( 43 | "--init_method", 44 | help="Initialization method, includes TCP or shared file-system", 45 | default="tcp://localhost:9999", 46 | type=str, 47 | ) 48 | parser.add_argument( 49 | "--cfg", 50 | dest="cfg_file", 51 | help="Path to the config file", 52 | default="configs/Kinetics/SLOWFAST_4x16_R50.yaml", 53 | type=str, 54 | ) 55 | parser.add_argument( 56 | "opts", 57 | help="See slowfast/config/defaults.py for all options", 58 | default=None, 59 | nargs=argparse.REMAINDER, 60 | ) 61 | if len(sys.argv) == 1: 62 | parser.print_help() 63 | return parser.parse_args() 64 | 65 | 66 | def load_config(args): 67 | """ 68 | Given the arguemnts, load and initialize the configs. 69 | Args: 70 | args (argument): arguments includes `shard_id`, `num_shards`, 71 | `init_method`, `cfg_file`, and `opts`. 72 | """ 73 | # Setup cfg. 74 | cfg = get_cfg() 75 | # Load config from cfg. 76 | if args.cfg_file is not None: 77 | cfg.merge_from_file(args.cfg_file) 78 | # Load config from command line, overwrite config from opts. 79 | if args.opts is not None: 80 | cfg.merge_from_list(args.opts) 81 | 82 | # Inherit parameters from args. 83 | if hasattr(args, "num_shards") and hasattr(args, "shard_id"): 84 | cfg.NUM_SHARDS = args.num_shards 85 | cfg.SHARD_ID = args.shard_id 86 | if hasattr(args, "rng_seed"): 87 | cfg.RNG_SEED = args.rng_seed 88 | if hasattr(args, "output_dir"): 89 | cfg.OUTPUT_DIR = args.output_dir 90 | 91 | # Create the checkpoint dir. 92 | cu.make_checkpoint_dir(cfg.OUTPUT_DIR) 93 | return cfg 94 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/utils/weight_init_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | 3 | """Utility function for weight initialization""" 4 | 5 | import torch.nn as nn 6 | from fvcore.nn.weight_init import c2_msra_fill 7 | 8 | 9 | def init_weights(model, fc_init_std=0.01, zero_init_final_bn=True): 10 | """ 11 | Performs ResNet style weight initialization. 12 | Args: 13 | fc_init_std (float): the expected standard deviation for fc layer. 14 | zero_init_final_bn (bool): if True, zero initialize the final bn for 15 | every bottleneck. 16 | """ 17 | for m in model.modules(): 18 | if isinstance(m, nn.Conv3d): 19 | """ 20 | Follow the initialization method proposed in: 21 | {He, Kaiming, et al. 22 | "Delving deep into rectifiers: Surpassing human-level 23 | performance on imagenet classification." 24 | arXiv preprint arXiv:1502.01852 (2015)} 25 | """ 26 | c2_msra_fill(m) 27 | elif isinstance(m, nn.BatchNorm3d): 28 | if ( 29 | hasattr(m, "transform_final_bn") 30 | and m.transform_final_bn 31 | and zero_init_final_bn 32 | ): 33 | batchnorm_weight = 0.0 34 | else: 35 | batchnorm_weight = 1.0 36 | if m.weight is not None: 37 | m.weight.data.fill_(batchnorm_weight) 38 | if m.bias is not None: 39 | m.bias.data.zero_() 40 | if isinstance(m, nn.Linear): 41 | m.weight.data.normal_(mean=0.0, std=fc_init_std) 42 | if m.bias is not None: 43 | m.bias.data.zero_() 44 | -------------------------------------------------------------------------------- /third_party/TimeSformer/timesformer/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /utils/geometry.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Tools / utilities / helper methods pertaining to camera projections and other 3D stuff. 3 | Created by Basile Van Hoorick for TCOW. 4 | ''' 5 | 6 | import os 7 | import sys 8 | sys.path.append(os.getcwd()) 9 | sys.path.append(os.path.join(os.getcwd(), 'utils/')) 10 | 11 | from __init__ import * 12 | 13 | # Library imports. 14 | import numpy as np 15 | 16 | 17 | def box_to_tf_matrix(box, rows): 18 | ''' 19 | :param box (8, 3) array: All corners in XYZ space of 3D cube surrounding object. 20 | :param (tf_matrix, rows). 21 | tf_matrix (4, 4) array: Coordinate transformation matrix from object space to world space. 22 | rows (3) array: Indices of rows in box that form an edge with the first row (= origin). 23 | ''' 24 | # We make minimal assumptions about the ordering of the 3D points, except that the first two 25 | # rows must make up an edge of the box. Then, we look for the next two orthogonal vectors. 26 | origin = box[0] 27 | 28 | if rows is None: 29 | axis1 = box[1] - origin 30 | axis2 = None 31 | axis3 = None 32 | row1 = 1 33 | row2 = None 34 | row3 = None 35 | 36 | for i in range(2, 8): 37 | cand_axis = box[i] - origin 38 | if axis2 is None: 39 | if np.abs(np.dot(axis1, cand_axis)) < 1e-7: 40 | axis2 = cand_axis 41 | row2 = i 42 | elif axis3 is None: 43 | if np.abs(np.dot(axis1, cand_axis)) < 1e-7 and np.abs(np.dot(axis2, cand_axis)) < 1e-7: 44 | axis3 = cand_axis 45 | row3 = i 46 | 47 | assert axis2 is not None and axis3 is not None, \ 48 | 'Could not find orthogonal axes for object_box' 49 | rows = np.array([row1, row2, row3]) 50 | 51 | else: 52 | axis1 = box[rows[0]] - origin 53 | axis2 = box[rows[1]] - origin 54 | axis3 = box[rows[2]] - origin 55 | 56 | object_to_world = np.stack([axis1, axis2, axis3, origin], axis=1) 57 | object_to_world = np.concatenate([object_to_world, [[0.0, 0.0, 0.0, 1.0]]], axis=0) 58 | # Sanity check while debugging: origin + axis1 must be close to object_to_world @ [1, 0, 0, 1]. 59 | # NOTE: object_to_world is generally not orthonormal, because the axis lengths follow the size 60 | # of the container box, not unit vectors. 61 | 62 | return (object_to_world, rows) 63 | 64 | 65 | def get_containment_fraction_approx(inside_box, outside_box): 66 | ''' 67 | Calculates a sampled approximation of how much volume of a non-aligned 3D bounding box of a 68 | candidate object intersects (i.e. is inside of) that of a reference object. 69 | :param inside_box (8, 3) array of float: All corners in XYZ space of candidate containee cube. 70 | :param outside_box (8, 3) array of float: All corners in XYZ space of reference container cube. 71 | :return cflb (float). 72 | ''' 73 | # NEW: Work with sampling. This is kind of brute-force, but at least it is simple and correct. 74 | # https://stackoverflow.com/questions/1827489/numpy-meshgrid-in-3d 75 | (x, y, z) = np.meshgrid(np.linspace(0, 1, 6), np.linspace(0, 1, 6), np.linspace(0, 1, 6), 76 | indexing='ij') 77 | xyz = np.stack([x, y, z], axis=-1).reshape((-1, 3)) # (216, 3). 78 | xyz_homo = np.concatenate([xyz, np.ones((xyz.shape[0], 1))], axis=1) # (216, 4). 79 | 80 | # Study the inside box in the coordinate system of the outside box. 81 | (outside_to_world, rows) = box_to_tf_matrix(outside_box, None) 82 | (inside_to_world, rows) = box_to_tf_matrix(inside_box, None) 83 | world_to_outside = np.linalg.inv(outside_to_world) 84 | inside_to_outside = np.matmul(world_to_outside, inside_to_world) 85 | # # NOTE: Unlike outside_to_world, world_to_outside (and inside_to_outside) are generally not 86 | # even orthogonal, because outside_to_world is not orthonormal! 87 | 88 | xyz_warped = np.matmul(inside_to_outside, xyz_homo.T).T 89 | assert np.all(np.abs(xyz_warped[..., -1] - 1.0) < 1e-5), \ 90 | 'Homogeneous coordinate is not 1' 91 | xyz_warped = xyz_warped[..., 0:3] 92 | points_contained = np.logical_and(np.all(xyz_warped >= 0.0, axis=1), 93 | np.all(xyz_warped <= 1.0, axis=1)) 94 | cf_approx = np.mean(points_contained.astype(np.float32)) 95 | 96 | return cf_approx 97 | --------------------------------------------------------------------------------