├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs └── images │ ├── block_diagram.png │ ├── overview.png │ └── qualitative.png ├── pytorch ├── TIPS_Demo.ipynb ├── __init__.py ├── checkpoints │ └── download_checkpoints.sh ├── image_encoder.py ├── run_image_encoder_inference.py ├── run_text_encoder_inference.py └── text_encoder.py └── scenic ├── checkpoints └── download_checkpoints.sh ├── configs └── tips_model_config.py ├── images ├── example_image.jpg └── example_image_2.jpg ├── models ├── text.py ├── tips.py └── vit.py ├── notebooks └── TIPS_Demo.ipynb ├── run_tips_inference.py └── utils ├── checkpoint.py └── feature_viz.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | .Python 8 | build/ 9 | develop-eggs/ 10 | dist/ 11 | downloads/ 12 | eggs/ 13 | .eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | wheels/ 20 | share/python-wheels/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | MANIFEST 25 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | ## Contributor License Agreement 4 | 5 | Contributions to this project must be accompanied by a Contributor License 6 | Agreement. You (or your employer) retain the copyright to your contribution, 7 | this simply gives us permission to use and redistribute your contributions as 8 | part of the project. Head over to to see 9 | your current agreements on file or to sign a new one. 10 | 11 | You generally only need to submit a CLA once, so if you've already submitted one 12 | (even if it was for a different project), you probably don't need to do it 13 | again. 14 | 15 | ## Code reviews 16 | 17 | All submissions, including submissions by project members, require review. We 18 | use GitHub pull requests for this purpose. Consult 19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 20 | information on using pull requests. 21 | 22 | ## Community Guidelines 23 | 24 | This project follows [Google's Open Source Community 25 | Guidelines](https://opensource.google/conduct/). 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TIPS: Text-Image Pretraining with Spatial awareness (ICLR 2025) 2 | 3 | This repository contains the implementation and models introduced in 4 | TIPS: Text-Image Pretraining with Spatial Awareness, published at ICLR 2025. 5 | 6 | **Quick Links:** 7 | [Paper](https://arxiv.org/abs/2410.16512) | 8 | [Project Website](https://gdm-tips.github.io) | 9 | [Pytorch Notebook](./pytorch/TIPS_Demo.ipynb) | 10 | [Scenic Notebook](./scenic/notebooks/TIPS_Demo.ipynb) 11 | 12 | We provide both Pytorch and Jax (Scenic) implementations: 13 | 14 | - `tips/pytorch/`: PyTorch inference for the model. The image tower largely 15 | follows the official [DINOv2 definition](https://github.com/facebookresearch/dinov2). 16 | - `tips/scenic/`: Jax-based inference using the 17 | [scenic library](https://github.com/google-research/scenic). 18 | 19 |

20 | 24 |

25 | 26 | **Abstract** 27 |
28 | While image-text representation learning has become very popular 29 | in recent years, existing models tend to lack spatial awareness and have limited 30 | direct applicability for dense understanding tasks. For this reason, 31 | self-supervised image-only pretraining is still the go-to method for many dense 32 | vision applications (e.g. depth estimation, semantic segmentation), despite the 33 | lack of explicit supervisory signals. In this paper, we close this gap between 34 | image-text and self-supervised learning, by proposing a novel general-purpose 35 | image-text model, which can be effectively used off the shelf for dense and 36 | global vision tasks. Our method, which we refer to as Text-Image Pretraining 37 | with Spatial awareness (TIPS), leverages two simple and effective insights. 38 | First, on textual supervision: we reveal that replacing noisy web image captions 39 | by synthetically generated textual descriptions boosts dense understanding 40 | performance significantly, due to a much richer signal for learning spatially 41 | aware representations. We propose an adapted training method that combines noisy 42 | and synthetic captions, resulting in improvements across both dense and global 43 | understanding tasks. Second, on the learning technique: we propose to combine 44 | contrastive image-text learning with self-supervised masked image modeling, to 45 | encourage spatial coherence, unlocking substantial enhancements for downstream 46 | applications. Building on these two ideas, we scale our model using the 47 | transformer architecture, trained on a curated set of public images. Our 48 | experiments are conducted on 8 tasks involving 16 datasets in total, 49 | demonstrating strong off-the-shelf performance on both dense and global 50 | understanding, for several image-only and image-text tasks. 51 |
52 | 53 |

54 | 58 |

