├── LICENSE ├── README.md ├── requirements-exact.txt ├── requirements.txt ├── setup.py ├── teaser-short.svg └── uio2 ├── __init__.py ├── audio_embedder.py ├── audio_utils.py ├── audio_vqgan.py ├── config.py ├── convert_checkpoint.py ├── data_utils.py ├── get_modality_processor.py ├── get_model.py ├── hifigan ├── README.md ├── __init__.py ├── checkpoints │ ├── config.json │ └── g_00930000 ├── models.py └── utils.py ├── image_embedder.py ├── image_vqgan.py ├── input_modalities.py ├── layers.py ├── model.py ├── perceiver.py ├── preprocessing.py ├── prompt.py ├── runner.py ├── seq_features.py ├── target_modalities.py ├── utils.py ├── video_utils.py └── vocabulary.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UnifiedIO-2.PyTorch 2 | 3 | This repo is an official pytorch port of [UnifiedIO-2](https://unified-io-2.allenai.org/). The original jax code can be found 4 | [here](https://github.com/allenai/unified-io-2). UnifiedIO 2 is a multi-modal multi-task model capable of performing a wide 5 | range of tasks. 6 | 7 | ![test](teaser-short.svg) 8 | 9 | ## Installation 10 | Install [pytorch](https://pytorch.org/) following the recommendation for your system. Then install with 11 | 12 | ``` 13 | git clone unified-io-2.pytorch 14 | cd unified-io-2.pytorch 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | ## Loading the model 19 | 20 | Load the model with 21 | ``` 22 | from uio2.model import UnifiedIOModel 23 | model = UnifiedIOModel.from_pretrained("allenai/uio2-large") 24 | ``` 25 | This loads the large (1B) model, load the XL (3B) or XXL (7B) with 26 | `allenai/uio2-xl` and `allenai/uio2-xxl`. 27 | 28 | This model requires pre-processed tensor inputs. Pre-processing is done by `UnifiedIOPreprocessor`: 29 | 30 | ``` 31 | from uio2.preprocessing import UnifiedIOPreprocessor 32 | preprocessor = UnifiedIOPreprocessor.from_pretrained("allenai/uio2-preprocessor", tokenizer="/path/to/tokenizer") 33 | ``` 34 | 35 | Here "/path/to/tokenizer" needs to point to the LLaMa tokenizer file. The tokenizer 36 | file needs to be downloaded manually from [LLaMA](https://llama.meta.com/). 37 | 38 | You can remove modality-specific components you don't need. For example, 39 | if you only want to do text-to-image tasks run: 40 | 41 | ``` 42 | model.set_modalities(input_modalities=["text"], target_modalities=["image"]) 43 | ``` 44 | 45 | 46 | This will remove some unneeded parameters from the model. 47 | 48 | ### Initializing from Scratch 49 | The model can also be built from scratch by directly using a config: 50 | 51 | ``` 52 | from uio2 import config 53 | preprocessor = UnifiedIOPreprocessor.from_config(config.LARGE, /path/to/tokenizer) 54 | model = UnifiedIOModel(config.LARGE) 55 | ``` 56 | 57 | ### Using bfloat16 58 | The model can be run in `bfloat16`, typically we have done this while keeping the ViTs 59 | and VQGANs as `float32`. To convert the model to this format run: 60 | ``` 61 | model.to_dtype(torch.bfloat16, vit_dtype=torch.float32, vqgan_dtype=torch.float32) 62 | ``` 63 | 64 | We provide pre-trained models in this format to reduce bandwidth/memory requirements 65 | when downloading/loading the models: 66 | 67 | ``` 68 | model = UnifiedIOModel.from_pretrained("allenai/uio2-large-bfloat16") 69 | ``` 70 | 71 | ## Usage 72 | ### Generation 73 | Do text generation 74 | 75 | ``` 76 | from uio2.preprocessing import build_batch 77 | preprocessed_example = preprocessor(text_inputs="What color is the sky?", target_modality="text") 78 | batch = build_batch([preprocessed_example], device=model.device) 79 | tokens = model.generate(batch, modality="text", max_new_tokens=128) 80 | ``` 81 | 82 | `modality` can be set to `"image"` or `"audio"`. Image will return a `[256, 256, 3]` image, and 83 | audio will return a `[128. 256, 1]` mel-spectrogram. See `UnifiedIOPreprocessor` for the various 84 | kinds of input the model supports. 85 | 86 | To see many other examples of generation and how to best configure the model and post-process 87 | the output, see `TaskRunner` 88 | 89 | ``` 90 | from uio2.runner import TaskRunner 91 | 92 | runner = TaskRunner(model, preprocessor) 93 | image = runner.image_generation("a cat") 94 | wavform = runner.audio_generation("dogs barking") 95 | box = runner.refexp("/path/to/image", "the green car") 96 | keypoint = runner.keypoint("/path/to/image") 97 | # And many more, see TaskRunner 98 | ``` 99 | 100 | ### Answer Scoring 101 | `model.score_answer_options` can compute the loss of several possible 102 | outputs given one set of inputs. See `TaskRunner.categorization` or `TaskRunner.box_categorization` to see 103 | examples of how to use it. 104 | 105 | ``` 106 | runner.categorization("/path/to/image", ["cat", "dog"]) 107 | ``` 108 | 109 | 110 | ### Computing the Loss 111 | Calling the model will produce logits, masks, and targets for each modality. 112 | If using forward, at least one target modality should be set when calling the 113 | preprocessor. 114 | 115 | The loss for an example can then be computed like this: 116 | 117 | ``` 118 | from torch.nn import functional as F 119 | from uio2.preprocessing import build_batch 120 | preprocessed_example = preprocessor( 121 | text_inputs="What is 1+1?", text_targets="2", target_modality="text") 122 | batch = build_batch([preprocessed_example], device=model.device) 123 | out = model(batch) 124 | total_loss = 0 125 | for modality, (logits, targets, mask) in out.items(): 126 | losses = F.cross_entropy( 127 | logits.view(-1, logits.shape[-1]), targets.view(-1).to(torch.long), reduction="none") 128 | total_loss += (losses.reshape(logits.shape[:2])*mask)/mask.sum() 129 | print(total_loss) 130 | ``` 131 | 132 | See `preprocessor` supports inputs/output for all modalities. 133 | 134 | To train the model, run `preprocessor` and `build_batch` in a DataLoader and then 135 | backprop on the loss. 136 | 137 | ## Citation 138 | 139 | ```bibtex 140 | @article{lu2023uio2, 141 | title = {Unified-IO 2: Scaling Autoregressive Multimodal Models with Vision, Language, Audio, and Action}, 142 | author = {Jiasen Lu and Christopher Clark and Sangho Lee and Zichen Zhang and Savya Khosla and Ryan Marten and Derek Hoiem and Aniruddha Kembhavi}, 143 | journal = {arXiv preprint arXiv:2312.17172}, 144 | year = {2023}, 145 | } 146 | ``` 147 | -------------------------------------------------------------------------------- /requirements-exact.txt: -------------------------------------------------------------------------------- 1 | transformers==4.37.2 2 | numpy==1.24.3 3 | scipy=1.12.0 4 | tensorflow==2.15.0 5 | sentencepiece==0.1.99 6 | einops==0.7.0 7 | protobuf==3.20.* 8 | tqdm==4.66.1 9 | pillow==10.2.0 10 | librosa==0.10.1 11 | scikit-video==1.1.11 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Need new caching abstraction from transformers 4.36.0 2 | # Later versions can hit incompatibilities with how we have setup generation 3 | transformers==v4.36.0 4 | numpy 5 | scipy 6 | tensorflow # Some of our pre-processing code is still in tensorflow 7 | sentencepiece # For the tokenizer 8 | einops 9 | protobuf==3.20.* # Downgrade, new versions cannot load the tokenizer 10 | tqdm 11 | pillow # For image processing 12 | librosa # For audio processing 13 | scikit-video # For video processing -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | 4 | import os 5 | 6 | from setuptools import find_packages, setup 7 | 8 | _PATH_ROOT = os.path.dirname(__file__) 9 | 10 | with open(os.path.join(_PATH_ROOT, "README.md")) as fo: 11 | readme = fo.read() 12 | 13 | with open(os.path.join(_PATH_ROOT, "requirements.txt")) as fo: 14 | requirements = [x.strip().split()[0] for x in fo.readlines() if x.strip()] 15 | 16 | 17 | setup( 18 | name="Unified-IO-2-PyTorch", 19 | version="0.1.0", 20 | description="A multi-task multi-modal model", 21 | author="UnifiedIO Team", 22 | url="https://github.com/allenai/unified-io-2.pytorch", 23 | install_requires=requirements, 24 | packages=find_packages(), 25 | long_description=readme, 26 | long_description_content_type="text/markdown", 27 | ) -------------------------------------------------------------------------------- /uio2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/unified-io-2.pytorch/6e487bec6be8f9b909453a5f9833c49914f4a777/uio2/__init__.py -------------------------------------------------------------------------------- /uio2/audio_embedder.py: -------------------------------------------------------------------------------- 1 | """Audio ViT that builds features from spectograms""" 2 | from typing import Any, Optional 3 | import torch 4 | 5 | from uio2 import layers 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | 10 | class MLP(nn.Module): 11 | def __init__(self, config): 12 | super().__init__() 13 | self.config = config 14 | self.fc1 = nn.Linear(config.emb_dim, config.mlp_dim, bias=True) 15 | self.gelu = nn.GELU(approximate='tanh') # The uio2 jax code used the tanh approximation. 16 | self.fc2 = nn.Linear(config.mlp_dim, config.emb_dim, bias=True) 17 | 18 | def forward(self, x): 19 | x = self.fc1(x) 20 | x = self.gelu(x) 21 | x = self.fc2(x) 22 | return x 23 | 24 | 25 | class MultiHeadDotProductAttention(nn.Module): 26 | def __init__( 27 | self, 28 | emb_dim, 29 | num_heads: int, 30 | head_dim: int, 31 | dropout_rate: float = 0., 32 | float32_logits: bool = False # computes logits in float32 for stability. 33 | ): 34 | super().__init__() 35 | self.num_heads = num_heads 36 | self.head_dim = head_dim 37 | assert emb_dim == num_heads * head_dim, "embed_dim must be divisible by num_heads" 38 | self.scale = self.head_dim ** -0.5 39 | self.dropout_rate = dropout_rate 40 | self.float32_logits = float32_logits 41 | 42 | self.query_in_proj_weight = nn.Parameter(torch.randn(emb_dim, emb_dim) * self.scale) 43 | self.query_in_proj_bias = nn.Parameter(torch.zeros(emb_dim)) 44 | self.key_in_proj_weight = nn.Parameter(torch.randn(emb_dim, emb_dim) * self.scale) 45 | self.key_in_proj_bias = nn.Parameter(torch.zeros(emb_dim)) 46 | self.value_in_proj_weight = nn.Parameter(torch.randn(emb_dim, emb_dim) * self.scale) 47 | self.value_in_proj_bias = nn.Parameter(torch.zeros(emb_dim)) 48 | 49 | self.attn_drop = layers.Dropout(dropout_rate, broadcast_dims=(-2, )) 50 | self.out_proj = nn.Linear(emb_dim, emb_dim, bias=True) 51 | 52 | def forward(self, inputs_q, inputs_kv, attn_mask: Optional[torch.Tensor] = None): 53 | # inputs_q: [batch_size, len_q, emb_dim] 54 | # inputs_kv: [batch_size, len_kv, emb_dim] 55 | # attn_mask: [batch_size, num_heads, len_q, len_kv] 56 | 57 | # Project inputs_q/inputs_kv to multi-headed q/k/v 58 | # dimensions are then [batch, len, num_heads, head_dim] 59 | bs, q_len, emb_dim = inputs_q.shape 60 | kv_len = inputs_kv.shape[1] 61 | query = F.linear(inputs_q, self.query_in_proj_weight, self.query_in_proj_bias).reshape( 62 | bs, q_len, self.num_heads, self.head_dim 63 | ) 64 | key = F.linear(inputs_kv, self.key_in_proj_weight, self.key_in_proj_bias).reshape( 65 | bs, kv_len, self.num_heads, self.head_dim 66 | ) 67 | value = F.linear(inputs_kv, self.value_in_proj_weight, self.value_in_proj_bias).reshape( 68 | bs, kv_len, self.num_heads, self.head_dim 69 | ) 70 | 71 | if self.float32_logits: 72 | query = query.to(torch.float32) 73 | key = key.to(torch.float32) 74 | 75 | query = query * self.scale 76 | # `attn_weights`: [batch, num_heads, len_q, len_kv] 77 | attn_weights = torch.einsum("bqhd,bkhd->bhqk", query, key) 78 | 79 | if attn_mask is not None: 80 | new_attn_mask = torch.zeros_like(attn_mask, dtype=attn_weights.dtype) 81 | new_attn_mask.masked_fill_(~(attn_mask > 0), -1e10) 82 | attn_mask = new_attn_mask 83 | attn_weights += attn_mask 84 | 85 | attn_weights = F.softmax(attn_weights, dim=-1).to(inputs_q.dtype) 86 | attn_weights = self.attn_drop(attn_weights) 87 | 88 | # `attn_out`: [batch, len_q, num_heads, head_dim] 89 | attn_out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value) 90 | # `out`: [batch, len_q, emb_dim] 91 | out = self.out_proj(attn_out.reshape(bs, q_len, emb_dim)) 92 | 93 | return out 94 | 95 | 96 | class ResidualAttentionBlock(nn.Module): 97 | def __init__(self, config): 98 | super().__init__() 99 | self.config = config 100 | self.ln_1 = nn.LayerNorm(config.emb_dim, eps=1e-6) 101 | self.attn = MultiHeadDotProductAttention( 102 | config.emb_dim, 103 | config.num_heads, 104 | config.head_dim, 105 | config.dropout_rate, 106 | # The uio2 jax code did not use this parameter. 107 | # float32_logits=config.float32_attention_logits 108 | ) 109 | self.ln_2 = nn.LayerNorm(config.emb_dim, eps=1e-6) 110 | self.mlp = MLP(config) 111 | 112 | def forward(self, x, attn_mask): 113 | x1 = self.ln_1(x) 114 | x2 = self.attn(x1, x1, attn_mask) 115 | x = x + x2 116 | x1 = self.ln_2(x) 117 | x2 = self.mlp(x1) 118 | x = x + x2 119 | return x 120 | 121 | 122 | class Transformer(nn.Module): 123 | def __init__(self, config): 124 | super().__init__() 125 | self.config = config 126 | self.num_layers = config.num_layers 127 | resblocks = [] 128 | for i in range(config.num_layers): 129 | resblocks.append(ResidualAttentionBlock(config)) 130 | self.resblocks = nn.ModuleList(resblocks) 131 | 132 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 133 | xs = [] 134 | for r in self.resblocks: 135 | x = r(x, attn_mask) 136 | xs.append(x) 137 | 138 | return x, xs 139 | 140 | 141 | def _expand_token(token, batch_size: int): 142 | return token.view(1, 1, -1).expand(batch_size, -1, -1) 143 | 144 | 145 | class VisionTransformer(nn.Module): 146 | def __init__(self, config): 147 | super().__init__() 148 | self.config = config 149 | 150 | input_dim = config.patch_size * config.patch_size * 1 151 | self.embedding = nn.Linear(input_dim, config.emb_dim, bias=True) 152 | self.cls_token = nn.Parameter(torch.zeros(config.emb_dim)) 153 | self.dist_token = nn.Parameter(torch.zeros(config.emb_dim)) 154 | self.positional_embedding = nn.Parameter(torch.zeros(514, config.emb_dim)) 155 | self.transformer = Transformer(config) 156 | 157 | def add_pos_emb(self, x, pos_ids): 158 | cls_emb = self.positional_embedding[0] 159 | dist_emb = self.positional_embedding[1] 160 | pos_emb = self.positional_embedding[2:][pos_ids] 161 | 162 | x = x + torch.cat( 163 | [ 164 | _expand_token(cls_emb, x.shape[0]), 165 | _expand_token(dist_emb, x.shape[0]), 166 | pos_emb, 167 | ], 168 | dim=1, 169 | ).to(x.dtype) 170 | return x 171 | 172 | def forward(self, x, mask, pos_ids, *, patch_num: Any = (16, 16)): 173 | B = x.shape[0] 174 | x = self.embedding(x) 175 | x = torch.cat([_expand_token(self.cls_token, B).to(x.dtype), _expand_token(self.dist_token, B).to(x.dtype), x], dim=1) 176 | 177 | mask = torch.cat( 178 | [ 179 | torch.ones([B, 1], dtype=torch.int32, device=mask.device), 180 | torch.ones([B, 1], dtype=torch.int32, device=mask.device), 181 | mask, 182 | ], 183 | dim=1 184 | ) 185 | 186 | x = self.add_pos_emb(x, pos_ids) 187 | 188 | attn_mask = layers.make_attention_mask(mask, mask).to(x.dtype) 189 | 190 | x, xs = self.transformer(x, attn_mask) 191 | 192 | # remove the cls/dist token 193 | x = x[:, 2:, :] 194 | 195 | x1 = xs[1][:, 2:, :] 196 | 197 | return x, x1 198 | 199 | 200 | def transpose_input(pos_ids, input_size, patch_size): 201 | h, w = ( 202 | int(input_size[0] / patch_size), 203 | int(input_size[1] / patch_size), 204 | ) 205 | w_coord = pos_ids % w 206 | h_coord = pos_ids // w 207 | pos_ids_t = w_coord * h + h_coord 208 | 209 | return pos_ids_t 210 | 211 | 212 | class AudioFeature(nn.Module): 213 | """Converts mel-spectrograms into features""" 214 | 215 | def __init__(self, config) -> None: 216 | super().__init__() 217 | self.config = config 218 | self.vision_transformer = VisionTransformer(config) 219 | 220 | def forward(self, x, mask, pos_ids, *, patch_num: Any = (16, 8)): 221 | if self.config.transpose_input: 222 | pos_ids = transpose_input(pos_ids, self.config.default_input_size, self.config.patch_size) 223 | x, x1 = self.vision_transformer(x, mask, pos_ids) 224 | return x, x1 225 | -------------------------------------------------------------------------------- /uio2/audio_utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for pre-processing audio""" 2 | import logging 3 | import subprocess 4 | from os.path import exists 5 | from typing import Optional, List 6 | 7 | import numpy as np 8 | import scipy 9 | 10 | from uio2 import config 11 | 12 | 13 | BUFFER_FROM_END = 0.1 14 | WAV_MAX_VALUE = 32768.0 15 | 16 | 17 | def get_num_segments(audio_length, audio_segment_length): 18 | num_segments = int(audio_length // audio_segment_length) 19 | 20 | # allows extra frame only if the midpoint is an available to extract video frames 21 | if (audio_length % audio_segment_length) - BUFFER_FROM_END > ( 22 | audio_segment_length / 2.0 23 | ): 24 | num_segments += 1 25 | 26 | if num_segments == 0 and audio_length > 0: 27 | num_segments = 1 28 | 29 | return num_segments 30 | 31 | 32 | def get_audio_length(audio_path): 33 | out = subprocess.check_output( 34 | [ 35 | "ffprobe", 36 | "-v", 37 | "error", 38 | "-select_streams", 39 | "a:0", 40 | "-show_entries", 41 | "stream=duration", 42 | "-of", 43 | "default=noprint_wrappers=1:nokey=1", 44 | audio_path, 45 | ], 46 | ) 47 | duration = float(out.decode("utf-8").strip()) 48 | return duration 49 | 50 | 51 | def read_audio_file(src, sr=config.AUDIO_SAMPLING_RATE): 52 | """Load wavform from file or file handle""" 53 | try: 54 | import librosa 55 | except ImportError as e: 56 | raise ValueError("Librosa must be install for audio pre-processing", e) 57 | waveform, _sr = librosa.core.load(src, sr=sr) 58 | assert _sr == sr 59 | waveform = waveform.astype(np.float32) 60 | if len(waveform.shape) > 1: 61 | waveform = np.mean(waveform, axis=1) 62 | return waveform 63 | 64 | 65 | def make_spectrogram(waveform, sample_rate=16000): 66 | """Make spectrogram from waveform""" 67 | try: 68 | from librosa.feature import melspectrogram 69 | except ImportError as e: 70 | raise ValueError("Librosa must be install for audio pre-processing", e) 71 | 72 | # Parameters we manually selected for sound quality 73 | params = { 74 | 'n_fft': 1024, 75 | 'hop_length': 256, 76 | 'window': scipy.signal.windows.hann, 77 | 'n_mels': 128, 78 | 'fmin': 0.0, 79 | 'fmax': sample_rate / 2.0, 80 | 'center': True, 81 | 'pad_mode': 'reflect', 82 | } 83 | mel = melspectrogram(y=waveform, sr=sample_rate, **params) 84 | return mel 85 | 86 | 87 | def extract_spectrograms_from_audio( 88 | waveform: np.ndarray, 89 | audio_length, 90 | audio_segment_length: float = config.AUDIO_SEGMENT_LENGTH, 91 | spectrogram_length: float = config.AUDIO_SPECTRUM_LENGTH, 92 | sampling_rate: int = config.AUDIO_SAMPLING_RATE, 93 | ) -> List[np.ndarray]: 94 | """Turns a waveform in a list of melspectograms UIO2 can process""" 95 | num_segments = get_num_segments(audio_length, audio_segment_length) 96 | boundaries = np.linspace( 97 | 0, num_segments * audio_segment_length, num_segments + 1 98 | ).tolist() 99 | 100 | # Pad to max time just in case, crop if longer 101 | max_samples = int(sampling_rate * num_segments * audio_segment_length) 102 | if waveform.size < max_samples: 103 | waveform = np.concatenate( 104 | [waveform, np.zeros(max_samples - waveform.size, dtype=np.float32)], 0 105 | ) 106 | waveform = waveform[:max_samples] 107 | 108 | # split waveform into segments 109 | spectrograms = [] 110 | for i in range(num_segments): 111 | if audio_segment_length <= spectrogram_length: 112 | ts_start = int(boundaries[i] * sampling_rate) 113 | ts_end = int(boundaries[i + 1] * sampling_rate) 114 | waveform_segment = waveform[ts_start:ts_end] 115 | num_pad = int(sampling_rate * spectrogram_length) - (ts_end - ts_start) 116 | if num_pad > 0: 117 | waveform_segment = np.concatenate( 118 | [ 119 | np.zeros(num_pad // 2, dtype=np.float32), 120 | waveform_segment, 121 | np.zeros(num_pad - num_pad // 2, dtype=np.float32), 122 | ], 123 | 0, 124 | ) 125 | waveform_segment = waveform_segment[ 126 | : int(sampling_rate * spectrogram_length) 127 | ] 128 | else: 129 | ts_start = int(boundaries[i] * sampling_rate) 130 | ts_end = int(boundaries[i + 1] * sampling_rate) 131 | ts_mid = (ts_start + ts_end) / 2 132 | start = int(ts_mid - sampling_rate * spectrogram_length / 2) 133 | end = start + int(sampling_rate * spectrogram_length) 134 | waveform_segment = waveform[start:end] 135 | 136 | # Create spectrogram from waveform 137 | spectrogram = make_spectrogram( 138 | waveform_segment, sampling_rate, 139 | ) # shape (128, 256) 140 | spectrograms.append(spectrogram) 141 | 142 | if len(spectrograms) == 0: 143 | assert num_segments == 0 144 | raise ValueError("Couldn't make spectrograms: num_segments is 0") 145 | 146 | # (N,128,256) is (# of segments, # of mel bands in spectrogram, # of hops in spectrogram) 147 | spectrograms = np.stack(spectrograms).astype(np.float32) 148 | assert spectrograms.shape[1:] == (128, 256) 149 | return spectrograms 150 | 151 | 152 | def load_audio( 153 | path: str, 154 | audio_segment_length=config.AUDIO_SEGMENT_LENGTH, 155 | spectrogram_length=config.AUDIO_SEGMENT_LENGTH, 156 | max_audio_length: Optional[float] = None, 157 | ): 158 | """Loads audio as a spectrogram from `path`""" 159 | if not exists(path): 160 | raise FileNotFoundError(f"{path} not found") 161 | audio_length = get_audio_length(path) 162 | if max_audio_length and max_audio_length > audio_length: 163 | logging.warning(f"Use the input audio length of {max_audio_length} (original {audio_length}) seconds.") 164 | audio_length = max_audio_length 165 | 166 | wavform = read_audio_file(path) 167 | 168 | return extract_spectrograms_from_audio( 169 | wavform, 170 | audio_length=audio_length, 171 | audio_segment_length=audio_segment_length, 172 | spectrogram_length=spectrogram_length, 173 | ) 174 | 175 | -------------------------------------------------------------------------------- /uio2/audio_vqgan.py: -------------------------------------------------------------------------------- 1 | """ViTVQGAN model implementation in PyTorch""" 2 | import torch 3 | 4 | from uio2.config import AudioViTVQGANConfig 5 | 6 | from uio2 import layers 7 | import math 8 | from torch import nn 9 | 10 | 11 | class MlpBlock(nn.Module): 12 | """Transformer MLP / feed-forward block. 13 | 14 | Attributes: 15 | emb_dim; input/output dimension of MLP 16 | intermediate_dim: Shared dimension of hidden layers. 17 | activation: Type of activation for each layer. It is either 18 | 'linear', a string function name in torch.nn.functional, or a function. 19 | intermediate_dropout_rate: Dropout rate used after the intermediate layers. 20 | """ 21 | def __init__( 22 | self, 23 | emb_dim: int, 24 | mlp_dim: int, 25 | act_fn: str = 'relu', 26 | dropout_rate: float = 0.0, 27 | ): 28 | super().__init__() 29 | self.act_fn = act_fn 30 | self.fc1 = nn.Linear(emb_dim, mlp_dim, bias=True) 31 | self.dropout = nn.Dropout(dropout_rate) 32 | self.fc2 = nn.Linear(mlp_dim, emb_dim, bias=True) 33 | 34 | nn.init.xavier_uniform_(self.fc1.weight) 35 | nn.init.normal_(self.fc1.bias, std=1e-6) 36 | nn.init.xavier_uniform_(self.fc2.weight) 37 | nn.init.normal_(self.fc2.bias, std=1e-6) 38 | 39 | def forward(self, inputs): 40 | """Applies Transformer MlpBlock module.""" 41 | x = self.fc1(inputs) 42 | x = layers._convert_to_activation_function(self.act_fn)(x) 43 | x = self.dropout(x) 44 | output = self.fc2(x) 45 | output = self.dropout(output) 46 | return output 47 | 48 | 49 | class TransformerLayer(nn.Module): 50 | """Transformer layer""" 51 | def __init__( 52 | self, 53 | emb_dim: int, 54 | mlp_dim: int, 55 | num_heads: int, 56 | head_dim: int, 57 | dropout_rate: float = 0.0, 58 | droppath_rate: float = 0.0, 59 | attention_dropout_rate: float = 0.0, 60 | act_fn: str = 'relu', 61 | float32_attention_logits: bool = False 62 | ): 63 | super().__init__() 64 | self.ln_1 = nn.LayerNorm(emb_dim, eps=1e-6) 65 | self.attn = layers.MultiHeadDotProductAttention( 66 | emb_dim, 67 | num_heads, 68 | head_dim, 69 | dropout_rate=attention_dropout_rate, 70 | float32_logits=float32_attention_logits, 71 | qk_norm=False, 72 | depth_normalize=True, 73 | scaled_cosine=False, 74 | ) 75 | self.dropout = nn.Dropout(dropout_rate) 76 | self.droppath = layers.DropPath(droppath_rate) 77 | self.ln_2 = nn.LayerNorm(emb_dim, eps=1e-6) 78 | self.mlp = MlpBlock(emb_dim, mlp_dim, act_fn, dropout_rate) 79 | 80 | def forward(self, inputs): 81 | x = self.ln_1(inputs) 82 | x = self.attn(x, x) 83 | x = self.dropout(x) 84 | x = self.droppath(x) + inputs 85 | 86 | y = self.ln_2(x) 87 | y = self.mlp(y) 88 | return x + self.droppath(y) 89 | 90 | 91 | class Transformer(nn.Module): 92 | """Transformer Model for sequence to sequence translation. 93 | Attributes: 94 | num_layers: number of layers 95 | mlp_dim: dimension of the mlp on top of attention block 96 | num_heads: Number of heads in nn.MultiHeadDotProductAttention 97 | dropout_rate: dropout rate. 98 | attention_dropout_rate: dropout rate in self attention. 99 | """ 100 | def __init__( 101 | self, 102 | num_layers: int, 103 | emb_dim: int, 104 | mlp_dim: int, 105 | num_heads: int, 106 | head_dim: int, 107 | dropout_rate: float = 0.0, 108 | droppath_rate: float = 0.0, 109 | attention_dropout_rate: float = 0.0, 110 | act_fn: str = 'relu', 111 | ): 112 | super().__init__() 113 | self.dropout = nn.Dropout(dropout_rate) 114 | self.num_layers = num_layers 115 | dpr = [x.item() for x in torch.linspace(0, droppath_rate, num_layers)] 116 | for lyr in range(self.num_layers): 117 | self.add_module( 118 | f"encoderblock_{lyr}", TransformerLayer( 119 | emb_dim=emb_dim, 120 | mlp_dim=mlp_dim, 121 | num_heads=num_heads, 122 | head_dim=head_dim, 123 | dropout_rate=dropout_rate, 124 | droppath_rate=dpr[lyr], 125 | attention_dropout_rate=attention_dropout_rate, 126 | act_fn=act_fn)) 127 | 128 | self.encoder_norm = nn.LayerNorm(emb_dim, eps=1e-6) 129 | 130 | def forward(self, x): 131 | x = self.dropout(x) 132 | for lyr in range(self.num_layers): 133 | x = self.__getattr__(f"encoderblock_{lyr}")(x) 134 | x = self.encoder_norm(x) 135 | 136 | return x 137 | 138 | 139 | class ViTEncoder(nn.Module): 140 | def __init__(self, config: AudioViTVQGANConfig): 141 | super().__init__() 142 | self.config = config 143 | cfg = self.config 144 | 145 | self.register_buffer("encoder_position_embedding", 146 | layers.get_2d_sincos_pos_embed( 147 | emb_dim=cfg.encoder_hidden_size, 148 | image_size=cfg.default_input_size, 149 | image_patch_size=cfg.patch_size, 150 | class_token=False), persistent=False) 151 | in_size = cfg.output_channel * cfg.patch_size[0] * cfg.patch_size[1] 152 | self.embedding = nn.Linear(in_size, cfg.encoder_hidden_size, bias=True) 153 | self.transformer = Transformer( 154 | num_layers=cfg.encoder_num_layers, 155 | emb_dim=cfg.encoder_hidden_size, 156 | mlp_dim=cfg.encoder_mlp_dim, 157 | num_heads=cfg.encoder_num_heads, 158 | head_dim=cfg.encoder_head_dim, 159 | dropout_rate=cfg.dropout_rate, 160 | droppath_rate=cfg.droppath_rate, 161 | attention_dropout_rate=cfg.attention_dropout_rate, 162 | act_fn=cfg.act_fn, 163 | ) 164 | self.act_fn = cfg.act_fn 165 | self.encoder_proj = nn.Linear(cfg.encoder_hidden_size, cfg.proj_dim, bias=cfg.use_bias) 166 | self.encoder_norm = layers.LayerNorm(cfg.proj_dim, eps=1e-6, weight=False) 167 | 168 | nn.init.trunc_normal_(self.embedding.weight, std=math.sqrt(1 / in_size), a=-2.0, b=2.0) 169 | nn.init.zeros_(self.embedding.bias) 170 | nn.init.trunc_normal_(self.encoder_proj.weight, std=math.sqrt(1 / in_size), a=-2.0, b=2.0) 171 | if cfg.use_bias: 172 | nn.init.zeros_(self.encoder_proj.bias) 173 | nn.init.ones_(self.encoder_norm.bias) 174 | 175 | def forward(self, x): 176 | # reshape [bs, h, w, c] to [bs, (h/dh) * (w/dw), c*dh*dw] 177 | x = layers.space_to_depth(x, spatial_block_size=self.config.patch_size[0]) 178 | x = self.embedding(x) 179 | x += self.encoder_position_embedding.unsqueeze(0) 180 | x = self.transformer(x) 181 | x = layers._convert_to_activation_function(self.act_fn)(x) 182 | x = self.encoder_proj(x) 183 | x = self.encoder_norm(x) 184 | return x 185 | 186 | 187 | class ViTDecoder(nn.Module): 188 | def __init__(self, config: AudioViTVQGANConfig): 189 | super().__init__() 190 | self.config = config 191 | cfg = self.config 192 | 193 | self.register_buffer("decoder_position_embedding", 194 | layers.get_2d_sincos_pos_embed( 195 | emb_dim=cfg.encoder_hidden_size, 196 | image_size=cfg.default_input_size, 197 | image_patch_size=cfg.patch_size, 198 | class_token=False), persistent=False) 199 | self.decoder_proj = nn.Linear(cfg.proj_dim, cfg.decoder_hidden_size, bias=cfg.use_bias) 200 | self.transformer = Transformer( 201 | num_layers=cfg.decoder_num_layers, 202 | emb_dim=cfg.decoder_hidden_size, 203 | mlp_dim=cfg.decoder_mlp_dim, 204 | num_heads=cfg.decoder_num_heads, 205 | head_dim=cfg.decoder_head_dim, 206 | dropout_rate=cfg.dropout_rate, 207 | droppath_rate=cfg.droppath_rate, 208 | attention_dropout_rate=cfg.attention_dropout_rate, 209 | act_fn=cfg.act_fn, 210 | ) 211 | 212 | self.conv_transpose = nn.ConvTranspose2d( 213 | cfg.decoder_hidden_size, 214 | cfg.output_channel, 215 | kernel_size=cfg.patch_size, 216 | stride=cfg.patch_size, 217 | bias=cfg.use_bias, 218 | ) 219 | 220 | nn.init.trunc_normal_(self.decoder_proj.weight, std=math.sqrt(1 / cfg.proj_dim), a=-2.0, b=2.0) 221 | # the weight shape of ConvTranspose2d is (in_channels, out_channels/groups, kernel_size[0], kernel_size[1]) 222 | # while that of Conv2d is (out_channels/groups, in_channels, kernel_size[0], kernel_size[1]). 223 | # Thus, get fan_out 224 | _, fan_out = nn.init._calculate_fan_in_and_fan_out(self.conv_transpose.weight) 225 | nn.init.trunc_normal_(self.conv_transpose.weight, std=math.sqrt(1 / fan_out), a=-2.0, b=2.0) 226 | if cfg.use_bias: 227 | nn.init.zeros_(self.decoder_proj.bias) 228 | nn.init.zeros_(self.conv_transpose.bias) 229 | 230 | def forward(self, x): 231 | # [bs, (h/dh) * (w/dw), c*dh*dw] -> [bs, c, h, w] 232 | cfg = self.config 233 | bs = x.shape[0] 234 | x = self.decoder_proj(x) 235 | x += self.decoder_position_embedding.unsqueeze(0) 236 | x = self.transformer(x) 237 | img_size = cfg.default_input_size 238 | patch_size = cfg.patch_size 239 | x = x.reshape( 240 | bs, img_size[0] // patch_size[0], img_size[1] // patch_size[1], cfg.decoder_hidden_size) 241 | x = x.permute(0, 3, 1, 2).contiguous() 242 | output_size = (x.shape[0], cfg.output_channel, img_size[0], img_size[1]) 243 | x = self.conv_transpose(x, output_size=output_size) 244 | return x 245 | 246 | 247 | class ViTVQGAN(nn.Module): 248 | """Pytorch Implementation of ViT-VQGAN""" 249 | def __init__(self, config: AudioViTVQGANConfig): 250 | super().__init__() 251 | self.config = config 252 | cfg = self.config 253 | 254 | self.quantize = layers.VectorQuantizer( 255 | n_e=cfg.vocab_size, 256 | e_dim=cfg.proj_dim, 257 | beta=0.25, 258 | uniform_init=True, 259 | legacy=False, 260 | l2_norm=True, 261 | ) 262 | self.encoder = ViTEncoder(cfg) 263 | self.decoder = ViTDecoder(cfg) 264 | 265 | def encode(self, x): 266 | return self.encoder(x) 267 | 268 | def decode(self, x): 269 | return self.decoder(x) 270 | 271 | def get_quantize_from_emb(self, h): 272 | z, _, [_, _, indices] = self.quantize(h) 273 | return indices.reshape(h.shape[0], -1) 274 | 275 | def decode_code(self, code_b): 276 | quant_b = self.quantize.get_codebook_entry(code_b) 277 | dec = self.decode(quant_b) 278 | return dec 279 | 280 | def get_codebook_indices(self, x, vqgan_decode=False): 281 | h = self.encode(x) 282 | z, _, [_, _, indices] = self.quantize(h) 283 | if vqgan_decode: 284 | dec = self.decode(z) 285 | 286 | return indices.reshape(h.shape[0], -1) 287 | 288 | def forward(self, x): 289 | # x: [bs, h, w, c] 290 | h = self.encode(x) 291 | z, _, [_, _, indices] = self.quantize(h) 292 | if self.config.use_decoder: 293 | # [bs, c, h, w] 294 | dec = self.decode(z) 295 | else: 296 | dec = None 297 | return z, dec -------------------------------------------------------------------------------- /uio2/config.py: -------------------------------------------------------------------------------- 1 | """Configuration settings used in UIO2""" 2 | import dataclasses 3 | from dataclasses import dataclass, field 4 | from typing import Any, Sequence, Dict, Tuple 5 | 6 | import torch 7 | import math 8 | 9 | from uio2.vocabulary import SentencePieceVocabulary 10 | 11 | PAD_ID = 0 12 | EOS_ID = 1 13 | BOS_ID = 0 14 | MAX_TEXT_LEN = 512 15 | 16 | # Constants used when encoding region 17 | VOCAB_START = 200 18 | NUM_DETECTION_BIN = 1000 19 | POS_MAX_VALUE = 50 20 | POS_MIN_VALUE = -50 21 | 22 | D_THETA_MAX_VALUE = math.pi 23 | D_THETA_MIN_VALUE = -math.pi 24 | D_RADIUS_MAX_VALUE = 0.7 25 | D_RADIUS_MIN_VALUE = -0.7 26 | D_SINUSOID_MAX_VALUE = 1.0 27 | D_SINUSOID_MIN_VALUE = -1.0 28 | 29 | 30 | # Controls data augmentation 31 | RANDOM_SCALE_MAX = 1.3333 32 | RANDOM_SCALE_MIN = 0.75 33 | RANDOM_SCALE_RATIO = 0.5 34 | 35 | # Image pre-processing 36 | IMAGE_INPUT_SIZE = [384, 384] 37 | IMAGE_INPUT_D = 16 38 | IMAGE_INPUT_PATCHES = (IMAGE_INPUT_SIZE[0] // IMAGE_INPUT_D, IMAGE_INPUT_SIZE[1] // IMAGE_INPUT_D) 39 | IMAGE_HISTORY_INPUT_SIZE = [256, 256] 40 | IMAGE_HISTORY_INPUT_D = 16 41 | IMAGE_VIT_MEAN = [0.48145466, 0.4578275, 0.40821073] 42 | IMAGE_VIT_STD = [0.26862954, 0.26130258, 0.27577711] 43 | 44 | IMAGE_TARGET_SIZE = [256, 256] 45 | IMAGE_TARGET_D = 8 46 | 47 | # Control parameters for 3D tasks 48 | LOCATION_RANGE = [-0.1, 1.1] 49 | DIMENSION_RANGE = [0, 6] 50 | DEPTH_RANGE = [-0.001, 0.1] 51 | ANGLE_RANGE = [0, 6.283185307179586] 52 | 53 | # Controls input/output audio sizes 54 | AUDIO_INPUT_SIZE = [256, 128] 55 | AUDIO_INPUT_D = 16 56 | AUDIO_TARGET_SIZE = [256, 128] 57 | AUDIO_TARGET_D = 8 58 | AUDIO_HISTORY_INPUT_SIZE = [256, 128] 59 | AUDIO_HISTORY_INPUT_D = 16 60 | AUDIO_SEGMENT_LENGTH = 4.08 61 | AUDIO_SPECTRUM_LENGTH = 4.08 62 | AUDIO_SAMPLING_RATE = 16000 63 | 64 | # Used for audio pre-processing 65 | AUDIOSET_MEAN = -5.0945 66 | AUDIOSET_STD = 3.8312 67 | AUDIO_VIT_MEAN = -4.26 68 | AUDIO_VIT_STD = 9.14 69 | 70 | 71 | DEFAULT_EXTRA_IDS = VOCAB_START + NUM_DETECTION_BIN 72 | MODALITY_EXTRA_ID_N_FRAMES = 8 # 8 frames just in case 73 | if MODALITY_EXTRA_ID_N_FRAMES: 74 | MODALITY_EXTRA_IDS = (1 + MODALITY_EXTRA_ID_N_FRAMES) * 2 # image/audio input + n * image/audio history 75 | else: 76 | MODALITY_EXTRA_IDS = 0 77 | 78 | 79 | def get_tokenizer(path): 80 | """Gets the UIO2 tokenizer 81 | 82 | This is the LLaMaTokenizer but with bos=0, pad=0, eos=1. `path` should point to a 83 | `llama_tokenizer.model` file 84 | """ 85 | return SentencePieceVocabulary( 86 | path, 87 | extra_ids=DEFAULT_EXTRA_IDS, 88 | reverse_extra_ids=True, 89 | modality_extra_id_n_frames=MODALITY_EXTRA_ID_N_FRAMES, 90 | hack_to_t5_start_tokens=True, 91 | prefix_as_special_token=True, 92 | ) 93 | 94 | 95 | @dataclass 96 | class T5Config: 97 | """Configures the main transformer""" 98 | vocab_size: int = 33280 99 | image_vocab_size: int = 16512 100 | image_patch_size: int = 16 101 | audio_vocab_size: int = 8320 102 | emb_dim: int = 512 103 | num_heads: int = 8 104 | num_encoder_layers: int = 6 105 | num_decoder_layers: int = 6 106 | head_dim: int = 64 107 | mlp_dim: int = 2048 108 | mlp_activations: Sequence[str] = ('silu', 'linear') 109 | dropout_rate: float = 0.0 110 | dropout_broadcast_dims: Sequence[int] = (-2, ) 111 | # the embedding weights are used in the decoder output layer. 112 | logits_via_embedding: bool = True 113 | # Whether to accumulate attention logits in float32 regardless of dtype. 114 | float32_attention_logits: bool = True 115 | decoder_xattention_internval: int = 1 116 | qk_norm: bool = True 117 | dalle_attn_mask: bool = True 118 | # Whether to use dynamic masking when computing the loss of a target image 119 | dynamic_unk_mask: bool = True 120 | 121 | # Used to for ROPE 122 | encoder_max_image_length: int = IMAGE_INPUT_PATCHES[0]*IMAGE_INPUT_PATCHES[1] 123 | encoder_max_audio_length: int = 128 124 | encoder_max_text_length: int = MAX_TEXT_LEN 125 | decoder_max_image_length: int = 1024 126 | decoder_max_audio_length: int = 512 127 | decoder_max_text_length: int = MAX_TEXT_LEN 128 | text_pos_emb: str = 'llama_rope' # '1d-sincos' # 'learnable' 129 | image_pos_emb: str = 'llama_rope' 130 | audio_pos_emb: str = 'llama_rope' 131 | image_history_pos_emb: str = 'llama_rope' 132 | audio_history_pos_emb: str = 'llama_rope' 133 | 134 | # Used for encoding and pre-processing input modalities 135 | image_tokenizer_type: str = "vqgan" 136 | default_image_size: Sequence[int] = (256, 256) 137 | default_image_vit_size: Sequence[int] = tuple(IMAGE_INPUT_SIZE) # for vit-large model 138 | default_image_history_vit_size: Sequence[int] = (256, 256) 139 | default_audio_size: Sequence[int] = (256, 128) 140 | default_audio_vit_size: Sequence[int] = (256, 128) 141 | default_audio_history_vit_size: Sequence[int] = (256, 128) 142 | image_vit_patch_size: int = 16 143 | audio_patch_size: int = 16 144 | audio_vit_patch_size: int = 16 145 | 146 | 147 | # Modality-specific processing configs 148 | 149 | @dataclass 150 | class VQGANConfig: 151 | embed_dim: int = 4 152 | n_embed: int = 16384 153 | double_z: bool = False 154 | z_channels: int = 4 155 | resolution: int = 256 156 | in_channels: int = 3 157 | out_ch: int = 3 158 | ch: int = 128 159 | ch_mult: Sequence[int] = (1,2,2,4) 160 | num_res_blocks: int = 2 161 | attn_resolutions: Sequence[int] = (32,) 162 | dropout: float = 0 163 | default_input_size: Sequence[int] = (256,256) 164 | patch_size: Sequence[int] = (8, 8) 165 | checkpoint_path: str = '' 166 | 167 | 168 | @dataclass 169 | class ImageVitFeatureConfig: 170 | patch_size: int = 16 171 | pos_patch_size: int = 16 172 | emb_dim: int = 768 173 | num_heads: int = 12 174 | num_layers: int = 11 # -2 layer 175 | head_dim: int = 64 176 | mlp_dim: int = 3072 177 | mlp_activations: Sequence[str] = ('gelu',) 178 | dropout_rate: float = 0.0 179 | dropout_broadcast_dims: Sequence[int] = () 180 | float32_attention_logits: bool = True 181 | default_input_size: Sequence[int] = (256, 256) 182 | num_pos: int = 197 183 | 184 | 185 | @dataclass 186 | class AudioVitFeatureConfig: 187 | vit_embed: bool = True 188 | patch_size: int = 16 189 | pos_patch_size: int = 16 190 | emb_dim: int = 768 191 | num_heads: int = 12 192 | num_layers: int = 11 # -2 layer 193 | head_dim: int = 64 194 | mlp_dim: int = 3072 195 | mlp_activations: Sequence[str] = ('gelu',) 196 | dropout_rate: float = 0.0 197 | dropout_broadcast_dims: Sequence[int] = () 198 | float32_attention_logits: bool = True 199 | default_input_size: Sequence[int] = (256, 128) 200 | transpose_input: bool = True 201 | 202 | 203 | @dataclass 204 | class ImageResamplerConfig: 205 | resampler_type: str = "perceiver" # linear, perceiver, v2 206 | max_frames: int = 8 207 | latents_size: int = 32 208 | emb_dim: int = 768 209 | num_heads: int = 12 210 | num_layers: int = 2 211 | xattention_index: Sequence[int] = (0, 1) 212 | head_dim: int = 64 213 | mlp_dim: int = 2048 214 | mlp_activations: Sequence[str] = ('gelu',) 215 | dropout_rate: float = 0.0 216 | dropout_broadcast_dims: Sequence[int] = (-2,) 217 | droppath_rate: float = 0.0 218 | layer_drop: float = 0.0 219 | xattn_qk_norm: bool = True 220 | xattn_scaled_cosine: bool = False 221 | attn_qk_norm: bool = True 222 | attn_scaled_cosine: bool = False 223 | float32_attention_logits: bool = True 224 | clip_attn_logit: Any = None 225 | 226 | 227 | @dataclass 228 | class AudioResamplerConfig: 229 | resampler_type: str = "perceiver" # perceiver, attention 230 | max_frames: int = 8 231 | latents_size: int = 16 232 | emb_dim: int = 768 233 | num_heads: int = 12 234 | num_layers: int = 2 235 | xattention_index: Sequence[int] = (0, 1) 236 | head_dim: int = 64 237 | mlp_dim: int = 2048 238 | mlp_activations: Sequence[str] = ('gelu',) 239 | dropout_rate: float = 0.0 240 | dropout_broadcast_dims: Sequence[int] = (-2,) 241 | droppath_rate: float = 0.0 242 | layer_drop: float = 0.0 243 | xattn_qk_norm: bool = True 244 | xattn_scaled_cosine: bool = False 245 | attn_qk_norm: bool = True 246 | attn_scaled_cosine: bool = False 247 | float32_attention_logits: bool = True 248 | clip_attn_logit: Any = None 249 | 250 | 251 | @dataclass 252 | class ImageViTVQGANConfig: 253 | # VIT-VQGAN CONFIG 254 | vocab_size: int = 8192 255 | proj_dim: int = 32 256 | # Transformers 257 | encoder_hidden_size: int = 512 258 | encoder_num_layers: int = 8 259 | encoder_mlp_dim: int = 2048 260 | encoder_num_heads: int = 8 261 | encoder_head_dim: int = 64 262 | 263 | decoder_hidden_size: int = 512 264 | decoder_num_layers: int = 8 265 | decoder_mlp_dim: int = 2048 266 | decoder_num_heads: int = 8 267 | decoder_head_dim: int = 64 268 | 269 | dropout_rate: float = 0.0 270 | droppath_rate: float = 0.0 271 | attention_dropout_rate: float = 0.0 272 | use_bias: bool = False 273 | act_fn: str = 'relu' 274 | # Misc. 275 | default_input_size: Sequence[int] = (256,256) 276 | patch_size: Sequence[int] = (8, 8) 277 | 278 | output_channel: int = 3 279 | # checkpoint path for initialization. 280 | checkpoint_path: str = '' 281 | use_decoder: bool = True 282 | 283 | 284 | @dataclass 285 | class AudioViTVQGANConfig: 286 | # VIT-VQGAN CONFIG 287 | vocab_size: int = 8192 288 | proj_dim: int = 32 289 | # Transformers 290 | encoder_hidden_size: int = 512 291 | encoder_num_layers: int = 8 292 | encoder_mlp_dim: int = 2048 293 | encoder_num_heads: int = 8 294 | encoder_head_dim: int = 64 295 | 296 | decoder_hidden_size: int = 512 297 | decoder_num_layers: int = 8 298 | decoder_mlp_dim: int = 2048 299 | decoder_num_heads: int = 8 300 | decoder_head_dim: int = 64 301 | 302 | dropout_rate: float = 0.0 303 | droppath_rate: float = 0.0 304 | attention_dropout_rate: float = 0.0 305 | use_bias: bool = False 306 | act_fn: str = 'relu' 307 | # Misc. 308 | default_input_size: Sequence[int] = (128, 256) # we need to keep this to make it 309 | patch_size: Sequence[int] = (8, 8) 310 | 311 | output_channel: int = 1 312 | # checkpoint path for initialization. 313 | checkpoint_path: str = '' 314 | use_decoder: bool = True 315 | 316 | 317 | DEFAULT_SEQUENCE_LEN = { 318 | "is_training": True, 319 | "image_input_samples": 576, 320 | "image_history_input_samples": 256, 321 | "audio_input_samples": 128, 322 | "audio_history_input_samples": 128, 323 | 'num_frames': 4, 324 | } 325 | 326 | INPUT_MODALITIES = ['text', 'image', 'image_history', 'audio', 'audio_history'] 327 | TARGET_MODALITIES = ['text', 'image', 'audio'] 328 | 329 | 330 | @dataclass 331 | class Config: 332 | """Complete config that includes pre-processing and modality-specific configs""" 333 | t5_config: T5Config 334 | image_history_cfg: ImageResamplerConfig=ImageResamplerConfig() 335 | audio_history_cfg: AudioResamplerConfig=AudioResamplerConfig() 336 | freeze_vit: bool = True 337 | input_modalities: Tuple = tuple(INPUT_MODALITIES) 338 | target_modalities : Tuple = tuple(TARGET_MODALITIES) 339 | sequence_length: Dict = field(default_factory=lambda: dict(DEFAULT_SEQUENCE_LEN)) 340 | image_vqgan: VQGANConfig=VQGANConfig() 341 | audio_vqgan: AudioViTVQGANConfig=AudioViTVQGANConfig() 342 | image_vit_cfg: ImageVitFeatureConfig=ImageVitFeatureConfig() 343 | audio_vit_cfg: AudioVitFeatureConfig=AudioVitFeatureConfig() 344 | use_image_vit: bool = True 345 | use_audio_vit: bool = True 346 | use_image_history_vit: bool = True 347 | use_audio_history_vit: bool = True 348 | 349 | def to_dict(self) -> Dict: 350 | return dataclasses.asdict(self) 351 | 352 | @staticmethod 353 | def from_dict(data: Dict) -> 'Config': 354 | return Config( 355 | t5_config=T5Config(**data["t5_config"]), 356 | image_history_cfg=ImageResamplerConfig(**data["image_history_cfg"]), 357 | audio_history_cfg=AudioResamplerConfig(**data["audio_history_cfg"]), 358 | image_vqgan=VQGANConfig(**data["image_vqgan"]), 359 | audio_vqgan=AudioViTVQGANConfig(**data["audio_vqgan"]), 360 | image_vit_cfg=ImageVitFeatureConfig(**data["image_vit_cfg"]), 361 | audio_vit_cfg=AudioVitFeatureConfig(**data["audio_vit_cfg"]), 362 | **{k: v for k, v in data.items() if not ("cfg" in k or "vqgan" in k or k == "t5_config")} 363 | ) 364 | 365 | 366 | # Configs used for our trained models 367 | LARGE = Config( 368 | t5_config=T5Config( 369 | emb_dim=1024, 370 | num_heads=16, 371 | num_encoder_layers=24, 372 | num_decoder_layers=24, 373 | head_dim=64, 374 | mlp_dim=2816 375 | ), 376 | ) 377 | 378 | XL = Config( 379 | t5_config=T5Config( 380 | emb_dim=2048, 381 | num_heads=16, 382 | num_encoder_layers=24, 383 | num_decoder_layers=24, 384 | head_dim=128, 385 | mlp_dim=5120, 386 | ), 387 | image_history_cfg=ImageResamplerConfig( 388 | emb_dim=1024, 389 | num_heads=16, 390 | head_dim=64, 391 | mlp_dim=4096, 392 | ), 393 | audio_history_cfg=AudioResamplerConfig( 394 | emb_dim=1024, 395 | num_heads=16, 396 | head_dim=64, 397 | mlp_dim=4096, 398 | ) 399 | ) 400 | 401 | 402 | XXL = Config( 403 | t5_config=T5Config( 404 | emb_dim=3072, 405 | num_heads=24, 406 | num_encoder_layers=24, 407 | num_decoder_layers=24, 408 | head_dim=128, 409 | mlp_dim=8192, 410 | ), 411 | image_history_cfg=ImageResamplerConfig( 412 | emb_dim=1024, 413 | num_heads=16, 414 | head_dim=64, 415 | mlp_dim=4096, 416 | xattn_qk_norm=False, 417 | xattn_scaled_cosine=True, 418 | attn_qk_norm=False, 419 | attn_scaled_cosine=True, 420 | ), 421 | audio_history_cfg=AudioResamplerConfig( 422 | emb_dim=1024, 423 | num_heads=16, 424 | head_dim=64, 425 | mlp_dim=4096, 426 | xattn_qk_norm=False, 427 | xattn_scaled_cosine=True, 428 | attn_qk_norm=False, 429 | attn_scaled_cosine=True, 430 | ) 431 | ) 432 | 433 | CONFIG_MAP = dict( 434 | large=LARGE, 435 | xl=XL, 436 | xxl=XXL 437 | ) -------------------------------------------------------------------------------- /uio2/convert_checkpoint.py: -------------------------------------------------------------------------------- 1 | """Maps jax parameters to pytorch ones""" 2 | from typing import Tuple, Dict 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def convert_param(name: str, param: np.ndarray) -> Tuple[str, np.ndarray]: 9 | """Converts a jax UIO2 name/parameter into a torch name/parameter""" 10 | parts = name.split(".") 11 | if len(parts) == 0: 12 | return name, param 13 | 14 | if parts[0] in {"input_encoders_audio_history", "input_encoders_image_history"}: 15 | # resampler translation 16 | if parts[0] == "input_encoders_image_history": 17 | parts[0] = "input_embedders.image_history" 18 | else: 19 | parts[0] = "input_embedders.audio_history" 20 | 21 | if parts[-1] == "resampler_latents": 22 | parts[-1] = "latents" 23 | if len(parts) > 2 and parts[2] == "PerceiverResampler_0": 24 | parts[2] = "perceiver" 25 | 26 | if parts[0] in {'input_image_encoder', 'input_audio_encoder'}: 27 | # encoder translation 28 | if "audio" in parts[0]: 29 | parts[0] = "input_embedders.audio" 30 | else: 31 | parts[0] = "input_embedders.image" 32 | 33 | if ".class_embedding." in name: 34 | return name, param.T 35 | 36 | if len(parts) > 3 and parts[3] == "Transformer_0": 37 | parts[3] = "transformer" 38 | if parts[4].startswith("ResidualAttentionBlock"): 39 | num = parts[4].split("_")[1] 40 | parts[4] = f"resblocks.{num}" 41 | if parts[5] == "MultiHeadDotProductAttention_0": 42 | parts[5] = "attn" 43 | if parts[-2] == "out": 44 | parts[-2] = "out_proj" 45 | elif parts[-1] == "bias": 46 | parts = parts[:-2] + [f"{parts[-2]}_in_proj_bias"] 47 | elif parts[-1] == "kernel": 48 | parts = parts[:-2] + [f"{parts[-2]}_in_proj_weight"] 49 | param = param.T 50 | elif parts[5] == "MLP_0": 51 | parts[5] = "mlp" 52 | if parts[6] == "c_fc": 53 | parts[6] = "fc1" 54 | if parts[6] == "c_proj": 55 | parts[6] = "fc2" 56 | if parts[-1] == "kernel": 57 | parts[-1] = "weight" 58 | param = param.T 59 | if parts[-2] == "norm1": 60 | parts[-2] = "ln_1" 61 | elif parts[-2] == "norm2": 62 | parts[-2] = "ln_2" 63 | if parts[-2] in {"pre_ln", "ln_1", "ln_2", "lin_2", "lin_1"} and parts[-1] == "scale": 64 | parts[-1] = "weight" 65 | if parts[-1] == "pos_embed": 66 | parts[-1] = "positional_embedding" 67 | 68 | if parts[0] == "input_text_encoder": 69 | parts[0] = "input_embedders.text" 70 | 71 | for ix, p in enumerate(parts): 72 | if p.endswith("layer_norm"): 73 | parts[ix] = p[:-len("layer_norm")] + "norm" 74 | 75 | if parts[0] == "target_encoders_text": 76 | parts[0] = "target_embedders.text" 77 | 78 | if parts[0] == "target_encoders_image": 79 | # image target encoder translation 80 | parts[0] = "target_embedders.image" 81 | if parts[1] == "discrete_vae": 82 | parts[1] = "vqgan" 83 | if parts[2] == "quantize": 84 | parts.append("weight") 85 | elif parts[-2].startswith("norm"): 86 | if parts[-1] == "scale": 87 | parts[-1] = "weight" 88 | elif parts[-1] == "kernel": 89 | parts[-1] = "weight" 90 | return ".".join(parts), np.transpose(param, (3, 2, 0, 1)) 91 | 92 | if parts[0] == "target_encoders_audio": 93 | # audio target encoder translation 94 | parts[0] = "target_embedders.audio" 95 | if parts[1] == "discrete_vae": 96 | parts[1] = "vqgan" 97 | if parts[2] == "quantize": 98 | parts.append("weight") 99 | elif len(parts) > 3 and parts[3] == "Transformer_0": 100 | parts[3] = "transformer" 101 | if parts[4].startswith("encoderblock"): 102 | if parts[5] == "MultiHeadDotProductAttention_0": 103 | parts[5] = "attn" 104 | elif parts[5] == "MlpBlock_0": 105 | parts[5] = "mlp" 106 | num = int(parts[6].split("_")[1]) + 1 107 | parts[6] = f"fc{num}" 108 | elif parts[5].startswith("LayerNormWithBias"): 109 | num = int(parts[5].split("_")[1]) + 1 110 | parts[5] = f"ln_{num}" 111 | elif parts[3] == "ConvTranspose_0": 112 | parts[3] = "conv_transpose" 113 | parts[-1] = "weight" 114 | v = np.transpose(param, (2, 3, 0, 1)) 115 | v = np.flip(v, [2, 3]) 116 | return ".".join(parts), v.copy() 117 | 118 | if parts[-1] == "scale": 119 | parts[-1] = "weight" 120 | 121 | if parts[-2] == "attention": 122 | parts[-1] = "weight" 123 | 124 | if parts[-1] == "embedding": 125 | parts[-1] = "weight" 126 | 127 | if parts[-1] == "kernel": 128 | parts[-1] = "weight" 129 | param = param.T 130 | return ".".join(parts), param 131 | 132 | 133 | def convert_params(params: Dict) -> Dict: 134 | """Convert a dictionary of jax parameters into torch ones""" 135 | mapped_params = {} 136 | for k, v in params.items(): 137 | k, v = convert_param(k, v) 138 | mapped_params[k] = torch.as_tensor(v) 139 | return mapped_params 140 | 141 | 142 | def flatten_checkpoint(src, prefix, out): 143 | if isinstance(src, dict): 144 | for k, v in src.items(): 145 | flatten_checkpoint(v, prefix + "." + k, out) 146 | elif isinstance(src, np.ndarray) and src.dtype == np.object_: 147 | flatten_checkpoint(src.item(), prefix, out) 148 | else: 149 | out[prefix] = src 150 | return {k.lstrip("."): v for k, v in out.items()} 151 | 152 | 153 | def load_uio2_checkpoint(checkpoint, input_modalities=("text",), target_modalities=("text",)): 154 | """Load UIO2 parameters stored in a npz file as a torch compatible state dict""" 155 | prefixes = [ 156 | 'decoder', 'encoder', 157 | 'audio_token_embedder', 'image_token_embedder', 'text_token_embedder', 158 | 'input_text_encoder', 159 | ] 160 | if "image" in input_modalities: 161 | prefixes.append("input_image_encoder") 162 | if "audio" in input_modalities: 163 | prefixes.append('input_audio_encoder') 164 | if "audio_history" in input_modalities: 165 | prefixes.append('input_encoders_audio_history') 166 | if "image_history" in input_modalities: 167 | prefixes.append('input_encoders_image_history') 168 | if "image" in target_modalities: 169 | prefixes.append('target_encoders_image') 170 | if "text" in target_modalities: 171 | prefixes.append('target_encoders_text') 172 | if "audio" in target_modalities: 173 | prefixes.append('target_encoders_audio') 174 | 175 | if checkpoint.endswith(".npz"): 176 | params = np.load(checkpoint, allow_pickle=True) 177 | params = {k: params[k] for k in params if any(k.startswith(x) for x in prefixes)} 178 | params = flatten_checkpoint(params, '', {}) 179 | mapped_params = convert_params(params) 180 | else: 181 | raise NotImplementedError() 182 | return mapped_params 183 | 184 | 185 | -------------------------------------------------------------------------------- /uio2/data_utils.py: -------------------------------------------------------------------------------- 1 | """Utility pre-processing functions""" 2 | from typing import Optional 3 | 4 | import tensorflow as tf 5 | from tensorflow.python.ops import control_flow_ops 6 | 7 | from uio2 import config 8 | 9 | 10 | def apply_with_random_selector(x, func, num_cases): 11 | """Computes func(x, sel), with sel sampled from [0...num_cases-1]. 12 | Args: 13 | x: input Tensor. 14 | func: Python function to apply. 15 | num_cases: Python int32, number of cases to sample sel from. 16 | Returns: 17 | The result of func(x, sel), where func receives the value of the 18 | selector as a python integer, but sel is sampled dynamically. 19 | """ 20 | sel = tf.random.uniform([], maxval=num_cases, dtype=tf.int32) 21 | # Pass the real x only to one of the func calls. 22 | return control_flow_ops.merge([ 23 | func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case) 24 | for case in range(num_cases)])[0] 25 | 26 | 27 | def get_non_empty_box_indices(boxes): 28 | """Get indices for non-empty boxes.""" 29 | height = boxes[:, 2] - boxes[:, 0] 30 | width = boxes[:, 3] - boxes[:, 1] 31 | indices = tf.where( 32 | tf.logical_and(tf.greater(height, 0), tf.greater(width, 0))) 33 | return indices[:, 0] 34 | 35 | 36 | def clip_boxes(boxes, image_shape): 37 | """Clips boxes to image boundaries. 38 | Args: 39 | boxes: a tensor whose last dimension is 4 representing the coordinates of 40 | boxes in ymin, xmin, ymax, xmax order. 41 | image_shape: a list of two integers, a two-element vector or a tensor such 42 | that all but the last dimensions are `broadcastable` to `boxes`. The last 43 | dimension is 2, which represents [height, width]. 44 | Returns: 45 | clipped_boxes: a tensor whose shape is the same as `boxes` representing the 46 | clipped boxes. 47 | Raises: 48 | ValueError: If the last dimension of boxes is not 4. 49 | """ 50 | if boxes.shape[-1] != 4: 51 | raise ValueError('boxes.shape[-1] is {:d}, but must be 4.'.format( 52 | boxes.shape[-1])) 53 | 54 | with tf.name_scope('clip_boxes'): 55 | if isinstance(image_shape, list) or isinstance(image_shape, tuple): 56 | height, width = image_shape 57 | max_length = [height, width, height, width] 58 | else: 59 | image_shape = tf.cast(image_shape, dtype=boxes.dtype) 60 | height, width = tf.unstack(image_shape, axis=-1) 61 | max_length = tf.stack( 62 | [height, width, height, width], axis=-1) 63 | 64 | clipped_boxes = tf.math.maximum(tf.math.minimum(boxes, max_length), 0.0) 65 | return clipped_boxes 66 | 67 | 68 | def resize_and_crop_boxes(boxes, image_scale, output_size, offset, paddings): 69 | """Resizes boxes to output size with scale and offset. 70 | Args: 71 | boxes: `Tensor` of shape [N, 4] representing ground truth boxes. 72 | image_scale: 2D float `Tensor` representing scale factors that apply to 73 | [height, width] of input image. 74 | output_size: 2D `Tensor` or `int` representing [height, width] of target 75 | output image size. 76 | offset: 2D `Tensor` representing top-left corner [y0, x0] to crop scaled 77 | boxes. 78 | paddings: 2D `Tensor` representing top/left paddings. 79 | Returns: 80 | boxes: `Tensor` of shape [N, 4] representing the scaled boxes. 81 | """ 82 | # Adjusts box coordinates based on image_scale, offset and paddings. 83 | boxes *= tf.tile(tf.expand_dims(image_scale, axis=0), [1, 2]) 84 | boxes -= tf.tile(tf.expand_dims(offset, axis=0), [1, 2]) 85 | boxes += tf.tile(tf.expand_dims(paddings, axis=0), [1, 2]) 86 | # Clips the boxes. 87 | boxes = clip_boxes(boxes, output_size) 88 | return boxes 89 | 90 | 91 | def denormalize_boxes(boxes, image_shape): 92 | """Converts boxes normalized by [height, width] to pixel coordinates. 93 | Args: 94 | boxes: a tensor whose last dimension is 4 representing the coordinates of 95 | boxes in ymin, xmin, ymax, xmax order. 96 | image_shape: a list of two integers, a two-element vector or a tensor such 97 | that all but the last dimensions are `broadcastable` to `boxes`. The last 98 | dimension is 2, which represents [height, width]. 99 | Returns: 100 | denormalized_boxes: a tensor whose shape is the same as `boxes` representing 101 | the denormalized boxes. 102 | Raises: 103 | ValueError: If the last dimension of boxes is not 4. 104 | """ 105 | with tf.name_scope('denormalize_boxes'): 106 | if isinstance(image_shape, list) or isinstance(image_shape, tuple): 107 | height, width = image_shape 108 | height = tf.cast(height, dtype=boxes.dtype) 109 | width = tf.cast(width, dtype=boxes.dtype) 110 | else: 111 | image_shape = tf.cast(image_shape, dtype=boxes.dtype) 112 | height, width = tf.split(image_shape, 2, axis=-1) 113 | 114 | ymin, xmin, ymax, xmax = tf.split(boxes, 4, axis=-1) 115 | ymin = ymin * height 116 | xmin = xmin * width 117 | ymax = ymax * height 118 | xmax = xmax * width 119 | 120 | denormalized_boxes = tf.concat([ymin, xmin, ymax, xmax], axis=-1) 121 | return denormalized_boxes 122 | 123 | 124 | def resize_and_pad_default( 125 | image, is_training, is_input=True, masks=None, boxes=None, box_labels=None, 126 | random_scale_min=None, random_scale_max=None, random_scale_ratio=None, 127 | resize_method=None, is_history=False 128 | ): 129 | """Apply `resize_and_pad` with default settings""" 130 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 131 | if masks is not None: 132 | masks = tf.image.convert_image_dtype(masks, dtype=tf.float32) 133 | if random_scale_min is None: 134 | random_scale_min = config.RANDOM_SCALE_MIN 135 | if random_scale_max is None: 136 | random_scale_max = config.RANDOM_SCALE_MAX 137 | if random_scale_ratio is None: 138 | random_scale_ratio = config.RANDOM_SCALE_RATIO 139 | if resize_method is None: 140 | resize_method ='random' if is_training else tf.image.ResizeMethod.BILINEAR 141 | if is_history: 142 | output_size = config.IMAGE_HISTORY_INPUT_SIZE 143 | elif is_input: 144 | output_size = config.IMAGE_INPUT_SIZE 145 | else: 146 | assert masks is None 147 | output_size = config.IMAGE_TARGET_SIZE 148 | return resize_and_pad( 149 | image, output_size, 150 | masks, boxes, box_labels, 151 | random_scale_min=random_scale_min, 152 | random_scale_max=random_scale_max, 153 | do_random_scale=is_training, 154 | random_scale_ratio=random_scale_ratio, 155 | resize_method=resize_method, 156 | desired_target_size=config.IMAGE_TARGET_SIZE 157 | ) 158 | 159 | 160 | def resize_and_pad( 161 | image, desired_output_size, target_image=None, boxes=None, box_labels=None, 162 | random_scale_min=0.1, random_scale_max=2.0, do_random_scale=False, 163 | shrink_both_sides=True, filter_box=True, desired_target_size=None, random_scale_ratio=0.0, 164 | resize_method=tf.image.ResizeMethod.BILINEAR, boxes_normalized=False 165 | ): 166 | """Resizes and pads an input image/video to `desired_output_size` 167 | 168 | Support random scaling augmentation if `do_random_scale` is True 169 | 170 | If `masks` or `boxes` are given, the same transformation that is applied ot the image 171 | is applied to them. Boxes can be completely removed if doing scaling augmentation, in which 172 | case the deleted boxes will not be returned. 173 | 174 | outputs: 175 | image: The resized image/video 176 | image_mask: A mask showing which pixels are padding in the output image 177 | meta-data: Meta-data about the transformation and the boxes/masks that were also transformed 178 | """ 179 | desired_height, desired_width = desired_output_size 180 | desired_height_f = tf.cast(desired_height, dtype=tf.float32) 181 | desired_width_f = tf.cast(desired_width, dtype=tf.float32) 182 | 183 | is_video = len(image.shape) == 4 184 | 185 | if is_video: 186 | height = tf.cast(tf.shape(image)[1], tf.float32) 187 | width = tf.cast(tf.shape(image)[2], tf.float32) 188 | else: 189 | height = tf.cast(tf.shape(image)[0], tf.float32) 190 | width = tf.cast(tf.shape(image)[1], tf.float32) 191 | 192 | if boxes is not None and boxes_normalized: 193 | # Converts boxes from normalized coordinates to pixel coordinates. 194 | # Now the coordinates of boxes are w.r.t. the original image. 195 | boxes = denormalize_boxes(boxes, [height, width]) 196 | 197 | if do_random_scale: 198 | random_scale_factor = tf.random.uniform([], random_scale_min, random_scale_max) 199 | if not shrink_both_sides: 200 | # Max random is where scale * W > W_desired 201 | # scale * H > H_desired 202 | rsf_max = tf.maximum(desired_width_f / width, desired_height_f / height) 203 | random_scale_factor = tf.minimum(rsf_max, random_scale_factor) 204 | 205 | scaled_y = tf.cast(random_scale_factor * desired_height_f, tf.int32) 206 | scaled_x = tf.cast(random_scale_factor * desired_width_f, tf.int32) 207 | 208 | # Recompute the accurate scale_factor using rounded scaled image size. 209 | image_scale_y = tf.cast(scaled_y, tf.float32) / height 210 | image_scale_x = tf.cast(scaled_x, tf.float32) / width 211 | 212 | image_scale = tf.cond(tf.less( 213 | tf.random.uniform([], minval=0, maxval=1, dtype=tf.float32), 214 | tf.cast(random_scale_ratio, tf.float32)), 215 | lambda: tf.maximum(image_scale_x, image_scale_y), 216 | lambda: tf.minimum(image_scale_x, image_scale_y)) 217 | 218 | # Don't scale any side lower than to 64 219 | # For very wide images, this truncates the edge in order to keep the resolution 220 | # reasonable 221 | image_scale = tf.maximum(image_scale, 64.0 / tf.minimum(height, width)) 222 | 223 | # Select non-zero random offset (x, y) if scaled image is larger than 224 | scaled_height = tf.cast(height * image_scale, tf.int32) 225 | scaled_width = tf.cast(width * image_scale, tf.int32) 226 | offset_y = tf.cast(scaled_height - desired_height, tf.float32) 227 | offset_x = tf.cast(scaled_width - desired_width, tf.float32) 228 | offset_y = tf.maximum(0.0, offset_y) * tf.random.uniform([], 0, 1) 229 | offset_x = tf.maximum(0.0, offset_x) * tf.random.uniform([], 0, 1) 230 | offset_y = tf.cast(offset_y, tf.int32) 231 | offset_x = tf.cast(offset_x, tf.int32) 232 | else: 233 | image_scale_y = desired_height_f / height 234 | image_scale_x = desired_width_f / width 235 | image_scale = tf.minimum(image_scale_x, image_scale_y) 236 | scaled_height = tf.cast(height * image_scale, tf.int32) 237 | scaled_width = tf.cast(width * image_scale, tf.int32) 238 | offset_y = tf.constant(0) 239 | offset_x = tf.constant(0) 240 | 241 | # Now resize and crop 242 | if resize_method == 'random' and do_random_scale and (not tf.executing_eagerly()): 243 | resize_methods = sorted([k for k in tf.image.ResizeMethod.__dict__.keys() if k.isupper()]) 244 | # print("Random resize method:\n{}".format(','.join(resize_methods))) 245 | image = apply_with_random_selector( 246 | image, 247 | lambda x, method_idx: tf.image.resize(x, [scaled_height, scaled_width], 248 | tf.image.ResizeMethod.__dict__[resize_methods[method_idx]], 249 | antialias=True), 250 | num_cases=len(resize_methods)) 251 | 252 | elif resize_method != 'random': 253 | image = tf.image.resize(image, [scaled_height, scaled_width], method=resize_method, antialias=True) 254 | else: 255 | image = tf.image.resize(image, [scaled_height, scaled_width], 256 | method=tf.image.ResizeMethod.BILINEAR, antialias=True) 257 | 258 | image = tf.clip_by_value(image, 0.0, 1.0) 259 | 260 | if is_video: 261 | # frames x H x W x C 262 | image = image[:,offset_y:offset_y + desired_height, offset_x:offset_x + desired_width, :] 263 | H = tf.shape(image)[1] 264 | W = tf.shape(image)[2] 265 | else: 266 | # H x W x C 267 | image = image[offset_y:offset_y + desired_height, offset_x:offset_x + desired_width, :] 268 | H = tf.shape(image)[0] 269 | W = tf.shape(image)[1] 270 | 271 | top_pad = (desired_height - H) // 2 272 | left_pad = (desired_width - W) // 2 273 | 274 | # Get the mask which indicates which regions were padded 275 | mask = tf.ones(tf.concat([tf.shape(image)[:-1], [1]], 0), dtype=tf.int32) 276 | image_mask = tf.squeeze(tf.image.pad_to_bounding_box( 277 | mask, top_pad, left_pad, desired_height, desired_width), -1) 278 | 279 | image = tf.image.pad_to_bounding_box( 280 | image, top_pad, left_pad, desired_height, desired_width) 281 | 282 | if is_video: 283 | image.set_shape([None, desired_height, desired_width, 3]) 284 | else: 285 | image.set_shape([desired_height, desired_width, 3]) 286 | 287 | if target_image is not None and tf.size(target_image) != 0: 288 | target_image = tf.image.resize( 289 | target_image, [scaled_height, scaled_width], 290 | method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) 291 | if len(target_image.shape) == 3: 292 | target_image = target_image[offset_y:offset_y + desired_height, offset_x:offset_x + desired_width] 293 | else: 294 | target_image = target_image[:, offset_y:offset_y + desired_height, offset_x:offset_x + desired_width] 295 | 296 | target_image = tf.image.pad_to_bounding_box( 297 | target_image, top_pad, left_pad, desired_height, desired_width) 298 | target = tf.image.resize(target_image, desired_target_size, 299 | method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) 300 | else: 301 | target = None 302 | 303 | indices = None 304 | if boxes is not None: 305 | boxes = resize_and_crop_boxes( 306 | boxes, 307 | tf.stack([image_scale, image_scale]), 308 | [desired_height, desired_width], 309 | tf.cast(tf.stack([offset_y, offset_x]), dtype=tf.float32), 310 | tf.cast(tf.stack([top_pad, left_pad]), dtype=tf.float32)) 311 | 312 | if filter_box: 313 | indices = get_non_empty_box_indices(boxes) 314 | else: 315 | indices = tf.range(tf.shape(boxes)[0]) 316 | boxes = tf.gather(boxes, indices) 317 | 318 | if box_labels is not None: 319 | box_labels = tf.gather(box_labels, indices) 320 | 321 | # Stores meta meta-data about how the image was resized, needed if we want 322 | # reverse the padding/resizing later 323 | image_info = tf.stack([ 324 | tf.cast(top_pad, tf.float32), 325 | tf.cast(left_pad, tf.float32), 326 | 1.0 / image_scale, 327 | height, 328 | width, 329 | tf.cast(offset_y, dtype=tf.float32) / height, 330 | tf.cast(offset_x, dtype=tf.float32) / width, 331 | tf.cast(offset_y, dtype=tf.float32), 332 | tf.cast(offset_x, dtype=tf.float32), 333 | tf.cast(scaled_height, dtype=tf.float32), 334 | tf.cast(scaled_width, dtype=tf.float32), 335 | ]) 336 | 337 | outputs = (image_info, target, boxes, box_labels, indices) 338 | return image, image_mask, outputs 339 | 340 | 341 | def trim_or_pad_tf(x, seq_len, pad_constant=0): 342 | x = x[:seq_len] 343 | sh = list(x.shape) 344 | sh[0] = seq_len 345 | x = tf.pad( 346 | x, 347 | [[0, seq_len-tf.shape(x)[0]]] + [[0, 0]]*(len(sh)-1), 348 | constant_values=pad_constant, 349 | ) 350 | return tf.ensure_shape(x, sh) 351 | 352 | 353 | def trim_or_pad_tf_2d(x, batch, seq_len): 354 | x = x[:batch, :seq_len] 355 | sh = [batch, seq_len] + list(x.shape)[2:] 356 | x = tf.pad(x, 357 | [[0, batch-tf.shape(x)[0]]] + 358 | [[0, seq_len-tf.shape(x)[1]]] + 359 | [[0, 0]]*(len(sh)-2)) 360 | return tf.ensure_shape(x, sh) 361 | 362 | 363 | def values_to_tokens(vals, clss=None): 364 | """Convert real values to quantized text tokens""" 365 | vals = tf.convert_to_tensor(vals) 366 | num_bins = config.NUM_DETECTION_BIN 367 | vocab_start = config.VOCAB_START 368 | quantized_boxes = tf.cast(vals * (num_bins-1), tf.int32) 369 | 370 | # For values that were exactly one 371 | vals = tf.constant([f'' for i in range(vocab_start, vocab_start+num_bins)]) 372 | tokens = tf.gather(vals, quantized_boxes) 373 | 374 | if clss is not None: 375 | tokens = tf.concat([tokens, tf.expand_dims(clss, 1)], axis=-1) 376 | 377 | return tokens 378 | 379 | 380 | def _shift_right_by_one(tensor: tf.Tensor, bos_id: int = 0) -> tf.Tensor: 381 | """Shift the input tensor to the right by one position without wrapping 382 | 383 | From seqio: https://github.com/google/seqio 384 | """ 385 | 386 | if not (tensor.dtype.is_integer or tensor.dtype.is_floating): 387 | raise ValueError(f"Only numeric types are supported. Got: {tensor.dtype}") 388 | # tf.roll wraps around the axis. 389 | rolled = tf.roll(tensor, shift=1, axis=0) 390 | 391 | # Zero out the first position by multiplying with [0, 1, 1, ..., 1]. 392 | depth = tf.shape(tensor)[0] 393 | mask = tf.one_hot(0, depth=depth, on_value=0, off_value=1, dtype=tensor.dtype) 394 | 395 | # Expand dims of mask to broadcast to rolled. 396 | dim_expansion = [slice(None, None)] + [None] * (len(rolled.shape) - 1) 397 | mask = mask[dim_expansion] 398 | return rolled * mask + (1 - mask) * bos_id 399 | 400 | 401 | def make_autoregressive_inputs( 402 | targets: tf.Tensor, 403 | sequence_id: tf.Tensor = None, 404 | output_dtype: Optional[tf.dtypes.DType] = None, 405 | bos_id: int = 0, 406 | ) -> tf.Tensor: 407 | """Shift tokens right and add BOS to build decoder inputs 408 | 409 | from seqio: https://github.com/google/seqio 410 | """ 411 | 412 | output_dtype = output_dtype or targets.dtype 413 | if sequence_id is not None and not sequence_id.dtype.is_integer: 414 | raise ValueError( 415 | "The sequence_id should be integer-valued tensors for a packed dataset." 416 | ) 417 | if sequence_id is not None and len(targets.shape) > 1: 418 | raise ValueError( 419 | "Only 1-D sequences are supported with packing. Got a " 420 | f"packed {len(targets.shape)}-D sequence." 421 | ) 422 | 423 | inputs = _shift_right_by_one(targets, bos_id) 424 | if inputs.dtype != output_dtype: 425 | inputs = tf.cast(inputs, output_dtype) 426 | 427 | # We should have a 0 at the beginning of each sequence rather than the 428 | # shifted EOS (e.g. 1) from the previous sequence. 429 | if sequence_id is not None: 430 | not_first_in_sequence = tf.equal( 431 | sequence_id, _shift_right_by_one(sequence_id) 432 | ) 433 | not_first_in_sequence = tf.cast(not_first_in_sequence, output_dtype) 434 | first_ids = tf.cast((1 - not_first_in_sequence) * bos_id, output_dtype) 435 | inputs = inputs * not_first_in_sequence + first_ids 436 | return inputs 437 | 438 | 439 | def normalize_image(image, 440 | offset=(0.48145466, 0.4578275, 0.40821073), 441 | scale=(0.26862954, 0.26130258, 0.27577711)): 442 | """Normalizes the image by, uses image net scale/offset by default""" 443 | shape = [1]*(len(image.shape) - 1) + [3] 444 | image -= tf.constant(offset, dtype=image.dtype, shape=shape) 445 | image /= tf.constant(scale, dtype=image.dtype, shape=shape) 446 | return image 447 | 448 | 449 | def unnormalize_image(image, 450 | offset=(0.48145466, 0.4578275, 0.40821073), 451 | scale=(0.26862954, 0.26130258, 0.27577711)): 452 | shape = [1]*(len(image.shape) - 1) + [3] 453 | image *= tf.constant(scale, dtype=image.dtype, shape=shape) 454 | image += tf.constant(offset, dtype=image.dtype, shape=shape) 455 | return image 456 | 457 | 458 | def sample_patches(mask, n_patches): 459 | """Select `n_patches` position from `mask`""" 460 | input_sample_valid = tf.boolean_mask(tf.range(tf.shape(mask)[0]), mask) 461 | input_sample_masked = tf.boolean_mask(tf.range(tf.shape(mask)[0]), mask == 0) 462 | encoder_pos_ids = tf.concat([ 463 | tf.random.shuffle(input_sample_valid), 464 | tf.random.shuffle(input_sample_masked)], axis=0)[:n_patches] 465 | encoder_pos_ids = tf.reshape(encoder_pos_ids, (n_patches,)) 466 | encoder_pos_ids = tf.cast(encoder_pos_ids, tf.int32) 467 | return encoder_pos_ids 468 | -------------------------------------------------------------------------------- /uio2/get_modality_processor.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | 3 | from uio2 import config 4 | from uio2.audio_embedder import AudioFeature 5 | from uio2.config import ImageVitFeatureConfig, AudioVitFeatureConfig, ImageResamplerConfig, \ 6 | AudioResamplerConfig, AudioViTVQGANConfig, VQGANConfig 7 | from uio2.input_modalities import InputImageViTEncoder, InputImageHistoryViTEncoder, \ 8 | InputAudioViTEncoder, InputAudioHistoryViTEncoder, InputTextEncoder, ModalityEncoder 9 | from uio2.target_modalities import TargetTextEncoder, TargetImageVQGANEmbedder, \ 10 | TargetAudioVQGANEmbedder 11 | from uio2.image_embedder import ImageFeature 12 | 13 | 14 | class ModuleReference: 15 | # Used as part of a hack to handle a case where multiple modules what a reference to 16 | # a shared submodule in UIO2. 17 | # 18 | # In particular `InputAudioHistoryViTEncoder` and `InputImageViTEncoder` both need a reference to 19 | # the `ImageFeature` module, which causes issues where the state dict includes the 20 | # `ImageFeature` parameters twice, once for each reference. 21 | # 22 | # I am not sure what the canonical solution to this is, but as a hack we wrap the 23 | # `ImageFeature` in this class for history encoder so it has a reference to the module, 24 | # but does not register the module. Then the state_dict will onlu include one copy of the 25 | # `ImageFeature` parameters 26 | def __init__(self, module): 27 | self.module = module 28 | 29 | @property 30 | def config(self): 31 | return self.module.config 32 | 33 | def __call__(self, *args, **kwargs): 34 | return self.module(*args, **kwargs) 35 | 36 | 37 | def get_input_modalities( 38 | input_modality=tuple(config.INPUT_MODALITIES), 39 | image_vit_cfg: ImageVitFeatureConfig=ImageVitFeatureConfig(), 40 | audio_vit_cfg: AudioVitFeatureConfig=AudioVitFeatureConfig(), 41 | image_history_cfg: ImageResamplerConfig=ImageResamplerConfig(), 42 | audio_history_cfg: AudioResamplerConfig=AudioResamplerConfig(), 43 | use_image_vit = False, 44 | use_audio_vit = False, 45 | freeze_vit=False, 46 | use_image_history_vit = False, 47 | use_audio_history_vit = False, 48 | ) -> Dict[str, ModalityEncoder]: 49 | """Returns the ModalityEncoder for the input modalities""" 50 | 51 | out = dict() 52 | if 'text' in input_modality: 53 | out["text"] = InputTextEncoder() 54 | 55 | image_encoder = None 56 | if "image" in input_modality or "image_history" in input_modality: 57 | if use_image_vit or use_image_history_vit: 58 | image_encoder = ImageFeature(image_vit_cfg) 59 | 60 | audio_encoder = None 61 | if "audio" in input_modality or "audio_history" in input_modality: 62 | if use_audio_vit or use_audio_history_vit: 63 | audio_encoder = AudioFeature(audio_vit_cfg) 64 | 65 | if 'image' in input_modality: 66 | out["image"] = InputImageViTEncoder( 67 | image_encoder if use_image_vit else None, use_image_vit, freeze_vit) 68 | 69 | if 'image_history' in input_modality: 70 | encoder = image_encoder if use_image_history_vit else None 71 | if "image" in input_modality and encoder is not None: 72 | encoder = ModuleReference(encoder) 73 | out["image_history"] = InputImageHistoryViTEncoder(encoder, image_history_cfg) 74 | 75 | if 'audio' in input_modality: 76 | out["audio"] = InputAudioViTEncoder(audio_encoder if use_audio_vit else None, use_audio_vit, freeze_vit) 77 | 78 | if 'audio_history' in input_modality: 79 | encoder = audio_encoder if use_audio_history_vit else None 80 | if "audio" in input_modality and encoder is not None: 81 | encoder = ModuleReference(encoder) 82 | out["audio_history"] = InputAudioHistoryViTEncoder(encoder, audio_history_cfg) 83 | assert len(out) > 0 84 | return out 85 | 86 | 87 | def get_target_modalities( 88 | target_modality=tuple(config.TARGET_MODALITIES), 89 | image_vqgan_config: VQGANConfig=VQGANConfig(), 90 | audio_vqgan_config: AudioViTVQGANConfig=AudioViTVQGANConfig(), 91 | ) -> Dict[str, ModalityEncoder]: 92 | """Return the encoders to use for target modalities""" 93 | 94 | out = {} 95 | if 'text' in target_modality: 96 | out['text'] = TargetTextEncoder() 97 | if 'image' in target_modality: 98 | out['image'] = TargetImageVQGANEmbedder(image_vqgan_config) 99 | if 'audio' in target_modality: 100 | out['audio'] = TargetAudioVQGANEmbedder(audio_vqgan_config) 101 | assert len(out) > 0 102 | return out 103 | 104 | -------------------------------------------------------------------------------- /uio2/get_model.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | from uio2.config import Config 4 | from uio2.get_modality_processor import get_input_modalities, get_target_modalities 5 | from uio2.preprocessing import UnifiedIOPreprocessor 6 | from uio2.model import UnifiedIOModel 7 | 8 | 9 | def get_model(config: Config, tokenizer_path) -> Tuple[UnifiedIOPreprocessor, UnifiedIOModel]: 10 | """Return a model (with new initialized parameters) and preprocess for the configuration""" 11 | preprocessor = UnifiedIOPreprocessor.from_config(config, tokenizer_path) 12 | model = UnifiedIOModel(config) 13 | return preprocessor, model 14 | -------------------------------------------------------------------------------- /uio2/hifigan/README.md: -------------------------------------------------------------------------------- 1 | # HiFi-GAN 2 | 3 | UIO2 uses a custom trained HiFiGAN to convert audio spectrograms into wave form. 4 | The modelling code comes [here](https://github.com/jik876/hifi-gan). 5 | -------------------------------------------------------------------------------- /uio2/hifigan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/unified-io-2.pytorch/6e487bec6be8f9b909453a5f9833c49914f4a777/uio2/hifigan/__init__.py -------------------------------------------------------------------------------- /uio2/hifigan/checkpoints/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "2", 3 | "num_gpus": 8, 4 | "batch_size": 512, 5 | "learning_rate": 0.0002, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [8,8,4], 12 | "upsample_kernel_sizes": [16,16,8], 13 | "upsample_initial_channel": 256, 14 | "resblock_kernel_sizes": [3,5,7], 15 | "resblock_dilation_sizes": [[1,2], [2,6], [3,12]], 16 | 17 | "segment_size": 8192, 18 | "num_mels": 128, 19 | "num_freq": 1025, 20 | "n_fft": 1024, 21 | "hop_size": 256, 22 | "win_size": 1024, 23 | 24 | "sampling_rate": 16000, 25 | 26 | "fmin": 0, 27 | "fmax": 8000, 28 | "fmax_for_loss": null, 29 | 30 | "num_workers": 8, 31 | 32 | "dist_config": { 33 | "dist_backend": "nccl", 34 | "dist_url": "tcp://127.0.0.1:52111", 35 | "world_size": 1 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /uio2/hifigan/checkpoints/g_00930000: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/unified-io-2.pytorch/6e487bec6be8f9b909453a5f9833c49914f4a777/uio2/hifigan/checkpoints/g_00930000 -------------------------------------------------------------------------------- /uio2/hifigan/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 5 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm 6 | from .utils import init_weights, get_padding 7 | 8 | 9 | LRELU_SLOPE = 0.1 10 | 11 | 12 | class ResBlock1(torch.nn.Module): 13 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 14 | super(ResBlock1, self).__init__() 15 | self.h = h 16 | self.convs1 = nn.ModuleList([ 17 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 18 | padding=get_padding(kernel_size, dilation[0]))), 19 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 20 | padding=get_padding(kernel_size, dilation[1]))), 21 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], 22 | padding=get_padding(kernel_size, dilation[2]))) 23 | ]) 24 | self.convs1.apply(init_weights) 25 | 26 | self.convs2 = nn.ModuleList([ 27 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 28 | padding=get_padding(kernel_size, 1))), 29 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 30 | padding=get_padding(kernel_size, 1))), 31 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 32 | padding=get_padding(kernel_size, 1))) 33 | ]) 34 | self.convs2.apply(init_weights) 35 | 36 | def forward(self, x): 37 | for c1, c2 in zip(self.convs1, self.convs2): 38 | xt = F.leaky_relu(x, LRELU_SLOPE) 39 | xt = c1(xt) 40 | xt = F.leaky_relu(xt, LRELU_SLOPE) 41 | xt = c2(xt) 42 | x = xt + x 43 | return x 44 | 45 | def remove_weight_norm(self): 46 | for l in self.convs1: 47 | remove_weight_norm(l) 48 | for l in self.convs2: 49 | remove_weight_norm(l) 50 | 51 | 52 | class ResBlock2(torch.nn.Module): 53 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): 54 | super(ResBlock2, self).__init__() 55 | self.h = h 56 | self.convs = nn.ModuleList([ 57 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 58 | padding=get_padding(kernel_size, dilation[0]))), 59 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 60 | padding=get_padding(kernel_size, dilation[1]))) 61 | ]) 62 | self.convs.apply(init_weights) 63 | 64 | def forward(self, x): 65 | for c in self.convs: 66 | xt = F.leaky_relu(x, LRELU_SLOPE) 67 | xt = c(xt) 68 | x = xt + x 69 | return x 70 | 71 | def remove_weight_norm(self): 72 | for l in self.convs: 73 | remove_weight_norm(l) 74 | 75 | 76 | class Generator(torch.nn.Module): 77 | def __init__(self, h): 78 | super(Generator, self).__init__() 79 | self.h = h 80 | self.num_kernels = len(h.resblock_kernel_sizes) 81 | self.num_upsamples = len(h.upsample_rates) 82 | self.conv_pre = weight_norm(Conv1d(128, h.upsample_initial_channel, 7, 1, padding=3)) 83 | resblock = ResBlock1 if h.resblock == '1' else ResBlock2 84 | 85 | self.ups = nn.ModuleList() 86 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 87 | self.ups.append(weight_norm( 88 | ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)), 89 | k, u, padding=(k-u)//2))) 90 | 91 | self.resblocks = nn.ModuleList() 92 | for i in range(len(self.ups)): 93 | ch = h.upsample_initial_channel//(2**(i+1)) 94 | for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): 95 | self.resblocks.append(resblock(h, ch, k, d)) 96 | 97 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 98 | self.ups.apply(init_weights) 99 | self.conv_post.apply(init_weights) 100 | 101 | def forward(self, x): 102 | x = self.conv_pre(x) 103 | for i in range(self.num_upsamples): 104 | x = F.leaky_relu(x, LRELU_SLOPE) 105 | x = self.ups[i](x) 106 | xs = None 107 | for j in range(self.num_kernels): 108 | if xs is None: 109 | xs = self.resblocks[i*self.num_kernels+j](x) 110 | else: 111 | xs += self.resblocks[i*self.num_kernels+j](x) 112 | x = xs / self.num_kernels 113 | x = F.leaky_relu(x) 114 | x = self.conv_post(x) 115 | x = torch.tanh(x) 116 | 117 | return x 118 | 119 | def remove_weight_norm(self): 120 | for l in self.ups: 121 | remove_weight_norm(l) 122 | for l in self.resblocks: 123 | l.remove_weight_norm() 124 | remove_weight_norm(self.conv_pre) 125 | remove_weight_norm(self.conv_post) 126 | 127 | 128 | class DiscriminatorP(torch.nn.Module): 129 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 130 | super(DiscriminatorP, self).__init__() 131 | self.period = period 132 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 133 | self.convs = nn.ModuleList([ 134 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 135 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 136 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 137 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 138 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), 139 | ]) 140 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 141 | 142 | def forward(self, x): 143 | fmap = [] 144 | 145 | # 1d to 2d 146 | b, c, t = x.shape 147 | if t % self.period != 0: # pad first 148 | n_pad = self.period - (t % self.period) 149 | x = F.pad(x, (0, n_pad), "reflect") 150 | t = t + n_pad 151 | x = x.view(b, c, t // self.period, self.period) 152 | 153 | for l in self.convs: 154 | x = l(x) 155 | x = F.leaky_relu(x, LRELU_SLOPE) 156 | fmap.append(x) 157 | x = self.conv_post(x) 158 | fmap.append(x) 159 | x = torch.flatten(x, 1, -1) 160 | 161 | return x, fmap 162 | 163 | 164 | class MultiPeriodDiscriminator(torch.nn.Module): 165 | def __init__(self): 166 | super(MultiPeriodDiscriminator, self).__init__() 167 | self.discriminators = nn.ModuleList([ 168 | DiscriminatorP(2), 169 | DiscriminatorP(3), 170 | DiscriminatorP(5), 171 | DiscriminatorP(7), 172 | DiscriminatorP(11), 173 | ]) 174 | 175 | def forward(self, y, y_hat): 176 | y_d_rs = [] 177 | y_d_gs = [] 178 | fmap_rs = [] 179 | fmap_gs = [] 180 | for i, d in enumerate(self.discriminators): 181 | y_d_r, fmap_r = d(y) 182 | y_d_g, fmap_g = d(y_hat) 183 | y_d_rs.append(y_d_r) 184 | fmap_rs.append(fmap_r) 185 | y_d_gs.append(y_d_g) 186 | fmap_gs.append(fmap_g) 187 | 188 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 189 | 190 | 191 | class DiscriminatorS(torch.nn.Module): 192 | def __init__(self, use_spectral_norm=False): 193 | super(DiscriminatorS, self).__init__() 194 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 195 | self.convs = nn.ModuleList([ 196 | norm_f(Conv1d(1, 128, 15, 1, padding=7)), 197 | norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), 198 | norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), 199 | norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), 200 | norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), 201 | norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), 202 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), 203 | ]) 204 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) 205 | 206 | def forward(self, x): 207 | fmap = [] 208 | for l in self.convs: 209 | x = l(x) 210 | x = F.leaky_relu(x, LRELU_SLOPE) 211 | fmap.append(x) 212 | x = self.conv_post(x) 213 | fmap.append(x) 214 | x = torch.flatten(x, 1, -1) 215 | 216 | return x, fmap 217 | 218 | 219 | class MultiScaleDiscriminator(torch.nn.Module): 220 | def __init__(self): 221 | super(MultiScaleDiscriminator, self).__init__() 222 | self.discriminators = nn.ModuleList([ 223 | DiscriminatorS(use_spectral_norm=True), 224 | DiscriminatorS(), 225 | DiscriminatorS(), 226 | ]) 227 | self.meanpools = nn.ModuleList([ 228 | AvgPool1d(4, 2, padding=2), 229 | AvgPool1d(4, 2, padding=2) 230 | ]) 231 | 232 | def forward(self, y, y_hat): 233 | y_d_rs = [] 234 | y_d_gs = [] 235 | fmap_rs = [] 236 | fmap_gs = [] 237 | for i, d in enumerate(self.discriminators): 238 | if i != 0: 239 | y = self.meanpools[i-1](y) 240 | y_hat = self.meanpools[i-1](y_hat) 241 | y_d_r, fmap_r = d(y) 242 | y_d_g, fmap_g = d(y_hat) 243 | y_d_rs.append(y_d_r) 244 | fmap_rs.append(fmap_r) 245 | y_d_gs.append(y_d_g) 246 | fmap_gs.append(fmap_g) 247 | 248 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 249 | 250 | 251 | def feature_loss(fmap_r, fmap_g): 252 | loss = 0 253 | for dr, dg in zip(fmap_r, fmap_g): 254 | for rl, gl in zip(dr, dg): 255 | loss += torch.mean(torch.abs(rl - gl)) 256 | 257 | return loss*2 258 | 259 | 260 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 261 | loss = 0 262 | r_losses = [] 263 | g_losses = [] 264 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 265 | r_loss = torch.mean((1-dr)**2) 266 | g_loss = torch.mean(dg**2) 267 | loss += (r_loss + g_loss) 268 | r_losses.append(r_loss.item()) 269 | g_losses.append(g_loss.item()) 270 | 271 | return loss, r_losses, g_losses 272 | 273 | 274 | def generator_loss(disc_outputs): 275 | loss = 0 276 | gen_losses = [] 277 | for dg in disc_outputs: 278 | l = torch.mean((1-dg)**2) 279 | gen_losses.append(l) 280 | loss += l 281 | 282 | return loss, gen_losses 283 | 284 | -------------------------------------------------------------------------------- /uio2/hifigan/utils.py: -------------------------------------------------------------------------------- 1 | from torch.nn.utils import weight_norm 2 | 3 | def init_weights(m, mean=0.0, std=0.01): 4 | classname = m.__class__.__name__ 5 | if classname.find("Conv") != -1: 6 | m.weight.data.normal_(mean, std) 7 | 8 | 9 | def apply_weight_norm(m): 10 | classname = m.__class__.__name__ 11 | if classname.find("Conv") != -1: 12 | weight_norm(m) 13 | 14 | 15 | def get_padding(kernel_size, dilation=1): 16 | return int((kernel_size*dilation - dilation)/2) 17 | 18 | -------------------------------------------------------------------------------- /uio2/image_embedder.py: -------------------------------------------------------------------------------- 1 | """Model that builds patch features from an image""" 2 | import math 3 | from typing import Any, Optional 4 | 5 | import torch 6 | 7 | from uio2.config import ImageVitFeatureConfig, AudioVitFeatureConfig 8 | 9 | from uio2 import layers 10 | from torch import nn 11 | from torch.nn import functional as F 12 | 13 | 14 | class MLP(nn.Module): 15 | def __init__(self, config): 16 | super().__init__() 17 | self.config = config 18 | self.fc1 = nn.Linear(config.emb_dim, config.mlp_dim, bias=True) 19 | self.gelu = nn.GELU(approximate='none') 20 | self.fc2 = nn.Linear(config.mlp_dim, config.emb_dim, bias=True) 21 | 22 | def forward(self, x): 23 | x = self.fc1(x) 24 | x = self.gelu(x) 25 | x = self.fc2(x) 26 | return x 27 | 28 | 29 | class MultiHeadDotProductAttention(nn.Module): 30 | def __init__( 31 | self, 32 | emb_dim, 33 | num_heads: int, 34 | head_dim: int, 35 | dropout_rate: float = 0., 36 | float32_logits: bool = False # computes logits in float32 for stability. 37 | ): 38 | super().__init__() 39 | self.num_heads = num_heads 40 | self.head_dim = head_dim 41 | assert emb_dim == num_heads * head_dim, "embed_dim must be divisible by num_heads" 42 | self.scale = self.head_dim ** -0.5 43 | self.dropout_rate = dropout_rate 44 | self.float32_logits = float32_logits 45 | 46 | self.query_in_proj_weight = nn.Parameter(torch.randn(emb_dim, emb_dim) * self.scale) 47 | self.query_in_proj_bias = nn.Parameter(torch.zeros(emb_dim)) 48 | self.key_in_proj_weight = nn.Parameter(torch.randn(emb_dim, emb_dim) * self.scale) 49 | self.key_in_proj_bias = nn.Parameter(torch.zeros(emb_dim)) 50 | self.value_in_proj_weight = nn.Parameter(torch.randn(emb_dim, emb_dim) * self.scale) 51 | self.value_in_proj_bias = nn.Parameter(torch.zeros(emb_dim)) 52 | 53 | self.attn_drop = layers.Dropout(dropout_rate, broadcast_dims=(-2, )) 54 | self.out_proj = nn.Linear(emb_dim, emb_dim, bias=True) 55 | 56 | def forward(self, inputs_q, inputs_kv, attn_mask: Optional[torch.Tensor] = None): 57 | # inputs_q: [batch_size, len_q, emb_dim] 58 | # inputs_kv: [batch_size, len_kv, emb_dim] 59 | # attn_mask: [batch_size, num_heads, len_q, len_kv] 60 | 61 | # Project inputs_q/inputs_kv to multi-headed q/k/v 62 | # dimensions are then [batch, len, num_heads, head_dim] 63 | bs, q_len, emb_dim = inputs_q.shape 64 | kv_len = inputs_kv.shape[1] 65 | query = F.linear(inputs_q, self.query_in_proj_weight, self.query_in_proj_bias).reshape( 66 | bs, q_len, self.num_heads, self.head_dim 67 | ) 68 | key = F.linear(inputs_kv, self.key_in_proj_weight, self.key_in_proj_bias).reshape( 69 | bs, kv_len, self.num_heads, self.head_dim 70 | ) 71 | value = F.linear(inputs_kv, self.value_in_proj_weight, self.value_in_proj_bias).reshape( 72 | bs, kv_len, self.num_heads, self.head_dim 73 | ) 74 | 75 | if self.float32_logits: 76 | query = query.to(torch.float32) 77 | key = key.to(torch.float32) 78 | 79 | query = query * self.scale 80 | # `attn_weights`: [batch, num_heads, len_q, len_kv] 81 | attn_weights = torch.einsum("bqhd,bkhd->bhqk", query, key) 82 | 83 | if attn_mask is not None: 84 | new_attn_mask = torch.zeros_like(attn_mask, dtype=attn_weights.dtype) 85 | new_attn_mask.masked_fill_(~(attn_mask > 0), -1e10) 86 | attn_mask = new_attn_mask 87 | attn_weights += attn_mask 88 | 89 | attn_weights = F.softmax(attn_weights, dim=-1).to(inputs_q.dtype) 90 | attn_weights = self.attn_drop(attn_weights) 91 | 92 | # `attn_out`: [batch, len_q, num_heads, head_dim] 93 | attn_out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value) 94 | # `out`: [batch, len_q, emb_dim] 95 | out = self.out_proj(attn_out.reshape(bs, q_len, emb_dim)) 96 | 97 | return out 98 | 99 | 100 | class ResidualAttentionBlock(nn.Module): 101 | def __init__(self, config): 102 | super().__init__() 103 | self.config = config 104 | self.ln_1 = nn.LayerNorm(config.emb_dim, eps=1e-5) 105 | self.attn = MultiHeadDotProductAttention( 106 | config.emb_dim, 107 | config.num_heads, 108 | config.head_dim, 109 | config.dropout_rate, 110 | # The uio2 jax code did not use this parameter. 111 | # float32_logits=config.float32_attention_logits 112 | ) 113 | self.ln_2 = nn.LayerNorm(config.emb_dim, eps=1e-5) 114 | self.mlp = MLP(config) 115 | 116 | def forward(self, x, attn_mask): 117 | x1 = self.ln_1(x) 118 | x2 = self.attn(x1, x1, attn_mask) 119 | x = x + x2 120 | x1 = self.ln_2(x) 121 | x2 = self.mlp(x1) 122 | x = x + x2 123 | return x 124 | 125 | 126 | class Transformer(nn.Module): 127 | def __init__(self, config): 128 | super().__init__() 129 | self.config = config 130 | self.num_layers = config.num_layers 131 | resblocks = [] 132 | for i in range(config.num_layers): 133 | resblocks.append(ResidualAttentionBlock(config)) 134 | self.resblocks = nn.ModuleList(resblocks) 135 | 136 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 137 | xs = [] 138 | for r in self.resblocks: 139 | x = r(x, attn_mask) 140 | xs.append(x) 141 | 142 | return x, xs 143 | 144 | 145 | def _expand_token(token, batch_size: int): 146 | return token.view(1, 1, -1).expand(batch_size, -1, -1) 147 | 148 | 149 | class VisionTransformer(nn.Module): 150 | def __init__(self, config): 151 | super().__init__() 152 | self.config = config 153 | 154 | input_dim = config.patch_size * config.patch_size * 3 155 | self.embedding = nn.Linear(input_dim, config.emb_dim, bias=False) 156 | scale = config.emb_dim 157 | self.class_embedding = nn.Parameter(scale * torch.randn(config.emb_dim)) 158 | self.positional_embedding = nn.Parameter(scale * torch.randn(config.num_pos, config.emb_dim)) 159 | self.pre_ln = nn.LayerNorm(config.emb_dim, eps=1e-5) 160 | self.transformer = Transformer(config) 161 | 162 | def add_pos_emb(self, x, pos_ids, patch_num): 163 | cls_emb = self.positional_embedding[0] 164 | pos_emb = self.positional_embedding[1:] 165 | 166 | pos_emb = pos_emb.reshape( 167 | (int(math.sqrt(pos_emb.shape[0])), int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1]) 168 | ) 169 | 170 | (patch_num_0, patch_num_1) = patch_num 171 | # assert patch_num_0 == self.config.patch_size and patch_num_1 == self.config.patch_size_1 172 | if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1: 173 | # Dervied from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 174 | # antialias: default True in jax.image.resize 175 | pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2) 176 | pos_emb = F.interpolate( 177 | pos_emb, size=(patch_num_0, patch_num_1), mode="bicubic", align_corners=False, antialias=True, 178 | ) 179 | pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0) 180 | 181 | pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1])[pos_ids] 182 | x = x + torch.cat([_expand_token(cls_emb, x.shape[0]), pos_emb], dim=1).to(x.dtype) 183 | return x 184 | 185 | def forward(self, x, mask, pos_ids, *, patch_num: Any = (16, 16)): 186 | B = x.shape[0] 187 | x = self.embedding(x) 188 | x = torch.cat([_expand_token(self.class_embedding, B).to(x.dtype), x], dim=1) 189 | 190 | mask = torch.cat([torch.ones([B, 1], dtype=torch.int32, device=mask.device), mask], dim=1) 191 | 192 | x = self.add_pos_emb(x, pos_ids, patch_num) 193 | 194 | x = self.pre_ln(x) 195 | 196 | attn_mask = layers.make_attention_mask(mask, mask).to(x.dtype) 197 | 198 | x, xs = self.transformer(x, attn_mask) 199 | 200 | # remove the cls token 201 | x = x[:, 1:, :] 202 | 203 | x1 = xs[1][:, 1:, :] 204 | 205 | return x, x1 206 | 207 | 208 | class ImageFeature(nn.Module): 209 | """Image features""" 210 | def __init__(self, config) -> None: 211 | super().__init__() 212 | self.config = config 213 | self.vision_transformer = VisionTransformer(config) 214 | 215 | def forward(self, x, mask, pos_ids, *, patch_num: Any = (16, 16)): 216 | x, x1 = self.vision_transformer(x, mask, pos_ids, patch_num=patch_num) 217 | return x, x1 218 | -------------------------------------------------------------------------------- /uio2/image_vqgan.py: -------------------------------------------------------------------------------- 1 | """VQGAN model implementation in PyTorch 2 | Derived from https://github.com/CompVis/taming-transformers. 3 | """ 4 | import math 5 | 6 | import torch 7 | from einops import einops 8 | 9 | from uio2.config import VQGANConfig 10 | from torch import nn 11 | from torch.nn import functional as F 12 | from uio2 import layers 13 | 14 | 15 | def Normalize(in_channels): 16 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) # default num_groups=32, eps=1e-6 in our jax implementation 17 | 18 | 19 | class ResBlock(nn.Module): 20 | def __init__(self, n_in: int, n_out: int): 21 | """ResNet Block""" 22 | super().__init__() 23 | self.norm1 = Normalize(n_in) 24 | self.nonlinear = nn.SiLU() 25 | self.conv1 = nn.Conv2d(n_in, n_out, kernel_size=3, stride=1, padding=1) 26 | self.norm2 = Normalize(n_out) 27 | self.conv2 = nn.Conv2d(n_out, n_out, kernel_size=3, stride=1, padding=1) 28 | self.nin_shortcut = nn.Conv2d(n_in, n_out, kernel_size=1, stride=1, padding=0) if n_in != n_out else None 29 | 30 | def forward(self, x): 31 | # [bs, c, h, w] 32 | h = x 33 | 34 | h = self.norm1(h) 35 | h = self.nonlinear(h) 36 | h = self.conv1(h) 37 | h = self.norm2(h) 38 | h = self.nonlinear(h) 39 | h = self.conv2(h) 40 | 41 | if self.nin_shortcut is not None: 42 | x = self.nin_shortcut(x) 43 | 44 | return x + h 45 | 46 | 47 | class AttnBlock(nn.Module): 48 | def __init__(self, n_in: int): 49 | """Single head self-attention layer""" 50 | super().__init__() 51 | self.norm = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-06) 52 | self.q = nn.Conv2d(n_in, n_in, kernel_size=1, stride=1, padding=0) 53 | self.k = nn.Conv2d(n_in, n_in, kernel_size=1, stride=1, padding=0) 54 | self.v = nn.Conv2d(n_in, n_in, kernel_size=1, stride=1, padding=0) 55 | self.proj_out = nn.Conv2d(n_in, n_in, kernel_size=1, stride=1, padding=0) 56 | 57 | def forward(self, x): 58 | # [bs, c, h, w] 59 | h_ = x 60 | h_ = self.norm(h_) 61 | q = self.q(h_) 62 | k = self.k(h_) 63 | v = self.v(h_) 64 | 65 | # compute attention 66 | b, c, h, w = q.shape 67 | 68 | # attend to values 69 | w_ = torch.einsum('bcq,bck->bqk', q.reshape(b, c, h*w), k.reshape(b, c, h*w)) 70 | w_ = w_ * (c ** -0.5) 71 | w_ = F.softmax(w_, dim=-1) 72 | h_ = torch.einsum('bqk,bck->bcq', w_, v.reshape(b, c, h*w)) 73 | h_ = h_.reshape(b, c, h, w) 74 | 75 | h_ = self.proj_out(h_) 76 | 77 | return x + h_ 78 | 79 | 80 | class Downsample(nn.Module): 81 | def __init__(self, n_in): 82 | """Downsampling layer""" 83 | super().__init__() 84 | # no asymmetric padding in torch conv, must do it ourselves 85 | self.conv = nn.Conv2d(n_in, n_in, kernel_size=3, stride=2, padding=0) 86 | 87 | def forward(self, x): 88 | # [bs, c, h, w] 89 | pad = (0, 1, 0, 1) 90 | x = F.pad(x, pad, mode="constant", value=0) 91 | x = self.conv(x) 92 | return x 93 | 94 | 95 | class Upsample(nn.Module): 96 | def __init__(self, n_in): 97 | """Upsampling layer""" 98 | super().__init__() 99 | self.conv = nn.Conv2d(n_in, n_in, kernel_size=3, stride=1, padding=1) 100 | 101 | def forward(self, x): 102 | # [bs, c, h, w] 103 | x = F.interpolate(x, scale_factor=2.0, mode="nearest") 104 | x = self.conv(x) 105 | return x 106 | 107 | 108 | class Encoder(nn.Module): 109 | def __init__(self, config: VQGANConfig): 110 | super().__init__() 111 | self.config = config 112 | cfg = self.config 113 | 114 | curr_res = cfg.resolution 115 | self.num_resolutions = len(cfg.ch_mult) 116 | self.num_res_blocks = cfg.num_res_blocks 117 | in_ch_mult = (1, ) + tuple(cfg.ch_mult) 118 | 119 | # downsampling 120 | self.conv_in = nn.Conv2d(cfg.in_channels, 1 * cfg.ch, kernel_size=3, stride=1, padding=1) 121 | self.attn_levels = set() 122 | for i_level in range(self.num_resolutions): 123 | block_in = cfg.ch * in_ch_mult[i_level] 124 | block_out = cfg.ch * cfg.ch_mult[i_level] 125 | if curr_res in cfg.attn_resolutions: 126 | self.attn_levels.add(i_level) 127 | for i_block in range(self.num_res_blocks): 128 | self.add_module(f"down_{i_level}_block_{i_block}", ResBlock(block_in, block_out)) 129 | block_in = block_out 130 | if i_level in self.attn_levels: 131 | self.add_module(f"down_{i_level}_attn_{i_block}", AttnBlock(block_in)) 132 | 133 | if i_level != self.num_resolutions - 1: 134 | self.add_module(f"down_{i_level}_downsample", Downsample(block_in)) 135 | curr_res = curr_res // 2 136 | 137 | # middle 138 | self.mid_block_1 = ResBlock(block_in, block_in) 139 | self.mid_attn_1 = AttnBlock(block_in) 140 | self.mid_block_2 = ResBlock(block_in, block_in) 141 | 142 | # end 143 | self.norm_out = Normalize(block_in) 144 | self.nonlinear = nn.SiLU() 145 | self.conv_out = nn.Conv2d( 146 | block_in, 147 | 2 * cfg.z_channels if cfg.double_z else cfg.z_channels, 148 | kernel_size=3, 149 | stride=1, 150 | padding=1, 151 | ) 152 | 153 | def forward(self, x): 154 | # [bs, c, h, w] 155 | # downsampling 156 | h = self.conv_in(x) 157 | for i_level in range(self.num_resolutions): 158 | for i_block in range(self.num_res_blocks): 159 | h = self.__getattr__(f"down_{i_level}_block_{i_block}")(h) 160 | if i_level in self.attn_levels: 161 | h = self.__getattr__(f"down_{i_level}_attn_{i_block}")(h) 162 | if i_level != self.num_resolutions - 1: 163 | h = self.__getattr__(f"down_{i_level}_downsample")(h) 164 | 165 | # middle 166 | h = self.mid_block_1(h) 167 | h = self.mid_attn_1(h) 168 | h = self.mid_block_2(h) 169 | 170 | # end 171 | h = self.norm_out(h) 172 | h = self.nonlinear(h) 173 | h = self.conv_out(h) 174 | 175 | return h 176 | 177 | 178 | class Decoder(nn.Module): 179 | def __init__(self, config: VQGANConfig): 180 | super().__init__() 181 | self.config = config 182 | cfg = self.config 183 | 184 | self.num_resolutions = len(cfg.ch_mult) 185 | self.num_res_blocks = cfg.num_res_blocks 186 | 187 | # compute in_ch_mult, block_in and curr_res at lowest res 188 | in_ch_mult = (1, ) + tuple(cfg.ch_mult) 189 | curr_res = cfg.resolution // (2 ** (self.num_resolutions - 1)) 190 | block_in = cfg.ch * cfg.ch_mult[self.num_resolutions - 1] 191 | 192 | # z to block_in 193 | self.conv_in = nn.Conv2d( 194 | cfg.z_channels, 195 | block_in, 196 | kernel_size=3, 197 | stride=1, 198 | padding=1, 199 | ) 200 | 201 | # middle 202 | self.mid_block_1 = ResBlock(block_in, block_in) 203 | self.mid_attn_1 = AttnBlock(block_in) 204 | self.mid_block_2 = ResBlock(block_in, block_in) 205 | 206 | # upsampling 207 | self.attn_levels = set() 208 | for i_level in reversed(range(self.num_resolutions)): 209 | i_idx = self.num_resolutions - i_level - 1 210 | block_out = cfg.ch * cfg.ch_mult[i_level] 211 | if curr_res in cfg.attn_resolutions: 212 | self.attn_levels.add(i_level) 213 | for i_block in range(self.num_res_blocks + 1): 214 | self.add_module(f"up_{i_idx}_block_{i_block}", ResBlock(block_in, block_out)) 215 | block_in = block_out 216 | if i_level in self.attn_levels: 217 | self.add_module(f"up_{i_idx}_attn_{i_block}", AttnBlock(block_in)) 218 | if i_level != 0: 219 | self.add_module(f"up_{i_idx}_upsample", Upsample(block_in)) 220 | curr_res = curr_res * 2 221 | 222 | # end 223 | self.norm_out = Normalize(block_in) 224 | self.nonlinear = nn.SiLU() 225 | self.conv_out = nn.Conv2d(block_in, cfg.out_ch, kernel_size=3, stride=1, padding=1) 226 | 227 | def forward(self, z): 228 | # [bs, z_channels, h, w] 229 | # z to block_in 230 | h = self.conv_in(z) 231 | 232 | # middle 233 | h = self.mid_block_1(h) 234 | h = self.mid_attn_1(h) 235 | h = self.mid_block_2(h) 236 | 237 | # upsampling 238 | for i_level in reversed(range(self.num_resolutions)): 239 | i_idx = self.num_resolutions - i_level - 1 240 | for i_block in range(self.num_res_blocks + 1): 241 | h = self.__getattr__(f"up_{i_idx}_block_{i_block}")(h) 242 | if i_level in self.attn_levels: 243 | h = self.__getattr__(f"up_{i_idx}_attn_{i_block}")(h) 244 | if i_level != 0: 245 | h = self.__getattr__(f"up_{i_idx}_upsample")(h) 246 | 247 | # end 248 | h = self.norm_out(h) 249 | h = self.nonlinear(h) 250 | h = self.conv_out(h) 251 | 252 | return h 253 | 254 | 255 | class VQGAN(nn.Module): 256 | def __init__(self, config: VQGANConfig): 257 | """VQGAN""" 258 | super().__init__() 259 | self.config = config 260 | cfg = self.config 261 | self.embed_dim = cfg.embed_dim 262 | 263 | self.encoder = Encoder(cfg) 264 | self.quant_conv = nn.Conv2d(cfg.z_channels, cfg.embed_dim, kernel_size=1, stride=1, padding=0) 265 | self.quantize = layers.VectorQuantizer(cfg.n_embed, cfg.embed_dim, beta=0.25) 266 | self.post_quant_conv = nn.Conv2d(cfg.embed_dim, cfg.z_channels, kernel_size=1, stride=1, padding=0) 267 | self.decoder = Decoder(cfg) 268 | 269 | # initialize nn.Conv2d 270 | self.apply(self._init_weights) 271 | 272 | def _init_weights(self, m): 273 | if isinstance(m, nn.Conv2d): 274 | # lecun normal initialization 275 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight) 276 | nn.init.trunc_normal_(m.weight, std=math.sqrt(1 / fan_in), a=-2.0, b=2.0) 277 | if m.bias is not None: 278 | nn.init.zeros_(m.bias) 279 | 280 | def encode(self, x): 281 | h = self.encoder(x) 282 | h = self.quant_conv(h) 283 | quant, emb_loss, info = self.quantize(h) 284 | return quant, emb_loss, info 285 | 286 | def decode(self, quant): 287 | quant = self.post_quant_conv(quant) 288 | dec = self.decoder(quant) 289 | return dec 290 | 291 | def decode_code(self, code_b): 292 | bs, seq_len = code_b.shape 293 | size = int(math.sqrt(seq_len)) 294 | # (bs, h*w) -> (bs, c, h, w) 295 | quant_b = self.quantize.get_codebook_entry(code_b, (bs, size, size, self.embed_dim)) 296 | dec = self.decode(quant_b) 297 | return dec 298 | 299 | def get_codebook_indices(self, x, vqgan_decode=False): 300 | h = self.encoder(x) 301 | h = self.quant_conv(h) 302 | z, _, [_, _, indices] = self.quantize(h) 303 | 304 | if vqgan_decode: 305 | _ = self.decode(z) 306 | 307 | return indices.reshape(h.shape[0], -1) 308 | 309 | def forward(self, x): 310 | # [bs, c, h, w] 311 | quant, diff, _ = self.encode(x) 312 | dec = self.decode(quant) 313 | return dec -------------------------------------------------------------------------------- /uio2/perceiver.py: -------------------------------------------------------------------------------- 1 | """Resampler used for history inputs""" 2 | from typing import Union 3 | 4 | import torch 5 | 6 | from uio2.config import ImageResamplerConfig, AudioResamplerConfig 7 | 8 | from uio2 import layers 9 | from torch import nn 10 | 11 | 12 | class CrossAttention(nn.Module): 13 | def __init__(self, config: Union[ImageResamplerConfig, AudioResamplerConfig], droppath_rate: float = 0.0): 14 | """Cross-attention layer.""" 15 | super().__init__() 16 | self.config = config 17 | self.pre_xattention_norm = layers.UIOLayerNorm(config.emb_dim) 18 | self.xattention = layers.MultiHeadDotProductAttention( 19 | emb_dim=config.emb_dim, 20 | num_heads=config.num_heads, 21 | head_dim=config.head_dim, 22 | dropout_rate=config.dropout_rate, 23 | dropout_broadcast_dims=config.dropout_broadcast_dims, 24 | float32_logits=config.float32_attention_logits, 25 | qk_norm=config.xattn_qk_norm, 26 | clip_attn_logit=config.clip_attn_logit, 27 | scaled_cosine=config.xattn_scaled_cosine, 28 | ) 29 | self.dropout = layers.Dropout(p=config.dropout_rate, broadcast_dims=config.dropout_broadcast_dims) 30 | self.post_xattn_droppath = layers.DropPath(droppath_rate) 31 | self.pre_mlp_norm = layers.UIOLayerNorm(config.emb_dim) 32 | self.mlp = layers.MlpBlock( 33 | emb_dim=config.emb_dim, 34 | intermediate_dim=config.mlp_dim, 35 | activations=config.mlp_activations, 36 | intermediate_dropout_rate=config.dropout_rate, 37 | dropout_broadcast_dims=config.dropout_broadcast_dims, 38 | ) 39 | self.post_mlp_droppath = layers.DropPath(droppath_rate) 40 | 41 | def forward(self, latents, context, mask=None): 42 | # Cross attention block. 43 | assert context.ndim == 3 44 | assert latents.ndim == 3 45 | assert latents.shape[-1] == context.shape[-1] 46 | 47 | # q: latents. [batch, latent_length, emb_dim] 48 | # kv: context. [batch, context_length, emb_dim] 49 | inputs_q = self.pre_xattention_norm(latents) 50 | inputs_kv = context 51 | 52 | # Cross-attention 53 | # [batch, latent_length, emb_dim] x [batch, context_length, emb_dim] 54 | # => [batch, latent_length, emb_dim] 55 | x = self.xattention(inputs_q, inputs_kv, mask=mask) 56 | 57 | x = self.dropout(x) 58 | 59 | x = self.post_xattn_droppath(x) + latents 60 | 61 | # MLP block. 62 | y = self.pre_mlp_norm(x) 63 | 64 | # [batch, length, emb_dim] -> [batch, length, emb_dim] 65 | y = self.mlp(y) 66 | 67 | y = self.post_mlp_droppath(y) + x 68 | return y 69 | 70 | 71 | class Attention(nn.Module): 72 | def __init__(self, config: Union[ImageResamplerConfig, AudioResamplerConfig], droppath_rate: float = 0.0): 73 | """Self-attention layer.""" 74 | super().__init__() 75 | self.config = config 76 | self.pre_attention_norm = layers.UIOLayerNorm(config.emb_dim) 77 | self.attention = layers.MultiHeadDotProductAttention( 78 | emb_dim=config.emb_dim, 79 | num_heads=config.num_heads, 80 | head_dim=config.head_dim, 81 | dropout_rage=config.dropout_rate, 82 | dropout_broadcast_dims=config.dropout_broadcast_dims, 83 | float32_logits=config.float32_attention_logits, 84 | qk_norm=config.attn_qk_norm, 85 | clip_attn_logit=config.clip_attn_logit, 86 | scaled_cosine=config.attn_scaled_cosine, 87 | ) 88 | self.dropout = layers.Dropout(p=config.dropout_rate, broadcast_dims=config.dropout_broadcast_dims) 89 | self.post_attn_droppath = layers.DropPath(droppath_rate) 90 | self.pre_mlp_norm = layers.UIOLayerNorm(config.emb_dim) 91 | self.mlp = layers.MlpBlock( 92 | emb_dim=config.emb_dim, 93 | intermediate_dim=config.mlp_dim, 94 | activations=config.mlp_activations, 95 | intermediate_dropout_rate=config.dropout_rate, 96 | dropout_broadcast_dims=config.dropout_broadcast_dims, 97 | ) 98 | self.post_mlp_droppath = layers.DropPath(droppath_rate) 99 | 100 | def forward(self, latents, mask=None): 101 | # Self-attention block. 102 | 103 | # qkv: latents. [batch, latent_length, emb_dim] 104 | x = self.pre_attention_norm(latents) 105 | 106 | # Self-attention 107 | # [batch, latent_length, emb_dim] 108 | # => [batch, latent_length, emb_dim] 109 | x = self.attention(x, x, mask=mask) 110 | 111 | x = self.dropout(x) 112 | 113 | x = self.post_attn_droppath(x) + latents 114 | 115 | # MLP block. 116 | y = self.pre_mlp_norm(x) 117 | # [batch, length, emb_dim] -> [batch, length, emb_dim] 118 | y = self.mlp(y) 119 | 120 | y = self.post_mlp_droppath(y) + x 121 | return y 122 | 123 | 124 | class PerceiverResampler(nn.Module): 125 | def __init__(self, config: Union[ImageResamplerConfig, AudioResamplerConfig]) -> None: 126 | super().__init__() 127 | """Perceiver resampler: a stack of cross-attention layers.""" 128 | self.config = config 129 | 130 | self.latents = nn.Parameter(torch.empty(config.latents_size, config.emb_dim)) 131 | # default_embedding_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal', out_axis=0) 132 | nn.init.kaiming_normal_(self.latents, mode='fan_in', nonlinearity='linear') 133 | self.context_norm = layers.UIOLayerNorm(config.emb_dim) 134 | self.perceiver_norm = layers.UIOLayerNorm(config.emb_dim) 135 | 136 | dpr = [x.item() for x in torch.linspace(0, config.droppath_rate, config.num_layers)] 137 | for lyr in range(config.num_layers): 138 | if lyr in config.xattention_index: 139 | self.add_module(f'layers_{lyr}', CrossAttention(config, droppath_rate=dpr[lyr])) 140 | else: 141 | self.add_module(f'layers_{lyr}', Attention(config, droppath_rate=dpr[lyr])) 142 | 143 | def forward(self, embed, *, mask=None): 144 | bs, seq_len, dim = embed.shape 145 | 146 | if mask is None: 147 | mask = torch.ones([bs, seq_len], dtype=torch.int32, device=embed.device) 148 | 149 | embed = embed.reshape((bs, seq_len, dim)) 150 | query_mask = torch.ones([bs, self.config.latents_size], dtype=mask.dtype, device=mask.device) 151 | key_mask = mask.reshape((bs, seq_len)) 152 | latents = torch.unsqueeze(self.latents, dim=0) 153 | latents = latents.expand(bs, -1, -1).to(embed.dtype) 154 | 155 | embed = self.context_norm(embed) 156 | xattention_mask = layers.make_attention_mask(query_mask, key_mask).to(embed.dtype) 157 | attention_mask = layers.make_attention_mask(query_mask, query_mask).to(embed.dtype) 158 | 159 | for lyr in range(self.config.num_layers): 160 | if lyr in self.config.xattention_index: 161 | latents = getattr(self, f'layers_{lyr}')(latents, embed, xattention_mask) 162 | else: 163 | latents = getattr(self, f'layers_{lyr}')(latents, attention_mask) 164 | 165 | latents = self.perceiver_norm(latents) 166 | 167 | return latents 168 | 169 | 170 | class Resampler(nn.Module): 171 | def __init__(self, config: Union[ImageResamplerConfig, AudioResamplerConfig]) -> None: 172 | super().__init__() 173 | self.config = config 174 | self.perceiver = PerceiverResampler(config) 175 | 176 | """Perceiver resampler: a stack of cross-attention layers.""" 177 | def forward(self, embed, *, mask=None): 178 | embed = self.perceiver(embed, mask=mask) 179 | return embed 180 | -------------------------------------------------------------------------------- /uio2/preprocessing.py: -------------------------------------------------------------------------------- 1 | """UIO2 pre-processor""" 2 | import dataclasses 3 | import json 4 | from typing import Dict, List 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | import torch 9 | from huggingface_hub import PyTorchModelHubMixin 10 | from transformers import ProcessorMixin, FeatureExtractionMixin 11 | from transformers.utils import PushToHubMixin 12 | 13 | from uio2 import config 14 | from uio2.audio_utils import load_audio 15 | from uio2.config import get_tokenizer, Config 16 | from uio2.data_utils import resize_and_pad_default, values_to_tokens 17 | from uio2.get_modality_processor import get_input_modalities, get_target_modalities 18 | from uio2.utils import flatten_dict 19 | from uio2.video_utils import load_video, remove_bars_from_frames 20 | 21 | 22 | class UnifiedIOPreprocessor(FeatureExtractionMixin): 23 | 24 | PREFIXES = { 25 | "text": "[Text] [S] ", 26 | "audio": "[Audio] [S] ", 27 | "image": "[Image] [S] " 28 | } 29 | 30 | @staticmethod 31 | def from_config(cfg: Config, tokenizer): 32 | input_encoders = get_input_modalities( 33 | cfg.input_modalities, cfg.image_vit_cfg, cfg.audio_vit_cfg, 34 | cfg.image_history_cfg, cfg.audio_history_cfg, cfg.use_image_vit, cfg.use_audio_vit, 35 | cfg.freeze_vit, cfg.use_image_history_vit, cfg.use_audio_history_vit, 36 | ) 37 | target_encoders = get_target_modalities( 38 | cfg.target_modalities, cfg.image_vqgan, cfg.audio_vqgan) 39 | return UnifiedIOPreprocessor( 40 | input_encoders, target_encoders, cfg.sequence_length, tokenizer, cfg) 41 | 42 | @staticmethod 43 | def from_dict(data, tokenizer=None, sequence_length=None): 44 | if tokenizer is None: 45 | raise ValueError("Tokenizer path must be given: `tokenizer=path/to/tokenizer`") 46 | cfg = Config.from_dict(data["config"]) 47 | if sequence_length is not None: 48 | cfg.sequence_length = sequence_length 49 | return UnifiedIOPreprocessor.from_config(cfg, tokenizer) 50 | 51 | def __init__( 52 | self, 53 | input_encoders, 54 | target_encoders, 55 | sequence_length, 56 | tokenizer, 57 | config: config.Config=None 58 | ): 59 | super().__init__() 60 | self.input_encoders = input_encoders 61 | self.target_encoders = target_encoders 62 | self.sequence_length = sequence_length 63 | if isinstance(tokenizer, str): 64 | # Assume a path to the tokenizer file 65 | tokenizer = get_tokenizer(tokenizer) 66 | self.tokenizer = tokenizer 67 | self.config = config # Only needed if saving the Preprocessor 68 | 69 | def to_dict(self): 70 | # Our configuration does not cleanly distinguish pre-processing and model config options 71 | # To avoid a significant re-write, we just dump everything as part of the pre-processor config 72 | if self.config is None: 73 | raise ValueError("Config must be given to convert to dictionary") 74 | out = dict(config=self.config.to_dict()) 75 | out["sequence_length"] = self.sequence_length 76 | return out 77 | 78 | def load_image(self, image): 79 | try: 80 | from PIL import Image 81 | except ImportError: 82 | raise ImportError("Loading images require PIL to be installed") 83 | with Image.open(image) as img: 84 | return np.array(img.convert('RGB')) 85 | 86 | def __call__( 87 | self, 88 | text_inputs, 89 | target_modality=None, 90 | box_inputs=None, image_inputs=None, audio_inputs=None, 91 | video_inputs=None, use_video_audio=True, 92 | encode_frame_as_image=-1, 93 | encode_audio_segment_as_audio=-1, 94 | image_history=None, 95 | 96 | image_targets=None, audio_targets=None, text_targets=None, 97 | 98 | # Other 99 | is_training=False, 100 | ) -> Dict[str, np.ndarray]: 101 | """General pre-processing function 102 | 103 | Args: 104 | target_modality: image, audio or text, the target output modality, 105 | if None will be inferred from the targets 106 | 107 | # inputs 108 | text_inputs: String text inputs 109 | box_input: [x1, y1, x2, y2] pixel coordinates relative to image_inputs, this box 110 | will be tokenized and replace the keyword ``{box}` in text_inputs 111 | image_inputs: RGB image or image file 112 | audio_inputs: Audio spectrograms in [N, 256, 128] format or audio file 113 | video_inputs: [n, W, H, 3] tensor of images or a video file 114 | use_video_audio: Extract audio from the `video_inputs` if it is a file 115 | encode_frame_as_image: If given a video, encode this frame of that video as an image 116 | encode_audio_segment_as_audio: Encode this audio segment with the audio modality 117 | image_history: List of images, can not be set if `video_inputs` is used. 118 | 119 | # Targets 120 | text_targets: String text targets 121 | image_targets: RGB image or image file 122 | audio_targets: Audio spectrograms in [256, 128] format or audio file of < 4.08 seconds 123 | 124 | # Other 125 | is_training: Do rescaling augmentation 126 | 127 | Returns batch of tensors that can be passed into the UIO2 model 128 | """ 129 | targets = [image_targets, audio_targets, text_targets] 130 | assert sum(x is not None for x in targets) <= 1, "Can have at most one target" 131 | if target_modality is None: 132 | if sum(x is not None for x in targets) == 0: 133 | raise ValueError("No targets and not `target_modality` given") 134 | if image_targets is not None: 135 | target_modality = "image" 136 | elif audio_targets is not None: 137 | target_modality = "audio" 138 | else: 139 | target_modality = "text" 140 | 141 | features = {} 142 | 143 | # Add the target-modality prefix which tells the model what to generate 144 | text_inputs = self.PREFIXES[target_modality] + text_inputs 145 | 146 | if box_inputs is not None: 147 | # Need something the box references 148 | assert (image_inputs is not None or 149 | (video_inputs is not None and encode_frame_as_image is not None)) 150 | # To yxyx 151 | box_inputs = [box_inputs[1], box_inputs[0], box_inputs[3], box_inputs[2]] 152 | boxes = np.asarray(box_inputs, dtype=np.float32)[None, :] 153 | else: 154 | boxes = None 155 | 156 | if isinstance(image_targets, str): 157 | image_targets = self.load_image(image_targets) 158 | if isinstance(image_inputs, str): 159 | image_inputs = self.load_image(image_inputs) 160 | 161 | # Information about how the input image was resized 162 | resize_meta = None 163 | 164 | if image_history is not None: 165 | assert video_inputs is None 166 | image_history = [self.load_image(x) if isinstance(x, str) else x for x in image_history] 167 | parts = [resize_and_pad_default(x, is_training, is_input=True, is_history=True) 168 | for x in image_history] 169 | features["image_history_inputs"] = tf.stack([x[0] for x in parts]) 170 | features["image_history_input_masks"] = tf.stack([x[1] for x in parts]) 171 | 172 | video_audio = None 173 | if video_inputs is not None: 174 | if encode_frame_as_image is not None and image_inputs is not None: 175 | raise ValueError("Asked to encode a frame as an image, but also given an image input") 176 | max_frame = self.sequence_length["num_frames"] 177 | if encode_frame_as_image is not None: 178 | # image_inputs will use the last frame 179 | max_frame += 1 180 | if isinstance(video_inputs, str): 181 | video_inputs, video_audio = load_video(video_inputs, max_frame, use_audio=use_video_audio) 182 | else: 183 | assert video_inputs.shape[0] <= max_frame 184 | assert len(video_inputs.shape) == 4 and video_inputs.shape[-1] == 3 185 | 186 | # remove black bars 187 | video_inputs = remove_bars_from_frames(video_inputs, black_bar=True, threshold=16) 188 | 189 | if encode_frame_as_image is None: 190 | video_inputs, video_mask, _ = resize_and_pad_default( 191 | video_inputs, is_training, is_input=True, is_history=True) 192 | elif not is_training: 193 | image_inputs = video_inputs[encode_frame_as_image] 194 | video_inputs = np.delete(video_inputs, encode_frame_as_image, axis=0) 195 | video_inputs, video_mask, _ = resize_and_pad_default( 196 | video_inputs, is_training, is_input=True, is_history=True) 197 | else: 198 | # Make sure augmentation effects the image and history in the same way 199 | # by applying `resize_and_pad_default` to them in the same way 200 | video_inputs, video_mask, resize_meta = resize_and_pad_default( 201 | video_inputs, is_training, boxes=boxes, 202 | masks=image_targets, is_input=True) 203 | features["meta/image_info"] = resize_meta[1] 204 | features["image_inputs"] = video_inputs[encode_frame_as_image] 205 | features["image_input_masks"] = video_mask[encode_frame_as_image] 206 | video_inputs = np.delete(video_inputs, encode_frame_as_image, axis=0) 207 | video_mask = np.delete(video_mask, encode_frame_as_image, axis=0) 208 | # now resize the video into the correct video size 209 | video_inputs = tf.image.resize( 210 | video_inputs, 211 | config.IMAGE_HISTORY_INPUT_SIZE, 212 | method=tf.image.ResizeMethod.BICUBIC) 213 | video_mask = tf.squeeze(tf.image.resize( 214 | tf.expand_dims(video_mask, 3), 215 | config.IMAGE_HISTORY_INPUT_SIZE, 216 | method=tf.image.ResizeMethod.NEAREST_NEIGHBOR), -1) 217 | 218 | features["image_history_inputs"] = video_inputs 219 | features["image_history_input_masks"] = video_mask 220 | 221 | if video_audio is not None or audio_inputs is not None: 222 | if video_audio is not None and audio_inputs is not None: 223 | raise ValueError("Have audio from both the video and as `audio_inputs`") 224 | if isinstance(audio_inputs, str): 225 | spectograms = load_audio(audio_inputs) 226 | elif isinstance(audio_inputs, np.ndarray): 227 | spectograms = audio_inputs 228 | if len(spectograms.shape) == 2: 229 | spectograms = np.expand_dims(spectograms, 0) 230 | else: 231 | spectograms = video_audio 232 | 233 | # spectogram pre-processing 234 | spectograms = np.transpose(spectograms, [0, 2, 1]) 235 | mask = (spectograms != 0).astype(np.int32) 236 | audio = tf.math.log(tf.clip_by_value(spectograms, 1e-5, 1e5)) 237 | audio = audio * mask 238 | audio = tf.expand_dims(audio, -1) 239 | 240 | if encode_audio_segment_as_audio is not None: 241 | features["audio_inputs"] = audio[encode_audio_segment_as_audio] 242 | features["audio_input_masks"] = mask[encode_audio_segment_as_audio] 243 | audio = np.delete(audio, encode_audio_segment_as_audio, axis=0) 244 | mask = np.delete(mask, encode_audio_segment_as_audio, axis=0) 245 | if len(audio) > 0: 246 | features["audio_history_inputs"] = audio 247 | features["audio_history_input_masks"] = mask 248 | 249 | if image_inputs is not None: 250 | image_inputs, image_inputs_mask, resize_meta = resize_and_pad_default( 251 | image_inputs, is_training, boxes=boxes, 252 | masks=image_targets, is_input=True) 253 | features["image_inputs"] = image_inputs 254 | features["image_input_masks"] = image_inputs_mask 255 | 256 | if resize_meta is not None: 257 | features["meta/image_info"] = resize_meta[1] 258 | 259 | if box_inputs: 260 | resized_boxes = resize_meta[2] 261 | if len(resized_boxes) == 0: 262 | # Can happen if `is_training=True` and the box gets cropped during rescaling augmentation 263 | return None 264 | box_text = values_to_tokens(resized_boxes / image_inputs.shape[0]) 265 | assert "{box}" in text_inputs 266 | box_text = " ".join([x.decode("utf-8") for x in box_text.numpy()[0]]) 267 | text_inputs = text_inputs.replace("{box}", box_text) 268 | 269 | if image_targets is not None: 270 | if resize_meta is not None: 271 | # Image was resized in way that matches input image/video 272 | features["image_targets"] = resize_meta[1] 273 | target_mask = tf.image.resize( 274 | tf.expand_dims(tf.cast(features["image_input_masks"], tf.float32), -1), 275 | config.IMAGE_TARGET_SIZE, 276 | method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)[:, :, 0] 277 | features["image_target_masks"] = target_mask 278 | else: 279 | # Resize the image independently 280 | image_targets, image_targets_mask, other = resize_and_pad_default( 281 | image_targets, is_training, is_input=False) 282 | features["image_targets"] = image_targets 283 | features["image_target_masks"] = image_targets_mask 284 | 285 | if audio_targets is not None: 286 | if isinstance(audio_targets, str): 287 | target_spectograms = load_audio(audio_targets) 288 | assert target_spectograms.shape[0] == 1, "Target audio does not fit in one segment" 289 | target_spectograms = target_spectograms[0] 290 | else: 291 | target_spectograms = audio_targets[:, :, None] 292 | 293 | mask = (target_spectograms != 0).astype(np.int32) 294 | audio = tf.math.log(tf.clip_by_value(target_spectograms, 1e-5, 1e5)) 295 | audio = audio * mask 296 | features["audio_targets"] = audio 297 | features["audio_target_masks"] = mask[:, :, 0] 298 | 299 | if text_targets: 300 | features["text_targets"] = text_targets 301 | 302 | if resize_meta: 303 | features["meta/image_info"] = resize_meta[0] 304 | 305 | features["text_inputs"] = text_inputs 306 | features = self.unified_io_preprocessor(features) 307 | return {k: v.numpy() for k, v in features.items()} 308 | 309 | def unified_io_preprocessor(self, features): 310 | input_features = {} 311 | for k, v in self.input_encoders.items(): 312 | fe = v.preprocess_inputs(features, self.tokenizer, self.sequence_length) 313 | if fe: 314 | input_features[k] = fe 315 | 316 | target_features = {} 317 | for k, v in self.target_encoders.items(): 318 | fe = v.preprocess_inputs(features, self.tokenizer, self.sequence_length) 319 | if fe: 320 | target_features[k] = fe 321 | 322 | # Extra features that might be needed by metric functions or for evaluations 323 | if "meta" in features: 324 | meta = features["meta"] 325 | else: 326 | meta = {} 327 | for k in features: 328 | if k.startswith("meta/"): 329 | meta[k[len("meta/"):]] = features[k] 330 | 331 | out = dict( 332 | inputs=input_features, 333 | targets=target_features, 334 | meta=meta 335 | ) 336 | 337 | # Special cases that might need to be used inference 338 | if "choices" in features: 339 | out["choices"] = self.target_encoders["text"].convert_choices( 340 | features["choices"], self.sequence_length) 341 | return flatten_dict(out, sep="/") 342 | 343 | 344 | def build_batch(examples: List[Dict[str, np.ndarray]], device=None) -> Dict[str, np.ndarray]: 345 | """Batch examples from `UnifiedIOPreprocess`""" 346 | keys = set(examples[0]) 347 | for ex in examples[1:]: 348 | keys.update(ex) 349 | out_dict = {} 350 | for key in keys: 351 | vals = [ex.get(key) for ex in examples] 352 | val = [v for v in vals if v is not None][0] 353 | sh = list(val.shape[1:]) 354 | max_len = max(len(v) if v is not None else 0 for v in vals) 355 | out = np.zeros([len(examples), max_len]+sh, dtype=val.dtype) 356 | for ix, v in enumerate(vals): 357 | if v is not None: 358 | out[ix, :len(v)] = v 359 | out_dict[key] = out 360 | 361 | if device is not None: 362 | out_dict = {k: torch.as_tensor(v, device=device) for k, v in out_dict.items()} 363 | return out_dict 364 | -------------------------------------------------------------------------------- /uio2/runner.py: -------------------------------------------------------------------------------- 1 | """Runner to use the model for specific tasks""" 2 | import json 3 | import logging 4 | import re 5 | from os.path import join, dirname 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | from typing import List 10 | 11 | import torch 12 | from PIL import Image 13 | from transformers import LogitsProcessor 14 | 15 | from uio2 import config 16 | from uio2.hifigan.models import Generator as HifiganGenerator 17 | from uio2.preprocessing import UnifiedIOPreprocessor 18 | from uio2.prompt import Prompt 19 | from uio2.utils import flatten_dict, pad_and_stack, token_to_float, undo_box_preprocessing, \ 20 | extra_id_to_float, extract_locations_from_token_ids, undo_image_preprocessing 21 | 22 | HUMAN_POSE_PART = [ 23 | "nose", "left eye", "right eye", "left ear", "right ear", "left shoulder", 24 | "right shoulder", "left elbow", "right elbow", "left wrist", "right wrist", 25 | "left hip", "right hip", "left knee", "right knee", "left ankle", "right ankle"] 26 | 27 | part_name_re = re.compile(r" ([a-z ]+)") 28 | 29 | labelled_box_re = re.compile( 30 | r" ([a-z ]+)") 31 | 32 | 33 | class ClfFreeGuidanceProcessor(LogitsProcessor): 34 | """Apply CLF Free Guidance assuming the bottom half of the score are from the guidance batches""" 35 | def __init__(self, alpha): 36 | self.alpha = alpha 37 | 38 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 39 | scores = torch.log_softmax(scores, -1) 40 | n = scores.shape[0] // 2 41 | guidance_scores = scores[n:] 42 | main_scores = scores[:n] 43 | out = (1 + self.alpha) * main_scores - self.alpha * guidance_scores 44 | return torch.cat([out, out], 0) 45 | 46 | 47 | class ForceKeypointPrediction(LogitsProcessor): 48 | """Force a keypoint prediction from the model that makes a guess for every point 49 | 50 | During training, we don't train the model to predict coordinates for invisible keypoints, 51 | but during inference it is helpful to make a guess for every point since the 52 | KP metric does not penalize you for guessing at an invisible point 53 | """ 54 | 55 | def __init__(self, tokenizer): 56 | mask = [] 57 | for part in HUMAN_POSE_PART: 58 | mask.append(None) 59 | mask.append(None) 60 | mask += tokenizer.encode(part) 61 | mask.append(1) # EOS 62 | self.mask = mask 63 | 64 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 65 | cur_index = input_ids.shape[1] - 1 # Minus one for the BOS 66 | if cur_index >= len(self.mask): 67 | return scores 68 | mask = self.mask[cur_index] 69 | if mask is None: 70 | # Force a location token predictions 71 | scores[:, :32000] = -10000 72 | scores[:, 33000:] = -10000 73 | else: 74 | # Force the next part name 75 | scores = scores*0 76 | scores[:, mask] = 1000 77 | return scores 78 | 79 | 80 | def extract_labelled_boxes(text): 81 | """Extract labelled boxes for UIO2 output text""" 82 | labels = [] 83 | boxes = [] 84 | for y1, x1, y2, x2, name in labelled_box_re.findall(text): 85 | labels.append(name) 86 | boxes.append([int(y1), int(x1), int(y2), int(x2)]) 87 | if boxes: 88 | boxes = extra_id_to_float(np.array(boxes)) 89 | return boxes, labels 90 | 91 | 92 | def extract_keypoints(text, image_info): 93 | """Extract keypoint prediction from UIO output text""" 94 | invalid = False # Is this text a valid keypoint prediction 95 | points, labels = [], [] 96 | for id1, id2, part in part_name_re.findall(text): 97 | ids = (int(id1), int(id2)) 98 | if all(200 <= x < 1200 for x in ids): 99 | labels.append(part) 100 | points.append(ids) 101 | else: 102 | invalid = False 103 | points = extra_id_to_float(np.array(points)) 104 | points *= config.IMAGE_INPUT_SIZE[0] 105 | 106 | part_map = {k: i for i, k in enumerate(HUMAN_POSE_PART)} 107 | output_points = np.zeros([17, 2]) 108 | output_labels = np.zeros([17]) 109 | for point, label in zip(points, labels): 110 | lower = label.strip().lower() 111 | ix = part_map.get(lower) 112 | if ix is None: 113 | invalid = True 114 | elif output_labels[ix] != 0: 115 | # Generated a part twice, skip the later one 116 | invalid = True 117 | else: 118 | output_points[ix] = point 119 | output_labels[ix] = 2 120 | points, labels = output_points, output_labels 121 | 122 | if np.sum(labels) == 0: 123 | # No visible points predicted 124 | return None, invalid 125 | 126 | if image_info is not None: 127 | points = undo_box_preprocessing(np.tile(points, [1, 2]), image_info)[:, :2] 128 | points = points[:, ::-1] # convert to xy 129 | 130 | # replace non visible point with mean so we do something non-crazy if the 131 | # GT turns out to be `visible` 132 | mean = np.mean(points[labels != 0], 0, keepdims=True) 133 | points[labels == 0] = mean 134 | 135 | assert points.shape == (17, 2) 136 | points = np.concatenate([points, labels.astype(points.dtype)[:, None]], -1) 137 | return points, invalid 138 | 139 | 140 | class PredictBoxesPreprocessor(LogitsProcessor): 141 | """Force the model to predict a location tokens if the total probability mass on 142 | all locations > then a threshold. 143 | 144 | Used to prevent a bias towards short sequence caused by EOS becoming the most probable tokens 145 | when probability mass gets spread out over many location tokens 146 | """ 147 | def __init__(self, thresh=0.5, require_one_box=False): 148 | self.require_one_box = require_one_box 149 | self.thresh = thresh 150 | 151 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 152 | logits = torch.log_softmax(scores, dim=-1) 153 | # Total probability on a location token 154 | probs = torch.exp(torch.logsumexp(logits[:, 32000:33000], dim=-1)) 155 | use_loc = probs > self.thresh 156 | if use_loc: 157 | scores[:, :32000] = -10000 158 | scores[:, 33000:] = -10000 159 | if self.require_one_box and input_ids.shape[1] == 1: 160 | # Prevent starting with EOS 161 | scores[:, config.EOS_ID] = -10000 162 | return scores 163 | 164 | 165 | # Default prompts use to train the model in the classifier free settings 166 | IMAGE_CLF_FREE_PROMPT = "An image of a random picture." 167 | AUDIO_CLF_FREE_PROMPT = "A video of a random audio." 168 | 169 | 170 | class SpectogramConverter: 171 | """Convert UIO2 audio spectograms into waveforms that can be played""" 172 | def __init__(self, use_hifigan=True): 173 | self.use_hi_fi_gan = use_hifigan 174 | self.hifigan = None 175 | 176 | def __call__(self, spectogram): 177 | """ 178 | Args: 179 | spectogram: UIO2 spectogram [128, 256, 1] 180 | 181 | Returns waveform with 16000 sampling rate 182 | """ 183 | if self.use_hi_fi_gan: 184 | if self.hifigan is None: 185 | src = join(dirname(__file__), "hifigan") 186 | logging.info("Loading hi-fi-gan") 187 | config_file = f"{src}/checkpoints/config.json" 188 | checkpoint = f"{src}/checkpoints/g_00930000" 189 | with open(config_file) as f: 190 | json_config = json.load(f) 191 | torch_device = torch.device("cpu") 192 | 193 | class ObjConfig: # `Generator` uses attribute lookup, so wrap the json in a dummy class 194 | def __getattr__(self, item): 195 | return json_config[item] 196 | 197 | checkpoint_dict = torch.load(checkpoint, map_location=torch_device) 198 | hifigan_generator = HifiganGenerator(ObjConfig()).to(torch_device) 199 | hifigan_generator.load_state_dict(checkpoint_dict["generator"]) 200 | hifigan_generator.eval() 201 | hifigan_generator.remove_weight_norm() 202 | self.hifigan = hifigan_generator 203 | 204 | spectrogram = np.array(spectogram * 3.8312 - 5.0945)[:, :, 0] 205 | spectrogram = torch.as_tensor(spectrogram, dtype=torch.float32, device=torch.device("cpu")) 206 | 207 | with torch.no_grad(): 208 | y_g_hat = self.hifigan(spectrogram) 209 | return y_g_hat.squeeze().cpu().numpy() 210 | else: 211 | import librosa 212 | spectrogram = np.exp(spectogram * 3.8312 - 5.0945)[:, :, 0] 213 | return librosa.feature.inverse.mel_to_audio( # type: ignore 214 | spectrogram, 215 | sr=16000, 216 | n_fft=1024, 217 | hop_length=256, 218 | win_length=None, 219 | window="hann", 220 | center=True, 221 | pad_mode="reflect", 222 | power=2.0, 223 | n_iter=32, 224 | ) 225 | 226 | 227 | class TaskRunner: 228 | """Wraps a UIO2 model and UIO2 preprocessor and does a set of tasks. 229 | 230 | This is intended mostly to demonstrate how to use the model for these different tasks. 231 | To run these tasks efficiently batch the inputs and run the pre-processing inside a DataLoader. 232 | """ 233 | 234 | def __init__(self, model, uio2_preprocessor: UnifiedIOPreprocessor, prompts=None, 235 | use_hifigan_for_audio=True): 236 | self.model = model 237 | self.uio2_preprocessor = uio2_preprocessor 238 | if prompts is None: 239 | prompts = Prompt() 240 | self.prompt = prompts 241 | self.spectogram_converter = SpectogramConverter(use_hifigan_for_audio) 242 | 243 | @property 244 | def tokenizer(self): 245 | return self.uio2_preprocessor.tokenizer 246 | 247 | @property 248 | def device(self): 249 | return self.model.device 250 | 251 | def singleton_batch(self, batch): 252 | return {k: torch.as_tensor(v, device=self.device)[None, ...] for k, v in batch.items()} 253 | 254 | def predict_text(self, example, max_tokens, detokenize=True, **gen_args): 255 | tokens = self.model.generate( 256 | batch=self.singleton_batch(example), modality="text", 257 | use_cache=True, max_new_tokens=max_tokens, 258 | **gen_args 259 | ) 260 | tokens = tokens[0].cpu() 261 | if detokenize: 262 | return self.tokenizer.decode(tokens) 263 | else: 264 | return tokens 265 | 266 | def refexp(self, image, expression) -> List[float]: 267 | """Perform referring expression 268 | 269 | Args: 270 | image: image or image file to examine 271 | expression: expression to locate 272 | 273 | Returns: bounding box of `expression` in `image` in x1, y1, x2, y2 form 274 | """ 275 | prompt = self.prompt.random_prompt("Refexp") 276 | prompt = prompt.replace("{}", expression) 277 | batch = self.uio2_preprocessor(text_inputs=prompt, image_inputs=image, target_modality="text") 278 | tokens = self.predict_text(batch, max_tokens=6, detokenize=False) 279 | if len(tokens) != 6 or (tokens[0] != 0) or (tokens[-1] != 1): 280 | raise ValueError(f"Output not a bounding box {tokens}") 281 | box = token_to_float(np.array(tokens[1:-1])) 282 | box *= config.IMAGE_INPUT_SIZE[0] # de-normalized w.r.t the preprocessed image 283 | box = undo_box_preprocessing(box, batch["/meta/image_info"]) # -> coordinates for the input image 284 | box = box.tolist() 285 | box = [box[1], box[0], box[3], box[2]] # yxyx to xyxy 286 | return box 287 | 288 | def vqa(self, image, question) -> str: 289 | """Perform VQA 290 | 291 | Args: 292 | image: image or image_file to loo at 293 | question: short answer question 294 | 295 | Returns then answer 296 | """ 297 | prompt = self.prompt.random_prompt("VQA_short_prompt") 298 | prompt = prompt.replace("{}", question) 299 | example = self.uio2_preprocessor(text_inputs=prompt, image_inputs=image, target_modality="text") 300 | out = self.predict_text(example, max_tokens=32) 301 | return out 302 | 303 | def box_categorization(self, image, box, answer_options, batch_size=50): 304 | """Categorization the object in an image region 305 | 306 | Args: 307 | image: image to examine 308 | box: x1y1x2y2 region coordinates 309 | answer_options: possible classes 310 | 311 | Returns: the most probable class 312 | """ 313 | if isinstance(answer_options, list): 314 | tensors = pad_and_stack( 315 | [self.tokenizer.encode(x) + [1] for x in answer_options], add_eos=True) 316 | tensors = tensors.to(self.device) 317 | else: 318 | # assume options are already in tensor form 319 | tensors = answer_options 320 | prompt = self.prompt.random_prompt("Box_Classification_Scene") 321 | example = self.uio2_preprocessor( 322 | text_inputs=prompt, image_inputs=image, target_modality="text", box_inputs=box) 323 | batch = self.singleton_batch(example) 324 | scores = self.model.score_answer_options(batch, tensors, batch_size) 325 | ix = torch.argmin(scores) 326 | if isinstance(answer_options, list): 327 | return answer_options[ix] 328 | else: 329 | return self.tokenizer.decode(tensors[ix]) 330 | 331 | def categorization(self, image, answer_options, batch_size=50): 332 | """Categorize the image, return a class in `answer_options`""" 333 | # imagenet prompt is generic, but using a prompt that give a better hint about what kind 334 | # of classes to consider can help 335 | prompt = self.prompt.random_prompt("image_tagging_imagenet2012") 336 | batch = self.uio2_preprocessor(text_inputs=prompt, image_inputs=image, target_modality="text") 337 | batch = self.singleton_batch(batch) 338 | tensors = pad_and_stack( 339 | [self.tokenizer.encode(x) + [1] for x in answer_options], add_eos=True) 340 | tensors = tensors.to(self.device) 341 | scores = self.model.score_answer_options(batch, tensors, batch_size) 342 | ix = torch.argmin(scores) 343 | return answer_options[ix] 344 | 345 | def localization(self, image, cls, thresh=0.3, nms=0.8, no_cat=False): 346 | """Find all locations where `cls` occurs in `image` 347 | 348 | Args: 349 | image: Image to look at 350 | cls: class name 351 | thresh: always produce a location token if total probabilty on locations is > `thresh` 352 | used to prevent premature EOS during beam search due to probability getting 353 | distributed over many similar location tokens 354 | nms: Apply NMS, if `thresh` is we can occasionally get repeated boxes, 355 | we use NMS with a high threshold to prune them 356 | no_cat: Don't prompt the model to repeat the object categories, makes the response 357 | more token-efficient, but off by default since we did not eval grit with this on 358 | Returns: List of [x1, y1, x2, y2] boxes 359 | """ 360 | if no_cat: 361 | prompt = self.prompt.random_prompt("Object_Detection_No_Cat") 362 | else: 363 | prompt = self.prompt.random_prompt("Object_Detection") 364 | prompt = prompt.replace("{}", cls) 365 | batch = self.uio2_preprocessor( 366 | text_inputs=prompt, image_inputs=image, target_modality="text") 367 | out = self.predict_text( 368 | batch, max_tokens=256, 369 | logits_processor=[PredictBoxesPreprocessor(thresh)], 370 | detokenize=False) 371 | boxes = extract_locations_from_token_ids(out) 372 | if len(boxes) > 0: 373 | boxes = boxes*config.IMAGE_INPUT_SIZE[0] 374 | boxes = undo_box_preprocessing(boxes, batch["/meta/image_info"]) 375 | if nms is not None and len(boxes) > 1: 376 | ixs = tf.image.non_max_suppression( 377 | np.array(boxes), 378 | max_output_size=len(boxes), 379 | scores=np.arange(len(boxes))[::-1], 380 | iou_threshold=nms 381 | ).numpy() 382 | boxes = boxes[ixs] 383 | boxes = np.stack([ 384 | boxes[:, 1], boxes[:, 0], 385 | boxes[:, 3], boxes[:, 2] 386 | ], 1) 387 | return boxes 388 | else: 389 | return np.zeros((0, 4), dtype=np.int32) 390 | 391 | def keypoint_box(self, image, target_box, free_form=False): 392 | """Find keypoint for the person in `target_box` 393 | 394 | Args: 395 | image: image to examine 396 | target_box: person box in x1, y1, x2, y2 coordinates 397 | free_form: Don't force a prediction for every keypoint, including non-visible points 398 | 399 | Returns: the points in [17, 3] if (x1, y1, visible) triples or None 400 | """ 401 | prompt = self.prompt.random_prompt("Pose_Estimation") 402 | prompt = prompt.replace("{}", "{box}") 403 | batch = self.uio2_preprocessor( 404 | text_inputs=prompt, image_inputs=image, target_modality="text", 405 | box_inputs=target_box) 406 | text = self.predict_text( 407 | batch, max_tokens=128, 408 | logits_processor=None if free_form else [ForceKeypointPrediction(self.tokenizer)]) 409 | kps, valid = extract_keypoints(text, batch["/meta/image_info"]) 410 | return kps, text 411 | 412 | def keypoint(self, image): 413 | """End-to-end keypoint, requires multiple rounds of generation 414 | 415 | Args: 416 | image: Image to get keypoints for 417 | 418 | Returns: points: List of [17, 3] keypoint arrays 419 | """ 420 | boxes = self.localization(image, "person", thresh=0.5) 421 | all_points = [] 422 | for box in boxes: 423 | all_points.append(self.keypoint_box(image, box)[0]) 424 | return all_points 425 | 426 | def object_detection(self, image, coco_prompt=False, thresh=0.5, nms=0.8, max_tokens=256): 427 | """Returns a list of x1 y2 x2 y2 boxes, and list string box labels 428 | 429 | note this task can be pretty unreliable for UIO2, particularly for crowded images 430 | """ 431 | if coco_prompt: 432 | # Prompt used for the COCO training data 433 | prompt = self.prompt.random_prompt("Detection_COCO") 434 | else: 435 | # Prompt for other detection datasets, can result in detecting more classes 436 | prompt = self.prompt.random_prompt("Detection_Generic") 437 | batch = self.uio2_preprocessor(text_inputs=prompt, image_inputs=image, target_modality="text") 438 | out = self.predict_text( 439 | batch, max_tokens=max_tokens, logits_processor=[PredictBoxesPreprocessor(thresh)]) 440 | boxes, labels = extract_labelled_boxes(out) 441 | if len(boxes) > 0: 442 | boxes = boxes*config.IMAGE_INPUT_SIZE[0] 443 | boxes = undo_box_preprocessing(boxes, batch["/meta/image_info"]) 444 | if nms is not None and len(boxes) > 1: 445 | ixs = tf.image.non_max_suppression( 446 | np.array(boxes), 447 | max_output_size=len(boxes), 448 | scores=np.arange(len(boxes))[::-1], 449 | iou_threshold=nms 450 | ).numpy() 451 | boxes = boxes[ixs] 452 | labels = [labels[i] for i in ixs] 453 | boxes = np.stack([ 454 | boxes[:, 1], boxes[:, 0], 455 | boxes[:, 3], boxes[:, 2] 456 | ], 1) 457 | return boxes, labels 458 | 459 | def video_tagging(self, video): 460 | """Classify a video 461 | 462 | Args: 463 | video: video file path, or a sequence of frames 464 | 465 | Returns: Predicted text class 466 | """ 467 | prompt = self.prompt.random_prompt("video_tagging") 468 | batch = self.uio2_preprocessor( 469 | text_inputs=prompt, video_inputs=video, use_video_audio=False, target_modality="text") 470 | text = self.predict_text(batch, max_tokens=16) 471 | return text 472 | 473 | def video_captioning(self, video): 474 | """Caption a video 475 | 476 | Args: 477 | video: video file path, or a sequence of frames 478 | 479 | Returns: Text video caption 480 | """ 481 | prompt = self.prompt.random_prompt("video_captioning") 482 | batch = self.uio2_preprocessor( 483 | text_inputs=prompt, video_inputs=video, use_video_audio=False, target_modality="text") 484 | text = self.predict_text(batch, max_tokens=64) 485 | return text 486 | 487 | def audio_captioning(self, audio): 488 | """Caption an audio clip 489 | 490 | Args: 491 | audio: audio file path, or a sequence of spectograms 492 | 493 | Returns: Text audio caption 494 | """ 495 | prompt = self.prompt.random_prompt("audio_caption") 496 | batch = self.uio2_preprocessor( 497 | text_inputs=prompt, audio_inputs=audio, target_modality="text") 498 | text = self.predict_text(batch, max_tokens=64) 499 | return text 500 | 501 | def image_captioning(self, image): 502 | """Caption an image 503 | 504 | Args: 505 | image: image file path or RGB image array 506 | 507 | Returns: Text caption 508 | """ 509 | # This prompt will get a COCO-like caption, which is generally expected 510 | prompt = self.prompt.random_prompt("image_caption_coco_2017") 511 | batch = self.uio2_preprocessor(text_inputs=prompt, image_inputs=image, target_modality="text") 512 | return self.predict_text(batch, max_tokens=64) 513 | 514 | def image_generation(self, text, guidance_scale=10, top_p=0.9, num_out=None, 515 | use_prompt=True): 516 | """Generate a natural image 517 | 518 | Args: 519 | text: Text o match 520 | guidance_scale: Guidance scale for classifier free guidance 521 | top_p: top during sampling 522 | num_out: number of examples to generate 523 | use_prompt: Embed `text` in an image generation prompt 524 | 525 | Returns: List of PIL.Image of lengths `num_out` if num_out, else one PIL.Image 526 | """ 527 | if use_prompt: 528 | prompt = self.prompt.random_prompt("image_generation_coco_2017") 529 | prompt = prompt.replace("{}", text) 530 | else: 531 | prompt = text 532 | example = self.uio2_preprocessor(text_inputs=prompt, target_modality="image") 533 | example = self.singleton_batch(example) 534 | 535 | if guidance_scale: 536 | negative_prompt = self.uio2_preprocessor( 537 | text_inputs=IMAGE_CLF_FREE_PROMPT, target_modality="image") 538 | negative_prompt = self.singleton_batch(negative_prompt) 539 | else: 540 | negative_prompt = None 541 | 542 | if num_out: 543 | # A bit wasteful since we end up re-encoding the same inputs multiple times, 544 | # but GenerationMixin doesn't seem to support multiple outputs 545 | example = {k: v.expand(*([num_out] + [-1]*(len(v.shape)-1))) for k, v in example.items()} 546 | 547 | out = self.model.generate( 548 | example, 549 | negative_prompt=negative_prompt, 550 | guidance_scale=guidance_scale, 551 | top_p=top_p, 552 | top_k=None, 553 | do_sample=True, 554 | modality="image" 555 | ) 556 | out = out.cpu().numpy() 557 | out = (out*255).astype(np.uint8) 558 | if num_out: 559 | return [Image.fromarray(x) for x in out] 560 | else: 561 | return Image.fromarray(out[0]) 562 | 563 | def surface_normal_estimation(self, image, top_p=0.9, temperature=0.9, original_size=True): 564 | """Returns: a RGB surface normal encoding for `image``""" 565 | prompt = self.prompt.random_prompt("Surface_Normals_Estimation") 566 | example = self.uio2_preprocessor(text_inputs=prompt, image_inputs=image, target_modality="image") 567 | out = self.model.generate( 568 | self.singleton_batch(example), 569 | top_p=top_p, 570 | top_k=None, 571 | do_sample=True, 572 | temperature=temperature, 573 | modality="image" 574 | ) 575 | data = out.cpu().numpy()[0] 576 | if original_size: 577 | return undo_image_preprocessing(data, example["/meta/image_info"], to_int=True) 578 | else: 579 | return (data*255).astype(np.uint8) 580 | 581 | def depth_estimation(self, image, top_p=0.9, temperature=0.9, original_size=True): 582 | """Returns: a gray-scale depth map `image`` 583 | 584 | white=0meters, black=10meters, note UIO2 seems to be under-trained on this tasks so 585 | results are often not great 586 | """ 587 | prompt = self.prompt.random_prompt("Depth_Estimation") 588 | example = self.uio2_preprocessor(text_inputs=prompt, image_inputs=image, target_modality="image") 589 | out = self.model.generate( 590 | self.singleton_batch(example), 591 | top_p=top_p, 592 | top_k=None, 593 | do_sample=True, 594 | temperature=temperature, 595 | modality="image" 596 | ) 597 | data = out.cpu().numpy()[0] 598 | if original_size: 599 | return undo_image_preprocessing(data, example["/meta/image_info"], gray_scale=True) 600 | else: 601 | return data.mean(-1) 602 | 603 | def segmentation_box(self, image, target_class, target_box, top_p=0.95, 604 | temperature=0.9, original_size=True): 605 | """Returns a binary mask over the instances of `target_class` in `target_box`""" 606 | prompt = self.prompt.random_prompt("Object_Segmentation") 607 | prompt = prompt.replace("{}", "{box} " + target_class) 608 | example = self.uio2_preprocessor( 609 | text_inputs=prompt, image_inputs=image, box_inputs=target_box, target_modality="image") 610 | out = self.model.generate( 611 | self.singleton_batch(example), 612 | top_p=top_p, 613 | top_k=None, 614 | do_sample=True, 615 | temperature=temperature, 616 | modality="image" 617 | ) 618 | data = out.cpu().numpy() 619 | if original_size: 620 | image = undo_image_preprocessing(data[0], example["/meta/image_info"], gray_scale=True) 621 | image = np.squeeze(image, -1) 622 | else: 623 | image = image.mean(-1) 624 | return image > 0.5 625 | 626 | def segmentation_class(self, image, target_class): 627 | """Return binary masks for each instance of `target_class` in `image`""" 628 | masks = [] 629 | for box in self.localization(image, target_class): 630 | mask = self.segmentation_box(image, target_class, box) 631 | if np.any(mask): 632 | masks.append(mask) 633 | return masks 634 | 635 | def audio_generation(self, text, use_prompt=True, guidance_scale=0, num_out=None, top_p=0.9): 636 | """Generate an audio clip from text""" 637 | if use_prompt: 638 | prompt = self.prompt.random_prompt("Audio_Generation") 639 | prompt = prompt.replace("{}", text) 640 | else: 641 | prompt = text 642 | example = self.uio2_preprocessor(text_inputs=prompt, target_modality="audio") 643 | example = self.singleton_batch(example) 644 | 645 | if guidance_scale: 646 | # Generally not helpful for audio, but can be worth experimenting with 647 | negative_prompt = self.uio2_preprocessor( 648 | text_inputs=AUDIO_CLF_FREE_PROMPT, target_modality="audio") 649 | negative_prompt = self.singleton_batch(negative_prompt) 650 | else: 651 | negative_prompt = None 652 | 653 | if num_out: 654 | example = {k: v.expand(*([num_out] + [-1]*(len(v.shape)-1))) for k, v in example.items()} 655 | 656 | out = self.model.generate( 657 | example, 658 | negative_prompt=negative_prompt, 659 | guidance_scale=guidance_scale, 660 | top_p=top_p, 661 | top_k=None, 662 | do_sample=True, 663 | modality="audio" 664 | ) 665 | out = out.cpu().numpy() 666 | if num_out: 667 | return [self.spectogram_converter(x) for x in out] 668 | else: 669 | return self.spectogram_converter(out[0]) 670 | -------------------------------------------------------------------------------- /uio2/seq_features.py: -------------------------------------------------------------------------------- 1 | """Abstracts sequence of tokens we can encode/decode to making mixing modalities easier""" 2 | import dataclasses 3 | from typing import Optional, List 4 | from dataclasses import dataclass 5 | from torch.nn import functional as F 6 | import torch 7 | 8 | 9 | @dataclass 10 | class TargetSequence: 11 | """Target sequence we can train a decoder to predict""" 12 | 13 | input_embedding: torch.Tensor 14 | """Input embeddings to the decoder""" 15 | 16 | position_embed: torch.Tensor 17 | """Int position ids or embedding""" 18 | 19 | modality_id: torch.Tensor 20 | """Modality ids's of the tokens, can be a scalar if all the same""" 21 | 22 | mask: Optional[torch.Tensor] 23 | """Mask of valid tokens""" 24 | 25 | attn_pattern_mask: Optional[torch.Tensor] = None 26 | """[batch, n_heads, seq_len, seq_len] of relative attention bias""" 27 | 28 | target_tokens: Optional[torch.Tensor] = None 29 | """Target tokens used to compute the loss""" 30 | 31 | subsegments: Optional[torch.Tensor] = None 32 | """ids of targets that should be independently predicted from the encoding of one example""" 33 | 34 | segment_ids: Optional[torch.Tensor] = None 35 | """If packed, an example id for each token""" 36 | 37 | loss_mask: Optional[torch.Tensor] = None 38 | """Mask of tokens to use when computing the loss""" 39 | 40 | @property 41 | def seq_len(self): 42 | return self.input_embedding.shape[1] 43 | 44 | @property 45 | def batch_size(self): 46 | return self.input_embedding.shape[0] 47 | 48 | def __post_init__(self): 49 | bs, seq_len = self.input_embedding.shape[:2] 50 | 51 | if self.position_embed is not None: 52 | assert self.position_embed.shape[:2] in [(1, seq_len), (bs, seq_len)] 53 | 54 | assert self.modality_id.shape in [(), (1, seq_len), (bs, seq_len)] 55 | assert self.modality_id.dtype == torch.int32 56 | 57 | if self.target_tokens is not None: 58 | assert self.target_tokens.shape == (bs, seq_len) 59 | assert self.target_tokens.dtype == torch.int32 60 | 61 | if self.mask is not None: 62 | assert self.mask.shape == (bs, seq_len) 63 | assert self.mask.dtype == torch.int32 or self.mask.dtype == torch.bool 64 | 65 | if self.attn_pattern_mask is not None: 66 | assert self.attn_pattern_mask.shape[0] in [1, bs] 67 | 68 | if self.subsegments is not None: 69 | assert self.subsegments.shape == (bs, seq_len) 70 | assert self.subsegments.dtype == torch.int32 71 | 72 | if self.segment_ids is not None: 73 | assert self.segment_ids.shape == (bs, seq_len) 74 | assert self.segment_ids.dtype == torch.int32 75 | 76 | def get_all_subsegments(self): 77 | subsegments = [self.subsegments, self.segment_ids, 78 | None if len(self.modality_id.shape) <= 1 else self.modality_id] 79 | all_subsegments = None 80 | for part in subsegments: 81 | if part is None: 82 | continue 83 | if all_subsegments is None: 84 | all_subsegments = part 85 | continue 86 | all_subsegments = all_subsegments*(part.max()+1) + part 87 | return all_subsegments 88 | 89 | 90 | @dataclass 91 | class InputSequence: 92 | """Input sequence we can encode with an Encoder""" 93 | 94 | embed: torch.Tensor 95 | """Token input embedding""" 96 | 97 | mask: Optional[torch.Tensor] 98 | """Mask over valid time steps""" 99 | 100 | segment_ids: Optional[torch.Tensor]=None 101 | """If packed, an example id for each token""" 102 | 103 | position_embed: Optional[torch.Tensor]=None 104 | """Positional bias embedding""" 105 | 106 | @property 107 | def seq_len(self): 108 | return self.embed.shape[1] 109 | 110 | @property 111 | def batch_size(self): 112 | return self.embed.shape[0] 113 | 114 | @staticmethod 115 | def empty(bs, seq_len, cfg) -> 'InputSequence': 116 | return InputSequence( 117 | torch.zeros((bs, seq_len, cfg.emb_dim), dtype=cfg.dtype), 118 | torch.zeros((bs, seq_len), dtype=torch.int32), 119 | position_embed=torch.zeros((bs, seq_len, cfg.emb_dim), dtype=cfg.dtype), 120 | ) 121 | 122 | def __post_init__(self): 123 | assert len(self.embed.shape) == 3 124 | bs, seq_len = self.embed.shape[:2] 125 | 126 | if self.position_embed is not None: 127 | assert len(self.position_embed.shape) == 3 128 | assert self.position_embed.shape[:2] in [(bs, seq_len), (1, seq_len)] 129 | if self.mask is not None: 130 | assert self.mask.shape == (bs, seq_len) 131 | if self.segment_ids is not None: 132 | assert self.segment_ids.shape == (bs, seq_len) 133 | 134 | 135 | def expand_scalar(val, seq_len): 136 | if val is None: 137 | return None 138 | elif len(val.shape) <= 1: 139 | val = torch.reshape(val, (1, 1)) 140 | return torch.tile(val, [1, seq_len]) 141 | else: 142 | return val 143 | 144 | 145 | def seq_seq_concat(args): 146 | total_len = sum(x.shape[-1] for x in args) 147 | on = 0 148 | out_list = [] 149 | for args in args: 150 | n = args.shape[-1] 151 | out_list.append(F.pad(args, [0, 0, on, total_len-on-n])) 152 | on += n 153 | return torch.concatenate(out_list, -1) 154 | 155 | 156 | def concat_sequences(seqs: List): 157 | """Concats along the sequence dimension (i.e., horizontally)""" 158 | seq_lens = [x.seq_len for x in seqs] 159 | out = {} 160 | for k in dataclasses.fields(seqs[0]): 161 | k = k.name 162 | args = [expand_scalar(getattr(seq, k), seq.seq_len) for seq in seqs] 163 | 164 | if all(x is None for x in args): 165 | out[k] = None 166 | continue 167 | 168 | max_bs = max(x.shape[0] for x in args if x is not None) 169 | full_sized = [x for x in args if (x is not None and x.shape[0] == max_bs)] 170 | shape = list(full_sized[0].shape) 171 | 172 | if len(full_sized) != len(args): 173 | # Replace scalar/None values with blank/full values 174 | padded_args = [] 175 | for ix, x in enumerate(args): 176 | if x is not None and x.shape[0] == max_bs: 177 | padded_args.append(x) # Full sized 178 | 179 | elif x is not None and x.shape[0] != max_bs: 180 | assert x.shape[0] == 1 # broadcasts the batch dim, tile to max_bs 181 | padded_args.append(torch.tile(x, [max_bs] + [1]*(len(x.shape)-1))) 182 | 183 | else: 184 | assert x is None # replace with zero array of the correct shape 185 | arg_shape = list(shape) 186 | arg_shape[0] = max_bs 187 | if len(shape) <= 3: 188 | arg_shape[1] = seq_lens[ix] 189 | elif len(shape) == 4: 190 | arg_shape = arg_shape[:2] + [seq_lens[ix], seq_lens[ix]] 191 | 192 | padded_args.append(torch.zeros( 193 | *arg_shape, device=full_sized[0].device, dtype=full_sized[0].dtype)) 194 | args = padded_args 195 | if len(shape) == 4: 196 | out[k] = seq_seq_concat(args) 197 | else: 198 | out[k] = torch.concat(args, dim=1) 199 | 200 | if isinstance(seqs[0], InputSequence): 201 | return InputSequence(**out) 202 | else: 203 | return TargetSequence(**out) 204 | -------------------------------------------------------------------------------- /uio2/target_modalities.py: -------------------------------------------------------------------------------- 1 | """Target modality processing""" 2 | from typing import Dict, Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | 8 | from uio2.config import T5Config, VQGANConfig, AudioViTVQGANConfig 9 | from uio2.data_utils import make_autoregressive_inputs 10 | from uio2.input_modalities import ModalityEncoder 11 | from uio2.seq_features import TargetSequence 12 | from uio2.image_vqgan import VQGAN 13 | from uio2.audio_vqgan import ViTVQGAN 14 | from uio2 import layers, config 15 | import tensorflow as tf 16 | 17 | 18 | TEXT_MODALITY_INDEX = 0 19 | IMAGE_MODALITY_INDEX = 1 20 | AUDIO_MODALITY_INDEX = 2 21 | 22 | 23 | class TextEmbedder(nn.Module): 24 | def __init__(self, config): 25 | super().__init__() 26 | self.config = config 27 | 28 | cfg = self.config 29 | self.register_buffer("pos_emb_cache", layers.get_1d_position_embedding( 30 | cfg.text_pos_emb, cfg.decoder_max_text_length, cfg.emb_dim, cfg.head_dim, True, 1), persistent=False) 31 | if "llama_rope" in cfg.text_pos_emb: 32 | self.modality_embedding = nn.Parameter(torch.empty(cfg.emb_dim).normal_(std=0.02)) 33 | 34 | def forward(self, inputs, shared_embed, mask=None, pos_ids=None, segment_ids=None, 35 | targets=None, cur_index=None): 36 | cfg = self.config 37 | bs = inputs.shape[0] 38 | 39 | if pos_ids is None: 40 | if cur_index is not None: 41 | pos_ids = torch.full_like(inputs, cur_index) 42 | else: 43 | pos_ids = torch.arange(inputs.shape[1], dtype=torch.int32, device=inputs.device)[None, ...] 44 | pos_ids = pos_ids.expand(bs, inputs.shape[1]) 45 | 46 | x = shared_embed(inputs) 47 | 48 | pos_emb = self.pos_emb_cache[pos_ids] 49 | 50 | if "llama_rope" in cfg.text_pos_emb: 51 | x += self.modality_embedding[None, None, :].to(x.dtype) 52 | 53 | attn_pattern_mask = torch.ones( 54 | (bs, 4, x.shape[1], x.shape[1]), dtype=x.dtype, device=x.device) 55 | modality_id = torch.full((), TEXT_MODALITY_INDEX, device=x.device, dtype=torch.int32) 56 | return TargetSequence( 57 | x, pos_emb, modality_id, mask, attn_pattern_mask=attn_pattern_mask, 58 | subsegments=segment_ids, target_tokens=targets, loss_mask=mask 59 | ) 60 | 61 | 62 | class TargetTextEncoder(ModalityEncoder): 63 | """Tokenize and embed input text, handles multiple target texts""" 64 | 65 | def preprocess_inputs(self, features, vocab, sequence_length) -> Dict: 66 | text_targets = features.get(f"text_targets") 67 | if "segment_ids" in features: 68 | raise NotImplementedError() 69 | if text_targets is None: 70 | return {} 71 | 72 | if isinstance(text_targets, str): 73 | tokens = tf.convert_to_tensor(vocab.encode(text_targets)) 74 | else: 75 | tokens = text_targets 76 | 77 | tokens = tokens[..., :config.MAX_TEXT_LEN-1] 78 | tokens = tf.pad(tokens, paddings=[[0, 1]], constant_values=config.EOS_ID) 79 | sh = tokens.shape[0] 80 | return { 81 | "targets": tokens, 82 | "inputs": make_autoregressive_inputs(tokens, bos_id=config.BOS_ID), 83 | "pos_ids": tf.range(sh, dtype=tf.int32), 84 | "segment_ids": tf.ones((sh,), dtype=tf.int32), 85 | "mask": tf.cast(tokens > config.PAD_ID, tf.int32) 86 | } 87 | 88 | def get_encoder(self, config: T5Config) -> nn.Module: 89 | return TextEmbedder(config) 90 | 91 | 92 | def _init_mask(height, width, is_bool_mask=False): 93 | attn_size = height * width 94 | mask = torch.tril(torch.ones(attn_size, attn_size, dtype=torch.bool if is_bool_mask else torch.float32)) 95 | return mask 96 | 97 | 98 | def get_row_mask(height=32, width=32, is_bool_mask=False): 99 | mask = _init_mask(height, width, is_bool_mask=is_bool_mask) 100 | step = width + 1 101 | for col in range(mask.shape[1]): 102 | mask[col + step:, col] = False if is_bool_mask else 0.0 103 | return mask 104 | 105 | 106 | def get_col_mask(height=32, width=32, is_bool_mask=False): 107 | mask = _init_mask(height, width, is_bool_mask=is_bool_mask) 108 | step = width - 1 109 | for col in range(mask.shape[1]): 110 | for i in range(1, mask.shape[0], step+1): 111 | mask[col + i: col + i + step, col] = False if is_bool_mask else 0.0 112 | return mask 113 | 114 | 115 | def get_conv_mask(height=32, width=32, kernel=11, is_bool_mask=False, hf_version='v3'): 116 | mask = _init_mask(height, width, is_bool_mask=is_bool_mask) 117 | shift = kernel // 2 118 | for pos in range(mask.shape[1]): 119 | mask[pos+1:, pos] = False if is_bool_mask else 0.0 120 | img = torch.zeros([height, width]) 121 | pixel_id = pos 122 | row = pixel_id // width 123 | col = pixel_id % width 124 | for r in range(-shift, shift+1): 125 | for c in range(-shift, shift+1): 126 | c_abs = max(min(c + col, width - 1), 0) 127 | r_abs = max(min(r + row, height - 1), 0) 128 | img[r_abs, c_abs] = 0.2 129 | cell_id = r_abs * width + c_abs 130 | if cell_id > pos: 131 | mask[cell_id, pos] = True if is_bool_mask else 1.0 132 | img[row, col] = 1.0 133 | return mask 134 | 135 | 136 | class ImageVQGAN(nn.Module): 137 | def __init__(self, config: T5Config, vqgan_config: VQGANConfig): 138 | super().__init__() 139 | self.config = config 140 | self.vqgan_config = vqgan_config 141 | 142 | cfg = self.config 143 | vqgan_cfg = self.vqgan_config 144 | self.grid_size = [ 145 | self.config.default_image_size[0] // self.vqgan_config.patch_size[0], 146 | self.config.default_image_size[1] // self.vqgan_config.patch_size[1], 147 | ] 148 | 149 | assert cfg.image_tokenizer_type == 'vqgan', "Only VQGAN is supported for image." 150 | self.vqgan = VQGAN(vqgan_config) 151 | 152 | # construct the row, col and conv mask. 153 | row_mask = get_row_mask(self.grid_size[0], self.grid_size[1]) 154 | col_mask = get_col_mask(self.grid_size[0], self.grid_size[1]) 155 | conv_mask = get_conv_mask(self.grid_size[0], self.grid_size[1]) 156 | full_mask = _init_mask(self.grid_size[0], self.grid_size[1]) 157 | 158 | self.register_buffer( 159 | "attn_mask", torch.stack([row_mask, col_mask, conv_mask, full_mask], dim=0), persistent=False) 160 | 161 | self.register_buffer("pos_emb_cache", layers.get_2d_position_embedding( 162 | cfg.image_pos_emb, 163 | vqgan_cfg.default_input_size, 164 | vqgan_cfg.patch_size, 165 | cfg.emb_dim, 166 | cfg.head_dim, 167 | 2), persistent=False) 168 | 169 | if "llama_rope" in cfg.image_pos_emb: 170 | self.modality_embedding = nn.Parameter(torch.empty(cfg.emb_dim).normal_(std=0.02)) 171 | 172 | def target_image_to_seq(self, image: torch.Tensor, loss_mask: torch.Tensor = None): 173 | cfg = self.config 174 | bs = image.shape[0] 175 | 176 | # reshape image to (batch, channel, height, width) 177 | image = image.permute(0, 3, 1, 2).contiguous() 178 | target_tokens = self.vqgan.get_codebook_indices(image) 179 | 180 | # 0: start token 181 | # 1: [MASK] token 182 | # from 2: normal tokens 183 | target_tokens = target_tokens + 2 184 | target_tokens = target_tokens.detach() 185 | 186 | input_tokens = torch.cat([ 187 | torch.zeros((target_tokens.shape[0], 1), dtype=torch.int32, device=target_tokens.device), 188 | target_tokens[:, :-1]], dim=1) 189 | 190 | return input_tokens, target_tokens, loss_mask 191 | 192 | def get_target_sequence(self, input_tokens, shared_embed, mask, target_tokens=None, task_mask=None, 193 | loss_mask=None, segment_ids=None, cur_index=None, pos_ids=None): 194 | cfg = self.config 195 | bs = input_tokens.shape[0] 196 | 197 | x = shared_embed(input_tokens) 198 | 199 | if cur_index is not None: 200 | pos_emb = self.pos_emb_cache[cur_index:cur_index+1,:][None, :, :] 201 | else: 202 | pos_emb = self.pos_emb_cache[:x.shape[1]][None, :, :] 203 | 204 | pos_emb = pos_emb.expand(bs, -1, -1) 205 | 206 | if "llama_rope" in cfg.image_pos_emb: 207 | x += self.modality_embedding[None, None, :].to(x.dtype) 208 | 209 | if mask is None: 210 | mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.int32, device=x.device) 211 | 212 | if cfg.dalle_attn_mask: 213 | attn_pattern_mask = self.attn_mask[None,:,:,:].expand(x.shape[0], -1, -1, -1) 214 | else: 215 | # use full mask if we are not using dalle attn mask. 216 | attn_pattern_mask = self.attn_mask[None,-1,:,:].expand(x.shape[0], 4, -1, -1) 217 | 218 | # task_mask: 1 if we should mask the corresponding token 219 | if cfg.dynamic_unk_mask and task_mask is not None: 220 | noise_mask = 1 - task_mask 221 | # shift the mask by 1 222 | noise_mask = torch.cat([ 223 | torch.ones(noise_mask.shape[0], 1, dtype=noise_mask.dtype, device=noise_mask.device), 224 | noise_mask[:, :-1]], dim=1) 225 | dynamic_unk_mask = layers.make_attention_mask(noise_mask, noise_mask) 226 | identity_mask = torch.eye(x.shape[1], dtype=dynamic_unk_mask.dtype, device=dynamic_unk_mask.device) 227 | dynamic_unk_mask = torch.logical_or(dynamic_unk_mask, identity_mask) 228 | attn_pattern_mask = layers.combine_masks(dynamic_unk_mask, attn_pattern_mask).to(attn_pattern_mask.dtype) 229 | 230 | modality_id = torch.full((), IMAGE_MODALITY_INDEX, device=x.device, dtype=torch.int32) 231 | seq = TargetSequence( 232 | x, pos_emb, modality_id, mask, attn_pattern_mask=attn_pattern_mask, 233 | subsegments=segment_ids, target_tokens=target_tokens, loss_mask=loss_mask) 234 | 235 | return seq 236 | 237 | def forward(self, image, shared_embed, mask=None, loss_mask=None, task_mask=None, segment_ids=None, 238 | cur_index=None, pos_ids=None): 239 | 240 | cfg = self.config 241 | if cur_index is not None: 242 | return self.get_target_sequence(image, shared_embed, mask, segment_ids, cur_index=cur_index) 243 | else: 244 | input_tokens, target_tokens, loss_mask = self.target_image_to_seq(image, loss_mask) 245 | 246 | return self.get_target_sequence(input_tokens, shared_embed, mask, target_tokens, task_mask, 247 | loss_mask, segment_ids, pos_ids=pos_ids) 248 | 249 | 250 | class TargetImageVQGANEmbedder(ModalityEncoder): 251 | def __init__(self, config): 252 | super().__init__() 253 | self.config = config 254 | 255 | def preprocess_inputs( 256 | self, features: Dict, tokenizer, sequence_length) -> Optional[Dict[str, tf.Tensor]]: 257 | image_target_size = config.IMAGE_TARGET_SIZE 258 | image_target_d = config.IMAGE_TARGET_D 259 | target_padding_size = tf.constant( 260 | np.array(image_target_size) / image_target_d, tf.int32) 261 | 262 | image_targets = features.pop("image_targets", None) 263 | image_target_masks = features.pop("image_target_masks", None) 264 | image_target_task_masks = features.pop("image_target_task_masks", None) 265 | if image_targets is None: 266 | return {} 267 | else: 268 | image_targets = image_targets * 2.0 - 1 # VQGAN pre-processing 269 | # In case the dimension were unknown 270 | image_targets = tf.ensure_shape(image_targets, image_target_size + [3]) 271 | assert image_target_masks is not None 272 | if len(image_target_masks.shape) == 1: 273 | # Given mask is on the patches rather then pixels, used in depth_preprocessing 274 | image_target_masks = image_target_masks 275 | else: 276 | image_target_masks = tf.image.resize( 277 | tf.expand_dims(image_target_masks, -1), 278 | target_padding_size, 279 | method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) 280 | image_target_masks = tf.cast(tf.reshape(image_target_masks, [-1]), tf.int32) 281 | if image_target_task_masks is None: 282 | image_target_task_masks = tf.zeros(image_target_masks.shape, tf.int32) 283 | else: 284 | if len(image_target_task_masks.shape) == 1: 285 | image_target_task_masks = image_target_task_masks 286 | else: 287 | image_target_task_masks = tf.image.resize( 288 | tf.expand_dims(image_target_task_masks, -1), 289 | target_padding_size, 290 | method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) 291 | image_target_task_masks = tf.cast(tf.reshape(image_target_task_masks, [-1]), tf.int32) 292 | 293 | loss_mask = features.get('image_target_loss_masks', image_target_masks) 294 | 295 | return dict( 296 | image=image_targets, 297 | mask=image_target_masks, 298 | loss_mask=loss_mask, 299 | task_mask=image_target_task_masks, 300 | ) 301 | 302 | def get_encoder(self, config: T5Config) -> nn.Module: 303 | return ImageVQGAN(config, self.config) 304 | 305 | 306 | class AudioVQGAN(nn.Module): 307 | def __init__(self, config: T5Config, vqgan_config: AudioViTVQGANConfig): 308 | super().__init__() 309 | self.config = config 310 | self.vqgan_config = vqgan_config 311 | 312 | cfg = self.config 313 | vqgan_cfg = self.vqgan_config 314 | self.grid_size = [ 315 | self.config.default_audio_size[0] // self.vqgan_config.patch_size[0], 316 | self.config.default_audio_size[1] // self.vqgan_config.patch_size[1], 317 | ] 318 | 319 | self.vqgan = ViTVQGAN(vqgan_config) 320 | 321 | # construct the row, col and conv mask. 322 | row_mask = get_row_mask(self.grid_size[0], self.grid_size[1]) 323 | col_mask = get_col_mask(self.grid_size[0], self.grid_size[1]) 324 | conv_mask = get_conv_mask(self.grid_size[0], self.grid_size[1]) 325 | full_mask = _init_mask(self.grid_size[0], self.grid_size[1]) 326 | 327 | self.register_buffer( 328 | "attn_mask", torch.stack([row_mask, col_mask, conv_mask, full_mask], dim=0), persistent=False) 329 | 330 | self.register_buffer("pos_emb_cache", layers.get_2d_position_embedding( 331 | cfg.audio_pos_emb, 332 | vqgan_cfg.default_input_size, 333 | vqgan_cfg.patch_size, 334 | cfg.emb_dim, 335 | cfg.head_dim, 336 | 3), persistent=False) 337 | 338 | if "llama_rope" in cfg.image_pos_emb: 339 | self.modality_embedding = nn.Parameter(torch.empty(cfg.emb_dim).normal_(std=0.02)) 340 | 341 | def target_audio_to_seq(self, audio: torch.Tensor, loss_mask: torch.Tensor = None): 342 | # audio: (batch, height, width, channel) 343 | cfg = self.config 344 | bs = audio.shape[0] 345 | 346 | # since the vit-vqgan takes as input of shape [128, 256], we need to tranpose this first. 347 | audio = audio.permute(0, 2, 1, 3).contiguous() 348 | target_tokens = self.vqgan.get_codebook_indices(audio) 349 | 350 | # reshape the target back to the original shape: (batch, height=256, width=128) 351 | target_tokens = target_tokens.reshape(bs, self.grid_size[1], self.grid_size[0]) 352 | target_tokens = target_tokens.permute(0, 2, 1).contiguous().view(bs, -1) 353 | 354 | # 0: start token 355 | # 1: [MASK] token 356 | # from 2: normal tokens 357 | target_tokens = target_tokens + 2 358 | target_tokens = target_tokens.detach() 359 | 360 | input_tokens = torch.cat([ 361 | torch.zeros((target_tokens.shape[0], 1), dtype=torch.int32, device=target_tokens.device), 362 | target_tokens[:, :-1]], dim=1) 363 | 364 | return input_tokens, target_tokens, loss_mask 365 | 366 | def get_target_sequence(self, input_tokens, shared_embed, mask, target_tokens=None, task_mask=None, 367 | loss_mask=None, segment_ids=None, cur_index=None): 368 | cfg = self.config 369 | vqgan_cfg = self.vqgan_config 370 | bs = input_tokens.shape[0] 371 | 372 | x = shared_embed(input_tokens) 373 | 374 | if cur_index is not None: 375 | pos_emb = self.pos_emb_cache[cur_index:cur_index+1,:][None, :, :] 376 | else: 377 | pos_emb = self.pos_emb_cache[:x.shape[1]][None, :, :] 378 | 379 | pos_emb = pos_emb.expand(bs, -1, -1) 380 | 381 | if "llama_rope" in cfg.image_pos_emb: 382 | x += self.modality_embedding[None, None, :].to(x.dtype) 383 | 384 | if mask is None: 385 | mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.int32, device=x.device) 386 | 387 | if cfg.dalle_attn_mask: 388 | attn_pattern_mask = self.attn_mask[None,:,:,:].expand(x.shape[0], -1, -1, -1) 389 | else: 390 | # use full mask if we are not using dalle attn mask. 391 | attn_pattern_mask = self.attn_mask[None,-1,:,:].expand(x.shape[0], 4, -1, -1) 392 | 393 | # task_mask: 1 if we should mask the corresponding token 394 | if cfg.dynamic_unk_mask and task_mask is not None: 395 | noise_mask = 1 - task_mask 396 | # shift the mask by 1 397 | noise_mask = torch.cat([ 398 | torch.ones(noise_mask.shape[0], 1, dtype=noise_mask.dtype, device=noise_mask.device), 399 | noise_mask[:, :-1]], dim=1) 400 | dynamic_unk_mask = layers.make_attention_mask(noise_mask, noise_mask) 401 | identity_mask = torch.eye(x.shape[1], dtype=dynamic_unk_mask.dtype, device=dynamic_unk_mask.device) 402 | dynamic_unk_mask = torch.logical_or(dynamic_unk_mask, identity_mask) 403 | attn_pattern_mask = layers.combine_masks(dynamic_unk_mask, attn_pattern_mask).to(attn_pattern_mask.dtype) 404 | 405 | modality_id = torch.full((), AUDIO_MODALITY_INDEX, device=x.device, dtype=torch.int32) 406 | seq = TargetSequence( 407 | x, pos_emb, modality_id, mask, attn_pattern_mask=attn_pattern_mask, 408 | subsegments=segment_ids, target_tokens=target_tokens, loss_mask=loss_mask) 409 | 410 | return seq 411 | 412 | def forward(self, audio, shared_embed, mask=None, loss_mask=None, task_mask=None, segment_ids=None, 413 | cur_index=None, pos_ids=None): 414 | 415 | cfg = self.config 416 | if cur_index is not None: 417 | return self.get_target_sequence(audio, shared_embed, mask, segment_ids, cur_index=cur_index) 418 | else: 419 | input_tokens, target_tokens, loss_mask = self.target_audio_to_seq(audio, loss_mask) 420 | 421 | return self.get_target_sequence(input_tokens, shared_embed, mask, target_tokens, task_mask, 422 | loss_mask, segment_ids) 423 | 424 | 425 | class TargetAudioVQGANEmbedder(ModalityEncoder): 426 | def __init__(self, config): 427 | super().__init__() 428 | self.config = config 429 | 430 | def get_encoder(self, config: T5Config) -> nn.Module: 431 | return AudioVQGAN(config, self.config) 432 | 433 | def preprocess_inputs( 434 | self, features: Dict, tokenizer, sequence_length) -> Optional[Dict[str, tf.Tensor]]: 435 | target_size = config.AUDIO_TARGET_SIZE 436 | target_d = config.AUDIO_TARGET_D 437 | 438 | target_padding_size = tf.constant( 439 | np.array(target_size) / target_d, tf.int32) 440 | 441 | targets = features.pop("audio_targets", None) 442 | target_masks = features.pop("audio_target_masks", None) 443 | target_task_masks = features.pop("audio_target_task_masks", None) 444 | 445 | if targets is None: 446 | return {} 447 | else: 448 | targets = (targets - config.AUDIOSET_MEAN) / config.AUDIOSET_STD 449 | # In case the dimension were unknown 450 | targets = tf.ensure_shape(targets, target_size + [1]) 451 | assert target_masks is not None 452 | if len(target_masks.shape) == 1: 453 | raise ValueError("Mask should be over pixels") 454 | else: 455 | target_masks = tf.image.resize( 456 | tf.expand_dims(target_masks, -1), 457 | target_padding_size, 458 | method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) 459 | target_masks = tf.cast(tf.reshape(target_masks, [-1]), tf.int32) 460 | if target_task_masks is None: 461 | target_task_masks = tf.zeros(target_masks.shape, tf.int32) 462 | else: 463 | if len(target_task_masks.shape) == 1: 464 | target_task_masks = target_task_masks 465 | else: 466 | target_task_masks = tf.image.resize( 467 | tf.expand_dims(target_task_masks, -1), 468 | target_padding_size, 469 | method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) 470 | target_task_masks = tf.cast(tf.reshape(target_task_masks, [-1]), tf.int32) 471 | 472 | loss_mask = features.get('audio_target_loss_masks', target_masks) 473 | 474 | return dict( 475 | audio=targets, 476 | mask=target_masks, 477 | loss_mask=loss_mask, 478 | task_mask=target_task_masks, 479 | ) 480 | -------------------------------------------------------------------------------- /uio2/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for training and inference.""" 2 | from typing import Union 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | import torch 7 | import torch.utils._device 8 | from torch.nn import functional as F 9 | 10 | 11 | def flatten_dict(d, sep="/"): 12 | _out = dict() 13 | 14 | def _fn(part, prefix): 15 | if isinstance(part, dict): 16 | for k, v in part.items(): 17 | _fn(v, prefix + sep + k) 18 | else: 19 | _out[prefix] = part 20 | _fn(d, prefix="") 21 | return _out 22 | 23 | 24 | def unflatten_dict(d, sep='/'): 25 | out = {} 26 | for k, v in d.items(): 27 | parts = k.lstrip(sep).split(sep) 28 | k_out = out 29 | for key in parts[:-1]: 30 | if key not in k_out: 31 | k_out[key] = {} 32 | k_out = k_out[key] 33 | k_out[parts[-1]] = v 34 | return out 35 | 36 | 37 | def pad_and_stack(data, add_eos=False): 38 | data = [np.asarray(x) for x in data] 39 | max_len = max(x.shape[0] for x in data) 40 | if add_eos: 41 | max_len += 1 42 | out = np.zeros((len(data), max_len), dtype=np.int32) 43 | for ix, x in enumerate(data): 44 | out[ix, :len(x)] = x 45 | if add_eos: 46 | out[ix, len(x)] = 1 47 | return torch.as_tensor(out) 48 | 49 | 50 | def extract_locations_from_token_ids(tokens, n=4): 51 | """Extract consecutive location tokens from `tokens`""" 52 | boxes = [] 53 | box = [] 54 | for token in tokens: 55 | if 32000 <= token < 33000: 56 | box.append(token) 57 | if len(box) == n: 58 | boxes.append(token_to_float(np.array(box))) 59 | box = [] 60 | else: 61 | # sequence 0: 79 | b = F.pad(b, other + [0, diff]) 80 | else: 81 | a = F.pad(a, other + [0, -diff]) 82 | return torch.cat([a, b]) 83 | 84 | 85 | def extra_id_to_float(extra_id: Union[int, np.ndarray]): 86 | """Converts extra id numbers from location text tokens to floats 87 | 88 | e.g., means location `extra_id_to_float(201)` 89 | """ 90 | if isinstance(extra_id, int): 91 | assert 200 <= extra_id < 1200 92 | else: 93 | assert np.all(200 <= extra_id) and np.all(extra_id < 1200) 94 | return (extra_id - 200) / (1000 - 1) 95 | 96 | 97 | def undo_box_preprocessing(boxes, image_info): 98 | """Converts bounding boxes to boundings on the original image scale""" 99 | top_pad, left_pad = image_info[0], image_info[1] 100 | paddings = np.array([top_pad, left_pad, top_pad, left_pad], dtype=boxes.dtype) 101 | 102 | if len(boxes.shape) == 1: 103 | boxes = boxes - paddings 104 | else: 105 | boxes = boxes - paddings[None, :] 106 | 107 | # Not sure how to handle offsets at the moment (simple addition?) 108 | # for now just require them to be zero as should be the case during eval 109 | off_y = int(image_info[7]) 110 | off_x = int(image_info[8]) 111 | assert off_x == off_y == 0 112 | 113 | # Undo the scaling 114 | inv_scale = image_info[2] 115 | boxes = boxes * inv_scale 116 | 117 | # clip in case the model predicted a region in the padded area 118 | h, w = image_info[3:5] 119 | boxes = np.maximum(boxes, 0) 120 | boxes = np.minimum(boxes, [h, w, h, w]) 121 | return boxes 122 | 123 | 124 | def undo_image_preprocessing(image, image_info, gray_scale=False, 125 | resize_method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, to_int=False): 126 | """Resizes/crops an image to match the size/scale before pre-processing""" 127 | if gray_scale: 128 | image = tf.reduce_mean(image, -1, keepdims=True) 129 | 130 | off_x = int(image_info[7]) 131 | off_y = int(image_info[8]) 132 | if not (off_x == 0 and off_y == 0): 133 | raise NotImplementedError() 134 | 135 | src_h = int(image_info[3]) 136 | src_w = int(image_info[4]) 137 | 138 | w = max(src_h, src_w) 139 | image = tf.image.resize(image, [w, w], method=resize_method) 140 | if src_h > src_w: 141 | delta = (src_h - src_w) // 2 142 | image = image[:, delta:delta+src_w] 143 | else: 144 | delta = (src_w - src_h) // 2 145 | image = image[delta:delta+src_h, :] 146 | 147 | if to_int: 148 | image = tf.image.convert_image_dtype(image, dtype=tf.uint8) 149 | return image.numpy() 150 | 151 | -------------------------------------------------------------------------------- /uio2/video_utils.py: -------------------------------------------------------------------------------- 1 | """Video utils for video pre-processing""" 2 | import logging 3 | import os.path 4 | import subprocess 5 | from io import BytesIO 6 | import numpy 7 | numpy.float = numpy.float64 8 | numpy.int = numpy.int_ 9 | import numpy as np 10 | 11 | from uio2.audio_utils import read_audio_file, extract_spectrograms_from_audio 12 | 13 | from skvideo import io as skvideo_io 14 | 15 | 16 | # found by trial and error with ffmpeg 17 | BUFFER_FROM_END = 0.1 18 | 19 | WAV_MAX_VALUE = 32768.0 20 | 21 | 22 | def get_video_length(video_path): 23 | # this gets just the video stream length (in the case audio stream is longer) 24 | # E.g. k700-2020/train/watering plants/af3epdZsrTc_000178_000188.mp4 25 | # if audio is shorter than video stream, just pad that 26 | # "-select_streams v:0" gets the video stream, '-select_streams a:0" is audio stream 27 | proc = subprocess.Popen(['ffprobe', '-v', 'error', '-select_streams', 'v:0', '-show_entries', 'stream=duration', 28 | '-of', 'default=noprint_wrappers=1:nokey=1', video_path], 29 | stdout=subprocess.PIPE, stderr=subprocess.STDOUT) 30 | out, _ = proc.communicate() 31 | duration = out.decode('utf-8') 32 | duration = float(duration.strip()) 33 | return duration 34 | 35 | 36 | def exact_audio_from_video(video_file: str, timeout=None, sampling_rate:int=16000): 37 | out = subprocess.run( 38 | ['ffmpeg', '-y', '-i', str(video_file), '-ac', '1', '-ar', 39 | str(sampling_rate), "-f", "wav", "pipe:1"], 40 | timeout=timeout, capture_output=True 41 | ) 42 | if out.returncode != 0: 43 | # Assume no audio in the video file 44 | return None 45 | return out.stdout 46 | 47 | 48 | def extract_single_frame_from_video(video_file, t, verbosity=0): 49 | 50 | timecode = '{:.3f}'.format(t) 51 | try: 52 | reader = skvideo_io.FFmpegReader( 53 | video_file, 54 | inputdict={'-ss': timecode, '-threads': '1'}, 55 | outputdict={'-r': '1', '-q:v': '2', '-pix_fmt': 'rgb24', '-frames:v': '1'}, 56 | verbosity=verbosity 57 | ) 58 | except ValueError as err: 59 | raise ValueError(f"Error on loading {video_file}", err) 60 | 61 | try: 62 | frame = next(iter(reader.nextFrame())) 63 | except StopIteration: 64 | raise ValueError(f"Error on getting frame at time {timecode}s from {video_file}") 65 | 66 | return frame 67 | 68 | 69 | def get_num_segments(video_length, video_segment_length): 70 | num_segments = int(video_length // video_segment_length) 71 | 72 | # allows extra frame only if the midpoint is an available to extract video frames 73 | if (video_length % video_segment_length) - BUFFER_FROM_END > (video_segment_length / 2.0): 74 | num_segments += 1 75 | 76 | return num_segments 77 | 78 | 79 | def extract_frames_from_video(video_path, 80 | video_length, 81 | video_segment_length=None, 82 | times=None, 83 | num_frames=None): 84 | if times is None: # automatically calculate the times if not set 85 | 86 | # make sure one and only one of video_segment_length and num_frames is None 87 | assert video_segment_length is not None or num_frames is not None 88 | assert video_segment_length is None or num_frames is None 89 | 90 | if num_frames is None: 91 | # allows extra frame only if for >=50% of the segment video is available 92 | num_segments = get_num_segments(video_length, video_segment_length) 93 | else: 94 | num_segments = num_frames 95 | 96 | # frames are located at the midpoint of a segment 97 | boundaries = np.linspace(0, video_length, num_segments + 1).tolist() 98 | extract_times = [(boundaries[i] + boundaries[i+1]) / 2.0 for i in range(num_segments)] 99 | else: 100 | extract_times = times 101 | boundaries = None 102 | 103 | # TODO can we do this in one call to ffmpeg? 104 | frames = [extract_single_frame_from_video(video_path, time) for time in extract_times] 105 | 106 | # check to see if any extraction failed 107 | if any([x is None for x in frames]) or frames is None or len(frames) == 0: 108 | raise ValueError(f"Failed to extract frames from {video_path}") 109 | 110 | return np.stack(frames).astype(np.uint8) 111 | 112 | 113 | def extract_frames_and_spectrograms_from_video( 114 | video_file, 115 | video_length=None, 116 | video_segment_length=None, 117 | audio_segment_length=None, 118 | times=None, 119 | num_frames=None, 120 | *, 121 | use_audio, 122 | ): 123 | if times is None: 124 | # get actual video length 125 | if video_length is None: 126 | video_length = get_video_length(video_file) 127 | if video_length is None: 128 | raise ValueError(f"Couldn't get video length for {video_file}") 129 | 130 | # make sure one and only one of video_segment_length and num_frames is None 131 | assert video_segment_length is not None or num_frames is not None 132 | assert video_segment_length is None or num_frames is None 133 | 134 | _video_segment_length = video_length / num_frames if video_segment_length is None else video_segment_length 135 | if video_length < (_video_segment_length / 2.0) - BUFFER_FROM_END: 136 | raise ValueError( 137 | f"Video is too short ({video_length}s is less than half the segment length of {_video_segment_length}s segments") 138 | else: 139 | # don't need this if times is given 140 | video_length = None 141 | 142 | frames = extract_frames_from_video( 143 | video_file, 144 | video_length, 145 | video_segment_length=video_segment_length, 146 | times=times, 147 | num_frames=num_frames, 148 | ) 149 | 150 | spectrograms = None 151 | if use_audio: 152 | assert times is None, "Can't use audio with specific times" 153 | wav_bytes = exact_audio_from_video(video_file) 154 | if wav_bytes is not None: 155 | waveform = read_audio_file(BytesIO(wav_bytes)) 156 | spectrograms = extract_spectrograms_from_audio( 157 | waveform, 158 | audio_length=video_length, 159 | audio_segment_length=_video_segment_length, 160 | spectrogram_length=audio_segment_length, 161 | ) 162 | 163 | return frames, spectrograms 164 | 165 | 166 | def load_video( 167 | path: str, 168 | max_frames: int = 5, 169 | audio_segment_length: float = 4.08, 170 | use_audio: bool=True, 171 | ): 172 | if skvideo_io is None: 173 | raise ValueError("Need to install skvideo to load videos") 174 | 175 | assert os.path.exists(path), path 176 | 177 | frames, spectrograms = extract_frames_and_spectrograms_from_video( 178 | path, 179 | audio_segment_length=audio_segment_length, 180 | num_frames=max_frames, 181 | use_audio=use_audio, 182 | ) 183 | return frames, spectrograms 184 | 185 | 186 | def remove_bars_from_frames(frames, black_bar=True, threshold=32, max_perc_to_trim=0.3): 187 | """ 188 | :param frames: [num_frames, height, width, 3] 189 | :param blackbar_threshold: Pixels must be this intense for us to not trim 190 | :param max_perc_to_prim: Will trim x% by default of the image at most in each dimension 191 | :return: 192 | """ 193 | # Detect black bars#################### 194 | h, w = frames.shape[1], frames.shape[2] 195 | if black_bar: 196 | has_content = frames.max(axis=(0, -1)) >= threshold 197 | else: 198 | has_content = frames.min(axis=(0, -1)) <= threshold 199 | 200 | y_frames = np.where(has_content.any(1))[0] 201 | if y_frames.size == 0: 202 | y_frames = [h // 2] 203 | 204 | y1 = min(y_frames[0], int(h * max_perc_to_trim)) 205 | y2 = max(y_frames[-1] + 1, int(h * (1 - max_perc_to_trim))) 206 | 207 | x_frames = np.where(has_content.any(0))[0] 208 | if x_frames.size == 0: 209 | x_frames = [w // 2] 210 | 211 | x1 = min(x_frames[0], int(w * max_perc_to_trim)) 212 | x2 = max(x_frames[-1] + 1, int(w * (1 - max_perc_to_trim))) 213 | 214 | frames = frames[:, y1:y2, x1:x2] 215 | return frames -------------------------------------------------------------------------------- /uio2/vocabulary.py: -------------------------------------------------------------------------------- 1 | """Tokenizer for UIO2, light modified from seqio""" 2 | # Copyright 2023 The SeqIO Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Modified for UIO2 to use with the LLaMa tokenizer 17 | # For backward compatibility reasons, our tokenizer 18 | # is changed so that EOS is 1 and BOS 0 19 | 20 | import dataclasses 21 | import functools 22 | import threading 23 | from typing import ClassVar, Iterable, Optional, Sequence, Union 24 | 25 | import tensorflow.compat.v2 as tf 26 | 27 | from sentencepiece import sentencepiece_model_pb2 28 | import sentencepiece as sentencepiece_processor 29 | 30 | 31 | class SentencePieceVocabulary: 32 | """Wrapper for nlp/sentencepiece encoder. 33 | 34 | If using extra ids, you can represent them in string-form as ``, 35 | ``, etc. They will be indexed starting from the end of the 36 | vocabulary to match how the masking preprocessors are set up. 37 | 38 | IMPORTANT NOTE: these placeholders only work properly when they are used at 39 | word starts (e.g., "I like peanut butter and sandwiches." or 40 | "I like peanut butter and ly sandwiches" are both okay, but 41 | "I like peanut butter and jel sandwiches" is not.). 42 | """ 43 | 44 | @dataclasses.dataclass 45 | class _ModelContext: 46 | tokenizer: sentencepiece_processor.SentencePieceProcessor 47 | sp_model: bytes 48 | 49 | _load_model_lock: ClassVar[threading.Lock] = threading.Lock() 50 | 51 | def __init__( 52 | self, 53 | sentencepiece_model_file: str, 54 | extra_ids: int = 0, 55 | normalizer_spec_overrides: Optional[ 56 | sentencepiece_model_pb2.NormalizerSpec 57 | ] = None, 58 | reverse_extra_ids: bool = False, 59 | modality_extra_id_n_frames: int = 0, 60 | hack_to_t5_start_tokens: bool = True, 61 | prefix_as_special_token: bool = True, 62 | ): 63 | """Create a SentencePieceVocabulary. 64 | 65 | Optionally, specify a number of extra ids to add to the end of the 66 | vocabulary for use as sentinels. 67 | 68 | Args: 69 | sentencepiece_model_file: path of the sentence piece model. 70 | extra_ids: number of extra ids to include. 71 | normalizer_spec_overrides: If not None, this proto will be merged into the 72 | model's normalizer and denormalizer specs. Thus, any options set on this 73 | object will override the values of those options in the loaded model. 74 | reverse_extra_ids: if True, extra_ids are numbered in descending order, so 75 | the first extra_id has the highest number. This is done for 76 | compatibility with span_corruption mask generation in T5. 77 | """ 78 | self._sentencepiece_model_file = sentencepiece_model_file 79 | self._normalizer_spec_overrides = normalizer_spec_overrides 80 | self._reverse_extra_ids = reverse_extra_ids 81 | self._model: Optional[SentencePieceVocabulary._ModelContext] = None 82 | self._modality_extra_id_n_frames = modality_extra_id_n_frames 83 | self._hack_to_t5_start_tokens = hack_to_t5_start_tokens 84 | self._prefix_as_special_token = prefix_as_special_token 85 | self._extra_ids = extra_ids or 0 86 | 87 | def __getstate__(self): 88 | state = self.__dict__.copy() 89 | # Gin config makes a deep copy of the keyword arguments of configurables. 90 | # When a SentencePieceVocabulary vocabulary is used as a keyword argument 91 | # in a Gin configurable, it must be picklable. We therefore remove 92 | # _model; will be initialized lazily as needed. 93 | del state["_model"] 94 | return state 95 | 96 | def __setstate__(self, state): 97 | self.__dict__.update(state) 98 | self._model = None 99 | 100 | def load_model(self) -> None: 101 | _ = self._model_context() 102 | 103 | def _model_context( 104 | self, 105 | ) -> _ModelContext: 106 | """Loads model if not yet loaded and returns the model context. 107 | 108 | Returns: 109 | The model context as a tuple of (tokenizer, sp_model). 110 | """ 111 | if self._model: 112 | return self._model 113 | 114 | normalizer_spec_overrides_serialized = ( 115 | self._normalizer_spec_overrides.SerializeToString(deterministic=True) 116 | if self._normalizer_spec_overrides 117 | else None 118 | ) 119 | 120 | self._model = self._load_model( 121 | self._sentencepiece_model_file, 122 | self._extra_ids, 123 | normalizer_spec_overrides_serialized, 124 | self._reverse_extra_ids, 125 | modality_extra_id_n_frames=self._modality_extra_id_n_frames, 126 | hack_to_t5_start_tokens=self._hack_to_t5_start_tokens, 127 | prefix_as_special_token=self._prefix_as_special_token 128 | ) 129 | return self._model 130 | 131 | @classmethod 132 | @functools.lru_cache(maxsize=None) 133 | def _load_model( 134 | cls, 135 | sentencepiece_model_file: str, 136 | extra_ids: int, 137 | normalizer_spec_overrides_serialized: Optional[bytes] = None, 138 | reverse_extra_ids: bool = True, 139 | modality_extra_id_n_frames: int = 0, 140 | hack_to_t5_start_tokens=True, 141 | prefix_as_special_token=True, 142 | ) -> _ModelContext: 143 | """Load SPM, Python tokenizer, and cache results to the class definition.""" 144 | # SentencePieceProcessor::LoadFromSerializedProto is not thread-safe. 145 | # Without a lock, users may randomly see SIGSEGV on 146 | # sentencepiece::ModelInterface::pad_piece when using the vocabulary in 147 | # SeqIO preprocessors. 148 | with cls._load_model_lock: 149 | # Handle cases where SP can't load the file, but gfile can. 150 | with tf.io.gfile.GFile(sentencepiece_model_file, "rb") as f: 151 | sp_model = f.read() 152 | model = sentencepiece_model_pb2.ModelProto.FromString(sp_model) 153 | 154 | if hack_to_t5_start_tokens: 155 | # PAD token would still be 0 same as BOS for consistency as previous! 156 | unk = model.pieces[0] 157 | bos = model.pieces[1] 158 | eos = model.pieces[2] 159 | model.pieces.remove(unk) 160 | model.pieces.remove(bos) 161 | model.pieces.remove(eos) 162 | model.pieces.insert(0, bos) # BOS is token 0 163 | model.pieces.insert(1, eos) # EOS is token 1 164 | model.pieces.insert(2, unk) # UNK is token 2 165 | 166 | # Add placeholder strings for extra IDs. 167 | if extra_ids: 168 | # By default, we them in reverse order to match span corruption. 169 | if reverse_extra_ids: 170 | extra_id_tokens = reversed(range(extra_ids)) 171 | else: 172 | extra_id_tokens = range(extra_ids) 173 | 174 | for i in extra_id_tokens: 175 | model.pieces.add( 176 | piece=f"▁", 177 | score=0.0, 178 | type=sentencepiece_model_pb2.ModelProto.SentencePiece.USER_DEFINED, 179 | ) 180 | 181 | if modality_extra_id_n_frames: 182 | # Note: start from 1, not affect by `reverse_extra_ids` and not counted in `extra_ids` 183 | for i in range(1, modality_extra_id_n_frames + 1): 184 | model.pieces.add( 185 | piece=f"▁", 186 | score=0.0, 187 | type=sentencepiece_model_pb2.ModelProto.SentencePiece.USER_DEFINED, 188 | ) 189 | model.pieces.add( 190 | piece=f"▁", 191 | score=0.0, 192 | type=sentencepiece_model_pb2.ModelProto.SentencePiece.USER_DEFINED, 193 | ) 194 | model.pieces.add( 195 | piece=f"▁", 196 | score=0.0, 197 | type=sentencepiece_model_pb2.ModelProto.SentencePiece.USER_DEFINED, 198 | ) 199 | model.pieces.add( 200 | piece=f"▁", 201 | score=0.0, 202 | type=sentencepiece_model_pb2.ModelProto.SentencePiece.USER_DEFINED, 203 | ) 204 | 205 | if prefix_as_special_token: 206 | model.pieces.add( 207 | piece=f"▁[Text]▁[S]", 208 | score=0.0, 209 | type=sentencepiece_model_pb2.ModelProto.SentencePiece.USER_DEFINED, 210 | ) 211 | model.pieces.add( 212 | piece=f"▁[Text]▁[R]", 213 | score=0.0, 214 | type=sentencepiece_model_pb2.ModelProto.SentencePiece.USER_DEFINED, 215 | ) 216 | model.pieces.add( 217 | piece=f"▁[Text]▁[X]", 218 | score=0.0, 219 | type=sentencepiece_model_pb2.ModelProto.SentencePiece.USER_DEFINED, 220 | ) 221 | model.pieces.add( 222 | piece=f"▁[Image]▁[S]", 223 | score=0.0, 224 | type=sentencepiece_model_pb2.ModelProto.SentencePiece.USER_DEFINED, 225 | ) 226 | model.pieces.add( 227 | piece=f"▁[Image]▁[R]", 228 | score=0.0, 229 | type=sentencepiece_model_pb2.ModelProto.SentencePiece.USER_DEFINED, 230 | ) 231 | model.pieces.add( 232 | piece=f"▁[Audio]▁[S]", 233 | score=0.0, 234 | type=sentencepiece_model_pb2.ModelProto.SentencePiece.USER_DEFINED, 235 | ) 236 | model.pieces.add( 237 | piece=f"▁[Audio]▁[R]", 238 | score=0.0, 239 | type=sentencepiece_model_pb2.ModelProto.SentencePiece.USER_DEFINED, 240 | ) 241 | 242 | if normalizer_spec_overrides_serialized is not None: 243 | normalizer_spec_overrides = ( 244 | sentencepiece_model_pb2.NormalizerSpec.FromString( 245 | normalizer_spec_overrides_serialized 246 | ) 247 | ) 248 | 249 | model.normalizer_spec.MergeFrom(normalizer_spec_overrides) 250 | model.denormalizer_spec.MergeFrom(normalizer_spec_overrides) 251 | sp_model = model.SerializeToString() 252 | # Load Python tokenizer and ensure the EOS and PAD IDs are correct. 253 | tokenizer = sentencepiece_processor.SentencePieceProcessor() 254 | tokenizer.LoadFromSerializedProto(sp_model) 255 | return cls._ModelContext(tokenizer=tokenizer, sp_model=sp_model) 256 | 257 | @property 258 | def modality_extra_ids(self): 259 | if self._modality_extra_id_n_frames: 260 | # image/audio input + n * image/audio history + R/S * 3 modalities + [Text] [X] 261 | return (self._modality_extra_id_n_frames + 1) * 2 + self._prefix_as_special_token * (2 * 3 + 1) 262 | return 0 + self._prefix_as_special_token * (2 * 3 + 1) 263 | 264 | @property 265 | def bos_id(self) -> Optional[int]: 266 | return self.tokenizer.bos_id() 267 | 268 | @property 269 | def pad_id(self) -> Optional[int]: 270 | return 0 271 | 272 | @property 273 | def eos_id(self) -> Optional[int]: 274 | return self.tokenizer.eos_id() 275 | 276 | @property 277 | def unk_id(self) -> Optional[int]: 278 | return self.tokenizer.unk_id() 279 | 280 | @property 281 | def sp_model(self) -> Optional[bytes]: 282 | """Retrieve the SPM.""" 283 | return self._model_context().sp_model 284 | 285 | @property 286 | def sentencepiece_model_file(self) -> str: 287 | return self._sentencepiece_model_file 288 | 289 | @property 290 | def tokenizer(self) -> sentencepiece_processor.SentencePieceProcessor: 291 | """Returns the Python tokenizer.""" 292 | return self._model_context().tokenizer 293 | 294 | @property 295 | def vocab_size(self): 296 | return self._base_vocab_size 297 | 298 | @property 299 | def _base_vocab_size(self): 300 | return self.tokenizer.GetPieceSize() 301 | 302 | def _encode(self, s): 303 | return self.tokenizer.EncodeAsIds(s) 304 | 305 | def _decode(self, ids): 306 | # convert all the extra ids (sentinels) to UNK=2 307 | unk_id = self.tokenizer.unk_id() 308 | piece_size = self.tokenizer.GetPieceSize() 309 | ids = [unk_id if i >= piece_size else int(i) for i in ids] 310 | return self.tokenizer.DecodeIds(ids) 311 | 312 | @property 313 | def extra_ids(self) -> int: 314 | return self._extra_ids 315 | 316 | def encode(self, s: Union[Sequence[int], str]) -> Sequence[int]: 317 | """Tokenizes string to an int sequence, without adding EOS.""" 318 | return self._encode(s) 319 | 320 | def decode(self, ids: Iterable[int]): 321 | """Detokenizes int32 iterable to a string, up through first EOS.""" 322 | clean_ids = list(ids) 323 | 324 | if self.unk_id is not None: 325 | vocab_size = self._base_vocab_size 326 | clean_ids = [self.unk_id if i >= vocab_size else i for i in clean_ids] 327 | 328 | if self.eos_id is not None and self.eos_id in clean_ids: 329 | clean_ids = clean_ids[: clean_ids.index(self.eos_id) + 1] 330 | 331 | return self._decode(clean_ids) 332 | 333 | @property 334 | def tf_tokenizer(self): 335 | """Instantiate and return a TF tokenizer.""" 336 | # TF tokenize is not used in the pytorch version, so import here to keep the 337 | # dependency optional 338 | import tensorflow_text as tf_text 339 | return tf_text.SentencepieceTokenizer(model=self.sp_model) 340 | 341 | def encode_tf(self, s: tf.Tensor) -> tf.Tensor: 342 | """Tokenizes string Scalar to an int32 Tensor, without adding EOS.""" 343 | return self._encode_tf(s) 344 | 345 | def decode_tf(self, ids: tf.Tensor) -> tf.Tensor: 346 | """Detokenizes int32 batched Tensor through first EOS.""" 347 | clean_ids = ids 348 | 349 | if self.unk_id is not None: 350 | base_vocab_size = self._base_vocab_size 351 | clean_ids = tf.where( 352 | tf.less(clean_ids, base_vocab_size), clean_ids, self.unk_id 353 | ) 354 | 355 | if self.eos_id is not None: 356 | after_eos = tf.cumsum( 357 | tf.cast(tf.equal(clean_ids, self.eos_id), tf.int32), 358 | exclusive=True, 359 | axis=-1, 360 | ) 361 | clean_ids = tf.where(tf.cast(after_eos, tf.bool), self.pad_id, clean_ids) 362 | 363 | return self._decode_tf(clean_ids) 364 | 365 | def _encode_tf(self, s): 366 | return self.tf_tokenizer.tokenize(s) 367 | 368 | def _decode_tf(self, ids): 369 | return self.tf_tokenizer.detokenize(ids) 370 | --------------------------------------------------------------------------------