├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── README_CN.md ├── assets ├── demo.gif ├── dolphin.png └── framework.png ├── chat.py ├── config └── Dolphin.yaml ├── demo ├── element_imgs │ ├── block_formula.jpeg │ ├── line_formula.jpeg │ ├── para_1.jpg │ ├── para_2.jpg │ ├── para_3.jpeg │ ├── table_1.jpeg │ └── table_2.jpeg └── page_imgs │ ├── page_1.jpeg │ ├── page_2.jpeg │ ├── page_3.jpeg │ ├── page_4.png │ ├── page_5.jpg │ ├── page_6.pdf │ └── page_7.jpeg ├── demo_element.py ├── demo_element_hf.py ├── demo_page.py ├── demo_page_hf.py ├── deployment ├── ReadMe.md ├── tensorrt_llm │ ├── ReadMe.md │ ├── api_client.py │ ├── api_server.py │ ├── convert │ │ ├── __init__.py │ │ ├── build_visual_engine.py │ │ ├── convert_checkpoint.py │ │ └── helper.py │ ├── convert_dolphin.sh │ ├── dolphin_runner.py │ ├── run_dolphin.py │ ├── run_dolphin.sh │ ├── start_dolphin_server.sh │ └── utils.py └── vllm │ ├── ReadMe.md │ ├── api_client.py │ ├── api_server.py │ └── demo_vllm.py ├── pyproject.toml ├── requirements.txt └── utils ├── markdown_utils.py ├── model.py ├── processor.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .nox/ 42 | .coverage 43 | *.cover 44 | *.py,cover 45 | .hypothesis/ 46 | .pytest_cache/ 47 | coverage.xml 48 | *.mo 49 | *.pot 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | db.sqlite3-journal 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # IPython 78 | profile_default/ 79 | ipython_config.py 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # pipenv 85 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 86 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 87 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 88 | # install all needed dependencies. 89 | #Pipfile.lock 90 | 91 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 92 | __pypackages__/ 93 | 94 | # Celery stuff 95 | celerybeat-schedule 96 | celerybeat.pid 97 | 98 | # SageMath parsed files 99 | *.sage.py 100 | 101 | # Environments 102 | .env 103 | .venv 104 | env/ 105 | venv/ 106 | ENV/ 107 | env.bak/ 108 | venv.bak/ 109 | 110 | # Spyder project settings 111 | .spyderproject 112 | .spyproject 113 | 114 | # Rope project settings 115 | .ropeproject 116 | 117 | # mkdocs documentation 118 | /site 119 | 120 | # mypy 121 | .mypy_cache/ 122 | .dmypy.json 123 | dmypy.json 124 | 125 | # Pyre type checker 126 | .pyre/ 127 | 128 | # pytype static type analyzer 129 | .pytype/ 130 | 131 | # Cython debug symbols 132 | cython_debug/ 133 | 134 | # PyCharm 135 | .idea/ 136 | *.iml 137 | 138 | # VS Code 139 | .vscode/ 140 | !.vscode/settings.json 141 | !.vscode/tasks.json 142 | !.vscode/launch.json 143 | !.vscode/extensions.json 144 | 145 | # macOS 146 | .DS_Store 147 | 148 | # Windows 149 | Thumbs.db 150 | ehthumbs.db 151 | Desktop.ini 152 | 153 | fusion_result.json 154 | kernel_meta/ 155 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | # 1. isort - 自动排序 Python imports 3 | - repo: https://github.com/pycqa/isort 4 | rev: 6.0.1 # 使用固定版本号 5 | hooks: 6 | - id: isort 7 | name: isort (python) 8 | args: [--profile=black] # 与 Black 兼容的配置 9 | language: python 10 | 11 | # 2. Black - 自动格式化 Python 代码 12 | - repo: https://github.com/psf/black 13 | rev: 25.1.0 # 使用固定版本号 14 | hooks: 15 | - id: black 16 | language: python 17 | 18 | # 3. flake8 - Python 静态检查 19 | - repo: https://github.com/pycqa/flake8 20 | rev: 7.2.0 21 | hooks: 22 | - id: flake8 23 | args: [--max-line-length=120, --ignore=E203] # 设置行长度为 120 24 | additional_dependencies: [flake8-bugbear==24.12.12] # 可选:增强检查 25 | 26 | # 4. pre-commit-hooks - 通用 Git 钩子 27 | - repo: https://github.com/pre-commit/pre-commit-hooks 28 | rev: v5.0.0 29 | hooks: 30 | - id: trailing-whitespace # 删除行尾空格 31 | - id: end-of-file-fixer # 确保文件以换行符结束 32 | - id: check-yaml # 验证 YAML 文件语法 33 | - id: check-added-large-files # 阻止大文件提交 34 | args: ["--maxkb=512"] 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright 2025 ByteDance Ltd. and/or its affiliates 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | <div align="center"> 2 | <img src="./assets/dolphin.png" width="300"> 3 | </div> 4 | 5 | <div align="center"> 6 | <a href="https://arxiv.org/abs/2505.14059"> 7 | <img src="https://img.shields.io/badge/Paper-arXiv-red"> 8 | </a> 9 | <a href="https://huggingface.co/ByteDance/Dolphin"> 10 | <img src="https://img.shields.io/badge/HuggingFace-Dolphin-yellow"> 11 | </a> 12 | <a href="https://modelscope.cn/models/ByteDance/Dolphin"> 13 | <img src="https://img.shields.io/badge/ModelScope-Dolphin-purple"> 14 | </a> 15 | <a href="http://115.190.42.15:8888/dolphin/"> 16 | <img src="https://img.shields.io/badge/Demo-Dolphin-blue"> 17 | </a> 18 | <a href="https://github.com/bytedance/Dolphin"> 19 | <img src="https://img.shields.io/badge/Code-Github-green"> 20 | </a> 21 | <a href="https://opensource.org/licenses/MIT"> 22 | <img src="https://img.shields.io/badge/License-MIT-lightgray"> 23 | </a> 24 | <br> 25 | </div> 26 | 27 | <br> 28 | 29 | <div align="center"> 30 | <img src="./assets/demo.gif" width="800"> 31 | </div> 32 | 33 | # Dolphin: Document Image Parsing via Heterogeneous Anchor Prompting 34 | 35 | Dolphin (**Do**cument Image **P**arsing via **H**eterogeneous Anchor Prompt**in**g) is a novel multimodal document image parsing model following an analyze-then-parse paradigm. This repository contains the demo code and pre-trained models for Dolphin. 36 | 37 | ## 📑 Overview 38 | 39 | Document image parsing is challenging due to its complexly intertwined elements such as text paragraphs, figures, formulas, and tables. Dolphin addresses these challenges through a two-stage approach: 40 | 41 | 1. **🔍 Stage 1**: Comprehensive page-level layout analysis by generating element sequence in natural reading order 42 | 2. **🧩 Stage 2**: Efficient parallel parsing of document elements using heterogeneous anchors and task-specific prompts 43 | 44 | <div align="center"> 45 | <img src="./assets/framework.png" width="680"> 46 | </div> 47 | 48 | Dolphin achieves promising performance across diverse page-level and element-level parsing tasks while ensuring superior efficiency through its lightweight architecture and parallel parsing mechanism. 49 | 50 | ## 🚀 Demo 51 | Try our demo on [Demo-Dolphin](http://115.190.42.15:8888/dolphin/). 52 | 53 | ## 📅 Changelog 54 | - 🔥 **2025.07.10** Released the *Fox-Page Benchmark*, a manually refined subset of the original [Fox dataset](https://github.com/ucaslcl/Fox). Download via: [Baidu Yun](https://pan.baidu.com/share/init?surl=t746ULp6iU5bUraVrPlMSw&pwd=fox1) | [Google Drive](https://drive.google.com/file/d/1yZQZqI34QCqvhB4Tmdl3X_XEvYvQyP0q/view?usp=sharing). 55 | - 🔥 **2025.06.30** Added [TensorRT-LLM support](https://github.com/bytedance/Dolphin/blob/master/deployment/tensorrt_llm/ReadMe.md) for accelerated inference! 56 | - 🔥 **2025.06.27** Added [vLLM support](https://github.com/bytedance/Dolphin/blob/master/deployment/vllm/ReadMe.md) for accelerated inference! 57 | - 🔥 **2025.06.13** Added multi-page PDF document parsing capability. 58 | - 🔥 **2025.05.21** Our demo is released at [link](http://115.190.42.15:8888/dolphin/). Check it out! 59 | - 🔥 **2025.05.20** The pretrained model and inference code of Dolphin are released. 60 | - 🔥 **2025.05.16** Our paper has been accepted by ACL 2025. Paper link: [arXiv](https://arxiv.org/abs/2505.14059). 61 | 62 | ## 🛠️ Installation 63 | 64 | 1. Clone the repository: 65 | ```bash 66 | git clone https://github.com/ByteDance/Dolphin.git 67 | cd Dolphin 68 | ``` 69 | 70 | 2. Install the dependencies: 71 | ```bash 72 | pip install -r requirements.txt 73 | ``` 74 | 75 | 3. Download the pre-trained models using one of the following options: 76 | 77 | **Option A: Original Model Format (config-based)** 78 | 79 | Download from [Baidu Yun](https://pan.baidu.com/s/15zcARoX0CTOHKbW8bFZovQ?pwd=9rpx) or [Google Drive](https://drive.google.com/drive/folders/1PQJ3UutepXvunizZEw-uGaQ0BCzf-mie?usp=sharing) and put them in the `./checkpoints` folder. 80 | 81 | **Option B: Hugging Face Model Format** 82 | 83 | Visit our Huggingface [model card](https://huggingface.co/ByteDance/Dolphin), or download model by: 84 | 85 | ```bash 86 | # Download the model from Hugging Face Hub 87 | git lfs install 88 | git clone https://huggingface.co/ByteDance/Dolphin ./hf_model 89 | # Or use the Hugging Face CLI 90 | pip install huggingface_hub 91 | huggingface-cli download ByteDance/Dolphin --local-dir ./hf_model 92 | ``` 93 | 94 | ## ⚡ Inference 95 | 96 | Dolphin provides two inference frameworks with support for two parsing granularities: 97 | - **Page-level Parsing**: Parse the entire document page into a structured JSON and Markdown format 98 | - **Element-level Parsing**: Parse individual document elements (text, table, formula) 99 | 100 | ### 📄 Page-level Parsing 101 | 102 | #### Using Original Framework (config-based) 103 | 104 | ```bash 105 | # Process a single document image 106 | python demo_page.py --config ./config/Dolphin.yaml --input_path ./demo/page_imgs/page_1.jpeg --save_dir ./results 107 | 108 | # Process a single document pdf 109 | python demo_page.py --config ./config/Dolphin.yaml --input_path ./demo/page_imgs/page_6.pdf --save_dir ./results 110 | 111 | # Process all documents in a directory 112 | python demo_page.py --config ./config/Dolphin.yaml --input_path ./demo/page_imgs --save_dir ./results 113 | 114 | # Process with custom batch size for parallel element decoding 115 | python demo_page.py --config ./config/Dolphin.yaml --input_path ./demo/page_imgs --save_dir ./results --max_batch_size 8 116 | ``` 117 | 118 | #### Using Hugging Face Framework 119 | 120 | ```bash 121 | # Process a single document image 122 | python demo_page_hf.py --model_path ./hf_model --input_path ./demo/page_imgs/page_1.jpeg --save_dir ./results 123 | 124 | # Process a single document pdf 125 | python demo_page_hf.py --model_path ./hf_model --input_path ./demo/page_imgs/page_6.pdf --save_dir ./results 126 | 127 | # Process all documents in a directory 128 | python demo_page_hf.py --model_path ./hf_model --input_path ./demo/page_imgs --save_dir ./results 129 | 130 | # Process with custom batch size for parallel element decoding 131 | python demo_page_hf.py --model_path ./hf_model --input_path ./demo/page_imgs --save_dir ./results --max_batch_size 16 132 | ``` 133 | 134 | ### 🧩 Element-level Parsing 135 | 136 | #### Using Original Framework (config-based) 137 | 138 | ```bash 139 | # Process a single table image 140 | python demo_element.py --config ./config/Dolphin.yaml --input_path ./demo/element_imgs/table_1.jpeg --element_type table 141 | 142 | # Process a single formula image 143 | python demo_element.py --config ./config/Dolphin.yaml --input_path ./demo/element_imgs/line_formula.jpeg --element_type formula 144 | 145 | # Process a single text paragraph image 146 | python demo_element.py --config ./config/Dolphin.yaml --input_path ./demo/element_imgs/para_1.jpg --element_type text 147 | ``` 148 | 149 | #### Using Hugging Face Framework 150 | 151 | ```bash 152 | # Process a single table image 153 | python demo_element_hf.py --model_path ./hf_model --input_path ./demo/element_imgs/table_1.jpeg --element_type table 154 | 155 | # Process a single formula image 156 | python demo_element_hf.py --model_path ./hf_model --input_path ./demo/element_imgs/line_formula.jpeg --element_type formula 157 | 158 | # Process a single text paragraph image 159 | python demo_element_hf.py --model_path ./hf_model --input_path ./demo/element_imgs/para_1.jpg --element_type text 160 | ``` 161 | 162 | ## 🌟 Key Features 163 | 164 | - 🔄 Two-stage analyze-then-parse approach based on a single VLM 165 | - 📊 Promising performance on document parsing tasks 166 | - 🔍 Natural reading order element sequence generation 167 | - 🧩 Heterogeneous anchor prompting for different document elements 168 | - ⏱️ Efficient parallel parsing mechanism 169 | - 🤗 Support for Hugging Face Transformers for easier integration 170 | 171 | 172 | ## 📮 Notice 173 | **Call for Bad Cases:** If you have encountered any cases where the model performs poorly, we would greatly appreciate it if you could share them in the issue. We are continuously working to optimize and improve the model. 174 | 175 | ## 💖 Acknowledgement 176 | 177 | We would like to acknowledge the following open-source projects that provided inspiration and reference for this work: 178 | - [Donut](https://github.com/clovaai/donut/) 179 | - [Nougat](https://github.com/facebookresearch/nougat) 180 | - [GOT](https://github.com/Ucas-HaoranWei/GOT-OCR2.0) 181 | - [MinerU](https://github.com/opendatalab/MinerU/tree/master) 182 | - [Swin](https://github.com/microsoft/Swin-Transformer) 183 | - [Hugging Face Transformers](https://github.com/huggingface/transformers) 184 | 185 | ## 📝 Citation 186 | 187 | If you find this code useful for your research, please use the following BibTeX entry. 188 | 189 | ```bibtex 190 | @article{feng2025dolphin, 191 | title={Dolphin: Document Image Parsing via Heterogeneous Anchor Prompting}, 192 | author={Feng, Hao and Wei, Shu and Fei, Xiang and Shi, Wei and Han, Yingdong and Liao, Lei and Lu, Jinghui and Wu, Binghong and Liu, Qi and Lin, Chunhui and others}, 193 | journal={arXiv preprint arXiv:2505.14059}, 194 | year={2025} 195 | } 196 | ``` 197 | 198 | ## Star History 199 | 200 | [](https://www.star-history.com/#bytedance/Dolphin&Date) 201 | -------------------------------------------------------------------------------- /README_CN.md: -------------------------------------------------------------------------------- 1 | <div align="center"> 2 | <img src="./assets/dolphin.png" width="300"> 3 | </div> 4 | 5 | <div align="center"> 6 | <a href="https://arxiv.org/abs/2505.14059"> 7 | <img src="https://img.shields.io/badge/论文-arXiv-red"> 8 | </a> 9 | <a href="https://huggingface.co/ByteDance/Dolphin"> 10 | <img src="https://img.shields.io/badge/HuggingFace-Dolphin-yellow"> 11 | </a> 12 | <a href="https://modelscope.cn/models/ByteDance/Dolphin"> 13 | <img src="https://img.shields.io/badge/ModelScope-Dolphin-purple"> 14 | </a> 15 | <a href="https://huggingface.co/spaces/ByteDance/Dolphin"> 16 | <img src="https://img.shields.io/badge/演示-Dolphin-blue"> 17 | </a> 18 | <a href="https://github.com/bytedance/Dolphin"> 19 | <img src="https://img.shields.io/badge/代码-Github-green"> 20 | </a> 21 | <a href="https://opensource.org/licenses/MIT"> 22 | <img src="https://img.shields.io/badge/许可证-MIT-lightgray"> 23 | </a> 24 | <br> 25 | </div> 26 | 27 | <br> 28 | 29 | <div align="center"> 30 | <img src="./assets/demo.gif" width="800"> 31 | </div> 32 | 33 | # Dolphin: 基于异构锚点提示的文档图像解析 34 | 35 | Dolphin(**Do**cument Image **P**arsing via **H**eterogeneous Anchor Prompt**in**g)是一个创新的多模态文档图像解析模型,采用"分析-解析"的两阶段范式。本仓库包含Dolphin的演示代码和预训练模型。 36 | 37 | ## 📑 概述 38 | 39 | 由于文档图像中文本段落、图表、公式和表格等元素的复杂交织,文档图像解析具有挑战性。Dolphin通过两阶段方法解决这些挑战: 40 | 41 | 1. **🔍 第一阶段**:通过按自然阅读顺序生成元素序列进行全面的页面级布局分析 42 | 2. **🧩 第二阶段**:使用异构锚点和任务特定提示高效并行解析文档元素 43 | 44 | <div align="center"> 45 | <img src="./assets/framework.png" width="680"> 46 | </div> 47 | 48 | Dolphin在多样化的页面级和元素级解析任务中取得了优异的性能,同时通过其轻量级架构和并行解析机制确保了卓越的效率。 49 | 50 | ## 🚀 演示 51 | 在 [Demo-Dolphin](http://115.190.42.15:8888/dolphin/) 上试用我们的演示。 52 | 53 | ## 📅 更新日志 54 | - 🔥 **2025.06.30** 新增[TensorRT-LLM](https://github.com/bytedance/Dolphin/blob/master/deployment/tensorrt_llm/ReadMe.md)支持,提升推理速度! 55 | - 🔥 **2025.06.27** 新增[vLLM](https://github.com/bytedance/Dolphin/blob/master/deployment/vllm/ReadMe.md)支持,提升推理速度! 56 | - 🔥 **2025.06.13** 新增多页PDF文档解析功能。 57 | - 🔥 **2025.05.21** 我们的演示已在 [链接](http://115.190.42.15:8888/dolphin/) 发布。快来体验吧! 58 | - 🔥 **2025.05.20** Dolphin的预训练模型和推理代码已发布。 59 | - 🔥 **2025.05.16** 我们的论文已被ACL 2025接收。论文链接:[arXiv](https://arxiv.org/abs/2505.14059)。 60 | 61 | ## 🛠️ 安装 62 | 63 | 1. 克隆仓库: 64 | ```bash 65 | git clone https://github.com/ByteDance/Dolphin.git 66 | cd Dolphin 67 | ``` 68 | 69 | 2. 安装依赖: 70 | ```bash 71 | pip install -r requirements.txt 72 | ``` 73 | 74 | 3. 使用以下选项之一下载预训练模型: 75 | 76 | **选项A:原始模型格式(基于配置文件)** 77 | 78 | 从 [百度网盘](https://pan.baidu.com/s/15zcARoX0CTOHKbW8bFZovQ?pwd=9rpx) 或 [Google Drive](https://drive.google.com/drive/folders/1PQJ3UutepXvunizZEw-uGaQ0BCzf-mie?usp=sharing) 下载,并将其放在 `./checkpoints` 文件夹中。 79 | 80 | **选项B:Hugging Face模型格式** 81 | 82 | 访问我们的Huggingface [模型卡片](https://huggingface.co/ByteDance/Dolphin),或通过以下方式下载模型: 83 | 84 | ```bash 85 | # 从Hugging Face Hub下载模型 86 | git lfs install 87 | git clone https://huggingface.co/ByteDance/Dolphin ./hf_model 88 | # 或使用Hugging Face CLI 89 | pip install huggingface_hub 90 | huggingface-cli download ByteDance/Dolphin --local-dir ./hf_model 91 | ``` 92 | 93 | ## ⚡ 推理 94 | 95 | Dolphin提供两个推理框架,支持两种解析粒度: 96 | - **页面级解析**:将整个文档页面解析为结构化的JSON和Markdown格式 97 | - **元素级解析**:解析单个文档元素(文本、表格、公式) 98 | 99 | ### 📄 页面级解析 100 | 101 | #### 使用原始框架(基于配置文件) 102 | 103 | ```bash 104 | # 处理单个文档图像 105 | python demo_page.py --config ./config/Dolphin.yaml --input_path ./demo/page_imgs/page_1.jpeg --save_dir ./results 106 | 107 | # 处理单个文档PDF 108 | python demo_page.py --config ./config/Dolphin.yaml --input_path ./demo/page_imgs/page_6.pdf --save_dir ./results 109 | 110 | # 处理目录中的所有文档 111 | python demo_page.py --config ./config/Dolphin.yaml --input_path ./demo/page_imgs --save_dir ./results 112 | 113 | # 使用自定义批次大小进行并行元素解码 114 | python demo_page.py --config ./config/Dolphin.yaml --input_path ./demo/page_imgs --save_dir ./results --max_batch_size 8 115 | ``` 116 | 117 | #### 使用Hugging Face框架 118 | 119 | ```bash 120 | # 处理单个文档图像 121 | python demo_page_hf.py --model_path ./hf_model --input_path ./demo/page_imgs/page_1.jpeg --save_dir ./results 122 | 123 | # 处理单个文档PDF 124 | python demo_page_hf.py --model_path ./hf_model --input_path ./demo/page_imgs/page_6.pdf --save_dir ./results 125 | 126 | # 处理目录中的所有文档 127 | python demo_page_hf.py --model_path ./hf_model --input_path ./demo/page_imgs --save_dir ./results 128 | 129 | # 使用自定义批次大小进行并行元素解码 130 | python demo_page_hf.py --model_path ./hf_model --input_path ./demo/page_imgs --save_dir ./results --max_batch_size 16 131 | ``` 132 | 133 | ### 🧩 元素级解析 134 | 135 | #### 使用原始框架(基于配置文件) 136 | 137 | ```bash 138 | # 处理单个表格图像 139 | python demo_element.py --config ./config/Dolphin.yaml --input_path ./demo/element_imgs/table_1.jpeg --element_type table 140 | 141 | # 处理单个公式图像 142 | python demo_element.py --config ./config/Dolphin.yaml --input_path ./demo/element_imgs/line_formula.jpeg --element_type formula 143 | 144 | # 处理单个文本段落图像 145 | python demo_element.py --config ./config/Dolphin.yaml --input_path ./demo/element_imgs/para_1.jpg --element_type text 146 | ``` 147 | 148 | #### 使用Hugging Face框架 149 | 150 | ```bash 151 | # 处理单个表格图像 152 | python demo_element_hf.py --model_path ./hf_model --input_path ./demo/element_imgs/table_1.jpeg --element_type table 153 | 154 | # 处理单个公式图像 155 | python demo_element_hf.py --model_path ./hf_model --input_path ./demo/element_imgs/line_formula.jpeg --element_type formula 156 | 157 | # 处理单个文本段落图像 158 | python demo_element_hf.py --model_path ./hf_model --input_path ./demo/element_imgs/para_1.jpg --element_type text 159 | ``` 160 | 161 | ## 🌟 主要特性 162 | 163 | - 🔄 基于单一VLM的两阶段分析-解析方法 164 | - 📊 在文档解析任务上的优异性能 165 | - 🔍 自然阅读顺序元素序列生成 166 | - 🧩 针对不同文档元素的异构锚点提示 167 | - ⏱️ 高效的并行解析机制 168 | - 🤗 支持Hugging Face Transformers,便于集成 169 | 170 | ## 📮 通知 171 | **征集不良案例:** 如果您遇到模型表现不佳的案例,我们非常欢迎您在issue中分享。我们正在持续优化和改进模型。 172 | 173 | ## 💖 致谢 174 | 175 | 我们要感谢以下开源项目为本工作提供的灵感和参考: 176 | - [Donut](https://github.com/clovaai/donut/) 177 | - [Nougat](https://github.com/facebookresearch/nougat) 178 | - [GOT](https://github.com/Ucas-HaoranWei/GOT-OCR2.0) 179 | - [MinerU](https://github.com/opendatalab/MinerU/tree/master) 180 | - [Swin](https://github.com/microsoft/Swin-Transformer) 181 | - [Hugging Face Transformers](https://github.com/huggingface/transformers) 182 | 183 | ## 📝 引用 184 | 185 | 如果您在研究中发现此代码有用,请使用以下BibTeX条目。 186 | 187 | ```bibtex 188 | @article{feng2025dolphin, 189 | title={Dolphin: Document Image Parsing via Heterogeneous Anchor Prompting}, 190 | author={Feng, Hao and Wei, Shu and Fei, Xiang and Shi, Wei and Han, Yingdong and Liao, Lei and Lu, Jinghui and Wu, Binghong and Liu, Qi and Lin, Chunhui and others}, 191 | journal={arXiv preprint arXiv:2505.14059}, 192 | year={2025} 193 | } 194 | ``` 195 | 196 | ## 星标历史 197 | 198 | [](https://www.star-history.com/#bytedance/Dolphin&Date) 199 | -------------------------------------------------------------------------------- /assets/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/assets/demo.gif -------------------------------------------------------------------------------- /assets/dolphin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/assets/dolphin.png -------------------------------------------------------------------------------- /assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/assets/framework.png -------------------------------------------------------------------------------- /chat.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 3 | SPDX-License-Identifier: MIT 4 | """ 5 | 6 | import os 7 | import warnings 8 | from collections import OrderedDict 9 | 10 | from omegaconf import ListConfig 11 | 12 | warnings.filterwarnings("ignore", category=UserWarning) 13 | warnings.filterwarnings("ignore", category=FutureWarning) 14 | os.environ.setdefault("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION", "python") 15 | 16 | import torch 17 | from PIL import Image 18 | from transformers import PreTrainedTokenizerFast 19 | 20 | from utils.model import DonutConfig, DonutModel, SwinEncoder 21 | from utils.processor import DolphinProcessor 22 | 23 | 24 | def try_rename_lagacy_weights(ckpt, output_path=""): 25 | if "state_dict" in ckpt.keys(): 26 | ckpt = ckpt["state_dict"] 27 | if "module" in ckpt.keys(): 28 | ckpt = ckpt["module"] 29 | new_ckpt = OrderedDict() 30 | for k, v in ckpt.items(): 31 | if k.startswith("model."): 32 | k = k[len("model.") :] 33 | if k.startswith("encoder"): 34 | new_ckpt["vpm" + k[len("encoder") :]] = v 35 | elif k.startswith("decoder"): 36 | new_ckpt["llm" + k[len("encoder") :]] = v 37 | else: 38 | new_ckpt[k] = v 39 | if output_path: 40 | torch.save(new_ckpt, output_path) 41 | return new_ckpt 42 | 43 | 44 | def convert_listconfig_to_list(config): 45 | new_config = {} 46 | for k, v in config.items(): 47 | if isinstance(v, ListConfig): 48 | new_config[k] = list(v) 49 | else: 50 | new_config[k] = v 51 | return new_config 52 | 53 | 54 | class DOLPHIN: 55 | def __init__(self, config, ckpt_path="") -> None: 56 | self.model_args = config.model 57 | self.swin_args = config.model.pop("swin_args") 58 | self.swin_args = convert_listconfig_to_list(self.swin_args) 59 | 60 | vision_tower = SwinEncoder( 61 | input_size=self.swin_args["img_size"], 62 | patch_size=self.swin_args["patch_size"], 63 | embed_dim=self.swin_args["embed_dim"], 64 | window_size=self.swin_args["window_size"], 65 | encoder_layer=self.swin_args["encoder_layer"], 66 | num_heads=self.swin_args["num_heads"], 67 | align_long_axis=self.swin_args["align_long_axis"], 68 | ) 69 | 70 | self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=self.model_args.tokenizer_path) 71 | self.tokenizer.pad_token = "<pad>" 72 | self.tokenizer.bos_token = "<s>" 73 | self.tokenizer.eos_token = "</s>" 74 | self.tokenizer.unk_token = "<unk>" 75 | 76 | if self.model_args.get("extra_answer_tokens", False): 77 | # print("Allowing multitask training: adding <Answer/> to the tokenizer.") 78 | prompt_end_token = " <Answer/>" 79 | self.tokenizer.add_special_tokens({"additional_special_tokens": sorted(set([prompt_end_token]))}) 80 | self.tokenizer._prompt_end_token = prompt_end_token 81 | self.tokenizer._prompt_end_token_id = self.tokenizer.convert_tokens_to_ids(prompt_end_token) 82 | 83 | donut_config = DonutConfig( 84 | decoder_layer=self.model_args.decoder_layer, 85 | max_length=self.model_args.max_length, 86 | max_position_embeddings=self.model_args.max_position_embeddings, 87 | hidden_dimension=self.model_args.hidden_dimension, 88 | ) 89 | 90 | self.model = DonutModel(config=donut_config, vision_tower=vision_tower, tokenizer=self.tokenizer) 91 | if self.model_args.model_name_or_path: 92 | ckpt = torch.load(self.model_args.model_name_or_path) 93 | ckpt = try_rename_lagacy_weights(ckpt) 94 | self.model.load_state_dict(ckpt, strict=True) 95 | 96 | device = "cuda" if torch.cuda.is_available() else "cpu" 97 | self.model.to(device) 98 | self.model.eval() 99 | transform_args = { 100 | "input_size": self.swin_args["img_size"], 101 | "max_length": self.model_args.max_length, 102 | } 103 | self.processor = DolphinProcessor({}, self.tokenizer, transform_args=transform_args) 104 | 105 | def chat( 106 | self, 107 | question, 108 | image, 109 | return_raw=False, 110 | return_score=False, 111 | return_img_size=False, 112 | only_return_img_size=False, 113 | max_batch_size=16, 114 | ): 115 | 116 | def _preprocess_image(image): 117 | if isinstance(image, str): 118 | image = Image.open(image).convert("RGB") 119 | if return_img_size or only_return_img_size: 120 | image_tensor, ori_size = self.processor.process_image_for_inference(image, return_img_size=True) 121 | else: 122 | image_tensor = self.processor.process_image_for_inference(image, return_img_size=False) 123 | ori_size = None 124 | return image_tensor, ori_size 125 | 126 | def _preprocess_prompt(question): 127 | if self.model_args.get("extra_answer_tokens", False): 128 | if self.tokenizer._prompt_end_token not in question: 129 | question = question + self.tokenizer._prompt_end_token 130 | prompt_ids = self.processor.process_prompt_for_inference(question) 131 | return prompt_ids 132 | 133 | def _preprocess_prompt_batch(question): 134 | if self.model_args.get("extra_answer_tokens", False): 135 | for i in range(len(question)): 136 | if self.tokenizer._prompt_end_token not in question[i]: 137 | question[i] = question[i] + self.tokenizer._prompt_end_token 138 | if not question[i].startswith("<s>"): 139 | question[i] = "<s>" + question[i] 140 | return question 141 | 142 | def _postprocess(output, question): 143 | output = output.replace("<s>", "").replace(question, "").replace("</s>", "").replace("<pad>", "") 144 | if self.model_args.get("extra_answer_tokens", False): 145 | output = output.split(self.tokenizer._prompt_end_token)[-1] 146 | return output 147 | 148 | if isinstance(question, list): 149 | image_tensor_list = [] 150 | for i in image: 151 | image_tensor, ori_size = _preprocess_image(i) 152 | image_tensor_list.append(image_tensor) 153 | image_tensor = torch.cat(image_tensor_list, dim=0) 154 | 155 | question = _preprocess_prompt_batch(question) 156 | self.processor.tokenizer.padding_side = "left" 157 | prompt_ids = self.processor.tokenizer( 158 | question, add_special_tokens=False, return_tensors="pt", padding=True 159 | ).input_ids 160 | else: 161 | image_tensor, ori_size = _preprocess_image(image) 162 | prompt_ids = _preprocess_prompt(question) 163 | 164 | if only_return_img_size: 165 | return ori_size 166 | 167 | model_output_batch = [] 168 | for i in range(0, image_tensor.shape[0], max_batch_size): 169 | image_tensor_batch = image_tensor[i : i + max_batch_size] 170 | prompt_ids_batch = prompt_ids[i : i + max_batch_size] 171 | model_output = self.model.inference(image_tensors=image_tensor_batch, prompt_ids=prompt_ids_batch) 172 | model_output_batch.append(model_output) 173 | model_output = {} 174 | for k, v in model_output_batch[0].items(): 175 | if isinstance(v, torch.Tensor): 176 | model_output[k] = sum( 177 | [v_batch[k].cpu().numpy().tolist() for v_batch in model_output_batch], 178 | [], 179 | ) 180 | else: 181 | model_output[k] = sum([v_batch[k] for v_batch in model_output_batch], []) 182 | 183 | if return_raw: 184 | if return_img_size: 185 | return model_output, ori_size 186 | return model_output 187 | else: 188 | if isinstance(question, list): 189 | output = [_postprocess(model_output["repetitions"][i], question[i]) for i in range(len(question))] 190 | score = model_output["scores"] 191 | else: 192 | output = _postprocess(model_output["repetitions"][0], question) 193 | score = model_output["scores"][0] 194 | if return_score: 195 | return output, score 196 | if return_img_size: 197 | return output, ori_size 198 | return output 199 | -------------------------------------------------------------------------------- /config/Dolphin.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | model_name_or_path: "./checkpoints/dolphin_model.bin" 3 | tokenizer_path: "./checkpoints/dolphin_tokenizer.json" 4 | extra_answer_tokens: True # add <Answer/> token 5 | max_length: 4096 6 | decoder_layer: 10 7 | max_position_embeddings: 4096 8 | hidden_dimension: 1024 9 | swin_args: 10 | name: 'swin' 11 | img_size: [896, 896] 12 | patch_size: 4 13 | embed_dim: 128 14 | align_long_axis: False 15 | window_size: 7 16 | encoder_layer: [2, 2, 14, 2] 17 | num_heads: [4, 8, 16, 32] 18 | -------------------------------------------------------------------------------- /demo/element_imgs/block_formula.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/demo/element_imgs/block_formula.jpeg -------------------------------------------------------------------------------- /demo/element_imgs/line_formula.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/demo/element_imgs/line_formula.jpeg -------------------------------------------------------------------------------- /demo/element_imgs/para_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/demo/element_imgs/para_1.jpg -------------------------------------------------------------------------------- /demo/element_imgs/para_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/demo/element_imgs/para_2.jpg -------------------------------------------------------------------------------- /demo/element_imgs/para_3.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/demo/element_imgs/para_3.jpeg -------------------------------------------------------------------------------- /demo/element_imgs/table_1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/demo/element_imgs/table_1.jpeg -------------------------------------------------------------------------------- /demo/element_imgs/table_2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/demo/element_imgs/table_2.jpeg -------------------------------------------------------------------------------- /demo/page_imgs/page_1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/demo/page_imgs/page_1.jpeg -------------------------------------------------------------------------------- /demo/page_imgs/page_2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/demo/page_imgs/page_2.jpeg -------------------------------------------------------------------------------- /demo/page_imgs/page_3.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/demo/page_imgs/page_3.jpeg -------------------------------------------------------------------------------- /demo/page_imgs/page_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/demo/page_imgs/page_4.png -------------------------------------------------------------------------------- /demo/page_imgs/page_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/demo/page_imgs/page_5.jpg -------------------------------------------------------------------------------- /demo/page_imgs/page_6.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/demo/page_imgs/page_6.pdf -------------------------------------------------------------------------------- /demo/page_imgs/page_7.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/demo/page_imgs/page_7.jpeg -------------------------------------------------------------------------------- /demo_element.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 3 | SPDX-License-Identifier: MIT 4 | """ 5 | 6 | import argparse 7 | import glob 8 | import os 9 | 10 | from omegaconf import OmegaConf 11 | from PIL import Image 12 | 13 | from chat import DOLPHIN 14 | from utils.utils import * 15 | 16 | 17 | def process_element(image_path, model, element_type, save_dir=None): 18 | """Process a single element image (text, table, formula) 19 | 20 | Args: 21 | image_path: Path to the element image 22 | model: DOLPHIN model instance 23 | element_type: Type of element ('text', 'table', 'formula') 24 | save_dir: Directory to save results (default: same as input directory) 25 | 26 | Returns: 27 | Parsed content of the element and recognition results 28 | """ 29 | # Load and prepare image 30 | pil_image = Image.open(image_path).convert("RGB") 31 | pil_image = crop_margin(pil_image) 32 | 33 | # Select appropriate prompt based on element type 34 | if element_type == "table": 35 | prompt = "Parse the table in the image." 36 | label = "tab" 37 | elif element_type == "formula": 38 | prompt = "Read text in the image." 39 | label = "formula" 40 | else: # Default to text 41 | prompt = "Read text in the image." 42 | label = "text" 43 | 44 | # Process the element 45 | result = model.chat(prompt, pil_image) 46 | 47 | # Create recognition result in the same format as the document parser 48 | recognition_result = [ 49 | { 50 | "label": label, 51 | "text": result.strip(), 52 | } 53 | ] 54 | 55 | # Save results if save_dir is provided 56 | if save_dir: 57 | save_outputs(recognition_result, image_path, save_dir) 58 | print(f"Results saved to {save_dir}") 59 | 60 | return result, recognition_result 61 | 62 | 63 | def main(): 64 | parser = argparse.ArgumentParser(description="Element-level processing using DOLPHIN model") 65 | parser.add_argument("--config", default="./config/Dolphin.yaml", help="Path to configuration file") 66 | parser.add_argument("--input_path", type=str, required=True, help="Path to input image or directory of images") 67 | parser.add_argument( 68 | "--element_type", 69 | type=str, 70 | choices=["text", "table", "formula"], 71 | default="text", 72 | help="Type of element to process (text, table, formula)", 73 | ) 74 | parser.add_argument( 75 | "--save_dir", 76 | type=str, 77 | default=None, 78 | help="Directory to save parsing results (default: same as input directory)", 79 | ) 80 | parser.add_argument("--print_results", action="store_true", help="Print recognition results to console") 81 | args = parser.parse_args() 82 | 83 | # Load Model 84 | config = OmegaConf.load(args.config) 85 | model = DOLPHIN(config) 86 | 87 | # Set save directory 88 | save_dir = args.save_dir or ( 89 | args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path) 90 | ) 91 | setup_output_dirs(save_dir) 92 | 93 | # Collect Images 94 | if os.path.isdir(args.input_path): 95 | image_files = [] 96 | for ext in [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"]: 97 | image_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}"))) 98 | image_files = sorted(image_files) 99 | else: 100 | if not os.path.exists(args.input_path): 101 | raise FileNotFoundError(f"Input path {args.input_path} does not exist") 102 | image_files = [args.input_path] 103 | 104 | total_samples = len(image_files) 105 | print(f"\nTotal samples to process: {total_samples}") 106 | 107 | # Process images one by one 108 | for image_path in image_files: 109 | print(f"\nProcessing {image_path}") 110 | try: 111 | result, recognition_result = process_element( 112 | image_path=image_path, 113 | model=model, 114 | element_type=args.element_type, 115 | save_dir=save_dir, 116 | ) 117 | 118 | if args.print_results: 119 | print("\nRecognition result:") 120 | print(result) 121 | print("-" * 40) 122 | 123 | except Exception as e: 124 | print(f"Error processing {image_path}: {str(e)}") 125 | continue 126 | 127 | 128 | if __name__ == "__main__": 129 | main() 130 | -------------------------------------------------------------------------------- /demo_element_hf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 3 | SPDX-License-Identifier: MIT 4 | """ 5 | 6 | import argparse 7 | import glob 8 | import os 9 | 10 | import torch 11 | from PIL import Image 12 | from transformers import AutoProcessor, VisionEncoderDecoderModel 13 | 14 | from utils.utils import * 15 | 16 | 17 | class DOLPHIN: 18 | def __init__(self, model_id_or_path): 19 | """Initialize the Hugging Face model 20 | 21 | Args: 22 | model_id_or_path: Path to local model or Hugging Face model ID 23 | """ 24 | # Load model from local path or Hugging Face hub 25 | self.processor = AutoProcessor.from_pretrained(model_id_or_path) 26 | self.model = VisionEncoderDecoderModel.from_pretrained(model_id_or_path) 27 | self.model.eval() 28 | 29 | # Set device and precision 30 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 31 | self.model.to(self.device) 32 | self.model = self.model.half() # Always use half precision by default 33 | 34 | # set tokenizer 35 | self.tokenizer = self.processor.tokenizer 36 | 37 | def chat(self, prompt, image): 38 | """Process an image with the given prompt 39 | 40 | Args: 41 | prompt: Text prompt to guide the model 42 | image: PIL Image to process 43 | 44 | Returns: 45 | Generated text from the model 46 | """ 47 | # Prepare image 48 | pixel_values = self.processor(image, return_tensors="pt").pixel_values 49 | pixel_values = pixel_values.half() 50 | 51 | # Prepare prompt 52 | prompt = f"<s>{prompt} <Answer/>" 53 | prompt_ids = self.tokenizer( 54 | prompt, 55 | add_special_tokens=False, 56 | return_tensors="pt" 57 | ).input_ids.to(self.device) 58 | 59 | decoder_attention_mask = torch.ones_like(prompt_ids) 60 | 61 | # Generate text 62 | outputs = self.model.generate( 63 | pixel_values=pixel_values.to(self.device), 64 | decoder_input_ids=prompt_ids, 65 | decoder_attention_mask=decoder_attention_mask, 66 | min_length=1, 67 | max_length=4096, 68 | pad_token_id=self.tokenizer.pad_token_id, 69 | eos_token_id=self.tokenizer.eos_token_id, 70 | use_cache=True, 71 | bad_words_ids=[[self.tokenizer.unk_token_id]], 72 | return_dict_in_generate=True, 73 | do_sample=False, 74 | num_beams=1, 75 | repetition_penalty=1.1, 76 | temperature=1.0 77 | ) 78 | 79 | # Process the output 80 | sequence = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)[0] 81 | sequence = sequence.replace(prompt, "").replace("<pad>", "").replace("</s>", "").strip() 82 | 83 | return sequence 84 | 85 | def process_element(image_path, model, element_type, save_dir=None): 86 | """Process a single element image (text, table, formula) 87 | 88 | Args: 89 | image_path: Path to the element image 90 | model: HFModel model instance 91 | element_type: Type of element ('text', 'table', 'formula') 92 | save_dir: Directory to save results (default: same as input directory) 93 | 94 | Returns: 95 | Parsed content of the element and recognition results 96 | """ 97 | # Load and prepare image 98 | pil_image = Image.open(image_path).convert("RGB") 99 | pil_image = crop_margin(pil_image) 100 | 101 | # Select appropriate prompt based on element type 102 | if element_type == "table": 103 | prompt = "Parse the table in the image." 104 | label = "tab" 105 | elif element_type == "formula": 106 | prompt = "Read text in the image." 107 | label = "formula" 108 | else: # Default to text 109 | prompt = "Read text in the image." 110 | label = "text" 111 | 112 | # Process the element 113 | result = model.chat(prompt, pil_image) 114 | 115 | # Create recognition result in the same format as the document parser 116 | recognition_result = [ 117 | { 118 | "label": label, 119 | "text": result.strip(), 120 | } 121 | ] 122 | 123 | # Save results if save_dir is provided 124 | if save_dir: 125 | save_outputs(recognition_result, image_path, save_dir) 126 | print(f"Results saved to {save_dir}") 127 | 128 | return result, recognition_result 129 | 130 | 131 | def main(): 132 | parser = argparse.ArgumentParser(description="Element-level processing using DOLPHIN model") 133 | parser.add_argument("--model_path", default="./hf_model", help="Path to Hugging Face model") 134 | parser.add_argument("--input_path", type=str, required=True, help="Path to input image or directory of images") 135 | parser.add_argument( 136 | "--element_type", 137 | type=str, 138 | choices=["text", "table", "formula"], 139 | default="text", 140 | help="Type of element to process (text, table, formula)", 141 | ) 142 | parser.add_argument( 143 | "--save_dir", 144 | type=str, 145 | default=None, 146 | help="Directory to save parsing results (default: same as input directory)", 147 | ) 148 | parser.add_argument("--print_results", action="store_true", help="Print recognition results to console") 149 | args = parser.parse_args() 150 | 151 | # Load Model 152 | model = DOLPHIN(args.model_path) 153 | 154 | # Set save directory 155 | save_dir = args.save_dir or ( 156 | args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path) 157 | ) 158 | setup_output_dirs(save_dir) 159 | 160 | # Collect Images 161 | if os.path.isdir(args.input_path): 162 | image_files = [] 163 | for ext in [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"]: 164 | image_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}"))) 165 | image_files = sorted(image_files) 166 | else: 167 | if not os.path.exists(args.input_path): 168 | raise FileNotFoundError(f"Input path {args.input_path} does not exist") 169 | image_files = [args.input_path] 170 | 171 | total_samples = len(image_files) 172 | print(f"\nTotal samples to process: {total_samples}") 173 | 174 | # Process images one by one 175 | for image_path in image_files: 176 | print(f"\nProcessing {image_path}") 177 | try: 178 | result, recognition_result = process_element( 179 | image_path=image_path, 180 | model=model, 181 | element_type=args.element_type, 182 | save_dir=save_dir, 183 | ) 184 | 185 | if args.print_results: 186 | print("\nRecognition result:") 187 | print(result) 188 | print("-" * 40) 189 | except Exception as e: 190 | print(f"Error processing {image_path}: {str(e)}") 191 | continue 192 | 193 | 194 | if __name__ == "__main__": 195 | main() 196 | -------------------------------------------------------------------------------- /demo_page.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 3 | SPDX-License-Identifier: MIT 4 | """ 5 | 6 | import argparse 7 | import glob 8 | import os 9 | 10 | import cv2 11 | from omegaconf import OmegaConf 12 | from PIL import Image 13 | 14 | from chat import DOLPHIN 15 | from utils.utils import * 16 | 17 | 18 | def process_document(document_path, model, save_dir, max_batch_size): 19 | """Parse documents - Handles both images and PDFs""" 20 | file_ext = os.path.splitext(document_path)[1].lower() 21 | 22 | if file_ext == '.pdf': 23 | # Process PDF file 24 | # Convert PDF to images 25 | images = convert_pdf_to_images(document_path) 26 | if not images: 27 | raise Exception(f"Failed to convert PDF {document_path} to images") 28 | 29 | all_results = [] 30 | 31 | # Process each page 32 | for page_idx, pil_image in enumerate(images): 33 | print(f"Processing page {page_idx + 1}/{len(images)}") 34 | 35 | # Generate output name for this page 36 | base_name = os.path.splitext(os.path.basename(document_path))[0] 37 | page_name = f"{base_name}_page_{page_idx + 1:03d}" 38 | 39 | # Process this page (don't save individual page results) 40 | json_path, recognition_results = process_single_image( 41 | pil_image, model, save_dir, page_name, max_batch_size, save_individual=False 42 | ) 43 | 44 | # Add page information to results 45 | page_results = { 46 | "page_number": page_idx + 1, 47 | "elements": recognition_results 48 | } 49 | all_results.append(page_results) 50 | 51 | # Save combined results for multi-page PDF 52 | combined_json_path = save_combined_pdf_results(all_results, document_path, save_dir) 53 | 54 | return combined_json_path, all_results 55 | 56 | else: 57 | # Process regular image file 58 | pil_image = Image.open(document_path).convert("RGB") 59 | base_name = os.path.splitext(os.path.basename(document_path))[0] 60 | return process_single_image(pil_image, model, save_dir, base_name, max_batch_size) 61 | 62 | 63 | def process_single_image(image, model, save_dir, image_name, max_batch_size, save_individual=True): 64 | """Process a single image (either from file or converted from PDF page) 65 | 66 | Args: 67 | image: PIL Image object 68 | model: DOLPHIN model instance 69 | save_dir: Directory to save results 70 | image_name: Name for the output file 71 | max_batch_size: Maximum batch size for processing 72 | save_individual: Whether to save individual results (False for PDF pages) 73 | 74 | Returns: 75 | Tuple of (json_path, recognition_results) 76 | """ 77 | # Stage 1: Page-level layout and reading order parsing 78 | layout_output = model.chat("Parse the reading order of this document.", image) 79 | 80 | # Stage 2: Element-level content parsing 81 | padded_image, dims = prepare_image(image) 82 | recognition_results = process_elements(layout_output, padded_image, dims, model, max_batch_size, save_dir, image_name) 83 | 84 | # Save outputs only if requested (skip for PDF pages) 85 | json_path = None 86 | if save_individual: 87 | # Create a dummy image path for save_outputs function 88 | dummy_image_path = f"{image_name}.jpg" # Extension doesn't matter, only basename is used 89 | json_path = save_outputs(recognition_results, dummy_image_path, save_dir) 90 | 91 | return json_path, recognition_results 92 | 93 | 94 | def process_elements(layout_results, padded_image, dims, model, max_batch_size, save_dir=None, image_name=None): 95 | """Parse all document elements with parallel decoding""" 96 | layout_results = parse_layout_string(layout_results) 97 | 98 | text_table_elements = [] # Elements that need processing 99 | figure_results = [] # Figure elements (no processing needed) 100 | previous_box = None 101 | reading_order = 0 102 | 103 | # Collect elements for processing 104 | for bbox, label in layout_results: 105 | try: 106 | # Adjust coordinates 107 | x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = process_coordinates( 108 | bbox, padded_image, dims, previous_box 109 | ) 110 | 111 | # Crop and parse element 112 | cropped = padded_image[y1:y2, x1:x2] 113 | if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3: 114 | if label == "fig": 115 | pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB)) 116 | 117 | figure_filename = save_figure_to_local(pil_crop, save_dir, image_name, reading_order) 118 | 119 | # For figure regions, store relative path instead of base64 120 | figure_results.append( 121 | { 122 | "label": label, 123 | "text": f"", 124 | "figure_path": f"figures/{figure_filename}", 125 | "bbox": [orig_x1, orig_y1, orig_x2, orig_y2], 126 | "reading_order": reading_order, 127 | } 128 | ) 129 | else: 130 | # For text or table regions, prepare for parsing 131 | pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB)) 132 | prompt = "Parse the table in the image." if label == "tab" else "Read text in the image." 133 | text_table_elements.append( 134 | { 135 | "crop": pil_crop, 136 | "prompt": prompt, 137 | "label": label, 138 | "bbox": [orig_x1, orig_y1, orig_x2, orig_y2], 139 | "reading_order": reading_order, 140 | } 141 | ) 142 | 143 | reading_order += 1 144 | 145 | except Exception as e: 146 | print(f"Error processing bbox with label {label}: {str(e)}") 147 | continue 148 | 149 | # Parse text/table elements in parallel 150 | recognition_results = figure_results 151 | if text_table_elements: 152 | crops_list = [elem["crop"] for elem in text_table_elements] 153 | prompts_list = [elem["prompt"] for elem in text_table_elements] 154 | 155 | # Inference in batch 156 | batch_results = model.chat(prompts_list, crops_list, max_batch_size=max_batch_size) 157 | 158 | # Add batch results to recognition_results 159 | for i, result in enumerate(batch_results): 160 | elem = text_table_elements[i] 161 | recognition_results.append( 162 | { 163 | "label": elem["label"], 164 | "bbox": elem["bbox"], 165 | "text": result.strip(), 166 | "reading_order": elem["reading_order"], 167 | } 168 | ) 169 | 170 | # Sort elements by reading order 171 | recognition_results.sort(key=lambda x: x.get("reading_order", 0)) 172 | 173 | return recognition_results 174 | 175 | 176 | def main(): 177 | parser = argparse.ArgumentParser(description="Document parsing based on DOLPHIN") 178 | parser.add_argument("--config", default="./config/Dolphin.yaml", help="Path to configuration file") 179 | parser.add_argument("--input_path", type=str, default="./demo", help="Path to input image/PDF or directory of files") 180 | parser.add_argument( 181 | "--save_dir", 182 | type=str, 183 | default=None, 184 | help="Directory to save parsing results (default: same as input directory)", 185 | ) 186 | parser.add_argument( 187 | "--max_batch_size", 188 | type=int, 189 | default=4, 190 | help="Maximum number of document elements to parse in a single batch (default: 4)", 191 | ) 192 | args = parser.parse_args() 193 | 194 | # Load Model 195 | config = OmegaConf.load(args.config) 196 | model = DOLPHIN(config) 197 | 198 | # Collect Document Files (images and PDFs) 199 | if os.path.isdir(args.input_path): 200 | # Support both image and PDF files 201 | file_extensions = [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG", ".pdf", ".PDF"] 202 | 203 | document_files = [] 204 | for ext in file_extensions: 205 | document_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}"))) 206 | document_files = sorted(document_files) 207 | else: 208 | if not os.path.exists(args.input_path): 209 | raise FileNotFoundError(f"Input path {args.input_path} does not exist") 210 | 211 | # Check if it's a supported file type 212 | file_ext = os.path.splitext(args.input_path)[1].lower() 213 | supported_exts = ['.jpg', '.jpeg', '.png', '.pdf'] 214 | 215 | if file_ext not in supported_exts: 216 | raise ValueError(f"Unsupported file type: {file_ext}. Supported types: {supported_exts}") 217 | 218 | document_files = [args.input_path] 219 | 220 | save_dir = args.save_dir or ( 221 | args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path) 222 | ) 223 | setup_output_dirs(save_dir) 224 | 225 | total_samples = len(document_files) 226 | print(f"\nTotal files to process: {total_samples}") 227 | 228 | # Process All Document Files 229 | for file_path in document_files: 230 | print(f"\nProcessing {file_path}") 231 | try: 232 | json_path, recognition_results = process_document( 233 | document_path=file_path, 234 | model=model, 235 | save_dir=save_dir, 236 | max_batch_size=args.max_batch_size, 237 | ) 238 | 239 | print(f"Processing completed. Results saved to {save_dir}") 240 | 241 | except Exception as e: 242 | print(f"Error processing {file_path}: {str(e)}") 243 | continue 244 | 245 | 246 | if __name__ == "__main__": 247 | main() 248 | -------------------------------------------------------------------------------- /demo_page_hf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 3 | SPDX-License-Identifier: MIT 4 | """ 5 | 6 | import argparse 7 | import glob 8 | import os 9 | 10 | import cv2 11 | import torch 12 | from PIL import Image 13 | from transformers import AutoProcessor, VisionEncoderDecoderModel 14 | 15 | from utils.utils import * 16 | 17 | 18 | class DOLPHIN: 19 | def __init__(self, model_id_or_path): 20 | """Initialize the Hugging Face model 21 | 22 | Args: 23 | model_id_or_path: Path to local model or Hugging Face model ID 24 | """ 25 | # Load model from local path or Hugging Face hub 26 | self.processor = AutoProcessor.from_pretrained(model_id_or_path) 27 | self.model = VisionEncoderDecoderModel.from_pretrained(model_id_or_path) 28 | self.model.eval() 29 | 30 | # Set device and precision 31 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 32 | self.model.to(self.device) 33 | self.model = self.model.half() # Always use half precision by default 34 | 35 | # set tokenizer 36 | self.tokenizer = self.processor.tokenizer 37 | 38 | def chat(self, prompt, image): 39 | """Process an image or batch of images with the given prompt(s) 40 | 41 | Args: 42 | prompt: Text prompt or list of prompts to guide the model 43 | image: PIL Image or list of PIL Images to process 44 | 45 | Returns: 46 | Generated text or list of texts from the model 47 | """ 48 | # Check if we're dealing with a batch 49 | is_batch = isinstance(image, list) 50 | 51 | if not is_batch: 52 | # Single image, wrap it in a list for consistent processing 53 | images = [image] 54 | prompts = [prompt] 55 | else: 56 | # Batch of images 57 | images = image 58 | prompts = prompt if isinstance(prompt, list) else [prompt] * len(images) 59 | 60 | # Prepare image 61 | batch_inputs = self.processor(images, return_tensors="pt", padding=True) 62 | batch_pixel_values = batch_inputs.pixel_values.half().to(self.device) 63 | 64 | # Prepare prompt 65 | prompts = [f"<s>{p} <Answer/>" for p in prompts] 66 | batch_prompt_inputs = self.tokenizer( 67 | prompts, 68 | add_special_tokens=False, 69 | return_tensors="pt" 70 | ) 71 | 72 | batch_prompt_ids = batch_prompt_inputs.input_ids.to(self.device) 73 | batch_attention_mask = batch_prompt_inputs.attention_mask.to(self.device) 74 | 75 | # Generate text 76 | outputs = self.model.generate( 77 | pixel_values=batch_pixel_values, 78 | decoder_input_ids=batch_prompt_ids, 79 | decoder_attention_mask=batch_attention_mask, 80 | min_length=1, 81 | max_length=4096, 82 | pad_token_id=self.tokenizer.pad_token_id, 83 | eos_token_id=self.tokenizer.eos_token_id, 84 | use_cache=True, 85 | bad_words_ids=[[self.tokenizer.unk_token_id]], 86 | return_dict_in_generate=True, 87 | do_sample=False, 88 | num_beams=1, 89 | repetition_penalty=1.1, 90 | temperature=1.0 91 | ) 92 | 93 | # Process output 94 | sequences = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False) 95 | 96 | # Clean prompt text from output 97 | results = [] 98 | for i, sequence in enumerate(sequences): 99 | cleaned = sequence.replace(prompts[i], "").replace("<pad>", "").replace("</s>", "").strip() 100 | results.append(cleaned) 101 | 102 | # Return a single result for single image input 103 | if not is_batch: 104 | return results[0] 105 | return results 106 | 107 | 108 | def process_document(document_path, model, save_dir, max_batch_size=None): 109 | """Parse documents with two stages - Handles both images and PDFs""" 110 | file_ext = os.path.splitext(document_path)[1].lower() 111 | 112 | if file_ext == '.pdf': 113 | # Process PDF file 114 | # Convert PDF to images 115 | images = convert_pdf_to_images(document_path) 116 | if not images: 117 | raise Exception(f"Failed to convert PDF {document_path} to images") 118 | 119 | all_results = [] 120 | 121 | # Process each page 122 | for page_idx, pil_image in enumerate(images): 123 | print(f"Processing page {page_idx + 1}/{len(images)}") 124 | 125 | # Generate output name for this page 126 | base_name = os.path.splitext(os.path.basename(document_path))[0] 127 | page_name = f"{base_name}_page_{page_idx + 1:03d}" 128 | 129 | # Process this page (don't save individual page results) 130 | json_path, recognition_results = process_single_image( 131 | pil_image, model, save_dir, page_name, max_batch_size, save_individual=False 132 | ) 133 | 134 | # Add page information to results 135 | page_results = { 136 | "page_number": page_idx + 1, 137 | "elements": recognition_results 138 | } 139 | all_results.append(page_results) 140 | 141 | # Save combined results for multi-page PDF 142 | combined_json_path = save_combined_pdf_results(all_results, document_path, save_dir) 143 | 144 | return combined_json_path, all_results 145 | 146 | else: 147 | # Process regular image file 148 | pil_image = Image.open(document_path).convert("RGB") 149 | base_name = os.path.splitext(os.path.basename(document_path))[0] 150 | return process_single_image(pil_image, model, save_dir, base_name, max_batch_size) 151 | 152 | 153 | def process_single_image(image, model, save_dir, image_name, max_batch_size=None, save_individual=True): 154 | """Process a single image (either from file or converted from PDF page) 155 | 156 | Args: 157 | image: PIL Image object 158 | model: DOLPHIN model instance 159 | save_dir: Directory to save results 160 | image_name: Name for the output file 161 | max_batch_size: Maximum batch size for processing 162 | save_individual: Whether to save individual results (False for PDF pages) 163 | 164 | Returns: 165 | Tuple of (json_path, recognition_results) 166 | """ 167 | # Stage 1: Page-level layout and reading order parsing 168 | layout_output = model.chat("Parse the reading order of this document.", image) 169 | 170 | # Stage 2: Element-level content parsing 171 | padded_image, dims = prepare_image(image) 172 | recognition_results = process_elements(layout_output, padded_image, dims, model, max_batch_size, save_dir, image_name) 173 | 174 | # Save outputs only if requested (skip for PDF pages) 175 | json_path = None 176 | if save_individual: 177 | # Create a dummy image path for save_outputs function 178 | dummy_image_path = f"{image_name}.jpg" # Extension doesn't matter, only basename is used 179 | json_path = save_outputs(recognition_results, dummy_image_path, save_dir) 180 | 181 | return json_path, recognition_results 182 | 183 | 184 | def process_elements(layout_results, padded_image, dims, model, max_batch_size, save_dir=None, image_name=None): 185 | """Parse all document elements with parallel decoding""" 186 | layout_results = parse_layout_string(layout_results) 187 | 188 | # Store text and table elements separately 189 | text_elements = [] # Text elements 190 | table_elements = [] # Table elements 191 | figure_results = [] # Image elements (no processing needed) 192 | previous_box = None 193 | reading_order = 0 194 | 195 | # Collect elements to process and group by type 196 | for bbox, label in layout_results: 197 | try: 198 | # Adjust coordinates 199 | x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = process_coordinates( 200 | bbox, padded_image, dims, previous_box 201 | ) 202 | 203 | # Crop and parse element 204 | cropped = padded_image[y1:y2, x1:x2] 205 | if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3: 206 | if label == "fig": 207 | pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB)) 208 | 209 | figure_filename = save_figure_to_local(pil_crop, save_dir, image_name, reading_order) 210 | 211 | # For figure regions, store relative path instead of base64 212 | figure_results.append( 213 | { 214 | "label": label, 215 | "text": f"", 216 | "figure_path": f"figures/{figure_filename}", 217 | "bbox": [orig_x1, orig_y1, orig_x2, orig_y2], 218 | "reading_order": reading_order, 219 | } 220 | ) 221 | else: 222 | # Prepare element for parsing 223 | pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB)) 224 | element_info = { 225 | "crop": pil_crop, 226 | "label": label, 227 | "bbox": [orig_x1, orig_y1, orig_x2, orig_y2], 228 | "reading_order": reading_order, 229 | } 230 | 231 | # Group by type 232 | if label == "tab": 233 | table_elements.append(element_info) 234 | else: # Text elements 235 | text_elements.append(element_info) 236 | 237 | reading_order += 1 238 | 239 | except Exception as e: 240 | print(f"Error processing bbox with label {label}: {str(e)}") 241 | continue 242 | 243 | # Initialize results list 244 | recognition_results = figure_results.copy() 245 | 246 | # Process text elements (in batches) 247 | if text_elements: 248 | text_results = process_element_batch(text_elements, model, "Read text in the image.", max_batch_size) 249 | recognition_results.extend(text_results) 250 | 251 | # Process table elements (in batches) 252 | if table_elements: 253 | table_results = process_element_batch(table_elements, model, "Parse the table in the image.", max_batch_size) 254 | recognition_results.extend(table_results) 255 | 256 | # Sort elements by reading order 257 | recognition_results.sort(key=lambda x: x.get("reading_order", 0)) 258 | 259 | return recognition_results 260 | 261 | 262 | def process_element_batch(elements, model, prompt, max_batch_size=None): 263 | """Process elements of the same type in batches""" 264 | results = [] 265 | 266 | # Determine batch size 267 | batch_size = len(elements) 268 | if max_batch_size is not None and max_batch_size > 0: 269 | batch_size = min(batch_size, max_batch_size) 270 | 271 | # Process in batches 272 | for i in range(0, len(elements), batch_size): 273 | batch_elements = elements[i:i+batch_size] 274 | crops_list = [elem["crop"] for elem in batch_elements] 275 | 276 | # Use the same prompt for all elements in the batch 277 | prompts_list = [prompt] * len(crops_list) 278 | 279 | # Batch inference 280 | batch_results = model.chat(prompts_list, crops_list) 281 | 282 | # Add results 283 | for j, result in enumerate(batch_results): 284 | elem = batch_elements[j] 285 | results.append({ 286 | "label": elem["label"], 287 | "bbox": elem["bbox"], 288 | "text": result.strip(), 289 | "reading_order": elem["reading_order"], 290 | }) 291 | 292 | return results 293 | 294 | 295 | def main(): 296 | parser = argparse.ArgumentParser(description="Document parsing based on DOLPHIN") 297 | parser.add_argument("--model_path", default="./hf_model", help="Path to Hugging Face model") 298 | parser.add_argument("--input_path", type=str, default="./demo", help="Path to input image/PDF or directory of files") 299 | parser.add_argument( 300 | "--save_dir", 301 | type=str, 302 | default=None, 303 | help="Directory to save parsing results (default: same as input directory)", 304 | ) 305 | parser.add_argument( 306 | "--max_batch_size", 307 | type=int, 308 | default=16, 309 | help="Maximum number of document elements to parse in a single batch (default: 16)", 310 | ) 311 | args = parser.parse_args() 312 | 313 | # Load Model 314 | model = DOLPHIN(args.model_path) 315 | 316 | # Collect Document Files (images and PDFs) 317 | if os.path.isdir(args.input_path): 318 | # Support both image and PDF files 319 | file_extensions = [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG", ".pdf", ".PDF"] 320 | 321 | document_files = [] 322 | for ext in file_extensions: 323 | document_files.extend(glob.glob(os.path.join(args.input_path, f"*{ext}"))) 324 | document_files = sorted(document_files) 325 | else: 326 | if not os.path.exists(args.input_path): 327 | raise FileNotFoundError(f"Input path {args.input_path} does not exist") 328 | 329 | # Check if it's a supported file type 330 | file_ext = os.path.splitext(args.input_path)[1].lower() 331 | supported_exts = ['.jpg', '.jpeg', '.png', '.pdf'] 332 | 333 | if file_ext not in supported_exts: 334 | raise ValueError(f"Unsupported file type: {file_ext}. Supported types: {supported_exts}") 335 | 336 | document_files = [args.input_path] 337 | 338 | save_dir = args.save_dir or ( 339 | args.input_path if os.path.isdir(args.input_path) else os.path.dirname(args.input_path) 340 | ) 341 | setup_output_dirs(save_dir) 342 | 343 | total_samples = len(document_files) 344 | print(f"\nTotal files to process: {total_samples}") 345 | 346 | # Process All Document Files 347 | for file_path in document_files: 348 | print(f"\nProcessing {file_path}") 349 | try: 350 | json_path, recognition_results = process_document( 351 | document_path=file_path, 352 | model=model, 353 | save_dir=save_dir, 354 | max_batch_size=args.max_batch_size, 355 | ) 356 | 357 | print(f"Processing completed. Results saved to {save_dir}") 358 | 359 | except Exception as e: 360 | print(f"Error processing {file_path}: {str(e)}") 361 | continue 362 | 363 | 364 | if __name__ == "__main__": 365 | main() 366 | -------------------------------------------------------------------------------- /deployment/ReadMe.md: -------------------------------------------------------------------------------- 1 | <h1 align="center"> 2 | 🚀 Dolphin Inference/Serving 3 | </h1> 4 | 5 | ## vLLM 6 | > [Doc](./vllm/ReadMe.md) 7 | 8 | ## TensorRT-LLM 9 | > [Doc](./tensorrt_llm/ReadMe.md) 10 | 11 | ## Others 12 | 13 | -------------------------------------------------------------------------------- /deployment/tensorrt_llm/ReadMe.md: -------------------------------------------------------------------------------- 1 | <h1 align="center"> 2 | 🚀 Dolphin TensorRT-LLM Demo 3 | </h1> 4 | 5 | ## ✅ Introduction 6 | The Dolphin model employs a **Swin Encoder + MBart Decoder** architecture. In the HuggingFace Transformers [Config](https://huggingface.co/ByteDance/Dolphin/blob/main/config.json), 7 | its architectures field is specified as "VisionEncoderDecoderModel". **Dolphin**, **[Nougat](https://huggingface.co/docs/transformers/model_doc/nougat)**, and **[Donut](https://huggingface.co/docs/transformers/model_doc/donut)** share the same model architecture. TensorRT-LLM has already supported the Nougat model. 8 | Following Nougat's conversion script, we have successfully implemented Dolphin on TensorRT-LLM. 9 | 10 | **Note:** [prompt_ids](./dolphin_runner.py#L120) MUST be of **int32** type, otherwise TensorRT-LLM will produce incorrect results. 11 | 12 | ## 🛠️ Installation 13 | > We only test TensorRT-LLM 0.18.1 on Linux. 14 | 15 | https://nvidia.github.io/TensorRT-LLM/0.18.1/installation/linux.html 16 | 17 | 18 | ## ⚡ Offline Inference 19 | ``` 20 | export MODEL_NAME="Dolphin" 21 | 22 | # predict elements reading order 23 | python run_dolphin.py \ 24 | --batch_size 1 \ 25 | --hf_model_dir tmp/hf_models/${MODEL_NAME} \ 26 | --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \ 27 | --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \ 28 | --max_new_tokens 4096 \ 29 | --repetition_penalty 1.0 \ 30 | --input_text "Parse the reading order of this document." \ 31 | --image_path "../../demo/page_imgs/page_1.jpeg" 32 | 33 | # recognize text/latex 34 | python run_dolphin.py \ 35 | --batch_size 1 \ 36 | --hf_model_dir tmp/hf_models/${MODEL_NAME} \ 37 | --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \ 38 | --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \ 39 | --max_new_tokens 4096 \ 40 | --repetition_penalty 1.0 \ 41 | --input_text "Read text in the image." \ 42 | --image_path "../../demo/element_imgs/block_formula.jpeg" 43 | 44 | 45 | python run_dolphin.py \ 46 | --batch_size 1 \ 47 | --hf_model_dir tmp/hf_models/${MODEL_NAME} \ 48 | --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \ 49 | --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \ 50 | --max_new_tokens 4096 \ 51 | --repetition_penalty 1.0 \ 52 | --input_text "Read text in the image." \ 53 | --image_path "../../demo/element_imgs/para_1.jpg" 54 | 55 | # recognize table 56 | python run_dolphin.py \ 57 | --batch_size 1 \ 58 | --hf_model_dir tmp/hf_models/${MODEL_NAME} \ 59 | --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \ 60 | --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \ 61 | --max_new_tokens 4096 \ 62 | --repetition_penalty 1.0 \ 63 | --input_text "Parse the table in the image." \ 64 | --image_path "../../demo/element_imgs/table_1.jpeg" 65 | ``` 66 | 67 | 68 | ## ⚡ Online Inference 69 | ``` 70 | # 1. Start Api Server 71 | export MODEL_NAME="Dolphin" 72 | 73 | python api_server.py \ 74 | --hf_model_dir tmp/hf_models/${MODEL_NAME} \ 75 | --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \ 76 | --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \ 77 | --max_batch_size 16 78 | 79 | # 2. Predict 80 | # predict elements reading order 81 | python deployment/tensorrt_llm/api_client.py --image_path ./demo/page_imgs/page_1.jpeg --prompt "Parse the reading order of this document." 82 | 83 | # recognize text/latex 84 | python deployment/tensorrt_llm/api_client.py --image_path ./demo/element_imgs/block_formula.jpeg --prompt "Read text in the image." 85 | python deployment/tensorrt_llm/api_client.py --image_path ./demo/element_imgs/para_1.jpg --prompt "Read text in the image." 86 | 87 | # recognize table 88 | python deployment/tensorrt_llm/api_client.py --image_path ./demo/element_imgs/table_1.jpeg --prompt "Parse the table in the image." 89 | ``` -------------------------------------------------------------------------------- /deployment/tensorrt_llm/api_client.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project 3 | """Example Python client for `vllm.entrypoints.api_server` 4 | Start the demo server: 5 | python -m vllm.entrypoints.api_server --model <model_name> 6 | 7 | NOTE: The API server is used only for demonstration and simple performance 8 | benchmarks. It is not intended for production use. 9 | For production use, we recommend `vllm serve` and the OpenAI client API. 10 | """ 11 | 12 | import argparse 13 | import base64 14 | import json 15 | from argparse import Namespace 16 | from collections.abc import Iterable 17 | 18 | import requests 19 | 20 | 21 | def clear_line(n: int = 1) -> None: 22 | LINE_UP = "\033[1A" 23 | LINE_CLEAR = "\x1b[2K" 24 | for _ in range(n): 25 | print(LINE_UP, end=LINE_CLEAR, flush=True) 26 | 27 | 28 | def encode_image_base64(image_path: str) -> str: 29 | """Encode local image to base64 format.""" 30 | 31 | with open(image_path, "rb") as f: 32 | image_data = f.read() 33 | result = base64.b64encode(image_data).decode("utf-8") 34 | 35 | return result 36 | 37 | 38 | def post_http_request( 39 | prompt: str, image_path: str, api_url: str, stream: bool = False 40 | ) -> requests.Response: 41 | headers = {"User-Agent": "Test Client"} 42 | pload = { 43 | "prompt": prompt, 44 | "image_base64": encode_image_base64(image_path), 45 | } 46 | response = requests.post(api_url, headers=headers, json=pload, stream=stream) 47 | return response 48 | 49 | 50 | def get_streaming_response(response: requests.Response) -> Iterable[list[str]]: 51 | for chunk in response.iter_lines( 52 | chunk_size=8192, decode_unicode=False, delimiter=b"\n" 53 | ): 54 | if chunk: 55 | data = json.loads(chunk.decode("utf-8")) 56 | output = data["text"] 57 | yield output 58 | 59 | 60 | def get_response(response: requests.Response) -> list[str]: 61 | data = json.loads(response.content) 62 | output = data["text"] 63 | return output 64 | 65 | 66 | def parse_args(): 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument("--host", type=str, default="localhost") 69 | parser.add_argument("--port", type=int, default=8000) 70 | parser.add_argument("--prompt", type=str, default="Parse the reading order of this document.") 71 | parser.add_argument("--image_path", type=str, default="./demo/page_imgs/page_1.jpeg") 72 | parser.add_argument("--stream", action="store_true") 73 | return parser.parse_args() 74 | 75 | 76 | def main(args: Namespace): 77 | prompt = args.prompt 78 | image_path = args.image_path 79 | api_url = f"http://{args.host}:{args.port}/generate" 80 | stream = args.stream 81 | 82 | print(f"Prompt: {prompt!r}\n", flush=True) 83 | response = post_http_request(prompt, image_path, api_url, stream) 84 | 85 | if stream: 86 | num_printed_lines = 0 87 | for h in get_streaming_response(response): 88 | clear_line(num_printed_lines) 89 | num_printed_lines = 0 90 | for i, line in enumerate(h): 91 | num_printed_lines += 1 92 | print(f"Response {i}: {line!r}", flush=True) 93 | else: 94 | output = get_response(response) 95 | print(f"Response: {output!r}", flush=True) 96 | 97 | 98 | if __name__ == "__main__": 99 | args = parse_args() 100 | main(args) 101 | -------------------------------------------------------------------------------- /deployment/tensorrt_llm/api_server.py: -------------------------------------------------------------------------------- 1 | # copied from: https://github.com/NVIDIA/TensorRT-LLM/blob/v0.18.1/examples/apps/fastapi_server.py 2 | 3 | #!/usr/bin/env python 4 | import asyncio 5 | import base64 6 | import io 7 | import logging 8 | import signal 9 | from http import HTTPStatus 10 | from PIL import Image 11 | from typing import Optional 12 | 13 | import click 14 | import uvicorn 15 | from fastapi import FastAPI, Request 16 | from fastapi.responses import JSONResponse, Response 17 | 18 | from tensorrt_llm.executor import CppExecutorError, RequestError 19 | from dolphin_runner import DolphinRunner, InferenceConfig 20 | 21 | TIMEOUT_KEEP_ALIVE = 5 # seconds. 22 | 23 | 24 | async def decode_image(image_base64: str) -> Image.Image: 25 | image_data = base64.b64decode(image_base64) 26 | image = Image.open(io.BytesIO(image_data)) 27 | return image 28 | 29 | 30 | class LlmServer: 31 | def __init__(self, runner: DolphinRunner): 32 | self.runner = runner 33 | self.app = FastAPI() 34 | self.register_routes() 35 | 36 | def register_routes(self): 37 | self.app.add_api_route("/health", self.health, methods=["GET"]) 38 | self.app.add_api_route("/generate", self.generate, methods=["POST"]) 39 | 40 | async def health(self) -> Response: 41 | return Response(status_code=200) 42 | 43 | async def generate(self, request: Request) -> Response: 44 | """ Generate completion for the request. 45 | 46 | The request should be a JSON object with the following fields: 47 | - prompt: the prompt to use for the generation. 48 | - image_base64: the image to use for the generation. 49 | """ 50 | request_dict = await request.json() 51 | 52 | prompt = request_dict.pop("prompt", "") 53 | logging.info(f"request prompt: {prompt}") 54 | image_base64 = request_dict.pop("image_base64", "") 55 | image = await decode_image(image_base64) 56 | 57 | try: 58 | output_texts = self.runner.run([prompt], [image], 4024) 59 | output_texts = [texts[0] for texts in output_texts] 60 | return JSONResponse({"text": output_texts[0]}) 61 | except RequestError as e: 62 | return JSONResponse(content=str(e), 63 | status_code=HTTPStatus.BAD_REQUEST) 64 | except CppExecutorError: 65 | # If internal executor error is raised, shutdown the server 66 | signal.raise_signal(signal.SIGINT) 67 | 68 | async def __call__(self, host, port): 69 | config = uvicorn.Config(self.app, 70 | host=host, 71 | port=port, 72 | log_level="info", 73 | timeout_keep_alive=TIMEOUT_KEEP_ALIVE) 74 | await uvicorn.Server(config).serve() 75 | 76 | 77 | @click.command() 78 | @click.option("--hf_model_dir", type=str, required=True) 79 | @click.option("--visual_engine_dir", type=str, required=True) 80 | @click.option("--llm_engine_dir", type=str, required=True) 81 | @click.option("--max_batch_size", type=int, default=16) 82 | @click.option("--max_new_tokens", type=int, default=4024) 83 | @click.option("--host", type=str, default=None) 84 | @click.option("--port", type=int, default=8000) 85 | def entrypoint(hf_model_dir: str, 86 | visual_engine_dir: str, 87 | llm_engine_dir: str, 88 | max_batch_size: int, 89 | max_new_tokens: int, 90 | host: Optional[str] = None, 91 | port: int = 8000): 92 | host = host or "0.0.0.0" 93 | port = port or 8000 94 | logging.info(f"Starting server at {host}:{port}") 95 | 96 | config = InferenceConfig( 97 | max_new_tokens=max_new_tokens, 98 | batch_size=max_batch_size, 99 | log_level="info", 100 | hf_model_dir=hf_model_dir, 101 | visual_engine_dir=visual_engine_dir, 102 | llm_engine_dir=llm_engine_dir, 103 | ) 104 | 105 | dolphin_runner = DolphinRunner(config) 106 | server = LlmServer(runner=dolphin_runner) 107 | 108 | asyncio.run(server(host, port)) 109 | 110 | 111 | if __name__ == "__main__": 112 | entrypoint() -------------------------------------------------------------------------------- /deployment/tensorrt_llm/convert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/Dolphin/169a1a70294790b5ee039486589c9c97b6f89d58/deployment/tensorrt_llm/convert/__init__.py -------------------------------------------------------------------------------- /deployment/tensorrt_llm/convert/build_visual_engine.py: -------------------------------------------------------------------------------- 1 | # copied from: https://github.com/NVIDIA/TensorRT-LLM/blob/v0.18.2/examples/multimodal/build_visual_engine.py 2 | 3 | import argparse 4 | 5 | from tensorrt_llm.tools.multimodal_builder import (VisionEngineBuilder, 6 | add_multimodal_arguments) 7 | 8 | if __name__ == '__main__': 9 | parser = argparse.ArgumentParser() 10 | parser = add_multimodal_arguments(parser) 11 | args = parser.parse_args() 12 | 13 | builder = VisionEngineBuilder(args) 14 | builder.build() 15 | -------------------------------------------------------------------------------- /deployment/tensorrt_llm/convert/helper.py: -------------------------------------------------------------------------------- 1 | # copied from: https://github.com/NVIDIA/TensorRT-LLM/blob/v0.18.1/examples/enc_dec/helper.py 2 | 3 | import typing 4 | from typing import Union 5 | 6 | import numpy as np 7 | import torch # pytype: disable=import-error 8 | 9 | from tensorrt_llm._utils import str_dtype_to_torch 10 | 11 | 12 | def split(v: Union[np.ndarray, torch.Tensor], 13 | tp_size: int, 14 | tp_rank: int, 15 | dim=0): 16 | if tp_size == 1: 17 | if isinstance(v, np.ndarray): 18 | return np.ascontiguousarray(v.copy()) 19 | else: 20 | return v.clone().detach() 21 | assert len(v.shape) > 1 or dim == 0 22 | if isinstance(v, np.ndarray): 23 | return np.ascontiguousarray( 24 | np.split(v, tp_size, axis=dim)[tp_rank].copy()) 25 | else: 26 | assert v.shape[dim] % tp_size == 0, \ 27 | 'Unable to split: shape={v.shape} (dim={dim}) tp_size={tp_size}.' 28 | split_size = v.shape[dim] // tp_size 29 | return v.split(split_size, dim=dim)[tp_rank].clone().detach() 30 | 31 | 32 | def reshape(v: torch.Tensor, shape=None): 33 | if shape is None: 34 | return v.contiguous() 35 | else: 36 | return v.reshape(shape).contiguous() 37 | 38 | 39 | def fuse_qkv_one_layer(params, attn_module_name, trtllm_layer_name, tp_size, 40 | tp_rank, model_type, weight_shape, bias_shape): 41 | 42 | qkv_module_names = get_qkv_module_name(model_type) 43 | 44 | weight = {} 45 | 46 | # fuse weights of q, k, v 47 | q_w = params[f'{attn_module_name}.{qkv_module_names["q"]}.weight'] 48 | k_w = params[f'{attn_module_name}.{qkv_module_names["k"]}.weight'] 49 | v_w = params[f'{attn_module_name}.{qkv_module_names["v"]}.weight'] 50 | 51 | # fuse qkv weight 52 | shape = q_w.shape # (do, din) 53 | qkv_w = torch.cat([q_w, k_w, v_w], 54 | dim=0).reshape([3, shape[0], shape[1]]) # (3, do, din) 55 | qkv_w = split(qkv_w, tp_size, tp_rank, dim=1) 56 | weight[f'{trtllm_layer_name}.qkv.weight'] = reshape(qkv_w, 57 | shape=weight_shape) 58 | 59 | # fuse qkv biases if present 60 | if f'{attn_module_name}.{qkv_module_names["q"]}.bias' in params.keys( 61 | ) and params[f'{attn_module_name}.{qkv_module_names["q"]}.bias'] is not None: 62 | q_b = params[f'{attn_module_name}.{qkv_module_names["q"]}.bias'] 63 | k_b = params[f'{attn_module_name}.{qkv_module_names["k"]}.bias'] 64 | v_b = params[f'{attn_module_name}.{qkv_module_names["v"]}.bias'] 65 | shape = q_b.shape[0] # (do,) 66 | qkv_b = torch.cat([q_b, k_b, v_b], dim=0).reshape([3, shape]) # (3, do) 67 | qkv_b = split(qkv_b, tp_size, tp_rank, dim=1) 68 | weight[f'{trtllm_layer_name}.qkv.bias'] = reshape(qkv_b, 69 | shape=bias_shape) 70 | return weight 71 | 72 | 73 | def get_qkv_module_name(model_type): 74 | if model_type in ["t5", "blip2"]: 75 | q = "q" 76 | k = "k" 77 | v = "v" 78 | elif model_type == "bart" or model_type == "nmt": 79 | q = "q_proj" 80 | k = "k_proj" 81 | v = "v_proj" 82 | elif model_type == "pix2struct": 83 | q = "query" 84 | k = "key" 85 | v = "value" 86 | return {"q": q, "k": k, "v": v} 87 | 88 | 89 | def convert_weight_to_dtype(params: typing.Dict[str, torch.Tensor], 90 | dtype: typing.Optional[np.dtype] = None): 91 | if dtype is not None: 92 | assert isinstance(dtype, 93 | str), f"dtype must be str, but get type {type(dtype)}" 94 | for name in params.keys(): 95 | params[name] = params[name].to(str_dtype_to_torch(dtype)) 96 | -------------------------------------------------------------------------------- /deployment/tensorrt_llm/convert_dolphin.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -ex 3 | 4 | ############################################################################################ 5 | # Reference: https://github.com/NVIDIA/TensorRT-LLM/tree/v0.18.2/examples/multimodal#nougat 6 | ############################################################################################ 7 | 8 | export LD_LIBRARY_PATH=/usr/local/lib/python3.10/site-packages/tensorrt_libs/:/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib/:$LD_LIBRARY_PATH 9 | 10 | # 1. Download Huggingface weights 11 | export MODEL_NAME="Dolphin" 12 | git clone https://huggingface.co/Bytedance/${MODEL_NAME} tmp/hf_models/${MODEL_NAME} 13 | 14 | 15 | export MAX_BATCH_SIZE=16 16 | export MAX_SEQ_LEN=4096 17 | export MAX_INPUT_LEN=10 18 | export MAX_ENCODER_INPUT_LEN=784 19 | 20 | # 2. Convert Huggingface weights into TRT-LLM checkpoints and build TRT engines using scripts in examples/enc_dec 21 | python ./convert/convert_checkpoint.py --model_type bart \ 22 | --model_dir tmp/hf_models/${MODEL_NAME} \ 23 | --output_dir tmp/trt_models/${MODEL_NAME}/bfloat16 \ 24 | --tp_size 1 \ 25 | --pp_size 1 \ 26 | --dtype bfloat16 \ 27 | --nougat 28 | 29 | 30 | trtllm-build --checkpoint_dir tmp/trt_models/${MODEL_NAME}/bfloat16/decoder \ 31 | --output_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16/decoder \ 32 | --paged_kv_cache disable \ 33 | --moe_plugin disable \ 34 | --gemm_plugin bfloat16 \ 35 | --bert_attention_plugin bfloat16 \ 36 | --gpt_attention_plugin bfloat16 \ 37 | --remove_input_padding enable \ 38 | --max_beam_width 1 \ 39 | --max_batch_size ${MAX_BATCH_SIZE} \ 40 | --max_seq_len ${MAX_SEQ_LEN} \ 41 | --max_input_len ${MAX_INPUT_LEN} \ 42 | --max_encoder_input_len $((${MAX_BATCH_SIZE} * ${MAX_ENCODER_INPUT_LEN})) # MAX_BATCH_SIZE (max_batch_size) * MAX_ENCODER_INPUT_LEN (num_visual_features) 43 | 44 | # 3. Generate TensorRT engines for visual components and combine everything into final pipeline. 45 | python ./convert/build_visual_engine.py --model_type nougat \ 46 | --model_path tmp/hf_models/${MODEL_NAME} \ 47 | --max_batch_size ${MAX_BATCH_SIZE} -------------------------------------------------------------------------------- /deployment/tensorrt_llm/dolphin_runner.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 3 | SPDX-License-Identifier: MIT 4 | """ 5 | 6 | import json 7 | import os 8 | from typing import Optional 9 | 10 | import tensorrt_llm 11 | import tensorrt_llm.profiler as profiler 12 | import torch 13 | from PIL import Image 14 | from pydantic import BaseModel, Field 15 | from tensorrt_llm import logger 16 | from tensorrt_llm import mpi_rank 17 | from tensorrt_llm.runtime import MultimodalModelRunner 18 | from transformers import AutoTokenizer, DonutProcessor 19 | 20 | 21 | class InferenceConfig(BaseModel): 22 | max_new_tokens: int = Field(128, description="Maximum new tokens to generate") 23 | batch_size: int = Field(1, description="Batch size for inference") 24 | log_level: str = Field("info", description="Logging level") 25 | visual_engine_dir: Optional[str] = Field(None, description="Directory for visual engine files") 26 | visual_engine_name: str = Field("model.engine", description="Visual engine filename") 27 | llm_engine_dir: Optional[str] = Field(None, description="Directory for LLM engine files") 28 | hf_model_dir: Optional[str] = Field(None, description="Hugging Face model directory") 29 | input_text: Optional[str] = Field(None, description="Input text for inference") 30 | num_beams: int = Field(1, description="Number of beams for beam search") 31 | top_k: int = Field(1, description="Top-k sampling value") 32 | top_p: float = Field(0.0, description="Top-p (nucleus) sampling value") 33 | temperature: float = Field(1.0, description="Sampling temperature") 34 | repetition_penalty: float = Field(1.0, description="Repetition penalty factor") 35 | run_profiling: bool = Field(False, description="Enable profiling mode") 36 | profiling_iterations: int = Field(20, description="Number of profiling iterations") 37 | check_accuracy: bool = Field(False, description="Enable accuracy checking") 38 | video_path: Optional[str] = Field(None, description="Path to input video file") 39 | video_num_frames: Optional[int] = Field(None, description="Number of video frames to process") 40 | image_path: Optional[str] = Field(None, description="Path to input image file") 41 | path_sep: str = Field(",", description="Path separator character") 42 | prompt_sep: str = Field(",", description="Prompt separator character") 43 | enable_context_fmha_fp32_acc: Optional[bool] = Field( 44 | None, 45 | description="Enable FP32 accumulation for context FMHA" 46 | ) 47 | enable_chunked_context: bool = Field(False, description="Enable chunked context processing") 48 | use_py_session: bool = Field(False, description="Use Python session instead of C++") 49 | kv_cache_free_gpu_memory_fraction: float = Field( 50 | 0.9, 51 | description="Fraction of GPU memory free for KV cache", 52 | ge=0.0, le=1.0 53 | ) 54 | cross_kv_cache_fraction: float = Field( 55 | 0.5, 56 | description="Fraction of cross-attention KV cache", 57 | ge=0.0, le=1.0 58 | ) 59 | multi_block_mode: bool = Field(True, description="Enable multi-block processing mode") 60 | 61 | 62 | class DolphinRunner(MultimodalModelRunner): 63 | def __init__(self, args): 64 | self.args = args 65 | 66 | self.runtime_rank = mpi_rank() 67 | device_id = self.runtime_rank % torch.cuda.device_count() 68 | torch.cuda.set_device(device_id) 69 | self.device = "cuda:%d" % (device_id) 70 | 71 | self.stream = torch.cuda.Stream(torch.cuda.current_device()) 72 | torch.cuda.set_stream(self.stream) 73 | 74 | # parse model type from visual engine config 75 | with open(os.path.join(self.args.visual_engine_dir, "config.json"), 76 | "r") as f: 77 | config = json.load(f) 78 | self.model_type = config['builder_config']['model_type'] 79 | self.vision_precision = config['builder_config']['precision'] 80 | self.decoder_llm = not ( 81 | 't5' in self.model_type 82 | or self.model_type in ['nougat', 'pix2struct'] 83 | ) # BLIP2-T5, pix2struct and Nougat are using encoder-decoder models as LLMs 84 | 85 | if self.model_type == "mllama": 86 | self.vision_input_names = [ 87 | "pixel_values", 88 | "aspect_ratio_ids", 89 | "aspect_ratio_mask", 90 | ] 91 | self.vision_output_names = [ 92 | "output", 93 | ] 94 | else: 95 | self.vision_input_names = ["input"] 96 | self.vision_output_names = ["output"] 97 | 98 | self.use_py_session = True 99 | 100 | self.init_image_encoder() 101 | self.init_tokenizer() 102 | self.init_processor() 103 | self.init_llm() 104 | 105 | def init_tokenizer(self): 106 | assert self.model_type == 'nougat' 107 | self.tokenizer = AutoTokenizer.from_pretrained(self.args.hf_model_dir) 108 | self.tokenizer.padding_side = "right" 109 | 110 | def init_processor(self): 111 | assert self.model_type == 'nougat' 112 | self.processor = DonutProcessor.from_pretrained(self.args.hf_model_dir, use_fast=True) 113 | 114 | def run(self, input_texts, input_images, max_new_tokens): 115 | prompts = [f"<s>{text.strip()} <Answer/>" for text in input_texts] 116 | images = self.processor(input_images, return_tensors="pt")['pixel_values'].to("cuda") 117 | prompt_ids = self.tokenizer(prompts, add_special_tokens=False, return_tensors="pt").input_ids.to("cuda") 118 | 119 | # 🚨🚨🚨 Important! If the type of prompt_ids is not int32, the output will be wrong. 🚨🚨🚨 120 | prompt_ids = prompt_ids.to(torch.int32) 121 | 122 | logger.info("---------------------------------------------------------") 123 | logger.info(f"images size: {images.size()}") 124 | logger.info(f"prompt_ids: {prompt_ids}, size: {prompt_ids.size()}, dtype: {prompt_ids.dtype}") 125 | logger.info("---------------------------------------------------------") 126 | 127 | output_texts = self.generate(input_texts, 128 | [None] * len(input_texts), 129 | images, 130 | prompt_ids, 131 | max_new_tokens, 132 | warmup=False, 133 | ) 134 | 135 | return output_texts 136 | 137 | def generate(self, 138 | pre_prompt, 139 | post_prompt, 140 | image, 141 | decoder_input_ids, 142 | max_new_tokens, 143 | warmup=False, 144 | other_vision_inputs={}, 145 | other_decoder_inputs={}): 146 | if not warmup: 147 | profiler.start("Generate") 148 | input_ids, input_lengths, ptuning_args, visual_features = self.preprocess( 149 | warmup, pre_prompt, post_prompt, image, other_vision_inputs) 150 | 151 | if warmup: return None 152 | 153 | # use prompt tuning to pass multimodal features 154 | # model.generate() expects the following params (see layers/embedding.py): 155 | # args[0]: prompt embedding table, [batch_size, multimodal_len, hidden_size], later flattened to [batch_size * multimodal_len, hidden_size] 156 | # args[1]: prompt task ids, [batch_size]. in multimodal case, arange(batch_size), i.e. in VILA batching mode 2, each image is treated separately in the batch instead of concated together (although the prompt embedding table has to be concated) 157 | # args[2]: prompt task vocab size, [1]. assuming all table has the same length, which in multimodal case equals to multimodal_len 158 | profiler.start("LLM") 159 | if self.model_type in ['nougat', 'pix2struct']: 160 | # Trim encoder input_ids to match visual features shape 161 | ids_shape = (min(self.args.batch_size, len(pre_prompt)), visual_features.shape[1]) 162 | if self.model_type == 'nougat': 163 | input_ids = torch.zeros(ids_shape, dtype=torch.int32) 164 | elif self.model_type == 'pix2struct': 165 | input_ids = torch.ones(ids_shape, dtype=torch.int32) 166 | 167 | output_ids = self.model.generate( 168 | input_ids, 169 | decoder_input_ids, 170 | max_new_tokens, 171 | num_beams=self.args.num_beams, 172 | bos_token_id=self.tokenizer.bos_token_id, 173 | pad_token_id=self.tokenizer.pad_token_id, 174 | eos_token_id=self.tokenizer.eos_token_id, 175 | debug_mode=False, 176 | prompt_embedding_table=ptuning_args[0], 177 | prompt_tasks=ptuning_args[1], 178 | prompt_vocab_size=ptuning_args[2], 179 | ) 180 | profiler.stop("LLM") 181 | 182 | if mpi_rank() == 0: 183 | # Extract a list of tensors of shape beam_width x output_ids. 184 | output_beams_list = [ 185 | self.tokenizer.batch_decode( 186 | output_ids[batch_idx, :, decoder_input_ids.shape[1]:], 187 | skip_special_tokens=False) for batch_idx in range( 188 | min(self.args.batch_size, decoder_input_ids.shape[0])) 189 | ] 190 | 191 | stripped_text = [[ 192 | output_beams_list[batch_idx][beam_idx].replace("</s>", "").replace("<pad>", "").strip() 193 | for beam_idx in range(self.args.num_beams) 194 | ] for batch_idx in range( 195 | min(self.args.batch_size, decoder_input_ids.shape[0]))] 196 | profiler.stop("Generate") 197 | return stripped_text 198 | else: 199 | profiler.stop("Generate") 200 | return None 201 | 202 | 203 | if __name__ == "__main__": 204 | config = InferenceConfig( 205 | max_new_tokens=4024, 206 | batch_size=16, 207 | log_level="info", 208 | hf_model_dir=f"./tmp/hf_models/Dolphin", 209 | visual_engine_dir=f"./tmp/trt_engines/Dolphin/vision_encoder", 210 | llm_engine_dir=f"./tmp/trt_engines/Dolphin/1-gpu/bfloat16", 211 | ) 212 | 213 | model = DolphinRunner(config) 214 | 215 | image_path = "../../demo/page_imgs/page_1.jpeg" 216 | prompt = "Parse the reading order of this document." 217 | image = Image.open(image_path).convert("RGB") 218 | output_texts = model.run([prompt], [image], 4024) 219 | output_texts = [texts[0] for texts in output_texts] 220 | print(output_texts) 221 | -------------------------------------------------------------------------------- /deployment/tensorrt_llm/run_dolphin.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 3 | SPDX-License-Identifier: MIT 4 | """ 5 | 6 | import argparse 7 | import os 8 | 9 | import tensorrt_llm 10 | import tensorrt_llm.profiler as profiler 11 | from PIL import Image 12 | from tensorrt_llm import logger 13 | from tensorrt_llm import mpi_rank 14 | from tensorrt_llm.runtime import MultimodalModelRunner 15 | 16 | from dolphin_runner import DolphinRunner 17 | from utils import add_common_args 18 | 19 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 20 | 21 | 22 | def print_result(model, input_text, output_text, args): 23 | logger.info("---------------------------------------------------------") 24 | logger.info(f"\n[Q] {input_text}") 25 | for i in range(len(output_text)): 26 | logger.info(f"\n[A]: {output_text[i]}") 27 | 28 | if args.num_beams == 1: 29 | output_ids = model.tokenizer(output_text[0][0], 30 | add_special_tokens=False)['input_ids'] 31 | logger.info(f"Generated {len(output_ids)} tokens") 32 | 33 | if args.check_accuracy: 34 | if model.model_type != 'nougat': 35 | if model.model_type == "vila": 36 | for i in range(len(args.image_path.split(args.path_sep))): 37 | if i % 2 == 0: 38 | assert output_text[i][0].lower( 39 | ) == "the image captures a bustling city intersection teeming with life. from the perspective of a car's dashboard camera, we see" 40 | else: 41 | assert output_text[i][0].lower( 42 | ) == "the image captures the iconic merlion statue in singapore, a renowned worldwide landmark. the merlion, a mythical" 43 | elif model.model_type == "llava": 44 | for i in range(len(args.image_path.split(args.path_sep))): 45 | assert output_text[i][0].lower() == 'singapore' 46 | elif model.model_type == 'fuyu': 47 | assert output_text[0][0].lower() == '4' 48 | elif model.model_type == "pix2struct": 49 | assert "characteristic | cat food, day | cat food, wet | cat treats" in output_text[ 50 | 0][0].lower() 51 | elif model.model_type in [ 52 | 'blip2', 'neva', 'phi-3-vision', 'llava_next' 53 | ]: 54 | assert 'singapore' in output_text[0][0].lower() 55 | elif model.model_type == 'video-neva': 56 | assert 'robot' in output_text[0][0].lower() 57 | elif model.model_type == 'kosmos-2': 58 | assert 'snowman' in output_text[0][0].lower() 59 | elif model.model_type == "mllama": 60 | if "If I had to write a haiku for this one" in input_text: 61 | assert "it would be:.\\nPeter Rabbit is a rabbit.\\nHe lives in a" in output_text[ 62 | 0][0] or "Here is a haiku for the image:\n\n" in output_text[ 63 | 0][0], f"expected results: 'it would be:.\\nPeter Rabbit is a rabbit.\\nHe lives in a', generated results: '{output_text[0][0]}'" 64 | elif "The key to life is" in input_text: 65 | assert "to find your passion and pursue it with all your heart." in output_text[ 66 | 0][0] or "not to be found in the external world," in output_text[ 67 | 0][0], f"expected results: 'to find your passion and pursue it with all your heart.', generated results: '{output_text[0][0]}'" 68 | elif model.model_type == 'llava_onevision': 69 | if args.video_path is None: 70 | assert 'singapore' in output_text[0][0].lower() 71 | else: 72 | assert 'the video is funny because the child\'s actions are' in output_text[ 73 | 0][0].lower() 74 | elif model.model_type == "qwen2_vl": 75 | assert 'dog' in output_text[0][0].lower() 76 | else: 77 | assert output_text[0][0].lower() == 'singapore' 78 | 79 | if args.run_profiling: 80 | msec_per_batch = lambda name: 1000 * profiler.elapsed_time_in_sec( 81 | name) / args.profiling_iterations 82 | logger.info('Latencies per batch (msec)') 83 | logger.info('TRT vision encoder: %.1f' % (msec_per_batch('Vision'))) 84 | logger.info('TRTLLM LLM generate: %.1f' % (msec_per_batch('LLM'))) 85 | logger.info('Multimodal generate: %.1f' % (msec_per_batch('Generate'))) 86 | 87 | logger.info("---------------------------------------------------------") 88 | 89 | 90 | if __name__ == '__main__': 91 | parser = argparse.ArgumentParser() 92 | parser = add_common_args(parser) 93 | args = parser.parse_args() 94 | logger.set_level(args.log_level) 95 | 96 | model = DolphinRunner(args) 97 | 98 | input_image = Image.open(args.image_path[0]).convert('RGB') 99 | num_iters = args.profiling_iterations if args.run_profiling else 1 100 | 101 | for _ in range(num_iters): 102 | output_texts = model.run(args.input_text, [input_image], args.max_new_tokens) 103 | 104 | runtime_rank = tensorrt_llm.mpi_rank() 105 | if runtime_rank == 0: 106 | print_result(model, args.input_text, output_texts, args) 107 | -------------------------------------------------------------------------------- /deployment/tensorrt_llm/run_dolphin.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -ex 3 | 4 | export MODEL_NAME="Dolphin" 5 | 6 | python run_dolphin.py \ 7 | --batch_size 1 \ 8 | --hf_model_dir tmp/hf_models/${MODEL_NAME} \ 9 | --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \ 10 | --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \ 11 | --max_new_tokens 4096 \ 12 | --repetition_penalty 1.0 \ 13 | --input_text "Parse the reading order of this document." \ 14 | --image_path "../../demo/page_imgs/page_1.jpeg" 15 | 16 | 17 | python run_dolphin.py \ 18 | --batch_size 1 \ 19 | --hf_model_dir tmp/hf_models/${MODEL_NAME} \ 20 | --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \ 21 | --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \ 22 | --max_new_tokens 4096 \ 23 | --repetition_penalty 1.0 \ 24 | --input_text "Read text in the image." \ 25 | --image_path "../../demo/element_imgs/block_formula.jpeg" 26 | 27 | 28 | python run_dolphin.py \ 29 | --batch_size 1 \ 30 | --hf_model_dir tmp/hf_models/${MODEL_NAME} \ 31 | --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \ 32 | --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \ 33 | --max_new_tokens 4096 \ 34 | --repetition_penalty 1.0 \ 35 | --input_text "Read text in the image." \ 36 | --image_path "../../demo/element_imgs/para_1.jpg" 37 | 38 | 39 | python run_dolphin.py \ 40 | --batch_size 1 \ 41 | --hf_model_dir tmp/hf_models/${MODEL_NAME} \ 42 | --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \ 43 | --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \ 44 | --max_new_tokens 4096 \ 45 | --repetition_penalty 1.0 \ 46 | --input_text "Parse the table in the image." \ 47 | --image_path "../../demo/element_imgs/table_1.jpeg" 48 | -------------------------------------------------------------------------------- /deployment/tensorrt_llm/start_dolphin_server.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -ex 3 | 4 | export MODEL_NAME="Dolphin" 5 | 6 | python api_server.py \ 7 | --hf_model_dir tmp/hf_models/${MODEL_NAME} \ 8 | --visual_engine_dir tmp/trt_engines/${MODEL_NAME}/vision_encoder \ 9 | --llm_engine_dir tmp/trt_engines/${MODEL_NAME}/1-gpu/bfloat16 \ 10 | --max_batch_size 16 -------------------------------------------------------------------------------- /deployment/tensorrt_llm/utils.py: -------------------------------------------------------------------------------- 1 | def add_common_args(parser): 2 | parser.add_argument('--max_new_tokens', type=int, default=128) 3 | parser.add_argument('--batch_size', type=int, default=1) 4 | parser.add_argument('--log_level', type=str, default='info') 5 | parser.add_argument('--visual_engine_dir', 6 | type=str, 7 | default=None, 8 | help='Directory containing visual TRT engines') 9 | parser.add_argument('--visual_engine_name', 10 | type=str, 11 | default='model.engine', 12 | help='Name of visual TRT engine') 13 | parser.add_argument('--llm_engine_dir', 14 | type=str, 15 | default=None, 16 | help='Directory containing TRT-LLM engines') 17 | parser.add_argument('--hf_model_dir', 18 | type=str, 19 | default=None, 20 | help="Directory containing tokenizer") 21 | parser.add_argument('--input_text', 22 | type=str, 23 | nargs='+', 24 | default=None, 25 | help='Text prompt to LLM') 26 | parser.add_argument('--num_beams', 27 | type=int, 28 | help="Use beam search if num_beams >1", 29 | default=1) 30 | parser.add_argument('--top_k', type=int, default=1) 31 | parser.add_argument('--top_p', type=float, default=0.0) 32 | parser.add_argument('--temperature', type=float, default=1.0) 33 | parser.add_argument('--repetition_penalty', type=float, default=1.0) 34 | parser.add_argument('--run_profiling', 35 | action='store_true', 36 | help='Profile runtime over several iterations') 37 | parser.add_argument('--profiling_iterations', 38 | type=int, 39 | help="Number of iterations to run profiling", 40 | default=20) 41 | parser.add_argument('--check_accuracy', 42 | action='store_true', 43 | help='Check correctness of text output') 44 | parser.add_argument("--image_path", 45 | type=str, 46 | nargs='+', 47 | default=None, 48 | help='List of input image paths, separated by symbol') 49 | parser.add_argument("--path_sep", 50 | type=str, 51 | default=",", 52 | help='Path separator symbol') 53 | parser.add_argument("--prompt_sep", 54 | type=str, 55 | default=",", 56 | help="Prompt separator symbol") 57 | parser.add_argument('--enable_context_fmha_fp32_acc', 58 | action='store_true', 59 | default=None, 60 | help="Enable FMHA runner FP32 accumulation.") 61 | parser.add_argument( 62 | '--enable_chunked_context', 63 | action='store_true', 64 | help='Enables chunked context (only available with cpp session).', 65 | ) 66 | parser.add_argument( 67 | '--use_py_session', 68 | default=False, 69 | action='store_true', 70 | help= 71 | "Whether or not to use Python runtime session. By default C++ runtime session is used for the LLM." 72 | ) 73 | parser.add_argument( 74 | '--kv_cache_free_gpu_memory_fraction', 75 | default=0.9, 76 | type=float, 77 | help='Specify the free gpu memory fraction.', 78 | ) 79 | parser.add_argument( 80 | '--cross_kv_cache_fraction', 81 | default=0.5, 82 | type=float, 83 | help= 84 | 'Specify the kv cache fraction reserved for cross attention. Only applicable for encoder-decoder models. By default 0.5 for self and 0.5 for cross.', 85 | ) 86 | parser.add_argument( 87 | '--multi_block_mode', 88 | type=lambda s: s.lower() in 89 | ("yes", "true", "t", "1" 90 | ), # custom boolean function to convert input string to boolean 91 | default=True, 92 | help= 93 | "Distribute the work across multiple CUDA thread-blocks on the GPU for masked MHA kernel." 94 | ) 95 | return parser 96 | -------------------------------------------------------------------------------- /deployment/vllm/ReadMe.md: -------------------------------------------------------------------------------- 1 | <h1 align="center"> 2 | 🚀 Dolphin vLLM Demo 3 | </h1> 4 | 5 | ## ✅ Introduction 6 | The Dolphin model employs a **Swin Encoder + MBart Decoder** architecture. In the HuggingFace Transformers [Config](https://huggingface.co/ByteDance/Dolphin/blob/main/config.json), 7 | its architectures field is specified as "VisionEncoderDecoderModel". vLLM does not natively support this architecture. 8 | To enable vLLM deployment of the Dolphin model, we implemented two vllm plugins: [vllm-dolphin](https://github.com/hanyd2010/vllm-dolphin)[](https://pypi.org/project/vllm-dolphin/) and [vllm-mbart](https://github.com/hanyd2010/vllm-mbart)[](https://pypi.org/project/vllm-mbart/). 9 | We also provide Dolphin vllm demos for both offline inference and online deployment. 10 | 11 | ## 🛠️ Installation 12 | 13 | ``` 14 | # Install vllm 15 | pip install vllm>=0.9.0 16 | 17 | # Install vllm-dolphin 18 | pip install vllm-dolphin==0.1 19 | ``` 20 | 21 | ## ⚡ Offline Inference 22 | ``` 23 | # predict elements reading order 24 | python deployment/vllm/demo_vllm.py --model ByteDance/Dolphin --image_path ./demo/page_imgs/page_1.jpeg --prompt "Parse the reading order of this document." 25 | 26 | # recognize text/latex 27 | python deployment/vllm/demo_vllm.py --model ByteDance/Dolphin --image_path ./demo/element_imgs/block_formula.jpeg --prompt "Read text in the image." 28 | python deployment/vllm/demo_vllm.py --model ByteDance/Dolphin --image_path ./demo/element_imgs/para_1.jpg --prompt "Read text in the image." 29 | 30 | # recognize table 31 | python deployment/vllm/demo_vllm.py --model ByteDance/Dolphin --image_path ./demo/element_imgs/table_1.jpeg --prompt "Parse the table in the image." 32 | ``` 33 | 34 | 35 | ## ⚡ Online Inference 36 | ``` 37 | # 1. Start Api Server 38 | python deployment/vllm/api_server.py --model="ByteDance/Dolphin" --hf-overrides "{\"architectures\": [\"DolphinForConditionalGeneration\"]}" 39 | 40 | # 2. Predict 41 | # predict elements reading order 42 | python deployment/vllm/api_client.py --image_path ./demo/page_imgs/page_1.jpeg --prompt "Parse the reading order of this document." 43 | 44 | # recognize text/latex 45 | python deployment/vllm/api_client.py --image_path ./demo/element_imgs/block_formula.jpeg --prompt "Read text in the image." 46 | python deployment/vllm/api_client.py --image_path ./demo/element_imgs/para_1.jpg --prompt "Read text in the image." 47 | 48 | # recognize table 49 | python deployment/vllm/api_client.py --image_path ./demo/element_imgs/table_1.jpeg --prompt "Parse the table in the image." 50 | ``` -------------------------------------------------------------------------------- /deployment/vllm/api_client.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project 3 | """Example Python client for `vllm.entrypoints.api_server` 4 | Start the demo server: 5 | python -m vllm.entrypoints.api_server --model <model_name> 6 | 7 | NOTE: The API server is used only for demonstration and simple performance 8 | benchmarks. It is not intended for production use. 9 | For production use, we recommend `vllm serve` and the OpenAI client API. 10 | """ 11 | 12 | import argparse 13 | import base64 14 | import json 15 | from argparse import Namespace 16 | from collections.abc import Iterable 17 | 18 | import requests 19 | 20 | 21 | def clear_line(n: int = 1) -> None: 22 | LINE_UP = "\033[1A" 23 | LINE_CLEAR = "\x1b[2K" 24 | for _ in range(n): 25 | print(LINE_UP, end=LINE_CLEAR, flush=True) 26 | 27 | 28 | def encode_image_base64(image_path: str) -> str: 29 | """Encode local image to base64 format.""" 30 | 31 | with open(image_path, "rb") as f: 32 | image_data = f.read() 33 | result = base64.b64encode(image_data).decode("utf-8") 34 | 35 | return result 36 | 37 | 38 | def post_http_request( 39 | prompt: str, image_path: str, api_url: str, stream: bool = False 40 | ) -> requests.Response: 41 | headers = {"User-Agent": "Test Client"} 42 | pload = { 43 | "encoder_prompt": "", 44 | "decoder_prompt": prompt, 45 | "image_base64": encode_image_base64(image_path), 46 | "temperature": 0.0, 47 | "max_tokens": 2048, 48 | "stream": stream, 49 | } 50 | response = requests.post(api_url, headers=headers, json=pload, stream=stream) 51 | return response 52 | 53 | 54 | def get_streaming_response(response: requests.Response) -> Iterable[list[str]]: 55 | for chunk in response.iter_lines( 56 | chunk_size=8192, decode_unicode=False, delimiter=b"\n" 57 | ): 58 | if chunk: 59 | data = json.loads(chunk.decode("utf-8")) 60 | output = data["text"] 61 | yield output 62 | 63 | 64 | def get_response(response: requests.Response) -> list[str]: 65 | data = json.loads(response.content) 66 | output = data["text"] 67 | return output 68 | 69 | 70 | def parse_args(): 71 | parser = argparse.ArgumentParser() 72 | parser.add_argument("--host", type=str, default="localhost") 73 | parser.add_argument("--port", type=int, default=8000) 74 | parser.add_argument("--prompt", type=str, default="Parse the reading order of this document.") 75 | parser.add_argument("--image_path", type=str, default="./demo/page_imgs/page_1.jpeg") 76 | parser.add_argument("--stream", action="store_true") 77 | return parser.parse_args() 78 | 79 | 80 | def main(args: Namespace): 81 | prompt = args.prompt 82 | image_path = args.image_path 83 | api_url = f"http://{args.host}:{args.port}/generate" 84 | stream = args.stream 85 | 86 | print(f"Prompt: {prompt!r}\n", flush=True) 87 | response = post_http_request(prompt, image_path, api_url, stream) 88 | 89 | if stream: 90 | num_printed_lines = 0 91 | for h in get_streaming_response(response): 92 | clear_line(num_printed_lines) 93 | num_printed_lines = 0 94 | for i, line in enumerate(h): 95 | num_printed_lines += 1 96 | print(f"Response {i}: {line!r}", flush=True) 97 | else: 98 | output = get_response(response) 99 | print(f"Response: {output[0]!r}", flush=True) 100 | 101 | 102 | if __name__ == "__main__": 103 | args = parse_args() 104 | main(args) 105 | -------------------------------------------------------------------------------- /deployment/vllm/api_server.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | """ 3 | NOTE: This API server is used only for demonstrating usage of AsyncEngine 4 | and simple performance benchmarks. It is not intended for production use. 5 | For production use, we recommend using our OpenAI compatible server. 6 | We are also not going to accept PRs modifying this file, please 7 | change `vllm/entrypoints/openai/api_server.py` instead. 8 | """ 9 | 10 | import asyncio 11 | import base64 12 | import json 13 | import io 14 | import ssl 15 | from argparse import Namespace 16 | from collections.abc import AsyncGenerator 17 | from PIL import Image 18 | from typing import Any, Optional 19 | 20 | from fastapi import FastAPI, Request 21 | from fastapi.responses import JSONResponse, Response, StreamingResponse 22 | 23 | from vllm.engine.arg_utils import AsyncEngineArgs 24 | from vllm.engine.async_llm_engine import AsyncLLMEngine 25 | from vllm.entrypoints.launcher import serve_http 26 | from vllm.entrypoints.utils import with_cancellation 27 | from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt 28 | from vllm.logger import init_logger 29 | from vllm.sampling_params import SamplingParams 30 | from vllm.usage.usage_lib import UsageContext 31 | from vllm.utils import FlexibleArgumentParser, random_uuid, set_ulimit 32 | from vllm.version import __version__ as VLLM_VERSION 33 | 34 | logger = init_logger("api_server") 35 | 36 | TIMEOUT_KEEP_ALIVE = 5 # seconds. 37 | app = FastAPI() 38 | engine = None 39 | 40 | 41 | @app.get("/health") 42 | async def health() -> Response: 43 | """Health check.""" 44 | return Response(status_code=200) 45 | 46 | 47 | @app.post("/generate") 48 | async def generate(request: Request) -> Response: 49 | """Generate completion for the request. 50 | 51 | The request should be a JSON object with the following fields: 52 | - prompt: the prompt to use for the generation. 53 | - stream: whether to stream the results or not. 54 | - other fields: the sampling parameters (See `SamplingParams` for details). 55 | """ 56 | request_dict = await request.json() 57 | return await _generate(request_dict, raw_request=request) 58 | 59 | 60 | async def decode_image(image_base64: str) -> Image.Image: 61 | image_data = base64.b64decode(image_base64) 62 | image = Image.open(io.BytesIO(image_data)) 63 | return image 64 | 65 | 66 | async def custom_process_prompt(encoder_prompt: str, decoder_prompt: str, 67 | image_base64: str) -> ExplicitEncoderDecoderPrompt: 68 | assert engine is not None 69 | tokenizer = engine.engine.get_tokenizer_group().tokenizer 70 | image = await decode_image(image_base64) 71 | 72 | if encoder_prompt == "": 73 | encoder_prompt = "0" * 783 # For Dolphin 74 | 75 | if decoder_prompt == "": 76 | decoder_prompt_ids = tokenizer.bos_token_id 77 | else: 78 | decoder_prompt = f"<s>{decoder_prompt.strip()} <Answer/>" 79 | decoder_prompt_ids = tokenizer(decoder_prompt, add_special_tokens=False)["input_ids"] 80 | 81 | enc_dec_prompt = ExplicitEncoderDecoderPrompt( 82 | encoder_prompt=TextPrompt(prompt=encoder_prompt, multi_modal_data={"image": image}), 83 | decoder_prompt=TokensPrompt(prompt_token_ids=decoder_prompt_ids), 84 | ) 85 | 86 | return enc_dec_prompt 87 | 88 | 89 | @with_cancellation 90 | async def _generate(request_dict: dict, raw_request: Request) -> Response: 91 | encoder_prompt = request_dict.pop("encoder_prompt", "") 92 | decoder_prompt = request_dict.pop("decoder_prompt", "") 93 | image_base64 = request_dict.pop("image_base64", "") 94 | stream = request_dict.pop("stream", False) 95 | sampling_params = SamplingParams(**request_dict) 96 | request_id = random_uuid() 97 | 98 | assert engine is not None 99 | 100 | enc_dec_prompt = await custom_process_prompt(encoder_prompt, decoder_prompt, image_base64) 101 | results_generator = engine.generate(enc_dec_prompt, sampling_params, request_id) 102 | 103 | # Streaming case 104 | async def stream_results() -> AsyncGenerator[bytes, None]: 105 | async for request_output in results_generator: 106 | prompt = request_output.prompt 107 | assert prompt is not None 108 | text_outputs = [ 109 | prompt + output.text for output in request_output.outputs 110 | ] 111 | ret = {"text": text_outputs} 112 | yield (json.dumps(ret) + "\n").encode("utf-8") 113 | 114 | if stream: 115 | return StreamingResponse(stream_results()) 116 | 117 | # Non-streaming case 118 | final_output = None 119 | try: 120 | async for request_output in results_generator: 121 | final_output = request_output 122 | except asyncio.CancelledError: 123 | return Response(status_code=499) 124 | 125 | assert final_output is not None 126 | prompt = final_output.prompt 127 | assert prompt is not None 128 | text_outputs = [prompt + output.text.strip() for output in final_output.outputs] 129 | ret = {"text": text_outputs} 130 | return JSONResponse(ret) 131 | 132 | 133 | def build_app(args: Namespace) -> FastAPI: 134 | global app 135 | 136 | app.root_path = args.root_path 137 | return app 138 | 139 | 140 | async def init_app( 141 | args: Namespace, 142 | llm_engine: Optional[AsyncLLMEngine] = None, 143 | ) -> FastAPI: 144 | app = build_app(args) 145 | 146 | global engine 147 | 148 | engine_args = AsyncEngineArgs.from_cli_args(args) 149 | engine = (llm_engine 150 | if llm_engine is not None else AsyncLLMEngine.from_engine_args( 151 | engine_args, usage_context=UsageContext.API_SERVER)) 152 | app.state.engine_client = engine 153 | return app 154 | 155 | 156 | async def run_server(args: Namespace, 157 | llm_engine: Optional[AsyncLLMEngine] = None, 158 | **uvicorn_kwargs: Any) -> None: 159 | logger.info("vLLM API server version %s", VLLM_VERSION) 160 | logger.info("args: %s", args) 161 | 162 | set_ulimit() 163 | 164 | app = await init_app(args, llm_engine) 165 | assert engine is not None 166 | 167 | shutdown_task = await serve_http( 168 | app, 169 | sock=None, 170 | enable_ssl_refresh=args.enable_ssl_refresh, 171 | host=args.host, 172 | port=args.port, 173 | log_level=args.log_level, 174 | timeout_keep_alive=TIMEOUT_KEEP_ALIVE, 175 | ssl_keyfile=args.ssl_keyfile, 176 | ssl_certfile=args.ssl_certfile, 177 | ssl_ca_certs=args.ssl_ca_certs, 178 | ssl_cert_reqs=args.ssl_cert_reqs, 179 | **uvicorn_kwargs, 180 | ) 181 | 182 | await shutdown_task 183 | 184 | 185 | if __name__ == "__main__": 186 | parser = FlexibleArgumentParser() 187 | parser.add_argument("--host", type=str, default=None) 188 | parser.add_argument("--port", type=parser.check_port, default=8000) 189 | parser.add_argument("--ssl-keyfile", type=str, default=None) 190 | parser.add_argument("--ssl-certfile", type=str, default=None) 191 | parser.add_argument("--ssl-ca-certs", 192 | type=str, 193 | default=None, 194 | help="The CA certificates file") 195 | parser.add_argument( 196 | "--enable-ssl-refresh", 197 | action="store_true", 198 | default=False, 199 | help="Refresh SSL Context when SSL certificate files change") 200 | parser.add_argument( 201 | "--ssl-cert-reqs", 202 | type=int, 203 | default=int(ssl.CERT_NONE), 204 | help="Whether client certificate is required (see stdlib ssl module's)" 205 | ) 206 | parser.add_argument( 207 | "--root-path", 208 | type=str, 209 | default=None, 210 | help="FastAPI root_path when app is behind a path based routing proxy") 211 | parser.add_argument("--log-level", type=str, default="debug") 212 | parser = AsyncEngineArgs.add_cli_args(parser) 213 | args = parser.parse_args() 214 | 215 | asyncio.run(run_server(args)) 216 | -------------------------------------------------------------------------------- /deployment/vllm/demo_vllm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 3 | SPDX-License-Identifier: MIT 4 | """ 5 | 6 | import vllm_dolphin # vllm_dolphin plugin 7 | import argparse 8 | from argparse import Namespace 9 | from PIL import Image 10 | 11 | from vllm import LLM, SamplingParams 12 | from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt 13 | 14 | import torch 15 | import os 16 | 17 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 18 | 19 | 20 | def offline_inference(model_id: str, prompt: str, image_path: str, max_tokens: int = 2048): 21 | dtype = "float16" if torch.cuda.is_available() else "float32" 22 | # Create an encoder/decoder model instance 23 | llm = LLM( 24 | model=model_id, 25 | dtype=dtype, 26 | enforce_eager=True, 27 | trust_remote_code=True, 28 | max_num_seqs=8, 29 | hf_overrides={"architectures": ["DolphinForConditionalGeneration"]}, 30 | ) 31 | 32 | # Create a sampling params object. 33 | sampling_params = SamplingParams( 34 | temperature=0.0, 35 | logprobs=0, 36 | max_tokens=max_tokens, 37 | prompt_logprobs=None, 38 | skip_special_tokens=False, 39 | ) 40 | 41 | # process prompt 42 | tokenizer = llm.llm_engine.get_tokenizer_group().tokenizer 43 | 44 | # The Dolphin model does not require an Encoder Prompt. To ensure vllm correctly allocates KV Cache, 45 | # it is necessary to simulate an Encoder Prompt. 46 | encoder_prompt = "0" * 783 47 | decoder_prompt = f"<s>{prompt.strip()} <Answer/>" 48 | 49 | image = Image.open(image_path) 50 | enc_dec_prompt = ExplicitEncoderDecoderPrompt( 51 | encoder_prompt=TextPrompt(prompt=encoder_prompt, multi_modal_data={"image": image}), 52 | decoder_prompt=TokensPrompt( 53 | prompt_token_ids=tokenizer(decoder_prompt, add_special_tokens=False)["input_ids"] 54 | ), 55 | ) 56 | 57 | # Generate output tokens from the prompts. The output is a list of 58 | # RequestOutput objects that contain the prompt, generated text, and other information. 59 | outputs = llm.generate(enc_dec_prompt, sampling_params) 60 | 61 | print("------" * 8) 62 | # Print the outputs. 63 | for output in outputs: 64 | decoder_prompt_tokens = tokenizer.batch_decode(output.prompt_token_ids, skip_special_tokens=True) 65 | decoder_prompt = "".join(decoder_prompt_tokens) 66 | generated_text = output.outputs[0].text.strip() 67 | print(f"Decoder prompt: {decoder_prompt!r}, " 68 | f"\nGenerated text: {generated_text!r}") 69 | 70 | print("------" * 8) 71 | 72 | 73 | def parse_args(): 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument("--model", type=str, default="ByteDance/Dolphin") 76 | parser.add_argument("--image_path", type=str, default="./demo/page_imgs/page_1.jpeg") 77 | parser.add_argument("--prompt", type=str, default="Parse the reading order of this document.") 78 | return parser.parse_args() 79 | 80 | 81 | def main(args: Namespace): 82 | model = args.model 83 | prompt = args.prompt 84 | image_path = args.image_path 85 | 86 | offline_inference(model, prompt, image_path) 87 | 88 | 89 | if __name__ == "__main__": 90 | args = parse_args() 91 | main(args) 92 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | include = '\.pyi?#39; 4 | exclude = ''' 5 | /( 6 | \.git 7 | | \.hg 8 | | \.mypy_cache 9 | | \.tox 10 | | \.venv 11 | | _build 12 | | buck-out 13 | | build 14 | | dist 15 | )/ 16 | ''' 17 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.24.4 2 | omegaconf==2.3.0 3 | opencv-python==4.11.0.86 4 | opencv-python-headless==4.5.5.64 5 | pillow==9.3.0 6 | timm==0.5.4 7 | torch==2.1.0 8 | torchvision==0.16.0 9 | transformers==4.47.0 10 | accelerate==1.6.0 11 | pymupdf==1.26 -------------------------------------------------------------------------------- /utils/markdown_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 3 | SPDX-License-Identifier: MIT 4 | """ 5 | 6 | import re 7 | import base64 8 | from typing import List, Dict, Any, Optional 9 | 10 | 11 | """ 12 | Example input: 13 | [ 14 | {"label": "tab", "bbox": [0.176, 0.74, 0.824, 0.82], "text": "<table><tr><td></td><td>HellaSwag</td><td>Obqa</td><td>WinoGrande</td><td>ARC-c</td><td>ARC-e</td><td>boolq</td><td>piqa</td><td>Avg</td></tr><tr><td>OPT-1.3B</td><td>53.65</td><td>33.40</td><td>59.59</td><td>29.44</td><td>50.80</td><td>60.83</td><td>72.36</td><td>51.44</td></tr><tr><td>Pythia-1.0B</td><td>47.16</td><td>31.40</td><td>53.43</td><td>27.05</td><td>48.99</td><td>57.83</td><td>69.21</td><td>48.30</td></tr><tr><td>Pythia-1.4B</td><td>52.01</td><td>33.20</td><td>57.38</td><td>28.50</td><td>54.00</td><td>63.27</td><td>70.95</td><td>51.33</td></tr><tr><td>TinyLlama-1.1B</td><td>59.20</td><td>36.00</td><td>59.12</td><td>30.10</td><td>55.25</td><td>57.83</td><td>73.29</td><td>52.99</td></tr></table>", "reading_order": 6}, 15 | {"label": "cap", "bbox": [0.28, 0.729, 0.711, 0.74], "text": "Table 2: Zero-shot performance on commonsense reasoning tasks", "reading_order": 7}, 16 | {"label": "para", "bbox": [0.176, 0.848, 0.826, 0.873], "text": "We of performance during training We tracked the accuracy of TinyLlama on common-\nsense reasoning benchmarks during its pre-training, as shown in Fig. 2 . Generally, the performance of", "reading_order": 8}, 17 | {"label": "fnote", "bbox": [0.176, 0.88, 0.824, 0.912], "text": "${ }^{4}$ Due to a bug in the config file, the learning rate did not decrease immediately after warmup and remained at\nthe maximum value for several steps before we fixed this.", "reading_order": 9}, 18 | {"label": "foot", "bbox": [0.496, 0.939, 0.501, 0.95], "text": "14", "reading_order": 10} 19 | ] 20 | """ 21 | 22 | 23 | def extract_table_from_html(html_string): 24 | """Extract and clean table tags from HTML string""" 25 | try: 26 | table_pattern = re.compile(r'<table.*?>.*?</table>', re.DOTALL) 27 | tables = table_pattern.findall(html_string) 28 | tables = [re.sub(r'<table[^>]*>', '<table>', table) for table in tables] 29 | return '\n'.join(tables) 30 | except Exception as e: 31 | print(f"extract_table_from_html error: {str(e)}") 32 | return f"<table><tr><td>Error extracting table: {str(e)}</td></tr></table>" 33 | 34 | 35 | class MarkdownConverter: 36 | """Convert structured recognition results to Markdown format""" 37 | 38 | def __init__(self): 39 | # Define heading levels for different section types 40 | self.heading_levels = { 41 | 'title': '#', 42 | 'sec': '##', 43 | 'sub_sec': '###' 44 | } 45 | 46 | # Define which labels need special handling 47 | self.special_labels = { 48 | 'tab', 'fig', 'title', 'sec', 'sub_sec', 49 | 'list', 'formula', 'reference', 'alg' 50 | } 51 | 52 | def try_remove_newline(self, text: str) -> str: 53 | try: 54 | # Preprocess text to handle line breaks 55 | text = text.strip() 56 | text = text.replace('-\n', '') 57 | 58 | # Handle Chinese text line breaks 59 | def is_chinese(char): 60 | return '\u4e00' <= char <= '\u9fff' 61 | 62 | lines = text.split('\n') 63 | processed_lines = [] 64 | 65 | # Process all lines except the last one 66 | for i in range(len(lines)-1): 67 | current_line = lines[i].strip() 68 | next_line = lines[i+1].strip() 69 | 70 | # Always add the current line, but determine if we need a newline 71 | if current_line: # If current line is not empty 72 | if next_line: # If next line is not empty 73 | # For Chinese text handling 74 | if is_chinese(current_line[-1]) and is_chinese(next_line[0]): 75 | processed_lines.append(current_line) 76 | else: 77 | processed_lines.append(current_line + ' ') 78 | else: 79 | # Next line is empty, add current line with newline 80 | processed_lines.append(current_line + '\n') 81 | else: 82 | # Current line is empty, add an empty line 83 | processed_lines.append('\n') 84 | 85 | # Add the last line 86 | if lines and lines[-1].strip(): 87 | processed_lines.append(lines[-1].strip()) 88 | 89 | text = ''.join(processed_lines) 90 | 91 | return text 92 | except Exception as e: 93 | print(f"try_remove_newline error: {str(e)}") 94 | return text # Return original text on error 95 | 96 | def _handle_text(self, text: str) -> str: 97 | """ 98 | Process regular text content, preserving paragraph structure 99 | """ 100 | try: 101 | if not text: 102 | return "" 103 | 104 | if text.strip().startswith("\\begin{array}") and text.strip().endswith("\\end{array}"): 105 | text = "$" + text + "$" 106 | elif ("_{" in text or "^{" in text or "\\" in text or "_ {" in text or "^ {" in text) and ("quot; not in text) and ("\\begin" not in text): 107 | text = "quot; + text + "quot; 108 | 109 | # Process formulas in text before handling other text processing 110 | text = self._process_formulas_in_text(text) 111 | 112 | text = self.try_remove_newline(text) 113 | 114 | # Return processed text 115 | return text 116 | except Exception as e: 117 | print(f"_handle_text error: {str(e)}") 118 | return text # Return original text on error 119 | 120 | def _process_formulas_in_text(self, text: str) -> str: 121 | """ 122 | Process mathematical formulas in text by iteratively finding and replacing formulas. 123 | - Identify inline and block formulas 124 | - Replace newlines within formulas with \\ 125 | """ 126 | try: 127 | # Define formula delimiters and their corresponding patterns 128 | delimiters = [ 129 | ('$', '$'), # Block formula with $ 130 | ('\\[', '\\]'), # Block formula with \[ \] 131 | ('#39;, '#39;), # Inline formula with $ 132 | ('\\(', '\\)') # Inline formula with \( \) 133 | ] 134 | 135 | # Process the text by iterating through each delimiter type 136 | result = text 137 | 138 | for start_delim, end_delim in delimiters: 139 | # Create a pattern that matches from start to end delimiter 140 | # Using a custom approach to avoid issues with nested delimiters 141 | current_pos = 0 142 | processed_parts = [] 143 | 144 | while current_pos < len(result): 145 | # Find the next start delimiter 146 | start_pos = result.find(start_delim, current_pos) 147 | if start_pos == -1: 148 | # No more formulas of this type 149 | processed_parts.append(result[current_pos:]) 150 | break 151 | 152 | # Add text before the formula 153 | processed_parts.append(result[current_pos:start_pos]) 154 | 155 | # Find the matching end delimiter 156 | end_pos = result.find(end_delim, start_pos + len(start_delim)) 157 | if end_pos == -1: 158 | # No matching end delimiter, treat as regular text 159 | processed_parts.append(result[start_pos:]) 160 | break 161 | 162 | # Extract the formula content (without delimiters) 163 | formula_content = result[start_pos + len(start_delim):end_pos] 164 | 165 | # Process the formula content - replace newlines with \\ 166 | processed_formula = formula_content.replace('\n', ' \\\\ ') 167 | 168 | # Add the processed formula with its delimiters 169 | processed_parts.append(f"{start_delim}{processed_formula}{end_delim}") 170 | 171 | # Move past this formula 172 | current_pos = end_pos + len(end_delim) 173 | 174 | # Update the result with processed text 175 | result = ''.join(processed_parts) 176 | return result 177 | except Exception as e: 178 | print(f"_process_formulas_in_text error: {str(e)}") 179 | return text # Return original text on error 180 | 181 | def _remove_newline_in_heading(self, text: str) -> str: 182 | """ 183 | Remove newline in heading 184 | """ 185 | try: 186 | # Handle Chinese text line breaks 187 | def is_chinese(char): 188 | return '\u4e00' <= char <= '\u9fff' 189 | 190 | # Check if the text contains Chinese characters 191 | if any(is_chinese(char) for char in text): 192 | return text.replace('\n', '') 193 | else: 194 | return text.replace('\n', ' ') 195 | 196 | except Exception as e: 197 | print(f"_remove_newline_in_heading error: {str(e)}") 198 | return text 199 | 200 | def _handle_heading(self, text: str, label: str) -> str: 201 | """ 202 | Convert section headings to appropriate markdown format 203 | """ 204 | try: 205 | level = self.heading_levels.get(label, '#') 206 | text = text.strip() 207 | text = self._remove_newline_in_heading(text) 208 | text = self._handle_text(text) 209 | return f"{level} {text}\n\n" 210 | except Exception as e: 211 | print(f"_handle_heading error: {str(e)}") 212 | return f"# Error processing heading: {text}\n\n" 213 | 214 | def _handle_list_item(self, text: str) -> str: 215 | """ 216 | Convert list items to markdown list format 217 | """ 218 | try: 219 | return f"- {text.strip()}\n" 220 | except Exception as e: 221 | print(f"_handle_list_item error: {str(e)}") 222 | return f"- Error processing list item: {text}\n" 223 | 224 | def _handle_figure(self, text: str, section_count: int) -> str: 225 | """ 226 | Handle figure content 227 | """ 228 | try: 229 | # Check if it's a file path starting with "figures/" 230 | if text.startswith("figures/"): 231 | # Convert to relative path from markdown directory to figures directory 232 | relative_path = f"../{text}" 233 | return f"\n\n" 234 | 235 | # Check if it's already a markdown format image link 236 | if text.startswith("\n\n" 243 | elif ";" in text and "," in text: 244 | return f"\n\n" 245 | else: 246 | # Assume it's raw base64, convert to data URI 247 | img_format = "png" 248 | data_uri = f"data:image/{img_format};base64,{text}" 249 | return f"\n\n" 250 | 251 | except Exception as e: 252 | print(f"_handle_figure error: {str(e)}") 253 | return f"*[Error processing figure: {str(e)}]*\n\n" 254 | 255 | def _handle_table(self, text: str) -> str: 256 | """ 257 | Convert table content to markdown format 258 | """ 259 | try: 260 | markdown_content = [] 261 | if '<table' in text.lower() or '<tr' in text.lower(): 262 | markdown_table = extract_table_from_html(text) 263 | markdown_content.append(markdown_table + "\n") 264 | else: 265 | table_lines = text.split('\n') 266 | if table_lines: 267 | col_count = len(table_lines[0].split()) if table_lines[0] else 1 268 | header = '| ' + ' | '.join(table_lines[0].split()) + ' |' 269 | markdown_content.append(header) 270 | markdown_content.append('| ' + ' | '.join(['---'] * col_count) + ' |') 271 | for line in table_lines[1:]: 272 | cells = line.split() 273 | while len(cells) < col_count: 274 | cells.append('') 275 | markdown_content.append('| ' + ' | '.join(cells) + ' |') 276 | return '\n'.join(markdown_content) + '\n\n' 277 | except Exception as e: 278 | print(f"_handle_table error: {str(e)}") 279 | return f"*[Error processing table: {str(e)}]*\n\n" 280 | 281 | def _handle_algorithm(self, text: str) -> str: 282 | """ 283 | Process algorithm blocks with proper formatting 284 | """ 285 | try: 286 | # Remove algorithm environment tags if present 287 | text = re.sub(r'\\begin\{algorithm\}(.*?)\\end\{algorithm\}', r'\1', text, flags=re.DOTALL) 288 | text = text.replace('\\begin{algorithm}', '').replace('\\end{algorithm}', '') 289 | text = text.replace('\\begin{algorithmic}', '').replace('\\end{algorithmic}', '') 290 | 291 | # Process the algorithm text 292 | lines = text.strip().split('\n') 293 | 294 | # Check if there's a caption or label 295 | caption = "" 296 | algorithm_text = [] 297 | 298 | for line in lines: 299 | if '\\caption' in line: 300 | # Extract caption text 301 | caption_match = re.search(r'\\caption\{(.*?)\}', line) 302 | if caption_match: 303 | caption = f"**{caption_match.group(1)}**\n\n" 304 | continue 305 | elif '\\label' in line: 306 | continue # Skip label lines 307 | else: 308 | algorithm_text.append(line) 309 | 310 | # Join the algorithm text and wrap in code block 311 | formatted_text = '\n'.join(algorithm_text) 312 | 313 | # Return the formatted algorithm with caption 314 | return f"{caption}```\n{formatted_text}\n```\n\n" 315 | except Exception as e: 316 | print(f"_handle_algorithm error: {str(e)}") 317 | return f"*[Error processing algorithm: {str(e)}]*\n\n{text}\n\n" 318 | 319 | def _handle_formula(self, text: str) -> str: 320 | """ 321 | Handle formula-specific content 322 | """ 323 | try: 324 | # Process the formula content 325 | processed_text = self._process_formulas_in_text(text) 326 | 327 | # For formula blocks, ensure they're properly formatted in markdown 328 | if '$' not in processed_text and '\\[' not in processed_text: 329 | # If no block formula delimiters are present, wrap in $ for block formula 330 | processed_text = f'${processed_text}$' 331 | 332 | return f"{processed_text}\n\n" 333 | except Exception as e: 334 | print(f"_handle_formula error: {str(e)}") 335 | return f"*[Error processing formula: {str(e)}]*\n\n" 336 | 337 | def convert(self, recognition_results: List[Dict[str, Any]]) -> str: 338 | """ 339 | Convert recognition results to markdown format 340 | """ 341 | try: 342 | markdown_content = [] 343 | 344 | for section_count, result in enumerate(recognition_results): 345 | try: 346 | label = result.get('label', '') 347 | text = result.get('text', '').strip() 348 | 349 | # Skip empty text 350 | if not text: 351 | continue 352 | 353 | # Handle different content types 354 | if label in {'title', 'sec', 'sub_sec'}: 355 | markdown_content.append(self._handle_heading(text, label)) 356 | elif label == 'list': 357 | markdown_content.append(self._handle_list_item(text)) 358 | elif label == 'fig': 359 | markdown_content.append(self._handle_figure(text, section_count)) 360 | elif label == 'tab': 361 | markdown_content.append(self._handle_table(text)) 362 | elif label == 'alg': 363 | markdown_content.append(self._handle_algorithm(text)) 364 | elif label == 'formula': 365 | markdown_content.append(self._handle_formula(text)) 366 | elif label not in self.special_labels: 367 | # Handle regular text (paragraphs, etc.) 368 | processed_text = self._handle_text(text) 369 | markdown_content.append(f"{processed_text}\n\n") 370 | except Exception as e: 371 | print(f"Error processing item {section_count}: {str(e)}") 372 | # Add a placeholder for the failed item 373 | markdown_content.append(f"*[Error processing content]*\n\n") 374 | 375 | # Join all content and apply post-processing 376 | result = ''.join(markdown_content) 377 | return self._post_process(result) 378 | except Exception as e: 379 | print(f"convert error: {str(e)}") 380 | return f"Error generating markdown content: {str(e)}" 381 | 382 | def _post_process(self, markdown_content: str) -> str: 383 | """ 384 | Apply post-processing fixes to the generated markdown content 385 | """ 386 | try: 387 | # Handle author information 388 | author_pattern = re.compile(r'\\author\{(.*?)\}', re.DOTALL) 389 | 390 | def process_author_match(match): 391 | # Extract author content 392 | author_content = match.group(1) 393 | # Process the author content 394 | return self._handle_text(author_content) 395 | 396 | # Replace \author{...} with processed content 397 | markdown_content = author_pattern.sub(process_author_match, markdown_content) 398 | 399 | # Handle special case where author is inside math environment 400 | math_author_pattern = re.compile(r'\$(\\author\{.*?\})\#39;, re.DOTALL) 401 | match = math_author_pattern.search(markdown_content) 402 | if match: 403 | # Extract the author command 404 | author_cmd = match.group(1) 405 | # Extract content from author command 406 | author_content_match = re.search(r'\\author\{(.*?)\}', author_cmd, re.DOTALL) 407 | if author_content_match: 408 | # Get author content and process it 409 | author_content = author_content_match.group(1) 410 | processed_content = self._handle_text(author_content) 411 | # Replace the entire $\author{...}$ block with processed content 412 | markdown_content = markdown_content.replace(match.group(0), processed_content) 413 | 414 | # Replace LaTeX abstract environment with plain text 415 | markdown_content = re.sub(r'\\begin\{abstract\}(.*?)\\end\{abstract\}', 416 | r'**Abstract** \1', 417 | markdown_content, 418 | flags=re.DOTALL) 419 | 420 | # Replace standalone \begin{abstract} (without matching end) 421 | markdown_content = re.sub(r'\\begin\{abstract\}', 422 | r'**Abstract**', 423 | markdown_content) 424 | 425 | # Replace LaTeX equation numbers with tag format, handling cases with extra backslashes 426 | markdown_content = re.sub(r'\\eqno\{\((.*?)\)\}', 427 | r'\\tag{\1}', 428 | markdown_content) 429 | 430 | # Find the starting tag of the formula 431 | markdown_content = markdown_content.replace("\[ \\\\", "$ \\\\") 432 | 433 | # Find the ending tag of the formula (ensure this is the only ending tag) 434 | markdown_content = markdown_content.replace("\\\\ \]", "\\\\ $") 435 | 436 | # Fix other common LaTeX issues 437 | replacements = [ 438 | # Fix spacing issues in subscripts and superscripts 439 | (r'_ {', r'_{'), 440 | (r'^ {', r'^{'), 441 | 442 | # Fix potential issues with multiple consecutive newlines 443 | (r'\n{3,}', r'\n\n') 444 | ] 445 | 446 | for old, new in replacements: 447 | markdown_content = re.sub(old, new, markdown_content) 448 | 449 | return markdown_content 450 | except Exception as e: 451 | print(f"_post_process error: {str(e)}") 452 | return markdown_content # Return original content if post-processing fails 453 | -------------------------------------------------------------------------------- /utils/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022-present NAVER Corp. 3 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. 4 | MIT License 5 | This file has been modified by [ByteDance Ltd. and/or its affiliates] on 20250118. 6 | The original file available at https://github.com/clovaai/donut/blob/master/donut/model.py was released under the MIT license. 7 | This modified file is released under the same license. 8 | """ 9 | 10 | import logging 11 | from collections import defaultdict 12 | from typing import List, Optional 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | from PIL import Image 17 | from timm.models.swin_transformer import SwinTransformer 18 | from torch import nn 19 | from transformers import ( 20 | MBartConfig, 21 | MBartForCausalLM, 22 | StoppingCriteria, 23 | StoppingCriteriaList, 24 | ) 25 | from transformers.file_utils import ModelOutput 26 | from transformers.modeling_utils import PretrainedConfig, PreTrainedModel 27 | 28 | 29 | class SwinEncoder(nn.Module): 30 | r""" 31 | Encoder based on SwinTransformer 32 | Set the initial weights and configuration with a pretrained SwinTransformer and then 33 | modify the detailed configurations 34 | 35 | Args: 36 | input_size: Input image size (width, height) 37 | align_long_axis: Whether to rotate image if height is greater than width 38 | window_size: Window size(=patch size) of SwinTransformer 39 | encoder_layer: Number of layers of SwinTransformer encoder 40 | name_or_path: Name of a pretrained model name either registered in huggingface.co. or saved in local. 41 | otherwise, `swin_base_patch4_window12_384` will be set (using `timm`). 42 | """ 43 | 44 | def __init__( 45 | self, 46 | input_size, 47 | align_long_axis: bool = False, 48 | window_size: int = 7, 49 | encoder_layer: List[int] = [2, 2, 14, 2], 50 | patch_size: int = [4, 4], 51 | embed_dim: int = 128, 52 | num_heads: List[int] = [4, 8, 16, 32], 53 | ): 54 | super().__init__() 55 | if isinstance(input_size, int): 56 | input_size = [input_size, input_size] 57 | self.input_size = input_size 58 | self.align_long_axis = align_long_axis 59 | self.window_size = window_size 60 | self.encoder_layer = encoder_layer 61 | self.patch_size = patch_size 62 | self.embed_dim = embed_dim 63 | self.num_heads = num_heads 64 | 65 | self.model = SwinTransformer( 66 | img_size=self.input_size, 67 | depths=self.encoder_layer, 68 | window_size=self.window_size, 69 | patch_size=self.patch_size, 70 | embed_dim=self.embed_dim, 71 | num_heads=self.num_heads, 72 | num_classes=0, 73 | ) 74 | 75 | def forward(self, x: torch.Tensor, text_embedding: Optional[torch.Tensor] = None) -> torch.Tensor: 76 | """ 77 | Args: 78 | x: (batch_size, num_channels, height, width) 79 | """ 80 | x = self.model.patch_embed(x) 81 | x = self.model.pos_drop(x) 82 | x = self.model.layers(x) 83 | return x 84 | 85 | 86 | class LayerNorm(nn.LayerNorm): 87 | """Subclass torch's LayerNorm to handle fp16.""" 88 | 89 | def _set_dtype(self, dtype): 90 | self._dtype = dtype 91 | 92 | def forward(self, x: torch.Tensor): 93 | orig_type = x.dtype 94 | ret = super().forward(x.type(dtype=self._dtype)) 95 | return ret.type(orig_type) 96 | 97 | 98 | class BARTDecoder(nn.Module): 99 | """ 100 | Decoder based on Multilingual BART 101 | Set the initial weights and configuration with a pretrained multilingual BART model, 102 | and modify the detailed configurations as a Donut decoder 103 | 104 | Args: 105 | decoder_layer: 106 | Number of layers of BARTDecoder 107 | max_position_embeddings: 108 | The maximum sequence length to be trained 109 | name_or_path: 110 | Name of a pretrained model name either registered in huggingface.co. or saved in local, 111 | otherwise, `facebook/mbart-large-50` will be set (using `transformers`) 112 | """ 113 | 114 | def __init__( 115 | self, 116 | tokenizer, 117 | decoder_layer: int, 118 | max_position_embeddings: int, 119 | hidden_dimension: int = 1024, 120 | **kwargs, 121 | ): 122 | super().__init__() 123 | self.decoder_layer = decoder_layer 124 | self.max_position_embeddings = max_position_embeddings 125 | self.hidden_dimension = hidden_dimension 126 | 127 | self.tokenizer = tokenizer 128 | 129 | self.model = MBartForCausalLM( 130 | config=MBartConfig( 131 | tie_word_embeddings=True, 132 | is_decoder=True, 133 | is_encoder_decoder=False, 134 | add_cross_attention=True, 135 | decoder_layers=self.decoder_layer, 136 | max_position_embeddings=self.max_position_embeddings, 137 | vocab_size=len(self.tokenizer), 138 | scale_embedding=True, 139 | add_final_layer_norm=True, 140 | d_model=self.hidden_dimension, 141 | ) 142 | ) 143 | # self.model.config.is_encoder_decoder = True # to get cross-attention 144 | self.model.model.decoder.embed_tokens.padding_idx = self.tokenizer.pad_token_id 145 | self.model.prepare_inputs_for_generation = self.prepare_inputs_for_inference 146 | 147 | def add_special_tokens(self, list_of_tokens: List[str]): 148 | """ 149 | Add special tokens to tokenizer and resize the token embeddings 150 | """ 151 | newly_added_num = self.tokenizer.add_special_tokens({"additional_special_tokens": sorted(set(list_of_tokens))}) 152 | if newly_added_num > 0: 153 | self.model.resize_token_embeddings(len(self.tokenizer)) 154 | 155 | def add_tokens(self, list_of_tokens: List[str]): 156 | """ 157 | Add special tokens to tokenizer and resize the token embeddings 158 | """ 159 | newly_added_num = self.tokenizer.add_tokens(sorted(set(list_of_tokens))) 160 | if newly_added_num > 0: 161 | self.model.resize_token_embeddings(len(self.tokenizer)) 162 | 163 | def prepare_inputs_for_inference( 164 | self, 165 | input_ids: torch.Tensor, 166 | encoder_outputs: torch.Tensor, 167 | past=None, 168 | past_key_values=None, 169 | use_cache: bool = None, 170 | attention_mask: torch.Tensor = None, 171 | **kwargs, 172 | ): 173 | """ 174 | Args: 175 | input_ids: (batch_size, sequence_length) 176 | 177 | Returns: 178 | input_ids: (batch_size, sequence_length) 179 | attention_mask: (batch_size, sequence_length) 180 | encoder_hidden_states: (batch_size, sequence_length, embedding_dim) 181 | """ 182 | attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long() 183 | past = past or past_key_values 184 | if past is not None: 185 | input_ids = input_ids[:, -1:] 186 | output = { 187 | "input_ids": input_ids, 188 | "attention_mask": attention_mask, 189 | "past_key_values": past, 190 | "use_cache": use_cache, 191 | "encoder_hidden_states": encoder_outputs.last_hidden_state, 192 | } 193 | return output 194 | 195 | def forward( 196 | self, 197 | input_ids: torch.LongTensor = None, 198 | attention_mask: Optional[torch.Tensor] = None, 199 | encoder_hidden_states: Optional[torch.Tensor] = None, 200 | past_key_values: Optional[torch.Tensor] = None, 201 | inputs_embeds: Optional[torch.FloatTensor] = None, 202 | labels: Optional[torch.Tensor] = None, 203 | use_cache: bool = None, 204 | output_attentions: Optional[torch.Tensor] = None, 205 | output_hidden_states: Optional[torch.Tensor] = None, 206 | return_dict: bool = None, 207 | ): 208 | return self.model.forward( 209 | input_ids=input_ids, 210 | attention_mask=attention_mask, 211 | labels=labels, 212 | encoder_hidden_states=encoder_hidden_states, 213 | past_key_values=past_key_values, 214 | inputs_embeds=inputs_embeds, 215 | use_cache=use_cache, 216 | output_attentions=output_attentions, 217 | output_hidden_states=output_hidden_states, 218 | return_dict=return_dict, 219 | ) 220 | 221 | @staticmethod 222 | def resize_bart_abs_pos_emb(weight: torch.Tensor, max_length: int) -> torch.Tensor: 223 | """ 224 | Resize position embeddings 225 | Truncate if sequence length of MBart backbone is greater than given max_length, 226 | else interpolate to max_length 227 | """ 228 | if weight.shape[0] > max_length: 229 | weight = weight[:max_length, ...] 230 | else: 231 | weight = ( 232 | F.interpolate( 233 | weight.permute(1, 0).unsqueeze(0), 234 | size=max_length, 235 | mode="linear", 236 | align_corners=False, 237 | ) 238 | .squeeze(0) 239 | .permute(1, 0) 240 | ) 241 | return weight 242 | 243 | 244 | class DonutConfig(PretrainedConfig): 245 | 246 | def __init__( 247 | self, 248 | decoder_layer: int = 10, 249 | max_position_embeddings: int = None, 250 | max_length: int = 4096, 251 | hidden_dimension: int = 1024, 252 | **kwargs, 253 | ): 254 | super().__init__() 255 | self.decoder_layer = decoder_layer 256 | self.max_position_embeddings = max_length if max_position_embeddings is None else max_position_embeddings 257 | self.max_length = max_length 258 | self.hidden_dimension = hidden_dimension 259 | 260 | 261 | class RunningVarTorch: 262 | def __init__(self, L=15, norm=False): 263 | self.values = None 264 | self.L = L 265 | self.norm = norm 266 | 267 | def push(self, x: torch.Tensor): 268 | assert x.dim() == 1 269 | if self.values is None: 270 | self.values = x[:, None] 271 | elif self.values.shape[1] < self.L: 272 | self.values = torch.cat((self.values, x[:, None]), 1) 273 | else: 274 | self.values = torch.cat((self.values[:, 1:], x[:, None]), 1) 275 | 276 | def variance(self): 277 | if self.values is None: 278 | return 279 | if self.norm: 280 | return torch.var(self.values, 1) / self.values.shape[1] 281 | else: 282 | return torch.var(self.values, 1) 283 | 284 | 285 | class StoppingCriteriaScores(StoppingCriteria): 286 | def __init__(self, threshold: float = 0.015, window_size: int = 200): 287 | super().__init__() 288 | self.threshold = threshold 289 | self.vars = RunningVarTorch(norm=True) 290 | self.varvars = RunningVarTorch(L=window_size) 291 | self.stop_inds = defaultdict(int) 292 | self.stopped = defaultdict(bool) 293 | self.size = 0 294 | self.window_size = window_size 295 | 296 | @torch.no_grad() 297 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): 298 | last_scores = scores[-1] 299 | self.vars.push(last_scores.max(1)[0].float().cpu()) 300 | self.varvars.push(self.vars.variance()) 301 | self.size += 1 302 | if self.size < self.window_size: 303 | return False 304 | 305 | varvar = self.varvars.variance() 306 | for b in range(len(last_scores)): 307 | if varvar[b] < self.threshold: 308 | if self.stop_inds[b] > 0 and not self.stopped[b]: 309 | self.stopped[b] = self.stop_inds[b] >= self.size 310 | else: 311 | self.stop_inds[b] = int(min(max(self.size, 1) * 1.15 + 150 + self.window_size, 4095)) 312 | else: 313 | self.stop_inds[b] = 0 314 | self.stopped[b] = False 315 | return all(self.stopped.values()) and len(self.stopped) > 0 316 | 317 | 318 | def batch(l, b=15): 319 | subs = [] 320 | for i in range(len(l) - b): 321 | subs.append(l[i : i + b]) 322 | return subs 323 | 324 | 325 | def subdiv(l, b=10): 326 | subs = [] 327 | for i in range(len(l) - b): 328 | subs.append(l[: i + b]) 329 | return subs 330 | 331 | 332 | class DonutModel(PreTrainedModel): 333 | config_class = DonutConfig 334 | base_model_prefix = "donut" 335 | 336 | def __init__(self, config: DonutConfig, vision_tower=None, tokenizer=None): 337 | super().__init__(config) 338 | self.config = config 339 | 340 | self.tokenizer = tokenizer 341 | self.vpm = vision_tower 342 | 343 | # build language model 344 | self.llm = BARTDecoder( 345 | tokenizer=tokenizer, 346 | decoder_layer=self.config.decoder_layer, 347 | max_position_embeddings=self.config.max_position_embeddings, 348 | hidden_dimension=self.config.hidden_dimension, 349 | ) 350 | self.ids_to_tokens = {id: content for content, id in self.llm.tokenizer.vocab.items()} 351 | 352 | def get_input_embeddings(self, tensor): 353 | return self.llm.model.get_input_embeddings()(tensor) 354 | 355 | def forward( 356 | self, 357 | inputs: dict, 358 | ): 359 | image_tensors = inputs["pixel_values"] 360 | input_ids = inputs["input_ids"].contiguous() 361 | attention_mask = inputs["attention_mask"] 362 | labels = inputs["labels"].contiguous() 363 | 364 | encoder_outputs = self.vpm( 365 | image_tensors, 366 | text_embedding=self.llm.model.get_input_embeddings()(input_ids), 367 | ) 368 | 369 | decoder_outputs = self.llm( 370 | input_ids=input_ids, 371 | encoder_hidden_states=encoder_outputs, 372 | attention_mask=attention_mask, 373 | labels=labels, 374 | ) 375 | return decoder_outputs 376 | 377 | def get_hidden_states_during_inference( 378 | self, 379 | prompt_ids: torch.Tensor, 380 | image: Image.Image = None, 381 | image_tensors: Optional[torch.Tensor] = None, 382 | ): 383 | if image_tensors is None: 384 | image_tensors = self.vpm.prepare_input(image).unsqueeze(0) 385 | 386 | if self.device.type != "mps": 387 | image_tensors = image_tensors.to(next(self.parameters()).dtype) 388 | 389 | image_tensors = image_tensors.to(self.device) 390 | prompt_ids = prompt_ids.to(self.device) 391 | all_hidden_states = self.vpm.forward_features( 392 | image_tensors, text_embedding=self.get_input_embeddings(prompt_ids) 393 | ) 394 | return all_hidden_states 395 | 396 | def get_attn_weights_during_inference( 397 | self, 398 | prompt_ids: torch.Tensor, 399 | image: Image.Image = None, 400 | image_tensors: Optional[torch.Tensor] = None, 401 | ): 402 | if image_tensors is None: 403 | image_tensors = self.vpm.prepare_input(image).unsqueeze(0) 404 | 405 | if self.device.type != "mps": 406 | image_tensors = image_tensors.to(next(self.parameters()).dtype) 407 | 408 | image_tensors = image_tensors.to(self.device) 409 | prompt_ids = prompt_ids.to(self.device) 410 | last_attn_score = self.vpm.get_last_layer_cross_attn_score( 411 | image_tensors, text_embedding=self.get_input_embeddings(prompt_ids) 412 | ) 413 | return last_attn_score 414 | 415 | def inference( 416 | self, 417 | prompt_ids: torch.Tensor, 418 | image: Image.Image = None, 419 | image_tensors: Optional[torch.Tensor] = None, 420 | return_attentions: bool = False, 421 | early_stopping: bool = True, 422 | ): 423 | """ 424 | Generate a token sequence in an auto-regressive manner. 425 | 426 | Args: 427 | image: input document image (PIL.Image) 428 | image_tensors: (1, num_channels, height, width) 429 | convert prompt to tensor if image_tensor is not fed 430 | """ 431 | output = { 432 | "predictions": list(), 433 | "sequences": list(), 434 | "repeats": list(), 435 | "repetitions": list(), 436 | } 437 | if image is None and image_tensors is None: 438 | logging.warn("Image not found") 439 | return output 440 | 441 | if image_tensors is None: 442 | image_tensors = self.vpm.prepare_input(image).unsqueeze(0) 443 | 444 | if self.device.type != "mps": 445 | image_tensors = image_tensors.to(next(self.parameters()).dtype) 446 | 447 | image_tensors = image_tensors.to(self.device) 448 | prompt_ids = prompt_ids.to(self.device) 449 | last_hidden_state = self.vpm(image_tensors, text_embedding=self.get_input_embeddings(prompt_ids)) 450 | 451 | encoder_outputs = ModelOutput(last_hidden_state=last_hidden_state, attentions=None) 452 | if len(encoder_outputs.last_hidden_state.size()) == 1: 453 | encoder_outputs.last_hidden_state = encoder_outputs.last_hidden_state.unsqueeze(0) 454 | 455 | # get decoder output 456 | decoder_output = self.llm.model.generate( 457 | input_ids=prompt_ids, 458 | encoder_outputs=encoder_outputs, 459 | min_length=1, 460 | max_length=self.config.max_length, 461 | pad_token_id=self.llm.tokenizer.pad_token_id, 462 | eos_token_id=self.llm.tokenizer.eos_token_id, 463 | use_cache=True, 464 | return_dict_in_generate=True, 465 | output_scores=True, 466 | output_attentions=return_attentions, 467 | do_sample=False, 468 | num_beams=1, 469 | stopping_criteria=StoppingCriteriaList([StoppingCriteriaScores()] if early_stopping else []), 470 | ) 471 | 472 | output["repetitions"] = decoder_output.sequences.clone() 473 | output["sequences"] = decoder_output.sequences.clone() 474 | output["scores"] = torch.stack(decoder_output.scores, 1).softmax(-1).cpu().max(-1)[0] 475 | 476 | output["repetitions"] = self.llm.tokenizer.batch_decode(output["repetitions"], skip_special_tokens=False) 477 | return output 478 | -------------------------------------------------------------------------------- /utils/processor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 3 | SPDX-License-Identifier: MIT 4 | """ 5 | 6 | import numpy as np 7 | import torch 8 | from PIL import ImageOps 9 | from torchvision import transforms 10 | from torchvision.transforms.functional import resize 11 | 12 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 13 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 14 | 15 | 16 | class DolphinProcessor: 17 | def __init__( 18 | self, 19 | dp_config, 20 | tokenizer, 21 | **kwargs, 22 | ) -> None: 23 | 24 | self.tokenizer = tokenizer 25 | transform_args = kwargs.get("transform_args", {}) 26 | self.max_length = transform_args.get("max_length", 2048) 27 | self.input_size = transform_args.get("input_size", [896, 896]) # height, width 28 | if isinstance(self.input_size, int): 29 | self.input_size = [self.input_size, self.input_size] 30 | 31 | try: 32 | self.answer_start_token = self.tokenizer._prompt_end_token 33 | except AttributeError as err: 34 | print('No answer_start_token found, use "" instead') 35 | self.answer_start_token = "" 36 | 37 | self.prefix_answer_space_flag = dp_config.get("prefix_answer_space_flag", True) 38 | self.suffix_prompt_space_flag = dp_config.get("suffix_prompt_space_flag", True) 39 | 40 | self.transform = transforms.Compose( 41 | [transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)] 42 | ) 43 | 44 | def process_prompt_for_inference(self, prompt): 45 | prompt = prompt.replace("<image>\n", "") 46 | if not prompt.startswith("<s>"): 47 | prompt = "<s>" + prompt 48 | message_ids = [self.tokenizer.encode(prompt, add_special_tokens=False)] 49 | ids = torch.from_numpy(np.hstack(message_ids, dtype=np.int32)) 50 | return ids.unsqueeze(0) 51 | 52 | def process_image_for_inference(self, image, return_img_size=False): 53 | image = resize(image, min(self.input_size)) 54 | 55 | image.thumbnail((self.input_size[1], self.input_size[0])) 56 | origin_w, origin_h = image.size 57 | 58 | delta_width = self.input_size[1] - image.width 59 | delta_height = self.input_size[0] - image.height 60 | pad_width = delta_width // 2 61 | pad_height = delta_height // 2 62 | padding = ( 63 | pad_width, 64 | pad_height, 65 | delta_width - pad_width, 66 | delta_height - pad_height, 67 | ) 68 | image = ImageOps.expand(image, padding) 69 | if return_img_size: 70 | return self.transform(image).unsqueeze(0), (origin_w, origin_h) 71 | return self.transform(image).unsqueeze(0) 72 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 3 | SPDX-License-Identifier: MIT 4 | """ 5 | 6 | import copy 7 | import io 8 | import json 9 | import os 10 | import re 11 | from dataclasses import dataclass 12 | from typing import List, Tuple 13 | 14 | import cv2 15 | import numpy as np 16 | import pymupdf 17 | from PIL import Image 18 | 19 | from utils.markdown_utils import MarkdownConverter 20 | 21 | 22 | def save_figure_to_local(pil_crop, save_dir, image_name, reading_order): 23 | """Save cropped figure to local file system 24 | 25 | Args: 26 | pil_crop: PIL Image object of the cropped figure 27 | save_dir: Base directory to save results 28 | image_name: Name of the source image/document 29 | reading_order: Reading order of the figure in the document 30 | 31 | Returns: 32 | str: Filename of the saved figure 33 | """ 34 | try: 35 | # Create figures directory if it doesn't exist 36 | figures_dir = os.path.join(save_dir, "markdown", "figures") 37 | # os.makedirs(figures_dir, exist_ok=True) 38 | 39 | # Generate figure filename 40 | figure_filename = f"{image_name}_figure_{reading_order:03d}.png" 41 | figure_path = os.path.join(figures_dir, figure_filename) 42 | 43 | # Save the figure 44 | pil_crop.save(figure_path, format="PNG", quality=95) 45 | 46 | # print(f"Saved figure: {figure_filename}") 47 | return figure_filename 48 | 49 | except Exception as e: 50 | print(f"Error saving figure: {str(e)}") 51 | # Return a fallback filename 52 | return f"{image_name}_figure_{reading_order:03d}_error.png" 53 | 54 | 55 | def convert_pdf_to_images(pdf_path, target_size=896): 56 | """Convert PDF pages to images 57 | 58 | Args: 59 | pdf_path: Path to PDF file 60 | target_size: Target size for the longest dimension 61 | 62 | Returns: 63 | List of PIL Images 64 | """ 65 | images = [] 66 | try: 67 | doc = pymupdf.open(pdf_path) 68 | 69 | for page_num in range(len(doc)): 70 | page = doc[page_num] 71 | 72 | # Calculate scale to make longest dimension equal to target_size 73 | rect = page.rect 74 | scale = target_size / max(rect.width, rect.height) 75 | 76 | # Render page as image 77 | mat = pymupdf.Matrix(scale, scale) 78 | pix = page.get_pixmap(matrix=mat) 79 | 80 | # Convert to PIL Image 81 | img_data = pix.tobytes("png") 82 | pil_image = Image.open(io.BytesIO(img_data)) 83 | images.append(pil_image) 84 | 85 | doc.close() 86 | print(f"Successfully converted {len(images)} pages from PDF") 87 | return images 88 | 89 | except Exception as e: 90 | print(f"Error converting PDF to images: {str(e)}") 91 | return [] 92 | 93 | 94 | def is_pdf_file(file_path): 95 | """Check if file is a PDF""" 96 | return file_path.lower().endswith(".pdf") 97 | 98 | 99 | def save_combined_pdf_results(all_page_results, pdf_path, save_dir): 100 | """Save combined results for multi-page PDF with both JSON and Markdown 101 | 102 | Args: 103 | all_page_results: List of results for all pages 104 | pdf_path: Path to original PDF file 105 | save_dir: Directory to save results 106 | 107 | Returns: 108 | Path to saved combined JSON file 109 | """ 110 | # Create output filename based on PDF name 111 | base_name = os.path.splitext(os.path.basename(pdf_path))[0] 112 | 113 | # Prepare combined results 114 | combined_results = {"source_file": pdf_path, "total_pages": len(all_page_results), "pages": all_page_results} 115 | 116 | # Save combined JSON results 117 | json_filename = f"{base_name}.json" 118 | json_path = os.path.join(save_dir, "recognition_json", json_filename) 119 | os.makedirs(os.path.dirname(json_path), exist_ok=True) 120 | 121 | with open(json_path, "w", encoding="utf-8") as f: 122 | json.dump(combined_results, f, indent=2, ensure_ascii=False) 123 | 124 | # Generate and save combined markdown 125 | try: 126 | markdown_converter = MarkdownConverter() 127 | 128 | # Combine all page results into a single list for markdown conversion 129 | all_elements = [] 130 | for page_data in all_page_results: 131 | page_elements = page_data.get("elements", []) 132 | if page_elements: 133 | # Add page separator if not the first page 134 | if all_elements: 135 | all_elements.append( 136 | {"label": "page_separator", "text": f"\n\n---\n\n", "reading_order": len(all_elements)} 137 | ) 138 | all_elements.extend(page_elements) 139 | 140 | # Generate markdown content 141 | markdown_content = markdown_converter.convert(all_elements) 142 | 143 | # Save markdown file 144 | markdown_filename = f"{base_name}.md" 145 | markdown_path = os.path.join(save_dir, "markdown", markdown_filename) 146 | os.makedirs(os.path.dirname(markdown_path), exist_ok=True) 147 | 148 | with open(markdown_path, "w", encoding="utf-8") as f: 149 | f.write(markdown_content) 150 | 151 | # print(f"Combined markdown saved to: {markdown_path}") 152 | 153 | except ImportError: 154 | print("MarkdownConverter not available, skipping markdown generation") 155 | except Exception as e: 156 | print(f"Error generating markdown: {e}") 157 | 158 | # print(f"Combined JSON results saved to: {json_path}") 159 | return json_path 160 | 161 | 162 | def check_coord_valid(x1, y1, x2, y2, image_size=None, abs_coord=True): 163 | # print(f"check_coord_valid: {x1}, {y1}, {x2}, {y2}, {image_size}, {abs_coord}") 164 | if x2 <= x1 or y2 <= y1: 165 | return False, f"[{x1}, {y1}, {x2}, {y2}]" 166 | if x1 < 0 or y1 < 0: 167 | return False, f"[{x1}, {y1}, {x2}, {y2}]" 168 | if not abs_coord: 169 | if x2 > 1 or y2 > 1: 170 | return False, f"[{x1}, {y1}, {x2}, {y2}]" 171 | elif image_size is not None: # has image size 172 | if x2 > image_size[0] or y2 > image_size[1]: 173 | return False, f"[{x1}, {y1}, {x2}, {y2}]" 174 | return True, None 175 | 176 | 177 | def adjust_box_edges(image, boxes: List[List[float]], max_pixels=15, threshold=0.2): 178 | """ 179 | Image: cv2.image object, or Path 180 | Input: boxes: list of boxes [[x1, y1, x2, y2]]. Using absolute coordinates. 181 | """ 182 | if isinstance(image, str): 183 | image = cv2.imread(image) 184 | img_h, img_w = image.shape[:2] 185 | new_boxes = [] 186 | for box in boxes: 187 | best_box = copy.deepcopy(box) 188 | 189 | def check_edge(img, current_box, i, is_vertical): 190 | edge = current_box[i] 191 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 192 | _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) 193 | 194 | if is_vertical: 195 | line = binary[current_box[1] : current_box[3] + 1, edge] 196 | else: 197 | line = binary[edge, current_box[0] : current_box[2] + 1] 198 | 199 | transitions = np.abs(np.diff(line)) 200 | return np.sum(transitions) / len(transitions) 201 | 202 | # Only widen the box 203 | edges = [(0, -1, True), (2, 1, True), (1, -1, False), (3, 1, False)] 204 | 205 | current_box = copy.deepcopy(box) 206 | # make sure the box is within the image 207 | current_box[0] = min(max(current_box[0], 0), img_w - 1) 208 | current_box[1] = min(max(current_box[1], 0), img_h - 1) 209 | current_box[2] = min(max(current_box[2], 0), img_w - 1) 210 | current_box[3] = min(max(current_box[3], 0), img_h - 1) 211 | 212 | for i, direction, is_vertical in edges: 213 | best_score = check_edge(image, current_box, i, is_vertical) 214 | if best_score <= threshold: 215 | continue 216 | for step in range(max_pixels): 217 | current_box[i] += direction 218 | if i == 0 or i == 2: 219 | current_box[i] = min(max(current_box[i], 0), img_w - 1) 220 | else: 221 | current_box[i] = min(max(current_box[i], 0), img_h - 1) 222 | score = check_edge(image, current_box, i, is_vertical) 223 | 224 | if score < best_score: 225 | best_score = score 226 | best_box = copy.deepcopy(current_box) 227 | 228 | if score <= threshold: 229 | break 230 | new_boxes.append(best_box) 231 | 232 | return new_boxes 233 | 234 | 235 | def parse_layout_string(bbox_str): 236 | """Parse layout string using regular expressions""" 237 | pattern = r"\[(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+)\]\s*(\w+)" 238 | matches = re.finditer(pattern, bbox_str) 239 | 240 | parsed_results = [] 241 | for match in matches: 242 | coords = [float(match.group(i)) for i in range(1, 5)] 243 | label = match.group(5).strip() 244 | parsed_results.append((coords, label)) 245 | 246 | return parsed_results 247 | 248 | 249 | @dataclass 250 | class ImageDimensions: 251 | """Class to store image dimensions""" 252 | 253 | original_w: int 254 | original_h: int 255 | padded_w: int 256 | padded_h: int 257 | 258 | 259 | def map_to_original_coordinates(x1, y1, x2, y2, dims: ImageDimensions) -> Tuple[int, int, int, int]: 260 | """Map coordinates from padded image back to original image 261 | 262 | Args: 263 | x1, y1, x2, y2: Coordinates in padded image 264 | dims: Image dimensions object 265 | 266 | Returns: 267 | tuple: (x1, y1, x2, y2) coordinates in original image 268 | """ 269 | try: 270 | # Calculate padding offsets 271 | top = (dims.padded_h - dims.original_h) // 2 272 | left = (dims.padded_w - dims.original_w) // 2 273 | 274 | # Map back to original coordinates 275 | orig_x1 = max(0, x1 - left) 276 | orig_y1 = max(0, y1 - top) 277 | orig_x2 = min(dims.original_w, x2 - left) 278 | orig_y2 = min(dims.original_h, y2 - top) 279 | 280 | # Ensure we have a valid box (width and height > 0) 281 | if orig_x2 <= orig_x1: 282 | orig_x2 = min(orig_x1 + 1, dims.original_w) 283 | if orig_y2 <= orig_y1: 284 | orig_y2 = min(orig_y1 + 1, dims.original_h) 285 | 286 | return int(orig_x1), int(orig_y1), int(orig_x2), int(orig_y2) 287 | except Exception as e: 288 | print(f"map_to_original_coordinates error: {str(e)}") 289 | # Return safe coordinates 290 | return 0, 0, min(100, dims.original_w), min(100, dims.original_h) 291 | 292 | 293 | def map_to_relevant_coordinates(abs_coords, dims: ImageDimensions): 294 | """ 295 | From absolute coordinates to relevant coordinates 296 | e.g. [100, 100, 200, 200] -> [0.1, 0.2, 0.3, 0.4] 297 | """ 298 | try: 299 | x1, y1, x2, y2 = abs_coords 300 | return ( 301 | round(x1 / dims.original_w, 3), 302 | round(y1 / dims.original_h, 3), 303 | round(x2 / dims.original_w, 3), 304 | round(y2 / dims.original_h, 3), 305 | ) 306 | except Exception as e: 307 | print(f"map_to_relevant_coordinates error: {str(e)}") 308 | return 0.0, 0.0, 1.0, 1.0 # Return full image coordinates 309 | 310 | 311 | def process_coordinates(coords, padded_image, dims: ImageDimensions, previous_box=None): 312 | """Process and adjust coordinates 313 | 314 | Args: 315 | coords: Normalized coordinates [x1, y1, x2, y2] 316 | padded_image: Padded image 317 | dims: Image dimensions object 318 | previous_box: Previous box coordinates for overlap adjustment 319 | 320 | Returns: 321 | tuple: (x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box) 322 | """ 323 | try: 324 | # Convert normalized coordinates to absolute coordinates 325 | x1, y1 = int(coords[0] * dims.padded_w), int(coords[1] * dims.padded_h) 326 | x2, y2 = int(coords[2] * dims.padded_w), int(coords[3] * dims.padded_h) 327 | 328 | # Ensure coordinates are within image bounds before adjustment 329 | x1 = max(0, min(x1, dims.padded_w - 1)) 330 | y1 = max(0, min(y1, dims.padded_h - 1)) 331 | x2 = max(0, min(x2, dims.padded_w)) 332 | y2 = max(0, min(y2, dims.padded_h)) 333 | 334 | # Ensure width and height are at least 1 pixel 335 | if x2 <= x1: 336 | x2 = min(x1 + 1, dims.padded_w) 337 | if y2 <= y1: 338 | y2 = min(y1 + 1, dims.padded_h) 339 | 340 | # Extend box boundaries 341 | new_boxes = adjust_box_edges(padded_image, [[x1, y1, x2, y2]]) 342 | x1, y1, x2, y2 = new_boxes[0] 343 | 344 | # Ensure coordinates are still within image bounds after adjustment 345 | x1 = max(0, min(x1, dims.padded_w - 1)) 346 | y1 = max(0, min(y1, dims.padded_h - 1)) 347 | x2 = max(0, min(x2, dims.padded_w)) 348 | y2 = max(0, min(y2, dims.padded_h)) 349 | 350 | # Ensure width and height are at least 1 pixel after adjustment 351 | if x2 <= x1: 352 | x2 = min(x1 + 1, dims.padded_w) 353 | if y2 <= y1: 354 | y2 = min(y1 + 1, dims.padded_h) 355 | 356 | # Check for overlap with previous box and adjust 357 | if previous_box is not None: 358 | prev_x1, prev_y1, prev_x2, prev_y2 = previous_box 359 | if (x1 < prev_x2 and x2 > prev_x1) and (y1 < prev_y2 and y2 > prev_y1): 360 | y1 = prev_y2 361 | # Ensure y1 is still valid 362 | y1 = min(y1, dims.padded_h - 1) 363 | # Make sure y2 is still greater than y1 364 | if y2 <= y1: 365 | y2 = min(y1 + 1, dims.padded_h) 366 | 367 | # Update previous box 368 | new_previous_box = [x1, y1, x2, y2] 369 | 370 | # Map to original coordinates 371 | orig_x1, orig_y1, orig_x2, orig_y2 = map_to_original_coordinates(x1, y1, x2, y2, dims) 372 | 373 | return x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box 374 | except Exception as e: 375 | print(f"process_coordinates error: {str(e)}") 376 | # Return safe values 377 | orig_x1, orig_y1, orig_x2, orig_y2 = 0, 0, min(100, dims.original_w), min(100, dims.original_h) 378 | return 0, 0, 100, 100, orig_x1, orig_y1, orig_x2, orig_y2, [0, 0, 100, 100] 379 | 380 | 381 | def prepare_image(image) -> Tuple[np.ndarray, ImageDimensions]: 382 | """Load and prepare image with padding while maintaining aspect ratio 383 | 384 | Args: 385 | image: PIL image 386 | 387 | Returns: 388 | tuple: (padded_image, image_dimensions) 389 | """ 390 | try: 391 | # Convert PIL image to OpenCV format 392 | image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) 393 | original_h, original_w = image.shape[:2] 394 | 395 | # Calculate padding to make square image 396 | max_size = max(original_h, original_w) 397 | top = (max_size - original_h) // 2 398 | bottom = max_size - original_h - top 399 | left = (max_size - original_w) // 2 400 | right = max_size - original_w - left 401 | 402 | # Apply padding 403 | padded_image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(0, 0, 0)) 404 | 405 | padded_h, padded_w = padded_image.shape[:2] 406 | 407 | dimensions = ImageDimensions(original_w=original_w, original_h=original_h, padded_w=padded_w, padded_h=padded_h) 408 | 409 | return padded_image, dimensions 410 | except Exception as e: 411 | print(f"prepare_image error: {str(e)}") 412 | # Create a minimal valid image and dimensions 413 | h, w = image.height, image.width 414 | dimensions = ImageDimensions(original_w=w, original_h=h, padded_w=w, padded_h=h) 415 | # Return a black image of the same size 416 | return np.zeros((h, w, 3), dtype=np.uint8), dimensions 417 | 418 | 419 | def setup_output_dirs(save_dir): 420 | """Create necessary output directories""" 421 | os.makedirs(save_dir, exist_ok=True) 422 | os.makedirs(os.path.join(save_dir, "markdown"), exist_ok=True) 423 | os.makedirs(os.path.join(save_dir, "recognition_json"), exist_ok=True) 424 | os.makedirs(os.path.join(save_dir, "markdown", "figures"), exist_ok=True) 425 | 426 | 427 | def save_outputs(recognition_results, image_path, save_dir): 428 | """Save JSON and markdown outputs""" 429 | basename = os.path.splitext(os.path.basename(image_path))[0] 430 | 431 | # Save JSON file 432 | json_path = os.path.join(save_dir, "recognition_json", f"{basename}.json") 433 | with open(json_path, "w", encoding="utf-8") as f: 434 | json.dump(recognition_results, f, ensure_ascii=False, indent=2) 435 | 436 | # Generate and save markdown file 437 | markdown_converter = MarkdownConverter() 438 | markdown_content = markdown_converter.convert(recognition_results) 439 | markdown_path = os.path.join(save_dir, "markdown", f"{basename}.md") 440 | with open(markdown_path, "w", encoding="utf-8") as f: 441 | f.write(markdown_content) 442 | 443 | return json_path 444 | 445 | 446 | def crop_margin(img: Image.Image) -> Image.Image: 447 | """Crop margins from image""" 448 | try: 449 | width, height = img.size 450 | if width == 0 or height == 0: 451 | print("Warning: Image has zero width or height") 452 | return img 453 | 454 | data = np.array(img.convert("L")) 455 | data = data.astype(np.uint8) 456 | max_val = data.max() 457 | min_val = data.min() 458 | if max_val == min_val: 459 | return img 460 | data = (data - min_val) / (max_val - min_val) * 255 461 | gray = 255 * (data < 200).astype(np.uint8) 462 | 463 | coords = cv2.findNonZero(gray) # Find all non-zero points (text) 464 | if coords is None: 465 | return img 466 | a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box 467 | 468 | # Ensure crop coordinates are within image bounds 469 | a = max(0, a) 470 | b = max(0, b) 471 | w = min(w, width - a) 472 | h = min(h, height - b) 473 | 474 | # Only crop if we have a valid region 475 | if w > 0 and h > 0: 476 | return img.crop((a, b, a + w, b + h)) 477 | return img 478 | except Exception as e: 479 | print(f"crop_margin error: {str(e)}") 480 | return img # Return original image on error 481 | --------------------------------------------------------------------------------