├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── assets
├── demo.gif
├── dolphin.png
└── framework.png
├── chat.py
├── config
└── Dolphin.yaml
├── demo
├── element_imgs
│ ├── block_formula.jpeg
│ ├── line_formula.jpeg
│ ├── para_1.jpg
│ ├── para_2.jpg
│ ├── para_3.jpeg
│ ├── table_1.jpeg
│ └── table_2.jpeg
└── page_imgs
│ ├── page_1.jpeg
│ ├── page_2.jpeg
│ ├── page_3.jpeg
│ ├── page_4.png
│ └── page_5.jpg
├── demo_element.py
├── demo_element_hf.py
├── demo_page.py
├── demo_page_hf.py
├── pyproject.toml
├── requirements.txt
└── utils
├── markdown_utils.py
├── model.py
├── processor.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .nox/
42 | .coverage
43 | *.cover
44 | *.py,cover
45 | .hypothesis/
46 | .pytest_cache/
47 | coverage.xml
48 | *.mo
49 | *.pot
50 |
51 | # Translations
52 | *.mo
53 | *.pot
54 |
55 | # Django stuff:
56 | *.log
57 | local_settings.py
58 | db.sqlite3
59 | db.sqlite3-journal
60 |
61 | # Flask stuff:
62 | instance/
63 | .webassets-cache
64 |
65 | # Scrapy stuff:
66 | .scrapy
67 |
68 | # Sphinx documentation
69 | docs/_build/
70 |
71 | # PyBuilder
72 | target/
73 |
74 | # Jupyter Notebook
75 | .ipynb_checkpoints
76 |
77 | # IPython
78 | profile_default/
79 | ipython_config.py
80 |
81 | # pyenv
82 | .python-version
83 |
84 | # pipenv
85 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
86 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
87 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
88 | # install all needed dependencies.
89 | #Pipfile.lock
90 |
91 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
92 | __pypackages__/
93 |
94 | # Celery stuff
95 | celerybeat-schedule
96 | celerybeat.pid
97 |
98 | # SageMath parsed files
99 | *.sage.py
100 |
101 | # Environments
102 | .env
103 | .venv
104 | env/
105 | venv/
106 | ENV/
107 | env.bak/
108 | venv.bak/
109 |
110 | # Spyder project settings
111 | .spyderproject
112 | .spyproject
113 |
114 | # Rope project settings
115 | .ropeproject
116 |
117 | # mkdocs documentation
118 | /site
119 |
120 | # mypy
121 | .mypy_cache/
122 | .dmypy.json
123 | dmypy.json
124 |
125 | # Pyre type checker
126 | .pyre/
127 |
128 | # pytype static type analyzer
129 | .pytype/
130 |
131 | # Cython debug symbols
132 | cython_debug/
133 |
134 | # PyCharm
135 | .idea/
136 | *.iml
137 |
138 | # VS Code
139 | .vscode/
140 | !.vscode/settings.json
141 | !.vscode/tasks.json
142 | !.vscode/launch.json
143 | !.vscode/extensions.json
144 |
145 | # macOS
146 | .DS_Store
147 |
148 | # Windows
149 | Thumbs.db
150 | ehthumbs.db
151 | Desktop.ini
152 |
153 | fusion_result.json
154 | kernel_meta/
155 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | # 1. isort - 自动排序 Python imports
3 | - repo: https://github.com/pycqa/isort
4 | rev: 6.0.1 # 使用固定版本号
5 | hooks:
6 | - id: isort
7 | name: isort (python)
8 | args: [--profile=black] # 与 Black 兼容的配置
9 | language: python
10 |
11 | # 2. Black - 自动格式化 Python 代码
12 | - repo: https://github.com/psf/black
13 | rev: 25.1.0 # 使用固定版本号
14 | hooks:
15 | - id: black
16 | language: python
17 |
18 | # 3. flake8 - Python 静态检查
19 | - repo: https://github.com/pycqa/flake8
20 | rev: 7.2.0
21 | hooks:
22 | - id: flake8
23 | args: [--max-line-length=120, --ignore=E203] # 设置行长度为 120
24 | additional_dependencies: [flake8-bugbear==24.12.12] # 可选:增强检查
25 |
26 | # 4. pre-commit-hooks - 通用 Git 钩子
27 | - repo: https://github.com/pre-commit/pre-commit-hooks
28 | rev: v5.0.0
29 | hooks:
30 | - id: trailing-whitespace # 删除行尾空格
31 | - id: end-of-file-fixer # 确保文件以换行符结束
32 | - id: check-yaml # 验证 YAML 文件语法
33 | - id: check-added-large-files # 阻止大文件提交
34 | args: ["--maxkb=512"]
35 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright 2025 ByteDance Ltd. and/or its affiliates
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 |
5 |
23 |
24 |
25 |
26 |
27 |

28 |
29 |
30 | # Dolphin: Document Image Parsing via Heterogeneous Anchor Prompting
31 |
32 | Dolphin (**Do**cument Image **P**arsing via **H**eterogeneous Anchor Prompt**in**g) is a novel multimodal document image parsing model following an analyze-then-parse paradigm. This repository contains the demo code and pre-trained models for Dolphin.
33 |
34 | ## 📑 Overview
35 |
36 | Document image parsing is challenging due to its complexly intertwined elements such as text paragraphs, figures, formulas, and tables. Dolphin addresses these challenges through a two-stage approach:
37 |
38 | 1. **🔍 Stage 1**: Comprehensive page-level layout analysis by generating element sequence in natural reading order
39 | 2. **🧩 Stage 2**: Efficient parallel parsing of document elements using heterogeneous anchors and task-specific prompts
40 |
41 |
42 |

