├── tapnet
├── tapvid
│ ├── requirements.txt
│ ├── __init__.py
│ ├── download_kinetics_videos.sh
│ ├── visualize.py
│ ├── generate_tapvid.py
│ └── README.md
├── models
│ ├── __init__.py
│ ├── tsm_utils.py
│ └── tapnet_model.py
├── torch
│ └── __init__.py
├── trajan
│ ├── __init__.py
│ └── attention.py
├── utils
│ ├── __init__.py
│ ├── optimizers.py
│ ├── transforms.py
│ ├── index_utils.py
│ ├── ssm_utils.py
│ └── experiment_utils.py
├── tapvid3d
│ ├── splits
│ │ └── __init__.py
│ ├── evaluation
│ │ ├── __init__.py
│ │ ├── run_evaluate_model.sh
│ │ └── evaluate_model.py
│ ├── annotation_generation
│ │ ├── __init__.py
│ │ ├── generate_drivetrack.py
│ │ ├── gcs_utils.py
│ │ ├── generate_adt.py
│ │ ├── generate_pstudio.py
│ │ └── adt_utils.py
│ ├── generate_all.sh
│ └── README.md
├── __init__.py
├── training
│ ├── README.md
│ └── task.py
├── tapnext
│ ├── pscan.py
│ ├── torch_losses.py
│ ├── losses.py
│ ├── tapnext_benchmark_pytorch.ipynb
│ ├── tapnext_torch_utils.py
│ └── tapnext_torch.py
├── live_demo.py
└── pytorch_live_demo.py
├── assets
└── tapir_benchmark.png
├── requirements_inference.txt
├── requirements.txt
├── .gitignore
├── CONTRIBUTING.md
├── pyproject.toml
├── colabs
└── tapir_clustering.ipynb
└── configs
├── tapnet_config.py
├── tapir_config.py
├── causal_tapir_config.py
└── tapir_bootstrap_config.py
/tapnet/tapvid/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py
2 | ffmpeg
3 | ffmpeg-python
4 | mediapy
5 | numpy
6 |
--------------------------------------------------------------------------------
/assets/tapir_benchmark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/tapnet/HEAD/assets/tapir_benchmark.png
--------------------------------------------------------------------------------
/requirements_inference.txt:
--------------------------------------------------------------------------------
1 | chex
2 | jax
3 | jaxline
4 | optax
5 | dm-haiku
6 | dm-tree
7 | typing_extensions
8 | matplotlib
9 | mediapy
10 | opencv-python
11 | einshape
12 | ipympl
13 | tqdm
14 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py
2 | notebook
3 | jupyter_http_over_ws
4 | chex
5 | jax
6 | jaxline
7 | optax
8 | dm-haiku
9 | dm-tree
10 | typing_extensions
11 | matplotlib
12 | mediapy
13 | opencv-python
14 | einshape
15 | einops
16 | tensorflow
17 | tensorflow-datasets
18 | tensorflow_graphics
19 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Distribution / packaging
7 | .Python
8 | build/
9 | develop-eggs/
10 | dist/
11 | downloads/
12 | eggs/
13 | .eggs/
14 | lib/
15 | lib64/
16 | parts/
17 | sdist/
18 | var/
19 | wheels/
20 | share/python-wheels/
21 | *.egg-info/
22 | .installed.cfg
23 | *.egg
24 | MANIFEST
25 |
--------------------------------------------------------------------------------
/tapnet/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 |
17 |
--------------------------------------------------------------------------------
/tapnet/tapvid/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 |
17 |
--------------------------------------------------------------------------------
/tapnet/torch/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 |
17 |
--------------------------------------------------------------------------------
/tapnet/trajan/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 |
17 |
--------------------------------------------------------------------------------
/tapnet/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 |
17 |
--------------------------------------------------------------------------------
/tapnet/tapvid3d/splits/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 |
17 |
--------------------------------------------------------------------------------
/tapnet/tapvid3d/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 |
17 |
--------------------------------------------------------------------------------
/tapnet/tapvid3d/annotation_generation/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 |
17 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | ## Contributor License Agreement
4 |
5 | Contributions to this project must be accompanied by a Contributor License
6 | Agreement. You (or your employer) retain the copyright to your contribution,
7 | this simply gives us permission to use and redistribute your contributions as
8 | part of the project. Head over to to see
9 | your current agreements on file or to sign a new one.
10 |
11 | You generally only need to submit a CLA once, so if you've already submitted one
12 | (even if it was for a different project), you probably don't need to do it
13 | again.
14 |
15 | ## Code reviews
16 |
17 | All submissions, including submissions by project members, require review. We
18 | use GitHub pull requests for this purpose. Consult
19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
20 | information on using pull requests.
21 |
22 | ## Community Guidelines
23 |
24 | This project follows [Google's Open Source Community
25 | Guidelines](https://opensource.google/conduct/).
26 |
--------------------------------------------------------------------------------
/tapnet/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Legacy API for TAP. Prefer importing from project subfolders."""
17 |
18 | from tapnet.models import tapir_model # pylint:disable=g-importing-member
19 | from tapnet.models import tapnet_model # pylint:disable=g-importing-member
20 | from tapnet.robotap import tapir_clustering # pylint:disable=g-importing-member
21 | from tapnet.tapvid import evaluation_datasets # pylint:disable=g-importing-member
22 |
--------------------------------------------------------------------------------
/tapnet/tapvid/download_kinetics_videos.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2025 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 |
18 | git clone https://github.com/cvdfoundation/kinetics-dataset.git
19 | mkdir kinetics_videos
20 | cd kinetics_videos
21 | wget https://s3.amazonaws.com/kinetics/700_2020/val/k700_2020_val_path.txt
22 | bash ../kinetics-dataset/download.sh k700_2020_val_path.txt
23 | bash ../kinetics-dataset/extract.sh k700_2020_val_path.txt
24 | rm -f k700_val_*
25 | rm -f k700_2020_val_path.txt
26 | cd ..
27 | rm -rf kinetics-dataset
28 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "tapnet"
3 | version = "0.1.0"
4 | description = "Tracking-Any-Point codebase from Google DeepMind."
5 | dependencies = [
6 | "chex",
7 | "jax",
8 | "jaxline",
9 | "optax",
10 | "dm-haiku",
11 | "dm-tree",
12 | "typing_extensions",
13 | "matplotlib",
14 | "mediapy",
15 | "opencv-python",
16 | "einshape",
17 | "ipympl",
18 | "tqdm",
19 | ]
20 |
21 | [project.optional-dependencies]
22 | train = [
23 | "absl-py",
24 | "notebook",
25 | "jupyter_http_over_ws",
26 | "tensorflow",
27 | "tensorflow-datasets",
28 | "tensorflow_graphics",
29 | "kubric@git+https://github.com/google-research/kubric",
30 | "recurrentgemma@git+https://github.com/google-deepmind/recurrentgemma"
31 | ]
32 |
33 | torch = [
34 | "torch",
35 | "torchvision",
36 | ]
37 | tapvid3d_eval = [
38 | "einops>=0.8.0",
39 | "numpy>=1.25.2",
40 | "absl-py>=2.1.0",
41 | "tqdm>=4.66.4",
42 | "pillow>=9.4.0",
43 | ]
44 | tapvid3d_generation = [
45 | "absl-py==2.1.0",
46 | "tqdm==4.66.4",
47 | "absl-py==2.1.0",
48 | "tqdm==4.66.4",
49 | "tensorflow==2.15.0",
50 | "numpy==1.25.2",
51 | "pillow==9.4.0",
52 | "projectaria-tools==1.5.2",
53 | "visu3d==1.5.3",
54 | "torch==2.3.0",
55 | "torchvision==0.18.0",
56 | "etils==1.7.0",
57 | "tensorflow-datasets",
58 | ]
59 |
60 | [tool.setuptools.packages.find]
61 | where = ["."] # list of folders that contain the packages (["."] by default)
62 | include = ["tapnet*"] # package names should match these glob patterns (["*"] by default)
63 |
64 |
--------------------------------------------------------------------------------
/tapnet/tapvid3d/generate_all.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2025 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 |
18 | set -x
19 |
20 | debug=0
21 |
22 | # Parse debug option
23 | while [[ "$#" -gt 0 ]]; do
24 | case "$1" in
25 | -d|--debug)
26 | debug=1
27 | ;;
28 | *)
29 | echo "Unknown option: $1"
30 | ;;
31 | esac
32 | shift
33 | done
34 |
35 | if [[ $debug -eq 1 ]]; then
36 | PYTHON_DEBUG="True"
37 | else
38 | PYTHON_DEBUG="False"
39 | fi
40 |
41 | python3 -m venv tapvid3d
42 | source tapvid3d/bin/activate
43 |
44 | # Download the ADT data and annotations
45 | ADT_OUTPUT_DIRECTORY="tapvid3d_dataset/adt/"
46 | mkdir -p $ADT_OUTPUT_DIRECTORY
47 | python3 -m tapnet.tapvid3d.annotation_generation.generate_adt --output_dir=$ADT_OUTPUT_DIRECTORY --debug=$PYTHON_DEBUG --split=all
48 |
49 | # Download the Panoptic Studio data and annotations
50 | PSTUDIO_OUTPUT_DIRECTORY="tapvid3d_dataset/pstudio/"
51 | mkdir -p $PSTUDIO_OUTPUT_DIRECTORY
52 | python3 -m tapnet.tapvid3d.annotation_generation.generate_pstudio --output_dir=$PSTUDIO_OUTPUT_DIRECTORY --debug=$PYTHON_DEBUG --split=all
53 |
54 | # Download the Waymo Open / DriveTrack data and annotations
55 | DT_OUTPUT_DIRECTORY="tapvid3d_dataset/drivetrack/"
56 | mkdir -p $DT_OUTPUT_DIRECTORY
57 | python3 -m tapnet.tapvid3d.annotation_generation.generate_drivetrack --output_dir=$DT_OUTPUT_DIRECTORY --debug=$PYTHON_DEBUG --split=all
58 |
59 | deactivate
60 |
--------------------------------------------------------------------------------
/tapnet/tapvid3d/annotation_generation/generate_drivetrack.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Download the Waymo Open + DriveTrack videos and annotations.
17 |
18 | The *.npz files in the GCS bucket for the DriveTrack split contains both
19 | the preprocessed annotations and source videos, so all we need to do is
20 | bulk download them to the local machine.
21 | """
22 |
23 | from collections.abc import Sequence
24 | import os
25 |
26 | from absl import app
27 | from absl import flags
28 | from tapnet.tapvid3d.annotation_generation import gcs_utils
29 |
30 |
31 | _OUTPUT_DIR = flags.DEFINE_string(
32 | "output_dir",
33 | "tapvid3d_dataset/drivetrack/",
34 | "Path to folder to store output npz files containing all fields.",
35 | )
36 |
37 | _DEBUG = flags.DEFINE_boolean(
38 | "debug",
39 | False,
40 | "Whether to run in debug mode, downloads only one video.",
41 | )
42 |
43 | _SPLIT = flags.DEFINE_enum(
44 | "split",
45 | "all",
46 | ["minival", "full_eval", "all"],
47 | """
48 | If True, compute metrics on the minival split;
49 | otherwise uses the full_eval split.
50 | """,
51 | )
52 |
53 |
54 | def main(argv: Sequence[str]) -> None:
55 | if len(argv) > 1:
56 | raise app.UsageError("Too many command-line arguments.")
57 |
58 | if not os.path.exists(_OUTPUT_DIR.value):
59 | os.makedirs(_OUTPUT_DIR.value)
60 |
61 | gcs_utils.download_tapvid3d_files(
62 | _OUTPUT_DIR.value, _SPLIT.value, "drivetrack", _DEBUG.value
63 | )
64 |
65 |
66 | if __name__ == "__main__":
67 | app.run(main)
68 |
--------------------------------------------------------------------------------
/tapnet/tapvid/visualize.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Visualize frames of a random video of the given dataset."""
17 |
18 | import io
19 | import pickle
20 | import random
21 | from typing import Sequence
22 |
23 | from absl import app
24 | from absl import flags
25 | from absl import logging
26 | import mediapy as media
27 | import numpy as np
28 | from PIL import Image
29 |
30 | from tapnet.utils import viz_utils
31 |
32 | FLAGS = flags.FLAGS
33 |
34 | flags.DEFINE_string(
35 | 'input_path', None, 'Path to the pickle file.', required=True
36 | )
37 | flags.DEFINE_string(
38 | 'output_path', None, 'Path to the output mp4 video.', required=True
39 | )
40 |
41 |
42 | def main(argv: Sequence[str]) -> None:
43 | del argv
44 |
45 | logging.info('Loading data from %s. This takes time.', FLAGS.input_path)
46 | with open(FLAGS.input_path, 'rb') as f:
47 | data = pickle.load(f)
48 | if isinstance(data, dict):
49 | data = list(data.values())
50 |
51 | idx = random.randint(0, len(data) - 1)
52 | video = data[idx]
53 |
54 | frames = video['video']
55 |
56 | if isinstance(frames[0], bytes):
57 | # Tapnet is stored and JPEG bytes rather than `np.ndarray`s.
58 | def decode(frame):
59 | byteio = io.BytesIO(frame)
60 | img = Image.open(byteio)
61 | return np.array(img)
62 |
63 | frames = np.array([decode(frame) for frame in frames])
64 |
65 | if frames.shape[1] > 360:
66 | frames = media.resize_video(frames, (360, 640))
67 |
68 | scale_factor = np.array(frames.shape[2:0:-1])[np.newaxis, np.newaxis, :]
69 | painted_frames = viz_utils.paint_point_track(
70 | frames,
71 | video['points'] * scale_factor,
72 | ~video['occluded'],
73 | )
74 |
75 | media.write_video(FLAGS.output_path, painted_frames, fps=25)
76 | logging.info('Examplar point visualization saved to %s', FLAGS.output_path)
77 |
78 |
79 | if __name__ == '__main__':
80 | app.run(main)
81 |
--------------------------------------------------------------------------------
/tapnet/tapvid3d/evaluation/run_evaluate_model.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2025 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 |
18 | # Expects:
19 | # (1) to be run from this directory! `./run_evaluate_model.sh``
20 | # (2) that a copy of the generated dataset is in TAPVID3D_DIR, defined below
21 | # and should be modified as necessary, depending on where you stored the
22 | # generated dataset. Format is as described in the README.md, in summary:
23 | # a folder with drivetrack/, adt/, pstudio/ subfolders, each containing
24 | # *.npz files, one per video containing ground truth 3D tracks and other
25 | # metadata.
26 |
27 | TAPVID3D_DIR="/tmp/datasets/tapvid3d/"
28 |
29 | python3 -m venv eval_tapvid3d
30 | source eval_tapvid3d/bin/activate
31 | pip install absl-py==2.1.0 tqdm==4.66.4 numpy==2.0.0 pillow=10.4.0
32 |
33 | cd ../..
34 | pip install .
35 | cd tapvid3d
36 |
37 | # First, download the SpaTracker predictions on the DriveTrack subset from GCP
38 | PREDICTIONS_DIR="/tmp/model_outputs/spatracker/"
39 | mkdir -p "$PREDICTIONS_DIR"
40 |
41 | # Get the DriveTrack filenames
42 | PYTHON_SCRIPT=$(cat < Callable[[str, str, jnp.ndarray], bool]:
30 | """Logic for deciding which parameters to include for weight decay..
31 |
32 | Args:
33 | exclude_names: an optional list of names to include for weight_decay. ['w']
34 | by default.
35 |
36 | Returns:
37 | A predicate that returns False for params that need to be excluded from
38 | weight_decay.
39 | """
40 | # By default weight_decay the weights but not the biases.
41 | if exclude_names is None:
42 | exclude_names = ["b"]
43 |
44 | def include(module_name: Text, name: Text, value: jnp.ndarray):
45 | del value
46 | # Do not weight decay the parameters of normalization blocks.
47 | if any([norm_name in module_name for norm_name in NORM_NAMES]):
48 | return False
49 | else:
50 | return name not in exclude_names
51 |
52 | return include
53 |
54 |
55 | class AddWeightDecayState(NamedTuple):
56 | """Stateless transformation."""
57 |
58 |
59 | def add_weight_decay(
60 | weight_decay: float,
61 | exclude_names: Optional[Sequence[Text]] = None,
62 | ) -> optax.GradientTransformation:
63 | """Add parameter scaled by `weight_decay` to the `updates`.
64 |
65 | Same as optax.add_decayed_weights but can exclude some parameters.
66 |
67 | Args:
68 | weight_decay: weight_decay coefficient.
69 | exclude_names: an optional list of names to exclude for weight_decay. ['b']
70 | by default.
71 |
72 | Returns:
73 | An (init_fn, update_fn) tuple.
74 | """
75 |
76 | def init_fn(_):
77 | return AddWeightDecayState()
78 |
79 | def update_fn(updates, state, params):
80 | include = _weight_decay_exclude(exclude_names=exclude_names)
81 |
82 | u_in, u_ex = hk.data_structures.partition(include, updates)
83 | p_in, _ = hk.data_structures.partition(include, params)
84 | u_in = jax.tree_util.tree_map(lambda g, p: g + weight_decay * p, u_in, p_in)
85 | updates = hk.data_structures.merge(u_ex, u_in)
86 | return updates, state
87 |
88 | return optax.GradientTransformation(init_fn, update_fn)
89 |
--------------------------------------------------------------------------------
/tapnet/training/README.md:
--------------------------------------------------------------------------------
1 | # TAP training setup
2 |
3 | This directory contains a reference implementation for training and evaluating TAP-Net and TAPIR using Kubric, following our papers.
4 |
5 | ## Installation
6 |
7 | Install ffmpeg on your machine:
8 |
9 | ```sudo apt update```
10 |
11 | ```sudo apt install ffmpeg```
12 |
13 | Install OpenEXR:
14 |
15 | ```sudo apt-get install libopenexr-dev```
16 |
17 | Clone the repository:
18 |
19 | ```git clone https://github.com/deepmind/tapnet.git```
20 |
21 | Add current path (parent directory of where TapNet is installed)
22 | to ```PYTHONPATH```:
23 |
24 | ```export PYTHONPATH=`(cd ../ && pwd)`:`pwd`:$PYTHONPATH```
25 |
26 | Switch to the project directory:
27 |
28 | ```cd tapnet```
29 |
30 | Install kubric as a subdirectory:
31 |
32 | ```git clone https://github.com/google-research/kubric.git```
33 |
34 | Install requirements:
35 |
36 | ```pip install -r requirements.txt```
37 |
38 | If you want to use CUDA, make sure you install the drivers and a version
39 | of JAX that's compatible with your CUDA and CUDNN versions.
40 | Refer to
41 | [the jax manual](https://github.com/jax-ml/jax#installation)
42 | to install the correct JAX version with CUDA.
43 |
44 | ## Training
45 |
46 | The configuration file is located at: ```./tapnet/configs/tapnet_config.py```.
47 |
48 | You can modify it for your need or create your own config file following
49 | the example of ```tapnet_config.py```.
50 |
51 | To launch experiment run the command:
52 |
53 | ```python3 -m tapnet.training.experiment --config ./tapnet/configs/tapnet_config.py```
54 |
55 | or
56 |
57 | ```python3 -m tapnet.training.experiment --config ./tapnet/configs/tapir_config.py```
58 |
59 | ## Evaluation
60 |
61 | You can run evaluation for a particular dataset (i.e. tapvid_davis) using the command:
62 |
63 | ```bash
64 | python3 -m tapnet.training.experiment \
65 | --config=./tapnet/configs/tapir_config.py \
66 | --jaxline_mode=eval_davis_points \
67 | --config.checkpoint_dir=./tapnet/checkpoint/ \
68 | --config.experiment_kwargs.config.davis_points_path=/path/to/tapvid_davis.pkl
69 | ```
70 |
71 | Available eval datasets are listed in `supervised_point_prediction.py`.
72 |
73 | ## Inference
74 |
75 | You can run inference for a particular video (i.e. horsejump-high.mp4) using the command:
76 |
77 | ```bash
78 | python3 -m tapnet.training.experiment \
79 | --config=./tapnet/configs/tapnet_config.py \
80 | --jaxline_mode=eval_inference \
81 | --config.checkpoint_dir=./tapnet/checkpoint/ \
82 | --config.experiment_kwargs.config.inference.input_video_path=horsejump-high.mp4 \
83 | --config.experiment_kwargs.config.inference.output_video_path=result.mp4 \
84 | --config.experiment_kwargs.config.inference.resize_height=256 \
85 | --config.experiment_kwargs.config.inference.resize_width=256 \
86 | --config.experiment_kwargs.config.inference.num_points=20
87 | ```
88 |
89 | The inference only serves as an example. It will resize the video to 256x256 resolution, sample 20 random query points on the first frame and track these random points in the rest frames.
90 |
91 | Note that this uses jaxline for model inference if you are training your own model. A more direct way for model inference can be found on the [colab and real-time demos](#tapir-demos).
92 |
--------------------------------------------------------------------------------
/tapnet/utils/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Utilities for transforming image coordinates."""
17 |
18 | from typing import Sequence
19 |
20 | import chex
21 | import numpy as np
22 |
23 |
24 | def convert_grid_coordinates(
25 | coords: chex.Array,
26 | input_grid_size: Sequence[int],
27 | output_grid_size: Sequence[int],
28 | coordinate_format: str = 'xy',
29 | ) -> chex.Array:
30 | """Convert image coordinates between image grids of different sizes.
31 |
32 | By default, it assumes that the image corners are aligned. Therefore,
33 | it adds .5 (since (0,0) is assumed to be the center of the upper-left grid
34 | cell), multiplies by the size ratio, and then subtracts .5.
35 |
36 | Args:
37 | coords: The coordinates to be converted. It is of shape [..., 2] if
38 | coordinate_format is 'xy' or [..., 3] if coordinate_format is 'tyx'.
39 | input_grid_size: The size of the image/grid that the coordinates currently
40 | are with respect to. This is a 2-tuple of the format [width, height]
41 | if coordinate_format is 'xy' or a 3-tuple of the format
42 | [num_frames, height, width] if coordinate_format is 'tyx'.
43 | output_grid_size: The size of the target image/grid that you want the
44 | coordinates to be with respect to. This is a 2-tuple of the format
45 | [width, height] if coordinate_format is 'xy' or a 3-tuple of the format
46 | [num_frames, height, width] if coordinate_format is 'tyx'.
47 | coordinate_format: Which format the coordinates are in. This can be one
48 | of 'xy' (the default) or 'tyx', which are the only formats used in this
49 | project.
50 |
51 | Returns:
52 | The transformed coordinates, of the same shape as coordinates.
53 |
54 | Raises:
55 | ValueError: if coordinates don't match the given format.
56 | """
57 | if isinstance(input_grid_size, tuple):
58 | input_grid_size = np.array(input_grid_size)
59 | if isinstance(output_grid_size, tuple):
60 | output_grid_size = np.array(output_grid_size)
61 |
62 | if coordinate_format == 'xy':
63 | if input_grid_size.shape[0] != 2 or output_grid_size.shape[0] != 2:
64 | raise ValueError(
65 | 'If coordinate_format is xy, the shapes must be length 2.')
66 | elif coordinate_format == 'tyx':
67 | if input_grid_size.shape[0] != 3 or output_grid_size.shape[0] != 3:
68 | raise ValueError(
69 | 'If coordinate_format is tyx, the shapes must be length 3.')
70 | if input_grid_size[0] != output_grid_size[0]:
71 | raise ValueError('converting frame count is not supported.')
72 | else:
73 | raise ValueError('Recognized coordinate formats are xy and tyx.')
74 |
75 | position_in_grid = coords
76 | position_in_grid = position_in_grid * output_grid_size / input_grid_size
77 |
78 | return position_in_grid
79 |
--------------------------------------------------------------------------------
/tapnet/tapvid3d/annotation_generation/gcs_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Utils for downloading TAPVid3d data from GCS."""
17 |
18 | import os
19 | import sys
20 | from absl import logging
21 | import requests
22 | import tqdm
23 |
24 | current_dir = os.path.dirname(os.path.abspath(__file__))
25 | top_level_dir = os.path.abspath(os.path.join(current_dir, ".."))
26 | sys.path.insert(0, top_level_dir)
27 | from tapnet.tapvid3d.splits import tapvid3d_splits # pylint: disable=g-import-not-at-top, g-bad-import-order
28 |
29 | TAPVID3D_GCS_URL = (
30 | "https://storage.googleapis.com/dm-tapnet/tapvid3d/release_files/v1.0"
31 | )
32 |
33 |
34 | def download_tapvid3d_files(
35 | output_dir: str, split: str, subset: str, debug: bool
36 | ):
37 | """Downloads files from the given split and subset."""
38 | os.makedirs(output_dir, exist_ok=True)
39 | if split == "minival":
40 | filenames_to_download = tapvid3d_splits.get_minival_files(subset)
41 | elif split == "full_eval":
42 | filenames_to_download = tapvid3d_splits.get_full_eval_files(subset)
43 | elif split == "all":
44 | filenames_to_download = tapvid3d_splits.get_all_files(subset)
45 | else:
46 | raise ValueError(f"Unknown split: {split}")
47 |
48 | logging.info("Downloading %s split", split)
49 | for filename in tqdm.tqdm(
50 | filenames_to_download, total=len(filenames_to_download)
51 | ):
52 | local_path = os.path.join(output_dir, filename)
53 | gcs_url = get_tapvid3d_gcs_urls(
54 | filenames=filename,
55 | url_postfix=subset,
56 | )
57 | logging.info("Downloading %s to %s", gcs_url, local_path)
58 | download_file_from_url(input_url=gcs_url, output_filepath=local_path)
59 | if debug:
60 | logging.info("Stopping after one video, debug run.")
61 | break
62 | logging.info("Finished downloading all examples!")
63 |
64 |
65 | def download_file_from_url(input_url: str, output_filepath: str):
66 | """Download the GCS file from the given URL to the given output path."""
67 | if os.path.exists(output_filepath):
68 | logging.info("Skipping download, file already exists: %s", output_filepath)
69 | return
70 | response = requests.get(input_url, stream=True)
71 | if response.status_code == 200:
72 | logging.info("Downloading: %s to %s", input_url, output_filepath)
73 | with open(output_filepath, "wb") as f:
74 | for chunk in response.iter_content(chunk_size=1024):
75 | f.write(chunk)
76 | logging.info("Downloaded!")
77 | else:
78 | logging.info("Download failed. HTTP Status Code: %d", response.status_code)
79 |
80 |
81 | def get_tapvid3d_gcs_urls(
82 | filenames: str | list[str], url_postfix: str = ""
83 | ) -> str | list[str]:
84 | if isinstance(filenames, str):
85 | return f"{TAPVID3D_GCS_URL}/{url_postfix}/{filenames}"
86 | else:
87 | return [
88 | f"{TAPVID3D_GCS_URL}/{url_postfix}/{filename}" for filename in filenames
89 | ]
90 |
--------------------------------------------------------------------------------
/tapnet/tapnext/pscan.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Parallel Scan Operation."""
17 |
18 | import torch
19 |
20 |
21 | def safe_div(numerator, denominator):
22 | return torch.where(
23 | torch.abs(denominator) < 1e-5,
24 | numerator * 100000.0,
25 | numerator / denominator,
26 | )
27 |
28 |
29 | class PScan(torch.autograd.Function):
30 | """Implements a parallel scan operation.
31 |
32 | Given A is (N, T, D) and X is (N, T, D), expands A and X in-place in O(T),
33 | and O(log(T)) if not core-bounded, so that:
34 | Y[:, 0] = Y_init
35 | Y[:, t] = A[:, t] * Y[:, t-1] + X[:, t]
36 | can be computed as:
37 | Y[:, t] = A[:, t] * Y_init + X[:, t]
38 | """
39 |
40 | @classmethod
41 | def expand(cls, weights, bias):
42 | if weights.size(1) == 1:
43 | return
44 | t_even = 2 * (weights.size(1) // 2)
45 |
46 | w_pairs = weights[:, :t_even].view(weights.size(0), t_even // 2, 2, -1)
47 | b_pairs = bias[:, :t_even].view(bias.size(0), t_even // 2, 2, -1)
48 |
49 | b_pairs[:, :, 1].add_(w_pairs[:, :, 1] * b_pairs[:, :, 0])
50 | w_pairs[:, :, 1].mul_(w_pairs[:, :, 0])
51 |
52 | PScan.expand(w_pairs[:, :, 1], b_pairs[:, :, 1])
53 |
54 | b_pairs[:, 1:, 0].add_(w_pairs[:, 1:, 0] * b_pairs[:, :-1, 1])
55 | w_pairs[:, 1:, 0].mul_(w_pairs[:, :-1, 1])
56 |
57 | if t_even < weights.size(1):
58 | bias[:, -1].add_(weights[:, -1] * bias[:, -2])
59 | weights[:, -1].mul_(weights[:, -2])
60 |
61 | @classmethod
62 | def accrev(cls, tensor):
63 | if tensor.size(1) == 1:
64 | return
65 | t_even = 2 * (tensor.size(1) // 2)
66 |
67 | pairs = tensor[:, -t_even:].view(tensor.size(0), t_even // 2, 2, -1)
68 |
69 | pairs[:, :, 0].add_(pairs[:, :, 1])
70 | PScan.accrev(pairs[:, :, 0])
71 | pairs[:, :-1, 1].add_(pairs[:, 1:, 0])
72 |
73 | if t_even < tensor.size(1):
74 | tensor[:, 0].add_(tensor[:, 1])
75 |
76 | @classmethod
77 | def forward(cls, ctx, weights, bias, y_init):
78 | ctx.weights_orig = weights.clone()
79 | ctx.y_init_expanded = y_init[:, None, :].clone()
80 | ctx.weights_expanded = weights.clone()
81 | ctx.bias_expanded = bias.clone()
82 |
83 | PScan.expand(ctx.weights_expanded, ctx.bias_expanded)
84 | output = ctx.weights_expanded * ctx.y_init_expanded + ctx.bias_expanded
85 | return output
86 |
87 | @classmethod
88 | def backward(cls, ctx, grad_output):
89 | grad_input_wrt_output = grad_output * ctx.weights_expanded
90 | grad_accumulated = grad_input_wrt_output.clone()
91 |
92 | PScan.accrev(grad_accumulated)
93 |
94 | grad_weights = safe_div(ctx.y_init_expanded, ctx.weights_orig)
95 | grad_weights[:, 1:].add_(
96 | safe_div(ctx.bias_expanded[:, :-1], ctx.weights_expanded[:, 1:])
97 | )
98 |
99 | grad_bias = safe_div(grad_accumulated, ctx.weights_expanded)
100 | grad_y_init = grad_input_wrt_output.sum(dim=1)
101 |
102 | return grad_weights * grad_accumulated, grad_bias, grad_y_init
103 |
104 |
105 | pscan = PScan.apply
106 |
--------------------------------------------------------------------------------
/tapnet/tapvid3d/annotation_generation/generate_adt.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Download the ADT query points and compute the remaining annotations.
17 |
18 | The *.npz files in the GCS bucket for the ADT split contain only the query
19 | points. This script loads the ADT scenes from the GCS bucket, computes the
20 | remaining annotations and saves them to *.npz files in the local output folder.
21 | """
22 |
23 | import collections
24 | import glob
25 | import os
26 | from typing import Sequence
27 |
28 | from absl import app
29 | from absl import flags
30 | from tapnet.tapvid3d.annotation_generation import adt_utils
31 | from tapnet.tapvid3d.annotation_generation import gcs_utils
32 | import tqdm
33 |
34 |
35 | _OUTPUT_DIR = flags.DEFINE_string(
36 | "output_dir",
37 | "tapvid3d_dataset/adt/",
38 | "Path to folder to store output npz files containing all fields.",
39 | )
40 |
41 | _DEBUG = flags.DEFINE_boolean(
42 | "debug",
43 | False,
44 | "Whether to run in debug mode, downloads only one video.",
45 | )
46 |
47 | _SPLIT = flags.DEFINE_enum(
48 | "split",
49 | "all",
50 | ["minival", "full_eval", "all"],
51 | """
52 | If True, compute metrics on the minival split;
53 | otherwise uses the full_eval split.
54 | """,
55 | )
56 |
57 | _ADT_BASE_PATH = flags.DEFINE_string(
58 | "adt_base_path",
59 | "",
60 | "Path to folder containing ADT scenes as subfolders.",
61 | )
62 |
63 |
64 | def generate_adt_npz(
65 | adt_base_path: str, input_npz_dir: str, output_npz_dir: str
66 | ):
67 | """Generates the final ADT npz files, adding the remaining annotations."""
68 | input_npz = sorted(glob.glob(os.path.join(input_npz_dir, "*.npz")))
69 | done_npz = [
70 | os.path.basename(x)
71 | for x in sorted(glob.glob(os.path.join(output_npz_dir, "*.npz")))
72 | ]
73 |
74 | # Filter completed files.
75 | input_npz = list(
76 | filter(lambda x: os.path.basename(x) not in done_npz, input_npz)
77 | )
78 |
79 | # Group pending files by video.
80 | pending_vid_chunks = collections.defaultdict(list)
81 | for f in input_npz:
82 | basename = os.path.basename(f)[:-4].split("_")
83 | vid = "_".join(basename[:-1])
84 | chunk = int(basename[-1])
85 | pending_vid_chunks[vid].append(chunk)
86 |
87 | for vid, chunks in tqdm.tqdm(
88 | pending_vid_chunks.items(), total=len(pending_vid_chunks)
89 | ):
90 | adt_utils.process_vid(
91 | adt_base_path,
92 | input_npz_dir,
93 | output_npz_dir,
94 | vid,
95 | chunks,
96 | )
97 |
98 |
99 | def main(argv: Sequence[str]) -> None:
100 | if len(argv) > 1:
101 | raise app.UsageError("Too many command-line arguments.")
102 | tmp_adt_dir = os.path.join(_OUTPUT_DIR.value, "tmp")
103 |
104 | # Download ADT npz's containing query points only.
105 | gcs_utils.download_tapvid3d_files(
106 | tmp_adt_dir, _SPLIT.value, "adt", _DEBUG.value
107 | )
108 |
109 | # Compute the remaining annotations and save them to *.npz files.
110 | generate_adt_npz(_ADT_BASE_PATH.value, tmp_adt_dir, _OUTPUT_DIR.value)
111 |
112 |
113 | if __name__ == "__main__":
114 | app.run(main)
115 |
--------------------------------------------------------------------------------
/tapnet/tapvid3d/annotation_generation/generate_pstudio.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Download the PStudio annotations and join with video data.
17 |
18 | The *.npz files in the GCS bucket for the ADT split contain all annotation but
19 | don't contain the video data. This script downloads the video data and adds it
20 | to the *.npz files.
21 | """
22 |
23 | # pylint: disable=g-import-not-at-top,g-bad-import-order
24 | import tensorflow as tf
25 |
26 | tf.config.set_visible_devices([], "GPU")
27 |
28 | import glob
29 | import os
30 | from typing import Sequence
31 | import urllib.request
32 | import zipfile
33 |
34 | from absl import app
35 | from absl import flags
36 | from tapnet.tapvid3d.annotation_generation import gcs_utils
37 | import PIL.Image
38 | import numpy as np
39 | import tqdm
40 |
41 |
42 | _OUTPUT_DIR = flags.DEFINE_string(
43 | "output_dir",
44 | "tapvid3d_dataset/pstudio/",
45 | "Path to folder to store output npz files containing all fields.",
46 | )
47 |
48 | _DEBUG = flags.DEFINE_boolean(
49 | "debug",
50 | False,
51 | "Whether to run in debug mode, downloads only one video.",
52 | )
53 |
54 | _SPLIT = flags.DEFINE_enum(
55 | "split",
56 | "all",
57 | ["minival", "full_eval", "all"],
58 | """
59 | If True, compute metrics on the minival split;
60 | otherwise uses the full_eval split.
61 | """,
62 | )
63 |
64 | _PSTUDIO_DATA = flags.DEFINE_string(
65 | "pstudio_url",
66 | "https://omnomnom.vision.rwth-aachen.de/data/Dynamic3DGaussians/data.zip",
67 | "URL of PStudio data.",
68 | )
69 |
70 |
71 | def generate_pstudio_npz(
72 | pstudio_base_path: str, input_npz_dir: str, output_npz_dir: str
73 | ):
74 | """Generates the final PStudio npz files, adding the video data."""
75 | input_npz = sorted(glob.glob(os.path.join(input_npz_dir, "*.npz")))
76 | done_npz = [
77 | os.path.basename(x)
78 | for x in sorted(glob.glob(os.path.join(output_npz_dir, "*.npz")))
79 | ]
80 |
81 | # Filter completed files.
82 | input_npz = list(
83 | filter(lambda x: os.path.basename(x) not in done_npz, input_npz)
84 | )
85 |
86 | # For each example, load the video data and add it to the npz file.
87 | for filename in tqdm.tqdm(input_npz):
88 | example = dict(np.load(filename, allow_pickle=True))
89 | out_fn = os.path.join(output_npz_dir, os.path.basename(filename))
90 | seq, cam_id = os.path.basename(filename)[:-4].split("_")
91 | # load rgb images
92 | im_fns = sorted(
93 | glob.glob(os.path.join(pstudio_base_path, f"{seq}/ims/{cam_id}/*.jpg"))
94 | )
95 | ims = [np.array(PIL.Image.open(im_fn)) for im_fn in im_fns]
96 | ims_jpeg = [np.array(tf.io.encode_jpeg(im)).item() for im in ims]
97 | example["images_jpeg_bytes"] = ims_jpeg
98 | np.savez(out_fn, **example)
99 |
100 |
101 | def main(argv: Sequence[str]) -> None:
102 | if len(argv) > 1:
103 | raise app.UsageError("Too many command-line arguments.")
104 | tmp_pstudio_dir = os.path.join(_OUTPUT_DIR.value, "tmp")
105 |
106 | gcs_utils.download_tapvid3d_files(
107 | tmp_pstudio_dir, _SPLIT.value, "pstudio", _DEBUG.value
108 | )
109 | # Download and extract PStudio video data.
110 | pstudio_zip = os.path.join(tmp_pstudio_dir, "data.zip")
111 | pstudio_data = os.path.join(tmp_pstudio_dir, "data")
112 | if not os.path.exists(pstudio_zip):
113 | print(f"Downloading PStudio data to {pstudio_zip}")
114 | urllib.request.urlretrieve(_PSTUDIO_DATA.value, pstudio_zip)
115 | else:
116 | print(f"Skipping download, PStudio data already exists: {pstudio_zip}")
117 | if not os.path.exists(pstudio_data):
118 | print(f"Extracting PStudio data to {pstudio_data}")
119 | with zipfile.ZipFile(pstudio_zip, "r") as zip_file:
120 | zip_file.extractall(tmp_pstudio_dir)
121 | else:
122 | print(f"Skipping extraction, PStudio data already exists: {pstudio_data}")
123 |
124 | # Compute the remaining annotations and save them to *.npz files.
125 | generate_pstudio_npz(pstudio_data, tmp_pstudio_dir, _OUTPUT_DIR.value)
126 |
127 |
128 | if __name__ == "__main__":
129 | app.run(main)
130 |
--------------------------------------------------------------------------------
/tapnet/utils/index_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Index utils for TAPViT-SSM."""
17 |
18 | import functools
19 |
20 | import chex
21 | import jax
22 | import jax.numpy as jnp
23 | # from kauldron.typing import Int, Bool, Float # pylint: disable=g-multiple-import,g-importing-member
24 |
25 |
26 | def scatter_inner(
27 | target: chex.Array, mask: chex.Array, timestep: chex.Array, data: chex.Array
28 | ) -> chex.Array:
29 | """Scatter data into target at timestep if mask is true.
30 |
31 | Args:
32 | target: (T, c) float target tensor
33 | mask: ([]) bool mask
34 | timestep: ([]) int timestep
35 | data: (c,) data to scatter
36 | Returns:
37 | (T, c) updated target tensor
38 | """
39 | updated_target = target.at[timestep].set(data)
40 | return jnp.where(mask, updated_target, target)
41 |
42 |
43 | @jax.vmap
44 | @functools.partial(jax.vmap, in_axes=(1, 0, 0, 0), out_axes=1)
45 | def scatter(
46 | target: chex.Array, mask: chex.Array, timestep: chex.Array, data: chex.Array
47 | ) -> chex.Array:
48 | """Scatter data into target at timestep if mask is true.
49 |
50 | (dimensions that are added via vmap are put into square brackets [])
51 |
52 | Args:
53 | target: ([B], T, [Q], c) float target tensor
54 | mask: ([B, Q]) bool mask
55 | timestep: ([B, Q]) int timestep
56 | data: ([B, Q], c,) data to scatter
57 | Returns:
58 | ([B], T, [Q], c) updated target tensor
59 | """
60 | return scatter_inner(target, mask, timestep, data)
61 |
62 |
63 | @jax.vmap
64 | @functools.partial(jax.vmap, in_axes=(1, None, None, 0), out_axes=1)
65 | def scatter2(
66 | target: chex.Array, mask: chex.Array, timestep: chex.Array, data: chex.Array
67 | ) -> chex.Array:
68 | """Scatter data into target at timestep if mask is true.
69 |
70 | (dimensions that are added via vmap are put into square brackets [])
71 |
72 | Args:
73 | target: ([B], T, [N], c) float target tensor
74 | mask: ([B]) bool mask
75 | timestep: ([B]) int timestep
76 | data: ([B, N], c,) data to scatter
77 | Returns:
78 | ([B], T, [N], c) updated target tensor
79 | """
80 | return scatter_inner(target, mask, timestep, data)
81 |
82 |
83 | @jax.vmap
84 | @functools.partial(jax.vmap, in_axes=(1, 0, 0, 0), out_axes=1)
85 | def scatter_prefix(
86 | target: chex.Array, mask: chex.Array, timestep: chex.Array, data: chex.Array
87 | ) -> chex.Array:
88 | """Scatter data into target before timestep if mask is true.
89 |
90 | Equivalent to
91 |
92 | updated_target = target.at[:timestep].set(data)
93 | return jnp.where(mask, updated_target, target)
94 |
95 | but works in a static way.
96 |
97 | (dimensions that are added via vmap are put into square brackets [])
98 |
99 | Args:
100 | target: ([B], T, [Q], c) float target tensor
101 | mask: ([B, Q]) bool mask
102 | timestep: ([B, Q]) int timestep
103 | data: ([B, Q], c,) data to scatter
104 | Returns:
105 | ([B], T, [Q], c) updated target tensor
106 | """
107 | cond = (jnp.arange(target.shape[0]) < timestep) & mask
108 | return jnp.where(
109 | jnp.tile(cond[:, None], (1, target.shape[1])),
110 | jnp.tile(data, (target.shape[0], 1)),
111 | target,
112 | )
113 |
114 |
115 | @jax.vmap
116 | @functools.partial(jax.vmap, in_axes=(1, 0, 0, 0), out_axes=1)
117 | def scatter_suffix(
118 | target: chex.Array, mask: chex.Array, timestep: chex.Array, data: chex.Array
119 | ) -> chex.Array:
120 | """Scatter data into target before timestep if mask is true.
121 |
122 | Equivalent to
123 |
124 | updated_target = target.at[timestep:].set(data)
125 | return jnp.where(mask, updated_target, target)
126 |
127 | but works in a static way.
128 |
129 | (dimensions that are added via vmap are put into square brackets [])
130 |
131 | Args:
132 | target: ([B], T, [Q], c) float target tensor
133 | mask: ([B, Q]) bool mask
134 | timestep: ([B, Q]) int timestep
135 | data: ([B, Q], c,) data to scatter
136 | Returns:
137 | ([B], T, [Q], c) updated target tensor
138 | """
139 | cond = (jnp.arange(target.shape[0]) >= timestep) & mask
140 | return jnp.where(
141 | jnp.tile(cond[:, None], (1, target.shape[1])),
142 | jnp.tile(data, (target.shape[0], 1)),
143 | target,
144 | )
145 |
--------------------------------------------------------------------------------
/tapnet/tapnext/torch_losses.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Losses for TAPNext."""
17 |
18 | import einops
19 | import torch
20 | from torch.nn import functional as F
21 |
22 |
23 | def huber_coordinate_loss(
24 | pred_points, target_points, mask, delta=1.0, pixel_size=256
25 | ):
26 | """Computes the Huber loss between predicted and target coordinates.
27 |
28 | Args:
29 | pred_points (*shape, 2): point coordinates predicted by the model
30 | target_points (*shape, 2): target point coordinates
31 | mask (*shape): visibility mask
32 | delta (float): the threshold of the Huber loss
33 | pixel_size (int): pixel size of the image
34 |
35 | Returns:
36 | Continuous huber loss (*shape)
37 | """
38 | pred_points = pred_points.float()
39 | target_points = target_points.float()
40 | target_points = target_points.clip(0, pixel_size - 1)
41 | error = pred_points - target_points
42 | error = error.clip(-1e8, 1e8) # add magnitude bound to prevent nan
43 | distsqr = torch.sum(torch.square(error), dim=-1, keepdims=True)
44 | dist = torch.sqrt(distsqr + 1e-12)
45 | loss = torch.where(
46 | dist < delta,
47 | distsqr / 2,
48 | delta * (torch.abs(dist) - delta / 2),
49 | )
50 | mask = mask.float()
51 | loss = (loss * mask).sum() / mask.sum()
52 | return loss
53 |
54 |
55 | def coordinate_softmax(logits, labels, mask, pixel_size=256):
56 | """Computes the softmax loss between predicted logits and target coordinates.
57 |
58 | Args:
59 | logits (*shape, n_bins x 2): marginal softmax logits for predicting x and y
60 | coordinates
61 | labels (*shape, 2): taget coordinates
62 | mask (*shape): visibility mask
63 | pixel_size (int): pixel size of the image
64 |
65 | Returns:
66 | loss (float): the softmax loss
67 | """
68 | logits = logits.float()
69 | labels = labels.float()
70 | labels -= 0.5
71 | labels = labels.clip(0, pixel_size - 1)
72 | labels = torch.round(labels).long()
73 | logits = einops.rearrange(logits, 'b ... c -> b c ...')
74 | labels = einops.rearrange(labels, 'b ... c -> b c ...')
75 | logits_x, logits_y = logits.chunk(2, dim=1)
76 | labels_x, labels_y = labels.chunk(2, dim=1)
77 | print(logits_x.shape, labels_x.shape)
78 | loss_x = F.cross_entropy(logits_x, labels_x.squeeze(1))
79 | loss_y = F.cross_entropy(logits_y, labels_y.squeeze(1))
80 | loss = loss_x + loss_y
81 | mask = mask.float()
82 | loss = (loss * mask).sum() / mask.sum()
83 | return loss
84 |
85 |
86 | def tapnext_loss_and_grad(model, batch, loss_weight=1.0):
87 | """Computes the TAPNext loss and performs backward pass on the model.
88 |
89 | Use the init arg `use_checkpointing=True` when constructing TAPNext to
90 | optimize memory; this does not have any impact on the inference speed/quality.
91 |
92 | Args:
93 | model (TAPNext):
94 | batch (dict): a dictionary with 4 keys: * 'video' - a float32 tensor of
95 | shape [batch, time, height, width, 3]; it should be mean-std normalized
96 | (e.g. ImageNet normalization) * 'query_points' - a float32 tensor of shape
97 | [batch, num_queries, 3] - queries have the form (t, x, y); where `t` - is
98 | in [0, time]; x is in [0, width] and y is in [0, height] * 'target_points'
99 | - a float32 tensor of shape [batch, num_queries, time, 2] - target points
100 | of the form (y, x), same ranges as query points * 'visible' - a float32
101 | tensor of shape [batch, num_queries, time, 1] - visibility flags (1. is
102 | visible and 0. is not visible)
103 | loss_weight (float): weight of the loss (default: 1.0)
104 |
105 | Returns:
106 | loss (float): the total loss
107 | """
108 | pred_tracks, track_logits, visible_logits, _ = model(
109 | video=batch['video'], query_points=batch['query_points']
110 | )
111 | huber_loss = huber_coordinate_loss(
112 | pred_tracks,
113 | batch['target_points'].transpose(1, 2).flip(-1),
114 | batch['visible'].transpose(1, 2),
115 | )
116 | softmax_loss = coordinate_softmax(
117 | track_logits,
118 | batch['target_points'].transpose(1, 2).flip(-1),
119 | batch['visible'].transpose(1, 2),
120 | )
121 | coordinate_loss = 0.1 * huber_loss + 1.0 * softmax_loss
122 | visibility_loss = F.binary_cross_entropy_with_logits(
123 | visible_logits, batch['visible'].transpose(1, 2)
124 | )
125 | loss = coordinate_loss + visibility_loss
126 | (loss * loss_weight).backward()
127 | return loss.item()
128 |
--------------------------------------------------------------------------------
/tapnet/utils/ssm_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Util functions for processing videos with SSMs."""
17 |
18 | import enum
19 | import chex
20 | import einops
21 | import flax.linen as nn
22 | import jax
23 | from jax import numpy as jnp
24 | import numpy as np
25 | from recurrentgemma._src.jax import scan
26 |
27 | from tapnet.utils import index_utils
28 |
29 |
30 | class CodeOrigin(enum.StrEnum):
31 | XLM = "xlm"
32 | THIRD_PARTY = "third_party"
33 |
34 |
35 | # @typechecked
36 | def transpose_flatten(
37 | x: chex.Array, like_shape: tuple[int, int, int, int], original_shape: str
38 | ) -> chex.Array:
39 | shape = dict(zip("btnc", like_shape, strict=True))
40 | return einops.rearrange(x, original_shape + "-> (b n) t c", **shape)
41 |
42 |
43 | # @typechecked
44 | def unflatten_untranspose(
45 | x: chex.Array, like_shape: tuple[int, int, int, int], original_shape: str
46 | ) -> chex.Array:
47 | shape = dict(zip("btnc", like_shape, strict=True))
48 | return einops.rearrange(x, "(b n) t c ->" + original_shape, **shape)
49 |
50 |
51 | def get_sharding_spec() -> scan.ShardingSpec:
52 | """Returns the sharding spec for the Pallas kernel."""
53 | devices = np.asarray(jax.devices())
54 | grid = devices.reshape(-1, len(devices))
55 |
56 | # The axis names used in Kauldron are 'i' and 'j'
57 | mesh = jax.sharding.Mesh(grid, axis_names=("i", "j"))
58 | # It looks like `j` is the data axis
59 | scan_sharding_spec = scan.ShardingSpec(
60 | mesh=mesh,
61 | batch_axis_name="j",
62 | activations_axis_name="i",
63 | )
64 | return scan_sharding_spec
65 |
66 |
67 | class TokenSubsampling(nn.Module):
68 | """Drops video tubes."""
69 |
70 | # Hparams.
71 | drop_ratio: float
72 | drop_ratio_test: float = 0.0
73 | shuffle_tokens: bool = True
74 |
75 | # only true is supported for now
76 | mask_temporal_tokens: bool = True
77 |
78 | is_training: bool = False
79 |
80 | @nn.compact
81 | # @typechecked
82 | def __call__(
83 | self,
84 | tokens: chex.Array, # Float["*B T N D"],
85 | mask_token: chex.Array, # Float["*B T N D"],
86 | override_drop_ratio: float | None = None,
87 | ) -> tuple[chex.Array, chex.Array]: # Float["*B T N D"], Bool["*B T"]]:
88 | """Drops tokens randomly."""
89 | n_batch, seq_len, num_tokens, _ = tokens.shape
90 |
91 | # By default tokens are only dropped for training.
92 | if override_drop_ratio is not None:
93 | drop_ratio = override_drop_ratio
94 | elif self.is_training:
95 | drop_ratio = self.drop_ratio
96 | else:
97 | drop_ratio = self.drop_ratio_test
98 | if drop_ratio == 0.0:
99 | return tokens, jnp.ones(tokens.shape[:2], dtype=jnp.bool_)
100 |
101 | # Drop tokens randomly.
102 | if self.mask_temporal_tokens:
103 | n_tokens = int(seq_len) - 1
104 | else:
105 | n_tokens = int(num_tokens)
106 | if len(n_batch) != 1:
107 | raise NotImplementedError("*B is not supported yet!")
108 | rng = self.make_rng("degradation")
109 | num_vis_patches = tokens.shape[2]
110 |
111 | subkey, _ = jax.random.split(rng, 2)
112 |
113 | masked_tokens = tokens
114 |
115 | subsample_size = jax.random.choice(subkey, n_tokens - 1, shape=(n_batch,))
116 | subsample_size += 1
117 | # subsample_size is a random integer between 1 and T - 1 (inclusively)
118 |
119 | # mask_tokens - B T N D
120 | mask = jnp.ones(
121 | (n_batch, num_vis_patches), dtype=jnp.bool_)
122 | indices = jnp.tile(
123 | subsample_size[:, None], (1, num_vis_patches)
124 | )
125 | # TODO(zholus): this works because we don't have temporal positional
126 | # embeddings (i.e. at all temporal positions masked tokens are the same).
127 | # when we have positional embeddings, we need dynamically choose
128 | # the mask token for each position according to `subsample_size`.
129 | scatter_data = mask_token[:, 0]
130 | # scatter_data.shape == [B, N, D]
131 | masked_tokens = index_utils.scatter_suffix(
132 | masked_tokens, mask, indices, scatter_data
133 | )
134 | masked_positions = jnp.zeros(
135 | (n_batch, n_tokens + 1, 1, 1), dtype=jnp.bool_)
136 | mask = jnp.ones(
137 | (n_batch, 1), dtype=jnp.bool_)
138 | masked_positions = index_utils.scatter_suffix(
139 | masked_positions, mask, subsample_size[:, None],
140 | jnp.ones((n_batch, 1, 1), dtype=jnp.bool_)
141 | )[..., 0, 0]
142 | return masked_tokens, masked_positions
143 |
--------------------------------------------------------------------------------
/colabs/tapir_clustering.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "id": "MWPOsk-I8o69"
8 | },
9 | "outputs": [],
10 | "source": [
11 | "# @title Install code and dependencies {form-width: \"25%\"}\n",
12 | "!pip install git+https://github.com/google-deepmind/tapnet.git"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": null,
18 | "metadata": {
19 | "id": "dNWBx_DOHSSt"
20 | },
21 | "outputs": [],
22 | "source": [
23 | "# @title Download Model {form-width: \"25%\"}\n",
24 | "\n",
25 | "%mkdir tapnet/checkpoints\n",
26 | "\n",
27 | "!wget -P tapnet/checkpoints https://storage.googleapis.com/dm-tapnet/causal_tapir_checkpoint.npy\n",
28 | "\n",
29 | "%ls tapnet/checkpoints\n",
30 | "\n",
31 | "checkpoint_path = 'tapnet/checkpoints/causal_tapir_checkpoint.npy'"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": null,
37 | "metadata": {
38 | "id": "jtTNXUNCHVAL"
39 | },
40 | "outputs": [],
41 | "source": [
42 | "# @title Imports {form-width: \"25%\"}\n",
43 | "%matplotlib widget\n",
44 | "import functools\n",
45 | "\n",
46 | "import haiku as hk\n",
47 | "import jax\n",
48 | "import jax.numpy as jnp\n",
49 | "import matplotlib.pyplot as plt\n",
50 | "import mediapy as media\n",
51 | "import numpy as np\n",
52 | "from tqdm import tqdm\n",
53 | "import tree\n",
54 | "\n",
55 | "from tapnet.robotap import tapir_clustering\n",
56 | "from tapnet.utils import transforms\n",
57 | "from tapnet.utils import viz_utils"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": null,
63 | "metadata": {
64 | "id": "0J9kVfSuHmqS"
65 | },
66 | "outputs": [],
67 | "source": [
68 | "# @title Load an Exemplar Video {form-width: \"25%\"}\n",
69 | "\n",
70 | "%mkdir tapnet/examplar_videos\n",
71 | "\n",
72 | "!wget -P tapnet/examplar_videos https://storage.googleapis.com/dm-tapnet/robotap/for_clustering.mp4\n",
73 | "\n",
74 | "video = media.read_video('tapnet/examplar_videos/for_clustering.mp4')\n",
75 | "height, width = video.shape[1:3]\n",
76 | "media.show_video(video[::5], fps=10)"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": null,
82 | "metadata": {
83 | "id": "7Vjhi4PdJ2W-"
84 | },
85 | "outputs": [],
86 | "source": [
87 | "# @title Run TAPIR to extract point tracks {form-width: \"25%\"}\n",
88 | "\n",
89 | "demo_videos = {\"dummy_id\":video}\n",
90 | "demo_episode_ids = list(demo_videos.keys())\n",
91 | "track_dict = tapir_clustering.track_many_points(\n",
92 | " demo_videos,\n",
93 | " demo_episode_ids,\n",
94 | " checkpoint_path,\n",
95 | " point_batch_size=1024,\n",
96 | " points_per_frame=10,\n",
97 | ")"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": null,
103 | "metadata": {
104 | "id": "kU2yqJVTPgg-"
105 | },
106 | "outputs": [],
107 | "source": [
108 | "# @title Run the clustering {form-width: \"25%\"}\n",
109 | "\n",
110 | "clustered = tapir_clustering.compute_clusters(\n",
111 | " track_dict['separation_tracks'],\n",
112 | " track_dict['separation_visibility'],\n",
113 | " track_dict['demo_episode_ids'],\n",
114 | " track_dict['video_shape'],\n",
115 | " track_dict['query_features'],\n",
116 | " max_num_cats=12,\n",
117 | " final_num_cats=7,\n",
118 | ")"
119 | ]
120 | },
121 | {
122 | "cell_type": "code",
123 | "execution_count": null,
124 | "metadata": {
125 | "id": "FCNCAeLVQ0r2"
126 | },
127 | "outputs": [],
128 | "source": [
129 | "# @title Display the inferred clusters {form-width: \"25%\"}\n",
130 | "\n",
131 | "separation_visibility_trim = clustered['separation_visibility']\n",
132 | "separation_tracks_trim = clustered['separation_tracks']\n",
133 | "\n",
134 | "pointtrack_video = viz_utils.plot_tracks_v2(\n",
135 | " (demo_videos[demo_episode_ids[0]]).astype(np.uint8),\n",
136 | " separation_tracks_trim[demo_episode_ids[0]],\n",
137 | " 1.0-separation_visibility_trim[demo_episode_ids[0]],\n",
138 | " trackgroup=clustered['classes']\n",
139 | ")\n",
140 | "media.show_video(pointtrack_video, fps=20)"
141 | ]
142 | }
143 | ],
144 | "metadata": {
145 | "colab": {
146 | "provenance": []
147 | },
148 | "kernelspec": {
149 | "display_name": "Python 3",
150 | "name": "python3"
151 | },
152 | "language_info": {
153 | "name": "python"
154 | }
155 | },
156 | "nbformat": 4,
157 | "nbformat_minor": 0
158 | }
159 |
--------------------------------------------------------------------------------
/configs/tapnet_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Default config to train the TapNet."""
17 |
18 | from jaxline import base_config
19 | from ml_collections import config_dict
20 |
21 |
22 | # We define the experiment launch config in the same file as the experiment to
23 | # keep things self-contained in a single file.
24 | def get_config() -> config_dict.ConfigDict:
25 | """Return config object for training."""
26 | config = base_config.get_base_config()
27 |
28 | # Experiment config.
29 | config.training_steps = 100000
30 |
31 | # NOTE: duplicates not allowed.
32 | config.shared_module_names = ('tapnet_model',)
33 |
34 | config.dataset_names = ('kubric',)
35 | # Note: eval modes must always start with 'eval_'.
36 | config.eval_modes = (
37 | 'eval_davis_points',
38 | 'eval_jhmdb',
39 | 'eval_robotics_points',
40 | 'eval_kinetics_points',
41 | )
42 | config.checkpoint_dir = '/tmp/tapnet_training/'
43 | config.evaluate_every = 10000
44 |
45 | config.experiment_kwargs = config_dict.ConfigDict(
46 | dict(
47 | config=dict(
48 | sweep_name='default_sweep',
49 | save_final_checkpoint_as_npy=True,
50 | # `enable_double_transpose` should always be false when using 1D.
51 | # For other D It is also completely untested and very unlikely
52 | # to work.
53 | optimizer=dict(
54 | base_lr=2e-3,
55 | max_norm=-1, # < 0 to turn off.
56 | weight_decay=1e-2,
57 | schedule_type='cosine',
58 | cosine_decay_kwargs=dict(
59 | init_value=0.0,
60 | warmup_steps=5000,
61 | end_value=0.0,
62 | ),
63 | optimizer='adam',
64 | # Optimizer-specific kwargs.
65 | adam_kwargs=dict(
66 | b1=0.9,
67 | b2=0.95,
68 | eps=1e-8,
69 | ),
70 | ),
71 | fast_variables=tuple(),
72 | shared_modules=dict(
73 | shared_module_names=config.get_oneway_ref(
74 | 'shared_module_names',
75 | ),
76 | tapnet_model_kwargs=dict(),
77 | ),
78 | datasets=dict(
79 | dataset_names=config.get_oneway_ref('dataset_names'),
80 | kubric_kwargs=dict(
81 | batch_dims=8,
82 | shuffle_buffer_size=128,
83 | train_size=(256, 256),
84 | ),
85 | ),
86 | supervised_point_prediction_kwargs=dict(
87 | prediction_algo='cost_volume_regressor',
88 | ),
89 | checkpoint_dir=config.get_oneway_ref('checkpoint_dir'),
90 | evaluate_every=config.get_oneway_ref('evaluate_every'),
91 | eval_modes=config.get_oneway_ref('eval_modes'),
92 | # If true, run evaluate() on the experiment once before
93 | # you load a checkpoint.
94 | # This is useful for getting initial values of metrics
95 | # at random weights, or when debugging locally if you
96 | # do not have any train job running.
97 | davis_points_path='',
98 | jhmdb_path='',
99 | robotics_points_path='',
100 | training=dict(
101 | # Note: to sweep n_training_steps, DO NOT sweep these
102 | # fields directly. Instead sweep config.training_steps.
103 | # Otherwise, decay/stopping logic
104 | # is not guaranteed to be consistent.
105 | n_training_steps=config.get_oneway_ref('training_steps'),
106 | ),
107 | inference=dict(
108 | input_video_path='',
109 | output_video_path='',
110 | resize_height=256, # video height resized to before inference
111 | resize_width=256, # video width resized to before inference
112 | num_points=20, # number of random points to sample
113 | ),
114 | )
115 | )
116 | )
117 |
118 | # Set up where to store the resulting model.
119 | config.train_checkpoint_all_hosts = False
120 | config.save_checkpoint_interval = 10
121 | config.eval_initial_weights = True
122 |
123 | # Prevents accidentally setting keys that aren't recognized (e.g. in tests).
124 | config.lock()
125 |
126 | return config
127 |
--------------------------------------------------------------------------------
/configs/tapir_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Default config to train the TAPIR."""
17 |
18 | from jaxline import base_config
19 | from ml_collections import config_dict
20 |
21 |
22 | # We define the experiment launch config in the same file as the experiment to
23 | # keep things self-contained in a single file.
24 | def get_config() -> config_dict.ConfigDict:
25 | """Return config object for training."""
26 | config = base_config.get_base_config()
27 |
28 | # Experiment config.
29 | config.training_steps = 100000
30 |
31 | # NOTE: duplicates not allowed.
32 | config.shared_module_names = ('tapir_model',)
33 |
34 | config.dataset_names = ('kubric',)
35 | # Note: eval modes must always start with 'eval_'.
36 | config.eval_modes = (
37 | 'eval_davis_points',
38 | 'eval_jhmdb',
39 | 'eval_robotics_points',
40 | 'eval_kinetics_points',
41 | )
42 | config.checkpoint_dir = '/tmp/tapnet_training/'
43 | config.evaluate_every = 10000
44 |
45 | config.experiment_kwargs = config_dict.ConfigDict(
46 | dict(
47 | config=dict(
48 | sweep_name='default_sweep',
49 | save_final_checkpoint_as_npy=True,
50 | # `enable_double_transpose` should always be false when using 1D.
51 | # For other D It is also completely untested and very unlikely
52 | # to work.
53 | optimizer=dict(
54 | base_lr=1e-3,
55 | max_norm=-1, # < 0 to turn off.
56 | weight_decay=1e-1,
57 | schedule_type='cosine',
58 | cosine_decay_kwargs=dict(
59 | init_value=0.0,
60 | warmup_steps=1000,
61 | end_value=0.0,
62 | ),
63 | optimizer='adam',
64 | # Optimizer-specific kwargs.
65 | adam_kwargs=dict(
66 | b1=0.9,
67 | b2=0.95,
68 | eps=1e-8,
69 | ),
70 | ),
71 | fast_variables=tuple(),
72 | shared_modules=dict(
73 | shared_module_names=config.get_oneway_ref(
74 | 'shared_module_names',
75 | ),
76 | tapir_model_kwargs=dict(
77 | bilinear_interp_with_depthwise_conv=False,
78 | pyramid_level=0,
79 | use_causal_conv=False,
80 | initial_resolution=(256, 256),
81 | ),
82 | ),
83 | datasets=dict(
84 | dataset_names=config.get_oneway_ref('dataset_names'),
85 | kubric_kwargs=dict(
86 | batch_dims=8,
87 | shuffle_buffer_size=128,
88 | train_size=(256, 256),
89 | ),
90 | ),
91 | supervised_point_prediction_kwargs=dict(
92 | prediction_algo='cost_volume_regressor',
93 | model_key='tapir_model',
94 | ),
95 | checkpoint_dir=config.get_oneway_ref('checkpoint_dir'),
96 | evaluate_every=config.get_oneway_ref('evaluate_every'),
97 | eval_modes=config.get_oneway_ref('eval_modes'),
98 | # If true, run evaluate() on the experiment once before
99 | # you load a checkpoint.
100 | # This is useful for getting initial values of metrics
101 | # at random weights, or when debugging locally if you
102 | # do not have any train job running.
103 | davis_points_path='',
104 | jhmdb_path='',
105 | robotics_points_path='',
106 | training=dict(
107 | # Note: to sweep n_training_steps, DO NOT sweep these
108 | # fields directly. Instead sweep config.training_steps.
109 | # Otherwise, decay/stopping logic
110 | # is not guaranteed to be consistent.
111 | n_training_steps=config.get_oneway_ref('training_steps'),
112 | ),
113 | inference=dict(
114 | input_video_path='',
115 | output_video_path='',
116 | resize_height=256, # video height resized to before inference
117 | resize_width=256, # video width resized to before inference
118 | num_points=20, # number of random points to sample
119 | ),
120 | )
121 | )
122 | )
123 |
124 | # Set up where to store the resulting model.
125 | config.train_checkpoint_all_hosts = False
126 | config.save_checkpoint_interval = 10
127 | config.eval_initial_weights = True
128 |
129 | # Prevents accidentally setting keys that aren't recognized (e.g. in tests).
130 | config.lock()
131 |
132 | return config
133 |
--------------------------------------------------------------------------------
/configs/causal_tapir_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Default config to train the TAPIR."""
17 |
18 | from jaxline import base_config
19 | from ml_collections import config_dict
20 |
21 |
22 | # We define the experiment launch config in the same file as the experiment to
23 | # keep things self-contained in a single file.
24 | def get_config() -> config_dict.ConfigDict:
25 | """Return config object for training."""
26 | config = base_config.get_base_config()
27 |
28 | # Experiment config.
29 | config.training_steps = 100000
30 |
31 | # NOTE: duplicates not allowed.
32 | config.shared_module_names = ('tapir_model',)
33 |
34 | config.dataset_names = ('kubric',)
35 | # Note: eval modes must always start with 'eval_'.
36 | config.eval_modes = (
37 | 'eval_davis_points',
38 | 'eval_jhmdb',
39 | 'eval_robotics_points',
40 | 'eval_kinetics_points',
41 | )
42 | config.checkpoint_dir = '/tmp/tapnet_training/'
43 | config.evaluate_every = 10000
44 |
45 | config.experiment_kwargs = config_dict.ConfigDict(
46 | dict(
47 | config=dict(
48 | sweep_name='default_sweep',
49 | save_final_checkpoint_as_npy=True,
50 | # `enable_double_transpose` should always be false when using 1D.
51 | # For other D It is also completely untested and very unlikely
52 | # to work.
53 | optimizer=dict(
54 | base_lr=1e-3,
55 | max_norm=-1, # < 0 to turn off.
56 | weight_decay=1e-1,
57 | schedule_type='cosine',
58 | cosine_decay_kwargs=dict(
59 | init_value=0.0,
60 | warmup_steps=1000,
61 | end_value=0.0,
62 | ),
63 | optimizer='adam',
64 | # Optimizer-specific kwargs.
65 | adam_kwargs=dict(
66 | b1=0.9,
67 | b2=0.95,
68 | eps=1e-8,
69 | ),
70 | ),
71 | fast_variables=tuple(),
72 | shared_modules=dict(
73 | shared_module_names=config.get_oneway_ref(
74 | 'shared_module_names',
75 | ),
76 | tapir_model_kwargs=dict(
77 | bilinear_interp_with_depthwise_conv=False,
78 | pyramid_level=1,
79 | use_causal_conv=True,
80 | initial_resolution=(256, 256),
81 | ),
82 | ),
83 | datasets=dict(
84 | dataset_names=config.get_oneway_ref('dataset_names'),
85 | kubric_kwargs=dict(
86 | batch_dims=8,
87 | shuffle_buffer_size=128,
88 | train_size=(256, 256),
89 | ),
90 | ),
91 | supervised_point_prediction_kwargs=dict(
92 | prediction_algo='cost_volume_regressor',
93 | model_key='tapir_model',
94 | ),
95 | checkpoint_dir=config.get_oneway_ref('checkpoint_dir'),
96 | evaluate_every=config.get_oneway_ref('evaluate_every'),
97 | eval_modes=config.get_oneway_ref('eval_modes'),
98 | # If true, run evaluate() on the experiment once before
99 | # you load a checkpoint.
100 | # This is useful for getting initial values of metrics
101 | # at random weights, or when debugging locally if you
102 | # do not have any train job running.
103 | davis_points_path='',
104 | jhmdb_path='',
105 | robotics_points_path='',
106 | training=dict(
107 | # Note: to sweep n_training_steps, DO NOT sweep these
108 | # fields directly. Instead sweep config.training_steps.
109 | # Otherwise, decay/stopping logic
110 | # is not guaranteed to be consistent.
111 | n_training_steps=config.get_oneway_ref('training_steps'),
112 | ),
113 | inference=dict(
114 | input_video_path='',
115 | output_video_path='',
116 | resize_height=256, # video height resized to before inference
117 | resize_width=256, # video width resized to before inference
118 | num_points=20, # number of random points to sample
119 | ),
120 | )
121 | )
122 | )
123 |
124 | # Set up where to store the resulting model.
125 | config.train_checkpoint_all_hosts = False
126 | config.save_checkpoint_interval = 10
127 | config.eval_initial_weights = True
128 |
129 | # Prevents accidentally setting keys that aren't recognized (e.g. in tests).
130 | config.lock()
131 |
132 | return config
133 |
--------------------------------------------------------------------------------
/tapnet/tapnext/losses.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Losses for TAPNext."""
17 |
18 | import dataclasses
19 |
20 | import jax
21 | import jax.numpy as jnp
22 | from kauldron import kontext
23 | from kauldron.losses import base
24 | from kauldron.typing import Bool, Float, typechecked # pylint: disable=g-multiple-import,g-importing-member
25 | import optax
26 |
27 |
28 | @dataclasses.dataclass(eq=True, frozen=True, kw_only=True)
29 | class Huber(base.Loss):
30 | """Huber loss for point track prediction."""
31 |
32 | delta: float = 1.0
33 |
34 | pred_points: kontext.Key = kontext.REQUIRED
35 | target_points: kontext.Key = kontext.REQUIRED
36 | normalize_by: str = "values"
37 |
38 | @typechecked
39 | def get_values(
40 | self,
41 | pred_points: Float["*a 2"],
42 | target_points: Float["*a 2"],
43 | ) -> Float["*a 1"]:
44 | pred_points = jnp.astype(pred_points, jnp.float32)
45 | target_points = jnp.astype(target_points, jnp.float32)
46 | target_points = jnp.clip(target_points, 0, 255)
47 | error = pred_points - target_points
48 | error = jnp.clip(error, -1e8, 1e8) # add magnitude bound to prevent nan
49 | distsqr = jnp.sum(jnp.square(error), axis=-1, keepdims=True)
50 | dist = jnp.sqrt(distsqr + 1e-12) # add eps to prevent nan
51 | loss = jnp.where(
52 | dist < self.delta,
53 | distsqr / 2,
54 | self.delta * (jnp.abs(dist) - self.delta / 2),
55 | )
56 | return loss
57 |
58 |
59 | @dataclasses.dataclass(eq=True, frozen=True, kw_only=True)
60 | class MaskedL1(base.Loss):
61 | """Masked L1 loss for predicting random image patches."""
62 | pred_patches: kontext.Key = kontext.REQUIRED
63 | target_patches: kontext.Key = kontext.REQUIRED
64 | temporal_mask: kontext.Key = kontext.REQUIRED
65 | normalize_by: str = "values"
66 | image_norm: str = "sum" # "sum" or "mean"
67 |
68 | @typechecked
69 | def get_values(
70 | self,
71 | pred_patches: Float["*B T h w C"],
72 | target_patches: Float["*B T h w C"],
73 | temporal_mask: Bool["*B T"],
74 | ) -> Float["*a 1"]:
75 | pred_patches = jnp.astype(pred_patches, jnp.float32)
76 | target_patches = jnp.astype(target_patches, jnp.float32)
77 | loss = jnp.abs(pred_patches - target_patches) # * temporal_mask
78 | if self.image_norm == "sum":
79 | loss = jnp.sum(loss, axis=[-1, -2, -3]) / 1024.
80 | elif self.image_norm == "mean":
81 | loss = jnp.mean(loss, axis=[-1, -2, -3])
82 | loss = jnp.mean(loss, axis=-1)
83 | return loss[..., None]
84 |
85 |
86 | @dataclasses.dataclass(eq=True, frozen=True, kw_only=True)
87 | class CoordinateSoftmaxCrossEntropyWithIntLabels(base.Loss):
88 | """Softmax cross-entropy loss with integer labels."""
89 |
90 | logits: kontext.Key = kontext.REQUIRED # e.g. "preds.logits"
91 | labels: kontext.Key = kontext.REQUIRED # e.g. "batch.label"
92 | pixel_size: int = 256
93 |
94 | @typechecked
95 | def get_values(
96 | self, logits: Float["*a n"], labels: Float["*a 2"]
97 | ) -> Float["*a 1"]:
98 | logits = jnp.astype(logits, jnp.float32)
99 | labels = jnp.astype(labels, jnp.float32)
100 | labels -= 0.5
101 | labels = jnp.clip(labels, 0, self.pixel_size - 1)
102 | labels = jnp.round(labels).astype(jnp.int32)
103 | logits_x, logits_y = jnp.split(logits, 2, axis=-1)
104 | labels_x, labels_y = jnp.split(labels, 2, axis=-1)
105 | loss_x = optax.softmax_cross_entropy_with_integer_labels(
106 | logits=logits_x, labels=labels_x.squeeze(-1)
107 | )[..., None]
108 | loss_y = optax.softmax_cross_entropy_with_integer_labels(
109 | logits=logits_y, labels=labels_y.squeeze(-1)
110 | )[..., None]
111 | return loss_x + loss_y
112 |
113 |
114 | @dataclasses.dataclass(eq=True, frozen=True, kw_only=True)
115 | class Certainty(base.Loss):
116 | """Loss for point track uncertainty prediction.
117 |
118 | A point prediction is certain if it falls within threshold of ground truth.
119 | The 3rd term of the loss in Equation (1) of TAPIR paper
120 | https://arxiv.org/abs/2306.08637
121 | """
122 |
123 | threshold: float = 1.0
124 |
125 | logits: kontext.Key = kontext.REQUIRED
126 | pred_points: kontext.Key = kontext.REQUIRED
127 | target_points: kontext.Key = kontext.REQUIRED
128 | normalize_by: str = "values"
129 |
130 | @typechecked
131 | def get_values(
132 | self,
133 | logits: Float["*a 1"],
134 | pred_points: Float["*a 2"],
135 | target_points: Float["*a 2"],
136 | ) -> Float["*a 1"]:
137 | logits = jnp.astype(logits, jnp.float32)
138 | pred_points = jnp.astype(pred_points, jnp.float32)
139 | target_points = jnp.astype(target_points, jnp.float32)
140 | pred_points = jax.lax.stop_gradient(pred_points)
141 | error = pred_points - target_points
142 | distsqr = jnp.sum(jnp.square(error), axis=-1, keepdims=True)
143 | is_certain = (distsqr <= self.threshold**2).astype(logits.dtype)
144 | loss = optax.sigmoid_binary_cross_entropy(logits, is_certain)
145 | return loss
146 |
--------------------------------------------------------------------------------
/configs/tapir_bootstrap_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Default config to train the TAPIR."""
17 |
18 | from jaxline import base_config
19 | from ml_collections import config_dict
20 |
21 |
22 | # We define the experiment launch config in the same file as the experiment to
23 | # keep things self-contained in a single file.
24 | def get_config() -> config_dict.ConfigDict:
25 | """Return config object for training."""
26 | config = base_config.get_base_config()
27 |
28 | # Experiment config.
29 | config.training_steps = 100000
30 |
31 | # NOTE: duplicates not allowed.
32 | config.shared_module_names = ('tapir_model',)
33 |
34 | config.dataset_names = ('kubric',)
35 | # Note: eval modes must always start with 'eval_'.
36 | config.eval_modes = (
37 | 'eval_davis_points',
38 | 'eval_jhmdb',
39 | 'eval_robotics_points',
40 | 'eval_kinetics_points',
41 | )
42 | config.checkpoint_dir = '/tmp/tapnet_training/'
43 | config.evaluate_every = 10000
44 |
45 | config.experiment_kwargs = config_dict.ConfigDict(
46 | dict(
47 | config=dict(
48 | sweep_name='default_sweep',
49 | save_final_checkpoint_as_npy=True,
50 | # `enable_double_transpose` should always be false when using 1D.
51 | # For other D It is also completely untested and very unlikely
52 | # to work.
53 | optimizer=dict(
54 | base_lr=1e-3,
55 | max_norm=-1, # < 0 to turn off.
56 | weight_decay=1e-1,
57 | schedule_type='cosine',
58 | cosine_decay_kwargs=dict(
59 | init_value=0.0,
60 | warmup_steps=1000,
61 | end_value=0.0,
62 | ),
63 | optimizer='adam',
64 | # Optimizer-specific kwargs.
65 | adam_kwargs=dict(
66 | b1=0.9,
67 | b2=0.95,
68 | eps=1e-8,
69 | ),
70 | ),
71 | fast_variables=tuple(),
72 | shared_modules=dict(
73 | shared_module_names=config.get_oneway_ref(
74 | 'shared_module_names',
75 | ),
76 | tapir_model_kwargs=dict(
77 | bilinear_interp_with_depthwise_conv=False,
78 | pyramid_level=1,
79 | use_causal_conv=False,
80 | initial_resolution=(256, 256),
81 | extra_convs=True,
82 | softmax_temperature=10.0,
83 | ),
84 | ),
85 | datasets=dict(
86 | dataset_names=config.get_oneway_ref('dataset_names'),
87 | kubric_kwargs=dict(
88 | batch_dims=8,
89 | shuffle_buffer_size=128,
90 | train_size=(256, 256),
91 | ),
92 | ),
93 | supervised_point_prediction_kwargs=dict(
94 | prediction_algo='cost_volume_regressor',
95 | model_key='tapir_model',
96 | ),
97 | checkpoint_dir=config.get_oneway_ref('checkpoint_dir'),
98 | evaluate_every=config.get_oneway_ref('evaluate_every'),
99 | eval_modes=config.get_oneway_ref('eval_modes'),
100 | # If true, run evaluate() on the experiment once before
101 | # you load a checkpoint.
102 | # This is useful for getting initial values of metrics
103 | # at random weights, or when debugging locally if you
104 | # do not have any train job running.
105 | davis_points_path='',
106 | jhmdb_path='',
107 | robotics_points_path='',
108 | training=dict(
109 | # Note: to sweep n_training_steps, DO NOT sweep these
110 | # fields directly. Instead sweep config.training_steps.
111 | # Otherwise, decay/stopping logic
112 | # is not guaranteed to be consistent.
113 | n_training_steps=config.get_oneway_ref('training_steps'),
114 | ),
115 | inference=dict(
116 | input_video_path='',
117 | output_video_path='',
118 | resize_height=256, # video height resized to before inference
119 | resize_width=256, # video width resized to before inference
120 | num_points=20, # number of random points to sample
121 | ),
122 | )
123 | )
124 | )
125 |
126 | # Set up where to store the resulting model.
127 | config.train_checkpoint_all_hosts = False
128 | config.save_checkpoint_interval = 10
129 | config.eval_initial_weights = True
130 |
131 | # Prevents accidentally setting keys that aren't recognized (e.g. in tests).
132 | config.lock()
133 |
134 | return config
135 |
--------------------------------------------------------------------------------
/tapnet/tapnext/tapnext_benchmark_pytorch.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "id": "RHt3rGLLxfWs"
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import torch\n",
12 | "import torchvision"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": null,
18 | "metadata": {
19 | "colab": {
20 | "base_uri": "https://localhost:8080/"
21 | },
22 | "id": "-gBuXTWqxuMV",
23 | "outputId": "42ca4293-2496-4a03-ced5-a02d9253c721"
24 | },
25 | "outputs": [],
26 | "source": [
27 | "torch.__version__, torchvision.__version__"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": null,
33 | "metadata": {
34 | "colab": {
35 | "base_uri": "https://localhost:8080/"
36 | },
37 | "id": "gjdq8wLcQWd7",
38 | "outputId": "8c5ed594-021c-4c67-e0df-949151d1247f"
39 | },
40 | "outputs": [],
41 | "source": [
42 | "!pip install -q git+https://github.com/google-deepmind/tapnet.git"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": null,
48 | "metadata": {
49 | "colab": {
50 | "base_uri": "https://localhost:8080/"
51 | },
52 | "id": "mrQ_uQDeQee-",
53 | "outputId": "ce0d3868-7b82-4e8e-88ae-fd5ae960f95b"
54 | },
55 | "outputs": [],
56 | "source": [
57 | "!pip install -q git+https://github.com/google-deepmind/recurrentgemma.git@main"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": null,
63 | "metadata": {
64 | "colab": {
65 | "base_uri": "https://localhost:8080/"
66 | },
67 | "id": "NIZQxfcgyFM9",
68 | "outputId": "8cd10428-69a7-4c39-f3f9-eec9c07ed108"
69 | },
70 | "outputs": [],
71 | "source": [
72 | "!pip install \"numpy\u003c2.1.0\""
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": null,
78 | "metadata": {
79 | "id": "gju7QkZH2XLL"
80 | },
81 | "outputs": [],
82 | "source": [
83 | "import tqdm"
84 | ]
85 | },
86 | {
87 | "cell_type": "markdown",
88 | "metadata": {
89 | "id": "uQ6Ako7IQERX"
90 | },
91 | "source": [
92 | "### TAPNext"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": null,
98 | "metadata": {
99 | "id": "hy98wCMgQx6x"
100 | },
101 | "outputs": [],
102 | "source": [
103 | "import time\n",
104 | "import numpy as np\n",
105 | "from tapnet.tapnext.tapnext_torch import TAPNext\n",
106 | "import torch.nn.functional as F"
107 | ]
108 | },
109 | {
110 | "cell_type": "markdown",
111 | "metadata": {
112 | "id": "gJpHisHSQERc"
113 | },
114 | "source": [
115 | "### Create the model and load checkpoint"
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": null,
121 | "metadata": {
122 | "id": "pnTma4NKQERc"
123 | },
124 | "outputs": [],
125 | "source": [
126 | "model = TAPNext(image_size=(256, 256))\n",
127 | "model.cuda()"
128 | ]
129 | },
130 | {
131 | "cell_type": "markdown",
132 | "metadata": {
133 | "id": "CukzSyYSQERd"
134 | },
135 | "source": [
136 | "### Run inference"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": null,
142 | "metadata": {
143 | "colab": {
144 | "base_uri": "https://localhost:8080/",
145 | "height": 477
146 | },
147 | "id": "YNf98GICtJMa",
148 | "outputId": "05da0eef-b914-4d1b-f740-5152f47ddb4c"
149 | },
150 | "outputs": [],
151 | "source": [
152 | "model.eval()\n",
153 | "for p in model.parameters():\n",
154 | " p.requires_grad = False"
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": null,
160 | "metadata": {
161 | "id": "IUowukX9Qx6x"
162 | },
163 | "outputs": [],
164 | "source": [
165 | "NUM_QUERIES = 1024"
166 | ]
167 | },
168 | {
169 | "cell_type": "code",
170 | "execution_count": null,
171 | "metadata": {
172 | "id": "TyiVKz9MQx6x"
173 | },
174 | "outputs": [],
175 | "source": [
176 | "video = torch.zeros((1, 1024, 256, 256, 3), dtype=torch.float32).cuda()\n",
177 | "query_points = torch.zeros((1, NUM_QUERIES, 3), dtype=torch.float32).cuda()"
178 | ]
179 | },
180 | {
181 | "cell_type": "code",
182 | "execution_count": null,
183 | "metadata": {
184 | "id": "vLRpvleUQx6x"
185 | },
186 | "outputs": [],
187 | "source": [
188 | "DTYPE = torch.float16 # use fp16 or bf16"
189 | ]
190 | },
191 | {
192 | "cell_type": "code",
193 | "execution_count": null,
194 | "metadata": {
195 | "id": "n6x1NxPaQx6x"
196 | },
197 | "outputs": [],
198 | "source": [
199 | "fwd = torch.compile(model.forward)\n",
200 | "with torch.no_grad():\n",
201 | " with torch.amp.autocast('cuda', dtype=DTYPE, enabled=True):\n",
202 | " _, _, _, tracking_state = fwd(video=video[:, :1], query_points=query_points)\n",
203 | " c = 0\n",
204 | " for k in tqdm.tqdm(range(1, video.shape[1])):\n",
205 | " if k == 512:\n",
206 | " # we let the model to run for several GPU burn-in steps\n",
207 | " tt = time.time()\n",
208 | " c = 0\n",
209 | " _, _, _, tracking_state = fwd(\n",
210 | " video=video[:, k : k + 1], state=tracking_state\n",
211 | " )\n",
212 | " c += 1\n",
213 | " d = time.time() - tt\n",
214 | " print('FPS:', c / d, 'latency', 1000 * d / c, 'ms')"
215 | ]
216 | }
217 | ],
218 | "metadata": {
219 | "accelerator": "GPU",
220 | "colab": {
221 | "gpuType": "T4",
222 | "provenance": []
223 | },
224 | "kernelspec": {
225 | "display_name": "Python 3 (ipykernel)",
226 | "language": "python",
227 | "name": "python3"
228 | },
229 | "language_info": {
230 | "codemirror_mode": {
231 | "name": "ipython",
232 | "version": 3
233 | },
234 | "file_extension": ".py",
235 | "mimetype": "text/x-python",
236 | "name": "python",
237 | "nbconvert_exporter": "python",
238 | "pygments_lexer": "ipython3",
239 | "version": "3.11.11"
240 | }
241 | },
242 | "nbformat": 4,
243 | "nbformat_minor": 4
244 | }
245 |
--------------------------------------------------------------------------------
/tapnet/trajan/attention.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Attention modules for TRAJAN."""
17 |
18 | from __future__ import annotations
19 |
20 | from typing import Optional
21 |
22 | from flax import linen as nn
23 | import jax.numpy as jnp
24 |
25 |
26 | class ImprovedTransformer(nn.Module):
27 | """Improved Transformer using tricks from ViT-22B (w/ cross-attention).
28 |
29 | 1) Normalize keys/queries w/ LayerNorm.
30 | 2) Do some ops in parallel (here: cross + self-attention, but not MLP).
31 | """
32 |
33 | qkv_size: int
34 | num_heads: int
35 | mlp_size: int
36 | num_layers: int
37 |
38 | @nn.compact
39 | def __call__(
40 | self,
41 | queries, # float['... d1'],
42 | inputs_kv=None, # Optional[float['*b N D']]
43 | qk_mask=None, # Optional[bool['...']]
44 | qq_mask=None, # Optional[bool['...']]
45 | ): # -> float['... d2']
46 |
47 | for i in range(self.num_layers):
48 | if qk_mask is not None and len(qk_mask.shape) == len(inputs_kv.shape):
49 | qk_mask = qk_mask[..., jnp.newaxis, :, :]
50 | if qq_mask is not None and len(qq_mask.shape) == len(queries.shape):
51 | qq_mask = qq_mask[..., jnp.newaxis, :, :]
52 |
53 | queries = ImprovedTransformerBlock(
54 | qkv_size=self.qkv_size,
55 | num_heads=self.num_heads,
56 | mlp_size=self.mlp_size,
57 | name=f'layer_{i}',
58 | )(
59 | queries,
60 | inputs_kv=inputs_kv,
61 | qq_mask=qq_mask,
62 | qk_mask=qk_mask,
63 | )
64 |
65 | queries = nn.LayerNorm(use_bias=False, use_scale=True, name='norm_encoder')(
66 | queries
67 | )
68 |
69 | return queries
70 |
71 |
72 | class ImprovedTransformerBlock(nn.Module):
73 | """Improved Transformer block using tricks from ViT-22B (w/ cross-attention).
74 |
75 | 1) RMSNorm instead of LayerNorm.
76 | 2) Normalize keys/queries w/ RMSNorm.
77 | 3) Do some ops in parallel (here: cross + self-attention, but not MLP).
78 | """
79 |
80 | mlp_size: int
81 | num_heads: int
82 | qkv_size: int
83 |
84 | @nn.compact
85 | def __call__(
86 | self,
87 | queries, # float['*b n d'],
88 | inputs_kv, # Optional[float['*b N D']]
89 | qq_mask=None, # Optional[bool['...']]
90 | qk_mask=None, # Optional[bool['...']]
91 | remat_attn: bool = False,
92 | ): # -> Float['*b n d']
93 | width = queries.shape[-1]
94 | normed_queries = nn.LayerNorm(
95 | use_bias=False, use_scale=True, name='norm_q'
96 | )(queries)
97 | attn_out = queries
98 |
99 | # Self-attention.
100 | self_attn_out = ImprovedMHDPAttention(
101 | num_heads=self.num_heads, qk_size=self.qkv_size, name='self_att'
102 | )(
103 | inputs_q=normed_queries,
104 | inputs_kv=normed_queries,
105 | mask=jnp.array(qq_mask, jnp.float32) if qq_mask is not None else None,
106 | )
107 | attn_out += self_attn_out
108 |
109 | # Cross-attention.
110 | if inputs_kv is not None:
111 | cross_attn_out = ImprovedMHDPAttention(
112 | num_heads=self.num_heads, qk_size=self.qkv_size, name='cross_att'
113 | )(
114 | inputs_q=normed_queries,
115 | inputs_kv=inputs_kv,
116 | mask=jnp.array(qk_mask, jnp.float32) if qk_mask is not None else None,
117 | )
118 | attn_out += cross_attn_out
119 |
120 | # MLP.
121 | normed_attn_out = nn.LayerNorm(
122 | use_bias=False, use_scale=True, name='norm_attn'
123 | )(attn_out)
124 | h = nn.gelu(nn.Dense(self.mlp_size, name='MLP_in')(normed_attn_out))
125 | mlp_out = nn.Dense(width, name='MLP_out')(h)
126 | return attn_out + mlp_out
127 |
128 |
129 | class ImprovedMHDPAttention(nn.Module):
130 | """Multi-head dot-product attention.
131 |
132 | Simplified nn.MultiheadDotProductAttention with a few modifications:
133 | - include normalization of keys and queries
134 | - dropped out support for dropout
135 |
136 | Attributes:
137 | num_heads: Number of attention heads.
138 | qk_size: Total dimension of the keys and queries.
139 | v_size: Total dimension of the values. Defaults to qk_size.
140 | """
141 |
142 | num_heads: int
143 | qk_size: int
144 | v_size: Optional[int] = None
145 |
146 | @nn.compact
147 | def __call__(
148 | self,
149 | inputs_q, # float['*b q d1'],
150 | inputs_kv, # float['*b k d2'],
151 | mask=None, # Optional[float['*b #h #q #k']]
152 | ): # -> float['*b q d1']
153 | """Applies multi-head dot product attention on the input data.
154 |
155 | Projects the inputs into multi-headed query, key, and value vectors,
156 | applies dot-product attention and project the results to an output vector.
157 |
158 | Args:
159 | inputs_q: Input tokens from which queries are computed.
160 | inputs_kv: Input tokens from which the keys and queries are computed.
161 | mask: Mask for the attention weights.
162 |
163 | Returns:
164 | output tokens (linear projection of an attention weighted average of value
165 | tokens per query).
166 | """
167 | v_size = self.qk_size if self.v_size is None else self.v_size
168 |
169 | if self.qk_size % self.num_heads:
170 | raise ValueError(f'{self.num_heads=} must divide {self.qk_size=}.')
171 | if v_size % self.num_heads:
172 | raise ValueError(f'{v_size=} must divide {self.num_heads=}.')
173 |
174 | # Project inputs_q to multi-headed queries and keys.
175 | # dimensions are then [B..., Q, H, qk_size]
176 | query = nn.DenseGeneral(
177 | features=(self.num_heads, self.qk_size // self.num_heads),
178 | use_bias=False,
179 | name='dense_query',
180 | )(inputs_q)
181 | key = nn.DenseGeneral(
182 | features=(self.num_heads, self.qk_size // self.num_heads),
183 | use_bias=False,
184 | name='dense_key',
185 | )(inputs_kv)
186 |
187 | # Normalize keys and queries before attention.
188 | query = nn.RMSNorm(name='norm_query')(query)
189 | key = nn.RMSNorm(name='norm_key')(key)
190 |
191 | value = nn.DenseGeneral(
192 | features=(self.num_heads, v_size // self.num_heads),
193 | use_bias=False,
194 | name='dense_value',
195 | )(inputs_kv)
196 |
197 | x = nn.dot_product_attention(query, key, value, mask=mask)
198 |
199 | # Back to the original input dimensions.
200 | out = nn.DenseGeneral(
201 | features=inputs_q.shape[-1],
202 | axis=(-2, -1),
203 | use_bias=True,
204 | name='dense_out',
205 | )(x)
206 |
207 | return out
208 |
--------------------------------------------------------------------------------
/tapnet/tapvid3d/annotation_generation/adt_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Utilities for generating TAPVid3d ADT npz files."""
17 |
18 | # pylint: disable=g-import-not-at-top,g-bad-import-order
19 | import tensorflow as tf
20 |
21 | tf.config.set_visible_devices([], "GPU")
22 |
23 | import os
24 |
25 | import numpy as np
26 | import numpy.typing as npt
27 | from PIL import Image
28 | from projectaria_tools.core import calibration
29 | from projectaria_tools.core.stream_id import StreamId
30 | from projectaria_tools.projects.adt import AriaDigitalTwinDataPathsProvider
31 | from projectaria_tools.projects.adt import AriaDigitalTwinDataProvider
32 |
33 | import tqdm
34 | from tapnet.tapvid3d.annotation_generation import adt_v1v2_mappings
35 |
36 |
37 | # Fixed hyperparameters for generating the ADT data.
38 | N_FRAMES = 300
39 | HEIGHT = 512
40 | WIDTH = 512
41 | FOCAL_LENGTH = 280
42 |
43 |
44 | class ADTVideoProcessor:
45 | """ADT video processor."""
46 |
47 | def __init__(self, sequence_path: str):
48 | self.timestamps_ns, self.gt_provider, self.stream_id = self._load_adt_data(
49 | sequence_path
50 | )
51 |
52 | def _load_adt_data(self, sequence_path: str):
53 | """Loads ADT data for a given sequence."""
54 | paths_provider = AriaDigitalTwinDataPathsProvider(sequence_path)
55 | selected_device_number = 0
56 | data_paths = paths_provider.get_datapaths_by_device_num(
57 | selected_device_number, False
58 | )
59 | gt_provider = AriaDigitalTwinDataProvider(data_paths)
60 | stream_id = StreamId("214-1")
61 | timestamps_ns = np.array(
62 | gt_provider.get_aria_device_capture_timestamps_ns(stream_id)
63 | )
64 | # Remove timestamps without annotations.
65 | timestamps_ns = timestamps_ns[
66 | timestamps_ns > gt_provider.get_start_time_ns()
67 | ]
68 | timestamps_ns = timestamps_ns[timestamps_ns < gt_provider.get_end_time_ns()]
69 | return timestamps_ns, gt_provider, stream_id
70 |
71 | def extract_image_data(
72 | self, chunk_timestamps_ns: list[int]
73 | ) -> tuple[
74 | list[npt.NDArray], list[npt.NDArray], list[npt.NDArray], list[int]
75 | ]:
76 | """Extracts image, depth and segmentation data for a given video chunk."""
77 | sensor_name = (
78 | self.gt_provider.raw_data_provider_ptr().get_label_from_stream_id(
79 | self.stream_id
80 | )
81 | )
82 | device_calib = (
83 | self.gt_provider.raw_data_provider_ptr().get_device_calibration()
84 | )
85 | src_calib = device_calib.get_camera_calib(sensor_name)
86 | identity_tnf = calibration.get_linear_camera_calibration(
87 | 1, 1, 1
88 | ).get_transform_device_camera()
89 | dst_calib = calibration.CameraCalibration(
90 | "camera-rgb",
91 | calibration.CameraModelType.LINEAR,
92 | np.array([FOCAL_LENGTH, FOCAL_LENGTH, WIDTH / 2, HEIGHT / 2]),
93 | identity_tnf,
94 | WIDTH,
95 | HEIGHT,
96 | None,
97 | np.pi,
98 | "LinearCameraCalibration",
99 | )
100 | rgb_ims = []
101 | depth_ims = []
102 | segmentation_ims = []
103 | ok_timestamps_ns = []
104 | for select_timestamps_ns in chunk_timestamps_ns:
105 | depth_with_dt = self.gt_provider.get_depth_image_by_timestamp_ns(
106 | select_timestamps_ns, self.stream_id
107 | )
108 | segmentation_with_dt = (
109 | self.gt_provider.get_segmentation_image_by_timestamp_ns(
110 | select_timestamps_ns, self.stream_id
111 | )
112 | )
113 | image_with_dt = self.gt_provider.get_aria_image_by_timestamp_ns(
114 | select_timestamps_ns, self.stream_id
115 | )
116 | # Check if the result is valid.
117 | if (
118 | image_with_dt.is_valid()
119 | and depth_with_dt.is_valid()
120 | and segmentation_with_dt.is_valid()
121 | ):
122 | ok_timestamps_ns.append(select_timestamps_ns)
123 | else:
124 | continue
125 | image = image_with_dt.data().to_numpy_array()
126 | image = calibration.distort_by_calibration(image, dst_calib, src_calib)
127 | rgb_ims.append(image)
128 |
129 | depth_image = depth_with_dt.data().to_numpy_array()
130 | depth_image = calibration.distort_depth_by_calibration(
131 | depth_image, dst_calib, src_calib
132 | )
133 | depth_ims.append(depth_image)
134 |
135 | segmentation_data = segmentation_with_dt.data().to_numpy_array()
136 | segmentation_data = calibration.distort_label_by_calibration(
137 | segmentation_data, dst_calib, src_calib
138 | )
139 | segmentation_ims.append(segmentation_data)
140 | chunk_timestamps_ns = ok_timestamps_ns
141 | return rgb_ims, depth_ims, segmentation_ims, chunk_timestamps_ns
142 |
143 |
144 | def process_vid(
145 | input_adt_path: str,
146 | input_npz_path: str,
147 | output_npz_path: str,
148 | seq_name: str,
149 | chunks: list[int],
150 | ):
151 | """Processes multiple chunks of a single video."""
152 | adt_v2_name = adt_v1v2_mappings.ADT_MAPPINGS[seq_name]
153 | sequence_path = os.path.join(input_adt_path, adt_v2_name)
154 | adt_processor = ADTVideoProcessor(sequence_path)
155 |
156 | for chunk_idx in tqdm.tqdm(chunks):
157 | track_fn = os.path.join(output_npz_path, f"{seq_name}_{chunk_idx}.npz")
158 | chunk_timestamps_ns = adt_processor.timestamps_ns[
159 | chunk_idx * N_FRAMES : (chunk_idx + 1) * N_FRAMES
160 | ]
161 |
162 | rgb_ims, _, _, _ = adt_processor.extract_image_data(chunk_timestamps_ns)
163 | rgb_ims = [np.array(Image.fromarray(im).rotate(-90)) for im in rgb_ims]
164 | rgb_jpegs = [np.array(tf.io.encode_jpeg(im)).item() for im in rgb_ims]
165 |
166 | # Load query points.
167 | in_npz = np.load(
168 | os.path.join(input_npz_path, f"{seq_name}_{chunk_idx}.npz"),
169 | allow_pickle=True,
170 | )
171 | queries_xyt = in_npz["queries_xyt"]
172 | trajectories = in_npz["tracks_XYZ"]
173 | visibilities = in_npz["visibility"]
174 |
175 | # Verify video means.
176 | video_means = np.stack([np.mean(x, axis=(0, 1)) for x in rgb_ims], axis=0)
177 | assert np.allclose(video_means, in_npz["video_means"], atol=1e-3)
178 |
179 | example = {
180 | "images_jpeg_bytes": rgb_jpegs,
181 | "queries_xyt": queries_xyt,
182 | "tracks_XYZ": trajectories,
183 | "visibility": visibilities,
184 | "fx_fy_cx_cy": np.array(
185 | [FOCAL_LENGTH, FOCAL_LENGTH, WIDTH / 2, HEIGHT / 2]
186 | ),
187 | }
188 | np.savez(track_fn, **example)
189 |
--------------------------------------------------------------------------------
/tapnet/live_demo.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Live Demo for Online TAPIR."""
17 |
18 | import time
19 |
20 | import cv2
21 | import jax
22 | import jax.numpy as jnp
23 | import numpy as np
24 | from tapnet.models import tapir_model
25 | from tapnet.utils import model_utils
26 |
27 |
28 | NUM_POINTS = 8
29 |
30 |
31 | def load_checkpoint(checkpoint_path):
32 | ckpt_state = np.load(checkpoint_path, allow_pickle=True).item()
33 | return ckpt_state["params"], ckpt_state["state"]
34 |
35 | print("Loading checkpoint...")
36 | # --------------------
37 | # Load checkpoint and initialize
38 | params, state = load_checkpoint(
39 | "tapnet/checkpoints/causal_tapir_checkpoint.npy"
40 | )
41 |
42 | tapir = tapir_model.ParameterizedTAPIR(
43 | params=params,
44 | state=state,
45 | tapir_kwargs=dict(
46 | use_causal_conv=True, bilinear_interp_with_depthwise_conv=False
47 | ),
48 | )
49 |
50 |
51 | def online_model_init(frames, points):
52 | feature_grids = tapir.get_feature_grids(frames, is_training=False)
53 | features = tapir.get_query_features(
54 | frames,
55 | is_training=False,
56 | query_points=points,
57 | feature_grids=feature_grids,
58 | )
59 | return features
60 |
61 |
62 | def online_model_predict(frames, features, causal_context):
63 | """Compute point tracks and occlusions given frames and query points."""
64 | feature_grids = tapir.get_feature_grids(frames, is_training=False)
65 | trajectories = tapir.estimate_trajectories(
66 | frames.shape[-3:-1],
67 | is_training=False,
68 | feature_grids=feature_grids,
69 | query_features=features,
70 | query_points_in_video=None,
71 | query_chunk_size=64,
72 | causal_context=causal_context,
73 | get_causal_context=True,
74 | )
75 | causal_context = trajectories["causal_context"]
76 | del trajectories["causal_context"]
77 | return {k: v[-1] for k, v in trajectories.items()}, causal_context
78 |
79 |
80 | def get_frame(video_capture):
81 | r_val, image = video_capture.read()
82 | trunc = np.abs(image.shape[1] - image.shape[0]) // 2
83 | if image.shape[1] > image.shape[0]:
84 | image = image[:, trunc:-trunc]
85 | elif image.shape[1] < image.shape[0]:
86 | image = image[trunc:-trunc]
87 | return r_val, image
88 |
89 |
90 | print("Welcome to the TAPIR live demo.")
91 | print("Please note that if the framerate is low (<~12 fps), TAPIR performance")
92 | print("may degrade and you may need a more powerful GPU.")
93 |
94 | print("Creating model...")
95 | online_init_apply = jax.jit(online_model_init)
96 |
97 | online_predict_apply = jax.jit(online_model_predict)
98 |
99 | print("Initializing camera...")
100 | # --------------------
101 | # Start point tracking
102 | vc = cv2.VideoCapture(0)
103 |
104 | vc.set(cv2.CAP_PROP_FRAME_HEIGHT, 240)
105 |
106 | if vc.isOpened(): # try to get the first frame
107 | rval, frame = get_frame(vc)
108 | else:
109 | raise ValueError("Unable to open camera.")
110 |
111 | pos = tuple()
112 | query_frame = True
113 | have_point = [False] * NUM_POINTS
114 | query_features = None
115 | causal_state = None
116 | next_query_idx = 0
117 |
118 | print("Compiling jax functions (this may take a while...)")
119 | # --------------------
120 | # Call one time to compile
121 | query_points = jnp.zeros([NUM_POINTS, 3], dtype=jnp.float32)
122 | _ = online_init_apply(
123 | frames=model_utils.preprocess_frames(frame[None, None]),
124 | points=query_points[None, 0:1],
125 | )
126 | jax.block_until_ready(query_features)
127 | query_features = online_init_apply(
128 | frames=model_utils.preprocess_frames(frame[None, None]),
129 | points=query_points[None, :],
130 | )
131 |
132 | causal_state = tapir.construct_initial_causal_state(
133 | NUM_POINTS, len(query_features.resolutions) - 1
134 | )
135 |
136 | prediction, causal_state = online_predict_apply(
137 | frames=model_utils.preprocess_frames(frame[None, None]),
138 | features=query_features,
139 | causal_context=causal_state,
140 | )
141 |
142 | jax.block_until_ready(prediction["tracks"])
143 |
144 | last_click_time = 0
145 |
146 |
147 | def mouse_click(event, x, y, flags, param):
148 | del flags, param
149 | global pos, query_frame, last_click_time
150 |
151 | # event fires multiple times per click sometimes??
152 | if (time.time() - last_click_time) < 0.5:
153 | return
154 |
155 | if event == cv2.EVENT_LBUTTONDOWN:
156 | pos = (y, frame.shape[1] - x)
157 | query_frame = True
158 | last_click_time = time.time()
159 |
160 |
161 | cv2.namedWindow("Point Tracking")
162 | cv2.setMouseCallback("Point Tracking", mouse_click)
163 |
164 | t = time.time()
165 | step_counter = 0
166 |
167 | print("Press ESC to exit.")
168 |
169 | while rval:
170 | rval, frame = get_frame(vc)
171 | if query_frame:
172 | query_points = jnp.array((0,) + pos, dtype=jnp.float32)
173 |
174 | init_query_features = online_init_apply(
175 | frames=model_utils.preprocess_frames(frame[None, None]),
176 | points=query_points[None, None],
177 | )
178 | query_frame = False
179 | query_features, causal_state = tapir.update_query_features(
180 | query_features=query_features,
181 | new_query_features=init_query_features,
182 | idx_to_update=np.array([next_query_idx]),
183 | causal_state=causal_state,
184 | )
185 | have_point[next_query_idx] = True
186 | next_query_idx = (next_query_idx + 1) % NUM_POINTS
187 | if pos:
188 | prediction, causal_state = online_predict_apply(
189 | frames=model_utils.preprocess_frames(frame[None, None]),
190 | features=query_features,
191 | causal_context=causal_state,
192 | )
193 | track = prediction["tracks"][0, :, 0]
194 | occlusion = prediction["occlusion"][0, :, 0]
195 | expected_dist = prediction["expected_dist"][0, :, 0]
196 | visibles = model_utils.postprocess_occlusions(occlusion, expected_dist)
197 | track = np.round(track)
198 |
199 | for i, _ in enumerate(have_point):
200 | if visibles[i] and have_point[i]:
201 | cv2.circle(
202 | frame, (int(track[i, 0]), int(track[i, 1])), 5, (255, 0, 0), -1
203 | )
204 | if track[i, 0] < 16 and track[i, 1] < 16:
205 | print((i, next_query_idx))
206 | cv2.imshow("Point Tracking", frame[:, ::-1])
207 | if pos:
208 | step_counter += 1
209 | if time.time() - t > 5:
210 | print(f"{step_counter/(time.time()-t)} frames per second")
211 | t = time.time()
212 | step_counter = 0
213 | else:
214 | t = time.time()
215 | key = cv2.waitKey(1)
216 |
217 | if key == 27: # exit on ESC
218 | break
219 |
220 | cv2.destroyWindow("Point Tracking")
221 | vc.release()
222 |
--------------------------------------------------------------------------------
/tapnet/tapvid/generate_tapvid.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Python script to generate a pickle with annotations from raw videos."""
17 |
18 | import csv
19 | import dataclasses
20 | import io
21 | import math
22 | import os
23 | import pickle
24 | from typing import Dict, Iterator, List, Sequence, Tuple
25 |
26 | from absl import app
27 | from absl import flags
28 | from absl import logging
29 | import ffmpeg
30 | import numpy as np
31 | from PIL import Image
32 |
33 | FLAGS = flags.FLAGS
34 |
35 | flags.DEFINE_string('input_csv_path', None, 'Path to the input csv.')
36 | flags.DEFINE_string(
37 | 'output_base_path',
38 | None,
39 | 'Path to the output folder where pickle files will be stored.',
40 | )
41 | flags.DEFINE_string(
42 | 'video_root_path',
43 | None,
44 | (
45 | 'Path to the root directory of the extracted Kinetics-700-2020 videos'
46 | ' from the validation set.'
47 | ),
48 | )
49 | flags.DEFINE_integer('num_shards', 10, 'Number of pickle shards to output.')
50 |
51 | _JPEG_HEADER = b'\xff\xd8'
52 |
53 |
54 | @dataclasses.dataclass(frozen=True)
55 | class Point:
56 | x: float
57 | y: float
58 | occluded: bool
59 |
60 |
61 | @dataclasses.dataclass(frozen=True)
62 | class Track:
63 | points: Tuple[Point, ...]
64 |
65 |
66 | @dataclasses.dataclass(frozen=True)
67 | class Video:
68 | youtube_id: str
69 | start_time_sec: int
70 | end_time_sec: int
71 | video_path: str
72 | tracks: Tuple[Track, ...]
73 |
74 |
75 | def csv_to_dataset(
76 | csv_path: str, videos_path: Dict[str, str]
77 | ) -> Tuple[Video, ...]:
78 | """Reads the input CSV and creates a list of `Video`s out of that."""
79 |
80 | def points(row: Sequence[str]) -> Iterator[Point]:
81 | for i in range(250):
82 | x, y, occ = row[3 + 3 * i : 3 + 3 * i + 3]
83 | x = float(x)
84 | y = float(y)
85 | assert occ in ('0', '1')
86 | occ = occ == '1'
87 | yield Point(x, y, occ)
88 |
89 | logging.info('Reading CSV "%s".', csv_path)
90 |
91 | with open(csv_path) as f:
92 | reader = csv.reader(f, delimiter=',')
93 |
94 | tracks_per_video: Dict[Tuple[str, int, int], List[Track]] = {}
95 | for row in reader:
96 | assert len(row) == 3 + 3 * 250
97 |
98 | youtube_id, start_time_sec, end_time_sec = row[:3]
99 | start_time_sec = int(start_time_sec)
100 | end_time_sec = int(end_time_sec)
101 | key = (youtube_id, start_time_sec, end_time_sec)
102 |
103 | track = Track(tuple(points(row)))
104 |
105 | if key not in tracks_per_video:
106 | tracks_per_video[key] = []
107 | tracks_per_video[key].append(track)
108 |
109 | def videos() -> Iterator[Video]:
110 | for key, tracks in tracks_per_video.items():
111 | youtube_id, start_time_sec, end_time_sec = key
112 |
113 | name = f'{youtube_id}_{start_time_sec:06}_{end_time_sec:06}'
114 | if name not in videos_path:
115 | logging.warning('Video "%s" not downloaded. Skipping it.', name)
116 | continue
117 | video_path = videos_path[name]
118 |
119 | yield Video(
120 | youtube_id, start_time_sec, end_time_sec, video_path, tuple(tracks)
121 | )
122 |
123 | return tuple(videos())
124 |
125 |
126 | def get_paths_to_videos(video_root_path: str) -> Dict[str, str]:
127 | """Returns the relative path to each downloaded video."""
128 | logging.info('Reading all videos in subfolders of "%s".', video_root_path)
129 | video_to_path: Dict[str, str] = {}
130 | for folder_or_video in os.listdir(video_root_path):
131 | path = os.path.join(video_root_path, folder_or_video)
132 | if os.path.isdir(path):
133 | subfolder_paths = get_paths_to_videos(path)
134 | for k, v in subfolder_paths.items():
135 | assert k not in video_to_path
136 | video_to_path[k] = v
137 | elif folder_or_video.endswith('.mp4'):
138 | name = folder_or_video[:-4] # Remove '.mp4'.
139 | assert name not in video_to_path
140 | video_to_path[name] = path
141 |
142 | return video_to_path
143 |
144 |
145 | def extract_frames(video_path: str, fps: float) -> Tuple[bytes, ...]:
146 | """Extracts list of jpeg bytes from the given video using ffmpeg."""
147 | cmd = (
148 | ffmpeg.input(video_path)
149 | .filter('fps', fps=fps)
150 | .output('pipe:', format='image2pipe')
151 | )
152 | jpeg_bytes, _ = cmd.run(capture_stdout=True, quiet=True)
153 | jpeg_bytes = jpeg_bytes.split(_JPEG_HEADER)[1:]
154 | jpeg_bytes = map(lambda x: _JPEG_HEADER + x, jpeg_bytes)
155 | return tuple(jpeg_bytes)
156 |
157 |
158 | def generate_example(video: Video) -> Dict[str, np.ndarray]:
159 | """Generates a dictionary with the info from a `Video`."""
160 | example: Dict[str, np.ndarray] = {}
161 |
162 | imgs_encoded = extract_frames(video.video_path, 25.0)
163 | if len(imgs_encoded) > 250:
164 | imgs_encoded = imgs_encoded[:250]
165 |
166 | if len(imgs_encoded) < 250:
167 | # Clip is shorter than 10s.
168 | num_frames = len(imgs_encoded)
169 | new_tracks = tuple(
170 | Track(tuple(t.points[:num_frames])) for t in video.tracks
171 | )
172 | video = Video(
173 | video.youtube_id,
174 | video.start_time_sec,
175 | video.end_time_sec,
176 | video.video_path,
177 | new_tracks,
178 | )
179 |
180 | example['video'] = np.array(imgs_encoded)
181 | byteio = io.BytesIO(imgs_encoded[0])
182 | img = Image.open(byteio)
183 | height, width, _ = np.array(img).shape
184 |
185 | points = []
186 | occluded = []
187 | for track in video.tracks:
188 | points.append(
189 | [
190 | [(p.x * width - 0.5) / width, (p.y * height - 0.5) / height]
191 | for p in track.points
192 | ]
193 | )
194 | occluded.append([p.occluded for p in track.points])
195 |
196 | example['points'] = np.array(points, dtype=np.float64)
197 | example['occluded'] = np.array(occluded, dtype=bool)
198 |
199 | return example
200 |
201 |
202 | def main(argv: Sequence[str]) -> None:
203 | del argv
204 |
205 | output_folder = FLAGS.output_base_path
206 | if output_folder and not os.path.exists(output_folder):
207 | os.makedirs(output_folder)
208 |
209 | # Reads data.
210 | videos_path = get_paths_to_videos(FLAGS.video_root_path)
211 | videos = csv_to_dataset(FLAGS.input_csv_path, videos_path)
212 |
213 | # Process the dataset and store pickles.
214 | num_examples_per_shard = int(math.ceil(len(videos) / FLAGS.num_shards))
215 | shard = 0
216 | data = []
217 | for i, video in enumerate(videos):
218 | print(
219 | 'Processing example %d of %d (%d%%) \r'
220 | % (i, len(videos), i * 100 / len(videos)),
221 | end='',
222 | )
223 | data.append(generate_example(video))
224 | if i == len(videos) - 1 or len(data) == num_examples_per_shard:
225 | shard_path = os.path.join(
226 | output_folder, f'{shard:04}_of_{FLAGS.num_shards:04}.pkl'
227 | )
228 | logging.info('Writing file "%s".', shard_path)
229 | with open(shard_path, 'wb') as f:
230 | pickle.dump(data, f)
231 | data.clear()
232 | shard += 1
233 |
234 |
235 | if __name__ == '__main__':
236 | app.run(main)
237 |
--------------------------------------------------------------------------------
/tapnet/pytorch_live_demo.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Live Demo for PyTorch Online TAPIR."""
17 |
18 | import time
19 |
20 | import cv2
21 | import numpy as np
22 | from tapnet.torch import tapir_model
23 | import torch
24 | import torch.nn.functional as F
25 | import tree
26 |
27 | NUM_POINTS = 8
28 |
29 |
30 | def preprocess_frames(frames):
31 | """Preprocess frames to model inputs.
32 |
33 | Args:
34 | frames: [num_frames, height, width, 3], [0, 255], np.uint8
35 |
36 | Returns:
37 | frames: [num_frames, height, width, 3], [-1, 1], np.float32
38 | """
39 | frames = frames.float()
40 | frames = frames / 255 * 2 - 1
41 | return frames
42 |
43 |
44 | def online_model_init(frames, points):
45 | """Initialize query features for the query points."""
46 | frames = preprocess_frames(frames)
47 | feature_grids = model.get_feature_grids(frames, is_training=False)
48 | features = model.get_query_features(
49 | frames,
50 | is_training=False,
51 | query_points=points,
52 | feature_grids=feature_grids,
53 | )
54 | return features
55 |
56 |
57 | def postprocess_occlusions(occlusions, expected_dist):
58 | visibles = (1 - F.sigmoid(occlusions)) * (1 - F.sigmoid(expected_dist)) > 0.5
59 | return visibles
60 |
61 |
62 | def online_model_predict(frames, features, causal_context):
63 | """Compute point tracks and occlusions given frames and query points."""
64 | frames = preprocess_frames(frames)
65 | feature_grids = model.get_feature_grids(frames, is_training=False)
66 | trajectories = model.estimate_trajectories(
67 | frames.shape[-3:-1],
68 | is_training=False,
69 | feature_grids=feature_grids,
70 | query_features=features,
71 | query_points_in_video=None,
72 | query_chunk_size=64,
73 | causal_context=causal_context,
74 | get_causal_context=True,
75 | )
76 | causal_context = trajectories["causal_context"]
77 | del trajectories["causal_context"]
78 | # Take only the predictions for the final resolution.
79 | # For running on higher resolution, it's typically better to average across
80 | # resolutions.
81 | tracks = trajectories["tracks"][-1]
82 | occlusions = trajectories["occlusion"][-1]
83 | uncertainty = trajectories["expected_dist"][-1]
84 | visibles = postprocess_occlusions(occlusions, uncertainty)
85 | return tracks, visibles, causal_context
86 |
87 |
88 | def get_frame(video_capture):
89 | r_val, image = video_capture.read()
90 | trunc = np.abs(image.shape[1] - image.shape[0]) // 2
91 | if image.shape[1] > image.shape[0]:
92 | image = image[:, trunc:-trunc]
93 | elif image.shape[1] < image.shape[0]:
94 | image = image[trunc:-trunc]
95 | return r_val, image
96 |
97 |
98 | print("Welcome to the TAPIR PyTorch live demo.")
99 | print("Please note that if the framerate is low (<~12 fps), TAPIR performance")
100 | print("may degrade and you may need a more powerful GPU.")
101 |
102 | if torch.cuda.is_available():
103 | device = torch.device("cuda")
104 | else:
105 | device = torch.device("cpu")
106 |
107 | # --------------------
108 | # Load checkpoint and initialize
109 | print("Creating model...")
110 | model = tapir_model.TAPIR(pyramid_level=1, use_casual_conv=True)
111 | print("Loading checkpoint...")
112 | model.load_state_dict(
113 | torch.load("tapnet/checkpoints/causal_bootstapir_checkpoint.pt")
114 | )
115 | model = model.to(device)
116 | model = model.eval()
117 | torch.set_grad_enabled(False)
118 |
119 | # --------------------
120 | # Start point tracking
121 | print("Initializing camera...")
122 | vc = cv2.VideoCapture(0)
123 |
124 | vc.set(cv2.CAP_PROP_FRAME_HEIGHT, 240)
125 |
126 | if vc.isOpened(): # try to get the first frame
127 | rval, frame = get_frame(vc)
128 | else:
129 | raise ValueError("Unable to open camera.")
130 |
131 | pos = tuple()
132 | query_frame = True
133 | have_point = [False] * NUM_POINTS
134 |
135 | query_points = torch.zeros([NUM_POINTS, 3], dtype=torch.float32)
136 | query_points = query_points.to(device)
137 | frame = torch.tensor(frame).to(device)
138 |
139 | query_features = online_model_init(
140 | frames=frame[None, None], points=query_points[None, :]
141 | )
142 |
143 | causal_state = model.construct_initial_causal_state(
144 | NUM_POINTS, len(query_features.resolutions) - 1
145 | )
146 | causal_state = tree.map_structure(lambda x: x.to(device), causal_state)
147 |
148 | prediction, visible, causal_state = online_model_predict(
149 | frames=frame[None, None],
150 | features=query_features,
151 | causal_context=causal_state,
152 | )
153 |
154 | next_query_idx = 0
155 | last_click_time = 0
156 |
157 |
158 | def mouse_click(event, x, y, flags, param):
159 | del flags, param
160 | global pos, query_frame, last_click_time
161 |
162 | # event fires multiple times per click sometimes??
163 | if (time.time() - last_click_time) < 0.5:
164 | return
165 |
166 | if event == cv2.EVENT_LBUTTONDOWN:
167 | pos = (y, frame.shape[1] - x)
168 | query_frame = True
169 | last_click_time = time.time()
170 |
171 |
172 | cv2.namedWindow("Point Tracking")
173 | cv2.setMouseCallback("Point Tracking", mouse_click)
174 |
175 | t = time.time()
176 | step_counter = 0
177 |
178 | print("Press ESC to exit.")
179 |
180 | while rval:
181 | rval, frame = get_frame(vc)
182 | numpy_frame = frame
183 | if query_frame:
184 | query_points = np.array((0,) + pos, dtype=np.float32)
185 | frame = torch.tensor(frame).to(device)
186 | query_points = torch.tensor(query_points).to(device)
187 |
188 | init_query_features = online_model_init(
189 | frames=frame[None, None], points=query_points[None, None]
190 | )
191 | query_frame = False
192 | query_features, causal_state = model.update_query_features(
193 | query_features=query_features,
194 | new_query_features=init_query_features,
195 | idx_to_update=np.array([next_query_idx]),
196 | causal_state=causal_state,
197 | )
198 | have_point[next_query_idx] = True
199 | next_query_idx = (next_query_idx + 1) % NUM_POINTS
200 | if pos:
201 | frame = torch.tensor(frame).to(device)
202 | track, visible, causal_state = online_model_predict(
203 | frames=frame[None, None],
204 | features=query_features,
205 | causal_context=causal_state,
206 | )
207 | track = np.round(track.cpu().numpy())
208 | visible = visible.cpu().numpy()
209 |
210 | for i, _ in enumerate(have_point):
211 | if visible[0, i, 0] and have_point[i]:
212 | cv2.circle(
213 | numpy_frame,
214 | (int(track[0, i, 0, 0]), int(track[0, i, 0, 1])),
215 | 5,
216 | (255, 0, 0),
217 | -1,
218 | )
219 | if track[0, i, 0, 0] < 16 and track[0, i, 0, 1] < 16:
220 | print((i, next_query_idx))
221 | cv2.imshow("Point Tracking", numpy_frame[:, ::-1])
222 | if pos:
223 | step_counter += 1
224 | if time.time() - t > 5:
225 | print(f"{step_counter/(time.time()-t)} frames per second")
226 | t = time.time()
227 | step_counter = 0
228 | else:
229 | t = time.time()
230 | key = cv2.waitKey(1)
231 |
232 | if key == 27: # exit on ESC
233 | break
234 |
235 | cv2.destroyWindow("Point Tracking")
236 | vc.release()
237 |
--------------------------------------------------------------------------------
/tapnet/training/task.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Abstract task interface with documentation.
17 |
18 | """
19 |
20 | import abc
21 | from typing import Mapping, Optional, Tuple
22 |
23 | import chex
24 | import typing_extensions
25 |
26 |
27 | class SharedModule(typing_extensions.Protocol):
28 |
29 | def __call__(
30 | self,
31 | video: chex.Array,
32 | is_training: bool,
33 | query_points: chex.Array,
34 | query_chunk_size: Optional[int] = None,
35 | get_query_feats: bool = False,
36 | **kwargs,
37 | ) -> Mapping[str, chex.Array]:
38 | """Runs a forward pass of a module.
39 |
40 | Args:
41 | video: A 4-D or 5-D tensor representing a batch of sequences of images. In
42 | the 4-D case, we assume the entire batch has been concatenated along the
43 | batch dimension, one sequence after the other. This can speed up
44 | inference on the TPU and save memory.
45 | is_training: Whether we are training.
46 | query_points: The query points for which we compute tracks.
47 | query_chunk_size: When computing cost volumes, break the queries into
48 | chunks of this size to save memory.
49 | get_query_feats: If True, also return the features for each query obtained
50 | using bilinear interpolation from the feature grid
51 | **kwargs: Additional module-specific parameters.
52 |
53 | Returns:
54 | Module outputs.
55 | """
56 |
57 |
58 | class WrappedForwardFn(typing_extensions.Protocol):
59 | """Forward function, wrapped by haiku.
60 |
61 | This wrapped forward function will inject the shared_modules and allow them
62 | to use shared params. It should be called inside a loss_fn using the same
63 | signature as `Task.forward_fn` (minus the shared_modules).
64 | """
65 |
66 | def __call__(
67 | self,
68 | params: chex.ArrayTree,
69 | state: chex.ArrayTree,
70 | rng: chex.PRNGKey,
71 | inputs: chex.ArrayTree,
72 | is_training: bool,
73 | input_key: Optional[str] = None,
74 | query_chunk_size: int = 16,
75 | get_query_feats: bool = True,
76 | ) -> Mapping[str, chex.Array]:
77 | """Forward pass for predicting point tracks.
78 |
79 | Args:
80 | params: hk.Params with the model parameters
81 | state: hk.State with the model state
82 | rng: jax.random.PRNGKey for random number generation.
83 | inputs: Input dict. Inference will be performed on will be performed on
84 | inputs[input_key]['video'] (with fallback to the input_key specified in
85 | the constructor). Input videos should be a standard video tensor
86 | ([batch, num_frames, height, width, 3]) normalize to [-1,1].
87 | inputs[input_key]['query_points'] specifies the query point locations,
88 | of shape [batch, num_queries, 3], where each query is [t,y,x]
89 | coordinates normalized to between -1 and 1.
90 | is_training: Whether the model is in training mode.
91 | input_key: Run on inputs[input_key]['video']. If None, use the input_key
92 | from the constructor.
93 | query_chunk_size: Compute predictions on this many queries simultaneously.
94 | This saves memory as the cost volumes can be very large.
95 | get_query_feats: If True, also return features for each query.
96 |
97 | Returns:
98 | Result dict produced by calling the model.
99 | """
100 |
101 |
102 | class Task(abc.ABC):
103 | """An abstract Task definition."""
104 |
105 | @abc.abstractmethod
106 | def forward_fn(
107 | self,
108 | inputs: chex.ArrayTree,
109 | is_training: bool,
110 | shared_modules: Optional[Mapping[str, SharedModule]] = None,
111 | ) -> chex.ArrayTree:
112 | """Run the model forward pass and construct all required Haiku modules.
113 |
114 | Args:
115 | inputs: model input tensors. This is a dict keyed by dataset name, where
116 | the value for each key is an item from the specified dataset.
117 | is_training: Is the forward pass in training mode or not.
118 | shared_modules: A dict of Haiku modules, keyed by module name, which
119 | can be used to construct the modules which are shared across different
120 | tasks.
121 |
122 | Returns:
123 | Anything. The important part is that this must construct all modules that
124 | Haiku needs to initialize.
125 |
126 | """
127 |
128 | def get_gradients(
129 | self,
130 | params: chex.ArrayTree,
131 | state: chex.ArrayTree,
132 | inputs: chex.ArrayTree,
133 | rng: chex.PRNGKey,
134 | global_step: chex.Array,
135 | wrapped_forward_fn: WrappedForwardFn,
136 | is_training: bool = True,
137 | ) -> Tuple[chex.ArrayTree, chex.ArrayTree, Mapping[str, chex.Array]]:
138 | """Get gradients for this tasks's loss function.
139 |
140 | Params, state, inputs, rng, and global_step are pmapped, i.e. a separate
141 | copy on each device.
142 |
143 | Args:
144 | params: Haiku parameters
145 | state: Haiku state
146 | inputs: model input tensors. This is a dict keyed by dataset name, where
147 | the value for each key is an item from the specified dataset.
148 | rng: random number state
149 | global_step: global step
150 | wrapped_forward_fn: A wrapper for the forward function that will inject
151 | the shared_modules and allow them to use shared params. It should be
152 | called inside a loss_fn using the same signature as forward_fn
153 | (minus the shared_modules).
154 | is_training: Is the forward pass in training mode or not.
155 |
156 | Returns:
157 | grads: A set of gradients compatible with optax apply_gradients (these
158 | will be summed across tasks).
159 | state: An updated Haiku state. The returned state will be passed to the
160 | next task in the list.
161 | scalars: A dict of (pmapped) scalars to be logged for this task. All
162 | dict keys will have the task name prepended before they are logged.
163 |
164 | """
165 | raise NotImplementedError()
166 |
167 | def evaluate(
168 | self,
169 | global_step: chex.Array,
170 | params: chex.ArrayTree,
171 | state: chex.ArrayTree,
172 | rng: chex.PRNGKey,
173 | wrapped_forward_fn: WrappedForwardFn,
174 | mode: str,
175 | ) -> Mapping[str, chex.Array]:
176 | """Evaluate this task's performance on a downstream benchmark.
177 |
178 | Args:
179 | global_step: global step
180 | params: Haiku parameters
181 | state: Haiku state
182 | rng: random number state
183 | wrapped_forward_fn: A wrapper for the forward function that will inject
184 | the shared_modules and allow them to use shared params.
185 | mode: A string mode used to determine, e.g., which dataset or split to
186 | evaluate on. This will be the same value as the 'mode' parameter
187 | used to launch different eval jobs in Jaxline.
188 |
189 | Returns:
190 | scalars: A dict of scalars to be logged for this task. All
191 | dict keys will have the task name prepended before they are logged.
192 |
193 | """
194 | raise NotImplementedError()
195 |
196 |
197 |
198 |
--------------------------------------------------------------------------------
/tapnet/tapnext/tapnext_torch_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Utils for TAPNext torch implementation."""
17 |
18 | import re
19 | import numpy as np
20 | import torch
21 | import torch.nn.functional as F
22 |
23 |
24 | def get_window(coord, softmax, radius: int = 8):
25 | b = coord.shape[0]
26 | start = torch.floor(coord - radius - 0.5).int()
27 | start.clamp_(min=0)
28 | indices = start + torch.arange(radius * 2 + 1, device=softmax.device).repeat(
29 | b, 1
30 | )
31 | # this is to simulate one corner case of jax implementation
32 | shift = (indices.max(1).values - softmax.shape[1] + 1).clamp(min=0)
33 | indices -= shift.unsqueeze(1)
34 | softmax = softmax.gather(dim=1, index=indices)
35 | return softmax, indices + 0.5
36 |
37 |
38 | def tracker_certainty(coord_yx, track_logits, radius=8):
39 | """Computes the certainty of the tracker."""
40 | shape = coord_yx.shape[:-1]
41 | coord_yx = coord_yx.flatten(0, -2)
42 | track_logits = track_logits.flatten(0, -2)
43 | # track_logits.shape == [b, 512]
44 | # coord_yx.shape == [b, 2]
45 | logits_y, logits_x = track_logits.chunk(2, dim=-1)
46 | track_softmax_y = F.softmax(logits_y, dim=-1)
47 | track_softmax_x = F.softmax(logits_x, dim=-1)
48 | sm_y, coord_y = get_window(coord_yx[:, 0:1], track_softmax_y)
49 | sm_x, coord_x = get_window(coord_yx[:, 1:2], track_softmax_x)
50 | sm = sm_y[..., :, None] * sm_x[..., None, :]
51 | grid_x, grid_y = torch.vmap(torch.meshgrid)(coord_x, coord_y)
52 | # grid_x.shape == [b, N, N]
53 | grid = torch.stack([grid_y, grid_x], dim=-1)
54 | in_radius = ((grid - coord_yx[:, None, None]) ** 2).sum(-1) <= (
55 | (radius**2) + 1e-8
56 | )
57 | return (sm * in_radius).sum(-1).sum(-1).reshape(*shape, 1)
58 |
59 |
60 | def restore_model_from_jax_checkpoint(model, ckpt_path):
61 | """Restores a TAPNext model from a JAX checkpoint."""
62 | ckpt = {k: v for k, v in np.load(ckpt_path).items()}
63 | model.lin_proj.weight.data.copy_(
64 | torch.tensor(ckpt['backbone/embedding/kernel'][0]).permute(3, 2, 0, 1)
65 | )
66 | model.lin_proj.bias.data.copy_(torch.tensor(ckpt['backbone/embedding/bias']))
67 | model.mask_token.data.copy_(torch.tensor(ckpt['backbone/mask_token']))
68 | model.point_query_token.data.copy_(
69 | torch.tensor(ckpt['backbone/point_query_token'])
70 | )
71 | model.unknown_token.data.copy_(torch.tensor(ckpt['backbone/unknown_token']))
72 | model.image_pos_emb.data.copy_(torch.tensor(ckpt['backbone/pos_embedding']))
73 | model.encoder_norm.weight.data.copy_(
74 | torch.tensor(ckpt['backbone/Transformer/encoder_norm/scale'])
75 | )
76 | model.encoder_norm.bias.data.copy_(
77 | torch.tensor(ckpt['backbone/Transformer/encoder_norm/bias'])
78 | )
79 | for layer in range(12):
80 | # convert ssm part
81 | prefix = f'backbone/Transformer/encoderblock_{layer}/ssm_block'
82 | ssm_params = {
83 | key: torch.tensor(
84 | ckpt[
85 | f'{prefix}/'
86 | + re.sub('weight', 'kernel', re.sub(r'\.', '/', key))
87 | ]
88 | )
89 | for key, _ in model.blocks[layer].ssm_block.named_parameters()
90 | }
91 | for key in ssm_params:
92 | if 'weight' in key:
93 | ssm_params[key] = ssm_params[key].T
94 | model.blocks[layer].ssm_block.load_state_dict(ssm_params)
95 |
96 | # convert vit part
97 | vit_params = {
98 | re.sub(
99 | f'backbone/Transformer/encoderblock_{layer}/vit_block/', '', k
100 | ): v
101 | for k, v in ckpt.items()
102 | if f'backbone/Transformer/encoderblock_{layer}/vit_block' in k
103 | }
104 | torch_vit_params = {}
105 | torch_vit_params['ln_1.weight'] = vit_params['LayerNorm_0/scale']
106 | torch_vit_params['ln_1.bias'] = vit_params['LayerNorm_0/bias']
107 | torch_vit_params['ln_2.weight'] = vit_params['LayerNorm_1/scale']
108 | torch_vit_params['ln_2.bias'] = vit_params['LayerNorm_1/bias']
109 | torch_vit_params['mlp.0.weight'] = vit_params['MlpBlock_0/Dense_0/kernel'].T
110 | torch_vit_params['mlp.0.bias'] = vit_params['MlpBlock_0/Dense_0/bias']
111 | torch_vit_params['mlp.3.weight'] = vit_params['MlpBlock_0/Dense_1/kernel'].T
112 | torch_vit_params['mlp.3.bias'] = vit_params['MlpBlock_0/Dense_1/bias']
113 | torch_vit_params['self_attention.in_proj_weight'] = np.concatenate(
114 | [
115 | vit_params['MultiHeadDotProductAttention_0/query/kernel']
116 | .reshape(768, 768)
117 | .T,
118 | vit_params['MultiHeadDotProductAttention_0/key/kernel']
119 | .reshape(768, 768)
120 | .T,
121 | vit_params['MultiHeadDotProductAttention_0/value/kernel']
122 | .reshape(768, 768)
123 | .T,
124 | ],
125 | axis=0,
126 | )
127 | torch_vit_params['self_attention.in_proj_bias'] = np.concatenate([
128 | vit_params['MultiHeadDotProductAttention_0/query/bias'].flatten(),
129 | vit_params['MultiHeadDotProductAttention_0/key/bias'].flatten(),
130 | vit_params['MultiHeadDotProductAttention_0/value/bias'].flatten(),
131 | ])
132 | torch_vit_params['self_attention.out_proj.weight'] = (
133 | vit_params['MultiHeadDotProductAttention_0/out/kernel']
134 | .reshape(768, 768)
135 | .T
136 | )
137 | torch_vit_params['self_attention.out_proj.bias'] = vit_params[
138 | 'MultiHeadDotProductAttention_0/out/bias'
139 | ].flatten()
140 | for k in torch_vit_params:
141 | torch_vit_params[k] = torch.tensor(np.array(torch_vit_params[k]))
142 | model.blocks[layer].vit_block.load_state_dict(torch_vit_params)
143 | model.visible_head[0].weight.data.copy_(
144 | torch.from_numpy(ckpt['visible_head/layers_0/kernel'].T)
145 | )
146 | model.visible_head[0].bias.data.copy_(
147 | torch.from_numpy(ckpt['visible_head/layers_0/bias'])
148 | )
149 | model.visible_head[1].weight.data.copy_(
150 | torch.from_numpy(ckpt['visible_head/layers_1/scale'])
151 | )
152 | model.visible_head[1].bias.data.copy_(
153 | torch.from_numpy(ckpt['visible_head/layers_1/bias'])
154 | )
155 | model.visible_head[3].weight.data.copy_(
156 | torch.from_numpy(ckpt['visible_head/layers_3/kernel'].T)
157 | )
158 | model.visible_head[3].bias.data.copy_(
159 | torch.from_numpy(ckpt['visible_head/layers_3/bias'])
160 | )
161 | model.visible_head[4].weight.data.copy_(
162 | torch.from_numpy(ckpt['visible_head/layers_4/scale'])
163 | )
164 | model.visible_head[4].bias.data.copy_(
165 | torch.from_numpy(ckpt['visible_head/layers_4/bias'])
166 | )
167 | model.visible_head[6].weight.data.copy_(
168 | torch.from_numpy(ckpt['visible_head/layers_6/kernel'].T)
169 | )
170 | model.visible_head[6].bias.data.copy_(
171 | torch.from_numpy(ckpt['visible_head/layers_6/bias'])
172 | )
173 |
174 | model.coordinate_head[0].weight.data.copy_(
175 | torch.from_numpy(ckpt['coordinate_head/layers_0/kernel'].T)
176 | )
177 | model.coordinate_head[0].bias.data.copy_(
178 | torch.from_numpy(ckpt['coordinate_head/layers_0/bias'])
179 | )
180 | model.coordinate_head[1].weight.data.copy_(
181 | torch.from_numpy(ckpt['coordinate_head/layers_1/scale'])
182 | )
183 | model.coordinate_head[1].bias.data.copy_(
184 | torch.from_numpy(ckpt['coordinate_head/layers_1/bias'])
185 | )
186 | model.coordinate_head[3].weight.data.copy_(
187 | torch.from_numpy(ckpt['coordinate_head/layers_3/kernel'].T)
188 | )
189 | model.coordinate_head[3].bias.data.copy_(
190 | torch.from_numpy(ckpt['coordinate_head/layers_3/bias'])
191 | )
192 | model.coordinate_head[4].weight.data.copy_(
193 | torch.from_numpy(ckpt['coordinate_head/layers_4/scale'])
194 | )
195 | model.coordinate_head[4].bias.data.copy_(
196 | torch.from_numpy(ckpt['coordinate_head/layers_4/bias'])
197 | )
198 | model.coordinate_head[6].weight.data.copy_(
199 | torch.from_numpy(ckpt['coordinate_head/layers_6/kernel'].T)
200 | )
201 | model.coordinate_head[6].bias.data.copy_(
202 | torch.from_numpy(ckpt['coordinate_head/layers_6/bias'])
203 | )
204 | return model
205 |
--------------------------------------------------------------------------------
/tapnet/tapvid/README.md:
--------------------------------------------------------------------------------
1 | # TAP-Vid: A Benchmark for Tracking Any Point in a Video
2 |
3 | https://github.com/google-deepmind/tapnet/assets/4534987/ff5fa5e3-ed37-4480-ad39-42a1e2744d8b
4 |
5 | [TAP-Vid](https://tapvid.github.io) is a dataset of videos along with point tracks, either manually annotated or obtained from a simulator. The aim is to evaluate tracking of any trackable point on any solid physical surface. Algorithms receive a single query point on some frame, and must produce the rest of the track, i.e., including where that point has moved to (if visible), and whether it is visible, on every other frame. This requires point-level precision (unlike prior work on box and segment tracking) potentially on deformable surfaces (unlike structure from motion) over the long term (unlike optical flow) on potentially any object (i.e. class-agnostic, unlike prior class-specific keypoint tracking on humans).
6 |
7 | Our full benchmark incorporates 4 datasets: 30 videos from the [DAVIS val set](https://storage.googleapis.com/dm-tapnet/tapvid_davis.zip), 1000 videos from the [Kinetics val set](https://storage.googleapis.com/dm-tapnet/tapvid_kinetics.zip), 50 synthetic [Deepmind Robotics videos](https://storage.googleapis.com/dm-tapnet/tapvid_rgb_stacking.zip) for evaluation, and (almost infinite) point track ground truth on the large-scale synthetic [Kubric dataset](https://github.com/google-research/kubric/tree/main/challenges/point_tracking) for training.
8 |
9 | We also include a point tracking model TAP-Net, with code to train it on Kubric dataset. TAP-Net outperforms both optical flow and structure-from-motion methods on the TAP-Vid benchmark while achieving state-of-the-art performance on unsupervised human keypoint tracking on JHMDB, even though the model tracks points on clothes and skin rather than the joints as intended by the benchmark.
10 |
11 | ## Evaluating on TAP-Vid
12 |
13 | [`evaluation_datasets.py`](tapvid/evaluation_datasets.py) is intended to be a
14 | stand-alone, copy-and-pasteable reader and evaluator, which depends only
15 | on numpy and other basic tools. Tensorflow is required only for reading Kubric
16 | (which provides a tensorflow reader by default) as well as file operations,
17 | which should be straightforward to replace for systems without Tensorflow.
18 |
19 | For each dataset, there is a basic reader which will produce examples, dicts of
20 | numpy arrays containing the video, the query points, the target points, and the
21 | occlusion flag. Evaluation datasets may be used with one of two possible values
22 | for `query_mode`: `strided` (each trajectory is queried multiple times, with
23 | a fixed-length stride between queries) or `first` (each trajectory is queried
24 | once, with only the first visible point on the query). For details on outputs,
25 | see the documentation for `sample_queries_strided` and `sample_queries_first`.
26 |
27 | To compute metrics, use `compute_tapvid_metrics` in the same file. This
28 | computes results on each batch; the final metrics for the paper can be computed
29 | by simple averaging across all videos in the dataset. See the documentation for
30 | more details.
31 |
32 | Note that the outputs for a single query point *should not depend on the other
33 | queries defined in the batch*: that is, the outputs should be the same whether
34 | the queries are passed one at a time or all at once. This is important because
35 | the other queries may leak information about how pixels are grouped and how they
36 | move. This property is not enforced in the current evaluation code, but
37 | algorithms which violate this principle should not be considered valid
38 | competitors on this benchmark.
39 |
40 | Our readers also supply videos resized at 256x256 resolution. If algorithms can handle it, we encourage using full-resolution videos instead; we anticipate that
41 | predictions on such videos would be scaled to match a 256x256 resolution
42 | before computing metrics. Such predictions would, however, be evaluated as a separate category: we don't consider them comparable to those produced from lower-resolution videos.
43 |
44 | ## Comparison of Tracking With and Without Optical Flow
45 |
46 | When annotating videos for the TAP-Vid benchmark, we use a track assist algorithm interpolates between the sparse points that the annotators click, since requiring annotators to click every frame is prohibitively expensive. Specifically, we find tracks which minimize the discrepancy with the optical flow while still connecting the chosen points. Annotators will then check the interpolations and repeat the annotation until they observe no drift.
47 |
48 | To validate that this is a better approach than a simple linear interpolation between clicked points, we annotated several DAVIS videos twice and [compare them side by side](https://storage.googleapis.com/dm-tapnet/content/flow_tracker.html), once using the flow-based interpolation, and again using a naive linear interpolation, which simply moves the point at a constant velocity between points.
49 |
50 |
**Flow assist point annotation**: You can run this colab demo to see how point tracks are annotated with optical flow assistance.
51 |
52 | ## Downloading TAP-Vid-DAVIS and TAP-Vid-RGB-Stacking
53 |
54 | The data are contained in pickle files with download links: [TAP-Vid-DAVIS](https://storage.googleapis.com/dm-tapnet/tapvid_davis.zip) and [TAP-Vid-RGB-stacking](https://storage.googleapis.com/dm-tapnet/tapvid_rgb_stacking.zip).
55 |
56 | For DAVIS, the pickle file contains a dictionary, where each key is a DAVIS video name, and the values are the frames (4D uint8 tensor), the points (float32 tensor with 3 axes; the first is point id, the second is time, and the third is x/y), and the occlusions (bool tensor with 2 axis; the first is point id, the second is time). RGB-Stacking is the same format, except there is no video name, so it is a list of these structures rather than a dictionary.
57 |
58 | ## Downloading and Processing TAP-Vid-Kinetics
59 |
60 | The labels are contained in a csv file with download link: [TAP-Vid-Kinetics](https://storage.googleapis.com/dm-tapnet/tapvid_kinetics.zip).
61 |
62 | The videos are expected as the raw clips from [Kinetics700-2020](https://github.com/cvdfoundation/kinetics-dataset) validation set and stored in a local folder ``. The clips should be stored as MP4, following the name pattern `f'{youtube_id}_{start_time_sec:06}_{end_time_sec:06}.mp4'`, e.g. 'abcdefghijk_000010_000020.mp4'.
63 |
64 | Clips can be stored in any subfolder within the ``. The most common pattern is to store it as `//`.
65 |
66 | Once the validation clips have been downloaded, a pickle file containing all the information can be generated using the provided script:
67 |
68 | ```bash
69 | pip3 install -r requirements.txt
70 | python3 generate_tapvid.py \
71 | --input_csv_path= \
72 | --output_base_path= \
73 | --video_root_path= \
74 | --alsologtostderr
75 | ```
76 |
77 | ## Downloading RoboTAP
78 |
79 | The data are contained in pickle files with download links: [RoboTAP](https://storage.googleapis.com/dm-tapnet/robotap/robotap.zip).
80 |
81 | RoboTAP follows the same annotation format as TAP-Vid-DAVIS and TAP-Vid-RGB-stacking. The pickle file contains a dictionary, where each key is a video name, and the values are the frames (4D uint8 tensor), the points (float32 tensor with 3 axes; the first is point id, the second is time, and the third is x/y), and the occlusions (bool tensor with 2 axis; the first is point id, the second is time).
82 |
83 | ## Visualizing TAP-Vid and RoboTAP Dataset
84 |
85 | We also provide a script generating an MP4 with the points painted on top of the frames. The script will work with any of the pickle files. A random clip is chosen from all the available ones and all the point tracks are painted.
86 |
87 | ```bash
88 | pip3 install -r requirements.txt
89 | python3 visualize.py \
90 | --input_path= \
91 | --output_path= \
92 | --alsologtostderr
93 | ```
94 |
95 | ## Exemplar Visualization
96 | For visualization examples, we have the full [TAP-Vid-DAVIS](https://storage.googleapis.com/dm-tapnet/content/davis_ground_truth_v2.html) as well as 10 examples from the synthetic [TAP-Vid-Kubric](https://storage.googleapis.com/dm-tapnet/content/kubric_ground_truth.html) and [TAP-Vid-RGB-Stacking](https://storage.googleapis.com/dm-tapnet/content/rgb_stacking_ground_truth_v2.html) datasets.
97 |
98 | ## Training with TAP-Vid-Kubric
99 | TAP-Vid-DAVIS, TAP-Vid-RGB-stacking and TAP-Vid-Kinetics are mainly used for evaluation purpose. To train the model, we use [TAP-Vid-Kubric](https://github.com/google-research/kubric/tree/main/challenges/point_tracking).
100 |
--------------------------------------------------------------------------------
/tapnet/tapvid3d/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Tracking Any Point in 3D (TAPVid-3D)
3 |
4 | [Skanda Koppula](https://skoppula.com/), [Ignacio Rocco](https://www.irocco.info/), [Yi Yang](https://yangyi02.github.io/), [Joe Heyward](https://uk.linkedin.com/in/joe-heyward-71623595), [João Carreira](https://uk.linkedin.com/in/jo%C3%A3o-carreira-56238a7), [Andrew Zisserman](https://www.robots.ox.ac.uk/~az/), [Gabriel Brostow](http://www0.cs.ucl.ac.uk/staff/G.Brostow/), [Carl Doersch](http://www.carldoersch.com/)
5 |
6 | **[Google DeepMind](https://deepmind.google/)**, **[University College London](http://vis.cs.ucl.ac.uk/home/), [University of Oxford](https://www.robots.ox.ac.uk/~vgg/)**
7 |
8 | ### [`TAPVid-3D Website`](https://tapvid3d.github.io/) [`TAPVid-3D Paper`](https://arxiv.org/abs/2407.05921) [`Colab to Visualize Samples`](https://colab.research.google.com/github/google-deepmind/tapnet/blob/main/tapnet/tapvid3d/colab/load_and_visualize_tapvid3d_samples.ipynb)
9 |
10 | TAPVid-3D is a dataset and benchmark for evaluating the task of long-range
11 | Tracking Any Point in 3D (TAP-3D).
12 |
13 | The benchmark features 4,000+ real-world videos, along with their metric 3D
14 | position point trajectories and camera extrinsics. The dataset is contains three
15 | different video sources, and spans a variety of object types, motion patterns,
16 | and indoor and outdoor environments. This repository folder contains the code to
17 | download and generate these annotations and dataset samples to view.
18 |
19 | **Note that in order to use the dataset, you must accept the licenses of the
20 | constituent original video data sources that you use!** In particular, you must
21 | adhere to the terms of service outlined in:
22 |
23 | 1. [Aria Digital Twin](https://www.projectaria.com/datasets/adt/license/)
24 | 2. [Waymo Open Dataset](https://waymo.com/open/terms/)
25 | 3. [Panoptic Studio](http://domedb.perception.cs.cmu.edu/)
26 |
27 | To measure performance on the TAP-3D task, we formulated metrics that extend the
28 | Jaccard-based metric used in 2D TAP to handle the complexities of ambiguous
29 | depth scales across models, occlusions, and multi-track spatio-temporal
30 | smoothness. Our implementation of these metrics (3D-AJ, APD, OA) can be found in
31 | `tapvid3d_metrics.py`.
32 |
33 | For more details, including the performance of multiple baseline 3D point
34 | tracking models on the benchmark, please see the paper:
35 | [TAPVid-3D:A Benchmark for Tracking Any Point in 3D](http://arxiv.org).
36 |
37 | ### Getting Started: Installing
38 |
39 | (If you want to generate the dataset and want to use CUDA for running semantic segmentation, first install a CUDA-enabled PyTorch with `pip3 install torch==2.3.0 torchvision==0.18.0 --index-url https://download.pytorch.org/whl/cu118`)
40 |
41 | For generating the dataset, install the Tracking Any Point repo with:
42 |
43 | `pip install "git+https://github.com/google-deepmind/tapnet.git"[tapvid3d_eval,tapvid3d_generation]`
44 |
45 | If you only want to use the metrics, install with:
46 |
47 | `pip install "git+https://github.com/google-deepmind/tapnet.git"[tapvid3d_eval]`
48 |
49 | For a local editable installation, clone the repo and use:
50 |
51 | `pip install -e .[tapvid3d_eval,tapvid3d_generation]`
52 |
53 | or
54 |
55 | `pip install -e .[tapvid3d_eval]`.
56 |
57 | ### How to Download and Generate the Dataset
58 |
59 | Scripts to download and generate the annotations of each of the datasets can be
60 | found in the `annotation_generation` subdirectory, and can be run as:
61 |
62 | ```
63 | python3 -m tapnet.tapvid3d.annotation_generation.generate_adt --help
64 | python3 -m tapnet.tapvid3d.annotation_generation.generate_pstudio --help
65 | python3 -m tapnet.tapvid3d.annotation_generation.generate_drivetrack --help
66 | ```
67 |
68 | Because of license restrictions in distribution of the underlying source videos
69 | for Aria Digital Twin, you will need to accept their licence terms and download
70 | the ADT dataset by first getting the cdn json file from
71 | [Project Aria Explorer](https://explorer.projectaria.com/?v=%22Aria+Digital+Twin%22),
72 | and downloading the ADT `main_vrs`, `main_groundtruth`, `segmentation` and `depth` files with:
73 |
74 | `aria_dataset_downloader --cdn_file /PATH_TO/Aria_Digital_Twin_1720774989.json -o /OUTPUT_PATH -d 0 6 7 8`
75 |
76 | To run all generation scripts, follow the instructions and run the commands in
77 | `generate_all.sh`. This will generate all the `*.npz` files, and place them into
78 | a new folder, `tapvid3d_dataset`. To generate only the `minival` split,
79 | replace `--split=all` in this script with `--split=minival`.
80 |
81 | To test things are working before full generation, you can run `./run_all.sh
82 | --debug`. This runs generates exactly one `*.npz`/video annotation per data
83 | source.
84 |
85 | ### Data format
86 |
87 | Once the benchmark files are fully generated, you will have roughly 4,500
88 | `*.npz` files, each one with exactly one dataset video+annotation. Each `*.npz`
89 | file, contains:
90 |
91 | * `images_jpeg_bytes`: tensor of shape [`# of frames`, `height`, `width`, 3],
92 | each frame stored as JPEG bytes that must be decoded
93 |
94 | * `intrinsics`: (fx, fy, cx, cy) camera intrinsics of the video
95 |
96 | * `tracks_xyz`: tensor of shape (`# of frames`, `# of point tracks`, 3),
97 | representing the 3D point trajectories and the last dimension is the `(x, y,
98 | z)` point position in meters.
99 |
100 | * `visibility`: tensor of shape (`# of frames`, `# of point tracks`),
101 | representing the visibility of each point along its trajectory
102 |
103 | * `queries_xyt`: tensor of shape (`# of point tracks`, 3), representing the
104 | query point used in the benchmark as the initial given point to track. The
105 | last dimension is given in `(x, y, t)`, where `x,y` is the pixel location of
106 | the query point and `t` is the query frame.
107 |
108 | * `extrinsics_w2c`: for videos with a moving camera (videos from the Waymo
109 | Open Dataset and ADT), we provide camera extrinsics in the form of a world
110 | to camera transform, a 4x4 matrix consisting of the rotation matrix and
111 | translation matrix. This field is NOT present in the `pstudio` *.npz files,
112 | because in Panoptic Studio, the camera is static.
113 |
114 | ### Getting Started: Evaluating Your Own 3D Tracking Model
115 |
116 | To get started using this dataset to benchmark your own 3D tracking model, check
117 | out `evaluate_model.py`. This script reads the ground truth track prediction,
118 | and predicted tracks and computes the TAPVid-3D evaluation metrics for those
119 | predictions. An example of running this is provided in `run_evaluate_model.sh`,
120 | which computes metrics for our precomputed outputs from
121 | [SpatialTracker](https://henry123-boy.github.io/SpaTracker/) on the `minival`
122 | split, for only the `DriveTrack` subset. `evaluate_model.py`
123 | takes as input a folder with subdirectories (corresponding to up to 3 data
124 | sources [among PStudio/DriveTrack/ADT]). Each
125 | subdirectory should contain a list of `*.npz` files
126 | (one per TAPVid-3D video, each containing
127 | `tracks_XYZ` and `visibility` in the format above).
128 | Running this will return all TAP-3D metrics
129 | described in the paper (these are implemented in `tapvid3d_metrics.py`).
130 |
131 | ### Visualizing Samples in Colab
132 |
133 | You can view samples of the dataset, using a public Colab demo:
134 |
.
135 |
136 | ## Citing this Work
137 |
138 | If you find this work useful, consider citing the manuscript:
139 |
140 | ```
141 | @inproceedings{koppula2024tapvid3d,
142 | title={TAPVid-3D: A Benchmark for Tracking Any Point in 3D},
143 | author={Skanda Koppula and Ignacio Rocco and Yi Yang and Joe Heyward and João Carreira and Andrew Zisserman and Gabriel Brostow and Carl Doersch},
144 | booktitle={NeurIPS},
145 | year={2024},
146 | }
147 | ```
148 |
149 | ## License and Disclaimer
150 |
151 | Copyright 2024 Google LLC
152 |
153 | All software here is licensed under the Apache License, Version 2.0 (Apache
154 | 2.0); you may not use this file except in compliance with the Apache 2.0
155 | license. You may obtain a copy of the Apache 2.0 license at:
156 |
157 | https://www.apache.org/licenses/LICENSE-2.0
158 |
159 | All other materials, excepting the source videos and artifacts used to generate
160 | annotations, are licensed under the Creative Commons Attribution 4.0
161 | International License (CC-BY). You may obtain a copy of the CC-BY license at:
162 | https://creativecommons.org/licenses/by/4.0/legalcode
163 |
164 | The Aria Digital Twin, Waymo Open Dataset, and Panoptic Studio datasets all
165 | have individual usage terms and licenses that must be adhered to when using any
166 | software or materials from here.
167 |
168 | Unless required by applicable law or agreed to in writing, all software and
169 | materials distributed here under the Apache 2.0 or CC-BY licenses are
170 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
171 | either express or implied. See the licenses for the specific language governing
172 | permissions and limitations under those licenses.
173 |
174 | This is not an official Google product.
175 |
--------------------------------------------------------------------------------
/tapnet/models/tsm_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Utils functions for TSM."""
17 |
18 | from typing import Tuple
19 |
20 | import chex
21 | import jax
22 | import jax.numpy as jnp
23 |
24 |
25 | def prepare_inputs(inputs: chex.Array) -> Tuple[jnp.ndarray, str, int]:
26 | """Deduces input mode for TSM."""
27 | # Deduce if we run on TPU based on input shape.
28 | if len(inputs.shape) == 5:
29 | # Input is given in the standard [B, T, H, W, 3] format.
30 | tsm_mode = 'gpu'
31 | num_frames = inputs.shape[1]
32 | inputs = jnp.reshape(inputs, [-1] + list(inputs.shape[2:]))
33 | else:
34 | # Input is given in the [T * B, H, W, 3] format.
35 | tsm_mode = 'tpu'
36 | num_frames = None
37 | return inputs, tsm_mode, num_frames
38 |
39 |
40 | def prepare_outputs(
41 | outputs: chex.Array,
42 | tsm_mode: str,
43 | num_frames: int,
44 | reduce_mean: bool = True,
45 | ) -> jnp.ndarray:
46 | """Processes output of TSM to undo the merging of batch and time."""
47 | # Get the shape without the batch/time dimension (for TSM batch and time are
48 | # merged in the first dimension).
49 | shape_no_bt = list(outputs.shape[1:])
50 | if tsm_mode == 'tpu':
51 | # Outputs are of the shape [num_frames * B, ..., n_channels]
52 | outputs = jnp.reshape(outputs, [num_frames, -1] + shape_no_bt)
53 | if reduce_mean:
54 | # We average over time and space.
55 | outputs = jnp.mean(
56 | outputs, axis=[0] + list(range(2,
57 | len(shape_no_bt) + 1)))
58 | else:
59 | outputs = jnp.transpose(
60 | outputs, axes=[1, 0] + list(range(2,
61 | len(shape_no_bt) + 2)))
62 | elif tsm_mode == 'gpu':
63 | # Outputs are of the shape [B * num_frames, ..., n_channels].
64 | outputs = jnp.reshape(outputs, [-1, num_frames] + shape_no_bt)
65 | if reduce_mean:
66 | outputs = jnp.mean(
67 | outputs, axis=[1] + list(range(2,
68 | len(shape_no_bt) + 1)))
69 | elif tsm_mode.startswith('deflated'):
70 | # In deflated mode, outputs are already in the right format.
71 | pass
72 | else:
73 | raise ValueError('`tsm_mode` should be \'tpu\' or \'gpu\' or '
74 | f'\'deflated_0.x\' ({tsm_mode} given)')
75 | return outputs # pytype: disable=bad-return-type # numpy-scalars
76 |
77 |
78 | def apply_temporal_shift(
79 | x: chex.Array,
80 | tsm_mode: str,
81 | num_frames: int,
82 | channel_shift_fraction: float = 0.125,
83 | ) -> jnp.ndarray:
84 | """Performs a temporal shift: https://arxiv.org/abs/1811.08383 with mode."""
85 | if tsm_mode == 'tpu':
86 | outputs = temporal_shift_tpu(x, num_frames, channel_shift_fraction)
87 | elif tsm_mode == 'gpu':
88 | outputs = temporal_shift_gpu(x, num_frames, channel_shift_fraction)
89 | elif tsm_mode.startswith('deflated'):
90 | alpha = float(tsm_mode.split('_')[1])
91 | outputs = temporal_shift_image_mode(x, channel_shift_fraction, alpha)
92 | else:
93 | raise ValueError('`tsm_mode` should be \'tpu\' or \'gpu\' or '
94 | f'\'deflated_0.x\' ({tsm_mode} given)')
95 | return outputs
96 |
97 |
98 | def temporal_shift_image_mode(x, channel_shift_fraction=0.125, alpha=0.3):
99 | """Temporal shift applied on single image (to emulate a fixed video)."""
100 | # B, H, W, C = batch_size, im_height, im_width, channels.
101 | # Input is (B, H, W, C).
102 | orig_shp = tuple(x.shape)
103 | n_channels = orig_shp[-1]
104 | n_shift = int(n_channels * channel_shift_fraction)
105 | # Alpha emulates the effect of the padding when using a single frame.
106 | shifted_backward = alpha * x[:, :, :, -n_shift:]
107 | shifted_forward = alpha * x[:, :, :, :n_shift]
108 | no_shift = x[:, :, :, n_shift:-n_shift]
109 | shifted_x = jnp.concatenate([shifted_backward, no_shift, shifted_forward],
110 | axis=3)
111 | return shifted_x
112 |
113 |
114 | def temporal_shift_gpu(
115 | x: chex.Array,
116 | num_frames: int,
117 | channel_shift_fraction: float = 0.125,
118 | ) -> jnp.ndarray:
119 | """Performs a temporal shift: https://arxiv.org/abs/1811.08383."""
120 | # B, T, H, W, C = batch_size, num_frames, im_height, im_width, channels.
121 | # Input is (B * T, H, W, C).
122 | orig_shp = tuple(x.shape)
123 | reshaped_x = jnp.reshape(x, (-1, num_frames) + orig_shp[1:])
124 | n_channels = orig_shp[-1]
125 | n_shift = int(n_channels * channel_shift_fraction)
126 |
127 | new_shp = tuple(reshaped_x.shape)
128 |
129 | # shifted_backward = reshaped_x[:, 1:, :, :, -n_shift:].
130 | shifted_backward = jax.lax.slice(
131 | reshaped_x, (0, 1, 0, 0, new_shp[4] - n_shift),
132 | (new_shp[0], new_shp[1], new_shp[2], new_shp[3], new_shp[4]))
133 | shifted_backward_padding = ((0, 0), (0, 1), (0, 0), (0, 0), (0, 0))
134 | shifted_backward = jnp.pad(shifted_backward, shifted_backward_padding)
135 |
136 | # shifted_forward = reshaped_x[:, :-1, :, :, :n_shift].
137 | shifted_forward = jax.lax.slice(
138 | reshaped_x, (0, 0, 0, 0, 0),
139 | (new_shp[0], new_shp[1] - 1, new_shp[2], new_shp[3], n_shift))
140 | shifted_forward_padding = ((0, 0), (1, 0), (0, 0), (0, 0), (0, 0))
141 | shifted_forward = jnp.pad(shifted_forward, shifted_forward_padding)
142 |
143 | no_shift = reshaped_x[:, :, :, :, n_shift:-n_shift]
144 | shifted_x = jnp.concatenate([shifted_backward, no_shift, shifted_forward],
145 | axis=4)
146 | return jnp.reshape(shifted_x, (-1,) + orig_shp[1:])
147 |
148 |
149 | def temporal_shift_tpu(
150 | x: chex.Array,
151 | num_frames: int,
152 | channel_shift_fraction: float = 0.125,
153 | ) -> jnp.ndarray:
154 | """Performs a temporal shift: https://arxiv.org/abs/1811.08383.
155 |
156 | TPU optimized version of TSM. Reshape is avoided by having the images
157 | reshaped in [T * B, :] so that frames corresponding to same time frame in
158 | videos are contiguous in memory. Finally, to avoid concatenate that prevent
159 | some fusion from happening we simply sum masked version of the features.
160 | Args:
161 | x: Input expected to be [T * B, H, W, C] (where the batch has been reshaped
162 | from a time major version of the input).
163 | num_frames: number of frames T per video.
164 | channel_shift_fraction: fraction of the channel to shift forward and
165 | backward.
166 |
167 | Returns:
168 | The temporal shifted version of x.
169 | """
170 | # B, T, H, W, C = batch_size, num_frames, im_height, im_width, channels.
171 | # Input is (T * B, H, W, C).
172 | original_dtype = x.dtype
173 | original_shape = list(x.shape)
174 |
175 | batch_size = int(original_shape[0] / num_frames)
176 | n_channels = int(original_shape[-1])
177 | n_shift = int(n_channels * channel_shift_fraction)
178 |
179 | # Cast to bfloat16.
180 | x = x.astype(jnp.bfloat16)
181 |
182 | # For the following, assume that x has 3 channels [x1, x2, x3] and n_shift=1.
183 | # Shift backward, we first pad by zeros [x1, x2, x3, 0, 0].
184 | orig_shp = list(x.shape)
185 |
186 | shifted_backward_padding = ((0, batch_size, 0), (0, 0, 0), (0, 0, 0),
187 | (0, n_channels - n_shift, 0))
188 | x_backward_padding = jax.lax.pad(
189 | x,
190 | padding_value=jnp.bfloat16(0.),
191 | padding_config=shifted_backward_padding)
192 | # The following shift gets to [x3^+1, 0, 0] (where +1 means from the future).
193 | shifted_backward = jax.lax.slice(x_backward_padding,
194 | (batch_size, 0, 0, n_channels - n_shift),
195 | (orig_shp[0] + batch_size, orig_shp[1],
196 | orig_shp[2], 2 * n_channels - n_shift))
197 | # Shift forward, we first pad by zeros [0, 0, x1, x2, x3].
198 | shifted_forward_padding = ((batch_size, 0, 0), (0, 0, 0), (0, 0, 0),
199 | (n_channels - n_shift, 0, 0))
200 | x_forward_padding = jax.lax.pad(
201 | x, padding_value=jnp.bfloat16(0.), padding_config=shifted_forward_padding)
202 | # The following shift gets to [0, 0, x1^-1] (where -1 means from the past).
203 | shifted_forward = jax.lax.slice(
204 | x_forward_padding, (0, 0, 0, 0),
205 | (orig_shp[0], orig_shp[1], orig_shp[2], n_channels))
206 | # No shift is in the middle, this gets [0, x2, 0].
207 | mask_noshift = (jnp.reshape((jnp.arange(n_channels) >= n_shift) &
208 | (jnp.arange(n_channels) < n_channels - n_shift),
209 | (1, 1, 1, -1))).astype(jnp.bfloat16)
210 | no_shift = mask_noshift * x
211 | # By summing everything together, we end up with [x3^+1, x2, x1^-1].
212 | # Note: channels have been reordered but that doesn't matter for the model.
213 | shifted_x = shifted_backward + shifted_forward + no_shift
214 |
215 | return shifted_x.astype(original_dtype)
216 |
--------------------------------------------------------------------------------
/tapnet/utils/experiment_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Logging and other experiment utilities."""
17 |
18 | import os
19 | from typing import Mapping, Optional
20 |
21 | import jax
22 | from jaxline import utils
23 | from ml_collections import config_dict
24 | import numpy as np
25 | import optax
26 | import tensorflow as tf
27 |
28 | from tapnet.utils import optimizers
29 |
30 |
31 | def get_lr_schedule(
32 | total_steps: int,
33 | optimizer_config: config_dict.ConfigDict,
34 | ) -> optax.Schedule:
35 | """Build the LR schedule function."""
36 | base_lr = optimizer_config.base_lr
37 |
38 | schedule_type = optimizer_config.schedule_type
39 | if schedule_type == 'cosine':
40 | warmup_steps = (optimizer_config.cosine_decay_kwargs.warmup_steps)
41 | # Batch scale the other lr values as well:
42 | init_value = optimizer_config.cosine_decay_kwargs.init_value
43 | end_value = optimizer_config.cosine_decay_kwargs.end_value
44 |
45 | schedule_fn = optax.warmup_cosine_decay_schedule(
46 | init_value=init_value,
47 | peak_value=base_lr,
48 | warmup_steps=warmup_steps,
49 | decay_steps=total_steps,
50 | end_value=end_value)
51 | elif schedule_type == 'constant_cosine':
52 | # Convert end_value to alpha, used by cosine_decay_schedule.
53 | alpha = optimizer_config.constant_cosine_decay_kwargs.end_value / base_lr
54 |
55 | # Number of steps spent in constant phase.
56 | constant_steps = int(
57 | optimizer_config.constant_cosine_decay_kwargs.constant_fraction *
58 | total_steps)
59 | decay_steps = total_steps - constant_steps
60 |
61 | constant_phase = optax.constant_schedule(value=base_lr)
62 | decay_phase = optax.cosine_decay_schedule(
63 | init_value=base_lr, decay_steps=decay_steps, alpha=alpha)
64 | schedule_fn = optax.join_schedules(
65 | schedules=[constant_phase, decay_phase], boundaries=[constant_steps])
66 | else:
67 | raise ValueError(f'Unknown learning rate schedule: {schedule_type}')
68 |
69 | return schedule_fn
70 |
71 |
72 | def make_optimizer(
73 | optimizer_config: config_dict.ConfigDict,
74 | lr_schedule: optax.Schedule,
75 | ) -> optax.GradientTransformation:
76 | """Construct the optax optimizer with given LR schedule."""
77 | # Decay learned position embeddings by default.
78 | weight_decay_exclude_names = ['b']
79 |
80 | optax_chain = []
81 | if optimizer_config.max_norm > 0:
82 | optax_chain.append(optax.clip_by_global_norm(optimizer_config.max_norm))
83 |
84 | if optimizer_config.optimizer == 'sgd':
85 | optax_chain.extend([
86 | optax.trace(**optimizer_config.sgd_kwargs),
87 | optimizers.add_weight_decay(
88 | optimizer_config.weight_decay,
89 | exclude_names=weight_decay_exclude_names)
90 | ])
91 | elif optimizer_config.optimizer == 'adam':
92 | optax_chain.extend([
93 | optax.scale_by_adam(**optimizer_config.adam_kwargs),
94 | optimizers.add_weight_decay(
95 | optimizer_config.weight_decay,
96 | exclude_names=weight_decay_exclude_names)
97 | ])
98 | else:
99 | raise ValueError(f'Undefined optimizer {optimizer_config.optimizer}')
100 | optax_chain.extend([
101 | optax.scale_by_schedule(lr_schedule),
102 | optax.scale(-1),
103 | ])
104 |
105 | optimizer = optax.chain(*optax_chain)
106 | optimizer = optax.apply_if_finite(optimizer, max_consecutive_errors=5)
107 | return optimizer
108 |
109 |
110 | class NumpyFileCheckpointer(utils.Checkpointer):
111 | """A Jaxline checkpointer which saves to numpy files on disk."""
112 |
113 | def __init__(self, config: config_dict.ConfigDict, mode: str):
114 | self._checkpoint_file = os.path.join(config.checkpoint_dir,
115 | 'checkpoint.npy')
116 | self._checkpoint_state = config_dict.ConfigDict()
117 | del mode
118 |
119 | def get_experiment_state(self, ckpt_series: str) -> config_dict.ConfigDict:
120 | """Returns the experiment state for a given checkpoint series."""
121 | if ckpt_series != 'latest':
122 | raise ValueError('multiple checkpoint series are not supported')
123 | return self._checkpoint_state
124 |
125 | def save(self, ckpt_series: str) -> None:
126 | """Saves the checkpoint."""
127 | if ckpt_series != 'latest':
128 | raise ValueError('multiple checkpoint series are not supported')
129 | exp_mod = self._checkpoint_state.experiment_module
130 | global_step = self._checkpoint_state.global_step
131 | f_np = lambda x: np.array(jax.device_get(utils.get_first(x)))
132 | to_save = {}
133 | for attr, name in exp_mod.CHECKPOINT_ATTRS.items():
134 | if name == 'global_step':
135 | raise ValueError(
136 | 'global_step attribute would overwrite jaxline global step')
137 | np_params = jax.tree_util.tree_map(f_np, getattr(exp_mod, attr))
138 | to_save[name] = np_params
139 | to_save['global_step'] = global_step
140 |
141 | with tf.io.gfile.GFile(self._checkpoint_file + '_tmp', 'wb') as fp:
142 | np.save(fp, to_save)
143 | tf.io.gfile.rename(
144 | self._checkpoint_file + '_tmp',
145 | self._checkpoint_file,
146 | overwrite=True,
147 | )
148 |
149 | def can_be_restored(self, ckpt_series: str) -> bool:
150 | """Returns whether or not a given checkpoint series can be restored."""
151 | if ckpt_series != 'latest':
152 | raise ValueError('multiple checkpoint series are not supported')
153 | return tf.io.gfile.exists(self._checkpoint_file)
154 |
155 | def restore(self, ckpt_series: str) -> None:
156 | """Restores the checkpoint."""
157 | experiment_state = self.get_experiment_state(ckpt_series)
158 | with tf.io.gfile.GFile(self._checkpoint_file, 'rb') as fp:
159 | ckpt_state = np.load(fp, allow_pickle=True).item()
160 | experiment_state.global_step = int(ckpt_state['global_step'])
161 | exp_mod = experiment_state.experiment_module
162 | for attr, name in exp_mod.CHECKPOINT_ATTRS.items():
163 | setattr(exp_mod, attr, utils.bcast_local_devices(ckpt_state[name]))
164 |
165 | def restore_path(self, ckpt_series: str) -> Optional[str]:
166 | """Returns the restore path for the checkpoint, or None."""
167 | if not self.can_be_restored(ckpt_series):
168 | return None
169 | return self._checkpoint_file
170 |
171 | def wait_for_checkpointing_to_finish(self) -> None:
172 | """Waits for any async checkpointing to complete."""
173 |
174 | @classmethod
175 | def create(
176 | cls,
177 | config: config_dict.ConfigDict,
178 | mode: str,
179 | ) -> utils.Checkpointer:
180 | return cls(config, mode)
181 |
182 |
183 | def default_color_augmentation_fn(
184 | inputs: Mapping[str, tf.Tensor],
185 | zero_centering_image: bool = True,
186 | prob_color_augment: float = 0.8,
187 | prob_color_drop: float = 0.2,
188 | ) -> Mapping[str, tf.Tensor]:
189 | """Standard color augmentation for videos.
190 |
191 | Args:
192 | inputs: A DatasetElement containing the item 'video' which will have
193 | augmentations applied to it.
194 | zero_centering_image: Whether to zero center the image.
195 | prob_color_augment: Probability of applying color augmentation.
196 | prob_color_drop: Probability of applying color drop.
197 |
198 | Returns:
199 | A DatasetElement with all the same data as the original, except that
200 | the video has augmentations applied.
201 | """
202 | frames = inputs['video']
203 | if frames.dtype != tf.float32:
204 | raise ValueError('`frames` should be in float32.')
205 |
206 | def color_augment(video: tf.Tensor) -> tf.Tensor:
207 | """Do standard color augmentations."""
208 | # Note the same augmentation will be applied to all frames of the video.
209 | if zero_centering_image:
210 | video = 0.5 * (video + 1.0)
211 | video = tf.image.random_brightness(video, max_delta=32. / 255.)
212 | video = tf.image.random_saturation(video, lower=0.6, upper=1.4)
213 | video = tf.image.random_contrast(video, lower=0.6, upper=1.4)
214 | video = tf.image.random_hue(video, max_delta=0.2)
215 | video = tf.clip_by_value(video, 0.0, 1.0)
216 | if zero_centering_image:
217 | video = 2 * (video-0.5)
218 | return video
219 |
220 | def color_drop(video: tf.Tensor) -> tf.Tensor:
221 | video = tf.image.rgb_to_grayscale(video)
222 | video = tf.tile(video, [1, 1, 1, 3])
223 | return video
224 |
225 | # Eventually applies color augmentation.
226 | coin_toss_color_augment = tf.random.uniform(
227 | [], minval=0, maxval=1, dtype=tf.float32)
228 | frames = tf.cond(
229 | pred=tf.less(coin_toss_color_augment,
230 | tf.cast(prob_color_augment, tf.float32)),
231 | true_fn=lambda: color_augment(frames),
232 | false_fn=lambda: frames)
233 |
234 | # Eventually applies color drop.
235 | coin_toss_color_drop = tf.random.uniform(
236 | [], minval=0, maxval=1, dtype=tf.float32)
237 | frames = tf.cond(
238 | pred=tf.less(coin_toss_color_drop, tf.cast(prob_color_drop, tf.float32)),
239 | true_fn=lambda: color_drop(frames),
240 | false_fn=lambda: frames)
241 | result = {**inputs}
242 | result['video'] = frames
243 |
244 | return result
245 |
246 |
247 | def add_default_data_augmentation(ds: tf.data.Dataset) -> tf.data.Dataset:
248 | return ds.map(
249 | default_color_augmentation_fn, num_parallel_calls=tf.data.AUTOTUNE)
250 |
--------------------------------------------------------------------------------
/tapnet/models/tapnet_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """TAP-Net model definition."""
17 |
18 | import functools
19 | from typing import Optional, Mapping, Tuple
20 |
21 | import chex
22 | from einshape import jax_einshape as einshape
23 | import haiku as hk
24 | import jax
25 | import jax.numpy as jnp
26 |
27 | from tapnet.models import tsm_resnet
28 | from tapnet.utils import model_utils
29 | from tapnet.utils import transforms
30 |
31 |
32 | def create_batch_norm(
33 | x: chex.Array, is_training: bool, cross_replica_axis: Optional[str]
34 | ) -> chex.Array:
35 | """Function to allow TSM-ResNet to create batch norm layers."""
36 | return hk.BatchNorm(
37 | create_scale=True,
38 | create_offset=True,
39 | decay_rate=0.9,
40 | cross_replica_axis=cross_replica_axis,
41 | )(x, is_training)
42 |
43 |
44 | class TAPNet(hk.Module):
45 | """Joint model for performing flow-based tasks."""
46 |
47 | def __init__(
48 | self,
49 | feature_grid_stride: int = 8,
50 | num_heads: int = 1,
51 | cross_replica_axis: Optional[str] = 'i',
52 | num_frames: int = 24,
53 | ):
54 | """Initialize the model and provide kwargs for the various components.
55 |
56 | Args:
57 | feature_grid_stride: Stride to extract features. For TSM-ResNet,
58 | supported values are 8 (default), 16, and 32.
59 | num_heads: Number of heads in the cost volume.
60 | cross_replica_axis: Which cross replica axis to use for the batch norm.
61 | num_frames: Number of frames passed into TSM-ResNet.
62 | """
63 |
64 | super().__init__()
65 |
66 | self.feature_grid_stride = feature_grid_stride
67 | self.num_heads = num_heads
68 | self.softmax_temperature = 10.0
69 |
70 | self.tsm_resnet = tsm_resnet.TSMResNetV2(
71 | normalize_fn=functools.partial(
72 | create_batch_norm, cross_replica_axis=cross_replica_axis
73 | ),
74 | num_frames=num_frames,
75 | channel_shift_fraction=[0.125, 0.125, 0.0, 0.0],
76 | name='tsm_resnet_video',
77 | )
78 |
79 | self.cost_volume_track_mods = {
80 | 'hid1':
81 | hk.Conv3D(
82 | 16,
83 | [1, 3, 3],
84 | name='cost_volume_regression_1',
85 | stride=[1, 1, 1],
86 | ),
87 | 'hid2':
88 | hk.Conv3D(
89 | 1,
90 | [1, 3, 3],
91 | name='cost_volume_regression_2',
92 | stride=[1, 1, 1],
93 | ),
94 | 'hid3':
95 | hk.Conv3D(
96 | 32,
97 | [1, 3, 3],
98 | name='cost_volume_occlusion_1',
99 | stride=[1, 2, 2],
100 | ),
101 | 'hid4':
102 | hk.Linear(16, name='cost_volume_occlusion_2'),
103 | 'occ_out':
104 | hk.Linear(1, name='occlusion_out'),
105 | 'regression_hid':
106 | hk.Linear(128, name='regression_hid'),
107 | 'regression_out':
108 | hk.Linear(2, name='regression_out'),
109 | }
110 |
111 | def tracks_from_cost_volume(
112 | self,
113 | interp_feature_heads: chex.Array,
114 | feature_grid_heads: chex.Array,
115 | query_points: Optional[chex.Array],
116 | im_shp: Optional[chex.Shape] = None,
117 | ) -> Tuple[chex.Array, chex.Array]:
118 | """Converts features into tracks by computing a cost volume.
119 |
120 | The computed cost volume will have shape
121 | [batch, num_queries, time, height, width, num_heads], which can be very
122 | memory intensive.
123 |
124 | Args:
125 | interp_feature_heads: A tensor of features for each query point, of shape
126 | [batch, num_queries, channels, heads].
127 | feature_grid_heads: A tensor of features for the video, of shape [batch,
128 | time, height, width, channels, heads].
129 | query_points: When computing tracks, we assume these points are given as
130 | ground truth and we reproduce them exactly. This is a set of points of
131 | shape [batch, num_points, 3], where each entry is [t, y, x] in frame/
132 | raster coordinates.
133 | im_shp: The shape of the original image, i.e., [batch, num_frames, time,
134 | height, width, 3].
135 |
136 | Returns:
137 | A 2-tuple of the inferred points (of shape
138 | [batch, num_points, num_frames, 2] where each point is [x, y]) and
139 | inferred occlusion (of shape [batch, num_points, num_frames], where
140 | each is a logit where higher means occluded)
141 | """
142 |
143 | mods = self.cost_volume_track_mods
144 | # Note: time is first axis to prevent the TPU from padding
145 | cost_volume = jnp.einsum(
146 | 'bncd,bthwcd->tbnhwd',
147 | interp_feature_heads,
148 | feature_grid_heads,
149 | )
150 | shape = cost_volume.shape
151 | cost_volume = einshape('tbnhwd->t(bn)hwd', cost_volume)
152 |
153 | occlusion = mods['hid1'](cost_volume)
154 | occlusion = jax.nn.relu(occlusion)
155 |
156 | pos = mods['hid2'](occlusion)
157 | pos = jax.nn.softmax(pos * self.softmax_temperature, axis=(-2, -3))
158 | pos = einshape('t(bn)hw1->bnthw', pos, n=shape[2])
159 | points = model_utils.heatmaps_to_points(
160 | pos, im_shp, query_points=query_points
161 | )
162 |
163 | occlusion = mods['hid3'](occlusion)
164 | occlusion = jnp.mean(occlusion, axis=(-2, -3))
165 | occlusion = mods['hid4'](occlusion)
166 | occlusion = jax.nn.relu(occlusion)
167 | occlusion = mods['occ_out'](occlusion)
168 | occlusion = jnp.transpose(occlusion, (1, 0, 2))
169 | assert occlusion.shape[1] == shape[0]
170 | occlusion = jnp.reshape(occlusion, (shape[1], shape[2], shape[0]))
171 | return points, occlusion
172 |
173 | def __call__(
174 | self,
175 | video: chex.Array,
176 | is_training: bool,
177 | query_points: chex.Array,
178 | compute_regression: bool = True,
179 | query_chunk_size: Optional[int] = None,
180 | get_query_feats: bool = False,
181 | feature_grid: Optional[chex.Array] = None,
182 | ) -> Mapping[str, chex.Array]:
183 | """Runs a forward pass of the model.
184 |
185 | Args:
186 | video: A 4-D or 5-D tensor representing a batch of sequences of images. In
187 | the 4-D case, we assume the entire batch has been concatenated along the
188 | batch dimension, one sequence after the other. This can speed up
189 | inference on the TPU and save memory.
190 | is_training: Whether we are training.
191 | query_points: The query points for which we compute tracks.
192 | compute_regression: if True, compute tracks using cost volumes; otherwise
193 | simply compute features (required for the baseline)
194 | query_chunk_size: When computing cost volumes, break the queries into
195 | chunks of this size to save memory.
196 | get_query_feats: If True, also return the features for each query obtained
197 | using bilinear interpolation from the feature grid
198 | feature_grid: If specified, use this as the feature grid rather than
199 | computing it from the pixels.
200 |
201 | Returns:
202 | A dict of outputs, including:
203 | feature_grid: a TSM-ResNet feature grid of shape
204 | [batch, num_frames, height//stride, width//stride, channels]
205 | query_feats (optional): A feature for each query point, of size
206 | [batch, num_queries, channels]
207 | occlusion: Occlusion logits, of shape [batch, num_queries, num_frames]
208 | where higher indicates more likely to be occluded.
209 | tracks: predicted point locations, of shape
210 | [batch, num_queries, num_frames, 2], where each point is [x, y]
211 | in raster coordinates
212 | """
213 | num_frames = None
214 | if feature_grid is None:
215 | latent = self.tsm_resnet(
216 | video,
217 | is_training=is_training,
218 | output_stride=self.feature_grid_stride,
219 | out_num_frames=num_frames,
220 | final_endpoint='tsm_resnet_unit_2',
221 | )
222 |
223 | feature_grid = latent / jnp.sqrt(
224 | jnp.maximum(
225 | jnp.sum(jnp.square(latent), axis=-1, keepdims=True),
226 | 1e-12,
227 | ))
228 |
229 | shape = video.shape
230 | if num_frames is not None and len(shape) < 5:
231 | shape = (shape[0] // num_frames, num_frames) + shape[1:]
232 |
233 | # shape is [batch_size, time, height, width, channels]; conversion needs
234 | # [time, width, height]
235 | position_in_grid = transforms.convert_grid_coordinates(
236 | query_points,
237 | shape[1:4],
238 | feature_grid.shape[1:4],
239 | coordinate_format='tyx',
240 | )
241 | interp_features = jax.vmap(
242 | jax.vmap(
243 | model_utils.interp,
244 | in_axes=(3, None),
245 | out_axes=1,
246 | )
247 | )(feature_grid, position_in_grid)
248 | feature_grid_heads = einshape(
249 | 'bthw(cd)->bthwcd', feature_grid, d=self.num_heads
250 | )
251 | interp_features_heads = einshape(
252 | 'bn(cd)->bncd',
253 | interp_features,
254 | d=self.num_heads,
255 | )
256 | out = {'feature_grid': feature_grid}
257 | if get_query_feats:
258 | out['query_feats'] = interp_features
259 |
260 | if compute_regression:
261 | assert query_chunk_size is not None
262 | all_occ = []
263 | all_pts = []
264 | infer = functools.partial(self.tracks_from_cost_volume, im_shp=shape)
265 |
266 | for i in range(0, query_points.shape[1], query_chunk_size):
267 | points, occlusion = infer(
268 | interp_features_heads[:, i:i + query_chunk_size],
269 | feature_grid_heads,
270 | query_points[:, i:i + query_chunk_size],
271 | )
272 | all_occ.append(occlusion)
273 | all_pts.append(points)
274 | occlusion = jnp.concatenate(all_occ, axis=1)
275 | points = jnp.concatenate(all_pts, axis=1)
276 |
277 | out['occlusion'] = occlusion
278 | out['tracks'] = points
279 |
280 | return out
281 |
--------------------------------------------------------------------------------
/tapnet/tapnext/tapnext_torch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """TAPNext implementation in torch."""
17 |
18 | import dataclasses
19 | from typing import List
20 |
21 | import einops
22 | import numpy as np
23 | from tapnet.tapnext import tapnext_lru_modules
24 | import torch
25 | from torch import nn
26 | from torch.nn import functional as F
27 | from torchvision.models import vision_transformer
28 |
29 |
30 | def posemb_sincos_2d(h, w, width, temperature=10_000.0, dtype=np.float32):
31 | """Follows the MoCo v3 logic."""
32 | y, x = np.mgrid[:h, :w]
33 |
34 | assert width % 4 == 0, 'Width must be mult of 4 for sincos posemb'
35 | omega = np.arange(width // 4) / (width // 4 - 1)
36 | omega = 1.0 / (temperature**omega)
37 | y = np.einsum('m,d->md', y.flatten(), omega)
38 | x = np.einsum('m,d->md', x.flatten(), omega)
39 | pe = np.concatenate([np.sin(x), np.cos(x), np.sin(y), np.cos(y)], axis=1)
40 | return np.asarray(pe, dtype)[None, :, :]
41 |
42 |
43 | class TRecViTBlock(nn.Module):
44 | """A block proposed by https://arxiv.org/abs/2412.14294."""
45 |
46 | def __init__(self, depth, width, num_heads, lru_width, dtype, device):
47 | super().__init__()
48 | self.ssm_block = tapnext_lru_modules.ResidualBlock(
49 | width=width,
50 | mlp_expanded_width=width * 4,
51 | num_heads=num_heads,
52 | lru_width=lru_width,
53 | final_w_init_variance_scale=2.0 / depth,
54 | dtype=dtype,
55 | device=device,
56 | )
57 | self.vit_block = vision_transformer.EncoderBlock(
58 | num_heads=num_heads,
59 | mlp_dim=width * 4,
60 | hidden_dim=width,
61 | attention_dropout=0.0,
62 | dropout=0.0,
63 | )
64 |
65 | def forward(self, x, cache=None, use_linear_scan=True):
66 | b, t, n, _ = x.shape
67 | x = einops.rearrange(x, 'b t n c -> (b n) t c')
68 | x, ssm_cache = self.ssm_block(x, cache, use_linear_scan=use_linear_scan)
69 | x = einops.rearrange(x, '(b n) t c -> (b t) n c', b=b, n=n)
70 | x = self.vit_block(x)
71 | x = einops.rearrange(x, '(b t) n c -> b t n c', b=b, t=t)
72 | return x, ssm_cache
73 |
74 |
75 | @dataclasses.dataclass
76 | class TAPNextTrackingState:
77 | """State for TAPNext."""
78 |
79 | step: int
80 | query_points: torch.Tensor # Float["*B Q 3"]
81 | hidden_state: List[tapnext_lru_modules.RecurrentBlockCache] = None
82 |
83 |
84 | class TAPNext(nn.Module):
85 | """TAPNext implementation in pytorch."""
86 |
87 | def __init__(
88 | self,
89 | image_size,
90 | width=768,
91 | patch_size=(8, 8),
92 | num_heads=12,
93 | lru_width=768,
94 | depth=12,
95 | use_checkpointing=False,
96 | ):
97 | super().__init__()
98 | self.width = width
99 | self.patch_size = patch_size
100 | self.use_checkpointing = use_checkpointing
101 | self.image_size = image_size
102 |
103 | self.lin_proj = nn.Conv2d(
104 | in_channels=3,
105 | out_channels=self.width,
106 | kernel_size=self.patch_size,
107 | stride=self.patch_size,
108 | )
109 | self.blocks = nn.ModuleList([
110 | TRecViTBlock(
111 | depth=depth,
112 | width=width,
113 | num_heads=num_heads,
114 | lru_width=lru_width,
115 | dtype=torch.float32,
116 | device='cuda',
117 | )
118 | for _ in range(depth)
119 | ])
120 | self.encoder_norm = nn.LayerNorm(self.width)
121 | self.mask_token = nn.Parameter(
122 | torch.zeros((1, 1, 1, self.width)), requires_grad=True
123 | )
124 | self.unknown_token = nn.Parameter(
125 | torch.zeros((1, 1, self.width)), requires_grad=True
126 | )
127 | self.point_query_token = nn.Parameter(
128 | torch.zeros((1, 1, 1, self.width)), requires_grad=True
129 | )
130 | h = self.image_size[0] // self.patch_size[0]
131 | w = self.image_size[1] // self.patch_size[1]
132 | c = self.width
133 | self.image_pos_emb = nn.Parameter(
134 | torch.zeros((1, h * w, c)), requires_grad=True
135 | )
136 | self.register_buffer(
137 | 'query_pos_embed',
138 | torch.tensor(
139 | posemb_sincos_2d(self.image_size[0], self.image_size[1], c)
140 | ),
141 | )
142 | self.visible_head = nn.Sequential(
143 | nn.Linear(width, 256),
144 | nn.LayerNorm(256),
145 | nn.GELU(),
146 | nn.Linear(256, 256),
147 | nn.LayerNorm(256),
148 | nn.GELU(),
149 | nn.Linear(256, 1),
150 | )
151 | self.coordinate_head = nn.Sequential(
152 | nn.Linear(width, 256),
153 | nn.LayerNorm(256),
154 | nn.GELU(),
155 | nn.Linear(256, 256),
156 | nn.LayerNorm(256),
157 | nn.GELU(),
158 | nn.Linear(256, 512),
159 | )
160 |
161 | def embed_queries(self, timesteps, query_points):
162 | b, q, _ = query_points.shape
163 | t = timesteps
164 | c = self.width
165 | tiled_point_query_tokens = self.point_query_token.repeat(b, 1, q, 1)
166 | mask_tokens = self.mask_token.repeat(b, t, q, 1)
167 | unknown_tokens = self.unknown_token.repeat(b, t, q, 1)
168 | query_pos_embed = self.query_pos_embed.view(
169 | 1, self.image_size[0], self.image_size[1], c
170 | ).repeat(b, 1, 1, 1)
171 | # [B Q t 3]
172 | query_timesteps, query_positions = (
173 | query_points[..., :1],
174 | query_points[..., 1:],
175 | )
176 | # grid sample expects coordinates in [-1, 1]
177 | query_pos_embed_spatial = F.grid_sample(
178 | query_pos_embed.permute(
179 | 0, 3, 2, 1
180 | ), # we swap h and w due to how grid_sample works
181 | (
182 | query_positions.unsqueeze(1)
183 | / torch.tensor(self.image_size, device=query_positions.device)
184 | )
185 | * 2
186 | - 1,
187 | align_corners=False,
188 | ).permute(
189 | 0, 2, 3, 1
190 | ) # [b Q t c]
191 | point_query_tokens = tiled_point_query_tokens + query_pos_embed_spatial
192 | # NOTE: query hints are only implemented in jax but not in pytorch
193 | # the two masks below are used to not add point query (if queried later in
194 | # online tracking)
195 | if t == 1: # online tracking
196 | mask_and_query_tokens = torch.where(
197 | query_timesteps.unsqueeze(1) == 0, point_query_tokens, mask_tokens
198 | )
199 | else:
200 | queries_are_late = query_timesteps >= t
201 | queries_are_early = query_timesteps < 0
202 | mask_and_query_tokens = mask_tokens.scatter(
203 | dim=1,
204 | index=query_timesteps.unsqueeze(1)
205 | .long()
206 | .clamp(0, t - 1)
207 | .repeat(1, 1, 1, c),
208 | src=point_query_tokens,
209 | )
210 | mask_and_query_tokens = torch.where(
211 | (queries_are_late | queries_are_early).unsqueeze(1),
212 | mask_tokens,
213 | mask_and_query_tokens,
214 | )
215 | is_unknown_token = torch.arange(t, device=query_points.device)[
216 | None, :, None, None
217 | ] < query_timesteps.unsqueeze(1)
218 | mask_and_query_tokens = torch.where(
219 | is_unknown_token, unknown_tokens, mask_and_query_tokens
220 | )
221 | return mask_and_query_tokens
222 |
223 | def prediction_heads(self, x):
224 | soft_argmax_threshold = 20
225 | softmax_temperature = 0.5
226 | track_logits = self.coordinate_head(x.float())
227 | position_x, position_y = track_logits.chunk(2, dim=-1)
228 | argmax_x, argmax_y = position_x.argmax(
229 | dim=-1, keepdim=True
230 | ), position_y.argmax(dim=-1, keepdim=True)
231 | index = torch.arange(position_x.shape[-1], device=x.device).repeat(
232 | *argmax_x.shape[:-1], 1
233 | )
234 | mask_x = (torch.abs(argmax_x - index) <= soft_argmax_threshold).float()
235 | mask_y = (torch.abs(argmax_y - index) <= soft_argmax_threshold).float()
236 | probs_x = F.softmax(position_x * softmax_temperature, dim=-1) * mask_x
237 | probs_y = F.softmax(position_y * softmax_temperature, dim=-1) * mask_y
238 | probs_x = probs_x / probs_x.sum(dim=-1, keepdim=True)
239 | probs_y = probs_y / probs_y.sum(dim=-1, keepdim=True)
240 | tracks_x = torch.sum(probs_x * index, dim=-1)[..., None]
241 | tracks_y = torch.sum(probs_y * index, dim=-1)[..., None]
242 | tracks = torch.cat([tracks_x, tracks_y], axis=-1)
243 | tracks += 0.5
244 | visible_logits = self.visible_head(x)
245 | return tracks, track_logits, visible_logits
246 |
247 | def forward(self, video, query_points=None, state=None):
248 | # video.shape
249 | b, t, _, _, _ = video.shape
250 | # [b, t, h, w, 3] -> [b, t, 3, h, w]
251 | video_tokens = self.lin_proj(
252 | einops.rearrange(video, 'b t h w c -> (b t) c h w')
253 | )
254 | _, _, h, w = video_tokens.shape
255 | video_tokens = einops.rearrange(
256 | video_tokens, '(b t) c h w -> b t (h w) c', b=b, t=t
257 | )
258 | video_tokens = video_tokens + self.image_pos_emb.unsqueeze(0)
259 | if state is not None:
260 | # in online tracking, we put query "back in time"
261 | if query_points is None:
262 | query_points = state.query_points
263 | query_points = torch.cat(
264 | [query_points[..., :1] - state.step, query_points[..., 1:]], dim=-1
265 | )
266 | step = state.step
267 | else:
268 | step = 0
269 | point_tokens = self.embed_queries(t, query_points) # [b t Q c]
270 | x = torch.cat([video_tokens, point_tokens], dim=2) # [b t (h * w + Q) c]
271 | ssm_cache = []
272 | use_linear_scan = not self.training
273 | for blk, cache in zip(
274 | self.blocks, state.hidden_state if state is not None else [None] * 12
275 | ):
276 | if self.use_checkpointing:
277 | x, ssm_cache_layer = torch.utils.checkpoint.checkpoint(
278 | blk, x, cache, use_linear_scan, use_reentrant=False
279 | )
280 | else:
281 | x, ssm_cache_layer = blk(
282 | x, cache=cache, use_linear_scan=use_linear_scan
283 | )
284 | ssm_cache.append(ssm_cache_layer)
285 | x = self.encoder_norm(x)
286 | video_tokens, point_tokens = x[:, :, : h * w, :], x[:, :, h * w :, :]
287 | return (
288 | *self.prediction_heads(point_tokens),
289 | TAPNextTrackingState(
290 | step=step + t,
291 | query_points=state.query_points
292 | if state is not None
293 | else query_points,
294 | hidden_state=ssm_cache,
295 | ),
296 | )
297 |
298 | torch.export.register_dataclass(TAPNextTrackingState)
299 |
300 |
301 | def flatten_tracking_state(state, _):
302 | return (
303 | state.step,
304 | state.query_points,
305 | [[st for st in h] for h in state.hidden_state],
306 | )
307 |
308 |
309 | torch.fx._pytree.register_pytree_flatten_spec( # pylint: disable=protected-access
310 | TAPNextTrackingState, flatten_tracking_state
311 | )
312 |
--------------------------------------------------------------------------------
/tapnet/tapvid3d/evaluation/evaluate_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Compute metrics on the saved outputs of a TAP-3D model."""
17 |
18 | from collections.abc import Sequence
19 | import io
20 | import os
21 |
22 | from absl import app
23 | from absl import flags
24 | from absl import logging
25 | import numpy as np
26 | from PIL import Image
27 | from tapnet.tapvid3d.evaluation import metrics
28 | from tapnet.tapvid3d.splits import tapvid3d_splits
29 | import tqdm
30 |
31 |
32 | _TAPVID3D_DIR = flags.DEFINE_string(
33 | 'tapvid3d_dir',
34 | 'tapvid3d_dataset/',
35 | """
36 | Path to folder that has all the ground truth npy files. Should
37 | contain subfolders for each data source (adt, pstudio, drivetrack).
38 | """,
39 | )
40 |
41 | _TAPVID3D_PREDICTIONS = flags.DEFINE_string(
42 | 'tapvid3d_predictions',
43 | 'tapvid3d_predictions/',
44 | """
45 | Must be a folder that contains a set of *.npz files, each one with
46 | the predicted 3D trajectories for that video, with filename
47 | exactly matching the corresponding ground truth file/video example.
48 | The npy file must have a tensor `tracks_xyz` of shape (T, N, 3) and
49 | a tensor `visible` of shape where [T, N], where N is the
50 | number of tracks in the video, and T is the number of frames.
51 | """,
52 | )
53 |
54 | _USE_MINIVAL = flags.DEFINE_boolean(
55 | 'use_minival',
56 | True,
57 | """
58 | If True, compute metrics on the minival split;
59 | otherwise uses the full_eval split.
60 | """,
61 | )
62 |
63 | _DEPTH_SCALINGS = flags.DEFINE_list(
64 | 'depth_scalings',
65 | ['median'],
66 | 'Depth scaling strategies to use, list of [median, per_trajectory, etc..].',
67 | )
68 |
69 | _DATA_SOURCES_TO_EVALUATE = flags.DEFINE_list(
70 | 'data_sources_to_evaluate',
71 | ['drivetrack', 'adt', 'pstudio'],
72 | 'Which data source subsets to evaluate.',
73 | )
74 |
75 | _DEBUG = flags.DEFINE_boolean(
76 | 'debug',
77 | False,
78 | 'Whether to run in debug mode, downloads only one video.',
79 | )
80 |
81 |
82 | ZERO_METRICS_DICT = {
83 | 'occlusion_accuracy': np.array([0.0]),
84 | 'pts_within_1': np.array([0.0]),
85 | 'jaccard_1': np.array([0.0]),
86 | 'pts_within_2': np.array([0.0]),
87 | 'jaccard_2': np.array([0.0]),
88 | 'pts_within_4': np.array([0.0]),
89 | 'jaccard_4': np.array([0.0]),
90 | 'pts_within_8': np.array([0.0]),
91 | 'jaccard_8': np.array([0.0]),
92 | 'pts_within_16': np.array([0.0]),
93 | 'jaccard_16': np.array([0.0]),
94 | 'average_jaccard': np.array([0.0]),
95 | 'average_pts_within_thresh': np.array([0.0]),
96 | }
97 |
98 |
99 | def get_new_hw_with_given_smallest_side_length(
100 | *, orig_height: int, orig_width: int, smallest_side_length: int = 256
101 | ):
102 | orig_shape = np.array([orig_height, orig_width])
103 | scaling_factor = smallest_side_length / np.min(orig_shape)
104 | resized_shape = np.round(orig_shape * scaling_factor)
105 | return (int(resized_shape[0]), int(resized_shape[1])), scaling_factor
106 |
107 |
108 | def get_jpeg_byte_hw(jpeg_bytes: bytes):
109 | with io.BytesIO(jpeg_bytes) as img_bytes:
110 | img = Image.open(img_bytes)
111 | img = img.convert('RGB')
112 | return np.array(img).shape[:2]
113 |
114 |
115 | def get_average_over_metrics(
116 | list_of_metrics: list[dict[str, dict[str, np.ndarray]]]
117 | ):
118 | """Average over per video metrics in a nested metrics dictionaries."""
119 | avg_metrics = {}
120 | for metric_category in list_of_metrics[0].keys():
121 | avg_metrics[metric_category] = {}
122 | for metric_name in list_of_metrics[0][metric_category]:
123 | avg_metrics[metric_category][metric_name] = np.mean(
124 | [
125 | video_metric[metric_category][metric_name]
126 | for video_metric in list_of_metrics
127 | ]
128 | )
129 | return avg_metrics
130 |
131 |
132 | def evaluate_data_source(
133 | npz_filenames: list[str],
134 | ground_truth_dir: str,
135 | predictions_dir: str,
136 | depth_scalings: list[str],
137 | metric_eval_resolution: int = 256,
138 | ) -> dict[str, dict[str, np.ndarray]]:
139 | """Compute metrics on the set of 3D track predictions.
140 |
141 | This function expects as input two directories: `ground_truth_dir` and
142 | `predictions_dir`. They should have a structure as follows:
143 |
144 | - ground_truth_dir:
145 | - video1.npz (with contents+format described in tapvid3d/README.md)
146 | - video2.npz
147 | - video3.npz
148 | - ...
149 | - videoN.npz
150 |
151 | - predictions_dir:
152 | - video1.npz (contains two tensors: the predicted `visibility` and
153 | `tracks_XYZ`, with shapes described in tapvid3d/README.md)
154 | - video2.npz
155 | - video3.npz
156 | - ...
157 | - videoN.npz
158 |
159 | `ground_truth_dir` contains npz files with ground truth data, while
160 | `predictions_dir` should contains npz files with the same file name as their
161 | corresponding ground truth file in `ground_truth_dir`. npz files can be named
162 | anything (do not need to be named as video{N}.npz in the example), as long
163 | as it is consistent across both directories.
164 |
165 | Also note that this directory structure (for `ground_truth_dir`) is what is
166 | generated by `generate_all.sh` script. The script generates three different
167 | directories: `tapvid3d_dataset/adt`, `tapvid3d_dataset/pstudio`,
168 | `tapvid3d_dataset/drivetrack`, each one with the structure described above.
169 |
170 | Args:
171 | npz_filenames: List of npz file names to compute metrics over.
172 | Each filename must exist in both `ground_truth_dir` and
173 | `predictions_dir`.
174 | ground_truth_dir: Path to the TAP-3D dataset, format described above.
175 | predictions_dir: Path to the TAP-3D predictions, format described above.
176 | depth_scalings: List of strings, describing which depth scaling strategies
177 | to use. See `compute_tapvid3d_metrics()` in metrics.py for
178 | available scaling options and their descriptions.
179 | metric_eval_resolution: The resolution at which to evaluate the metrics,
180 | by default is 256, which is used for all results
181 | in the TAPVid-3D paper and original TAPVid paper.
182 |
183 | Returns:
184 | A dictionary of dictionaries: the outer dict mapping different
185 | depth scaling strategies to an inner metrics dict (mapping a metric name
186 | to the corresponding metric score).
187 | """
188 | metrics_all_videos = []
189 | for npy_file in tqdm.tqdm(npz_filenames, total=len(npz_filenames)):
190 | gt_file = os.path.join(ground_truth_dir, npy_file)
191 | with open(gt_file, 'rb') as in_f:
192 | in_npz = np.load(in_f, allow_pickle=True)
193 | images_jpeg_bytes = in_npz['images_jpeg_bytes']
194 | queries_xyt = in_npz['queries_xyt']
195 | tracks_xyz = in_npz['tracks_XYZ']
196 | visibles = in_npz['visibility']
197 | intrinsics_params = in_npz['fx_fy_cx_cy']
198 |
199 | # Simple rescaling from original video resolution to the resolution
200 | # at which we want to compute the metrics (metric_eval_resolution)
201 | # Since camera intrinsics are for original video resolution, we need to
202 | # rescale them as well.
203 | # `intrinsics_params` is in format [fx, fy, cx, cy].
204 | video_height, video_width = get_jpeg_byte_hw(images_jpeg_bytes[0])
205 | smallest_side_length = metric_eval_resolution
206 | (_, _), scaling_factor = (
207 | get_new_hw_with_given_smallest_side_length(
208 | orig_height=video_height,
209 | orig_width=video_width,
210 | smallest_side_length=smallest_side_length,
211 | )
212 | )
213 | intrinsics_params_resized = intrinsics_params * scaling_factor
214 |
215 | prediction_file = os.path.join(predictions_dir, npy_file)
216 | try:
217 | with open(prediction_file, 'rb') as in_f:
218 | predictor_data = np.load(in_f, allow_pickle=True)
219 | predicted_tracks_xyz = predictor_data['tracks_XYZ']
220 | predicted_visibility = predictor_data['visibility']
221 |
222 | except Exception: # pylint: disable=broad-exception-caught
223 | logging.exception('Failed to read %s', prediction_file)
224 | failure_metrics_dict = {
225 | scaling: ZERO_METRICS_DICT for scaling in depth_scalings
226 | }
227 | metrics_all_videos.append(failure_metrics_dict)
228 | if _DEBUG.value:
229 | logging.info('Stopping after one video, debug run.')
230 | break
231 | continue
232 |
233 | video_metrics = {}
234 | for depth_scaling in depth_scalings:
235 | try:
236 | metrics_for_scale = metrics.compute_tapvid3d_metrics(
237 | gt_occluded=np.logical_not(visibles),
238 | gt_tracks=tracks_xyz,
239 | pred_occluded=np.logical_not(predicted_visibility),
240 | pred_tracks=predicted_tracks_xyz,
241 | intrinsics_params=intrinsics_params_resized,
242 | scaling=depth_scaling,
243 | query_points=queries_xyt[..., ::-1],
244 | order='t n',
245 | )
246 | except Exception: # pylint: disable=broad-exception-caught
247 | logging.exception('Failed to compute metrics for %s', npy_file)
248 | metrics_for_scale = ZERO_METRICS_DICT
249 | video_metrics[depth_scaling] = metrics_for_scale
250 | metrics_all_videos.append(video_metrics)
251 |
252 | if _DEBUG.value:
253 | logging.info('Stopping after one video, debug run.')
254 | break
255 |
256 | avg_metrics = get_average_over_metrics(metrics_all_videos)
257 | return avg_metrics
258 |
259 |
260 | def main(argv: Sequence[str]) -> None:
261 | if len(argv) > 1:
262 | raise app.UsageError('Too many command-line arguments.')
263 |
264 | metrics_all_sources = []
265 | for data_source in _DATA_SOURCES_TO_EVALUATE.value:
266 | if _USE_MINIVAL.value:
267 | all_npz_files = tapvid3d_splits.get_minival_files(subset=data_source)
268 | else:
269 | all_npz_files = tapvid3d_splits.get_full_eval_files(subset=data_source)
270 | source_gt_dir = os.path.join(_TAPVID3D_DIR.value, data_source)
271 | source_pred_dir = os.path.join(_TAPVID3D_PREDICTIONS.value, data_source)
272 | source_metrics = evaluate_data_source(
273 | npz_filenames=all_npz_files,
274 | ground_truth_dir=source_gt_dir,
275 | predictions_dir=source_pred_dir,
276 | depth_scalings=_DEPTH_SCALINGS.value,
277 | )
278 | metrics_all_sources.append(source_metrics)
279 | logging.info('Metrics for data source %s', data_source)
280 | logging.info(source_metrics)
281 |
282 | avg_metrics = get_average_over_metrics(metrics_all_sources)
283 | logging.info('Metrics, averaged across all data sources:')
284 | logging.info(avg_metrics)
285 |
286 | logging.info('Finished computing metrics!')
287 |
288 |
289 | if __name__ == '__main__':
290 | app.run(main)
291 |
--------------------------------------------------------------------------------