The response has been limited to 50k tokens of the smallest files in the repo. You can remove this limitation by removing the max tokens filter.
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── README_CN.md
├── assets
    ├── demo.gif
    ├── dolphin.png
    └── framework.png
├── chat.py
├── config
    └── Dolphin.yaml
├── demo
    ├── element_imgs
    │   ├── block_formula.jpeg
    │   ├── line_formula.jpeg
    │   ├── para_1.jpg
    │   ├── para_2.jpg
    │   ├── para_3.jpeg
    │   ├── table_1.jpeg
    │   └── table_2.jpeg
    └── page_imgs
    │   ├── page_1.jpeg
    │   ├── page_2.jpeg
    │   ├── page_3.jpeg
    │   ├── page_4.png
    │   ├── page_5.jpg
    │   ├── page_6.pdf
    │   └── page_7.jpeg
├── demo_element.py
├── demo_element_hf.py
├── demo_page.py
├── demo_page_hf.py
├── deployment
    ├── ReadMe.md
    ├── tensorrt_llm
    │   ├── ReadMe.md
    │   ├── api_client.py
    │   ├── api_server.py
    │   ├── convert
    │   │   ├── __init__.py
    │   │   ├── build_visual_engine.py
    │   │   ├── convert_checkpoint.py
    │   │   └── helper.py
    │   ├── convert_dolphin.sh
    │   ├── dolphin_runner.py
    │   ├── run_dolphin.py
    │   ├── run_dolphin.sh
    │   ├── start_dolphin_server.sh
    │   └── utils.py
    └── vllm
    │   ├── ReadMe.md
    │   ├── api_client.py
    │   ├── api_server.py
    │   └── demo_vllm.py
├── pyproject.toml
├── requirements.txt
└── utils
    ├── markdown_utils.py
    ├── model.py
    ├── processor.py
    └── utils.py


/.gitignore:
--------------------------------------------------------------------------------
  1 | # Byte-compiled / optimized / DLL files
  2 | __pycache__/
  3 | *.py[cod]
  4 | *$py.class
  5 | 
  6 | # C extensions
  7 | *.so
  8 | 
  9 | # Distribution / packaging
 10 | .Python
 11 | build/
 12 | develop-eggs/
 13 | dist/
 14 | downloads/
 15 | eggs/
 16 | .eggs/
 17 | lib/
 18 | lib64/
 19 | parts/
 20 | sdist/
 21 | var/
 22 | wheels/
 23 | *.egg-info/
 24 | .installed.cfg
 25 | *.egg
 26 | MANIFEST
 27 | 
 28 | # PyInstaller
 29 | #  Usually these files are written by a python script from a template
 30 | #  before PyInstaller builds the exe, so as to inject date/other infos into it.
 31 | *.manifest
 32 | *.spec
 33 | 
 34 | # Installer logs
 35 | pip-log.txt
 36 | pip-delete-this-directory.txt
 37 | 
 38 | # Unit test / coverage reports
 39 | htmlcov/
 40 | .tox/
 41 | .nox/
 42 | .coverage
 43 | *.cover
 44 | *.py,cover
 45 | .hypothesis/
 46 | .pytest_cache/
 47 | coverage.xml
 48 | *.mo
 49 | *.pot
 50 | 
 51 | # Translations
 52 | *.mo
 53 | *.pot
 54 | 
 55 | # Django stuff:
 56 | *.log
 57 | local_settings.py
 58 | db.sqlite3
 59 | db.sqlite3-journal
 60 | 
 61 | # Flask stuff:
 62 | instance/
 63 | .webassets-cache
 64 | 
 65 | # Scrapy stuff:
 66 | .scrapy
 67 | 
 68 | # Sphinx documentation
 69 | docs/_build/
 70 | 
 71 | # PyBuilder
 72 | target/
 73 | 
 74 | # Jupyter Notebook
 75 | .ipynb_checkpoints
 76 | 
 77 | # IPython
 78 | profile_default/
 79 | ipython_config.py
 80 | 
 81 | # pyenv
 82 | .python-version
 83 | 
 84 | # pipenv
 85 | #   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
 86 | #   However, in case of collaboration, if having platform-specific dependencies or dependencies
 87 | #   having no cross-platform support, pipenv may install dependencies that don't work, or not
 88 | #   install all needed dependencies.
 89 | #Pipfile.lock
 90 | 
 91 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
 92 | __pypackages__/
 93 | 
 94 | # Celery stuff
 95 | celerybeat-schedule
 96 | celerybeat.pid
 97 | 
 98 | # SageMath parsed files
 99 | *.sage.py
100 | 
101 | # Environments
102 | .env
103 | .venv
104 | env/
105 | venv/
106 | ENV/
107 | env.bak/
108 | venv.bak/
109 | 
110 | # Spyder project settings
111 | .spyderproject
112 | .spyproject
113 | 
114 | # Rope project settings
115 | .ropeproject
116 | 
117 | # mkdocs documentation
118 | /site
119 | 
120 | # mypy
121 | .mypy_cache/
122 | .dmypy.json
123 | dmypy.json
124 | 
125 | # Pyre type checker
126 | .pyre/
127 | 
128 | # pytype static type analyzer
129 | .pytype/
130 | 
131 | # Cython debug symbols
132 | cython_debug/
133 | 
134 | # PyCharm
135 | .idea/
136 | *.iml
137 | 
138 | # VS Code
139 | .vscode/
140 | !.vscode/settings.json
141 | !.vscode/tasks.json
142 | !.vscode/launch.json
143 | !.vscode/extensions.json
144 | 
145 | # macOS
146 | .DS_Store
147 | 
148 | # Windows
149 | Thumbs.db
150 | ehthumbs.db
151 | Desktop.ini
152 | 
153 | fusion_result.json
154 | kernel_meta/
155 | 


--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
 1 | repos:
 2 |   # 1. isort - 自动排序 Python imports
 3 |   - repo: https://github.com/pycqa/isort
 4 |     rev: 6.0.1  # 使用固定版本号
 5 |     hooks:
 6 |       - id: isort
 7 |         name: isort (python)
 8 |         args: [--profile=black]  # 与 Black 兼容的配置
 9 |         language: python
10 | 
11 |   # 2. Black - 自动格式化 Python 代码
12 |   - repo: https://github.com/psf/black
13 |     rev: 25.1.0  # 使用固定版本号
14 |     hooks:
15 |       - id: black
16 |         language: python
17 | 
18 |   # 3. flake8 - Python 静态检查
19 |   - repo: https://github.com/pycqa/flake8
20 |     rev: 7.2.0
21 |     hooks:
22 |       - id: flake8
23 |         args: [--max-line-length=120, --ignore=E203]  # 设置行长度为 120
24 |         additional_dependencies: [flake8-bugbear==24.12.12]  # 可选:增强检查
25 | 
26 |   # 4. pre-commit-hooks - 通用 Git 钩子
27 |   - repo: https://github.com/pre-commit/pre-commit-hooks
28 |     rev: v5.0.0
29 |     hooks:
30 |       - id: trailing-whitespace  # 删除行尾空格
31 |       - id: end-of-file-fixer    # 确保文件以换行符结束
32 |       - id: check-yaml           # 验证 YAML 文件语法
33 |       - id: check-added-large-files  # 阻止大文件提交
34 |         args: ["--maxkb=512"]
35 | 


--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 | 
3 | Copyright 2025 ByteDance Ltd. and/or its affiliates
4 | 
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 | 
7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 | 
9 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.


