├── tapnet ├── tapvid │ ├── requirements.txt │ ├── __init__.py │ ├── download_kinetics_videos.sh │ ├── visualize.py │ ├── generate_tapvid.py │ └── README.md ├── models │ ├── __init__.py │ ├── tsm_utils.py │ └── tapnet_model.py ├── torch │ └── __init__.py ├── trajan │ ├── __init__.py │ └── attention.py ├── utils │ ├── __init__.py │ ├── optimizers.py │ ├── transforms.py │ ├── index_utils.py │ ├── ssm_utils.py │ └── experiment_utils.py ├── tapvid3d │ ├── splits │ │ └── __init__.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── run_evaluate_model.sh │ │ └── evaluate_model.py │ ├── annotation_generation │ │ ├── __init__.py │ │ ├── generate_drivetrack.py │ │ ├── gcs_utils.py │ │ ├── generate_adt.py │ │ ├── generate_pstudio.py │ │ └── adt_utils.py │ ├── generate_all.sh │ └── README.md ├── __init__.py ├── training │ ├── README.md │ └── task.py ├── tapnext │ ├── pscan.py │ ├── torch_losses.py │ ├── losses.py │ ├── tapnext_benchmark_pytorch.ipynb │ ├── tapnext_torch_utils.py │ └── tapnext_torch.py ├── live_demo.py └── pytorch_live_demo.py ├── assets └── tapir_benchmark.png ├── requirements_inference.txt ├── requirements.txt ├── .gitignore ├── CONTRIBUTING.md ├── pyproject.toml ├── colabs └── tapir_clustering.ipynb └── configs ├── tapnet_config.py ├── tapir_config.py ├── causal_tapir_config.py └── tapir_bootstrap_config.py /tapnet/tapvid/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | ffmpeg 3 | ffmpeg-python 4 | mediapy 5 | numpy 6 | -------------------------------------------------------------------------------- /assets/tapir_benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/tapnet/HEAD/assets/tapir_benchmark.png -------------------------------------------------------------------------------- /requirements_inference.txt: -------------------------------------------------------------------------------- 1 | chex 2 | jax 3 | jaxline 4 | optax 5 | dm-haiku 6 | dm-tree 7 | typing_extensions 8 | matplotlib 9 | mediapy 10 | opencv-python 11 | einshape 12 | ipympl 13 | tqdm 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | notebook 3 | jupyter_http_over_ws 4 | chex 5 | jax 6 | jaxline 7 | optax 8 | dm-haiku 9 | dm-tree 10 | typing_extensions 11 | matplotlib 12 | mediapy 13 | opencv-python 14 | einshape 15 | einops 16 | tensorflow 17 | tensorflow-datasets 18 | tensorflow_graphics 19 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | .Python 8 | build/ 9 | develop-eggs/ 10 | dist/ 11 | downloads/ 12 | eggs/ 13 | .eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | wheels/ 20 | share/python-wheels/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | MANIFEST 25 | -------------------------------------------------------------------------------- /tapnet/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | 17 | -------------------------------------------------------------------------------- /tapnet/tapvid/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | 17 | -------------------------------------------------------------------------------- /tapnet/torch/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | 17 | -------------------------------------------------------------------------------- /tapnet/trajan/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | 17 | -------------------------------------------------------------------------------- /tapnet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | 17 | -------------------------------------------------------------------------------- /tapnet/tapvid3d/splits/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | 17 | -------------------------------------------------------------------------------- /tapnet/tapvid3d/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | 17 | -------------------------------------------------------------------------------- /tapnet/tapvid3d/annotation_generation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | 17 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | ## Contributor License Agreement 4 | 5 | Contributions to this project must be accompanied by a Contributor License 6 | Agreement. You (or your employer) retain the copyright to your contribution, 7 | this simply gives us permission to use and redistribute your contributions as 8 | part of the project. Head over to to see 9 | your current agreements on file or to sign a new one. 10 | 11 | You generally only need to submit a CLA once, so if you've already submitted one 12 | (even if it was for a different project), you probably don't need to do it 13 | again. 14 | 15 | ## Code reviews 16 | 17 | All submissions, including submissions by project members, require review. We 18 | use GitHub pull requests for this purpose. Consult 19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 20 | information on using pull requests. 21 | 22 | ## Community Guidelines 23 | 24 | This project follows [Google's Open Source Community 25 | Guidelines](https://opensource.google/conduct/). 26 | -------------------------------------------------------------------------------- /tapnet/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Legacy API for TAP. Prefer importing from project subfolders.""" 17 | 18 | from tapnet.models import tapir_model # pylint:disable=g-importing-member 19 | from tapnet.models import tapnet_model # pylint:disable=g-importing-member 20 | from tapnet.robotap import tapir_clustering # pylint:disable=g-importing-member 21 | from tapnet.tapvid import evaluation_datasets # pylint:disable=g-importing-member 22 | -------------------------------------------------------------------------------- /tapnet/tapvid/download_kinetics_videos.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2025 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | 18 | git clone https://github.com/cvdfoundation/kinetics-dataset.git 19 | mkdir kinetics_videos 20 | cd kinetics_videos 21 | wget https://s3.amazonaws.com/kinetics/700_2020/val/k700_2020_val_path.txt 22 | bash ../kinetics-dataset/download.sh k700_2020_val_path.txt 23 | bash ../kinetics-dataset/extract.sh k700_2020_val_path.txt 24 | rm -f k700_val_* 25 | rm -f k700_2020_val_path.txt 26 | cd .. 27 | rm -rf kinetics-dataset 28 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "tapnet" 3 | version = "0.1.0" 4 | description = "Tracking-Any-Point codebase from Google DeepMind." 5 | dependencies = [ 6 | "chex", 7 | "jax", 8 | "jaxline", 9 | "optax", 10 | "dm-haiku", 11 | "dm-tree", 12 | "typing_extensions", 13 | "matplotlib", 14 | "mediapy", 15 | "opencv-python", 16 | "einshape", 17 | "ipympl", 18 | "tqdm", 19 | ] 20 | 21 | [project.optional-dependencies] 22 | train = [ 23 | "absl-py", 24 | "notebook", 25 | "jupyter_http_over_ws", 26 | "tensorflow", 27 | "tensorflow-datasets", 28 | "tensorflow_graphics", 29 | "kubric@git+https://github.com/google-research/kubric", 30 | "recurrentgemma@git+https://github.com/google-deepmind/recurrentgemma" 31 | ] 32 | 33 | torch = [ 34 | "torch", 35 | "torchvision", 36 | ] 37 | tapvid3d_eval = [ 38 | "einops>=0.8.0", 39 | "numpy>=1.25.2", 40 | "absl-py>=2.1.0", 41 | "tqdm>=4.66.4", 42 | "pillow>=9.4.0", 43 | ] 44 | tapvid3d_generation = [ 45 | "absl-py==2.1.0", 46 | "tqdm==4.66.4", 47 | "absl-py==2.1.0", 48 | "tqdm==4.66.4", 49 | "tensorflow==2.15.0", 50 | "numpy==1.25.2", 51 | "pillow==9.4.0", 52 | "projectaria-tools==1.5.2", 53 | "visu3d==1.5.3", 54 | "torch==2.3.0", 55 | "torchvision==0.18.0", 56 | "etils==1.7.0", 57 | "tensorflow-datasets", 58 | ] 59 | 60 | [tool.setuptools.packages.find] 61 | where = ["."] # list of folders that contain the packages (["."] by default) 62 | include = ["tapnet*"] # package names should match these glob patterns (["*"] by default) 63 | 64 | -------------------------------------------------------------------------------- /tapnet/tapvid3d/generate_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2025 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | 18 | set -x 19 | 20 | debug=0 21 | 22 | # Parse debug option 23 | while [[ "$#" -gt 0 ]]; do 24 | case "$1" in 25 | -d|--debug) 26 | debug=1 27 | ;; 28 | *) 29 | echo "Unknown option: $1" 30 | ;; 31 | esac 32 | shift 33 | done 34 | 35 | if [[ $debug -eq 1 ]]; then 36 | PYTHON_DEBUG="True" 37 | else 38 | PYTHON_DEBUG="False" 39 | fi 40 | 41 | python3 -m venv tapvid3d 42 | source tapvid3d/bin/activate 43 | 44 | # Download the ADT data and annotations 45 | ADT_OUTPUT_DIRECTORY="tapvid3d_dataset/adt/" 46 | mkdir -p $ADT_OUTPUT_DIRECTORY 47 | python3 -m tapnet.tapvid3d.annotation_generation.generate_adt --output_dir=$ADT_OUTPUT_DIRECTORY --debug=$PYTHON_DEBUG --split=all 48 | 49 | # Download the Panoptic Studio data and annotations 50 | PSTUDIO_OUTPUT_DIRECTORY="tapvid3d_dataset/pstudio/" 51 | mkdir -p $PSTUDIO_OUTPUT_DIRECTORY 52 | python3 -m tapnet.tapvid3d.annotation_generation.generate_pstudio --output_dir=$PSTUDIO_OUTPUT_DIRECTORY --debug=$PYTHON_DEBUG --split=all 53 | 54 | # Download the Waymo Open / DriveTrack data and annotations 55 | DT_OUTPUT_DIRECTORY="tapvid3d_dataset/drivetrack/" 56 | mkdir -p $DT_OUTPUT_DIRECTORY 57 | python3 -m tapnet.tapvid3d.annotation_generation.generate_drivetrack --output_dir=$DT_OUTPUT_DIRECTORY --debug=$PYTHON_DEBUG --split=all 58 | 59 | deactivate 60 | -------------------------------------------------------------------------------- /tapnet/tapvid3d/annotation_generation/generate_drivetrack.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Download the Waymo Open + DriveTrack videos and annotations. 17 | 18 | The *.npz files in the GCS bucket for the DriveTrack split contains both 19 | the preprocessed annotations and source videos, so all we need to do is 20 | bulk download them to the local machine. 21 | """ 22 | 23 | from collections.abc import Sequence 24 | import os 25 | 26 | from absl import app 27 | from absl import flags 28 | from tapnet.tapvid3d.annotation_generation import gcs_utils 29 | 30 | 31 | _OUTPUT_DIR = flags.DEFINE_string( 32 | "output_dir", 33 | "tapvid3d_dataset/drivetrack/", 34 | "Path to folder to store output npz files containing all fields.", 35 | ) 36 | 37 | _DEBUG = flags.DEFINE_boolean( 38 | "debug", 39 | False, 40 | "Whether to run in debug mode, downloads only one video.", 41 | ) 42 | 43 | _SPLIT = flags.DEFINE_enum( 44 | "split", 45 | "all", 46 | ["minival", "full_eval", "all"], 47 | """ 48 | If True, compute metrics on the minival split; 49 | otherwise uses the full_eval split. 50 | """, 51 | ) 52 | 53 | 54 | def main(argv: Sequence[str]) -> None: 55 | if len(argv) > 1: 56 | raise app.UsageError("Too many command-line arguments.") 57 | 58 | if not os.path.exists(_OUTPUT_DIR.value): 59 | os.makedirs(_OUTPUT_DIR.value) 60 | 61 | gcs_utils.download_tapvid3d_files( 62 | _OUTPUT_DIR.value, _SPLIT.value, "drivetrack", _DEBUG.value 63 | ) 64 | 65 | 66 | if __name__ == "__main__": 67 | app.run(main) 68 | -------------------------------------------------------------------------------- /tapnet/tapvid/visualize.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Visualize frames of a random video of the given dataset.""" 17 | 18 | import io 19 | import pickle 20 | import random 21 | from typing import Sequence 22 | 23 | from absl import app 24 | from absl import flags 25 | from absl import logging 26 | import mediapy as media 27 | import numpy as np 28 | from PIL import Image 29 | 30 | from tapnet.utils import viz_utils 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | flags.DEFINE_string( 35 | 'input_path', None, 'Path to the pickle file.', required=True 36 | ) 37 | flags.DEFINE_string( 38 | 'output_path', None, 'Path to the output mp4 video.', required=True 39 | ) 40 | 41 | 42 | def main(argv: Sequence[str]) -> None: 43 | del argv 44 | 45 | logging.info('Loading data from %s. This takes time.', FLAGS.input_path) 46 | with open(FLAGS.input_path, 'rb') as f: 47 | data = pickle.load(f) 48 | if isinstance(data, dict): 49 | data = list(data.values()) 50 | 51 | idx = random.randint(0, len(data) - 1) 52 | video = data[idx] 53 | 54 | frames = video['video'] 55 | 56 | if isinstance(frames[0], bytes): 57 | # Tapnet is stored and JPEG bytes rather than `np.ndarray`s. 58 | def decode(frame): 59 | byteio = io.BytesIO(frame) 60 | img = Image.open(byteio) 61 | return np.array(img) 62 | 63 | frames = np.array([decode(frame) for frame in frames]) 64 | 65 | if frames.shape[1] > 360: 66 | frames = media.resize_video(frames, (360, 640)) 67 | 68 | scale_factor = np.array(frames.shape[2:0:-1])[np.newaxis, np.newaxis, :] 69 | painted_frames = viz_utils.paint_point_track( 70 | frames, 71 | video['points'] * scale_factor, 72 | ~video['occluded'], 73 | ) 74 | 75 | media.write_video(FLAGS.output_path, painted_frames, fps=25) 76 | logging.info('Examplar point visualization saved to %s', FLAGS.output_path) 77 | 78 | 79 | if __name__ == '__main__': 80 | app.run(main) 81 | -------------------------------------------------------------------------------- /tapnet/tapvid3d/evaluation/run_evaluate_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2025 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | 18 | # Expects: 19 | # (1) to be run from this directory! `./run_evaluate_model.sh`` 20 | # (2) that a copy of the generated dataset is in TAPVID3D_DIR, defined below 21 | # and should be modified as necessary, depending on where you stored the 22 | # generated dataset. Format is as described in the README.md, in summary: 23 | # a folder with drivetrack/, adt/, pstudio/ subfolders, each containing 24 | # *.npz files, one per video containing ground truth 3D tracks and other 25 | # metadata. 26 | 27 | TAPVID3D_DIR="/tmp/datasets/tapvid3d/" 28 | 29 | python3 -m venv eval_tapvid3d 30 | source eval_tapvid3d/bin/activate 31 | pip install absl-py==2.1.0 tqdm==4.66.4 numpy==2.0.0 pillow=10.4.0 32 | 33 | cd ../.. 34 | pip install . 35 | cd tapvid3d 36 | 37 | # First, download the SpaTracker predictions on the DriveTrack subset from GCP 38 | PREDICTIONS_DIR="/tmp/model_outputs/spatracker/" 39 | mkdir -p "$PREDICTIONS_DIR" 40 | 41 | # Get the DriveTrack filenames 42 | PYTHON_SCRIPT=$(cat < Callable[[str, str, jnp.ndarray], bool]: 30 | """Logic for deciding which parameters to include for weight decay.. 31 | 32 | Args: 33 | exclude_names: an optional list of names to include for weight_decay. ['w'] 34 | by default. 35 | 36 | Returns: 37 | A predicate that returns False for params that need to be excluded from 38 | weight_decay. 39 | """ 40 | # By default weight_decay the weights but not the biases. 41 | if exclude_names is None: 42 | exclude_names = ["b"] 43 | 44 | def include(module_name: Text, name: Text, value: jnp.ndarray): 45 | del value 46 | # Do not weight decay the parameters of normalization blocks. 47 | if any([norm_name in module_name for norm_name in NORM_NAMES]): 48 | return False 49 | else: 50 | return name not in exclude_names 51 | 52 | return include 53 | 54 | 55 | class AddWeightDecayState(NamedTuple): 56 | """Stateless transformation.""" 57 | 58 | 59 | def add_weight_decay( 60 | weight_decay: float, 61 | exclude_names: Optional[Sequence[Text]] = None, 62 | ) -> optax.GradientTransformation: 63 | """Add parameter scaled by `weight_decay` to the `updates`. 64 | 65 | Same as optax.add_decayed_weights but can exclude some parameters. 66 | 67 | Args: 68 | weight_decay: weight_decay coefficient. 69 | exclude_names: an optional list of names to exclude for weight_decay. ['b'] 70 | by default. 71 | 72 | Returns: 73 | An (init_fn, update_fn) tuple. 74 | """ 75 | 76 | def init_fn(_): 77 | return AddWeightDecayState() 78 | 79 | def update_fn(updates, state, params): 80 | include = _weight_decay_exclude(exclude_names=exclude_names) 81 | 82 | u_in, u_ex = hk.data_structures.partition(include, updates) 83 | p_in, _ = hk.data_structures.partition(include, params) 84 | u_in = jax.tree_util.tree_map(lambda g, p: g + weight_decay * p, u_in, p_in) 85 | updates = hk.data_structures.merge(u_ex, u_in) 86 | return updates, state 87 | 88 | return optax.GradientTransformation(init_fn, update_fn) 89 | -------------------------------------------------------------------------------- /tapnet/training/README.md: -------------------------------------------------------------------------------- 1 | # TAP training setup 2 | 3 | This directory contains a reference implementation for training and evaluating TAP-Net and TAPIR using Kubric, following our papers. 4 | 5 | ## Installation 6 | 7 | Install ffmpeg on your machine: 8 | 9 | ```sudo apt update``` 10 | 11 | ```sudo apt install ffmpeg``` 12 | 13 | Install OpenEXR: 14 | 15 | ```sudo apt-get install libopenexr-dev``` 16 | 17 | Clone the repository: 18 | 19 | ```git clone https://github.com/deepmind/tapnet.git``` 20 | 21 | Add current path (parent directory of where TapNet is installed) 22 | to ```PYTHONPATH```: 23 | 24 | ```export PYTHONPATH=`(cd ../ && pwd)`:`pwd`:$PYTHONPATH``` 25 | 26 | Switch to the project directory: 27 | 28 | ```cd tapnet``` 29 | 30 | Install kubric as a subdirectory: 31 | 32 | ```git clone https://github.com/google-research/kubric.git``` 33 | 34 | Install requirements: 35 | 36 | ```pip install -r requirements.txt``` 37 | 38 | If you want to use CUDA, make sure you install the drivers and a version 39 | of JAX that's compatible with your CUDA and CUDNN versions. 40 | Refer to 41 | [the jax manual](https://github.com/jax-ml/jax#installation) 42 | to install the correct JAX version with CUDA. 43 | 44 | ## Training 45 | 46 | The configuration file is located at: ```./tapnet/configs/tapnet_config.py```. 47 | 48 | You can modify it for your need or create your own config file following 49 | the example of ```tapnet_config.py```. 50 | 51 | To launch experiment run the command: 52 | 53 | ```python3 -m tapnet.training.experiment --config ./tapnet/configs/tapnet_config.py``` 54 | 55 | or 56 | 57 | ```python3 -m tapnet.training.experiment --config ./tapnet/configs/tapir_config.py``` 58 | 59 | ## Evaluation 60 | 61 | You can run evaluation for a particular dataset (i.e. tapvid_davis) using the command: 62 | 63 | ```bash 64 | python3 -m tapnet.training.experiment \ 65 | --config=./tapnet/configs/tapir_config.py \ 66 | --jaxline_mode=eval_davis_points \ 67 | --config.checkpoint_dir=./tapnet/checkpoint/ \ 68 | --config.experiment_kwargs.config.davis_points_path=/path/to/tapvid_davis.pkl 69 | ``` 70 | 71 | Available eval datasets are listed in `supervised_point_prediction.py`. 72 | 73 | ## Inference 74 | 75 | You can run inference for a particular video (i.e. horsejump-high.mp4) using the command: 76 | 77 | ```bash 78 | python3 -m tapnet.training.experiment \ 79 | --config=./tapnet/configs/tapnet_config.py \ 80 | --jaxline_mode=eval_inference \ 81 | --config.checkpoint_dir=./tapnet/checkpoint/ \ 82 | --config.experiment_kwargs.config.inference.input_video_path=horsejump-high.mp4 \ 83 | --config.experiment_kwargs.config.inference.output_video_path=result.mp4 \ 84 | --config.experiment_kwargs.config.inference.resize_height=256 \ 85 | --config.experiment_kwargs.config.inference.resize_width=256 \ 86 | --config.experiment_kwargs.config.inference.num_points=20 87 | ``` 88 | 89 | The inference only serves as an example. It will resize the video to 256x256 resolution, sample 20 random query points on the first frame and track these random points in the rest frames. 90 | 91 | Note that this uses jaxline for model inference if you are training your own model. A more direct way for model inference can be found on the [colab and real-time demos](#tapir-demos). 92 | -------------------------------------------------------------------------------- /tapnet/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utilities for transforming image coordinates.""" 17 | 18 | from typing import Sequence 19 | 20 | import chex 21 | import numpy as np 22 | 23 | 24 | def convert_grid_coordinates( 25 | coords: chex.Array, 26 | input_grid_size: Sequence[int], 27 | output_grid_size: Sequence[int], 28 | coordinate_format: str = 'xy', 29 | ) -> chex.Array: 30 | """Convert image coordinates between image grids of different sizes. 31 | 32 | By default, it assumes that the image corners are aligned. Therefore, 33 | it adds .5 (since (0,0) is assumed to be the center of the upper-left grid 34 | cell), multiplies by the size ratio, and then subtracts .5. 35 | 36 | Args: 37 | coords: The coordinates to be converted. It is of shape [..., 2] if 38 | coordinate_format is 'xy' or [..., 3] if coordinate_format is 'tyx'. 39 | input_grid_size: The size of the image/grid that the coordinates currently 40 | are with respect to. This is a 2-tuple of the format [width, height] 41 | if coordinate_format is 'xy' or a 3-tuple of the format 42 | [num_frames, height, width] if coordinate_format is 'tyx'. 43 | output_grid_size: The size of the target image/grid that you want the 44 | coordinates to be with respect to. This is a 2-tuple of the format 45 | [width, height] if coordinate_format is 'xy' or a 3-tuple of the format 46 | [num_frames, height, width] if coordinate_format is 'tyx'. 47 | coordinate_format: Which format the coordinates are in. This can be one 48 | of 'xy' (the default) or 'tyx', which are the only formats used in this 49 | project. 50 | 51 | Returns: 52 | The transformed coordinates, of the same shape as coordinates. 53 | 54 | Raises: 55 | ValueError: if coordinates don't match the given format. 56 | """ 57 | if isinstance(input_grid_size, tuple): 58 | input_grid_size = np.array(input_grid_size) 59 | if isinstance(output_grid_size, tuple): 60 | output_grid_size = np.array(output_grid_size) 61 | 62 | if coordinate_format == 'xy': 63 | if input_grid_size.shape[0] != 2 or output_grid_size.shape[0] != 2: 64 | raise ValueError( 65 | 'If coordinate_format is xy, the shapes must be length 2.') 66 | elif coordinate_format == 'tyx': 67 | if input_grid_size.shape[0] != 3 or output_grid_size.shape[0] != 3: 68 | raise ValueError( 69 | 'If coordinate_format is tyx, the shapes must be length 3.') 70 | if input_grid_size[0] != output_grid_size[0]: 71 | raise ValueError('converting frame count is not supported.') 72 | else: 73 | raise ValueError('Recognized coordinate formats are xy and tyx.') 74 | 75 | position_in_grid = coords 76 | position_in_grid = position_in_grid * output_grid_size / input_grid_size 77 | 78 | return position_in_grid 79 | -------------------------------------------------------------------------------- /tapnet/tapvid3d/annotation_generation/gcs_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utils for downloading TAPVid3d data from GCS.""" 17 | 18 | import os 19 | import sys 20 | from absl import logging 21 | import requests 22 | import tqdm 23 | 24 | current_dir = os.path.dirname(os.path.abspath(__file__)) 25 | top_level_dir = os.path.abspath(os.path.join(current_dir, "..")) 26 | sys.path.insert(0, top_level_dir) 27 | from tapnet.tapvid3d.splits import tapvid3d_splits # pylint: disable=g-import-not-at-top, g-bad-import-order 28 | 29 | TAPVID3D_GCS_URL = ( 30 | "https://storage.googleapis.com/dm-tapnet/tapvid3d/release_files/v1.0" 31 | ) 32 | 33 | 34 | def download_tapvid3d_files( 35 | output_dir: str, split: str, subset: str, debug: bool 36 | ): 37 | """Downloads files from the given split and subset.""" 38 | os.makedirs(output_dir, exist_ok=True) 39 | if split == "minival": 40 | filenames_to_download = tapvid3d_splits.get_minival_files(subset) 41 | elif split == "full_eval": 42 | filenames_to_download = tapvid3d_splits.get_full_eval_files(subset) 43 | elif split == "all": 44 | filenames_to_download = tapvid3d_splits.get_all_files(subset) 45 | else: 46 | raise ValueError(f"Unknown split: {split}") 47 | 48 | logging.info("Downloading %s split", split) 49 | for filename in tqdm.tqdm( 50 | filenames_to_download, total=len(filenames_to_download) 51 | ): 52 | local_path = os.path.join(output_dir, filename) 53 | gcs_url = get_tapvid3d_gcs_urls( 54 | filenames=filename, 55 | url_postfix=subset, 56 | ) 57 | logging.info("Downloading %s to %s", gcs_url, local_path) 58 | download_file_from_url(input_url=gcs_url, output_filepath=local_path) 59 | if debug: 60 | logging.info("Stopping after one video, debug run.") 61 | break 62 | logging.info("Finished downloading all examples!") 63 | 64 | 65 | def download_file_from_url(input_url: str, output_filepath: str): 66 | """Download the GCS file from the given URL to the given output path.""" 67 | if os.path.exists(output_filepath): 68 | logging.info("Skipping download, file already exists: %s", output_filepath) 69 | return 70 | response = requests.get(input_url, stream=True) 71 | if response.status_code == 200: 72 | logging.info("Downloading: %s to %s", input_url, output_filepath) 73 | with open(output_filepath, "wb") as f: 74 | for chunk in response.iter_content(chunk_size=1024): 75 | f.write(chunk) 76 | logging.info("Downloaded!") 77 | else: 78 | logging.info("Download failed. HTTP Status Code: %d", response.status_code) 79 | 80 | 81 | def get_tapvid3d_gcs_urls( 82 | filenames: str | list[str], url_postfix: str = "" 83 | ) -> str | list[str]: 84 | if isinstance(filenames, str): 85 | return f"{TAPVID3D_GCS_URL}/{url_postfix}/{filenames}" 86 | else: 87 | return [ 88 | f"{TAPVID3D_GCS_URL}/{url_postfix}/{filename}" for filename in filenames 89 | ] 90 | -------------------------------------------------------------------------------- /tapnet/tapnext/pscan.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Parallel Scan Operation.""" 17 | 18 | import torch 19 | 20 | 21 | def safe_div(numerator, denominator): 22 | return torch.where( 23 | torch.abs(denominator) < 1e-5, 24 | numerator * 100000.0, 25 | numerator / denominator, 26 | ) 27 | 28 | 29 | class PScan(torch.autograd.Function): 30 | """Implements a parallel scan operation. 31 | 32 | Given A is (N, T, D) and X is (N, T, D), expands A and X in-place in O(T), 33 | and O(log(T)) if not core-bounded, so that: 34 | Y[:, 0] = Y_init 35 | Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t] 36 | can be computed as: 37 | Y[:, t] = A[:, t] * Y_init + X[:, t] 38 | """ 39 | 40 | @classmethod 41 | def expand(cls, weights, bias): 42 | if weights.size(1) == 1: 43 | return 44 | t_even = 2 * (weights.size(1) // 2) 45 | 46 | w_pairs = weights[:, :t_even].view(weights.size(0), t_even // 2, 2, -1) 47 | b_pairs = bias[:, :t_even].view(bias.size(0), t_even // 2, 2, -1) 48 | 49 | b_pairs[:, :, 1].add_(w_pairs[:, :, 1] * b_pairs[:, :, 0]) 50 | w_pairs[:, :, 1].mul_(w_pairs[:, :, 0]) 51 | 52 | PScan.expand(w_pairs[:, :, 1], b_pairs[:, :, 1]) 53 | 54 | b_pairs[:, 1:, 0].add_(w_pairs[:, 1:, 0] * b_pairs[:, :-1, 1]) 55 | w_pairs[:, 1:, 0].mul_(w_pairs[:, :-1, 1]) 56 | 57 | if t_even < weights.size(1): 58 | bias[:, -1].add_(weights[:, -1] * bias[:, -2]) 59 | weights[:, -1].mul_(weights[:, -2]) 60 | 61 | @classmethod 62 | def accrev(cls, tensor): 63 | if tensor.size(1) == 1: 64 | return 65 | t_even = 2 * (tensor.size(1) // 2) 66 | 67 | pairs = tensor[:, -t_even:].view(tensor.size(0), t_even // 2, 2, -1) 68 | 69 | pairs[:, :, 0].add_(pairs[:, :, 1]) 70 | PScan.accrev(pairs[:, :, 0]) 71 | pairs[:, :-1, 1].add_(pairs[:, 1:, 0]) 72 | 73 | if t_even < tensor.size(1): 74 | tensor[:, 0].add_(tensor[:, 1]) 75 | 76 | @classmethod 77 | def forward(cls, ctx, weights, bias, y_init): 78 | ctx.weights_orig = weights.clone() 79 | ctx.y_init_expanded = y_init[:, None, :].clone() 80 | ctx.weights_expanded = weights.clone() 81 | ctx.bias_expanded = bias.clone() 82 | 83 | PScan.expand(ctx.weights_expanded, ctx.bias_expanded) 84 | output = ctx.weights_expanded * ctx.y_init_expanded + ctx.bias_expanded 85 | return output 86 | 87 | @classmethod 88 | def backward(cls, ctx, grad_output): 89 | grad_input_wrt_output = grad_output * ctx.weights_expanded 90 | grad_accumulated = grad_input_wrt_output.clone() 91 | 92 | PScan.accrev(grad_accumulated) 93 | 94 | grad_weights = safe_div(ctx.y_init_expanded, ctx.weights_orig) 95 | grad_weights[:, 1:].add_( 96 | safe_div(ctx.bias_expanded[:, :-1], ctx.weights_expanded[:, 1:]) 97 | ) 98 | 99 | grad_bias = safe_div(grad_accumulated, ctx.weights_expanded) 100 | grad_y_init = grad_input_wrt_output.sum(dim=1) 101 | 102 | return grad_weights * grad_accumulated, grad_bias, grad_y_init 103 | 104 | 105 | pscan = PScan.apply 106 | -------------------------------------------------------------------------------- /tapnet/tapvid3d/annotation_generation/generate_adt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Download the ADT query points and compute the remaining annotations. 17 | 18 | The *.npz files in the GCS bucket for the ADT split contain only the query 19 | points. This script loads the ADT scenes from the GCS bucket, computes the 20 | remaining annotations and saves them to *.npz files in the local output folder. 21 | """ 22 | 23 | import collections 24 | import glob 25 | import os 26 | from typing import Sequence 27 | 28 | from absl import app 29 | from absl import flags 30 | from tapnet.tapvid3d.annotation_generation import adt_utils 31 | from tapnet.tapvid3d.annotation_generation import gcs_utils 32 | import tqdm 33 | 34 | 35 | _OUTPUT_DIR = flags.DEFINE_string( 36 | "output_dir", 37 | "tapvid3d_dataset/adt/", 38 | "Path to folder to store output npz files containing all fields.", 39 | ) 40 | 41 | _DEBUG = flags.DEFINE_boolean( 42 | "debug", 43 | False, 44 | "Whether to run in debug mode, downloads only one video.", 45 | ) 46 | 47 | _SPLIT = flags.DEFINE_enum( 48 | "split", 49 | "all", 50 | ["minival", "full_eval", "all"], 51 | """ 52 | If True, compute metrics on the minival split; 53 | otherwise uses the full_eval split. 54 | """, 55 | ) 56 | 57 | _ADT_BASE_PATH = flags.DEFINE_string( 58 | "adt_base_path", 59 | "", 60 | "Path to folder containing ADT scenes as subfolders.", 61 | ) 62 | 63 | 64 | def generate_adt_npz( 65 | adt_base_path: str, input_npz_dir: str, output_npz_dir: str 66 | ): 67 | """Generates the final ADT npz files, adding the remaining annotations.""" 68 | input_npz = sorted(glob.glob(os.path.join(input_npz_dir, "*.npz"))) 69 | done_npz = [ 70 | os.path.basename(x) 71 | for x in sorted(glob.glob(os.path.join(output_npz_dir, "*.npz"))) 72 | ] 73 | 74 | # Filter completed files. 75 | input_npz = list( 76 | filter(lambda x: os.path.basename(x) not in done_npz, input_npz) 77 | ) 78 | 79 | # Group pending files by video. 80 | pending_vid_chunks = collections.defaultdict(list) 81 | for f in input_npz: 82 | basename = os.path.basename(f)[:-4].split("_") 83 | vid = "_".join(basename[:-1]) 84 | chunk = int(basename[-1]) 85 | pending_vid_chunks[vid].append(chunk) 86 | 87 | for vid, chunks in tqdm.tqdm( 88 | pending_vid_chunks.items(), total=len(pending_vid_chunks) 89 | ): 90 | adt_utils.process_vid( 91 | adt_base_path, 92 | input_npz_dir, 93 | output_npz_dir, 94 | vid, 95 | chunks, 96 | ) 97 | 98 | 99 | def main(argv: Sequence[str]) -> None: 100 | if len(argv) > 1: 101 | raise app.UsageError("Too many command-line arguments.") 102 | tmp_adt_dir = os.path.join(_OUTPUT_DIR.value, "tmp") 103 | 104 | # Download ADT npz's containing query points only. 105 | gcs_utils.download_tapvid3d_files( 106 | tmp_adt_dir, _SPLIT.value, "adt", _DEBUG.value 107 | ) 108 | 109 | # Compute the remaining annotations and save them to *.npz files. 110 | generate_adt_npz(_ADT_BASE_PATH.value, tmp_adt_dir, _OUTPUT_DIR.value) 111 | 112 | 113 | if __name__ == "__main__": 114 | app.run(main) 115 | -------------------------------------------------------------------------------- /tapnet/tapvid3d/annotation_generation/generate_pstudio.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Download the PStudio annotations and join with video data. 17 | 18 | The *.npz files in the GCS bucket for the ADT split contain all annotation but 19 | don't contain the video data. This script downloads the video data and adds it 20 | to the *.npz files. 21 | """ 22 | 23 | # pylint: disable=g-import-not-at-top,g-bad-import-order 24 | import tensorflow as tf 25 | 26 | tf.config.set_visible_devices([], "GPU") 27 | 28 | import glob 29 | import os 30 | from typing import Sequence 31 | import urllib.request 32 | import zipfile 33 | 34 | from absl import app 35 | from absl import flags 36 | from tapnet.tapvid3d.annotation_generation import gcs_utils 37 | import PIL.Image 38 | import numpy as np 39 | import tqdm 40 | 41 | 42 | _OUTPUT_DIR = flags.DEFINE_string( 43 | "output_dir", 44 | "tapvid3d_dataset/pstudio/", 45 | "Path to folder to store output npz files containing all fields.", 46 | ) 47 | 48 | _DEBUG = flags.DEFINE_boolean( 49 | "debug", 50 | False, 51 | "Whether to run in debug mode, downloads only one video.", 52 | ) 53 | 54 | _SPLIT = flags.DEFINE_enum( 55 | "split", 56 | "all", 57 | ["minival", "full_eval", "all"], 58 | """ 59 | If True, compute metrics on the minival split; 60 | otherwise uses the full_eval split. 61 | """, 62 | ) 63 | 64 | _PSTUDIO_DATA = flags.DEFINE_string( 65 | "pstudio_url", 66 | "https://omnomnom.vision.rwth-aachen.de/data/Dynamic3DGaussians/data.zip", 67 | "URL of PStudio data.", 68 | ) 69 | 70 | 71 | def generate_pstudio_npz( 72 | pstudio_base_path: str, input_npz_dir: str, output_npz_dir: str 73 | ): 74 | """Generates the final PStudio npz files, adding the video data.""" 75 | input_npz = sorted(glob.glob(os.path.join(input_npz_dir, "*.npz"))) 76 | done_npz = [ 77 | os.path.basename(x) 78 | for x in sorted(glob.glob(os.path.join(output_npz_dir, "*.npz"))) 79 | ] 80 | 81 | # Filter completed files. 82 | input_npz = list( 83 | filter(lambda x: os.path.basename(x) not in done_npz, input_npz) 84 | ) 85 | 86 | # For each example, load the video data and add it to the npz file. 87 | for filename in tqdm.tqdm(input_npz): 88 | example = dict(np.load(filename, allow_pickle=True)) 89 | out_fn = os.path.join(output_npz_dir, os.path.basename(filename)) 90 | seq, cam_id = os.path.basename(filename)[:-4].split("_") 91 | # load rgb images 92 | im_fns = sorted( 93 | glob.glob(os.path.join(pstudio_base_path, f"{seq}/ims/{cam_id}/*.jpg")) 94 | ) 95 | ims = [np.array(PIL.Image.open(im_fn)) for im_fn in im_fns] 96 | ims_jpeg = [np.array(tf.io.encode_jpeg(im)).item() for im in ims] 97 | example["images_jpeg_bytes"] = ims_jpeg 98 | np.savez(out_fn, **example) 99 | 100 | 101 | def main(argv: Sequence[str]) -> None: 102 | if len(argv) > 1: 103 | raise app.UsageError("Too many command-line arguments.") 104 | tmp_pstudio_dir = os.path.join(_OUTPUT_DIR.value, "tmp") 105 | 106 | gcs_utils.download_tapvid3d_files( 107 | tmp_pstudio_dir, _SPLIT.value, "pstudio", _DEBUG.value 108 | ) 109 | # Download and extract PStudio video data. 110 | pstudio_zip = os.path.join(tmp_pstudio_dir, "data.zip") 111 | pstudio_data = os.path.join(tmp_pstudio_dir, "data") 112 | if not os.path.exists(pstudio_zip): 113 | print(f"Downloading PStudio data to {pstudio_zip}") 114 | urllib.request.urlretrieve(_PSTUDIO_DATA.value, pstudio_zip) 115 | else: 116 | print(f"Skipping download, PStudio data already exists: {pstudio_zip}") 117 | if not os.path.exists(pstudio_data): 118 | print(f"Extracting PStudio data to {pstudio_data}") 119 | with zipfile.ZipFile(pstudio_zip, "r") as zip_file: 120 | zip_file.extractall(tmp_pstudio_dir) 121 | else: 122 | print(f"Skipping extraction, PStudio data already exists: {pstudio_data}") 123 | 124 | # Compute the remaining annotations and save them to *.npz files. 125 | generate_pstudio_npz(pstudio_data, tmp_pstudio_dir, _OUTPUT_DIR.value) 126 | 127 | 128 | if __name__ == "__main__": 129 | app.run(main) 130 | -------------------------------------------------------------------------------- /tapnet/utils/index_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Index utils for TAPViT-SSM.""" 17 | 18 | import functools 19 | 20 | import chex 21 | import jax 22 | import jax.numpy as jnp 23 | # from kauldron.typing import Int, Bool, Float # pylint: disable=g-multiple-import,g-importing-member 24 | 25 | 26 | def scatter_inner( 27 | target: chex.Array, mask: chex.Array, timestep: chex.Array, data: chex.Array 28 | ) -> chex.Array: 29 | """Scatter data into target at timestep if mask is true. 30 | 31 | Args: 32 | target: (T, c) float target tensor 33 | mask: ([]) bool mask 34 | timestep: ([]) int timestep 35 | data: (c,) data to scatter 36 | Returns: 37 | (T, c) updated target tensor 38 | """ 39 | updated_target = target.at[timestep].set(data) 40 | return jnp.where(mask, updated_target, target) 41 | 42 | 43 | @jax.vmap 44 | @functools.partial(jax.vmap, in_axes=(1, 0, 0, 0), out_axes=1) 45 | def scatter( 46 | target: chex.Array, mask: chex.Array, timestep: chex.Array, data: chex.Array 47 | ) -> chex.Array: 48 | """Scatter data into target at timestep if mask is true. 49 | 50 | (dimensions that are added via vmap are put into square brackets []) 51 | 52 | Args: 53 | target: ([B], T, [Q], c) float target tensor 54 | mask: ([B, Q]) bool mask 55 | timestep: ([B, Q]) int timestep 56 | data: ([B, Q], c,) data to scatter 57 | Returns: 58 | ([B], T, [Q], c) updated target tensor 59 | """ 60 | return scatter_inner(target, mask, timestep, data) 61 | 62 | 63 | @jax.vmap 64 | @functools.partial(jax.vmap, in_axes=(1, None, None, 0), out_axes=1) 65 | def scatter2( 66 | target: chex.Array, mask: chex.Array, timestep: chex.Array, data: chex.Array 67 | ) -> chex.Array: 68 | """Scatter data into target at timestep if mask is true. 69 | 70 | (dimensions that are added via vmap are put into square brackets []) 71 | 72 | Args: 73 | target: ([B], T, [N], c) float target tensor 74 | mask: ([B]) bool mask 75 | timestep: ([B]) int timestep 76 | data: ([B, N], c,) data to scatter 77 | Returns: 78 | ([B], T, [N], c) updated target tensor 79 | """ 80 | return scatter_inner(target, mask, timestep, data) 81 | 82 | 83 | @jax.vmap 84 | @functools.partial(jax.vmap, in_axes=(1, 0, 0, 0), out_axes=1) 85 | def scatter_prefix( 86 | target: chex.Array, mask: chex.Array, timestep: chex.Array, data: chex.Array 87 | ) -> chex.Array: 88 | """Scatter data into target before timestep if mask is true. 89 | 90 | Equivalent to 91 | 92 | updated_target = target.at[:timestep].set(data) 93 | return jnp.where(mask, updated_target, target) 94 | 95 | but works in a static way. 96 | 97 | (dimensions that are added via vmap are put into square brackets []) 98 | 99 | Args: 100 | target: ([B], T, [Q], c) float target tensor 101 | mask: ([B, Q]) bool mask 102 | timestep: ([B, Q]) int timestep 103 | data: ([B, Q], c,) data to scatter 104 | Returns: 105 | ([B], T, [Q], c) updated target tensor 106 | """ 107 | cond = (jnp.arange(target.shape[0]) < timestep) & mask 108 | return jnp.where( 109 | jnp.tile(cond[:, None], (1, target.shape[1])), 110 | jnp.tile(data, (target.shape[0], 1)), 111 | target, 112 | ) 113 | 114 | 115 | @jax.vmap 116 | @functools.partial(jax.vmap, in_axes=(1, 0, 0, 0), out_axes=1) 117 | def scatter_suffix( 118 | target: chex.Array, mask: chex.Array, timestep: chex.Array, data: chex.Array 119 | ) -> chex.Array: 120 | """Scatter data into target before timestep if mask is true. 121 | 122 | Equivalent to 123 | 124 | updated_target = target.at[timestep:].set(data) 125 | return jnp.where(mask, updated_target, target) 126 | 127 | but works in a static way. 128 | 129 | (dimensions that are added via vmap are put into square brackets []) 130 | 131 | Args: 132 | target: ([B], T, [Q], c) float target tensor 133 | mask: ([B, Q]) bool mask 134 | timestep: ([B, Q]) int timestep 135 | data: ([B, Q], c,) data to scatter 136 | Returns: 137 | ([B], T, [Q], c) updated target tensor 138 | """ 139 | cond = (jnp.arange(target.shape[0]) >= timestep) & mask 140 | return jnp.where( 141 | jnp.tile(cond[:, None], (1, target.shape[1])), 142 | jnp.tile(data, (target.shape[0], 1)), 143 | target, 144 | ) 145 | -------------------------------------------------------------------------------- /tapnet/tapnext/torch_losses.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Losses for TAPNext.""" 17 | 18 | import einops 19 | import torch 20 | from torch.nn import functional as F 21 | 22 | 23 | def huber_coordinate_loss( 24 | pred_points, target_points, mask, delta=1.0, pixel_size=256 25 | ): 26 | """Computes the Huber loss between predicted and target coordinates. 27 | 28 | Args: 29 | pred_points (*shape, 2): point coordinates predicted by the model 30 | target_points (*shape, 2): target point coordinates 31 | mask (*shape): visibility mask 32 | delta (float): the threshold of the Huber loss 33 | pixel_size (int): pixel size of the image 34 | 35 | Returns: 36 | Continuous huber loss (*shape) 37 | """ 38 | pred_points = pred_points.float() 39 | target_points = target_points.float() 40 | target_points = target_points.clip(0, pixel_size - 1) 41 | error = pred_points - target_points 42 | error = error.clip(-1e8, 1e8) # add magnitude bound to prevent nan 43 | distsqr = torch.sum(torch.square(error), dim=-1, keepdims=True) 44 | dist = torch.sqrt(distsqr + 1e-12) 45 | loss = torch.where( 46 | dist < delta, 47 | distsqr / 2, 48 | delta * (torch.abs(dist) - delta / 2), 49 | ) 50 | mask = mask.float() 51 | loss = (loss * mask).sum() / mask.sum() 52 | return loss 53 | 54 | 55 | def coordinate_softmax(logits, labels, mask, pixel_size=256): 56 | """Computes the softmax loss between predicted logits and target coordinates. 57 | 58 | Args: 59 | logits (*shape, n_bins x 2): marginal softmax logits for predicting x and y 60 | coordinates 61 | labels (*shape, 2): taget coordinates 62 | mask (*shape): visibility mask 63 | pixel_size (int): pixel size of the image 64 | 65 | Returns: 66 | loss (float): the softmax loss 67 | """ 68 | logits = logits.float() 69 | labels = labels.float() 70 | labels -= 0.5 71 | labels = labels.clip(0, pixel_size - 1) 72 | labels = torch.round(labels).long() 73 | logits = einops.rearrange(logits, 'b ... c -> b c ...') 74 | labels = einops.rearrange(labels, 'b ... c -> b c ...') 75 | logits_x, logits_y = logits.chunk(2, dim=1) 76 | labels_x, labels_y = labels.chunk(2, dim=1) 77 | print(logits_x.shape, labels_x.shape) 78 | loss_x = F.cross_entropy(logits_x, labels_x.squeeze(1)) 79 | loss_y = F.cross_entropy(logits_y, labels_y.squeeze(1)) 80 | loss = loss_x + loss_y 81 | mask = mask.float() 82 | loss = (loss * mask).sum() / mask.sum() 83 | return loss 84 | 85 | 86 | def tapnext_loss_and_grad(model, batch, loss_weight=1.0): 87 | """Computes the TAPNext loss and performs backward pass on the model. 88 | 89 | Use the init arg `use_checkpointing=True` when constructing TAPNext to 90 | optimize memory; this does not have any impact on the inference speed/quality. 91 | 92 | Args: 93 | model (TAPNext): 94 | batch (dict): a dictionary with 4 keys: * 'video' - a float32 tensor of 95 | shape [batch, time, height, width, 3]; it should be mean-std normalized 96 | (e.g. ImageNet normalization) * 'query_points' - a float32 tensor of shape 97 | [batch, num_queries, 3] - queries have the form (t, x, y); where `t` - is 98 | in [0, time]; x is in [0, width] and y is in [0, height] * 'target_points' 99 | - a float32 tensor of shape [batch, num_queries, time, 2] - target points 100 | of the form (y, x), same ranges as query points * 'visible' - a float32 101 | tensor of shape [batch, num_queries, time, 1] - visibility flags (1. is 102 | visible and 0. is not visible) 103 | loss_weight (float): weight of the loss (default: 1.0) 104 | 105 | Returns: 106 | loss (float): the total loss 107 | """ 108 | pred_tracks, track_logits, visible_logits, _ = model( 109 | video=batch['video'], query_points=batch['query_points'] 110 | ) 111 | huber_loss = huber_coordinate_loss( 112 | pred_tracks, 113 | batch['target_points'].transpose(1, 2).flip(-1), 114 | batch['visible'].transpose(1, 2), 115 | ) 116 | softmax_loss = coordinate_softmax( 117 | track_logits, 118 | batch['target_points'].transpose(1, 2).flip(-1), 119 | batch['visible'].transpose(1, 2), 120 | ) 121 | coordinate_loss = 0.1 * huber_loss + 1.0 * softmax_loss 122 | visibility_loss = F.binary_cross_entropy_with_logits( 123 | visible_logits, batch['visible'].transpose(1, 2) 124 | ) 125 | loss = coordinate_loss + visibility_loss 126 | (loss * loss_weight).backward() 127 | return loss.item() 128 | -------------------------------------------------------------------------------- /tapnet/utils/ssm_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Util functions for processing videos with SSMs.""" 17 | 18 | import enum 19 | import chex 20 | import einops 21 | import flax.linen as nn 22 | import jax 23 | from jax import numpy as jnp 24 | import numpy as np 25 | from recurrentgemma._src.jax import scan 26 | 27 | from tapnet.utils import index_utils 28 | 29 | 30 | class CodeOrigin(enum.StrEnum): 31 | XLM = "xlm" 32 | THIRD_PARTY = "third_party" 33 | 34 | 35 | # @typechecked 36 | def transpose_flatten( 37 | x: chex.Array, like_shape: tuple[int, int, int, int], original_shape: str 38 | ) -> chex.Array: 39 | shape = dict(zip("btnc", like_shape, strict=True)) 40 | return einops.rearrange(x, original_shape + "-> (b n) t c", **shape) 41 | 42 | 43 | # @typechecked 44 | def unflatten_untranspose( 45 | x: chex.Array, like_shape: tuple[int, int, int, int], original_shape: str 46 | ) -> chex.Array: 47 | shape = dict(zip("btnc", like_shape, strict=True)) 48 | return einops.rearrange(x, "(b n) t c ->" + original_shape, **shape) 49 | 50 | 51 | def get_sharding_spec() -> scan.ShardingSpec: 52 | """Returns the sharding spec for the Pallas kernel.""" 53 | devices = np.asarray(jax.devices()) 54 | grid = devices.reshape(-1, len(devices)) 55 | 56 | # The axis names used in Kauldron are 'i' and 'j' 57 | mesh = jax.sharding.Mesh(grid, axis_names=("i", "j")) 58 | # It looks like `j` is the data axis 59 | scan_sharding_spec = scan.ShardingSpec( 60 | mesh=mesh, 61 | batch_axis_name="j", 62 | activations_axis_name="i", 63 | ) 64 | return scan_sharding_spec 65 | 66 | 67 | class TokenSubsampling(nn.Module): 68 | """Drops video tubes.""" 69 | 70 | # Hparams. 71 | drop_ratio: float 72 | drop_ratio_test: float = 0.0 73 | shuffle_tokens: bool = True 74 | 75 | # only true is supported for now 76 | mask_temporal_tokens: bool = True 77 | 78 | is_training: bool = False 79 | 80 | @nn.compact 81 | # @typechecked 82 | def __call__( 83 | self, 84 | tokens: chex.Array, # Float["*B T N D"], 85 | mask_token: chex.Array, # Float["*B T N D"], 86 | override_drop_ratio: float | None = None, 87 | ) -> tuple[chex.Array, chex.Array]: # Float["*B T N D"], Bool["*B T"]]: 88 | """Drops tokens randomly.""" 89 | n_batch, seq_len, num_tokens, _ = tokens.shape 90 | 91 | # By default tokens are only dropped for training. 92 | if override_drop_ratio is not None: 93 | drop_ratio = override_drop_ratio 94 | elif self.is_training: 95 | drop_ratio = self.drop_ratio 96 | else: 97 | drop_ratio = self.drop_ratio_test 98 | if drop_ratio == 0.0: 99 | return tokens, jnp.ones(tokens.shape[:2], dtype=jnp.bool_) 100 | 101 | # Drop tokens randomly. 102 | if self.mask_temporal_tokens: 103 | n_tokens = int(seq_len) - 1 104 | else: 105 | n_tokens = int(num_tokens) 106 | if len(n_batch) != 1: 107 | raise NotImplementedError("*B is not supported yet!") 108 | rng = self.make_rng("degradation") 109 | num_vis_patches = tokens.shape[2] 110 | 111 | subkey, _ = jax.random.split(rng, 2) 112 | 113 | masked_tokens = tokens 114 | 115 | subsample_size = jax.random.choice(subkey, n_tokens - 1, shape=(n_batch,)) 116 | subsample_size += 1 117 | # subsample_size is a random integer between 1 and T - 1 (inclusively) 118 | 119 | # mask_tokens - B T N D 120 | mask = jnp.ones( 121 | (n_batch, num_vis_patches), dtype=jnp.bool_) 122 | indices = jnp.tile( 123 | subsample_size[:, None], (1, num_vis_patches) 124 | ) 125 | # TODO(zholus): this works because we don't have temporal positional 126 | # embeddings (i.e. at all temporal positions masked tokens are the same). 127 | # when we have positional embeddings, we need dynamically choose 128 | # the mask token for each position according to `subsample_size`. 129 | scatter_data = mask_token[:, 0] 130 | # scatter_data.shape == [B, N, D] 131 | masked_tokens = index_utils.scatter_suffix( 132 | masked_tokens, mask, indices, scatter_data 133 | ) 134 | masked_positions = jnp.zeros( 135 | (n_batch, n_tokens + 1, 1, 1), dtype=jnp.bool_) 136 | mask = jnp.ones( 137 | (n_batch, 1), dtype=jnp.bool_) 138 | masked_positions = index_utils.scatter_suffix( 139 | masked_positions, mask, subsample_size[:, None], 140 | jnp.ones((n_batch, 1, 1), dtype=jnp.bool_) 141 | )[..., 0, 0] 142 | return masked_tokens, masked_positions 143 | -------------------------------------------------------------------------------- /colabs/tapir_clustering.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "id": "MWPOsk-I8o69" 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "# @title Install code and dependencies {form-width: \"25%\"}\n", 12 | "!pip install git+https://github.com/google-deepmind/tapnet.git" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": { 19 | "id": "dNWBx_DOHSSt" 20 | }, 21 | "outputs": [], 22 | "source": [ 23 | "# @title Download Model {form-width: \"25%\"}\n", 24 | "\n", 25 | "%mkdir tapnet/checkpoints\n", 26 | "\n", 27 | "!wget -P tapnet/checkpoints https://storage.googleapis.com/dm-tapnet/causal_tapir_checkpoint.npy\n", 28 | "\n", 29 | "%ls tapnet/checkpoints\n", 30 | "\n", 31 | "checkpoint_path = 'tapnet/checkpoints/causal_tapir_checkpoint.npy'" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": { 38 | "id": "jtTNXUNCHVAL" 39 | }, 40 | "outputs": [], 41 | "source": [ 42 | "# @title Imports {form-width: \"25%\"}\n", 43 | "%matplotlib widget\n", 44 | "import functools\n", 45 | "\n", 46 | "import haiku as hk\n", 47 | "import jax\n", 48 | "import jax.numpy as jnp\n", 49 | "import matplotlib.pyplot as plt\n", 50 | "import mediapy as media\n", 51 | "import numpy as np\n", 52 | "from tqdm import tqdm\n", 53 | "import tree\n", 54 | "\n", 55 | "from tapnet.robotap import tapir_clustering\n", 56 | "from tapnet.utils import transforms\n", 57 | "from tapnet.utils import viz_utils" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": { 64 | "id": "0J9kVfSuHmqS" 65 | }, 66 | "outputs": [], 67 | "source": [ 68 | "# @title Load an Exemplar Video {form-width: \"25%\"}\n", 69 | "\n", 70 | "%mkdir tapnet/examplar_videos\n", 71 | "\n", 72 | "!wget -P tapnet/examplar_videos https://storage.googleapis.com/dm-tapnet/robotap/for_clustering.mp4\n", 73 | "\n", 74 | "video = media.read_video('tapnet/examplar_videos/for_clustering.mp4')\n", 75 | "height, width = video.shape[1:3]\n", 76 | "media.show_video(video[::5], fps=10)" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": { 83 | "id": "7Vjhi4PdJ2W-" 84 | }, 85 | "outputs": [], 86 | "source": [ 87 | "# @title Run TAPIR to extract point tracks {form-width: \"25%\"}\n", 88 | "\n", 89 | "demo_videos = {\"dummy_id\":video}\n", 90 | "demo_episode_ids = list(demo_videos.keys())\n", 91 | "track_dict = tapir_clustering.track_many_points(\n", 92 | " demo_videos,\n", 93 | " demo_episode_ids,\n", 94 | " checkpoint_path,\n", 95 | " point_batch_size=1024,\n", 96 | " points_per_frame=10,\n", 97 | ")" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": { 104 | "id": "kU2yqJVTPgg-" 105 | }, 106 | "outputs": [], 107 | "source": [ 108 | "# @title Run the clustering {form-width: \"25%\"}\n", 109 | "\n", 110 | "clustered = tapir_clustering.compute_clusters(\n", 111 | " track_dict['separation_tracks'],\n", 112 | " track_dict['separation_visibility'],\n", 113 | " track_dict['demo_episode_ids'],\n", 114 | " track_dict['video_shape'],\n", 115 | " track_dict['query_features'],\n", 116 | " max_num_cats=12,\n", 117 | " final_num_cats=7,\n", 118 | ")" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": { 125 | "id": "FCNCAeLVQ0r2" 126 | }, 127 | "outputs": [], 128 | "source": [ 129 | "# @title Display the inferred clusters {form-width: \"25%\"}\n", 130 | "\n", 131 | "separation_visibility_trim = clustered['separation_visibility']\n", 132 | "separation_tracks_trim = clustered['separation_tracks']\n", 133 | "\n", 134 | "pointtrack_video = viz_utils.plot_tracks_v2(\n", 135 | " (demo_videos[demo_episode_ids[0]]).astype(np.uint8),\n", 136 | " separation_tracks_trim[demo_episode_ids[0]],\n", 137 | " 1.0-separation_visibility_trim[demo_episode_ids[0]],\n", 138 | " trackgroup=clustered['classes']\n", 139 | ")\n", 140 | "media.show_video(pointtrack_video, fps=20)" 141 | ] 142 | } 143 | ], 144 | "metadata": { 145 | "colab": { 146 | "provenance": [] 147 | }, 148 | "kernelspec": { 149 | "display_name": "Python 3", 150 | "name": "python3" 151 | }, 152 | "language_info": { 153 | "name": "python" 154 | } 155 | }, 156 | "nbformat": 4, 157 | "nbformat_minor": 0 158 | } 159 | -------------------------------------------------------------------------------- /configs/tapnet_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Default config to train the TapNet.""" 17 | 18 | from jaxline import base_config 19 | from ml_collections import config_dict 20 | 21 | 22 | # We define the experiment launch config in the same file as the experiment to 23 | # keep things self-contained in a single file. 24 | def get_config() -> config_dict.ConfigDict: 25 | """Return config object for training.""" 26 | config = base_config.get_base_config() 27 | 28 | # Experiment config. 29 | config.training_steps = 100000 30 | 31 | # NOTE: duplicates not allowed. 32 | config.shared_module_names = ('tapnet_model',) 33 | 34 | config.dataset_names = ('kubric',) 35 | # Note: eval modes must always start with 'eval_'. 36 | config.eval_modes = ( 37 | 'eval_davis_points', 38 | 'eval_jhmdb', 39 | 'eval_robotics_points', 40 | 'eval_kinetics_points', 41 | ) 42 | config.checkpoint_dir = '/tmp/tapnet_training/' 43 | config.evaluate_every = 10000 44 | 45 | config.experiment_kwargs = config_dict.ConfigDict( 46 | dict( 47 | config=dict( 48 | sweep_name='default_sweep', 49 | save_final_checkpoint_as_npy=True, 50 | # `enable_double_transpose` should always be false when using 1D. 51 | # For other D It is also completely untested and very unlikely 52 | # to work. 53 | optimizer=dict( 54 | base_lr=2e-3, 55 | max_norm=-1, # < 0 to turn off. 56 | weight_decay=1e-2, 57 | schedule_type='cosine', 58 | cosine_decay_kwargs=dict( 59 | init_value=0.0, 60 | warmup_steps=5000, 61 | end_value=0.0, 62 | ), 63 | optimizer='adam', 64 | # Optimizer-specific kwargs. 65 | adam_kwargs=dict( 66 | b1=0.9, 67 | b2=0.95, 68 | eps=1e-8, 69 | ), 70 | ), 71 | fast_variables=tuple(), 72 | shared_modules=dict( 73 | shared_module_names=config.get_oneway_ref( 74 | 'shared_module_names', 75 | ), 76 | tapnet_model_kwargs=dict(), 77 | ), 78 | datasets=dict( 79 | dataset_names=config.get_oneway_ref('dataset_names'), 80 | kubric_kwargs=dict( 81 | batch_dims=8, 82 | shuffle_buffer_size=128, 83 | train_size=(256, 256), 84 | ), 85 | ), 86 | supervised_point_prediction_kwargs=dict( 87 | prediction_algo='cost_volume_regressor', 88 | ), 89 | checkpoint_dir=config.get_oneway_ref('checkpoint_dir'), 90 | evaluate_every=config.get_oneway_ref('evaluate_every'), 91 | eval_modes=config.get_oneway_ref('eval_modes'), 92 | # If true, run evaluate() on the experiment once before 93 | # you load a checkpoint. 94 | # This is useful for getting initial values of metrics 95 | # at random weights, or when debugging locally if you 96 | # do not have any train job running. 97 | davis_points_path='', 98 | jhmdb_path='', 99 | robotics_points_path='', 100 | training=dict( 101 | # Note: to sweep n_training_steps, DO NOT sweep these 102 | # fields directly. Instead sweep config.training_steps. 103 | # Otherwise, decay/stopping logic 104 | # is not guaranteed to be consistent. 105 | n_training_steps=config.get_oneway_ref('training_steps'), 106 | ), 107 | inference=dict( 108 | input_video_path='', 109 | output_video_path='', 110 | resize_height=256, # video height resized to before inference 111 | resize_width=256, # video width resized to before inference 112 | num_points=20, # number of random points to sample 113 | ), 114 | ) 115 | ) 116 | ) 117 | 118 | # Set up where to store the resulting model. 119 | config.train_checkpoint_all_hosts = False 120 | config.save_checkpoint_interval = 10 121 | config.eval_initial_weights = True 122 | 123 | # Prevents accidentally setting keys that aren't recognized (e.g. in tests). 124 | config.lock() 125 | 126 | return config 127 | -------------------------------------------------------------------------------- /configs/tapir_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Default config to train the TAPIR.""" 17 | 18 | from jaxline import base_config 19 | from ml_collections import config_dict 20 | 21 | 22 | # We define the experiment launch config in the same file as the experiment to 23 | # keep things self-contained in a single file. 24 | def get_config() -> config_dict.ConfigDict: 25 | """Return config object for training.""" 26 | config = base_config.get_base_config() 27 | 28 | # Experiment config. 29 | config.training_steps = 100000 30 | 31 | # NOTE: duplicates not allowed. 32 | config.shared_module_names = ('tapir_model',) 33 | 34 | config.dataset_names = ('kubric',) 35 | # Note: eval modes must always start with 'eval_'. 36 | config.eval_modes = ( 37 | 'eval_davis_points', 38 | 'eval_jhmdb', 39 | 'eval_robotics_points', 40 | 'eval_kinetics_points', 41 | ) 42 | config.checkpoint_dir = '/tmp/tapnet_training/' 43 | config.evaluate_every = 10000 44 | 45 | config.experiment_kwargs = config_dict.ConfigDict( 46 | dict( 47 | config=dict( 48 | sweep_name='default_sweep', 49 | save_final_checkpoint_as_npy=True, 50 | # `enable_double_transpose` should always be false when using 1D. 51 | # For other D It is also completely untested and very unlikely 52 | # to work. 53 | optimizer=dict( 54 | base_lr=1e-3, 55 | max_norm=-1, # < 0 to turn off. 56 | weight_decay=1e-1, 57 | schedule_type='cosine', 58 | cosine_decay_kwargs=dict( 59 | init_value=0.0, 60 | warmup_steps=1000, 61 | end_value=0.0, 62 | ), 63 | optimizer='adam', 64 | # Optimizer-specific kwargs. 65 | adam_kwargs=dict( 66 | b1=0.9, 67 | b2=0.95, 68 | eps=1e-8, 69 | ), 70 | ), 71 | fast_variables=tuple(), 72 | shared_modules=dict( 73 | shared_module_names=config.get_oneway_ref( 74 | 'shared_module_names', 75 | ), 76 | tapir_model_kwargs=dict( 77 | bilinear_interp_with_depthwise_conv=False, 78 | pyramid_level=0, 79 | use_causal_conv=False, 80 | initial_resolution=(256, 256), 81 | ), 82 | ), 83 | datasets=dict( 84 | dataset_names=config.get_oneway_ref('dataset_names'), 85 | kubric_kwargs=dict( 86 | batch_dims=8, 87 | shuffle_buffer_size=128, 88 | train_size=(256, 256), 89 | ), 90 | ), 91 | supervised_point_prediction_kwargs=dict( 92 | prediction_algo='cost_volume_regressor', 93 | model_key='tapir_model', 94 | ), 95 | checkpoint_dir=config.get_oneway_ref('checkpoint_dir'), 96 | evaluate_every=config.get_oneway_ref('evaluate_every'), 97 | eval_modes=config.get_oneway_ref('eval_modes'), 98 | # If true, run evaluate() on the experiment once before 99 | # you load a checkpoint. 100 | # This is useful for getting initial values of metrics 101 | # at random weights, or when debugging locally if you 102 | # do not have any train job running. 103 | davis_points_path='', 104 | jhmdb_path='', 105 | robotics_points_path='', 106 | training=dict( 107 | # Note: to sweep n_training_steps, DO NOT sweep these 108 | # fields directly. Instead sweep config.training_steps. 109 | # Otherwise, decay/stopping logic 110 | # is not guaranteed to be consistent. 111 | n_training_steps=config.get_oneway_ref('training_steps'), 112 | ), 113 | inference=dict( 114 | input_video_path='', 115 | output_video_path='', 116 | resize_height=256, # video height resized to before inference 117 | resize_width=256, # video width resized to before inference 118 | num_points=20, # number of random points to sample 119 | ), 120 | ) 121 | ) 122 | ) 123 | 124 | # Set up where to store the resulting model. 125 | config.train_checkpoint_all_hosts = False 126 | config.save_checkpoint_interval = 10 127 | config.eval_initial_weights = True 128 | 129 | # Prevents accidentally setting keys that aren't recognized (e.g. in tests). 130 | config.lock() 131 | 132 | return config 133 | -------------------------------------------------------------------------------- /configs/causal_tapir_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Default config to train the TAPIR.""" 17 | 18 | from jaxline import base_config 19 | from ml_collections import config_dict 20 | 21 | 22 | # We define the experiment launch config in the same file as the experiment to 23 | # keep things self-contained in a single file. 24 | def get_config() -> config_dict.ConfigDict: 25 | """Return config object for training.""" 26 | config = base_config.get_base_config() 27 | 28 | # Experiment config. 29 | config.training_steps = 100000 30 | 31 | # NOTE: duplicates not allowed. 32 | config.shared_module_names = ('tapir_model',) 33 | 34 | config.dataset_names = ('kubric',) 35 | # Note: eval modes must always start with 'eval_'. 36 | config.eval_modes = ( 37 | 'eval_davis_points', 38 | 'eval_jhmdb', 39 | 'eval_robotics_points', 40 | 'eval_kinetics_points', 41 | ) 42 | config.checkpoint_dir = '/tmp/tapnet_training/' 43 | config.evaluate_every = 10000 44 | 45 | config.experiment_kwargs = config_dict.ConfigDict( 46 | dict( 47 | config=dict( 48 | sweep_name='default_sweep', 49 | save_final_checkpoint_as_npy=True, 50 | # `enable_double_transpose` should always be false when using 1D. 51 | # For other D It is also completely untested and very unlikely 52 | # to work. 53 | optimizer=dict( 54 | base_lr=1e-3, 55 | max_norm=-1, # < 0 to turn off. 56 | weight_decay=1e-1, 57 | schedule_type='cosine', 58 | cosine_decay_kwargs=dict( 59 | init_value=0.0, 60 | warmup_steps=1000, 61 | end_value=0.0, 62 | ), 63 | optimizer='adam', 64 | # Optimizer-specific kwargs. 65 | adam_kwargs=dict( 66 | b1=0.9, 67 | b2=0.95, 68 | eps=1e-8, 69 | ), 70 | ), 71 | fast_variables=tuple(), 72 | shared_modules=dict( 73 | shared_module_names=config.get_oneway_ref( 74 | 'shared_module_names', 75 | ), 76 | tapir_model_kwargs=dict( 77 | bilinear_interp_with_depthwise_conv=False, 78 | pyramid_level=1, 79 | use_causal_conv=True, 80 | initial_resolution=(256, 256), 81 | ), 82 | ), 83 | datasets=dict( 84 | dataset_names=config.get_oneway_ref('dataset_names'), 85 | kubric_kwargs=dict( 86 | batch_dims=8, 87 | shuffle_buffer_size=128, 88 | train_size=(256, 256), 89 | ), 90 | ), 91 | supervised_point_prediction_kwargs=dict( 92 | prediction_algo='cost_volume_regressor', 93 | model_key='tapir_model', 94 | ), 95 | checkpoint_dir=config.get_oneway_ref('checkpoint_dir'), 96 | evaluate_every=config.get_oneway_ref('evaluate_every'), 97 | eval_modes=config.get_oneway_ref('eval_modes'), 98 | # If true, run evaluate() on the experiment once before 99 | # you load a checkpoint. 100 | # This is useful for getting initial values of metrics 101 | # at random weights, or when debugging locally if you 102 | # do not have any train job running. 103 | davis_points_path='', 104 | jhmdb_path='', 105 | robotics_points_path='', 106 | training=dict( 107 | # Note: to sweep n_training_steps, DO NOT sweep these 108 | # fields directly. Instead sweep config.training_steps. 109 | # Otherwise, decay/stopping logic 110 | # is not guaranteed to be consistent. 111 | n_training_steps=config.get_oneway_ref('training_steps'), 112 | ), 113 | inference=dict( 114 | input_video_path='', 115 | output_video_path='', 116 | resize_height=256, # video height resized to before inference 117 | resize_width=256, # video width resized to before inference 118 | num_points=20, # number of random points to sample 119 | ), 120 | ) 121 | ) 122 | ) 123 | 124 | # Set up where to store the resulting model. 125 | config.train_checkpoint_all_hosts = False 126 | config.save_checkpoint_interval = 10 127 | config.eval_initial_weights = True 128 | 129 | # Prevents accidentally setting keys that aren't recognized (e.g. in tests). 130 | config.lock() 131 | 132 | return config 133 | -------------------------------------------------------------------------------- /tapnet/tapnext/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Losses for TAPNext.""" 17 | 18 | import dataclasses 19 | 20 | import jax 21 | import jax.numpy as jnp 22 | from kauldron import kontext 23 | from kauldron.losses import base 24 | from kauldron.typing import Bool, Float, typechecked # pylint: disable=g-multiple-import,g-importing-member 25 | import optax 26 | 27 | 28 | @dataclasses.dataclass(eq=True, frozen=True, kw_only=True) 29 | class Huber(base.Loss): 30 | """Huber loss for point track prediction.""" 31 | 32 | delta: float = 1.0 33 | 34 | pred_points: kontext.Key = kontext.REQUIRED 35 | target_points: kontext.Key = kontext.REQUIRED 36 | normalize_by: str = "values" 37 | 38 | @typechecked 39 | def get_values( 40 | self, 41 | pred_points: Float["*a 2"], 42 | target_points: Float["*a 2"], 43 | ) -> Float["*a 1"]: 44 | pred_points = jnp.astype(pred_points, jnp.float32) 45 | target_points = jnp.astype(target_points, jnp.float32) 46 | target_points = jnp.clip(target_points, 0, 255) 47 | error = pred_points - target_points 48 | error = jnp.clip(error, -1e8, 1e8) # add magnitude bound to prevent nan 49 | distsqr = jnp.sum(jnp.square(error), axis=-1, keepdims=True) 50 | dist = jnp.sqrt(distsqr + 1e-12) # add eps to prevent nan 51 | loss = jnp.where( 52 | dist < self.delta, 53 | distsqr / 2, 54 | self.delta * (jnp.abs(dist) - self.delta / 2), 55 | ) 56 | return loss 57 | 58 | 59 | @dataclasses.dataclass(eq=True, frozen=True, kw_only=True) 60 | class MaskedL1(base.Loss): 61 | """Masked L1 loss for predicting random image patches.""" 62 | pred_patches: kontext.Key = kontext.REQUIRED 63 | target_patches: kontext.Key = kontext.REQUIRED 64 | temporal_mask: kontext.Key = kontext.REQUIRED 65 | normalize_by: str = "values" 66 | image_norm: str = "sum" # "sum" or "mean" 67 | 68 | @typechecked 69 | def get_values( 70 | self, 71 | pred_patches: Float["*B T h w C"], 72 | target_patches: Float["*B T h w C"], 73 | temporal_mask: Bool["*B T"], 74 | ) -> Float["*a 1"]: 75 | pred_patches = jnp.astype(pred_patches, jnp.float32) 76 | target_patches = jnp.astype(target_patches, jnp.float32) 77 | loss = jnp.abs(pred_patches - target_patches) # * temporal_mask 78 | if self.image_norm == "sum": 79 | loss = jnp.sum(loss, axis=[-1, -2, -3]) / 1024. 80 | elif self.image_norm == "mean": 81 | loss = jnp.mean(loss, axis=[-1, -2, -3]) 82 | loss = jnp.mean(loss, axis=-1) 83 | return loss[..., None] 84 | 85 | 86 | @dataclasses.dataclass(eq=True, frozen=True, kw_only=True) 87 | class CoordinateSoftmaxCrossEntropyWithIntLabels(base.Loss): 88 | """Softmax cross-entropy loss with integer labels.""" 89 | 90 | logits: kontext.Key = kontext.REQUIRED # e.g. "preds.logits" 91 | labels: kontext.Key = kontext.REQUIRED # e.g. "batch.label" 92 | pixel_size: int = 256 93 | 94 | @typechecked 95 | def get_values( 96 | self, logits: Float["*a n"], labels: Float["*a 2"] 97 | ) -> Float["*a 1"]: 98 | logits = jnp.astype(logits, jnp.float32) 99 | labels = jnp.astype(labels, jnp.float32) 100 | labels -= 0.5 101 | labels = jnp.clip(labels, 0, self.pixel_size - 1) 102 | labels = jnp.round(labels).astype(jnp.int32) 103 | logits_x, logits_y = jnp.split(logits, 2, axis=-1) 104 | labels_x, labels_y = jnp.split(labels, 2, axis=-1) 105 | loss_x = optax.softmax_cross_entropy_with_integer_labels( 106 | logits=logits_x, labels=labels_x.squeeze(-1) 107 | )[..., None] 108 | loss_y = optax.softmax_cross_entropy_with_integer_labels( 109 | logits=logits_y, labels=labels_y.squeeze(-1) 110 | )[..., None] 111 | return loss_x + loss_y 112 | 113 | 114 | @dataclasses.dataclass(eq=True, frozen=True, kw_only=True) 115 | class Certainty(base.Loss): 116 | """Loss for point track uncertainty prediction. 117 | 118 | A point prediction is certain if it falls within threshold of ground truth. 119 | The 3rd term of the loss in Equation (1) of TAPIR paper 120 | https://arxiv.org/abs/2306.08637 121 | """ 122 | 123 | threshold: float = 1.0 124 | 125 | logits: kontext.Key = kontext.REQUIRED 126 | pred_points: kontext.Key = kontext.REQUIRED 127 | target_points: kontext.Key = kontext.REQUIRED 128 | normalize_by: str = "values" 129 | 130 | @typechecked 131 | def get_values( 132 | self, 133 | logits: Float["*a 1"], 134 | pred_points: Float["*a 2"], 135 | target_points: Float["*a 2"], 136 | ) -> Float["*a 1"]: 137 | logits = jnp.astype(logits, jnp.float32) 138 | pred_points = jnp.astype(pred_points, jnp.float32) 139 | target_points = jnp.astype(target_points, jnp.float32) 140 | pred_points = jax.lax.stop_gradient(pred_points) 141 | error = pred_points - target_points 142 | distsqr = jnp.sum(jnp.square(error), axis=-1, keepdims=True) 143 | is_certain = (distsqr <= self.threshold**2).astype(logits.dtype) 144 | loss = optax.sigmoid_binary_cross_entropy(logits, is_certain) 145 | return loss 146 | -------------------------------------------------------------------------------- /configs/tapir_bootstrap_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Default config to train the TAPIR.""" 17 | 18 | from jaxline import base_config 19 | from ml_collections import config_dict 20 | 21 | 22 | # We define the experiment launch config in the same file as the experiment to 23 | # keep things self-contained in a single file. 24 | def get_config() -> config_dict.ConfigDict: 25 | """Return config object for training.""" 26 | config = base_config.get_base_config() 27 | 28 | # Experiment config. 29 | config.training_steps = 100000 30 | 31 | # NOTE: duplicates not allowed. 32 | config.shared_module_names = ('tapir_model',) 33 | 34 | config.dataset_names = ('kubric',) 35 | # Note: eval modes must always start with 'eval_'. 36 | config.eval_modes = ( 37 | 'eval_davis_points', 38 | 'eval_jhmdb', 39 | 'eval_robotics_points', 40 | 'eval_kinetics_points', 41 | ) 42 | config.checkpoint_dir = '/tmp/tapnet_training/' 43 | config.evaluate_every = 10000 44 | 45 | config.experiment_kwargs = config_dict.ConfigDict( 46 | dict( 47 | config=dict( 48 | sweep_name='default_sweep', 49 | save_final_checkpoint_as_npy=True, 50 | # `enable_double_transpose` should always be false when using 1D. 51 | # For other D It is also completely untested and very unlikely 52 | # to work. 53 | optimizer=dict( 54 | base_lr=1e-3, 55 | max_norm=-1, # < 0 to turn off. 56 | weight_decay=1e-1, 57 | schedule_type='cosine', 58 | cosine_decay_kwargs=dict( 59 | init_value=0.0, 60 | warmup_steps=1000, 61 | end_value=0.0, 62 | ), 63 | optimizer='adam', 64 | # Optimizer-specific kwargs. 65 | adam_kwargs=dict( 66 | b1=0.9, 67 | b2=0.95, 68 | eps=1e-8, 69 | ), 70 | ), 71 | fast_variables=tuple(), 72 | shared_modules=dict( 73 | shared_module_names=config.get_oneway_ref( 74 | 'shared_module_names', 75 | ), 76 | tapir_model_kwargs=dict( 77 | bilinear_interp_with_depthwise_conv=False, 78 | pyramid_level=1, 79 | use_causal_conv=False, 80 | initial_resolution=(256, 256), 81 | extra_convs=True, 82 | softmax_temperature=10.0, 83 | ), 84 | ), 85 | datasets=dict( 86 | dataset_names=config.get_oneway_ref('dataset_names'), 87 | kubric_kwargs=dict( 88 | batch_dims=8, 89 | shuffle_buffer_size=128, 90 | train_size=(256, 256), 91 | ), 92 | ), 93 | supervised_point_prediction_kwargs=dict( 94 | prediction_algo='cost_volume_regressor', 95 | model_key='tapir_model', 96 | ), 97 | checkpoint_dir=config.get_oneway_ref('checkpoint_dir'), 98 | evaluate_every=config.get_oneway_ref('evaluate_every'), 99 | eval_modes=config.get_oneway_ref('eval_modes'), 100 | # If true, run evaluate() on the experiment once before 101 | # you load a checkpoint. 102 | # This is useful for getting initial values of metrics 103 | # at random weights, or when debugging locally if you 104 | # do not have any train job running. 105 | davis_points_path='', 106 | jhmdb_path='', 107 | robotics_points_path='', 108 | training=dict( 109 | # Note: to sweep n_training_steps, DO NOT sweep these 110 | # fields directly. Instead sweep config.training_steps. 111 | # Otherwise, decay/stopping logic 112 | # is not guaranteed to be consistent. 113 | n_training_steps=config.get_oneway_ref('training_steps'), 114 | ), 115 | inference=dict( 116 | input_video_path='', 117 | output_video_path='', 118 | resize_height=256, # video height resized to before inference 119 | resize_width=256, # video width resized to before inference 120 | num_points=20, # number of random points to sample 121 | ), 122 | ) 123 | ) 124 | ) 125 | 126 | # Set up where to store the resulting model. 127 | config.train_checkpoint_all_hosts = False 128 | config.save_checkpoint_interval = 10 129 | config.eval_initial_weights = True 130 | 131 | # Prevents accidentally setting keys that aren't recognized (e.g. in tests). 132 | config.lock() 133 | 134 | return config 135 | -------------------------------------------------------------------------------- /tapnet/tapnext/tapnext_benchmark_pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "id": "RHt3rGLLxfWs" 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import torch\n", 12 | "import torchvision" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": { 19 | "colab": { 20 | "base_uri": "https://localhost:8080/" 21 | }, 22 | "id": "-gBuXTWqxuMV", 23 | "outputId": "42ca4293-2496-4a03-ced5-a02d9253c721" 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "torch.__version__, torchvision.__version__" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": { 34 | "colab": { 35 | "base_uri": "https://localhost:8080/" 36 | }, 37 | "id": "gjdq8wLcQWd7", 38 | "outputId": "8c5ed594-021c-4c67-e0df-949151d1247f" 39 | }, 40 | "outputs": [], 41 | "source": [ 42 | "!pip install -q git+https://github.com/google-deepmind/tapnet.git" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": { 49 | "colab": { 50 | "base_uri": "https://localhost:8080/" 51 | }, 52 | "id": "mrQ_uQDeQee-", 53 | "outputId": "ce0d3868-7b82-4e8e-88ae-fd5ae960f95b" 54 | }, 55 | "outputs": [], 56 | "source": [ 57 | "!pip install -q git+https://github.com/google-deepmind/recurrentgemma.git@main" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": { 64 | "colab": { 65 | "base_uri": "https://localhost:8080/" 66 | }, 67 | "id": "NIZQxfcgyFM9", 68 | "outputId": "8cd10428-69a7-4c39-f3f9-eec9c07ed108" 69 | }, 70 | "outputs": [], 71 | "source": [ 72 | "!pip install \"numpy\u003c2.1.0\"" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": { 79 | "id": "gju7QkZH2XLL" 80 | }, 81 | "outputs": [], 82 | "source": [ 83 | "import tqdm" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": { 89 | "id": "uQ6Ako7IQERX" 90 | }, 91 | "source": [ 92 | "### TAPNext" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": { 99 | "id": "hy98wCMgQx6x" 100 | }, 101 | "outputs": [], 102 | "source": [ 103 | "import time\n", 104 | "import numpy as np\n", 105 | "from tapnet.tapnext.tapnext_torch import TAPNext\n", 106 | "import torch.nn.functional as F" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": { 112 | "id": "gJpHisHSQERc" 113 | }, 114 | "source": [ 115 | "### Create the model and load checkpoint" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": { 122 | "id": "pnTma4NKQERc" 123 | }, 124 | "outputs": [], 125 | "source": [ 126 | "model = TAPNext(image_size=(256, 256))\n", 127 | "model.cuda()" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": { 133 | "id": "CukzSyYSQERd" 134 | }, 135 | "source": [ 136 | "### Run inference" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "metadata": { 143 | "colab": { 144 | "base_uri": "https://localhost:8080/", 145 | "height": 477 146 | }, 147 | "id": "YNf98GICtJMa", 148 | "outputId": "05da0eef-b914-4d1b-f740-5152f47ddb4c" 149 | }, 150 | "outputs": [], 151 | "source": [ 152 | "model.eval()\n", 153 | "for p in model.parameters():\n", 154 | " p.requires_grad = False" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": { 161 | "id": "IUowukX9Qx6x" 162 | }, 163 | "outputs": [], 164 | "source": [ 165 | "NUM_QUERIES = 1024" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": { 172 | "id": "TyiVKz9MQx6x" 173 | }, 174 | "outputs": [], 175 | "source": [ 176 | "video = torch.zeros((1, 1024, 256, 256, 3), dtype=torch.float32).cuda()\n", 177 | "query_points = torch.zeros((1, NUM_QUERIES, 3), dtype=torch.float32).cuda()" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "metadata": { 184 | "id": "vLRpvleUQx6x" 185 | }, 186 | "outputs": [], 187 | "source": [ 188 | "DTYPE = torch.float16 # use fp16 or bf16" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "metadata": { 195 | "id": "n6x1NxPaQx6x" 196 | }, 197 | "outputs": [], 198 | "source": [ 199 | "fwd = torch.compile(model.forward)\n", 200 | "with torch.no_grad():\n", 201 | " with torch.amp.autocast('cuda', dtype=DTYPE, enabled=True):\n", 202 | " _, _, _, tracking_state = fwd(video=video[:, :1], query_points=query_points)\n", 203 | " c = 0\n", 204 | " for k in tqdm.tqdm(range(1, video.shape[1])):\n", 205 | " if k == 512:\n", 206 | " # we let the model to run for several GPU burn-in steps\n", 207 | " tt = time.time()\n", 208 | " c = 0\n", 209 | " _, _, _, tracking_state = fwd(\n", 210 | " video=video[:, k : k + 1], state=tracking_state\n", 211 | " )\n", 212 | " c += 1\n", 213 | " d = time.time() - tt\n", 214 | " print('FPS:', c / d, 'latency', 1000 * d / c, 'ms')" 215 | ] 216 | } 217 | ], 218 | "metadata": { 219 | "accelerator": "GPU", 220 | "colab": { 221 | "gpuType": "T4", 222 | "provenance": [] 223 | }, 224 | "kernelspec": { 225 | "display_name": "Python 3 (ipykernel)", 226 | "language": "python", 227 | "name": "python3" 228 | }, 229 | "language_info": { 230 | "codemirror_mode": { 231 | "name": "ipython", 232 | "version": 3 233 | }, 234 | "file_extension": ".py", 235 | "mimetype": "text/x-python", 236 | "name": "python", 237 | "nbconvert_exporter": "python", 238 | "pygments_lexer": "ipython3", 239 | "version": "3.11.11" 240 | } 241 | }, 242 | "nbformat": 4, 243 | "nbformat_minor": 4 244 | } 245 | -------------------------------------------------------------------------------- /tapnet/trajan/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Attention modules for TRAJAN.""" 17 | 18 | from __future__ import annotations 19 | 20 | from typing import Optional 21 | 22 | from flax import linen as nn 23 | import jax.numpy as jnp 24 | 25 | 26 | class ImprovedTransformer(nn.Module): 27 | """Improved Transformer using tricks from ViT-22B (w/ cross-attention). 28 | 29 | 1) Normalize keys/queries w/ LayerNorm. 30 | 2) Do some ops in parallel (here: cross + self-attention, but not MLP). 31 | """ 32 | 33 | qkv_size: int 34 | num_heads: int 35 | mlp_size: int 36 | num_layers: int 37 | 38 | @nn.compact 39 | def __call__( 40 | self, 41 | queries, # float['... d1'], 42 | inputs_kv=None, # Optional[float['*b N D']] 43 | qk_mask=None, # Optional[bool['...']] 44 | qq_mask=None, # Optional[bool['...']] 45 | ): # -> float['... d2'] 46 | 47 | for i in range(self.num_layers): 48 | if qk_mask is not None and len(qk_mask.shape) == len(inputs_kv.shape): 49 | qk_mask = qk_mask[..., jnp.newaxis, :, :] 50 | if qq_mask is not None and len(qq_mask.shape) == len(queries.shape): 51 | qq_mask = qq_mask[..., jnp.newaxis, :, :] 52 | 53 | queries = ImprovedTransformerBlock( 54 | qkv_size=self.qkv_size, 55 | num_heads=self.num_heads, 56 | mlp_size=self.mlp_size, 57 | name=f'layer_{i}', 58 | )( 59 | queries, 60 | inputs_kv=inputs_kv, 61 | qq_mask=qq_mask, 62 | qk_mask=qk_mask, 63 | ) 64 | 65 | queries = nn.LayerNorm(use_bias=False, use_scale=True, name='norm_encoder')( 66 | queries 67 | ) 68 | 69 | return queries 70 | 71 | 72 | class ImprovedTransformerBlock(nn.Module): 73 | """Improved Transformer block using tricks from ViT-22B (w/ cross-attention). 74 | 75 | 1) RMSNorm instead of LayerNorm. 76 | 2) Normalize keys/queries w/ RMSNorm. 77 | 3) Do some ops in parallel (here: cross + self-attention, but not MLP). 78 | """ 79 | 80 | mlp_size: int 81 | num_heads: int 82 | qkv_size: int 83 | 84 | @nn.compact 85 | def __call__( 86 | self, 87 | queries, # float['*b n d'], 88 | inputs_kv, # Optional[float['*b N D']] 89 | qq_mask=None, # Optional[bool['...']] 90 | qk_mask=None, # Optional[bool['...']] 91 | remat_attn: bool = False, 92 | ): # -> Float['*b n d'] 93 | width = queries.shape[-1] 94 | normed_queries = nn.LayerNorm( 95 | use_bias=False, use_scale=True, name='norm_q' 96 | )(queries) 97 | attn_out = queries 98 | 99 | # Self-attention. 100 | self_attn_out = ImprovedMHDPAttention( 101 | num_heads=self.num_heads, qk_size=self.qkv_size, name='self_att' 102 | )( 103 | inputs_q=normed_queries, 104 | inputs_kv=normed_queries, 105 | mask=jnp.array(qq_mask, jnp.float32) if qq_mask is not None else None, 106 | ) 107 | attn_out += self_attn_out 108 | 109 | # Cross-attention. 110 | if inputs_kv is not None: 111 | cross_attn_out = ImprovedMHDPAttention( 112 | num_heads=self.num_heads, qk_size=self.qkv_size, name='cross_att' 113 | )( 114 | inputs_q=normed_queries, 115 | inputs_kv=inputs_kv, 116 | mask=jnp.array(qk_mask, jnp.float32) if qk_mask is not None else None, 117 | ) 118 | attn_out += cross_attn_out 119 | 120 | # MLP. 121 | normed_attn_out = nn.LayerNorm( 122 | use_bias=False, use_scale=True, name='norm_attn' 123 | )(attn_out) 124 | h = nn.gelu(nn.Dense(self.mlp_size, name='MLP_in')(normed_attn_out)) 125 | mlp_out = nn.Dense(width, name='MLP_out')(h) 126 | return attn_out + mlp_out 127 | 128 | 129 | class ImprovedMHDPAttention(nn.Module): 130 | """Multi-head dot-product attention. 131 | 132 | Simplified nn.MultiheadDotProductAttention with a few modifications: 133 | - include normalization of keys and queries 134 | - dropped out support for dropout 135 | 136 | Attributes: 137 | num_heads: Number of attention heads. 138 | qk_size: Total dimension of the keys and queries. 139 | v_size: Total dimension of the values. Defaults to qk_size. 140 | """ 141 | 142 | num_heads: int 143 | qk_size: int 144 | v_size: Optional[int] = None 145 | 146 | @nn.compact 147 | def __call__( 148 | self, 149 | inputs_q, # float['*b q d1'], 150 | inputs_kv, # float['*b k d2'], 151 | mask=None, # Optional[float['*b #h #q #k']] 152 | ): # -> float['*b q d1'] 153 | """Applies multi-head dot product attention on the input data. 154 | 155 | Projects the inputs into multi-headed query, key, and value vectors, 156 | applies dot-product attention and project the results to an output vector. 157 | 158 | Args: 159 | inputs_q: Input tokens from which queries are computed. 160 | inputs_kv: Input tokens from which the keys and queries are computed. 161 | mask: Mask for the attention weights. 162 | 163 | Returns: 164 | output tokens (linear projection of an attention weighted average of value 165 | tokens per query). 166 | """ 167 | v_size = self.qk_size if self.v_size is None else self.v_size 168 | 169 | if self.qk_size % self.num_heads: 170 | raise ValueError(f'{self.num_heads=} must divide {self.qk_size=}.') 171 | if v_size % self.num_heads: 172 | raise ValueError(f'{v_size=} must divide {self.num_heads=}.') 173 | 174 | # Project inputs_q to multi-headed queries and keys. 175 | # dimensions are then [B..., Q, H, qk_size] 176 | query = nn.DenseGeneral( 177 | features=(self.num_heads, self.qk_size // self.num_heads), 178 | use_bias=False, 179 | name='dense_query', 180 | )(inputs_q) 181 | key = nn.DenseGeneral( 182 | features=(self.num_heads, self.qk_size // self.num_heads), 183 | use_bias=False, 184 | name='dense_key', 185 | )(inputs_kv) 186 | 187 | # Normalize keys and queries before attention. 188 | query = nn.RMSNorm(name='norm_query')(query) 189 | key = nn.RMSNorm(name='norm_key')(key) 190 | 191 | value = nn.DenseGeneral( 192 | features=(self.num_heads, v_size // self.num_heads), 193 | use_bias=False, 194 | name='dense_value', 195 | )(inputs_kv) 196 | 197 | x = nn.dot_product_attention(query, key, value, mask=mask) 198 | 199 | # Back to the original input dimensions. 200 | out = nn.DenseGeneral( 201 | features=inputs_q.shape[-1], 202 | axis=(-2, -1), 203 | use_bias=True, 204 | name='dense_out', 205 | )(x) 206 | 207 | return out 208 | -------------------------------------------------------------------------------- /tapnet/tapvid3d/annotation_generation/adt_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utilities for generating TAPVid3d ADT npz files.""" 17 | 18 | # pylint: disable=g-import-not-at-top,g-bad-import-order 19 | import tensorflow as tf 20 | 21 | tf.config.set_visible_devices([], "GPU") 22 | 23 | import os 24 | 25 | import numpy as np 26 | import numpy.typing as npt 27 | from PIL import Image 28 | from projectaria_tools.core import calibration 29 | from projectaria_tools.core.stream_id import StreamId 30 | from projectaria_tools.projects.adt import AriaDigitalTwinDataPathsProvider 31 | from projectaria_tools.projects.adt import AriaDigitalTwinDataProvider 32 | 33 | import tqdm 34 | from tapnet.tapvid3d.annotation_generation import adt_v1v2_mappings 35 | 36 | 37 | # Fixed hyperparameters for generating the ADT data. 38 | N_FRAMES = 300 39 | HEIGHT = 512 40 | WIDTH = 512 41 | FOCAL_LENGTH = 280 42 | 43 | 44 | class ADTVideoProcessor: 45 | """ADT video processor.""" 46 | 47 | def __init__(self, sequence_path: str): 48 | self.timestamps_ns, self.gt_provider, self.stream_id = self._load_adt_data( 49 | sequence_path 50 | ) 51 | 52 | def _load_adt_data(self, sequence_path: str): 53 | """Loads ADT data for a given sequence.""" 54 | paths_provider = AriaDigitalTwinDataPathsProvider(sequence_path) 55 | selected_device_number = 0 56 | data_paths = paths_provider.get_datapaths_by_device_num( 57 | selected_device_number, False 58 | ) 59 | gt_provider = AriaDigitalTwinDataProvider(data_paths) 60 | stream_id = StreamId("214-1") 61 | timestamps_ns = np.array( 62 | gt_provider.get_aria_device_capture_timestamps_ns(stream_id) 63 | ) 64 | # Remove timestamps without annotations. 65 | timestamps_ns = timestamps_ns[ 66 | timestamps_ns > gt_provider.get_start_time_ns() 67 | ] 68 | timestamps_ns = timestamps_ns[timestamps_ns < gt_provider.get_end_time_ns()] 69 | return timestamps_ns, gt_provider, stream_id 70 | 71 | def extract_image_data( 72 | self, chunk_timestamps_ns: list[int] 73 | ) -> tuple[ 74 | list[npt.NDArray], list[npt.NDArray], list[npt.NDArray], list[int] 75 | ]: 76 | """Extracts image, depth and segmentation data for a given video chunk.""" 77 | sensor_name = ( 78 | self.gt_provider.raw_data_provider_ptr().get_label_from_stream_id( 79 | self.stream_id 80 | ) 81 | ) 82 | device_calib = ( 83 | self.gt_provider.raw_data_provider_ptr().get_device_calibration() 84 | ) 85 | src_calib = device_calib.get_camera_calib(sensor_name) 86 | identity_tnf = calibration.get_linear_camera_calibration( 87 | 1, 1, 1 88 | ).get_transform_device_camera() 89 | dst_calib = calibration.CameraCalibration( 90 | "camera-rgb", 91 | calibration.CameraModelType.LINEAR, 92 | np.array([FOCAL_LENGTH, FOCAL_LENGTH, WIDTH / 2, HEIGHT / 2]), 93 | identity_tnf, 94 | WIDTH, 95 | HEIGHT, 96 | None, 97 | np.pi, 98 | "LinearCameraCalibration", 99 | ) 100 | rgb_ims = [] 101 | depth_ims = [] 102 | segmentation_ims = [] 103 | ok_timestamps_ns = [] 104 | for select_timestamps_ns in chunk_timestamps_ns: 105 | depth_with_dt = self.gt_provider.get_depth_image_by_timestamp_ns( 106 | select_timestamps_ns, self.stream_id 107 | ) 108 | segmentation_with_dt = ( 109 | self.gt_provider.get_segmentation_image_by_timestamp_ns( 110 | select_timestamps_ns, self.stream_id 111 | ) 112 | ) 113 | image_with_dt = self.gt_provider.get_aria_image_by_timestamp_ns( 114 | select_timestamps_ns, self.stream_id 115 | ) 116 | # Check if the result is valid. 117 | if ( 118 | image_with_dt.is_valid() 119 | and depth_with_dt.is_valid() 120 | and segmentation_with_dt.is_valid() 121 | ): 122 | ok_timestamps_ns.append(select_timestamps_ns) 123 | else: 124 | continue 125 | image = image_with_dt.data().to_numpy_array() 126 | image = calibration.distort_by_calibration(image, dst_calib, src_calib) 127 | rgb_ims.append(image) 128 | 129 | depth_image = depth_with_dt.data().to_numpy_array() 130 | depth_image = calibration.distort_depth_by_calibration( 131 | depth_image, dst_calib, src_calib 132 | ) 133 | depth_ims.append(depth_image) 134 | 135 | segmentation_data = segmentation_with_dt.data().to_numpy_array() 136 | segmentation_data = calibration.distort_label_by_calibration( 137 | segmentation_data, dst_calib, src_calib 138 | ) 139 | segmentation_ims.append(segmentation_data) 140 | chunk_timestamps_ns = ok_timestamps_ns 141 | return rgb_ims, depth_ims, segmentation_ims, chunk_timestamps_ns 142 | 143 | 144 | def process_vid( 145 | input_adt_path: str, 146 | input_npz_path: str, 147 | output_npz_path: str, 148 | seq_name: str, 149 | chunks: list[int], 150 | ): 151 | """Processes multiple chunks of a single video.""" 152 | adt_v2_name = adt_v1v2_mappings.ADT_MAPPINGS[seq_name] 153 | sequence_path = os.path.join(input_adt_path, adt_v2_name) 154 | adt_processor = ADTVideoProcessor(sequence_path) 155 | 156 | for chunk_idx in tqdm.tqdm(chunks): 157 | track_fn = os.path.join(output_npz_path, f"{seq_name}_{chunk_idx}.npz") 158 | chunk_timestamps_ns = adt_processor.timestamps_ns[ 159 | chunk_idx * N_FRAMES : (chunk_idx + 1) * N_FRAMES 160 | ] 161 | 162 | rgb_ims, _, _, _ = adt_processor.extract_image_data(chunk_timestamps_ns) 163 | rgb_ims = [np.array(Image.fromarray(im).rotate(-90)) for im in rgb_ims] 164 | rgb_jpegs = [np.array(tf.io.encode_jpeg(im)).item() for im in rgb_ims] 165 | 166 | # Load query points. 167 | in_npz = np.load( 168 | os.path.join(input_npz_path, f"{seq_name}_{chunk_idx}.npz"), 169 | allow_pickle=True, 170 | ) 171 | queries_xyt = in_npz["queries_xyt"] 172 | trajectories = in_npz["tracks_XYZ"] 173 | visibilities = in_npz["visibility"] 174 | 175 | # Verify video means. 176 | video_means = np.stack([np.mean(x, axis=(0, 1)) for x in rgb_ims], axis=0) 177 | assert np.allclose(video_means, in_npz["video_means"], atol=1e-3) 178 | 179 | example = { 180 | "images_jpeg_bytes": rgb_jpegs, 181 | "queries_xyt": queries_xyt, 182 | "tracks_XYZ": trajectories, 183 | "visibility": visibilities, 184 | "fx_fy_cx_cy": np.array( 185 | [FOCAL_LENGTH, FOCAL_LENGTH, WIDTH / 2, HEIGHT / 2] 186 | ), 187 | } 188 | np.savez(track_fn, **example) 189 | -------------------------------------------------------------------------------- /tapnet/live_demo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Live Demo for Online TAPIR.""" 17 | 18 | import time 19 | 20 | import cv2 21 | import jax 22 | import jax.numpy as jnp 23 | import numpy as np 24 | from tapnet.models import tapir_model 25 | from tapnet.utils import model_utils 26 | 27 | 28 | NUM_POINTS = 8 29 | 30 | 31 | def load_checkpoint(checkpoint_path): 32 | ckpt_state = np.load(checkpoint_path, allow_pickle=True).item() 33 | return ckpt_state["params"], ckpt_state["state"] 34 | 35 | print("Loading checkpoint...") 36 | # -------------------- 37 | # Load checkpoint and initialize 38 | params, state = load_checkpoint( 39 | "tapnet/checkpoints/causal_tapir_checkpoint.npy" 40 | ) 41 | 42 | tapir = tapir_model.ParameterizedTAPIR( 43 | params=params, 44 | state=state, 45 | tapir_kwargs=dict( 46 | use_causal_conv=True, bilinear_interp_with_depthwise_conv=False 47 | ), 48 | ) 49 | 50 | 51 | def online_model_init(frames, points): 52 | feature_grids = tapir.get_feature_grids(frames, is_training=False) 53 | features = tapir.get_query_features( 54 | frames, 55 | is_training=False, 56 | query_points=points, 57 | feature_grids=feature_grids, 58 | ) 59 | return features 60 | 61 | 62 | def online_model_predict(frames, features, causal_context): 63 | """Compute point tracks and occlusions given frames and query points.""" 64 | feature_grids = tapir.get_feature_grids(frames, is_training=False) 65 | trajectories = tapir.estimate_trajectories( 66 | frames.shape[-3:-1], 67 | is_training=False, 68 | feature_grids=feature_grids, 69 | query_features=features, 70 | query_points_in_video=None, 71 | query_chunk_size=64, 72 | causal_context=causal_context, 73 | get_causal_context=True, 74 | ) 75 | causal_context = trajectories["causal_context"] 76 | del trajectories["causal_context"] 77 | return {k: v[-1] for k, v in trajectories.items()}, causal_context 78 | 79 | 80 | def get_frame(video_capture): 81 | r_val, image = video_capture.read() 82 | trunc = np.abs(image.shape[1] - image.shape[0]) // 2 83 | if image.shape[1] > image.shape[0]: 84 | image = image[:, trunc:-trunc] 85 | elif image.shape[1] < image.shape[0]: 86 | image = image[trunc:-trunc] 87 | return r_val, image 88 | 89 | 90 | print("Welcome to the TAPIR live demo.") 91 | print("Please note that if the framerate is low (<~12 fps), TAPIR performance") 92 | print("may degrade and you may need a more powerful GPU.") 93 | 94 | print("Creating model...") 95 | online_init_apply = jax.jit(online_model_init) 96 | 97 | online_predict_apply = jax.jit(online_model_predict) 98 | 99 | print("Initializing camera...") 100 | # -------------------- 101 | # Start point tracking 102 | vc = cv2.VideoCapture(0) 103 | 104 | vc.set(cv2.CAP_PROP_FRAME_HEIGHT, 240) 105 | 106 | if vc.isOpened(): # try to get the first frame 107 | rval, frame = get_frame(vc) 108 | else: 109 | raise ValueError("Unable to open camera.") 110 | 111 | pos = tuple() 112 | query_frame = True 113 | have_point = [False] * NUM_POINTS 114 | query_features = None 115 | causal_state = None 116 | next_query_idx = 0 117 | 118 | print("Compiling jax functions (this may take a while...)") 119 | # -------------------- 120 | # Call one time to compile 121 | query_points = jnp.zeros([NUM_POINTS, 3], dtype=jnp.float32) 122 | _ = online_init_apply( 123 | frames=model_utils.preprocess_frames(frame[None, None]), 124 | points=query_points[None, 0:1], 125 | ) 126 | jax.block_until_ready(query_features) 127 | query_features = online_init_apply( 128 | frames=model_utils.preprocess_frames(frame[None, None]), 129 | points=query_points[None, :], 130 | ) 131 | 132 | causal_state = tapir.construct_initial_causal_state( 133 | NUM_POINTS, len(query_features.resolutions) - 1 134 | ) 135 | 136 | prediction, causal_state = online_predict_apply( 137 | frames=model_utils.preprocess_frames(frame[None, None]), 138 | features=query_features, 139 | causal_context=causal_state, 140 | ) 141 | 142 | jax.block_until_ready(prediction["tracks"]) 143 | 144 | last_click_time = 0 145 | 146 | 147 | def mouse_click(event, x, y, flags, param): 148 | del flags, param 149 | global pos, query_frame, last_click_time 150 | 151 | # event fires multiple times per click sometimes?? 152 | if (time.time() - last_click_time) < 0.5: 153 | return 154 | 155 | if event == cv2.EVENT_LBUTTONDOWN: 156 | pos = (y, frame.shape[1] - x) 157 | query_frame = True 158 | last_click_time = time.time() 159 | 160 | 161 | cv2.namedWindow("Point Tracking") 162 | cv2.setMouseCallback("Point Tracking", mouse_click) 163 | 164 | t = time.time() 165 | step_counter = 0 166 | 167 | print("Press ESC to exit.") 168 | 169 | while rval: 170 | rval, frame = get_frame(vc) 171 | if query_frame: 172 | query_points = jnp.array((0,) + pos, dtype=jnp.float32) 173 | 174 | init_query_features = online_init_apply( 175 | frames=model_utils.preprocess_frames(frame[None, None]), 176 | points=query_points[None, None], 177 | ) 178 | query_frame = False 179 | query_features, causal_state = tapir.update_query_features( 180 | query_features=query_features, 181 | new_query_features=init_query_features, 182 | idx_to_update=np.array([next_query_idx]), 183 | causal_state=causal_state, 184 | ) 185 | have_point[next_query_idx] = True 186 | next_query_idx = (next_query_idx + 1) % NUM_POINTS 187 | if pos: 188 | prediction, causal_state = online_predict_apply( 189 | frames=model_utils.preprocess_frames(frame[None, None]), 190 | features=query_features, 191 | causal_context=causal_state, 192 | ) 193 | track = prediction["tracks"][0, :, 0] 194 | occlusion = prediction["occlusion"][0, :, 0] 195 | expected_dist = prediction["expected_dist"][0, :, 0] 196 | visibles = model_utils.postprocess_occlusions(occlusion, expected_dist) 197 | track = np.round(track) 198 | 199 | for i, _ in enumerate(have_point): 200 | if visibles[i] and have_point[i]: 201 | cv2.circle( 202 | frame, (int(track[i, 0]), int(track[i, 1])), 5, (255, 0, 0), -1 203 | ) 204 | if track[i, 0] < 16 and track[i, 1] < 16: 205 | print((i, next_query_idx)) 206 | cv2.imshow("Point Tracking", frame[:, ::-1]) 207 | if pos: 208 | step_counter += 1 209 | if time.time() - t > 5: 210 | print(f"{step_counter/(time.time()-t)} frames per second") 211 | t = time.time() 212 | step_counter = 0 213 | else: 214 | t = time.time() 215 | key = cv2.waitKey(1) 216 | 217 | if key == 27: # exit on ESC 218 | break 219 | 220 | cv2.destroyWindow("Point Tracking") 221 | vc.release() 222 | -------------------------------------------------------------------------------- /tapnet/tapvid/generate_tapvid.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Python script to generate a pickle with annotations from raw videos.""" 17 | 18 | import csv 19 | import dataclasses 20 | import io 21 | import math 22 | import os 23 | import pickle 24 | from typing import Dict, Iterator, List, Sequence, Tuple 25 | 26 | from absl import app 27 | from absl import flags 28 | from absl import logging 29 | import ffmpeg 30 | import numpy as np 31 | from PIL import Image 32 | 33 | FLAGS = flags.FLAGS 34 | 35 | flags.DEFINE_string('input_csv_path', None, 'Path to the input csv.') 36 | flags.DEFINE_string( 37 | 'output_base_path', 38 | None, 39 | 'Path to the output folder where pickle files will be stored.', 40 | ) 41 | flags.DEFINE_string( 42 | 'video_root_path', 43 | None, 44 | ( 45 | 'Path to the root directory of the extracted Kinetics-700-2020 videos' 46 | ' from the validation set.' 47 | ), 48 | ) 49 | flags.DEFINE_integer('num_shards', 10, 'Number of pickle shards to output.') 50 | 51 | _JPEG_HEADER = b'\xff\xd8' 52 | 53 | 54 | @dataclasses.dataclass(frozen=True) 55 | class Point: 56 | x: float 57 | y: float 58 | occluded: bool 59 | 60 | 61 | @dataclasses.dataclass(frozen=True) 62 | class Track: 63 | points: Tuple[Point, ...] 64 | 65 | 66 | @dataclasses.dataclass(frozen=True) 67 | class Video: 68 | youtube_id: str 69 | start_time_sec: int 70 | end_time_sec: int 71 | video_path: str 72 | tracks: Tuple[Track, ...] 73 | 74 | 75 | def csv_to_dataset( 76 | csv_path: str, videos_path: Dict[str, str] 77 | ) -> Tuple[Video, ...]: 78 | """Reads the input CSV and creates a list of `Video`s out of that.""" 79 | 80 | def points(row: Sequence[str]) -> Iterator[Point]: 81 | for i in range(250): 82 | x, y, occ = row[3 + 3 * i : 3 + 3 * i + 3] 83 | x = float(x) 84 | y = float(y) 85 | assert occ in ('0', '1') 86 | occ = occ == '1' 87 | yield Point(x, y, occ) 88 | 89 | logging.info('Reading CSV "%s".', csv_path) 90 | 91 | with open(csv_path) as f: 92 | reader = csv.reader(f, delimiter=',') 93 | 94 | tracks_per_video: Dict[Tuple[str, int, int], List[Track]] = {} 95 | for row in reader: 96 | assert len(row) == 3 + 3 * 250 97 | 98 | youtube_id, start_time_sec, end_time_sec = row[:3] 99 | start_time_sec = int(start_time_sec) 100 | end_time_sec = int(end_time_sec) 101 | key = (youtube_id, start_time_sec, end_time_sec) 102 | 103 | track = Track(tuple(points(row))) 104 | 105 | if key not in tracks_per_video: 106 | tracks_per_video[key] = [] 107 | tracks_per_video[key].append(track) 108 | 109 | def videos() -> Iterator[Video]: 110 | for key, tracks in tracks_per_video.items(): 111 | youtube_id, start_time_sec, end_time_sec = key 112 | 113 | name = f'{youtube_id}_{start_time_sec:06}_{end_time_sec:06}' 114 | if name not in videos_path: 115 | logging.warning('Video "%s" not downloaded. Skipping it.', name) 116 | continue 117 | video_path = videos_path[name] 118 | 119 | yield Video( 120 | youtube_id, start_time_sec, end_time_sec, video_path, tuple(tracks) 121 | ) 122 | 123 | return tuple(videos()) 124 | 125 | 126 | def get_paths_to_videos(video_root_path: str) -> Dict[str, str]: 127 | """Returns the relative path to each downloaded video.""" 128 | logging.info('Reading all videos in subfolders of "%s".', video_root_path) 129 | video_to_path: Dict[str, str] = {} 130 | for folder_or_video in os.listdir(video_root_path): 131 | path = os.path.join(video_root_path, folder_or_video) 132 | if os.path.isdir(path): 133 | subfolder_paths = get_paths_to_videos(path) 134 | for k, v in subfolder_paths.items(): 135 | assert k not in video_to_path 136 | video_to_path[k] = v 137 | elif folder_or_video.endswith('.mp4'): 138 | name = folder_or_video[:-4] # Remove '.mp4'. 139 | assert name not in video_to_path 140 | video_to_path[name] = path 141 | 142 | return video_to_path 143 | 144 | 145 | def extract_frames(video_path: str, fps: float) -> Tuple[bytes, ...]: 146 | """Extracts list of jpeg bytes from the given video using ffmpeg.""" 147 | cmd = ( 148 | ffmpeg.input(video_path) 149 | .filter('fps', fps=fps) 150 | .output('pipe:', format='image2pipe') 151 | ) 152 | jpeg_bytes, _ = cmd.run(capture_stdout=True, quiet=True) 153 | jpeg_bytes = jpeg_bytes.split(_JPEG_HEADER)[1:] 154 | jpeg_bytes = map(lambda x: _JPEG_HEADER + x, jpeg_bytes) 155 | return tuple(jpeg_bytes) 156 | 157 | 158 | def generate_example(video: Video) -> Dict[str, np.ndarray]: 159 | """Generates a dictionary with the info from a `Video`.""" 160 | example: Dict[str, np.ndarray] = {} 161 | 162 | imgs_encoded = extract_frames(video.video_path, 25.0) 163 | if len(imgs_encoded) > 250: 164 | imgs_encoded = imgs_encoded[:250] 165 | 166 | if len(imgs_encoded) < 250: 167 | # Clip is shorter than 10s. 168 | num_frames = len(imgs_encoded) 169 | new_tracks = tuple( 170 | Track(tuple(t.points[:num_frames])) for t in video.tracks 171 | ) 172 | video = Video( 173 | video.youtube_id, 174 | video.start_time_sec, 175 | video.end_time_sec, 176 | video.video_path, 177 | new_tracks, 178 | ) 179 | 180 | example['video'] = np.array(imgs_encoded) 181 | byteio = io.BytesIO(imgs_encoded[0]) 182 | img = Image.open(byteio) 183 | height, width, _ = np.array(img).shape 184 | 185 | points = [] 186 | occluded = [] 187 | for track in video.tracks: 188 | points.append( 189 | [ 190 | [(p.x * width - 0.5) / width, (p.y * height - 0.5) / height] 191 | for p in track.points 192 | ] 193 | ) 194 | occluded.append([p.occluded for p in track.points]) 195 | 196 | example['points'] = np.array(points, dtype=np.float64) 197 | example['occluded'] = np.array(occluded, dtype=bool) 198 | 199 | return example 200 | 201 | 202 | def main(argv: Sequence[str]) -> None: 203 | del argv 204 | 205 | output_folder = FLAGS.output_base_path 206 | if output_folder and not os.path.exists(output_folder): 207 | os.makedirs(output_folder) 208 | 209 | # Reads data. 210 | videos_path = get_paths_to_videos(FLAGS.video_root_path) 211 | videos = csv_to_dataset(FLAGS.input_csv_path, videos_path) 212 | 213 | # Process the dataset and store pickles. 214 | num_examples_per_shard = int(math.ceil(len(videos) / FLAGS.num_shards)) 215 | shard = 0 216 | data = [] 217 | for i, video in enumerate(videos): 218 | print( 219 | 'Processing example %d of %d (%d%%) \r' 220 | % (i, len(videos), i * 100 / len(videos)), 221 | end='', 222 | ) 223 | data.append(generate_example(video)) 224 | if i == len(videos) - 1 or len(data) == num_examples_per_shard: 225 | shard_path = os.path.join( 226 | output_folder, f'{shard:04}_of_{FLAGS.num_shards:04}.pkl' 227 | ) 228 | logging.info('Writing file "%s".', shard_path) 229 | with open(shard_path, 'wb') as f: 230 | pickle.dump(data, f) 231 | data.clear() 232 | shard += 1 233 | 234 | 235 | if __name__ == '__main__': 236 | app.run(main) 237 | -------------------------------------------------------------------------------- /tapnet/pytorch_live_demo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Live Demo for PyTorch Online TAPIR.""" 17 | 18 | import time 19 | 20 | import cv2 21 | import numpy as np 22 | from tapnet.torch import tapir_model 23 | import torch 24 | import torch.nn.functional as F 25 | import tree 26 | 27 | NUM_POINTS = 8 28 | 29 | 30 | def preprocess_frames(frames): 31 | """Preprocess frames to model inputs. 32 | 33 | Args: 34 | frames: [num_frames, height, width, 3], [0, 255], np.uint8 35 | 36 | Returns: 37 | frames: [num_frames, height, width, 3], [-1, 1], np.float32 38 | """ 39 | frames = frames.float() 40 | frames = frames / 255 * 2 - 1 41 | return frames 42 | 43 | 44 | def online_model_init(frames, points): 45 | """Initialize query features for the query points.""" 46 | frames = preprocess_frames(frames) 47 | feature_grids = model.get_feature_grids(frames, is_training=False) 48 | features = model.get_query_features( 49 | frames, 50 | is_training=False, 51 | query_points=points, 52 | feature_grids=feature_grids, 53 | ) 54 | return features 55 | 56 | 57 | def postprocess_occlusions(occlusions, expected_dist): 58 | visibles = (1 - F.sigmoid(occlusions)) * (1 - F.sigmoid(expected_dist)) > 0.5 59 | return visibles 60 | 61 | 62 | def online_model_predict(frames, features, causal_context): 63 | """Compute point tracks and occlusions given frames and query points.""" 64 | frames = preprocess_frames(frames) 65 | feature_grids = model.get_feature_grids(frames, is_training=False) 66 | trajectories = model.estimate_trajectories( 67 | frames.shape[-3:-1], 68 | is_training=False, 69 | feature_grids=feature_grids, 70 | query_features=features, 71 | query_points_in_video=None, 72 | query_chunk_size=64, 73 | causal_context=causal_context, 74 | get_causal_context=True, 75 | ) 76 | causal_context = trajectories["causal_context"] 77 | del trajectories["causal_context"] 78 | # Take only the predictions for the final resolution. 79 | # For running on higher resolution, it's typically better to average across 80 | # resolutions. 81 | tracks = trajectories["tracks"][-1] 82 | occlusions = trajectories["occlusion"][-1] 83 | uncertainty = trajectories["expected_dist"][-1] 84 | visibles = postprocess_occlusions(occlusions, uncertainty) 85 | return tracks, visibles, causal_context 86 | 87 | 88 | def get_frame(video_capture): 89 | r_val, image = video_capture.read() 90 | trunc = np.abs(image.shape[1] - image.shape[0]) // 2 91 | if image.shape[1] > image.shape[0]: 92 | image = image[:, trunc:-trunc] 93 | elif image.shape[1] < image.shape[0]: 94 | image = image[trunc:-trunc] 95 | return r_val, image 96 | 97 | 98 | print("Welcome to the TAPIR PyTorch live demo.") 99 | print("Please note that if the framerate is low (<~12 fps), TAPIR performance") 100 | print("may degrade and you may need a more powerful GPU.") 101 | 102 | if torch.cuda.is_available(): 103 | device = torch.device("cuda") 104 | else: 105 | device = torch.device("cpu") 106 | 107 | # -------------------- 108 | # Load checkpoint and initialize 109 | print("Creating model...") 110 | model = tapir_model.TAPIR(pyramid_level=1, use_casual_conv=True) 111 | print("Loading checkpoint...") 112 | model.load_state_dict( 113 | torch.load("tapnet/checkpoints/causal_bootstapir_checkpoint.pt") 114 | ) 115 | model = model.to(device) 116 | model = model.eval() 117 | torch.set_grad_enabled(False) 118 | 119 | # -------------------- 120 | # Start point tracking 121 | print("Initializing camera...") 122 | vc = cv2.VideoCapture(0) 123 | 124 | vc.set(cv2.CAP_PROP_FRAME_HEIGHT, 240) 125 | 126 | if vc.isOpened(): # try to get the first frame 127 | rval, frame = get_frame(vc) 128 | else: 129 | raise ValueError("Unable to open camera.") 130 | 131 | pos = tuple() 132 | query_frame = True 133 | have_point = [False] * NUM_POINTS 134 | 135 | query_points = torch.zeros([NUM_POINTS, 3], dtype=torch.float32) 136 | query_points = query_points.to(device) 137 | frame = torch.tensor(frame).to(device) 138 | 139 | query_features = online_model_init( 140 | frames=frame[None, None], points=query_points[None, :] 141 | ) 142 | 143 | causal_state = model.construct_initial_causal_state( 144 | NUM_POINTS, len(query_features.resolutions) - 1 145 | ) 146 | causal_state = tree.map_structure(lambda x: x.to(device), causal_state) 147 | 148 | prediction, visible, causal_state = online_model_predict( 149 | frames=frame[None, None], 150 | features=query_features, 151 | causal_context=causal_state, 152 | ) 153 | 154 | next_query_idx = 0 155 | last_click_time = 0 156 | 157 | 158 | def mouse_click(event, x, y, flags, param): 159 | del flags, param 160 | global pos, query_frame, last_click_time 161 | 162 | # event fires multiple times per click sometimes?? 163 | if (time.time() - last_click_time) < 0.5: 164 | return 165 | 166 | if event == cv2.EVENT_LBUTTONDOWN: 167 | pos = (y, frame.shape[1] - x) 168 | query_frame = True 169 | last_click_time = time.time() 170 | 171 | 172 | cv2.namedWindow("Point Tracking") 173 | cv2.setMouseCallback("Point Tracking", mouse_click) 174 | 175 | t = time.time() 176 | step_counter = 0 177 | 178 | print("Press ESC to exit.") 179 | 180 | while rval: 181 | rval, frame = get_frame(vc) 182 | numpy_frame = frame 183 | if query_frame: 184 | query_points = np.array((0,) + pos, dtype=np.float32) 185 | frame = torch.tensor(frame).to(device) 186 | query_points = torch.tensor(query_points).to(device) 187 | 188 | init_query_features = online_model_init( 189 | frames=frame[None, None], points=query_points[None, None] 190 | ) 191 | query_frame = False 192 | query_features, causal_state = model.update_query_features( 193 | query_features=query_features, 194 | new_query_features=init_query_features, 195 | idx_to_update=np.array([next_query_idx]), 196 | causal_state=causal_state, 197 | ) 198 | have_point[next_query_idx] = True 199 | next_query_idx = (next_query_idx + 1) % NUM_POINTS 200 | if pos: 201 | frame = torch.tensor(frame).to(device) 202 | track, visible, causal_state = online_model_predict( 203 | frames=frame[None, None], 204 | features=query_features, 205 | causal_context=causal_state, 206 | ) 207 | track = np.round(track.cpu().numpy()) 208 | visible = visible.cpu().numpy() 209 | 210 | for i, _ in enumerate(have_point): 211 | if visible[0, i, 0] and have_point[i]: 212 | cv2.circle( 213 | numpy_frame, 214 | (int(track[0, i, 0, 0]), int(track[0, i, 0, 1])), 215 | 5, 216 | (255, 0, 0), 217 | -1, 218 | ) 219 | if track[0, i, 0, 0] < 16 and track[0, i, 0, 1] < 16: 220 | print((i, next_query_idx)) 221 | cv2.imshow("Point Tracking", numpy_frame[:, ::-1]) 222 | if pos: 223 | step_counter += 1 224 | if time.time() - t > 5: 225 | print(f"{step_counter/(time.time()-t)} frames per second") 226 | t = time.time() 227 | step_counter = 0 228 | else: 229 | t = time.time() 230 | key = cv2.waitKey(1) 231 | 232 | if key == 27: # exit on ESC 233 | break 234 | 235 | cv2.destroyWindow("Point Tracking") 236 | vc.release() 237 | -------------------------------------------------------------------------------- /tapnet/training/task.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Abstract task interface with documentation. 17 | 18 | """ 19 | 20 | import abc 21 | from typing import Mapping, Optional, Tuple 22 | 23 | import chex 24 | import typing_extensions 25 | 26 | 27 | class SharedModule(typing_extensions.Protocol): 28 | 29 | def __call__( 30 | self, 31 | video: chex.Array, 32 | is_training: bool, 33 | query_points: chex.Array, 34 | query_chunk_size: Optional[int] = None, 35 | get_query_feats: bool = False, 36 | **kwargs, 37 | ) -> Mapping[str, chex.Array]: 38 | """Runs a forward pass of a module. 39 | 40 | Args: 41 | video: A 4-D or 5-D tensor representing a batch of sequences of images. In 42 | the 4-D case, we assume the entire batch has been concatenated along the 43 | batch dimension, one sequence after the other. This can speed up 44 | inference on the TPU and save memory. 45 | is_training: Whether we are training. 46 | query_points: The query points for which we compute tracks. 47 | query_chunk_size: When computing cost volumes, break the queries into 48 | chunks of this size to save memory. 49 | get_query_feats: If True, also return the features for each query obtained 50 | using bilinear interpolation from the feature grid 51 | **kwargs: Additional module-specific parameters. 52 | 53 | Returns: 54 | Module outputs. 55 | """ 56 | 57 | 58 | class WrappedForwardFn(typing_extensions.Protocol): 59 | """Forward function, wrapped by haiku. 60 | 61 | This wrapped forward function will inject the shared_modules and allow them 62 | to use shared params. It should be called inside a loss_fn using the same 63 | signature as `Task.forward_fn` (minus the shared_modules). 64 | """ 65 | 66 | def __call__( 67 | self, 68 | params: chex.ArrayTree, 69 | state: chex.ArrayTree, 70 | rng: chex.PRNGKey, 71 | inputs: chex.ArrayTree, 72 | is_training: bool, 73 | input_key: Optional[str] = None, 74 | query_chunk_size: int = 16, 75 | get_query_feats: bool = True, 76 | ) -> Mapping[str, chex.Array]: 77 | """Forward pass for predicting point tracks. 78 | 79 | Args: 80 | params: hk.Params with the model parameters 81 | state: hk.State with the model state 82 | rng: jax.random.PRNGKey for random number generation. 83 | inputs: Input dict. Inference will be performed on will be performed on 84 | inputs[input_key]['video'] (with fallback to the input_key specified in 85 | the constructor). Input videos should be a standard video tensor 86 | ([batch, num_frames, height, width, 3]) normalize to [-1,1]. 87 | inputs[input_key]['query_points'] specifies the query point locations, 88 | of shape [batch, num_queries, 3], where each query is [t,y,x] 89 | coordinates normalized to between -1 and 1. 90 | is_training: Whether the model is in training mode. 91 | input_key: Run on inputs[input_key]['video']. If None, use the input_key 92 | from the constructor. 93 | query_chunk_size: Compute predictions on this many queries simultaneously. 94 | This saves memory as the cost volumes can be very large. 95 | get_query_feats: If True, also return features for each query. 96 | 97 | Returns: 98 | Result dict produced by calling the model. 99 | """ 100 | 101 | 102 | class Task(abc.ABC): 103 | """An abstract Task definition.""" 104 | 105 | @abc.abstractmethod 106 | def forward_fn( 107 | self, 108 | inputs: chex.ArrayTree, 109 | is_training: bool, 110 | shared_modules: Optional[Mapping[str, SharedModule]] = None, 111 | ) -> chex.ArrayTree: 112 | """Run the model forward pass and construct all required Haiku modules. 113 | 114 | Args: 115 | inputs: model input tensors. This is a dict keyed by dataset name, where 116 | the value for each key is an item from the specified dataset. 117 | is_training: Is the forward pass in training mode or not. 118 | shared_modules: A dict of Haiku modules, keyed by module name, which 119 | can be used to construct the modules which are shared across different 120 | tasks. 121 | 122 | Returns: 123 | Anything. The important part is that this must construct all modules that 124 | Haiku needs to initialize. 125 | 126 | """ 127 | 128 | def get_gradients( 129 | self, 130 | params: chex.ArrayTree, 131 | state: chex.ArrayTree, 132 | inputs: chex.ArrayTree, 133 | rng: chex.PRNGKey, 134 | global_step: chex.Array, 135 | wrapped_forward_fn: WrappedForwardFn, 136 | is_training: bool = True, 137 | ) -> Tuple[chex.ArrayTree, chex.ArrayTree, Mapping[str, chex.Array]]: 138 | """Get gradients for this tasks's loss function. 139 | 140 | Params, state, inputs, rng, and global_step are pmapped, i.e. a separate 141 | copy on each device. 142 | 143 | Args: 144 | params: Haiku parameters 145 | state: Haiku state 146 | inputs: model input tensors. This is a dict keyed by dataset name, where 147 | the value for each key is an item from the specified dataset. 148 | rng: random number state 149 | global_step: global step 150 | wrapped_forward_fn: A wrapper for the forward function that will inject 151 | the shared_modules and allow them to use shared params. It should be 152 | called inside a loss_fn using the same signature as forward_fn 153 | (minus the shared_modules). 154 | is_training: Is the forward pass in training mode or not. 155 | 156 | Returns: 157 | grads: A set of gradients compatible with optax apply_gradients (these 158 | will be summed across tasks). 159 | state: An updated Haiku state. The returned state will be passed to the 160 | next task in the list. 161 | scalars: A dict of (pmapped) scalars to be logged for this task. All 162 | dict keys will have the task name prepended before they are logged. 163 | 164 | """ 165 | raise NotImplementedError() 166 | 167 | def evaluate( 168 | self, 169 | global_step: chex.Array, 170 | params: chex.ArrayTree, 171 | state: chex.ArrayTree, 172 | rng: chex.PRNGKey, 173 | wrapped_forward_fn: WrappedForwardFn, 174 | mode: str, 175 | ) -> Mapping[str, chex.Array]: 176 | """Evaluate this task's performance on a downstream benchmark. 177 | 178 | Args: 179 | global_step: global step 180 | params: Haiku parameters 181 | state: Haiku state 182 | rng: random number state 183 | wrapped_forward_fn: A wrapper for the forward function that will inject 184 | the shared_modules and allow them to use shared params. 185 | mode: A string mode used to determine, e.g., which dataset or split to 186 | evaluate on. This will be the same value as the 'mode' parameter 187 | used to launch different eval jobs in Jaxline. 188 | 189 | Returns: 190 | scalars: A dict of scalars to be logged for this task. All 191 | dict keys will have the task name prepended before they are logged. 192 | 193 | """ 194 | raise NotImplementedError() 195 | 196 | 197 | 198 | -------------------------------------------------------------------------------- /tapnet/tapnext/tapnext_torch_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utils for TAPNext torch implementation.""" 17 | 18 | import re 19 | import numpy as np 20 | import torch 21 | import torch.nn.functional as F 22 | 23 | 24 | def get_window(coord, softmax, radius: int = 8): 25 | b = coord.shape[0] 26 | start = torch.floor(coord - radius - 0.5).int() 27 | start.clamp_(min=0) 28 | indices = start + torch.arange(radius * 2 + 1, device=softmax.device).repeat( 29 | b, 1 30 | ) 31 | # this is to simulate one corner case of jax implementation 32 | shift = (indices.max(1).values - softmax.shape[1] + 1).clamp(min=0) 33 | indices -= shift.unsqueeze(1) 34 | softmax = softmax.gather(dim=1, index=indices) 35 | return softmax, indices + 0.5 36 | 37 | 38 | def tracker_certainty(coord_yx, track_logits, radius=8): 39 | """Computes the certainty of the tracker.""" 40 | shape = coord_yx.shape[:-1] 41 | coord_yx = coord_yx.flatten(0, -2) 42 | track_logits = track_logits.flatten(0, -2) 43 | # track_logits.shape == [b, 512] 44 | # coord_yx.shape == [b, 2] 45 | logits_y, logits_x = track_logits.chunk(2, dim=-1) 46 | track_softmax_y = F.softmax(logits_y, dim=-1) 47 | track_softmax_x = F.softmax(logits_x, dim=-1) 48 | sm_y, coord_y = get_window(coord_yx[:, 0:1], track_softmax_y) 49 | sm_x, coord_x = get_window(coord_yx[:, 1:2], track_softmax_x) 50 | sm = sm_y[..., :, None] * sm_x[..., None, :] 51 | grid_x, grid_y = torch.vmap(torch.meshgrid)(coord_x, coord_y) 52 | # grid_x.shape == [b, N, N] 53 | grid = torch.stack([grid_y, grid_x], dim=-1) 54 | in_radius = ((grid - coord_yx[:, None, None]) ** 2).sum(-1) <= ( 55 | (radius**2) + 1e-8 56 | ) 57 | return (sm * in_radius).sum(-1).sum(-1).reshape(*shape, 1) 58 | 59 | 60 | def restore_model_from_jax_checkpoint(model, ckpt_path): 61 | """Restores a TAPNext model from a JAX checkpoint.""" 62 | ckpt = {k: v for k, v in np.load(ckpt_path).items()} 63 | model.lin_proj.weight.data.copy_( 64 | torch.tensor(ckpt['backbone/embedding/kernel'][0]).permute(3, 2, 0, 1) 65 | ) 66 | model.lin_proj.bias.data.copy_(torch.tensor(ckpt['backbone/embedding/bias'])) 67 | model.mask_token.data.copy_(torch.tensor(ckpt['backbone/mask_token'])) 68 | model.point_query_token.data.copy_( 69 | torch.tensor(ckpt['backbone/point_query_token']) 70 | ) 71 | model.unknown_token.data.copy_(torch.tensor(ckpt['backbone/unknown_token'])) 72 | model.image_pos_emb.data.copy_(torch.tensor(ckpt['backbone/pos_embedding'])) 73 | model.encoder_norm.weight.data.copy_( 74 | torch.tensor(ckpt['backbone/Transformer/encoder_norm/scale']) 75 | ) 76 | model.encoder_norm.bias.data.copy_( 77 | torch.tensor(ckpt['backbone/Transformer/encoder_norm/bias']) 78 | ) 79 | for layer in range(12): 80 | # convert ssm part 81 | prefix = f'backbone/Transformer/encoderblock_{layer}/ssm_block' 82 | ssm_params = { 83 | key: torch.tensor( 84 | ckpt[ 85 | f'{prefix}/' 86 | + re.sub('weight', 'kernel', re.sub(r'\.', '/', key)) 87 | ] 88 | ) 89 | for key, _ in model.blocks[layer].ssm_block.named_parameters() 90 | } 91 | for key in ssm_params: 92 | if 'weight' in key: 93 | ssm_params[key] = ssm_params[key].T 94 | model.blocks[layer].ssm_block.load_state_dict(ssm_params) 95 | 96 | # convert vit part 97 | vit_params = { 98 | re.sub( 99 | f'backbone/Transformer/encoderblock_{layer}/vit_block/', '', k 100 | ): v 101 | for k, v in ckpt.items() 102 | if f'backbone/Transformer/encoderblock_{layer}/vit_block' in k 103 | } 104 | torch_vit_params = {} 105 | torch_vit_params['ln_1.weight'] = vit_params['LayerNorm_0/scale'] 106 | torch_vit_params['ln_1.bias'] = vit_params['LayerNorm_0/bias'] 107 | torch_vit_params['ln_2.weight'] = vit_params['LayerNorm_1/scale'] 108 | torch_vit_params['ln_2.bias'] = vit_params['LayerNorm_1/bias'] 109 | torch_vit_params['mlp.0.weight'] = vit_params['MlpBlock_0/Dense_0/kernel'].T 110 | torch_vit_params['mlp.0.bias'] = vit_params['MlpBlock_0/Dense_0/bias'] 111 | torch_vit_params['mlp.3.weight'] = vit_params['MlpBlock_0/Dense_1/kernel'].T 112 | torch_vit_params['mlp.3.bias'] = vit_params['MlpBlock_0/Dense_1/bias'] 113 | torch_vit_params['self_attention.in_proj_weight'] = np.concatenate( 114 | [ 115 | vit_params['MultiHeadDotProductAttention_0/query/kernel'] 116 | .reshape(768, 768) 117 | .T, 118 | vit_params['MultiHeadDotProductAttention_0/key/kernel'] 119 | .reshape(768, 768) 120 | .T, 121 | vit_params['MultiHeadDotProductAttention_0/value/kernel'] 122 | .reshape(768, 768) 123 | .T, 124 | ], 125 | axis=0, 126 | ) 127 | torch_vit_params['self_attention.in_proj_bias'] = np.concatenate([ 128 | vit_params['MultiHeadDotProductAttention_0/query/bias'].flatten(), 129 | vit_params['MultiHeadDotProductAttention_0/key/bias'].flatten(), 130 | vit_params['MultiHeadDotProductAttention_0/value/bias'].flatten(), 131 | ]) 132 | torch_vit_params['self_attention.out_proj.weight'] = ( 133 | vit_params['MultiHeadDotProductAttention_0/out/kernel'] 134 | .reshape(768, 768) 135 | .T 136 | ) 137 | torch_vit_params['self_attention.out_proj.bias'] = vit_params[ 138 | 'MultiHeadDotProductAttention_0/out/bias' 139 | ].flatten() 140 | for k in torch_vit_params: 141 | torch_vit_params[k] = torch.tensor(np.array(torch_vit_params[k])) 142 | model.blocks[layer].vit_block.load_state_dict(torch_vit_params) 143 | model.visible_head[0].weight.data.copy_( 144 | torch.from_numpy(ckpt['visible_head/layers_0/kernel'].T) 145 | ) 146 | model.visible_head[0].bias.data.copy_( 147 | torch.from_numpy(ckpt['visible_head/layers_0/bias']) 148 | ) 149 | model.visible_head[1].weight.data.copy_( 150 | torch.from_numpy(ckpt['visible_head/layers_1/scale']) 151 | ) 152 | model.visible_head[1].bias.data.copy_( 153 | torch.from_numpy(ckpt['visible_head/layers_1/bias']) 154 | ) 155 | model.visible_head[3].weight.data.copy_( 156 | torch.from_numpy(ckpt['visible_head/layers_3/kernel'].T) 157 | ) 158 | model.visible_head[3].bias.data.copy_( 159 | torch.from_numpy(ckpt['visible_head/layers_3/bias']) 160 | ) 161 | model.visible_head[4].weight.data.copy_( 162 | torch.from_numpy(ckpt['visible_head/layers_4/scale']) 163 | ) 164 | model.visible_head[4].bias.data.copy_( 165 | torch.from_numpy(ckpt['visible_head/layers_4/bias']) 166 | ) 167 | model.visible_head[6].weight.data.copy_( 168 | torch.from_numpy(ckpt['visible_head/layers_6/kernel'].T) 169 | ) 170 | model.visible_head[6].bias.data.copy_( 171 | torch.from_numpy(ckpt['visible_head/layers_6/bias']) 172 | ) 173 | 174 | model.coordinate_head[0].weight.data.copy_( 175 | torch.from_numpy(ckpt['coordinate_head/layers_0/kernel'].T) 176 | ) 177 | model.coordinate_head[0].bias.data.copy_( 178 | torch.from_numpy(ckpt['coordinate_head/layers_0/bias']) 179 | ) 180 | model.coordinate_head[1].weight.data.copy_( 181 | torch.from_numpy(ckpt['coordinate_head/layers_1/scale']) 182 | ) 183 | model.coordinate_head[1].bias.data.copy_( 184 | torch.from_numpy(ckpt['coordinate_head/layers_1/bias']) 185 | ) 186 | model.coordinate_head[3].weight.data.copy_( 187 | torch.from_numpy(ckpt['coordinate_head/layers_3/kernel'].T) 188 | ) 189 | model.coordinate_head[3].bias.data.copy_( 190 | torch.from_numpy(ckpt['coordinate_head/layers_3/bias']) 191 | ) 192 | model.coordinate_head[4].weight.data.copy_( 193 | torch.from_numpy(ckpt['coordinate_head/layers_4/scale']) 194 | ) 195 | model.coordinate_head[4].bias.data.copy_( 196 | torch.from_numpy(ckpt['coordinate_head/layers_4/bias']) 197 | ) 198 | model.coordinate_head[6].weight.data.copy_( 199 | torch.from_numpy(ckpt['coordinate_head/layers_6/kernel'].T) 200 | ) 201 | model.coordinate_head[6].bias.data.copy_( 202 | torch.from_numpy(ckpt['coordinate_head/layers_6/bias']) 203 | ) 204 | return model 205 | -------------------------------------------------------------------------------- /tapnet/tapvid/README.md: -------------------------------------------------------------------------------- 1 | # TAP-Vid: A Benchmark for Tracking Any Point in a Video 2 | 3 | https://github.com/google-deepmind/tapnet/assets/4534987/ff5fa5e3-ed37-4480-ad39-42a1e2744d8b 4 | 5 | [TAP-Vid](https://tapvid.github.io) is a dataset of videos along with point tracks, either manually annotated or obtained from a simulator. The aim is to evaluate tracking of any trackable point on any solid physical surface. Algorithms receive a single query point on some frame, and must produce the rest of the track, i.e., including where that point has moved to (if visible), and whether it is visible, on every other frame. This requires point-level precision (unlike prior work on box and segment tracking) potentially on deformable surfaces (unlike structure from motion) over the long term (unlike optical flow) on potentially any object (i.e. class-agnostic, unlike prior class-specific keypoint tracking on humans). 6 | 7 | Our full benchmark incorporates 4 datasets: 30 videos from the [DAVIS val set](https://storage.googleapis.com/dm-tapnet/tapvid_davis.zip), 1000 videos from the [Kinetics val set](https://storage.googleapis.com/dm-tapnet/tapvid_kinetics.zip), 50 synthetic [Deepmind Robotics videos](https://storage.googleapis.com/dm-tapnet/tapvid_rgb_stacking.zip) for evaluation, and (almost infinite) point track ground truth on the large-scale synthetic [Kubric dataset](https://github.com/google-research/kubric/tree/main/challenges/point_tracking) for training. 8 | 9 | We also include a point tracking model TAP-Net, with code to train it on Kubric dataset. TAP-Net outperforms both optical flow and structure-from-motion methods on the TAP-Vid benchmark while achieving state-of-the-art performance on unsupervised human keypoint tracking on JHMDB, even though the model tracks points on clothes and skin rather than the joints as intended by the benchmark. 10 | 11 | ## Evaluating on TAP-Vid 12 | 13 | [`evaluation_datasets.py`](tapvid/evaluation_datasets.py) is intended to be a 14 | stand-alone, copy-and-pasteable reader and evaluator, which depends only 15 | on numpy and other basic tools. Tensorflow is required only for reading Kubric 16 | (which provides a tensorflow reader by default) as well as file operations, 17 | which should be straightforward to replace for systems without Tensorflow. 18 | 19 | For each dataset, there is a basic reader which will produce examples, dicts of 20 | numpy arrays containing the video, the query points, the target points, and the 21 | occlusion flag. Evaluation datasets may be used with one of two possible values 22 | for `query_mode`: `strided` (each trajectory is queried multiple times, with 23 | a fixed-length stride between queries) or `first` (each trajectory is queried 24 | once, with only the first visible point on the query). For details on outputs, 25 | see the documentation for `sample_queries_strided` and `sample_queries_first`. 26 | 27 | To compute metrics, use `compute_tapvid_metrics` in the same file. This 28 | computes results on each batch; the final metrics for the paper can be computed 29 | by simple averaging across all videos in the dataset. See the documentation for 30 | more details. 31 | 32 | Note that the outputs for a single query point *should not depend on the other 33 | queries defined in the batch*: that is, the outputs should be the same whether 34 | the queries are passed one at a time or all at once. This is important because 35 | the other queries may leak information about how pixels are grouped and how they 36 | move. This property is not enforced in the current evaluation code, but 37 | algorithms which violate this principle should not be considered valid 38 | competitors on this benchmark. 39 | 40 | Our readers also supply videos resized at 256x256 resolution. If algorithms can handle it, we encourage using full-resolution videos instead; we anticipate that 41 | predictions on such videos would be scaled to match a 256x256 resolution 42 | before computing metrics. Such predictions would, however, be evaluated as a separate category: we don't consider them comparable to those produced from lower-resolution videos. 43 | 44 | ## Comparison of Tracking With and Without Optical Flow 45 | 46 | When annotating videos for the TAP-Vid benchmark, we use a track assist algorithm interpolates between the sparse points that the annotators click, since requiring annotators to click every frame is prohibitively expensive. Specifically, we find tracks which minimize the discrepancy with the optical flow while still connecting the chosen points. Annotators will then check the interpolations and repeat the annotation until they observe no drift. 47 | 48 | To validate that this is a better approach than a simple linear interpolation between clicked points, we annotated several DAVIS videos twice and [compare them side by side](https://storage.googleapis.com/dm-tapnet/content/flow_tracker.html), once using the flow-based interpolation, and again using a naive linear interpolation, which simply moves the point at a constant velocity between points. 49 | 50 | Point Track Annotation **Flow assist point annotation**: You can run this colab demo to see how point tracks are annotated with optical flow assistance. 51 | 52 | ## Downloading TAP-Vid-DAVIS and TAP-Vid-RGB-Stacking 53 | 54 | The data are contained in pickle files with download links: [TAP-Vid-DAVIS](https://storage.googleapis.com/dm-tapnet/tapvid_davis.zip) and [TAP-Vid-RGB-stacking](https://storage.googleapis.com/dm-tapnet/tapvid_rgb_stacking.zip). 55 | 56 | For DAVIS, the pickle file contains a dictionary, where each key is a DAVIS video name, and the values are the frames (4D uint8 tensor), the points (float32 tensor with 3 axes; the first is point id, the second is time, and the third is x/y), and the occlusions (bool tensor with 2 axis; the first is point id, the second is time). RGB-Stacking is the same format, except there is no video name, so it is a list of these structures rather than a dictionary. 57 | 58 | ## Downloading and Processing TAP-Vid-Kinetics 59 | 60 | The labels are contained in a csv file with download link: [TAP-Vid-Kinetics](https://storage.googleapis.com/dm-tapnet/tapvid_kinetics.zip). 61 | 62 | The videos are expected as the raw clips from [Kinetics700-2020](https://github.com/cvdfoundation/kinetics-dataset) validation set and stored in a local folder ``. The clips should be stored as MP4, following the name pattern `f'{youtube_id}_{start_time_sec:06}_{end_time_sec:06}.mp4'`, e.g. 'abcdefghijk_000010_000020.mp4'. 63 | 64 | Clips can be stored in any subfolder within the ``. The most common pattern is to store it as `//`. 65 | 66 | Once the validation clips have been downloaded, a pickle file containing all the information can be generated using the provided script: 67 | 68 | ```bash 69 | pip3 install -r requirements.txt 70 | python3 generate_tapvid.py \ 71 | --input_csv_path= \ 72 | --output_base_path= \ 73 | --video_root_path= \ 74 | --alsologtostderr 75 | ``` 76 | 77 | ## Downloading RoboTAP 78 | 79 | The data are contained in pickle files with download links: [RoboTAP](https://storage.googleapis.com/dm-tapnet/robotap/robotap.zip). 80 | 81 | RoboTAP follows the same annotation format as TAP-Vid-DAVIS and TAP-Vid-RGB-stacking. The pickle file contains a dictionary, where each key is a video name, and the values are the frames (4D uint8 tensor), the points (float32 tensor with 3 axes; the first is point id, the second is time, and the third is x/y), and the occlusions (bool tensor with 2 axis; the first is point id, the second is time). 82 | 83 | ## Visualizing TAP-Vid and RoboTAP Dataset 84 | 85 | We also provide a script generating an MP4 with the points painted on top of the frames. The script will work with any of the pickle files. A random clip is chosen from all the available ones and all the point tracks are painted. 86 | 87 | ```bash 88 | pip3 install -r requirements.txt 89 | python3 visualize.py \ 90 | --input_path= \ 91 | --output_path= \ 92 | --alsologtostderr 93 | ``` 94 | 95 | ## Exemplar Visualization 96 | For visualization examples, we have the full [TAP-Vid-DAVIS](https://storage.googleapis.com/dm-tapnet/content/davis_ground_truth_v2.html) as well as 10 examples from the synthetic [TAP-Vid-Kubric](https://storage.googleapis.com/dm-tapnet/content/kubric_ground_truth.html) and [TAP-Vid-RGB-Stacking](https://storage.googleapis.com/dm-tapnet/content/rgb_stacking_ground_truth_v2.html) datasets. 97 | 98 | ## Training with TAP-Vid-Kubric 99 | TAP-Vid-DAVIS, TAP-Vid-RGB-stacking and TAP-Vid-Kinetics are mainly used for evaluation purpose. To train the model, we use [TAP-Vid-Kubric](https://github.com/google-research/kubric/tree/main/challenges/point_tracking). 100 | -------------------------------------------------------------------------------- /tapnet/tapvid3d/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Tracking Any Point in 3D (TAPVid-3D) 3 | 4 | [Skanda Koppula](https://skoppula.com/), [Ignacio Rocco](https://www.irocco.info/), [Yi Yang](https://yangyi02.github.io/), [Joe Heyward](https://uk.linkedin.com/in/joe-heyward-71623595), [João Carreira](https://uk.linkedin.com/in/jo%C3%A3o-carreira-56238a7), [Andrew Zisserman](https://www.robots.ox.ac.uk/~az/), [Gabriel Brostow](http://www0.cs.ucl.ac.uk/staff/G.Brostow/), [Carl Doersch](http://www.carldoersch.com/) 5 | 6 | **[Google DeepMind](https://deepmind.google/)**, **[University College London](http://vis.cs.ucl.ac.uk/home/), [University of Oxford](https://www.robots.ox.ac.uk/~vgg/)** 7 | 8 | ### [`TAPVid-3D Website`](https://tapvid3d.github.io/) [`TAPVid-3D Paper`](https://arxiv.org/abs/2407.05921) [`Colab to Visualize Samples`](https://colab.research.google.com/github/google-deepmind/tapnet/blob/main/tapnet/tapvid3d/colab/load_and_visualize_tapvid3d_samples.ipynb) 9 | 10 | TAPVid-3D is a dataset and benchmark for evaluating the task of long-range 11 | Tracking Any Point in 3D (TAP-3D). 12 | 13 | The benchmark features 4,000+ real-world videos, along with their metric 3D 14 | position point trajectories and camera extrinsics. The dataset is contains three 15 | different video sources, and spans a variety of object types, motion patterns, 16 | and indoor and outdoor environments. This repository folder contains the code to 17 | download and generate these annotations and dataset samples to view. 18 | 19 | **Note that in order to use the dataset, you must accept the licenses of the 20 | constituent original video data sources that you use!** In particular, you must 21 | adhere to the terms of service outlined in: 22 | 23 | 1. [Aria Digital Twin](https://www.projectaria.com/datasets/adt/license/) 24 | 2. [Waymo Open Dataset](https://waymo.com/open/terms/) 25 | 3. [Panoptic Studio](http://domedb.perception.cs.cmu.edu/) 26 | 27 | To measure performance on the TAP-3D task, we formulated metrics that extend the 28 | Jaccard-based metric used in 2D TAP to handle the complexities of ambiguous 29 | depth scales across models, occlusions, and multi-track spatio-temporal 30 | smoothness. Our implementation of these metrics (3D-AJ, APD, OA) can be found in 31 | `tapvid3d_metrics.py`. 32 | 33 | For more details, including the performance of multiple baseline 3D point 34 | tracking models on the benchmark, please see the paper: 35 | [TAPVid-3D:A Benchmark for Tracking Any Point in 3D](http://arxiv.org). 36 | 37 | ### Getting Started: Installing 38 | 39 | (If you want to generate the dataset and want to use CUDA for running semantic segmentation, first install a CUDA-enabled PyTorch with `pip3 install torch==2.3.0 torchvision==0.18.0 --index-url https://download.pytorch.org/whl/cu118`) 40 | 41 | For generating the dataset, install the Tracking Any Point repo with: 42 | 43 | `pip install "git+https://github.com/google-deepmind/tapnet.git"[tapvid3d_eval,tapvid3d_generation]` 44 | 45 | If you only want to use the metrics, install with: 46 | 47 | `pip install "git+https://github.com/google-deepmind/tapnet.git"[tapvid3d_eval]` 48 | 49 | For a local editable installation, clone the repo and use: 50 | 51 | `pip install -e .[tapvid3d_eval,tapvid3d_generation]` 52 | 53 | or 54 | 55 | `pip install -e .[tapvid3d_eval]`. 56 | 57 | ### How to Download and Generate the Dataset 58 | 59 | Scripts to download and generate the annotations of each of the datasets can be 60 | found in the `annotation_generation` subdirectory, and can be run as: 61 | 62 | ``` 63 | python3 -m tapnet.tapvid3d.annotation_generation.generate_adt --help 64 | python3 -m tapnet.tapvid3d.annotation_generation.generate_pstudio --help 65 | python3 -m tapnet.tapvid3d.annotation_generation.generate_drivetrack --help 66 | ``` 67 | 68 | Because of license restrictions in distribution of the underlying source videos 69 | for Aria Digital Twin, you will need to accept their licence terms and download 70 | the ADT dataset by first getting the cdn json file from 71 | [Project Aria Explorer](https://explorer.projectaria.com/?v=%22Aria+Digital+Twin%22), 72 | and downloading the ADT `main_vrs`, `main_groundtruth`, `segmentation` and `depth` files with: 73 | 74 | `aria_dataset_downloader --cdn_file /PATH_TO/Aria_Digital_Twin_1720774989.json -o /OUTPUT_PATH -d 0 6 7 8` 75 | 76 | To run all generation scripts, follow the instructions and run the commands in 77 | `generate_all.sh`. This will generate all the `*.npz` files, and place them into 78 | a new folder, `tapvid3d_dataset`. To generate only the `minival` split, 79 | replace `--split=all` in this script with `--split=minival`. 80 | 81 | To test things are working before full generation, you can run `./run_all.sh 82 | --debug`. This runs generates exactly one `*.npz`/video annotation per data 83 | source. 84 | 85 | ### Data format 86 | 87 | Once the benchmark files are fully generated, you will have roughly 4,500 88 | `*.npz` files, each one with exactly one dataset video+annotation. Each `*.npz` 89 | file, contains: 90 | 91 | * `images_jpeg_bytes`: tensor of shape [`# of frames`, `height`, `width`, 3], 92 | each frame stored as JPEG bytes that must be decoded 93 | 94 | * `intrinsics`: (fx, fy, cx, cy) camera intrinsics of the video 95 | 96 | * `tracks_xyz`: tensor of shape (`# of frames`, `# of point tracks`, 3), 97 | representing the 3D point trajectories and the last dimension is the `(x, y, 98 | z)` point position in meters. 99 | 100 | * `visibility`: tensor of shape (`# of frames`, `# of point tracks`), 101 | representing the visibility of each point along its trajectory 102 | 103 | * `queries_xyt`: tensor of shape (`# of point tracks`, 3), representing the 104 | query point used in the benchmark as the initial given point to track. The 105 | last dimension is given in `(x, y, t)`, where `x,y` is the pixel location of 106 | the query point and `t` is the query frame. 107 | 108 | * `extrinsics_w2c`: for videos with a moving camera (videos from the Waymo 109 | Open Dataset and ADT), we provide camera extrinsics in the form of a world 110 | to camera transform, a 4x4 matrix consisting of the rotation matrix and 111 | translation matrix. This field is NOT present in the `pstudio` *.npz files, 112 | because in Panoptic Studio, the camera is static. 113 | 114 | ### Getting Started: Evaluating Your Own 3D Tracking Model 115 | 116 | To get started using this dataset to benchmark your own 3D tracking model, check 117 | out `evaluate_model.py`. This script reads the ground truth track prediction, 118 | and predicted tracks and computes the TAPVid-3D evaluation metrics for those 119 | predictions. An example of running this is provided in `run_evaluate_model.sh`, 120 | which computes metrics for our precomputed outputs from 121 | [SpatialTracker](https://henry123-boy.github.io/SpaTracker/) on the `minival` 122 | split, for only the `DriveTrack` subset. `evaluate_model.py` 123 | takes as input a folder with subdirectories (corresponding to up to 3 data 124 | sources [among PStudio/DriveTrack/ADT]). Each 125 | subdirectory should contain a list of `*.npz` files 126 | (one per TAPVid-3D video, each containing 127 | `tracks_XYZ` and `visibility` in the format above). 128 | Running this will return all TAP-3D metrics 129 | described in the paper (these are implemented in `tapvid3d_metrics.py`). 130 | 131 | ### Visualizing Samples in Colab 132 | 133 | You can view samples of the dataset, using a public Colab demo: 134 | TAPVid-3D Colab Visualization. 135 | 136 | ## Citing this Work 137 | 138 | If you find this work useful, consider citing the manuscript: 139 | 140 | ``` 141 | @inproceedings{koppula2024tapvid3d, 142 | title={TAPVid-3D: A Benchmark for Tracking Any Point in 3D}, 143 | author={Skanda Koppula and Ignacio Rocco and Yi Yang and Joe Heyward and João Carreira and Andrew Zisserman and Gabriel Brostow and Carl Doersch}, 144 | booktitle={NeurIPS}, 145 | year={2024}, 146 | } 147 | ``` 148 | 149 | ## License and Disclaimer 150 | 151 | Copyright 2024 Google LLC 152 | 153 | All software here is licensed under the Apache License, Version 2.0 (Apache 154 | 2.0); you may not use this file except in compliance with the Apache 2.0 155 | license. You may obtain a copy of the Apache 2.0 license at: 156 | 157 | https://www.apache.org/licenses/LICENSE-2.0 158 | 159 | All other materials, excepting the source videos and artifacts used to generate 160 | annotations, are licensed under the Creative Commons Attribution 4.0 161 | International License (CC-BY). You may obtain a copy of the CC-BY license at: 162 | https://creativecommons.org/licenses/by/4.0/legalcode 163 | 164 | The Aria Digital Twin, Waymo Open Dataset, and Panoptic Studio datasets all 165 | have individual usage terms and licenses that must be adhered to when using any 166 | software or materials from here. 167 | 168 | Unless required by applicable law or agreed to in writing, all software and 169 | materials distributed here under the Apache 2.0 or CC-BY licenses are 170 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, 171 | either express or implied. See the licenses for the specific language governing 172 | permissions and limitations under those licenses. 173 | 174 | This is not an official Google product. 175 | -------------------------------------------------------------------------------- /tapnet/models/tsm_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utils functions for TSM.""" 17 | 18 | from typing import Tuple 19 | 20 | import chex 21 | import jax 22 | import jax.numpy as jnp 23 | 24 | 25 | def prepare_inputs(inputs: chex.Array) -> Tuple[jnp.ndarray, str, int]: 26 | """Deduces input mode for TSM.""" 27 | # Deduce if we run on TPU based on input shape. 28 | if len(inputs.shape) == 5: 29 | # Input is given in the standard [B, T, H, W, 3] format. 30 | tsm_mode = 'gpu' 31 | num_frames = inputs.shape[1] 32 | inputs = jnp.reshape(inputs, [-1] + list(inputs.shape[2:])) 33 | else: 34 | # Input is given in the [T * B, H, W, 3] format. 35 | tsm_mode = 'tpu' 36 | num_frames = None 37 | return inputs, tsm_mode, num_frames 38 | 39 | 40 | def prepare_outputs( 41 | outputs: chex.Array, 42 | tsm_mode: str, 43 | num_frames: int, 44 | reduce_mean: bool = True, 45 | ) -> jnp.ndarray: 46 | """Processes output of TSM to undo the merging of batch and time.""" 47 | # Get the shape without the batch/time dimension (for TSM batch and time are 48 | # merged in the first dimension). 49 | shape_no_bt = list(outputs.shape[1:]) 50 | if tsm_mode == 'tpu': 51 | # Outputs are of the shape [num_frames * B, ..., n_channels] 52 | outputs = jnp.reshape(outputs, [num_frames, -1] + shape_no_bt) 53 | if reduce_mean: 54 | # We average over time and space. 55 | outputs = jnp.mean( 56 | outputs, axis=[0] + list(range(2, 57 | len(shape_no_bt) + 1))) 58 | else: 59 | outputs = jnp.transpose( 60 | outputs, axes=[1, 0] + list(range(2, 61 | len(shape_no_bt) + 2))) 62 | elif tsm_mode == 'gpu': 63 | # Outputs are of the shape [B * num_frames, ..., n_channels]. 64 | outputs = jnp.reshape(outputs, [-1, num_frames] + shape_no_bt) 65 | if reduce_mean: 66 | outputs = jnp.mean( 67 | outputs, axis=[1] + list(range(2, 68 | len(shape_no_bt) + 1))) 69 | elif tsm_mode.startswith('deflated'): 70 | # In deflated mode, outputs are already in the right format. 71 | pass 72 | else: 73 | raise ValueError('`tsm_mode` should be \'tpu\' or \'gpu\' or ' 74 | f'\'deflated_0.x\' ({tsm_mode} given)') 75 | return outputs # pytype: disable=bad-return-type # numpy-scalars 76 | 77 | 78 | def apply_temporal_shift( 79 | x: chex.Array, 80 | tsm_mode: str, 81 | num_frames: int, 82 | channel_shift_fraction: float = 0.125, 83 | ) -> jnp.ndarray: 84 | """Performs a temporal shift: https://arxiv.org/abs/1811.08383 with mode.""" 85 | if tsm_mode == 'tpu': 86 | outputs = temporal_shift_tpu(x, num_frames, channel_shift_fraction) 87 | elif tsm_mode == 'gpu': 88 | outputs = temporal_shift_gpu(x, num_frames, channel_shift_fraction) 89 | elif tsm_mode.startswith('deflated'): 90 | alpha = float(tsm_mode.split('_')[1]) 91 | outputs = temporal_shift_image_mode(x, channel_shift_fraction, alpha) 92 | else: 93 | raise ValueError('`tsm_mode` should be \'tpu\' or \'gpu\' or ' 94 | f'\'deflated_0.x\' ({tsm_mode} given)') 95 | return outputs 96 | 97 | 98 | def temporal_shift_image_mode(x, channel_shift_fraction=0.125, alpha=0.3): 99 | """Temporal shift applied on single image (to emulate a fixed video).""" 100 | # B, H, W, C = batch_size, im_height, im_width, channels. 101 | # Input is (B, H, W, C). 102 | orig_shp = tuple(x.shape) 103 | n_channels = orig_shp[-1] 104 | n_shift = int(n_channels * channel_shift_fraction) 105 | # Alpha emulates the effect of the padding when using a single frame. 106 | shifted_backward = alpha * x[:, :, :, -n_shift:] 107 | shifted_forward = alpha * x[:, :, :, :n_shift] 108 | no_shift = x[:, :, :, n_shift:-n_shift] 109 | shifted_x = jnp.concatenate([shifted_backward, no_shift, shifted_forward], 110 | axis=3) 111 | return shifted_x 112 | 113 | 114 | def temporal_shift_gpu( 115 | x: chex.Array, 116 | num_frames: int, 117 | channel_shift_fraction: float = 0.125, 118 | ) -> jnp.ndarray: 119 | """Performs a temporal shift: https://arxiv.org/abs/1811.08383.""" 120 | # B, T, H, W, C = batch_size, num_frames, im_height, im_width, channels. 121 | # Input is (B * T, H, W, C). 122 | orig_shp = tuple(x.shape) 123 | reshaped_x = jnp.reshape(x, (-1, num_frames) + orig_shp[1:]) 124 | n_channels = orig_shp[-1] 125 | n_shift = int(n_channels * channel_shift_fraction) 126 | 127 | new_shp = tuple(reshaped_x.shape) 128 | 129 | # shifted_backward = reshaped_x[:, 1:, :, :, -n_shift:]. 130 | shifted_backward = jax.lax.slice( 131 | reshaped_x, (0, 1, 0, 0, new_shp[4] - n_shift), 132 | (new_shp[0], new_shp[1], new_shp[2], new_shp[3], new_shp[4])) 133 | shifted_backward_padding = ((0, 0), (0, 1), (0, 0), (0, 0), (0, 0)) 134 | shifted_backward = jnp.pad(shifted_backward, shifted_backward_padding) 135 | 136 | # shifted_forward = reshaped_x[:, :-1, :, :, :n_shift]. 137 | shifted_forward = jax.lax.slice( 138 | reshaped_x, (0, 0, 0, 0, 0), 139 | (new_shp[0], new_shp[1] - 1, new_shp[2], new_shp[3], n_shift)) 140 | shifted_forward_padding = ((0, 0), (1, 0), (0, 0), (0, 0), (0, 0)) 141 | shifted_forward = jnp.pad(shifted_forward, shifted_forward_padding) 142 | 143 | no_shift = reshaped_x[:, :, :, :, n_shift:-n_shift] 144 | shifted_x = jnp.concatenate([shifted_backward, no_shift, shifted_forward], 145 | axis=4) 146 | return jnp.reshape(shifted_x, (-1,) + orig_shp[1:]) 147 | 148 | 149 | def temporal_shift_tpu( 150 | x: chex.Array, 151 | num_frames: int, 152 | channel_shift_fraction: float = 0.125, 153 | ) -> jnp.ndarray: 154 | """Performs a temporal shift: https://arxiv.org/abs/1811.08383. 155 | 156 | TPU optimized version of TSM. Reshape is avoided by having the images 157 | reshaped in [T * B, :] so that frames corresponding to same time frame in 158 | videos are contiguous in memory. Finally, to avoid concatenate that prevent 159 | some fusion from happening we simply sum masked version of the features. 160 | Args: 161 | x: Input expected to be [T * B, H, W, C] (where the batch has been reshaped 162 | from a time major version of the input). 163 | num_frames: number of frames T per video. 164 | channel_shift_fraction: fraction of the channel to shift forward and 165 | backward. 166 | 167 | Returns: 168 | The temporal shifted version of x. 169 | """ 170 | # B, T, H, W, C = batch_size, num_frames, im_height, im_width, channels. 171 | # Input is (T * B, H, W, C). 172 | original_dtype = x.dtype 173 | original_shape = list(x.shape) 174 | 175 | batch_size = int(original_shape[0] / num_frames) 176 | n_channels = int(original_shape[-1]) 177 | n_shift = int(n_channels * channel_shift_fraction) 178 | 179 | # Cast to bfloat16. 180 | x = x.astype(jnp.bfloat16) 181 | 182 | # For the following, assume that x has 3 channels [x1, x2, x3] and n_shift=1. 183 | # Shift backward, we first pad by zeros [x1, x2, x3, 0, 0]. 184 | orig_shp = list(x.shape) 185 | 186 | shifted_backward_padding = ((0, batch_size, 0), (0, 0, 0), (0, 0, 0), 187 | (0, n_channels - n_shift, 0)) 188 | x_backward_padding = jax.lax.pad( 189 | x, 190 | padding_value=jnp.bfloat16(0.), 191 | padding_config=shifted_backward_padding) 192 | # The following shift gets to [x3^+1, 0, 0] (where +1 means from the future). 193 | shifted_backward = jax.lax.slice(x_backward_padding, 194 | (batch_size, 0, 0, n_channels - n_shift), 195 | (orig_shp[0] + batch_size, orig_shp[1], 196 | orig_shp[2], 2 * n_channels - n_shift)) 197 | # Shift forward, we first pad by zeros [0, 0, x1, x2, x3]. 198 | shifted_forward_padding = ((batch_size, 0, 0), (0, 0, 0), (0, 0, 0), 199 | (n_channels - n_shift, 0, 0)) 200 | x_forward_padding = jax.lax.pad( 201 | x, padding_value=jnp.bfloat16(0.), padding_config=shifted_forward_padding) 202 | # The following shift gets to [0, 0, x1^-1] (where -1 means from the past). 203 | shifted_forward = jax.lax.slice( 204 | x_forward_padding, (0, 0, 0, 0), 205 | (orig_shp[0], orig_shp[1], orig_shp[2], n_channels)) 206 | # No shift is in the middle, this gets [0, x2, 0]. 207 | mask_noshift = (jnp.reshape((jnp.arange(n_channels) >= n_shift) & 208 | (jnp.arange(n_channels) < n_channels - n_shift), 209 | (1, 1, 1, -1))).astype(jnp.bfloat16) 210 | no_shift = mask_noshift * x 211 | # By summing everything together, we end up with [x3^+1, x2, x1^-1]. 212 | # Note: channels have been reordered but that doesn't matter for the model. 213 | shifted_x = shifted_backward + shifted_forward + no_shift 214 | 215 | return shifted_x.astype(original_dtype) 216 | -------------------------------------------------------------------------------- /tapnet/utils/experiment_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Logging and other experiment utilities.""" 17 | 18 | import os 19 | from typing import Mapping, Optional 20 | 21 | import jax 22 | from jaxline import utils 23 | from ml_collections import config_dict 24 | import numpy as np 25 | import optax 26 | import tensorflow as tf 27 | 28 | from tapnet.utils import optimizers 29 | 30 | 31 | def get_lr_schedule( 32 | total_steps: int, 33 | optimizer_config: config_dict.ConfigDict, 34 | ) -> optax.Schedule: 35 | """Build the LR schedule function.""" 36 | base_lr = optimizer_config.base_lr 37 | 38 | schedule_type = optimizer_config.schedule_type 39 | if schedule_type == 'cosine': 40 | warmup_steps = (optimizer_config.cosine_decay_kwargs.warmup_steps) 41 | # Batch scale the other lr values as well: 42 | init_value = optimizer_config.cosine_decay_kwargs.init_value 43 | end_value = optimizer_config.cosine_decay_kwargs.end_value 44 | 45 | schedule_fn = optax.warmup_cosine_decay_schedule( 46 | init_value=init_value, 47 | peak_value=base_lr, 48 | warmup_steps=warmup_steps, 49 | decay_steps=total_steps, 50 | end_value=end_value) 51 | elif schedule_type == 'constant_cosine': 52 | # Convert end_value to alpha, used by cosine_decay_schedule. 53 | alpha = optimizer_config.constant_cosine_decay_kwargs.end_value / base_lr 54 | 55 | # Number of steps spent in constant phase. 56 | constant_steps = int( 57 | optimizer_config.constant_cosine_decay_kwargs.constant_fraction * 58 | total_steps) 59 | decay_steps = total_steps - constant_steps 60 | 61 | constant_phase = optax.constant_schedule(value=base_lr) 62 | decay_phase = optax.cosine_decay_schedule( 63 | init_value=base_lr, decay_steps=decay_steps, alpha=alpha) 64 | schedule_fn = optax.join_schedules( 65 | schedules=[constant_phase, decay_phase], boundaries=[constant_steps]) 66 | else: 67 | raise ValueError(f'Unknown learning rate schedule: {schedule_type}') 68 | 69 | return schedule_fn 70 | 71 | 72 | def make_optimizer( 73 | optimizer_config: config_dict.ConfigDict, 74 | lr_schedule: optax.Schedule, 75 | ) -> optax.GradientTransformation: 76 | """Construct the optax optimizer with given LR schedule.""" 77 | # Decay learned position embeddings by default. 78 | weight_decay_exclude_names = ['b'] 79 | 80 | optax_chain = [] 81 | if optimizer_config.max_norm > 0: 82 | optax_chain.append(optax.clip_by_global_norm(optimizer_config.max_norm)) 83 | 84 | if optimizer_config.optimizer == 'sgd': 85 | optax_chain.extend([ 86 | optax.trace(**optimizer_config.sgd_kwargs), 87 | optimizers.add_weight_decay( 88 | optimizer_config.weight_decay, 89 | exclude_names=weight_decay_exclude_names) 90 | ]) 91 | elif optimizer_config.optimizer == 'adam': 92 | optax_chain.extend([ 93 | optax.scale_by_adam(**optimizer_config.adam_kwargs), 94 | optimizers.add_weight_decay( 95 | optimizer_config.weight_decay, 96 | exclude_names=weight_decay_exclude_names) 97 | ]) 98 | else: 99 | raise ValueError(f'Undefined optimizer {optimizer_config.optimizer}') 100 | optax_chain.extend([ 101 | optax.scale_by_schedule(lr_schedule), 102 | optax.scale(-1), 103 | ]) 104 | 105 | optimizer = optax.chain(*optax_chain) 106 | optimizer = optax.apply_if_finite(optimizer, max_consecutive_errors=5) 107 | return optimizer 108 | 109 | 110 | class NumpyFileCheckpointer(utils.Checkpointer): 111 | """A Jaxline checkpointer which saves to numpy files on disk.""" 112 | 113 | def __init__(self, config: config_dict.ConfigDict, mode: str): 114 | self._checkpoint_file = os.path.join(config.checkpoint_dir, 115 | 'checkpoint.npy') 116 | self._checkpoint_state = config_dict.ConfigDict() 117 | del mode 118 | 119 | def get_experiment_state(self, ckpt_series: str) -> config_dict.ConfigDict: 120 | """Returns the experiment state for a given checkpoint series.""" 121 | if ckpt_series != 'latest': 122 | raise ValueError('multiple checkpoint series are not supported') 123 | return self._checkpoint_state 124 | 125 | def save(self, ckpt_series: str) -> None: 126 | """Saves the checkpoint.""" 127 | if ckpt_series != 'latest': 128 | raise ValueError('multiple checkpoint series are not supported') 129 | exp_mod = self._checkpoint_state.experiment_module 130 | global_step = self._checkpoint_state.global_step 131 | f_np = lambda x: np.array(jax.device_get(utils.get_first(x))) 132 | to_save = {} 133 | for attr, name in exp_mod.CHECKPOINT_ATTRS.items(): 134 | if name == 'global_step': 135 | raise ValueError( 136 | 'global_step attribute would overwrite jaxline global step') 137 | np_params = jax.tree_util.tree_map(f_np, getattr(exp_mod, attr)) 138 | to_save[name] = np_params 139 | to_save['global_step'] = global_step 140 | 141 | with tf.io.gfile.GFile(self._checkpoint_file + '_tmp', 'wb') as fp: 142 | np.save(fp, to_save) 143 | tf.io.gfile.rename( 144 | self._checkpoint_file + '_tmp', 145 | self._checkpoint_file, 146 | overwrite=True, 147 | ) 148 | 149 | def can_be_restored(self, ckpt_series: str) -> bool: 150 | """Returns whether or not a given checkpoint series can be restored.""" 151 | if ckpt_series != 'latest': 152 | raise ValueError('multiple checkpoint series are not supported') 153 | return tf.io.gfile.exists(self._checkpoint_file) 154 | 155 | def restore(self, ckpt_series: str) -> None: 156 | """Restores the checkpoint.""" 157 | experiment_state = self.get_experiment_state(ckpt_series) 158 | with tf.io.gfile.GFile(self._checkpoint_file, 'rb') as fp: 159 | ckpt_state = np.load(fp, allow_pickle=True).item() 160 | experiment_state.global_step = int(ckpt_state['global_step']) 161 | exp_mod = experiment_state.experiment_module 162 | for attr, name in exp_mod.CHECKPOINT_ATTRS.items(): 163 | setattr(exp_mod, attr, utils.bcast_local_devices(ckpt_state[name])) 164 | 165 | def restore_path(self, ckpt_series: str) -> Optional[str]: 166 | """Returns the restore path for the checkpoint, or None.""" 167 | if not self.can_be_restored(ckpt_series): 168 | return None 169 | return self._checkpoint_file 170 | 171 | def wait_for_checkpointing_to_finish(self) -> None: 172 | """Waits for any async checkpointing to complete.""" 173 | 174 | @classmethod 175 | def create( 176 | cls, 177 | config: config_dict.ConfigDict, 178 | mode: str, 179 | ) -> utils.Checkpointer: 180 | return cls(config, mode) 181 | 182 | 183 | def default_color_augmentation_fn( 184 | inputs: Mapping[str, tf.Tensor], 185 | zero_centering_image: bool = True, 186 | prob_color_augment: float = 0.8, 187 | prob_color_drop: float = 0.2, 188 | ) -> Mapping[str, tf.Tensor]: 189 | """Standard color augmentation for videos. 190 | 191 | Args: 192 | inputs: A DatasetElement containing the item 'video' which will have 193 | augmentations applied to it. 194 | zero_centering_image: Whether to zero center the image. 195 | prob_color_augment: Probability of applying color augmentation. 196 | prob_color_drop: Probability of applying color drop. 197 | 198 | Returns: 199 | A DatasetElement with all the same data as the original, except that 200 | the video has augmentations applied. 201 | """ 202 | frames = inputs['video'] 203 | if frames.dtype != tf.float32: 204 | raise ValueError('`frames` should be in float32.') 205 | 206 | def color_augment(video: tf.Tensor) -> tf.Tensor: 207 | """Do standard color augmentations.""" 208 | # Note the same augmentation will be applied to all frames of the video. 209 | if zero_centering_image: 210 | video = 0.5 * (video + 1.0) 211 | video = tf.image.random_brightness(video, max_delta=32. / 255.) 212 | video = tf.image.random_saturation(video, lower=0.6, upper=1.4) 213 | video = tf.image.random_contrast(video, lower=0.6, upper=1.4) 214 | video = tf.image.random_hue(video, max_delta=0.2) 215 | video = tf.clip_by_value(video, 0.0, 1.0) 216 | if zero_centering_image: 217 | video = 2 * (video-0.5) 218 | return video 219 | 220 | def color_drop(video: tf.Tensor) -> tf.Tensor: 221 | video = tf.image.rgb_to_grayscale(video) 222 | video = tf.tile(video, [1, 1, 1, 3]) 223 | return video 224 | 225 | # Eventually applies color augmentation. 226 | coin_toss_color_augment = tf.random.uniform( 227 | [], minval=0, maxval=1, dtype=tf.float32) 228 | frames = tf.cond( 229 | pred=tf.less(coin_toss_color_augment, 230 | tf.cast(prob_color_augment, tf.float32)), 231 | true_fn=lambda: color_augment(frames), 232 | false_fn=lambda: frames) 233 | 234 | # Eventually applies color drop. 235 | coin_toss_color_drop = tf.random.uniform( 236 | [], minval=0, maxval=1, dtype=tf.float32) 237 | frames = tf.cond( 238 | pred=tf.less(coin_toss_color_drop, tf.cast(prob_color_drop, tf.float32)), 239 | true_fn=lambda: color_drop(frames), 240 | false_fn=lambda: frames) 241 | result = {**inputs} 242 | result['video'] = frames 243 | 244 | return result 245 | 246 | 247 | def add_default_data_augmentation(ds: tf.data.Dataset) -> tf.data.Dataset: 248 | return ds.map( 249 | default_color_augmentation_fn, num_parallel_calls=tf.data.AUTOTUNE) 250 | -------------------------------------------------------------------------------- /tapnet/models/tapnet_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """TAP-Net model definition.""" 17 | 18 | import functools 19 | from typing import Optional, Mapping, Tuple 20 | 21 | import chex 22 | from einshape import jax_einshape as einshape 23 | import haiku as hk 24 | import jax 25 | import jax.numpy as jnp 26 | 27 | from tapnet.models import tsm_resnet 28 | from tapnet.utils import model_utils 29 | from tapnet.utils import transforms 30 | 31 | 32 | def create_batch_norm( 33 | x: chex.Array, is_training: bool, cross_replica_axis: Optional[str] 34 | ) -> chex.Array: 35 | """Function to allow TSM-ResNet to create batch norm layers.""" 36 | return hk.BatchNorm( 37 | create_scale=True, 38 | create_offset=True, 39 | decay_rate=0.9, 40 | cross_replica_axis=cross_replica_axis, 41 | )(x, is_training) 42 | 43 | 44 | class TAPNet(hk.Module): 45 | """Joint model for performing flow-based tasks.""" 46 | 47 | def __init__( 48 | self, 49 | feature_grid_stride: int = 8, 50 | num_heads: int = 1, 51 | cross_replica_axis: Optional[str] = 'i', 52 | num_frames: int = 24, 53 | ): 54 | """Initialize the model and provide kwargs for the various components. 55 | 56 | Args: 57 | feature_grid_stride: Stride to extract features. For TSM-ResNet, 58 | supported values are 8 (default), 16, and 32. 59 | num_heads: Number of heads in the cost volume. 60 | cross_replica_axis: Which cross replica axis to use for the batch norm. 61 | num_frames: Number of frames passed into TSM-ResNet. 62 | """ 63 | 64 | super().__init__() 65 | 66 | self.feature_grid_stride = feature_grid_stride 67 | self.num_heads = num_heads 68 | self.softmax_temperature = 10.0 69 | 70 | self.tsm_resnet = tsm_resnet.TSMResNetV2( 71 | normalize_fn=functools.partial( 72 | create_batch_norm, cross_replica_axis=cross_replica_axis 73 | ), 74 | num_frames=num_frames, 75 | channel_shift_fraction=[0.125, 0.125, 0.0, 0.0], 76 | name='tsm_resnet_video', 77 | ) 78 | 79 | self.cost_volume_track_mods = { 80 | 'hid1': 81 | hk.Conv3D( 82 | 16, 83 | [1, 3, 3], 84 | name='cost_volume_regression_1', 85 | stride=[1, 1, 1], 86 | ), 87 | 'hid2': 88 | hk.Conv3D( 89 | 1, 90 | [1, 3, 3], 91 | name='cost_volume_regression_2', 92 | stride=[1, 1, 1], 93 | ), 94 | 'hid3': 95 | hk.Conv3D( 96 | 32, 97 | [1, 3, 3], 98 | name='cost_volume_occlusion_1', 99 | stride=[1, 2, 2], 100 | ), 101 | 'hid4': 102 | hk.Linear(16, name='cost_volume_occlusion_2'), 103 | 'occ_out': 104 | hk.Linear(1, name='occlusion_out'), 105 | 'regression_hid': 106 | hk.Linear(128, name='regression_hid'), 107 | 'regression_out': 108 | hk.Linear(2, name='regression_out'), 109 | } 110 | 111 | def tracks_from_cost_volume( 112 | self, 113 | interp_feature_heads: chex.Array, 114 | feature_grid_heads: chex.Array, 115 | query_points: Optional[chex.Array], 116 | im_shp: Optional[chex.Shape] = None, 117 | ) -> Tuple[chex.Array, chex.Array]: 118 | """Converts features into tracks by computing a cost volume. 119 | 120 | The computed cost volume will have shape 121 | [batch, num_queries, time, height, width, num_heads], which can be very 122 | memory intensive. 123 | 124 | Args: 125 | interp_feature_heads: A tensor of features for each query point, of shape 126 | [batch, num_queries, channels, heads]. 127 | feature_grid_heads: A tensor of features for the video, of shape [batch, 128 | time, height, width, channels, heads]. 129 | query_points: When computing tracks, we assume these points are given as 130 | ground truth and we reproduce them exactly. This is a set of points of 131 | shape [batch, num_points, 3], where each entry is [t, y, x] in frame/ 132 | raster coordinates. 133 | im_shp: The shape of the original image, i.e., [batch, num_frames, time, 134 | height, width, 3]. 135 | 136 | Returns: 137 | A 2-tuple of the inferred points (of shape 138 | [batch, num_points, num_frames, 2] where each point is [x, y]) and 139 | inferred occlusion (of shape [batch, num_points, num_frames], where 140 | each is a logit where higher means occluded) 141 | """ 142 | 143 | mods = self.cost_volume_track_mods 144 | # Note: time is first axis to prevent the TPU from padding 145 | cost_volume = jnp.einsum( 146 | 'bncd,bthwcd->tbnhwd', 147 | interp_feature_heads, 148 | feature_grid_heads, 149 | ) 150 | shape = cost_volume.shape 151 | cost_volume = einshape('tbnhwd->t(bn)hwd', cost_volume) 152 | 153 | occlusion = mods['hid1'](cost_volume) 154 | occlusion = jax.nn.relu(occlusion) 155 | 156 | pos = mods['hid2'](occlusion) 157 | pos = jax.nn.softmax(pos * self.softmax_temperature, axis=(-2, -3)) 158 | pos = einshape('t(bn)hw1->bnthw', pos, n=shape[2]) 159 | points = model_utils.heatmaps_to_points( 160 | pos, im_shp, query_points=query_points 161 | ) 162 | 163 | occlusion = mods['hid3'](occlusion) 164 | occlusion = jnp.mean(occlusion, axis=(-2, -3)) 165 | occlusion = mods['hid4'](occlusion) 166 | occlusion = jax.nn.relu(occlusion) 167 | occlusion = mods['occ_out'](occlusion) 168 | occlusion = jnp.transpose(occlusion, (1, 0, 2)) 169 | assert occlusion.shape[1] == shape[0] 170 | occlusion = jnp.reshape(occlusion, (shape[1], shape[2], shape[0])) 171 | return points, occlusion 172 | 173 | def __call__( 174 | self, 175 | video: chex.Array, 176 | is_training: bool, 177 | query_points: chex.Array, 178 | compute_regression: bool = True, 179 | query_chunk_size: Optional[int] = None, 180 | get_query_feats: bool = False, 181 | feature_grid: Optional[chex.Array] = None, 182 | ) -> Mapping[str, chex.Array]: 183 | """Runs a forward pass of the model. 184 | 185 | Args: 186 | video: A 4-D or 5-D tensor representing a batch of sequences of images. In 187 | the 4-D case, we assume the entire batch has been concatenated along the 188 | batch dimension, one sequence after the other. This can speed up 189 | inference on the TPU and save memory. 190 | is_training: Whether we are training. 191 | query_points: The query points for which we compute tracks. 192 | compute_regression: if True, compute tracks using cost volumes; otherwise 193 | simply compute features (required for the baseline) 194 | query_chunk_size: When computing cost volumes, break the queries into 195 | chunks of this size to save memory. 196 | get_query_feats: If True, also return the features for each query obtained 197 | using bilinear interpolation from the feature grid 198 | feature_grid: If specified, use this as the feature grid rather than 199 | computing it from the pixels. 200 | 201 | Returns: 202 | A dict of outputs, including: 203 | feature_grid: a TSM-ResNet feature grid of shape 204 | [batch, num_frames, height//stride, width//stride, channels] 205 | query_feats (optional): A feature for each query point, of size 206 | [batch, num_queries, channels] 207 | occlusion: Occlusion logits, of shape [batch, num_queries, num_frames] 208 | where higher indicates more likely to be occluded. 209 | tracks: predicted point locations, of shape 210 | [batch, num_queries, num_frames, 2], where each point is [x, y] 211 | in raster coordinates 212 | """ 213 | num_frames = None 214 | if feature_grid is None: 215 | latent = self.tsm_resnet( 216 | video, 217 | is_training=is_training, 218 | output_stride=self.feature_grid_stride, 219 | out_num_frames=num_frames, 220 | final_endpoint='tsm_resnet_unit_2', 221 | ) 222 | 223 | feature_grid = latent / jnp.sqrt( 224 | jnp.maximum( 225 | jnp.sum(jnp.square(latent), axis=-1, keepdims=True), 226 | 1e-12, 227 | )) 228 | 229 | shape = video.shape 230 | if num_frames is not None and len(shape) < 5: 231 | shape = (shape[0] // num_frames, num_frames) + shape[1:] 232 | 233 | # shape is [batch_size, time, height, width, channels]; conversion needs 234 | # [time, width, height] 235 | position_in_grid = transforms.convert_grid_coordinates( 236 | query_points, 237 | shape[1:4], 238 | feature_grid.shape[1:4], 239 | coordinate_format='tyx', 240 | ) 241 | interp_features = jax.vmap( 242 | jax.vmap( 243 | model_utils.interp, 244 | in_axes=(3, None), 245 | out_axes=1, 246 | ) 247 | )(feature_grid, position_in_grid) 248 | feature_grid_heads = einshape( 249 | 'bthw(cd)->bthwcd', feature_grid, d=self.num_heads 250 | ) 251 | interp_features_heads = einshape( 252 | 'bn(cd)->bncd', 253 | interp_features, 254 | d=self.num_heads, 255 | ) 256 | out = {'feature_grid': feature_grid} 257 | if get_query_feats: 258 | out['query_feats'] = interp_features 259 | 260 | if compute_regression: 261 | assert query_chunk_size is not None 262 | all_occ = [] 263 | all_pts = [] 264 | infer = functools.partial(self.tracks_from_cost_volume, im_shp=shape) 265 | 266 | for i in range(0, query_points.shape[1], query_chunk_size): 267 | points, occlusion = infer( 268 | interp_features_heads[:, i:i + query_chunk_size], 269 | feature_grid_heads, 270 | query_points[:, i:i + query_chunk_size], 271 | ) 272 | all_occ.append(occlusion) 273 | all_pts.append(points) 274 | occlusion = jnp.concatenate(all_occ, axis=1) 275 | points = jnp.concatenate(all_pts, axis=1) 276 | 277 | out['occlusion'] = occlusion 278 | out['tracks'] = points 279 | 280 | return out 281 | -------------------------------------------------------------------------------- /tapnet/tapnext/tapnext_torch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """TAPNext implementation in torch.""" 17 | 18 | import dataclasses 19 | from typing import List 20 | 21 | import einops 22 | import numpy as np 23 | from tapnet.tapnext import tapnext_lru_modules 24 | import torch 25 | from torch import nn 26 | from torch.nn import functional as F 27 | from torchvision.models import vision_transformer 28 | 29 | 30 | def posemb_sincos_2d(h, w, width, temperature=10_000.0, dtype=np.float32): 31 | """Follows the MoCo v3 logic.""" 32 | y, x = np.mgrid[:h, :w] 33 | 34 | assert width % 4 == 0, 'Width must be mult of 4 for sincos posemb' 35 | omega = np.arange(width // 4) / (width // 4 - 1) 36 | omega = 1.0 / (temperature**omega) 37 | y = np.einsum('m,d->md', y.flatten(), omega) 38 | x = np.einsum('m,d->md', x.flatten(), omega) 39 | pe = np.concatenate([np.sin(x), np.cos(x), np.sin(y), np.cos(y)], axis=1) 40 | return np.asarray(pe, dtype)[None, :, :] 41 | 42 | 43 | class TRecViTBlock(nn.Module): 44 | """A block proposed by https://arxiv.org/abs/2412.14294.""" 45 | 46 | def __init__(self, depth, width, num_heads, lru_width, dtype, device): 47 | super().__init__() 48 | self.ssm_block = tapnext_lru_modules.ResidualBlock( 49 | width=width, 50 | mlp_expanded_width=width * 4, 51 | num_heads=num_heads, 52 | lru_width=lru_width, 53 | final_w_init_variance_scale=2.0 / depth, 54 | dtype=dtype, 55 | device=device, 56 | ) 57 | self.vit_block = vision_transformer.EncoderBlock( 58 | num_heads=num_heads, 59 | mlp_dim=width * 4, 60 | hidden_dim=width, 61 | attention_dropout=0.0, 62 | dropout=0.0, 63 | ) 64 | 65 | def forward(self, x, cache=None, use_linear_scan=True): 66 | b, t, n, _ = x.shape 67 | x = einops.rearrange(x, 'b t n c -> (b n) t c') 68 | x, ssm_cache = self.ssm_block(x, cache, use_linear_scan=use_linear_scan) 69 | x = einops.rearrange(x, '(b n) t c -> (b t) n c', b=b, n=n) 70 | x = self.vit_block(x) 71 | x = einops.rearrange(x, '(b t) n c -> b t n c', b=b, t=t) 72 | return x, ssm_cache 73 | 74 | 75 | @dataclasses.dataclass 76 | class TAPNextTrackingState: 77 | """State for TAPNext.""" 78 | 79 | step: int 80 | query_points: torch.Tensor # Float["*B Q 3"] 81 | hidden_state: List[tapnext_lru_modules.RecurrentBlockCache] = None 82 | 83 | 84 | class TAPNext(nn.Module): 85 | """TAPNext implementation in pytorch.""" 86 | 87 | def __init__( 88 | self, 89 | image_size, 90 | width=768, 91 | patch_size=(8, 8), 92 | num_heads=12, 93 | lru_width=768, 94 | depth=12, 95 | use_checkpointing=False, 96 | ): 97 | super().__init__() 98 | self.width = width 99 | self.patch_size = patch_size 100 | self.use_checkpointing = use_checkpointing 101 | self.image_size = image_size 102 | 103 | self.lin_proj = nn.Conv2d( 104 | in_channels=3, 105 | out_channels=self.width, 106 | kernel_size=self.patch_size, 107 | stride=self.patch_size, 108 | ) 109 | self.blocks = nn.ModuleList([ 110 | TRecViTBlock( 111 | depth=depth, 112 | width=width, 113 | num_heads=num_heads, 114 | lru_width=lru_width, 115 | dtype=torch.float32, 116 | device='cuda', 117 | ) 118 | for _ in range(depth) 119 | ]) 120 | self.encoder_norm = nn.LayerNorm(self.width) 121 | self.mask_token = nn.Parameter( 122 | torch.zeros((1, 1, 1, self.width)), requires_grad=True 123 | ) 124 | self.unknown_token = nn.Parameter( 125 | torch.zeros((1, 1, self.width)), requires_grad=True 126 | ) 127 | self.point_query_token = nn.Parameter( 128 | torch.zeros((1, 1, 1, self.width)), requires_grad=True 129 | ) 130 | h = self.image_size[0] // self.patch_size[0] 131 | w = self.image_size[1] // self.patch_size[1] 132 | c = self.width 133 | self.image_pos_emb = nn.Parameter( 134 | torch.zeros((1, h * w, c)), requires_grad=True 135 | ) 136 | self.register_buffer( 137 | 'query_pos_embed', 138 | torch.tensor( 139 | posemb_sincos_2d(self.image_size[0], self.image_size[1], c) 140 | ), 141 | ) 142 | self.visible_head = nn.Sequential( 143 | nn.Linear(width, 256), 144 | nn.LayerNorm(256), 145 | nn.GELU(), 146 | nn.Linear(256, 256), 147 | nn.LayerNorm(256), 148 | nn.GELU(), 149 | nn.Linear(256, 1), 150 | ) 151 | self.coordinate_head = nn.Sequential( 152 | nn.Linear(width, 256), 153 | nn.LayerNorm(256), 154 | nn.GELU(), 155 | nn.Linear(256, 256), 156 | nn.LayerNorm(256), 157 | nn.GELU(), 158 | nn.Linear(256, 512), 159 | ) 160 | 161 | def embed_queries(self, timesteps, query_points): 162 | b, q, _ = query_points.shape 163 | t = timesteps 164 | c = self.width 165 | tiled_point_query_tokens = self.point_query_token.repeat(b, 1, q, 1) 166 | mask_tokens = self.mask_token.repeat(b, t, q, 1) 167 | unknown_tokens = self.unknown_token.repeat(b, t, q, 1) 168 | query_pos_embed = self.query_pos_embed.view( 169 | 1, self.image_size[0], self.image_size[1], c 170 | ).repeat(b, 1, 1, 1) 171 | # [B Q t 3] 172 | query_timesteps, query_positions = ( 173 | query_points[..., :1], 174 | query_points[..., 1:], 175 | ) 176 | # grid sample expects coordinates in [-1, 1] 177 | query_pos_embed_spatial = F.grid_sample( 178 | query_pos_embed.permute( 179 | 0, 3, 2, 1 180 | ), # we swap h and w due to how grid_sample works 181 | ( 182 | query_positions.unsqueeze(1) 183 | / torch.tensor(self.image_size, device=query_positions.device) 184 | ) 185 | * 2 186 | - 1, 187 | align_corners=False, 188 | ).permute( 189 | 0, 2, 3, 1 190 | ) # [b Q t c] 191 | point_query_tokens = tiled_point_query_tokens + query_pos_embed_spatial 192 | # NOTE: query hints are only implemented in jax but not in pytorch 193 | # the two masks below are used to not add point query (if queried later in 194 | # online tracking) 195 | if t == 1: # online tracking 196 | mask_and_query_tokens = torch.where( 197 | query_timesteps.unsqueeze(1) == 0, point_query_tokens, mask_tokens 198 | ) 199 | else: 200 | queries_are_late = query_timesteps >= t 201 | queries_are_early = query_timesteps < 0 202 | mask_and_query_tokens = mask_tokens.scatter( 203 | dim=1, 204 | index=query_timesteps.unsqueeze(1) 205 | .long() 206 | .clamp(0, t - 1) 207 | .repeat(1, 1, 1, c), 208 | src=point_query_tokens, 209 | ) 210 | mask_and_query_tokens = torch.where( 211 | (queries_are_late | queries_are_early).unsqueeze(1), 212 | mask_tokens, 213 | mask_and_query_tokens, 214 | ) 215 | is_unknown_token = torch.arange(t, device=query_points.device)[ 216 | None, :, None, None 217 | ] < query_timesteps.unsqueeze(1) 218 | mask_and_query_tokens = torch.where( 219 | is_unknown_token, unknown_tokens, mask_and_query_tokens 220 | ) 221 | return mask_and_query_tokens 222 | 223 | def prediction_heads(self, x): 224 | soft_argmax_threshold = 20 225 | softmax_temperature = 0.5 226 | track_logits = self.coordinate_head(x.float()) 227 | position_x, position_y = track_logits.chunk(2, dim=-1) 228 | argmax_x, argmax_y = position_x.argmax( 229 | dim=-1, keepdim=True 230 | ), position_y.argmax(dim=-1, keepdim=True) 231 | index = torch.arange(position_x.shape[-1], device=x.device).repeat( 232 | *argmax_x.shape[:-1], 1 233 | ) 234 | mask_x = (torch.abs(argmax_x - index) <= soft_argmax_threshold).float() 235 | mask_y = (torch.abs(argmax_y - index) <= soft_argmax_threshold).float() 236 | probs_x = F.softmax(position_x * softmax_temperature, dim=-1) * mask_x 237 | probs_y = F.softmax(position_y * softmax_temperature, dim=-1) * mask_y 238 | probs_x = probs_x / probs_x.sum(dim=-1, keepdim=True) 239 | probs_y = probs_y / probs_y.sum(dim=-1, keepdim=True) 240 | tracks_x = torch.sum(probs_x * index, dim=-1)[..., None] 241 | tracks_y = torch.sum(probs_y * index, dim=-1)[..., None] 242 | tracks = torch.cat([tracks_x, tracks_y], axis=-1) 243 | tracks += 0.5 244 | visible_logits = self.visible_head(x) 245 | return tracks, track_logits, visible_logits 246 | 247 | def forward(self, video, query_points=None, state=None): 248 | # video.shape 249 | b, t, _, _, _ = video.shape 250 | # [b, t, h, w, 3] -> [b, t, 3, h, w] 251 | video_tokens = self.lin_proj( 252 | einops.rearrange(video, 'b t h w c -> (b t) c h w') 253 | ) 254 | _, _, h, w = video_tokens.shape 255 | video_tokens = einops.rearrange( 256 | video_tokens, '(b t) c h w -> b t (h w) c', b=b, t=t 257 | ) 258 | video_tokens = video_tokens + self.image_pos_emb.unsqueeze(0) 259 | if state is not None: 260 | # in online tracking, we put query "back in time" 261 | if query_points is None: 262 | query_points = state.query_points 263 | query_points = torch.cat( 264 | [query_points[..., :1] - state.step, query_points[..., 1:]], dim=-1 265 | ) 266 | step = state.step 267 | else: 268 | step = 0 269 | point_tokens = self.embed_queries(t, query_points) # [b t Q c] 270 | x = torch.cat([video_tokens, point_tokens], dim=2) # [b t (h * w + Q) c] 271 | ssm_cache = [] 272 | use_linear_scan = not self.training 273 | for blk, cache in zip( 274 | self.blocks, state.hidden_state if state is not None else [None] * 12 275 | ): 276 | if self.use_checkpointing: 277 | x, ssm_cache_layer = torch.utils.checkpoint.checkpoint( 278 | blk, x, cache, use_linear_scan, use_reentrant=False 279 | ) 280 | else: 281 | x, ssm_cache_layer = blk( 282 | x, cache=cache, use_linear_scan=use_linear_scan 283 | ) 284 | ssm_cache.append(ssm_cache_layer) 285 | x = self.encoder_norm(x) 286 | video_tokens, point_tokens = x[:, :, : h * w, :], x[:, :, h * w :, :] 287 | return ( 288 | *self.prediction_heads(point_tokens), 289 | TAPNextTrackingState( 290 | step=step + t, 291 | query_points=state.query_points 292 | if state is not None 293 | else query_points, 294 | hidden_state=ssm_cache, 295 | ), 296 | ) 297 | 298 | torch.export.register_dataclass(TAPNextTrackingState) 299 | 300 | 301 | def flatten_tracking_state(state, _): 302 | return ( 303 | state.step, 304 | state.query_points, 305 | [[st for st in h] for h in state.hidden_state], 306 | ) 307 | 308 | 309 | torch.fx._pytree.register_pytree_flatten_spec( # pylint: disable=protected-access 310 | TAPNextTrackingState, flatten_tracking_state 311 | ) 312 | -------------------------------------------------------------------------------- /tapnet/tapvid3d/evaluation/evaluate_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Compute metrics on the saved outputs of a TAP-3D model.""" 17 | 18 | from collections.abc import Sequence 19 | import io 20 | import os 21 | 22 | from absl import app 23 | from absl import flags 24 | from absl import logging 25 | import numpy as np 26 | from PIL import Image 27 | from tapnet.tapvid3d.evaluation import metrics 28 | from tapnet.tapvid3d.splits import tapvid3d_splits 29 | import tqdm 30 | 31 | 32 | _TAPVID3D_DIR = flags.DEFINE_string( 33 | 'tapvid3d_dir', 34 | 'tapvid3d_dataset/', 35 | """ 36 | Path to folder that has all the ground truth npy files. Should 37 | contain subfolders for each data source (adt, pstudio, drivetrack). 38 | """, 39 | ) 40 | 41 | _TAPVID3D_PREDICTIONS = flags.DEFINE_string( 42 | 'tapvid3d_predictions', 43 | 'tapvid3d_predictions/', 44 | """ 45 | Must be a folder that contains a set of *.npz files, each one with 46 | the predicted 3D trajectories for that video, with filename 47 | exactly matching the corresponding ground truth file/video example. 48 | The npy file must have a tensor `tracks_xyz` of shape (T, N, 3) and 49 | a tensor `visible` of shape where [T, N], where N is the 50 | number of tracks in the video, and T is the number of frames. 51 | """, 52 | ) 53 | 54 | _USE_MINIVAL = flags.DEFINE_boolean( 55 | 'use_minival', 56 | True, 57 | """ 58 | If True, compute metrics on the minival split; 59 | otherwise uses the full_eval split. 60 | """, 61 | ) 62 | 63 | _DEPTH_SCALINGS = flags.DEFINE_list( 64 | 'depth_scalings', 65 | ['median'], 66 | 'Depth scaling strategies to use, list of [median, per_trajectory, etc..].', 67 | ) 68 | 69 | _DATA_SOURCES_TO_EVALUATE = flags.DEFINE_list( 70 | 'data_sources_to_evaluate', 71 | ['drivetrack', 'adt', 'pstudio'], 72 | 'Which data source subsets to evaluate.', 73 | ) 74 | 75 | _DEBUG = flags.DEFINE_boolean( 76 | 'debug', 77 | False, 78 | 'Whether to run in debug mode, downloads only one video.', 79 | ) 80 | 81 | 82 | ZERO_METRICS_DICT = { 83 | 'occlusion_accuracy': np.array([0.0]), 84 | 'pts_within_1': np.array([0.0]), 85 | 'jaccard_1': np.array([0.0]), 86 | 'pts_within_2': np.array([0.0]), 87 | 'jaccard_2': np.array([0.0]), 88 | 'pts_within_4': np.array([0.0]), 89 | 'jaccard_4': np.array([0.0]), 90 | 'pts_within_8': np.array([0.0]), 91 | 'jaccard_8': np.array([0.0]), 92 | 'pts_within_16': np.array([0.0]), 93 | 'jaccard_16': np.array([0.0]), 94 | 'average_jaccard': np.array([0.0]), 95 | 'average_pts_within_thresh': np.array([0.0]), 96 | } 97 | 98 | 99 | def get_new_hw_with_given_smallest_side_length( 100 | *, orig_height: int, orig_width: int, smallest_side_length: int = 256 101 | ): 102 | orig_shape = np.array([orig_height, orig_width]) 103 | scaling_factor = smallest_side_length / np.min(orig_shape) 104 | resized_shape = np.round(orig_shape * scaling_factor) 105 | return (int(resized_shape[0]), int(resized_shape[1])), scaling_factor 106 | 107 | 108 | def get_jpeg_byte_hw(jpeg_bytes: bytes): 109 | with io.BytesIO(jpeg_bytes) as img_bytes: 110 | img = Image.open(img_bytes) 111 | img = img.convert('RGB') 112 | return np.array(img).shape[:2] 113 | 114 | 115 | def get_average_over_metrics( 116 | list_of_metrics: list[dict[str, dict[str, np.ndarray]]] 117 | ): 118 | """Average over per video metrics in a nested metrics dictionaries.""" 119 | avg_metrics = {} 120 | for metric_category in list_of_metrics[0].keys(): 121 | avg_metrics[metric_category] = {} 122 | for metric_name in list_of_metrics[0][metric_category]: 123 | avg_metrics[metric_category][metric_name] = np.mean( 124 | [ 125 | video_metric[metric_category][metric_name] 126 | for video_metric in list_of_metrics 127 | ] 128 | ) 129 | return avg_metrics 130 | 131 | 132 | def evaluate_data_source( 133 | npz_filenames: list[str], 134 | ground_truth_dir: str, 135 | predictions_dir: str, 136 | depth_scalings: list[str], 137 | metric_eval_resolution: int = 256, 138 | ) -> dict[str, dict[str, np.ndarray]]: 139 | """Compute metrics on the set of 3D track predictions. 140 | 141 | This function expects as input two directories: `ground_truth_dir` and 142 | `predictions_dir`. They should have a structure as follows: 143 | 144 | - ground_truth_dir: 145 | - video1.npz (with contents+format described in tapvid3d/README.md) 146 | - video2.npz 147 | - video3.npz 148 | - ... 149 | - videoN.npz 150 | 151 | - predictions_dir: 152 | - video1.npz (contains two tensors: the predicted `visibility` and 153 | `tracks_XYZ`, with shapes described in tapvid3d/README.md) 154 | - video2.npz 155 | - video3.npz 156 | - ... 157 | - videoN.npz 158 | 159 | `ground_truth_dir` contains npz files with ground truth data, while 160 | `predictions_dir` should contains npz files with the same file name as their 161 | corresponding ground truth file in `ground_truth_dir`. npz files can be named 162 | anything (do not need to be named as video{N}.npz in the example), as long 163 | as it is consistent across both directories. 164 | 165 | Also note that this directory structure (for `ground_truth_dir`) is what is 166 | generated by `generate_all.sh` script. The script generates three different 167 | directories: `tapvid3d_dataset/adt`, `tapvid3d_dataset/pstudio`, 168 | `tapvid3d_dataset/drivetrack`, each one with the structure described above. 169 | 170 | Args: 171 | npz_filenames: List of npz file names to compute metrics over. 172 | Each filename must exist in both `ground_truth_dir` and 173 | `predictions_dir`. 174 | ground_truth_dir: Path to the TAP-3D dataset, format described above. 175 | predictions_dir: Path to the TAP-3D predictions, format described above. 176 | depth_scalings: List of strings, describing which depth scaling strategies 177 | to use. See `compute_tapvid3d_metrics()` in metrics.py for 178 | available scaling options and their descriptions. 179 | metric_eval_resolution: The resolution at which to evaluate the metrics, 180 | by default is 256, which is used for all results 181 | in the TAPVid-3D paper and original TAPVid paper. 182 | 183 | Returns: 184 | A dictionary of dictionaries: the outer dict mapping different 185 | depth scaling strategies to an inner metrics dict (mapping a metric name 186 | to the corresponding metric score). 187 | """ 188 | metrics_all_videos = [] 189 | for npy_file in tqdm.tqdm(npz_filenames, total=len(npz_filenames)): 190 | gt_file = os.path.join(ground_truth_dir, npy_file) 191 | with open(gt_file, 'rb') as in_f: 192 | in_npz = np.load(in_f, allow_pickle=True) 193 | images_jpeg_bytes = in_npz['images_jpeg_bytes'] 194 | queries_xyt = in_npz['queries_xyt'] 195 | tracks_xyz = in_npz['tracks_XYZ'] 196 | visibles = in_npz['visibility'] 197 | intrinsics_params = in_npz['fx_fy_cx_cy'] 198 | 199 | # Simple rescaling from original video resolution to the resolution 200 | # at which we want to compute the metrics (metric_eval_resolution) 201 | # Since camera intrinsics are for original video resolution, we need to 202 | # rescale them as well. 203 | # `intrinsics_params` is in format [fx, fy, cx, cy]. 204 | video_height, video_width = get_jpeg_byte_hw(images_jpeg_bytes[0]) 205 | smallest_side_length = metric_eval_resolution 206 | (_, _), scaling_factor = ( 207 | get_new_hw_with_given_smallest_side_length( 208 | orig_height=video_height, 209 | orig_width=video_width, 210 | smallest_side_length=smallest_side_length, 211 | ) 212 | ) 213 | intrinsics_params_resized = intrinsics_params * scaling_factor 214 | 215 | prediction_file = os.path.join(predictions_dir, npy_file) 216 | try: 217 | with open(prediction_file, 'rb') as in_f: 218 | predictor_data = np.load(in_f, allow_pickle=True) 219 | predicted_tracks_xyz = predictor_data['tracks_XYZ'] 220 | predicted_visibility = predictor_data['visibility'] 221 | 222 | except Exception: # pylint: disable=broad-exception-caught 223 | logging.exception('Failed to read %s', prediction_file) 224 | failure_metrics_dict = { 225 | scaling: ZERO_METRICS_DICT for scaling in depth_scalings 226 | } 227 | metrics_all_videos.append(failure_metrics_dict) 228 | if _DEBUG.value: 229 | logging.info('Stopping after one video, debug run.') 230 | break 231 | continue 232 | 233 | video_metrics = {} 234 | for depth_scaling in depth_scalings: 235 | try: 236 | metrics_for_scale = metrics.compute_tapvid3d_metrics( 237 | gt_occluded=np.logical_not(visibles), 238 | gt_tracks=tracks_xyz, 239 | pred_occluded=np.logical_not(predicted_visibility), 240 | pred_tracks=predicted_tracks_xyz, 241 | intrinsics_params=intrinsics_params_resized, 242 | scaling=depth_scaling, 243 | query_points=queries_xyt[..., ::-1], 244 | order='t n', 245 | ) 246 | except Exception: # pylint: disable=broad-exception-caught 247 | logging.exception('Failed to compute metrics for %s', npy_file) 248 | metrics_for_scale = ZERO_METRICS_DICT 249 | video_metrics[depth_scaling] = metrics_for_scale 250 | metrics_all_videos.append(video_metrics) 251 | 252 | if _DEBUG.value: 253 | logging.info('Stopping after one video, debug run.') 254 | break 255 | 256 | avg_metrics = get_average_over_metrics(metrics_all_videos) 257 | return avg_metrics 258 | 259 | 260 | def main(argv: Sequence[str]) -> None: 261 | if len(argv) > 1: 262 | raise app.UsageError('Too many command-line arguments.') 263 | 264 | metrics_all_sources = [] 265 | for data_source in _DATA_SOURCES_TO_EVALUATE.value: 266 | if _USE_MINIVAL.value: 267 | all_npz_files = tapvid3d_splits.get_minival_files(subset=data_source) 268 | else: 269 | all_npz_files = tapvid3d_splits.get_full_eval_files(subset=data_source) 270 | source_gt_dir = os.path.join(_TAPVID3D_DIR.value, data_source) 271 | source_pred_dir = os.path.join(_TAPVID3D_PREDICTIONS.value, data_source) 272 | source_metrics = evaluate_data_source( 273 | npz_filenames=all_npz_files, 274 | ground_truth_dir=source_gt_dir, 275 | predictions_dir=source_pred_dir, 276 | depth_scalings=_DEPTH_SCALINGS.value, 277 | ) 278 | metrics_all_sources.append(source_metrics) 279 | logging.info('Metrics for data source %s', data_source) 280 | logging.info(source_metrics) 281 | 282 | avg_metrics = get_average_over_metrics(metrics_all_sources) 283 | logging.info('Metrics, averaged across all data sources:') 284 | logging.info(avg_metrics) 285 | 286 | logging.info('Finished computing metrics!') 287 | 288 | 289 | if __name__ == '__main__': 290 | app.run(main) 291 | --------------------------------------------------------------------------------