├── videoprism
├── __init__.py
├── assets
│ ├── testdata
│ │ └── test_spm.model
│ └── water_bottle_drumming.mp4
├── utils_test.py
├── tokenizers_test.py
├── models_test.py
├── utils.py
├── tokenizers.py
├── colabs
│ ├── videoprism_video_encoder_demo.ipynb
│ └── videoprism_video_text_demo.ipynb
├── layers_test.py
├── encoders_test.py
├── models.py
├── encoders.py
└── layers.py
├── requirements.txt
├── CONTRIBUTING.md
├── setup.py
├── README.md
└── LICENSE
/videoprism/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.26
2 | absl-py
3 | einops
4 | einshape
5 | flax
6 | huggingface-hub
7 | sentencepiece
8 | tensorflow-cpu
--------------------------------------------------------------------------------
/videoprism/assets/testdata/test_spm.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/videoprism/HEAD/videoprism/assets/testdata/test_spm.model
--------------------------------------------------------------------------------
/videoprism/assets/water_bottle_drumming.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/videoprism/HEAD/videoprism/assets/water_bottle_drumming.mp4
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | ## Contributor License Agreement
4 |
5 | Contributions to this project must be accompanied by a Contributor License
6 | Agreement. You (or your employer) retain the copyright to your contribution,
7 | this simply gives us permission to use and redistribute your contributions as
8 | part of the project. Head over to to see
9 | your current agreements on file or to sign a new one.
10 |
11 | You generally only need to submit a CLA once, so if you've already submitted one
12 | (even if it was for a different project), you probably don't need to do it
13 | again.
14 |
15 | ## Code reviews
16 |
17 | All submissions, including submissions by project members, require review. We
18 | use GitHub pull requests for this purpose. Consult
19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
20 | information on using pull requests.
21 |
22 | ## Community Guidelines
23 |
24 | This project follows [Google's Open Source Community
25 | Guidelines](https://opensource.google/conduct/).
--------------------------------------------------------------------------------
/videoprism/utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 VideoPrism Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for utility functions."""
16 |
17 | from absl.testing import absltest
18 | from videoprism import utils
19 |
20 |
21 | class UtilsTest(absltest.TestCase):
22 |
23 | def test_canonicalize_text(self):
24 | self.assertEqual(utils.canonicalize_text("Hello, World!"), "hello world.")
25 | self.assertEqual(utils.canonicalize_text("Hello,World.."), "hello world.")
26 | self.assertEqual(utils.canonicalize_text(" Hello WORLD"), "hello world.")
27 |
28 |
29 | if __name__ == "__main__":
30 | absltest.main()
31 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 VideoPrism Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """setup.py for VideoPrism.
16 |
17 | Install for development:
18 |
19 | pip intall -e . .[testing]
20 | """
21 |
22 | import setuptools
23 |
24 | # Get install requirements from the REQUIREMENTS file.
25 | with open("requirements.txt") as fp:
26 | install_requires_core = fp.read().splitlines()
27 |
28 | tests_require = [
29 | "chex",
30 | "pytest",
31 | ]
32 |
33 | setuptools.setup(
34 | name="videoprism",
35 | version="1.0.0",
36 | description=(
37 | "VideoPrism: A Foundational Visual Encoder for Video Understanding."
38 | ),
39 | author="VideoPrism Authors",
40 | author_email="no-reply@google.com",
41 | long_description=open("README.md").read(),
42 | long_description_content_type="text/markdown",
43 | url="https://github.com/google-deepmind/videoprism",
44 | license="Apache 2.0",
45 | packages=setuptools.find_packages(),
46 | include_package_data=True,
47 | install_requires=install_requires_core,
48 | tests_require=tests_require,
49 | extras_require={
50 | "testing": tests_require,
51 | },
52 | classifiers=[
53 | "Development Status :: 1 - Beta",
54 | "Intended Audience :: Developers",
55 | "Intended Audience :: VideoPrism/Research",
56 | "License :: OSI Approved :: Apache Software License",
57 | "Programming Language :: Python",
58 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
59 | ],
60 | keywords="VideoPrism",
61 | )
62 |
--------------------------------------------------------------------------------
/videoprism/tokenizers_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 VideoPrism Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for text tokenizers."""
16 |
17 | from absl.testing import parameterized
18 | import numpy as np
19 | import tensorflow as tf
20 | from videoprism import tokenizers
21 |
22 |
23 | class PyToTfWrapper:
24 | """Allows to use `to_int_tf_op()` via `to_int()`."""
25 |
26 | def __init__(self, model):
27 | self.model = model
28 | self.bos_token = model.bos_token
29 | self.eos_token = model.eos_token
30 | self.vocab_size = model.vocab_size
31 |
32 | def to_int(self, text, *, bos=False, eos=False):
33 | ret = self.model.to_int_tf_op(text, bos=bos, eos=eos)
34 | if isinstance(ret, tf.RaggedTensor):
35 | return [t.numpy().tolist() for t in ret]
36 | return ret.numpy().tolist()
37 |
38 |
39 | class TokenizersTest(tf.test.TestCase, parameterized.TestCase):
40 |
41 | def setUp(self):
42 | import os
43 | self.spm_path = os.path.join(
44 | os.path.dirname(__file__), "assets", "testdata", "test_spm.model"
45 | )
46 | super().setUp()
47 |
48 | @parameterized.named_parameters(
49 | ("py", False),
50 | ("tf", True),
51 | )
52 | def test_sentencepiece_tokenizer(self, wrap_model):
53 | model = tokenizers.SentencePieceTokenizer(self.spm_path)
54 | if wrap_model:
55 | model = PyToTfWrapper(model)
56 | self.assertEqual(model.vocab_size, 1000)
57 | bos, eos = model.bos_token, model.eos_token
58 | self.assertEqual(bos, 1)
59 | self.assertEqual(eos, 2)
60 | self.assertEqual(model.to_int("blah"), [80, 180, 60])
61 | self.assertEqual(model.to_int("blah", bos=True), [bos, 80, 180, 60])
62 | self.assertEqual(model.to_int("blah", eos=True), [80, 180, 60, eos])
63 | self.assertEqual(
64 | model.to_int("blah", bos=True, eos=True), [bos, 80, 180, 60, eos]
65 | )
66 | self.assertEqual(
67 | model.to_int(["blah", "blah blah"]),
68 | [[80, 180, 60], [80, 180, 60, 80, 180, 60]],
69 | )
70 |
71 | def test_sentencepiece_tokenizer_tf_data(self):
72 | model = tokenizers.SentencePieceTokenizer(self.spm_path)
73 |
74 | def gen():
75 | yield tf.convert_to_tensor(["blah"])
76 | yield tf.convert_to_tensor(["blah", "blah blah"])
77 |
78 | ds = tf.data.Dataset.from_generator(gen, tf.string, tf.TensorShape([None]))
79 | ds = ds.map(model.to_int_tf_op)
80 | res = [
81 | [b.tolist() if isinstance(b, np.ndarray) else b for b in a.tolist()]
82 | for a in ds.as_numpy_iterator()
83 | ]
84 | print(res)
85 | self.assertAllEqual(
86 | res, [[[80, 180, 60]], [[80, 180, 60], [80, 180, 60, 80, 180, 60]]]
87 | )
88 |
89 |
90 | if __name__ == "__main__":
91 | tf.test.main()
92 |
--------------------------------------------------------------------------------
/videoprism/models_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 VideoPrism Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for VideoPrism models."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | import jax
20 | from jax import numpy as jnp
21 | import numpy as np
22 | from videoprism import models
23 | from videoprism import tokenizers
24 |
25 |
26 | class ModelsTest(parameterized.TestCase):
27 |
28 | @parameterized.parameters(
29 | ('videoprism_public_v1_base', True),
30 | ('videoprism_public_v1_large', True),
31 | ('videoprism_public_v1_giant', False),
32 | )
33 | def test_has_model(self, model_name, exists):
34 | self.assertEqual(models.has_model(model_name), exists)
35 |
36 | @parameterized.parameters(8, 16)
37 | def test_videoprism(self, num_frames):
38 | batch_size = 1
39 | np_inputs = np.random.normal(
40 | 0.0, 0.1, [batch_size, num_frames, 288, 288, 3]
41 | ).astype('float32')
42 | inputs = jnp.asarray(np_inputs)
43 | prng_key = jax.random.PRNGKey(seed=123)
44 |
45 | mdl = models.videoprism_v1_base()
46 | mdl_params = mdl.init(prng_key, inputs, train=False)
47 |
48 | @jax.jit
49 | def forward_fn(mdl_inputs):
50 | return mdl.apply(mdl_params, mdl_inputs, train=False)
51 |
52 | embeddings, _ = forward_fn(inputs)
53 | self.assertEqual(embeddings.shape, (batch_size, num_frames * 16**2, 768))
54 |
55 | def test_videoprism_lvt(self):
56 | batch_size, num_frames = 1, 16
57 | np_inputs = np.random.normal(
58 | 0.0, 0.1, [batch_size, num_frames, 288, 288, 3]
59 | ).astype('float32')
60 | inputs = jnp.asarray(np_inputs)
61 | np_text_token_ids = np.random.randint(
62 | 0, 32_000, [batch_size, models.TEXT_MAX_LEN]
63 | ).astype('int32')
64 | text_token_ids = jnp.asarray(np_text_token_ids)
65 | np_text_paddings = np.zeros(
66 | [batch_size, models.TEXT_MAX_LEN], dtype='float32'
67 | )
68 | np_text_paddings[:, models.TEXT_MAX_LEN // 2 :] = 1
69 | text_paddings = jnp.asarray(np_text_paddings)
70 | prng_key = jax.random.PRNGKey(seed=123)
71 |
72 | mdl = models.videoprism_lvt_v1_base()
73 | mdl_params = mdl.init(
74 | prng_key, inputs, text_token_ids, text_paddings, train=False
75 | )
76 |
77 | @jax.jit
78 | def forward_fn(mdl_inputs, mdl_text_token_ids, mdl_text_paddings):
79 | return mdl.apply(
80 | mdl_params,
81 | mdl_inputs,
82 | mdl_text_token_ids,
83 | mdl_text_paddings,
84 | train=False,
85 | )
86 |
87 | vision_embeddings, text_embeddings, _ = forward_fn(
88 | inputs, text_token_ids, text_paddings
89 | )
90 | self.assertEqual(vision_embeddings.shape, (batch_size, 768))
91 | self.assertEqual(text_embeddings.shape, (batch_size, 768))
92 |
93 | def test_tokenize_texts(self):
94 | import os
95 | spm_path = os.path.join(
96 | os.path.dirname(__file__), 'assets', 'testdata', 'test_spm.model'
97 | )
98 | model = tokenizers.SentencePieceTokenizer(spm_path)
99 | ids, paddings = models.tokenize_texts(
100 | model,
101 | ['blah', 'blah blah', 'blah blah blah'],
102 | max_length=6,
103 | add_bos=False,
104 | canonicalize=False,
105 | )
106 | np.testing.assert_array_equal(
107 | ids,
108 | [
109 | [80, 180, 60, 0, 0, 0],
110 | [80, 180, 60, 80, 180, 60],
111 | [80, 180, 60, 80, 180, 60],
112 | ],
113 | )
114 | np.testing.assert_array_equal(
115 | paddings, [[0, 0, 0, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]]
116 | )
117 |
118 |
119 | if __name__ == '__main__':
120 | absltest.main()
121 |
--------------------------------------------------------------------------------
/videoprism/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 VideoPrism Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utility functions for checkpointing and other purposes."""
16 |
17 | import collections
18 | from collections.abc import Mapping, Sequence
19 | import io
20 | import os
21 | import string
22 |
23 | import jax
24 | import numpy as np
25 | from tensorflow.io import gfile
26 |
27 |
28 | def traverse_with_names(tree, with_inner_nodes=False):
29 | """Traverses nested dicts and emits (leaf_name, leaf_val).
30 |
31 | Args:
32 | tree: JAX Pytree object.
33 | with_inner_nodes: Whether to traverse the non-leaf nodes.
34 |
35 | Yields:
36 | A pair of (leaf_name, leaf_val).
37 | """
38 | # Don't output the non-leaf nodes. If the optimizer doesn't have a state
39 | # the tree leaves can be Nones which was interpreted as a leaf by this
40 | # function but not by the other functions (like jax.tree.map).
41 | if tree is None:
42 | return
43 | elif isinstance(tree, Mapping):
44 | keys = sorted(tree.keys())
45 | for key in keys:
46 | for path, v in traverse_with_names(tree[key], with_inner_nodes):
47 | yield (key + "/" + path).rstrip("/"), v
48 | if with_inner_nodes:
49 | yield "", tree
50 | elif isinstance(tree, Sequence):
51 | for idx in range(len(tree)):
52 | for path, v in traverse_with_names(tree[idx], with_inner_nodes):
53 | yield (str(idx) + "/" + path).rstrip("/"), v
54 | if with_inner_nodes:
55 | yield "", tree
56 | else:
57 | yield "", tree
58 |
59 |
60 | def tree_flatten_with_names(tree):
61 | """Populates tree_flatten with leaf names.
62 |
63 | Args:
64 | tree: JAX Pytree object.
65 |
66 | Returns:
67 | A list of values with names: [(name, value), ...]
68 | """
69 | vals, tree_def = jax.tree.flatten(tree)
70 |
71 | tokens = range(len(vals))
72 | token_tree = tree_def.unflatten(tokens)
73 | val_names, perm = zip(*traverse_with_names(token_tree))
74 | inv_perm = np.argsort(perm)
75 |
76 | # Custom traverasal should visit the same number of leaves.
77 | assert len(val_names) == len(vals)
78 |
79 | return [(val_names[i], v) for i, v in zip(inv_perm, vals)]
80 |
81 |
82 | def recover_tree(keys, values):
83 | """Recovers a tree as a nested dict from flat names and values.
84 |
85 | Args:
86 | keys: A list of keys, where '/' is used as separator between nodes.
87 | values: A list of leaf values.
88 |
89 | Returns:
90 | A nested tree-like dict.
91 | """
92 | tree = {}
93 | sub_trees = collections.defaultdict(list)
94 | for k, v in zip(keys, values):
95 | if "/" not in k:
96 | tree[k] = v
97 | else:
98 | k_left, k_right = k.split("/", 1)
99 | sub_trees[k_left].append((k_right, v))
100 | for k, kv_pairs in sub_trees.items():
101 | k_subtree, v_subtree = zip(*kv_pairs)
102 | tree[k] = recover_tree(k_subtree, v_subtree)
103 | return tree
104 |
105 |
106 | def npload(fname):
107 | """Loads `fname` and returns an np.ndarray or dict thereof."""
108 | # Load the data; use local paths directly if possible:
109 | if os.path.exists(fname):
110 | loaded = np.load(fname, allow_pickle=False)
111 | else:
112 | # For other (remote) paths go via gfile+BytesIO as np.load requires seeks.
113 | with gfile.GFile(fname, "rb") as f:
114 | data = f.read()
115 | loaded = np.load(io.BytesIO(data), allow_pickle=False)
116 |
117 | # Support loading both single-array files (np.save) and zips (np.savez).
118 | if isinstance(loaded, np.ndarray):
119 | return loaded
120 | else:
121 | return dict(loaded)
122 |
123 |
124 | def load_checkpoint(npz):
125 | """Loads a jax Pytree from a npz file.
126 |
127 | Args:
128 | npz: Either path to the checkpoint file (.npz), or a dict-like.
129 |
130 | Returns:
131 | A Pytree that is the checkpoint.
132 | """
133 | if isinstance(npz, str): # If not already loaded, then load.
134 | npz = npload(npz)
135 | keys, values = zip(*list(npz.items()))
136 | return recover_tree(keys, values)
137 |
138 |
139 | def canonicalize_text(text: str) -> str:
140 | """Canonicalizes text.
141 |
142 | Canonicalization includes:
143 | - Replace all punctuation with a whitespace.
144 | - Use all lower case.
145 | - Leave only one whitespace between words.
146 | - End with a period.
147 |
148 | Examples:
149 | "Hello, World!" -> "hello world."
150 | "Hello,World.." -> "hello world."
151 | " Hello WORLD" -> "hello world."
152 |
153 | Args:
154 | text: A string for the input text.
155 |
156 | Returns:
157 | A string for the canonicalized text.
158 | """
159 | # Replace all punctuation with a whitespace.
160 | p = string.punctuation
161 | text = text.translate(str.maketrans(p, " " * len(p)))
162 | # Use all lower case.
163 | text = text.lower()
164 | # Leave only one whitespace between words.
165 | text = " ".join(text.split())
166 | # End with a period.
167 | text = text + "."
168 | return text
169 |
--------------------------------------------------------------------------------
/videoprism/tokenizers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 VideoPrism Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tokenizers for text encoders."""
16 |
17 | from collections.abc import Sequence
18 | from typing import Protocol
19 |
20 | import tensorflow as tf
21 | from tensorflow.io import gfile
22 |
23 | import sentencepiece
24 |
25 | SentencePieceProcessor = sentencepiece.SentencePieceProcessor
26 |
27 |
28 | class Tokenizer(Protocol):
29 | """Tokenizer interface."""
30 |
31 | def to_int(
32 | self, text: str | Sequence[str], *, bos: bool = False, eos: bool = False
33 | ) -> list[int] | list[list[int]]:
34 | """Tokenizes `text` into a list of integer tokens.
35 |
36 | Args:
37 | text: can be a single string, or a list of strings.
38 | bos: Whether a beginning-of-sentence token should be prepended.
39 | eos: Whether an end-of-sentence token should be appended.
40 |
41 | Returns:
42 | A list or list-of-list of tokens.
43 | """
44 |
45 | def to_int_tf_op(
46 | self, text: str | Sequence[str], *, bos: bool = False, eos: bool = False
47 | ) -> tf.Tensor | tf.RaggedTensor:
48 | """Same as `to_int()`, but as TF ops to be used in data pipelines.
49 |
50 | Args:
51 | text: can be a single string, or a list of strings.
52 | bos: Whether a beginning-of-sentence token should be prepended.
53 | eos: Whether an end-of-sentence token should be appended.
54 |
55 | Returns:
56 | A tf.Tensor of tokens.
57 | """
58 |
59 | @property
60 | def pad_token(self) -> int:
61 | """Token id of padding token."""
62 |
63 | @property
64 | def eos_token(self) -> int:
65 | """Token id of end-of-sentence token."""
66 |
67 | @property
68 | def bos_token(self) -> int:
69 | """Token id of beginning-of-sentence token."""
70 |
71 | @property
72 | def vocab_size(self) -> int:
73 | """Returns the size of the vocabulary."""
74 |
75 |
76 | class SentencePieceTokenizer(Tokenizer):
77 | """Wraps a SentencePiece model for tokenization."""
78 |
79 | def __init__(self, model_path):
80 | """Initializes the tokenizer.
81 |
82 | Args:
83 | model_path: A path to load the SentencePiece model.
84 | """
85 | with gfile.GFile(model_path, "rb") as f:
86 | model_bytes = f.read()
87 |
88 | self._model = SentencePieceProcessor()
89 | self._model.LoadFromSerializedProto(model_bytes)
90 |
91 | def to_int(
92 | self, text: str | Sequence[str], *, bos: bool = False, eos: bool = False
93 | ) -> list[int] | list[list[int]]:
94 | """Tokenizes `text` into a list of integer tokens.
95 |
96 | Args:
97 | text: can be a single string, or a list of strings.
98 | bos: Whether a beginning-of-sentence token should be prepended.
99 | eos: Whether an end-of-sentence token should be appended.
100 |
101 | Returns:
102 | A list or list-of-list of tokens.
103 | """
104 |
105 | def _single(s: str) -> list[int]:
106 | return (
107 | ([self.bos_token] if bos else [])
108 | + self._model.EncodeAsIds(s)
109 | + ([self.eos_token] if eos else [])
110 | )
111 |
112 | if isinstance(text, str):
113 | return _single(text)
114 | return list([_single(s) for s in text])
115 |
116 | def to_int_tf_op(
117 | self, text: str | Sequence[str], *, bos: bool = False, eos: bool = False
118 | ) -> tf.Tensor | tf.RaggedTensor:
119 | """Same as `to_int()`, but as TF ops to be used in data pipelines.
120 |
121 | Args:
122 | text: can be a single string, or a list of strings.
123 | bos: Whether a beginning-of-sentence token should be prepended.
124 | eos: Whether an end-of-sentence token should be appended.
125 |
126 | Returns:
127 | A tf.Tensor or tf.RaggedTensor of tokens.
128 | """
129 | text = tf.convert_to_tensor(text)
130 | if text.ndim == 0:
131 |
132 | def fn(txt):
133 | """Tokenizes a single string."""
134 | s = txt.numpy().decode()
135 | return tf.constant(self.to_int(s, bos=bos, eos=eos), tf.int32)
136 |
137 | return tf.py_function(fn, [text], tf.int32)
138 | else:
139 |
140 | def fn(txt):
141 | """Tokenizes a list of strings."""
142 | strings = [s.decode() for s in txt.numpy().tolist()]
143 | toks = self.to_int(strings, bos=bos, eos=eos)
144 | return tf.ragged.constant(toks)
145 |
146 | out_type = tf.RaggedTensorSpec([text.shape[0], None], tf.int32)
147 | return tf.py_function(fn, [text], Tout=out_type)
148 |
149 | @property
150 | def pad_token(self) -> int:
151 | """Token id of padding token."""
152 | return self._model.pad_id()
153 |
154 | @property
155 | def eos_token(self) -> int:
156 | """Token id of end-of-sentence token."""
157 | return self._model.eos_id()
158 |
159 | @property
160 | def bos_token(self) -> int:
161 | """Token id of beginning-of-sentence token."""
162 | return self._model.bos_id()
163 |
164 | @property
165 | def vocab_size(self) -> int:
166 | """Returns the size of the vocabulary."""
167 | return self._model.GetPieceSize()
168 |
--------------------------------------------------------------------------------
/videoprism/colabs/videoprism_video_encoder_demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "metadata": {
5 | "id": "KPPUiCpSbm53"
6 | },
7 | "cell_type": "markdown",
8 | "source": [
9 | "# VideoPrism Video Encoder Demo\n",
10 | "\n",
11 | "[](https://arxiv.org/abs/2402.13217)\n",
12 | "[](https://research.google/blog/videoprism-a-foundational-visual-encoder-for-video-understanding/)\n",
13 | "[](https://opensource.org/licenses/Apache-2.0)\n",
14 | "\n",
15 | "This notebook provides an example of video feature extraction with a pre-trained VideoPrism video encoder.\n",
16 | "\n",
17 | "Please run this demo on Google Colab with (faster) or without TPU."
18 | ]
19 | },
20 | {
21 | "metadata": {
22 | "id": "k08qFZ9-cn9v"
23 | },
24 | "cell_type": "markdown",
25 | "source": [
26 | "## Set up"
27 | ]
28 | },
29 | {
30 | "metadata": {
31 | "id": "1dfyX8EyVsvL"
32 | },
33 | "cell_type": "code",
34 | "source": [
35 | "# @title Prepare environment\n",
36 | "\n",
37 | "import os\n",
38 | "import sys\n",
39 | "\n",
40 | "# Fetch VideoPrism repository if Python does not know about it and install\n",
41 | "# dependencies needed for this notebook.\n",
42 | "if not os.path.exists(\"videoprism_repo\"):\n",
43 | " !git clone --quiet --branch=main --depth=1 \\\n",
44 | " https://github.com/google-deepmind/videoprism.git videoprism_repo\n",
45 | " os.chdir('./videoprism_repo')\n",
46 | " !pip install .\n",
47 | " os.chdir('..')\n",
48 | "\n",
49 | "# Append VideoPrism code to Python import path.\n",
50 | "if \"videoprism_repo\" not in sys.path:\n",
51 | " sys.path.append(\"videoprism_repo\")\n",
52 | "\n",
53 | "# Install missing dependencies.\n",
54 | "!pip install mediapy\n",
55 | "\n",
56 | "import jax\n",
57 | "from jax.extend import backend\n",
58 | "import tensorflow as tf\n",
59 | "\n",
60 | "# Do not let TF use the GPU or TPUs.\n",
61 | "tf.config.set_visible_devices([], \"GPU\")\n",
62 | "tf.config.set_visible_devices([], \"TPU\")\n",
63 | "\n",
64 | "print(f\"JAX version: {jax.__version__}\")\n",
65 | "print(f\"JAX platform: {backend.get_backend().platform}\")\n",
66 | "print(f\"JAX devices: {jax.device_count()}\")"
67 | ],
68 | "outputs": [],
69 | "execution_count": null
70 | },
71 | {
72 | "metadata": {
73 | "id": "zByA1K0IVKAI"
74 | },
75 | "cell_type": "code",
76 | "source": [
77 | "# @title Load dependencies and define utilities\n",
78 | "\n",
79 | "import mediapy\n",
80 | "import numpy as np\n",
81 | "from PIL import Image\n",
82 | "\n",
83 | "\n",
84 | "def read_and_preprocess_video(\n",
85 | " filename: str, target_num_frames: int, target_frame_size: tuple[int, int]\n",
86 | "):\n",
87 | " \"\"\"Reads and preprocesses a video.\"\"\"\n",
88 | "\n",
89 | " frames = mediapy.read_video(filename)\n",
90 | "\n",
91 | " # Sample to target number of frames.\n",
92 | " frame_indices = np.linspace(\n",
93 | " 0, len(frames), num=target_num_frames, endpoint=False, dtype=np.int32\n",
94 | " )\n",
95 | " frames = np.asarray([frames[i] for i in frame_indices])\n",
96 | "\n",
97 | " # Resize to target size.\n",
98 | " original_height, original_width = frames.shape[-3:-1]\n",
99 | " target_height, target_width = target_frame_size\n",
100 | " assert (\n",
101 | " original_height * target_width == original_width * target_height\n",
102 | " ), 'Currently does not support aspect ratio mismatch.'\n",
103 | " frames = mediapy.resize_video(frames, shape=target_frame_size)\n",
104 | "\n",
105 | " # Normalize pixel values to [0.0, 1.0].\n",
106 | " frames = mediapy.to_float01(frames)\n",
107 | "\n",
108 | " return frames"
109 | ],
110 | "outputs": [],
111 | "execution_count": null
112 | },
113 | {
114 | "metadata": {
115 | "id": "WnYuzSgrXCL1"
116 | },
117 | "cell_type": "code",
118 | "source": [
119 | "# @title Load model\n",
120 | "\n",
121 | "import jax\n",
122 | "import jax.numpy as jnp\n",
123 | "from videoprism import models as vp\n",
124 | "\n",
125 | "MODEL_NAME = 'videoprism_public_v1_base' # @param ['videoprism_public_v1_base', 'videoprism_public_v1_large'] {allow-input: false}\n",
126 | "USE_BFLOAT16 = False # @param { type: \"boolean\" }\n",
127 | "NUM_FRAMES = 16\n",
128 | "FRAME_SIZE = 288\n",
129 | "\n",
130 | "fprop_dtype = jnp.bfloat16 if USE_BFLOAT16 else None\n",
131 | "flax_model = vp.get_model(MODEL_NAME, fprop_dtype=fprop_dtype)\n",
132 | "loaded_state = vp.load_pretrained_weights(MODEL_NAME)\n",
133 | "\n",
134 | "\n",
135 | "@jax.jit\n",
136 | "def forward_fn(inputs, train=False):\n",
137 | " return flax_model.apply(loaded_state, inputs, train=train)"
138 | ],
139 | "outputs": [],
140 | "execution_count": null
141 | },
142 | {
143 | "metadata": {
144 | "id": "AliScLC0jo1s"
145 | },
146 | "cell_type": "markdown",
147 | "source": [
148 | "# Example: Video feature extraction\n",
149 | "\n",
150 | "In this example, we extract the spatiotemporal embeddings of an example video."
151 | ]
152 | },
153 | {
154 | "metadata": {
155 | "id": "kLzkhP8CYUYj"
156 | },
157 | "cell_type": "code",
158 | "source": [
159 | "VIDEO_FILE_PATH = 'videoprism_repo/videoprism/assets/water_bottle_drumming.mp4' # @param {type: \"string\"}\n",
160 | "\n",
161 | "frames = read_and_preprocess_video(\n",
162 | " VIDEO_FILE_PATH,\n",
163 | " target_num_frames=NUM_FRAMES,\n",
164 | " target_frame_size=[FRAME_SIZE, FRAME_SIZE],\n",
165 | ")\n",
166 | "mediapy.show_video(frames, fps=6.0)\n",
167 | "\n",
168 | "frames = jnp.asarray(frames[None, ...]) # Add batch dimension.\n",
169 | "if USE_BFLOAT16:\n",
170 | " frames = frames.astype(jnp.bfloat16)\n",
171 | "print(f'Input shape: {frames.shape} [type: {frames.dtype}]')\n",
172 | "\n",
173 | "embeddings, _ = forward_fn(frames)\n",
174 | "print(f'Encoded embedding shape: {embeddings.shape} [type: {embeddings.dtype}]')\n"
175 | ],
176 | "outputs": [],
177 | "execution_count": null
178 | }
179 | ],
180 | "metadata": {
181 | "colab": {
182 | "provenance": [],
183 | "gpuType": "V28"
184 | },
185 | "kernelspec": {
186 | "name": "python3",
187 | "display_name": "Python 3"
188 | },
189 | "language_info": {
190 | "name": "python"
191 | },
192 | "accelerator": "TPU"
193 | },
194 | "nbformat": 4,
195 | "nbformat_minor": 0
196 | }
197 |
--------------------------------------------------------------------------------
/videoprism/layers_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 VideoPrism Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for layer modules."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | import chex
20 | import jax
21 | from jax import numpy as jnp
22 | import numpy as np
23 | from videoprism import layers
24 |
25 |
26 | class LayersTest(parameterized.TestCase):
27 |
28 | def test_identity(self):
29 | inputs = jnp.ones((4, 16, 8))
30 | outputs = layers.identity(inputs)
31 | self.assertEqual(inputs.shape, outputs.shape)
32 | self.assertTrue(jnp.array_equal(inputs, outputs))
33 |
34 | @chex.variants(with_jit=True, without_jit=True)
35 | @parameterized.parameters(True, False)
36 | def test_layer_norm(self, direct_scale: bool):
37 | np_inputs = np.random.normal(1.0, 0.5, [10, 10, 10, 3]).astype(np.float32)
38 | inputs = jnp.asarray(np_inputs)
39 | prng_key = jax.random.PRNGKey(seed=123)
40 | ln = layers.LayerNorm(name='ln', direct_scale=direct_scale)
41 |
42 | @self.variant
43 | def var_fn():
44 | return ln.init_with_output(prng_key, inputs)
45 |
46 | outputs, params = var_fn()
47 | self.assertLen(jax.tree_util.tree_flatten(params)[0], 2)
48 | self.assertEqual(inputs.shape, outputs.shape)
49 |
50 | @chex.variants(with_jit=True, without_jit=True)
51 | def test_feedforward_layer(self):
52 | np_inputs = np.random.normal(1.0, 0.5, [10, 10, 3]).astype(np.float32)
53 | inputs = jnp.asarray(np_inputs)
54 | prng_key = jax.random.PRNGKey(seed=123)
55 | ffn = layers.FeedForward(name='ffn', output_dim=20)
56 |
57 | @self.variant
58 | def var_fn():
59 | return ffn.init_with_output(prng_key, inputs)
60 |
61 | outputs, params = var_fn()
62 | self.assertLen(jax.tree_util.tree_flatten(params)[0], 2)
63 | self.assertEqual(outputs.shape, (10, 10, 20))
64 |
65 | @chex.variants(with_jit=True, without_jit=True)
66 | @parameterized.parameters(True, False)
67 | def test_transformer_feedforward(self, train: bool):
68 | batch_size, seq_len, input_dims = 4, 512, 8
69 | np_inputs = np.random.normal(
70 | 1.0, 0.5, [batch_size, seq_len, input_dims]
71 | ).astype(np.float32)
72 | inputs = jnp.asarray(np_inputs)
73 | np_paddings = np.zeros([batch_size, seq_len], dtype=np.float32)
74 | input_paddings = jnp.asarray(np_paddings)
75 | prng_key = jax.random.PRNGKey(seed=123)
76 | ffwd = layers.TransformerFeedForward(
77 | name='ffwd', hidden_dim=32, activation_fn=layers.gelu
78 | )
79 |
80 | @self.variant
81 | def var_fn():
82 | return ffwd.init_with_output(
83 | prng_key, inputs, input_paddings, train=train
84 | )
85 |
86 | outputs, params = var_fn()
87 | self.assertLen(jax.tree_util.tree_flatten(params)[0], 6)
88 | self.assertEqual(outputs.shape, (batch_size, seq_len, input_dims))
89 |
90 | @chex.variants(with_jit=True, without_jit=True)
91 | @parameterized.parameters(
92 | (16, 2, 5, False, [5, 16], [5, 2, 5]),
93 | (16, 2, 5, True, [5, 2, 5], [5, 16]),
94 | (256, 16, 16, True, [2, 16, 16], [2, 256]),
95 | )
96 | def test_mhd_projection(
97 | self,
98 | input_dim: int,
99 | num_heads: int,
100 | dim_per_head: int,
101 | is_output_projection: bool,
102 | inputs_shape: list[int],
103 | expected_outputs_shape: list[int],
104 | ):
105 | np_inputs = np.random.normal(1.5, 2.0, inputs_shape).astype(np.float32)
106 | inputs = jnp.asarray(np_inputs)
107 | prng_key = jax.random.PRNGKey(seed=123)
108 | mh = layers.AttentionProjection(
109 | name='mh',
110 | output_dim=input_dim,
111 | num_heads=num_heads,
112 | dim_per_head=dim_per_head,
113 | is_output_projection=is_output_projection,
114 | )
115 |
116 | @self.variant
117 | def var_fn():
118 | return mh.init_with_output(prng_key, inputs)
119 |
120 | outputs, params = var_fn()
121 | self.assertLen(jax.tree_util.tree_flatten(params)[0], 2)
122 | self.assertEqual(outputs.shape, tuple(expected_outputs_shape))
123 |
124 | @chex.variants(with_jit=True, without_jit=True)
125 | def test_per_dim_scale(self):
126 | np_inputs = np.random.normal(1.5, 2.0, [5, 4]).astype(np.float32)
127 | inputs = jnp.asarray(np_inputs)
128 | prng_key = jax.random.PRNGKey(seed=123)
129 | mdl = layers.PerDimScale(name='scale')
130 |
131 | @self.variant
132 | def var_fn():
133 | return mdl.init_with_output(prng_key, inputs)
134 |
135 | outputs, params = var_fn()
136 | self.assertLen(jax.tree_util.tree_flatten(params)[0], 1)
137 | self.assertEqual(outputs.shape, (5, 4))
138 |
139 | @chex.variants(with_jit=True, without_jit=True)
140 | @parameterized.product(
141 | enable_query_scale=[True, False],
142 | enable_per_dim_scale=[True, False],
143 | train=[True, False],
144 | )
145 | def test_mha(
146 | self,
147 | enable_query_scale: bool,
148 | enable_per_dim_scale: bool,
149 | train: bool,
150 | ):
151 | batch_size, seq_len, num_heads, mdl_dim = 3, 8, 4, 16
152 | query_vec = np.random.normal(size=[batch_size, seq_len, mdl_dim]).astype(
153 | np.float32
154 | )
155 | key_vec = np.random.normal(size=[batch_size, seq_len, mdl_dim]).astype(
156 | np.float32
157 | )
158 | value_vec = np.random.normal(size=[batch_size, seq_len, mdl_dim]).astype(
159 | np.float32
160 | )
161 | paddings = jnp.zeros(query_vec.shape[:-1], dtype=query_vec.dtype)
162 | atten_mask = layers.compute_attention_masks_for_fprop(query_vec, paddings)
163 | prng_key = jax.random.PRNGKey(seed=123)
164 | mha = layers.DotProductAttention(
165 | name='mha',
166 | hidden_dim=32,
167 | num_heads=num_heads,
168 | atten_logit_cap=20.0,
169 | internal_enable_query_scale=enable_query_scale,
170 | internal_enable_per_dim_scale=enable_per_dim_scale,
171 | )
172 |
173 | @self.variant
174 | def var_fn():
175 | return mha.init_with_output(
176 | prng_key, query_vec, key_vec, value_vec, atten_mask, train=train
177 | )
178 |
179 | (outputs, probs), params = var_fn()
180 | expected_num_weights = 8
181 | if enable_query_scale and enable_per_dim_scale:
182 | expected_num_weights += 1
183 | self.assertLen(jax.tree_util.tree_flatten(params)[0], expected_num_weights)
184 | self.assertEqual(outputs.shape, (batch_size, seq_len, mdl_dim))
185 | self.assertEqual(probs.shape, (batch_size, num_heads, seq_len, seq_len))
186 |
187 | @chex.variants(with_jit=True, without_jit=True)
188 | @parameterized.parameters(True, False)
189 | def test_transformer_layer(self, train: bool):
190 | num_heads, batch_size, seq_len, dim = 8, 3, 12, 32
191 | np_inputs = np.random.normal(1.0, 0.5, [batch_size, seq_len, dim]).astype(
192 | 'float32'
193 | )
194 | inputs = jnp.asarray(np_inputs)
195 | np_paddings = np.random.randint(0, 1, [batch_size, seq_len]).astype(
196 | 'float32'
197 | )
198 | paddings = jnp.asarray(np_paddings)
199 | atten_mask = layers.compute_attention_masks_for_fprop(inputs, paddings)
200 | prng_key = jax.random.PRNGKey(seed=123)
201 | tfm = layers.Transformer(
202 | name='tfm',
203 | hidden_dim=128,
204 | num_heads=num_heads,
205 | )
206 |
207 | @self.variant
208 | def var_fn():
209 | return tfm.init_with_output(
210 | prng_key, inputs, paddings, atten_mask=atten_mask, train=train
211 | )
212 |
213 | outputs, params = var_fn()
214 | self.assertLen(jax.tree_util.tree_flatten(params)[0], 17)
215 | self.assertEqual(outputs.shape, (batch_size, seq_len, dim))
216 |
217 | @chex.variants(with_jit=True, without_jit=True)
218 | @parameterized.product(
219 | scan=[True, False],
220 | train=[True, False],
221 | )
222 | def test_stacked_transformer_layer(self, scan: bool, train: bool):
223 | batch_size, seq_len, dim = 3, 12, 16
224 | np_inputs = np.random.normal(1.0, 0.5, [batch_size, seq_len, dim]).astype(
225 | 'float32'
226 | )
227 | inputs = jnp.asarray(np_inputs)
228 | np_paddings = np.random.randint(0, 1, [batch_size, seq_len]).astype(
229 | 'float32'
230 | )
231 | paddings = jnp.asarray(np_paddings)
232 | prng_key = jax.random.PRNGKey(seed=123)
233 | stacked_tfm = layers.StackedTransformer(
234 | name='stacked_tfm',
235 | hidden_dim=64,
236 | num_heads=8,
237 | num_layers=4,
238 | scan=scan,
239 | )
240 |
241 | @self.variant
242 | def var_fn():
243 | return stacked_tfm.init_with_output(
244 | prng_key, inputs, paddings, train=train
245 | )
246 |
247 | outputs, params = var_fn()
248 | self.assertLen(jax.tree_util.tree_flatten(params)[0], 17 if scan else 68)
249 | self.assertEqual(outputs.shape, (batch_size, seq_len, dim))
250 |
251 | @chex.variants(with_jit=True, without_jit=True)
252 | @parameterized.product(
253 | num_queries=[1, 4],
254 | train=[True, False],
255 | )
256 | def test_atten_token_pooling_layer(
257 | self,
258 | num_queries: int,
259 | train: bool,
260 | ):
261 | batch_size, seq_len, num_heads, input_dim = 3, 8, 4, 16
262 | np_inputs = np.random.normal(
263 | 1.5, 2.0, [batch_size, seq_len, input_dim]
264 | ).astype(np.float32)
265 | np_paddings = np.zeros([batch_size, seq_len], dtype=np.float32)
266 | inputs = jnp.asarray(np_inputs)
267 | input_paddings = jnp.asarray(np_paddings)
268 | prng_key = jax.random.PRNGKey(seed=123)
269 | pooler = layers.AttenTokenPoolingLayer(
270 | name='pooling',
271 | num_heads=num_heads,
272 | num_queries=num_queries,
273 | )
274 |
275 | @self.variant
276 | def var_fn():
277 | return pooler.init_with_output(
278 | prng_key, inputs, input_paddings, train=train
279 | )
280 |
281 | outputs, params = var_fn()
282 | self.assertLen(jax.tree_util.tree_flatten(params)[0], 12)
283 | self.assertEqual(outputs.shape, (batch_size, num_queries, input_dim))
284 |
285 |
286 | if __name__ == '__main__':
287 | absltest.main()
288 |
--------------------------------------------------------------------------------
/videoprism/colabs/videoprism_video_text_demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": []
7 | },
8 | "kernelspec": {
9 | "name": "python3",
10 | "display_name": "Python 3"
11 | },
12 | "language_info": {
13 | "name": "python"
14 | }
15 | },
16 | "cells": [
17 | {
18 | "cell_type": "markdown",
19 | "source": [
20 | "# VideoPrism Video-Text Encoder Demo\n",
21 | "\n",
22 | "[](https://arxiv.org/abs/2402.13217)\n",
23 | "[](https://research.google/blog/videoprism-a-foundational-visual-encoder-for-video-understanding/)\n",
24 | "[](https://opensource.org/licenses/Apache-2.0)\n",
25 | "\n",
26 | "This notebook provides an example of video and text feature extraction with a pre-trained VideoPrism video-text model for zero-shot video classification/retrieval.\n",
27 | "\n",
28 | "Please run this demo on Google Colab with (faster) or without TPU."
29 | ],
30 | "metadata": {
31 | "id": "KPPUiCpSbm53"
32 | }
33 | },
34 | {
35 | "cell_type": "markdown",
36 | "source": [
37 | "## Set up"
38 | ],
39 | "metadata": {
40 | "id": "k08qFZ9-cn9v"
41 | }
42 | },
43 | {
44 | "cell_type": "code",
45 | "source": [
46 | "# @title Prepare environment\n",
47 | "\n",
48 | "import os\n",
49 | "import sys\n",
50 | "\n",
51 | "# Fetch VideoPrism repository if Python does not know about it and install\n",
52 | "# dependencies needed for this notebook.\n",
53 | "if not os.path.exists(\"videoprism_repo\"):\n",
54 | " !git clone --quiet --branch=main --depth=1 \\\n",
55 | " https://github.com/google-deepmind/videoprism.git videoprism_repo\n",
56 | " os.chdir('./videoprism_repo')\n",
57 | " !pip install .\n",
58 | " os.chdir('..')\n",
59 | "\n",
60 | "# Append VideoPrism code to Python import path.\n",
61 | "if \"videoprism_repo\" not in sys.path:\n",
62 | " sys.path.append(\"videoprism_repo\")\n",
63 | "\n",
64 | "# Install missing dependencies.\n",
65 | "!pip install mediapy\n",
66 | "\n",
67 | "import jax\n",
68 | "from jax.extend import backend\n",
69 | "import tensorflow as tf\n",
70 | "\n",
71 | "# Do not let TF use the GPU or TPUs.\n",
72 | "tf.config.set_visible_devices([], \"GPU\")\n",
73 | "tf.config.set_visible_devices([], \"TPU\")\n",
74 | "\n",
75 | "print(f\"JAX version: {jax.__version__}\")\n",
76 | "print(f\"JAX platform: {backend.get_backend().platform}\")\n",
77 | "print(f\"JAX devices: {jax.device_count()}\")"
78 | ],
79 | "metadata": {
80 | "id": "1dfyX8EyVsvL"
81 | },
82 | "execution_count": null,
83 | "outputs": []
84 | },
85 | {
86 | "cell_type": "code",
87 | "execution_count": null,
88 | "metadata": {
89 | "id": "zByA1K0IVKAI"
90 | },
91 | "outputs": [],
92 | "source": [
93 | "# @title Load dependencies and define utilities\n",
94 | "\n",
95 | "import mediapy\n",
96 | "import numpy as np\n",
97 | "from PIL import Image\n",
98 | "\n",
99 | "\n",
100 | "def read_and_preprocess_video(\n",
101 | " filename: str, target_num_frames: int, target_frame_size: tuple[int, int]\n",
102 | "):\n",
103 | " \"\"\"Reads and preprocesses a video.\"\"\"\n",
104 | "\n",
105 | " frames = mediapy.read_video(filename)\n",
106 | "\n",
107 | " # Sample to target number of frames.\n",
108 | " frame_indices = np.linspace(\n",
109 | " 0, len(frames), num=target_num_frames, endpoint=False, dtype=np.int32\n",
110 | " )\n",
111 | " frames = np.array([frames[i] for i in frame_indices])\n",
112 | "\n",
113 | " # Resize to target size.\n",
114 | " original_height, original_width = frames.shape[-3:-1]\n",
115 | " target_height, target_width = target_frame_size\n",
116 | " assert (\n",
117 | " original_height * target_width == original_width * target_height\n",
118 | " ), 'Currently does not support aspect ratio mismatch.'\n",
119 | " frames = mediapy.resize_video(frames, shape=target_frame_size)\n",
120 | "\n",
121 | " # Normalize pixel values to [0.0, 1.0].\n",
122 | " frames = mediapy.to_float01(frames)\n",
123 | "\n",
124 | " return frames\n",
125 | "\n",
126 | "\n",
127 | "def compute_similarity_matrix(\n",
128 | " video_embeddings,\n",
129 | " text_embeddings,\n",
130 | " temperature: float,\n",
131 | " apply_softmax: str | None = None,\n",
132 | ") -\u003e np.ndarray:\n",
133 | " \"\"\"Computes cosine similarity matrix.\"\"\"\n",
134 | " assert apply_softmax in [None, 'over_texts', 'over_videos']\n",
135 | " emb_dim = video_embeddings[0].shape[-1]\n",
136 | " assert emb_dim == text_embeddings[0].shape[-1]\n",
137 | "\n",
138 | " video_embeddings = np.array(video_embeddings).reshape(-1, emb_dim)\n",
139 | " text_embeddings = np.array(text_embeddings).reshape(-1, emb_dim)\n",
140 | " similarity_matrix = np.dot(video_embeddings, text_embeddings.T)\n",
141 | "\n",
142 | " if temperature is not None:\n",
143 | " similarity_matrix /= temperature\n",
144 | "\n",
145 | " if apply_softmax == 'over_videos':\n",
146 | " similarity_matrix = np.exp(similarity_matrix)\n",
147 | " similarity_matrix = similarity_matrix / np.sum(\n",
148 | " similarity_matrix, axis=0, keepdims=True\n",
149 | " )\n",
150 | " elif apply_softmax == 'over_texts':\n",
151 | " similarity_matrix = np.exp(similarity_matrix)\n",
152 | " similarity_matrix = similarity_matrix / np.sum(\n",
153 | " similarity_matrix, axis=1, keepdims=True\n",
154 | " )\n",
155 | "\n",
156 | " return similarity_matrix"
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "source": [
162 | "# @title Load model\n",
163 | "\n",
164 | "import jax\n",
165 | "import jax.numpy as jnp\n",
166 | "from videoprism import models as vp\n",
167 | "\n",
168 | "MODEL_NAME = 'videoprism_lvt_public_v1_base' # @param ['videoprism_lvt_public_v1_base', 'videoprism_lvt_public_v1_large'] {allow-input: false}\n",
169 | "USE_BFLOAT16 = False # @param { type: \"boolean\" }\n",
170 | "NUM_FRAMES = 16\n",
171 | "FRAME_SIZE = 288\n",
172 | "\n",
173 | "fprop_dtype = jnp.bfloat16 if USE_BFLOAT16 else None\n",
174 | "flax_model = vp.get_model(MODEL_NAME, fprop_dtype=fprop_dtype)\n",
175 | "loaded_state = vp.load_pretrained_weights(MODEL_NAME)\n",
176 | "text_tokenizer = vp.load_text_tokenizer('c4_en')\n",
177 | "\n",
178 | "\n",
179 | "@jax.jit\n",
180 | "def forward_fn(inputs, text_token_ids, text_paddings, train=False):\n",
181 | " return flax_model.apply(\n",
182 | " loaded_state,\n",
183 | " inputs,\n",
184 | " text_token_ids,\n",
185 | " text_paddings,\n",
186 | " train=train,\n",
187 | " )"
188 | ],
189 | "metadata": {
190 | "id": "WnYuzSgrXCL1"
191 | },
192 | "execution_count": null,
193 | "outputs": []
194 | },
195 | {
196 | "cell_type": "markdown",
197 | "source": [
198 | "# Example: Zero-shot Video Classification/Retrieval\n",
199 | "\n",
200 | "In this example, we extract the embedding of an input video, and the embeddings of five senetence. We measure the cosine similarites between the videos and sentences."
201 | ],
202 | "metadata": {
203 | "id": "AliScLC0jo1s"
204 | }
205 | },
206 | {
207 | "cell_type": "code",
208 | "source": [
209 | "# @title Specify input video\n",
210 | "VIDEO_FILE_PATH = 'videoprism_repo/videoprism/assets/water_bottle_drumming.mp4' # @param {type: \"string\"}\n",
211 | "\n",
212 | "frames = read_and_preprocess_video(\n",
213 | " VIDEO_FILE_PATH,\n",
214 | " target_num_frames=NUM_FRAMES,\n",
215 | " target_frame_size=[FRAME_SIZE, FRAME_SIZE],\n",
216 | ")\n",
217 | "frames = jnp.asarray(frames[None, ...]) # Add batch dimension.\n",
218 | "if USE_BFLOAT16:\n",
219 | " frames = frames.astype(jnp.bfloat16)"
220 | ],
221 | "metadata": {
222 | "id": "sESN_CjfEiQR"
223 | },
224 | "execution_count": null,
225 | "outputs": []
226 | },
227 | {
228 | "cell_type": "code",
229 | "source": [
230 | "# @title Specify input text queries\n",
231 | "TEXT_QUERY_CSV = 'playing drums,sitting,playing flute,playing at playground,concert' # @param {type: \"string\"}\n",
232 | "PROMPT_TEMPLATE = 'a video of {}.'\n",
233 | "\n",
234 | "text_queries = TEXT_QUERY_CSV.split(',')\n",
235 | "text_queries = [PROMPT_TEMPLATE.format(t) for t in text_queries]\n",
236 | "text_ids, text_paddings = vp.tokenize_texts(text_tokenizer, text_queries)\n",
237 | "if USE_BFLOAT16:\n",
238 | " text_paddings = text_paddings.astype(jnp.bfloat16)\n",
239 | "\n",
240 | "print('Input text queries:')\n",
241 | "for i, text in enumerate(text_queries):\n",
242 | " print(f'({i + 1}) {text}')"
243 | ],
244 | "metadata": {
245 | "id": "kLzkhP8CYUYj"
246 | },
247 | "execution_count": null,
248 | "outputs": []
249 | },
250 | {
251 | "cell_type": "code",
252 | "source": [
253 | "# @title Compute video-to-text retrieval results\n",
254 | "video_embeddings, text_embeddings, _ = forward_fn(\n",
255 | " frames, text_ids, text_paddings)\n",
256 | "\n",
257 | "TEMPERATURE = 0.01 # @param {type: \"number\"}\n",
258 | "similarity_matrix = compute_similarity_matrix(\n",
259 | " video_embeddings,\n",
260 | " text_embeddings,\n",
261 | " temperature=TEMPERATURE,\n",
262 | " apply_softmax='over_texts',\n",
263 | ")"
264 | ],
265 | "metadata": {
266 | "id": "bfwT93Yz5oi_"
267 | },
268 | "execution_count": null,
269 | "outputs": []
270 | },
271 | {
272 | "cell_type": "code",
273 | "source": [
274 | "v2t_similarity_vector = similarity_matrix[0]\n",
275 | "top_indices = np.argsort(v2t_similarity_vector)[::-1]\n",
276 | "\n",
277 | "print(f'Query video: {os.path.basename(VIDEO_FILE_PATH)}')\n",
278 | "mediapy.show_video(frames[0].astype(jnp.float32), fps=6.0)\n",
279 | "\n",
280 | "for k, j in enumerate(top_indices):\n",
281 | " print(\n",
282 | " 'Top-%d retrieved text: %s [Similarity = %0.4f]'\n",
283 | " % (k + 1, text_queries[j], v2t_similarity_vector[j])\n",
284 | " )\n",
285 | "print(f'\\nThis is {text_queries[top_indices[0]]}')"
286 | ],
287 | "metadata": {
288 | "id": "lZ8woxde6t_S"
289 | },
290 | "execution_count": null,
291 | "outputs": []
292 | }
293 | ]
294 | }
295 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VideoPrism: A Foundational Visual Encoder for Video Understanding
2 |
3 | [](https://arxiv.org/abs/2402.13217)
4 | [](https://research.google/blog/videoprism-a-foundational-visual-encoder-for-video-understanding/)
5 | [](https://colab.research.google.com/github/google-deepmind/videoprism/blob/main/videoprism/colabs/videoprism_video_encoder_demo.ipynb)
6 | [](https://colab.research.google.com/github/google-deepmind/videoprism/blob/main/videoprism/colabs/videoprism_video_text_demo.ipynb)
7 | [](https://huggingface.co/collections/google/videoprism-686e823d6070ec6ad9e4b1f2)
8 | [](https://opensource.org/licenses/Apache-2.0)
9 |
10 | [VideoPrism](https://arxiv.org/abs/2402.13217) is a general-purpose video
11 | encoder designed to handle a wide spectrum of video understanding tasks,
12 | including classification, retrieval, localization, captioning, and question
13 | answering. It is pre-trained on a massive and diverse dataset: 1 billion
14 | image-text pairs from [WebLI](https://arxiv.org/abs/2209.06794), 36 million
15 | high-quality video-text pairs, and 582 million video clips with noisy or
16 | machine-generated parallel text (subject to data wipeout). The pre-training
17 | approach is designed for these hybrid data, to learn both from video-text pairs
18 | and the videos themselves. VideoPrism is fairly easy to adapt to new video
19 | understanding tasks, and achieves state-of-the-art performance on 31 out of 33
20 | public video understanding benchmarks using a single frozen model.
21 |
22 | This repository releases the model weight checkpoints and hosts [JAX](https://github.com/jax-ml/jax)/[Flax](https://github.com/google/flax) utility
23 | functions for checkpoint loading and model inference.
24 |
25 | ## Updates
26 |
27 | * **[Jul-16-25]:** Released VideoPrism video-text encoders for cross-modal retrieval [[`Colab notebook`](https://colab.research.google.com/github/google-deepmind/videoprism/blob/main/videoprism/colabs/videoprism_video_text_demo.ipynb)]. :fire::fire:
28 | * **[Jun-15-25]:** Added models to [[`Hugging Face`](https://huggingface.co/collections/google/videoprism-686e823d6070ec6ad9e4b1f2)].
29 | * **[Jun-05-25]:** Added video encoder demo [[`Colab notebook`](https://colab.research.google.com/github/google-deepmind/videoprism/blob/main/videoprism/colabs/videoprism_video_encoder_demo.ipynb)].
30 | * **[Jun-03-25]:** Released VideoPrism video encoders (Base and Large) [[`Blog`](https://research.google/blog/videoprism-a-foundational-visual-encoder-for-video-understanding/)] [[`Paper`](https://arxiv.org/abs/2402.13217)]. :fire::fire:
31 |
32 | ## TODOs
33 |
34 | - [ ] Add PyTorch model support.
35 |
36 | ## Getting started
37 |
38 | You will need Python 3.9 or later. Download the code from GitHub and run:
39 |
40 | ```shell
41 | $ git clone https://github.com/google-deepmind/videoprism.git
42 | $ cd videoprism
43 | $ pip install .
44 | ```
45 |
46 | Please get started with the following example code for model checkpoint loading
47 | and inference or use the [Colab notebook for video encoders](https://colab.research.google.com/github/google-deepmind/videoprism/blob/main/videoprism/colabs/videoprism_video_encoder_demo.ipynb) / [Colab notebook for video-text encoders](https://colab.research.google.com/github/google-deepmind/videoprism/blob/main/videoprism/colabs/videoprism_video_text_demo.ipynb):
48 |
49 | ```python
50 | import jax
51 | from videoprism import models as vp
52 |
53 | # Video encoders.
54 | model_name = 'videoprism_public_v1_base' # configuration name
55 | flax_model = vp.get_model(model_name)
56 | loaded_state = vp.load_pretrained_weights(model_name)
57 |
58 | @jax.jit
59 | def forward_fn(inputs):
60 | return flax_model.apply(loaded_state, inputs, train=False)
61 |
62 | video_inputs = ... # Shape = [batch_size, num_frames, height, width, 3].
63 | outputs, _ = forward_fn(video_inputs) # Shape = [batch_size, num_tokens, feature_channels].
64 |
65 | # Video-text encoders.
66 | model_name = 'videoprism_lvt_public_v1_base' # configuration name
67 | flax_model = vp.get_model(model_name)
68 | loaded_state = vp.load_pretrained_weights(model_name)
69 | text_tokenizer = vp.load_text_tokenizer('c4_en')
70 |
71 | @jax.jit
72 | def forward_fn(inputs, text_token_ids, text_token_paddings, train=False):
73 | return flax_model.apply(
74 | loaded_state,
75 | inputs,
76 | text_token_ids,
77 | text_token_paddings,
78 | train=train,
79 | )
80 |
81 | video_inputs = ... # Shape = [batch_size, num_frames, height, width, 3].
82 | text_queries = ... # A list of input text queries.
83 | text_ids, text_paddings = vp.tokenize_texts(text_tokenizer, text_queries)
84 | video_embeddings, text_embeddings, _ = forward_fn(
85 | video_inputs, text_ids, text_paddings) # Shape = [batch_size, feature_channels].
86 | ```
87 |
88 | ## Released models
89 |
90 | We release the following model variants:
91 |
92 | | Model Name | Configuration Name | Model Type | Backbone | #Params | File Size | Checkpoint |
93 | | -------- | -------- | ------- | :-------: | :-------: | :-------: | :-------: |
94 | | VideoPrism-B | `videoprism_public_v1_base` | Video encoder | ViT-B | 114M | 458MB | [link](https://huggingface.co/google/videoprism-base-f16r288) |
95 | | VideoPrism-L | `videoprism_public_v1_large` | Video encoder | ViT-L | 354M | 1.42GB | [link](https://huggingface.co/google/videoprism-large-f8r288) |
96 | | VideoPrism-LvT-B | `videoprism_lvt_public_v1_base` | Video-text encoders | ViT-B | 248M | 991MB | [link](https://huggingface.co/google/videoprism-lvt-base-f16r288) |
97 | | VideoPrism-LvT-L | `videoprism_lvt_public_v1_large` | Video-text encoders | ViT-L | 580M | 2.30GB | [link](https://huggingface.co/google/videoprism-lvt-large-f8r288) |
98 |
99 | Video encoders take videos with shape `(batch_size, num_frames, 288, 288, 3)`
100 | as inputs and output embeddings with shape
101 | `(batch_size, num_frames * 16 * 16, feature_channels)` which could be reshaped
102 | into `(batch_size, num_frames, 16, 16, feature_channels)` for spatiotemporal
103 | representations. During model training, `num_frames` is set to 16 and 8 for
104 | VideoPrism-B and VideoPrism-L, respectively. Both models are expected to work
105 | with arbitrary `num_frames` by interpolating the temporal positional embeddings.
106 | The RGB values of input videos should be normalized in [0.0, 1.0].
107 |
108 | In video-text models, both video and text encoders produce global embeddings
109 | with shape `(batch_size, feature_channels)`, whose similarities could be
110 | measured by cosine distances. We use the `c4_en` [SentencePiece](https://github.com/google/sentencepiece) model for text tokenization. During inference, embedding
111 | calculation for either modality can be skipped by providing `None` as the input.
112 |
113 | ## Results with frozen backbones
114 |
115 | *"Public"* denotes models we released in this repository. *"Paper"* and
116 | *"Prior SOTA"* denote our models and previous best-performing models reported
117 | in the [paper](https://arxiv.org/abs/2402.13217), respectively. Our *public*
118 | models perform slightly worse than the *paper* models due to different
119 | pre-training image-text data we used subject to data policy.
120 |
121 | ### Video-focused tasks ([VideoGLUE](https://arxiv.org/abs/2307.03166))
122 |
123 | | Models | K400 | MiT | SSv2 | D48 | Charades | ActivityNet | AVA | AVA-K |
124 | | -------- | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: |
125 | | **VideoPrism-B (public)** | 82.9 | 39.7 | 62.2 | 64.3 | 43.5 | 36.5 | 28.3 | 30.8 |
126 | | **VideoPrism-L (public)** | 85.0 | 43.3 | 64.6 | 67.6 | 53.2 | 37.0 | 32.4 | 34.5 |
127 | | VideoPrism-B (paper) | 84.2 | 40.8 | 63.6 | 67.4 | 40.4 | 36.6 | 30.6 | 31.8 |
128 | | VideoPrism-g (paper) | 87.2 | 45.5 | 68.5 | 71.3 | 62.3 | 37.8 | 36.2 | 37.3 |
129 | | Prior SOTA (B) | 77.1 | 34.0 | 58.2 | 55.6 | 33.3 | 35.8 | 21.1 | 25.9 |
130 | | Prior SOTA (L+) | 82.8 | 40.3 | 67.4 | 69.6 | 39.9 | 36.7 | 24.4 | 26.2 |
131 |
132 | ### Zero-shot video-text retrieval
133 |
134 | | Models | MSRVTT-1K (v2t) | MSRVTT-1K (t2v) | VATEX (v2t) | VATEX (t2v) | ActivityNet (v2t) | ActivityNet (t2v) |
135 | | -------- | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: |
136 | | **VideoPrism-LvT-B (public)** | 49.8 | 50.1 | 73.1 | 56.2 | 47.9 | 48.8 |
137 | | **VideoPrism-LvT-L (public)** | 50.6 | 50.1 | 75.0 | 57.2 | 49.1 | 51.3 |
138 | | VideoPrism-LvT-B (paper) | 50.2 | 51.4 | 76.2 | 57.7 | 47.9 | 49.6 |
139 | | VideoPrism-LvT-g (paper) | 51.7 | 52.7 | 77.1 | 62.5 | 50.3 | 52.7 |
140 | | Prior SOTA (B) | - | 34.0 | - | - | - | 30.6 |
141 | | Prior SOTA (L+) | 45.4 | 43.9 | 73.6 | 53.2 | 40.7 | 42.8 |
142 |
143 | ### Zero-shot video classification
144 |
145 | | Models | K400 | SSv2 (Temporal) | SSv2 (Events) | NExT-QA (Hard) | Charades | Charades (STA) |
146 | | -------- | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: |
147 | | **VideoPrism-LvT-B (public)** | 69.2 | 14.6 | 11.3 | 31.1 | 26.9 | 48.6 |
148 | | **VideoPrism-LvT-L (public)** | 72.4 | 18.0 | 12.4 | 32.1 | 32.4 | 50.2 |
149 | | VideoPrism-LvT-B (paper) | 71.3 | 16.1 | 11.9 | 31.3 | 29.2 | 50.0 |
150 | | VideoPrism-LvT-g (paper) | 74.6 | 18.6 | 15.7 | 32.7 | 32.4 | 50.4 |
151 | | Prior SOTA (B) | - | 9.8 | 6.4 | 27.6 | 21.1 | - |
152 | | Prior SOTA (L+) | 72.0 | 15.2 | 11.4 | 25.2 | 25.8 | 47.2 |
153 |
154 | ## Citation
155 |
156 | If you use VideoPrism, please cite the following papers:
157 |
158 |
159 | ```bibtex
160 | @inproceedings{zhao2024videoprism,
161 | title = {{VideoPrism}: A Foundational Visual Encoder for Video Understanding},
162 | author = {Long Zhao and Nitesh B. Gundavarapu and Liangzhe Yuan and Hao Zhou and Shen Yan and Jennifer J. Sun and Luke Friedman and Rui Qian and Tobias Weyand and Yue Zhao and Rachel Hornung and Florian Schroff and Ming-Hsuan Yang and David A. Ross and Huisheng Wang and Hartwig Adam and Mikhail Sirotenko and Ting Liu and Boqing Gong},
163 | booktitle = {International Conference on Machine Learning (ICML)},
164 | year = {2024}
165 | }
166 |
167 | @article{yuan2024videoglue,
168 | title = {{VideoGLUE}: Video General Understanding Evaluation of Foundation Models},
169 | author = {Liangzhe Yuan and Nitesh Bharadwaj Gundavarapu and Long Zhao and Hao Zhou and Yin Cui and Lu Jiang and Xuan Yang and Menglin Jia and Tobias Weyand and Luke Friedman and Mikhail Sirotenko and Huisheng Wang and Florian Schroff and Hartwig Adam and Ming-Hsuan Yang and Ting Liu and Boqing Gong},
170 | journal = {Transactions on Machine Learning Research (TMLR)},
171 | year = {2024}
172 | }
173 | ```
174 |
175 | ## License
176 |
177 | Copyright 2025 Google LLC
178 |
179 | All software is licensed under the Apache License, Version 2.0 (Apache 2.0);
180 | you may not use this file except in compliance with the Apache 2.0 license. You
181 | may obtain a copy of the Apache 2.0 license at:
182 |
183 | All other materials are licensed under the Creative Commons Attribution 4.0 International License (CC-BY). You may obtain a copy of the CC-BY license at:
184 |
185 | Unless required by applicable law or agreed to in writing, all software and
186 | materials distributed here under the Apache 2.0 or CC-BY licenses are
187 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
188 | either express or implied. See the licenses for the specific language governing
189 | permissions and limitations under those licenses.
190 |
191 | ## Disclaimer
192 |
193 | This is not an official Google product.
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
--------------------------------------------------------------------------------
/videoprism/encoders_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 VideoPrism Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for encoder modules."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | import chex
20 | import jax
21 | from jax import numpy as jnp
22 | import numpy as np
23 | from videoprism import encoders
24 |
25 |
26 | class EncodersTest(parameterized.TestCase):
27 |
28 | @chex.variants(with_jit=True, without_jit=True)
29 | def test_embedding_layer(self):
30 | num_classes, dim, input_shape = 8, 10, (5, 20)
31 | npy_input = np.random.randint(0, num_classes, input_shape).astype('int32')
32 | inputs = jnp.asarray(npy_input)
33 | prng_key = jax.random.PRNGKey(seed=123)
34 | emb_layer = encoders.Embedding(
35 | name='emb_lookup',
36 | num_classes=num_classes,
37 | input_dim=dim,
38 | scale_sqrt_depth=True,
39 | )
40 |
41 | @self.variant
42 | def var_fn():
43 | return emb_layer.init_with_output(prng_key, inputs)
44 |
45 | outputs, params = var_fn()
46 | self.assertLen(jax.tree_util.tree_flatten(params)[0], 1)
47 | self.assertEqual(outputs.shape, input_shape + (dim,))
48 |
49 | @chex.variants(with_jit=True, without_jit=True)
50 | def test_positional_embedding_layer(self):
51 | seq_len, dim = 8, 10
52 | prng_key = jax.random.PRNGKey(seed=123)
53 | emb_layer = encoders.PositionalEmbedding(
54 | name='pos_emb',
55 | embedding_dim=dim,
56 | min_timescale=10,
57 | max_timescale=20,
58 | )
59 |
60 | @self.variant
61 | def var_fn():
62 | return emb_layer.init_with_output(prng_key, seq_len)
63 |
64 | outputs, params = var_fn()
65 | self.assertEmpty(jax.tree_util.tree_flatten(params)[0])
66 | self.assertEqual(outputs.shape, (1, seq_len, dim))
67 |
68 | @chex.variants(with_jit=True, without_jit=True)
69 | def test_trainable_positional_embedding_layer(self):
70 | seq_len, dim = 8, 10
71 | prng_key = jax.random.PRNGKey(seed=123)
72 | emb_layer = encoders.TrainablePositionalEmbedding(
73 | name='pos_emb',
74 | max_seq_length=seq_len,
75 | embedding_dim=dim,
76 | lookup_style='matmul',
77 | )
78 |
79 | @self.variant
80 | def var_fn():
81 | return emb_layer.init_with_output(prng_key, seq_len)
82 |
83 | outputs, params = var_fn()
84 | self.assertLen(jax.tree_util.tree_flatten(params)[0], 1)
85 | self.assertEqual(outputs.shape, (1, seq_len, dim))
86 |
87 | @chex.variants(with_jit=True)
88 | @parameterized.product(
89 | scan=[True, False],
90 | train=[True, False],
91 | )
92 | def test_vision_transformer(self, scan: bool, train: bool):
93 | batch_size, seq_len, dim = 1, 6, 4
94 | np_inputs = np.random.normal(1.0, 0.5, [batch_size, seq_len, dim]).astype(
95 | 'float32'
96 | )
97 | inputs = jnp.asarray(np_inputs)
98 | prng_key = jax.random.PRNGKey(seed=123)
99 | vit = encoders.VisionTransformer(
100 | name='vit',
101 | num_tfm_layers=2,
102 | mlp_dim=4,
103 | num_heads=2,
104 | scan=scan,
105 | )
106 |
107 | @self.variant
108 | def var_fn():
109 | return vit.init_with_output(prng_key, inputs, train=train)
110 |
111 | outputs, params = var_fn()
112 | self.assertLen(jax.tree_util.tree_flatten(params)[0], 16 if scan else 32)
113 | self.assertEqual(outputs.shape, (batch_size, seq_len, dim))
114 |
115 | @chex.variants(with_jit=True)
116 | @parameterized.named_parameters(
117 | ('train', False, True, False, False),
118 | ('scan', True, False, False, False),
119 | ('scan_and_train', True, True, False, False),
120 | ('return_intermediate', True, False, True, False),
121 | ('use_frame_paddings', True, False, False, True),
122 | )
123 | def test_factorized_encoder(
124 | self,
125 | scan: bool,
126 | train: bool,
127 | return_intermediate: bool,
128 | use_frame_paddings: bool,
129 | ):
130 | batch_size, num_frames, image_size, patch_size, dim = 1, 4, 16, 4, 8
131 | np_inputs = np.random.normal(
132 | 0.0,
133 | 0.1,
134 | [batch_size, num_frames, image_size, image_size, 3],
135 | ).astype('float32')
136 | inputs = jnp.asarray(np_inputs)
137 |
138 | frame_paddings = None
139 | if use_frame_paddings:
140 | np_frame_paddings = np.zeros((batch_size, num_frames), dtype='float32')
141 | np_frame_paddings[:, num_frames // 2 :] = 1
142 | frame_paddings = jnp.asarray(np_frame_paddings)
143 |
144 | prng_key = jax.random.PRNGKey(seed=123)
145 | enc = encoders.FactorizedEncoder(
146 | name='enc',
147 | patch_size=patch_size,
148 | pos_emb_shape=(16, 16, 16),
149 | model_dim=dim,
150 | num_spatial_layers=2,
151 | num_temporal_layers=2,
152 | num_heads=2,
153 | mlp_dim=4,
154 | atten_logit_cap=50.0,
155 | scan=scan,
156 | )
157 |
158 | @self.variant
159 | def var_fn():
160 | return enc.init_with_output(
161 | prng_key,
162 | inputs,
163 | train=train,
164 | return_intermediate=return_intermediate,
165 | frame_paddings=frame_paddings,
166 | )
167 |
168 | (embeddings, outputs), params = var_fn()
169 |
170 | self.assertLen(jax.tree_util.tree_flatten(params)[0], 40 if scan else 72)
171 | self.assertEqual(
172 | embeddings.shape,
173 | (batch_size, num_frames * (image_size // patch_size) ** 2, dim),
174 | )
175 | if return_intermediate:
176 | self.assertEqual(
177 | outputs['spatial_features'].shape,
178 | (batch_size, num_frames * (image_size // patch_size) ** 2, dim),
179 | )
180 | else:
181 | self.assertEmpty(outputs)
182 |
183 | @chex.variants(with_jit=True)
184 | @parameterized.named_parameters(
185 | ('train', True, False),
186 | ('return_intermediate', False, True),
187 | )
188 | def test_factorized_video_classifier(
189 | self, train: bool, return_intermediate: bool
190 | ):
191 | batch_size, num_frames, image_size, patch_size, dim = 1, 4, 16, 4, 8
192 | np_inputs = np.random.normal(
193 | 0.0,
194 | 0.1,
195 | [batch_size, num_frames, image_size, image_size, 3],
196 | ).astype('float32')
197 | inputs = jnp.asarray(np_inputs)
198 |
199 | encoder_params = dict(
200 | patch_size=patch_size,
201 | pos_emb_shape=(16, 16, 16),
202 | model_dim=dim,
203 | num_spatial_layers=2,
204 | num_temporal_layers=2,
205 | num_heads=2,
206 | mlp_dim=4,
207 | atten_logit_cap=50.0,
208 | scan=True,
209 | )
210 | prng_key = jax.random.PRNGKey(seed=123)
211 | classifier = encoders.FactorizedVideoClassifier(
212 | name='classifier',
213 | encoder_params=encoder_params,
214 | num_classes=10,
215 | )
216 |
217 | @self.variant
218 | def var_fn():
219 | return classifier.init_with_output(
220 | prng_key, inputs, train=train, return_intermediate=return_intermediate
221 | )
222 |
223 | (logits, outputs), params = var_fn()
224 | self.assertLen(jax.tree_util.tree_flatten(params)[0], 54)
225 | self.assertEqual(logits.shape, (batch_size, 10))
226 | if return_intermediate:
227 | self.assertEqual(
228 | outputs['spatial_features'].shape,
229 | (batch_size, num_frames * (image_size // patch_size) ** 2, dim),
230 | )
231 | self.assertEqual(
232 | outputs['spatiotemporal_features'].shape,
233 | (batch_size, num_frames * (image_size // patch_size) ** 2, dim),
234 | )
235 | self.assertEqual(outputs['global_embeddings'].shape, (batch_size, dim))
236 | else:
237 | self.assertEmpty(outputs)
238 |
239 | @chex.variants(with_jit=True)
240 | @parameterized.named_parameters(
241 | ('train', False, True),
242 | ('scan', True, False),
243 | ('scan_and_train', True, True),
244 | )
245 | def test_text_encoder(self, scan: bool, train: bool):
246 | batch_size, seq_len, vocab_size, dim = 1, 10, 20, 8
247 | np_inputs = np.random.randint(0, vocab_size, [batch_size, seq_len]).astype(
248 | 'int32'
249 | )
250 | inputs = jnp.asarray(np_inputs)
251 | np_paddings = np.zeros([batch_size, seq_len], dtype='float32')
252 | np_paddings[:, seq_len // 2 :] = 1
253 | paddings = jnp.asarray(np_paddings)
254 |
255 | prng_key = jax.random.PRNGKey(seed=123)
256 | enc = encoders.TextEncoder(
257 | name='enc',
258 | vocabulary_size=vocab_size,
259 | num_class_tokens=1,
260 | model_dim=dim,
261 | num_layers=2,
262 | mlp_dim=4,
263 | num_heads=2,
264 | atten_logit_cap=50.0,
265 | scan=scan,
266 | )
267 |
268 | @self.variant
269 | def var_fn():
270 | return enc.init_with_output(prng_key, inputs, paddings, train=train)
271 |
272 | outputs, params = var_fn()
273 | self.assertLen(jax.tree_util.tree_flatten(params)[0], 20 if scan else 36)
274 | self.assertEqual(outputs.shape, (batch_size, seq_len + 1, dim))
275 |
276 | @chex.variants(with_jit=True)
277 | @parameterized.named_parameters(
278 | ('train', False, True, False),
279 | ('scan_and_train', True, True, False),
280 | ('return_intermediate', True, False, True),
281 | (
282 | 'selectively_return_intermediate',
283 | True,
284 | False,
285 | {'spatial_features', 'frame_embeddings'},
286 | ),
287 | )
288 | def test_factorized_video_clip(
289 | self, scan: bool, train: bool, return_intermediate: bool
290 | ):
291 | batch_size, num_frames, image_size, patch_size, dim = 1, 4, 16, 4, 8
292 | np_inputs = np.random.normal(
293 | 0.0,
294 | 0.1,
295 | [batch_size, num_frames, image_size, image_size, 3],
296 | ).astype('float32')
297 | inputs = jnp.asarray(np_inputs)
298 |
299 | seq_len, vocab_size = 10, 20
300 | np_text_token_ids = np.random.randint(
301 | 0, vocab_size, [batch_size, seq_len]
302 | ).astype('int32')
303 | text_token_ids = jnp.asarray(np_text_token_ids)
304 | np_text_paddings = np.zeros([batch_size, seq_len], dtype='float32')
305 | np_text_paddings[:, seq_len // 2 :] = 1
306 | text_paddings = jnp.asarray(np_text_paddings)
307 |
308 | prng_key = jax.random.PRNGKey(seed=123)
309 | net = encoders.FactorizedVideoCLIP(
310 | name='net',
311 | patch_size=patch_size,
312 | pos_emb_shape=(16, 16, 16),
313 | num_spatial_layers=2,
314 | num_temporal_layers=2,
315 | mlp_dim=4,
316 | num_auxiliary_layers=1,
317 | vocabulary_size=vocab_size,
318 | enable_causal_atten=True,
319 | num_unimodal_layers=2,
320 | norm_policy='pre',
321 | model_dim=dim,
322 | num_heads=2,
323 | atten_logit_cap=50.0,
324 | scan=scan,
325 | )
326 |
327 | @self.variant
328 | def var_fn():
329 | return net.init_with_output(
330 | prng_key,
331 | inputs=inputs,
332 | text_token_ids=text_token_ids,
333 | text_paddings=text_paddings,
334 | train=train,
335 | normalize=True,
336 | return_intermediate=return_intermediate,
337 | )
338 |
339 | (video_embeddings, text_embeddings, outputs), params = var_fn()
340 | self.assertLen(jax.tree_util.tree_flatten(params)[0], 88 if scan else 136)
341 | self.assertEqual(video_embeddings.shape, (batch_size, dim))
342 | self.assertEqual(text_embeddings.shape, (batch_size, dim))
343 | if not return_intermediate:
344 | self.assertEmpty(outputs)
345 | else:
346 | if return_intermediate is True: # pylint: disable=g-bool-id-comparison
347 | self.assertEqual(
348 | set(outputs.keys()),
349 | {
350 | 'frame_embeddings',
351 | 'spatial_features',
352 | 'spatiotemporal_features',
353 | },
354 | )
355 | else:
356 | self.assertEqual(set(outputs.keys()), set(return_intermediate))
357 |
358 | if 'spatial_features' in outputs:
359 | self.assertEqual(
360 | outputs['spatial_features'].shape,
361 | (batch_size, num_frames * (image_size // patch_size) ** 2, dim),
362 | )
363 | if 'spatiotemporal_features' in outputs:
364 | self.assertEqual(
365 | outputs['spatiotemporal_features'].shape,
366 | (batch_size, num_frames * (image_size // patch_size) ** 2, dim),
367 | )
368 | if 'frame_embeddings' in outputs:
369 | self.assertEqual(
370 | outputs['frame_embeddings'].shape, (batch_size, num_frames, dim)
371 | )
372 |
373 |
374 | if __name__ == '__main__':
375 | absltest.main()
376 |
--------------------------------------------------------------------------------
/videoprism/models.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 VideoPrism Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Provides builders and loaders of VideoPrism checkpoints.
16 |
17 | The v1 base model takes videos with shape (16, 288, 288) as inputs and outputs
18 | embeddings with shape (batch_size, 4096, 768) which could be reshaped into
19 | (batch_size, 16, 16, 16, 768) for spatiotemporal representations. The input
20 | videos should be normalized in [0.0, 1.0].
21 |
22 | Example usage:
23 | ```
24 | from videoprism import models as vp
25 |
26 | model_name = 'videoprism_public_v1_base'
27 | flax_model = vp.get_model(model_name)
28 | loaded_state = vp.load_pretrained_weights(model_name)
29 |
30 | @jax.jit
31 | def forward_fn(inputs):
32 | return flax_model.apply(loaded_state, inputs, train=False)
33 |
34 | model_inputs = ...
35 | outputs = forward_fn(model_inputs)
36 | ```
37 | """
38 |
39 | from collections.abc import Callable, Mapping, Sequence
40 | import functools
41 |
42 | from flax import linen as nn
43 | import jax
44 | import jax.numpy as jnp
45 | import huggingface_hub
46 | import numpy as np
47 | from videoprism import encoders
48 | from videoprism import tokenizers
49 | from videoprism import utils
50 |
51 | K400_NUM_CLASSES: int = 400
52 | SSV2_NUM_CLASSES: int = 174
53 |
54 | TEXT_MAX_LEN: int = 64
55 | TEXT_TOKENIZERS = {
56 | 'c4_en': {
57 | 'model_path': 'gs://t5-data/vocabs/cc_en.32000/sentencepiece.model',
58 | 'vocab_size': 32_000,
59 | },
60 | }
61 |
62 | CHECKPOINTS = {
63 | # Hugging Face checkpoints (repository, filename).
64 | 'videoprism_public_v1_base': (
65 | 'google/videoprism-base-f16r288',
66 | 'flax_base_f16r288_repeated.npz',
67 | ),
68 | 'videoprism_public_v1_large': (
69 | 'google/videoprism-large-f8r288',
70 | 'flax_large_f8r288_repeated.npz',
71 | ),
72 | 'videoprism_lvt_public_v1_base': (
73 | 'google/videoprism-lvt-base-f16r288',
74 | 'flax_lvt_base_f16r288_repeated.npz',
75 | ),
76 | 'videoprism_lvt_public_v1_large': (
77 | 'google/videoprism-lvt-large-f8r288',
78 | 'flax_lvt_large_f8r288_repeated.npz',
79 | ),
80 | }
81 |
82 | CONFIGS = {
83 | 'videoprism_v1_base': dict(
84 | patch_size=18,
85 | pos_emb_shape=(16, 16, 16),
86 | model_dim=768,
87 | num_spatial_layers=12,
88 | num_temporal_layers=4,
89 | num_heads=12,
90 | mlp_dim=3072,
91 | atten_logit_cap=50.0,
92 | scan=True,
93 | ),
94 | 'videoprism_v1_large': dict(
95 | patch_size=18,
96 | pos_emb_shape=(8, 16, 16),
97 | model_dim=1024,
98 | num_spatial_layers=24,
99 | num_temporal_layers=4,
100 | num_heads=16,
101 | mlp_dim=4096,
102 | atten_logit_cap=50.0,
103 | scan=True,
104 | ),
105 | 'videoprism_v1_giant': dict(
106 | patch_size=18,
107 | pos_emb_shape=(8, 16, 16),
108 | model_dim=1408,
109 | num_spatial_layers=40,
110 | num_temporal_layers=4,
111 | num_heads=16,
112 | mlp_dim=6144,
113 | atten_logit_cap=50.0,
114 | scan=True,
115 | ),
116 | 'videoprism_lvt_v1_base': dict(
117 | patch_size=18,
118 | pos_emb_shape=(16, 16, 16),
119 | num_spatial_layers=12,
120 | num_temporal_layers=4,
121 | mlp_dim=3072,
122 | num_auxiliary_layers=2,
123 | enable_causal_atten=True,
124 | num_unimodal_layers=12,
125 | norm_policy='pre',
126 | model_dim=768,
127 | num_heads=12,
128 | atten_logit_cap=50.0,
129 | scan=True,
130 | ),
131 | 'videoprism_lvt_v1_large': dict(
132 | patch_size=18,
133 | pos_emb_shape=(8, 16, 16),
134 | num_spatial_layers=24,
135 | num_temporal_layers=4,
136 | mlp_dim=4096,
137 | num_auxiliary_layers=2,
138 | enable_causal_atten=True,
139 | num_unimodal_layers=12,
140 | norm_policy='pre',
141 | model_dim=1024,
142 | num_heads=16,
143 | atten_logit_cap=50.0,
144 | scan=True,
145 | ),
146 | 'videoprism_lvt_v1_giant': dict(
147 | patch_size=18,
148 | pos_emb_shape=(8, 16, 16),
149 | num_spatial_layers=40,
150 | num_temporal_layers=4,
151 | mlp_dim=6144,
152 | num_auxiliary_layers=2,
153 | enable_causal_atten=True,
154 | num_unimodal_layers=16,
155 | norm_policy='primer_hybrid',
156 | model_dim=1408,
157 | num_heads=16,
158 | atten_logit_cap=50.0,
159 | scan=True,
160 | ),
161 | }
162 |
163 |
164 | def videoprism_v1_base():
165 | """Builds VideoPrism v1 base model."""
166 | return encoders.FactorizedEncoder(**CONFIGS['videoprism_v1_base'])
167 |
168 |
169 | def videoprism_v1_large():
170 | """Builds VideoPrism v1 large model."""
171 | return encoders.FactorizedEncoder(**CONFIGS['videoprism_v1_large'])
172 |
173 |
174 | def videoprism_v1_giant():
175 | """Builds VideoPrism v1 giant model."""
176 | return encoders.FactorizedEncoder(**CONFIGS['videoprism_v1_giant'])
177 |
178 |
179 | def videoprism_lvt_v1_base(text_tokenizer: str = 'c4_en'):
180 | """Builds VideoPrism LvT v1 base model."""
181 | config = CONFIGS['videoprism_lvt_v1_base']
182 | config['vocabulary_size'] = TEXT_TOKENIZERS[text_tokenizer]['vocab_size']
183 | return encoders.FactorizedVideoCLIP(**config)
184 |
185 |
186 | def videoprism_lvt_v1_large(text_tokenizer: str = 'c4_en'):
187 | """Builds VideoPrism LvT v1 large model."""
188 | config = CONFIGS['videoprism_lvt_v1_large']
189 | config['vocabulary_size'] = TEXT_TOKENIZERS[text_tokenizer]['vocab_size']
190 | return encoders.FactorizedVideoCLIP(**config)
191 |
192 |
193 | def videoprism_lvt_v1_giant(text_tokenizer: str = 'c4_en'):
194 | """Builds VideoPrism LvT v1 giant model."""
195 | config = CONFIGS['videoprism_lvt_v1_giant']
196 | config['vocabulary_size'] = TEXT_TOKENIZERS[text_tokenizer]['vocab_size']
197 | return encoders.FactorizedVideoCLIP(**config)
198 |
199 |
200 | def videoprism_vc_v1_base(num_classes: int):
201 | """Builds VideoPrism Classification v1 base model."""
202 | encoder_params = CONFIGS['videoprism_v1_base']
203 | return encoders.FactorizedVideoClassifier(
204 | encoder_params=encoder_params, num_classes=num_classes
205 | )
206 |
207 |
208 | def videoprism_vc_v1_large(num_classes: int):
209 | """Builds VideoPrism Classification v1 large model."""
210 | encoder_params = CONFIGS['videoprism_v1_large']
211 | return encoders.FactorizedVideoClassifier(
212 | encoder_params=encoder_params, num_classes=num_classes
213 | )
214 |
215 |
216 | def videoprism_vc_v1_giant(num_classes: int):
217 | """Builds VideoPrism Classification v1 giant model."""
218 | encoder_params = CONFIGS['videoprism_v1_giant']
219 | return encoders.FactorizedVideoClassifier(
220 | encoder_params=encoder_params, num_classes=num_classes
221 | )
222 |
223 |
224 | MODELS = {
225 | 'videoprism_public_v1_base': videoprism_v1_base,
226 | 'videoprism_public_v1_large': videoprism_v1_large,
227 | 'videoprism_lvt_public_v1_base': functools.partial(
228 | videoprism_lvt_v1_base, text_tokenizer='c4_en'
229 | ),
230 | 'videoprism_lvt_public_v1_large': functools.partial(
231 | videoprism_lvt_v1_large, text_tokenizer='c4_en'
232 | ),
233 | }
234 |
235 |
236 | def _get_model_name_by_hf_model_id(model_id: str) -> str | None:
237 | """Returns model name for the given Hugging Face model ID.
238 |
239 | Hugging Face model ID is typically the name of the repository, e.g.,
240 | `google/videoprism-base-f16r288`.
241 |
242 | Args:
243 | model_id: A string for the Hugging Face model ID.
244 |
245 | Returns:
246 | The model name for the given Hugging Face model ID or None if not found.
247 | """
248 | for model_name, value in CHECKPOINTS.items():
249 | if isinstance(value, tuple) and value[0] == model_id:
250 | return model_name
251 |
252 | return None
253 |
254 |
255 | def has_model(
256 | model_name: str,
257 | models: Mapping[str, Callable[[], nn.Module]] | None = None,
258 | ) -> bool:
259 | """Returns whether the model is available."""
260 | models = models or MODELS
261 | if model_name.startswith('google/'):
262 | # Handle Hugging Face model ID.
263 | model_name = _get_model_name_by_hf_model_id(model_name)
264 |
265 | return model_name is not None and model_name in models
266 |
267 |
268 | def get_model(
269 | model_name: str | None,
270 | model_fn: Callable[[], nn.Module] | None = None,
271 | models: Mapping[str, Callable[[], nn.Module]] | None = None,
272 | fprop_dtype: jax.typing.DTypeLike | None = None,
273 | ):
274 | """Returns VideoPrism model with the given name.
275 |
276 | Args:
277 | model_name: A string for the model name or Hugging Face model ID.
278 | model_fn: Optional function that returns the model.
279 | models: Mapping from model name to model creation function. Used with
280 | `model_name`. If None, use the default `MODELS`.
281 |
282 | Returns:
283 | A Flax VideoPrism model.
284 | """
285 |
286 | if model_fn is None:
287 | assert model_name is not None
288 | models = models or MODELS
289 | if model_name.startswith('google/'):
290 | # Handle Hugging Face model ID.
291 | model_name = _get_model_name_by_hf_model_id(model_name)
292 | if model_name is None:
293 | raise ValueError(f'Failed to find model name with `{model_name}`.')
294 |
295 | if model_name not in models:
296 | raise ValueError(f'Model `{model_name}` not found.')
297 |
298 | model_fn = models[model_name]
299 |
300 | model = model_fn()
301 | if fprop_dtype is not None:
302 | model.fprop_dtype = fprop_dtype
303 | return model
304 |
305 |
306 | def load_pretrained_weights(
307 | model_name: str | None,
308 | checkpoint_path: str | None = None,
309 | checkpoints: Mapping[str, str | tuple[str, str]] | None = None,
310 | ):
311 | """Loads pretrained model weights.
312 |
313 | Args:
314 | model_name: A string for the model name or Hugging Face model ID.
315 | checkpoint_path: Optional path of the model checkpoint.
316 | checkpoints: Mapping from model name to checkpoint path. Used with
317 | `model_name`. If None, use the default `CHECKPOINTS`.
318 |
319 | Returns:
320 | Restored Flax model weights.
321 | """
322 | checkpoints = checkpoints or CHECKPOINTS
323 |
324 | if checkpoint_path is None:
325 | assert model_name is not None
326 | if model_name.startswith('google/'):
327 | # Handle Hugging Face model ID.
328 | model_name = _get_model_name_by_hf_model_id(model_name)
329 |
330 | repo_id, filename = checkpoints[model_name]
331 | checkpoint_path = huggingface_hub.hf_hub_download(
332 | repo_id=repo_id, filename=filename
333 | )
334 |
335 | variables = utils.load_checkpoint(checkpoint_path)
336 | return jax.tree_util.tree_map(jnp.asarray, variables)
337 |
338 |
339 | def load_text_tokenizer(name: str) -> tokenizers.Tokenizer:
340 | """Loads a text tokenizer by name.
341 |
342 | Args:
343 | name: A string for the text tokenizer model name.
344 |
345 | Returns:
346 | A text tokenizer.
347 | """
348 | if name not in TEXT_TOKENIZERS:
349 | raise ValueError(f'Text tokenizer `{name}` not found.')
350 |
351 | model_path = TEXT_TOKENIZERS[name]['model_path']
352 | return tokenizers.SentencePieceTokenizer(model_path)
353 |
354 |
355 | def tokenize_texts(
356 | tokenizer: tokenizers.Tokenizer,
357 | inputs: Sequence[str],
358 | max_length: int = TEXT_MAX_LEN,
359 | add_bos: bool | None = None,
360 | canonicalize: bool = True,
361 | ) -> tuple[np.ndarray, np.ndarray]:
362 | """Tokenizes a batch of texts.
363 |
364 | Args:
365 | tokenizer: The tokenizer to use.
366 | inputs: The list of texts to tokenize.
367 | max_length: The maximum length of the tokenized texts.
368 | add_bos: Whether to add a beginning-of-sentence token. If None, the
369 | beginning-of-sentence token will be added if the tokenizer's bos_token is
370 | a non-negative integer.
371 | canonicalize: Whether to canonicalize the texts before tokenization.
372 |
373 | Returns:
374 | A tuple of two numpy arrays containing the padded token ids and the
375 | corresponding paddings, where 1 denotes padding token.
376 | """
377 |
378 | if canonicalize:
379 | inputs = [utils.canonicalize_text(text) for text in inputs]
380 |
381 | batch_ids, batch_paddings = [], []
382 | if add_bos is None:
383 | add_bos = tokenizer.bos_token >= 0
384 |
385 | for ids in tokenizer.to_int(inputs, bos=add_bos, eos=False):
386 | ids_seq_len = len(ids)
387 | if ids_seq_len > max_length:
388 | ids = ids[:max_length]
389 |
390 | ids = np.asarray(ids, dtype=np.int32)
391 | paddings = np.zeros_like(ids, dtype=np.float32)
392 |
393 | if ids_seq_len < max_length:
394 | ids = np.pad(
395 | ids, (0, max_length - ids_seq_len), 'constant', constant_values=0
396 | )
397 | paddings = np.pad(
398 | paddings,
399 | (0, max_length - ids_seq_len),
400 | 'constant',
401 | constant_values=1.0,
402 | )
403 |
404 | batch_ids.append(ids)
405 | batch_paddings.append(paddings)
406 |
407 | return np.asarray(batch_ids), np.asarray(batch_paddings)
408 |
--------------------------------------------------------------------------------
/videoprism/encoders.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 VideoPrism Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Modules for video encoders."""
16 |
17 | from collections.abc import Collection, Sequence
18 | import dataclasses
19 | import math
20 | from typing import Any
21 |
22 | import einops
23 | import einshape
24 | from flax import linen as nn
25 | import jax
26 | from jax import numpy as jnp
27 | import numpy as np
28 | from videoprism import layers
29 |
30 | Array = jax.Array
31 | Variables = nn.module.VariableDict
32 |
33 | default_kernel_init = layers.default_kernel_init
34 |
35 |
36 | def _contains(collection: Collection[str] | bool, key: str) -> bool:
37 | """Checks if a collection contains a key.
38 |
39 | Args:
40 | collection: A collection of strings or a boolean value.
41 | key: A string key to check.
42 |
43 | Returns:
44 | True if the collection contains the key, or if the collection is a True
45 | boolean. False otherwise.
46 | """
47 | return collection if isinstance(collection, bool) else key in collection
48 |
49 |
50 | def _l2_normalize(
51 | x: Array, axis: int | Sequence[int] = -1, epsilon: float = 1e-12
52 | ) -> Array:
53 | """L2-normalizes a jax.Array along certain dimension.
54 |
55 | Args:
56 | x: An input jax.Array.
57 | axis: An integer or a sequence of integers for the axis to normalize.
58 | epsilon: A small constant for numerical stability.
59 |
60 | Returns:
61 | Normalized jax.Array.
62 | """
63 | x_dtype = x.dtype
64 | # Always convert embed to float32 for all precisions.
65 | x = x.astype(jnp.float32)
66 | norm = jnp.sqrt(jnp.sum(x * x, axis=axis, keepdims=True) + epsilon)
67 | return (x / norm).astype(x_dtype)
68 |
69 |
70 | def _image_to_patch(inputs: Array, patch_size: int) -> Array:
71 | """Converts an image to patches.
72 |
73 | Args:
74 | inputs: A jax.Array of shape [B, H, W, C] ,
75 | patch_size: An integer for dimension of a square patch.
76 |
77 | Returns:
78 | batched_patches: [B, (H * W / P^2), P^2 * C].
79 | """
80 | if len(inputs.shape) < 4:
81 | raise ValueError(
82 | f'Image should be formatted as 4D [B, H, W, C], Shape: {inputs.shape}'
83 | )
84 | height, width, channels = inputs.shape[-3:]
85 |
86 | if height % patch_size != 0 or width % patch_size != 0:
87 | raise ValueError(
88 | f'Image height ({height}) and width ({width}) should be multiples '
89 | f'of patch_size ({patch_size}).'
90 | )
91 |
92 | row_blocks = height // patch_size
93 | column_blocks = width // patch_size
94 |
95 | patches = einops.rearrange(
96 | inputs,
97 | '... (m p)(n q) c->...(m n)(p q c)',
98 | m=row_blocks,
99 | n=column_blocks,
100 | p=patch_size,
101 | q=patch_size,
102 | c=channels,
103 | )
104 | return patches
105 |
106 |
107 | def _interpolate_emb_1d(emb: Array, target_emb_length: int) -> Array:
108 | """Interpolates a 1D positional embedding to a new shape.
109 |
110 | Args:
111 | emb: jax.Array, (1, N, D), flattened 1D positional embedding.
112 | target_emb_length: length of the target embedding.
113 |
114 | Returns:
115 | Flattened, interpolated embedding of shape (1, target_emb_length, D)
116 | """
117 |
118 | if len(emb.shape) > 3 or emb.shape[0] != 1:
119 | raise ValueError('The shape of the embedding should be (1, N, D)')
120 |
121 | emb_dim = emb.shape[-1]
122 | emb = jnp.squeeze(emb, axis=0)
123 |
124 | target_emb = jax.image.resize(
125 | emb, (target_emb_length, emb_dim), method='bilinear'
126 | )
127 | target_emb = jnp.reshape(target_emb, (1, target_emb_length, emb_dim))
128 | return target_emb
129 |
130 |
131 | def _interpolate_emb_2d(
132 | emb: Array,
133 | source_emb_shape: tuple[int, int],
134 | target_emb_shape: tuple[int, int],
135 | ) -> Array:
136 | """Interpolates a 2D positional embedding to a new shape.
137 |
138 | Args:
139 | emb: A jax.Array of shape (1, H1xW1, D) for flattened 2D positional
140 | embedding.
141 | source_emb_shape: Tuple, (H1, W1), height and width of the source embedding.
142 | target_emb_shape: Tuple, (H2, W2), height and width of the target embedding.
143 |
144 | Returns:
145 | Flattened, interpolated embedding of shape (1, H2xW2, D)
146 | """
147 |
148 | if len(emb.shape) > 3 or emb.shape[0] != 1:
149 | raise ValueError('The shape of the embedding should be (1, H * W, D)')
150 |
151 | if emb.shape[-2] != source_emb_shape[0] * source_emb_shape[1]:
152 | raise ValueError('The shape of the embedding does NOT match input specs.')
153 |
154 | emb_dim = emb.shape[-1]
155 | emb = jnp.reshape(emb, (source_emb_shape[0], source_emb_shape[1], emb_dim))
156 |
157 | target_emb = jax.image.resize(
158 | emb,
159 | (target_emb_shape[0], target_emb_shape[1], emb_dim),
160 | method='bilinear',
161 | )
162 | target_emb = jnp.reshape(
163 | target_emb, (1, target_emb_shape[0] * target_emb_shape[1], emb_dim)
164 | )
165 | return target_emb
166 |
167 |
168 | class Embedding(layers.Module):
169 | """A simple embedding layer that performs embedding lookups from ids.
170 |
171 | Attributes:
172 | num_classes: Number of tokens in the vocabulary.
173 | input_dim: Depth of the embedding output. This is called `input_dim` as
174 | opposed to the more appropriate `embedding_dim` to be compatible with
175 | other embedding layers defined in this file.
176 | lookup_style: Style of lookup, one of index or matmul.
177 | scale_sqrt_depth: If set to True, activations are scaled with
178 | sqrt(embedding_dim) in embeding lookup.
179 | set_nan_for_oob_id: If set to True, embeddings corresponding to
180 | out-of-boundaries ids will be set to NaN.
181 | """
182 |
183 | num_classes: int = 0
184 | input_dim: int = 0
185 | lookup_style: str = 'index'
186 | scale_sqrt_depth: bool = False
187 | set_nan_for_oob_id: bool = False
188 |
189 | @nn.compact
190 | def __call__(self, ids: Array) -> Array:
191 | """Generates a jax.Array of embedding lookup result.
192 |
193 | Args:
194 | ids: Indexes of shape [...] for embedding lookup.
195 |
196 | Returns:
197 | A jax.Array of shape [..., input_dim].
198 | """
199 | emb_var = self._cast_to_fprop_dtype(
200 | self.param(
201 | 'emb_var',
202 | nn.initializers.normal(stddev=1.0 / math.sqrt(self.input_dim)),
203 | [self.num_classes, self.input_dim],
204 | self.dtype,
205 | )
206 | )
207 | if self.lookup_style == 'index':
208 | embs = jnp.asarray(emb_var)[(ids,)]
209 | elif self.lookup_style == 'matmul':
210 | one_hot_ids = jax.nn.one_hot(
211 | ids, self.num_classes, dtype=self.fprop_dtype
212 | )
213 | embs = jnp.einsum('...y,yz->...z', one_hot_ids, emb_var)
214 | else:
215 | raise ValueError(f'Unknown lookup style: `{self.lookup_style}`.')
216 |
217 | # Map out-of-boundary ids to NaN.
218 | if self.set_nan_for_oob_id:
219 | embs = jnp.where(ids[..., jnp.newaxis] < self.num_classes, embs, jnp.nan)
220 |
221 | if self.scale_sqrt_depth:
222 | embs *= self.input_dim**0.5
223 |
224 | return embs
225 |
226 |
227 | class PositionalEmbedding(layers.Module):
228 | """Generates position embedding for a given 1-d sequence.
229 |
230 | Attributes:
231 | embedding_dim: Dimension of the embedding to be generated.
232 | min_timescale: Start of the geometric index.
233 | max_timescale: End of the geometric index.
234 | """
235 |
236 | embedding_dim: int = 0
237 | min_timescale: int = 1
238 | max_timescale: int = 10_000
239 |
240 | def __call__(self, seq_length: int) -> Array:
241 | """Generates a jax.Array of embedding lookup result.
242 |
243 | Args:
244 | seq_length: Sequence length of the embeddings to be generated.
245 |
246 | Returns:
247 | A jax.Array of shape [1, seq_length, embedding_dim].
248 | """
249 | position = jnp.arange(seq_length, dtype=jnp.float32)[jnp.newaxis, :]
250 | num_timescales = self.embedding_dim // 2
251 | log_timescale_increment = math.log(
252 | float(self.max_timescale) / float(self.min_timescale)
253 | ) / jnp.maximum(jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1)
254 | inv_timescales = self.min_timescale * jnp.exp(
255 | jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment
256 | )
257 | scaled_time = (
258 | position[:, :, jnp.newaxis]
259 | * inv_timescales[jnp.newaxis, jnp.newaxis, :]
260 | )
261 | embs = jnp.concatenate(
262 | [jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=-1
263 | ).astype(self.fprop_dtype)
264 | # Force usage of `np` to compute static values at trace time.
265 | embs = jnp.pad(embs, [[0, 0], [0, 0], [0, np.mod(self.embedding_dim, 2)]])
266 | return embs
267 |
268 |
269 | class TrainablePositionalEmbedding(layers.Module):
270 | """Generates trainable position embedding for a given 1-d sequence.
271 |
272 | Attributes:
273 | embedding_dim: Dimension of the embedding to be generated.
274 | max_seq_length: Max sequence length.
275 | lookup_style: Style of lookup, one of index or matmul.
276 | """
277 |
278 | embedding_dim: int = 0
279 | max_seq_length: int = 10_240
280 | lookup_style: str = 'matmul'
281 |
282 | @nn.compact
283 | def __call__(self, seq_length: int) -> Array:
284 | """Generates a jax.Array of embedding lookup result.
285 |
286 | Args:
287 | seq_length: Sequence length of the embeddings to be generated.
288 |
289 | Returns:
290 | A jax.Array of shape [1, seq_length, embedding_dim].
291 | """
292 | position = jnp.arange(seq_length, dtype=jnp.int32)[jnp.newaxis, :]
293 | pos_emb_var = self._cast_to_fprop_dtype(
294 | self.param(
295 | 'emb_var',
296 | default_kernel_init,
297 | [self.max_seq_length, self.embedding_dim],
298 | self.dtype,
299 | )
300 | )
301 | pos_emb_var = jax.lax.slice_in_dim(pos_emb_var, 0, seq_length, axis=0)
302 | if self.lookup_style == 'matmul':
303 | one_hot_ids = jax.nn.one_hot(position, seq_length, dtype=self.fprop_dtype)
304 | embs = jnp.einsum('...y,yz->...z', one_hot_ids, pos_emb_var)
305 | else:
306 | raise ValueError(f'Unknown lookup style: `{self.lookup_style}`.')
307 | return embs
308 |
309 |
310 | class VisionTransformer(layers.Module):
311 | """Vision transformer model.
312 |
313 | This class follows a minimalistic design pattern. Users need to configure the
314 | templates for the submodules themselves; this increases the generalizability
315 | of this class.
316 |
317 | Attributes:
318 | num_tfm_layers: Number of layers in this model.
319 | mlp_dim: The hidden layer dimension of FFN in Transformer layers.
320 | num_heads: Number of attention heads.
321 | xformer_has_bias: Whether to use bias.
322 | xformer_dropout_prob: Apply dropout at this prob at various places.
323 | xformer_atten_dropout_prob: Probability at which we apply dropout to the
324 | attention weights.
325 | xformer_residual_dropout_prob: Probability at which we apply dropout to the
326 | residual layers, such that, residual(x, y) = (x + dropout(y)).
327 | xformer_relu_dropout_prob: Probability at which we apply dropout to the FFN
328 | layers.
329 | atten_logit_cap: Cap the absolute values of logits by tanh. Enabled when a
330 | positive value is specified. May not be supported by a subclass.
331 | norm_policy: Policy for applying normalization wrt. transformations. Options
332 | are: (1) "pre", applied before transformation. (2) "primer_hybrid",
333 | applied before and after transformation. (3) "post", applied after
334 | transformation. (4) "post_skip", applied after the skip connection.
335 | scan: Whether to use `nn.remat` and`nn.scan`.
336 | """
337 |
338 | num_tfm_layers: int = 12
339 | mlp_dim: int = 3072
340 | num_heads: int = 12
341 | xformer_has_bias: bool = True
342 | xformer_dropout_prob: float = 0.0
343 | xformer_atten_dropout_prob: float | None = None
344 | xformer_residual_dropout_prob: float | None = None
345 | xformer_relu_dropout_prob: float | None = None
346 | atten_logit_cap: float = 0.0
347 | norm_policy: str = 'pre'
348 | scan: bool = False
349 |
350 | @nn.compact
351 | def __call__(
352 | self, inputs: Array, paddings: Array | None = None, train: bool = False
353 | ) -> Array:
354 | """Applies the ViT model to the inputs.
355 |
356 | Args:
357 | inputs: Input tensor of shape [B, N, D], which are sequences of embeddings
358 | or patches.
359 | paddings: Optional [B, N] padding field of inputs when inputs are with [B,
360 | N, D].
361 | train: If the model is in the train mode.
362 |
363 | Returns:
364 | Output tensor of shape [B, N, D].
365 | """
366 | features = inputs
367 | if paddings is None:
368 | paddings = jnp.zeros(features.shape[:-1], dtype=features.dtype)
369 | features = layers.StackedTransformer(
370 | name='transformers_stack',
371 | num_layers=self.num_tfm_layers,
372 | hidden_dim=self.mlp_dim,
373 | num_heads=self.num_heads,
374 | dropout_prob=self.xformer_dropout_prob,
375 | atten_dropout_prob=self.xformer_atten_dropout_prob,
376 | residual_dropout_prob=self.xformer_residual_dropout_prob,
377 | relu_dropout_prob=self.xformer_relu_dropout_prob,
378 | use_bias=self.xformer_has_bias,
379 | atten_logit_cap=self.atten_logit_cap,
380 | norm_policy=self.norm_policy,
381 | internal_enable_per_dim_scale=False,
382 | activation_fn=layers.gelu,
383 | enable_causal_atten=False,
384 | scan=self.scan,
385 | dtype=self.dtype,
386 | fprop_dtype=self.fprop_dtype,
387 | )(features, paddings, train=train)
388 | return features
389 |
390 |
391 | class FactorizedEncoder(layers.Module):
392 | """Factorized encoder from the paper `ViViT: A Video Vision Transformer`.
393 |
394 | This is an implementation of model-2 in the paper. It applies ViT model for
395 | video data based on the factorized space-time encoder.
396 |
397 | Reference: https://arxiv.org/abs/2103.15691
398 | """
399 |
400 | patch_size: int = 18
401 | pos_emb_shape: tuple[int, int, int] = (16, 16, 16)
402 | model_dim: int = 768
403 | num_spatial_layers: int = 12
404 | num_temporal_layers: int = 4
405 | num_heads: int = 12
406 | mlp_dim: int = 3072
407 | atten_logit_cap: float = 0.0
408 | norm_policy: str = 'pre'
409 | scan: bool = False
410 |
411 | def __call__(
412 | self,
413 | inputs: Array,
414 | train: bool = False,
415 | return_intermediate: bool | Collection[str] = False,
416 | frame_paddings: Array | None = None,
417 | ) -> tuple[Array, dict[str, Array]]:
418 | """Computes predictions for batched inputs.
419 |
420 | Args:
421 | inputs: Input image tensor of shape [B, T, H, W, 3] (H == W).
422 | train: If the model is in the train mode.
423 | return_intermediate: A boolean for whether all intermediate features are
424 | returned, or a container of intermediate feature names to return.
425 | frame_paddings: Optional binary tensor of shape [B, T] indicating padding.
426 | 1 denotes padding frame.
427 |
428 | Returns:
429 | embeddings: Output tensor for video embeddings of shape [B, T * N, D].
430 | outputs: A dictionary of additional outputs, including `spatial_features`
431 | (shape = [B, T * N, D]). Empty if `return_intermediate` is False or does
432 | not contain 'spatial_features'.
433 | """
434 | b, t, h, w, c = inputs.shape
435 | assert h == w
436 | reshaped_inputs = inputs.reshape(b * t, h, w, c) # (B * T, H, W, C).
437 |
438 | # Tokenization.
439 | patches = _image_to_patch(reshaped_inputs, self.patch_size)
440 | patches_paddings = None
441 | if frame_paddings is not None:
442 | assert frame_paddings.shape == (b, t)
443 | reshaped_frame_paddings = frame_paddings.reshape(b * t) # (B * T,).
444 | num_patches = patches.shape[1]
445 | patches_paddings = jnp.repeat(
446 | reshaped_frame_paddings[:, jnp.newaxis], num_patches, axis=-1
447 | ) # (B * T, num_patches).
448 |
449 | embeddings, outputs = self.encode_with_patches(
450 | patches=patches,
451 | image_shape=(t, h, w),
452 | train=train,
453 | return_intermediate=return_intermediate,
454 | patches_paddings=patches_paddings,
455 | )
456 | return embeddings, outputs
457 |
458 | @nn.compact
459 | def encode_with_patches(
460 | self,
461 | patches: Array,
462 | image_shape: tuple[int, int, int],
463 | train: bool = False,
464 | return_intermediate: bool | Collection[str] = False,
465 | patches_paddings: Array | None = None,
466 | ) -> tuple[Array, dict[str, Array]]:
467 | """Computes predictions for patches.
468 |
469 | Args:
470 | patches: Input patches tensor of shape [B * T, (H * W / P^2), P^2 * C].
471 | image_shape: Original image shape (T, H, W).
472 | train: If the model is in the train mode.
473 | return_intermediate: A boolean for whether all intermediate features are
474 | returned, or a collection of intermediate feature names to return.
475 | patches_paddings: Optional binary tensor of shape [B * T, (H * W / P^2)]
476 | indicating padding. 1 denotes padded patch.
477 |
478 | Returns:
479 | embeddings: Output tensor for video embedding sequence of shape [B, T * N,
480 | D].
481 | outputs: A dictionary of additional outputs, including `spatial_features`
482 | of shape [B, T * N, D]. Empty if `return_intermediate` is False or does
483 | not contain 'spatial_features'.
484 | """
485 | t, h, w = image_shape
486 | b = patches.shape[0] // t
487 |
488 | patches = layers.FeedForward( # (B * T, N, D).
489 | name='patch_projection',
490 | output_dim=self.model_dim,
491 | activation_fn=layers.identity,
492 | dtype=self.dtype,
493 | fprop_dtype=self.fprop_dtype,
494 | )(patches)
495 |
496 | # Add spatial positional encoding.
497 | spatial_pos_emb_shape = self.pos_emb_shape[-2:]
498 | spatial_seq_length = np.prod(spatial_pos_emb_shape)
499 | spatial_pos_emb = TrainablePositionalEmbedding(
500 | name='spatial_pos_emb',
501 | embedding_dim=self.model_dim,
502 | max_seq_length=spatial_seq_length,
503 | dtype=self.dtype,
504 | fprop_dtype=self.fprop_dtype,
505 | )(seq_length=spatial_seq_length)
506 | num_row_patches = h // self.patch_size
507 | num_col_patches = w // self.patch_size
508 | if spatial_pos_emb_shape != (num_row_patches, num_col_patches):
509 | spatial_pos_emb = _interpolate_emb_2d(
510 | spatial_pos_emb,
511 | spatial_pos_emb_shape,
512 | (num_row_patches, num_col_patches),
513 | )
514 | patches += spatial_pos_emb # (B * T, N, D).
515 |
516 | # Get features from the spatial encoder.
517 | features = VisionTransformer( # (B * T, N, D).
518 | name='spatial_encoder',
519 | num_tfm_layers=self.num_spatial_layers,
520 | mlp_dim=self.mlp_dim,
521 | num_heads=self.num_heads,
522 | atten_logit_cap=self.atten_logit_cap,
523 | norm_policy=self.norm_policy,
524 | scan=self.scan,
525 | dtype=self.dtype,
526 | fprop_dtype=self.fprop_dtype,
527 | )(patches, train=train, paddings=patches_paddings)
528 | features = layers.LayerNorm(
529 | name='spatial_ln', dtype=self.dtype, fprop_dtype=self.fprop_dtype
530 | )(features)
531 | spatial_features = features
532 |
533 | # Instead of mean pooling, we keep the spatial tokens.
534 | # Shape = (B * N, T, D).
535 | features = einshape.jax_einshape('(bt)nd->(bn)td', features, t=t)
536 | temporal_paddings = None
537 | if patches_paddings is not None:
538 | temporal_paddings = einshape.jax_einshape(
539 | '(bt)n->(bn)t', patches_paddings, t=t
540 | ) # (B * N, T).
541 |
542 | # Add temporal positional encoding.
543 | temporal_seq_length = self.pos_emb_shape[0]
544 | temporal_pos_emb = TrainablePositionalEmbedding(
545 | name='temporal_pos_emb',
546 | embedding_dim=self.model_dim,
547 | max_seq_length=temporal_seq_length,
548 | dtype=self.dtype,
549 | fprop_dtype=self.fprop_dtype,
550 | )(seq_length=temporal_seq_length)
551 | if temporal_seq_length != t:
552 | temporal_pos_emb = _interpolate_emb_1d(temporal_pos_emb, t)
553 | features += temporal_pos_emb
554 |
555 | # Get features from the temporal encoder.
556 | features = VisionTransformer(
557 | name='temporal_encoder',
558 | num_tfm_layers=self.num_temporal_layers,
559 | mlp_dim=self.mlp_dim,
560 | num_heads=self.num_heads,
561 | atten_logit_cap=self.atten_logit_cap,
562 | norm_policy=self.norm_policy,
563 | scan=self.scan,
564 | dtype=self.dtype,
565 | fprop_dtype=self.fprop_dtype,
566 | )(features, train=train, paddings=temporal_paddings)
567 | features = layers.LayerNorm(
568 | name='temporal_ln', dtype=self.dtype, fprop_dtype=self.fprop_dtype
569 | )(features)
570 | features = einshape.jax_einshape( # (B, T * N, D).
571 | '(bn)td->b(tn)d', features, b=b
572 | )
573 |
574 | embeddings, outputs = features, {}
575 | if _contains(return_intermediate, 'spatial_features'):
576 | outputs['spatial_features'] = einshape.jax_einshape(
577 | '(bt)nd->b(tn)d', spatial_features, t=t
578 | )
579 |
580 | return embeddings, outputs
581 |
582 |
583 | class FactorizedVideoClassifier(layers.Module):
584 | """Video classifier with `FactorizedEncoder` backbone.
585 |
586 | Attributes:
587 | encoder_params: A dictionary of parameters for `FactorizedEncoder`.
588 | num_classes: Number of output classes.
589 | """
590 |
591 | encoder_params: dict[str, Any] = dataclasses.field(default_factory=dict)
592 | num_classes: int = 0
593 |
594 | @nn.compact
595 | def __call__(
596 | self,
597 | inputs: Array,
598 | train: bool = False,
599 | return_intermediate: bool | Collection[str] = False,
600 | frame_paddings: Array | None = None,
601 | ):
602 | """Applies video classifier to inputs.
603 |
604 | Args:
605 | inputs: Input tensor of shape [B, T, H, W, 3].
606 | train: Whether the model is in the training mode.
607 | return_intermediate: A boolean for whether all intermediate features are
608 | returned, or a collection of intermediate feature names to return.
609 | frame_paddings: Optional binary tensor of shape [B, T] indicating padding.
610 | 1 denotes padding frame.
611 |
612 | Returns:
613 | logits: Output tensor of shape [B, num_classes].
614 | outputs: A dictionary of additional outputs, including `spatial_features`
615 | of shape [B, T * N, D], `spatiotemporal_features` of shape [B, T * N,
616 | D], and `global_embeddings` of shape [B, D]. Empty if
617 | `return_intermediate` is False.
618 | """
619 | features, outputs = FactorizedEncoder(
620 | name='encoder',
621 | dtype=self.dtype,
622 | fprop_dtype=self.fprop_dtype,
623 | **self.encoder_params,
624 | )(
625 | inputs,
626 | train=train,
627 | return_intermediate=return_intermediate,
628 | frame_paddings=frame_paddings,
629 | )
630 | if _contains(return_intermediate, 'spatiotemporal_features'):
631 | outputs['spatiotemporal_features'] = features
632 |
633 | embeddings = layers.AttenTokenPoolingLayer(
634 | name='atten_pooler',
635 | num_heads=self.encoder_params['num_heads'],
636 | hidden_dim=self.encoder_params['model_dim'],
637 | num_queries=1,
638 | dtype=self.dtype,
639 | fprop_dtype=self.fprop_dtype,
640 | )(features, paddings=None, train=train)
641 | embeddings = jnp.squeeze(embeddings, axis=-2)
642 |
643 | if _contains(return_intermediate, 'global_embeddings'):
644 | outputs['global_embeddings'] = embeddings
645 |
646 | logits = layers.FeedForward(
647 | name='projection',
648 | output_dim=self.num_classes,
649 | activation_fn=layers.identity,
650 | dtype=self.dtype,
651 | fprop_dtype=self.fprop_dtype,
652 | )(embeddings)
653 | return logits, outputs
654 |
655 |
656 | class TextEncoder(layers.Module):
657 | """CoCa-style text encoder.
658 |
659 | Reference: https://arxiv.org/abs/2205.01917
660 |
661 | Attributes:
662 | vocabulary_size: Vocabulary size of the text tokens.
663 | num_class_tokens: Number of class tokens.
664 | enable_causal_atten: Whether to enable causal attention.
665 | model_dim: The model dimension.
666 | num_tfm_layers: Number of layers in this model.
667 | mlp_dim: The hidden layer dimension of FFN in Transformer layers.
668 | num_heads: Number of attention heads.
669 | enable_per_dim_scale: Whether to ensable rescaling of attention logits with
670 | 1/sqrt(dim) factor.
671 | atten_logit_cap: Cap the absolute values of logits by tanh. Enabled when a
672 | positive value is specified. May not be supported by a subclass.
673 | norm_policy: Policy for applying normalization wrt. transformations. Options
674 | are: (1) "pre", applied before transformation. (2) "primer_hybrid",
675 | applied before and after transformation. (3) "post", applied after
676 | transformation. (4) "post_skip", applied after the skip connection.
677 | scan: Whether to use `nn.remat` and`nn.scan`.
678 | """
679 |
680 | vocabulary_size: int = 128
681 | num_class_tokens: int = 0
682 | enable_causal_atten: bool = True
683 | model_dim: int = 768
684 | num_layers: int = 12
685 | mlp_dim: int = 3072
686 | num_heads: int = 12
687 | atten_logit_cap: float = 0.0
688 | norm_policy: str = 'pre'
689 | enable_per_dim_scale: bool = False
690 | scan: bool = False
691 |
692 | @nn.compact
693 | def __call__(
694 | self, inputs: Array, paddings: Array, train: bool = False
695 | ) -> Array:
696 | """Applies the text encoder to the inputs.
697 |
698 | Args:
699 | inputs: Input tensor of shape [B, N] including sequences of token ids.
700 | paddings: Optional [B, N] padding field of inputs.
701 | train: If the model is in the train mode.
702 |
703 | Returns:
704 | Output tensor of shape [B, N, D].
705 | """
706 | batch_size, seq_length = inputs.shape
707 |
708 | pos_emb = PositionalEmbedding(
709 | name='pos_emb',
710 | embedding_dim=self.model_dim,
711 | dtype=self.dtype,
712 | fprop_dtype=self.fprop_dtype,
713 | )(seq_length=seq_length)
714 | input_emb = Embedding(
715 | name='token_emb',
716 | num_classes=self.vocabulary_size,
717 | input_dim=self.model_dim,
718 | scale_sqrt_depth=True,
719 | dtype=self.dtype,
720 | fprop_dtype=self.fprop_dtype,
721 | )(inputs)
722 | features = input_emb + pos_emb
723 |
724 | if self.num_class_tokens > 0:
725 | cls_emb = self._cast_to_fprop_dtype(
726 | self.param(
727 | 'cls_emb',
728 | nn.initializers.normal(stddev=1.0 / math.sqrt(self.model_dim)),
729 | [1, self.num_class_tokens, self.model_dim],
730 | self.dtype,
731 | )
732 | )
733 | cls_emb = jnp.tile(cls_emb, [batch_size, 1, 1])
734 | cls_emb *= self.model_dim**0.5
735 | features = jnp.concatenate([features, cls_emb], axis=-2)
736 |
737 | cls_paddings = jnp.zeros(
738 | [batch_size, self.num_class_tokens], dtype=paddings.dtype
739 | )
740 | paddings = jnp.concatenate([paddings, cls_paddings], axis=-1)
741 |
742 | features = layers.StackedTransformer(
743 | name='unimodal_transformer',
744 | num_layers=self.num_layers,
745 | hidden_dim=self.mlp_dim,
746 | num_heads=self.num_heads,
747 | atten_logit_cap=self.atten_logit_cap,
748 | norm_policy=self.norm_policy,
749 | internal_enable_per_dim_scale=self.enable_per_dim_scale,
750 | activation_fn=jax.nn.relu,
751 | enable_causal_atten=self.enable_causal_atten,
752 | scan=self.scan,
753 | dtype=self.dtype,
754 | fprop_dtype=self.fprop_dtype,
755 | )(features, paddings, train=train)
756 | features = layers.LayerNorm(
757 | name='unimodal_ln', dtype=self.dtype, fprop_dtype=self.fprop_dtype
758 | )(features)
759 | return features
760 |
761 |
762 | class FactorizedVideoCLIP(layers.Module):
763 | """Video CLIP model with a factorized vision encoder."""
764 |
765 | # Vision parameters.
766 | patch_size: int = 18
767 | pos_emb_shape: tuple[int, int, int] = (16, 16, 16)
768 | num_spatial_layers: int = 12
769 | num_temporal_layers: int = 4
770 | mlp_dim: int = 3072
771 | num_auxiliary_layers: int = 0
772 | # Text parameters.
773 | vocabulary_size: int = 128
774 | enable_causal_atten: bool = True
775 | num_unimodal_layers: int = 12
776 | norm_policy: str = 'pre'
777 | # Shared parameters.
778 | model_dim: int = 768
779 | num_heads: int = 12
780 | atten_logit_cap: float = 0.0
781 | scan: bool = False
782 |
783 | @nn.compact
784 | def __call__(
785 | self,
786 | inputs: Array | None = None,
787 | text_token_ids: Array | None = None,
788 | text_paddings: Array | None = None,
789 | train: bool = False,
790 | normalize: bool = True,
791 | return_intermediate: bool | Collection[str] = False,
792 | frame_paddings: Array | None = None,
793 | ) -> tuple[Array | None, Array | None, dict[str, Array]]:
794 | """Computes predictions for `input_batch`.
795 |
796 | Args:
797 | inputs: Input frame image tensor of shape [B, T, H, W, 3] (H == W).
798 | text_token_ids: Input text token id tensor of shape [B, L].
799 | text_paddings: Input text paddings of shape [B, L]. Required if
800 | `text_token_ids` is not None.
801 | train: If the model is in the train mode.
802 | normalize: Whether to normalize the output embeddings.
803 | return_intermediate: A boolean for whether all intermediate features are
804 | returned, or a collection of intermediate feature names to return.
805 | frame_paddings: Optional binary tensor of shape [B, T] indicating padding.
806 | 1 denotes padding frame.
807 |
808 | Returns:
809 | video_embeddings: Output contrastive video embeddings of shape [B, D].
810 | None if `inputs` is None.
811 | text_embeddings: Output contrastive text embeddings of shape [B, D]. None
812 | if `text_token_ids` is None.
813 | outputs: A dictionary of additional outputs, including `spatial_features`
814 | of shape [B, T * N, D], `spatiotemporal_features` of shape [B, T * N,
815 | D], and `frame_embeddings` of shape [B, T, D]. Empty if
816 | `return_intermediate` is False or does not contain `spatial_features`.
817 | """
818 | video_embeddings, text_embeddings, outputs = None, None, {}
819 |
820 | if inputs is not None:
821 | num_frames = inputs.shape[-4]
822 | vision_features, vision_outputs = FactorizedEncoder(
823 | name='vision_encoder',
824 | patch_size=self.patch_size,
825 | pos_emb_shape=self.pos_emb_shape,
826 | model_dim=self.model_dim,
827 | num_spatial_layers=self.num_spatial_layers,
828 | num_temporal_layers=self.num_temporal_layers,
829 | num_heads=self.num_heads,
830 | mlp_dim=self.mlp_dim,
831 | atten_logit_cap=self.atten_logit_cap,
832 | norm_policy='pre',
833 | scan=self.scan,
834 | dtype=self.dtype,
835 | fprop_dtype=self.fprop_dtype,
836 | )(
837 | inputs,
838 | train=train,
839 | return_intermediate=return_intermediate,
840 | frame_paddings=frame_paddings,
841 | )
842 | outputs.update(vision_outputs)
843 | if _contains(return_intermediate, 'spatiotemporal_features'):
844 | outputs['spatiotemporal_features'] = vision_features
845 |
846 | if self.num_auxiliary_layers > 0:
847 | vision_features = VisionTransformer(
848 | name='auxiliary_encoder',
849 | num_tfm_layers=self.num_auxiliary_layers,
850 | mlp_dim=self.mlp_dim,
851 | num_heads=self.num_heads,
852 | atten_logit_cap=self.atten_logit_cap,
853 | norm_policy='pre',
854 | scan=self.scan,
855 | dtype=self.dtype,
856 | fprop_dtype=self.fprop_dtype,
857 | )(vision_features, train=train)
858 |
859 | pooling_layer = layers.AttenTokenPoolingLayer(
860 | name='contrastive_vision_pooler',
861 | hidden_dim=self.model_dim * 4,
862 | num_heads=self.num_heads,
863 | num_queries=1,
864 | dtype=self.dtype,
865 | fprop_dtype=self.fprop_dtype,
866 | )
867 | video_embeddings = pooling_layer(vision_features, None, train=train)
868 |
869 | # Squeeze the query dimension in the pooler output.
870 | video_embeddings = jnp.squeeze(video_embeddings, axis=-2)
871 | if normalize:
872 | video_embeddings = _l2_normalize(video_embeddings, axis=-1)
873 |
874 | if _contains(return_intermediate, 'frame_embeddings'):
875 | frame_features = einshape.jax_einshape(
876 | 'b(tn)d->(bt)nd', vision_features, t=num_frames
877 | )
878 | frame_embeddings = pooling_layer(frame_features, None, train=train)
879 | frame_embeddings = jnp.squeeze(frame_embeddings, axis=-2)
880 | frame_embeddings = einshape.jax_einshape(
881 | '(bt)d->btd', frame_embeddings, t=num_frames
882 | )
883 | if normalize:
884 | frame_embeddings = _l2_normalize(frame_embeddings, axis=-1)
885 | outputs['frame_embeddings'] = frame_embeddings
886 |
887 | if text_token_ids is not None:
888 | assert text_paddings is not None, 'Text paddings are required.'
889 | text_features = TextEncoder(
890 | name='text_encoder',
891 | vocabulary_size=self.vocabulary_size,
892 | num_class_tokens=1,
893 | enable_causal_atten=self.enable_causal_atten,
894 | model_dim=self.model_dim,
895 | num_layers=self.num_unimodal_layers,
896 | num_heads=self.num_heads,
897 | mlp_dim=self.model_dim * 4,
898 | atten_logit_cap=self.atten_logit_cap,
899 | norm_policy=self.norm_policy,
900 | scan=self.scan,
901 | dtype=self.dtype,
902 | fprop_dtype=self.fprop_dtype,
903 | )(text_token_ids, text_paddings, train=train)
904 |
905 | # Take the last token (i.e., class token) as the text embedding.
906 | text_embeddings = text_features[:, -1]
907 | if normalize:
908 | text_embeddings = _l2_normalize(text_embeddings, axis=-1)
909 |
910 | return video_embeddings, text_embeddings, outputs
911 |
--------------------------------------------------------------------------------
/videoprism/layers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 VideoPrism Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """VideoPrism Flax layers."""
16 |
17 | from collections.abc import Callable
18 | import functools
19 | import string
20 | from typing import Any
21 | from flax import linen as nn
22 | import jax
23 | from jax import numpy as jnp
24 | import numpy as np
25 |
26 | Array = jax.Array
27 | ActivationFunc = Callable[[Array], Array]
28 | Initializer = nn.initializers.Initializer
29 |
30 | default_kernel_init = nn.initializers.lecun_normal()
31 | gelu = functools.partial(jax.nn.gelu, approximate=False)
32 |
33 |
34 | def identity(x: Array) -> Array:
35 | """Identity activation."""
36 | return x
37 |
38 |
39 | def _get_large_negative_number(dtype: jax.typing.DTypeLike) -> Array:
40 | """Returns a large-magnitude negative value for the given dtype."""
41 | # -0.7 is a float64 in JAX. Explicit cast output to target dtype.
42 | if jnp.issubdtype(dtype, jnp.inexact):
43 | dtype_max = jnp.finfo(dtype).max
44 | elif jnp.issubdtype(dtype, jnp.integer):
45 | dtype_max = jnp.iinfo(dtype).max
46 | else:
47 | raise ValueError('Unsupported dtype for inputs.')
48 | return jnp.asarray(-0.7 * dtype_max, dtype=dtype)
49 |
50 |
51 | def _apply_mask_to_logits(logits: Array, mask: Array) -> Array:
52 | """Applies a floating-point mask to a set of logits.
53 |
54 | The mask is represented as a float32 tensor where 0 represents true and values
55 | below a large negative number (here set to
56 | _get_large_negative_number(jnp.float32) / 2) represent false. Applying the
57 | mask leaves the logits alone in the true case and replaces them by
58 | _get_large_negative_number(jnp.float32) in the false case. Previously, this
59 | was done by adding the logits to the mask; however, this leads to a bad fusion
60 | decision in the compiler that saves the float32 values in memory rather than
61 | just the predicate. This implementation avoids that problem.
62 |
63 | Args:
64 | logits: A jax.Array of logit values.
65 | mask: A jax.Array (float32) of mask values with the encoding described in
66 | the function documentation.
67 |
68 | Returns:
69 | Masked logits.
70 | """
71 | min_value = _get_large_negative_number(logits.dtype)
72 | return jnp.where((mask >= min_value * 0.5), logits, min_value)
73 |
74 |
75 | def _convert_paddings_to_mask(
76 | paddings: Array, dtype: jax.typing.DTypeLike = jnp.float32
77 | ) -> Array:
78 | """Converts binary paddings to a logit mask ready to add to attention matrix.
79 |
80 | Args:
81 | paddings: A binary jax.Array of shape [B, T], with 1 denoting padding token.
82 | dtype: Data type of the input.
83 |
84 | Returns:
85 | A jax.Array of shape [B, 1, 1, T] ready to be added to attention logits.
86 | """
87 | attention_mask = paddings[:, jnp.newaxis, jnp.newaxis, :]
88 | attention_mask *= _get_large_negative_number(dtype)
89 | return attention_mask
90 |
91 |
92 | def _causal_mask(input_t: Array) -> Array:
93 | """Computes and returns causal mask.
94 |
95 | Args:
96 | input_t: A jax.Array of shape [B, T, D].
97 |
98 | Returns:
99 | An attention_mask jax.Array of shape [1, 1, T, T]. Attention mask has
100 | already been converted large negative values.
101 | """
102 | assert jnp.issubdtype(input_t.dtype, jnp.floating), input_t.dtype
103 | large_negative_number = _get_large_negative_number(input_t.dtype)
104 | t = input_t.shape[-2]
105 | col_idx = jnp.tile(jnp.arange(t)[jnp.newaxis, :], [t, 1])
106 | row_idx = jnp.tile(jnp.arange(t)[:, jnp.newaxis], [1, t])
107 | mask = (row_idx < col_idx).astype(input_t.dtype) * large_negative_number
108 | return mask[jnp.newaxis, jnp.newaxis, :, :]
109 |
110 |
111 | def _merge_masks(a: Array, b: Array) -> Array:
112 | """Merges two masks.
113 |
114 | This function merges two masks with the same shape, where the smaller value
115 | will be chosen at the same position. Log-scale mask is expected but 0/1 mask
116 | is also fine.
117 |
118 | Args:
119 | a: A jax.Array of shape [1|B, 1, 1|T, S].
120 | b: A jax.Array of shape [1|B, 1, 1|T, S].
121 |
122 | Returns:
123 | A jax.Array of shape [1|B, 1, 1|T, S].
124 | """
125 |
126 | def expand_t(key_mask):
127 | """Expands the 1D mask to the 2D mask.
128 |
129 | Given [[1, 1, 0, 0]], this function returns the following mask,
130 | 1 1 0 0
131 | 1 1 0 0
132 | 0 0 0 0
133 | 0 0 0 0
134 |
135 | Args:
136 | key_mask: A jax.Array of the input 1D mask.
137 |
138 | Returns:
139 | A jax.Array of the expanded 2D mask.
140 | """
141 | query_mask = jnp.transpose(key_mask, [0, 1, 3, 2])
142 | return jnp.minimum(query_mask, key_mask)
143 |
144 | if a.shape[-2] != b.shape[-2]:
145 | if a.shape[-2] == 1:
146 | a = expand_t(a)
147 | else:
148 | assert b.shape[-2] == 1
149 | b = expand_t(b)
150 |
151 | assert a.shape[-3:] == b.shape[-3:], f'a.shape={a.shape}, b.shape={b.shape}.'
152 | return jnp.minimum(a, b)
153 |
154 |
155 | def compute_attention_masks_for_fprop(
156 | inputs: Array,
157 | paddings: Array,
158 | causal_attention: bool = False,
159 | ) -> Array:
160 | """Computes attention mask from inputs and paddings for fprop.
161 |
162 | Args:
163 | inputs: Input sequence jax.Array of shape [B, T, H].
164 | paddings: Input paddings jax.Array of shape [B, T].
165 | causal_attention: Boolean to apply causal masking.
166 |
167 | Returns:
168 | attention_mask: Attention mask jax.Array ready to be added to logits for
169 | self-attention of shape [1|B, 1, 1|T, T].
170 | """
171 | # Get paddings mask to [B, 1, 1, T].
172 | attention_mask = _convert_paddings_to_mask(paddings, inputs.dtype)
173 |
174 | # Causal mask of shape [1, 1, T, T].
175 | if causal_attention:
176 | causal_mask = _causal_mask(inputs)
177 | attention_mask = _merge_masks(attention_mask, causal_mask)
178 |
179 | return attention_mask
180 |
181 |
182 | class Module(nn.Module):
183 | """Base class for layers with dtype configured.
184 |
185 | Attributes:
186 | dtype: Default dtype for all variables.
187 | fprop_dtype: Activations dtype to use.
188 | """
189 |
190 | dtype: jnp.dtype = jnp.float32
191 | fprop_dtype: jnp.dtype = jnp.float32
192 |
193 | @nn.nowrap
194 | def _cast_to_fprop_dtype(self, value: Any) -> Any:
195 | """Casts values to the desired dtype."""
196 |
197 | def _cast(x):
198 | if x is None:
199 | return None
200 | if self.fprop_dtype != x.dtype:
201 | if jnp.issubdtype(x.dtype, jnp.floating):
202 | return x.astype(self.fprop_dtype)
203 | return x
204 |
205 | return jax.tree_util.tree_map(_cast, value)
206 |
207 |
208 | class LayerNorm(Module):
209 | """Layer normalization.
210 |
211 | Attributes:
212 | direct_scale: Whether to apply scale directly without a +1.0. Var is
213 | initialized to 1.0 instead when True.
214 | epsilon: Tiny value to guard rsqrt.
215 | use_scale: Whether to use a learned scaling.
216 | use_bias: Whether to use bias.
217 | reductions_in_fp32: Whether to compute mean and variance in fp32.
218 | Recommended for stable training on GPUs.
219 | """
220 |
221 | direct_scale: bool = False
222 | epsilon: float = 1e-6
223 | use_scale: bool = True
224 | use_bias: bool = True
225 | reductions_in_fp32: bool = False
226 |
227 | @nn.compact
228 | def __call__(self, inputs: Array) -> Array:
229 | """Applies layer norm to inputs.
230 |
231 | Args:
232 | inputs: A jax.Array for the inputs of shape [..., dim].
233 |
234 | Returns:
235 | A jax.Aray for the normalized inputs of the same shape.
236 | """
237 | inputs_dtype = inputs.dtype
238 | if self.reductions_in_fp32:
239 | inputs = inputs.astype(jnp.float32)
240 | mean = jnp.mean(inputs, axis=[-1], keepdims=True)
241 | var = jnp.mean(jnp.square(inputs - mean), axis=[-1], keepdims=True)
242 | normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon)
243 | if self.reductions_in_fp32:
244 | normed_inputs = normed_inputs.astype(inputs_dtype)
245 |
246 | input_dim = inputs.shape[-1]
247 | if self.use_scale:
248 | init_value = 1.0 if self.direct_scale else 0.0
249 | scale = self._cast_to_fprop_dtype(
250 | self.param(
251 | 'scale',
252 | nn.initializers.constant(init_value),
253 | [input_dim],
254 | self.dtype,
255 | )
256 | )
257 | if not self.direct_scale:
258 | scale += 1.0
259 | normed_inputs *= scale
260 | if self.use_bias:
261 | bias = self._cast_to_fprop_dtype(
262 | self.param(
263 | 'bias',
264 | nn.initializers.zeros_init(),
265 | [input_dim],
266 | self.dtype,
267 | )
268 | )
269 | normed_inputs += bias
270 | return normed_inputs
271 |
272 |
273 | class FeedForward(Module):
274 | """Feedforward layer with activation.
275 |
276 | Attributes:
277 | output_dim: Depth of the output.
278 | has_bias: Adds bias weights or not.
279 | activation_fn: Activation function to use.
280 | weight_init: Initializer function for the weight matrix.
281 | bias_init: Initializer function for the bias.
282 | """
283 |
284 | output_dim: int = 0
285 | has_bias: bool = True
286 | activation_fn: ActivationFunc = nn.relu
287 | weight_init: Initializer = default_kernel_init
288 | bias_init: Initializer = nn.initializers.zeros_init()
289 |
290 | @nn.compact
291 | def __call__(self, inputs: Array) -> Array:
292 |
293 | def _promote_dtype(x, kernel, bias, dtype):
294 | """Promotes the dtype of the arrays to the desired dtype."""
295 | del dtype
296 | # To be compatible with other layers, we do not promote the inputs as they
297 | # are expected to be in the `fprop_dtype`.
298 | return (
299 | x,
300 | self._cast_to_fprop_dtype(kernel),
301 | self._cast_to_fprop_dtype(bias),
302 | )
303 |
304 | projected_inputs = nn.Dense(
305 | self.output_dim,
306 | use_bias=self.has_bias,
307 | kernel_init=self.weight_init,
308 | bias_init=self.bias_init,
309 | name='linear',
310 | param_dtype=self.dtype,
311 | promote_dtype=_promote_dtype,
312 | )(inputs)
313 | return self.activation_fn(projected_inputs)
314 |
315 |
316 | class TransformerFeedForward(Module):
317 | """Transformer feedforward layer with residual connection and dropout.
318 |
319 | Attributes:
320 | output_dim: Depth of the output. The value of input_dim will be used when
321 | output_dim is 0. Must be equal to input_dim if add_skip_connection=True.
322 | hidden_dim: Hidden dimension of FFN.
323 | has_bias: Adds bias weights to Feedforward or not.
324 | activation_fn: Activation function to use.
325 | residual_dropout_prob: Residual dropout.
326 | relu_dropout_prob: FFN dropout.
327 | add_skip_connection: Whether to add residual connection.
328 | residual_weight: Weight of the residual connection. Output = fn(x) *
329 | residual_weight + x.
330 | norm_policy: Policy for applying normalization wrt. transformations. Options
331 | are: (1) "pre", applied before transformation. (2) "primer_hybrid",
332 | applied before and after transformation. (3) "post", applied after
333 | transformation, (4) "post_skip", applied after the skip connection.
334 | """
335 |
336 | output_dim: int = 0
337 | hidden_dim: int = 0
338 | has_bias: bool = True
339 | activation_fn: ActivationFunc = nn.relu
340 | residual_dropout_prob: float = 0.0
341 | relu_dropout_prob: float = 0.0
342 | add_skip_connection: bool = True
343 | residual_weight: float = 1.0
344 | norm_policy: str = 'pre'
345 |
346 | @nn.nowrap
347 | def _make_ln(self, name: str) -> LayerNorm:
348 | """Makes a LayerNorm module."""
349 | return LayerNorm(
350 | name=name,
351 | use_bias=self.has_bias,
352 | dtype=self.dtype,
353 | fprop_dtype=self.fprop_dtype,
354 | )
355 |
356 | @nn.nowrap
357 | def _make_ffn(
358 | self, output_dim: int, name: str, skip_activation: bool = False
359 | ) -> FeedForward:
360 | """Makes a FeedForward module."""
361 | return FeedForward(
362 | name=name,
363 | output_dim=output_dim,
364 | has_bias=self.has_bias,
365 | activation_fn=identity if skip_activation else self.activation_fn,
366 | dtype=self.dtype,
367 | fprop_dtype=self.fprop_dtype,
368 | )
369 |
370 | @nn.compact
371 | def __call__(
372 | self, inputs: Array, paddings: Array | None, train: bool
373 | ) -> Array:
374 | residual = inputs
375 | output_dim = self.output_dim
376 | if output_dim == 0:
377 | output_dim = inputs.shape[-1]
378 | if self.add_skip_connection and output_dim != inputs.shape[-1]:
379 | raise ValueError(
380 | 'Skip connections are only supported when input_dim == output_dim '
381 | f'but got {self.input_dim} != {output_dim}'
382 | )
383 |
384 | # Expand paddings to last dim if not None to have shape [batch, seq_len, 1].
385 | if paddings is not None:
386 | paddings = jnp.expand_dims(paddings, axis=-1)
387 |
388 | if self.norm_policy == 'primer_hybrid':
389 | inputs = self._make_ln(name='pre_layer_norm')(inputs)
390 | elif self.norm_policy == 'pre':
391 | inputs = self._make_ln(name='layer_norm')(inputs)
392 |
393 | # Apply first FFN layer.
394 | activations = self._make_ffn(self.hidden_dim, name='ffn_layer1')(inputs)
395 |
396 | # Apply paddings if not None.
397 | if paddings is not None:
398 | activations *= 1.0 - paddings
399 |
400 | # Apply RELU dropout.
401 | activations = nn.Dropout(self.relu_dropout_prob, name='relu_dropout')(
402 | activations, deterministic=not train
403 | )
404 | # Apply second FFN layer.
405 | outputs = self._make_ffn(
406 | output_dim, name='ffn_layer2', skip_activation=True
407 | )(activations)
408 |
409 | # Apply paddings if not None.
410 | if paddings is not None:
411 | outputs *= 1.0 - paddings
412 |
413 | # Apply Primer normalization before dropout.
414 | if self.norm_policy == 'primer_hybrid':
415 | outputs = self._make_ln(name='post_layer_norm')(outputs)
416 | elif self.norm_policy == 'post':
417 | outputs = self._make_ln(name='layer_norm')(outputs)
418 |
419 | # Apply residual dropout.
420 | outputs = nn.Dropout(self.residual_dropout_prob, name='residual_dropout')(
421 | outputs, deterministic=not train
422 | )
423 | # Apply skip connection.
424 | if self.add_skip_connection:
425 | outputs = residual + outputs * self.residual_weight
426 |
427 | if self.norm_policy == 'post_skip':
428 | outputs = self._make_ln(name='layer_norm')(outputs)
429 |
430 | return outputs
431 |
432 |
433 | class AttentionProjection(Module):
434 | """Layer that computes multi heads projection.
435 |
436 | This layer is expected to be used within DotProductAttention below.
437 |
438 | Attributes:
439 | output_dim: Input dimension.
440 | num_heads: Number of attention heads.
441 | dim_per_head: Size of each head.
442 | is_output_projection: Whether it is out projection or not. If False, we use
443 | "...D,DNH->...NH" for query,key,value projection. Otherwise we use
444 | "...NH,DNH->...D" for output projection.
445 | use_bias: Whether to add bias in projection or not.
446 | """
447 |
448 | output_dim: int = 0
449 | num_heads: int = 0
450 | dim_per_head: int = 0
451 | is_output_projection: bool = False
452 | use_bias: bool = True
453 |
454 | @nn.compact
455 | def __call__(self, inputs: Array) -> Array:
456 | """Computes the multi headed projection for inputs.
457 |
458 | Args:
459 | inputs: A jax.Array with shape [..., num_heads, dim_per_head] if
460 | is_output_projection is True or [..., input_dim] otherwise.
461 |
462 | Returns:
463 | The projected jax.Array with shape [..., input_dim] if
464 | is_output_projection is True or [..., num_heads, dim_per_head]
465 | otherwise.
466 | """
467 | # Sort the available symbols to avoid nondeterminism.
468 | eqn_sym = ''.join(sorted(set(string.ascii_uppercase) - set('DHN')))
469 | output_dim = (
470 | self.output_dim if self.is_output_projection else inputs.shape[-1]
471 | )
472 | rank = len(inputs.shape)
473 |
474 | hd_shape = [self.num_heads, self.dim_per_head]
475 | pc_shape = [output_dim] + hd_shape
476 | w = self._cast_to_fprop_dtype(
477 | self.param('w', default_kernel_init, pc_shape, self.dtype)
478 | )
479 |
480 | if self.is_output_projection:
481 | assert inputs.shape[-2:] == (self.num_heads, self.dim_per_head)
482 | batch_eqn = eqn_sym[: (rank - 2)]
483 | eqn = f'{batch_eqn}NH,DNH->{batch_eqn}D'
484 | else:
485 | batch_eqn = eqn_sym[: (rank - 1)] if rank else '...'
486 | eqn = f'{batch_eqn}D,DNH->{batch_eqn}NH'
487 |
488 | ret = jnp.einsum(eqn, inputs, w)
489 | if self.use_bias:
490 | b = self._cast_to_fprop_dtype(
491 | self.param(
492 | 'b',
493 | nn.initializers.zeros_init(),
494 | [output_dim] if self.is_output_projection else hd_shape,
495 | self.dtype,
496 | )
497 | )
498 | ret += b
499 | return ret
500 |
501 |
502 | class PerDimScale(Module):
503 | """A layer to scale individual dimensions of the input."""
504 |
505 | @nn.compact
506 | def __call__(self, inputs: Array) -> Array:
507 | """Returns per_dim_scale * inputs / jnp.sqrt(dim)).
508 |
509 | Args:
510 | inputs: A jax.Array with shape [..., dim].
511 |
512 | Returns:
513 | outputs: A jax.Array with shape [..., dim].
514 | """
515 | dim = inputs.shape[-1]
516 | per_dim_scale = self._cast_to_fprop_dtype(
517 | self.param(
518 | 'per_dim_scale', nn.initializers.zeros_init(), [dim], self.dtype
519 | )
520 | )
521 |
522 | # 1.0/jax.nn.softplus(0.0) = 1.442695041. Hard code this number so that we
523 | # can avoid unnecessary XLA op fusion mess on TPU.
524 | r_softplus_0 = 1.442695041
525 | scale = jnp.array(r_softplus_0 / np.sqrt(dim), dtype=self.fprop_dtype)
526 | scale *= jax.nn.softplus(per_dim_scale)
527 | return inputs * scale
528 |
529 |
530 | class DotProductAttention(Module):
531 | """Dot-product attention with multiple attention heads.
532 |
533 | Attributes:
534 | hidden_dim: Number of hidden nodes.
535 | num_heads: Number of attention heads.
536 | dim_per_head: Dimension of each attention head. If None then dim_per_head ==
537 | hidden_dim // num_heads.
538 | atten_dropout_prob: Probability at which we apply dropout to the attention
539 | weights.
540 | use_bias: Whether to use bias for projection layers.
541 | internal_enable_query_scale: Internal. Enable scaling of query vector.
542 | internal_enable_per_dim_scale: Internal. Setting to False disables rescaling
543 | of attention logits with 1/sqrt(dim) factor. Some Transformer variants
544 | (GShard, T5) use internal_enable_per_dim_scale=False and adjust
545 | initialization of the linear transformations(einsums), in conjunction with
546 | Adafactor optimizer.
547 | scale_query_by_dim_per_head: whether to scale the query by dim_per_head,
548 | instead of default hidden_dim // num_heads (only activated when
549 | internal_enable_per_dim_scale = False).
550 | scale_logits_by_head_dims: Enables a 1/sqrt(head dim) scaling to the logits.
551 | This occurs prior to logit cap, if any.
552 | atten_logit_cap: Cap the absolute values of logits by tanh. Enabled when a
553 | positive value is specified. May not be supported by a subclass.
554 | use_qk_norm: If QK norm is used.
555 | """
556 |
557 | hidden_dim: int = 0
558 | num_heads: int = 1
559 | dim_per_head: int | None = None
560 | atten_dropout_prob: float = 0.0
561 | use_bias: bool = True
562 | internal_enable_query_scale: bool = True
563 | internal_enable_per_dim_scale: bool = True
564 | scale_query_by_dim_per_head: bool = False
565 | scale_logits_by_head_dims: bool = False
566 | atten_logit_cap: float = 0.0
567 | use_qk_norm: bool = False
568 |
569 | def _scale_query(self, query: Array) -> Array:
570 | """Scales the query vector if enabled."""
571 | if not self.internal_enable_query_scale:
572 | return query
573 | if self.internal_enable_per_dim_scale:
574 | query = PerDimScale(
575 | name='per_dim_scale', dtype=self.dtype, fprop_dtype=self.fprop_dtype
576 | )(query)
577 | else:
578 | if self.scale_query_by_dim_per_head and self.dim_per_head is not None:
579 | dim_per_head = self.dim_per_head
580 | else:
581 | dim_per_head = self.hidden_dim // self.num_heads
582 |
583 | query *= dim_per_head**-0.5
584 | return query
585 |
586 | def _cap_logits(self, logits: Array) -> Array:
587 | """Caps the logits by p.atten_logit_cap with tanh, if enabled."""
588 | if not self.atten_logit_cap or self.atten_logit_cap <= 0.0:
589 | return logits
590 | cap = jnp.array(self.atten_logit_cap, dtype=self.fprop_dtype)
591 | # Note that since this caps the negative side as well, caller must defer the
592 | # pad-with-very-negative-logits logic to after this function returns.
593 | logits = cap * jnp.tanh(logits / cap)
594 | return logits
595 |
596 | def _atten_logits(self, query: Array, key: Array) -> Array:
597 | """Computes logits from query and key."""
598 | logits = jnp.einsum('BTNH,BSNH->BNTS', query, key)
599 | return logits
600 |
601 | def _dot_atten(
602 | self,
603 | query: Array,
604 | key: Array,
605 | value: Array,
606 | atten_mask: Array,
607 | train: bool,
608 | ) -> tuple[Array, Array]:
609 | """Main attention function.
610 |
611 | Args:
612 | query: A jax.Array of shape [B, T, N, H].
613 | key: A jax.Array of shape [B, S, N, H].
614 | value: A jax.Array of shape [B, S, N, H].
615 | atten_mask: A jax.Array of shape [1|B, 1, 1|T, S] which is a mask that is
616 | applied to prevent attention between unwanted pairs. This has already
617 | been converted into large negative logits. Note that the first and third
618 | dimension allow size 1 if the mask is shared by every item in the batch
619 | or every token in the target sequence.
620 | train: Whether the model is in the train mode.
621 |
622 | Returns:
623 | encoded: A jax.Array of shape [B, T, N, H].
624 | atten_probs: A jax.Array of shape [B, N, T, S].
625 | """
626 | assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
627 | assert (
628 | query.shape[:-3] == key.shape[:-3] == value.shape[:-3]
629 | ), 'q, k, v batch dims must match.'
630 | assert (
631 | query.shape[-2] == key.shape[-2] == value.shape[-2]
632 | ), 'q, k, v num_heads must match.'
633 | assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.'
634 | # If only padding bias is supplied, then atten_mask can be [B, 1, 1, S]
635 | # since each target token is prohibited from attending to the same set of
636 | # source tokens. In this case tiling is inefficient and unnecessary.
637 | # If there is no padding mask, and only causal mask then the shape can be
638 | # [1, 1, T, S].
639 | assert atten_mask.ndim == 4 and atten_mask.shape[-1] == key.shape[-3]
640 | assert atten_mask.shape[-2] in [query.shape[-3], 1]
641 | assert atten_mask.shape[0] in [key.shape[0], 1]
642 |
643 | query = self._scale_query(query)
644 | logits = self._atten_logits(query, key)
645 |
646 | if self.scale_logits_by_head_dims:
647 | logits = jnp.multiply(logits, 1.0 / np.sqrt(key.shape[-1]))
648 |
649 | logits = self._cap_logits(logits)
650 | # Attention softmax is always carried out in fp32.
651 | logits = logits.astype(jnp.float32)
652 | # Apply attention masking.
653 | padded_logits = _apply_mask_to_logits(logits, atten_mask)
654 | probs = jax.nn.softmax(padded_logits, axis=-1).astype(self.fprop_dtype)
655 | # Apply attention dropout.
656 | probs = nn.Dropout(self.atten_dropout_prob, name='atten_dropout')(
657 | probs, deterministic=not train
658 | )
659 | # Compute the attention context.
660 | encoded = jnp.einsum('BNTS,BSNH->BTNH', probs, value)
661 | return encoded, probs
662 |
663 | @nn.nowrap
664 | def _project_input(self, name: str, dim_per_head: int) -> AttentionProjection:
665 | """Builds an AttentionProjection module."""
666 | return AttentionProjection(
667 | name=name,
668 | num_heads=self.num_heads,
669 | dim_per_head=dim_per_head,
670 | use_bias=self.use_bias,
671 | dtype=self.dtype,
672 | fprop_dtype=self.fprop_dtype,
673 | )
674 |
675 | @nn.nowrap
676 | def _make_ln(self, name: str) -> LayerNorm:
677 | """Makes a LayerNorm module."""
678 | return LayerNorm(
679 | name=name,
680 | use_bias=self.use_bias,
681 | dtype=self.dtype,
682 | fprop_dtype=self.fprop_dtype,
683 | )
684 |
685 | @nn.compact
686 | def __call__(
687 | self,
688 | query_vec: Array,
689 | key_vec: Array,
690 | value_vec: Array,
691 | atten_mask: Array,
692 | train: bool,
693 | ) -> tuple[Array, Array]:
694 | """Computes the value vector given the current query output.
695 |
696 | Args:
697 | query_vec: jax.Array of shape [B, T, D].
698 | key_vec: jax.Array of shape [B, S, D].
699 | value_vec: jax.Array of shape [B, S, D].
700 | atten_mask: jax.Array of shape [1|B, 1, 1|T, S] which is a mask that is
701 | applied to prevent attention between unwanted pairs. This has already
702 | been converted into large negative logits. Note that the first and third
703 | dimension allow size 1 if the mask is shared by every item in the batch
704 | or every token in the target sequence.
705 | train: If the model is in the train mode.
706 |
707 | Returns:
708 | encoded: jax.Array of shape [B, T, D].
709 | atten_probs: jax.Array of shape [B, N, T, S].
710 | """
711 | dim_per_head = self.dim_per_head
712 | if dim_per_head is None:
713 | dim_per_head = self.hidden_dim // self.num_heads
714 | assert (
715 | dim_per_head * self.num_heads == self.hidden_dim
716 | ), f'{dim_per_head} * {self.num_heads} != {self.hidden_dim}'
717 |
718 | # Project inputs to key, value and query, respectively has shape
719 | # [B, S, N, H], [B, S, N, H], and [B, T, N, H].
720 | query_proj = self._project_input('query', dim_per_head)(query_vec)
721 | key_proj = self._project_input('key', dim_per_head)(key_vec)
722 | value_proj = self._project_input('value', dim_per_head)(value_vec)
723 |
724 | if self.use_qk_norm:
725 | query_proj = self._make_ln(name='layer_norm_q')(query_proj)
726 | key_proj = self._make_ln(name='layer_norm_k')(key_proj)
727 |
728 | encoded, atten_probs = self._dot_atten(
729 | query_proj, key_proj, value_proj, atten_mask, train=train
730 | )
731 |
732 | # Post projection. Setting is_output_projection=True to set the projection
733 | # direction from hidden dim to input dim. Output projection follows
734 | # query_input_dim.
735 | query_input_dim = query_vec.shape[-1]
736 | encoded = AttentionProjection(
737 | name='post',
738 | output_dim=query_input_dim,
739 | num_heads=self.num_heads,
740 | dim_per_head=dim_per_head,
741 | is_output_projection=True,
742 | use_bias=self.use_bias,
743 | dtype=self.dtype,
744 | fprop_dtype=self.fprop_dtype,
745 | )(encoded)
746 | return encoded, atten_probs
747 |
748 |
749 | class Transformer(Module):
750 | """Transformer layer with multi-headed attention.
751 |
752 | Attributes:
753 | hidden_dim: Hidden dimension of FFN layer.
754 | num_heads: Number of heads in self-attention.
755 | dim_per_head: Dimension of each attention head. If None then dim_per_head ==
756 | hidden_dim // num_heads.
757 | atten_dropout_prob: Probability at which we apply dropout to the attention
758 | weights.
759 | residual_dropout_prob: Probability at which we apply dropout to the residual
760 | layers, such that, residual(x, y) = (x + dropout(y)).
761 | relu_dropout_prob: Probability at which we apply dropout to the FFN layers.
762 | norm_policy: Policy for applying normalization wrt. transformations. Options
763 | are: (1) "pre", applied before transformation. (2) "primer_hybrid",
764 | applied before and after transformation. (3) "post", applied after
765 | transformation. (4) "post_skip", applied after the skip connection.
766 | use_bias: Whether to use bias.
767 | activation_fn: Activation function to use.
768 | internal_enable_per_dim_scale: Internal. Setting to False disables rescaling
769 | of attention logits with 1/sqrt(dim) factor.
770 | atten_logit_cap: Cap the absolute values of logits by tanh. Enabled when a
771 | positive value is specified. May not be supported by a subclass.
772 | """
773 |
774 | hidden_dim: int = 0
775 | num_heads: int = 0
776 | dim_per_head: int | None = None
777 | atten_dropout_prob: float = 0.0
778 | residual_dropout_prob: float = 0.0
779 | relu_dropout_prob: float = 0.0
780 | norm_policy: str = 'pre'
781 | use_bias: bool = True
782 | activation_fn: ActivationFunc = nn.relu
783 | internal_enable_per_dim_scale: bool = True
784 | atten_logit_cap: float = 0.0
785 |
786 | @nn.nowrap
787 | def _make_ln(self, name: str) -> LayerNorm:
788 | """Makes a LayerNorm module."""
789 | return LayerNorm(
790 | name=name,
791 | use_bias=self.use_bias,
792 | dtype=self.dtype,
793 | fprop_dtype=self.fprop_dtype,
794 | )
795 |
796 | @nn.compact
797 | def __call__(
798 | self,
799 | inputs: Array,
800 | paddings: Array,
801 | atten_mask: Array,
802 | train: bool,
803 | ) -> Array:
804 | """Transformer decoder layer.
805 |
806 | Args:
807 | inputs: Input sequence jax.Array of shape [B, T, H].
808 | paddings: Input paddings jax.Array of shape [B, T] (only used in FFN).
809 | atten_mask: Self attention mask ready to add to the logits. It can be of
810 | shape [1|B, 1, 1|T, T] which is broadcast compatible with the
811 | self-attention matrix of shape [B, N, T, T]. This is assumed to have
812 | combined paddings, causal masking as well as segment maskings.
813 | train: Whether the model is in the train mode.
814 |
815 | Returns:
816 | The fflayer output with shape [B, T, D].
817 | """
818 |
819 | if self.norm_policy == 'primer_hybrid':
820 | inputs_normalized = self._make_ln(name='pre_layer_norm')(inputs)
821 | elif self.norm_policy == 'pre':
822 | inputs_normalized = self._make_ln(name='layer_norm')(inputs)
823 | else:
824 | inputs_normalized = inputs
825 |
826 | # Compute self-attention, key/value vectors are the input itself.
827 | atten_outputs, _ = DotProductAttention(
828 | name='self_attention',
829 | hidden_dim=inputs_normalized.shape[-1],
830 | num_heads=self.num_heads,
831 | dim_per_head=self.dim_per_head,
832 | atten_dropout_prob=self.atten_dropout_prob,
833 | use_bias=self.use_bias,
834 | internal_enable_per_dim_scale=self.internal_enable_per_dim_scale,
835 | atten_logit_cap=self.atten_logit_cap,
836 | dtype=self.dtype,
837 | fprop_dtype=self.fprop_dtype,
838 | )(
839 | inputs_normalized,
840 | inputs_normalized,
841 | inputs_normalized,
842 | atten_mask=atten_mask,
843 | train=train,
844 | )
845 |
846 | if self.norm_policy == 'primer_hybrid':
847 | atten_outputs = self._make_ln(name='post_layer_norm')(atten_outputs)
848 | elif self.norm_policy == 'post':
849 | atten_outputs = self._make_ln(name='layer_norm')(atten_outputs)
850 |
851 | # Residual dropout and connection.
852 | atten_outputs = nn.Dropout(
853 | self.residual_dropout_prob, name='residual_dropout'
854 | )(atten_outputs, deterministic=not train)
855 | atten_outputs += inputs
856 |
857 | if self.norm_policy == 'post_skip':
858 | atten_outputs = self._make_ln(name='layer_norm')(atten_outputs)
859 |
860 | # Apply FFN layer.
861 | outputs = TransformerFeedForward(
862 | name='ff_layer',
863 | hidden_dim=self.hidden_dim,
864 | has_bias=self.use_bias,
865 | activation_fn=self.activation_fn,
866 | residual_dropout_prob=self.residual_dropout_prob,
867 | relu_dropout_prob=self.relu_dropout_prob,
868 | norm_policy=self.norm_policy,
869 | dtype=self.dtype,
870 | fprop_dtype=self.fprop_dtype,
871 | )(atten_outputs, paddings=paddings, train=train)
872 | return outputs
873 |
874 |
875 | class Repeat(nn.Module):
876 | """A generic repeat layer with `nn.remat` and`nn.scan`.
877 |
878 | Attributes:
879 | block_fn: The block function to repeat.
880 | times: The number of times to repeat block.
881 | checkpoint_policy: Checkpoint policy for `nn.remat`.
882 | """
883 |
884 | block_fn: Callable[..., Any]
885 | times: int = 0
886 | checkpoint_policy: str = 'nothing_saveable'
887 |
888 | def __call__(
889 | self,
890 | inputs: Array,
891 | *args: Any,
892 | **kwargs: Any,
893 | ) -> Any:
894 | """Forwards inputs through the block layer stack.
895 |
896 | Block outputs are expected to be of the same structure as inputs.
897 |
898 | Args:
899 | inputs: A NestedMap of inputs that goes through the block layer stack.
900 | *args: Positional args to be passed to the forward method.
901 | **kwargs: Keyward args to be passed to the forward method.
902 |
903 | Returns:
904 | Output from the last layer.
905 | """
906 | return self.call_with_custom_method(
907 | inputs,
908 | *args,
909 | main_fn=self.block_fn,
910 | **kwargs,
911 | )
912 |
913 | def call_with_custom_method(
914 | self,
915 | inputs: Array,
916 | *args: Any,
917 | main_fn: Callable[..., Any],
918 | **kwargs: Any,
919 | ) -> Any:
920 | """Similar to __call__, but allows a custom way to create a layer method."""
921 |
922 | def body_fn(fn, layer_inputs):
923 | return fn(layer_inputs, *args, **kwargs), None
924 |
925 | rematted_body_fn = nn.remat(
926 | body_fn,
927 | prevent_cse=False,
928 | policy=getattr(jax.checkpoint_policies, self.checkpoint_policy, None),
929 | )
930 | scan_fn = nn.scan(
931 | rematted_body_fn,
932 | variable_axes={'params': 0},
933 | split_rngs={'params': True, 'dropout': True},
934 | length=self.times,
935 | )
936 | outputs, _ = scan_fn(main_fn, inputs)
937 | return outputs
938 |
939 |
940 | class StackedTransformer(Module):
941 | """A stack of Transformer layers.
942 |
943 | Attributes:
944 | num_layers: Number of layers in this stack.
945 | hidden_dim: The hidden layer dimension of FFN in Transformer layers.
946 | num_heads: Number of attention heads.
947 | dim_per_head: Dimension of each attention head. If None then dim_per_head ==
948 | model_dims // num_heads.
949 | dropout_prob: Apply dropout at this prob at various places.
950 | atten_dropout_prob: Probability at which we apply dropout to the attention
951 | weights.
952 | residual_dropout_prob: Probability at which we apply dropout to the residual
953 | layers, such that, residual(x, y) = (x + dropout(y)).
954 | relu_dropout_prob: Probability at which we apply dropout to the FFN layers.
955 | input_dropout_prob: Dropout probability applied to the input before any
956 | processing happens.
957 | norm_policy: Policy for applying normalization wrt. transformations. Options
958 | are: (1) "pre", applied before transformation. (2) "primer_hybrid",
959 | applied before and after transformation. (3) "post", applied after
960 | transformation. (4) "post_skip", applied after the skip connection.
961 | use_bias: Whether to use bias.
962 | activation_fn: Activation function to use.
963 | internal_enable_per_dim_scale: Internal. Setting to False disables rescaling
964 | of attention logits with 1/sqrt(dim) factor.
965 | atten_logit_cap: Cap the absolute values of logits by tanh. Enabled when a
966 | positive value is specified. May not be supported by a subclass.
967 | enable_causal_atten: Whether to enable causal attention.
968 | scan: Whether to use `nn.remat` and`nn.scan`.
969 | """
970 |
971 | num_layers: int = 0
972 | hidden_dim: int = 0
973 | num_heads: int = 0
974 | dim_per_head: int | None = None
975 | dropout_prob: float = 0.0
976 | atten_dropout_prob: float | None = None
977 | residual_dropout_prob: float | None = None
978 | relu_dropout_prob: float | None = None
979 | input_dropout_prob: float = 0.0
980 | norm_policy: str = 'pre'
981 | use_bias: bool = True
982 | activation_fn: ActivationFunc = nn.relu
983 | internal_enable_per_dim_scale: bool = True
984 | atten_logit_cap: float = 0.0
985 | enable_causal_atten: bool = False
986 | scan: bool = False
987 |
988 | @nn.compact
989 | def __call__(
990 | self,
991 | inputs: Array,
992 | paddings: Array,
993 | train: bool,
994 | ) -> Array:
995 | """Stacked Transformer layer.
996 |
997 | Args:
998 | inputs: Input sequence of shape [B, T, H].
999 | paddings: Input paddings of shape [B, T].
1000 | train: If the model is in the train mode.
1001 |
1002 | Returns:
1003 | Output vector with shape [B, T, D].
1004 | """
1005 |
1006 | atten_mask = compute_attention_masks_for_fprop(
1007 | inputs, paddings, causal_attention=self.enable_causal_atten
1008 | )
1009 |
1010 | outputs = inputs
1011 | if self.input_dropout_prob > 0.0:
1012 | outputs = nn.Dropout(self.input_dropout_prob, name='input_dropout')(
1013 | outputs, deterministic=not train
1014 | )
1015 |
1016 | transformer_kwargs = dict(
1017 | num_heads=self.num_heads,
1018 | dim_per_head=self.dim_per_head,
1019 | hidden_dim=self.hidden_dim,
1020 | atten_dropout_prob=self.atten_dropout_prob or self.dropout_prob,
1021 | residual_dropout_prob=self.residual_dropout_prob or self.dropout_prob,
1022 | relu_dropout_prob=self.relu_dropout_prob or self.dropout_prob,
1023 | norm_policy=self.norm_policy,
1024 | use_bias=self.use_bias,
1025 | activation_fn=self.activation_fn,
1026 | internal_enable_per_dim_scale=self.internal_enable_per_dim_scale,
1027 | atten_logit_cap=self.atten_logit_cap,
1028 | dtype=self.dtype,
1029 | fprop_dtype=self.fprop_dtype,
1030 | )
1031 | if self.scan:
1032 | block_fn = Transformer(name='x_layers', **transformer_kwargs)
1033 | outputs = Repeat(block_fn=block_fn, times=self.num_layers)(
1034 | outputs, paddings, atten_mask, train
1035 | )
1036 | else:
1037 | for i in range(self.num_layers):
1038 | outputs = Transformer(name=f'x_layers_{i}', **transformer_kwargs)(
1039 | outputs, paddings, atten_mask, train
1040 | )
1041 | return outputs
1042 |
1043 |
1044 | class AttenTokenPoolingLayer(Module):
1045 | """Attentional token pooling layer.
1046 |
1047 | Attributes:
1048 | query_dim: The query dimension of attention. If None then query_dim ==
1049 | input_dim.
1050 | hidden_dim: The hidden layer dimension of FFN in Transformer layers.
1051 | num_heads: Number of attention heads.
1052 | num_queries: Number of attention queries.
1053 | add_layer_norm: Whether to apply layer norm to the pooled tokens.
1054 | dropout_prob: The probability of dropout on the pooled tokens.
1055 | use_qk_norm: If QK norm is used.
1056 | use_bias: Whether to use bias.
1057 | internal_enable_per_dim_scale: Internal. Setting to False disables rescaling
1058 | of attention logits with 1/sqrt(dim) factor.
1059 | """
1060 |
1061 | query_dim: int | None = None
1062 | hidden_dim: int = 0
1063 | num_heads: int = 1
1064 | num_queries: int = 1
1065 | add_layer_norm: bool = True
1066 | dropout_prob: float = 0.0
1067 | use_qk_norm: bool = False
1068 | use_bias: bool = True
1069 | internal_enable_per_dim_scale: bool = True
1070 |
1071 | @nn.compact
1072 | def __call__(
1073 | self,
1074 | tokens: Array,
1075 | paddings: Array | None,
1076 | train: bool,
1077 | ) -> Array:
1078 | """Computes the pooled tokens for inputs.
1079 |
1080 | Args:
1081 | tokens: Input tokens of shape [B, T, H].
1082 | paddings: Input paddings of shape [B, T].
1083 | train: If the model is in the train mode.
1084 |
1085 | Returns:
1086 | Output vector with shape [B, N, D].
1087 | """
1088 | input_dim = tokens.shape[-1]
1089 | query_dim = self.query_dim or input_dim
1090 | hidden_dim = self.hidden_dim if self.hidden_dim > 0 else 4 * input_dim
1091 | batch_size, seq_length = tokens.shape[0], tokens.shape[-2]
1092 |
1093 | query = self._cast_to_fprop_dtype(
1094 | self.param(
1095 | 'pooling_attention_query',
1096 | default_kernel_init,
1097 | [self.num_queries, query_dim],
1098 | self.dtype,
1099 | )
1100 | )
1101 | query = jnp.tile(query[jnp.newaxis, :, :], [batch_size, 1, 1])
1102 |
1103 | if paddings is None:
1104 | paddings = jnp.zeros([batch_size, seq_length], dtype=tokens.dtype)
1105 |
1106 | atten_mask = _convert_paddings_to_mask(paddings, dtype=paddings.dtype)
1107 | outputs, _ = DotProductAttention(
1108 | name='pooling_attention',
1109 | hidden_dim=hidden_dim,
1110 | num_heads=self.num_heads,
1111 | use_bias=self.use_bias,
1112 | internal_enable_per_dim_scale=self.internal_enable_per_dim_scale,
1113 | use_qk_norm=self.use_qk_norm,
1114 | dtype=self.dtype,
1115 | fprop_dtype=self.fprop_dtype,
1116 | )(
1117 | query,
1118 | tokens,
1119 | tokens,
1120 | atten_mask=atten_mask,
1121 | train=train,
1122 | )
1123 |
1124 | if self.add_layer_norm:
1125 | outputs = LayerNorm(
1126 | name='pooling_attention_layer_norm',
1127 | dtype=self.dtype,
1128 | fprop_dtype=self.fprop_dtype,
1129 | )(outputs)
1130 |
1131 | if self.dropout_prob > 0.0:
1132 | outputs = nn.Dropout(self.dropout_prob, name='attention_dropout')(
1133 | outputs, deterministic=not train
1134 | )
1135 |
1136 | return outputs
1137 |
--------------------------------------------------------------------------------