├── .gitignore ├── LICENSE ├── README.md ├── demo.ipynb ├── demo_script.py ├── requirements.txt └── uio ├── configs.py ├── decoding.py ├── model.py ├── network.py ├── runner.py ├── t5x_layers.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UnifiedIO 2 | 3 | This repo contains code to run models from our paper [Unified-IO: A Unified Model for Vision, Language, and Multi-Modal Tasks](https://arxiv.org/abs/2206.08916). 4 | 5 | ## Installation 6 | Install [jax](https://github.com/google/jax#installation), note this might require manually installing 7 | Cuda Toolkits and Cudnn toolkits if using GPUs. 8 | 9 | Then install the supporting libraries with: 10 | 11 | ``` 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | ## Model weights 16 | Model weights can be found on aws: 17 | - XL: [https://ai2-prior-uio.s3.us-west-2.amazonaws.com/public/model-weights-bin/xl_1000k.bin](https://ai2-prior-uio.s3.us-west-2.amazonaws.com/public/model-weights-bin/xl_1000k.bin) (10.9gb) 18 | - Large: [https://ai2-prior-uio.s3.us-west-2.amazonaws.com/public/model-weights-bin/large_1000k.bin](https://ai2-prior-uio.s3.us-west-2.amazonaws.com/public/model-weights-bin/large_1000k.bin) (3.2gb) 19 | - Base: [https://ai2-prior-uio.s3.us-west-2.amazonaws.com/public/model-weights-bin/base_1000k.bin](https://ai2-prior-uio.s3.us-west-2.amazonaws.com/public/model-weights-bin/base_1000k.bin) (1.2gb) 20 | - Small: [https://ai2-prior-uio.s3.us-west-2.amazonaws.com/public/model-weights-bin/small_1000k.bin](https://ai2-prior-uio.s3.us-west-2.amazonaws.com/public/model-weights-bin/small_1000k.bin) (0.6gb) 21 | 22 | To download run: 23 | 24 | ```wget 25 | wget https://ai2-prior-uio.s3.us-west-2.amazonaws.com/public/model-weights-bin/small_1000k.bin -O small.bin 26 | ``` 27 | 28 | or download with aws-cli: 29 | ```aws 30 | aws s3 cp s3://ai2-prior-uio/public/model-weights-bin/small_1000k.bin small.bin 31 | ``` 32 | 33 | ## Usage 34 | Download an image to test on: 35 | ```bash 36 | wget https://farm2.staticflickr.com/1362/1261465554_95741e918b_z.jpg -O dbg_img.png 37 | ``` 38 | 39 | Then tasks can done using the `ModelRunner` class: 40 | 41 | ```python 42 | from uio import runner 43 | from PIL import Image 44 | import numpy as np 45 | 46 | model = runner.ModelRunner("small", "small.bin") 47 | 48 | with Image.open("dbg_img.png") as img: 49 | image = np.array(img.convert('RGB')) 50 | 51 | # Answer a VQA question, note this might take over a minute the first time it is 52 | # called while the function is compiled by jax 53 | output = model.vqa(image, "What color is the sofa?") 54 | print(output["text"]) # Should print `green` 55 | ``` 56 | 57 | This example can be run end-to-end by `demo_script.py`. `ModelRunner` supports many more tasks, 58 | examples can be seen in the demo notebook. 59 | 60 | 61 | `ModelRunner` also provides a lower-level API that can be called with arbitrary text/image output and 62 | can generate text/image outputs, as well supporting batch input 63 | 64 | ```python 65 | out = model.run([image], ["What is the depth map of the image ?"], 66 | output_text_len=1, generate_image=True, num_decodes=None) 67 | depth_image = out["image"][0] 68 | ``` 69 | 70 | ## Demo notebook 71 | More tasks are shown in demo.ipynb, this requires additionally install jupyter and matplotlib: 72 | 73 | ``` 74 | pip install matplotlib notebook 75 | ``` 76 | 77 | Then it can be run with: 78 | 79 | ```python 80 | jupyter notebook demo.ipynb 81 | ``` 82 | 83 | 84 | ## Just-in-time compilation 85 | By default `ModelRunner` compiles the underlying inference calls the first time they are used, 86 | this results in faster performance at a one-time cost. This can be disabled by setting the 87 | `compile` parameter to false. You can set the environment variable `JAX_LOG_COMPILES=1` 88 | to see when a function is being compiled. 89 | 90 | ## Implementation Details 91 | Running UnifiedIO on a task is a 4-step process: 92 | 93 | 1. Convert tasks inputs into (image_input, prompt) pairs, the image_input can be `None`. 94 | This step is task-specific and involve things like selecting a prompt for the tasks 95 | or converting region locations into region location tokens that are then embedded in the prompt, 96 | 2. Preprocess these components, done by `utils.preprocess_image` and converting the input prompt into 97 | tokens using a `T5Tokenizer` 98 | 3. Running the model on these pre-processed input, done in `model.py`. This produces text 99 | tokens and/or a 256x256 image as output. 100 | 4. Post-process the results, this step is task-specific and can involve converting the output 101 | tokens into text or image locations and/or resizing/cropping the output image. 102 | 103 | In `ModelRunner`, `run` does steps 2 and 3 and the task-specific methods do steps 1 and 4 104 | for various tasks. 105 | 106 | The main neural network code itself can be found in `modules.Transformer` 107 | 108 | ## Hardware requirements 109 | We have run XL model on GPUs with 24GB of memory, lower memory GPUs should be able to run 110 | the smaller models but might not be able to run the XL model. 111 | 112 | ## Citation 113 | If you use this codebase, please cite: 114 | 115 | ``` 116 | @article{lu2022unified, 117 | title={Unified-IO: A Unified Model for Vision, Language, and Multi-Modal Tasks}, 118 | author={Lu, Jiasen and Clark, Christopher and Zellers, Rowan and Mottaghi, Roozbeh and Kembhavi, Aniruddha}, 119 | journal={arXiv preprint arXiv:2206.08916}, 120 | year={2022} 121 | } 122 | ``` 123 | -------------------------------------------------------------------------------- /demo_script.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from os.path import exists 3 | 4 | from PIL import Image 5 | 6 | from uio import runner 7 | from uio.configs import CONFIGS 8 | import numpy as np 9 | 10 | from absl import logging 11 | import warnings 12 | 13 | # flax kicks up a lot of future warnings at the moment, ignore them 14 | warnings.simplefilter(action='ignore', category=FutureWarning) 15 | 16 | # To see INFO messages from `ModelRunner` 17 | logging.set_verbosity(logging.INFO) 18 | 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("model_size", choices=list(CONFIGS)) 23 | parser.add_argument("model_weights") 24 | args = parser.parse_args() 25 | 26 | if not exists("dbg_img.png"): 27 | logging.info("Downloading image") 28 | import urllib.request 29 | urllib.request.urlretrieve( 30 | "https://farm2.staticflickr.com/1362/1261465554_95741e918b_z.jpg", 31 | filename="dbg_img.png") 32 | 33 | model = runner.ModelRunner(args.model_size, args.model_weights) 34 | with Image.open("dbg_img.png") as img: 35 | image = np.array(img.convert('RGB')) 36 | output = model.vqa(image, "What color is the sofa?") 37 | print(output["text"]) # Should print `green` 38 | 39 | 40 | if __name__ == "__main__": 41 | main() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops 2 | flax 3 | absl-py 4 | 5 | # For T5 tokenizer 6 | transformers 7 | sentencepiece 8 | 9 | # Just used for image resizing 10 | torch 11 | torchvision 12 | -------------------------------------------------------------------------------- /uio/configs.py: -------------------------------------------------------------------------------- 1 | from uio.network import UnifiedIOConfig, VAEConfig 2 | 3 | 4 | DTYPE = "float32" 5 | 6 | 7 | # Shared between all model sizes 8 | VAE_CONFIG = VAEConfig( 9 | embed_dim=256, 10 | n_embed=16384, 11 | double_z=False, 12 | z_channels=256, 13 | resolution=256, 14 | in_channels=3, 15 | out_ch=3, 16 | ch=128, 17 | ch_mult=(1,1,2,2,4), 18 | num_res_blocks=2, 19 | attn_resolutions=(16,), 20 | dropout=0, 21 | dtype=DTYPE, 22 | ) 23 | 24 | CONFIGS = { 25 | "small": UnifiedIOConfig( 26 | dtype=DTYPE, 27 | emb_dim=512, 28 | num_heads=6, 29 | num_encoder_layers=8, 30 | num_decoder_layers=8, 31 | mlp_dim=1024, 32 | ), 33 | "base": UnifiedIOConfig( 34 | dtype=DTYPE, 35 | emb_dim=768, 36 | num_heads=12, 37 | num_encoder_layers=12, 38 | num_decoder_layers=12, 39 | mlp_dim=2048, 40 | vocab_size=33152, 41 | ), 42 | "large": UnifiedIOConfig( 43 | dtype=DTYPE, 44 | emb_dim=1024, 45 | num_heads=16, 46 | num_encoder_layers=24, 47 | num_decoder_layers=24, 48 | mlp_dim=2816, 49 | ), 50 | "xl": UnifiedIOConfig( 51 | dtype=DTYPE, 52 | emb_dim=2048, 53 | num_heads=32, 54 | num_encoder_layers=24, 55 | num_decoder_layers=24, 56 | mlp_dim=5120, 57 | num_seg_emb=8 58 | ) 59 | } -------------------------------------------------------------------------------- /uio/decoding.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Fast decoding routines for inference from a trained model.""" 16 | import functools 17 | 18 | from typing import Callable, Mapping, Optional, Tuple 19 | import flax 20 | from flax import traverse_util 21 | import jax 22 | from jax import lax 23 | from jax import random 24 | import jax.numpy as jnp 25 | import numpy as np 26 | 27 | PyTreeDef = type(jax.tree_util.tree_structure(None)) 28 | SamplingLoopState = Tuple[int, jnp.ndarray, Mapping[str, jnp.ndarray], 29 | jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray] 30 | 31 | # Constants 32 | # "Effective negative infinity" constant for masking in beam search. 33 | NEG_INF = np.array(-1.0e7) 34 | 35 | #------------------------------------------------------------------------------ 36 | # Temperature Sampling 37 | #------------------------------------------------------------------------------ 38 | _dynamic_update_vector_slice_in_dim = jax.vmap( 39 | lax.dynamic_update_slice_in_dim, in_axes=(0, 0, 0, None)) 40 | 41 | 42 | def temperature_sample( 43 | inputs: jnp.ndarray, 44 | cache: Mapping[str, jnp.ndarray], 45 | tokens_to_logits: Callable[[jnp.ndarray, Mapping[str, jnp.ndarray]], 46 | Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]], 47 | eos_id: int, 48 | decode_rng: Optional[jnp.ndarray] = None, 49 | num_decodes: int = 1, 50 | temperature: float = 1.0, 51 | topk: int = 1, 52 | topp: float = 0.0, 53 | cache_offset: int = 0, 54 | initial_index: Optional[jnp.ndarray] = None, 55 | max_decode_steps: Optional[int] = None, 56 | ): 57 | """Temperature sampling for language model generation. 58 | 59 | The temperature sampling is performed `num_decodes` times in a vectorized 60 | manner by expanding the batch dimension. This is similar to how beam search 61 | expands the batch dimension to process each batch element with multiple beams. 62 | 63 | This function dynamically updates the `inputs` array by sampling from the 64 | model logits, which is provided by `tokens_to_logits` callable. The input 65 | sequences are expanded at the end, populated and sliced by dropping the first 66 | position. 67 | 68 | If `inputs` has non-zero entries, those values are not modified, i.e., 69 | the sampled values for those positions are discarded. This simulates the 70 | teacher forcing on the prefix positions. 71 | 72 | There are a few important observations related to this function. 73 | 74 | 1. The `inputs` is assumed to be a non-packed sequence. 75 | 76 | 2. If `initial_index=None`, then `inputs`[:, 0] is ignored. We will use 0 as a 77 | BOS token to start the generation. This inherently assumes that `inputs` is 78 | already shifted to the right by one position. If `initial_index=an_array`, 79 | the token values at `inputs`[:, initial_index] are used as the token to 80 | start the generation. 81 | 82 | 3. The loop index, i, is a vector of shape [batch_size]. When beginning 83 | generation from scratch, each value will always have the same value. When 84 | beginning with a partially filled cache, the loop index of different 85 | elements can differ, via providing a value for `initial_index`. 86 | 87 | 3. Unless all batch elements generated the eos_id before reaching the end, we 88 | always make `max_decode_len = inputs.shape[1]` number of calls to 89 | `tokens_to_logits` when decoding from scratch and 90 | `max_decode_len - jnp.minimum(initial_index)` number of calls when starting 91 | from a partially filled cache. 92 | 93 | 4. Let `output` be the output sequences, i.e.,`sequences`[:, 1:]. Then 94 | `output`[:, j] are the tokens generated when the while loop counter `i = 95 | j`. Therefore, we generate the last token when `i = max_decode_len - 1` 96 | and exit the while loop as all `i`s are incremented to `max_decode_len`. 97 | 98 | 5. Once `eos_id = 1` is generated, the subsequent predictions are all replaced 99 | by padding token 0. 100 | 101 | 6. When using a partially filled cache, different batch elements can have 102 | different lengths. This means an input that has a longer input will have 103 | fewer steps until its `i` value reaches `max_decode_len` than an input with 104 | a shorter input. We keep these longer examples alive, doing busy work 105 | continually overwriting a new garbage token at the end of the sequence 106 | until shorter examples finish. 107 | 108 | 7. When using a partially filled cache, providing a value for `initial_index`, 109 | the attention cache index should be a vector of [batch_size]. 110 | 111 | We show three examples to illustrate how this function works. In addition to 112 | input and output of the function, we also show two intermediate values: 113 | `expanded_prompt_inputs` and `final_sequences`. Also for simplicity, the 114 | examples are limited to `num_decodes = 1` usage and the `num_decodes` 115 | dimension is omitted. 116 | 117 | ``` 118 | Example 1: 119 | inputs = [0, 5, 6, 1, 0] 120 | expanded_prompt_inputs = [0, 5, 6, 1, 0, 0] 121 | final_sequences = [0, 5, 6, 1, a, b] # before slicing. 122 | output = [5, 6, 1, a, b] 123 | where `a` is prediction while taking 1 as input and `b` is prediction while 124 | taking `a` as input. 125 | 126 | Example 2 (early stopping): 127 | inputs = [[0, 5, 1, 0, 0, 0, 0], 128 | [0, 8, 0, 0, 0, 0, 0] 129 | expanded_prompt_inputs = [[0, 5, 1, 0, 0, 0, 0, 0], 130 | [0, 8, 0, 0, 0, 0, 0, 0] 131 | final_sequences = [[0, 5, 1, a, b, c=1, 0, 0], 132 | [0, 8, d, e, f=1, g=0, 0, 0]] 133 | output = [[5, 1, a, b, c=1, 0, 0], 134 | [8, d, e, f=1, g=0, 0, 0]] 135 | 136 | In this example, there are two sequences. Let's look at sequence 0. The 137 | first generated token is `a`, which is in turn used to generate `b`. 138 | Finally, `c = 1` is generated with the input `b`. Then the loop terminates 139 | early because 1 is the `eos_id`. 140 | 141 | Now consider sequence 1. The when `f = 1` was generated, it is considered 142 | done. Since sequence 0 is not done at this point, the next prediction, i.e., 143 | `g` is zerod out. This continues until the end. 144 | 145 | Example 3 (prefilled cache): 146 | inputs = [[0, 5, 2, 6, 1, 0], 147 | [0, 8, 1, 0, 0, 0]] 148 | expanded_prompt_inputs = [[0, 5, 2, 6, 1, 0, 0, 0], 149 | [0, 8, 1, 0, 0, 0, 0, 0]] 150 | max_decode_length = 6 151 | i = [4, 2] 152 | input_tokens = [[1], 153 | [1]] 154 | output_tokens = [[a], 155 | [b]] 156 | expanded_prompt_inputs = [[0, 5, 2, 6, 1, a, 0, 0], 157 | [0, 8, 1, b, 0, 0, 0, 0]] 158 | i = [5, 3] 159 | input_tokens = [[a], 160 | [b]] 161 | output_tokens = [[c], 162 | [d]] 163 | expanded_prompt_inputs = [[0, 5, 2, 6, 1, a, c, 0], 164 | [0, 8, 1, b, d, 0, 0, 0]] 165 | i = [6, 4] 166 | input_tokens = [[c], 167 | [d]] 168 | output_tokens = [[y], 169 | [e]] 170 | expanded_prompt_inputs = [[0, 5, 2, 6, 1, a, c, y], 171 | [0, 8, 1, b, d, e, 0, 0]] 172 | i = [6, 5] 173 | input_tokens = [[z], 174 | [e]] 175 | output_tokens = [[z], 176 | [f]] 177 | expanded_prompt_inputs = [[0, 5, 2, 6, 1, a, c, z], 178 | [0, 8, 1, b, d, e, f, 0]] 179 | i = [6, 6] 180 | exit 181 | outputs = [[5, 2, 6, 1, a, c], 182 | [8, 1, b, d, e, f]] 183 | 184 | In this example, there are two sequences with different input lengths. Thus 185 | the two caches had been filled to different positions. As we decode, the 186 | first sequence hits the max decode length before the second. In order to 187 | avoid prematurely ending decoding for the second sequence, the first 188 | sequence continually overwrites the final token. 189 | ``` 190 | 191 | Args: 192 | inputs: array: [batch_size, max_decode_len] int32 sequence of tokens. 193 | cache: flax attention cache. 194 | tokens_to_logits: fast autoregressive decoder function taking single token 195 | slices and cache and returning next-token logits and updated cache. 196 | eos_id: int: end-of-sentence token for target vocabulary. 197 | decode_rng: JAX PRNGKey. 198 | num_decodes: number of decoded sequences to be returned. 199 | temperature: float: sampling temperature factor. As it approaches zero this 200 | becomes equivalent to greedy sampling. 201 | topk: integer: if nonzero only use the top-k logits to sample next token, if 202 | zero don't use any cutoff and sample from full logits over vocabulary. 203 | topp: float: if nonzero only use the smallest number of logits whose 204 | cumulative sum of probs adds up to (at least) topp. Will raise ValueError 205 | if it's nonzero when topk is nonzero. 206 | cache_offset: axis offset for cache, arising from scanned layers. 207 | initial_index: Optional[array]: [batch_size] int32 a vector of loop indexes 208 | to start decoding at. 209 | max_decode_steps: int: an optional maximum number of decoding steps. If 210 | None, it will decode until the full input shape `inputs.shape[1]` is 211 | filled. max_decode_steps begins counting after the prompt, so it will 212 | decode at most len(prompt) + max_decode_steps tokens. 213 | 214 | Returns: 215 | A tuple (decodes, log_prob) where `decodes` is sampled sequences with shape 216 | [batch_size, num_decodes, max_decode_len] sorted by `log_prob`, which is log 217 | probability of each of the sampled sequences. 218 | """ 219 | if decode_rng is None: 220 | decode_rng = jax.random.PRNGKey(0) 221 | 222 | # [batch, len] -> [batch * num_decodes, len] 223 | expanded_inputs = flat_batch_beam_expand(inputs, num_decodes) 224 | expanded_cache = cache_map( 225 | functools.partial( 226 | flat_batch_beam_expand, beam_size=num_decodes, offset=cache_offset), 227 | cache, 228 | # When we start with a prefilled cache, the cache index is no longer a 229 | # scalar that will broadcast across multiple decodes, it is a vector and 230 | # needs to be updated to handle the multiple decodes. 231 | apply_to_index=initial_index is not None) 232 | if initial_index is not None: 233 | initial_index = flat_batch_beam_expand(initial_index, num_decodes) 234 | 235 | # expanded_decodes: [batch * num_decodes, len] 236 | # expanded_log_prob: [batch * num_decodes] 237 | expanded_decodes, expanded_log_prob, expanded_all_logprob = _temperature_sample_single_trial( 238 | expanded_inputs, 239 | expanded_cache, 240 | tokens_to_logits, 241 | eos_id, 242 | decode_rng, 243 | temperature, 244 | topk, 245 | topp, 246 | initial_index=initial_index, 247 | max_decode_steps=max_decode_steps) 248 | 249 | batch_size = inputs.shape[0] 250 | # [batch * num_decodes, len] -> [batch, num_decodes, len] 251 | decodes = unflatten_beam_dim(expanded_decodes, batch_size, num_decodes) 252 | # [batch * num_decodes] -> [batch, num_decodes] 253 | log_prob = unflatten_beam_dim(expanded_log_prob, batch_size, num_decodes) 254 | all_logprob = unflatten_beam_dim(expanded_all_logprob, batch_size, num_decodes) 255 | 256 | # Sort `decodes` and `log_prob` by increasing log probabilities of the sampled 257 | # sequence. 258 | # [batch, num_decodes, 1] 259 | idxs = jnp.expand_dims(jnp.argsort(log_prob, axis=-1), axis=-1) 260 | 261 | # returns [batch, num_decodes, len], [batch, num_decodes] in sorted order. 262 | return jnp.take_along_axis( 263 | decodes, idxs, axis=1), jnp.take_along_axis( 264 | log_prob, jnp.squeeze(idxs, axis=-1), axis=-1), jnp.take_along_axis(all_logprob, idxs, axis=1) 265 | 266 | def _temperature_sample_single_trial( 267 | inputs: jnp.ndarray, 268 | cache: Mapping[str, jnp.ndarray], 269 | tokens_to_logits: Callable[[jnp.ndarray, Mapping[str, jnp.ndarray]], 270 | Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]], 271 | eos_id: int, 272 | prng_key: jnp.ndarray, 273 | temperature: float = 1.0, 274 | topk: int = 20, 275 | topp: float = 0.0, 276 | initial_index: Optional[jnp.ndarray] = None, 277 | max_decode_steps: Optional[int] = None) -> jnp.ndarray: 278 | """A helper function for `temperature_sample`.""" 279 | if topp and topk: 280 | raise ValueError('At most one of `topp` or `topk` must be non-zero.') 281 | batch_size, max_decode_len = inputs.shape 282 | 283 | if max_decode_steps is not None: 284 | if max_decode_steps > inputs.shape[1]: 285 | raise ValueError('Cannot decode more steps than the sequence length.') 286 | 287 | # the number of decode steps required to process the prefix is the number 288 | # of non-zero tokens, since inputs[0] == 0 is the BOS token. 289 | max_decode_len = jnp.sum(inputs != 0, axis=1) + max_decode_steps 290 | max_decode_len = jnp.minimum(inputs.shape[1], max_decode_len) 291 | 292 | # We start with a dummy token in the beginning so extend the maximum length. 293 | # [batch, length] -> [batch, length+1] 294 | # 295 | # In the case of starting generation from a non-zero index, it is possible for 296 | # one batch element to reach `max_decode_len` number of decoding steps before 297 | # another. In order to let the last element decoder all the way to 298 | # `max_decode_len` number of steps, we add a final garbage token to the end of 299 | # the sequences. Any element that has reached `max_decode_len` before the rest 300 | # of the elements will continually overwrite this token until all elements 301 | # finish. 302 | # [batch, length+1] -> [batch, length+2] 303 | expanded_prompt_inputs = jnp.append( 304 | inputs, jnp.zeros((batch_size, 2), dtype=inputs.dtype), axis=1) 305 | # end_marker = jnp.array(eos_id) 306 | end_marker = jnp.array(10000000) 307 | 308 | # TODO(hwchung): handle zero temperature case in an optimized manner. 309 | # Add a small number to avoid division by zero when `temperature = 0.0`. 310 | temperature = jnp.array(temperature) + 1e-7 311 | 312 | # Initialize sampling loop state. 313 | # initial loop PRNGKey 314 | rng0 = prng_key 315 | # the per batch-item holding current token in loop. 316 | if initial_index is None: 317 | # the per batch-item loop position counter. 318 | i0 = jnp.zeros((batch_size), dtype=jnp.int32) 319 | # the per batch-item holding current token in loop. 320 | token0 = jnp.zeros((batch_size, 1), dtype=jnp.int32) 321 | else: 322 | # the per batch-item loop position counter. 323 | i0 = initial_index 324 | # the per batch-item holding current token in loop. 325 | # Select the token that the initial index is pointing to. 326 | token0 = jnp.take_along_axis( 327 | expanded_prompt_inputs, jnp.expand_dims(i0, axis=1), axis=1) 328 | # per batch-item state bit indicating if sentence has finished. 329 | ended0 = jnp.zeros((batch_size, 1), dtype=jnp.bool_) 330 | # (batch, length+2) array containing prefix prompt tokens for sampling loop 331 | # as well as the generated output of newly sampled tokens. 332 | sequences0 = expanded_prompt_inputs 333 | log_prob0 = jnp.zeros((batch_size,), dtype=jnp.float32) 334 | all_log_prob0 = jnp.zeros((batch_size,max_decode_len), dtype=jnp.float32) 335 | 336 | # Sampling loop state is stored in a simple tuple. 337 | sampling_loop_init_state = (i0, sequences0, cache, token0, ended0, rng0, 338 | log_prob0, all_log_prob0) 339 | # Initial eos count to be used to determine whether eos is "generated". Many 340 | # inputs follow the format bos, inputs..., eos, targets..., eos. By counting 341 | # the number of eos tokens we can detect when a new one is added, instead of 342 | # just finding the one that probably ends the inputs. 343 | # [batch, 1] 344 | initial_eos_count = jnp.sum(sequences0 == end_marker, axis=-1, keepdims=True) 345 | 346 | def sampling_loop_cond_fn(state: SamplingLoopState) -> bool: 347 | """Sampling loop termination condition.""" 348 | (_, _, _, _, ended, _, _, _) = state 349 | 350 | # Have all sampled sequences reached an end marker? 351 | # Different elements in the batch can be at different loop indices, if any 352 | # of our examples are not at the end, keep going. 353 | all_sequences_ended = jnp.all(ended) 354 | return ~all_sequences_ended 355 | 356 | def sampling_loop_body_fn(state: SamplingLoopState) -> SamplingLoopState: 357 | """Sampling loop state update.""" 358 | i, sequences, cache, cur_token, ended, rng, log_prob, all_log_prob = state 359 | # Split RNG for sampling. 360 | rng1, rng2 = random.split(rng) 361 | # Call fast-decoder model on current tokens to get next-position logits. 362 | logits, new_cache = tokens_to_logits(cur_token, cache, i[0], sequences) 363 | # Sample next token from logits. 364 | if topp: 365 | logits_sorted = jnp.sort(logits, axis=-1)[:, ::-1] # sort descending 366 | sorted_cum_probs = jnp.cumsum( 367 | jax.nn.softmax(logits_sorted, axis=-1), axis=-1) 368 | cutoff_index = jnp.sum(sorted_cum_probs < topp, axis=-1, keepdims=True) 369 | cutoff_logit = jnp.take_along_axis(logits_sorted, cutoff_index, axis=-1) 370 | logits = jnp.where(logits < cutoff_logit, jnp.full_like(logits, NEG_INF), 371 | logits) 372 | if topk: 373 | # Get top-k logits and their indices, sample within these top-k tokens. 374 | topk_logits, topk_idxs = lax.top_k(logits, topk) 375 | topk_token = jnp.expand_dims( 376 | random.categorical(rng1, topk_logits / temperature).astype(jnp.int32), 377 | axis=-1) 378 | # Return the original indices corresponding to the sampled top-k tokens. 379 | # [batch] 380 | next_token = jnp.squeeze( 381 | jnp.take_along_axis(topk_idxs, topk_token, axis=-1), axis=-1) 382 | else: 383 | # [batch] 384 | next_token = random.categorical(rng1, 385 | logits / temperature).astype(jnp.int32) 386 | 387 | # log probability of the current token conditioned on the previously sampled 388 | # and prefix tokens. 389 | # [batch, vocab] -> [batch, vocab] 390 | log_probs = jax.nn.log_softmax(logits) 391 | # [batch, vocab] -> [batch] 392 | next_log_prob = jnp.squeeze( 393 | jnp.take_along_axis( 394 | log_probs, jnp.expand_dims(next_token, axis=1), axis=-1), 395 | axis=-1) 396 | 397 | one_hot_indices = jax.nn.one_hot(i, all_log_prob.shape[-1], dtype=next_log_prob.dtype) 398 | all_log_prob = all_log_prob + one_hot_indices * jnp.expand_dims(next_log_prob, -1) 399 | 400 | # When different batch elements are at different points in the loop counter, 401 | # it is possible that an element that started at a higher index will reach 402 | # `max_decode_len` before other elements. When this happens we need to make 403 | # sure this element continuous overwrites our new garbage collection index. 404 | # Here we clamp `i` to `max_decode_len`. This will cause the a write to 405 | # `max_decode_len + 1` which is the final index in `sequences`. Subsequent 406 | # loop body executions will also get their value clamped causing continual 407 | # overwriting of the final garbage position until all examples are finished. 408 | i = jnp.minimum(i, max_decode_len) 409 | 410 | # Only use sampled tokens if we're past provided prefix tokens. 411 | # Select the next token from sequences. 412 | # [batch] 413 | next_input_token = jnp.squeeze( 414 | jnp.take_along_axis(sequences, jnp.expand_dims(i + 1, axis=1), axis=1), 415 | axis=1) 416 | # Check if the next token is padding (a target) or non-padding (an input). 417 | # Mask will have `1` for targets and `0` for inputs. 418 | out_of_prompt = (next_input_token == 0) 419 | # Select the sampled next token for targets and the actual next token for 420 | # inputs (teacher forcing). 421 | # [batch] 422 | next_token = ( 423 | next_token * out_of_prompt + next_input_token * ~out_of_prompt) 424 | 425 | # only add probability if outside prefix region 426 | # [batch] -> [batch] 427 | next_log_prob = log_prob + (next_log_prob * out_of_prompt) * jnp.squeeze( 428 | ~ended, axis=-1).astype(jnp.int32) 429 | 430 | # [batch] -> [batch, 1] 431 | next_token = jnp.expand_dims(next_token, axis=-1) 432 | 433 | # If end-marker reached for batch item, only emit padding tokens. 434 | # [batch, 1] * [batch, 1] -> [batch, 1] 435 | next_token_or_endpad = next_token * ~ended 436 | # Add current sampled tokens to recorded sequences. 437 | one_hot = jax.nn.one_hot(i + 1, sequences.shape[1], dtype=sequences.dtype) 438 | new_sequences = sequences * (1 - one_hot) + next_token_or_endpad * one_hot 439 | # new_sequences = dynamic_update_vector_slice_in_dim(sequences, 440 | # next_token_or_endpad, 441 | # i + 1, 442 | # 0) 443 | # Count eos tokens in the sequences and compare to the initial count 444 | # [batch, 1] 445 | cur_eos_count = jnp.sum(new_sequences == end_marker, axis=-1, keepdims=True) 446 | # [batch, 1] 447 | 448 | # Have we reached max decoding length? 449 | # We generally index into sequences[:, i + 1], and sequences.shape[1] = 450 | # max_decode_len + 2, therefore i == max_decode_len - 1 will write to 451 | # sequences[-2] which is our last valid location. i == max_decode_len will 452 | # write to sequences[-1] which is our garbage collection token. Thus `i` 453 | # should be strictly less than max_decode_len. 454 | has_additional_eos = cur_eos_count > initial_eos_count 455 | ended |= has_additional_eos | jnp.expand_dims( 456 | i >= max_decode_len - 1, axis=1) 457 | 458 | return (i + 1, new_sequences, new_cache, next_token_or_endpad, ended, rng2, 459 | next_log_prob, all_log_prob) 460 | 461 | # Run sampling loop and collect final state. 462 | final_state = lax.while_loop(sampling_loop_cond_fn, sampling_loop_body_fn, 463 | sampling_loop_init_state) 464 | 465 | # Pick part of the state corresponding to the sampled sequences. 466 | final_sequences = final_state[1] 467 | log_prob = final_state[-2] 468 | all_logprob = final_state[-1] 469 | # Drop the first position because they are dummy bos tokens. Drop the new 470 | # garbage collection token at the end too. 471 | return final_sequences[:, 1:-1], log_prob, all_logprob 472 | 473 | 474 | #------------------------------------------------------------------------------ 475 | # BEAM Sampling 476 | #------------------------------------------------------------------------------ 477 | 478 | 479 | def brevity_penalty(alpha: float, length: int) -> jnp.ndarray: 480 | """Brevity penalty function for beam search penalizing short sequences. 481 | 482 | Args: 483 | alpha: float: brevity-penalty scaling parameter. 484 | length: int: length of considered sequence. 485 | 486 | Returns: 487 | Brevity penalty score as jax scalar. 488 | """ 489 | return jnp.power(((5.0 + length) / 6.0), alpha) 490 | 491 | 492 | # Beam handling utility functions: 493 | 494 | 495 | def cache_map(fn, cache, apply_to_index: bool = False): 496 | """Maps function over that caches, even multiple caches in various layers. 497 | 498 | Args: 499 | fn: The function to apply. 500 | cache: The cache to apply it to. 501 | apply_to_index: Whether to apply the function to the cache index. 502 | 503 | Returns: 504 | The result of applying `fn` to the cache. 505 | """ 506 | frozen = isinstance(cache, flax.core.FrozenDict) 507 | if frozen: 508 | cache = flax.core.unfreeze(cache) 509 | flat_cache = traverse_util.flatten_dict(cache) 510 | if apply_to_index: 511 | keyvals = flat_cache 512 | else: 513 | keyvals = {k: v for k, v in flat_cache.items() if k[-1] != 'cache_index'} 514 | # Exclude cached relative position bias from beam expansion, etc. 515 | # Also excludes scalar index in absolute position embedder from expansion. 516 | # TODO(levskaya): generalize cache_map to accept a list of leaf names to 517 | # map over, instead of doing this ad-hoc. 518 | exclusion_list = ['cached_bias', 'position_embedder_index'] 519 | keyvals = {k: v for k, v in keyvals.items() if k[-1] not in exclusion_list} 520 | 521 | keyvals = jax.tree_util.tree_map(fn, keyvals) 522 | flat_cache.update(keyvals) 523 | new_cache = traverse_util.unflatten_dict(flat_cache) 524 | if frozen: 525 | new_cache = flax.core.freeze(new_cache) 526 | return new_cache 527 | 528 | 529 | def add_beam_dim(x: jnp.ndarray, 530 | beam_size: int, 531 | offset: int = 0) -> jnp.ndarray: 532 | """Creates new beam dimension in non-scalar array and tiles into it.""" 533 | x = jnp.expand_dims(x, axis=offset + 1) 534 | tile_dims = [1] * x.ndim 535 | tile_dims[offset + 1] = beam_size 536 | return jnp.tile(x, tile_dims) 537 | 538 | 539 | def flatten_beam_dim(x: jnp.ndarray, offset: int = 0) -> jnp.ndarray: 540 | """Flattens the first two dimensions of a non-scalar array.""" 541 | xshape = list(x.shape) 542 | b_sz = xshape.pop(offset) 543 | xshape[offset] *= b_sz 544 | return x.reshape(xshape) 545 | 546 | 547 | def unflatten_beam_dim(x: jnp.ndarray, 548 | batch_size: int, 549 | beam_size: int, 550 | offset: int = 0) -> jnp.ndarray: 551 | """Unflattens the first, flat batch*beam dimension of a non-scalar array.""" 552 | assert batch_size * beam_size == x.shape[offset] 553 | xshape = list(x.shape) 554 | newshape = xshape[:offset] + [batch_size, beam_size] + xshape[offset + 1:] 555 | return x.reshape(newshape) 556 | 557 | 558 | def flat_batch_beam_expand(x: jnp.ndarray, 559 | beam_size: int, 560 | offset: int = 0) -> jnp.ndarray: 561 | """Expands the each batch item by beam_size in batch_dimension.""" 562 | return flatten_beam_dim(add_beam_dim(x, beam_size, offset), offset) 563 | 564 | 565 | def cache_gather_beams(nested: PyTreeDef, 566 | beam_indices: jnp.ndarray, 567 | batch_size: int, 568 | old_beam_size: int, 569 | new_beam_size: int, 570 | one_hot: bool = True, 571 | offset: int = 0) -> jnp.ndarray: 572 | """Gathers the cache beam slices indexed by beam_indices into new beam array. 573 | 574 | Args: 575 | nested: cache pytree. 576 | beam_indices: array of beam_indices 577 | batch_size: size of batch. 578 | old_beam_size: size of _old_ beam dimension. 579 | new_beam_size: size of _new_ beam dimension. 580 | one_hot: whether to perform gathers by one-hot contraction or directly. 581 | offset: cache axis offset from scanned layers. 582 | 583 | Returns: 584 | New pytree with new beam arrays. 585 | [batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...] 586 | """ 587 | assert offset in (0, 1), 'general offsets not supported' 588 | if one_hot: 589 | # Gather via one-hot contraction, needed for SPMD partitioning. 590 | oh_beam_indices = jax.nn.one_hot( 591 | beam_indices, old_beam_size, dtype=jnp.int32) 592 | if offset == 0: 593 | 594 | def gather_fn(x): 595 | return jnp.einsum('beo,bo...->be...', oh_beam_indices, 596 | x).astype(x.dtype) 597 | else: 598 | 599 | def gather_fn(x): 600 | return jnp.einsum('beo,lbo...->lbe...', oh_beam_indices, 601 | x).astype(x.dtype) 602 | 603 | return cache_map(gather_fn, nested) 604 | 605 | else: 606 | # True gather via fancy indexing. 607 | batch_indices = jnp.reshape( 608 | jnp.arange(batch_size * new_beam_size) // new_beam_size, 609 | (batch_size, new_beam_size)) 610 | if offset == 0: 611 | 612 | def gather_fn(x): 613 | return x[batch_indices, beam_indices] 614 | else: 615 | 616 | def gather_fn(x): 617 | return x[:, batch_indices, beam_indices] 618 | 619 | return cache_map(gather_fn, nested) 620 | 621 | 622 | def gather_beams(nested: PyTreeDef, 623 | beam_indices: jnp.ndarray, 624 | batch_size: int, 625 | old_beam_size: int, 626 | new_beam_size: int, 627 | one_hot: bool = True) -> jnp.ndarray: 628 | """Gathers the beam slices indexed by beam_indices into new beam array. 629 | 630 | Args: 631 | nested: pytree of arrays or scalars (the latter ignored). 632 | beam_indices: array of beam_indices 633 | batch_size: size of batch. 634 | old_beam_size: size of _old_ beam dimension. 635 | new_beam_size: size of _new_ beam dimension. 636 | one_hot: whether to perform gathers by one-hot contraction or directly. 637 | 638 | Returns: 639 | New pytree with new beam arrays. 640 | [batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...] 641 | """ 642 | if one_hot: 643 | # Gather via one-hot contraction, needed for SPMD partitioning. 644 | oh_beam_indices = jax.nn.one_hot( 645 | beam_indices, old_beam_size, dtype=jnp.int32) 646 | 647 | def gather_fn(x): 648 | return jnp.einsum('beo,bo...->be...', oh_beam_indices, x).astype(x.dtype) 649 | 650 | return jax.tree_util.tree_map(gather_fn, nested) 651 | else: 652 | # True gather via fancy indexing. 653 | batch_indices = jnp.reshape( 654 | jnp.arange(batch_size * new_beam_size) // new_beam_size, 655 | (batch_size, new_beam_size)) 656 | 657 | def gather_fn(x): 658 | return x[batch_indices, beam_indices] 659 | 660 | return jax.tree_util.tree_map(gather_fn, nested) 661 | 662 | 663 | def top_k_two_stage(x, k): 664 | """Wrapper around lax.top_k with low-batch optimization. 665 | 666 | Args: 667 | x: tensor with shape f32[batch, num_samples]. 668 | k: integer indicating how many top values to return. 669 | 670 | Returns: 671 | Largest k values and indices with shape (f32[batch, k], s32[batch, k]). 672 | """ 673 | 674 | batch, num_samples = x.shape 675 | num_lanes = 128 676 | if (isinstance(batch, int) and batch <= 8 and 677 | num_samples > 8 * num_lanes * k): 678 | # At small batch, when num_samples is sufficiently large, optimize 679 | # execution on TPU by doing TopK in two stages. Reshaping 'x' to fill 680 | # lanes reduces tensor padding in TopK call. 681 | if num_samples % num_lanes != 0: 682 | # Pad input tensor to multiples of num_lanes. 683 | num_samples_rounded_up = num_samples + ( 684 | num_lanes - num_samples % num_lanes) 685 | x = jnp.pad( 686 | x, ((0, 0), (0, num_samples_rounded_up - num_samples)), 687 | mode='constant', 688 | constant_values=np.NINF) 689 | num_samples = num_samples_rounded_up 690 | # Reshape input tensor to fill lanes. 691 | num_samples_sublanes = int(num_samples / num_lanes) 692 | x_reshaped = jnp.reshape(x, (batch * num_lanes, num_samples_sublanes)) 693 | # First stage top_k. 694 | vals, indices = lax.top_k(x_reshaped, k) 695 | indices = jnp.reshape(indices, (batch, num_lanes, k)) 696 | index_offsets = jnp.reshape(num_samples_sublanes * jnp.arange(num_lanes), 697 | (1, num_lanes, 1)) 698 | indices = jnp.reshape( 699 | jnp.add(index_offsets, indices), (batch, num_lanes * k)) 700 | vals = jnp.reshape(vals, (batch, num_lanes * k)) 701 | # Second stage top_k. 702 | vals_s2, indices_s2 = lax.top_k(vals, k) 703 | indices_s2 = jnp.take_along_axis(indices, indices_s2, axis=1) 704 | return vals_s2, indices_s2 705 | else: 706 | # Use default TopK implementation. 707 | return lax.top_k(x, k) 708 | 709 | 710 | def gather_topk_beams(nested: PyTreeDef, score_or_log_prob: jnp.ndarray, 711 | batch_size: int, new_beam_size: int) -> jnp.ndarray: 712 | """Gathers the top-k beam slices given by score_or_log_prob array. 713 | 714 | Args: 715 | nested: pytree of arrays or scalars (the latter ignored). 716 | score_or_log_prob: [batch_size, old_beam_size] array of values to sort by 717 | for top-k selection of beam slices. 718 | batch_size: int: size of batch. 719 | new_beam_size: int: size of _new_ top-k selected beam dimension 720 | 721 | Returns: 722 | New pytree with new beam arrays containing top k new_beam_size slices. 723 | [batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...] 724 | """ 725 | _, topk_indices = lax.top_k(score_or_log_prob, k=new_beam_size) 726 | topk_indices = jnp.flip(topk_indices, axis=1) 727 | return gather_beams(nested, topk_indices, batch_size, 728 | score_or_log_prob.shape[1], new_beam_size) 729 | 730 | 731 | # Beam search state: 732 | 733 | 734 | @flax.struct.dataclass 735 | class BeamState: 736 | """Holds beam search state data.""" 737 | # The position of the decoding loop in the length dimension. 738 | cur_index: jnp.DeviceArray # scalar int32: current decoded length index 739 | # The active sequence log probabilities and finished sequence scores. 740 | live_logprobs: jnp.DeviceArray # float32: [batch_size, beam_size] 741 | all_logprobs: jnp.DeviceArray 742 | finished_scores: jnp.DeviceArray # float32: [batch_size, beam_size] 743 | # The current active-beam-searching and finished sequences. 744 | live_seqs: jnp.DeviceArray # int32: [batch_size, beam_size, max_decode_len] 745 | finished_seqs: jnp.DeviceArray # int32: [batch_size, beam_size, 746 | # max_decode_len] 747 | # Records which of the 'finished_seqs' is occupied and not a filler slot. 748 | finished_flags: jnp.DeviceArray # bool: [batch_size, beam_size] 749 | # The current state of the autoregressive decoding caches. 750 | cache: PyTreeDef # Any pytree of arrays, e.g. flax attention Cache object 751 | 752 | 753 | def beam_init(batch_size: int, 754 | beam_size: int, 755 | max_decode_len: int, 756 | cache: Mapping[str, jnp.ndarray], 757 | offset: int = 0) -> BeamState: 758 | """Initializes the beam search state data structure.""" 759 | cur_index0 = jnp.array(0) 760 | live_logprobs0 = jnp.tile( 761 | jnp.array([0.0] + [NEG_INF] * (beam_size - 1)), [batch_size, 1]) 762 | all_logprobs0 = jnp.tile( 763 | jnp.array([0.0] + [0.0] * (beam_size - 1)), [batch_size, max_decode_len]) 764 | all_logprobs0 = jnp.reshape(all_logprobs0, [batch_size, beam_size, max_decode_len]) 765 | finished_scores0 = jnp.ones((batch_size, beam_size)) * NEG_INF 766 | live_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) 767 | finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) 768 | finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_) 769 | # add beam dimension to attention cache pytree elements 770 | beam_cache0 = cache_map(lambda x: add_beam_dim(x, beam_size, offset), cache) 771 | return BeamState( 772 | cur_index=cur_index0, 773 | live_logprobs=live_logprobs0, 774 | all_logprobs=all_logprobs0, 775 | finished_scores=finished_scores0, 776 | live_seqs=live_seqs0, 777 | finished_seqs=finished_seqs0, 778 | finished_flags=finished_flags0, 779 | cache=beam_cache0) 780 | 781 | # Beam search routine: 782 | 783 | 784 | def beam_search(inputs: jnp.ndarray, 785 | cache: Mapping[str, jnp.ndarray], 786 | tokens_to_logits: Callable[ 787 | [jnp.ndarray, Mapping[str, jnp.ndarray]], 788 | Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]], 789 | eos_id: int, 790 | num_decodes: int = 4, 791 | alpha: float = 0.6, 792 | max_decode_len: Optional[int] = None, 793 | decode_rng: Optional[jnp.ndarray] = None, 794 | cache_offset: int = 0): 795 | """Beam search for transformer machine translation. 796 | 797 | Args: 798 | inputs: array: [batch_size, length] int32 sequence of tokens. 799 | cache: flax attention cache. 800 | tokens_to_logits: fast autoregressive decoder function taking single token 801 | slices and cache and returning next-token logits and updated cache. 802 | eos_id: int: id of end-of-sentence token for target vocabulary. 803 | num_decodes: number of decoded sequences to be returned. This is equivalent 804 | to the number of beams used in the beam search. 805 | alpha: float: scaling factor for brevity penalty. 806 | max_decode_len: int: an optional maximum length of decoded sequence. If 807 | None, it uses `inputs.shape[1]` as `max_decode_len`. 808 | decode_rng: Unused decoder RNG seed. 809 | cache_offset: axis offset for cache, arising from scanned layers. 810 | 811 | Returns: 812 | Tuple of: 813 | [batch_size, beam_size, max_decode_len] top-scoring sequences 814 | [batch_size, beam_size] beam-search scores. 815 | """ 816 | del decode_rng 817 | # We liberally annotate shape information for clarity below. 818 | 819 | beam_size = num_decodes 820 | 821 | batch_size = inputs.shape[0] 822 | end_marker = jnp.array(10000000) 823 | if max_decode_len is None: 824 | max_decode_len = inputs.shape[1] 825 | # We start with a dummy token in the beginning so extend the maximum length. 826 | max_decode_len += 1 827 | 828 | # initialize beam search state 829 | beam_search_init_state = beam_init(batch_size, beam_size, max_decode_len, 830 | cache, cache_offset) 831 | 832 | def beam_search_loop_cond_fn(state: BeamState) -> bool: 833 | """Beam search loop termination condition.""" 834 | # Have we reached max decoding length? 835 | # Because we mutate the "i+1" position, we stop one token before the end. 836 | not_at_end = (state.cur_index < max_decode_len - 1) 837 | 838 | # Is no further progress in the beam search possible? 839 | # Get the best possible scores from alive sequences. 840 | min_brevity_penalty = brevity_penalty(alpha, max_decode_len) 841 | best_live_scores = state.live_logprobs[:, -1:] / min_brevity_penalty 842 | # Get the worst scores from finished sequences. 843 | worst_finished_scores = jnp.min( 844 | state.finished_scores, axis=1, keepdims=True) 845 | # Mask out scores from slots without any actual finished sequences. 846 | worst_finished_scores = jnp.where(state.finished_flags, 847 | worst_finished_scores, NEG_INF) 848 | # If no best possible live score is better than current worst finished 849 | # scores, the search cannot improve the finished set further. 850 | search_terminated = jnp.all(worst_finished_scores > best_live_scores) 851 | 852 | # If we're not at the max decode length, and the search hasn't terminated, 853 | # continue looping. 854 | return not_at_end & (~search_terminated) 855 | 856 | def beam_search_loop_body_fn(state: BeamState) -> BeamState: 857 | """Beam search loop state update function.""" 858 | # Collect the current position slice along length to feed the fast 859 | # autoregressive decoder model. Flatten the beam dimension into batch 860 | # dimension for feeding into the model. 861 | # --> [batch * beam, 1] 862 | flat_ids = flatten_beam_dim( 863 | lax.dynamic_slice(state.live_seqs, (0, 0, state.cur_index), 864 | (batch_size, beam_size, 1))) 865 | # Flatten beam dimension into batch to be compatible with model. 866 | # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...} 867 | flat_cache = cache_map( 868 | functools.partial(flatten_beam_dim, offset=cache_offset), state.cache) 869 | 870 | # Call fast-decoder model on current tokens to get next-position logits. 871 | # --> [batch * beam, vocab] 872 | flat_logits, new_flat_cache = tokens_to_logits(flat_ids, flat_cache, state.cur_index, state.live_seqs) 873 | 874 | # unflatten beam dimension 875 | # [batch * beam, vocab] --> [batch, beam, vocab] 876 | 877 | logits = unflatten_beam_dim(flat_logits, batch_size, beam_size) 878 | # Unflatten beam dimension in attention cache arrays 879 | # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...} 880 | new_cache = cache_map( 881 | lambda x: unflatten_beam_dim(x, batch_size, beam_size, cache_offset), 882 | new_flat_cache) 883 | 884 | # Gather log probabilities from logits 885 | candidate_log_probs = jax.nn.log_softmax(logits) 886 | # Add new logprobs to existing prefix logprobs. 887 | # --> [batch, beam, vocab] 888 | log_probs = ( 889 | candidate_log_probs + jnp.expand_dims(state.live_logprobs, axis=2)) 890 | 891 | # We'll need the vocab size, gather it from the log probability dimension. 892 | vocab_size = log_probs.shape[-1] 893 | 894 | # Each item in batch has beam_size * vocab_size candidate sequences. 895 | # For each item, get the top 2*k candidates with the highest log- 896 | # probabilities. We gather the top 2*K beams here so that even if the best 897 | # K sequences reach EOS simultaneously, we have another K sequences 898 | # remaining to continue the live beam search. 899 | beams_to_keep = 2 * beam_size 900 | # Flatten beam and vocab dimensions. 901 | flat_log_probs = log_probs.reshape((batch_size, beam_size * vocab_size)) 902 | # Gather the top 2*K scores from _all_ beams. 903 | # --> [batch, 2*beams], [batch, 2*beams] 904 | topk_log_probs, topk_indices = top_k_two_stage( 905 | flat_log_probs, k=beams_to_keep) 906 | # Recover the beam index by floor division. 907 | topk_beam_indices = topk_indices // vocab_size 908 | # Gather 2*k top beams. 909 | # --> [batch, 2*beams, length] 910 | topk_seq = gather_beams(state.live_seqs, topk_beam_indices, batch_size, 911 | beam_size, beams_to_keep) 912 | 913 | # Append the most probable 2*K token IDs to the top 2*K sequences 914 | # Recover token id by modulo division and expand Id array for broadcasting. 915 | # --> [batch, 2*beams, 1] 916 | topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2) 917 | # Update sequences for the 2*K top-k new sequences. 918 | # --> [batch, 2*beams, length] 919 | topk_seq = lax.dynamic_update_slice(topk_seq, topk_ids, 920 | (0, 0, state.cur_index + 1)) 921 | 922 | # Update LIVE (in-progress) sequences: 923 | # Did any of these sequences reach an end marker? 924 | # --> [batch, 2*beams] 925 | newly_finished = (topk_seq[:, :, state.cur_index + 1] == end_marker) 926 | # To prevent these newly finished sequences from being added to the LIVE 927 | # set of active beam search sequences, set their log probs to a very large 928 | # negative value. 929 | new_log_probs = topk_log_probs + newly_finished * NEG_INF 930 | # Determine the top k beam indices (from top 2*k beams) from log probs. 931 | # --> [batch, beams] 932 | _, new_topk_indices = lax.top_k(new_log_probs, k=beam_size) 933 | new_topk_indices = jnp.flip(new_topk_indices, axis=1) 934 | # Gather the top k beams (from top 2*k beams). 935 | # --> [batch, beams, length], [batch, beams] 936 | top_alive_seq, top_alive_log_probs = gather_beams([topk_seq, new_log_probs], 937 | new_topk_indices, 938 | batch_size, 2 * beam_size, 939 | beam_size) 940 | 941 | one_hot_indices = jax.nn.one_hot(state.cur_index, state.all_logprobs.shape[2], dtype=state.all_logprobs.dtype) 942 | 943 | # curr_log_probs = top_alive_log_probs - state.live_logprobs 944 | def identity_fn(x): return x 945 | def update_fn(x, y): return x - y 946 | 947 | curr_log_probs = jax.lax.cond( 948 | jax.lax.eq(state.cur_index, 0), 949 | lambda: identity_fn(top_alive_log_probs), 950 | lambda: update_fn(top_alive_log_probs, state.live_logprobs), 951 | ) 952 | 953 | one_hot_indices = jnp.reshape(one_hot_indices, [1,1,-1]) 954 | curr_log_probs = jnp.expand_dims(curr_log_probs, -1) 955 | all_logprobs = state.all_logprobs + one_hot_indices * curr_log_probs 956 | # Determine the top k beam indices from the original set of all beams. 957 | # --> [batch, beams] 958 | top_alive_indices = gather_beams(topk_beam_indices, new_topk_indices, 959 | batch_size, 2 * beam_size, beam_size) 960 | # With these, gather the top k beam-associated caches. 961 | # --> {[batch, beams, ...], ...} 962 | top_alive_cache = cache_gather_beams(new_cache, top_alive_indices, 963 | batch_size, beam_size, beam_size, True, 964 | cache_offset) 965 | 966 | # Update FINISHED (reached end of sentence) sequences: 967 | # Calculate new seq scores from log probabilities. 968 | new_scores = topk_log_probs / brevity_penalty(alpha, state.cur_index + 1) 969 | # Mask out the still unfinished sequences by adding large negative value. 970 | # --> [batch, 2*beams] 971 | new_scores += (~newly_finished) * NEG_INF 972 | 973 | # Combine sequences, scores, and flags along the beam dimension and compare 974 | # new finished sequence scores to existing finished scores and select the 975 | # best from the new set of beams. 976 | finished_seqs = jnp.concatenate( # --> [batch, 3*beams, length] 977 | [state.finished_seqs, topk_seq], 978 | axis=1) 979 | finished_scores = jnp.concatenate( # --> [batch, 3*beams] 980 | [state.finished_scores, new_scores], axis=1) 981 | finished_flags = jnp.concatenate( # --> [batch, 3*beams] 982 | [state.finished_flags, newly_finished], axis=1) 983 | # --> [batch, beams, length], [batch, beams], [batch, beams] 984 | top_finished_seq, top_finished_scores, top_finished_flags = ( 985 | gather_topk_beams([finished_seqs, finished_scores, finished_flags], 986 | finished_scores, batch_size, beam_size)) 987 | 988 | return BeamState( 989 | cur_index=state.cur_index + 1, 990 | live_logprobs=top_alive_log_probs, 991 | all_logprobs=all_logprobs, 992 | finished_scores=top_finished_scores, 993 | live_seqs=top_alive_seq, 994 | finished_seqs=top_finished_seq, 995 | finished_flags=top_finished_flags, 996 | cache=top_alive_cache) 997 | 998 | # Run while loop and get final beam search state. 999 | final_state = lax.while_loop(beam_search_loop_cond_fn, 1000 | beam_search_loop_body_fn, beam_search_init_state) 1001 | 1002 | # Account for the edge-case where there are no finished sequences for a 1003 | # particular batch item. If so, return live sequences for that batch item. 1004 | # --> [batch] 1005 | none_finished = jnp.any(final_state.finished_flags, axis=1) 1006 | # --> [batch, beams, length] 1007 | finished_seqs = jnp.where(none_finished[:, None, None], 1008 | final_state.finished_seqs, final_state.live_seqs) 1009 | # --> [batch, beams] 1010 | finished_scores = jnp.where(none_finished[:, 1011 | None], final_state.finished_scores, 1012 | final_state.live_logprobs) 1013 | 1014 | finished_logprobs = final_state.all_logprobs 1015 | 1016 | # Drop the first dummy 0 token. 1017 | return finished_seqs[:, :, 1:], finished_scores, finished_logprobs 1018 | -------------------------------------------------------------------------------- /uio/model.py: -------------------------------------------------------------------------------- 1 | # Modified from code from T5X (https://github.com/google-research/t5x) 2 | 3 | import functools 4 | from typing import Any, Callable, Mapping, MutableMapping, Optional, Tuple, Type, Union 5 | from flax import linen as nn 6 | from flax.core import scope as flax_scope 7 | from flax.training import common_utils 8 | import jax 9 | import jax.numpy as jnp 10 | import numpy as np 11 | import typing_extensions 12 | 13 | from uio import decoding 14 | 15 | Array = Union[np.ndarray, jnp.ndarray, jax.pxla.ShardedDeviceArray] 16 | PyTreeDef = type(jax.tree_util.tree_structure(None)) 17 | 18 | 19 | # Sentinel used instead of None to indicate missing values 20 | _NoValueSentinel = object() 21 | 22 | 23 | class TokensIdsToLogitsCallable(typing_extensions.Protocol): 24 | """Token ids to logits mapping call signature.""" 25 | 26 | def __call__( 27 | self, token_ids: jnp.ndarray, cache: Mapping[str, jnp.ndarray] 28 | ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: 29 | """Performs forward pass to convert token ids to logits. 30 | 31 | Args: 32 | token_ids: [batch_size, 1] int32 tokens for single position used during 33 | incremental decoding. Non-0 prefix tokens to be used as a forced prompt. 34 | cache: flax attention cache. 35 | 36 | Returns: 37 | a tuple of logits with a shape [batch_size, vocab_size] and an updated 38 | cache. 39 | """ 40 | ... 41 | 42 | 43 | class DecodeFnCallable(typing_extensions.Protocol): 44 | """Decoding function call signature.""" 45 | 46 | def __call__(self, *, inputs: jnp.ndarray, cache: Mapping[str, jnp.ndarray], 47 | tokens_to_logits: TokensIdsToLogitsCallable, eos_id: int, 48 | num_decodes: int, decode_rng: Optional[jnp.ndarray], 49 | **kwargs) -> Tuple[jnp.ndarray, jnp.ndarray]: 50 | """Decoding function interface. 51 | 52 | Args: 53 | inputs: [batch_size, max_decode_len] int32 sequence of tokens, with non-0 54 | prefix tokens to be used as a forced prompt. 55 | cache: flax attention cache. 56 | tokens_to_logits: fast autoregressive decoder function taking single token 57 | slices and cache and returning next-token logits and updated cache. 58 | eos_id: end-of-sentence token for target vocabulary. 59 | num_decodes: number of decoded sequences to be returned. 60 | decode_rng: an optional JAX PRNG Key for stochastic sampling routines. 61 | **kwargs: an optional kwargs. One common usecase of this is passing 62 | decoding parameters at the callsite. 63 | 64 | Returns: 65 | decodes: Array of sequences: [batch_size, num_decodes, max_decode_len]. 66 | The `num_decodes` dimension is expected to be sorted by the `scores`, 67 | i.e., `decodes[:, -1, :] has the highest scores among `num_decodes` 68 | decoded sequences. 69 | scores: Array of log likelihood scores: [batch_size, num_decodes] 70 | """ 71 | ... 72 | 73 | 74 | def cross_entropy_with_logits(logits: jnp.ndarray, targets: jnp.ndarray) -> jnp.ndarray: 75 | logits_sum = jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True) 76 | log_softmax = logits - logits_sum 77 | loss = -jnp.sum(targets * log_softmax, axis=-1) 78 | return loss 79 | 80 | 81 | class UnifiedIOModel(nn.Module): 82 | """Wrapper that provides generation methods using a `Transformer` module""" 83 | 84 | def __init__( 85 | self, 86 | module: nn.Module, 87 | text_decoder_length=None, 88 | image_decoder_length=None, 89 | ): 90 | self.module = module 91 | self._text_decoder_length = text_decoder_length 92 | self._image_decoder_length = image_decoder_length 93 | 94 | def _compute_logits( 95 | self, 96 | params: PyTreeDef, 97 | batch: Mapping[str, jnp.ndarray], 98 | dropout_rng: Optional[jnp.ndarray] = None, 99 | mutable: flax_scope.CollectionFilter = False 100 | ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, flax_scope.FrozenVariableDict]]: 101 | """Computes logits via a forward pass of `self.module_cls`.""" 102 | # Dropout is provided only for the training mode. 103 | rngs = {'dropout': dropout_rng} if dropout_rng is not None else None 104 | 105 | return self.module.apply( 106 | {'params': params}, 107 | batch['text_encoder_inputs'], 108 | batch['image_encoder_inputs'], 109 | batch['text_decoder_inputs'], 110 | batch['image_decoder_targets'], 111 | batch['text_decoder_targets'], 112 | text_encoder_masks=batch.get('text_encoder_masks', None), 113 | image_encoder_masks=batch.get('image_input_masks', None), 114 | image_encoder_pos_ids=batch.get('image_encoder_pos_ids', None), 115 | text_encoder_pos_ids=batch.get('text_encoder_pos_ids', None), 116 | text_decoder_masks=batch.get('text_decoder_masks', None), 117 | image_decoder_masks=batch.get('image_target_masks', None), 118 | text_decoder_segment_ids=batch.get('text_decoder_segment_ids', None), 119 | text_decoder_positions=batch.get('text_decoder_positions', None), 120 | cache_text_length=self._text_decoder_length, 121 | cache_image_length=self._image_decoder_length, 122 | decode=False, 123 | enable_dropout=rngs is not None, 124 | rngs=rngs, 125 | mutable=mutable) 126 | 127 | def get_initial_variables( 128 | self, 129 | rng: jnp.ndarray, 130 | input_shapes: Mapping[str, Array], 131 | input_types: Optional[Mapping[str, jnp.dtype]] = None, 132 | ) -> flax_scope.FrozenVariableDict: 133 | """Get the initial variables for an encoder-decoder model.""" 134 | input_types = {} if input_types is None else input_types 135 | text_encoder_shape = input_shapes['text_encoder_inputs'] 136 | text_encoder_type = input_types.get('text_encoder_inputs', jnp.float32) 137 | image_encoder_shape = input_shapes['image_encoder_inputs'] 138 | image_encoder_type = input_types.get('image_encoder_inputs', jnp.float32) 139 | text_decoder_shape = input_shapes['text_decoder_inputs'] 140 | text_decoder_type = input_types.get('text_decoder_inputs', jnp.float32) 141 | image_decoder_shape = input_shapes['image_decoder_targets'] 142 | image_decoder_type = input_types.get('image_decoder_targets', jnp.float32) 143 | initial_variables = self.module.init( 144 | rng, 145 | jnp.ones(text_encoder_shape, text_encoder_type), 146 | jnp.ones(image_encoder_shape, image_encoder_type), 147 | jnp.ones(text_decoder_shape, text_decoder_type), 148 | jnp.ones(image_decoder_shape, image_decoder_type), 149 | jnp.ones(text_decoder_shape, text_decoder_type), 150 | decode=False, 151 | enable_dropout=False, 152 | cache_text_length=self._text_decoder_length, 153 | cache_image_length=self._image_decoder_length, 154 | vae_decode=True) 155 | return initial_variables 156 | 157 | def predict_with_answer_options( 158 | self, 159 | params: PyTreeDef, 160 | batch: Mapping[str, jnp.ndarray], 161 | max_options=800, 162 | average_loss=False 163 | ): 164 | text_answer_options = len(batch["output_options"].shape) == 3 165 | text_encoder_inputs = batch['text_encoder_inputs'] 166 | text_encoder_masks = batch.get('text_encoder_masks') 167 | if text_encoder_masks is None: 168 | text_encoder_masks = text_encoder_inputs > 0 169 | 170 | _encoded_inputs, _encoder_masks = self.module.apply( 171 | {'params': params}, 172 | text_encoder_inputs, 173 | batch['image_encoder_inputs'], 174 | text_encoder_masks, 175 | batch['image_input_masks'], 176 | image_encoder_pos_ids=batch.get('image_encoder_pos_ids', None), 177 | text_encoder_pos_ids=batch.get('text_encoder_pos_ids', None), 178 | enable_dropout=False, 179 | method=self.module.encode 180 | ) 181 | 182 | all_losses = [] 183 | n_options = batch["output_options"].shape[1] 184 | 185 | n_groups = (n_options + max_options - 1) // max_options 186 | for i in range(n_groups): 187 | output_options = batch["output_options"][:, i*max_options:(i+1)*max_options] 188 | batch_size, num_option = output_options.shape[:2] 189 | encoded, encoder_position_embedding = _encoded_inputs 190 | encoded = decoding.flat_batch_beam_expand(encoded, num_option) 191 | encoder_position_embedding = decoding.flat_batch_beam_expand(encoder_position_embedding, num_option) 192 | encoder_masks = decoding.flat_batch_beam_expand(_encoder_masks, num_option) 193 | encoded_inputs = (encoded, encoder_position_embedding) 194 | decoded_size = batch_size*num_option 195 | 196 | if text_answer_options: 197 | # Text answer options 198 | # `output_options` does not have EOS or BOS, we need to do a bit work to correctly-formatted 199 | # text inputs/outputs here 200 | text_decoder_inputs = output_options.reshape((decoded_size, -1)) 201 | text_decoder_targets = text_decoder_inputs 202 | text_decoder_targets = jnp.pad(text_decoder_targets, [[0, 0], [0, 1]]) # Add room for EOS 203 | 204 | text_decoder_masks = text_decoder_inputs > 0 205 | text_decoder_inputs = jnp.pad(text_decoder_inputs, [[0, 0], [1, 0]]) 206 | text_decoder_masks = jnp.pad(text_decoder_masks, [[0, 0], [1, 0]], constant_values=True) 207 | 208 | eos_mask = jnp.logical_and(text_decoder_masks, text_decoder_targets == 0) 209 | text_decoder_targets = text_decoder_targets + eos_mask 210 | 211 | image_decoder_inputs = jnp.zeros([encoded.shape[0], 1], jnp.int32) 212 | image_decoder_targets = jnp.zeros([encoded.shape[0], 1], jnp.int32) 213 | image_decoder_masks = jnp.zeros([encoded.shape[0], 1], jnp.int32) 214 | else: 215 | # Image answer options 216 | image_decoder_masks = batch["output_options_masks"][:, i*max_options:(i+1)*max_options] 217 | image_decoder_masks = image_decoder_masks.reshape(-1, 256) 218 | output_options = output_options.reshape([decoded_size] + list(output_options.shape[2:])) 219 | 220 | # Apply the VAE to get the target tokens 221 | image_decoder_targets = self.module.apply( 222 | {'params': params}, 223 | output_options, 224 | method=self.module.encode_target_image 225 | ) 226 | 227 | # Build auto-regressive inputs 228 | image_start_token = self.module.config.vocab_size - 1 229 | image_decoder_inputs = jnp.concatenate([ 230 | jnp.zeros((image_decoder_targets.shape[0], 1), dtype=jnp.int32) + image_start_token, 231 | image_decoder_targets[:, :-1]], axis=1) 232 | 233 | # Predict EOS to start following the training scheme 234 | text_decoder_inputs = jnp.zeros([decoded_size, 1], jnp.int32) 235 | text_decoder_targets = jnp.ones([decoded_size, 1], jnp.int32) 236 | text_decoder_masks = jnp.ones([decoded_size, 1], jnp.int32) 237 | 238 | text_logits, image_logits, image_decoder_targets = self.module.apply( 239 | {'params': params}, 240 | encoded_inputs, 241 | encoder_masks, 242 | text_decoder_inputs, 243 | image_decoder_inputs, 244 | text_decoder_targets, 245 | image_decoder_targets, 246 | text_decoder_masks=text_decoder_masks, 247 | image_decoder_masks=image_decoder_masks, 248 | enable_dropout=False, 249 | method=self.module.decode 250 | ) 251 | 252 | vocab_size = 33152 253 | if text_answer_options: 254 | soft_targets = common_utils.onehot(text_decoder_targets, text_logits.shape[-1], on_value=1.0, off_value=0.0) 255 | total_loss = cross_entropy_with_logits(text_logits, soft_targets) 256 | total_loss = total_loss * text_decoder_masks 257 | total_loss = jnp.sum(total_loss, axis=1) 258 | if average_loss: 259 | total_loss = total_loss / jnp.sum(text_decoder_masks, axis=1) 260 | total_loss = jnp.reshape(total_loss, [batch_size, -1]) 261 | else: 262 | soft_targets = common_utils.onehot(image_decoder_targets+vocab_size, image_logits.shape[-1]) 263 | total_loss = cross_entropy_with_logits(image_logits, soft_targets) 264 | total_loss = total_loss * image_decoder_masks 265 | total_loss = jnp.sum(total_loss, axis=1) 266 | if average_loss: 267 | total_loss = total_loss / jnp.sum(image_decoder_masks, axis=1) 268 | total_loss = jnp.reshape(total_loss, [batch_size, -1]) 269 | 270 | all_losses.append(total_loss) 271 | 272 | text_loss = jnp.concatenate(all_losses, -1) 273 | selected_option_ix = jnp.argmin(text_loss, -1) 274 | ix = jnp.arange(0, len(selected_option_ix)) 275 | selected_options = batch["output_options"][ix, selected_option_ix] 276 | selected_loss = text_loss[ix, selected_option_ix] 277 | out = {'scores': selected_loss, "all_scores": text_loss} 278 | if text_answer_options: 279 | out['text_tokens'] = selected_options 280 | else: 281 | out['image'] = jnp.clip((selected_options+1)/2.0, 0, 1) 282 | return out 283 | 284 | def _compute_logits_from_slice( 285 | self, flat_ids: jnp.ndarray, flat_cache: Mapping[str, jnp.ndarray], cur_index: int, 286 | live_seqs: jnp.ndarray, params: PyTreeDef, encoded_inputs: jnp.ndarray, encoder_masks: jnp.ndarray, 287 | text_length: int, image_length: int, logit_masks: jnp.ndarray = None) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: 288 | """Token slice to logits from decoder model.""" 289 | # flat_ids: [batch * beam, seq_len=1] 290 | # cache is expanded inside beam_search to become flat_cache 291 | # flat_cache: [batch * beam, num_heads, depth_per_head, max_decode_len] 292 | # flat_logits: [batch * beam, seq_len=1, vocab] 293 | 294 | def update_flat_ids(x): 295 | x = jnp.zeros_like(x) + self.module.config.vocab_size - 1 296 | return x 297 | 298 | def update_pos_ids(x): 299 | x = x + self.module.config.max_text_length - text_length 300 | return x 301 | 302 | def identity_fn(x): 303 | return x 304 | 305 | def update_ones(x): 306 | x = jnp.zeros_like(x) + 1 307 | return x 308 | 309 | def update_zeros(x): 310 | x = jnp.zeros_like(x) 311 | return x 312 | 313 | flat_ids = jax.lax.cond( 314 | jax.lax.eq(cur_index, text_length), 315 | lambda: update_flat_ids(flat_ids), 316 | lambda: identity_fn(flat_ids)) 317 | 318 | seg_ids = jax.lax.cond( 319 | jax.lax.ge(cur_index, text_length), 320 | lambda: update_ones(flat_ids), 321 | lambda: update_zeros(flat_ids)) 322 | 323 | decoder_masks = jax.lax.cond(cur_index < text_length, 324 | lambda: jnp.reshape((live_seqs == 1).sum(axis=-1) == 0, (-1,1)), 325 | lambda: jnp.ones(flat_ids.shape, dtype=jnp.bool_)) 326 | 327 | flat_logits, new_vars = self.module.apply( 328 | { 329 | 'params': params, 330 | 'cache': flat_cache 331 | }, 332 | encoded_inputs, 333 | encoder_masks, # only needed for encoder padding mask 334 | flat_ids, 335 | decoder_masks=decoder_masks, 336 | decoder_segments=seg_ids, 337 | enable_dropout=False, 338 | decode=True, 339 | image_decode_length=image_length, 340 | text_decode_length=text_length, 341 | cur_index=cur_index, 342 | mutable=['cache'], 343 | method=self.module.sample) 344 | # Remove sequence length dimension since it's always 1 during decoding. 345 | flat_logits = jnp.squeeze(flat_logits, axis=1) 346 | new_flat_cache = new_vars['cache'] 347 | 348 | cfg = self.module.config 349 | total_vocab_size = cfg.vocab_size + cfg.image_vocab_size 350 | logit_range = jnp.reshape(jnp.arange(total_vocab_size), [1, 1, -1]) 351 | image_logits_mask = jnp.reshape(logit_range < cfg.vocab_size, [1, -1]) 352 | text_logits_mask = jnp.reshape(logit_range >= cfg.vocab_size, [1, -1]) 353 | 354 | flat_logits = jax.lax.cond( 355 | jax.lax.ge(cur_index, text_length), 356 | lambda: jnp.where(image_logits_mask, -1e10, flat_logits), 357 | lambda: jnp.where(text_logits_mask, -1e10, flat_logits)) 358 | 359 | def update_mask(flat_logits, logit_masks, cur_index): 360 | mask = jnp.reshape(logit_masks[cur_index], [1, -1]) 361 | flat_logits = jnp.where(mask, -1e10, flat_logits) 362 | return flat_logits 363 | 364 | # apply mask here. 365 | if logit_masks is not None: 366 | flat_logits = jax.lax.cond( 367 | jax.lax.lt(cur_index, logit_masks.shape[0]), 368 | lambda: update_mask(flat_logits, logit_masks, cur_index), 369 | lambda: identity_fn(flat_logits)) 370 | 371 | return flat_logits, new_flat_cache 372 | 373 | def predict_batch_with_aux( 374 | self, 375 | params: PyTreeDef, 376 | batch: Mapping[str, jnp.ndarray], 377 | decoder_params: Optional[MutableMapping[str, Any]] = None, 378 | return_all_decodes: bool = False, 379 | num_decodes: int=1, 380 | text_length=64, 381 | image_length=256, 382 | logit_mask_fn=None, 383 | beam_search=None, 384 | ) -> Mapping[str, jnp.ndarray]: 385 | """Generate outputs from the model. 386 | 387 | Args: 388 | params: model parameters. 389 | batch: a batch of inputs. 390 | decoder_params: additional (model-independent) parameters for the decoder. 391 | return_all_decodes: whether to return the entire beam or just the top-1. 392 | num_decodes: the number of beams to use in beam search. 393 | 394 | Returns: 395 | A tuple containing: 396 | the batch of predictions, with the entire beam if requested 397 | an auxiliary dictionary of decoder scores 398 | """ 399 | if "output_options" in batch: 400 | return self.predict_with_answer_options(params, batch) 401 | 402 | # [batch, input_len] 403 | text_encoder_inputs = batch['text_encoder_inputs'] 404 | image_encoder_inputs = batch['image_encoder_inputs'] 405 | image_input_masks = batch['image_input_masks'] 406 | text_encoder_masks = batch.get('text_encoder_masks') 407 | if text_encoder_masks is None: 408 | text_encoder_masks = text_encoder_inputs > 0 409 | 410 | # Prepare zeroed-out autoregressive cache. 411 | # [batch, input_len] 412 | text_type = batch['text_encoder_inputs'].dtype 413 | bs = text_encoder_inputs.shape[0] 414 | 415 | _, variables_with_cache = self.module.apply( 416 | {'params': params}, 417 | jnp.ones_like(text_encoder_inputs), 418 | jnp.ones_like(image_encoder_inputs), 419 | jnp.ones((bs, text_length), text_type), 420 | jnp.ones((bs, 256, 256, 3), image_encoder_inputs.dtype), 421 | jnp.ones((bs, text_length), text_type), 422 | decode=True, 423 | enable_dropout=False, 424 | vae_decode=False, 425 | cache_text_length=text_length, 426 | cache_image_length=image_length, 427 | mutable=['cache']) 428 | 429 | cache = variables_with_cache['cache'] 430 | 431 | # Prepare transformer fast-decoder call for beam search: for beam search, we 432 | # need to set up our decoder model to handle a batch size equal to 433 | # batch_size * num_decodes, where each batch item's data is expanded 434 | # in-place rather than tiled. 435 | # i.e. if we denote each batch element subtensor as el[n]: 436 | # [el0, el1, el2] --> beamsize=2 --> [el0,el0,el1,el1,el2,el2] 437 | # [batch * num_decodes, input_len, emb_dim] 438 | encoded_inputs, encoder_masks = self.module.apply({'params': params}, 439 | text_encoder_inputs, 440 | image_encoder_inputs, 441 | text_encoder_masks, 442 | image_input_masks, 443 | image_encoder_pos_ids=batch.get('image_encoder_pos_ids', None), 444 | text_encoder_pos_ids=batch.get('text_encoder_pos_ids', None), 445 | enable_dropout=False, 446 | method=self.module.encode) 447 | 448 | encoded, encoder_position_embedding = encoded_inputs 449 | encoded = decoding.flat_batch_beam_expand(encoded, num_decodes) 450 | encoder_masks = decoding.flat_batch_beam_expand(encoder_masks, num_decodes) 451 | encoded_inputs = (encoded, encoder_position_embedding) 452 | 453 | if logit_mask_fn is not None: 454 | logit_masks = logit_mask_fn() 455 | else: 456 | logit_masks = None 457 | 458 | tokens_ids_to_logits = functools.partial( 459 | self._compute_logits_from_slice, 460 | params=params, 461 | encoded_inputs=encoded_inputs, 462 | encoder_masks=encoder_masks, 463 | text_length=text_length, 464 | image_length=image_length, 465 | logit_masks=logit_masks) 466 | 467 | if decoder_params is None: 468 | decoder_params = {} 469 | 470 | # For beam search, `decoder_prompt_inputs` is only used to obtain batch size 471 | # and max decode length information. For temperature sampling, 472 | # `decod_prompt_inputs` will be filled with the sampled ids. 473 | decoder_prompt_inputs = jnp.zeros([bs, text_length+image_length], text_type) 474 | 475 | # TODO(hwchung): rename the returned value names to more generic ones. 476 | # Using the above-defined single-step decoder function, run a 477 | # beam search over possible sequences given input encoding. 478 | # decodes: [batch, num_decodes, max_decode_len + 1] 479 | # scores: [batch, num_decodes] 480 | scanned = hasattr(self.module, 'scan_layers') and self.module.scan_layers 481 | 482 | if isinstance(beam_search, Callable): # For fine-grain hyper-parameter control 483 | decodes, scores, logprobs = beam_search( 484 | inputs=decoder_prompt_inputs, 485 | cache=cache, 486 | tokens_to_logits=tokens_ids_to_logits, 487 | num_decodes=num_decodes, 488 | cache_offset=1 if scanned else 0, 489 | ) 490 | elif beam_search: 491 | decodes, scores, logprobs = decoding.beam_search( 492 | inputs=decoder_prompt_inputs, 493 | cache=cache, 494 | alpha=0.0, 495 | tokens_to_logits=tokens_ids_to_logits, 496 | eos_id=1, 497 | num_decodes=num_decodes, 498 | cache_offset=1 if scanned else 0, 499 | **decoder_params) 500 | else: 501 | decodes, scores, logprobs = decoding.temperature_sample( 502 | inputs=decoder_prompt_inputs, 503 | cache=cache, 504 | tokens_to_logits=tokens_ids_to_logits, 505 | eos_id=1, 506 | num_decodes=num_decodes, 507 | topk = 0, 508 | topp = 0.9, 509 | cache_offset=1 if scanned else 0, 510 | **decoder_params) 511 | 512 | scores = jax.lax.stop_gradient(scores) 513 | 514 | out = {} 515 | 516 | if image_length == 256: 517 | # Get the image tokens and decode with the VAE 518 | if return_all_decodes: 519 | image_decodes = decodes[:, :, -256:].reshape(-1, 256) 520 | else: 521 | image_decodes = decodes[:, -1, -256:] 522 | decodes = decodes[:, :, :-256] 523 | 524 | image_decodes = image_decodes - self.module.config.vocab_size 525 | img = self.module.apply( 526 | {'params': params}, 527 | method=self.module.decode_code, 528 | code_b=image_decodes) 529 | 530 | if return_all_decodes: 531 | img = jnp.reshape(img, decodes.shape[:2] + img.shape[1:]) 532 | image_decodes = jnp.reshape(image_decodes, decodes.shape[:2] + image_decodes.shape[1:]) 533 | out["image"] = jnp.clip((img+1)/2.0, 0, 1) 534 | out["image_tokens"] = image_decodes 535 | 536 | if not return_all_decodes: 537 | # Beam search returns [n_batch, n_beam, n_length] with beam dimension sorted 538 | # in increasing order of log-probability. 539 | # Return the highest scoring beam sequence. 540 | decodes = decodes[:, -1] 541 | scores = scores[:, -1] 542 | 543 | out["text_tokens"] = decodes 544 | out["scores"] = scores 545 | return out 546 | -------------------------------------------------------------------------------- /uio/network.py: -------------------------------------------------------------------------------- 1 | """Defines the modules that make up the UnifiedIO model""" 2 | # Modified from code from T5X (https://github.com/google-research/t5x) 3 | 4 | import logging 5 | from dataclasses import dataclass 6 | from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union 7 | import numpy as np 8 | import math 9 | 10 | import jax 11 | from flax import linen as nn 12 | from flax import struct 13 | import jax.numpy as jnp 14 | import uio.t5x_layers as layers 15 | 16 | 17 | @dataclass 18 | class UnifiedIOConfig: 19 | vocab_size: int = 33152 20 | image_vocab_size: int = 16384 21 | image_patch_size: int = 16 22 | # Activation dtypes. 23 | dtype: Any = jnp.float32 24 | emb_dim: int = 512 25 | num_heads: int = 8 26 | num_encoder_layers: int = 6 27 | num_decoder_layers: int = 6 28 | head_dim: int = 64 29 | mlp_dim: int = 2048 30 | # Activation functions are retrieved from Flax. 31 | mlp_activations: Sequence[str] = ('gelu', 'linear') 32 | dropout_rate: float = 0.0 33 | # the embedding weights are used in the decoder output layer. 34 | logits_via_embedding: bool = True 35 | # Whether to accumulate attention logits in float32 regardless of dtype. 36 | float32_attention_logits: bool = False 37 | encoder_max_image_length: int = 576 38 | encoder_max_text_length: int = 256 39 | decoder_max_image_length: int = 256 40 | decoder_max_text_length: int = 256 41 | visual_backbone_type: str = None 42 | visual_backbone_feature: str = None 43 | default_image_size: Sequence[int] = (384, 384) 44 | num_seg_emb: int = 2 45 | 46 | 47 | @dataclass 48 | class VAEConfig: 49 | embed_dim: int = 256 50 | n_embed: int = 1024 51 | double_z: bool = False 52 | z_channels: int = 256 53 | resolution: int = 256 54 | in_channels: int = 3 55 | out_ch: int = 3 56 | ch: int = 128 57 | ch_mult: Sequence[int] = (1,1,2,2,4) 58 | num_res_blocks: int = 2 59 | attn_resolutions: Sequence[int] = (16,) 60 | dropout: float = 0 61 | dtype: Any = jnp.float32 62 | 63 | 64 | class AttnBlock(nn.Module): 65 | n_in: int 66 | dtype: Any = jnp.float32 67 | 68 | @nn.compact 69 | def __call__(self, x, training=False): 70 | h_ = x 71 | h_ = layers.GroupNorm(name='norm')(h_) 72 | q = layers.Conv( 73 | features=self.n_in, 74 | kernel_size=(1, 1), 75 | dtype=self.dtype, 76 | kernel_axes=('axis_0', 'axis_1', 'axis_2', 'axis_3'), 77 | bias_axes=('axis_3',), 78 | name='q')(h_) 79 | 80 | k = layers.Conv( 81 | features=self.n_in, 82 | kernel_size=(1, 1), 83 | dtype=self.dtype, 84 | kernel_axes=('axis_0', 'axis_1', 'axis_2', 'axis_3'), 85 | bias_axes=('axis_3',), 86 | name='k')(h_) 87 | 88 | v = layers.Conv( 89 | features=self.n_in, 90 | kernel_size=(1, 1), 91 | dtype=self.dtype, 92 | kernel_axes=('axis_0', 'axis_1', 'axis_2', 'axis_3'), 93 | bias_axes=('axis_3',), 94 | name='v')(h_) 95 | 96 | b, h, w, c = q.shape 97 | 98 | w_ = jnp.einsum('bqc,bkc->bqk', jnp.reshape(q, (b, h*w, c)), jnp.reshape(k, (b, h*w, c))) 99 | w_ = w_ * (c ** -0.5) 100 | w_ = jax.nn.softmax(w_).astype(self.dtype) 101 | h_ = jnp.einsum('bqk,bkd->bqd', w_, jnp.reshape(v, (b, h*w, c))) 102 | h_ = jnp.reshape(h_, (b, h, w, c)) 103 | h_ = layers.Conv( 104 | features=self.n_in, 105 | kernel_size=(1, 1), 106 | dtype=self.dtype, 107 | kernel_axes=('axis_0', 'axis_1', 'axis_2', 'axis_3'), 108 | bias_axes=('axis_3',), 109 | name='proj_out')(h_) 110 | 111 | return x+h_ 112 | 113 | 114 | class Downsample(nn.Module): 115 | n_in: int 116 | dtype: Any = jnp.float32 117 | 118 | @nn.compact 119 | def __call__(self, x, training=False): 120 | return layers.Conv( 121 | features=self.n_in, 122 | kernel_size=(3, 3), 123 | strides=(2,2), 124 | dtype=self.dtype, 125 | kernel_axes=('axis_0', 'axis_1', 'axis_2', 'axis_3'), 126 | bias_axes=('axis_3',), 127 | name='conv')(x) 128 | 129 | 130 | class Upsample(nn.Module): 131 | n_in: int 132 | dtype: Any = jnp.float32 133 | 134 | @nn.compact 135 | def __call__(self, x, training=False): 136 | B, H, W, C = x.shape 137 | x = jax.image.resize(x, shape=(B, H * 2, W * 2, C), method='nearest') 138 | x = layers.Conv( 139 | features=self.n_in, 140 | kernel_size=(3, 3), 141 | strides=(1,1), 142 | dtype=self.dtype, 143 | kernel_axes=('axis_0', 'axis_1', 'axis_2', 'axis_3'), 144 | bias_axes=('axis_3',), 145 | name='conv')(x) 146 | 147 | return x 148 | 149 | 150 | class ResBlock(nn.Module): 151 | n_in: int 152 | n_out: int 153 | dtype: Any = jnp.float32 154 | 155 | @nn.compact 156 | def __call__(self, x, training=False): 157 | h = x 158 | h = layers.GroupNorm(name='norm1')(h) 159 | h = layers.nonlinearity(h) 160 | h = layers.Conv( 161 | features=self.n_out, 162 | kernel_size=(3, 3), 163 | dtype=self.dtype, 164 | kernel_axes=('axis_0', 'axis_1', 'axis_2', 'axis_3'), 165 | bias_axes=('axis_3',), 166 | name='conv1')(h) 167 | 168 | h = layers.GroupNorm(name='norm2')(h) 169 | h = layers.nonlinearity(h) 170 | h = layers.Conv( 171 | features=self.n_out, 172 | kernel_size=(3, 3), 173 | dtype=self.dtype, 174 | kernel_axes=('axis_0', 'axis_1', 'axis_2', 'axis_3'), 175 | bias_axes=('axis_3',), 176 | name='conv2')(h) 177 | 178 | if self.n_in != self.n_out: 179 | x = layers.Conv( 180 | features=self.n_out, 181 | kernel_size=(1,1), 182 | dtype=self.dtype, 183 | kernel_axes=('axis_0', 'axis_1', 'axis_2', 'axis_3'), 184 | bias_axes=('axis_3',), 185 | name='nin_shortcut')(x) 186 | return x + h 187 | 188 | 189 | class VAE_Encoder(nn.Module): 190 | """Jax implementation of Taming VAE encoder""" 191 | config: VAEConfig 192 | 193 | @nn.compact 194 | def __call__(self, x, training=False): 195 | cfg = self.config 196 | curr_res = cfg.resolution 197 | num_resolutions = len(cfg.ch_mult) 198 | in_ch_mult = (1,)+tuple(cfg.ch_mult) 199 | 200 | hs = layers.Conv( 201 | features=1 * cfg.ch, 202 | kernel_size=(3, 3), 203 | strides=(1, 1), 204 | dtype=cfg.dtype, 205 | kernel_axes=('axis_0', 'axis_1', 'axis_2', 'axis_3'), 206 | bias_axes=('axis_3',), 207 | name='conv_in')(x) 208 | 209 | for i_level in range(num_resolutions): 210 | block_in = cfg.ch * in_ch_mult[i_level] 211 | block_out = cfg.ch * cfg.ch_mult[i_level] 212 | for i_block in range(cfg.num_res_blocks): 213 | hs = ResBlock( 214 | block_in, 215 | block_out, 216 | cfg.dtype, 217 | name=f"down_{i_level}_block_{i_block}")(hs) 218 | block_in = block_out 219 | if curr_res in cfg.attn_resolutions: 220 | hs = AttnBlock( 221 | block_in, 222 | name=f"down_{i_level}_attn_{i_block}")(hs) 223 | 224 | if i_level != num_resolutions-1: 225 | hs = Downsample( 226 | block_in, 227 | name=f"down_{i_level}_downsample")(hs) 228 | curr_res = curr_res // 2 229 | 230 | hs = ResBlock(block_in, block_in, name='mid_block_1')(hs) 231 | hs = AttnBlock(block_in, name='mid_attn_1')(hs) 232 | hs = ResBlock(block_in, block_in, name='mid_block_2')(hs) 233 | hs = layers.GroupNorm(name='norm_out')(hs) 234 | 235 | hs = layers.nonlinearity(hs) 236 | hs = layers.Conv( 237 | features=cfg.z_channels, 238 | kernel_size=(3, 3), 239 | dtype=cfg.dtype, 240 | kernel_axes=('axis_0', 'axis_1', 'axis_2', 'axis_3'), 241 | bias_axes=('axis_3',), 242 | name='conv_out')(hs) 243 | 244 | return hs 245 | 246 | class VAE_Decoder(nn.Module): 247 | """Jax implementation of Taming VAE encoder""" 248 | config: VAEConfig 249 | 250 | @nn.compact 251 | def __call__(self, x, training=False): 252 | 253 | cfg = self.config 254 | in_ch_mult = (1,)+tuple(cfg.ch_mult) 255 | num_resolutions = len(cfg.ch_mult) 256 | curr_res = cfg.resolution // 2**(num_resolutions-1) 257 | block_in = cfg.ch*cfg.ch_mult[num_resolutions-1] 258 | 259 | # z to block_in 260 | h = layers.Conv( 261 | features=block_in, 262 | kernel_size=(3, 3), 263 | strides=(1, 1), 264 | dtype=cfg.dtype, 265 | kernel_axes=('axis_0', 'axis_1', 'axis_2', 'axis_3'), 266 | bias_axes=('axis_3',), 267 | name='conv_in')(x) 268 | 269 | h = ResBlock(block_in, block_in, name='mid_block_1')(h) 270 | h = AttnBlock(block_in, name='mid_attn_1')(h) 271 | h = ResBlock(block_in, block_in, name='mid_block_2')(h) 272 | 273 | for i_level in reversed(range(num_resolutions)): 274 | i_idx = num_resolutions - i_level-1 275 | block_out = cfg.ch * cfg.ch_mult[i_level] 276 | for i_block in range(cfg.num_res_blocks+1): 277 | h = ResBlock(block_in, block_out, name=f"up_{i_idx}_block_{i_block}")(h) 278 | block_in = block_out 279 | if curr_res in cfg.attn_resolutions: 280 | h = AttnBlock(block_in, name=f"up_{i_idx}_attn_{i_block}")(h) 281 | if i_level != 0: 282 | h = Upsample(block_in, name=f"up_{i_idx}_upsample")(h) 283 | curr_res = curr_res * 2 284 | 285 | h = layers.GroupNorm(name='norm_out')(h) 286 | h = layers.nonlinearity(h) 287 | h = layers.Conv( 288 | features=cfg.out_ch, 289 | kernel_size=(3, 3), 290 | strides=(1, 1), 291 | dtype=cfg.dtype, 292 | kernel_axes=('axis_0', 'axis_1', 'axis_2', 'axis_3'), 293 | bias_axes=('axis_3',), 294 | name='conv_out')(h) 295 | 296 | return h 297 | 298 | class DiscreteVAE(nn.Module): 299 | """Jax implementation of Taming VAE""" 300 | config: VAEConfig 301 | 302 | def setup(self): 303 | cfg = self.config 304 | self.encoder = VAE_Encoder(cfg) 305 | self.quant_conv = layers.Conv( 306 | features=cfg.z_channels, 307 | kernel_size=(1, 1), 308 | dtype=cfg.dtype, 309 | kernel_axes=('axis_0', 'axis_1', 'axis_2', 'axis_3'), 310 | bias_axes=('axis_3',), 311 | name='quant_conv') 312 | 313 | self.quantize = layers.VectorQuantizer( 314 | cfg.n_embed, 315 | cfg.embed_dim, 316 | beta=0.25) 317 | 318 | self.post_quant_conv = layers.Conv( 319 | features=cfg.z_channels, 320 | kernel_size=(1, 1), 321 | dtype=cfg.dtype, 322 | kernel_axes=('axis_0', 'axis_1', 'axis_2', 'axis_3'), 323 | bias_axes=('axis_3',), 324 | name='post_quant_conv') 325 | 326 | self.decoder = VAE_Decoder(cfg) 327 | 328 | def encode(self, x, training=False): 329 | h = self.encoder(x, training) 330 | h = self.quant_conv(h) 331 | quant, emb_loss, info = self.quantize(h) 332 | return quant, emb_loss, info 333 | 334 | def decode(self, quant, training=False): 335 | quant = self.post_quant_conv(quant) 336 | dec = self.decoder(quant, training) 337 | return dec 338 | 339 | def decode_code(self, code_b): 340 | quant_b = self.quantize.get_codebook_entry(code_b) 341 | bs, seq_len, dim = quant_b.shape 342 | size = int(math.sqrt(seq_len)) 343 | quant_b = jnp.reshape(quant_b, (bs, size, size, dim)) 344 | dec = self.decode(quant_b) 345 | return dec 346 | 347 | def get_codebook_indices(self, x, vae_decode=False, training=False): 348 | h = self.encoder(x, training) 349 | h = self.quant_conv(h) 350 | z, _, [_, _, indices] = self.quantize(h) 351 | 352 | if vae_decode: 353 | _ = self.decode(z, training) 354 | 355 | return jnp.reshape(indices, (jnp.shape(h)[0], -1)) 356 | 357 | @nn.compact 358 | def __call__(self, x, training=False): 359 | quant, diff, _ = self.encode(x, training) 360 | dec = self.decode(quant, training) 361 | return dec 362 | 363 | class EncoderLayer(nn.Module): 364 | """Transformer encoder layer.""" 365 | config: UnifiedIOConfig 366 | relative_embedding: nn.Module 367 | 368 | @nn.compact 369 | def __call__(self, inputs, txt_position_ids, img_position_ids, abs_pos_bias, encoder_mask=None, deterministic=False): 370 | cfg = self.config 371 | 372 | # Relative position embedding as attention biases. 373 | encoder_bias = self.relative_embedding(txt_position_ids, img_position_ids, 374 | True) 375 | # Attention block. 376 | assert inputs.ndim == 3 377 | x = layers.LayerNorm( 378 | dtype=cfg.dtype, name='pre_attention_layer_norm')( 379 | inputs) 380 | # [batch, length, emb_dim] -> [batch, length, emb_dim] 381 | x = layers.MultiHeadDotProductAttention( 382 | num_heads=cfg.num_heads, 383 | dtype=cfg.dtype, 384 | head_dim=cfg.head_dim, 385 | dropout_rate=cfg.dropout_rate, 386 | float32_logits=cfg.float32_attention_logits, 387 | name='attention')( 388 | x, x, encoder_mask, encoder_bias, abs_pos_bias, deterministic=deterministic) 389 | 390 | x = nn.Dropout( 391 | rate=cfg.dropout_rate, broadcast_dims=(-2,))( 392 | x, deterministic=deterministic) 393 | 394 | x = x + inputs 395 | 396 | # MLP block. 397 | y = layers.LayerNorm(dtype=cfg.dtype, name='pre_mlp_layer_norm')(x) 398 | # [batch, length, emb_dim] -> [batch, length, emb_dim] 399 | y = layers.MlpBlock( 400 | intermediate_dim=cfg.mlp_dim, 401 | activations=cfg.mlp_activations, 402 | intermediate_dropout_rate=cfg.dropout_rate, 403 | dtype=cfg.dtype, 404 | name='mlp', 405 | )(y, deterministic=deterministic) 406 | 407 | y = nn.Dropout( 408 | rate=cfg.dropout_rate, broadcast_dims=(-2,))( 409 | y, deterministic=deterministic) 410 | y = y + x 411 | return y 412 | 413 | class DecoderLayer(nn.Module): 414 | """Transformer decoder layer that attends to the encoder.""" 415 | config: UnifiedIOConfig 416 | relative_embedding: nn.Module 417 | 418 | @nn.compact 419 | def __call__(self, 420 | inputs, 421 | encoded, 422 | self_abs_pos_bias, 423 | cross_abs_pos_bias, 424 | decoder_mask=None, 425 | encoder_decoder_mask=None, 426 | deterministic=False, 427 | decode=False, 428 | image_decoder_positions=None, 429 | text_decoder_positions=None): 430 | 431 | cfg = self.config 432 | 433 | # Relative position embedding as attention biases. 434 | # l = max_decode_length if decode and max_decode_length else inputs.shape[-2] 435 | decoder_bias = self.relative_embedding(text_decoder_positions, image_decoder_positions, False) 436 | # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] 437 | x = layers.LayerNorm( 438 | dtype=cfg.dtype, name='pre_self_attention_layer_norm')( 439 | inputs) 440 | # Self-attention block 441 | x = layers.MultiHeadDotProductAttention( 442 | num_heads=cfg.num_heads, 443 | dtype=cfg.dtype, 444 | head_dim=cfg.head_dim, 445 | dropout_rate=cfg.dropout_rate, 446 | float32_logits=cfg.float32_attention_logits, 447 | name='self_attention')( 448 | x, 449 | x, 450 | decoder_mask, 451 | decoder_bias, 452 | self_abs_pos_bias, 453 | deterministic=deterministic, 454 | decode=decode) 455 | 456 | x = nn.Dropout( 457 | rate=cfg.dropout_rate, broadcast_dims=(-2,))( 458 | x, deterministic=deterministic) 459 | 460 | x = x + inputs 461 | # Encoder-Decoder block. 462 | y = layers.LayerNorm( 463 | dtype=cfg.dtype, name='pre_cross_attention_layer_norm')( 464 | x) 465 | y = layers.MultiHeadDotProductAttention( 466 | num_heads=cfg.num_heads, 467 | dtype=cfg.dtype, 468 | head_dim=cfg.head_dim, 469 | dropout_rate=cfg.dropout_rate, 470 | float32_logits=cfg.float32_attention_logits, 471 | name='encoder_decoder_attention')( 472 | y, 473 | encoded, 474 | encoder_decoder_mask, 475 | None, 476 | cross_abs_pos_bias, 477 | deterministic=deterministic) 478 | 479 | 480 | y = nn.Dropout( 481 | rate=cfg.dropout_rate, broadcast_dims=(-2,))( 482 | y, deterministic=deterministic) 483 | 484 | y = y + x 485 | 486 | # MLP block. 487 | z = layers.LayerNorm(dtype=cfg.dtype, name='pre_mlp_layer_norm')(y) 488 | z = layers.MlpBlock( 489 | intermediate_dim=cfg.mlp_dim, 490 | activations=cfg.mlp_activations, 491 | intermediate_dropout_rate=cfg.dropout_rate, 492 | dtype=cfg.dtype, 493 | name='mlp', 494 | )(z, deterministic=deterministic) 495 | z = nn.Dropout( 496 | rate=cfg.dropout_rate, broadcast_dims=(-2,))( 497 | z, deterministic=deterministic) 498 | z = z + y 499 | 500 | return z 501 | 502 | 503 | class Encoder(nn.Module): 504 | """A stack of encoder layers.""" 505 | config: UnifiedIOConfig 506 | shared_embedding: nn.Module 507 | 508 | def setup(self): 509 | cfg = self.config 510 | self.segment_embedding = layers.Embed( 511 | num_embeddings=cfg.num_seg_emb, 512 | features=cfg.emb_dim, 513 | dtype=cfg.dtype, 514 | attend_dtype=jnp.float32, # for logit training stability 515 | embedding_init=nn.initializers.normal(stddev=1.0), 516 | one_hot=True, 517 | name='segment_embedding') 518 | 519 | self.positon_embedding = layers.Embed( 520 | num_embeddings=cfg.encoder_max_text_length+cfg.encoder_max_image_length, 521 | features=cfg.emb_dim, 522 | dtype=cfg.dtype, 523 | attend_dtype=jnp.float32, # for logit training stability 524 | embedding_init=nn.initializers.normal(stddev=1.0), 525 | one_hot=True, 526 | name='position_embedding') 527 | 528 | @nn.compact 529 | def __call__(self, 530 | text_encoder_inputs, 531 | image_encoder_inputs, 532 | txt_position_ids, 533 | img_position_ids, 534 | encoder_masks=None, 535 | deterministic=False): 536 | cfg = self.config 537 | assert text_encoder_inputs.ndim == 2 # [batch, length] 538 | if image_encoder_inputs.ndim == 3: 539 | # use default length 540 | bs = image_encoder_inputs.shape[0] 541 | h, w = cfg.default_image_size 542 | else: 543 | bs, h, w, _ = image_encoder_inputs.shape 544 | 545 | txt_length = text_encoder_inputs.shape[1] 546 | 547 | rel_emb = layers.RelativePositionBiases( 548 | num_buckets=32, 549 | img_num_buckets=8, 550 | max_distance=128, 551 | img_max_distance=20, 552 | num_heads=cfg.num_heads, 553 | img_width=w//cfg.image_patch_size, 554 | img_height=h//cfg.image_patch_size, 555 | dtype=cfg.dtype, 556 | embedding_init=nn.initializers.variance_scaling(1.0, 'fan_avg', 557 | 'uniform'), 558 | name='relpos_bias') 559 | 560 | # do the image encoding. 561 | if image_encoder_inputs.ndim == 4: 562 | img_emb = layers.space_to_depth(image_encoder_inputs, spatial_block_size=cfg.image_patch_size) 563 | else: 564 | img_emb = image_encoder_inputs 565 | 566 | txt_pos_emb = self.positon_embedding(txt_position_ids) 567 | img_pos_emb = self.positon_embedding(img_position_ids+cfg.encoder_max_text_length) 568 | 569 | if (image_encoder_inputs.ndim == 4 and 570 | img_emb.shape[1] != cfg.encoder_max_image_length and 571 | img_emb.shape[1] != 1): 572 | # Our input is a full-sized image that has more or less patches than our default 573 | # `img_emb.shape[1] != 1` catches the case of being give 574 | pos_size = int(cfg.encoder_max_image_length ** 0.5) 575 | target_size = int(img_emb.shape[1] ** 0.5) 576 | img_pos_emb = jnp.reshape(img_pos_emb, [1, pos_size, pos_size, cfg.emb_dim]) 577 | img_pos_emb = jax.image.resize(img_pos_emb, [1, target_size, target_size, cfg.emb_dim], "bicubic") 578 | img_pos_emb = jnp.reshape(img_pos_emb, [1, -1, cfg.emb_dim]) 579 | # update image position ids for relative position encoding. 580 | img_position_ids = jnp.arange(img_emb.shape[1], dtype=jnp.int32) 581 | img_position_ids = jnp.expand_dims(img_position_ids, axis=0) 582 | 583 | img_emb = layers.DenseGeneral( 584 | cfg.emb_dim, 585 | dtype=cfg.dtype, 586 | kernel_axes=('image_patch', 'embed'), 587 | name='image_projection', 588 | )(img_emb) 589 | 590 | # do the text encoding 591 | # [batch, length] -> [batch, length, emb_dim] 592 | txt_emb = self.shared_embedding(text_encoder_inputs.astype('int32')) 593 | 594 | txt_segments = jnp.zeros(txt_emb.shape[1], dtype=jnp.int32)[None,...] 595 | img_segments = jnp.ones(img_emb.shape[1], dtype=jnp.int32)[None,...] 596 | 597 | txt_emb += self.segment_embedding(txt_segments) 598 | img_emb += self.segment_embedding(img_segments) 599 | 600 | txt_emb += txt_pos_emb 601 | img_emb += img_pos_emb 602 | 603 | txt_emb = layers.LayerNorm( 604 | dtype=cfg.dtype, name='txt_emb_pre_ln')(txt_emb) 605 | 606 | img_emb = layers.LayerNorm( 607 | dtype=cfg.dtype, name='img_emb_pre_ln')(img_emb) 608 | 609 | position_embedding =jnp.concatenate([txt_pos_emb, img_pos_emb], axis=1) 610 | 611 | position_embedding = layers.LayerNorm( 612 | dtype=cfg.dtype, name='pe_pre_ln')(position_embedding) 613 | 614 | # get absolute position bias. 615 | pos_q = layers.DenseGeneral( 616 | features=(cfg.num_heads, cfg.head_dim), 617 | dtype=cfg.dtype, 618 | kernel_axes=('embed', 'joined_kv'), 619 | name='position_q_linear', 620 | )(position_embedding) 621 | 622 | pos_k = layers.DenseGeneral( 623 | features=(cfg.num_heads, cfg.head_dim), 624 | dtype=cfg.dtype, 625 | kernel_axes=('embed', 'joined_kv'), 626 | name='position_k_linear', 627 | )(position_embedding) 628 | 629 | pos_scaling = float(cfg.emb_dim / cfg.num_heads) ** -0.5 630 | abs_pos_bias = jnp.einsum('bqhd,bkhd->bhqk', pos_q, pos_k) * pos_scaling 631 | 632 | x = jnp.concatenate([txt_emb, img_emb], axis=1) 633 | x = nn.Dropout( 634 | rate=cfg.dropout_rate, broadcast_dims=(-2,))( 635 | x, deterministic=deterministic) 636 | x = x.astype(cfg.dtype) 637 | 638 | for lyr in range(cfg.num_encoder_layers): 639 | # [batch, length, emb_dim] -> [batch, length, emb_dim] 640 | x = EncoderLayer( 641 | config=cfg, relative_embedding=rel_emb, 642 | name=f'layers_{lyr}')(x, txt_position_ids, img_position_ids, abs_pos_bias, encoder_masks, deterministic) 643 | 644 | x = layers.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x) 645 | return nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=deterministic), position_embedding 646 | 647 | 648 | class Decoder(nn.Module): 649 | """A stack of decoder layers as a part of an encoder-decoder architecture.""" 650 | config: UnifiedIOConfig 651 | shared_embedding: nn.Module 652 | 653 | @nn.compact 654 | def __call__(self, 655 | encoded, 656 | decoder_inputs, 657 | decoder_positions=None, 658 | decoder_segments=None, 659 | decoder_attn_mask=None, 660 | encoder_decoder_mask=None, 661 | deterministic=False, 662 | decode=False, 663 | image_decoder_positions=None, 664 | text_decoder_positions=None, 665 | cur_index=None): 666 | 667 | cfg = self.config 668 | assert decoder_inputs.ndim == 2 # [batch, len] 669 | encoded, encoder_position_embedding = encoded 670 | 671 | rel_emb = layers.RelativePositionBiases( 672 | num_buckets=32, 673 | img_num_buckets=8, 674 | max_distance=128, 675 | img_max_distance=20, 676 | num_heads=cfg.num_heads, 677 | img_width=16, 678 | img_height=16, 679 | dtype=cfg.dtype, 680 | embedding_init=nn.initializers.variance_scaling(1.0, 'fan_avg', 681 | 'uniform'), 682 | name='relpos_bias') 683 | 684 | # [batch, length] -> [batch, length, emb_dim] 685 | y = self.shared_embedding(decoder_inputs.astype('int32')) 686 | 687 | position_embedding = layers.Embed( 688 | num_embeddings=cfg.decoder_max_text_length + cfg.decoder_max_image_length, 689 | features=cfg.emb_dim, 690 | dtype=cfg.dtype, 691 | attend_dtype=jnp.float32, # for logit training stability 692 | embedding_init=nn.initializers.normal(stddev=1.0), 693 | one_hot=True, 694 | name='position_embedding')(decoder_positions) 695 | 696 | if cur_index is None: 697 | y += position_embedding 698 | else: 699 | y += position_embedding[:,cur_index][:,None,:] 700 | 701 | y += layers.Embed( 702 | num_embeddings=cfg.num_seg_emb, 703 | features=cfg.emb_dim, 704 | dtype=cfg.dtype, 705 | attend_dtype=jnp.float32, # for logit training stability 706 | embedding_init=nn.initializers.normal(stddev=1.0), 707 | one_hot=True, 708 | name='segments_embedding')(decoder_segments) 709 | 710 | y = layers.LayerNorm(dtype=cfg.dtype, name='pre_ln')(y) 711 | 712 | position_embedding = layers.LayerNorm( 713 | dtype=cfg.dtype, name='pe_pre_ln')(position_embedding) 714 | 715 | # get absolute position bias. 716 | self_pos_q = layers.DenseGeneral( 717 | features=(cfg.num_heads, cfg.head_dim), 718 | dtype=cfg.dtype, 719 | kernel_axes=('embed', 'joined_kv'), 720 | name='self_position_q_linear', 721 | )(position_embedding) 722 | 723 | self_pos_k = layers.DenseGeneral( 724 | features=(cfg.num_heads, cfg.head_dim), 725 | dtype=cfg.dtype, 726 | kernel_axes=('embed', 'joined_kv'), 727 | name='self_position_k_linear', 728 | )(position_embedding) 729 | 730 | pos_scaling = float(cfg.emb_dim / cfg.num_heads) ** -0.5 731 | self_abs_pos_bias = jnp.einsum('bqhd,bkhd->bhqk', self_pos_q, self_pos_k) * pos_scaling 732 | 733 | # get absolute position bias. 734 | cross_pos_q = layers.DenseGeneral( 735 | features=(cfg.num_heads, cfg.head_dim), 736 | dtype=cfg.dtype, 737 | kernel_axes=('embed', 'joined_kv'), 738 | name='cross_position_q_linear', 739 | )(position_embedding) 740 | 741 | cross_pos_k = layers.DenseGeneral( 742 | features=(cfg.num_heads, cfg.head_dim), 743 | dtype=cfg.dtype, 744 | kernel_axes=('embed', 'joined_kv'), 745 | name='cross_position_k_linear', 746 | )(encoder_position_embedding) 747 | 748 | cross_abs_pos_bias = jnp.einsum('bqhd,bkhd->bhqk', cross_pos_q, cross_pos_k) * pos_scaling 749 | 750 | y = nn.Dropout( 751 | rate=cfg.dropout_rate, broadcast_dims=(-2,))( 752 | y, deterministic=deterministic) 753 | y = y.astype(cfg.dtype) 754 | 755 | for lyr in range(cfg.num_decoder_layers): 756 | # [batch, length, emb_dim] -> [batch, length, emb_dim] 757 | y = DecoderLayer( 758 | config=cfg, 759 | relative_embedding=rel_emb, 760 | name=f'layers_{lyr}')( 761 | y, 762 | encoded, 763 | self_abs_pos_bias, 764 | cross_abs_pos_bias, 765 | decoder_mask=decoder_attn_mask, 766 | encoder_decoder_mask=encoder_decoder_mask, 767 | deterministic=deterministic, 768 | decode=decode, 769 | image_decoder_positions=image_decoder_positions, 770 | text_decoder_positions=text_decoder_positions) 771 | 772 | y = layers.LayerNorm(dtype=cfg.dtype, name='decoder_norm')(y) 773 | y = nn.Dropout( 774 | rate=cfg.dropout_rate, broadcast_dims=(-2,))( 775 | y, deterministic=deterministic) 776 | 777 | # [batch, length, emb_dim] -> [batch, length, vocab_size] 778 | if cfg.logits_via_embedding: 779 | # Use the transpose of embedding matrix for logit transform. 780 | logits = self.shared_embedding.attend(y) 781 | # Correctly normalize pre-softmax logits for this shared case. 782 | logits = logits / jnp.sqrt(y.shape[-1]) 783 | else: 784 | logits = layers.DenseGeneral( 785 | cfg.vocab_size + cfg.image_vocab_size, 786 | dtype=jnp.float32, # Use float32 for stabiliity. 787 | kernel_axes=('embed', 'vocab'), 788 | name='logits_dense')(y) 789 | 790 | return logits 791 | 792 | 793 | class Transformer(nn.Module): 794 | """The ynderlying UnifiedIO network""" 795 | 796 | config: UnifiedIOConfig 797 | vae_config: VAEConfig 798 | 799 | def setup(self): 800 | cfg = self.config 801 | vae_config = self.vae_config 802 | 803 | self.shared_embedding = layers.Embed( 804 | num_embeddings=cfg.vocab_size + cfg.image_vocab_size, 805 | features=cfg.emb_dim, 806 | dtype=cfg.dtype, 807 | attend_dtype=jnp.float32, # for logit training stability 808 | embedding_init=nn.initializers.normal(stddev=1.0), 809 | one_hot=True, 810 | name='token_embedder') 811 | 812 | self.discrete_vae = DiscreteVAE(config=vae_config) 813 | self.encoder = Encoder( 814 | config=cfg, 815 | shared_embedding=self.shared_embedding, 816 | ) 817 | self.decoder = Decoder( 818 | config=cfg, 819 | shared_embedding=self.shared_embedding) 820 | 821 | total_vocab_size = cfg.vocab_size + cfg.image_vocab_size 822 | self.logit_range = jnp.reshape(jnp.arange(total_vocab_size), [1, 1, -1]) 823 | self.image_logits_mask = jnp.reshape(self.logit_range < cfg.vocab_size, [1, -1]) 824 | self.text_logits_mask = jnp.reshape(self.logit_range >= cfg.vocab_size, [1, -1]) 825 | 826 | def encode(self, 827 | text_encoder_inputs, 828 | image_encoder_inputs, 829 | text_encoder_masks, 830 | image_encoder_masks, 831 | image_encoder_pos_ids, 832 | text_encoder_pos_ids, 833 | enable_dropout=True): 834 | """Applies Transformer encoder-branch on the inputs.""" 835 | cfg = self.config 836 | assert text_encoder_inputs.ndim == 2 # (batch, len) 837 | bs = text_encoder_inputs.shape[0] 838 | 839 | if text_encoder_masks is None: 840 | text_encoder_masks = text_encoder_inputs > 0 841 | 842 | if image_encoder_inputs.ndim == 3: 843 | image_length = image_encoder_inputs.shape[1] 844 | else: 845 | image_length = int(np.prod(image_encoder_inputs.shape[1:3]) / (cfg.image_patch_size**2)) 846 | 847 | if image_encoder_masks is None: 848 | image_encoder_masks = jnp.ones([bs, image_length], dtype=jnp.bool_) 849 | 850 | if image_encoder_pos_ids is None: 851 | image_encoder_pos_ids = jnp.arange(image_length, dtype=jnp.int32) 852 | image_encoder_pos_ids = jnp.expand_dims(image_encoder_pos_ids, axis=0) 853 | image_encoder_pos_ids = jnp.tile(image_encoder_pos_ids, [bs, 1]) 854 | 855 | if text_encoder_pos_ids is None: 856 | text_encoder_pos_ids = jnp.arange(text_encoder_inputs.shape[1], dtype=jnp.int32) 857 | text_encoder_pos_ids = jnp.expand_dims(text_encoder_pos_ids, axis=0) 858 | text_encoder_pos_ids = jnp.tile(text_encoder_pos_ids, [bs, 1]) 859 | 860 | encoder_masks = jnp.concatenate([text_encoder_masks, image_encoder_masks], axis=1) 861 | encoder_attn_masks = layers.make_attention_mask( 862 | encoder_masks, encoder_masks, dtype=cfg.dtype) 863 | 864 | return self.encoder( 865 | text_encoder_inputs, 866 | image_encoder_inputs, 867 | text_encoder_pos_ids, 868 | image_encoder_pos_ids, 869 | encoder_attn_masks, 870 | deterministic=not enable_dropout 871 | ), encoder_masks 872 | 873 | def decode( 874 | self, 875 | encoded, 876 | encoder_masks, 877 | text_decoder_inputs, 878 | image_decoder_inputs, 879 | text_decoder_targets, 880 | image_decoder_targets, 881 | text_decoder_masks=None, 882 | image_decoder_masks=None, 883 | text_decoder_segment_ids=None, 884 | text_decoder_positions=None, 885 | enable_dropout=True, 886 | decode=False, 887 | max_decode_length=None): 888 | """Applies Transformer decoder-branch on encoded-input and target.""" 889 | cfg = self.config 890 | 891 | if text_decoder_masks is None: 892 | text_decoder_masks = text_decoder_targets > 0 893 | 894 | if image_decoder_masks is None: 895 | image_decoder_masks = jnp.ones(image_decoder_inputs.shape, dtype=jnp.bool_) 896 | 897 | if text_decoder_segment_ids is not None: 898 | decoder_segment_ids = jnp.concatenate([text_decoder_segment_ids, jnp.ones(image_decoder_masks.shape)], axis=1) 899 | else: 900 | decoder_segment_ids = None 901 | 902 | decoder_masks = jnp.concatenate([text_decoder_masks, image_decoder_masks], axis=1) 903 | decoder_attn_mask = layers.make_decoder_mask( 904 | decoder_target_tokens=decoder_masks, 905 | dtype=cfg.dtype, 906 | decoder_segment_ids=decoder_segment_ids) 907 | 908 | encoder_decoder_mask = layers.make_attention_mask( 909 | decoder_masks, encoder_masks, dtype=cfg.dtype) 910 | 911 | decoder_inputs = jnp.concatenate([text_decoder_inputs, image_decoder_inputs], axis=1) 912 | 913 | if text_decoder_positions is None: 914 | text_decoder_positions = jnp.arange(text_decoder_inputs.shape[1], dtype=jnp.int32)[None,...] 915 | image_decoder_positions = jnp.arange(image_decoder_inputs.shape[1], dtype=jnp.int32)[None,...] 916 | else: 917 | image_decoder_positions = jnp.arange(image_decoder_inputs.shape[1], dtype=jnp.int32)[None,...] 918 | image_decoder_positions = jnp.tile(image_decoder_positions, [image_decoder_inputs.shape[0], 1]) 919 | 920 | decoder_positions = jnp.concatenate([ 921 | text_decoder_positions, 922 | cfg.decoder_max_text_length+image_decoder_positions], 923 | axis=1) 924 | 925 | decoder_segments = jnp.expand_dims( 926 | jnp.concatenate([ 927 | jnp.zeros(text_decoder_inputs.shape[1], dtype=jnp.int32), 928 | jnp.ones(image_decoder_inputs.shape[1], dtype=jnp.int32)], 929 | axis=0), 930 | axis=0) 931 | 932 | logging.info(f"Decode called with EncodeLen={encoded[0].shape[1]}, DecodeInputLen={decoder_inputs.shape[1]}") 933 | logits = self.decoder( 934 | encoded, 935 | decoder_positions=decoder_positions, 936 | decoder_segments=decoder_segments, 937 | decoder_inputs=decoder_inputs, 938 | decoder_attn_mask=decoder_attn_mask, 939 | encoder_decoder_mask=encoder_decoder_mask, 940 | deterministic=not enable_dropout, 941 | decode=decode, 942 | image_decoder_positions=image_decoder_positions, 943 | text_decoder_positions=text_decoder_positions) 944 | 945 | # mask the logits. 946 | text_length = text_decoder_inputs.shape[1] 947 | seq_range = jnp.reshape(jnp.arange(logits.shape[1]), [1, -1, 1]) 948 | logits_mask = (((seq_range >= text_length) & (self.logit_range < cfg.vocab_size)) | 949 | (seq_range < text_length) & (self.logit_range >= cfg.vocab_size)) 950 | logits = jnp.where(logits_mask, -1e10, logits) 951 | text_logits = logits[:,:text_length] 952 | image_logits = logits[:,text_length:] 953 | 954 | return text_logits, image_logits, image_decoder_targets 955 | 956 | def decode_code(self, code_b): 957 | return self.discrete_vae.decode_code(code_b) 958 | 959 | def encode_target_image(self, image): 960 | return self.discrete_vae.get_codebook_indices(image) 961 | 962 | def sample( 963 | self, 964 | encoded, 965 | encoder_masks, 966 | decoder_inputs, 967 | decoder_masks=None, 968 | decoder_segments=None, 969 | enable_dropout=True, 970 | decode=False, 971 | cur_index=None, 972 | image_decode_length=None, 973 | text_decode_length=None): 974 | 975 | cfg = self.config 976 | encoder_decoder_mask = layers.make_attention_mask( 977 | jnp.ones_like(decoder_inputs), 978 | encoder_masks, 979 | dtype=cfg.dtype) 980 | 981 | if decoder_masks is not None: 982 | decoder_attn_mask = layers.make_decoder_mask( 983 | decoder_target_tokens=decoder_masks, 984 | dtype=cfg.dtype) 985 | else: 986 | decoder_attn_mask = None 987 | 988 | image_decoder_positions = jnp.arange(image_decode_length)[None,...] 989 | text_decoder_positions = jnp.arange(text_decode_length)[None,...] 990 | 991 | decoder_positions = jnp.concatenate([ 992 | text_decoder_positions, 993 | cfg.decoder_max_text_length+image_decoder_positions], 994 | axis=1) 995 | 996 | logits = self.decoder( 997 | encoded, 998 | decoder_inputs=decoder_inputs, 999 | decoder_positions=decoder_positions, 1000 | decoder_segments=decoder_segments, 1001 | decoder_attn_mask=decoder_attn_mask, 1002 | encoder_decoder_mask=encoder_decoder_mask, 1003 | deterministic=not enable_dropout, 1004 | decode=decode, 1005 | image_decoder_positions=image_decoder_positions, 1006 | text_decoder_positions=text_decoder_positions, 1007 | cur_index=cur_index) 1008 | 1009 | return logits 1010 | 1011 | def __call__(self, 1012 | text_encoder_inputs, 1013 | image_encoder_inputs, 1014 | text_decoder_inputs, 1015 | image_decoder_targets, 1016 | text_decoder_targets, 1017 | text_encoder_masks=None, 1018 | image_encoder_masks=None, 1019 | text_decoder_masks=None, 1020 | image_decoder_masks=None, 1021 | image_encoder_pos_ids=None, 1022 | text_encoder_pos_ids=None, 1023 | text_decoder_segment_ids=None, 1024 | text_decoder_positions=None, 1025 | *, 1026 | enable_dropout: bool = True, 1027 | decode: bool = False, 1028 | cache_text_length = None, 1029 | cache_image_length = None, 1030 | vae_decode: bool = False, 1031 | return_targets = False 1032 | ): 1033 | """Applies Transformer model on the inputs. 1034 | 1035 | This method requires both decoder_target_tokens and decoder_input_tokens, 1036 | which is a shifted version of the former. For a packed dataset, it usually 1037 | has additional processing applied. For example, the first element of each 1038 | sequence has id 0 instead of the shifted EOS id from the previous sequence. 1039 | 1040 | Args: 1041 | encoder_input_tokens: input data to the encoder. 1042 | decoder_input_tokens: input token to the decoder. 1043 | decoder_target_tokens: target token to the decoder. 1044 | encoder_segment_ids: encoder segmentation info for packed examples. 1045 | decoder_segment_ids: decoder segmentation info for packed examples. 1046 | encoder_positions: encoder subsequence positions for packed examples. 1047 | decoder_positions: decoder subsequence positions for packed examples. 1048 | enable_dropout: Ensables dropout if set to True. 1049 | decode: Whether to prepare and use an autoregressive cache. 1050 | 1051 | Returns: 1052 | logits array from full transformer. 1053 | """ 1054 | cfg = self.config 1055 | 1056 | if image_decoder_targets.shape[1] > 1: 1057 | image_decoder_tokens = self.discrete_vae.get_codebook_indices(image_decoder_targets, vae_decode) # 0 is the start token. 1058 | # stop gradient. 1059 | image_decoder_tokens = image_decoder_tokens + cfg.vocab_size 1060 | image_decoder_tokens = jax.lax.stop_gradient(image_decoder_tokens) 1061 | else: 1062 | # Dummy input image, use a single token as output that will be masked out 1063 | bs = image_decoder_targets.shape[0] 1064 | image_decoder_tokens = jnp.zeros((bs, 1), dtype=jnp.int32) 1065 | # Client should ensure this are also size 1 and mask out the image token 1066 | assert image_decoder_targets.shape[1] == 1 1067 | if image_decoder_masks is not None: 1068 | assert image_decoder_masks.shape[1] == 1 1069 | 1070 | image_decoder_inputs = jnp.concatenate([ 1071 | jnp.zeros((image_decoder_tokens.shape[0], 1), dtype=jnp.int32) + cfg.vocab_size - 1, 1072 | image_decoder_tokens[:,:-1]], axis=1) 1073 | 1074 | encoded, encoder_masks = self.encode( 1075 | text_encoder_inputs, 1076 | image_encoder_inputs, 1077 | text_encoder_masks, 1078 | image_encoder_masks, 1079 | image_encoder_pos_ids, 1080 | text_encoder_pos_ids, 1081 | enable_dropout=enable_dropout) 1082 | 1083 | if cache_image_length is not None: 1084 | image_decoder_inputs = image_decoder_inputs[:,:cache_image_length] 1085 | image_decoder_tokens = image_decoder_tokens[:,:cache_image_length] 1086 | if image_decoder_masks is not None: 1087 | image_decoder_masks = image_decoder_masks[:,:cache_image_length] 1088 | 1089 | if cache_text_length is not None: 1090 | text_decoder_inputs = text_decoder_inputs[:,:cache_text_length] 1091 | text_decoder_targets = text_decoder_targets[:,:cache_text_length] 1092 | if text_decoder_masks is not None: 1093 | text_decoder_masks = text_decoder_masks[:,:cache_text_length] 1094 | 1095 | logits = self.decode( 1096 | encoded, 1097 | encoder_masks, 1098 | text_decoder_inputs, 1099 | image_decoder_inputs, 1100 | text_decoder_targets, 1101 | image_decoder_tokens, 1102 | text_decoder_masks=text_decoder_masks, 1103 | image_decoder_masks=image_decoder_masks, 1104 | text_decoder_segment_ids=text_decoder_segment_ids, 1105 | text_decoder_positions=text_decoder_positions, 1106 | enable_dropout=enable_dropout, 1107 | decode=decode) 1108 | 1109 | if return_targets: 1110 | return logits 1111 | else: 1112 | return logits 1113 | 1114 | 1115 | -------------------------------------------------------------------------------- /uio/runner.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import numpy as np 6 | from absl import logging 7 | from transformers import T5Tokenizer 8 | 9 | from uio.configs import CONFIGS, VAE_CONFIG 10 | from uio import network 11 | from uio import utils 12 | from uio.model import UnifiedIOModel 13 | 14 | CAPTIONING_PROMPT = 'What does the image describe ?' 15 | DEPTH_PROMPT = "What is the depth map of the image ?" 16 | SURFACE_NORMAL_PROMPT = 'What is the surface normal of the image ?' 17 | OBJECT_SEGMENTATION = 'What is the segmentation of " {} " ?' 18 | IMAGE_GENERATION = 'What is the complete image? Text: " {} " .' 19 | REFEXP_PROMPT = 'Which region does the text " {} " describe ?' 20 | REGION_CAPTION = 'What does the region " {} " describe ?' 21 | REGION_CLASSIFICATION = 'What is the category of region " {} " ?' 22 | IMAGE_TAGGING = 'What is this in the image ?' 23 | IMAGE_INPAINTING = 'Filling the blank region " {} " ?' 24 | POSE_ESTIMATION = 'Find the human joints in the region " {} " .' 25 | SEGMENTATION_BASED_GENERATION = 'What is the complete image? Segmentation color: " {} "' 26 | 27 | 28 | GEN_SEGMENTATION_COLORS = np.array([ 29 | [255, 0, 0], 30 | [255, 0, 0], 31 | [0, 255, 0], 32 | [0, 0, 255], 33 | [255, 255, 255], 34 | [128, 128, 128], 35 | [255, 0, 255], 36 | [255, 255, 0], 37 | [0, 255, 255], 38 | [192, 192, 192], 39 | [128, 0, 0], 40 | [128, 128, 0], 41 | [0, 128, 0], 42 | [0, 128, 128], 43 | [0, 0, 128], 44 | [128, 0, 128], 45 | ], dtype=np.uint8) 46 | 47 | 48 | GEN_SEGMENTATION_COLOR_NAMES = [ 49 | "white", 50 | 'red', 51 | 'lime', 52 | 'blue', 53 | 'white', 54 | 'gray', 55 | 'fuchsia', 56 | 'yellow', 57 | 'aqua', 58 | 'silver', 59 | 'maroon', 60 | 'olive', 61 | 'green', 62 | 'teal', 63 | 'navy', 64 | 'purple' 65 | ] 66 | 67 | 68 | class ModelRunner: 69 | """High-level API to run UnifiedIO 70 | 71 | This is intended to provide an easy way test out examples and 72 | to demonstrate the pre-/ post-preprocessing we use for different tasks 73 | """ 74 | 75 | def __init__(self, size, param_file, pad_input_to_max=None, max_input_len=64, 76 | max_options=800, compiled=False, log_inputs=True): 77 | """Construct the ModeRunner 78 | 79 | :param size: Model size (small, base, large, xl) 80 | :param param_file: .bin storing the parameters 81 | :param pad_input_to_max: Always pad input text tokens to this value, this can avoid excess 82 | jax.jit re-compilations when `compiled` is set, defaults to the value of `compiled` 83 | :param max_input_len: if `pad_to_max` is true, the max value to pad to, longer values will 84 | result in more expensive inference. We support up 256 token, but 85 | we default to 64 which is enough for almost any tasks. 86 | :param max_options: For input with answer options, max number of options to process at once 87 | :param compiled: Compile the underlying prediction function, faster inference at a one-time 88 | cost when using the same input shapes 89 | :param log_inputs: Log the input text run on 90 | """ 91 | self.max_input_len = max_input_len 92 | if pad_input_to_max is None: 93 | pad_input_to_max = compiled 94 | self.pad_to_max = pad_input_to_max 95 | self.max_options = max_options 96 | self.compiled = compiled 97 | self.log_inputs = log_inputs 98 | 99 | conf = CONFIGS[size] 100 | module = network.Transformer(config=conf, vae_config=VAE_CONFIG) 101 | 102 | logging.info("Setting up model...") 103 | self.model = UnifiedIOModel(module, text_decoder_length=32, image_decoder_length=1) 104 | 105 | # extra_ids are used as location tokens 106 | # uio is trained to use at most 256 input tokens 107 | self.tokenizer = T5Tokenizer.from_pretrained( 108 | "t5-base", model_max_length=256, extra_ids=1100) 109 | 110 | logging.info("Loading parameters...") 111 | self.params = utils.load_checkpoint(param_file) 112 | logging.info("Model is ready") 113 | 114 | self._compiled_batch_fn = None 115 | self._compiled_option_fn = None 116 | 117 | def _get_batch_fn(self): 118 | if self.compiled: 119 | if self._compiled_batch_fn is None: 120 | self._compiled_batch_fn = jax.jit( 121 | self.model.predict_batch_with_aux, 122 | static_argnums=list(range(3, 9))) 123 | return self._compiled_batch_fn 124 | else: 125 | return self.model.predict_batch_with_aux 126 | 127 | def _get_answer_options_fn(self): 128 | if self.compiled: 129 | if self._compiled_option_fn is None: 130 | self._compiled_option_fn = jax.jit( 131 | self.model.predict_with_answer_options, static_argnums=[2, 3]) 132 | return self._compiled_option_fn 133 | else: 134 | return self.model.predict_with_answer_options 135 | 136 | def run(self, input_images, input_texts, output_text_len=128, generate_image=False, 137 | beam_search=None, num_decodes=None, answer_options=None, 138 | mask_regions=None, average_loss=False) -> Dict: 139 | """Runs UnifiedIO on input images/texts and produces output images/text 140 | 141 | :param input_images: List of images as [h, w, 3] float32/uint8 arrays or None 142 | :param input_texts: List of string prompts 143 | :param output_text_len: Max text tokens to generate, less max tokens will result in faster 144 | inference 145 | :param generate_image: Generate an image, if false inference will be faster 146 | :param beam_search: Use beam search rather than sampling, if None using beam_search when 147 | not generating an image and sampling otherwise 148 | :param num_decodes: if `None` return one generation for an input, otherwise generate a list 149 | `num_decodes` outputs for each example. Also defines the beam size if 150 | doing beam search. 151 | :param answer_options: List of strings or images, limits text/image generation to one of these options 152 | :param mask_regions: Mask these regions from ech image, used for inpainting 153 | :param average_loss: If using answer_options, compute the average per-token loss instead of the 154 | total loss 155 | :return: dictionary outputs with the output text, image, scores and tokens generated 156 | """ 157 | if answer_options is not None: 158 | if num_decodes is not None: 159 | raise NotImplementedError("Not support if `answer_options` is given") 160 | 161 | assert output_text_len <= 128, "128 is the max output text len" 162 | assert len(input_images) == len(input_texts), "Different number of text/image inputs" 163 | 164 | if beam_search is None: 165 | beam_search = not generate_image 166 | 167 | input_tokens = np.array(self.tokenizer( 168 | input_texts, max_length=self.max_input_len, truncation=True, 169 | padding='max_length' if self.pad_to_max else 'longest')["input_ids"], dtype=np.int32) 170 | 171 | image_tensor = [] 172 | mask_tensor = [] 173 | for ix, image in enumerate(input_images): 174 | if image is not None: 175 | assert len(image.shape) == 3 and image.shape[-1] == 3 176 | image, image_mask = utils.preprocess_image( 177 | image, None if mask_regions is None else mask_regions[ix]) 178 | image_tensor.append(image) 179 | mask_tensor.append(image_mask) 180 | 181 | batch = { 182 | 'image_encoder_inputs': np.stack(image_tensor), 183 | 'image_input_masks': np.stack(mask_tensor), 184 | 'text_encoder_inputs': input_tokens, 185 | } 186 | 187 | if not answer_options: 188 | if self.log_inputs: 189 | logging.info(f"Running model text_inputs={input_texts}") 190 | out = self._get_batch_fn()( 191 | params=self.params, batch=batch, text_length=output_text_len, 192 | image_length=256 if generate_image else 1, 193 | beam_search=beam_search, num_decodes=1 if num_decodes is None else num_decodes, 194 | return_all_decodes=True 195 | ) 196 | else: 197 | if isinstance(answer_options[0], str): 198 | # One set of strings options for the entire batch 199 | output_options = np.array(self.tokenizer( 200 | answer_options, max_length=self.max_input_len, truncation=True, 201 | padding='longest')["input_ids"], dtype=np.int32) 202 | output_options = np.expand_dims(output_options, 0) 203 | bs = len(input_texts) 204 | output_options = np.tile(output_options, [bs, 1, 1]) 205 | batch["output_options"] = output_options 206 | elif isinstance(answer_options[0], np.ndarray): 207 | # One set of image options for the entire batch 208 | preprocessed = [utils.preprocess_target_image(x) for x in answer_options] 209 | output_options = np.stack([x[0] for x in preprocessed], 0) 210 | output_options_mask = np.stack([x[1] for x in preprocessed], 0) 211 | bs = len(input_texts) 212 | # [batch, n_options, h, w, c] 213 | output_options = np.tile(np.expand_dims(output_options, 0), [bs, 1, 1, 1, 1]) 214 | # [batch, n_options, n_patches] 215 | output_options_mask = np.tile(np.expand_dims(output_options_mask, 0), [bs, 1, 1]) 216 | batch["output_options"] = output_options 217 | batch["output_options_masks"] = output_options_mask 218 | else: 219 | raise NotImplementedError("Per-example answer options") 220 | 221 | if self.log_inputs: 222 | logging.info(f"Running model text_inputs={input_texts} and " 223 | f"{output_options.shape[1]} answer options") 224 | out = self._get_answer_options_fn()( 225 | params=self.params, batch=batch, max_options=self.max_options, average_loss=average_loss) 226 | # Add a fake beam dimensi7on to be compatible with the no answer options case 227 | out = {k: jnp.expand_dims(v, 1) for k, v in out.items()} 228 | 229 | if generate_image: 230 | output_image = out["image"] 231 | else: 232 | output_image = None 233 | 234 | if output_text_len > 1: 235 | output_text = [] 236 | for batch_out in out["text_tokens"]: 237 | beam_text = [] 238 | for beam_out in batch_out: 239 | row = np.array(beam_out) 240 | # Manually cutoff at the EOS since jax beam search method will generate tokens beyond it 241 | eos = np.where(row == 1)[0] 242 | if len(eos) != 0: 243 | row = row[:np.min(eos)] 244 | text = self.tokenizer.decode(row, skip_special_tokens=False) 245 | beam_text.append(text) 246 | output_text.append(beam_text) 247 | else: 248 | output_text = None 249 | 250 | if num_decodes is None: 251 | if output_text is not None: 252 | output_text = [x[0] for x in output_text] 253 | if output_image is not None: 254 | output_image = [x[0] for x in output_image] 255 | outputs = dict( 256 | text_tokens=np.array(out["text_tokens"]) if "text_tokens" in out else None, 257 | text=output_text, 258 | image_tokens=np.array(out["image_tokens"]) if "image_tokens" in out else None, 259 | image=np.array(output_image), 260 | score=np.array(out["scores"]), 261 | ) 262 | if "all_scores" in out: 263 | outputs["all_scores"] = np.array(out["all_scores"]) 264 | return outputs 265 | 266 | def _extract_text(self, out): 267 | return {k: out[k][0] for k in ["text", "score"]} 268 | 269 | def _extract_image(self, out): 270 | return {k: out[k][0] for k in ["image", "score"]} 271 | 272 | def _extract_pose(self, out, image_size): 273 | tokens = out["text_tokens"][0] 274 | if len(tokens) == 1: 275 | points, labels, invalid = utils.extract_keypoints( 276 | tokens[0], self.tokenizer, image_size) 277 | else: 278 | kp = [] 279 | for line in tokens: 280 | kp.append(utils.extract_keypoints(line, self.tokenizer, image_size)) 281 | points, labels, invalid = utils.transpose_lists(kp) 282 | 283 | out = dict(points=points, labels=labels, invalid=invalid, 284 | score=out["score"], text_tokens=out["text_tokens"]) 285 | return out 286 | 287 | def _extract_boxes(self, out, image_size, include_labels=False): 288 | tokens = out["text_tokens"][0] 289 | if len(tokens) == 1: 290 | all_labels, all_boxes = utils.tokens_to_regions(tokens[0], image_size) 291 | else: 292 | all_boxes = [] 293 | all_labels = [] 294 | for line in tokens: 295 | labels, boxes = utils.tokens_to_regions(line, image_size) 296 | all_labels.append(labels) 297 | all_boxes.append(boxes) 298 | 299 | out = dict(boxes=all_boxes, text=out["text"], score=out["score"], text_tokens=out["text_tokens"]) 300 | if include_labels: 301 | out["labels"] = all_labels 302 | return out 303 | 304 | def caption(self, image, num_decodes=None) -> Dict: 305 | """Generate a caption for `image`""" 306 | out = self.run([image], [CAPTIONING_PROMPT], output_text_len=32, 307 | generate_image=False, num_decodes=num_decodes) 308 | return self._extract_text(out) 309 | 310 | def vqa(self, image, question, num_decodes=None) -> Dict: 311 | """Answer `question` for `image`""" 312 | # We trained on lowercase question so lowercasing is recommended 313 | out = self.run([image], [question.lower()], output_text_len=32, 314 | generate_image=False, num_decodes=num_decodes) 315 | return self._extract_text(out) 316 | 317 | def depth(self, image, num_decodes=None, beam_search=None) -> Dict: 318 | """Produce a grayscale depth map for `image`""" 319 | out = self.run([image], [DEPTH_PROMPT], output_text_len=1, generate_image=True, 320 | num_decodes=num_decodes, beam_search=beam_search) 321 | rescaled_image = utils.undo_image_preprocessing(out["image"][0], image.shape[:2]) 322 | return { 323 | "image": out["image"][0], 324 | "rescaled_image": rescaled_image, 325 | "score": out["score"][0], 326 | } 327 | 328 | def surface_normal(self, image, num_decodes=None, beam_search=None) -> Dict: 329 | """Produce a RGB surface normal map for `image`""" 330 | out = self.run([image], [SURFACE_NORMAL_PROMPT], output_text_len=1, generate_image=True, 331 | num_decodes=num_decodes, beam_search=beam_search) 332 | # Rescale the output image to the size of the original image 333 | rescaled_image = utils.undo_image_preprocessing(out["image"][0], image.shape[:2]) 334 | return { 335 | "image": out["image"][0], 336 | "rescaled_image": rescaled_image, 337 | "score": out["score"][0], 338 | } 339 | 340 | def image_generation(self, description, num_decodes=None) -> Dict: 341 | """Generate an image based on `description`""" 342 | prompt = IMAGE_GENERATION.replace("{}", description) 343 | out = self.run( 344 | [None], [prompt], output_text_len=1, generate_image=True, num_decodes=num_decodes) 345 | return self._extract_image(out) 346 | 347 | def image_inpainting(self, image, location, replace_with: str, num_decodes=None) -> Dict: 348 | """Generate an image with `location` in-painted with `replace_with`""" 349 | region = utils.region_to_tokens(location, image.shape[1], image.shape[0]) 350 | region.append(replace_with) 351 | prompt = IMAGE_INPAINTING.replace("{}", " ".join(region)) 352 | out = self.run( 353 | [image], [prompt], output_text_len=1, generate_image=True, num_decodes=num_decodes, 354 | mask_regions=[np.array(location)] 355 | ) 356 | return self._extract_image(out) 357 | 358 | def object_segmentation(self, image, object_name, num_decodes=None) -> Dict: 359 | """Generate instances masks for occurrences of `object_name` in `image`""" 360 | prompt = OBJECT_SEGMENTATION.replace("{}", object_name) 361 | out = self.run( 362 | [image], [prompt], output_text_len=1, generate_image=True, num_decodes=num_decodes) 363 | if num_decodes is None: 364 | masks = utils.extract_segmentation_masks(out["image"][0]) 365 | else: 366 | masks = [utils.extract_segmentation_masks(x) for x in out["image"][0]] 367 | return dict(mask=masks, image=out["image"][0], score=out["score"][0]) 368 | 369 | def refexp(self, image, expression, num_decodes=None) -> Dict: 370 | """Return the `location` corresponding to `expression`""" 371 | prompt = REFEXP_PROMPT.replace("{}", expression) 372 | out = self.run( 373 | [image], [prompt], output_text_len=32, generate_image=False, num_decodes=num_decodes) 374 | return self._extract_boxes(out, image.shape) 375 | 376 | def object_localization(self, image, object_name, num_decodes=None) -> Dict: 377 | """Return the `locations` of `object_name` in `image`""" 378 | # Same prompt/setup as refex 379 | return self.refexp(image, object_name, num_decodes) 380 | 381 | def region_caption(self, image, location, num_decodes=None) -> Dict: 382 | """Generate a caption for `location` in `image`""" 383 | region = utils.region_to_tokens(location, image.shape[1], image.shape[0]) 384 | prompt = REGION_CAPTION.replace("{}", " ".join(region)) 385 | out = self.run( 386 | [image], [prompt], output_text_len=32, generate_image=False, num_decodes=num_decodes) 387 | return self._extract_text(out) 388 | 389 | def region_classification(self, image, location, num_decodes=None, answer_options=None) -> Dict: 390 | """Return the class of the object in `location` in `image`, 391 | constrain the outputs to `answer_options` if given""" 392 | region = utils.region_to_tokens(location, image.shape[1], image.shape[0]) 393 | prompt = REGION_CLASSIFICATION.replace("{}", " ".join(region)) 394 | out = self.run( 395 | [image], [prompt], output_text_len=32, generate_image=False, 396 | num_decodes=num_decodes, answer_options=answer_options) 397 | return self._extract_text(out) 398 | 399 | def image_classification(self, image, num_decodes=None, answer_options=None) -> Dict: 400 | """Return the class of the `image`, constrain the outputs to `answer_options` if given""" 401 | out = self.run( 402 | [image], [IMAGE_TAGGING], output_text_len=32, generate_image=False, 403 | num_decodes=num_decodes, answer_options=answer_options) 404 | return self._extract_text(out) 405 | 406 | def pose(self, image, location, num_decodes=None) -> Dict: 407 | """Return points and labels of human joints in `location`""" 408 | region = utils.region_to_tokens(location, image.shape[1], image.shape[0]) 409 | prompt = POSE_ESTIMATION.replace("{}", " ".join(region)) 410 | out = self.run( 411 | [image], [prompt], output_text_len=128, generate_image=False, 412 | num_decodes=num_decodes, beam_search=False) 413 | return self._extract_pose(out, image.shape[:2]) 414 | 415 | def segmentation_based_generation( 416 | self, binary_masks: List[np.ndarray], labels: List[str], num_decodes=None) -> Dict: 417 | """Return an image where pixels in each `binary_mask` belong to corresponding class in 418 | `labels""" 419 | assert len(binary_masks) <= len(GEN_SEGMENTATION_COLOR_NAMES) 420 | assert len(binary_masks) == len(labels) 421 | assert len(binary_masks) > 0 422 | 423 | h, w = binary_masks[0].shape 424 | image = np.zeros((h, w, 3), dtype=np.uint8) 425 | for ix, mask in enumerate(binary_masks): 426 | image[mask, :] = GEN_SEGMENTATION_COLORS[ix] 427 | text = " , ".join(f"{a} : {b}" for a, b in zip(labels, GEN_SEGMENTATION_COLOR_NAMES)) 428 | text = text.lower() 429 | prompt = SEGMENTATION_BASED_GENERATION.replace("{}", text) 430 | out = self.run( 431 | [image], [prompt], output_text_len=1, generate_image=True, num_decodes=num_decodes) 432 | return self._extract_image(out) 433 | -------------------------------------------------------------------------------- /uio/t5x_layers.py: -------------------------------------------------------------------------------- 1 | # Modified from code from T5X (https://github.com/google-research/t5x) 2 | 3 | import dataclasses 4 | import functools 5 | import operator 6 | from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union 7 | 8 | import einops 9 | import jax 10 | from jax import random 11 | from flax import linen as nn 12 | from jax import lax 13 | import jax.numpy as jnp 14 | from flax.linen.module import Module, compact, merge_param 15 | 16 | import numpy as np 17 | 18 | 19 | from flax.linen import partitioning as nn_partitioning 20 | 21 | default_kernel_init = nn.initializers.lecun_normal() 22 | 23 | # from flax.linen.partitioning import param_with_axes, with_sharding_constraint 24 | param_with_axes = nn_partitioning.param_with_axes 25 | with_sharding_constraint = nn_partitioning.with_sharding_constraint 26 | 27 | 28 | # Type annotations 29 | Array = jnp.ndarray 30 | DType = jnp.dtype 31 | PRNGKey = jnp.ndarray 32 | Shape = Iterable[int] 33 | Activation = Callable[..., Array] 34 | Axes = Union[int, Iterable[int]] 35 | 36 | # Parameter initializers. 37 | Initializer = Callable[[PRNGKey, Shape, DType], Array] 38 | 39 | default_embed_init = nn.initializers.variance_scaling( 40 | 1.0, 'fan_in', 'normal', out_axis=0) 41 | 42 | 43 | def reverse_space_to_depth( 44 | frames: jnp.ndarray, 45 | temporal_block_size: int = 1, 46 | spatial_block_size: int = 1) -> jnp.ndarray: 47 | """Reverse space to depth transform.""" 48 | if len(frames.shape) == 4: 49 | return einops.rearrange( 50 | frames, 'b h w (dh dw c) -> b (h dh) (w dw) c', 51 | dh=spatial_block_size, dw=spatial_block_size) 52 | elif len(frames.shape) == 5: 53 | return einops.rearrange( 54 | frames, 'b t h w (dt dh dw c) -> b (t dt) (h dh) (w dw) c', 55 | dt=temporal_block_size, dh=spatial_block_size, dw=spatial_block_size) 56 | else: 57 | raise ValueError( 58 | 'Frames should be of rank 4 (batch, height, width, channels)' 59 | ' or rank 5 (batch, time, height, width, channels)') 60 | 61 | 62 | def space_to_depth( 63 | frames: jnp.ndarray, 64 | temporal_block_size: int = 1, 65 | spatial_block_size: int = 1) -> jnp.ndarray: 66 | """Space to depth transform.""" 67 | if len(frames.shape) == 4: 68 | return einops.rearrange( 69 | frames, 'b (h dh) (w dw) c -> b (h w) (dh dw c)', 70 | dh=spatial_block_size, dw=spatial_block_size) 71 | elif len(frames.shape) == 5: 72 | return einops.rearrange( 73 | frames, 'b (t dt) (h dh) (w dw) c -> b t (h w) (dt dh dw c)', 74 | dt=temporal_block_size, dh=spatial_block_size, dw=spatial_block_size) 75 | else: 76 | raise ValueError( 77 | 'Frames should be of rank 4 (batch, height, width, channels)' 78 | ' or rank 5 (batch, time, height, width, channels)') 79 | 80 | 81 | def dot_product_attention(query: Array, 82 | key: Array, 83 | value: Array, 84 | bias: Optional[Array] = None, 85 | dropout_rng: Optional[PRNGKey] = None, 86 | dropout_rate: float = 0., 87 | deterministic: bool = False, 88 | dtype: DType = jnp.float32, 89 | float32_logits: bool = False): 90 | """Computes dot-product attention given query, key, and value. 91 | 92 | This is the core function for applying attention based on 93 | https://arxiv.org/abs/1706.03762. It calculates the attention weights given 94 | query and key and combines the values using the attention weights. 95 | 96 | Args: 97 | query: queries for calculating attention with shape of `[batch, q_length, 98 | num_heads, qk_depth_per_head]`. 99 | key: keys for calculating attention with shape of `[batch, kv_length, 100 | num_heads, qk_depth_per_head]`. 101 | value: values to be used in attention with shape of `[batch, kv_length, 102 | num_heads, v_depth_per_head]`. 103 | bias: bias for the attention weights. This should be broadcastable to the 104 | shape `[batch, num_heads, q_length, kv_length]` This can be used for 105 | incorporating causal masks, padding masks, proximity bias, etc. 106 | dropout_rng: JAX PRNGKey: to be used for dropout 107 | dropout_rate: dropout rate 108 | deterministic: bool, deterministic or not (to apply dropout) 109 | dtype: the dtype of the computation (default: float32) 110 | float32_logits: bool, if True then compute logits in float32 to avoid 111 | numerical issues with bfloat16. 112 | 113 | Returns: 114 | Output of shape `[batch, length, num_heads, v_depth_per_head]`. 115 | """ 116 | assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' 117 | assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], ( 118 | 'q, k, v batch dims must match.') 119 | assert query.shape[-2] == key.shape[-2] == value.shape[-2], ( 120 | 'q, k, v num_heads must match.') 121 | assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.' 122 | assert query.shape[-1] == key.shape[-1], 'q, k depths must match.' 123 | 124 | # Casting logits and softmax computation for float32 for model stability. 125 | if float32_logits: 126 | query = query.astype(jnp.float32) 127 | key = key.astype(jnp.float32) 128 | 129 | # `attn_weights`: [batch, num_heads, q_length, kv_length] 130 | attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key) 131 | 132 | # Apply attention bias: masking, dropout, proximity bias, etc. 133 | if bias is not None: 134 | attn_weights = attn_weights + bias.astype(attn_weights.dtype) 135 | # Normalize the attention weights across `kv_length` dimension. 136 | attn_weights = jax.nn.softmax(attn_weights).astype(dtype) 137 | 138 | # Apply attention dropout. 139 | if not deterministic and dropout_rate > 0.: 140 | keep_prob = 1.0 - dropout_rate 141 | # T5 broadcasts along the "length" dim, but unclear which one that 142 | # corresponds to in positional dimensions here, assuming query dim. 143 | dropout_shape = list(attn_weights.shape) 144 | dropout_shape[-2] = 1 145 | keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) 146 | keep = jnp.broadcast_to(keep, attn_weights.shape) 147 | multiplier = ( 148 | keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype)) 149 | attn_weights = attn_weights * multiplier 150 | 151 | # Take the linear combination of `value`. 152 | return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value) 153 | 154 | 155 | dynamic_vector_slice_in_dim = jax.vmap( 156 | lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) 157 | 158 | 159 | class MultiHeadDotProductAttention(nn.Module): 160 | """Multi-head dot-product attention. 161 | 162 | Attributes: 163 | num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) 164 | should be divisible by the number of heads. 165 | head_dim: dimension of each head. 166 | dtype: the dtype of the computation. 167 | dropout_rate: dropout rate 168 | kernel_init: initializer for the kernel of the Dense layers. 169 | float32_logits: bool, if True then compute logits in float32 to avoid 170 | numerical issues with bfloat16. 171 | """ 172 | 173 | num_heads: int 174 | head_dim: int 175 | dtype: DType = jnp.float32 176 | dropout_rate: float = 0. 177 | kernel_init: Initializer = nn.initializers.variance_scaling( 178 | 1.0, 'fan_in', 'normal') 179 | float32_logits: bool = False # computes logits in float32 for stability. 180 | 181 | @nn.compact 182 | def __call__(self, 183 | inputs_q: Array, 184 | inputs_kv: Array, 185 | mask: Optional[Array] = None, 186 | bias: Optional[Array] = None, 187 | abs_bias: Optional[Array] = None, 188 | *, 189 | decode: bool = False, 190 | deterministic: bool = False) -> Array: 191 | """Applies multi-head dot product attention on the input data. 192 | 193 | Projects the inputs into multi-headed query, key, and value vectors, 194 | applies dot-product attention and project the results to an output vector. 195 | 196 | There are two modes: decoding and non-decoding (e.g., training). The mode is 197 | determined by `decode` argument. For decoding, this method is called twice, 198 | first to initialize the cache and then for an actual decoding process. The 199 | two calls are differentiated by the presence of 'cached_key' in the variable 200 | dict. In the cache initialization stage, the cache variables are initialized 201 | as zeros and will be filled in the subsequent decoding process. 202 | 203 | In the cache initialization call, `inputs_q` has a shape [batch, length, 204 | q_features] and `inputs_kv`: [batch, length, kv_features]. During the 205 | incremental decoding stage, query, key and value all have the shape [batch, 206 | 1, qkv_features] corresponding to a single step. 207 | 208 | Args: 209 | inputs_q: input queries of shape `[batch, q_length, q_features]`. 210 | inputs_kv: key/values of shape `[batch, kv_length, kv_features]`. 211 | mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`. 212 | bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`. 213 | decode: Whether to prepare and use an autoregressive cache. 214 | deterministic: Disables dropout if set to True. 215 | 216 | Returns: 217 | output of shape `[batch, length, q_features]`. 218 | """ 219 | projection = functools.partial( 220 | DenseGeneral, 221 | axis=-1, 222 | features=(self.num_heads, self.head_dim), 223 | kernel_axes=('embed', 'joined_kv'), 224 | dtype=self.dtype) 225 | 226 | # NOTE: T5 does not explicitly rescale the attention logits by 227 | # 1/sqrt(depth_kq)! This is folded into the initializers of the 228 | # linear transformations, which is equivalent under Adafactor. 229 | depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) 230 | query_init = lambda *args: self.kernel_init(*args) / depth_scaling 231 | 232 | # Project inputs_q to multi-headed q/k/v 233 | # dimensions are then [batch, length, num_heads, head_dim] 234 | query = projection(kernel_init=query_init, name='query')(inputs_q) 235 | key = projection(kernel_init=self.kernel_init, name='key')(inputs_kv) 236 | value = projection(kernel_init=self.kernel_init, name='value')(inputs_kv) 237 | 238 | query = with_sharding_constraint(query, ('batch', 'length', 'heads', 'kv')) 239 | key = with_sharding_constraint(key, ('batch', 'length', 'heads', 'kv')) 240 | value = with_sharding_constraint(value, ('batch', 'length', 'heads', 'kv')) 241 | 242 | if decode: 243 | # Detect if we're initializing by absence of existing cache data. 244 | is_initialized = self.has_variable('cache', 'cached_key') 245 | # The key and value have dimension [batch, length, num_heads, head_dim], 246 | # but we cache them as [batch, num_heads, head_dim, length] as a TPU 247 | # fusion optimization. This also enables the "scatter via one-hot 248 | # broadcast" trick, which means we do a one-hot broadcast instead of a 249 | # scatter/gather operations, resulting in a 3-4x speedup in practice. 250 | swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3]) 251 | cached_key = self.variable('cache', 'cached_key', jnp.zeros, 252 | swap_dims(key.shape), key.dtype) 253 | cached_value = self.variable('cache', 'cached_value', jnp.zeros, 254 | swap_dims(value.shape), value.dtype) 255 | cache_index = self.variable('cache', 'cache_index', 256 | lambda: jnp.array(0, dtype=jnp.int32)) 257 | cache_mask = self.variable('cache', 'cache_mask', jnp.zeros, 258 | (query.shape[0], 1, 1, query.shape[1]), jnp.float32) 259 | if is_initialized: 260 | batch, num_heads, head_dim, length = (cached_key.value.shape) 261 | # During fast autoregressive decoding, we feed one position at a time, 262 | # and cache the keys and values step by step. 263 | # Sanity shape check of cached key against input query. 264 | expected_shape = (batch, 1, num_heads, head_dim) 265 | if expected_shape != query.shape: 266 | raise ValueError('Autoregressive cache shape error, ' 267 | 'expected query shape %s instead got %s.' % 268 | (expected_shape, query.shape)) 269 | 270 | # Create a OHE of the current index. NOTE: the index is increased below. 271 | cur_index = cache_index.value 272 | one_hot_indices = jax.nn.one_hot(cur_index, length, dtype=key.dtype) 273 | # In order to update the key, value caches with the current key and 274 | # value, we move the length axis to the back, similar to what we did for 275 | # the cached ones above. 276 | # Note these are currently the key and value of a single position, since 277 | # we feed one position at a time. 278 | one_token_key = jnp.moveaxis(key, -3, -1) 279 | one_token_value = jnp.moveaxis(value, -3, -1) 280 | # Update key, value caches with our new 1d spatial slices. 281 | # We implement an efficient scatter into the cache via one-hot 282 | # broadcast and addition. 283 | key = cached_key.value + one_token_key * one_hot_indices 284 | value = cached_value.value + one_token_value * one_hot_indices 285 | cached_key.value = key 286 | cached_value.value = value 287 | cache_index.value = cache_index.value + 1 288 | # Move the keys and values back to their original shapes. 289 | key = jnp.moveaxis(key, -1, -3) 290 | value = jnp.moveaxis(value, -1, -3) 291 | 292 | # Causal mask for cached decoder self-attention: our single query 293 | # position should only attend to those key positions that have already 294 | # been generated and cached, not the remaining zero elements. 295 | # mask = jnp.logical_or(cache_mask.value, mask).astype(jnp.int32) 296 | # cache_mask.value = mask 297 | 298 | # if cur_index == 20: 299 | # import ipdb; ipdb.set_trace() 300 | 301 | mask = (cache_mask.value + mask * one_hot_indices).astype(jnp.float32) 302 | cache_mask.value = mask 303 | 304 | mask = combine_masks( 305 | mask, 306 | jnp.broadcast_to( 307 | jnp.arange(length) <= cur_index, 308 | # (1, 1, length) represent (head dim, query length, key length) 309 | # query length is 1 because during decoding we deal with one 310 | # index. 311 | # The same mask is applied to all batch elements and heads. 312 | (batch, 1, 1, length))) 313 | 314 | 315 | # Grab the correct relative attention bias during decoding. This is 316 | # only required during single step decoding. 317 | if bias is not None: 318 | # The bias is a full attention matrix, but during decoding we only 319 | # have to take a slice of it. 320 | # This is equivalent to bias[..., cur_index:cur_index+1, :]. 321 | bias = dynamic_vector_slice_in_dim( 322 | jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2) 323 | 324 | abs_bias = dynamic_vector_slice_in_dim( 325 | jnp.squeeze(abs_bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2) 326 | 327 | # Convert the boolean attention mask to an attention bias. 328 | if mask is not None: 329 | # attention mask in the form of attention bias 330 | attention_bias = lax.select( 331 | mask > 0, 332 | jnp.full(mask.shape, 0.).astype(self.dtype), 333 | jnp.full(mask.shape, -1e10).astype(self.dtype)) 334 | else: 335 | attention_bias = None 336 | 337 | # Add provided bias term (e.g. relative position embedding). 338 | if bias is not None: 339 | attention_bias = combine_biases(attention_bias, bias, abs_bias) 340 | 341 | dropout_rng = None 342 | if not deterministic and self.dropout_rate > 0.: 343 | dropout_rng = self.make_rng('dropout') 344 | 345 | # Apply attention. 346 | x = dot_product_attention( 347 | query, 348 | key, 349 | value, 350 | bias=attention_bias, 351 | dropout_rng=dropout_rng, 352 | dropout_rate=self.dropout_rate, 353 | deterministic=deterministic, 354 | dtype=self.dtype, 355 | float32_logits=self.float32_logits) 356 | 357 | # Back to the original inputs dimensions. 358 | out = DenseGeneral( 359 | features=inputs_q.shape[-1], # output dim is set to the input dim. 360 | axis=(-2, -1), 361 | kernel_init=self.kernel_init, 362 | kernel_axes=('joined_kv', 'embed'), 363 | dtype=self.dtype, 364 | name='out')( 365 | x) 366 | return out 367 | 368 | 369 | def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]: 370 | # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. 371 | return tuple([ax if ax >= 0 else ndim + ax for ax in axes]) 372 | 373 | 374 | def _canonicalize_tuple(x): 375 | if isinstance(x, Iterable): 376 | return tuple(x) 377 | else: 378 | return (x,) 379 | 380 | #------------------------------------------------------------------------------ 381 | # Convolution layers 382 | #------------------------------------------------------------------------------ 383 | 384 | class VectorQuantizer(nn.Module): 385 | n_e: int 386 | e_dim: int 387 | beta: float = 0.25 388 | embedding_init: Initializer = default_embed_init 389 | dtype: Any = jnp.float32 390 | 391 | def setup(self): 392 | self.embedding = param_with_axes( 393 | 'embedding', 394 | self.embedding_init, (self.n_e, self.e_dim), 395 | jnp.float32, 396 | axes=(('vocab', 'embed'))) 397 | 398 | def get_codebook_entry(self, indices): 399 | min_encodings = jax.nn.one_hot(indices, self.n_e, dtype=self.dtype) 400 | z_q = jnp.einsum('bqk,kd->bqd', min_encodings, self.embedding) 401 | return z_q 402 | 403 | @nn.compact 404 | def __call__(self, z: Array) -> Array: 405 | 406 | z_flattened = jnp.reshape(z, (-1, self.e_dim)) 407 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 408 | d = jnp.sum(z_flattened ** 2, axis=1, keepdims=True) + \ 409 | jnp.sum(self.embedding ** 2, axis=1) - 2 * \ 410 | jnp.einsum('ij,kj->ik', z_flattened, self.embedding) 411 | 412 | min_encoding_indices = jnp.argmin(d, axis=1) 413 | z_q = jnp.asarray(self.embedding, self.dtype)[min_encoding_indices] 414 | z_q = jnp.reshape(z_q, z.shape) 415 | 416 | perplexity = None 417 | min_encodings = None 418 | loss = jnp.mean((jax.lax.stop_gradient(z_q)-z)**2) + self.beta * \ 419 | jnp.mean((z_q - jax.lax.stop_gradient(z)) ** 2) 420 | 421 | z_q = z + jax.lax.stop_gradient(z_q - z) 422 | 423 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices) 424 | 425 | 426 | def nonlinearity(x): 427 | # swish 428 | return x*nn.sigmoid(x) 429 | 430 | def _conv_dimension_numbers(input_shape): 431 | """Computes the dimension numbers based on the input shape.""" 432 | ndim = len(input_shape) 433 | lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1)) 434 | rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2)) 435 | out_spec = lhs_spec 436 | return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec) 437 | 438 | class Conv(nn.Module): 439 | """Convolution Module with flexible axes. 440 | Attributes: 441 | features: number of convolution filters. 442 | kernel_size: shape of the convolutional kernel. For 1D convolution, 443 | the kernel size can be passed as an integer. For all other cases, it must 444 | be a sequence of integers. 445 | strides: an integer or a sequence of `n` integers, representing the 446 | inter-window strides (default: 1). 447 | padding: either the string `'SAME'`, the string `'VALID'`, the string 448 | `'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low, 449 | high)` integer pairs that give the padding to apply before and after each 450 | spatial dimension. 451 | use_bias: whether to add a bias to the output (default: True). 452 | dtype: the dtype of the computation (default: float32). 453 | kernel_init: initializer for the convolutional kernel. 454 | bias_init: initializer for the bias. 455 | """ 456 | features: int 457 | kernel_size: Iterable[int] 458 | strides: Union[None, int, Iterable[int]] = 1 459 | padding: Union[str, Iterable[Tuple[int, int]]] = 'SAME' 460 | input_dilation: Union[None, int, Iterable[int]] = 1 461 | kernel_dilation: Union[None, int, Iterable[int]] = 1 462 | feature_group_count: int = 1 463 | use_bias: bool = True 464 | dtype: DType = jnp.float32 465 | param_dtype: DType = jnp.float32 466 | kernel_init: Initializer = default_kernel_init 467 | bias_init: Initializer = nn.initializers.zeros 468 | precision: Any = None 469 | kernel_axes: Tuple[str, ...] = () 470 | bias_axes: Tuple[str, ...] = () 471 | 472 | @nn.compact 473 | def __call__(self, inputs: Array) -> Array: 474 | """Applies a convolution to the inputs. 475 | 476 | Args: 477 | inputs: input data with dimensions (batch, spatial_dims..., features). 478 | This is the channels-last convention, i.e. NHWC for a 2d convolution 479 | and NDHWC for a 3D convolution. Note: this is different from the input 480 | convention used by `lax.conv_general_dilated`, which puts the spatial 481 | dimensions last. 482 | Returns: 483 | The convolved data. 484 | """ 485 | inputs = jnp.asarray(inputs, self.dtype) 486 | if isinstance(self.kernel_size, int): 487 | raise TypeError('The kernel size must be specified as a' 488 | ' tuple/list of integers (eg.: [3, 3]).') 489 | else: 490 | kernel_size = tuple(self.kernel_size) 491 | 492 | def maybe_broadcast(x): 493 | if x is None: 494 | # backward compatibility with using None as sentinel for 495 | # broadcast 1 496 | x = 1 497 | if isinstance(x, int): 498 | return (x,) * len(kernel_size) 499 | return x 500 | 501 | is_single_input = False 502 | if inputs.ndim == len(kernel_size) + 1: 503 | is_single_input = True 504 | inputs = jnp.expand_dims(inputs, axis=0) 505 | 506 | strides = maybe_broadcast(self.strides) # self.strides or (1,) * (inputs.ndim - 2) 507 | input_dilation = maybe_broadcast(self.input_dilation) 508 | kernel_dilation = maybe_broadcast(self.kernel_dilation) 509 | 510 | in_features = inputs.shape[-1] 511 | assert in_features % self.feature_group_count == 0 512 | kernel_shape = kernel_size + ( 513 | in_features // self.feature_group_count, self.features) 514 | 515 | kernel = param_with_axes( 516 | 'kernel', 517 | self.kernel_init, 518 | kernel_shape, 519 | self.param_dtype, 520 | axes=self.kernel_axes) 521 | kernel = jnp.asarray(kernel, self.dtype) 522 | if self.padding == 'CIRCULAR': 523 | kernel_size_dilated = [(k - 1) * d + 1 for k, d in zip(kernel_size, kernel_dilation)] 524 | pads = [(0, 0)] + [((k - 1) // 2, k // 2) for k in kernel_size_dilated] + [(0, 0)] 525 | inputs = jnp.pad(inputs, pads, mode='wrap') 526 | padding_lax = 'VALID' 527 | else: 528 | padding_lax = self.padding 529 | 530 | dimension_numbers = _conv_dimension_numbers(inputs.shape) 531 | y = lax.conv_general_dilated( 532 | inputs, 533 | kernel, 534 | strides, 535 | padding_lax, 536 | lhs_dilation=input_dilation, 537 | rhs_dilation=kernel_dilation, 538 | dimension_numbers=dimension_numbers, 539 | feature_group_count=self.feature_group_count, 540 | precision=self.precision) 541 | 542 | if is_single_input: 543 | y = jnp.squeeze(y, axis=0) 544 | if self.use_bias: 545 | bias = param_with_axes( 546 | 'bias', 547 | self.bias_init, 548 | (self.features,), 549 | self.param_dtype, 550 | axes=self.bias_axes) 551 | 552 | bias = jnp.asarray(bias, self.dtype) 553 | y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) 554 | return y 555 | 556 | #------------------------------------------------------------------------------ 557 | # DenseGeneral for attention layers. 558 | #------------------------------------------------------------------------------ 559 | class DenseGeneral(nn.Module): 560 | """A linear transformation (without bias) with flexible axes. 561 | 562 | Attributes: 563 | features: tuple with numbers of output features. 564 | axis: tuple with axes to apply the transformation on. 565 | dtype: the dtype of the computation (default: float32). 566 | kernel_init: initializer function for the weight matrix. 567 | """ 568 | features: Union[Iterable[int], int] 569 | axis: Union[Iterable[int], int] = -1 570 | dtype: DType = jnp.float32 571 | kernel_init: Initializer = nn.initializers.variance_scaling( 572 | 1.0, 'fan_in', 'truncated_normal') 573 | kernel_axes: Tuple[str, ...] = () 574 | 575 | @nn.compact 576 | def __call__(self, inputs: Array) -> Array: 577 | """Applies a linear transformation to the inputs along multiple dimensions. 578 | 579 | Args: 580 | inputs: The nd-array to be transformed. 581 | 582 | Returns: 583 | The transformed input. 584 | """ 585 | features = _canonicalize_tuple(self.features) 586 | axis = _canonicalize_tuple(self.axis) 587 | 588 | inputs = jnp.asarray(inputs, self.dtype) 589 | axis = _normalize_axes(axis, inputs.ndim) 590 | 591 | kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features 592 | kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), 593 | np.prod(features)) 594 | kernel = param_with_axes( 595 | 'kernel', 596 | self.kernel_init, 597 | kernel_param_shape, 598 | jnp.float32, 599 | axes=self.kernel_axes) 600 | kernel = jnp.asarray(kernel, self.dtype) 601 | kernel = jnp.reshape(kernel, kernel_shape) 602 | 603 | contract_ind = tuple(range(0, len(axis))) 604 | return lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ()))) 605 | 606 | 607 | def _convert_to_activation_function( 608 | fn_or_string: Union[str, Callable]) -> Callable: 609 | """Convert a string to an activation function.""" 610 | if fn_or_string == 'linear': 611 | return lambda x: x 612 | elif isinstance(fn_or_string, str): 613 | return getattr(nn, fn_or_string) 614 | elif callable(fn_or_string): 615 | return fn_or_string 616 | else: 617 | raise ValueError("don't know how to convert %s to an activation function" % 618 | (fn_or_string,)) 619 | 620 | 621 | class MlpBlock(nn.Module): 622 | """Transformer MLP / feed-forward block. 623 | 624 | Attributes: 625 | intermediate_dim: Shared dimension of hidden layers. 626 | activations: Type of activations for each layer. Each element is either 627 | 'linear', a string function name in flax.linen, or a function. 628 | kernel_init: Kernel function, passed to the dense layers. 629 | deterministic: Whether the dropout layers should be deterministic. 630 | intermediate_dropout_rate: Dropout rate used after the intermediate layers. 631 | dtype: Type for the dense layer. 632 | """ 633 | intermediate_dim: int = 2048 634 | activations: Sequence[Union[str, Callable]] = ('relu',) 635 | kernel_init: Initializer = nn.initializers.variance_scaling( 636 | 1.0, 'fan_in', 'truncated_normal') 637 | intermediate_dropout_rate: float = 0.1 638 | dtype: Any = jnp.float32 639 | 640 | @nn.compact 641 | def __call__(self, inputs, decode: bool = False, deterministic: bool = False): 642 | """Applies Transformer MlpBlock module.""" 643 | # Iterate over specified MLP input activation functions. 644 | # e.g. ('relu',) or ('linear', 'gelu') for gated-gelu. 645 | activations = [] 646 | for idx, act_fn in enumerate(self.activations): 647 | dense_name = 'wi' if len(self.activations) == 1 else f'wi_{idx}' 648 | x = DenseGeneral( 649 | self.intermediate_dim, 650 | dtype=self.dtype, 651 | kernel_init=self.kernel_init, 652 | kernel_axes=('embed', 'mlp'), 653 | name=dense_name)( 654 | inputs) 655 | x = _convert_to_activation_function(act_fn)(x) 656 | activations.append(x) 657 | 658 | # Take elementwise product of above intermediate activations. 659 | x = functools.reduce(operator.mul, activations) 660 | # Apply dropout and final dense output projection. 661 | x = nn.Dropout( 662 | rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))( 663 | x, deterministic=deterministic) # Broadcast along length. 664 | x = with_sharding_constraint(x, ('batch', 'length', 'mlp')) 665 | output = DenseGeneral( 666 | inputs.shape[-1], 667 | dtype=self.dtype, 668 | kernel_init=self.kernel_init, 669 | kernel_axes=('mlp', 'embed'), 670 | name='wo')( 671 | x) 672 | return output 673 | 674 | 675 | class Embed(nn.Module): 676 | """A parameterized function from integers [0, n) to d-dimensional vectors. 677 | 678 | Attributes: 679 | num_embeddings: number of embeddings. 680 | features: number of feature dimensions for each embedding. 681 | dtype: the dtype of the embedding vectors (default: float32). 682 | embedding_init: embedding initializer. 683 | one_hot: performs the gather with a one-hot contraction rather than a true 684 | gather. This is currently needed for SPMD partitioning. 685 | """ 686 | num_embeddings: int 687 | features: int 688 | cast_input_dtype: Optional[DType] = None 689 | dtype: DType = jnp.float32 690 | attend_dtype: Optional[DType] = None 691 | embedding_init: Initializer = default_embed_init 692 | one_hot: bool = False 693 | embedding: Array = dataclasses.field(init=False) 694 | 695 | def setup(self): 696 | self.embedding = param_with_axes( 697 | 'embedding', 698 | self.embedding_init, (self.num_embeddings, self.features), 699 | jnp.float32, 700 | axes=('vocab', 'embed')) 701 | 702 | def __call__(self, inputs: Array) -> Array: 703 | """Embeds the inputs along the last dimension. 704 | 705 | Args: 706 | inputs: input data, all dimensions are considered batch dimensions. 707 | 708 | Returns: 709 | Output which is embedded input data. The output shape follows the input, 710 | with an additional `features` dimension appended. 711 | """ 712 | if self.cast_input_dtype: 713 | inputs = inputs.astype(self.cast_input_dtype) 714 | if not jnp.issubdtype(inputs.dtype, jnp.integer): 715 | raise ValueError('Input type must be an integer or unsigned integer.') 716 | if self.one_hot: 717 | iota = lax.iota(jnp.int32, self.num_embeddings) 718 | one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype) 719 | output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype)) 720 | else: 721 | output = jnp.asarray(self.embedding, self.dtype)[inputs] 722 | output = with_sharding_constraint(output, ('batch', 'length', 'embed')) 723 | return output 724 | 725 | def attend(self, query: Array) -> Array: 726 | """Attend over the embedding using a query array. 727 | 728 | Args: 729 | query: array with last dimension equal the feature depth `features` of the 730 | embedding. 731 | 732 | Returns: 733 | An array with final dim `num_embeddings` corresponding to the batched 734 | inner-product of the array of query vectors against each embedding. 735 | Commonly used for weight-sharing between embeddings and logit transform 736 | in NLP models. 737 | """ 738 | dtype = self.attend_dtype if self.attend_dtype is not None else self.dtype 739 | return jnp.dot(query, jnp.asarray(self.embedding, dtype).T) 740 | 741 | 742 | class RelativePositionBiases(nn.Module): 743 | """Adds T5-style relative positional embeddings to the attention logits. 744 | 745 | Attributes: 746 | num_buckets: Number of buckets to bucket distances between key and query 747 | positions into. 748 | max_distance: Maximum distance before everything is lumped into the last 749 | distance bucket. 750 | num_heads: Number of heads in the attention layer. Each head will get a 751 | different relative position weighting. 752 | dtype: Type of arrays through this module. 753 | embedding_init: initializer for relative embedding table. 754 | """ 755 | num_buckets: int 756 | img_num_buckets: int 757 | max_distance: int 758 | img_max_distance: int 759 | num_heads: int 760 | img_width: int 761 | img_height: int 762 | dtype: Any 763 | embedding_init: Callable[..., Array] = nn.linear.default_embed_init 764 | 765 | @staticmethod 766 | def _relative_position_bucket(relative_position, 767 | bidirectional=True, 768 | num_buckets=32, 769 | max_distance=128): 770 | """Translate relative position to a bucket number for relative attention. 771 | 772 | The relative position is defined as memory_position - query_position, i.e. 773 | the distance in tokens from the attending position to the attended-to 774 | position. If bidirectional=False, then positive relative positions are 775 | invalid. 776 | We use smaller buckets for small absolute relative_position and larger 777 | buckets for larger absolute relative_positions. All relative 778 | positions >=max_distance map to the same bucket. All relative 779 | positions <=-max_distance map to the same bucket. This should allow for 780 | more graceful generalization to longer sequences than the model has been 781 | trained on. 782 | 783 | Args: 784 | relative_position: an int32 array 785 | bidirectional: a boolean - whether the attention is bidirectional 786 | num_buckets: an integer 787 | max_distance: an integer 788 | 789 | Returns: 790 | a Tensor with the same shape as relative_position, containing int32 791 | values in the range [0, num_buckets) 792 | """ 793 | ret = 0 794 | n = -relative_position 795 | if bidirectional: 796 | num_buckets //= 2 797 | ret += (n < 0).astype(jnp.int32) * num_buckets 798 | n = jnp.abs(n) 799 | else: 800 | n = jnp.maximum(n, 0) 801 | 802 | # now n is in the range [0, inf) 803 | max_exact = num_buckets // 2 804 | is_small = (n < max_exact) 805 | val_if_large = max_exact + ( 806 | jnp.log(n.astype(jnp.float32) / max_exact + jnp.finfo(jnp.float32).eps) / 807 | jnp.log(max_distance / max_exact) * 808 | (num_buckets - max_exact)).astype(jnp.int32) 809 | 810 | val_if_large = jnp.minimum(val_if_large, num_buckets - 1) 811 | ret += jnp.where(is_small, n, val_if_large) 812 | return ret 813 | 814 | @staticmethod 815 | def _img_relative_position_bucket(relative_position_x, 816 | relative_position_y, 817 | num_buckets=8, 818 | max_distance=20): 819 | 820 | max_exact = num_buckets // 2 821 | nx = -relative_position_x 822 | ny = -relative_position_y 823 | 824 | total_buckets = num_buckets ** 2 825 | ret = 0 826 | ret += (jnp.logical_and(nx <=0, ny <0)).astype(jnp.int32) * total_buckets * 3 827 | ret += (jnp.logical_and(nx <0, ny >=0)).astype(jnp.int32) * total_buckets * 2 828 | ret += (jnp.logical_and(nx >0, ny <=0)).astype(jnp.int32) * total_buckets * 1 829 | 830 | nx = jnp.abs(nx) 831 | ny = jnp.abs(ny) 832 | 833 | is_small_x = nx < max_exact 834 | val_x_if_large = max_exact + (jnp.log(nx.astype(jnp.float32) / 835 | max_exact + jnp.finfo(jnp.float32).eps) / jnp.log(max_distance / 836 | max_exact) * (num_buckets - max_exact)).astype(np.int32) 837 | 838 | val_x_if_large = jnp.minimum(val_x_if_large, num_buckets - 1) 839 | 840 | is_small_y = ny < max_exact 841 | val_y_if_large = max_exact + (jnp.log(ny.astype(jnp.float32) / 842 | max_exact + jnp.finfo(jnp.float32).eps) / jnp.log(max_distance / 843 | max_exact) * (num_buckets - max_exact)).astype(jnp.int32) 844 | val_y_if_large = jnp.minimum(val_y_if_large, num_buckets - 1) 845 | 846 | xx = jnp.where(is_small_x, nx, val_x_if_large) 847 | yy = jnp.where(is_small_y, ny, val_y_if_large) 848 | ret += xx + num_buckets * yy 849 | return ret 850 | 851 | @nn.compact 852 | def __call__(self, txt_position_ids, img_position_ids, bidirectional=True): 853 | """Produce relative position embedding attention biases. 854 | 855 | Args: 856 | txt_position_ids: attention query length. 857 | img_position_ids: attention key length. 858 | bidirectional: whether to allow positive memory-query relative position 859 | embeddings. 860 | 861 | Returns: 862 | output: `(1, len, q_len, k_len)` attention bias 863 | """ 864 | # TODO(levskaya): should we be computing this w. numpy as a program 865 | # constant? 866 | 867 | # compute text position encoding first. 868 | txt_context_position = txt_position_ids[:, :, None] 869 | txt_memory_position = txt_position_ids[:, None, :] 870 | txt_relative_position = txt_memory_position - txt_context_position # shape (qlen, klen) 871 | 872 | # different way to compute relative position. 873 | rp_bucket = self._relative_position_bucket( 874 | txt_relative_position, 875 | bidirectional=bidirectional, 876 | num_buckets=self.num_buckets, 877 | max_distance=self.max_distance) 878 | 879 | relative_attention_bias = param_with_axes( 880 | 'rel_embedding', 881 | self.embedding_init, (self.num_heads, self.num_buckets), 882 | jnp.float32, 883 | axes=('heads', 'relpos_buckets')) 884 | 885 | img_position_x = img_position_ids % self.img_width 886 | img_position_y = img_position_ids // self.img_width 887 | img_context_position_x = img_position_x[:,:,None] 888 | img_memory_position_x = img_position_x[:, None, :] 889 | img_context_position_y = img_position_y[:,:,None] 890 | img_memory_position_y = img_position_y[:, None, :] 891 | img_relative_position_x = img_memory_position_x - img_context_position_x 892 | img_relative_position_y = img_memory_position_y - img_context_position_y 893 | 894 | img_rp_bucket = self._img_relative_position_bucket( 895 | img_relative_position_x, 896 | img_relative_position_y, 897 | num_buckets=self.img_num_buckets, 898 | max_distance=self.img_max_distance) 899 | 900 | image_num_rel_dis = self.img_num_buckets ** 2 * 4 901 | img_relative_attention_bias = param_with_axes( 902 | 'image_rel_embedding', 903 | self.embedding_init, (self.num_heads, image_num_rel_dis), 904 | jnp.float32, 905 | axes=('heads', 'relpos_buckets')) 906 | 907 | relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype) 908 | img_relative_attention_bias = jnp.asarray(img_relative_attention_bias, self.dtype) 909 | # Instead of using a slow gather, we create a leading-dimension one-hot 910 | # array from rp_bucket and use it to perform the gather-equivalent via a 911 | # contraction, i.e.: 912 | # (num_head, num_buckets) x (num_buckets one-hot, qlen, klen). 913 | # This is equivalent to relative_attention_bias[:, rp_bucket] 914 | bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1, 1), 0) 915 | rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype) 916 | 917 | img_bcast_iota = lax.broadcasted_iota(jnp.int32, (image_num_rel_dis, 1, 1, 1), 0) 918 | img_rp_bucket_one_hot = jnp.array( 919 | img_rp_bucket[jnp.newaxis, ...] == img_bcast_iota, dtype=self.dtype) 920 | # --> shape (qlen, klen, num_heads) 921 | t_values = lax.dot_general( 922 | relative_attention_bias, 923 | rp_bucket_one_hot, 924 | ( 925 | ((1,), (0,)), 926 | ((), ()))) # no batched dims 927 | i_values = lax.dot_general( 928 | img_relative_attention_bias, 929 | img_rp_bucket_one_hot, 930 | ( 931 | ((1,), (0,)), # rhs, lhs contracting dims 932 | ((), ()))) # no batched dims 933 | 934 | t_values_pad = jax.lax.pad( 935 | t_values, 936 | jnp.array(0, dtype=t_values.dtype), 937 | [(0,0,0),(0,0,0),(0,img_position_ids.shape[1],0),(0,img_position_ids.shape[1],0)]) 938 | 939 | i_values_pad = jax.lax.pad( 940 | i_values, 941 | jnp.array(0, dtype=i_values.dtype), 942 | [(0,0,0),(0,0,0),(txt_position_ids.shape[1],0,0),(txt_position_ids.shape[1],0,0)]) 943 | values = t_values_pad + i_values_pad 944 | return jnp.transpose(values, (1,0,2,3)) 945 | 946 | #------------------------------------------------------------------------------ 947 | # T5 Layernorm - no subtraction of mean or bias. 948 | #------------------------------------------------------------------------------ 949 | class LayerNorm(nn.Module): 950 | """T5 Layer normalization operating on the last axis of the input data.""" 951 | epsilon: float = 1e-6 952 | dtype: Any = jnp.float32 953 | scale_init: Initializer = nn.initializers.ones 954 | 955 | @nn.compact 956 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 957 | """Applies layer normalization on the input.""" 958 | x = jnp.asarray(x, jnp.float32) 959 | features = x.shape[-1] 960 | mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) 961 | y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype) 962 | scale = param_with_axes( 963 | 'scale', self.scale_init, (features,), jnp.float32, axes=('embed',)) 964 | 965 | scale = jnp.asarray(scale, self.dtype) 966 | return y * scale 967 | 968 | def _canonicalize_axes(rank: int, axes: Axes) -> Iterable[int]: 969 | """Returns a tuple of deduplicated, sorted, and positive axes.""" 970 | if not isinstance(axes, Iterable): 971 | axes = (axes,) 972 | return tuple(set([rank + axis if axis < 0 else axis for axis in axes])) 973 | 974 | def _abs_sq(x): 975 | """Computes the elementwise square of the absolute value |x|^2.""" 976 | if jnp.iscomplexobj(x): 977 | return lax.square(lax.real(x)) + lax.square(lax.imag(x)) 978 | else: 979 | return lax.square(x) 980 | 981 | def _compute_stats(x: Array, axes: Axes, 982 | axis_name: Optional[str] = None, 983 | axis_index_groups: Any = None): 984 | """Computes mean and variance statistics. 985 | This implementation takes care of a few important details: 986 | - Computes in float32 precision for half precision inputs 987 | - mean and variance is computable in a single XLA fusion, 988 | by using Var = E[|x|^2] - |E[x]|^2 instead of Var = E[|x - E[x]|^2]). 989 | - Clips negative variances to zero which can happen due to 990 | roundoff errors. This avoids downstream NaNs. 991 | - Supports averaging across a parallel axis and subgroups of a parallel axis 992 | with a single `lax.pmean` call to avoid latency. 993 | """ 994 | # promote x to at least float32, this avoids half precision computation 995 | # but preserves double or complex floating points 996 | x = jnp.asarray(x, jnp.promote_types(jnp.float32, jnp.result_type(x))) 997 | mean = jnp.mean(x, axes) 998 | mean2 = jnp.mean(_abs_sq(x), axes) 999 | if axis_name is not None: 1000 | concatenated_mean = jnp.concatenate([mean, mean2]) 1001 | mean, mean2 = jnp.split( 1002 | lax.pmean( 1003 | concatenated_mean, 1004 | axis_name=axis_name, 1005 | axis_index_groups=axis_index_groups), 2) 1006 | # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due 1007 | # to floating point round-off errors. 1008 | var = jnp.maximum(0., mean2 - _abs_sq(mean)) 1009 | return mean, var 1010 | 1011 | def _normalize(mdl: Module, x: Array, mean: Array, var: Array, 1012 | reduction_axes: Axes, feature_axes: Axes, 1013 | dtype: Any, param_dtype: Any, 1014 | epsilon: float, 1015 | use_bias: bool, use_scale: bool, 1016 | bias_init: Initializer, 1017 | scale_init: Initializer): 1018 | """"Normalizes the input of a normalization layer and optionally applies a learned scale and bias. 1019 | A seperate bias and scale is learned for each feature as specified by feature_axes. 1020 | """ 1021 | reduction_axes = _canonicalize_axes(x.ndim, reduction_axes) 1022 | feature_axes = _canonicalize_axes(x.ndim, feature_axes) 1023 | stats_shape = list(x.shape) 1024 | for axis in reduction_axes: 1025 | stats_shape[axis] = 1 1026 | mean = mean.reshape(stats_shape) 1027 | var = var.reshape(stats_shape) 1028 | feature_shape = [1] * x.ndim 1029 | reduced_feature_shape = [] 1030 | for ax in feature_axes: 1031 | feature_shape[ax] = x.shape[ax] 1032 | reduced_feature_shape.append(x.shape[ax]) 1033 | y = x - mean 1034 | mul = lax.rsqrt(var + epsilon) 1035 | if use_scale: 1036 | scale = param_with_axes('scale', scale_init, reduced_feature_shape, 1037 | param_dtype, axes=('axis_0',)).reshape(feature_shape) 1038 | mul *= scale 1039 | y *= mul 1040 | if use_bias: 1041 | bias = param_with_axes('bias', bias_init, reduced_feature_shape, 1042 | param_dtype, axes=('axis_0',)).reshape(feature_shape) 1043 | y += bias 1044 | return jnp.asarray(y, dtype) 1045 | 1046 | class GroupNorm(Module): 1047 | num_groups: Optional[int] = 32 1048 | group_size: Optional[int] = None 1049 | epsilon: float = 1e-6 1050 | dtype: Any = jnp.float32 1051 | param_dtype: Any = jnp.float32 1052 | use_bias: bool = True 1053 | use_scale: bool = True 1054 | bias_init: Initializer = nn.initializers.zeros 1055 | scale_init: Initializer = nn.initializers.ones 1056 | 1057 | @nn.compact 1058 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 1059 | reduction_axes = list(range(1, x.ndim - 1)) + [-1] 1060 | feature_axes = (-1,) 1061 | 1062 | if ((self.num_groups is None and self.group_size is None) or 1063 | (self.num_groups is not None and self.group_size is not None)): 1064 | raise ValueError('Either `num_groups` or `group_size` should be ' 1065 | 'specified, but not both of them.') 1066 | num_groups = self.num_groups 1067 | 1068 | channels = x.shape[-1] 1069 | if self.group_size is not None: 1070 | if channels % self.group_size != 0: 1071 | raise ValueError('Number of channels ({}) is not multiple of the ' 1072 | 'group size ({}).'.format(channels, self.group_size)) 1073 | num_groups = channels // self.group_size 1074 | 1075 | if num_groups <= 0 or channels % num_groups != 0: 1076 | raise ValueError('Number of groups ({}) does not divide the number' 1077 | ' of channels ({}).'.format(num_groups, channels)) 1078 | 1079 | group_size = x.shape[-1] // num_groups 1080 | group_shape = x.shape[:-1] + (num_groups, group_size) 1081 | 1082 | def broadcast_stat(stat): 1083 | stat = jnp.broadcast_to(stat[..., None], (x.shape[0], num_groups, group_size)) 1084 | return stat.reshape((x.shape[0], num_groups * group_size)) 1085 | 1086 | # TODO suport axis_name for model parallelism? 1087 | mean, var = _compute_stats(x.reshape(group_shape), reduction_axes, None, None) 1088 | mean = broadcast_stat(mean) 1089 | var = broadcast_stat(var) 1090 | 1091 | return _normalize( 1092 | self, x, mean, var, reduction_axes[:-1], feature_axes, 1093 | self.dtype, self.param_dtype, self.epsilon, 1094 | self.use_bias, self.use_scale, 1095 | self.bias_init, self.scale_init) 1096 | 1097 | #------------------------------------------------------------------------------ 1098 | # Mask-making utility functions. 1099 | #------------------------------------------------------------------------------ 1100 | def make_attention_mask(query_input: Array, 1101 | key_input: Array, 1102 | pairwise_fn: Callable = jnp.multiply, 1103 | extra_batch_dims: int = 0, 1104 | dtype: DType = jnp.float32) -> Array: 1105 | """Mask-making helper for attention weights. 1106 | 1107 | In case of 1d inputs (i.e., `[batch, len_q]`, `[batch, len_kv]`, the 1108 | attention weights will be `[batch, heads, len_q, len_kv]` and this 1109 | function will produce `[batch, 1, len_q, len_kv]`. 1110 | 1111 | Args: 1112 | query_input: a batched, flat input of query_length size 1113 | key_input: a batched, flat input of key_length size 1114 | pairwise_fn: broadcasting elementwise comparison function 1115 | extra_batch_dims: number of extra batch dims to add singleton axes for, none 1116 | by default 1117 | dtype: mask return dtype 1118 | 1119 | Returns: 1120 | A `[batch, 1, len_q, len_kv]` shaped mask for 1d attention. 1121 | """ 1122 | # [batch, len_q, len_kv] 1123 | mask = pairwise_fn( 1124 | # [batch, len_q] -> [batch, len_q, 1] 1125 | jnp.expand_dims(query_input, axis=-1), 1126 | # [batch, len_q] -> [batch, 1, len_kv] 1127 | jnp.expand_dims(key_input, axis=-2)) 1128 | 1129 | # [batch, 1, len_q, len_kv]. This creates the head dim. 1130 | mask = jnp.expand_dims(mask, axis=-3) 1131 | mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims))) 1132 | return mask.astype(dtype) 1133 | 1134 | 1135 | def make_causal_mask(x: Array, 1136 | extra_batch_dims: int = 0, 1137 | dtype: DType = jnp.float32) -> Array: 1138 | """Make a causal mask for self-attention. 1139 | 1140 | In case of 1d inputs (i.e., `[batch, len]`, the self-attention weights 1141 | will be `[batch, heads, len, len]` and this function will produce a 1142 | causal mask of shape `[batch, 1, len, len]`. 1143 | 1144 | Note that a causal mask does not depend on the values of x; it only depends on 1145 | the shape. If x has padding elements, they will not be treated in a special 1146 | manner. 1147 | 1148 | Args: 1149 | x: input array of shape `[batch, len]` 1150 | extra_batch_dims: number of batch dims to add singleton axes for, none by 1151 | default 1152 | dtype: mask return dtype 1153 | 1154 | Returns: 1155 | A `[batch, 1, len, len]` shaped causal mask for 1d attention. 1156 | """ 1157 | idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape) 1158 | return make_attention_mask( 1159 | idxs, 1160 | idxs, 1161 | jnp.greater_equal, 1162 | extra_batch_dims=extra_batch_dims, 1163 | dtype=dtype) 1164 | 1165 | 1166 | def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32): 1167 | """Combine attention masks. 1168 | 1169 | Args: 1170 | *masks: set of attention mask arguments to combine, some can be None. 1171 | dtype: final mask dtype 1172 | 1173 | Returns: 1174 | Combined mask, reduced by logical and, returns None if no masks given. 1175 | """ 1176 | masks = [m for m in masks if m is not None] 1177 | if not masks: 1178 | return None 1179 | assert all(map(lambda x: x.ndim == masks[0].ndim, masks)), ( 1180 | f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}') 1181 | mask, *other_masks = masks 1182 | for other_mask in other_masks: 1183 | mask = jnp.logical_and(mask, other_mask) 1184 | return mask.astype(dtype) 1185 | 1186 | 1187 | def combine_biases(*masks: Optional[Array]): 1188 | """Combine attention biases. 1189 | 1190 | Args: 1191 | *masks: set of attention bias arguments to combine, some can be None. 1192 | 1193 | Returns: 1194 | Combined mask, reduced by summation, returns None if no masks given. 1195 | """ 1196 | masks = [m for m in masks if m is not None] 1197 | if not masks: 1198 | return None 1199 | assert all(map(lambda x: x.ndim == masks[0].ndim, masks)), ( 1200 | f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}') 1201 | mask, *other_masks = masks 1202 | for other_mask in other_masks: 1203 | mask = mask + other_mask 1204 | return mask 1205 | 1206 | 1207 | def make_decoder_mask(decoder_target_tokens: Array, 1208 | dtype: DType, 1209 | decoder_causal_attention: Optional[Array] = None, 1210 | decoder_segment_ids: Optional[Array] = None) -> Array: 1211 | """Compute the self-attention mask for a decoder. 1212 | 1213 | Decoder mask is formed by combining a causal mask, a padding mask and an 1214 | optional packing mask. If decoder_causal_attention is passed, it makes the 1215 | masking non-causal for positions that have value of 1. 1216 | 1217 | A prefix LM is applied to a dataset which has a notion of "inputs" and 1218 | "targets", e.g., a machine translation task. The inputs and targets are 1219 | concatenated to form a new target. `decoder_target_tokens` is the concatenated 1220 | decoder output tokens. 1221 | 1222 | The "inputs" portion of the concatenated sequence can attend to other "inputs" 1223 | tokens even for those at a later time steps. In order to control this 1224 | behavior, `decoder_causal_attention` is necessary. This is a binary mask with 1225 | a value of 1 indicating that the position belonged to "inputs" portion of the 1226 | original dataset. 1227 | 1228 | Example: 1229 | 1230 | Suppose we have a dataset with two examples. 1231 | 1232 | ds = [{"inputs": [6, 7], "targets": [8]}, 1233 | {"inputs": [3, 4], "targets": [5]}] 1234 | 1235 | After the data preprocessing with packing, the two examples are packed into 1236 | one example with the following three fields (some fields are skipped for 1237 | simplicity). 1238 | 1239 | decoder_target_tokens = [[6, 7, 8, 3, 4, 5, 0]] 1240 | decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]] 1241 | decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]] 1242 | 1243 | where each array has [batch, length] shape with batch size being 1. Then, 1244 | this function computes the following mask. 1245 | 1246 | mask = [[[[1, 1, 0, 0, 0, 0, 0], 1247 | [1, 1, 0, 0, 0, 0, 0], 1248 | [1, 1, 1, 0, 0, 0, 0], 1249 | [0, 0, 0, 1, 1, 0, 0], 1250 | [0, 0, 0, 1, 1, 0, 0], 1251 | [0, 0, 0, 1, 1, 1, 0], 1252 | [0, 0, 0, 0, 0, 0, 0]]]] 1253 | 1254 | mask[b, 1, :, :] represents the mask for the example `b` in the batch. 1255 | Because mask is for a self-attention layer, the mask's shape is a square of 1256 | shape [query length, key length]. 1257 | 1258 | mask[b, 1, i, j] = 1 means that the query token at position i can attend to 1259 | the key token at position j. 1260 | 1261 | Args: 1262 | decoder_target_tokens: decoder output tokens. [batch, length] 1263 | dtype: dtype of the output mask. 1264 | decoder_causal_attention: a binary mask indicating which position should 1265 | only attend to earlier positions in the sequence. Others will attend 1266 | bidirectionally. [batch, length] 1267 | decoder_segment_ids: decoder segmentation info for packed examples. [batch, 1268 | length] 1269 | 1270 | Returns: 1271 | the combined decoder mask. 1272 | """ 1273 | masks = [] 1274 | # The same mask is applied to all attention heads. So the head dimension is 1, 1275 | # i.e., the mask will be broadcast along the heads dim. 1276 | # [batch, 1, length, length] 1277 | causal_mask = make_causal_mask(decoder_target_tokens, dtype=dtype) 1278 | 1279 | # Positions with value 1 in `decoder_causal_attneition` can attend 1280 | # bidirectionally. 1281 | if decoder_causal_attention is not None: 1282 | # [batch, 1, lengtlength] 1283 | inputs_mask = make_attention_mask( 1284 | decoder_causal_attention, 1285 | decoder_causal_attention, 1286 | jnp.logical_and, 1287 | dtype=dtype) 1288 | masks.append(jnp.logical_or(causal_mask, inputs_mask).astype(dtype)) 1289 | else: 1290 | masks.append(causal_mask) 1291 | 1292 | # Padding mask. 1293 | masks.append( 1294 | make_attention_mask( 1295 | decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=dtype)) 1296 | 1297 | # Packing mask 1298 | if decoder_segment_ids is not None: 1299 | masks.append( 1300 | make_attention_mask( 1301 | decoder_segment_ids, decoder_segment_ids, jnp.equal, dtype=dtype)) 1302 | 1303 | return combine_masks(*masks, dtype=dtype) 1304 | -------------------------------------------------------------------------------- /uio/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union, Sequence 2 | 3 | import jax 4 | import numpy as np 5 | import jax.numpy as jnp 6 | from flax.serialization import from_bytes 7 | 8 | # Constants used by all UnifiedIO Models 9 | import torch 10 | from torchvision.transforms import InterpolationMode 11 | from torchvision.transforms import functional as F 12 | 13 | vocab_size = 33100 14 | BIN_START = vocab_size - 1100 15 | NUM_DETECTION_BIN = 1000 16 | VOCAB_START = 100 17 | IMAGE_INPUT_SIZE = [384, 384] 18 | IMAGE_INPUT_PATCH_SIZE = 16 19 | 20 | IMAGE_TARGET_SIZE = [256, 256] 21 | 22 | 23 | def load_checkpoint(checkpoint): 24 | """Load a bin file as a tree of jax arrays""" 25 | with open(checkpoint, "rb") as state_f: 26 | state = from_bytes(None, state_f.read()) 27 | state = jax.tree_util.tree_map(jnp.array, state) 28 | return state 29 | 30 | 31 | def transpose_lists(lsts): 32 | """Transpose a list of lists.""" 33 | return [list(i) for i in zip(*lsts)] 34 | 35 | 36 | def region_to_tokens(box, img_w, img_h) -> List[str]: 37 | """Convert a region into a text sequence 38 | 39 | :param box: [x1, y1, x2, y2] non-normalized bounding box 40 | :param img_w: image width 41 | :param img_h: image height 42 | :return: text tokens representation of the region 43 | """ 44 | # Convert to yx format normalized to the padded input image. 45 | scale = max(img_w, img_h) 46 | box = np.array([box[1], box[0], box[3], box[2]]) / scale 47 | # Quantize 48 | quantized_boxes = ((NUM_DETECTION_BIN-1) * box).astype(np.int32) 49 | # Convert to tokens 50 | return [f"" for i in quantized_boxes] 51 | 52 | 53 | def tokens_to_regions(predicted_tokens, image_size, token_per_label=4) -> Tuple[List[str], np.ndarray]: 54 | """Convert tokens into a list of image locations and labels 55 | 56 | :param predicted_tokens: Integer tokens from UnifiedIO 57 | :param image_size: original image size 58 | :param token_per_label: number of location tokens preceding each object label 59 | :return: 60 | labels: List[str] of object labels 61 | locations: np.ndarray [n_objects, token_per_label] image coordinates 62 | """ 63 | predicted_tokens = np.array(predicted_tokens) 64 | locations = [] 65 | labels = [] 66 | cur = 0 67 | while True: 68 | if cur >= len(predicted_tokens) or predicted_tokens[cur] == 1: 69 | # end of sequence 70 | break 71 | if not np.all(predicted_tokens[cur:cur+token_per_label] > BIN_START): 72 | # error, should be a list of locations then label 73 | raise ValueError() 74 | locations.append(vocab_size-predicted_tokens[cur:cur+token_per_label] - 100) 75 | cur += token_per_label 76 | label_end = cur 77 | while label_end < len(predicted_tokens) and 1 < predicted_tokens[label_end] <= BIN_START: 78 | label_end += 1 79 | labels.append(predicted_tokens[cur:label_end]) 80 | cur = label_end 81 | 82 | locations = np.array(locations) 83 | locations = locations.reshape((-1, 2))[:, ::-1].reshape((-1, token_per_label)) # [yx to xy] 84 | # Account for image resizing 85 | factor = max(image_size) 86 | locations = locations * (factor / 1000) 87 | return labels, locations 88 | 89 | 90 | def extract_keypoints(tokens, tokenizer, image_size): 91 | """Read keypoints from UnifiedIO output 92 | 93 | :param tokens: integer tokens generated 94 | :param tokenizer: T5Tokenizer 95 | :param image_size: size of the input image 96 | :return: 97 | points: [17, 2] keypoint coordinates 98 | labels: [17] integer labels between 0 and 2 99 | invalid: bool, true if `tokens` did not correctly conform the keypoint output format, 100 | if missing/invalid points will be filled by the mean coordiantes of the visible points 101 | """ 102 | labels, points = tokens_to_regions(tokens, image_size, token_per_label=2) 103 | points = np.array(points) 104 | invalid = False # Is this text a valid keypoint prediction 105 | 106 | # Convert label to integers 107 | for i, l in enumerate(labels): 108 | l = tokenizer.decode(l) 109 | try: 110 | l = int(l) - 1 111 | if not (0 <= l <= 2): 112 | invalid = True 113 | l = 0 114 | except ValueError: 115 | invalid = True 116 | l = 0 117 | labels[i] = l 118 | labels = np.array(labels) 119 | if np.sum(labels) == 0: 120 | # No visible points predicted 121 | return None, None, invalid 122 | 123 | # replace non visible point with mean so we do something non-crazy if the 124 | # GT turns out to be `visible` 125 | mean = np.mean(points[labels != 0], 0, keepdims=True) 126 | points[labels == 0] = mean 127 | 128 | if len(points) > 17: 129 | # Truncate if we generated extra for some reason 130 | invalid = True 131 | points = points[:17] 132 | labels = labels[:17] 133 | elif len(points) < 17: 134 | # Replace with mean if we generated too few points 135 | invalid = True 136 | mean = np.mean(points, 0, keepdims=True) 137 | n = 17 - len(points) 138 | points = np.concatenate([points, np.tile(mean, (n, 1))], 0) 139 | labels = np.concatenate([labels, np.zeros((n,), labels.dtype)]) 140 | 141 | assert points.shape == (17, 2) 142 | return points, labels, invalid 143 | 144 | 145 | def clean_mask(mask, min_size): 146 | """Remove connected components that have less than `min_size` pixels""" 147 | from scipy import ndimage 148 | label, n_obj = ndimage.measurements.label(mask) 149 | cleaned = None 150 | for c in range(1, n_obj+1): 151 | is_c = label == c 152 | if np.sum(is_c) > min_size: 153 | if cleaned is None: 154 | cleaned = is_c 155 | else: 156 | cleaned = np.logical_or(cleaned, is_c) 157 | return cleaned 158 | 159 | 160 | def extract_segmentation_masks(img, segmention_mode="coarse_color") -> List[np.ndarray]: 161 | """Extract a list of binary segmentation masks from `img`""" 162 | if not np.issubdtype(img.dtype, np.integer): 163 | img = (img*255).astype(np.uint8) 164 | 165 | if segmention_mode == "any_pixel": 166 | # Assume there is only a single instance 167 | is_instance = img.mean(-1) > 30 168 | return [is_instance] 169 | 170 | elif segmention_mode == "coarse_color": 171 | # Find instances based on coarse-grained color detection, and then clean them for 172 | # extra/floating background pixels. Pretty slow, I think because `clean_mask` is slow 173 | w, h = img.shape[:2] 174 | img = np.array(img).reshape((-1, 3)) # [n_pixels, 3] 175 | 176 | img = img.astype(np.float64) 177 | means = img.mean(axis=-1) 178 | mean_diff = img - means[:, None] 179 | 180 | # Background pixels are black or nearly black 181 | background = means <= 30 182 | 183 | # First object pixels are gray/white, we allow gray since the VAE will often put gray 184 | # pixels around the white blobs it is supposed to predict 185 | # We detect such pixels if all RGB values are close to the mean 186 | first_obj = np.logical_and(np.logical_not(background), np.abs(mean_diff).sum(-1) < 100) 187 | used = np.logical_and(background, first_obj) # Pixel already assigned 188 | out = [] 189 | first_obj = clean_mask(first_obj, 10) 190 | if np.any(first_obj): 191 | out.append(first_obj) 192 | 193 | color = np.argmax(img, -1) 194 | for c in range(3): 195 | # Find pixels if each color they must have that color's value 196 | # be the largest RGB value be large then the mean by a reasonable margin 197 | candidate = np.logical_and(np.logical_not(used), color == c) 198 | color_map = np.logical_and(candidate, np.abs(mean_diff[:, c]) > 40) 199 | color_map = clean_mask(color_map, 10) 200 | if np.any(color_map): 201 | out.append(color_map) 202 | used = np.logical_and(used, color_map) 203 | return [x.reshape(w, h) for x in out] 204 | 205 | else: 206 | raise NotImplementedError() 207 | 208 | 209 | def _resize(image: np.ndarray, target_size: Sequence[int], 210 | mode: Union[str, InterpolationMode]="bilinear", antialias=True): 211 | if isinstance(mode, str): 212 | mode = InterpolationMode(mode) 213 | if image.dtype == np.uint8: 214 | image = image / 255.0 215 | image = F.resize(torch.as_tensor(image.transpose((2, 0, 1))), target_size, antialias=antialias, 216 | interpolation=mode) 217 | image = np.transpose(image.numpy().astype(np.float32), [1, 2, 0]) 218 | return image 219 | 220 | 221 | def resize_and_pad(image: np.ndarray, size) -> Tuple[np.ndarray, np.ndarray]: 222 | """Resize and pad `image` to `size` and returns a mask over pixels introduced by padding""" 223 | h, w = image.shape[:2] 224 | scale = size[0] / max(h, w) 225 | if scale != 1.0: 226 | scale_to = (int(h*scale), int(w*scale)) 227 | image = _resize(image, scale_to) 228 | else: 229 | scale_to = (h, w) 230 | image_mask = np.zeros(size, dtype=np.bool) 231 | image_mask[:scale_to[0], :scale_to[1]] = True 232 | image = np.pad( 233 | image, 234 | [[0, size[0] - scale_to[0]], 235 | [0, size[1] - scale_to[1]], 236 | [0, 0] 237 | ] 238 | ) 239 | return image, image_mask 240 | 241 | 242 | def undo_image_preprocessing(image, original_size, mode="nearest", antialias=False): 243 | """Resize image generated from UnifiedIO to the size of `original_size`, this undoes 244 | the padding and down-scaling done in `preprocess_image`. 245 | 246 | By default, we use near-neighbor interpolation and not anti-aliasing since that makes the most 247 | sense for tasks involving non-natural images like segmentation and surface normals 248 | """ 249 | h, w = original_size 250 | ratio = image.shape[0] / max(w, h) 251 | # undo the padding 252 | if h > w: 253 | image_rescale = image[:, :int(ratio*w)] 254 | else: 255 | image_rescale = image[:int(ratio*h), :] 256 | # Undo the scaling 257 | return _resize(np.copy(image_rescale), (h, w), mode=mode, antialias=antialias) 258 | 259 | 260 | def preprocess_image(input_image, mask_region=None) -> Tuple[np.ndarray, np.ndarray]: 261 | """Preprocess an image for processing UnifiedIO 262 | 263 | :param input_image: image array in [h, w, 3] in float or uint8 format 264 | :param mask_region: Optional region to include in the image mask, used for image inpaintin 265 | :return: preprocessed image and image-patch mask 266 | """ 267 | n_patches = 384//16 268 | if input_image is not None: 269 | original_size = input_image.shape 270 | input_image, image_mask = resize_and_pad(input_image, IMAGE_INPUT_SIZE) 271 | 272 | if mask_region is not None: 273 | region = mask_region / max(original_size[:2]) * max(input_image.shape[:2]) 274 | x1, y1, x2, y2 = np.round(region).astype(np.int32) 275 | region_mask = np.ones_like(image_mask) 276 | region_mask[y1:y2, x1:x2] = 0 277 | image_mask = image_mask*region_mask 278 | 279 | # Convert mask over pixels to mask of image patches 280 | image_mask = _resize( 281 | np.expand_dims(image_mask, 2), [n_patches, n_patches], 282 | InterpolationMode.NEAREST, antialias=False 283 | ) 284 | image_mask = image_mask.reshape((-1,)).astype(np.int32) 285 | else: 286 | if mask_region is not None: 287 | raise ValueError() 288 | # Masked, dummy values since this code does not support skipping the image 289 | input_image = np.zeros((384, 384, 3), np.float32) 290 | image_mask = np.zeros((n_patches*n_patches, ), dtype=np.int32) 291 | input_image = normalize_image(input_image) 292 | return input_image, image_mask 293 | 294 | 295 | def preprocess_target_image(target_image) -> Tuple[np.ndarray, np.ndarray]: 296 | """Preprocess a target image for processing UnifiedIO 297 | 298 | :param target_image: image array in [h, w, 3] in float or uint8 format 299 | :return: preprocessed image and image-patch mask 300 | """ 301 | n_patches = IMAGE_TARGET_SIZE[0]//16 302 | if target_image is not None: 303 | input_image, image_mask = resize_and_pad(target_image, IMAGE_TARGET_SIZE) 304 | 305 | # Convert mask over pixels to mask of image patches 306 | image_mask = _resize( 307 | np.expand_dims(image_mask, 2), [n_patches, n_patches], 308 | InterpolationMode.NEAREST, antialias=False 309 | ) 310 | image_mask = image_mask.reshape((-1,)).astype(np.int32) 311 | else: 312 | input_image = np.zeros(IMAGE_TARGET_SIZE + [3], np.float32) 313 | image_mask = np.zeros((n_patches*n_patches, ), dtype=np.int32) 314 | input_image = input_image * 2 - 1 # VAE pre-processing 315 | return input_image, image_mask 316 | 317 | 318 | BIAS = np.array([0.485, 0.456, 0.406]) 319 | SCALE = np.array([0.229, 0.224, 0.225]) 320 | 321 | 322 | def normalize_image(image) -> np.ndarray: 323 | """Pixel normalizing used by UnifiedIO""" 324 | image -= BIAS.reshape((1, 1, 3)) 325 | image /= SCALE.reshape((1, 1, 3)) 326 | return image 327 | 328 | --------------------------------------------------------------------------------