├── .github ├── FUNDING.yml ├── ISSUE_TEMPLATE │ ├── 01-feature_request.md │ ├── 02-bug.md │ └── 03-blank.md └── workflows │ ├── SyncToGitee.yml │ └── publish_whl.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── cliff.toml ├── demo.py ├── docs └── doc_whl_rapid_table.md ├── rapid_table ├── __init__.py ├── default_models.yaml ├── engine_cfg.yaml ├── inference_engine │ ├── __init__.py │ ├── base.py │ ├── onnxruntime │ │ ├── __init__.py │ │ ├── main.py │ │ └── provider_config.py │ └── torch.py ├── main.py ├── model_processor │ ├── __init__.py │ └── main.py ├── models │ └── .gitkeep ├── table_matcher │ ├── __init__.py │ ├── main.py │ └── utils.py ├── table_structure │ ├── __init__.py │ ├── pp_structure │ │ ├── __init__.py │ │ ├── main.py │ │ ├── post_process.py │ │ └── pre_process.py │ ├── unitable │ │ ├── __init__.py │ │ ├── consts.py │ │ ├── main.py │ │ ├── post_process.py │ │ ├── pre_process.py │ │ └── unitable_modules.py │ └── utils.py └── utils │ ├── __init__.py │ ├── download_file.py │ ├── load_image.py │ ├── logger.py │ ├── typings.py │ ├── utils.py │ └── vis.py ├── requirements.txt ├── setup.py └── tests ├── test_files ├── table.jpg └── table_without_txt.jpg ├── test_main.py └── test_table_matcher.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 13 | custom: https://raw.githubusercontent.com/RapidAI/.github/6db6b6b9273f3151094a462a61fbc8e88564562c/assets/Sponsor.png 14 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/01-feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature Request 3 | about: requests for new RapidOCR features 4 | title: 'Feature Request' 5 | labels: 'Feature Request' 6 | assignees: '' 7 | 8 | --- 9 | 10 | 请您详细描述想要添加的新功能或者是新特性 11 | (Please describe in detail the new function or new feature you want to add) 12 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/02-bug.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug 3 | about: Bug 4 | title: 'Bug' 5 | labels: 'Bug' 6 | assignees: '' 7 | 8 | --- 9 | 10 | 请提供下述完整信息以便快速定位问题 11 | (Please provide the following information to quickly locate the problem) 12 | - **系统环境/System Environment**: 13 | - **使用的是哪门语言的程序/Which programing language**: 14 | - **所使用语言相关版本信息/Version**: 15 | - **OnnxRuntime版本/OnnxRuntime Version**: 16 | - **可复现问题的demo/Demo of reproducible problems**: 17 | - **完整报错/Complete Error Message**: 18 | - **可能的解决方案/Possible solutions**: -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/03-blank.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Blank Template 3 | about: Blank Template 4 | title: 'Blank Template' 5 | labels: 'Blank Template' 6 | assignees: '' 7 | 8 | --- -------------------------------------------------------------------------------- /.github/workflows/SyncToGitee.yml: -------------------------------------------------------------------------------- 1 | name: SyncToGitee 2 | on: 3 | push: 4 | branches: 5 | - main 6 | jobs: 7 | repo-sync: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - name: Checkout source codes 11 | uses: actions/checkout@v3 12 | 13 | - name: Mirror the Github organization repos to Gitee. 14 | uses: Yikun/hub-mirror-action@v1.4 15 | with: 16 | src: 'github/RapidAI' 17 | dst: 'gitee/RapidAI' 18 | dst_key: ${{ secrets.GITEE_PRIVATE_KEY }} 19 | dst_token: ${{ secrets.GITEE_TOKEN }} 20 | force_update: true 21 | static_list: "RapidTable" 22 | debug: true 23 | -------------------------------------------------------------------------------- /.github/workflows/publish_whl.yml: -------------------------------------------------------------------------------- 1 | name: Push rapidocr_table to pypi 2 | 3 | on: 4 | push: 5 | tags: 6 | - v* 7 | 8 | env: 9 | DEFAULT_MODEL: https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/slanet-plus.onnx 10 | 11 | jobs: 12 | UnitTesting: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Pull latest code 16 | uses: actions/checkout@v3 17 | 18 | - name: Set up Python 3.10 19 | uses: actions/setup-python@v4 20 | with: 21 | python-version: '3.10' 22 | architecture: 'x64' 23 | 24 | - name: Display Python version 25 | run: python -c "import sys; print(sys.version)" 26 | 27 | - name: Unit testings 28 | run: | 29 | wget $DEFAULT_MODEL -P rapid_table/models 30 | 31 | pip install -r requirements.txt 32 | pip install rapidocr onnxruntime torch torchvision tokenizers pytest 33 | pytest tests/*.py 34 | 35 | GenerateWHL_PushPyPi: 36 | needs: UnitTesting 37 | runs-on: ubuntu-latest 38 | 39 | steps: 40 | - uses: actions/checkout@v3 41 | 42 | - name: Set up Python 3.10 43 | uses: actions/setup-python@v4 44 | with: 45 | python-version: '3.10' 46 | architecture: 'x64' 47 | 48 | - name: Run setup 49 | run: | 50 | pip install -r requirements.txt 51 | python -m pip install --upgrade pip 52 | pip install wheel get_pypi_latest_version 53 | 54 | wget $DEFAULT_MODEL -P rapid_table/models 55 | python setup.py bdist_wheel ${{ github.ref_name }} 56 | 57 | - name: Publish distribution 📦 to PyPI 58 | uses: pypa/gh-action-pypi-publish@v1.5.0 59 | with: 60 | password: ${{ secrets.RAPID_TABLE }} 61 | packages_dir: dist/ 62 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | outputs/ 2 | *.json 3 | 4 | # Created by .ignore support plugin (hsz.mobi) 5 | ### Python template 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | .pytest_cache 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | # *.manifest 40 | # *.spec 41 | *.res 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | 138 | #idea 139 | .vs 140 | .vscode 141 | .idea 142 | /images 143 | /models 144 | 145 | #models 146 | *.onnx 147 | 148 | *.ttf 149 | *.ttc 150 | 151 | long1.jpg 152 | 153 | *.bin 154 | *.mapping 155 | *.xml 156 | 157 | *.pdiparams 158 | *.pdiparams.info 159 | *.pdmodel 160 | 161 | .DS_Store 162 | *.pth 163 | /rapid_table_torch/models/*.pth 164 | /rapid_table_torch/models/*.json 165 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://gitee.com/SWHL/autoflake 3 | rev: v2.1.1 4 | hooks: 5 | - id: autoflake 6 | args: 7 | [ 8 | "--recursive", 9 | "--in-place", 10 | "--remove-all-unused-imports", 11 | "--ignore-init-module-imports", 12 | ] 13 | files: \.py$ 14 | - repo: https://gitee.com/SWHL/black 15 | rev: 23.1.0 16 | hooks: 17 | - id: black 18 | files: \.py$ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2025 RapidAI 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |
3 |

📊 Rapid Table

