├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── docs
└── images
│ ├── block_diagram.png
│ ├── overview.png
│ └── qualitative.png
├── pytorch
├── TIPS_Demo.ipynb
├── __init__.py
├── checkpoints
│ └── download_checkpoints.sh
├── image_encoder.py
├── run_image_encoder_inference.py
├── run_text_encoder_inference.py
└── text_encoder.py
└── scenic
├── checkpoints
└── download_checkpoints.sh
├── configs
└── tips_model_config.py
├── images
├── example_image.jpg
└── example_image_2.jpg
├── models
├── text.py
├── tips.py
└── vit.py
├── notebooks
└── TIPS_Demo.ipynb
├── run_tips_inference.py
└── utils
├── checkpoint.py
└── feature_viz.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Distribution / packaging
7 | .Python
8 | build/
9 | develop-eggs/
10 | dist/
11 | downloads/
12 | eggs/
13 | .eggs/
14 | lib/
15 | lib64/
16 | parts/
17 | sdist/
18 | var/
19 | wheels/
20 | share/python-wheels/
21 | *.egg-info/
22 | .installed.cfg
23 | *.egg
24 | MANIFEST
25 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | ## Contributor License Agreement
4 |
5 | Contributions to this project must be accompanied by a Contributor License
6 | Agreement. You (or your employer) retain the copyright to your contribution,
7 | this simply gives us permission to use and redistribute your contributions as
8 | part of the project. Head over to to see
9 | your current agreements on file or to sign a new one.
10 |
11 | You generally only need to submit a CLA once, so if you've already submitted one
12 | (even if it was for a different project), you probably don't need to do it
13 | again.
14 |
15 | ## Code reviews
16 |
17 | All submissions, including submissions by project members, require review. We
18 | use GitHub pull requests for this purpose. Consult
19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
20 | information on using pull requests.
21 |
22 | ## Community Guidelines
23 |
24 | This project follows [Google's Open Source Community
25 | Guidelines](https://opensource.google/conduct/).
26 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TIPS: Text-Image Pretraining with Spatial awareness (ICLR 2025)
2 |
3 | This repository contains the implementation and models introduced in
4 | TIPS: Text-Image Pretraining with Spatial Awareness, published at ICLR 2025.
5 |
6 | **Quick Links:**
7 | [Paper](https://arxiv.org/abs/2410.16512) |
8 | [Project Website](https://gdm-tips.github.io) |
9 | [Pytorch Notebook](./pytorch/TIPS_Demo.ipynb) |
10 | [Scenic Notebook](./scenic/notebooks/TIPS_Demo.ipynb)
11 |
12 | We provide both Pytorch and Jax (Scenic) implementations:
13 |
14 | - `tips/pytorch/`: PyTorch inference for the model. The image tower largely
15 | follows the official [DINOv2 definition](https://github.com/facebookresearch/dinov2).
16 | - `tips/scenic/`: Jax-based inference using the
17 | [scenic library](https://github.com/google-research/scenic).
18 |
19 |
20 |
24 |
25 |
26 | **Abstract**
27 |
28 | While image-text representation learning has become very popular
29 | in recent years, existing models tend to lack spatial awareness and have limited
30 | direct applicability for dense understanding tasks. For this reason,
31 | self-supervised image-only pretraining is still the go-to method for many dense
32 | vision applications (e.g. depth estimation, semantic segmentation), despite the
33 | lack of explicit supervisory signals. In this paper, we close this gap between
34 | image-text and self-supervised learning, by proposing a novel general-purpose
35 | image-text model, which can be effectively used off the shelf for dense and
36 | global vision tasks. Our method, which we refer to as Text-Image Pretraining
37 | with Spatial awareness (TIPS), leverages two simple and effective insights.
38 | First, on textual supervision: we reveal that replacing noisy web image captions
39 | by synthetically generated textual descriptions boosts dense understanding
40 | performance significantly, due to a much richer signal for learning spatially
41 | aware representations. We propose an adapted training method that combines noisy
42 | and synthetic captions, resulting in improvements across both dense and global
43 | understanding tasks. Second, on the learning technique: we propose to combine
44 | contrastive image-text learning with self-supervised masked image modeling, to
45 | encourage spatial coherence, unlocking substantial enhancements for downstream
46 | applications. Building on these two ideas, we scale our model using the
47 | transformer architecture, trained on a curated set of public images. Our
48 | experiments are conducted on 8 tasks involving 16 datasets in total,
49 | demonstrating strong off-the-shelf performance on both dense and global
50 | understanding, for several image-only and image-text tasks.
51 |
52 |
53 |
54 |
58 |
59 |
60 |
61 | ## Checkpoints
62 | We provide links to all available checkpoints, for both Pytorch and Jax model
63 | definitions, together with representative evals.
64 |
65 | Model size | #Params vision / text | Pytorch ckp. | Jax ckp. | PASCAL seg.↑ | NYU-depth↓ | ImageNet-KNN↑ | UNED-KNN↑ | Flickr T→I↑ | Flickr I→T↑
66 | :---------- | :--------------------- | :------------------------------------------------------: | :------------------------------------------------------: | :---------: | :-------: | :----------: | :------: | :--------: | :--------:
67 | g/14-HR | 1.1B / 389.1M | [vision][pth-g14-hr-vision] \| [text][pth-g14-hr-text] | [vision][jax-g14-hr-vision] \| [text][jax-g14-hr-text] | 83.1 | 0.363 | 83.2 | 68.4 | 93.8 | 83.8
68 | g/14-LR | 1.1B / 389.1M | [vision][pth-g14-lr-vision] \| [text][pth-g14-lr-text] | [vision][jax-g14-lr-vision] \| [text][jax-g14-lr-text] | 82.0 | 0.390 | 83.6 | 71.5 | 93.4 | 82.1
69 | SO/14-HR | 412.4M / 448.3M | [vision][pth-so14-hr-vision] \| [text][pth-so14-hr-text] | [vision][jax-so14-hr-vision] \| [text][jax-so14-hr-text] | 83.7 | 0.362 | 83.0 | 68.6 | 94.2 | 83.8
70 | L/14-HR | 303.2M / 183.9M | [vision][pth-l14-hr-vision] \| [text][pth-l14-hr-text] | [vision][jax-l14-hr-vision] \| [text][jax-l14-hr-text] | 83.9 | 0.372 | 82.5 | 67.8 | 93.6 | 83.5
71 | B/14-HR | 85.7M / 109.6M | [vision][pth-b14-hr-vision] \| [text][pth-b14-hr-text] | [vision][jax-b14-hr-vision] \| [text][jax-b14-hr-text] | 82.9 | 0.379 | 80.0 | 62.7 | 91.3 | 79.4
72 | S/14-HR | 21.6M / 33.6M | [vision][pth-s14-hr-vision] \| [text][pth-s14-hr-text] | [vision][jax-s14-hr-vision] \| [text][jax-s14-hr-text] | 80.6 | 0.425 | 75.1 | 57.7 | 86.3 | 74.7
73 |
74 | ## Using Pytorch
75 |
76 | ### Installation
77 | Manage dependencies with a custom environment (eg. Conda)
78 |
79 | ```bash
80 | conda create -n tips python=3.11
81 |
82 | # Activate the environment.
83 | conda activate tips
84 | ```
85 |
86 | Install Pytorch dependencies.
87 |
88 | ```bash
89 | # Install pytorch (change to GPU version if needed)
90 | pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
91 |
92 | # Install other dependencies.
93 | pip install tensorflow_text mediapy jax jaxlib scikit-learn
94 |
95 | # Optionally, install Jupyter to use the notebook.
96 | pip install jupyter
97 | ```
98 |
99 | Clone the code from this repo.
100 |
101 | ```bash
102 | git clone https://github.com/google-deepmind/tips.git
103 |
104 | # Add the current directory to PYTHONPATH.
105 | export PYTHONPATH=$PYTHONPATH:$(pwd)
106 | ```
107 |
108 | Download the checkpoints locally. The script downloads all released checkpoints.
109 | Please adjust accordingly.
110 |
111 | ```bash
112 | cd tips/pytorch/checkpoints
113 | chmod +x download_checkpoints.sh
114 | ./download_checkpoints.sh
115 | cd ../../..
116 | ```
117 |
118 | ### Usage (Pytorch)
119 |
120 | To run inference on one image and get the L2-normalized image embedding from the
121 | 1st and 2nd CLS token, one can use the following:
122 |
123 | ```bash
124 | cd tips/pytorch && \
125 | python run_image_encoder_inference.py \
126 | --model_path=${PATH_TO_CHECKPOINT} \
127 | --image_file=${PATH_TO_IMAGE} \
128 | --model_variant=${MODEL_VARIANT}
129 | ```
130 |
131 | One can use `is_low_res` to specify whether a low-resolution or high-resolution
132 | checkpoint is used.
133 |
134 | To run text model inference and get the L2-normalized text embedding, please use
135 | the following cmd
136 |
137 | ```bash
138 | cd tips/pytorch && \
139 | python run_text_encoder_inference.py \
140 | --model_path=${PATH_TO_CHECKPOINT} \
141 | --tokenizer_path=${PATH_TO_TOKENIZER} \
142 | --model_variant=${MODEL_VARIANT} \
143 | --text_input=${TEXT_INPUT}
144 | ```
145 |
146 | We also provide a simple notebook demo:
147 |
148 | ```bash
149 | jupyter-notebook
150 | ```
151 | Then navigate to `tips/pytorch/TIPS_Demo.ipynb`.
152 |
153 | ## Using Jax (Scenic)
154 |
155 | ### Installation
156 | Similar to using Pytorch, manage dependencies with a custom environment.
157 |
158 | ```bash
159 | conda create -n tips python=3.11
160 |
161 | # Activate the environment.
162 | conda activate tips
163 | ```
164 |
165 | ```bash
166 | # Install scenic.
167 | git clone https://github.com/google-research/scenic.git scenic_src
168 | cd scenic_src
169 | pip install .
170 | cd ..
171 | rm -rf scenic_src
172 |
173 | # Install other dependencies.
174 | pip install pillow scikit-learn opencv-python tensorflow_text
175 |
176 | # Optionally, install Jupyter to use the notebook.
177 | pip install jupyter mediapy
178 |
179 | # In case of using CUDA, install the CUDA-supported JAX libraries.
180 | # For example, for CUDA 12 run:
181 | # pip install --upgrade "jax[cuda12_pip]" -f \
182 | # https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
183 | ```
184 |
185 | Clone the code from the this repo.
186 |
187 | ```bash
188 | git clone https://github.com/google-deepmind/tips.git
189 |
190 | # Add the current directory to PYTHONPATH.
191 | export PYTHONPATH=$PYTHONPATH:$(pwd)
192 | ```
193 |
194 | Download the checkpoints (different files from Pytorch).
195 |
196 | ```bash
197 | cd tips/scenic/checkpoints
198 | chmod +x download_checkpoints.sh
199 | ./download_checkpoints.sh
200 | cd ../../..
201 | ```
202 |
203 | ### Usage (Jax)
204 |
205 | To run inference on an image, use the following script:
206 |
207 | ```bash
208 | cd tips/scenic
209 | python run_tips_inference.py
210 | ```
211 |
212 | Alternatively, try the demo in the notebook:
213 |
214 | ```bash
215 | jupyter-notebook
216 | ```
217 | Then navigate to `tips/scenic/notebooks/TIPS_Demo.ipynb`.
218 |
219 | ## Citing this work
220 |
221 | The paper can be found on [arXiv](https://arxiv.org/abs/2410.16512).
222 | Please consider citing this work using:
223 |
224 | ```
225 | @InProceedings{tips_paper,
226 | Title={{TIPS: Text-Image Pretraining with Spatial Awareness}},
227 | Author={Maninis, Kevis-Kokitsi and Chen, Kaifeng and Ghosh, Soham and Karpur, Arjun and Chen, Koert and Xia, Ye and Cao, Bingyi and Salz, Daniel and Han, Guangxing and Dlabal, Jan and Gnanapragasam, Dan and Seyedhosseini, Mojtaba and Zhou, Howard and Araujo, Andr\'e},
228 | Booktitle={ICLR},
229 | year={2025},
230 | }
231 | ```
232 |
233 | ## License and disclaimer
234 |
235 | Copyright 2025 DeepMind Technologies Limited
236 |
237 | All software is licensed under the Apache License, Version 2.0 (Apache 2.0);
238 | you may not use this file except in compliance with the Apache 2.0 license.
239 | You may obtain a copy of the Apache 2.0 license at:
240 | https://www.apache.org/licenses/LICENSE-2.0
241 |
242 | All other materials are licensed under the Creative Commons Attribution 4.0
243 | International License (CC-BY). You may obtain a copy of the CC-BY license at:
244 | https://creativecommons.org/licenses/by/4.0/legalcode
245 |
246 | Unless required by applicable law or agreed to in writing, all software and
247 | materials distributed here under the Apache 2.0 or CC-BY licenses are
248 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
249 | either express or implied. See the licenses for the specific language governing
250 | permissions and limitations under those licenses.
251 |
252 | This is not an official Google product.
253 |
254 | [jax-g14-hr-vision]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/scenic/tips_oss_g14_highres_vision.npz
255 | [jax-g14-hr-text]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/scenic/tips_oss_g14_highres_text.npz
256 | [jax-g14-lr-vision]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/scenic/tips_oss_g14_lowres_vision.npz
257 | [jax-g14-lr-text]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/scenic/tips_oss_g14_lowres_text.npz
258 | [jax-so14-hr-vision]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/scenic/tips_oss_so400m14_highres_largetext_distilled_vision.npz
259 | [jax-so14-hr-text]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/scenic/tips_oss_so400m14_highres_largetext_distilled_text.npz
260 | [jax-l14-hr-vision]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/scenic/tips_oss_l14_highres_distilled_vision.npz
261 | [jax-l14-hr-text]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/scenic/tips_oss_l14_highres_distilled_text.npz
262 | [jax-b14-hr-vision]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/scenic/tips_oss_b14_highres_distilled_vision.npz
263 | [jax-b14-hr-text]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/scenic/tips_oss_b14_highres_distilled_text.npz
264 | [jax-s14-hr-vision]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/scenic/tips_oss_s14_highres_distilled_vision.npz
265 | [jax-s14-hr-text]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/scenic/tips_oss_s14_highres_distilled_text.npz
266 |
267 | [pth-g14-hr-vision]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/pytorch/tips_oss_g14_highres_vision.npz
268 | [pth-g14-hr-text]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/pytorch/tips_oss_g14_highres_text.npz
269 | [pth-g14-lr-vision]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/pytorch/tips_oss_g14_lowres_vision.npz
270 | [pth-g14-lr-text]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/pytorch/tips_oss_g14_lowres_text.npz
271 | [pth-so14-hr-vision]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/pytorch/tips_oss_so400m14_highres_largetext_distilled_vision.npz
272 | [pth-so14-hr-text]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/pytorch/tips_oss_so400m14_highres_largetext_distilled_text.npz
273 | [pth-l14-hr-vision]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/pytorch/tips_oss_l14_highres_distilled_vision.npz
274 | [pth-l14-hr-text]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/pytorch/tips_oss_l14_highres_distilled_text.npz
275 | [pth-b14-hr-vision]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/pytorch/tips_oss_b14_highres_distilled_vision.npz
276 | [pth-b14-hr-text]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/pytorch/tips_oss_b14_highres_distilled_text.npz
277 | [pth-s14-hr-vision]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/pytorch/tips_oss_s14_highres_distilled_vision.npz
278 | [pth-s14-hr-text]: https://storage.googleapis.com/tips_data/v1_0/checkpoints/pytorch/tips_oss_s14_highres_distilled_text.npz
279 |
--------------------------------------------------------------------------------
/docs/images/block_diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/tips/c1be95c4cba9345f1d207e652c6e285ed7c2ec04/docs/images/block_diagram.png
--------------------------------------------------------------------------------
/docs/images/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/tips/c1be95c4cba9345f1d207e652c6e285ed7c2ec04/docs/images/overview.png
--------------------------------------------------------------------------------
/docs/images/qualitative.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/tips/c1be95c4cba9345f1d207e652c6e285ed7c2ec04/docs/images/qualitative.png
--------------------------------------------------------------------------------
/pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | # All software is licensed under the Apache License, Version 2.0 (Apache 2.0);
17 | # you may not use this file except in compliance with the Apache 2.0 license.
18 | # You may obtain a copy of the Apache 2.0 license at:
19 | # https://www.apache.org/licenses/LICENSE-2.0
20 |
21 | # All other materials are licensed under the Creative Commons Attribution 4.0
22 | # International License (CC-BY). You may obtain a copy of the CC-BY license at:
23 | # https://creativecommons.org/licenses/by/4.0/legalcode
24 |
25 | # Unless required by applicable law or agreed to in writing, all software and
26 | # materials distributed here under the Apache 2.0 or CC-BY licenses are
27 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
28 | # either express or implied. See the licenses for the specific language
29 | # governing permissions and limitations under those licenses.
30 |
31 | # This is not an official Google product.
32 | """Import all files."""
33 |
--------------------------------------------------------------------------------
/pytorch/checkpoints/download_checkpoints.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2025 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 |
18 | # The model weights can be found in https://console.cloud.google.com/storage/browser/tips_data
19 | ALL_CHECKPOINTS=(
20 | "tips_oss_s14_highres_distilled"
21 | "tips_oss_b14_highres_distilled"
22 | "tips_oss_l14_highres_distilled"
23 | "tips_oss_so400m14_highres_largetext_distilled"
24 | "tips_oss_g14_lowres"
25 | "tips_oss_g14_highres"
26 | )
27 |
28 | echo "Downloading the tokenizer."
29 | wget https://storage.googleapis.com/tips_data/v1_0/checkpoints/tokenizer.model
30 |
31 | for CHECKPOINT in "${ALL_CHECKPOINTS[@]}"; do
32 | echo "Downloading ${CHECKPOINT} (vision encoder weights)"
33 | wget https://storage.googleapis.com/tips_data/v1_0/checkpoints/pytorch/${CHECKPOINT}_vision.npz
34 | echo "Downloading ${CHECKPOINT} (text encoder weights)"
35 | wget https://storage.googleapis.com/tips_data/v1_0/checkpoints/pytorch/${CHECKPOINT}_text.npz
36 | done
37 |
--------------------------------------------------------------------------------
/pytorch/image_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Vision encoder implementation in PyTorch."""
17 |
18 | import functools
19 | import math
20 | import os
21 | from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
22 | import warnings
23 | import torch
24 | from torch import nn
25 | import torch.nn.functional as F
26 | import torch.utils.checkpoint
27 |
28 |
29 | class Mlp(nn.Module):
30 | """Transformer MLP, following DINOv2 implementation."""
31 |
32 | def __init__(
33 | self,
34 | in_features: int,
35 | hidden_features: Optional[int] = None,
36 | out_features: Optional[int] = None,
37 | act_layer: Callable[..., nn.Module] = nn.GELU,
38 | drop: float = 0.0,
39 | bias: bool = True,
40 | ) -> None:
41 | super().__init__()
42 | out_features = out_features or in_features
43 | hidden_features = hidden_features or in_features
44 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
45 | self.act = act_layer()
46 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
47 | self.drop = nn.Dropout(drop)
48 |
49 | def forward(self, x: torch.Tensor) -> torch.Tensor:
50 | x = self.fc1(x)
51 | x = self.act(x)
52 | x = self.drop(x)
53 | x = self.fc2(x)
54 | x = self.drop(x)
55 | return x
56 |
57 |
58 | def make_2tuple(x):
59 | if isinstance(x, tuple):
60 | assert len(x) == 2
61 | return x
62 |
63 | assert isinstance(x, int)
64 | return (x, x)
65 |
66 |
67 | class PatchEmbed(nn.Module):
68 | """2D image to patch embedding: (B,C,H,W) -> (B,N,D)."""
69 |
70 | def __init__(
71 | self,
72 | img_size: Union[int, Tuple[int, int]] = 224,
73 | patch_size: Union[int, Tuple[int, int]] = 16,
74 | in_chans: int = 3,
75 | embed_dim: int = 768,
76 | norm_layer: Optional[Callable] = None, # pylint: disable=g-bare-generic
77 | flatten_embedding: bool = True,
78 | ) -> None:
79 | super().__init__()
80 |
81 | image_hw = make_2tuple(img_size)
82 | patch_hw = make_2tuple(patch_size)
83 | patch_grid_size = (
84 | image_hw[0] // patch_hw[0],
85 | image_hw[1] // patch_hw[1],
86 | )
87 |
88 | self.img_size = image_hw
89 | self.patch_size = patch_hw
90 | self.patches_resolution = patch_grid_size
91 | self.num_patches = patch_grid_size[0] * patch_grid_size[1]
92 |
93 | self.in_chans = in_chans
94 | self.embed_dim = embed_dim
95 |
96 | self.flatten_embedding = flatten_embedding
97 |
98 | self.proj = nn.Conv2d(
99 | in_chans, embed_dim, kernel_size=patch_hw, stride=patch_hw
100 | )
101 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
102 |
103 | def forward(self, x: torch.Tensor) -> torch.Tensor:
104 | _, _, h, w = x.shape
105 | patch_h, patch_w = self.patch_size
106 |
107 | assert (
108 | h % patch_h == 0
109 | ), f"Input image height {h} is not a multiple of patch height {patch_h}"
110 | assert (
111 | w % patch_w == 0
112 | ), f"Input image width {w} is not a multiple of patch width: {patch_w}"
113 |
114 | x = self.proj(x) # B C H W
115 | h, w = x.size(2), x.size(3)
116 | x = x.flatten(2).transpose(1, 2) # B HW C
117 | x = self.norm(x)
118 | if not self.flatten_embedding:
119 | x = x.reshape(-1, h, w, self.embed_dim) # B H W C
120 | return x
121 |
122 | def flops(self) -> float:
123 | ho, wo = self.patches_resolution
124 | flops = (
125 | ho
126 | * wo
127 | * self.embed_dim
128 | * self.in_chans
129 | * (self.patch_size[0] * self.patch_size[1])
130 | )
131 | if self.norm is not None:
132 | flops += ho * wo * self.embed_dim
133 | return flops
134 |
135 |
136 | class SwiGLUFFN(nn.Module):
137 | """SwiGLU FFN layer, following DINOv2 implementation."""
138 |
139 | def __init__(
140 | self,
141 | in_features: int,
142 | hidden_features: Optional[int] = None,
143 | out_features: Optional[int] = None,
144 | act_layer: Callable[..., nn.Module] = None,
145 | drop: float = 0.0,
146 | bias: bool = True,
147 | ) -> None:
148 | super().__init__()
149 | out_features = out_features or in_features
150 | hidden_features = hidden_features or in_features
151 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
152 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
153 |
154 | def forward(self, x: torch.Tensor) -> torch.Tensor:
155 | x12 = self.w12(x)
156 | x1, x2 = x12.chunk(2, dim=-1)
157 | hidden = F.silu(x1) * x2
158 | return self.w3(hidden)
159 |
160 |
161 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
162 | try:
163 | if XFORMERS_ENABLED:
164 | from xformers.ops import SwiGLU, memory_efficient_attention, unbind, fmha, scaled_index_add, index_select_cat # pylint: disable=g-multiple-import, g-import-not-at-top
165 |
166 | XFORMERS_AVAILABLE = True
167 | warnings.warn("xFormers is available (SwiGLU)")
168 | else:
169 | warnings.warn("xFormers is disabled (SwiGLU)")
170 | raise ImportError
171 | except ImportError:
172 | SwiGLU = SwiGLUFFN
173 | XFORMERS_AVAILABLE = False
174 |
175 | warnings.warn("xFormers is not available (SwiGLU)")
176 |
177 |
178 | class SwiGLUFFNFused(SwiGLU):
179 | """SwiGLU FFN layer, following DINOv2 implementation."""
180 |
181 | def __init__(
182 | self,
183 | in_features: int,
184 | hidden_features: Optional[int] = None,
185 | out_features: Optional[int] = None,
186 | act_layer: Callable[..., nn.Module] = None, # pylint: disable=unused-argument
187 | drop: float = 0.0, # pylint: disable=unused-argument
188 | bias: bool = True,
189 | ) -> None:
190 | out_features = out_features or in_features
191 | hidden_features = hidden_features or in_features
192 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
193 | super().__init__(
194 | in_features=in_features,
195 | hidden_features=hidden_features,
196 | out_features=out_features,
197 | bias=bias,
198 | )
199 |
200 |
201 | class Attention(nn.Module):
202 | """Attention layer, following DINOv2 implementation."""
203 |
204 | def __init__(
205 | self,
206 | dim: int,
207 | num_heads: int = 8,
208 | qkv_bias: bool = False,
209 | proj_bias: bool = True,
210 | attn_drop: float = 0.0,
211 | proj_drop: float = 0.0,
212 | ) -> None:
213 | super().__init__()
214 | self.num_heads = num_heads
215 | head_dim = dim // num_heads
216 | self.scale = head_dim**-0.5
217 |
218 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
219 | self.attn_drop = nn.Dropout(attn_drop)
220 | self.proj = nn.Linear(dim, dim, bias=proj_bias)
221 | self.proj_drop = nn.Dropout(proj_drop)
222 |
223 | def forward(self, x: torch.Tensor) -> torch.Tensor:
224 | b_dim, n_dim, c_dim = x.shape
225 | qkv = (
226 | self.qkv(x)
227 | .reshape(b_dim, n_dim, 3, self.num_heads, c_dim // self.num_heads)
228 | .permute(2, 0, 3, 1, 4)
229 | )
230 |
231 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
232 | attn = q @ k.transpose(-2, -1)
233 |
234 | attn = attn.softmax(dim=-1)
235 | attn = self.attn_drop(attn)
236 |
237 | x = (attn @ v).transpose(1, 2).reshape(b_dim, n_dim, c_dim)
238 | x = self.proj(x)
239 | x = self.proj_drop(x)
240 | return x
241 |
242 |
243 | class MemEffAttention(Attention):
244 | """Memory Efficient Attention layer, following DINOv2 implementation."""
245 |
246 | def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
247 | if not XFORMERS_AVAILABLE:
248 | if attn_bias is not None:
249 | raise AssertionError("xFormers is required for using nested tensors")
250 | return super().forward(x)
251 |
252 | b_dim, n_dim, c_dim = x.shape
253 | qkv = self.qkv(x).reshape(
254 | b_dim, n_dim, 3, self.num_heads, c_dim // self.num_heads
255 | )
256 |
257 | q, k, v = unbind(qkv, 2)
258 |
259 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
260 | x = x.reshape([b_dim, n_dim, c_dim])
261 |
262 | x = self.proj(x)
263 | x = self.proj_drop(x)
264 | return x
265 |
266 |
267 | class LayerScale(nn.Module):
268 | """Layer scale, following DINOv2 implementation."""
269 |
270 | def __init__(
271 | self,
272 | dim: int,
273 | init_values: Union[float, torch.Tensor] = 1e-5,
274 | inplace: bool = False,
275 | ) -> None:
276 | super().__init__()
277 | self.inplace = inplace
278 | self.gamma = nn.Parameter(init_values * torch.ones(dim))
279 |
280 | def forward(self, x: torch.Tensor) -> torch.Tensor:
281 | return x.mul_(self.gamma) if self.inplace else x * self.gamma
282 |
283 |
284 | def drop_path_impl(x, drop_prob: float = 0.0, training: bool = False):
285 | if drop_prob == 0.0 or not training:
286 | return x
287 | keep_prob = 1 - drop_prob
288 | shape = (x.shape[0],) + (1,) * (
289 | x.ndim - 1
290 | ) # work with diff dim tensors, not just 2D ConvNets
291 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
292 | if keep_prob > 0.0:
293 | random_tensor.div_(keep_prob)
294 | output = x * random_tensor
295 | return output
296 |
297 |
298 | class DropPath(nn.Module):
299 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
300 |
301 | def __init__(self, drop_prob=None):
302 | super(DropPath, self).__init__()
303 | self.drop_prob = drop_prob
304 |
305 | def forward(self, x):
306 | return drop_path_impl(x, self.drop_prob, self.training)
307 |
308 |
309 | class Block(nn.Module):
310 | """Transformer Block Implementation, following DINOv2 implementation."""
311 |
312 | def __init__(
313 | self,
314 | dim: int,
315 | num_heads: int,
316 | mlp_ratio: float = 4.0,
317 | qkv_bias: bool = False,
318 | proj_bias: bool = True,
319 | ffn_bias: bool = True,
320 | drop: float = 0.0,
321 | attn_drop: float = 0.0,
322 | init_values=None,
323 | drop_path: float = 0.0,
324 | act_layer: Callable[..., nn.Module] = nn.GELU,
325 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
326 | attn_class: Callable[..., nn.Module] = Attention,
327 | ffn_layer: Callable[..., nn.Module] = Mlp,
328 | ) -> None:
329 | super().__init__()
330 | self.norm1 = norm_layer(dim)
331 | self.attn = attn_class(
332 | dim,
333 | num_heads=num_heads,
334 | qkv_bias=qkv_bias,
335 | proj_bias=proj_bias,
336 | attn_drop=attn_drop,
337 | proj_drop=drop,
338 | )
339 | self.ls1 = (
340 | LayerScale(dim, init_values=init_values)
341 | if init_values
342 | else nn.Identity()
343 | )
344 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
345 |
346 | self.norm2 = norm_layer(dim)
347 | mlp_hidden_dim = int(dim * mlp_ratio)
348 | self.mlp = ffn_layer(
349 | in_features=dim,
350 | hidden_features=mlp_hidden_dim,
351 | act_layer=act_layer,
352 | drop=drop,
353 | bias=ffn_bias,
354 | )
355 | self.ls2 = (
356 | LayerScale(dim, init_values=init_values)
357 | if init_values
358 | else nn.Identity()
359 | )
360 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
361 |
362 | self.sample_drop_ratio = drop_path
363 |
364 | def forward(self, x: torch.Tensor) -> torch.Tensor:
365 | def attn_residual_func(x: torch.Tensor) -> torch.Tensor:
366 | return self.ls1(self.attn(self.norm1(x)))
367 |
368 | def ffn_residual_func(x: torch.Tensor) -> torch.Tensor:
369 | return self.ls2(self.mlp(self.norm2(x)))
370 |
371 | if self.training and self.sample_drop_ratio > 0.1:
372 | # the overhead is compensated only for a drop path rate larger than 0.1
373 | x = drop_add_residual_stochastic_depth(
374 | x,
375 | residual_func=attn_residual_func,
376 | sample_drop_ratio=self.sample_drop_ratio,
377 | )
378 | x = drop_add_residual_stochastic_depth(
379 | x,
380 | residual_func=ffn_residual_func,
381 | sample_drop_ratio=self.sample_drop_ratio,
382 | )
383 | elif self.training and self.sample_drop_ratio > 0.0:
384 | x = x + self.drop_path1(attn_residual_func(x))
385 | x = x + self.drop_path1(ffn_residual_func(x))
386 | else:
387 | x = x + attn_residual_func(x)
388 | x = x + ffn_residual_func(x)
389 | return x
390 |
391 |
392 | def drop_add_residual_stochastic_depth(
393 | x: torch.Tensor,
394 | residual_func: Callable[[torch.Tensor], torch.Tensor],
395 | sample_drop_ratio: float = 0.0,
396 | ) -> torch.Tensor:
397 | """This function is taken from the original implementation in DINOv2 to implement stochastic depth in the image encoder."""
398 | # 1) extract subset using permutation
399 | b, _, _ = x.shape
400 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
401 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
402 | x_subset = x[brange]
403 |
404 | # 2) apply residual_func to get residual
405 | residual = residual_func(x_subset)
406 |
407 | x_flat = x.flatten(1)
408 | residual = residual.flatten(1)
409 |
410 | residual_scale_factor = b / sample_subset_size
411 |
412 | # 3) add the residual
413 | x_plus_residual = torch.index_add(
414 | x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
415 | )
416 | return x_plus_residual.view_as(x)
417 |
418 |
419 | def get_branges_scales(x, sample_drop_ratio=0.0):
420 | b, _, _ = x.shape
421 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
422 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
423 | residual_scale_factor = b / sample_subset_size
424 | return brange, residual_scale_factor
425 |
426 |
427 | def add_residual(
428 | x, brange, residual, residual_scale_factor, scaling_vector=None
429 | ):
430 | """Implement residual addition in the image encoder."""
431 | if scaling_vector is None:
432 | x_flat = x.flatten(1)
433 | residual = residual.flatten(1)
434 | x_plus_residual = torch.index_add(
435 | x_flat,
436 | 0,
437 | brange,
438 | residual.to(dtype=x.dtype),
439 | alpha=residual_scale_factor,
440 | )
441 | else:
442 | x_plus_residual = scaled_index_add(
443 | x,
444 | brange,
445 | residual.to(dtype=x.dtype),
446 | scaling=scaling_vector,
447 | alpha=residual_scale_factor,
448 | )
449 | return x_plus_residual
450 |
451 |
452 | attn_bias_cache: Dict[Tuple, Any] = {} # pylint: disable=g-bare-generic
453 |
454 |
455 | def get_attn_bias_and_cat(x_list, branges=None):
456 | """this will perform the index select, cat the tensors, and provide the attn_bias from cache."""
457 | batch_sizes = (
458 | [b.shape[0] for b in branges]
459 | if branges is not None
460 | else [x.shape[0] for x in x_list]
461 | )
462 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
463 | if all_shapes not in attn_bias_cache.keys():
464 | seqlens = []
465 | for b, x in zip(batch_sizes, x_list):
466 | for _ in range(b):
467 | seqlens.append(x.shape[1])
468 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
469 | attn_bias._batch_sizes = batch_sizes # pylint: disable=protected-access
470 | attn_bias_cache[all_shapes] = attn_bias
471 |
472 | if branges is not None:
473 | cat_tensors = index_select_cat(
474 | [x.flatten(1) for x in x_list], branges
475 | ).view(1, -1, x_list[0].shape[-1])
476 | else:
477 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
478 | cat_tensors = torch.cat(tensors_bs1, dim=1)
479 |
480 | return attn_bias_cache[all_shapes], cat_tensors
481 |
482 |
483 | def drop_add_residual_stochastic_depth_list(
484 | x_list: List[torch.Tensor],
485 | residual_func: Callable[[torch.Tensor, Any], torch.Tensor],
486 | sample_drop_ratio: float = 0.0,
487 | scaling_vector=None,
488 | ) -> torch.Tensor:
489 | """Add residual to a list of tensors."""
490 | # 1) generate random set of indices for dropping samples in the batch.
491 | branges_scales = [
492 | get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list
493 | ]
494 | branges = [s[0] for s in branges_scales]
495 | residual_scale_factors = [s[1] for s in branges_scales]
496 |
497 | # 2) get attention bias and index+concat the tensors.
498 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
499 |
500 | # 3) apply residual_func to get residual, and split the result.
501 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
502 |
503 | outputs = []
504 | for x, brange, residual, residual_scale_factor in zip(
505 | x_list, branges, residual_list, residual_scale_factors
506 | ):
507 | outputs.append(
508 | add_residual(
509 | x, brange, residual, residual_scale_factor, scaling_vector
510 | ).view_as(x)
511 | )
512 | return outputs
513 |
514 |
515 | class NestedTensorBlock(Block):
516 | """Nested tensor block implementation."""
517 |
518 | def forward_nested(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
519 | """x_list contains a list of tensors to nest together and run."""
520 | assert isinstance(self.attn, MemEffAttention)
521 |
522 | if self.training and self.sample_drop_ratio > 0.0:
523 |
524 | def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
525 | return self.attn(self.norm1(x), attn_bias=attn_bias)
526 |
527 | def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
528 | del attn_bias
529 | return self.mlp(self.norm2(x))
530 |
531 | x_list = drop_add_residual_stochastic_depth_list(
532 | x_list,
533 | residual_func=attn_residual_func,
534 | sample_drop_ratio=self.sample_drop_ratio,
535 | scaling_vector=self.ls1.gamma
536 | if isinstance(self.ls1, LayerScale)
537 | else None,
538 | )
539 | x_list = drop_add_residual_stochastic_depth_list(
540 | x_list,
541 | residual_func=ffn_residual_func,
542 | sample_drop_ratio=self.sample_drop_ratio,
543 | scaling_vector=self.ls2.gamma
544 | if isinstance(self.ls1, LayerScale)
545 | else None,
546 | )
547 | return x_list
548 | else:
549 |
550 | def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
551 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
552 |
553 | def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
554 | del attn_bias
555 | return self.ls2(self.mlp(self.norm2(x)))
556 |
557 | attn_bias, x = get_attn_bias_and_cat(x_list)
558 | x = x + attn_residual_func(x, attn_bias=attn_bias)
559 | x = x + ffn_residual_func(x)
560 | return attn_bias.split(x)
561 |
562 | def forward(self, x):
563 | if isinstance(x, torch.Tensor):
564 | return super().forward(x)
565 | elif isinstance(x, list):
566 | if not XFORMERS_AVAILABLE:
567 | raise AssertionError("xFormers is required for using nested tensors")
568 | return self.forward_nested(x)
569 | else:
570 | raise AssertionError
571 |
572 |
573 | def named_apply(
574 | fn: Callable, # pylint: disable=g-bare-generic
575 | module: nn.Module,
576 | name="",
577 | depth_first=True,
578 | include_root=False,
579 | ) -> nn.Module:
580 | """Apply a function to a module and its children."""
581 | if not depth_first and include_root:
582 | fn(module=module, name=name)
583 | for child_name, child_module in module.named_children():
584 | child_name = ".".join((name, child_name)) if name else child_name
585 | named_apply(
586 | fn=fn,
587 | module=child_module,
588 | name=child_name,
589 | depth_first=depth_first,
590 | include_root=True,
591 | )
592 | if depth_first and include_root:
593 | fn(module=module, name=name)
594 | return module
595 |
596 |
597 | class BlockChunk(nn.ModuleList):
598 |
599 | def forward(self, x):
600 | for b in self:
601 | x = b(x)
602 | return x
603 |
604 |
605 | class VisionTransformer(nn.Module):
606 | """Vision Transformer implementation."""
607 |
608 | def __init__(
609 | self,
610 | img_size=224,
611 | patch_size=16,
612 | in_chans=3,
613 | embed_dim=768,
614 | depth=12,
615 | num_heads=12,
616 | mlp_ratio=4.0,
617 | qkv_bias=True,
618 | ffn_bias=True,
619 | proj_bias=True,
620 | drop_path_rate=0.0,
621 | drop_path_uniform=False,
622 | init_values=None, # for layerscale: None or 0 => no layerscale
623 | embed_layer=PatchEmbed,
624 | act_layer=nn.GELU,
625 | block_fn=Block,
626 | ffn_layer="mlp",
627 | block_chunks=1,
628 | num_register_tokens=0,
629 | interpolate_antialias=False,
630 | interpolate_offset=0.1,
631 | ):
632 | """Defines the Vision Transformer model.
633 |
634 | Args:
635 | img_size (int, tuple): input image size
636 | patch_size (int, tuple): patch size
637 | in_chans (int): number of input channels
638 | embed_dim (int): embedding dimension
639 | depth (int): depth of transformer
640 | num_heads (int): number of attention heads
641 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim
642 | qkv_bias (bool): enable bias for qkv if True
643 | ffn_bias (bool): enable bias for ffn if True
644 | proj_bias (bool): enable bias for proj in attn if True
645 | drop_path_rate (float): stochastic depth rate
646 | drop_path_uniform (bool): apply uniform drop rate across blocks
647 | init_values (float): layer-scale init values
648 | embed_layer (nn.Module): patch embedding layer
649 | act_layer (nn.Module): MLP activation layer
650 | block_fn (nn.Module): transformer block class
651 | ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
652 | block_chunks: (int) split block sequence into block_chunks units for FSDP
653 | wrap
654 | num_register_tokens: (int) number of extra cls tokens (so-called
655 | "registers")
656 | interpolate_antialias: (str) flag to apply anti-aliasing when
657 | interpolating positional embeddings
658 | interpolate_offset: (float) work-around offset to apply when interpolating
659 | positional embeddings
660 | """
661 | super().__init__()
662 | norm_layer = functools.partial(nn.LayerNorm, eps=1e-6)
663 |
664 | self.num_features = self.embed_dim = (
665 | embed_dim # num_features for consistency with other models
666 | )
667 | self.num_tokens = 1
668 | self.n_blocks = depth
669 | self.num_heads = num_heads
670 | self.patch_size = patch_size
671 | self.num_register_tokens = num_register_tokens
672 | self.interpolate_antialias = interpolate_antialias
673 | self.interpolate_offset = interpolate_offset
674 |
675 | self.patch_embed = embed_layer(
676 | img_size=img_size,
677 | patch_size=patch_size,
678 | in_chans=in_chans,
679 | embed_dim=embed_dim,
680 | )
681 | num_patches = self.patch_embed.num_patches
682 |
683 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
684 | self.pos_embed = nn.Parameter(
685 | torch.zeros(1, num_patches + self.num_tokens, embed_dim)
686 | )
687 | assert num_register_tokens >= 0
688 | self.register_tokens = (
689 | nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim))
690 | if num_register_tokens
691 | else None
692 | )
693 |
694 | if drop_path_uniform:
695 | dpr = [drop_path_rate] * depth
696 | else:
697 | dpr = [
698 | x.item() for x in torch.linspace(0, drop_path_rate, depth)
699 | ] # stochastic depth decay rule
700 |
701 | if ffn_layer == "mlp":
702 | ffn_layer = Mlp
703 | elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
704 | ffn_layer = SwiGLUFFNFused
705 | else:
706 | raise NotImplementedError
707 |
708 | blocks_list = [
709 | block_fn(
710 | dim=embed_dim,
711 | num_heads=num_heads,
712 | mlp_ratio=mlp_ratio,
713 | qkv_bias=qkv_bias,
714 | proj_bias=proj_bias,
715 | ffn_bias=ffn_bias,
716 | drop_path=dpr[i],
717 | norm_layer=norm_layer,
718 | act_layer=act_layer,
719 | ffn_layer=ffn_layer,
720 | init_values=init_values,
721 | )
722 | for i in range(depth)
723 | ]
724 | if block_chunks > 0:
725 | self.chunked_blocks = True
726 | chunked_blocks = []
727 | chunksize = depth // block_chunks
728 | for i in range(0, depth, chunksize):
729 | # this is to keep the block index consistent if we chunk the block list
730 | chunked_blocks.append(
731 | [nn.Identity()] * i + blocks_list[i : i + chunksize]
732 | )
733 | self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
734 | else:
735 | self.chunked_blocks = False
736 | self.blocks = nn.ModuleList(blocks_list)
737 |
738 | self.norm = norm_layer(embed_dim)
739 | self.head = nn.Identity()
740 |
741 | self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
742 |
743 | self.init_weights()
744 |
745 | def init_weights(self):
746 | nn.init.trunc_normal_(self.pos_embed, std=0.02)
747 | nn.init.normal_(self.cls_token, std=1e-6)
748 | if self.register_tokens is not None:
749 | nn.init.normal_(self.register_tokens, std=1e-6)
750 | named_apply(init_weights_vit_timm, self)
751 |
752 | def interpolate_pos_encoding(self, x, w, h):
753 | previous_dtype = x.dtype
754 | npatch = x.shape[1] - 1
755 | num_patches = self.pos_embed.shape[1] - 1
756 | if npatch == num_patches and w == h:
757 | return self.pos_embed
758 | pos_embed = self.pos_embed.float()
759 | class_pos_embed = pos_embed[:, 0]
760 | patch_pos_embed = pos_embed[:, 1:]
761 | dim = x.shape[-1]
762 | w0 = w // self.patch_size
763 | h0 = h // self.patch_size
764 | num_patches_dim = int(
765 | math.sqrt(num_patches)
766 | ) # Recover the number of patches in each dimension
767 | assert num_patches == num_patches_dim * num_patches_dim
768 | kwargs = {}
769 | if self.interpolate_offset:
770 | sx = float(w0 + self.interpolate_offset) / num_patches_dim
771 | sy = float(h0 + self.interpolate_offset) / num_patches_dim
772 | kwargs["scale_factor"] = (sx, sy)
773 | else:
774 | # Simply specify an output size instead of a scale factor
775 | kwargs["size"] = (w0, h0)
776 | patch_pos_embed = nn.functional.interpolate(
777 | patch_pos_embed.reshape(
778 | 1, num_patches_dim, num_patches_dim, dim
779 | ).permute(0, 3, 1, 2),
780 | mode="bilinear",
781 | antialias=self.interpolate_antialias,
782 | **kwargs,
783 | )
784 | assert (w0, h0) == patch_pos_embed.shape[-2:]
785 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
786 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
787 | previous_dtype
788 | )
789 |
790 | def prepare_tokens_with_masks(self, x, masks=None):
791 | _, _, w, h = x.shape
792 | x = self.patch_embed(x)
793 | if masks is not None:
794 | x = torch.where(
795 | masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x
796 | )
797 |
798 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
799 | x = x + self.interpolate_pos_encoding(x, w, h)
800 |
801 | if self.register_tokens is not None:
802 | x = torch.cat(
803 | (
804 | x[:, :1],
805 | self.register_tokens.expand(x.shape[0], -1, -1),
806 | x[:, 1:],
807 | ),
808 | dim=1,
809 | )
810 |
811 | return x
812 |
813 | def forward_features_list(self, x_list, masks_list):
814 | x = [
815 | self.prepare_tokens_with_masks(x, masks)
816 | for x, masks in zip(x_list, masks_list)
817 | ]
818 | for blk in self.blocks:
819 | x = blk(x)
820 |
821 | all_x = x
822 | output = []
823 | for x, masks in zip(all_x, masks_list):
824 | x_norm = self.norm(x)
825 | output.append({
826 | "x_norm_1st_clstoken": x_norm[:, :1],
827 | "x_norm_2nd_clstoken": x_norm[:, 1 : self.num_register_tokens + 1],
828 | "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
829 | "x_prenorm": x,
830 | "masks": masks,
831 | })
832 | return output
833 |
834 | def forward_features(self, x, masks=None):
835 | if isinstance(x, list):
836 | return self.forward_features_list(x, masks)
837 |
838 | x = self.prepare_tokens_with_masks(x, masks)
839 |
840 | for blk in self.blocks:
841 | x = blk(x)
842 |
843 | x_norm = self.norm(x)
844 | return {
845 | "x_norm_1st_clstoken": x_norm[:, :1],
846 | "x_norm_2nd_clstoken": x_norm[:, 1 : self.num_register_tokens + 1],
847 | "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
848 | "x_prenorm": x,
849 | "masks": masks,
850 | }
851 |
852 | def _get_intermediate_layers_not_chunked(self, x, n=1):
853 | x = self.prepare_tokens_with_masks(x)
854 | # If n is an int, take the n last blocks. If it's a list, take them
855 | output, total_block_len = [], len(self.blocks)
856 | blocks_to_take = (
857 | range(total_block_len - n, total_block_len) if isinstance(n, int) else n
858 | )
859 | for i, blk in enumerate(self.blocks):
860 | x = blk(x)
861 | if i in blocks_to_take:
862 | output.append(x)
863 | assert len(output) == len(
864 | blocks_to_take
865 | ), f"only {len(output)} / {len(blocks_to_take)} blocks found"
866 | return output
867 |
868 | def _get_intermediate_layers_chunked(self, x, n=1):
869 | x = self.prepare_tokens_with_masks(x)
870 | output, i, total_block_len = [], 0, len(self.blocks[-1])
871 | # If n is an int, take the n last blocks. If it's a list, take them
872 | blocks_to_take = (
873 | range(total_block_len - n, total_block_len) if isinstance(n, int) else n
874 | )
875 | for block_chunk in self.blocks:
876 | for blk in block_chunk[i:]: # Passing the nn.Identity()
877 | x = blk(x)
878 | if i in blocks_to_take:
879 | output.append(x)
880 | i += 1
881 | assert len(output) == len(
882 | blocks_to_take
883 | ), f"only {len(output)} / {len(blocks_to_take)} blocks found"
884 | return output
885 |
886 | def get_intermediate_layers(
887 | self,
888 | x: torch.torch.Tensor,
889 | n: Union[int, Sequence] = 1, # Layers or n last layers to take # pylint: disable=g-bare-generic
890 | reshape: bool = False,
891 | return_class_token: bool = False,
892 | norm=True,
893 | ) -> Tuple[Union[torch.torch.Tensor, Tuple[torch.torch.Tensor]]]: # pylint: disable=g-one-element-tuple
894 | if self.chunked_blocks:
895 | outputs = self._get_intermediate_layers_chunked(x, n)
896 | else:
897 | outputs = self._get_intermediate_layers_not_chunked(x, n)
898 | if norm:
899 | outputs = [self.norm(out) for out in outputs]
900 | class_tokens = [out[:, 0] for out in outputs]
901 | outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
902 | if reshape:
903 | batch_size, _, w, h = x.shape
904 | outputs = [
905 | out.reshape(
906 | batch_size, w // self.patch_size, h // self.patch_size, -1
907 | )
908 | .permute(0, 3, 1, 2)
909 | .contiguous()
910 | for out in outputs
911 | ]
912 | if return_class_token:
913 | return tuple(zip(outputs, class_tokens))
914 | return tuple(outputs)
915 |
916 | def forward(self, *args, is_training=False, **kwargs):
917 | ret = self.forward_features(*args, **kwargs)
918 | if is_training:
919 | return ret
920 | else:
921 | return self.head(ret["x_norm_1st_clstoken"]), self.head(
922 | ret["x_norm_2nd_clstoken"]
923 | ), ret["x_norm_patchtokens"]
924 |
925 |
926 | def init_weights_vit_timm(module: nn.Module, name: str = ""): # pylint: disable=unused-argument
927 | """ViT weight initialization, original timm impl (for reproducibility)."""
928 | if isinstance(module, nn.Linear):
929 | nn.init.trunc_normal_(module.weight, std=0.02)
930 | if module.bias is not None:
931 | nn.init.zeros_(module.bias)
932 |
933 |
934 | def vit_small(patch_size=14, **kwargs):
935 | model = VisionTransformer(
936 | patch_size=patch_size,
937 | embed_dim=384,
938 | depth=12,
939 | num_heads=6,
940 | mlp_ratio=4,
941 | block_fn=functools.partial(Block, attn_class=MemEffAttention),
942 | num_register_tokens=1,
943 | **kwargs,
944 | )
945 | return model
946 |
947 |
948 | def vit_base(patch_size=14, **kwargs):
949 | model = VisionTransformer(
950 | patch_size=patch_size,
951 | embed_dim=768,
952 | depth=12,
953 | num_heads=12,
954 | mlp_ratio=4,
955 | block_fn=functools.partial(Block, attn_class=MemEffAttention),
956 | num_register_tokens=1,
957 | **kwargs,
958 | )
959 | return model
960 |
961 |
962 | def vit_large(patch_size=14, **kwargs):
963 | model = VisionTransformer(
964 | patch_size=patch_size,
965 | embed_dim=1024,
966 | depth=24,
967 | num_heads=16,
968 | mlp_ratio=4,
969 | block_fn=functools.partial(Block, attn_class=MemEffAttention),
970 | num_register_tokens=1,
971 | **kwargs,
972 | )
973 | return model
974 |
975 |
976 | def vit_so400m(patch_size=14, **kwargs):
977 | """SoViT 400M model (https://arxiv.org/abs/2305.13035)."""
978 | model = VisionTransformer(
979 | patch_size=patch_size,
980 | embed_dim=1152,
981 | depth=27,
982 | num_heads=16,
983 | mlp_ratio=4304 / 1152,
984 | block_fn=functools.partial(Block, attn_class=MemEffAttention),
985 | num_register_tokens=1,
986 | **kwargs,
987 | )
988 | return model
989 |
990 |
991 | def vit_giant2(patch_size=14, **kwargs):
992 | model = VisionTransformer(
993 | patch_size=patch_size,
994 | embed_dim=1536,
995 | depth=40,
996 | num_heads=24,
997 | mlp_ratio=4,
998 | block_fn=functools.partial(Block, attn_class=MemEffAttention),
999 | num_register_tokens=1,
1000 | **kwargs,
1001 | )
1002 | return model
1003 |
--------------------------------------------------------------------------------
/pytorch/run_image_encoder_inference.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | r"""Running TIPS (https://arxiv.org/abs/2410.16512) ViT-g model inference.
17 |
18 | Usage:
19 | ```python
20 | python run_image_encoder_inference.py --model_path=${PATH_TO_LOW_RES_CHECKPOINT} \
21 | --image_file=${PATH_TO_IMAGE} --is_low_res --model_variant=g
22 | ```
23 | """
24 |
25 | import argparse
26 | import io
27 |
28 | import numpy as np
29 | from PIL import Image
30 | import torch
31 | from torchvision import transforms
32 |
33 | from tips.pytorch import image_encoder
34 |
35 | IMAGE_MEAN = (0, 0, 0)
36 | IMAGE_STD = (1.0, 1.0, 1.0)
37 | PATCH_SIZE = 14
38 |
39 | parser = argparse.ArgumentParser()
40 | parser.add_argument(
41 | '--model_path', default=None, required=True, help='The path to the model.'
42 | )
43 | parser.add_argument(
44 | '--image_file',
45 | default=None,
46 | required=True,
47 | help='The path to the image file for inference.',
48 | )
49 | parser.add_argument(
50 | '--is_low_res',
51 | action='store_true',
52 | help='Whether the model is low-resolution.',
53 | )
54 | parser.add_argument(
55 | '--model_variant',
56 | default=None,
57 | required=True,
58 | help='The variant of the model.',
59 | )
60 |
61 |
62 | def main(args):
63 |
64 | image_size = 224 if args.is_low_res else 448
65 | model_def = {
66 | 'S': image_encoder.vit_small,
67 | 'B': image_encoder.vit_base,
68 | 'L': image_encoder.vit_large,
69 | 'So400m': image_encoder.vit_so400m,
70 | 'g': image_encoder.vit_giant2,
71 | }[args.model_variant]
72 |
73 | ffn_layer = 'swiglu' if args.model_variant == 'g' else 'mlp'
74 |
75 | # Load checkpoint.
76 | checkpoint = dict(np.load(args.model_path, allow_pickle=False))
77 | for key in checkpoint:
78 | checkpoint[key] = torch.tensor(checkpoint[key])
79 |
80 | # Run inference on the image.
81 | with open(args.image_file, 'rb') as fd:
82 | image_bytes = io.BytesIO(fd.read())
83 | pil_image = Image.open(image_bytes)
84 | transform = transforms.Compose([
85 | transforms.Resize((image_size, image_size)),
86 | transforms.ToTensor(),
87 | transforms.Normalize(IMAGE_MEAN, IMAGE_STD),
88 | ])
89 | input_tensor = transform(pil_image)
90 | input_batch = input_tensor.unsqueeze(0)
91 |
92 | with torch.no_grad():
93 | model = model_def(
94 | img_size=image_size,
95 | patch_size=PATCH_SIZE,
96 | ffn_layer=ffn_layer,
97 | block_chunks=0,
98 | init_values=1.0,
99 | interpolate_antialias=True,
100 | interpolate_offset=0.0,
101 | )
102 | model.load_state_dict(checkpoint)
103 |
104 | # Compute embeddings from two CLS tokens.
105 | outputs = model(input_batch)
106 | first_cls_token = outputs[0].detach().numpy().squeeze()
107 | second_cls_token = outputs[1].detach().numpy().squeeze()
108 |
109 | first_cls_token = first_cls_token / np.linalg.norm(
110 | first_cls_token, ord=2, axis=-1, keepdims=True
111 | ).clip(min=1e-3)
112 | second_cls_token = second_cls_token / np.linalg.norm(
113 | second_cls_token, ord=2, axis=-1, keepdims=True
114 | ).clip(min=1e-3)
115 | print('First cls token: ', first_cls_token.tolist())
116 | print('Second cls token: ', second_cls_token.tolist())
117 |
118 |
119 | if __name__ == '__main__':
120 | main(parser.parse_args())
121 |
--------------------------------------------------------------------------------
/pytorch/run_text_encoder_inference.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | r"""Running TIPS (https://arxiv.org/abs/2410.16512) text encoder inference.
17 |
18 | Usage:
19 | ```python
20 | python run_text_encoder_inference.py --model_path=${PATH_TO_LOW_RES_CHECKPOINT} \
21 | --model_variant=g --tokenizer_path=${PATH_TO_TOKENIZER} \
22 | --text_input="Hello world."
23 | ```
24 | """
25 |
26 | import argparse
27 | import io
28 | import numpy as np
29 | import torch
30 | from tips.pytorch import text_encoder
31 |
32 | MAX_LEN = 64
33 | VOCAB_SIZE = 32000
34 |
35 | parser = argparse.ArgumentParser()
36 | parser.add_argument(
37 | '--model_path', default=None, required=True, help='The path to the model.'
38 | )
39 | parser.add_argument(
40 | '--model_variant',
41 | default=None,
42 | required=True,
43 | help='The variant of the model.',
44 | )
45 | parser.add_argument(
46 | '--tokenizer_path',
47 | default=None,
48 | required=True,
49 | help='The path to the tokenizer.',
50 | )
51 | parser.add_argument(
52 | '--text_input',
53 | default=None,
54 | required=True,
55 | help='The text input to the model.',
56 | )
57 |
58 |
59 | def get_config(v: str):
60 | return {
61 | 'hidden_size': {'S': 384, 'B': 768, 'L': 1024, 'So400m': 1152, 'g': 1536}[
62 | v
63 | ],
64 | 'mlp_dim': {'S': 1536, 'B': 3072, 'L': 4096, 'So400m': 4304, 'g': 6144}[
65 | v
66 | ],
67 | 'num_heads': {'S': 6, 'B': 12, 'L': 16, 'So400m': 16, 'g': 24}[v],
68 | 'num_layers': {'S': 12, 'B': 12, 'L': 12, 'So400m': 27, 'g': 12}[v],
69 | }
70 |
71 |
72 | def main(args):
73 |
74 | with open(args.model_path, 'rb') as fin:
75 | inbuffer = io.BytesIO(fin.read())
76 | np_weights_text = np.load(inbuffer, allow_pickle=False)
77 |
78 | pytorch_weights_text = {}
79 | for key, value in np_weights_text.items():
80 | pytorch_weights_text[key] = torch.from_numpy(value)
81 | pytorch_weights_text.pop('temperature')
82 |
83 | with torch.no_grad():
84 | # Define the text model.
85 | model_text = text_encoder.TextEncoder(
86 | get_config(args.model_variant),
87 | vocab_size=VOCAB_SIZE,
88 | )
89 | model_text.load_state_dict(pytorch_weights_text)
90 |
91 | tokenizer_obj = text_encoder.Tokenizer(tokenizer_path=args.tokenizer_path)
92 | text_ids, text_paddings = tokenizer_obj.tokenize(
93 | [args.text_input], max_len=MAX_LEN
94 | )
95 | text_embedding = (
96 | model_text(torch.from_numpy(text_ids), torch.from_numpy(text_paddings))
97 | .detach()
98 | .numpy()
99 | .squeeze()
100 | )
101 | text_embedding = text_embedding / np.linalg.norm(
102 | text_embedding, ord=2, axis=-1, keepdims=True
103 | ).clip(min=1e-3)
104 | print(text_embedding.tolist())
105 |
106 |
107 | if __name__ == '__main__':
108 | main(parser.parse_args())
109 |
--------------------------------------------------------------------------------
/pytorch/text_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Text encoder implementation in PyTorch."""
17 |
18 | import typing as t
19 |
20 | import tensorflow as tf
21 | import tensorflow_text
22 | import torch
23 | from torch import nn
24 | import torch.nn.functional as F
25 |
26 |
27 | class Tokenizer(object):
28 | """A simple tokenizer."""
29 |
30 | def __init__(self, tokenizer_path: str):
31 | """Initializes the tokenizer."""
32 | with open(tokenizer_path, 'rb') as f:
33 | model = f.read()
34 | self.tokenizer = tensorflow_text.SentencepieceTokenizer(
35 | model=model, add_eos=False, add_bos=False
36 | )
37 |
38 | def tokenize(self, input_text, max_len=64):
39 | tokens = self.tokenizer.tokenize(tf.strings.lower(input_text)).to_tensor()
40 | curr_len = tokens.shape[1]
41 | is_padding = tf.zeros((tokens.shape[0], max_len))
42 | if curr_len > max_len:
43 | tokens = tokens[:, :max_len]
44 | else:
45 | padding_len = max_len - curr_len
46 | tokens = tf.pad(tokens, [[0, 0], [0, padding_len]], constant_values=0)
47 | is_padding = tf.cast(tokens == 0, tf.int32)
48 | return tokens.numpy(), is_padding.numpy()
49 |
50 |
51 | class PositionalEmbedding(nn.Module):
52 | """Generates position embedding for a given 1-d sequence.
53 |
54 | Attributes:
55 | min_timescale: Start of the geometric index. Determines the periodicity of
56 | the added signal.
57 | max_timescale: End of the geometric index. Determines the frequency of the
58 | added signal.
59 | embedding_dim: Dimension of the embedding to be generated.
60 | """
61 |
62 | min_timescale: int = 1
63 | max_timescale: int = 10_000
64 | embedding_dim: int = 0
65 |
66 | def __init__(self, embedding_dim: int):
67 | super().__init__()
68 | self.embedding_dim = embedding_dim
69 |
70 | def __call__(self, seq_length: int = None, position: torch.tensor = None):
71 | """Generates a torch.tensor of sinusoids with different frequencies.
72 |
73 | Args:
74 | seq_length: an optional Python int defining the output sequence length.
75 | if the `position` argument is specified.
76 | position: [B, seq_length], optional position for each token in the
77 | sequence, only required when the sequence is packed.
78 |
79 | Returns:
80 | [B, seqlen, D] if `position` is specified, else [1, seqlen, D]
81 | """
82 | if position is None:
83 | assert seq_length is not None
84 | # [1, seqlen]
85 | position = torch.arange(seq_length, dtype=torch.float32)[None, :]
86 | else:
87 | assert position.ndim == 2, position.shape
88 |
89 | num_timescales = self.embedding_dim // 2
90 | log_timescale_increment = torch.log(
91 | torch.tensor(float(self.max_timescale) / float(self.min_timescale))
92 | ) / torch.maximum(
93 | torch.tensor(num_timescales, dtype=torch.float32) - 1, torch.tensor(1)
94 | )
95 | inv_timescales = self.min_timescale * torch.exp(
96 | torch.arange(num_timescales, dtype=torch.float32)
97 | * -log_timescale_increment
98 | )
99 | scaled_time = position[:, :, None] * inv_timescales[None, None, :]
100 | signal = torch.cat((torch.sin(scaled_time), torch.cos(scaled_time)), dim=2)
101 | # Force usage of `np` rather than `jnp` to compute static values at trace
102 | # time.
103 | signal = F.pad(signal, (0, self.embedding_dim % 2, 0, 0, 0, 0))
104 | return signal
105 |
106 |
107 | class MlpBlockWithMask(nn.Module):
108 | """Transformer MLP / feed-forward block that supports masking."""
109 |
110 | def __init__(
111 | self,
112 | mlp_dim: int,
113 | d_model: int,
114 | use_bias: bool = True,
115 | dtype: torch.dtype = torch.float32,
116 | activation_fn: nn.Module = nn.GELU,
117 | ):
118 | super().__init__()
119 |
120 | self.mlp_dim = mlp_dim
121 | self.d_model = d_model
122 | self.use_bias = use_bias
123 | self.dtype = dtype
124 | self.activation_fn = activation_fn
125 |
126 | self.c_fc = nn.Linear(
127 | in_features=self.d_model,
128 | out_features=self.mlp_dim,
129 | dtype=self.dtype,
130 | bias=self.use_bias,
131 | )
132 | self.c_proj = nn.Linear(
133 | in_features=self.mlp_dim,
134 | out_features=self.d_model,
135 | dtype=self.dtype,
136 | bias=self.use_bias,
137 | )
138 |
139 | def __call__(
140 | self, inputs: torch.Tensor, mlp_mask: torch.Tensor
141 | ) -> torch.Tensor:
142 | """Applies Transformer MlpBlock with mask module."""
143 | x = self.c_fc(inputs)
144 | x = self.activation_fn()(x)
145 | x = x * mlp_mask[..., None] # First masking.
146 | x = self.c_proj(x)
147 | x = x * mlp_mask[..., None] # Second masking.
148 | return x
149 |
150 |
151 | class ResidualAttentionBlock(nn.Module):
152 | """Transformer residual attention block."""
153 |
154 | def __init__(
155 | self,
156 | d_model: int,
157 | n_head: int,
158 | mlp_dim: int,
159 | dtype: torch.dtype = torch.float32,
160 | ):
161 | super().__init__()
162 | self.d_model = d_model
163 | self.n_head = n_head
164 | self.mlp_dim = mlp_dim
165 | self.dtype = dtype
166 |
167 | self.attn = nn.MultiheadAttention(d_model, n_head, dtype=self.dtype)
168 | self.ln_1 = nn.LayerNorm(d_model, dtype=self.dtype)
169 | self.mlp = MlpBlockWithMask(
170 | self.mlp_dim,
171 | d_model,
172 | use_bias=True,
173 | dtype=self.dtype,
174 | activation_fn=nn.ReLU,
175 | )
176 | self.ln_2 = nn.LayerNorm(d_model, dtype=self.dtype)
177 |
178 | def attention(self, x: torch.Tensor, mask: torch.Tensor):
179 | attn_mask = (
180 | mask[:, None, None, :]
181 | .repeat(1, self.n_head, x.shape[0], 1)
182 | .flatten(0, 1)
183 | )
184 | attn_mask[attn_mask == 0] = float('-inf')
185 | attn_mask[attn_mask == 1] = 0
186 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
187 |
188 | def forward(self, x: torch.Tensor, mask: torch.Tensor):
189 | x = x + self.attention(self.ln_1(x), mask.permute(1, 0))
190 | x = x + self.mlp(self.ln_2(x), mask)
191 | return x, mask
192 |
193 |
194 | class SequentialMultiInput(nn.Sequential):
195 | """Sequential module that can take multiple inputs."""
196 |
197 | def forward(self, *inputs):
198 | for module in self._modules.values():
199 | if isinstance(inputs, tuple):
200 | inputs = module(*inputs)
201 | else:
202 | inputs = module(inputs)
203 | return inputs
204 |
205 |
206 | class Transformer(nn.Module):
207 | """Transformer implementation."""
208 |
209 | def __init__(
210 | self,
211 | width: int,
212 | layers: int,
213 | heads: int,
214 | mlp_dim: int,
215 | dtype: torch.dtype = torch.float32,
216 | ):
217 | super().__init__()
218 | self.width = width
219 | self.layers = layers
220 | self.heads = heads
221 | self.mlp_dim = mlp_dim
222 | self.dtype = dtype
223 |
224 | self.resblocks = SequentialMultiInput(*[
225 | ResidualAttentionBlock(self.width, self.heads, self.mlp_dim, self.dtype)
226 | for _ in range(self.layers)
227 | ])
228 |
229 | def forward(self, x: torch.Tensor, mask: torch.Tensor):
230 | return self.resblocks(x, mask)[0]
231 |
232 |
233 | class GlobalAvgPooling(nn.Module):
234 | """Performs a simple global pooling over the input with optional paddings.
235 |
236 | Attributes:
237 | pooling_dims: A list of dims to perform pooling over.
238 | keepdims: If True, keep dimension of inputs after pooling.
239 | """
240 |
241 | pooling_dims: t.Sequence[int]
242 | epsilon: float = 1e-8
243 |
244 | def __init__(
245 | self, pooling_dims: t.Sequence[int], epsilon: float = 1e-8
246 | ):
247 | super().__init__()
248 | self.pooling_dims = pooling_dims
249 | self.epsilon = epsilon
250 |
251 | if not all([p_dims >= 0 for p_dims in self.pooling_dims]):
252 | raise ValueError('pooling_dims must be non-negative integers.')
253 |
254 | def __call__(
255 | self,
256 | inputs: torch.tensor,
257 | compatible_paddings: torch.tensor,
258 | ):
259 | """Applies global average spatial pooling to inputs.
260 |
261 | Args:
262 | inputs: An input tensor.
263 | compatible_paddings: paddings of inputs with shapes compatible with
264 | inputs, e.g. compatible_paddings with shape [B, 1] for inputs with shape
265 | [B, D].
266 |
267 | Returns:
268 | Output tensor with global pooling applied.
269 | """
270 | padded_value = torch.zeros_like(inputs)
271 | padded_value = torch.ones_like(inputs) * padded_value
272 | inputs = torch.where(compatible_paddings > 0, padded_value, inputs)
273 | valid_inputs = (
274 | torch.sum(
275 | 1.0 - compatible_paddings,
276 | self.pooling_dims,
277 | keepdims=True,
278 | dtype=inputs.dtype,
279 | )
280 | + self.epsilon
281 | )
282 | inputs_sum = torch.sum(inputs, self.pooling_dims, keepdims=True)
283 | outputs = torch.divide(inputs_sum, valid_inputs).type(inputs.dtype)
284 | outputs = torch.squeeze(outputs, axis=self.pooling_dims)
285 | return outputs
286 |
287 |
288 | class TextEncoder(nn.Module):
289 | """Text encoder implementation."""
290 |
291 | def __init__(
292 | self,
293 | config: t.Dict[str, int],
294 | vocab_size: int,
295 | dtype: torch.dtype = torch.float32,
296 | scale_sqrt_depth: bool = True,
297 | ):
298 | super().__init__()
299 | self.vocab_size = vocab_size
300 | self.dtype = dtype
301 | self.scale_sqrt_depth = scale_sqrt_depth
302 |
303 | # The text tower layers are fixed independent of vision tower size.
304 | self.transformer_layers = config['num_layers']
305 | self.embedding_dim = config['hidden_size']
306 | self.transformer_width = config['hidden_size']
307 | self.mlp_dim = config['mlp_dim']
308 | self.transformer_heads = config['num_heads']
309 |
310 | self.token_embedding = nn.Embedding(
311 | self.vocab_size, self.embedding_dim, dtype=self.dtype
312 | )
313 | self.pos_embedder = PositionalEmbedding(embedding_dim=self.embedding_dim)
314 | self.transformer = Transformer(
315 | width=self.transformer_width,
316 | layers=self.transformer_layers,
317 | heads=self.transformer_heads,
318 | mlp_dim=self.mlp_dim,
319 | dtype=self.dtype,
320 | )
321 | self.pooling = GlobalAvgPooling(pooling_dims=[1])
322 | self.ln_final = nn.LayerNorm(self.transformer_width, dtype=self.dtype)
323 |
324 | def __call__(
325 | self,
326 | ids: torch.tensor,
327 | paddings: torch.tensor,
328 | ):
329 | """Applies TextEncoder module."""
330 | _, seq_length = ids.shape
331 | mask = (paddings == 0).type(torch.float32)
332 | mask = mask.permute(1, 0) # NL -> LN
333 | x = self.token_embedding(ids)
334 | if self.scale_sqrt_depth:
335 | x = x * (self.embedding_dim**0.5)
336 | x = x + self.pos_embedder(seq_length=seq_length)
337 | x = x.permute(1, 0, 2) # NLD -> LND
338 | x = self.transformer(x, mask)
339 | x = x.permute(1, 0, 2) # LND -> NLD
340 | x = self.ln_final(x)
341 | x = self.pooling(x, compatible_paddings=paddings[:, :, None])
342 | return x
343 |
--------------------------------------------------------------------------------
/scenic/checkpoints/download_checkpoints.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2025 Google LLC
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 |
18 | # The model weights can be found in https://console.cloud.google.com/storage/browser/tips_data
19 | ALL_CHECKPOINTS=(
20 | "tips_oss_s14_highres_distilled"
21 | "tips_oss_b14_highres_distilled"
22 | "tips_oss_l14_highres_distilled"
23 | "tips_oss_so400m14_highres_largetext_distilled"
24 | "tips_oss_g14_lowres"
25 | "tips_oss_g14_highres"
26 | )
27 |
28 | echo "Downloading the tokenizer."
29 | wget https://storage.googleapis.com/tips_data/v1_0/checkpoints/tokenizer.model
30 |
31 | for CHECKPOINT in "${ALL_CHECKPOINTS[@]}"; do
32 | echo "Downloading ${CHECKPOINT} (vision encoder weights)"
33 | wget https://storage.googleapis.com/tips_data/v1_0/checkpoints/scenic/${CHECKPOINT}_vision.npz
34 | echo "Downloading ${CHECKPOINT} (text encoder weights)"
35 | wget https://storage.googleapis.com/tips_data/v1_0/checkpoints/scenic/${CHECKPOINT}_text.npz
36 | done
37 |
--------------------------------------------------------------------------------
/scenic/configs/tips_model_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """TIPS model config."""
17 |
18 | import ml_collections
19 |
20 | _MEAN_RGB = [0., 0., 0.]
21 | _STDDEV_RGB = [1., 1., 1.]
22 |
23 | # The 'g' variant refers to the DINO-v2 'giant2', which differs from ViT-g.
24 | # The differences are highlighted in https://arxiv.org/pdf/2304.07193 Section 5.
25 | _VARIANT_DICT = {
26 | 'tips_oss_g14_highres': 'g/14',
27 | 'tips_oss_g14_lowres': 'g/14',
28 | 'tips_oss_so400m14_highres_largetext_distilled': 'So400m/14',
29 | 'tips_oss_l14_highres_distilled': 'L/14',
30 | 'tips_oss_b14_highres_distilled': 'B/14',
31 | 'tips_oss_s14_highres_distilled': 'S/14',
32 | }
33 |
34 |
35 | def get_config(variant: str):
36 | """Returns the TIPS model config."""
37 | config = ml_collections.ConfigDict()
38 | if variant not in _VARIANT_DICT:
39 | raise ValueError(
40 | f'Unknown TIPS variant: {variant}. Please choose one of: '
41 | f'{list(_VARIANT_DICT.keys())}')
42 |
43 | config.variant = _VARIANT_DICT[variant]
44 | config.rgb_mean = _MEAN_RGB
45 | config.rgb_std = _STDDEV_RGB
46 |
47 | config.pooling = 'tok'
48 | config.pos_interpolation_method = 'bilinear'
49 |
50 | # TIPS defaults to 2 CLS tokens.
51 | config.num_cls_tokens = 2
52 |
53 | return config
54 |
--------------------------------------------------------------------------------
/scenic/images/example_image.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/tips/c1be95c4cba9345f1d207e652c6e285ed7c2ec04/scenic/images/example_image.jpg
--------------------------------------------------------------------------------
/scenic/images/example_image_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/tips/c1be95c4cba9345f1d207e652c6e285ed7c2ec04/scenic/images/example_image_2.jpg
--------------------------------------------------------------------------------
/scenic/models/text.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Text-encoder related modules."""
17 |
18 | import math
19 | import typing as t
20 |
21 | import flax.linen as nn
22 | import jax
23 | import jax.numpy as jnp
24 | import numpy as np
25 | from scenic.model_lib.layers import nn_layers
26 | import tensorflow as tf
27 | import tensorflow_text
28 |
29 |
30 | Initializer = t.Callable[[jnp.ndarray, t.Sequence[int], jnp.dtype], jnp.ndarray]
31 |
32 |
33 | class Tokenizer(object):
34 | """A simple tokenizer."""
35 |
36 | def __init__(self, tokenizer_path: str):
37 | """Initializes the tokenizer."""
38 | with open(tokenizer_path, 'rb') as f:
39 | model = f.read()
40 | self.tokenizer = tensorflow_text.SentencepieceTokenizer(
41 | model=model, add_eos=False, add_bos=False)
42 |
43 | def tokenize(self, input_text, max_len=64):
44 | tokens = self.tokenizer.tokenize(tf.strings.lower(input_text)).to_tensor()
45 | curr_len = tokens.shape[1]
46 | is_padding = tf.zeros((tokens.shape[0], max_len))
47 | if curr_len > max_len:
48 | tokens = tokens[:, :max_len]
49 | else:
50 | padding_len = max_len - curr_len
51 | tokens = tf.pad(tokens, [[0, 0], [0, padding_len]], constant_values=0)
52 | is_padding = tf.cast(tokens == 0, tf.int32)
53 | return tokens.numpy(), is_padding.numpy()
54 |
55 |
56 | class Embedding(nn.Module):
57 | """A simple embedding layer that performs embedding lookups from ids.
58 |
59 | Simple version of
60 | https://github.com/google/praxis/blob/main/praxis/layers/embedding_softmax.py#L97
61 |
62 | Attributes:
63 | num_classes: Number of tokens in the vocabulary.
64 | embedding_dim: Depth of the embedding output.
65 | scale_sqrt_depth: If set to True, activations are scaled with
66 | sqrt(embedding_dim) in emb_lookup.
67 | """
68 |
69 | num_classes: int = 0
70 | embedding_dim: int = 0
71 | scale_sqrt_depth: bool = True
72 |
73 | def setup(self) -> None:
74 | assert self.num_classes > 0
75 | assert self.embedding_dim > 0
76 |
77 | self.emb_var = self.param(
78 | 'emb_var',
79 | nn.initializers.variance_scaling(1.0, 'fan_in', 'normal', out_axis=0),
80 | (self.num_classes, self.embedding_dim),
81 | jnp.float32)
82 |
83 | def emb_lookup(self, ids: jnp.ndarray) -> jnp.ndarray:
84 | embs = self.emb_var[ids]
85 |
86 | if self.scale_sqrt_depth:
87 | embs *= self.embedding_dim**0.5
88 |
89 | return embs
90 |
91 | def __call__(self, ids: jnp.ndarray) -> jnp.ndarray:
92 | return self.emb_lookup(ids)
93 |
94 |
95 | class PositionalEmbedding(nn.Module):
96 | """Generates fixed position embedding for a given 1-d sequence.
97 |
98 | Simplified version of
99 | https://github.com/google/praxis/blob/main/praxis/layers/embedding_softmax.py#L1011
100 |
101 | Attributes:
102 | min_timescale: Start of the geometric index. Determines the periodicity of
103 | the added signal.
104 | max_timescale: End of the geometric index. Determines the frequency of the
105 | added signal.
106 | embedding_dim: Dimension of the embedding to be generated.
107 | """
108 |
109 | min_timescale: int = 1
110 | max_timescale: int = 10_000
111 | embedding_dim: int = 0
112 |
113 | def __call__(
114 | self, seq_length: int | None = None, position: jnp.ndarray | None = None
115 | ) -> jnp.ndarray:
116 | """Generates a jnp.ndarray of sinusoids with different frequencies.
117 |
118 | Args:
119 | seq_length: an optional Python int definiing the output sequence length.
120 | if the `position` argument is specified.
121 | position: [B, seq_length], optional position for each token in the
122 | sequence, only required when the sequence is packed.
123 |
124 | Returns:
125 | [B, seqlen, D] if `position` is specified, else [1, seqlen, D]
126 | """
127 | if position is None:
128 | assert seq_length is not None
129 | # [1, seqlen]
130 | position = jnp.arange(seq_length, dtype=jnp.float32)[jnp.newaxis, :]
131 | else:
132 | assert position.ndim == 2, position.shape
133 |
134 | num_timescales = self.embedding_dim // 2
135 | log_timescale_increment = math.log(
136 | float(self.max_timescale) / float(self.min_timescale)
137 | ) / jnp.maximum(jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1)
138 | inv_timescales = self.min_timescale * jnp.exp(
139 | jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment
140 | )
141 | scaled_time = (
142 | position[:, :, jnp.newaxis]
143 | * inv_timescales[jnp.newaxis, jnp.newaxis, :]
144 | )
145 | signal = jnp.concatenate(
146 | [jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=2)
147 | # Force usage of `np` rather than `jnp` to compute static values at trace
148 | # time.
149 | signal = jnp.pad(
150 | signal, [[0, 0], [0, 0], [0, np.mod(self.embedding_dim, 2)]]
151 | )
152 | return signal
153 |
154 |
155 | class GlobalAvgPooling(nn.Module):
156 | """Performs a simple global pooling over the input with optional paddings.
157 |
158 | Attributes:
159 | pooling_dims: A list of dims to perform pooling over.
160 | keepdims: If True, keep dimension of inputs after pooling.
161 | """
162 | pooling_dims: t.Sequence[int] | None = None
163 | epsilon: float = 1e-8
164 |
165 | def setup(self) -> None:
166 | if self.pooling_dims is None:
167 | raise ValueError('pooling_dims must be set as a list.')
168 | else:
169 | if not all([p_dims >= 0 for p_dims in self.pooling_dims]):
170 | raise ValueError('pooling_dims must be non-negative integers.')
171 |
172 | def __call__(
173 | self,
174 | inputs: jnp.ndarray,
175 | compatible_paddings: jnp.ndarray,
176 | ) -> jnp.ndarray:
177 | """Applies global average spatial pooling to inputs.
178 |
179 | Args:
180 | inputs: An input tensor.
181 | compatible_paddings: paddings of inputs with shapes compatible
182 | with inputs, e.g. compatible_paddings with shape [B, 1] for inputs with
183 | shape [B, D].
184 |
185 | Returns:
186 | Output tensor with global pooling applied.
187 | """
188 | padded_value = jnp.zeros(shape=(), dtype=inputs.dtype)
189 | padded_value = jnp.ones_like(inputs) * padded_value
190 | inputs = jnp.where(compatible_paddings > 0, padded_value, inputs)
191 | valid_inputs = (
192 | jnp.sum(
193 | 1.0 - compatible_paddings,
194 | self.pooling_dims,
195 | keepdims=True,
196 | dtype=inputs.dtype)
197 | + self.epsilon)
198 | inputs_sum = jnp.sum(inputs, self.pooling_dims, keepdims=True)
199 | outputs = jnp.divide(inputs_sum, valid_inputs).astype(inputs.dtype)
200 | outputs = jnp.squeeze(outputs, axis=self.pooling_dims)
201 | return outputs
202 |
203 |
204 | class MlpBlockWithMask(nn.Module):
205 | """Transformer MLP / feed-forward block that supports masking."""
206 |
207 | mlp_dim: int
208 | out_dim: t.Optional[int] = None
209 | dropout_rate: float = 0.1
210 | use_bias: bool = True
211 | kernel_init: Initializer = nn.initializers.xavier_uniform()
212 | bias_init: Initializer = nn.initializers.normal(stddev=1e-6)
213 | activation_fn: t.Callable[[jnp.ndarray], jnp.ndarray] = nn.gelu
214 | precision: t.Optional[jax.lax.Precision] = None
215 | dtype: jnp.ndarray = jnp.float32
216 |
217 | @nn.compact
218 | def __call__(self, inputs: jnp.ndarray, *, mask, deterministic: bool):
219 | """Applies Transformer MlpBlock with mask module."""
220 | actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
221 | x = nn.Dense(
222 | self.mlp_dim,
223 | dtype=self.dtype,
224 | use_bias=self.use_bias,
225 | kernel_init=self.kernel_init,
226 | bias_init=self.bias_init,
227 | precision=self.precision)(
228 | inputs)
229 | x = nn_layers.IdentityLayer(name='mlp1')(self.activation_fn(x))
230 | x = x * mask[..., None] # First masking.
231 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
232 | output = nn.Dense(
233 | actual_out_dim,
234 | dtype=self.dtype,
235 | use_bias=self.use_bias,
236 | kernel_init=self.kernel_init,
237 | bias_init=self.bias_init,
238 | precision=self.precision)(x)
239 | output = output * mask[..., None] # Second masking.
240 | output = nn_layers.IdentityLayer(name='mlp2')(output)
241 | output = nn.Dropout(rate=self.dropout_rate)(
242 | output, deterministic=deterministic)
243 | return output
244 |
245 |
246 | class TextEncoder1DBlock(nn.Module):
247 | """Transformer text encoder layer.
248 |
249 | Attributes:
250 | mlp_dim: Dimension of the mlp on top of attention block.
251 | num_heads: Number of self-attention heads.
252 | dtype: The dtype of the computation (default: float32).
253 | dropout_rate: Dropout rate.
254 | attention_dropout_rate: Dropout for attention heads.
255 | stochastic_depth: probability of dropping a layer linearly grows
256 | from 0 to the provided value.
257 | ffn_layer: type of the feed-forward layer. Options are 'mlp', 'swiglufused'.
258 |
259 | Returns:
260 | output after transformer encoder block.
261 | """
262 | mlp_dim: int
263 | num_heads: int
264 | dtype: t.Any = jnp.float32
265 | dropout_rate: float = 0.1
266 | attention_dropout_rate: float = 0.1
267 | stochastic_depth: float = 0.0
268 |
269 | @nn.compact
270 | def __call__(
271 | self, inputs: jnp.ndarray, mask: jnp.ndarray, deterministic: bool
272 | ) -> jnp.ndarray:
273 | """Applies Encoder1DBlock module.
274 |
275 | Args:
276 | inputs: Input data.
277 | mask: Input mask.
278 | deterministic: Deterministic or not (to apply dropout).
279 |
280 | Returns:
281 | Output after transformer encoder block.
282 | """
283 | # Attention block.
284 | assert inputs.ndim == 3
285 | x = nn.LayerNorm(name='LayerNorm_0', dtype=self.dtype)(inputs)
286 | x = nn.MultiHeadDotProductAttention(
287 | num_heads=self.num_heads,
288 | dtype=self.dtype,
289 | kernel_init=nn.initializers.xavier_uniform(),
290 | broadcast_dropout=False,
291 | deterministic=deterministic,
292 | dropout_rate=self.attention_dropout_rate)(
293 | x, x, mask=mask[:, jnp.newaxis, jnp.newaxis, :])
294 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic)
295 | x = nn_layers.StochasticDepth(rate=self.stochastic_depth)(x, deterministic)
296 | x = x + inputs
297 |
298 | # MLP block.
299 | y = nn.LayerNorm(dtype=self.dtype, name='LayerNorm_1')(x)
300 | mlp0 = MlpBlockWithMask(
301 | mlp_dim=self.mlp_dim,
302 | dtype=self.dtype,
303 | dropout_rate=self.dropout_rate,
304 | activation_fn=nn.relu, # ReLU is the choice for the PAX experiments.
305 | kernel_init=nn.initializers.xavier_uniform(),
306 | bias_init=nn.initializers.normal(stddev=1e-6),
307 | name='MlpBlock_0'
308 | )
309 | y = mlp0(y, mask=mask, deterministic=deterministic)
310 | y = nn_layers.StochasticDepth(rate=self.stochastic_depth)(y, deterministic)
311 | return x + y
312 |
313 |
314 | class StackedTransformer(nn.Module):
315 | """Stacked transformer."""
316 |
317 | mlp_dim: int
318 | num_layers: int
319 | num_heads: int
320 | dropout_rate: float = 0.1
321 | attention_dropout_rate: float = 0.1
322 | stochastic_depth: float = 0.0
323 | dtype: t.Any = jnp.float32
324 |
325 | def setup(self):
326 | encoder_blocks = []
327 | for lyr in range(self.num_layers):
328 | encoder_blocks.append(
329 | TextEncoder1DBlock(
330 | mlp_dim=self.mlp_dim,
331 | num_heads=self.num_heads,
332 | dropout_rate=self.dropout_rate,
333 | attention_dropout_rate=self.attention_dropout_rate,
334 | stochastic_depth=(
335 | (lyr / max(self.num_layers - 1, 1)) * self.stochastic_depth),
336 | name=f'encoderblock_{lyr}',
337 | ))
338 | self.encoder_blocks = encoder_blocks
339 |
340 | def __call__(
341 | self, x: jnp.ndarray, mask: jnp.ndarray, deterministic: bool
342 | ) -> jnp.ndarray:
343 | """Applies StackedTransformer module."""
344 | for block in self.encoder_blocks:
345 | x = block(x, mask, deterministic=deterministic)
346 | return x
347 |
348 |
--------------------------------------------------------------------------------
/scenic/models/tips.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """The TIPS model definition."""
17 |
18 | import typing as t
19 |
20 | import flax.linen as nn
21 | import jax.numpy as jnp
22 |
23 | from tips.scenic.models import text
24 | from tips.scenic.models import vit
25 |
26 |
27 | class VisionEncoder(nn.Module):
28 | """TIPS vision encoder based on ViT."""
29 |
30 | variant: str
31 | pooling: str = 'tok'
32 | num_cls_tokens: int = 2 # TIPS defaults to 2 CLS tokens.
33 | dropout_rate: float = 0.0
34 | attention_dropout_rate: float = 0.0
35 | stochastic_depth: float = 0.0
36 | dtype: t.Any = jnp.float32
37 |
38 | def setup(self):
39 | super().setup()
40 |
41 | self.encoder = vit.ViT(
42 | variant=self.variant,
43 | num_cls_tokens=self.num_cls_tokens,
44 | dropout_rate=self.dropout_rate,
45 | attention_dropout_rate=self.attention_dropout_rate,
46 | stochastic_depth=self.stochastic_depth,
47 | dtype=self.dtype,
48 | )
49 | self.patches = self.encoder.patches
50 |
51 | def pool_features(self, x: jnp.ndarray)-> t.Tuple[jnp.ndarray, jnp.ndarray]:
52 | """Extracts the spatial and vector features from the backhone.
53 |
54 | Currently supports only 'tok' pooling (CLS tokens). The CLS tokens are
55 | always prepended to the spatial (patch) tokens.
56 |
57 | Args:
58 | x: The input features.
59 |
60 | Returns:
61 | x_patch: The spatial features.
62 | x_vec: The vector embedding(s).
63 | """
64 | if self.pooling == 'tok':
65 | x_vec = x[:, :self.num_cls_tokens, :]
66 | x_patch = x[:, self.num_cls_tokens:, :]
67 | else:
68 | raise ValueError(f'Invalid pooling: {self.pooling}')
69 | return x_patch, x_vec
70 |
71 | def reshape_spatial_features(
72 | self, x: jnp.ndarray, h: int, w: int) -> jnp.ndarray:
73 | """Re-shapes the spatial features according to the initial dimensions."""
74 | fh = h // self.patches[0]
75 | fw = w // self.patches[1]
76 | bs, l, f = x.shape
77 | if l != fh * fw:
78 | raise ValueError(f'Invalid shape: {x.shape}')
79 | return x.reshape(bs, fh, fw, f)
80 |
81 | @nn.compact
82 | def __call__(
83 | self, x: jnp.ndarray, *, train: bool, debug: bool = False
84 | ) -> t.Tuple[jnp.ndarray, jnp.ndarray]:
85 | del debug
86 | x = vit.maybe_center_pad(
87 | x, patch_h=self.patches[0], patch_w=self.patches[1])
88 | h, w = x.shape[1:3] # w, h of images after padding.
89 | x = self.encoder(x, train=train)
90 | x_patch, x_vec = self.pool_features(x)
91 | x_patch = self.reshape_spatial_features(x_patch, h, w)
92 |
93 | return x_patch, x_vec
94 |
95 |
96 | class TextEncoder(nn.Module):
97 | """TIPS Text encoder."""
98 |
99 | variant: str
100 | vocab_size: int = 32_000
101 | dropout_rate: float = 0.1
102 | attention_dropout_rate: float = 0.1
103 | stochastic_depth: float = 0.0
104 | dtype: t.Any = jnp.float32
105 | scale_sqrt_depth: bool = True # Default param in PAX experiments.
106 |
107 | def setup(self):
108 | super().setup()
109 | text_config = vit.get_vit_config(self.variant)
110 | text_config['num_layers'] = 12
111 | # The text tower layers are fixed independent of vision tower size.
112 | # Exception: The So400m/14 text tower is a symmetric copy of the vision
113 | # tower.
114 | self.num_layers = 12
115 | if self.variant != 'So400m/14':
116 | self.num_layers = text_config['num_layers']
117 | self.embedding_dim = text_config['hidden_size']
118 | self.mlp_dim = text_config['mlp_dim']
119 | self.num_heads = text_config['num_heads']
120 | self.embedder = text.Embedding(
121 | name='token_emb',
122 | num_classes=self.vocab_size,
123 | embedding_dim=self.embedding_dim,
124 | scale_sqrt_depth=self.scale_sqrt_depth)
125 | self.pos_embedder = text.PositionalEmbedding(
126 | embedding_dim=self.embedding_dim)
127 | self.transformer = text.StackedTransformer(
128 | name='transformer',
129 | mlp_dim=self.mlp_dim,
130 | num_layers=self.num_layers,
131 | num_heads=self.num_heads,
132 | dropout_rate=self.dropout_rate,
133 | attention_dropout_rate=self.attention_dropout_rate,
134 | stochastic_depth=self.stochastic_depth,
135 | dtype=self.dtype,
136 | )
137 | self.pooling = text.GlobalAvgPooling(pooling_dims=[1])
138 | self.norm = nn.LayerNorm(dtype=self.dtype, name='text_encoder_norm')
139 |
140 | def __call__(
141 | self,
142 | ids: jnp.ndarray,
143 | paddings: jnp.ndarray,
144 | train: bool,
145 | ) -> jnp.ndarray:
146 | """Applies TextEncoder module."""
147 | _, seq_length = ids.shape
148 | mask = (paddings == 0).astype(jnp.int32)
149 | x = self.embedder(ids)
150 | x = x + self.pos_embedder(seq_length=seq_length)
151 | x = self.transformer(x, mask, deterministic=not train)
152 | x = self.norm(x)
153 | x = self.pooling(x, compatible_paddings=paddings[:, :, jnp.newaxis])
154 | return x
155 |
156 |
--------------------------------------------------------------------------------
/scenic/models/vit.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Standard ViT model definition."""
17 |
18 | import logging
19 | import math
20 | import typing as t
21 |
22 | import flax.linen as nn
23 | import jax
24 | import jax.numpy as jnp
25 | import ml_collections
26 | import numpy as np
27 |
28 | from scenic.model_lib.layers import attention_layers
29 | from scenic.model_lib.layers import nn_layers
30 |
31 | Initializer = t.Callable[[jnp.ndarray, t.Sequence[int], jnp.dtype], jnp.ndarray]
32 |
33 |
34 | def get_vit_config(variant: str) -> t.Dict[str, t.Any]:
35 | v, patch = variant.split('/')
36 | return {
37 | # pylint:disable=line-too-long
38 | 'hidden_size': {'S': 384, 'B': 768, 'L': 1024, 'So400m': 1152, 'g': 1536}[v],
39 | 'num_layers': {'S': 12, 'B': 12, 'L': 24, 'So400m': 27, 'g': 40}[v],
40 | 'mlp_dim': {'S': 1536, 'B': 3072, 'L': 4096, 'So400m': 4304, 'g': 6144}[v],
41 | 'num_heads': {'S': 6, 'B': 12, 'L': 16, 'So400m': 16, 'g': 24}[v],
42 | 'patch_size': (int(patch), int(patch)),
43 | 'ffn_layer': {'S': 'mlp', 'B': 'mlp', 'L': 'mlp', 'So400m': 'mlp', 'g': 'swiglu'}[v],
44 | # pylint:enable=line-too-long
45 | }
46 |
47 |
48 | def maybe_center_pad(x: jnp.ndarray, patch_h: int, patch_w: int):
49 | """Pads the input to the next multiple of the patch size."""
50 | h_old, w_old = x.shape[1:3]
51 | pad_h = math.ceil(h_old / patch_h) * patch_h - h_old
52 | pad_w = math.ceil(w_old / patch_w) * patch_w - w_old
53 | if pad_w > 0 or pad_h > 0:
54 | pad_h_top = pad_h // 2
55 | pad_h_bottom = pad_h - pad_h_top
56 | pad_w_left = pad_w // 2
57 | pad_w_right = pad_w - pad_w_left
58 | logging.info(
59 | 'Applying center padding (%d, %d), (%d, %d)',
60 | pad_w_left, pad_w_right, pad_h_top, pad_h_bottom)
61 | x = jnp.pad(
62 | x, ((0, 0),
63 | (pad_h_top, pad_h_bottom),
64 | (pad_w_left, pad_w_right),
65 | (0, 0)))
66 | return x
67 |
68 |
69 | class ToTokenSequence(nn.Module):
70 | """Transform a batch of views into a sequence of tokens."""
71 |
72 | patches: ml_collections.ConfigDict
73 | hidden_size: int
74 | num_cls_tokens: int = 0
75 | posembs: t.Tuple[int, int] = (16, 16)
76 | pos_interpolation_method: str = 'bilinear'
77 |
78 | def add_positional_encodings(self, x: jnp.ndarray) -> jnp.ndarray:
79 | """Support a few variants for sinsuoidal 2D position embeddings."""
80 | n, h, w, c = x.shape
81 | posemb = self.param(
82 | 'posembed_input',
83 | nn.initializers.normal(stddev=1/np.sqrt(c)),
84 | (1, self.posembs[0], self.posembs[1], c), x.dtype)
85 | # Interpolate the positional encodings.
86 | if (h, w) != self.posembs:
87 | posemb = jax.image.resize(
88 | posemb, (1, h, w, c), self.pos_interpolation_method)
89 | x = x + posemb
90 | x = jnp.reshape(x, [n, h * w, c])
91 |
92 | assert x.ndim == 3 # Shape is `[batch, len, emb]`.
93 | return x
94 |
95 | @nn.compact
96 | def __call__(self, x: jnp.ndarray, *, seqlen: int = -1):
97 |
98 | fh, fw = self.patches
99 | # Extracting patches and then embedding is in fact a single convolution.
100 | x = nn.Conv(
101 | self.hidden_size, (fh, fw),
102 | strides=(fh, fw),
103 | padding='VALID',
104 | name='embedding')(x)
105 |
106 | # Add positional encodings.
107 | x = self.add_positional_encodings(x)
108 |
109 | # Add extra "cls" tokens.
110 | if self.num_cls_tokens > 0:
111 | n, _, c = x.shape
112 | cls_tok = self.param(
113 | 'cls',
114 | nn.initializers.zeros,
115 | (1, self.num_cls_tokens, c),
116 | x.dtype)
117 | cls_tok = jnp.tile(cls_tok, [n, 1, 1])
118 | x = jnp.concatenate([cls_tok, x], axis=1)
119 | return x
120 |
121 |
122 | class FFNSwiGluFused(nn.Module):
123 | """SwiGlu variant of the feed-forward block.
124 |
125 | https://arxiv.org/abs/2002.05202v1
126 | """
127 |
128 | mlp_dim: int
129 | out_dim: t.Optional[int] = None
130 | dropout_rate: float = 0.0
131 | use_bias: bool = False
132 | kernel_init: Initializer = nn.initializers.xavier_uniform()
133 | bias_init: Initializer = nn.initializers.zeros
134 | precision: t.Optional[jax.lax.Precision] = None
135 | dtype: jnp.ndarray = jnp.float32
136 |
137 | def _hidden_layer(self, inputs: jnp.ndarray) -> jnp.ndarray:
138 | # https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/swiglu_ffn.py#L57 # pylint: disable=line-too-long
139 | mlp_dim = (int(self.mlp_dim * 2 / 3) + 7) // 8 * 8
140 | xw = nn.Dense(
141 | mlp_dim,
142 | dtype=self.dtype,
143 | use_bias=self.use_bias,
144 | kernel_init=self.kernel_init,
145 | bias_init=self.bias_init,
146 | precision=self.precision,
147 | )(inputs)
148 | xv = nn.Dense(
149 | mlp_dim,
150 | dtype=self.dtype,
151 | use_bias=self.use_bias,
152 | kernel_init=self.kernel_init,
153 | bias_init=self.bias_init,
154 | precision=self.precision,
155 | )(inputs)
156 | xw = nn.swish(xw)
157 | x = xw * xv
158 | return x
159 |
160 | @nn.compact
161 | def __call__(
162 | self, inputs: jnp.ndarray, *, deterministic: bool
163 | ) -> jnp.ndarray:
164 | """Applies FFN module."""
165 | actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
166 | x = self._hidden_layer(inputs)
167 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
168 | output = nn.Dense(
169 | actual_out_dim,
170 | dtype=self.dtype,
171 | use_bias=self.use_bias,
172 | kernel_init=self.kernel_init,
173 | bias_init=self.bias_init,
174 | precision=self.precision)(x)
175 | output = nn.Dropout(rate=self.dropout_rate)(
176 | output, deterministic=deterministic)
177 | return output
178 |
179 |
180 | class VisionEncoder1DBlock(nn.Module):
181 | """Transformer encoder layer.
182 |
183 | Attributes:
184 | mlp_dim: Dimension of the mlp on top of attention block.
185 | num_heads: Number of self-attention heads.
186 | dtype: The dtype of the computation (default: float32).
187 | dropout_rate: Dropout rate.
188 | attention_dropout_rate: Dropout for attention heads.
189 | stochastic_depth: probability of dropping a layer linearly grows
190 | from 0 to the provided value.
191 | ffn_layer: type of the feed-forward layer. Options are 'mlp', 'swiglufused'.
192 |
193 | Returns:
194 | output after transformer encoder block.
195 | """
196 | mlp_dim: int
197 | num_heads: int
198 | dtype: t.Any = jnp.float32
199 | dropout_rate: float = 0.0
200 | attention_dropout_rate: float = 0.0
201 | stochastic_depth: float = 0.0
202 | ffn_layer: str = 'mlp'
203 |
204 | def setup(self):
205 | super().setup()
206 |
207 | if self.ffn_layer == 'mlp':
208 | ffn_layer = attention_layers.MlpBlock(
209 | mlp_dim=self.mlp_dim,
210 | dtype=self.dtype,
211 | dropout_rate=self.dropout_rate,
212 | activation_fn=nn.gelu,
213 | kernel_init=nn.initializers.xavier_uniform(),
214 | bias_init=nn.initializers.normal(stddev=1e-6),
215 | name='MlpBlock_0')
216 | elif self.ffn_layer == 'swiglu':
217 | ffn_layer = FFNSwiGluFused(
218 | mlp_dim=self.mlp_dim,
219 | dtype=self.dtype,
220 | use_bias=True,
221 | dropout_rate=self.dropout_rate,
222 | kernel_init=nn.initializers.xavier_uniform(),
223 | bias_init=nn.initializers.normal(stddev=1e-6),
224 | name='FFNSwiGluFused_0')
225 | else:
226 | raise ValueError(f'Unsupported ffn_layer: {self.ffn_layer}')
227 | self.ffn = ffn_layer
228 | self.ln_0 = nn.LayerNorm(dtype=self.dtype, name='LayerNorm_0')
229 | self.ln_1 = nn.LayerNorm(dtype=self.dtype, name='LayerNorm_1')
230 | self.attention = nn.MultiHeadDotProductAttention(
231 | name='MultiHeadDotProductAttention_0',
232 | num_heads=self.num_heads,
233 | dtype=self.dtype,
234 | kernel_init=nn.initializers.xavier_uniform(),
235 | broadcast_dropout=False,
236 | dropout_rate=self.attention_dropout_rate)
237 |
238 | @nn.compact
239 | def __call__(
240 | self, inputs: jnp.ndarray, deterministic: bool) -> jnp.ndarray:
241 | """Applies Encoder1DBlock module.
242 |
243 | Args:
244 | inputs: Input data.
245 | deterministic: Deterministic or not (to apply dropout).
246 |
247 | Returns:
248 | Output after transformer encoder block.
249 | """
250 | # Attention block.
251 | assert inputs.ndim == 3
252 | x = self.ln_0(inputs)
253 | x = self.attention(x, x, deterministic=deterministic)
254 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic)
255 | x = nn_layers.StochasticDepth(rate=self.stochastic_depth)(x, deterministic)
256 | x = x + inputs
257 |
258 | # MLP block.
259 | y = self.ln_1(x)
260 | y = self.ffn(y, deterministic=deterministic)
261 | y = nn_layers.StochasticDepth(rate=self.stochastic_depth)(y, deterministic)
262 | return y + x
263 |
264 |
265 | class StackedTransformer(nn.Module):
266 | """Stacked transformer."""
267 |
268 | mlp_dim: int
269 | num_layers: int
270 | num_heads: int
271 | ffn_layer: str = 'mlp'
272 | dropout_rate: float = 0.0
273 | attention_dropout_rate: float = 0.0
274 | stochastic_depth: float = 0.0
275 | dtype: t.Any = jnp.float32
276 |
277 | def setup(self):
278 | encoder_blocks = []
279 | for lyr in range(self.num_layers):
280 | encoder_blocks.append(
281 | VisionEncoder1DBlock(
282 | mlp_dim=self.mlp_dim,
283 | num_heads=self.num_heads,
284 | dropout_rate=self.dropout_rate,
285 | attention_dropout_rate=self.attention_dropout_rate,
286 | stochastic_depth=(lyr / max(self.num_layers - 1, 1))
287 | * self.stochastic_depth,
288 | name=f'encoderblock_{lyr}',
289 | ffn_layer=self.ffn_layer,
290 | dtype=self.dtype))
291 | self.encoder_blocks = encoder_blocks
292 |
293 | def __call__(
294 | self, x: jnp.ndarray, deterministic: bool) -> jnp.ndarray:
295 | """Applies StackedTransformer module."""
296 | for block in self.encoder_blocks:
297 | x = block(x, deterministic=deterministic)
298 | return x
299 |
300 |
301 | class ViT(nn.Module):
302 | """Dense Features backbone based on ViT."""
303 |
304 | variant: str
305 | freeze_backbone: bool = False
306 | num_cls_tokens: int = 1
307 | dropout_rate: float = 0.1
308 | attention_dropout_rate: float = 0.1
309 | stochastic_depth: float = 0.0
310 | dtype: t.Any = jnp.float32
311 |
312 | def setup(self):
313 | super().setup()
314 | vit_config = get_vit_config(self.variant)
315 | self.patches = vit_config['patch_size']
316 | self.hidden_size = vit_config['hidden_size']
317 | self.num_layers = vit_config['num_layers']
318 | self.mlp_dim = vit_config['mlp_dim']
319 | self.num_heads = vit_config['num_heads']
320 | self.ffn_layer = vit_config['ffn_layer']
321 |
322 | # Setup for layers.
323 | self.token_fn = ToTokenSequence(
324 | name='ToTokenSequence_0',
325 | patches=self.patches,
326 | hidden_size=self.hidden_size,
327 | num_cls_tokens=self.num_cls_tokens,
328 | posembs=(16, 16),
329 | )
330 | self.norm = nn.LayerNorm(name='encoder_norm')
331 | self.transformer = StackedTransformer(
332 | name='transformer',
333 | mlp_dim=self.mlp_dim,
334 | num_layers=self.num_layers,
335 | num_heads=self.num_heads,
336 | dropout_rate=self.dropout_rate,
337 | attention_dropout_rate=self.attention_dropout_rate,
338 | stochastic_depth=self.stochastic_depth,
339 | dtype=self.dtype,
340 | ffn_layer=self.ffn_layer,
341 | )
342 |
343 | @nn.compact
344 | def __call__(
345 | self, x: jnp.ndarray, *, train: bool, debug: bool = False) -> jnp.ndarray:
346 | del debug
347 | logging.info('train=%s shape before padding=%s', train, x.shape)
348 | x = maybe_center_pad(x, patch_h=self.patches[0], patch_w=self.patches[1])
349 | logging.info('train=%s shape after padding=%s', train, x.shape)
350 |
351 | x = self.token_fn(x)
352 | x = self.transformer(x, deterministic=not train)
353 | x = self.norm(x)
354 |
355 | return x
356 |
--------------------------------------------------------------------------------
/scenic/run_tips_inference.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Runs TIPS inference."""
17 |
18 | import argparse
19 | import os
20 | import cv2
21 | import flax.linen as nn
22 | import jax
23 | import jax.numpy as jnp
24 | import numpy as np
25 | from PIL import Image
26 |
27 | from tips.scenic.configs import tips_model_config
28 | from tips.scenic.models import text
29 | from tips.scenic.models import tips
30 | from tips.scenic.utils import checkpoint
31 | from tips.scenic.utils import feature_viz
32 |
33 |
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument(
36 | '--image_width',
37 | type=int,
38 | default=448,
39 | help='Image width.',
40 | )
41 | parser.add_argument(
42 | '--variant',
43 | type=str,
44 | default='tips_oss_b14_highres_distilled',
45 | choices=(
46 | 'tips_oss_g14_highres',
47 | 'tips_oss_g14_lowres',
48 | 'tips_oss_so400m14_highres_largetext_distilled',
49 | 'tips_oss_l14_highres_distilled',
50 | 'tips_oss_b14_highres_distilled',
51 | 'tips_oss_s14_highres_distilled',
52 | ),
53 | help='Model variant.',
54 | )
55 | parser.add_argument(
56 | '--checkpoint_dir',
57 | type=str,
58 | default='checkpoints/',
59 | help='The directory of the checkpoints and the tokenizer.',
60 | )
61 | parser.add_argument(
62 | '--image_path',
63 | type=str,
64 | default='images/example_image.jpg',
65 | help='The path to the image file.'
66 | )
67 |
68 |
69 | def main() -> None:
70 | args = parser.parse_args()
71 | image_width = args.image_width
72 | image_shape = (image_width,) * 2
73 | variant = args.variant
74 | checkpoint_dir = args.checkpoint_dir
75 | image_path = args.image_path
76 |
77 | # Load the model configuration.
78 | model_config = tips_model_config.get_config(variant)
79 |
80 | # Load the vision encoder.
81 | model_vision = tips.VisionEncoder(
82 | variant=model_config.variant,
83 | pooling=model_config.pooling,
84 | num_cls_tokens=model_config.num_cls_tokens)
85 | init_params_vision = model_vision.init(
86 | jax.random.PRNGKey(0), jnp.ones([1, *image_shape, 3]), train=False)
87 | params_vision = checkpoint.load_checkpoint(
88 | os.path.join(checkpoint_dir, f'{variant}_vision.npz'),
89 | init_params_vision['params'])
90 |
91 | # Load the text encoder.
92 | tokenizer_path = os.path.join(checkpoint_dir, 'tokenizer.model')
93 | tokenizer = text.Tokenizer(tokenizer_path)
94 | model_text = tips.TextEncoder(variant=model_config.variant)
95 | init_params_text = model_text.init(
96 | jax.random.PRNGKey(0),
97 | ids=jnp.ones((4, 64), dtype=jnp.int32),
98 | paddings=jnp.zeros((4, 64), dtype=jnp.int32),
99 | train=False)
100 | init_params_text['params']['temperature_contrastive'] = (
101 | np.array(0, dtype=np.float32))
102 | params_text = checkpoint.load_checkpoint(
103 | os.path.join(checkpoint_dir, f'{variant}_text.npz'),
104 | init_params_text['params'])
105 |
106 | # Load and preprocess the image.
107 | image = jnp.array(Image.open(image_path)).astype(jnp.float32) / 255.
108 | image = jax.image.resize(image, (*image_shape, 3), method='bilinear')
109 | image = image.astype(jnp.float32)
110 |
111 | # Run inference on the image.
112 | spatial_features, embeddings_vision = model_vision.apply(
113 | {'params': params_vision}, image[None], train=False)
114 | # We choose the first CLS token (the second one is better for dense tasks.).
115 | cls_token = feature_viz.normalize(embeddings_vision[:, 0, :])
116 |
117 | # Run inference on text.
118 | text_input = [
119 | 'A ship', 'holidays', 'a toy dinosaur', 'Two astronauts',
120 | 'a real dinosaur', 'A streetview image of burger kings',
121 | 'a streetview image of mc donalds']
122 | text_ids, text_paddings = tokenizer.tokenize(text_input, max_len=64)
123 | embeddings_text = model_text.apply(
124 | {'params': params_text},
125 | ids=text_ids,
126 | paddings=text_paddings,
127 | train=False)
128 | embeddings_text = feature_viz.normalize(embeddings_text)
129 |
130 | # Compute cosine similariy.
131 | cos_sim = nn.softmax(
132 | ((cls_token @ embeddings_text.T) /
133 | params_text['temperature_contrastive']), axis=-1)
134 | label_idxs = jnp.argmax(cos_sim, axis=-1)
135 | cos_sim_max = jnp.max(cos_sim, axis=-1)
136 | label_predicted = text_input[label_idxs[0].item()]
137 | similarity = cos_sim_max[0].item()
138 |
139 | # Compute PCA of patch tokens.
140 | pca_obj = feature_viz.PCAVisualizer(spatial_features)
141 | image_pca = pca_obj(spatial_features)[0]
142 | image_pca = np.asarray(jax.image.resize(
143 | image_pca, (*image_shape, 3), method='nearest'))
144 |
145 | # Display the results.
146 | cv2.imshow(
147 | f'{label_predicted}, prob: {similarity*100:.1f}%',
148 | np.concatenate([image, image_pca], axis=1)[..., ::-1])
149 | cv2.waitKey(0)
150 | cv2.destroyAllWindows()
151 |
152 |
153 | if __name__ == '__main__':
154 | main()
155 |
--------------------------------------------------------------------------------
/scenic/utils/checkpoint.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Checkpoint helpers functions."""
17 |
18 | import logging
19 | import typing as t
20 | import flax
21 | import numpy as np
22 |
23 |
24 | def load_checkpoint(
25 | checkpoint_path: str,
26 | params_to_load: t.Dict[str, np.ndarray],
27 | strict: bool = True,
28 | ) -> t.Dict[str, np.ndarray]:
29 | """Loads a TIPS checkpoint and checks that the parameters are compatible."""
30 | params_to_load_flat = flax.traverse_util.flatten_dict(params_to_load, sep='/')
31 | params_loaded_flat = dict(np.load(checkpoint_path, allow_pickle=True))
32 |
33 | # Check that params to load are in the checkpoint, and have identical shapes.
34 | for k in params_to_load_flat:
35 | if k not in params_loaded_flat:
36 | raise ValueError(f'Param {k} not found in checkpoint.')
37 | if params_loaded_flat[k].shape != params_to_load_flat[k].shape:
38 | raise ValueError(f'Param {k} has wrong shape in checkpoint.')
39 |
40 | # Check that the checkpoint does not contain extra parameter groups.
41 | for k in params_loaded_flat:
42 | if k not in params_to_load_flat:
43 | if strict:
44 | raise ValueError(f'Param {k} not found in params_to_load.')
45 | else:
46 | logging.warning('Param %s not found in params_to_load.', k)
47 |
48 | return flax.traverse_util.unflatten_dict(params_loaded_flat, sep='/')
49 |
--------------------------------------------------------------------------------
/scenic/utils/feature_viz.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Visualization helpers for features."""
17 |
18 | import typing as t
19 | import jax.numpy as jnp
20 | import numpy as np
21 | from sklearn import decomposition
22 |
23 |
24 | _ArrayLike = t.Union[np.ndarray, jnp.ndarray]
25 |
26 |
27 | def normalize(x, order: int = 2):
28 | return x / np.linalg.norm(
29 | x, ord=order, axis=-1, keepdims=True).clip(min=1e-3)
30 |
31 |
32 | class PCAVisualizer:
33 | """PCA visualizer."""
34 |
35 | def __init__(
36 | self,
37 | features: _ArrayLike,
38 | n_samples: int = 100000,
39 | n_components: int = 3) -> None:
40 | """Creates a PCA object for visualizing features of shape [..., F]."""
41 | features = np.array(features)
42 | pca_object = decomposition.PCA(n_components=n_components)
43 | features = features.reshape([-1, features.shape[-1]])
44 | features = features[np.random.randint(0, features.shape[0], n_samples), :]
45 | pca_object.fit(features)
46 | self.pca_object = pca_object
47 | self.n_components = n_components
48 |
49 | def __call__(self, features: _ArrayLike) -> np.ndarray:
50 | """Apply PCA to features of shape [..., F]."""
51 | features = np.array(features)
52 | features_pca = self.pca_object.transform(
53 | features.reshape([-1, features.shape[-1]])
54 | ).reshape(features.shape[:-1] + (self.n_components,))
55 | return normalize(features_pca) * 0.5 + 0.5
56 |
--------------------------------------------------------------------------------