├── .gitignore ├── LICENSE ├── README.md ├── docs ├── TENSORRT_GETTING_STARTED.md └── imgs │ └── output.png ├── requirements.txt ├── setup.py ├── struct_eqtable ├── __init__.py ├── internvl │ ├── __init__.py │ ├── conversation.py │ ├── internvl.py │ └── internvl_lmdeploy.py └── pix2s │ ├── __init__.py │ ├── pix2s.py │ └── pix2s_trt.py └── tools ├── demo ├── demo.png ├── demo.py └── demo.tex ├── scripts └── build_tensorrt.sh └── tensorrt_utils ├── build_visual_engine.py ├── convert_checkpoint.py └── helper.py /.gitignore: -------------------------------------------------------------------------------- 1 | dist/ 2 | build/ 3 | **.egg-info/ 4 | **__pycache__/ 5 | **.cache 6 | ckpts/ 7 | **version.py 8 | 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

StructEqTable-Deploy: A High-efficiency Open-source Toolkit for Table-to-Latex Transformation

3 | 4 | 5 | [[ Paper ]](https://arxiv.org/abs/2406.11633) [[ Website ]](https://unimodal4reasoning.github.io/DocGenome_page/) [[ Dataset🤗 ]](https://huggingface.co/datasets/U4R/DocGenome/tree/main) [[ Models🤗 ]](https://huggingface.co/U4R/StructTable-InternVL2-1B/tree/main) [[ Demo💬 ]](https://www.modelscope.cn/studios/HongbinZhou/StructEqTable-Demo/) 6 | 7 | 8 |
9 | 10 | Welcome to the official repository of StructEqTable-Deploy, a solution that converts images of Table into LaTeX/HTML/MarkDown, powered by scalable data from [DocGenome benchmark](https://unimodal4reasoning.github.io/DocGenome_page/). 11 | 12 | 13 | ## Overview 14 | Table is an effective way to represent structured data in scientific publications, financial statements, invoices, web pages, and many other scenarios. Extracting tabular data from a visual table image and performing the downstream reasoning tasks according to the extracted data is challenging, mainly due to that tables often present complicated column and row headers with spanning cell operation. To address these challenges, we present TableX, a large-scale multi-modal table benchmark extracted from [DocGenome benchmark](https://unimodal4reasoning.github.io/DocGenome_page/) for table pre-training, comprising more than 2 million high-quality Image-LaTeX pair data covering 156 disciplinary classes. Besides, benefiting from such large-scale data, we train an end-to-end model, StructEqTable, which provides the capability to precisely obtain the corresponding LaTeX description from a visual table image and perform multiple table-related reasoning tasks, including structural extraction and question answering, broadening its application scope and potential. 15 | 16 | ## Changelog 17 | - [2024/12/12] 🔥 We have released latest model **[StructTable-InternVL2-1B v0.2](https://huggingface.co/U4R/StructTable-InternVL2-1B/tree/main)** with enhanced recognition stability for HTML and Markdown formats! 18 | 19 | - [2024/10/19] We have released our latest model StructTable-InternVL2-1B! 20 | 21 | Thanks to IntenrVL2 powerful foundational capabilities, and through fine-tuning on the synthetic tabular data and DocGenome dataset, StructTable can convert table image into various common table formats including LaTeX, HTML, and Markdown. Moreover, inference speed has been significantly improved compared to the v0.2 version. 22 | - [2024/8/22] We have released our StructTable-base-v0.2, fine-tuned on the DocGenome dataset. This version features improved inference speed and robustness, achieved through data augmentation and reduced image token num. 23 | - [2024/8/08] We have released the TensorRT accelerated version, which only takes about 1 second for most images on GPU A100. Please follow the tutorial to install the environment and compile the model weights. 24 | - [2024/7/30] We have released the first version of StructEqTable. 25 | 26 | ## TODO 27 | 28 | - [x] Release inference code and checkpoints of StructEqTable. 29 | - [x] Support Chinese version of StructEqTable. 30 | - [x] Accelerated version of StructEqTable using TensorRT-LLM. 31 | - [x] Expand more domains of table image to improve the model's general capabilities. 32 | - [x] Efficient inference of StructTable-InternVL2-1B by [LMDeploy](https://github.com/InternLM/lmdeploy) Tookit. 33 | - [ ] Release our table pre-training and fine-tuning code 34 | 35 | 36 | ## Installation 37 | ``` bash 38 | conda create -n structeqtable python>=3.10 39 | conda activate structeqtable 40 | 41 | # Install from Source code (Suggested) 42 | git clone https://github.com/UniModal4Reasoning/StructEqTable-Deploy.git 43 | cd StructEqTable-Deploy 44 | pip install -r requirements.txt 45 | python setup develop 46 | 47 | # or Install from Github repo 48 | pip install "git+https://github.com/UniModal4Reasoning/StructEqTable-Deploy.git" 49 | 50 | # or Install from PyPI 51 | pip install struct-eqtable --upgrade 52 | ``` 53 | 54 | ## Model Zoo 55 | 56 | | Base Model | Model Size | Training Data | Data Augmentation | LMDeploy | TensorRT | HuggingFace | 57 | |---------------------|------------|------------------|-------------------|----------|----------|-------------------| 58 | | InternVL2-1B | ~1B | DocGenome and Synthetic Data | ✔ | ✔ | | [StructTable-InternVL2-1B v0.2](https://huggingface.co/U4R/StructTable-InternVL2-1B/tree/main) | 59 | | InternVL2-1B | ~1B | DocGenome and Synthetic Data | ✔ | ✔ | | [StructTable-InternVL2-1B v0.1](https://huggingface.co/U4R/StructTable-InternVL2-1B/tree/v0.1) | 60 | | Pix2Struct-base | ~300M | DocGenome | ✔ | | ✔ | [StructTable-base v0.2](https://huggingface.co/U4R/StructTable-base/tree/v0.2) | 61 | | Pix2Struct-base | ~300M | DocGenome | | | ✔ | [StructTable-base v0.1](https://huggingface.co/U4R/StructTable-base/tree/v0.1) | 62 | 63 | 64 | 65 | ## Quick Demo 66 | - Run the demo/demo.py 67 | ```shell script 68 | cd tools/demo 69 | 70 | python demo.py \ 71 | --image_path ./demo.png \ 72 | --ckpt_path U4R/StructTable-InternVL2-1B \ 73 | --output_format latex 74 | ``` 75 | 76 | - HTML or Markdown format output (Only Supported by StructTable-InternVL2-1B) 77 | 78 | ```shell script 79 | python demo.py \ 80 | --image_path ./demo.png \ 81 | --ckpt_path U4R/StructTable-InternVL2-1B \ 82 | --output_format html markdown 83 | ``` 84 | 85 | ## Efficient Inference 86 | - Install LMDeploy Tookit 87 | ```shell script 88 | pip install lmdeploy 89 | ``` 90 | 91 | - Run the demo/demo.py 92 | ```shell script 93 | cd tools/demo 94 | 95 | python demo.py \ 96 | --image_path ./demo.png \ 97 | --ckpt_path U4R/StructTable-InternVL2-1B \ 98 | --output_format latex \ 99 | --lmdeploy 100 | ``` 101 | 102 | 103 | - Visualization Result 104 | 105 | You can copy the output LaTeX code into [demo.tex](../tools/demo/demo.tex), then use [Overleaf](https://www.overleaf.com/project) for table visualization. 106 | ![](docs/imgs/output.png) 107 | 108 | 109 | ## Acknowledgements 110 | - [DocGenome](https://github.com/UniModal4Reasoning/DocGenome). An Open Large-scale Scientific Document Benchmark for Training and Testing Multi-modal Large Models. 111 | - [ChartVLM](https://github.com/UniModal4Reasoning/ChartVLM). A Versatile Benchmark and Foundation Model for Complicated Chart Reasoning. 112 | - [Pix2Struct](https://github.com/google-research/pix2struct). Screenshot Parsing as Pretraining for Visual Language Understanding. 113 | - [InternVL Family](https://github.com/OpenGVLab/InternVL). A Series of Powerful Foundational Vision-Language Models. 114 | - [LMDeploy](https://github.com/InternLM/lmdeploy). A toolkit for compressing, deploying, and serving LLM and MLLM. 115 | - [UniMERNet](https://github.com/opendatalab/UniMERNet). A Universal Network for Real-World Mathematical Expression Recognition. 116 | - [Donut](https://huggingface.co/naver-clova-ix/donut-base). The UniMERNet's Transformer Encoder-Decoder are referenced from Donut. 117 | - [Nougat](https://github.com/facebookresearch/nougat). Data Augmentation follows Nougat. 118 | - [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM). Model inference acceleration uses TensorRT-LLM. 119 | 120 | 121 | ## License 122 | StructEqTable is released under the [Apache License 2.0](LICENSE) 123 | 124 | ## Citation 125 | If you find our models / code / papers useful in your research, please consider giving ⭐ and citations 📝, thx :) 126 | ```bibtex 127 | @article{xia2024docgenome, 128 | title={DocGenome: An Open Large-scale Scientific Document Benchmark for Training and Testing Multi-modal Large Language Models}, 129 | author={Xia, Renqiu and Mao, Song and Yan, Xiangchao and Zhou, Hongbin and Zhang, Bo and Peng, Haoyang and Pi, Jiahao and Fu, Daocheng and Wu, Wenjie and Ye, Hancheng and others}, 130 | journal={arXiv preprint arXiv:2406.11633}, 131 | year={2024} 132 | } 133 | ``` 134 | 135 | ## Contact Us 136 | If you encounter any issues or have questions, please feel free to contact us via zhouhongbin@pjlab.org.cn. 137 | -------------------------------------------------------------------------------- /docs/TENSORRT_GETTING_STARTED.md: -------------------------------------------------------------------------------- 1 | # Getting Started 2 | [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) is used for model inference speeding up. 3 | 4 | All the codes are successfully tested in the following enviroments: 5 | * Linux (18.04, 20.04, 22.04) 6 | * Python 3.10 7 | * Pytorch 2.0 or higher 8 | * CUDA 12.1 or higher 9 | * TensorRT-LLM 0.11.0 (stable version) 10 | 11 | ### 1. Conda or Python Environment Preparation 12 | 13 | 14 | * Please follow the step 1, 2 from the [official tutorial](https://nvidia.github.io/TensorRT-LLM/installation/linux.html) of TensorRT-LLM to install the environment. 15 | 16 | Note we used the TensorRT-LLM **stable version `0.11.0`**. 17 | ``` bash 18 | # Installing on Linux 19 | Step 1. Retrieve and launch the docker container (optional). 20 | 21 | You can pre-install the environment using the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit) to avoid manual environment configuration. 22 | 23 | ```bash 24 | # Obtain and start the basic docker image environment (optional). 25 | docker run --rm --ipc=host --runtime=nvidia --gpus all --entrypoint /bin/bash -it nvidia/cuda:12.4.1-devel-ubuntu22.04 26 | ``` 27 | Note: please make sure to set `--ipc=host` as a docker run argument to avoid `Bus error (core dumped)`. 28 | 29 | Step 2. Install TensorRT-LLM. 30 | 31 | ```bash 32 | # Install dependencies, TensorRT-LLM requires Python 3.10 33 | apt-get update && apt-get -y install python3.10 python3-pip openmpi-bin libopenmpi-dev git git-lfs 34 | 35 | # Install the latest preview version (corresponding to the main branch) of TensorRT-LLM. 36 | # If you want to install the stable version (corresponding to the release branch), please 37 | # remove the `--pre` option. 38 | pip3 install tensorrt_llm==0.11.0 --extra-index-url https://pypi.nvidia.com 39 | 40 | # Check installation 41 | python3 -c "import tensorrt_llm" 42 | ``` 43 | 44 | Please note that TensorRT-LLM depends on TensorRT. In earlier versions that include TensorRT 8, 45 | overwriting an upgraded to a new version may require explicitly running `pip uninstall tensorrt` 46 | to uninstall the old version. 47 | ``` 48 | * Once you successfully execute `python3 -c "import tensorrt_llm"`, it means that you have completed Environment Preparation. 49 | 50 | Tips: If you want to install the environment manually, please note that the version of Python require >= 3.10 51 | 52 | 53 | ### 2. Model Compilation 54 | You can refer to the [official tutorial](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html) to complete the model compilation, or follow our instructions and use the provided scripts to implement it. 55 | 56 | #### 2.1 Download [StructEqTable checkpoints](https://huggingface.co/U4R/StructTable-base/tree/v0.2) 57 | ``` 58 | cd StructEqTable-Deploy 59 | 60 | # using huggingface-cli download checkpoint 61 | huggingface-cli download --resume-download --local-dir-use-symlinks False U4R/StructTable-base --local-dir ckpts/StructTable-base 62 | 63 | ``` 64 | After above steps, the files to directory of StructEqTable-Deploy as follows: 65 | ``` 66 | StructEqTable-Deploy 67 | ├── ckpts 68 | │ ├── StructTable-base 69 | ├── docs 70 | ├── struct_eqtable 71 | ├── tools 72 | ``` 73 | 74 | #### 2.2 Convert Checkpoint and Build Engine 75 | We provide a script to help users quickly implement model compilation. 76 | 77 | ``` bash 78 | cd StructEqTable-Deploy/tools 79 | # execute the script to quickly compile the model. 80 | bash scripts/build_tensorrt.sh 81 | ``` 82 | After the script runs successfully, the built models can be found in `ckpts/StructTable-base-TensorRT`. 83 | The file structure in the path `ckpts/StructTable-base-TensorRT` should be as follows: 84 | ``` 85 | ckpts 86 | ├── StructTable-base 87 | ├── StructTable-base-TensorRT 88 | │ ├── trt_engines 89 | │ ├── trt_models 90 | │ ├── visual_engiens 91 | ``` 92 | 93 | #### 2.3 Run Quickly Demo 94 | Run the demo/demo.py with TensorRT mode. 95 | 96 | ``` bash 97 | cd StructEqTable-Deploy/tools/demo 98 | 99 | python demo.py \ 100 | --image_path ./demo.png \ 101 | --ckpt_path ../../ckpts/StructTable-base \ 102 | --output_format latex 103 | --tensorrt ../../ckpts/StructTable-base-TensorRT 104 | ``` 105 | 106 | You may get output as follows: 107 | ``` 108 | total cost time: 0.88s 109 | Table 0 LATEX format output: 110 | \begin{tabular}{|c|c|c|c|} 111 | \hline 112 | Quantity $\backslash$ Unit System & International System SI (kg-m-s) & Traditional aeronautical (lb-ft-s) & Traditional structural (lb-inch-s) \\ 113 | \hline 114 | Mass (translational inertia), $m$ & kilogram mass (kg) & slug = lb-s$^2$/f & lb-s$^2$/inch \\ 115 | \hline 116 | Length, translational motion & meter (m) & foot (ft) & inch (in.) \\ 117 | \hline 118 | Time, $t$ & second (s) & second (s) & second (s) \\ 119 | \hline 120 | Force, translational action & newton (N) = kg-m/s$^2$ & pound force (lb) & pound force (lb) \\ 121 | \hline 122 | Translational stiffness constant, $k$ & N/m & lb/ft & lb/inch \\ 123 | \hline 124 | Translational damping constant, $c$ & N/(m/s) = N-s/m & lb/(ft/s) = lb-s/ft & lb/(inch/s) = lb-s/inch \\ 125 | \hline 126 | Angle, rotational motion & radial (rad), which is dimensionless & radial (rad), which is dimensionless & radial (rad), which is dimensionless \\ 127 | \hline 128 | Rotational inertia, $J$ & kg-m$^2$ & slug-ft$^2$ = lb-s$^2$ - ft & lb-s$^2$ - inch \\ 129 | \hline 130 | Moment or torque, rotational action & N-m & lb-ft & lb-inch \\ 131 | \hline 132 | Rotational stiffness constant, $k_\theta$ & (N-m)/rad = N-m & (lb-ft)/rad = lb-ft & (lb-inch)/rad = lb-inch \\ 133 | \hline 134 | Rotational damping constant, $c_\theta$ & (N-m)/(rad/s) = N-m-s & (lb-ft)/(rad/s) = lb-ft-s & (lb-inch)/(rad/s) = lb-inch-s \\ 135 | \hline 136 | \end{tabular} 137 | ``` 138 | 139 | 140 | ### 3. Table Visualization 141 | You can copy the output LaTeX code into [demo.tex](../tools/demo/demo.tex), then use [Overleaf](https://www.overleaf.com/project) or Visual Studio Code LaTeX Workshop Extension for table visualization. 142 | 143 | ![](./imgs/demo.png) -------------------------------------------------------------------------------- /docs/imgs/output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-Innovator/StructEqTable-Deploy/55649befcf880bc4fbf229a7b685c09c962d0dea/docs/imgs/output.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | transformers<=4.47 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from setuptools import find_packages, setup 3 | 4 | 5 | def write_version_to_file(version, target_file): 6 | with open(target_file, 'w') as f: 7 | print('__version__ = "%s"' % version, file=f) 8 | 9 | if __name__ == '__main__': 10 | version = '0.3.3' 11 | write_version_to_file(version, 'struct_eqtable/version.py') 12 | with Path(Path(__file__).parent, 13 | 'README.md').open(encoding='utf-8') as file: 14 | long_description = file.read() 15 | setup( 16 | name='struct_eqtable', 17 | version=version, 18 | description='A High-efficiency Open-source Toolkit for Table-to-Latex Transformation', 19 | long_description=long_description, 20 | long_description_content_type="text/markdown", 21 | install_requires=[ 22 | 'torch', 23 | 'transformers<=4.47', 24 | ], 25 | python_requires=">=3.9", 26 | author='Hongbin Zhou, Xiangchao Yan, Bo Zhang', 27 | author_email='zhangbo@pjlab.org.cn', 28 | url="https://github.com/UniModal4Reasoning/StructEqTable-Deploy", 29 | license='Apache License 2.0', 30 | packages=find_packages(exclude=['demo']), 31 | ) 32 | -------------------------------------------------------------------------------- /struct_eqtable/__init__.py: -------------------------------------------------------------------------------- 1 | from .pix2s import Pix2Struct, Pix2StructTensorRT 2 | from .internvl import InternVL, InternVL_LMDeploy 3 | 4 | from transformers import AutoConfig 5 | 6 | 7 | __ALL_MODELS__ = { 8 | 'Pix2Struct': Pix2Struct, 9 | 'Pix2StructTensorRT': Pix2StructTensorRT, 10 | 'InternVL': InternVL, 11 | 'InternVL_LMDeploy': InternVL_LMDeploy, 12 | } 13 | 14 | 15 | def get_model_name(model_path): 16 | model_config = AutoConfig.from_pretrained( 17 | model_path, 18 | trust_remote_code=True, 19 | ) 20 | 21 | if 'Pix2Struct' in model_config.architectures[0]: 22 | model_name = 'Pix2Struct' 23 | elif 'InternVL' in model_config.architectures[0]: 24 | model_name = 'InternVL' 25 | else: 26 | raise ValueError(f"Unsupported model type: {model_config.architectures[0]}") 27 | 28 | return model_name 29 | 30 | 31 | def build_model(model_ckpt='U4R/StructTable-InternVL2-1B', **kwargs): 32 | model_name = get_model_name(model_ckpt) 33 | if model_name == 'InternVL' and kwargs.get('lmdeploy', False): 34 | model_name = 'InternVL_LMDeploy' 35 | elif model_name == 'Pix2Struct' and kwargs.get('tensorrt_path', None): 36 | model_name = 'Pix2StructTensorRT' 37 | 38 | model = __ALL_MODELS__[model_name]( 39 | model_ckpt, 40 | **kwargs 41 | ) 42 | 43 | return model -------------------------------------------------------------------------------- /struct_eqtable/internvl/__init__.py: -------------------------------------------------------------------------------- 1 | from .internvl import InternVL 2 | from .internvl_lmdeploy import InternVL_LMDeploy -------------------------------------------------------------------------------- /struct_eqtable/internvl/conversation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Conversation prompt templates. 3 | 4 | We kindly request that you import fastchat instead of copying this file if you wish to use it. 5 | If you have changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates. 6 | """ 7 | 8 | import dataclasses 9 | from enum import IntEnum, auto 10 | from typing import Any, Dict, List, Tuple, Union 11 | 12 | 13 | class SeparatorStyle(IntEnum): 14 | """Separator styles.""" 15 | 16 | ADD_COLON_SINGLE = auto() 17 | ADD_COLON_TWO = auto() 18 | ADD_COLON_SPACE_SINGLE = auto() 19 | NO_COLON_SINGLE = auto() 20 | NO_COLON_TWO = auto() 21 | ADD_NEW_LINE_SINGLE = auto() 22 | LLAMA2 = auto() 23 | CHATGLM = auto() 24 | CHATML = auto() 25 | CHATINTERN = auto() 26 | DOLLY = auto() 27 | RWKV = auto() 28 | PHOENIX = auto() 29 | ROBIN = auto() 30 | FALCON_CHAT = auto() 31 | CHATGLM3 = auto() 32 | INTERNVL_ZH = auto() 33 | MPT = auto() 34 | 35 | 36 | @dataclasses.dataclass 37 | class Conversation: 38 | """A class that manages prompt templates and keeps all conversation history.""" 39 | 40 | # The name of this template 41 | name: str 42 | # The template of the system prompt 43 | system_template: str = '{system_message}' 44 | # The system message 45 | system_message: str = '' 46 | # The names of two roles 47 | roles: Tuple[str] = ('USER', 'ASSISTANT') 48 | # All messages. Each item is (role, message). 49 | messages: List[List[str]] = () 50 | # The number of few shot examples 51 | offset: int = 0 52 | # The separator style and configurations 53 | sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE 54 | sep: str = '\n' 55 | sep2: str = None 56 | # Stop criteria (the default one is EOS token) 57 | stop_str: Union[str, List[str]] = None 58 | # Stops generation if meeting any token in this list 59 | stop_token_ids: List[int] = None 60 | 61 | def get_prompt(self) -> str: 62 | """Get the prompt for generation.""" 63 | system_prompt = self.system_template.format(system_message=self.system_message) 64 | if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: 65 | ret = system_prompt + self.sep 66 | for role, message in self.messages: 67 | if message: 68 | ret += role + ': ' + message + self.sep 69 | else: 70 | ret += role + ':' 71 | return ret 72 | elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: 73 | seps = [self.sep, self.sep2] 74 | ret = system_prompt + seps[0] 75 | for i, (role, message) in enumerate(self.messages): 76 | if message: 77 | ret += role + ': ' + message + seps[i % 2] 78 | else: 79 | ret += role + ':' 80 | return ret 81 | elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: 82 | ret = system_prompt + self.sep 83 | for role, message in self.messages: 84 | if message: 85 | ret += role + ': ' + message + self.sep 86 | else: 87 | ret += role + ': ' # must be end with a space 88 | return ret 89 | elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: 90 | ret = '' if system_prompt == '' else system_prompt + self.sep 91 | for role, message in self.messages: 92 | if message: 93 | ret += role + '\n' + message + self.sep 94 | else: 95 | ret += role + '\n' 96 | return ret 97 | elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: 98 | ret = system_prompt 99 | for role, message in self.messages: 100 | if message: 101 | ret += role + message + self.sep 102 | else: 103 | ret += role 104 | return ret 105 | elif self.sep_style == SeparatorStyle.NO_COLON_TWO: 106 | seps = [self.sep, self.sep2] 107 | ret = system_prompt 108 | for i, (role, message) in enumerate(self.messages): 109 | if message: 110 | ret += role + message + seps[i % 2] 111 | else: 112 | ret += role 113 | return ret 114 | elif self.sep_style == SeparatorStyle.RWKV: 115 | ret = system_prompt 116 | for i, (role, message) in enumerate(self.messages): 117 | if message: 118 | ret += ( 119 | role 120 | + ': ' 121 | + message.replace('\r\n', '\n').replace('\n\n', '\n') 122 | ) 123 | ret += '\n\n' 124 | else: 125 | ret += role + ':' 126 | return ret 127 | elif self.sep_style == SeparatorStyle.LLAMA2: 128 | seps = [self.sep, self.sep2] 129 | if self.system_message: 130 | ret = system_prompt 131 | else: 132 | ret = '[INST] ' 133 | for i, (role, message) in enumerate(self.messages): 134 | tag = self.roles[i % 2] 135 | if message: 136 | if i == 0: 137 | ret += message + ' ' 138 | else: 139 | ret += tag + ' ' + message + seps[i % 2] 140 | else: 141 | ret += tag 142 | return ret 143 | elif self.sep_style == SeparatorStyle.CHATGLM: 144 | # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 145 | # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 146 | round_add_n = 1 if self.name == 'chatglm2' else 0 147 | if system_prompt: 148 | ret = system_prompt + self.sep 149 | else: 150 | ret = '' 151 | 152 | for i, (role, message) in enumerate(self.messages): 153 | if i % 2 == 0: 154 | ret += f'[Round {i//2 + round_add_n}]{self.sep}' 155 | 156 | if message: 157 | ret += f'{role}:{message}{self.sep}' 158 | else: 159 | ret += f'{role}:' 160 | return ret 161 | elif self.sep_style == SeparatorStyle.CHATML: 162 | ret = '' if system_prompt == '' else system_prompt + self.sep + '\n' 163 | for role, message in self.messages: 164 | if message: 165 | ret += role + '\n' + message + self.sep + '\n' 166 | else: 167 | ret += role + '\n' 168 | return ret 169 | elif self.sep_style == SeparatorStyle.CHATGLM3: 170 | ret = '' 171 | if self.system_message: 172 | ret += system_prompt 173 | for role, message in self.messages: 174 | if message: 175 | ret += role + '\n' + ' ' + message 176 | else: 177 | ret += role 178 | return ret 179 | elif self.sep_style == SeparatorStyle.CHATINTERN: 180 | # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 181 | seps = [self.sep, self.sep2] 182 | ret = system_prompt 183 | for i, (role, message) in enumerate(self.messages): 184 | # if i % 2 == 0: 185 | # ret += "" 186 | if message: 187 | ret += role + ':' + message + seps[i % 2] + '\n' 188 | else: 189 | ret += role + ':' 190 | return ret 191 | elif self.sep_style == SeparatorStyle.DOLLY: 192 | seps = [self.sep, self.sep2] 193 | ret = system_prompt 194 | for i, (role, message) in enumerate(self.messages): 195 | if message: 196 | ret += role + ':\n' + message + seps[i % 2] 197 | if i % 2 == 1: 198 | ret += '\n\n' 199 | else: 200 | ret += role + ':\n' 201 | return ret 202 | elif self.sep_style == SeparatorStyle.PHOENIX: 203 | ret = system_prompt 204 | for role, message in self.messages: 205 | if message: 206 | ret += role + ': ' + '' + message + '' 207 | else: 208 | ret += role + ': ' + '' 209 | return ret 210 | elif self.sep_style == SeparatorStyle.ROBIN: 211 | ret = system_prompt + self.sep 212 | for role, message in self.messages: 213 | if message: 214 | ret += role + ':\n' + message + self.sep 215 | else: 216 | ret += role + ':\n' 217 | return ret 218 | elif self.sep_style == SeparatorStyle.FALCON_CHAT: 219 | ret = '' 220 | if self.system_message: 221 | ret += system_prompt + self.sep 222 | for role, message in self.messages: 223 | if message: 224 | ret += role + ': ' + message + self.sep 225 | else: 226 | ret += role + ':' 227 | 228 | return ret 229 | elif self.sep_style == SeparatorStyle.INTERNVL_ZH: 230 | seps = [self.sep, self.sep2] 231 | ret = self.system_message + seps[0] 232 | for i, (role, message) in enumerate(self.messages): 233 | if message: 234 | ret += role + ': ' + message + seps[i % 2] 235 | else: 236 | ret += role + ':' 237 | return ret 238 | elif self.sep_style == SeparatorStyle.MPT: 239 | ret = system_prompt + self.sep 240 | for role, message in self.messages: 241 | if message: 242 | if type(message) is tuple: 243 | message, _, _ = message 244 | ret += role + message + self.sep 245 | else: 246 | ret += role 247 | return ret 248 | else: 249 | raise ValueError(f'Invalid style: {self.sep_style}') 250 | 251 | def set_system_message(self, system_message: str): 252 | """Set the system message.""" 253 | self.system_message = system_message 254 | 255 | def append_message(self, role: str, message: str): 256 | """Append a new message.""" 257 | self.messages.append([role, message]) 258 | 259 | def update_last_message(self, message: str): 260 | """Update the last output. 261 | 262 | The last message is typically set to be None when constructing the prompt, 263 | so we need to update it in-place after getting the response from a model. 264 | """ 265 | self.messages[-1][1] = message 266 | 267 | def to_gradio_chatbot(self): 268 | """Convert the conversation to gradio chatbot format.""" 269 | ret = [] 270 | for i, (role, msg) in enumerate(self.messages[self.offset :]): 271 | if i % 2 == 0: 272 | ret.append([msg, None]) 273 | else: 274 | ret[-1][-1] = msg 275 | return ret 276 | 277 | def to_openai_api_messages(self): 278 | """Convert the conversation to OpenAI chat completion format.""" 279 | ret = [{'role': 'system', 'content': self.system_message}] 280 | 281 | for i, (_, msg) in enumerate(self.messages[self.offset :]): 282 | if i % 2 == 0: 283 | ret.append({'role': 'user', 'content': msg}) 284 | else: 285 | if msg is not None: 286 | ret.append({'role': 'assistant', 'content': msg}) 287 | return ret 288 | 289 | def copy(self): 290 | return Conversation( 291 | name=self.name, 292 | system_template=self.system_template, 293 | system_message=self.system_message, 294 | roles=self.roles, 295 | messages=[[x, y] for x, y in self.messages], 296 | offset=self.offset, 297 | sep_style=self.sep_style, 298 | sep=self.sep, 299 | sep2=self.sep2, 300 | stop_str=self.stop_str, 301 | stop_token_ids=self.stop_token_ids, 302 | ) 303 | 304 | def dict(self): 305 | return { 306 | 'template_name': self.name, 307 | 'system_message': self.system_message, 308 | 'roles': self.roles, 309 | 'messages': self.messages, 310 | 'offset': self.offset, 311 | } 312 | 313 | 314 | # A global registry for all conversation templates 315 | conv_templates: Dict[str, Conversation] = {} 316 | 317 | 318 | def register_conv_template(template: Conversation, override: bool = False): 319 | """Register a new conversation template.""" 320 | if not override: 321 | assert ( 322 | template.name not in conv_templates 323 | ), f'{template.name} has been registered.' 324 | 325 | conv_templates[template.name] = template 326 | 327 | 328 | def get_conv_template(name: str) -> Conversation: 329 | """Get a conversation template.""" 330 | return conv_templates[name].copy() 331 | 332 | 333 | # Both Hermes-2 and internlm2-chat are chatml-format conversation templates. The difference 334 | # is that during training, the preprocessing function for the Hermes-2 template doesn't add 335 | # at the beginning of the tokenized sequence, while the internlm2-chat template does. 336 | # Therefore, they are completely equivalent during inference. 337 | register_conv_template( 338 | Conversation( 339 | name='Hermes-2', 340 | system_template='<|im_start|>system\n{system_message}', 341 | # note: The new system prompt was not used here to avoid changes in benchmark performance. 342 | # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。', 343 | # system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。', 344 | system_message='You are a Table Image to LaTeX/Markdown/HMTL Code converter.', 345 | roles=('<|im_start|>user\n', '<|im_start|>assistant\n'), 346 | sep_style=SeparatorStyle.MPT, 347 | sep='<|im_end|>', 348 | stop_token_ids=[ 349 | 2, 350 | 6, 351 | 7, 352 | 8, 353 | ], 354 | stop_str='<|endoftext|>', 355 | ) 356 | ) 357 | 358 | 359 | register_conv_template( 360 | Conversation( 361 | name='internlm2-chat', 362 | system_template='<|im_start|>system\n{system_message}', 363 | # note: The new system prompt was not used here to avoid changes in benchmark performance. 364 | # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。', 365 | system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。', 366 | roles=('<|im_start|>user\n', '<|im_start|>assistant\n'), 367 | sep_style=SeparatorStyle.MPT, 368 | sep='<|im_end|>', 369 | stop_token_ids=[ 370 | 2, 371 | 92543, 372 | 92542 373 | ] 374 | ) 375 | ) 376 | 377 | 378 | register_conv_template( 379 | Conversation( 380 | name='phi3-chat', 381 | system_template='<|system|>\n{system_message}', 382 | # note: The new system prompt was not used here to avoid changes in benchmark performance. 383 | # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。', 384 | system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。', 385 | roles=('<|user|>\n', '<|assistant|>\n'), 386 | sep_style=SeparatorStyle.MPT, 387 | sep='<|end|>', 388 | stop_token_ids=[ 389 | 2, 390 | 32000, 391 | 32007 392 | ] 393 | ) 394 | ) 395 | -------------------------------------------------------------------------------- /struct_eqtable/internvl/internvl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import nn 4 | from transformers import AutoModel, AutoTokenizer, AutoImageProcessor, GenerationConfig 5 | 6 | from .conversation import get_conv_template 7 | 8 | class InternVL(nn.Module): 9 | def __init__(self, model_path='U4R/StructTable-InternVL2-1B', max_new_tokens=1024, max_time=30, flash_attn=True, **kwargs): 10 | super().__init__() 11 | self.model_path = model_path 12 | self.max_new_tokens = max_new_tokens 13 | self.max_generate_time = max_time 14 | self.flash_attn = flash_attn 15 | 16 | # init model and image processor from ckpt path 17 | self.init_tokenizer(model_path) 18 | self.init_image_processor(model_path) 19 | self.init_model(model_path) 20 | 21 | self.prompt_template = { 22 | 'latex': '', 23 | 'html': '', 24 | 'markdown': '', 25 | } 26 | # support output format 27 | self.supported_output_format = ['latex', 'html', 'markdown'] 28 | 29 | def init_model(self, model_path): 30 | self.model = AutoModel.from_pretrained( 31 | model_path, 32 | trust_remote_code=True, 33 | torch_dtype=torch.bfloat16, 34 | low_cpu_mem_usage=True, 35 | use_flash_attn=self.flash_attn, 36 | ) 37 | self.model.eval() 38 | 39 | def init_image_processor(self, image_processor_path): 40 | self.image_processor = AutoImageProcessor.from_pretrained( 41 | image_processor_path, 42 | trust_remote_code=True, 43 | ) 44 | 45 | def init_tokenizer(self, tokenizer_path): 46 | self.tokenizer = AutoTokenizer.from_pretrained( 47 | tokenizer_path, 48 | trust_remote_code=True, 49 | use_fast=False, 50 | ) 51 | 52 | self.image_context_token = '' 53 | self.image_token_num = 256 54 | self.image_start_token = '' 55 | self.image_end_token = '' 56 | self.img_context_token_id = self.tokenizer.convert_tokens_to_ids(self.image_context_token) 57 | 58 | def format_image_tokens(self, path_num): 59 | return f'{self.image_start_token}{self.image_context_token* self.image_token_num * path_num}{self.image_end_token}' 60 | 61 | def forward(self, images, output_format='latex', **kwargs): 62 | # process image to tokens 63 | if not isinstance(images, list): 64 | images = [images] 65 | 66 | pixel_values_list = [] 67 | for image in images: 68 | path_images = self.dynamic_preprocess( 69 | image, image_size=448, max_num=12 70 | ) 71 | pixel_values = self.image_processor( 72 | path_images, 73 | return_tensors='pt' 74 | )['pixel_values'].to(torch.bfloat16) 75 | pixel_values_list.append(pixel_values) 76 | 77 | batch_size = len(pixel_values_list) 78 | conversation_list = [] 79 | for bs_idx in range(batch_size): 80 | pixel_values= pixel_values_list[bs_idx].to(torch.bfloat16) 81 | 82 | image_tokens = self.format_image_tokens(pixel_values.shape[0]) 83 | question = '\n' + self.prompt_template[output_format] 84 | answer = None 85 | 86 | template = get_conv_template(self.model.config.template) 87 | template.append_message(template.roles[0], question) 88 | template.append_message(template.roles[1], answer) 89 | conversation = template.get_prompt() 90 | conversation = conversation.replace('', image_tokens, 1) 91 | conversation_list.append(conversation) 92 | 93 | device = next(self.parameters()).device 94 | self.tokenizer.padding_side = 'left' 95 | model_inputs = self.tokenizer( 96 | conversation_list, 97 | return_tensors='pt', 98 | padding=True, 99 | max_length=self.tokenizer.model_max_length, 100 | truncation=True, 101 | ).to(device) 102 | pixel_values = torch.cat(pixel_values_list, axis=0).to(device) 103 | 104 | # generation config 105 | generation_config = dict( 106 | max_new_tokens=self.max_new_tokens, 107 | max_time=self.max_generate_time, 108 | img_context_token_id=self.img_context_token_id, 109 | pad_token_id=self.tokenizer.pad_token_id, 110 | eos_token_id=self.tokenizer.eos_token_id, 111 | do_sample=False, 112 | no_repeat_ngram_size=20, 113 | ) 114 | 115 | # generate text from image tokens 116 | model_output = self.model.generate( 117 | pixel_values=pixel_values, 118 | input_ids=model_inputs.input_ids, 119 | attention_mask=model_inputs.attention_mask, 120 | **generation_config, 121 | # **kwargs 122 | ) 123 | 124 | batch_decode_texts = self.tokenizer.batch_decode( 125 | model_output, 126 | skip_special_tokens=True 127 | ) 128 | return batch_decode_texts 129 | 130 | def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size): 131 | best_ratio_diff = float('inf') 132 | best_ratio = (1, 1) 133 | area = width * height 134 | for ratio in target_ratios: 135 | target_aspect_ratio = ratio[0] / ratio[1] 136 | ratio_diff = abs(aspect_ratio - target_aspect_ratio) 137 | if ratio_diff < best_ratio_diff: 138 | best_ratio_diff = ratio_diff 139 | best_ratio = ratio 140 | elif ratio_diff == best_ratio_diff: 141 | if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: 142 | best_ratio = ratio 143 | return best_ratio 144 | 145 | def dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=448, use_thumbnail=True): 146 | orig_width, orig_height = image.size 147 | aspect_ratio = orig_width / orig_height 148 | 149 | # calculate the existing image aspect ratio 150 | target_ratios = set( 151 | (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if 152 | i * j <= max_num and i * j >= min_num) 153 | target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) 154 | 155 | # find the closest aspect ratio to the target 156 | target_aspect_ratio = self.find_closest_aspect_ratio( 157 | aspect_ratio, target_ratios, orig_width, orig_height, image_size) 158 | 159 | # calculate the target width and height 160 | target_width = image_size * target_aspect_ratio[0] 161 | target_height = image_size * target_aspect_ratio[1] 162 | blocks = target_aspect_ratio[0] * target_aspect_ratio[1] 163 | 164 | # resize the image 165 | resized_img = image.resize((target_width, target_height)) 166 | processed_images = [] 167 | for i in range(blocks): 168 | box = ( 169 | (i % (target_width // image_size)) * image_size, 170 | (i // (target_width // image_size)) * image_size, 171 | ((i % (target_width // image_size)) + 1) * image_size, 172 | ((i // (target_width // image_size)) + 1) * image_size 173 | ) 174 | # split the image 175 | split_img = resized_img.crop(box) 176 | processed_images.append(split_img) 177 | assert len(processed_images) == blocks 178 | if use_thumbnail and len(processed_images) != 1: 179 | thumbnail_img = image.resize((image_size, image_size)) 180 | processed_images.append(thumbnail_img) 181 | return processed_images 182 | -------------------------------------------------------------------------------- /struct_eqtable/internvl/internvl_lmdeploy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from transformers import AutoTokenizer 5 | try: 6 | from lmdeploy import pipeline, GenerationConfig, PytorchEngineConfig, ChatTemplateConfig 7 | except: 8 | print("\033[93mimport lmdeploy failed, if do not use lmdeploy, ignore this message\033[0m") 9 | 10 | 11 | class InternVL_LMDeploy(nn.Module): 12 | def __init__(self, model_path='U4R/StructTable-InternVL2-1B', max_new_tokens=1024, batch_size=4, **kwargs): 13 | super().__init__() 14 | self.model_path = model_path 15 | self.max_new_tokens = max_new_tokens 16 | self.max_batch_size = batch_size 17 | 18 | # init model and tokenizer from ckpt path 19 | self.init_tokenizer(model_path) 20 | self.init_model(model_path) 21 | 22 | self.prompt_template = { 23 | 'latex': '', 24 | 'html': '', 25 | 'markdown': '', 26 | } 27 | # support output format 28 | self.supported_output_format = ['latex', 'html', 'markdown'] 29 | 30 | def init_tokenizer(self, tokenizer_path): 31 | self.tokenizer = AutoTokenizer.from_pretrained( 32 | tokenizer_path, 33 | trust_remote_code=True, 34 | use_fast=False, 35 | ) 36 | 37 | def init_model(self, model_path): 38 | engine_config = PytorchEngineConfig( 39 | dtype='bfloat16', 40 | max_batch_size=self.max_batch_size, 41 | cache_max_entry_count=0.1 42 | ) 43 | self.pipeline = pipeline( 44 | model_path, 45 | backend_config=engine_config, 46 | chat_template_config=ChatTemplateConfig(model_name='internvl2-internlm2') 47 | ) 48 | 49 | def forward(self, images, output_format='latex', **kwargs): 50 | # process image to tokens 51 | if not isinstance(images, list): 52 | images = [images] 53 | 54 | prompts = [self.prompt_template[output_format]] * len(images) 55 | generation_config = GenerationConfig( 56 | max_new_tokens=self.max_new_tokens, 57 | do_sample=False, 58 | temperature=1.0, 59 | stop_token_ids=[self.tokenizer.eos_token_id], 60 | ) 61 | 62 | responses = self.pipeline( 63 | [(x, y) for x, y in zip(prompts, images)], 64 | gen_config=generation_config, 65 | ) 66 | batch_decode_texts = [responce.text for responce in responses] 67 | return batch_decode_texts 68 | 69 | 70 | -------------------------------------------------------------------------------- /struct_eqtable/pix2s/__init__.py: -------------------------------------------------------------------------------- 1 | from .pix2s import Pix2Struct 2 | from .pix2s_trt import Pix2StructTensorRT 3 | -------------------------------------------------------------------------------- /struct_eqtable/pix2s/pix2s.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import nn 4 | from transformers import AutoModelForVision2Seq, AutoProcessor 5 | 6 | 7 | class Pix2Struct(nn.Module): 8 | def __init__(self, model_path='U4R/StructTable-base', max_new_tokens=1024, max_time=30, **kwargs): 9 | super().__init__() 10 | self.model_path = model_path 11 | self.max_new_tokens = max_new_tokens 12 | self.max_generate_time = max_time 13 | 14 | # init model and image processor from ckpt path 15 | self.init_image_processor(model_path) 16 | self.init_model(model_path) 17 | 18 | self.special_str_list = ['\\midrule', '\\hline'] 19 | self.supported_output_format = ['latex'] 20 | 21 | def postprocess_latex_code(self, code): 22 | for special_str in self.special_str_list: 23 | code = code.replace(special_str, special_str + ' ') 24 | return code 25 | 26 | def init_model(self, model_path): 27 | self.model = AutoModelForVision2Seq.from_pretrained(model_path) 28 | self.model.eval() 29 | 30 | def init_image_processor(self, image_processor_path): 31 | self.data_processor = AutoProcessor.from_pretrained(image_processor_path) 32 | 33 | def forward(self, image, **kwargs): 34 | # process image to tokens 35 | image_tokens = self.data_processor.image_processor( 36 | images=image, 37 | return_tensors='pt', 38 | ) 39 | 40 | device = next(self.parameters()).device 41 | for k, v in image_tokens.items(): 42 | image_tokens[k] = v.to(device) 43 | 44 | # generate text from image tokens 45 | model_output = self.model.generate( 46 | flattened_patches=image_tokens['flattened_patches'], 47 | attention_mask=image_tokens['attention_mask'], 48 | max_new_tokens=self.max_new_tokens, 49 | max_time=self.max_generate_time, 50 | no_repeat_ngram_size=20, 51 | ) 52 | 53 | latex_codes = self.data_processor.batch_decode(model_output, skip_special_tokens=True) 54 | # postprocess 55 | for i, code in enumerate(latex_codes): 56 | latex_codes[i] = self.postprocess_latex_code(code) 57 | 58 | return latex_codes 59 | -------------------------------------------------------------------------------- /struct_eqtable/pix2s/pix2s_trt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import json 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | try: 9 | import tensorrt_llm 10 | import tensorrt as trt 11 | import tensorrt_llm.profiler as profiler 12 | 13 | from tensorrt_llm._utils import str_dtype_to_trt, torch_to_numpy 14 | from tensorrt_llm.lora_manager import LoraManager 15 | from tensorrt_llm.runtime import Session, TensorInfo, ModelConfig, SamplingConfig 16 | except: 17 | print("\033[93mimport tensorrt_llm failed, if do not use tensorrt, ignore this message\033[0m") 18 | 19 | from typing import List 20 | from transformers import AutoProcessor, AutoTokenizer, AutoConfig 21 | 22 | 23 | def trt_dtype_to_torch(dtype): 24 | if dtype == trt.float16: 25 | return torch.float16 26 | elif dtype == trt.float32: 27 | return torch.float32 28 | elif dtype == trt.int32: 29 | return torch.int32 30 | elif dtype == trt.bfloat16: 31 | return torch.bfloat16 32 | else: 33 | raise TypeError("%s is not supported" % dtype) 34 | 35 | 36 | class Pix2StructTensorRT(nn.Module): 37 | 38 | def __init__(self, model_path, tensorrt_path, batch_size=1, max_new_tokens=4096, **kwargs): 39 | 40 | self.model_ckpt_path = model_path 41 | self.tensorrt_path = tensorrt_path 42 | self.batch_size = batch_size 43 | self.max_new_tokens = max_new_tokens 44 | 45 | self.llm_engine_path = os.path.join(tensorrt_path, 'llm_engines') 46 | self.visual_engine_path = os.path.join(tensorrt_path, 'visual_engines') 47 | 48 | device_id = torch.cuda.current_device() % torch.cuda.device_count() 49 | self.device_id = device_id 50 | self.device = "cuda:%d" % (device_id) 51 | 52 | self.stream = torch.cuda.Stream(torch.cuda.current_device()) 53 | torch.cuda.set_stream(self.stream) 54 | 55 | # parse model type from visual engine config 56 | with open(os.path.join(self.visual_engine_path, "config.json"), 57 | "r") as f: 58 | config = json.load(f) 59 | self.model_type = config['builder_config']['model_type'] 60 | self.vision_precision = config['builder_config']['precision'] 61 | 62 | self.vision_precision = 'float16' 63 | self.decoder_llm = not ( 64 | 't5' in self.model_type 65 | or self.model_type in ['nougat', 'pix2struct', 'StructEqTable'] 66 | ) # BLIP2-T5, pix2struct and Nougat are using encoder-decoder models as LLMs 67 | 68 | self.profiling_iterations = 20 69 | 70 | self.init_image_encoder() 71 | self.init_tokenizer() 72 | self.init_llm() 73 | self.init_image_processor() 74 | 75 | self.special_str_list = ['\\midrule', '\\hline'] 76 | self.supported_output_format = ['latex'] 77 | 78 | def postprocess_latex_code(self, code): 79 | for special_str in self.special_str_list: 80 | code = code.replace(special_str, special_str + ' ') 81 | return code 82 | 83 | def init_image_processor(self): 84 | self.data_processor = AutoProcessor.from_pretrained( 85 | self.model_ckpt_path) 86 | 87 | def init_tokenizer(self): 88 | self.tokenizer = AutoTokenizer.from_pretrained( 89 | self.model_ckpt_path, use_fast=True, use_legacy=False) 90 | # self.tokenizer.padding_side = "right" 91 | 92 | def init_image_encoder(self): 93 | vision_encoder_path = os.path.join(self.visual_engine_path, 94 | 'visual_encoder.engine') 95 | with open(vision_encoder_path, 'rb') as f: 96 | engine_buffer = f.read() 97 | self.visual_encoder_session = Session.from_serialized_engine( 98 | engine_buffer) 99 | 100 | def init_llm(self): 101 | 102 | self.model = TRTLLMEncDecModel.from_engine( 103 | os.path.basename(self.model_ckpt_path), 104 | self.llm_engine_path, 105 | skip_encoder=self.model_type in ['nougat', 'pix2struct', 'StructEqTable'], 106 | debug_mode=False, 107 | stream=self.stream) 108 | 109 | self.model_config = self.model.decoder_model_config 110 | self.runtime_mapping = self.model.decoder_runtime_mapping 111 | 112 | def __call__(self, image, **kwargs): 113 | # process image to tokens 114 | image_tokens = self.data_processor.image_processor( 115 | images=image, 116 | return_tensors='pt', 117 | ) 118 | 119 | for k, v in image_tokens.items(): 120 | image_tokens[k] = v.cuda() 121 | 122 | model_output = self.run( 123 | flattened_patches=image_tokens['flattened_patches'], 124 | attention_mask=image_tokens['attention_mask'], 125 | max_new_tokens=self.max_new_tokens 126 | ) 127 | 128 | # postprocess 129 | latex_codes = [] 130 | for i, code in enumerate(model_output): 131 | latex_codes.append(self.postprocess_latex_code(code[0])) 132 | 133 | return latex_codes 134 | 135 | def preprocess(self, warmup, pre_prompt, post_prompt, image, 136 | attention_mask): 137 | if not warmup: 138 | profiler.start("Vision") 139 | 140 | visual_features, visual_atts = self.get_visual_features( 141 | torch.stack(image['image_patches'], dim=0) 142 | if self.model_type == 'fuyu' else image, attention_mask) 143 | 144 | if not warmup: 145 | profiler.stop("Vision") 146 | 147 | pre_input_ids = self.tokenizer(pre_prompt, 148 | return_tensors="pt", 149 | padding=True).input_ids 150 | if post_prompt[0] is not None: 151 | post_input_ids = self.tokenizer(post_prompt, 152 | return_tensors="pt", 153 | padding=True).input_ids 154 | length = pre_input_ids.shape[1] + post_input_ids.shape[ 155 | 1] + visual_atts.shape[1] 156 | else: 157 | post_input_ids = None 158 | length = pre_input_ids.shape[1] + visual_atts.shape[1] 159 | 160 | input_lengths = torch.IntTensor([length] * 1).to( 161 | torch.int32) 162 | 163 | input_ids, ptuning_args = self.setup_fake_prompts( 164 | visual_features, pre_input_ids, post_input_ids, input_lengths) 165 | 166 | return input_ids, input_lengths, ptuning_args, visual_features 167 | 168 | def generate(self, pre_prompt, post_prompt, image, decoder_input_ids, 169 | max_new_tokens, attention_mask, warmup): 170 | if not warmup: 171 | profiler.start("Generate") 172 | 173 | input_ids, input_lengths, ptuning_args, visual_features = self.preprocess( 174 | warmup, pre_prompt, post_prompt, image, attention_mask) 175 | 176 | if warmup: return None 177 | 178 | profiler.start("LLM") 179 | 180 | # Trim encoder input_ids to match visual features shape 181 | ids_shape = (self.batch_size, visual_features.shape[1]) 182 | 183 | input_ids = torch.ones(ids_shape, dtype=torch.int32) 184 | 185 | output_ids = self.model.generate( 186 | input_ids, 187 | decoder_input_ids, 188 | max_new_tokens, 189 | num_beams=1, 190 | bos_token_id=self.tokenizer.bos_token_id, 191 | pad_token_id=self.tokenizer.pad_token_id, 192 | eos_token_id=self.tokenizer.eos_token_id, 193 | debug_mode=False, 194 | prompt_embedding_table=ptuning_args[0], 195 | prompt_tasks=ptuning_args[1], 196 | prompt_vocab_size=ptuning_args[2], 197 | attention_mask=attention_mask) 198 | 199 | # Reset input_lengths to match decoder_input_ids 200 | input_lengths = torch.ones(input_lengths.shape, 201 | dtype=input_lengths.dtype) 202 | profiler.stop("LLM") 203 | 204 | if tensorrt_llm.mpi_rank() == 0: 205 | # Extract a list of tensors of shape beam_width x output_ids. 206 | output_beams_list = [ 207 | self.tokenizer.batch_decode( 208 | output_ids[batch_idx, :, input_lengths[batch_idx]:], 209 | skip_special_tokens=True) 210 | for batch_idx in range(self.batch_size) 211 | ] 212 | 213 | stripped_text = [[ 214 | output_beams_list[batch_idx][beam_idx].strip() 215 | for beam_idx in range(1) 216 | ] for batch_idx in range(self.batch_size)] 217 | profiler.stop("Generate") 218 | return stripped_text 219 | else: 220 | profiler.stop("Generate") 221 | return None 222 | 223 | def get_visual_features(self, image, attention_mask): 224 | visual_features = { 225 | 'input': 226 | image.to( 227 | tensorrt_llm._utils.str_dtype_to_torch(self.vision_precision)) 228 | } 229 | if attention_mask is not None: 230 | visual_features['attention_mask'] = attention_mask 231 | tensor_info = [ 232 | TensorInfo('input', str_dtype_to_trt(self.vision_precision), 233 | image.shape) 234 | ] 235 | if attention_mask is not None: 236 | tensor_info.append( 237 | TensorInfo('attention_mask', trt.DataType.INT32, 238 | attention_mask.shape)) 239 | visual_output_info = self.visual_encoder_session.infer_shapes( 240 | tensor_info) 241 | visual_outputs = { 242 | t.name: torch.empty(tuple(t.shape), 243 | dtype=trt_dtype_to_torch(t.dtype), 244 | device=image.device) 245 | for t in visual_output_info 246 | } 247 | 248 | ok = self.visual_encoder_session.run(visual_features, visual_outputs, 249 | self.stream.cuda_stream) 250 | assert ok, "Runtime execution failed for vision encoder session" 251 | self.stream.synchronize() 252 | 253 | image_embeds = visual_outputs['output'] 254 | image_atts = torch.ones(image_embeds.size()[:-1], 255 | dtype=torch.long).to(image.device) 256 | 257 | return image_embeds, image_atts 258 | 259 | def setup_fake_prompts(self, visual_features, pre_input_ids, post_input_ids, 260 | input_lengths): 261 | # Assemble fake prompts which points to image embedding actually 262 | fake_prompt_id = torch.arange( 263 | self.model_config.vocab_size, self.model_config.vocab_size + 264 | visual_features.shape[0] * visual_features.shape[1]) 265 | fake_prompt_id = fake_prompt_id.reshape(visual_features.shape[0], 266 | visual_features.shape[1]) 267 | 268 | if post_input_ids is not None: 269 | input_ids = [pre_input_ids, fake_prompt_id, post_input_ids] 270 | else: 271 | input_ids = [fake_prompt_id, pre_input_ids] 272 | 273 | input_ids = torch.cat(input_ids, dim=1).contiguous().to(torch.int32) 274 | 275 | if self.decoder_llm or self.runtime_mapping.is_first_pp_rank(): 276 | ptuning_args = self.ptuning_setup(visual_features, input_ids, 277 | input_lengths) 278 | else: 279 | ptuning_args = [None, None, None] 280 | 281 | return input_ids, ptuning_args 282 | 283 | def ptuning_setup(self, prompt_table, input_ids, input_lengths): 284 | hidden_size = self.model_config.hidden_size * self.runtime_mapping.tp_size 285 | if prompt_table is not None: 286 | task_vocab_size = torch.tensor( 287 | [prompt_table.shape[1]], 288 | dtype=torch.int32, 289 | ).cuda() 290 | prompt_table = prompt_table.view( 291 | (prompt_table.shape[0] * prompt_table.shape[1], 292 | prompt_table.shape[2])) 293 | assert prompt_table.shape[ 294 | 1] == hidden_size, "Prompt table dimensions do not match hidden size" 295 | 296 | prompt_table = prompt_table.cuda().to( 297 | dtype=tensorrt_llm._utils.str_dtype_to_torch( 298 | self.model_config.dtype)) 299 | else: 300 | prompt_table = torch.empty([1, hidden_size]).cuda() 301 | task_vocab_size = torch.zeros([1]).cuda() 302 | 303 | if self.model_config.remove_input_padding: 304 | tasks = torch.zeros([torch.sum(input_lengths)], 305 | dtype=torch.int32).cuda() 306 | if self.decoder_llm: tasks = tasks.unsqueeze(0) 307 | else: 308 | tasks = torch.zeros(input_ids.shape, dtype=torch.int32).cuda() 309 | 310 | return [prompt_table, tasks, task_vocab_size] 311 | 312 | def setup_inputs(self, input_text, raw_image): 313 | attention_mask = None 314 | 315 | image_processor = AutoProcessor.from_pretrained(self.model_ckpt_path) 316 | if input_text is None: 317 | input_text = "" 318 | inputs = image_processor( 319 | images=raw_image, 320 | text=input_text, 321 | return_tensors="pt", 322 | ) 323 | image = inputs['flattened_patches'] 324 | image = image.expand(self.batch_size, -1, -1).contiguous() 325 | attention_mask = inputs['attention_mask'].to(self.device).to( 326 | torch.int) 327 | attention_mask = attention_mask.expand(self.batch_size, 328 | -1).contiguous() 329 | pre_prompt = "" 330 | post_prompt = None 331 | 332 | # Repeat inputs to match batch size 333 | pre_prompt = [pre_prompt] * self.batch_size 334 | post_prompt = [post_prompt] * self.batch_size 335 | image = image.to(self.device) 336 | 337 | # Generate decoder_input_ids for enc-dec models 338 | # Custom prompts can be added as: 339 | # decoder_input_ids = model.tokenizer(decoder_prompt).input_ids 340 | if self.decoder_llm: 341 | decoder_input_ids = None 342 | else: 343 | config = AutoConfig.from_pretrained(self.model_ckpt_path) 344 | decoder_start_id = config.decoder_start_token_id # T5 345 | if decoder_start_id is None: 346 | decoder_start_id = config.decoder.bos_token_id # Nougat 347 | 348 | decoder_input_ids = torch.IntTensor([[decoder_start_id]]) 349 | decoder_input_ids = decoder_input_ids.repeat((self.batch_size, 1)) 350 | 351 | return input_text, pre_prompt, post_prompt, image, decoder_input_ids, attention_mask 352 | 353 | def run(self, flattened_patches, attention_mask, max_new_tokens): 354 | # input_text, pre_prompt, post_prompt, processed_image, decoder_input_ids, attention_mask = self.setup_inputs( 355 | # None, raw_image) 356 | pre_prompt = [""] * self.batch_size 357 | post_prompt = [None] * self.batch_size 358 | config = AutoConfig.from_pretrained(self.model_ckpt_path) 359 | decoder_start_id = config.decoder_start_token_id # T5 360 | decoder_input_ids = torch.IntTensor([[decoder_start_id]]) 361 | decoder_input_ids = decoder_input_ids.repeat((self.batch_size, 1)) 362 | 363 | processed_image = flattened_patches.expand(self.batch_size, -1, -1).contiguous() 364 | attention_mask = attention_mask.to(self.device).to(torch.int) 365 | attention_mask = attention_mask.expand(self.batch_size,-1).contiguous() 366 | 367 | self.generate(pre_prompt, 368 | post_prompt, 369 | processed_image, 370 | decoder_input_ids, 371 | max_new_tokens, 372 | attention_mask=attention_mask, 373 | warmup=True) 374 | # num_iters = self.profiling_iterations if self.args.run_profiling else 1 375 | num_iters = 1 376 | # print(num_iters) 377 | for _ in range(num_iters): 378 | output_text = self.generate(pre_prompt, 379 | post_prompt, 380 | processed_image, 381 | decoder_input_ids, 382 | max_new_tokens, 383 | attention_mask=attention_mask, 384 | warmup=False) 385 | # if self.runtime_rank == 0: 386 | # self.print_result(input_text, output_text) 387 | return output_text 388 | 389 | 390 | def read_config(config_path): 391 | with open(config_path, "r") as f: 392 | config = json.load(f) 393 | 394 | builder_config = config['build_config'] 395 | plugin_config = builder_config['plugin_config'] 396 | pretrained_config = config['pretrained_config'] 397 | lora_config = builder_config['lora_config'] 398 | auto_parallel_config = builder_config['auto_parallel_config'] 399 | use_gpt_attention_plugin = plugin_config["gpt_attention_plugin"] 400 | remove_input_padding = plugin_config["remove_input_padding"] 401 | use_lora_plugin = plugin_config["lora_plugin"] 402 | tp_size = pretrained_config['mapping']['tp_size'] 403 | pp_size = pretrained_config['mapping']['pp_size'] 404 | gpus_per_node = auto_parallel_config['gpus_per_node'] 405 | world_size = tp_size * pp_size 406 | assert world_size == tensorrt_llm.mpi_world_size(), \ 407 | f'Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})' 408 | num_heads = pretrained_config["num_attention_heads"] 409 | hidden_size = pretrained_config["hidden_size"] 410 | head_size = pretrained_config["head_size"] 411 | vocab_size = pretrained_config["vocab_size"] 412 | max_batch_size = builder_config["max_batch_size"] 413 | max_beam_width = builder_config["max_beam_width"] 414 | num_layers = pretrained_config["num_hidden_layers"] 415 | num_kv_heads = pretrained_config.get('num_kv_heads', num_heads) 416 | 417 | assert (num_heads % tp_size) == 0 418 | num_heads = num_heads // tp_size 419 | hidden_size = hidden_size // tp_size 420 | num_kv_heads = (num_kv_heads + tp_size - 1) // tp_size 421 | 422 | cross_attention = pretrained_config["architecture"] == "DecoderModel" 423 | skip_cross_qkv = pretrained_config.get('skip_cross_qkv', False) 424 | has_position_embedding = pretrained_config["has_position_embedding"] 425 | has_token_type_embedding = hasattr(pretrained_config, "type_vocab_size") 426 | use_custom_all_reduce = plugin_config.get('use_custom_all_reduce', False) 427 | dtype = pretrained_config["dtype"] 428 | 429 | paged_kv_cache = plugin_config['paged_kv_cache'] 430 | tokens_per_block = plugin_config['tokens_per_block'] 431 | 432 | gather_context_logits = builder_config.get('gather_context_logits', False) 433 | gather_generation_logits = builder_config.get('gather_generation_logits', 434 | False) 435 | max_prompt_embedding_table_size = builder_config.get( 436 | 'max_prompt_embedding_table_size', 0) 437 | 438 | model_config = ModelConfig( 439 | num_heads=num_heads, 440 | num_kv_heads=num_kv_heads, 441 | hidden_size=hidden_size, 442 | head_size=head_size, 443 | max_batch_size=max_batch_size, 444 | max_beam_width=max_beam_width, 445 | vocab_size=vocab_size, 446 | num_layers=num_layers, 447 | gpt_attention_plugin=use_gpt_attention_plugin, 448 | remove_input_padding=remove_input_padding, 449 | paged_kv_cache=paged_kv_cache, 450 | tokens_per_block=tokens_per_block, 451 | cross_attention=cross_attention, 452 | has_position_embedding=has_position_embedding, 453 | has_token_type_embedding=has_token_type_embedding, 454 | use_custom_all_reduce=use_custom_all_reduce, 455 | dtype=dtype, 456 | gather_context_logits=gather_context_logits, 457 | gather_generation_logits=gather_generation_logits, 458 | max_prompt_embedding_table_size=max_prompt_embedding_table_size, 459 | lora_plugin=use_lora_plugin, 460 | lora_target_modules=lora_config.get('lora_target_modules'), 461 | trtllm_modules_to_hf_modules=lora_config.get( 462 | 'trtllm_modules_to_hf_modules'), 463 | skip_cross_qkv=skip_cross_qkv, 464 | ) 465 | 466 | return model_config, tp_size, pp_size, gpus_per_node, dtype 467 | 468 | 469 | class Mapping(object): 470 | def __init__( 471 | self, 472 | world_size=1, 473 | rank=0, 474 | gpus_per_node=8, 475 | tp_size=1, 476 | pp_size=1, 477 | moe_tp_size=-1, # -1 means no moe 478 | moe_ep_size=-1): # -1 means no moe 479 | # set default values for non-moe cases 480 | if moe_tp_size == -1: 481 | moe_tp_size = tp_size 482 | moe_ep_size = 1 483 | 484 | if pp_size * tp_size != world_size: 485 | raise ValueError( 486 | f"world_size must equal to pp_size * tp_size, but got {world_size} != {pp_size} * {tp_size}" 487 | ) 488 | 489 | moe_tp_ep_size = moe_tp_size * moe_ep_size 490 | if moe_tp_ep_size != tp_size: 491 | raise ValueError( 492 | f"tp_size must equal to moe_tp_size * moe_ep_size, but got {tp_size} != {moe_tp_size} * {moe_ep_size}" 493 | ) 494 | 495 | self.tp_size = tp_size 496 | self.pp_size = pp_size 497 | self.moe_tp_size = moe_tp_size 498 | self.moe_ep_size = moe_ep_size 499 | self.world_size = world_size 500 | self.rank = rank 501 | self.gpus_per_node = gpus_per_node 502 | 503 | self.pp_groups = [] 504 | self.tp_groups = [] 505 | self.moe_tp_groups = [] 506 | self.moe_ep_groups = [] 507 | 508 | # init pp group 509 | for i in range(tp_size): 510 | ranks = range(i+ self.rank, world_size+ self.rank, tp_size) 511 | self.pp_groups.append(list(ranks)) 512 | 513 | # init tp group 514 | for i in range(pp_size): 515 | ranks = range(i * tp_size + self.rank, (i + 1) * tp_size + self.rank) 516 | self.tp_groups.append(list(ranks)) 517 | 518 | # init moe tp group 519 | for i in range(pp_size): 520 | for j in range(moe_ep_size): 521 | ranks = range(i * moe_tp_ep_size + j, (i + 1) * moe_tp_ep_size, 522 | moe_ep_size) 523 | self.moe_tp_groups.append(list(ranks)) 524 | 525 | # init moe ep group 526 | for i in range(pp_size): 527 | for j in range(moe_tp_size): 528 | ranks = range(i * moe_tp_ep_size + j * moe_ep_size, 529 | i * moe_tp_ep_size + (j + 1) * moe_ep_size) 530 | self.moe_ep_groups.append(list(ranks)) 531 | 532 | # self.pp_rank = self.rank // self.tp_size 533 | # self.tp_rank = self.rank % self.tp_size 534 | self.pp_rank = 0 535 | self.tp_rank = 0 536 | self.moe_tp_rank = self.tp_rank // self.moe_ep_size 537 | self.moe_ep_rank = self.tp_rank % self.moe_ep_size 538 | 539 | # self.tp_group = self.tp_groups[self.pp_rank] 540 | # self.pp_group = self.pp_groups[self.tp_rank] 541 | self.moe_tp_group = self.moe_tp_groups[self.pp_rank * moe_ep_size + 542 | self.moe_ep_rank] 543 | self.moe_ep_group = self.moe_ep_groups[self.pp_rank * moe_tp_size + 544 | self.moe_tp_rank] 545 | 546 | self.node_rank = self.rank // self.gpus_per_node 547 | self.local_rank = self.rank % self.gpus_per_node 548 | 549 | def get_node_rank(self, rank: int): 550 | return rank // self.gpus_per_node 551 | 552 | def get_local_rank(self, rank: int): 553 | return rank % self.gpus_per_node 554 | 555 | def has_tp(self): 556 | return self.tp_size > 1 557 | 558 | def is_last_pp_rank(self): 559 | return self.pp_rank == self.pp_size - 1 560 | 561 | def is_first_pp_rank(self): 562 | return self.pp_rank == 0 563 | 564 | def has_pp(self): 565 | return self.pp_size > 1 566 | 567 | def prev_pp_rank(self): 568 | p = self.rank - self.tp_size 569 | if p < 0: 570 | p = p + self.world_size 571 | return p 572 | 573 | def next_pp_rank(self): 574 | p = self.rank + self.tp_size 575 | if p >= self.world_size: 576 | p = p - self.world_size 577 | return p 578 | 579 | def has_moe_tp(self): 580 | return self.moe_tp_size > 1 581 | 582 | def has_moe_ep(self): 583 | return self.moe_ep_size > 1 584 | 585 | def pp_layers(self, num_layers: int) -> List[int]: 586 | layers_per_pipeline_stage = num_layers // self.pp_size 587 | layers_range = range(self.pp_rank * layers_per_pipeline_stage, 588 | (self.pp_rank + 1) * layers_per_pipeline_stage) 589 | return list(layers_range) 590 | 591 | def ep_experts(self, num_experts: int) -> List[int]: 592 | experts_per_rank = num_experts // self.moe_ep_size 593 | experts_range = range(self.moe_ep_rank * experts_per_rank, 594 | (self.moe_ep_rank + 1) * experts_per_rank) 595 | return list(experts_range) 596 | 597 | 598 | def get_engine_name(rank): 599 | return 'rank{}.engine'.format(rank) 600 | 601 | class TRTLLMEncDecModel: 602 | 603 | def __init__( 604 | self, 605 | engine_name, 606 | engine_dir, 607 | lora_dir=None, 608 | lora_task_uids=None, 609 | debug_mode=False, 610 | skip_encoder=False, 611 | stream: torch.cuda.Stream = None, 612 | ): 613 | # in multi-node setup, it's important to set_device at the very beginning so .to('cuda') refers to current device 614 | # accordingly, all input & output tensors should be moved to current device 615 | # otherwise, it's default to 'cuda:0' 616 | 617 | # self.runtime_rank = tensorrt_llm.mpi_rank() 618 | self.device_id = torch.cuda.current_device() 619 | # torch.cuda.set_device(device_id) 620 | self.device = torch.cuda.current_device() 621 | self.skip_encoder = skip_encoder 622 | self.lora_task_uids = lora_task_uids 623 | 624 | # when enc-dec runs by itself, stream can be None and we create new stream here 625 | # when enc-dec has to run as a component in a bigger workflow (e.g., multimodal), earlier components in the workflow may have results in its stream, which we should pass that stream in to avoid unnecessary stream sync 626 | self.stream = stream 627 | if self.stream is None: 628 | self.stream = torch.cuda.Stream(self.device) 629 | torch.cuda.set_stream(self.stream) 630 | 631 | def engine_setup(component): 632 | # model config 633 | config_path = os.path.join(engine_dir, component, "config.json") 634 | model_config, tp_size, pp_size, gpus_per_node, dtype = read_config( 635 | config_path) 636 | 637 | # MGMN config 638 | world_size = tp_size * pp_size 639 | # runtime_rank = tensorrt_llm.mpi_rank() 640 | runtime_rank = torch.cuda.current_device() 641 | # assert runtime_rank < world_size, "Runtime GPU rank exceeds MPI world size. Did you launch more MPI processes than required?" 642 | # runtime_mapping = tensorrt_llm.Mapping(world_size, 643 | # runtime_rank, 644 | # tp_size=tp_size, 645 | # pp_size=pp_size, 646 | # gpus_per_node=gpus_per_node) 647 | # tensorrt_llm.Mapping 648 | runtime_mapping = Mapping(world_size, 649 | runtime_rank, 650 | tp_size=tp_size, 651 | pp_size=pp_size, 652 | gpus_per_node=gpus_per_node) 653 | # load engine 654 | # engine_fname = get_engine_name(runtime_rank) 655 | engine_fname = get_engine_name(0) 656 | with open(os.path.join(engine_dir, component, engine_fname), "rb") as f: 657 | engine_buffer = f.read() 658 | 659 | return model_config, runtime_mapping, engine_buffer 660 | 661 | # Note: encoder and decoder doesn't necessarily have the same TP & PP config 662 | 663 | if not skip_encoder: 664 | self.encoder_model_config, self.encoder_runtime_mapping, encoder_engine_buffer = engine_setup( 665 | component='encoder') 666 | 667 | self.nccl_comm = None 668 | if self.encoder_runtime_mapping.has_pp(): 669 | # for Pipeline Parallelism in encoder 670 | self.nccl_comm = torch.classes.trtllm.NcclCommunicatorOp( 671 | self.encoder_runtime_mapping.tp_size, 672 | self.encoder_runtime_mapping.pp_size, 673 | self.encoder_runtime_mapping.rank) 674 | 675 | # session setup 676 | self.encoder_session = tensorrt_llm.runtime.Session.from_serialized_engine( 677 | encoder_engine_buffer) 678 | 679 | # encoder lora manager setup 680 | if self.encoder_model_config.lora_plugin: 681 | self.encoder_lora_manager = LoraManager() 682 | # TODO: this is only for bart 683 | self.encoder_lora_manager.load_from_hf( 684 | model_dirs=lora_dir, 685 | model_config=self.encoder_model_config, 686 | runtime_mapping=self.encoder_runtime_mapping, 687 | component='encoder', 688 | ) 689 | else: 690 | self.encoder_lora_manager = None 691 | else: 692 | self.encoder_model_config, self.encoder_runtime_mapping, encoder_engine_buffer = None, None, None 693 | self.nccl_comm, self.encoder_session = None, None 694 | 695 | self.decoder_model_config, self.decoder_runtime_mapping, decoder_engine_buffer = engine_setup( 696 | component='decoder') 697 | 698 | self.decoder_session = tensorrt_llm.runtime.GenerationSession( 699 | self.decoder_model_config, 700 | decoder_engine_buffer, 701 | self.decoder_runtime_mapping, 702 | debug_mode=debug_mode) 703 | 704 | # decoder lora manager setup 705 | if self.decoder_model_config.lora_plugin: 706 | self.decoder_lora_manager = LoraManager() 707 | # TODO: this is only for bart 708 | self.decoder_lora_manager.load_from_hf( 709 | model_dirs=lora_dir, 710 | model_config=self.decoder_model_config, 711 | runtime_mapping=self.decoder_runtime_mapping, 712 | component='decoder', 713 | ) 714 | else: 715 | self.decoder_lora_manager = None 716 | 717 | @classmethod 718 | def from_engine(cls, 719 | engine_name, 720 | engine_dir, 721 | lora_dir=None, 722 | lora_task_uids=None, 723 | debug_mode=False, 724 | skip_encoder=False, 725 | stream=None): 726 | return cls(engine_name, 727 | engine_dir, 728 | lora_dir, 729 | lora_task_uids, 730 | debug_mode=debug_mode, 731 | skip_encoder=skip_encoder, 732 | stream=stream) 733 | 734 | def process_input(self, 735 | input_ids, 736 | remove_input_padding=False, 737 | pad_token_id=0, 738 | prompt_tasks=None): 739 | if remove_input_padding: 740 | # in remove padding mode --> flatten input, calculate actual length and max length 741 | # Note: 1st token should never be removed, even if it is pad_token_id 742 | first_ids = input_ids[:, 0] 743 | input_ids = input_ids[:, 1:] 744 | input_lengths = 1 + (input_ids != pad_token_id).sum(dim=1).type( 745 | torch.IntTensor).to(self.device) # [batch_size] 746 | new_ids = [] 747 | for i in range(len(input_ids)): 748 | row = input_ids[i, :] 749 | row = row[row != pad_token_id] 750 | new_ids.append( 751 | torch.cat( 752 | (torch.IntTensor([first_ids[i]]).to(self.device), row))) 753 | input_ids = torch.cat(new_ids) # [num_tokens] 754 | if prompt_tasks is not None: 755 | prompt_tasks = prompt_tasks[:input_ids.shape[0]] 756 | else: 757 | # in padding mode --> keep input, just calculate actual length and max length 758 | # Note: 1st token should always count, even if it is pad_token_id. e.g., decoder start id in enc-dec models could be a single pad_token_id, we should count 759 | input_lengths = torch.tensor( 760 | 1 + (input_ids[:, 1:] != pad_token_id).sum(dim=1).type( 761 | torch.IntTensor).to(self.device), 762 | dtype=torch.int32, 763 | device=self.device) 764 | max_input_length = torch.max(input_lengths).item() 765 | return input_ids, input_lengths, max_input_length, prompt_tasks 766 | 767 | def encoder_run(self, 768 | input_ids, 769 | input_lengths, 770 | max_input_length, 771 | position_ids=None, 772 | token_type_ids=None, 773 | debug_mode=False, 774 | prompt_embedding_table=None, 775 | prompt_tasks=None, 776 | prompt_vocab_size=None, 777 | attention_mask=None): 778 | 779 | # each engine has hidden_dim/TP, don't forget to multiply TP 780 | hidden_size = self.encoder_model_config.hidden_size * self.encoder_runtime_mapping.tp_size 781 | if input_ids.dim() == 1: 782 | hidden_states_shape = (input_ids.shape[0], hidden_size 783 | ) # [num_tokens,D] 784 | else: 785 | hidden_states_shape = (input_ids.shape[0], input_ids.shape[1], 786 | hidden_size) # [BS,seqlen,D] 787 | hidden_states_dtype = lambda name: trt_dtype_to_torch( 788 | self.encoder_session.engine.get_tensor_dtype(name)) 789 | 790 | # input tensors. only first PP rank has id input, others are hidden_states input 791 | inputs = {} 792 | if self.encoder_runtime_mapping.is_first_pp_rank(): 793 | inputs['input_ids'] = input_ids.contiguous() 794 | if self.encoder_model_config.has_position_embedding: 795 | if position_ids is None: 796 | if self.encoder_model_config.remove_input_padding: 797 | position_ids = [ 798 | torch.arange(sample_length, 799 | dtype=torch.int32, 800 | device=input_ids.device) 801 | for sample_length in torch_to_numpy(input_lengths) 802 | ] 803 | position_ids = torch.cat(position_ids) 804 | else: 805 | bsz, seq_len = input_ids.shape[:2] 806 | position_ids = torch.arange( 807 | seq_len, dtype=torch.int32, 808 | device=input_ids.device).expand(bsz, -1) 809 | inputs['position_ids'] = position_ids.contiguous() 810 | if self.encoder_model_config.has_token_type_embedding: 811 | inputs['token_type_ids'] = token_type_ids.contiguous() 812 | 813 | if self.encoder_model_config.max_prompt_embedding_table_size > 0: 814 | inputs[ 815 | 'prompt_embedding_table'] = prompt_embedding_table.contiguous( 816 | ) 817 | inputs['tasks'] = prompt_tasks.contiguous() 818 | inputs['prompt_vocab_size'] = prompt_vocab_size.contiguous() 819 | else: 820 | # just need a placeholder, engine will call NCCL to recv and fill data from previous rank 821 | inputs['hidden_states_input'] = torch.empty( 822 | hidden_states_shape, 823 | dtype=hidden_states_dtype('hidden_states_input'), 824 | device=self.device).contiguous() 825 | if attention_mask is not None and not self.encoder_model_config.gpt_attention_plugin: 826 | inputs['attention_mask'] = attention_mask.contiguous() 827 | 828 | inputs['input_lengths'] = input_lengths 829 | # use shape info to pass max length info in remove padding mode 830 | inputs['max_input_length'] = torch.empty( 831 | (max_input_length, ), 832 | dtype=hidden_states_dtype('max_input_length'), 833 | device=self.device).contiguous() 834 | batch_size = input_lengths.size(0) 835 | inputs['host_request_types'] = torch.IntTensor([0] * 836 | batch_size).to('cpu') 837 | if self.encoder_model_config.remove_input_padding: 838 | inputs['host_context_lengths'] = input_lengths.to('cpu') 839 | 840 | if self.encoder_model_config.lora_plugin and self.encoder_lora_manager is not None: 841 | inputs.update( 842 | self.encoder_lora_manager.input_buffers( 843 | self.lora_task_uids, 844 | self.encoder_runtime_mapping, 845 | self.encoder_model_config.num_layers, 846 | )) 847 | 848 | # Note: runtime.Session's run() method will set input/output tensor address, here we only need to provide tensor shape 849 | self.encoder_session.set_shapes(inputs) 850 | 851 | # output tensors. only last PP rank final encoder output, others are intermediate hidden_states output. Need broadcast later 852 | outputs = {} 853 | if self.encoder_runtime_mapping.is_last_pp_rank(): 854 | outputs['encoder_output'] = torch.empty( 855 | hidden_states_shape, 856 | dtype=hidden_states_dtype('encoder_output'), 857 | device=self.device).contiguous() 858 | else: 859 | outputs['hidden_states_output'] = torch.empty( 860 | hidden_states_shape, 861 | dtype=hidden_states_dtype('hidden_states_output'), 862 | device=self.device).contiguous() 863 | 864 | # ------------------------------------------- 865 | if debug_mode: 866 | engine = self.encoder_session.engine 867 | context = self.encoder_session.context 868 | # setup debugging buffer for the encoder 869 | for i in range(self.encoder_session.engine.num_io_tensors): 870 | name = engine.get_tensor_name(i) 871 | if engine.get_tensor_mode( 872 | name 873 | ) == trt.TensorIOMode.OUTPUT and name not in outputs.keys(): 874 | dtype = engine.get_tensor_dtype(name) 875 | shape = context.get_tensor_shape(name) 876 | outputs[name] = torch.zeros(tuple(shape), 877 | dtype=trt_dtype_to_torch(dtype), 878 | device=self.device) 879 | context.set_tensor_address(name, outputs[name].data_ptr()) 880 | # ------------------------------------------- 881 | 882 | # TRT session run 883 | # Note: need cuda stream ID, not a torch Stream 884 | ok = self.encoder_session.run(inputs, outputs, self.stream.cuda_stream) 885 | assert ok, "Runtime execution failed" 886 | self.stream.synchronize() 887 | 888 | # Tensor Parallelism is handled by model/engine definition 889 | # But we need to broadcast among PP group at the end of encoder's Pipeline Parallelism 890 | # After this, all ranks should recv the encoder output, and world might be re-configured using decoder's TP-PP config 891 | def pp_communicate_encoder_output(encoder_output): 892 | if self.encoder_runtime_mapping.is_last_pp_rank(): 893 | for pp_rank in self.encoder_runtime_mapping.pp_group: 894 | if pp_rank != self.encoder_runtime_mapping.rank: 895 | self.nccl_comm.send(encoder_output, pp_rank) 896 | return encoder_output 897 | else: 898 | self.nccl_comm.recv(encoder_output, 899 | self.encoder_runtime_mapping.pp_group[-1]) 900 | return encoder_output 901 | 902 | if self.encoder_runtime_mapping.has_pp(): 903 | # use hidden_states output buffer to receive output as the shapes are same 904 | encoder_output_buf = outputs[ 905 | 'encoder_output'] if self.encoder_runtime_mapping.is_last_pp_rank( 906 | ) else outputs['hidden_states_output'] 907 | encoder_output = pp_communicate_encoder_output(encoder_output_buf) 908 | else: 909 | encoder_output = outputs['encoder_output'] 910 | 911 | return encoder_output 912 | 913 | def generate(self, 914 | encoder_input_ids, 915 | decoder_input_ids, 916 | max_new_tokens, 917 | num_beams=1, 918 | pad_token_id=None, 919 | eos_token_id=None, 920 | bos_token_id=None, 921 | debug_mode=False, 922 | return_dict=False, 923 | prompt_embedding_table=None, 924 | prompt_tasks=None, 925 | prompt_vocab_size=None, 926 | attention_mask=None, 927 | time_encoder=False, 928 | return_encoder_output=False): 929 | ## ensure all externally provided tensors are on the correct device. 930 | encoder_input_ids = encoder_input_ids.to(self.device) 931 | decoder_input_ids = decoder_input_ids.to(self.device) 932 | 933 | if attention_mask is not None: 934 | attention_mask = torch.tensor(attention_mask, 935 | dtype=torch.int32, 936 | device=self.device) 937 | 938 | ## encoder run 939 | encoder_remove_input_padding = self.encoder_model_config.remove_input_padding if self.encoder_model_config else self.decoder_model_config.remove_input_padding 940 | 941 | encoder_input_ids, encoder_input_lengths, encoder_max_input_length, prompt_tasks = self.process_input( 942 | encoder_input_ids, encoder_remove_input_padding, pad_token_id, 943 | prompt_tasks) 944 | 945 | if not self.skip_encoder: 946 | #logger.info(f"Rank {self.runtime_rank} Running encoder engine ...") 947 | if time_encoder: 948 | tik = time.time() 949 | encoder_output = self.encoder_run( 950 | encoder_input_ids, 951 | encoder_input_lengths, 952 | encoder_max_input_length, 953 | debug_mode=debug_mode, 954 | prompt_embedding_table=prompt_embedding_table, 955 | prompt_tasks=prompt_tasks, 956 | prompt_vocab_size=prompt_vocab_size, 957 | attention_mask=attention_mask) 958 | if time_encoder: 959 | tok = time.time() 960 | print(f"TRT-LLM Encoder time {(tok-tik)*1000}ms") 961 | else: 962 | encoder_output = prompt_embedding_table 963 | if encoder_input_ids.dim() > 1: 964 | encoder_output = encoder_output.unsqueeze(0) 965 | 966 | ## decoder run 967 | # logger.info(f"Rank {self.runtime_rank} Running decoder engine ...") 968 | decoder_input_ids, decoder_input_lengths, decoder_max_input_length, _ = self.process_input( 969 | decoder_input_ids, self.decoder_model_config.remove_input_padding, 970 | pad_token_id) 971 | 972 | # `cross_attention_mask` in context phase [batch_size, query_len, encoder_input_len] 973 | # where query_len happens to be 1 in current cases, but not necessarily always, and 974 | # `cross_attention_mask` in generation phase [batch_size, 1, encoder_input_len] where 975 | # the query_len is always 1 since we have kv cache. 976 | cross_attention_mask = None 977 | if attention_mask is not None: 978 | cross_attention_mask = torch.tensor(attention_mask, 979 | dtype=torch.int32, 980 | device=self.device).reshape( 981 | attention_mask.shape[0], 1, 982 | attention_mask.shape[1]) 983 | 984 | # generation config 985 | sampling_config = SamplingConfig(end_id=eos_token_id, 986 | pad_id=pad_token_id, 987 | num_beams=num_beams, 988 | min_length=1, 989 | return_dict=return_dict) 990 | sampling_config.update(output_cum_log_probs=return_dict, 991 | output_log_probs=return_dict) 992 | 993 | # decoder autoregressive generation 994 | self.decoder_session.setup( 995 | decoder_input_lengths.size(0), 996 | decoder_max_input_length, 997 | max_new_tokens, 998 | num_beams, 999 | max_attention_window_size=None, 1000 | encoder_max_input_length=encoder_max_input_length, 1001 | lora_manager=self.decoder_lora_manager, 1002 | lora_uids=self.lora_task_uids, 1003 | ) 1004 | 1005 | output = self.decoder_session.decode( 1006 | decoder_input_ids, 1007 | decoder_input_lengths, 1008 | sampling_config, 1009 | encoder_output=encoder_output, 1010 | encoder_input_lengths=encoder_input_lengths, 1011 | return_dict=return_dict, 1012 | cross_attention_mask=cross_attention_mask) 1013 | 1014 | if return_dict and return_encoder_output: 1015 | output['encoder_output'] = encoder_output 1016 | 1017 | return output 1018 | -------------------------------------------------------------------------------- /tools/demo/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-Innovator/StructEqTable-Deploy/55649befcf880bc4fbf229a7b685c09c962d0dea/tools/demo/demo.png -------------------------------------------------------------------------------- /tools/demo/demo.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import argparse 4 | 5 | from PIL import Image 6 | from struct_eqtable import build_model 7 | 8 | 9 | def parse_config(): 10 | parser = argparse.ArgumentParser(description='arg parser') 11 | parser.add_argument('--image_path', type=str, default='demo.png', help='data path for table image') 12 | parser.add_argument('--ckpt_path', type=str, default='U4R/StructTable-InternVL2-1B', help='ckpt path for table model, which can be downloaded from huggingface') 13 | parser.add_argument('--max_new_tokens', type=int, default=1024, help='maximum output tokens of model inference') 14 | parser.add_argument('-t', '--max_waiting_time', type=int, default=60, help='maximum waiting time of model inference') 15 | parser.add_argument('-f', '--output_format', type=str, nargs='+', default=['latex'], 16 | help='The model outputs LaTeX format code by default. Simple structured table LaTeX code can be converted to HTML or Markdown format using pypandoc.') 17 | parser.add_argument('--tensorrt_path', type=str, default=None, help='enable tensorrt for model acceleration') 18 | parser.add_argument('--lmdeploy', action='store_true', help='use lmdepoly to accelerate model inference') 19 | parser.add_argument('--disable_flash_attn', action='store_true', help='disable flash attention for non ampere gpu') 20 | args = parser.parse_args() 21 | return args 22 | 23 | def main(): 24 | args = parse_config() 25 | 26 | # build model 27 | model = build_model( 28 | args.ckpt_path, 29 | max_new_tokens=args.max_new_tokens, 30 | max_time=args.max_waiting_time, 31 | tensorrt_path=args.tensorrt_path, 32 | lmdeploy=args.lmdeploy, 33 | flash_attn=not args.disable_flash_attn 34 | ) 35 | 36 | assert torch.cuda.is_available(), "Our model current only support with gpu" 37 | if not args.tensorrt_path: 38 | model = model.cuda() 39 | 40 | # process output format 41 | output_formats = list(set(args.output_format) & set(model.supported_output_format)) 42 | print(f"Supported output format: {' '.join(output_formats)}") 43 | 44 | # model inference 45 | raw_image = Image.open(args.image_path) 46 | 47 | output_list = [] 48 | start_time = time.time() 49 | 50 | with torch.no_grad(): 51 | for tgt_fmt in output_formats: 52 | output = model(raw_image, output_format=tgt_fmt) 53 | output_list.append(output) 54 | 55 | # show output latex code of table 56 | cost_time = time.time() - start_time 57 | print(f"total cost time: {cost_time:.2f}s") 58 | 59 | if cost_time >= args.max_waiting_time: 60 | warn_log = f"\033[93mThe model inference time exceeds the maximum waiting time {args.max_waiting_time} seconds, the result may be incomplete.\n" \ 61 | "Please increase the maximum waiting time with argument --max_waiting_time or Model may not support the type of input table image \033[0m" 62 | print(warn_log) 63 | 64 | for i, tgt_fmt in enumerate(output_formats): 65 | for j, output in enumerate(output_list[i]): 66 | print(f"Table {j} {tgt_fmt.upper()} format output:\n{output}") 67 | 68 | 69 | if __name__ == '__main__': 70 | main() 71 | -------------------------------------------------------------------------------- /tools/demo/demo.tex: -------------------------------------------------------------------------------- 1 | 2 | \documentclass[border=20pt]{standalone} 3 | \usepackage{blindtext}% 4 | \usepackage{subcaption} 5 | \usepackage{url} 6 | \usepackage{graphicx} 7 | \usepackage{caption} 8 | \usepackage{multirow} 9 | \usepackage{booktabs} 10 | \usepackage{color} 11 | \usepackage{colortbl} 12 | \usepackage{xcolor,soul,framed} 13 | \usepackage{xeCJK} 14 | %\usepackage{fontspec} 15 | %\usepackage[margin=1in]{geometry} 16 | \usepackage{printlen} 17 | \usepackage{amsmath,amssymb,mathtools,bm,mathrsfs,textcomp} 18 | \setlength{\parindent}{0pt} 19 | 20 | \begin{document} 21 | 22 | \begin{tabular}{|c|c|c|c|} 23 | \hline 24 | Quantity $\backslash$ Unit System & International System SI (kg-m-s) & Traditional aeronautical (lb-ft-s) & Traditional structural (lb-inch-s) \\ 25 | \hline 26 | Mass (translational inertia), $m$ & kilogram mass (kg) & slug = lb-s$^2$/f & lb-s$^2$/inch \\ 27 | \hline 28 | Length, translational motion & meter (m) & foot (ft) & inch (in.) \\ 29 | \hline 30 | Time, $t$ & second (s) & second (s) & second (s) \\ 31 | \hline 32 | Force, translational action & newton (N) = kg-m/s$^2$ & pound force (lb) & pound force (lb) \\ 33 | \hline 34 | Translational stiffness constant, $k$ & N/m & lb/ft & lb/inch \\ 35 | \hline 36 | Translational damping constant, $c$ & N/(m/s) = N-s/m & lb/(ft/s) = lb-s/ft & lb/(inch/s) = lb-s/inch \\ 37 | \hline 38 | Angle, rotational motion & radial (rad), which is dimensionless & radial (rad), which is dimensionless & radial (rad), which is dimensionless \\ 39 | \hline 40 | Rotational inertia, $J$ & kg-m$^2$ & slug-ft$^2$ = lb-s$^2$ - ft & lb-s$^2$ - inch \\ 41 | \hline 42 | Moment or torque, rotational action & N-m & lb-ft & lb-inch \\ 43 | \hline 44 | Rotational stiffness constant, $k_\theta$ & (N-m)/rad = N-m & (lb-ft)/rad = lb-ft & (lb-inch)/rad = lb-inch \\ 45 | \hline 46 | Rotational damping constant, $c_\theta$ & (N-m)/(rad/s) = N-m-s & (lb-ft)/(rad/s) = lb-ft-s & (lb-inch)/(rad/s) = lb-inch-s \\ 47 | \hline 48 | \end{tabular} 49 | 50 | \end{document} -------------------------------------------------------------------------------- /tools/scripts/build_tensorrt.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | HF_CKPT_PATH=${1:-"../ckpts/StructTable-base"} 4 | MODEL_OUTPUT=${2:-"../ckpts/StructTable-base-TensorRT"} 5 | MAX_IMAGE_TOKEN_NUM=${3:-2048} 6 | MAX_OUPTPUT_TOKEN_NUM=${4:-2048} 7 | MODEL_TYPE=${5:-"StructEqTable"} 8 | 9 | if [ ! -d $MODEL_OUTPUT ]; then 10 | mkdir -p $MODEL_OUTPUT 11 | fi 12 | 13 | # Step1 Convert the model into TensorrtLLM checkpoint format 14 | echo "Step1 Convert the model into TensorrtLLM checkpoint format" 15 | 16 | python tensorrt_utils/convert_checkpoint.py --model_type $MODEL_TYPE \ 17 | --model_dir $HF_CKPT_PATH \ 18 | --output_dir $MODEL_OUTPUT/trt_models/float16 \ 19 | --tp_size 1 \ 20 | --pp_size 1 \ 21 | --workers 1 \ 22 | --dtype float16 23 | 24 | # Step2 Compile the model 25 | echo "Step2 build LLM Engine" 26 | 27 | trtllm-build --checkpoint_dir $MODEL_OUTPUT/trt_models/float16/decoder \ 28 | --output_dir $MODEL_OUTPUT/llm_engines/decoder \ 29 | --paged_kv_cache disable \ 30 | --moe_plugin disable \ 31 | --enable_xqa disable \ 32 | --use_custom_all_reduce disable \ 33 | --gemm_plugin float16 \ 34 | --bert_attention_plugin float16 \ 35 | --gpt_attention_plugin float16 \ 36 | --remove_input_padding enable \ 37 | --context_fmha disable \ 38 | --max_beam_width 1 \ 39 | --max_batch_size 1 \ 40 | --max_seq_len $MAX_OUPTPUT_TOKEN_NUM \ 41 | --max_encoder_input_len $MAX_IMAGE_TOKEN_NUM \ 42 | --max_input_len 1 43 | 44 | # Step3 build visual engine 45 | echo "Step3 Build Visual Engine" 46 | 47 | python tensorrt_utils/build_visual_engine.py --model_type $MODEL_TYPE \ 48 | --model_path $HF_CKPT_PATH \ 49 | --output_dir $MODEL_OUTPUT/visual_engines \ 50 | --max_batch_size 1 51 | 52 | if [ -f './model.cache' ]; then 53 | rm ./model.cache 54 | fi 55 | 56 | echo "Build TensorRT model and Visual Engine Successfully" -------------------------------------------------------------------------------- /tools/tensorrt_utils/build_visual_engine.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import sys 5 | import tarfile 6 | from time import time 7 | 8 | import yaml 9 | 10 | # isort: off 11 | import torch 12 | import tensorrt as trt 13 | from tensorrt_llm.builder import Builder 14 | # isort: on 15 | import json 16 | import math 17 | 18 | import torch.nn.functional as F 19 | from PIL import Image 20 | from safetensors.torch import save_file 21 | from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM, 22 | AutoModelForVision2Seq, AutoProcessor, 23 | Blip2ForConditionalGeneration, Blip2Processor, 24 | FuyuForCausalLM, FuyuProcessor, 25 | LlavaForConditionalGeneration, NougatProcessor, 26 | Pix2StructForConditionalGeneration, 27 | VisionEncoderDecoderModel) 28 | 29 | 30 | def parse_arguments(): 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--model_type', 33 | type=str, 34 | default=None, 35 | choices=[ 36 | 'opt-2.7b', 'opt-6.7b', 'flan-t5-xl', 'flan-t5-xxl', 37 | 'llava', 'vila', 'nougat', 'cogvlm', 'fuyu', 'pix2struct', 38 | 'StructEqTable', 'neva', 'kosmos-2', 'video-neva', 39 | 'phi-3-vision' 40 | ], 41 | help="Model type") 42 | parser.add_argument( 43 | '--model_path', 44 | type=str, 45 | default=None, 46 | help= 47 | "Huggingface repo, local directory with weights or path to checkpoint file" 48 | ) 49 | parser.add_argument('--vila_path', 50 | type=str, 51 | default=None, 52 | help="Path to VILA source code directory") 53 | parser.add_argument('--output_dir', 54 | type=str, 55 | default=None, 56 | help="Directory where visual TRT engines are saved") 57 | parser.add_argument('--max_batch_size', 58 | type=int, 59 | default=4, 60 | help="Maximum batch size for input images") 61 | return parser.parse_args() 62 | 63 | 64 | class VisionEngineBuilder: 65 | 66 | def __init__(self, args): 67 | args.device = torch.device( 68 | "cuda") if torch.cuda.is_available() else "cpu" 69 | if args.output_dir is None: 70 | args.output_dir = 'visual_engines/%s' % ( 71 | args.model_path.split('/')[-1] if args.vila_path is not None 72 | else args.model_path.split('/')[-1]) 73 | if not os.path.exists(args.output_dir): 74 | os.makedirs(args.output_dir) 75 | 76 | self.args = args 77 | 78 | def build(self): 79 | args = self.args 80 | if 'opt' in args.model_type or 't5' in args.model_type: 81 | build_blip2_engine(args) 82 | elif args.model_type == 'pix2struct': 83 | build_pix2struct_engine(args) 84 | elif args.model_type == 'StructEqTable': 85 | build_StructEqTable_engine(args) 86 | elif args.model_type == 'llava': 87 | build_llava_engine(args) 88 | elif args.model_type == 'vila': 89 | assert args.vila_path is not None, "Please clone and provide VILA source code path" 90 | build_vila_engine(args) 91 | elif args.model_type == 'nougat': 92 | build_nougat_engine(args) 93 | elif args.model_type == 'cogvlm': 94 | build_cogvlm_engine(args) 95 | elif args.model_type == 'fuyu': 96 | build_fuyu_engine(args) 97 | elif args.model_type == 'neva': 98 | build_neva_engine(args) 99 | elif args.model_type == 'video-neva': 100 | build_video_neva_engine(args) 101 | elif args.model_type == 'kosmos-2': 102 | build_kosmos_engine(args) 103 | elif args.model_type == 'phi-3-vision': 104 | build_phi_engine(args) 105 | else: 106 | raise RuntimeError(f"Invalid model type {args.model_type}") 107 | 108 | 109 | def export_visual_wrapper_onnx(visual_wrapper, 110 | input, 111 | output_dir, 112 | input_names=['input'], 113 | dynamic_axes={'input': { 114 | 0: 'batch' 115 | }}): 116 | logger.log(trt.Logger.INFO, "Exporting onnx") 117 | os.makedirs(f'{output_dir}/onnx', exist_ok=True) 118 | torch.onnx.export(visual_wrapper, 119 | input, 120 | f'{output_dir}/onnx/visual_encoder.onnx', 121 | opset_version=17, 122 | input_names=input_names, 123 | output_names=['output'], 124 | dynamic_axes=dynamic_axes) 125 | 126 | 127 | def build_trt_engine(model_type, 128 | input_sizes, 129 | output_dir, 130 | max_batch_size, 131 | dtype=torch.float16, 132 | num_frames=None): 133 | part_name = 'visual_encoder' 134 | onnx_file = '%s/onnx/%s.onnx' % (output_dir, part_name) 135 | engine_file = '%s/%s.engine' % (output_dir, part_name) 136 | config_file = '%s/%s' % (output_dir, "config.json") 137 | logger.log(trt.Logger.INFO, "Building TRT engine for %s" % part_name) 138 | 139 | builder = trt.Builder(logger) 140 | network = builder.create_network( 141 | 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) 142 | profile = builder.create_optimization_profile() 143 | 144 | config_args = { 145 | "precision": str(dtype).split('.')[-1], 146 | "model_type": model_type 147 | } 148 | if num_frames is not None: 149 | config_args["num_frames"] = num_frames 150 | 151 | config_wrapper = Builder().create_builder_config(**config_args) 152 | config = config_wrapper.trt_builder_config 153 | 154 | parser = trt.OnnxParser(network, logger) 155 | 156 | with open(onnx_file, 'rb') as model: 157 | if not parser.parse(model.read(), os.path.abspath(onnx_file)): 158 | logger.log(trt.Logger.ERROR, "Failed parsing %s" % onnx_file) 159 | for error in range(parser.num_errors): 160 | logger.log(trt.Logger.ERROR, parser.get_error(error)) 161 | logger.log(trt.Logger.INFO, "Succeeded parsing %s" % onnx_file) 162 | 163 | # Delete onnx files since we don't need them now 164 | shutil.rmtree(f'{output_dir}/onnx') 165 | 166 | nBS = -1 167 | nMinBS = 1 168 | nOptBS = max(nMinBS, int(max_batch_size / 2)) 169 | nMaxBS = max_batch_size 170 | 171 | inputT = network.get_input(0) 172 | 173 | # input sizes can be a list of ints (e.g., [3, H, W]) when inputs are images, 174 | # or a list of three int lists (e.g., [[1, 1, 2700], [1, 500, 2700], [1, 4096, 2700]]). 175 | assert isinstance(input_sizes, list), "input_sizes must be a list" 176 | if isinstance(input_sizes[0], int): 177 | logger.log(trt.Logger.INFO, f"Processed input sizes {input_sizes}") 178 | inputT.shape = [nBS, *input_sizes] 179 | min_size = opt_size = max_size = input_sizes 180 | elif len(input_sizes) == 3 and isinstance(input_sizes[0], list): 181 | min_size, opt_size, max_size = input_sizes 182 | logger.log( 183 | trt.Logger.INFO, 184 | f"Processed min/opt/max input sizes {min_size}/{opt_size}/{max_size}" 185 | ) 186 | else: 187 | raise ValueError(f"invalid input sizes: {input_sizes}") 188 | 189 | profile.set_shape(inputT.name, [nMinBS, *min_size], [nOptBS, *opt_size], 190 | [nMaxBS, *max_size]) 191 | if model_type == "pix2struct" or model_type == "StructEqTable" : 192 | inputT = network.get_input(1) 193 | P = input_sizes[0] # Number of patches 194 | inputT.shape = [nBS, P] 195 | profile.set_shape(inputT.name, [nMinBS, P], [nOptBS, P], [nMaxBS, P]) 196 | config.add_optimization_profile(profile) 197 | 198 | t0 = time() 199 | engine_string = builder.build_serialized_network(network, config) 200 | t1 = time() 201 | if engine_string is None: 202 | raise RuntimeError("Failed building %s" % (engine_file)) 203 | else: 204 | logger.log(trt.Logger.INFO, 205 | "Succeeded building %s in %d s" % (engine_file, t1 - t0)) 206 | with open(engine_file, 'wb') as f: 207 | f.write(engine_string) 208 | 209 | Builder.save_config(config_wrapper, config_file) 210 | 211 | 212 | def build_blip2_engine(args): 213 | model_type = 'Salesforce/blip2-' + args.model_type 214 | processor = Blip2Processor.from_pretrained(model_type) 215 | 216 | raw_image = Image.new('RGB', [10, 10]) # dummy image 217 | prompt = "Question: what is this? Answer:" 218 | inputs = processor(raw_image, prompt, 219 | return_tensors="pt").to(args.device, torch.float16) 220 | image = inputs['pixel_values'] 221 | 222 | class Blip2VisionWrapper(torch.nn.Module): 223 | 224 | def __init__(self, vision_model, qformer, projector, query_tokens): 225 | super().__init__() 226 | self.vision_model = vision_model 227 | self.qformer = qformer 228 | self.projector = projector 229 | self.query_tokens = query_tokens 230 | 231 | def forward(self, image): 232 | features = self.vision_model(image)[0] 233 | qformer_output = self.qformer(query_embeds=self.query_tokens, 234 | encoder_hidden_states=features, 235 | return_dict=True) 236 | return self.projector(qformer_output.last_hidden_state) 237 | 238 | model = Blip2ForConditionalGeneration.from_pretrained( 239 | model_type, torch_dtype=torch.float16) 240 | wrapper = Blip2VisionWrapper(model.vision_model, model.qformer, 241 | model.language_projection, model.query_tokens) 242 | wrapper.to(args.device) 243 | 244 | export_visual_wrapper_onnx(wrapper, image, args.output_dir) 245 | build_trt_engine( 246 | model_type, 247 | [image.shape[1], image.shape[2], image.shape[3]], # [3, H, W] 248 | args.output_dir, 249 | args.max_batch_size) 250 | 251 | 252 | def build_pix2struct_engine(args): 253 | processor = AutoProcessor.from_pretrained(args.model_path) 254 | raw_image = Image.new('RGB', [10, 10]) # dummy image 255 | dtype = torch.float16 256 | inputs = processor(text="dummy", images=raw_image, return_tensors="pt", max_patches=processor.image_processor.max_patches) 257 | image = inputs['flattened_patches'].to(args.device, dtype) 258 | attention_mask = inputs['attention_mask'].to(args.device, torch.int) 259 | class pix2structVisionWrapper(torch.nn.Module): 260 | 261 | def __init__(self, encoder): 262 | super().__init__() 263 | self.encoder = encoder 264 | 265 | def forward(self, image, attention_mask): 266 | vision_x = self.encoder.embeddings(image) 267 | img_features = self.encoder.encoder(vision_x, 268 | attention_mask=attention_mask) 269 | img_features = self.encoder.layernorm(img_features[0]) 270 | return img_features 271 | 272 | model = Pix2StructForConditionalGeneration.from_pretrained( 273 | args.model_path, torch_dtype=dtype) 274 | 275 | wrapper = pix2structVisionWrapper(model.encoder.to(args.device)) 276 | # input shape: batch size, number of patches, hidden dimension 277 | # attention mask shape: batch size, number of patches 278 | # The number of image patches can vary depending on the image size, but it typically 279 | # falls within a relatively narrow range. To improve performance, we can avoid using 280 | # dynamic axis for the input patches and instead use a fixed number of patches along 281 | # with an attention mask. 282 | export_visual_wrapper_onnx(wrapper, (image, attention_mask), 283 | args.output_dir, 284 | input_names=['input', 'attention_mask'], 285 | dynamic_axes={ 286 | 'input': { 287 | 0: 'batch' 288 | }, 289 | 'attention_mask': { 290 | 0: 'batch' 291 | } 292 | }) 293 | build_trt_engine( 294 | args.model_type, 295 | [image.shape[1], image.shape[2]], # Number of Patches, Hidden Dimension 296 | args.output_dir, 297 | args.max_batch_size, 298 | torch.bfloat16) 299 | 300 | 301 | def build_StructEqTable_engine(args): 302 | processor = AutoProcessor.from_pretrained(args.model_path) 303 | raw_image = Image.new('RGB', [10, 10]) # dummy image 304 | dtype = torch.float16 305 | inputs = processor(text="dummy", images=raw_image, return_tensors="pt", max_patches=processor.image_processor.max_patches) 306 | image = inputs['flattened_patches'].to(args.device, dtype) 307 | attention_mask = inputs['attention_mask'].to(args.device, torch.int) 308 | class StructEqTableVisionWrapper(torch.nn.Module): 309 | 310 | def __init__(self, encoder): 311 | super().__init__() 312 | self.encoder = encoder 313 | 314 | def forward(self, image, attention_mask): 315 | vision_x = self.encoder.embeddings(image) 316 | img_features = self.encoder.encoder(vision_x, 317 | attention_mask=attention_mask) 318 | img_features = self.encoder.layernorm(img_features[0]) 319 | return img_features 320 | 321 | model = AutoModelForVision2Seq.from_pretrained( 322 | args.model_path, torch_dtype=dtype) 323 | 324 | wrapper = StructEqTableVisionWrapper(model.encoder.to(args.device)) 325 | # input shape: batch size, number of patches, hidden dimension 326 | # attention mask shape: batch size, number of patches 327 | # The number of image patches can vary depending on the image size, but it typically 328 | # falls within a relatively narrow range. To improve performance, we can avoid using 329 | # dynamic axis for the input patches and instead use a fixed number of patches along 330 | # with an attention mask. 331 | export_visual_wrapper_onnx(wrapper, (image, attention_mask), 332 | args.output_dir, 333 | input_names=['input', 'attention_mask'], 334 | dynamic_axes={ 335 | 'input': { 336 | 0: 'batch' 337 | }, 338 | 'attention_mask': { 339 | 0: 'batch' 340 | } 341 | }) 342 | build_trt_engine( 343 | args.model_type, 344 | [image.shape[1], image.shape[2]], # Number of Patches, Hidden Dimension 345 | args.output_dir, 346 | args.max_batch_size, 347 | torch.bfloat16) 348 | 349 | 350 | def build_llava_engine(args): 351 | processor = AutoProcessor.from_pretrained(args.model_path) 352 | raw_image = Image.new('RGB', [10, 10]) # dummy image 353 | image = processor(text="dummy", images=raw_image, 354 | return_tensors="pt")['pixel_values'].to( 355 | args.device, torch.float16) 356 | 357 | class LlavaVisionWrapper(torch.nn.Module): 358 | 359 | def __init__(self, tower, projector, feature_layer): 360 | super().__init__() 361 | self.tower = tower 362 | self.projector = projector 363 | self.feature_layer = feature_layer 364 | 365 | def forward(self, image): 366 | all_hidden_states = self.tower( 367 | image, output_hidden_states=True).hidden_states 368 | features = all_hidden_states[self.feature_layer][:, 1:] 369 | return self.projector(features) 370 | 371 | model = LlavaForConditionalGeneration.from_pretrained( 372 | args.model_path, torch_dtype=torch.float16) 373 | wrapper = LlavaVisionWrapper(model.vision_tower.to(args.device), 374 | model.multi_modal_projector.to(args.device), 375 | model.config.vision_feature_layer) 376 | 377 | export_visual_wrapper_onnx(wrapper, image, args.output_dir) 378 | build_trt_engine( 379 | args.model_type, 380 | [image.shape[1], image.shape[2], image.shape[3]], # [3, H, W] 381 | args.output_dir, 382 | args.max_batch_size) 383 | 384 | 385 | def build_vila_engine(args): 386 | # Note: VILA model is not in public HF model zoo yet. We need to explicitly import from the git repo 387 | sys.path.append(args.vila_path) 388 | from llava.model import LlavaLlamaConfig, LlavaLlamaModel # noqa 389 | from transformers import AutoModel 390 | model = AutoModel.from_pretrained( 391 | args.model_path, 392 | device_map='auto', 393 | ) 394 | 395 | vision_tower = model.get_vision_tower() 396 | image_processor = vision_tower.image_processor 397 | raw_image = Image.new('RGB', [10, 10]) # dummy image 398 | image = image_processor(images=raw_image, 399 | return_tensors="pt")['pixel_values'] 400 | if isinstance(image, list): 401 | image = image[0].unsqueeze(0) 402 | image = image.to(args.device, torch.float16) 403 | 404 | class VilaVisionWrapper(torch.nn.Module): 405 | 406 | def __init__(self, tower, projector): 407 | super().__init__() 408 | self.tower = tower 409 | self.projector = projector 410 | 411 | def forward(self, image): 412 | features = self.tower(image) 413 | return self.projector(features) 414 | 415 | model = AutoModel.from_pretrained( 416 | args.model_path, 417 | device_map='auto', 418 | ) 419 | wrapper = VilaVisionWrapper(model.get_vision_tower().to(args.device), 420 | model.mm_projector.to(args.device)) 421 | export_visual_wrapper_onnx(wrapper, image, args.output_dir) 422 | build_trt_engine( 423 | args.model_type, 424 | [image.shape[1], image.shape[2], image.shape[3]], # [3, H, W] 425 | args.output_dir, 426 | args.max_batch_size) 427 | 428 | 429 | def build_nougat_engine(args): 430 | processor = NougatProcessor.from_pretrained(args.model_path) 431 | raw_image = Image.new('RGB', [10, 10]) # dummy image 432 | image = processor(raw_image, return_tensors="pt")['pixel_values'].to( 433 | args.device, torch.float16) 434 | 435 | class SwinEncoderWrapper(torch.nn.Module): 436 | 437 | def __init__(self, encoder): 438 | super().__init__() 439 | self.encoder = encoder 440 | 441 | def forward(self, image): 442 | return self.encoder(image).last_hidden_state 443 | 444 | model = VisionEncoderDecoderModel.from_pretrained(args.model_path, 445 | torch_dtype=torch.float16) 446 | swin_encoder = model.get_encoder().to(args.device) 447 | wrapper = SwinEncoderWrapper(swin_encoder) 448 | 449 | export_visual_wrapper_onnx(wrapper, image, args.output_dir) 450 | build_trt_engine( 451 | args.model_type, 452 | [image.shape[1], image.shape[2], image.shape[3]], # [3, H, W] 453 | args.output_dir, 454 | args.max_batch_size) 455 | 456 | 457 | def build_cogvlm_engine(args): 458 | hf_config = AutoConfig.from_pretrained(args.model_path, 459 | trust_remote_code=True) 460 | image_size = hf_config.vision_config['image_size'] 461 | dtype = hf_config.torch_dtype 462 | image = torch.empty(1, 463 | 3, 464 | image_size, 465 | image_size, 466 | dtype=dtype, 467 | device=args.device) # dummy image 468 | 469 | class CogVlmVisionWrapper(torch.nn.Module): 470 | 471 | def __init__(self, encoder): 472 | super().__init__() 473 | self.encoder = encoder 474 | 475 | def forward(self, image): 476 | return self.encoder(image) 477 | 478 | cogvlm = AutoModelForCausalLM.from_pretrained(args.model_path, 479 | torch_dtype=dtype, 480 | trust_remote_code=True) 481 | vit_encoder = cogvlm.model.vision.to(args.device).eval() 482 | 483 | wrapper = CogVlmVisionWrapper(vit_encoder) 484 | export_visual_wrapper_onnx(wrapper, image, args.output_dir) 485 | build_trt_engine( 486 | args.model_type, 487 | [image.shape[1], image.shape[2], image.shape[3]], # [3, H, W] 488 | args.output_dir, 489 | args.max_batch_size, 490 | dtype) 491 | 492 | 493 | def build_fuyu_engine(args): 494 | processor = FuyuProcessor.from_pretrained(args.model_path) 495 | raw_image = Image.new('RGB', [10, 10]) 496 | image = processor(text="dummy", images=raw_image, 497 | return_tensors="pt")['image_patches'][0].to( 498 | args.device, torch.float16).unsqueeze(0) 499 | 500 | class FuyuEncoderWrapper(torch.nn.Module): 501 | 502 | def __init__(self, linear): 503 | super().__init__() 504 | self.linear = linear.to(torch.float16) 505 | 506 | def forward(self, patches): 507 | return self.linear(patches).flatten(0, 1) 508 | 509 | model = FuyuForCausalLM.from_pretrained(args.model_path, 510 | torch_dtype=torch.float16) 511 | 512 | vision_encoder = model.vision_embed_tokens 513 | wrapper = FuyuEncoderWrapper(vision_encoder).to(args.device) 514 | 515 | export_visual_wrapper_onnx(wrapper, 516 | image, 517 | args.output_dir, 518 | dynamic_axes={'input': { 519 | 0: 'batch', 520 | 2: 'patch' 521 | }}) 522 | build_trt_engine( 523 | args.model_type, 524 | # [nImgs, nImgPatches, nDims] 525 | # nImgs is always one since each query has exactly one image 526 | # nImgPatches depends on image size (patch size: 30x30) 527 | # nDims is 30x30x3=2700 (patch size x color channels) 528 | [[1, 1, 2700], [1, 500, 2700], [1, 4096, 2700]], 529 | args.output_dir, 530 | args.max_batch_size) 531 | 532 | 533 | def build_neva_engine(args): 534 | # extract NeMo checkpoint 535 | with tarfile.open(args.model_path) as tar: 536 | nemo_config = yaml.safe_load(tar.extractfile("./model_config.yaml")) 537 | try: 538 | # trained without TP 539 | mp0_weights = torch.load(tar.extractfile("./model_weights.ckpt"), 540 | map_location=args.device) 541 | except KeyError: 542 | # trained with TP 543 | mp0_weights = torch.load( 544 | tar.extractfile("./mp_rank_00/model_weights.ckpt"), 545 | map_location=args.device) 546 | 547 | vision_config = nemo_config["mm_cfg"]["vision_encoder"] 548 | 549 | class VisionEncoderWrapper(torch.nn.Module): 550 | 551 | def __init__(self, encoder, connector): 552 | super().__init__() 553 | self.encoder = encoder 554 | self.connector = connector 555 | 556 | def forward(self, images): 557 | vision_x = self.encoder(pixel_values=images, 558 | output_hidden_states=True) 559 | vision_x = vision_x.hidden_states[-2] 560 | vision_x = vision_x[:, 1:] 561 | vision_x = self.connector(vision_x) 562 | return vision_x 563 | 564 | encoder = AutoModel.from_pretrained(vision_config["from_pretrained"], 565 | torch_dtype=torch.bfloat16, 566 | trust_remote_code=True) 567 | vision_encoder = encoder.vision_model 568 | hf_config = encoder.config 569 | dtype = hf_config.torch_dtype 570 | 571 | # connector 572 | assert nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "mlp2x_gelu" 573 | vision_connector = torch.nn.Sequential( 574 | torch.nn.Linear(vision_config["hidden_size"], 575 | nemo_config["hidden_size"], 576 | bias=True), torch.nn.GELU(), 577 | torch.nn.Linear(nemo_config["hidden_size"], 578 | nemo_config["hidden_size"], 579 | bias=True)).to(dtype=dtype) 580 | 581 | key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector" 582 | for layer in range(0, 3, 2): 583 | vision_connector[layer].load_state_dict({ 584 | 'weight': 585 | mp0_weights[f"{key_prefix}.{layer}.weight"].to(dtype), 586 | 'bias': 587 | mp0_weights[f"{key_prefix}.{layer}.bias"].to(dtype), 588 | }) 589 | 590 | # export the whole wrapper 591 | wrapper = VisionEncoderWrapper(vision_encoder, 592 | vision_connector).to(args.device, dtype) 593 | image_size = hf_config.vision_config.image_size 594 | dummy_image = torch.empty( 595 | 1, 3, image_size, image_size, dtype=dtype, 596 | device=args.device) # dummy image shape [B, C, H, W] 597 | export_visual_wrapper_onnx(wrapper, dummy_image, args.output_dir) 598 | build_trt_engine( 599 | args.model_type, 600 | [3, image_size, image_size], # [3, H, W] 601 | args.output_dir, 602 | args.max_batch_size, 603 | dtype) 604 | 605 | 606 | def build_video_neva_engine(args): 607 | # extract NeMo checkpoint 608 | with tarfile.open(args.model_path) as tar: 609 | nemo_config = yaml.safe_load(tar.extractfile("./model_config.yaml")) 610 | try: 611 | # trained without TP 612 | mp0_weights = torch.load(tar.extractfile("./model_weights.ckpt"), 613 | map_location=args.device) 614 | except KeyError: 615 | # trained with TP 616 | mp0_weights = torch.load( 617 | tar.extractfile("./mp_rank_00/model_weights.ckpt"), 618 | map_location=args.device) 619 | 620 | vision_config = nemo_config["mm_cfg"]["vision_encoder"] 621 | 622 | class VisionEncoderWrapper(torch.nn.Module): 623 | 624 | def __init__(self, encoder, connector): 625 | super().__init__() 626 | self.encoder = encoder 627 | self.connector = connector 628 | 629 | def forward(self, images): 630 | b, num_frames, c, h, w = images.shape 631 | images = images.view(b * num_frames, c, h, w) 632 | vision_x = self.encoder( 633 | pixel_values=images, #[(B num_frames), C, H, W] 634 | output_hidden_states=True) 635 | vision_x = vision_x.hidden_states[-2] 636 | vision_x = vision_x[:, 1:] 637 | 638 | # reshape back to [B, num_frames, img_size, hidden_size] 639 | vision_x = vision_x.view(b, num_frames, -1, vision_x.shape[-1]) 640 | 641 | vision_x = self.connector(vision_x) 642 | return vision_x 643 | 644 | encoder = AutoModel.from_pretrained(vision_config["from_pretrained"], 645 | torch_dtype=torch.bfloat16, 646 | trust_remote_code=True) 647 | vision_encoder = encoder.vision_model 648 | hf_config = encoder.config 649 | dtype = hf_config.torch_dtype 650 | 651 | # connector 652 | assert nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "linear" 653 | vision_connector = torch.nn.Linear(vision_config["hidden_size"], 654 | nemo_config["hidden_size"], 655 | bias=True) 656 | 657 | key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector" 658 | vision_connector.load_state_dict({ 659 | 'weight': 660 | mp0_weights[f"{key_prefix}.weight"].to(dtype), 661 | 'bias': 662 | mp0_weights[f"{key_prefix}.bias"].to(dtype), 663 | }) 664 | 665 | # export the whole wrapper 666 | wrapper = VisionEncoderWrapper(vision_encoder, 667 | vision_connector).to(args.device, dtype) 668 | image_size = hf_config.vision_config.image_size 669 | num_frames = nemo_config['data']['num_frames'] 670 | dummy_video = torch.empty(1, 671 | num_frames, 672 | 3, 673 | image_size, 674 | image_size, 675 | dtype=dtype, 676 | device=args.device) # dummy image 677 | export_visual_wrapper_onnx(wrapper, dummy_video, args.output_dir) 678 | build_trt_engine( 679 | args.model_type, 680 | [num_frames, 3, image_size, image_size], # [num_frames, 3, H, W] 681 | args.output_dir, 682 | args.max_batch_size, 683 | dtype, 684 | num_frames=num_frames) 685 | 686 | 687 | def build_kosmos_engine(args): 688 | processor = AutoProcessor.from_pretrained(args.model_path) 689 | raw_image = Image.new('RGB', [10, 10]) # dummy image 690 | image = processor(text="dummy", images=raw_image, 691 | return_tensors="pt")['pixel_values'].to( 692 | args.device, torch.float16) 693 | 694 | class VisionEncoderWrapper(torch.nn.Module): 695 | 696 | def __init__(self, encoder, connector): 697 | super().__init__() 698 | self.encoder = encoder 699 | self.connector = connector 700 | 701 | def forward(self, images): 702 | vision_x = self.encoder(images, output_hidden_states=True) 703 | img_features = self.encoder.model.post_layernorm( 704 | vision_x.last_hidden_state) 705 | img_features = F.normalize(img_features, dim=-1) 706 | img_features, _ = self.connector(img_features) 707 | return img_features 708 | 709 | model = AutoModelForVision2Seq.from_pretrained(args.model_path, 710 | torch_dtype=torch.float16) 711 | wrapper = VisionEncoderWrapper( 712 | model.vision_model.to(args.device), 713 | model.image_to_text_projection.to(args.device)) 714 | 715 | export_visual_wrapper_onnx(wrapper, image, args.output_dir) 716 | build_trt_engine( 717 | args.model_type, 718 | [image.shape[1], image.shape[2], image.shape[3]], # [3, H, W] 719 | args.output_dir, 720 | args.max_batch_size) 721 | 722 | 723 | def build_phi_engine(args): 724 | processor = AutoProcessor.from_pretrained(args.model_path, 725 | trust_remote_code=True) 726 | raw_image = Image.new('RGB', [10, 10]) # dummy image 727 | image = processor(text="<|image_1|>\ndummy", 728 | images=raw_image, 729 | return_tensors="pt")['pixel_values'].to( 730 | args.device, torch.float16) 731 | try: 732 | with open(f"{args.model_path}/preprocessor_config.json", "r") as file: 733 | config = file.read() 734 | config_dict = json.loads(config) 735 | num_crops = config_dict.get("num_crops") 736 | except: 737 | num_crops = 16 738 | 739 | class Phi3VisionWrapper(torch.nn.Module): 740 | 741 | def __init__(self, img_processor, img_projection, layer_idx, 742 | image_dim_out): 743 | super().__init__() 744 | self.img_processor = img_processor 745 | self.img_projection = img_projection 746 | self.layer_idx = layer_idx 747 | self.image_dim_out = image_dim_out 748 | 749 | def get_img_features( 750 | self, img_embeds: torch.FloatTensor) -> torch.FloatTensor: 751 | LAYER_IDX = self.layer_idx 752 | 753 | img_processor_output = self.img_processor(img_embeds, 754 | output_hidden_states=True) 755 | img_feature = img_processor_output.hidden_states[LAYER_IDX] 756 | 757 | patch_feature = img_feature[:, 1:] 758 | return patch_feature 759 | 760 | def forward(self, image): 761 | img_features = self.get_img_features(image) 762 | base_feat_height = int(math.sqrt(img_features.shape[1])) 763 | C = self.image_dim_out 764 | H = base_feat_height 765 | img_features = img_features.reshape(-1, H, H, C).reshape( 766 | -1, H // 2, 2, H // 2, 2, 767 | C).contiguous().permute(0, 1, 3, 2, 4, 768 | 5).reshape(-1, H // 2, H // 2, 769 | 4 * C).contiguous() 770 | return self.apply_img_projection(img_features) 771 | 772 | def apply_img_projection(self, input): 773 | return self.img_projection(input) 774 | 775 | model = AutoModelForCausalLM.from_pretrained(args.model_path, 776 | torch_dtype=torch.float16, 777 | trust_remote_code=True).to( 778 | args.device) 779 | 780 | wrapper = Phi3VisionWrapper(model.model.vision_embed_tokens.img_processor, 781 | model.model.vision_embed_tokens.img_projection, 782 | model.model.vision_embed_tokens.layer_idx, 783 | model.model.vision_embed_tokens.image_dim_out) 784 | image = image.flatten(0, 1) 785 | glb_GN = wrapper.apply_img_projection( 786 | model.model.vision_embed_tokens.glb_GN) 787 | sub_GN = wrapper.apply_img_projection( 788 | model.model.vision_embed_tokens.sub_GN) 789 | tensors = {"glb_GN": glb_GN, "sub_GN": sub_GN} 790 | save_file(tensors, args.output_dir + "/image_newlines.safetensors") 791 | export_visual_wrapper_onnx(wrapper, image, args.output_dir) 792 | build_trt_engine( 793 | args.model_type, 794 | [image.shape[1], image.shape[2], image.shape[3]], args.output_dir, 795 | args.max_batch_size * (num_crops + 1)) #TODO: Take input from config 796 | 797 | 798 | if __name__ == '__main__': 799 | logger = trt.Logger(trt.Logger.INFO) 800 | args = parse_arguments() 801 | builder = VisionEngineBuilder(args) 802 | builder.build() 803 | -------------------------------------------------------------------------------- /tools/tensorrt_utils/helper.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from typing import Union 3 | 4 | import numpy as np 5 | import torch # pytype: disable=import-error 6 | 7 | from tensorrt_llm._utils import str_dtype_to_torch 8 | 9 | 10 | def split(v: Union[np.ndarray, torch.Tensor], 11 | tp_size: int, 12 | tp_rank: int, 13 | dim=0): 14 | if tp_size == 1: 15 | if isinstance(v, np.ndarray): 16 | return np.ascontiguousarray(v.copy()) 17 | else: 18 | return v.clone().detach() 19 | assert len(v.shape) > 1 or dim == 0 20 | if isinstance(v, np.ndarray): 21 | return np.ascontiguousarray( 22 | np.split(v, tp_size, axis=dim)[tp_rank].copy()) 23 | else: 24 | assert v.shape[dim] % tp_size == 0, \ 25 | 'Unable to split: shape={v.shape} (dim={dim}) tp_size={tp_size}.' 26 | split_size = v.shape[dim] // tp_size 27 | return v.split(split_size, dim=dim)[tp_rank].clone().detach() 28 | 29 | 30 | def reshape(v: torch.Tensor, shape=None): 31 | if shape is None: 32 | return v.contiguous() 33 | else: 34 | return v.reshape(shape).contiguous() 35 | 36 | 37 | def fuse_qkv_one_layer(params, attn_module_name, trtllm_layer_name, tp_size, 38 | tp_rank, model_type, weight_shape, bias_shape): 39 | 40 | qkv_module_names = get_qkv_module_name(model_type) 41 | 42 | weight = {} 43 | 44 | # fuse weights of q, k, v 45 | q_w = params[f'{attn_module_name}.{qkv_module_names["q"]}.weight'] 46 | k_w = params[f'{attn_module_name}.{qkv_module_names["k"]}.weight'] 47 | v_w = params[f'{attn_module_name}.{qkv_module_names["v"]}.weight'] 48 | 49 | # fuse qkv weight 50 | shape = q_w.shape # (do, din) 51 | qkv_w = torch.cat([q_w, k_w, v_w], 52 | dim=0).reshape([3, shape[0], shape[1]]) # (3, do, din) 53 | qkv_w = split(qkv_w, tp_size, tp_rank, dim=1) 54 | weight[f'{trtllm_layer_name}.qkv.weight'] = reshape(qkv_w, 55 | shape=weight_shape) 56 | 57 | # fuse qkv biases if present 58 | if f'{attn_module_name}.{qkv_module_names["q"]}.bias' in params.keys( 59 | ) and params[f'{attn_module_name}.{qkv_module_names["q"]}.bias'] is not None: 60 | q_b = params[f'{attn_module_name}.{qkv_module_names["q"]}.bias'] 61 | k_b = params[f'{attn_module_name}.{qkv_module_names["k"]}.bias'] 62 | v_b = params[f'{attn_module_name}.{qkv_module_names["v"]}.bias'] 63 | shape = q_b.shape[0] # (do,) 64 | qkv_b = torch.cat([q_b, k_b, v_b], dim=0).reshape([3, shape]) # (3, do) 65 | qkv_b = split(qkv_b, tp_size, tp_rank, dim=1) 66 | weight[f'{trtllm_layer_name}.qkv.bias'] = reshape(qkv_b, 67 | shape=bias_shape) 68 | return weight 69 | 70 | 71 | def get_qkv_module_name(model_type): 72 | if model_type == "t5": 73 | q = "q" 74 | k = "k" 75 | v = "v" 76 | elif model_type == "bart" or model_type == "nmt": 77 | q = "q_proj" 78 | k = "k_proj" 79 | v = "v_proj" 80 | elif model_type == "pix2struct": 81 | q = "query" 82 | k = "key" 83 | v = "value" 84 | elif model_type == "StructEqTable": 85 | q = "query" 86 | k = "key" 87 | v = "value" 88 | return {"q": q, "k": k, "v": v} 89 | 90 | 91 | def convert_weight_to_dtype(params: typing.Dict[str, torch.Tensor], 92 | dtype: typing.Optional[np.dtype] = None): 93 | if dtype is not None: 94 | assert isinstance(dtype, 95 | str), f"dtype must be str, but get type {type(dtype)}" 96 | for name in params.keys(): 97 | params[name] = params[name].to(str_dtype_to_torch(dtype)) 98 | --------------------------------------------------------------------------------