--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
  1 | <div align="center">
  2 |   <img src="./assets/dolphin.png" width="300">
  3 | </div>
  4 | 
  5 | <div align="center">
  6 |   <a href="https://arxiv.org/abs/2505.14059">
  7 |     <img src="https://img.shields.io/badge/Paper-arXiv-red">
  8 |   </a>
  9 |   <a href="https://huggingface.co/ByteDance/Dolphin">
 10 |     <img src="https://img.shields.io/badge/HuggingFace-Dolphin-yellow">
 11 |   </a>
 12 |   <a href="https://modelscope.cn/models/ByteDance/Dolphin">
 13 |     <img src="https://img.shields.io/badge/ModelScope-Dolphin-purple">
 14 |   </a>
 15 |   <a href="http://115.190.42.15:8888/dolphin/">
 16 |     <img src="https://img.shields.io/badge/Demo-Dolphin-blue">
 17 |   </a>
 18 |   <a href="https://github.com/bytedance/Dolphin">
 19 |     <img src="https://img.shields.io/badge/Code-Github-green">
 20 |   </a>
 21 |   <a href="https://opensource.org/licenses/MIT">
 22 |     <img src="https://img.shields.io/badge/License-MIT-lightgray">
 23 |   </a>
 24 |   <br>
 25 | </div>
 26 | 
 27 | <br>
 28 | 
 29 | <div align="center">
 30 |   <img src="./assets/demo.gif" width="800">
 31 | </div>
 32 | 
 33 | # Dolphin: Document Image Parsing via Heterogeneous Anchor Prompting
 34 | 
 35 | Dolphin (**Do**cument Image **P**arsing via **H**eterogeneous Anchor Prompt**in**g) is a novel multimodal document image parsing model following an analyze-then-parse paradigm. This repository contains the demo code and pre-trained models for Dolphin.
 36 | 
 37 | ## 📑 Overview
 38 | 
 39 | Document image parsing is challenging due to its complexly intertwined elements such as text paragraphs, figures, formulas, and tables. Dolphin addresses these challenges through a two-stage approach:
 40 | 
 41 | 1. **🔍 Stage 1**: Comprehensive page-level layout analysis by generating element sequence in natural reading order
 42 | 2. **🧩 Stage 2**: Efficient parallel parsing of document elements using heterogeneous anchors and task-specific prompts
 43 | 
 44 | <div align="center">
 45 |   <img src="./assets/framework.png" width="680">
 46 | </div>
 47 | 
 48 | Dolphin achieves promising performance across diverse page-level and element-level parsing tasks while ensuring superior efficiency through its lightweight architecture and parallel parsing mechanism.
 49 | 
 50 | ## 🚀 Demo
 51 | Try our demo on [Demo-Dolphin](http://115.190.42.15:8888/dolphin/).
 52 | 
 53 | ## 📅 Changelog
 54 | - 🔥 **2025.07.10** Released the *Fox-Page Benchmark*, a manually refined subset of the original [Fox dataset](https://github.com/ucaslcl/Fox). Download via: [Baidu Yun](https://pan.baidu.com/share/init?surl=t746ULp6iU5bUraVrPlMSw&pwd=fox1) | [Google Drive](https://drive.google.com/file/d/1yZQZqI34QCqvhB4Tmdl3X_XEvYvQyP0q/view?usp=sharing).
 55 | - 🔥 **2025.06.30** Added [TensorRT-LLM support](https://github.com/bytedance/Dolphin/blob/master/deployment/tensorrt_llm/ReadMe.md) for accelerated inference!
 56 | - 🔥 **2025.06.27** Added [vLLM support](https://github.com/bytedance/Dolphin/blob/master/deployment/vllm/ReadMe.md) for accelerated inference!
 57 | - 🔥 **2025.06.13** Added multi-page PDF document parsing capability.
 58 | - 🔥 **2025.05.21** Our demo is released at [link](http://115.190.42.15:8888/dolphin/). Check it out!
 59 | - 🔥 **2025.05.20** The pretrained model and inference code of Dolphin are released.
 60 | - 🔥 **2025.05.16** Our paper has been accepted by ACL 2025. Paper link: [arXiv](https://arxiv.org/abs/2505.14059).
 61 | 
 62 | ## 🛠️ Installation
 63 | 
 64 | 1. Clone the repository:
 65 |    ```bash
 66 |    git clone https://github.com/ByteDance/Dolphin.git
 67 |    cd Dolphin
 68 |    ```
 69 | 
 70 | 2. Install the dependencies:
 71 |    ```bash
 72 |    pip install -r requirements.txt
 73 |    ```
 74 | 
 75 | 3. Download the pre-trained models using one of the following options:
 76 | 
 77 |    **Option A: Original Model Format (config-based)**
 78 |    
 79 |    Download from [Baidu Yun](https://pan.baidu.com/s/15zcARoX0CTOHKbW8bFZovQ?pwd=9rpx) or [Google Drive](https://drive.google.com/drive/folders/1PQJ3UutepXvunizZEw-uGaQ0BCzf-mie?usp=sharing) and put them in the `./checkpoints` folder.
 80 | 
 81 |    **Option B: Hugging Face Model Format**
 82 |    
 83 |    Visit our Huggingface [model card](https://huggingface.co/ByteDance/Dolphin), or download model by:
 84 |    
 85 |    ```bash
 86 |    # Download the model from Hugging Face Hub
 87 |    git lfs install
 88 |    git clone https://huggingface.co/ByteDance/Dolphin ./hf_model
 89 |    # Or use the Hugging Face CLI
 90 |    pip install huggingface_hub
 91 |    huggingface-cli download ByteDance/Dolphin --local-dir ./hf_model
 92 |    ```
 93 | 
 94 | ## ⚡ Inference
 95 | 
 96 | Dolphin provides two inference frameworks with support for two parsing granularities:
 97 | - **Page-level Parsing**: Parse the entire document page into a structured JSON and Markdown format
 98 | - **Element-level Parsing**: Parse individual document elements (text, table, formula)
 99 | 
100 | ### 📄 Page-level Parsing
101 | 
102 | #### Using Original Framework (config-based)
103 | 
104 | ```bash
105 | # Process a single document image
106 | python demo_page.py --config ./config/Dolphin.yaml --input_path ./demo/page_imgs/page_1.jpeg --save_dir ./results
107 | 
108 | # Process a single document pdf
109 | python demo_page.py --config ./config/Dolphin.yaml --input_path ./demo/page_imgs/page_6.pdf --save_dir ./results
110 | 
111 | # Process all documents in a directory
112 | python demo_page.py --config ./config/Dolphin.yaml --input_path ./demo/page_imgs --save_dir ./results
113 | 
114 | # Process with custom batch size for parallel element decoding
115 | python demo_page.py --config ./config/Dolphin.yaml --input_path ./demo/page_imgs --save_dir ./results --max_batch_size 8
116 | ```
117 | 
118 | #### Using Hugging Face Framework
119 | 
120 | ```bash
121 | # Process a single document image
122 | python demo_page_hf.py --model_path ./hf_model --input_path ./demo/page_imgs/page_1.jpeg --save_dir ./results
123 | 
124 | # Process a single document pdf
125 | python demo_page_hf.py --model_path ./hf_model --input_path ./demo/page_imgs/page_6.pdf --save_dir ./results
126 | 
127 | # Process all documents in a directory
128 | python demo_page_hf.py --model_path ./hf_model --input_path ./demo/page_imgs --save_dir ./results
129 | 
130 | # Process with custom batch size for parallel element decoding
131 | python demo_page_hf.py --model_path ./hf_model --input_path ./demo/page_imgs --save_dir ./results --max_batch_size 16
132 | ```
133 | 
134 | ### 🧩 Element-level Parsing
135 | 
136 | #### Using Original Framework (config-based)
137 | 
138 | ```bash
139 | # Process a single table image
140 | python demo_element.py --config ./config/Dolphin.yaml --input_path ./demo/element_imgs/table_1.jpeg --element_type table
141 | 
142 | # Process a single formula image
143 | python demo_element.py --config ./config/Dolphin.yaml --input_path ./demo/element_imgs/line_formula.jpeg --element_type formula
144 | 
145 | # Process a single text paragraph image
146 | python demo_element.py --config ./config/Dolphin.yaml --input_path ./demo/element_imgs/para_1.jpg --element_type text
147 | ```
148 | 
149 | #### Using Hugging Face Framework
150 | 
151 | ```bash
152 | # Process a single table image
153 | python demo_element_hf.py --model_path ./hf_model --input_path ./demo/element_imgs/table_1.jpeg --element_type table
154 | 
155 | # Process a single formula image
156 | python demo_element_hf.py --model_path ./hf_model --input_path ./demo/element_imgs/line_formula.jpeg --element_type formula
157 | 
158 | # Process a single text paragraph image
159 | python demo_element_hf.py --model_path ./hf_model --input_path ./demo/element_imgs/para_1.jpg --element_type text
160 | ```
161 | 
162 | ## 🌟 Key Features
163 | 
164 | - 🔄 Two-stage analyze-then-parse approach based on a single VLM
165 | - 📊 Promising performance on document parsing tasks
166 | - 🔍 Natural reading order element sequence generation
167 | - 🧩 Heterogeneous anchor prompting for different document elements
168 | - ⏱️ Efficient parallel parsing mechanism
169 | - 🤗 Support for Hugging Face Transformers for easier integration
170 | 
171 | 
172 | ## 📮 Notice
173 | **Call for Bad Cases:** If you have encountered any cases where the model performs poorly, we would greatly appreciate it if you could share them in the issue. We are continuously working to optimize and improve the model.
174 | 
175 | ## 💖 Acknowledgement
176 | 
177 | We would like to acknowledge the following open-source projects that provided inspiration and reference for this work:
178 | - [Donut](https://github.com/clovaai/donut/)
179 | - [Nougat](https://github.com/facebookresearch/nougat)
180 | - [GOT](https://github.com/Ucas-HaoranWei/GOT-OCR2.0)
181 | - [MinerU](https://github.com/opendatalab/MinerU/tree/master)
182 | - [Swin](https://github.com/microsoft/Swin-Transformer)
183 | - [Hugging Face Transformers](https://github.com/huggingface/transformers)
184 | 
185 | ## 📝 Citation
186 | 
187 | If you find this code useful for your research, please use the following BibTeX entry.
188 | 
189 | ```bibtex
190 | @article{feng2025dolphin,
191 |   title={Dolphin: Document Image Parsing via Heterogeneous Anchor Prompting},
192 |   author={Feng, Hao and Wei, Shu and Fei, Xiang and Shi, Wei and Han, Yingdong and Liao, Lei and Lu, Jinghui and Wu, Binghong and Liu, Qi and Lin, Chunhui and others},
193 |   journal={arXiv preprint arXiv:2505.14059},
194 |   year={2025}
195 | }
196 | ```
197 | 
198 | ## Star History
199 | 
200 | [![Star History Chart](https://api.star-history.com/svg?repos=bytedance/Dolphin&type=Date)](https://www.star-history.com/#bytedance/Dolphin&Date)
201 | 


--------------------------------------------------------------------------------
/README_CN.md:
--------------------------------------------------------------------------------
  1 | <div align="center">
  2 |   <img src="./assets/dolphin.png" width="300">
  3 | </div>
  4 | 
  5 | <div align="center">
  6 |   <a href="https://arxiv.org/abs/2505.14059">
  7 |     <img src="https://img.shields.io/badge/论文-arXiv-red">
  8 |   </a>
  9 |   <a href="https://huggingface.co/ByteDance/Dolphin">
 10 |     <img src="https://img.shields.io/badge/HuggingFace-Dolphin-yellow">
 11 |   </a>
 12 |   <a href="https://modelscope.cn/models/ByteDance/Dolphin">
 13 |     <img src="https://img.shields.io/badge/ModelScope-Dolphin-purple">
 14 |   </a>
 15 |   <a href="https://huggingface.co/spaces/ByteDance/Dolphin">
 16 |     <img src="https://img.shields.io/badge/演示-Dolphin-blue">
 17 |   </a>
 18 |   <a href="https://github.com/bytedance/Dolphin">
 19 |     <img src="https://img.shields.io/badge/代码-Github-green">
 20 |   </a>
 21 |   <a href="https://opensource.org/licenses/MIT">
 22 |     <img src="https://img.shields.io/badge/许可证-MIT-lightgray">
 23 |   </a>
 24 |   <br>
 25 | </div>
 26 | 
 27 | <br>
 28 | 
 29 | <div align="center">
 30 |   <img src="./assets/demo.gif" width="800">
 31 | </div>
 32 | 
 33 | # Dolphin: 基于异构锚点提示的文档图像解析
 34 | 
 35 | Dolphin(**Do**cument Image **P**arsing via **H**eterogeneous Anchor Prompt**in**g)是一个创新的多模态文档图像解析模型,采用"分析-解析"的两阶段范式。本仓库包含Dolphin的演示代码和预训练模型。
 36 | 
 37 | ## 📑 概述
 38 | 
 39 | 由于文档图像中文本段落、图表、公式和表格等元素的复杂交织,文档图像解析具有挑战性。Dolphin通过两阶段方法解决这些挑战:
 40 | 
 41 | 1. **🔍 第一阶段**:通过按自然阅读顺序生成元素序列进行全面的页面级布局分析
 42 | 2. **🧩 第二阶段**:使用异构锚点和任务特定提示高效并行解析文档元素
 43 | 
 44 | <div align="center">
 45 |   <img src="./assets/framework.png" width="680">
 46 | </div>
 47 | 
 48 | Dolphin在多样化的页面级和元素级解析任务中取得了优异的性能,同时通过其轻量级架构和并行解析机制确保了卓越的效率。
 49 | 
 50 | ## 🚀 演示
 51 | 在 [Demo-Dolphin](http://115.190.42.15:8888/dolphin/) 上试用我们的演示。
 52 | 
 53 | ## 📅 更新日志
 54 | - 🔥 **2025.06.30** 新增[TensorRT-LLM](https://github.com/bytedance/Dolphin/blob/master/deployment/tensorrt_llm/ReadMe.md)支持,提升推理速度!
 55 | - 🔥 **2025.06.27** 新增[vLLM](https://github.com/bytedance/Dolphin/blob/master/deployment/vllm/ReadMe.md)支持,提升推理速度!
 56 | - 🔥 **2025.06.13** 新增多页PDF文档解析功能。
 57 | - 🔥 **2025.05.21** 我们的演示已在 [链接](http://115.190.42.15:8888/dolphin/) 发布。快来体验吧!
 58 | - 🔥 **2025.05.20** Dolphin的预训练模型和推理代码已发布。
 59 | - 🔥 **2025.05.16** 我们的论文已被ACL 2025接收。论文链接:[arXiv](https://arxiv.org/abs/2505.14059)。
 60 | 
 61 | ## 🛠️ 安装
 62 | 
 63 | 1. 克隆仓库:
 64 |    ```bash
 65 |    git clone https://github.com/ByteDance/Dolphin.git
 66 |    cd Dolphin
 67 |    ```
 68 | 
 69 | 2. 安装依赖:
 70 |    ```bash
 71 |    pip install -r requirements.txt
 72 |    ```
 73 | 
 74 | 3. 使用以下选项之一下载预训练模型:
 75 | 
 76 |    **选项A:原始模型格式(基于配置文件)**
 77 |    
 78 |    从 [百度网盘](https://pan.baidu.com/s/15zcARoX0CTOHKbW8bFZovQ?pwd=9rpx) 或 [Google Drive](https://drive.google.com/drive/folders/1PQJ3UutepXvunizZEw-uGaQ0BCzf-mie?usp=sharing) 下载,并将其放在 `./checkpoints` 文件夹中。
 79 | 
 80 |    **选项B:Hugging Face模型格式**
 81 |    
 82 |    访问我们的Huggingface [模型卡片](https://huggingface.co/ByteDance/Dolphin),或通过以下方式下载模型:
 83 |    
 84 |    ```bash
 85 |    # 从Hugging Face Hub下载模型
 86 |    git lfs install
 87 |    git clone https://huggingface.co/ByteDance/Dolphin ./hf_model
 88 |    # 或使用Hugging Face CLI
 89 |    pip install huggingface_hub
 90 |    huggingface-cli download ByteDance/Dolphin --local-dir ./hf_model
 91 |    ```
 92 | 
 93 | ## ⚡ 推理
 94 | 
 95 | Dolphin提供两个推理框架,支持两种解析粒度:
 96 | - **页面级解析**:将整个文档页面解析为结构化的JSON和Markdown格式
 97 | - **元素级解析**:解析单个文档元素(文本、表格、公式)
 98 | 
 99 | ### 📄 页面级解析
100 | 
101 | #### 使用原始框架(基于配置文件)
102 | 
103 | ```bash
104 | # 处理单个文档图像
105 | python demo_page.py --config ./config/Dolphin.yaml --input_path ./demo/page_imgs/page_1.jpeg --save_dir ./results
106 | 
107 | # 处理单个文档PDF
108 | python demo_page.py --config ./config/Dolphin.yaml --input_path ./demo/page_imgs/page_6.pdf --save_dir ./results
109 | 
110 | # 处理目录中的所有文档
111 | python demo_page.py --config ./config/Dolphin.yaml --input_path ./demo/page_imgs --save_dir ./results
112 | 
113 | # 使用自定义批次大小进行并行元素解码
114 | python demo_page.py --config ./config/Dolphin.yaml --input_path ./demo/page_imgs --save_dir ./results --max_batch_size 8
115 | ```
116 | 
117 | #### 使用Hugging Face框架
118 | 
119 | ```bash
120 | # 处理单个文档图像
121 | python demo_page_hf.py --model_path ./hf_model --input_path ./demo/page_imgs/page_1.jpeg --save_dir ./results
122 | 
123 | # 处理单个文档PDF
124 | python demo_page_hf.py --model_path ./hf_model --input_path ./demo/page_imgs/page_6.pdf --save_dir ./results
125 | 
126 | # 处理目录中的所有文档
127 | python demo_page_hf.py --model_path ./hf_model --input_path ./demo/page_imgs --save_dir ./results
128 | 
129 | # 使用自定义批次大小进行并行元素解码
130 | python demo_page_hf.py --model_path ./hf_model --input_path ./demo/page_imgs --save_dir ./results --max_batch_size 16
131 | ```
132 | 
133 | ### 🧩 元素级解析
134 | 
135 | #### 使用原始框架(基于配置文件)
136 | 
137 | ```bash
138 | # 处理单个表格图像
139 | python demo_element.py --config ./config/Dolphin.yaml --input_path ./demo/element_imgs/table_1.jpeg --element_type table
140 | 
141 | # 处理单个公式图像
142 | python demo_element.py --config ./config/Dolphin.yaml --input_path ./demo/element_imgs/line_formula.jpeg --element_type formula
143 | 
144 | # 处理单个文本段落图像
145 | python demo_element.py --config ./config/Dolphin.yaml --input_path ./demo/element_imgs/para_1.jpg --element_type text
146 | ```
147 | 
148 | #### 使用Hugging Face框架
149 | 
150 | ```bash
151 | # 处理单个表格图像
152 | python demo_element_hf.py --model_path ./hf_model --input_path ./demo/element_imgs/table_1.jpeg --element_type table
153 | 
154 | # 处理单个公式图像
155 | python demo_element_hf.py --model_path ./hf_model --input_path ./demo/element_imgs/line_formula.jpeg --element_type formula
156 | 
157 | # 处理单个文本段落图像
158 | python demo_element_hf.py --model_path ./hf_model --input_path ./demo/element_imgs/para_1.jpg --element_type text
159 | ```
160 | 
161 | ## 🌟 主要特性
162 | 
163 | - 🔄 基于单一VLM的两阶段分析-解析方法
164 | - 📊 在文档解析任务上的优异性能
165 | - 🔍 自然阅读顺序元素序列生成
166 | - 🧩 针对不同文档元素的异构锚点提示
167 | - ⏱️ 高效的并行解析机制
168 | - 🤗 支持Hugging Face Transformers,便于集成
169 | 
170 | ## 📮 通知
171 | **征集不良案例:** 如果您遇到模型表现不佳的案例,我们非常欢迎您在issue中分享。我们正在持续优化和改进模型。
172 | 
173 | ## 💖 致谢
174 | 
175 | 我们要感谢以下开源项目为本工作提供的灵感和参考:
176 | - [Donut](https://github.com/clovaai/donut/)
177 | - [Nougat](https://github.com/facebookresearch/nougat)
178 | - [GOT](https://github.com/Ucas-HaoranWei/GOT-OCR2.0)
179 | - [MinerU](https://github.com/opendatalab/MinerU/tree/master)
180 | - [Swin](https://github.com/microsoft/Swin-Transformer)
181 | - [Hugging Face Transformers](https://github.com/huggingface/transformers)
182 | 
183 | ## 📝 引用
184 | 
185 | 如果您在研究中发现此代码有用,请使用以下BibTeX条目。
186 | 
187 | ```bibtex
188 | @article{feng2025dolphin,
189 |   title={Dolphin: Document Image Parsing via Heterogeneous Anchor Prompting},
190 |   author={Feng, Hao and Wei, Shu and Fei, Xiang and Shi, Wei and Han, Yingdong and Liao, Lei and Lu, Jinghui and Wu, Binghong and Liu, Qi and Lin, Chunhui and others},
191 |   journal={arXiv preprint arXiv:2505.14059},
192 |   year={2025}
193 | }
194 | ```
195 | 
196 | ## 星标历史
197 | 
198 | [![Star History Chart](https://api.star-history.com/svg?repos=bytedance/Dolphin&type=Date)](https://www.star-history.com/#bytedance/Dolphin&Date) 
199 | 


--------------------------------------------------------------------------------
/assets/demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/assets/demo.gif


--------------------------------------------------------------------------------
/assets/dolphin.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/assets/dolphin.png


--------------------------------------------------------------------------------
/assets/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/assets/framework.png


--------------------------------------------------------------------------------
/chat.py:
--------------------------------------------------------------------------------
  1 | """ 
  2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
  3 | SPDX-License-Identifier: MIT
  4 | """
  5 | 
  6 | import os
  7 | import warnings
  8 | from collections import OrderedDict
  9 | 
 10 | from omegaconf import ListConfig
 11 | 
 12 | warnings.filterwarnings("ignore", category=UserWarning)
 13 | warnings.filterwarnings("ignore", category=FutureWarning)
 14 | os.environ.setdefault("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION", "python")
 15 | 
 16 | import torch
 17 | from PIL import Image
 18 | from transformers import PreTrainedTokenizerFast
 19 | 
 20 | from utils.model import DonutConfig, DonutModel, SwinEncoder
 21 | from utils.processor import DolphinProcessor
 22 | 
 23 | 
 24 | def try_rename_lagacy_weights(ckpt, output_path=""):
 25 |     if "state_dict" in ckpt.keys():
 26 |         ckpt = ckpt["state_dict"]
 27 |     if "module" in ckpt.keys():
 28 |         ckpt = ckpt["module"]
 29 |     new_ckpt = OrderedDict()
 30 |     for k, v in ckpt.items():
 31 |         if k.startswith("model."):
 32 |             k = k[len("model.") :]
 33 |         if k.startswith("encoder"):
 34 |             new_ckpt["vpm" + k[len("encoder") :]] = v
 35 |         elif k.startswith("decoder"):
 36 |             new_ckpt["llm" + k[len("encoder") :]] = v
 37 |         else:
 38 |             new_ckpt[k] = v
 39 |     if output_path:
 40 |         torch.save(new_ckpt, output_path)
 41 |     return new_ckpt
 42 | 
 43 | 
 44 | def convert_listconfig_to_list(config):
 45 |     new_config = {}
 46 |     for k, v in config.items():
 47 |         if isinstance(v, ListConfig):
 48 |             new_config[k] = list(v)
 49 |         else:
 50 |             new_config[k] = v
 51 |     return new_config
 52 | 
 53 | 
 54 | class DOLPHIN:
 55 |     def __init__(self, config, ckpt_path="") -> None:
 56 |         self.model_args = config.model
 57 |         self.swin_args = config.model.pop("swin_args")
 58 |         self.swin_args = convert_listconfig_to_list(self.swin_args)
 59 | 
 60 |         vision_tower = SwinEncoder(
 61 |             input_size=self.swin_args["img_size"],
 62 |             patch_size=self.swin_args["patch_size"],
 63 |             embed_dim=self.swin_args["embed_dim"],
 64 |             window_size=self.swin_args["window_size"],
 65 |             encoder_layer=self.swin_args["encoder_layer"],
 66 |             num_heads=self.swin_args["num_heads"],
 67 |             align_long_axis=self.swin_args["align_long_axis"],
 68 |         )
 69 | 
 70 |         self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=self.model_args.tokenizer_path)
 71 |         self.tokenizer.pad_token = "<pad>"
 72 |         self.tokenizer.bos_token = "<s>"
 73 |         self.tokenizer.eos_token = "</s>"
 74 |         self.tokenizer.unk_token = "<unk>"
 75 | 
 76 |         if self.model_args.get("extra_answer_tokens", False):
 77 |             # print("Allowing multitask training: adding <Answer/> to the tokenizer.")
 78 |             prompt_end_token = " <Answer/>"
 79 |             self.tokenizer.add_special_tokens({"additional_special_tokens": sorted(set([prompt_end_token]))})
 80 |             self.tokenizer._prompt_end_token = prompt_end_token
 81 |             self.tokenizer._prompt_end_token_id = self.tokenizer.convert_tokens_to_ids(prompt_end_token)
 82 | 
 83 |         donut_config = DonutConfig(
 84 |             decoder_layer=self.model_args.decoder_layer,
 85 |             max_length=self.model_args.max_length,
 86 |             max_position_embeddings=self.model_args.max_position_embeddings,
 87 |             hidden_dimension=self.model_args.hidden_dimension,
 88 |         )
 89 | 
 90 |         self.model = DonutModel(config=donut_config, vision_tower=vision_tower, tokenizer=self.tokenizer)
 91 |         if self.model_args.model_name_or_path:
 92 |             ckpt = torch.load(self.model_args.model_name_or_path)
 93 |             ckpt = try_rename_lagacy_weights(ckpt)
 94 |             self.model.load_state_dict(ckpt, strict=True)
 95 | 
 96 |         device = "cuda" if torch.cuda.is_available() else "cpu"
 97 |         self.model.to(device)
 98 |         self.model.eval()
 99 |         transform_args = {
100 |             "input_size": self.swin_args["img_size"],
101 |             "max_length": self.model_args.max_length,
102 |         }
103 |         self.processor = DolphinProcessor({}, self.tokenizer, transform_args=transform_args)
104 | 
105 |     def chat(
106 |         self,
107 |         question,
108 |         image,
109 |         return_raw=False,
110 |         return_score=False,
111 |         return_img_size=False,
112 |         only_return_img_size=False,
113 |         max_batch_size=16,
114 |     ):
115 | 
116 |         def _preprocess_image(image):
117 |             if isinstance(image, str):
118 |                 image = Image.open(image).convert("RGB")
119 |             if return_img_size or only_return_img_size:
120 |                 image_tensor, ori_size = self.processor.process_image_for_inference(image, return_img_size=True)
121 |             else:
122 |                 image_tensor = self.processor.process_image_for_inference(image, return_img_size=False)
123 |                 ori_size = None
124 |             return image_tensor, ori_size
125 | 
126 |         def _preprocess_prompt(question):
127 |             if self.model_args.get("extra_answer_tokens", False):
128 |                 if self.tokenizer._prompt_end_token not in question:
129 |                     question = question + self.tokenizer._prompt_end_token
130 |             prompt_ids = self.processor.process_prompt_for_inference(question)
131 |             return prompt_ids
132 | 
133 |         def _preprocess_prompt_batch(question):
134 |             if self.model_args.get("extra_answer_tokens", False):
135 |                 for i in range(len(question)):
136 |                     if self.tokenizer._prompt_end_token not in question[i]:
137 |                         question[i] = question[i] + self.tokenizer._prompt_end_token
138 |                     if not question[i].startswith("<s>"):
139 |                         question[i] = "<s>" + question[i]
140 |             return question
141 | 
142 |         def _postprocess(output, question):
143 |             output = output.replace("<s>", "").replace(question, "").replace("</s>", "").replace("<pad>", "")
144 |             if self.model_args.get("extra_answer_tokens", False):
145 |                 output = output.split(self.tokenizer._prompt_end_token)[-1]
146 |             return output
147 | 
148 |         if isinstance(question, list):
149 |             image_tensor_list = []
150 |             for i in image:
151 |                 image_tensor, ori_size = _preprocess_image(i)
152 |                 image_tensor_list.append(image_tensor)
153 |             image_tensor = torch.cat(image_tensor_list, dim=0)
154 | 
155 |             question = _preprocess_prompt_batch(question)
156 |             self.processor.tokenizer.padding_side = "left"
157 |             prompt_ids = self.processor.tokenizer(
158 |                 question, add_special_tokens=False, return_tensors="pt", padding=True
159 |             ).input_ids
160 |         else:
161 |             image_tensor, ori_size = _preprocess_image(image)
162 |             prompt_ids = _preprocess_prompt(question)
163 | 
164 |         if only_return_img_size:
165 |             return ori_size
166 | 
167 |         model_output_batch = []
168 |         for i in range(0, image_tensor.shape[0], max_batch_size):
169 |             image_tensor_batch = image_tensor[i : i + max_batch_size]
170 |             prompt_ids_batch = prompt_ids[i : i + max_batch_size]
171 |             model_output = self.model.inference(image_tensors=image_tensor_batch, prompt_ids=prompt_ids_batch)
172 |             model_output_batch.append(model_output)
173 |         model_output = {}
174 |         for k, v in model_output_batch[0].items():
175 |             if isinstance(v, torch.Tensor):
176 |                 model_output[k] = sum(
177 |                     [v_batch[k].cpu().numpy().tolist() for v_batch in model_output_batch],
178 |                     [],
179 |                 )
180 |             else:
181 |                 model_output[k] = sum([v_batch[k] for v_batch in model_output_batch], [])
182 | 
183 |         if return_raw:
184 |             if return_img_size:
185 |                 return model_output, ori_size
186 |             return model_output
187 |         else:
188 |             if isinstance(question, list):
189 |                 output = [_postprocess(model_output["repetitions"][i], question[i]) for i in range(len(question))]
190 |                 score = model_output["scores"]
191 |             else:
192 |                 output = _postprocess(model_output["repetitions"][0], question)
193 |                 score = model_output["scores"][0]
194 |             if return_score:
195 |                 return output, score
196 |             if return_img_size:
197 |                 return output, ori_size
198 |             return output
199 | 


--------------------------------------------------------------------------------
/config/Dolphin.yaml:
--------------------------------------------------------------------------------
 1 | model:
 2 |   model_name_or_path: "./checkpoints/dolphin_model.bin"
 3 |   tokenizer_path: "./checkpoints/dolphin_tokenizer.json"
 4 |   extra_answer_tokens: True   # add <Answer/> token
 5 |   max_length: 4096
 6 |   decoder_layer: 10
 7 |   max_position_embeddings: 4096
 8 |   hidden_dimension: 1024
 9 |   swin_args:
10 |     name: 'swin'
11 |     img_size: [896, 896]
12 |     patch_size: 4
13 |     embed_dim: 128
14 |     align_long_axis: False
15 |     window_size: 7
16 |     encoder_layer: [2, 2, 14, 2]
17 |     num_heads: [4, 8, 16, 32]
18 | 


--------------------------------------------------------------------------------
/demo/element_imgs/block_formula.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/demo/element_imgs/block_formula.jpeg


--------------------------------------------------------------------------------
/demo/element_imgs/line_formula.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/demo/element_imgs/line_formula.jpeg


--------------------------------------------------------------------------------
/demo/element_imgs/para_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/demo/element_imgs/para_1.jpg


--------------------------------------------------------------------------------
/demo/element_imgs/para_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/demo/element_imgs/para_2.jpg


--------------------------------------------------------------------------------
/demo/element_imgs/para_3.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/demo/element_imgs/para_3.jpeg


--------------------------------------------------------------------------------
/demo/element_imgs/table_1.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/demo/element_imgs/table_1.jpeg


--------------------------------------------------------------------------------
/demo/element_imgs/table_2.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/demo/element_imgs/table_2.jpeg


--------------------------------------------------------------------------------
/demo/page_imgs/page_1.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/demo/page_imgs/page_1.jpeg


--------------------------------------------------------------------------------
/demo/page_imgs/page_2.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/demo/page_imgs/page_2.jpeg


--------------------------------------------------------------------------------
/demo/page_imgs/page_3.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/demo/page_imgs/page_3.jpeg


--------------------------------------------------------------------------------
/demo/page_imgs/page_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/demo/page_imgs/page_4.png


--------------------------------------------------------------------------------
/demo/page_imgs/page_5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/demo/page_imgs/page_5.jpg


--------------------------------------------------------------------------------
/demo/page_imgs/page_6.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/demo/page_imgs/page_6.pdf


--------------------------------------------------------------------------------
/demo/page_imgs/page_7.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/demo/page_imgs/page_7.jpeg


--------------------------------------------------------------------------------
/demo_element.py:
--------------------------------------------------------------------------------
  1 | """ 
  2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
  3 | SPDX-License-Identifier: MIT
  4 | """
  5 | 
  6 | import argparse
  7 | import glob
  8 | import os
  9 | 
 10 | from omegaconf import OmegaConf
 11 | from PIL import Image
 12 | 
 13 | from chat import DOLPHIN
 14 | from utils.utils import *
 15 | 
 16 | 
 17 | def process_element(image_path, model, element_type, save_dir=None):
 18 |     """Process a single element image (text, table, formula)
 19 | 
 20 |     Args:
 21 |         image_path: Path to the element image
 22 |         model: DOLPHIN model instance
 23 |         element_type: Type of element ('text', 'table', 'formula')
 24 |         save_dir: Directory to save results (default: same as input directory)
 25 | 
 26 |     Returns:
 27 |         Parsed content of the element and recognition results
 28 |     """
 29 |     # Load and prepare image
 30 |     pil_image = Image.open(image_path).convert("RGB")
 31 |     pil_image = crop_margin(pil_image)
 32 | 
 33 |     # Select appropriate prompt based on element type
 34 |     if element_type == "table":
 35 |         prompt = "Parse the table in the image."
 36 |         label = "tab"
 37 |     elif element_type == "formula":
 38 |         prompt = "Read text in the image."
 39 |         label = "formula"
 40 |     else:  # Default to text
 41 |         prompt = "Read text in the image."
 42 |         label = "text"
 43 | 
 44 |     # Process the element
 45 |     result = model.chat(prompt, pil_image)
 46 | 
 47 |     # Create recognition result in the same format as the document parser
 48 |     recognition_result = [
 49 |         {
 50 |             "label": label,
 51 |             "text": result.strip(),
 52 |         }
 53 |     ]
 54 | 
 55 |     # Save results if save_dir is provided
 56 |     if save_dir:
 57 |         save_outputs(recognition_result, image_path, save_dir)
 58 |         print(f"Results saved to {save_dir}")
 59 | 
 60 |     return result, recognition_result
 61 | 
 62 | 
 63 | def main():
 64 |     parser = argparse.ArgumentParser(description="Element-level processing using DOLPHIN model")
 65 |     parser.add_argument("--config", default="./config/Dolphin.yaml", help="Path to configuration file")
 66 |     parser.add_argument("--input_path", type=str, required=True, help="Path to input image or directory of images")
 67 |     parser.add_argument(
 68 |         "--element_type",
 69 |         type=str,
 70 |         choices=["text", "table", "formula"],
 71 |         default="text",
 72 |         help="Type of element to process (text, table, formula)",
 73 |     )
 74 |     parser.add_argument(
 75 |         "--save_dir",
 76 |         type=str,
 77 |         default=None,
 78 |         help="Directory to save parsing results (default: same as input directory)",
 79 |     )
 80 |     parser.add_argument("--print_results", action="store_true", help="Print recognition results to console")
 81 |     args = parser.parse_args()
 82 | 
 83 |     # Load Model
 84 |     config = OmegaConf.load(args.config)
 85 |     model = DOLPHIN(config)
 86 | 
 87 |     # Set save directory
 88 |     save_dir = args.save_dir or (
 89 |         args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path)
 90 |     )
 91 |     setup_output_dirs(save_dir)
 92 | 
 93 |     # Collect Images
 94 |     if os.path.isdir(args.input_path):
 95 |         image_files = []
 96 |         for ext in [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"]:
 97 |             image_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}")))
 98 |         image_files = sorted(image_files)
 99 |     else:
100 |         if not os.path.exists(args.input_path):
101 |             raise FileNotFoundError(f"Input path {args.input_path} does not exist")
102 |         image_files = [args.input_path]
103 | 
104 |     total_samples = len(image_files)
105 |     print(f"\nTotal samples to process: {total_samples}")
106 | 
107 |     # Process images one by one
108 |     for image_path in image_files:
109 |         print(f"\nProcessing {image_path}")
110 |         try:
111 |             result, recognition_result = process_element(
112 |                 image_path=image_path,
113 |                 model=model,
114 |                 element_type=args.element_type,
115 |                 save_dir=save_dir,
116 |             )
117 | 
118 |             if args.print_results:
119 |                 print("\nRecognition result:")
120 |                 print(result)
121 |                 print("-" * 40)
122 | 
123 |         except Exception as e:
124 |             print(f"Error processing {image_path}: {str(e)}")
125 |             continue
126 | 
127 | 
128 | if __name__ == "__main__":
129 |     main()
130 | 


--------------------------------------------------------------------------------
/demo_element_hf.py:
--------------------------------------------------------------------------------
  1 | """
  2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
  3 | SPDX-License-Identifier: MIT
  4 | """
  5 | 
  6 | import argparse
  7 | import glob
  8 | import os
  9 | 
 10 | import torch
 11 | from PIL import Image
 12 | from transformers import AutoProcessor, VisionEncoderDecoderModel
 13 | 
 14 | from utils.utils import *
 15 | 
 16 | 
 17 | class DOLPHIN:
 18 |     def __init__(self, model_id_or_path):
 19 |         """Initialize the Hugging Face model
 20 |         
 21 |         Args:
 22 |             model_id_or_path: Path to local model or Hugging Face model ID
 23 |         """
 24 |         # Load model from local path or Hugging Face hub
 25 |         self.processor = AutoProcessor.from_pretrained(model_id_or_path)
 26 |         self.model = VisionEncoderDecoderModel.from_pretrained(model_id_or_path)
 27 |         self.model.eval()
 28 |         
 29 |         # Set device and precision
 30 |         self.device = "cuda" if torch.cuda.is_available() else "cpu"
 31 |         self.model.to(self.device)
 32 |         self.model = self.model.half()  # Always use half precision by default
 33 |         
 34 |         # set tokenizer
 35 |         self.tokenizer = self.processor.tokenizer
 36 |         
 37 |     def chat(self, prompt, image):
 38 |         """Process an image with the given prompt
 39 |         
 40 |         Args:
 41 |             prompt: Text prompt to guide the model
 42 |             image: PIL Image to process
 43 |             
 44 |         Returns:
 45 |             Generated text from the model
 46 |         """
 47 |         # Prepare image
 48 |         pixel_values = self.processor(image, return_tensors="pt").pixel_values
 49 |         pixel_values = pixel_values.half()
 50 |             
 51 |         # Prepare prompt
 52 |         prompt = f"<s>{prompt} <Answer/>"
 53 |         prompt_ids = self.tokenizer(
 54 |             prompt, 
 55 |             add_special_tokens=False, 
 56 |             return_tensors="pt"
 57 |         ).input_ids.to(self.device)
 58 |         
 59 |         decoder_attention_mask = torch.ones_like(prompt_ids)
 60 |         
 61 |         # Generate text
 62 |         outputs = self.model.generate(
 63 |             pixel_values=pixel_values.to(self.device),
 64 |             decoder_input_ids=prompt_ids,
 65 |             decoder_attention_mask=decoder_attention_mask,
 66 |             min_length=1,
 67 |             max_length=4096,
 68 |             pad_token_id=self.tokenizer.pad_token_id,
 69 |             eos_token_id=self.tokenizer.eos_token_id,
 70 |             use_cache=True,
 71 |             bad_words_ids=[[self.tokenizer.unk_token_id]],
 72 |             return_dict_in_generate=True,
 73 |             do_sample=False,
 74 |             num_beams=1,
 75 |             repetition_penalty=1.1,
 76 |             temperature=1.0
 77 |         )
 78 |         
 79 |         # Process the output
 80 |         sequence = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)[0]
 81 |         sequence = sequence.replace(prompt, "").replace("<pad>", "").replace("</s>", "").strip()
 82 |         
 83 |         return sequence
 84 | 
 85 | def process_element(image_path, model, element_type, save_dir=None):
 86 |     """Process a single element image (text, table, formula)
 87 |     
 88 |     Args:
 89 |         image_path: Path to the element image
 90 |         model: HFModel model instance
 91 |         element_type: Type of element ('text', 'table', 'formula')
 92 |         save_dir: Directory to save results (default: same as input directory)
 93 |         
 94 |     Returns:
 95 |         Parsed content of the element and recognition results
 96 |     """
 97 |     # Load and prepare image
 98 |     pil_image = Image.open(image_path).convert("RGB")
 99 |     pil_image = crop_margin(pil_image)
100 |     
101 |     # Select appropriate prompt based on element type
102 |     if element_type == "table":
103 |         prompt = "Parse the table in the image."
104 |         label = "tab"
105 |     elif element_type == "formula":
106 |         prompt = "Read text in the image."
107 |         label = "formula"
108 |     else:  # Default to text
109 |         prompt = "Read text in the image."
110 |         label = "text"
111 |     
112 |     # Process the element
113 |     result = model.chat(prompt, pil_image)
114 |     
115 |     # Create recognition result in the same format as the document parser
116 |     recognition_result = [
117 |         {
118 |             "label": label,
119 |             "text": result.strip(),
120 |         }
121 |     ]
122 |     
123 |     # Save results if save_dir is provided
124 |     if save_dir:
125 |         save_outputs(recognition_result, image_path, save_dir)
126 |         print(f"Results saved to {save_dir}")
127 |     
128 |     return result, recognition_result
129 | 
130 | 
131 | def main():
132 |     parser = argparse.ArgumentParser(description="Element-level processing using DOLPHIN model")
133 |     parser.add_argument("--model_path", default="./hf_model", help="Path to Hugging Face model")
134 |     parser.add_argument("--input_path", type=str, required=True, help="Path to input image or directory of images")
135 |     parser.add_argument(
136 |         "--element_type",
137 |         type=str,
138 |         choices=["text", "table", "formula"],
139 |         default="text",
140 |         help="Type of element to process (text, table, formula)",
141 |     )
142 |     parser.add_argument(
143 |         "--save_dir",
144 |         type=str,
145 |         default=None,
146 |         help="Directory to save parsing results (default: same as input directory)",
147 |     )
148 |     parser.add_argument("--print_results", action="store_true", help="Print recognition results to console")
149 |     args = parser.parse_args()
150 |     
151 |     # Load Model
152 |     model = DOLPHIN(args.model_path)
153 |     
154 |     # Set save directory
155 |     save_dir = args.save_dir or (
156 |         args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path)
157 |     )
158 |     setup_output_dirs(save_dir)
159 |     
160 |     # Collect Images
161 |     if os.path.isdir(args.input_path):
162 |         image_files = []
163 |         for ext in [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"]:
164 |             image_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}")))
165 |         image_files = sorted(image_files)
166 |     else:
167 |         if not os.path.exists(args.input_path):
168 |             raise FileNotFoundError(f"Input path {args.input_path} does not exist")
169 |         image_files = [args.input_path]
170 |     
171 |     total_samples = len(image_files)
172 |     print(f"\nTotal samples to process: {total_samples}")
173 |     
174 |     # Process images one by one
175 |     for image_path in image_files:
176 |         print(f"\nProcessing {image_path}")
177 |         try:
178 |             result, recognition_result = process_element(
179 |                 image_path=image_path,
180 |                 model=model,
181 |                 element_type=args.element_type,
182 |                 save_dir=save_dir,
183 |             )
184 | 
185 |             if args.print_results:
186 |                 print("\nRecognition result:")
187 |                 print(result)
188 |                 print("-" * 40)
189 |         except Exception as e:
190 |             print(f"Error processing {image_path}: {str(e)}")
191 |             continue
192 | 
193 | 
194 | if __name__ == "__main__":
195 |     main()
196 | 


--------------------------------------------------------------------------------
/demo_page.py:
--------------------------------------------------------------------------------
  1 | """ 
  2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
  3 | SPDX-License-Identifier: MIT
  4 | """
  5 | 
  6 | import argparse
  7 | import glob
  8 | import os
  9 | 
 10 | import cv2
 11 | from omegaconf import OmegaConf
 12 | from PIL import Image
 13 | 
 14 | from chat import DOLPHIN
 15 | from utils.utils import *
 16 | 
 17 | 
 18 | def process_document(document_path, model, save_dir, max_batch_size):
 19 |     """Parse documents - Handles both images and PDFs"""
 20 |     file_ext = os.path.splitext(document_path)[1].lower()
 21 |     
 22 |     if file_ext == '.pdf':
 23 |         # Process PDF file
 24 |         # Convert PDF to images
 25 |         images = convert_pdf_to_images(document_path)
 26 |         if not images:
 27 |             raise Exception(f"Failed to convert PDF {document_path} to images")
 28 |         
 29 |         all_results = []
 30 |         
 31 |         # Process each page
 32 |         for page_idx, pil_image in enumerate(images):
 33 |             print(f"Processing page {page_idx + 1}/{len(images)}")
 34 |             
 35 |             # Generate output name for this page
 36 |             base_name = os.path.splitext(os.path.basename(document_path))[0]
 37 |             page_name = f"{base_name}_page_{page_idx + 1:03d}"
 38 |             
 39 |             # Process this page (don't save individual page results)
 40 |             json_path, recognition_results = process_single_image(
 41 |                 pil_image, model, save_dir, page_name, max_batch_size, save_individual=False
 42 |             )
 43 |             
 44 |             # Add page information to results
 45 |             page_results = {
 46 |                 "page_number": page_idx + 1,
 47 |                 "elements": recognition_results
 48 |             }
 49 |             all_results.append(page_results)
 50 |         
 51 |         # Save combined results for multi-page PDF
 52 |         combined_json_path = save_combined_pdf_results(all_results, document_path, save_dir)
 53 |         
 54 |         return combined_json_path, all_results
 55 | 
 56 |     else:
 57 |         # Process regular image file
 58 |         pil_image = Image.open(document_path).convert("RGB")
 59 |         base_name = os.path.splitext(os.path.basename(document_path))[0]
 60 |         return process_single_image(pil_image, model, save_dir, base_name, max_batch_size)
 61 | 
 62 | 
 63 | def process_single_image(image, model, save_dir, image_name, max_batch_size, save_individual=True):
 64 |     """Process a single image (either from file or converted from PDF page)
 65 |     
 66 |     Args:
 67 |         image: PIL Image object
 68 |         model: DOLPHIN model instance
 69 |         save_dir: Directory to save results
 70 |         image_name: Name for the output file
 71 |         max_batch_size: Maximum batch size for processing
 72 |         save_individual: Whether to save individual results (False for PDF pages)
 73 |         
 74 |     Returns:
 75 |         Tuple of (json_path, recognition_results)
 76 |     """
 77 |     # Stage 1: Page-level layout and reading order parsing
 78 |     layout_output = model.chat("Parse the reading order of this document.", image)
 79 | 
 80 |     # Stage 2: Element-level content parsing
 81 |     padded_image, dims = prepare_image(image)
 82 |     recognition_results = process_elements(layout_output, padded_image, dims, model, max_batch_size, save_dir, image_name)
 83 | 
 84 |     # Save outputs only if requested (skip for PDF pages)
 85 |     json_path = None
 86 |     if save_individual:
 87 |         # Create a dummy image path for save_outputs function
 88 |         dummy_image_path = f"{image_name}.jpg"  # Extension doesn't matter, only basename is used
 89 |         json_path = save_outputs(recognition_results, dummy_image_path, save_dir)
 90 | 
 91 |     return json_path, recognition_results
 92 | 
 93 | 
 94 | def process_elements(layout_results, padded_image, dims, model, max_batch_size, save_dir=None, image_name=None):
 95 |     """Parse all document elements with parallel decoding"""
 96 |     layout_results = parse_layout_string(layout_results)
 97 | 
 98 |     text_table_elements = []  # Elements that need processing
 99 |     figure_results = []  # Figure elements (no processing needed)
100 |     previous_box = None
101 |     reading_order = 0
102 | 
103 |     # Collect elements for processing
104 |     for bbox, label in layout_results:
105 |         try:
106 |             # Adjust coordinates
107 |             x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = process_coordinates(
108 |                 bbox, padded_image, dims, previous_box
109 |             )
110 | 
111 |             # Crop and parse element
112 |             cropped = padded_image[y1:y2, x1:x2]
113 |             if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3:
114 |                 if label == "fig":
115 |                     pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
116 |                     
117 |                     figure_filename = save_figure_to_local(pil_crop, save_dir, image_name, reading_order)
118 |                     
119 |                     # For figure regions, store relative path instead of base64
120 |                     figure_results.append(
121 |                         {
122 |                             "label": label,
123 |                             "text": f"![Figure](figures/{figure_filename})",
124 |                             "figure_path": f"figures/{figure_filename}",
125 |                             "bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
126 |                             "reading_order": reading_order,
127 |                         }
128 |                     )
129 |                 else:
130 |                     # For text or table regions, prepare for parsing
131 |                     pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
132 |                     prompt = "Parse the table in the image." if label == "tab" else "Read text in the image."
133 |                     text_table_elements.append(
134 |                         {
135 |                             "crop": pil_crop,
136 |                             "prompt": prompt,
137 |                             "label": label,
138 |                             "bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
139 |                             "reading_order": reading_order,
140 |                         }
141 |                     )
142 | 
143 |             reading_order += 1
144 | 
145 |         except Exception as e:
146 |             print(f"Error processing bbox with label {label}: {str(e)}")
147 |             continue
148 | 
149 |     # Parse text/table elements in parallel
150 |     recognition_results = figure_results
151 |     if text_table_elements:
152 |         crops_list = [elem["crop"] for elem in text_table_elements]
153 |         prompts_list = [elem["prompt"] for elem in text_table_elements]
154 | 
155 |         # Inference in batch
156 |         batch_results = model.chat(prompts_list, crops_list, max_batch_size=max_batch_size)
157 | 
158 |         # Add batch results to recognition_results
159 |         for i, result in enumerate(batch_results):
160 |             elem = text_table_elements[i]
161 |             recognition_results.append(
162 |                 {
163 |                     "label": elem["label"],
164 |                     "bbox": elem["bbox"],
165 |                     "text": result.strip(),
166 |                     "reading_order": elem["reading_order"],
167 |                 }
168 |             )
169 | 
170 |     # Sort elements by reading order
171 |     recognition_results.sort(key=lambda x: x.get("reading_order", 0))
172 | 
173 |     return recognition_results
174 | 
175 | 
176 | def main():
177 |     parser = argparse.ArgumentParser(description="Document parsing based on DOLPHIN")
178 |     parser.add_argument("--config", default="./config/Dolphin.yaml", help="Path to configuration file")
179 |     parser.add_argument("--input_path", type=str, default="./demo", help="Path to input image/PDF or directory of files")
180 |     parser.add_argument(
181 |         "--save_dir",
182 |         type=str,
183 |         default=None,
184 |         help="Directory to save parsing results (default: same as input directory)",
185 |     )
186 |     parser.add_argument(
187 |         "--max_batch_size",
188 |         type=int,
189 |         default=4,
190 |         help="Maximum number of document elements to parse in a single batch (default: 4)",
191 |     )
192 |     args = parser.parse_args()
193 | 
194 |     # Load Model
195 |     config = OmegaConf.load(args.config)
196 |     model = DOLPHIN(config)
197 | 
198 |     # Collect Document Files (images and PDFs)
199 |     if os.path.isdir(args.input_path):
200 |         # Support both image and PDF files
201 |         file_extensions = [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG", ".pdf", ".PDF"]
202 |         
203 |         document_files = []
204 |         for ext in file_extensions:
205 |             document_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}")))
206 |         document_files = sorted(document_files)
207 |     else:
208 |         if not os.path.exists(args.input_path):
209 |             raise FileNotFoundError(f"Input path {args.input_path} does not exist")
210 |         
211 |         # Check if it's a supported file type
212 |         file_ext = os.path.splitext(args.input_path)[1].lower()
213 |         supported_exts = ['.jpg', '.jpeg', '.png', '.pdf']
214 |         
215 |         if file_ext not in supported_exts:
216 |             raise ValueError(f"Unsupported file type: {file_ext}. Supported types: {supported_exts}")
217 |         
218 |         document_files = [args.input_path]
219 | 
220 |     save_dir = args.save_dir or (
221 |         args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path)
222 |     )
223 |     setup_output_dirs(save_dir)
224 | 
225 |     total_samples = len(document_files)
226 |     print(f"\nTotal files to process: {total_samples}")
227 | 
228 |     # Process All Document Files
229 |     for file_path in document_files:
230 |         print(f"\nProcessing {file_path}")
231 |         try:
232 |             json_path, recognition_results = process_document(
233 |                 document_path=file_path,
234 |                 model=model,
235 |                 save_dir=save_dir,
236 |                 max_batch_size=args.max_batch_size,
237 |             )
238 | 
239 |             print(f"Processing completed. Results saved to {save_dir}")
240 | 
241 |         except Exception as e:
242 |             print(f"Error processing {file_path}: {str(e)}")
243 |             continue
244 | 
245 | 
246 | if __name__ == "__main__":
247 |     main()
248 | 


--------------------------------------------------------------------------------
/demo_page_hf.py:
--------------------------------------------------------------------------------
  1 | """ 
  2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
  3 | SPDX-License-Identifier: MIT
  4 | """
  5 | 
  6 | import argparse
  7 | import glob
  8 | import os
  9 | 
 10 | import cv2
 11 | import torch
 12 | from PIL import Image
 13 | from transformers import AutoProcessor, VisionEncoderDecoderModel
 14 | 
 15 | from utils.utils import *
 16 | 
 17 | 
 18 | class DOLPHIN:
 19 |     def __init__(self, model_id_or_path):
 20 |         """Initialize the Hugging Face model
 21 |         
 22 |         Args:
 23 |             model_id_or_path: Path to local model or Hugging Face model ID
 24 |         """
 25 |         # Load model from local path or Hugging Face hub
 26 |         self.processor = AutoProcessor.from_pretrained(model_id_or_path)
 27 |         self.model = VisionEncoderDecoderModel.from_pretrained(model_id_or_path)
 28 |         self.model.eval()
 29 |         
 30 |         # Set device and precision
 31 |         self.device = "cuda" if torch.cuda.is_available() else "cpu"
 32 |         self.model.to(self.device)
 33 |         self.model = self.model.half()  # Always use half precision by default
 34 |         
 35 |         # set tokenizer
 36 |         self.tokenizer = self.processor.tokenizer
 37 |         
 38 |     def chat(self, prompt, image):
 39 |         """Process an image or batch of images with the given prompt(s)
 40 |         
 41 |         Args:
 42 |             prompt: Text prompt or list of prompts to guide the model
 43 |             image: PIL Image or list of PIL Images to process
 44 |             
 45 |         Returns:
 46 |             Generated text or list of texts from the model
 47 |         """
 48 |         # Check if we're dealing with a batch
 49 |         is_batch = isinstance(image, list)
 50 |         
 51 |         if not is_batch:
 52 |             # Single image, wrap it in a list for consistent processing
 53 |             images = [image]
 54 |             prompts = [prompt]
 55 |         else:
 56 |             # Batch of images
 57 |             images = image
 58 |             prompts = prompt if isinstance(prompt, list) else [prompt] * len(images)
 59 |         
 60 |         # Prepare image
 61 |         batch_inputs = self.processor(images, return_tensors="pt", padding=True)
 62 |         batch_pixel_values = batch_inputs.pixel_values.half().to(self.device)
 63 |         
 64 |         # Prepare prompt
 65 |         prompts = [f"<s>{p} <Answer/>" for p in prompts]
 66 |         batch_prompt_inputs = self.tokenizer(
 67 |             prompts,
 68 |             add_special_tokens=False,
 69 |             return_tensors="pt"
 70 |         )
 71 | 
 72 |         batch_prompt_ids = batch_prompt_inputs.input_ids.to(self.device)
 73 |         batch_attention_mask = batch_prompt_inputs.attention_mask.to(self.device)
 74 |         
 75 |         # Generate text
 76 |         outputs = self.model.generate(
 77 |             pixel_values=batch_pixel_values,
 78 |             decoder_input_ids=batch_prompt_ids,
 79 |             decoder_attention_mask=batch_attention_mask,
 80 |             min_length=1,
 81 |             max_length=4096,
 82 |             pad_token_id=self.tokenizer.pad_token_id,
 83 |             eos_token_id=self.tokenizer.eos_token_id,
 84 |             use_cache=True,
 85 |             bad_words_ids=[[self.tokenizer.unk_token_id]],
 86 |             return_dict_in_generate=True,
 87 |             do_sample=False,
 88 |             num_beams=1,
 89 |             repetition_penalty=1.1,
 90 |             temperature=1.0
 91 |         )
 92 |         
 93 |         # Process output
 94 |         sequences = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)
 95 |         
 96 |         # Clean prompt text from output
 97 |         results = []
 98 |         for i, sequence in enumerate(sequences):
 99 |             cleaned = sequence.replace(prompts[i], "").replace("<pad>", "").replace("</s>", "").strip()
100 |             results.append(cleaned)
101 |             
102 |         # Return a single result for single image input
103 |         if not is_batch:
104 |             return results[0]
105 |         return results
106 | 
107 | 
108 | def process_document(document_path, model, save_dir, max_batch_size=None):
109 |     """Parse documents with two stages - Handles both images and PDFs"""
110 |     file_ext = os.path.splitext(document_path)[1].lower()
111 |     
112 |     if file_ext == '.pdf':
113 |         # Process PDF file
114 |         # Convert PDF to images
115 |         images = convert_pdf_to_images(document_path)
116 |         if not images:
117 |             raise Exception(f"Failed to convert PDF {document_path} to images")
118 |         
119 |         all_results = []
120 |         
121 |         # Process each page
122 |         for page_idx, pil_image in enumerate(images):
123 |             print(f"Processing page {page_idx + 1}/{len(images)}")
124 |             
125 |             # Generate output name for this page
126 |             base_name = os.path.splitext(os.path.basename(document_path))[0]
127 |             page_name = f"{base_name}_page_{page_idx + 1:03d}"
128 |             
129 |             # Process this page (don't save individual page results)
130 |             json_path, recognition_results = process_single_image(
131 |                 pil_image, model, save_dir, page_name, max_batch_size, save_individual=False
132 |             )
133 |             
134 |             # Add page information to results
135 |             page_results = {
136 |                 "page_number": page_idx + 1,
137 |                 "elements": recognition_results
138 |             }
139 |             all_results.append(page_results)
140 |         
141 |         # Save combined results for multi-page PDF
142 |         combined_json_path = save_combined_pdf_results(all_results, document_path, save_dir)
143 |         
144 |         return combined_json_path, all_results
145 |     
146 |     else:
147 |         # Process regular image file
148 |         pil_image = Image.open(document_path).convert("RGB")
149 |         base_name = os.path.splitext(os.path.basename(document_path))[0]
150 |         return process_single_image(pil_image, model, save_dir, base_name, max_batch_size)
151 | 
152 | 
153 | def process_single_image(image, model, save_dir, image_name, max_batch_size=None, save_individual=True):
154 |     """Process a single image (either from file or converted from PDF page)
155 |     
156 |     Args:
157 |         image: PIL Image object
158 |         model: DOLPHIN model instance
159 |         save_dir: Directory to save results
160 |         image_name: Name for the output file
161 |         max_batch_size: Maximum batch size for processing
162 |         save_individual: Whether to save individual results (False for PDF pages)
163 |         
164 |     Returns:
165 |         Tuple of (json_path, recognition_results)
166 |     """
167 |     # Stage 1: Page-level layout and reading order parsing
168 |     layout_output = model.chat("Parse the reading order of this document.", image)
169 | 
170 |     # Stage 2: Element-level content parsing
171 |     padded_image, dims = prepare_image(image)
172 |     recognition_results = process_elements(layout_output, padded_image, dims, model, max_batch_size, save_dir, image_name)
173 | 
174 |     # Save outputs only if requested (skip for PDF pages)
175 |     json_path = None
176 |     if save_individual:
177 |         # Create a dummy image path for save_outputs function
178 |         dummy_image_path = f"{image_name}.jpg"  # Extension doesn't matter, only basename is used
179 |         json_path = save_outputs(recognition_results, dummy_image_path, save_dir)
180 | 
181 |     return json_path, recognition_results
182 | 
183 | 
184 | def process_elements(layout_results, padded_image, dims, model, max_batch_size, save_dir=None, image_name=None):
185 |     """Parse all document elements with parallel decoding"""
186 |     layout_results = parse_layout_string(layout_results)
187 | 
188 |     # Store text and table elements separately
189 |     text_elements = []  # Text elements
190 |     table_elements = []  # Table elements
191 |     figure_results = []  # Image elements (no processing needed)
192 |     previous_box = None
193 |     reading_order = 0
194 | 
195 |     # Collect elements to process and group by type
196 |     for bbox, label in layout_results:
197 |         try:
198 |             # Adjust coordinates
199 |             x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = process_coordinates(
200 |                 bbox, padded_image, dims, previous_box
201 |             )
202 | 
203 |             # Crop and parse element
204 |             cropped = padded_image[y1:y2, x1:x2]
205 |             if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3:
206 |                 if label == "fig":
207 |                     pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
208 |                     
209 |                     figure_filename = save_figure_to_local(pil_crop, save_dir, image_name, reading_order)
210 |                     
211 |                     # For figure regions, store relative path instead of base64
212 |                     figure_results.append(
213 |                         {
214 |                             "label": label,
215 |                             "text": f"![Figure](figures/{figure_filename})",
216 |                             "figure_path": f"figures/{figure_filename}",
217 |                             "bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
218 |                             "reading_order": reading_order,
219 |                         }
220 |                     )
221 |                 else:
222 |                     # Prepare element for parsing
223 |                     pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
224 |                     element_info = {
225 |                         "crop": pil_crop,
226 |                         "label": label,
227 |                         "bbox": [orig_x1, orig_y1, orig_x2, orig_y2],
228 |                         "reading_order": reading_order,
229 |                     }
230 |                     
231 |                     # Group by type
232 |                     if label == "tab":
233 |                         table_elements.append(element_info)
234 |                     else:  # Text elements
235 |                         text_elements.append(element_info)
236 | 
237 |             reading_order += 1
238 | 
239 |         except Exception as e:
240 |             print(f"Error processing bbox with label {label}: {str(e)}")
241 |             continue
242 | 
243 |     # Initialize results list
244 |     recognition_results = figure_results.copy()
245 |     
246 |     # Process text elements (in batches)
247 |     if text_elements:
248 |         text_results = process_element_batch(text_elements, model, "Read text in the image.", max_batch_size)
249 |         recognition_results.extend(text_results)
250 |     
251 |     # Process table elements (in batches)
252 |     if table_elements:
253 |         table_results = process_element_batch(table_elements, model, "Parse the table in the image.", max_batch_size)
254 |         recognition_results.extend(table_results)
255 | 
256 |     # Sort elements by reading order
257 |     recognition_results.sort(key=lambda x: x.get("reading_order", 0))
258 | 
259 |     return recognition_results
260 | 
261 | 
262 | def process_element_batch(elements, model, prompt, max_batch_size=None):
263 |     """Process elements of the same type in batches"""
264 |     results = []
265 |     
266 |     # Determine batch size
267 |     batch_size = len(elements)
268 |     if max_batch_size is not None and max_batch_size > 0:
269 |         batch_size = min(batch_size, max_batch_size)
270 |     
271 |     # Process in batches
272 |     for i in range(0, len(elements), batch_size):
273 |         batch_elements = elements[i:i+batch_size]
274 |         crops_list = [elem["crop"] for elem in batch_elements]
275 |         
276 |         # Use the same prompt for all elements in the batch
277 |         prompts_list = [prompt] * len(crops_list)
278 |         
279 |         # Batch inference
280 |         batch_results = model.chat(prompts_list, crops_list)
281 |         
282 |         # Add results
283 |         for j, result in enumerate(batch_results):
284 |             elem = batch_elements[j]
285 |             results.append({
286 |                 "label": elem["label"],
287 |                 "bbox": elem["bbox"],
288 |                 "text": result.strip(),
289 |                 "reading_order": elem["reading_order"],
290 |             })
291 |     
292 |     return results
293 | 
294 | 
295 | def main():
296 |     parser = argparse.ArgumentParser(description="Document parsing based on DOLPHIN")
297 |     parser.add_argument("--model_path", default="./hf_model", help="Path to Hugging Face model")
298 |     parser.add_argument("--input_path", type=str, default="./demo", help="Path to input image/PDF or directory of files")
299 |     parser.add_argument(
300 |         "--save_dir",
301 |         type=str,
302 |         default=None,
303 |         help="Directory to save parsing results (default: same as input directory)",
304 |     )
305 |     parser.add_argument(
306 |         "--max_batch_size",
307 |         type=int,
308 |         default=16,
309 |         help="Maximum number of document elements to parse in a single batch (default: 16)",
310 |     )
311 |     args = parser.parse_args()
312 | 
313 |     # Load Model
314 |     model = DOLPHIN(args.model_path)
315 | 
316 |     # Collect Document Files (images and PDFs)
317 |     if os.path.isdir(args.input_path):
318 |         # Support both image and PDF files
319 |         file_extensions = [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG", ".pdf", ".PDF"]
320 |         
321 |         document_files = []
322 |         for ext in file_extensions:
323 |             document_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}")))
324 |         document_files = sorted(document_files)
325 |     else:
326 |         if not os.path.exists(args.input_path):
327 |             raise FileNotFoundError(f"Input path {args.input_path} does not exist")
328 |         
329 |         # Check if it's a supported file type
330 |         file_ext = os.path.splitext(args.input_path)[1].lower()
331 |         supported_exts = ['.jpg', '.jpeg', '.png', '.pdf']
332 |         
333 |         if file_ext not in supported_exts:
334 |             raise ValueError(f"Unsupported file type: {file_ext}. Supported types: {supported_exts}")
335 |         
336 |         document_files = [args.input_path]
337 | 
338 |     save_dir = args.save_dir or (
339 |         args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path)
340 |     )
341 |     setup_output_dirs(save_dir)
342 | 
343 |     total_samples = len(document_files)
344 |     print(f"\nTotal files to process: {total_samples}")
345 | 
346 |     # Process All Document Files
347 |     for file_path in document_files:
348 |         print(f"\nProcessing {file_path}")
349 |         try:
350 |             json_path, recognition_results = process_document(
351 |                 document_path=file_path,
352 |                 model=model,
353 |                 save_dir=save_dir,
354 |                 max_batch_size=args.max_batch_size,
355 |             )
356 | 
357 |             print(f"Processing completed. Results saved to {save_dir}")
358 | 
359 |         except Exception as e:
360 |             print(f"Error processing {file_path}: {str(e)}")
361 |             continue
362 | 
363 | 
364 | if __name__ == "__main__":
365 |     main()
366 | 


--------------------------------------------------------------------------------
/deployment/ReadMe.md:
--------------------------------------------------------------------------------
 1 | <h1 align="center">
 2 | 🚀 Dolphin Inference/Serving
 3 | </h1>
 4 | 
 5 | ## vLLM
 6 | > [Doc](./vllm/ReadMe.md)
 7 | 
 8 | ## TensorRT-LLM
 9 | > [Doc](./tensorrt_llm/ReadMe.md)
10 | 
11 | ## Others
12 | 
13 | 


--------------------------------------------------------------------------------
/deployment/tensorrt_llm/ReadMe.md:
--------------------------------------------------------------------------------
 1 | <h1 align="center">
 2 | 🚀 Dolphin TensorRT-LLM Demo
 3 | </h1>
 4 | 
 5 | ## ✅ Introduction
 6 | The Dolphin model employs a **Swin Encoder + MBart Decoder** architecture. In the HuggingFace Transformers [Config](https://huggingface.co/ByteDance/Dolphin/blob/main/config.json), 
 7 | its architectures field is specified as "VisionEncoderDecoderModel". **Dolphin**, **[Nougat](https://huggingface.co/docs/transformers/model_doc/nougat)**, and **[Donut](https://huggingface.co/docs/transformers/model_doc/donut)** share the same model architecture. TensorRT-LLM has already supported the Nougat model. 
 8 | Following Nougat's conversion script, we have successfully implemented Dolphin on TensorRT-LLM. 
 9 | 
10 | **Note:** [prompt_ids](./dolphin_runner.py#L120) MUST be of **int32** type, otherwise TensorRT-LLM will produce incorrect results.
11 | 
12 | ## 🛠️ Installation
13 | > We only test TensorRT-LLM 0.18.1 on Linux.
14 | 
15 | https://nvidia.github.io/TensorRT-LLM/0.18.1/installation/linux.html
16 | 
17 | 
18 | ## ⚡ Offline Inference
19 | ```
20 | export MODEL_NAME="Dolphin"
21 | 
22 | # predict elements reading order
23 | python run_dolphin.py \
24 |     --batch_size 1 \
25 |     --hf_model_dir tmp/hf_models/${MODEL_NAME} \
26 |     --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
27 |     --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
28 |     --max_new_tokens 4096 \
29 |     --repetition_penalty 1.0 \
30 |     --input_text "Parse the reading order of this document." \
31 |     --image_path "../../demo/page_imgs/page_1.jpeg"
32 | 
33 | # recognize text/latex
34 | python run_dolphin.py \
35 |     --batch_size 1 \
36 |     --hf_model_dir tmp/hf_models/${MODEL_NAME} \
37 |     --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
38 |     --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
39 |     --max_new_tokens 4096 \
40 |     --repetition_penalty 1.0 \
41 |     --input_text "Read text in the image." \
42 |     --image_path "../../demo/element_imgs/block_formula.jpeg"
43 | 
44 | 
45 | python run_dolphin.py \
46 |     --batch_size 1 \
47 |     --hf_model_dir tmp/hf_models/${MODEL_NAME} \
48 |     --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
49 |     --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
50 |     --max_new_tokens 4096 \
51 |     --repetition_penalty 1.0 \
52 |     --input_text "Read text in the image." \
53 |     --image_path "../../demo/element_imgs/para_1.jpg"
54 | 
55 | # recognize table
56 | python run_dolphin.py \
57 |     --batch_size 1 \
58 |     --hf_model_dir tmp/hf_models/${MODEL_NAME} \
59 |     --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
60 |     --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
61 |     --max_new_tokens 4096 \
62 |     --repetition_penalty 1.0 \
63 |     --input_text "Parse the table in the image." \
64 |     --image_path "../../demo/element_imgs/table_1.jpeg"
65 | ```
66 | 
67 | 
68 | ## ⚡ Online Inference
69 | ```
70 | # 1. Start Api Server
71 | export MODEL_NAME="Dolphin"
72 | 
73 | python api_server.py \
74 |     --hf_model_dir tmp/hf_models/${MODEL_NAME} \
75 |     --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
76 |     --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
77 |     --max_batch_size 16
78 | 
79 | # 2. Predict
80 | # predict elements reading order
81 | python deployment/tensorrt_llm/api_client.py --image_path ./demo/page_imgs/page_1.jpeg --prompt "Parse the reading order of this document."
82 | 
83 | # recognize text/latex
84 | python deployment/tensorrt_llm/api_client.py --image_path ./demo/element_imgs/block_formula.jpeg --prompt "Read text in the image."
85 | python deployment/tensorrt_llm/api_client.py --image_path ./demo/element_imgs/para_1.jpg --prompt "Read text in the image."
86 | 
87 | # recognize table
88 | python deployment/tensorrt_llm/api_client.py --image_path ./demo/element_imgs/table_1.jpeg --prompt "Parse the table in the image."
89 | ```


--------------------------------------------------------------------------------
/deployment/tensorrt_llm/api_client.py:
--------------------------------------------------------------------------------
  1 | # SPDX-License-Identifier: Apache-2.0
  2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
  3 | """Example Python client for `vllm.entrypoints.api_server`
  4 | Start the demo server:
  5 |     python -m vllm.entrypoints.api_server --model <model_name>
  6 | 
  7 | NOTE: The API server is used only for demonstration and simple performance
  8 | benchmarks. It is not intended for production use.
  9 | For production use, we recommend `vllm serve` and the OpenAI client API.
 10 | """
 11 | 
 12 | import argparse
 13 | import base64
 14 | import json
 15 | from argparse import Namespace
 16 | from collections.abc import Iterable
 17 | 
 18 | import requests
 19 | 
 20 | 
 21 | def clear_line(n: int = 1) -> None:
 22 |     LINE_UP = "\033[1A"
 23 |     LINE_CLEAR = "\x1b[2K"
 24 |     for _ in range(n):
 25 |         print(LINE_UP, end=LINE_CLEAR, flush=True)
 26 | 
 27 | 
 28 | def encode_image_base64(image_path: str) -> str:
 29 |     """Encode local image to base64 format."""
 30 | 
 31 |     with open(image_path, "rb") as f:
 32 |         image_data = f.read()
 33 |         result = base64.b64encode(image_data).decode("utf-8")
 34 | 
 35 |     return result
 36 | 
 37 | 
 38 | def post_http_request(
 39 |         prompt: str, image_path: str, api_url: str, stream: bool = False
 40 | ) -> requests.Response:
 41 |     headers = {"User-Agent": "Test Client"}
 42 |     pload = {
 43 |         "prompt": prompt,
 44 |         "image_base64": encode_image_base64(image_path),
 45 |     }
 46 |     response = requests.post(api_url, headers=headers, json=pload, stream=stream)
 47 |     return response
 48 | 
 49 | 
 50 | def get_streaming_response(response: requests.Response) -> Iterable[list[str]]:
 51 |     for chunk in response.iter_lines(
 52 |             chunk_size=8192, decode_unicode=False, delimiter=b"\n"
 53 |     ):
 54 |         if chunk:
 55 |             data = json.loads(chunk.decode("utf-8"))
 56 |             output = data["text"]
 57 |             yield output
 58 | 
 59 | 
 60 | def get_response(response: requests.Response) -> list[str]:
 61 |     data = json.loads(response.content)
 62 |     output = data["text"]
 63 |     return output
 64 | 
 65 | 
 66 | def parse_args():
 67 |     parser = argparse.ArgumentParser()
 68 |     parser.add_argument("--host", type=str, default="localhost")
 69 |     parser.add_argument("--port", type=int, default=8000)
 70 |     parser.add_argument("--prompt", type=str, default="Parse the reading order of this document.")
 71 |     parser.add_argument("--image_path", type=str, default="./demo/page_imgs/page_1.jpeg")
 72 |     parser.add_argument("--stream", action="store_true")
 73 |     return parser.parse_args()
 74 | 
 75 | 
 76 | def main(args: Namespace):
 77 |     prompt = args.prompt
 78 |     image_path = args.image_path
 79 |     api_url = f"http://{args.host}:{args.port}/generate"
 80 |     stream = args.stream
 81 | 
 82 |     print(f"Prompt: {prompt!r}\n", flush=True)
 83 |     response = post_http_request(prompt, image_path, api_url, stream)
 84 | 
 85 |     if stream:
 86 |         num_printed_lines = 0
 87 |         for h in get_streaming_response(response):
 88 |             clear_line(num_printed_lines)
 89 |             num_printed_lines = 0
 90 |             for i, line in enumerate(h):
 91 |                 num_printed_lines += 1
 92 |                 print(f"Response {i}: {line!r}", flush=True)
 93 |     else:
 94 |         output = get_response(response)
 95 |         print(f"Response: {output!r}", flush=True)
 96 | 
 97 | 
 98 | if __name__ == "__main__":
 99 |     args = parse_args()
100 |     main(args)
101 | 


--------------------------------------------------------------------------------
/deployment/tensorrt_llm/api_server.py:
--------------------------------------------------------------------------------
  1 | # copied from: https://github.com/NVIDIA/TensorRT-LLM/blob/v0.18.1/examples/apps/fastapi_server.py
  2 | 
  3 | #!/usr/bin/env python
  4 | import asyncio
  5 | import base64
  6 | import io
  7 | import logging
  8 | import signal
  9 | from http import HTTPStatus
 10 | from PIL import Image
 11 | from typing import Optional
 12 | 
 13 | import click
 14 | import uvicorn
 15 | from fastapi import FastAPI, Request
 16 | from fastapi.responses import JSONResponse, Response
 17 | 
 18 | from tensorrt_llm.executor import CppExecutorError, RequestError
 19 | from dolphin_runner import DolphinRunner, InferenceConfig
 20 | 
 21 | TIMEOUT_KEEP_ALIVE = 5  # seconds.
 22 | 
 23 | 
 24 | async def decode_image(image_base64: str) -> Image.Image:
 25 |     image_data = base64.b64decode(image_base64)
 26 |     image = Image.open(io.BytesIO(image_data))
 27 |     return image
 28 | 
 29 | 
 30 | class LlmServer:
 31 |     def __init__(self, runner: DolphinRunner):
 32 |         self.runner = runner
 33 |         self.app = FastAPI()
 34 |         self.register_routes()
 35 | 
 36 |     def register_routes(self):
 37 |         self.app.add_api_route("/health", self.health, methods=["GET"])
 38 |         self.app.add_api_route("/generate", self.generate, methods=["POST"])
 39 | 
 40 |     async def health(self) -> Response:
 41 |         return Response(status_code=200)
 42 | 
 43 |     async def generate(self, request: Request) -> Response:
 44 |         """ Generate completion for the request.
 45 | 
 46 |         The request should be a JSON object with the following fields:
 47 |         - prompt: the prompt to use for the generation.
 48 |         - image_base64: the image to use for the generation.
 49 |         """
 50 |         request_dict = await request.json()
 51 | 
 52 |         prompt = request_dict.pop("prompt", "")
 53 |         logging.info(f"request prompt: {prompt}")
 54 |         image_base64 = request_dict.pop("image_base64", "")
 55 |         image = await decode_image(image_base64)
 56 | 
 57 |         try:
 58 |             output_texts = self.runner.run([prompt], [image], 4024)
 59 |             output_texts = [texts[0] for texts in output_texts]
 60 |             return JSONResponse({"text": output_texts[0]})
 61 |         except RequestError as e:
 62 |             return JSONResponse(content=str(e),
 63 |                                 status_code=HTTPStatus.BAD_REQUEST)
 64 |         except CppExecutorError:
 65 |             # If internal executor error is raised, shutdown the server
 66 |             signal.raise_signal(signal.SIGINT)
 67 | 
 68 |     async def __call__(self, host, port):
 69 |         config = uvicorn.Config(self.app,
 70 |                                 host=host,
 71 |                                 port=port,
 72 |                                 log_level="info",
 73 |                                 timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
 74 |         await uvicorn.Server(config).serve()
 75 | 
 76 | 
 77 | @click.command()
 78 | @click.option("--hf_model_dir", type=str, required=True)
 79 | @click.option("--visual_engine_dir", type=str, required=True)
 80 | @click.option("--llm_engine_dir", type=str, required=True)
 81 | @click.option("--max_batch_size", type=int, default=16)
 82 | @click.option("--max_new_tokens", type=int, default=4024)
 83 | @click.option("--host", type=str, default=None)
 84 | @click.option("--port", type=int, default=8000)
 85 | def entrypoint(hf_model_dir: str,
 86 |                visual_engine_dir: str,
 87 |                llm_engine_dir: str,
 88 |                max_batch_size: int,
 89 |                max_new_tokens: int,
 90 |                host: Optional[str] = None,
 91 |                port: int = 8000):
 92 |     host = host or "0.0.0.0"
 93 |     port = port or 8000
 94 |     logging.info(f"Starting server at {host}:{port}")
 95 | 
 96 |     config = InferenceConfig(
 97 |         max_new_tokens=max_new_tokens,
 98 |         batch_size=max_batch_size,
 99 |         log_level="info",
100 |         hf_model_dir=hf_model_dir,
101 |         visual_engine_dir=visual_engine_dir,
102 |         llm_engine_dir=llm_engine_dir,
103 |     )
104 | 
105 |     dolphin_runner = DolphinRunner(config)
106 |     server = LlmServer(runner=dolphin_runner)
107 | 
108 |     asyncio.run(server(host, port))
109 | 
110 | 
111 | if __name__ == "__main__":
112 |     entrypoint()


--------------------------------------------------------------------------------
/deployment/tensorrt_llm/convert/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/deployment/tensorrt_llm/convert/__init__.py


--------------------------------------------------------------------------------
/deployment/tensorrt_llm/convert/build_visual_engine.py:
--------------------------------------------------------------------------------
 1 | # copied from: https://github.com/NVIDIA/TensorRT-LLM/blob/v0.18.2/examples/multimodal/build_visual_engine.py
 2 | 
 3 | import argparse
 4 | 
 5 | from tensorrt_llm.tools.multimodal_builder import (VisionEngineBuilder,
 6 |                                                    add_multimodal_arguments)
 7 | 
 8 | if __name__ == '__main__':
 9 |     parser = argparse.ArgumentParser()
10 |     parser = add_multimodal_arguments(parser)
11 |     args = parser.parse_args()
12 | 
13 |     builder = VisionEngineBuilder(args)
14 |     builder.build()
15 | 


--------------------------------------------------------------------------------
/deployment/tensorrt_llm/convert/helper.py:
--------------------------------------------------------------------------------
 1 | # copied from: https://github.com/NVIDIA/TensorRT-LLM/blob/v0.18.1/examples/enc_dec/helper.py
 2 | 
 3 | import typing
 4 | from typing import Union
 5 | 
 6 | import numpy as np
 7 | import torch  # pytype: disable=import-error
 8 | 
 9 | from tensorrt_llm._utils import str_dtype_to_torch
10 | 
11 | 
12 | def split(v: Union[np.ndarray, torch.Tensor],
13 |           tp_size: int,
14 |           tp_rank: int,
15 |           dim=0):
16 |     if tp_size == 1:
17 |         if isinstance(v, np.ndarray):
18 |             return np.ascontiguousarray(v.copy())
19 |         else:
20 |             return v.clone().detach()
21 |     assert len(v.shape) > 1 or dim == 0
22 |     if isinstance(v, np.ndarray):
23 |         return np.ascontiguousarray(
24 |             np.split(v, tp_size, axis=dim)[tp_rank].copy())
25 |     else:
26 |         assert v.shape[dim] % tp_size == 0, \
27 |             'Unable to split: shape={v.shape} (dim={dim}) tp_size={tp_size}.'
28 |         split_size = v.shape[dim] // tp_size
29 |         return v.split(split_size, dim=dim)[tp_rank].clone().detach()
30 | 
31 | 
32 | def reshape(v: torch.Tensor, shape=None):
33 |     if shape is None:
34 |         return v.contiguous()
35 |     else:
36 |         return v.reshape(shape).contiguous()
37 | 
38 | 
39 | def fuse_qkv_one_layer(params, attn_module_name, trtllm_layer_name, tp_size,
40 |                        tp_rank, model_type, weight_shape, bias_shape):
41 | 
42 |     qkv_module_names = get_qkv_module_name(model_type)
43 | 
44 |     weight = {}
45 | 
46 |     # fuse weights of q, k, v
47 |     q_w = params[f'{attn_module_name}.{qkv_module_names["q"]}.weight']
48 |     k_w = params[f'{attn_module_name}.{qkv_module_names["k"]}.weight']
49 |     v_w = params[f'{attn_module_name}.{qkv_module_names["v"]}.weight']
50 | 
51 |     # fuse qkv weight
52 |     shape = q_w.shape  # (do, din)
53 |     qkv_w = torch.cat([q_w, k_w, v_w],
54 |                       dim=0).reshape([3, shape[0], shape[1]])  # (3, do, din)
55 |     qkv_w = split(qkv_w, tp_size, tp_rank, dim=1)
56 |     weight[f'{trtllm_layer_name}.qkv.weight'] = reshape(qkv_w,
57 |                                                         shape=weight_shape)
58 | 
59 |     # fuse qkv biases if present
60 |     if f'{attn_module_name}.{qkv_module_names["q"]}.bias' in params.keys(
61 |     ) and params[f'{attn_module_name}.{qkv_module_names["q"]}.bias'] is not None:
62 |         q_b = params[f'{attn_module_name}.{qkv_module_names["q"]}.bias']
63 |         k_b = params[f'{attn_module_name}.{qkv_module_names["k"]}.bias']
64 |         v_b = params[f'{attn_module_name}.{qkv_module_names["v"]}.bias']
65 |         shape = q_b.shape[0]  # (do,)
66 |         qkv_b = torch.cat([q_b, k_b, v_b], dim=0).reshape([3, shape])  # (3, do)
67 |         qkv_b = split(qkv_b, tp_size, tp_rank, dim=1)
68 |         weight[f'{trtllm_layer_name}.qkv.bias'] = reshape(qkv_b,
69 |                                                           shape=bias_shape)
70 |     return weight
71 | 
72 | 
73 | def get_qkv_module_name(model_type):
74 |     if model_type in ["t5", "blip2"]:
75 |         q = "q"
76 |         k = "k"
77 |         v = "v"
78 |     elif model_type == "bart" or model_type == "nmt":
79 |         q = "q_proj"
80 |         k = "k_proj"
81 |         v = "v_proj"
82 |     elif model_type == "pix2struct":
83 |         q = "query"
84 |         k = "key"
85 |         v = "value"
86 |     return {"q": q, "k": k, "v": v}
87 | 
88 | 
89 | def convert_weight_to_dtype(params: typing.Dict[str, torch.Tensor],
90 |                             dtype: typing.Optional[np.dtype] = None):
91 |     if dtype is not None:
92 |         assert isinstance(dtype,
93 |                           str), f"dtype must be str, but get type {type(dtype)}"
94 |         for name in params.keys():
95 |             params[name] = params[name].to(str_dtype_to_torch(dtype))
96 | 


--------------------------------------------------------------------------------
/deployment/tensorrt_llm/convert_dolphin.sh:
--------------------------------------------------------------------------------
 1 | #!/usr/bin/env bash
 2 | set -ex
 3 | 
 4 | ############################################################################################
 5 | # Reference: https://github.com/NVIDIA/TensorRT-LLM/tree/v0.18.2/examples/multimodal#nougat
 6 | ############################################################################################
 7 | 
 8 | export LD_LIBRARY_PATH=/usr/local/lib/python3.10/site-packages/tensorrt_libs/:/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib/:$LD_LIBRARY_PATH
 9 | 
10 | # 1. Download Huggingface weights
11 | export MODEL_NAME="Dolphin"
12 | git clone https://huggingface.co/Bytedance/${MODEL_NAME} tmp/hf_models/${MODEL_NAME}
13 | 
14 | 
15 | export MAX_BATCH_SIZE=16
16 | export MAX_SEQ_LEN=4096
17 | export MAX_INPUT_LEN=10
18 | export MAX_ENCODER_INPUT_LEN=784
19 | 
20 | # 2. Convert Huggingface weights into TRT-LLM checkpoints and build TRT engines using scripts in examples/enc_dec
21 | python ./convert/convert_checkpoint.py --model_type bart \
22 |     --model_dir tmp/hf_models/${MODEL_NAME} \
23 |     --output_dir tmp/trt_models/${MODEL_NAME}/bfloat16 \
24 |     --tp_size 1 \
25 |     --pp_size 1 \
26 |     --dtype bfloat16 \
27 |     --nougat
28 | 
29 | 
30 | trtllm-build --checkpoint_dir tmp/trt_models/${MODEL_NAME}/bfloat16/decoder \
31 |     --output_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16/decoder \
32 |     --paged_kv_cache disable \
33 |     --moe_plugin disable \
34 |     --gemm_plugin bfloat16 \
35 |     --bert_attention_plugin bfloat16 \
36 |     --gpt_attention_plugin bfloat16 \
37 |     --remove_input_padding enable \
38 |     --max_beam_width 1 \
39 |     --max_batch_size ${MAX_BATCH_SIZE} \
40 |     --max_seq_len ${MAX_SEQ_LEN} \
41 |     --max_input_len ${MAX_INPUT_LEN} \
42 |     --max_encoder_input_len $((${MAX_BATCH_SIZE} * ${MAX_ENCODER_INPUT_LEN})) # MAX_BATCH_SIZE (max_batch_size) * MAX_ENCODER_INPUT_LEN (num_visual_features)
43 | 
44 | # 3. Generate TensorRT engines for visual components and combine everything into final pipeline.
45 | python ./convert/build_visual_engine.py --model_type nougat \
46 |     --model_path tmp/hf_models/${MODEL_NAME} \
47 |     --max_batch_size ${MAX_BATCH_SIZE}


--------------------------------------------------------------------------------
/deployment/tensorrt_llm/dolphin_runner.py:
--------------------------------------------------------------------------------
  1 | """
  2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
  3 | SPDX-License-Identifier: MIT
  4 | """
  5 | 
  6 | import json
  7 | import os
  8 | from typing import Optional
  9 | 
 10 | import tensorrt_llm
 11 | import tensorrt_llm.profiler as profiler
 12 | import torch
 13 | from PIL import Image
 14 | from pydantic import BaseModel, Field
 15 | from tensorrt_llm import logger
 16 | from tensorrt_llm import mpi_rank
 17 | from tensorrt_llm.runtime import MultimodalModelRunner
 18 | from transformers import AutoTokenizer, DonutProcessor
 19 | 
 20 | 
 21 | class InferenceConfig(BaseModel):
 22 |     max_new_tokens: int = Field(128, description="Maximum new tokens to generate")
 23 |     batch_size: int = Field(1, description="Batch size for inference")
 24 |     log_level: str = Field("info", description="Logging level")
 25 |     visual_engine_dir: Optional[str] = Field(None, description="Directory for visual engine files")
 26 |     visual_engine_name: str = Field("model.engine", description="Visual engine filename")
 27 |     llm_engine_dir: Optional[str] = Field(None, description="Directory for LLM engine files")
 28 |     hf_model_dir: Optional[str] = Field(None, description="Hugging Face model directory")
 29 |     input_text: Optional[str] = Field(None, description="Input text for inference")
 30 |     num_beams: int = Field(1, description="Number of beams for beam search")
 31 |     top_k: int = Field(1, description="Top-k sampling value")
 32 |     top_p: float = Field(0.0, description="Top-p (nucleus) sampling value")
 33 |     temperature: float = Field(1.0, description="Sampling temperature")
 34 |     repetition_penalty: float = Field(1.0, description="Repetition penalty factor")
 35 |     run_profiling: bool = Field(False, description="Enable profiling mode")
 36 |     profiling_iterations: int = Field(20, description="Number of profiling iterations")
 37 |     check_accuracy: bool = Field(False, description="Enable accuracy checking")
 38 |     video_path: Optional[str] = Field(None, description="Path to input video file")
 39 |     video_num_frames: Optional[int] = Field(None, description="Number of video frames to process")
 40 |     image_path: Optional[str] = Field(None, description="Path to input image file")
 41 |     path_sep: str = Field(",", description="Path separator character")
 42 |     prompt_sep: str = Field(",", description="Prompt separator character")
 43 |     enable_context_fmha_fp32_acc: Optional[bool] = Field(
 44 |         None,
 45 |         description="Enable FP32 accumulation for context FMHA"
 46 |     )
 47 |     enable_chunked_context: bool = Field(False, description="Enable chunked context processing")
 48 |     use_py_session: bool = Field(False, description="Use Python session instead of C++")
 49 |     kv_cache_free_gpu_memory_fraction: float = Field(
 50 |         0.9,
 51 |         description="Fraction of GPU memory free for KV cache",
 52 |         ge=0.0, le=1.0
 53 |     )
 54 |     cross_kv_cache_fraction: float = Field(
 55 |         0.5,
 56 |         description="Fraction of cross-attention KV cache",
 57 |         ge=0.0, le=1.0
 58 |     )
 59 |     multi_block_mode: bool = Field(True, description="Enable multi-block processing mode")
 60 | 
 61 | 
 62 | class DolphinRunner(MultimodalModelRunner):
 63 |     def __init__(self, args):
 64 |         self.args = args
 65 | 
 66 |         self.runtime_rank = mpi_rank()
 67 |         device_id = self.runtime_rank % torch.cuda.device_count()
 68 |         torch.cuda.set_device(device_id)
 69 |         self.device = "cuda:%d" % (device_id)
 70 | 
 71 |         self.stream = torch.cuda.Stream(torch.cuda.current_device())
 72 |         torch.cuda.set_stream(self.stream)
 73 | 
 74 |         # parse model type from visual engine config
 75 |         with open(os.path.join(self.args.visual_engine_dir, "config.json"),
 76 |                   "r") as f:
 77 |             config = json.load(f)
 78 |         self.model_type = config['builder_config']['model_type']
 79 |         self.vision_precision = config['builder_config']['precision']
 80 |         self.decoder_llm = not (
 81 |                 't5' in self.model_type
 82 |                 or self.model_type in ['nougat', 'pix2struct']
 83 |         )  # BLIP2-T5, pix2struct and Nougat are using encoder-decoder models as LLMs
 84 | 
 85 |         if self.model_type == "mllama":
 86 |             self.vision_input_names = [
 87 |                 "pixel_values",
 88 |                 "aspect_ratio_ids",
 89 |                 "aspect_ratio_mask",
 90 |             ]
 91 |             self.vision_output_names = [
 92 |                 "output",
 93 |             ]
 94 |         else:
 95 |             self.vision_input_names = ["input"]
 96 |             self.vision_output_names = ["output"]
 97 | 
 98 |         self.use_py_session = True
 99 | 
100 |         self.init_image_encoder()
101 |         self.init_tokenizer()
102 |         self.init_processor()
103 |         self.init_llm()
104 | 
105 |     def init_tokenizer(self):
106 |         assert self.model_type == 'nougat'
107 |         self.tokenizer = AutoTokenizer.from_pretrained(self.args.hf_model_dir)
108 |         self.tokenizer.padding_side = "right"
109 | 
110 |     def init_processor(self):
111 |         assert self.model_type == 'nougat'
112 |         self.processor = DonutProcessor.from_pretrained(self.args.hf_model_dir, use_fast=True)
113 | 
114 |     def run(self, input_texts, input_images, max_new_tokens):
115 |         prompts = [f"<s>{text.strip()} <Answer/>" for text in input_texts]
116 |         images = self.processor(input_images, return_tensors="pt")['pixel_values'].to("cuda")
117 |         prompt_ids = self.tokenizer(prompts, add_special_tokens=False, return_tensors="pt").input_ids.to("cuda")
118 | 
119 |         # 🚨🚨🚨 Important! If the type of prompt_ids is not int32, the output will be wrong. 🚨🚨🚨
120 |         prompt_ids = prompt_ids.to(torch.int32)
121 | 
122 |         logger.info("---------------------------------------------------------")
123 |         logger.info(f"images size: {images.size()}")
124 |         logger.info(f"prompt_ids: {prompt_ids}, size: {prompt_ids.size()}, dtype: {prompt_ids.dtype}")
125 |         logger.info("---------------------------------------------------------")
126 | 
127 |         output_texts = self.generate(input_texts,
128 |                                      [None] * len(input_texts),
129 |                                      images,
130 |                                      prompt_ids,
131 |                                      max_new_tokens,
132 |                                      warmup=False,
133 |                                      )
134 | 
135 |         return output_texts
136 | 
137 |     def generate(self,
138 |                  pre_prompt,
139 |                  post_prompt,
140 |                  image,
141 |                  decoder_input_ids,
142 |                  max_new_tokens,
143 |                  warmup=False,
144 |                  other_vision_inputs={},
145 |                  other_decoder_inputs={}):
146 |         if not warmup:
147 |             profiler.start("Generate")
148 |         input_ids, input_lengths, ptuning_args, visual_features = self.preprocess(
149 |             warmup, pre_prompt, post_prompt, image, other_vision_inputs)
150 | 
151 |         if warmup: return None
152 | 
153 |         # use prompt tuning to pass multimodal features
154 |         # model.generate() expects the following params (see layers/embedding.py):
155 |         # args[0]: prompt embedding table, [batch_size, multimodal_len, hidden_size], later flattened to [batch_size * multimodal_len, hidden_size]
156 |         # args[1]: prompt task ids, [batch_size]. in multimodal case, arange(batch_size), i.e. in VILA batching mode 2, each image is treated separately in the batch instead of concated together (although the prompt embedding table has to be concated)
157 |         # args[2]: prompt task vocab size, [1]. assuming all table has the same length, which in multimodal case equals to multimodal_len
158 |         profiler.start("LLM")
159 |         if self.model_type in ['nougat', 'pix2struct']:
160 |             # Trim encoder input_ids to match visual features shape
161 |             ids_shape = (min(self.args.batch_size, len(pre_prompt)), visual_features.shape[1])
162 |             if self.model_type == 'nougat':
163 |                 input_ids = torch.zeros(ids_shape, dtype=torch.int32)
164 |             elif self.model_type == 'pix2struct':
165 |                 input_ids = torch.ones(ids_shape, dtype=torch.int32)
166 | 
167 |         output_ids = self.model.generate(
168 |             input_ids,
169 |             decoder_input_ids,
170 |             max_new_tokens,
171 |             num_beams=self.args.num_beams,
172 |             bos_token_id=self.tokenizer.bos_token_id,
173 |             pad_token_id=self.tokenizer.pad_token_id,
174 |             eos_token_id=self.tokenizer.eos_token_id,
175 |             debug_mode=False,
176 |             prompt_embedding_table=ptuning_args[0],
177 |             prompt_tasks=ptuning_args[1],
178 |             prompt_vocab_size=ptuning_args[2],
179 |         )
180 |         profiler.stop("LLM")
181 | 
182 |         if mpi_rank() == 0:
183 |             # Extract a list of tensors of shape beam_width x output_ids.
184 |             output_beams_list = [
185 |                 self.tokenizer.batch_decode(
186 |                     output_ids[batch_idx, :, decoder_input_ids.shape[1]:],
187 |                     skip_special_tokens=False) for batch_idx in range(
188 |                     min(self.args.batch_size, decoder_input_ids.shape[0]))
189 |             ]
190 | 
191 |             stripped_text = [[
192 |                 output_beams_list[batch_idx][beam_idx].replace("</s>", "").replace("<pad>", "").strip()
193 |                 for beam_idx in range(self.args.num_beams)
194 |             ] for batch_idx in range(
195 |                 min(self.args.batch_size, decoder_input_ids.shape[0]))]
196 |             profiler.stop("Generate")
197 |             return stripped_text
198 |         else:
199 |             profiler.stop("Generate")
200 |             return None
201 | 
202 | 
203 | if __name__ == "__main__":
204 |     config = InferenceConfig(
205 |         max_new_tokens=4024,
206 |         batch_size=16,
207 |         log_level="info",
208 |         hf_model_dir=f"./tmp/hf_models/Dolphin",
209 |         visual_engine_dir=f"./tmp/trt_engines/Dolphin/vision_encoder",
210 |         llm_engine_dir=f"./tmp/trt_engines/Dolphin/1-gpu/bfloat16",
211 |     )
212 | 
213 |     model = DolphinRunner(config)
214 | 
215 |     image_path = "../../demo/page_imgs/page_1.jpeg"
216 |     prompt = "Parse the reading order of this document."
217 |     image = Image.open(image_path).convert("RGB")
218 |     output_texts = model.run([prompt], [image], 4024)
219 |     output_texts = [texts[0] for texts in output_texts]
220 |     print(output_texts)
221 | 


--------------------------------------------------------------------------------
/deployment/tensorrt_llm/run_dolphin.py:
--------------------------------------------------------------------------------
  1 | """
  2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
  3 | SPDX-License-Identifier: MIT
  4 | """
  5 | 
  6 | import argparse
  7 | import os
  8 | 
  9 | import tensorrt_llm
 10 | import tensorrt_llm.profiler as profiler
 11 | from PIL import Image
 12 | from tensorrt_llm import logger
 13 | from tensorrt_llm import mpi_rank
 14 | from tensorrt_llm.runtime import MultimodalModelRunner
 15 | 
 16 | from dolphin_runner import DolphinRunner
 17 | from utils import add_common_args
 18 | 
 19 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
 20 | 
 21 | 
 22 | def print_result(model, input_text, output_text, args):
 23 |     logger.info("---------------------------------------------------------")
 24 |     logger.info(f"\n[Q] {input_text}")
 25 |     for i in range(len(output_text)):
 26 |         logger.info(f"\n[A]: {output_text[i]}")
 27 | 
 28 |     if args.num_beams == 1:
 29 |         output_ids = model.tokenizer(output_text[0][0],
 30 |                                      add_special_tokens=False)['input_ids']
 31 |         logger.info(f"Generated {len(output_ids)} tokens")
 32 | 
 33 |     if args.check_accuracy:
 34 |         if model.model_type != 'nougat':
 35 |             if model.model_type == "vila":
 36 |                 for i in range(len(args.image_path.split(args.path_sep))):
 37 |                     if i % 2 == 0:
 38 |                         assert output_text[i][0].lower(
 39 |                         ) == "the image captures a bustling city intersection teeming with life. from the perspective of a car's dashboard camera, we see"
 40 |                     else:
 41 |                         assert output_text[i][0].lower(
 42 |                         ) == "the image captures the iconic merlion statue in singapore, a renowned worldwide landmark. the merlion, a mythical"
 43 |             elif model.model_type == "llava":
 44 |                 for i in range(len(args.image_path.split(args.path_sep))):
 45 |                     assert output_text[i][0].lower() == 'singapore'
 46 |             elif model.model_type == 'fuyu':
 47 |                 assert output_text[0][0].lower() == '4'
 48 |             elif model.model_type == "pix2struct":
 49 |                 assert "characteristic | cat food, day | cat food, wet | cat treats" in output_text[
 50 |                     0][0].lower()
 51 |             elif model.model_type in [
 52 |                     'blip2', 'neva', 'phi-3-vision', 'llava_next'
 53 |             ]:
 54 |                 assert 'singapore' in output_text[0][0].lower()
 55 |             elif model.model_type == 'video-neva':
 56 |                 assert 'robot' in output_text[0][0].lower()
 57 |             elif model.model_type == 'kosmos-2':
 58 |                 assert 'snowman' in output_text[0][0].lower()
 59 |             elif model.model_type == "mllama":
 60 |                 if "If I had to write a haiku for this one" in input_text:
 61 |                     assert "it would be:.\\nPeter Rabbit is a rabbit.\\nHe lives in a" in output_text[
 62 |                         0][0] or "Here is a haiku for the image:\n\n" in output_text[
 63 |                             0][0], f"expected results: 'it would be:.\\nPeter Rabbit is a rabbit.\\nHe lives in a', generated results: '{output_text[0][0]}'"
 64 |                 elif "The key to life is" in input_text:
 65 |                     assert "to find your passion and pursue it with all your heart." in output_text[
 66 |                         0][0] or "not to be found in the external world," in output_text[
 67 |                             0][0], f"expected results: 'to find your passion and pursue it with all your heart.', generated results: '{output_text[0][0]}'"
 68 |             elif model.model_type == 'llava_onevision':
 69 |                 if args.video_path is None:
 70 |                     assert 'singapore' in output_text[0][0].lower()
 71 |                 else:
 72 |                     assert 'the video is funny because the child\'s actions are' in output_text[
 73 |                         0][0].lower()
 74 |             elif model.model_type == "qwen2_vl":
 75 |                 assert 'dog' in output_text[0][0].lower()
 76 |             else:
 77 |                 assert output_text[0][0].lower() == 'singapore'
 78 | 
 79 |     if args.run_profiling:
 80 |         msec_per_batch = lambda name: 1000 * profiler.elapsed_time_in_sec(
 81 |             name) / args.profiling_iterations
 82 |         logger.info('Latencies per batch (msec)')
 83 |         logger.info('TRT vision encoder: %.1f' % (msec_per_batch('Vision')))
 84 |         logger.info('TRTLLM LLM generate: %.1f' % (msec_per_batch('LLM')))
 85 |         logger.info('Multimodal generate: %.1f' % (msec_per_batch('Generate')))
 86 | 
 87 |     logger.info("---------------------------------------------------------")
 88 | 
 89 | 
 90 | if __name__ == '__main__':
 91 |     parser = argparse.ArgumentParser()
 92 |     parser = add_common_args(parser)
 93 |     args = parser.parse_args()
 94 |     logger.set_level(args.log_level)
 95 | 
 96 |     model = DolphinRunner(args)
 97 | 
 98 |     input_image = Image.open(args.image_path[0]).convert('RGB')
 99 |     num_iters = args.profiling_iterations if args.run_profiling else 1
100 | 
101 |     for _ in range(num_iters):
102 |         output_texts = model.run(args.input_text, [input_image], args.max_new_tokens)
103 | 
104 |     runtime_rank = tensorrt_llm.mpi_rank()
105 |     if runtime_rank == 0:
106 |         print_result(model, args.input_text, output_texts, args)
107 | 


--------------------------------------------------------------------------------
/deployment/tensorrt_llm/run_dolphin.sh:
--------------------------------------------------------------------------------
 1 | #!/usr/bin/env bash
 2 | set -ex
 3 | 
 4 | export MODEL_NAME="Dolphin"
 5 | 
 6 | python run_dolphin.py \
 7 |     --batch_size 1 \
 8 |     --hf_model_dir tmp/hf_models/${MODEL_NAME} \
 9 |     --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
10 |     --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
11 |     --max_new_tokens 4096 \
12 |     --repetition_penalty 1.0 \
13 |     --input_text "Parse the reading order of this document." \
14 |     --image_path "../../demo/page_imgs/page_1.jpeg"
15 | 
16 | 
17 | python run_dolphin.py \
18 |     --batch_size 1 \
19 |     --hf_model_dir tmp/hf_models/${MODEL_NAME} \
20 |     --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
21 |     --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
22 |     --max_new_tokens 4096 \
23 |     --repetition_penalty 1.0 \
24 |     --input_text "Read text in the image." \
25 |     --image_path "../../demo/element_imgs/block_formula.jpeg"
26 | 
27 | 
28 | python run_dolphin.py \
29 |     --batch_size 1 \
30 |     --hf_model_dir tmp/hf_models/${MODEL_NAME} \
31 |     --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
32 |     --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
33 |     --max_new_tokens 4096 \
34 |     --repetition_penalty 1.0 \
35 |     --input_text "Read text in the image." \
36 |     --image_path "../../demo/element_imgs/para_1.jpg"
37 | 
38 | 
39 | python run_dolphin.py \
40 |     --batch_size 1 \
41 |     --hf_model_dir tmp/hf_models/${MODEL_NAME} \
42 |     --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
43 |     --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
44 |     --max_new_tokens 4096 \
45 |     --repetition_penalty 1.0 \
46 |     --input_text "Parse the table in the image." \
47 |     --image_path "../../demo/element_imgs/table_1.jpeg"
48 | 


--------------------------------------------------------------------------------
/deployment/tensorrt_llm/start_dolphin_server.sh:
--------------------------------------------------------------------------------
 1 | #!/usr/bin/env bash
 2 | set -ex
 3 | 
 4 | export MODEL_NAME="Dolphin"
 5 | 
 6 | python api_server.py \
 7 |     --hf_model_dir tmp/hf_models/${MODEL_NAME} \
 8 |     --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \
 9 |     --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \
10 |     --max_batch_size 16


--------------------------------------------------------------------------------
/deployment/tensorrt_llm/utils.py:
--------------------------------------------------------------------------------
 1 | def add_common_args(parser):
 2 |     parser.add_argument('--max_new_tokens', type=int, default=128)
 3 |     parser.add_argument('--batch_size', type=int, default=1)
 4 |     parser.add_argument('--log_level', type=str, default='info')
 5 |     parser.add_argument('--visual_engine_dir',
 6 |                         type=str,
 7 |                         default=None,
 8 |                         help='Directory containing visual TRT engines')
 9 |     parser.add_argument('--visual_engine_name',
10 |                         type=str,
11 |                         default='model.engine',
12 |                         help='Name of visual TRT engine')
13 |     parser.add_argument('--llm_engine_dir',
14 |                         type=str,
15 |                         default=None,
16 |                         help='Directory containing TRT-LLM engines')
17 |     parser.add_argument('--hf_model_dir',
18 |                         type=str,
19 |                         default=None,
20 |                         help="Directory containing tokenizer")
21 |     parser.add_argument('--input_text',
22 |                         type=str,
23 |                         nargs='+',
24 |                         default=None,
25 |                         help='Text prompt to LLM')
26 |     parser.add_argument('--num_beams',
27 |                         type=int,
28 |                         help="Use beam search if num_beams >1",
29 |                         default=1)
30 |     parser.add_argument('--top_k', type=int, default=1)
31 |     parser.add_argument('--top_p', type=float, default=0.0)
32 |     parser.add_argument('--temperature', type=float, default=1.0)
33 |     parser.add_argument('--repetition_penalty', type=float, default=1.0)
34 |     parser.add_argument('--run_profiling',
35 |                         action='store_true',
36 |                         help='Profile runtime over several iterations')
37 |     parser.add_argument('--profiling_iterations',
38 |                         type=int,
39 |                         help="Number of iterations to run profiling",
40 |                         default=20)
41 |     parser.add_argument('--check_accuracy',
42 |                         action='store_true',
43 |                         help='Check correctness of text output')
44 |     parser.add_argument("--image_path",
45 |                         type=str,
46 |                         nargs='+',
47 |                         default=None,
48 |                         help='List of input image paths, separated by symbol')
49 |     parser.add_argument("--path_sep",
50 |                         type=str,
51 |                         default=",",
52 |                         help='Path separator symbol')
53 |     parser.add_argument("--prompt_sep",
54 |                         type=str,
55 |                         default=",",
56 |                         help="Prompt separator symbol")
57 |     parser.add_argument('--enable_context_fmha_fp32_acc',
58 |                         action='store_true',
59 |                         default=None,
60 |                         help="Enable FMHA runner FP32 accumulation.")
61 |     parser.add_argument(
62 |         '--enable_chunked_context',
63 |         action='store_true',
64 |         help='Enables chunked context (only available with cpp session).',
65 |     )
66 |     parser.add_argument(
67 |         '--use_py_session',
68 |         default=False,
69 |         action='store_true',
70 |         help=
71 |         "Whether or not to use Python runtime session. By default C++ runtime session is used for the LLM."
72 |     )
73 |     parser.add_argument(
74 |         '--kv_cache_free_gpu_memory_fraction',
75 |         default=0.9,
76 |         type=float,
77 |         help='Specify the free gpu memory fraction.',
78 |     )
79 |     parser.add_argument(
80 |         '--cross_kv_cache_fraction',
81 |         default=0.5,
82 |         type=float,
83 |         help=
84 |         'Specify the kv cache fraction reserved for cross attention. Only applicable for encoder-decoder models. By default 0.5 for self and 0.5 for cross.',
85 |     )
86 |     parser.add_argument(
87 |         '--multi_block_mode',
88 |         type=lambda s: s.lower() in
89 |         ("yes", "true", "t", "1"
90 |          ),  # custom boolean function to convert input string to boolean
91 |         default=True,
92 |         help=
93 |         "Distribute the work across multiple CUDA thread-blocks on the GPU for masked MHA kernel."
94 |     )
95 |     return parser
96 | 


--------------------------------------------------------------------------------
/deployment/vllm/ReadMe.md:
--------------------------------------------------------------------------------
 1 | <h1 align="center">
 2 | 🚀 Dolphin vLLM Demo
 3 | </h1>
 4 | 
 5 | ## ✅ Introduction
 6 | The Dolphin model employs a **Swin Encoder + MBart Decoder** architecture. In the HuggingFace Transformers [Config](https://huggingface.co/ByteDance/Dolphin/blob/main/config.json), 
 7 | its architectures field is specified as "VisionEncoderDecoderModel". vLLM does not natively support this architecture. 
 8 | To enable vLLM deployment of the Dolphin model, we implemented two vllm plugins: [vllm-dolphin](https://github.com/hanyd2010/vllm-dolphin)[![PyPI version](https://img.shields.io/pypi/v/vllm-dolphin)](https://pypi.org/project/vllm-dolphin/) and [vllm-mbart](https://github.com/hanyd2010/vllm-mbart)[![PyPI version](https://img.shields.io/pypi/v/vllm-mbart)](https://pypi.org/project/vllm-mbart/). 
 9 | We also provide Dolphin vllm demos for both offline inference and online deployment.
10 | 
11 | ## 🛠️ Installation
12 | 
13 | ```
14 | # Install vllm
15 | pip install vllm>=0.9.0
16 | 
17 | # Install vllm-dolphin
18 | pip install vllm-dolphin==0.1
19 | ```
20 | 
21 | ## ⚡ Offline Inference
22 | ```
23 | # predict elements reading order
24 | python deployment/vllm/demo_vllm.py --model ByteDance/Dolphin --image_path ./demo/page_imgs/page_1.jpeg --prompt "Parse the reading order of this document."
25 | 
26 | # recognize text/latex
27 | python deployment/vllm/demo_vllm.py --model ByteDance/Dolphin --image_path ./demo/element_imgs/block_formula.jpeg --prompt "Read text in the image."
28 | python deployment/vllm/demo_vllm.py --model ByteDance/Dolphin --image_path ./demo/element_imgs/para_1.jpg --prompt "Read text in the image."
29 | 
30 | # recognize table
31 | python deployment/vllm/demo_vllm.py --model ByteDance/Dolphin --image_path ./demo/element_imgs/table_1.jpeg --prompt "Parse the table in the image."
32 | ```
33 | 
34 | 
35 | ## ⚡ Online Inference
36 | ```
37 | # 1. Start Api Server
38 | python deployment/vllm/api_server.py --model="ByteDance/Dolphin" --hf-overrides "{\"architectures\": [\"DolphinForConditionalGeneration\"]}"
39 | 
40 | # 2. Predict
41 | # predict elements reading order
42 | python deployment/vllm/api_client.py --image_path ./demo/page_imgs/page_1.jpeg --prompt "Parse the reading order of this document."
43 | 
44 | # recognize text/latex
45 | python deployment/vllm/api_client.py --image_path ./demo/element_imgs/block_formula.jpeg --prompt "Read text in the image."
46 | python deployment/vllm/api_client.py --image_path ./demo/element_imgs/para_1.jpg --prompt "Read text in the image."
47 | 
48 | # recognize table
49 | python deployment/vllm/api_client.py --image_path ./demo/element_imgs/table_1.jpeg --prompt "Parse the table in the image."
50 | ```


--------------------------------------------------------------------------------
/deployment/vllm/api_client.py:
--------------------------------------------------------------------------------
  1 | # SPDX-License-Identifier: Apache-2.0
  2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
  3 | """Example Python client for `vllm.entrypoints.api_server`
  4 | Start the demo server:
  5 |     python -m vllm.entrypoints.api_server --model <model_name>
  6 | 
  7 | NOTE: The API server is used only for demonstration and simple performance
  8 | benchmarks. It is not intended for production use.
  9 | For production use, we recommend `vllm serve` and the OpenAI client API.
 10 | """
 11 | 
 12 | import argparse
 13 | import base64
 14 | import json
 15 | from argparse import Namespace
 16 | from collections.abc import Iterable
 17 | 
 18 | import requests
 19 | 
 20 | 
 21 | def clear_line(n: int = 1) -> None:
 22 |     LINE_UP = "\033[1A"
 23 |     LINE_CLEAR = "\x1b[2K"
 24 |     for _ in range(n):
 25 |         print(LINE_UP, end=LINE_CLEAR, flush=True)
 26 | 
 27 | 
 28 | def encode_image_base64(image_path: str) -> str:
 29 |     """Encode local image to base64 format."""
 30 | 
 31 |     with open(image_path, "rb") as f:
 32 |         image_data = f.read()
 33 |         result = base64.b64encode(image_data).decode("utf-8")
 34 | 
 35 |     return result
 36 | 
 37 | 
 38 | def post_http_request(
 39 |         prompt: str, image_path: str, api_url: str, stream: bool = False
 40 | ) -> requests.Response:
 41 |     headers = {"User-Agent": "Test Client"}
 42 |     pload = {
 43 |         "encoder_prompt": "",
 44 |         "decoder_prompt": prompt,
 45 |         "image_base64": encode_image_base64(image_path),
 46 |         "temperature": 0.0,
 47 |         "max_tokens": 2048,
 48 |         "stream": stream,
 49 |     }
 50 |     response = requests.post(api_url, headers=headers, json=pload, stream=stream)
 51 |     return response
 52 | 
 53 | 
 54 | def get_streaming_response(response: requests.Response) -> Iterable[list[str]]:
 55 |     for chunk in response.iter_lines(
 56 |             chunk_size=8192, decode_unicode=False, delimiter=b"\n"
 57 |     ):
 58 |         if chunk:
 59 |             data = json.loads(chunk.decode("utf-8"))
 60 |             output = data["text"]
 61 |             yield output
 62 | 
 63 | 
 64 | def get_response(response: requests.Response) -> list[str]:
 65 |     data = json.loads(response.content)
 66 |     output = data["text"]
 67 |     return output
 68 | 
 69 | 
 70 | def parse_args():
 71 |     parser = argparse.ArgumentParser()
 72 |     parser.add_argument("--host", type=str, default="localhost")
 73 |     parser.add_argument("--port", type=int, default=8000)
 74 |     parser.add_argument("--prompt", type=str, default="Parse the reading order of this document.")
 75 |     parser.add_argument("--image_path", type=str, default="./demo/page_imgs/page_1.jpeg")
 76 |     parser.add_argument("--stream", action="store_true")
 77 |     return parser.parse_args()
 78 | 
 79 | 
 80 | def main(args: Namespace):
 81 |     prompt = args.prompt
 82 |     image_path = args.image_path
 83 |     api_url = f"http://{args.host}:{args.port}/generate"
 84 |     stream = args.stream
 85 | 
 86 |     print(f"Prompt: {prompt!r}\n", flush=True)
 87 |     response = post_http_request(prompt, image_path, api_url, stream)
 88 | 
 89 |     if stream:
 90 |         num_printed_lines = 0
 91 |         for h in get_streaming_response(response):
 92 |             clear_line(num_printed_lines)
 93 |             num_printed_lines = 0
 94 |             for i, line in enumerate(h):
 95 |                 num_printed_lines += 1
 96 |                 print(f"Response {i}: {line!r}", flush=True)
 97 |     else:
 98 |         output = get_response(response)
 99 |         print(f"Response: {output[0]!r}", flush=True)
100 | 
101 | 
102 | if __name__ == "__main__":
103 |     args = parse_args()
104 |     main(args)
105 | 


--------------------------------------------------------------------------------
/deployment/vllm/api_server.py:
--------------------------------------------------------------------------------
  1 | # SPDX-License-Identifier: Apache-2.0
  2 | """
  3 | NOTE: This API server is used only for demonstrating usage of AsyncEngine
  4 | and simple performance benchmarks. It is not intended for production use.
  5 | For production use, we recommend using our OpenAI compatible server.
  6 | We are also not going to accept PRs modifying this file, please
  7 | change `vllm/entrypoints/openai/api_server.py` instead.
  8 | """
  9 | 
 10 | import asyncio
 11 | import base64
 12 | import json
 13 | import io
 14 | import ssl
 15 | from argparse import Namespace
 16 | from collections.abc import AsyncGenerator
 17 | from PIL import Image
 18 | from typing import Any, Optional
 19 | 
 20 | from fastapi import FastAPI, Request
 21 | from fastapi.responses import JSONResponse, Response, StreamingResponse
 22 | 
 23 | from vllm.engine.arg_utils import AsyncEngineArgs
 24 | from vllm.engine.async_llm_engine import AsyncLLMEngine
 25 | from vllm.entrypoints.launcher import serve_http
 26 | from vllm.entrypoints.utils import with_cancellation
 27 | from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
 28 | from vllm.logger import init_logger
 29 | from vllm.sampling_params import SamplingParams
 30 | from vllm.usage.usage_lib import UsageContext
 31 | from vllm.utils import FlexibleArgumentParser, random_uuid, set_ulimit
 32 | from vllm.version import __version__ as VLLM_VERSION
 33 | 
 34 | logger = init_logger("api_server")
 35 | 
 36 | TIMEOUT_KEEP_ALIVE = 5  # seconds.
 37 | app = FastAPI()
 38 | engine = None
 39 | 
 40 | 
 41 | @app.get("/health")
 42 | async def health() -> Response:
 43 |     """Health check."""
 44 |     return Response(status_code=200)
 45 | 
 46 | 
 47 | @app.post("/generate")
 48 | async def generate(request: Request) -> Response:
 49 |     """Generate completion for the request.
 50 | 
 51 |     The request should be a JSON object with the following fields:
 52 |     - prompt: the prompt to use for the generation.
 53 |     - stream: whether to stream the results or not.
 54 |     - other fields: the sampling parameters (See `SamplingParams` for details).
 55 |     """
 56 |     request_dict = await request.json()
 57 |     return await _generate(request_dict, raw_request=request)
 58 | 
 59 | 
 60 | async def decode_image(image_base64: str) -> Image.Image:
 61 |     image_data = base64.b64decode(image_base64)
 62 |     image = Image.open(io.BytesIO(image_data))
 63 |     return image
 64 | 
 65 | 
 66 | async def custom_process_prompt(encoder_prompt: str, decoder_prompt: str,
 67 |                                 image_base64: str) -> ExplicitEncoderDecoderPrompt:
 68 |     assert engine is not None
 69 |     tokenizer = engine.engine.get_tokenizer_group().tokenizer
 70 |     image = await decode_image(image_base64)
 71 | 
 72 |     if encoder_prompt == "":
 73 |         encoder_prompt = "0" * 783  # For Dolphin
 74 | 
 75 |     if decoder_prompt == "":
 76 |         decoder_prompt_ids = tokenizer.bos_token_id
 77 |     else:
 78 |         decoder_prompt = f"<s>{decoder_prompt.strip()} <Answer/>"
 79 |         decoder_prompt_ids = tokenizer(decoder_prompt, add_special_tokens=False)["input_ids"]
 80 | 
 81 |     enc_dec_prompt = ExplicitEncoderDecoderPrompt(
 82 |         encoder_prompt=TextPrompt(prompt=encoder_prompt, multi_modal_data={"image": image}),
 83 |         decoder_prompt=TokensPrompt(prompt_token_ids=decoder_prompt_ids),
 84 |     )
 85 | 
 86 |     return enc_dec_prompt
 87 | 
 88 | 
 89 | @with_cancellation
 90 | async def _generate(request_dict: dict, raw_request: Request) -> Response:
 91 |     encoder_prompt = request_dict.pop("encoder_prompt", "")
 92 |     decoder_prompt = request_dict.pop("decoder_prompt", "")
 93 |     image_base64 = request_dict.pop("image_base64", "")
 94 |     stream = request_dict.pop("stream", False)
 95 |     sampling_params = SamplingParams(**request_dict)
 96 |     request_id = random_uuid()
 97 | 
 98 |     assert engine is not None
 99 | 
100 |     enc_dec_prompt = await custom_process_prompt(encoder_prompt, decoder_prompt, image_base64)
101 |     results_generator = engine.generate(enc_dec_prompt, sampling_params, request_id)
102 | 
103 |     # Streaming case
104 |     async def stream_results() -> AsyncGenerator[bytes, None]:
105 |         async for request_output in results_generator:
106 |             prompt = request_output.prompt
107 |             assert prompt is not None
108 |             text_outputs = [
109 |                 prompt + output.text for output in request_output.outputs
110 |             ]
111 |             ret = {"text": text_outputs}
112 |             yield (json.dumps(ret) + "\n").encode("utf-8")
113 | 
114 |     if stream:
115 |         return StreamingResponse(stream_results())
116 | 
117 |     # Non-streaming case
118 |     final_output = None
119 |     try:
120 |         async for request_output in results_generator:
121 |             final_output = request_output
122 |     except asyncio.CancelledError:
123 |         return Response(status_code=499)
124 | 
125 |     assert final_output is not None
126 |     prompt = final_output.prompt
127 |     assert prompt is not None
128 |     text_outputs = [prompt + output.text.strip() for output in final_output.outputs]
129 |     ret = {"text": text_outputs}
130 |     return JSONResponse(ret)
131 | 
132 | 
133 | def build_app(args: Namespace) -> FastAPI:
134 |     global app
135 | 
136 |     app.root_path = args.root_path
137 |     return app
138 | 
139 | 
140 | async def init_app(
141 |         args: Namespace,
142 |         llm_engine: Optional[AsyncLLMEngine] = None,
143 | ) -> FastAPI:
144 |     app = build_app(args)
145 | 
146 |     global engine
147 | 
148 |     engine_args = AsyncEngineArgs.from_cli_args(args)
149 |     engine = (llm_engine
150 |               if llm_engine is not None else AsyncLLMEngine.from_engine_args(
151 |         engine_args, usage_context=UsageContext.API_SERVER))
152 |     app.state.engine_client = engine
153 |     return app
154 | 
155 | 
156 | async def run_server(args: Namespace,
157 |                      llm_engine: Optional[AsyncLLMEngine] = None,
158 |                      **uvicorn_kwargs: Any) -> None:
159 |     logger.info("vLLM API server version %s", VLLM_VERSION)
160 |     logger.info("args: %s", args)
161 | 
162 |     set_ulimit()
163 | 
164 |     app = await init_app(args, llm_engine)
165 |     assert engine is not None
166 | 
167 |     shutdown_task = await serve_http(
168 |         app,
169 |         sock=None,
170 |         enable_ssl_refresh=args.enable_ssl_refresh,
171 |         host=args.host,
172 |         port=args.port,
173 |         log_level=args.log_level,
174 |         timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
175 |         ssl_keyfile=args.ssl_keyfile,
176 |         ssl_certfile=args.ssl_certfile,
177 |         ssl_ca_certs=args.ssl_ca_certs,
178 |         ssl_cert_reqs=args.ssl_cert_reqs,
179 |         **uvicorn_kwargs,
180 |     )
181 | 
182 |     await shutdown_task
183 | 
184 | 
185 | if __name__ == "__main__":
186 |     parser = FlexibleArgumentParser()
187 |     parser.add_argument("--host", type=str, default=None)
188 |     parser.add_argument("--port", type=parser.check_port, default=8000)
189 |     parser.add_argument("--ssl-keyfile", type=str, default=None)
190 |     parser.add_argument("--ssl-certfile", type=str, default=None)
191 |     parser.add_argument("--ssl-ca-certs",
192 |                         type=str,
193 |                         default=None,
194 |                         help="The CA certificates file")
195 |     parser.add_argument(
196 |         "--enable-ssl-refresh",
197 |         action="store_true",
198 |         default=False,
199 |         help="Refresh SSL Context when SSL certificate files change")
200 |     parser.add_argument(
201 |         "--ssl-cert-reqs",
202 |         type=int,
203 |         default=int(ssl.CERT_NONE),
204 |         help="Whether client certificate is required (see stdlib ssl module's)"
205 |     )
206 |     parser.add_argument(
207 |         "--root-path",
208 |         type=str,
209 |         default=None,
210 |         help="FastAPI root_path when app is behind a path based routing proxy")
211 |     parser.add_argument("--log-level", type=str, default="debug")
212 |     parser = AsyncEngineArgs.add_cli_args(parser)
213 |     args = parser.parse_args()
214 | 
215 |     asyncio.run(run_server(args))
216 | 


--------------------------------------------------------------------------------
/deployment/vllm/demo_vllm.py:
--------------------------------------------------------------------------------
 1 | """
 2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
 3 | SPDX-License-Identifier: MIT
 4 | """
 5 | 
 6 | import vllm_dolphin  # vllm_dolphin plugin
 7 | import argparse
 8 | from argparse import Namespace
 9 | from PIL import Image
10 | 
11 | from vllm import LLM, SamplingParams
12 | from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
13 | 
14 | import torch
15 | import os
16 | 
17 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
18 | 
19 | 
20 | def offline_inference(model_id: str, prompt: str, image_path: str, max_tokens: int = 2048):
21 |     dtype = "float16" if torch.cuda.is_available() else "float32"
22 |     # Create an encoder/decoder model instance
23 |     llm = LLM(
24 |         model=model_id,
25 |         dtype=dtype,
26 |         enforce_eager=True,
27 |         trust_remote_code=True,
28 |         max_num_seqs=8,
29 |         hf_overrides={"architectures": ["DolphinForConditionalGeneration"]},
30 |     )
31 | 
32 |     # Create a sampling params object.
33 |     sampling_params = SamplingParams(
34 |         temperature=0.0,
35 |         logprobs=0,
36 |         max_tokens=max_tokens,
37 |         prompt_logprobs=None,
38 |         skip_special_tokens=False,
39 |     )
40 | 
41 |     # process prompt
42 |     tokenizer = llm.llm_engine.get_tokenizer_group().tokenizer
43 | 
44 |     # The Dolphin model does not require an Encoder Prompt. To ensure vllm correctly allocates KV Cache,
45 |     # it is necessary to simulate an Encoder Prompt.
46 |     encoder_prompt = "0" * 783
47 |     decoder_prompt = f"<s>{prompt.strip()} <Answer/>"
48 | 
49 |     image = Image.open(image_path)
50 |     enc_dec_prompt = ExplicitEncoderDecoderPrompt(
51 |         encoder_prompt=TextPrompt(prompt=encoder_prompt, multi_modal_data={"image": image}),
52 |         decoder_prompt=TokensPrompt(
53 |             prompt_token_ids=tokenizer(decoder_prompt, add_special_tokens=False)["input_ids"]
54 |         ),
55 |     )
56 | 
57 |     # Generate output tokens from the prompts. The output is a list of
58 |     # RequestOutput objects that contain the prompt, generated text, and other information.
59 |     outputs = llm.generate(enc_dec_prompt, sampling_params)
60 | 
61 |     print("------" * 8)
62 |     # Print the outputs.
63 |     for output in outputs:
64 |         decoder_prompt_tokens = tokenizer.batch_decode(output.prompt_token_ids, skip_special_tokens=True)
65 |         decoder_prompt = "".join(decoder_prompt_tokens)
66 |         generated_text = output.outputs[0].text.strip()
67 |         print(f"Decoder prompt: {decoder_prompt!r}, "
68 |               f"\nGenerated text: {generated_text!r}")
69 | 
70 |         print("------" * 8)
71 | 
72 | 
73 | def parse_args():
74 |     parser = argparse.ArgumentParser()
75 |     parser.add_argument("--model", type=str, default="ByteDance/Dolphin")
76 |     parser.add_argument("--image_path", type=str, default="./demo/page_imgs/page_1.jpeg")
77 |     parser.add_argument("--prompt", type=str, default="Parse the reading order of this document.")
78 |     return parser.parse_args()
79 | 
80 | 
81 | def main(args: Namespace):
82 |     model = args.model
83 |     prompt = args.prompt
84 |     image_path = args.image_path
85 | 
86 |     offline_inference(model, prompt, image_path)
87 | 
88 | 
89 | if __name__ == "__main__":
90 |     args = parse_args()
91 |     main(args)
92 | 


--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
 1 | [tool.black]
 2 | line-length = 120
 3 | include = '\.pyi?
#39;
 4 | exclude = '''
 5 | /(
 6 |     \.git
 7 |   | \.hg
 8 |   | \.mypy_cache
 9 |   | \.tox
10 |   | \.venv
11 |   | _build
12 |   | buck-out
13 |   | build
14 |   | dist
15 | )/
16 | '''
17 | 


--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
 1 | numpy==1.24.4
 2 | omegaconf==2.3.0
 3 | opencv-python==4.11.0.86
 4 | opencv-python-headless==4.5.5.64
 5 | pillow==9.3.0
 6 | timm==0.5.4
 7 | torch==2.1.0
 8 | torchvision==0.16.0
 9 | transformers==4.47.0
10 | accelerate==1.6.0
11 | pymupdf==1.26


--------------------------------------------------------------------------------
/utils/markdown_utils.py:
--------------------------------------------------------------------------------
  1 | """ 
  2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
  3 | SPDX-License-Identifier: MIT
  4 | """
  5 | 
  6 | import re
  7 | import base64
  8 | from typing import List, Dict, Any, Optional
  9 | 
 10 | 
 11 | """
 12 |     Example input:
 13 |     [
 14 |         {"label": "tab", "bbox": [0.176, 0.74, 0.824, 0.82], "text": "<table><tr><td></td><td>HellaSwag</td><td>Obqa</td><td>WinoGrande</td><td>ARC-c</td><td>ARC-e</td><td>boolq</td><td>piqa</td><td>Avg</td></tr><tr><td>OPT-1.3B</td><td>53.65</td><td>33.40</td><td>59.59</td><td>29.44</td><td>50.80</td><td>60.83</td><td>72.36</td><td>51.44</td></tr><tr><td>Pythia-1.0B</td><td>47.16</td><td>31.40</td><td>53.43</td><td>27.05</td><td>48.99</td><td>57.83</td><td>69.21</td><td>48.30</td></tr><tr><td>Pythia-1.4B</td><td>52.01</td><td>33.20</td><td>57.38</td><td>28.50</td><td>54.00</td><td>63.27</td><td>70.95</td><td>51.33</td></tr><tr><td>TinyLlama-1.1B</td><td>59.20</td><td>36.00</td><td>59.12</td><td>30.10</td><td>55.25</td><td>57.83</td><td>73.29</td><td>52.99</td></tr></table>", "reading_order": 6},
 15 |         {"label": "cap", "bbox": [0.28, 0.729, 0.711, 0.74], "text": "Table 2: Zero-shot performance on commonsense reasoning tasks", "reading_order": 7}, 
 16 |         {"label": "para", "bbox": [0.176, 0.848, 0.826, 0.873], "text": "We of performance during training We tracked the accuracy of TinyLlama on common-\nsense reasoning benchmarks during its pre-training, as shown in Fig. 2 . Generally, the performance of", "reading_order": 8}, 
 17 |         {"label": "fnote", "bbox": [0.176, 0.88, 0.824, 0.912], "text": "${ }^{4}$ Due to a bug in the config file, the learning rate did not decrease immediately after warmup and remained at\nthe maximum value for several steps before we fixed this.", "reading_order": 9}, 
 18 |         {"label": "foot", "bbox": [0.496, 0.939, 0.501, 0.95], "text": "14", "reading_order": 10}
 19 |     ]
 20 | """
 21 | 
 22 | 
 23 | def extract_table_from_html(html_string):
 24 |     """Extract and clean table tags from HTML string"""
 25 |     try:
 26 |         table_pattern = re.compile(r'<table.*?>.*?</table>', re.DOTALL)
 27 |         tables = table_pattern.findall(html_string)
 28 |         tables = [re.sub(r'<table[^>]*>', '<table>', table) for table in tables]
 29 |         return '\n'.join(tables)
 30 |     except Exception as e:
 31 |         print(f"extract_table_from_html error: {str(e)}")
 32 |         return f"<table><tr><td>Error extracting table: {str(e)}</td></tr></table>"
 33 | 
 34 | 
 35 | class MarkdownConverter:
 36 |     """Convert structured recognition results to Markdown format"""
 37 |     
 38 |     def __init__(self):
 39 |         # Define heading levels for different section types
 40 |         self.heading_levels = {
 41 |             'title': '#',
 42 |             'sec': '##',
 43 |             'sub_sec': '###'
 44 |         }
 45 |         
 46 |         # Define which labels need special handling
 47 |         self.special_labels = {
 48 |             'tab', 'fig', 'title', 'sec', 'sub_sec', 
 49 |             'list', 'formula', 'reference', 'alg'
 50 |         }
 51 |     
 52 |     def try_remove_newline(self, text: str) -> str:
 53 |         try:
 54 |             # Preprocess text to handle line breaks
 55 |             text = text.strip()
 56 |             text = text.replace('-\n', '')
 57 |             
 58 |             # Handle Chinese text line breaks
 59 |             def is_chinese(char):
 60 |                 return '\u4e00' <= char <= '\u9fff'
 61 | 
 62 |             lines = text.split('\n')
 63 |             processed_lines = []
 64 |             
 65 |             # Process all lines except the last one
 66 |             for i in range(len(lines)-1):
 67 |                 current_line = lines[i].strip()
 68 |                 next_line = lines[i+1].strip()
 69 |                 
 70 |                 # Always add the current line, but determine if we need a newline
 71 |                 if current_line:  # If current line is not empty
 72 |                     if next_line:  # If next line is not empty
 73 |                         # For Chinese text handling
 74 |                         if is_chinese(current_line[-1]) and is_chinese(next_line[0]):
 75 |                             processed_lines.append(current_line)
 76 |                         else:
 77 |                             processed_lines.append(current_line + ' ')
 78 |                     else:
 79 |                         # Next line is empty, add current line with newline
 80 |                         processed_lines.append(current_line + '\n')
 81 |                 else:
 82 |                     # Current line is empty, add an empty line
 83 |                     processed_lines.append('\n')
 84 |             
 85 |             # Add the last line
 86 |             if lines and lines[-1].strip():
 87 |                 processed_lines.append(lines[-1].strip())
 88 |             
 89 |             text = ''.join(processed_lines)
 90 |             
 91 |             return text
 92 |         except Exception as e:
 93 |             print(f"try_remove_newline error: {str(e)}")
 94 |             return text  # Return original text on error
 95 | 
 96 |     def _handle_text(self, text: str) -> str:
 97 |         """
 98 |         Process regular text content, preserving paragraph structure
 99 |         """
100 |         try:
101 |             if not text:
102 |                 return ""
103 |             
104 |             if text.strip().startswith("\\begin{array}") and text.strip().endswith("\\end{array}"):
105 |                 text = "$" + text + "$"
106 |             elif ("_{" in text or "^{" in text or "\\" in text or "_ {" in text or "^ {" in text) and ("
quot; not in text) and ("\\begin" not in text):
107 |                 text = "
quot; + text + "
quot;
108 | 
109 |             # Process formulas in text before handling other text processing
110 |             text = self._process_formulas_in_text(text)
111 |             
112 |             text = self.try_remove_newline(text)
113 |             
114 |             # Return processed text
115 |             return text
116 |         except Exception as e:
117 |             print(f"_handle_text error: {str(e)}")
118 |             return text  # Return original text on error
119 |     
120 |     def _process_formulas_in_text(self, text: str) -> str:
121 |         """
122 |         Process mathematical formulas in text by iteratively finding and replacing formulas.
123 |         - Identify inline and block formulas
124 |         - Replace newlines within formulas with \\
125 |         """
126 |         try:
127 |             # Define formula delimiters and their corresponding patterns
128 |             delimiters = [
129 |                 ('$', '$'),  # Block formula with $
130 |                 ('\\[', '\\]'),  # Block formula with \[ \]
131 |                 ('
#39;, '
#39;),  # Inline formula with $
132 |                 ('\\(', '\\)')  # Inline formula with \( \)
133 |             ]
134 |             
135 |             # Process the text by iterating through each delimiter type
136 |             result = text
137 |             
138 |             for start_delim, end_delim in delimiters:
139 |                 # Create a pattern that matches from start to end delimiter
140 |                 # Using a custom approach to avoid issues with nested delimiters
141 |                 current_pos = 0
142 |                 processed_parts = []
143 |                 
144 |                 while current_pos < len(result):
145 |                     # Find the next start delimiter
146 |                     start_pos = result.find(start_delim, current_pos)
147 |                     if start_pos == -1:
148 |                         # No more formulas of this type
149 |                         processed_parts.append(result[current_pos:])
150 |                         break
151 |                     
152 |                     # Add text before the formula
153 |                     processed_parts.append(result[current_pos:start_pos])
154 |                     
155 |                     # Find the matching end delimiter
156 |                     end_pos = result.find(end_delim, start_pos + len(start_delim))
157 |                     if end_pos == -1:
158 |                         # No matching end delimiter, treat as regular text
159 |                         processed_parts.append(result[start_pos:])
160 |                         break
161 |                     
162 |                     # Extract the formula content (without delimiters)
163 |                     formula_content = result[start_pos + len(start_delim):end_pos]
164 |                     
165 |                     # Process the formula content - replace newlines with \\
166 |                     processed_formula = formula_content.replace('\n', ' \\\\ ')
167 |                     
168 |                     # Add the processed formula with its delimiters
169 |                     processed_parts.append(f"{start_delim}{processed_formula}{end_delim}")
170 |                     
171 |                     # Move past this formula
172 |                     current_pos = end_pos + len(end_delim)
173 |                 
174 |                 # Update the result with processed text
175 |                 result = ''.join(processed_parts)
176 |             return result
177 |         except Exception as e:
178 |             print(f"_process_formulas_in_text error: {str(e)}")
179 |             return text  # Return original text on error
180 |     
181 |     def _remove_newline_in_heading(self, text: str) -> str:
182 |         """
183 |         Remove newline in heading
184 |         """
185 |         try:
186 |             # Handle Chinese text line breaks
187 |             def is_chinese(char):
188 |                 return '\u4e00' <= char <= '\u9fff'
189 |             
190 |             # Check if the text contains Chinese characters
191 |             if any(is_chinese(char) for char in text):
192 |                 return text.replace('\n', '')
193 |             else:
194 |                 return text.replace('\n', ' ')
195 | 
196 |         except Exception as e:
197 |             print(f"_remove_newline_in_heading error: {str(e)}")
198 |             return text
199 |     
200 |     def _handle_heading(self, text: str, label: str) -> str:
201 |         """
202 |         Convert section headings to appropriate markdown format
203 |         """
204 |         try:
205 |             level = self.heading_levels.get(label, '#')
206 |             text = text.strip()
207 |             text = self._remove_newline_in_heading(text)
208 |             text = self._handle_text(text)
209 |             return f"{level} {text}\n\n"
210 |         except Exception as e:
211 |             print(f"_handle_heading error: {str(e)}")
212 |             return f"# Error processing heading: {text}\n\n"
213 |     
214 |     def _handle_list_item(self, text: str) -> str:
215 |         """
216 |         Convert list items to markdown list format
217 |         """
218 |         try:
219 |             return f"- {text.strip()}\n"
220 |         except Exception as e:
221 |             print(f"_handle_list_item error: {str(e)}")
222 |             return f"- Error processing list item: {text}\n"
223 |     
224 |     def _handle_figure(self, text: str, section_count: int) -> str:
225 |         """
226 |         Handle figure content
227 |         """
228 |         try:
229 |             # Check if it's a file path starting with "figures/"
230 |             if text.startswith("figures/"):
231 |                 # Convert to relative path from markdown directory to figures directory
232 |                 relative_path = f"../{text}"
233 |                 return f"![Figure {section_count}]({relative_path})\n\n"
234 | 
235 |             # Check if it's already a markdown format image link
236 |             if text.startswith("!["):
237 |                 # Already in markdown format, return directly
238 |                 return f"{text}\n\n"
239 | 
240 |             # If it's still base64 format, maintain original logic
241 |             if text.startswith("data:image/"):
242 |                 return f"![Figure {section_count}]({text})\n\n"
243 |             elif ";" in text and "," in text:
244 |                 return f"![Figure {section_count}]({text})\n\n"
245 |             else:
246 |                 # Assume it's raw base64, convert to data URI
247 |                 img_format = "png"
248 |                 data_uri = f"data:image/{img_format};base64,{text}"
249 |                 return f"![Figure {section_count}]({data_uri})\n\n"
250 |                 
251 |         except Exception as e:
252 |             print(f"_handle_figure error: {str(e)}")
253 |             return f"*[Error processing figure: {str(e)}]*\n\n"
254 | 
255 |     def _handle_table(self, text: str) -> str:
256 |         """
257 |         Convert table content to markdown format
258 |         """
259 |         try:
260 |             markdown_content = []
261 |             if '<table' in text.lower() or '<tr' in text.lower():
262 |                 markdown_table = extract_table_from_html(text)
263 |                 markdown_content.append(markdown_table + "\n")
264 |             else:
265 |                 table_lines = text.split('\n')
266 |                 if table_lines:
267 |                     col_count = len(table_lines[0].split()) if table_lines[0] else 1
268 |                     header = '| ' + ' | '.join(table_lines[0].split()) + ' |'
269 |                     markdown_content.append(header)
270 |                     markdown_content.append('| ' + ' | '.join(['---'] * col_count) + ' |')
271 |                     for line in table_lines[1:]:
272 |                         cells = line.split()
273 |                         while len(cells) < col_count:
274 |                             cells.append('')
275 |                         markdown_content.append('| ' + ' | '.join(cells) + ' |')
276 |             return '\n'.join(markdown_content) + '\n\n'
277 |         except Exception as e:
278 |             print(f"_handle_table error: {str(e)}")
279 |             return f"*[Error processing table: {str(e)}]*\n\n"
280 | 
281 |     def _handle_algorithm(self, text: str) -> str:
282 |         """
283 |         Process algorithm blocks with proper formatting
284 |         """
285 |         try:
286 |             # Remove algorithm environment tags if present
287 |             text = re.sub(r'\\begin\{algorithm\}(.*?)\\end\{algorithm\}', r'\1', text, flags=re.DOTALL)
288 |             text = text.replace('\\begin{algorithm}', '').replace('\\end{algorithm}', '')
289 |             text = text.replace('\\begin{algorithmic}', '').replace('\\end{algorithmic}', '')
290 |             
291 |             # Process the algorithm text
292 |             lines = text.strip().split('\n')
293 |             
294 |             # Check if there's a caption or label
295 |             caption = ""
296 |             algorithm_text = []
297 |             
298 |             for line in lines:
299 |                 if '\\caption' in line:
300 |                     # Extract caption text
301 |                     caption_match = re.search(r'\\caption\{(.*?)\}', line)
302 |                     if caption_match:
303 |                         caption = f"**{caption_match.group(1)}**\n\n"
304 |                     continue
305 |                 elif '\\label' in line:
306 |                     continue  # Skip label lines
307 |                 else:
308 |                     algorithm_text.append(line)
309 |             
310 |             # Join the algorithm text and wrap in code block
311 |             formatted_text = '\n'.join(algorithm_text)
312 |             
313 |             # Return the formatted algorithm with caption
314 |             return f"{caption}```\n{formatted_text}\n```\n\n"
315 |         except Exception as e:
316 |             print(f"_handle_algorithm error: {str(e)}")
317 |             return f"*[Error processing algorithm: {str(e)}]*\n\n{text}\n\n"
318 | 
319 |     def _handle_formula(self, text: str) -> str:
320 |         """
321 |         Handle formula-specific content
322 |         """
323 |         try:
324 |             # Process the formula content
325 |             processed_text = self._process_formulas_in_text(text)
326 |             
327 |             # For formula blocks, ensure they're properly formatted in markdown
328 |             if '$' not in processed_text and '\\[' not in processed_text:
329 |                 # If no block formula delimiters are present, wrap in $ for block formula
330 |                 processed_text = f'${processed_text}$'
331 |             
332 |             return f"{processed_text}\n\n"
333 |         except Exception as e:
334 |             print(f"_handle_formula error: {str(e)}")
335 |             return f"*[Error processing formula: {str(e)}]*\n\n"
336 | 
337 |     def convert(self, recognition_results: List[Dict[str, Any]]) -> str:
338 |         """
339 |         Convert recognition results to markdown format
340 |         """
341 |         try:
342 |             markdown_content = []
343 |             
344 |             for section_count, result in enumerate(recognition_results):
345 |                 try:
346 |                     label = result.get('label', '')
347 |                     text = result.get('text', '').strip()
348 |                     
349 |                     # Skip empty text
350 |                     if not text:
351 |                         continue
352 |                         
353 |                     # Handle different content types
354 |                     if label in {'title', 'sec', 'sub_sec'}:
355 |                         markdown_content.append(self._handle_heading(text, label))
356 |                     elif label == 'list':
357 |                         markdown_content.append(self._handle_list_item(text))
358 |                     elif label == 'fig':
359 |                         markdown_content.append(self._handle_figure(text, section_count))
360 |                     elif label == 'tab':
361 |                         markdown_content.append(self._handle_table(text))
362 |                     elif label == 'alg':
363 |                         markdown_content.append(self._handle_algorithm(text))
364 |                     elif label == 'formula':
365 |                         markdown_content.append(self._handle_formula(text))
366 |                     elif label not in self.special_labels:
367 |                         # Handle regular text (paragraphs, etc.)
368 |                         processed_text = self._handle_text(text)
369 |                         markdown_content.append(f"{processed_text}\n\n")
370 |                 except Exception as e:
371 |                     print(f"Error processing item {section_count}: {str(e)}")
372 |                     # Add a placeholder for the failed item
373 |                     markdown_content.append(f"*[Error processing content]*\n\n")
374 |             
375 |             # Join all content and apply post-processing
376 |             result = ''.join(markdown_content)
377 |             return self._post_process(result)
378 |         except Exception as e:
379 |             print(f"convert error: {str(e)}")
380 |             return f"Error generating markdown content: {str(e)}"
381 | 
382 |     def _post_process(self, markdown_content: str) -> str:
383 |         """
384 |         Apply post-processing fixes to the generated markdown content
385 |         """
386 |         try:
387 |             # Handle author information
388 |             author_pattern = re.compile(r'\\author\{(.*?)\}', re.DOTALL)
389 |             
390 |             def process_author_match(match):
391 |                 # Extract author content
392 |                 author_content = match.group(1)
393 |                 # Process the author content
394 |                 return self._handle_text(author_content)
395 |             
396 |             # Replace \author{...} with processed content
397 |             markdown_content = author_pattern.sub(process_author_match, markdown_content)
398 |             
399 |             # Handle special case where author is inside math environment
400 |             math_author_pattern = re.compile(r'\$(\\author\{.*?\})\
#39;, re.DOTALL)
401 |             match = math_author_pattern.search(markdown_content)
402 |             if match:
403 |                 # Extract the author command
404 |                 author_cmd = match.group(1)
405 |                 # Extract content from author command
406 |                 author_content_match = re.search(r'\\author\{(.*?)\}', author_cmd, re.DOTALL)
407 |                 if author_content_match:
408 |                     # Get author content and process it
409 |                     author_content = author_content_match.group(1)
410 |                     processed_content = self._handle_text(author_content)
411 |                     # Replace the entire $\author{...}$ block with processed content
412 |                     markdown_content = markdown_content.replace(match.group(0), processed_content)
413 |             
414 |             # Replace LaTeX abstract environment with plain text
415 |             markdown_content = re.sub(r'\\begin\{abstract\}(.*?)\\end\{abstract\}', 
416 |                                      r'**Abstract** \1', 
417 |                                      markdown_content, 
418 |                                      flags=re.DOTALL)
419 |             
420 |             # Replace standalone \begin{abstract} (without matching end)
421 |             markdown_content = re.sub(r'\\begin\{abstract\}', 
422 |                                      r'**Abstract**', 
423 |                                      markdown_content)
424 |             
425 |             # Replace LaTeX equation numbers with tag format, handling cases with extra backslashes
426 |             markdown_content = re.sub(r'\\eqno\{\((.*?)\)\}',
427 |                                     r'\\tag{\1}',
428 |                                     markdown_content)
429 | 
430 |             # Find the starting tag of the formula
431 |             markdown_content = markdown_content.replace("\[ \\\\", "$ \\\\")
432 | 
433 |             # Find the ending tag of the formula (ensure this is the only ending tag)
434 |             markdown_content = markdown_content.replace("\\\\ \]", "\\\\ $")
435 | 
436 |             # Fix other common LaTeX issues
437 |             replacements = [
438 |                 # Fix spacing issues in subscripts and superscripts
439 |                 (r'_ {', r'_{'),
440 |                 (r'^ {', r'^{'),
441 |                 
442 |                 # Fix potential issues with multiple consecutive newlines
443 |                 (r'\n{3,}', r'\n\n')
444 |             ]
445 |             
446 |             for old, new in replacements:
447 |                 markdown_content = re.sub(old, new, markdown_content)
448 |             
449 |             return markdown_content
450 |         except Exception as e:
451 |             print(f"_post_process error: {str(e)}")
452 |             return markdown_content  # Return original content if post-processing fails
453 | 


--------------------------------------------------------------------------------
/utils/model.py:
--------------------------------------------------------------------------------
  1 | """
  2 | Copyright (c) 2022-present NAVER Corp.
  3 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
  4 | MIT License
  5 | This file has been modified by [ByteDance Ltd. and/or its affiliates] on 20250118.
  6 | The original file available at https://github.com/clovaai/donut/blob/master/donut/model.py was released under the MIT license.
  7 | This modified file is released under the same license.
  8 | """
  9 | 
 10 | import logging
 11 | from collections import defaultdict
 12 | from typing import List, Optional
 13 | 
 14 | import torch
 15 | import torch.nn.functional as F
 16 | from PIL import Image
 17 | from timm.models.swin_transformer import SwinTransformer
 18 | from torch import nn
 19 | from transformers import (
 20 |     MBartConfig,
 21 |     MBartForCausalLM,
 22 |     StoppingCriteria,
 23 |     StoppingCriteriaList,
 24 | )
 25 | from transformers.file_utils import ModelOutput
 26 | from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
 27 | 
 28 | 
 29 | class SwinEncoder(nn.Module):
 30 |     r"""
 31 |     Encoder based on SwinTransformer
 32 |     Set the initial weights and configuration with a pretrained SwinTransformer and then
 33 |     modify the detailed configurations
 34 | 
 35 |     Args:
 36 |         input_size: Input image size (width, height)
 37 |         align_long_axis: Whether to rotate image if height is greater than width
 38 |         window_size: Window size(=patch size) of SwinTransformer
 39 |         encoder_layer: Number of layers of SwinTransformer encoder
 40 |         name_or_path: Name of a pretrained model name either registered in huggingface.co. or saved in local.
 41 |                       otherwise, `swin_base_patch4_window12_384` will be set (using `timm`).
 42 |     """
 43 | 
 44 |     def __init__(
 45 |         self,
 46 |         input_size,
 47 |         align_long_axis: bool = False,
 48 |         window_size: int = 7,
 49 |         encoder_layer: List[int] = [2, 2, 14, 2],
 50 |         patch_size: int = [4, 4],
 51 |         embed_dim: int = 128,
 52 |         num_heads: List[int] = [4, 8, 16, 32],
 53 |     ):
 54 |         super().__init__()
 55 |         if isinstance(input_size, int):
 56 |             input_size = [input_size, input_size]
 57 |         self.input_size = input_size
 58 |         self.align_long_axis = align_long_axis
 59 |         self.window_size = window_size
 60 |         self.encoder_layer = encoder_layer
 61 |         self.patch_size = patch_size
 62 |         self.embed_dim = embed_dim
 63 |         self.num_heads = num_heads
 64 | 
 65 |         self.model = SwinTransformer(
 66 |             img_size=self.input_size,
 67 |             depths=self.encoder_layer,
 68 |             window_size=self.window_size,
 69 |             patch_size=self.patch_size,
 70 |             embed_dim=self.embed_dim,
 71 |             num_heads=self.num_heads,
 72 |             num_classes=0,
 73 |         )
 74 | 
 75 |     def forward(self, x: torch.Tensor, text_embedding: Optional[torch.Tensor] = None) -> torch.Tensor:
 76 |         """
 77 |         Args:
 78 |             x: (batch_size, num_channels, height, width)
 79 |         """
 80 |         x = self.model.patch_embed(x)
 81 |         x = self.model.pos_drop(x)
 82 |         x = self.model.layers(x)
 83 |         return x
 84 | 
 85 | 
 86 | class LayerNorm(nn.LayerNorm):
 87 |     """Subclass torch's LayerNorm to handle fp16."""
 88 | 
 89 |     def _set_dtype(self, dtype):
 90 |         self._dtype = dtype
 91 | 
 92 |     def forward(self, x: torch.Tensor):
 93 |         orig_type = x.dtype
 94 |         ret = super().forward(x.type(dtype=self._dtype))
 95 |         return ret.type(orig_type)
 96 | 
 97 | 
 98 | class BARTDecoder(nn.Module):
 99 |     """
100 |     Decoder based on Multilingual BART
101 |     Set the initial weights and configuration with a pretrained multilingual BART model,
102 |     and modify the detailed configurations as a Donut decoder
103 | 
104 |     Args:
105 |         decoder_layer:
106 |             Number of layers of BARTDecoder
107 |         max_position_embeddings:
108 |             The maximum sequence length to be trained
109 |         name_or_path:
110 |             Name of a pretrained model name either registered in huggingface.co. or saved in local,
111 |             otherwise, `facebook/mbart-large-50` will be set (using `transformers`)
112 |     """
113 | 
114 |     def __init__(
115 |         self,
116 |         tokenizer,
117 |         decoder_layer: int,
118 |         max_position_embeddings: int,
119 |         hidden_dimension: int = 1024,
120 |         **kwargs,
121 |     ):
122 |         super().__init__()
123 |         self.decoder_layer = decoder_layer
124 |         self.max_position_embeddings = max_position_embeddings
125 |         self.hidden_dimension = hidden_dimension
126 | 
127 |         self.tokenizer = tokenizer
128 | 
129 |         self.model = MBartForCausalLM(
130 |             config=MBartConfig(
131 |                 tie_word_embeddings=True,
132 |                 is_decoder=True,
133 |                 is_encoder_decoder=False,
134 |                 add_cross_attention=True,
135 |                 decoder_layers=self.decoder_layer,
136 |                 max_position_embeddings=self.max_position_embeddings,
137 |                 vocab_size=len(self.tokenizer),
138 |                 scale_embedding=True,
139 |                 add_final_layer_norm=True,
140 |                 d_model=self.hidden_dimension,
141 |             )
142 |         )
143 |         # self.model.config.is_encoder_decoder = True  # to get cross-attention
144 |         self.model.model.decoder.embed_tokens.padding_idx = self.tokenizer.pad_token_id
145 |         self.model.prepare_inputs_for_generation = self.prepare_inputs_for_inference
146 | 
147 |     def add_special_tokens(self, list_of_tokens: List[str]):
148 |         """
149 |         Add special tokens to tokenizer and resize the token embeddings
150 |         """
151 |         newly_added_num = self.tokenizer.add_special_tokens({"additional_special_tokens": sorted(set(list_of_tokens))})
152 |         if newly_added_num > 0:
153 |             self.model.resize_token_embeddings(len(self.tokenizer))
154 | 
155 |     def add_tokens(self, list_of_tokens: List[str]):
156 |         """
157 |         Add special tokens to tokenizer and resize the token embeddings
158 |         """
159 |         newly_added_num = self.tokenizer.add_tokens(sorted(set(list_of_tokens)))
160 |         if newly_added_num > 0:
161 |             self.model.resize_token_embeddings(len(self.tokenizer))
162 | 
163 |     def prepare_inputs_for_inference(
164 |         self,
165 |         input_ids: torch.Tensor,
166 |         encoder_outputs: torch.Tensor,
167 |         past=None,
168 |         past_key_values=None,
169 |         use_cache: bool = None,
170 |         attention_mask: torch.Tensor = None,
171 |         **kwargs,
172 |     ):
173 |         """
174 |         Args:
175 |             input_ids: (batch_size, sequence_length)
176 | 
177 |         Returns:
178 |             input_ids: (batch_size, sequence_length)
179 |             attention_mask: (batch_size, sequence_length)
180 |             encoder_hidden_states: (batch_size, sequence_length, embedding_dim)
181 |         """
182 |         attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long()
183 |         past = past or past_key_values
184 |         if past is not None:
185 |             input_ids = input_ids[:, -1:]
186 |         output = {
187 |             "input_ids": input_ids,
188 |             "attention_mask": attention_mask,
189 |             "past_key_values": past,
190 |             "use_cache": use_cache,
191 |             "encoder_hidden_states": encoder_outputs.last_hidden_state,
192 |         }
193 |         return output
194 | 
195 |     def forward(
196 |         self,
197 |         input_ids: torch.LongTensor = None,
198 |         attention_mask: Optional[torch.Tensor] = None,
199 |         encoder_hidden_states: Optional[torch.Tensor] = None,
200 |         past_key_values: Optional[torch.Tensor] = None,
201 |         inputs_embeds: Optional[torch.FloatTensor] = None,
202 |         labels: Optional[torch.Tensor] = None,
203 |         use_cache: bool = None,
204 |         output_attentions: Optional[torch.Tensor] = None,
205 |         output_hidden_states: Optional[torch.Tensor] = None,
206 |         return_dict: bool = None,
207 |     ):
208 |         return self.model.forward(
209 |             input_ids=input_ids,
210 |             attention_mask=attention_mask,
211 |             labels=labels,
212 |             encoder_hidden_states=encoder_hidden_states,
213 |             past_key_values=past_key_values,
214 |             inputs_embeds=inputs_embeds,
215 |             use_cache=use_cache,
216 |             output_attentions=output_attentions,
217 |             output_hidden_states=output_hidden_states,
218 |             return_dict=return_dict,
219 |         )
220 | 
221 |     @staticmethod
222 |     def resize_bart_abs_pos_emb(weight: torch.Tensor, max_length: int) -> torch.Tensor:
223 |         """
224 |         Resize position embeddings
225 |         Truncate if sequence length of MBart backbone is greater than given max_length,
226 |         else interpolate to max_length
227 |         """
228 |         if weight.shape[0] > max_length:
229 |             weight = weight[:max_length, ...]
230 |         else:
231 |             weight = (
232 |                 F.interpolate(
233 |                     weight.permute(1, 0).unsqueeze(0),
234 |                     size=max_length,
235 |                     mode="linear",
236 |                     align_corners=False,
237 |                 )
238 |                 .squeeze(0)
239 |                 .permute(1, 0)
240 |             )
241 |         return weight
242 | 
243 | 
244 | class DonutConfig(PretrainedConfig):
245 | 
246 |     def __init__(
247 |         self,
248 |         decoder_layer: int = 10,
249 |         max_position_embeddings: int = None,
250 |         max_length: int = 4096,
251 |         hidden_dimension: int = 1024,
252 |         **kwargs,
253 |     ):
254 |         super().__init__()
255 |         self.decoder_layer = decoder_layer
256 |         self.max_position_embeddings = max_length if max_position_embeddings is None else max_position_embeddings
257 |         self.max_length = max_length
258 |         self.hidden_dimension = hidden_dimension
259 | 
260 | 
261 | class RunningVarTorch:
262 |     def __init__(self, L=15, norm=False):
263 |         self.values = None
264 |         self.L = L
265 |         self.norm = norm
266 | 
267 |     def push(self, x: torch.Tensor):
268 |         assert x.dim() == 1
269 |         if self.values is None:
270 |             self.values = x[:, None]
271 |         elif self.values.shape[1] < self.L:
272 |             self.values = torch.cat((self.values, x[:, None]), 1)
273 |         else:
274 |             self.values = torch.cat((self.values[:, 1:], x[:, None]), 1)
275 | 
276 |     def variance(self):
277 |         if self.values is None:
278 |             return
279 |         if self.norm:
280 |             return torch.var(self.values, 1) / self.values.shape[1]
281 |         else:
282 |             return torch.var(self.values, 1)
283 | 
284 | 
285 | class StoppingCriteriaScores(StoppingCriteria):
286 |     def __init__(self, threshold: float = 0.015, window_size: int = 200):
287 |         super().__init__()
288 |         self.threshold = threshold
289 |         self.vars = RunningVarTorch(norm=True)
290 |         self.varvars = RunningVarTorch(L=window_size)
291 |         self.stop_inds = defaultdict(int)
292 |         self.stopped = defaultdict(bool)
293 |         self.size = 0
294 |         self.window_size = window_size
295 | 
296 |     @torch.no_grad()
297 |     def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
298 |         last_scores = scores[-1]
299 |         self.vars.push(last_scores.max(1)[0].float().cpu())
300 |         self.varvars.push(self.vars.variance())
301 |         self.size += 1
302 |         if self.size < self.window_size:
303 |             return False
304 | 
305 |         varvar = self.varvars.variance()
306 |         for b in range(len(last_scores)):
307 |             if varvar[b] < self.threshold:
308 |                 if self.stop_inds[b] > 0 and not self.stopped[b]:
309 |                     self.stopped[b] = self.stop_inds[b] >= self.size
310 |                 else:
311 |                     self.stop_inds[b] = int(min(max(self.size, 1) * 1.15 + 150 + self.window_size, 4095))
312 |             else:
313 |                 self.stop_inds[b] = 0
314 |                 self.stopped[b] = False
315 |         return all(self.stopped.values()) and len(self.stopped) > 0
316 | 
317 | 
318 | def batch(l, b=15):
319 |     subs = []
320 |     for i in range(len(l) - b):
321 |         subs.append(l[i : i + b])
322 |     return subs
323 | 
324 | 
325 | def subdiv(l, b=10):
326 |     subs = []
327 |     for i in range(len(l) - b):
328 |         subs.append(l[: i + b])
329 |     return subs
330 | 
331 | 
332 | class DonutModel(PreTrainedModel):
333 |     config_class = DonutConfig
334 |     base_model_prefix = "donut"
335 | 
336 |     def __init__(self, config: DonutConfig, vision_tower=None, tokenizer=None):
337 |         super().__init__(config)
338 |         self.config = config
339 | 
340 |         self.tokenizer = tokenizer
341 |         self.vpm = vision_tower
342 | 
343 |         # build language model
344 |         self.llm = BARTDecoder(
345 |             tokenizer=tokenizer,
346 |             decoder_layer=self.config.decoder_layer,
347 |             max_position_embeddings=self.config.max_position_embeddings,
348 |             hidden_dimension=self.config.hidden_dimension,
349 |         )
350 |         self.ids_to_tokens = {id: content for content, id in self.llm.tokenizer.vocab.items()}
351 | 
352 |     def get_input_embeddings(self, tensor):
353 |         return self.llm.model.get_input_embeddings()(tensor)
354 | 
355 |     def forward(
356 |         self,
357 |         inputs: dict,
358 |     ):
359 |         image_tensors = inputs["pixel_values"]
360 |         input_ids = inputs["input_ids"].contiguous()
361 |         attention_mask = inputs["attention_mask"]
362 |         labels = inputs["labels"].contiguous()
363 | 
364 |         encoder_outputs = self.vpm(
365 |             image_tensors,
366 |             text_embedding=self.llm.model.get_input_embeddings()(input_ids),
367 |         )
368 | 
369 |         decoder_outputs = self.llm(
370 |             input_ids=input_ids,
371 |             encoder_hidden_states=encoder_outputs,
372 |             attention_mask=attention_mask,
373 |             labels=labels,
374 |         )
375 |         return decoder_outputs
376 | 
377 |     def get_hidden_states_during_inference(
378 |         self,
379 |         prompt_ids: torch.Tensor,
380 |         image: Image.Image = None,
381 |         image_tensors: Optional[torch.Tensor] = None,
382 |     ):
383 |         if image_tensors is None:
384 |             image_tensors = self.vpm.prepare_input(image).unsqueeze(0)
385 | 
386 |         if self.device.type != "mps":
387 |             image_tensors = image_tensors.to(next(self.parameters()).dtype)
388 | 
389 |         image_tensors = image_tensors.to(self.device)
390 |         prompt_ids = prompt_ids.to(self.device)
391 |         all_hidden_states = self.vpm.forward_features(
392 |             image_tensors, text_embedding=self.get_input_embeddings(prompt_ids)
393 |         )
394 |         return all_hidden_states
395 | 
396 |     def get_attn_weights_during_inference(
397 |         self,
398 |         prompt_ids: torch.Tensor,
399 |         image: Image.Image = None,
400 |         image_tensors: Optional[torch.Tensor] = None,
401 |     ):
402 |         if image_tensors is None:
403 |             image_tensors = self.vpm.prepare_input(image).unsqueeze(0)
404 | 
405 |         if self.device.type != "mps":
406 |             image_tensors = image_tensors.to(next(self.parameters()).dtype)
407 | 
408 |         image_tensors = image_tensors.to(self.device)
409 |         prompt_ids = prompt_ids.to(self.device)
410 |         last_attn_score = self.vpm.get_last_layer_cross_attn_score(
411 |             image_tensors, text_embedding=self.get_input_embeddings(prompt_ids)
412 |         )
413 |         return last_attn_score
414 | 
415 |     def inference(
416 |         self,
417 |         prompt_ids: torch.Tensor,
418 |         image: Image.Image = None,
419 |         image_tensors: Optional[torch.Tensor] = None,
420 |         return_attentions: bool = False,
421 |         early_stopping: bool = True,
422 |     ):
423 |         """
424 |         Generate a token sequence in an auto-regressive manner.
425 | 
426 |         Args:
427 |             image: input document image (PIL.Image)
428 |             image_tensors: (1, num_channels, height, width)
429 |                 convert prompt to tensor if image_tensor is not fed
430 |         """
431 |         output = {
432 |             "predictions": list(),
433 |             "sequences": list(),
434 |             "repeats": list(),
435 |             "repetitions": list(),
436 |         }
437 |         if image is None and image_tensors is None:
438 |             logging.warn("Image not found")
439 |             return output
440 | 
441 |         if image_tensors is None:
442 |             image_tensors = self.vpm.prepare_input(image).unsqueeze(0)
443 | 
444 |         if self.device.type != "mps":
445 |             image_tensors = image_tensors.to(next(self.parameters()).dtype)
446 | 
447 |         image_tensors = image_tensors.to(self.device)
448 |         prompt_ids = prompt_ids.to(self.device)
449 |         last_hidden_state = self.vpm(image_tensors, text_embedding=self.get_input_embeddings(prompt_ids))
450 | 
451 |         encoder_outputs = ModelOutput(last_hidden_state=last_hidden_state, attentions=None)
452 |         if len(encoder_outputs.last_hidden_state.size()) == 1:
453 |             encoder_outputs.last_hidden_state = encoder_outputs.last_hidden_state.unsqueeze(0)
454 | 
455 |         # get decoder output
456 |         decoder_output = self.llm.model.generate(
457 |             input_ids=prompt_ids,
458 |             encoder_outputs=encoder_outputs,
459 |             min_length=1,
460 |             max_length=self.config.max_length,
461 |             pad_token_id=self.llm.tokenizer.pad_token_id,
462 |             eos_token_id=self.llm.tokenizer.eos_token_id,
463 |             use_cache=True,
464 |             return_dict_in_generate=True,
465 |             output_scores=True,
466 |             output_attentions=return_attentions,
467 |             do_sample=False,
468 |             num_beams=1,
469 |             stopping_criteria=StoppingCriteriaList([StoppingCriteriaScores()] if early_stopping else []),
470 |         )
471 | 
472 |         output["repetitions"] = decoder_output.sequences.clone()
473 |         output["sequences"] = decoder_output.sequences.clone()
474 |         output["scores"] = torch.stack(decoder_output.scores, 1).softmax(-1).cpu().max(-1)[0]
475 | 
476 |         output["repetitions"] = self.llm.tokenizer.batch_decode(output["repetitions"], skip_special_tokens=False)
477 |         return output
478 | 


--------------------------------------------------------------------------------
/utils/processor.py:
--------------------------------------------------------------------------------
 1 | """
 2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
 3 | SPDX-License-Identifier: MIT
 4 | """
 5 | 
 6 | import numpy as np
 7 | import torch
 8 | from PIL import ImageOps
 9 | from torchvision import transforms
10 | from torchvision.transforms.functional import resize
11 | 
12 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
13 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
14 | 
15 | 
16 | class DolphinProcessor:
17 |     def __init__(
18 |         self,
19 |         dp_config,
20 |         tokenizer,
21 |         **kwargs,
22 |     ) -> None:
23 | 
24 |         self.tokenizer = tokenizer
25 |         transform_args = kwargs.get("transform_args", {})
26 |         self.max_length = transform_args.get("max_length", 2048)
27 |         self.input_size = transform_args.get("input_size", [896, 896])  # height, width
28 |         if isinstance(self.input_size, int):
29 |             self.input_size = [self.input_size, self.input_size]
30 | 
31 |         try:
32 |             self.answer_start_token = self.tokenizer._prompt_end_token
33 |         except AttributeError as err:
34 |             print('No answer_start_token found, use "" instead')
35 |             self.answer_start_token = ""
36 | 
37 |         self.prefix_answer_space_flag = dp_config.get("prefix_answer_space_flag", True)
38 |         self.suffix_prompt_space_flag = dp_config.get("suffix_prompt_space_flag", True)
39 | 
40 |         self.transform = transforms.Compose(
41 |             [transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)]
42 |         )
43 | 
44 |     def process_prompt_for_inference(self, prompt):
45 |         prompt = prompt.replace("<image>\n", "")
46 |         if not prompt.startswith("<s>"):
47 |             prompt = "<s>" + prompt
48 |         message_ids = [self.tokenizer.encode(prompt, add_special_tokens=False)]
49 |         ids = torch.from_numpy(np.hstack(message_ids, dtype=np.int32))
50 |         return ids.unsqueeze(0)
51 | 
52 |     def process_image_for_inference(self, image, return_img_size=False):
53 |         image = resize(image, min(self.input_size))
54 | 
55 |         image.thumbnail((self.input_size[1], self.input_size[0]))
56 |         origin_w, origin_h = image.size
57 | 
58 |         delta_width = self.input_size[1] - image.width
59 |         delta_height = self.input_size[0] - image.height
60 |         pad_width = delta_width // 2
61 |         pad_height = delta_height // 2
62 |         padding = (
63 |             pad_width,
64 |             pad_height,
65 |             delta_width - pad_width,
66 |             delta_height - pad_height,
67 |         )
68 |         image = ImageOps.expand(image, padding)
69 |         if return_img_size:
70 |             return self.transform(image).unsqueeze(0), (origin_w, origin_h)
71 |         return self.transform(image).unsqueeze(0)
72 | 


--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
  1 | """
  2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
  3 | SPDX-License-Identifier: MIT
  4 | """
  5 | 
  6 | import copy
  7 | import io
  8 | import json
  9 | import os
 10 | import re
 11 | from dataclasses import dataclass
 12 | from typing import List, Tuple
 13 | 
 14 | import cv2
 15 | import numpy as np
 16 | import pymupdf
 17 | from PIL import Image
 18 | 
 19 | from utils.markdown_utils import MarkdownConverter
 20 | 
 21 | 
 22 | def save_figure_to_local(pil_crop, save_dir, image_name, reading_order):
 23 |     """Save cropped figure to local file system
 24 | 
 25 |     Args:
 26 |         pil_crop: PIL Image object of the cropped figure
 27 |         save_dir: Base directory to save results
 28 |         image_name: Name of the source image/document
 29 |         reading_order: Reading order of the figure in the document
 30 | 
 31 |     Returns:
 32 |         str: Filename of the saved figure
 33 |     """
 34 |     try:
 35 |         # Create figures directory if it doesn't exist
 36 |         figures_dir = os.path.join(save_dir, "markdown", "figures")
 37 |         # os.makedirs(figures_dir, exist_ok=True)
 38 | 
 39 |         # Generate figure filename
 40 |         figure_filename = f"{image_name}_figure_{reading_order:03d}.png"
 41 |         figure_path = os.path.join(figures_dir, figure_filename)
 42 | 
 43 |         # Save the figure
 44 |         pil_crop.save(figure_path, format="PNG", quality=95)
 45 | 
 46 |         # print(f"Saved figure: {figure_filename}")
 47 |         return figure_filename
 48 | 
 49 |     except Exception as e:
 50 |         print(f"Error saving figure: {str(e)}")
 51 |         # Return a fallback filename
 52 |         return f"{image_name}_figure_{reading_order:03d}_error.png"
 53 | 
 54 | 
 55 | def convert_pdf_to_images(pdf_path, target_size=896):
 56 |     """Convert PDF pages to images
 57 | 
 58 |     Args:
 59 |         pdf_path: Path to PDF file
 60 |         target_size: Target size for the longest dimension
 61 | 
 62 |     Returns:
 63 |         List of PIL Images
 64 |     """
 65 |     images = []
 66 |     try:
 67 |         doc = pymupdf.open(pdf_path)
 68 | 
 69 |         for page_num in range(len(doc)):
 70 |             page = doc[page_num]
 71 | 
 72 |             # Calculate scale to make longest dimension equal to target_size
 73 |             rect = page.rect
 74 |             scale = target_size / max(rect.width, rect.height)
 75 | 
 76 |             # Render page as image
 77 |             mat = pymupdf.Matrix(scale, scale)
 78 |             pix = page.get_pixmap(matrix=mat)
 79 | 
 80 |             # Convert to PIL Image
 81 |             img_data = pix.tobytes("png")
 82 |             pil_image = Image.open(io.BytesIO(img_data))
 83 |             images.append(pil_image)
 84 | 
 85 |         doc.close()
 86 |         print(f"Successfully converted {len(images)} pages from PDF")
 87 |         return images
 88 | 
 89 |     except Exception as e:
 90 |         print(f"Error converting PDF to images: {str(e)}")
 91 |         return []
 92 | 
 93 | 
 94 | def is_pdf_file(file_path):
 95 |     """Check if file is a PDF"""
 96 |     return file_path.lower().endswith(".pdf")
 97 | 
 98 | 
 99 | def save_combined_pdf_results(all_page_results, pdf_path, save_dir):
100 |     """Save combined results for multi-page PDF with both JSON and Markdown
101 | 
102 |     Args:
103 |         all_page_results: List of results for all pages
104 |         pdf_path: Path to original PDF file
105 |         save_dir: Directory to save results
106 | 
107 |     Returns:
108 |         Path to saved combined JSON file
109 |     """
110 |     # Create output filename based on PDF name
111 |     base_name = os.path.splitext(os.path.basename(pdf_path))[0]
112 | 
113 |     # Prepare combined results
114 |     combined_results = {"source_file": pdf_path, "total_pages": len(all_page_results), "pages": all_page_results}
115 | 
116 |     # Save combined JSON results
117 |     json_filename = f"{base_name}.json"
118 |     json_path = os.path.join(save_dir, "recognition_json", json_filename)
119 |     os.makedirs(os.path.dirname(json_path), exist_ok=True)
120 | 
121 |     with open(json_path, "w", encoding="utf-8") as f:
122 |         json.dump(combined_results, f, indent=2, ensure_ascii=False)
123 | 
124 |     # Generate and save combined markdown
125 |     try:
126 |         markdown_converter = MarkdownConverter()
127 | 
128 |         # Combine all page results into a single list for markdown conversion
129 |         all_elements = []
130 |         for page_data in all_page_results:
131 |             page_elements = page_data.get("elements", [])
132 |             if page_elements:
133 |                 # Add page separator if not the first page
134 |                 if all_elements:
135 |                     all_elements.append(
136 |                         {"label": "page_separator", "text": f"\n\n---\n\n", "reading_order": len(all_elements)}
137 |                     )
138 |                 all_elements.extend(page_elements)
139 | 
140 |         # Generate markdown content
141 |         markdown_content = markdown_converter.convert(all_elements)
142 | 
143 |         # Save markdown file
144 |         markdown_filename = f"{base_name}.md"
145 |         markdown_path = os.path.join(save_dir, "markdown", markdown_filename)
146 |         os.makedirs(os.path.dirname(markdown_path), exist_ok=True)
147 | 
148 |         with open(markdown_path, "w", encoding="utf-8") as f:
149 |             f.write(markdown_content)
150 | 
151 |         # print(f"Combined markdown saved to: {markdown_path}")
152 | 
153 |     except ImportError:
154 |         print("MarkdownConverter not available, skipping markdown generation")
155 |     except Exception as e:
156 |         print(f"Error generating markdown: {e}")
157 | 
158 |     # print(f"Combined JSON results saved to: {json_path}")
159 |     return json_path
160 | 
161 | 
162 | def check_coord_valid(x1, y1, x2, y2, image_size=None, abs_coord=True):
163 |     # print(f"check_coord_valid: {x1}, {y1}, {x2}, {y2}, {image_size}, {abs_coord}")
164 |     if x2 <= x1 or y2 <= y1:
165 |         return False, f"[{x1}, {y1}, {x2}, {y2}]"
166 |     if x1 < 0 or y1 < 0:
167 |         return False, f"[{x1}, {y1}, {x2}, {y2}]"
168 |     if not abs_coord:
169 |         if x2 > 1 or y2 > 1:
170 |             return False, f"[{x1}, {y1}, {x2}, {y2}]"
171 |     elif image_size is not None:  # has image size
172 |         if x2 > image_size[0] or y2 > image_size[1]:
173 |             return False, f"[{x1}, {y1}, {x2}, {y2}]"
174 |     return True, None
175 | 
176 | 
177 | def adjust_box_edges(image, boxes: List[List[float]], max_pixels=15, threshold=0.2):
178 |     """
179 |     Image: cv2.image object, or Path
180 |     Input: boxes: list of boxes [[x1, y1, x2, y2]]. Using absolute coordinates.
181 |     """
182 |     if isinstance(image, str):
183 |         image = cv2.imread(image)
184 |     img_h, img_w = image.shape[:2]
185 |     new_boxes = []
186 |     for box in boxes:
187 |         best_box = copy.deepcopy(box)
188 | 
189 |         def check_edge(img, current_box, i, is_vertical):
190 |             edge = current_box[i]
191 |             gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
192 |             _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
193 | 
194 |             if is_vertical:
195 |                 line = binary[current_box[1] : current_box[3] + 1, edge]
196 |             else:
197 |                 line = binary[edge, current_box[0] : current_box[2] + 1]
198 | 
199 |             transitions = np.abs(np.diff(line))
200 |             return np.sum(transitions) / len(transitions)
201 | 
202 |         # Only widen the box
203 |         edges = [(0, -1, True), (2, 1, True), (1, -1, False), (3, 1, False)]
204 | 
205 |         current_box = copy.deepcopy(box)
206 |         # make sure the box is within the image
207 |         current_box[0] = min(max(current_box[0], 0), img_w - 1)
208 |         current_box[1] = min(max(current_box[1], 0), img_h - 1)
209 |         current_box[2] = min(max(current_box[2], 0), img_w - 1)
210 |         current_box[3] = min(max(current_box[3], 0), img_h - 1)
211 | 
212 |         for i, direction, is_vertical in edges:
213 |             best_score = check_edge(image, current_box, i, is_vertical)
214 |             if best_score <= threshold:
215 |                 continue
216 |             for step in range(max_pixels):
217 |                 current_box[i] += direction
218 |                 if i == 0 or i == 2:
219 |                     current_box[i] = min(max(current_box[i], 0), img_w - 1)
220 |                 else:
221 |                     current_box[i] = min(max(current_box[i], 0), img_h - 1)
222 |                 score = check_edge(image, current_box, i, is_vertical)
223 | 
224 |                 if score < best_score:
225 |                     best_score = score
226 |                     best_box = copy.deepcopy(current_box)
227 | 
228 |                 if score <= threshold:
229 |                     break
230 |         new_boxes.append(best_box)
231 | 
232 |     return new_boxes
233 | 
234 | 
235 | def parse_layout_string(bbox_str):
236 |     """Parse layout string using regular expressions"""
237 |     pattern = r"\[(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+)\]\s*(\w+)"
238 |     matches = re.finditer(pattern, bbox_str)
239 | 
240 |     parsed_results = []
241 |     for match in matches:
242 |         coords = [float(match.group(i)) for i in range(1, 5)]
243 |         label = match.group(5).strip()
244 |         parsed_results.append((coords, label))
245 | 
246 |     return parsed_results
247 | 
248 | 
249 | @dataclass
250 | class ImageDimensions:
251 |     """Class to store image dimensions"""
252 | 
253 |     original_w: int
254 |     original_h: int
255 |     padded_w: int
256 |     padded_h: int
257 | 
258 | 
259 | def map_to_original_coordinates(x1, y1, x2, y2, dims: ImageDimensions) -> Tuple[int, int, int, int]:
260 |     """Map coordinates from padded image back to original image
261 | 
262 |     Args:
263 |         x1, y1, x2, y2: Coordinates in padded image
264 |         dims: Image dimensions object
265 | 
266 |     Returns:
267 |         tuple: (x1, y1, x2, y2) coordinates in original image
268 |     """
269 |     try:
270 |         # Calculate padding offsets
271 |         top = (dims.padded_h - dims.original_h) // 2
272 |         left = (dims.padded_w - dims.original_w) // 2
273 | 
274 |         # Map back to original coordinates
275 |         orig_x1 = max(0, x1 - left)
276 |         orig_y1 = max(0, y1 - top)
277 |         orig_x2 = min(dims.original_w, x2 - left)
278 |         orig_y2 = min(dims.original_h, y2 - top)
279 | 
280 |         # Ensure we have a valid box (width and height > 0)
281 |         if orig_x2 <= orig_x1:
282 |             orig_x2 = min(orig_x1 + 1, dims.original_w)
283 |         if orig_y2 <= orig_y1:
284 |             orig_y2 = min(orig_y1 + 1, dims.original_h)
285 | 
286 |         return int(orig_x1), int(orig_y1), int(orig_x2), int(orig_y2)
287 |     except Exception as e:
288 |         print(f"map_to_original_coordinates error: {str(e)}")
289 |         # Return safe coordinates
290 |         return 0, 0, min(100, dims.original_w), min(100, dims.original_h)
291 | 
292 | 
293 | def map_to_relevant_coordinates(abs_coords, dims: ImageDimensions):
294 |     """
295 |     From absolute coordinates to relevant coordinates
296 |     e.g. [100, 100, 200, 200] -> [0.1, 0.2, 0.3, 0.4]
297 |     """
298 |     try:
299 |         x1, y1, x2, y2 = abs_coords
300 |         return (
301 |             round(x1 / dims.original_w, 3),
302 |             round(y1 / dims.original_h, 3),
303 |             round(x2 / dims.original_w, 3),
304 |             round(y2 / dims.original_h, 3),
305 |         )
306 |     except Exception as e:
307 |         print(f"map_to_relevant_coordinates error: {str(e)}")
308 |         return 0.0, 0.0, 1.0, 1.0  # Return full image coordinates
309 | 
310 | 
311 | def process_coordinates(coords, padded_image, dims: ImageDimensions, previous_box=None):
312 |     """Process and adjust coordinates
313 | 
314 |     Args:
315 |         coords: Normalized coordinates [x1, y1, x2, y2]
316 |         padded_image: Padded image
317 |         dims: Image dimensions object
318 |         previous_box: Previous box coordinates for overlap adjustment
319 | 
320 |     Returns:
321 |         tuple: (x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box)
322 |     """
323 |     try:
324 |         # Convert normalized coordinates to absolute coordinates
325 |         x1, y1 = int(coords[0] * dims.padded_w), int(coords[1] * dims.padded_h)
326 |         x2, y2 = int(coords[2] * dims.padded_w), int(coords[3] * dims.padded_h)
327 | 
328 |         # Ensure coordinates are within image bounds before adjustment
329 |         x1 = max(0, min(x1, dims.padded_w - 1))
330 |         y1 = max(0, min(y1, dims.padded_h - 1))
331 |         x2 = max(0, min(x2, dims.padded_w))
332 |         y2 = max(0, min(y2, dims.padded_h))
333 | 
334 |         # Ensure width and height are at least 1 pixel
335 |         if x2 <= x1:
336 |             x2 = min(x1 + 1, dims.padded_w)
337 |         if y2 <= y1:
338 |             y2 = min(y1 + 1, dims.padded_h)
339 | 
340 |         # Extend box boundaries
341 |         new_boxes = adjust_box_edges(padded_image, [[x1, y1, x2, y2]])
342 |         x1, y1, x2, y2 = new_boxes[0]
343 | 
344 |         # Ensure coordinates are still within image bounds after adjustment
345 |         x1 = max(0, min(x1, dims.padded_w - 1))
346 |         y1 = max(0, min(y1, dims.padded_h - 1))
347 |         x2 = max(0, min(x2, dims.padded_w))
348 |         y2 = max(0, min(y2, dims.padded_h))
349 | 
350 |         # Ensure width and height are at least 1 pixel after adjustment
351 |         if x2 <= x1:
352 |             x2 = min(x1 + 1, dims.padded_w)
353 |         if y2 <= y1:
354 |             y2 = min(y1 + 1, dims.padded_h)
355 | 
356 |         # Check for overlap with previous box and adjust
357 |         if previous_box is not None:
358 |             prev_x1, prev_y1, prev_x2, prev_y2 = previous_box
359 |             if (x1 < prev_x2 and x2 > prev_x1) and (y1 < prev_y2 and y2 > prev_y1):
360 |                 y1 = prev_y2
361 |                 # Ensure y1 is still valid
362 |                 y1 = min(y1, dims.padded_h - 1)
363 |                 # Make sure y2 is still greater than y1
364 |                 if y2 <= y1:
365 |                     y2 = min(y1 + 1, dims.padded_h)
366 | 
367 |         # Update previous box
368 |         new_previous_box = [x1, y1, x2, y2]
369 | 
370 |         # Map to original coordinates
371 |         orig_x1, orig_y1, orig_x2, orig_y2 = map_to_original_coordinates(x1, y1, x2, y2, dims)
372 | 
373 |         return x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box
374 |     except Exception as e:
375 |         print(f"process_coordinates error: {str(e)}")
376 |         # Return safe values
377 |         orig_x1, orig_y1, orig_x2, orig_y2 = 0, 0, min(100, dims.original_w), min(100, dims.original_h)
378 |         return 0, 0, 100, 100, orig_x1, orig_y1, orig_x2, orig_y2, [0, 0, 100, 100]
379 | 
380 | 
381 | def prepare_image(image) -> Tuple[np.ndarray, ImageDimensions]:
382 |     """Load and prepare image with padding while maintaining aspect ratio
383 | 
384 |     Args:
385 |         image: PIL image
386 | 
387 |     Returns:
388 |         tuple: (padded_image, image_dimensions)
389 |     """
390 |     try:
391 |         # Convert PIL image to OpenCV format
392 |         image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
393 |         original_h, original_w = image.shape[:2]
394 | 
395 |         # Calculate padding to make square image
396 |         max_size = max(original_h, original_w)
397 |         top = (max_size - original_h) // 2
398 |         bottom = max_size - original_h - top
399 |         left = (max_size - original_w) // 2
400 |         right = max_size - original_w - left
401 | 
402 |         # Apply padding
403 |         padded_image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(0, 0, 0))
404 | 
405 |         padded_h, padded_w = padded_image.shape[:2]
406 | 
407 |         dimensions = ImageDimensions(original_w=original_w, original_h=original_h, padded_w=padded_w, padded_h=padded_h)
408 | 
409 |         return padded_image, dimensions
410 |     except Exception as e:
411 |         print(f"prepare_image error: {str(e)}")
412 |         # Create a minimal valid image and dimensions
413 |         h, w = image.height, image.width
414 |         dimensions = ImageDimensions(original_w=w, original_h=h, padded_w=w, padded_h=h)
415 |         # Return a black image of the same size
416 |         return np.zeros((h, w, 3), dtype=np.uint8), dimensions
417 | 
418 | 
419 | def setup_output_dirs(save_dir):
420 |     """Create necessary output directories"""
421 |     os.makedirs(save_dir, exist_ok=True)
422 |     os.makedirs(os.path.join(save_dir, "markdown"), exist_ok=True)
423 |     os.makedirs(os.path.join(save_dir, "recognition_json"), exist_ok=True)
424 |     os.makedirs(os.path.join(save_dir, "markdown", "figures"), exist_ok=True)
425 | 
426 | 
427 | def save_outputs(recognition_results, image_path, save_dir):
428 |     """Save JSON and markdown outputs"""
429 |     basename = os.path.splitext(os.path.basename(image_path))[0]
430 | 
431 |     # Save JSON file
432 |     json_path = os.path.join(save_dir, "recognition_json", f"{basename}.json")
433 |     with open(json_path, "w", encoding="utf-8") as f:
434 |         json.dump(recognition_results, f, ensure_ascii=False, indent=2)
435 | 
436 |     # Generate and save markdown file
437 |     markdown_converter = MarkdownConverter()
438 |     markdown_content = markdown_converter.convert(recognition_results)
439 |     markdown_path = os.path.join(save_dir, "markdown", f"{basename}.md")
440 |     with open(markdown_path, "w", encoding="utf-8") as f:
441 |         f.write(markdown_content)
442 | 
443 |     return json_path
444 | 
445 | 
446 | def crop_margin(img: Image.Image) -> Image.Image:
447 |     """Crop margins from image"""
448 |     try:
449 |         width, height = img.size
450 |         if width == 0 or height == 0:
451 |             print("Warning: Image has zero width or height")
452 |             return img
453 | 
454 |         data = np.array(img.convert("L"))
455 |         data = data.astype(np.uint8)
456 |         max_val = data.max()
457 |         min_val = data.min()
458 |         if max_val == min_val:
459 |             return img
460 |         data = (data - min_val) / (max_val - min_val) * 255
461 |         gray = 255 * (data < 200).astype(np.uint8)
462 | 
463 |         coords = cv2.findNonZero(gray)  # Find all non-zero points (text)
464 |         if coords is None:
465 |             return img
466 |         a, b, w, h = cv2.boundingRect(coords)  # Find minimum spanning bounding box
467 | 
468 |         # Ensure crop coordinates are within image bounds
469 |         a = max(0, a)
470 |         b = max(0, b)
471 |         w = min(w, width - a)
472 |         h = min(h, height - b)
473 | 
474 |         # Only crop if we have a valid region
475 |         if w > 0 and h > 0:
476 |             return img.crop((a, b, a + w, b + h))
477 |         return img
478 |     except Exception as e:
479 |         print(f"crop_margin error: {str(e)}")
480 |         return img  # Return original image on error
481 | 


--------------------------------------------------------------------------------