4 |
5 | 6 | 7 | 8 | 9 | 10 | PyPI 11 | 12 | SemVer2.0 13 | 14 | 15 |
16 | 17 | ### 🌟 简介 18 | 19 | RapidTable库是专门用来文档类图像的表格结构还原,表格结构模型均属于序列预测方法,结合RapidOCR,将给定图像中的表格转化对应的HTML格式。 20 | 21 | slanet_plus是paddlex内置的SLANet升级版模型,准确率有大幅提升 22 | 23 | unitable是来源unitable的transformer模型,精度最高,暂仅支持pytorch推理,支持gpu推理加速,训练权重来源于 [OhMyTable项目](https://github.com/Sanster/OhMyTable) 24 | 25 | ### 📅 最近动态 26 | 27 | 2025-08-29 update: 发布v3.0.0,支持batch推理,更改了返回值参数,大家可debug单条查看使用 \ 28 | 2025-06-22 update: 发布v2.x,适配rapidocr v3.x \ 29 | 2025-01-09 update: 发布v1.x,全新接口升级 \ 30 | 2024.12.30 update:支持Unitable模型的表格识别,使用pytorch框架 \ 31 | 2024.11.24 update:支持gpu推理,适配 rapidOCR 单字识别匹配,支持逻辑坐标返回及可视化 \ 32 | 2024.10.13 update:补充最新paddlex-SLANet-plus 模型(paddle2onnx原因暂不能支持onnx) 33 | 34 | ### 📸 效果展示 35 | 36 |
37 | Demo 38 |
39 | 40 | ### 🖥️ 支持设备 41 | 42 | 通过ONNXRuntime推理引擎支持(`rapid_table>=2.0.0`): 43 | 44 | - DirectML 45 | - 昇腾NPU 46 | 47 | 具体使用方法: 48 | 49 | 1. 安装(需要卸载其他onnxruntime): 50 | 51 | ```bash 52 | # DirectML 53 | pip install onnxruntime-directml 54 | 55 | # 昇腾NPU 56 | pip install onnxruntime-cann 57 | ``` 58 | 59 | 2. 使用: 60 | 61 | ```python 62 | from rapidocr import RapidOCR 63 | 64 | from rapid_table import ModelType, RapidTable, RapidTableInput 65 | 66 | # DirectML 67 | ocr_engine = RapidOCR(params={"EngineConfig.onnxruntime.use_dml": True}) 68 | input_args = RapidTableInput( 69 | model_type=ModelType.SLANETPLUS, engine_cfg={"use_dml": True} 70 | ) 71 | 72 | # 昇腾NPU 73 | ocr_engine = RapidOCR(params={"EngineConfig.onnxruntime.use_cann": True}) 74 | 75 | input_args = RapidTableInput( 76 | model_type=ModelType.SLANETPLUS, 77 | engine_cfg={"use_cann": True, "cann_ep_cfg.gpu_id": 1}, 78 | ) 79 | 80 | table_engine = RapidTable(input_args) 81 | 82 | img_path = "https://raw.githubusercontent.com/RapidAI/RapidTable/refs/heads/main/tests/test_files/table.jpg" 83 | 84 | ori_ocr_res = ocr_engine(img_path) 85 | ocr_results = [ori_ocr_res.boxes, ori_ocr_res.txts, ori_ocr_res.scores] 86 | 87 | results = table_engine(img_path, ocr_results=ocr_results) 88 | results.vis(save_dir="outputs", save_name="vis") 89 | ``` 90 | 91 | ### 🧩 模型列表 92 | 93 | | `model_type` | 模型名称 | 推理框架 |模型大小 |推理耗时(单图 60KB)| 94 | |:--------------|:--------------------------------------| :------: |:------ |:------ | 95 | | `ppstructure_en` | `en_ppstructure_mobile_v2_SLANet.onnx` | onnxruntime |7.3M |0.15s | 96 | | `ppstructure_zh` | `ch_ppstructure_mobile_v2_SLANet.onnx` | onnxruntime |7.4M |0.15s | 97 | | `slanet_plus` | `slanet-plus.onnx` | onnxruntime |6.8M |0.15s | 98 | | `unitable` | `unitable(encoder.pth,decoder.pth)` | pytorch |500M |cpu(6s) gpu-4090(1.5s)| 99 | 100 | 模型来源\ 101 | [PaddleOCR 表格识别](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/table/README_ch.md)\ 102 | [PaddleX-SlaNetPlus 表格识别](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/module_usage/tutorials/ocr_modules/table_structure_recognition.md)\ 103 | [Unitable](https://github.com/poloclub/unitable?tab=readme-ov-file) 104 | 105 | 模型下载地址:[link](https://www.modelscope.cn/models/RapidAI/RapidTable/files) 106 | 107 | ### 🛠️ 安装 108 | 109 | 版本依赖关系如下: 110 | 111 | |`rapid_table`|OCR| 112 | |:---:|:---| 113 | |v2.x & v3.x |`rapidocr>=3.0.0`| 114 | |v1.0.x|`rapidocr>=2.0.0,<3.0.0`| 115 | |v0.x|`rapidocr_onnxruntime`| 116 | 117 | 由于模型较小,预先将slanet-plus表格识别模型(`slanet-plus.onnx`)打包进了whl包内。其余模型在初始化`RapidTable`类时,会根据`model_type`来自动下载模型到安装包所在`models`目录下。 118 | 119 | 当然也可以通过`RapidTableInput(model_path='')`来指定自己模型路径(`v1.0.x` 参数变量名使用`model_path`, `v2.x` 参数变量名变更为`model_dir_or_path`)。注意仅限于我们现支持的`model_type`。 120 | 121 | > ⚠️注意:`rapid_table>=v1.0.0`之后,不再将`rapidocr`依赖强制打包到`rapid_table`中。使用前,需要自行安装`rapidocr`包。 122 | > 123 | > ⚠️注意:`rapid_table>=v0.1.0,<1.0.0`之后,不再将`rapidocr`依赖强制打包到`rapid_table`中。使用前,需要自行安装`rapidocr_onnxruntime`包。 124 | 125 | ```bash 126 | pip install rapidocr 127 | pip install rapid_table 128 | 129 | # 基于torch来推理unitable模型 130 | pip install rapid_table[torch] # for unitable inference 131 | 132 | # onnxruntime-gpu推理 133 | pip uninstall onnxruntime 134 | pip install onnxruntime-gpu # for onnx gpu inference 135 | ``` 136 | 137 | ### 🚀 使用方式 138 | 139 | #### 🐍 Python脚本运行 140 | 141 | > ⚠️注意:在`rapid_table>=1.0.0`之后,模型输入均采用dataclasses封装,简化和兼容参数传递。输入和输出定义如下: 142 | 143 | ModelType支持已有的4个模型 ([source](./rapid_table/utils/typings.py)): 144 | 145 | ```python 146 | class ModelType(Enum): 147 | PPSTRUCTURE_EN = "ppstructure_en" # onnxruntime 148 | PPSTRUCTURE_ZH = "ppstructure_zh" # onnxruntime 149 | SLANETPLUS = "slanet_plus" # onnxruntime 150 | UNITABLE = "unitable" # torch推理引擎 151 | ``` 152 | 153 | #### Batch推理 154 | 155 | ```python 156 | from pathlib import Path 157 | 158 | from rapid_table import ModelType, RapidTable, RapidTableInput 159 | 160 | input_args = RapidTableInput(model_type=ModelType.PPSTRUCTURE_ZH) 161 | table_engine = RapidTable(input_args) 162 | 163 | img_list = list(Path("images").iterdir()) 164 | results = table_engine(img_path, batch_size=3) # 这里,batch_size默认为1 165 | 166 | # indexes:指定可视化的图像索引。默认为0 167 | results.vis(save_dir="outputs", save_name="vis", indexes=(0, 1, 2)) 168 | ``` 169 | 170 | ##### CPU使用 171 | 172 | ```python 173 | from rapid_table import ModelType, RapidTable, RapidTableInput 174 | 175 | input_args = RapidTableInput(model_type=ModelType.UNITABLE) 176 | table_engine = RapidTable(input_args) 177 | 178 | img_path = "https://raw.githubusercontent.com/RapidAI/RapidTable/refs/heads/main/tests/test_files/table.jpg" 179 | ori_ocr_res = ocr_engine(img_path) 180 | results = table_engine(img_path) 181 | results.vis(save_dir="outputs", save_name="vis") 182 | ``` 183 | 184 | ##### GPU使用 185 | 186 | > `engine_cfg`中参数是和[`engine_cfg.yaml`](https://github.com/RapidAI/RapidTable/blob/6da3974a35ac5da8a5cf58194eab00b6886212e8/rapid_table/engine_cfg.yaml)相对应的。 187 | 188 | ```python 189 | from rapid_table import ModelType, RapidTable, RapidTableInput 190 | 191 | # onnxruntime-gpu 192 | input_args = RapidTableInput( 193 | model_type=ModelType.SLANETPLUS, 194 | engine_cfg={"use_cuda": True, "cuda_ep_cfg.gpu_id": 1} 195 | ) 196 | 197 | # torch gpu 198 | # input_args = RapidTableInput( 199 | # model_type=ModelType.UNITABLE, 200 | # engine_cfg={"use_cuda": True, "gpu_id": 1}, 201 | # ) 202 | 203 | table_engine = RapidTable(input_args) 204 | 205 | img_path = "https://raw.githubusercontent.com/RapidAI/RapidTable/refs/heads/main/tests/test_files/table.jpg" 206 | results = table_engine(img_path) 207 | results.vis(save_dir="outputs", save_name="vis") 208 | ``` 209 | 210 | #### 📦 终端运行 211 | 212 | ```bash 213 | rapid_table https://raw.githubusercontent.com/RapidAI/RapidTable/refs/heads/main/tests/test_files/table.jpg -v 214 | ``` 215 | 216 | ### 📝 结果 217 | 218 | #### 📎 返回结果 219 | 220 |
221 | 222 | ```html 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | <> 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 |
MethodsFPS
SegLink [26]70.086d> 237 | 239 | 77.08.9
PixelLink [4]73.283.077.8
TextSnake [18]73.983.278.31.1
TextField [37]75.987.481.35.2
MSR[38]76.787.87.481.7
FTSN [3]77.187.682.0
LSE[30]81.784.282.9
CRAFT [2]78.288.282.98.6
MCN[16]798883
ATRR[35]82.185.283.6
PAN [34]83.884.484.130.2
DB[12]79.2 317 | 91.584.932.0
DRRG[41]82.3088.0585.08
Ours (SynText)80.6885 332 | 334 | 82.9712.68
Ours (MLT-17)84.5486.6285.5712.31
345 | 346 | 347 | ``` 348 | 349 |
350 | 351 | #### 🖼️ 可视化结果 352 | 353 |
354 | <>
MethodsFPS
SegLink [26]70.086d>77.08.9
PixelLink [4]73.283.077.8
TextSnake [18]73.983.278.31.1
TextField [37]75.987.481.35.2
MSR[38]76.787.87.481.7
FTSN [3]77.187.682.0
LSE[30]81.784.282.9
CRAFT [2]78.288.282.98.6
MCN[16]798883
ATRR[35]82.185.283.6
PAN [34]83.884.484.130.2
DB[12]79.291.584.932.0
DRRG[41]82.3088.0585.08
Ours (SynText)80.688582.9712.68
Ours (MLT-17)84.5486.6285.5712.31
355 | 356 |
357 | 358 | ### 🔄 与[TableStructureRec](https://github.com/RapidAI/TableStructureRec)关系 359 | 360 | TableStructureRec库是一个表格识别算法的集合库,当前有`wired_table_rec`有线表格识别算法和`lineless_table_rec`无线表格识别算法的推理包。 361 | 362 | RapidTable是整理自PP-Structure中表格识别部分而来。由于PP-Structure较早,这个库命名就成了`rapid_table`。 363 | 364 | 总之,RapidTable和TabelStructureRec都是表格识别的仓库。大家可以都试试,哪个好用用哪个。由于每个算法都不太同,暂时不打算做统一处理。 365 | 366 | 关于表格识别算法的比较,可参见[TableStructureRec测评](https://github.com/RapidAI/TableStructureRec#指标结果) 367 | 368 | ### 📌 更新日志 ([more](https://github.com/RapidAI/RapidTable/releases)) 369 | 370 |
371 | 372 | #### 2024.12.30 update 373 | 374 | - 支持Unitable模型的表格识别,使用pytorch框架 375 | 376 | #### 2024.11.24 update 377 | 378 | - 支持gpu推理,适配 rapidOCR 单字识别匹配,支持逻辑坐标返回及可视化 379 | 380 | #### 2024.10.13 update 381 | 382 | - 补充最新paddlex-SLANet-plus 模型(paddle2onnx原因暂不能支持onnx) 383 | 384 | #### 2023-12-29 v0.1.3 update 385 | 386 | - 优化可视化结果部分 387 | 388 | #### 2023-12-27 v0.1.2 update 389 | 390 | - 添加返回cell坐标框参数 391 | - 完善可视化函数 392 | 393 | #### 2023-07-17 v0.1.0 update 394 | 395 | - 将`rapidocr_onnxruntime`部分从`rapid_table`中解耦合出来,给出选项是否依赖,更加灵活。 396 | 397 | - 增加接口输入参数`ocr_result`: 398 | - 如果在调用函数时,事先指定了`ocr_result`参数值,则不会再走OCR。其中`ocr_result`格式需要和`rapidocr_onnxruntime`返回值一致。 399 | - 如果未指定`ocr_result`参数值,但是事先安装了`rapidocr_onnxruntime`库,则会自动调用该库,进行识别。 400 | - 如果`ocr_result`未指定,且`rapidocr_onnxruntime`未安装,则会报错。必须满足两个条件中一个。 401 | 402 | #### 2023-07-10 v0.0.13 updata 403 | 404 | - 更改传入表格还原中OCR的实例接口,可以传入其他OCR实例,前提要与`rapidocr_onnxruntime`接口一致 405 | 406 | #### 2023-07-06 v0.0.12 update 407 | 408 | - 去掉返回表格的html字符串中的``元素,便于后续统一。 409 | - 采用Black工具优化代码 410 | 411 |
412 | -------------------------------------------------------------------------------- /cliff.toml: -------------------------------------------------------------------------------- 1 | # git-cliff ~ configuration file 2 | # https://git-cliff.org/docs/configuration 3 | 4 | [changelog] 5 | # A Tera template to be rendered as the changelog's footer. 6 | # See https://keats.github.io/tera/docs/#introduction 7 | # header = """ 8 | # # Changelog\n 9 | # All notable changes to this project will be documented in this file. See [conventional commits](https://www.conventionalcommits.org/) for commit guidelines.\n 10 | # """ 11 | # A Tera template to be rendered for each release in the changelog. 12 | # See https://keats.github.io/tera/docs/#introduction 13 | body = """ 14 | {% for group, commits in commits | group_by(attribute="group") %} 15 | ### {{ group | striptags | trim | upper_first }} 16 | {% for commit in commits 17 | | filter(attribute="scope") 18 | | sort(attribute="scope") %} 19 | - **({{commit.scope}})**{% if commit.breaking %} [**breaking**]{% endif %} \ 20 | {{ commit.message }} by [@{{ commit.author.name }}](https://github.com/{{ commit.author.name }}) in [{{ commit.id | truncate(length=7, end="") }}]($REPO/commit/{{ commit.id }}) 21 | {%- endfor -%} 22 | {% raw %}\n{% endraw %}\ 23 | {%- for commit in commits %} 24 | {%- if commit.scope -%} 25 | {% else -%} 26 | - {% if commit.breaking %} [**breaking**]{% endif %}\ 27 | {{ commit.message }} by [@{{ commit.author.name }}](https://github.com/{{ commit.author.name }}) in [{{ commit.id | truncate(length=7, end="") }}]($REPO/commit/{{ commit.id }}) 28 | {% endif -%} 29 | {% endfor -%} 30 | {% endfor %} 31 | 32 | 33 | {% if github.contributors | length > 0 %} 34 | ### 🎉 Contributors 35 | 36 | {% for contributor in github.contributors %} 37 | - [@{{ contributor.username }}](https://github.com/{{ contributor.username }}) 38 | {%- endfor -%} 39 | {% endif %} 40 | 41 | 42 | {% if version %} 43 | {% if previous.version %}\ 44 | **Full Changelog**: [{{ version | trim_start_matches(pat="v") }}]($REPO/compare/{{ previous.version }}..{{ version }}) 45 | {% else %}\ 46 | **Full Changelog**: [{{ version | trim_start_matches(pat="v") }}] 47 | {% endif %}\ 48 | {% else %}\ 49 | ## [unreleased] 50 | {% endif %} 51 | """ 52 | # A Tera template to be rendered as the changelog's footer. 53 | # See https://keats.github.io/tera/docs/#introduction 54 | 55 | footer = """ 56 | 57 | """ 58 | 59 | # Remove leading and trailing whitespaces from the changelog's body. 60 | trim = true 61 | # postprocessors 62 | postprocessors = [ 63 | # Replace the placeholder `` with a URL. 64 | { pattern = '\$REPO', replace = "https://github.com/RapidAI/RapidTable" }, # replace repository URL 65 | ] 66 | 67 | [git] 68 | # Parse commits according to the conventional commits specification. 69 | # See https://www.conventionalcommits.org 70 | conventional_commits = true 71 | # Exclude commits that do not match the conventional commits specification. 72 | filter_unconventional = true 73 | # Split commits on newlines, treating each line as an individual commit. 74 | split_commits = false 75 | # An array of regex based parsers to modify commit messages prior to further processing. 76 | commit_preprocessors = [ 77 | # Replace issue numbers with link templates to be updated in `changelog.postprocessors`. 78 | #{ pattern = '\((\w+\s)?#([0-9]+)\)', replace = "([#${2}](https://github.com/orhun/git-cliff/issues/${2}))"}, 79 | ] 80 | # An array of regex based parsers for extracting data from the commit message. 81 | # Assigns commits to groups. 82 | # Optionally sets the commit's scope and can decide to exclude commits from further processing. 83 | commit_parsers = [ 84 | { message = "^feat", group = "🚀 Features" }, 85 | { message = "^fix", group = "🐛 Bug Fixes" }, 86 | { message = "^doc", group = "📚 Documentation" }, 87 | { message = "^perf", group = "⚡ Performance" }, 88 | { message = "^refactor", group = "🚜 Refactor" }, 89 | { message = "^style", group = "🎨 Styling" }, 90 | { message = "^test", group = "🧪 Testing" }, 91 | { message = "^chore\\(release\\): prepare for", skip = true }, 92 | { message = "^chore\\(deps.*\\)", skip = true }, 93 | { message = "^chore\\(pr\\)", skip = true }, 94 | { message = "^chore\\(pull\\)", skip = true }, 95 | { message = "^chore|^ci", group = "⚙️ Miscellaneous Tasks" }, 96 | { body = ".*security", group = "🛡️ Security" }, 97 | { message = "^revert", group = "◀️ Revert" }, 98 | { message = ".*", group = "💼 Other" }, 99 | ] 100 | # Exclude commits that are not matched by any commit parser. 101 | filter_commits = false 102 | # Order releases topologically instead of chronologically. 103 | topo_order = false 104 | # Order of commits in each group/release within the changelog. 105 | # Allowed values: newest, oldest 106 | sort_commits = "newest" -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from pathlib import Path 5 | 6 | from rapid_table import ModelType, RapidTable, RapidTableInput 7 | 8 | # input_args = RapidTableInput( 9 | # model_type=ModelType.UNITABLE, 10 | # engine_cfg={"use_cuda": True, "gpu_id": 1}, 11 | # ) 12 | input_args = RapidTableInput(model_type=ModelType.PPSTRUCTURE_ZH) 13 | table_engine = RapidTable(input_args) 14 | 15 | img_list = list(Path("images").iterdir()) 16 | results = table_engine(img_list, batch_size=3) 17 | results.vis(save_dir="outputs", save_name="vis", indexes=(0, 1, 2)) 18 | -------------------------------------------------------------------------------- /docs/doc_whl_rapid_table.md: -------------------------------------------------------------------------------- 1 | ### For details, see [Rapid Table](https://github.com/RapidAI/RapidTable) 2 | -------------------------------------------------------------------------------- /rapid_table/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from .main import RapidTable, RapidTableInput 5 | from .utils import EngineType, ModelType, VisTable 6 | -------------------------------------------------------------------------------- /rapid_table/default_models.yaml: -------------------------------------------------------------------------------- 1 | ppstructure_en: 2 | model_dir_or_path: https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/v2.0.0/en_ppstructure_mobile_v2_SLANet.onnx 3 | SHA256: 2cae17d16a16f9df7229e21665fe3fbe06f3ca85b2024772ee3e3142e955aa60 4 | 5 | ppstructure_zh: 6 | model_dir_or_path: https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/v2.0.0/ch_ppstructure_mobile_v2_SLANet.onnx 7 | SHA256: ddfc6c97ee4db2a5e9de4de8b6a14508a39d42d228503219fdfebfac364885e3 8 | 9 | slanet_plus: 10 | model_dir_or_path: https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/v2.0.0/slanet-plus.onnx 11 | SHA256: d57a942af6a2f57d6a4a0372573c696a2379bf5857c45e2ac69993f3b334514b 12 | 13 | unitable: 14 | model_dir_or_path: https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/v2.0.0/unitable 15 | SHA256: 16 | encoder.pth: 2c66b3c6a3d1c86a00985bab2cd79412fc2b668ff39d338bc3c63d383b08684d 17 | decoder.pth: fa342ef3de259576a01a5545ede804208ef35a124935e30df4768e6708dcb6cb 18 | vocab.json: 05037d02c48d106639bc90284aa847e5e2151d4746b3f5efe1628599efbd668a 19 | 20 | -------------------------------------------------------------------------------- /rapid_table/engine_cfg.yaml: -------------------------------------------------------------------------------- 1 | onnxruntime: 2 | intra_op_num_threads: -1 3 | inter_op_num_threads: -1 4 | enable_cpu_mem_arena: false 5 | 6 | cpu_ep_cfg: 7 | arena_extend_strategy: "kSameAsRequested" 8 | 9 | use_cuda: false 10 | cuda_ep_cfg: 11 | gpu_id: 0 12 | arena_extend_strategy: "kNextPowerOfTwo" 13 | cudnn_conv_algo_search: "EXHAUSTIVE" 14 | do_copy_in_default_stream: true 15 | 16 | use_dml: false 17 | dm_ep_cfg: null 18 | 19 | use_cann: false 20 | cann_ep_cfg: 21 | gpu_id: 0 22 | arena_extend_strategy: "kNextPowerOfTwo" 23 | npu_mem_limit: 21474836480 # 20 * 1024 * 1024 * 1024 24 | op_select_impl_mode: "high_performance" 25 | optypelist_for_implmode: "Gelu" 26 | enable_cann_graph: true 27 | 28 | openvino: 29 | inference_num_threads: -1 30 | 31 | paddle: 32 | cpu_math_library_num_threads: -1 33 | use_cuda: false 34 | gpu_id: 0 35 | gpu_mem: 500 36 | 37 | torch: 38 | use_cuda: false 39 | gpu_id: 0 40 | 41 | -------------------------------------------------------------------------------- /rapid_table/inference_engine/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | -------------------------------------------------------------------------------- /rapid_table/inference_engine/base.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | import abc 5 | from pathlib import Path 6 | from typing import Any, Dict, Union 7 | 8 | import numpy as np 9 | from omegaconf import DictConfig, OmegaConf 10 | 11 | from ..utils import EngineType, Logger, import_package, read_yaml 12 | 13 | logger = Logger(logger_name=__name__).get_log() 14 | 15 | 16 | class InferSession(abc.ABC): 17 | cur_dir = Path(__file__).resolve().parent.parent 18 | ENGINE_CFG_PATH = cur_dir / "engine_cfg.yaml" 19 | engine_cfg = read_yaml(ENGINE_CFG_PATH) 20 | 21 | @abc.abstractmethod 22 | def __init__(self, config): 23 | pass 24 | 25 | @abc.abstractmethod 26 | def __call__(self, input_content: np.ndarray) -> np.ndarray: 27 | pass 28 | 29 | @staticmethod 30 | def _verify_model(model_path: Union[str, Path, None]): 31 | if model_path is None: 32 | raise ValueError("model_path is None!") 33 | 34 | model_path = Path(model_path) 35 | if not model_path.exists(): 36 | raise FileNotFoundError(f"{model_path} does not exists.") 37 | 38 | if not model_path.is_file(): 39 | raise FileExistsError(f"{model_path} is not a file.") 40 | 41 | @abc.abstractmethod 42 | def have_key(self, key: str = "character") -> bool: 43 | pass 44 | 45 | @staticmethod 46 | def update_params(cfg: DictConfig, params: Dict[str, Any]): 47 | for k, v in params.items(): 48 | OmegaConf.update(cfg, k, v) 49 | return cfg 50 | 51 | 52 | def get_engine(engine_type: EngineType): 53 | logger.info("Using engine_name: %s", engine_type.value) 54 | 55 | if engine_type == EngineType.ONNXRUNTIME: 56 | if not import_package(engine_type.value): 57 | raise ImportError(f"{engine_type.value} is not installed.") 58 | 59 | from .onnxruntime import OrtInferSession 60 | 61 | return OrtInferSession 62 | 63 | if engine_type == EngineType.TORCH: 64 | if not import_package(engine_type.value): 65 | raise ImportError(f"{engine_type.value} is not installed") 66 | 67 | from .torch import TorchInferSession 68 | 69 | return TorchInferSession 70 | 71 | raise ValueError(f"Unsupported engine: {engine_type.value}") 72 | -------------------------------------------------------------------------------- /rapid_table/inference_engine/onnxruntime/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from .main import OrtInferSession 5 | -------------------------------------------------------------------------------- /rapid_table/inference_engine/onnxruntime/main.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | import os 5 | import traceback 6 | from pathlib import Path 7 | from typing import Any, Dict, List, Optional 8 | 9 | import numpy as np 10 | from omegaconf import DictConfig 11 | from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions 12 | 13 | from ...utils.logger import Logger 14 | from ..base import InferSession 15 | from .provider_config import ProviderConfig 16 | 17 | 18 | class OrtInferSession(InferSession): 19 | def __init__(self, cfg: Optional[Dict[str, Any]] = None): 20 | self.logger = Logger(logger_name=__name__).get_log() 21 | 22 | # support custom session (PR #451) 23 | session = cfg.get("session", None) 24 | if session is not None: 25 | if not isinstance(session, InferenceSession): 26 | raise TypeError( 27 | f"Expected session to be an InferenceSession, got {type(session)}" 28 | ) 29 | 30 | self.logger.debug("Using the provided InferenceSession for inference.") 31 | self.session = session 32 | return 33 | 34 | model_path = cfg.get("model_dir_or_path", None) 35 | self.logger.info(f"Using {model_path}") 36 | model_path = Path(model_path) 37 | self._verify_model(model_path) 38 | 39 | engine_cfg = self.update_params( 40 | self.engine_cfg[cfg["engine_type"].value], cfg["engine_cfg"] 41 | ) 42 | 43 | sess_opt = self._init_sess_opts(engine_cfg) 44 | provider_cfg = ProviderConfig(engine_cfg=engine_cfg) 45 | self.session = InferenceSession( 46 | model_path, 47 | sess_options=sess_opt, 48 | providers=provider_cfg.get_ep_list(), 49 | ) 50 | provider_cfg.verify_providers(self.session.get_providers()) 51 | 52 | @staticmethod 53 | def _init_sess_opts(cfg: DictConfig) -> SessionOptions: 54 | sess_opt = SessionOptions() 55 | sess_opt.log_severity_level = 4 56 | sess_opt.enable_cpu_mem_arena = cfg.get("enable_cpu_mem_arena", False) 57 | sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL 58 | 59 | cpu_nums = os.cpu_count() 60 | intra_op_num_threads = cfg.get("intra_op_num_threads", -1) 61 | if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums: 62 | sess_opt.intra_op_num_threads = intra_op_num_threads 63 | 64 | inter_op_num_threads = cfg.get("inter_op_num_threads", -1) 65 | if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums: 66 | sess_opt.inter_op_num_threads = inter_op_num_threads 67 | 68 | return sess_opt 69 | 70 | def __call__(self, input_content: np.ndarray) -> np.ndarray: 71 | input_dict = dict(zip(self.get_input_names(), [input_content])) 72 | try: 73 | return self.session.run(self.get_output_names(), input_dict) 74 | except Exception as e: 75 | error_info = traceback.format_exc() 76 | raise ONNXRuntimeError(error_info) from e 77 | 78 | def get_input_names(self) -> List[str]: 79 | return [v.name for v in self.session.get_inputs()] 80 | 81 | def get_output_names(self) -> List[str]: 82 | return [v.name for v in self.session.get_outputs()] 83 | 84 | def get_character_list(self, key: str = "character") -> List[str]: 85 | meta_dict = self.session.get_modelmeta().custom_metadata_map 86 | return meta_dict[key].splitlines() 87 | 88 | def have_key(self, key: str = "character") -> bool: 89 | meta_dict = self.session.get_modelmeta().custom_metadata_map 90 | if key in meta_dict.keys(): 91 | return True 92 | return False 93 | 94 | 95 | class ONNXRuntimeError(Exception): 96 | pass 97 | -------------------------------------------------------------------------------- /rapid_table/inference_engine/onnxruntime/provider_config.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | import platform 5 | from enum import Enum 6 | from typing import Any, Dict, List, Sequence, Tuple 7 | 8 | from omegaconf import DictConfig 9 | from onnxruntime import get_available_providers, get_device 10 | 11 | from ...utils.logger import Logger 12 | 13 | 14 | class EP(Enum): 15 | CPU_EP = "CPUExecutionProvider" 16 | CUDA_EP = "CUDAExecutionProvider" 17 | DIRECTML_EP = "DmlExecutionProvider" 18 | CANN_EP = "CANNExecutionProvider" 19 | 20 | 21 | class ProviderConfig: 22 | def __init__(self, engine_cfg: DictConfig): 23 | self.logger = Logger(logger_name=__name__).get_log() 24 | 25 | self.had_providers: List[str] = get_available_providers() 26 | self.default_provider = self.had_providers[0] 27 | 28 | self.cfg_use_cuda = engine_cfg.get("use_cuda", False) 29 | self.cfg_use_dml = engine_cfg.get("use_dml", False) 30 | self.cfg_use_cann = engine_cfg.get("use_cann", False) 31 | 32 | self.cfg = engine_cfg 33 | 34 | def get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]: 35 | results = [(EP.CPU_EP.value, self.cpu_ep_cfg())] 36 | 37 | if self.is_cuda_available(): 38 | results.insert(0, (EP.CUDA_EP.value, self.cuda_ep_cfg())) 39 | 40 | if self.is_dml_available(): 41 | self.logger.info( 42 | "Windows 10 or above detected, try to use DirectML as primary provider" 43 | ) 44 | results.insert(0, (EP.DIRECTML_EP.value, self.dml_ep_cfg())) 45 | 46 | if self.is_cann_available(): 47 | self.logger.info("Try to use CANNExecutionProvider to infer") 48 | results.insert(0, (EP.CANN_EP.value, self.cann_ep_cfg())) 49 | 50 | return results 51 | 52 | def cpu_ep_cfg(self) -> Dict[str, Any]: 53 | return dict(self.cfg.cpu_ep_cfg) 54 | 55 | def cuda_ep_cfg(self) -> Dict[str, Any]: 56 | return dict(self.cfg.cuda_ep_cfg) 57 | 58 | def dml_ep_cfg(self) -> Dict[str, Any]: 59 | if self.cfg.dm_ep_cfg is not None: 60 | return self.cfg.dm_ep_cfg 61 | 62 | if self.is_cuda_available(): 63 | return self.cuda_ep_cfg() 64 | return self.cpu_ep_cfg() 65 | 66 | def cann_ep_cfg(self) -> Dict[str, Any]: 67 | return dict(self.cfg.cann_ep_cfg) 68 | 69 | def verify_providers(self, session_providers: Sequence[str]): 70 | if not session_providers: 71 | raise ValueError("Session Providers is empty") 72 | 73 | first_provider = session_providers[0] 74 | 75 | providers_to_check = { 76 | EP.CUDA_EP: self.is_cuda_available, 77 | EP.DIRECTML_EP: self.is_dml_available, 78 | EP.CANN_EP: self.is_cann_available, 79 | } 80 | 81 | for ep, check_func in providers_to_check.items(): 82 | if check_func() and first_provider != ep.value: 83 | self.logger.warning( 84 | f"{ep.value} is available, but the inference part is automatically shifted to be executed under {first_provider}. " 85 | ) 86 | self.logger.warning(f"The available lists are {session_providers}") 87 | 88 | def is_cuda_available(self) -> bool: 89 | if not self.cfg_use_cuda: 90 | return False 91 | 92 | CUDA_EP = EP.CUDA_EP.value 93 | if get_device() == "GPU" and CUDA_EP in self.had_providers: 94 | return True 95 | 96 | self.logger.warning( 97 | f"{CUDA_EP} is not in available providers ({self.had_providers}). Use {self.default_provider} inference by default." 98 | ) 99 | install_instructions = [ 100 | f"If you want to use {CUDA_EP} acceleration, you must do:" 101 | "(For reference only) If you want to use GPU acceleration, you must do:", 102 | "First, uninstall all onnxruntime packages in current environment.", 103 | "Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`.", 104 | "Note the onnxruntime-gpu version must match your cuda and cudnn version.", 105 | "You can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html", 106 | f"Third, ensure {CUDA_EP} is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']", 107 | ] 108 | self.print_log(install_instructions) 109 | return False 110 | 111 | def is_dml_available(self) -> bool: 112 | if not self.cfg_use_dml: 113 | return False 114 | 115 | cur_os = platform.system() 116 | if cur_os != "Windows": 117 | self.logger.warning( 118 | f"DirectML is only supported in Windows OS. The current OS is {cur_os}. Use {self.default_provider} inference by default.", 119 | ) 120 | return False 121 | 122 | window_build_number_str = platform.version().split(".")[-1] 123 | window_build_number = ( 124 | int(window_build_number_str) if window_build_number_str.isdigit() else 0 125 | ) 126 | if window_build_number < 18362: 127 | self.logger.warning( 128 | f"DirectML is only supported in Windows 10 Build 18362 and above OS. The current Windows Build is {window_build_number}. Use {self.default_provider} inference by default.", 129 | ) 130 | return False 131 | 132 | DML_EP = EP.DIRECTML_EP.value 133 | if DML_EP in self.had_providers: 134 | return True 135 | 136 | self.logger.warning( 137 | f"{DML_EP} is not in available providers ({self.had_providers}). Use {self.default_provider} inference by default." 138 | ) 139 | install_instructions = [ 140 | "If you want to use DirectML acceleration, you must do:", 141 | "First, uninstall all onnxruntime packages in current environment.", 142 | "Second, install onnxruntime-directml by `pip install onnxruntime-directml`", 143 | f"Third, ensure {DML_EP} is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']", 144 | ] 145 | self.print_log(install_instructions) 146 | return False 147 | 148 | def is_cann_available(self) -> bool: 149 | if not self.cfg_use_cann: 150 | return False 151 | 152 | CANN_EP = EP.CANN_EP.value 153 | if CANN_EP in self.had_providers: 154 | return True 155 | 156 | self.logger.warning( 157 | f"{CANN_EP} is not in available providers ({self.had_providers}). Use {self.default_provider} inference by default." 158 | ) 159 | install_instructions = [ 160 | "If you want to use CANN acceleration, you must do:", 161 | "First, ensure you have installed Huawei Ascend software stack.", 162 | "Second, install onnxruntime with CANN support by following the instructions at:", 163 | "\thttps://onnxruntime.ai/docs/execution-providers/community-maintained/CANN-ExecutionProvider.html", 164 | f"Third, ensure {CANN_EP} is in available providers list. e.g. ['CANNExecutionProvider', 'CPUExecutionProvider']", 165 | ] 166 | self.print_log(install_instructions) 167 | return False 168 | 169 | def print_log(self, log_list: List[str]): 170 | for log_info in log_list: 171 | self.logger.info(log_info) 172 | -------------------------------------------------------------------------------- /rapid_table/inference_engine/torch.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from pathlib import Path 5 | from typing import Union 6 | 7 | import numpy as np 8 | import torch 9 | from tokenizers import Tokenizer 10 | 11 | from ..table_structure.unitable.unitable_modules import Encoder, GPTFastDecoder 12 | from ..utils.logger import Logger 13 | from .base import InferSession 14 | 15 | root_dir = Path(__file__).resolve().parent.parent 16 | 17 | 18 | class TorchInferSession(InferSession): 19 | def __init__(self, cfg) -> None: 20 | self.logger = Logger(logger_name=__name__).get_log() 21 | 22 | engine_cfg = self.update_params( 23 | self.engine_cfg[cfg["engine_type"].value], cfg["engine_cfg"] 24 | ) 25 | 26 | self.device = torch.device( 27 | f"cuda:{engine_cfg.gpu_id}" 28 | if torch.cuda.is_available() and engine_cfg.use_cuda 29 | else "cpu" 30 | ) 31 | 32 | model_info = cfg["model_dir_or_path"] 33 | self.encoder = self._init_model(model_info["encoder"], Encoder) 34 | self.decoder = self._init_model(model_info["decoder"], GPTFastDecoder) 35 | self.vocab = self._init_vocab(model_info["vocab"]) 36 | 37 | def _init_model(self, model_path, model_class): 38 | model = model_class() 39 | model.load_state_dict(torch.load(model_path, map_location=self.device)) 40 | model.eval().to(self.device) 41 | return model 42 | 43 | def _init_vocab(self, vocab_path: Union[str, Path]): 44 | return Tokenizer.from_file(str(vocab_path)) 45 | 46 | def __call__(self, img: np.ndarray): 47 | raise NotImplementedError( 48 | "Inference logic is not implemented for TorchInferSession." 49 | ) 50 | 51 | def have_key(self, key: str = "character") -> bool: 52 | return False 53 | 54 | 55 | class TorchInferError(Exception): 56 | pass 57 | -------------------------------------------------------------------------------- /rapid_table/main.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | import argparse 5 | import time 6 | from dataclasses import asdict 7 | from pathlib import Path 8 | from typing import Any, Dict, List, Optional, Tuple, Union, get_args 9 | 10 | import numpy as np 11 | from tqdm import tqdm 12 | 13 | from .model_processor.main import ModelProcessor 14 | from .table_matcher import TableMatch 15 | from .utils import ( 16 | InputType, 17 | LoadImage, 18 | Logger, 19 | ModelType, 20 | RapidTableInput, 21 | RapidTableOutput, 22 | format_ocr_results, 23 | import_package, 24 | is_url, 25 | ) 26 | 27 | logger = Logger(logger_name=__name__).get_log() 28 | root_dir = Path(__file__).resolve().parent 29 | 30 | 31 | class RapidTable: 32 | def __init__(self, cfg: Optional[RapidTableInput] = None): 33 | if cfg is None: 34 | cfg = RapidTableInput() 35 | 36 | if not cfg.model_dir_or_path and cfg.model_type is not None: 37 | cfg.model_dir_or_path = ModelProcessor.get_model_path(cfg.model_type) 38 | 39 | self.cfg = cfg 40 | self.table_structure = self._init_table_structer() 41 | 42 | self.ocr_engine = None 43 | if cfg.use_ocr: 44 | self.ocr_engine = self._init_ocr_engine(self.cfg.ocr_params) 45 | 46 | self.table_matcher = TableMatch() 47 | self.load_img = LoadImage() 48 | 49 | def _init_ocr_engine(self, params: Dict[Any, Any]): 50 | rapidocr_ = import_package("rapidocr") 51 | if rapidocr_ is None: 52 | logger.warning("rapidocr package is not installed, only table rec") 53 | return None 54 | 55 | if not params: 56 | return rapidocr_.RapidOCR() 57 | return rapidocr_.RapidOCR(params=params) 58 | 59 | def _init_table_structer(self): 60 | if self.cfg.model_type == ModelType.UNITABLE: 61 | from .table_structure.unitable import UniTableStructure 62 | 63 | return UniTableStructure(asdict(self.cfg)) 64 | 65 | from .table_structure.pp_structure import PPTableStructurer 66 | 67 | return PPTableStructurer(asdict(self.cfg)) 68 | 69 | def __call__( 70 | self, 71 | img_contents: Union[List[InputType], InputType], 72 | ocr_results: Optional[Tuple[np.ndarray, Tuple[str], Tuple[float]]] = None, 73 | batch_size: int = 1, 74 | ) -> RapidTableOutput: 75 | s = time.perf_counter() 76 | 77 | if not isinstance(img_contents, list): 78 | img_contents = [img_contents] 79 | 80 | for img_content in img_contents: 81 | if not isinstance(img_content, get_args(InputType)): 82 | type_names = ", ".join([t.__name__ for t in get_args(InputType)]) 83 | actual_type = ( 84 | type(img_content).__name__ if img_content is not None else "None" 85 | ) 86 | raise TypeError( 87 | f"Type Error: Expected input of type [{type_names}], but received type {actual_type}." 88 | ) 89 | 90 | results = RapidTableOutput() 91 | 92 | total_nums = len(img_contents) 93 | for start_i in tqdm(range(0, total_nums, batch_size), desc="BatchRec"): 94 | end_i = min(total_nums, start_i + batch_size) 95 | 96 | imgs = self._load_imgs(img_contents[start_i:end_i]) 97 | 98 | pred_structures, cell_bboxes = self.table_structure(imgs) 99 | logic_points = self.table_matcher.decode_logic_points(pred_structures) 100 | 101 | if not self.cfg.use_ocr: 102 | results.imgs.extend(imgs) 103 | results.cell_bboxes.extend(cell_bboxes) 104 | results.logic_points.extend(logic_points) 105 | continue 106 | 107 | dt_boxes, rec_res = self.get_ocr_results(imgs, start_i, end_i, ocr_results) 108 | pred_htmls = self.table_matcher( 109 | pred_structures, cell_bboxes, dt_boxes, rec_res 110 | ) 111 | 112 | results.imgs.extend(imgs) 113 | results.pred_htmls.extend(pred_htmls) 114 | results.cell_bboxes.extend(cell_bboxes) 115 | results.logic_points.extend(logic_points) 116 | 117 | elapse = time.perf_counter() - s 118 | results.elapse = elapse / total_nums 119 | return results 120 | 121 | def _load_imgs( 122 | self, img_content: Union[List[InputType], InputType] 123 | ) -> List[np.ndarray]: 124 | img_contents = img_content if isinstance(img_content, list) else [img_content] 125 | return [self.load_img(img) for img in img_contents] 126 | 127 | def get_ocr_results( 128 | self, 129 | imgs: List[np.ndarray], 130 | start_i: int, 131 | end_i: int, 132 | ocr_results: Optional[List[Tuple[np.ndarray, Tuple[str], Tuple[float]]]] = None, 133 | ) -> Any: 134 | batch_dt_boxes, batch_rec_res = [], [] 135 | 136 | if ocr_results is not None: 137 | ocr_results_batch = ocr_results[start_i:end_i] 138 | if len(ocr_results_batch) != len(imgs): 139 | raise ValueError( 140 | f"Batch size mismatch: {len(imgs)} images but {len(ocr_results_batch)} OCR results " 141 | f"(indices {start_i}:{end_i})." 142 | ) 143 | 144 | for img, ocr_result in zip(imgs, ocr_results_batch): 145 | img_h, img_w = img.shape[:2] 146 | dt_boxes, rec_res = format_ocr_results(ocr_result, img_h, img_w) 147 | batch_dt_boxes.append(dt_boxes) 148 | batch_rec_res.append(rec_res) 149 | return batch_dt_boxes, batch_rec_res 150 | 151 | for img in tqdm(imgs, desc="OCR"): 152 | if img is None: 153 | continue 154 | 155 | ori_ocr_res = self.ocr_engine(img) 156 | if ori_ocr_res.boxes is None: 157 | logger.warning("OCR Result is empty") 158 | batch_dt_boxes.append(None) 159 | batch_rec_res.append(None) 160 | continue 161 | 162 | img_h, img_w = img.shape[:2] 163 | 164 | ocr_result = [ori_ocr_res.boxes, ori_ocr_res.txts, ori_ocr_res.scores] 165 | dt_boxes, rec_res = format_ocr_results(ocr_result, img_h, img_w) 166 | batch_dt_boxes.append(dt_boxes) 167 | batch_rec_res.append(rec_res) 168 | 169 | return batch_dt_boxes, batch_rec_res 170 | 171 | 172 | def parse_args(arg_list: Optional[List[str]] = None): 173 | parser = argparse.ArgumentParser() 174 | parser.add_argument("img_path", type=str, help="the image path or URL of the table") 175 | parser.add_argument( 176 | "-m", 177 | "--model_type", 178 | type=str, 179 | default=ModelType.SLANETPLUS.value, 180 | choices=[v.value for v in ModelType], 181 | help="Supported table rec models", 182 | ) 183 | parser.add_argument( 184 | "-v", 185 | "--vis", 186 | action="store_true", 187 | default=False, 188 | help="Wheter to visualize the layout results.", 189 | ) 190 | args = parser.parse_args(arg_list) 191 | return args 192 | 193 | 194 | def main(arg_list: Optional[List[str]] = None): 195 | args = parse_args(arg_list) 196 | img_path = args.img_path 197 | 198 | input_args = RapidTableInput(model_type=ModelType(args.model_type)) 199 | table_engine = RapidTable(input_args) 200 | 201 | table_results = table_engine(img_path) 202 | print(table_results.pred_htmls) 203 | 204 | if args.vis: 205 | save_dir = Path(".") if is_url(img_path) else Path(img_path).resolve().parent 206 | table_results.vis(save_dir, save_name=Path(img_path).stem) 207 | 208 | 209 | if __name__ == "__main__": 210 | main() 211 | -------------------------------------------------------------------------------- /rapid_table/model_processor/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | -------------------------------------------------------------------------------- /rapid_table/model_processor/main.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from pathlib import Path 5 | from typing import Dict, Union 6 | 7 | from ..utils import DownloadFile, DownloadFileInput, Logger, ModelType, mkdir, read_yaml 8 | 9 | 10 | class ModelProcessor: 11 | logger = Logger(logger_name=__name__).get_log() 12 | 13 | cur_dir = Path(__file__).resolve().parent 14 | root_dir = cur_dir.parent 15 | DEFAULT_MODEL_PATH = root_dir / "default_models.yaml" 16 | 17 | DEFAULT_MODEL_DIR = root_dir / "models" 18 | mkdir(DEFAULT_MODEL_DIR) 19 | 20 | model_map = read_yaml(DEFAULT_MODEL_PATH) 21 | 22 | @classmethod 23 | def get_model_path(cls, model_type: ModelType) -> Union[str, Dict[str, str]]: 24 | if model_type == ModelType.UNITABLE: 25 | return cls.get_multi_models_dict(model_type) 26 | return cls.get_single_model_path(model_type) 27 | 28 | @classmethod 29 | def get_single_model_path(cls, model_type: ModelType) -> str: 30 | model_info = cls.model_map[model_type.value] 31 | save_model_path = ( 32 | cls.DEFAULT_MODEL_DIR / Path(model_info["model_dir_or_path"]).name 33 | ) 34 | download_params = DownloadFileInput( 35 | file_url=model_info["model_dir_or_path"], 36 | sha256=model_info["SHA256"], 37 | save_path=save_model_path, 38 | logger=cls.logger, 39 | ) 40 | DownloadFile.run(download_params) 41 | 42 | return str(save_model_path) 43 | 44 | @classmethod 45 | def get_multi_models_dict(cls, model_type: ModelType) -> Dict[str, str]: 46 | model_info = cls.model_map[model_type.value] 47 | 48 | results = {} 49 | 50 | model_root_dir = model_info["model_dir_or_path"] 51 | save_model_dir = cls.DEFAULT_MODEL_DIR / Path(model_root_dir).name 52 | for file_name, sha256 in model_info["SHA256"].items(): 53 | save_path = save_model_dir / file_name 54 | 55 | download_params = DownloadFileInput( 56 | file_url=f"{model_root_dir}/{file_name}", 57 | sha256=sha256, 58 | save_path=save_path, 59 | logger=cls.logger, 60 | ) 61 | DownloadFile.run(download_params) 62 | results[Path(file_name).stem] = str(save_path) 63 | 64 | return results 65 | -------------------------------------------------------------------------------- /rapid_table/models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidTable/b0ba540a2a5ac06a61bcd6506795246a5046a291/rapid_table/models/.gitkeep -------------------------------------------------------------------------------- /rapid_table/table_matcher/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from .main import TableMatch 5 | -------------------------------------------------------------------------------- /rapid_table/table_matcher/main.py: -------------------------------------------------------------------------------- 1 | # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # -*- encoding: utf-8 -*- 15 | from typing import Dict, List, Optional, Tuple 16 | 17 | import numpy as np 18 | 19 | from .utils import compute_iou, distance 20 | 21 | 22 | class TableMatch: 23 | def __init__(self): 24 | pass 25 | 26 | def __call__( 27 | self, 28 | pred_structures: List[List[List[str]]], 29 | cell_bboxes: List[np.ndarray], 30 | dt_boxes: List[np.ndarray], 31 | rec_reses: List[List[Tuple[str, float]]], 32 | ) -> List[Optional[str]]: 33 | results = [] 34 | for item in zip(pred_structures, cell_bboxes, dt_boxes, rec_reses): 35 | pred_struct, cell_bbox, dt_box, rec_res = item 36 | if dt_box is None or rec_res is None: 37 | results.append(None) 38 | continue 39 | 40 | one_result = self.process_one(pred_struct, cell_bbox, dt_box, rec_res) 41 | results.append(one_result) 42 | return results 43 | 44 | def process_one( 45 | self, 46 | pred_struct: List[List[str]], 47 | cell_bboxes: np.ndarray, 48 | dt_boxes: np.ndarray, 49 | rec_res: List[Tuple[str, float]], 50 | ) -> str: 51 | dt_boxes, rec_res = self.filter_ocr_result(cell_bboxes, dt_boxes, rec_res) 52 | matched_index = self.match_result(cell_bboxes, dt_boxes) 53 | pred_html, pred = self.get_pred_html(pred_struct[0], matched_index, rec_res) 54 | return pred_html 55 | 56 | def filter_ocr_result( 57 | self, 58 | cell_bboxes: np.ndarray, 59 | dt_boxes: np.ndarray, 60 | rec_res: List[Tuple[str, float]], 61 | ) -> Tuple[np.ndarray, List[Tuple[str, float]]]: 62 | y1 = cell_bboxes[:, 1::2].min() 63 | 64 | new_dt_boxes, new_rec_res = [], [] 65 | for box, rec in zip(dt_boxes, rec_res): 66 | if np.max(box[1::2]) < y1: 67 | continue 68 | 69 | new_dt_boxes.append(box) 70 | new_rec_res.append(rec) 71 | return np.array(new_dt_boxes), new_rec_res 72 | 73 | def match_result( 74 | self, cell_bboxes: np.ndarray, dt_boxes: np.ndarray, min_iou: float = 0.1**8 75 | ) -> Dict[int, List[int]]: 76 | matched = {} 77 | for i, gt_box in enumerate(dt_boxes): 78 | distances = [] 79 | for j, pred_box in enumerate(cell_bboxes): 80 | if len(pred_box) == 8: 81 | pred_box = [ 82 | np.min(pred_box[0::2]), 83 | np.min(pred_box[1::2]), 84 | np.max(pred_box[0::2]), 85 | np.max(pred_box[1::2]), 86 | ] 87 | distances.append( 88 | (distance(gt_box, pred_box), 1.0 - compute_iou(gt_box, pred_box)) 89 | ) # compute iou and l1 distance 90 | sorted_distances = distances.copy() 91 | # select det box by iou and l1 distance 92 | sorted_distances = sorted( 93 | sorted_distances, key=lambda item: (item[1], item[0]) 94 | ) 95 | # must > min_iou 96 | if sorted_distances[0][1] >= 1 - min_iou: 97 | continue 98 | 99 | if distances.index(sorted_distances[0]) not in matched: 100 | matched[distances.index(sorted_distances[0])] = [i] 101 | else: 102 | matched[distances.index(sorted_distances[0])].append(i) 103 | return matched 104 | 105 | def get_pred_html( 106 | self, 107 | pred_structures: List[str], 108 | matched_index: Dict[int, List[int]], 109 | ocr_contents: List[Tuple[str, float]], 110 | ): 111 | end_html = [] 112 | td_index = 0 113 | for tag in pred_structures: 114 | if "" not in tag: 115 | end_html.append(tag) 116 | continue 117 | 118 | if "" == tag: 119 | end_html.extend("") 120 | 121 | if td_index in matched_index.keys(): 122 | b_with = False 123 | if ( 124 | "" in ocr_contents[matched_index[td_index][0]] 125 | and len(matched_index[td_index]) > 1 126 | ): 127 | b_with = True 128 | end_html.extend("") 129 | 130 | for i, td_index_index in enumerate(matched_index[td_index]): 131 | content = ocr_contents[td_index_index][0] 132 | if len(matched_index[td_index]) > 1: 133 | if len(content) == 0: 134 | continue 135 | 136 | if content[0] == " ": 137 | content = content[1:] 138 | 139 | if "" in content: 140 | content = content[3:] 141 | 142 | if "" in content: 143 | content = content[:-4] 144 | 145 | if len(content) == 0: 146 | continue 147 | 148 | if i != len(matched_index[td_index]) - 1 and " " != content[-1]: 149 | content += " " 150 | end_html.extend(content) 151 | 152 | if b_with: 153 | end_html.extend("") 154 | 155 | if "" == tag: 156 | end_html.append("") 157 | else: 158 | end_html.append(tag) 159 | 160 | td_index += 1 161 | 162 | # Filter elements 163 | filter_elements = ["", "", "", ""] 164 | end_html = [v for v in end_html if v not in filter_elements] 165 | return "".join(end_html), end_html 166 | 167 | def decode_logic_points(self, pred_structures: List[List[List[str]]]): 168 | results = [] 169 | for pred_struct in pred_structures: 170 | decode_result = self.decode_one_logic_points(pred_struct[0]) 171 | results.append(np.array(decode_result)) 172 | return results 173 | 174 | def decode_one_logic_points(self, pred_structures): 175 | logic_points = [] 176 | current_row = 0 177 | current_col = 0 178 | max_rows = 0 179 | max_cols = 0 180 | occupied_cells = {} # 用于记录已经被占用的单元格 181 | 182 | def is_occupied(row, col): 183 | return (row, col) in occupied_cells 184 | 185 | def mark_occupied(row, col, rowspan, colspan): 186 | for r in range(row, row + rowspan): 187 | for c in range(col, col + colspan): 188 | occupied_cells[(r, c)] = True 189 | 190 | i = 0 191 | while i < len(pred_structures): 192 | token = pred_structures[i] 193 | 194 | if token == "": 195 | current_col = 0 # 每次遇到 时,重置当前列号 196 | elif token == "": 197 | current_row += 1 # 行结束,行号增加 198 | elif token.startswith(""): 208 | if "colspan=" in pred_structures[j]: 209 | colspan = int(pred_structures[j].split("=")[1].strip("\"'")) 210 | elif "rowspan=" in pred_structures[j]: 211 | rowspan = int(pred_structures[j].split("=")[1].strip("\"'")) 212 | j += 1 213 | 214 | # 跳过已经处理过的属性 token 215 | i = j 216 | 217 | # 找到下一个未被占用的列 218 | while is_occupied(current_row, current_col): 219 | current_col += 1 220 | 221 | # 计算逻辑坐标 222 | r_start = current_row 223 | r_end = current_row + rowspan - 1 224 | col_start = current_col 225 | col_end = current_col + colspan - 1 226 | 227 | # 记录逻辑坐标 228 | logic_points.append([r_start, r_end, col_start, col_end]) 229 | 230 | # 标记占用的单元格 231 | mark_occupied(r_start, col_start, rowspan, colspan) 232 | 233 | # 更新当前列号 234 | current_col += colspan 235 | 236 | # 更新最大行数和列数 237 | max_rows = max(max_rows, r_end + 1) 238 | max_cols = max(max_cols, col_end + 1) 239 | 240 | i += 1 241 | 242 | return logic_points 243 | -------------------------------------------------------------------------------- /rapid_table/table_matcher/utils.py: -------------------------------------------------------------------------------- 1 | # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # -*- encoding: utf-8 -*- 15 | # @Author: SWHL 16 | # @Contact: liekkaskono@163.com 17 | import copy 18 | import re 19 | 20 | 21 | def deal_isolate_span(thead_part): 22 | """ 23 | Deal with isolate span cases in this function. 24 | It causes by wrong prediction in structure recognition model. 25 | eg. predict to rowspan="2">. 26 | :param thead_part: 27 | :return: 28 | """ 29 | # 1. find out isolate span tokens. 30 | isolate_pattern = ( 31 | r' rowspan="(\d)+" colspan="(\d)+">|' 32 | r' colspan="(\d)+" rowspan="(\d)+">|' 33 | r' rowspan="(\d)+">|' 34 | r' colspan="(\d)+">' 35 | ) 36 | isolate_iter = re.finditer(isolate_pattern, thead_part) 37 | isolate_list = [i.group() for i in isolate_iter] 38 | 39 | # 2. find out span number, by step 1 results. 40 | span_pattern = ( 41 | r' rowspan="(\d)+" colspan="(\d)+"|' 42 | r' colspan="(\d)+" rowspan="(\d)+"|' 43 | r' rowspan="(\d)+"|' 44 | r' colspan="(\d)+"' 45 | ) 46 | corrected_list = [] 47 | for isolate_item in isolate_list: 48 | span_part = re.search(span_pattern, isolate_item) 49 | spanStr_in_isolateItem = span_part.group() 50 | # 3. merge the span number into the span token format string. 51 | if spanStr_in_isolateItem is not None: 52 | corrected_item = f"" 53 | corrected_list.append(corrected_item) 54 | else: 55 | corrected_list.append(None) 56 | 57 | # 4. replace original isolated token. 58 | for corrected_item, isolate_item in zip(corrected_list, isolate_list): 59 | if corrected_item is not None: 60 | thead_part = thead_part.replace(isolate_item, corrected_item) 61 | else: 62 | pass 63 | return thead_part 64 | 65 | 66 | def deal_duplicate_bb(thead_part): 67 | """ 68 | Deal duplicate or after replace. 69 | Keep one in a token. 70 | :param thead_part: 71 | :return: 72 | """ 73 | # 1. find out in . 74 | td_pattern = ( 75 | r'(.+?)|' 76 | r'(.+?)|' 77 | r'(.+?)|' 78 | r'(.+?)|' 79 | r"(.*?)" 80 | ) 81 | td_iter = re.finditer(td_pattern, thead_part) 82 | td_list = [t.group() for t in td_iter] 83 | 84 | # 2. is multiply in or not? 85 | new_td_list = [] 86 | for td_item in td_list: 87 | if td_item.count("") > 1 or td_item.count("") > 1: 88 | # multiply in case. 89 | # 1. remove all 90 | td_item = td_item.replace("", "").replace("", "") 91 | # 2. replace -> , -> . 92 | td_item = td_item.replace("", "").replace("", "") 93 | new_td_list.append(td_item) 94 | else: 95 | new_td_list.append(td_item) 96 | 97 | # 3. replace original thead part. 98 | for td_item, new_td_item in zip(td_list, new_td_list): 99 | thead_part = thead_part.replace(td_item, new_td_item) 100 | return thead_part 101 | 102 | 103 | def deal_bb(result_token): 104 | """ 105 | In our opinion, always occurs in text's context. 106 | This function will find out all tokens in and insert by manual. 107 | :param result_token: 108 | :return: 109 | """ 110 | # find out parts. 111 | thead_pattern = "(.*?)" 112 | if re.search(thead_pattern, result_token) is None: 113 | return result_token 114 | thead_part = re.search(thead_pattern, result_token).group() 115 | origin_thead_part = copy.deepcopy(thead_part) 116 | 117 | # check "rowspan" or "colspan" occur in parts or not . 118 | span_pattern = r'|||' 119 | span_iter = re.finditer(span_pattern, thead_part) 120 | span_list = [s.group() for s in span_iter] 121 | has_span_in_head = True if len(span_list) > 0 else False 122 | 123 | if not has_span_in_head: 124 | # not include "rowspan" or "colspan" branch 1. 125 | # 1. replace to , and to 126 | # 2. it is possible to predict text include or by Text-line recognition, 127 | # so we replace to , and to 128 | thead_part = ( 129 | thead_part.replace("", "") 130 | .replace("", "") 131 | .replace("", "") 132 | .replace("", "") 133 | ) 134 | else: 135 | # include "rowspan" or "colspan" branch 2. 136 | # Firstly, we deal rowspan or colspan cases. 137 | # 1. replace > to > 138 | # 2. replace to 139 | # 3. it is possible to predict text include or by Text-line recognition, 140 | # so we replace to , and to 141 | 142 | # Secondly, deal ordinary cases like branch 1 143 | 144 | # replace ">" to "" 145 | replaced_span_list = [] 146 | for sp in span_list: 147 | replaced_span_list.append(sp.replace(">", ">")) 148 | for sp, rsp in zip(span_list, replaced_span_list): 149 | thead_part = thead_part.replace(sp, rsp) 150 | 151 | # replace "" to "" 152 | thead_part = thead_part.replace("", "") 153 | 154 | # remove duplicated by re.sub 155 | mb_pattern = "()+" 156 | single_b_string = "" 157 | thead_part = re.sub(mb_pattern, single_b_string, thead_part) 158 | 159 | mgb_pattern = "()+" 160 | single_gb_string = "" 161 | thead_part = re.sub(mgb_pattern, single_gb_string, thead_part) 162 | 163 | # ordinary cases like branch 1 164 | thead_part = thead_part.replace("", "").replace("", "") 165 | 166 | # convert back to , empty cell has no . 167 | # but space cell( ) is suitable for 168 | thead_part = thead_part.replace("", "") 169 | # deal with duplicated 170 | thead_part = deal_duplicate_bb(thead_part) 171 | # deal with isolate span tokens, which causes by wrong predict by structure prediction. 172 | # eg.PMC5994107_011_00.png 173 | thead_part = deal_isolate_span(thead_part) 174 | # replace original result with new thead part. 175 | result_token = result_token.replace(origin_thead_part, thead_part) 176 | return result_token 177 | 178 | 179 | def deal_eb_token(master_token): 180 | """ 181 | post process with , , ... 182 | emptyBboxTokenDict = { 183 | "[]": '', 184 | "[' ']": '', 185 | "['', ' ', '']": '', 186 | "['\\u2028', '\\u2028']": '', 187 | "['', ' ', '']": '', 188 | "['', '']": '', 189 | "['', ' ', '']": '', 190 | "['', '', '', '']": '', 191 | "['', '', ' ', '', '']": '', 192 | "['', '']": '', 193 | "['', ' ', '\\u2028', ' ', '\\u2028', ' ', '']": '', 194 | } 195 | :param master_token: 196 | :return: 197 | """ 198 | master_token = master_token.replace("", "") 199 | master_token = master_token.replace("", " ") 200 | master_token = master_token.replace("", " ") 201 | master_token = master_token.replace("", "\u2028\u2028") 202 | master_token = master_token.replace("", " ") 203 | master_token = master_token.replace("", "") 204 | master_token = master_token.replace("", " ") 205 | master_token = master_token.replace("", "") 206 | master_token = master_token.replace("", " ") 207 | master_token = master_token.replace("", "") 208 | master_token = master_token.replace( 209 | "", " \u2028 \u2028 " 210 | ) 211 | return master_token 212 | 213 | 214 | def distance(box_1, box_2): 215 | x1, y1, x2, y2 = box_1 216 | x3, y3, x4, y4 = box_2 217 | dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2) 218 | dis_2 = abs(x3 - x1) + abs(y3 - y1) 219 | dis_3 = abs(x4 - x2) + abs(y4 - y2) 220 | return dis + min(dis_2, dis_3) 221 | 222 | 223 | def compute_iou(rec1, rec2): 224 | """ 225 | computing IoU 226 | :param rec1: (y0, x0, y1, x1), which reflects 227 | (top, left, bottom, right) 228 | :param rec2: (y0, x0, y1, x1) 229 | :return: scala value of IoU 230 | """ 231 | # computing area of each rectangles 232 | S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1]) 233 | S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1]) 234 | 235 | # computing the sum_area 236 | sum_area = S_rec1 + S_rec2 237 | 238 | # find the each edge of intersect rectangle 239 | left_line = max(rec1[1], rec2[1]) 240 | right_line = min(rec1[3], rec2[3]) 241 | top_line = max(rec1[0], rec2[0]) 242 | bottom_line = min(rec1[2], rec2[2]) 243 | 244 | # judge if there is an intersect 245 | if left_line >= right_line or top_line >= bottom_line: 246 | return 0.0 247 | 248 | intersect = (right_line - left_line) * (bottom_line - top_line) 249 | return (intersect / (sum_area - intersect)) * 1.0 250 | -------------------------------------------------------------------------------- /rapid_table/table_structure/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | 5 | -------------------------------------------------------------------------------- /rapid_table/table_structure/pp_structure/__init__.py: -------------------------------------------------------------------------------- 1 | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from .main import PPTableStructurer 15 | -------------------------------------------------------------------------------- /rapid_table/table_structure/pp_structure/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Any, Dict, List, Tuple 15 | 16 | import numpy as np 17 | 18 | from ...inference_engine.base import get_engine 19 | from ...utils.typings import EngineType 20 | from .post_process import TableLabelDecode 21 | from .pre_process import TablePreprocess 22 | 23 | 24 | class PPTableStructurer: 25 | def __init__(self, cfg: Dict[str, Any]): 26 | if cfg["engine_type"] is None: 27 | cfg["engine_type"] = EngineType.ONNXRUNTIME 28 | 29 | self.session = get_engine(cfg["engine_type"])(cfg) 30 | self.cfg = cfg 31 | 32 | self.preprocess_op = TablePreprocess() 33 | 34 | character = self.session.get_character_list() 35 | self.postprocess_op = TableLabelDecode(character, cfg) 36 | 37 | def __call__( 38 | self, ori_imgs: List[np.ndarray] 39 | ) -> Tuple[List[str], List[np.ndarray]]: 40 | imgs, shape_lists = self.preprocess_op(ori_imgs) 41 | 42 | bbox_preds, struct_probs = self.session(imgs.copy()) 43 | 44 | table_structs, cell_bboxes = self.postprocess_op( 45 | bbox_preds, struct_probs, shape_lists, ori_imgs 46 | ) 47 | return table_structs, cell_bboxes 48 | -------------------------------------------------------------------------------- /rapid_table/table_structure/pp_structure/post_process.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from typing import List, Tuple 5 | 6 | import numpy as np 7 | 8 | from ...utils.typings import ModelType 9 | from ..utils import wrap_with_html_struct 10 | 11 | 12 | class TableLabelDecode: 13 | def __init__(self, dict_character, cfg, merge_no_span_structure=True): 14 | if merge_no_span_structure: 15 | if "" not in dict_character: 16 | dict_character.append("") 17 | if "" in dict_character: 18 | dict_character.remove("") 19 | 20 | dict_character = self.add_special_char(dict_character) 21 | self.char_to_index = {} 22 | for i, char in enumerate(dict_character): 23 | self.char_to_index[char] = i 24 | 25 | self.character = dict_character 26 | self.td_token = ["", ""] 27 | self.cfg = cfg 28 | 29 | def __call__( 30 | self, 31 | bbox_preds: np.ndarray, 32 | structure_probs: np.ndarray, 33 | shape_list: np.ndarray, 34 | ori_imgs: np.ndarray, 35 | ): 36 | result = self.decode(bbox_preds, structure_probs, shape_list, ori_imgs) 37 | return result 38 | 39 | def decode( 40 | self, 41 | bbox_preds: np.ndarray, 42 | structure_probs: np.ndarray, 43 | shape_list: np.ndarray, 44 | ori_imgs: np.ndarray, 45 | ) -> Tuple[List[Tuple[List[str], float]], List[np.ndarray]]: 46 | """convert text-label into text-index.""" 47 | ignored_tokens = self.get_ignored_tokens() 48 | end_idx = self.char_to_index[self.end_str] 49 | 50 | structure_idx = structure_probs.argmax(axis=2) 51 | structure_probs = structure_probs.max(axis=2) 52 | 53 | table_structs, cell_bboxes = [], [] 54 | batch_size = len(structure_idx) 55 | for batch_idx in range(batch_size): 56 | structure_list, bbox_list, score_list = [], [], [] 57 | for idx in range(len(structure_idx[batch_idx])): 58 | char_idx = int(structure_idx[batch_idx][idx]) 59 | if idx > 0 and char_idx == end_idx: 60 | break 61 | 62 | if char_idx in ignored_tokens: 63 | continue 64 | 65 | text = self.character[char_idx] 66 | if text in self.td_token: 67 | bbox = bbox_preds[batch_idx, idx] 68 | bbox = self._bbox_decode(bbox, shape_list[batch_idx]) 69 | bbox_list.append(bbox) 70 | 71 | structure_list.append(text) 72 | score_list.append(structure_probs[batch_idx, idx]) 73 | 74 | bboxes = self.normalize_bboxes(bbox_list, ori_imgs[batch_idx]) 75 | cell_bboxes.append(bboxes) 76 | 77 | table_structs.append( 78 | (wrap_with_html_struct(structure_list), float(np.mean(score_list))) 79 | ) 80 | return table_structs, cell_bboxes 81 | 82 | def _bbox_decode(self, bbox, shape): 83 | h, w = shape[:2] 84 | bbox[0::2] *= w 85 | bbox[1::2] *= h 86 | return bbox 87 | 88 | def get_ignored_tokens(self): 89 | beg_idx = self.get_beg_end_flag_idx("beg") 90 | end_idx = self.get_beg_end_flag_idx("end") 91 | return [beg_idx, end_idx] 92 | 93 | def get_beg_end_flag_idx(self, beg_or_end): 94 | if beg_or_end == "beg": 95 | return np.array(self.char_to_index[self.beg_str]) 96 | 97 | if beg_or_end == "end": 98 | return np.array(self.char_to_index[self.end_str]) 99 | 100 | raise TypeError(f"unsupport type {beg_or_end} in get_beg_end_flag_idx") 101 | 102 | def add_special_char(self, dict_character): 103 | self.beg_str = "sos" 104 | self.end_str = "eos" 105 | dict_character = [self.beg_str] + dict_character + [self.end_str] 106 | return dict_character 107 | 108 | def normalize_bboxes(self, bbox_list, ori_imgs): 109 | cell_bboxes = np.array(bbox_list) 110 | if self.cfg["model_type"] == ModelType.SLANETPLUS: 111 | cell_bboxes = self.rescale_cell_bboxes(ori_imgs, cell_bboxes) 112 | cell_bboxes = self.filter_blank_bbox(cell_bboxes) 113 | return cell_bboxes 114 | 115 | def rescale_cell_bboxes( 116 | self, img: np.ndarray, cell_bboxes: np.ndarray 117 | ) -> np.ndarray: 118 | h, w = img.shape[:2] 119 | resized = 488 120 | ratio = min(resized / h, resized / w) 121 | w_ratio = resized / (w * ratio) 122 | h_ratio = resized / (h * ratio) 123 | cell_bboxes[:, 0::2] *= w_ratio 124 | cell_bboxes[:, 1::2] *= h_ratio 125 | return cell_bboxes 126 | 127 | @staticmethod 128 | def filter_blank_bbox(cell_bboxes: np.ndarray) -> np.ndarray: 129 | # 过滤掉占位的bbox 130 | mask = ~np.all(cell_bboxes == 0, axis=1) 131 | return cell_bboxes[mask] 132 | -------------------------------------------------------------------------------- /rapid_table/table_structure/pp_structure/pre_process.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from typing import List, Tuple 5 | 6 | import cv2 7 | import numpy as np 8 | 9 | 10 | class TablePreprocess: 11 | def __init__(self, max_len: int = 488): 12 | self.max_len = max_len 13 | 14 | self.std = np.array([0.229, 0.224, 0.225]) 15 | self.mean = np.array([0.485, 0.456, 0.406]) 16 | self.scale = 1 / 255.0 17 | 18 | def __call__( 19 | self, img_list: List[np.ndarray] 20 | ) -> Tuple[List[np.ndarray], np.ndarray]: 21 | if isinstance(img_list, np.ndarray): 22 | img_list = [img_list] 23 | 24 | processed_imgs, shape_lists = [], [] 25 | for img in img_list: 26 | if img is None: 27 | continue 28 | 29 | img_processed, shape_list = self.resize_image(img) 30 | img_processed = self.normalize(img_processed) 31 | img_processed, shape_list = self.pad_img(img_processed, shape_list) 32 | img_processed = self.to_chw(img_processed) 33 | 34 | processed_imgs.append(img_processed) 35 | shape_lists.append(shape_list) 36 | 37 | return processed_imgs, np.array(shape_lists) 38 | 39 | def resize_image(self, img: np.ndarray) -> Tuple[np.ndarray, List[float]]: 40 | h, w = img.shape[:2] 41 | ratio = self.max_len / (max(h, w) * 1.0) 42 | resize_h, resize_w = int(h * ratio), int(w * ratio) 43 | 44 | resize_img = cv2.resize(img, (resize_w, resize_h)) 45 | return resize_img, [h, w, ratio, ratio] 46 | 47 | def normalize(self, img: np.ndarray) -> np.ndarray: 48 | return (img.astype("float32") * self.scale - self.mean) / self.std 49 | 50 | def pad_img( 51 | self, img: np.ndarray, shape: List[float] 52 | ) -> Tuple[np.ndarray, List[float]]: 53 | padding_img = np.zeros((self.max_len, self.max_len, 3), dtype=np.float32) 54 | h, w = img.shape[:2] 55 | padding_img[:h, :w, :] = img.copy() 56 | shape.extend([self.max_len, self.max_len]) 57 | return padding_img, shape 58 | 59 | def to_chw(self, img: np.ndarray) -> np.ndarray: 60 | return img.transpose((2, 0, 1)) 61 | -------------------------------------------------------------------------------- /rapid_table/table_structure/unitable/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from .main import UniTableStructure 5 | -------------------------------------------------------------------------------- /rapid_table/table_structure/unitable/consts.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | IMG_SIZE = 448 5 | MAX_SEQ_LEN = 1024 6 | EOS_TOKEN = "" 7 | BBOX_TOKENS = [f"bbox-{i}" for i in range(IMG_SIZE + 1)] 8 | 9 | HTML_BBOX_HTML_TOKENS = [ 10 | "", 11 | "[", 12 | "]", 13 | "[", 15 | ">", 16 | "", 17 | "", 18 | "", 19 | "", 20 | "", 21 | "", 22 | ' rowspan="2"', 23 | ' rowspan="3"', 24 | ' rowspan="4"', 25 | ' rowspan="5"', 26 | ' rowspan="6"', 27 | ' rowspan="7"', 28 | ' rowspan="8"', 29 | ' rowspan="9"', 30 | ' rowspan="10"', 31 | ' rowspan="11"', 32 | ' rowspan="12"', 33 | ' rowspan="13"', 34 | ' rowspan="14"', 35 | ' rowspan="15"', 36 | ' rowspan="16"', 37 | ' rowspan="17"', 38 | ' rowspan="18"', 39 | ' rowspan="19"', 40 | ' colspan="2"', 41 | ' colspan="3"', 42 | ' colspan="4"', 43 | ' colspan="5"', 44 | ' colspan="6"', 45 | ' colspan="7"', 46 | ' colspan="8"', 47 | ' colspan="9"', 48 | ' colspan="10"', 49 | ' colspan="11"', 50 | ' colspan="12"', 51 | ' colspan="13"', 52 | ' colspan="14"', 53 | ' colspan="15"', 54 | ' colspan="16"', 55 | ' colspan="17"', 56 | ' colspan="18"', 57 | ' colspan="19"', 58 | ' colspan="25"', 59 | ] 60 | 61 | VALID_HTML_BBOX_TOKENS = [EOS_TOKEN] + HTML_BBOX_HTML_TOKENS + BBOX_TOKENS 62 | TASK_TOKENS = [ 63 | "[table]", 64 | "[html]", 65 | "[cell]", 66 | "[bbox]", 67 | "[cell+bbox]", 68 | "[html+bbox]", 69 | ] 70 | -------------------------------------------------------------------------------- /rapid_table/table_structure/unitable/main.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | import re 3 | from typing import Any, Dict, List 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from ...inference_engine.base import get_engine 9 | from ...utils import EngineType 10 | from ..utils import wrap_with_html_struct 11 | from .consts import ( 12 | BBOX_TOKENS, 13 | EOS_TOKEN, 14 | MAX_SEQ_LEN, 15 | TASK_TOKENS, 16 | VALID_HTML_BBOX_TOKENS, 17 | ) 18 | from .post_process import rescale_bboxes 19 | from .pre_process import TablePreprocess 20 | 21 | 22 | class UniTableStructure: 23 | def __init__(self, cfg: Dict[str, Any]): 24 | if cfg["engine_type"] is None: 25 | cfg["engine_type"] = EngineType.TORCH 26 | self.model = get_engine(cfg["engine_type"])(cfg) 27 | 28 | self.encoder = self.model.encoder 29 | self.device = self.model.device 30 | 31 | self.vocab = self.model.vocab 32 | 33 | self.token_white_list = [ 34 | self.vocab.token_to_id(i) for i in VALID_HTML_BBOX_TOKENS 35 | ] 36 | 37 | self.bbox_token_ids = set(self.vocab.token_to_id(i) for i in BBOX_TOKENS) 38 | self.bbox_close_html_token = self.vocab.token_to_id("]") 39 | 40 | self.prefix_token_id = self.vocab.token_to_id("[html+bbox]") 41 | 42 | self.eos_id = self.vocab.token_to_id(EOS_TOKEN) 43 | 44 | self.context = ( 45 | torch.tensor([self.prefix_token_id], dtype=torch.int32) 46 | .repeat(1, 1) 47 | .to(self.device) 48 | ) 49 | self.eos_id_tensor = torch.tensor(self.eos_id, dtype=torch.int32).to( 50 | self.device 51 | ) 52 | 53 | self.max_seq_len = MAX_SEQ_LEN 54 | 55 | self.decoder = self.model.decoder 56 | 57 | self.preprocess_op = TablePreprocess(self.device) 58 | 59 | def __call__(self, imgs: List[np.ndarray]): 60 | img_batch, ori_shapes = self.preprocess_op(imgs) 61 | memory_batch = self.encoder(img_batch) 62 | 63 | struct_list, total_bboxes = [], [] 64 | for i, memory in enumerate(memory_batch): 65 | self.decoder.setup_caches( 66 | max_batch_size=1, 67 | max_seq_length=self.max_seq_len, 68 | dtype=torch.float32, 69 | device=self.device, 70 | ) 71 | 72 | context = self.loop_decode( 73 | self.context, self.eos_id_tensor, memory[None, ...] 74 | ) 75 | bboxes, html_tokens = self.decode_tokens(context) 76 | 77 | ori_h, ori_w = ori_shapes[i] 78 | one_bboxes = rescale_bboxes(ori_h, ori_w, bboxes) 79 | total_bboxes.append(one_bboxes) 80 | 81 | one_struct = wrap_with_html_struct(html_tokens) 82 | struct_list.append((one_struct, 1.0)) 83 | return struct_list, total_bboxes 84 | 85 | def loop_decode(self, context, eos_id_tensor, memory): 86 | box_token_count = 0 87 | for _ in range(self.max_seq_len): 88 | eos_flag = (context == eos_id_tensor).any(dim=1) 89 | if torch.all(eos_flag): 90 | break 91 | 92 | next_tokens = self.decoder(memory, context) 93 | if next_tokens[0] in self.bbox_token_ids: 94 | box_token_count += 1 95 | if box_token_count > 4: 96 | next_tokens = torch.tensor( 97 | [self.bbox_close_html_token], dtype=torch.int32 98 | ) 99 | box_token_count = 0 100 | context = torch.cat([context, next_tokens], dim=1) 101 | return context 102 | 103 | def decode_tokens(self, context): 104 | pred_html = context[0] 105 | pred_html = pred_html.detach().cpu().numpy() 106 | pred_html = self.vocab.decode(pred_html, skip_special_tokens=False) 107 | seq = pred_html.split("")[0] 108 | token_black_list = ["", "", *TASK_TOKENS] 109 | for i in token_black_list: 110 | seq = seq.replace(i, "") 111 | 112 | tr_pattern = re.compile(r"(.*?)", re.DOTALL) 113 | td_pattern = re.compile(r"(.*?)", re.DOTALL) 114 | bbox_pattern = re.compile(r"\[ bbox-(\d+) bbox-(\d+) bbox-(\d+) bbox-(\d+) \]") 115 | 116 | decoded_list, bbox_coords = [], [] 117 | 118 | # 查找所有的 标签 119 | for tr_match in tr_pattern.finditer(pred_html): 120 | tr_content = tr_match.group(1) 121 | decoded_list.append("") 122 | 123 | # 查找所有的 标签 124 | for td_match in td_pattern.finditer(tr_content): 125 | td_attrs = td_match.group(1).strip() 126 | td_content = td_match.group(2).strip() 127 | if td_attrs: 128 | decoded_list.append("") 134 | decoded_list.append("") 135 | else: 136 | decoded_list.append("") 137 | 138 | # 查找 bbox 坐标 139 | bbox_match = bbox_pattern.search(td_content) 140 | if bbox_match: 141 | xmin, ymin, xmax, ymax = map(int, bbox_match.groups()) 142 | # 将坐标转换为从左上角开始顺时针到左下角的点的坐标 143 | coords = np.array([xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax]) 144 | bbox_coords.append(coords) 145 | else: 146 | # 填充占位的bbox,保证后续流程统一 147 | bbox_coords.append(np.array([0, 0, 0, 0, 0, 0, 0, 0])) 148 | decoded_list.append("") 149 | 150 | bbox_coords_array = np.array(bbox_coords).astype(np.float32) 151 | return bbox_coords_array, decoded_list 152 | -------------------------------------------------------------------------------- /rapid_table/table_structure/unitable/post_process.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | import numpy as np 5 | 6 | from .consts import IMG_SIZE 7 | 8 | 9 | def rescale_bboxes(ori_h, ori_w, bboxes): 10 | scale_h = ori_h / IMG_SIZE 11 | scale_w = ori_w / IMG_SIZE 12 | bboxes[:, 0::2] *= scale_w 13 | bboxes[:, 1::2] *= scale_h 14 | bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, ori_w - 1) 15 | bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, ori_h - 1) 16 | return bboxes 17 | -------------------------------------------------------------------------------- /rapid_table/table_structure/unitable/pre_process.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from typing import List 5 | 6 | import cv2 7 | import numpy as np 8 | import torch 9 | from PIL import Image 10 | from torchvision import transforms 11 | 12 | from .consts import IMG_SIZE 13 | 14 | 15 | class TablePreprocess: 16 | def __init__(self, device: str): 17 | self.img_size = IMG_SIZE 18 | self.transform = transforms.Compose( 19 | [ 20 | transforms.Resize((448, 448)), 21 | transforms.ToTensor(), 22 | transforms.Normalize( 23 | mean=[0.86597056, 0.88463002, 0.87491087], 24 | std=[0.20686628, 0.18201602, 0.18485524], 25 | ), 26 | ] 27 | ) 28 | 29 | self.device = device 30 | 31 | def __call__(self, imgs: List[np.ndarray]): 32 | processed_imgs, ori_shapes = [], [] 33 | for img in imgs: 34 | if img is None: 35 | continue 36 | 37 | ori_h, ori_w = img.shape[:2] 38 | ori_shapes.append((ori_h, ori_w)) 39 | 40 | processed_img = self.preprocess_img(img) 41 | processed_imgs.append(processed_img) 42 | return torch.concatenate(processed_imgs), ori_shapes 43 | 44 | def preprocess_img(self, ori_image: np.ndarray) -> torch.Tensor: 45 | image = cv2.cvtColor(ori_image, cv2.COLOR_BGR2RGB) 46 | image = Image.fromarray(image) 47 | image = self.transform(image).unsqueeze(0).to(self.device) 48 | return image 49 | -------------------------------------------------------------------------------- /rapid_table/table_structure/unitable/unitable_modules.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from functools import partial 3 | from typing import Optional 4 | 5 | import torch 6 | from torch import Tensor, nn 7 | from torch.nn import functional as F 8 | from torch.nn.attention import SDPBackend, sdpa_kernel 9 | from torch.nn.modules.transformer import _get_activation_fn 10 | 11 | TOKEN_WHITE_LIST = [ 12 | 1, 13 | 12, 14 | 13, 15 | 14, 16 | 15, 17 | 16, 18 | 17, 19 | 18, 20 | 19, 21 | 20, 22 | 21, 23 | 22, 24 | 23, 25 | 24, 26 | 25, 27 | 26, 28 | 27, 29 | 28, 30 | 29, 31 | 30, 32 | 31, 33 | 32, 34 | 33, 35 | 34, 36 | 35, 37 | 36, 38 | 37, 39 | 38, 40 | 39, 41 | 40, 42 | 41, 43 | 42, 44 | 43, 45 | 44, 46 | 45, 47 | 46, 48 | 47, 49 | 48, 50 | 49, 51 | 50, 52 | 51, 53 | 52, 54 | 53, 55 | 54, 56 | 55, 57 | 56, 58 | 57, 59 | 58, 60 | 59, 61 | 60, 62 | 61, 63 | 62, 64 | 63, 65 | 64, 66 | 65, 67 | 66, 68 | 67, 69 | 68, 70 | 69, 71 | 70, 72 | 71, 73 | 72, 74 | 73, 75 | 74, 76 | 75, 77 | 76, 78 | 77, 79 | 78, 80 | 79, 81 | 80, 82 | 81, 83 | 82, 84 | 83, 85 | 84, 86 | 85, 87 | 86, 88 | 87, 89 | 88, 90 | 89, 91 | 90, 92 | 91, 93 | 92, 94 | 93, 95 | 94, 96 | 95, 97 | 96, 98 | 97, 99 | 98, 100 | 99, 101 | 100, 102 | 101, 103 | 102, 104 | 103, 105 | 104, 106 | 105, 107 | 106, 108 | 107, 109 | 108, 110 | 109, 111 | 110, 112 | 111, 113 | 112, 114 | 113, 115 | 114, 116 | 115, 117 | 116, 118 | 117, 119 | 118, 120 | 119, 121 | 120, 122 | 121, 123 | 122, 124 | 123, 125 | 124, 126 | 125, 127 | 126, 128 | 127, 129 | 128, 130 | 129, 131 | 130, 132 | 131, 133 | 132, 134 | 133, 135 | 134, 136 | 135, 137 | 136, 138 | 137, 139 | 138, 140 | 139, 141 | 140, 142 | 141, 143 | 142, 144 | 143, 145 | 144, 146 | 145, 147 | 146, 148 | 147, 149 | 148, 150 | 149, 151 | 150, 152 | 151, 153 | 152, 154 | 153, 155 | 154, 156 | 155, 157 | 156, 158 | 157, 159 | 158, 160 | 159, 161 | 160, 162 | 161, 163 | 162, 164 | 163, 165 | 164, 166 | 165, 167 | 166, 168 | 167, 169 | 168, 170 | 169, 171 | 170, 172 | 171, 173 | 172, 174 | 173, 175 | 174, 176 | 175, 177 | 176, 178 | 177, 179 | 178, 180 | 179, 181 | 180, 182 | 181, 183 | 182, 184 | 183, 185 | 184, 186 | 185, 187 | 186, 188 | 187, 189 | 188, 190 | 189, 191 | 190, 192 | 191, 193 | 192, 194 | 193, 195 | 194, 196 | 195, 197 | 196, 198 | 197, 199 | 198, 200 | 199, 201 | 200, 202 | 201, 203 | 202, 204 | 203, 205 | 204, 206 | 205, 207 | 206, 208 | 207, 209 | 208, 210 | 209, 211 | 210, 212 | 211, 213 | 212, 214 | 213, 215 | 214, 216 | 215, 217 | 216, 218 | 217, 219 | 218, 220 | 219, 221 | 220, 222 | 221, 223 | 222, 224 | 223, 225 | 224, 226 | 225, 227 | 226, 228 | 227, 229 | 228, 230 | 229, 231 | 230, 232 | 231, 233 | 232, 234 | 233, 235 | 234, 236 | 235, 237 | 236, 238 | 237, 239 | 238, 240 | 239, 241 | 240, 242 | 241, 243 | 242, 244 | 243, 245 | 244, 246 | 245, 247 | 246, 248 | 247, 249 | 248, 250 | 249, 251 | 250, 252 | 251, 253 | 252, 254 | 253, 255 | 254, 256 | 255, 257 | 256, 258 | 257, 259 | 258, 260 | 259, 261 | 260, 262 | 261, 263 | 262, 264 | 263, 265 | 264, 266 | 265, 267 | 266, 268 | 267, 269 | 268, 270 | 269, 271 | 270, 272 | 271, 273 | 272, 274 | 273, 275 | 274, 276 | 275, 277 | 276, 278 | 277, 279 | 278, 280 | 279, 281 | 280, 282 | 281, 283 | 282, 284 | 283, 285 | 284, 286 | 285, 287 | 286, 288 | 287, 289 | 288, 290 | 289, 291 | 290, 292 | 291, 293 | 292, 294 | 293, 295 | 294, 296 | 295, 297 | 296, 298 | 297, 299 | 298, 300 | 299, 301 | 300, 302 | 301, 303 | 302, 304 | 303, 305 | 304, 306 | 305, 307 | 306, 308 | 307, 309 | 308, 310 | 309, 311 | 310, 312 | 311, 313 | 312, 314 | 313, 315 | 314, 316 | 315, 317 | 316, 318 | 317, 319 | 318, 320 | 319, 321 | 320, 322 | 321, 323 | 322, 324 | 323, 325 | 324, 326 | 325, 327 | 326, 328 | 327, 329 | 328, 330 | 329, 331 | 330, 332 | 331, 333 | 332, 334 | 333, 335 | 334, 336 | 335, 337 | 336, 338 | 337, 339 | 338, 340 | 339, 341 | 340, 342 | 341, 343 | 342, 344 | 343, 345 | 344, 346 | 345, 347 | 346, 348 | 347, 349 | 348, 350 | 349, 351 | 350, 352 | 351, 353 | 352, 354 | 353, 355 | 354, 356 | 355, 357 | 356, 358 | 357, 359 | 358, 360 | 359, 361 | 360, 362 | 361, 363 | 362, 364 | 363, 365 | 364, 366 | 365, 367 | 366, 368 | 367, 369 | 368, 370 | 369, 371 | 370, 372 | 371, 373 | 372, 374 | 373, 375 | 374, 376 | 375, 377 | 376, 378 | 377, 379 | 378, 380 | 379, 381 | 380, 382 | 381, 383 | 382, 384 | 383, 385 | 384, 386 | 385, 387 | 386, 388 | 387, 389 | 388, 390 | 389, 391 | 390, 392 | 391, 393 | 392, 394 | 393, 395 | 394, 396 | 395, 397 | 396, 398 | 397, 399 | 398, 400 | 399, 401 | 400, 402 | 401, 403 | 402, 404 | 403, 405 | 404, 406 | 405, 407 | 406, 408 | 407, 409 | 408, 410 | 409, 411 | 410, 412 | 411, 413 | 412, 414 | 413, 415 | 414, 416 | 415, 417 | 416, 418 | 417, 419 | 418, 420 | 419, 421 | 420, 422 | 421, 423 | 422, 424 | 423, 425 | 424, 426 | 425, 427 | 426, 428 | 427, 429 | 428, 430 | 429, 431 | 430, 432 | 431, 433 | 432, 434 | 433, 435 | 434, 436 | 435, 437 | 436, 438 | 437, 439 | 438, 440 | 439, 441 | 440, 442 | 441, 443 | 442, 444 | 443, 445 | 444, 446 | 445, 447 | 446, 448 | 447, 449 | 448, 450 | 449, 451 | 450, 452 | 451, 453 | 452, 454 | 453, 455 | 454, 456 | 455, 457 | 456, 458 | 457, 459 | 458, 460 | 459, 461 | 460, 462 | 461, 463 | 462, 464 | 463, 465 | 464, 466 | 465, 467 | 466, 468 | 467, 469 | 468, 470 | 469, 471 | 470, 472 | 471, 473 | 472, 474 | 473, 475 | 474, 476 | 475, 477 | 476, 478 | 477, 479 | 478, 480 | 479, 481 | 480, 482 | 481, 483 | 482, 484 | 483, 485 | 484, 486 | 485, 487 | 486, 488 | 487, 489 | 488, 490 | 489, 491 | 490, 492 | 491, 493 | 492, 494 | 493, 495 | 494, 496 | 495, 497 | 496, 498 | 497, 499 | 498, 500 | 499, 501 | 500, 502 | 501, 503 | 502, 504 | 503, 505 | 504, 506 | 505, 507 | 506, 508 | 507, 509 | 508, 510 | 509, 511 | ] 512 | 513 | 514 | class ImgLinearBackbone(nn.Module): 515 | def __init__( 516 | self, 517 | d_model: int, 518 | patch_size: int, 519 | in_chan: int = 3, 520 | ) -> None: 521 | super().__init__() 522 | 523 | self.conv_proj = nn.Conv2d( 524 | in_chan, 525 | out_channels=d_model, 526 | kernel_size=patch_size, 527 | stride=patch_size, 528 | ) 529 | self.d_model = d_model 530 | 531 | def forward(self, x: Tensor) -> Tensor: 532 | x = self.conv_proj(x) 533 | x = x.flatten(start_dim=-2).transpose(1, 2) 534 | return x 535 | 536 | 537 | class Encoder(nn.Module): 538 | def __init__(self) -> None: 539 | super().__init__() 540 | 541 | self.patch_size = 16 542 | self.d_model = 768 543 | self.dropout = 0 544 | self.activation = "gelu" 545 | self.norm_first = True 546 | self.ff_ratio = 4 547 | self.nhead = 12 548 | self.max_seq_len = 1024 549 | self.n_encoder_layer = 12 550 | encoder_layer = nn.TransformerEncoderLayer( 551 | self.d_model, 552 | nhead=self.nhead, 553 | dim_feedforward=self.ff_ratio * self.d_model, 554 | dropout=self.dropout, 555 | activation=self.activation, 556 | batch_first=True, 557 | norm_first=self.norm_first, 558 | ) 559 | norm_layer = partial(nn.LayerNorm, eps=1e-6) 560 | self.norm = norm_layer(self.d_model) 561 | self.backbone = ImgLinearBackbone( 562 | d_model=self.d_model, patch_size=self.patch_size 563 | ) 564 | self.pos_embed = PositionEmbedding( 565 | max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout 566 | ) 567 | self.encoder = nn.TransformerEncoder( 568 | encoder_layer, num_layers=self.n_encoder_layer, enable_nested_tensor=False 569 | ) 570 | 571 | def forward(self, x: Tensor) -> Tensor: 572 | src_feature = self.backbone(x) 573 | src_feature = self.pos_embed(src_feature) 574 | memory = self.encoder(src_feature) 575 | memory = self.norm(memory) 576 | return memory 577 | 578 | 579 | class PositionEmbedding(nn.Module): 580 | def __init__(self, max_seq_len: int, d_model: int, dropout: float) -> None: 581 | super().__init__() 582 | self.embedding = nn.Embedding(max_seq_len, d_model) 583 | self.dropout = nn.Dropout(dropout) 584 | 585 | def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: 586 | # assume x is batch first 587 | if input_pos is None: 588 | _pos = torch.arange(x.shape[1], device=x.device) 589 | else: 590 | _pos = input_pos 591 | out = self.embedding(_pos) 592 | return self.dropout(out + x) 593 | 594 | 595 | class TokenEmbedding(nn.Module): 596 | def __init__( 597 | self, 598 | vocab_size: int, 599 | d_model: int, 600 | padding_idx: int, 601 | ) -> None: 602 | super().__init__() 603 | assert vocab_size > 0 604 | self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx) 605 | 606 | def forward(self, x: Tensor) -> Tensor: 607 | return self.embedding(x) 608 | 609 | 610 | def find_multiple(n: int, k: int) -> int: 611 | if n % k == 0: 612 | return n 613 | return n + k - (n % k) 614 | 615 | 616 | @dataclass 617 | class ModelArgs: 618 | n_layer: int = 4 619 | n_head: int = 12 620 | dim: int = 768 621 | intermediate_size: int = None 622 | head_dim: int = 64 623 | activation: str = "gelu" 624 | norm_first: bool = True 625 | 626 | def __post_init__(self): 627 | if self.intermediate_size is None: 628 | hidden_dim = 4 * self.dim 629 | n_hidden = int(2 * hidden_dim / 3) 630 | self.intermediate_size = find_multiple(n_hidden, 256) 631 | self.head_dim = self.dim // self.n_head 632 | 633 | 634 | class KVCache(nn.Module): 635 | def __init__( 636 | self, 637 | max_batch_size, 638 | max_seq_length, 639 | n_heads, 640 | head_dim, 641 | dtype=torch.bfloat16, 642 | device="cpu", 643 | ): 644 | super().__init__() 645 | cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) 646 | self.register_buffer( 647 | "k_cache", 648 | torch.zeros(cache_shape, dtype=dtype, device=device), 649 | persistent=False, 650 | ) 651 | self.register_buffer( 652 | "v_cache", 653 | torch.zeros(cache_shape, dtype=dtype, device=device), 654 | persistent=False, 655 | ) 656 | 657 | def update(self, input_pos, k_val, v_val): 658 | bs = k_val.shape[0] 659 | k_out = self.k_cache 660 | v_out = self.v_cache 661 | k_out[:bs, :, input_pos] = k_val 662 | v_out[:bs, :, input_pos] = v_val 663 | 664 | return k_out[:bs], v_out[:bs] 665 | 666 | 667 | class GPTFastDecoder(nn.Module): 668 | def __init__(self) -> None: 669 | super().__init__() 670 | 671 | self.vocab_size = 960 672 | self.padding_idx = 2 673 | self.prefix_token_id = 11 674 | self.eos_id = 1 675 | self.max_seq_len = 1024 676 | self.dropout = 0 677 | self.d_model = 768 678 | self.nhead = 12 679 | self.activation = "gelu" 680 | self.norm_first = True 681 | self.n_decoder_layer = 4 682 | config = ModelArgs( 683 | n_layer=self.n_decoder_layer, 684 | n_head=self.nhead, 685 | dim=self.d_model, 686 | intermediate_size=self.d_model * 4, 687 | activation=self.activation, 688 | norm_first=self.norm_first, 689 | ) 690 | self.config = config 691 | self.layers = nn.ModuleList( 692 | TransformerBlock(config) for _ in range(config.n_layer) 693 | ) 694 | self.token_embed = TokenEmbedding( 695 | vocab_size=self.vocab_size, 696 | d_model=self.d_model, 697 | padding_idx=self.padding_idx, 698 | ) 699 | self.pos_embed = PositionEmbedding( 700 | max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout 701 | ) 702 | self.generator = nn.Linear(self.d_model, self.vocab_size) 703 | self.token_white_list = TOKEN_WHITE_LIST 704 | self.mask_cache: Optional[Tensor] = None 705 | self.max_batch_size = -1 706 | self.max_seq_length = -1 707 | 708 | def setup_caches(self, max_batch_size, max_seq_length, dtype, device): 709 | for b in self.layers: 710 | b.multihead_attn.k_cache = None 711 | b.multihead_attn.v_cache = None 712 | 713 | if ( 714 | self.max_seq_length >= max_seq_length 715 | and self.max_batch_size >= max_batch_size 716 | ): 717 | return 718 | head_dim = self.config.dim // self.config.n_head 719 | max_seq_length = find_multiple(max_seq_length, 8) 720 | self.max_seq_length = max_seq_length 721 | self.max_batch_size = max_batch_size 722 | 723 | for b in self.layers: 724 | b.self_attn.kv_cache = KVCache( 725 | max_batch_size, 726 | max_seq_length, 727 | self.config.n_head, 728 | head_dim, 729 | dtype, 730 | device, 731 | ) 732 | b.multihead_attn.k_cache = None 733 | b.multihead_attn.v_cache = None 734 | 735 | self.causal_mask = torch.tril( 736 | torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool) 737 | ).to(device) 738 | 739 | def forward(self, memory: Tensor, tgt: Tensor) -> Tensor: 740 | input_pos = torch.tensor([tgt.shape[1] - 1], device=tgt.device, dtype=torch.int) 741 | tgt = tgt[:, -1:] 742 | tgt_feature = self.pos_embed(self.token_embed(tgt), input_pos=input_pos) 743 | 744 | with sdpa_kernel(SDPBackend.MATH): 745 | logits = tgt_feature 746 | tgt_mask = self.causal_mask[None, None, input_pos] 747 | for layer in self.layers: 748 | logits = layer(logits, memory, input_pos=input_pos, tgt_mask=tgt_mask) 749 | 750 | logits = self.generator(logits)[:, -1, :] 751 | total = set(range(logits.shape[-1])) 752 | black_list = list(total.difference(set(self.token_white_list))) 753 | logits[..., black_list] = -1e9 754 | probs = F.softmax(logits, dim=-1) 755 | _, next_tokens = probs.topk(1) 756 | return next_tokens 757 | 758 | 759 | class TransformerBlock(nn.Module): 760 | def __init__(self, config: ModelArgs) -> None: 761 | super().__init__() 762 | self.self_attn = Attention(config) 763 | self.multihead_attn = CrossAttention(config) 764 | 765 | layer_norm_eps = 1e-5 766 | 767 | d_model = config.dim 768 | dim_feedforward = config.intermediate_size 769 | 770 | self.linear1 = nn.Linear(d_model, dim_feedforward) 771 | self.linear2 = nn.Linear(dim_feedforward, d_model) 772 | 773 | self.norm_first = config.norm_first 774 | self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) 775 | self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) 776 | self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps) 777 | 778 | self.activation = _get_activation_fn(config.activation) 779 | 780 | def forward( 781 | self, 782 | tgt: Tensor, 783 | memory: Tensor, 784 | tgt_mask: Tensor, 785 | input_pos: Tensor, 786 | ) -> Tensor: 787 | if self.norm_first: 788 | x = tgt 789 | x = x + self.self_attn(self.norm1(x), tgt_mask, input_pos) 790 | x = x + self.multihead_attn(self.norm2(x), memory) 791 | x = x + self._ff_block(self.norm3(x)) 792 | else: 793 | x = tgt 794 | x = self.norm1(x + self.self_attn(x, tgt_mask, input_pos)) 795 | x = self.norm2(x + self.multihead_attn(x, memory)) 796 | x = self.norm3(x + self._ff_block(x)) 797 | return x 798 | 799 | def _ff_block(self, x: Tensor) -> Tensor: 800 | x = self.linear2(self.activation(self.linear1(x))) 801 | return x 802 | 803 | 804 | class Attention(nn.Module): 805 | def __init__(self, config: ModelArgs): 806 | super().__init__() 807 | assert config.dim % config.n_head == 0 808 | 809 | # key, query, value projections for all heads, but in a batch 810 | self.wqkv = nn.Linear(config.dim, 3 * config.dim) 811 | self.wo = nn.Linear(config.dim, config.dim) 812 | 813 | self.kv_cache: Optional[KVCache] = None 814 | 815 | self.n_head = config.n_head 816 | self.head_dim = config.head_dim 817 | self.dim = config.dim 818 | 819 | def forward( 820 | self, 821 | x: Tensor, 822 | mask: Tensor, 823 | input_pos: Optional[Tensor] = None, 824 | ) -> Tensor: 825 | bsz, seqlen, _ = x.shape 826 | 827 | kv_size = self.n_head * self.head_dim 828 | q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) 829 | 830 | q = q.view(bsz, seqlen, self.n_head, self.head_dim) 831 | k = k.view(bsz, seqlen, self.n_head, self.head_dim) 832 | v = v.view(bsz, seqlen, self.n_head, self.head_dim) 833 | 834 | q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) 835 | 836 | if self.kv_cache is not None: 837 | k, v = self.kv_cache.update(input_pos, k, v) 838 | 839 | y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) 840 | 841 | y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) 842 | 843 | y = self.wo(y) 844 | return y 845 | 846 | 847 | class CrossAttention(nn.Module): 848 | def __init__(self, config: ModelArgs): 849 | super().__init__() 850 | assert config.dim % config.n_head == 0 851 | 852 | self.query = nn.Linear(config.dim, config.dim) 853 | self.key = nn.Linear(config.dim, config.dim) 854 | self.value = nn.Linear(config.dim, config.dim) 855 | self.out = nn.Linear(config.dim, config.dim) 856 | 857 | self.k_cache = None 858 | self.v_cache = None 859 | 860 | self.n_head = config.n_head 861 | self.head_dim = config.head_dim 862 | 863 | def get_kv(self, xa: torch.Tensor): 864 | if self.k_cache is not None and self.v_cache is not None: 865 | return self.k_cache, self.v_cache 866 | 867 | k = self.key(xa) 868 | v = self.value(xa) 869 | 870 | # Reshape for correct format 871 | batch_size, source_seq_len, _ = k.shape 872 | k = k.view(batch_size, source_seq_len, self.n_head, self.head_dim) 873 | v = v.view(batch_size, source_seq_len, self.n_head, self.head_dim) 874 | 875 | if self.k_cache is None: 876 | self.k_cache = k 877 | 878 | if self.v_cache is None: 879 | self.v_cache = v 880 | 881 | return k, v 882 | 883 | def forward( 884 | self, 885 | x: Tensor, 886 | xa: Tensor, 887 | ): 888 | q = self.query(x) 889 | batch_size, target_seq_len, _ = q.shape 890 | q = q.view(batch_size, target_seq_len, self.n_head, self.head_dim) 891 | k, v = self.get_kv(xa) 892 | 893 | q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) 894 | 895 | wv = F.scaled_dot_product_attention( 896 | query=q, 897 | key=k, 898 | value=v, 899 | is_causal=False, 900 | ) 901 | wv = wv.transpose(1, 2).reshape( 902 | batch_size, 903 | target_seq_len, 904 | self.n_head * self.head_dim, 905 | ) 906 | 907 | return self.out(wv) 908 | -------------------------------------------------------------------------------- /rapid_table/table_structure/utils.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from typing import List 5 | 6 | 7 | def wrap_with_html_struct(structure_str_list: List[str]) -> List[str]: 8 | structure_str_list = ( 9 | ["", "", ""] 10 | + structure_str_list 11 | + ["
", "", ""] 12 | ) 13 | return structure_str_list 14 | -------------------------------------------------------------------------------- /rapid_table/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from .download_file import DownloadFile, DownloadFileInput 5 | from .load_image import InputType, LoadImage 6 | from .logger import Logger 7 | from .typings import EngineType, ModelType, RapidTableInput, RapidTableOutput 8 | from .utils import format_ocr_results, import_package, is_url, mkdir, read_yaml 9 | from .vis import VisTable 10 | -------------------------------------------------------------------------------- /rapid_table/utils/download_file.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | import logging 5 | import sys 6 | from dataclasses import dataclass 7 | from pathlib import Path 8 | from typing import Optional, Union 9 | 10 | import requests 11 | from tqdm import tqdm 12 | 13 | from .utils import get_file_sha256 14 | 15 | 16 | @dataclass 17 | class DownloadFileInput: 18 | file_url: str 19 | save_path: Union[str, Path] 20 | logger: logging.Logger 21 | sha256: Optional[str] = None 22 | 23 | 24 | class DownloadFile: 25 | BLOCK_SIZE = 1024 # 1 KiB 26 | REQUEST_TIMEOUT = 60 27 | 28 | @classmethod 29 | def run(cls, input_params: DownloadFileInput): 30 | save_path = Path(input_params.save_path) 31 | 32 | logger = input_params.logger 33 | cls._ensure_parent_dir_exists(save_path) 34 | if cls._should_skip_download(save_path, input_params.sha256, logger): 35 | return 36 | 37 | response = cls._make_http_request(input_params.file_url, logger) 38 | cls._save_response_with_progress(response, save_path, logger) 39 | 40 | @staticmethod 41 | def _ensure_parent_dir_exists(path: Path): 42 | path.parent.mkdir(parents=True, exist_ok=True) 43 | 44 | @classmethod 45 | def _should_skip_download( 46 | cls, path: Path, expected_sha256: Optional[str], logger: logging.Logger 47 | ) -> bool: 48 | if not path.exists(): 49 | return False 50 | 51 | if expected_sha256 is None: 52 | logger.info("File exists (no checksum verification): %s", path) 53 | return True 54 | 55 | if cls.check_file_sha256(path, expected_sha256): 56 | logger.info("File exists and is valid: %s", path) 57 | return True 58 | 59 | logger.warning("File exists but is invalid, redownloading: %s", path) 60 | return False 61 | 62 | @classmethod 63 | def _make_http_request(cls, url: str, logger: logging.Logger) -> requests.Response: 64 | logger.info("Initiating download: %s", url) 65 | try: 66 | response = requests.get(url, stream=True, timeout=cls.REQUEST_TIMEOUT) 67 | response.raise_for_status() # Raises HTTPError for 4XX/5XX 68 | return response 69 | except requests.RequestException as e: 70 | logger.error("Download failed: %s", url) 71 | raise DownloadFileException(f"Failed to download {url}") from e 72 | 73 | @classmethod 74 | def _save_response_with_progress( 75 | cls, response: requests.Response, save_path: Path, logger: logging.Logger 76 | ) -> None: 77 | total_size = int(response.headers.get("content-length", 0)) 78 | logger.info("Download size: %.2fMB", total_size / 1024 / 1024) 79 | 80 | with tqdm( 81 | total=total_size, 82 | unit="iB", 83 | unit_scale=True, 84 | disable=not cls.check_is_atty(), 85 | ) as progress_bar: 86 | with open(save_path, "wb") as output_file: 87 | for chunk in response.iter_content(chunk_size=cls.BLOCK_SIZE): 88 | progress_bar.update(len(chunk)) 89 | output_file.write(chunk) 90 | 91 | logger.info("Successfully saved to: %s", save_path) 92 | 93 | @staticmethod 94 | def check_file_sha256(file_path: Union[str, Path], gt_sha256: str) -> bool: 95 | return get_file_sha256(file_path) == gt_sha256 96 | 97 | @staticmethod 98 | def check_is_atty() -> bool: 99 | try: 100 | is_interactive = sys.stderr.isatty() 101 | except AttributeError: 102 | return False 103 | return is_interactive 104 | 105 | 106 | class DownloadFileException(Exception): 107 | pass 108 | -------------------------------------------------------------------------------- /rapid_table/utils/load_image.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from io import BytesIO 5 | from pathlib import Path 6 | from typing import Any, Union 7 | 8 | import cv2 9 | import numpy as np 10 | import requests 11 | from PIL import Image, UnidentifiedImageError 12 | 13 | from .utils import is_url 14 | 15 | root_dir = Path(__file__).resolve().parent 16 | InputType = Union[str, np.ndarray, bytes, Path, Image.Image] 17 | 18 | 19 | class LoadImage: 20 | def __init__(self): 21 | pass 22 | 23 | def __call__(self, img: InputType) -> np.ndarray: 24 | if not isinstance(img, InputType.__args__): 25 | raise LoadImageError( 26 | f"The img type {type(img)} does not in {InputType.__args__}" 27 | ) 28 | origin_img_type = type(img) 29 | img = self.load_img(img) 30 | img = self.convert_img(img, origin_img_type) 31 | return img 32 | 33 | def load_img(self, img: InputType) -> np.ndarray: 34 | if isinstance(img, (str, Path)): 35 | if is_url(img): 36 | img = Image.open(requests.get(img, stream=True, timeout=60).raw) 37 | else: 38 | self.verify_exist(img) 39 | img = Image.open(img) 40 | 41 | try: 42 | img = self.img_to_ndarray(img) 43 | except UnidentifiedImageError as e: 44 | raise LoadImageError(f"cannot identify image file {img}") from e 45 | return img 46 | 47 | if isinstance(img, bytes): 48 | img = self.img_to_ndarray(Image.open(BytesIO(img))) 49 | return img 50 | 51 | if isinstance(img, np.ndarray): 52 | return img 53 | 54 | if isinstance(img, Image.Image): 55 | return self.img_to_ndarray(img) 56 | 57 | raise LoadImageError(f"{type(img)} is not supported!") 58 | 59 | def img_to_ndarray(self, img: Image.Image) -> np.ndarray: 60 | if img.mode == "1": 61 | img = img.convert("L") 62 | return np.array(img) 63 | return np.array(img) 64 | 65 | def convert_img(self, img: np.ndarray, origin_img_type: Any) -> np.ndarray: 66 | if img.ndim == 2: 67 | return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 68 | 69 | if img.ndim == 3: 70 | channel = img.shape[2] 71 | if channel == 1: 72 | return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 73 | 74 | if channel == 2: 75 | return self.cvt_two_to_three(img) 76 | 77 | if channel == 3: 78 | if issubclass(origin_img_type, (str, Path, bytes, Image.Image)): 79 | return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 80 | return img 81 | 82 | if channel == 4: 83 | return self.cvt_four_to_three(img) 84 | 85 | raise LoadImageError( 86 | f"The channel({channel}) of the img is not in [1, 2, 3, 4]" 87 | ) 88 | 89 | raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]") 90 | 91 | @staticmethod 92 | def cvt_two_to_three(img: np.ndarray) -> np.ndarray: 93 | """gray + alpha → BGR""" 94 | img_gray = img[..., 0] 95 | img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR) 96 | 97 | img_alpha = img[..., 1] 98 | not_a = cv2.bitwise_not(img_alpha) 99 | not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) 100 | 101 | new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha) 102 | new_img = cv2.add(new_img, not_a) 103 | return new_img 104 | 105 | @staticmethod 106 | def cvt_four_to_three(img: np.ndarray) -> np.ndarray: 107 | """RGBA → BGR""" 108 | r, g, b, a = cv2.split(img) 109 | new_img = cv2.merge((b, g, r)) 110 | 111 | not_a = cv2.bitwise_not(a) 112 | not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) 113 | 114 | new_img = cv2.bitwise_and(new_img, new_img, mask=a) 115 | 116 | mean_color = np.mean(new_img) 117 | if mean_color <= 0.0: 118 | new_img = cv2.add(new_img, not_a) 119 | else: 120 | new_img = cv2.bitwise_not(new_img) 121 | return new_img 122 | 123 | @staticmethod 124 | def verify_exist(file_path: Union[str, Path]): 125 | if not Path(file_path).exists(): 126 | raise LoadImageError(f"{file_path} does not exist.") 127 | 128 | 129 | class LoadImageError(Exception): 130 | pass 131 | -------------------------------------------------------------------------------- /rapid_table/utils/logger.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: Jocker1212 3 | # @Contact: xinyijianggo@gmail.com 4 | import logging 5 | 6 | import colorlog 7 | 8 | 9 | class Logger: 10 | def __init__(self, log_level=logging.DEBUG, logger_name=None): 11 | self.logger = logging.getLogger(logger_name) 12 | self.logger.setLevel(log_level) 13 | self.logger.propagate = False 14 | 15 | formatter = colorlog.ColoredFormatter( 16 | "%(log_color)s[%(levelname)s] %(asctime)s [RapidTable] %(filename)s:%(lineno)d: %(message)s", 17 | log_colors={ 18 | "DEBUG": "cyan", 19 | "INFO": "green", 20 | "WARNING": "yellow", 21 | "ERROR": "red", 22 | "CRITICAL": "red,bg_white", 23 | }, 24 | ) 25 | 26 | if not self.logger.handlers: 27 | console_handler = logging.StreamHandler() 28 | console_handler.setFormatter(formatter) 29 | 30 | for handler in self.logger.handlers: 31 | self.logger.removeHandler(handler) 32 | 33 | console_handler.setLevel(log_level) 34 | self.logger.addHandler(console_handler) 35 | 36 | def get_log(self): 37 | return self.logger 38 | -------------------------------------------------------------------------------- /rapid_table/utils/typings.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from dataclasses import dataclass, field 5 | from enum import Enum 6 | from pathlib import Path 7 | from typing import Dict, List, Optional, Tuple, Union 8 | 9 | import numpy as np 10 | 11 | from .utils import mkdir 12 | from .vis import VisTable 13 | 14 | 15 | class EngineType(Enum): 16 | ONNXRUNTIME = "onnxruntime" 17 | TORCH = "torch" 18 | 19 | 20 | class ModelType(Enum): 21 | PPSTRUCTURE_EN = "ppstructure_en" 22 | PPSTRUCTURE_ZH = "ppstructure_zh" 23 | SLANETPLUS = "slanet_plus" 24 | UNITABLE = "unitable" 25 | 26 | 27 | @dataclass 28 | class RapidTableInput: 29 | model_type: Optional[ModelType] = ModelType.SLANETPLUS 30 | model_dir_or_path: Union[str, Path, None, Dict[str, str]] = None 31 | 32 | engine_type: Optional[EngineType] = None 33 | engine_cfg: dict = field(default_factory=dict) 34 | 35 | use_ocr: bool = True 36 | ocr_params: dict = field(default_factory=dict) 37 | 38 | 39 | @dataclass 40 | class RapidTableOutput: 41 | imgs: List[np.ndarray] = field(default_factory=list) 42 | pred_htmls: List[str] = field(default_factory=list) 43 | cell_bboxes: List[np.ndarray] = field(default_factory=list) 44 | logic_points: List[np.ndarray] = field(default_factory=list) 45 | elapse: float = 0.0 46 | 47 | def vis( 48 | self, 49 | save_dir: Union[str, Path], 50 | save_name: str, 51 | indexes: Tuple[int, ...] = (0,), 52 | ) -> List[np.ndarray]: 53 | vis = VisTable() 54 | 55 | save_dir = Path(save_dir) 56 | mkdir(save_dir) 57 | 58 | results = [] 59 | for idx in indexes: 60 | save_one_dir = save_dir / str(idx) 61 | mkdir(save_one_dir) 62 | 63 | save_html_path = save_one_dir / f"{save_name}.html" 64 | save_drawed_path = save_one_dir / f"{save_name}_vis.jpg" 65 | save_logic_points_path = save_one_dir / f"{save_name}_col_row_vis.jpg" 66 | 67 | vis_img = vis( 68 | self.imgs[idx], 69 | self.pred_htmls[idx], 70 | self.cell_bboxes[idx], 71 | self.logic_points[idx], 72 | save_html_path, 73 | save_drawed_path, 74 | save_logic_points_path, 75 | ) 76 | results.append(vis_img) 77 | return results 78 | -------------------------------------------------------------------------------- /rapid_table/utils/utils.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | import hashlib 5 | import importlib 6 | from pathlib import Path 7 | from typing import List, Tuple, Union 8 | from urllib.parse import urlparse 9 | 10 | import cv2 11 | import numpy as np 12 | from omegaconf import DictConfig, OmegaConf 13 | 14 | 15 | def format_ocr_results( 16 | ocr_results: Tuple[np.ndarray, Tuple[str], Tuple[float]], img_h: int, img_w: int 17 | ) -> Tuple[np.ndarray, List[Tuple[str, float]]]: 18 | rec_res = list(zip(ocr_results[1], ocr_results[2])) 19 | 20 | bboxes = np.array(ocr_results[0]) 21 | min_coords = bboxes[..., :2].min(axis=1) 22 | max_coords = bboxes[..., :2].max(axis=1) 23 | 24 | min_coords = np.maximum(min_coords, 0) 25 | max_coords = np.minimum(max_coords, [img_w, img_h]) 26 | dt_boxes = np.hstack([min_coords, max_coords]) 27 | return dt_boxes, rec_res 28 | 29 | 30 | def save_img(save_path: Union[str, Path], img: np.ndarray): 31 | cv2.imwrite(str(save_path), img) 32 | 33 | 34 | def save_txt(save_path: Union[str, Path], txt: str): 35 | with open(save_path, "w", encoding="utf-8") as f: 36 | f.write(txt) 37 | 38 | 39 | def import_package(name, package=None): 40 | try: 41 | module = importlib.import_module(name, package=package) 42 | return module 43 | except ModuleNotFoundError: 44 | return None 45 | 46 | 47 | def mkdir(dir_path): 48 | Path(dir_path).mkdir(parents=True, exist_ok=True) 49 | 50 | 51 | def read_yaml(file_path: Union[str, Path]) -> DictConfig: 52 | return OmegaConf.load(file_path) 53 | 54 | 55 | def get_file_sha256(file_path: Union[str, Path], chunk_size: int = 65536) -> str: 56 | with open(file_path, "rb") as file: 57 | sha_signature = hashlib.sha256() 58 | while True: 59 | chunk = file.read(chunk_size) 60 | if not chunk: 61 | break 62 | sha_signature.update(chunk) 63 | 64 | return sha_signature.hexdigest() 65 | 66 | 67 | def is_url(url: str) -> bool: 68 | try: 69 | result = urlparse(url) 70 | return all([result.scheme, result.netloc]) 71 | except Exception: 72 | return False 73 | -------------------------------------------------------------------------------- /rapid_table/utils/vis.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from pathlib import Path 5 | from typing import Union 6 | 7 | import cv2 8 | import numpy as np 9 | 10 | from .logger import Logger 11 | from .utils import save_img, save_txt 12 | 13 | 14 | class VisTable: 15 | def __init__(self): 16 | self.logger = Logger(logger_name=__name__).get_log() 17 | 18 | def __call__( 19 | self, 20 | img: np.ndarray, 21 | pred_html: str, 22 | cell_bboxes: np.ndarray, 23 | logic_points: np.ndarray, 24 | save_html_path: Union[str, Path, None] = None, 25 | save_drawed_path: Union[str, Path, None] = None, 26 | save_logic_path: Union[str, Path, None] = None, 27 | ): 28 | if pred_html and save_html_path: 29 | html_with_border = self.insert_border_style(pred_html) 30 | save_txt(save_html_path, html_with_border) 31 | self.logger.info(f"Save HTML to {save_html_path}") 32 | 33 | if cell_bboxes is None: 34 | return None 35 | 36 | drawed_img = self.draw(img, cell_bboxes) 37 | if save_drawed_path: 38 | save_img(save_drawed_path, drawed_img) 39 | self.logger.info(f"Saved table struacter result to {save_drawed_path}") 40 | 41 | if save_logic_path and logic_points.size > 0: 42 | self.plot_rec_box_with_logic_info( 43 | img, save_logic_path, logic_points, cell_bboxes 44 | ) 45 | self.logger.info(f"Saved rec and box result to {save_logic_path}") 46 | return drawed_img 47 | 48 | def insert_border_style(self, table_html_str: str) -> str: 49 | style_res = """""" 63 | 64 | prefix_table, suffix_table = table_html_str.split("") 65 | html_with_border = f"{prefix_table}{style_res}{suffix_table}" 66 | return html_with_border 67 | 68 | def draw(self, img: np.ndarray, cell_bboxes: np.ndarray) -> np.ndarray: 69 | dims_bboxes = cell_bboxes.shape[1] 70 | if dims_bboxes == 4: 71 | return self.draw_rectangle(img, cell_bboxes) 72 | 73 | if dims_bboxes == 8: 74 | return self.draw_polylines(img, cell_bboxes) 75 | 76 | raise ValueError("Shape of table bounding boxes is not between in 4 or 8.") 77 | 78 | @staticmethod 79 | def draw_rectangle(img: np.ndarray, boxes: np.ndarray) -> np.ndarray: 80 | img_copy = img.copy() 81 | for box in boxes.astype(int): 82 | x1, y1, x2, y2 = box 83 | cv2.rectangle(img_copy, (x1, y1), (x2, y2), (255, 0, 0), 2) 84 | return img_copy 85 | 86 | @staticmethod 87 | def draw_polylines(img: np.ndarray, points) -> np.ndarray: 88 | img_copy = img.copy() 89 | for point in points.astype(int): 90 | point = point.reshape(4, 2) 91 | cv2.polylines(img_copy, [point.astype(int)], True, (255, 0, 0), 2) 92 | return img_copy 93 | 94 | def plot_rec_box_with_logic_info( 95 | self, img: np.ndarray, output_path, logic_points, cell_bboxes 96 | ): 97 | img = cv2.copyMakeBorder( 98 | img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255] 99 | ) 100 | 101 | polygons = [[box[0], box[1], box[4], box[5]] for box in cell_bboxes] 102 | for idx, polygon in enumerate(polygons): 103 | x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3] 104 | x0 = round(x0) 105 | y0 = round(y0) 106 | x1 = round(x1) 107 | y1 = round(y1) 108 | cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1) 109 | 110 | # 增大字体大小和线宽 111 | font_scale = 0.9 # 原先是0.5 112 | thickness = 1 # 原先是1 113 | logic_point = logic_points[idx] 114 | cv2.putText( 115 | img, 116 | f"row: {logic_point[0]}-{logic_point[1]}", 117 | (x0 + 3, y0 + 8), 118 | cv2.FONT_HERSHEY_PLAIN, 119 | font_scale, 120 | (0, 0, 255), 121 | thickness, 122 | ) 123 | cv2.putText( 124 | img, 125 | f"col: {logic_point[2]}-{logic_point[3]}", 126 | (x0 + 3, y0 + 18), 127 | cv2.FONT_HERSHEY_PLAIN, 128 | font_scale, 129 | (0, 0, 255), 130 | thickness, 131 | ) 132 | 133 | save_img(output_path, img) 134 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | onnxruntime>1.17.0 2 | opencv_python>=4.5.1.48 3 | numpy>=1.21.6 4 | Pillow 5 | requests 6 | colorlog 7 | omegaconf -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | import sys 5 | from pathlib import Path 6 | from typing import List, Union 7 | 8 | import setuptools 9 | from get_pypi_latest_version import GetPyPiLatestVersion 10 | 11 | 12 | def read_txt(txt_path: Union[Path, str]) -> List[str]: 13 | with open(txt_path, "r", encoding="utf-8") as f: 14 | data = [v.rstrip("\n") for v in f] 15 | return data 16 | 17 | 18 | def get_readme(): 19 | root_dir = Path(__file__).resolve().parent 20 | readme_path = str(root_dir / "docs" / "doc_whl_rapid_table.md") 21 | with open(readme_path, "r", encoding="utf-8") as f: 22 | readme = f.read() 23 | return readme 24 | 25 | 26 | MODULE_NAME = "rapid_table" 27 | obtainer = GetPyPiLatestVersion() 28 | try: 29 | latest_version = obtainer(MODULE_NAME) 30 | except Exception: 31 | latest_version = "0.0.0" 32 | VERSION_NUM = obtainer.version_add_one(latest_version) 33 | 34 | if len(sys.argv) > 2: 35 | match_str = " ".join(sys.argv[2:]) 36 | matched_versions = obtainer.extract_version(match_str) 37 | if matched_versions: 38 | VERSION_NUM = matched_versions 39 | sys.argv = sys.argv[:2] 40 | 41 | setuptools.setup( 42 | name=MODULE_NAME, 43 | version=VERSION_NUM, 44 | platforms="Any", 45 | long_description=get_readme(), 46 | long_description_content_type="text/markdown", 47 | description="Table Recognition", 48 | author="SWHL", 49 | author_email="liekkaskono@163.com", 50 | url="https://github.com/RapidAI/RapidTable", 51 | license="Apache-2.0", 52 | include_package_data=True, 53 | install_requires=read_txt("requirements.txt"), 54 | packages=setuptools.find_packages(), 55 | package_data={"": ["*.onnx", "*.yaml"]}, 56 | keywords=["ppstructure,table,rapidocr,rapid_table"], 57 | classifiers=[ 58 | "Programming Language :: Python :: 3.6", 59 | "Programming Language :: Python :: 3.7", 60 | "Programming Language :: Python :: 3.8", 61 | "Programming Language :: Python :: 3.9", 62 | "Programming Language :: Python :: 3.10", 63 | "Programming Language :: Python :: 3.11", 64 | "Programming Language :: Python :: 3.12", 65 | "Programming Language :: Python :: 3.13", 66 | ], 67 | python_requires=">=3.6,<4", 68 | entry_points={"console_scripts": [f"{MODULE_NAME}={MODULE_NAME}.main:main"]}, 69 | extras_require={"torch": ["torch", "torchvision", "tokenizers"]}, 70 | ) 71 | -------------------------------------------------------------------------------- /tests/test_files/table.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidTable/b0ba540a2a5ac06a61bcd6506795246a5046a291/tests/test_files/table.jpg -------------------------------------------------------------------------------- /tests/test_files/table_without_txt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidTable/b0ba540a2a5ac06a61bcd6506795246a5046a291/tests/test_files/table_without_txt.jpg -------------------------------------------------------------------------------- /tests/test_main.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | import shlex 5 | import sys 6 | from ast import literal_eval 7 | from pathlib import Path 8 | 9 | import pytest 10 | from rapidocr import RapidOCR 11 | 12 | cur_dir = Path(__file__).resolve().parent 13 | root_dir = cur_dir.parent 14 | 15 | sys.path.append(str(root_dir)) 16 | from rapid_table import EngineType, ModelType, RapidTable, RapidTableInput 17 | from rapid_table.main import main 18 | 19 | ocr_engine = RapidOCR() 20 | table_engine = RapidTable() 21 | 22 | test_file_dir = cur_dir / "test_files" 23 | img_path = str(test_file_dir / "table.jpg") 24 | img_url = "https://raw.githubusercontent.com/RapidAI/RapidTable/refs/heads/main/tests/test_files/table.jpg" 25 | 26 | 27 | def test_only_table(): 28 | img_path = test_file_dir / "table_without_txt.jpg" 29 | table_engine = RapidTable(RapidTableInput(use_ocr=False)) 30 | results = table_engine(img_path) 31 | 32 | assert len(results.pred_htmls) == 0 33 | assert results.cell_bboxes[0].shape == (16, 8) 34 | 35 | 36 | def test_without_txt_table(): 37 | img_path = test_file_dir / "table_without_txt.jpg" 38 | results = table_engine(img_path) 39 | 40 | assert results.pred_htmls[0] is None 41 | assert results.cell_bboxes[0].shape == (16, 8) 42 | 43 | 44 | @pytest.mark.parametrize( 45 | "command, expected_output", 46 | [ 47 | (f"{img_path} --model_type slanet_plus", 1274), 48 | (f"{img_url} --model_type slanet_plus", 1274), 49 | ], 50 | ) 51 | def test_main_cli(capsys, command, expected_output): 52 | main(shlex.split(command)) 53 | output = capsys.readouterr().out.rstrip() 54 | assert len(literal_eval(output)[0]) == expected_output 55 | 56 | 57 | @pytest.mark.parametrize( 58 | "model_type,engine_type", 59 | [ 60 | (ModelType.SLANETPLUS, EngineType.ONNXRUNTIME), 61 | (ModelType.UNITABLE, EngineType.TORCH), 62 | ], 63 | ) 64 | def test_ocr_input(model_type, engine_type): 65 | ori_ocr_res = ocr_engine(img_path) 66 | ocr_results = [ori_ocr_res.boxes, ori_ocr_res.txts, ori_ocr_res.scores] 67 | 68 | input_args = RapidTableInput(model_type=model_type, engine_type=engine_type) 69 | table_engine = RapidTable(input_args) 70 | table_results = table_engine(img_path, ocr_results=[ocr_results]) 71 | assert table_results.pred_htmls[0].count("") == 16 72 | 73 | 74 | @pytest.mark.parametrize( 75 | "model_type,engine_type", 76 | [ 77 | (ModelType.SLANETPLUS, EngineType.ONNXRUNTIME), 78 | (ModelType.UNITABLE, EngineType.TORCH), 79 | ], 80 | ) 81 | def test_input_ocr_none(model_type, engine_type): 82 | input_args = RapidTableInput(model_type=model_type, engine_type=engine_type) 83 | table_engine = RapidTable(input_args) 84 | table_results = table_engine(img_path) 85 | assert table_results.pred_htmls[0].count("") == 16 86 | assert len(table_results.cell_bboxes) == len(table_results.logic_points) 87 | -------------------------------------------------------------------------------- /tests/test_table_matcher.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | import sys 3 | import warnings 4 | 5 | import pytest 6 | 7 | 8 | @pytest.mark.skipif( 9 | sys.version_info.major == 3 and sys.version_info.minor < 12, 10 | reason="仅在python>=3.12时测试", 11 | ) 12 | def test_regex_syntax_warning(): 13 | """测试捕获正则表达式中无效转义序列产生的 SyntaxWarning""" 14 | 15 | with warnings.catch_warnings(record=True) as w: 16 | warnings.simplefilter("always") 17 | 18 | # 使用 compile() 来编译包含无效转义序列的代码,这会触发 SyntaxWarning 19 | code_with_invalid_escape = """ 20 | import re 21 | thead_part = ' rowspan="2">
' 22 | isolate_pattern = ( 23 | ' rowspan="(\d)+" colspan="(\d)+">
|' 24 | ' colspan="(\d)+" rowspan="(\d)+">
|' 25 | ' rowspan="(\d)+">
|' 26 | ' colspan="(\d)+">
' 27 | ) 28 | re.finditer(isolate_pattern, thead_part) 29 | """ 30 | 31 | # 编译代码时会产生 SyntaxWarning 32 | compile(code_with_invalid_escape, "", "exec") 33 | 34 | # 检查是否捕获到 SyntaxWarning 35 | syntax_warnings = [ 36 | warn for warn in w if issubclass(warn.category, SyntaxWarning) 37 | ] 38 | assert ( 39 | len(syntax_warnings) > 0 40 | ), f"未捕获到 SyntaxWarning: {[str(warn.message) for warn in w]}" 41 | 42 | # 应该捕获到无效转义序列的警告 43 | for warning in syntax_warnings: 44 | assert "invalid escape sequence" in str(warning.message) 45 | 46 | 47 | @pytest.mark.skipif( 48 | sys.version_info.major == 3 and sys.version_info.minor < 12, 49 | reason="仅在python>=3.12时测试", 50 | ) 51 | def test_correct_regex_pattern(): 52 | with warnings.catch_warnings(record=True) as w: 53 | warnings.simplefilter("always") 54 | 55 | # 这不会触发 SyntaxWarning 56 | code_with_invalid_escape = """ 57 | import re 58 | thead_part = ' rowspan="2">' 59 | isolate_pattern_raw = ( 60 | r' rowspan="(\d)+" colspan="(\d)+">|' 61 | r' colspan="(\d)+" rowspan="(\d)+">|' 62 | r' rowspan="(\d)+">|' 63 | r' colspan="(\d)+">' 64 | ) 65 | re.finditer(isolate_pattern_raw, thead_part) 66 | """ 67 | compile(code_with_invalid_escape, "", "exec") 68 | 69 | # 检查是否捕获到 SyntaxWarning 70 | syntax_warnings = [ 71 | warn for warn in w if issubclass(warn.category, SyntaxWarning) 72 | ] 73 | assert ( 74 | len(syntax_warnings) == 0 75 | ), f"正常写法捕获到 SyntaxWarning: {[str(warn.message) for warn in w]}" 76 | --------------------------------------------------------------------------------