├── .gitignore ├── LICENSE ├── NOTICE ├── README.md ├── app.py ├── config ├── train_cord.yaml ├── train_docvqa.yaml ├── train_rvlcdip.yaml └── train_zhtrainticket.yaml ├── dataset └── .gitkeep ├── donut ├── __init__.py ├── _version.py ├── model.py └── util.py ├── lightning_module.py ├── misc ├── overview.png ├── sample_image_cord_test_receipt_00004.png ├── sample_image_donut_document.png ├── sample_synthdog.png └── screenshot_gradio_demos.png ├── result └── .gitkeep ├── setup.py ├── synthdog ├── README.md ├── config_en.yaml ├── config_ja.yaml ├── config_ko.yaml ├── config_zh.yaml ├── elements │ ├── __init__.py │ ├── background.py │ ├── content.py │ ├── document.py │ ├── paper.py │ └── textbox.py ├── layouts │ ├── __init__.py │ ├── grid.py │ └── grid_stack.py ├── resources │ ├── background │ │ ├── bedroom_83.jpg │ │ ├── bob+dylan_83.jpg │ │ ├── coffee_122.jpg │ │ ├── coffee_18.jpeg │ │ ├── crater_141.jpg │ │ ├── cream_124.jpg │ │ ├── eagle_110.jpg │ │ ├── farm_25.jpg │ │ └── hiking_18.jpg │ ├── corpus │ │ ├── enwiki.txt │ │ ├── jawiki.txt │ │ ├── kowiki.txt │ │ └── zhwiki.txt │ ├── font │ │ ├── en │ │ │ ├── NotoSans-Regular.ttf │ │ │ └── NotoSerif-Regular.ttf │ │ ├── ja │ │ │ ├── NotoSansJP-Regular.otf │ │ │ └── NotoSerifJP-Regular.otf │ │ ├── ko │ │ │ ├── NotoSansKR-Regular.otf │ │ │ └── NotoSerifKR-Regular.otf │ │ └── zh │ │ │ ├── NotoSansSC-Regular.otf │ │ │ └── NotoSerifSC-Regular.otf │ └── paper │ │ ├── paper_1.jpg │ │ ├── paper_2.jpg │ │ ├── paper_3.jpg │ │ ├── paper_4.jpg │ │ ├── paper_5.jpg │ │ └── paper_6.jpg └── template.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | core.* 2 | *.bin 3 | .nfs* 4 | .vscode/* 5 | dataset/* 6 | result/* 7 | misc/* 8 | !misc/*.png 9 | !dataset/.gitkeep 10 | !result/.gitkeep 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | pip-wheel-metadata/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | *.py,cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT license 2 | 3 | Copyright (c) 2022-present NAVER Corp. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Donut 2 | Copyright (c) 2022-present NAVER Corp. 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy 5 | of this software and associated documentation files (the "Software"), to deal 6 | in the Software without restriction, including without limitation the rights 7 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | copies of the Software, and to permit persons to whom the Software is 9 | furnished to do so, subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included in 12 | all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 | THE SOFTWARE. 21 | 22 | -------------------------------------------------------------------------------------- 23 | 24 | This project contains subcomponents with separate copyright notices and license terms. 25 | Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses. 26 | 27 | ===== 28 | 29 | googlefonts/noto-fonts 30 | https://fonts.google.com/specimen/Noto+Sans 31 | 32 | 33 | Copyright 2018 The Noto Project Authors (github.com/googlei18n/noto-fonts) 34 | 35 | This Font Software is licensed under the SIL Open Font License, 36 | Version 1.1. 37 | 38 | This license is copied below, and is also available with a FAQ at: 39 | http://scripts.sil.org/OFL 40 | 41 | ----------------------------------------------------------- 42 | SIL OPEN FONT LICENSE Version 1.1 - 26 February 2007 43 | ----------------------------------------------------------- 44 | 45 | PREAMBLE 46 | The goals of the Open Font License (OFL) are to stimulate worldwide 47 | development of collaborative font projects, to support the font 48 | creation efforts of academic and linguistic communities, and to 49 | provide a free and open framework in which fonts may be shared and 50 | improved in partnership with others. 51 | 52 | The OFL allows the licensed fonts to be used, studied, modified and 53 | redistributed freely as long as they are not sold by themselves. The 54 | fonts, including any derivative works, can be bundled, embedded, 55 | redistributed and/or sold with any software provided that any reserved 56 | names are not used by derivative works. The fonts and derivatives, 57 | however, cannot be released under any other type of license. The 58 | requirement for fonts to remain under this license does not apply to 59 | any document created using the fonts or their derivatives. 60 | 61 | DEFINITIONS 62 | "Font Software" refers to the set of files released by the Copyright 63 | Holder(s) under this license and clearly marked as such. This may 64 | include source files, build scripts and documentation. 65 | 66 | "Reserved Font Name" refers to any names specified as such after the 67 | copyright statement(s). 68 | 69 | "Original Version" refers to the collection of Font Software 70 | components as distributed by the Copyright Holder(s). 71 | 72 | "Modified Version" refers to any derivative made by adding to, 73 | deleting, or substituting -- in part or in whole -- any of the 74 | components of the Original Version, by changing formats or by porting 75 | the Font Software to a new environment. 76 | 77 | "Author" refers to any designer, engineer, programmer, technical 78 | writer or other person who contributed to the Font Software. 79 | 80 | PERMISSION & CONDITIONS 81 | Permission is hereby granted, free of charge, to any person obtaining 82 | a copy of the Font Software, to use, study, copy, merge, embed, 83 | modify, redistribute, and sell modified and unmodified copies of the 84 | Font Software, subject to the following conditions: 85 | 86 | 1) Neither the Font Software nor any of its individual components, in 87 | Original or Modified Versions, may be sold by itself. 88 | 89 | 2) Original or Modified Versions of the Font Software may be bundled, 90 | redistributed and/or sold with any software, provided that each copy 91 | contains the above copyright notice and this license. These can be 92 | included either as stand-alone text files, human-readable headers or 93 | in the appropriate machine-readable metadata fields within text or 94 | binary files as long as those fields can be easily viewed by the user. 95 | 96 | 3) No Modified Version of the Font Software may use the Reserved Font 97 | Name(s) unless explicit written permission is granted by the 98 | corresponding Copyright Holder. This restriction only applies to the 99 | primary font name as presented to the users. 100 | 101 | 4) The name(s) of the Copyright Holder(s) or the Author(s) of the Font 102 | Software shall not be used to promote, endorse or advertise any 103 | Modified Version, except to acknowledge the contribution(s) of the 104 | Copyright Holder(s) and the Author(s) or with their explicit written 105 | permission. 106 | 107 | 5) The Font Software, modified or unmodified, in part or in whole, 108 | must be distributed entirely under this license, and must not be 109 | distributed under any other license. The requirement for fonts to 110 | remain under this license does not apply to any document created using 111 | the Font Software. 112 | 113 | TERMINATION 114 | This license becomes null and void if any of the above conditions are 115 | not met. 116 | 117 | DISCLAIMER 118 | THE FONT SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 119 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY WARRANTIES OF 120 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT 121 | OF COPYRIGHT, PATENT, TRADEMARK, OR OTHER RIGHT. IN NO EVENT SHALL THE 122 | COPYRIGHT HOLDER BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 123 | INCLUDING ANY GENERAL, SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL 124 | DAMAGES, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 125 | FROM, OUT OF THE USE OR INABILITY TO USE THE FONT SOFTWARE OR FROM 126 | OTHER DEALINGS IN THE FONT SOFTWARE. 127 | 128 | ===== 129 | 130 | huggingface/transformers 131 | https://github.com/huggingface/transformers 132 | 133 | 134 | Copyright [yyyy] [name of copyright owner] 135 | 136 | Licensed under the Apache License, Version 2.0 (the "License"); 137 | you may not use this file except in compliance with the License. 138 | You may obtain a copy of the License at 139 | 140 | http://www.apache.org/licenses/LICENSE-2.0 141 | 142 | Unless required by applicable law or agreed to in writing, software 143 | distributed under the License is distributed on an "AS IS" BASIS, 144 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 145 | See the License for the specific language governing permissions and limitations under the License. 146 | 147 | ===== 148 | 149 | clovaai/synthtiger 150 | https://github.com/clovaai/synthtiger 151 | 152 | 153 | Copyright (c) 2021-present NAVER Corp. 154 | 155 | Permission is hereby granted, free of charge, to any person obtaining a copy 156 | of this software and associated documentation files (the "Software"), to deal 157 | in the Software without restriction, including without limitation the rights 158 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 159 | copies of the Software, and to permit persons to whom the Software is 160 | furnished to do so, subject to the following conditions: 161 | 162 | The above copyright notice and this permission notice shall be included in 163 | all copies or substantial portions of the Software. 164 | 165 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 166 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 167 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 168 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 169 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 170 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 171 | THE SOFTWARE. 172 | 173 | ===== 174 | 175 | rwightman/pytorch-image-models 176 | https://github.com/rwightman/pytorch-image-models 177 | 178 | 179 | Copyright 2019 Ross Wightman 180 | 181 | Licensed under the Apache License, Version 2.0 (the "License"); 182 | you may not use this file except in compliance with the License. 183 | You may obtain a copy of the License at 184 | 185 | http://www.apache.org/licenses/LICENSE-2.0 186 | 187 | Unless required by applicable law or agreed to in writing, software 188 | distributed under the License is distributed on an "AS IS" BASIS, 189 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 190 | See the License for the specific language governing permissions and 191 | limitations under the License. 192 | 193 | ===== 194 | 195 | ankush-me/SynthText 196 | https://github.com/ankush-me/SynthText 197 | 198 | 199 | Copyright 2017, Ankush Gupta. 200 | 201 | Licensed under the Apache License, Version 2.0 (the "License"); 202 | you may not use this file except in compliance with the License. 203 | You may obtain a copy of the License at 204 | 205 | http://www.apache.org/licenses/LICENSE-2.0 206 | 207 | Unless required by applicable law or agreed to in writing, software 208 | distributed under the License is distributed on an "AS IS" BASIS, 209 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 210 | See the License for the specific language governing permissions and 211 | limitations under the License. 212 | 213 | ===== 214 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 在线测试 2 | ### 本人尝试使用donut多模态模型,对身份证进行识别,获得了比较好的效果,测试链接如下: 3 | 网址为: http://116.198.235.162:7860/ 4 | 5 | # 训练流程 6 | ### 基础训练指导 7 | 参考: https://www.philschmid.de/fine-tuning-donut 8 | 9 | ### 身份证样本生成 10 | 使用工具:https://github.com/airob0t/idcardgenerator 11 | 12 | ### 使用donut模型对样本进行训练 13 | 14 | # 联系方式 15 | 如果有老板觉得我训练模型尚可,而自己不想训练的可以加我微信购买 16 | 微信:lex_workshop 17 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | """ 2 | Donut 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT License 5 | """ 6 | import argparse 7 | 8 | import gradio as gr 9 | import torch 10 | from PIL import Image 11 | 12 | from donut import DonutModel 13 | 14 | 15 | def demo_process_vqa(input_img, question): 16 | global pretrained_model, task_prompt, task_name 17 | input_img = Image.fromarray(input_img) 18 | user_prompt = task_prompt.replace("{user_input}", question) 19 | output = pretrained_model.inference(input_img, prompt=user_prompt)["predictions"][0] 20 | return output 21 | 22 | 23 | def demo_process(input_img): 24 | global pretrained_model, task_prompt, task_name 25 | input_img = Image.fromarray(input_img) 26 | output = pretrained_model.inference(image=input_img, prompt=task_prompt)["predictions"][0] 27 | return output 28 | 29 | 30 | if __name__ == "__main__": 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("--task", type=str, default="docvqa") 33 | parser.add_argument("--pretrained_path", type=str, default="naver-clova-ix/donut-base-finetuned-docvqa") 34 | parser.add_argument("--port", type=int, default=None) 35 | parser.add_argument("--url", type=str, default=None) 36 | parser.add_argument("--sample_img_path", type=str) 37 | args, left_argv = parser.parse_known_args() 38 | 39 | task_name = args.task 40 | if "docvqa" == task_name: 41 | task_prompt = "{user_input}" 42 | else: # rvlcdip, cord, ... 43 | task_prompt = f"" 44 | 45 | example_sample = [] 46 | if args.sample_img_path: 47 | example_sample.append(args.sample_img_path) 48 | 49 | pretrained_model = DonutModel.from_pretrained(args.pretrained_path) 50 | 51 | if torch.cuda.is_available(): 52 | pretrained_model.half() 53 | device = torch.device("cuda") 54 | pretrained_model.to(device) 55 | 56 | pretrained_model.eval() 57 | 58 | demo = gr.Interface( 59 | fn=demo_process_vqa if task_name == "docvqa" else demo_process, 60 | inputs=["image", "text"] if task_name == "docvqa" else "image", 61 | outputs="json", 62 | title=f"Donut 🍩 demonstration for `{task_name}` task", 63 | examples=[example_sample] if example_sample else None, 64 | ) 65 | demo.launch(server_name=args.url, server_port=args.port) 66 | -------------------------------------------------------------------------------- /config/train_cord.yaml: -------------------------------------------------------------------------------- 1 | resume_from_checkpoint_path: null # only used for resume_from_checkpoint option in PL 2 | result_path: "./result" 3 | pretrained_model_name_or_path: "naver-clova-ix/donut-base" # loading a pre-trained model (from moldehub or path) 4 | dataset_name_or_paths: ["naver-clova-ix/cord-v2"] # loading datasets (from moldehub or path) 5 | sort_json_key: False # cord dataset is preprocessed, and publicly available at https://huggingface.co/datasets/naver-clova-ix/cord-v2 6 | train_batch_sizes: [8] 7 | val_batch_sizes: [1] 8 | input_size: [1280, 960] # when the input resolution differs from the pre-training setting, some weights will be newly initialized (but the model training would be okay) 9 | max_length: 768 10 | align_long_axis: False 11 | num_nodes: 1 12 | seed: 2022 13 | lr: 3e-5 14 | warmup_steps: 300 # 800/8*30/10, 10% 15 | num_training_samples_per_epoch: 800 16 | max_epochs: 30 17 | max_steps: -1 18 | num_workers: 8 19 | val_check_interval: 1.0 20 | check_val_every_n_epoch: 3 21 | gradient_clip_val: 1.0 22 | verbose: True 23 | -------------------------------------------------------------------------------- /config/train_docvqa.yaml: -------------------------------------------------------------------------------- 1 | resume_from_checkpoint_path: null 2 | result_path: "./result" 3 | pretrained_model_name_or_path: "naver-clova-ix/donut-base" 4 | dataset_name_or_paths: ["./dataset/docvqa"] # should be prepared from https://rrc.cvc.uab.es/?ch=17 5 | sort_json_key: True 6 | train_batch_sizes: [2] 7 | val_batch_sizes: [4] 8 | input_size: [2560, 1920] 9 | max_length: 128 10 | align_long_axis: False 11 | # num_nodes: 8 # memo: donut-base-finetuned-docvqa was trained with 8 nodes 12 | num_nodes: 1 13 | seed: 2022 14 | lr: 3e-5 15 | warmup_steps: 10000 16 | num_training_samples_per_epoch: 39463 17 | max_epochs: 300 18 | max_steps: -1 19 | num_workers: 8 20 | val_check_interval: 1.0 21 | check_val_every_n_epoch: 1 22 | gradient_clip_val: 0.25 23 | verbose: True 24 | -------------------------------------------------------------------------------- /config/train_rvlcdip.yaml: -------------------------------------------------------------------------------- 1 | resume_from_checkpoint_path: null 2 | result_path: "./result" 3 | pretrained_model_name_or_path: "naver-clova-ix/donut-base" 4 | dataset_name_or_paths: ["./dataset/rvlcdip"] # should be prepared from https://www.cs.cmu.edu/~aharley/rvl-cdip/ 5 | sort_json_key: True 6 | train_batch_sizes: [2] 7 | val_batch_sizes: [4] 8 | input_size: [2560, 1920] 9 | max_length: 8 10 | align_long_axis: False 11 | # num_nodes: 8 # memo: donut-base-finetuned-rvlcdip was trained with 8 nodes 12 | num_nodes: 1 13 | seed: 2022 14 | lr: 2e-5 15 | warmup_steps: 10000 16 | num_training_samples_per_epoch: 320000 17 | max_epochs: 100 18 | max_steps: -1 19 | num_workers: 8 20 | val_check_interval: 1.0 21 | check_val_every_n_epoch: 1 22 | gradient_clip_val: 1.0 23 | verbose: True 24 | -------------------------------------------------------------------------------- /config/train_zhtrainticket.yaml: -------------------------------------------------------------------------------- 1 | resume_from_checkpoint_path: null 2 | result_path: "./result" 3 | pretrained_model_name_or_path: "naver-clova-ix/donut-base" 4 | dataset_name_or_paths: ["./dataset/zhtrainticket"] # should be prepared from https://github.com/beacandler/EATEN 5 | sort_json_key: True 6 | train_batch_sizes: [8] 7 | val_batch_sizes: [1] 8 | input_size: [960, 1280] 9 | max_length: 256 10 | align_long_axis: False 11 | num_nodes: 1 12 | seed: 2022 13 | lr: 3e-5 14 | warmup_steps: 300 15 | num_training_samples_per_epoch: 1368 16 | max_epochs: 10 17 | max_steps: -1 18 | num_workers: 8 19 | val_check_interval: 1.0 20 | check_val_every_n_epoch: 1 21 | gradient_clip_val: 1.0 22 | verbose: True 23 | -------------------------------------------------------------------------------- /dataset/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /donut/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Donut 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT License 5 | """ 6 | from .model import DonutConfig, DonutModel 7 | from .util import DonutDataset, JSONParseEvaluator, load_json, save_json 8 | 9 | __all__ = [ 10 | "DonutConfig", 11 | "DonutModel", 12 | "DonutDataset", 13 | "JSONParseEvaluator", 14 | "load_json", 15 | "save_json", 16 | ] 17 | -------------------------------------------------------------------------------- /donut/_version.py: -------------------------------------------------------------------------------- 1 | """ 2 | Donut 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT License 5 | """ 6 | __version__ = "1.0.9" 7 | -------------------------------------------------------------------------------- /donut/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Donut 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT License 5 | """ 6 | import math 7 | import os 8 | import re 9 | from typing import Any, List, Optional, Union 10 | 11 | import numpy as np 12 | import PIL 13 | import timm 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | from PIL import ImageOps 18 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 19 | from timm.models.swin_transformer import SwinTransformer 20 | from torchvision import transforms 21 | from torchvision.transforms.functional import resize, rotate 22 | from transformers import MBartConfig, MBartForCausalLM, XLMRobertaTokenizer 23 | from transformers.file_utils import ModelOutput 24 | from transformers.modeling_utils import PretrainedConfig, PreTrainedModel 25 | 26 | 27 | class SwinEncoder(nn.Module): 28 | r""" 29 | Donut encoder based on SwinTransformer 30 | Set the initial weights and configuration with a pretrained SwinTransformer and then 31 | modify the detailed configurations as a Donut Encoder 32 | 33 | Args: 34 | input_size: Input image size (width, height) 35 | align_long_axis: Whether to rotate image if height is greater than width 36 | window_size: Window size(=patch size) of SwinTransformer 37 | encoder_layer: Number of layers of SwinTransformer encoder 38 | name_or_path: Name of a pretrained model name either registered in huggingface.co. or saved in local. 39 | otherwise, `swin_base_patch4_window12_384` will be set (using `timm`). 40 | """ 41 | 42 | def __init__( 43 | self, 44 | input_size: List[int], 45 | align_long_axis: bool, 46 | window_size: int, 47 | encoder_layer: List[int], 48 | name_or_path: Union[str, bytes, os.PathLike] = None, 49 | ): 50 | super().__init__() 51 | self.input_size = input_size 52 | self.align_long_axis = align_long_axis 53 | self.window_size = window_size 54 | self.encoder_layer = encoder_layer 55 | 56 | self.to_tensor = transforms.Compose( 57 | [ 58 | transforms.ToTensor(), 59 | transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), 60 | ] 61 | ) 62 | 63 | self.model = SwinTransformer( 64 | img_size=self.input_size, 65 | depths=self.encoder_layer, 66 | window_size=self.window_size, 67 | patch_size=4, 68 | embed_dim=128, 69 | num_heads=[4, 8, 16, 32], 70 | num_classes=0, 71 | ) 72 | self.model.norm = None 73 | 74 | # weight init with swin 75 | if not name_or_path: 76 | swin_state_dict = timm.create_model("swin_base_patch4_window12_384", pretrained=True).state_dict() 77 | new_swin_state_dict = self.model.state_dict() 78 | for x in new_swin_state_dict: 79 | if x.endswith("relative_position_index") or x.endswith("attn_mask"): 80 | pass 81 | elif ( 82 | x.endswith("relative_position_bias_table") 83 | and self.model.layers[0].blocks[0].attn.window_size[0] != 12 84 | ): 85 | pos_bias = swin_state_dict[x].unsqueeze(0)[0] 86 | old_len = int(math.sqrt(len(pos_bias))) 87 | new_len = int(2 * window_size - 1) 88 | pos_bias = pos_bias.reshape(1, old_len, old_len, -1).permute(0, 3, 1, 2) 89 | pos_bias = F.interpolate(pos_bias, size=(new_len, new_len), mode="bicubic", align_corners=False) 90 | new_swin_state_dict[x] = pos_bias.permute(0, 2, 3, 1).reshape(1, new_len ** 2, -1).squeeze(0) 91 | else: 92 | new_swin_state_dict[x] = swin_state_dict[x] 93 | self.model.load_state_dict(new_swin_state_dict) 94 | 95 | def forward(self, x: torch.Tensor) -> torch.Tensor: 96 | """ 97 | Args: 98 | x: (batch_size, num_channels, height, width) 99 | """ 100 | x = self.model.patch_embed(x) 101 | x = self.model.pos_drop(x) 102 | x = self.model.layers(x) 103 | return x 104 | 105 | def prepare_input(self, img: PIL.Image.Image, random_padding: bool = False) -> torch.Tensor: 106 | """ 107 | Convert PIL Image to tensor according to specified input_size after following steps below: 108 | - resize 109 | - rotate (if align_long_axis is True and image is not aligned longer axis with canvas) 110 | - pad 111 | """ 112 | img = img.convert("RGB") 113 | if self.align_long_axis and ( 114 | (self.input_size[0] > self.input_size[1] and img.width > img.height) 115 | or (self.input_size[0] < self.input_size[1] and img.width < img.height) 116 | ): 117 | img = rotate(img, angle=-90, expand=True) 118 | img = resize(img, min(self.input_size)) 119 | img.thumbnail((self.input_size[1], self.input_size[0])) 120 | delta_width = self.input_size[1] - img.width 121 | delta_height = self.input_size[0] - img.height 122 | if random_padding: 123 | pad_width = np.random.randint(low=0, high=delta_width + 1) 124 | pad_height = np.random.randint(low=0, high=delta_height + 1) 125 | else: 126 | pad_width = delta_width // 2 127 | pad_height = delta_height // 2 128 | padding = ( 129 | pad_width, 130 | pad_height, 131 | delta_width - pad_width, 132 | delta_height - pad_height, 133 | ) 134 | return self.to_tensor(ImageOps.expand(img, padding)) 135 | 136 | 137 | class BARTDecoder(nn.Module): 138 | """ 139 | Donut Decoder based on Multilingual BART 140 | Set the initial weights and configuration with a pretrained multilingual BART model, 141 | and modify the detailed configurations as a Donut decoder 142 | 143 | Args: 144 | decoder_layer: 145 | Number of layers of BARTDecoder 146 | max_position_embeddings: 147 | The maximum sequence length to be trained 148 | name_or_path: 149 | Name of a pretrained model name either registered in huggingface.co. or saved in local, 150 | otherwise, `hyunwoongko/asian-bart-ecjk` will be set (using `transformers`) 151 | """ 152 | 153 | def __init__( 154 | self, decoder_layer: int, max_position_embeddings: int, name_or_path: Union[str, bytes, os.PathLike] = None 155 | ): 156 | super().__init__() 157 | self.decoder_layer = decoder_layer 158 | self.max_position_embeddings = max_position_embeddings 159 | 160 | self.tokenizer = XLMRobertaTokenizer.from_pretrained( 161 | "hyunwoongko/asian-bart-ecjk" if not name_or_path else name_or_path 162 | ) 163 | 164 | self.model = MBartForCausalLM( 165 | config=MBartConfig( 166 | is_decoder=True, 167 | is_encoder_decoder=False, 168 | add_cross_attention=True, 169 | decoder_layers=self.decoder_layer, 170 | max_position_embeddings=self.max_position_embeddings, 171 | vocab_size=len(self.tokenizer), 172 | scale_embedding=True, 173 | add_final_layer_norm=True, 174 | ) 175 | ) 176 | self.model.forward = self.forward # to get cross attentions and utilize `generate` function 177 | 178 | self.model.config.is_encoder_decoder = True # to get cross-attention 179 | self.add_special_tokens([""]) # is used for representing a list in a JSON 180 | self.model.model.decoder.embed_tokens.padding_idx = self.tokenizer.pad_token_id 181 | self.model.prepare_inputs_for_generation = self.prepare_inputs_for_inference 182 | 183 | # weight init with asian-bart 184 | if not name_or_path: 185 | bart_state_dict = MBartForCausalLM.from_pretrained("hyunwoongko/asian-bart-ecjk").state_dict() 186 | new_bart_state_dict = self.model.state_dict() 187 | for x in new_bart_state_dict: 188 | if x.endswith("embed_positions.weight") and self.max_position_embeddings != 1024: 189 | new_bart_state_dict[x] = torch.nn.Parameter( 190 | self.resize_bart_abs_pos_emb( 191 | bart_state_dict[x], 192 | self.max_position_embeddings 193 | + 2, # https://github.com/huggingface/transformers/blob/v4.11.3/src/transformers/models/mbart/modeling_mbart.py#L118-L119 194 | ) 195 | ) 196 | elif x.endswith("embed_tokens.weight") or x.endswith("lm_head.weight"): 197 | new_bart_state_dict[x] = bart_state_dict[x][: len(self.tokenizer), :] 198 | else: 199 | new_bart_state_dict[x] = bart_state_dict[x] 200 | self.model.load_state_dict(new_bart_state_dict) 201 | 202 | def add_special_tokens(self, list_of_tokens: List[str]): 203 | """ 204 | Add special tokens to tokenizer and resize the token embeddings 205 | """ 206 | newly_added_num = self.tokenizer.add_special_tokens({"additional_special_tokens": sorted(set(list_of_tokens))}) 207 | if newly_added_num > 0: 208 | self.model.resize_token_embeddings(len(self.tokenizer)) 209 | 210 | def prepare_inputs_for_inference(self, input_ids: torch.Tensor, encoder_outputs: torch.Tensor, past_key_values=None, past=None, use_cache: bool = None, attention_mask: torch.Tensor = None): 211 | """ 212 | Args: 213 | input_ids: (batch_size, sequence_lenth) 214 | Returns: 215 | input_ids: (batch_size, sequence_length) 216 | attention_mask: (batch_size, sequence_length) 217 | encoder_hidden_states: (batch_size, sequence_length, embedding_dim) 218 | """ 219 | # for compatibility with transformers==4.11.x 220 | if past is not None: 221 | past_key_values = past 222 | attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long() 223 | if past_key_values is not None: 224 | input_ids = input_ids[:, -1:] 225 | output = { 226 | "input_ids": input_ids, 227 | "attention_mask": attention_mask, 228 | "past_key_values": past_key_values, 229 | "use_cache": use_cache, 230 | "encoder_hidden_states": encoder_outputs.last_hidden_state, 231 | } 232 | return output 233 | 234 | def forward( 235 | self, 236 | input_ids, 237 | attention_mask: Optional[torch.Tensor] = None, 238 | encoder_hidden_states: Optional[torch.Tensor] = None, 239 | past_key_values: Optional[torch.Tensor] = None, 240 | labels: Optional[torch.Tensor] = None, 241 | use_cache: bool = None, 242 | output_attentions: Optional[torch.Tensor] = None, 243 | output_hidden_states: Optional[torch.Tensor] = None, 244 | return_dict: bool = None, 245 | ): 246 | """ 247 | A forward fucntion to get cross attentions and utilize `generate` function 248 | 249 | Source: 250 | https://github.com/huggingface/transformers/blob/v4.11.3/src/transformers/models/mbart/modeling_mbart.py#L1669-L1810 251 | 252 | Args: 253 | input_ids: (batch_size, sequence_length) 254 | attention_mask: (batch_size, sequence_length) 255 | encoder_hidden_states: (batch_size, sequence_length, hidden_size) 256 | 257 | Returns: 258 | loss: (1, ) 259 | logits: (batch_size, sequence_length, hidden_dim) 260 | hidden_states: (batch_size, sequence_length, hidden_size) 261 | decoder_attentions: (batch_size, num_heads, sequence_length, sequence_length) 262 | cross_attentions: (batch_size, num_heads, sequence_length, sequence_length) 263 | """ 264 | output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions 265 | output_hidden_states = ( 266 | output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states 267 | ) 268 | return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict 269 | outputs = self.model.model.decoder( 270 | input_ids=input_ids, 271 | attention_mask=attention_mask, 272 | encoder_hidden_states=encoder_hidden_states, 273 | past_key_values=past_key_values, 274 | use_cache=use_cache, 275 | output_attentions=output_attentions, 276 | output_hidden_states=output_hidden_states, 277 | return_dict=return_dict, 278 | ) 279 | 280 | logits = self.model.lm_head(outputs[0]) 281 | 282 | loss = None 283 | if labels is not None: 284 | loss_fct = nn.CrossEntropyLoss(ignore_index=-100) 285 | loss = loss_fct(logits.view(-1, self.model.config.vocab_size), labels.view(-1)) 286 | 287 | if not return_dict: 288 | output = (logits,) + outputs[1:] 289 | return (loss,) + output if loss is not None else output 290 | 291 | return ModelOutput( 292 | loss=loss, 293 | logits=logits, 294 | past_key_values=outputs.past_key_values, 295 | hidden_states=outputs.hidden_states, 296 | decoder_attentions=outputs.attentions, 297 | cross_attentions=outputs.cross_attentions, 298 | ) 299 | 300 | @staticmethod 301 | def resize_bart_abs_pos_emb(weight: torch.Tensor, max_length: int) -> torch.Tensor: 302 | """ 303 | Resize position embeddings 304 | Truncate if sequence length of Bart backbone is greater than given max_length, 305 | else interpolate to max_length 306 | """ 307 | if weight.shape[0] > max_length: 308 | weight = weight[:max_length, ...] 309 | else: 310 | weight = ( 311 | F.interpolate( 312 | weight.permute(1, 0).unsqueeze(0), 313 | size=max_length, 314 | mode="linear", 315 | align_corners=False, 316 | ) 317 | .squeeze(0) 318 | .permute(1, 0) 319 | ) 320 | return weight 321 | 322 | 323 | class DonutConfig(PretrainedConfig): 324 | r""" 325 | This is the configuration class to store the configuration of a [`DonutModel`]. It is used to 326 | instantiate a Donut model according to the specified arguments, defining the model architecture 327 | 328 | Args: 329 | input_size: 330 | Input image size (canvas size) of Donut.encoder, SwinTransformer in this codebase 331 | align_long_axis: 332 | Whether to rotate image if height is greater than width 333 | window_size: 334 | Window size of Donut.encoder, SwinTransformer in this codebase 335 | encoder_layer: 336 | Depth of each Donut.encoder Encoder layer, SwinTransformer in this codebase 337 | decoder_layer: 338 | Number of hidden layers in the Donut.decoder, such as BART 339 | max_position_embeddings 340 | Trained max position embeddings in the Donut decoder, 341 | if not specified, it will have same value with max_length 342 | max_length: 343 | Max position embeddings(=maximum sequence length) you want to train 344 | name_or_path: 345 | Name of a pretrained model name either registered in huggingface.co. or saved in local 346 | """ 347 | 348 | model_type = "donut" 349 | 350 | def __init__( 351 | self, 352 | input_size: List[int] = [2560, 1920], 353 | align_long_axis: bool = False, 354 | window_size: int = 10, 355 | encoder_layer: List[int] = [2, 2, 14, 2], 356 | decoder_layer: int = 4, 357 | max_position_embeddings: int = None, 358 | max_length: int = 1536, 359 | name_or_path: Union[str, bytes, os.PathLike] = "", 360 | **kwargs, 361 | ): 362 | super().__init__() 363 | self.input_size = input_size 364 | self.align_long_axis = align_long_axis 365 | self.window_size = window_size 366 | self.encoder_layer = encoder_layer 367 | self.decoder_layer = decoder_layer 368 | self.max_position_embeddings = max_length if max_position_embeddings is None else max_position_embeddings 369 | self.max_length = max_length 370 | self.name_or_path = name_or_path 371 | 372 | 373 | class DonutModel(PreTrainedModel): 374 | r""" 375 | Donut: an E2E OCR-free Document Understanding Transformer. 376 | The encoder maps an input document image into a set of embeddings, 377 | the decoder predicts a desired token sequence, that can be converted to a structured format, 378 | given a prompt and the encoder output embeddings 379 | """ 380 | config_class = DonutConfig 381 | base_model_prefix = "donut" 382 | 383 | def __init__(self, config: DonutConfig): 384 | super().__init__(config) 385 | self.config = config 386 | self.encoder = SwinEncoder( 387 | input_size=self.config.input_size, 388 | align_long_axis=self.config.align_long_axis, 389 | window_size=self.config.window_size, 390 | encoder_layer=self.config.encoder_layer, 391 | name_or_path=self.config.name_or_path, 392 | ) 393 | self.decoder = BARTDecoder( 394 | max_position_embeddings=self.config.max_position_embeddings, 395 | decoder_layer=self.config.decoder_layer, 396 | name_or_path=self.config.name_or_path, 397 | ) 398 | 399 | def forward(self, image_tensors: torch.Tensor, decoder_input_ids: torch.Tensor, decoder_labels: torch.Tensor): 400 | """ 401 | Calculate a loss given an input image and a desired token sequence, 402 | the model will be trained in a teacher-forcing manner 403 | 404 | Args: 405 | image_tensors: (batch_size, num_channels, height, width) 406 | decoder_input_ids: (batch_size, sequence_length, embedding_dim) 407 | decode_labels: (batch_size, sequence_length) 408 | """ 409 | encoder_outputs = self.encoder(image_tensors) 410 | decoder_outputs = self.decoder( 411 | input_ids=decoder_input_ids, 412 | encoder_hidden_states=encoder_outputs, 413 | labels=decoder_labels, 414 | ) 415 | return decoder_outputs 416 | 417 | def inference( 418 | self, 419 | image: PIL.Image = None, 420 | prompt: str = None, 421 | image_tensors: Optional[torch.Tensor] = None, 422 | prompt_tensors: Optional[torch.Tensor] = None, 423 | return_json: bool = True, 424 | return_attentions: bool = False, 425 | ): 426 | """ 427 | Generate a token sequence in an auto-regressive manner, 428 | the generated token sequence is convereted into an ordered JSON format 429 | 430 | Args: 431 | image: input document image (PIL.Image) 432 | prompt: task prompt (string) to guide Donut Decoder generation 433 | image_tensors: (1, num_channels, height, width) 434 | convert prompt to tensor if image_tensor is not fed 435 | prompt_tensors: (1, sequence_length) 436 | convert image to tensor if prompt_tensor is not fed 437 | """ 438 | # prepare backbone inputs (image and prompt) 439 | if image is None and image_tensors is None: 440 | raise ValueError("Expected either image or image_tensors") 441 | if all(v is None for v in {prompt, prompt_tensors}): 442 | raise ValueError("Expected either prompt or prompt_tensors") 443 | 444 | if image_tensors is None: 445 | image_tensors = self.encoder.prepare_input(image).unsqueeze(0) 446 | 447 | if self.device.type == "cuda": # half is not compatible in cpu implementation. 448 | image_tensors = image_tensors.half() 449 | image_tensors = image_tensors.to(self.device) 450 | 451 | if prompt_tensors is None: 452 | prompt_tensors = self.decoder.tokenizer(prompt, add_special_tokens=False, return_tensors="pt")["input_ids"] 453 | 454 | prompt_tensors = prompt_tensors.to(self.device) 455 | 456 | last_hidden_state = self.encoder(image_tensors) 457 | if self.device.type != "cuda": 458 | last_hidden_state = last_hidden_state.to(torch.float32) 459 | 460 | encoder_outputs = ModelOutput(last_hidden_state=last_hidden_state, attentions=None) 461 | 462 | if len(encoder_outputs.last_hidden_state.size()) == 1: 463 | encoder_outputs.last_hidden_state = encoder_outputs.last_hidden_state.unsqueeze(0) 464 | if len(prompt_tensors.size()) == 1: 465 | prompt_tensors = prompt_tensors.unsqueeze(0) 466 | 467 | # get decoder output 468 | decoder_output = self.decoder.model.generate( 469 | decoder_input_ids=prompt_tensors, 470 | encoder_outputs=encoder_outputs, 471 | max_length=self.config.max_length, 472 | early_stopping=True, 473 | pad_token_id=self.decoder.tokenizer.pad_token_id, 474 | eos_token_id=self.decoder.tokenizer.eos_token_id, 475 | use_cache=True, 476 | num_beams=1, 477 | bad_words_ids=[[self.decoder.tokenizer.unk_token_id]], 478 | return_dict_in_generate=True, 479 | output_attentions=return_attentions, 480 | ) 481 | 482 | output = {"predictions": list()} 483 | for seq in self.decoder.tokenizer.batch_decode(decoder_output.sequences): 484 | seq = seq.replace(self.decoder.tokenizer.eos_token, "").replace(self.decoder.tokenizer.pad_token, "") 485 | seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token 486 | if return_json: 487 | output["predictions"].append(self.token2json(seq)) 488 | else: 489 | output["predictions"].append(seq) 490 | 491 | if return_attentions: 492 | output["attentions"] = { 493 | "self_attentions": decoder_output.decoder_attentions, 494 | "cross_attentions": decoder_output.cross_attentions, 495 | } 496 | 497 | return output 498 | 499 | def json2token(self, obj: Any, update_special_tokens_for_json_key: bool = True, sort_json_key: bool = True): 500 | """ 501 | Convert an ordered JSON object into a token sequence 502 | """ 503 | if type(obj) == dict: 504 | if len(obj) == 1 and "text_sequence" in obj: 505 | return obj["text_sequence"] 506 | else: 507 | output = "" 508 | if sort_json_key: 509 | keys = sorted(obj.keys(), reverse=True) 510 | else: 511 | keys = obj.keys() 512 | for k in keys: 513 | if update_special_tokens_for_json_key: 514 | self.decoder.add_special_tokens([fr"", fr""]) 515 | output += ( 516 | fr"" 517 | + self.json2token(obj[k], update_special_tokens_for_json_key, sort_json_key) 518 | + fr"" 519 | ) 520 | return output 521 | elif type(obj) == list: 522 | return r"".join( 523 | [self.json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj] 524 | ) 525 | else: 526 | obj = str(obj) 527 | if f"<{obj}/>" in self.decoder.tokenizer.all_special_tokens: 528 | obj = f"<{obj}/>" # for categorical special tokens 529 | return obj 530 | 531 | def token2json(self, tokens, is_inner_value=False): 532 | """ 533 | Convert a (generated) token seuqnce into an ordered JSON format 534 | """ 535 | output = dict() 536 | 537 | while tokens: 538 | start_token = re.search(r"", tokens, re.IGNORECASE) 539 | if start_token is None: 540 | break 541 | key = start_token.group(1) 542 | end_token = re.search(fr"", tokens, re.IGNORECASE) 543 | start_token = start_token.group() 544 | if end_token is None: 545 | tokens = tokens.replace(start_token, "") 546 | else: 547 | end_token = end_token.group() 548 | start_token_escaped = re.escape(start_token) 549 | end_token_escaped = re.escape(end_token) 550 | content = re.search(f"{start_token_escaped}(.*?){end_token_escaped}", tokens, re.IGNORECASE) 551 | if content is not None: 552 | content = content.group(1).strip() 553 | if r""): 562 | leaf = leaf.strip() 563 | if ( 564 | leaf in self.decoder.tokenizer.get_added_vocab() 565 | and leaf[0] == "<" 566 | and leaf[-2:] == "/>" 567 | ): 568 | leaf = leaf[1:-2] # for categorical special tokens 569 | output[key].append(leaf) 570 | if len(output[key]) == 1: 571 | output[key] = output[key][0] 572 | 573 | tokens = tokens[tokens.find(end_token) + len(end_token) :].strip() 574 | if tokens[:6] == r"": # non-leaf nodes 575 | return [output] + self.token2json(tokens[6:], is_inner_value=True) 576 | 577 | if len(output): 578 | return [output] if is_inner_value else output 579 | else: 580 | return [] if is_inner_value else {"text_sequence": tokens} 581 | 582 | @classmethod 583 | def from_pretrained( 584 | cls, 585 | pretrained_model_name_or_path: Union[str, bytes, os.PathLike], 586 | *model_args, 587 | **kwargs, 588 | ): 589 | r""" 590 | Instantiate a pretrained donut model from a pre-trained model configuration 591 | 592 | Args: 593 | pretrained_model_name_or_path: 594 | Name of a pretrained model name either registered in huggingface.co. or saved in local, 595 | e.g., `naver-clova-ix/donut-base`, or `naver-clova-ix/donut-base-finetuned-rvlcdip` 596 | """ 597 | model = super(DonutModel, cls).from_pretrained(pretrained_model_name_or_path, revision="official", *model_args, **kwargs) 598 | 599 | # truncate or interplolate position embeddings of donut decoder 600 | max_length = kwargs.get("max_length", model.config.max_position_embeddings) 601 | if ( 602 | max_length != model.config.max_position_embeddings 603 | ): # if max_length of trained model differs max_length you want to train 604 | model.decoder.model.model.decoder.embed_positions.weight = torch.nn.Parameter( 605 | model.decoder.resize_bart_abs_pos_emb( 606 | model.decoder.model.model.decoder.embed_positions.weight, 607 | max_length 608 | + 2, # https://github.com/huggingface/transformers/blob/v4.11.3/src/transformers/models/mbart/modeling_mbart.py#L118-L119 609 | ) 610 | ) 611 | model.config.max_position_embeddings = max_length 612 | 613 | return model 614 | -------------------------------------------------------------------------------- /donut/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Donut 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT License 5 | """ 6 | import json 7 | import os 8 | import random 9 | from collections import defaultdict 10 | from typing import Any, Dict, List, Tuple, Union 11 | 12 | import torch 13 | import zss 14 | from datasets import load_dataset 15 | from nltk import edit_distance 16 | from torch.utils.data import Dataset 17 | from transformers.modeling_utils import PreTrainedModel 18 | from zss import Node 19 | 20 | 21 | def save_json(write_path: Union[str, bytes, os.PathLike], save_obj: Any): 22 | with open(write_path, "w") as f: 23 | json.dump(save_obj, f) 24 | 25 | 26 | def load_json(json_path: Union[str, bytes, os.PathLike]): 27 | with open(json_path, "r") as f: 28 | return json.load(f) 29 | 30 | 31 | class DonutDataset(Dataset): 32 | """ 33 | DonutDataset which is saved in huggingface datasets format. (see details in https://huggingface.co/docs/datasets) 34 | Each row, consists of image path(png/jpg/jpeg) and gt data (json/jsonl/txt), 35 | and it will be converted into input_tensor(vectorized image) and input_ids(tokenized string) 36 | 37 | Args: 38 | dataset_name_or_path: name of dataset (available at huggingface.co/datasets) or the path containing image files and metadata.jsonl 39 | ignore_id: ignore_index for torch.nn.CrossEntropyLoss 40 | task_start_token: the special token to be fed to the decoder to conduct the target task 41 | """ 42 | 43 | def __init__( 44 | self, 45 | dataset_name_or_path: str, 46 | donut_model: PreTrainedModel, 47 | max_length: int, 48 | split: str = "train", 49 | ignore_id: int = -100, 50 | task_start_token: str = "", 51 | prompt_end_token: str = None, 52 | sort_json_key: bool = True, 53 | ): 54 | super().__init__() 55 | 56 | self.donut_model = donut_model 57 | self.max_length = max_length 58 | self.split = split 59 | self.ignore_id = ignore_id 60 | self.task_start_token = task_start_token 61 | self.prompt_end_token = prompt_end_token if prompt_end_token else task_start_token 62 | self.sort_json_key = sort_json_key 63 | 64 | self.dataset = load_dataset(dataset_name_or_path, split=self.split) 65 | self.dataset_length = len(self.dataset) 66 | 67 | self.gt_token_sequences = [] 68 | for sample in self.dataset: 69 | ground_truth = json.loads(sample["ground_truth"]) 70 | if "gt_parses" in ground_truth: # when multiple ground truths are available, e.g., docvqa 71 | assert isinstance(ground_truth["gt_parses"], list) 72 | gt_jsons = ground_truth["gt_parses"] 73 | else: 74 | assert "gt_parse" in ground_truth and isinstance(ground_truth["gt_parse"], dict) 75 | gt_jsons = [ground_truth["gt_parse"]] 76 | 77 | self.gt_token_sequences.append( 78 | [ 79 | task_start_token 80 | + self.donut_model.json2token( 81 | gt_json, 82 | update_special_tokens_for_json_key=self.split == "train", 83 | sort_json_key=self.sort_json_key, 84 | ) 85 | + self.donut_model.decoder.tokenizer.eos_token 86 | for gt_json in gt_jsons # load json from list of json 87 | ] 88 | ) 89 | 90 | self.donut_model.decoder.add_special_tokens([self.task_start_token, self.prompt_end_token]) 91 | self.prompt_end_token_id = self.donut_model.decoder.tokenizer.convert_tokens_to_ids(self.prompt_end_token) 92 | 93 | def __len__(self) -> int: 94 | return self.dataset_length 95 | 96 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 97 | """ 98 | Load image from image_path of given dataset_path and convert into input_tensor and labels. 99 | Convert gt data into input_ids (tokenized string) 100 | 101 | Returns: 102 | input_tensor : preprocessed image 103 | input_ids : tokenized gt_data 104 | labels : masked labels (model doesn't need to predict prompt and pad token) 105 | """ 106 | sample = self.dataset[idx] 107 | 108 | # input_tensor 109 | input_tensor = self.donut_model.encoder.prepare_input(sample["image"], random_padding=self.split == "train") 110 | 111 | # input_ids 112 | processed_parse = random.choice(self.gt_token_sequences[idx]) # can be more than one, e.g., DocVQA Task 1 113 | input_ids = self.donut_model.decoder.tokenizer( 114 | processed_parse, 115 | add_special_tokens=False, 116 | max_length=self.max_length, 117 | padding="max_length", 118 | truncation=True, 119 | return_tensors="pt", 120 | )["input_ids"].squeeze(0) 121 | 122 | if self.split == "train": 123 | labels = input_ids.clone() 124 | labels[ 125 | labels == self.donut_model.decoder.tokenizer.pad_token_id 126 | ] = self.ignore_id # model doesn't need to predict pad token 127 | labels[ 128 | : torch.nonzero(labels == self.prompt_end_token_id).sum() + 1 129 | ] = self.ignore_id # model doesn't need to predict prompt (for VQA) 130 | return input_tensor, input_ids, labels 131 | else: 132 | prompt_end_index = torch.nonzero( 133 | input_ids == self.prompt_end_token_id 134 | ).sum() # return prompt end index instead of target output labels 135 | return input_tensor, input_ids, prompt_end_index, processed_parse 136 | 137 | 138 | class JSONParseEvaluator: 139 | """ 140 | Calculate n-TED(Normalized Tree Edit Distance) based accuracy and F1 accuracy score 141 | """ 142 | 143 | @staticmethod 144 | def flatten(data: dict): 145 | """ 146 | Convert Dictionary into Non-nested Dictionary 147 | Example: 148 | input(dict) 149 | { 150 | "menu": [ 151 | {"name" : ["cake"], "count" : ["2"]}, 152 | {"name" : ["juice"], "count" : ["1"]}, 153 | ] 154 | } 155 | output(list) 156 | [ 157 | ("menu.name", "cake"), 158 | ("menu.count", "2"), 159 | ("menu.name", "juice"), 160 | ("menu.count", "1"), 161 | ] 162 | """ 163 | flatten_data = list() 164 | 165 | def _flatten(value, key=""): 166 | if type(value) is dict: 167 | for child_key, child_value in value.items(): 168 | _flatten(child_value, f"{key}.{child_key}" if key else child_key) 169 | elif type(value) is list: 170 | for value_item in value: 171 | _flatten(value_item, key) 172 | else: 173 | flatten_data.append((key, value)) 174 | 175 | _flatten(data) 176 | return flatten_data 177 | 178 | @staticmethod 179 | def update_cost(node1: Node, node2: Node): 180 | """ 181 | Update cost for tree edit distance. 182 | If both are leaf node, calculate string edit distance between two labels (special token '' will be ignored). 183 | If one of them is leaf node, cost is length of string in leaf node + 1. 184 | If neither are leaf node, cost is 0 if label1 is same with label2 othewise 1 185 | """ 186 | label1 = node1.label 187 | label2 = node2.label 188 | label1_leaf = "" in label1 189 | label2_leaf = "" in label2 190 | if label1_leaf == True and label2_leaf == True: 191 | return edit_distance(label1.replace("", ""), label2.replace("", "")) 192 | elif label1_leaf == False and label2_leaf == True: 193 | return 1 + len(label2.replace("", "")) 194 | elif label1_leaf == True and label2_leaf == False: 195 | return 1 + len(label1.replace("", "")) 196 | else: 197 | return int(label1 != label2) 198 | 199 | @staticmethod 200 | def insert_and_remove_cost(node: Node): 201 | """ 202 | Insert and remove cost for tree edit distance. 203 | If leaf node, cost is length of label name. 204 | Otherwise, 1 205 | """ 206 | label = node.label 207 | if "" in label: 208 | return len(label.replace("", "")) 209 | else: 210 | return 1 211 | 212 | def normalize_dict(self, data: Union[Dict, List, Any]): 213 | """ 214 | Sort by value, while iterate over element if data is list 215 | """ 216 | if not data: 217 | return {} 218 | 219 | if isinstance(data, dict): 220 | new_data = dict() 221 | for key in sorted(data.keys(), key=lambda k: (len(k), k)): 222 | value = self.normalize_dict(data[key]) 223 | if value: 224 | if not isinstance(value, list): 225 | value = [value] 226 | new_data[key] = value 227 | 228 | elif isinstance(data, list): 229 | if all(isinstance(item, dict) for item in data): 230 | new_data = [] 231 | for item in data: 232 | item = self.normalize_dict(item) 233 | if item: 234 | new_data.append(item) 235 | else: 236 | new_data = [str(item).strip() for item in data if type(item) in {str, int, float} and str(item).strip()] 237 | else: 238 | new_data = [str(data).strip()] 239 | 240 | return new_data 241 | 242 | def cal_f1(self, preds: List[dict], answers: List[dict]): 243 | """ 244 | Calculate global F1 accuracy score (field-level, micro-averaged) by counting all true positives, false negatives and false positives 245 | """ 246 | total_tp, total_fn_or_fp = 0, 0 247 | for pred, answer in zip(preds, answers): 248 | pred, answer = self.flatten(self.normalize_dict(pred)), self.flatten(self.normalize_dict(answer)) 249 | for field in pred: 250 | if field in answer: 251 | total_tp += 1 252 | answer.remove(field) 253 | else: 254 | total_fn_or_fp += 1 255 | total_fn_or_fp += len(answer) 256 | return total_tp / (total_tp + total_fn_or_fp / 2) 257 | 258 | def construct_tree_from_dict(self, data: Union[Dict, List], node_name: str = None): 259 | """ 260 | Convert Dictionary into Tree 261 | 262 | Example: 263 | input(dict) 264 | 265 | { 266 | "menu": [ 267 | {"name" : ["cake"], "count" : ["2"]}, 268 | {"name" : ["juice"], "count" : ["1"]}, 269 | ] 270 | } 271 | 272 | output(tree) 273 | 274 | | 275 | menu 276 | / \ 277 | 278 | / | | \ 279 | name count name count 280 | / | | \ 281 | cake 2 juice 1 282 | """ 283 | if node_name is None: 284 | node_name = "" 285 | 286 | node = Node(node_name) 287 | 288 | if isinstance(data, dict): 289 | for key, value in data.items(): 290 | kid_node = self.construct_tree_from_dict(value, key) 291 | node.addkid(kid_node) 292 | elif isinstance(data, list): 293 | if all(isinstance(item, dict) for item in data): 294 | for item in data: 295 | kid_node = self.construct_tree_from_dict( 296 | item, 297 | "", 298 | ) 299 | node.addkid(kid_node) 300 | else: 301 | for item in data: 302 | node.addkid(Node(f"{item}")) 303 | else: 304 | raise Exception(data, node_name) 305 | return node 306 | 307 | def cal_acc(self, pred: dict, answer: dict): 308 | """ 309 | Calculate normalized tree edit distance(nTED) based accuracy. 310 | 1) Construct tree from dict, 311 | 2) Get tree distance with insert/remove/update cost, 312 | 3) Divide distance with GT tree size (i.e., nTED), 313 | 4) Calculate nTED based accuracy. (= max(1 - nTED, 0 ). 314 | """ 315 | pred = self.construct_tree_from_dict(self.normalize_dict(pred)) 316 | answer = self.construct_tree_from_dict(self.normalize_dict(answer)) 317 | return max( 318 | 0, 319 | 1 320 | - ( 321 | zss.distance( 322 | pred, 323 | answer, 324 | get_children=zss.Node.get_children, 325 | insert_cost=self.insert_and_remove_cost, 326 | remove_cost=self.insert_and_remove_cost, 327 | update_cost=self.update_cost, 328 | return_operations=False, 329 | ) 330 | / zss.distance( 331 | self.construct_tree_from_dict(self.normalize_dict({})), 332 | answer, 333 | get_children=zss.Node.get_children, 334 | insert_cost=self.insert_and_remove_cost, 335 | remove_cost=self.insert_and_remove_cost, 336 | update_cost=self.update_cost, 337 | return_operations=False, 338 | ) 339 | ), 340 | ) 341 | -------------------------------------------------------------------------------- /lightning_module.py: -------------------------------------------------------------------------------- 1 | """ 2 | Donut 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT License 5 | """ 6 | import math 7 | import random 8 | import re 9 | from pathlib import Path 10 | 11 | import numpy as np 12 | import pytorch_lightning as pl 13 | import torch 14 | from nltk import edit_distance 15 | from pytorch_lightning.utilities import rank_zero_only 16 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 17 | from torch.nn.utils.rnn import pad_sequence 18 | from torch.optim.lr_scheduler import LambdaLR 19 | from torch.utils.data import DataLoader 20 | 21 | from donut import DonutConfig, DonutModel 22 | 23 | 24 | class DonutModelPLModule(pl.LightningModule): 25 | def __init__(self, config): 26 | super().__init__() 27 | self.config = config 28 | 29 | if self.config.get("pretrained_model_name_or_path", False): 30 | self.model = DonutModel.from_pretrained( 31 | self.config.pretrained_model_name_or_path, 32 | input_size=self.config.input_size, 33 | max_length=self.config.max_length, 34 | align_long_axis=self.config.align_long_axis, 35 | ignore_mismatched_sizes=True, 36 | ) 37 | else: 38 | self.model = DonutModel( 39 | config=DonutConfig( 40 | input_size=self.config.input_size, 41 | max_length=self.config.max_length, 42 | align_long_axis=self.config.align_long_axis, 43 | # with DonutConfig, the architecture customization is available, e.g., 44 | # encoder_layer=[2,2,14,2], decoder_layer=4, ... 45 | ) 46 | ) 47 | self.pytorch_lightning_version_is_1 = int(pl.__version__[0]) < 2 48 | self.num_of_loaders = len(self.config.dataset_name_or_paths) 49 | 50 | def training_step(self, batch, batch_idx): 51 | image_tensors, decoder_input_ids, decoder_labels = list(), list(), list() 52 | for batch_data in batch: 53 | image_tensors.append(batch_data[0]) 54 | decoder_input_ids.append(batch_data[1][:, :-1]) 55 | decoder_labels.append(batch_data[2][:, 1:]) 56 | image_tensors = torch.cat(image_tensors) 57 | decoder_input_ids = torch.cat(decoder_input_ids) 58 | decoder_labels = torch.cat(decoder_labels) 59 | loss = self.model(image_tensors, decoder_input_ids, decoder_labels)[0] 60 | self.log_dict({"train_loss": loss}, sync_dist=True) 61 | if not self.pytorch_lightning_version_is_1: 62 | self.log('loss', loss, prog_bar=True) 63 | return loss 64 | 65 | def on_validation_epoch_start(self) -> None: 66 | super().on_validation_epoch_start() 67 | self.validation_step_outputs = [[] for _ in range(self.num_of_loaders)] 68 | return 69 | 70 | def validation_step(self, batch, batch_idx, dataloader_idx=0): 71 | image_tensors, decoder_input_ids, prompt_end_idxs, answers = batch 72 | decoder_prompts = pad_sequence( 73 | [input_id[: end_idx + 1] for input_id, end_idx in zip(decoder_input_ids, prompt_end_idxs)], 74 | batch_first=True, 75 | ) 76 | 77 | preds = self.model.inference( 78 | image_tensors=image_tensors, 79 | prompt_tensors=decoder_prompts, 80 | return_json=False, 81 | return_attentions=False, 82 | )["predictions"] 83 | 84 | scores = list() 85 | for pred, answer in zip(preds, answers): 86 | pred = re.sub(r"(?:(?<=>) | (?=", "", answer, count=1) 88 | answer = answer.replace(self.model.decoder.tokenizer.eos_token, "") 89 | scores.append(edit_distance(pred, answer) / max(len(pred), len(answer))) 90 | 91 | if self.config.get("verbose", False) and len(scores) == 1: 92 | self.print(f"Prediction: {pred}") 93 | self.print(f" Answer: {answer}") 94 | self.print(f" Normed ED: {scores[0]}") 95 | 96 | self.validation_step_outputs[dataloader_idx].append(scores) 97 | 98 | return scores 99 | 100 | def on_validation_epoch_end(self): 101 | assert len(self.validation_step_outputs) == self.num_of_loaders 102 | cnt = [0] * self.num_of_loaders 103 | total_metric = [0] * self.num_of_loaders 104 | val_metric = [0] * self.num_of_loaders 105 | for i, results in enumerate(self.validation_step_outputs): 106 | for scores in results: 107 | cnt[i] += len(scores) 108 | total_metric[i] += np.sum(scores) 109 | val_metric[i] = total_metric[i] / cnt[i] 110 | val_metric_name = f"val_metric_{i}th_dataset" 111 | self.log_dict({val_metric_name: val_metric[i]}, sync_dist=True) 112 | self.log_dict({"val_metric": np.sum(total_metric) / np.sum(cnt)}, sync_dist=True) 113 | 114 | def configure_optimizers(self): 115 | 116 | max_iter = None 117 | 118 | if int(self.config.get("max_epochs", -1)) > 0: 119 | assert len(self.config.train_batch_sizes) == 1, "Set max_epochs only if the number of datasets is 1" 120 | max_iter = (self.config.max_epochs * self.config.num_training_samples_per_epoch) / ( 121 | self.config.train_batch_sizes[0] * torch.cuda.device_count() * self.config.get("num_nodes", 1) 122 | ) 123 | 124 | if int(self.config.get("max_steps", -1)) > 0: 125 | max_iter = min(self.config.max_steps, max_iter) if max_iter is not None else self.config.max_steps 126 | 127 | assert max_iter is not None 128 | optimizer = torch.optim.Adam(self.parameters(), lr=self.config.lr) 129 | scheduler = { 130 | "scheduler": self.cosine_scheduler(optimizer, max_iter, self.config.warmup_steps), 131 | "name": "learning_rate", 132 | "interval": "step", 133 | } 134 | return [optimizer], [scheduler] 135 | 136 | @staticmethod 137 | def cosine_scheduler(optimizer, training_steps, warmup_steps): 138 | def lr_lambda(current_step): 139 | if current_step < warmup_steps: 140 | return current_step / max(1, warmup_steps) 141 | progress = current_step - warmup_steps 142 | progress /= max(1, training_steps - warmup_steps) 143 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) 144 | 145 | return LambdaLR(optimizer, lr_lambda) 146 | 147 | @rank_zero_only 148 | def on_save_checkpoint(self, checkpoint): 149 | save_path = Path(self.config.result_path) / self.config.exp_name / self.config.exp_version 150 | self.model.save_pretrained(save_path) 151 | self.model.decoder.tokenizer.save_pretrained(save_path) 152 | 153 | 154 | class DonutDataPLModule(pl.LightningDataModule): 155 | def __init__(self, config): 156 | super().__init__() 157 | self.config = config 158 | self.train_batch_sizes = self.config.train_batch_sizes 159 | self.val_batch_sizes = self.config.val_batch_sizes 160 | self.train_datasets = [] 161 | self.val_datasets = [] 162 | self.g = torch.Generator() 163 | self.g.manual_seed(self.config.seed) 164 | 165 | def train_dataloader(self): 166 | loaders = list() 167 | for train_dataset, batch_size in zip(self.train_datasets, self.train_batch_sizes): 168 | loaders.append( 169 | DataLoader( 170 | train_dataset, 171 | batch_size=batch_size, 172 | num_workers=self.config.num_workers, 173 | pin_memory=True, 174 | worker_init_fn=self.seed_worker, 175 | generator=self.g, 176 | shuffle=True, 177 | ) 178 | ) 179 | return loaders 180 | 181 | def val_dataloader(self): 182 | loaders = list() 183 | for val_dataset, batch_size in zip(self.val_datasets, self.val_batch_sizes): 184 | loaders.append( 185 | DataLoader( 186 | val_dataset, 187 | batch_size=batch_size, 188 | pin_memory=True, 189 | shuffle=False, 190 | ) 191 | ) 192 | return loaders 193 | 194 | @staticmethod 195 | def seed_worker(wordker_id): 196 | worker_seed = torch.initial_seed() % 2 ** 32 197 | np.random.seed(worker_seed) 198 | random.seed(worker_seed) 199 | -------------------------------------------------------------------------------- /misc/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lexlang/donut/5dc8b717547a1a550c8ce19fab57ab518ae5fbe7/misc/overview.png -------------------------------------------------------------------------------- /misc/sample_image_cord_test_receipt_00004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lexlang/donut/5dc8b717547a1a550c8ce19fab57ab518ae5fbe7/misc/sample_image_cord_test_receipt_00004.png -------------------------------------------------------------------------------- /misc/sample_image_donut_document.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lexlang/donut/5dc8b717547a1a550c8ce19fab57ab518ae5fbe7/misc/sample_image_donut_document.png -------------------------------------------------------------------------------- /misc/sample_synthdog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lexlang/donut/5dc8b717547a1a550c8ce19fab57ab518ae5fbe7/misc/sample_synthdog.png -------------------------------------------------------------------------------- /misc/screenshot_gradio_demos.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lexlang/donut/5dc8b717547a1a550c8ce19fab57ab518ae5fbe7/misc/screenshot_gradio_demos.png -------------------------------------------------------------------------------- /result/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Donut 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT License 5 | """ 6 | import os 7 | from setuptools import find_packages, setup 8 | 9 | ROOT = os.path.abspath(os.path.dirname(__file__)) 10 | 11 | 12 | def read_version(): 13 | data = {} 14 | path = os.path.join(ROOT, "donut", "_version.py") 15 | with open(path, "r", encoding="utf-8") as f: 16 | exec(f.read(), data) 17 | return data["__version__"] 18 | 19 | 20 | def read_long_description(): 21 | path = os.path.join(ROOT, "README.md") 22 | with open(path, "r", encoding="utf-8") as f: 23 | text = f.read() 24 | return text 25 | 26 | 27 | setup( 28 | name="donut-python", 29 | version=read_version(), 30 | description="OCR-free Document Understanding Transformer", 31 | long_description=read_long_description(), 32 | long_description_content_type="text/markdown", 33 | author="Geewook Kim, Teakgyu Hong, Moonbin Yim, JeongYeon Nam, Jinyoung Park, Jinyeong Yim, Wonseok Hwang, Sangdoo Yun, Dongyoon Han, Seunghyun Park", 34 | author_email="gwkim.rsrch@gmail.com", 35 | url="https://github.com/clovaai/donut", 36 | license="MIT", 37 | packages=find_packages( 38 | exclude=[ 39 | "config", 40 | "dataset", 41 | "misc", 42 | "result", 43 | "synthdog", 44 | "app.py", 45 | "lightning_module.py", 46 | "README.md", 47 | "train.py", 48 | "test.py", 49 | ] 50 | ), 51 | python_requires=">=3.7", 52 | install_requires=[ 53 | "transformers>=4.11.3", 54 | "timm", 55 | "datasets[vision]", 56 | "pytorch-lightning>=1.6.4", 57 | "nltk", 58 | "sentencepiece", 59 | "zss", 60 | "sconf>=0.2.3", 61 | ], 62 | classifiers=[ 63 | "Intended Audience :: Developers", 64 | "Intended Audience :: Information Technology", 65 | "Intended Audience :: Science/Research", 66 | "License :: OSI Approved :: MIT License", 67 | "Programming Language :: Python", 68 | "Programming Language :: Python :: 3", 69 | "Programming Language :: Python :: 3.7", 70 | "Programming Language :: Python :: 3.8", 71 | "Programming Language :: Python :: 3.9", 72 | "Programming Language :: Python :: 3.10", 73 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 74 | "Topic :: Software Development :: Libraries", 75 | "Topic :: Software Development :: Libraries :: Python Modules", 76 | ], 77 | ) 78 | -------------------------------------------------------------------------------- /synthdog/README.md: -------------------------------------------------------------------------------- 1 | # SynthDoG 🐶: Synthetic Document Generator 2 | 3 | SynthDoG is synthetic document generator for visual document understanding (VDU). 4 | 5 | ![image](../misc/sample_synthdog.png) 6 | 7 | ## Prerequisites 8 | 9 | - python>=3.6 10 | - [synthtiger](https://github.com/clovaai/synthtiger) (`pip install synthtiger`) 11 | 12 | ## Usage 13 | 14 | ```bash 15 | # Set environment variable (for macOS) 16 | $ export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES 17 | 18 | synthtiger -o ./outputs/SynthDoG_en -c 50 -w 4 -v template.py SynthDoG config_en.yaml 19 | 20 | {'config': 'config_en.yaml', 21 | 'count': 50, 22 | 'name': 'SynthDoG', 23 | 'output': './outputs/SynthDoG_en', 24 | 'script': 'template.py', 25 | 'verbose': True, 26 | 'worker': 4} 27 | {'aspect_ratio': [1, 2], 28 | . 29 | . 30 | 'quality': [50, 95], 31 | 'short_size': [720, 1024]} 32 | Generated 1 data (task 3) 33 | Generated 2 data (task 0) 34 | Generated 3 data (task 1) 35 | . 36 | . 37 | Generated 49 data (task 48) 38 | Generated 50 data (task 49) 39 | 46.32 seconds elapsed 40 | ``` 41 | 42 | Some important arguments: 43 | 44 | - `-o` : directory path to save data. 45 | - `-c` : number of data to generate. 46 | - `-w` : number of workers. 47 | - `-s` : random seed. 48 | - `-v` : print error messages. 49 | 50 | To generate ECJK samples: 51 | ```bash 52 | # english 53 | synthtiger -o {dataset_path} -c {num_of_data} -w {num_of_workers} -v template.py SynthDoG config_en.yaml 54 | 55 | # chinese 56 | synthtiger -o {dataset_path} -c {num_of_data} -w {num_of_workers} -v template.py SynthDoG config_zh.yaml 57 | 58 | # japanese 59 | synthtiger -o {dataset_path} -c {num_of_data} -w {num_of_workers} -v template.py SynthDoG config_ja.yaml 60 | 61 | # korean 62 | synthtiger -o {dataset_path} -c {num_of_data} -w {num_of_workers} -v template.py SynthDoG config_ko.yaml 63 | ``` 64 | -------------------------------------------------------------------------------- /synthdog/config_en.yaml: -------------------------------------------------------------------------------- 1 | quality: [50, 95] 2 | landscape: 0.5 3 | short_size: [720, 1024] 4 | aspect_ratio: [1, 2] 5 | 6 | background: 7 | image: 8 | paths: [resources/background] 9 | weights: [1] 10 | 11 | effect: 12 | args: 13 | # gaussian blur 14 | - prob: 1 15 | args: 16 | sigma: [0, 10] 17 | 18 | document: 19 | fullscreen: 0.5 20 | landscape: 0.5 21 | short_size: [480, 1024] 22 | aspect_ratio: [1, 2] 23 | 24 | paper: 25 | image: 26 | paths: [resources/paper] 27 | weights: [1] 28 | alpha: [0, 0.2] 29 | grayscale: 1 30 | crop: 1 31 | 32 | content: 33 | margin: [0, 0.1] 34 | text: 35 | path: resources/corpus/enwiki.txt 36 | font: 37 | paths: [resources/font/en] 38 | weights: [1] 39 | bold: 0 40 | layout: 41 | text_scale: [0.0334, 0.1] 42 | max_row: 10 43 | max_col: 3 44 | fill: [0.5, 1] 45 | full: 0.1 46 | align: [left, right, center] 47 | stack_spacing: [0.0334, 0.0334] 48 | stack_fill: [0.5, 1] 49 | stack_full: 0.1 50 | textbox: 51 | fill: [0.5, 1] 52 | textbox_color: 53 | prob: 0.2 54 | args: 55 | gray: [0, 64] 56 | colorize: 1 57 | content_color: 58 | prob: 0.2 59 | args: 60 | gray: [0, 64] 61 | colorize: 1 62 | 63 | effect: 64 | args: 65 | # elastic distortion 66 | - prob: 1 67 | args: 68 | alpha: [0, 1] 69 | sigma: [0, 0.5] 70 | # gaussian noise 71 | - prob: 1 72 | args: 73 | scale: [0, 8] 74 | per_channel: 0 75 | # perspective 76 | - prob: 1 77 | args: 78 | weights: [750, 50, 50, 25, 25, 25, 25, 50] 79 | args: 80 | - percents: [[0.75, 1], [0.75, 1], [0.75, 1], [0.75, 1]] 81 | - percents: [[0.75, 1], [1, 1], [0.75, 1], [1, 1]] 82 | - percents: [[1, 1], [0.75, 1], [1, 1], [0.75, 1]] 83 | - percents: [[0.75, 1], [1, 1], [1, 1], [1, 1]] 84 | - percents: [[1, 1], [0.75, 1], [1, 1], [1, 1]] 85 | - percents: [[1, 1], [1, 1], [0.75, 1], [1, 1]] 86 | - percents: [[1, 1], [1, 1], [1, 1], [0.75, 1]] 87 | - percents: [[1, 1], [1, 1], [1, 1], [1, 1]] 88 | 89 | effect: 90 | args: 91 | # color 92 | - prob: 0.2 93 | args: 94 | rgb: [[0, 255], [0, 255], [0, 255]] 95 | alpha: [0, 0.2] 96 | # shadow 97 | - prob: 1 98 | args: 99 | intensity: [0, 160] 100 | amount: [0, 1] 101 | smoothing: [0.5, 1] 102 | bidirectional: 0 103 | # contrast 104 | - prob: 1 105 | args: 106 | alpha: [1, 1.5] 107 | # brightness 108 | - prob: 1 109 | args: 110 | beta: [-48, 0] 111 | # motion blur 112 | - prob: 0.5 113 | args: 114 | k: [3, 5] 115 | angle: [0, 360] 116 | # gaussian blur 117 | - prob: 1 118 | args: 119 | sigma: [0, 1.5] 120 | -------------------------------------------------------------------------------- /synthdog/config_ja.yaml: -------------------------------------------------------------------------------- 1 | quality: [50, 95] 2 | landscape: 0.5 3 | short_size: [720, 1024] 4 | aspect_ratio: [1, 2] 5 | 6 | background: 7 | image: 8 | paths: [resources/background] 9 | weights: [1] 10 | 11 | effect: 12 | args: 13 | # gaussian blur 14 | - prob: 1 15 | args: 16 | sigma: [0, 10] 17 | 18 | document: 19 | fullscreen: 0.5 20 | landscape: 0.5 21 | short_size: [480, 1024] 22 | aspect_ratio: [1, 2] 23 | 24 | paper: 25 | image: 26 | paths: [resources/paper] 27 | weights: [1] 28 | alpha: [0, 0.2] 29 | grayscale: 1 30 | crop: 1 31 | 32 | content: 33 | margin: [0, 0.1] 34 | text: 35 | path: resources/corpus/jawiki.txt 36 | font: 37 | paths: [resources/font/ja] 38 | weights: [1] 39 | bold: 0 40 | layout: 41 | text_scale: [0.0334, 0.1] 42 | max_row: 10 43 | max_col: 3 44 | fill: [0.5, 1] 45 | full: 0.1 46 | align: [left, right, center] 47 | stack_spacing: [0.0334, 0.0334] 48 | stack_fill: [0.5, 1] 49 | stack_full: 0.1 50 | textbox: 51 | fill: [0.5, 1] 52 | textbox_color: 53 | prob: 0.2 54 | args: 55 | gray: [0, 64] 56 | colorize: 1 57 | content_color: 58 | prob: 0.2 59 | args: 60 | gray: [0, 64] 61 | colorize: 1 62 | 63 | effect: 64 | args: 65 | # elastic distortion 66 | - prob: 1 67 | args: 68 | alpha: [0, 1] 69 | sigma: [0, 0.5] 70 | # gaussian noise 71 | - prob: 1 72 | args: 73 | scale: [0, 8] 74 | per_channel: 0 75 | # perspective 76 | - prob: 1 77 | args: 78 | weights: [750, 50, 50, 25, 25, 25, 25, 50] 79 | args: 80 | - percents: [[0.75, 1], [0.75, 1], [0.75, 1], [0.75, 1]] 81 | - percents: [[0.75, 1], [1, 1], [0.75, 1], [1, 1]] 82 | - percents: [[1, 1], [0.75, 1], [1, 1], [0.75, 1]] 83 | - percents: [[0.75, 1], [1, 1], [1, 1], [1, 1]] 84 | - percents: [[1, 1], [0.75, 1], [1, 1], [1, 1]] 85 | - percents: [[1, 1], [1, 1], [0.75, 1], [1, 1]] 86 | - percents: [[1, 1], [1, 1], [1, 1], [0.75, 1]] 87 | - percents: [[1, 1], [1, 1], [1, 1], [1, 1]] 88 | 89 | effect: 90 | args: 91 | # color 92 | - prob: 0.2 93 | args: 94 | rgb: [[0, 255], [0, 255], [0, 255]] 95 | alpha: [0, 0.2] 96 | # shadow 97 | - prob: 1 98 | args: 99 | intensity: [0, 160] 100 | amount: [0, 1] 101 | smoothing: [0.5, 1] 102 | bidirectional: 0 103 | # contrast 104 | - prob: 1 105 | args: 106 | alpha: [1, 1.5] 107 | # brightness 108 | - prob: 1 109 | args: 110 | beta: [-48, 0] 111 | # motion blur 112 | - prob: 0.5 113 | args: 114 | k: [3, 5] 115 | angle: [0, 360] 116 | # gaussian blur 117 | - prob: 1 118 | args: 119 | sigma: [0, 1.5] 120 | -------------------------------------------------------------------------------- /synthdog/config_ko.yaml: -------------------------------------------------------------------------------- 1 | quality: [50, 95] 2 | landscape: 0.5 3 | short_size: [720, 1024] 4 | aspect_ratio: [1, 2] 5 | 6 | background: 7 | image: 8 | paths: [resources/background] 9 | weights: [1] 10 | 11 | effect: 12 | args: 13 | # gaussian blur 14 | - prob: 1 15 | args: 16 | sigma: [0, 10] 17 | 18 | document: 19 | fullscreen: 0.5 20 | landscape: 0.5 21 | short_size: [480, 1024] 22 | aspect_ratio: [1, 2] 23 | 24 | paper: 25 | image: 26 | paths: [resources/paper] 27 | weights: [1] 28 | alpha: [0, 0.2] 29 | grayscale: 1 30 | crop: 1 31 | 32 | content: 33 | margin: [0, 0.1] 34 | text: 35 | path: resources/corpus/kowiki.txt 36 | font: 37 | paths: [resources/font/ko] 38 | weights: [1] 39 | bold: 0 40 | layout: 41 | text_scale: [0.0334, 0.1] 42 | max_row: 10 43 | max_col: 3 44 | fill: [0.5, 1] 45 | full: 0.1 46 | align: [left, right, center] 47 | stack_spacing: [0.0334, 0.0334] 48 | stack_fill: [0.5, 1] 49 | stack_full: 0.1 50 | textbox: 51 | fill: [0.5, 1] 52 | textbox_color: 53 | prob: 0.2 54 | args: 55 | gray: [0, 64] 56 | colorize: 1 57 | content_color: 58 | prob: 0.2 59 | args: 60 | gray: [0, 64] 61 | colorize: 1 62 | 63 | effect: 64 | args: 65 | # elastic distortion 66 | - prob: 1 67 | args: 68 | alpha: [0, 1] 69 | sigma: [0, 0.5] 70 | # gaussian noise 71 | - prob: 1 72 | args: 73 | scale: [0, 8] 74 | per_channel: 0 75 | # perspective 76 | - prob: 1 77 | args: 78 | weights: [750, 50, 50, 25, 25, 25, 25, 50] 79 | args: 80 | - percents: [[0.75, 1], [0.75, 1], [0.75, 1], [0.75, 1]] 81 | - percents: [[0.75, 1], [1, 1], [0.75, 1], [1, 1]] 82 | - percents: [[1, 1], [0.75, 1], [1, 1], [0.75, 1]] 83 | - percents: [[0.75, 1], [1, 1], [1, 1], [1, 1]] 84 | - percents: [[1, 1], [0.75, 1], [1, 1], [1, 1]] 85 | - percents: [[1, 1], [1, 1], [0.75, 1], [1, 1]] 86 | - percents: [[1, 1], [1, 1], [1, 1], [0.75, 1]] 87 | - percents: [[1, 1], [1, 1], [1, 1], [1, 1]] 88 | 89 | effect: 90 | args: 91 | # color 92 | - prob: 0.2 93 | args: 94 | rgb: [[0, 255], [0, 255], [0, 255]] 95 | alpha: [0, 0.2] 96 | # shadow 97 | - prob: 1 98 | args: 99 | intensity: [0, 160] 100 | amount: [0, 1] 101 | smoothing: [0.5, 1] 102 | bidirectional: 0 103 | # contrast 104 | - prob: 1 105 | args: 106 | alpha: [1, 1.5] 107 | # brightness 108 | - prob: 1 109 | args: 110 | beta: [-48, 0] 111 | # motion blur 112 | - prob: 0.5 113 | args: 114 | k: [3, 5] 115 | angle: [0, 360] 116 | # gaussian blur 117 | - prob: 1 118 | args: 119 | sigma: [0, 1.5] 120 | -------------------------------------------------------------------------------- /synthdog/config_zh.yaml: -------------------------------------------------------------------------------- 1 | quality: [50, 95] 2 | landscape: 0.5 3 | short_size: [720, 1024] 4 | aspect_ratio: [1, 2] 5 | 6 | background: 7 | image: 8 | paths: [resources/background] 9 | weights: [1] 10 | 11 | effect: 12 | args: 13 | # gaussian blur 14 | - prob: 1 15 | args: 16 | sigma: [0, 10] 17 | 18 | document: 19 | fullscreen: 0.5 20 | landscape: 0.5 21 | short_size: [480, 1024] 22 | aspect_ratio: [1, 2] 23 | 24 | paper: 25 | image: 26 | paths: [resources/paper] 27 | weights: [1] 28 | alpha: [0, 0.2] 29 | grayscale: 1 30 | crop: 1 31 | 32 | content: 33 | margin: [0, 0.1] 34 | text: 35 | path: resources/corpus/zhwiki.txt 36 | font: 37 | paths: [resources/font/zh] 38 | weights: [1] 39 | bold: 0 40 | layout: 41 | text_scale: [0.0334, 0.1] 42 | max_row: 10 43 | max_col: 3 44 | fill: [0.5, 1] 45 | full: 0.1 46 | align: [left, right, center] 47 | stack_spacing: [0.0334, 0.0334] 48 | stack_fill: [0.5, 1] 49 | stack_full: 0.1 50 | textbox: 51 | fill: [0.5, 1] 52 | textbox_color: 53 | prob: 0.2 54 | args: 55 | gray: [0, 64] 56 | colorize: 1 57 | content_color: 58 | prob: 0.2 59 | args: 60 | gray: [0, 64] 61 | colorize: 1 62 | 63 | effect: 64 | args: 65 | # elastic distortion 66 | - prob: 1 67 | args: 68 | alpha: [0, 1] 69 | sigma: [0, 0.5] 70 | # gaussian noise 71 | - prob: 1 72 | args: 73 | scale: [0, 8] 74 | per_channel: 0 75 | # perspective 76 | - prob: 1 77 | args: 78 | weights: [750, 50, 50, 25, 25, 25, 25, 50] 79 | args: 80 | - percents: [[0.75, 1], [0.75, 1], [0.75, 1], [0.75, 1]] 81 | - percents: [[0.75, 1], [1, 1], [0.75, 1], [1, 1]] 82 | - percents: [[1, 1], [0.75, 1], [1, 1], [0.75, 1]] 83 | - percents: [[0.75, 1], [1, 1], [1, 1], [1, 1]] 84 | - percents: [[1, 1], [0.75, 1], [1, 1], [1, 1]] 85 | - percents: [[1, 1], [1, 1], [0.75, 1], [1, 1]] 86 | - percents: [[1, 1], [1, 1], [1, 1], [0.75, 1]] 87 | - percents: [[1, 1], [1, 1], [1, 1], [1, 1]] 88 | 89 | effect: 90 | args: 91 | # color 92 | - prob: 0.2 93 | args: 94 | rgb: [[0, 255], [0, 255], [0, 255]] 95 | alpha: [0, 0.2] 96 | # shadow 97 | - prob: 1 98 | args: 99 | intensity: [0, 160] 100 | amount: [0, 1] 101 | smoothing: [0.5, 1] 102 | bidirectional: 0 103 | # contrast 104 | - prob: 1 105 | args: 106 | alpha: [1, 1.5] 107 | # brightness 108 | - prob: 1 109 | args: 110 | beta: [-48, 0] 111 | # motion blur 112 | - prob: 0.5 113 | args: 114 | k: [3, 5] 115 | angle: [0, 360] 116 | # gaussian blur 117 | - prob: 1 118 | args: 119 | sigma: [0, 1.5] 120 | -------------------------------------------------------------------------------- /synthdog/elements/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Donut 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT License 5 | """ 6 | from elements.background import Background 7 | from elements.content import Content 8 | from elements.document import Document 9 | from elements.paper import Paper 10 | from elements.textbox import TextBox 11 | 12 | __all__ = ["Background", "Content", "Document", "Paper", "TextBox"] 13 | -------------------------------------------------------------------------------- /synthdog/elements/background.py: -------------------------------------------------------------------------------- 1 | """ 2 | Donut 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT License 5 | """ 6 | from synthtiger import components, layers 7 | 8 | 9 | class Background: 10 | def __init__(self, config): 11 | self.image = components.BaseTexture(**config.get("image", {})) 12 | self.effect = components.Iterator( 13 | [ 14 | components.Switch(components.GaussianBlur()), 15 | ], 16 | **config.get("effect", {}) 17 | ) 18 | 19 | def generate(self, size): 20 | bg_layer = layers.RectLayer(size, (255, 255, 255, 255)) 21 | self.image.apply([bg_layer]) 22 | self.effect.apply([bg_layer]) 23 | 24 | return bg_layer 25 | -------------------------------------------------------------------------------- /synthdog/elements/content.py: -------------------------------------------------------------------------------- 1 | """ 2 | Donut 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT License 5 | """ 6 | from collections import OrderedDict 7 | 8 | import numpy as np 9 | from synthtiger import components 10 | 11 | from elements.textbox import TextBox 12 | from layouts import GridStack 13 | 14 | 15 | class TextReader: 16 | def __init__(self, path, cache_size=2 ** 28, block_size=2 ** 20): 17 | self.fp = open(path, "r", encoding="utf-8") 18 | self.length = 0 19 | self.offsets = [0] 20 | self.cache = OrderedDict() 21 | self.cache_size = cache_size 22 | self.block_size = block_size 23 | self.bucket_size = cache_size // block_size 24 | self.idx = 0 25 | 26 | while True: 27 | text = self.fp.read(self.block_size) 28 | if not text: 29 | break 30 | self.length += len(text) 31 | self.offsets.append(self.fp.tell()) 32 | 33 | def __len__(self): 34 | return self.length 35 | 36 | def __iter__(self): 37 | return self 38 | 39 | def __next__(self): 40 | char = self.get() 41 | self.next() 42 | return char 43 | 44 | def move(self, idx): 45 | self.idx = idx 46 | 47 | def next(self): 48 | self.idx = (self.idx + 1) % self.length 49 | 50 | def prev(self): 51 | self.idx = (self.idx - 1) % self.length 52 | 53 | def get(self): 54 | key = self.idx // self.block_size 55 | 56 | if key in self.cache: 57 | text = self.cache[key] 58 | else: 59 | if len(self.cache) >= self.bucket_size: 60 | self.cache.popitem(last=False) 61 | 62 | offset = self.offsets[key] 63 | self.fp.seek(offset, 0) 64 | text = self.fp.read(self.block_size) 65 | self.cache[key] = text 66 | 67 | self.cache.move_to_end(key) 68 | char = text[self.idx % self.block_size] 69 | return char 70 | 71 | 72 | class Content: 73 | def __init__(self, config): 74 | self.margin = config.get("margin", [0, 0.1]) 75 | self.reader = TextReader(**config.get("text", {})) 76 | self.font = components.BaseFont(**config.get("font", {})) 77 | self.layout = GridStack(config.get("layout", {})) 78 | self.textbox = TextBox(config.get("textbox", {})) 79 | self.textbox_color = components.Switch(components.Gray(), **config.get("textbox_color", {})) 80 | self.content_color = components.Switch(components.Gray(), **config.get("content_color", {})) 81 | 82 | def generate(self, size): 83 | width, height = size 84 | 85 | layout_left = width * np.random.uniform(self.margin[0], self.margin[1]) 86 | layout_top = height * np.random.uniform(self.margin[0], self.margin[1]) 87 | layout_width = max(width - layout_left * 2, 0) 88 | layout_height = max(height - layout_top * 2, 0) 89 | layout_bbox = [layout_left, layout_top, layout_width, layout_height] 90 | 91 | text_layers, texts = [], [] 92 | layouts = self.layout.generate(layout_bbox) 93 | self.reader.move(np.random.randint(len(self.reader))) 94 | 95 | for layout in layouts: 96 | font = self.font.sample() 97 | 98 | for bbox, align in layout: 99 | x, y, w, h = bbox 100 | text_layer, text = self.textbox.generate((w, h), self.reader, font) 101 | self.reader.prev() 102 | 103 | if text_layer is None: 104 | continue 105 | 106 | text_layer.center = (x + w / 2, y + h / 2) 107 | if align == "left": 108 | text_layer.left = x 109 | if align == "right": 110 | text_layer.right = x + w 111 | 112 | self.textbox_color.apply([text_layer]) 113 | text_layers.append(text_layer) 114 | texts.append(text) 115 | 116 | self.content_color.apply(text_layers) 117 | 118 | return text_layers, texts 119 | -------------------------------------------------------------------------------- /synthdog/elements/document.py: -------------------------------------------------------------------------------- 1 | """ 2 | Donut 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT License 5 | """ 6 | import numpy as np 7 | from synthtiger import components 8 | 9 | from elements.content import Content 10 | from elements.paper import Paper 11 | 12 | 13 | class Document: 14 | def __init__(self, config): 15 | self.fullscreen = config.get("fullscreen", 0.5) 16 | self.landscape = config.get("landscape", 0.5) 17 | self.short_size = config.get("short_size", [480, 1024]) 18 | self.aspect_ratio = config.get("aspect_ratio", [1, 2]) 19 | self.paper = Paper(config.get("paper", {})) 20 | self.content = Content(config.get("content", {})) 21 | self.effect = components.Iterator( 22 | [ 23 | components.Switch(components.ElasticDistortion()), 24 | components.Switch(components.AdditiveGaussianNoise()), 25 | components.Switch( 26 | components.Selector( 27 | [ 28 | components.Perspective(), 29 | components.Perspective(), 30 | components.Perspective(), 31 | components.Perspective(), 32 | components.Perspective(), 33 | components.Perspective(), 34 | components.Perspective(), 35 | components.Perspective(), 36 | ] 37 | ) 38 | ), 39 | ], 40 | **config.get("effect", {}), 41 | ) 42 | 43 | def generate(self, size): 44 | width, height = size 45 | fullscreen = np.random.rand() < self.fullscreen 46 | 47 | if not fullscreen: 48 | landscape = np.random.rand() < self.landscape 49 | max_size = width if landscape else height 50 | short_size = np.random.randint( 51 | min(width, height, self.short_size[0]), 52 | min(width, height, self.short_size[1]) + 1, 53 | ) 54 | aspect_ratio = np.random.uniform( 55 | min(max_size / short_size, self.aspect_ratio[0]), 56 | min(max_size / short_size, self.aspect_ratio[1]), 57 | ) 58 | long_size = int(short_size * aspect_ratio) 59 | size = (long_size, short_size) if landscape else (short_size, long_size) 60 | 61 | text_layers, texts = self.content.generate(size) 62 | paper_layer = self.paper.generate(size) 63 | self.effect.apply([*text_layers, paper_layer]) 64 | 65 | return paper_layer, text_layers, texts 66 | -------------------------------------------------------------------------------- /synthdog/elements/paper.py: -------------------------------------------------------------------------------- 1 | """ 2 | Donut 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT License 5 | """ 6 | from synthtiger import components, layers 7 | 8 | 9 | class Paper: 10 | def __init__(self, config): 11 | self.image = components.BaseTexture(**config.get("image", {})) 12 | 13 | def generate(self, size): 14 | paper_layer = layers.RectLayer(size, (255, 255, 255, 255)) 15 | self.image.apply([paper_layer]) 16 | 17 | return paper_layer 18 | -------------------------------------------------------------------------------- /synthdog/elements/textbox.py: -------------------------------------------------------------------------------- 1 | """ 2 | Donut 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT License 5 | """ 6 | import numpy as np 7 | from synthtiger import layers 8 | 9 | 10 | class TextBox: 11 | def __init__(self, config): 12 | self.fill = config.get("fill", [1, 1]) 13 | 14 | def generate(self, size, text, font): 15 | width, height = size 16 | 17 | char_layers, chars = [], [] 18 | fill = np.random.uniform(self.fill[0], self.fill[1]) 19 | width = np.clip(width * fill, height, width) 20 | font = {**font, "size": int(height)} 21 | left, top = 0, 0 22 | 23 | for char in text: 24 | if char in "\r\n": 25 | continue 26 | 27 | char_layer = layers.TextLayer(char, **font) 28 | char_scale = height / char_layer.height 29 | char_layer.bbox = [left, top, *(char_layer.size * char_scale)] 30 | if char_layer.right > width: 31 | break 32 | 33 | char_layers.append(char_layer) 34 | chars.append(char) 35 | left = char_layer.right 36 | 37 | text = "".join(chars).strip() 38 | if len(char_layers) == 0 or len(text) == 0: 39 | return None, None 40 | 41 | text_layer = layers.Group(char_layers).merge() 42 | 43 | return text_layer, text 44 | -------------------------------------------------------------------------------- /synthdog/layouts/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Donut 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT License 5 | """ 6 | from layouts.grid import Grid 7 | from layouts.grid_stack import GridStack 8 | 9 | __all__ = ["Grid", "GridStack"] 10 | -------------------------------------------------------------------------------- /synthdog/layouts/grid.py: -------------------------------------------------------------------------------- 1 | """ 2 | Donut 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT License 5 | """ 6 | import numpy as np 7 | 8 | 9 | class Grid: 10 | def __init__(self, config): 11 | self.text_scale = config.get("text_scale", [0.05, 0.1]) 12 | self.max_row = config.get("max_row", 5) 13 | self.max_col = config.get("max_col", 3) 14 | self.fill = config.get("fill", [0, 1]) 15 | self.full = config.get("full", 0) 16 | self.align = config.get("align", ["left", "right", "center"]) 17 | 18 | def generate(self, bbox): 19 | left, top, width, height = bbox 20 | 21 | text_scale = np.random.uniform(self.text_scale[0], self.text_scale[1]) 22 | text_size = min(width, height) * text_scale 23 | grids = np.random.permutation(self.max_row * self.max_col) 24 | 25 | for grid in grids: 26 | row = grid // self.max_col + 1 27 | col = grid % self.max_col + 1 28 | if text_size * (col * 2 - 1) <= width and text_size * row <= height: 29 | break 30 | else: 31 | return None 32 | 33 | bound = max(1 - text_size / width * (col - 1), 0) 34 | full = np.random.rand() < self.full 35 | fill = np.random.uniform(self.fill[0], self.fill[1]) 36 | fill = 1 if full else fill 37 | fill = np.clip(fill, 0, bound) 38 | 39 | padding = np.random.randint(4) if col > 1 else np.random.randint(1, 4) 40 | padding = (bool(padding // 2), bool(padding % 2)) 41 | 42 | weights = np.zeros(col * 2 + 1) 43 | weights[1:-1] = text_size / width 44 | probs = 1 - np.random.rand(col * 2 + 1) 45 | probs[0] = 0 if not padding[0] else probs[0] 46 | probs[-1] = 0 if not padding[-1] else probs[-1] 47 | probs[1::2] *= max(fill - sum(weights[1::2]), 0) / sum(probs[1::2]) 48 | probs[::2] *= max(1 - fill - sum(weights[::2]), 0) / sum(probs[::2]) 49 | weights += probs 50 | 51 | widths = [width * weights[c] for c in range(col * 2 + 1)] 52 | heights = [text_size for _ in range(row)] 53 | 54 | xs = np.cumsum([0] + widths) 55 | ys = np.cumsum([0] + heights) 56 | 57 | layout = [] 58 | 59 | for c in range(col): 60 | align = self.align[np.random.randint(len(self.align))] 61 | 62 | for r in range(row): 63 | x, y = xs[c * 2 + 1], ys[r] 64 | w, h = xs[c * 2 + 2] - x, ys[r + 1] - y 65 | bbox = [left + x, top + y, w, h] 66 | layout.append((bbox, align)) 67 | 68 | return layout 69 | -------------------------------------------------------------------------------- /synthdog/layouts/grid_stack.py: -------------------------------------------------------------------------------- 1 | """ 2 | Donut 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT License 5 | """ 6 | import numpy as np 7 | 8 | from layouts import Grid 9 | 10 | 11 | class GridStack: 12 | def __init__(self, config): 13 | self.text_scale = config.get("text_scale", [0.05, 0.1]) 14 | self.max_row = config.get("max_row", 5) 15 | self.max_col = config.get("max_col", 3) 16 | self.fill = config.get("fill", [0, 1]) 17 | self.full = config.get("full", 0) 18 | self.align = config.get("align", ["left", "right", "center"]) 19 | self.stack_spacing = config.get("stack_spacing", [0, 0.05]) 20 | self.stack_fill = config.get("stack_fill", [1, 1]) 21 | self.stack_full = config.get("stack_full", 0) 22 | self._grid = Grid( 23 | { 24 | "text_scale": self.text_scale, 25 | "max_row": self.max_row, 26 | "max_col": self.max_col, 27 | "align": self.align, 28 | } 29 | ) 30 | 31 | def generate(self, bbox): 32 | left, top, width, height = bbox 33 | 34 | stack_spacing = np.random.uniform(self.stack_spacing[0], self.stack_spacing[1]) 35 | stack_spacing *= min(width, height) 36 | 37 | stack_full = np.random.rand() < self.stack_full 38 | stack_fill = np.random.uniform(self.stack_fill[0], self.stack_fill[1]) 39 | stack_fill = 1 if stack_full else stack_fill 40 | 41 | full = np.random.rand() < self.full 42 | fill = np.random.uniform(self.fill[0], self.fill[1]) 43 | fill = 1 if full else fill 44 | self._grid.fill = [fill, fill] 45 | 46 | layouts = [] 47 | line = 0 48 | 49 | while True: 50 | grid_size = (width, height * stack_fill - line) 51 | text_scale = np.random.uniform(self.text_scale[0], self.text_scale[1]) 52 | text_size = min(width, height) * text_scale 53 | text_scale = text_size / min(grid_size) 54 | self._grid.text_scale = [text_scale, text_scale] 55 | 56 | layout = self._grid.generate([left, top + line, *grid_size]) 57 | if layout is None: 58 | break 59 | 60 | line = max(y + h - top for (_, y, _, h), _ in layout) + stack_spacing 61 | layouts.append(layout) 62 | 63 | line = max(line - stack_spacing, 0) 64 | space = max(height - line, 0) 65 | spaces = np.random.rand(len(layouts) + 1) 66 | spaces *= space / sum(spaces) if sum(spaces) > 0 else 0 67 | spaces = np.cumsum(spaces) 68 | 69 | for layout, space in zip(layouts, spaces): 70 | for bbox, _ in layout: 71 | x, y, w, h = bbox 72 | bbox[:] = [x, y + space, w, h] 73 | 74 | return layouts 75 | -------------------------------------------------------------------------------- /synthdog/resources/background/bedroom_83.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lexlang/donut/5dc8b717547a1a550c8ce19fab57ab518ae5fbe7/synthdog/resources/background/bedroom_83.jpg -------------------------------------------------------------------------------- /synthdog/resources/background/bob+dylan_83.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lexlang/donut/5dc8b717547a1a550c8ce19fab57ab518ae5fbe7/synthdog/resources/background/bob+dylan_83.jpg -------------------------------------------------------------------------------- /synthdog/resources/background/coffee_122.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lexlang/donut/5dc8b717547a1a550c8ce19fab57ab518ae5fbe7/synthdog/resources/background/coffee_122.jpg -------------------------------------------------------------------------------- /synthdog/resources/background/coffee_18.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lexlang/donut/5dc8b717547a1a550c8ce19fab57ab518ae5fbe7/synthdog/resources/background/coffee_18.jpeg -------------------------------------------------------------------------------- /synthdog/resources/background/crater_141.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lexlang/donut/5dc8b717547a1a550c8ce19fab57ab518ae5fbe7/synthdog/resources/background/crater_141.jpg -------------------------------------------------------------------------------- /synthdog/resources/background/cream_124.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lexlang/donut/5dc8b717547a1a550c8ce19fab57ab518ae5fbe7/synthdog/resources/background/cream_124.jpg -------------------------------------------------------------------------------- /synthdog/resources/background/eagle_110.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lexlang/donut/5dc8b717547a1a550c8ce19fab57ab518ae5fbe7/synthdog/resources/background/eagle_110.jpg -------------------------------------------------------------------------------- /synthdog/resources/background/farm_25.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lexlang/donut/5dc8b717547a1a550c8ce19fab57ab518ae5fbe7/synthdog/resources/background/farm_25.jpg -------------------------------------------------------------------------------- /synthdog/resources/background/hiking_18.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lexlang/donut/5dc8b717547a1a550c8ce19fab57ab518ae5fbe7/synthdog/resources/background/hiking_18.jpg -------------------------------------------------------------------------------- /synthdog/resources/font/en/NotoSans-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lexlang/donut/5dc8b717547a1a550c8ce19fab57ab518ae5fbe7/synthdog/resources/font/en/NotoSans-Regular.ttf -------------------------------------------------------------------------------- /synthdog/resources/font/en/NotoSerif-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lexlang/donut/5dc8b717547a1a550c8ce19fab57ab518ae5fbe7/synthdog/resources/font/en/NotoSerif-Regular.ttf -------------------------------------------------------------------------------- /synthdog/resources/font/ja/NotoSansJP-Regular.otf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lexlang/donut/5dc8b717547a1a550c8ce19fab57ab518ae5fbe7/synthdog/resources/font/ja/NotoSansJP-Regular.otf -------------------------------------------------------------------------------- /synthdog/resources/font/ja/NotoSerifJP-Regular.otf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lexlang/donut/5dc8b717547a1a550c8ce19fab57ab518ae5fbe7/synthdog/resources/font/ja/NotoSerifJP-Regular.otf -------------------------------------------------------------------------------- /synthdog/resources/font/ko/NotoSansKR-Regular.otf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lexlang/donut/5dc8b717547a1a550c8ce19fab57ab518ae5fbe7/synthdog/resources/font/ko/NotoSansKR-Regular.otf -------------------------------------------------------------------------------- /synthdog/resources/font/ko/NotoSerifKR-Regular.otf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lexlang/donut/5dc8b717547a1a550c8ce19fab57ab518ae5fbe7/synthdog/resources/font/ko/NotoSerifKR-Regular.otf -------------------------------------------------------------------------------- /synthdog/resources/font/zh/NotoSansSC-Regular.otf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lexlang/donut/5dc8b717547a1a550c8ce19fab57ab518ae5fbe7/synthdog/resources/font/zh/NotoSansSC-Regular.otf -------------------------------------------------------------------------------- /synthdog/resources/font/zh/NotoSerifSC-Regular.otf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lexlang/donut/5dc8b717547a1a550c8ce19fab57ab518ae5fbe7/synthdog/resources/font/zh/NotoSerifSC-Regular.otf -------------------------------------------------------------------------------- /synthdog/resources/paper/paper_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lexlang/donut/5dc8b717547a1a550c8ce19fab57ab518ae5fbe7/synthdog/resources/paper/paper_1.jpg -------------------------------------------------------------------------------- /synthdog/resources/paper/paper_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lexlang/donut/5dc8b717547a1a550c8ce19fab57ab518ae5fbe7/synthdog/resources/paper/paper_2.jpg -------------------------------------------------------------------------------- /synthdog/resources/paper/paper_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lexlang/donut/5dc8b717547a1a550c8ce19fab57ab518ae5fbe7/synthdog/resources/paper/paper_3.jpg -------------------------------------------------------------------------------- /synthdog/resources/paper/paper_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lexlang/donut/5dc8b717547a1a550c8ce19fab57ab518ae5fbe7/synthdog/resources/paper/paper_4.jpg -------------------------------------------------------------------------------- /synthdog/resources/paper/paper_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lexlang/donut/5dc8b717547a1a550c8ce19fab57ab518ae5fbe7/synthdog/resources/paper/paper_5.jpg -------------------------------------------------------------------------------- /synthdog/resources/paper/paper_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lexlang/donut/5dc8b717547a1a550c8ce19fab57ab518ae5fbe7/synthdog/resources/paper/paper_6.jpg -------------------------------------------------------------------------------- /synthdog/template.py: -------------------------------------------------------------------------------- 1 | """ 2 | Donut 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT License 5 | """ 6 | import json 7 | import os 8 | import re 9 | from typing import Any, List 10 | 11 | import numpy as np 12 | from elements import Background, Document 13 | from PIL import Image 14 | from synthtiger import components, layers, templates 15 | 16 | 17 | class SynthDoG(templates.Template): 18 | def __init__(self, config=None, split_ratio: List[float] = [0.8, 0.1, 0.1]): 19 | super().__init__(config) 20 | if config is None: 21 | config = {} 22 | 23 | self.quality = config.get("quality", [50, 95]) 24 | self.landscape = config.get("landscape", 0.5) 25 | self.short_size = config.get("short_size", [720, 1024]) 26 | self.aspect_ratio = config.get("aspect_ratio", [1, 2]) 27 | self.background = Background(config.get("background", {})) 28 | self.document = Document(config.get("document", {})) 29 | self.effect = components.Iterator( 30 | [ 31 | components.Switch(components.RGB()), 32 | components.Switch(components.Shadow()), 33 | components.Switch(components.Contrast()), 34 | components.Switch(components.Brightness()), 35 | components.Switch(components.MotionBlur()), 36 | components.Switch(components.GaussianBlur()), 37 | ], 38 | **config.get("effect", {}), 39 | ) 40 | 41 | # config for splits 42 | self.splits = ["train", "validation", "test"] 43 | self.split_ratio = split_ratio 44 | self.split_indexes = np.random.choice(3, size=10000, p=split_ratio) 45 | 46 | def generate(self): 47 | landscape = np.random.rand() < self.landscape 48 | short_size = np.random.randint(self.short_size[0], self.short_size[1] + 1) 49 | aspect_ratio = np.random.uniform(self.aspect_ratio[0], self.aspect_ratio[1]) 50 | long_size = int(short_size * aspect_ratio) 51 | size = (long_size, short_size) if landscape else (short_size, long_size) 52 | 53 | bg_layer = self.background.generate(size) 54 | paper_layer, text_layers, texts = self.document.generate(size) 55 | 56 | document_group = layers.Group([*text_layers, paper_layer]) 57 | document_space = np.clip(size - document_group.size, 0, None) 58 | document_group.left = np.random.randint(document_space[0] + 1) 59 | document_group.top = np.random.randint(document_space[1] + 1) 60 | roi = np.array(paper_layer.quad, dtype=int) 61 | 62 | layer = layers.Group([*document_group.layers, bg_layer]).merge() 63 | self.effect.apply([layer]) 64 | 65 | image = layer.output(bbox=[0, 0, *size]) 66 | label = " ".join(texts) 67 | label = label.strip() 68 | label = re.sub(r"\s+", " ", label) 69 | quality = np.random.randint(self.quality[0], self.quality[1] + 1) 70 | 71 | data = { 72 | "image": image, 73 | "label": label, 74 | "quality": quality, 75 | "roi": roi, 76 | } 77 | 78 | return data 79 | 80 | def init_save(self, root): 81 | if not os.path.exists(root): 82 | os.makedirs(root, exist_ok=True) 83 | 84 | def save(self, root, data, idx): 85 | image = data["image"] 86 | label = data["label"] 87 | quality = data["quality"] 88 | roi = data["roi"] 89 | 90 | # split 91 | split_idx = self.split_indexes[idx % len(self.split_indexes)] 92 | output_dirpath = os.path.join(root, self.splits[split_idx]) 93 | 94 | # save image 95 | image_filename = f"image_{idx}.jpg" 96 | image_filepath = os.path.join(output_dirpath, image_filename) 97 | os.makedirs(os.path.dirname(image_filepath), exist_ok=True) 98 | image = Image.fromarray(image[..., :3].astype(np.uint8)) 99 | image.save(image_filepath, quality=quality) 100 | 101 | # save metadata (gt_json) 102 | metadata_filename = "metadata.jsonl" 103 | metadata_filepath = os.path.join(output_dirpath, metadata_filename) 104 | os.makedirs(os.path.dirname(metadata_filepath), exist_ok=True) 105 | 106 | metadata = self.format_metadata(image_filename=image_filename, keys=["text_sequence"], values=[label]) 107 | with open(metadata_filepath, "a") as fp: 108 | json.dump(metadata, fp, ensure_ascii=False) 109 | fp.write("\n") 110 | 111 | def end_save(self, root): 112 | pass 113 | 114 | def format_metadata(self, image_filename: str, keys: List[str], values: List[Any]): 115 | """ 116 | Fit gt_parse contents to huggingface dataset's format 117 | keys and values, whose lengths are equal, are used to constrcut 'gt_parse' field in 'ground_truth' field 118 | Args: 119 | keys: List of task_name 120 | values: List of actual gt data corresponding to each task_name 121 | """ 122 | assert len(keys) == len(values), "Length does not match: keys({}), values({})".format(len(keys), len(values)) 123 | 124 | _gt_parse_v = dict() 125 | for k, v in zip(keys, values): 126 | _gt_parse_v[k] = v 127 | gt_parse = {"gt_parse": _gt_parse_v} 128 | gt_parse_str = json.dumps(gt_parse, ensure_ascii=False) 129 | metadata = {"file_name": image_filename, "ground_truth": gt_parse_str} 130 | return metadata 131 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Donut 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT License 5 | """ 6 | import argparse 7 | import json 8 | import os 9 | import re 10 | from pathlib import Path 11 | 12 | import numpy as np 13 | import torch 14 | from datasets import load_dataset 15 | from PIL import Image 16 | from tqdm import tqdm 17 | 18 | from donut import DonutModel, JSONParseEvaluator, load_json, save_json 19 | 20 | 21 | def test(args): 22 | pretrained_model = DonutModel.from_pretrained(args.pretrained_model_name_or_path) 23 | 24 | if torch.cuda.is_available(): 25 | pretrained_model.half() 26 | pretrained_model.to("cuda") 27 | 28 | pretrained_model.eval() 29 | 30 | if args.save_path: 31 | os.makedirs(os.path.dirname(args.save_path), exist_ok=True) 32 | 33 | predictions = [] 34 | ground_truths = [] 35 | accs = [] 36 | 37 | evaluator = JSONParseEvaluator() 38 | dataset = load_dataset(args.dataset_name_or_path, split=args.split) 39 | 40 | for idx, sample in tqdm(enumerate(dataset), total=len(dataset)): 41 | ground_truth = json.loads(sample["ground_truth"]) 42 | 43 | if args.task_name == "docvqa": 44 | output = pretrained_model.inference( 45 | image=sample["image"], 46 | prompt=f"{ground_truth['gt_parses'][0]['question'].lower()}", 47 | )["predictions"][0] 48 | else: 49 | output = pretrained_model.inference(image=sample["image"], prompt=f"")["predictions"][0] 50 | 51 | if args.task_name == "rvlcdip": 52 | gt = ground_truth["gt_parse"] 53 | score = float(output["class"] == gt["class"]) 54 | elif args.task_name == "docvqa": 55 | # Note: we evaluated the model on the official website. 56 | # In this script, an exact-match based score will be returned instead 57 | gt = ground_truth["gt_parses"] 58 | answers = set([qa_parse["answer"] for qa_parse in gt]) 59 | score = float(output["answer"] in answers) 60 | else: 61 | gt = ground_truth["gt_parse"] 62 | score = evaluator.cal_acc(output, gt) 63 | 64 | accs.append(score) 65 | 66 | predictions.append(output) 67 | ground_truths.append(gt) 68 | 69 | scores = { 70 | "ted_accuracies": accs, 71 | "ted_accuracy": np.mean(accs), 72 | "f1_accuracy": evaluator.cal_f1(predictions, ground_truths), 73 | } 74 | print( 75 | f"Total number of samples: {len(accs)}, Tree Edit Distance (TED) based accuracy score: {scores['ted_accuracy']}, F1 accuracy score: {scores['f1_accuracy']}" 76 | ) 77 | 78 | if args.save_path: 79 | scores["predictions"] = predictions 80 | scores["ground_truths"] = ground_truths 81 | save_json(args.save_path, scores) 82 | 83 | return predictions 84 | 85 | 86 | if __name__ == "__main__": 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument("--pretrained_model_name_or_path", type=str) 89 | parser.add_argument("--dataset_name_or_path", type=str) 90 | parser.add_argument("--split", type=str, default="test") 91 | parser.add_argument("--task_name", type=str, default=None) 92 | parser.add_argument("--save_path", type=str, default=None) 93 | args, left_argv = parser.parse_known_args() 94 | 95 | if args.task_name is None: 96 | args.task_name = os.path.basename(args.dataset_name_or_path) 97 | 98 | predictions = test(args) 99 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Donut 3 | Copyright (c) 2022-present NAVER Corp. 4 | MIT License 5 | """ 6 | import argparse 7 | import datetime 8 | import json 9 | import os 10 | import random 11 | from io import BytesIO 12 | from os.path import basename 13 | from pathlib import Path 14 | 15 | import numpy as np 16 | import pytorch_lightning as pl 17 | import torch 18 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 19 | from pytorch_lightning.loggers.tensorboard import TensorBoardLogger 20 | from pytorch_lightning.plugins import CheckpointIO 21 | from pytorch_lightning.utilities import rank_zero_only 22 | from sconf import Config 23 | 24 | from donut import DonutDataset 25 | from lightning_module import DonutDataPLModule, DonutModelPLModule 26 | 27 | 28 | class CustomCheckpointIO(CheckpointIO): 29 | def save_checkpoint(self, checkpoint, path, storage_options=None): 30 | del checkpoint["state_dict"] 31 | torch.save(checkpoint, path) 32 | 33 | def load_checkpoint(self, path, storage_options=None): 34 | checkpoint = torch.load(path + "artifacts.ckpt") 35 | state_dict = torch.load(path + "pytorch_model.bin") 36 | checkpoint["state_dict"] = {"model." + key: value for key, value in state_dict.items()} 37 | return checkpoint 38 | 39 | def remove_checkpoint(self, path) -> None: 40 | return super().remove_checkpoint(path) 41 | 42 | 43 | @rank_zero_only 44 | def save_config_file(config, path): 45 | if not Path(path).exists(): 46 | os.makedirs(path) 47 | save_path = Path(path) / "config.yaml" 48 | print(config.dumps()) 49 | with open(save_path, "w") as f: 50 | f.write(config.dumps(modified_color=None, quote_str=True)) 51 | print(f"Config is saved at {save_path}") 52 | 53 | 54 | class ProgressBar(pl.callbacks.TQDMProgressBar): 55 | def __init__(self, config): 56 | super().__init__() 57 | self.enable = True 58 | self.config = config 59 | 60 | def disable(self): 61 | self.enable = False 62 | 63 | def get_metrics(self, trainer, model): 64 | items = super().get_metrics(trainer, model) 65 | items.pop("v_num", None) 66 | items["exp_name"] = f"{self.config.get('exp_name', '')}" 67 | items["exp_version"] = f"{self.config.get('exp_version', '')}" 68 | return items 69 | 70 | 71 | def set_seed(seed): 72 | pytorch_lightning_version = int(pl.__version__[0]) 73 | if pytorch_lightning_version < 2: 74 | pl.utilities.seed.seed_everything(seed, workers=True) 75 | else: 76 | import lightning_fabric 77 | lightning_fabric.utilities.seed.seed_everything(seed, workers=True) 78 | 79 | 80 | def train(config): 81 | set_seed(config.get("seed", 42)) 82 | 83 | model_module = DonutModelPLModule(config) 84 | data_module = DonutDataPLModule(config) 85 | 86 | # add datasets to data_module 87 | datasets = {"train": [], "validation": []} 88 | for i, dataset_name_or_path in enumerate(config.dataset_name_or_paths): 89 | task_name = os.path.basename(dataset_name_or_path) # e.g., cord-v2, docvqa, rvlcdip, ... 90 | 91 | # add categorical special tokens (optional) 92 | if task_name == "rvlcdip": 93 | model_module.model.decoder.add_special_tokens([ 94 | "", "", "", "", 95 | "
", "", "", "", 96 | "", "", "", "", 97 | "", "", "", "" 98 | ]) 99 | if task_name == "docvqa": 100 | model_module.model.decoder.add_special_tokens(["", ""]) 101 | 102 | for split in ["train", "validation"]: 103 | datasets[split].append( 104 | DonutDataset( 105 | dataset_name_or_path=dataset_name_or_path, 106 | donut_model=model_module.model, 107 | max_length=config.max_length, 108 | split=split, 109 | task_start_token=config.task_start_tokens[i] 110 | if config.get("task_start_tokens", None) 111 | else f"", 112 | prompt_end_token="" if "docvqa" in dataset_name_or_path else f"", 113 | sort_json_key=config.sort_json_key, 114 | ) 115 | ) 116 | # prompt_end_token is used for ignoring a given prompt in a loss function 117 | # for docvqa task, i.e., {"question": {used as a prompt}, "answer": {prediction target}}, 118 | # set prompt_end_token to "" 119 | data_module.train_datasets = datasets["train"] 120 | data_module.val_datasets = datasets["validation"] 121 | 122 | logger = TensorBoardLogger( 123 | save_dir=config.result_path, 124 | name=config.exp_name, 125 | version=config.exp_version, 126 | default_hp_metric=False, 127 | ) 128 | 129 | lr_callback = LearningRateMonitor(logging_interval="step") 130 | 131 | checkpoint_callback = ModelCheckpoint( 132 | monitor="val_metric", 133 | dirpath=Path(config.result_path) / config.exp_name / config.exp_version, 134 | filename="artifacts", 135 | save_top_k=1, 136 | save_last=False, 137 | mode="min", 138 | ) 139 | 140 | bar = ProgressBar(config) 141 | 142 | custom_ckpt = CustomCheckpointIO() 143 | trainer = pl.Trainer( 144 | num_nodes=config.get("num_nodes", 1), 145 | devices=torch.cuda.device_count(), 146 | strategy="ddp", 147 | accelerator="gpu", 148 | plugins=custom_ckpt, 149 | max_epochs=config.max_epochs, 150 | max_steps=config.max_steps, 151 | val_check_interval=config.val_check_interval, 152 | check_val_every_n_epoch=config.check_val_every_n_epoch, 153 | gradient_clip_val=config.gradient_clip_val, 154 | precision=16, 155 | num_sanity_val_steps=0, 156 | logger=logger, 157 | callbacks=[lr_callback, checkpoint_callback, bar], 158 | ) 159 | 160 | trainer.fit(model_module, data_module, ckpt_path=config.get("resume_from_checkpoint_path", None)) 161 | 162 | 163 | if __name__ == "__main__": 164 | parser = argparse.ArgumentParser() 165 | parser.add_argument("--config", type=str, required=True) 166 | parser.add_argument("--exp_version", type=str, required=False) 167 | args, left_argv = parser.parse_known_args() 168 | 169 | config = Config(args.config) 170 | config.argv_update(left_argv) 171 | 172 | config.exp_name = basename(args.config).split(".")[0] 173 | config.exp_version = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") if not args.exp_version else args.exp_version 174 | 175 | save_config_file(config, Path(config.result_path) / config.exp_name / config.exp_version) 176 | train(config) 177 | --------------------------------------------------------------------------------