├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yaml │ └── config.yml └── workflows │ ├── rapid_table_det.yml │ └── rapid_table_det_paddle.yml ├── .gitignore ├── .ipynb_checkpoints └── onnx_transform-checkpoint.ipynb ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── README_en.md ├── __init__.py ├── demo_onnx.py ├── demo_paddle.py ├── rapid_table_det ├── __init__.py ├── inference.py ├── models │ └── .gitkeep ├── predictor.py ├── requirments.txt └── utils │ ├── __init__.py │ ├── download_model.py │ ├── infer_engine.py │ ├── load_image.py │ ├── logger.py │ ├── transform.py │ └── visuallize.py ├── rapid_table_det_paddle ├── __init__.py ├── inference.py ├── models │ └── .gitkeep ├── predictor.py ├── requirments.txt └── utils.py ├── readme_resource ├── res_show.jpg ├── res_show2.jpg └── structure.png ├── requirements.txt ├── setup_rapid_table_det.py ├── setup_rapid_table_det_paddle.py ├── tests ├── __init__.py ├── test_files │ ├── chip.jpg │ ├── chip2.jpg │ └── doc.png ├── test_table_det.py └── test_table_det_paddle.py └── tools ├── __init__.py ├── fix_onnx.py └── onnx_transform.ipynb /.github/ISSUE_TEMPLATE/bug_report.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🐞 Bug 3 | about: Bug 4 | title: 'Bug' 5 | labels: 'Bug' 6 | assignees: '' 7 | 8 | --- 9 | 10 | 请提供下述完整信息以便快速定位问题 11 | (Please provide the following information to quickly locate the problem) 12 | - **系统环境/System Environment**: 13 | - **使用的是哪门语言的程序/Which programing language**: 14 | - **使用当前库的版本/Use version**: 15 | - **可复现问题的demo和文件/Demo of reproducible problems**: 16 | - **完整报错/Complete Error Message**: 17 | - **可能的解决方案/Possible solutions**: -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: ❓ Questions 4 | url: https://github.com/RapidAI/TableStructureRec/discussions/categories/q-a 5 | about: Please use the community forum for help and questions regarding ProcessLaTeXFormulaTools Docs 6 | - name: 💡 Feature requests and ideas 7 | url: https://github.com/RapidAI/TableStructureRec/discussions/new?category=feature-requests 8 | about: Please vote for and post new feature ideas in the community forum 9 | - name: 📖 Documentation 10 | url: https://rapidai.github.io/TableStructureRec/docs/ 11 | about: A great place to find instructions and answers about RapidOCR. -------------------------------------------------------------------------------- /.github/workflows/rapid_table_det.yml: -------------------------------------------------------------------------------- 1 | name: Push rapid_table_det_v to pypi 2 | 3 | on: 4 | push: 5 | tags: 6 | - rapid_table_det_v* 7 | 8 | jobs: 9 | UnitTesting: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Pull latest code 13 | uses: actions/checkout@v3 14 | 15 | - name: Set up Python 3.10 16 | uses: actions/setup-python@v4 17 | with: 18 | python-version: '3.10' 19 | architecture: 'x64' 20 | 21 | - name: Display Python version 22 | run: python -c "import sys; print(sys.version)" 23 | 24 | - name: Unit testings 25 | run: | 26 | pip install -r requirements.txt 27 | pip install pytest 28 | pytest tests/test_table_det.py 29 | 30 | GenerateWHL_PushPyPi: 31 | needs: UnitTesting 32 | runs-on: ubuntu-latest 33 | 34 | steps: 35 | - uses: actions/checkout@v3 36 | 37 | - name: Set up Python 3.10 38 | uses: actions/setup-python@v4 39 | with: 40 | python-version: '3.10' 41 | architecture: 'x64' 42 | 43 | - name: Run setup.py 44 | run: | 45 | pip install -r requirements.txt 46 | python -m pip install --upgrade pip 47 | pip install wheel get_pypi_latest_version 48 | python setup_rapid_table_det.py bdist_wheel "${{ github.ref_name }}" 49 | 50 | - name: Publish distribution 📦 to PyPI 51 | uses: pypa/gh-action-pypi-publish@v1.5.0 52 | with: 53 | password: ${{ secrets.PYPI_API_TOKEN }} 54 | packages_dir: dist/ 55 | -------------------------------------------------------------------------------- /.github/workflows/rapid_table_det_paddle.yml: -------------------------------------------------------------------------------- 1 | name: Push rapid_table_det_v to pypi 2 | 3 | on: 4 | push: 5 | tags: 6 | - rapid_table_det_paddle_v* 7 | 8 | jobs: 9 | UnitTesting: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Pull latest code 13 | uses: actions/checkout@v3 14 | 15 | - name: Set up Python 3.10 16 | uses: actions/setup-python@v4 17 | with: 18 | python-version: '3.10' 19 | architecture: 'x64' 20 | 21 | - name: Display Python version 22 | run: python -c "import sys; print(sys.version)" 23 | 24 | - name: Unit testings 25 | run: | 26 | pip install -r requirements.txt 27 | pip install paddlepaddle-gpu 28 | pip install pytest 29 | 30 | wget https://github.com/Joker1212/RapidTableDetection/releases/download/v0.0.0/cls_det_paddle.zip 31 | wget https://github.com/Joker1212/RapidTableDetection/releases/download/v0.0.0/obj_det_paddle.zip 32 | wget https://github.com/Joker1212/RapidTableDetection/releases/download/v0.0.0/edge_det_paddle.zip 33 | unzip cls_det_paddle.zip 34 | unzip obj_det_paddle.zip 35 | unzip edge_det_paddle.zip 36 | mv *.pd* rapid_table_det_paddle/models/ 37 | 38 | pytest tests/test_table_det_paddle.py 39 | 40 | GenerateWHL_PushPyPi: 41 | needs: UnitTesting 42 | runs-on: ubuntu-latest 43 | 44 | steps: 45 | - uses: actions/checkout@v3 46 | 47 | - name: Set up Python 3.10 48 | uses: actions/setup-python@v4 49 | with: 50 | python-version: '3.10' 51 | architecture: 'x64' 52 | 53 | - name: Run setup.py 54 | run: | 55 | pip install -r requirements.txt 56 | pip install paddlepaddle-gpu 57 | python -m pip install --upgrade pip 58 | pip install wheel get_pypi_latest_version 59 | 60 | wget https://github.com/Joker1212/RapidTableDetection/releases/download/v0.0.0/cls_det_paddle.zip 61 | wget https://github.com/Joker1212/RapidTableDetection/releases/download/v0.0.0/obj_det_paddle.zip 62 | wget https://github.com/Joker1212/RapidTableDetection/releases/download/v0.0.0/edge_det_paddle.zip 63 | unzip cls_det_paddle.zip 64 | unzip obj_det_paddle.zip 65 | unzip edge_det_paddle.zip 66 | mv *.pd* rapid_table_det_paddle/models/ 67 | 68 | python setup_rapid_table_det_paddle.py bdist_wheel "${{ github.ref_name }}" 69 | 70 | - name: Publish distribution 📦 to PyPI 71 | uses: pypa/gh-action-pypi-publish@v1.5.0 72 | with: 73 | password: ${{ secrets.PYPI_API_TOKEN }} 74 | packages_dir: dist/ 75 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /models/ 2 | /images/ 3 | /outputs/ 4 | /rapid_table_det_paddle/models/*.pd* 5 | /rapid_table_det_paddle/outputs/ 6 | /rapid_table_det/outputs/ 7 | /rapid_table_det/models/*onnx 8 | /tools/.ipynb_checkpoints/ 9 | /.ipynb_checkpoints/ 10 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/onnx_transform-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": {}, 7 | "source": [ 8 | "!pip install paddle2onnx onnxruntime onnxslim onnxruntime-tools onnx -i https://pypi.tuna.tsinghua.edu.cn/simple" 9 | ], 10 | "outputs": [] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": { 16 | "collapsed": false, 17 | "jupyter": { 18 | "is_executing": true, 19 | "outputs_hidden": false 20 | }, 21 | "pycharm": { 22 | "name": "#%%\n" 23 | } 24 | }, 25 | "source": [ 26 | "!paddle2onnx --model_dir rapid_table_det_paddle/models/obj_det --model_filename model.pdmodel --params_filename model.pdiparams --save_file rapid_table_det/models/obj_det.onnx --opset_version 16 --enable_onnx_checker True\n", 27 | "!paddle2onnx --model_dir rapid_table_det_paddle/models/db_net --model_filename model.pdmodel --params_filename model.pdiparams --save_file rapid_table_det/models/edge_det.onnx --opset_version 16 --enable_onnx_checker True\n", 28 | "!paddle2onnx --model_dir rapid_table_det_paddle/models/pplcnet --model_filename model.pdmodel --params_filename model.pdiparams --save_file rapid_table_det/models/cls_det.onnx --opset_version 16 --enable_onnx_checker True\n", 29 | "\n", 30 | "!onnxslim rapid_table_det/models/obj_det.onnx rapid_table_det/models/obj_det.onnx\n", 31 | "!onnxslim rapid_table_det/models/edge_det.onnx rapid_table_det/models/edge_det.onnx\n", 32 | "!onnxslim rapid_table_det/models/cls_det.onnx rapid_table_det/models/cls_det.onnx" 33 | ], 34 | "outputs": [] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 1, 39 | "metadata": { 40 | "ExecuteTime": { 41 | "end_time": "2024-10-15T12:12:35.265576Z", 42 | "start_time": "2024-10-15T12:12:34.281134Z" 43 | }, 44 | "collapsed": false, 45 | "jupyter": { 46 | "outputs_hidden": false 47 | }, 48 | "pycharm": { 49 | "name": "#%%\n" 50 | } 51 | }, 52 | "source": [ 53 | "from pathlib import Path\n", 54 | "import onnx\n", 55 | "from onnxruntime.quantization import quantize_dynamic, QuantType, quant_pre_process\n", 56 | "def quantize_model(root_dir_str, model_dir, pre_fix):\n", 57 | "\n", 58 | " original_model_path = f\"{pre_fix}.onnx\"\n", 59 | " # quantized_model_path = f\"{pre_fix}_quantized.onnx\"\n", 60 | " quantized_model_path = original_model_path\n", 61 | " original_model_path = f\"{root_dir_str}/{model_dir}/{original_model_path}\"\n", 62 | " quantized_model_path = f\"{root_dir_str}/{model_dir}/{quantized_model_path}\"\n", 63 | " quant_pre_process(original_model_path, quantized_model_path, auto_merge=True)\n", 64 | " # 进行动态量化\n", 65 | " quantize_dynamic(\n", 66 | " model_input=quantized_model_path,\n", 67 | " model_output=quantized_model_path,\n", 68 | " weight_type=QuantType.QUInt8\n", 69 | " )\n", 70 | "\n", 71 | " # 验证量化后的模型\n", 72 | " quantized_model = onnx.load(quantized_model_path)\n", 73 | " onnx.checker.check_model(quantized_model)\n", 74 | " print(\"Quantized model is valid.\")" 75 | ], 76 | "outputs": [] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 10, 81 | "metadata": { 82 | "collapsed": false, 83 | "jupyter": { 84 | "outputs_hidden": false 85 | }, 86 | "pycharm": { 87 | "name": "#%%\n" 88 | } 89 | }, 90 | "source": [ 91 | "root_dir_str = \".\"\n", 92 | "model_dir = f\"rapid_table_det/models\"\n", 93 | "quantize_model(root_dir_str, model_dir, \"obj_det\")\n", 94 | "quantize_model(root_dir_str, model_dir, \"edge_det\")\n", 95 | "# quantize_model(root_dir_str, model_dir, \"cls_det\")" 96 | ], 97 | "outputs": [] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": {}, 103 | "source": [], 104 | "outputs": [] 105 | } 106 | ], 107 | "metadata": { 108 | "kernelspec": { 109 | "display_name": "Python 3 (ipykernel)", 110 | "language": "python", 111 | "name": "python3" 112 | }, 113 | "language_info": { 114 | "codemirror_mode": { 115 | "name": "ipython", 116 | "version": 3 117 | }, 118 | "file_extension": ".py", 119 | "mimetype": "text/x-python", 120 | "name": "python", 121 | "nbconvert_exporter": "python", 122 | "pygments_lexer": "ipython3", 123 | "version": "3.10.14" 124 | } 125 | }, 126 | "nbformat": 4, 127 | "nbformat_minor": 4 128 | } 129 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://gitee.com/SWHL/autoflake 3 | rev: v2.1.1 4 | hooks: 5 | - id: autoflake 6 | args: 7 | [ 8 | "--recursive", 9 | "--in-place", 10 | "--remove-all-unused-imports", 11 | "--remove-unused-variable", 12 | "--ignore-init-module-imports", 13 | ] 14 | files: \.py$ 15 | - repo: https://gitee.com/SWHL/black 16 | rev: 23.1.0 17 | hooks: 18 | - id: black 19 | files: \.py$ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
13 | 14 | ### 最近更新 15 | 16 | - **2024.10.15** 17 | - 完成初版代码,包含目标检测,语义分割,角点方向识别三个模块 18 | - **2024.11.2** 19 | - 补充新训练yolo11的目标检测模型和边缘检测模型 20 | - 增加自动下载,轻量化包体积 21 | - 补充onnx-gpu推理支持,给出benchmark测试结果 22 | - 补充在线示例使用 23 | 24 | ### 简介 25 | 26 | 💡✨ 强大且高效的表格检测,支持论文、期刊、杂志、发票、收据、签到单等各种表格。 27 | 28 | 🚀 支持来源于paddle和yolo的版本,默认模型组合单图 CPU 推理仅需 1.2 秒,onnx-GPU(V100) 最小组合仅需 0.4 秒,paddle-gpu版0.2s 29 | 🛠️ 支持三个模块自由组合,独立训练调优,提供 ONNX 转换脚本和微调训练方案。 30 | 31 | 🌟 whl 包轻松集成使用,为下游 OCR、表格识别和数据采集提供强力支撑。 32 | 33 | 📚参考项目 [百度表格检测大赛第2名方案](https://aistudio.baidu.com/projectdetail/5398861?searchKeyword=%E8%A1%A8%E6%A0%BC%E6%A3%80%E6%B5%8B%E5%A4%A7%E8%B5%9B&searchTab=ALL) 34 | 的实现方案,补充大量真实场景数据再训练 35 |  36 | 👇🏻训练数据集在致谢, 作者天天上班摸鱼搞开源,希望大家点个⭐️支持一下 37 | 38 | ### 使用建议 39 | 40 | 📚 文档场景: 无透视旋转,只使用目标检测\ 41 | 📷 拍照场景小角度旋转(-90~90): 默认左上角,不使用角点方向识别\ 42 | 🔍 使用在线体验找到适合你场景的模型组合 43 | 44 | ### 在线体验 45 | [modelscope](https://www.modelscope.cn/studios/jockerK/RapidTableDetDemo) [huggingface](https://huggingface.co/spaces/Joker1212/RapidTableDetection) 46 | ### 效果展示 47 | 48 |  49 | 50 | ### 安装 51 | 52 | 🪜模型会自动下载,也可以自己去仓库下载 [modescope模型仓](https://www.modelscope.cn/models/jockerK/TableExtractor) 53 | 54 | ``` python {linenos=table} 55 | # 建议使用清华源安装 https://pypi.tuna.tsinghua.edu.cn/simple 56 | pip install rapid-table-det 57 | ``` 58 | 59 | #### 参数说明 60 | 61 | 默认值 62 | use_cuda: False : 启用gpu加速推理 \ 63 | obj_model_type="yolo_obj_det", \ 64 | edge_model_type= "yolo_edge_det", \ 65 | cls_model_type= "paddle_cls_det" 66 | 67 | 由于onnx使用gpu加速效果有限,还是建议直接使用yolox或安装paddle来执行模型会快很多(有需要我再补充整体流程) 68 | paddle的s模型由于量化导致反而速度降低和精度降低,但是模型大小减少很多 69 | 70 | | `model_type` | 任务类型 | 训练来源 | 大小 | 单表格耗时(v100-16G,cuda12,cudnn9,ubuntu) | 71 | |:---------------------|:-------|:-------------------------------------|:-------|:-------------------------------------| 72 | | **yolo_obj_det** | 表格目标检测 | `yolo11-l` | `100m` | `cpu:570ms, gpu:400ms` | 73 | | `paddle_obj_det` | 表格目标检测 | `paddle yoloe-plus-x` | `380m` | `cpu:1000ms, gpu:300ms` | 74 | | `paddle_obj_det_s` | 表格目标检测 | `paddle yoloe-plus-x + quantization` | `95m` | `cpu:1200ms, gpu:1000ms` | 75 | | **yolo_edge_det** | 语义分割 | `yolo11-l-segment` | `108m` | `cpu:570ms, gpu:200ms` | 76 | | `yolo_edge_det_s` | 语义分割 | `yolo11-s-segment` | `11m` | `cpu:260ms, gpu:200ms` | 77 | | `paddle_edge_det` | 语义分割 | `paddle-dbnet` | `99m` | `cpu:1200ms, gpu:120ms` | 78 | | `paddle_edge_det_s` | 语义分割 | `paddle-dbnet + quantization` | `25m` | `cpu:860ms, gpu:760ms` | 79 | | **paddle_cls_det** | 方向分类 | `paddle pplcnet` | `6.5m` | `cpu:70ms, gpu:60ms` | 80 | 81 | 82 | 执行参数 83 | det_accuracy=0.7, 84 | use_obj_det=True, 85 | use_edge_det=True, 86 | use_cls_det=True, 87 | 88 | ### 快速使用 89 | 90 | ``` python {linenos=table} 91 | from rapid_table_det.inference import TableDetector 92 | 93 | img_path = f"tests/test_files/chip.jpg" 94 | table_det = TableDetector() 95 | 96 | result, elapse = table_det(img_path) 97 | obj_det_elapse, edge_elapse, rotate_det_elapse = elapse 98 | print( 99 | f"obj_det_elapse:{obj_det_elapse}, edge_elapse={edge_elapse}, rotate_det_elapse={rotate_det_elapse}" 100 | ) 101 | # 输出可视化 102 | # import os 103 | # import cv2 104 | # from rapid_table_det.utils.visuallize import img_loader, visuallize, extract_table_img 105 | # 106 | # img = img_loader(img_path) 107 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 108 | # file_name_with_ext = os.path.basename(img_path) 109 | # file_name, file_ext = os.path.splitext(file_name_with_ext) 110 | # out_dir = "rapid_table_det/outputs" 111 | # if not os.path.exists(out_dir): 112 | # os.makedirs(out_dir) 113 | # extract_img = img.copy() 114 | # for i, res in enumerate(result): 115 | # box = res["box"] 116 | # lt, rt, rb, lb = res["lt"], res["rt"], res["rb"], res["lb"] 117 | # # 带识别框和左上角方向位置 118 | # img = visuallize(img, box, lt, rt, rb, lb) 119 | # # 透视变换提取表格图片 120 | # wrapped_img = extract_table_img(extract_img.copy(), lt, rt, rb, lb) 121 | # cv2.imwrite(f"{out_dir}/{file_name}-extract-{i}.jpg", wrapped_img) 122 | # cv2.imwrite(f"{out_dir}/{file_name}-visualize.jpg", img) 123 | 124 | ``` 125 | ### paddle版本使用 126 | 必须下载模型,指定模型位置! 127 | ``` python {linenos=table} 128 | # 建议使用清华源安装 https://pypi.tuna.tsinghua.edu.cn/simple 129 | pip install rapid-table-det-paddle (默认安装gpu版本,可以自行覆盖安装cpu版本paddlepaddle) 130 | ``` 131 | ```python 132 | from rapid_table_det_paddle.inference import TableDetector 133 | 134 | img_path = f"tests/test_files/chip.jpg" 135 | 136 | table_det = TableDetector( 137 | obj_model_path="models/obj_det_paddle", 138 | edge_model_path="models/edge_det_paddle", 139 | cls_model_path="models/cls_det_paddle", 140 | use_obj_det=True, 141 | use_edge_det=True, 142 | use_cls_det=True, 143 | ) 144 | result, elapse = table_det(img_path) 145 | obj_det_elapse, edge_elapse, rotate_det_elapse = elapse 146 | print( 147 | f"obj_det_elapse:{obj_det_elapse}, edge_elapse={edge_elapse}, rotate_det_elapse={rotate_det_elapse}" 148 | ) 149 | # 一张图片中可能有多个表格 150 | # img = img_loader(img_path) 151 | # file_name_with_ext = os.path.basename(img_path) 152 | # file_name, file_ext = os.path.splitext(file_name_with_ext) 153 | # out_dir = "rapid_table_det_paddle/outputs" 154 | # if not os.path.exists(out_dir): 155 | # os.makedirs(out_dir) 156 | # extract_img = img.copy() 157 | # for i, res in enumerate(result): 158 | # box = res["box"] 159 | # lt, rt, rb, lb = res["lt"], res["rt"], res["rb"], res["lb"] 160 | # # 带识别框和左上角方向位置 161 | # img = visuallize(img, box, lt, rt, rb, lb) 162 | # # 透视变换提取表格图片 163 | # wrapped_img = extract_table_img(extract_img.copy(), lt, rt, rb, lb) 164 | # cv2.imwrite(f"{out_dir}/{file_name}-extract-{i}.jpg", wrapped_img) 165 | # cv2.imwrite(f"{out_dir}/{file_name}-visualize.jpg", img) 166 | 167 | ``` 168 | 169 | ## FAQ (Frequently Asked Questions) 170 | 171 | 1. **问:如何微调模型适应特定场景?** 172 | - 答:直接参考这个项目,有非常详细的可视化操作步骤,数据集也在里面,可以得到paddle的推理模型 [百度表格检测大赛](https://aistudio.baidu.com/projectdetail/5398861?searchKeyword=%E8%A1%A8%E6%A0%BC%E6%A3%80%E6%B5%8B%E5%A4%A7%E8%B5%9B&searchTab=ALL), 173 | - yolo11的训练使用官方脚本足够简单,按官方指导转换为coco格式训练即可 174 | 2. **问:如何导出onnx** 175 | - 答:paddle模型需要在本项目tools下,有onnx_transform.ipynb文件 176 | yolo11的话,直接参照官方的方式一行搞定 177 | 3. **问:图片有扭曲可以修正吗?** 178 | - 答:本项目只解决旋转和透视场景的表格提取,对于扭曲的场景,需要先进行扭曲修正 179 | 180 | ### 致谢 181 | 182 | [百度表格检测大赛第2名方案](https://aistudio.baidu.com/projectdetail/5398861?searchKeyword=%E8%A1%A8%E6%A0%BC%E6%A3%80%E6%B5%8B%E5%A4%A7%E8%B5%9B&searchTab=ALL) \ 183 | [WTW 自然场景表格数据集](https://tianchi.aliyun.com/dataset/108587) \ 184 | [FinTabNet PDF文档表格数据集](https://developer.ibm.com/exchanges/data/all/fintabnet/) \ 185 | [TableBank 表格数据集](https://doc-analysis.github.io/tablebank-page/) \ 186 | [TableGeneration 表格自动生成工具](https://github.com/WenmuZhou/TableGeneration) 187 | 188 | ### 贡献指南 189 | 190 | 欢迎提交请求。对于重大更改,请先打开issue讨论您想要改变的内容。 191 | 192 | 有其他的好建议和集成场景,作者也会积极响应支持 193 | 194 | ### 开源许可证 195 | 196 | 该项目采用[Apache 2.0](https://github.com/RapidAI/TableStructureRec/blob/c41bbd23898cb27a957ed962b0ffee3c74dfeff1/LICENSE) 197 | 开源许可证。 198 | 199 | -------------------------------------------------------------------------------- /README_en.md: -------------------------------------------------------------------------------- 1 | 11 | 12 | ### Recent Updates 13 | 14 | - **2024.10.15** 15 | - Completed the initial version of the code, including three modules: object detection, semantic segmentation, and corner direction recognition. 16 | - **2024.11.2** 17 | - Added new YOLOv11 object detection models and edge detection models. 18 | - Increased automatic downloading and reduced package size. 19 | - Added ONNX-GPU inference support and provided benchmark test results. 20 | - Added online example usage. 21 | 22 | ### Introduction 23 | 24 | 💡✨ RapidTableDetection is a powerful and efficient table detection system that supports various types of tables, including those in papers, journals, magazines, invoices, receipts, and sign-in sheets. 25 | 26 | 🚀 It supports versions derived from PaddlePaddle and YOLO, with the default model combination requiring only 1.2 seconds for single-image CPU inference, and 0.4 seconds for the smallest ONNX-GPU (V100) combination, or 0.2 seconds for the PaddlePaddle-GPU version. 27 | 28 | 🛠️ It supports free combination and independent training optimization of three modules, providing ONNX conversion scripts and fine-tuning training solutions. 29 | 30 | 🌟 The whl package is easy to integrate and use, providing strong support for downstream OCR, table recognition, and data collection. 31 | 32 | Refer to the implementation solution of the [2nd place in the Baidu Table Detection Competition](https://aistudio.baidu.com/projectdetail/5398861?searchKeyword=%E8%A1%A8%E6%A0%BC%E6%A3%80%E6%B5%8B%E5%A4%A7%E8%B5%9B&searchTab=ALL), and retrain with a large amount of real-world scenario data. 33 |  \ 34 | The training dataset is acknowledged. The author works on open-source projects during spare time, please support by giving a star. 35 | 36 | 37 | ### Usage Recommendations 38 | 39 | - Document scenarios: No perspective rotation, use only object detection. 40 | - Photography scenarios with small angle rotation (-90~90): Default top-left corner, do not use corner direction recognition. 41 | - Use the online experience to find the suitable model combination for your scenario. 42 | 43 | ### Online Experience 44 | [modelscope](https://www.modelscope.cn/studios/jockerK/RapidTableDetDemo) [huggingface](https://huggingface.co/spaces/Joker1212/RapidTableDetection) 45 | 46 | ### Effect Demonstration 47 | 48 |  49 | 50 | ### Installation 51 | 52 | Models will be automatically downloaded, or you can download them from the repository [modelscope model warehouse](https://www.modelscope.cn/models/jockerK/TableExtractor). 53 | 54 | ``` python {linenos=table} 55 | pip install rapid-table-det 56 | ``` 57 | 58 | #### Parameter Explanation 59 | 60 | Default values: 61 | - `use_cuda: False`: Enable GPU acceleration for inference. 62 | - `obj_model_type="yolo_obj_det"`: Object detection model type. 63 | - `edge_model_type="yolo_edge_det"`: Edge detection model type. 64 | - `cls_model_type="paddle_cls_det"`: Corner direction classification model type. 65 | 66 | 67 | Since ONNX has limited GPU acceleration, it is still recommended to directly use YOLOX or install PaddlePaddle for faster model execution (I can provide the entire process if needed). 68 | The PaddlePaddle S model, due to quantization, actually slows down and reduces accuracy, but significantly reduces model size. 69 | 70 | 71 | | `model_type` | Task Type | Training Source | Size | Single Table Inference Time (V100-16G, cuda12, cudnn9, ubuntu) | 72 | |:---------------------|:---------|:-------------------------------------|:-------|:-------------------------------------| 73 | | **yolo_obj_det** | Table Object Detection | `yolo11-l` | `100m` | `cpu:570ms, gpu:400ms` | 74 | | `paddle_obj_det` | Table Object Detection | `paddle yoloe-plus-x` | `380m` | `cpu:1000ms, gpu:300ms` | 75 | | `paddle_obj_det_s` | Table Object Detection | `paddle yoloe-plus-x + quantization` | `95m` | `cpu:1200ms, gpu:1000ms` | 76 | | **yolo_edge_det** | Semantic Segmentation | `yolo11-l-segment` | `108m` | `cpu:570ms, gpu:200ms` | 77 | | `yolo_edge_det_s` | Semantic Segmentation | `yolo11-s-segment` | `11m` | `cpu:260ms, gpu:200ms` | 78 | | `paddle_edge_det` | Semantic Segmentation | `paddle-dbnet` | `99m` | `cpu:1200ms, gpu:120ms` | 79 | | `paddle_edge_det_s` | Semantic Segmentation | `paddle-dbnet + quantization` | `25m` | `cpu:860ms, gpu:760ms` | 80 | | **paddle_cls_det** | Direction Classification | `paddle pplcnet` | `6.5m` | `cpu:70ms, gpu:60ms` | 81 | 82 | Execution parameters: 83 | - `det_accuracy=0.7` 84 | - `use_obj_det=True` 85 | - `use_edge_det=True` 86 | - `use_cls_det=True` 87 | 88 | ### Quick Start 89 | 90 | ``` python {linenos=table} 91 | from rapid_table_det.inference import TableDetector 92 | 93 | img_path = f"tests/test_files/chip.jpg" 94 | table_det = TableDetector() 95 | 96 | result, elapse = table_det(img_path) 97 | obj_det_elapse, edge_elapse, rotate_det_elapse = elapse 98 | print( 99 | f"obj_det_elapse:{obj_det_elapse}, edge_elapse={edge_elapse}, rotate_det_elapse={rotate_det_elapse}" 100 | ) 101 | # Output visualization 102 | # import os 103 | # import cv2 104 | # from rapid_table_det.utils.visuallize import img_loader, visuallize, extract_table_img 105 | # 106 | # img = img_loader(img_path) 107 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 108 | # file_name_with_ext = os.path.basename(img_path) 109 | # file_name, file_ext = os.path.splitext(file_name_with_ext) 110 | # out_dir = "rapid_table_det/outputs" 111 | # if not os.path.exists(out_dir): 112 | # os.makedirs(out_dir) 113 | # extract_img = img.copy() 114 | # for i, res in enumerate(result): 115 | # box = res["box"] 116 | # lt, rt, rb, lb = res["lt"], res["rt"], res["rb"], res["lb"] 117 | # # With detection box and top-left corner position 118 | # img = visuallize(img, box, lt, rt, rb, lb) 119 | # # Perspective transformation to extract table image 120 | # wrapped_img = extract_table_img(extract_img.copy(), lt, rt, rb, lb) 121 | # cv2.imwrite(f"{out_dir}/{file_name}-extract-{i}.jpg", wrapped_img) 122 | # cv2.imwrite(f"{out_dir}/{file_name}-visualize.jpg", img) 123 | 124 | ``` 125 | ### Using PaddlePaddle Version 126 | You must download the models and specify their locations! 127 | ``` python {linenos=table} 128 | #(default installation is GPU version, you can override with CPU version paddlepaddle) 129 | pip install rapid-table-det-paddle 130 | ``` 131 | ```python 132 | from rapid_table_det_paddle.inference import TableDetector 133 | 134 | img_path = f"tests/test_files/chip.jpg" 135 | 136 | table_det = TableDetector( 137 | obj_model_path="models/obj_det_paddle", 138 | edge_model_path="models/edge_det_paddle", 139 | cls_model_path="models/cls_det_paddle", 140 | use_obj_det=True, 141 | use_edge_det=True, 142 | use_cls_det=True, 143 | ) 144 | result, elapse = table_det(img_path) 145 | obj_det_elapse, edge_elapse, rotate_det_elapse = elapse 146 | print( 147 | f"obj_det_elapse:{obj_det_elapse}, edge_elapse={edge_elapse}, rotate_det_elapse={rotate_det_elapse}" 148 | ) 149 | # more than one table in one image 150 | # img = img_loader(img_path) 151 | # file_name_with_ext = os.path.basename(img_path) 152 | # file_name, file_ext = os.path.splitext(file_name_with_ext) 153 | # out_dir = "rapid_table_det_paddle/outputs" 154 | # if not os.path.exists(out_dir): 155 | # os.makedirs(out_dir) 156 | # extract_img = img.copy() 157 | # for i, res in enumerate(result): 158 | # box = res["box"] 159 | # lt, rt, rb, lb = res["lt"], res["rt"], res["rb"], res["lb"] 160 | # # With detection box and top-left corner position 161 | # img = visuallize(img, box, lt, rt, rb, lb) 162 | # # Perspective transformation to extract table image 163 | # wrapped_img = extract_table_img(extract_img.copy(), lt, rt, rb, lb) 164 | # cv2.imwrite(f"{out_dir}/{file_name}-extract-{i}.jpg", wrapped_img) 165 | # cv2.imwrite(f"{out_dir}/{file_name}-visualize.jpg", img) 166 | 167 | ``` 168 | 169 | ## FAQ (Frequently Asked Questions) 170 | 171 | 1. **Q: How to fine-tune the model for specific scenarios?** 172 | - A: Refer to this project, which provides detailed visualization steps and datasets. You can get the PaddlePaddle inference model from [Baidu Table Detection Competition](https://aistudio.baidu.com/projectdetail/5398861?searchKeyword=%E8%A1%A8%E6%A0%BC%E6%A3%80%E6%B5%8B%E5%A4%A7%E8%B5%9B&searchTab=ALL). For YOLOv11, use the official script, which is simple enough, and convert the data to COCO format for training as per the official guidelines. 173 | 2. **Q: How to export ONNX?** 174 | - A: For PaddlePaddle models, use the `onnx_transform.ipynb` file in the `tools` directory of this project. For YOLOv11, follow the official method, which can be done in one line. 175 | 3. **Q: Can distorted images be corrected?** 176 | - A: This project only handles rotation and perspective scenarios for table extraction. For distorted images, you need to correct the distortion first. 177 | 178 | ### Acknowledgments 179 | 180 | - [2nd Place Solution in Baidu Table Detection Competition](https://aistudio.baidu.com/projectdetail/5398861?searchKeyword=%E8%A1%A8%E6%A0%BC%E6%A3%80%E6%B5%8B%E5%A4%A7%E8%B5%9B&searchTab=ALL) 181 | - [WTW Natural Scene Table Dataset](https://tianchi.aliyun.com/dataset/108587) 182 | - [FinTabNet PDF Document Table Dataset](https://developer.ibm.com/exchanges/data/all/fintabnet/) 183 | - [TableBank Table Dataset](https://doc-analysis.github.io/tablebank-page/) 184 | - [TableGeneration Table Auto-Generation Tool](https://github.com/WenmuZhou/TableGeneration) 185 | 186 | ### Contribution Guidelines 187 | 188 | Pull requests are welcome. For major changes, please open an issue to discuss what you would like to change. 189 | 190 | If you have other good suggestions and integration scenarios, the author will actively respond and support them. 191 | 192 | ### Open Source License 193 | 194 | This project is licensed under the [Apache 2.0](https://github.com/RapidAI/TableStructureRec/blob/c41bbd23898cb27a957ed962b0ffee3c74dfeff1/LICENSE) open source license. 195 | 196 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidTableDetection/ce3e4887df65b6629a8ac7b6ce9983783fc48347/__init__.py -------------------------------------------------------------------------------- /demo_onnx.py: -------------------------------------------------------------------------------- 1 | from rapid_table_det.inference import TableDetector 2 | 3 | img_path = f"images/0c35d6430193babb29c6a94711742531-1_rot2_noise.jpg" 4 | table_det = TableDetector( 5 | edge_model_type="yolo_edge_det", obj_model_type="yolo_obj_det" 6 | ) 7 | 8 | result, elapse = table_det(img_path) 9 | obj_det_elapse, edge_elapse, rotate_det_elapse = elapse 10 | print( 11 | f"obj_det_elapse:{obj_det_elapse}, edge_elapse={edge_elapse}, rotate_det_elapse={rotate_det_elapse}" 12 | ) 13 | # 输出可视化 14 | import os 15 | import cv2 16 | from rapid_table_det.utils.visuallize import img_loader, visuallize, extract_table_img 17 | 18 | img = img_loader(img_path) 19 | file_name_with_ext = os.path.basename(img_path) 20 | file_name, file_ext = os.path.splitext(file_name_with_ext) 21 | out_dir = "rapid_table_det/outputs" 22 | if not os.path.exists(out_dir): 23 | os.makedirs(out_dir) 24 | extract_img = img.copy() 25 | for i, res in enumerate(result): 26 | box = res["box"] 27 | lt, rt, rb, lb = res["lt"], res["rt"], res["rb"], res["lb"] 28 | # 带识别框和左上角方向位置 29 | img = visuallize(img, box, lt, rt, rb, lb) 30 | # 透视变换提取表格图片 31 | wrapped_img = extract_table_img(extract_img.copy(), lt, rt, rb, lb) 32 | cv2.imwrite(f"{out_dir}/{file_name}-extract-{i}.jpg", wrapped_img) 33 | cv2.imwrite(f"{out_dir}/{file_name}-visualize.jpg", img) 34 | -------------------------------------------------------------------------------- /demo_paddle.py: -------------------------------------------------------------------------------- 1 | from rapid_table_det_paddle.inference import TableDetector 2 | 3 | img_path = f"tests/test_files/chip.jpg" 4 | 5 | table_det = TableDetector( 6 | obj_model_path="rapid_table_det_paddle/models/obj_det_paddle", 7 | edge_model_path="rapid_table_det_paddle/models/edge_det_paddle", 8 | cls_model_path="rapid_table_det_paddle/models/cls_det_paddle", 9 | use_obj_det=True, 10 | use_edge_det=True, 11 | use_cls_det=True, 12 | ) 13 | result, elapse = table_det(img_path) 14 | obj_det_elapse, edge_elapse, rotate_det_elapse = elapse 15 | print( 16 | f"obj_det_elapse:{obj_det_elapse}, edge_elapse={edge_elapse}, rotate_det_elapse={rotate_det_elapse}" 17 | ) 18 | # 一张图片中可能有多个表格 19 | # img = img_loader(img_path) 20 | # file_name_with_ext = os.path.basename(img_path) 21 | # file_name, file_ext = os.path.splitext(file_name_with_ext) 22 | # out_dir = "rapid_table_det_paddle/outputs" 23 | # if not os.path.exists(out_dir): 24 | # os.makedirs(out_dir) 25 | # extract_img = img.copy() 26 | # for i, res in enumerate(result): 27 | # box = res["box"] 28 | # lt, rt, rb, lb = res["lt"], res["rt"], res["rb"], res["lb"] 29 | # # 带识别框和左上角方向位置 30 | # img = visuallize(img, box, lt, rt, rb, lb) 31 | # # 透视变换提取表格图片 32 | # wrapped_img = extract_table_img(extract_img.copy(), lt, rt, rb, lb) 33 | # cv2.imwrite(f"{out_dir}/{file_name}-extract-{i}.jpg", wrapped_img) 34 | # cv2.imwrite(f"{out_dir}/{file_name}-visualize.jpg", img) 35 | -------------------------------------------------------------------------------- /rapid_table_det/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: Jocker1212 3 | # @Contact: xinyijianggo@gmail.com 4 | from .inference import TableDetector 5 | -------------------------------------------------------------------------------- /rapid_table_det/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Union 4 | 5 | import cv2 6 | import numpy as np 7 | 8 | from .predictor import DbNet, PaddleYoloEDet, PPLCNet, YoloSeg, YoloDet 9 | from .utils.download_model import DownloadModel 10 | 11 | from .utils.logger import get_logger 12 | from .utils.load_image import LoadImage 13 | 14 | root_dir = Path(__file__).resolve().parent 15 | model_dir = os.path.join(root_dir, "models") 16 | 17 | ROOT_DIR = Path(__file__).resolve().parent 18 | logger = get_logger("rapid_layout") 19 | 20 | ROOT_URL = "https://www.modelscope.cn/models/jockerK/TableExtractor/resolve/master/rapid_table_det/models/" 21 | KEY_TO_MODEL_URL = { 22 | "yolo_obj_det": f"{ROOT_URL}/yolo_obj_det.onnx", 23 | "yolo_edge_det": f"{ROOT_URL}/yolo_edge_det.onnx", 24 | "yolo_edge_det_s": f"{ROOT_URL}/yolo_edge_det_s.onnx", 25 | "paddle_obj_det": f"{ROOT_URL}/paddle_obj_det.onnx", 26 | "paddle_obj_det_s": f"{ROOT_URL}/paddle_obj_det_s.onnx", 27 | "paddle_edge_det": f"{ROOT_URL}/paddle_edge_det.onnx", 28 | "paddle_edge_det_s": f"{ROOT_URL}/paddle_edge_det_s.onnx", 29 | "paddle_cls_det": f"{ROOT_URL}/paddle_cls_det.onnx", 30 | } 31 | 32 | 33 | class TableDetector: 34 | def __init__( 35 | self, 36 | use_cuda=False, 37 | use_dml=False, 38 | obj_model_path=None, 39 | edge_model_path=None, 40 | cls_model_path=None, 41 | obj_model_type="yolo_obj_det", 42 | edge_model_type="yolo_edge_det", 43 | cls_model_type="paddle_cls_det", 44 | ): 45 | self.img_loader = LoadImage() 46 | obj_det_config = { 47 | "model_path": self.get_model_path(obj_model_type, obj_model_path), 48 | "use_cuda": use_cuda, 49 | "use_dml": use_dml, 50 | } 51 | edge_det_config = { 52 | "model_path": self.get_model_path(edge_model_type, edge_model_path), 53 | "use_cuda": use_cuda, 54 | "use_dml": use_dml, 55 | } 56 | cls_det_config = { 57 | "model_path": self.get_model_path(cls_model_type, cls_model_path), 58 | "use_cuda": use_cuda, 59 | "use_dml": use_dml, 60 | } 61 | if "yolo" in obj_model_type: 62 | self.obj_detector = YoloDet(obj_det_config) 63 | else: 64 | self.obj_detector = PaddleYoloEDet(obj_det_config) 65 | if "yolo" in edge_model_type: 66 | self.dbnet = YoloSeg(edge_det_config) 67 | else: 68 | self.dbnet = DbNet(edge_det_config) 69 | if "yolo" in cls_model_type: 70 | self.pplcnet = PPLCNet(cls_det_config) 71 | else: 72 | self.pplcnet = PPLCNet(cls_det_config) 73 | 74 | def __call__( 75 | self, 76 | img, 77 | det_accuracy=0.7, 78 | use_obj_det=True, 79 | use_edge_det=True, 80 | use_cls_det=True, 81 | ): 82 | img = self.img_loader(img) 83 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 84 | img_mask = img.copy() 85 | h, w = img.shape[:-1] 86 | obj_det_res, pred_label = self.init_default_output(h, w) 87 | result = [] 88 | obj_det_elapse, edge_elapse, rotate_det_elapse = 0, 0, 0 89 | if use_obj_det: 90 | obj_det_res, obj_det_elapse = self.obj_detector(img, score=det_accuracy) 91 | for i in range(len(obj_det_res)): 92 | det_res = obj_det_res[i] 93 | score, box = det_res 94 | xmin, ymin, xmax, ymax = box 95 | edge_box = box.reshape([-1, 2]) 96 | lb, lt, rb, rt = self.get_box_points(box) 97 | if use_edge_det: 98 | xmin_edge, ymin_edge, xmax_edge, ymax_edge = self.pad_box_points( 99 | h, w, xmax, xmin, ymax, ymin, 10 100 | ) 101 | crop_img = img_mask[ymin_edge:ymax_edge, xmin_edge:xmax_edge, :] 102 | edge_box, lt, lb, rt, rb, tmp_edge_elapse = self.dbnet(crop_img) 103 | edge_elapse += tmp_edge_elapse 104 | if edge_box is None: 105 | continue 106 | lb, lt, rb, rt = self.adjust_edge_points_axis( 107 | edge_box, lb, lt, rb, rt, xmin_edge, ymin_edge 108 | ) 109 | if use_cls_det: 110 | xmin_cls, ymin_cls, xmax_cls, ymax_cls = self.pad_box_points( 111 | h, w, xmax, xmin, ymax, ymin, 5 112 | ) 113 | cls_img = img_mask[ymin_cls:ymax_cls, xmin_cls:xmax_cls, :] 114 | # 增加先验信息 115 | self.add_pre_info_for_cls(cls_img, edge_box, xmin_cls, ymin_cls) 116 | pred_label, tmp_rotate_det_elapse = self.pplcnet(cls_img) 117 | rotate_det_elapse += tmp_rotate_det_elapse 118 | lb1, lt1, rb1, rt1 = self.get_real_rotated_points( 119 | lb, lt, pred_label, rb, rt 120 | ) 121 | result.append( 122 | { 123 | "box": [int(xmin), int(ymin), int(xmax), int(ymax)], 124 | "lb": [int(lb1[0]), int(lb1[1])], 125 | "lt": [int(lt1[0]), int(lt1[1])], 126 | "rt": [int(rt1[0]), int(rt1[1])], 127 | "rb": [int(rb1[0]), int(rb1[1])], 128 | } 129 | ) 130 | elapse = [obj_det_elapse, edge_elapse, rotate_det_elapse] 131 | return result, elapse 132 | 133 | def init_default_output(self, h, w): 134 | img_box = np.array([0, 0, w, h]) 135 | # 初始化默认值 136 | obj_det_res, edge_box, pred_label = ( 137 | [[1.0, img_box]], 138 | img_box.reshape([-1, 2]), 139 | 0, 140 | ) 141 | return obj_det_res, pred_label 142 | 143 | def add_pre_info_for_cls(self, cls_img, edge_box, xmin_cls, ymin_cls): 144 | """ 145 | Args: 146 | cls_img: 147 | edge_box: 148 | xmin_cls: 149 | ymin_cls: 150 | 151 | Returns: 带边缘划线的图片,给方向分类提供先验信息 152 | 153 | """ 154 | cls_box = edge_box.copy() 155 | cls_box[:, 0] = cls_box[:, 0] - xmin_cls 156 | cls_box[:, 1] = cls_box[:, 1] - ymin_cls 157 | # 画框增加先验信息,辅助方向label识别 158 | cv2.polylines( 159 | cls_img, 160 | [np.array(cls_box).astype(np.int32).reshape((-1, 1, 2))], 161 | True, 162 | color=(255, 0, 255), 163 | thickness=5, 164 | ) 165 | 166 | def adjust_edge_points_axis(self, edge_box, lb, lt, rb, rt, xmin_edge, ymin_edge): 167 | edge_box[:, 0] += xmin_edge 168 | edge_box[:, 1] += ymin_edge 169 | lt, lb, rt, rb = ( 170 | lt + [xmin_edge, ymin_edge], 171 | lb + [xmin_edge, ymin_edge], 172 | rt + [xmin_edge, ymin_edge], 173 | rb + [xmin_edge, ymin_edge], 174 | ) 175 | return lb, lt, rb, rt 176 | 177 | def get_box_points(self, img_box): 178 | x1, y1, x2, y2 = img_box 179 | lt = np.array([x1, y1]) # 左上角 180 | rt = np.array([x2, y1]) # 右上角 181 | rb = np.array([x2, y2]) # 右下角 182 | lb = np.array([x1, y2]) # 左下角 183 | return lb, lt, rb, rt 184 | 185 | def get_real_rotated_points(self, lb, lt, pred_label, rb, rt): 186 | if pred_label == 0: 187 | lt1 = lt 188 | rt1 = rt 189 | rb1 = rb 190 | lb1 = lb 191 | elif pred_label == 1: 192 | lt1 = rt 193 | rt1 = rb 194 | rb1 = lb 195 | lb1 = lt 196 | elif pred_label == 2: 197 | lt1 = rb 198 | rt1 = lb 199 | rb1 = lt 200 | lb1 = rt 201 | elif pred_label == 3: 202 | lt1 = lb 203 | rt1 = lt 204 | rb1 = rt 205 | lb1 = rb 206 | else: 207 | lt1 = lt 208 | rt1 = rt 209 | rb1 = rb 210 | lb1 = lb 211 | return lb1, lt1, rb1, rt1 212 | 213 | def pad_box_points(self, h, w, xmax, xmin, ymax, ymin, pad): 214 | ymin_edge = max(ymin - pad, 0) 215 | xmin_edge = max(xmin - pad, 0) 216 | ymax_edge = min(ymax + pad, h) 217 | xmax_edge = min(xmax + pad, w) 218 | return xmin_edge, ymin_edge, xmax_edge, ymax_edge 219 | 220 | @staticmethod 221 | def get_model_path(model_type: str, model_path: Union[str, Path, None]) -> str: 222 | if model_path is not None: 223 | return model_path 224 | 225 | model_url = KEY_TO_MODEL_URL.get(model_type, None) 226 | if model_url: 227 | model_path = DownloadModel.download(model_url) 228 | return model_path 229 | 230 | logger.info( 231 | "model url is None, using the default download model %s", model_path 232 | ) 233 | return model_path 234 | -------------------------------------------------------------------------------- /rapid_table_det/models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidTableDetection/ce3e4887df65b6629a8ac7b6ce9983783fc48347/rapid_table_det/models/.gitkeep -------------------------------------------------------------------------------- /rapid_table_det/predictor.py: -------------------------------------------------------------------------------- 1 | import time 2 | from pathlib import Path 3 | 4 | import cv2 5 | import numpy as np 6 | from typing import Dict, Any 7 | 8 | from .utils.infer_engine import OrtInferSession 9 | from .utils.load_image import LoadImage 10 | from .utils.transform import ( 11 | custom_NMSBoxes, 12 | resize, 13 | pad, 14 | ResizePad, 15 | sigmoid, 16 | get_max_adjacent_bbox, 17 | ) 18 | 19 | MODEL_STAGES_PATTERN = { 20 | "PPLCNet": ["blocks2", "blocks3", "blocks4", "blocks5", "blocks6"] 21 | } 22 | root_dir = Path(__file__).resolve().parent 23 | root_dir_str = str(root_dir) 24 | 25 | 26 | class PaddleYoloEDet: 27 | model_key = "obj_det" 28 | 29 | def __init__(self, config: Dict[str, Any]): 30 | self.model = OrtInferSession(config) 31 | self.img_loader = LoadImage() 32 | self.resize_shape = [928, 928] 33 | 34 | def __call__(self, img, **kwargs): 35 | start = time.time() 36 | score = kwargs.get("score", 0.4) 37 | img = self.img_loader(img) 38 | ori_h, ori_w = img.shape[:-1] 39 | img, im_shape, factor = self.img_preprocess(img, self.resize_shape) 40 | pre = self.model([img, factor]) 41 | result = self.img_postprocess(ori_h, ori_w, pre, score) 42 | return result, time.time() - start 43 | 44 | def img_postprocess(self, ori_h, ori_w, pre, score): 45 | result = [] 46 | for item in pre[0]: 47 | cls, value, xmin, ymin, xmax, ymax = list(item) 48 | if value < score: 49 | continue 50 | cls, xmin, ymin, xmax, ymax = [ 51 | int(x) for x in [cls, xmin, ymin, xmax, ymax] 52 | ] 53 | xmin = max(xmin, 0) 54 | ymin = max(ymin, 0) 55 | xmax = min(xmax, ori_w) 56 | ymax = min(ymax, ori_h) 57 | result.append([value, np.array([xmin, ymin, xmax, ymax])]) 58 | return result 59 | 60 | def img_preprocess(self, img, resize_shape=[928, 928]): 61 | im_info = { 62 | "scale_factor": np.array([1.0, 1.0], dtype=np.float32), 63 | "im_shape": np.array(img.shape[:2], dtype=np.float32), 64 | } 65 | im, im_info = resize(img, im_info, resize_shape, False) 66 | im, im_info = pad(im, im_info, resize_shape) 67 | im = im / 255.0 68 | im = im.transpose((2, 0, 1)).copy() 69 | im = im[None, :] 70 | factor = im_info["scale_factor"].reshape((1, 2)) 71 | im_shape = im_info["im_shape"].reshape((1, 2)) 72 | return im, im_shape, factor 73 | 74 | 75 | class YoloDet: 76 | def __init__(self, config: Dict[str, Any]): 77 | self.model = OrtInferSession(config) 78 | self.img_loader = LoadImage() 79 | self.resize_shape = [928, 928] 80 | 81 | def __call__(self, img, **kwargs): 82 | start = time.time() 83 | score = kwargs.get("score", 0.4) 84 | img = self.img_loader(img) 85 | ori_h, ori_w = img.shape[:-1] 86 | img, new_w, new_h, left, top = self.img_preprocess(img, self.resize_shape) 87 | pre = self.model([img]) 88 | result = self.img_postprocess( 89 | pre, ori_w / new_w, ori_h / new_h, left, top, score 90 | ) 91 | return result, time.time() - start 92 | 93 | def img_preprocess(self, img, resize_shape=[928, 928]): 94 | im, new_w, new_h, left, top = ResizePad(img, resize_shape[0]) 95 | im = im / 255.0 96 | im = im.transpose((2, 0, 1)).copy() 97 | im = im[None, :].astype("float32") 98 | return im, new_w, new_h, left, top 99 | 100 | def img_postprocess(self, predict_maps, x_factor, y_factor, left, top, score): 101 | result = [] 102 | # 转置和压缩输出以匹配预期的形状 103 | outputs = np.transpose(np.squeeze(predict_maps[0])) 104 | # 获取输出数组的行数 105 | rows = outputs.shape[0] 106 | # 用于存储检测的边界框、得分和类别ID的列表 107 | boxes = [] 108 | scores = [] 109 | # 遍历输出数组的每一行 110 | for i in range(rows): 111 | # 找到类别得分中的最大得分 112 | max_score = outputs[i][4] 113 | # 如果最大得分高于置信度阈值 114 | if max_score >= score: 115 | # 从当前行提取边界框坐标 116 | x, y, w, h = outputs[i][0], outputs[i][1], outputs[i][2], outputs[i][3] 117 | # 计算边界框的缩放坐标 118 | xmin = max(int((x - w / 2 - left) * x_factor), 0) 119 | ymin = max(int((y - h / 2 - top) * y_factor), 0) 120 | xmax = xmin + int(w * x_factor) 121 | ymax = ymin + int(h * y_factor) 122 | # 将类别ID、得分和框坐标添加到各自的列表中 123 | boxes.append([xmin, ymin, xmax, ymax]) 124 | scores.append(max_score) 125 | # 应用非最大抑制过滤重叠的边界框 126 | indices = custom_NMSBoxes(boxes, scores) 127 | for i in indices: 128 | result.append([scores[i], np.array(boxes[i])]) 129 | return result 130 | 131 | 132 | class DbNet: 133 | model_key = "edge_det" 134 | 135 | def __init__(self, config: Dict[str, Any]): 136 | self.model = OrtInferSession(config) 137 | self.img_loader = LoadImage() 138 | self.resize_shape = [800, 800] 139 | 140 | def __call__(self, img, **kwargs): 141 | start = time.time() 142 | img = self.img_loader(img) 143 | destHeight, destWidth = img.shape[:-1] 144 | img, resize_h, resize_w, left, top = self.img_preprocess(img, self.resize_shape) 145 | # with paddle.no_grad(): 146 | predict_maps = self.model([img]) 147 | pred = self.img_postprocess(predict_maps) 148 | if pred is None: 149 | return None, None, None, None, None, time.time() - start 150 | segmentation = pred > 0.8 151 | mask = np.array(segmentation).astype(np.uint8) 152 | # 找到最佳边缘box shape(4, 2) 153 | box = get_max_adjacent_bbox(mask) 154 | # todo 注意还有crop的偏移 155 | if box is not None: 156 | # 根据缩放调整坐标适配输入的img大小 157 | adjusted_box = self.adjust_coordinates( 158 | box, left, top, resize_w, resize_h, destWidth, destHeight 159 | ) 160 | # 排序并裁剪负值 161 | lt, lb, rt, rb = self.sort_and_clip_coordinates(adjusted_box) 162 | return box, lt, lb, rt, rb, time.time() - start 163 | else: 164 | return None, None, None, None, None, time.time() - start 165 | 166 | def img_postprocess(self, predict_maps): 167 | pred = np.squeeze(predict_maps[0]) 168 | return pred 169 | 170 | def adjust_coordinates( 171 | self, box, left, top, resize_w, resize_h, destWidth, destHeight 172 | ): 173 | """ 174 | 调整边界框坐标,确保它们在合理范围内。 175 | 176 | 参数: 177 | box (numpy.ndarray): 原始边界框坐标 (shape: (4, 2)) 178 | left (int): 左侧偏移量 179 | top (int): 顶部偏移量 180 | resize_w (int): 缩放宽度 181 | resize_h (int): 缩放高度 182 | destWidth (int): 目标宽度 183 | destHeight (int): 目标高度 184 | xmin_a (int): 目标左上角横坐标 185 | ymin_a (int): 目标左上角纵坐标 186 | 187 | 返回: 188 | numpy.ndarray: 调整后的边界框坐标 189 | """ 190 | # 调整横坐标 191 | box[:, 0] = np.clip( 192 | (np.round(box[:, 0] - left) / resize_w * destWidth), 0, destWidth 193 | ) 194 | 195 | # 调整纵坐标 196 | box[:, 1] = np.clip( 197 | (np.round(box[:, 1] - top) / resize_h * destHeight), 0, destHeight 198 | ) 199 | return box 200 | 201 | def sort_and_clip_coordinates(self, box): 202 | """ 203 | 对边界框坐标进行排序并裁剪负值。 204 | 205 | 参数: 206 | box (numpy.ndarray): 边界框坐标 (shape: (4, 2)) 207 | 208 | 返回: 209 | tuple: 左上角、左下角、右上角、右下角坐标 210 | """ 211 | # 按横坐标排序 212 | x = box[:, 0] 213 | l_idx = x.argsort() 214 | l_box = np.array([box[l_idx[0]], box[l_idx[1]]]) 215 | r_box = np.array([box[l_idx[2]], box[l_idx[3]]]) 216 | 217 | # 左侧坐标按纵坐标排序 218 | l_idx_1 = np.array(l_box[:, 1]).argsort() 219 | lt = l_box[l_idx_1[0]] 220 | lb = l_box[l_idx_1[1]] 221 | 222 | # 右侧坐标按纵坐标排序 223 | r_idx_1 = np.array(r_box[:, 1]).argsort() 224 | rt = r_box[r_idx_1[0]] 225 | rb = r_box[r_idx_1[1]] 226 | 227 | # 裁剪负值 228 | lt[lt < 0] = 0 229 | lb[lb < 0] = 0 230 | rt[rt < 0] = 0 231 | rb[rb < 0] = 0 232 | 233 | return lt, lb, rt, rb 234 | 235 | def img_preprocess(self, img, resize_shape=[800, 800]): 236 | im, new_w, new_h, left, top = ResizePad(img, resize_shape[0]) 237 | im = im / 255.0 238 | im = im.transpose((2, 0, 1)).copy() 239 | im = im[None, :].astype("float32") 240 | return im, new_h, new_w, left, top 241 | 242 | 243 | class YoloSeg(DbNet): 244 | model_key = "edge_det" 245 | 246 | def img_postprocess(self, predict_maps): 247 | box_output = predict_maps[0] 248 | mask_output = predict_maps[1] 249 | predictions = np.squeeze(box_output).T 250 | # Filter out object confidence scores below threshold 251 | scores = predictions[:, 4] 252 | # 获取得分最高的索引 253 | highest_score_index = scores.argmax() 254 | # 获取得分最高的预测结果 255 | highest_score_prediction = predictions[highest_score_index] 256 | x, y, w, h = highest_score_prediction[:4] 257 | highest_score = highest_score_prediction[4] 258 | if highest_score < 0.7: 259 | return None 260 | mask_predictions = highest_score_prediction[5:] 261 | mask_predictions = np.expand_dims(mask_predictions, axis=0) 262 | mask_output = np.squeeze(mask_output) 263 | # Calculate the mask maps for each box 264 | num_mask, mask_height, mask_width = mask_output.shape # CHW 265 | masks = sigmoid(mask_predictions @ mask_output.reshape((num_mask, -1))) 266 | masks = masks.reshape((-1, mask_height, mask_width)) 267 | # 提取第一个通道 268 | mask = masks[0] 269 | 270 | # 计算缩小后的区域边界 271 | small_w = 200 272 | small_h = 200 273 | small_x_min = max(0, int((x - w / 2) * small_w / 800)) 274 | small_x_max = min(small_w, int((x + w / 2) * small_w / 800)) 275 | small_y_min = max(0, int((y - h / 2) * small_h / 800)) 276 | small_y_max = min(small_h, int((y + h / 2) * small_h / 800)) 277 | 278 | # 创建一个全零的掩码 279 | filtered_mask = np.zeros((small_h, small_w), dtype=np.float32) 280 | 281 | # 将区域内的值复制到过滤后的掩码中 282 | filtered_mask[small_y_min:small_y_max, small_x_min:small_x_max] = mask[ 283 | small_y_min:small_y_max, small_x_min:small_x_max 284 | ] 285 | 286 | # 使用 OpenCV 进行放大,保持边缘清晰 287 | resized_mask = cv2.resize( 288 | filtered_mask, (800, 800), interpolation=cv2.INTER_CUBIC 289 | ) 290 | return resized_mask 291 | 292 | 293 | class PPLCNet: 294 | model_key = "cls_det" 295 | 296 | def __init__(self, config: Dict[str, Any]): 297 | self.model = OrtInferSession(config) 298 | self.img_loader = LoadImage() 299 | self.resize_shape = [624, 624] 300 | 301 | def __call__(self, img, **kwargs): 302 | start = time.time() 303 | img = self.img_loader(img) 304 | img = self.img_preprocess(img, self.resize_shape) 305 | label = self.model([img])[0] 306 | label = label[None, :] 307 | mini_batch_result = np.argsort(label) 308 | mini_batch_result = mini_batch_result[0][-1] # 把这些列标拿出来 309 | mini_batch_result = mini_batch_result.flatten() # 拉平了,只吐出一个 array 310 | mini_batch_result = mini_batch_result[::-1] # 逆序 311 | pred_label = mini_batch_result[0] 312 | return pred_label, time.time() - start 313 | 314 | def img_preprocess(self, img, resize_shape=[624, 624]): 315 | im, new_w, new_h, left, top = ResizePad(img, resize_shape[0]) 316 | im = np.array(im).transpose((2, 0, 1)) / 255.0 317 | return im[None, :].astype("float32") 318 | -------------------------------------------------------------------------------- /rapid_table_det/requirments.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | Pillow 3 | opencv-python 4 | onnxruntime 5 | requests 6 | -------------------------------------------------------------------------------- /rapid_table_det/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidTableDetection/ce3e4887df65b6629a8ac7b6ce9983783fc48347/rapid_table_det/utils/__init__.py -------------------------------------------------------------------------------- /rapid_table_det/utils/download_model.py: -------------------------------------------------------------------------------- 1 | import io 2 | from pathlib import Path 3 | from typing import Optional, Union 4 | 5 | import requests 6 | from tqdm import tqdm 7 | 8 | from .logger import get_logger 9 | 10 | logger = get_logger("DownloadModel") 11 | CUR_DIR = Path(__file__).resolve() 12 | PROJECT_DIR = CUR_DIR.parent.parent 13 | 14 | 15 | class DownloadModel: 16 | cur_dir = PROJECT_DIR 17 | 18 | @classmethod 19 | def download(cls, model_full_url: Union[str, Path]) -> str: 20 | save_dir = cls.cur_dir / "models" 21 | save_dir.mkdir(parents=True, exist_ok=True) 22 | 23 | model_name = Path(model_full_url).name 24 | save_file_path = save_dir / model_name 25 | if save_file_path.exists(): 26 | logger.debug("%s already exists", save_file_path) 27 | return str(save_file_path) 28 | 29 | try: 30 | logger.info("Download %s to %s", model_full_url, save_dir) 31 | file = cls.download_as_bytes_with_progress(model_full_url, model_name) 32 | cls.save_file(save_file_path, file) 33 | except Exception as exc: 34 | raise DownloadModelError from exc 35 | return str(save_file_path) 36 | 37 | @staticmethod 38 | def download_as_bytes_with_progress( 39 | url: Union[str, Path], name: Optional[str] = None 40 | ) -> bytes: 41 | resp = requests.get(str(url), stream=True, allow_redirects=True, timeout=180) 42 | total = int(resp.headers.get("content-length", 0)) 43 | bio = io.BytesIO() 44 | with tqdm( 45 | desc=name, total=total, unit="b", unit_scale=True, unit_divisor=1024 46 | ) as pbar: 47 | for chunk in resp.iter_content(chunk_size=65536): 48 | pbar.update(len(chunk)) 49 | bio.write(chunk) 50 | return bio.getvalue() 51 | 52 | @staticmethod 53 | def save_file(save_path: Union[str, Path], file: bytes): 54 | with open(save_path, "wb") as f: 55 | f.write(file) 56 | 57 | 58 | class DownloadModelError(Exception): 59 | pass 60 | -------------------------------------------------------------------------------- /rapid_table_det/utils/infer_engine.py: -------------------------------------------------------------------------------- 1 | from .logger import get_logger 2 | import os 3 | import platform 4 | import traceback 5 | from enum import Enum 6 | from pathlib import Path 7 | from typing import Any, Dict, List, Tuple, Union 8 | 9 | import numpy as np 10 | from onnxruntime import ( 11 | GraphOptimizationLevel, 12 | InferenceSession, 13 | SessionOptions, 14 | get_available_providers, 15 | get_device, 16 | ) 17 | 18 | 19 | class EP(Enum): 20 | CPU_EP = "CPUExecutionProvider" 21 | CUDA_EP = "CUDAExecutionProvider" 22 | DIRECTML_EP = "DmlExecutionProvider" 23 | 24 | 25 | class OrtInferSession: 26 | def __init__(self, config: Dict[str, Any]): 27 | self.logger = get_logger("OrtInferSession") 28 | 29 | model_path = config.get("model_path", None) 30 | self._verify_model(model_path) 31 | 32 | self.cfg_use_cuda = config.get("use_cuda", None) 33 | self.cfg_use_dml = config.get("use_dml", None) 34 | 35 | self.had_providers: List[str] = get_available_providers() 36 | EP_list = self._get_ep_list() 37 | 38 | sess_opt = self._init_sess_opts(config) 39 | self.session = InferenceSession( 40 | model_path, 41 | sess_options=sess_opt, 42 | providers=EP_list, 43 | ) 44 | self._verify_providers() 45 | 46 | @staticmethod 47 | def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions: 48 | sess_opt = SessionOptions() 49 | sess_opt.log_severity_level = 4 50 | sess_opt.enable_cpu_mem_arena = False 51 | sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL 52 | 53 | cpu_nums = os.cpu_count() 54 | intra_op_num_threads = config.get("intra_op_num_threads", -1) 55 | if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums: 56 | sess_opt.intra_op_num_threads = intra_op_num_threads 57 | 58 | inter_op_num_threads = config.get("inter_op_num_threads", -1) 59 | if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums: 60 | sess_opt.inter_op_num_threads = inter_op_num_threads 61 | 62 | return sess_opt 63 | 64 | def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]: 65 | cpu_provider_opts = { 66 | "arena_extend_strategy": "kSameAsRequested", 67 | } 68 | EP_list = [(EP.CPU_EP.value, cpu_provider_opts)] 69 | 70 | cuda_provider_opts = { 71 | "device_id": 0, 72 | "arena_extend_strategy": "kNextPowerOfTwo", 73 | "cudnn_conv_algo_search": "EXHAUSTIVE", 74 | "do_copy_in_default_stream": True, 75 | } 76 | self.use_cuda = self._check_cuda() 77 | if self.use_cuda: 78 | EP_list.insert(0, (EP.CUDA_EP.value, cuda_provider_opts)) 79 | 80 | self.use_directml = self._check_dml() 81 | if self.use_directml: 82 | self.logger.info( 83 | "Windows 10 or above detected, try to use DirectML as primary provider" 84 | ) 85 | directml_options = ( 86 | cuda_provider_opts if self.use_cuda else cpu_provider_opts 87 | ) 88 | EP_list.insert(0, (EP.DIRECTML_EP.value, directml_options)) 89 | return EP_list 90 | 91 | def _check_cuda(self) -> bool: 92 | if not self.cfg_use_cuda: 93 | return False 94 | 95 | cur_device = get_device() 96 | if cur_device == "GPU" and EP.CUDA_EP.value in self.had_providers: 97 | return True 98 | 99 | self.logger.warning( 100 | "%s is not in available providers (%s). Use %s inference by default.", 101 | EP.CUDA_EP.value, 102 | self.had_providers, 103 | self.had_providers[0], 104 | ) 105 | self.logger.info("!!!Recommend to use rapidocr_paddle for inference on GPU.") 106 | self.logger.info( 107 | "(For reference only) If you want to use GPU acceleration, you must do:" 108 | ) 109 | self.logger.info( 110 | "First, uninstall all onnxruntime pakcages in current environment." 111 | ) 112 | self.logger.info( 113 | "Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`." 114 | ) 115 | self.logger.info( 116 | "\tNote the onnxruntime-gpu version must match your cuda and cudnn version." 117 | ) 118 | self.logger.info( 119 | "\tYou can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-EP.html" 120 | ) 121 | self.logger.info( 122 | "Third, ensure %s is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']", 123 | EP.CUDA_EP.value, 124 | ) 125 | return False 126 | 127 | def _check_dml(self) -> bool: 128 | if not self.cfg_use_dml: 129 | return False 130 | 131 | cur_os = platform.system() 132 | if cur_os != "Windows": 133 | self.logger.warning( 134 | "DirectML is only supported in Windows OS. The current OS is %s. Use %s inference by default.", 135 | cur_os, 136 | self.had_providers[0], 137 | ) 138 | return False 139 | 140 | cur_window_version = int(platform.release().split(".")[0]) 141 | if cur_window_version < 10: 142 | self.logger.warning( 143 | "DirectML is only supported in Windows 10 and above OS. The current Windows version is %s. Use %s inference by default.", 144 | cur_window_version, 145 | self.had_providers[0], 146 | ) 147 | return False 148 | 149 | if EP.DIRECTML_EP.value in self.had_providers: 150 | return True 151 | 152 | self.logger.warning( 153 | "%s is not in available providers (%s). Use %s inference by default.", 154 | EP.DIRECTML_EP.value, 155 | self.had_providers, 156 | self.had_providers[0], 157 | ) 158 | self.logger.info("If you want to use DirectML acceleration, you must do:") 159 | self.logger.info( 160 | "First, uninstall all onnxruntime pakcages in current environment." 161 | ) 162 | self.logger.info( 163 | "Second, install onnxruntime-directml by `pip install onnxruntime-directml`" 164 | ) 165 | self.logger.info( 166 | "Third, ensure %s is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']", 167 | EP.DIRECTML_EP.value, 168 | ) 169 | return False 170 | 171 | def _verify_providers(self): 172 | session_providers = self.session.get_providers() 173 | first_provider = session_providers[0] 174 | 175 | if self.use_cuda and first_provider != EP.CUDA_EP.value: 176 | self.logger.warning( 177 | "%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.", 178 | EP.CUDA_EP.value, 179 | first_provider, 180 | ) 181 | 182 | if self.use_directml and first_provider != EP.DIRECTML_EP.value: 183 | self.logger.warning( 184 | "%s is not available for current env, the inference part is automatically shifted to be executed under %s.", 185 | EP.DIRECTML_EP.value, 186 | first_provider, 187 | ) 188 | 189 | def __call__(self, input_content: List[np.ndarray]) -> np.ndarray: 190 | input_dict = dict(zip(self.get_input_names(), input_content)) 191 | try: 192 | return self.session.run(None, input_dict) 193 | except Exception as e: 194 | error_info = traceback.format_exc() 195 | raise ONNXRuntimeError(error_info) from e 196 | 197 | def get_input_names(self) -> List[str]: 198 | return [v.name for v in self.session.get_inputs()] 199 | 200 | def get_output_names(self) -> List[str]: 201 | return [v.name for v in self.session.get_outputs()] 202 | 203 | def get_character_list(self, key: str = "character") -> List[str]: 204 | meta_dict = self.session.get_modelmeta().custom_metadata_map 205 | return meta_dict[key].splitlines() 206 | 207 | def have_key(self, key: str = "character") -> bool: 208 | meta_dict = self.session.get_modelmeta().custom_metadata_map 209 | if key in meta_dict.keys(): 210 | return True 211 | return False 212 | 213 | @staticmethod 214 | def _verify_model(model_path: Union[str, Path, None]): 215 | if model_path is None: 216 | raise ValueError("model_path is None!") 217 | 218 | model_path = Path(model_path) 219 | if not model_path.exists(): 220 | raise FileNotFoundError(f"{model_path} does not exists.") 221 | 222 | if not model_path.is_file(): 223 | raise FileExistsError(f"{model_path} is not a file.") 224 | 225 | 226 | class ONNXRuntimeError(Exception): 227 | pass 228 | -------------------------------------------------------------------------------- /rapid_table_det/utils/load_image.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: SWHL 3 | # @Contact: liekkaskono@163.com 4 | from io import BytesIO 5 | from pathlib import Path 6 | from typing import Any, Union 7 | 8 | import cv2 9 | import numpy as np 10 | from PIL import Image, UnidentifiedImageError 11 | 12 | root_dir = Path(__file__).resolve().parent 13 | InputType = Union[str, np.ndarray, bytes, Path, Image.Image] 14 | 15 | 16 | class LoadImage: 17 | def __init__(self): 18 | pass 19 | 20 | def __call__(self, img: InputType) -> np.ndarray: 21 | if not isinstance(img, InputType.__args__): 22 | raise LoadImageError( 23 | f"The img type {type(img)} does not in {InputType.__args__}" 24 | ) 25 | 26 | origin_img_type = type(img) 27 | img = self.load_img(img) 28 | img = self.convert_img(img, origin_img_type) 29 | return img 30 | 31 | def load_img(self, img: InputType) -> np.ndarray: 32 | if isinstance(img, (str, Path)): 33 | self.verify_exist(img) 34 | try: 35 | img = self.img_to_ndarray(Image.open(img)) 36 | except UnidentifiedImageError as e: 37 | raise LoadImageError(f"cannot identify image file {img}") from e 38 | return img 39 | 40 | if isinstance(img, bytes): 41 | img = self.img_to_ndarray(Image.open(BytesIO(img))) 42 | return img 43 | 44 | if isinstance(img, np.ndarray): 45 | return img 46 | 47 | if isinstance(img, Image.Image): 48 | return self.img_to_ndarray(img) 49 | 50 | raise LoadImageError(f"{type(img)} is not supported!") 51 | 52 | def img_to_ndarray(self, img: Image.Image) -> np.ndarray: 53 | if img.mode == "1": 54 | img = img.convert("L") 55 | return np.array(img) 56 | return np.array(img) 57 | 58 | def convert_img(self, img: np.ndarray, origin_img_type: Any) -> np.ndarray: 59 | if img.ndim == 2: 60 | return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 61 | 62 | if img.ndim == 3: 63 | channel = img.shape[2] 64 | if channel == 1: 65 | return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 66 | 67 | if channel == 2: 68 | return self.cvt_two_to_three(img) 69 | 70 | if channel == 3: 71 | if issubclass(origin_img_type, (str, Path, bytes, Image.Image)): 72 | return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 73 | return img 74 | 75 | if channel == 4: 76 | return self.cvt_four_to_three(img) 77 | 78 | raise LoadImageError( 79 | f"The channel({channel}) of the img is not in [1, 2, 3, 4]" 80 | ) 81 | 82 | raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]") 83 | 84 | @staticmethod 85 | def cvt_two_to_three(img: np.ndarray) -> np.ndarray: 86 | """gray + alpha → BGR""" 87 | img_gray = img[..., 0] 88 | img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR) 89 | 90 | img_alpha = img[..., 1] 91 | not_a = cv2.bitwise_not(img_alpha) 92 | not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) 93 | 94 | new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha) 95 | new_img = cv2.add(new_img, not_a) 96 | return new_img 97 | 98 | @staticmethod 99 | def cvt_four_to_three(img: np.ndarray) -> np.ndarray: 100 | """RGBA → BGR""" 101 | r, g, b, a = cv2.split(img) 102 | new_img = cv2.merge((b, g, r)) 103 | 104 | not_a = cv2.bitwise_not(a) 105 | not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) 106 | 107 | new_img = cv2.bitwise_and(new_img, new_img, mask=a) 108 | 109 | mean_color = np.mean(new_img) 110 | if mean_color <= 0.0: 111 | new_img = cv2.add(new_img, not_a) 112 | else: 113 | new_img = cv2.bitwise_not(new_img) 114 | return new_img 115 | 116 | @staticmethod 117 | def verify_exist(file_path: Union[str, Path]): 118 | if not Path(file_path).exists(): 119 | raise LoadImageError(f"{file_path} does not exist.") 120 | 121 | 122 | class LoadImageError(Exception): 123 | pass 124 | -------------------------------------------------------------------------------- /rapid_table_det/utils/logger.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: Jocker1212 3 | # @Contact: xinyijianggo@gmail.com 4 | import logging 5 | from functools import lru_cache 6 | 7 | 8 | @lru_cache(maxsize=32) 9 | def get_logger(name: str) -> logging.Logger: 10 | logger = logging.getLogger(name) 11 | logger.setLevel(logging.DEBUG) 12 | 13 | fmt = "%(asctime)s - %(name)s - %(levelname)s: %(message)s" 14 | format_str = logging.Formatter(fmt) 15 | 16 | sh = logging.StreamHandler() 17 | sh.setLevel(logging.DEBUG) 18 | 19 | logger.addHandler(sh) 20 | sh.setFormatter(format_str) 21 | return logger 22 | -------------------------------------------------------------------------------- /rapid_table_det/utils/transform.py: -------------------------------------------------------------------------------- 1 | import math 2 | import itertools 3 | import cv2 4 | import numpy as np 5 | 6 | 7 | def generate_scale(im, resize_shape, keep_ratio): 8 | """ 9 | Args: 10 | im (np.ndarray): image (np.ndarray) 11 | Returns: 12 | im_scale_x: the resize ratio of X 13 | im_scale_y: the resize ratio of Y 14 | """ 15 | target_size = (resize_shape[0], resize_shape[1]) 16 | # target_size = (800, 1333) 17 | origin_shape = im.shape[:2] 18 | 19 | if keep_ratio: 20 | im_size_min = np.min(origin_shape) 21 | im_size_max = np.max(origin_shape) 22 | target_size_min = np.min(target_size) 23 | target_size_max = np.max(target_size) 24 | im_scale = float(target_size_min) / float(im_size_min) 25 | if np.round(im_scale * im_size_max) > target_size_max: 26 | im_scale = float(target_size_max) / float(im_size_max) 27 | im_scale_x = im_scale 28 | im_scale_y = im_scale 29 | else: 30 | resize_h, resize_w = target_size 31 | im_scale_y = resize_h / float(origin_shape[0]) 32 | im_scale_x = resize_w / float(origin_shape[1]) 33 | return im_scale_y, im_scale_x 34 | 35 | 36 | def resize(im, im_info, resize_shape, keep_ratio, interp=2): 37 | im_scale_y, im_scale_x = generate_scale(im, resize_shape, keep_ratio) 38 | im = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=interp) 39 | im_info["im_shape"] = np.array(im.shape[:2]).astype("float32") 40 | im_info["scale_factor"] = np.array([im_scale_y, im_scale_x]).astype("float32") 41 | 42 | return im, im_info 43 | 44 | 45 | def pad(im, im_info, resize_shape): 46 | im_h, im_w = im.shape[:2] 47 | fill_value = [114.0, 114.0, 114.0] 48 | h, w = resize_shape[0], resize_shape[1] 49 | if h == im_h and w == im_w: 50 | im = im.astype(np.float32) 51 | return im, im_info 52 | 53 | canvas = np.ones((h, w, 3), dtype=np.float32) 54 | canvas *= np.array(fill_value, dtype=np.float32) 55 | canvas[0:im_h, 0:im_w, :] = im.astype(np.float32) 56 | im = canvas 57 | return im, im_info 58 | 59 | 60 | def ResizePad(img, target_size): 61 | h, w = img.shape[:2] 62 | m = max(h, w) 63 | ratio = target_size / m 64 | new_w, new_h = int(ratio * w), int(ratio * h) 65 | img = cv2.resize(img, (new_w, new_h), cv2.INTER_LINEAR) 66 | top = (target_size - new_h) // 2 67 | bottom = (target_size - new_h) - top 68 | left = (target_size - new_w) // 2 69 | right = (target_size - new_w) - left 70 | img1 = cv2.copyMakeBorder( 71 | img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114) 72 | ) 73 | return img1, new_w, new_h, left, top 74 | 75 | 76 | def get_mini_boxes(contour): 77 | bounding_box = cv2.minAreaRect(contour) 78 | points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) 79 | 80 | index_1, index_2, index_3, index_4 = 0, 1, 2, 3 81 | if points[1][1] > points[0][1]: 82 | index_1 = 0 83 | index_4 = 1 84 | else: 85 | index_1 = 1 86 | index_4 = 0 87 | if points[3][1] > points[2][1]: 88 | index_2 = 2 89 | index_3 = 3 90 | else: 91 | index_2 = 3 92 | index_3 = 2 93 | 94 | box = [points[index_1], points[index_2], points[index_3], points[index_4]] 95 | return box, min(bounding_box[1]) 96 | 97 | 98 | def minboundquad(hull): 99 | len_hull = len(hull) 100 | xy = np.array(hull).reshape([-1, 2]) 101 | idx = np.arange(0, len_hull) 102 | idx_roll = np.roll(idx, -1, axis=0) 103 | edges = np.array([idx, idx_roll]).reshape([2, -1]) 104 | edges = np.transpose(edges, [1, 0]) 105 | edgeangles1 = [] 106 | for i in range(len_hull): 107 | y = xy[edges[i, 1], 1] - xy[edges[i, 0], 1] 108 | x = xy[edges[i, 1], 0] - xy[edges[i, 0], 0] 109 | angle = math.atan2(y, x) 110 | if angle < 0: 111 | angle = angle + 2 * math.pi 112 | edgeangles1.append([angle, i]) 113 | edgeangles1_idx = sorted(list(edgeangles1), key=lambda x: x[0]) 114 | edges1 = [] 115 | edgeangle1 = [] 116 | for item in edgeangles1_idx: 117 | idx = item[1] 118 | edges1.append(edges[idx, :]) 119 | edgeangle1.append(item[0]) 120 | edgeangles = np.array(edgeangle1) 121 | edges = np.array(edges1) 122 | eps = 2.2204e-16 123 | angletol = eps * 100 124 | 125 | k = np.diff(edgeangles) < angletol 126 | idx = np.where(k == 1) 127 | edges = np.delete(edges, idx, 0) 128 | edgeangles = np.delete(edgeangles, idx, 0) 129 | nedges = edges.shape[0] 130 | edgelist = np.array(nchoosek(0, nedges - 1, 1, 4)) 131 | k = edgeangles[edgelist[:, 3]] - edgeangles[edgelist[:, 0]] <= math.pi 132 | k_idx = np.where(k == 1) 133 | edgelist = np.delete(edgelist, k_idx, 0) 134 | 135 | nquads = edgelist.shape[0] 136 | quadareas = math.inf 137 | qxi = np.zeros([5]) 138 | qyi = np.zeros([5]) 139 | cnt = np.zeros([4, 1, 2]) 140 | edgelist = list(edgelist) 141 | edges = list(edges) 142 | xy = list(xy) 143 | 144 | for i in range(nquads): 145 | edgeind = list(edgelist[i]) 146 | edgeind.append(edgelist[i][0]) 147 | edgesi = [] 148 | edgeang = [] 149 | for idx in edgeind: 150 | edgesi.append(edges[idx]) 151 | edgeang.append(edgeangles[idx]) 152 | is_continue = False 153 | for idx in range(len(edgeang) - 1): 154 | diff = edgeang[idx + 1] - edgeang[idx] 155 | if diff > math.pi: 156 | is_continue = True 157 | if is_continue: 158 | continue 159 | for j in range(4): 160 | jplus1 = j + 1 161 | shared = np.intersect1d(edgesi[j], edgesi[jplus1]) 162 | if shared.size != 0: 163 | qxi[j] = xy[shared[0]][0] 164 | qyi[j] = xy[shared[0]][1] 165 | else: 166 | A = xy[edgesi[j][0]] 167 | B = xy[edgesi[j][1]] 168 | C = xy[edgesi[jplus1][0]] 169 | D = xy[edgesi[jplus1][1]] 170 | concat = np.hstack(((A - B).reshape([2, -1]), (D - C).reshape([2, -1]))) 171 | div = (A - C).reshape([2, -1]) 172 | inv_result = get_inv(concat) 173 | a = inv_result[0, 0] 174 | b = inv_result[0, 1] 175 | c = inv_result[1, 0] 176 | d = inv_result[1, 1] 177 | e = div[0, 0] 178 | f = div[1, 0] 179 | ts1 = [a * e + b * f, c * e + d * f] 180 | Q = A + (B - A) * ts1[0] 181 | qxi[j] = Q[0] 182 | qyi[j] = Q[1] 183 | 184 | contour = np.array([qxi[:4], qyi[:4]]).astype(np.int32) 185 | contour = np.transpose(contour, [1, 0]) 186 | contour = contour[:, np.newaxis, :] 187 | A_i = cv2.contourArea(contour) 188 | # break 189 | 190 | if A_i < quadareas: 191 | quadareas = A_i 192 | cnt = contour 193 | return cnt 194 | 195 | 196 | def nchoosek(startnum, endnum, step=1, n=1): 197 | c = [] 198 | for i in itertools.combinations(range(startnum, endnum + 1, step), n): 199 | c.append(list(i)) 200 | return c 201 | 202 | 203 | def get_inv(concat): 204 | a = concat[0][0] 205 | b = concat[0][1] 206 | c = concat[1][0] 207 | d = concat[1][1] 208 | det_concat = a * d - b * c 209 | inv_result = np.array( 210 | [[d / det_concat, -b / det_concat], [-c / det_concat, a / det_concat]] 211 | ) 212 | return inv_result 213 | 214 | 215 | def get_max_adjacent_bbox(mask): 216 | contours, _ = cv2.findContours( 217 | (mask * 255).astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE 218 | ) 219 | max_size = 0 220 | cnt_save = None 221 | # 找到最大边缘邻接矩形 222 | for cont in contours: 223 | points, sside = get_mini_boxes(cont) 224 | if sside > max_size: 225 | max_size = sside 226 | cnt_save = cont 227 | if cnt_save is not None: 228 | epsilon = 0.01 * cv2.arcLength(cnt_save, True) 229 | box = cv2.approxPolyDP(cnt_save, epsilon, True) 230 | hull = cv2.convexHull(box) 231 | points, sside = get_mini_boxes(cnt_save) 232 | len_hull = len(hull) 233 | 234 | if len_hull == 4: 235 | target_box = np.array(hull) 236 | elif len_hull > 4: 237 | target_box = minboundquad(hull) 238 | else: 239 | target_box = np.array(points) 240 | 241 | return np.array(target_box).reshape([-1, 2]) 242 | 243 | 244 | def sigmoid(x): 245 | return 1 / (1 + np.exp(-x)) 246 | 247 | 248 | def calculate_iou(box, other_boxes): 249 | """ 250 | 计算给定边界框与一组其他边界框之间的交并比(IoU)。 251 | 252 | 参数: 253 | - box: 单个边界框,格式为 [x1, y1, width, height]。 254 | - other_boxes: 其他边界框的数组,每个边界框的格式也为 [x1, y1, width, height]。 255 | 256 | 返回值: 257 | - iou: 一个数组,包含给定边界框与每个其他边界框的IoU值。 258 | """ 259 | 260 | # 计算交集的左上角坐标 261 | x1 = np.maximum(box[0], np.array(other_boxes)[:, 0]) 262 | y1 = np.maximum(box[1], np.array(other_boxes)[:, 1]) 263 | # 计算交集的右下角坐标 264 | x2 = np.minimum(box[2], np.array(other_boxes)[:, 2]) 265 | y2 = np.minimum(box[3], np.array(other_boxes)[:, 3]) 266 | # 计算交集区域的面积 267 | intersection_area = np.maximum(0, x2 - x1) * np.maximum(0, y2 - y1) 268 | # 计算给定边界框的面积 269 | box_area = (box[2] - box[0]) * (box[3] - box[1]) 270 | # 计算其他边界框的面积 271 | other_boxes_area = np.array(other_boxes[:, 2] - other_boxes[:, 0]) * np.array( 272 | other_boxes[:, 3] - other_boxes[:, 1] 273 | ) 274 | # 计算IoU值 275 | iou = intersection_area / (box_area + other_boxes_area - intersection_area) 276 | return iou 277 | 278 | 279 | def custom_NMSBoxes(boxes, scores, iou_threshold=0.4): 280 | # 如果没有边界框,则直接返回空列表 281 | if len(boxes) == 0: 282 | return [] 283 | # 将得分和边界框转换为NumPy数组 284 | scores = np.array(scores) 285 | boxes = np.array(boxes) 286 | # 根据置信度阈值过滤边界框 287 | # filtered_boxes = boxes[mask] 288 | # filtered_scores = scores[mask] 289 | # 如果过滤后没有边界框,则返回空列表 290 | if len(boxes) == 0: 291 | return [] 292 | # 根据置信度得分对边界框进行排序 293 | sorted_indices = np.argsort(scores)[::-1] 294 | # 初始化一个空列表来存储选择的边界框索引 295 | indices = [] 296 | # 当还有未处理的边界框时,循环继续 297 | while len(sorted_indices) > 0: 298 | # 选择得分最高的边界框索引 299 | current_index = sorted_indices[0] 300 | indices.append(current_index) 301 | # 如果只剩一个边界框,则结束循环 302 | if len(sorted_indices) == 1: 303 | break 304 | # 获取当前边界框和其他边界框 305 | current_box = boxes[current_index] 306 | other_boxes = boxes[sorted_indices[1:]] 307 | # 计算当前边界框与其他边界框的IoU 308 | iou = calculate_iou(current_box, other_boxes) 309 | # 找到IoU低于阈值的边界框,即与当前边界框不重叠的边界框 310 | non_overlapping_indices = np.where(iou <= iou_threshold)[0] 311 | # 更新sorted_indices以仅包含不重叠的边界框 312 | sorted_indices = sorted_indices[non_overlapping_indices + 1] 313 | # 返回选择的边界框索引 314 | return indices 315 | -------------------------------------------------------------------------------- /rapid_table_det/utils/visuallize.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | from rapid_table_det.utils.load_image import LoadImage 4 | import numpy as np 5 | 6 | img_loader = LoadImage() 7 | 8 | 9 | def visuallize(img, box, lt, rt, rb, lb): 10 | xmin, ymin, xmax, ymax = box 11 | draw_box = np.array([lt, rt, rb, lb]).reshape([-1, 2]) 12 | cv2.circle(img, (int(lt[0]), int(lt[1])), 50, (255, 0, 0), 10) 13 | cv2.rectangle(img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (255, 0, 0), 10) 14 | cv2.polylines( 15 | img, 16 | [np.array(draw_box).astype(np.int32).reshape((-1, 1, 2))], 17 | True, 18 | color=(255, 0, 255), 19 | thickness=6, 20 | ) 21 | return img 22 | 23 | 24 | def extract_table_img(img, lt, rt, rb, lb): 25 | """ 26 | 根据四个角点进行透视变换,并提取出角点区域的图片。 27 | 28 | 参数: 29 | img (numpy.ndarray): 输入图像 30 | lt (numpy.ndarray): 左上角坐标 31 | rt (numpy.ndarray): 右上角坐标 32 | lb (numpy.ndarray): 左下角坐标 33 | rb (numpy.ndarray): 右下角坐标 34 | 35 | 返回: 36 | numpy.ndarray: 提取出的角点区域图片 37 | """ 38 | # 源点坐标 39 | src_points = np.float32([lt, rt, lb, rb]) 40 | 41 | # 目标点坐标 42 | width_a = np.sqrt(((rb[0] - lb[0]) ** 2) + ((rb[1] - lb[1]) ** 2)) 43 | width_b = np.sqrt(((rt[0] - lt[0]) ** 2) + ((rt[1] - lt[1]) ** 2)) 44 | max_width = max(int(width_a), int(width_b)) 45 | 46 | height_a = np.sqrt(((rt[0] - rb[0]) ** 2) + ((rt[1] - rb[1]) ** 2)) 47 | height_b = np.sqrt(((lt[0] - lb[0]) ** 2) + ((lt[1] - lb[1]) ** 2)) 48 | max_height = max(int(height_a), int(height_b)) 49 | 50 | dst_points = np.float32( 51 | [ 52 | [0, 0], 53 | [max_width - 1, 0], 54 | [0, max_height - 1], 55 | [max_width - 1, max_height - 1], 56 | ] 57 | ) 58 | 59 | # 获取透视变换矩阵 60 | M = cv2.getPerspectiveTransform(src_points, dst_points) 61 | 62 | # 应用透视变换 63 | warped = cv2.warpPerspective(img, M, (max_width, max_height)) 64 | 65 | return warped 66 | -------------------------------------------------------------------------------- /rapid_table_det_paddle/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: Jocker1212 3 | # @Contact: xinyijianggo@gmail.com 4 | from .inference import TableDetector 5 | from .utils import img_loader, visuallize, extract_table_img 6 | 7 | __all__ = ["TableDetector", "img_loader", "visuallize", "extract_table_img"] 8 | -------------------------------------------------------------------------------- /rapid_table_det_paddle/inference.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from rapid_table_det_paddle.predictor import DbNet, ObjectDetector, PPLCNet 5 | from rapid_table_det_paddle.utils import LoadImage 6 | 7 | 8 | class TableDetector: 9 | def __init__( 10 | self, 11 | edge_model_path=None, 12 | obj_model_path=None, 13 | cls_model_path=None, 14 | use_obj_det=True, 15 | use_edge_det=True, 16 | use_cls_det=True, 17 | ): 18 | self.use_obj_det = use_obj_det 19 | self.use_edge_det = use_edge_det 20 | self.use_cls_det = use_cls_det 21 | self.img_loader = LoadImage() 22 | if self.use_obj_det: 23 | self.obj_detector = ObjectDetector(obj_model_path) 24 | if self.use_edge_det: 25 | self.dbnet = DbNet(edge_model_path) 26 | if self.use_cls_det: 27 | self.pplcnet = PPLCNet(cls_model_path) 28 | 29 | def __call__(self, img, det_accuracy=0.7): 30 | img = self.img_loader(img) 31 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 32 | img_mask = img.copy() 33 | h, w = img.shape[:-1] 34 | img_box = np.array([0, 0, w, h]) 35 | lb, lt, rb, rt = self.get_box_points(img_box) 36 | # 初始化默认值 37 | obj_det_res, edge_box, pred_label = ( 38 | [[1.0, img_box]], 39 | img_box.reshape([-1, 2]), 40 | 0, 41 | ) 42 | result = [] 43 | obj_det_elapse, edge_elapse, rotate_det_elapse = 0, 0, 0 44 | if self.use_obj_det: 45 | obj_det_res, obj_det_elapse = self.obj_detector(img, score=det_accuracy) 46 | for i in range(len(obj_det_res)): 47 | det_res = obj_det_res[i] 48 | score, box = det_res 49 | xmin, ymin, xmax, ymax = box 50 | edge_box = box.reshape([-1, 2]) 51 | lb, lt, rb, rt = self.get_box_points(box) 52 | if self.use_edge_det: 53 | xmin_edge, ymin_edge, xmax_edge, ymax_edge = self.pad_box_points( 54 | h, w, xmax, xmin, ymax, ymin, 10 55 | ) 56 | crop_img = img_mask[ymin_edge:ymax_edge, xmin_edge:xmax_edge, :] 57 | edge_box, lt, lb, rt, rb, tmp_edge_elapse = self.dbnet(crop_img) 58 | edge_elapse += tmp_edge_elapse 59 | if edge_box is None: 60 | continue 61 | edge_box[:, 0] += xmin_edge 62 | edge_box[:, 1] += ymin_edge 63 | lt, lb, rt, rb = ( 64 | lt + [xmin_edge, ymin_edge], 65 | lb + [xmin_edge, ymin_edge], 66 | rt + [xmin_edge, ymin_edge], 67 | rb + [xmin_edge, ymin_edge], 68 | ) 69 | if self.use_cls_det: 70 | xmin_cls, ymin_cls, xmax_cls, ymax_cls = self.pad_box_points( 71 | h, w, xmax, xmin, ymax, ymin, 5 72 | ) 73 | cls_box = edge_box.copy() 74 | cls_img = img_mask[ymin_cls:ymax_cls, xmin_cls:xmax_cls, :] 75 | cls_box[:, 0] = cls_box[:, 0] - xmin_cls 76 | cls_box[:, 1] = cls_box[:, 1] - ymin_cls 77 | # 画框增加先验信息,辅助方向label识别 78 | cv2.polylines( 79 | cls_img, 80 | [np.array(cls_box).astype(np.int32).reshape((-1, 1, 2))], 81 | True, 82 | color=(255, 0, 255), 83 | thickness=5, 84 | ) 85 | pred_label, tmp_rotate_det_elapse = self.pplcnet(cls_img) 86 | rotate_det_elapse += tmp_rotate_det_elapse 87 | lb1, lt1, rb1, rt1 = self.get_real_rotated_points( 88 | lb, lt, pred_label, rb, rt 89 | ) 90 | result.append( 91 | { 92 | "box": [int(xmin), int(ymin), int(xmax), int(ymax)], 93 | "lb": [int(lb1[0]), int(lb1[1])], 94 | "lt": [int(lt1[0]), int(lt1[1])], 95 | "rt": [int(rt1[0]), int(rt1[1])], 96 | "rb": [int(rb1[0]), int(rb1[1])], 97 | } 98 | ) 99 | elapse = [obj_det_elapse, edge_elapse, rotate_det_elapse] 100 | return result, elapse 101 | 102 | def get_box_points(self, img_box): 103 | x1, y1, x2, y2 = img_box 104 | lt = np.array([x1, y1]) # 左上角 105 | rt = np.array([x2, y1]) # 右上角 106 | rb = np.array([x2, y2]) # 右下角 107 | lb = np.array([x1, y2]) # 左下角 108 | return lb, lt, rb, rt 109 | 110 | def get_real_rotated_points(self, lb, lt, pred_label, rb, rt): 111 | if pred_label == 0: 112 | lt1 = lt 113 | rt1 = rt 114 | rb1 = rb 115 | lb1 = lb 116 | elif pred_label == 1: 117 | lt1 = rt 118 | rt1 = rb 119 | rb1 = lb 120 | lb1 = lt 121 | elif pred_label == 2: 122 | lt1 = rb 123 | rt1 = lb 124 | rb1 = lt 125 | lb1 = rt 126 | elif pred_label == 3: 127 | lt1 = lb 128 | rt1 = lt 129 | rb1 = rt 130 | lb1 = rb 131 | else: 132 | lt1 = lt 133 | rt1 = rt 134 | rb1 = rb 135 | lb1 = lb 136 | return lb1, lt1, rb1, rt1 137 | 138 | def pad_box_points(self, h, w, xmax, xmin, ymax, ymin, pad): 139 | ymin_edge = max(ymin - pad, 0) 140 | xmin_edge = max(xmin - pad, 0) 141 | ymax_edge = min(ymax + pad, h) 142 | xmax_edge = min(xmax + pad, w) 143 | return xmin_edge, ymin_edge, xmax_edge, ymax_edge 144 | -------------------------------------------------------------------------------- /rapid_table_det_paddle/models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidTableDetection/ce3e4887df65b6629a8ac7b6ce9983783fc48347/rapid_table_det_paddle/models/.gitkeep -------------------------------------------------------------------------------- /rapid_table_det_paddle/predictor.py: -------------------------------------------------------------------------------- 1 | import time 2 | import paddle 3 | from rapid_table_det_paddle.utils import * 4 | 5 | MODEL_STAGES_PATTERN = { 6 | "PPLCNet": ["blocks2", "blocks3", "blocks4", "blocks5", "blocks6"] 7 | } 8 | 9 | 10 | class ObjectDetector: 11 | model_key = "obj_det_paddle" 12 | 13 | def __init__(self, model_path, **kwargs): 14 | self.model = paddle.jit.load(model_path) 15 | self.img_loader = LoadImage() 16 | self.resize_shape = [928, 928] 17 | 18 | def __call__(self, img, **kwargs): 19 | start = time.time() 20 | score = kwargs.get("score", 0.4) 21 | img = self.img_loader(img) 22 | ori_h, ori_w = img.shape[:-1] 23 | img, im_shape, factor = self.img_preprocess(img, self.resize_shape) 24 | pre = self.model(img, factor) 25 | result = [] 26 | for item in pre[0].numpy(): 27 | cls, value, xmin, ymin, xmax, ymax = list(item) 28 | if value < score: 29 | continue 30 | cls, xmin, ymin, xmax, ymax = [ 31 | int(x) for x in [cls, xmin, ymin, xmax, ymax] 32 | ] 33 | xmin = max(xmin, 0) 34 | ymin = max(ymin, 0) 35 | xmax = min(xmax, ori_w) 36 | ymax = min(ymax, ori_h) 37 | result.append([value, np.array([xmin, ymin, xmax, ymax])]) 38 | return result, time.time() - start 39 | 40 | def img_preprocess(self, img, resize_shape=[928, 928]): 41 | im_info = { 42 | "scale_factor": np.array([1.0, 1.0], dtype=np.float32), 43 | "im_shape": np.array(img.shape[:2], dtype=np.float32), 44 | } 45 | im, im_info = resize(img, im_info, resize_shape, False) 46 | im, im_info = pad(im, im_info, resize_shape) 47 | im = im / 255.0 48 | im = im.transpose((2, 0, 1)).copy() 49 | im = paddle.to_tensor(im, dtype="float32") 50 | im = im.unsqueeze(0) 51 | factor = ( 52 | paddle.to_tensor(im_info["scale_factor"]).reshape((1, 2)).astype("float32") 53 | ) 54 | im_shape = paddle.to_tensor( 55 | im_info["im_shape"].reshape((1, 2)), dtype="float32" 56 | ) 57 | return im, im_shape, factor 58 | 59 | 60 | class DbNet: 61 | model_key = "edge_det_paddle" 62 | 63 | def __init__(self, model_path, **kwargs): 64 | self.model = paddle.jit.load(model_path) 65 | self.img_loader = LoadImage() 66 | self.resize_shape = [800, 800] 67 | 68 | def __call__(self, img, **kwargs): 69 | start = time.time() 70 | img = self.img_loader(img) 71 | destHeight, destWidth = img.shape[:-1] 72 | img, resize_h, resize_w, left, top = self.img_preprocess(img, self.resize_shape) 73 | with paddle.no_grad(): 74 | predicts = self.model(img) 75 | predict_maps = predicts.cpu() 76 | pred = predict_maps[0, 0].numpy() 77 | segmentation = pred > 0.7 78 | mask = np.array(segmentation).astype(np.uint8) 79 | # 找到最佳边缘box shape(4, 2) 80 | box = get_max_adjacent_bbox(mask) 81 | # todo 注意还有crop的偏移 82 | if box is not None: 83 | # 根据缩放调整坐标适配输入的img大小 84 | adjusted_box = self.adjust_coordinates( 85 | box, left, top, resize_w, resize_h, destWidth, destHeight 86 | ) 87 | # 排序并裁剪负值 88 | lt, lb, rt, rb = self.sort_and_clip_coordinates(adjusted_box) 89 | return box, lt, lb, rt, rb, time.time() - start 90 | else: 91 | return None, None, None, None, None, time.time() - start 92 | 93 | def adjust_coordinates( 94 | self, box, left, top, resize_w, resize_h, destWidth, destHeight 95 | ): 96 | """ 97 | 调整边界框坐标,确保它们在合理范围内。 98 | 99 | 参数: 100 | box (numpy.ndarray): 原始边界框坐标 (shape: (4, 2)) 101 | left (int): 左侧偏移量 102 | top (int): 顶部偏移量 103 | resize_w (int): 缩放宽度 104 | resize_h (int): 缩放高度 105 | destWidth (int): 目标宽度 106 | destHeight (int): 目标高度 107 | xmin_a (int): 目标左上角横坐标 108 | ymin_a (int): 目标左上角纵坐标 109 | 110 | 返回: 111 | numpy.ndarray: 调整后的边界框坐标 112 | """ 113 | # 调整横坐标 114 | box[:, 0] = np.clip( 115 | (np.round(box[:, 0] - left) / resize_w * destWidth), 0, destWidth 116 | ) 117 | 118 | # 调整纵坐标 119 | box[:, 1] = np.clip( 120 | (np.round(box[:, 1] - top) / resize_h * destHeight), 0, destHeight 121 | ) 122 | return box 123 | 124 | def sort_and_clip_coordinates(self, box): 125 | """ 126 | 对边界框坐标进行排序并裁剪负值。 127 | 128 | 参数: 129 | box (numpy.ndarray): 边界框坐标 (shape: (4, 2)) 130 | 131 | 返回: 132 | tuple: 左上角、左下角、右上角、右下角坐标 133 | """ 134 | # 按横坐标排序 135 | x = box[:, 0] 136 | l_idx = x.argsort() 137 | l_box = np.array([box[l_idx[0]], box[l_idx[1]]]) 138 | r_box = np.array([box[l_idx[2]], box[l_idx[3]]]) 139 | 140 | # 左侧坐标按纵坐标排序 141 | l_idx_1 = np.array(l_box[:, 1]).argsort() 142 | lt = l_box[l_idx_1[0]] 143 | lb = l_box[l_idx_1[1]] 144 | 145 | # 右侧坐标按纵坐标排序 146 | r_idx_1 = np.array(r_box[:, 1]).argsort() 147 | rt = r_box[r_idx_1[0]] 148 | rb = r_box[r_idx_1[1]] 149 | 150 | # 裁剪负值 151 | lt[lt < 0] = 0 152 | lb[lb < 0] = 0 153 | rt[rt < 0] = 0 154 | rb[rb < 0] = 0 155 | 156 | return lt, lb, rt, rb 157 | 158 | def img_preprocess(self, img, resize_shape=[800, 800]): 159 | im, new_w, new_h, left, top = ResizePad(img, resize_shape[0]) 160 | im = im / 255.0 161 | im = im.transpose((2, 0, 1)).copy() 162 | im = paddle.to_tensor(im, dtype="float32") 163 | im = im.unsqueeze(0) 164 | return im, new_h, new_w, left, top 165 | 166 | 167 | class PPLCNet: 168 | model_key = "cls_det_paddle" 169 | 170 | def __init__(self, model_path, **kwargs): 171 | self.model = paddle.jit.load(model_path) 172 | self.img_loader = LoadImage() 173 | self.resize_shape = [624, 624] 174 | 175 | def __call__(self, img, **kwargs): 176 | start = time.time() 177 | img = self.img_loader(img) 178 | img = self.img_preprocess(img, self.resize_shape) 179 | with paddle.no_grad(): 180 | label = self.model(img) 181 | label = label.unsqueeze(0).numpy() 182 | mini_batch_result = np.argsort(label) 183 | mini_batch_result = mini_batch_result[0][-1] # 把这些列标拿出来 184 | mini_batch_result = mini_batch_result.flatten() # 拉平了,只吐出一个 array 185 | mini_batch_result = mini_batch_result[::-1] # 逆序 186 | pred_label = mini_batch_result[0] 187 | return pred_label, time.time() - start 188 | 189 | def img_preprocess(self, img, resize_shape=[624, 624]): 190 | im, new_w, new_h, left, top = ResizePad(img, resize_shape[0]) 191 | im = np.array(im).astype("float32").transpose((2, 0, 1)) / 255.0 192 | im = paddle.to_tensor(im) 193 | return im.unsqueeze(0) 194 | -------------------------------------------------------------------------------- /rapid_table_det_paddle/requirments.txt: -------------------------------------------------------------------------------- 1 | paddlepaddle 2 | numpy 3 | Pillow 4 | opencv-python 5 | onnxruntime 6 | requests 7 | -------------------------------------------------------------------------------- /rapid_table_det_paddle/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from io import BytesIO 3 | from pathlib import Path 4 | from typing import Union 5 | import itertools 6 | import cv2 7 | import numpy as np 8 | from PIL import Image, UnidentifiedImageError 9 | 10 | root_dir = Path(__file__).resolve().parent 11 | InputType = Union[str, np.ndarray, bytes, Path] 12 | 13 | 14 | class LoadImage: 15 | def __init__( 16 | self, 17 | ): 18 | pass 19 | 20 | def __call__(self, img: InputType) -> np.ndarray: 21 | if not isinstance(img, InputType.__args__): 22 | raise LoadImageError( 23 | f"The img type {type(img)} does not in {InputType.__args__}" 24 | ) 25 | 26 | img = self.load_img(img) 27 | img = self.convert_img(img) 28 | return img 29 | 30 | def load_img(self, img: InputType) -> np.ndarray: 31 | if isinstance(img, (str, Path)): 32 | self.verify_exist(img) 33 | try: 34 | img = np.array(Image.open(img)) 35 | except UnidentifiedImageError as e: 36 | raise LoadImageError(f"cannot identify image file {img}") from e 37 | return img 38 | 39 | if isinstance(img, bytes): 40 | img = np.array(Image.open(BytesIO(img))) 41 | return img 42 | 43 | if isinstance(img, np.ndarray): 44 | return img 45 | 46 | raise LoadImageError(f"{type(img)} is not supported!") 47 | 48 | def convert_img(self, img: np.ndarray): 49 | if img.ndim == 2: 50 | return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 51 | 52 | if img.ndim == 3: 53 | channel = img.shape[2] 54 | if channel == 1: 55 | return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 56 | 57 | if channel == 2: 58 | return self.cvt_two_to_three(img) 59 | 60 | if channel == 4: 61 | return self.cvt_four_to_three(img) 62 | 63 | if channel == 3: 64 | return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 65 | 66 | raise LoadImageError( 67 | f"The channel({channel}) of the img is not in [1, 2, 3, 4]" 68 | ) 69 | 70 | raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]") 71 | 72 | @staticmethod 73 | def cvt_four_to_three(img: np.ndarray) -> np.ndarray: 74 | """RGBA → BGR""" 75 | r, g, b, a = cv2.split(img) 76 | new_img = cv2.merge((b, g, r)) 77 | 78 | not_a = cv2.bitwise_not(a) 79 | not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) 80 | 81 | new_img = cv2.bitwise_and(new_img, new_img, mask=a) 82 | new_img = cv2.add(new_img, not_a) 83 | return new_img 84 | 85 | @staticmethod 86 | def cvt_two_to_three(img: np.ndarray) -> np.ndarray: 87 | """gray + alpha → BGR""" 88 | img_gray = img[..., 0] 89 | img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR) 90 | 91 | img_alpha = img[..., 1] 92 | not_a = cv2.bitwise_not(img_alpha) 93 | not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) 94 | 95 | new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha) 96 | new_img = cv2.add(new_img, not_a) 97 | return new_img 98 | 99 | @staticmethod 100 | def verify_exist(file_path: Union[str, Path]): 101 | if not Path(file_path).exists(): 102 | raise LoadImageError(f"{file_path} does not exist.") 103 | 104 | 105 | img_loader = LoadImage() 106 | 107 | 108 | class LoadImageError(Exception): 109 | pass 110 | 111 | 112 | def generate_scale(im, resize_shape, keep_ratio): 113 | """ 114 | Args: 115 | im (np.ndarray): image (np.ndarray) 116 | Returns: 117 | im_scale_x: the resize ratio of X 118 | im_scale_y: the resize ratio of Y 119 | """ 120 | target_size = (resize_shape[0], resize_shape[1]) 121 | # target_size = (800, 1333) 122 | origin_shape = im.shape[:2] 123 | 124 | if keep_ratio: 125 | im_size_min = np.min(origin_shape) 126 | im_size_max = np.max(origin_shape) 127 | target_size_min = np.min(target_size) 128 | target_size_max = np.max(target_size) 129 | im_scale = float(target_size_min) / float(im_size_min) 130 | if np.round(im_scale * im_size_max) > target_size_max: 131 | im_scale = float(target_size_max) / float(im_size_max) 132 | im_scale_x = im_scale 133 | im_scale_y = im_scale 134 | else: 135 | resize_h, resize_w = target_size 136 | im_scale_y = resize_h / float(origin_shape[0]) 137 | im_scale_x = resize_w / float(origin_shape[1]) 138 | return im_scale_y, im_scale_x 139 | 140 | 141 | def resize(im, im_info, resize_shape, keep_ratio, interp=2): 142 | im_scale_y, im_scale_x = generate_scale(im, resize_shape, keep_ratio) 143 | im = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=interp) 144 | im_info["im_shape"] = np.array(im.shape[:2]).astype("float32") 145 | im_info["scale_factor"] = np.array([im_scale_y, im_scale_x]).astype("float32") 146 | 147 | return im, im_info 148 | 149 | 150 | def pad(im, im_info, resize_shape): 151 | im_h, im_w = im.shape[:2] 152 | fill_value = [114.0, 114.0, 114.0] 153 | h, w = resize_shape[0], resize_shape[1] 154 | if h == im_h and w == im_w: 155 | im = im.astype(np.float32) 156 | return im, im_info 157 | 158 | canvas = np.ones((h, w, 3), dtype=np.float32) 159 | canvas *= np.array(fill_value, dtype=np.float32) 160 | canvas[0:im_h, 0:im_w, :] = im.astype(np.float32) 161 | im = canvas 162 | return im, im_info 163 | 164 | 165 | def ResizePad(img, target_size): 166 | h, w = img.shape[:2] 167 | m = max(h, w) 168 | ratio = target_size / m 169 | new_w, new_h = int(ratio * w), int(ratio * h) 170 | img = cv2.resize(img, (new_w, new_h), cv2.INTER_LINEAR) 171 | top = (target_size - new_h) // 2 172 | bottom = (target_size - new_h) - top 173 | left = (target_size - new_w) // 2 174 | right = (target_size - new_w) - left 175 | img1 = cv2.copyMakeBorder( 176 | img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(255, 255, 255) 177 | ) 178 | return img1, new_w, new_h, left, top 179 | 180 | 181 | def get_mini_boxes(contour): 182 | bounding_box = cv2.minAreaRect(contour) 183 | points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) 184 | 185 | index_1, index_2, index_3, index_4 = 0, 1, 2, 3 186 | if points[1][1] > points[0][1]: 187 | index_1 = 0 188 | index_4 = 1 189 | else: 190 | index_1 = 1 191 | index_4 = 0 192 | if points[3][1] > points[2][1]: 193 | index_2 = 2 194 | index_3 = 3 195 | else: 196 | index_2 = 3 197 | index_3 = 2 198 | 199 | box = [points[index_1], points[index_2], points[index_3], points[index_4]] 200 | return box, min(bounding_box[1]) 201 | 202 | 203 | def minboundquad(hull): 204 | len_hull = len(hull) 205 | xy = np.array(hull).reshape([-1, 2]) 206 | idx = np.arange(0, len_hull) 207 | idx_roll = np.roll(idx, -1, axis=0) 208 | edges = np.array([idx, idx_roll]).reshape([2, -1]) 209 | edges = np.transpose(edges, [1, 0]) 210 | edgeangles1 = [] 211 | for i in range(len_hull): 212 | y = xy[edges[i, 1], 1] - xy[edges[i, 0], 1] 213 | x = xy[edges[i, 1], 0] - xy[edges[i, 0], 0] 214 | angle = math.atan2(y, x) 215 | if angle < 0: 216 | angle = angle + 2 * math.pi 217 | edgeangles1.append([angle, i]) 218 | edgeangles1_idx = sorted(list(edgeangles1), key=lambda x: x[0]) 219 | edges1 = [] 220 | edgeangle1 = [] 221 | for item in edgeangles1_idx: 222 | idx = item[1] 223 | edges1.append(edges[idx, :]) 224 | edgeangle1.append(item[0]) 225 | edgeangles = np.array(edgeangle1) 226 | edges = np.array(edges1) 227 | eps = 2.2204e-16 228 | angletol = eps * 100 229 | 230 | k = np.diff(edgeangles) < angletol 231 | idx = np.where(k == 1) 232 | edges = np.delete(edges, idx, 0) 233 | edgeangles = np.delete(edgeangles, idx, 0) 234 | nedges = edges.shape[0] 235 | edgelist = np.array(nchoosek(0, nedges - 1, 1, 4)) 236 | k = edgeangles[edgelist[:, 3]] - edgeangles[edgelist[:, 0]] <= math.pi 237 | k_idx = np.where(k == 1) 238 | edgelist = np.delete(edgelist, k_idx, 0) 239 | 240 | nquads = edgelist.shape[0] 241 | quadareas = math.inf 242 | qxi = np.zeros([5]) 243 | qyi = np.zeros([5]) 244 | cnt = np.zeros([4, 1, 2]) 245 | edgelist = list(edgelist) 246 | edges = list(edges) 247 | xy = list(xy) 248 | 249 | for i in range(nquads): 250 | edgeind = list(edgelist[i]) 251 | edgeind.append(edgelist[i][0]) 252 | edgesi = [] 253 | edgeang = [] 254 | for idx in edgeind: 255 | edgesi.append(edges[idx]) 256 | edgeang.append(edgeangles[idx]) 257 | is_continue = False 258 | for idx in range(len(edgeang) - 1): 259 | diff = edgeang[idx + 1] - edgeang[idx] 260 | if diff > math.pi: 261 | is_continue = True 262 | if is_continue: 263 | continue 264 | for j in range(4): 265 | jplus1 = j + 1 266 | shared = np.intersect1d(edgesi[j], edgesi[jplus1]) 267 | if shared.size != 0: 268 | qxi[j] = xy[shared[0]][0] 269 | qyi[j] = xy[shared[0]][1] 270 | else: 271 | A = xy[edgesi[j][0]] 272 | B = xy[edgesi[j][1]] 273 | C = xy[edgesi[jplus1][0]] 274 | D = xy[edgesi[jplus1][1]] 275 | concat = np.hstack(((A - B).reshape([2, -1]), (D - C).reshape([2, -1]))) 276 | div = (A - C).reshape([2, -1]) 277 | inv_result = get_inv(concat) 278 | a = inv_result[0, 0] 279 | b = inv_result[0, 1] 280 | c = inv_result[1, 0] 281 | d = inv_result[1, 1] 282 | e = div[0, 0] 283 | f = div[1, 0] 284 | ts1 = [a * e + b * f, c * e + d * f] 285 | Q = A + (B - A) * ts1[0] 286 | qxi[j] = Q[0] 287 | qyi[j] = Q[1] 288 | 289 | contour = np.array([qxi[:4], qyi[:4]]).astype(np.int32) 290 | contour = np.transpose(contour, [1, 0]) 291 | contour = contour[:, np.newaxis, :] 292 | A_i = cv2.contourArea(contour) 293 | # break 294 | 295 | if A_i < quadareas: 296 | quadareas = A_i 297 | cnt = contour 298 | return cnt 299 | 300 | 301 | def nchoosek(startnum, endnum, step=1, n=1): 302 | c = [] 303 | for i in itertools.combinations(range(startnum, endnum + 1, step), n): 304 | c.append(list(i)) 305 | return c 306 | 307 | 308 | def get_inv(concat): 309 | a = concat[0][0] 310 | b = concat[0][1] 311 | c = concat[1][0] 312 | d = concat[1][1] 313 | det_concat = a * d - b * c 314 | inv_result = np.array( 315 | [[d / det_concat, -b / det_concat], [-c / det_concat, a / det_concat]] 316 | ) 317 | return inv_result 318 | 319 | 320 | def get_max_adjacent_bbox(mask): 321 | contours, _ = cv2.findContours( 322 | (mask * 255).astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE 323 | ) 324 | max_size = 0 325 | cnt_save = None 326 | # 找到最大边缘邻接矩形 327 | for cont in contours: 328 | points, sside = get_mini_boxes(cont) 329 | if sside > max_size: 330 | max_size = sside 331 | cnt_save = cont 332 | if cnt_save is not None: 333 | epsilon = 0.01 * cv2.arcLength(cnt_save, True) 334 | box = cv2.approxPolyDP(cnt_save, epsilon, True) 335 | hull = cv2.convexHull(box) 336 | points, sside = get_mini_boxes(cnt_save) 337 | len_hull = len(hull) 338 | 339 | if len_hull == 4: 340 | target_box = np.array(hull) 341 | elif len_hull > 4: 342 | target_box = minboundquad(hull) 343 | else: 344 | target_box = np.array(points) 345 | 346 | return np.array(target_box).reshape([-1, 2]) 347 | 348 | 349 | def visuallize(img, box, lt, rt, rb, lb): 350 | xmin, ymin, xmax, ymax = box 351 | draw_box = np.array([lt, rt, rb, lb]).reshape([-1, 2]) 352 | cv2.circle(img, (int(lt[0]), int(lt[1])), 50, (255, 0, 0), 10) 353 | cv2.rectangle(img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (255, 0, 0), 10) 354 | cv2.polylines( 355 | img, 356 | [np.array(draw_box).astype(np.int32).reshape((-1, 1, 2))], 357 | True, 358 | color=(255, 0, 255), 359 | thickness=6, 360 | ) 361 | return img 362 | 363 | 364 | def extract_table_img(img, lt, rt, rb, lb): 365 | """ 366 | 根据四个角点进行透视变换,并提取出角点区域的图片。 367 | 368 | 参数: 369 | img (numpy.ndarray): 输入图像 370 | lt (numpy.ndarray): 左上角坐标 371 | rt (numpy.ndarray): 右上角坐标 372 | lb (numpy.ndarray): 左下角坐标 373 | rb (numpy.ndarray): 右下角坐标 374 | 375 | 返回: 376 | numpy.ndarray: 提取出的角点区域图片 377 | """ 378 | # 源点坐标 379 | src_points = np.float32([lt, rt, lb, rb]) 380 | 381 | # 目标点坐标 382 | width_a = np.sqrt(((rb[0] - lb[0]) ** 2) + ((rb[1] - lb[1]) ** 2)) 383 | width_b = np.sqrt(((rt[0] - lt[0]) ** 2) + ((rt[1] - lt[1]) ** 2)) 384 | max_width = max(int(width_a), int(width_b)) 385 | 386 | height_a = np.sqrt(((rt[0] - rb[0]) ** 2) + ((rt[1] - rb[1]) ** 2)) 387 | height_b = np.sqrt(((lt[0] - lb[0]) ** 2) + ((lt[1] - lb[1]) ** 2)) 388 | max_height = max(int(height_a), int(height_b)) 389 | 390 | dst_points = np.float32( 391 | [ 392 | [0, 0], 393 | [max_width - 1, 0], 394 | [0, max_height - 1], 395 | [max_width - 1, max_height - 1], 396 | ] 397 | ) 398 | 399 | # 获取透视变换矩阵 400 | M = cv2.getPerspectiveTransform(src_points, dst_points) 401 | 402 | # 应用透视变换 403 | warped = cv2.warpPerspective(img, M, (max_width, max_height)) 404 | return warped 405 | -------------------------------------------------------------------------------- /readme_resource/res_show.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidTableDetection/ce3e4887df65b6629a8ac7b6ce9983783fc48347/readme_resource/res_show.jpg -------------------------------------------------------------------------------- /readme_resource/res_show2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidTableDetection/ce3e4887df65b6629a8ac7b6ce9983783fc48347/readme_resource/res_show2.jpg -------------------------------------------------------------------------------- /readme_resource/structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidTableDetection/ce3e4887df65b6629a8ac7b6ce9983783fc48347/readme_resource/structure.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | tqdm 3 | Pillow 4 | opencv-python 5 | onnxruntime 6 | requests 7 | -------------------------------------------------------------------------------- /setup_rapid_table_det.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: Jocker1212 3 | # @Contact: xinyijianggo@gmail.com 4 | import sys 5 | from typing import List, Union 6 | from pathlib import Path 7 | from get_pypi_latest_version import GetPyPiLatestVersion 8 | 9 | import setuptools 10 | 11 | 12 | def read_txt(txt_path: Union[Path, str]) -> List[str]: 13 | with open(txt_path, "r", encoding="utf-8") as f: 14 | data = [v.rstrip("\n") for v in f] 15 | return data 16 | 17 | 18 | MODULE_NAME = "rapid_table_det" 19 | 20 | obtainer = GetPyPiLatestVersion() 21 | try: 22 | latest_version = obtainer(MODULE_NAME) 23 | except Exception: 24 | latest_version = "0.0.0" 25 | 26 | VERSION_NUM = obtainer.version_add_one(latest_version) 27 | 28 | if len(sys.argv) > 2: 29 | match_str = " ".join(sys.argv[2:]) 30 | matched_versions = obtainer.extract_version(match_str) 31 | if matched_versions: 32 | VERSION_NUM = matched_versions 33 | sys.argv = sys.argv[:2] 34 | 35 | setuptools.setup( 36 | name=MODULE_NAME, 37 | version=VERSION_NUM, 38 | platforms="Any", 39 | description="table detection with onnx model", 40 | long_description="table detection with onnx model", 41 | author="jockerK", 42 | author_email="xinyijianggo@gmail.com", 43 | url="https://github.com/Joker1212/RapidTableDetection", 44 | license="Apache-2.0", 45 | install_requires=read_txt("requirements.txt"), 46 | include_package_data=False, 47 | packages=[MODULE_NAME, f"{MODULE_NAME}.models", f"{MODULE_NAME}.utils"], 48 | package_data={"": [".gitkeep"]}, 49 | keywords=["obj detection,ocr,table-recognition"], 50 | classifiers=[ 51 | "Programming Language :: Python :: 3.8", 52 | "Programming Language :: Python :: 3.9", 53 | "Programming Language :: Python :: 3.10", 54 | "Programming Language :: Python :: 3.11", 55 | ], 56 | python_requires=">=3.8,<3.13", 57 | ) 58 | -------------------------------------------------------------------------------- /setup_rapid_table_det_paddle.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | # @Author: Jocker1212 3 | # @Contact: xinyijianggo@gmail.com 4 | import sys 5 | from typing import List, Union 6 | from pathlib import Path 7 | from get_pypi_latest_version import GetPyPiLatestVersion 8 | 9 | import setuptools 10 | 11 | 12 | def read_txt(txt_path: Union[Path, str]) -> List[str]: 13 | with open(txt_path, "r", encoding="utf-8") as f: 14 | data = [v.rstrip("\n") for v in f] 15 | return data 16 | 17 | 18 | MODULE_NAME = "rapid_table_det_paddle" 19 | 20 | obtainer = GetPyPiLatestVersion() 21 | try: 22 | latest_version = obtainer(MODULE_NAME) 23 | except Exception: 24 | latest_version = "0.0.0" 25 | 26 | VERSION_NUM = obtainer.version_add_one(latest_version) 27 | 28 | if len(sys.argv) > 2: 29 | match_str = " ".join(sys.argv[2:]) 30 | matched_versions = obtainer.extract_version(match_str) 31 | if matched_versions: 32 | VERSION_NUM = matched_versions 33 | sys.argv = sys.argv[:2] 34 | 35 | setuptools.setup( 36 | name=MODULE_NAME, 37 | version=VERSION_NUM, 38 | platforms="Any", 39 | description="table detection with original paddle model", 40 | long_description="table detection with original paddle model", 41 | author="jockerK", 42 | author_email="xinyijianggo@gmail.com", 43 | url="https://github.com/Joker1212/RapidTableDetection", 44 | license="Apache-2.0", 45 | install_requires=read_txt("requirements.txt"), 46 | include_package_data=True, 47 | packages=[MODULE_NAME, f"{MODULE_NAME}.models"], 48 | keywords=["obj detection,ocr,table-recognition"], 49 | classifiers=[ 50 | "Programming Language :: Python :: 3.8", 51 | "Programming Language :: Python :: 3.9", 52 | "Programming Language :: Python :: 3.10", 53 | "Programming Language :: Python :: 3.11", 54 | ], 55 | python_requires=">=3.8,<3.13", 56 | ) 57 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidTableDetection/ce3e4887df65b6629a8ac7b6ce9983783fc48347/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_files/chip.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidTableDetection/ce3e4887df65b6629a8ac7b6ce9983783fc48347/tests/test_files/chip.jpg -------------------------------------------------------------------------------- /tests/test_files/chip2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidTableDetection/ce3e4887df65b6629a8ac7b6ce9983783fc48347/tests/test_files/chip2.jpg -------------------------------------------------------------------------------- /tests/test_files/doc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidTableDetection/ce3e4887df65b6629a8ac7b6ce9983783fc48347/tests/test_files/doc.png -------------------------------------------------------------------------------- /tests/test_table_det.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | import pytest 5 | 6 | 7 | cur_dir = Path(__file__).resolve().parent 8 | root_dir = cur_dir.parent 9 | 10 | sys.path.append(str(root_dir)) 11 | test_file_dir = cur_dir / "test_files" 12 | 13 | 14 | @pytest.mark.parametrize( 15 | "img_path, expected", 16 | [("chip.jpg", 1), ("doc.png", 2)], 17 | ) 18 | def test_input_normal(img_path, expected): 19 | from rapid_table_det import TableDetector 20 | 21 | table_det = TableDetector() 22 | img_path = test_file_dir / img_path 23 | result, elapse = table_det(img_path) 24 | assert len(result) == expected 25 | -------------------------------------------------------------------------------- /tests/test_table_det_paddle.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | import pytest 5 | 6 | cur_dir = Path(__file__).resolve().parent 7 | root_dir = cur_dir.parent 8 | 9 | sys.path.append(str(root_dir)) 10 | test_file_dir = cur_dir / "test_files" 11 | 12 | 13 | @pytest.mark.parametrize( 14 | "img_path, expected", 15 | [("chip.jpg", 1), ("doc.png", 2)], 16 | ) 17 | def test_input_normal(img_path, expected): 18 | from rapid_table_det_paddle.inference import TableDetector 19 | 20 | table_det = TableDetector( 21 | obj_model_path=f"{root_dir}/rapid_table_det_paddle/models/obj_det_paddle", 22 | edge_model_path=f"{root_dir}/rapid_table_det_paddle/models/edge_det_paddle", 23 | cls_model_path=f"{root_dir}/rapid_table_det_paddle/models/cls_det_paddle", 24 | use_obj_det=True, 25 | use_edge_det=True, 26 | use_cls_det=True, 27 | ) 28 | img_path = test_file_dir / img_path 29 | result, elapse = table_det(img_path) 30 | assert len(result) == expected 31 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RapidAI/RapidTableDetection/ce3e4887df65b6629a8ac7b6ce9983783fc48347/tools/__init__.py -------------------------------------------------------------------------------- /tools/fix_onnx.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | from onnx import helper 3 | 4 | 5 | def create_constant_tensor(name, value): 6 | tensor = helper.make_tensor( 7 | name=name, data_type=onnx.TensorProto.INT64, dims=[len(value)], vals=value 8 | ) 9 | return tensor 10 | 11 | 12 | # 创建新的 squeeze 节点 13 | def create_squeeze_node(name, input_name, axes): 14 | const_tensor = create_constant_tensor(name + "_constant", axes) 15 | squeeze_node = helper.make_node( 16 | "Squeeze", 17 | inputs=[input_name, name + "_constant"], 18 | outputs=[input_name + "_squeezed"], 19 | name=name + "_Squeeze", 20 | ) 21 | return const_tensor, squeeze_node 22 | 23 | 24 | def fix_onnx(model_path): 25 | model = onnx.load(model_path) 26 | 27 | # 删除指定的节点 28 | deleted_nodes = ["p2o.Squeeze.3", "p2o.Squeeze.5"] 29 | nodes_to_delete = [node for node in model.graph.node if node.name in deleted_nodes] 30 | for node in nodes_to_delete: 31 | model.graph.node.remove(node) 32 | # 找到 gather8 和 gather10 的输出 33 | gather8_output = None 34 | gather10_output = None 35 | # 找到 'p2o.Gather.1' 节点 36 | for node in model.graph.node: 37 | if node.name == "p2o.Gather.0": 38 | gather_output_name = node.output[0] 39 | # break 40 | if node.name == "p2o.Gather.2": 41 | gather_output_name1 = node.output[0] 42 | # break 43 | 44 | for node in model.graph.node: 45 | if node.name == "p2o.Gather.8": 46 | node.input[0] = gather_output_name 47 | gather8_output = node.output[0] 48 | elif node.name == "p2o.Gather.10": 49 | node.input[0] = gather_output_name1 50 | gather10_output = node.output[0] 51 | 52 | if gather8_output: 53 | new_squeeze_components_8 = create_squeeze_node( 54 | "p2o.Gather.8", gather8_output, [1] 55 | ) 56 | model.graph.initializer.append(new_squeeze_components_8[0]) # 添加常量张量 57 | model.graph.node.append(new_squeeze_components_8[1]) # 添加 Squeeze 节点 58 | 59 | if gather10_output: 60 | new_squeeze_components_10 = create_squeeze_node( 61 | "p2o.Gather.10", gather10_output, [1] 62 | ) 63 | model.graph.initializer.append(new_squeeze_components_10[0]) # 添加常量张量 64 | model.graph.node.append(new_squeeze_components_10[1]) # 添加 Squeeze 节点 65 | 66 | # 更新依赖于 gather8 和 gather10 的节点输入 67 | for node in model.graph.node: 68 | if node.name == "p2o.Cast.0": 69 | node.input[0] = new_squeeze_components_8[1].output[0] 70 | 71 | if node.name == "p2o.Gather.12": 72 | node.input[1] = new_squeeze_components_10[1].output[0] 73 | 74 | # 保存修改后的模型 75 | onnx.save(model, model_path) 76 | -------------------------------------------------------------------------------- /tools/onnx_transform.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": {}, 7 | "source": [ 8 | "!pip install paddle2onnx onnxruntime onnxslim onnxruntime-tools onnx pickleshare -i https://pypi.tuna.tsinghua.edu.cn/simple" 9 | ], 10 | "outputs": [] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "metadata": { 15 | "collapsed": false, 16 | "jupyter": { 17 | "outputs_hidden": false 18 | }, 19 | "pycharm": { 20 | "name": "#%%\n" 21 | }, 22 | "ExecuteTime": { 23 | "end_time": "2024-10-19T13:58:51.289307Z", 24 | "start_time": "2024-10-19T13:58:31.510101Z" 25 | } 26 | }, 27 | "source": [ 28 | "!paddle2onnx --model_dir ../rapid_table_det_paddle/models --model_filename obj_det_paddle.pdmodel --params_filename obj_det_paddle.pdiparams --save_file ../rapid_table_det/models/obj_det.onnx --opset_version 16 --enable_onnx_checker True\n", 29 | "!paddle2onnx --model_dir ../rapid_table_det_paddle/models --model_filename edge_det_paddle.pdmodel --params_filename edge_det_paddle.pdiparams --save_file ../rapid_table_det/models/edge_det.onnx --opset_version 16 --enable_onnx_checker True\n", 30 | "!paddle2onnx --model_dir ../rapid_table_det_paddle/models --model_filename cls_det_paddle.pdmodel --params_filename cls_det_paddle.pdiparams --save_file ../rapid_table_det/models/cls_det.onnx --opset_version 16 --enable_onnx_checker True\n", 31 | "\n", 32 | "!onnxslim ../rapid_table_det/models/obj_det.onnx ../rapid_table_det/models/obj_det.onnx\n", 33 | "!onnxslim ../rapid_table_det/models/edge_det.onnx ../rapid_table_det/models/edge_det.onnx\n", 34 | "!onnxslim ../rapid_table_det/models/cls_det.onnx ../rapid_table_det/models/cls_det.onnx" 35 | ], 36 | "execution_count": 1, 37 | "outputs": [] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "metadata": { 42 | "collapsed": false, 43 | "jupyter": { 44 | "outputs_hidden": false 45 | }, 46 | "pycharm": { 47 | "name": "#%%\n" 48 | }, 49 | "ExecuteTime": { 50 | "end_time": "2024-10-19T13:58:56.174983Z", 51 | "start_time": "2024-10-19T13:58:55.580038Z" 52 | } 53 | }, 54 | "source": [ 55 | "from pathlib import Path\n", 56 | "import onnx\n", 57 | "from onnxruntime.quantization import quantize_dynamic, QuantType, quant_pre_process\n", 58 | "def quantize_model(root_dir_str, model_dir, pre_fix):\n", 59 | "\n", 60 | " original_model_path = f\"{pre_fix}.onnx\"\n", 61 | " quantized_model_path = f\"{pre_fix}_quantized.onnx\"\n", 62 | " # quantized_model_path = original_model_path\n", 63 | " original_model_path = f\"{root_dir_str}/{model_dir}/{original_model_path}\"\n", 64 | " quantized_model_path = f\"{root_dir_str}/{model_dir}/{quantized_model_path}\"\n", 65 | " quant_pre_process(original_model_path, quantized_model_path, auto_merge=True)\n", 66 | " # 进行动态量化\n", 67 | " quantize_dynamic(\n", 68 | " model_input=quantized_model_path,\n", 69 | " model_output=quantized_model_path,\n", 70 | " weight_type=QuantType.QUInt8\n", 71 | " )\n", 72 | "\n", 73 | " # 验证量化后的模型\n", 74 | " quantized_model = onnx.load(quantized_model_path)\n", 75 | " onnx.checker.check_model(quantized_model)\n", 76 | " print(\"Quantized model is valid.\")" 77 | ], 78 | "execution_count": 2, 79 | "outputs": [] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "metadata": { 84 | "collapsed": false, 85 | "jupyter": { 86 | "outputs_hidden": false 87 | }, 88 | "pycharm": { 89 | "name": "#%%\n" 90 | }, 91 | "ExecuteTime": { 92 | "end_time": "2024-10-19T13:59:14.149803Z", 93 | "start_time": "2024-10-19T13:58:59.542092Z" 94 | } 95 | }, 96 | "source": [ 97 | "root_dir_str = \"..\"\n", 98 | "model_dir = f\"rapid_table_det/models\"\n", 99 | "quantize_model(root_dir_str, model_dir, \"obj_det\")\n", 100 | "quantize_model(root_dir_str, model_dir, \"edge_det\")" 101 | ], 102 | "execution_count": 3, 103 | "outputs": [] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "metadata": { 108 | "ExecuteTime": { 109 | "end_time": "2024-10-19T13:59:19.984452Z", 110 | "start_time": "2024-10-19T13:59:18.181521Z" 111 | } 112 | }, 113 | "source": [ 114 | "from fix_onnx import fix_onnx\n", 115 | "import os\n", 116 | "# 指定目录路径\n", 117 | "model_dir = \"../rapid_table_det/models\"\n", 118 | "# 加载现有 ONNX 模型\n", 119 | "model_path = os.path.join(model_dir, \"obj_det.onnx\")\n", 120 | "fix_onnx(model_path)\n", 121 | "model_path = os.path.join(model_dir, \"obj_det_quantized.onnx\")\n", 122 | "fix_onnx(model_path)" 123 | ], 124 | "execution_count": 4, 125 | "outputs": [] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "source": [], 132 | "outputs": [] 133 | } 134 | ], 135 | "metadata": { 136 | "kernelspec": { 137 | "display_name": "Python 3 (ipykernel)", 138 | "language": "python", 139 | "name": "python3" 140 | }, 141 | "language_info": { 142 | "codemirror_mode": { 143 | "name": "ipython", 144 | "version": 3 145 | }, 146 | "file_extension": ".py", 147 | "mimetype": "text/x-python", 148 | "name": "python", 149 | "nbconvert_exporter": "python", 150 | "pygments_lexer": "ipython3", 151 | "version": "3.10.14" 152 | } 153 | }, 154 | "nbformat": 4, 155 | "nbformat_minor": 4 156 | } 157 | --------------------------------------------------------------------------------