├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── pix2struct ├── __init__.py ├── configs │ ├── __init__.py │ ├── init │ │ ├── __init__.py │ │ ├── pix2struct_base_init.gin │ │ ├── pix2struct_large_init.gin │ │ ├── random_init.gin │ │ ├── warmup_base_init.gin │ │ └── warmup_large_init.gin │ ├── models │ │ ├── __init__.py │ │ ├── pix2struct.gin │ │ └── t5_1_1_flaxformer.gin │ ├── optimizers │ │ ├── __init__.py │ │ └── adafactor.gin │ ├── runs │ │ ├── __init__.py │ │ ├── eval.gin │ │ ├── inference.gin │ │ └── train.gin │ ├── schedules │ │ ├── __init__.py │ │ ├── ai2d.gin │ │ ├── chartqa.gin │ │ ├── docvqa.gin │ │ ├── infographicvqa.gin │ │ ├── ocrvqa.gin │ │ ├── refexp.gin │ │ ├── screen2words.gin │ │ ├── textcaps.gin │ │ └── widget_captioning.gin │ └── sizes │ │ ├── __init__.py │ │ ├── base.gin │ │ ├── large.gin │ │ └── tiny.gin ├── demo.py ├── demo_utils.py ├── example_inference.py ├── inference_utils.py ├── metrics.py ├── metrics_test.py ├── models.py ├── models_test.py ├── postprocessors.py ├── preprocessing │ ├── __init__.py │ ├── convert_ai2d.py │ ├── convert_chartqa.py │ ├── convert_docvqa.py │ ├── convert_ocrvqa.py │ ├── convert_refexp.py │ ├── convert_screen2words.py │ ├── convert_textcaps.py │ ├── convert_widget_captioning.py │ └── preprocessing_utils.py ├── preprocessors.py ├── preprocessors_test.py ├── tasks.py ├── transfer_utils.py └── web │ ├── static │ └── style.css │ └── templates │ └── demo_screenshot.html └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled python modules. 2 | *.pyc 3 | 4 | # Byte-compiled 5 | _pycache__/ 6 | .cache/ 7 | 8 | # Poetry, setuptools, PyPI distribution artifacts. 9 | /*.egg-info 10 | .eggs/ 11 | build/ 12 | dist/ 13 | poetry.lock 14 | 15 | # Tests 16 | .pytest_cache/ 17 | 18 | # Type checking 19 | .pytype/ 20 | 21 | # Other 22 | *.DS_Store 23 | 24 | # PyCharm 25 | .idea 26 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | External contributions are not accepted, sorry! 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pix2Struct 2 | This repository contains code for [Pix2Struct: Screenshot Parsing as Pretraining 3 | for Visual Language Understanding](https://arxiv.org/abs/2210.03347). 4 | 5 | We release pretrained checkpoints for the Base and Large models and code for 6 | finetuning them on the nine downstream tasks discussed in the paper. 7 | We are unable to release the pretraining data, but they can be replicated using 8 | the publicly available URLs released in the 9 | [C4 dataset](https://www.tensorflow.org/datasets/catalog/c4). 10 | 11 | # Getting Started 12 | Clone the github repository, install the `pix2struct` package, and run 13 | the tests to ensure that all dependencies were successfully installed. 14 | 15 | ``` 16 | git clone https://github.com/google-research/pix2struct.git 17 | cd pix2struct 18 | conda create -n pix2struct python=3.9 19 | conda activate pix2struct 20 | pip install -e ."[dev]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 21 | pytest 22 | ``` 23 | 24 | You may first need to install Java (`sudo apt install default-jre`) and 25 | [conda](https://conda.io/projects/conda/en/latest/user-guide/install/index.html) 26 | if not already installed. 27 | 28 | We will be using Google Cloud Storage (GCS) for data and model storage. For the 29 | remaining documentation we will assume that the path to your own bucket and 30 | directory is in the `PIX2STRUCT_DIR` environment variable: 31 | 32 | ``` 33 | export PIX2STRUCT_DIR="gs:///" 34 | ``` 35 | 36 | The code for running experiments assumes this environment variable when looking 37 | for the preprocessed data. 38 | 39 | # Data Preprocessing 40 | 41 | Our data preprocessing scripts are run with [Dataflow](https://cloud.google.com/dataflow/docs/quickstarts/create-pipeline-python) 42 | by default using the [Apache Beam library](https://cloud.google.com/dataflow/docs/concepts/beam-programming-model). 43 | They can also be run locally by turning off flags appearing after `--`. 44 | 45 | For the remaining documentation we will assume that GCP project information is 46 | in the following environment variables: 47 | 48 | ``` 49 | export GCP_PROJECT= 50 | export GCP_REGION= 51 | ``` 52 | 53 | Below are the commands required to preprocess each dataset. The results will 54 | be written to `$PIX2STRUCT_DIR/data//preprocessed/`, which is the 55 | file structure assumed in `tasks.py`. 56 | 57 | ## TextCaps 58 | ``` 59 | mkdir -p data/textcaps 60 | cd data/textcaps 61 | curl -O https://dl.fbaipublicfiles.com/textvqa/data/textcaps/TextCaps_0.1_train.json 62 | curl -O https://dl.fbaipublicfiles.com/textvqa/data/textcaps/TextCaps_0.1_val.json 63 | curl -O https://dl.fbaipublicfiles.com/textvqa/data/textcaps/TextCaps_0.1_test.json 64 | curl -O https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip 65 | curl -O https://dl.fbaipublicfiles.com/textvqa/images/test_images.zip 66 | unzip train_val_images.zip 67 | rm train_val_images.zip 68 | unzip test_images.zip 69 | rm test_images.zip 70 | cd .. 71 | gsutil -m cp -r textcaps_data $PIX2STRUCT_DIR/data/textcaps 72 | python -m pix2struct.preprocessing.convert_textcaps \ 73 | --textcaps_dir=$PIX2STRUCT_DIR/data/textcaps \ 74 | --output_dir=$PIX2STRUCT_DIR/data/textcaps/processed \ 75 | -- \ 76 | --runner=DataflowRunner \ 77 | --save_main_session \ 78 | --project=$GCP_PROJECT \ 79 | --region=$GCP_REGION \ 80 | --temp_location=$PIX2STRUCT_DIR/data/temp \ 81 | --staging_location=$PIX2STRUCT_DIR/data/staging \ 82 | --setup_file=./setup.py 83 | ``` 84 | 85 | ## ChartQA 86 | ``` 87 | mkdir -p data/chartqa 88 | cd data/chartqa 89 | git clone https://github.com/vis-nlp/ChartQA.git 90 | cp -r ChartQA/ChartQA\ Dataset/* ./ 91 | rm -rf ChartQA 92 | cd .. 93 | gsutil -m cp -r chartqa $PIX2STRUCT_DIR/data/chartqa 94 | python -m pix2struct.preprocessing.convert_chartqa \ 95 | --data_dir=$PIX2STRUCT_DIR/data/chartqa \ 96 | -- \ 97 | --runner=DataflowRunner \ 98 | --save_main_session \ 99 | --project=$GCP_PROJECT \ 100 | --region=$GCP_REGION \ 101 | --temp_location=$PIX2STRUCT_DIR/data/temp \ 102 | --staging_location=$PIX2STRUCT_DIR/data/staging \ 103 | --setup_file=./setup.py 104 | ``` 105 | 106 | ## RICO Images 107 | Screen2Words, RefExp, and Widget Captioning all require images from the RICO 108 | dataset. If you'd like to use any of these datasets, please process RICO images 109 | before proceeding. 110 | 111 | ``` 112 | cd data 113 | wget https://storage.googleapis.com/crowdstf-rico-uiuc-4540/rico_dataset_v0.1/unique_uis.tar.gz 114 | tar xvfz unique_uis.tar.gz 115 | rm unique_uis.tar.gz 116 | gsutil -m cp -r combined $PIX2STRUCT_DIR/data/rico_images 117 | ``` 118 | 119 | ## Widget Captioning 120 | If you haven't already setup RICO, please do so before you proceed. 121 | 122 | ``` 123 | mkdir -p data/widget_captioning 124 | cd data/widget_captioning 125 | git clone https://github.com/google-research-datasets/widget-caption.git 126 | cp widget-caption/widget_captions.csv ./ 127 | cp widget-caption/split/*.txt ./ 128 | mv dev.txt val.txt 129 | rm -rf widget-caption 130 | cd .. 131 | gsutil -m cp -r widget_captioning $PIX2STRUCT_DIR/data/widget_captioning 132 | python -m pix2struct.preprocessing.convert_widget_captioning \ 133 | --data_dir=$PIX2STRUCT_DIR/data/widget_captioning \ 134 | --image_dir=$PIX2STRUCT_DIR/data/rico_images \ 135 | -- \ 136 | --runner=DataflowRunner \ 137 | --save_main_session \ 138 | --project=$GCP_PROJECT \ 139 | --region=$GCP_REGION \ 140 | --temp_location=$PIX2STRUCT_DIR/data/temp \ 141 | --staging_location=$PIX2STRUCT_DIR/data/staging \ 142 | --setup_file=./setup.py 143 | ``` 144 | 145 | ## Screen2Words 146 | If you haven't already setup RICO, please do so before you proceed. 147 | 148 | ``` 149 | cd data 150 | git clone https://github.com/google-research-datasets/screen2words.git 151 | gsutil -m cp -r screen2words $PIX2STRUCT_DIR/data/screen2words 152 | python -m pix2struct.preprocessing.convert_screen2words \ 153 | --screen2words_dir=$PIX2STRUCT_DIR/data/screen2words \ 154 | --rico_dir=$PIX2STRUCT_DIR/data/rico_images \ 155 | -- \ 156 | --runner=DataflowRunner \ 157 | --save_main_session \ 158 | --project=$GCP_PROJECT \ 159 | --region=$GCP_REGION \ 160 | --temp_location=$PIX2STRUCT_DIR/data/temp \ 161 | --staging_location=$PIX2STRUCT_DIR/data/staging \ 162 | --setup_file=./setup.py 163 | ``` 164 | 165 | ## RefExp 166 | If you haven't already setup RICO, please do so before you proceed. 167 | 168 | ``` 169 | mkdir -p data/refexp 170 | cd data/refexp 171 | wget https://github.com/google-research-datasets/uibert/raw/main/ref_exp/train.tfrecord 172 | wget https://github.com/google-research-datasets/uibert/raw/main/ref_exp/dev.tfrecord 173 | wget https://github.com/google-research-datasets/uibert/raw/main/ref_exp/test.tfrecord 174 | mv dev.tfrecord val.tfrecord 175 | cd .. 176 | gsutil -m cp -r refexp $PIX2STRUCT_DIR/data/refexp 177 | python -m pix2struct.preprocessing.convert_refexp \ 178 | --data_dir=$PIX2STRUCT_DIR/data/refexp \ 179 | --image_dir=$PIX2STRUCT_DIR/data/rico_images \ 180 | -- \ 181 | --runner=DataflowRunner \ 182 | --save_main_session \ 183 | --project=$GCP_PROJECT \ 184 | --region=$GCP_REGION \ 185 | --temp_location=$PIX2STRUCT_DIR/data/temp \ 186 | --staging_location=$PIX2STRUCT_DIR/data/staging \ 187 | --setup_file=./setup.py 188 | ``` 189 | 190 | ## DocVQA 191 | ``` 192 | mkdir -p data/docvqa 193 | cd data/docvqa 194 | ``` 195 | Download DocVQA (Single Document Visual Question Answering) from 196 | [the official source](https://rrc.cvc.uab.es/?ch=17&com=downloads) (requires 197 | registration). The following steps assume that the train/val/test.tar.gz files 198 | are in `data/docvqa`. 199 | 200 | ``` 201 | tar xvf train.tar.gz 202 | tar xvf val.tar.gz 203 | tar xvf test.tar.gz 204 | rm -r *.tar.gz */ocr_results 205 | 206 | cd .. 207 | gsutil -m cp -r docvqa $PIX2STRUCT_DIR/data/docvqa 208 | python -m pix2struct.preprocessing.convert_docvqa \ 209 | --data_dir=$PIX2STRUCT_DIR/data/docvqa \ 210 | -- \ 211 | --runner=DataflowRunner \ 212 | --save_main_session \ 213 | --project=$GCP_PROJECT \ 214 | --region=$GCP_REGION \ 215 | --temp_location=$PIX2STRUCT_DIR/data/temp \ 216 | --staging_location=$PIX2STRUCT_DIR/data/staging \ 217 | --setup_file=./setup.py 218 | ``` 219 | 220 | ## InfographicVQA 221 | ``` 222 | mkdir -p data/infographicvqa 223 | cd data/infographicvqa 224 | ``` 225 | Download InfographicVQA Task 1 from [this](https://rrc.cvc.uab.es/?ch=17&com=downloads) 226 | website (requires registration). The following steps assume that the 227 | `train/val/test.json` and the `zip` files are in `data/infographicvqa`. 228 | 229 | ``` 230 | for split in train val test 231 | do 232 | unzip infographicVQA_${split}_v1.0_images.zip 233 | mv infographicVQA_${split}_v1.0_images $split 234 | mv infographicVQA_${split}_v1.0.json $split/${split}_v1.0.json 235 | done 236 | rm *.zip 237 | 238 | cd .. 239 | gsutil -m cp -r infographicvqa $PIX2STRUCT_DIR/data/infographicvqa 240 | python -m pix2struct.preprocessing.convert_docvqa \ 241 | --data_dir=$PIX2STRUCT_DIR/data/infographicvqa \ 242 | -- \ 243 | --runner=DataflowRunner \ 244 | --save_main_session \ 245 | --project=$GCP_PROJECT \ 246 | --region=$GCP_REGION \ 247 | --temp_location=$PIX2STRUCT_DIR/data/temp \ 248 | --staging_location=$PIX2STRUCT_DIR/data/staging \ 249 | --setup_file=./setup.py 250 | ``` 251 | 252 | ## OCR-VQA 253 | ``` 254 | mkdir -p data/ocrvqa 255 | cd data/ocrvqa 256 | ``` 257 | Follow instructions on the [OCR-VQA](https://ocr-vqa.github.io/) website to 258 | download the data into `data/ocrvqa` (requires crawling). The following steps 259 | assume that `data/ocrvqa` contains a directory called `images` and a file called 260 | `dataset.json`. 261 | 262 | ``` 263 | cd .. 264 | gsutil -m cp -r ocrvqa $PIX2STRUCT_DIR/data/ocrvqa 265 | python -m pix2struct.preprocessing.convert_ocrvqa \ 266 | --data_dir=$PIX2STRUCT_DIR/data/ocrvqa \ 267 | -- \ 268 | --runner=DataflowRunner \ 269 | --save_main_session \ 270 | --project=$GCP_PROJECT \ 271 | --region=$GCP_REGION \ 272 | --temp_location=$PIX2STRUCT_DIR/data/temp \ 273 | --staging_location=$PIX2STRUCT_DIR/data/staging \ 274 | --setup_file=./setup.py 275 | ``` 276 | 277 | ## AI2D 278 | ``` 279 | mkdir -p data/ 280 | cd data/ 281 | wget https://ai2-public-datasets.s3.amazonaws.com/diagrams/ai2d-all.zip 282 | unzip ai2d-all.zip 283 | rm ai2d-all.zip 284 | gsutil -m cp -r ai2d $PIX2STRUCT_DIR/data/ai2d 285 | python -m pix2struct.preprocessing.convert_ai2d \ 286 | --data_dir=$PIX2STRUCT_DIR/data/ai2d \ 287 | --test_ids_path=gs://pix2struct-data/ai2d_test_ids.csv \ 288 | -- \ 289 | --runner=DataflowRunner \ 290 | --save_main_session \ 291 | --project=$GCP_PROJECT \ 292 | --region=$GCP_REGION \ 293 | --temp_location=$PIX2STRUCT_DIR/data/temp \ 294 | --staging_location=$PIX2STRUCT_DIR/data/staging \ 295 | --setup_file=./setup.py 296 | ``` 297 | 298 | # Running experiments 299 | 300 | The main experiments are implemented as a light wrapper around the 301 | [T5X](https://github.com/google-research/t5x) library. For brevity, we 302 | illustrate an example workflow of finetuning the pretrained base Pix2Struct 303 | model on the Screen2Words dataset. To scale up to larger setups, please see 304 | to the T5X documentation. 305 | 306 | ## Setting up the TPU 307 | 308 | Following official [instructions](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm) 309 | for running JAX on a Cloud TPU VM, which allows you to directly `ssh` into the 310 | TPU host. 311 | 312 | In this example, we are using a `v3-8` TPU: 313 | 314 | ``` 315 | TPU_TYPE=v3-8 316 | TPU_NAME=pix2struct-$TPU_TYPE 317 | TPU_ZONE=europe-west4-a 318 | gcloud compute tpus tpu-vm create $TPU_NAME \ 319 | --zone=$TPU_ZONE \ 320 | --accelerator-type=$TPU_TYPE \ 321 | --version=tpu-vm-base 322 | gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$TPU_ZONE 323 | ``` 324 | 325 | Once you have `ssh`ed into the TPU host, follow the "Getting Started" 326 | instructions to install the `pix2struct` package. 327 | 328 | ## Training 329 | The following command will initiate the training loop, which consists of train 330 | steps interleaved with evaluations on the validation set. 331 | 332 | ``` 333 | python -m t5x.train \ 334 | --gin_search_paths="pix2struct/configs" \ 335 | --gin_file="models/pix2struct.gin" \ 336 | --gin_file="runs/train.gin" \ 337 | --gin_file="sizes/base.gin" \ 338 | --gin_file="optimizers/adafactor.gin" \ 339 | --gin_file="schedules/screen2words.gin" \ 340 | --gin_file="init/pix2struct_base_init.gin" \ 341 | --gin.MIXTURE_OR_TASK_NAME="'screen2words'" \ 342 | --gin.MODEL_DIR="'$PIX2STRUCT_DIR/experiments/screen2words_base'" \ 343 | --gin.TASK_FEATURE_LENGTHS="{'inputs': 4096, 'targets': 128}" \ 344 | --gin.BATCH_SIZE=32 345 | ``` 346 | 347 | ## Evaluation 348 | The following command evaluates the model on the test set. You will need to 349 | replace the checkpoint path with the one that was actually selected based on the 350 | validation performance. 351 | 352 | ``` 353 | python -m t5x.eval \ 354 | --gin_search_paths="pix2struct/configs" \ 355 | --gin_file="models/pix2struct.gin" \ 356 | --gin_file="runs/eval.gin" \ 357 | --gin_file="sizes/base.gin" \ 358 | --gin.MIXTURE_OR_TASK_NAME="'screen2words'" \ 359 | --gin.CHECKPOINT_PATH="'$PIX2STRUCT_DIR/experiments/screen2words_base/checkpoint_286600'" \ 360 | --gin.EVAL_OUTPUT_DIR="'$PIX2STRUCT_DIR/experiments/test_exp/test_eval'" \ 361 | --gin.EVAL_SPLIT="'test'" \ 362 | --gin.TASK_FEATURE_LENGTHS="{'inputs': 4096, 'targets': 128}" \ 363 | --gin.BATCH_SIZE=32 364 | ``` 365 | 366 | ## Finetuned Checkpoints 367 | In addition to the pretrained checkpoints released and specified in the 368 | `configs/init` directory. We also release checkpoints for the finetuned models 369 | on all tasks below. 370 | 371 | | Task | GCS Path (Base) | GCS Path (Large) | 372 | | -----------------| ------------------------------------------------------------- | -------------------------------------------------------------- | 373 | | TextCaps | `gs://pix2struct-data/textcaps_base/checkpoint_280400` | `gs://pix2struct-data/textcaps_large/checkpoint_180600` | 374 | | ChartQA | `gs://pix2struct-data/chartqa_base/checkpoint_287600` | `gs://pix2struct-data/charqa_large/checkpoint_182600` | 375 | | WidgetCaptioning | `gs://pix2struct-data/widget_captioning_base/checkpoint_281600` | `gs://pix2struct-data/widget_captioning_large/checkpoint_181600` | 376 | | Screen2Words | `gs://pix2struct-data/screen2words_base/checkpoint_282600` | `gs://pix2struct-data/screen2words_large/checkpoint_183000` | 377 | | RefExp | `gs://pix2struct-data/refexp_base/checkpoint_290000` | `gs://pix2struct-data/refexp_large/checkpoint_187800` | 378 | | DocVQA | `gs://pix2struct-data/docvqa_base/checkpoint_284400` | `gs://pix2struct-data/docvqa_large/checkpoint_184000` | 379 | | InfographicVQA | `gs://pix2struct-data/infographicvqa_base/checkpoint_284000` | `gs://pix2struct-data/infographicvqa_large/checkpoint_182000` | 380 | | OCR-VQA | `gs://pix2struct-data/ocrvqa_base/checkpoint_290000` | `gs://pix2struct-data/ocrvqa_large/checkpoint_188400` | 381 | | AI2D | `gs://pix2struct-data/ai2d_base/checkpoint_284400` | `gs://pix2struct-data/ai2d_large/checkpoint_184000` | 382 | 383 | These checkpoints are compatible with the eval command documented above and the 384 | two ways of performing inference mentioned below. Please ensure that the config 385 | file under `configs/sizes` is set to be consistent with the checkpoint. 386 | 387 | 388 | ## Inference 389 | 390 | We provide two ways of performing inference. For testing and demoing purposes, 391 | these may be run on CPU. In that case, please set the `JAX_PLATFORMS` 392 | environment variable to `cpu`. 393 | 394 | ### Command-line example 395 | 396 | We provide a minimal script for performing inference on a single example. This 397 | path has only been tested at extremely small scale and is not meant for 398 | larger-scale inference. For large-scale inference, we recommend setting a custom 399 | task with placeholder labels and running the evaluation script (`t5x.eval`) as 400 | documented above. 401 | 402 | In the following example, we show the command for predicting the caption of an 403 | image using a base-sized checkpoint finetuned on the TextCaps task. For a task 404 | that also accepts textual prompts such as questions in VQA, you can also supply 405 | the question via the `text` flag (in addition to specifying the image with the 406 | `image` flag). 407 | 408 | ``` 409 | python -m pix2struct.example_inference \ 410 | --gin_search_paths="pix2struct/configs" \ 411 | --gin_file=models/pix2struct.gin \ 412 | --gin_file=runs/inference.gin \ 413 | --gin_file=sizes/base.gin \ 414 | --gin.MIXTURE_OR_TASK_NAME="'placeholder_pix2struct'" \ 415 | --gin.TASK_FEATURE_LENGTHS="{'inputs': 2048, 'targets': 128}" \ 416 | --gin.BATCH_SIZE=1 \ 417 | --gin.CHECKPOINT_PATH="'gs://pix2struct-data/textcaps_base/checkpoint_280400'" \ 418 | --image=$HOME/test_image.jpg 419 | ``` 420 | 421 | ### Web Demo 422 | 423 | For a more user-friendly demo, we also provide a web-based alternative of 424 | inference script above. While running this command, the web demo can be accessed 425 | at `localhost:8080` (or any port specified via the `port` flag), assuming you 426 | are running the demo locally. You can then upload your custom image and optional 427 | prompt instead of specifying it via the command line. 428 | 429 | ``` 430 | python -m pix2struct.demo \ 431 | --gin_search_paths="pix2struct/configs" \ 432 | --gin_file=models/pix2struct.gin \ 433 | --gin_file=runs/inference.gin \ 434 | --gin_file=sizes/base.gin \ 435 | --gin.MIXTURE_OR_TASK_NAME="'placeholder_pix2struct'" \ 436 | --gin.TASK_FEATURE_LENGTHS="{'inputs': 2048, 'targets': 128}" \ 437 | --gin.BATCH_SIZE=1 \ 438 | --gin.CHECKPOINT_PATH="'gs://pix2struct-data/textcaps_base/checkpoint_280400'" 439 | ``` 440 | 441 | ## Clean up 442 | When you are done with your TPU VM, remember to delete the instance: 443 | 444 | ``` 445 | gcloud compute tpus tpu-vm delete $TPU_NAME --zone=$TPU_ZONE 446 | ``` 447 | 448 | # Note 449 | 450 | *This is not an officially supported Google product.* 451 | -------------------------------------------------------------------------------- /pix2struct/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """pix2struct API.""" 16 | 17 | # A new PyPI release will be pushed everytime `__version__` is increased 18 | __version__ = '0.1.0' 19 | -------------------------------------------------------------------------------- /pix2struct/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This empty file is needed for loading the gin files in this directory. 16 | -------------------------------------------------------------------------------- /pix2struct/configs/init/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This empty file is needed for loading the gin files in this directory. 16 | -------------------------------------------------------------------------------- /pix2struct/configs/init/pix2struct_base_init.gin: -------------------------------------------------------------------------------- 1 | INITIAL_CHECKPOINT_PATH = 'gs://pix2struct-data/pix2struct_base' 2 | INIT_STEPS = 280000 3 | -------------------------------------------------------------------------------- /pix2struct/configs/init/pix2struct_large_init.gin: -------------------------------------------------------------------------------- 1 | INITIAL_CHECKPOINT_PATH = 'gs://pix2struct-data/pix2struct_large' 2 | INIT_STEPS = 180000 3 | -------------------------------------------------------------------------------- /pix2struct/configs/init/random_init.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | from t5x import utils 3 | 4 | INITIAL_CHECKPOINT_PATH = None 5 | INIT_STEPS = 0 6 | utils.CheckpointConfig: 7 | restore = None -------------------------------------------------------------------------------- /pix2struct/configs/init/warmup_base_init.gin: -------------------------------------------------------------------------------- 1 | INITIAL_CHECKPOINT_PATH = 'gs://pix2struct-data/warmup_base' 2 | INIT_STEPS = 30000 3 | -------------------------------------------------------------------------------- /pix2struct/configs/init/warmup_large_init.gin: -------------------------------------------------------------------------------- 1 | INITIAL_CHECKPOINT_PATH = 'gs://pix2struct-data/warmup_large' 2 | INIT_STEPS = 30000 3 | -------------------------------------------------------------------------------- /pix2struct/configs/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This empty file is needed for loading the gin files in this directory. 16 | -------------------------------------------------------------------------------- /pix2struct/configs/models/pix2struct.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | import seqio 4 | from t5x import adafactor 5 | from t5x import utils 6 | from pix2struct import models 7 | from flaxformer.architectures.t5 import t5_architecture 8 | from flaxformer.components import dense 9 | from flaxformer.components import embedding 10 | from flaxformer.components import layer_norm 11 | 12 | include 'pix2struct/configs/models/t5_1_1_flaxformer.gin' 13 | 14 | NUM_EMBEDDINGS = 50244 15 | ACTIVATION_PARTITIONING_DIMS = None 16 | OPTIMIZER = None 17 | 18 | seqio.PassThroughVocabulary: 19 | size = 0 20 | 21 | embedding.PositionEmbed: 22 | num_embeddings = 4096 23 | features = %EMBED_DIM 24 | dtype = %ACTIVATION_DTYPE 25 | 26 | patch_projection/dense.DenseGeneral: 27 | features = %EMBED_DIM 28 | use_bias = True 29 | dtype = %ACTIVATION_DTYPE 30 | kernel_axis_names = ['embed', 'mlp'] 31 | name = 'patch_projection' 32 | 33 | models.PatchEmbed: 34 | num_extra_embedders = 2 # rows and columns 35 | patch_projection_factory = @patch_projection/dense.DenseGeneral 36 | embedder_factory = @embedding.PositionEmbed 37 | 38 | models.ImageEncoder: 39 | num_layers = %NUM_ENCODER_LAYERS 40 | layer_factory = @t5_architecture.EncoderLayer 41 | input_dropout_factory = %DROPOUT_FACTORY 42 | output_dropout_factory = %DROPOUT_FACTORY 43 | layer_norm_factory = @layer_norm.T5LayerNorm 44 | token_embedder_factory = @models.PatchEmbed 45 | shared_relative_position_bias_factory = None 46 | dtype = %ACTIVATION_DTYPE 47 | 48 | t5_architecture.Decoder: 49 | token_embedder_factory = @embedding.Embed 50 | 51 | models.ImageEncoderTextDecoder: 52 | encoder_factory = @models.ImageEncoder 53 | decoder_factory = @t5_architecture.Decoder 54 | dtype = %ACTIVATION_DTYPE 55 | 56 | seqio.SentencePieceVocabulary: 57 | sentencepiece_model_file = "gs://pix2struct-data/sentencepiece.model" 58 | 59 | MODEL = @models.ImageToTextModel() 60 | models.ImageToTextModel: 61 | module = @models.ImageEncoderTextDecoder() 62 | input_vocabulary = @seqio.PassThroughVocabulary() 63 | output_vocabulary = @seqio.SentencePieceVocabulary() 64 | optimizer_def = %OPTIMIZER 65 | z_loss = 0.0001 66 | -------------------------------------------------------------------------------- /pix2struct/configs/models/t5_1_1_flaxformer.gin: -------------------------------------------------------------------------------- 1 | # Flaxformer implementation of T5.1.1 architecture. 2 | # 3 | # Required to be overridden: 4 | # 5 | # - NUM_ENCODER_LAYERS 6 | # - NUM_DECODER_LAYERS 7 | # - NUM_HEADS 8 | # - HEAD_DIM 9 | # - EMBED_DIM 10 | # - MLP_DIM 11 | from __gin__ import dynamic_registration 12 | 13 | from flax import linen 14 | 15 | from flaxformer.architectures.t5 import t5_architecture 16 | from flaxformer.components.attention import dense_attention 17 | from flaxformer.components import dense 18 | from flaxformer.components import embedding 19 | from flaxformer.components import layer_norm 20 | from flaxformer.components import relative_position_biases 21 | 22 | # Must be overridden. 23 | NUM_ENCODER_LAYERS = %gin.REQUIRED 24 | NUM_DECODER_LAYERS = %gin.REQUIRED 25 | NUM_HEADS = %gin.REQUIRED 26 | HEAD_DIM = %gin.REQUIRED 27 | EMBED_DIM = %gin.REQUIRED 28 | MLP_DIM = %gin.REQUIRED 29 | NUM_EMBEDDINGS = %gin.REQUIRED 30 | 31 | # Constants (may be overridden) 32 | ACTIVATION_DTYPE = 'bfloat16' 33 | ACTIVATION_PARTITIONING_DIMS = 1 34 | SCALE = 1.0 35 | DROPOUT_RATE = 0.0 36 | 37 | # Macros 38 | BIAS_INIT = @bias_init/linen.initializers.normal() 39 | bias_init/linen.initializers.normal.stddev = 1e-6 40 | DROPOUT_FACTORY = @dropout_factory/linen.Dropout 41 | dropout_factory/linen.Dropout: 42 | rate = %DROPOUT_RATE 43 | broadcast_dims = (-2,) 44 | 45 | # Architecture (Flax Module) 46 | ARCHITECTURE = @t5_architecture.EncoderDecoder() 47 | t5_architecture.EncoderDecoder: 48 | encoder_factory = @t5_architecture.Encoder 49 | decoder_factory = @t5_architecture.Decoder 50 | shared_token_embedder_factory = @embedding.Embed 51 | dtype = %ACTIVATION_DTYPE 52 | 53 | # Encoder 54 | t5_architecture.Encoder: 55 | num_layers = %NUM_ENCODER_LAYERS 56 | layer_factory = @t5_architecture.EncoderLayer 57 | input_dropout_factory = %DROPOUT_FACTORY 58 | output_dropout_factory = %DROPOUT_FACTORY 59 | layer_norm_factory = @layer_norm.T5LayerNorm 60 | position_embedder_factory = None 61 | shared_relative_position_bias_factory = @relative_position_biases.RelativePositionBiases 62 | dtype = %ACTIVATION_DTYPE 63 | 64 | # Encoder Layer 65 | t5_architecture.EncoderLayer: 66 | attention = @dense_attention.MultiHeadDotProductAttention() 67 | mlp = @dense.MlpBlock() 68 | dropout_factory = %DROPOUT_FACTORY 69 | layer_norm_factory = @layer_norm.T5LayerNorm 70 | activation_partitioning_dims = %ACTIVATION_PARTITIONING_DIMS 71 | 72 | # Decoder 73 | t5_architecture.Decoder: 74 | num_layers = %NUM_DECODER_LAYERS 75 | layer_factory = @t5_architecture.DecoderLayer 76 | dropout_factory = %DROPOUT_FACTORY 77 | layer_norm_factory = @layer_norm.T5LayerNorm 78 | position_embedder_factory = None 79 | shared_relative_position_bias_factory = @relative_position_biases.RelativePositionBiases 80 | output_logits_factory = @output_logits/dense.DenseGeneral 81 | dtype = %ACTIVATION_DTYPE 82 | 83 | # Decoupled embedding 84 | output_logits/dense.DenseGeneral: 85 | features = %NUM_EMBEDDINGS 86 | use_bias = False 87 | dtype = 'float32' 88 | kernel_init = @output_logits_kernel_init/linen.initializers.variance_scaling() 89 | bias_init = %BIAS_INIT 90 | kernel_axis_names = ["embed", "vocab"] 91 | output_logits_kernel_init/linen.initializers.variance_scaling: 92 | scale = %SCALE 93 | mode = 'fan_in' 94 | distribution = 'truncated_normal' 95 | 96 | # Decoder Layer 97 | t5_architecture.DecoderLayer: 98 | self_attention = @dense_attention.MultiHeadDotProductAttention() 99 | encoder_decoder_attention = @dense_attention.MultiHeadDotProductAttention() 100 | mlp = @dense.MlpBlock() 101 | dropout_factory = %DROPOUT_FACTORY 102 | layer_norm_factory = @layer_norm.T5LayerNorm 103 | activation_partitioning_dims = %ACTIVATION_PARTITIONING_DIMS 104 | 105 | # Token Embedder (shared) 106 | embedding.Embed: 107 | num_embeddings= %NUM_EMBEDDINGS 108 | features = %EMBED_DIM 109 | cast_input_dtype = 'int32' 110 | dtype = %ACTIVATION_DTYPE 111 | attend_dtype = 'float32' # for logit training stability 112 | embedding_init = @token_embedder_init/linen.initializers.normal() 113 | one_hot = True 114 | name = 'token_embedder' 115 | token_embedder_init/linen.initializers.normal.stddev = 1.0 116 | 117 | # Attention (encoder, decoder, self-attention) 118 | dense_attention.MultiHeadDotProductAttention: 119 | num_heads = %NUM_HEADS 120 | dtype = %ACTIVATION_DTYPE 121 | head_dim = %HEAD_DIM 122 | kernel_init = @attention_kernel_init/linen.initializers.variance_scaling() 123 | bias_init = %BIAS_INIT 124 | use_bias = False 125 | broadcast_dropout = True 126 | dropout_rate = %DROPOUT_RATE 127 | attention_kernel_init/linen.initializers.variance_scaling: 128 | scale = %SCALE 129 | mode = 'fan_in' 130 | distribution = 'normal' 131 | 132 | # Relative position biases (encoder, decoder) 133 | relative_position_biases.RelativePositionBiases: 134 | num_heads = %NUM_HEADS 135 | dtype = %ACTIVATION_DTYPE 136 | num_buckets = 32 137 | max_distance = 128 138 | embedding_init = @relative_position_bias_init/linen.initializers.variance_scaling() 139 | relative_position_bias_init/linen.initializers.variance_scaling: 140 | scale = %SCALE 141 | mode = 'fan_avg' 142 | distribution = 'uniform' 143 | 144 | # MLP (encoder, decoder) 145 | dense.MlpBlock: 146 | use_bias = False 147 | intermediate_dim = %MLP_DIM 148 | activations = ('gelu', 'linear') 149 | kernel_init = @mlp_kernel_init/linen.initializers.variance_scaling() 150 | bias_init = %BIAS_INIT 151 | intermediate_dropout_rate = %DROPOUT_RATE 152 | final_dropout_rate = 0 153 | dtype = %ACTIVATION_DTYPE 154 | mlp_kernel_init/linen.initializers.variance_scaling: 155 | scale = %SCALE 156 | mode = 'fan_in' 157 | distribution = 'truncated_normal' 158 | 159 | layer_norm.T5LayerNorm.dtype = %ACTIVATION_DTYPE 160 | -------------------------------------------------------------------------------- /pix2struct/configs/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This empty file is needed for loading the gin files in this directory. 16 | -------------------------------------------------------------------------------- /pix2struct/configs/optimizers/adafactor.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | 3 | from t5x import adafactor 4 | from t5x import trainer 5 | from t5x import utils 6 | import optax 7 | from pix2struct import transfer_utils 8 | 9 | OPTIMIZER = @adafactor.Adafactor() 10 | 11 | adafactor.Adafactor: 12 | weight_decay_rate = 1e-5 13 | 14 | transfer_utils.transfer_warmup_cosine_decay_schedule: 15 | start_step = %INIT_STEPS 16 | peak_value = 1e-2 17 | warmup_steps = 1000 18 | end_step = %TRAIN_STEPS 19 | 20 | trainer.Trainer: 21 | learning_rate_fn = @transfer_utils.transfer_warmup_cosine_decay_schedule() 22 | -------------------------------------------------------------------------------- /pix2struct/configs/runs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This empty file is needed for loading the gin files in this directory. 16 | -------------------------------------------------------------------------------- /pix2struct/configs/runs/eval.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | from t5x import utils 3 | import pix2struct.tasks 4 | 5 | include 't5x/configs/runs/eval.gin' 6 | 7 | USE_CACHED_TASKS = False 8 | EVAL_SPLIT = 'validation' 9 | 10 | utils.DatasetConfig: 11 | split = %EVAL_SPLIT 12 | batch_size = %BATCH_SIZE 13 | use_memory_cache = False 14 | -------------------------------------------------------------------------------- /pix2struct/configs/runs/inference.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | from t5x import partitioning 3 | from pix2struct import inference_utils 4 | import pix2struct.tasks 5 | 6 | TASK_FEATURE_LENGTHS = %gin.REQUIRED 7 | BATCH_SIZE = %gin.REQUIRED 8 | CHECKPOINT_PATH = %gin.REQUIRED 9 | MIXTURE_OR_TASK_NAME = %gin.REQUIRED 10 | 11 | inference_utils.get_inference_fns: 12 | task_name = %MIXTURE_OR_TASK_NAME 13 | batch_size = %BATCH_SIZE 14 | sequence_length = %TASK_FEATURE_LENGTHS 15 | model = %MODEL 16 | checkpoint_path = %CHECKPOINT_PATH 17 | partitioner = @partitioning.PjitPartitioner() 18 | 19 | partitioning.PjitPartitioner: 20 | num_partitions = 1 21 | logical_axis_rules = @partitioning.standard_logical_axis_rules() 22 | -------------------------------------------------------------------------------- /pix2struct/configs/runs/train.gin: -------------------------------------------------------------------------------- 1 | from __gin__ import dynamic_registration 2 | import __main__ as train_script 3 | from t5x import checkpoints 4 | from t5x import utils 5 | import pix2struct.tasks 6 | from pix2struct import transfer_utils 7 | 8 | include 't5x/configs/runs/finetune.gin' 9 | 10 | CHECKPOINT_PERIOD = %gin.REQUIRED 11 | EVAL_PERIOD = %gin.REQUIRED 12 | EVALUATOR_NUM_EXAMPLES = %gin.REQUIRED 13 | STAGE_STEPS = %gin.REQUIRED 14 | METRIC_NAME = %gin.REQUIRED 15 | 16 | DROPOUT_RATE = 0.0 17 | USE_CACHED_TASKS = False 18 | 19 | train_script.train: 20 | eval_period = %EVAL_PERIOD 21 | 22 | train/utils.DatasetConfig: 23 | pack = False 24 | 25 | train_eval/utils.DatasetConfig: 26 | pack = False 27 | 28 | infer_eval/utils.DatasetConfig: 29 | task_feature_lengths = %TASK_FEATURE_LENGTHS 30 | 31 | utils.CheckpointConfig: 32 | restore = @transfer_utils.TransferRestoreCheckpointConfig() 33 | 34 | transfer_utils.TransferRestoreCheckpointConfig: 35 | path = %INITIAL_CHECKPOINT_PATH 36 | mode = 'specific' 37 | dtype = 'float32' 38 | steps = %INIT_STEPS 39 | 40 | compute_train_steps/transfer_utils.add: 41 | b = %INIT_STEPS 42 | a = %STAGE_STEPS 43 | TRAIN_STEPS = @compute_train_steps/transfer_utils.add() 44 | 45 | utils.SaveCheckpointConfig: 46 | period = %CHECKPOINT_PERIOD 47 | keep = 1 48 | save_dataset = False 49 | checkpointer_cls = @checkpoints.SaveBestCheckpointer 50 | 51 | checkpoints.SaveBestCheckpointer: 52 | metric_name_to_monitor = %METRIC_NAME 53 | metric_mode = 'max' 54 | -------------------------------------------------------------------------------- /pix2struct/configs/schedules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This empty file is needed for loading the gin files in this directory. 16 | -------------------------------------------------------------------------------- /pix2struct/configs/schedules/ai2d.gin: -------------------------------------------------------------------------------- 1 | EVALUATOR_NUM_EXAMPLES = None 2 | STAGE_STEPS = 5000 3 | CHECKPOINT_PERIOD = 200 4 | EVAL_PERIOD = 200 5 | METRIC_NAME = 'inference_eval/ai2d/eval/exact_match' 6 | -------------------------------------------------------------------------------- /pix2struct/configs/schedules/chartqa.gin: -------------------------------------------------------------------------------- 1 | EVALUATOR_NUM_EXAMPLES = None 2 | STAGE_STEPS = 10000 3 | CHECKPOINT_PERIOD = 200 4 | EVAL_PERIOD = 200 5 | METRIC_NAME = 'inference_eval/chartqa_human/eval/relaxed_accuracy' 6 | -------------------------------------------------------------------------------- /pix2struct/configs/schedules/docvqa.gin: -------------------------------------------------------------------------------- 1 | EVALUATOR_NUM_EXAMPLES = None 2 | STAGE_STEPS = 10000 3 | CHECKPOINT_PERIOD = 200 4 | EVAL_PERIOD = 200 5 | METRIC_NAME = 'inference_eval/docvqa/eval/anls' 6 | -------------------------------------------------------------------------------- /pix2struct/configs/schedules/infographicvqa.gin: -------------------------------------------------------------------------------- 1 | EVALUATOR_NUM_EXAMPLES = None 2 | STAGE_STEPS = 10000 3 | CHECKPOINT_PERIOD = 200 4 | EVAL_PERIOD = 200 5 | METRIC_NAME = 'inference_eval/infographicvqa/eval/anls' 6 | -------------------------------------------------------------------------------- /pix2struct/configs/schedules/ocrvqa.gin: -------------------------------------------------------------------------------- 1 | EVALUATOR_NUM_EXAMPLES = 5000 2 | STAGE_STEPS = 10000 3 | CHECKPOINT_PERIOD = 200 4 | EVAL_PERIOD = 200 5 | METRIC_NAME = 'inference_eval/ocrvqa/eval/exact_match' 6 | -------------------------------------------------------------------------------- /pix2struct/configs/schedules/refexp.gin: -------------------------------------------------------------------------------- 1 | EVALUATOR_NUM_EXAMPLES = None 2 | STAGE_STEPS = 10000 3 | CHECKPOINT_PERIOD = 200 4 | EVAL_PERIOD = 200 5 | METRIC_NAME = 'inference_eval/refexp_infer/eval/group_accuracy' 6 | -------------------------------------------------------------------------------- /pix2struct/configs/schedules/screen2words.gin: -------------------------------------------------------------------------------- 1 | EVALUATOR_NUM_EXAMPLES = 5000 2 | STAGE_STEPS = 10000 3 | CHECKPOINT_PERIOD = 200 4 | EVAL_PERIOD = 200 5 | METRIC_NAME = 'inference_eval/screen2words/eval/cider' 6 | -------------------------------------------------------------------------------- /pix2struct/configs/schedules/textcaps.gin: -------------------------------------------------------------------------------- 1 | EVALUATOR_NUM_EXAMPLES = 5000 2 | STAGE_STEPS = 10000 3 | CHECKPOINT_PERIOD = 200 4 | EVAL_PERIOD = 200 5 | METRIC_NAME = 'inference_eval/textcaps/eval/cider' 6 | -------------------------------------------------------------------------------- /pix2struct/configs/schedules/widget_captioning.gin: -------------------------------------------------------------------------------- 1 | EVALUATOR_NUM_EXAMPLES = None 2 | STAGE_STEPS = 5000 3 | CHECKPOINT_PERIOD = 200 4 | EVAL_PERIOD = 200 5 | METRIC_NAME = 'inference_eval/widget_captioning/eval/cider' 6 | -------------------------------------------------------------------------------- /pix2struct/configs/sizes/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This empty file is needed for loading the gin files in this directory. 16 | -------------------------------------------------------------------------------- /pix2struct/configs/sizes/base.gin: -------------------------------------------------------------------------------- 1 | NUM_ENCODER_LAYERS = 12 2 | NUM_DECODER_LAYERS = 12 3 | NUM_HEADS = 12 4 | HEAD_DIM = 64 5 | MLP_DIM = 2048 6 | EMBED_DIM = 768 7 | -------------------------------------------------------------------------------- /pix2struct/configs/sizes/large.gin: -------------------------------------------------------------------------------- 1 | NUM_ENCODER_LAYERS = 18 2 | NUM_DECODER_LAYERS = 18 3 | NUM_HEADS = 24 4 | HEAD_DIM = 64 5 | MLP_DIM = 3968 6 | EMBED_DIM = 1536 -------------------------------------------------------------------------------- /pix2struct/configs/sizes/tiny.gin: -------------------------------------------------------------------------------- 1 | NUM_ENCODER_LAYERS = 2 2 | NUM_DECODER_LAYERS = 2 3 | NUM_HEADS = 2 4 | HEAD_DIM = 2 5 | MLP_DIM = 2 6 | EMBED_DIM = 2 7 | -------------------------------------------------------------------------------- /pix2struct/demo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Web demo of Pix2Struct.""" 16 | import base64 17 | import html 18 | import io 19 | import os 20 | import wsgiref.simple_server 21 | 22 | from absl import flags 23 | import gin 24 | import jinja2 25 | import PIL.Image 26 | from pix2struct import demo_utils 27 | from pix2struct import inference_utils 28 | from t5x import gin_utils 29 | import tensorflow as tf 30 | import tornado.web 31 | import tornado.wsgi 32 | 33 | 34 | flags.DEFINE_multi_string("gin_file", None, "Gin files.") 35 | flags.DEFINE_multi_string("gin_bindings", [], "Individual gin bindings.") 36 | flags.DEFINE_list("gin_search_paths", ["."], "Gin search paths.") 37 | flags.DEFINE_integer("port", 8080, "Port number for localhost.") 38 | 39 | 40 | FLAGS = flags.FLAGS 41 | 42 | 43 | class ScreenshotHandler(tornado.web.RequestHandler): 44 | """Main handler.""" 45 | _tmpl = None 46 | _demo_fn = None 47 | 48 | def initialize(self, 49 | env=None, 50 | demo_fn=None): 51 | self._demo_fn = demo_fn 52 | self._tmpl = env.get_template("demo_screenshot.html") 53 | 54 | def get(self): 55 | self.post() 56 | 57 | def post(self): 58 | if "image" in self.request.files: 59 | image_bytes = self.request.files["image"][0]["body"] 60 | image_bytes = demo_utils.maybe_add_question( 61 | question=self.get_argument("question", default=""), 62 | image_bytes=image_bytes, 63 | ) 64 | if self._demo_fn is None: 65 | raise ValueError("self._demo_fn is None") 66 | prediction = html.escape( 67 | demo_utils.apply_single_inference(self._demo_fn, image_bytes) 68 | ) 69 | image = tf.compat.as_str(base64.b64encode(image_bytes)) 70 | else: 71 | prediction = "" 72 | image = "" 73 | if self._tmpl is None: 74 | raise ValueError("self._tmpl is None") 75 | self.write(self._tmpl.render( 76 | image=image, 77 | prediction=prediction)) 78 | 79 | 80 | def main(_): 81 | get_demo_fns_using_gin = gin.configurable(inference_utils.get_inference_fns) 82 | gin_utils.parse_gin_flags( 83 | gin_search_paths=FLAGS.gin_search_paths, 84 | gin_files=FLAGS.gin_file, 85 | gin_bindings=FLAGS.gin_bindings) 86 | demo_fn = get_demo_fns_using_gin()["predict"] 87 | 88 | print("Warming up demo function...") 89 | placeholder_bytes = io.BytesIO() 90 | PIL.Image.new("RGB", size=(1, 1)).save(placeholder_bytes, "png") 91 | demo_utils.apply_single_inference(demo_fn, placeholder_bytes.getvalue()) 92 | print("Done warming up demo function.") 93 | 94 | web_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "web") 95 | env = jinja2.Environment( 96 | loader=jinja2.FileSystemLoader(os.path.join(web_path, "templates"))) 97 | application = tornado.wsgi.WSGIApplication([ 98 | (r"/", ScreenshotHandler, { 99 | "env": env, 100 | "demo_fn": demo_fn, 101 | }), 102 | (r"/static/(.*)", tornado.web.StaticFileHandler, { 103 | "path": os.path.join(web_path, "static") 104 | }) 105 | ]) 106 | server = wsgiref.simple_server.make_server("", FLAGS.port, application) 107 | print("") 108 | server.serve_forever() 109 | 110 | if __name__ == "__main__": 111 | gin_utils.run(main) 112 | -------------------------------------------------------------------------------- /pix2struct/demo_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Demo utils.""" 16 | 17 | import io 18 | from typing import Any, Callable, Iterable 19 | 20 | import PIL.Image 21 | from pix2struct.preprocessing import preprocessing_utils 22 | import tensorflow as tf 23 | 24 | 25 | def maybe_add_question(question, image_bytes): 26 | if question: 27 | # If it exists, add a question as a header. 28 | image = PIL.Image.open(io.BytesIO(image_bytes)).convert("RGB") 29 | output_image = preprocessing_utils.render_header(image, question) 30 | output_image_bytes = io.BytesIO() 31 | output_image.save(output_image_bytes, format="PNG") 32 | return output_image_bytes.getvalue() 33 | else: 34 | return image_bytes 35 | 36 | 37 | def apply_single_inference( 38 | inference_fn: Callable[[tf.data.Dataset], Iterable[Any]], image_bytes: bytes 39 | ) -> Any: 40 | dataset = tf.data.Dataset.from_tensors( 41 | {"id": "", "group_id": "", "image": image_bytes, "parse": [""]} 42 | ) 43 | return next(iter(inference_fn(dataset))) 44 | -------------------------------------------------------------------------------- /pix2struct/example_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Example script for Pix2Struct Inference.""" 16 | from absl import flags 17 | import gin 18 | from pix2struct import demo_utils 19 | from pix2struct import inference_utils 20 | from t5x import gin_utils 21 | import tensorflow as tf 22 | 23 | flags.DEFINE_multi_string("gin_file", None, "Gin files.") 24 | flags.DEFINE_multi_string("gin_bindings", [], "Individual gin bindings.") 25 | flags.DEFINE_list("gin_search_paths", ["."], "Gin search paths.") 26 | flags.DEFINE_string("image", "", "Path to the image file.") 27 | flags.DEFINE_string("text", None, "Optional text (e.g. question).") 28 | 29 | FLAGS = flags.FLAGS 30 | 31 | 32 | def main(_) -> None: 33 | get_inference_fns_using_gin = gin.configurable( 34 | inference_utils.get_inference_fns 35 | ) 36 | gin_utils.parse_gin_flags( 37 | gin_search_paths=FLAGS.gin_search_paths, 38 | gin_files=FLAGS.gin_file, 39 | gin_bindings=FLAGS.gin_bindings, 40 | ) 41 | inference_fns = get_inference_fns_using_gin() 42 | predict_fn = inference_fns["predict"] 43 | 44 | with tf.io.gfile.GFile(FLAGS.image, "rb") as f: 45 | image_bytes = f.read() 46 | image_bytes = demo_utils.maybe_add_question(FLAGS.text, image_bytes) 47 | prediction = demo_utils.apply_single_inference(predict_fn, image_bytes) 48 | print(prediction) 49 | 50 | if __name__ == "__main__": 51 | gin_utils.run(main) 52 | -------------------------------------------------------------------------------- /pix2struct/inference_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Inference utils.""" 16 | 17 | from typing import Any, Callable, Dict, Iterable, Mapping 18 | import jax 19 | import seqio 20 | from t5x import models 21 | from t5x import partitioning 22 | from t5x import utils 23 | import tensorflow as tf 24 | 25 | 26 | def get_inference_fns( 27 | task_name: str, 28 | batch_size: int, 29 | sequence_length: Mapping[str, int], 30 | model: models.BaseTransformerModel, 31 | checkpoint_path: str, 32 | partitioner: partitioning.BasePartitioner 33 | ) -> Dict[str, Callable[[tf.data.Dataset], Iterable[Any]]]: 34 | """Get inference function.""" 35 | task = seqio.get_mixture_or_task(task_name) 36 | feature_converter = model.FEATURE_CONVERTER_CLS(pack=False) 37 | 38 | def _task_to_dataset(t: seqio.Task) -> tf.data.Dataset: 39 | d = t.get_dataset( 40 | sequence_length=sequence_length, 41 | split=task.splits[0], 42 | shuffle=False, 43 | num_epochs=1, 44 | use_cached=False) 45 | return feature_converter(d, sequence_length) 46 | 47 | input_shapes = { 48 | k: (batch_size,) + spec.shape for k, spec in 49 | _task_to_dataset(task).element_spec.items() 50 | } 51 | train_state_initializer = utils.TrainStateInitializer( # pytype: disable=wrong-arg-types # jax-array 52 | optimizer_def=None, 53 | init_fn=model.get_initial_variables, 54 | input_shapes=input_shapes, 55 | partitioner=partitioner) 56 | restore_checkpoint_cfg = utils.RestoreCheckpointConfig( 57 | path=checkpoint_path, 58 | mode="specific", 59 | strict=False) 60 | 61 | train_state = train_state_initializer.from_checkpoint( 62 | [restore_checkpoint_cfg]) 63 | assert train_state is not None 64 | 65 | def _dataset_to_batches(dataset: tf.data.Dataset) -> Iterable[Any]: 66 | temp_task = seqio.Task( 67 | name="tmp", 68 | source=seqio.FunctionDataSource( 69 | dataset_fn=lambda split, shuffle_files: dataset, 70 | splits=["tmp"]), 71 | output_features=task.output_features, 72 | preprocessors=task.preprocessors) # pytype: disable=attribute-error # always-use-return-annotations 73 | temp_dataset = _task_to_dataset(temp_task) 74 | temp_dataset = temp_dataset.batch(batch_size) 75 | return temp_dataset.as_numpy_iterator() 76 | 77 | predict_batch_jit = jax.jit(model.predict_batch) 78 | score_batch_jit = jax.jit(model.score_batch) 79 | 80 | vocabulary = task.output_features["targets"].vocabulary 81 | def _predict(dataset: tf.data.Dataset) -> Iterable[str]: 82 | for batch in _dataset_to_batches(dataset): 83 | for token_ids in predict_batch_jit(train_state.params, batch): 84 | yield vocabulary.decode(token_ids) 85 | 86 | def _intermediates(dataset: tf.data.Dataset) -> Iterable[Any]: 87 | for batch in _dataset_to_batches(dataset): 88 | _, intermediates = score_batch_jit( 89 | train_state.params, batch, return_intermediates=True) 90 | yield batch, intermediates 91 | 92 | return { 93 | "predict": _predict, 94 | "intermediates": _intermediates, 95 | } 96 | -------------------------------------------------------------------------------- /pix2struct/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Metrics.""" 16 | import collections 17 | import itertools 18 | from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union 19 | 20 | from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer 21 | from pycocoevalcap.cider.cider import Cider 22 | import editdistance 23 | 24 | 25 | def aggregate_metrics( 26 | targets: Sequence[Sequence[str]], 27 | predictions: Sequence[str], 28 | metric_fn: Callable[[str, str], Any], 29 | normalize_fn: Callable[[str], str] = lambda v: v) -> float: 30 | """Aggregate target-prediction pair metrics over a dataset.""" 31 | assert len(targets) == len(predictions) 32 | total = 0 33 | for prediction, target in zip(predictions, targets): 34 | p = normalize_fn(prediction) 35 | total += max(metric_fn(normalize_fn(t), p) for t in target) 36 | return (100.0 * total) / len(targets) 37 | 38 | 39 | def cider( 40 | targets: Sequence[Sequence[str]], 41 | predictions: Sequence[str]) -> float: 42 | """Compute CIDEr score.""" 43 | coco_tokenizer = PTBTokenizer() 44 | 45 | scorer = Cider() 46 | avg_score, _ = scorer.compute_score( 47 | gts=coco_tokenizer.tokenize({ 48 | str(i): [{"caption": t} for t in target] 49 | for i, target in enumerate(targets) 50 | }), 51 | res=coco_tokenizer.tokenize({ 52 | str(i): [{"caption": prediction}] 53 | for i, prediction in enumerate(predictions) 54 | })) 55 | return float(avg_score) * 100.0 56 | 57 | 58 | def anls_metric(target: str, prediction: str, theta: float = 0.5): 59 | """Calculates ANLS for DocVQA. 60 | 61 | There does not seem to be an official evaluation script. 62 | Public implementation on which this implementation is based: 63 | https://github.com/herobd/layoutlmv2/blob/main/eval_docvqa.py#L92 64 | 65 | Original paper (see Eq 1): https://arxiv.org/pdf/1907.00490.pdf 66 | 67 | Args: 68 | target: Target string. 69 | prediction: Predicted string. 70 | theta: Filter threshold set to 0.5 for DocVQA. 71 | 72 | Returns: 73 | ANLS score. 74 | """ 75 | 76 | edit_distance = editdistance.eval(target, prediction) 77 | normalized_ld = edit_distance / max(len(target), len(prediction)) 78 | return 1 - normalized_ld if normalized_ld < theta else 0 79 | 80 | 81 | def relaxed_correctness(target: str, 82 | prediction: str, 83 | max_relative_change: float = 0.05) -> bool: 84 | """Calculates relaxed correctness. 85 | 86 | The correctness tolerates certain error ratio defined by max_relative_change. 87 | See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1: 88 | “Following Methani et al. (2020), we use a relaxed accuracy measure for the 89 | numeric answers to allow a minor inaccuracy that may result from the automatic 90 | data extraction process. We consider an answer to be correct if it is within 91 | 5% of the gold answer. For non-numeric answers, we still need an exact match 92 | to consider an answer to be correct.” 93 | 94 | Args: 95 | target: Target string. 96 | prediction: Predicted string. 97 | max_relative_change: Maximum relative change. 98 | 99 | Returns: 100 | Whether the prediction was correct given the specified tolerance. 101 | """ 102 | 103 | def _to_float(text: str) -> Optional[float]: 104 | try: 105 | if text.endswith("%"): 106 | # Convert percentages to floats. 107 | return float(text.rstrip("%")) / 100.0 108 | else: 109 | return float(text) 110 | except ValueError: 111 | return None 112 | 113 | prediction_float = _to_float(prediction) 114 | target_float = _to_float(target) 115 | if prediction_float is not None and target_float: 116 | relative_change = abs(prediction_float - target_float) / abs(target_float) 117 | return relative_change <= max_relative_change 118 | else: 119 | return prediction.lower() == target.lower() 120 | 121 | 122 | def pix2struct_metrics( 123 | targets: Sequence[Sequence[str]], 124 | predictions: Sequence[str]) -> Mapping[str, float]: 125 | """Calculates evaluation metrics. 126 | 127 | Args: 128 | targets: list of list of strings. 129 | predictions: list of strings. 130 | 131 | Returns: 132 | dictionary with metric names as keys and metric value as values. 133 | """ 134 | return dict( 135 | exact_match=aggregate_metrics( 136 | targets=targets, 137 | predictions=predictions, 138 | metric_fn=lambda x, y: x == y), 139 | anls=aggregate_metrics( 140 | targets=targets, 141 | predictions=predictions, 142 | metric_fn=anls_metric, 143 | normalize_fn=lambda v: v.lower()), 144 | relaxed_accuracy=aggregate_metrics( 145 | targets=targets, 146 | predictions=predictions, 147 | metric_fn=relaxed_correctness), 148 | cider=cider( 149 | targets=targets, 150 | predictions=predictions)) 151 | 152 | 153 | def instance_ranking_metrics( 154 | targets: List[Dict[str, Any]], 155 | predictions: List[str], 156 | aux_values: Dict[str, Any], 157 | group_fn: Callable[[Any], Any], 158 | correct_fn: Callable[[Any], bool], 159 | ranking_fn: Callable[[str, float], Any], 160 | return_correctness: bool = False 161 | ) -> Union[Mapping[str, float], Tuple[Mapping[str, float], List[bool]]]: 162 | """Compute accuracy of instance ranking. 163 | 164 | Args: 165 | targets: List of dictionaries after the postprocessor is applied. 166 | predictions: List of predicted strings. 167 | aux_values: Dictionary where the "scores" entry has a list of float scores. 168 | group_fn: Function that maps a target to a grouping key. 169 | correct_fn: Function that maps a target to a boolean indicating correctness. 170 | Must return `True` for exactly one instance per group. 171 | ranking_fn: Function that maps a (prediction, score) pair to a something 172 | that can be used as a key to rank instances. 173 | return_correctness: Whether or not to also return a list of judgments of 174 | about correctness. Used for testing only. 175 | Returns: 176 | Dictionary with metric names as keys and metric value as values. Optionally 177 | also returns a list of correctness if specified. 178 | """ 179 | Instance = collections.namedtuple( 180 | "Instance", ["target", "prediction", "score"]) 181 | assert len(targets) == len(predictions) == len(aux_values["scores"]) 182 | instances = [Instance(t, p, s) for t, p, s in 183 | zip(targets, predictions, aux_values["scores"])] 184 | is_correct = [] 185 | total_groups = 0 186 | for _, group in itertools.groupby( 187 | sorted(instances, key=lambda i: group_fn(i.target)), 188 | lambda i: group_fn(i.target)): 189 | group = list(group) 190 | best_idx, _ = max( 191 | enumerate(group), 192 | key=lambda idx_i: ranking_fn(idx_i[1].prediction, idx_i[1].score)) 193 | (true_idx,) = [idx for idx, i in enumerate(group) 194 | if correct_fn(i.target)] 195 | is_correct.append(best_idx == true_idx) 196 | total_groups += 1 197 | eval_dict = dict( 198 | group_accuracy=sum(is_correct) * 100.0 / total_groups, 199 | total_groups=total_groups) 200 | if return_correctness: 201 | return eval_dict, is_correct 202 | else: 203 | return eval_dict 204 | -------------------------------------------------------------------------------- /pix2struct/metrics_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for metrics.""" 16 | from absl.testing import absltest 17 | from pix2struct import metrics 18 | 19 | 20 | class MetricsTest(absltest.TestCase): 21 | 22 | def test_instance_ranking_metrics(self): 23 | eval_dict, is_correct = metrics.instance_ranking_metrics( 24 | predictions=[ 25 | # Rely on score ranking between 'true' labels. 26 | "true", 27 | "true", 28 | "true", 29 | # Rely on score ranking between 'false' labels. 30 | "false", 31 | "false", 32 | "false", 33 | # Rely on predicted label regardless of score. 34 | "false", 35 | "false", 36 | "true", 37 | # Rely on both predicted label and score. 38 | "false", 39 | "true", 40 | "true", 41 | ], 42 | aux_values={"scores": [ 43 | # Rely on score ranking between all 'true' predictions. 44 | -1, 45 | 1, 46 | 0, 47 | # Rely on score ranking between all 'false' predictions. 48 | 1, 49 | 2, 50 | 3, 51 | # Rely on predicted label regardless of score. 52 | 0, 53 | 0, 54 | 0, 55 | # Rely on both predicted label and score. 56 | 2, 57 | 0, 58 | 1, 59 | ]}, 60 | targets=[ 61 | # Rely on score ranking between 'true' labels. 62 | {"group_id": "0", "id": "0_0", "parse": ["false"]}, 63 | {"group_id": "0", "id": "0_1", "parse": ["true"]}, 64 | {"group_id": "0", "id": "0_2", "parse": ["false"]}, 65 | # Rely on score ranking between 'false' labels. 66 | {"group_id": "1", "id": "1_0", "parse": ["true"]}, 67 | {"group_id": "1", "id": "1_1", "parse": ["false"]}, 68 | {"group_id": "1", "id": "1_2", "parse": ["false"]}, 69 | # Rely on predicted label regardless of score. 70 | {"group_id": "2", "id": "2_0", "parse": ["false"]}, 71 | {"group_id": "2", "id": "2_1", "parse": ["false"]}, 72 | {"group_id": "2", "id": "2_2", "parse": ["true"]}, 73 | # Rely on both predicted label and score. 74 | {"group_id": "3", "id": "3_0", "parse": ["false"]}, 75 | {"group_id": "3", "id": "3_1", "parse": ["false"]}, 76 | {"group_id": "3", "id": "3_2", "parse": ["true"]}, 77 | ], 78 | group_fn=lambda t: t["group_id"], 79 | correct_fn=lambda t: t["parse"][0] == "true", 80 | ranking_fn=lambda p, s: (p == "true", s * (1 if p == "true" else -1)), 81 | return_correctness=True) 82 | self.assertEqual([True, True, True, True], is_correct) 83 | self.assertEqual( 84 | { 85 | "group_accuracy": 100.0, 86 | "total_groups": 4 87 | }, 88 | eval_dict) 89 | 90 | def test_pix2struct_metrics(self): 91 | eval_dict = metrics.pix2struct_metrics( 92 | predictions=[ 93 | "abc", 94 | "abc", 95 | "Abc", 96 | "100%", 97 | "100%", 98 | "100%", 99 | "100%", 100 | "Don't", 101 | ], 102 | targets=[ 103 | ["abc"], 104 | ["Abc"], 105 | ["abc"], 106 | ["96%"], 107 | ["94%"], 108 | ["0.96"], 109 | ["0.94"], 110 | ["Won't"], 111 | ]) 112 | for k, v in { 113 | "exact_match": 12.5, 114 | "anls": 47.5, 115 | "relaxed_accuracy": 62.5, 116 | "cider": 128.6 117 | }.items(): 118 | self.assertAlmostEqual(v, eval_dict[k], places=1) 119 | 120 | if __name__ == "__main__": 121 | absltest.main() 122 | -------------------------------------------------------------------------------- /pix2struct/models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Models.""" 16 | from typing import Callable 17 | 18 | from flax import linen as nn 19 | import jax.numpy as jnp 20 | import numpy as np 21 | import seqio 22 | from t5x import models 23 | import tensorflow as tf 24 | 25 | from flaxformer import types 26 | from flaxformer.architectures.t5 import t5_architecture 27 | from flaxformer.components import embedding 28 | 29 | 30 | class ImageToTextFeatureConverter(seqio.EncDecFeatureConverter): 31 | """Feature converter for an image-to-text encoder-decoder architecture.""" 32 | 33 | TASK_FEATURES = { 34 | "inputs": seqio.FeatureConverter.FeatureSpec(dtype=tf.float32, rank=2), 35 | "targets": seqio.FeatureConverter.FeatureSpec(dtype=tf.int32), 36 | } 37 | MODEL_FEATURES = { 38 | "encoder_input_tokens": seqio.FeatureConverter.FeatureSpec( 39 | dtype=tf.float32, rank=2), 40 | "decoder_target_tokens": seqio.FeatureConverter.FeatureSpec( 41 | dtype=tf.int32), 42 | "decoder_input_tokens": seqio.FeatureConverter.FeatureSpec( 43 | dtype=tf.int32), 44 | "decoder_loss_weights": seqio.FeatureConverter.FeatureSpec( 45 | dtype=tf.int32), 46 | } 47 | 48 | 49 | class ImageToTextModel(models.EncoderDecoderModel): 50 | """ImageToTextModel.""" 51 | 52 | FEATURE_CONVERTER_CLS = ImageToTextFeatureConverter 53 | 54 | 55 | class ImageEncoder(t5_architecture.Encoder): 56 | """ImageEncoder.""" 57 | 58 | def __call__(self, 59 | inputs, 60 | inputs_positions=None, 61 | encoder_mask=None, 62 | *, 63 | segment_ids=None, 64 | enable_dropout: bool = True): 65 | 66 | assert inputs.ndim == 3 67 | # We assume `inputs_positions` and `segment_ids` are not present because 68 | # (1) we do not support packing and (2) positional information is encoded as 69 | # the first several channels of the `inputs`. 70 | assert inputs_positions is None 71 | assert segment_ids is None 72 | 73 | assert encoder_mask is not None 74 | 75 | embedded_inputs = self.embedder(token_ids=inputs) 76 | embedded_inputs = self.input_dropout( 77 | embedded_inputs, deterministic=not enable_dropout) 78 | encoder_outputs = self.encode_from_continuous_inputs( 79 | embedded_inputs, 80 | encoder_mask=encoder_mask, 81 | enable_dropout=enable_dropout) 82 | return encoder_outputs 83 | 84 | 85 | class ImageEncoderTextDecoder(t5_architecture.EncoderDecoder): 86 | """ImageEncoderTextDecoder.""" 87 | 88 | def setup(self): 89 | # Having a shared token embedder for images and text doesn't make sense. 90 | assert self.shared_token_embedder_factory is None 91 | self.token_embedder = None 92 | self.encoder = self.encoder_factory() 93 | self.decoder = self.decoder_factory() 94 | 95 | def _make_padding_attention_mask(self, 96 | query_tokens: types.Array, 97 | key_tokens: types.Array) -> types.Array: 98 | del query_tokens 99 | 100 | # Use padding from the positional information from the first channel to 101 | # detect padding. 102 | row_ids = key_tokens[:, :, 0].astype(jnp.int32) 103 | key_mask = row_ids > 0 104 | 105 | # Add singleton axis -3 for broadcasting to the attention heads and 106 | # singleton axis -2 for broadcasting to the queries. 107 | return jnp.expand_dims(key_mask, axis=(-3, -2)).astype(self.dtype) 108 | 109 | 110 | class PatchEmbed(nn.Module, embedding.Embedder[types.Array]): 111 | """Patch embed.""" 112 | # In addition to the patches with the pixels, the first `num_extra_embedders` 113 | # channels in the inputs are assumed to contain additional ids that represent 114 | # any metadata such as positional information. 115 | num_extra_embedders: int 116 | embedder_factory: Callable[[], embedding.Embed] 117 | patch_projection_factory: Callable[[], nn.Module] 118 | 119 | def setup(self): 120 | self.patch_projection = self.patch_projection_factory() 121 | self.embedders = [self.embedder_factory() 122 | for _ in range(self.num_extra_embedders)] 123 | 124 | def __call__(self, inputs, **kwargs): 125 | # Inputs: [id_0, id_1, ..., id_{num_extra_embedders}, 126 | # pixel_0, pixel_1, ..., pixel_{patch_size}] 127 | split_inputs = jnp.split( 128 | inputs, np.arange(self.num_extra_embedders) + 1, -1) 129 | 130 | ids = split_inputs[:-1] 131 | patches = split_inputs[-1] 132 | embeddings = [embedder(i.astype(jnp.int32).squeeze(-1)) for embedder, i in 133 | zip(self.embedders, ids)] 134 | embeddings.append(self.patch_projection(patches)) 135 | return sum(embeddings) 136 | -------------------------------------------------------------------------------- /pix2struct/models_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for pix2struct.models.""" 16 | from absl.testing import absltest 17 | import gin 18 | import jax 19 | import numpy as np 20 | from t5x import partitioning 21 | from t5x import trainer as trainer_lib 22 | from t5x import utils 23 | 24 | 25 | class ModelsTest(absltest.TestCase): 26 | 27 | def setUp(self): 28 | super().setUp() 29 | 30 | gin.clear_config() 31 | gin.add_config_file_search_path("pix2struct/configs") 32 | gin.parse_config_file("models/pix2struct.gin") 33 | gin.parse_config_file("optimizers/adafactor.gin") 34 | gin.parse_config_file("sizes/tiny.gin") 35 | gin.parse_config_file("init/random_init.gin") 36 | 37 | # Our Adafactor implementation requires knowing the total number of steps. 38 | # Don't use a real output vocab to keep this test hermetic. 39 | gin.parse_config(""" 40 | TRAIN_STEPS = 1 41 | models.ImageToTextModel.output_vocabulary = @seqio.PassThroughVocabulary() 42 | """) 43 | gin.finalize() 44 | self.model = gin.query_parameter("%MODEL").scoped_configurable_fn() 45 | 46 | self.input_data = { 47 | "encoder_input_tokens": np.ones(shape=(8, 4, 5), dtype=np.float32), 48 | "decoder_input_tokens": np.ones(shape=(8, 3), dtype=np.int32), 49 | "decoder_target_tokens": np.ones(shape=(8, 3), dtype=np.int32) 50 | } 51 | self.partitioner = partitioning.PjitPartitioner(num_partitions=1) 52 | self.train_state_initializer = utils.TrainStateInitializer( 53 | optimizer_def=self.model.optimizer_def, 54 | init_fn=self.model.get_initial_variables, 55 | input_shapes={k: v.shape for k, v in self.input_data.items()}, 56 | partitioner=self.partitioner) 57 | self.train_state = self.train_state_initializer.from_scratch( 58 | jax.random.PRNGKey(0)) 59 | 60 | def test_image_encoder_text_decoder_train(self): 61 | trainer = trainer_lib.Trainer( 62 | self.model, 63 | train_state=self.train_state, 64 | partitioner=self.partitioner, 65 | eval_names=[], 66 | summary_dir=None, 67 | train_state_axes=self.train_state_initializer.train_state_axes, # pytype: disable=attribute-error # jax-api-types 68 | rng=jax.random.PRNGKey(0), 69 | learning_rate_fn=lambda x: 0.001, 70 | num_microbatches=1) 71 | 72 | trainer.train( 73 | batch_iter=iter([self.input_data]), 74 | num_steps=1) 75 | 76 | def test_image_encoder_text_decoder_predict(self): 77 | predictions = self.model.predict_batch( 78 | params=self.train_state.params, 79 | batch=self.input_data) 80 | self.assertSequenceEqual(predictions.shape, [8, 3]) 81 | 82 | if __name__ == "__main__": 83 | absltest.main() 84 | -------------------------------------------------------------------------------- /pix2struct/postprocessors.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Postprocessors.""" 16 | import tensorflow as tf 17 | 18 | 19 | def multi_target(output, example=None, is_target=False): 20 | if is_target: 21 | return [tf.compat.as_text(p) for p in example["parse"]] 22 | return output 23 | 24 | 25 | def group_target(output, example=None, is_target=False): 26 | if is_target: 27 | return { 28 | "group_id": tf.compat.as_text(example["group_id"]), 29 | "parse": [tf.compat.as_text(p) for p in example["parse"]] 30 | } 31 | return output 32 | -------------------------------------------------------------------------------- /pix2struct/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /pix2struct/preprocessing/convert_ai2d.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Convert AI2D data. 16 | """ 17 | import json 18 | import logging 19 | import os 20 | import string 21 | from typing import Iterable 22 | 23 | from absl import app 24 | from absl import flags 25 | import apache_beam as beam 26 | from PIL import Image 27 | from pix2struct.preprocessing import preprocessing_utils 28 | import tensorflow as tf 29 | 30 | flags.DEFINE_string( 31 | "data_dir", 32 | None, 33 | "Directory containing the AI2D data.") 34 | 35 | flags.DEFINE_string( 36 | "test_ids_path", 37 | None, 38 | "Path to CSV file containing the ids of the test datapoints.") 39 | 40 | 41 | def convert(input_path: str, data_dir: str) -> Iterable[tf.train.Example]: 42 | """Convert example.""" 43 | with tf.io.gfile.GFile(input_path) as f: 44 | data = json.load(f) 45 | with tf.io.gfile.GFile( 46 | os.path.join(data_dir, "images", data["imageName"]), "rb") as f: 47 | image = Image.open(f) 48 | with tf.io.gfile.GFile( 49 | os.path.join(data_dir, "annotations", f"{data['imageName']}.json")) as f: 50 | annotation = json.load(f) 51 | 52 | image_with_placeholders = image.copy() 53 | for v in annotation["text"].values(): 54 | preprocessing_utils.render_text_on_bounding_box( 55 | text=v["replacementText"], 56 | bounding_box=v["rectangle"], 57 | image=image_with_placeholders) 58 | 59 | for k, v in data["questions"].items(): 60 | example = tf.train.Example() 61 | 62 | # The `image_id` field is only used to ensure correct splitting of the data. 63 | preprocessing_utils.add_text_feature(example, "image_id", data["imageName"]) 64 | options = " ".join( 65 | f"({string.ascii_lowercase[i]}) {a}" 66 | for i, a in enumerate(v["answerTexts"]) 67 | ) 68 | 69 | image_with_header = preprocessing_utils.render_header( 70 | image=image_with_placeholders if v["abcLabel"] else image, 71 | header=f"{k} {options}", 72 | ) 73 | preprocessing_utils.add_bytes_feature( 74 | example, "image", preprocessing_utils.image_to_bytes(image_with_header) 75 | ) 76 | parse = v["answerTexts"][v["correctAnswer"]] 77 | preprocessing_utils.add_text_feature(example, "parse", parse) 78 | yield example 79 | 80 | 81 | def pipeline(root): 82 | with tf.io.gfile.GFile(flags.FLAGS.test_ids_path) as f: 83 | test_ids = {f"{l.strip()}.png" for l in f if l.strip()} 84 | _ = (root 85 | | "Create" >> beam.Create(tf.io.gfile.glob( 86 | os.path.join(flags.FLAGS.data_dir, "questions", "*.json"))) 87 | | "Convert" >> beam.FlatMap(convert, data_dir=flags.FLAGS.data_dir) 88 | | "Write" >> preprocessing_utils.SplitAndWriteTFRecords( 89 | output_dir=os.path.join(flags.FLAGS.data_dir, "processed"), 90 | key="image_id", 91 | validation_percent=1, 92 | is_test=lambda x: x in test_ids)) 93 | 94 | 95 | def main(argv): 96 | with beam.Pipeline( 97 | options=beam.options.pipeline_options.PipelineOptions(argv[1:])) as root: 98 | pipeline(root) 99 | 100 | if __name__ == "__main__": 101 | logging.getLogger().setLevel(logging.INFO) 102 | flags.mark_flag_as_required("data_dir") 103 | app.run(main) 104 | -------------------------------------------------------------------------------- /pix2struct/preprocessing/convert_chartqa.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Convert ChartQA data to the common Pix2Struct format. 16 | """ 17 | import json 18 | import logging 19 | import os 20 | 21 | from absl import app 22 | from absl import flags 23 | import apache_beam as beam 24 | from PIL import Image 25 | from pix2struct.preprocessing import preprocessing_utils 26 | import tensorflow as tf 27 | 28 | flags.DEFINE_string("data_dir", 29 | None, 30 | "Directory containing the ChartQA data.") 31 | 32 | 33 | class ProcessSplit(beam.PTransform): 34 | """Process split.""" 35 | 36 | def __init__(self, split: str, version: str): 37 | self._split = split 38 | self._data_dir = flags.FLAGS.data_dir 39 | self._version = version 40 | 41 | def convert_to_tf_examples( 42 | self, example_id, json_example 43 | ) -> tf.train.Example: 44 | with tf.io.gfile.GFile( 45 | os.path.join( 46 | self._data_dir, self._split, "png", json_example["imgname"] 47 | ), 48 | "rb", 49 | ) as f: 50 | image = Image.open(f) 51 | 52 | tf_example = tf.train.Example() 53 | image_with_question = preprocessing_utils.render_header( 54 | image, json_example["query"] 55 | ) 56 | preprocessing_utils.add_bytes_feature( 57 | tf_example, 58 | "image", 59 | preprocessing_utils.image_to_bytes(image_with_question), 60 | ) 61 | preprocessing_utils.add_text_feature( 62 | tf_example, "id", f"{self._split}_{self._version}_{example_id}" 63 | ) 64 | parse = json_example["label"] 65 | preprocessing_utils.add_text_feature(tf_example, "parse", parse) 66 | return tf_example 67 | 68 | def expand(self, root): 69 | assert self._version in ("human", "augmented") 70 | data_path = os.path.join( 71 | self._data_dir, self._split, f"{self._split}_{self._version}.json" 72 | ) 73 | with tf.io.gfile.GFile(data_path) as data_file: 74 | data = json.load(data_file) 75 | 76 | output_path = os.path.join( 77 | self._data_dir, f"processed_{self._version}", f"{self._split}.tfr" 78 | ) 79 | return ( 80 | root 81 | | "Create" >> beam.Create(enumerate(data)) 82 | | "Convert" >> beam.MapTuple(self.convert_to_tf_examples) 83 | | "Shuffle" >> beam.Reshuffle() 84 | | "Write" 85 | >> beam.io.WriteToTFRecord( 86 | output_path, coder=beam.coders.ProtoCoder(tf.train.Example) 87 | ) 88 | ) 89 | 90 | 91 | def pipeline(root): 92 | _ = root | "ProcessTrainHuman" >> ProcessSplit("train", "human") 93 | _ = root | "ProcessValHuman" >> ProcessSplit("val", "human") 94 | _ = root | "ProcessTestHuman" >> ProcessSplit("test", "human") 95 | 96 | _ = root | "ProcessTrainAugmented" >> ProcessSplit("train", "augmented") 97 | _ = root | "ProcessValAugmented" >> ProcessSplit("val", "augmented") 98 | _ = root | "ProcessTestAugmented" >> ProcessSplit("test", "augmented") 99 | 100 | 101 | def main(argv): 102 | with beam.Pipeline( 103 | options=beam.options.pipeline_options.PipelineOptions(argv[1:])) as root: 104 | pipeline(root) 105 | 106 | if __name__ == "__main__": 107 | logging.getLogger().setLevel(logging.INFO) 108 | flags.mark_flag_as_required("data_dir") 109 | app.run(main) 110 | -------------------------------------------------------------------------------- /pix2struct/preprocessing/convert_docvqa.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Convert DocVQA/InfographicVQA data to the common Pix2Struct format. 16 | """ 17 | 18 | import json 19 | import logging 20 | import os 21 | 22 | from absl import app 23 | from absl import flags 24 | import apache_beam as beam 25 | from PIL import Image 26 | from pix2struct.preprocessing import preprocessing_utils 27 | import tensorflow as tf 28 | 29 | flags.DEFINE_string( 30 | "data_dir", 31 | None, 32 | "Directory containing the DocVQA or InfographicVQA data.") 33 | 34 | 35 | class ProcessSplit(beam.PTransform): 36 | """Process split.""" 37 | 38 | def __init__(self, split: str): 39 | self._split = split 40 | self._data_dir = flags.FLAGS.data_dir 41 | 42 | def read_image(self, filename): 43 | with tf.io.gfile.GFile(os.path.join( 44 | self._data_dir, self._split, filename), "rb") as f: 45 | return Image.open(f) 46 | 47 | def convert_to_tf_examples(self, json_example) -> tf.train.Example: 48 | if "image" in json_example: 49 | image = self.read_image(json_example["image"]) 50 | else: 51 | image = self.read_image(json_example["image_local_name"]) 52 | 53 | tf_example = tf.train.Example() 54 | image_with_question = preprocessing_utils.render_header( 55 | image, json_example["question"] 56 | ) 57 | preprocessing_utils.add_bytes_feature( 58 | tf_example, 59 | "image", 60 | preprocessing_utils.image_to_bytes(image_with_question), 61 | ) 62 | preprocessing_utils.add_text_feature( 63 | tf_example, "id", str(json_example["questionId"]) 64 | ) 65 | # "N/A" parse for the test set where the answers are not available 66 | for parse in json_example.get("answers", ["N/A"]): 67 | preprocessing_utils.add_text_feature(tf_example, "parse", parse) 68 | return tf_example 69 | 70 | def expand(self, root): 71 | data_path = os.path.join( 72 | self._data_dir, self._split, f"{self._split}_v1.0.json" 73 | ) 74 | with tf.io.gfile.GFile(data_path) as data_file: 75 | data = json.load(data_file) 76 | assert data["dataset_name"] in ("docvqa", "infographicVQA") 77 | assert data["dataset_version"] == "1.0" 78 | assert data["dataset_split"] == self._split 79 | 80 | output_path = os.path.join( 81 | self._data_dir, "processed", f"{self._split}.tfr") 82 | return (root 83 | | "Create" >> beam.Create(data["data"]) 84 | | "Convert" >> beam.Map(self.convert_to_tf_examples) 85 | | "Shuffle" >> beam.Reshuffle() 86 | | "Write" >> beam.io.WriteToTFRecord( 87 | output_path, 88 | coder=beam.coders.ProtoCoder(tf.train.Example))) 89 | 90 | 91 | def pipeline(root): 92 | _ = (root | "ProcessTrain" >> ProcessSplit("train")) 93 | _ = (root | "ProcessVal" >> ProcessSplit("val")) 94 | _ = (root | "ProcessTest" >> ProcessSplit("test")) 95 | 96 | 97 | def main(argv): 98 | with beam.Pipeline( 99 | options=beam.options.pipeline_options.PipelineOptions(argv[1:])) as root: 100 | pipeline(root) 101 | 102 | if __name__ == "__main__": 103 | logging.getLogger().setLevel(logging.INFO) 104 | flags.mark_flag_as_required("data_dir") 105 | app.run(main) 106 | -------------------------------------------------------------------------------- /pix2struct/preprocessing/convert_ocrvqa.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Convert OCR-VQA data to the common Pix2Struct format. 16 | """ 17 | 18 | 19 | import json 20 | import logging 21 | import os 22 | from typing import Iterable 23 | 24 | from absl import app 25 | from absl import flags 26 | import apache_beam as beam 27 | from PIL import Image 28 | from pix2struct.preprocessing import preprocessing_utils 29 | import tensorflow as tf 30 | 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | flags.DEFINE_string( 35 | "data_dir", 36 | None, 37 | "Directory containing OCR-VQA data.") 38 | 39 | 40 | class ProcessSplit(beam.PTransform): 41 | """Process split.""" 42 | 43 | # fixed extension since all images have already been converted to .jpg 44 | def __init__(self, split: str, extension: str = ".jpg"): 45 | self._split = split 46 | split_indexes = {"train": 1, "val": 2, "test": 3} 47 | self._split_index = split_indexes[split] 48 | self._extension = extension 49 | self._data_dir = FLAGS.data_dir 50 | self._error_counter = beam.metrics.Metrics.counter( 51 | "example", "errors") 52 | self._processed_counter = beam.metrics.Metrics.counter( 53 | "example", "processed") 54 | 55 | def read_image(self, filename): 56 | with tf.io.gfile.GFile(os.path.join( 57 | self._data_dir, filename), "rb") as f: 58 | return Image.open(f) 59 | 60 | def convert_to_tf_examples(self, json_example) -> Iterable[tf.train.Example]: 61 | """Returns a list, either empty or with a single example.""" 62 | 63 | # OutOfRangeError happens due to unknown reasons on some training examples; 64 | # we discard those. 65 | try: 66 | image = self.read_image(json_example["image"]) 67 | except tf.errors.OutOfRangeError: 68 | self._error_counter.inc() 69 | return 70 | 71 | tf_example = tf.train.Example() 72 | image_with_question = preprocessing_utils.render_header( 73 | image, json_example["question"]) 74 | preprocessing_utils.add_bytes_feature( 75 | tf_example, "image", 76 | preprocessing_utils.image_to_bytes(image_with_question)) 77 | preprocessing_utils.add_text_feature( 78 | tf_example, "id", str(json_example["questionId"])) 79 | for parse in json_example["answers"]: 80 | preprocessing_utils.add_text_feature(tf_example, "parse", parse) 81 | self._processed_counter.inc() 82 | yield tf_example 83 | 84 | def expand(self, root): 85 | data_path = os.path.join( 86 | self._data_dir, "dataset.json") 87 | with tf.io.gfile.GFile(data_path) as data_file: 88 | json_data = json.load(data_file) 89 | data = [] 90 | question_id = 0 91 | for image_id, image_data in json_data.items(): 92 | if image_data["split"] != self._split_index: 93 | continue 94 | for question, answer in zip(image_data["questions"], 95 | image_data["answers"]): 96 | data.append({ 97 | "question": question, 98 | "questionId": str(question_id), 99 | "answers": [answer], 100 | "image": f"images/{image_id}{self._extension}" 101 | }) 102 | question_id += 1 103 | 104 | output_path = os.path.join( 105 | self._data_dir, "processed", f"{self._split}.tfr") 106 | return (root 107 | | "Create" >> beam.Create(data) 108 | | "Convert" >> beam.FlatMap(self.convert_to_tf_examples) 109 | | "Shuffle" >> beam.Reshuffle() 110 | | "Write" >> beam.io.WriteToTFRecord( 111 | output_path, coder=beam.coders.ProtoCoder(tf.train.Example))) 112 | 113 | 114 | def pipeline(root): 115 | _ = (root | "ProcessTrain" >> ProcessSplit("train")) 116 | _ = (root | "ProcessVal" >> ProcessSplit("val")) 117 | _ = (root | "ProcessTest" >> ProcessSplit("test")) 118 | 119 | 120 | def main(argv): 121 | with beam.Pipeline( 122 | options=beam.options.pipeline_options.PipelineOptions(argv[1:])) as root: 123 | pipeline(root) 124 | 125 | if __name__ == "__main__": 126 | logging.getLogger().setLevel(logging.INFO) 127 | flags.mark_flag_as_required("data_dir") 128 | app.run(main) 129 | -------------------------------------------------------------------------------- /pix2struct/preprocessing/convert_refexp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Convert refexp data to the common Pix2Struct format. 16 | """ 17 | import logging 18 | import os 19 | import random 20 | from typing import Iterable 21 | 22 | from absl import app 23 | from absl import flags 24 | import apache_beam as beam 25 | import numpy as np 26 | from PIL import Image 27 | from PIL import ImageDraw 28 | from pix2struct.preprocessing import preprocessing_utils 29 | import tensorflow as tf 30 | 31 | flags.DEFINE_string("data_dir", None, "Directory containing the refexp data.") 32 | 33 | flags.DEFINE_string( 34 | "image_dir", 35 | None, 36 | "Directory containing the images referenced in refexp data.") 37 | 38 | flags.DEFINE_integer( 39 | "num_negative_samples", 40 | 5, 41 | "Number of negative samples per instance.") 42 | 43 | 44 | class ProcessSplit(beam.PTransform): 45 | """Process split.""" 46 | 47 | def __init__(self, split: str): 48 | self._split = split 49 | self._data_dir = flags.FLAGS.data_dir 50 | self._image_dir = flags.FLAGS.image_dir 51 | 52 | def get_image(self, image_id): 53 | filename = image_id + ".jpg" 54 | with tf.io.gfile.GFile(os.path.join(self._image_dir, filename), "rb") as f: 55 | return Image.open(f) 56 | 57 | def draw_bounding_box(self, image, candidate_idx, example): 58 | def _get_coordinate(key, max_value): 59 | float_val = example.features.feature[key].float_list.value[candidate_idx] 60 | return round(float_val * max_value) 61 | image_dims = np.asarray(image).shape 62 | xmin = _get_coordinate("image/object/bbox/xmin", image_dims[1]) 63 | xmax = _get_coordinate("image/object/bbox/xmax", image_dims[1]) 64 | ymin = _get_coordinate("image/object/bbox/ymin", image_dims[0]) 65 | ymax = _get_coordinate("image/object/bbox/ymax", image_dims[0]) 66 | img_draw = ImageDraw.Draw(image, "RGBA") 67 | img_draw.rectangle( 68 | xy=((xmin, ymax), 69 | (xmax, ymin)), 70 | fill=(0, 0, 255, 0), 71 | outline=(0, 0, 255, 255)) 72 | return image 73 | 74 | def convert_to_tf_examples(self, record_id, record 75 | ) -> Iterable[tf.train.Example]: 76 | raw_example = tf.train.Example().FromString(record.numpy()) 77 | record_id = record_id.numpy().item() 78 | try: 79 | label = preprocessing_utils.get_int_feature(raw_example, 80 | "image/ref_exp/label") 81 | num_candidates = int( 82 | preprocessing_utils.get_float_feature(raw_example, 83 | "image/object/num")) 84 | query = preprocessing_utils.get_text_feature(raw_example, 85 | "image/ref_exp/text") 86 | image_id = preprocessing_utils.get_text_feature(raw_example, "image/id") 87 | image = self.get_image(image_id) 88 | except (IndexError, tf.errors.NotFoundError): 89 | return 90 | 91 | if flags.FLAGS.num_negative_samples and self._split == "train": 92 | num_negative_samples = flags.FLAGS.num_negative_samples 93 | else: 94 | num_negative_samples = num_candidates 95 | 96 | candidates = list(cand for cand in range(num_candidates) if cand != label) 97 | random.shuffle(candidates) 98 | candidates = candidates[:num_negative_samples] + [label] 99 | for candidate_idx in candidates: 100 | tf_example = tf.train.Example() 101 | candidate_image = image.copy() 102 | candidate_image = self.draw_bounding_box(candidate_image, candidate_idx, 103 | raw_example) 104 | candidate_image = preprocessing_utils.render_header( 105 | candidate_image, query) 106 | is_correct = label == candidate_idx 107 | # pix2struct features 108 | preprocessing_utils.add_bytes_feature( 109 | tf_example, "image", 110 | preprocessing_utils.image_to_bytes(candidate_image)) 111 | preprocessing_utils.add_text_feature( 112 | tf_example, "parse", str(is_correct).lower()) 113 | preprocessing_utils.add_text_feature( 114 | tf_example, "id", str(f"{record_id}_{candidate_idx}")) 115 | # pix2box features 116 | preprocessing_utils.add_text_feature( 117 | tf_example, "group_id", str(record_id)) 118 | preprocessing_utils.add_text_feature( 119 | tf_example, "candidate_id", str(candidate_idx)) 120 | yield tf_example 121 | 122 | def expand(self, root): 123 | data_path = os.path.join( 124 | self._data_dir, f"{self._split}.tfrecord") 125 | raw_dataset = tf.data.TFRecordDataset([data_path]) 126 | # get a unique id per record 127 | raw_dataset = raw_dataset.enumerate(start=0) 128 | output_path = os.path.join( 129 | self._data_dir, "processed", f"{self._split}.tfr") 130 | 131 | return (root 132 | | "Create" >> beam.Create(raw_dataset) 133 | | "Convert" >> beam.FlatMapTuple(self.convert_to_tf_examples) 134 | | "Shuffle" >> beam.Reshuffle() 135 | | "Write" >> beam.io.WriteToTFRecord( 136 | output_path, 137 | coder=beam.coders.ProtoCoder(tf.train.Example))) 138 | 139 | 140 | def pipeline(root): 141 | _ = (root | "ProcessTrain" >> ProcessSplit("train")) 142 | _ = (root | "ProcessVal" >> ProcessSplit("val")) 143 | _ = (root | "ProcessTest" >> ProcessSplit("test")) 144 | 145 | 146 | def main(argv): 147 | with beam.Pipeline( 148 | options=beam.options.pipeline_options.PipelineOptions(argv[1:])) as root: 149 | pipeline(root) 150 | 151 | if __name__ == "__main__": 152 | logging.getLogger().setLevel(logging.INFO) 153 | flags.mark_flag_as_required("data_dir") 154 | flags.mark_flag_as_required("image_dir") 155 | app.run(main) 156 | -------------------------------------------------------------------------------- /pix2struct/preprocessing/convert_screen2words.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Convert Screen2Words data to a format that can be streamed. 16 | """ 17 | import logging 18 | import os 19 | from typing import Iterable, Tuple 20 | 21 | from absl import app 22 | from absl import flags 23 | import apache_beam as beam 24 | from pix2struct.preprocessing import preprocessing_utils 25 | import tensorflow as tf 26 | 27 | flags.DEFINE_string("screen2words_dir", None, 28 | "Directory containing Screen2Words data.") 29 | 30 | flags.DEFINE_string("rico_dir", None, "Directory containing RICO data.") 31 | 32 | 33 | def parse_summary_line(line: str) -> Iterable[Tuple[str, str]]: 34 | line = line.strip() 35 | if line and line != "screenId,summary": 36 | screen_id, summary = line.split(",", 1) 37 | yield screen_id, summary 38 | 39 | 40 | class ProcessSplit(beam.PTransform): 41 | """Examples from task.""" 42 | 43 | def __init__(self, split: str): 44 | self._split = split 45 | self._rico_dir = flags.FLAGS.rico_dir 46 | self._screen2words_dir = flags.FLAGS.screen2words_dir 47 | 48 | def get_image(self, screen_id: str) -> bytes: 49 | with tf.io.gfile.GFile( 50 | os.path.join(self._rico_dir, f"{screen_id}.jpg"), "rb") as f: 51 | return f.read() 52 | 53 | def convert_to_tf_examples( 54 | self, 55 | screen_id: str, 56 | summaries_and_placeholder: Tuple[Iterable[str], Iterable[bool]], 57 | ) -> Iterable[tf.train.Example]: 58 | """Convert the results of joining examples and splits to TF examples.""" 59 | summaries, placeholder = summaries_and_placeholder 60 | # Only yield examples if there was a non-empty join with the intended split. 61 | if any(placeholder): 62 | example = tf.train.Example() 63 | preprocessing_utils.add_bytes_feature( 64 | example, "image", self.get_image(screen_id)) 65 | for summary in summaries: 66 | preprocessing_utils.add_text_feature( 67 | example, "parse", summary) 68 | yield example 69 | 70 | def expand(self, root_and_summaries): 71 | root, summaries = root_and_summaries 72 | screens_path = os.path.join(flags.FLAGS.screen2words_dir, "split", 73 | f"{self._split}_screens.txt") 74 | screen_ids_for_split = (root 75 | | "Read" >> beam.io.ReadFromText(screens_path) 76 | | "Parse" >> beam.Map(lambda l: (l.strip(), True))) 77 | output_path = os.path.join( 78 | self._screen2words_dir, "processed", f"{self._split}.tfr") 79 | return ((summaries, screen_ids_for_split) 80 | | "Join" >> beam.CoGroupByKey() 81 | | "Convert" >> beam.FlatMapTuple(self.convert_to_tf_examples) 82 | | "Shuffle" >> beam.Reshuffle() 83 | | "Write" >> beam.io.WriteToTFRecord( 84 | output_path, 85 | coder=beam.coders.ProtoCoder(tf.train.Example))) 86 | 87 | 88 | def pipeline(root): 89 | """Pipeline.""" 90 | summaries = ( 91 | root 92 | | "ReadSummaries" >> beam.io.ReadFromText( 93 | os.path.join(flags.FLAGS.screen2words_dir, "screen_summaries.csv")) 94 | | "ParseSummaries" >> beam.FlatMap(parse_summary_line)) 95 | _ = ((root, summaries) | "ProcessTrain" >> ProcessSplit("train")) 96 | _ = ((root, summaries) | "ProcessDev" >> ProcessSplit("dev")) 97 | _ = ((root, summaries) | "ProcessTest" >> ProcessSplit("test")) 98 | 99 | 100 | def main(argv): 101 | with beam.Pipeline( 102 | options=beam.options.pipeline_options.PipelineOptions(argv[1:])) as root: 103 | pipeline(root) 104 | 105 | if __name__ == "__main__": 106 | logging.getLogger().setLevel(logging.INFO) 107 | flags.mark_flag_as_required("screen2words_dir") 108 | flags.mark_flag_as_required("rico_dir") 109 | app.run(main) 110 | -------------------------------------------------------------------------------- /pix2struct/preprocessing/convert_textcaps.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Convert TextCaps data. 16 | """ 17 | import collections 18 | import io 19 | import json 20 | import logging 21 | import os 22 | from typing import Iterable, List, Tuple 23 | from absl import app 24 | from absl import flags 25 | import apache_beam as beam 26 | import PIL 27 | from pix2struct.preprocessing import preprocessing_utils 28 | import tensorflow as tf 29 | 30 | flags.DEFINE_string("textcaps_dir", None, "Train captions path.") 31 | 32 | flags.DEFINE_string("output_dir", None, "Output path.") 33 | 34 | 35 | class ConvertDataForSplit(beam.PTransform): 36 | """Convert data for split.""" 37 | 38 | def __init__(self, split: str, image_dir: str): 39 | self._split = split 40 | self._textcaps_dir = flags.FLAGS.textcaps_dir 41 | self._output_dir = flags.FLAGS.output_dir 42 | self._image_dir = image_dir 43 | 44 | def json_to_image_ids_and_captions(self) -> Iterable[Tuple[str, List[str]]]: 45 | with tf.io.gfile.GFile(os.path.join( 46 | self._textcaps_dir, f"TextCaps_0.1_{self._split}.json")) as f: 47 | data = json.load(f) 48 | assert data["dataset_name"] == "textcaps" 49 | assert data["dataset_type"] == self._split 50 | example_dict = collections.defaultdict(list) 51 | for example in data["data"]: 52 | example_dict[example["image_id"]].append( 53 | example.get("caption_str", "N/A")) 54 | for image_id, captions in example_dict.items(): 55 | yield image_id, captions 56 | 57 | def image_id_and_caption_to_example(self, image_id: str, 58 | captions: List[str]) -> tf.train.Example: 59 | image_bytes = io.BytesIO() 60 | with tf.io.gfile.GFile(os.path.join( 61 | self._textcaps_dir, self._image_dir, f"{image_id}.jpg"), "rb") as f: 62 | PIL.Image.open(f).convert("RGB").save(image_bytes, format="PNG") 63 | example = tf.train.Example() 64 | preprocessing_utils.add_text_feature(example, "id", image_id) 65 | preprocessing_utils.add_bytes_feature( 66 | example, "image", image_bytes.getvalue()) 67 | for caption in captions: 68 | preprocessing_utils.add_text_feature(example, "parse", caption) 69 | return example 70 | 71 | def expand(self, pcoll): 72 | return (pcoll 73 | | "Read" >> beam.Create(self.json_to_image_ids_and_captions()) 74 | | "Reshuffle" >> beam.Reshuffle() 75 | | "ToExample" >> beam.MapTuple(self.image_id_and_caption_to_example) 76 | | "Write" >> beam.io.WriteToTFRecord( 77 | os.path.join(self._output_dir, f"{self._split}.tfr"), 78 | coder=beam.coders.ProtoCoder(tf.train.Example))) 79 | 80 | 81 | def pipeline(root): 82 | """Pipeline.""" 83 | _ = (root | "ConvertTrain" >> ConvertDataForSplit("train", "train_images")) 84 | _ = (root | "ConvertVal" >> ConvertDataForSplit("val", "train_images")) 85 | _ = (root | "ConvertTest" >> ConvertDataForSplit("test", "test_images")) 86 | 87 | 88 | def main(argv): 89 | with beam.Pipeline( 90 | options=beam.options.pipeline_options.PipelineOptions(argv[1:])) as root: 91 | pipeline(root) 92 | 93 | if __name__ == "__main__": 94 | logging.getLogger().setLevel(logging.INFO) 95 | flags.mark_flag_as_required("textcaps_dir") 96 | flags.mark_flag_as_required("output_dir") 97 | app.run(main) 98 | -------------------------------------------------------------------------------- /pix2struct/preprocessing/convert_widget_captioning.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Convert widget captioning data to the common Pix2Struct format. 16 | """ 17 | import csv 18 | import io 19 | import json 20 | import logging 21 | import os 22 | from typing import Iterable 23 | 24 | from absl import app 25 | from absl import flags 26 | import apache_beam as beam 27 | import numpy as np 28 | from PIL import Image 29 | from PIL import ImageDraw 30 | from pix2struct.preprocessing import preprocessing_utils 31 | import tensorflow as tf 32 | 33 | flags.DEFINE_string( 34 | "data_dir", None, 35 | "Directory containing the widget captioning data.") 36 | 37 | flags.DEFINE_string("data_file", "widget_captions.csv", "Data file name.") 38 | 39 | flags.DEFINE_integer("rico_canvas_y", 2560, 40 | "Dataset property indicating the y-dim of the canvas.") 41 | 42 | flags.DEFINE_string( 43 | "image_dir", 44 | None, 45 | "Directory containing the images referenced in refexp data.") 46 | 47 | flags.DEFINE_string( 48 | "processed_dir", 49 | "processed", 50 | "Sub-directory containing the refexp processed data.") 51 | 52 | 53 | class ProcessSplit(beam.PTransform): 54 | """Process split.""" 55 | 56 | def __init__(self, split: str): 57 | self._split = split 58 | self._data_dir = flags.FLAGS.data_dir 59 | self._image_dir = flags.FLAGS.image_dir 60 | self._rico_canvas_y = flags.FLAGS.rico_canvas_y 61 | 62 | def get_node_box(self, screen_id, node_id, image_dims): 63 | index_list = [int(i) for i in node_id.split(".")[1:]] 64 | with tf.io.gfile.GFile(os.path.join(self._image_dir, 65 | screen_id + ".json")) as f: 66 | view = json.load(f) 67 | curr_node = view["activity"]["root"] 68 | for index in index_list: 69 | curr_node = curr_node["children"][index] 70 | normalized_bounds = map(lambda x: x * image_dims[0] / self._rico_canvas_y, 71 | curr_node["bounds"]) 72 | return normalized_bounds 73 | 74 | def convert_to_tf_examples(self, screen_id, node_id, 75 | captions) -> Iterable[tf.train.Example]: 76 | # get image 77 | with tf.io.gfile.GFile( 78 | os.path.join(self._image_dir, screen_id + ".jpg"), "rb") as f: 79 | image = Image.open(f) 80 | image_dims = np.asarray(image).shape 81 | # get bounding box coordinates 82 | xmin, ymin, xmax, ymax = self.get_node_box(screen_id, node_id, image_dims) 83 | # draw bounding box 84 | img_draw = ImageDraw.Draw(image, "RGBA") 85 | img_draw.rectangle( 86 | xy=((xmin, ymax), 87 | (xmax, ymin)), 88 | fill=(0, 0, 255, 0), 89 | outline=(0, 0, 255, 255)) 90 | tf_example = tf.train.Example() 91 | # Convert the image to bytes. 92 | img_byte_arr = io.BytesIO() 93 | image.save(img_byte_arr, format="PNG") 94 | preprocessing_utils.add_bytes_feature(tf_example, "image", 95 | img_byte_arr.getvalue()) 96 | preprocessing_utils.add_text_feature(tf_example, "id", 97 | str(f"{screen_id}_{node_id}")) 98 | for caption in captions.split("|"): 99 | preprocessing_utils.add_text_feature(tf_example, "parse", caption) 100 | yield tf_example 101 | 102 | def expand(self, root): 103 | # read split screen ids 104 | split_screen_ids = set() 105 | with tf.io.gfile.GFile(os.path.join(self._data_dir, 106 | self._split + ".txt")) as f: 107 | for line in f: 108 | split_screen_ids.add(line.strip()) 109 | 110 | data = [] 111 | with tf.io.gfile.GFile(os.path.join(self._data_dir, 112 | "widget_captions.csv")) as f: 113 | reader = csv.DictReader(f, delimiter=",") 114 | for row in reader: 115 | if row["screenId"] in split_screen_ids: 116 | data.append((row["screenId"], row["nodeId"], row["captions"])) 117 | 118 | output_path = os.path.join( 119 | self._data_dir, flags.FLAGS.processed_dir, f"{self._split}.tfr") 120 | return (root 121 | | "Create" >> beam.Create(data) 122 | | "Convert" >> beam.FlatMapTuple(self.convert_to_tf_examples) 123 | | "Shuffle" >> beam.Reshuffle() 124 | | "Write" >> beam.io.WriteToTFRecord( 125 | output_path, coder=beam.coders.ProtoCoder(tf.train.Example))) 126 | 127 | 128 | def pipeline(root): 129 | _ = (root | "ProcessTrain" >> ProcessSplit("train")) 130 | _ = (root | "ProcessVal" >> ProcessSplit("val")) 131 | _ = (root | "ProcessTest" >> ProcessSplit("test")) 132 | 133 | 134 | def main(argv): 135 | with beam.Pipeline( 136 | options=beam.options.pipeline_options.PipelineOptions(argv[1:])) as root: 137 | pipeline(root) 138 | 139 | if __name__ == "__main__": 140 | logging.getLogger().setLevel(logging.INFO) 141 | flags.mark_flag_as_required("data_dir") 142 | flags.mark_flag_as_required("image_dir") 143 | app.run(main) 144 | -------------------------------------------------------------------------------- /pix2struct/preprocessing/preprocessing_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Preprocessing utils.""" 16 | import hashlib 17 | import io 18 | import os 19 | import random 20 | import textwrap 21 | from typing import Any, Callable, Iterable, List, Optional 22 | 23 | import apache_beam as beam 24 | from PIL import Image 25 | from PIL import ImageDraw 26 | from PIL import ImageFont 27 | import tensorflow as tf 28 | 29 | DEFAULT_FONT_PATH = "arial.ttf" 30 | 31 | 32 | def add_int_feature(example: tf.train.Example, 33 | key: str, 34 | value: int) -> None: 35 | example.features.feature[key].int64_list.value.append(value) 36 | 37 | 38 | def add_bytes_feature(example: tf.train.Example, 39 | key: str, 40 | value: bytes) -> None: 41 | example.features.feature[key].bytes_list.value.append(value) 42 | 43 | 44 | def add_text_feature(example: tf.train.Example, key: str, value: str) -> None: 45 | add_bytes_feature(example, key, value.encode("utf-8")) 46 | 47 | 48 | def get_bytes_feature(example: tf.train.Example, key: str) -> bytes: 49 | return example.features.feature[key].bytes_list.value[0] 50 | 51 | 52 | def get_text_feature(example: tf.train.Example, key: str) -> str: 53 | return get_bytes_feature(example, key).decode("utf-8") 54 | 55 | 56 | def get_text_features(example: tf.train.Example, key: str) -> List[str]: 57 | return [v.decode("utf-8") 58 | for v in example.features.feature[key].bytes_list.value] 59 | 60 | 61 | def get_int_feature(example: tf.train.Example, key: str) -> int: 62 | return example.features.feature[key].int64_list.value[0] 63 | 64 | 65 | def get_float_feature(example: tf.train.Example, key: str) -> float: 66 | return example.features.feature[key].float_list.value[0] 67 | 68 | 69 | def get_hash(key: str) -> int: 70 | return int(hashlib.sha1(key.encode("utf-8")).hexdigest(), 16) 71 | 72 | 73 | def keep_every(_: Any, ratio: float) -> bool: 74 | return random.random() < ratio 75 | 76 | 77 | def deterministic_sample(items: Iterable[Any], value_fn) -> Any: 78 | return max(items, key=lambda x: get_hash(value_fn(x))) 79 | 80 | 81 | def image_to_bytes(image: Image.Image) -> bytes: 82 | img_byte_arr = io.BytesIO() 83 | image.save(img_byte_arr, format="PNG") 84 | return img_byte_arr.getvalue() 85 | 86 | 87 | def render_header(image: Image.Image, header: str) -> Image.Image: 88 | """Renders a header on a PIL image and returns a new PIL image.""" 89 | header_image = render_text(header) 90 | new_width = max(header_image.width, image.width) 91 | 92 | new_height = int(image.height * (new_width / image.width)) 93 | new_header_height = int( 94 | header_image.height * (new_width / header_image.width)) 95 | 96 | new_image = Image.new( 97 | "RGB", 98 | (new_width, new_height + new_header_height), 99 | "white") 100 | new_image.paste(header_image.resize((new_width, new_header_height)), (0, 0)) 101 | new_image.paste(image.resize((new_width, new_height)), (0, new_header_height)) 102 | 103 | return new_image 104 | 105 | 106 | def render_text(text: str, 107 | text_size: int = 36, 108 | text_color: str = "black", 109 | background_color: str = "white", 110 | left_padding: int = 5, 111 | right_padding: int = 5, 112 | top_padding: int = 5, 113 | bottom_padding: int = 5, 114 | font_bytes: Optional[bytes] = None) -> Image.Image: 115 | """Render text.""" 116 | # Add new lines so that each line is no more than 80 characters. 117 | wrapper = textwrap.TextWrapper(width=80) 118 | lines = wrapper.wrap(text=text) 119 | wrapped_text = "\n".join(lines) 120 | 121 | if font_bytes is not None: 122 | font_spec = io.BytesIO(font_bytes) 123 | else: 124 | font_spec = DEFAULT_FONT_PATH 125 | font = ImageFont.truetype(font_spec, encoding="UTF-8", size=text_size) 126 | 127 | # Use a temporary canvas to determine the width and height in pixels when 128 | # rendering the text. 129 | temp_draw = ImageDraw.Draw(Image.new("RGB", (1, 1), background_color)) 130 | _, _, text_width, text_height = temp_draw.textbbox((0, 0), wrapped_text, font) 131 | 132 | # Create the actual image with a bit of padding around the text. 133 | image_width = int(text_width + left_padding + right_padding) 134 | image_height = int(text_height + top_padding + bottom_padding) 135 | image = Image.new("RGB", (image_width, image_height), background_color) 136 | draw = ImageDraw.Draw(image) 137 | draw.text( 138 | xy=(left_padding, top_padding), 139 | text=wrapped_text, 140 | fill=text_color, 141 | font=font) 142 | return image 143 | 144 | 145 | def render_text_on_bounding_box( 146 | text: str, 147 | bounding_box: Iterable[Iterable[int]], 148 | image: Image.Image): 149 | """Render text on top of a specific bounding box.""" 150 | draw = ImageDraw.Draw(image) 151 | (x0, y0), (x1, y1) = bounding_box 152 | draw.rectangle(xy=[(x0, y0), (x1, y1)], fill=(255, 255, 255, 255)) 153 | 154 | fontsize = 1 155 | def _can_increment_font(ratio=0.95): 156 | next_font = ImageFont.truetype( 157 | DEFAULT_FONT_PATH, encoding="UTF-8", size=fontsize + 1) 158 | width, height = next_font.getsize(text) # pytype: disable=attribute-error # pillow-102-upgrade 159 | return width < ratio * (x1 - x0) and height < ratio * (y1 - y0) 160 | 161 | while _can_increment_font(): 162 | fontsize += 1 163 | font = ImageFont.truetype(DEFAULT_FONT_PATH, encoding="UTF-8", size=fontsize) 164 | 165 | draw.text( 166 | xy=((x0 + x1)/2, (y0 + y1)/2), 167 | text=text, 168 | font=font, 169 | fill="black", 170 | anchor="mm") 171 | 172 | 173 | def increment_counter(item, counter): 174 | counter.inc() 175 | return item 176 | 177 | 178 | class SplitAndWriteTFRecords(beam.PTransform): 179 | """Split and write TFRecords.""" 180 | 181 | def __init__(self, 182 | output_dir: str, 183 | key: str, 184 | validation_percent: Optional[int] = 10, 185 | train_file_name: str = "train.tfr", 186 | val_file_name: str = "val.tfr", 187 | test_file_name: str = "test.tfr", 188 | is_test: Optional[Callable[[str], bool]] = None): 189 | self._output_dir = output_dir 190 | self._key = key 191 | self._validation_percent = validation_percent 192 | self._train_file_name = train_file_name 193 | self._val_file_name = val_file_name 194 | self._test_file_name = test_file_name 195 | self._is_test = is_test 196 | self._train_counter = beam.metrics.Metrics.counter( 197 | "SplitAndWriteTFRecords", "train") 198 | self._val_counter = beam.metrics.Metrics.counter( 199 | "SplitAndWriteTFRecords", "val") 200 | self._test_counter = beam.metrics.Metrics.counter( 201 | "SplitAndWriteTFRecords", "test") 202 | 203 | def _partition_index(self, 204 | example: tf.train.Example, 205 | num_partitions: int) -> int: 206 | assert num_partitions == 3 207 | key_feature = get_text_feature(example, self._key) 208 | if self._is_test is not None and self._is_test(key_feature): 209 | return 2 210 | else: 211 | return int(get_hash(key_feature) % 100 < self._validation_percent) 212 | 213 | def expand(self, pcoll): 214 | train, val, test = (pcoll 215 | | "Shuffle" >> beam.Reshuffle() 216 | | "Partition" >> beam.Partition( 217 | self._partition_index, 3)) 218 | _ = (train 219 | | "CountTrain" >> beam.Map(increment_counter, self._train_counter) 220 | | "WriteTrain" >> beam.io.WriteToTFRecord( 221 | os.path.join(self._output_dir, self._train_file_name), 222 | coder=beam.coders.ProtoCoder(tf.train.Example))) 223 | _ = (val 224 | | "CountVal" >> beam.Map(increment_counter, self._val_counter) 225 | | "WriteVal" >> beam.io.WriteToTFRecord( 226 | os.path.join(self._output_dir, self._val_file_name), 227 | coder=beam.coders.ProtoCoder(tf.train.Example))) 228 | if self._is_test is not None: 229 | _ = (test 230 | | "CountTest" >> beam.Map(increment_counter, self._test_counter) 231 | | "WriteTest" >> beam.io.WriteToTFRecord( 232 | os.path.join(self._output_dir, self._test_file_name), 233 | coder=beam.coders.ProtoCoder(tf.train.Example))) 234 | 235 | 236 | class DeterministicSamplePerKey(beam.PTransform): 237 | """Deterministic sample per key.""" 238 | 239 | def __init__(self, 240 | key_fn: Callable[[Any], str], 241 | value_fn: Callable[[Any], str]): 242 | self._key_fn = key_fn 243 | self._value_fn = value_fn 244 | 245 | def expand(self, pcoll): 246 | return (pcoll 247 | | "AddKeys" >> beam.WithKeys(self._key_fn) 248 | | "SampleOne" >> beam.CombinePerKey( 249 | deterministic_sample, value_fn=self._value_fn) 250 | | "DropKeys" >> beam.Values()) 251 | -------------------------------------------------------------------------------- /pix2struct/preprocessors.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Preprocessors.""" 17 | from typing import Callable, Dict, Mapping, Tuple 18 | 19 | import seqio 20 | import tensorflow as tf 21 | 22 | TensorMapping = Callable[[tf.Tensor], tf.Tensor] 23 | FeaturesDict = Dict[str, tf.Tensor] 24 | FeaturesMapping = Callable[[FeaturesDict], FeaturesDict] 25 | 26 | 27 | def map_feature(key: str, map_fn: TensorMapping) -> FeaturesMapping: 28 | @seqio.utils.map_over_dataset 29 | def _mapper(features: FeaturesDict) -> FeaturesDict: 30 | features[key] = map_fn(features[key]) 31 | return features 32 | return _mapper 33 | 34 | 35 | def image_decoder(key: str, channels: int) -> FeaturesMapping: 36 | return map_feature( 37 | key=key, 38 | map_fn=lambda f: tf.io.decode_png(f, channels=channels)) 39 | 40 | 41 | def read_image(key: str, image_dir: str) -> FeaturesMapping: 42 | return map_feature( 43 | key=key, 44 | map_fn=lambda f: tf.io.read_file(tf.strings.join([image_dir, "/", f]))) 45 | 46 | 47 | def normalize_image(key: str) -> FeaturesMapping: 48 | return map_feature( 49 | key=key, 50 | map_fn=tf.image.per_image_standardization) 51 | 52 | 53 | def sample_one(key: str) -> FeaturesMapping: 54 | return map_feature( 55 | key=key, 56 | map_fn=lambda v: tf.random.shuffle(v)[0]) 57 | 58 | 59 | def patch_sequence( 60 | image: tf.Tensor, 61 | max_patches: int, 62 | patch_size: Tuple[int, int]) -> Tuple[tf.Tensor, tf.Tensor]: 63 | """Extract patch sequence.""" 64 | patch_height, patch_width = patch_size 65 | image_shape = tf.shape(image) 66 | image_height = image_shape[0] 67 | image_width = image_shape[1] 68 | image_channels = image_shape[2] 69 | image_height = tf.cast(image_height, tf.float32) 70 | image_width = tf.cast(image_width, tf.float32) 71 | 72 | # maximize scale s.t. 73 | # ceil(scale * image_height / patch_height) * 74 | # ceil(scale * image_width / patch_width) <= max_patches 75 | scale = tf.sqrt( 76 | max_patches * 77 | (patch_height / image_height) * 78 | (patch_width / image_width)) 79 | num_feasible_rows = tf.maximum(tf.minimum( 80 | tf.math.floor(scale * image_height / patch_height), 81 | max_patches), 1) 82 | num_feasible_cols = tf.maximum(tf.minimum( 83 | tf.math.floor(scale * image_width / patch_width), 84 | max_patches), 1) 85 | resized_height = tf.maximum( 86 | tf.cast(num_feasible_rows * patch_height, tf.int32), 1) 87 | resized_width = tf.maximum( 88 | tf.cast(num_feasible_cols * patch_width, tf.int32), 1) 89 | 90 | image = tf.image.resize( 91 | images=image, 92 | size=(resized_height, resized_width), 93 | preserve_aspect_ratio=False, 94 | antialias=True) 95 | 96 | # [1, rows, columns, patch_height * patch_width * image_channels] 97 | patches = tf.image.extract_patches( 98 | images=tf.expand_dims(image, 0), 99 | sizes=[1, patch_height, patch_width, 1], 100 | strides=[1, patch_height, patch_width, 1], 101 | rates=[1, 1, 1, 1], 102 | padding="SAME") 103 | 104 | patches_shape = tf.shape(patches) 105 | rows = patches_shape[1] 106 | columns = patches_shape[2] 107 | depth = patches_shape[3] 108 | 109 | # [rows * columns, patch_height * patch_width * image_channels] 110 | patches = tf.reshape(patches, [rows * columns, depth]) 111 | 112 | # [rows * columns, 1] 113 | row_ids = tf.reshape( 114 | tf.tile(tf.expand_dims(tf.range(rows), 1), [1, columns]), 115 | [rows * columns, 1]) 116 | col_ids = tf.reshape( 117 | tf.tile(tf.expand_dims(tf.range(columns), 0), [rows, 1]), 118 | [rows * columns, 1]) 119 | 120 | # Offset by 1 so the ids do not contain zeros, which represent padding. 121 | row_ids += 1 122 | col_ids += 1 123 | 124 | # Prepare additional patch information for concatenation with real values. 125 | row_ids = tf.cast(row_ids, tf.float32) 126 | col_ids = tf.cast(col_ids, tf.float32) 127 | 128 | # [rows * columns, 2 + patch_height * patch_width * image_channels] 129 | result = tf.concat([row_ids, col_ids, patches], -1) 130 | 131 | # [max_patches, 2 + patch_height * patch_width * image_channels] 132 | result = tf.pad(result, [[0, max_patches - (rows * columns)], [0, 0]]) 133 | 134 | original_shape = tf.stack( 135 | [rows, columns, patch_height, patch_width, image_channels]) 136 | return result, original_shape 137 | 138 | 139 | def image_to_patches( 140 | key: str, 141 | patch_size: Tuple[int, int] = (16, 16)): 142 | """Image to patches.""" 143 | 144 | @seqio.utils.map_over_dataset 145 | def _mapper(features: FeaturesDict, 146 | sequence_length: Mapping[str, int]) -> FeaturesDict: 147 | inputs, original_shape = patch_sequence( 148 | image=features[key], 149 | max_patches=sequence_length[key], 150 | patch_size=patch_size) 151 | features[key] = inputs 152 | features["original_shape"] = original_shape 153 | return features 154 | 155 | return _mapper 156 | -------------------------------------------------------------------------------- /pix2struct/preprocessors_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for preprocessors.""" 16 | import random 17 | import numpy as np 18 | from pix2struct import preprocessors 19 | import tensorflow as tf 20 | 21 | 22 | class PreprocessorsTest(tf.test.TestCase): 23 | 24 | def test_patch_sequence_divisible(self): 25 | max_patches = 512 26 | patch_size = (16, 16) 27 | expected_depth = (patch_size[0] * patch_size[1] * 3) + 2 28 | 29 | # Perfectly divisible without resizing. 30 | random_image = tf.random.uniform((512, 256, 3)) 31 | patches, original_shape = preprocessors.patch_sequence( 32 | image=random_image, 33 | max_patches=max_patches, 34 | patch_size=patch_size) 35 | valid_patches = patches[patches[:, 0] > 0] 36 | positions = valid_patches[:, :2] 37 | self.assertAllEqual(patches.shape, [max_patches, expected_depth]) 38 | self.assertAllEqual(valid_patches.shape, [max_patches, expected_depth]) 39 | self.assertAllEqual(original_shape.shape, [5]) 40 | self.assertAllGreater(positions, 0) 41 | self.assertAllLessEqual(positions, valid_patches.shape[0]) 42 | 43 | # Perfectly divisible after scaling up. 44 | random_image = tf.random.uniform((1, 2, 3)) 45 | patches, original_shape = preprocessors.patch_sequence( 46 | image=random_image, 47 | max_patches=max_patches, 48 | patch_size=patch_size) 49 | valid_patches = patches[patches[:, 0] > 0] 50 | positions = valid_patches[:, :2] 51 | self.assertAllEqual(patches.shape, [max_patches, expected_depth]) 52 | self.assertAllEqual(valid_patches.shape, [max_patches, expected_depth]) 53 | self.assertAllEqual(original_shape.shape, [5]) 54 | self.assertAllGreater(positions, 0) 55 | self.assertAllLessEqual(positions, valid_patches.shape[0]) 56 | 57 | # Perfectly divisible after scaling down. 58 | random_image = tf.random.uniform((2048, 1024, 3)) 59 | patches, original_shape = preprocessors.patch_sequence( 60 | image=random_image, 61 | max_patches=max_patches, 62 | patch_size=patch_size) 63 | valid_patches = patches[patches[:, 0] > 0] 64 | positions = valid_patches[:, :2] 65 | self.assertAllEqual(patches.shape, [max_patches, expected_depth]) 66 | self.assertAllEqual(valid_patches.shape, [max_patches, expected_depth]) 67 | self.assertAllEqual(original_shape.shape, [5]) 68 | self.assertAllGreater(positions, 0) 69 | self.assertAllLessEqual(positions, valid_patches.shape[0]) 70 | 71 | def test_patch_sequence_random(self): 72 | # Test that random image sizes always respect the `max_patches` constraint 73 | # and always fills up at least half of the capacity. 74 | total_padding = 0 75 | num_trials = 100 76 | max_patches = 512 77 | patch_size = (16, 16) 78 | expected_depth = (patch_size[0] * patch_size[1] * 3) + 2 79 | 80 | for _ in range(num_trials): 81 | random_width = random.randint(1, 10000) 82 | random_height = random.randint(1, 10000) 83 | random_image = tf.random.uniform((random_width, random_height, 3)) 84 | patches, original_shape = preprocessors.patch_sequence( 85 | image=random_image, 86 | max_patches=max_patches, 87 | patch_size=patch_size) 88 | valid_patches = patches[patches[:, 0] > 0] 89 | positions = valid_patches[:, :2] 90 | total_padding += patches.shape[0] - valid_patches.shape[0] 91 | self.assertAllEqual(patches.shape, [max_patches, expected_depth]) 92 | self.assertLessEqual(valid_patches.shape[0], max_patches) 93 | self.assertAllEqual(original_shape.shape, [5]) 94 | self.assertGreaterEqual(valid_patches.shape[0], max_patches / 2) 95 | self.assertAllGreater(positions, 0) 96 | self.assertAllLessEqual(positions, valid_patches.shape[0]) 97 | 98 | # Average padding should be between 0 and half the sequence length. 99 | mean_padding = total_padding / num_trials 100 | self.assertGreater(mean_padding, 0) 101 | self.assertLess(mean_padding, max_patches / 2) 102 | 103 | def test_image_to_patches(self): 104 | random_image = tf.random.uniform((4, 4, 3)) 105 | preprocessor = preprocessors.image_to_patches( 106 | key="inputs", 107 | patch_size=(1, 1)) 108 | sequence_length = {"inputs": 7} 109 | dataset = tf.data.Dataset.from_tensors({"inputs": random_image}) 110 | dataset = preprocessor(dataset, sequence_length=sequence_length) 111 | np.set_printoptions(threshold=np.inf) 112 | print(list(dataset.as_numpy_iterator())) 113 | 114 | if __name__ == "__main__": 115 | tf.test.main() 116 | -------------------------------------------------------------------------------- /pix2struct/tasks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Pix2Struct tasks.""" 16 | import functools 17 | import io 18 | import os 19 | from typing import Any, Callable, List, Optional 20 | import PIL.Image 21 | from pix2struct import metrics 22 | from pix2struct import postprocessors 23 | from pix2struct import preprocessors 24 | import seqio 25 | import tensorflow as tf 26 | 27 | OUTPUT_FEATURES = dict( 28 | inputs=seqio.ContinuousFeature(rank=2, dtype=tf.float32), 29 | targets=seqio.Feature( 30 | vocabulary=seqio.SentencePieceVocabulary( 31 | "gs://pix2struct-data/sentencepiece.model"))) 32 | 33 | KEY_MAP = dict( 34 | inputs="image", 35 | targets="parse", 36 | parse="parse", 37 | image="image", 38 | id="id", 39 | group_id="group_id") 40 | 41 | PREPROCESSORS = [ 42 | functools.partial(seqio.preprocessors.rekey, key_map=KEY_MAP), 43 | preprocessors.sample_one(key="targets"), 44 | preprocessors.image_decoder(key="inputs", channels=3), 45 | preprocessors.normalize_image(key="inputs"), 46 | preprocessors.image_to_patches(key="inputs"), 47 | seqio.preprocessors.tokenize_and_append_eos, 48 | ] 49 | 50 | FEATURE_DESCRIPTION = { 51 | "id": tf.io.FixedLenFeature([], tf.string, default_value="no-id"), 52 | "image": tf.io.FixedLenFeature([], tf.string), 53 | "parse": tf.io.FixedLenSequenceFeature([], tf.string, allow_missing=True), 54 | "group_id": tf.io.FixedLenFeature( 55 | [], tf.string, default_value="no-group-id"), 56 | } 57 | 58 | 59 | def add_pix2struct_task( 60 | name: str, 61 | base_dir: str, 62 | train_file_pattern: str, 63 | valid_file_pattern: str, 64 | test_file_pattern: Optional[str] = None, 65 | metric_fns: Optional[List[seqio.dataset_providers.MetricFnCallable]] = None, 66 | postprocess_fn: Optional[Callable[..., Any]] = None): 67 | """Add pix2struct task.""" 68 | split_to_filepattern = { 69 | "train": os.path.join(base_dir, train_file_pattern), 70 | "validation": os.path.join(base_dir, valid_file_pattern) 71 | } 72 | if test_file_pattern is not None: 73 | split_to_filepattern["test"] = os.path.join(base_dir, test_file_pattern) 74 | 75 | seqio.TaskRegistry.add( 76 | name=name, 77 | source=seqio.TFExampleDataSource( 78 | split_to_filepattern=split_to_filepattern, 79 | feature_description=FEATURE_DESCRIPTION), 80 | preprocessors=PREPROCESSORS, 81 | output_features=OUTPUT_FEATURES, 82 | postprocess_fn=postprocess_fn or postprocessors.multi_target, 83 | metric_fns=metric_fns or [metrics.pix2struct_metrics]) 84 | 85 | 86 | # Placeholder task to be used during demos. 87 | placeholder_bytes = io.BytesIO() 88 | PIL.Image.new("RGB", size=(1, 1)).save(placeholder_bytes, "png") 89 | placeholder_dataset = tf.data.Dataset.from_tensors({ 90 | "image": placeholder_bytes.getvalue(), 91 | "parse": [""], 92 | "id": "", 93 | "group_id": "", 94 | }) 95 | seqio.TaskRegistry.add( 96 | name="placeholder_pix2struct", 97 | source=seqio.FunctionDataSource( 98 | dataset_fn=lambda split, shuffle_files: placeholder_dataset, 99 | splits=("placeholder",), 100 | ), 101 | preprocessors=PREPROCESSORS, 102 | output_features=OUTPUT_FEATURES, 103 | ) 104 | 105 | # TextCaps dataset from https://textvqa.org/textcaps/. 106 | add_pix2struct_task( 107 | name="textcaps", 108 | base_dir=os.environ.get("PIX2STRUCT_DIR", "") + "/data", 109 | train_file_pattern="textcaps/processed/train.tfr*", 110 | valid_file_pattern="textcaps/processed/val.tfr*", 111 | test_file_pattern="textcaps/processed/test.tfr*") 112 | 113 | # Screen2Words dataset. 114 | add_pix2struct_task( 115 | name="screen2words", 116 | base_dir=os.environ.get("PIX2STRUCT_DIR", "") + "/data", 117 | train_file_pattern="screen2words/processed/train.tfr*", 118 | valid_file_pattern="screen2words/processed/dev.tfr*", 119 | test_file_pattern="screen2words/processed/test.tfr*", 120 | ) 121 | 122 | # DocVQA (https://arxiv.org/abs/2007.00398). 123 | add_pix2struct_task( 124 | name="docvqa", 125 | base_dir=os.environ.get("PIX2STRUCT_DIR", "") + "/data", 126 | train_file_pattern="docvqa/processed/train.tfr*", 127 | valid_file_pattern="docvqa/processed/val.tfr*", 128 | test_file_pattern="docvqa/processed/test.tfr*") 129 | 130 | add_pix2struct_task( 131 | name="infographicvqa", 132 | base_dir=os.environ.get("PIX2STRUCT_DIR", "") + "/data", 133 | train_file_pattern="infographicvqa/processed/train.tfr*", 134 | valid_file_pattern="infographicvqa/processed/val.tfr*", 135 | test_file_pattern="infographicvqa/processed/test.tfr*") 136 | 137 | add_pix2struct_task( 138 | name="ocrvqa", 139 | base_dir=os.environ.get("PIX2STRUCT_DIR", "") + "/data", 140 | train_file_pattern="ocrvqa/processed/train.tfr*", 141 | valid_file_pattern="ocrvqa/processed/val.tfr*", 142 | test_file_pattern="ocrvqa/processed/test.tfr*") 143 | 144 | add_pix2struct_task( 145 | name="chartqa_augmented", 146 | base_dir=os.environ.get("PIX2STRUCT_DIR", "") + "/data", 147 | train_file_pattern="chartqa/processed_augmented/train.tfr*", 148 | valid_file_pattern="chartqa/processed_augmented/val.tfr*", 149 | test_file_pattern="chartqa/processed_augmented/test.tfr*") 150 | 151 | add_pix2struct_task( 152 | name="chartqa_human", 153 | base_dir=os.environ.get("PIX2STRUCT_DIR", "") + "/data", 154 | train_file_pattern="chartqa/processed_human/train.tfr*", 155 | valid_file_pattern="chartqa/processed_human/val.tfr*", 156 | test_file_pattern="chartqa/processed_human/test.tfr*") 157 | 158 | seqio.MixtureRegistry.add( 159 | "chartqa", 160 | ["chartqa_human", "chartqa_augmented"], 161 | default_rate=1.0) 162 | 163 | add_pix2struct_task( 164 | name="ai2d", 165 | base_dir=os.environ.get("PIX2STRUCT_DIR", "") + "/data", 166 | train_file_pattern="ai2d/processed/train.tfr*", 167 | valid_file_pattern="ai2d/processed/val.tfr*", 168 | test_file_pattern="ai2d/processed/test.tfr*") 169 | 170 | add_pix2struct_task( 171 | name="refexp", 172 | base_dir=os.environ.get("PIX2STRUCT_DIR", "") + "/data", 173 | train_file_pattern="refexp/processed/train.tfr*", 174 | valid_file_pattern="refexp/processed/val.tfr*", 175 | test_file_pattern="refexp/processed/test.tfr*", 176 | metric_fns=[functools.partial( 177 | metrics.instance_ranking_metrics, 178 | group_fn=lambda t: t["group_id"], 179 | correct_fn=lambda t: t["parse"][0] == "true", 180 | ranking_fn=lambda p, s: (p == "true", s * (1 if p == "true" else -1)) 181 | )], 182 | postprocess_fn=postprocessors.group_target) 183 | 184 | add_pix2struct_task( 185 | name="widget_captioning", 186 | base_dir=os.environ.get("PIX2STRUCT_DIR", "") + "/data", 187 | train_file_pattern="widget_captioning/processed/train.tfr*", 188 | valid_file_pattern="widget_captioning/processed/val.tfr*", 189 | test_file_pattern="widget_captioning/processed/test.tfr*") 190 | -------------------------------------------------------------------------------- /pix2struct/transfer_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Transfer utils.""" 16 | import dataclasses 17 | import os 18 | from typing import Optional 19 | import gin 20 | 21 | import optax 22 | from t5x import utils 23 | 24 | 25 | @dataclasses.dataclass 26 | class TransferRestoreCheckpointConfig(utils.RestoreCheckpointConfig): 27 | """Transfer restore checkpoint config.""" 28 | steps: Optional[int] = None 29 | 30 | def __post_init__(self): 31 | super().__post_init__() 32 | if self.steps is not None: 33 | assert self.mode == "specific" 34 | self.path = os.path.join(self.path, f"checkpoint_{self.steps}") 35 | 36 | 37 | def transfer_warmup_cosine_decay_schedule( 38 | peak_value: float, 39 | warmup_steps: int, 40 | start_step: int, 41 | end_step: int, 42 | end_value: float = 0.0, 43 | cycle_length_ratio: float = 1.0, 44 | ) -> optax.Schedule: 45 | """Warmup cosine decay schedule with offset.""" 46 | assert end_step >= start_step 47 | 48 | # Optionally adjust cycle length to overshoot the actually number of steps in 49 | # order to not stop at exactly 0. See https://arxiv.org/abs/2203.15556. 50 | decay_steps = int((end_step - start_step) * cycle_length_ratio) 51 | 52 | schedules = [ 53 | optax.linear_schedule( 54 | init_value=0, 55 | end_value=0, 56 | transition_steps=start_step), 57 | optax.warmup_cosine_decay_schedule( 58 | init_value=0, 59 | peak_value=peak_value, 60 | warmup_steps=warmup_steps, 61 | decay_steps=decay_steps, 62 | end_value=end_value)] 63 | return optax.join_schedules(schedules, [start_step]) 64 | 65 | 66 | @gin.configurable 67 | def add(a: int = gin.REQUIRED, b: int = gin.REQUIRED): 68 | return a + b 69 | -------------------------------------------------------------------------------- /pix2struct/web/static/style.css: -------------------------------------------------------------------------------- 1 | .pix2struct-table { 2 | word-wrap: break-word; 3 | border-collapse: separate; 4 | border-spacing: 20px; 5 | table-layout: fixed; 6 | width: 100%; 7 | } 8 | 9 | .pix2struct-cell { 10 | vertical-align:top; 11 | } 12 | 13 | .pix2struct-image { 14 | border: 1px solid #000000; 15 | max-width: 100%; 16 | } 17 | -------------------------------------------------------------------------------- /pix2struct/web/templates/demo_screenshot.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Pix2Struct Demo 6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 |
14 |
15 | 16 | Pix2Struct Demo 17 | 18 |
19 |
20 |
21 |
22 | Please upload an image with some text (e.g. web screenshot, document page, etc.) and maybe write a question if this server is hosting a VQA model. 23 |
24 |
25 | 26 |
27 |
28 |
29 | 30 | 31 |
32 |
33 | 34 |
35 |
36 | 37 | 38 | 39 | 40 | 41 | 42 | 45 | 48 | 49 |
ImagePrediction
43 | 44 | 46 | {{prediction}} 47 |
50 |
51 |
52 |
53 |
54 | 55 | 56 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The pix2struct Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """setup.py for pix2struct.""" 16 | import setuptools 17 | 18 | setuptools.setup( 19 | name="pix2struct", 20 | version="0.0.1", 21 | packages=setuptools.find_packages(), 22 | package_data={"": ["web/**/*.css", "web/**/*.html", "configs/**/*.gin"]}, 23 | extras_require={ 24 | "dev": [ 25 | "pytest", 26 | "gin-config", 27 | "t5x[tpu] @ git+https://github.com/google-research/t5x", 28 | "flaxformer @ git+https://github.com/google/flaxformer", 29 | "pycocoevalcap", 30 | "apache-beam[gcp]", 31 | "jinja2", 32 | "tornado==3.2.2", 33 | ], 34 | }, 35 | ) 36 | --------------------------------------------------------------------------------