59 | 60 | 61 | ## Checkpoints 62 | We provide links to all available checkpoints, for both Pytorch and Jax model 63 | definitions, together with representative evals. 64 | 65 | Model size | #Params vision / text | Pytorch ckp. | Jax ckp. | PASCAL seg.↑ | NYU-depth↓ | ImageNet-KNN↑ | UNED-KNN↑ | Flickr T→I↑ | Flickr I→T↑ 66 | :---------- | :--------------------- | :------------------------------------------------------: | :------------------------------------------------------: | :---------: | :-------: | :----------: | :------: | :--------: | :--------: 67 | g/14-HR | 1.1B / 389.1M | [vision][pth-g14-hr-vision] \| [text][pth-g14-hr-text] | [vision][jax-g14-hr-vision] \| [text][jax-g14-hr-text] | 83.1 | 0.363 | 83.2 | 68.4 | 93.8 | 83.8 68 | g/14-LR | 1.1B / 389.1M | [vision][pth-g14-lr-vision] \| [text][pth-g14-lr-text] | [vision][jax-g14-lr-vision] \| [text][jax-g14-lr-text] | 82.0 | 0.390 | 83.6 | 71.5 | 93.4 | 82.1 69 | SO/14-HR | 412.4M / 448.3M | [vision][pth-so14-hr-vision] \| [text][pth-so14-hr-text] | [vision][jax-so14-hr-vision] \| [text][jax-so14-hr-text] | 83.7 | 0.362 | 83.0 | 68.6 | 94.2 | 83.8 70 | L/14-HR | 303.2M / 183.9M | [vision][pth-l14-hr-vision] \| [text][pth-l14-hr-text] | [vision][jax-l14-hr-vision] \| [text][jax-l14-hr-text] | 83.9 | 0.372 | 82.5 | 67.8 | 93.6 | 83.5 71 | B/14-HR | 85.7M / 109.6M | [vision][pth-b14-hr-vision] \| [text][pth-b14-hr-text] | [vision][jax-b14-hr-vision] \| [text][jax-b14-hr-text] | 82.9 | 0.379 | 80.0 | 62.7 | 91.3 | 79.4 72 | S/14-HR | 21.6M / 33.6M | [vision][pth-s14-hr-vision] \| [text][pth-s14-hr-text] | [vision][jax-s14-hr-vision] \| [text][jax-s14-hr-text] | 80.6 | 0.425 | 75.1 | 57.7 | 86.3 | 74.7 73 | 74 | ## Using Pytorch 75 | 76 | ### Installation 77 | Manage dependencies with a custom environment (eg. Conda) 78 | 79 | ```bash 80 | conda create -n tips python=3.11 81 | 82 | # Activate the environment. 83 | conda activate tips 84 | ``` 85 | 86 | Install Pytorch dependencies. 87 | 88 | ```bash 89 | # Install pytorch (change to GPU version if needed) 90 | pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu 91 | 92 | # Install other dependencies. 93 | pip install tensorflow_text mediapy jax jaxlib scikit-learn 94 | 95 | # Optionally, install Jupyter to use the notebook. 96 | pip install jupyter 97 | ``` 98 | 99 | Clone the code from this repo. 100 | 101 | ```bash 102 | git clone https://github.com/google-deepmind/tips.git 103 | 104 | # Add the current directory to PYTHONPATH. 105 | export PYTHONPATH=$PYTHONPATH:$(pwd) 106 | ``` 107 | 108 | Download the checkpoints locally. The script downloads all released checkpoints. 109 | Please adjust accordingly. 110 | 111 | ```bash 112 | cd tips/pytorch/checkpoints 113 | chmod +x download_checkpoints.sh 114 | ./download_checkpoints.sh 115 | cd ../../.. 116 | ``` 117 | 118 | ### Usage (Pytorch) 119 | 120 | To run inference on one image and get the L2-normalized image embedding from the 121 | 1st and 2nd CLS token, one can use the following: 122 | 123 | ```bash 124 | cd tips/pytorch && \ 125 | python run_image_encoder_inference.py \ 126 | --model_path=${PATH_TO_CHECKPOINT} \ 127 | --image_file=${PATH_TO_IMAGE} \ 128 | --model_variant=${MODEL_VARIANT} 129 | ``` 130 | 131 | One can use `is_low_res` to specify whether a low-resolution or high-resolution 132 | checkpoint is used. 133 | 134 | To run text model inference and get the L2-normalized text embedding, please use 135 | the following cmd 136 | 137 | ```bash 138 | cd tips/pytorch && \ 139 | python run_text_encoder_inference.py \ 140 | --model_path=${PATH_TO_CHECKPOINT} \ 141 | --tokenizer_path=${PATH_TO_TOKENIZER} \ 142 | --model_variant=${MODEL_VARIANT} \ 143 | --text_input=${TEXT_INPUT} 144 | ``` 145 | 146 | We also provide a simple notebook demo: 147 | 148 | ```bash 149 | jupyter-notebook 150 | ``` 151 | Then navigate to `tips/pytorch/TIPS_Demo.ipynb`. 152 | 153 | ## Using Jax (Scenic) 154 | 155 | ### Installation 156 | Similar to using Pytorch, manage dependencies with a custom environment. 157 | 158 | ```bash 159 | conda create -n tips python=3.11 160 | 161 | # Activate the environment. 162 | conda activate tips 163 | ``` 164 | 165 | ```bash 166 | # Install scenic. 167 | git clone https://github.com/google-research/scenic.git scenic_src 168 | cd scenic_src 169 | pip install . 170 | cd .. 171 | rm -rf scenic_src 172 | 173 | # Install other dependencies. 174 | pip install pillow scikit-learn opencv-python tensorflow_text 175 | 176 | # Optionally, install Jupyter to use the notebook. 177 | pip install jupyter mediapy 178 | 179 | # In case of using CUDA, install the CUDA-supported JAX libraries. 180 | # For example, for CUDA 12 run: 181 | # pip install --upgrade "jax[cuda12_pip]" -f \ 182 | # https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 183 | ``` 184 | 185 | Clone the code from the this repo. 186 | 187 | ```bash 188 | git clone https://github.com/google-deepmind/tips.git 189 | 190 | # Add the current directory to PYTHONPATH. 191 | export PYTHONPATH=$PYTHONPATH:$(pwd) 192 | ``` 193 | 194 | Download the checkpoints (different files from Pytorch). 195 | 196 | ```bash 197 | cd tips/scenic/checkpoints 198 | chmod +x download_checkpoints.sh 199 | ./download_checkpoints.sh 200 | cd ../../.. 201 | ``` 202 | 203 | ### Usage (Jax) 204 | 205 | To run inference on an image, use the following script: 206 | 207 | ```bash 208 | cd tips/scenic 209 | python run_tips_inference.py 210 | ``` 211 | 212 | Alternatively, try the demo in the notebook: 213 | 214 | ```bash 215 | jupyter-notebook 216 | ``` 217 | Then navigate to `tips/scenic/notebooks/TIPS_Demo.ipynb`. 218 | 219 | ## Citing this work 220 | 221 | The paper can be found on [arXiv](https://arxiv.org/abs/2410.16512). 222 | Please consider citing this work using: 223 | 224 | ``` 225 | @InProceedings{tips_paper, 226 | Title={{TIPS: Text-Image Pretraining with Spatial Awareness}}, 227 | Author={Maninis, Kevis-Kokitsi and Chen, Kaifeng and Ghosh, Soham and Karpur, Arjun and Chen, Koert and Xia, Ye and Cao, Bingyi and Salz, Daniel and Han, Guangxing and Dlabal, Jan and Gnanapragasam, Dan and Seyedhosseini, Mojtaba and Zhou, Howard and Araujo, Andr\'e}, 228 | Booktitle={ICLR}, 229 | year={2025}, 230 | } 231 | ``` 232 | 233 | ## License and disclaimer 234 | 235 | Copyright 2025 DeepMind Technologies Limited 236 | 237 | All software is licensed under the Apache License, Version 2.0 (Apache 2.0); 238 | you may not use this file except in compliance with the Apache 2.0 license. 239 | You may obtain a copy of the Apache 2.0 license at: 240 | https://www.apache.org/licenses/LICENSE-2.0 241 | 242 | All other materials are licensed under the Creative Commons Attribution 4.0 243 | International License (CC-BY). You may obtain a copy of the CC-BY license at: 244 | https://creativecommons.org/licenses/by/4.0/legalcode 245 | 246 | Unless required by applicable law or agreed to in writing, all software and 247 | materials distributed here under the Apache 2.0 or CC-BY licenses are 248 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, 249 | either express or implied. See the licenses for the specific language governing 250 | permissions and limitations under those licenses. 251 | 252 | This is not an official Google product. 253 | 254 | [jax-g14-hr-vision]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/scenic/tips_oss_g14_highres_vision.npz 255 | [jax-g14-hr-text]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/scenic/tips_oss_g14_highres_text.npz 256 | [jax-g14-lr-vision]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/scenic/tips_oss_g14_lowres_vision.npz 257 | [jax-g14-lr-text]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/scenic/tips_oss_g14_lowres_text.npz 258 | [jax-so14-hr-vision]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/scenic/tips_oss_so400m14_highres_largetext_distilled_vision.npz 259 | [jax-so14-hr-text]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/scenic/tips_oss_so400m14_highres_largetext_distilled_text.npz 260 | [jax-l14-hr-vision]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/scenic/tips_oss_l14_highres_distilled_vision.npz 261 | [jax-l14-hr-text]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/scenic/tips_oss_l14_highres_distilled_text.npz 262 | [jax-b14-hr-vision]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/scenic/tips_oss_b14_highres_distilled_vision.npz 263 | [jax-b14-hr-text]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/scenic/tips_oss_b14_highres_distilled_text.npz 264 | [jax-s14-hr-vision]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/scenic/tips_oss_s14_highres_distilled_vision.npz 265 | [jax-s14-hr-text]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/scenic/tips_oss_s14_highres_distilled_text.npz 266 | 267 | [pth-g14-hr-vision]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/pytorch/tips_oss_g14_highres_vision.npz 268 | [pth-g14-hr-text]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/pytorch/tips_oss_g14_highres_text.npz 269 | [pth-g14-lr-vision]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/pytorch/tips_oss_g14_lowres_vision.npz 270 | [pth-g14-lr-text]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/pytorch/tips_oss_g14_lowres_text.npz 271 | [pth-so14-hr-vision]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/pytorch/tips_oss_so400m14_highres_largetext_distilled_vision.npz 272 | [pth-so14-hr-text]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/pytorch/tips_oss_so400m14_highres_largetext_distilled_text.npz 273 | [pth-l14-hr-vision]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/pytorch/tips_oss_l14_highres_distilled_vision.npz 274 | [pth-l14-hr-text]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/pytorch/tips_oss_l14_highres_distilled_text.npz 275 | [pth-b14-hr-vision]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/pytorch/tips_oss_b14_highres_distilled_vision.npz 276 | [pth-b14-hr-text]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/pytorch/tips_oss_b14_highres_distilled_text.npz 277 | [pth-s14-hr-vision]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/pytorch/tips_oss_s14_highres_distilled_vision.npz 278 | [pth-s14-hr-text]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/pytorch/tips_oss_s14_highres_distilled_text.npz 279 | -------------------------------------------------------------------------------- /docs/images/block_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/tips/c1be95c4cba9345f1d207e652c6e285ed7c2ec04/docs/images/block_diagram.png -------------------------------------------------------------------------------- /docs/images/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/tips/c1be95c4cba9345f1d207e652c6e285ed7c2ec04/docs/images/overview.png -------------------------------------------------------------------------------- /docs/images/qualitative.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/tips/c1be95c4cba9345f1d207e652c6e285ed7c2ec04/docs/images/qualitative.png -------------------------------------------------------------------------------- /pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | # All software is licensed under the Apache License, Version 2.0 (Apache 2.0); 17 | # you may not use this file except in compliance with the Apache 2.0 license. 18 | # You may obtain a copy of the Apache 2.0 license at: 19 | # https://www.apache.org/licenses/LICENSE-2.0 20 | 21 | # All other materials are licensed under the Creative Commons Attribution 4.0 22 | # International License (CC-BY). You may obtain a copy of the CC-BY license at: 23 | # https://creativecommons.org/licenses/by/4.0/legalcode 24 | 25 | # Unless required by applicable law or agreed to in writing, all software and 26 | # materials distributed here under the Apache 2.0 or CC-BY licenses are 27 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, 28 | # either express or implied. See the licenses for the specific language 29 | # governing permissions and limitations under those licenses. 30 | 31 | # This is not an official Google product. 32 | """Import all files.""" 33 | -------------------------------------------------------------------------------- /pytorch/checkpoints/download_checkpoints.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2025 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | 18 | # The model weights can be found in https://console.cloud.google.com/storage/browser/tips_data 19 | ALL_CHECKPOINTS=( 20 | "tips_oss_s14_highres_distilled" 21 | "tips_oss_b14_highres_distilled" 22 | "tips_oss_l14_highres_distilled" 23 | "tips_oss_so400m14_highres_largetext_distilled" 24 | "tips_oss_g14_lowres" 25 | "tips_oss_g14_highres" 26 | ) 27 | 28 | echo "Downloading the tokenizer." 29 | wget https://storage.googleapis.com/tips_data/v1_0/checkpoints/tokenizer.model 30 | 31 | for CHECKPOINT in "${ALL_CHECKPOINTS[@]}"; do 32 | echo "Downloading ${CHECKPOINT} (vision encoder weights)" 33 | wget https://storage.googleapis.com/tips_data/v1_0/checkpoints/pytorch/${CHECKPOINT}_vision.npz 34 | echo "Downloading ${CHECKPOINT} (text encoder weights)" 35 | wget https://storage.googleapis.com/tips_data/v1_0/checkpoints/pytorch/${CHECKPOINT}_text.npz 36 | done 37 | -------------------------------------------------------------------------------- /pytorch/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Vision encoder implementation in PyTorch.""" 17 | 18 | import functools 19 | import math 20 | import os 21 | from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union 22 | import warnings 23 | import torch 24 | from torch import nn 25 | import torch.nn.functional as F 26 | import torch.utils.checkpoint 27 | 28 | 29 | class Mlp(nn.Module): 30 | """Transformer MLP, following DINOv2 implementation.""" 31 | 32 | def __init__( 33 | self, 34 | in_features: int, 35 | hidden_features: Optional[int] = None, 36 | out_features: Optional[int] = None, 37 | act_layer: Callable[..., nn.Module] = nn.GELU, 38 | drop: float = 0.0, 39 | bias: bool = True, 40 | ) -> None: 41 | super().__init__() 42 | out_features = out_features or in_features 43 | hidden_features = hidden_features or in_features 44 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 45 | self.act = act_layer() 46 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 47 | self.drop = nn.Dropout(drop) 48 | 49 | def forward(self, x: torch.Tensor) -> torch.Tensor: 50 | x = self.fc1(x) 51 | x = self.act(x) 52 | x = self.drop(x) 53 | x = self.fc2(x) 54 | x = self.drop(x) 55 | return x 56 | 57 | 58 | def make_2tuple(x): 59 | if isinstance(x, tuple): 60 | assert len(x) == 2 61 | return x 62 | 63 | assert isinstance(x, int) 64 | return (x, x) 65 | 66 | 67 | class PatchEmbed(nn.Module): 68 | """2D image to patch embedding: (B,C,H,W) -> (B,N,D).""" 69 | 70 | def __init__( 71 | self, 72 | img_size: Union[int, Tuple[int, int]] = 224, 73 | patch_size: Union[int, Tuple[int, int]] = 16, 74 | in_chans: int = 3, 75 | embed_dim: int = 768, 76 | norm_layer: Optional[Callable] = None, # pylint: disable=g-bare-generic 77 | flatten_embedding: bool = True, 78 | ) -> None: 79 | super().__init__() 80 | 81 | image_hw = make_2tuple(img_size) 82 | patch_hw = make_2tuple(patch_size) 83 | patch_grid_size = ( 84 | image_hw[0] // patch_hw[0], 85 | image_hw[1] // patch_hw[1], 86 | ) 87 | 88 | self.img_size = image_hw 89 | self.patch_size = patch_hw 90 | self.patches_resolution = patch_grid_size 91 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 92 | 93 | self.in_chans = in_chans 94 | self.embed_dim = embed_dim 95 | 96 | self.flatten_embedding = flatten_embedding 97 | 98 | self.proj = nn.Conv2d( 99 | in_chans, embed_dim, kernel_size=patch_hw, stride=patch_hw 100 | ) 101 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 102 | 103 | def forward(self, x: torch.Tensor) -> torch.Tensor: 104 | _, _, h, w = x.shape 105 | patch_h, patch_w = self.patch_size 106 | 107 | assert ( 108 | h % patch_h == 0 109 | ), f"Input image height {h} is not a multiple of patch height {patch_h}" 110 | assert ( 111 | w % patch_w == 0 112 | ), f"Input image width {w} is not a multiple of patch width: {patch_w}" 113 | 114 | x = self.proj(x) # B C H W 115 | h, w = x.size(2), x.size(3) 116 | x = x.flatten(2).transpose(1, 2) # B HW C 117 | x = self.norm(x) 118 | if not self.flatten_embedding: 119 | x = x.reshape(-1, h, w, self.embed_dim) # B H W C 120 | return x 121 | 122 | def flops(self) -> float: 123 | ho, wo = self.patches_resolution 124 | flops = ( 125 | ho 126 | * wo 127 | * self.embed_dim 128 | * self.in_chans 129 | * (self.patch_size[0] * self.patch_size[1]) 130 | ) 131 | if self.norm is not None: 132 | flops += ho * wo * self.embed_dim 133 | return flops 134 | 135 | 136 | class SwiGLUFFN(nn.Module): 137 | """SwiGLU FFN layer, following DINOv2 implementation.""" 138 | 139 | def __init__( 140 | self, 141 | in_features: int, 142 | hidden_features: Optional[int] = None, 143 | out_features: Optional[int] = None, 144 | act_layer: Callable[..., nn.Module] = None, 145 | drop: float = 0.0, 146 | bias: bool = True, 147 | ) -> None: 148 | super().__init__() 149 | out_features = out_features or in_features 150 | hidden_features = hidden_features or in_features 151 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 152 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 153 | 154 | def forward(self, x: torch.Tensor) -> torch.Tensor: 155 | x12 = self.w12(x) 156 | x1, x2 = x12.chunk(2, dim=-1) 157 | hidden = F.silu(x1) * x2 158 | return self.w3(hidden) 159 | 160 | 161 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 162 | try: 163 | if XFORMERS_ENABLED: 164 | from xformers.ops import SwiGLU, memory_efficient_attention, unbind, fmha, scaled_index_add, index_select_cat # pylint: disable=g-multiple-import, g-import-not-at-top 165 | 166 | XFORMERS_AVAILABLE = True 167 | warnings.warn("xFormers is available (SwiGLU)") 168 | else: 169 | warnings.warn("xFormers is disabled (SwiGLU)") 170 | raise ImportError 171 | except ImportError: 172 | SwiGLU = SwiGLUFFN 173 | XFORMERS_AVAILABLE = False 174 | 175 | warnings.warn("xFormers is not available (SwiGLU)") 176 | 177 | 178 | class SwiGLUFFNFused(SwiGLU): 179 | """SwiGLU FFN layer, following DINOv2 implementation.""" 180 | 181 | def __init__( 182 | self, 183 | in_features: int, 184 | hidden_features: Optional[int] = None, 185 | out_features: Optional[int] = None, 186 | act_layer: Callable[..., nn.Module] = None, # pylint: disable=unused-argument 187 | drop: float = 0.0, # pylint: disable=unused-argument 188 | bias: bool = True, 189 | ) -> None: 190 | out_features = out_features or in_features 191 | hidden_features = hidden_features or in_features 192 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 193 | super().__init__( 194 | in_features=in_features, 195 | hidden_features=hidden_features, 196 | out_features=out_features, 197 | bias=bias, 198 | ) 199 | 200 | 201 | class Attention(nn.Module): 202 | """Attention layer, following DINOv2 implementation.""" 203 | 204 | def __init__( 205 | self, 206 | dim: int, 207 | num_heads: int = 8, 208 | qkv_bias: bool = False, 209 | proj_bias: bool = True, 210 | attn_drop: float = 0.0, 211 | proj_drop: float = 0.0, 212 | ) -> None: 213 | super().__init__() 214 | self.num_heads = num_heads 215 | head_dim = dim // num_heads 216 | self.scale = head_dim**-0.5 217 | 218 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 219 | self.attn_drop = nn.Dropout(attn_drop) 220 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 221 | self.proj_drop = nn.Dropout(proj_drop) 222 | 223 | def forward(self, x: torch.Tensor) -> torch.Tensor: 224 | b_dim, n_dim, c_dim = x.shape 225 | qkv = ( 226 | self.qkv(x) 227 | .reshape(b_dim, n_dim, 3, self.num_heads, c_dim // self.num_heads) 228 | .permute(2, 0, 3, 1, 4) 229 | ) 230 | 231 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 232 | attn = q @ k.transpose(-2, -1) 233 | 234 | attn = attn.softmax(dim=-1) 235 | attn = self.attn_drop(attn) 236 | 237 | x = (attn @ v).transpose(1, 2).reshape(b_dim, n_dim, c_dim) 238 | x = self.proj(x) 239 | x = self.proj_drop(x) 240 | return x 241 | 242 | 243 | class MemEffAttention(Attention): 244 | """Memory Efficient Attention layer, following DINOv2 implementation.""" 245 | 246 | def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor: 247 | if not XFORMERS_AVAILABLE: 248 | if attn_bias is not None: 249 | raise AssertionError("xFormers is required for using nested tensors") 250 | return super().forward(x) 251 | 252 | b_dim, n_dim, c_dim = x.shape 253 | qkv = self.qkv(x).reshape( 254 | b_dim, n_dim, 3, self.num_heads, c_dim // self.num_heads 255 | ) 256 | 257 | q, k, v = unbind(qkv, 2) 258 | 259 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 260 | x = x.reshape([b_dim, n_dim, c_dim]) 261 | 262 | x = self.proj(x) 263 | x = self.proj_drop(x) 264 | return x 265 | 266 | 267 | class LayerScale(nn.Module): 268 | """Layer scale, following DINOv2 implementation.""" 269 | 270 | def __init__( 271 | self, 272 | dim: int, 273 | init_values: Union[float, torch.Tensor] = 1e-5, 274 | inplace: bool = False, 275 | ) -> None: 276 | super().__init__() 277 | self.inplace = inplace 278 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 279 | 280 | def forward(self, x: torch.Tensor) -> torch.Tensor: 281 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 282 | 283 | 284 | def drop_path_impl(x, drop_prob: float = 0.0, training: bool = False): 285 | if drop_prob == 0.0 or not training: 286 | return x 287 | keep_prob = 1 - drop_prob 288 | shape = (x.shape[0],) + (1,) * ( 289 | x.ndim - 1 290 | ) # work with diff dim tensors, not just 2D ConvNets 291 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 292 | if keep_prob > 0.0: 293 | random_tensor.div_(keep_prob) 294 | output = x * random_tensor 295 | return output 296 | 297 | 298 | class DropPath(nn.Module): 299 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 300 | 301 | def __init__(self, drop_prob=None): 302 | super(DropPath, self).__init__() 303 | self.drop_prob = drop_prob 304 | 305 | def forward(self, x): 306 | return drop_path_impl(x, self.drop_prob, self.training) 307 | 308 | 309 | class Block(nn.Module): 310 | """Transformer Block Implementation, following DINOv2 implementation.""" 311 | 312 | def __init__( 313 | self, 314 | dim: int, 315 | num_heads: int, 316 | mlp_ratio: float = 4.0, 317 | qkv_bias: bool = False, 318 | proj_bias: bool = True, 319 | ffn_bias: bool = True, 320 | drop: float = 0.0, 321 | attn_drop: float = 0.0, 322 | init_values=None, 323 | drop_path: float = 0.0, 324 | act_layer: Callable[..., nn.Module] = nn.GELU, 325 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm, 326 | attn_class: Callable[..., nn.Module] = Attention, 327 | ffn_layer: Callable[..., nn.Module] = Mlp, 328 | ) -> None: 329 | super().__init__() 330 | self.norm1 = norm_layer(dim) 331 | self.attn = attn_class( 332 | dim, 333 | num_heads=num_heads, 334 | qkv_bias=qkv_bias, 335 | proj_bias=proj_bias, 336 | attn_drop=attn_drop, 337 | proj_drop=drop, 338 | ) 339 | self.ls1 = ( 340 | LayerScale(dim, init_values=init_values) 341 | if init_values 342 | else nn.Identity() 343 | ) 344 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 345 | 346 | self.norm2 = norm_layer(dim) 347 | mlp_hidden_dim = int(dim * mlp_ratio) 348 | self.mlp = ffn_layer( 349 | in_features=dim, 350 | hidden_features=mlp_hidden_dim, 351 | act_layer=act_layer, 352 | drop=drop, 353 | bias=ffn_bias, 354 | ) 355 | self.ls2 = ( 356 | LayerScale(dim, init_values=init_values) 357 | if init_values 358 | else nn.Identity() 359 | ) 360 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 361 | 362 | self.sample_drop_ratio = drop_path 363 | 364 | def forward(self, x: torch.Tensor) -> torch.Tensor: 365 | def attn_residual_func(x: torch.Tensor) -> torch.Tensor: 366 | return self.ls1(self.attn(self.norm1(x))) 367 | 368 | def ffn_residual_func(x: torch.Tensor) -> torch.Tensor: 369 | return self.ls2(self.mlp(self.norm2(x))) 370 | 371 | if self.training and self.sample_drop_ratio > 0.1: 372 | # the overhead is compensated only for a drop path rate larger than 0.1 373 | x = drop_add_residual_stochastic_depth( 374 | x, 375 | residual_func=attn_residual_func, 376 | sample_drop_ratio=self.sample_drop_ratio, 377 | ) 378 | x = drop_add_residual_stochastic_depth( 379 | x, 380 | residual_func=ffn_residual_func, 381 | sample_drop_ratio=self.sample_drop_ratio, 382 | ) 383 | elif self.training and self.sample_drop_ratio > 0.0: 384 | x = x + self.drop_path1(attn_residual_func(x)) 385 | x = x + self.drop_path1(ffn_residual_func(x)) 386 | else: 387 | x = x + attn_residual_func(x) 388 | x = x + ffn_residual_func(x) 389 | return x 390 | 391 | 392 | def drop_add_residual_stochastic_depth( 393 | x: torch.Tensor, 394 | residual_func: Callable[[torch.Tensor], torch.Tensor], 395 | sample_drop_ratio: float = 0.0, 396 | ) -> torch.Tensor: 397 | """This function is taken from the original implementation in DINOv2 to implement stochastic depth in the image encoder.""" 398 | # 1) extract subset using permutation 399 | b, _, _ = x.shape 400 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 401 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 402 | x_subset = x[brange] 403 | 404 | # 2) apply residual_func to get residual 405 | residual = residual_func(x_subset) 406 | 407 | x_flat = x.flatten(1) 408 | residual = residual.flatten(1) 409 | 410 | residual_scale_factor = b / sample_subset_size 411 | 412 | # 3) add the residual 413 | x_plus_residual = torch.index_add( 414 | x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor 415 | ) 416 | return x_plus_residual.view_as(x) 417 | 418 | 419 | def get_branges_scales(x, sample_drop_ratio=0.0): 420 | b, _, _ = x.shape 421 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 422 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 423 | residual_scale_factor = b / sample_subset_size 424 | return brange, residual_scale_factor 425 | 426 | 427 | def add_residual( 428 | x, brange, residual, residual_scale_factor, scaling_vector=None 429 | ): 430 | """Implement residual addition in the image encoder.""" 431 | if scaling_vector is None: 432 | x_flat = x.flatten(1) 433 | residual = residual.flatten(1) 434 | x_plus_residual = torch.index_add( 435 | x_flat, 436 | 0, 437 | brange, 438 | residual.to(dtype=x.dtype), 439 | alpha=residual_scale_factor, 440 | ) 441 | else: 442 | x_plus_residual = scaled_index_add( 443 | x, 444 | brange, 445 | residual.to(dtype=x.dtype), 446 | scaling=scaling_vector, 447 | alpha=residual_scale_factor, 448 | ) 449 | return x_plus_residual 450 | 451 | 452 | attn_bias_cache: Dict[Tuple, Any] = {} # pylint: disable=g-bare-generic 453 | 454 | 455 | def get_attn_bias_and_cat(x_list, branges=None): 456 | """this will perform the index select, cat the tensors, and provide the attn_bias from cache.""" 457 | batch_sizes = ( 458 | [b.shape[0] for b in branges] 459 | if branges is not None 460 | else [x.shape[0] for x in x_list] 461 | ) 462 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) 463 | if all_shapes not in attn_bias_cache.keys(): 464 | seqlens = [] 465 | for b, x in zip(batch_sizes, x_list): 466 | for _ in range(b): 467 | seqlens.append(x.shape[1]) 468 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) 469 | attn_bias._batch_sizes = batch_sizes # pylint: disable=protected-access 470 | attn_bias_cache[all_shapes] = attn_bias 471 | 472 | if branges is not None: 473 | cat_tensors = index_select_cat( 474 | [x.flatten(1) for x in x_list], branges 475 | ).view(1, -1, x_list[0].shape[-1]) 476 | else: 477 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) 478 | cat_tensors = torch.cat(tensors_bs1, dim=1) 479 | 480 | return attn_bias_cache[all_shapes], cat_tensors 481 | 482 | 483 | def drop_add_residual_stochastic_depth_list( 484 | x_list: List[torch.Tensor], 485 | residual_func: Callable[[torch.Tensor, Any], torch.Tensor], 486 | sample_drop_ratio: float = 0.0, 487 | scaling_vector=None, 488 | ) -> torch.Tensor: 489 | """Add residual to a list of tensors.""" 490 | # 1) generate random set of indices for dropping samples in the batch. 491 | branges_scales = [ 492 | get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list 493 | ] 494 | branges = [s[0] for s in branges_scales] 495 | residual_scale_factors = [s[1] for s in branges_scales] 496 | 497 | # 2) get attention bias and index+concat the tensors. 498 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) 499 | 500 | # 3) apply residual_func to get residual, and split the result. 501 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore 502 | 503 | outputs = [] 504 | for x, brange, residual, residual_scale_factor in zip( 505 | x_list, branges, residual_list, residual_scale_factors 506 | ): 507 | outputs.append( 508 | add_residual( 509 | x, brange, residual, residual_scale_factor, scaling_vector 510 | ).view_as(x) 511 | ) 512 | return outputs 513 | 514 | 515 | class NestedTensorBlock(Block): 516 | """Nested tensor block implementation.""" 517 | 518 | def forward_nested(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]: 519 | """x_list contains a list of tensors to nest together and run.""" 520 | assert isinstance(self.attn, MemEffAttention) 521 | 522 | if self.training and self.sample_drop_ratio > 0.0: 523 | 524 | def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor: 525 | return self.attn(self.norm1(x), attn_bias=attn_bias) 526 | 527 | def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor: 528 | del attn_bias 529 | return self.mlp(self.norm2(x)) 530 | 531 | x_list = drop_add_residual_stochastic_depth_list( 532 | x_list, 533 | residual_func=attn_residual_func, 534 | sample_drop_ratio=self.sample_drop_ratio, 535 | scaling_vector=self.ls1.gamma 536 | if isinstance(self.ls1, LayerScale) 537 | else None, 538 | ) 539 | x_list = drop_add_residual_stochastic_depth_list( 540 | x_list, 541 | residual_func=ffn_residual_func, 542 | sample_drop_ratio=self.sample_drop_ratio, 543 | scaling_vector=self.ls2.gamma 544 | if isinstance(self.ls1, LayerScale) 545 | else None, 546 | ) 547 | return x_list 548 | else: 549 | 550 | def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor: 551 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) 552 | 553 | def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor: 554 | del attn_bias 555 | return self.ls2(self.mlp(self.norm2(x))) 556 | 557 | attn_bias, x = get_attn_bias_and_cat(x_list) 558 | x = x + attn_residual_func(x, attn_bias=attn_bias) 559 | x = x + ffn_residual_func(x) 560 | return attn_bias.split(x) 561 | 562 | def forward(self, x): 563 | if isinstance(x, torch.Tensor): 564 | return super().forward(x) 565 | elif isinstance(x, list): 566 | if not XFORMERS_AVAILABLE: 567 | raise AssertionError("xFormers is required for using nested tensors") 568 | return self.forward_nested(x) 569 | else: 570 | raise AssertionError 571 | 572 | 573 | def named_apply( 574 | fn: Callable, # pylint: disable=g-bare-generic 575 | module: nn.Module, 576 | name="", 577 | depth_first=True, 578 | include_root=False, 579 | ) -> nn.Module: 580 | """Apply a function to a module and its children.""" 581 | if not depth_first and include_root: 582 | fn(module=module, name=name) 583 | for child_name, child_module in module.named_children(): 584 | child_name = ".".join((name, child_name)) if name else child_name 585 | named_apply( 586 | fn=fn, 587 | module=child_module, 588 | name=child_name, 589 | depth_first=depth_first, 590 | include_root=True, 591 | ) 592 | if depth_first and include_root: 593 | fn(module=module, name=name) 594 | return module 595 | 596 | 597 | class BlockChunk(nn.ModuleList): 598 | 599 | def forward(self, x): 600 | for b in self: 601 | x = b(x) 602 | return x 603 | 604 | 605 | class VisionTransformer(nn.Module): 606 | """Vision Transformer implementation.""" 607 | 608 | def __init__( 609 | self, 610 | img_size=224, 611 | patch_size=16, 612 | in_chans=3, 613 | embed_dim=768, 614 | depth=12, 615 | num_heads=12, 616 | mlp_ratio=4.0, 617 | qkv_bias=True, 618 | ffn_bias=True, 619 | proj_bias=True, 620 | drop_path_rate=0.0, 621 | drop_path_uniform=False, 622 | init_values=None, # for layerscale: None or 0 => no layerscale 623 | embed_layer=PatchEmbed, 624 | act_layer=nn.GELU, 625 | block_fn=Block, 626 | ffn_layer="mlp", 627 | block_chunks=1, 628 | num_register_tokens=0, 629 | interpolate_antialias=False, 630 | interpolate_offset=0.1, 631 | ): 632 | """Defines the Vision Transformer model. 633 | 634 | Args: 635 | img_size (int, tuple): input image size 636 | patch_size (int, tuple): patch size 637 | in_chans (int): number of input channels 638 | embed_dim (int): embedding dimension 639 | depth (int): depth of transformer 640 | num_heads (int): number of attention heads 641 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 642 | qkv_bias (bool): enable bias for qkv if True 643 | ffn_bias (bool): enable bias for ffn if True 644 | proj_bias (bool): enable bias for proj in attn if True 645 | drop_path_rate (float): stochastic depth rate 646 | drop_path_uniform (bool): apply uniform drop rate across blocks 647 | init_values (float): layer-scale init values 648 | embed_layer (nn.Module): patch embedding layer 649 | act_layer (nn.Module): MLP activation layer 650 | block_fn (nn.Module): transformer block class 651 | ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" 652 | block_chunks: (int) split block sequence into block_chunks units for FSDP 653 | wrap 654 | num_register_tokens: (int) number of extra cls tokens (so-called 655 | "registers") 656 | interpolate_antialias: (str) flag to apply anti-aliasing when 657 | interpolating positional embeddings 658 | interpolate_offset: (float) work-around offset to apply when interpolating 659 | positional embeddings 660 | """ 661 | super().__init__() 662 | norm_layer = functools.partial(nn.LayerNorm, eps=1e-6) 663 | 664 | self.num_features = self.embed_dim = ( 665 | embed_dim # num_features for consistency with other models 666 | ) 667 | self.num_tokens = 1 668 | self.n_blocks = depth 669 | self.num_heads = num_heads 670 | self.patch_size = patch_size 671 | self.num_register_tokens = num_register_tokens 672 | self.interpolate_antialias = interpolate_antialias 673 | self.interpolate_offset = interpolate_offset 674 | 675 | self.patch_embed = embed_layer( 676 | img_size=img_size, 677 | patch_size=patch_size, 678 | in_chans=in_chans, 679 | embed_dim=embed_dim, 680 | ) 681 | num_patches = self.patch_embed.num_patches 682 | 683 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 684 | self.pos_embed = nn.Parameter( 685 | torch.zeros(1, num_patches + self.num_tokens, embed_dim) 686 | ) 687 | assert num_register_tokens >= 0 688 | self.register_tokens = ( 689 | nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) 690 | if num_register_tokens 691 | else None 692 | ) 693 | 694 | if drop_path_uniform: 695 | dpr = [drop_path_rate] * depth 696 | else: 697 | dpr = [ 698 | x.item() for x in torch.linspace(0, drop_path_rate, depth) 699 | ] # stochastic depth decay rule 700 | 701 | if ffn_layer == "mlp": 702 | ffn_layer = Mlp 703 | elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": 704 | ffn_layer = SwiGLUFFNFused 705 | else: 706 | raise NotImplementedError 707 | 708 | blocks_list = [ 709 | block_fn( 710 | dim=embed_dim, 711 | num_heads=num_heads, 712 | mlp_ratio=mlp_ratio, 713 | qkv_bias=qkv_bias, 714 | proj_bias=proj_bias, 715 | ffn_bias=ffn_bias, 716 | drop_path=dpr[i], 717 | norm_layer=norm_layer, 718 | act_layer=act_layer, 719 | ffn_layer=ffn_layer, 720 | init_values=init_values, 721 | ) 722 | for i in range(depth) 723 | ] 724 | if block_chunks > 0: 725 | self.chunked_blocks = True 726 | chunked_blocks = [] 727 | chunksize = depth // block_chunks 728 | for i in range(0, depth, chunksize): 729 | # this is to keep the block index consistent if we chunk the block list 730 | chunked_blocks.append( 731 | [nn.Identity()] * i + blocks_list[i : i + chunksize] 732 | ) 733 | self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) 734 | else: 735 | self.chunked_blocks = False 736 | self.blocks = nn.ModuleList(blocks_list) 737 | 738 | self.norm = norm_layer(embed_dim) 739 | self.head = nn.Identity() 740 | 741 | self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) 742 | 743 | self.init_weights() 744 | 745 | def init_weights(self): 746 | nn.init.trunc_normal_(self.pos_embed, std=0.02) 747 | nn.init.normal_(self.cls_token, std=1e-6) 748 | if self.register_tokens is not None: 749 | nn.init.normal_(self.register_tokens, std=1e-6) 750 | named_apply(init_weights_vit_timm, self) 751 | 752 | def interpolate_pos_encoding(self, x, w, h): 753 | previous_dtype = x.dtype 754 | npatch = x.shape[1] - 1 755 | num_patches = self.pos_embed.shape[1] - 1 756 | if npatch == num_patches and w == h: 757 | return self.pos_embed 758 | pos_embed = self.pos_embed.float() 759 | class_pos_embed = pos_embed[:, 0] 760 | patch_pos_embed = pos_embed[:, 1:] 761 | dim = x.shape[-1] 762 | w0 = w // self.patch_size 763 | h0 = h // self.patch_size 764 | num_patches_dim = int( 765 | math.sqrt(num_patches) 766 | ) # Recover the number of patches in each dimension 767 | assert num_patches == num_patches_dim * num_patches_dim 768 | kwargs = {} 769 | if self.interpolate_offset: 770 | sx = float(w0 + self.interpolate_offset) / num_patches_dim 771 | sy = float(h0 + self.interpolate_offset) / num_patches_dim 772 | kwargs["scale_factor"] = (sx, sy) 773 | else: 774 | # Simply specify an output size instead of a scale factor 775 | kwargs["size"] = (w0, h0) 776 | patch_pos_embed = nn.functional.interpolate( 777 | patch_pos_embed.reshape( 778 | 1, num_patches_dim, num_patches_dim, dim 779 | ).permute(0, 3, 1, 2), 780 | mode="bilinear", 781 | antialias=self.interpolate_antialias, 782 | **kwargs, 783 | ) 784 | assert (w0, h0) == patch_pos_embed.shape[-2:] 785 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 786 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to( 787 | previous_dtype 788 | ) 789 | 790 | def prepare_tokens_with_masks(self, x, masks=None): 791 | _, _, w, h = x.shape 792 | x = self.patch_embed(x) 793 | if masks is not None: 794 | x = torch.where( 795 | masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x 796 | ) 797 | 798 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) 799 | x = x + self.interpolate_pos_encoding(x, w, h) 800 | 801 | if self.register_tokens is not None: 802 | x = torch.cat( 803 | ( 804 | x[:, :1], 805 | self.register_tokens.expand(x.shape[0], -1, -1), 806 | x[:, 1:], 807 | ), 808 | dim=1, 809 | ) 810 | 811 | return x 812 | 813 | def forward_features_list(self, x_list, masks_list): 814 | x = [ 815 | self.prepare_tokens_with_masks(x, masks) 816 | for x, masks in zip(x_list, masks_list) 817 | ] 818 | for blk in self.blocks: 819 | x = blk(x) 820 | 821 | all_x = x 822 | output = [] 823 | for x, masks in zip(all_x, masks_list): 824 | x_norm = self.norm(x) 825 | output.append({ 826 | "x_norm_1st_clstoken": x_norm[:, :1], 827 | "x_norm_2nd_clstoken": x_norm[:, 1 : self.num_register_tokens + 1], 828 | "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], 829 | "x_prenorm": x, 830 | "masks": masks, 831 | }) 832 | return output 833 | 834 | def forward_features(self, x, masks=None): 835 | if isinstance(x, list): 836 | return self.forward_features_list(x, masks) 837 | 838 | x = self.prepare_tokens_with_masks(x, masks) 839 | 840 | for blk in self.blocks: 841 | x = blk(x) 842 | 843 | x_norm = self.norm(x) 844 | return { 845 | "x_norm_1st_clstoken": x_norm[:, :1], 846 | "x_norm_2nd_clstoken": x_norm[:, 1 : self.num_register_tokens + 1], 847 | "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], 848 | "x_prenorm": x, 849 | "masks": masks, 850 | } 851 | 852 | def _get_intermediate_layers_not_chunked(self, x, n=1): 853 | x = self.prepare_tokens_with_masks(x) 854 | # If n is an int, take the n last blocks. If it's a list, take them 855 | output, total_block_len = [], len(self.blocks) 856 | blocks_to_take = ( 857 | range(total_block_len - n, total_block_len) if isinstance(n, int) else n 858 | ) 859 | for i, blk in enumerate(self.blocks): 860 | x = blk(x) 861 | if i in blocks_to_take: 862 | output.append(x) 863 | assert len(output) == len( 864 | blocks_to_take 865 | ), f"only {len(output)} / {len(blocks_to_take)} blocks found" 866 | return output 867 | 868 | def _get_intermediate_layers_chunked(self, x, n=1): 869 | x = self.prepare_tokens_with_masks(x) 870 | output, i, total_block_len = [], 0, len(self.blocks[-1]) 871 | # If n is an int, take the n last blocks. If it's a list, take them 872 | blocks_to_take = ( 873 | range(total_block_len - n, total_block_len) if isinstance(n, int) else n 874 | ) 875 | for block_chunk in self.blocks: 876 | for blk in block_chunk[i:]: # Passing the nn.Identity() 877 | x = blk(x) 878 | if i in blocks_to_take: 879 | output.append(x) 880 | i += 1 881 | assert len(output) == len( 882 | blocks_to_take 883 | ), f"only {len(output)} / {len(blocks_to_take)} blocks found" 884 | return output 885 | 886 | def get_intermediate_layers( 887 | self, 888 | x: torch.torch.Tensor, 889 | n: Union[int, Sequence] = 1, # Layers or n last layers to take # pylint: disable=g-bare-generic 890 | reshape: bool = False, 891 | return_class_token: bool = False, 892 | norm=True, 893 | ) -> Tuple[Union[torch.torch.Tensor, Tuple[torch.torch.Tensor]]]: # pylint: disable=g-one-element-tuple 894 | if self.chunked_blocks: 895 | outputs = self._get_intermediate_layers_chunked(x, n) 896 | else: 897 | outputs = self._get_intermediate_layers_not_chunked(x, n) 898 | if norm: 899 | outputs = [self.norm(out) for out in outputs] 900 | class_tokens = [out[:, 0] for out in outputs] 901 | outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] 902 | if reshape: 903 | batch_size, _, w, h = x.shape 904 | outputs = [ 905 | out.reshape( 906 | batch_size, w // self.patch_size, h // self.patch_size, -1 907 | ) 908 | .permute(0, 3, 1, 2) 909 | .contiguous() 910 | for out in outputs 911 | ] 912 | if return_class_token: 913 | return tuple(zip(outputs, class_tokens)) 914 | return tuple(outputs) 915 | 916 | def forward(self, *args, is_training=False, **kwargs): 917 | ret = self.forward_features(*args, **kwargs) 918 | if is_training: 919 | return ret 920 | else: 921 | return self.head(ret["x_norm_1st_clstoken"]), self.head( 922 | ret["x_norm_2nd_clstoken"] 923 | ), ret["x_norm_patchtokens"] 924 | 925 | 926 | def init_weights_vit_timm(module: nn.Module, name: str = ""): # pylint: disable=unused-argument 927 | """ViT weight initialization, original timm impl (for reproducibility).""" 928 | if isinstance(module, nn.Linear): 929 | nn.init.trunc_normal_(module.weight, std=0.02) 930 | if module.bias is not None: 931 | nn.init.zeros_(module.bias) 932 | 933 | 934 | def vit_small(patch_size=14, **kwargs): 935 | model = VisionTransformer( 936 | patch_size=patch_size, 937 | embed_dim=384, 938 | depth=12, 939 | num_heads=6, 940 | mlp_ratio=4, 941 | block_fn=functools.partial(Block, attn_class=MemEffAttention), 942 | num_register_tokens=1, 943 | **kwargs, 944 | ) 945 | return model 946 | 947 | 948 | def vit_base(patch_size=14, **kwargs): 949 | model = VisionTransformer( 950 | patch_size=patch_size, 951 | embed_dim=768, 952 | depth=12, 953 | num_heads=12, 954 | mlp_ratio=4, 955 | block_fn=functools.partial(Block, attn_class=MemEffAttention), 956 | num_register_tokens=1, 957 | **kwargs, 958 | ) 959 | return model 960 | 961 | 962 | def vit_large(patch_size=14, **kwargs): 963 | model = VisionTransformer( 964 | patch_size=patch_size, 965 | embed_dim=1024, 966 | depth=24, 967 | num_heads=16, 968 | mlp_ratio=4, 969 | block_fn=functools.partial(Block, attn_class=MemEffAttention), 970 | num_register_tokens=1, 971 | **kwargs, 972 | ) 973 | return model 974 | 975 | 976 | def vit_so400m(patch_size=14, **kwargs): 977 | """SoViT 400M model (https://arxiv.org/abs/2305.13035).""" 978 | model = VisionTransformer( 979 | patch_size=patch_size, 980 | embed_dim=1152, 981 | depth=27, 982 | num_heads=16, 983 | mlp_ratio=4304 / 1152, 984 | block_fn=functools.partial(Block, attn_class=MemEffAttention), 985 | num_register_tokens=1, 986 | **kwargs, 987 | ) 988 | return model 989 | 990 | 991 | def vit_giant2(patch_size=14, **kwargs): 992 | model = VisionTransformer( 993 | patch_size=patch_size, 994 | embed_dim=1536, 995 | depth=40, 996 | num_heads=24, 997 | mlp_ratio=4, 998 | block_fn=functools.partial(Block, attn_class=MemEffAttention), 999 | num_register_tokens=1, 1000 | **kwargs, 1001 | ) 1002 | return model 1003 | -------------------------------------------------------------------------------- /pytorch/run_image_encoder_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | r"""Running TIPS (https://arxiv.org/abs/2410.16512) ViT-g model inference. 17 | 18 | Usage: 19 | ```python 20 | python run_image_encoder_inference.py --model_path=${PATH_TO_LOW_RES_CHECKPOINT} \ 21 | --image_file=${PATH_TO_IMAGE} --is_low_res --model_variant=g 22 | ``` 23 | """ 24 | 25 | import argparse 26 | import io 27 | 28 | import numpy as np 29 | from PIL import Image 30 | import torch 31 | from torchvision import transforms 32 | 33 | from tips.pytorch import image_encoder 34 | 35 | IMAGE_MEAN = (0, 0, 0) 36 | IMAGE_STD = (1.0, 1.0, 1.0) 37 | PATCH_SIZE = 14 38 | 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument( 41 | '--model_path', default=None, required=True, help='The path to the model.' 42 | ) 43 | parser.add_argument( 44 | '--image_file', 45 | default=None, 46 | required=True, 47 | help='The path to the image file for inference.', 48 | ) 49 | parser.add_argument( 50 | '--is_low_res', 51 | action='store_true', 52 | help='Whether the model is low-resolution.', 53 | ) 54 | parser.add_argument( 55 | '--model_variant', 56 | default=None, 57 | required=True, 58 | help='The variant of the model.', 59 | ) 60 | 61 | 62 | def main(args): 63 | 64 | image_size = 224 if args.is_low_res else 448 65 | model_def = { 66 | 'S': image_encoder.vit_small, 67 | 'B': image_encoder.vit_base, 68 | 'L': image_encoder.vit_large, 69 | 'So400m': image_encoder.vit_so400m, 70 | 'g': image_encoder.vit_giant2, 71 | }[args.model_variant] 72 | 73 | ffn_layer = 'swiglu' if args.model_variant == 'g' else 'mlp' 74 | 75 | # Load checkpoint. 76 | checkpoint = dict(np.load(args.model_path, allow_pickle=False)) 77 | for key in checkpoint: 78 | checkpoint[key] = torch.tensor(checkpoint[key]) 79 | 80 | # Run inference on the image. 81 | with open(args.image_file, 'rb') as fd: 82 | image_bytes = io.BytesIO(fd.read()) 83 | pil_image = Image.open(image_bytes) 84 | transform = transforms.Compose([ 85 | transforms.Resize((image_size, image_size)), 86 | transforms.ToTensor(), 87 | transforms.Normalize(IMAGE_MEAN, IMAGE_STD), 88 | ]) 89 | input_tensor = transform(pil_image) 90 | input_batch = input_tensor.unsqueeze(0) 91 | 92 | with torch.no_grad(): 93 | model = model_def( 94 | img_size=image_size, 95 | patch_size=PATCH_SIZE, 96 | ffn_layer=ffn_layer, 97 | block_chunks=0, 98 | init_values=1.0, 99 | interpolate_antialias=True, 100 | interpolate_offset=0.0, 101 | ) 102 | model.load_state_dict(checkpoint) 103 | 104 | # Compute embeddings from two CLS tokens. 105 | outputs = model(input_batch) 106 | first_cls_token = outputs[0].detach().numpy().squeeze() 107 | second_cls_token = outputs[1].detach().numpy().squeeze() 108 | 109 | first_cls_token = first_cls_token / np.linalg.norm( 110 | first_cls_token, ord=2, axis=-1, keepdims=True 111 | ).clip(min=1e-3) 112 | second_cls_token = second_cls_token / np.linalg.norm( 113 | second_cls_token, ord=2, axis=-1, keepdims=True 114 | ).clip(min=1e-3) 115 | print('First cls token: ', first_cls_token.tolist()) 116 | print('Second cls token: ', second_cls_token.tolist()) 117 | 118 | 119 | if __name__ == '__main__': 120 | main(parser.parse_args()) 121 | -------------------------------------------------------------------------------- /pytorch/run_text_encoder_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | r"""Running TIPS (https://arxiv.org/abs/2410.16512) text encoder inference. 17 | 18 | Usage: 19 | ```python 20 | python run_text_encoder_inference.py --model_path=${PATH_TO_LOW_RES_CHECKPOINT} \ 21 | --model_variant=g --tokenizer_path=${PATH_TO_TOKENIZER} \ 22 | --text_input="Hello world." 23 | ``` 24 | """ 25 | 26 | import argparse 27 | import io 28 | import numpy as np 29 | import torch 30 | from tips.pytorch import text_encoder 31 | 32 | MAX_LEN = 64 33 | VOCAB_SIZE = 32000 34 | 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument( 37 | '--model_path', default=None, required=True, help='The path to the model.' 38 | ) 39 | parser.add_argument( 40 | '--model_variant', 41 | default=None, 42 | required=True, 43 | help='The variant of the model.', 44 | ) 45 | parser.add_argument( 46 | '--tokenizer_path', 47 | default=None, 48 | required=True, 49 | help='The path to the tokenizer.', 50 | ) 51 | parser.add_argument( 52 | '--text_input', 53 | default=None, 54 | required=True, 55 | help='The text input to the model.', 56 | ) 57 | 58 | 59 | def get_config(v: str): 60 | return { 61 | 'hidden_size': {'S': 384, 'B': 768, 'L': 1024, 'So400m': 1152, 'g': 1536}[ 62 | v 63 | ], 64 | 'mlp_dim': {'S': 1536, 'B': 3072, 'L': 4096, 'So400m': 4304, 'g': 6144}[ 65 | v 66 | ], 67 | 'num_heads': {'S': 6, 'B': 12, 'L': 16, 'So400m': 16, 'g': 24}[v], 68 | 'num_layers': {'S': 12, 'B': 12, 'L': 12, 'So400m': 27, 'g': 12}[v], 69 | } 70 | 71 | 72 | def main(args): 73 | 74 | with open(args.model_path, 'rb') as fin: 75 | inbuffer = io.BytesIO(fin.read()) 76 | np_weights_text = np.load(inbuffer, allow_pickle=False) 77 | 78 | pytorch_weights_text = {} 79 | for key, value in np_weights_text.items(): 80 | pytorch_weights_text[key] = torch.from_numpy(value) 81 | pytorch_weights_text.pop('temperature') 82 | 83 | with torch.no_grad(): 84 | # Define the text model. 85 | model_text = text_encoder.TextEncoder( 86 | get_config(args.model_variant), 87 | vocab_size=VOCAB_SIZE, 88 | ) 89 | model_text.load_state_dict(pytorch_weights_text) 90 | 91 | tokenizer_obj = text_encoder.Tokenizer(tokenizer_path=args.tokenizer_path) 92 | text_ids, text_paddings = tokenizer_obj.tokenize( 93 | [args.text_input], max_len=MAX_LEN 94 | ) 95 | text_embedding = ( 96 | model_text(torch.from_numpy(text_ids), torch.from_numpy(text_paddings)) 97 | .detach() 98 | .numpy() 99 | .squeeze() 100 | ) 101 | text_embedding = text_embedding / np.linalg.norm( 102 | text_embedding, ord=2, axis=-1, keepdims=True 103 | ).clip(min=1e-3) 104 | print(text_embedding.tolist()) 105 | 106 | 107 | if __name__ == '__main__': 108 | main(parser.parse_args()) 109 | -------------------------------------------------------------------------------- /pytorch/text_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Text encoder implementation in PyTorch.""" 17 | 18 | import typing as t 19 | 20 | import tensorflow as tf 21 | import tensorflow_text 22 | import torch 23 | from torch import nn 24 | import torch.nn.functional as F 25 | 26 | 27 | class Tokenizer(object): 28 | """A simple tokenizer.""" 29 | 30 | def __init__(self, tokenizer_path: str): 31 | """Initializes the tokenizer.""" 32 | with open(tokenizer_path, 'rb') as f: 33 | model = f.read() 34 | self.tokenizer = tensorflow_text.SentencepieceTokenizer( 35 | model=model, add_eos=False, add_bos=False 36 | ) 37 | 38 | def tokenize(self, input_text, max_len=64): 39 | tokens = self.tokenizer.tokenize(tf.strings.lower(input_text)).to_tensor() 40 | curr_len = tokens.shape[1] 41 | is_padding = tf.zeros((tokens.shape[0], max_len)) 42 | if curr_len > max_len: 43 | tokens = tokens[:, :max_len] 44 | else: 45 | padding_len = max_len - curr_len 46 | tokens = tf.pad(tokens, [[0, 0], [0, padding_len]], constant_values=0) 47 | is_padding = tf.cast(tokens == 0, tf.int32) 48 | return tokens.numpy(), is_padding.numpy() 49 | 50 | 51 | class PositionalEmbedding(nn.Module): 52 | """Generates position embedding for a given 1-d sequence. 53 | 54 | Attributes: 55 | min_timescale: Start of the geometric index. Determines the periodicity of 56 | the added signal. 57 | max_timescale: End of the geometric index. Determines the frequency of the 58 | added signal. 59 | embedding_dim: Dimension of the embedding to be generated. 60 | """ 61 | 62 | min_timescale: int = 1 63 | max_timescale: int = 10_000 64 | embedding_dim: int = 0 65 | 66 | def __init__(self, embedding_dim: int): 67 | super().__init__() 68 | self.embedding_dim = embedding_dim 69 | 70 | def __call__(self, seq_length: int = None, position: torch.tensor = None): 71 | """Generates a torch.tensor of sinusoids with different frequencies. 72 | 73 | Args: 74 | seq_length: an optional Python int defining the output sequence length. 75 | if the `position` argument is specified. 76 | position: [B, seq_length], optional position for each token in the 77 | sequence, only required when the sequence is packed. 78 | 79 | Returns: 80 | [B, seqlen, D] if `position` is specified, else [1, seqlen, D] 81 | """ 82 | if position is None: 83 | assert seq_length is not None 84 | # [1, seqlen] 85 | position = torch.arange(seq_length, dtype=torch.float32)[None, :] 86 | else: 87 | assert position.ndim == 2, position.shape 88 | 89 | num_timescales = self.embedding_dim // 2 90 | log_timescale_increment = torch.log( 91 | torch.tensor(float(self.max_timescale) / float(self.min_timescale)) 92 | ) / torch.maximum( 93 | torch.tensor(num_timescales, dtype=torch.float32) - 1, torch.tensor(1) 94 | ) 95 | inv_timescales = self.min_timescale * torch.exp( 96 | torch.arange(num_timescales, dtype=torch.float32) 97 | * -log_timescale_increment 98 | ) 99 | scaled_time = position[:, :, None] * inv_timescales[None, None, :] 100 | signal = torch.cat((torch.sin(scaled_time), torch.cos(scaled_time)), dim=2) 101 | # Force usage of `np` rather than `jnp` to compute static values at trace 102 | # time. 103 | signal = F.pad(signal, (0, self.embedding_dim % 2, 0, 0, 0, 0)) 104 | return signal 105 | 106 | 107 | class MlpBlockWithMask(nn.Module): 108 | """Transformer MLP / feed-forward block that supports masking.""" 109 | 110 | def __init__( 111 | self, 112 | mlp_dim: int, 113 | d_model: int, 114 | use_bias: bool = True, 115 | dtype: torch.dtype = torch.float32, 116 | activation_fn: nn.Module = nn.GELU, 117 | ): 118 | super().__init__() 119 | 120 | self.mlp_dim = mlp_dim 121 | self.d_model = d_model 122 | self.use_bias = use_bias 123 | self.dtype = dtype 124 | self.activation_fn = activation_fn 125 | 126 | self.c_fc = nn.Linear( 127 | in_features=self.d_model, 128 | out_features=self.mlp_dim, 129 | dtype=self.dtype, 130 | bias=self.use_bias, 131 | ) 132 | self.c_proj = nn.Linear( 133 | in_features=self.mlp_dim, 134 | out_features=self.d_model, 135 | dtype=self.dtype, 136 | bias=self.use_bias, 137 | ) 138 | 139 | def __call__( 140 | self, inputs: torch.Tensor, mlp_mask: torch.Tensor 141 | ) -> torch.Tensor: 142 | """Applies Transformer MlpBlock with mask module.""" 143 | x = self.c_fc(inputs) 144 | x = self.activation_fn()(x) 145 | x = x * mlp_mask[..., None] # First masking. 146 | x = self.c_proj(x) 147 | x = x * mlp_mask[..., None] # Second masking. 148 | return x 149 | 150 | 151 | class ResidualAttentionBlock(nn.Module): 152 | """Transformer residual attention block.""" 153 | 154 | def __init__( 155 | self, 156 | d_model: int, 157 | n_head: int, 158 | mlp_dim: int, 159 | dtype: torch.dtype = torch.float32, 160 | ): 161 | super().__init__() 162 | self.d_model = d_model 163 | self.n_head = n_head 164 | self.mlp_dim = mlp_dim 165 | self.dtype = dtype 166 | 167 | self.attn = nn.MultiheadAttention(d_model, n_head, dtype=self.dtype) 168 | self.ln_1 = nn.LayerNorm(d_model, dtype=self.dtype) 169 | self.mlp = MlpBlockWithMask( 170 | self.mlp_dim, 171 | d_model, 172 | use_bias=True, 173 | dtype=self.dtype, 174 | activation_fn=nn.ReLU, 175 | ) 176 | self.ln_2 = nn.LayerNorm(d_model, dtype=self.dtype) 177 | 178 | def attention(self, x: torch.Tensor, mask: torch.Tensor): 179 | attn_mask = ( 180 | mask[:, None, None, :] 181 | .repeat(1, self.n_head, x.shape[0], 1) 182 | .flatten(0, 1) 183 | ) 184 | attn_mask[attn_mask == 0] = float('-inf') 185 | attn_mask[attn_mask == 1] = 0 186 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] 187 | 188 | def forward(self, x: torch.Tensor, mask: torch.Tensor): 189 | x = x + self.attention(self.ln_1(x), mask.permute(1, 0)) 190 | x = x + self.mlp(self.ln_2(x), mask) 191 | return x, mask 192 | 193 | 194 | class SequentialMultiInput(nn.Sequential): 195 | """Sequential module that can take multiple inputs.""" 196 | 197 | def forward(self, *inputs): 198 | for module in self._modules.values(): 199 | if isinstance(inputs, tuple): 200 | inputs = module(*inputs) 201 | else: 202 | inputs = module(inputs) 203 | return inputs 204 | 205 | 206 | class Transformer(nn.Module): 207 | """Transformer implementation.""" 208 | 209 | def __init__( 210 | self, 211 | width: int, 212 | layers: int, 213 | heads: int, 214 | mlp_dim: int, 215 | dtype: torch.dtype = torch.float32, 216 | ): 217 | super().__init__() 218 | self.width = width 219 | self.layers = layers 220 | self.heads = heads 221 | self.mlp_dim = mlp_dim 222 | self.dtype = dtype 223 | 224 | self.resblocks = SequentialMultiInput(*[ 225 | ResidualAttentionBlock(self.width, self.heads, self.mlp_dim, self.dtype) 226 | for _ in range(self.layers) 227 | ]) 228 | 229 | def forward(self, x: torch.Tensor, mask: torch.Tensor): 230 | return self.resblocks(x, mask)[0] 231 | 232 | 233 | class GlobalAvgPooling(nn.Module): 234 | """Performs a simple global pooling over the input with optional paddings. 235 | 236 | Attributes: 237 | pooling_dims: A list of dims to perform pooling over. 238 | keepdims: If True, keep dimension of inputs after pooling. 239 | """ 240 | 241 | pooling_dims: t.Sequence[int] 242 | epsilon: float = 1e-8 243 | 244 | def __init__( 245 | self, pooling_dims: t.Sequence[int], epsilon: float = 1e-8 246 | ): 247 | super().__init__() 248 | self.pooling_dims = pooling_dims 249 | self.epsilon = epsilon 250 | 251 | if not all([p_dims >= 0 for p_dims in self.pooling_dims]): 252 | raise ValueError('pooling_dims must be non-negative integers.') 253 | 254 | def __call__( 255 | self, 256 | inputs: torch.tensor, 257 | compatible_paddings: torch.tensor, 258 | ): 259 | """Applies global average spatial pooling to inputs. 260 | 261 | Args: 262 | inputs: An input tensor. 263 | compatible_paddings: paddings of inputs with shapes compatible with 264 | inputs, e.g. compatible_paddings with shape [B, 1] for inputs with shape 265 | [B, D]. 266 | 267 | Returns: 268 | Output tensor with global pooling applied. 269 | """ 270 | padded_value = torch.zeros_like(inputs) 271 | padded_value = torch.ones_like(inputs) * padded_value 272 | inputs = torch.where(compatible_paddings > 0, padded_value, inputs) 273 | valid_inputs = ( 274 | torch.sum( 275 | 1.0 - compatible_paddings, 276 | self.pooling_dims, 277 | keepdims=True, 278 | dtype=inputs.dtype, 279 | ) 280 | + self.epsilon 281 | ) 282 | inputs_sum = torch.sum(inputs, self.pooling_dims, keepdims=True) 283 | outputs = torch.divide(inputs_sum, valid_inputs).type(inputs.dtype) 284 | outputs = torch.squeeze(outputs, axis=self.pooling_dims) 285 | return outputs 286 | 287 | 288 | class TextEncoder(nn.Module): 289 | """Text encoder implementation.""" 290 | 291 | def __init__( 292 | self, 293 | config: t.Dict[str, int], 294 | vocab_size: int, 295 | dtype: torch.dtype = torch.float32, 296 | scale_sqrt_depth: bool = True, 297 | ): 298 | super().__init__() 299 | self.vocab_size = vocab_size 300 | self.dtype = dtype 301 | self.scale_sqrt_depth = scale_sqrt_depth 302 | 303 | # The text tower layers are fixed independent of vision tower size. 304 | self.transformer_layers = config['num_layers'] 305 | self.embedding_dim = config['hidden_size'] 306 | self.transformer_width = config['hidden_size'] 307 | self.mlp_dim = config['mlp_dim'] 308 | self.transformer_heads = config['num_heads'] 309 | 310 | self.token_embedding = nn.Embedding( 311 | self.vocab_size, self.embedding_dim, dtype=self.dtype 312 | ) 313 | self.pos_embedder = PositionalEmbedding(embedding_dim=self.embedding_dim) 314 | self.transformer = Transformer( 315 | width=self.transformer_width, 316 | layers=self.transformer_layers, 317 | heads=self.transformer_heads, 318 | mlp_dim=self.mlp_dim, 319 | dtype=self.dtype, 320 | ) 321 | self.pooling = GlobalAvgPooling(pooling_dims=[1]) 322 | self.ln_final = nn.LayerNorm(self.transformer_width, dtype=self.dtype) 323 | 324 | def __call__( 325 | self, 326 | ids: torch.tensor, 327 | paddings: torch.tensor, 328 | ): 329 | """Applies TextEncoder module.""" 330 | _, seq_length = ids.shape 331 | mask = (paddings == 0).type(torch.float32) 332 | mask = mask.permute(1, 0) # NL -> LN 333 | x = self.token_embedding(ids) 334 | if self.scale_sqrt_depth: 335 | x = x * (self.embedding_dim**0.5) 336 | x = x + self.pos_embedder(seq_length=seq_length) 337 | x = x.permute(1, 0, 2) # NLD -> LND 338 | x = self.transformer(x, mask) 339 | x = x.permute(1, 0, 2) # LND -> NLD 340 | x = self.ln_final(x) 341 | x = self.pooling(x, compatible_paddings=paddings[:, :, None]) 342 | return x 343 | -------------------------------------------------------------------------------- /scenic/checkpoints/download_checkpoints.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2025 Google LLC 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | 18 | # The model weights can be found in https://console.cloud.google.com/storage/browser/tips_data 19 | ALL_CHECKPOINTS=( 20 | "tips_oss_s14_highres_distilled" 21 | "tips_oss_b14_highres_distilled" 22 | "tips_oss_l14_highres_distilled" 23 | "tips_oss_so400m14_highres_largetext_distilled" 24 | "tips_oss_g14_lowres" 25 | "tips_oss_g14_highres" 26 | ) 27 | 28 | echo "Downloading the tokenizer." 29 | wget https://storage.googleapis.com/tips_data/v1_0/checkpoints/tokenizer.model 30 | 31 | for CHECKPOINT in "${ALL_CHECKPOINTS[@]}"; do 32 | echo "Downloading ${CHECKPOINT} (vision encoder weights)" 33 | wget https://storage.googleapis.com/tips_data/v1_0/checkpoints/scenic/${CHECKPOINT}_vision.npz 34 | echo "Downloading ${CHECKPOINT} (text encoder weights)" 35 | wget https://storage.googleapis.com/tips_data/v1_0/checkpoints/scenic/${CHECKPOINT}_text.npz 36 | done 37 | -------------------------------------------------------------------------------- /scenic/configs/tips_model_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """TIPS model config.""" 17 | 18 | import ml_collections 19 | 20 | _MEAN_RGB = [0., 0., 0.] 21 | _STDDEV_RGB = [1., 1., 1.] 22 | 23 | # The 'g' variant refers to the DINO-v2 'giant2', which differs from ViT-g. 24 | # The differences are highlighted in https://arxiv.org/pdf/2304.07193 Section 5. 25 | _VARIANT_DICT = { 26 | 'tips_oss_g14_highres': 'g/14', 27 | 'tips_oss_g14_lowres': 'g/14', 28 | 'tips_oss_so400m14_highres_largetext_distilled': 'So400m/14', 29 | 'tips_oss_l14_highres_distilled': 'L/14', 30 | 'tips_oss_b14_highres_distilled': 'B/14', 31 | 'tips_oss_s14_highres_distilled': 'S/14', 32 | } 33 | 34 | 35 | def get_config(variant: str): 36 | """Returns the TIPS model config.""" 37 | config = ml_collections.ConfigDict() 38 | if variant not in _VARIANT_DICT: 39 | raise ValueError( 40 | f'Unknown TIPS variant: {variant}. Please choose one of: ' 41 | f'{list(_VARIANT_DICT.keys())}') 42 | 43 | config.variant = _VARIANT_DICT[variant] 44 | config.rgb_mean = _MEAN_RGB 45 | config.rgb_std = _STDDEV_RGB 46 | 47 | config.pooling = 'tok' 48 | config.pos_interpolation_method = 'bilinear' 49 | 50 | # TIPS defaults to 2 CLS tokens. 51 | config.num_cls_tokens = 2 52 | 53 | return config 54 | -------------------------------------------------------------------------------- /scenic/images/example_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/tips/c1be95c4cba9345f1d207e652c6e285ed7c2ec04/scenic/images/example_image.jpg -------------------------------------------------------------------------------- /scenic/images/example_image_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/tips/c1be95c4cba9345f1d207e652c6e285ed7c2ec04/scenic/images/example_image_2.jpg -------------------------------------------------------------------------------- /scenic/models/text.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Text-encoder related modules.""" 17 | 18 | import math 19 | import typing as t 20 | 21 | import flax.linen as nn 22 | import jax 23 | import jax.numpy as jnp 24 | import numpy as np 25 | from scenic.model_lib.layers import nn_layers 26 | import tensorflow as tf 27 | import tensorflow_text 28 | 29 | 30 | Initializer = t.Callable[[jnp.ndarray, t.Sequence[int], jnp.dtype], jnp.ndarray] 31 | 32 | 33 | class Tokenizer(object): 34 | """A simple tokenizer.""" 35 | 36 | def __init__(self, tokenizer_path: str): 37 | """Initializes the tokenizer.""" 38 | with open(tokenizer_path, 'rb') as f: 39 | model = f.read() 40 | self.tokenizer = tensorflow_text.SentencepieceTokenizer( 41 | model=model, add_eos=False, add_bos=False) 42 | 43 | def tokenize(self, input_text, max_len=64): 44 | tokens = self.tokenizer.tokenize(tf.strings.lower(input_text)).to_tensor() 45 | curr_len = tokens.shape[1] 46 | is_padding = tf.zeros((tokens.shape[0], max_len)) 47 | if curr_len > max_len: 48 | tokens = tokens[:, :max_len] 49 | else: 50 | padding_len = max_len - curr_len 51 | tokens = tf.pad(tokens, [[0, 0], [0, padding_len]], constant_values=0) 52 | is_padding = tf.cast(tokens == 0, tf.int32) 53 | return tokens.numpy(), is_padding.numpy() 54 | 55 | 56 | class Embedding(nn.Module): 57 | """A simple embedding layer that performs embedding lookups from ids. 58 | 59 | Simple version of 60 | https://github.com/google/praxis/blob/main/praxis/layers/embedding_softmax.py#L97 61 | 62 | Attributes: 63 | num_classes: Number of tokens in the vocabulary. 64 | embedding_dim: Depth of the embedding output. 65 | scale_sqrt_depth: If set to True, activations are scaled with 66 | sqrt(embedding_dim) in emb_lookup. 67 | """ 68 | 69 | num_classes: int = 0 70 | embedding_dim: int = 0 71 | scale_sqrt_depth: bool = True 72 | 73 | def setup(self) -> None: 74 | assert self.num_classes > 0 75 | assert self.embedding_dim > 0 76 | 77 | self.emb_var = self.param( 78 | 'emb_var', 79 | nn.initializers.variance_scaling(1.0, 'fan_in', 'normal', out_axis=0), 80 | (self.num_classes, self.embedding_dim), 81 | jnp.float32) 82 | 83 | def emb_lookup(self, ids: jnp.ndarray) -> jnp.ndarray: 84 | embs = self.emb_var[ids] 85 | 86 | if self.scale_sqrt_depth: 87 | embs *= self.embedding_dim**0.5 88 | 89 | return embs 90 | 91 | def __call__(self, ids: jnp.ndarray) -> jnp.ndarray: 92 | return self.emb_lookup(ids) 93 | 94 | 95 | class PositionalEmbedding(nn.Module): 96 | """Generates fixed position embedding for a given 1-d sequence. 97 | 98 | Simplified version of 99 | https://github.com/google/praxis/blob/main/praxis/layers/embedding_softmax.py#L1011 100 | 101 | Attributes: 102 | min_timescale: Start of the geometric index. Determines the periodicity of 103 | the added signal. 104 | max_timescale: End of the geometric index. Determines the frequency of the 105 | added signal. 106 | embedding_dim: Dimension of the embedding to be generated. 107 | """ 108 | 109 | min_timescale: int = 1 110 | max_timescale: int = 10_000 111 | embedding_dim: int = 0 112 | 113 | def __call__( 114 | self, seq_length: int | None = None, position: jnp.ndarray | None = None 115 | ) -> jnp.ndarray: 116 | """Generates a jnp.ndarray of sinusoids with different frequencies. 117 | 118 | Args: 119 | seq_length: an optional Python int definiing the output sequence length. 120 | if the `position` argument is specified. 121 | position: [B, seq_length], optional position for each token in the 122 | sequence, only required when the sequence is packed. 123 | 124 | Returns: 125 | [B, seqlen, D] if `position` is specified, else [1, seqlen, D] 126 | """ 127 | if position is None: 128 | assert seq_length is not None 129 | # [1, seqlen] 130 | position = jnp.arange(seq_length, dtype=jnp.float32)[jnp.newaxis, :] 131 | else: 132 | assert position.ndim == 2, position.shape 133 | 134 | num_timescales = self.embedding_dim // 2 135 | log_timescale_increment = math.log( 136 | float(self.max_timescale) / float(self.min_timescale) 137 | ) / jnp.maximum(jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1) 138 | inv_timescales = self.min_timescale * jnp.exp( 139 | jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment 140 | ) 141 | scaled_time = ( 142 | position[:, :, jnp.newaxis] 143 | * inv_timescales[jnp.newaxis, jnp.newaxis, :] 144 | ) 145 | signal = jnp.concatenate( 146 | [jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=2) 147 | # Force usage of `np` rather than `jnp` to compute static values at trace 148 | # time. 149 | signal = jnp.pad( 150 | signal, [[0, 0], [0, 0], [0, np.mod(self.embedding_dim, 2)]] 151 | ) 152 | return signal 153 | 154 | 155 | class GlobalAvgPooling(nn.Module): 156 | """Performs a simple global pooling over the input with optional paddings. 157 | 158 | Attributes: 159 | pooling_dims: A list of dims to perform pooling over. 160 | keepdims: If True, keep dimension of inputs after pooling. 161 | """ 162 | pooling_dims: t.Sequence[int] | None = None 163 | epsilon: float = 1e-8 164 | 165 | def setup(self) -> None: 166 | if self.pooling_dims is None: 167 | raise ValueError('pooling_dims must be set as a list.') 168 | else: 169 | if not all([p_dims >= 0 for p_dims in self.pooling_dims]): 170 | raise ValueError('pooling_dims must be non-negative integers.') 171 | 172 | def __call__( 173 | self, 174 | inputs: jnp.ndarray, 175 | compatible_paddings: jnp.ndarray, 176 | ) -> jnp.ndarray: 177 | """Applies global average spatial pooling to inputs. 178 | 179 | Args: 180 | inputs: An input tensor. 181 | compatible_paddings: paddings of inputs with shapes compatible 182 | with inputs, e.g. compatible_paddings with shape [B, 1] for inputs with 183 | shape [B, D]. 184 | 185 | Returns: 186 | Output tensor with global pooling applied. 187 | """ 188 | padded_value = jnp.zeros(shape=(), dtype=inputs.dtype) 189 | padded_value = jnp.ones_like(inputs) * padded_value 190 | inputs = jnp.where(compatible_paddings > 0, padded_value, inputs) 191 | valid_inputs = ( 192 | jnp.sum( 193 | 1.0 - compatible_paddings, 194 | self.pooling_dims, 195 | keepdims=True, 196 | dtype=inputs.dtype) 197 | + self.epsilon) 198 | inputs_sum = jnp.sum(inputs, self.pooling_dims, keepdims=True) 199 | outputs = jnp.divide(inputs_sum, valid_inputs).astype(inputs.dtype) 200 | outputs = jnp.squeeze(outputs, axis=self.pooling_dims) 201 | return outputs 202 | 203 | 204 | class MlpBlockWithMask(nn.Module): 205 | """Transformer MLP / feed-forward block that supports masking.""" 206 | 207 | mlp_dim: int 208 | out_dim: t.Optional[int] = None 209 | dropout_rate: float = 0.1 210 | use_bias: bool = True 211 | kernel_init: Initializer = nn.initializers.xavier_uniform() 212 | bias_init: Initializer = nn.initializers.normal(stddev=1e-6) 213 | activation_fn: t.Callable[[jnp.ndarray], jnp.ndarray] = nn.gelu 214 | precision: t.Optional[jax.lax.Precision] = None 215 | dtype: jnp.ndarray = jnp.float32 216 | 217 | @nn.compact 218 | def __call__(self, inputs: jnp.ndarray, *, mask, deterministic: bool): 219 | """Applies Transformer MlpBlock with mask module.""" 220 | actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim 221 | x = nn.Dense( 222 | self.mlp_dim, 223 | dtype=self.dtype, 224 | use_bias=self.use_bias, 225 | kernel_init=self.kernel_init, 226 | bias_init=self.bias_init, 227 | precision=self.precision)( 228 | inputs) 229 | x = nn_layers.IdentityLayer(name='mlp1')(self.activation_fn(x)) 230 | x = x * mask[..., None] # First masking. 231 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) 232 | output = nn.Dense( 233 | actual_out_dim, 234 | dtype=self.dtype, 235 | use_bias=self.use_bias, 236 | kernel_init=self.kernel_init, 237 | bias_init=self.bias_init, 238 | precision=self.precision)(x) 239 | output = output * mask[..., None] # Second masking. 240 | output = nn_layers.IdentityLayer(name='mlp2')(output) 241 | output = nn.Dropout(rate=self.dropout_rate)( 242 | output, deterministic=deterministic) 243 | return output 244 | 245 | 246 | class TextEncoder1DBlock(nn.Module): 247 | """Transformer text encoder layer. 248 | 249 | Attributes: 250 | mlp_dim: Dimension of the mlp on top of attention block. 251 | num_heads: Number of self-attention heads. 252 | dtype: The dtype of the computation (default: float32). 253 | dropout_rate: Dropout rate. 254 | attention_dropout_rate: Dropout for attention heads. 255 | stochastic_depth: probability of dropping a layer linearly grows 256 | from 0 to the provided value. 257 | ffn_layer: type of the feed-forward layer. Options are 'mlp', 'swiglufused'. 258 | 259 | Returns: 260 | output after transformer encoder block. 261 | """ 262 | mlp_dim: int 263 | num_heads: int 264 | dtype: t.Any = jnp.float32 265 | dropout_rate: float = 0.1 266 | attention_dropout_rate: float = 0.1 267 | stochastic_depth: float = 0.0 268 | 269 | @nn.compact 270 | def __call__( 271 | self, inputs: jnp.ndarray, mask: jnp.ndarray, deterministic: bool 272 | ) -> jnp.ndarray: 273 | """Applies Encoder1DBlock module. 274 | 275 | Args: 276 | inputs: Input data. 277 | mask: Input mask. 278 | deterministic: Deterministic or not (to apply dropout). 279 | 280 | Returns: 281 | Output after transformer encoder block. 282 | """ 283 | # Attention block. 284 | assert inputs.ndim == 3 285 | x = nn.LayerNorm(name='LayerNorm_0', dtype=self.dtype)(inputs) 286 | x = nn.MultiHeadDotProductAttention( 287 | num_heads=self.num_heads, 288 | dtype=self.dtype, 289 | kernel_init=nn.initializers.xavier_uniform(), 290 | broadcast_dropout=False, 291 | deterministic=deterministic, 292 | dropout_rate=self.attention_dropout_rate)( 293 | x, x, mask=mask[:, jnp.newaxis, jnp.newaxis, :]) 294 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic) 295 | x = nn_layers.StochasticDepth(rate=self.stochastic_depth)(x, deterministic) 296 | x = x + inputs 297 | 298 | # MLP block. 299 | y = nn.LayerNorm(dtype=self.dtype, name='LayerNorm_1')(x) 300 | mlp0 = MlpBlockWithMask( 301 | mlp_dim=self.mlp_dim, 302 | dtype=self.dtype, 303 | dropout_rate=self.dropout_rate, 304 | activation_fn=nn.relu, # ReLU is the choice for the PAX experiments. 305 | kernel_init=nn.initializers.xavier_uniform(), 306 | bias_init=nn.initializers.normal(stddev=1e-6), 307 | name='MlpBlock_0' 308 | ) 309 | y = mlp0(y, mask=mask, deterministic=deterministic) 310 | y = nn_layers.StochasticDepth(rate=self.stochastic_depth)(y, deterministic) 311 | return x + y 312 | 313 | 314 | class StackedTransformer(nn.Module): 315 | """Stacked transformer.""" 316 | 317 | mlp_dim: int 318 | num_layers: int 319 | num_heads: int 320 | dropout_rate: float = 0.1 321 | attention_dropout_rate: float = 0.1 322 | stochastic_depth: float = 0.0 323 | dtype: t.Any = jnp.float32 324 | 325 | def setup(self): 326 | encoder_blocks = [] 327 | for lyr in range(self.num_layers): 328 | encoder_blocks.append( 329 | TextEncoder1DBlock( 330 | mlp_dim=self.mlp_dim, 331 | num_heads=self.num_heads, 332 | dropout_rate=self.dropout_rate, 333 | attention_dropout_rate=self.attention_dropout_rate, 334 | stochastic_depth=( 335 | (lyr / max(self.num_layers - 1, 1)) * self.stochastic_depth), 336 | name=f'encoderblock_{lyr}', 337 | )) 338 | self.encoder_blocks = encoder_blocks 339 | 340 | def __call__( 341 | self, x: jnp.ndarray, mask: jnp.ndarray, deterministic: bool 342 | ) -> jnp.ndarray: 343 | """Applies StackedTransformer module.""" 344 | for block in self.encoder_blocks: 345 | x = block(x, mask, deterministic=deterministic) 346 | return x 347 | 348 | -------------------------------------------------------------------------------- /scenic/models/tips.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """The TIPS model definition.""" 17 | 18 | import typing as t 19 | 20 | import flax.linen as nn 21 | import jax.numpy as jnp 22 | 23 | from tips.scenic.models import text 24 | from tips.scenic.models import vit 25 | 26 | 27 | class VisionEncoder(nn.Module): 28 | """TIPS vision encoder based on ViT.""" 29 | 30 | variant: str 31 | pooling: str = 'tok' 32 | num_cls_tokens: int = 2 # TIPS defaults to 2 CLS tokens. 33 | dropout_rate: float = 0.0 34 | attention_dropout_rate: float = 0.0 35 | stochastic_depth: float = 0.0 36 | dtype: t.Any = jnp.float32 37 | 38 | def setup(self): 39 | super().setup() 40 | 41 | self.encoder = vit.ViT( 42 | variant=self.variant, 43 | num_cls_tokens=self.num_cls_tokens, 44 | dropout_rate=self.dropout_rate, 45 | attention_dropout_rate=self.attention_dropout_rate, 46 | stochastic_depth=self.stochastic_depth, 47 | dtype=self.dtype, 48 | ) 49 | self.patches = self.encoder.patches 50 | 51 | def pool_features(self, x: jnp.ndarray)-> t.Tuple[jnp.ndarray, jnp.ndarray]: 52 | """Extracts the spatial and vector features from the backhone. 53 | 54 | Currently supports only 'tok' pooling (CLS tokens). The CLS tokens are 55 | always prepended to the spatial (patch) tokens. 56 | 57 | Args: 58 | x: The input features. 59 | 60 | Returns: 61 | x_patch: The spatial features. 62 | x_vec: The vector embedding(s). 63 | """ 64 | if self.pooling == 'tok': 65 | x_vec = x[:, :self.num_cls_tokens, :] 66 | x_patch = x[:, self.num_cls_tokens:, :] 67 | else: 68 | raise ValueError(f'Invalid pooling: {self.pooling}') 69 | return x_patch, x_vec 70 | 71 | def reshape_spatial_features( 72 | self, x: jnp.ndarray, h: int, w: int) -> jnp.ndarray: 73 | """Re-shapes the spatial features according to the initial dimensions.""" 74 | fh = h // self.patches[0] 75 | fw = w // self.patches[1] 76 | bs, l, f = x.shape 77 | if l != fh * fw: 78 | raise ValueError(f'Invalid shape: {x.shape}') 79 | return x.reshape(bs, fh, fw, f) 80 | 81 | @nn.compact 82 | def __call__( 83 | self, x: jnp.ndarray, *, train: bool, debug: bool = False 84 | ) -> t.Tuple[jnp.ndarray, jnp.ndarray]: 85 | del debug 86 | x = vit.maybe_center_pad( 87 | x, patch_h=self.patches[0], patch_w=self.patches[1]) 88 | h, w = x.shape[1:3] # w, h of images after padding. 89 | x = self.encoder(x, train=train) 90 | x_patch, x_vec = self.pool_features(x) 91 | x_patch = self.reshape_spatial_features(x_patch, h, w) 92 | 93 | return x_patch, x_vec 94 | 95 | 96 | class TextEncoder(nn.Module): 97 | """TIPS Text encoder.""" 98 | 99 | variant: str 100 | vocab_size: int = 32_000 101 | dropout_rate: float = 0.1 102 | attention_dropout_rate: float = 0.1 103 | stochastic_depth: float = 0.0 104 | dtype: t.Any = jnp.float32 105 | scale_sqrt_depth: bool = True # Default param in PAX experiments. 106 | 107 | def setup(self): 108 | super().setup() 109 | text_config = vit.get_vit_config(self.variant) 110 | text_config['num_layers'] = 12 111 | # The text tower layers are fixed independent of vision tower size. 112 | # Exception: The So400m/14 text tower is a symmetric copy of the vision 113 | # tower. 114 | self.num_layers = 12 115 | if self.variant != 'So400m/14': 116 | self.num_layers = text_config['num_layers'] 117 | self.embedding_dim = text_config['hidden_size'] 118 | self.mlp_dim = text_config['mlp_dim'] 119 | self.num_heads = text_config['num_heads'] 120 | self.embedder = text.Embedding( 121 | name='token_emb', 122 | num_classes=self.vocab_size, 123 | embedding_dim=self.embedding_dim, 124 | scale_sqrt_depth=self.scale_sqrt_depth) 125 | self.pos_embedder = text.PositionalEmbedding( 126 | embedding_dim=self.embedding_dim) 127 | self.transformer = text.StackedTransformer( 128 | name='transformer', 129 | mlp_dim=self.mlp_dim, 130 | num_layers=self.num_layers, 131 | num_heads=self.num_heads, 132 | dropout_rate=self.dropout_rate, 133 | attention_dropout_rate=self.attention_dropout_rate, 134 | stochastic_depth=self.stochastic_depth, 135 | dtype=self.dtype, 136 | ) 137 | self.pooling = text.GlobalAvgPooling(pooling_dims=[1]) 138 | self.norm = nn.LayerNorm(dtype=self.dtype, name='text_encoder_norm') 139 | 140 | def __call__( 141 | self, 142 | ids: jnp.ndarray, 143 | paddings: jnp.ndarray, 144 | train: bool, 145 | ) -> jnp.ndarray: 146 | """Applies TextEncoder module.""" 147 | _, seq_length = ids.shape 148 | mask = (paddings == 0).astype(jnp.int32) 149 | x = self.embedder(ids) 150 | x = x + self.pos_embedder(seq_length=seq_length) 151 | x = self.transformer(x, mask, deterministic=not train) 152 | x = self.norm(x) 153 | x = self.pooling(x, compatible_paddings=paddings[:, :, jnp.newaxis]) 154 | return x 155 | 156 | -------------------------------------------------------------------------------- /scenic/models/vit.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Standard ViT model definition.""" 17 | 18 | import logging 19 | import math 20 | import typing as t 21 | 22 | import flax.linen as nn 23 | import jax 24 | import jax.numpy as jnp 25 | import ml_collections 26 | import numpy as np 27 | 28 | from scenic.model_lib.layers import attention_layers 29 | from scenic.model_lib.layers import nn_layers 30 | 31 | Initializer = t.Callable[[jnp.ndarray, t.Sequence[int], jnp.dtype], jnp.ndarray] 32 | 33 | 34 | def get_vit_config(variant: str) -> t.Dict[str, t.Any]: 35 | v, patch = variant.split('/') 36 | return { 37 | # pylint:disable=line-too-long 38 | 'hidden_size': {'S': 384, 'B': 768, 'L': 1024, 'So400m': 1152, 'g': 1536}[v], 39 | 'num_layers': {'S': 12, 'B': 12, 'L': 24, 'So400m': 27, 'g': 40}[v], 40 | 'mlp_dim': {'S': 1536, 'B': 3072, 'L': 4096, 'So400m': 4304, 'g': 6144}[v], 41 | 'num_heads': {'S': 6, 'B': 12, 'L': 16, 'So400m': 16, 'g': 24}[v], 42 | 'patch_size': (int(patch), int(patch)), 43 | 'ffn_layer': {'S': 'mlp', 'B': 'mlp', 'L': 'mlp', 'So400m': 'mlp', 'g': 'swiglu'}[v], 44 | # pylint:enable=line-too-long 45 | } 46 | 47 | 48 | def maybe_center_pad(x: jnp.ndarray, patch_h: int, patch_w: int): 49 | """Pads the input to the next multiple of the patch size.""" 50 | h_old, w_old = x.shape[1:3] 51 | pad_h = math.ceil(h_old / patch_h) * patch_h - h_old 52 | pad_w = math.ceil(w_old / patch_w) * patch_w - w_old 53 | if pad_w > 0 or pad_h > 0: 54 | pad_h_top = pad_h // 2 55 | pad_h_bottom = pad_h - pad_h_top 56 | pad_w_left = pad_w // 2 57 | pad_w_right = pad_w - pad_w_left 58 | logging.info( 59 | 'Applying center padding (%d, %d), (%d, %d)', 60 | pad_w_left, pad_w_right, pad_h_top, pad_h_bottom) 61 | x = jnp.pad( 62 | x, ((0, 0), 63 | (pad_h_top, pad_h_bottom), 64 | (pad_w_left, pad_w_right), 65 | (0, 0))) 66 | return x 67 | 68 | 69 | class ToTokenSequence(nn.Module): 70 | """Transform a batch of views into a sequence of tokens.""" 71 | 72 | patches: ml_collections.ConfigDict 73 | hidden_size: int 74 | num_cls_tokens: int = 0 75 | posembs: t.Tuple[int, int] = (16, 16) 76 | pos_interpolation_method: str = 'bilinear' 77 | 78 | def add_positional_encodings(self, x: jnp.ndarray) -> jnp.ndarray: 79 | """Support a few variants for sinsuoidal 2D position embeddings.""" 80 | n, h, w, c = x.shape 81 | posemb = self.param( 82 | 'posembed_input', 83 | nn.initializers.normal(stddev=1/np.sqrt(c)), 84 | (1, self.posembs[0], self.posembs[1], c), x.dtype) 85 | # Interpolate the positional encodings. 86 | if (h, w) != self.posembs: 87 | posemb = jax.image.resize( 88 | posemb, (1, h, w, c), self.pos_interpolation_method) 89 | x = x + posemb 90 | x = jnp.reshape(x, [n, h * w, c]) 91 | 92 | assert x.ndim == 3 # Shape is `[batch, len, emb]`. 93 | return x 94 | 95 | @nn.compact 96 | def __call__(self, x: jnp.ndarray, *, seqlen: int = -1): 97 | 98 | fh, fw = self.patches 99 | # Extracting patches and then embedding is in fact a single convolution. 100 | x = nn.Conv( 101 | self.hidden_size, (fh, fw), 102 | strides=(fh, fw), 103 | padding='VALID', 104 | name='embedding')(x) 105 | 106 | # Add positional encodings. 107 | x = self.add_positional_encodings(x) 108 | 109 | # Add extra "cls" tokens. 110 | if self.num_cls_tokens > 0: 111 | n, _, c = x.shape 112 | cls_tok = self.param( 113 | 'cls', 114 | nn.initializers.zeros, 115 | (1, self.num_cls_tokens, c), 116 | x.dtype) 117 | cls_tok = jnp.tile(cls_tok, [n, 1, 1]) 118 | x = jnp.concatenate([cls_tok, x], axis=1) 119 | return x 120 | 121 | 122 | class FFNSwiGluFused(nn.Module): 123 | """SwiGlu variant of the feed-forward block. 124 | 125 | https://arxiv.org/abs/2002.05202v1 126 | """ 127 | 128 | mlp_dim: int 129 | out_dim: t.Optional[int] = None 130 | dropout_rate: float = 0.0 131 | use_bias: bool = False 132 | kernel_init: Initializer = nn.initializers.xavier_uniform() 133 | bias_init: Initializer = nn.initializers.zeros 134 | precision: t.Optional[jax.lax.Precision] = None 135 | dtype: jnp.ndarray = jnp.float32 136 | 137 | def _hidden_layer(self, inputs: jnp.ndarray) -> jnp.ndarray: 138 | # https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py#L57 # pylint: disable=line-too-long 139 | mlp_dim = (int(self.mlp_dim * 2 / 3) + 7) // 8 * 8 140 | xw = nn.Dense( 141 | mlp_dim, 142 | dtype=self.dtype, 143 | use_bias=self.use_bias, 144 | kernel_init=self.kernel_init, 145 | bias_init=self.bias_init, 146 | precision=self.precision, 147 | )(inputs) 148 | xv = nn.Dense( 149 | mlp_dim, 150 | dtype=self.dtype, 151 | use_bias=self.use_bias, 152 | kernel_init=self.kernel_init, 153 | bias_init=self.bias_init, 154 | precision=self.precision, 155 | )(inputs) 156 | xw = nn.swish(xw) 157 | x = xw * xv 158 | return x 159 | 160 | @nn.compact 161 | def __call__( 162 | self, inputs: jnp.ndarray, *, deterministic: bool 163 | ) -> jnp.ndarray: 164 | """Applies FFN module.""" 165 | actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim 166 | x = self._hidden_layer(inputs) 167 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) 168 | output = nn.Dense( 169 | actual_out_dim, 170 | dtype=self.dtype, 171 | use_bias=self.use_bias, 172 | kernel_init=self.kernel_init, 173 | bias_init=self.bias_init, 174 | precision=self.precision)(x) 175 | output = nn.Dropout(rate=self.dropout_rate)( 176 | output, deterministic=deterministic) 177 | return output 178 | 179 | 180 | class VisionEncoder1DBlock(nn.Module): 181 | """Transformer encoder layer. 182 | 183 | Attributes: 184 | mlp_dim: Dimension of the mlp on top of attention block. 185 | num_heads: Number of self-attention heads. 186 | dtype: The dtype of the computation (default: float32). 187 | dropout_rate: Dropout rate. 188 | attention_dropout_rate: Dropout for attention heads. 189 | stochastic_depth: probability of dropping a layer linearly grows 190 | from 0 to the provided value. 191 | ffn_layer: type of the feed-forward layer. Options are 'mlp', 'swiglufused'. 192 | 193 | Returns: 194 | output after transformer encoder block. 195 | """ 196 | mlp_dim: int 197 | num_heads: int 198 | dtype: t.Any = jnp.float32 199 | dropout_rate: float = 0.0 200 | attention_dropout_rate: float = 0.0 201 | stochastic_depth: float = 0.0 202 | ffn_layer: str = 'mlp' 203 | 204 | def setup(self): 205 | super().setup() 206 | 207 | if self.ffn_layer == 'mlp': 208 | ffn_layer = attention_layers.MlpBlock( 209 | mlp_dim=self.mlp_dim, 210 | dtype=self.dtype, 211 | dropout_rate=self.dropout_rate, 212 | activation_fn=nn.gelu, 213 | kernel_init=nn.initializers.xavier_uniform(), 214 | bias_init=nn.initializers.normal(stddev=1e-6), 215 | name='MlpBlock_0') 216 | elif self.ffn_layer == 'swiglu': 217 | ffn_layer = FFNSwiGluFused( 218 | mlp_dim=self.mlp_dim, 219 | dtype=self.dtype, 220 | use_bias=True, 221 | dropout_rate=self.dropout_rate, 222 | kernel_init=nn.initializers.xavier_uniform(), 223 | bias_init=nn.initializers.normal(stddev=1e-6), 224 | name='FFNSwiGluFused_0') 225 | else: 226 | raise ValueError(f'Unsupported ffn_layer: {self.ffn_layer}') 227 | self.ffn = ffn_layer 228 | self.ln_0 = nn.LayerNorm(dtype=self.dtype, name='LayerNorm_0') 229 | self.ln_1 = nn.LayerNorm(dtype=self.dtype, name='LayerNorm_1') 230 | self.attention = nn.MultiHeadDotProductAttention( 231 | name='MultiHeadDotProductAttention_0', 232 | num_heads=self.num_heads, 233 | dtype=self.dtype, 234 | kernel_init=nn.initializers.xavier_uniform(), 235 | broadcast_dropout=False, 236 | dropout_rate=self.attention_dropout_rate) 237 | 238 | @nn.compact 239 | def __call__( 240 | self, inputs: jnp.ndarray, deterministic: bool) -> jnp.ndarray: 241 | """Applies Encoder1DBlock module. 242 | 243 | Args: 244 | inputs: Input data. 245 | deterministic: Deterministic or not (to apply dropout). 246 | 247 | Returns: 248 | Output after transformer encoder block. 249 | """ 250 | # Attention block. 251 | assert inputs.ndim == 3 252 | x = self.ln_0(inputs) 253 | x = self.attention(x, x, deterministic=deterministic) 254 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic) 255 | x = nn_layers.StochasticDepth(rate=self.stochastic_depth)(x, deterministic) 256 | x = x + inputs 257 | 258 | # MLP block. 259 | y = self.ln_1(x) 260 | y = self.ffn(y, deterministic=deterministic) 261 | y = nn_layers.StochasticDepth(rate=self.stochastic_depth)(y, deterministic) 262 | return y + x 263 | 264 | 265 | class StackedTransformer(nn.Module): 266 | """Stacked transformer.""" 267 | 268 | mlp_dim: int 269 | num_layers: int 270 | num_heads: int 271 | ffn_layer: str = 'mlp' 272 | dropout_rate: float = 0.0 273 | attention_dropout_rate: float = 0.0 274 | stochastic_depth: float = 0.0 275 | dtype: t.Any = jnp.float32 276 | 277 | def setup(self): 278 | encoder_blocks = [] 279 | for lyr in range(self.num_layers): 280 | encoder_blocks.append( 281 | VisionEncoder1DBlock( 282 | mlp_dim=self.mlp_dim, 283 | num_heads=self.num_heads, 284 | dropout_rate=self.dropout_rate, 285 | attention_dropout_rate=self.attention_dropout_rate, 286 | stochastic_depth=(lyr / max(self.num_layers - 1, 1)) 287 | * self.stochastic_depth, 288 | name=f'encoderblock_{lyr}', 289 | ffn_layer=self.ffn_layer, 290 | dtype=self.dtype)) 291 | self.encoder_blocks = encoder_blocks 292 | 293 | def __call__( 294 | self, x: jnp.ndarray, deterministic: bool) -> jnp.ndarray: 295 | """Applies StackedTransformer module.""" 296 | for block in self.encoder_blocks: 297 | x = block(x, deterministic=deterministic) 298 | return x 299 | 300 | 301 | class ViT(nn.Module): 302 | """Dense Features backbone based on ViT.""" 303 | 304 | variant: str 305 | freeze_backbone: bool = False 306 | num_cls_tokens: int = 1 307 | dropout_rate: float = 0.1 308 | attention_dropout_rate: float = 0.1 309 | stochastic_depth: float = 0.0 310 | dtype: t.Any = jnp.float32 311 | 312 | def setup(self): 313 | super().setup() 314 | vit_config = get_vit_config(self.variant) 315 | self.patches = vit_config['patch_size'] 316 | self.hidden_size = vit_config['hidden_size'] 317 | self.num_layers = vit_config['num_layers'] 318 | self.mlp_dim = vit_config['mlp_dim'] 319 | self.num_heads = vit_config['num_heads'] 320 | self.ffn_layer = vit_config['ffn_layer'] 321 | 322 | # Setup for layers. 323 | self.token_fn = ToTokenSequence( 324 | name='ToTokenSequence_0', 325 | patches=self.patches, 326 | hidden_size=self.hidden_size, 327 | num_cls_tokens=self.num_cls_tokens, 328 | posembs=(16, 16), 329 | ) 330 | self.norm = nn.LayerNorm(name='encoder_norm') 331 | self.transformer = StackedTransformer( 332 | name='transformer', 333 | mlp_dim=self.mlp_dim, 334 | num_layers=self.num_layers, 335 | num_heads=self.num_heads, 336 | dropout_rate=self.dropout_rate, 337 | attention_dropout_rate=self.attention_dropout_rate, 338 | stochastic_depth=self.stochastic_depth, 339 | dtype=self.dtype, 340 | ffn_layer=self.ffn_layer, 341 | ) 342 | 343 | @nn.compact 344 | def __call__( 345 | self, x: jnp.ndarray, *, train: bool, debug: bool = False) -> jnp.ndarray: 346 | del debug 347 | logging.info('train=%s shape before padding=%s', train, x.shape) 348 | x = maybe_center_pad(x, patch_h=self.patches[0], patch_w=self.patches[1]) 349 | logging.info('train=%s shape after padding=%s', train, x.shape) 350 | 351 | x = self.token_fn(x) 352 | x = self.transformer(x, deterministic=not train) 353 | x = self.norm(x) 354 | 355 | return x 356 | -------------------------------------------------------------------------------- /scenic/run_tips_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Runs TIPS inference.""" 17 | 18 | import argparse 19 | import os 20 | import cv2 21 | import flax.linen as nn 22 | import jax 23 | import jax.numpy as jnp 24 | import numpy as np 25 | from PIL import Image 26 | 27 | from tips.scenic.configs import tips_model_config 28 | from tips.scenic.models import text 29 | from tips.scenic.models import tips 30 | from tips.scenic.utils import checkpoint 31 | from tips.scenic.utils import feature_viz 32 | 33 | 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument( 36 | '--image_width', 37 | type=int, 38 | default=448, 39 | help='Image width.', 40 | ) 41 | parser.add_argument( 42 | '--variant', 43 | type=str, 44 | default='tips_oss_b14_highres_distilled', 45 | choices=( 46 | 'tips_oss_g14_highres', 47 | 'tips_oss_g14_lowres', 48 | 'tips_oss_so400m14_highres_largetext_distilled', 49 | 'tips_oss_l14_highres_distilled', 50 | 'tips_oss_b14_highres_distilled', 51 | 'tips_oss_s14_highres_distilled', 52 | ), 53 | help='Model variant.', 54 | ) 55 | parser.add_argument( 56 | '--checkpoint_dir', 57 | type=str, 58 | default='checkpoints/', 59 | help='The directory of the checkpoints and the tokenizer.', 60 | ) 61 | parser.add_argument( 62 | '--image_path', 63 | type=str, 64 | default='images/example_image.jpg', 65 | help='The path to the image file.' 66 | ) 67 | 68 | 69 | def main() -> None: 70 | args = parser.parse_args() 71 | image_width = args.image_width 72 | image_shape = (image_width,) * 2 73 | variant = args.variant 74 | checkpoint_dir = args.checkpoint_dir 75 | image_path = args.image_path 76 | 77 | # Load the model configuration. 78 | model_config = tips_model_config.get_config(variant) 79 | 80 | # Load the vision encoder. 81 | model_vision = tips.VisionEncoder( 82 | variant=model_config.variant, 83 | pooling=model_config.pooling, 84 | num_cls_tokens=model_config.num_cls_tokens) 85 | init_params_vision = model_vision.init( 86 | jax.random.PRNGKey(0), jnp.ones([1, *image_shape, 3]), train=False) 87 | params_vision = checkpoint.load_checkpoint( 88 | os.path.join(checkpoint_dir, f'{variant}_vision.npz'), 89 | init_params_vision['params']) 90 | 91 | # Load the text encoder. 92 | tokenizer_path = os.path.join(checkpoint_dir, 'tokenizer.model') 93 | tokenizer = text.Tokenizer(tokenizer_path) 94 | model_text = tips.TextEncoder(variant=model_config.variant) 95 | init_params_text = model_text.init( 96 | jax.random.PRNGKey(0), 97 | ids=jnp.ones((4, 64), dtype=jnp.int32), 98 | paddings=jnp.zeros((4, 64), dtype=jnp.int32), 99 | train=False) 100 | init_params_text['params']['temperature_contrastive'] = ( 101 | np.array(0, dtype=np.float32)) 102 | params_text = checkpoint.load_checkpoint( 103 | os.path.join(checkpoint_dir, f'{variant}_text.npz'), 104 | init_params_text['params']) 105 | 106 | # Load and preprocess the image. 107 | image = jnp.array(Image.open(image_path)).astype(jnp.float32) / 255. 108 | image = jax.image.resize(image, (*image_shape, 3), method='bilinear') 109 | image = image.astype(jnp.float32) 110 | 111 | # Run inference on the image. 112 | spatial_features, embeddings_vision = model_vision.apply( 113 | {'params': params_vision}, image[None], train=False) 114 | # We choose the first CLS token (the second one is better for dense tasks.). 115 | cls_token = feature_viz.normalize(embeddings_vision[:, 0, :]) 116 | 117 | # Run inference on text. 118 | text_input = [ 119 | 'A ship', 'holidays', 'a toy dinosaur', 'Two astronauts', 120 | 'a real dinosaur', 'A streetview image of burger kings', 121 | 'a streetview image of mc donalds'] 122 | text_ids, text_paddings = tokenizer.tokenize(text_input, max_len=64) 123 | embeddings_text = model_text.apply( 124 | {'params': params_text}, 125 | ids=text_ids, 126 | paddings=text_paddings, 127 | train=False) 128 | embeddings_text = feature_viz.normalize(embeddings_text) 129 | 130 | # Compute cosine similariy. 131 | cos_sim = nn.softmax( 132 | ((cls_token @ embeddings_text.T) / 133 | params_text['temperature_contrastive']), axis=-1) 134 | label_idxs = jnp.argmax(cos_sim, axis=-1) 135 | cos_sim_max = jnp.max(cos_sim, axis=-1) 136 | label_predicted = text_input[label_idxs[0].item()] 137 | similarity = cos_sim_max[0].item() 138 | 139 | # Compute PCA of patch tokens. 140 | pca_obj = feature_viz.PCAVisualizer(spatial_features) 141 | image_pca = pca_obj(spatial_features)[0] 142 | image_pca = np.asarray(jax.image.resize( 143 | image_pca, (*image_shape, 3), method='nearest')) 144 | 145 | # Display the results. 146 | cv2.imshow( 147 | f'{label_predicted}, prob: {similarity*100:.1f}%', 148 | np.concatenate([image, image_pca], axis=1)[..., ::-1]) 149 | cv2.waitKey(0) 150 | cv2.destroyAllWindows() 151 | 152 | 153 | if __name__ == '__main__': 154 | main() 155 | -------------------------------------------------------------------------------- /scenic/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Checkpoint helpers functions.""" 17 | 18 | import logging 19 | import typing as t 20 | import flax 21 | import numpy as np 22 | 23 | 24 | def load_checkpoint( 25 | checkpoint_path: str, 26 | params_to_load: t.Dict[str, np.ndarray], 27 | strict: bool = True, 28 | ) -> t.Dict[str, np.ndarray]: 29 | """Loads a TIPS checkpoint and checks that the parameters are compatible.""" 30 | params_to_load_flat = flax.traverse_util.flatten_dict(params_to_load, sep='/') 31 | params_loaded_flat = dict(np.load(checkpoint_path, allow_pickle=True)) 32 | 33 | # Check that params to load are in the checkpoint, and have identical shapes. 34 | for k in params_to_load_flat: 35 | if k not in params_loaded_flat: 36 | raise ValueError(f'Param {k} not found in checkpoint.') 37 | if params_loaded_flat[k].shape != params_to_load_flat[k].shape: 38 | raise ValueError(f'Param {k} has wrong shape in checkpoint.') 39 | 40 | # Check that the checkpoint does not contain extra parameter groups. 41 | for k in params_loaded_flat: 42 | if k not in params_to_load_flat: 43 | if strict: 44 | raise ValueError(f'Param {k} not found in params_to_load.') 45 | else: 46 | logging.warning('Param %s not found in params_to_load.', k) 47 | 48 | return flax.traverse_util.unflatten_dict(params_loaded_flat, sep='/') 49 | -------------------------------------------------------------------------------- /scenic/utils/feature_viz.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Visualization helpers for features.""" 17 | 18 | import typing as t 19 | import jax.numpy as jnp 20 | import numpy as np 21 | from sklearn import decomposition 22 | 23 | 24 | _ArrayLike = t.Union[np.ndarray, jnp.ndarray] 25 | 26 | 27 | def normalize(x, order: int = 2): 28 | return x / np.linalg.norm( 29 | x, ord=order, axis=-1, keepdims=True).clip(min=1e-3) 30 | 31 | 32 | class PCAVisualizer: 33 | """PCA visualizer.""" 34 | 35 | def __init__( 36 | self, 37 | features: _ArrayLike, 38 | n_samples: int = 100000, 39 | n_components: int = 3) -> None: 40 | """Creates a PCA object for visualizing features of shape [..., F].""" 41 | features = np.array(features) 42 | pca_object = decomposition.PCA(n_components=n_components) 43 | features = features.reshape([-1, features.shape[-1]]) 44 | features = features[np.random.randint(0, features.shape[0], n_samples), :] 45 | pca_object.fit(features) 46 | self.pca_object = pca_object 47 | self.n_components = n_components 48 | 49 | def __call__(self, features: _ArrayLike) -> np.ndarray: 50 | """Apply PCA to features of shape [..., F].""" 51 | features = np.array(features) 52 | features_pca = self.pca_object.transform( 53 | features.reshape([-1, features.shape[-1]]) 54 | ).reshape(features.shape[:-1] + (self.n_components,)) 55 | return normalize(features_pca) * 0.5 + 0.5 56 | --------------------------------------------------------------------------------