43 |
44 |
45 | Dolphin achieves promising performance across diverse page-level and element-level parsing tasks while ensuring superior efficiency through its lightweight architecture and parallel parsing mechanism.
46 |
47 | ## 🚀 Demo
48 |
49 | Try our demo on [Demo-Dolphin](http://115.190.42.15:8888/dolphin/).
50 |
51 |
52 | ## 📅 Changelog
53 | - 🔥 **2025.05.21** Our demo is released at [link](http://115.190.42.15:8888/dolphin/). Check it out!
54 | - 🔥 **2025.05.20** The pretrained model and inference code of Dolphin are released.
55 | - 🔥 **2025.05.16** Our paper has been accepted by ACL 2025. Paper link: [arXiv](https://arxiv.org/abs/2505.14059).
56 |
57 | ## 🛠️ Installation
58 |
59 | 1. Clone the repository:
60 | ```bash
61 | git clone https://github.com/ByteDance/Dolphin.git
62 | cd Dolphin
63 | ```
64 |
65 | 2. Install the dependencies:
66 | ```bash
67 | pip install -r requirements.txt
68 | ```
69 |
70 | 3. Download the pre-trained models using one of the following options:
71 |
72 | **Option A: Original Model Format (config-based)**
73 |
74 | Download from [Baidu Yun](https://pan.baidu.com/s/15zcARoX0CTOHKbW8bFZovQ?pwd=9rpx) or [Google Drive](https://drive.google.com/drive/folders/1PQJ3UutepXvunizZEw-uGaQ0BCzf-mie?usp=sharing) and put them in the `./checkpoints` folder.
75 |
76 | **Option B: Hugging Face Model Format**
77 |
78 | Visit our Huggingface [model card](https://huggingface.co/ByteDance/Dolphin), or download model by:
79 |
80 | ```bash
81 | # Download the model from Hugging Face Hub
82 | git lfs install
83 | git clone https://huggingface.co/ByteDance/Dolphin ./hf_model
84 | # Or use the Hugging Face CLI
85 | huggingface-cli download ByteDance/Dolphin --local-dir ./hf_model
86 | ```
87 |
88 | ## ⚡ Inference
89 |
90 | Dolphin provides two inference frameworks with support for two parsing granularities:
91 | - **Page-level Parsing**: Parse the entire document image into a structured JSON and Markdown format
92 | - **Element-level Parsing**: Parse individual document elements (text, table, formula)
93 |
94 | ### 📄 Page-level Parsing
95 |
96 | #### Using Original Framework (config-based)
97 |
98 | ```bash
99 | # Process a single document image
100 | python demo_page.py --config ./config/Dolphin.yaml --input_path ./demo/page_imgs/page_1.jpeg --save_dir ./results
101 |
102 | # Process all document images in a directory
103 | python demo_page.py --config ./config/Dolphin.yaml --input_path ./demo/page_imgs --save_dir ./results
104 |
105 | # Process with custom batch size for parallel element decoding
106 | python demo_page.py --config ./config/Dolphin.yaml --input_path ./demo/page_imgs --save_dir ./results --max_batch_size 8
107 | ```
108 |
109 | #### Using Hugging Face Framework
110 |
111 | ```bash
112 | # Process a single document image
113 | python demo_page_hf.py --model_path ./hf_model --input_path ./demo/page_imgs/page_1.jpeg --save_dir ./results
114 |
115 | # Process all document images in a directory
116 | python demo_page_hf.py --model_path ./hf_model --input_path ./demo/page_imgs --save_dir ./results
117 |
118 | # Process with custom batch size for parallel element decoding
119 | python demo_page_hf.py --model_path ./hf_model --input_path ./demo/page_imgs --save_dir ./results --max_batch_size 16
120 | ```
121 |
122 | ### 🧩 Element-level Parsing
123 |
124 | #### Using Original Framework (config-based)
125 |
126 | ```bash
127 | # Process a single table image
128 | python demo_element.py --config ./config/Dolphin.yaml --input_path ./demo/element_imgs/table_1.jpeg --element_type table
129 |
130 | # Process a single formula image
131 | python demo_element.py --config ./config/Dolphin.yaml --input_path ./demo/element_imgs/line_formula.jpeg --element_type formula
132 |
133 | # Process a single text paragraph image
134 | python demo_element.py --config ./config/Dolphin.yaml --input_path ./demo/element_imgs/para_1.jpg --element_type text
135 | ```
136 |
137 | #### Using Hugging Face Framework
138 |
139 | ```bash
140 | # Process a single table image
141 | python demo_element_hf.py --model_path ./hf_model --input_path ./demo/element_imgs/table_1.jpeg --element_type table
142 |
143 | # Process a single formula image
144 | python demo_element_hf.py --model_path ./hf_model --input_path ./demo/element_imgs/line_formula.jpeg --element_type formula
145 |
146 | # Process a single text paragraph image
147 | python demo_element_hf.py --model_path ./hf_model --input_path ./demo/element_imgs/para_1.jpg --element_type text
148 | ```
149 |
150 | ## 🌟 Key Features
151 |
152 | - 🔄 Two-stage analyze-then-parse approach based on a single VLM
153 | - 📊 Promising performance on document parsing tasks
154 | - 🔍 Natural reading order element sequence generation
155 | - 🧩 Heterogeneous anchor prompting for different document elements
156 | - ⏱️ Efficient parallel parsing mechanism
157 | - 🤗 Support for Hugging Face Transformers for easier integration
158 |
159 |
160 | ## 📮 Notice
161 | **Call for Bad Cases:** If you have encountered any cases where the model performs poorly, we would greatly appreciate it if you could share them in the issue. We are continuously working to optimize and improve the model.
162 |
163 | ## 💖 Acknowledgement
164 |
165 | We would like to acknowledge the following open-source projects that provided inspiration and reference for this work:
166 | - [Donut](https://github.com/clovaai/donut/)
167 | - [Nougat](https://github.com/facebookresearch/nougat)
168 | - [GOT](https://github.com/Ucas-HaoranWei/GOT-OCR2.0)
169 | - [MinerU](https://github.com/opendatalab/MinerU/tree/master)
170 | - [Swin](https://github.com/microsoft/Swin-Transformer)
171 | - [Hugging Face Transformers](https://github.com/huggingface/transformers)
172 |
173 | ## 📝 Citation
174 |
175 | If you find this code useful for your research, please use the following BibTeX entry.
176 |
177 | ```bibtex
178 | @inproceedings{dolphin2025,
179 | title={Dolphin: Document Image Parsing via Heterogeneous Anchor Prompting},
180 | author={Feng, Hao and Wei, Shu and Fei, Xiang and Shi, Wei and Han, Yingdong and Liao, Lei and Lu, Jinghui and Wu, Binghong and Liu, Qi and Lin, Chunhui and Tang, Jingqun and Liu, Hao and Huang, Can},
181 | year={2025},
182 | booktitle={Proceedings of the 65rd Annual Meeting of the Association for Computational Linguistics (ACL)}
183 | }
184 | ```
185 |
186 | ## Star History
187 |
188 | [](https://www.star-history.com/#bytedance/Dolphin&Date)
189 |
--------------------------------------------------------------------------------
/assets/demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/49f51871c6d2c56b2ef08112c386fd2222fbdd99/assets/demo.gif
--------------------------------------------------------------------------------
/assets/dolphin.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/49f51871c6d2c56b2ef08112c386fd2222fbdd99/assets/dolphin.png
--------------------------------------------------------------------------------
/assets/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/49f51871c6d2c56b2ef08112c386fd2222fbdd99/assets/framework.png
--------------------------------------------------------------------------------
/chat.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
3 | SPDX-License-Identifier: MIT
4 | """
5 |
6 | import os
7 | import warnings
8 | from collections import OrderedDict
9 |
10 | from omegaconf import ListConfig
11 |
12 | warnings.filterwarnings("ignore", category=UserWarning)
13 | warnings.filterwarnings("ignore", category=FutureWarning)
14 | os.environ.setdefault("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION", "python")
15 |
16 | import torch
17 | from PIL import Image
18 | from transformers import PreTrainedTokenizerFast
19 |
20 | from utils.model import DonutConfig, DonutModel, SwinEncoder
21 | from utils.processor import DolphinProcessor
22 |
23 |
24 | def try_rename_lagacy_weights(ckpt, output_path=""):
25 | if "state_dict" in ckpt.keys():
26 | ckpt = ckpt["state_dict"]
27 | if "module" in ckpt.keys():
28 | ckpt = ckpt["module"]
29 | new_ckpt = OrderedDict()
30 | for k, v in ckpt.items():
31 | if k.startswith("model."):
32 | k = k[len("model.") :]
33 | if k.startswith("encoder"):
34 | new_ckpt["vpm" + k[len("encoder") :]] = v
35 | elif k.startswith("decoder"):
36 | new_ckpt["llm" + k[len("encoder") :]] = v
37 | else:
38 | new_ckpt[k] = v
39 | if output_path:
40 | torch.save(new_ckpt, output_path)
41 | return new_ckpt
42 |
43 |
44 | def convert_listconfig_to_list(config):
45 | new_config = {}
46 | for k, v in config.items():
47 | if isinstance(v, ListConfig):
48 | new_config[k] = list(v)
49 | else:
50 | new_config[k] = v
51 | return new_config
52 |
53 |
54 | class DOLPHIN:
55 | def __init__(self, config, ckpt_path="") -> None:
56 | self.model_args = config.model
57 | self.swin_args = config.model.pop("swin_args")
58 | self.swin_args = convert_listconfig_to_list(self.swin_args)
59 |
60 | vision_tower = SwinEncoder(
61 | input_size=self.swin_args["img_size"],
62 | patch_size=self.swin_args["patch_size"],
63 | embed_dim=self.swin_args["embed_dim"],
64 | window_size=self.swin_args["window_size"],
65 | encoder_layer=self.swin_args["encoder_layer"],
66 | num_heads=self.swin_args["num_heads"],
67 | align_long_axis=self.swin_args["align_long_axis"],
68 | )
69 |
70 | self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=self.model_args.tokenizer_path)
71 | self.tokenizer.pad_token = ""
72 | self.tokenizer.bos_token = ""
73 | self.tokenizer.eos_token = ""
74 | self.tokenizer.unk_token = ""
75 |
76 | if self.model_args.get("extra_answer_tokens", False):
77 | # print("Allowing multitask training: adding to the tokenizer.")
78 | prompt_end_token = " "
79 | self.tokenizer.add_special_tokens({"additional_special_tokens": sorted(set([prompt_end_token]))})
80 | self.tokenizer._prompt_end_token = prompt_end_token
81 | self.tokenizer._prompt_end_token_id = self.tokenizer.convert_tokens_to_ids(prompt_end_token)
82 |
83 | donut_config = DonutConfig(
84 | decoder_layer=self.model_args.decoder_layer,
85 | max_length=self.model_args.max_length,
86 | max_position_embeddings=self.model_args.max_position_embeddings,
87 | hidden_dimension=self.model_args.hidden_dimension,
88 | )
89 |
90 | self.model = DonutModel(config=donut_config, vision_tower=vision_tower, tokenizer=self.tokenizer)
91 | if self.model_args.model_name_or_path:
92 | ckpt = torch.load(self.model_args.model_name_or_path)
93 | ckpt = try_rename_lagacy_weights(ckpt)
94 | self.model.load_state_dict(ckpt, strict=True)
95 |
96 | self.model.to("cuda")
97 | self.model.eval()
98 | transform_args = {
99 | "input_size": self.swin_args["img_size"],
100 | "max_length": self.model_args.max_length,
101 | }
102 | self.processor = DolphinProcessor({}, self.tokenizer, transform_args=transform_args)
103 |
104 | def chat(
105 | self,
106 | question,
107 | image,
108 | return_raw=False,
109 | return_score=False,
110 | return_img_size=False,
111 | only_return_img_size=False,
112 | max_batch_size=16,
113 | ):
114 |
115 | def _preprocess_image(image):
116 | if isinstance(image, str):
117 | image = Image.open(image).convert("RGB")
118 | if return_img_size or only_return_img_size:
119 | image_tensor, ori_size = self.processor.process_image_for_inference(image, return_img_size=True)
120 | else:
121 | image_tensor = self.processor.process_image_for_inference(image, return_img_size=False)
122 | ori_size = None
123 | return image_tensor, ori_size
124 |
125 | def _preprocess_prompt(question):
126 | if self.model_args.get("extra_answer_tokens", False):
127 | if self.tokenizer._prompt_end_token not in question:
128 | question = question + self.tokenizer._prompt_end_token
129 | prompt_ids = self.processor.process_prompt_for_inference(question)
130 | return prompt_ids
131 |
132 | def _preprocess_prompt_batch(question):
133 | if self.model_args.get("extra_answer_tokens", False):
134 | for i in range(len(question)):
135 | if self.tokenizer._prompt_end_token not in question[i]:
136 | question[i] = question[i] + self.tokenizer._prompt_end_token
137 | if not question[i].startswith(""):
138 | question[i] = "" + question[i]
139 | return question
140 |
141 | def _postprocess(output, question):
142 | output = output.replace("", "").replace(question, "").replace("", "").replace("", "")
143 | if self.model_args.get("extra_answer_tokens", False):
144 | output = output.split(self.tokenizer._prompt_end_token)[-1]
145 | return output
146 |
147 | if isinstance(question, list):
148 | image_tensor_list = []
149 | for i in image:
150 | image_tensor, ori_size = _preprocess_image(i)
151 | image_tensor_list.append(image_tensor)
152 | image_tensor = torch.cat(image_tensor_list, dim=0)
153 |
154 | question = _preprocess_prompt_batch(question)
155 | self.processor.tokenizer.padding_side = "left"
156 | prompt_ids = self.processor.tokenizer(
157 | question, add_special_tokens=False, return_tensors="pt", padding=True
158 | ).input_ids
159 | else:
160 | image_tensor, ori_size = _preprocess_image(image)
161 | prompt_ids = _preprocess_prompt(question)
162 |
163 | if only_return_img_size:
164 | return ori_size
165 |
166 | model_output_batch = []
167 | for i in range(0, image_tensor.shape[0], max_batch_size):
168 | image_tensor_batch = image_tensor[i : i + max_batch_size]
169 | prompt_ids_batch = prompt_ids[i : i + max_batch_size]
170 | model_output = self.model.inference(image_tensors=image_tensor_batch, prompt_ids=prompt_ids_batch)
171 | model_output_batch.append(model_output)
172 | model_output = {}
173 | for k, v in model_output_batch[0].items():
174 | if isinstance(v, torch.Tensor):
175 | model_output[k] = sum(
176 | [v_batch[k].cpu().numpy().tolist() for v_batch in model_output_batch],
177 | [],
178 | )
179 | else:
180 | model_output[k] = sum([v_batch[k] for v_batch in model_output_batch], [])
181 |
182 | if return_raw:
183 | if return_img_size:
184 | return model_output, ori_size
185 | return model_output
186 | else:
187 | if isinstance(question, list):
188 | output = [_postprocess(model_output["repetitions"][i], question[i]) for i in range(len(question))]
189 | score = model_output["scores"]
190 | else:
191 | output = _postprocess(model_output["repetitions"][0], question)
192 | score = model_output["scores"][0]
193 | if return_score:
194 | return output, score
195 | if return_img_size:
196 | return output, ori_size
197 | return output
198 |
--------------------------------------------------------------------------------
/config/Dolphin.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | model_name_or_path: "./checkpoints/dolphin_model.bin"
3 | tokenizer_path: "./checkpoints/dolphin_tokenizer.json"
4 | extra_answer_tokens: True # add token
5 | max_length: 4096
6 | decoder_layer: 10
7 | max_position_embeddings: 4096
8 | hidden_dimension: 1024
9 | swin_args:
10 | name: 'swin'
11 | img_size: [896, 896]
12 | patch_size: 4
13 | embed_dim: 128
14 | align_long_axis: False
15 | window_size: 7
16 | encoder_layer: [2, 2, 14, 2]
17 | num_heads: [4, 8, 16, 32]
18 |
--------------------------------------------------------------------------------
/demo/element_imgs/block_formula.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/49f51871c6d2c56b2ef08112c386fd2222fbdd99/demo/element_imgs/block_formula.jpeg
--------------------------------------------------------------------------------
/demo/element_imgs/line_formula.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/49f51871c6d2c56b2ef08112c386fd2222fbdd99/demo/element_imgs/line_formula.jpeg
--------------------------------------------------------------------------------
/demo/element_imgs/para_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/49f51871c6d2c56b2ef08112c386fd2222fbdd99/demo/element_imgs/para_1.jpg
--------------------------------------------------------------------------------
/demo/element_imgs/para_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/49f51871c6d2c56b2ef08112c386fd2222fbdd99/demo/element_imgs/para_2.jpg
--------------------------------------------------------------------------------
/demo/element_imgs/para_3.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/49f51871c6d2c56b2ef08112c386fd2222fbdd99/demo/element_imgs/para_3.jpeg
--------------------------------------------------------------------------------
/demo/element_imgs/table_1.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/49f51871c6d2c56b2ef08112c386fd2222fbdd99/demo/element_imgs/table_1.jpeg
--------------------------------------------------------------------------------
/demo/element_imgs/table_2.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/49f51871c6d2c56b2ef08112c386fd2222fbdd99/demo/element_imgs/table_2.jpeg
--------------------------------------------------------------------------------
/demo/page_imgs/page_1.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/49f51871c6d2c56b2ef08112c386fd2222fbdd99/demo/page_imgs/page_1.jpeg
--------------------------------------------------------------------------------
/demo/page_imgs/page_2.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/49f51871c6d2c56b2ef08112c386fd2222fbdd99/demo/page_imgs/page_2.jpeg
--------------------------------------------------------------------------------
/demo/page_imgs/page_3.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/49f51871c6d2c56b2ef08112c386fd2222fbdd99/demo/page_imgs/page_3.jpeg
--------------------------------------------------------------------------------
/demo/page_imgs/page_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/49f51871c6d2c56b2ef08112c386fd2222fbdd99/demo/page_imgs/page_4.png
--------------------------------------------------------------------------------
/demo/page_imgs/page_5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/49f51871c6d2c56b2ef08112c386fd2222fbdd99/demo/page_imgs/page_5.jpg
--------------------------------------------------------------------------------
/demo_element.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
3 | SPDX-License-Identifier: MIT
4 | """
5 |
6 | import argparse
7 | import glob
8 | import os
9 |
10 | from omegaconf import OmegaConf
11 | from PIL import Image
12 |
13 | from chat import DOLPHIN
14 | from utils.utils import *
15 |
16 |
17 | def process_element(image_path, model, element_type, save_dir=None):
18 | """Process a single element image (text, table, formula)
19 |
20 | Args:
21 | image_path: Path to the element image
22 | model: DOLPHIN model instance
23 | element_type: Type of element ('text', 'table', 'formula')
24 | save_dir: Directory to save results (default: same as input directory)
25 |
26 | Returns:
27 | Parsed content of the element and recognition results
28 | """
29 | # Load and prepare image
30 | pil_image = Image.open(image_path).convert("RGB")
31 | pil_image = crop_margin(pil_image)
32 |
33 | # Select appropriate prompt based on element type
34 | if element_type == "table":
35 | prompt = "Parse the table in the image."
36 | label = "tab"
37 | elif element_type == "formula":
38 | prompt = "Read text in the image."
39 | label = "formula"
40 | else: # Default to text
41 | prompt = "Read text in the image."
42 | label = "text"
43 |
44 | # Process the element
45 | result = model.chat(prompt, pil_image)
46 |
47 | # Create recognition result in the same format as the document parser
48 | recognition_result = [
49 | {
50 | "label": label,
51 | "text": result.strip(),
52 | }
53 | ]
54 |
55 | # Save results if save_dir is provided
56 | if save_dir:
57 | save_outputs(recognition_result, image_path, save_dir)
58 | print(f"Results saved to {save_dir}")
59 |
60 | return result, recognition_result
61 |
62 |
63 | def main():
64 | parser = argparse.ArgumentParser(description="Element-level processing using DOLPHIN model")
65 | parser.add_argument("--config", default="./config/Dolphin.yaml", help="Path to configuration file")
66 | parser.add_argument("--input_path", type=str, required=True, help="Path to input image or directory of images")
67 | parser.add_argument(
68 | "--element_type",
69 | type=str,
70 | choices=["text", "table", "formula"],
71 | default="text",
72 | help="Type of element to process (text, table, formula)",
73 | )
74 | parser.add_argument(
75 | "--save_dir",
76 | type=str,
77 | default=None,
78 | help="Directory to save parsing results (default: same as input directory)",
79 | )
80 | parser.add_argument("--print_results", action="store_true", help="Print recognition results to console")
81 | args = parser.parse_args()
82 |
83 | # Load Model
84 | config = OmegaConf.load(args.config)
85 | model = DOLPHIN(config)
86 |
87 | # Set save directory
88 | save_dir = args.save_dir or (
89 | args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path)
90 | )
91 | setup_output_dirs(save_dir)
92 |
93 | # Collect Images
94 | if os.path.isdir(args.input_path):
95 | image_files = []
96 | for ext in [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"]:
97 | image_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}")))
98 | image_files = sorted(image_files)
99 | else:
100 | if not os.path.exists(args.input_path):
101 | raise FileNotFoundError(f"Input path {args.input_path} does not exist")
102 | image_files = [args.input_path]
103 |
104 | total_samples = len(image_files)
105 | print(f"\nTotal samples to process: {total_samples}")
106 |
107 | # Process images one by one
108 | for image_path in image_files:
109 | print(f"\nProcessing {image_path}")
110 | try:
111 | result, recognition_result = process_element(
112 | image_path=image_path,
113 | model=model,
114 | element_type=args.element_type,
115 | save_dir=save_dir,
116 | )
117 |
118 | if args.print_results:
119 | print("\nRecognition result:")
120 | print(result)
121 | print("-" * 40)
122 |
123 | except Exception as e:
124 | print(f"Error processing {image_path}: {str(e)}")
125 | continue
126 |
127 |
128 | if __name__ == "__main__":
129 | main()
130 |
--------------------------------------------------------------------------------
/demo_element_hf.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
3 | SPDX-License-Identifier: MIT
4 | """
5 |
6 | import argparse
7 | import glob
8 | import os
9 |
10 | import torch
11 | from PIL import Image
12 | from transformers import AutoProcessor, VisionEncoderDecoderModel
13 |
14 | from utils.utils import *
15 |
16 |
17 | class DOLPHIN:
18 | def __init__(self, model_id_or_path):
19 | """Initialize the Hugging Face model
20 |
21 | Args:
22 | model_id_or_path: Path to local model or Hugging Face model ID
23 | """
24 | # Load model from local path or Hugging Face hub
25 | self.processor = AutoProcessor.from_pretrained(model_id_or_path)
26 | self.model = VisionEncoderDecoderModel.from_pretrained(model_id_or_path)
27 | self.model.eval()
28 |
29 | # Set device and precision
30 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
31 | self.model.to(self.device)
32 | self.model = self.model.half() # Always use half precision by default
33 |
34 | # set tokenizer
35 | self.tokenizer = self.processor.tokenizer
36 |
37 | def chat(self, prompt, image):
38 | """Process an image with the given prompt
39 |
40 | Args:
41 | prompt: Text prompt to guide the model
42 | image: PIL Image to process
43 |
44 | Returns:
45 | Generated text from the model
46 | """
47 | # Prepare image
48 | pixel_values = self.processor(image, return_tensors="pt").pixel_values
49 | pixel_values = pixel_values.half()
50 |
51 | # Prepare prompt
52 | prompt = f"{prompt} "
53 | prompt_ids = self.tokenizer(
54 | prompt,
55 | add_special_tokens=False,
56 | return_tensors="pt"
57 | ).input_ids.to(self.device)
58 |
59 | decoder_attention_mask = torch.ones_like(prompt_ids)
60 |
61 | # Generate text
62 | outputs = self.model.generate(
63 | pixel_values=pixel_values.to(self.device),
64 | decoder_input_ids=prompt_ids,
65 | decoder_attention_mask=decoder_attention_mask,
66 | min_length=1,
67 | max_length=4096,
68 | pad_token_id=self.tokenizer.pad_token_id,
69 | eos_token_id=self.tokenizer.eos_token_id,
70 | use_cache=True,
71 | bad_words_ids=[[self.tokenizer.unk_token_id]],
72 | return_dict_in_generate=True,
73 | do_sample=False,
74 | num_beams=1,
75 | )
76 |
77 | # Process the output
78 | sequence = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)[0]
79 | sequence = sequence.replace(prompt, "").replace("", "").replace("", "").strip()
80 |
81 | return sequence
82 |
83 | def process_element(image_path, model, element_type, save_dir=None):
84 | """Process a single element image (text, table, formula)
85 |
86 | Args:
87 | image_path: Path to the element image
88 | model: HFModel model instance
89 | element_type: Type of element ('text', 'table', 'formula')
90 | save_dir: Directory to save results (default: same as input directory)
91 |
92 | Returns:
93 | Parsed content of the element and recognition results
94 | """
95 | # Load and prepare image
96 | pil_image = Image.open(image_path).convert("RGB")
97 | pil_image = crop_margin(pil_image)
98 |
99 | # Select appropriate prompt based on element type
100 | if element_type == "table":
101 | prompt = "Parse the table in the image."
102 | label = "tab"
103 | elif element_type == "formula":
104 | prompt = "Read text in the image."
105 | label = "formula"
106 | else: # Default to text
107 | prompt = "Read text in the image."
108 | label = "text"
109 |
110 | # Process the element
111 | result = model.chat(prompt, pil_image)
112 |
113 | # Create recognition result in the same format as the document parser
114 | recognition_result = [
115 | {
116 | "label": label,
117 | "text": result.strip(),
118 | }
119 | ]
120 |
121 | # Save results if save_dir is provided
122 | if save_dir:
123 | save_outputs(recognition_result, image_path, save_dir)
124 | print(f"Results saved to {save_dir}")
125 |
126 | return result, recognition_result
127 |
128 |
129 | def main():
130 | parser = argparse.ArgumentParser(description="Element-level processing using DOLPHIN model")
131 | parser.add_argument("--model_path", default="./hf_model", help="Path to Hugging Face model")
132 | parser.add_argument("--input_path", type=str, required=True, help="Path to input image or directory of images")
133 | parser.add_argument(
134 | "--element_type",
135 | type=str,
136 | choices=["text", "table", "formula"],
137 | default="text",
138 | help="Type of element to process (text, table, formula)",
139 | )
140 | parser.add_argument(
141 | "--save_dir",
142 | type=str,
143 | default=None,
144 | help="Directory to save parsing results (default: same as input directory)",
145 | )
146 | parser.add_argument("--print_results", action="store_true", help="Print recognition results to console")
147 | args = parser.parse_args()
148 |
149 | # Load Model
150 | model = DOLPHIN(args.model_path)
151 |
152 | # Set save directory
153 | save_dir = args.save_dir or (
154 | args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path)
155 | )
156 | setup_output_dirs(save_dir)
157 |
158 | # Collect Images
159 | if os.path.isdir(args.input_path):
160 | image_files = []
161 | for ext in [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"]:
162 | image_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}")))
163 | image_files = sorted(image_files)
164 | else:
165 | if not os.path.exists(args.input_path):
166 | raise FileNotFoundError(f"Input path {args.input_path} does not exist")
167 | image_files = [args.input_path]
168 |
169 | total_samples = len(image_files)
170 | print(f"\nTotal samples to process: {total_samples}")
171 |
172 | # Process images one by one
173 | for image_path in image_files:
174 | print(f"\nProcessing {image_path}")
175 | try:
176 | result, recognition_result = process_element(
177 | image_path=image_path,
178 | model=model,
179 | element_type=args.element_type,
180 | save_dir=save_dir,
181 | )
182 |
183 | if args.print_results:
184 | print("\nRecognition result:")
185 | print(result)
186 | print("-" * 40)
187 | except Exception as e:
188 | print(f"Error processing {image_path}: {str(e)}")
189 | continue
190 |
191 |
192 | if __name__ == "__main__":
193 | main()
194 |
--------------------------------------------------------------------------------
/demo_page.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
3 | SPDX-License-Identifier: MIT
4 | """
5 |
6 | import argparse
7 | import glob
8 | import os
9 |
10 | import cv2
11 | from omegaconf import OmegaConf
12 | from PIL import Image
13 |
14 | from chat import DOLPHIN
15 | from utils.utils import *
16 |
17 |
18 | def process_page(image_path, model, save_dir, max_batch_size):
19 | """Parse document images with two stages"""
20 | # Stage 1: Page-level layout and reading order parsing
21 | pil_image = Image.open(image_path).convert("RGB")
22 | layout_output = model.chat("Parse the reading order of this document.", pil_image)
23 |
24 | # Stage 2: Element-level content parsing
25 | padded_image, dims = prepare_image(pil_image)
26 | recognition_results = process_elements(layout_output, padded_image, dims, model, max_batch_size)
27 |
28 | # Save outputs
29 | json_path = save_outputs(recognition_results, image_path, save_dir)
30 |
31 | return json_path, recognition_results
32 |
33 |
34 | def process_elements(layout_results, padded_image, dims, model, max_batch_size):
35 | """Parse all document elements with parallel decoding"""
36 | layout_results = parse_layout_string(layout_results)
37 |
38 | text_table_elements = [] # Elements that need processing
39 | figure_results = [] # Figure elements (no processing needed)
40 | previous_box = None
41 | reading_order = 0
42 |
43 | # Collect elements for processing
44 | for bbox, label in layout_results:
45 | try:
46 | # Adjust coordinates
47 | x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = process_coordinates(
48 | bbox, padded_image, dims, previous_box
49 | )
50 |
51 | # Crop and parse element
52 | cropped = padded_image[y1:y2, x1:x2]
53 | if cropped.size > 0:
54 | if label == "fig":
55 | # For figure regions, add empty text result immediately
56 | figure_results.append(
57 | {
58 | "label": label,
59 | "bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
60 | "text": "",
61 | "reading_order": reading_order,
62 | }
63 | )
64 | else:
65 | # For text or table regions, prepare for parsing
66 | pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
67 | prompt = "Parse the table in the image." if label == "tab" else "Read text in the image."
68 | text_table_elements.append(
69 | {
70 | "crop": pil_crop,
71 | "prompt": prompt,
72 | "label": label,
73 | "bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
74 | "reading_order": reading_order,
75 | }
76 | )
77 |
78 | reading_order += 1
79 |
80 | except Exception as e:
81 | print(f"Error processing bbox with label {label}: {str(e)}")
82 | continue
83 |
84 | # Parse text/table elements in parallel
85 | recognition_results = figure_results
86 | if text_table_elements:
87 | crops_list = [elem["crop"] for elem in text_table_elements]
88 | prompts_list = [elem["prompt"] for elem in text_table_elements]
89 |
90 | # Inference in batch
91 | batch_results = model.chat(prompts_list, crops_list, max_batch_size=max_batch_size)
92 |
93 | # Add batch results to recognition_results
94 | for i, result in enumerate(batch_results):
95 | elem = text_table_elements[i]
96 | recognition_results.append(
97 | {
98 | "label": elem["label"],
99 | "bbox": elem["bbox"],
100 | "text": result.strip(),
101 | "reading_order": elem["reading_order"],
102 | }
103 | )
104 |
105 | # Sort elements by reading order
106 | recognition_results.sort(key=lambda x: x.get("reading_order", 0))
107 |
108 | return recognition_results
109 |
110 |
111 | def main():
112 | parser = argparse.ArgumentParser(description="Document processing tool using DOLPHIN model")
113 | parser.add_argument("--config", default="./config/Dolphin.yaml", help="Path to configuration file")
114 | parser.add_argument("--input_path", type=str, default="./demo", help="Path to input image or directory of images")
115 | parser.add_argument(
116 | "--save_dir",
117 | type=str,
118 | default=None,
119 | help="Directory to save parsing results (default: same as input directory)",
120 | )
121 | parser.add_argument(
122 | "--max_batch_size",
123 | type=int,
124 | default=4,
125 | help="Maximum number of document elements to parse in a single batch (default: 4)",
126 | )
127 | args = parser.parse_args()
128 |
129 | # Load Model
130 | config = OmegaConf.load(args.config)
131 | model = DOLPHIN(config)
132 |
133 | # Collect Document Images
134 | if os.path.isdir(args.input_path):
135 | image_files = []
136 | for ext in [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"]:
137 | image_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}")))
138 | image_files = sorted(image_files)
139 | else:
140 | if not os.path.exists(args.input_path):
141 | raise FileNotFoundError(f"Input path {args.input_path} does not exist")
142 | image_files = [args.input_path]
143 |
144 | save_dir = args.save_dir or (
145 | args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path)
146 | )
147 | setup_output_dirs(save_dir)
148 |
149 | total_samples = len(image_files)
150 | print(f"\nTotal samples to process: {total_samples}")
151 |
152 | # Process All Document Images
153 | for image_path in image_files:
154 | print(f"\nProcessing {image_path}")
155 | try:
156 | json_path, recognition_results = process_page(
157 | image_path=image_path,
158 | model=model,
159 | save_dir=save_dir,
160 | max_batch_size=args.max_batch_size,
161 | )
162 |
163 | print(f"Processing completed. Results saved to {save_dir}")
164 |
165 | except Exception as e:
166 | print(f"Error processing {image_path}: {str(e)}")
167 | continue
168 |
169 |
170 | if __name__ == "__main__":
171 | main()
172 |
--------------------------------------------------------------------------------
/demo_page_hf.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
3 | SPDX-License-Identifier: MIT
4 | """
5 |
6 | import argparse
7 | import glob
8 | import os
9 |
10 | import cv2
11 | import torch
12 | from PIL import Image
13 | from transformers import AutoProcessor, VisionEncoderDecoderModel
14 |
15 | from utils.utils import *
16 |
17 |
18 | class DOLPHIN:
19 | def __init__(self, model_id_or_path):
20 | """Initialize the Hugging Face model
21 |
22 | Args:
23 | model_id_or_path: Path to local model or Hugging Face model ID
24 | """
25 | # Load model from local path or Hugging Face hub
26 | self.processor = AutoProcessor.from_pretrained(model_id_or_path)
27 | self.model = VisionEncoderDecoderModel.from_pretrained(model_id_or_path)
28 | self.model.eval()
29 |
30 | # Set device and precision
31 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
32 | self.model.to(self.device)
33 | self.model = self.model.half() # Always use half precision by default
34 |
35 | # set tokenizer
36 | self.tokenizer = self.processor.tokenizer
37 |
38 | def chat(self, prompt, image):
39 | """Process an image or batch of images with the given prompt(s)
40 |
41 | Args:
42 | prompt: Text prompt or list of prompts to guide the model
43 | image: PIL Image or list of PIL Images to process
44 |
45 | Returns:
46 | Generated text or list of texts from the model
47 | """
48 | # Check if we're dealing with a batch
49 | is_batch = isinstance(image, list)
50 |
51 | if not is_batch:
52 | # Single image, wrap it in a list for consistent processing
53 | images = [image]
54 | prompts = [prompt]
55 | else:
56 | # Batch of images
57 | images = image
58 | prompts = prompt if isinstance(prompt, list) else [prompt] * len(images)
59 |
60 | # Prepare image
61 | batch_inputs = self.processor(images, return_tensors="pt", padding=True)
62 | batch_pixel_values = batch_inputs.pixel_values.half().to(self.device)
63 |
64 | # Prepare prompt
65 | prompts = [f"{p} " for p in prompts]
66 | batch_prompt_inputs = self.tokenizer(
67 | prompts,
68 | add_special_tokens=False,
69 | return_tensors="pt"
70 | )
71 |
72 | batch_prompt_ids = batch_prompt_inputs.input_ids.to(self.device)
73 | batch_attention_mask = batch_prompt_inputs.attention_mask.to(self.device)
74 |
75 | # Generate text
76 | outputs = self.model.generate(
77 | pixel_values=batch_pixel_values,
78 | decoder_input_ids=batch_prompt_ids,
79 | decoder_attention_mask=batch_attention_mask,
80 | min_length=1,
81 | max_length=4096,
82 | pad_token_id=self.tokenizer.pad_token_id,
83 | eos_token_id=self.tokenizer.eos_token_id,
84 | use_cache=True,
85 | bad_words_ids=[[self.tokenizer.unk_token_id]],
86 | return_dict_in_generate=True,
87 | do_sample=False,
88 | num_beams=1,
89 | repetition_penalty=1.1
90 | )
91 |
92 | # Process output
93 | sequences = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)
94 |
95 | # Clean prompt text from output
96 | results = []
97 | for i, sequence in enumerate(sequences):
98 | cleaned = sequence.replace(prompts[i], "").replace("", "").replace("", "").strip()
99 | results.append(cleaned)
100 |
101 | # Return a single result for single image input
102 | if not is_batch:
103 | return results[0]
104 | return results
105 |
106 |
107 | def process_page(image_path, model, save_dir, max_batch_size=None):
108 | """Parse document images with two stages"""
109 | # Stage 1: Page-level layout and reading order parsing
110 | pil_image = Image.open(image_path).convert("RGB")
111 | layout_output = model.chat("Parse the reading order of this document.", pil_image)
112 |
113 | # Stage 2: Element-level content parsing
114 | padded_image, dims = prepare_image(pil_image)
115 | recognition_results = process_elements(layout_output, padded_image, dims, model, max_batch_size)
116 |
117 | # Save outputs
118 | json_path = save_outputs(recognition_results, image_path, save_dir)
119 |
120 | return json_path, recognition_results
121 |
122 |
123 | def process_elements(layout_results, padded_image, dims, model, max_batch_size=None):
124 | """Parse all document elements with parallel decoding"""
125 | layout_results = parse_layout_string(layout_results)
126 |
127 | # Store text and table elements separately
128 | text_elements = [] # Text elements
129 | table_elements = [] # Table elements
130 | figure_results = [] # Image elements (no processing needed)
131 | previous_box = None
132 | reading_order = 0
133 |
134 | # Collect elements to process and group by type
135 | for bbox, label in layout_results:
136 | try:
137 | # Adjust coordinates
138 | x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = process_coordinates(
139 | bbox, padded_image, dims, previous_box
140 | )
141 |
142 | # Crop and parse element
143 | cropped = padded_image[y1:y2, x1:x2]
144 | if cropped.size > 0:
145 | if label == "fig":
146 | # For figure regions, add empty text result immediately
147 | figure_results.append(
148 | {
149 | "label": label,
150 | "bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
151 | "text": "",
152 | "reading_order": reading_order,
153 | }
154 | )
155 | else:
156 | # Prepare element for parsing
157 | pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
158 | element_info = {
159 | "crop": pil_crop,
160 | "label": label,
161 | "bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
162 | "reading_order": reading_order,
163 | }
164 |
165 | # Group by type
166 | if label == "tab":
167 | table_elements.append(element_info)
168 | else: # Text elements
169 | text_elements.append(element_info)
170 |
171 | reading_order += 1
172 |
173 | except Exception as e:
174 | print(f"Error processing bbox with label {label}: {str(e)}")
175 | continue
176 |
177 | # Initialize results list
178 | recognition_results = figure_results.copy()
179 |
180 | # Process text elements (in batches)
181 | if text_elements:
182 | text_results = process_element_batch(text_elements, model, "Read text in the image.", max_batch_size)
183 | recognition_results.extend(text_results)
184 |
185 | # Process table elements (in batches)
186 | if table_elements:
187 | table_results = process_element_batch(table_elements, model, "Parse the table in the image.", max_batch_size)
188 | recognition_results.extend(table_results)
189 |
190 | # Sort elements by reading order
191 | recognition_results.sort(key=lambda x: x.get("reading_order", 0))
192 |
193 | return recognition_results
194 |
195 |
196 | def process_element_batch(elements, model, prompt, max_batch_size=None):
197 | """Process elements of the same type in batches"""
198 | results = []
199 |
200 | # Determine batch size
201 | batch_size = len(elements)
202 | if max_batch_size is not None and max_batch_size > 0:
203 | batch_size = min(batch_size, max_batch_size)
204 |
205 | # Process in batches
206 | for i in range(0, len(elements), batch_size):
207 | batch_elements = elements[i:i+batch_size]
208 | crops_list = [elem["crop"] for elem in batch_elements]
209 |
210 | # Use the same prompt for all elements in the batch
211 | prompts_list = [prompt] * len(crops_list)
212 |
213 | # Batch inference
214 | batch_results = model.chat(prompts_list, crops_list)
215 |
216 | # Add results
217 | for j, result in enumerate(batch_results):
218 | elem = batch_elements[j]
219 | results.append({
220 | "label": elem["label"],
221 | "bbox": elem["bbox"],
222 | "text": result.strip(),
223 | "reading_order": elem["reading_order"],
224 | })
225 |
226 | return results
227 |
228 |
229 | def main():
230 | parser = argparse.ArgumentParser(description="Document processing tool using DOLPHIN model")
231 | parser.add_argument("--model_path", default="./hf_model", help="Path to Hugging Face model")
232 | parser.add_argument("--input_path", type=str, default="./demo", help="Path to input image or directory of images")
233 | parser.add_argument(
234 | "--save_dir",
235 | type=str,
236 | default=None,
237 | help="Directory to save parsing results (default: same as input directory)",
238 | )
239 | parser.add_argument(
240 | "--max_batch_size",
241 | type=int,
242 | default=16,
243 | help="Maximum number of document elements to parse in a single batch (default: 16)",
244 | )
245 | args = parser.parse_args()
246 |
247 | # Load Model
248 | model = DOLPHIN(args.model_path)
249 |
250 | # Collect Document Images
251 | if os.path.isdir(args.input_path):
252 | image_files = []
253 | for ext in [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"]:
254 | image_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}")))
255 | image_files = sorted(image_files)
256 | else:
257 | if not os.path.exists(args.input_path):
258 | raise FileNotFoundError(f"Input path {args.input_path} does not exist")
259 | image_files = [args.input_path]
260 |
261 | save_dir = args.save_dir or (
262 | args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path)
263 | )
264 | setup_output_dirs(save_dir)
265 |
266 | total_samples = len(image_files)
267 | print(f"\nTotal samples to process: {total_samples}")
268 |
269 | # Process All Document Images
270 | for image_path in image_files:
271 | print(f"\nProcessing {image_path}")
272 | try:
273 | json_path, recognition_results = process_page(
274 | image_path=image_path,
275 | model=model,
276 | save_dir=save_dir,
277 | max_batch_size=args.max_batch_size,
278 | )
279 |
280 | print(f"Processing completed. Results saved to {save_dir}")
281 |
282 | except Exception as e:
283 | print(f"Error processing {image_path}: {str(e)}")
284 | continue
285 |
286 |
287 | if __name__ == "__main__":
288 | main()
289 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 120
3 | include = '\.pyi?$'
4 | exclude = '''
5 | /(
6 | \.git
7 | | \.hg
8 | | \.mypy_cache
9 | | \.tox
10 | | \.venv
11 | | _build
12 | | buck-out
13 | | build
14 | | dist
15 | )/
16 | '''
17 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | albumentations==1.4.0
2 | numpy==1.24.4
3 | omegaconf==2.3.0
4 | opencv-python==4.11.0.86
5 | opencv-python-headless==4.5.5.64
6 | pillow==9.3.0
7 | timm==0.5.4
8 | torch==2.1.0
9 | torchvision==0.16.0
10 | transformers==4.47.0
11 | accelerate==1.6.0
--------------------------------------------------------------------------------
/utils/markdown_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
3 | SPDX-License-Identifier: MIT
4 | """
5 |
6 | import re
7 | import base64
8 | from typing import List, Dict, Any, Optional
9 |
10 |
11 | """
12 | Example input:
13 | [
14 | {"label": "tab", "bbox": [0.176, 0.74, 0.824, 0.82], "text": " | HellaSwag | Obqa | WinoGrande | ARC-c | ARC-e | boolq | piqa | Avg |
OPT-1.3B | 53.65 | 33.40 | 59.59 | 29.44 | 50.80 | 60.83 | 72.36 | 51.44 |
Pythia-1.0B | 47.16 | 31.40 | 53.43 | 27.05 | 48.99 | 57.83 | 69.21 | 48.30 |
Pythia-1.4B | 52.01 | 33.20 | 57.38 | 28.50 | 54.00 | 63.27 | 70.95 | 51.33 |
TinyLlama-1.1B | 59.20 | 36.00 | 59.12 | 30.10 | 55.25 | 57.83 | 73.29 | 52.99 |
", "reading_order": 6},
15 | {"label": "cap", "bbox": [0.28, 0.729, 0.711, 0.74], "text": "Table 2: Zero-shot performance on commonsense reasoning tasks", "reading_order": 7},
16 | {"label": "para", "bbox": [0.176, 0.848, 0.826, 0.873], "text": "We of performance during training We tracked the accuracy of TinyLlama on common-\nsense reasoning benchmarks during its pre-training, as shown in Fig. 2 . Generally, the performance of", "reading_order": 8},
17 | {"label": "fnote", "bbox": [0.176, 0.88, 0.824, 0.912], "text": "${ }^{4}$ Due to a bug in the config file, the learning rate did not decrease immediately after warmup and remained at\nthe maximum value for several steps before we fixed this.", "reading_order": 9},
18 | {"label": "foot", "bbox": [0.496, 0.939, 0.501, 0.95], "text": "14", "reading_order": 10}
19 | ]
20 | """
21 |
22 |
23 | def extract_table_from_html(html_string):
24 | """Extract and clean table tags from HTML string"""
25 | try:
26 | table_pattern = re.compile(r'.*?', re.DOTALL)
27 | tables = table_pattern.findall(html_string)
28 | tables = [re.sub(r']*>', '', table) for table in tables]
29 | return '\n'.join(tables)
30 | except Exception as e:
31 | print(f"extract_table_from_html error: {str(e)}")
32 | return f"Error extracting table: {str(e)} |
"
33 |
34 |
35 | class MarkdownConverter:
36 | """Convert structured recognition results to Markdown format"""
37 |
38 | def __init__(self):
39 | # Define heading levels for different section types
40 | self.heading_levels = {
41 | 'title': '#',
42 | 'sec': '##',
43 | 'sub_sec': '###'
44 | }
45 |
46 | # Define which labels need special handling
47 | self.special_labels = {
48 | 'tab', 'fig', 'title', 'sec', 'sub_sec',
49 | 'list', 'formula', 'reference', 'alg'
50 | }
51 |
52 | def try_remove_newline(self, text: str) -> str:
53 | try:
54 | # Preprocess text to handle line breaks
55 | text = text.strip()
56 | text = text.replace('-\n', '')
57 |
58 | # Handle Chinese text line breaks
59 | def is_chinese(char):
60 | return '\u4e00' <= char <= '\u9fff'
61 |
62 | lines = text.split('\n')
63 | processed_lines = []
64 |
65 | # Process all lines except the last one
66 | for i in range(len(lines)-1):
67 | current_line = lines[i].strip()
68 | next_line = lines[i+1].strip()
69 |
70 | # Always add the current line, but determine if we need a newline
71 | if current_line: # If current line is not empty
72 | if next_line: # If next line is not empty
73 | # For Chinese text handling
74 | if is_chinese(current_line[-1]) and is_chinese(next_line[0]):
75 | processed_lines.append(current_line)
76 | else:
77 | processed_lines.append(current_line + ' ')
78 | else:
79 | # Next line is empty, add current line with newline
80 | processed_lines.append(current_line + '\n')
81 | else:
82 | # Current line is empty, add an empty line
83 | processed_lines.append('\n')
84 |
85 | # Add the last line
86 | if lines and lines[-1].strip():
87 | processed_lines.append(lines[-1].strip())
88 |
89 | text = ''.join(processed_lines)
90 |
91 | return text
92 | except Exception as e:
93 | print(f"try_remove_newline error: {str(e)}")
94 | return text # Return original text on error
95 |
96 | def _handle_text(self, text: str) -> str:
97 | """
98 | Process regular text content, preserving paragraph structure
99 | """
100 | try:
101 | if not text:
102 | return ""
103 |
104 | if text.strip().startswith("\\begin{array}") and text.strip().endswith("\\end{array}"):
105 | text = "$$" + text + "$$"
106 | elif ("_{" in text or "^{" in text or "\\" in text or "_ {" in text or "^ {" in text) and ("$" not in text) and ("\\begin" not in text):
107 | text = "$" + text + "$"
108 |
109 | # Process formulas in text before handling other text processing
110 | text = self._process_formulas_in_text(text)
111 |
112 | text = self.try_remove_newline(text)
113 |
114 | # Return processed text
115 | return text
116 | except Exception as e:
117 | print(f"_handle_text error: {str(e)}")
118 | return text # Return original text on error
119 |
120 | def _process_formulas_in_text(self, text: str) -> str:
121 | """
122 | Process mathematical formulas in text by iteratively finding and replacing formulas.
123 | - Identify inline and block formulas
124 | - Replace newlines within formulas with \\
125 | """
126 | try:
127 | # Define formula delimiters and their corresponding patterns
128 | delimiters = [
129 | ('$$', '$$'), # Block formula with $$
130 | ('\\[', '\\]'), # Block formula with \[ \]
131 | ('$', '$'), # Inline formula with $
132 | ('\\(', '\\)') # Inline formula with \( \)
133 | ]
134 |
135 | # Process the text by iterating through each delimiter type
136 | result = text
137 |
138 | for start_delim, end_delim in delimiters:
139 | # Create a pattern that matches from start to end delimiter
140 | # Using a custom approach to avoid issues with nested delimiters
141 | current_pos = 0
142 | processed_parts = []
143 |
144 | while current_pos < len(result):
145 | # Find the next start delimiter
146 | start_pos = result.find(start_delim, current_pos)
147 | if start_pos == -1:
148 | # No more formulas of this type
149 | processed_parts.append(result[current_pos:])
150 | break
151 |
152 | # Add text before the formula
153 | processed_parts.append(result[current_pos:start_pos])
154 |
155 | # Find the matching end delimiter
156 | end_pos = result.find(end_delim, start_pos + len(start_delim))
157 | if end_pos == -1:
158 | # No matching end delimiter, treat as regular text
159 | processed_parts.append(result[start_pos:])
160 | break
161 |
162 | # Extract the formula content (without delimiters)
163 | formula_content = result[start_pos + len(start_delim):end_pos]
164 |
165 | # Process the formula content - replace newlines with \\
166 | processed_formula = formula_content.replace('\n', ' \\\\ ')
167 |
168 | # Add the processed formula with its delimiters
169 | processed_parts.append(f"{start_delim}{processed_formula}{end_delim}")
170 |
171 | # Move past this formula
172 | current_pos = end_pos + len(end_delim)
173 |
174 | # Update the result with processed text
175 | result = ''.join(processed_parts)
176 | return result
177 | except Exception as e:
178 | print(f"_process_formulas_in_text error: {str(e)}")
179 | return text # Return original text on error
180 |
181 | def _remove_newline_in_heading(self, text: str) -> str:
182 | """
183 | Remove newline in heading
184 | """
185 | try:
186 | # Handle Chinese text line breaks
187 | def is_chinese(char):
188 | return '\u4e00' <= char <= '\u9fff'
189 |
190 | # Check if the text contains Chinese characters
191 | if any(is_chinese(char) for char in text):
192 | return text.replace('\n', '')
193 | else:
194 | return text.replace('\n', ' ')
195 |
196 | except Exception as e:
197 | print(f"_remove_newline_in_heading error: {str(e)}")
198 | return text
199 |
200 | def _handle_heading(self, text: str, label: str) -> str:
201 | """
202 | Convert section headings to appropriate markdown format
203 | """
204 | try:
205 | level = self.heading_levels.get(label, '#')
206 | text = text.strip()
207 | text = self._remove_newline_in_heading(text)
208 | text = self._handle_text(text)
209 | return f"{level} {text}\n\n"
210 | except Exception as e:
211 | print(f"_handle_heading error: {str(e)}")
212 | return f"# Error processing heading: {text}\n\n"
213 |
214 | def _handle_list_item(self, text: str) -> str:
215 | """
216 | Convert list items to markdown list format
217 | """
218 | try:
219 | return f"- {text.strip()}\n"
220 | except Exception as e:
221 | print(f"_handle_list_item error: {str(e)}")
222 | return f"- Error processing list item: {text}\n"
223 |
224 | def _handle_figure(self, text: str, section_count: int) -> str:
225 | """
226 | Convert base64 encoded image to markdown image syntax
227 | """
228 | try:
229 | # Determine image format (assuming PNG if not specified)
230 | img_format = "png"
231 | if text.startswith("data:image/"):
232 | # Extract format from data URI
233 | img_format = text.split(";")[0].split("/")[1]
234 | elif ";" in text and "," in text:
235 | # Already in data URI format
236 | return f"\n\n"
237 | else:
238 | # Raw base64, convert to data URI
239 | data_uri = f"data:image/{img_format};base64,{text}"
240 | return f"\n\n"
241 | except Exception as e:
242 | print(f"_handle_figure error: {str(e)}")
243 | return f"*[Error processing figure: {str(e)}]*\n\n"
244 |
245 | def _handle_table(self, text: str) -> str:
246 | """
247 | Convert table content to markdown format
248 | """
249 | try:
250 | markdown_content = []
251 | if ' str:
272 | """
273 | Process algorithm blocks with proper formatting
274 | """
275 | try:
276 | # Remove algorithm environment tags if present
277 | text = re.sub(r'\\begin\{algorithm\}(.*?)\\end\{algorithm\}', r'\1', text, flags=re.DOTALL)
278 | text = text.replace('\\begin{algorithm}', '').replace('\\end{algorithm}', '')
279 | text = text.replace('\\begin{algorithmic}', '').replace('\\end{algorithmic}', '')
280 |
281 | # Process the algorithm text
282 | lines = text.strip().split('\n')
283 |
284 | # Check if there's a caption or label
285 | caption = ""
286 | algorithm_text = []
287 |
288 | for line in lines:
289 | if '\\caption' in line:
290 | # Extract caption text
291 | caption_match = re.search(r'\\caption\{(.*?)\}', line)
292 | if caption_match:
293 | caption = f"**{caption_match.group(1)}**\n\n"
294 | continue
295 | elif '\\label' in line:
296 | continue # Skip label lines
297 | else:
298 | algorithm_text.append(line)
299 |
300 | # Join the algorithm text and wrap in code block
301 | formatted_text = '\n'.join(algorithm_text)
302 |
303 | # Return the formatted algorithm with caption
304 | return f"{caption}```\n{formatted_text}\n```\n\n"
305 | except Exception as e:
306 | print(f"_handle_algorithm error: {str(e)}")
307 | return f"*[Error processing algorithm: {str(e)}]*\n\n{text}\n\n"
308 |
309 | def _handle_formula(self, text: str) -> str:
310 | """
311 | Handle formula-specific content
312 | """
313 | try:
314 | # Process the formula content
315 | processed_text = self._process_formulas_in_text(text)
316 |
317 | # For formula blocks, ensure they're properly formatted in markdown
318 | if '$$' not in processed_text and '\\[' not in processed_text:
319 | # If no block formula delimiters are present, wrap in $$ for block formula
320 | processed_text = f'$${processed_text}$$'
321 |
322 | return f"{processed_text}\n\n"
323 | except Exception as e:
324 | print(f"_handle_formula error: {str(e)}")
325 | return f"*[Error processing formula: {str(e)}]*\n\n"
326 |
327 | def convert(self, recognition_results: List[Dict[str, Any]]) -> str:
328 | """
329 | Convert recognition results to markdown format
330 | """
331 | try:
332 | markdown_content = []
333 |
334 | for section_count, result in enumerate(recognition_results):
335 | try:
336 | label = result.get('label', '')
337 | text = result.get('text', '').strip()
338 |
339 | # Skip empty text
340 | if not text:
341 | continue
342 |
343 | # Handle different content types
344 | if label in {'title', 'sec', 'sub_sec'}:
345 | markdown_content.append(self._handle_heading(text, label))
346 | elif label == 'list':
347 | markdown_content.append(self._handle_list_item(text))
348 | elif label == 'fig':
349 | markdown_content.append(self._handle_figure(text, section_count))
350 | elif label == 'tab':
351 | markdown_content.append(self._handle_table(text))
352 | elif label == 'alg':
353 | markdown_content.append(self._handle_algorithm(text))
354 | elif label == 'formula':
355 | markdown_content.append(self._handle_formula(text))
356 | elif label not in self.special_labels:
357 | # Handle regular text (paragraphs, etc.)
358 | processed_text = self._handle_text(text)
359 | markdown_content.append(f"{processed_text}\n\n")
360 | except Exception as e:
361 | print(f"Error processing item {section_count}: {str(e)}")
362 | # Add a placeholder for the failed item
363 | markdown_content.append(f"*[Error processing content]*\n\n")
364 |
365 | # Join all content and apply post-processing
366 | result = ''.join(markdown_content)
367 | return self._post_process(result)
368 | except Exception as e:
369 | print(f"convert error: {str(e)}")
370 | return f"Error generating markdown content: {str(e)}"
371 |
372 | def _post_process(self, markdown_content: str) -> str:
373 | """
374 | Apply post-processing fixes to the generated markdown content
375 | """
376 | try:
377 | # Handle author information
378 | author_pattern = re.compile(r'\\author\{(.*?)\}', re.DOTALL)
379 |
380 | def process_author_match(match):
381 | # Extract author content
382 | author_content = match.group(1)
383 | # Process the author content
384 | return self._handle_text(author_content)
385 |
386 | # Replace \author{...} with processed content
387 | markdown_content = author_pattern.sub(process_author_match, markdown_content)
388 |
389 | # Handle special case where author is inside math environment
390 | math_author_pattern = re.compile(r'\$(\\author\{.*?\})\$', re.DOTALL)
391 | match = math_author_pattern.search(markdown_content)
392 | if match:
393 | # Extract the author command
394 | author_cmd = match.group(1)
395 | # Extract content from author command
396 | author_content_match = re.search(r'\\author\{(.*?)\}', author_cmd, re.DOTALL)
397 | if author_content_match:
398 | # Get author content and process it
399 | author_content = author_content_match.group(1)
400 | processed_content = self._handle_text(author_content)
401 | # Replace the entire $\author{...}$ block with processed content
402 | markdown_content = markdown_content.replace(match.group(0), processed_content)
403 |
404 | # Replace LaTeX abstract environment with plain text
405 | markdown_content = re.sub(r'\\begin\{abstract\}(.*?)\\end\{abstract\}',
406 | r'**Abstract** \1',
407 | markdown_content,
408 | flags=re.DOTALL)
409 |
410 | # Replace standalone \begin{abstract} (without matching end)
411 | markdown_content = re.sub(r'\\begin\{abstract\}',
412 | r'**Abstract**',
413 | markdown_content)
414 |
415 | # Replace LaTeX equation numbers with tag format, handling cases with extra backslashes
416 | markdown_content = re.sub(r'\\eqno\{\((.*?)\)\}',
417 | r'\\tag{\1}',
418 | markdown_content)
419 |
420 | # Find the starting tag of the formula
421 | markdown_content = markdown_content.replace("\[ \\\\", "$$ \\\\")
422 |
423 | # Find the ending tag of the formula (ensure this is the only ending tag)
424 | markdown_content = markdown_content.replace("\\\\ \]", "\\\\ $$")
425 |
426 | # Fix other common LaTeX issues
427 | replacements = [
428 | # Fix spacing issues in subscripts and superscripts
429 | (r'_ {', r'_{'),
430 | (r'^ {', r'^{'),
431 |
432 | # Fix potential issues with multiple consecutive newlines
433 | (r'\n{3,}', r'\n\n')
434 | ]
435 |
436 | for old, new in replacements:
437 | markdown_content = re.sub(old, new, markdown_content)
438 |
439 | return markdown_content
440 | except Exception as e:
441 | print(f"_post_process error: {str(e)}")
442 | return markdown_content # Return original content if post-processing fails
443 |
--------------------------------------------------------------------------------
/utils/model.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022-present NAVER Corp.
3 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
4 | MIT License
5 | This file has been modified by [ByteDance Ltd. and/or its affiliates] on 20250118.
6 | The original file available at https://github.com/clovaai/donut/blob/master/donut/model.py was released under the MIT license.
7 | This modified file is released under the same license.
8 | """
9 |
10 | import logging
11 | from collections import defaultdict
12 | from typing import List, Optional
13 |
14 | import torch
15 | import torch.nn.functional as F
16 | from PIL import Image
17 | from timm.models.swin_transformer import SwinTransformer
18 | from torch import nn
19 | from transformers import (
20 | MBartConfig,
21 | MBartForCausalLM,
22 | StoppingCriteria,
23 | StoppingCriteriaList,
24 | )
25 | from transformers.file_utils import ModelOutput
26 | from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
27 |
28 |
29 | class SwinEncoder(nn.Module):
30 | r"""
31 | Encoder based on SwinTransformer
32 | Set the initial weights and configuration with a pretrained SwinTransformer and then
33 | modify the detailed configurations
34 |
35 | Args:
36 | input_size: Input image size (width, height)
37 | align_long_axis: Whether to rotate image if height is greater than width
38 | window_size: Window size(=patch size) of SwinTransformer
39 | encoder_layer: Number of layers of SwinTransformer encoder
40 | name_or_path: Name of a pretrained model name either registered in huggingface.co. or saved in local.
41 | otherwise, `swin_base_patch4_window12_384` will be set (using `timm`).
42 | """
43 |
44 | def __init__(
45 | self,
46 | input_size,
47 | align_long_axis: bool = False,
48 | window_size: int = 7,
49 | encoder_layer: List[int] = [2, 2, 14, 2],
50 | patch_size: int = [4, 4],
51 | embed_dim: int = 128,
52 | num_heads: List[int] = [4, 8, 16, 32],
53 | ):
54 | super().__init__()
55 | if isinstance(input_size, int):
56 | input_size = [input_size, input_size]
57 | self.input_size = input_size
58 | self.align_long_axis = align_long_axis
59 | self.window_size = window_size
60 | self.encoder_layer = encoder_layer
61 | self.patch_size = patch_size
62 | self.embed_dim = embed_dim
63 | self.num_heads = num_heads
64 |
65 | self.model = SwinTransformer(
66 | img_size=self.input_size,
67 | depths=self.encoder_layer,
68 | window_size=self.window_size,
69 | patch_size=self.patch_size,
70 | embed_dim=self.embed_dim,
71 | num_heads=self.num_heads,
72 | num_classes=0,
73 | )
74 |
75 | def forward(self, x: torch.Tensor, text_embedding: Optional[torch.Tensor] = None) -> torch.Tensor:
76 | """
77 | Args:
78 | x: (batch_size, num_channels, height, width)
79 | """
80 | x = self.model.patch_embed(x)
81 | x = self.model.pos_drop(x)
82 | x = self.model.layers(x)
83 | return x
84 |
85 |
86 | class LayerNorm(nn.LayerNorm):
87 | """Subclass torch's LayerNorm to handle fp16."""
88 |
89 | def _set_dtype(self, dtype):
90 | self._dtype = dtype
91 |
92 | def forward(self, x: torch.Tensor):
93 | orig_type = x.dtype
94 | ret = super().forward(x.type(dtype=self._dtype))
95 | return ret.type(orig_type)
96 |
97 |
98 | class BARTDecoder(nn.Module):
99 | """
100 | Decoder based on Multilingual BART
101 | Set the initial weights and configuration with a pretrained multilingual BART model,
102 | and modify the detailed configurations as a Donut decoder
103 |
104 | Args:
105 | decoder_layer:
106 | Number of layers of BARTDecoder
107 | max_position_embeddings:
108 | The maximum sequence length to be trained
109 | name_or_path:
110 | Name of a pretrained model name either registered in huggingface.co. or saved in local,
111 | otherwise, `facebook/mbart-large-50` will be set (using `transformers`)
112 | """
113 |
114 | def __init__(
115 | self,
116 | tokenizer,
117 | decoder_layer: int,
118 | max_position_embeddings: int,
119 | hidden_dimension: int = 1024,
120 | **kwargs,
121 | ):
122 | super().__init__()
123 | self.decoder_layer = decoder_layer
124 | self.max_position_embeddings = max_position_embeddings
125 | self.hidden_dimension = hidden_dimension
126 |
127 | self.tokenizer = tokenizer
128 |
129 | self.model = MBartForCausalLM(
130 | config=MBartConfig(
131 | tie_word_embeddings=True,
132 | is_decoder=True,
133 | is_encoder_decoder=False,
134 | add_cross_attention=True,
135 | decoder_layers=self.decoder_layer,
136 | max_position_embeddings=self.max_position_embeddings,
137 | vocab_size=len(self.tokenizer),
138 | scale_embedding=True,
139 | add_final_layer_norm=True,
140 | d_model=self.hidden_dimension,
141 | )
142 | )
143 | # self.model.config.is_encoder_decoder = True # to get cross-attention
144 | self.model.model.decoder.embed_tokens.padding_idx = self.tokenizer.pad_token_id
145 | self.model.prepare_inputs_for_generation = self.prepare_inputs_for_inference
146 |
147 | def add_special_tokens(self, list_of_tokens: List[str]):
148 | """
149 | Add special tokens to tokenizer and resize the token embeddings
150 | """
151 | newly_added_num = self.tokenizer.add_special_tokens({"additional_special_tokens": sorted(set(list_of_tokens))})
152 | if newly_added_num > 0:
153 | self.model.resize_token_embeddings(len(self.tokenizer))
154 |
155 | def add_tokens(self, list_of_tokens: List[str]):
156 | """
157 | Add special tokens to tokenizer and resize the token embeddings
158 | """
159 | newly_added_num = self.tokenizer.add_tokens(sorted(set(list_of_tokens)))
160 | if newly_added_num > 0:
161 | self.model.resize_token_embeddings(len(self.tokenizer))
162 |
163 | def prepare_inputs_for_inference(
164 | self,
165 | input_ids: torch.Tensor,
166 | encoder_outputs: torch.Tensor,
167 | past=None,
168 | past_key_values=None,
169 | use_cache: bool = None,
170 | attention_mask: torch.Tensor = None,
171 | **kwargs,
172 | ):
173 | """
174 | Args:
175 | input_ids: (batch_size, sequence_length)
176 |
177 | Returns:
178 | input_ids: (batch_size, sequence_length)
179 | attention_mask: (batch_size, sequence_length)
180 | encoder_hidden_states: (batch_size, sequence_length, embedding_dim)
181 | """
182 | attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long()
183 | past = past or past_key_values
184 | if past is not None:
185 | input_ids = input_ids[:, -1:]
186 | output = {
187 | "input_ids": input_ids,
188 | "attention_mask": attention_mask,
189 | "past_key_values": past,
190 | "use_cache": use_cache,
191 | "encoder_hidden_states": encoder_outputs.last_hidden_state,
192 | }
193 | return output
194 |
195 | def forward(
196 | self,
197 | input_ids: torch.LongTensor = None,
198 | attention_mask: Optional[torch.Tensor] = None,
199 | encoder_hidden_states: Optional[torch.Tensor] = None,
200 | past_key_values: Optional[torch.Tensor] = None,
201 | inputs_embeds: Optional[torch.FloatTensor] = None,
202 | labels: Optional[torch.Tensor] = None,
203 | use_cache: bool = None,
204 | output_attentions: Optional[torch.Tensor] = None,
205 | output_hidden_states: Optional[torch.Tensor] = None,
206 | return_dict: bool = None,
207 | ):
208 | return self.model.forward(
209 | input_ids=input_ids,
210 | attention_mask=attention_mask,
211 | labels=labels,
212 | encoder_hidden_states=encoder_hidden_states,
213 | past_key_values=past_key_values,
214 | inputs_embeds=inputs_embeds,
215 | use_cache=use_cache,
216 | output_attentions=output_attentions,
217 | output_hidden_states=output_hidden_states,
218 | return_dict=return_dict,
219 | )
220 |
221 | @staticmethod
222 | def resize_bart_abs_pos_emb(weight: torch.Tensor, max_length: int) -> torch.Tensor:
223 | """
224 | Resize position embeddings
225 | Truncate if sequence length of MBart backbone is greater than given max_length,
226 | else interpolate to max_length
227 | """
228 | if weight.shape[0] > max_length:
229 | weight = weight[:max_length, ...]
230 | else:
231 | weight = (
232 | F.interpolate(
233 | weight.permute(1, 0).unsqueeze(0),
234 | size=max_length,
235 | mode="linear",
236 | align_corners=False,
237 | )
238 | .squeeze(0)
239 | .permute(1, 0)
240 | )
241 | return weight
242 |
243 |
244 | class DonutConfig(PretrainedConfig):
245 |
246 | def __init__(
247 | self,
248 | decoder_layer: int = 10,
249 | max_position_embeddings: int = None,
250 | max_length: int = 4096,
251 | hidden_dimension: int = 1024,
252 | **kwargs,
253 | ):
254 | super().__init__()
255 | self.decoder_layer = decoder_layer
256 | self.max_position_embeddings = max_length if max_position_embeddings is None else max_position_embeddings
257 | self.max_length = max_length
258 | self.hidden_dimension = hidden_dimension
259 |
260 |
261 | class RunningVarTorch:
262 | def __init__(self, L=15, norm=False):
263 | self.values = None
264 | self.L = L
265 | self.norm = norm
266 |
267 | def push(self, x: torch.Tensor):
268 | assert x.dim() == 1
269 | if self.values is None:
270 | self.values = x[:, None]
271 | elif self.values.shape[1] < self.L:
272 | self.values = torch.cat((self.values, x[:, None]), 1)
273 | else:
274 | self.values = torch.cat((self.values[:, 1:], x[:, None]), 1)
275 |
276 | def variance(self):
277 | if self.values is None:
278 | return
279 | if self.norm:
280 | return torch.var(self.values, 1) / self.values.shape[1]
281 | else:
282 | return torch.var(self.values, 1)
283 |
284 |
285 | class StoppingCriteriaScores(StoppingCriteria):
286 | def __init__(self, threshold: float = 0.015, window_size: int = 200):
287 | super().__init__()
288 | self.threshold = threshold
289 | self.vars = RunningVarTorch(norm=True)
290 | self.varvars = RunningVarTorch(L=window_size)
291 | self.stop_inds = defaultdict(int)
292 | self.stopped = defaultdict(bool)
293 | self.size = 0
294 | self.window_size = window_size
295 |
296 | @torch.no_grad()
297 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
298 | last_scores = scores[-1]
299 | self.vars.push(last_scores.max(1)[0].float().cpu())
300 | self.varvars.push(self.vars.variance())
301 | self.size += 1
302 | if self.size < self.window_size:
303 | return False
304 |
305 | varvar = self.varvars.variance()
306 | for b in range(len(last_scores)):
307 | if varvar[b] < self.threshold:
308 | if self.stop_inds[b] > 0 and not self.stopped[b]:
309 | self.stopped[b] = self.stop_inds[b] >= self.size
310 | else:
311 | self.stop_inds[b] = int(min(max(self.size, 1) * 1.15 + 150 + self.window_size, 4095))
312 | else:
313 | self.stop_inds[b] = 0
314 | self.stopped[b] = False
315 | return all(self.stopped.values()) and len(self.stopped) > 0
316 |
317 |
318 | def batch(l, b=15):
319 | subs = []
320 | for i in range(len(l) - b):
321 | subs.append(l[i : i + b])
322 | return subs
323 |
324 |
325 | def subdiv(l, b=10):
326 | subs = []
327 | for i in range(len(l) - b):
328 | subs.append(l[: i + b])
329 | return subs
330 |
331 |
332 | class DonutModel(PreTrainedModel):
333 | config_class = DonutConfig
334 | base_model_prefix = "donut"
335 |
336 | def __init__(self, config: DonutConfig, vision_tower=None, tokenizer=None):
337 | super().__init__(config)
338 | self.config = config
339 |
340 | self.tokenizer = tokenizer
341 | self.vpm = vision_tower
342 |
343 | # build language model
344 | self.llm = BARTDecoder(
345 | tokenizer=tokenizer,
346 | decoder_layer=self.config.decoder_layer,
347 | max_position_embeddings=self.config.max_position_embeddings,
348 | hidden_dimension=self.config.hidden_dimension,
349 | )
350 | self.ids_to_tokens = {id: content for content, id in self.llm.tokenizer.vocab.items()}
351 |
352 | def get_input_embeddings(self, tensor):
353 | return self.llm.model.get_input_embeddings()(tensor)
354 |
355 | def forward(
356 | self,
357 | inputs: dict,
358 | ):
359 | image_tensors = inputs["pixel_values"]
360 | input_ids = inputs["input_ids"].contiguous()
361 | attention_mask = inputs["attention_mask"]
362 | labels = inputs["labels"].contiguous()
363 |
364 | encoder_outputs = self.vpm(
365 | image_tensors,
366 | text_embedding=self.llm.model.get_input_embeddings()(input_ids),
367 | )
368 |
369 | decoder_outputs = self.llm(
370 | input_ids=input_ids,
371 | encoder_hidden_states=encoder_outputs,
372 | attention_mask=attention_mask,
373 | labels=labels,
374 | )
375 | return decoder_outputs
376 |
377 | def get_hidden_states_during_inference(
378 | self,
379 | prompt_ids: torch.Tensor,
380 | image: Image.Image = None,
381 | image_tensors: Optional[torch.Tensor] = None,
382 | ):
383 | if image_tensors is None:
384 | image_tensors = self.vpm.prepare_input(image).unsqueeze(0)
385 |
386 | if self.device.type != "mps":
387 | image_tensors = image_tensors.to(next(self.parameters()).dtype)
388 |
389 | image_tensors = image_tensors.to(self.device)
390 | prompt_ids = prompt_ids.to(self.device)
391 | all_hidden_states = self.vpm.forward_features(
392 | image_tensors, text_embedding=self.get_input_embeddings(prompt_ids)
393 | )
394 | return all_hidden_states
395 |
396 | def get_attn_weights_during_inference(
397 | self,
398 | prompt_ids: torch.Tensor,
399 | image: Image.Image = None,
400 | image_tensors: Optional[torch.Tensor] = None,
401 | ):
402 | if image_tensors is None:
403 | image_tensors = self.vpm.prepare_input(image).unsqueeze(0)
404 |
405 | if self.device.type != "mps":
406 | image_tensors = image_tensors.to(next(self.parameters()).dtype)
407 |
408 | image_tensors = image_tensors.to(self.device)
409 | prompt_ids = prompt_ids.to(self.device)
410 | last_attn_score = self.vpm.get_last_layer_cross_attn_score(
411 | image_tensors, text_embedding=self.get_input_embeddings(prompt_ids)
412 | )
413 | return last_attn_score
414 |
415 | def inference(
416 | self,
417 | prompt_ids: torch.Tensor,
418 | image: Image.Image = None,
419 | image_tensors: Optional[torch.Tensor] = None,
420 | return_attentions: bool = False,
421 | early_stopping: bool = True,
422 | ):
423 | """
424 | Generate a token sequence in an auto-regressive manner.
425 |
426 | Args:
427 | image: input document image (PIL.Image)
428 | image_tensors: (1, num_channels, height, width)
429 | convert prompt to tensor if image_tensor is not fed
430 | """
431 | output = {
432 | "predictions": list(),
433 | "sequences": list(),
434 | "repeats": list(),
435 | "repetitions": list(),
436 | }
437 | if image is None and image_tensors is None:
438 | logging.warn("Image not found")
439 | return output
440 |
441 | if image_tensors is None:
442 | image_tensors = self.vpm.prepare_input(image).unsqueeze(0)
443 |
444 | if self.device.type != "mps":
445 | image_tensors = image_tensors.to(next(self.parameters()).dtype)
446 |
447 | image_tensors = image_tensors.to(self.device)
448 | prompt_ids = prompt_ids.to(self.device)
449 | last_hidden_state = self.vpm(image_tensors, text_embedding=self.get_input_embeddings(prompt_ids))
450 |
451 | encoder_outputs = ModelOutput(last_hidden_state=last_hidden_state, attentions=None)
452 | if len(encoder_outputs.last_hidden_state.size()) == 1:
453 | encoder_outputs.last_hidden_state = encoder_outputs.last_hidden_state.unsqueeze(0)
454 |
455 | # get decoder output
456 | decoder_output = self.llm.model.generate(
457 | input_ids=prompt_ids,
458 | encoder_outputs=encoder_outputs,
459 | min_length=1,
460 | max_length=self.config.max_length,
461 | pad_token_id=self.llm.tokenizer.pad_token_id,
462 | eos_token_id=self.llm.tokenizer.eos_token_id,
463 | use_cache=True,
464 | return_dict_in_generate=True,
465 | output_scores=True,
466 | output_attentions=return_attentions,
467 | do_sample=False,
468 | num_beams=1,
469 | stopping_criteria=StoppingCriteriaList([StoppingCriteriaScores()] if early_stopping else []),
470 | )
471 |
472 | output["repetitions"] = decoder_output.sequences.clone()
473 | output["sequences"] = decoder_output.sequences.clone()
474 | output["scores"] = torch.stack(decoder_output.scores, 1).softmax(-1).cpu().max(-1)[0]
475 |
476 | output["repetitions"] = self.llm.tokenizer.batch_decode(output["repetitions"], skip_special_tokens=False)
477 | return output
478 |
--------------------------------------------------------------------------------
/utils/processor.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
3 | SPDX-License-Identifier: MIT
4 | """
5 |
6 | import numpy as np
7 | import torch
8 | from PIL import ImageOps
9 |
10 | from utils.utils import *
11 |
12 |
13 | class DolphinProcessor:
14 | def __init__(
15 | self,
16 | dp_config,
17 | tokenizer,
18 | **kwargs,
19 | ) -> None:
20 |
21 | self.tokenizer = tokenizer
22 | transform_args = kwargs.get("transform_args", {})
23 | self.max_length = transform_args.get("max_length", 2048)
24 | self.input_size = transform_args.get("input_size", [896, 896]) # height, width
25 | if isinstance(self.input_size, int):
26 | self.input_size = [self.input_size, self.input_size]
27 |
28 | try:
29 | self.answer_start_token = self.tokenizer._prompt_end_token
30 | except AttributeError as err:
31 | print('No answer_start_token found, use "" instead')
32 | self.answer_start_token = ""
33 |
34 | self.prefix_answer_space_flag = dp_config.get("prefix_answer_space_flag", True)
35 | self.suffix_prompt_space_flag = dp_config.get("suffix_prompt_space_flag", True)
36 |
37 | def process_prompt_for_inference(self, prompt):
38 | prompt = prompt.replace("\n", "")
39 | if not prompt.startswith(""):
40 | prompt = "" + prompt
41 | message_ids = [self.tokenizer.encode(prompt, add_special_tokens=False)]
42 | ids = torch.from_numpy(np.hstack(message_ids, dtype=np.int32))
43 | return ids.unsqueeze(0)
44 |
45 | def process_image_for_inference(self, image, return_img_size=False):
46 | image = resize(image, min(self.input_size))
47 |
48 | image.thumbnail((self.input_size[1], self.input_size[0]))
49 | origin_w, origin_h = image.size
50 |
51 | delta_width = self.input_size[1] - image.width
52 | delta_height = self.input_size[0] - image.height
53 | pad_width = delta_width // 2
54 | pad_height = delta_height // 2
55 | padding = (
56 | pad_width,
57 | pad_height,
58 | delta_width - pad_width,
59 | delta_height - pad_height,
60 | )
61 | image = ImageOps.expand(image, padding)
62 | if return_img_size:
63 | return test_transform(image).unsqueeze(0), (origin_w, origin_h)
64 | return test_transform(image).unsqueeze(0)
65 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
3 | SPDX-License-Identifier: MIT
4 | """
5 |
6 | import copy
7 | import json
8 | import os
9 | import re
10 | from dataclasses import dataclass
11 | from typing import List, Tuple
12 |
13 | import albumentations as alb
14 | import cv2
15 | import numpy as np
16 | from albumentations.pytorch import ToTensorV2
17 | from PIL import Image
18 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
19 | from torchvision.transforms.functional import resize
20 |
21 | from utils.markdown_utils import MarkdownConverter
22 |
23 |
24 | def alb_wrapper(transform):
25 | def f(im):
26 | return transform(image=np.asarray(im))["image"]
27 |
28 | return f
29 |
30 |
31 | test_transform = alb_wrapper(
32 | alb.Compose(
33 | [
34 | alb.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
35 | ToTensorV2(),
36 | ]
37 | )
38 | )
39 |
40 |
41 | def check_coord_valid(x1, y1, x2, y2, image_size=None, abs_coord=True):
42 | # print(f"check_coord_valid: {x1}, {y1}, {x2}, {y2}, {image_size}, {abs_coord}")
43 | if x2 <= x1 or y2 <= y1:
44 | return False, f"[{x1}, {y1}, {x2}, {y2}]"
45 | if x1 < 0 or y1 < 0:
46 | return False, f"[{x1}, {y1}, {x2}, {y2}]"
47 | if not abs_coord:
48 | if x2 > 1 or y2 > 1:
49 | return False, f"[{x1}, {y1}, {x2}, {y2}]"
50 | elif image_size is not None: # has image size
51 | if x2 > image_size[0] or y2 > image_size[1]:
52 | return False, f"[{x1}, {y1}, {x2}, {y2}]"
53 | return True, None
54 |
55 |
56 | def adjust_box_edges(image, boxes: List[List[float]], max_pixels=15, threshold=0.2):
57 | """
58 | Image: cv2.image object, or Path
59 | Input: boxes: list of boxes [[x1, y1, x2, y2]]. Using absolute coordinates.
60 | """
61 | if isinstance(image, str):
62 | image = cv2.imread(image)
63 | img_h, img_w = image.shape[:2]
64 | new_boxes = []
65 | for box in boxes:
66 | best_box = copy.deepcopy(box)
67 |
68 | def check_edge(img, current_box, i, is_vertical):
69 | edge = current_box[i]
70 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
71 | _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
72 |
73 | if is_vertical:
74 | line = binary[current_box[1] : current_box[3] + 1, edge]
75 | else:
76 | line = binary[edge, current_box[0] : current_box[2] + 1]
77 |
78 | transitions = np.abs(np.diff(line))
79 | return np.sum(transitions) / len(transitions)
80 |
81 | # Only widen the box
82 | edges = [(0, -1, True), (2, 1, True), (1, -1, False), (3, 1, False)]
83 |
84 | current_box = copy.deepcopy(box)
85 | # make sure the box is within the image
86 | current_box[0] = min(max(current_box[0], 0), img_w - 1)
87 | current_box[1] = min(max(current_box[1], 0), img_h - 1)
88 | current_box[2] = min(max(current_box[2], 0), img_w - 1)
89 | current_box[3] = min(max(current_box[3], 0), img_h - 1)
90 |
91 | for i, direction, is_vertical in edges:
92 | best_score = check_edge(image, current_box, i, is_vertical)
93 | if best_score <= threshold:
94 | continue
95 | for step in range(max_pixels):
96 | current_box[i] += direction
97 | if i == 0 or i == 2:
98 | current_box[i] = min(max(current_box[i], 0), img_w - 1)
99 | else:
100 | current_box[i] = min(max(current_box[i], 0), img_h - 1)
101 | score = check_edge(image, current_box, i, is_vertical)
102 |
103 | if score < best_score:
104 | best_score = score
105 | best_box = copy.deepcopy(current_box)
106 |
107 | if score <= threshold:
108 | break
109 | new_boxes.append(best_box)
110 |
111 | return new_boxes
112 |
113 |
114 | def parse_layout_string(bbox_str):
115 | """Parse layout string using regular expressions"""
116 | pattern = r"\[(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+)\]\s*(\w+)"
117 | matches = re.finditer(pattern, bbox_str)
118 |
119 | parsed_results = []
120 | for match in matches:
121 | coords = [float(match.group(i)) for i in range(1, 5)]
122 | label = match.group(5).strip()
123 | parsed_results.append((coords, label))
124 |
125 | return parsed_results
126 |
127 |
128 | @dataclass
129 | class ImageDimensions:
130 | """Class to store image dimensions"""
131 | original_w: int
132 | original_h: int
133 | padded_w: int
134 | padded_h: int
135 |
136 |
137 | def map_to_original_coordinates(x1, y1, x2, y2, dims: ImageDimensions) -> Tuple[int, int, int, int]:
138 | """Map coordinates from padded image back to original image
139 |
140 | Args:
141 | x1, y1, x2, y2: Coordinates in padded image
142 | dims: Image dimensions object
143 |
144 | Returns:
145 | tuple: (x1, y1, x2, y2) coordinates in original image
146 | """
147 | try:
148 | # Calculate padding offsets
149 | top = (dims.padded_h - dims.original_h) // 2
150 | left = (dims.padded_w - dims.original_w) // 2
151 |
152 | # Map back to original coordinates
153 | orig_x1 = max(0, x1 - left)
154 | orig_y1 = max(0, y1 - top)
155 | orig_x2 = min(dims.original_w, x2 - left)
156 | orig_y2 = min(dims.original_h, y2 - top)
157 |
158 | # Ensure we have a valid box (width and height > 0)
159 | if orig_x2 <= orig_x1:
160 | orig_x2 = min(orig_x1 + 1, dims.original_w)
161 | if orig_y2 <= orig_y1:
162 | orig_y2 = min(orig_y1 + 1, dims.original_h)
163 |
164 | return int(orig_x1), int(orig_y1), int(orig_x2), int(orig_y2)
165 | except Exception as e:
166 | print(f"map_to_original_coordinates error: {str(e)}")
167 | # Return safe coordinates
168 | return 0, 0, min(100, dims.original_w), min(100, dims.original_h)
169 |
170 |
171 | def map_to_relevant_coordinates(abs_coords, dims: ImageDimensions):
172 | """
173 | From absolute coordinates to relevant coordinates
174 | e.g. [100, 100, 200, 200] -> [0.1, 0.2, 0.3, 0.4]
175 | """
176 | try:
177 | x1, y1, x2, y2 = abs_coords
178 | return round(x1 / dims.original_w, 3), round(y1 / dims.original_h, 3), round(x2 / dims.original_w, 3), round(y2 / dims.original_h, 3)
179 | except Exception as e:
180 | print(f"map_to_relevant_coordinates error: {str(e)}")
181 | return 0.0, 0.0, 1.0, 1.0 # Return full image coordinates
182 |
183 |
184 | def process_coordinates(coords, padded_image, dims: ImageDimensions, previous_box=None):
185 | """Process and adjust coordinates
186 |
187 | Args:
188 | coords: Normalized coordinates [x1, y1, x2, y2]
189 | padded_image: Padded image
190 | dims: Image dimensions object
191 | previous_box: Previous box coordinates for overlap adjustment
192 |
193 | Returns:
194 | tuple: (x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box)
195 | """
196 | try:
197 | # Convert normalized coordinates to absolute coordinates
198 | x1, y1 = int(coords[0] * dims.padded_w), int(coords[1] * dims.padded_h)
199 | x2, y2 = int(coords[2] * dims.padded_w), int(coords[3] * dims.padded_h)
200 |
201 | # Ensure coordinates are within image bounds before adjustment
202 | x1 = max(0, min(x1, dims.padded_w - 1))
203 | y1 = max(0, min(y1, dims.padded_h - 1))
204 | x2 = max(0, min(x2, dims.padded_w))
205 | y2 = max(0, min(y2, dims.padded_h))
206 |
207 | # Ensure width and height are at least 1 pixel
208 | if x2 <= x1:
209 | x2 = min(x1 + 1, dims.padded_w)
210 | if y2 <= y1:
211 | y2 = min(y1 + 1, dims.padded_h)
212 |
213 | # Extend box boundaries
214 | new_boxes = adjust_box_edges(padded_image, [[x1, y1, x2, y2]])
215 | x1, y1, x2, y2 = new_boxes[0]
216 |
217 | # Ensure coordinates are still within image bounds after adjustment
218 | x1 = max(0, min(x1, dims.padded_w - 1))
219 | y1 = max(0, min(y1, dims.padded_h - 1))
220 | x2 = max(0, min(x2, dims.padded_w))
221 | y2 = max(0, min(y2, dims.padded_h))
222 |
223 | # Ensure width and height are at least 1 pixel after adjustment
224 | if x2 <= x1:
225 | x2 = min(x1 + 1, dims.padded_w)
226 | if y2 <= y1:
227 | y2 = min(y1 + 1, dims.padded_h)
228 |
229 | # Check for overlap with previous box and adjust
230 | if previous_box is not None:
231 | prev_x1, prev_y1, prev_x2, prev_y2 = previous_box
232 | if (x1 < prev_x2 and x2 > prev_x1) and (y1 < prev_y2 and y2 > prev_y1):
233 | y1 = prev_y2
234 | # Ensure y1 is still valid
235 | y1 = min(y1, dims.padded_h - 1)
236 | # Make sure y2 is still greater than y1
237 | if y2 <= y1:
238 | y2 = min(y1 + 1, dims.padded_h)
239 |
240 | # Update previous box
241 | new_previous_box = [x1, y1, x2, y2]
242 |
243 | # Map to original coordinates
244 | orig_x1, orig_y1, orig_x2, orig_y2 = map_to_original_coordinates(
245 | x1, y1, x2, y2, dims
246 | )
247 |
248 | return x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box
249 | except Exception as e:
250 | print(f"process_coordinates error: {str(e)}")
251 | # Return safe values
252 | orig_x1, orig_y1, orig_x2, orig_y2 = 0, 0, min(100, dims.original_w), min(100, dims.original_h)
253 | return 0, 0, 100, 100, orig_x1, orig_y1, orig_x2, orig_y2, [0, 0, 100, 100]
254 |
255 |
256 | def prepare_image(image) -> Tuple[np.ndarray, ImageDimensions]:
257 | """Load and prepare image with padding while maintaining aspect ratio
258 |
259 | Args:
260 | image: PIL image
261 |
262 | Returns:
263 | tuple: (padded_image, image_dimensions)
264 | """
265 | try:
266 | # Convert PIL image to OpenCV format
267 | image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
268 | original_h, original_w = image.shape[:2]
269 |
270 | # Calculate padding to make square image
271 | max_size = max(original_h, original_w)
272 | top = (max_size - original_h) // 2
273 | bottom = max_size - original_h - top
274 | left = (max_size - original_w) // 2
275 | right = max_size - original_w - left
276 |
277 | # Apply padding
278 | padded_image = cv2.copyMakeBorder(image, top, bottom, left, right,
279 | cv2.BORDER_CONSTANT, value=(0, 0, 0))
280 |
281 | padded_h, padded_w = padded_image.shape[:2]
282 |
283 | dimensions = ImageDimensions(
284 | original_w=original_w,
285 | original_h=original_h,
286 | padded_w=padded_w,
287 | padded_h=padded_h
288 | )
289 |
290 | return padded_image, dimensions
291 | except Exception as e:
292 | print(f"prepare_image error: {str(e)}")
293 | # Create a minimal valid image and dimensions
294 | h, w = image.height, image.width
295 | dimensions = ImageDimensions(
296 | original_w=w,
297 | original_h=h,
298 | padded_w=w,
299 | padded_h=h
300 | )
301 | # Return a black image of the same size
302 | return np.zeros((h, w, 3), dtype=np.uint8), dimensions
303 |
304 |
305 |
306 |
307 | def setup_output_dirs(save_dir):
308 | """Create necessary output directories"""
309 | os.makedirs(save_dir, exist_ok=True)
310 | os.makedirs(os.path.join(save_dir, "markdown"), exist_ok=True)
311 | os.makedirs(os.path.join(save_dir, "recognition_json"), exist_ok=True)
312 |
313 |
314 | def save_outputs(recognition_results, image_path, save_dir):
315 | """Save JSON and markdown outputs"""
316 | basename = os.path.splitext(os.path.basename(image_path))[0]
317 |
318 | # Save JSON file
319 | json_path = os.path.join(save_dir, "recognition_json", f"{basename}.json")
320 | with open(json_path, "w", encoding="utf-8") as f:
321 | json.dump(recognition_results, f, ensure_ascii=False, indent=2)
322 |
323 | # Generate and save markdown file
324 | markdown_converter = MarkdownConverter()
325 | markdown_content = markdown_converter.convert(recognition_results)
326 | markdown_path = os.path.join(save_dir, "markdown", f"{basename}.md")
327 | with open(markdown_path, "w", encoding="utf-8") as f:
328 | f.write(markdown_content)
329 |
330 | return json_path
331 |
332 |
333 | def crop_margin(img: Image.Image) -> Image.Image:
334 | """Crop margins from image"""
335 | try:
336 | width, height = img.size
337 | if width == 0 or height == 0:
338 | print("Warning: Image has zero width or height")
339 | return img
340 |
341 | data = np.array(img.convert("L"))
342 | data = data.astype(np.uint8)
343 | max_val = data.max()
344 | min_val = data.min()
345 | if max_val == min_val:
346 | return img
347 | data = (data - min_val) / (max_val - min_val) * 255
348 | gray = 255 * (data < 200).astype(np.uint8)
349 |
350 | coords = cv2.findNonZero(gray) # Find all non-zero points (text)
351 | if coords is None:
352 | return img
353 | a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
354 |
355 | # Ensure crop coordinates are within image bounds
356 | a = max(0, a)
357 | b = max(0, b)
358 | w = min(w, width - a)
359 | h = min(h, height - b)
360 |
361 | # Only crop if we have a valid region
362 | if w > 0 and h > 0:
363 | return img.crop((a, b, a + w, b + h))
364 | return img
365 | except Exception as e:
366 | print(f"crop_margin error: {str(e)}")
367 | return img # Return original image on error
368 |
--------------------------------------------------------------------------------