├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── assets ├── demo.gif ├── dolphin.png └── framework.png ├── chat.py ├── config └── Dolphin.yaml ├── demo ├── element_imgs │ ├── block_formula.jpeg │ ├── line_formula.jpeg │ ├── para_1.jpg │ ├── para_2.jpg │ ├── para_3.jpeg │ ├── table_1.jpeg │ └── table_2.jpeg └── page_imgs │ ├── page_1.jpeg │ ├── page_2.jpeg │ ├── page_3.jpeg │ ├── page_4.png │ └── page_5.jpg ├── demo_element.py ├── demo_element_hf.py ├── demo_page.py ├── demo_page_hf.py ├── pyproject.toml ├── requirements.txt └── utils ├── markdown_utils.py ├── model.py ├── processor.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .nox/ 42 | .coverage 43 | *.cover 44 | *.py,cover 45 | .hypothesis/ 46 | .pytest_cache/ 47 | coverage.xml 48 | *.mo 49 | *.pot 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | db.sqlite3-journal 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # IPython 78 | profile_default/ 79 | ipython_config.py 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # pipenv 85 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 86 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 87 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 88 | # install all needed dependencies. 89 | #Pipfile.lock 90 | 91 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 92 | __pypackages__/ 93 | 94 | # Celery stuff 95 | celerybeat-schedule 96 | celerybeat.pid 97 | 98 | # SageMath parsed files 99 | *.sage.py 100 | 101 | # Environments 102 | .env 103 | .venv 104 | env/ 105 | venv/ 106 | ENV/ 107 | env.bak/ 108 | venv.bak/ 109 | 110 | # Spyder project settings 111 | .spyderproject 112 | .spyproject 113 | 114 | # Rope project settings 115 | .ropeproject 116 | 117 | # mkdocs documentation 118 | /site 119 | 120 | # mypy 121 | .mypy_cache/ 122 | .dmypy.json 123 | dmypy.json 124 | 125 | # Pyre type checker 126 | .pyre/ 127 | 128 | # pytype static type analyzer 129 | .pytype/ 130 | 131 | # Cython debug symbols 132 | cython_debug/ 133 | 134 | # PyCharm 135 | .idea/ 136 | *.iml 137 | 138 | # VS Code 139 | .vscode/ 140 | !.vscode/settings.json 141 | !.vscode/tasks.json 142 | !.vscode/launch.json 143 | !.vscode/extensions.json 144 | 145 | # macOS 146 | .DS_Store 147 | 148 | # Windows 149 | Thumbs.db 150 | ehthumbs.db 151 | Desktop.ini 152 | 153 | fusion_result.json 154 | kernel_meta/ 155 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | # 1. isort - 自动排序 Python imports 3 | - repo: https://github.com/pycqa/isort 4 | rev: 6.0.1 # 使用固定版本号 5 | hooks: 6 | - id: isort 7 | name: isort (python) 8 | args: [--profile=black] # 与 Black 兼容的配置 9 | language: python 10 | 11 | # 2. Black - 自动格式化 Python 代码 12 | - repo: https://github.com/psf/black 13 | rev: 25.1.0 # 使用固定版本号 14 | hooks: 15 | - id: black 16 | language: python 17 | 18 | # 3. flake8 - Python 静态检查 19 | - repo: https://github.com/pycqa/flake8 20 | rev: 7.2.0 21 | hooks: 22 | - id: flake8 23 | args: [--max-line-length=120, --ignore=E203] # 设置行长度为 120 24 | additional_dependencies: [flake8-bugbear==24.12.12] # 可选:增强检查 25 | 26 | # 4. pre-commit-hooks - 通用 Git 钩子 27 | - repo: https://github.com/pre-commit/pre-commit-hooks 28 | rev: v5.0.0 29 | hooks: 30 | - id: trailing-whitespace # 删除行尾空格 31 | - id: end-of-file-fixer # 确保文件以换行符结束 32 | - id: check-yaml # 验证 YAML 文件语法 33 | - id: check-added-large-files # 阻止大文件提交 34 | args: ["--maxkb=512"] 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright 2025 ByteDance Ltd. and/or its affiliates 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
4 | 5 |
6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 |
22 |
23 | 24 |
25 | 26 |
27 | 28 |
29 | 30 | # Dolphin: Document Image Parsing via Heterogeneous Anchor Prompting 31 | 32 | Dolphin (**Do**cument Image **P**arsing via **H**eterogeneous Anchor Prompt**in**g) is a novel multimodal document image parsing model following an analyze-then-parse paradigm. This repository contains the demo code and pre-trained models for Dolphin. 33 | 34 | ## 📑 Overview 35 | 36 | Document image parsing is challenging due to its complexly intertwined elements such as text paragraphs, figures, formulas, and tables. Dolphin addresses these challenges through a two-stage approach: 37 | 38 | 1. **🔍 Stage 1**: Comprehensive page-level layout analysis by generating element sequence in natural reading order 39 | 2. **🧩 Stage 2**: Efficient parallel parsing of document elements using heterogeneous anchors and task-specific prompts 40 | 41 |
42 | 43 |
44 | 45 | Dolphin achieves promising performance across diverse page-level and element-level parsing tasks while ensuring superior efficiency through its lightweight architecture and parallel parsing mechanism. 46 | 47 | ## 🚀 Demo 48 | 49 | Try our demo on [Demo-Dolphin](http://115.190.42.15:8888/dolphin/). 50 | 51 | 52 | ## 📅 Changelog 53 | - 🔥 **2025.05.21** Our demo is released at [link](http://115.190.42.15:8888/dolphin/). Check it out! 54 | - 🔥 **2025.05.20** The pretrained model and inference code of Dolphin are released. 55 | - 🔥 **2025.05.16** Our paper has been accepted by ACL 2025. Paper link: [arXiv](https://arxiv.org/abs/2505.14059). 56 | 57 | ## 🛠️ Installation 58 | 59 | 1. Clone the repository: 60 | ```bash 61 | git clone https://github.com/ByteDance/Dolphin.git 62 | cd Dolphin 63 | ``` 64 | 65 | 2. Install the dependencies: 66 | ```bash 67 | pip install -r requirements.txt 68 | ``` 69 | 70 | 3. Download the pre-trained models using one of the following options: 71 | 72 | **Option A: Original Model Format (config-based)** 73 | 74 | Download from [Baidu Yun](https://pan.baidu.com/s/15zcARoX0CTOHKbW8bFZovQ?pwd=9rpx) or [Google Drive](https://drive.google.com/drive/folders/1PQJ3UutepXvunizZEw-uGaQ0BCzf-mie?usp=sharing) and put them in the `./checkpoints` folder. 75 | 76 | **Option B: Hugging Face Model Format** 77 | 78 | Visit our Huggingface [model card](https://huggingface.co/ByteDance/Dolphin), or download model by: 79 | 80 | ```bash 81 | # Download the model from Hugging Face Hub 82 | git lfs install 83 | git clone https://huggingface.co/ByteDance/Dolphin ./hf_model 84 | # Or use the Hugging Face CLI 85 | huggingface-cli download ByteDance/Dolphin --local-dir ./hf_model 86 | ``` 87 | 88 | ## ⚡ Inference 89 | 90 | Dolphin provides two inference frameworks with support for two parsing granularities: 91 | - **Page-level Parsing**: Parse the entire document image into a structured JSON and Markdown format 92 | - **Element-level Parsing**: Parse individual document elements (text, table, formula) 93 | 94 | ### 📄 Page-level Parsing 95 | 96 | #### Using Original Framework (config-based) 97 | 98 | ```bash 99 | # Process a single document image 100 | python demo_page.py --config ./config/Dolphin.yaml --input_path ./demo/page_imgs/page_1.jpeg --save_dir ./results 101 | 102 | # Process all document images in a directory 103 | python demo_page.py --config ./config/Dolphin.yaml --input_path ./demo/page_imgs --save_dir ./results 104 | 105 | # Process with custom batch size for parallel element decoding 106 | python demo_page.py --config ./config/Dolphin.yaml --input_path ./demo/page_imgs --save_dir ./results --max_batch_size 8 107 | ``` 108 | 109 | #### Using Hugging Face Framework 110 | 111 | ```bash 112 | # Process a single document image 113 | python demo_page_hf.py --model_path ./hf_model --input_path ./demo/page_imgs/page_1.jpeg --save_dir ./results 114 | 115 | # Process all document images in a directory 116 | python demo_page_hf.py --model_path ./hf_model --input_path ./demo/page_imgs --save_dir ./results 117 | 118 | # Process with custom batch size for parallel element decoding 119 | python demo_page_hf.py --model_path ./hf_model --input_path ./demo/page_imgs --save_dir ./results --max_batch_size 16 120 | ``` 121 | 122 | ### 🧩 Element-level Parsing 123 | 124 | #### Using Original Framework (config-based) 125 | 126 | ```bash 127 | # Process a single table image 128 | python demo_element.py --config ./config/Dolphin.yaml --input_path ./demo/element_imgs/table_1.jpeg --element_type table 129 | 130 | # Process a single formula image 131 | python demo_element.py --config ./config/Dolphin.yaml --input_path ./demo/element_imgs/line_formula.jpeg --element_type formula 132 | 133 | # Process a single text paragraph image 134 | python demo_element.py --config ./config/Dolphin.yaml --input_path ./demo/element_imgs/para_1.jpg --element_type text 135 | ``` 136 | 137 | #### Using Hugging Face Framework 138 | 139 | ```bash 140 | # Process a single table image 141 | python demo_element_hf.py --model_path ./hf_model --input_path ./demo/element_imgs/table_1.jpeg --element_type table 142 | 143 | # Process a single formula image 144 | python demo_element_hf.py --model_path ./hf_model --input_path ./demo/element_imgs/line_formula.jpeg --element_type formula 145 | 146 | # Process a single text paragraph image 147 | python demo_element_hf.py --model_path ./hf_model --input_path ./demo/element_imgs/para_1.jpg --element_type text 148 | ``` 149 | 150 | ## 🌟 Key Features 151 | 152 | - 🔄 Two-stage analyze-then-parse approach based on a single VLM 153 | - 📊 Promising performance on document parsing tasks 154 | - 🔍 Natural reading order element sequence generation 155 | - 🧩 Heterogeneous anchor prompting for different document elements 156 | - ⏱️ Efficient parallel parsing mechanism 157 | - 🤗 Support for Hugging Face Transformers for easier integration 158 | 159 | 160 | ## 📮 Notice 161 | **Call for Bad Cases:** If you have encountered any cases where the model performs poorly, we would greatly appreciate it if you could share them in the issue. We are continuously working to optimize and improve the model. 162 | 163 | ## 💖 Acknowledgement 164 | 165 | We would like to acknowledge the following open-source projects that provided inspiration and reference for this work: 166 | - [Donut](https://github.com/clovaai/donut/) 167 | - [Nougat](https://github.com/facebookresearch/nougat) 168 | - [GOT](https://github.com/Ucas-HaoranWei/GOT-OCR2.0) 169 | - [MinerU](https://github.com/opendatalab/MinerU/tree/master) 170 | - [Swin](https://github.com/microsoft/Swin-Transformer) 171 | - [Hugging Face Transformers](https://github.com/huggingface/transformers) 172 | 173 | ## 📝 Citation 174 | 175 | If you find this code useful for your research, please use the following BibTeX entry. 176 | 177 | ```bibtex 178 | @inproceedings{dolphin2025, 179 | title={Dolphin: Document Image Parsing via Heterogeneous Anchor Prompting}, 180 | author={Feng, Hao and Wei, Shu and Fei, Xiang and Shi, Wei and Han, Yingdong and Liao, Lei and Lu, Jinghui and Wu, Binghong and Liu, Qi and Lin, Chunhui and Tang, Jingqun and Liu, Hao and Huang, Can}, 181 | year={2025}, 182 | booktitle={Proceedings of the 65rd Annual Meeting of the Association for Computational Linguistics (ACL)} 183 | } 184 | ``` 185 | 186 | ## Star History 187 | 188 | [![Star History Chart](https://api.star-history.com/svg?repos=bytedance/Dolphin&type=Date)](https://www.star-history.com/#bytedance/Dolphin&Date) 189 | -------------------------------------------------------------------------------- /assets/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/49f51871c6d2c56b2ef08112c386fd2222fbdd99/assets/demo.gif -------------------------------------------------------------------------------- /assets/dolphin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/49f51871c6d2c56b2ef08112c386fd2222fbdd99/assets/dolphin.png -------------------------------------------------------------------------------- /assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/49f51871c6d2c56b2ef08112c386fd2222fbdd99/assets/framework.png -------------------------------------------------------------------------------- /chat.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 3 | SPDX-License-Identifier: MIT 4 | """ 5 | 6 | import os 7 | import warnings 8 | from collections import OrderedDict 9 | 10 | from omegaconf import ListConfig 11 | 12 | warnings.filterwarnings("ignore", category=UserWarning) 13 | warnings.filterwarnings("ignore", category=FutureWarning) 14 | os.environ.setdefault("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION", "python") 15 | 16 | import torch 17 | from PIL import Image 18 | from transformers import PreTrainedTokenizerFast 19 | 20 | from utils.model import DonutConfig, DonutModel, SwinEncoder 21 | from utils.processor import DolphinProcessor 22 | 23 | 24 | def try_rename_lagacy_weights(ckpt, output_path=""): 25 | if "state_dict" in ckpt.keys(): 26 | ckpt = ckpt["state_dict"] 27 | if "module" in ckpt.keys(): 28 | ckpt = ckpt["module"] 29 | new_ckpt = OrderedDict() 30 | for k, v in ckpt.items(): 31 | if k.startswith("model."): 32 | k = k[len("model.") :] 33 | if k.startswith("encoder"): 34 | new_ckpt["vpm" + k[len("encoder") :]] = v 35 | elif k.startswith("decoder"): 36 | new_ckpt["llm" + k[len("encoder") :]] = v 37 | else: 38 | new_ckpt[k] = v 39 | if output_path: 40 | torch.save(new_ckpt, output_path) 41 | return new_ckpt 42 | 43 | 44 | def convert_listconfig_to_list(config): 45 | new_config = {} 46 | for k, v in config.items(): 47 | if isinstance(v, ListConfig): 48 | new_config[k] = list(v) 49 | else: 50 | new_config[k] = v 51 | return new_config 52 | 53 | 54 | class DOLPHIN: 55 | def __init__(self, config, ckpt_path="") -> None: 56 | self.model_args = config.model 57 | self.swin_args = config.model.pop("swin_args") 58 | self.swin_args = convert_listconfig_to_list(self.swin_args) 59 | 60 | vision_tower = SwinEncoder( 61 | input_size=self.swin_args["img_size"], 62 | patch_size=self.swin_args["patch_size"], 63 | embed_dim=self.swin_args["embed_dim"], 64 | window_size=self.swin_args["window_size"], 65 | encoder_layer=self.swin_args["encoder_layer"], 66 | num_heads=self.swin_args["num_heads"], 67 | align_long_axis=self.swin_args["align_long_axis"], 68 | ) 69 | 70 | self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=self.model_args.tokenizer_path) 71 | self.tokenizer.pad_token = "" 72 | self.tokenizer.bos_token = "" 73 | self.tokenizer.eos_token = "" 74 | self.tokenizer.unk_token = "" 75 | 76 | if self.model_args.get("extra_answer_tokens", False): 77 | # print("Allowing multitask training: adding to the tokenizer.") 78 | prompt_end_token = " " 79 | self.tokenizer.add_special_tokens({"additional_special_tokens": sorted(set([prompt_end_token]))}) 80 | self.tokenizer._prompt_end_token = prompt_end_token 81 | self.tokenizer._prompt_end_token_id = self.tokenizer.convert_tokens_to_ids(prompt_end_token) 82 | 83 | donut_config = DonutConfig( 84 | decoder_layer=self.model_args.decoder_layer, 85 | max_length=self.model_args.max_length, 86 | max_position_embeddings=self.model_args.max_position_embeddings, 87 | hidden_dimension=self.model_args.hidden_dimension, 88 | ) 89 | 90 | self.model = DonutModel(config=donut_config, vision_tower=vision_tower, tokenizer=self.tokenizer) 91 | if self.model_args.model_name_or_path: 92 | ckpt = torch.load(self.model_args.model_name_or_path) 93 | ckpt = try_rename_lagacy_weights(ckpt) 94 | self.model.load_state_dict(ckpt, strict=True) 95 | 96 | self.model.to("cuda") 97 | self.model.eval() 98 | transform_args = { 99 | "input_size": self.swin_args["img_size"], 100 | "max_length": self.model_args.max_length, 101 | } 102 | self.processor = DolphinProcessor({}, self.tokenizer, transform_args=transform_args) 103 | 104 | def chat( 105 | self, 106 | question, 107 | image, 108 | return_raw=False, 109 | return_score=False, 110 | return_img_size=False, 111 | only_return_img_size=False, 112 | max_batch_size=16, 113 | ): 114 | 115 | def _preprocess_image(image): 116 | if isinstance(image, str): 117 | image = Image.open(image).convert("RGB") 118 | if return_img_size or only_return_img_size: 119 | image_tensor, ori_size = self.processor.process_image_for_inference(image, return_img_size=True) 120 | else: 121 | image_tensor = self.processor.process_image_for_inference(image, return_img_size=False) 122 | ori_size = None 123 | return image_tensor, ori_size 124 | 125 | def _preprocess_prompt(question): 126 | if self.model_args.get("extra_answer_tokens", False): 127 | if self.tokenizer._prompt_end_token not in question: 128 | question = question + self.tokenizer._prompt_end_token 129 | prompt_ids = self.processor.process_prompt_for_inference(question) 130 | return prompt_ids 131 | 132 | def _preprocess_prompt_batch(question): 133 | if self.model_args.get("extra_answer_tokens", False): 134 | for i in range(len(question)): 135 | if self.tokenizer._prompt_end_token not in question[i]: 136 | question[i] = question[i] + self.tokenizer._prompt_end_token 137 | if not question[i].startswith(""): 138 | question[i] = "" + question[i] 139 | return question 140 | 141 | def _postprocess(output, question): 142 | output = output.replace("", "").replace(question, "").replace("", "").replace("", "") 143 | if self.model_args.get("extra_answer_tokens", False): 144 | output = output.split(self.tokenizer._prompt_end_token)[-1] 145 | return output 146 | 147 | if isinstance(question, list): 148 | image_tensor_list = [] 149 | for i in image: 150 | image_tensor, ori_size = _preprocess_image(i) 151 | image_tensor_list.append(image_tensor) 152 | image_tensor = torch.cat(image_tensor_list, dim=0) 153 | 154 | question = _preprocess_prompt_batch(question) 155 | self.processor.tokenizer.padding_side = "left" 156 | prompt_ids = self.processor.tokenizer( 157 | question, add_special_tokens=False, return_tensors="pt", padding=True 158 | ).input_ids 159 | else: 160 | image_tensor, ori_size = _preprocess_image(image) 161 | prompt_ids = _preprocess_prompt(question) 162 | 163 | if only_return_img_size: 164 | return ori_size 165 | 166 | model_output_batch = [] 167 | for i in range(0, image_tensor.shape[0], max_batch_size): 168 | image_tensor_batch = image_tensor[i : i + max_batch_size] 169 | prompt_ids_batch = prompt_ids[i : i + max_batch_size] 170 | model_output = self.model.inference(image_tensors=image_tensor_batch, prompt_ids=prompt_ids_batch) 171 | model_output_batch.append(model_output) 172 | model_output = {} 173 | for k, v in model_output_batch[0].items(): 174 | if isinstance(v, torch.Tensor): 175 | model_output[k] = sum( 176 | [v_batch[k].cpu().numpy().tolist() for v_batch in model_output_batch], 177 | [], 178 | ) 179 | else: 180 | model_output[k] = sum([v_batch[k] for v_batch in model_output_batch], []) 181 | 182 | if return_raw: 183 | if return_img_size: 184 | return model_output, ori_size 185 | return model_output 186 | else: 187 | if isinstance(question, list): 188 | output = [_postprocess(model_output["repetitions"][i], question[i]) for i in range(len(question))] 189 | score = model_output["scores"] 190 | else: 191 | output = _postprocess(model_output["repetitions"][0], question) 192 | score = model_output["scores"][0] 193 | if return_score: 194 | return output, score 195 | if return_img_size: 196 | return output, ori_size 197 | return output 198 | -------------------------------------------------------------------------------- /config/Dolphin.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | model_name_or_path: "./checkpoints/dolphin_model.bin" 3 | tokenizer_path: "./checkpoints/dolphin_tokenizer.json" 4 | extra_answer_tokens: True # add token 5 | max_length: 4096 6 | decoder_layer: 10 7 | max_position_embeddings: 4096 8 | hidden_dimension: 1024 9 | swin_args: 10 | name: 'swin' 11 | img_size: [896, 896] 12 | patch_size: 4 13 | embed_dim: 128 14 | align_long_axis: False 15 | window_size: 7 16 | encoder_layer: [2, 2, 14, 2] 17 | num_heads: [4, 8, 16, 32] 18 | -------------------------------------------------------------------------------- /demo/element_imgs/block_formula.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/49f51871c6d2c56b2ef08112c386fd2222fbdd99/demo/element_imgs/block_formula.jpeg -------------------------------------------------------------------------------- /demo/element_imgs/line_formula.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/49f51871c6d2c56b2ef08112c386fd2222fbdd99/demo/element_imgs/line_formula.jpeg -------------------------------------------------------------------------------- /demo/element_imgs/para_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/49f51871c6d2c56b2ef08112c386fd2222fbdd99/demo/element_imgs/para_1.jpg -------------------------------------------------------------------------------- /demo/element_imgs/para_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/49f51871c6d2c56b2ef08112c386fd2222fbdd99/demo/element_imgs/para_2.jpg -------------------------------------------------------------------------------- /demo/element_imgs/para_3.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/49f51871c6d2c56b2ef08112c386fd2222fbdd99/demo/element_imgs/para_3.jpeg -------------------------------------------------------------------------------- /demo/element_imgs/table_1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/49f51871c6d2c56b2ef08112c386fd2222fbdd99/demo/element_imgs/table_1.jpeg -------------------------------------------------------------------------------- /demo/element_imgs/table_2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/49f51871c6d2c56b2ef08112c386fd2222fbdd99/demo/element_imgs/table_2.jpeg -------------------------------------------------------------------------------- /demo/page_imgs/page_1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/49f51871c6d2c56b2ef08112c386fd2222fbdd99/demo/page_imgs/page_1.jpeg -------------------------------------------------------------------------------- /demo/page_imgs/page_2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/49f51871c6d2c56b2ef08112c386fd2222fbdd99/demo/page_imgs/page_2.jpeg -------------------------------------------------------------------------------- /demo/page_imgs/page_3.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/49f51871c6d2c56b2ef08112c386fd2222fbdd99/demo/page_imgs/page_3.jpeg -------------------------------------------------------------------------------- /demo/page_imgs/page_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/49f51871c6d2c56b2ef08112c386fd2222fbdd99/demo/page_imgs/page_4.png -------------------------------------------------------------------------------- /demo/page_imgs/page_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/49f51871c6d2c56b2ef08112c386fd2222fbdd99/demo/page_imgs/page_5.jpg -------------------------------------------------------------------------------- /demo_element.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 3 | SPDX-License-Identifier: MIT 4 | """ 5 | 6 | import argparse 7 | import glob 8 | import os 9 | 10 | from omegaconf import OmegaConf 11 | from PIL import Image 12 | 13 | from chat import DOLPHIN 14 | from utils.utils import * 15 | 16 | 17 | def process_element(image_path, model, element_type, save_dir=None): 18 | """Process a single element image (text, table, formula) 19 | 20 | Args: 21 | image_path: Path to the element image 22 | model: DOLPHIN model instance 23 | element_type: Type of element ('text', 'table', 'formula') 24 | save_dir: Directory to save results (default: same as input directory) 25 | 26 | Returns: 27 | Parsed content of the element and recognition results 28 | """ 29 | # Load and prepare image 30 | pil_image = Image.open(image_path).convert("RGB") 31 | pil_image = crop_margin(pil_image) 32 | 33 | # Select appropriate prompt based on element type 34 | if element_type == "table": 35 | prompt = "Parse the table in the image." 36 | label = "tab" 37 | elif element_type == "formula": 38 | prompt = "Read text in the image." 39 | label = "formula" 40 | else: # Default to text 41 | prompt = "Read text in the image." 42 | label = "text" 43 | 44 | # Process the element 45 | result = model.chat(prompt, pil_image) 46 | 47 | # Create recognition result in the same format as the document parser 48 | recognition_result = [ 49 | { 50 | "label": label, 51 | "text": result.strip(), 52 | } 53 | ] 54 | 55 | # Save results if save_dir is provided 56 | if save_dir: 57 | save_outputs(recognition_result, image_path, save_dir) 58 | print(f"Results saved to {save_dir}") 59 | 60 | return result, recognition_result 61 | 62 | 63 | def main(): 64 | parser = argparse.ArgumentParser(description="Element-level processing using DOLPHIN model") 65 | parser.add_argument("--config", default="./config/Dolphin.yaml", help="Path to configuration file") 66 | parser.add_argument("--input_path", type=str, required=True, help="Path to input image or directory of images") 67 | parser.add_argument( 68 | "--element_type", 69 | type=str, 70 | choices=["text", "table", "formula"], 71 | default="text", 72 | help="Type of element to process (text, table, formula)", 73 | ) 74 | parser.add_argument( 75 | "--save_dir", 76 | type=str, 77 | default=None, 78 | help="Directory to save parsing results (default: same as input directory)", 79 | ) 80 | parser.add_argument("--print_results", action="store_true", help="Print recognition results to console") 81 | args = parser.parse_args() 82 | 83 | # Load Model 84 | config = OmegaConf.load(args.config) 85 | model = DOLPHIN(config) 86 | 87 | # Set save directory 88 | save_dir = args.save_dir or ( 89 | args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path) 90 | ) 91 | setup_output_dirs(save_dir) 92 | 93 | # Collect Images 94 | if os.path.isdir(args.input_path): 95 | image_files = [] 96 | for ext in [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"]: 97 | image_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}"))) 98 | image_files = sorted(image_files) 99 | else: 100 | if not os.path.exists(args.input_path): 101 | raise FileNotFoundError(f"Input path {args.input_path} does not exist") 102 | image_files = [args.input_path] 103 | 104 | total_samples = len(image_files) 105 | print(f"\nTotal samples to process: {total_samples}") 106 | 107 | # Process images one by one 108 | for image_path in image_files: 109 | print(f"\nProcessing {image_path}") 110 | try: 111 | result, recognition_result = process_element( 112 | image_path=image_path, 113 | model=model, 114 | element_type=args.element_type, 115 | save_dir=save_dir, 116 | ) 117 | 118 | if args.print_results: 119 | print("\nRecognition result:") 120 | print(result) 121 | print("-" * 40) 122 | 123 | except Exception as e: 124 | print(f"Error processing {image_path}: {str(e)}") 125 | continue 126 | 127 | 128 | if __name__ == "__main__": 129 | main() 130 | -------------------------------------------------------------------------------- /demo_element_hf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 3 | SPDX-License-Identifier: MIT 4 | """ 5 | 6 | import argparse 7 | import glob 8 | import os 9 | 10 | import torch 11 | from PIL import Image 12 | from transformers import AutoProcessor, VisionEncoderDecoderModel 13 | 14 | from utils.utils import * 15 | 16 | 17 | class DOLPHIN: 18 | def __init__(self, model_id_or_path): 19 | """Initialize the Hugging Face model 20 | 21 | Args: 22 | model_id_or_path: Path to local model or Hugging Face model ID 23 | """ 24 | # Load model from local path or Hugging Face hub 25 | self.processor = AutoProcessor.from_pretrained(model_id_or_path) 26 | self.model = VisionEncoderDecoderModel.from_pretrained(model_id_or_path) 27 | self.model.eval() 28 | 29 | # Set device and precision 30 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 31 | self.model.to(self.device) 32 | self.model = self.model.half() # Always use half precision by default 33 | 34 | # set tokenizer 35 | self.tokenizer = self.processor.tokenizer 36 | 37 | def chat(self, prompt, image): 38 | """Process an image with the given prompt 39 | 40 | Args: 41 | prompt: Text prompt to guide the model 42 | image: PIL Image to process 43 | 44 | Returns: 45 | Generated text from the model 46 | """ 47 | # Prepare image 48 | pixel_values = self.processor(image, return_tensors="pt").pixel_values 49 | pixel_values = pixel_values.half() 50 | 51 | # Prepare prompt 52 | prompt = f"{prompt} " 53 | prompt_ids = self.tokenizer( 54 | prompt, 55 | add_special_tokens=False, 56 | return_tensors="pt" 57 | ).input_ids.to(self.device) 58 | 59 | decoder_attention_mask = torch.ones_like(prompt_ids) 60 | 61 | # Generate text 62 | outputs = self.model.generate( 63 | pixel_values=pixel_values.to(self.device), 64 | decoder_input_ids=prompt_ids, 65 | decoder_attention_mask=decoder_attention_mask, 66 | min_length=1, 67 | max_length=4096, 68 | pad_token_id=self.tokenizer.pad_token_id, 69 | eos_token_id=self.tokenizer.eos_token_id, 70 | use_cache=True, 71 | bad_words_ids=[[self.tokenizer.unk_token_id]], 72 | return_dict_in_generate=True, 73 | do_sample=False, 74 | num_beams=1, 75 | ) 76 | 77 | # Process the output 78 | sequence = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)[0] 79 | sequence = sequence.replace(prompt, "").replace("", "").replace("", "").strip() 80 | 81 | return sequence 82 | 83 | def process_element(image_path, model, element_type, save_dir=None): 84 | """Process a single element image (text, table, formula) 85 | 86 | Args: 87 | image_path: Path to the element image 88 | model: HFModel model instance 89 | element_type: Type of element ('text', 'table', 'formula') 90 | save_dir: Directory to save results (default: same as input directory) 91 | 92 | Returns: 93 | Parsed content of the element and recognition results 94 | """ 95 | # Load and prepare image 96 | pil_image = Image.open(image_path).convert("RGB") 97 | pil_image = crop_margin(pil_image) 98 | 99 | # Select appropriate prompt based on element type 100 | if element_type == "table": 101 | prompt = "Parse the table in the image." 102 | label = "tab" 103 | elif element_type == "formula": 104 | prompt = "Read text in the image." 105 | label = "formula" 106 | else: # Default to text 107 | prompt = "Read text in the image." 108 | label = "text" 109 | 110 | # Process the element 111 | result = model.chat(prompt, pil_image) 112 | 113 | # Create recognition result in the same format as the document parser 114 | recognition_result = [ 115 | { 116 | "label": label, 117 | "text": result.strip(), 118 | } 119 | ] 120 | 121 | # Save results if save_dir is provided 122 | if save_dir: 123 | save_outputs(recognition_result, image_path, save_dir) 124 | print(f"Results saved to {save_dir}") 125 | 126 | return result, recognition_result 127 | 128 | 129 | def main(): 130 | parser = argparse.ArgumentParser(description="Element-level processing using DOLPHIN model") 131 | parser.add_argument("--model_path", default="./hf_model", help="Path to Hugging Face model") 132 | parser.add_argument("--input_path", type=str, required=True, help="Path to input image or directory of images") 133 | parser.add_argument( 134 | "--element_type", 135 | type=str, 136 | choices=["text", "table", "formula"], 137 | default="text", 138 | help="Type of element to process (text, table, formula)", 139 | ) 140 | parser.add_argument( 141 | "--save_dir", 142 | type=str, 143 | default=None, 144 | help="Directory to save parsing results (default: same as input directory)", 145 | ) 146 | parser.add_argument("--print_results", action="store_true", help="Print recognition results to console") 147 | args = parser.parse_args() 148 | 149 | # Load Model 150 | model = DOLPHIN(args.model_path) 151 | 152 | # Set save directory 153 | save_dir = args.save_dir or ( 154 | args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path) 155 | ) 156 | setup_output_dirs(save_dir) 157 | 158 | # Collect Images 159 | if os.path.isdir(args.input_path): 160 | image_files = [] 161 | for ext in [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"]: 162 | image_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}"))) 163 | image_files = sorted(image_files) 164 | else: 165 | if not os.path.exists(args.input_path): 166 | raise FileNotFoundError(f"Input path {args.input_path} does not exist") 167 | image_files = [args.input_path] 168 | 169 | total_samples = len(image_files) 170 | print(f"\nTotal samples to process: {total_samples}") 171 | 172 | # Process images one by one 173 | for image_path in image_files: 174 | print(f"\nProcessing {image_path}") 175 | try: 176 | result, recognition_result = process_element( 177 | image_path=image_path, 178 | model=model, 179 | element_type=args.element_type, 180 | save_dir=save_dir, 181 | ) 182 | 183 | if args.print_results: 184 | print("\nRecognition result:") 185 | print(result) 186 | print("-" * 40) 187 | except Exception as e: 188 | print(f"Error processing {image_path}: {str(e)}") 189 | continue 190 | 191 | 192 | if __name__ == "__main__": 193 | main() 194 | -------------------------------------------------------------------------------- /demo_page.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 3 | SPDX-License-Identifier: MIT 4 | """ 5 | 6 | import argparse 7 | import glob 8 | import os 9 | 10 | import cv2 11 | from omegaconf import OmegaConf 12 | from PIL import Image 13 | 14 | from chat import DOLPHIN 15 | from utils.utils import * 16 | 17 | 18 | def process_page(image_path, model, save_dir, max_batch_size): 19 | """Parse document images with two stages""" 20 | # Stage 1: Page-level layout and reading order parsing 21 | pil_image = Image.open(image_path).convert("RGB") 22 | layout_output = model.chat("Parse the reading order of this document.", pil_image) 23 | 24 | # Stage 2: Element-level content parsing 25 | padded_image, dims = prepare_image(pil_image) 26 | recognition_results = process_elements(layout_output, padded_image, dims, model, max_batch_size) 27 | 28 | # Save outputs 29 | json_path = save_outputs(recognition_results, image_path, save_dir) 30 | 31 | return json_path, recognition_results 32 | 33 | 34 | def process_elements(layout_results, padded_image, dims, model, max_batch_size): 35 | """Parse all document elements with parallel decoding""" 36 | layout_results = parse_layout_string(layout_results) 37 | 38 | text_table_elements = [] # Elements that need processing 39 | figure_results = [] # Figure elements (no processing needed) 40 | previous_box = None 41 | reading_order = 0 42 | 43 | # Collect elements for processing 44 | for bbox, label in layout_results: 45 | try: 46 | # Adjust coordinates 47 | x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = process_coordinates( 48 | bbox, padded_image, dims, previous_box 49 | ) 50 | 51 | # Crop and parse element 52 | cropped = padded_image[y1:y2, x1:x2] 53 | if cropped.size > 0: 54 | if label == "fig": 55 | # For figure regions, add empty text result immediately 56 | figure_results.append( 57 | { 58 | "label": label, 59 | "bbox": [orig_x1, orig_y1, orig_x2, orig_y2], 60 | "text": "", 61 | "reading_order": reading_order, 62 | } 63 | ) 64 | else: 65 | # For text or table regions, prepare for parsing 66 | pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB)) 67 | prompt = "Parse the table in the image." if label == "tab" else "Read text in the image." 68 | text_table_elements.append( 69 | { 70 | "crop": pil_crop, 71 | "prompt": prompt, 72 | "label": label, 73 | "bbox": [orig_x1, orig_y1, orig_x2, orig_y2], 74 | "reading_order": reading_order, 75 | } 76 | ) 77 | 78 | reading_order += 1 79 | 80 | except Exception as e: 81 | print(f"Error processing bbox with label {label}: {str(e)}") 82 | continue 83 | 84 | # Parse text/table elements in parallel 85 | recognition_results = figure_results 86 | if text_table_elements: 87 | crops_list = [elem["crop"] for elem in text_table_elements] 88 | prompts_list = [elem["prompt"] for elem in text_table_elements] 89 | 90 | # Inference in batch 91 | batch_results = model.chat(prompts_list, crops_list, max_batch_size=max_batch_size) 92 | 93 | # Add batch results to recognition_results 94 | for i, result in enumerate(batch_results): 95 | elem = text_table_elements[i] 96 | recognition_results.append( 97 | { 98 | "label": elem["label"], 99 | "bbox": elem["bbox"], 100 | "text": result.strip(), 101 | "reading_order": elem["reading_order"], 102 | } 103 | ) 104 | 105 | # Sort elements by reading order 106 | recognition_results.sort(key=lambda x: x.get("reading_order", 0)) 107 | 108 | return recognition_results 109 | 110 | 111 | def main(): 112 | parser = argparse.ArgumentParser(description="Document processing tool using DOLPHIN model") 113 | parser.add_argument("--config", default="./config/Dolphin.yaml", help="Path to configuration file") 114 | parser.add_argument("--input_path", type=str, default="./demo", help="Path to input image or directory of images") 115 | parser.add_argument( 116 | "--save_dir", 117 | type=str, 118 | default=None, 119 | help="Directory to save parsing results (default: same as input directory)", 120 | ) 121 | parser.add_argument( 122 | "--max_batch_size", 123 | type=int, 124 | default=4, 125 | help="Maximum number of document elements to parse in a single batch (default: 4)", 126 | ) 127 | args = parser.parse_args() 128 | 129 | # Load Model 130 | config = OmegaConf.load(args.config) 131 | model = DOLPHIN(config) 132 | 133 | # Collect Document Images 134 | if os.path.isdir(args.input_path): 135 | image_files = [] 136 | for ext in [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"]: 137 | image_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}"))) 138 | image_files = sorted(image_files) 139 | else: 140 | if not os.path.exists(args.input_path): 141 | raise FileNotFoundError(f"Input path {args.input_path} does not exist") 142 | image_files = [args.input_path] 143 | 144 | save_dir = args.save_dir or ( 145 | args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path) 146 | ) 147 | setup_output_dirs(save_dir) 148 | 149 | total_samples = len(image_files) 150 | print(f"\nTotal samples to process: {total_samples}") 151 | 152 | # Process All Document Images 153 | for image_path in image_files: 154 | print(f"\nProcessing {image_path}") 155 | try: 156 | json_path, recognition_results = process_page( 157 | image_path=image_path, 158 | model=model, 159 | save_dir=save_dir, 160 | max_batch_size=args.max_batch_size, 161 | ) 162 | 163 | print(f"Processing completed. Results saved to {save_dir}") 164 | 165 | except Exception as e: 166 | print(f"Error processing {image_path}: {str(e)}") 167 | continue 168 | 169 | 170 | if __name__ == "__main__": 171 | main() 172 | -------------------------------------------------------------------------------- /demo_page_hf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 3 | SPDX-License-Identifier: MIT 4 | """ 5 | 6 | import argparse 7 | import glob 8 | import os 9 | 10 | import cv2 11 | import torch 12 | from PIL import Image 13 | from transformers import AutoProcessor, VisionEncoderDecoderModel 14 | 15 | from utils.utils import * 16 | 17 | 18 | class DOLPHIN: 19 | def __init__(self, model_id_or_path): 20 | """Initialize the Hugging Face model 21 | 22 | Args: 23 | model_id_or_path: Path to local model or Hugging Face model ID 24 | """ 25 | # Load model from local path or Hugging Face hub 26 | self.processor = AutoProcessor.from_pretrained(model_id_or_path) 27 | self.model = VisionEncoderDecoderModel.from_pretrained(model_id_or_path) 28 | self.model.eval() 29 | 30 | # Set device and precision 31 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 32 | self.model.to(self.device) 33 | self.model = self.model.half() # Always use half precision by default 34 | 35 | # set tokenizer 36 | self.tokenizer = self.processor.tokenizer 37 | 38 | def chat(self, prompt, image): 39 | """Process an image or batch of images with the given prompt(s) 40 | 41 | Args: 42 | prompt: Text prompt or list of prompts to guide the model 43 | image: PIL Image or list of PIL Images to process 44 | 45 | Returns: 46 | Generated text or list of texts from the model 47 | """ 48 | # Check if we're dealing with a batch 49 | is_batch = isinstance(image, list) 50 | 51 | if not is_batch: 52 | # Single image, wrap it in a list for consistent processing 53 | images = [image] 54 | prompts = [prompt] 55 | else: 56 | # Batch of images 57 | images = image 58 | prompts = prompt if isinstance(prompt, list) else [prompt] * len(images) 59 | 60 | # Prepare image 61 | batch_inputs = self.processor(images, return_tensors="pt", padding=True) 62 | batch_pixel_values = batch_inputs.pixel_values.half().to(self.device) 63 | 64 | # Prepare prompt 65 | prompts = [f"{p} " for p in prompts] 66 | batch_prompt_inputs = self.tokenizer( 67 | prompts, 68 | add_special_tokens=False, 69 | return_tensors="pt" 70 | ) 71 | 72 | batch_prompt_ids = batch_prompt_inputs.input_ids.to(self.device) 73 | batch_attention_mask = batch_prompt_inputs.attention_mask.to(self.device) 74 | 75 | # Generate text 76 | outputs = self.model.generate( 77 | pixel_values=batch_pixel_values, 78 | decoder_input_ids=batch_prompt_ids, 79 | decoder_attention_mask=batch_attention_mask, 80 | min_length=1, 81 | max_length=4096, 82 | pad_token_id=self.tokenizer.pad_token_id, 83 | eos_token_id=self.tokenizer.eos_token_id, 84 | use_cache=True, 85 | bad_words_ids=[[self.tokenizer.unk_token_id]], 86 | return_dict_in_generate=True, 87 | do_sample=False, 88 | num_beams=1, 89 | repetition_penalty=1.1 90 | ) 91 | 92 | # Process output 93 | sequences = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False) 94 | 95 | # Clean prompt text from output 96 | results = [] 97 | for i, sequence in enumerate(sequences): 98 | cleaned = sequence.replace(prompts[i], "").replace("", "").replace("", "").strip() 99 | results.append(cleaned) 100 | 101 | # Return a single result for single image input 102 | if not is_batch: 103 | return results[0] 104 | return results 105 | 106 | 107 | def process_page(image_path, model, save_dir, max_batch_size=None): 108 | """Parse document images with two stages""" 109 | # Stage 1: Page-level layout and reading order parsing 110 | pil_image = Image.open(image_path).convert("RGB") 111 | layout_output = model.chat("Parse the reading order of this document.", pil_image) 112 | 113 | # Stage 2: Element-level content parsing 114 | padded_image, dims = prepare_image(pil_image) 115 | recognition_results = process_elements(layout_output, padded_image, dims, model, max_batch_size) 116 | 117 | # Save outputs 118 | json_path = save_outputs(recognition_results, image_path, save_dir) 119 | 120 | return json_path, recognition_results 121 | 122 | 123 | def process_elements(layout_results, padded_image, dims, model, max_batch_size=None): 124 | """Parse all document elements with parallel decoding""" 125 | layout_results = parse_layout_string(layout_results) 126 | 127 | # Store text and table elements separately 128 | text_elements = [] # Text elements 129 | table_elements = [] # Table elements 130 | figure_results = [] # Image elements (no processing needed) 131 | previous_box = None 132 | reading_order = 0 133 | 134 | # Collect elements to process and group by type 135 | for bbox, label in layout_results: 136 | try: 137 | # Adjust coordinates 138 | x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = process_coordinates( 139 | bbox, padded_image, dims, previous_box 140 | ) 141 | 142 | # Crop and parse element 143 | cropped = padded_image[y1:y2, x1:x2] 144 | if cropped.size > 0: 145 | if label == "fig": 146 | # For figure regions, add empty text result immediately 147 | figure_results.append( 148 | { 149 | "label": label, 150 | "bbox": [orig_x1, orig_y1, orig_x2, orig_y2], 151 | "text": "", 152 | "reading_order": reading_order, 153 | } 154 | ) 155 | else: 156 | # Prepare element for parsing 157 | pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB)) 158 | element_info = { 159 | "crop": pil_crop, 160 | "label": label, 161 | "bbox": [orig_x1, orig_y1, orig_x2, orig_y2], 162 | "reading_order": reading_order, 163 | } 164 | 165 | # Group by type 166 | if label == "tab": 167 | table_elements.append(element_info) 168 | else: # Text elements 169 | text_elements.append(element_info) 170 | 171 | reading_order += 1 172 | 173 | except Exception as e: 174 | print(f"Error processing bbox with label {label}: {str(e)}") 175 | continue 176 | 177 | # Initialize results list 178 | recognition_results = figure_results.copy() 179 | 180 | # Process text elements (in batches) 181 | if text_elements: 182 | text_results = process_element_batch(text_elements, model, "Read text in the image.", max_batch_size) 183 | recognition_results.extend(text_results) 184 | 185 | # Process table elements (in batches) 186 | if table_elements: 187 | table_results = process_element_batch(table_elements, model, "Parse the table in the image.", max_batch_size) 188 | recognition_results.extend(table_results) 189 | 190 | # Sort elements by reading order 191 | recognition_results.sort(key=lambda x: x.get("reading_order", 0)) 192 | 193 | return recognition_results 194 | 195 | 196 | def process_element_batch(elements, model, prompt, max_batch_size=None): 197 | """Process elements of the same type in batches""" 198 | results = [] 199 | 200 | # Determine batch size 201 | batch_size = len(elements) 202 | if max_batch_size is not None and max_batch_size > 0: 203 | batch_size = min(batch_size, max_batch_size) 204 | 205 | # Process in batches 206 | for i in range(0, len(elements), batch_size): 207 | batch_elements = elements[i:i+batch_size] 208 | crops_list = [elem["crop"] for elem in batch_elements] 209 | 210 | # Use the same prompt for all elements in the batch 211 | prompts_list = [prompt] * len(crops_list) 212 | 213 | # Batch inference 214 | batch_results = model.chat(prompts_list, crops_list) 215 | 216 | # Add results 217 | for j, result in enumerate(batch_results): 218 | elem = batch_elements[j] 219 | results.append({ 220 | "label": elem["label"], 221 | "bbox": elem["bbox"], 222 | "text": result.strip(), 223 | "reading_order": elem["reading_order"], 224 | }) 225 | 226 | return results 227 | 228 | 229 | def main(): 230 | parser = argparse.ArgumentParser(description="Document processing tool using DOLPHIN model") 231 | parser.add_argument("--model_path", default="./hf_model", help="Path to Hugging Face model") 232 | parser.add_argument("--input_path", type=str, default="./demo", help="Path to input image or directory of images") 233 | parser.add_argument( 234 | "--save_dir", 235 | type=str, 236 | default=None, 237 | help="Directory to save parsing results (default: same as input directory)", 238 | ) 239 | parser.add_argument( 240 | "--max_batch_size", 241 | type=int, 242 | default=16, 243 | help="Maximum number of document elements to parse in a single batch (default: 16)", 244 | ) 245 | args = parser.parse_args() 246 | 247 | # Load Model 248 | model = DOLPHIN(args.model_path) 249 | 250 | # Collect Document Images 251 | if os.path.isdir(args.input_path): 252 | image_files = [] 253 | for ext in [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"]: 254 | image_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}"))) 255 | image_files = sorted(image_files) 256 | else: 257 | if not os.path.exists(args.input_path): 258 | raise FileNotFoundError(f"Input path {args.input_path} does not exist") 259 | image_files = [args.input_path] 260 | 261 | save_dir = args.save_dir or ( 262 | args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path) 263 | ) 264 | setup_output_dirs(save_dir) 265 | 266 | total_samples = len(image_files) 267 | print(f"\nTotal samples to process: {total_samples}") 268 | 269 | # Process All Document Images 270 | for image_path in image_files: 271 | print(f"\nProcessing {image_path}") 272 | try: 273 | json_path, recognition_results = process_page( 274 | image_path=image_path, 275 | model=model, 276 | save_dir=save_dir, 277 | max_batch_size=args.max_batch_size, 278 | ) 279 | 280 | print(f"Processing completed. Results saved to {save_dir}") 281 | 282 | except Exception as e: 283 | print(f"Error processing {image_path}: {str(e)}") 284 | continue 285 | 286 | 287 | if __name__ == "__main__": 288 | main() 289 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | include = '\.pyi?$' 4 | exclude = ''' 5 | /( 6 | \.git 7 | | \.hg 8 | | \.mypy_cache 9 | | \.tox 10 | | \.venv 11 | | _build 12 | | buck-out 13 | | build 14 | | dist 15 | )/ 16 | ''' 17 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==1.4.0 2 | numpy==1.24.4 3 | omegaconf==2.3.0 4 | opencv-python==4.11.0.86 5 | opencv-python-headless==4.5.5.64 6 | pillow==9.3.0 7 | timm==0.5.4 8 | torch==2.1.0 9 | torchvision==0.16.0 10 | transformers==4.47.0 11 | accelerate==1.6.0 -------------------------------------------------------------------------------- /utils/markdown_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 3 | SPDX-License-Identifier: MIT 4 | """ 5 | 6 | import re 7 | import base64 8 | from typing import List, Dict, Any, Optional 9 | 10 | 11 | """ 12 | Example input: 13 | [ 14 | {"label": "tab", "bbox": [0.176, 0.74, 0.824, 0.82], "text": "
HellaSwagObqaWinoGrandeARC-cARC-eboolqpiqaAvg
OPT-1.3B53.6533.4059.5929.4450.8060.8372.3651.44
Pythia-1.0B47.1631.4053.4327.0548.9957.8369.2148.30
Pythia-1.4B52.0133.2057.3828.5054.0063.2770.9551.33
TinyLlama-1.1B59.2036.0059.1230.1055.2557.8373.2952.99
", "reading_order": 6}, 15 | {"label": "cap", "bbox": [0.28, 0.729, 0.711, 0.74], "text": "Table 2: Zero-shot performance on commonsense reasoning tasks", "reading_order": 7}, 16 | {"label": "para", "bbox": [0.176, 0.848, 0.826, 0.873], "text": "We of performance during training We tracked the accuracy of TinyLlama on common-\nsense reasoning benchmarks during its pre-training, as shown in Fig. 2 . Generally, the performance of", "reading_order": 8}, 17 | {"label": "fnote", "bbox": [0.176, 0.88, 0.824, 0.912], "text": "${ }^{4}$ Due to a bug in the config file, the learning rate did not decrease immediately after warmup and remained at\nthe maximum value for several steps before we fixed this.", "reading_order": 9}, 18 | {"label": "foot", "bbox": [0.496, 0.939, 0.501, 0.95], "text": "14", "reading_order": 10} 19 | ] 20 | """ 21 | 22 | 23 | def extract_table_from_html(html_string): 24 | """Extract and clean table tags from HTML string""" 25 | try: 26 | table_pattern = re.compile(r'.*?', re.DOTALL) 27 | tables = table_pattern.findall(html_string) 28 | tables = [re.sub(r']*>', '', table) for table in tables] 29 | return '\n'.join(tables) 30 | except Exception as e: 31 | print(f"extract_table_from_html error: {str(e)}") 32 | return f"
Error extracting table: {str(e)}
" 33 | 34 | 35 | class MarkdownConverter: 36 | """Convert structured recognition results to Markdown format""" 37 | 38 | def __init__(self): 39 | # Define heading levels for different section types 40 | self.heading_levels = { 41 | 'title': '#', 42 | 'sec': '##', 43 | 'sub_sec': '###' 44 | } 45 | 46 | # Define which labels need special handling 47 | self.special_labels = { 48 | 'tab', 'fig', 'title', 'sec', 'sub_sec', 49 | 'list', 'formula', 'reference', 'alg' 50 | } 51 | 52 | def try_remove_newline(self, text: str) -> str: 53 | try: 54 | # Preprocess text to handle line breaks 55 | text = text.strip() 56 | text = text.replace('-\n', '') 57 | 58 | # Handle Chinese text line breaks 59 | def is_chinese(char): 60 | return '\u4e00' <= char <= '\u9fff' 61 | 62 | lines = text.split('\n') 63 | processed_lines = [] 64 | 65 | # Process all lines except the last one 66 | for i in range(len(lines)-1): 67 | current_line = lines[i].strip() 68 | next_line = lines[i+1].strip() 69 | 70 | # Always add the current line, but determine if we need a newline 71 | if current_line: # If current line is not empty 72 | if next_line: # If next line is not empty 73 | # For Chinese text handling 74 | if is_chinese(current_line[-1]) and is_chinese(next_line[0]): 75 | processed_lines.append(current_line) 76 | else: 77 | processed_lines.append(current_line + ' ') 78 | else: 79 | # Next line is empty, add current line with newline 80 | processed_lines.append(current_line + '\n') 81 | else: 82 | # Current line is empty, add an empty line 83 | processed_lines.append('\n') 84 | 85 | # Add the last line 86 | if lines and lines[-1].strip(): 87 | processed_lines.append(lines[-1].strip()) 88 | 89 | text = ''.join(processed_lines) 90 | 91 | return text 92 | except Exception as e: 93 | print(f"try_remove_newline error: {str(e)}") 94 | return text # Return original text on error 95 | 96 | def _handle_text(self, text: str) -> str: 97 | """ 98 | Process regular text content, preserving paragraph structure 99 | """ 100 | try: 101 | if not text: 102 | return "" 103 | 104 | if text.strip().startswith("\\begin{array}") and text.strip().endswith("\\end{array}"): 105 | text = "$$" + text + "$$" 106 | elif ("_{" in text or "^{" in text or "\\" in text or "_ {" in text or "^ {" in text) and ("$" not in text) and ("\\begin" not in text): 107 | text = "$" + text + "$" 108 | 109 | # Process formulas in text before handling other text processing 110 | text = self._process_formulas_in_text(text) 111 | 112 | text = self.try_remove_newline(text) 113 | 114 | # Return processed text 115 | return text 116 | except Exception as e: 117 | print(f"_handle_text error: {str(e)}") 118 | return text # Return original text on error 119 | 120 | def _process_formulas_in_text(self, text: str) -> str: 121 | """ 122 | Process mathematical formulas in text by iteratively finding and replacing formulas. 123 | - Identify inline and block formulas 124 | - Replace newlines within formulas with \\ 125 | """ 126 | try: 127 | # Define formula delimiters and their corresponding patterns 128 | delimiters = [ 129 | ('$$', '$$'), # Block formula with $$ 130 | ('\\[', '\\]'), # Block formula with \[ \] 131 | ('$', '$'), # Inline formula with $ 132 | ('\\(', '\\)') # Inline formula with \( \) 133 | ] 134 | 135 | # Process the text by iterating through each delimiter type 136 | result = text 137 | 138 | for start_delim, end_delim in delimiters: 139 | # Create a pattern that matches from start to end delimiter 140 | # Using a custom approach to avoid issues with nested delimiters 141 | current_pos = 0 142 | processed_parts = [] 143 | 144 | while current_pos < len(result): 145 | # Find the next start delimiter 146 | start_pos = result.find(start_delim, current_pos) 147 | if start_pos == -1: 148 | # No more formulas of this type 149 | processed_parts.append(result[current_pos:]) 150 | break 151 | 152 | # Add text before the formula 153 | processed_parts.append(result[current_pos:start_pos]) 154 | 155 | # Find the matching end delimiter 156 | end_pos = result.find(end_delim, start_pos + len(start_delim)) 157 | if end_pos == -1: 158 | # No matching end delimiter, treat as regular text 159 | processed_parts.append(result[start_pos:]) 160 | break 161 | 162 | # Extract the formula content (without delimiters) 163 | formula_content = result[start_pos + len(start_delim):end_pos] 164 | 165 | # Process the formula content - replace newlines with \\ 166 | processed_formula = formula_content.replace('\n', ' \\\\ ') 167 | 168 | # Add the processed formula with its delimiters 169 | processed_parts.append(f"{start_delim}{processed_formula}{end_delim}") 170 | 171 | # Move past this formula 172 | current_pos = end_pos + len(end_delim) 173 | 174 | # Update the result with processed text 175 | result = ''.join(processed_parts) 176 | return result 177 | except Exception as e: 178 | print(f"_process_formulas_in_text error: {str(e)}") 179 | return text # Return original text on error 180 | 181 | def _remove_newline_in_heading(self, text: str) -> str: 182 | """ 183 | Remove newline in heading 184 | """ 185 | try: 186 | # Handle Chinese text line breaks 187 | def is_chinese(char): 188 | return '\u4e00' <= char <= '\u9fff' 189 | 190 | # Check if the text contains Chinese characters 191 | if any(is_chinese(char) for char in text): 192 | return text.replace('\n', '') 193 | else: 194 | return text.replace('\n', ' ') 195 | 196 | except Exception as e: 197 | print(f"_remove_newline_in_heading error: {str(e)}") 198 | return text 199 | 200 | def _handle_heading(self, text: str, label: str) -> str: 201 | """ 202 | Convert section headings to appropriate markdown format 203 | """ 204 | try: 205 | level = self.heading_levels.get(label, '#') 206 | text = text.strip() 207 | text = self._remove_newline_in_heading(text) 208 | text = self._handle_text(text) 209 | return f"{level} {text}\n\n" 210 | except Exception as e: 211 | print(f"_handle_heading error: {str(e)}") 212 | return f"# Error processing heading: {text}\n\n" 213 | 214 | def _handle_list_item(self, text: str) -> str: 215 | """ 216 | Convert list items to markdown list format 217 | """ 218 | try: 219 | return f"- {text.strip()}\n" 220 | except Exception as e: 221 | print(f"_handle_list_item error: {str(e)}") 222 | return f"- Error processing list item: {text}\n" 223 | 224 | def _handle_figure(self, text: str, section_count: int) -> str: 225 | """ 226 | Convert base64 encoded image to markdown image syntax 227 | """ 228 | try: 229 | # Determine image format (assuming PNG if not specified) 230 | img_format = "png" 231 | if text.startswith("data:image/"): 232 | # Extract format from data URI 233 | img_format = text.split(";")[0].split("/")[1] 234 | elif ";" in text and "," in text: 235 | # Already in data URI format 236 | return f"![Figure {section_count}]({text})\n\n" 237 | else: 238 | # Raw base64, convert to data URI 239 | data_uri = f"data:image/{img_format};base64,{text}" 240 | return f"![Figure {section_count}]({data_uri})\n\n" 241 | except Exception as e: 242 | print(f"_handle_figure error: {str(e)}") 243 | return f"*[Error processing figure: {str(e)}]*\n\n" 244 | 245 | def _handle_table(self, text: str) -> str: 246 | """ 247 | Convert table content to markdown format 248 | """ 249 | try: 250 | markdown_content = [] 251 | if ' str: 272 | """ 273 | Process algorithm blocks with proper formatting 274 | """ 275 | try: 276 | # Remove algorithm environment tags if present 277 | text = re.sub(r'\\begin\{algorithm\}(.*?)\\end\{algorithm\}', r'\1', text, flags=re.DOTALL) 278 | text = text.replace('\\begin{algorithm}', '').replace('\\end{algorithm}', '') 279 | text = text.replace('\\begin{algorithmic}', '').replace('\\end{algorithmic}', '') 280 | 281 | # Process the algorithm text 282 | lines = text.strip().split('\n') 283 | 284 | # Check if there's a caption or label 285 | caption = "" 286 | algorithm_text = [] 287 | 288 | for line in lines: 289 | if '\\caption' in line: 290 | # Extract caption text 291 | caption_match = re.search(r'\\caption\{(.*?)\}', line) 292 | if caption_match: 293 | caption = f"**{caption_match.group(1)}**\n\n" 294 | continue 295 | elif '\\label' in line: 296 | continue # Skip label lines 297 | else: 298 | algorithm_text.append(line) 299 | 300 | # Join the algorithm text and wrap in code block 301 | formatted_text = '\n'.join(algorithm_text) 302 | 303 | # Return the formatted algorithm with caption 304 | return f"{caption}```\n{formatted_text}\n```\n\n" 305 | except Exception as e: 306 | print(f"_handle_algorithm error: {str(e)}") 307 | return f"*[Error processing algorithm: {str(e)}]*\n\n{text}\n\n" 308 | 309 | def _handle_formula(self, text: str) -> str: 310 | """ 311 | Handle formula-specific content 312 | """ 313 | try: 314 | # Process the formula content 315 | processed_text = self._process_formulas_in_text(text) 316 | 317 | # For formula blocks, ensure they're properly formatted in markdown 318 | if '$$' not in processed_text and '\\[' not in processed_text: 319 | # If no block formula delimiters are present, wrap in $$ for block formula 320 | processed_text = f'$${processed_text}$$' 321 | 322 | return f"{processed_text}\n\n" 323 | except Exception as e: 324 | print(f"_handle_formula error: {str(e)}") 325 | return f"*[Error processing formula: {str(e)}]*\n\n" 326 | 327 | def convert(self, recognition_results: List[Dict[str, Any]]) -> str: 328 | """ 329 | Convert recognition results to markdown format 330 | """ 331 | try: 332 | markdown_content = [] 333 | 334 | for section_count, result in enumerate(recognition_results): 335 | try: 336 | label = result.get('label', '') 337 | text = result.get('text', '').strip() 338 | 339 | # Skip empty text 340 | if not text: 341 | continue 342 | 343 | # Handle different content types 344 | if label in {'title', 'sec', 'sub_sec'}: 345 | markdown_content.append(self._handle_heading(text, label)) 346 | elif label == 'list': 347 | markdown_content.append(self._handle_list_item(text)) 348 | elif label == 'fig': 349 | markdown_content.append(self._handle_figure(text, section_count)) 350 | elif label == 'tab': 351 | markdown_content.append(self._handle_table(text)) 352 | elif label == 'alg': 353 | markdown_content.append(self._handle_algorithm(text)) 354 | elif label == 'formula': 355 | markdown_content.append(self._handle_formula(text)) 356 | elif label not in self.special_labels: 357 | # Handle regular text (paragraphs, etc.) 358 | processed_text = self._handle_text(text) 359 | markdown_content.append(f"{processed_text}\n\n") 360 | except Exception as e: 361 | print(f"Error processing item {section_count}: {str(e)}") 362 | # Add a placeholder for the failed item 363 | markdown_content.append(f"*[Error processing content]*\n\n") 364 | 365 | # Join all content and apply post-processing 366 | result = ''.join(markdown_content) 367 | return self._post_process(result) 368 | except Exception as e: 369 | print(f"convert error: {str(e)}") 370 | return f"Error generating markdown content: {str(e)}" 371 | 372 | def _post_process(self, markdown_content: str) -> str: 373 | """ 374 | Apply post-processing fixes to the generated markdown content 375 | """ 376 | try: 377 | # Handle author information 378 | author_pattern = re.compile(r'\\author\{(.*?)\}', re.DOTALL) 379 | 380 | def process_author_match(match): 381 | # Extract author content 382 | author_content = match.group(1) 383 | # Process the author content 384 | return self._handle_text(author_content) 385 | 386 | # Replace \author{...} with processed content 387 | markdown_content = author_pattern.sub(process_author_match, markdown_content) 388 | 389 | # Handle special case where author is inside math environment 390 | math_author_pattern = re.compile(r'\$(\\author\{.*?\})\$', re.DOTALL) 391 | match = math_author_pattern.search(markdown_content) 392 | if match: 393 | # Extract the author command 394 | author_cmd = match.group(1) 395 | # Extract content from author command 396 | author_content_match = re.search(r'\\author\{(.*?)\}', author_cmd, re.DOTALL) 397 | if author_content_match: 398 | # Get author content and process it 399 | author_content = author_content_match.group(1) 400 | processed_content = self._handle_text(author_content) 401 | # Replace the entire $\author{...}$ block with processed content 402 | markdown_content = markdown_content.replace(match.group(0), processed_content) 403 | 404 | # Replace LaTeX abstract environment with plain text 405 | markdown_content = re.sub(r'\\begin\{abstract\}(.*?)\\end\{abstract\}', 406 | r'**Abstract** \1', 407 | markdown_content, 408 | flags=re.DOTALL) 409 | 410 | # Replace standalone \begin{abstract} (without matching end) 411 | markdown_content = re.sub(r'\\begin\{abstract\}', 412 | r'**Abstract**', 413 | markdown_content) 414 | 415 | # Replace LaTeX equation numbers with tag format, handling cases with extra backslashes 416 | markdown_content = re.sub(r'\\eqno\{\((.*?)\)\}', 417 | r'\\tag{\1}', 418 | markdown_content) 419 | 420 | # Find the starting tag of the formula 421 | markdown_content = markdown_content.replace("\[ \\\\", "$$ \\\\") 422 | 423 | # Find the ending tag of the formula (ensure this is the only ending tag) 424 | markdown_content = markdown_content.replace("\\\\ \]", "\\\\ $$") 425 | 426 | # Fix other common LaTeX issues 427 | replacements = [ 428 | # Fix spacing issues in subscripts and superscripts 429 | (r'_ {', r'_{'), 430 | (r'^ {', r'^{'), 431 | 432 | # Fix potential issues with multiple consecutive newlines 433 | (r'\n{3,}', r'\n\n') 434 | ] 435 | 436 | for old, new in replacements: 437 | markdown_content = re.sub(old, new, markdown_content) 438 | 439 | return markdown_content 440 | except Exception as e: 441 | print(f"_post_process error: {str(e)}") 442 | return markdown_content # Return original content if post-processing fails 443 | -------------------------------------------------------------------------------- /utils/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022-present NAVER Corp. 3 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. 4 | MIT License 5 | This file has been modified by [ByteDance Ltd. and/or its affiliates] on 20250118. 6 | The original file available at https://github.com/clovaai/donut/blob/master/donut/model.py was released under the MIT license. 7 | This modified file is released under the same license. 8 | """ 9 | 10 | import logging 11 | from collections import defaultdict 12 | from typing import List, Optional 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | from PIL import Image 17 | from timm.models.swin_transformer import SwinTransformer 18 | from torch import nn 19 | from transformers import ( 20 | MBartConfig, 21 | MBartForCausalLM, 22 | StoppingCriteria, 23 | StoppingCriteriaList, 24 | ) 25 | from transformers.file_utils import ModelOutput 26 | from transformers.modeling_utils import PretrainedConfig, PreTrainedModel 27 | 28 | 29 | class SwinEncoder(nn.Module): 30 | r""" 31 | Encoder based on SwinTransformer 32 | Set the initial weights and configuration with a pretrained SwinTransformer and then 33 | modify the detailed configurations 34 | 35 | Args: 36 | input_size: Input image size (width, height) 37 | align_long_axis: Whether to rotate image if height is greater than width 38 | window_size: Window size(=patch size) of SwinTransformer 39 | encoder_layer: Number of layers of SwinTransformer encoder 40 | name_or_path: Name of a pretrained model name either registered in huggingface.co. or saved in local. 41 | otherwise, `swin_base_patch4_window12_384` will be set (using `timm`). 42 | """ 43 | 44 | def __init__( 45 | self, 46 | input_size, 47 | align_long_axis: bool = False, 48 | window_size: int = 7, 49 | encoder_layer: List[int] = [2, 2, 14, 2], 50 | patch_size: int = [4, 4], 51 | embed_dim: int = 128, 52 | num_heads: List[int] = [4, 8, 16, 32], 53 | ): 54 | super().__init__() 55 | if isinstance(input_size, int): 56 | input_size = [input_size, input_size] 57 | self.input_size = input_size 58 | self.align_long_axis = align_long_axis 59 | self.window_size = window_size 60 | self.encoder_layer = encoder_layer 61 | self.patch_size = patch_size 62 | self.embed_dim = embed_dim 63 | self.num_heads = num_heads 64 | 65 | self.model = SwinTransformer( 66 | img_size=self.input_size, 67 | depths=self.encoder_layer, 68 | window_size=self.window_size, 69 | patch_size=self.patch_size, 70 | embed_dim=self.embed_dim, 71 | num_heads=self.num_heads, 72 | num_classes=0, 73 | ) 74 | 75 | def forward(self, x: torch.Tensor, text_embedding: Optional[torch.Tensor] = None) -> torch.Tensor: 76 | """ 77 | Args: 78 | x: (batch_size, num_channels, height, width) 79 | """ 80 | x = self.model.patch_embed(x) 81 | x = self.model.pos_drop(x) 82 | x = self.model.layers(x) 83 | return x 84 | 85 | 86 | class LayerNorm(nn.LayerNorm): 87 | """Subclass torch's LayerNorm to handle fp16.""" 88 | 89 | def _set_dtype(self, dtype): 90 | self._dtype = dtype 91 | 92 | def forward(self, x: torch.Tensor): 93 | orig_type = x.dtype 94 | ret = super().forward(x.type(dtype=self._dtype)) 95 | return ret.type(orig_type) 96 | 97 | 98 | class BARTDecoder(nn.Module): 99 | """ 100 | Decoder based on Multilingual BART 101 | Set the initial weights and configuration with a pretrained multilingual BART model, 102 | and modify the detailed configurations as a Donut decoder 103 | 104 | Args: 105 | decoder_layer: 106 | Number of layers of BARTDecoder 107 | max_position_embeddings: 108 | The maximum sequence length to be trained 109 | name_or_path: 110 | Name of a pretrained model name either registered in huggingface.co. or saved in local, 111 | otherwise, `facebook/mbart-large-50` will be set (using `transformers`) 112 | """ 113 | 114 | def __init__( 115 | self, 116 | tokenizer, 117 | decoder_layer: int, 118 | max_position_embeddings: int, 119 | hidden_dimension: int = 1024, 120 | **kwargs, 121 | ): 122 | super().__init__() 123 | self.decoder_layer = decoder_layer 124 | self.max_position_embeddings = max_position_embeddings 125 | self.hidden_dimension = hidden_dimension 126 | 127 | self.tokenizer = tokenizer 128 | 129 | self.model = MBartForCausalLM( 130 | config=MBartConfig( 131 | tie_word_embeddings=True, 132 | is_decoder=True, 133 | is_encoder_decoder=False, 134 | add_cross_attention=True, 135 | decoder_layers=self.decoder_layer, 136 | max_position_embeddings=self.max_position_embeddings, 137 | vocab_size=len(self.tokenizer), 138 | scale_embedding=True, 139 | add_final_layer_norm=True, 140 | d_model=self.hidden_dimension, 141 | ) 142 | ) 143 | # self.model.config.is_encoder_decoder = True # to get cross-attention 144 | self.model.model.decoder.embed_tokens.padding_idx = self.tokenizer.pad_token_id 145 | self.model.prepare_inputs_for_generation = self.prepare_inputs_for_inference 146 | 147 | def add_special_tokens(self, list_of_tokens: List[str]): 148 | """ 149 | Add special tokens to tokenizer and resize the token embeddings 150 | """ 151 | newly_added_num = self.tokenizer.add_special_tokens({"additional_special_tokens": sorted(set(list_of_tokens))}) 152 | if newly_added_num > 0: 153 | self.model.resize_token_embeddings(len(self.tokenizer)) 154 | 155 | def add_tokens(self, list_of_tokens: List[str]): 156 | """ 157 | Add special tokens to tokenizer and resize the token embeddings 158 | """ 159 | newly_added_num = self.tokenizer.add_tokens(sorted(set(list_of_tokens))) 160 | if newly_added_num > 0: 161 | self.model.resize_token_embeddings(len(self.tokenizer)) 162 | 163 | def prepare_inputs_for_inference( 164 | self, 165 | input_ids: torch.Tensor, 166 | encoder_outputs: torch.Tensor, 167 | past=None, 168 | past_key_values=None, 169 | use_cache: bool = None, 170 | attention_mask: torch.Tensor = None, 171 | **kwargs, 172 | ): 173 | """ 174 | Args: 175 | input_ids: (batch_size, sequence_length) 176 | 177 | Returns: 178 | input_ids: (batch_size, sequence_length) 179 | attention_mask: (batch_size, sequence_length) 180 | encoder_hidden_states: (batch_size, sequence_length, embedding_dim) 181 | """ 182 | attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long() 183 | past = past or past_key_values 184 | if past is not None: 185 | input_ids = input_ids[:, -1:] 186 | output = { 187 | "input_ids": input_ids, 188 | "attention_mask": attention_mask, 189 | "past_key_values": past, 190 | "use_cache": use_cache, 191 | "encoder_hidden_states": encoder_outputs.last_hidden_state, 192 | } 193 | return output 194 | 195 | def forward( 196 | self, 197 | input_ids: torch.LongTensor = None, 198 | attention_mask: Optional[torch.Tensor] = None, 199 | encoder_hidden_states: Optional[torch.Tensor] = None, 200 | past_key_values: Optional[torch.Tensor] = None, 201 | inputs_embeds: Optional[torch.FloatTensor] = None, 202 | labels: Optional[torch.Tensor] = None, 203 | use_cache: bool = None, 204 | output_attentions: Optional[torch.Tensor] = None, 205 | output_hidden_states: Optional[torch.Tensor] = None, 206 | return_dict: bool = None, 207 | ): 208 | return self.model.forward( 209 | input_ids=input_ids, 210 | attention_mask=attention_mask, 211 | labels=labels, 212 | encoder_hidden_states=encoder_hidden_states, 213 | past_key_values=past_key_values, 214 | inputs_embeds=inputs_embeds, 215 | use_cache=use_cache, 216 | output_attentions=output_attentions, 217 | output_hidden_states=output_hidden_states, 218 | return_dict=return_dict, 219 | ) 220 | 221 | @staticmethod 222 | def resize_bart_abs_pos_emb(weight: torch.Tensor, max_length: int) -> torch.Tensor: 223 | """ 224 | Resize position embeddings 225 | Truncate if sequence length of MBart backbone is greater than given max_length, 226 | else interpolate to max_length 227 | """ 228 | if weight.shape[0] > max_length: 229 | weight = weight[:max_length, ...] 230 | else: 231 | weight = ( 232 | F.interpolate( 233 | weight.permute(1, 0).unsqueeze(0), 234 | size=max_length, 235 | mode="linear", 236 | align_corners=False, 237 | ) 238 | .squeeze(0) 239 | .permute(1, 0) 240 | ) 241 | return weight 242 | 243 | 244 | class DonutConfig(PretrainedConfig): 245 | 246 | def __init__( 247 | self, 248 | decoder_layer: int = 10, 249 | max_position_embeddings: int = None, 250 | max_length: int = 4096, 251 | hidden_dimension: int = 1024, 252 | **kwargs, 253 | ): 254 | super().__init__() 255 | self.decoder_layer = decoder_layer 256 | self.max_position_embeddings = max_length if max_position_embeddings is None else max_position_embeddings 257 | self.max_length = max_length 258 | self.hidden_dimension = hidden_dimension 259 | 260 | 261 | class RunningVarTorch: 262 | def __init__(self, L=15, norm=False): 263 | self.values = None 264 | self.L = L 265 | self.norm = norm 266 | 267 | def push(self, x: torch.Tensor): 268 | assert x.dim() == 1 269 | if self.values is None: 270 | self.values = x[:, None] 271 | elif self.values.shape[1] < self.L: 272 | self.values = torch.cat((self.values, x[:, None]), 1) 273 | else: 274 | self.values = torch.cat((self.values[:, 1:], x[:, None]), 1) 275 | 276 | def variance(self): 277 | if self.values is None: 278 | return 279 | if self.norm: 280 | return torch.var(self.values, 1) / self.values.shape[1] 281 | else: 282 | return torch.var(self.values, 1) 283 | 284 | 285 | class StoppingCriteriaScores(StoppingCriteria): 286 | def __init__(self, threshold: float = 0.015, window_size: int = 200): 287 | super().__init__() 288 | self.threshold = threshold 289 | self.vars = RunningVarTorch(norm=True) 290 | self.varvars = RunningVarTorch(L=window_size) 291 | self.stop_inds = defaultdict(int) 292 | self.stopped = defaultdict(bool) 293 | self.size = 0 294 | self.window_size = window_size 295 | 296 | @torch.no_grad() 297 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): 298 | last_scores = scores[-1] 299 | self.vars.push(last_scores.max(1)[0].float().cpu()) 300 | self.varvars.push(self.vars.variance()) 301 | self.size += 1 302 | if self.size < self.window_size: 303 | return False 304 | 305 | varvar = self.varvars.variance() 306 | for b in range(len(last_scores)): 307 | if varvar[b] < self.threshold: 308 | if self.stop_inds[b] > 0 and not self.stopped[b]: 309 | self.stopped[b] = self.stop_inds[b] >= self.size 310 | else: 311 | self.stop_inds[b] = int(min(max(self.size, 1) * 1.15 + 150 + self.window_size, 4095)) 312 | else: 313 | self.stop_inds[b] = 0 314 | self.stopped[b] = False 315 | return all(self.stopped.values()) and len(self.stopped) > 0 316 | 317 | 318 | def batch(l, b=15): 319 | subs = [] 320 | for i in range(len(l) - b): 321 | subs.append(l[i : i + b]) 322 | return subs 323 | 324 | 325 | def subdiv(l, b=10): 326 | subs = [] 327 | for i in range(len(l) - b): 328 | subs.append(l[: i + b]) 329 | return subs 330 | 331 | 332 | class DonutModel(PreTrainedModel): 333 | config_class = DonutConfig 334 | base_model_prefix = "donut" 335 | 336 | def __init__(self, config: DonutConfig, vision_tower=None, tokenizer=None): 337 | super().__init__(config) 338 | self.config = config 339 | 340 | self.tokenizer = tokenizer 341 | self.vpm = vision_tower 342 | 343 | # build language model 344 | self.llm = BARTDecoder( 345 | tokenizer=tokenizer, 346 | decoder_layer=self.config.decoder_layer, 347 | max_position_embeddings=self.config.max_position_embeddings, 348 | hidden_dimension=self.config.hidden_dimension, 349 | ) 350 | self.ids_to_tokens = {id: content for content, id in self.llm.tokenizer.vocab.items()} 351 | 352 | def get_input_embeddings(self, tensor): 353 | return self.llm.model.get_input_embeddings()(tensor) 354 | 355 | def forward( 356 | self, 357 | inputs: dict, 358 | ): 359 | image_tensors = inputs["pixel_values"] 360 | input_ids = inputs["input_ids"].contiguous() 361 | attention_mask = inputs["attention_mask"] 362 | labels = inputs["labels"].contiguous() 363 | 364 | encoder_outputs = self.vpm( 365 | image_tensors, 366 | text_embedding=self.llm.model.get_input_embeddings()(input_ids), 367 | ) 368 | 369 | decoder_outputs = self.llm( 370 | input_ids=input_ids, 371 | encoder_hidden_states=encoder_outputs, 372 | attention_mask=attention_mask, 373 | labels=labels, 374 | ) 375 | return decoder_outputs 376 | 377 | def get_hidden_states_during_inference( 378 | self, 379 | prompt_ids: torch.Tensor, 380 | image: Image.Image = None, 381 | image_tensors: Optional[torch.Tensor] = None, 382 | ): 383 | if image_tensors is None: 384 | image_tensors = self.vpm.prepare_input(image).unsqueeze(0) 385 | 386 | if self.device.type != "mps": 387 | image_tensors = image_tensors.to(next(self.parameters()).dtype) 388 | 389 | image_tensors = image_tensors.to(self.device) 390 | prompt_ids = prompt_ids.to(self.device) 391 | all_hidden_states = self.vpm.forward_features( 392 | image_tensors, text_embedding=self.get_input_embeddings(prompt_ids) 393 | ) 394 | return all_hidden_states 395 | 396 | def get_attn_weights_during_inference( 397 | self, 398 | prompt_ids: torch.Tensor, 399 | image: Image.Image = None, 400 | image_tensors: Optional[torch.Tensor] = None, 401 | ): 402 | if image_tensors is None: 403 | image_tensors = self.vpm.prepare_input(image).unsqueeze(0) 404 | 405 | if self.device.type != "mps": 406 | image_tensors = image_tensors.to(next(self.parameters()).dtype) 407 | 408 | image_tensors = image_tensors.to(self.device) 409 | prompt_ids = prompt_ids.to(self.device) 410 | last_attn_score = self.vpm.get_last_layer_cross_attn_score( 411 | image_tensors, text_embedding=self.get_input_embeddings(prompt_ids) 412 | ) 413 | return last_attn_score 414 | 415 | def inference( 416 | self, 417 | prompt_ids: torch.Tensor, 418 | image: Image.Image = None, 419 | image_tensors: Optional[torch.Tensor] = None, 420 | return_attentions: bool = False, 421 | early_stopping: bool = True, 422 | ): 423 | """ 424 | Generate a token sequence in an auto-regressive manner. 425 | 426 | Args: 427 | image: input document image (PIL.Image) 428 | image_tensors: (1, num_channels, height, width) 429 | convert prompt to tensor if image_tensor is not fed 430 | """ 431 | output = { 432 | "predictions": list(), 433 | "sequences": list(), 434 | "repeats": list(), 435 | "repetitions": list(), 436 | } 437 | if image is None and image_tensors is None: 438 | logging.warn("Image not found") 439 | return output 440 | 441 | if image_tensors is None: 442 | image_tensors = self.vpm.prepare_input(image).unsqueeze(0) 443 | 444 | if self.device.type != "mps": 445 | image_tensors = image_tensors.to(next(self.parameters()).dtype) 446 | 447 | image_tensors = image_tensors.to(self.device) 448 | prompt_ids = prompt_ids.to(self.device) 449 | last_hidden_state = self.vpm(image_tensors, text_embedding=self.get_input_embeddings(prompt_ids)) 450 | 451 | encoder_outputs = ModelOutput(last_hidden_state=last_hidden_state, attentions=None) 452 | if len(encoder_outputs.last_hidden_state.size()) == 1: 453 | encoder_outputs.last_hidden_state = encoder_outputs.last_hidden_state.unsqueeze(0) 454 | 455 | # get decoder output 456 | decoder_output = self.llm.model.generate( 457 | input_ids=prompt_ids, 458 | encoder_outputs=encoder_outputs, 459 | min_length=1, 460 | max_length=self.config.max_length, 461 | pad_token_id=self.llm.tokenizer.pad_token_id, 462 | eos_token_id=self.llm.tokenizer.eos_token_id, 463 | use_cache=True, 464 | return_dict_in_generate=True, 465 | output_scores=True, 466 | output_attentions=return_attentions, 467 | do_sample=False, 468 | num_beams=1, 469 | stopping_criteria=StoppingCriteriaList([StoppingCriteriaScores()] if early_stopping else []), 470 | ) 471 | 472 | output["repetitions"] = decoder_output.sequences.clone() 473 | output["sequences"] = decoder_output.sequences.clone() 474 | output["scores"] = torch.stack(decoder_output.scores, 1).softmax(-1).cpu().max(-1)[0] 475 | 476 | output["repetitions"] = self.llm.tokenizer.batch_decode(output["repetitions"], skip_special_tokens=False) 477 | return output 478 | -------------------------------------------------------------------------------- /utils/processor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 3 | SPDX-License-Identifier: MIT 4 | """ 5 | 6 | import numpy as np 7 | import torch 8 | from PIL import ImageOps 9 | 10 | from utils.utils import * 11 | 12 | 13 | class DolphinProcessor: 14 | def __init__( 15 | self, 16 | dp_config, 17 | tokenizer, 18 | **kwargs, 19 | ) -> None: 20 | 21 | self.tokenizer = tokenizer 22 | transform_args = kwargs.get("transform_args", {}) 23 | self.max_length = transform_args.get("max_length", 2048) 24 | self.input_size = transform_args.get("input_size", [896, 896]) # height, width 25 | if isinstance(self.input_size, int): 26 | self.input_size = [self.input_size, self.input_size] 27 | 28 | try: 29 | self.answer_start_token = self.tokenizer._prompt_end_token 30 | except AttributeError as err: 31 | print('No answer_start_token found, use "" instead') 32 | self.answer_start_token = "" 33 | 34 | self.prefix_answer_space_flag = dp_config.get("prefix_answer_space_flag", True) 35 | self.suffix_prompt_space_flag = dp_config.get("suffix_prompt_space_flag", True) 36 | 37 | def process_prompt_for_inference(self, prompt): 38 | prompt = prompt.replace("\n", "") 39 | if not prompt.startswith(""): 40 | prompt = "" + prompt 41 | message_ids = [self.tokenizer.encode(prompt, add_special_tokens=False)] 42 | ids = torch.from_numpy(np.hstack(message_ids, dtype=np.int32)) 43 | return ids.unsqueeze(0) 44 | 45 | def process_image_for_inference(self, image, return_img_size=False): 46 | image = resize(image, min(self.input_size)) 47 | 48 | image.thumbnail((self.input_size[1], self.input_size[0])) 49 | origin_w, origin_h = image.size 50 | 51 | delta_width = self.input_size[1] - image.width 52 | delta_height = self.input_size[0] - image.height 53 | pad_width = delta_width // 2 54 | pad_height = delta_height // 2 55 | padding = ( 56 | pad_width, 57 | pad_height, 58 | delta_width - pad_width, 59 | delta_height - pad_height, 60 | ) 61 | image = ImageOps.expand(image, padding) 62 | if return_img_size: 63 | return test_transform(image).unsqueeze(0), (origin_w, origin_h) 64 | return test_transform(image).unsqueeze(0) 65 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 3 | SPDX-License-Identifier: MIT 4 | """ 5 | 6 | import copy 7 | import json 8 | import os 9 | import re 10 | from dataclasses import dataclass 11 | from typing import List, Tuple 12 | 13 | import albumentations as alb 14 | import cv2 15 | import numpy as np 16 | from albumentations.pytorch import ToTensorV2 17 | from PIL import Image 18 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 19 | from torchvision.transforms.functional import resize 20 | 21 | from utils.markdown_utils import MarkdownConverter 22 | 23 | 24 | def alb_wrapper(transform): 25 | def f(im): 26 | return transform(image=np.asarray(im))["image"] 27 | 28 | return f 29 | 30 | 31 | test_transform = alb_wrapper( 32 | alb.Compose( 33 | [ 34 | alb.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), 35 | ToTensorV2(), 36 | ] 37 | ) 38 | ) 39 | 40 | 41 | def check_coord_valid(x1, y1, x2, y2, image_size=None, abs_coord=True): 42 | # print(f"check_coord_valid: {x1}, {y1}, {x2}, {y2}, {image_size}, {abs_coord}") 43 | if x2 <= x1 or y2 <= y1: 44 | return False, f"[{x1}, {y1}, {x2}, {y2}]" 45 | if x1 < 0 or y1 < 0: 46 | return False, f"[{x1}, {y1}, {x2}, {y2}]" 47 | if not abs_coord: 48 | if x2 > 1 or y2 > 1: 49 | return False, f"[{x1}, {y1}, {x2}, {y2}]" 50 | elif image_size is not None: # has image size 51 | if x2 > image_size[0] or y2 > image_size[1]: 52 | return False, f"[{x1}, {y1}, {x2}, {y2}]" 53 | return True, None 54 | 55 | 56 | def adjust_box_edges(image, boxes: List[List[float]], max_pixels=15, threshold=0.2): 57 | """ 58 | Image: cv2.image object, or Path 59 | Input: boxes: list of boxes [[x1, y1, x2, y2]]. Using absolute coordinates. 60 | """ 61 | if isinstance(image, str): 62 | image = cv2.imread(image) 63 | img_h, img_w = image.shape[:2] 64 | new_boxes = [] 65 | for box in boxes: 66 | best_box = copy.deepcopy(box) 67 | 68 | def check_edge(img, current_box, i, is_vertical): 69 | edge = current_box[i] 70 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 71 | _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) 72 | 73 | if is_vertical: 74 | line = binary[current_box[1] : current_box[3] + 1, edge] 75 | else: 76 | line = binary[edge, current_box[0] : current_box[2] + 1] 77 | 78 | transitions = np.abs(np.diff(line)) 79 | return np.sum(transitions) / len(transitions) 80 | 81 | # Only widen the box 82 | edges = [(0, -1, True), (2, 1, True), (1, -1, False), (3, 1, False)] 83 | 84 | current_box = copy.deepcopy(box) 85 | # make sure the box is within the image 86 | current_box[0] = min(max(current_box[0], 0), img_w - 1) 87 | current_box[1] = min(max(current_box[1], 0), img_h - 1) 88 | current_box[2] = min(max(current_box[2], 0), img_w - 1) 89 | current_box[3] = min(max(current_box[3], 0), img_h - 1) 90 | 91 | for i, direction, is_vertical in edges: 92 | best_score = check_edge(image, current_box, i, is_vertical) 93 | if best_score <= threshold: 94 | continue 95 | for step in range(max_pixels): 96 | current_box[i] += direction 97 | if i == 0 or i == 2: 98 | current_box[i] = min(max(current_box[i], 0), img_w - 1) 99 | else: 100 | current_box[i] = min(max(current_box[i], 0), img_h - 1) 101 | score = check_edge(image, current_box, i, is_vertical) 102 | 103 | if score < best_score: 104 | best_score = score 105 | best_box = copy.deepcopy(current_box) 106 | 107 | if score <= threshold: 108 | break 109 | new_boxes.append(best_box) 110 | 111 | return new_boxes 112 | 113 | 114 | def parse_layout_string(bbox_str): 115 | """Parse layout string using regular expressions""" 116 | pattern = r"\[(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+)\]\s*(\w+)" 117 | matches = re.finditer(pattern, bbox_str) 118 | 119 | parsed_results = [] 120 | for match in matches: 121 | coords = [float(match.group(i)) for i in range(1, 5)] 122 | label = match.group(5).strip() 123 | parsed_results.append((coords, label)) 124 | 125 | return parsed_results 126 | 127 | 128 | @dataclass 129 | class ImageDimensions: 130 | """Class to store image dimensions""" 131 | original_w: int 132 | original_h: int 133 | padded_w: int 134 | padded_h: int 135 | 136 | 137 | def map_to_original_coordinates(x1, y1, x2, y2, dims: ImageDimensions) -> Tuple[int, int, int, int]: 138 | """Map coordinates from padded image back to original image 139 | 140 | Args: 141 | x1, y1, x2, y2: Coordinates in padded image 142 | dims: Image dimensions object 143 | 144 | Returns: 145 | tuple: (x1, y1, x2, y2) coordinates in original image 146 | """ 147 | try: 148 | # Calculate padding offsets 149 | top = (dims.padded_h - dims.original_h) // 2 150 | left = (dims.padded_w - dims.original_w) // 2 151 | 152 | # Map back to original coordinates 153 | orig_x1 = max(0, x1 - left) 154 | orig_y1 = max(0, y1 - top) 155 | orig_x2 = min(dims.original_w, x2 - left) 156 | orig_y2 = min(dims.original_h, y2 - top) 157 | 158 | # Ensure we have a valid box (width and height > 0) 159 | if orig_x2 <= orig_x1: 160 | orig_x2 = min(orig_x1 + 1, dims.original_w) 161 | if orig_y2 <= orig_y1: 162 | orig_y2 = min(orig_y1 + 1, dims.original_h) 163 | 164 | return int(orig_x1), int(orig_y1), int(orig_x2), int(orig_y2) 165 | except Exception as e: 166 | print(f"map_to_original_coordinates error: {str(e)}") 167 | # Return safe coordinates 168 | return 0, 0, min(100, dims.original_w), min(100, dims.original_h) 169 | 170 | 171 | def map_to_relevant_coordinates(abs_coords, dims: ImageDimensions): 172 | """ 173 | From absolute coordinates to relevant coordinates 174 | e.g. [100, 100, 200, 200] -> [0.1, 0.2, 0.3, 0.4] 175 | """ 176 | try: 177 | x1, y1, x2, y2 = abs_coords 178 | return round(x1 / dims.original_w, 3), round(y1 / dims.original_h, 3), round(x2 / dims.original_w, 3), round(y2 / dims.original_h, 3) 179 | except Exception as e: 180 | print(f"map_to_relevant_coordinates error: {str(e)}") 181 | return 0.0, 0.0, 1.0, 1.0 # Return full image coordinates 182 | 183 | 184 | def process_coordinates(coords, padded_image, dims: ImageDimensions, previous_box=None): 185 | """Process and adjust coordinates 186 | 187 | Args: 188 | coords: Normalized coordinates [x1, y1, x2, y2] 189 | padded_image: Padded image 190 | dims: Image dimensions object 191 | previous_box: Previous box coordinates for overlap adjustment 192 | 193 | Returns: 194 | tuple: (x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box) 195 | """ 196 | try: 197 | # Convert normalized coordinates to absolute coordinates 198 | x1, y1 = int(coords[0] * dims.padded_w), int(coords[1] * dims.padded_h) 199 | x2, y2 = int(coords[2] * dims.padded_w), int(coords[3] * dims.padded_h) 200 | 201 | # Ensure coordinates are within image bounds before adjustment 202 | x1 = max(0, min(x1, dims.padded_w - 1)) 203 | y1 = max(0, min(y1, dims.padded_h - 1)) 204 | x2 = max(0, min(x2, dims.padded_w)) 205 | y2 = max(0, min(y2, dims.padded_h)) 206 | 207 | # Ensure width and height are at least 1 pixel 208 | if x2 <= x1: 209 | x2 = min(x1 + 1, dims.padded_w) 210 | if y2 <= y1: 211 | y2 = min(y1 + 1, dims.padded_h) 212 | 213 | # Extend box boundaries 214 | new_boxes = adjust_box_edges(padded_image, [[x1, y1, x2, y2]]) 215 | x1, y1, x2, y2 = new_boxes[0] 216 | 217 | # Ensure coordinates are still within image bounds after adjustment 218 | x1 = max(0, min(x1, dims.padded_w - 1)) 219 | y1 = max(0, min(y1, dims.padded_h - 1)) 220 | x2 = max(0, min(x2, dims.padded_w)) 221 | y2 = max(0, min(y2, dims.padded_h)) 222 | 223 | # Ensure width and height are at least 1 pixel after adjustment 224 | if x2 <= x1: 225 | x2 = min(x1 + 1, dims.padded_w) 226 | if y2 <= y1: 227 | y2 = min(y1 + 1, dims.padded_h) 228 | 229 | # Check for overlap with previous box and adjust 230 | if previous_box is not None: 231 | prev_x1, prev_y1, prev_x2, prev_y2 = previous_box 232 | if (x1 < prev_x2 and x2 > prev_x1) and (y1 < prev_y2 and y2 > prev_y1): 233 | y1 = prev_y2 234 | # Ensure y1 is still valid 235 | y1 = min(y1, dims.padded_h - 1) 236 | # Make sure y2 is still greater than y1 237 | if y2 <= y1: 238 | y2 = min(y1 + 1, dims.padded_h) 239 | 240 | # Update previous box 241 | new_previous_box = [x1, y1, x2, y2] 242 | 243 | # Map to original coordinates 244 | orig_x1, orig_y1, orig_x2, orig_y2 = map_to_original_coordinates( 245 | x1, y1, x2, y2, dims 246 | ) 247 | 248 | return x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box 249 | except Exception as e: 250 | print(f"process_coordinates error: {str(e)}") 251 | # Return safe values 252 | orig_x1, orig_y1, orig_x2, orig_y2 = 0, 0, min(100, dims.original_w), min(100, dims.original_h) 253 | return 0, 0, 100, 100, orig_x1, orig_y1, orig_x2, orig_y2, [0, 0, 100, 100] 254 | 255 | 256 | def prepare_image(image) -> Tuple[np.ndarray, ImageDimensions]: 257 | """Load and prepare image with padding while maintaining aspect ratio 258 | 259 | Args: 260 | image: PIL image 261 | 262 | Returns: 263 | tuple: (padded_image, image_dimensions) 264 | """ 265 | try: 266 | # Convert PIL image to OpenCV format 267 | image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) 268 | original_h, original_w = image.shape[:2] 269 | 270 | # Calculate padding to make square image 271 | max_size = max(original_h, original_w) 272 | top = (max_size - original_h) // 2 273 | bottom = max_size - original_h - top 274 | left = (max_size - original_w) // 2 275 | right = max_size - original_w - left 276 | 277 | # Apply padding 278 | padded_image = cv2.copyMakeBorder(image, top, bottom, left, right, 279 | cv2.BORDER_CONSTANT, value=(0, 0, 0)) 280 | 281 | padded_h, padded_w = padded_image.shape[:2] 282 | 283 | dimensions = ImageDimensions( 284 | original_w=original_w, 285 | original_h=original_h, 286 | padded_w=padded_w, 287 | padded_h=padded_h 288 | ) 289 | 290 | return padded_image, dimensions 291 | except Exception as e: 292 | print(f"prepare_image error: {str(e)}") 293 | # Create a minimal valid image and dimensions 294 | h, w = image.height, image.width 295 | dimensions = ImageDimensions( 296 | original_w=w, 297 | original_h=h, 298 | padded_w=w, 299 | padded_h=h 300 | ) 301 | # Return a black image of the same size 302 | return np.zeros((h, w, 3), dtype=np.uint8), dimensions 303 | 304 | 305 | 306 | 307 | def setup_output_dirs(save_dir): 308 | """Create necessary output directories""" 309 | os.makedirs(save_dir, exist_ok=True) 310 | os.makedirs(os.path.join(save_dir, "markdown"), exist_ok=True) 311 | os.makedirs(os.path.join(save_dir, "recognition_json"), exist_ok=True) 312 | 313 | 314 | def save_outputs(recognition_results, image_path, save_dir): 315 | """Save JSON and markdown outputs""" 316 | basename = os.path.splitext(os.path.basename(image_path))[0] 317 | 318 | # Save JSON file 319 | json_path = os.path.join(save_dir, "recognition_json", f"{basename}.json") 320 | with open(json_path, "w", encoding="utf-8") as f: 321 | json.dump(recognition_results, f, ensure_ascii=False, indent=2) 322 | 323 | # Generate and save markdown file 324 | markdown_converter = MarkdownConverter() 325 | markdown_content = markdown_converter.convert(recognition_results) 326 | markdown_path = os.path.join(save_dir, "markdown", f"{basename}.md") 327 | with open(markdown_path, "w", encoding="utf-8") as f: 328 | f.write(markdown_content) 329 | 330 | return json_path 331 | 332 | 333 | def crop_margin(img: Image.Image) -> Image.Image: 334 | """Crop margins from image""" 335 | try: 336 | width, height = img.size 337 | if width == 0 or height == 0: 338 | print("Warning: Image has zero width or height") 339 | return img 340 | 341 | data = np.array(img.convert("L")) 342 | data = data.astype(np.uint8) 343 | max_val = data.max() 344 | min_val = data.min() 345 | if max_val == min_val: 346 | return img 347 | data = (data - min_val) / (max_val - min_val) * 255 348 | gray = 255 * (data < 200).astype(np.uint8) 349 | 350 | coords = cv2.findNonZero(gray) # Find all non-zero points (text) 351 | if coords is None: 352 | return img 353 | a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box 354 | 355 | # Ensure crop coordinates are within image bounds 356 | a = max(0, a) 357 | b = max(0, b) 358 | w = min(w, width - a) 359 | h = min(h, height - b) 360 | 361 | # Only crop if we have a valid region 362 | if w > 0 and h > 0: 363 | return img.crop((a, b, a + w, b + h)) 364 | return img 365 | except Exception as e: 366 | print(f"crop_margin error: {str(e)}") 367 | return img # Return original image on error 368 | --------------------------------------------------------------------------------