├── videoprism ├── __init__.py ├── assets │ ├── testdata │ │ └── test_spm.model │ └── water_bottle_drumming.mp4 ├── utils_test.py ├── tokenizers_test.py ├── models_test.py ├── utils.py ├── tokenizers.py ├── colabs │ ├── videoprism_video_encoder_demo.ipynb │ └── videoprism_video_text_demo.ipynb ├── layers_test.py ├── encoders_test.py ├── models.py ├── encoders.py └── layers.py ├── requirements.txt ├── CONTRIBUTING.md ├── setup.py ├── README.md └── LICENSE /videoprism/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.26 2 | absl-py 3 | einops 4 | einshape 5 | flax 6 | huggingface-hub 7 | sentencepiece 8 | tensorflow-cpu -------------------------------------------------------------------------------- /videoprism/assets/testdata/test_spm.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/videoprism/HEAD/videoprism/assets/testdata/test_spm.model -------------------------------------------------------------------------------- /videoprism/assets/water_bottle_drumming.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/videoprism/HEAD/videoprism/assets/water_bottle_drumming.mp4 -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | ## Contributor License Agreement 4 | 5 | Contributions to this project must be accompanied by a Contributor License 6 | Agreement. You (or your employer) retain the copyright to your contribution, 7 | this simply gives us permission to use and redistribute your contributions as 8 | part of the project. Head over to to see 9 | your current agreements on file or to sign a new one. 10 | 11 | You generally only need to submit a CLA once, so if you've already submitted one 12 | (even if it was for a different project), you probably don't need to do it 13 | again. 14 | 15 | ## Code reviews 16 | 17 | All submissions, including submissions by project members, require review. We 18 | use GitHub pull requests for this purpose. Consult 19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 20 | information on using pull requests. 21 | 22 | ## Community Guidelines 23 | 24 | This project follows [Google's Open Source Community 25 | Guidelines](https://opensource.google/conduct/). -------------------------------------------------------------------------------- /videoprism/utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 VideoPrism Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for utility functions.""" 16 | 17 | from absl.testing import absltest 18 | from videoprism import utils 19 | 20 | 21 | class UtilsTest(absltest.TestCase): 22 | 23 | def test_canonicalize_text(self): 24 | self.assertEqual(utils.canonicalize_text("Hello, World!"), "hello world.") 25 | self.assertEqual(utils.canonicalize_text("Hello,World.."), "hello world.") 26 | self.assertEqual(utils.canonicalize_text(" Hello WORLD"), "hello world.") 27 | 28 | 29 | if __name__ == "__main__": 30 | absltest.main() 31 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 VideoPrism Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """setup.py for VideoPrism. 16 | 17 | Install for development: 18 | 19 | pip intall -e . .[testing] 20 | """ 21 | 22 | import setuptools 23 | 24 | # Get install requirements from the REQUIREMENTS file. 25 | with open("requirements.txt") as fp: 26 | install_requires_core = fp.read().splitlines() 27 | 28 | tests_require = [ 29 | "chex", 30 | "pytest", 31 | ] 32 | 33 | setuptools.setup( 34 | name="videoprism", 35 | version="1.0.0", 36 | description=( 37 | "VideoPrism: A Foundational Visual Encoder for Video Understanding." 38 | ), 39 | author="VideoPrism Authors", 40 | author_email="no-reply@google.com", 41 | long_description=open("README.md").read(), 42 | long_description_content_type="text/markdown", 43 | url="https://github.com/google-deepmind/videoprism", 44 | license="Apache 2.0", 45 | packages=setuptools.find_packages(), 46 | include_package_data=True, 47 | install_requires=install_requires_core, 48 | tests_require=tests_require, 49 | extras_require={ 50 | "testing": tests_require, 51 | }, 52 | classifiers=[ 53 | "Development Status :: 1 - Beta", 54 | "Intended Audience :: Developers", 55 | "Intended Audience :: VideoPrism/Research", 56 | "License :: OSI Approved :: Apache Software License", 57 | "Programming Language :: Python", 58 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 59 | ], 60 | keywords="VideoPrism", 61 | ) 62 | -------------------------------------------------------------------------------- /videoprism/tokenizers_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 VideoPrism Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for text tokenizers.""" 16 | 17 | from absl.testing import parameterized 18 | import numpy as np 19 | import tensorflow as tf 20 | from videoprism import tokenizers 21 | 22 | 23 | class PyToTfWrapper: 24 | """Allows to use `to_int_tf_op()` via `to_int()`.""" 25 | 26 | def __init__(self, model): 27 | self.model = model 28 | self.bos_token = model.bos_token 29 | self.eos_token = model.eos_token 30 | self.vocab_size = model.vocab_size 31 | 32 | def to_int(self, text, *, bos=False, eos=False): 33 | ret = self.model.to_int_tf_op(text, bos=bos, eos=eos) 34 | if isinstance(ret, tf.RaggedTensor): 35 | return [t.numpy().tolist() for t in ret] 36 | return ret.numpy().tolist() 37 | 38 | 39 | class TokenizersTest(tf.test.TestCase, parameterized.TestCase): 40 | 41 | def setUp(self): 42 | import os 43 | self.spm_path = os.path.join( 44 | os.path.dirname(__file__), "assets", "testdata", "test_spm.model" 45 | ) 46 | super().setUp() 47 | 48 | @parameterized.named_parameters( 49 | ("py", False), 50 | ("tf", True), 51 | ) 52 | def test_sentencepiece_tokenizer(self, wrap_model): 53 | model = tokenizers.SentencePieceTokenizer(self.spm_path) 54 | if wrap_model: 55 | model = PyToTfWrapper(model) 56 | self.assertEqual(model.vocab_size, 1000) 57 | bos, eos = model.bos_token, model.eos_token 58 | self.assertEqual(bos, 1) 59 | self.assertEqual(eos, 2) 60 | self.assertEqual(model.to_int("blah"), [80, 180, 60]) 61 | self.assertEqual(model.to_int("blah", bos=True), [bos, 80, 180, 60]) 62 | self.assertEqual(model.to_int("blah", eos=True), [80, 180, 60, eos]) 63 | self.assertEqual( 64 | model.to_int("blah", bos=True, eos=True), [bos, 80, 180, 60, eos] 65 | ) 66 | self.assertEqual( 67 | model.to_int(["blah", "blah blah"]), 68 | [[80, 180, 60], [80, 180, 60, 80, 180, 60]], 69 | ) 70 | 71 | def test_sentencepiece_tokenizer_tf_data(self): 72 | model = tokenizers.SentencePieceTokenizer(self.spm_path) 73 | 74 | def gen(): 75 | yield tf.convert_to_tensor(["blah"]) 76 | yield tf.convert_to_tensor(["blah", "blah blah"]) 77 | 78 | ds = tf.data.Dataset.from_generator(gen, tf.string, tf.TensorShape([None])) 79 | ds = ds.map(model.to_int_tf_op) 80 | res = [ 81 | [b.tolist() if isinstance(b, np.ndarray) else b for b in a.tolist()] 82 | for a in ds.as_numpy_iterator() 83 | ] 84 | print(res) 85 | self.assertAllEqual( 86 | res, [[[80, 180, 60]], [[80, 180, 60], [80, 180, 60, 80, 180, 60]]] 87 | ) 88 | 89 | 90 | if __name__ == "__main__": 91 | tf.test.main() 92 | -------------------------------------------------------------------------------- /videoprism/models_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 VideoPrism Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for VideoPrism models.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import jax 20 | from jax import numpy as jnp 21 | import numpy as np 22 | from videoprism import models 23 | from videoprism import tokenizers 24 | 25 | 26 | class ModelsTest(parameterized.TestCase): 27 | 28 | @parameterized.parameters( 29 | ('videoprism_public_v1_base', True), 30 | ('videoprism_public_v1_large', True), 31 | ('videoprism_public_v1_giant', False), 32 | ) 33 | def test_has_model(self, model_name, exists): 34 | self.assertEqual(models.has_model(model_name), exists) 35 | 36 | @parameterized.parameters(8, 16) 37 | def test_videoprism(self, num_frames): 38 | batch_size = 1 39 | np_inputs = np.random.normal( 40 | 0.0, 0.1, [batch_size, num_frames, 288, 288, 3] 41 | ).astype('float32') 42 | inputs = jnp.asarray(np_inputs) 43 | prng_key = jax.random.PRNGKey(seed=123) 44 | 45 | mdl = models.videoprism_v1_base() 46 | mdl_params = mdl.init(prng_key, inputs, train=False) 47 | 48 | @jax.jit 49 | def forward_fn(mdl_inputs): 50 | return mdl.apply(mdl_params, mdl_inputs, train=False) 51 | 52 | embeddings, _ = forward_fn(inputs) 53 | self.assertEqual(embeddings.shape, (batch_size, num_frames * 16**2, 768)) 54 | 55 | def test_videoprism_lvt(self): 56 | batch_size, num_frames = 1, 16 57 | np_inputs = np.random.normal( 58 | 0.0, 0.1, [batch_size, num_frames, 288, 288, 3] 59 | ).astype('float32') 60 | inputs = jnp.asarray(np_inputs) 61 | np_text_token_ids = np.random.randint( 62 | 0, 32_000, [batch_size, models.TEXT_MAX_LEN] 63 | ).astype('int32') 64 | text_token_ids = jnp.asarray(np_text_token_ids) 65 | np_text_paddings = np.zeros( 66 | [batch_size, models.TEXT_MAX_LEN], dtype='float32' 67 | ) 68 | np_text_paddings[:, models.TEXT_MAX_LEN // 2 :] = 1 69 | text_paddings = jnp.asarray(np_text_paddings) 70 | prng_key = jax.random.PRNGKey(seed=123) 71 | 72 | mdl = models.videoprism_lvt_v1_base() 73 | mdl_params = mdl.init( 74 | prng_key, inputs, text_token_ids, text_paddings, train=False 75 | ) 76 | 77 | @jax.jit 78 | def forward_fn(mdl_inputs, mdl_text_token_ids, mdl_text_paddings): 79 | return mdl.apply( 80 | mdl_params, 81 | mdl_inputs, 82 | mdl_text_token_ids, 83 | mdl_text_paddings, 84 | train=False, 85 | ) 86 | 87 | vision_embeddings, text_embeddings, _ = forward_fn( 88 | inputs, text_token_ids, text_paddings 89 | ) 90 | self.assertEqual(vision_embeddings.shape, (batch_size, 768)) 91 | self.assertEqual(text_embeddings.shape, (batch_size, 768)) 92 | 93 | def test_tokenize_texts(self): 94 | import os 95 | spm_path = os.path.join( 96 | os.path.dirname(__file__), 'assets', 'testdata', 'test_spm.model' 97 | ) 98 | model = tokenizers.SentencePieceTokenizer(spm_path) 99 | ids, paddings = models.tokenize_texts( 100 | model, 101 | ['blah', 'blah blah', 'blah blah blah'], 102 | max_length=6, 103 | add_bos=False, 104 | canonicalize=False, 105 | ) 106 | np.testing.assert_array_equal( 107 | ids, 108 | [ 109 | [80, 180, 60, 0, 0, 0], 110 | [80, 180, 60, 80, 180, 60], 111 | [80, 180, 60, 80, 180, 60], 112 | ], 113 | ) 114 | np.testing.assert_array_equal( 115 | paddings, [[0, 0, 0, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]] 116 | ) 117 | 118 | 119 | if __name__ == '__main__': 120 | absltest.main() 121 | -------------------------------------------------------------------------------- /videoprism/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 VideoPrism Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utility functions for checkpointing and other purposes.""" 16 | 17 | import collections 18 | from collections.abc import Mapping, Sequence 19 | import io 20 | import os 21 | import string 22 | 23 | import jax 24 | import numpy as np 25 | from tensorflow.io import gfile 26 | 27 | 28 | def traverse_with_names(tree, with_inner_nodes=False): 29 | """Traverses nested dicts and emits (leaf_name, leaf_val). 30 | 31 | Args: 32 | tree: JAX Pytree object. 33 | with_inner_nodes: Whether to traverse the non-leaf nodes. 34 | 35 | Yields: 36 | A pair of (leaf_name, leaf_val). 37 | """ 38 | # Don't output the non-leaf nodes. If the optimizer doesn't have a state 39 | # the tree leaves can be Nones which was interpreted as a leaf by this 40 | # function but not by the other functions (like jax.tree.map). 41 | if tree is None: 42 | return 43 | elif isinstance(tree, Mapping): 44 | keys = sorted(tree.keys()) 45 | for key in keys: 46 | for path, v in traverse_with_names(tree[key], with_inner_nodes): 47 | yield (key + "/" + path).rstrip("/"), v 48 | if with_inner_nodes: 49 | yield "", tree 50 | elif isinstance(tree, Sequence): 51 | for idx in range(len(tree)): 52 | for path, v in traverse_with_names(tree[idx], with_inner_nodes): 53 | yield (str(idx) + "/" + path).rstrip("/"), v 54 | if with_inner_nodes: 55 | yield "", tree 56 | else: 57 | yield "", tree 58 | 59 | 60 | def tree_flatten_with_names(tree): 61 | """Populates tree_flatten with leaf names. 62 | 63 | Args: 64 | tree: JAX Pytree object. 65 | 66 | Returns: 67 | A list of values with names: [(name, value), ...] 68 | """ 69 | vals, tree_def = jax.tree.flatten(tree) 70 | 71 | tokens = range(len(vals)) 72 | token_tree = tree_def.unflatten(tokens) 73 | val_names, perm = zip(*traverse_with_names(token_tree)) 74 | inv_perm = np.argsort(perm) 75 | 76 | # Custom traverasal should visit the same number of leaves. 77 | assert len(val_names) == len(vals) 78 | 79 | return [(val_names[i], v) for i, v in zip(inv_perm, vals)] 80 | 81 | 82 | def recover_tree(keys, values): 83 | """Recovers a tree as a nested dict from flat names and values. 84 | 85 | Args: 86 | keys: A list of keys, where '/' is used as separator between nodes. 87 | values: A list of leaf values. 88 | 89 | Returns: 90 | A nested tree-like dict. 91 | """ 92 | tree = {} 93 | sub_trees = collections.defaultdict(list) 94 | for k, v in zip(keys, values): 95 | if "/" not in k: 96 | tree[k] = v 97 | else: 98 | k_left, k_right = k.split("/", 1) 99 | sub_trees[k_left].append((k_right, v)) 100 | for k, kv_pairs in sub_trees.items(): 101 | k_subtree, v_subtree = zip(*kv_pairs) 102 | tree[k] = recover_tree(k_subtree, v_subtree) 103 | return tree 104 | 105 | 106 | def npload(fname): 107 | """Loads `fname` and returns an np.ndarray or dict thereof.""" 108 | # Load the data; use local paths directly if possible: 109 | if os.path.exists(fname): 110 | loaded = np.load(fname, allow_pickle=False) 111 | else: 112 | # For other (remote) paths go via gfile+BytesIO as np.load requires seeks. 113 | with gfile.GFile(fname, "rb") as f: 114 | data = f.read() 115 | loaded = np.load(io.BytesIO(data), allow_pickle=False) 116 | 117 | # Support loading both single-array files (np.save) and zips (np.savez). 118 | if isinstance(loaded, np.ndarray): 119 | return loaded 120 | else: 121 | return dict(loaded) 122 | 123 | 124 | def load_checkpoint(npz): 125 | """Loads a jax Pytree from a npz file. 126 | 127 | Args: 128 | npz: Either path to the checkpoint file (.npz), or a dict-like. 129 | 130 | Returns: 131 | A Pytree that is the checkpoint. 132 | """ 133 | if isinstance(npz, str): # If not already loaded, then load. 134 | npz = npload(npz) 135 | keys, values = zip(*list(npz.items())) 136 | return recover_tree(keys, values) 137 | 138 | 139 | def canonicalize_text(text: str) -> str: 140 | """Canonicalizes text. 141 | 142 | Canonicalization includes: 143 | - Replace all punctuation with a whitespace. 144 | - Use all lower case. 145 | - Leave only one whitespace between words. 146 | - End with a period. 147 | 148 | Examples: 149 | "Hello, World!" -> "hello world." 150 | "Hello,World.." -> "hello world." 151 | " Hello WORLD" -> "hello world." 152 | 153 | Args: 154 | text: A string for the input text. 155 | 156 | Returns: 157 | A string for the canonicalized text. 158 | """ 159 | # Replace all punctuation with a whitespace. 160 | p = string.punctuation 161 | text = text.translate(str.maketrans(p, " " * len(p))) 162 | # Use all lower case. 163 | text = text.lower() 164 | # Leave only one whitespace between words. 165 | text = " ".join(text.split()) 166 | # End with a period. 167 | text = text + "." 168 | return text 169 | -------------------------------------------------------------------------------- /videoprism/tokenizers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 VideoPrism Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tokenizers for text encoders.""" 16 | 17 | from collections.abc import Sequence 18 | from typing import Protocol 19 | 20 | import tensorflow as tf 21 | from tensorflow.io import gfile 22 | 23 | import sentencepiece 24 | 25 | SentencePieceProcessor = sentencepiece.SentencePieceProcessor 26 | 27 | 28 | class Tokenizer(Protocol): 29 | """Tokenizer interface.""" 30 | 31 | def to_int( 32 | self, text: str | Sequence[str], *, bos: bool = False, eos: bool = False 33 | ) -> list[int] | list[list[int]]: 34 | """Tokenizes `text` into a list of integer tokens. 35 | 36 | Args: 37 | text: can be a single string, or a list of strings. 38 | bos: Whether a beginning-of-sentence token should be prepended. 39 | eos: Whether an end-of-sentence token should be appended. 40 | 41 | Returns: 42 | A list or list-of-list of tokens. 43 | """ 44 | 45 | def to_int_tf_op( 46 | self, text: str | Sequence[str], *, bos: bool = False, eos: bool = False 47 | ) -> tf.Tensor | tf.RaggedTensor: 48 | """Same as `to_int()`, but as TF ops to be used in data pipelines. 49 | 50 | Args: 51 | text: can be a single string, or a list of strings. 52 | bos: Whether a beginning-of-sentence token should be prepended. 53 | eos: Whether an end-of-sentence token should be appended. 54 | 55 | Returns: 56 | A tf.Tensor of tokens. 57 | """ 58 | 59 | @property 60 | def pad_token(self) -> int: 61 | """Token id of padding token.""" 62 | 63 | @property 64 | def eos_token(self) -> int: 65 | """Token id of end-of-sentence token.""" 66 | 67 | @property 68 | def bos_token(self) -> int: 69 | """Token id of beginning-of-sentence token.""" 70 | 71 | @property 72 | def vocab_size(self) -> int: 73 | """Returns the size of the vocabulary.""" 74 | 75 | 76 | class SentencePieceTokenizer(Tokenizer): 77 | """Wraps a SentencePiece model for tokenization.""" 78 | 79 | def __init__(self, model_path): 80 | """Initializes the tokenizer. 81 | 82 | Args: 83 | model_path: A path to load the SentencePiece model. 84 | """ 85 | with gfile.GFile(model_path, "rb") as f: 86 | model_bytes = f.read() 87 | 88 | self._model = SentencePieceProcessor() 89 | self._model.LoadFromSerializedProto(model_bytes) 90 | 91 | def to_int( 92 | self, text: str | Sequence[str], *, bos: bool = False, eos: bool = False 93 | ) -> list[int] | list[list[int]]: 94 | """Tokenizes `text` into a list of integer tokens. 95 | 96 | Args: 97 | text: can be a single string, or a list of strings. 98 | bos: Whether a beginning-of-sentence token should be prepended. 99 | eos: Whether an end-of-sentence token should be appended. 100 | 101 | Returns: 102 | A list or list-of-list of tokens. 103 | """ 104 | 105 | def _single(s: str) -> list[int]: 106 | return ( 107 | ([self.bos_token] if bos else []) 108 | + self._model.EncodeAsIds(s) 109 | + ([self.eos_token] if eos else []) 110 | ) 111 | 112 | if isinstance(text, str): 113 | return _single(text) 114 | return list([_single(s) for s in text]) 115 | 116 | def to_int_tf_op( 117 | self, text: str | Sequence[str], *, bos: bool = False, eos: bool = False 118 | ) -> tf.Tensor | tf.RaggedTensor: 119 | """Same as `to_int()`, but as TF ops to be used in data pipelines. 120 | 121 | Args: 122 | text: can be a single string, or a list of strings. 123 | bos: Whether a beginning-of-sentence token should be prepended. 124 | eos: Whether an end-of-sentence token should be appended. 125 | 126 | Returns: 127 | A tf.Tensor or tf.RaggedTensor of tokens. 128 | """ 129 | text = tf.convert_to_tensor(text) 130 | if text.ndim == 0: 131 | 132 | def fn(txt): 133 | """Tokenizes a single string.""" 134 | s = txt.numpy().decode() 135 | return tf.constant(self.to_int(s, bos=bos, eos=eos), tf.int32) 136 | 137 | return tf.py_function(fn, [text], tf.int32) 138 | else: 139 | 140 | def fn(txt): 141 | """Tokenizes a list of strings.""" 142 | strings = [s.decode() for s in txt.numpy().tolist()] 143 | toks = self.to_int(strings, bos=bos, eos=eos) 144 | return tf.ragged.constant(toks) 145 | 146 | out_type = tf.RaggedTensorSpec([text.shape[0], None], tf.int32) 147 | return tf.py_function(fn, [text], Tout=out_type) 148 | 149 | @property 150 | def pad_token(self) -> int: 151 | """Token id of padding token.""" 152 | return self._model.pad_id() 153 | 154 | @property 155 | def eos_token(self) -> int: 156 | """Token id of end-of-sentence token.""" 157 | return self._model.eos_id() 158 | 159 | @property 160 | def bos_token(self) -> int: 161 | """Token id of beginning-of-sentence token.""" 162 | return self._model.bos_id() 163 | 164 | @property 165 | def vocab_size(self) -> int: 166 | """Returns the size of the vocabulary.""" 167 | return self._model.GetPieceSize() 168 | -------------------------------------------------------------------------------- /videoprism/colabs/videoprism_video_encoder_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "metadata": { 5 | "id": "KPPUiCpSbm53" 6 | }, 7 | "cell_type": "markdown", 8 | "source": [ 9 | "# VideoPrism Video Encoder Demo\n", 10 | "\n", 11 | "[![Paper](https://img.shields.io/badge/arXiv-2402.13217-red.svg)](https://arxiv.org/abs/2402.13217)\n", 12 | "[![Blog](https://img.shields.io/badge/Google_Research-Blog-green.svg)](https://research.google/blog/videoprism-a-foundational-visual-encoder-for-video-understanding/)\n", 13 | "[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)\n", 14 | "\n", 15 | "This notebook provides an example of video feature extraction with a pre-trained VideoPrism video encoder.\n", 16 | "\n", 17 | "Please run this demo on Google Colab with (faster) or without TPU." 18 | ] 19 | }, 20 | { 21 | "metadata": { 22 | "id": "k08qFZ9-cn9v" 23 | }, 24 | "cell_type": "markdown", 25 | "source": [ 26 | "## Set up" 27 | ] 28 | }, 29 | { 30 | "metadata": { 31 | "id": "1dfyX8EyVsvL" 32 | }, 33 | "cell_type": "code", 34 | "source": [ 35 | "# @title Prepare environment\n", 36 | "\n", 37 | "import os\n", 38 | "import sys\n", 39 | "\n", 40 | "# Fetch VideoPrism repository if Python does not know about it and install\n", 41 | "# dependencies needed for this notebook.\n", 42 | "if not os.path.exists(\"videoprism_repo\"):\n", 43 | " !git clone --quiet --branch=main --depth=1 \\\n", 44 | " https://github.com/google-deepmind/videoprism.git videoprism_repo\n", 45 | " os.chdir('./videoprism_repo')\n", 46 | " !pip install .\n", 47 | " os.chdir('..')\n", 48 | "\n", 49 | "# Append VideoPrism code to Python import path.\n", 50 | "if \"videoprism_repo\" not in sys.path:\n", 51 | " sys.path.append(\"videoprism_repo\")\n", 52 | "\n", 53 | "# Install missing dependencies.\n", 54 | "!pip install mediapy\n", 55 | "\n", 56 | "import jax\n", 57 | "from jax.extend import backend\n", 58 | "import tensorflow as tf\n", 59 | "\n", 60 | "# Do not let TF use the GPU or TPUs.\n", 61 | "tf.config.set_visible_devices([], \"GPU\")\n", 62 | "tf.config.set_visible_devices([], \"TPU\")\n", 63 | "\n", 64 | "print(f\"JAX version: {jax.__version__}\")\n", 65 | "print(f\"JAX platform: {backend.get_backend().platform}\")\n", 66 | "print(f\"JAX devices: {jax.device_count()}\")" 67 | ], 68 | "outputs": [], 69 | "execution_count": null 70 | }, 71 | { 72 | "metadata": { 73 | "id": "zByA1K0IVKAI" 74 | }, 75 | "cell_type": "code", 76 | "source": [ 77 | "# @title Load dependencies and define utilities\n", 78 | "\n", 79 | "import mediapy\n", 80 | "import numpy as np\n", 81 | "from PIL import Image\n", 82 | "\n", 83 | "\n", 84 | "def read_and_preprocess_video(\n", 85 | " filename: str, target_num_frames: int, target_frame_size: tuple[int, int]\n", 86 | "):\n", 87 | " \"\"\"Reads and preprocesses a video.\"\"\"\n", 88 | "\n", 89 | " frames = mediapy.read_video(filename)\n", 90 | "\n", 91 | " # Sample to target number of frames.\n", 92 | " frame_indices = np.linspace(\n", 93 | " 0, len(frames), num=target_num_frames, endpoint=False, dtype=np.int32\n", 94 | " )\n", 95 | " frames = np.asarray([frames[i] for i in frame_indices])\n", 96 | "\n", 97 | " # Resize to target size.\n", 98 | " original_height, original_width = frames.shape[-3:-1]\n", 99 | " target_height, target_width = target_frame_size\n", 100 | " assert (\n", 101 | " original_height * target_width == original_width * target_height\n", 102 | " ), 'Currently does not support aspect ratio mismatch.'\n", 103 | " frames = mediapy.resize_video(frames, shape=target_frame_size)\n", 104 | "\n", 105 | " # Normalize pixel values to [0.0, 1.0].\n", 106 | " frames = mediapy.to_float01(frames)\n", 107 | "\n", 108 | " return frames" 109 | ], 110 | "outputs": [], 111 | "execution_count": null 112 | }, 113 | { 114 | "metadata": { 115 | "id": "WnYuzSgrXCL1" 116 | }, 117 | "cell_type": "code", 118 | "source": [ 119 | "# @title Load model\n", 120 | "\n", 121 | "import jax\n", 122 | "import jax.numpy as jnp\n", 123 | "from videoprism import models as vp\n", 124 | "\n", 125 | "MODEL_NAME = 'videoprism_public_v1_base' # @param ['videoprism_public_v1_base', 'videoprism_public_v1_large'] {allow-input: false}\n", 126 | "USE_BFLOAT16 = False # @param { type: \"boolean\" }\n", 127 | "NUM_FRAMES = 16\n", 128 | "FRAME_SIZE = 288\n", 129 | "\n", 130 | "fprop_dtype = jnp.bfloat16 if USE_BFLOAT16 else None\n", 131 | "flax_model = vp.get_model(MODEL_NAME, fprop_dtype=fprop_dtype)\n", 132 | "loaded_state = vp.load_pretrained_weights(MODEL_NAME)\n", 133 | "\n", 134 | "\n", 135 | "@jax.jit\n", 136 | "def forward_fn(inputs, train=False):\n", 137 | " return flax_model.apply(loaded_state, inputs, train=train)" 138 | ], 139 | "outputs": [], 140 | "execution_count": null 141 | }, 142 | { 143 | "metadata": { 144 | "id": "AliScLC0jo1s" 145 | }, 146 | "cell_type": "markdown", 147 | "source": [ 148 | "# Example: Video feature extraction\n", 149 | "\n", 150 | "In this example, we extract the spatiotemporal embeddings of an example video." 151 | ] 152 | }, 153 | { 154 | "metadata": { 155 | "id": "kLzkhP8CYUYj" 156 | }, 157 | "cell_type": "code", 158 | "source": [ 159 | "VIDEO_FILE_PATH = 'videoprism_repo/videoprism/assets/water_bottle_drumming.mp4' # @param {type: \"string\"}\n", 160 | "\n", 161 | "frames = read_and_preprocess_video(\n", 162 | " VIDEO_FILE_PATH,\n", 163 | " target_num_frames=NUM_FRAMES,\n", 164 | " target_frame_size=[FRAME_SIZE, FRAME_SIZE],\n", 165 | ")\n", 166 | "mediapy.show_video(frames, fps=6.0)\n", 167 | "\n", 168 | "frames = jnp.asarray(frames[None, ...]) # Add batch dimension.\n", 169 | "if USE_BFLOAT16:\n", 170 | " frames = frames.astype(jnp.bfloat16)\n", 171 | "print(f'Input shape: {frames.shape} [type: {frames.dtype}]')\n", 172 | "\n", 173 | "embeddings, _ = forward_fn(frames)\n", 174 | "print(f'Encoded embedding shape: {embeddings.shape} [type: {embeddings.dtype}]')\n" 175 | ], 176 | "outputs": [], 177 | "execution_count": null 178 | } 179 | ], 180 | "metadata": { 181 | "colab": { 182 | "provenance": [], 183 | "gpuType": "V28" 184 | }, 185 | "kernelspec": { 186 | "name": "python3", 187 | "display_name": "Python 3" 188 | }, 189 | "language_info": { 190 | "name": "python" 191 | }, 192 | "accelerator": "TPU" 193 | }, 194 | "nbformat": 4, 195 | "nbformat_minor": 0 196 | } 197 | -------------------------------------------------------------------------------- /videoprism/layers_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 VideoPrism Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for layer modules.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import chex 20 | import jax 21 | from jax import numpy as jnp 22 | import numpy as np 23 | from videoprism import layers 24 | 25 | 26 | class LayersTest(parameterized.TestCase): 27 | 28 | def test_identity(self): 29 | inputs = jnp.ones((4, 16, 8)) 30 | outputs = layers.identity(inputs) 31 | self.assertEqual(inputs.shape, outputs.shape) 32 | self.assertTrue(jnp.array_equal(inputs, outputs)) 33 | 34 | @chex.variants(with_jit=True, without_jit=True) 35 | @parameterized.parameters(True, False) 36 | def test_layer_norm(self, direct_scale: bool): 37 | np_inputs = np.random.normal(1.0, 0.5, [10, 10, 10, 3]).astype(np.float32) 38 | inputs = jnp.asarray(np_inputs) 39 | prng_key = jax.random.PRNGKey(seed=123) 40 | ln = layers.LayerNorm(name='ln', direct_scale=direct_scale) 41 | 42 | @self.variant 43 | def var_fn(): 44 | return ln.init_with_output(prng_key, inputs) 45 | 46 | outputs, params = var_fn() 47 | self.assertLen(jax.tree_util.tree_flatten(params)[0], 2) 48 | self.assertEqual(inputs.shape, outputs.shape) 49 | 50 | @chex.variants(with_jit=True, without_jit=True) 51 | def test_feedforward_layer(self): 52 | np_inputs = np.random.normal(1.0, 0.5, [10, 10, 3]).astype(np.float32) 53 | inputs = jnp.asarray(np_inputs) 54 | prng_key = jax.random.PRNGKey(seed=123) 55 | ffn = layers.FeedForward(name='ffn', output_dim=20) 56 | 57 | @self.variant 58 | def var_fn(): 59 | return ffn.init_with_output(prng_key, inputs) 60 | 61 | outputs, params = var_fn() 62 | self.assertLen(jax.tree_util.tree_flatten(params)[0], 2) 63 | self.assertEqual(outputs.shape, (10, 10, 20)) 64 | 65 | @chex.variants(with_jit=True, without_jit=True) 66 | @parameterized.parameters(True, False) 67 | def test_transformer_feedforward(self, train: bool): 68 | batch_size, seq_len, input_dims = 4, 512, 8 69 | np_inputs = np.random.normal( 70 | 1.0, 0.5, [batch_size, seq_len, input_dims] 71 | ).astype(np.float32) 72 | inputs = jnp.asarray(np_inputs) 73 | np_paddings = np.zeros([batch_size, seq_len], dtype=np.float32) 74 | input_paddings = jnp.asarray(np_paddings) 75 | prng_key = jax.random.PRNGKey(seed=123) 76 | ffwd = layers.TransformerFeedForward( 77 | name='ffwd', hidden_dim=32, activation_fn=layers.gelu 78 | ) 79 | 80 | @self.variant 81 | def var_fn(): 82 | return ffwd.init_with_output( 83 | prng_key, inputs, input_paddings, train=train 84 | ) 85 | 86 | outputs, params = var_fn() 87 | self.assertLen(jax.tree_util.tree_flatten(params)[0], 6) 88 | self.assertEqual(outputs.shape, (batch_size, seq_len, input_dims)) 89 | 90 | @chex.variants(with_jit=True, without_jit=True) 91 | @parameterized.parameters( 92 | (16, 2, 5, False, [5, 16], [5, 2, 5]), 93 | (16, 2, 5, True, [5, 2, 5], [5, 16]), 94 | (256, 16, 16, True, [2, 16, 16], [2, 256]), 95 | ) 96 | def test_mhd_projection( 97 | self, 98 | input_dim: int, 99 | num_heads: int, 100 | dim_per_head: int, 101 | is_output_projection: bool, 102 | inputs_shape: list[int], 103 | expected_outputs_shape: list[int], 104 | ): 105 | np_inputs = np.random.normal(1.5, 2.0, inputs_shape).astype(np.float32) 106 | inputs = jnp.asarray(np_inputs) 107 | prng_key = jax.random.PRNGKey(seed=123) 108 | mh = layers.AttentionProjection( 109 | name='mh', 110 | output_dim=input_dim, 111 | num_heads=num_heads, 112 | dim_per_head=dim_per_head, 113 | is_output_projection=is_output_projection, 114 | ) 115 | 116 | @self.variant 117 | def var_fn(): 118 | return mh.init_with_output(prng_key, inputs) 119 | 120 | outputs, params = var_fn() 121 | self.assertLen(jax.tree_util.tree_flatten(params)[0], 2) 122 | self.assertEqual(outputs.shape, tuple(expected_outputs_shape)) 123 | 124 | @chex.variants(with_jit=True, without_jit=True) 125 | def test_per_dim_scale(self): 126 | np_inputs = np.random.normal(1.5, 2.0, [5, 4]).astype(np.float32) 127 | inputs = jnp.asarray(np_inputs) 128 | prng_key = jax.random.PRNGKey(seed=123) 129 | mdl = layers.PerDimScale(name='scale') 130 | 131 | @self.variant 132 | def var_fn(): 133 | return mdl.init_with_output(prng_key, inputs) 134 | 135 | outputs, params = var_fn() 136 | self.assertLen(jax.tree_util.tree_flatten(params)[0], 1) 137 | self.assertEqual(outputs.shape, (5, 4)) 138 | 139 | @chex.variants(with_jit=True, without_jit=True) 140 | @parameterized.product( 141 | enable_query_scale=[True, False], 142 | enable_per_dim_scale=[True, False], 143 | train=[True, False], 144 | ) 145 | def test_mha( 146 | self, 147 | enable_query_scale: bool, 148 | enable_per_dim_scale: bool, 149 | train: bool, 150 | ): 151 | batch_size, seq_len, num_heads, mdl_dim = 3, 8, 4, 16 152 | query_vec = np.random.normal(size=[batch_size, seq_len, mdl_dim]).astype( 153 | np.float32 154 | ) 155 | key_vec = np.random.normal(size=[batch_size, seq_len, mdl_dim]).astype( 156 | np.float32 157 | ) 158 | value_vec = np.random.normal(size=[batch_size, seq_len, mdl_dim]).astype( 159 | np.float32 160 | ) 161 | paddings = jnp.zeros(query_vec.shape[:-1], dtype=query_vec.dtype) 162 | atten_mask = layers.compute_attention_masks_for_fprop(query_vec, paddings) 163 | prng_key = jax.random.PRNGKey(seed=123) 164 | mha = layers.DotProductAttention( 165 | name='mha', 166 | hidden_dim=32, 167 | num_heads=num_heads, 168 | atten_logit_cap=20.0, 169 | internal_enable_query_scale=enable_query_scale, 170 | internal_enable_per_dim_scale=enable_per_dim_scale, 171 | ) 172 | 173 | @self.variant 174 | def var_fn(): 175 | return mha.init_with_output( 176 | prng_key, query_vec, key_vec, value_vec, atten_mask, train=train 177 | ) 178 | 179 | (outputs, probs), params = var_fn() 180 | expected_num_weights = 8 181 | if enable_query_scale and enable_per_dim_scale: 182 | expected_num_weights += 1 183 | self.assertLen(jax.tree_util.tree_flatten(params)[0], expected_num_weights) 184 | self.assertEqual(outputs.shape, (batch_size, seq_len, mdl_dim)) 185 | self.assertEqual(probs.shape, (batch_size, num_heads, seq_len, seq_len)) 186 | 187 | @chex.variants(with_jit=True, without_jit=True) 188 | @parameterized.parameters(True, False) 189 | def test_transformer_layer(self, train: bool): 190 | num_heads, batch_size, seq_len, dim = 8, 3, 12, 32 191 | np_inputs = np.random.normal(1.0, 0.5, [batch_size, seq_len, dim]).astype( 192 | 'float32' 193 | ) 194 | inputs = jnp.asarray(np_inputs) 195 | np_paddings = np.random.randint(0, 1, [batch_size, seq_len]).astype( 196 | 'float32' 197 | ) 198 | paddings = jnp.asarray(np_paddings) 199 | atten_mask = layers.compute_attention_masks_for_fprop(inputs, paddings) 200 | prng_key = jax.random.PRNGKey(seed=123) 201 | tfm = layers.Transformer( 202 | name='tfm', 203 | hidden_dim=128, 204 | num_heads=num_heads, 205 | ) 206 | 207 | @self.variant 208 | def var_fn(): 209 | return tfm.init_with_output( 210 | prng_key, inputs, paddings, atten_mask=atten_mask, train=train 211 | ) 212 | 213 | outputs, params = var_fn() 214 | self.assertLen(jax.tree_util.tree_flatten(params)[0], 17) 215 | self.assertEqual(outputs.shape, (batch_size, seq_len, dim)) 216 | 217 | @chex.variants(with_jit=True, without_jit=True) 218 | @parameterized.product( 219 | scan=[True, False], 220 | train=[True, False], 221 | ) 222 | def test_stacked_transformer_layer(self, scan: bool, train: bool): 223 | batch_size, seq_len, dim = 3, 12, 16 224 | np_inputs = np.random.normal(1.0, 0.5, [batch_size, seq_len, dim]).astype( 225 | 'float32' 226 | ) 227 | inputs = jnp.asarray(np_inputs) 228 | np_paddings = np.random.randint(0, 1, [batch_size, seq_len]).astype( 229 | 'float32' 230 | ) 231 | paddings = jnp.asarray(np_paddings) 232 | prng_key = jax.random.PRNGKey(seed=123) 233 | stacked_tfm = layers.StackedTransformer( 234 | name='stacked_tfm', 235 | hidden_dim=64, 236 | num_heads=8, 237 | num_layers=4, 238 | scan=scan, 239 | ) 240 | 241 | @self.variant 242 | def var_fn(): 243 | return stacked_tfm.init_with_output( 244 | prng_key, inputs, paddings, train=train 245 | ) 246 | 247 | outputs, params = var_fn() 248 | self.assertLen(jax.tree_util.tree_flatten(params)[0], 17 if scan else 68) 249 | self.assertEqual(outputs.shape, (batch_size, seq_len, dim)) 250 | 251 | @chex.variants(with_jit=True, without_jit=True) 252 | @parameterized.product( 253 | num_queries=[1, 4], 254 | train=[True, False], 255 | ) 256 | def test_atten_token_pooling_layer( 257 | self, 258 | num_queries: int, 259 | train: bool, 260 | ): 261 | batch_size, seq_len, num_heads, input_dim = 3, 8, 4, 16 262 | np_inputs = np.random.normal( 263 | 1.5, 2.0, [batch_size, seq_len, input_dim] 264 | ).astype(np.float32) 265 | np_paddings = np.zeros([batch_size, seq_len], dtype=np.float32) 266 | inputs = jnp.asarray(np_inputs) 267 | input_paddings = jnp.asarray(np_paddings) 268 | prng_key = jax.random.PRNGKey(seed=123) 269 | pooler = layers.AttenTokenPoolingLayer( 270 | name='pooling', 271 | num_heads=num_heads, 272 | num_queries=num_queries, 273 | ) 274 | 275 | @self.variant 276 | def var_fn(): 277 | return pooler.init_with_output( 278 | prng_key, inputs, input_paddings, train=train 279 | ) 280 | 281 | outputs, params = var_fn() 282 | self.assertLen(jax.tree_util.tree_flatten(params)[0], 12) 283 | self.assertEqual(outputs.shape, (batch_size, num_queries, input_dim)) 284 | 285 | 286 | if __name__ == '__main__': 287 | absltest.main() 288 | -------------------------------------------------------------------------------- /videoprism/colabs/videoprism_video_text_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [] 7 | }, 8 | "kernelspec": { 9 | "name": "python3", 10 | "display_name": "Python 3" 11 | }, 12 | "language_info": { 13 | "name": "python" 14 | } 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "markdown", 19 | "source": [ 20 | "# VideoPrism Video-Text Encoder Demo\n", 21 | "\n", 22 | "[![Paper](https://img.shields.io/badge/arXiv-2402.13217-red.svg)](https://arxiv.org/abs/2402.13217)\n", 23 | "[![Blog](https://img.shields.io/badge/Google_Research-Blog-green.svg)](https://research.google/blog/videoprism-a-foundational-visual-encoder-for-video-understanding/)\n", 24 | "[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)\n", 25 | "\n", 26 | "This notebook provides an example of video and text feature extraction with a pre-trained VideoPrism video-text model for zero-shot video classification/retrieval.\n", 27 | "\n", 28 | "Please run this demo on Google Colab with (faster) or without TPU." 29 | ], 30 | "metadata": { 31 | "id": "KPPUiCpSbm53" 32 | } 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "source": [ 37 | "## Set up" 38 | ], 39 | "metadata": { 40 | "id": "k08qFZ9-cn9v" 41 | } 42 | }, 43 | { 44 | "cell_type": "code", 45 | "source": [ 46 | "# @title Prepare environment\n", 47 | "\n", 48 | "import os\n", 49 | "import sys\n", 50 | "\n", 51 | "# Fetch VideoPrism repository if Python does not know about it and install\n", 52 | "# dependencies needed for this notebook.\n", 53 | "if not os.path.exists(\"videoprism_repo\"):\n", 54 | " !git clone --quiet --branch=main --depth=1 \\\n", 55 | " https://github.com/google-deepmind/videoprism.git videoprism_repo\n", 56 | " os.chdir('./videoprism_repo')\n", 57 | " !pip install .\n", 58 | " os.chdir('..')\n", 59 | "\n", 60 | "# Append VideoPrism code to Python import path.\n", 61 | "if \"videoprism_repo\" not in sys.path:\n", 62 | " sys.path.append(\"videoprism_repo\")\n", 63 | "\n", 64 | "# Install missing dependencies.\n", 65 | "!pip install mediapy\n", 66 | "\n", 67 | "import jax\n", 68 | "from jax.extend import backend\n", 69 | "import tensorflow as tf\n", 70 | "\n", 71 | "# Do not let TF use the GPU or TPUs.\n", 72 | "tf.config.set_visible_devices([], \"GPU\")\n", 73 | "tf.config.set_visible_devices([], \"TPU\")\n", 74 | "\n", 75 | "print(f\"JAX version: {jax.__version__}\")\n", 76 | "print(f\"JAX platform: {backend.get_backend().platform}\")\n", 77 | "print(f\"JAX devices: {jax.device_count()}\")" 78 | ], 79 | "metadata": { 80 | "id": "1dfyX8EyVsvL" 81 | }, 82 | "execution_count": null, 83 | "outputs": [] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "metadata": { 89 | "id": "zByA1K0IVKAI" 90 | }, 91 | "outputs": [], 92 | "source": [ 93 | "# @title Load dependencies and define utilities\n", 94 | "\n", 95 | "import mediapy\n", 96 | "import numpy as np\n", 97 | "from PIL import Image\n", 98 | "\n", 99 | "\n", 100 | "def read_and_preprocess_video(\n", 101 | " filename: str, target_num_frames: int, target_frame_size: tuple[int, int]\n", 102 | "):\n", 103 | " \"\"\"Reads and preprocesses a video.\"\"\"\n", 104 | "\n", 105 | " frames = mediapy.read_video(filename)\n", 106 | "\n", 107 | " # Sample to target number of frames.\n", 108 | " frame_indices = np.linspace(\n", 109 | " 0, len(frames), num=target_num_frames, endpoint=False, dtype=np.int32\n", 110 | " )\n", 111 | " frames = np.array([frames[i] for i in frame_indices])\n", 112 | "\n", 113 | " # Resize to target size.\n", 114 | " original_height, original_width = frames.shape[-3:-1]\n", 115 | " target_height, target_width = target_frame_size\n", 116 | " assert (\n", 117 | " original_height * target_width == original_width * target_height\n", 118 | " ), 'Currently does not support aspect ratio mismatch.'\n", 119 | " frames = mediapy.resize_video(frames, shape=target_frame_size)\n", 120 | "\n", 121 | " # Normalize pixel values to [0.0, 1.0].\n", 122 | " frames = mediapy.to_float01(frames)\n", 123 | "\n", 124 | " return frames\n", 125 | "\n", 126 | "\n", 127 | "def compute_similarity_matrix(\n", 128 | " video_embeddings,\n", 129 | " text_embeddings,\n", 130 | " temperature: float,\n", 131 | " apply_softmax: str | None = None,\n", 132 | ") -\u003e np.ndarray:\n", 133 | " \"\"\"Computes cosine similarity matrix.\"\"\"\n", 134 | " assert apply_softmax in [None, 'over_texts', 'over_videos']\n", 135 | " emb_dim = video_embeddings[0].shape[-1]\n", 136 | " assert emb_dim == text_embeddings[0].shape[-1]\n", 137 | "\n", 138 | " video_embeddings = np.array(video_embeddings).reshape(-1, emb_dim)\n", 139 | " text_embeddings = np.array(text_embeddings).reshape(-1, emb_dim)\n", 140 | " similarity_matrix = np.dot(video_embeddings, text_embeddings.T)\n", 141 | "\n", 142 | " if temperature is not None:\n", 143 | " similarity_matrix /= temperature\n", 144 | "\n", 145 | " if apply_softmax == 'over_videos':\n", 146 | " similarity_matrix = np.exp(similarity_matrix)\n", 147 | " similarity_matrix = similarity_matrix / np.sum(\n", 148 | " similarity_matrix, axis=0, keepdims=True\n", 149 | " )\n", 150 | " elif apply_softmax == 'over_texts':\n", 151 | " similarity_matrix = np.exp(similarity_matrix)\n", 152 | " similarity_matrix = similarity_matrix / np.sum(\n", 153 | " similarity_matrix, axis=1, keepdims=True\n", 154 | " )\n", 155 | "\n", 156 | " return similarity_matrix" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "source": [ 162 | "# @title Load model\n", 163 | "\n", 164 | "import jax\n", 165 | "import jax.numpy as jnp\n", 166 | "from videoprism import models as vp\n", 167 | "\n", 168 | "MODEL_NAME = 'videoprism_lvt_public_v1_base' # @param ['videoprism_lvt_public_v1_base', 'videoprism_lvt_public_v1_large'] {allow-input: false}\n", 169 | "USE_BFLOAT16 = False # @param { type: \"boolean\" }\n", 170 | "NUM_FRAMES = 16\n", 171 | "FRAME_SIZE = 288\n", 172 | "\n", 173 | "fprop_dtype = jnp.bfloat16 if USE_BFLOAT16 else None\n", 174 | "flax_model = vp.get_model(MODEL_NAME, fprop_dtype=fprop_dtype)\n", 175 | "loaded_state = vp.load_pretrained_weights(MODEL_NAME)\n", 176 | "text_tokenizer = vp.load_text_tokenizer('c4_en')\n", 177 | "\n", 178 | "\n", 179 | "@jax.jit\n", 180 | "def forward_fn(inputs, text_token_ids, text_paddings, train=False):\n", 181 | " return flax_model.apply(\n", 182 | " loaded_state,\n", 183 | " inputs,\n", 184 | " text_token_ids,\n", 185 | " text_paddings,\n", 186 | " train=train,\n", 187 | " )" 188 | ], 189 | "metadata": { 190 | "id": "WnYuzSgrXCL1" 191 | }, 192 | "execution_count": null, 193 | "outputs": [] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "source": [ 198 | "# Example: Zero-shot Video Classification/Retrieval\n", 199 | "\n", 200 | "In this example, we extract the embedding of an input video, and the embeddings of five senetence. We measure the cosine similarites between the videos and sentences." 201 | ], 202 | "metadata": { 203 | "id": "AliScLC0jo1s" 204 | } 205 | }, 206 | { 207 | "cell_type": "code", 208 | "source": [ 209 | "# @title Specify input video\n", 210 | "VIDEO_FILE_PATH = 'videoprism_repo/videoprism/assets/water_bottle_drumming.mp4' # @param {type: \"string\"}\n", 211 | "\n", 212 | "frames = read_and_preprocess_video(\n", 213 | " VIDEO_FILE_PATH,\n", 214 | " target_num_frames=NUM_FRAMES,\n", 215 | " target_frame_size=[FRAME_SIZE, FRAME_SIZE],\n", 216 | ")\n", 217 | "frames = jnp.asarray(frames[None, ...]) # Add batch dimension.\n", 218 | "if USE_BFLOAT16:\n", 219 | " frames = frames.astype(jnp.bfloat16)" 220 | ], 221 | "metadata": { 222 | "id": "sESN_CjfEiQR" 223 | }, 224 | "execution_count": null, 225 | "outputs": [] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "source": [ 230 | "# @title Specify input text queries\n", 231 | "TEXT_QUERY_CSV = 'playing drums,sitting,playing flute,playing at playground,concert' # @param {type: \"string\"}\n", 232 | "PROMPT_TEMPLATE = 'a video of {}.'\n", 233 | "\n", 234 | "text_queries = TEXT_QUERY_CSV.split(',')\n", 235 | "text_queries = [PROMPT_TEMPLATE.format(t) for t in text_queries]\n", 236 | "text_ids, text_paddings = vp.tokenize_texts(text_tokenizer, text_queries)\n", 237 | "if USE_BFLOAT16:\n", 238 | " text_paddings = text_paddings.astype(jnp.bfloat16)\n", 239 | "\n", 240 | "print('Input text queries:')\n", 241 | "for i, text in enumerate(text_queries):\n", 242 | " print(f'({i + 1}) {text}')" 243 | ], 244 | "metadata": { 245 | "id": "kLzkhP8CYUYj" 246 | }, 247 | "execution_count": null, 248 | "outputs": [] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "source": [ 253 | "# @title Compute video-to-text retrieval results\n", 254 | "video_embeddings, text_embeddings, _ = forward_fn(\n", 255 | " frames, text_ids, text_paddings)\n", 256 | "\n", 257 | "TEMPERATURE = 0.01 # @param {type: \"number\"}\n", 258 | "similarity_matrix = compute_similarity_matrix(\n", 259 | " video_embeddings,\n", 260 | " text_embeddings,\n", 261 | " temperature=TEMPERATURE,\n", 262 | " apply_softmax='over_texts',\n", 263 | ")" 264 | ], 265 | "metadata": { 266 | "id": "bfwT93Yz5oi_" 267 | }, 268 | "execution_count": null, 269 | "outputs": [] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "source": [ 274 | "v2t_similarity_vector = similarity_matrix[0]\n", 275 | "top_indices = np.argsort(v2t_similarity_vector)[::-1]\n", 276 | "\n", 277 | "print(f'Query video: {os.path.basename(VIDEO_FILE_PATH)}')\n", 278 | "mediapy.show_video(frames[0].astype(jnp.float32), fps=6.0)\n", 279 | "\n", 280 | "for k, j in enumerate(top_indices):\n", 281 | " print(\n", 282 | " 'Top-%d retrieved text: %s [Similarity = %0.4f]'\n", 283 | " % (k + 1, text_queries[j], v2t_similarity_vector[j])\n", 284 | " )\n", 285 | "print(f'\\nThis is {text_queries[top_indices[0]]}')" 286 | ], 287 | "metadata": { 288 | "id": "lZ8woxde6t_S" 289 | }, 290 | "execution_count": null, 291 | "outputs": [] 292 | } 293 | ] 294 | } 295 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VideoPrism: A Foundational Visual Encoder for Video Understanding 2 | 3 | [![Paper](https://img.shields.io/badge/arXiv-2402.13217-red.svg)](https://arxiv.org/abs/2402.13217) 4 | [![Blog](https://img.shields.io/badge/Google_Research-Blog-green.svg)](https://research.google/blog/videoprism-a-foundational-visual-encoder-for-video-understanding/) 5 | [![Video Encoder Colab Demo](https://img.shields.io/static/v1?label=Video%20Encoder%20Demo&message=Google%20Colab&logo=google&color=orange)](https://colab.research.google.com/github/google-deepmind/videoprism/blob/main/videoprism/colabs/videoprism_video_encoder_demo.ipynb) 6 | [![Video-Text Encoder Colab Demo](https://img.shields.io/static/v1?label=Video-Text%20Encoder%20Demo&message=Google%20Colab&logo=google&color=orange)](https://colab.research.google.com/github/google-deepmind/videoprism/blob/main/videoprism/colabs/videoprism_video_text_demo.ipynb) 7 | [![HuggingFace](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-yellow)](https://huggingface.co/collections/google/videoprism-686e823d6070ec6ad9e4b1f2) 8 | [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) 9 | 10 | [VideoPrism](https://arxiv.org/abs/2402.13217) is a general-purpose video 11 | encoder designed to handle a wide spectrum of video understanding tasks, 12 | including classification, retrieval, localization, captioning, and question 13 | answering. It is pre-trained on a massive and diverse dataset: 1 billion 14 | image-text pairs from [WebLI](https://arxiv.org/abs/2209.06794), 36 million 15 | high-quality video-text pairs, and 582 million video clips with noisy or 16 | machine-generated parallel text (subject to data wipeout). The pre-training 17 | approach is designed for these hybrid data, to learn both from video-text pairs 18 | and the videos themselves. VideoPrism is fairly easy to adapt to new video 19 | understanding tasks, and achieves state-of-the-art performance on 31 out of 33 20 | public video understanding benchmarks using a single frozen model. 21 | 22 | This repository releases the model weight checkpoints and hosts [JAX](https://github.com/jax-ml/jax)/[Flax](https://github.com/google/flax) utility 23 | functions for checkpoint loading and model inference. 24 | 25 | ## Updates 26 | 27 | * **[Jul-16-25]:** Released VideoPrism video-text encoders for cross-modal retrieval [[`Colab notebook`](https://colab.research.google.com/github/google-deepmind/videoprism/blob/main/videoprism/colabs/videoprism_video_text_demo.ipynb)]. :fire::fire: 28 | * **[Jun-15-25]:** Added models to [[`Hugging Face`](https://huggingface.co/collections/google/videoprism-686e823d6070ec6ad9e4b1f2)]. 29 | * **[Jun-05-25]:** Added video encoder demo [[`Colab notebook`](https://colab.research.google.com/github/google-deepmind/videoprism/blob/main/videoprism/colabs/videoprism_video_encoder_demo.ipynb)]. 30 | * **[Jun-03-25]:** Released VideoPrism video encoders (Base and Large) [[`Blog`](https://research.google/blog/videoprism-a-foundational-visual-encoder-for-video-understanding/)] [[`Paper`](https://arxiv.org/abs/2402.13217)]. :fire::fire: 31 | 32 | ## TODOs 33 | 34 | - [ ] Add PyTorch model support. 35 | 36 | ## Getting started 37 | 38 | You will need Python 3.9 or later. Download the code from GitHub and run: 39 | 40 | ```shell 41 | $ git clone https://github.com/google-deepmind/videoprism.git 42 | $ cd videoprism 43 | $ pip install . 44 | ``` 45 | 46 | Please get started with the following example code for model checkpoint loading 47 | and inference or use the [Colab notebook for video encoders](https://colab.research.google.com/github/google-deepmind/videoprism/blob/main/videoprism/colabs/videoprism_video_encoder_demo.ipynb) / [Colab notebook for video-text encoders](https://colab.research.google.com/github/google-deepmind/videoprism/blob/main/videoprism/colabs/videoprism_video_text_demo.ipynb): 48 | 49 | ```python 50 | import jax 51 | from videoprism import models as vp 52 | 53 | # Video encoders. 54 | model_name = 'videoprism_public_v1_base' # configuration name 55 | flax_model = vp.get_model(model_name) 56 | loaded_state = vp.load_pretrained_weights(model_name) 57 | 58 | @jax.jit 59 | def forward_fn(inputs): 60 | return flax_model.apply(loaded_state, inputs, train=False) 61 | 62 | video_inputs = ... # Shape = [batch_size, num_frames, height, width, 3]. 63 | outputs, _ = forward_fn(video_inputs) # Shape = [batch_size, num_tokens, feature_channels]. 64 | 65 | # Video-text encoders. 66 | model_name = 'videoprism_lvt_public_v1_base' # configuration name 67 | flax_model = vp.get_model(model_name) 68 | loaded_state = vp.load_pretrained_weights(model_name) 69 | text_tokenizer = vp.load_text_tokenizer('c4_en') 70 | 71 | @jax.jit 72 | def forward_fn(inputs, text_token_ids, text_token_paddings, train=False): 73 | return flax_model.apply( 74 | loaded_state, 75 | inputs, 76 | text_token_ids, 77 | text_token_paddings, 78 | train=train, 79 | ) 80 | 81 | video_inputs = ... # Shape = [batch_size, num_frames, height, width, 3]. 82 | text_queries = ... # A list of input text queries. 83 | text_ids, text_paddings = vp.tokenize_texts(text_tokenizer, text_queries) 84 | video_embeddings, text_embeddings, _ = forward_fn( 85 | video_inputs, text_ids, text_paddings) # Shape = [batch_size, feature_channels]. 86 | ``` 87 | 88 | ## Released models 89 | 90 | We release the following model variants: 91 | 92 | | Model Name | Configuration Name | Model Type | Backbone | #Params | File Size | Checkpoint | 93 | | -------- | -------- | ------- | :-------: | :-------: | :-------: | :-------: | 94 | | VideoPrism-B | `videoprism_public_v1_base` | Video encoder | ViT-B | 114M | 458MB | [link](https://huggingface.co/google/videoprism-base-f16r288) | 95 | | VideoPrism-L | `videoprism_public_v1_large` | Video encoder | ViT-L | 354M | 1.42GB | [link](https://huggingface.co/google/videoprism-large-f8r288) | 96 | | VideoPrism-LvT-B | `videoprism_lvt_public_v1_base` | Video-text encoders | ViT-B | 248M | 991MB | [link](https://huggingface.co/google/videoprism-lvt-base-f16r288) | 97 | | VideoPrism-LvT-L | `videoprism_lvt_public_v1_large` | Video-text encoders | ViT-L | 580M | 2.30GB | [link](https://huggingface.co/google/videoprism-lvt-large-f8r288) | 98 | 99 | Video encoders take videos with shape `(batch_size, num_frames, 288, 288, 3)` 100 | as inputs and output embeddings with shape 101 | `(batch_size, num_frames * 16 * 16, feature_channels)` which could be reshaped 102 | into `(batch_size, num_frames, 16, 16, feature_channels)` for spatiotemporal 103 | representations. During model training, `num_frames` is set to 16 and 8 for 104 | VideoPrism-B and VideoPrism-L, respectively. Both models are expected to work 105 | with arbitrary `num_frames` by interpolating the temporal positional embeddings. 106 | The RGB values of input videos should be normalized in [0.0, 1.0]. 107 | 108 | In video-text models, both video and text encoders produce global embeddings 109 | with shape `(batch_size, feature_channels)`, whose similarities could be 110 | measured by cosine distances. We use the `c4_en` [SentencePiece](https://github.com/google/sentencepiece) model for text tokenization. During inference, embedding 111 | calculation for either modality can be skipped by providing `None` as the input. 112 | 113 | ## Results with frozen backbones 114 | 115 | *"Public"* denotes models we released in this repository. *"Paper"* and 116 | *"Prior SOTA"* denote our models and previous best-performing models reported 117 | in the [paper](https://arxiv.org/abs/2402.13217), respectively. Our *public* 118 | models perform slightly worse than the *paper* models due to different 119 | pre-training image-text data we used subject to data policy. 120 | 121 | ### Video-focused tasks ([VideoGLUE](https://arxiv.org/abs/2307.03166)) 122 | 123 | | Models | K400 | MiT | SSv2 | D48 | Charades | ActivityNet | AVA | AVA-K | 124 | | -------- | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | 125 | | **VideoPrism-B (public)** | 82.9 | 39.7 | 62.2 | 64.3 | 43.5 | 36.5 | 28.3 | 30.8 | 126 | | **VideoPrism-L (public)** | 85.0 | 43.3 | 64.6 | 67.6 | 53.2 | 37.0 | 32.4 | 34.5 | 127 | | VideoPrism-B (paper) | 84.2 | 40.8 | 63.6 | 67.4 | 40.4 | 36.6 | 30.6 | 31.8 | 128 | | VideoPrism-g (paper) | 87.2 | 45.5 | 68.5 | 71.3 | 62.3 | 37.8 | 36.2 | 37.3 | 129 | | Prior SOTA (B) | 77.1 | 34.0 | 58.2 | 55.6 | 33.3 | 35.8 | 21.1 | 25.9 | 130 | | Prior SOTA (L+) | 82.8 | 40.3 | 67.4 | 69.6 | 39.9 | 36.7 | 24.4 | 26.2 | 131 | 132 | ### Zero-shot video-text retrieval 133 | 134 | | Models | MSRVTT-1K (v2t) | MSRVTT-1K (t2v) | VATEX (v2t) | VATEX (t2v) | ActivityNet (v2t) | ActivityNet (t2v) | 135 | | -------- | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | 136 | | **VideoPrism-LvT-B (public)** | 49.8 | 50.1 | 73.1 | 56.2 | 47.9 | 48.8 | 137 | | **VideoPrism-LvT-L (public)** | 50.6 | 50.1 | 75.0 | 57.2 | 49.1 | 51.3 | 138 | | VideoPrism-LvT-B (paper) | 50.2 | 51.4 | 76.2 | 57.7 | 47.9 | 49.6 | 139 | | VideoPrism-LvT-g (paper) | 51.7 | 52.7 | 77.1 | 62.5 | 50.3 | 52.7 | 140 | | Prior SOTA (B) | - | 34.0 | - | - | - | 30.6 | 141 | | Prior SOTA (L+) | 45.4 | 43.9 | 73.6 | 53.2 | 40.7 | 42.8 | 142 | 143 | ### Zero-shot video classification 144 | 145 | | Models | K400 | SSv2 (Temporal) | SSv2 (Events) | NExT-QA (Hard) | Charades | Charades (STA) | 146 | | -------- | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | 147 | | **VideoPrism-LvT-B (public)** | 69.2 | 14.6 | 11.3 | 31.1 | 26.9 | 48.6 | 148 | | **VideoPrism-LvT-L (public)** | 72.4 | 18.0 | 12.4 | 32.1 | 32.4 | 50.2 | 149 | | VideoPrism-LvT-B (paper) | 71.3 | 16.1 | 11.9 | 31.3 | 29.2 | 50.0 | 150 | | VideoPrism-LvT-g (paper) | 74.6 | 18.6 | 15.7 | 32.7 | 32.4 | 50.4 | 151 | | Prior SOTA (B) | - | 9.8 | 6.4 | 27.6 | 21.1 | - | 152 | | Prior SOTA (L+) | 72.0 | 15.2 | 11.4 | 25.2 | 25.8 | 47.2 | 153 | 154 | ## Citation 155 | 156 | If you use VideoPrism, please cite the following papers: 157 | 158 | 159 | ```bibtex 160 | @inproceedings{zhao2024videoprism, 161 | title = {{VideoPrism}: A Foundational Visual Encoder for Video Understanding}, 162 | author = {Long Zhao and Nitesh B. Gundavarapu and Liangzhe Yuan and Hao Zhou and Shen Yan and Jennifer J. Sun and Luke Friedman and Rui Qian and Tobias Weyand and Yue Zhao and Rachel Hornung and Florian Schroff and Ming-Hsuan Yang and David A. Ross and Huisheng Wang and Hartwig Adam and Mikhail Sirotenko and Ting Liu and Boqing Gong}, 163 | booktitle = {International Conference on Machine Learning (ICML)}, 164 | year = {2024} 165 | } 166 | 167 | @article{yuan2024videoglue, 168 | title = {{VideoGLUE}: Video General Understanding Evaluation of Foundation Models}, 169 | author = {Liangzhe Yuan and Nitesh Bharadwaj Gundavarapu and Long Zhao and Hao Zhou and Yin Cui and Lu Jiang and Xuan Yang and Menglin Jia and Tobias Weyand and Luke Friedman and Mikhail Sirotenko and Huisheng Wang and Florian Schroff and Hartwig Adam and Ming-Hsuan Yang and Ting Liu and Boqing Gong}, 170 | journal = {Transactions on Machine Learning Research (TMLR)}, 171 | year = {2024} 172 | } 173 | ``` 174 | 175 | ## License 176 | 177 | Copyright 2025 Google LLC 178 | 179 | All software is licensed under the Apache License, Version 2.0 (Apache 2.0); 180 | you may not use this file except in compliance with the Apache 2.0 license. You 181 | may obtain a copy of the Apache 2.0 license at: 182 | 183 | All other materials are licensed under the Creative Commons Attribution 4.0 International License (CC-BY). You may obtain a copy of the CC-BY license at: 184 | 185 | Unless required by applicable law or agreed to in writing, all software and 186 | materials distributed here under the Apache 2.0 or CC-BY licenses are 187 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, 188 | either express or implied. See the licenses for the specific language governing 189 | permissions and limitations under those licenses. 190 | 191 | ## Disclaimer 192 | 193 | This is not an official Google product. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /videoprism/encoders_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 VideoPrism Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for encoder modules.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import chex 20 | import jax 21 | from jax import numpy as jnp 22 | import numpy as np 23 | from videoprism import encoders 24 | 25 | 26 | class EncodersTest(parameterized.TestCase): 27 | 28 | @chex.variants(with_jit=True, without_jit=True) 29 | def test_embedding_layer(self): 30 | num_classes, dim, input_shape = 8, 10, (5, 20) 31 | npy_input = np.random.randint(0, num_classes, input_shape).astype('int32') 32 | inputs = jnp.asarray(npy_input) 33 | prng_key = jax.random.PRNGKey(seed=123) 34 | emb_layer = encoders.Embedding( 35 | name='emb_lookup', 36 | num_classes=num_classes, 37 | input_dim=dim, 38 | scale_sqrt_depth=True, 39 | ) 40 | 41 | @self.variant 42 | def var_fn(): 43 | return emb_layer.init_with_output(prng_key, inputs) 44 | 45 | outputs, params = var_fn() 46 | self.assertLen(jax.tree_util.tree_flatten(params)[0], 1) 47 | self.assertEqual(outputs.shape, input_shape + (dim,)) 48 | 49 | @chex.variants(with_jit=True, without_jit=True) 50 | def test_positional_embedding_layer(self): 51 | seq_len, dim = 8, 10 52 | prng_key = jax.random.PRNGKey(seed=123) 53 | emb_layer = encoders.PositionalEmbedding( 54 | name='pos_emb', 55 | embedding_dim=dim, 56 | min_timescale=10, 57 | max_timescale=20, 58 | ) 59 | 60 | @self.variant 61 | def var_fn(): 62 | return emb_layer.init_with_output(prng_key, seq_len) 63 | 64 | outputs, params = var_fn() 65 | self.assertEmpty(jax.tree_util.tree_flatten(params)[0]) 66 | self.assertEqual(outputs.shape, (1, seq_len, dim)) 67 | 68 | @chex.variants(with_jit=True, without_jit=True) 69 | def test_trainable_positional_embedding_layer(self): 70 | seq_len, dim = 8, 10 71 | prng_key = jax.random.PRNGKey(seed=123) 72 | emb_layer = encoders.TrainablePositionalEmbedding( 73 | name='pos_emb', 74 | max_seq_length=seq_len, 75 | embedding_dim=dim, 76 | lookup_style='matmul', 77 | ) 78 | 79 | @self.variant 80 | def var_fn(): 81 | return emb_layer.init_with_output(prng_key, seq_len) 82 | 83 | outputs, params = var_fn() 84 | self.assertLen(jax.tree_util.tree_flatten(params)[0], 1) 85 | self.assertEqual(outputs.shape, (1, seq_len, dim)) 86 | 87 | @chex.variants(with_jit=True) 88 | @parameterized.product( 89 | scan=[True, False], 90 | train=[True, False], 91 | ) 92 | def test_vision_transformer(self, scan: bool, train: bool): 93 | batch_size, seq_len, dim = 1, 6, 4 94 | np_inputs = np.random.normal(1.0, 0.5, [batch_size, seq_len, dim]).astype( 95 | 'float32' 96 | ) 97 | inputs = jnp.asarray(np_inputs) 98 | prng_key = jax.random.PRNGKey(seed=123) 99 | vit = encoders.VisionTransformer( 100 | name='vit', 101 | num_tfm_layers=2, 102 | mlp_dim=4, 103 | num_heads=2, 104 | scan=scan, 105 | ) 106 | 107 | @self.variant 108 | def var_fn(): 109 | return vit.init_with_output(prng_key, inputs, train=train) 110 | 111 | outputs, params = var_fn() 112 | self.assertLen(jax.tree_util.tree_flatten(params)[0], 16 if scan else 32) 113 | self.assertEqual(outputs.shape, (batch_size, seq_len, dim)) 114 | 115 | @chex.variants(with_jit=True) 116 | @parameterized.named_parameters( 117 | ('train', False, True, False, False), 118 | ('scan', True, False, False, False), 119 | ('scan_and_train', True, True, False, False), 120 | ('return_intermediate', True, False, True, False), 121 | ('use_frame_paddings', True, False, False, True), 122 | ) 123 | def test_factorized_encoder( 124 | self, 125 | scan: bool, 126 | train: bool, 127 | return_intermediate: bool, 128 | use_frame_paddings: bool, 129 | ): 130 | batch_size, num_frames, image_size, patch_size, dim = 1, 4, 16, 4, 8 131 | np_inputs = np.random.normal( 132 | 0.0, 133 | 0.1, 134 | [batch_size, num_frames, image_size, image_size, 3], 135 | ).astype('float32') 136 | inputs = jnp.asarray(np_inputs) 137 | 138 | frame_paddings = None 139 | if use_frame_paddings: 140 | np_frame_paddings = np.zeros((batch_size, num_frames), dtype='float32') 141 | np_frame_paddings[:, num_frames // 2 :] = 1 142 | frame_paddings = jnp.asarray(np_frame_paddings) 143 | 144 | prng_key = jax.random.PRNGKey(seed=123) 145 | enc = encoders.FactorizedEncoder( 146 | name='enc', 147 | patch_size=patch_size, 148 | pos_emb_shape=(16, 16, 16), 149 | model_dim=dim, 150 | num_spatial_layers=2, 151 | num_temporal_layers=2, 152 | num_heads=2, 153 | mlp_dim=4, 154 | atten_logit_cap=50.0, 155 | scan=scan, 156 | ) 157 | 158 | @self.variant 159 | def var_fn(): 160 | return enc.init_with_output( 161 | prng_key, 162 | inputs, 163 | train=train, 164 | return_intermediate=return_intermediate, 165 | frame_paddings=frame_paddings, 166 | ) 167 | 168 | (embeddings, outputs), params = var_fn() 169 | 170 | self.assertLen(jax.tree_util.tree_flatten(params)[0], 40 if scan else 72) 171 | self.assertEqual( 172 | embeddings.shape, 173 | (batch_size, num_frames * (image_size // patch_size) ** 2, dim), 174 | ) 175 | if return_intermediate: 176 | self.assertEqual( 177 | outputs['spatial_features'].shape, 178 | (batch_size, num_frames * (image_size // patch_size) ** 2, dim), 179 | ) 180 | else: 181 | self.assertEmpty(outputs) 182 | 183 | @chex.variants(with_jit=True) 184 | @parameterized.named_parameters( 185 | ('train', True, False), 186 | ('return_intermediate', False, True), 187 | ) 188 | def test_factorized_video_classifier( 189 | self, train: bool, return_intermediate: bool 190 | ): 191 | batch_size, num_frames, image_size, patch_size, dim = 1, 4, 16, 4, 8 192 | np_inputs = np.random.normal( 193 | 0.0, 194 | 0.1, 195 | [batch_size, num_frames, image_size, image_size, 3], 196 | ).astype('float32') 197 | inputs = jnp.asarray(np_inputs) 198 | 199 | encoder_params = dict( 200 | patch_size=patch_size, 201 | pos_emb_shape=(16, 16, 16), 202 | model_dim=dim, 203 | num_spatial_layers=2, 204 | num_temporal_layers=2, 205 | num_heads=2, 206 | mlp_dim=4, 207 | atten_logit_cap=50.0, 208 | scan=True, 209 | ) 210 | prng_key = jax.random.PRNGKey(seed=123) 211 | classifier = encoders.FactorizedVideoClassifier( 212 | name='classifier', 213 | encoder_params=encoder_params, 214 | num_classes=10, 215 | ) 216 | 217 | @self.variant 218 | def var_fn(): 219 | return classifier.init_with_output( 220 | prng_key, inputs, train=train, return_intermediate=return_intermediate 221 | ) 222 | 223 | (logits, outputs), params = var_fn() 224 | self.assertLen(jax.tree_util.tree_flatten(params)[0], 54) 225 | self.assertEqual(logits.shape, (batch_size, 10)) 226 | if return_intermediate: 227 | self.assertEqual( 228 | outputs['spatial_features'].shape, 229 | (batch_size, num_frames * (image_size // patch_size) ** 2, dim), 230 | ) 231 | self.assertEqual( 232 | outputs['spatiotemporal_features'].shape, 233 | (batch_size, num_frames * (image_size // patch_size) ** 2, dim), 234 | ) 235 | self.assertEqual(outputs['global_embeddings'].shape, (batch_size, dim)) 236 | else: 237 | self.assertEmpty(outputs) 238 | 239 | @chex.variants(with_jit=True) 240 | @parameterized.named_parameters( 241 | ('train', False, True), 242 | ('scan', True, False), 243 | ('scan_and_train', True, True), 244 | ) 245 | def test_text_encoder(self, scan: bool, train: bool): 246 | batch_size, seq_len, vocab_size, dim = 1, 10, 20, 8 247 | np_inputs = np.random.randint(0, vocab_size, [batch_size, seq_len]).astype( 248 | 'int32' 249 | ) 250 | inputs = jnp.asarray(np_inputs) 251 | np_paddings = np.zeros([batch_size, seq_len], dtype='float32') 252 | np_paddings[:, seq_len // 2 :] = 1 253 | paddings = jnp.asarray(np_paddings) 254 | 255 | prng_key = jax.random.PRNGKey(seed=123) 256 | enc = encoders.TextEncoder( 257 | name='enc', 258 | vocabulary_size=vocab_size, 259 | num_class_tokens=1, 260 | model_dim=dim, 261 | num_layers=2, 262 | mlp_dim=4, 263 | num_heads=2, 264 | atten_logit_cap=50.0, 265 | scan=scan, 266 | ) 267 | 268 | @self.variant 269 | def var_fn(): 270 | return enc.init_with_output(prng_key, inputs, paddings, train=train) 271 | 272 | outputs, params = var_fn() 273 | self.assertLen(jax.tree_util.tree_flatten(params)[0], 20 if scan else 36) 274 | self.assertEqual(outputs.shape, (batch_size, seq_len + 1, dim)) 275 | 276 | @chex.variants(with_jit=True) 277 | @parameterized.named_parameters( 278 | ('train', False, True, False), 279 | ('scan_and_train', True, True, False), 280 | ('return_intermediate', True, False, True), 281 | ( 282 | 'selectively_return_intermediate', 283 | True, 284 | False, 285 | {'spatial_features', 'frame_embeddings'}, 286 | ), 287 | ) 288 | def test_factorized_video_clip( 289 | self, scan: bool, train: bool, return_intermediate: bool 290 | ): 291 | batch_size, num_frames, image_size, patch_size, dim = 1, 4, 16, 4, 8 292 | np_inputs = np.random.normal( 293 | 0.0, 294 | 0.1, 295 | [batch_size, num_frames, image_size, image_size, 3], 296 | ).astype('float32') 297 | inputs = jnp.asarray(np_inputs) 298 | 299 | seq_len, vocab_size = 10, 20 300 | np_text_token_ids = np.random.randint( 301 | 0, vocab_size, [batch_size, seq_len] 302 | ).astype('int32') 303 | text_token_ids = jnp.asarray(np_text_token_ids) 304 | np_text_paddings = np.zeros([batch_size, seq_len], dtype='float32') 305 | np_text_paddings[:, seq_len // 2 :] = 1 306 | text_paddings = jnp.asarray(np_text_paddings) 307 | 308 | prng_key = jax.random.PRNGKey(seed=123) 309 | net = encoders.FactorizedVideoCLIP( 310 | name='net', 311 | patch_size=patch_size, 312 | pos_emb_shape=(16, 16, 16), 313 | num_spatial_layers=2, 314 | num_temporal_layers=2, 315 | mlp_dim=4, 316 | num_auxiliary_layers=1, 317 | vocabulary_size=vocab_size, 318 | enable_causal_atten=True, 319 | num_unimodal_layers=2, 320 | norm_policy='pre', 321 | model_dim=dim, 322 | num_heads=2, 323 | atten_logit_cap=50.0, 324 | scan=scan, 325 | ) 326 | 327 | @self.variant 328 | def var_fn(): 329 | return net.init_with_output( 330 | prng_key, 331 | inputs=inputs, 332 | text_token_ids=text_token_ids, 333 | text_paddings=text_paddings, 334 | train=train, 335 | normalize=True, 336 | return_intermediate=return_intermediate, 337 | ) 338 | 339 | (video_embeddings, text_embeddings, outputs), params = var_fn() 340 | self.assertLen(jax.tree_util.tree_flatten(params)[0], 88 if scan else 136) 341 | self.assertEqual(video_embeddings.shape, (batch_size, dim)) 342 | self.assertEqual(text_embeddings.shape, (batch_size, dim)) 343 | if not return_intermediate: 344 | self.assertEmpty(outputs) 345 | else: 346 | if return_intermediate is True: # pylint: disable=g-bool-id-comparison 347 | self.assertEqual( 348 | set(outputs.keys()), 349 | { 350 | 'frame_embeddings', 351 | 'spatial_features', 352 | 'spatiotemporal_features', 353 | }, 354 | ) 355 | else: 356 | self.assertEqual(set(outputs.keys()), set(return_intermediate)) 357 | 358 | if 'spatial_features' in outputs: 359 | self.assertEqual( 360 | outputs['spatial_features'].shape, 361 | (batch_size, num_frames * (image_size // patch_size) ** 2, dim), 362 | ) 363 | if 'spatiotemporal_features' in outputs: 364 | self.assertEqual( 365 | outputs['spatiotemporal_features'].shape, 366 | (batch_size, num_frames * (image_size // patch_size) ** 2, dim), 367 | ) 368 | if 'frame_embeddings' in outputs: 369 | self.assertEqual( 370 | outputs['frame_embeddings'].shape, (batch_size, num_frames, dim) 371 | ) 372 | 373 | 374 | if __name__ == '__main__': 375 | absltest.main() 376 | -------------------------------------------------------------------------------- /videoprism/models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 VideoPrism Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Provides builders and loaders of VideoPrism checkpoints. 16 | 17 | The v1 base model takes videos with shape (16, 288, 288) as inputs and outputs 18 | embeddings with shape (batch_size, 4096, 768) which could be reshaped into 19 | (batch_size, 16, 16, 16, 768) for spatiotemporal representations. The input 20 | videos should be normalized in [0.0, 1.0]. 21 | 22 | Example usage: 23 | ``` 24 | from videoprism import models as vp 25 | 26 | model_name = 'videoprism_public_v1_base' 27 | flax_model = vp.get_model(model_name) 28 | loaded_state = vp.load_pretrained_weights(model_name) 29 | 30 | @jax.jit 31 | def forward_fn(inputs): 32 | return flax_model.apply(loaded_state, inputs, train=False) 33 | 34 | model_inputs = ... 35 | outputs = forward_fn(model_inputs) 36 | ``` 37 | """ 38 | 39 | from collections.abc import Callable, Mapping, Sequence 40 | import functools 41 | 42 | from flax import linen as nn 43 | import jax 44 | import jax.numpy as jnp 45 | import huggingface_hub 46 | import numpy as np 47 | from videoprism import encoders 48 | from videoprism import tokenizers 49 | from videoprism import utils 50 | 51 | K400_NUM_CLASSES: int = 400 52 | SSV2_NUM_CLASSES: int = 174 53 | 54 | TEXT_MAX_LEN: int = 64 55 | TEXT_TOKENIZERS = { 56 | 'c4_en': { 57 | 'model_path': 'gs://t5-data/vocabs/cc_en.32000/sentencepiece.model', 58 | 'vocab_size': 32_000, 59 | }, 60 | } 61 | 62 | CHECKPOINTS = { 63 | # Hugging Face checkpoints (repository, filename). 64 | 'videoprism_public_v1_base': ( 65 | 'google/videoprism-base-f16r288', 66 | 'flax_base_f16r288_repeated.npz', 67 | ), 68 | 'videoprism_public_v1_large': ( 69 | 'google/videoprism-large-f8r288', 70 | 'flax_large_f8r288_repeated.npz', 71 | ), 72 | 'videoprism_lvt_public_v1_base': ( 73 | 'google/videoprism-lvt-base-f16r288', 74 | 'flax_lvt_base_f16r288_repeated.npz', 75 | ), 76 | 'videoprism_lvt_public_v1_large': ( 77 | 'google/videoprism-lvt-large-f8r288', 78 | 'flax_lvt_large_f8r288_repeated.npz', 79 | ), 80 | } 81 | 82 | CONFIGS = { 83 | 'videoprism_v1_base': dict( 84 | patch_size=18, 85 | pos_emb_shape=(16, 16, 16), 86 | model_dim=768, 87 | num_spatial_layers=12, 88 | num_temporal_layers=4, 89 | num_heads=12, 90 | mlp_dim=3072, 91 | atten_logit_cap=50.0, 92 | scan=True, 93 | ), 94 | 'videoprism_v1_large': dict( 95 | patch_size=18, 96 | pos_emb_shape=(8, 16, 16), 97 | model_dim=1024, 98 | num_spatial_layers=24, 99 | num_temporal_layers=4, 100 | num_heads=16, 101 | mlp_dim=4096, 102 | atten_logit_cap=50.0, 103 | scan=True, 104 | ), 105 | 'videoprism_v1_giant': dict( 106 | patch_size=18, 107 | pos_emb_shape=(8, 16, 16), 108 | model_dim=1408, 109 | num_spatial_layers=40, 110 | num_temporal_layers=4, 111 | num_heads=16, 112 | mlp_dim=6144, 113 | atten_logit_cap=50.0, 114 | scan=True, 115 | ), 116 | 'videoprism_lvt_v1_base': dict( 117 | patch_size=18, 118 | pos_emb_shape=(16, 16, 16), 119 | num_spatial_layers=12, 120 | num_temporal_layers=4, 121 | mlp_dim=3072, 122 | num_auxiliary_layers=2, 123 | enable_causal_atten=True, 124 | num_unimodal_layers=12, 125 | norm_policy='pre', 126 | model_dim=768, 127 | num_heads=12, 128 | atten_logit_cap=50.0, 129 | scan=True, 130 | ), 131 | 'videoprism_lvt_v1_large': dict( 132 | patch_size=18, 133 | pos_emb_shape=(8, 16, 16), 134 | num_spatial_layers=24, 135 | num_temporal_layers=4, 136 | mlp_dim=4096, 137 | num_auxiliary_layers=2, 138 | enable_causal_atten=True, 139 | num_unimodal_layers=12, 140 | norm_policy='pre', 141 | model_dim=1024, 142 | num_heads=16, 143 | atten_logit_cap=50.0, 144 | scan=True, 145 | ), 146 | 'videoprism_lvt_v1_giant': dict( 147 | patch_size=18, 148 | pos_emb_shape=(8, 16, 16), 149 | num_spatial_layers=40, 150 | num_temporal_layers=4, 151 | mlp_dim=6144, 152 | num_auxiliary_layers=2, 153 | enable_causal_atten=True, 154 | num_unimodal_layers=16, 155 | norm_policy='primer_hybrid', 156 | model_dim=1408, 157 | num_heads=16, 158 | atten_logit_cap=50.0, 159 | scan=True, 160 | ), 161 | } 162 | 163 | 164 | def videoprism_v1_base(): 165 | """Builds VideoPrism v1 base model.""" 166 | return encoders.FactorizedEncoder(**CONFIGS['videoprism_v1_base']) 167 | 168 | 169 | def videoprism_v1_large(): 170 | """Builds VideoPrism v1 large model.""" 171 | return encoders.FactorizedEncoder(**CONFIGS['videoprism_v1_large']) 172 | 173 | 174 | def videoprism_v1_giant(): 175 | """Builds VideoPrism v1 giant model.""" 176 | return encoders.FactorizedEncoder(**CONFIGS['videoprism_v1_giant']) 177 | 178 | 179 | def videoprism_lvt_v1_base(text_tokenizer: str = 'c4_en'): 180 | """Builds VideoPrism LvT v1 base model.""" 181 | config = CONFIGS['videoprism_lvt_v1_base'] 182 | config['vocabulary_size'] = TEXT_TOKENIZERS[text_tokenizer]['vocab_size'] 183 | return encoders.FactorizedVideoCLIP(**config) 184 | 185 | 186 | def videoprism_lvt_v1_large(text_tokenizer: str = 'c4_en'): 187 | """Builds VideoPrism LvT v1 large model.""" 188 | config = CONFIGS['videoprism_lvt_v1_large'] 189 | config['vocabulary_size'] = TEXT_TOKENIZERS[text_tokenizer]['vocab_size'] 190 | return encoders.FactorizedVideoCLIP(**config) 191 | 192 | 193 | def videoprism_lvt_v1_giant(text_tokenizer: str = 'c4_en'): 194 | """Builds VideoPrism LvT v1 giant model.""" 195 | config = CONFIGS['videoprism_lvt_v1_giant'] 196 | config['vocabulary_size'] = TEXT_TOKENIZERS[text_tokenizer]['vocab_size'] 197 | return encoders.FactorizedVideoCLIP(**config) 198 | 199 | 200 | def videoprism_vc_v1_base(num_classes: int): 201 | """Builds VideoPrism Classification v1 base model.""" 202 | encoder_params = CONFIGS['videoprism_v1_base'] 203 | return encoders.FactorizedVideoClassifier( 204 | encoder_params=encoder_params, num_classes=num_classes 205 | ) 206 | 207 | 208 | def videoprism_vc_v1_large(num_classes: int): 209 | """Builds VideoPrism Classification v1 large model.""" 210 | encoder_params = CONFIGS['videoprism_v1_large'] 211 | return encoders.FactorizedVideoClassifier( 212 | encoder_params=encoder_params, num_classes=num_classes 213 | ) 214 | 215 | 216 | def videoprism_vc_v1_giant(num_classes: int): 217 | """Builds VideoPrism Classification v1 giant model.""" 218 | encoder_params = CONFIGS['videoprism_v1_giant'] 219 | return encoders.FactorizedVideoClassifier( 220 | encoder_params=encoder_params, num_classes=num_classes 221 | ) 222 | 223 | 224 | MODELS = { 225 | 'videoprism_public_v1_base': videoprism_v1_base, 226 | 'videoprism_public_v1_large': videoprism_v1_large, 227 | 'videoprism_lvt_public_v1_base': functools.partial( 228 | videoprism_lvt_v1_base, text_tokenizer='c4_en' 229 | ), 230 | 'videoprism_lvt_public_v1_large': functools.partial( 231 | videoprism_lvt_v1_large, text_tokenizer='c4_en' 232 | ), 233 | } 234 | 235 | 236 | def _get_model_name_by_hf_model_id(model_id: str) -> str | None: 237 | """Returns model name for the given Hugging Face model ID. 238 | 239 | Hugging Face model ID is typically the name of the repository, e.g., 240 | `google/videoprism-base-f16r288`. 241 | 242 | Args: 243 | model_id: A string for the Hugging Face model ID. 244 | 245 | Returns: 246 | The model name for the given Hugging Face model ID or None if not found. 247 | """ 248 | for model_name, value in CHECKPOINTS.items(): 249 | if isinstance(value, tuple) and value[0] == model_id: 250 | return model_name 251 | 252 | return None 253 | 254 | 255 | def has_model( 256 | model_name: str, 257 | models: Mapping[str, Callable[[], nn.Module]] | None = None, 258 | ) -> bool: 259 | """Returns whether the model is available.""" 260 | models = models or MODELS 261 | if model_name.startswith('google/'): 262 | # Handle Hugging Face model ID. 263 | model_name = _get_model_name_by_hf_model_id(model_name) 264 | 265 | return model_name is not None and model_name in models 266 | 267 | 268 | def get_model( 269 | model_name: str | None, 270 | model_fn: Callable[[], nn.Module] | None = None, 271 | models: Mapping[str, Callable[[], nn.Module]] | None = None, 272 | fprop_dtype: jax.typing.DTypeLike | None = None, 273 | ): 274 | """Returns VideoPrism model with the given name. 275 | 276 | Args: 277 | model_name: A string for the model name or Hugging Face model ID. 278 | model_fn: Optional function that returns the model. 279 | models: Mapping from model name to model creation function. Used with 280 | `model_name`. If None, use the default `MODELS`. 281 | 282 | Returns: 283 | A Flax VideoPrism model. 284 | """ 285 | 286 | if model_fn is None: 287 | assert model_name is not None 288 | models = models or MODELS 289 | if model_name.startswith('google/'): 290 | # Handle Hugging Face model ID. 291 | model_name = _get_model_name_by_hf_model_id(model_name) 292 | if model_name is None: 293 | raise ValueError(f'Failed to find model name with `{model_name}`.') 294 | 295 | if model_name not in models: 296 | raise ValueError(f'Model `{model_name}` not found.') 297 | 298 | model_fn = models[model_name] 299 | 300 | model = model_fn() 301 | if fprop_dtype is not None: 302 | model.fprop_dtype = fprop_dtype 303 | return model 304 | 305 | 306 | def load_pretrained_weights( 307 | model_name: str | None, 308 | checkpoint_path: str | None = None, 309 | checkpoints: Mapping[str, str | tuple[str, str]] | None = None, 310 | ): 311 | """Loads pretrained model weights. 312 | 313 | Args: 314 | model_name: A string for the model name or Hugging Face model ID. 315 | checkpoint_path: Optional path of the model checkpoint. 316 | checkpoints: Mapping from model name to checkpoint path. Used with 317 | `model_name`. If None, use the default `CHECKPOINTS`. 318 | 319 | Returns: 320 | Restored Flax model weights. 321 | """ 322 | checkpoints = checkpoints or CHECKPOINTS 323 | 324 | if checkpoint_path is None: 325 | assert model_name is not None 326 | if model_name.startswith('google/'): 327 | # Handle Hugging Face model ID. 328 | model_name = _get_model_name_by_hf_model_id(model_name) 329 | 330 | repo_id, filename = checkpoints[model_name] 331 | checkpoint_path = huggingface_hub.hf_hub_download( 332 | repo_id=repo_id, filename=filename 333 | ) 334 | 335 | variables = utils.load_checkpoint(checkpoint_path) 336 | return jax.tree_util.tree_map(jnp.asarray, variables) 337 | 338 | 339 | def load_text_tokenizer(name: str) -> tokenizers.Tokenizer: 340 | """Loads a text tokenizer by name. 341 | 342 | Args: 343 | name: A string for the text tokenizer model name. 344 | 345 | Returns: 346 | A text tokenizer. 347 | """ 348 | if name not in TEXT_TOKENIZERS: 349 | raise ValueError(f'Text tokenizer `{name}` not found.') 350 | 351 | model_path = TEXT_TOKENIZERS[name]['model_path'] 352 | return tokenizers.SentencePieceTokenizer(model_path) 353 | 354 | 355 | def tokenize_texts( 356 | tokenizer: tokenizers.Tokenizer, 357 | inputs: Sequence[str], 358 | max_length: int = TEXT_MAX_LEN, 359 | add_bos: bool | None = None, 360 | canonicalize: bool = True, 361 | ) -> tuple[np.ndarray, np.ndarray]: 362 | """Tokenizes a batch of texts. 363 | 364 | Args: 365 | tokenizer: The tokenizer to use. 366 | inputs: The list of texts to tokenize. 367 | max_length: The maximum length of the tokenized texts. 368 | add_bos: Whether to add a beginning-of-sentence token. If None, the 369 | beginning-of-sentence token will be added if the tokenizer's bos_token is 370 | a non-negative integer. 371 | canonicalize: Whether to canonicalize the texts before tokenization. 372 | 373 | Returns: 374 | A tuple of two numpy arrays containing the padded token ids and the 375 | corresponding paddings, where 1 denotes padding token. 376 | """ 377 | 378 | if canonicalize: 379 | inputs = [utils.canonicalize_text(text) for text in inputs] 380 | 381 | batch_ids, batch_paddings = [], [] 382 | if add_bos is None: 383 | add_bos = tokenizer.bos_token >= 0 384 | 385 | for ids in tokenizer.to_int(inputs, bos=add_bos, eos=False): 386 | ids_seq_len = len(ids) 387 | if ids_seq_len > max_length: 388 | ids = ids[:max_length] 389 | 390 | ids = np.asarray(ids, dtype=np.int32) 391 | paddings = np.zeros_like(ids, dtype=np.float32) 392 | 393 | if ids_seq_len < max_length: 394 | ids = np.pad( 395 | ids, (0, max_length - ids_seq_len), 'constant', constant_values=0 396 | ) 397 | paddings = np.pad( 398 | paddings, 399 | (0, max_length - ids_seq_len), 400 | 'constant', 401 | constant_values=1.0, 402 | ) 403 | 404 | batch_ids.append(ids) 405 | batch_paddings.append(paddings) 406 | 407 | return np.asarray(batch_ids), np.asarray(batch_paddings) 408 | -------------------------------------------------------------------------------- /videoprism/encoders.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 VideoPrism Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Modules for video encoders.""" 16 | 17 | from collections.abc import Collection, Sequence 18 | import dataclasses 19 | import math 20 | from typing import Any 21 | 22 | import einops 23 | import einshape 24 | from flax import linen as nn 25 | import jax 26 | from jax import numpy as jnp 27 | import numpy as np 28 | from videoprism import layers 29 | 30 | Array = jax.Array 31 | Variables = nn.module.VariableDict 32 | 33 | default_kernel_init = layers.default_kernel_init 34 | 35 | 36 | def _contains(collection: Collection[str] | bool, key: str) -> bool: 37 | """Checks if a collection contains a key. 38 | 39 | Args: 40 | collection: A collection of strings or a boolean value. 41 | key: A string key to check. 42 | 43 | Returns: 44 | True if the collection contains the key, or if the collection is a True 45 | boolean. False otherwise. 46 | """ 47 | return collection if isinstance(collection, bool) else key in collection 48 | 49 | 50 | def _l2_normalize( 51 | x: Array, axis: int | Sequence[int] = -1, epsilon: float = 1e-12 52 | ) -> Array: 53 | """L2-normalizes a jax.Array along certain dimension. 54 | 55 | Args: 56 | x: An input jax.Array. 57 | axis: An integer or a sequence of integers for the axis to normalize. 58 | epsilon: A small constant for numerical stability. 59 | 60 | Returns: 61 | Normalized jax.Array. 62 | """ 63 | x_dtype = x.dtype 64 | # Always convert embed to float32 for all precisions. 65 | x = x.astype(jnp.float32) 66 | norm = jnp.sqrt(jnp.sum(x * x, axis=axis, keepdims=True) + epsilon) 67 | return (x / norm).astype(x_dtype) 68 | 69 | 70 | def _image_to_patch(inputs: Array, patch_size: int) -> Array: 71 | """Converts an image to patches. 72 | 73 | Args: 74 | inputs: A jax.Array of shape [B, H, W, C] , 75 | patch_size: An integer for dimension of a square patch. 76 | 77 | Returns: 78 | batched_patches: [B, (H * W / P^2), P^2 * C]. 79 | """ 80 | if len(inputs.shape) < 4: 81 | raise ValueError( 82 | f'Image should be formatted as 4D [B, H, W, C], Shape: {inputs.shape}' 83 | ) 84 | height, width, channels = inputs.shape[-3:] 85 | 86 | if height % patch_size != 0 or width % patch_size != 0: 87 | raise ValueError( 88 | f'Image height ({height}) and width ({width}) should be multiples ' 89 | f'of patch_size ({patch_size}).' 90 | ) 91 | 92 | row_blocks = height // patch_size 93 | column_blocks = width // patch_size 94 | 95 | patches = einops.rearrange( 96 | inputs, 97 | '... (m p)(n q) c->...(m n)(p q c)', 98 | m=row_blocks, 99 | n=column_blocks, 100 | p=patch_size, 101 | q=patch_size, 102 | c=channels, 103 | ) 104 | return patches 105 | 106 | 107 | def _interpolate_emb_1d(emb: Array, target_emb_length: int) -> Array: 108 | """Interpolates a 1D positional embedding to a new shape. 109 | 110 | Args: 111 | emb: jax.Array, (1, N, D), flattened 1D positional embedding. 112 | target_emb_length: length of the target embedding. 113 | 114 | Returns: 115 | Flattened, interpolated embedding of shape (1, target_emb_length, D) 116 | """ 117 | 118 | if len(emb.shape) > 3 or emb.shape[0] != 1: 119 | raise ValueError('The shape of the embedding should be (1, N, D)') 120 | 121 | emb_dim = emb.shape[-1] 122 | emb = jnp.squeeze(emb, axis=0) 123 | 124 | target_emb = jax.image.resize( 125 | emb, (target_emb_length, emb_dim), method='bilinear' 126 | ) 127 | target_emb = jnp.reshape(target_emb, (1, target_emb_length, emb_dim)) 128 | return target_emb 129 | 130 | 131 | def _interpolate_emb_2d( 132 | emb: Array, 133 | source_emb_shape: tuple[int, int], 134 | target_emb_shape: tuple[int, int], 135 | ) -> Array: 136 | """Interpolates a 2D positional embedding to a new shape. 137 | 138 | Args: 139 | emb: A jax.Array of shape (1, H1xW1, D) for flattened 2D positional 140 | embedding. 141 | source_emb_shape: Tuple, (H1, W1), height and width of the source embedding. 142 | target_emb_shape: Tuple, (H2, W2), height and width of the target embedding. 143 | 144 | Returns: 145 | Flattened, interpolated embedding of shape (1, H2xW2, D) 146 | """ 147 | 148 | if len(emb.shape) > 3 or emb.shape[0] != 1: 149 | raise ValueError('The shape of the embedding should be (1, H * W, D)') 150 | 151 | if emb.shape[-2] != source_emb_shape[0] * source_emb_shape[1]: 152 | raise ValueError('The shape of the embedding does NOT match input specs.') 153 | 154 | emb_dim = emb.shape[-1] 155 | emb = jnp.reshape(emb, (source_emb_shape[0], source_emb_shape[1], emb_dim)) 156 | 157 | target_emb = jax.image.resize( 158 | emb, 159 | (target_emb_shape[0], target_emb_shape[1], emb_dim), 160 | method='bilinear', 161 | ) 162 | target_emb = jnp.reshape( 163 | target_emb, (1, target_emb_shape[0] * target_emb_shape[1], emb_dim) 164 | ) 165 | return target_emb 166 | 167 | 168 | class Embedding(layers.Module): 169 | """A simple embedding layer that performs embedding lookups from ids. 170 | 171 | Attributes: 172 | num_classes: Number of tokens in the vocabulary. 173 | input_dim: Depth of the embedding output. This is called `input_dim` as 174 | opposed to the more appropriate `embedding_dim` to be compatible with 175 | other embedding layers defined in this file. 176 | lookup_style: Style of lookup, one of index or matmul. 177 | scale_sqrt_depth: If set to True, activations are scaled with 178 | sqrt(embedding_dim) in embeding lookup. 179 | set_nan_for_oob_id: If set to True, embeddings corresponding to 180 | out-of-boundaries ids will be set to NaN. 181 | """ 182 | 183 | num_classes: int = 0 184 | input_dim: int = 0 185 | lookup_style: str = 'index' 186 | scale_sqrt_depth: bool = False 187 | set_nan_for_oob_id: bool = False 188 | 189 | @nn.compact 190 | def __call__(self, ids: Array) -> Array: 191 | """Generates a jax.Array of embedding lookup result. 192 | 193 | Args: 194 | ids: Indexes of shape [...] for embedding lookup. 195 | 196 | Returns: 197 | A jax.Array of shape [..., input_dim]. 198 | """ 199 | emb_var = self._cast_to_fprop_dtype( 200 | self.param( 201 | 'emb_var', 202 | nn.initializers.normal(stddev=1.0 / math.sqrt(self.input_dim)), 203 | [self.num_classes, self.input_dim], 204 | self.dtype, 205 | ) 206 | ) 207 | if self.lookup_style == 'index': 208 | embs = jnp.asarray(emb_var)[(ids,)] 209 | elif self.lookup_style == 'matmul': 210 | one_hot_ids = jax.nn.one_hot( 211 | ids, self.num_classes, dtype=self.fprop_dtype 212 | ) 213 | embs = jnp.einsum('...y,yz->...z', one_hot_ids, emb_var) 214 | else: 215 | raise ValueError(f'Unknown lookup style: `{self.lookup_style}`.') 216 | 217 | # Map out-of-boundary ids to NaN. 218 | if self.set_nan_for_oob_id: 219 | embs = jnp.where(ids[..., jnp.newaxis] < self.num_classes, embs, jnp.nan) 220 | 221 | if self.scale_sqrt_depth: 222 | embs *= self.input_dim**0.5 223 | 224 | return embs 225 | 226 | 227 | class PositionalEmbedding(layers.Module): 228 | """Generates position embedding for a given 1-d sequence. 229 | 230 | Attributes: 231 | embedding_dim: Dimension of the embedding to be generated. 232 | min_timescale: Start of the geometric index. 233 | max_timescale: End of the geometric index. 234 | """ 235 | 236 | embedding_dim: int = 0 237 | min_timescale: int = 1 238 | max_timescale: int = 10_000 239 | 240 | def __call__(self, seq_length: int) -> Array: 241 | """Generates a jax.Array of embedding lookup result. 242 | 243 | Args: 244 | seq_length: Sequence length of the embeddings to be generated. 245 | 246 | Returns: 247 | A jax.Array of shape [1, seq_length, embedding_dim]. 248 | """ 249 | position = jnp.arange(seq_length, dtype=jnp.float32)[jnp.newaxis, :] 250 | num_timescales = self.embedding_dim // 2 251 | log_timescale_increment = math.log( 252 | float(self.max_timescale) / float(self.min_timescale) 253 | ) / jnp.maximum(jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1) 254 | inv_timescales = self.min_timescale * jnp.exp( 255 | jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment 256 | ) 257 | scaled_time = ( 258 | position[:, :, jnp.newaxis] 259 | * inv_timescales[jnp.newaxis, jnp.newaxis, :] 260 | ) 261 | embs = jnp.concatenate( 262 | [jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=-1 263 | ).astype(self.fprop_dtype) 264 | # Force usage of `np` to compute static values at trace time. 265 | embs = jnp.pad(embs, [[0, 0], [0, 0], [0, np.mod(self.embedding_dim, 2)]]) 266 | return embs 267 | 268 | 269 | class TrainablePositionalEmbedding(layers.Module): 270 | """Generates trainable position embedding for a given 1-d sequence. 271 | 272 | Attributes: 273 | embedding_dim: Dimension of the embedding to be generated. 274 | max_seq_length: Max sequence length. 275 | lookup_style: Style of lookup, one of index or matmul. 276 | """ 277 | 278 | embedding_dim: int = 0 279 | max_seq_length: int = 10_240 280 | lookup_style: str = 'matmul' 281 | 282 | @nn.compact 283 | def __call__(self, seq_length: int) -> Array: 284 | """Generates a jax.Array of embedding lookup result. 285 | 286 | Args: 287 | seq_length: Sequence length of the embeddings to be generated. 288 | 289 | Returns: 290 | A jax.Array of shape [1, seq_length, embedding_dim]. 291 | """ 292 | position = jnp.arange(seq_length, dtype=jnp.int32)[jnp.newaxis, :] 293 | pos_emb_var = self._cast_to_fprop_dtype( 294 | self.param( 295 | 'emb_var', 296 | default_kernel_init, 297 | [self.max_seq_length, self.embedding_dim], 298 | self.dtype, 299 | ) 300 | ) 301 | pos_emb_var = jax.lax.slice_in_dim(pos_emb_var, 0, seq_length, axis=0) 302 | if self.lookup_style == 'matmul': 303 | one_hot_ids = jax.nn.one_hot(position, seq_length, dtype=self.fprop_dtype) 304 | embs = jnp.einsum('...y,yz->...z', one_hot_ids, pos_emb_var) 305 | else: 306 | raise ValueError(f'Unknown lookup style: `{self.lookup_style}`.') 307 | return embs 308 | 309 | 310 | class VisionTransformer(layers.Module): 311 | """Vision transformer model. 312 | 313 | This class follows a minimalistic design pattern. Users need to configure the 314 | templates for the submodules themselves; this increases the generalizability 315 | of this class. 316 | 317 | Attributes: 318 | num_tfm_layers: Number of layers in this model. 319 | mlp_dim: The hidden layer dimension of FFN in Transformer layers. 320 | num_heads: Number of attention heads. 321 | xformer_has_bias: Whether to use bias. 322 | xformer_dropout_prob: Apply dropout at this prob at various places. 323 | xformer_atten_dropout_prob: Probability at which we apply dropout to the 324 | attention weights. 325 | xformer_residual_dropout_prob: Probability at which we apply dropout to the 326 | residual layers, such that, residual(x, y) = (x + dropout(y)). 327 | xformer_relu_dropout_prob: Probability at which we apply dropout to the FFN 328 | layers. 329 | atten_logit_cap: Cap the absolute values of logits by tanh. Enabled when a 330 | positive value is specified. May not be supported by a subclass. 331 | norm_policy: Policy for applying normalization wrt. transformations. Options 332 | are: (1) "pre", applied before transformation. (2) "primer_hybrid", 333 | applied before and after transformation. (3) "post", applied after 334 | transformation. (4) "post_skip", applied after the skip connection. 335 | scan: Whether to use `nn.remat` and`nn.scan`. 336 | """ 337 | 338 | num_tfm_layers: int = 12 339 | mlp_dim: int = 3072 340 | num_heads: int = 12 341 | xformer_has_bias: bool = True 342 | xformer_dropout_prob: float = 0.0 343 | xformer_atten_dropout_prob: float | None = None 344 | xformer_residual_dropout_prob: float | None = None 345 | xformer_relu_dropout_prob: float | None = None 346 | atten_logit_cap: float = 0.0 347 | norm_policy: str = 'pre' 348 | scan: bool = False 349 | 350 | @nn.compact 351 | def __call__( 352 | self, inputs: Array, paddings: Array | None = None, train: bool = False 353 | ) -> Array: 354 | """Applies the ViT model to the inputs. 355 | 356 | Args: 357 | inputs: Input tensor of shape [B, N, D], which are sequences of embeddings 358 | or patches. 359 | paddings: Optional [B, N] padding field of inputs when inputs are with [B, 360 | N, D]. 361 | train: If the model is in the train mode. 362 | 363 | Returns: 364 | Output tensor of shape [B, N, D]. 365 | """ 366 | features = inputs 367 | if paddings is None: 368 | paddings = jnp.zeros(features.shape[:-1], dtype=features.dtype) 369 | features = layers.StackedTransformer( 370 | name='transformers_stack', 371 | num_layers=self.num_tfm_layers, 372 | hidden_dim=self.mlp_dim, 373 | num_heads=self.num_heads, 374 | dropout_prob=self.xformer_dropout_prob, 375 | atten_dropout_prob=self.xformer_atten_dropout_prob, 376 | residual_dropout_prob=self.xformer_residual_dropout_prob, 377 | relu_dropout_prob=self.xformer_relu_dropout_prob, 378 | use_bias=self.xformer_has_bias, 379 | atten_logit_cap=self.atten_logit_cap, 380 | norm_policy=self.norm_policy, 381 | internal_enable_per_dim_scale=False, 382 | activation_fn=layers.gelu, 383 | enable_causal_atten=False, 384 | scan=self.scan, 385 | dtype=self.dtype, 386 | fprop_dtype=self.fprop_dtype, 387 | )(features, paddings, train=train) 388 | return features 389 | 390 | 391 | class FactorizedEncoder(layers.Module): 392 | """Factorized encoder from the paper `ViViT: A Video Vision Transformer`. 393 | 394 | This is an implementation of model-2 in the paper. It applies ViT model for 395 | video data based on the factorized space-time encoder. 396 | 397 | Reference: https://arxiv.org/abs/2103.15691 398 | """ 399 | 400 | patch_size: int = 18 401 | pos_emb_shape: tuple[int, int, int] = (16, 16, 16) 402 | model_dim: int = 768 403 | num_spatial_layers: int = 12 404 | num_temporal_layers: int = 4 405 | num_heads: int = 12 406 | mlp_dim: int = 3072 407 | atten_logit_cap: float = 0.0 408 | norm_policy: str = 'pre' 409 | scan: bool = False 410 | 411 | def __call__( 412 | self, 413 | inputs: Array, 414 | train: bool = False, 415 | return_intermediate: bool | Collection[str] = False, 416 | frame_paddings: Array | None = None, 417 | ) -> tuple[Array, dict[str, Array]]: 418 | """Computes predictions for batched inputs. 419 | 420 | Args: 421 | inputs: Input image tensor of shape [B, T, H, W, 3] (H == W). 422 | train: If the model is in the train mode. 423 | return_intermediate: A boolean for whether all intermediate features are 424 | returned, or a container of intermediate feature names to return. 425 | frame_paddings: Optional binary tensor of shape [B, T] indicating padding. 426 | 1 denotes padding frame. 427 | 428 | Returns: 429 | embeddings: Output tensor for video embeddings of shape [B, T * N, D]. 430 | outputs: A dictionary of additional outputs, including `spatial_features` 431 | (shape = [B, T * N, D]). Empty if `return_intermediate` is False or does 432 | not contain 'spatial_features'. 433 | """ 434 | b, t, h, w, c = inputs.shape 435 | assert h == w 436 | reshaped_inputs = inputs.reshape(b * t, h, w, c) # (B * T, H, W, C). 437 | 438 | # Tokenization. 439 | patches = _image_to_patch(reshaped_inputs, self.patch_size) 440 | patches_paddings = None 441 | if frame_paddings is not None: 442 | assert frame_paddings.shape == (b, t) 443 | reshaped_frame_paddings = frame_paddings.reshape(b * t) # (B * T,). 444 | num_patches = patches.shape[1] 445 | patches_paddings = jnp.repeat( 446 | reshaped_frame_paddings[:, jnp.newaxis], num_patches, axis=-1 447 | ) # (B * T, num_patches). 448 | 449 | embeddings, outputs = self.encode_with_patches( 450 | patches=patches, 451 | image_shape=(t, h, w), 452 | train=train, 453 | return_intermediate=return_intermediate, 454 | patches_paddings=patches_paddings, 455 | ) 456 | return embeddings, outputs 457 | 458 | @nn.compact 459 | def encode_with_patches( 460 | self, 461 | patches: Array, 462 | image_shape: tuple[int, int, int], 463 | train: bool = False, 464 | return_intermediate: bool | Collection[str] = False, 465 | patches_paddings: Array | None = None, 466 | ) -> tuple[Array, dict[str, Array]]: 467 | """Computes predictions for patches. 468 | 469 | Args: 470 | patches: Input patches tensor of shape [B * T, (H * W / P^2), P^2 * C]. 471 | image_shape: Original image shape (T, H, W). 472 | train: If the model is in the train mode. 473 | return_intermediate: A boolean for whether all intermediate features are 474 | returned, or a collection of intermediate feature names to return. 475 | patches_paddings: Optional binary tensor of shape [B * T, (H * W / P^2)] 476 | indicating padding. 1 denotes padded patch. 477 | 478 | Returns: 479 | embeddings: Output tensor for video embedding sequence of shape [B, T * N, 480 | D]. 481 | outputs: A dictionary of additional outputs, including `spatial_features` 482 | of shape [B, T * N, D]. Empty if `return_intermediate` is False or does 483 | not contain 'spatial_features'. 484 | """ 485 | t, h, w = image_shape 486 | b = patches.shape[0] // t 487 | 488 | patches = layers.FeedForward( # (B * T, N, D). 489 | name='patch_projection', 490 | output_dim=self.model_dim, 491 | activation_fn=layers.identity, 492 | dtype=self.dtype, 493 | fprop_dtype=self.fprop_dtype, 494 | )(patches) 495 | 496 | # Add spatial positional encoding. 497 | spatial_pos_emb_shape = self.pos_emb_shape[-2:] 498 | spatial_seq_length = np.prod(spatial_pos_emb_shape) 499 | spatial_pos_emb = TrainablePositionalEmbedding( 500 | name='spatial_pos_emb', 501 | embedding_dim=self.model_dim, 502 | max_seq_length=spatial_seq_length, 503 | dtype=self.dtype, 504 | fprop_dtype=self.fprop_dtype, 505 | )(seq_length=spatial_seq_length) 506 | num_row_patches = h // self.patch_size 507 | num_col_patches = w // self.patch_size 508 | if spatial_pos_emb_shape != (num_row_patches, num_col_patches): 509 | spatial_pos_emb = _interpolate_emb_2d( 510 | spatial_pos_emb, 511 | spatial_pos_emb_shape, 512 | (num_row_patches, num_col_patches), 513 | ) 514 | patches += spatial_pos_emb # (B * T, N, D). 515 | 516 | # Get features from the spatial encoder. 517 | features = VisionTransformer( # (B * T, N, D). 518 | name='spatial_encoder', 519 | num_tfm_layers=self.num_spatial_layers, 520 | mlp_dim=self.mlp_dim, 521 | num_heads=self.num_heads, 522 | atten_logit_cap=self.atten_logit_cap, 523 | norm_policy=self.norm_policy, 524 | scan=self.scan, 525 | dtype=self.dtype, 526 | fprop_dtype=self.fprop_dtype, 527 | )(patches, train=train, paddings=patches_paddings) 528 | features = layers.LayerNorm( 529 | name='spatial_ln', dtype=self.dtype, fprop_dtype=self.fprop_dtype 530 | )(features) 531 | spatial_features = features 532 | 533 | # Instead of mean pooling, we keep the spatial tokens. 534 | # Shape = (B * N, T, D). 535 | features = einshape.jax_einshape('(bt)nd->(bn)td', features, t=t) 536 | temporal_paddings = None 537 | if patches_paddings is not None: 538 | temporal_paddings = einshape.jax_einshape( 539 | '(bt)n->(bn)t', patches_paddings, t=t 540 | ) # (B * N, T). 541 | 542 | # Add temporal positional encoding. 543 | temporal_seq_length = self.pos_emb_shape[0] 544 | temporal_pos_emb = TrainablePositionalEmbedding( 545 | name='temporal_pos_emb', 546 | embedding_dim=self.model_dim, 547 | max_seq_length=temporal_seq_length, 548 | dtype=self.dtype, 549 | fprop_dtype=self.fprop_dtype, 550 | )(seq_length=temporal_seq_length) 551 | if temporal_seq_length != t: 552 | temporal_pos_emb = _interpolate_emb_1d(temporal_pos_emb, t) 553 | features += temporal_pos_emb 554 | 555 | # Get features from the temporal encoder. 556 | features = VisionTransformer( 557 | name='temporal_encoder', 558 | num_tfm_layers=self.num_temporal_layers, 559 | mlp_dim=self.mlp_dim, 560 | num_heads=self.num_heads, 561 | atten_logit_cap=self.atten_logit_cap, 562 | norm_policy=self.norm_policy, 563 | scan=self.scan, 564 | dtype=self.dtype, 565 | fprop_dtype=self.fprop_dtype, 566 | )(features, train=train, paddings=temporal_paddings) 567 | features = layers.LayerNorm( 568 | name='temporal_ln', dtype=self.dtype, fprop_dtype=self.fprop_dtype 569 | )(features) 570 | features = einshape.jax_einshape( # (B, T * N, D). 571 | '(bn)td->b(tn)d', features, b=b 572 | ) 573 | 574 | embeddings, outputs = features, {} 575 | if _contains(return_intermediate, 'spatial_features'): 576 | outputs['spatial_features'] = einshape.jax_einshape( 577 | '(bt)nd->b(tn)d', spatial_features, t=t 578 | ) 579 | 580 | return embeddings, outputs 581 | 582 | 583 | class FactorizedVideoClassifier(layers.Module): 584 | """Video classifier with `FactorizedEncoder` backbone. 585 | 586 | Attributes: 587 | encoder_params: A dictionary of parameters for `FactorizedEncoder`. 588 | num_classes: Number of output classes. 589 | """ 590 | 591 | encoder_params: dict[str, Any] = dataclasses.field(default_factory=dict) 592 | num_classes: int = 0 593 | 594 | @nn.compact 595 | def __call__( 596 | self, 597 | inputs: Array, 598 | train: bool = False, 599 | return_intermediate: bool | Collection[str] = False, 600 | frame_paddings: Array | None = None, 601 | ): 602 | """Applies video classifier to inputs. 603 | 604 | Args: 605 | inputs: Input tensor of shape [B, T, H, W, 3]. 606 | train: Whether the model is in the training mode. 607 | return_intermediate: A boolean for whether all intermediate features are 608 | returned, or a collection of intermediate feature names to return. 609 | frame_paddings: Optional binary tensor of shape [B, T] indicating padding. 610 | 1 denotes padding frame. 611 | 612 | Returns: 613 | logits: Output tensor of shape [B, num_classes]. 614 | outputs: A dictionary of additional outputs, including `spatial_features` 615 | of shape [B, T * N, D], `spatiotemporal_features` of shape [B, T * N, 616 | D], and `global_embeddings` of shape [B, D]. Empty if 617 | `return_intermediate` is False. 618 | """ 619 | features, outputs = FactorizedEncoder( 620 | name='encoder', 621 | dtype=self.dtype, 622 | fprop_dtype=self.fprop_dtype, 623 | **self.encoder_params, 624 | )( 625 | inputs, 626 | train=train, 627 | return_intermediate=return_intermediate, 628 | frame_paddings=frame_paddings, 629 | ) 630 | if _contains(return_intermediate, 'spatiotemporal_features'): 631 | outputs['spatiotemporal_features'] = features 632 | 633 | embeddings = layers.AttenTokenPoolingLayer( 634 | name='atten_pooler', 635 | num_heads=self.encoder_params['num_heads'], 636 | hidden_dim=self.encoder_params['model_dim'], 637 | num_queries=1, 638 | dtype=self.dtype, 639 | fprop_dtype=self.fprop_dtype, 640 | )(features, paddings=None, train=train) 641 | embeddings = jnp.squeeze(embeddings, axis=-2) 642 | 643 | if _contains(return_intermediate, 'global_embeddings'): 644 | outputs['global_embeddings'] = embeddings 645 | 646 | logits = layers.FeedForward( 647 | name='projection', 648 | output_dim=self.num_classes, 649 | activation_fn=layers.identity, 650 | dtype=self.dtype, 651 | fprop_dtype=self.fprop_dtype, 652 | )(embeddings) 653 | return logits, outputs 654 | 655 | 656 | class TextEncoder(layers.Module): 657 | """CoCa-style text encoder. 658 | 659 | Reference: https://arxiv.org/abs/2205.01917 660 | 661 | Attributes: 662 | vocabulary_size: Vocabulary size of the text tokens. 663 | num_class_tokens: Number of class tokens. 664 | enable_causal_atten: Whether to enable causal attention. 665 | model_dim: The model dimension. 666 | num_tfm_layers: Number of layers in this model. 667 | mlp_dim: The hidden layer dimension of FFN in Transformer layers. 668 | num_heads: Number of attention heads. 669 | enable_per_dim_scale: Whether to ensable rescaling of attention logits with 670 | 1/sqrt(dim) factor. 671 | atten_logit_cap: Cap the absolute values of logits by tanh. Enabled when a 672 | positive value is specified. May not be supported by a subclass. 673 | norm_policy: Policy for applying normalization wrt. transformations. Options 674 | are: (1) "pre", applied before transformation. (2) "primer_hybrid", 675 | applied before and after transformation. (3) "post", applied after 676 | transformation. (4) "post_skip", applied after the skip connection. 677 | scan: Whether to use `nn.remat` and`nn.scan`. 678 | """ 679 | 680 | vocabulary_size: int = 128 681 | num_class_tokens: int = 0 682 | enable_causal_atten: bool = True 683 | model_dim: int = 768 684 | num_layers: int = 12 685 | mlp_dim: int = 3072 686 | num_heads: int = 12 687 | atten_logit_cap: float = 0.0 688 | norm_policy: str = 'pre' 689 | enable_per_dim_scale: bool = False 690 | scan: bool = False 691 | 692 | @nn.compact 693 | def __call__( 694 | self, inputs: Array, paddings: Array, train: bool = False 695 | ) -> Array: 696 | """Applies the text encoder to the inputs. 697 | 698 | Args: 699 | inputs: Input tensor of shape [B, N] including sequences of token ids. 700 | paddings: Optional [B, N] padding field of inputs. 701 | train: If the model is in the train mode. 702 | 703 | Returns: 704 | Output tensor of shape [B, N, D]. 705 | """ 706 | batch_size, seq_length = inputs.shape 707 | 708 | pos_emb = PositionalEmbedding( 709 | name='pos_emb', 710 | embedding_dim=self.model_dim, 711 | dtype=self.dtype, 712 | fprop_dtype=self.fprop_dtype, 713 | )(seq_length=seq_length) 714 | input_emb = Embedding( 715 | name='token_emb', 716 | num_classes=self.vocabulary_size, 717 | input_dim=self.model_dim, 718 | scale_sqrt_depth=True, 719 | dtype=self.dtype, 720 | fprop_dtype=self.fprop_dtype, 721 | )(inputs) 722 | features = input_emb + pos_emb 723 | 724 | if self.num_class_tokens > 0: 725 | cls_emb = self._cast_to_fprop_dtype( 726 | self.param( 727 | 'cls_emb', 728 | nn.initializers.normal(stddev=1.0 / math.sqrt(self.model_dim)), 729 | [1, self.num_class_tokens, self.model_dim], 730 | self.dtype, 731 | ) 732 | ) 733 | cls_emb = jnp.tile(cls_emb, [batch_size, 1, 1]) 734 | cls_emb *= self.model_dim**0.5 735 | features = jnp.concatenate([features, cls_emb], axis=-2) 736 | 737 | cls_paddings = jnp.zeros( 738 | [batch_size, self.num_class_tokens], dtype=paddings.dtype 739 | ) 740 | paddings = jnp.concatenate([paddings, cls_paddings], axis=-1) 741 | 742 | features = layers.StackedTransformer( 743 | name='unimodal_transformer', 744 | num_layers=self.num_layers, 745 | hidden_dim=self.mlp_dim, 746 | num_heads=self.num_heads, 747 | atten_logit_cap=self.atten_logit_cap, 748 | norm_policy=self.norm_policy, 749 | internal_enable_per_dim_scale=self.enable_per_dim_scale, 750 | activation_fn=jax.nn.relu, 751 | enable_causal_atten=self.enable_causal_atten, 752 | scan=self.scan, 753 | dtype=self.dtype, 754 | fprop_dtype=self.fprop_dtype, 755 | )(features, paddings, train=train) 756 | features = layers.LayerNorm( 757 | name='unimodal_ln', dtype=self.dtype, fprop_dtype=self.fprop_dtype 758 | )(features) 759 | return features 760 | 761 | 762 | class FactorizedVideoCLIP(layers.Module): 763 | """Video CLIP model with a factorized vision encoder.""" 764 | 765 | # Vision parameters. 766 | patch_size: int = 18 767 | pos_emb_shape: tuple[int, int, int] = (16, 16, 16) 768 | num_spatial_layers: int = 12 769 | num_temporal_layers: int = 4 770 | mlp_dim: int = 3072 771 | num_auxiliary_layers: int = 0 772 | # Text parameters. 773 | vocabulary_size: int = 128 774 | enable_causal_atten: bool = True 775 | num_unimodal_layers: int = 12 776 | norm_policy: str = 'pre' 777 | # Shared parameters. 778 | model_dim: int = 768 779 | num_heads: int = 12 780 | atten_logit_cap: float = 0.0 781 | scan: bool = False 782 | 783 | @nn.compact 784 | def __call__( 785 | self, 786 | inputs: Array | None = None, 787 | text_token_ids: Array | None = None, 788 | text_paddings: Array | None = None, 789 | train: bool = False, 790 | normalize: bool = True, 791 | return_intermediate: bool | Collection[str] = False, 792 | frame_paddings: Array | None = None, 793 | ) -> tuple[Array | None, Array | None, dict[str, Array]]: 794 | """Computes predictions for `input_batch`. 795 | 796 | Args: 797 | inputs: Input frame image tensor of shape [B, T, H, W, 3] (H == W). 798 | text_token_ids: Input text token id tensor of shape [B, L]. 799 | text_paddings: Input text paddings of shape [B, L]. Required if 800 | `text_token_ids` is not None. 801 | train: If the model is in the train mode. 802 | normalize: Whether to normalize the output embeddings. 803 | return_intermediate: A boolean for whether all intermediate features are 804 | returned, or a collection of intermediate feature names to return. 805 | frame_paddings: Optional binary tensor of shape [B, T] indicating padding. 806 | 1 denotes padding frame. 807 | 808 | Returns: 809 | video_embeddings: Output contrastive video embeddings of shape [B, D]. 810 | None if `inputs` is None. 811 | text_embeddings: Output contrastive text embeddings of shape [B, D]. None 812 | if `text_token_ids` is None. 813 | outputs: A dictionary of additional outputs, including `spatial_features` 814 | of shape [B, T * N, D], `spatiotemporal_features` of shape [B, T * N, 815 | D], and `frame_embeddings` of shape [B, T, D]. Empty if 816 | `return_intermediate` is False or does not contain `spatial_features`. 817 | """ 818 | video_embeddings, text_embeddings, outputs = None, None, {} 819 | 820 | if inputs is not None: 821 | num_frames = inputs.shape[-4] 822 | vision_features, vision_outputs = FactorizedEncoder( 823 | name='vision_encoder', 824 | patch_size=self.patch_size, 825 | pos_emb_shape=self.pos_emb_shape, 826 | model_dim=self.model_dim, 827 | num_spatial_layers=self.num_spatial_layers, 828 | num_temporal_layers=self.num_temporal_layers, 829 | num_heads=self.num_heads, 830 | mlp_dim=self.mlp_dim, 831 | atten_logit_cap=self.atten_logit_cap, 832 | norm_policy='pre', 833 | scan=self.scan, 834 | dtype=self.dtype, 835 | fprop_dtype=self.fprop_dtype, 836 | )( 837 | inputs, 838 | train=train, 839 | return_intermediate=return_intermediate, 840 | frame_paddings=frame_paddings, 841 | ) 842 | outputs.update(vision_outputs) 843 | if _contains(return_intermediate, 'spatiotemporal_features'): 844 | outputs['spatiotemporal_features'] = vision_features 845 | 846 | if self.num_auxiliary_layers > 0: 847 | vision_features = VisionTransformer( 848 | name='auxiliary_encoder', 849 | num_tfm_layers=self.num_auxiliary_layers, 850 | mlp_dim=self.mlp_dim, 851 | num_heads=self.num_heads, 852 | atten_logit_cap=self.atten_logit_cap, 853 | norm_policy='pre', 854 | scan=self.scan, 855 | dtype=self.dtype, 856 | fprop_dtype=self.fprop_dtype, 857 | )(vision_features, train=train) 858 | 859 | pooling_layer = layers.AttenTokenPoolingLayer( 860 | name='contrastive_vision_pooler', 861 | hidden_dim=self.model_dim * 4, 862 | num_heads=self.num_heads, 863 | num_queries=1, 864 | dtype=self.dtype, 865 | fprop_dtype=self.fprop_dtype, 866 | ) 867 | video_embeddings = pooling_layer(vision_features, None, train=train) 868 | 869 | # Squeeze the query dimension in the pooler output. 870 | video_embeddings = jnp.squeeze(video_embeddings, axis=-2) 871 | if normalize: 872 | video_embeddings = _l2_normalize(video_embeddings, axis=-1) 873 | 874 | if _contains(return_intermediate, 'frame_embeddings'): 875 | frame_features = einshape.jax_einshape( 876 | 'b(tn)d->(bt)nd', vision_features, t=num_frames 877 | ) 878 | frame_embeddings = pooling_layer(frame_features, None, train=train) 879 | frame_embeddings = jnp.squeeze(frame_embeddings, axis=-2) 880 | frame_embeddings = einshape.jax_einshape( 881 | '(bt)d->btd', frame_embeddings, t=num_frames 882 | ) 883 | if normalize: 884 | frame_embeddings = _l2_normalize(frame_embeddings, axis=-1) 885 | outputs['frame_embeddings'] = frame_embeddings 886 | 887 | if text_token_ids is not None: 888 | assert text_paddings is not None, 'Text paddings are required.' 889 | text_features = TextEncoder( 890 | name='text_encoder', 891 | vocabulary_size=self.vocabulary_size, 892 | num_class_tokens=1, 893 | enable_causal_atten=self.enable_causal_atten, 894 | model_dim=self.model_dim, 895 | num_layers=self.num_unimodal_layers, 896 | num_heads=self.num_heads, 897 | mlp_dim=self.model_dim * 4, 898 | atten_logit_cap=self.atten_logit_cap, 899 | norm_policy=self.norm_policy, 900 | scan=self.scan, 901 | dtype=self.dtype, 902 | fprop_dtype=self.fprop_dtype, 903 | )(text_token_ids, text_paddings, train=train) 904 | 905 | # Take the last token (i.e., class token) as the text embedding. 906 | text_embeddings = text_features[:, -1] 907 | if normalize: 908 | text_embeddings = _l2_normalize(text_embeddings, axis=-1) 909 | 910 | return video_embeddings, text_embeddings, outputs 911 | -------------------------------------------------------------------------------- /videoprism/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 VideoPrism Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """VideoPrism Flax layers.""" 16 | 17 | from collections.abc import Callable 18 | import functools 19 | import string 20 | from typing import Any 21 | from flax import linen as nn 22 | import jax 23 | from jax import numpy as jnp 24 | import numpy as np 25 | 26 | Array = jax.Array 27 | ActivationFunc = Callable[[Array], Array] 28 | Initializer = nn.initializers.Initializer 29 | 30 | default_kernel_init = nn.initializers.lecun_normal() 31 | gelu = functools.partial(jax.nn.gelu, approximate=False) 32 | 33 | 34 | def identity(x: Array) -> Array: 35 | """Identity activation.""" 36 | return x 37 | 38 | 39 | def _get_large_negative_number(dtype: jax.typing.DTypeLike) -> Array: 40 | """Returns a large-magnitude negative value for the given dtype.""" 41 | # -0.7 is a float64 in JAX. Explicit cast output to target dtype. 42 | if jnp.issubdtype(dtype, jnp.inexact): 43 | dtype_max = jnp.finfo(dtype).max 44 | elif jnp.issubdtype(dtype, jnp.integer): 45 | dtype_max = jnp.iinfo(dtype).max 46 | else: 47 | raise ValueError('Unsupported dtype for inputs.') 48 | return jnp.asarray(-0.7 * dtype_max, dtype=dtype) 49 | 50 | 51 | def _apply_mask_to_logits(logits: Array, mask: Array) -> Array: 52 | """Applies a floating-point mask to a set of logits. 53 | 54 | The mask is represented as a float32 tensor where 0 represents true and values 55 | below a large negative number (here set to 56 | _get_large_negative_number(jnp.float32) / 2) represent false. Applying the 57 | mask leaves the logits alone in the true case and replaces them by 58 | _get_large_negative_number(jnp.float32) in the false case. Previously, this 59 | was done by adding the logits to the mask; however, this leads to a bad fusion 60 | decision in the compiler that saves the float32 values in memory rather than 61 | just the predicate. This implementation avoids that problem. 62 | 63 | Args: 64 | logits: A jax.Array of logit values. 65 | mask: A jax.Array (float32) of mask values with the encoding described in 66 | the function documentation. 67 | 68 | Returns: 69 | Masked logits. 70 | """ 71 | min_value = _get_large_negative_number(logits.dtype) 72 | return jnp.where((mask >= min_value * 0.5), logits, min_value) 73 | 74 | 75 | def _convert_paddings_to_mask( 76 | paddings: Array, dtype: jax.typing.DTypeLike = jnp.float32 77 | ) -> Array: 78 | """Converts binary paddings to a logit mask ready to add to attention matrix. 79 | 80 | Args: 81 | paddings: A binary jax.Array of shape [B, T], with 1 denoting padding token. 82 | dtype: Data type of the input. 83 | 84 | Returns: 85 | A jax.Array of shape [B, 1, 1, T] ready to be added to attention logits. 86 | """ 87 | attention_mask = paddings[:, jnp.newaxis, jnp.newaxis, :] 88 | attention_mask *= _get_large_negative_number(dtype) 89 | return attention_mask 90 | 91 | 92 | def _causal_mask(input_t: Array) -> Array: 93 | """Computes and returns causal mask. 94 | 95 | Args: 96 | input_t: A jax.Array of shape [B, T, D]. 97 | 98 | Returns: 99 | An attention_mask jax.Array of shape [1, 1, T, T]. Attention mask has 100 | already been converted large negative values. 101 | """ 102 | assert jnp.issubdtype(input_t.dtype, jnp.floating), input_t.dtype 103 | large_negative_number = _get_large_negative_number(input_t.dtype) 104 | t = input_t.shape[-2] 105 | col_idx = jnp.tile(jnp.arange(t)[jnp.newaxis, :], [t, 1]) 106 | row_idx = jnp.tile(jnp.arange(t)[:, jnp.newaxis], [1, t]) 107 | mask = (row_idx < col_idx).astype(input_t.dtype) * large_negative_number 108 | return mask[jnp.newaxis, jnp.newaxis, :, :] 109 | 110 | 111 | def _merge_masks(a: Array, b: Array) -> Array: 112 | """Merges two masks. 113 | 114 | This function merges two masks with the same shape, where the smaller value 115 | will be chosen at the same position. Log-scale mask is expected but 0/1 mask 116 | is also fine. 117 | 118 | Args: 119 | a: A jax.Array of shape [1|B, 1, 1|T, S]. 120 | b: A jax.Array of shape [1|B, 1, 1|T, S]. 121 | 122 | Returns: 123 | A jax.Array of shape [1|B, 1, 1|T, S]. 124 | """ 125 | 126 | def expand_t(key_mask): 127 | """Expands the 1D mask to the 2D mask. 128 | 129 | Given [[1, 1, 0, 0]], this function returns the following mask, 130 | 1 1 0 0 131 | 1 1 0 0 132 | 0 0 0 0 133 | 0 0 0 0 134 | 135 | Args: 136 | key_mask: A jax.Array of the input 1D mask. 137 | 138 | Returns: 139 | A jax.Array of the expanded 2D mask. 140 | """ 141 | query_mask = jnp.transpose(key_mask, [0, 1, 3, 2]) 142 | return jnp.minimum(query_mask, key_mask) 143 | 144 | if a.shape[-2] != b.shape[-2]: 145 | if a.shape[-2] == 1: 146 | a = expand_t(a) 147 | else: 148 | assert b.shape[-2] == 1 149 | b = expand_t(b) 150 | 151 | assert a.shape[-3:] == b.shape[-3:], f'a.shape={a.shape}, b.shape={b.shape}.' 152 | return jnp.minimum(a, b) 153 | 154 | 155 | def compute_attention_masks_for_fprop( 156 | inputs: Array, 157 | paddings: Array, 158 | causal_attention: bool = False, 159 | ) -> Array: 160 | """Computes attention mask from inputs and paddings for fprop. 161 | 162 | Args: 163 | inputs: Input sequence jax.Array of shape [B, T, H]. 164 | paddings: Input paddings jax.Array of shape [B, T]. 165 | causal_attention: Boolean to apply causal masking. 166 | 167 | Returns: 168 | attention_mask: Attention mask jax.Array ready to be added to logits for 169 | self-attention of shape [1|B, 1, 1|T, T]. 170 | """ 171 | # Get paddings mask to [B, 1, 1, T]. 172 | attention_mask = _convert_paddings_to_mask(paddings, inputs.dtype) 173 | 174 | # Causal mask of shape [1, 1, T, T]. 175 | if causal_attention: 176 | causal_mask = _causal_mask(inputs) 177 | attention_mask = _merge_masks(attention_mask, causal_mask) 178 | 179 | return attention_mask 180 | 181 | 182 | class Module(nn.Module): 183 | """Base class for layers with dtype configured. 184 | 185 | Attributes: 186 | dtype: Default dtype for all variables. 187 | fprop_dtype: Activations dtype to use. 188 | """ 189 | 190 | dtype: jnp.dtype = jnp.float32 191 | fprop_dtype: jnp.dtype = jnp.float32 192 | 193 | @nn.nowrap 194 | def _cast_to_fprop_dtype(self, value: Any) -> Any: 195 | """Casts values to the desired dtype.""" 196 | 197 | def _cast(x): 198 | if x is None: 199 | return None 200 | if self.fprop_dtype != x.dtype: 201 | if jnp.issubdtype(x.dtype, jnp.floating): 202 | return x.astype(self.fprop_dtype) 203 | return x 204 | 205 | return jax.tree_util.tree_map(_cast, value) 206 | 207 | 208 | class LayerNorm(Module): 209 | """Layer normalization. 210 | 211 | Attributes: 212 | direct_scale: Whether to apply scale directly without a +1.0. Var is 213 | initialized to 1.0 instead when True. 214 | epsilon: Tiny value to guard rsqrt. 215 | use_scale: Whether to use a learned scaling. 216 | use_bias: Whether to use bias. 217 | reductions_in_fp32: Whether to compute mean and variance in fp32. 218 | Recommended for stable training on GPUs. 219 | """ 220 | 221 | direct_scale: bool = False 222 | epsilon: float = 1e-6 223 | use_scale: bool = True 224 | use_bias: bool = True 225 | reductions_in_fp32: bool = False 226 | 227 | @nn.compact 228 | def __call__(self, inputs: Array) -> Array: 229 | """Applies layer norm to inputs. 230 | 231 | Args: 232 | inputs: A jax.Array for the inputs of shape [..., dim]. 233 | 234 | Returns: 235 | A jax.Aray for the normalized inputs of the same shape. 236 | """ 237 | inputs_dtype = inputs.dtype 238 | if self.reductions_in_fp32: 239 | inputs = inputs.astype(jnp.float32) 240 | mean = jnp.mean(inputs, axis=[-1], keepdims=True) 241 | var = jnp.mean(jnp.square(inputs - mean), axis=[-1], keepdims=True) 242 | normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon) 243 | if self.reductions_in_fp32: 244 | normed_inputs = normed_inputs.astype(inputs_dtype) 245 | 246 | input_dim = inputs.shape[-1] 247 | if self.use_scale: 248 | init_value = 1.0 if self.direct_scale else 0.0 249 | scale = self._cast_to_fprop_dtype( 250 | self.param( 251 | 'scale', 252 | nn.initializers.constant(init_value), 253 | [input_dim], 254 | self.dtype, 255 | ) 256 | ) 257 | if not self.direct_scale: 258 | scale += 1.0 259 | normed_inputs *= scale 260 | if self.use_bias: 261 | bias = self._cast_to_fprop_dtype( 262 | self.param( 263 | 'bias', 264 | nn.initializers.zeros_init(), 265 | [input_dim], 266 | self.dtype, 267 | ) 268 | ) 269 | normed_inputs += bias 270 | return normed_inputs 271 | 272 | 273 | class FeedForward(Module): 274 | """Feedforward layer with activation. 275 | 276 | Attributes: 277 | output_dim: Depth of the output. 278 | has_bias: Adds bias weights or not. 279 | activation_fn: Activation function to use. 280 | weight_init: Initializer function for the weight matrix. 281 | bias_init: Initializer function for the bias. 282 | """ 283 | 284 | output_dim: int = 0 285 | has_bias: bool = True 286 | activation_fn: ActivationFunc = nn.relu 287 | weight_init: Initializer = default_kernel_init 288 | bias_init: Initializer = nn.initializers.zeros_init() 289 | 290 | @nn.compact 291 | def __call__(self, inputs: Array) -> Array: 292 | 293 | def _promote_dtype(x, kernel, bias, dtype): 294 | """Promotes the dtype of the arrays to the desired dtype.""" 295 | del dtype 296 | # To be compatible with other layers, we do not promote the inputs as they 297 | # are expected to be in the `fprop_dtype`. 298 | return ( 299 | x, 300 | self._cast_to_fprop_dtype(kernel), 301 | self._cast_to_fprop_dtype(bias), 302 | ) 303 | 304 | projected_inputs = nn.Dense( 305 | self.output_dim, 306 | use_bias=self.has_bias, 307 | kernel_init=self.weight_init, 308 | bias_init=self.bias_init, 309 | name='linear', 310 | param_dtype=self.dtype, 311 | promote_dtype=_promote_dtype, 312 | )(inputs) 313 | return self.activation_fn(projected_inputs) 314 | 315 | 316 | class TransformerFeedForward(Module): 317 | """Transformer feedforward layer with residual connection and dropout. 318 | 319 | Attributes: 320 | output_dim: Depth of the output. The value of input_dim will be used when 321 | output_dim is 0. Must be equal to input_dim if add_skip_connection=True. 322 | hidden_dim: Hidden dimension of FFN. 323 | has_bias: Adds bias weights to Feedforward or not. 324 | activation_fn: Activation function to use. 325 | residual_dropout_prob: Residual dropout. 326 | relu_dropout_prob: FFN dropout. 327 | add_skip_connection: Whether to add residual connection. 328 | residual_weight: Weight of the residual connection. Output = fn(x) * 329 | residual_weight + x. 330 | norm_policy: Policy for applying normalization wrt. transformations. Options 331 | are: (1) "pre", applied before transformation. (2) "primer_hybrid", 332 | applied before and after transformation. (3) "post", applied after 333 | transformation, (4) "post_skip", applied after the skip connection. 334 | """ 335 | 336 | output_dim: int = 0 337 | hidden_dim: int = 0 338 | has_bias: bool = True 339 | activation_fn: ActivationFunc = nn.relu 340 | residual_dropout_prob: float = 0.0 341 | relu_dropout_prob: float = 0.0 342 | add_skip_connection: bool = True 343 | residual_weight: float = 1.0 344 | norm_policy: str = 'pre' 345 | 346 | @nn.nowrap 347 | def _make_ln(self, name: str) -> LayerNorm: 348 | """Makes a LayerNorm module.""" 349 | return LayerNorm( 350 | name=name, 351 | use_bias=self.has_bias, 352 | dtype=self.dtype, 353 | fprop_dtype=self.fprop_dtype, 354 | ) 355 | 356 | @nn.nowrap 357 | def _make_ffn( 358 | self, output_dim: int, name: str, skip_activation: bool = False 359 | ) -> FeedForward: 360 | """Makes a FeedForward module.""" 361 | return FeedForward( 362 | name=name, 363 | output_dim=output_dim, 364 | has_bias=self.has_bias, 365 | activation_fn=identity if skip_activation else self.activation_fn, 366 | dtype=self.dtype, 367 | fprop_dtype=self.fprop_dtype, 368 | ) 369 | 370 | @nn.compact 371 | def __call__( 372 | self, inputs: Array, paddings: Array | None, train: bool 373 | ) -> Array: 374 | residual = inputs 375 | output_dim = self.output_dim 376 | if output_dim == 0: 377 | output_dim = inputs.shape[-1] 378 | if self.add_skip_connection and output_dim != inputs.shape[-1]: 379 | raise ValueError( 380 | 'Skip connections are only supported when input_dim == output_dim ' 381 | f'but got {self.input_dim} != {output_dim}' 382 | ) 383 | 384 | # Expand paddings to last dim if not None to have shape [batch, seq_len, 1]. 385 | if paddings is not None: 386 | paddings = jnp.expand_dims(paddings, axis=-1) 387 | 388 | if self.norm_policy == 'primer_hybrid': 389 | inputs = self._make_ln(name='pre_layer_norm')(inputs) 390 | elif self.norm_policy == 'pre': 391 | inputs = self._make_ln(name='layer_norm')(inputs) 392 | 393 | # Apply first FFN layer. 394 | activations = self._make_ffn(self.hidden_dim, name='ffn_layer1')(inputs) 395 | 396 | # Apply paddings if not None. 397 | if paddings is not None: 398 | activations *= 1.0 - paddings 399 | 400 | # Apply RELU dropout. 401 | activations = nn.Dropout(self.relu_dropout_prob, name='relu_dropout')( 402 | activations, deterministic=not train 403 | ) 404 | # Apply second FFN layer. 405 | outputs = self._make_ffn( 406 | output_dim, name='ffn_layer2', skip_activation=True 407 | )(activations) 408 | 409 | # Apply paddings if not None. 410 | if paddings is not None: 411 | outputs *= 1.0 - paddings 412 | 413 | # Apply Primer normalization before dropout. 414 | if self.norm_policy == 'primer_hybrid': 415 | outputs = self._make_ln(name='post_layer_norm')(outputs) 416 | elif self.norm_policy == 'post': 417 | outputs = self._make_ln(name='layer_norm')(outputs) 418 | 419 | # Apply residual dropout. 420 | outputs = nn.Dropout(self.residual_dropout_prob, name='residual_dropout')( 421 | outputs, deterministic=not train 422 | ) 423 | # Apply skip connection. 424 | if self.add_skip_connection: 425 | outputs = residual + outputs * self.residual_weight 426 | 427 | if self.norm_policy == 'post_skip': 428 | outputs = self._make_ln(name='layer_norm')(outputs) 429 | 430 | return outputs 431 | 432 | 433 | class AttentionProjection(Module): 434 | """Layer that computes multi heads projection. 435 | 436 | This layer is expected to be used within DotProductAttention below. 437 | 438 | Attributes: 439 | output_dim: Input dimension. 440 | num_heads: Number of attention heads. 441 | dim_per_head: Size of each head. 442 | is_output_projection: Whether it is out projection or not. If False, we use 443 | "...D,DNH->...NH" for query,key,value projection. Otherwise we use 444 | "...NH,DNH->...D" for output projection. 445 | use_bias: Whether to add bias in projection or not. 446 | """ 447 | 448 | output_dim: int = 0 449 | num_heads: int = 0 450 | dim_per_head: int = 0 451 | is_output_projection: bool = False 452 | use_bias: bool = True 453 | 454 | @nn.compact 455 | def __call__(self, inputs: Array) -> Array: 456 | """Computes the multi headed projection for inputs. 457 | 458 | Args: 459 | inputs: A jax.Array with shape [..., num_heads, dim_per_head] if 460 | is_output_projection is True or [..., input_dim] otherwise. 461 | 462 | Returns: 463 | The projected jax.Array with shape [..., input_dim] if 464 | is_output_projection is True or [..., num_heads, dim_per_head] 465 | otherwise. 466 | """ 467 | # Sort the available symbols to avoid nondeterminism. 468 | eqn_sym = ''.join(sorted(set(string.ascii_uppercase) - set('DHN'))) 469 | output_dim = ( 470 | self.output_dim if self.is_output_projection else inputs.shape[-1] 471 | ) 472 | rank = len(inputs.shape) 473 | 474 | hd_shape = [self.num_heads, self.dim_per_head] 475 | pc_shape = [output_dim] + hd_shape 476 | w = self._cast_to_fprop_dtype( 477 | self.param('w', default_kernel_init, pc_shape, self.dtype) 478 | ) 479 | 480 | if self.is_output_projection: 481 | assert inputs.shape[-2:] == (self.num_heads, self.dim_per_head) 482 | batch_eqn = eqn_sym[: (rank - 2)] 483 | eqn = f'{batch_eqn}NH,DNH->{batch_eqn}D' 484 | else: 485 | batch_eqn = eqn_sym[: (rank - 1)] if rank else '...' 486 | eqn = f'{batch_eqn}D,DNH->{batch_eqn}NH' 487 | 488 | ret = jnp.einsum(eqn, inputs, w) 489 | if self.use_bias: 490 | b = self._cast_to_fprop_dtype( 491 | self.param( 492 | 'b', 493 | nn.initializers.zeros_init(), 494 | [output_dim] if self.is_output_projection else hd_shape, 495 | self.dtype, 496 | ) 497 | ) 498 | ret += b 499 | return ret 500 | 501 | 502 | class PerDimScale(Module): 503 | """A layer to scale individual dimensions of the input.""" 504 | 505 | @nn.compact 506 | def __call__(self, inputs: Array) -> Array: 507 | """Returns per_dim_scale * inputs / jnp.sqrt(dim)). 508 | 509 | Args: 510 | inputs: A jax.Array with shape [..., dim]. 511 | 512 | Returns: 513 | outputs: A jax.Array with shape [..., dim]. 514 | """ 515 | dim = inputs.shape[-1] 516 | per_dim_scale = self._cast_to_fprop_dtype( 517 | self.param( 518 | 'per_dim_scale', nn.initializers.zeros_init(), [dim], self.dtype 519 | ) 520 | ) 521 | 522 | # 1.0/jax.nn.softplus(0.0) = 1.442695041. Hard code this number so that we 523 | # can avoid unnecessary XLA op fusion mess on TPU. 524 | r_softplus_0 = 1.442695041 525 | scale = jnp.array(r_softplus_0 / np.sqrt(dim), dtype=self.fprop_dtype) 526 | scale *= jax.nn.softplus(per_dim_scale) 527 | return inputs * scale 528 | 529 | 530 | class DotProductAttention(Module): 531 | """Dot-product attention with multiple attention heads. 532 | 533 | Attributes: 534 | hidden_dim: Number of hidden nodes. 535 | num_heads: Number of attention heads. 536 | dim_per_head: Dimension of each attention head. If None then dim_per_head == 537 | hidden_dim // num_heads. 538 | atten_dropout_prob: Probability at which we apply dropout to the attention 539 | weights. 540 | use_bias: Whether to use bias for projection layers. 541 | internal_enable_query_scale: Internal. Enable scaling of query vector. 542 | internal_enable_per_dim_scale: Internal. Setting to False disables rescaling 543 | of attention logits with 1/sqrt(dim) factor. Some Transformer variants 544 | (GShard, T5) use internal_enable_per_dim_scale=False and adjust 545 | initialization of the linear transformations(einsums), in conjunction with 546 | Adafactor optimizer. 547 | scale_query_by_dim_per_head: whether to scale the query by dim_per_head, 548 | instead of default hidden_dim // num_heads (only activated when 549 | internal_enable_per_dim_scale = False). 550 | scale_logits_by_head_dims: Enables a 1/sqrt(head dim) scaling to the logits. 551 | This occurs prior to logit cap, if any. 552 | atten_logit_cap: Cap the absolute values of logits by tanh. Enabled when a 553 | positive value is specified. May not be supported by a subclass. 554 | use_qk_norm: If QK norm is used. 555 | """ 556 | 557 | hidden_dim: int = 0 558 | num_heads: int = 1 559 | dim_per_head: int | None = None 560 | atten_dropout_prob: float = 0.0 561 | use_bias: bool = True 562 | internal_enable_query_scale: bool = True 563 | internal_enable_per_dim_scale: bool = True 564 | scale_query_by_dim_per_head: bool = False 565 | scale_logits_by_head_dims: bool = False 566 | atten_logit_cap: float = 0.0 567 | use_qk_norm: bool = False 568 | 569 | def _scale_query(self, query: Array) -> Array: 570 | """Scales the query vector if enabled.""" 571 | if not self.internal_enable_query_scale: 572 | return query 573 | if self.internal_enable_per_dim_scale: 574 | query = PerDimScale( 575 | name='per_dim_scale', dtype=self.dtype, fprop_dtype=self.fprop_dtype 576 | )(query) 577 | else: 578 | if self.scale_query_by_dim_per_head and self.dim_per_head is not None: 579 | dim_per_head = self.dim_per_head 580 | else: 581 | dim_per_head = self.hidden_dim // self.num_heads 582 | 583 | query *= dim_per_head**-0.5 584 | return query 585 | 586 | def _cap_logits(self, logits: Array) -> Array: 587 | """Caps the logits by p.atten_logit_cap with tanh, if enabled.""" 588 | if not self.atten_logit_cap or self.atten_logit_cap <= 0.0: 589 | return logits 590 | cap = jnp.array(self.atten_logit_cap, dtype=self.fprop_dtype) 591 | # Note that since this caps the negative side as well, caller must defer the 592 | # pad-with-very-negative-logits logic to after this function returns. 593 | logits = cap * jnp.tanh(logits / cap) 594 | return logits 595 | 596 | def _atten_logits(self, query: Array, key: Array) -> Array: 597 | """Computes logits from query and key.""" 598 | logits = jnp.einsum('BTNH,BSNH->BNTS', query, key) 599 | return logits 600 | 601 | def _dot_atten( 602 | self, 603 | query: Array, 604 | key: Array, 605 | value: Array, 606 | atten_mask: Array, 607 | train: bool, 608 | ) -> tuple[Array, Array]: 609 | """Main attention function. 610 | 611 | Args: 612 | query: A jax.Array of shape [B, T, N, H]. 613 | key: A jax.Array of shape [B, S, N, H]. 614 | value: A jax.Array of shape [B, S, N, H]. 615 | atten_mask: A jax.Array of shape [1|B, 1, 1|T, S] which is a mask that is 616 | applied to prevent attention between unwanted pairs. This has already 617 | been converted into large negative logits. Note that the first and third 618 | dimension allow size 1 if the mask is shared by every item in the batch 619 | or every token in the target sequence. 620 | train: Whether the model is in the train mode. 621 | 622 | Returns: 623 | encoded: A jax.Array of shape [B, T, N, H]. 624 | atten_probs: A jax.Array of shape [B, N, T, S]. 625 | """ 626 | assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' 627 | assert ( 628 | query.shape[:-3] == key.shape[:-3] == value.shape[:-3] 629 | ), 'q, k, v batch dims must match.' 630 | assert ( 631 | query.shape[-2] == key.shape[-2] == value.shape[-2] 632 | ), 'q, k, v num_heads must match.' 633 | assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.' 634 | # If only padding bias is supplied, then atten_mask can be [B, 1, 1, S] 635 | # since each target token is prohibited from attending to the same set of 636 | # source tokens. In this case tiling is inefficient and unnecessary. 637 | # If there is no padding mask, and only causal mask then the shape can be 638 | # [1, 1, T, S]. 639 | assert atten_mask.ndim == 4 and atten_mask.shape[-1] == key.shape[-3] 640 | assert atten_mask.shape[-2] in [query.shape[-3], 1] 641 | assert atten_mask.shape[0] in [key.shape[0], 1] 642 | 643 | query = self._scale_query(query) 644 | logits = self._atten_logits(query, key) 645 | 646 | if self.scale_logits_by_head_dims: 647 | logits = jnp.multiply(logits, 1.0 / np.sqrt(key.shape[-1])) 648 | 649 | logits = self._cap_logits(logits) 650 | # Attention softmax is always carried out in fp32. 651 | logits = logits.astype(jnp.float32) 652 | # Apply attention masking. 653 | padded_logits = _apply_mask_to_logits(logits, atten_mask) 654 | probs = jax.nn.softmax(padded_logits, axis=-1).astype(self.fprop_dtype) 655 | # Apply attention dropout. 656 | probs = nn.Dropout(self.atten_dropout_prob, name='atten_dropout')( 657 | probs, deterministic=not train 658 | ) 659 | # Compute the attention context. 660 | encoded = jnp.einsum('BNTS,BSNH->BTNH', probs, value) 661 | return encoded, probs 662 | 663 | @nn.nowrap 664 | def _project_input(self, name: str, dim_per_head: int) -> AttentionProjection: 665 | """Builds an AttentionProjection module.""" 666 | return AttentionProjection( 667 | name=name, 668 | num_heads=self.num_heads, 669 | dim_per_head=dim_per_head, 670 | use_bias=self.use_bias, 671 | dtype=self.dtype, 672 | fprop_dtype=self.fprop_dtype, 673 | ) 674 | 675 | @nn.nowrap 676 | def _make_ln(self, name: str) -> LayerNorm: 677 | """Makes a LayerNorm module.""" 678 | return LayerNorm( 679 | name=name, 680 | use_bias=self.use_bias, 681 | dtype=self.dtype, 682 | fprop_dtype=self.fprop_dtype, 683 | ) 684 | 685 | @nn.compact 686 | def __call__( 687 | self, 688 | query_vec: Array, 689 | key_vec: Array, 690 | value_vec: Array, 691 | atten_mask: Array, 692 | train: bool, 693 | ) -> tuple[Array, Array]: 694 | """Computes the value vector given the current query output. 695 | 696 | Args: 697 | query_vec: jax.Array of shape [B, T, D]. 698 | key_vec: jax.Array of shape [B, S, D]. 699 | value_vec: jax.Array of shape [B, S, D]. 700 | atten_mask: jax.Array of shape [1|B, 1, 1|T, S] which is a mask that is 701 | applied to prevent attention between unwanted pairs. This has already 702 | been converted into large negative logits. Note that the first and third 703 | dimension allow size 1 if the mask is shared by every item in the batch 704 | or every token in the target sequence. 705 | train: If the model is in the train mode. 706 | 707 | Returns: 708 | encoded: jax.Array of shape [B, T, D]. 709 | atten_probs: jax.Array of shape [B, N, T, S]. 710 | """ 711 | dim_per_head = self.dim_per_head 712 | if dim_per_head is None: 713 | dim_per_head = self.hidden_dim // self.num_heads 714 | assert ( 715 | dim_per_head * self.num_heads == self.hidden_dim 716 | ), f'{dim_per_head} * {self.num_heads} != {self.hidden_dim}' 717 | 718 | # Project inputs to key, value and query, respectively has shape 719 | # [B, S, N, H], [B, S, N, H], and [B, T, N, H]. 720 | query_proj = self._project_input('query', dim_per_head)(query_vec) 721 | key_proj = self._project_input('key', dim_per_head)(key_vec) 722 | value_proj = self._project_input('value', dim_per_head)(value_vec) 723 | 724 | if self.use_qk_norm: 725 | query_proj = self._make_ln(name='layer_norm_q')(query_proj) 726 | key_proj = self._make_ln(name='layer_norm_k')(key_proj) 727 | 728 | encoded, atten_probs = self._dot_atten( 729 | query_proj, key_proj, value_proj, atten_mask, train=train 730 | ) 731 | 732 | # Post projection. Setting is_output_projection=True to set the projection 733 | # direction from hidden dim to input dim. Output projection follows 734 | # query_input_dim. 735 | query_input_dim = query_vec.shape[-1] 736 | encoded = AttentionProjection( 737 | name='post', 738 | output_dim=query_input_dim, 739 | num_heads=self.num_heads, 740 | dim_per_head=dim_per_head, 741 | is_output_projection=True, 742 | use_bias=self.use_bias, 743 | dtype=self.dtype, 744 | fprop_dtype=self.fprop_dtype, 745 | )(encoded) 746 | return encoded, atten_probs 747 | 748 | 749 | class Transformer(Module): 750 | """Transformer layer with multi-headed attention. 751 | 752 | Attributes: 753 | hidden_dim: Hidden dimension of FFN layer. 754 | num_heads: Number of heads in self-attention. 755 | dim_per_head: Dimension of each attention head. If None then dim_per_head == 756 | hidden_dim // num_heads. 757 | atten_dropout_prob: Probability at which we apply dropout to the attention 758 | weights. 759 | residual_dropout_prob: Probability at which we apply dropout to the residual 760 | layers, such that, residual(x, y) = (x + dropout(y)). 761 | relu_dropout_prob: Probability at which we apply dropout to the FFN layers. 762 | norm_policy: Policy for applying normalization wrt. transformations. Options 763 | are: (1) "pre", applied before transformation. (2) "primer_hybrid", 764 | applied before and after transformation. (3) "post", applied after 765 | transformation. (4) "post_skip", applied after the skip connection. 766 | use_bias: Whether to use bias. 767 | activation_fn: Activation function to use. 768 | internal_enable_per_dim_scale: Internal. Setting to False disables rescaling 769 | of attention logits with 1/sqrt(dim) factor. 770 | atten_logit_cap: Cap the absolute values of logits by tanh. Enabled when a 771 | positive value is specified. May not be supported by a subclass. 772 | """ 773 | 774 | hidden_dim: int = 0 775 | num_heads: int = 0 776 | dim_per_head: int | None = None 777 | atten_dropout_prob: float = 0.0 778 | residual_dropout_prob: float = 0.0 779 | relu_dropout_prob: float = 0.0 780 | norm_policy: str = 'pre' 781 | use_bias: bool = True 782 | activation_fn: ActivationFunc = nn.relu 783 | internal_enable_per_dim_scale: bool = True 784 | atten_logit_cap: float = 0.0 785 | 786 | @nn.nowrap 787 | def _make_ln(self, name: str) -> LayerNorm: 788 | """Makes a LayerNorm module.""" 789 | return LayerNorm( 790 | name=name, 791 | use_bias=self.use_bias, 792 | dtype=self.dtype, 793 | fprop_dtype=self.fprop_dtype, 794 | ) 795 | 796 | @nn.compact 797 | def __call__( 798 | self, 799 | inputs: Array, 800 | paddings: Array, 801 | atten_mask: Array, 802 | train: bool, 803 | ) -> Array: 804 | """Transformer decoder layer. 805 | 806 | Args: 807 | inputs: Input sequence jax.Array of shape [B, T, H]. 808 | paddings: Input paddings jax.Array of shape [B, T] (only used in FFN). 809 | atten_mask: Self attention mask ready to add to the logits. It can be of 810 | shape [1|B, 1, 1|T, T] which is broadcast compatible with the 811 | self-attention matrix of shape [B, N, T, T]. This is assumed to have 812 | combined paddings, causal masking as well as segment maskings. 813 | train: Whether the model is in the train mode. 814 | 815 | Returns: 816 | The fflayer output with shape [B, T, D]. 817 | """ 818 | 819 | if self.norm_policy == 'primer_hybrid': 820 | inputs_normalized = self._make_ln(name='pre_layer_norm')(inputs) 821 | elif self.norm_policy == 'pre': 822 | inputs_normalized = self._make_ln(name='layer_norm')(inputs) 823 | else: 824 | inputs_normalized = inputs 825 | 826 | # Compute self-attention, key/value vectors are the input itself. 827 | atten_outputs, _ = DotProductAttention( 828 | name='self_attention', 829 | hidden_dim=inputs_normalized.shape[-1], 830 | num_heads=self.num_heads, 831 | dim_per_head=self.dim_per_head, 832 | atten_dropout_prob=self.atten_dropout_prob, 833 | use_bias=self.use_bias, 834 | internal_enable_per_dim_scale=self.internal_enable_per_dim_scale, 835 | atten_logit_cap=self.atten_logit_cap, 836 | dtype=self.dtype, 837 | fprop_dtype=self.fprop_dtype, 838 | )( 839 | inputs_normalized, 840 | inputs_normalized, 841 | inputs_normalized, 842 | atten_mask=atten_mask, 843 | train=train, 844 | ) 845 | 846 | if self.norm_policy == 'primer_hybrid': 847 | atten_outputs = self._make_ln(name='post_layer_norm')(atten_outputs) 848 | elif self.norm_policy == 'post': 849 | atten_outputs = self._make_ln(name='layer_norm')(atten_outputs) 850 | 851 | # Residual dropout and connection. 852 | atten_outputs = nn.Dropout( 853 | self.residual_dropout_prob, name='residual_dropout' 854 | )(atten_outputs, deterministic=not train) 855 | atten_outputs += inputs 856 | 857 | if self.norm_policy == 'post_skip': 858 | atten_outputs = self._make_ln(name='layer_norm')(atten_outputs) 859 | 860 | # Apply FFN layer. 861 | outputs = TransformerFeedForward( 862 | name='ff_layer', 863 | hidden_dim=self.hidden_dim, 864 | has_bias=self.use_bias, 865 | activation_fn=self.activation_fn, 866 | residual_dropout_prob=self.residual_dropout_prob, 867 | relu_dropout_prob=self.relu_dropout_prob, 868 | norm_policy=self.norm_policy, 869 | dtype=self.dtype, 870 | fprop_dtype=self.fprop_dtype, 871 | )(atten_outputs, paddings=paddings, train=train) 872 | return outputs 873 | 874 | 875 | class Repeat(nn.Module): 876 | """A generic repeat layer with `nn.remat` and`nn.scan`. 877 | 878 | Attributes: 879 | block_fn: The block function to repeat. 880 | times: The number of times to repeat block. 881 | checkpoint_policy: Checkpoint policy for `nn.remat`. 882 | """ 883 | 884 | block_fn: Callable[..., Any] 885 | times: int = 0 886 | checkpoint_policy: str = 'nothing_saveable' 887 | 888 | def __call__( 889 | self, 890 | inputs: Array, 891 | *args: Any, 892 | **kwargs: Any, 893 | ) -> Any: 894 | """Forwards inputs through the block layer stack. 895 | 896 | Block outputs are expected to be of the same structure as inputs. 897 | 898 | Args: 899 | inputs: A NestedMap of inputs that goes through the block layer stack. 900 | *args: Positional args to be passed to the forward method. 901 | **kwargs: Keyward args to be passed to the forward method. 902 | 903 | Returns: 904 | Output from the last layer. 905 | """ 906 | return self.call_with_custom_method( 907 | inputs, 908 | *args, 909 | main_fn=self.block_fn, 910 | **kwargs, 911 | ) 912 | 913 | def call_with_custom_method( 914 | self, 915 | inputs: Array, 916 | *args: Any, 917 | main_fn: Callable[..., Any], 918 | **kwargs: Any, 919 | ) -> Any: 920 | """Similar to __call__, but allows a custom way to create a layer method.""" 921 | 922 | def body_fn(fn, layer_inputs): 923 | return fn(layer_inputs, *args, **kwargs), None 924 | 925 | rematted_body_fn = nn.remat( 926 | body_fn, 927 | prevent_cse=False, 928 | policy=getattr(jax.checkpoint_policies, self.checkpoint_policy, None), 929 | ) 930 | scan_fn = nn.scan( 931 | rematted_body_fn, 932 | variable_axes={'params': 0}, 933 | split_rngs={'params': True, 'dropout': True}, 934 | length=self.times, 935 | ) 936 | outputs, _ = scan_fn(main_fn, inputs) 937 | return outputs 938 | 939 | 940 | class StackedTransformer(Module): 941 | """A stack of Transformer layers. 942 | 943 | Attributes: 944 | num_layers: Number of layers in this stack. 945 | hidden_dim: The hidden layer dimension of FFN in Transformer layers. 946 | num_heads: Number of attention heads. 947 | dim_per_head: Dimension of each attention head. If None then dim_per_head == 948 | model_dims // num_heads. 949 | dropout_prob: Apply dropout at this prob at various places. 950 | atten_dropout_prob: Probability at which we apply dropout to the attention 951 | weights. 952 | residual_dropout_prob: Probability at which we apply dropout to the residual 953 | layers, such that, residual(x, y) = (x + dropout(y)). 954 | relu_dropout_prob: Probability at which we apply dropout to the FFN layers. 955 | input_dropout_prob: Dropout probability applied to the input before any 956 | processing happens. 957 | norm_policy: Policy for applying normalization wrt. transformations. Options 958 | are: (1) "pre", applied before transformation. (2) "primer_hybrid", 959 | applied before and after transformation. (3) "post", applied after 960 | transformation. (4) "post_skip", applied after the skip connection. 961 | use_bias: Whether to use bias. 962 | activation_fn: Activation function to use. 963 | internal_enable_per_dim_scale: Internal. Setting to False disables rescaling 964 | of attention logits with 1/sqrt(dim) factor. 965 | atten_logit_cap: Cap the absolute values of logits by tanh. Enabled when a 966 | positive value is specified. May not be supported by a subclass. 967 | enable_causal_atten: Whether to enable causal attention. 968 | scan: Whether to use `nn.remat` and`nn.scan`. 969 | """ 970 | 971 | num_layers: int = 0 972 | hidden_dim: int = 0 973 | num_heads: int = 0 974 | dim_per_head: int | None = None 975 | dropout_prob: float = 0.0 976 | atten_dropout_prob: float | None = None 977 | residual_dropout_prob: float | None = None 978 | relu_dropout_prob: float | None = None 979 | input_dropout_prob: float = 0.0 980 | norm_policy: str = 'pre' 981 | use_bias: bool = True 982 | activation_fn: ActivationFunc = nn.relu 983 | internal_enable_per_dim_scale: bool = True 984 | atten_logit_cap: float = 0.0 985 | enable_causal_atten: bool = False 986 | scan: bool = False 987 | 988 | @nn.compact 989 | def __call__( 990 | self, 991 | inputs: Array, 992 | paddings: Array, 993 | train: bool, 994 | ) -> Array: 995 | """Stacked Transformer layer. 996 | 997 | Args: 998 | inputs: Input sequence of shape [B, T, H]. 999 | paddings: Input paddings of shape [B, T]. 1000 | train: If the model is in the train mode. 1001 | 1002 | Returns: 1003 | Output vector with shape [B, T, D]. 1004 | """ 1005 | 1006 | atten_mask = compute_attention_masks_for_fprop( 1007 | inputs, paddings, causal_attention=self.enable_causal_atten 1008 | ) 1009 | 1010 | outputs = inputs 1011 | if self.input_dropout_prob > 0.0: 1012 | outputs = nn.Dropout(self.input_dropout_prob, name='input_dropout')( 1013 | outputs, deterministic=not train 1014 | ) 1015 | 1016 | transformer_kwargs = dict( 1017 | num_heads=self.num_heads, 1018 | dim_per_head=self.dim_per_head, 1019 | hidden_dim=self.hidden_dim, 1020 | atten_dropout_prob=self.atten_dropout_prob or self.dropout_prob, 1021 | residual_dropout_prob=self.residual_dropout_prob or self.dropout_prob, 1022 | relu_dropout_prob=self.relu_dropout_prob or self.dropout_prob, 1023 | norm_policy=self.norm_policy, 1024 | use_bias=self.use_bias, 1025 | activation_fn=self.activation_fn, 1026 | internal_enable_per_dim_scale=self.internal_enable_per_dim_scale, 1027 | atten_logit_cap=self.atten_logit_cap, 1028 | dtype=self.dtype, 1029 | fprop_dtype=self.fprop_dtype, 1030 | ) 1031 | if self.scan: 1032 | block_fn = Transformer(name='x_layers', **transformer_kwargs) 1033 | outputs = Repeat(block_fn=block_fn, times=self.num_layers)( 1034 | outputs, paddings, atten_mask, train 1035 | ) 1036 | else: 1037 | for i in range(self.num_layers): 1038 | outputs = Transformer(name=f'x_layers_{i}', **transformer_kwargs)( 1039 | outputs, paddings, atten_mask, train 1040 | ) 1041 | return outputs 1042 | 1043 | 1044 | class AttenTokenPoolingLayer(Module): 1045 | """Attentional token pooling layer. 1046 | 1047 | Attributes: 1048 | query_dim: The query dimension of attention. If None then query_dim == 1049 | input_dim. 1050 | hidden_dim: The hidden layer dimension of FFN in Transformer layers. 1051 | num_heads: Number of attention heads. 1052 | num_queries: Number of attention queries. 1053 | add_layer_norm: Whether to apply layer norm to the pooled tokens. 1054 | dropout_prob: The probability of dropout on the pooled tokens. 1055 | use_qk_norm: If QK norm is used. 1056 | use_bias: Whether to use bias. 1057 | internal_enable_per_dim_scale: Internal. Setting to False disables rescaling 1058 | of attention logits with 1/sqrt(dim) factor. 1059 | """ 1060 | 1061 | query_dim: int | None = None 1062 | hidden_dim: int = 0 1063 | num_heads: int = 1 1064 | num_queries: int = 1 1065 | add_layer_norm: bool = True 1066 | dropout_prob: float = 0.0 1067 | use_qk_norm: bool = False 1068 | use_bias: bool = True 1069 | internal_enable_per_dim_scale: bool = True 1070 | 1071 | @nn.compact 1072 | def __call__( 1073 | self, 1074 | tokens: Array, 1075 | paddings: Array | None, 1076 | train: bool, 1077 | ) -> Array: 1078 | """Computes the pooled tokens for inputs. 1079 | 1080 | Args: 1081 | tokens: Input tokens of shape [B, T, H]. 1082 | paddings: Input paddings of shape [B, T]. 1083 | train: If the model is in the train mode. 1084 | 1085 | Returns: 1086 | Output vector with shape [B, N, D]. 1087 | """ 1088 | input_dim = tokens.shape[-1] 1089 | query_dim = self.query_dim or input_dim 1090 | hidden_dim = self.hidden_dim if self.hidden_dim > 0 else 4 * input_dim 1091 | batch_size, seq_length = tokens.shape[0], tokens.shape[-2] 1092 | 1093 | query = self._cast_to_fprop_dtype( 1094 | self.param( 1095 | 'pooling_attention_query', 1096 | default_kernel_init, 1097 | [self.num_queries, query_dim], 1098 | self.dtype, 1099 | ) 1100 | ) 1101 | query = jnp.tile(query[jnp.newaxis, :, :], [batch_size, 1, 1]) 1102 | 1103 | if paddings is None: 1104 | paddings = jnp.zeros([batch_size, seq_length], dtype=tokens.dtype) 1105 | 1106 | atten_mask = _convert_paddings_to_mask(paddings, dtype=paddings.dtype) 1107 | outputs, _ = DotProductAttention( 1108 | name='pooling_attention', 1109 | hidden_dim=hidden_dim, 1110 | num_heads=self.num_heads, 1111 | use_bias=self.use_bias, 1112 | internal_enable_per_dim_scale=self.internal_enable_per_dim_scale, 1113 | use_qk_norm=self.use_qk_norm, 1114 | dtype=self.dtype, 1115 | fprop_dtype=self.fprop_dtype, 1116 | )( 1117 | query, 1118 | tokens, 1119 | tokens, 1120 | atten_mask=atten_mask, 1121 | train=train, 1122 | ) 1123 | 1124 | if self.add_layer_norm: 1125 | outputs = LayerNorm( 1126 | name='pooling_attention_layer_norm', 1127 | dtype=self.dtype, 1128 | fprop_dtype=self.fprop_dtype, 1129 | )(outputs) 1130 | 1131 | if self.dropout_prob > 0.0: 1132 | outputs = nn.Dropout(self.dropout_prob, name='attention_dropout')( 1133 | outputs, deterministic=not train 1134 | ) 1135 | 1136 | return outputs 1137 | --------------------------------------------------------------------------------