├── .github
├── FUNDING.yml
├── ISSUE_TEMPLATE
│ ├── 01-feature_request.md
│ ├── 02-bug.md
│ └── 03-blank.md
└── workflows
│ ├── SyncToGitee.yml
│ └── publish_whl.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── cliff.toml
├── demo.py
├── docs
└── doc_whl_rapid_table.md
├── rapid_table
├── __init__.py
├── default_models.yaml
├── engine_cfg.yaml
├── inference_engine
│ ├── __init__.py
│ ├── base.py
│ ├── onnxruntime
│ │ ├── __init__.py
│ │ ├── main.py
│ │ └── provider_config.py
│ └── torch.py
├── main.py
├── model_processor
│ ├── __init__.py
│ └── main.py
├── models
│ └── .gitkeep
├── table_matcher
│ ├── __init__.py
│ ├── main.py
│ └── utils.py
├── table_structure
│ ├── __init__.py
│ ├── pp_structure
│ │ ├── __init__.py
│ │ ├── main.py
│ │ ├── post_process.py
│ │ └── pre_process.py
│ ├── unitable
│ │ ├── __init__.py
│ │ ├── consts.py
│ │ ├── main.py
│ │ ├── post_process.py
│ │ ├── pre_process.py
│ │ └── unitable_modules.py
│ └── utils.py
└── utils
│ ├── __init__.py
│ ├── download_file.py
│ ├── load_image.py
│ ├── logger.py
│ ├── typings.py
│ ├── utils.py
│ └── vis.py
├── requirements.txt
├── setup.py
└── tests
├── test_files
├── table.jpg
└── table_without_txt.jpg
├── test_main.py
└── test_table_matcher.py
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 |
3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
4 | patreon: # Replace with a single Patreon username
5 | open_collective: # Replace with a single Open Collective username
6 | ko_fi: # Replace with a single Ko-fi username
7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
9 | liberapay: # Replace with a single Liberapay username
10 | issuehunt: # Replace with a single IssueHunt username
11 | otechie: # Replace with a single Otechie username
12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
13 | custom: https://raw.githubusercontent.com/RapidAI/.github/6db6b6b9273f3151094a462a61fbc8e88564562c/assets/Sponsor.png
14 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/01-feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature Request
3 | about: requests for new RapidOCR features
4 | title: 'Feature Request'
5 | labels: 'Feature Request'
6 | assignees: ''
7 |
8 | ---
9 |
10 | 请您详细描述想要添加的新功能或者是新特性
11 | (Please describe in detail the new function or new feature you want to add)
12 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/02-bug.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug
3 | about: Bug
4 | title: 'Bug'
5 | labels: 'Bug'
6 | assignees: ''
7 |
8 | ---
9 |
10 | 请提供下述完整信息以便快速定位问题
11 | (Please provide the following information to quickly locate the problem)
12 | - **系统环境/System Environment**:
13 | - **使用的是哪门语言的程序/Which programing language**:
14 | - **所使用语言相关版本信息/Version**:
15 | - **OnnxRuntime版本/OnnxRuntime Version**:
16 | - **可复现问题的demo/Demo of reproducible problems**:
17 | - **完整报错/Complete Error Message**:
18 | - **可能的解决方案/Possible solutions**:
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/03-blank.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Blank Template
3 | about: Blank Template
4 | title: 'Blank Template'
5 | labels: 'Blank Template'
6 | assignees: ''
7 |
8 | ---
--------------------------------------------------------------------------------
/.github/workflows/SyncToGitee.yml:
--------------------------------------------------------------------------------
1 | name: SyncToGitee
2 | on:
3 | push:
4 | branches:
5 | - main
6 | jobs:
7 | repo-sync:
8 | runs-on: ubuntu-latest
9 | steps:
10 | - name: Checkout source codes
11 | uses: actions/checkout@v3
12 |
13 | - name: Mirror the Github organization repos to Gitee.
14 | uses: Yikun/hub-mirror-action@v1.4
15 | with:
16 | src: 'github/RapidAI'
17 | dst: 'gitee/RapidAI'
18 | dst_key: ${{ secrets.GITEE_PRIVATE_KEY }}
19 | dst_token: ${{ secrets.GITEE_TOKEN }}
20 | force_update: true
21 | static_list: "RapidTable"
22 | debug: true
23 |
--------------------------------------------------------------------------------
/.github/workflows/publish_whl.yml:
--------------------------------------------------------------------------------
1 | name: Push rapidocr_table to pypi
2 |
3 | on:
4 | push:
5 | tags:
6 | - v*
7 |
8 | env:
9 | DEFAULT_MODEL: https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/slanet-plus.onnx
10 |
11 | jobs:
12 | UnitTesting:
13 | runs-on: ubuntu-latest
14 | steps:
15 | - name: Pull latest code
16 | uses: actions/checkout@v3
17 |
18 | - name: Set up Python 3.10
19 | uses: actions/setup-python@v4
20 | with:
21 | python-version: '3.10'
22 | architecture: 'x64'
23 |
24 | - name: Display Python version
25 | run: python -c "import sys; print(sys.version)"
26 |
27 | - name: Unit testings
28 | run: |
29 | wget $DEFAULT_MODEL -P rapid_table/models
30 |
31 | pip install -r requirements.txt
32 | pip install rapidocr onnxruntime torch torchvision tokenizers pytest
33 | pytest tests/*.py
34 |
35 | GenerateWHL_PushPyPi:
36 | needs: UnitTesting
37 | runs-on: ubuntu-latest
38 |
39 | steps:
40 | - uses: actions/checkout@v3
41 |
42 | - name: Set up Python 3.10
43 | uses: actions/setup-python@v4
44 | with:
45 | python-version: '3.10'
46 | architecture: 'x64'
47 |
48 | - name: Run setup
49 | run: |
50 | pip install -r requirements.txt
51 | python -m pip install --upgrade pip
52 | pip install wheel get_pypi_latest_version
53 |
54 | wget $DEFAULT_MODEL -P rapid_table/models
55 | python setup.py bdist_wheel ${{ github.ref_name }}
56 |
57 | - name: Publish distribution 📦 to PyPI
58 | uses: pypa/gh-action-pypi-publish@v1.5.0
59 | with:
60 | password: ${{ secrets.RAPID_TABLE }}
61 | packages_dir: dist/
62 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | outputs/
2 | *.json
3 |
4 | # Created by .ignore support plugin (hsz.mobi)
5 | ### Python template
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 | .pytest_cache
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | wheels/
29 | pip-wheel-metadata/
30 | share/python-wheels/
31 | *.egg-info/
32 | .installed.cfg
33 | *.egg
34 | MANIFEST
35 |
36 | # PyInstaller
37 | # Usually these files are written by a python script from a template
38 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
39 | # *.manifest
40 | # *.spec
41 | *.res
42 |
43 | # Installer logs
44 | pip-log.txt
45 | pip-delete-this-directory.txt
46 |
47 | # Unit test / coverage reports
48 | htmlcov/
49 | .tox/
50 | .nox/
51 | .coverage
52 | .coverage.*
53 | .cache
54 | nosetests.xml
55 | coverage.xml
56 | *.cover
57 | *.py,cover
58 | .hypothesis/
59 | .pytest_cache/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 | db.sqlite3
69 | db.sqlite3-journal
70 |
71 | # Flask stuff:
72 | instance/
73 | .webassets-cache
74 |
75 | # Scrapy stuff:
76 | .scrapy
77 |
78 | # Sphinx documentation
79 | docs/_build/
80 |
81 | # PyBuilder
82 | target/
83 |
84 | # Jupyter Notebook
85 | .ipynb_checkpoints
86 |
87 | # IPython
88 | profile_default/
89 | ipython_config.py
90 |
91 | # pyenv
92 | .python-version
93 |
94 | # pipenv
95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
98 | # install all needed dependencies.
99 | #Pipfile.lock
100 |
101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
102 | __pypackages__/
103 |
104 | # Celery stuff
105 | celerybeat-schedule
106 | celerybeat.pid
107 |
108 | # SageMath parsed files
109 | *.sage.py
110 |
111 | # Environments
112 | .env
113 | .venv
114 | env/
115 | venv/
116 | ENV/
117 | env.bak/
118 | venv.bak/
119 |
120 | # Spyder project settings
121 | .spyderproject
122 | .spyproject
123 |
124 | # Rope project settings
125 | .ropeproject
126 |
127 | # mkdocs documentation
128 | /site
129 |
130 | # mypy
131 | .mypy_cache/
132 | .dmypy.json
133 | dmypy.json
134 |
135 | # Pyre type checker
136 | .pyre/
137 |
138 | #idea
139 | .vs
140 | .vscode
141 | .idea
142 | /images
143 | /models
144 |
145 | #models
146 | *.onnx
147 |
148 | *.ttf
149 | *.ttc
150 |
151 | long1.jpg
152 |
153 | *.bin
154 | *.mapping
155 | *.xml
156 |
157 | *.pdiparams
158 | *.pdiparams.info
159 | *.pdmodel
160 |
161 | .DS_Store
162 | *.pth
163 | /rapid_table_torch/models/*.pth
164 | /rapid_table_torch/models/*.json
165 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://gitee.com/SWHL/autoflake
3 | rev: v2.1.1
4 | hooks:
5 | - id: autoflake
6 | args:
7 | [
8 | "--recursive",
9 | "--in-place",
10 | "--remove-all-unused-imports",
11 | "--ignore-init-module-imports",
12 | ]
13 | files: \.py$
14 | - repo: https://gitee.com/SWHL/black
15 | rev: 23.1.0
16 | hooks:
17 | - id: black
18 | files: \.py$
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2025 RapidAI
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
📊 Rapid Table
4 |
5 |
6 |

7 |

8 |

9 |

10 |

11 |

12 |

13 |

14 |
15 |
16 |
17 | ### 🌟 简介
18 |
19 | RapidTable库是专门用来文档类图像的表格结构还原,表格结构模型均属于序列预测方法,结合RapidOCR,将给定图像中的表格转化对应的HTML格式。
20 |
21 | slanet_plus是paddlex内置的SLANet升级版模型,准确率有大幅提升
22 |
23 | unitable是来源unitable的transformer模型,精度最高,暂仅支持pytorch推理,支持gpu推理加速,训练权重来源于 [OhMyTable项目](https://github.com/Sanster/OhMyTable)
24 |
25 | ### 📅 最近动态
26 |
27 | 2025-08-29 update: 发布v3.0.0,支持batch推理,更改了返回值参数,大家可debug单条查看使用 \
28 | 2025-06-22 update: 发布v2.x,适配rapidocr v3.x \
29 | 2025-01-09 update: 发布v1.x,全新接口升级 \
30 | 2024.12.30 update:支持Unitable模型的表格识别,使用pytorch框架 \
31 | 2024.11.24 update:支持gpu推理,适配 rapidOCR 单字识别匹配,支持逻辑坐标返回及可视化 \
32 | 2024.10.13 update:补充最新paddlex-SLANet-plus 模型(paddle2onnx原因暂不能支持onnx)
33 |
34 | ### 📸 效果展示
35 |
36 |
37 |

38 |
39 |
40 | ### 🖥️ 支持设备
41 |
42 | 通过ONNXRuntime推理引擎支持(`rapid_table>=2.0.0`):
43 |
44 | - DirectML
45 | - 昇腾NPU
46 |
47 | 具体使用方法:
48 |
49 | 1. 安装(需要卸载其他onnxruntime):
50 |
51 | ```bash
52 | # DirectML
53 | pip install onnxruntime-directml
54 |
55 | # 昇腾NPU
56 | pip install onnxruntime-cann
57 | ```
58 |
59 | 2. 使用:
60 |
61 | ```python
62 | from rapidocr import RapidOCR
63 |
64 | from rapid_table import ModelType, RapidTable, RapidTableInput
65 |
66 | # DirectML
67 | ocr_engine = RapidOCR(params={"EngineConfig.onnxruntime.use_dml": True})
68 | input_args = RapidTableInput(
69 | model_type=ModelType.SLANETPLUS, engine_cfg={"use_dml": True}
70 | )
71 |
72 | # 昇腾NPU
73 | ocr_engine = RapidOCR(params={"EngineConfig.onnxruntime.use_cann": True})
74 |
75 | input_args = RapidTableInput(
76 | model_type=ModelType.SLANETPLUS,
77 | engine_cfg={"use_cann": True, "cann_ep_cfg.gpu_id": 1},
78 | )
79 |
80 | table_engine = RapidTable(input_args)
81 |
82 | img_path = "https://raw.githubusercontent.com/RapidAI/RapidTable/refs/heads/main/tests/test_files/table.jpg"
83 |
84 | ori_ocr_res = ocr_engine(img_path)
85 | ocr_results = [ori_ocr_res.boxes, ori_ocr_res.txts, ori_ocr_res.scores]
86 |
87 | results = table_engine(img_path, ocr_results=ocr_results)
88 | results.vis(save_dir="outputs", save_name="vis")
89 | ```
90 |
91 | ### 🧩 模型列表
92 |
93 | | `model_type` | 模型名称 | 推理框架 |模型大小 |推理耗时(单图 60KB)|
94 | |:--------------|:--------------------------------------| :------: |:------ |:------ |
95 | | `ppstructure_en` | `en_ppstructure_mobile_v2_SLANet.onnx` | onnxruntime |7.3M |0.15s |
96 | | `ppstructure_zh` | `ch_ppstructure_mobile_v2_SLANet.onnx` | onnxruntime |7.4M |0.15s |
97 | | `slanet_plus` | `slanet-plus.onnx` | onnxruntime |6.8M |0.15s |
98 | | `unitable` | `unitable(encoder.pth,decoder.pth)` | pytorch |500M |cpu(6s) gpu-4090(1.5s)|
99 |
100 | 模型来源\
101 | [PaddleOCR 表格识别](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/table/README_ch.md)\
102 | [PaddleX-SlaNetPlus 表格识别](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/module_usage/tutorials/ocr_modules/table_structure_recognition.md)\
103 | [Unitable](https://github.com/poloclub/unitable?tab=readme-ov-file)
104 |
105 | 模型下载地址:[link](https://www.modelscope.cn/models/RapidAI/RapidTable/files)
106 |
107 | ### 🛠️ 安装
108 |
109 | 版本依赖关系如下:
110 |
111 | |`rapid_table`|OCR|
112 | |:---:|:---|
113 | |v2.x & v3.x |`rapidocr>=3.0.0`|
114 | |v1.0.x|`rapidocr>=2.0.0,<3.0.0`|
115 | |v0.x|`rapidocr_onnxruntime`|
116 |
117 | 由于模型较小,预先将slanet-plus表格识别模型(`slanet-plus.onnx`)打包进了whl包内。其余模型在初始化`RapidTable`类时,会根据`model_type`来自动下载模型到安装包所在`models`目录下。
118 |
119 | 当然也可以通过`RapidTableInput(model_path='')`来指定自己模型路径(`v1.0.x` 参数变量名使用`model_path`, `v2.x` 参数变量名变更为`model_dir_or_path`)。注意仅限于我们现支持的`model_type`。
120 |
121 | > ⚠️注意:`rapid_table>=v1.0.0`之后,不再将`rapidocr`依赖强制打包到`rapid_table`中。使用前,需要自行安装`rapidocr`包。
122 | >
123 | > ⚠️注意:`rapid_table>=v0.1.0,<1.0.0`之后,不再将`rapidocr`依赖强制打包到`rapid_table`中。使用前,需要自行安装`rapidocr_onnxruntime`包。
124 |
125 | ```bash
126 | pip install rapidocr
127 | pip install rapid_table
128 |
129 | # 基于torch来推理unitable模型
130 | pip install rapid_table[torch] # for unitable inference
131 |
132 | # onnxruntime-gpu推理
133 | pip uninstall onnxruntime
134 | pip install onnxruntime-gpu # for onnx gpu inference
135 | ```
136 |
137 | ### 🚀 使用方式
138 |
139 | #### 🐍 Python脚本运行
140 |
141 | > ⚠️注意:在`rapid_table>=1.0.0`之后,模型输入均采用dataclasses封装,简化和兼容参数传递。输入和输出定义如下:
142 |
143 | ModelType支持已有的4个模型 ([source](./rapid_table/utils/typings.py)):
144 |
145 | ```python
146 | class ModelType(Enum):
147 | PPSTRUCTURE_EN = "ppstructure_en" # onnxruntime
148 | PPSTRUCTURE_ZH = "ppstructure_zh" # onnxruntime
149 | SLANETPLUS = "slanet_plus" # onnxruntime
150 | UNITABLE = "unitable" # torch推理引擎
151 | ```
152 |
153 | #### Batch推理
154 |
155 | ```python
156 | from pathlib import Path
157 |
158 | from rapid_table import ModelType, RapidTable, RapidTableInput
159 |
160 | input_args = RapidTableInput(model_type=ModelType.PPSTRUCTURE_ZH)
161 | table_engine = RapidTable(input_args)
162 |
163 | img_list = list(Path("images").iterdir())
164 | results = table_engine(img_path, batch_size=3) # 这里,batch_size默认为1
165 |
166 | # indexes:指定可视化的图像索引。默认为0
167 | results.vis(save_dir="outputs", save_name="vis", indexes=(0, 1, 2))
168 | ```
169 |
170 | ##### CPU使用
171 |
172 | ```python
173 | from rapid_table import ModelType, RapidTable, RapidTableInput
174 |
175 | input_args = RapidTableInput(model_type=ModelType.UNITABLE)
176 | table_engine = RapidTable(input_args)
177 |
178 | img_path = "https://raw.githubusercontent.com/RapidAI/RapidTable/refs/heads/main/tests/test_files/table.jpg"
179 | ori_ocr_res = ocr_engine(img_path)
180 | results = table_engine(img_path)
181 | results.vis(save_dir="outputs", save_name="vis")
182 | ```
183 |
184 | ##### GPU使用
185 |
186 | > `engine_cfg`中参数是和[`engine_cfg.yaml`](https://github.com/RapidAI/RapidTable/blob/6da3974a35ac5da8a5cf58194eab00b6886212e8/rapid_table/engine_cfg.yaml)相对应的。
187 |
188 | ```python
189 | from rapid_table import ModelType, RapidTable, RapidTableInput
190 |
191 | # onnxruntime-gpu
192 | input_args = RapidTableInput(
193 | model_type=ModelType.SLANETPLUS,
194 | engine_cfg={"use_cuda": True, "cuda_ep_cfg.gpu_id": 1}
195 | )
196 |
197 | # torch gpu
198 | # input_args = RapidTableInput(
199 | # model_type=ModelType.UNITABLE,
200 | # engine_cfg={"use_cuda": True, "gpu_id": 1},
201 | # )
202 |
203 | table_engine = RapidTable(input_args)
204 |
205 | img_path = "https://raw.githubusercontent.com/RapidAI/RapidTable/refs/heads/main/tests/test_files/table.jpg"
206 | results = table_engine(img_path)
207 | results.vis(save_dir="outputs", save_name="vis")
208 | ```
209 |
210 | #### 📦 终端运行
211 |
212 | ```bash
213 | rapid_table https://raw.githubusercontent.com/RapidAI/RapidTable/refs/heads/main/tests/test_files/table.jpg -v
214 | ```
215 |
216 | ### 📝 结果
217 |
218 | #### 📎 返回结果
219 |
220 |
221 |
222 | ```html
223 |
224 |
225 |
226 |
227 | Methods |
228 | |
229 | |
230 | |
231 | FPS |
232 |
233 |
234 | SegLink [26] |
235 | 70.0 |
236 | 86d>
237 | |
239 | 77.0 |
240 | 8.9 |
241 |
242 |
243 | PixelLink [4] |
244 | 73.2 |
245 | 83.0 |
246 | 77.8 |
247 | |
248 |
249 |
250 | TextSnake [18] |
251 | 73.9 |
252 | 83.2 |
253 | 78.3 |
254 | 1.1 |
255 |
256 |
257 | TextField [37] |
258 | 75.9 |
259 | 87.4 |
260 | 81.3 |
261 | 5.2 |
262 |
263 |
264 | MSR[38] |
265 | 76.7 |
266 | 87.87.4 |
267 | 81.7 |
268 | |
269 |
270 |
271 | FTSN [3] |
272 | 77.1 |
273 | 87.6 |
274 | 82.0 |
275 | |
276 |
277 |
278 | LSE[30] |
279 | 81.7 |
280 | 84.2 |
281 | 82.9 |
282 | <>
283 |
284 |
285 |
286 | CRAFT [2] |
287 | 78.2 |
288 | 88.2 |
289 | 82.9 |
290 | 8.6 |
291 |
292 |
293 | MCN[16] |
294 | 79 |
295 | 88 |
296 | 83 |
297 | |
298 |
299 |
300 | ATRR
301 | >[35] |
302 | 82.1 |
303 | 85.2 |
304 | 83.6 |
305 | |
306 |
307 |
308 | PAN [34] |
309 | 83.8 |
310 | 84.4 |
311 | 84.1 |
312 | 30.2 |
313 |
314 |
315 | DB[12] |
316 | 79.2
317 | | 91.5 |
318 | 84.9 |
319 | 32.0 |
320 |
321 |
322 | DRRG[41] |
323 | 82.30 |
324 | 88.05 |
325 | 85.08 |
326 | |
327 |
328 |
329 | Ours (SynText) |
330 | 80.68 |
331 | 85
332 |
334 | 82.97 |
335 | 12.68 |
336 | |
337 |
338 | Ours (MLT-17) |
339 | 84.54 |
340 | 86.62 |
341 | 85.57 |
342 | 12.31 |
343 |
344 |
345 |
346 |
347 | ```
348 |
349 |
350 |
351 | #### 🖼️ 可视化结果
352 |
353 |
354 |
Methods | | | | FPS |
SegLink [26] | 70.0 | 86d> | 77.0 | 8.9 |
PixelLink [4] | 73.2 | 83.0 | 77.8 | |
TextSnake [18] | 73.9 | 83.2 | 78.3 | 1.1 |
TextField [37] | 75.9 | 87.4 | 81.3 | 5.2 |
MSR[38] | 76.7 | 87.87.4 | 81.7 | |
FTSN [3] | 77.1 | 87.6 | 82.0 | |
LSE[30] | 81.7 | 84.2 | 82.9 | <>
CRAFT [2] | 78.2 | 88.2 | 82.9 | 8.6 |
MCN[16] | 79 | 88 | 83 | |
ATRR>[35] | 82.1 | 85.2 | 83.6 | |
PAN [34] | 83.8 | 84.4 | 84.1 | 30.2 |
DB[12] | 79.2 | 91.5 | 84.9 | 32.0 |
DRRG[41] | 82.30 | 88.05 | 85.08 | |
Ours (SynText) | 80.68 | 8582.97 | 12.68 | |
Ours (MLT-17) | 84.54 | 86.62 | 85.57 | 12.31 |
355 |
356 |
357 |
358 | ### 🔄 与[TableStructureRec](https://github.com/RapidAI/TableStructureRec)关系
359 |
360 | TableStructureRec库是一个表格识别算法的集合库,当前有`wired_table_rec`有线表格识别算法和`lineless_table_rec`无线表格识别算法的推理包。
361 |
362 | RapidTable是整理自PP-Structure中表格识别部分而来。由于PP-Structure较早,这个库命名就成了`rapid_table`。
363 |
364 | 总之,RapidTable和TabelStructureRec都是表格识别的仓库。大家可以都试试,哪个好用用哪个。由于每个算法都不太同,暂时不打算做统一处理。
365 |
366 | 关于表格识别算法的比较,可参见[TableStructureRec测评](https://github.com/RapidAI/TableStructureRec#指标结果)
367 |
368 | ### 📌 更新日志 ([more](https://github.com/RapidAI/RapidTable/releases))
369 |
370 |
371 |
372 | #### 2024.12.30 update
373 |
374 | - 支持Unitable模型的表格识别,使用pytorch框架
375 |
376 | #### 2024.11.24 update
377 |
378 | - 支持gpu推理,适配 rapidOCR 单字识别匹配,支持逻辑坐标返回及可视化
379 |
380 | #### 2024.10.13 update
381 |
382 | - 补充最新paddlex-SLANet-plus 模型(paddle2onnx原因暂不能支持onnx)
383 |
384 | #### 2023-12-29 v0.1.3 update
385 |
386 | - 优化可视化结果部分
387 |
388 | #### 2023-12-27 v0.1.2 update
389 |
390 | - 添加返回cell坐标框参数
391 | - 完善可视化函数
392 |
393 | #### 2023-07-17 v0.1.0 update
394 |
395 | - 将`rapidocr_onnxruntime`部分从`rapid_table`中解耦合出来,给出选项是否依赖,更加灵活。
396 |
397 | - 增加接口输入参数`ocr_result`:
398 | - 如果在调用函数时,事先指定了`ocr_result`参数值,则不会再走OCR。其中`ocr_result`格式需要和`rapidocr_onnxruntime`返回值一致。
399 | - 如果未指定`ocr_result`参数值,但是事先安装了`rapidocr_onnxruntime`库,则会自动调用该库,进行识别。
400 | - 如果`ocr_result`未指定,且`rapidocr_onnxruntime`未安装,则会报错。必须满足两个条件中一个。
401 |
402 | #### 2023-07-10 v0.0.13 updata
403 |
404 | - 更改传入表格还原中OCR的实例接口,可以传入其他OCR实例,前提要与`rapidocr_onnxruntime`接口一致
405 |
406 | #### 2023-07-06 v0.0.12 update
407 |
408 | - 去掉返回表格的html字符串中的``元素,便于后续统一。
409 | - 采用Black工具优化代码
410 |
411 |
412 |
--------------------------------------------------------------------------------
/cliff.toml:
--------------------------------------------------------------------------------
1 | # git-cliff ~ configuration file
2 | # https://git-cliff.org/docs/configuration
3 |
4 | [changelog]
5 | # A Tera template to be rendered as the changelog's footer.
6 | # See https://keats.github.io/tera/docs/#introduction
7 | # header = """
8 | # # Changelog\n
9 | # All notable changes to this project will be documented in this file. See [conventional commits](https://www.conventionalcommits.org/) for commit guidelines.\n
10 | # """
11 | # A Tera template to be rendered for each release in the changelog.
12 | # See https://keats.github.io/tera/docs/#introduction
13 | body = """
14 | {% for group, commits in commits | group_by(attribute="group") %}
15 | ### {{ group | striptags | trim | upper_first }}
16 | {% for commit in commits
17 | | filter(attribute="scope")
18 | | sort(attribute="scope") %}
19 | - **({{commit.scope}})**{% if commit.breaking %} [**breaking**]{% endif %} \
20 | {{ commit.message }} by [@{{ commit.author.name }}](https://github.com/{{ commit.author.name }}) in [{{ commit.id | truncate(length=7, end="") }}]($REPO/commit/{{ commit.id }})
21 | {%- endfor -%}
22 | {% raw %}\n{% endraw %}\
23 | {%- for commit in commits %}
24 | {%- if commit.scope -%}
25 | {% else -%}
26 | - {% if commit.breaking %} [**breaking**]{% endif %}\
27 | {{ commit.message }} by [@{{ commit.author.name }}](https://github.com/{{ commit.author.name }}) in [{{ commit.id | truncate(length=7, end="") }}]($REPO/commit/{{ commit.id }})
28 | {% endif -%}
29 | {% endfor -%}
30 | {% endfor %}
31 |
32 |
33 | {% if github.contributors | length > 0 %}
34 | ### 🎉 Contributors
35 |
36 | {% for contributor in github.contributors %}
37 | - [@{{ contributor.username }}](https://github.com/{{ contributor.username }})
38 | {%- endfor -%}
39 | {% endif %}
40 |
41 |
42 | {% if version %}
43 | {% if previous.version %}\
44 | **Full Changelog**: [{{ version | trim_start_matches(pat="v") }}]($REPO/compare/{{ previous.version }}..{{ version }})
45 | {% else %}\
46 | **Full Changelog**: [{{ version | trim_start_matches(pat="v") }}]
47 | {% endif %}\
48 | {% else %}\
49 | ## [unreleased]
50 | {% endif %}
51 | """
52 | # A Tera template to be rendered as the changelog's footer.
53 | # See https://keats.github.io/tera/docs/#introduction
54 |
55 | footer = """
56 |
57 | """
58 |
59 | # Remove leading and trailing whitespaces from the changelog's body.
60 | trim = true
61 | # postprocessors
62 | postprocessors = [
63 | # Replace the placeholder `` with a URL.
64 | { pattern = '\$REPO', replace = "https://github.com/RapidAI/RapidTable" }, # replace repository URL
65 | ]
66 |
67 | [git]
68 | # Parse commits according to the conventional commits specification.
69 | # See https://www.conventionalcommits.org
70 | conventional_commits = true
71 | # Exclude commits that do not match the conventional commits specification.
72 | filter_unconventional = true
73 | # Split commits on newlines, treating each line as an individual commit.
74 | split_commits = false
75 | # An array of regex based parsers to modify commit messages prior to further processing.
76 | commit_preprocessors = [
77 | # Replace issue numbers with link templates to be updated in `changelog.postprocessors`.
78 | #{ pattern = '\((\w+\s)?#([0-9]+)\)', replace = "([#${2}](https://github.com/orhun/git-cliff/issues/${2}))"},
79 | ]
80 | # An array of regex based parsers for extracting data from the commit message.
81 | # Assigns commits to groups.
82 | # Optionally sets the commit's scope and can decide to exclude commits from further processing.
83 | commit_parsers = [
84 | { message = "^feat", group = "🚀 Features" },
85 | { message = "^fix", group = "🐛 Bug Fixes" },
86 | { message = "^doc", group = "📚 Documentation" },
87 | { message = "^perf", group = "⚡ Performance" },
88 | { message = "^refactor", group = "🚜 Refactor" },
89 | { message = "^style", group = "🎨 Styling" },
90 | { message = "^test", group = "🧪 Testing" },
91 | { message = "^chore\\(release\\): prepare for", skip = true },
92 | { message = "^chore\\(deps.*\\)", skip = true },
93 | { message = "^chore\\(pr\\)", skip = true },
94 | { message = "^chore\\(pull\\)", skip = true },
95 | { message = "^chore|^ci", group = "⚙️ Miscellaneous Tasks" },
96 | { body = ".*security", group = "🛡️ Security" },
97 | { message = "^revert", group = "◀️ Revert" },
98 | { message = ".*", group = "💼 Other" },
99 | ]
100 | # Exclude commits that are not matched by any commit parser.
101 | filter_commits = false
102 | # Order releases topologically instead of chronologically.
103 | topo_order = false
104 | # Order of commits in each group/release within the changelog.
105 | # Allowed values: newest, oldest
106 | sort_commits = "newest"
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from pathlib import Path
5 |
6 | from rapid_table import ModelType, RapidTable, RapidTableInput
7 |
8 | # input_args = RapidTableInput(
9 | # model_type=ModelType.UNITABLE,
10 | # engine_cfg={"use_cuda": True, "gpu_id": 1},
11 | # )
12 | input_args = RapidTableInput(model_type=ModelType.PPSTRUCTURE_ZH)
13 | table_engine = RapidTable(input_args)
14 |
15 | img_list = list(Path("images").iterdir())
16 | results = table_engine(img_list, batch_size=3)
17 | results.vis(save_dir="outputs", save_name="vis", indexes=(0, 1, 2))
18 |
--------------------------------------------------------------------------------
/docs/doc_whl_rapid_table.md:
--------------------------------------------------------------------------------
1 | ### For details, see [Rapid Table](https://github.com/RapidAI/RapidTable)
2 |
--------------------------------------------------------------------------------
/rapid_table/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from .main import RapidTable, RapidTableInput
5 | from .utils import EngineType, ModelType, VisTable
6 |
--------------------------------------------------------------------------------
/rapid_table/default_models.yaml:
--------------------------------------------------------------------------------
1 | ppstructure_en:
2 | model_dir_or_path: https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/v2.0.0/en_ppstructure_mobile_v2_SLANet.onnx
3 | SHA256: 2cae17d16a16f9df7229e21665fe3fbe06f3ca85b2024772ee3e3142e955aa60
4 |
5 | ppstructure_zh:
6 | model_dir_or_path: https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/v2.0.0/ch_ppstructure_mobile_v2_SLANet.onnx
7 | SHA256: ddfc6c97ee4db2a5e9de4de8b6a14508a39d42d228503219fdfebfac364885e3
8 |
9 | slanet_plus:
10 | model_dir_or_path: https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/v2.0.0/slanet-plus.onnx
11 | SHA256: d57a942af6a2f57d6a4a0372573c696a2379bf5857c45e2ac69993f3b334514b
12 |
13 | unitable:
14 | model_dir_or_path: https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/v2.0.0/unitable
15 | SHA256:
16 | encoder.pth: 2c66b3c6a3d1c86a00985bab2cd79412fc2b668ff39d338bc3c63d383b08684d
17 | decoder.pth: fa342ef3de259576a01a5545ede804208ef35a124935e30df4768e6708dcb6cb
18 | vocab.json: 05037d02c48d106639bc90284aa847e5e2151d4746b3f5efe1628599efbd668a
19 |
20 |
--------------------------------------------------------------------------------
/rapid_table/engine_cfg.yaml:
--------------------------------------------------------------------------------
1 | onnxruntime:
2 | intra_op_num_threads: -1
3 | inter_op_num_threads: -1
4 | enable_cpu_mem_arena: false
5 |
6 | cpu_ep_cfg:
7 | arena_extend_strategy: "kSameAsRequested"
8 |
9 | use_cuda: false
10 | cuda_ep_cfg:
11 | gpu_id: 0
12 | arena_extend_strategy: "kNextPowerOfTwo"
13 | cudnn_conv_algo_search: "EXHAUSTIVE"
14 | do_copy_in_default_stream: true
15 |
16 | use_dml: false
17 | dm_ep_cfg: null
18 |
19 | use_cann: false
20 | cann_ep_cfg:
21 | gpu_id: 0
22 | arena_extend_strategy: "kNextPowerOfTwo"
23 | npu_mem_limit: 21474836480 # 20 * 1024 * 1024 * 1024
24 | op_select_impl_mode: "high_performance"
25 | optypelist_for_implmode: "Gelu"
26 | enable_cann_graph: true
27 |
28 | openvino:
29 | inference_num_threads: -1
30 |
31 | paddle:
32 | cpu_math_library_num_threads: -1
33 | use_cuda: false
34 | gpu_id: 0
35 | gpu_mem: 500
36 |
37 | torch:
38 | use_cuda: false
39 | gpu_id: 0
40 |
41 |
--------------------------------------------------------------------------------
/rapid_table/inference_engine/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 |
--------------------------------------------------------------------------------
/rapid_table/inference_engine/base.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | import abc
5 | from pathlib import Path
6 | from typing import Any, Dict, Union
7 |
8 | import numpy as np
9 | from omegaconf import DictConfig, OmegaConf
10 |
11 | from ..utils import EngineType, Logger, import_package, read_yaml
12 |
13 | logger = Logger(logger_name=__name__).get_log()
14 |
15 |
16 | class InferSession(abc.ABC):
17 | cur_dir = Path(__file__).resolve().parent.parent
18 | ENGINE_CFG_PATH = cur_dir / "engine_cfg.yaml"
19 | engine_cfg = read_yaml(ENGINE_CFG_PATH)
20 |
21 | @abc.abstractmethod
22 | def __init__(self, config):
23 | pass
24 |
25 | @abc.abstractmethod
26 | def __call__(self, input_content: np.ndarray) -> np.ndarray:
27 | pass
28 |
29 | @staticmethod
30 | def _verify_model(model_path: Union[str, Path, None]):
31 | if model_path is None:
32 | raise ValueError("model_path is None!")
33 |
34 | model_path = Path(model_path)
35 | if not model_path.exists():
36 | raise FileNotFoundError(f"{model_path} does not exists.")
37 |
38 | if not model_path.is_file():
39 | raise FileExistsError(f"{model_path} is not a file.")
40 |
41 | @abc.abstractmethod
42 | def have_key(self, key: str = "character") -> bool:
43 | pass
44 |
45 | @staticmethod
46 | def update_params(cfg: DictConfig, params: Dict[str, Any]):
47 | for k, v in params.items():
48 | OmegaConf.update(cfg, k, v)
49 | return cfg
50 |
51 |
52 | def get_engine(engine_type: EngineType):
53 | logger.info("Using engine_name: %s", engine_type.value)
54 |
55 | if engine_type == EngineType.ONNXRUNTIME:
56 | if not import_package(engine_type.value):
57 | raise ImportError(f"{engine_type.value} is not installed.")
58 |
59 | from .onnxruntime import OrtInferSession
60 |
61 | return OrtInferSession
62 |
63 | if engine_type == EngineType.TORCH:
64 | if not import_package(engine_type.value):
65 | raise ImportError(f"{engine_type.value} is not installed")
66 |
67 | from .torch import TorchInferSession
68 |
69 | return TorchInferSession
70 |
71 | raise ValueError(f"Unsupported engine: {engine_type.value}")
72 |
--------------------------------------------------------------------------------
/rapid_table/inference_engine/onnxruntime/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from .main import OrtInferSession
5 |
--------------------------------------------------------------------------------
/rapid_table/inference_engine/onnxruntime/main.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | import os
5 | import traceback
6 | from pathlib import Path
7 | from typing import Any, Dict, List, Optional
8 |
9 | import numpy as np
10 | from omegaconf import DictConfig
11 | from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions
12 |
13 | from ...utils.logger import Logger
14 | from ..base import InferSession
15 | from .provider_config import ProviderConfig
16 |
17 |
18 | class OrtInferSession(InferSession):
19 | def __init__(self, cfg: Optional[Dict[str, Any]] = None):
20 | self.logger = Logger(logger_name=__name__).get_log()
21 |
22 | # support custom session (PR #451)
23 | session = cfg.get("session", None)
24 | if session is not None:
25 | if not isinstance(session, InferenceSession):
26 | raise TypeError(
27 | f"Expected session to be an InferenceSession, got {type(session)}"
28 | )
29 |
30 | self.logger.debug("Using the provided InferenceSession for inference.")
31 | self.session = session
32 | return
33 |
34 | model_path = cfg.get("model_dir_or_path", None)
35 | self.logger.info(f"Using {model_path}")
36 | model_path = Path(model_path)
37 | self._verify_model(model_path)
38 |
39 | engine_cfg = self.update_params(
40 | self.engine_cfg[cfg["engine_type"].value], cfg["engine_cfg"]
41 | )
42 |
43 | sess_opt = self._init_sess_opts(engine_cfg)
44 | provider_cfg = ProviderConfig(engine_cfg=engine_cfg)
45 | self.session = InferenceSession(
46 | model_path,
47 | sess_options=sess_opt,
48 | providers=provider_cfg.get_ep_list(),
49 | )
50 | provider_cfg.verify_providers(self.session.get_providers())
51 |
52 | @staticmethod
53 | def _init_sess_opts(cfg: DictConfig) -> SessionOptions:
54 | sess_opt = SessionOptions()
55 | sess_opt.log_severity_level = 4
56 | sess_opt.enable_cpu_mem_arena = cfg.get("enable_cpu_mem_arena", False)
57 | sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
58 |
59 | cpu_nums = os.cpu_count()
60 | intra_op_num_threads = cfg.get("intra_op_num_threads", -1)
61 | if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums:
62 | sess_opt.intra_op_num_threads = intra_op_num_threads
63 |
64 | inter_op_num_threads = cfg.get("inter_op_num_threads", -1)
65 | if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums:
66 | sess_opt.inter_op_num_threads = inter_op_num_threads
67 |
68 | return sess_opt
69 |
70 | def __call__(self, input_content: np.ndarray) -> np.ndarray:
71 | input_dict = dict(zip(self.get_input_names(), [input_content]))
72 | try:
73 | return self.session.run(self.get_output_names(), input_dict)
74 | except Exception as e:
75 | error_info = traceback.format_exc()
76 | raise ONNXRuntimeError(error_info) from e
77 |
78 | def get_input_names(self) -> List[str]:
79 | return [v.name for v in self.session.get_inputs()]
80 |
81 | def get_output_names(self) -> List[str]:
82 | return [v.name for v in self.session.get_outputs()]
83 |
84 | def get_character_list(self, key: str = "character") -> List[str]:
85 | meta_dict = self.session.get_modelmeta().custom_metadata_map
86 | return meta_dict[key].splitlines()
87 |
88 | def have_key(self, key: str = "character") -> bool:
89 | meta_dict = self.session.get_modelmeta().custom_metadata_map
90 | if key in meta_dict.keys():
91 | return True
92 | return False
93 |
94 |
95 | class ONNXRuntimeError(Exception):
96 | pass
97 |
--------------------------------------------------------------------------------
/rapid_table/inference_engine/onnxruntime/provider_config.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | import platform
5 | from enum import Enum
6 | from typing import Any, Dict, List, Sequence, Tuple
7 |
8 | from omegaconf import DictConfig
9 | from onnxruntime import get_available_providers, get_device
10 |
11 | from ...utils.logger import Logger
12 |
13 |
14 | class EP(Enum):
15 | CPU_EP = "CPUExecutionProvider"
16 | CUDA_EP = "CUDAExecutionProvider"
17 | DIRECTML_EP = "DmlExecutionProvider"
18 | CANN_EP = "CANNExecutionProvider"
19 |
20 |
21 | class ProviderConfig:
22 | def __init__(self, engine_cfg: DictConfig):
23 | self.logger = Logger(logger_name=__name__).get_log()
24 |
25 | self.had_providers: List[str] = get_available_providers()
26 | self.default_provider = self.had_providers[0]
27 |
28 | self.cfg_use_cuda = engine_cfg.get("use_cuda", False)
29 | self.cfg_use_dml = engine_cfg.get("use_dml", False)
30 | self.cfg_use_cann = engine_cfg.get("use_cann", False)
31 |
32 | self.cfg = engine_cfg
33 |
34 | def get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]:
35 | results = [(EP.CPU_EP.value, self.cpu_ep_cfg())]
36 |
37 | if self.is_cuda_available():
38 | results.insert(0, (EP.CUDA_EP.value, self.cuda_ep_cfg()))
39 |
40 | if self.is_dml_available():
41 | self.logger.info(
42 | "Windows 10 or above detected, try to use DirectML as primary provider"
43 | )
44 | results.insert(0, (EP.DIRECTML_EP.value, self.dml_ep_cfg()))
45 |
46 | if self.is_cann_available():
47 | self.logger.info("Try to use CANNExecutionProvider to infer")
48 | results.insert(0, (EP.CANN_EP.value, self.cann_ep_cfg()))
49 |
50 | return results
51 |
52 | def cpu_ep_cfg(self) -> Dict[str, Any]:
53 | return dict(self.cfg.cpu_ep_cfg)
54 |
55 | def cuda_ep_cfg(self) -> Dict[str, Any]:
56 | return dict(self.cfg.cuda_ep_cfg)
57 |
58 | def dml_ep_cfg(self) -> Dict[str, Any]:
59 | if self.cfg.dm_ep_cfg is not None:
60 | return self.cfg.dm_ep_cfg
61 |
62 | if self.is_cuda_available():
63 | return self.cuda_ep_cfg()
64 | return self.cpu_ep_cfg()
65 |
66 | def cann_ep_cfg(self) -> Dict[str, Any]:
67 | return dict(self.cfg.cann_ep_cfg)
68 |
69 | def verify_providers(self, session_providers: Sequence[str]):
70 | if not session_providers:
71 | raise ValueError("Session Providers is empty")
72 |
73 | first_provider = session_providers[0]
74 |
75 | providers_to_check = {
76 | EP.CUDA_EP: self.is_cuda_available,
77 | EP.DIRECTML_EP: self.is_dml_available,
78 | EP.CANN_EP: self.is_cann_available,
79 | }
80 |
81 | for ep, check_func in providers_to_check.items():
82 | if check_func() and first_provider != ep.value:
83 | self.logger.warning(
84 | f"{ep.value} is available, but the inference part is automatically shifted to be executed under {first_provider}. "
85 | )
86 | self.logger.warning(f"The available lists are {session_providers}")
87 |
88 | def is_cuda_available(self) -> bool:
89 | if not self.cfg_use_cuda:
90 | return False
91 |
92 | CUDA_EP = EP.CUDA_EP.value
93 | if get_device() == "GPU" and CUDA_EP in self.had_providers:
94 | return True
95 |
96 | self.logger.warning(
97 | f"{CUDA_EP} is not in available providers ({self.had_providers}). Use {self.default_provider} inference by default."
98 | )
99 | install_instructions = [
100 | f"If you want to use {CUDA_EP} acceleration, you must do:"
101 | "(For reference only) If you want to use GPU acceleration, you must do:",
102 | "First, uninstall all onnxruntime packages in current environment.",
103 | "Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`.",
104 | "Note the onnxruntime-gpu version must match your cuda and cudnn version.",
105 | "You can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html",
106 | f"Third, ensure {CUDA_EP} is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']",
107 | ]
108 | self.print_log(install_instructions)
109 | return False
110 |
111 | def is_dml_available(self) -> bool:
112 | if not self.cfg_use_dml:
113 | return False
114 |
115 | cur_os = platform.system()
116 | if cur_os != "Windows":
117 | self.logger.warning(
118 | f"DirectML is only supported in Windows OS. The current OS is {cur_os}. Use {self.default_provider} inference by default.",
119 | )
120 | return False
121 |
122 | window_build_number_str = platform.version().split(".")[-1]
123 | window_build_number = (
124 | int(window_build_number_str) if window_build_number_str.isdigit() else 0
125 | )
126 | if window_build_number < 18362:
127 | self.logger.warning(
128 | f"DirectML is only supported in Windows 10 Build 18362 and above OS. The current Windows Build is {window_build_number}. Use {self.default_provider} inference by default.",
129 | )
130 | return False
131 |
132 | DML_EP = EP.DIRECTML_EP.value
133 | if DML_EP in self.had_providers:
134 | return True
135 |
136 | self.logger.warning(
137 | f"{DML_EP} is not in available providers ({self.had_providers}). Use {self.default_provider} inference by default."
138 | )
139 | install_instructions = [
140 | "If you want to use DirectML acceleration, you must do:",
141 | "First, uninstall all onnxruntime packages in current environment.",
142 | "Second, install onnxruntime-directml by `pip install onnxruntime-directml`",
143 | f"Third, ensure {DML_EP} is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']",
144 | ]
145 | self.print_log(install_instructions)
146 | return False
147 |
148 | def is_cann_available(self) -> bool:
149 | if not self.cfg_use_cann:
150 | return False
151 |
152 | CANN_EP = EP.CANN_EP.value
153 | if CANN_EP in self.had_providers:
154 | return True
155 |
156 | self.logger.warning(
157 | f"{CANN_EP} is not in available providers ({self.had_providers}). Use {self.default_provider} inference by default."
158 | )
159 | install_instructions = [
160 | "If you want to use CANN acceleration, you must do:",
161 | "First, ensure you have installed Huawei Ascend software stack.",
162 | "Second, install onnxruntime with CANN support by following the instructions at:",
163 | "\thttps://onnxruntime.ai/docs/execution-providers/community-maintained/CANN-ExecutionProvider.html",
164 | f"Third, ensure {CANN_EP} is in available providers list. e.g. ['CANNExecutionProvider', 'CPUExecutionProvider']",
165 | ]
166 | self.print_log(install_instructions)
167 | return False
168 |
169 | def print_log(self, log_list: List[str]):
170 | for log_info in log_list:
171 | self.logger.info(log_info)
172 |
--------------------------------------------------------------------------------
/rapid_table/inference_engine/torch.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from pathlib import Path
5 | from typing import Union
6 |
7 | import numpy as np
8 | import torch
9 | from tokenizers import Tokenizer
10 |
11 | from ..table_structure.unitable.unitable_modules import Encoder, GPTFastDecoder
12 | from ..utils.logger import Logger
13 | from .base import InferSession
14 |
15 | root_dir = Path(__file__).resolve().parent.parent
16 |
17 |
18 | class TorchInferSession(InferSession):
19 | def __init__(self, cfg) -> None:
20 | self.logger = Logger(logger_name=__name__).get_log()
21 |
22 | engine_cfg = self.update_params(
23 | self.engine_cfg[cfg["engine_type"].value], cfg["engine_cfg"]
24 | )
25 |
26 | self.device = torch.device(
27 | f"cuda:{engine_cfg.gpu_id}"
28 | if torch.cuda.is_available() and engine_cfg.use_cuda
29 | else "cpu"
30 | )
31 |
32 | model_info = cfg["model_dir_or_path"]
33 | self.encoder = self._init_model(model_info["encoder"], Encoder)
34 | self.decoder = self._init_model(model_info["decoder"], GPTFastDecoder)
35 | self.vocab = self._init_vocab(model_info["vocab"])
36 |
37 | def _init_model(self, model_path, model_class):
38 | model = model_class()
39 | model.load_state_dict(torch.load(model_path, map_location=self.device))
40 | model.eval().to(self.device)
41 | return model
42 |
43 | def _init_vocab(self, vocab_path: Union[str, Path]):
44 | return Tokenizer.from_file(str(vocab_path))
45 |
46 | def __call__(self, img: np.ndarray):
47 | raise NotImplementedError(
48 | "Inference logic is not implemented for TorchInferSession."
49 | )
50 |
51 | def have_key(self, key: str = "character") -> bool:
52 | return False
53 |
54 |
55 | class TorchInferError(Exception):
56 | pass
57 |
--------------------------------------------------------------------------------
/rapid_table/main.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | import argparse
5 | import time
6 | from dataclasses import asdict
7 | from pathlib import Path
8 | from typing import Any, Dict, List, Optional, Tuple, Union, get_args
9 |
10 | import numpy as np
11 | from tqdm import tqdm
12 |
13 | from .model_processor.main import ModelProcessor
14 | from .table_matcher import TableMatch
15 | from .utils import (
16 | InputType,
17 | LoadImage,
18 | Logger,
19 | ModelType,
20 | RapidTableInput,
21 | RapidTableOutput,
22 | format_ocr_results,
23 | import_package,
24 | is_url,
25 | )
26 |
27 | logger = Logger(logger_name=__name__).get_log()
28 | root_dir = Path(__file__).resolve().parent
29 |
30 |
31 | class RapidTable:
32 | def __init__(self, cfg: Optional[RapidTableInput] = None):
33 | if cfg is None:
34 | cfg = RapidTableInput()
35 |
36 | if not cfg.model_dir_or_path and cfg.model_type is not None:
37 | cfg.model_dir_or_path = ModelProcessor.get_model_path(cfg.model_type)
38 |
39 | self.cfg = cfg
40 | self.table_structure = self._init_table_structer()
41 |
42 | self.ocr_engine = None
43 | if cfg.use_ocr:
44 | self.ocr_engine = self._init_ocr_engine(self.cfg.ocr_params)
45 |
46 | self.table_matcher = TableMatch()
47 | self.load_img = LoadImage()
48 |
49 | def _init_ocr_engine(self, params: Dict[Any, Any]):
50 | rapidocr_ = import_package("rapidocr")
51 | if rapidocr_ is None:
52 | logger.warning("rapidocr package is not installed, only table rec")
53 | return None
54 |
55 | if not params:
56 | return rapidocr_.RapidOCR()
57 | return rapidocr_.RapidOCR(params=params)
58 |
59 | def _init_table_structer(self):
60 | if self.cfg.model_type == ModelType.UNITABLE:
61 | from .table_structure.unitable import UniTableStructure
62 |
63 | return UniTableStructure(asdict(self.cfg))
64 |
65 | from .table_structure.pp_structure import PPTableStructurer
66 |
67 | return PPTableStructurer(asdict(self.cfg))
68 |
69 | def __call__(
70 | self,
71 | img_contents: Union[List[InputType], InputType],
72 | ocr_results: Optional[Tuple[np.ndarray, Tuple[str], Tuple[float]]] = None,
73 | batch_size: int = 1,
74 | ) -> RapidTableOutput:
75 | s = time.perf_counter()
76 |
77 | if not isinstance(img_contents, list):
78 | img_contents = [img_contents]
79 |
80 | for img_content in img_contents:
81 | if not isinstance(img_content, get_args(InputType)):
82 | type_names = ", ".join([t.__name__ for t in get_args(InputType)])
83 | actual_type = (
84 | type(img_content).__name__ if img_content is not None else "None"
85 | )
86 | raise TypeError(
87 | f"Type Error: Expected input of type [{type_names}], but received type {actual_type}."
88 | )
89 |
90 | results = RapidTableOutput()
91 |
92 | total_nums = len(img_contents)
93 | for start_i in tqdm(range(0, total_nums, batch_size), desc="BatchRec"):
94 | end_i = min(total_nums, start_i + batch_size)
95 |
96 | imgs = self._load_imgs(img_contents[start_i:end_i])
97 |
98 | pred_structures, cell_bboxes = self.table_structure(imgs)
99 | logic_points = self.table_matcher.decode_logic_points(pred_structures)
100 |
101 | if not self.cfg.use_ocr:
102 | results.imgs.extend(imgs)
103 | results.cell_bboxes.extend(cell_bboxes)
104 | results.logic_points.extend(logic_points)
105 | continue
106 |
107 | dt_boxes, rec_res = self.get_ocr_results(imgs, start_i, end_i, ocr_results)
108 | pred_htmls = self.table_matcher(
109 | pred_structures, cell_bboxes, dt_boxes, rec_res
110 | )
111 |
112 | results.imgs.extend(imgs)
113 | results.pred_htmls.extend(pred_htmls)
114 | results.cell_bboxes.extend(cell_bboxes)
115 | results.logic_points.extend(logic_points)
116 |
117 | elapse = time.perf_counter() - s
118 | results.elapse = elapse / total_nums
119 | return results
120 |
121 | def _load_imgs(
122 | self, img_content: Union[List[InputType], InputType]
123 | ) -> List[np.ndarray]:
124 | img_contents = img_content if isinstance(img_content, list) else [img_content]
125 | return [self.load_img(img) for img in img_contents]
126 |
127 | def get_ocr_results(
128 | self,
129 | imgs: List[np.ndarray],
130 | start_i: int,
131 | end_i: int,
132 | ocr_results: Optional[List[Tuple[np.ndarray, Tuple[str], Tuple[float]]]] = None,
133 | ) -> Any:
134 | batch_dt_boxes, batch_rec_res = [], []
135 |
136 | if ocr_results is not None:
137 | ocr_results_batch = ocr_results[start_i:end_i]
138 | if len(ocr_results_batch) != len(imgs):
139 | raise ValueError(
140 | f"Batch size mismatch: {len(imgs)} images but {len(ocr_results_batch)} OCR results "
141 | f"(indices {start_i}:{end_i})."
142 | )
143 |
144 | for img, ocr_result in zip(imgs, ocr_results_batch):
145 | img_h, img_w = img.shape[:2]
146 | dt_boxes, rec_res = format_ocr_results(ocr_result, img_h, img_w)
147 | batch_dt_boxes.append(dt_boxes)
148 | batch_rec_res.append(rec_res)
149 | return batch_dt_boxes, batch_rec_res
150 |
151 | for img in tqdm(imgs, desc="OCR"):
152 | if img is None:
153 | continue
154 |
155 | ori_ocr_res = self.ocr_engine(img)
156 | if ori_ocr_res.boxes is None:
157 | logger.warning("OCR Result is empty")
158 | batch_dt_boxes.append(None)
159 | batch_rec_res.append(None)
160 | continue
161 |
162 | img_h, img_w = img.shape[:2]
163 |
164 | ocr_result = [ori_ocr_res.boxes, ori_ocr_res.txts, ori_ocr_res.scores]
165 | dt_boxes, rec_res = format_ocr_results(ocr_result, img_h, img_w)
166 | batch_dt_boxes.append(dt_boxes)
167 | batch_rec_res.append(rec_res)
168 |
169 | return batch_dt_boxes, batch_rec_res
170 |
171 |
172 | def parse_args(arg_list: Optional[List[str]] = None):
173 | parser = argparse.ArgumentParser()
174 | parser.add_argument("img_path", type=str, help="the image path or URL of the table")
175 | parser.add_argument(
176 | "-m",
177 | "--model_type",
178 | type=str,
179 | default=ModelType.SLANETPLUS.value,
180 | choices=[v.value for v in ModelType],
181 | help="Supported table rec models",
182 | )
183 | parser.add_argument(
184 | "-v",
185 | "--vis",
186 | action="store_true",
187 | default=False,
188 | help="Wheter to visualize the layout results.",
189 | )
190 | args = parser.parse_args(arg_list)
191 | return args
192 |
193 |
194 | def main(arg_list: Optional[List[str]] = None):
195 | args = parse_args(arg_list)
196 | img_path = args.img_path
197 |
198 | input_args = RapidTableInput(model_type=ModelType(args.model_type))
199 | table_engine = RapidTable(input_args)
200 |
201 | table_results = table_engine(img_path)
202 | print(table_results.pred_htmls)
203 |
204 | if args.vis:
205 | save_dir = Path(".") if is_url(img_path) else Path(img_path).resolve().parent
206 | table_results.vis(save_dir, save_name=Path(img_path).stem)
207 |
208 |
209 | if __name__ == "__main__":
210 | main()
211 |
--------------------------------------------------------------------------------
/rapid_table/model_processor/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 |
--------------------------------------------------------------------------------
/rapid_table/model_processor/main.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from pathlib import Path
5 | from typing import Dict, Union
6 |
7 | from ..utils import DownloadFile, DownloadFileInput, Logger, ModelType, mkdir, read_yaml
8 |
9 |
10 | class ModelProcessor:
11 | logger = Logger(logger_name=__name__).get_log()
12 |
13 | cur_dir = Path(__file__).resolve().parent
14 | root_dir = cur_dir.parent
15 | DEFAULT_MODEL_PATH = root_dir / "default_models.yaml"
16 |
17 | DEFAULT_MODEL_DIR = root_dir / "models"
18 | mkdir(DEFAULT_MODEL_DIR)
19 |
20 | model_map = read_yaml(DEFAULT_MODEL_PATH)
21 |
22 | @classmethod
23 | def get_model_path(cls, model_type: ModelType) -> Union[str, Dict[str, str]]:
24 | if model_type == ModelType.UNITABLE:
25 | return cls.get_multi_models_dict(model_type)
26 | return cls.get_single_model_path(model_type)
27 |
28 | @classmethod
29 | def get_single_model_path(cls, model_type: ModelType) -> str:
30 | model_info = cls.model_map[model_type.value]
31 | save_model_path = (
32 | cls.DEFAULT_MODEL_DIR / Path(model_info["model_dir_or_path"]).name
33 | )
34 | download_params = DownloadFileInput(
35 | file_url=model_info["model_dir_or_path"],
36 | sha256=model_info["SHA256"],
37 | save_path=save_model_path,
38 | logger=cls.logger,
39 | )
40 | DownloadFile.run(download_params)
41 |
42 | return str(save_model_path)
43 |
44 | @classmethod
45 | def get_multi_models_dict(cls, model_type: ModelType) -> Dict[str, str]:
46 | model_info = cls.model_map[model_type.value]
47 |
48 | results = {}
49 |
50 | model_root_dir = model_info["model_dir_or_path"]
51 | save_model_dir = cls.DEFAULT_MODEL_DIR / Path(model_root_dir).name
52 | for file_name, sha256 in model_info["SHA256"].items():
53 | save_path = save_model_dir / file_name
54 |
55 | download_params = DownloadFileInput(
56 | file_url=f"{model_root_dir}/{file_name}",
57 | sha256=sha256,
58 | save_path=save_path,
59 | logger=cls.logger,
60 | )
61 | DownloadFile.run(download_params)
62 | results[Path(file_name).stem] = str(save_path)
63 |
64 | return results
65 |
--------------------------------------------------------------------------------
/rapid_table/models/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RapidAI/RapidTable/b0ba540a2a5ac06a61bcd6506795246a5046a291/rapid_table/models/.gitkeep
--------------------------------------------------------------------------------
/rapid_table/table_matcher/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from .main import TableMatch
5 |
--------------------------------------------------------------------------------
/rapid_table/table_matcher/main.py:
--------------------------------------------------------------------------------
1 | # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # -*- encoding: utf-8 -*-
15 | from typing import Dict, List, Optional, Tuple
16 |
17 | import numpy as np
18 |
19 | from .utils import compute_iou, distance
20 |
21 |
22 | class TableMatch:
23 | def __init__(self):
24 | pass
25 |
26 | def __call__(
27 | self,
28 | pred_structures: List[List[List[str]]],
29 | cell_bboxes: List[np.ndarray],
30 | dt_boxes: List[np.ndarray],
31 | rec_reses: List[List[Tuple[str, float]]],
32 | ) -> List[Optional[str]]:
33 | results = []
34 | for item in zip(pred_structures, cell_bboxes, dt_boxes, rec_reses):
35 | pred_struct, cell_bbox, dt_box, rec_res = item
36 | if dt_box is None or rec_res is None:
37 | results.append(None)
38 | continue
39 |
40 | one_result = self.process_one(pred_struct, cell_bbox, dt_box, rec_res)
41 | results.append(one_result)
42 | return results
43 |
44 | def process_one(
45 | self,
46 | pred_struct: List[List[str]],
47 | cell_bboxes: np.ndarray,
48 | dt_boxes: np.ndarray,
49 | rec_res: List[Tuple[str, float]],
50 | ) -> str:
51 | dt_boxes, rec_res = self.filter_ocr_result(cell_bboxes, dt_boxes, rec_res)
52 | matched_index = self.match_result(cell_bboxes, dt_boxes)
53 | pred_html, pred = self.get_pred_html(pred_struct[0], matched_index, rec_res)
54 | return pred_html
55 |
56 | def filter_ocr_result(
57 | self,
58 | cell_bboxes: np.ndarray,
59 | dt_boxes: np.ndarray,
60 | rec_res: List[Tuple[str, float]],
61 | ) -> Tuple[np.ndarray, List[Tuple[str, float]]]:
62 | y1 = cell_bboxes[:, 1::2].min()
63 |
64 | new_dt_boxes, new_rec_res = [], []
65 | for box, rec in zip(dt_boxes, rec_res):
66 | if np.max(box[1::2]) < y1:
67 | continue
68 |
69 | new_dt_boxes.append(box)
70 | new_rec_res.append(rec)
71 | return np.array(new_dt_boxes), new_rec_res
72 |
73 | def match_result(
74 | self, cell_bboxes: np.ndarray, dt_boxes: np.ndarray, min_iou: float = 0.1**8
75 | ) -> Dict[int, List[int]]:
76 | matched = {}
77 | for i, gt_box in enumerate(dt_boxes):
78 | distances = []
79 | for j, pred_box in enumerate(cell_bboxes):
80 | if len(pred_box) == 8:
81 | pred_box = [
82 | np.min(pred_box[0::2]),
83 | np.min(pred_box[1::2]),
84 | np.max(pred_box[0::2]),
85 | np.max(pred_box[1::2]),
86 | ]
87 | distances.append(
88 | (distance(gt_box, pred_box), 1.0 - compute_iou(gt_box, pred_box))
89 | ) # compute iou and l1 distance
90 | sorted_distances = distances.copy()
91 | # select det box by iou and l1 distance
92 | sorted_distances = sorted(
93 | sorted_distances, key=lambda item: (item[1], item[0])
94 | )
95 | # must > min_iou
96 | if sorted_distances[0][1] >= 1 - min_iou:
97 | continue
98 |
99 | if distances.index(sorted_distances[0]) not in matched:
100 | matched[distances.index(sorted_distances[0])] = [i]
101 | else:
102 | matched[distances.index(sorted_distances[0])].append(i)
103 | return matched
104 |
105 | def get_pred_html(
106 | self,
107 | pred_structures: List[str],
108 | matched_index: Dict[int, List[int]],
109 | ocr_contents: List[Tuple[str, float]],
110 | ):
111 | end_html = []
112 | td_index = 0
113 | for tag in pred_structures:
114 | if "" not in tag:
115 | end_html.append(tag)
116 | continue
117 |
118 | if " | " == tag:
119 | end_html.extend("")
120 |
121 | if td_index in matched_index.keys():
122 | b_with = False
123 | if (
124 | "" in ocr_contents[matched_index[td_index][0]]
125 | and len(matched_index[td_index]) > 1
126 | ):
127 | b_with = True
128 | end_html.extend("")
129 |
130 | for i, td_index_index in enumerate(matched_index[td_index]):
131 | content = ocr_contents[td_index_index][0]
132 | if len(matched_index[td_index]) > 1:
133 | if len(content) == 0:
134 | continue
135 |
136 | if content[0] == " ":
137 | content = content[1:]
138 |
139 | if "" in content:
140 | content = content[3:]
141 |
142 | if "" in content:
143 | content = content[:-4]
144 |
145 | if len(content) == 0:
146 | continue
147 |
148 | if i != len(matched_index[td_index]) - 1 and " " != content[-1]:
149 | content += " "
150 | end_html.extend(content)
151 |
152 | if b_with:
153 | end_html.extend("")
154 |
155 | if " | | " == tag:
156 | end_html.append("")
157 | else:
158 | end_html.append(tag)
159 |
160 | td_index += 1
161 |
162 | # Filter elements
163 | filter_elements = ["", "", "", ""]
164 | end_html = [v for v in end_html if v not in filter_elements]
165 | return "".join(end_html), end_html
166 |
167 | def decode_logic_points(self, pred_structures: List[List[List[str]]]):
168 | results = []
169 | for pred_struct in pred_structures:
170 | decode_result = self.decode_one_logic_points(pred_struct[0])
171 | results.append(np.array(decode_result))
172 | return results
173 |
174 | def decode_one_logic_points(self, pred_structures):
175 | logic_points = []
176 | current_row = 0
177 | current_col = 0
178 | max_rows = 0
179 | max_cols = 0
180 | occupied_cells = {} # 用于记录已经被占用的单元格
181 |
182 | def is_occupied(row, col):
183 | return (row, col) in occupied_cells
184 |
185 | def mark_occupied(row, col, rowspan, colspan):
186 | for r in range(row, row + rowspan):
187 | for c in range(col, col + colspan):
188 | occupied_cells[(r, c)] = True
189 |
190 | i = 0
191 | while i < len(pred_structures):
192 | token = pred_structures[i]
193 |
194 | if token == "":
195 | current_col = 0 # 每次遇到
时,重置当前列号
196 | elif token == "
":
197 | current_row += 1 # 行结束,行号增加
198 | elif token.startswith(" | ":
203 | j += 1
204 | # 提取 colspan 和 rowspan 属性
205 | while j < len(pred_structures) and not pred_structures[
206 | j
207 | ].startswith(">"):
208 | if "colspan=" in pred_structures[j]:
209 | colspan = int(pred_structures[j].split("=")[1].strip("\"'"))
210 | elif "rowspan=" in pred_structures[j]:
211 | rowspan = int(pred_structures[j].split("=")[1].strip("\"'"))
212 | j += 1
213 |
214 | # 跳过已经处理过的属性 token
215 | i = j
216 |
217 | # 找到下一个未被占用的列
218 | while is_occupied(current_row, current_col):
219 | current_col += 1
220 |
221 | # 计算逻辑坐标
222 | r_start = current_row
223 | r_end = current_row + rowspan - 1
224 | col_start = current_col
225 | col_end = current_col + colspan - 1
226 |
227 | # 记录逻辑坐标
228 | logic_points.append([r_start, r_end, col_start, col_end])
229 |
230 | # 标记占用的单元格
231 | mark_occupied(r_start, col_start, rowspan, colspan)
232 |
233 | # 更新当前列号
234 | current_col += colspan
235 |
236 | # 更新最大行数和列数
237 | max_rows = max(max_rows, r_end + 1)
238 | max_cols = max(max_cols, col_end + 1)
239 |
240 | i += 1
241 |
242 | return logic_points
243 |
--------------------------------------------------------------------------------
/rapid_table/table_matcher/utils.py:
--------------------------------------------------------------------------------
1 | # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # -*- encoding: utf-8 -*-
15 | # @Author: SWHL
16 | # @Contact: liekkaskono@163.com
17 | import copy
18 | import re
19 |
20 |
21 | def deal_isolate_span(thead_part):
22 | """
23 | Deal with isolate span cases in this function.
24 | It causes by wrong prediction in structure recognition model.
25 | eg. predict | to | rowspan="2">.
26 | :param thead_part:
27 | :return:
28 | """
29 | # 1. find out isolate span tokens.
30 | isolate_pattern = (
31 | r' | rowspan="(\d)+" colspan="(\d)+">|'
32 | r' | colspan="(\d)+" rowspan="(\d)+">|'
33 | r' | rowspan="(\d)+">|'
34 | r' | colspan="(\d)+">'
35 | )
36 | isolate_iter = re.finditer(isolate_pattern, thead_part)
37 | isolate_list = [i.group() for i in isolate_iter]
38 |
39 | # 2. find out span number, by step 1 results.
40 | span_pattern = (
41 | r' rowspan="(\d)+" colspan="(\d)+"|'
42 | r' colspan="(\d)+" rowspan="(\d)+"|'
43 | r' rowspan="(\d)+"|'
44 | r' colspan="(\d)+"'
45 | )
46 | corrected_list = []
47 | for isolate_item in isolate_list:
48 | span_part = re.search(span_pattern, isolate_item)
49 | spanStr_in_isolateItem = span_part.group()
50 | # 3. merge the span number into the span token format string.
51 | if spanStr_in_isolateItem is not None:
52 | corrected_item = f" | "
53 | corrected_list.append(corrected_item)
54 | else:
55 | corrected_list.append(None)
56 |
57 | # 4. replace original isolated token.
58 | for corrected_item, isolate_item in zip(corrected_list, isolate_list):
59 | if corrected_item is not None:
60 | thead_part = thead_part.replace(isolate_item, corrected_item)
61 | else:
62 | pass
63 | return thead_part
64 |
65 |
66 | def deal_duplicate_bb(thead_part):
67 | """
68 | Deal duplicate or after replace.
69 | Keep one in a | token.
70 | :param thead_part:
71 | :return:
72 | """
73 | # 1. find out | in .
74 | td_pattern = (
75 | r'(.+?) | |'
76 | r'(.+?) | |'
77 | r'(.+?) | |'
78 | r'(.+?) | |'
79 | r"(.*?) | "
80 | )
81 | td_iter = re.finditer(td_pattern, thead_part)
82 | td_list = [t.group() for t in td_iter]
83 |
84 | # 2. is multiply in | or not?
85 | new_td_list = []
86 | for td_item in td_list:
87 | if td_item.count("") > 1 or td_item.count("") > 1:
88 | # multiply in | case.
89 | # 1. remove all
90 | td_item = td_item.replace("", "").replace("", "")
91 | # 2. replace -> , -> .
92 | td_item = td_item.replace("", " | ").replace(" | ", "")
93 | new_td_list.append(td_item)
94 | else:
95 | new_td_list.append(td_item)
96 |
97 | # 3. replace original thead part.
98 | for td_item, new_td_item in zip(td_list, new_td_list):
99 | thead_part = thead_part.replace(td_item, new_td_item)
100 | return thead_part
101 |
102 |
103 | def deal_bb(result_token):
104 | """
105 | In our opinion, always occurs in text's context.
106 | This function will find out all tokens in and insert by manual.
107 | :param result_token:
108 | :return:
109 | """
110 | # find out parts.
111 | thead_pattern = "(.*?)"
112 | if re.search(thead_pattern, result_token) is None:
113 | return result_token
114 | thead_part = re.search(thead_pattern, result_token).group()
115 | origin_thead_part = copy.deepcopy(thead_part)
116 |
117 | # check "rowspan" or "colspan" occur in parts or not .
118 | span_pattern = r'| | | | | | '
119 | span_iter = re.finditer(span_pattern, thead_part)
120 | span_list = [s.group() for s in span_iter]
121 | has_span_in_head = True if len(span_list) > 0 else False
122 |
123 | if not has_span_in_head:
124 | # not include "rowspan" or "colspan" branch 1.
125 | # 1. replace | to | , and | to
126 | # 2. it is possible to predict text include or by Text-line recognition,
127 | # so we replace to , and to
128 | thead_part = (
129 | thead_part.replace("", " | ")
130 | .replace(" | ", "")
131 | .replace("", "")
132 | .replace("", "")
133 | )
134 | else:
135 | # include "rowspan" or "colspan" branch 2.
136 | # Firstly, we deal rowspan or colspan cases.
137 | # 1. replace > to >
138 | # 2. replace to
139 | # 3. it is possible to predict text include or by Text-line recognition,
140 | # so we replace to , and to
141 |
142 | # Secondly, deal ordinary cases like branch 1
143 |
144 | # replace ">" to ""
145 | replaced_span_list = []
146 | for sp in span_list:
147 | replaced_span_list.append(sp.replace(">", ">"))
148 | for sp, rsp in zip(span_list, replaced_span_list):
149 | thead_part = thead_part.replace(sp, rsp)
150 |
151 | # replace "" to ""
152 | thead_part = thead_part.replace("", "")
153 |
154 | # remove duplicated by re.sub
155 | mb_pattern = "()+"
156 | single_b_string = ""
157 | thead_part = re.sub(mb_pattern, single_b_string, thead_part)
158 |
159 | mgb_pattern = "()+"
160 | single_gb_string = ""
161 | thead_part = re.sub(mgb_pattern, single_gb_string, thead_part)
162 |
163 | # ordinary cases like branch 1
164 | thead_part = thead_part.replace("", " | ").replace("", "")
165 |
166 | # convert back to , empty cell has no .
167 | # but space cell( ) is suitable for | |
168 | thead_part = thead_part.replace(" | ", " | ")
169 | # deal with duplicated
170 | thead_part = deal_duplicate_bb(thead_part)
171 | # deal with isolate span tokens, which causes by wrong predict by structure prediction.
172 | # eg.PMC5994107_011_00.png
173 | thead_part = deal_isolate_span(thead_part)
174 | # replace original result with new thead part.
175 | result_token = result_token.replace(origin_thead_part, thead_part)
176 | return result_token
177 |
178 |
179 | def deal_eb_token(master_token):
180 | """
181 | post process with , , ...
182 | emptyBboxTokenDict = {
183 | "[]": '',
184 | "[' ']": '',
185 | "['', ' ', '']": '',
186 | "['\\u2028', '\\u2028']": '',
187 | "['', ' ', '']": '',
188 | "['', '']": '',
189 | "['', ' ', '']": '',
190 | "['', '', '', '']": '',
191 | "['', '', ' ', '', '']": '',
192 | "['', '']": '',
193 | "['', ' ', '\\u2028', ' ', '\\u2028', ' ', '']": '',
194 | }
195 | :param master_token:
196 | :return:
197 | """
198 | master_token = master_token.replace("", " | ")
199 | master_token = master_token.replace("", " | ")
200 | master_token = master_token.replace("", " | ")
201 | master_token = master_token.replace("", "\u2028\u2028 | ")
202 | master_token = master_token.replace("", " | ")
203 | master_token = master_token.replace("", " | ")
204 | master_token = master_token.replace("", " | ")
205 | master_token = master_token.replace("", " | ")
206 | master_token = master_token.replace("", " | ")
207 | master_token = master_token.replace("", " | ")
208 | master_token = master_token.replace(
209 | "", " \u2028 \u2028 | "
210 | )
211 | return master_token
212 |
213 |
214 | def distance(box_1, box_2):
215 | x1, y1, x2, y2 = box_1
216 | x3, y3, x4, y4 = box_2
217 | dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
218 | dis_2 = abs(x3 - x1) + abs(y3 - y1)
219 | dis_3 = abs(x4 - x2) + abs(y4 - y2)
220 | return dis + min(dis_2, dis_3)
221 |
222 |
223 | def compute_iou(rec1, rec2):
224 | """
225 | computing IoU
226 | :param rec1: (y0, x0, y1, x1), which reflects
227 | (top, left, bottom, right)
228 | :param rec2: (y0, x0, y1, x1)
229 | :return: scala value of IoU
230 | """
231 | # computing area of each rectangles
232 | S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
233 | S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
234 |
235 | # computing the sum_area
236 | sum_area = S_rec1 + S_rec2
237 |
238 | # find the each edge of intersect rectangle
239 | left_line = max(rec1[1], rec2[1])
240 | right_line = min(rec1[3], rec2[3])
241 | top_line = max(rec1[0], rec2[0])
242 | bottom_line = min(rec1[2], rec2[2])
243 |
244 | # judge if there is an intersect
245 | if left_line >= right_line or top_line >= bottom_line:
246 | return 0.0
247 |
248 | intersect = (right_line - left_line) * (bottom_line - top_line)
249 | return (intersect / (sum_area - intersect)) * 1.0
250 |
--------------------------------------------------------------------------------
/rapid_table/table_structure/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 |
5 |
--------------------------------------------------------------------------------
/rapid_table/table_structure/pp_structure/__init__.py:
--------------------------------------------------------------------------------
1 | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from .main import PPTableStructurer
15 |
--------------------------------------------------------------------------------
/rapid_table/table_structure/pp_structure/main.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import Any, Dict, List, Tuple
15 |
16 | import numpy as np
17 |
18 | from ...inference_engine.base import get_engine
19 | from ...utils.typings import EngineType
20 | from .post_process import TableLabelDecode
21 | from .pre_process import TablePreprocess
22 |
23 |
24 | class PPTableStructurer:
25 | def __init__(self, cfg: Dict[str, Any]):
26 | if cfg["engine_type"] is None:
27 | cfg["engine_type"] = EngineType.ONNXRUNTIME
28 |
29 | self.session = get_engine(cfg["engine_type"])(cfg)
30 | self.cfg = cfg
31 |
32 | self.preprocess_op = TablePreprocess()
33 |
34 | character = self.session.get_character_list()
35 | self.postprocess_op = TableLabelDecode(character, cfg)
36 |
37 | def __call__(
38 | self, ori_imgs: List[np.ndarray]
39 | ) -> Tuple[List[str], List[np.ndarray]]:
40 | imgs, shape_lists = self.preprocess_op(ori_imgs)
41 |
42 | bbox_preds, struct_probs = self.session(imgs.copy())
43 |
44 | table_structs, cell_bboxes = self.postprocess_op(
45 | bbox_preds, struct_probs, shape_lists, ori_imgs
46 | )
47 | return table_structs, cell_bboxes
48 |
--------------------------------------------------------------------------------
/rapid_table/table_structure/pp_structure/post_process.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from typing import List, Tuple
5 |
6 | import numpy as np
7 |
8 | from ...utils.typings import ModelType
9 | from ..utils import wrap_with_html_struct
10 |
11 |
12 | class TableLabelDecode:
13 | def __init__(self, dict_character, cfg, merge_no_span_structure=True):
14 | if merge_no_span_structure:
15 | if " | " not in dict_character:
16 | dict_character.append(" | ")
17 | if "" in dict_character:
18 | dict_character.remove(" | ")
19 |
20 | dict_character = self.add_special_char(dict_character)
21 | self.char_to_index = {}
22 | for i, char in enumerate(dict_character):
23 | self.char_to_index[char] = i
24 |
25 | self.character = dict_character
26 | self.td_token = [" | ", " | | "]
27 | self.cfg = cfg
28 |
29 | def __call__(
30 | self,
31 | bbox_preds: np.ndarray,
32 | structure_probs: np.ndarray,
33 | shape_list: np.ndarray,
34 | ori_imgs: np.ndarray,
35 | ):
36 | result = self.decode(bbox_preds, structure_probs, shape_list, ori_imgs)
37 | return result
38 |
39 | def decode(
40 | self,
41 | bbox_preds: np.ndarray,
42 | structure_probs: np.ndarray,
43 | shape_list: np.ndarray,
44 | ori_imgs: np.ndarray,
45 | ) -> Tuple[List[Tuple[List[str], float]], List[np.ndarray]]:
46 | """convert text-label into text-index."""
47 | ignored_tokens = self.get_ignored_tokens()
48 | end_idx = self.char_to_index[self.end_str]
49 |
50 | structure_idx = structure_probs.argmax(axis=2)
51 | structure_probs = structure_probs.max(axis=2)
52 |
53 | table_structs, cell_bboxes = [], []
54 | batch_size = len(structure_idx)
55 | for batch_idx in range(batch_size):
56 | structure_list, bbox_list, score_list = [], [], []
57 | for idx in range(len(structure_idx[batch_idx])):
58 | char_idx = int(structure_idx[batch_idx][idx])
59 | if idx > 0 and char_idx == end_idx:
60 | break
61 |
62 | if char_idx in ignored_tokens:
63 | continue
64 |
65 | text = self.character[char_idx]
66 | if text in self.td_token:
67 | bbox = bbox_preds[batch_idx, idx]
68 | bbox = self._bbox_decode(bbox, shape_list[batch_idx])
69 | bbox_list.append(bbox)
70 |
71 | structure_list.append(text)
72 | score_list.append(structure_probs[batch_idx, idx])
73 |
74 | bboxes = self.normalize_bboxes(bbox_list, ori_imgs[batch_idx])
75 | cell_bboxes.append(bboxes)
76 |
77 | table_structs.append(
78 | (wrap_with_html_struct(structure_list), float(np.mean(score_list)))
79 | )
80 | return table_structs, cell_bboxes
81 |
82 | def _bbox_decode(self, bbox, shape):
83 | h, w = shape[:2]
84 | bbox[0::2] *= w
85 | bbox[1::2] *= h
86 | return bbox
87 |
88 | def get_ignored_tokens(self):
89 | beg_idx = self.get_beg_end_flag_idx("beg")
90 | end_idx = self.get_beg_end_flag_idx("end")
91 | return [beg_idx, end_idx]
92 |
93 | def get_beg_end_flag_idx(self, beg_or_end):
94 | if beg_or_end == "beg":
95 | return np.array(self.char_to_index[self.beg_str])
96 |
97 | if beg_or_end == "end":
98 | return np.array(self.char_to_index[self.end_str])
99 |
100 | raise TypeError(f"unsupport type {beg_or_end} in get_beg_end_flag_idx")
101 |
102 | def add_special_char(self, dict_character):
103 | self.beg_str = "sos"
104 | self.end_str = "eos"
105 | dict_character = [self.beg_str] + dict_character + [self.end_str]
106 | return dict_character
107 |
108 | def normalize_bboxes(self, bbox_list, ori_imgs):
109 | cell_bboxes = np.array(bbox_list)
110 | if self.cfg["model_type"] == ModelType.SLANETPLUS:
111 | cell_bboxes = self.rescale_cell_bboxes(ori_imgs, cell_bboxes)
112 | cell_bboxes = self.filter_blank_bbox(cell_bboxes)
113 | return cell_bboxes
114 |
115 | def rescale_cell_bboxes(
116 | self, img: np.ndarray, cell_bboxes: np.ndarray
117 | ) -> np.ndarray:
118 | h, w = img.shape[:2]
119 | resized = 488
120 | ratio = min(resized / h, resized / w)
121 | w_ratio = resized / (w * ratio)
122 | h_ratio = resized / (h * ratio)
123 | cell_bboxes[:, 0::2] *= w_ratio
124 | cell_bboxes[:, 1::2] *= h_ratio
125 | return cell_bboxes
126 |
127 | @staticmethod
128 | def filter_blank_bbox(cell_bboxes: np.ndarray) -> np.ndarray:
129 | # 过滤掉占位的bbox
130 | mask = ~np.all(cell_bboxes == 0, axis=1)
131 | return cell_bboxes[mask]
132 |
--------------------------------------------------------------------------------
/rapid_table/table_structure/pp_structure/pre_process.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from typing import List, Tuple
5 |
6 | import cv2
7 | import numpy as np
8 |
9 |
10 | class TablePreprocess:
11 | def __init__(self, max_len: int = 488):
12 | self.max_len = max_len
13 |
14 | self.std = np.array([0.229, 0.224, 0.225])
15 | self.mean = np.array([0.485, 0.456, 0.406])
16 | self.scale = 1 / 255.0
17 |
18 | def __call__(
19 | self, img_list: List[np.ndarray]
20 | ) -> Tuple[List[np.ndarray], np.ndarray]:
21 | if isinstance(img_list, np.ndarray):
22 | img_list = [img_list]
23 |
24 | processed_imgs, shape_lists = [], []
25 | for img in img_list:
26 | if img is None:
27 | continue
28 |
29 | img_processed, shape_list = self.resize_image(img)
30 | img_processed = self.normalize(img_processed)
31 | img_processed, shape_list = self.pad_img(img_processed, shape_list)
32 | img_processed = self.to_chw(img_processed)
33 |
34 | processed_imgs.append(img_processed)
35 | shape_lists.append(shape_list)
36 |
37 | return processed_imgs, np.array(shape_lists)
38 |
39 | def resize_image(self, img: np.ndarray) -> Tuple[np.ndarray, List[float]]:
40 | h, w = img.shape[:2]
41 | ratio = self.max_len / (max(h, w) * 1.0)
42 | resize_h, resize_w = int(h * ratio), int(w * ratio)
43 |
44 | resize_img = cv2.resize(img, (resize_w, resize_h))
45 | return resize_img, [h, w, ratio, ratio]
46 |
47 | def normalize(self, img: np.ndarray) -> np.ndarray:
48 | return (img.astype("float32") * self.scale - self.mean) / self.std
49 |
50 | def pad_img(
51 | self, img: np.ndarray, shape: List[float]
52 | ) -> Tuple[np.ndarray, List[float]]:
53 | padding_img = np.zeros((self.max_len, self.max_len, 3), dtype=np.float32)
54 | h, w = img.shape[:2]
55 | padding_img[:h, :w, :] = img.copy()
56 | shape.extend([self.max_len, self.max_len])
57 | return padding_img, shape
58 |
59 | def to_chw(self, img: np.ndarray) -> np.ndarray:
60 | return img.transpose((2, 0, 1))
61 |
--------------------------------------------------------------------------------
/rapid_table/table_structure/unitable/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from .main import UniTableStructure
5 |
--------------------------------------------------------------------------------
/rapid_table/table_structure/unitable/consts.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | IMG_SIZE = 448
5 | MAX_SEQ_LEN = 1024
6 | EOS_TOKEN = ""
7 | BBOX_TOKENS = [f"bbox-{i}" for i in range(IMG_SIZE + 1)]
8 |
9 | HTML_BBOX_HTML_TOKENS = [
10 | " | ",
11 | "[",
12 | "] | ",
13 | "[",
15 | "> | ",
16 | "",
17 | "
",
18 | "",
19 | "",
20 | "",
21 | "",
22 | ' rowspan="2"',
23 | ' rowspan="3"',
24 | ' rowspan="4"',
25 | ' rowspan="5"',
26 | ' rowspan="6"',
27 | ' rowspan="7"',
28 | ' rowspan="8"',
29 | ' rowspan="9"',
30 | ' rowspan="10"',
31 | ' rowspan="11"',
32 | ' rowspan="12"',
33 | ' rowspan="13"',
34 | ' rowspan="14"',
35 | ' rowspan="15"',
36 | ' rowspan="16"',
37 | ' rowspan="17"',
38 | ' rowspan="18"',
39 | ' rowspan="19"',
40 | ' colspan="2"',
41 | ' colspan="3"',
42 | ' colspan="4"',
43 | ' colspan="5"',
44 | ' colspan="6"',
45 | ' colspan="7"',
46 | ' colspan="8"',
47 | ' colspan="9"',
48 | ' colspan="10"',
49 | ' colspan="11"',
50 | ' colspan="12"',
51 | ' colspan="13"',
52 | ' colspan="14"',
53 | ' colspan="15"',
54 | ' colspan="16"',
55 | ' colspan="17"',
56 | ' colspan="18"',
57 | ' colspan="19"',
58 | ' colspan="25"',
59 | ]
60 |
61 | VALID_HTML_BBOX_TOKENS = [EOS_TOKEN] + HTML_BBOX_HTML_TOKENS + BBOX_TOKENS
62 | TASK_TOKENS = [
63 | "[table]",
64 | "[html]",
65 | "[cell]",
66 | "[bbox]",
67 | "[cell+bbox]",
68 | "[html+bbox]",
69 | ]
70 |
--------------------------------------------------------------------------------
/rapid_table/table_structure/unitable/main.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | import re
3 | from typing import Any, Dict, List
4 |
5 | import numpy as np
6 | import torch
7 |
8 | from ...inference_engine.base import get_engine
9 | from ...utils import EngineType
10 | from ..utils import wrap_with_html_struct
11 | from .consts import (
12 | BBOX_TOKENS,
13 | EOS_TOKEN,
14 | MAX_SEQ_LEN,
15 | TASK_TOKENS,
16 | VALID_HTML_BBOX_TOKENS,
17 | )
18 | from .post_process import rescale_bboxes
19 | from .pre_process import TablePreprocess
20 |
21 |
22 | class UniTableStructure:
23 | def __init__(self, cfg: Dict[str, Any]):
24 | if cfg["engine_type"] is None:
25 | cfg["engine_type"] = EngineType.TORCH
26 | self.model = get_engine(cfg["engine_type"])(cfg)
27 |
28 | self.encoder = self.model.encoder
29 | self.device = self.model.device
30 |
31 | self.vocab = self.model.vocab
32 |
33 | self.token_white_list = [
34 | self.vocab.token_to_id(i) for i in VALID_HTML_BBOX_TOKENS
35 | ]
36 |
37 | self.bbox_token_ids = set(self.vocab.token_to_id(i) for i in BBOX_TOKENS)
38 | self.bbox_close_html_token = self.vocab.token_to_id("]")
39 |
40 | self.prefix_token_id = self.vocab.token_to_id("[html+bbox]")
41 |
42 | self.eos_id = self.vocab.token_to_id(EOS_TOKEN)
43 |
44 | self.context = (
45 | torch.tensor([self.prefix_token_id], dtype=torch.int32)
46 | .repeat(1, 1)
47 | .to(self.device)
48 | )
49 | self.eos_id_tensor = torch.tensor(self.eos_id, dtype=torch.int32).to(
50 | self.device
51 | )
52 |
53 | self.max_seq_len = MAX_SEQ_LEN
54 |
55 | self.decoder = self.model.decoder
56 |
57 | self.preprocess_op = TablePreprocess(self.device)
58 |
59 | def __call__(self, imgs: List[np.ndarray]):
60 | img_batch, ori_shapes = self.preprocess_op(imgs)
61 | memory_batch = self.encoder(img_batch)
62 |
63 | struct_list, total_bboxes = [], []
64 | for i, memory in enumerate(memory_batch):
65 | self.decoder.setup_caches(
66 | max_batch_size=1,
67 | max_seq_length=self.max_seq_len,
68 | dtype=torch.float32,
69 | device=self.device,
70 | )
71 |
72 | context = self.loop_decode(
73 | self.context, self.eos_id_tensor, memory[None, ...]
74 | )
75 | bboxes, html_tokens = self.decode_tokens(context)
76 |
77 | ori_h, ori_w = ori_shapes[i]
78 | one_bboxes = rescale_bboxes(ori_h, ori_w, bboxes)
79 | total_bboxes.append(one_bboxes)
80 |
81 | one_struct = wrap_with_html_struct(html_tokens)
82 | struct_list.append((one_struct, 1.0))
83 | return struct_list, total_bboxes
84 |
85 | def loop_decode(self, context, eos_id_tensor, memory):
86 | box_token_count = 0
87 | for _ in range(self.max_seq_len):
88 | eos_flag = (context == eos_id_tensor).any(dim=1)
89 | if torch.all(eos_flag):
90 | break
91 |
92 | next_tokens = self.decoder(memory, context)
93 | if next_tokens[0] in self.bbox_token_ids:
94 | box_token_count += 1
95 | if box_token_count > 4:
96 | next_tokens = torch.tensor(
97 | [self.bbox_close_html_token], dtype=torch.int32
98 | )
99 | box_token_count = 0
100 | context = torch.cat([context, next_tokens], dim=1)
101 | return context
102 |
103 | def decode_tokens(self, context):
104 | pred_html = context[0]
105 | pred_html = pred_html.detach().cpu().numpy()
106 | pred_html = self.vocab.decode(pred_html, skip_special_tokens=False)
107 | seq = pred_html.split("")[0]
108 | token_black_list = ["", "", *TASK_TOKENS]
109 | for i in token_black_list:
110 | seq = seq.replace(i, "")
111 |
112 | tr_pattern = re.compile(r"(.*?)
", re.DOTALL)
113 | td_pattern = re.compile(r"(.*?) | ", re.DOTALL)
114 | bbox_pattern = re.compile(r"\[ bbox-(\d+) bbox-(\d+) bbox-(\d+) bbox-(\d+) \]")
115 |
116 | decoded_list, bbox_coords = [], []
117 |
118 | # 查找所有的 标签
119 | for tr_match in tr_pattern.finditer(pred_html):
120 | tr_content = tr_match.group(1)
121 | decoded_list.append("
")
122 |
123 | # 查找所有的 标签
124 | for td_match in td_pattern.finditer(tr_content):
125 | td_attrs = td_match.group(1).strip()
126 | td_content = td_match.group(2).strip()
127 | if td_attrs:
128 | decoded_list.append(" | ")
134 | decoded_list.append(" | ")
135 | else:
136 | decoded_list.append(" | ")
137 |
138 | # 查找 bbox 坐标
139 | bbox_match = bbox_pattern.search(td_content)
140 | if bbox_match:
141 | xmin, ymin, xmax, ymax = map(int, bbox_match.groups())
142 | # 将坐标转换为从左上角开始顺时针到左下角的点的坐标
143 | coords = np.array([xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax])
144 | bbox_coords.append(coords)
145 | else:
146 | # 填充占位的bbox,保证后续流程统一
147 | bbox_coords.append(np.array([0, 0, 0, 0, 0, 0, 0, 0]))
148 | decoded_list.append("
")
149 |
150 | bbox_coords_array = np.array(bbox_coords).astype(np.float32)
151 | return bbox_coords_array, decoded_list
152 |
--------------------------------------------------------------------------------
/rapid_table/table_structure/unitable/post_process.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | import numpy as np
5 |
6 | from .consts import IMG_SIZE
7 |
8 |
9 | def rescale_bboxes(ori_h, ori_w, bboxes):
10 | scale_h = ori_h / IMG_SIZE
11 | scale_w = ori_w / IMG_SIZE
12 | bboxes[:, 0::2] *= scale_w
13 | bboxes[:, 1::2] *= scale_h
14 | bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, ori_w - 1)
15 | bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, ori_h - 1)
16 | return bboxes
17 |
--------------------------------------------------------------------------------
/rapid_table/table_structure/unitable/pre_process.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from typing import List
5 |
6 | import cv2
7 | import numpy as np
8 | import torch
9 | from PIL import Image
10 | from torchvision import transforms
11 |
12 | from .consts import IMG_SIZE
13 |
14 |
15 | class TablePreprocess:
16 | def __init__(self, device: str):
17 | self.img_size = IMG_SIZE
18 | self.transform = transforms.Compose(
19 | [
20 | transforms.Resize((448, 448)),
21 | transforms.ToTensor(),
22 | transforms.Normalize(
23 | mean=[0.86597056, 0.88463002, 0.87491087],
24 | std=[0.20686628, 0.18201602, 0.18485524],
25 | ),
26 | ]
27 | )
28 |
29 | self.device = device
30 |
31 | def __call__(self, imgs: List[np.ndarray]):
32 | processed_imgs, ori_shapes = [], []
33 | for img in imgs:
34 | if img is None:
35 | continue
36 |
37 | ori_h, ori_w = img.shape[:2]
38 | ori_shapes.append((ori_h, ori_w))
39 |
40 | processed_img = self.preprocess_img(img)
41 | processed_imgs.append(processed_img)
42 | return torch.concatenate(processed_imgs), ori_shapes
43 |
44 | def preprocess_img(self, ori_image: np.ndarray) -> torch.Tensor:
45 | image = cv2.cvtColor(ori_image, cv2.COLOR_BGR2RGB)
46 | image = Image.fromarray(image)
47 | image = self.transform(image).unsqueeze(0).to(self.device)
48 | return image
49 |
--------------------------------------------------------------------------------
/rapid_table/table_structure/unitable/unitable_modules.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from functools import partial
3 | from typing import Optional
4 |
5 | import torch
6 | from torch import Tensor, nn
7 | from torch.nn import functional as F
8 | from torch.nn.attention import SDPBackend, sdpa_kernel
9 | from torch.nn.modules.transformer import _get_activation_fn
10 |
11 | TOKEN_WHITE_LIST = [
12 | 1,
13 | 12,
14 | 13,
15 | 14,
16 | 15,
17 | 16,
18 | 17,
19 | 18,
20 | 19,
21 | 20,
22 | 21,
23 | 22,
24 | 23,
25 | 24,
26 | 25,
27 | 26,
28 | 27,
29 | 28,
30 | 29,
31 | 30,
32 | 31,
33 | 32,
34 | 33,
35 | 34,
36 | 35,
37 | 36,
38 | 37,
39 | 38,
40 | 39,
41 | 40,
42 | 41,
43 | 42,
44 | 43,
45 | 44,
46 | 45,
47 | 46,
48 | 47,
49 | 48,
50 | 49,
51 | 50,
52 | 51,
53 | 52,
54 | 53,
55 | 54,
56 | 55,
57 | 56,
58 | 57,
59 | 58,
60 | 59,
61 | 60,
62 | 61,
63 | 62,
64 | 63,
65 | 64,
66 | 65,
67 | 66,
68 | 67,
69 | 68,
70 | 69,
71 | 70,
72 | 71,
73 | 72,
74 | 73,
75 | 74,
76 | 75,
77 | 76,
78 | 77,
79 | 78,
80 | 79,
81 | 80,
82 | 81,
83 | 82,
84 | 83,
85 | 84,
86 | 85,
87 | 86,
88 | 87,
89 | 88,
90 | 89,
91 | 90,
92 | 91,
93 | 92,
94 | 93,
95 | 94,
96 | 95,
97 | 96,
98 | 97,
99 | 98,
100 | 99,
101 | 100,
102 | 101,
103 | 102,
104 | 103,
105 | 104,
106 | 105,
107 | 106,
108 | 107,
109 | 108,
110 | 109,
111 | 110,
112 | 111,
113 | 112,
114 | 113,
115 | 114,
116 | 115,
117 | 116,
118 | 117,
119 | 118,
120 | 119,
121 | 120,
122 | 121,
123 | 122,
124 | 123,
125 | 124,
126 | 125,
127 | 126,
128 | 127,
129 | 128,
130 | 129,
131 | 130,
132 | 131,
133 | 132,
134 | 133,
135 | 134,
136 | 135,
137 | 136,
138 | 137,
139 | 138,
140 | 139,
141 | 140,
142 | 141,
143 | 142,
144 | 143,
145 | 144,
146 | 145,
147 | 146,
148 | 147,
149 | 148,
150 | 149,
151 | 150,
152 | 151,
153 | 152,
154 | 153,
155 | 154,
156 | 155,
157 | 156,
158 | 157,
159 | 158,
160 | 159,
161 | 160,
162 | 161,
163 | 162,
164 | 163,
165 | 164,
166 | 165,
167 | 166,
168 | 167,
169 | 168,
170 | 169,
171 | 170,
172 | 171,
173 | 172,
174 | 173,
175 | 174,
176 | 175,
177 | 176,
178 | 177,
179 | 178,
180 | 179,
181 | 180,
182 | 181,
183 | 182,
184 | 183,
185 | 184,
186 | 185,
187 | 186,
188 | 187,
189 | 188,
190 | 189,
191 | 190,
192 | 191,
193 | 192,
194 | 193,
195 | 194,
196 | 195,
197 | 196,
198 | 197,
199 | 198,
200 | 199,
201 | 200,
202 | 201,
203 | 202,
204 | 203,
205 | 204,
206 | 205,
207 | 206,
208 | 207,
209 | 208,
210 | 209,
211 | 210,
212 | 211,
213 | 212,
214 | 213,
215 | 214,
216 | 215,
217 | 216,
218 | 217,
219 | 218,
220 | 219,
221 | 220,
222 | 221,
223 | 222,
224 | 223,
225 | 224,
226 | 225,
227 | 226,
228 | 227,
229 | 228,
230 | 229,
231 | 230,
232 | 231,
233 | 232,
234 | 233,
235 | 234,
236 | 235,
237 | 236,
238 | 237,
239 | 238,
240 | 239,
241 | 240,
242 | 241,
243 | 242,
244 | 243,
245 | 244,
246 | 245,
247 | 246,
248 | 247,
249 | 248,
250 | 249,
251 | 250,
252 | 251,
253 | 252,
254 | 253,
255 | 254,
256 | 255,
257 | 256,
258 | 257,
259 | 258,
260 | 259,
261 | 260,
262 | 261,
263 | 262,
264 | 263,
265 | 264,
266 | 265,
267 | 266,
268 | 267,
269 | 268,
270 | 269,
271 | 270,
272 | 271,
273 | 272,
274 | 273,
275 | 274,
276 | 275,
277 | 276,
278 | 277,
279 | 278,
280 | 279,
281 | 280,
282 | 281,
283 | 282,
284 | 283,
285 | 284,
286 | 285,
287 | 286,
288 | 287,
289 | 288,
290 | 289,
291 | 290,
292 | 291,
293 | 292,
294 | 293,
295 | 294,
296 | 295,
297 | 296,
298 | 297,
299 | 298,
300 | 299,
301 | 300,
302 | 301,
303 | 302,
304 | 303,
305 | 304,
306 | 305,
307 | 306,
308 | 307,
309 | 308,
310 | 309,
311 | 310,
312 | 311,
313 | 312,
314 | 313,
315 | 314,
316 | 315,
317 | 316,
318 | 317,
319 | 318,
320 | 319,
321 | 320,
322 | 321,
323 | 322,
324 | 323,
325 | 324,
326 | 325,
327 | 326,
328 | 327,
329 | 328,
330 | 329,
331 | 330,
332 | 331,
333 | 332,
334 | 333,
335 | 334,
336 | 335,
337 | 336,
338 | 337,
339 | 338,
340 | 339,
341 | 340,
342 | 341,
343 | 342,
344 | 343,
345 | 344,
346 | 345,
347 | 346,
348 | 347,
349 | 348,
350 | 349,
351 | 350,
352 | 351,
353 | 352,
354 | 353,
355 | 354,
356 | 355,
357 | 356,
358 | 357,
359 | 358,
360 | 359,
361 | 360,
362 | 361,
363 | 362,
364 | 363,
365 | 364,
366 | 365,
367 | 366,
368 | 367,
369 | 368,
370 | 369,
371 | 370,
372 | 371,
373 | 372,
374 | 373,
375 | 374,
376 | 375,
377 | 376,
378 | 377,
379 | 378,
380 | 379,
381 | 380,
382 | 381,
383 | 382,
384 | 383,
385 | 384,
386 | 385,
387 | 386,
388 | 387,
389 | 388,
390 | 389,
391 | 390,
392 | 391,
393 | 392,
394 | 393,
395 | 394,
396 | 395,
397 | 396,
398 | 397,
399 | 398,
400 | 399,
401 | 400,
402 | 401,
403 | 402,
404 | 403,
405 | 404,
406 | 405,
407 | 406,
408 | 407,
409 | 408,
410 | 409,
411 | 410,
412 | 411,
413 | 412,
414 | 413,
415 | 414,
416 | 415,
417 | 416,
418 | 417,
419 | 418,
420 | 419,
421 | 420,
422 | 421,
423 | 422,
424 | 423,
425 | 424,
426 | 425,
427 | 426,
428 | 427,
429 | 428,
430 | 429,
431 | 430,
432 | 431,
433 | 432,
434 | 433,
435 | 434,
436 | 435,
437 | 436,
438 | 437,
439 | 438,
440 | 439,
441 | 440,
442 | 441,
443 | 442,
444 | 443,
445 | 444,
446 | 445,
447 | 446,
448 | 447,
449 | 448,
450 | 449,
451 | 450,
452 | 451,
453 | 452,
454 | 453,
455 | 454,
456 | 455,
457 | 456,
458 | 457,
459 | 458,
460 | 459,
461 | 460,
462 | 461,
463 | 462,
464 | 463,
465 | 464,
466 | 465,
467 | 466,
468 | 467,
469 | 468,
470 | 469,
471 | 470,
472 | 471,
473 | 472,
474 | 473,
475 | 474,
476 | 475,
477 | 476,
478 | 477,
479 | 478,
480 | 479,
481 | 480,
482 | 481,
483 | 482,
484 | 483,
485 | 484,
486 | 485,
487 | 486,
488 | 487,
489 | 488,
490 | 489,
491 | 490,
492 | 491,
493 | 492,
494 | 493,
495 | 494,
496 | 495,
497 | 496,
498 | 497,
499 | 498,
500 | 499,
501 | 500,
502 | 501,
503 | 502,
504 | 503,
505 | 504,
506 | 505,
507 | 506,
508 | 507,
509 | 508,
510 | 509,
511 | ]
512 |
513 |
514 | class ImgLinearBackbone(nn.Module):
515 | def __init__(
516 | self,
517 | d_model: int,
518 | patch_size: int,
519 | in_chan: int = 3,
520 | ) -> None:
521 | super().__init__()
522 |
523 | self.conv_proj = nn.Conv2d(
524 | in_chan,
525 | out_channels=d_model,
526 | kernel_size=patch_size,
527 | stride=patch_size,
528 | )
529 | self.d_model = d_model
530 |
531 | def forward(self, x: Tensor) -> Tensor:
532 | x = self.conv_proj(x)
533 | x = x.flatten(start_dim=-2).transpose(1, 2)
534 | return x
535 |
536 |
537 | class Encoder(nn.Module):
538 | def __init__(self) -> None:
539 | super().__init__()
540 |
541 | self.patch_size = 16
542 | self.d_model = 768
543 | self.dropout = 0
544 | self.activation = "gelu"
545 | self.norm_first = True
546 | self.ff_ratio = 4
547 | self.nhead = 12
548 | self.max_seq_len = 1024
549 | self.n_encoder_layer = 12
550 | encoder_layer = nn.TransformerEncoderLayer(
551 | self.d_model,
552 | nhead=self.nhead,
553 | dim_feedforward=self.ff_ratio * self.d_model,
554 | dropout=self.dropout,
555 | activation=self.activation,
556 | batch_first=True,
557 | norm_first=self.norm_first,
558 | )
559 | norm_layer = partial(nn.LayerNorm, eps=1e-6)
560 | self.norm = norm_layer(self.d_model)
561 | self.backbone = ImgLinearBackbone(
562 | d_model=self.d_model, patch_size=self.patch_size
563 | )
564 | self.pos_embed = PositionEmbedding(
565 | max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout
566 | )
567 | self.encoder = nn.TransformerEncoder(
568 | encoder_layer, num_layers=self.n_encoder_layer, enable_nested_tensor=False
569 | )
570 |
571 | def forward(self, x: Tensor) -> Tensor:
572 | src_feature = self.backbone(x)
573 | src_feature = self.pos_embed(src_feature)
574 | memory = self.encoder(src_feature)
575 | memory = self.norm(memory)
576 | return memory
577 |
578 |
579 | class PositionEmbedding(nn.Module):
580 | def __init__(self, max_seq_len: int, d_model: int, dropout: float) -> None:
581 | super().__init__()
582 | self.embedding = nn.Embedding(max_seq_len, d_model)
583 | self.dropout = nn.Dropout(dropout)
584 |
585 | def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
586 | # assume x is batch first
587 | if input_pos is None:
588 | _pos = torch.arange(x.shape[1], device=x.device)
589 | else:
590 | _pos = input_pos
591 | out = self.embedding(_pos)
592 | return self.dropout(out + x)
593 |
594 |
595 | class TokenEmbedding(nn.Module):
596 | def __init__(
597 | self,
598 | vocab_size: int,
599 | d_model: int,
600 | padding_idx: int,
601 | ) -> None:
602 | super().__init__()
603 | assert vocab_size > 0
604 | self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
605 |
606 | def forward(self, x: Tensor) -> Tensor:
607 | return self.embedding(x)
608 |
609 |
610 | def find_multiple(n: int, k: int) -> int:
611 | if n % k == 0:
612 | return n
613 | return n + k - (n % k)
614 |
615 |
616 | @dataclass
617 | class ModelArgs:
618 | n_layer: int = 4
619 | n_head: int = 12
620 | dim: int = 768
621 | intermediate_size: int = None
622 | head_dim: int = 64
623 | activation: str = "gelu"
624 | norm_first: bool = True
625 |
626 | def __post_init__(self):
627 | if self.intermediate_size is None:
628 | hidden_dim = 4 * self.dim
629 | n_hidden = int(2 * hidden_dim / 3)
630 | self.intermediate_size = find_multiple(n_hidden, 256)
631 | self.head_dim = self.dim // self.n_head
632 |
633 |
634 | class KVCache(nn.Module):
635 | def __init__(
636 | self,
637 | max_batch_size,
638 | max_seq_length,
639 | n_heads,
640 | head_dim,
641 | dtype=torch.bfloat16,
642 | device="cpu",
643 | ):
644 | super().__init__()
645 | cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
646 | self.register_buffer(
647 | "k_cache",
648 | torch.zeros(cache_shape, dtype=dtype, device=device),
649 | persistent=False,
650 | )
651 | self.register_buffer(
652 | "v_cache",
653 | torch.zeros(cache_shape, dtype=dtype, device=device),
654 | persistent=False,
655 | )
656 |
657 | def update(self, input_pos, k_val, v_val):
658 | bs = k_val.shape[0]
659 | k_out = self.k_cache
660 | v_out = self.v_cache
661 | k_out[:bs, :, input_pos] = k_val
662 | v_out[:bs, :, input_pos] = v_val
663 |
664 | return k_out[:bs], v_out[:bs]
665 |
666 |
667 | class GPTFastDecoder(nn.Module):
668 | def __init__(self) -> None:
669 | super().__init__()
670 |
671 | self.vocab_size = 960
672 | self.padding_idx = 2
673 | self.prefix_token_id = 11
674 | self.eos_id = 1
675 | self.max_seq_len = 1024
676 | self.dropout = 0
677 | self.d_model = 768
678 | self.nhead = 12
679 | self.activation = "gelu"
680 | self.norm_first = True
681 | self.n_decoder_layer = 4
682 | config = ModelArgs(
683 | n_layer=self.n_decoder_layer,
684 | n_head=self.nhead,
685 | dim=self.d_model,
686 | intermediate_size=self.d_model * 4,
687 | activation=self.activation,
688 | norm_first=self.norm_first,
689 | )
690 | self.config = config
691 | self.layers = nn.ModuleList(
692 | TransformerBlock(config) for _ in range(config.n_layer)
693 | )
694 | self.token_embed = TokenEmbedding(
695 | vocab_size=self.vocab_size,
696 | d_model=self.d_model,
697 | padding_idx=self.padding_idx,
698 | )
699 | self.pos_embed = PositionEmbedding(
700 | max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout
701 | )
702 | self.generator = nn.Linear(self.d_model, self.vocab_size)
703 | self.token_white_list = TOKEN_WHITE_LIST
704 | self.mask_cache: Optional[Tensor] = None
705 | self.max_batch_size = -1
706 | self.max_seq_length = -1
707 |
708 | def setup_caches(self, max_batch_size, max_seq_length, dtype, device):
709 | for b in self.layers:
710 | b.multihead_attn.k_cache = None
711 | b.multihead_attn.v_cache = None
712 |
713 | if (
714 | self.max_seq_length >= max_seq_length
715 | and self.max_batch_size >= max_batch_size
716 | ):
717 | return
718 | head_dim = self.config.dim // self.config.n_head
719 | max_seq_length = find_multiple(max_seq_length, 8)
720 | self.max_seq_length = max_seq_length
721 | self.max_batch_size = max_batch_size
722 |
723 | for b in self.layers:
724 | b.self_attn.kv_cache = KVCache(
725 | max_batch_size,
726 | max_seq_length,
727 | self.config.n_head,
728 | head_dim,
729 | dtype,
730 | device,
731 | )
732 | b.multihead_attn.k_cache = None
733 | b.multihead_attn.v_cache = None
734 |
735 | self.causal_mask = torch.tril(
736 | torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)
737 | ).to(device)
738 |
739 | def forward(self, memory: Tensor, tgt: Tensor) -> Tensor:
740 | input_pos = torch.tensor([tgt.shape[1] - 1], device=tgt.device, dtype=torch.int)
741 | tgt = tgt[:, -1:]
742 | tgt_feature = self.pos_embed(self.token_embed(tgt), input_pos=input_pos)
743 |
744 | with sdpa_kernel(SDPBackend.MATH):
745 | logits = tgt_feature
746 | tgt_mask = self.causal_mask[None, None, input_pos]
747 | for layer in self.layers:
748 | logits = layer(logits, memory, input_pos=input_pos, tgt_mask=tgt_mask)
749 |
750 | logits = self.generator(logits)[:, -1, :]
751 | total = set(range(logits.shape[-1]))
752 | black_list = list(total.difference(set(self.token_white_list)))
753 | logits[..., black_list] = -1e9
754 | probs = F.softmax(logits, dim=-1)
755 | _, next_tokens = probs.topk(1)
756 | return next_tokens
757 |
758 |
759 | class TransformerBlock(nn.Module):
760 | def __init__(self, config: ModelArgs) -> None:
761 | super().__init__()
762 | self.self_attn = Attention(config)
763 | self.multihead_attn = CrossAttention(config)
764 |
765 | layer_norm_eps = 1e-5
766 |
767 | d_model = config.dim
768 | dim_feedforward = config.intermediate_size
769 |
770 | self.linear1 = nn.Linear(d_model, dim_feedforward)
771 | self.linear2 = nn.Linear(dim_feedforward, d_model)
772 |
773 | self.norm_first = config.norm_first
774 | self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
775 | self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
776 | self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
777 |
778 | self.activation = _get_activation_fn(config.activation)
779 |
780 | def forward(
781 | self,
782 | tgt: Tensor,
783 | memory: Tensor,
784 | tgt_mask: Tensor,
785 | input_pos: Tensor,
786 | ) -> Tensor:
787 | if self.norm_first:
788 | x = tgt
789 | x = x + self.self_attn(self.norm1(x), tgt_mask, input_pos)
790 | x = x + self.multihead_attn(self.norm2(x), memory)
791 | x = x + self._ff_block(self.norm3(x))
792 | else:
793 | x = tgt
794 | x = self.norm1(x + self.self_attn(x, tgt_mask, input_pos))
795 | x = self.norm2(x + self.multihead_attn(x, memory))
796 | x = self.norm3(x + self._ff_block(x))
797 | return x
798 |
799 | def _ff_block(self, x: Tensor) -> Tensor:
800 | x = self.linear2(self.activation(self.linear1(x)))
801 | return x
802 |
803 |
804 | class Attention(nn.Module):
805 | def __init__(self, config: ModelArgs):
806 | super().__init__()
807 | assert config.dim % config.n_head == 0
808 |
809 | # key, query, value projections for all heads, but in a batch
810 | self.wqkv = nn.Linear(config.dim, 3 * config.dim)
811 | self.wo = nn.Linear(config.dim, config.dim)
812 |
813 | self.kv_cache: Optional[KVCache] = None
814 |
815 | self.n_head = config.n_head
816 | self.head_dim = config.head_dim
817 | self.dim = config.dim
818 |
819 | def forward(
820 | self,
821 | x: Tensor,
822 | mask: Tensor,
823 | input_pos: Optional[Tensor] = None,
824 | ) -> Tensor:
825 | bsz, seqlen, _ = x.shape
826 |
827 | kv_size = self.n_head * self.head_dim
828 | q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
829 |
830 | q = q.view(bsz, seqlen, self.n_head, self.head_dim)
831 | k = k.view(bsz, seqlen, self.n_head, self.head_dim)
832 | v = v.view(bsz, seqlen, self.n_head, self.head_dim)
833 |
834 | q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
835 |
836 | if self.kv_cache is not None:
837 | k, v = self.kv_cache.update(input_pos, k, v)
838 |
839 | y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
840 |
841 | y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
842 |
843 | y = self.wo(y)
844 | return y
845 |
846 |
847 | class CrossAttention(nn.Module):
848 | def __init__(self, config: ModelArgs):
849 | super().__init__()
850 | assert config.dim % config.n_head == 0
851 |
852 | self.query = nn.Linear(config.dim, config.dim)
853 | self.key = nn.Linear(config.dim, config.dim)
854 | self.value = nn.Linear(config.dim, config.dim)
855 | self.out = nn.Linear(config.dim, config.dim)
856 |
857 | self.k_cache = None
858 | self.v_cache = None
859 |
860 | self.n_head = config.n_head
861 | self.head_dim = config.head_dim
862 |
863 | def get_kv(self, xa: torch.Tensor):
864 | if self.k_cache is not None and self.v_cache is not None:
865 | return self.k_cache, self.v_cache
866 |
867 | k = self.key(xa)
868 | v = self.value(xa)
869 |
870 | # Reshape for correct format
871 | batch_size, source_seq_len, _ = k.shape
872 | k = k.view(batch_size, source_seq_len, self.n_head, self.head_dim)
873 | v = v.view(batch_size, source_seq_len, self.n_head, self.head_dim)
874 |
875 | if self.k_cache is None:
876 | self.k_cache = k
877 |
878 | if self.v_cache is None:
879 | self.v_cache = v
880 |
881 | return k, v
882 |
883 | def forward(
884 | self,
885 | x: Tensor,
886 | xa: Tensor,
887 | ):
888 | q = self.query(x)
889 | batch_size, target_seq_len, _ = q.shape
890 | q = q.view(batch_size, target_seq_len, self.n_head, self.head_dim)
891 | k, v = self.get_kv(xa)
892 |
893 | q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
894 |
895 | wv = F.scaled_dot_product_attention(
896 | query=q,
897 | key=k,
898 | value=v,
899 | is_causal=False,
900 | )
901 | wv = wv.transpose(1, 2).reshape(
902 | batch_size,
903 | target_seq_len,
904 | self.n_head * self.head_dim,
905 | )
906 |
907 | return self.out(wv)
908 |
--------------------------------------------------------------------------------
/rapid_table/table_structure/utils.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from typing import List
5 |
6 |
7 | def wrap_with_html_struct(structure_str_list: List[str]) -> List[str]:
8 | structure_str_list = (
9 | ["", "", ""]
10 | + structure_str_list
11 | + ["
", "", ""]
12 | )
13 | return structure_str_list
14 |
--------------------------------------------------------------------------------
/rapid_table/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from .download_file import DownloadFile, DownloadFileInput
5 | from .load_image import InputType, LoadImage
6 | from .logger import Logger
7 | from .typings import EngineType, ModelType, RapidTableInput, RapidTableOutput
8 | from .utils import format_ocr_results, import_package, is_url, mkdir, read_yaml
9 | from .vis import VisTable
10 |
--------------------------------------------------------------------------------
/rapid_table/utils/download_file.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | import logging
5 | import sys
6 | from dataclasses import dataclass
7 | from pathlib import Path
8 | from typing import Optional, Union
9 |
10 | import requests
11 | from tqdm import tqdm
12 |
13 | from .utils import get_file_sha256
14 |
15 |
16 | @dataclass
17 | class DownloadFileInput:
18 | file_url: str
19 | save_path: Union[str, Path]
20 | logger: logging.Logger
21 | sha256: Optional[str] = None
22 |
23 |
24 | class DownloadFile:
25 | BLOCK_SIZE = 1024 # 1 KiB
26 | REQUEST_TIMEOUT = 60
27 |
28 | @classmethod
29 | def run(cls, input_params: DownloadFileInput):
30 | save_path = Path(input_params.save_path)
31 |
32 | logger = input_params.logger
33 | cls._ensure_parent_dir_exists(save_path)
34 | if cls._should_skip_download(save_path, input_params.sha256, logger):
35 | return
36 |
37 | response = cls._make_http_request(input_params.file_url, logger)
38 | cls._save_response_with_progress(response, save_path, logger)
39 |
40 | @staticmethod
41 | def _ensure_parent_dir_exists(path: Path):
42 | path.parent.mkdir(parents=True, exist_ok=True)
43 |
44 | @classmethod
45 | def _should_skip_download(
46 | cls, path: Path, expected_sha256: Optional[str], logger: logging.Logger
47 | ) -> bool:
48 | if not path.exists():
49 | return False
50 |
51 | if expected_sha256 is None:
52 | logger.info("File exists (no checksum verification): %s", path)
53 | return True
54 |
55 | if cls.check_file_sha256(path, expected_sha256):
56 | logger.info("File exists and is valid: %s", path)
57 | return True
58 |
59 | logger.warning("File exists but is invalid, redownloading: %s", path)
60 | return False
61 |
62 | @classmethod
63 | def _make_http_request(cls, url: str, logger: logging.Logger) -> requests.Response:
64 | logger.info("Initiating download: %s", url)
65 | try:
66 | response = requests.get(url, stream=True, timeout=cls.REQUEST_TIMEOUT)
67 | response.raise_for_status() # Raises HTTPError for 4XX/5XX
68 | return response
69 | except requests.RequestException as e:
70 | logger.error("Download failed: %s", url)
71 | raise DownloadFileException(f"Failed to download {url}") from e
72 |
73 | @classmethod
74 | def _save_response_with_progress(
75 | cls, response: requests.Response, save_path: Path, logger: logging.Logger
76 | ) -> None:
77 | total_size = int(response.headers.get("content-length", 0))
78 | logger.info("Download size: %.2fMB", total_size / 1024 / 1024)
79 |
80 | with tqdm(
81 | total=total_size,
82 | unit="iB",
83 | unit_scale=True,
84 | disable=not cls.check_is_atty(),
85 | ) as progress_bar:
86 | with open(save_path, "wb") as output_file:
87 | for chunk in response.iter_content(chunk_size=cls.BLOCK_SIZE):
88 | progress_bar.update(len(chunk))
89 | output_file.write(chunk)
90 |
91 | logger.info("Successfully saved to: %s", save_path)
92 |
93 | @staticmethod
94 | def check_file_sha256(file_path: Union[str, Path], gt_sha256: str) -> bool:
95 | return get_file_sha256(file_path) == gt_sha256
96 |
97 | @staticmethod
98 | def check_is_atty() -> bool:
99 | try:
100 | is_interactive = sys.stderr.isatty()
101 | except AttributeError:
102 | return False
103 | return is_interactive
104 |
105 |
106 | class DownloadFileException(Exception):
107 | pass
108 |
--------------------------------------------------------------------------------
/rapid_table/utils/load_image.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from io import BytesIO
5 | from pathlib import Path
6 | from typing import Any, Union
7 |
8 | import cv2
9 | import numpy as np
10 | import requests
11 | from PIL import Image, UnidentifiedImageError
12 |
13 | from .utils import is_url
14 |
15 | root_dir = Path(__file__).resolve().parent
16 | InputType = Union[str, np.ndarray, bytes, Path, Image.Image]
17 |
18 |
19 | class LoadImage:
20 | def __init__(self):
21 | pass
22 |
23 | def __call__(self, img: InputType) -> np.ndarray:
24 | if not isinstance(img, InputType.__args__):
25 | raise LoadImageError(
26 | f"The img type {type(img)} does not in {InputType.__args__}"
27 | )
28 | origin_img_type = type(img)
29 | img = self.load_img(img)
30 | img = self.convert_img(img, origin_img_type)
31 | return img
32 |
33 | def load_img(self, img: InputType) -> np.ndarray:
34 | if isinstance(img, (str, Path)):
35 | if is_url(img):
36 | img = Image.open(requests.get(img, stream=True, timeout=60).raw)
37 | else:
38 | self.verify_exist(img)
39 | img = Image.open(img)
40 |
41 | try:
42 | img = self.img_to_ndarray(img)
43 | except UnidentifiedImageError as e:
44 | raise LoadImageError(f"cannot identify image file {img}") from e
45 | return img
46 |
47 | if isinstance(img, bytes):
48 | img = self.img_to_ndarray(Image.open(BytesIO(img)))
49 | return img
50 |
51 | if isinstance(img, np.ndarray):
52 | return img
53 |
54 | if isinstance(img, Image.Image):
55 | return self.img_to_ndarray(img)
56 |
57 | raise LoadImageError(f"{type(img)} is not supported!")
58 |
59 | def img_to_ndarray(self, img: Image.Image) -> np.ndarray:
60 | if img.mode == "1":
61 | img = img.convert("L")
62 | return np.array(img)
63 | return np.array(img)
64 |
65 | def convert_img(self, img: np.ndarray, origin_img_type: Any) -> np.ndarray:
66 | if img.ndim == 2:
67 | return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
68 |
69 | if img.ndim == 3:
70 | channel = img.shape[2]
71 | if channel == 1:
72 | return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
73 |
74 | if channel == 2:
75 | return self.cvt_two_to_three(img)
76 |
77 | if channel == 3:
78 | if issubclass(origin_img_type, (str, Path, bytes, Image.Image)):
79 | return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
80 | return img
81 |
82 | if channel == 4:
83 | return self.cvt_four_to_three(img)
84 |
85 | raise LoadImageError(
86 | f"The channel({channel}) of the img is not in [1, 2, 3, 4]"
87 | )
88 |
89 | raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]")
90 |
91 | @staticmethod
92 | def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
93 | """gray + alpha → BGR"""
94 | img_gray = img[..., 0]
95 | img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR)
96 |
97 | img_alpha = img[..., 1]
98 | not_a = cv2.bitwise_not(img_alpha)
99 | not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
100 |
101 | new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha)
102 | new_img = cv2.add(new_img, not_a)
103 | return new_img
104 |
105 | @staticmethod
106 | def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
107 | """RGBA → BGR"""
108 | r, g, b, a = cv2.split(img)
109 | new_img = cv2.merge((b, g, r))
110 |
111 | not_a = cv2.bitwise_not(a)
112 | not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
113 |
114 | new_img = cv2.bitwise_and(new_img, new_img, mask=a)
115 |
116 | mean_color = np.mean(new_img)
117 | if mean_color <= 0.0:
118 | new_img = cv2.add(new_img, not_a)
119 | else:
120 | new_img = cv2.bitwise_not(new_img)
121 | return new_img
122 |
123 | @staticmethod
124 | def verify_exist(file_path: Union[str, Path]):
125 | if not Path(file_path).exists():
126 | raise LoadImageError(f"{file_path} does not exist.")
127 |
128 |
129 | class LoadImageError(Exception):
130 | pass
131 |
--------------------------------------------------------------------------------
/rapid_table/utils/logger.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: Jocker1212
3 | # @Contact: xinyijianggo@gmail.com
4 | import logging
5 |
6 | import colorlog
7 |
8 |
9 | class Logger:
10 | def __init__(self, log_level=logging.DEBUG, logger_name=None):
11 | self.logger = logging.getLogger(logger_name)
12 | self.logger.setLevel(log_level)
13 | self.logger.propagate = False
14 |
15 | formatter = colorlog.ColoredFormatter(
16 | "%(log_color)s[%(levelname)s] %(asctime)s [RapidTable] %(filename)s:%(lineno)d: %(message)s",
17 | log_colors={
18 | "DEBUG": "cyan",
19 | "INFO": "green",
20 | "WARNING": "yellow",
21 | "ERROR": "red",
22 | "CRITICAL": "red,bg_white",
23 | },
24 | )
25 |
26 | if not self.logger.handlers:
27 | console_handler = logging.StreamHandler()
28 | console_handler.setFormatter(formatter)
29 |
30 | for handler in self.logger.handlers:
31 | self.logger.removeHandler(handler)
32 |
33 | console_handler.setLevel(log_level)
34 | self.logger.addHandler(console_handler)
35 |
36 | def get_log(self):
37 | return self.logger
38 |
--------------------------------------------------------------------------------
/rapid_table/utils/typings.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from dataclasses import dataclass, field
5 | from enum import Enum
6 | from pathlib import Path
7 | from typing import Dict, List, Optional, Tuple, Union
8 |
9 | import numpy as np
10 |
11 | from .utils import mkdir
12 | from .vis import VisTable
13 |
14 |
15 | class EngineType(Enum):
16 | ONNXRUNTIME = "onnxruntime"
17 | TORCH = "torch"
18 |
19 |
20 | class ModelType(Enum):
21 | PPSTRUCTURE_EN = "ppstructure_en"
22 | PPSTRUCTURE_ZH = "ppstructure_zh"
23 | SLANETPLUS = "slanet_plus"
24 | UNITABLE = "unitable"
25 |
26 |
27 | @dataclass
28 | class RapidTableInput:
29 | model_type: Optional[ModelType] = ModelType.SLANETPLUS
30 | model_dir_or_path: Union[str, Path, None, Dict[str, str]] = None
31 |
32 | engine_type: Optional[EngineType] = None
33 | engine_cfg: dict = field(default_factory=dict)
34 |
35 | use_ocr: bool = True
36 | ocr_params: dict = field(default_factory=dict)
37 |
38 |
39 | @dataclass
40 | class RapidTableOutput:
41 | imgs: List[np.ndarray] = field(default_factory=list)
42 | pred_htmls: List[str] = field(default_factory=list)
43 | cell_bboxes: List[np.ndarray] = field(default_factory=list)
44 | logic_points: List[np.ndarray] = field(default_factory=list)
45 | elapse: float = 0.0
46 |
47 | def vis(
48 | self,
49 | save_dir: Union[str, Path],
50 | save_name: str,
51 | indexes: Tuple[int, ...] = (0,),
52 | ) -> List[np.ndarray]:
53 | vis = VisTable()
54 |
55 | save_dir = Path(save_dir)
56 | mkdir(save_dir)
57 |
58 | results = []
59 | for idx in indexes:
60 | save_one_dir = save_dir / str(idx)
61 | mkdir(save_one_dir)
62 |
63 | save_html_path = save_one_dir / f"{save_name}.html"
64 | save_drawed_path = save_one_dir / f"{save_name}_vis.jpg"
65 | save_logic_points_path = save_one_dir / f"{save_name}_col_row_vis.jpg"
66 |
67 | vis_img = vis(
68 | self.imgs[idx],
69 | self.pred_htmls[idx],
70 | self.cell_bboxes[idx],
71 | self.logic_points[idx],
72 | save_html_path,
73 | save_drawed_path,
74 | save_logic_points_path,
75 | )
76 | results.append(vis_img)
77 | return results
78 |
--------------------------------------------------------------------------------
/rapid_table/utils/utils.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | import hashlib
5 | import importlib
6 | from pathlib import Path
7 | from typing import List, Tuple, Union
8 | from urllib.parse import urlparse
9 |
10 | import cv2
11 | import numpy as np
12 | from omegaconf import DictConfig, OmegaConf
13 |
14 |
15 | def format_ocr_results(
16 | ocr_results: Tuple[np.ndarray, Tuple[str], Tuple[float]], img_h: int, img_w: int
17 | ) -> Tuple[np.ndarray, List[Tuple[str, float]]]:
18 | rec_res = list(zip(ocr_results[1], ocr_results[2]))
19 |
20 | bboxes = np.array(ocr_results[0])
21 | min_coords = bboxes[..., :2].min(axis=1)
22 | max_coords = bboxes[..., :2].max(axis=1)
23 |
24 | min_coords = np.maximum(min_coords, 0)
25 | max_coords = np.minimum(max_coords, [img_w, img_h])
26 | dt_boxes = np.hstack([min_coords, max_coords])
27 | return dt_boxes, rec_res
28 |
29 |
30 | def save_img(save_path: Union[str, Path], img: np.ndarray):
31 | cv2.imwrite(str(save_path), img)
32 |
33 |
34 | def save_txt(save_path: Union[str, Path], txt: str):
35 | with open(save_path, "w", encoding="utf-8") as f:
36 | f.write(txt)
37 |
38 |
39 | def import_package(name, package=None):
40 | try:
41 | module = importlib.import_module(name, package=package)
42 | return module
43 | except ModuleNotFoundError:
44 | return None
45 |
46 |
47 | def mkdir(dir_path):
48 | Path(dir_path).mkdir(parents=True, exist_ok=True)
49 |
50 |
51 | def read_yaml(file_path: Union[str, Path]) -> DictConfig:
52 | return OmegaConf.load(file_path)
53 |
54 |
55 | def get_file_sha256(file_path: Union[str, Path], chunk_size: int = 65536) -> str:
56 | with open(file_path, "rb") as file:
57 | sha_signature = hashlib.sha256()
58 | while True:
59 | chunk = file.read(chunk_size)
60 | if not chunk:
61 | break
62 | sha_signature.update(chunk)
63 |
64 | return sha_signature.hexdigest()
65 |
66 |
67 | def is_url(url: str) -> bool:
68 | try:
69 | result = urlparse(url)
70 | return all([result.scheme, result.netloc])
71 | except Exception:
72 | return False
73 |
--------------------------------------------------------------------------------
/rapid_table/utils/vis.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | from pathlib import Path
5 | from typing import Union
6 |
7 | import cv2
8 | import numpy as np
9 |
10 | from .logger import Logger
11 | from .utils import save_img, save_txt
12 |
13 |
14 | class VisTable:
15 | def __init__(self):
16 | self.logger = Logger(logger_name=__name__).get_log()
17 |
18 | def __call__(
19 | self,
20 | img: np.ndarray,
21 | pred_html: str,
22 | cell_bboxes: np.ndarray,
23 | logic_points: np.ndarray,
24 | save_html_path: Union[str, Path, None] = None,
25 | save_drawed_path: Union[str, Path, None] = None,
26 | save_logic_path: Union[str, Path, None] = None,
27 | ):
28 | if pred_html and save_html_path:
29 | html_with_border = self.insert_border_style(pred_html)
30 | save_txt(save_html_path, html_with_border)
31 | self.logger.info(f"Save HTML to {save_html_path}")
32 |
33 | if cell_bboxes is None:
34 | return None
35 |
36 | drawed_img = self.draw(img, cell_bboxes)
37 | if save_drawed_path:
38 | save_img(save_drawed_path, drawed_img)
39 | self.logger.info(f"Saved table struacter result to {save_drawed_path}")
40 |
41 | if save_logic_path and logic_points.size > 0:
42 | self.plot_rec_box_with_logic_info(
43 | img, save_logic_path, logic_points, cell_bboxes
44 | )
45 | self.logger.info(f"Saved rec and box result to {save_logic_path}")
46 | return drawed_img
47 |
48 | def insert_border_style(self, table_html_str: str) -> str:
49 | style_res = """"""
63 |
64 | prefix_table, suffix_table = table_html_str.split("")
65 | html_with_border = f"{prefix_table}{style_res}{suffix_table}"
66 | return html_with_border
67 |
68 | def draw(self, img: np.ndarray, cell_bboxes: np.ndarray) -> np.ndarray:
69 | dims_bboxes = cell_bboxes.shape[1]
70 | if dims_bboxes == 4:
71 | return self.draw_rectangle(img, cell_bboxes)
72 |
73 | if dims_bboxes == 8:
74 | return self.draw_polylines(img, cell_bboxes)
75 |
76 | raise ValueError("Shape of table bounding boxes is not between in 4 or 8.")
77 |
78 | @staticmethod
79 | def draw_rectangle(img: np.ndarray, boxes: np.ndarray) -> np.ndarray:
80 | img_copy = img.copy()
81 | for box in boxes.astype(int):
82 | x1, y1, x2, y2 = box
83 | cv2.rectangle(img_copy, (x1, y1), (x2, y2), (255, 0, 0), 2)
84 | return img_copy
85 |
86 | @staticmethod
87 | def draw_polylines(img: np.ndarray, points) -> np.ndarray:
88 | img_copy = img.copy()
89 | for point in points.astype(int):
90 | point = point.reshape(4, 2)
91 | cv2.polylines(img_copy, [point.astype(int)], True, (255, 0, 0), 2)
92 | return img_copy
93 |
94 | def plot_rec_box_with_logic_info(
95 | self, img: np.ndarray, output_path, logic_points, cell_bboxes
96 | ):
97 | img = cv2.copyMakeBorder(
98 | img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255]
99 | )
100 |
101 | polygons = [[box[0], box[1], box[4], box[5]] for box in cell_bboxes]
102 | for idx, polygon in enumerate(polygons):
103 | x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3]
104 | x0 = round(x0)
105 | y0 = round(y0)
106 | x1 = round(x1)
107 | y1 = round(y1)
108 | cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1)
109 |
110 | # 增大字体大小和线宽
111 | font_scale = 0.9 # 原先是0.5
112 | thickness = 1 # 原先是1
113 | logic_point = logic_points[idx]
114 | cv2.putText(
115 | img,
116 | f"row: {logic_point[0]}-{logic_point[1]}",
117 | (x0 + 3, y0 + 8),
118 | cv2.FONT_HERSHEY_PLAIN,
119 | font_scale,
120 | (0, 0, 255),
121 | thickness,
122 | )
123 | cv2.putText(
124 | img,
125 | f"col: {logic_point[2]}-{logic_point[3]}",
126 | (x0 + 3, y0 + 18),
127 | cv2.FONT_HERSHEY_PLAIN,
128 | font_scale,
129 | (0, 0, 255),
130 | thickness,
131 | )
132 |
133 | save_img(output_path, img)
134 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | onnxruntime>1.17.0
2 | opencv_python>=4.5.1.48
3 | numpy>=1.21.6
4 | Pillow
5 | requests
6 | colorlog
7 | omegaconf
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | import sys
5 | from pathlib import Path
6 | from typing import List, Union
7 |
8 | import setuptools
9 | from get_pypi_latest_version import GetPyPiLatestVersion
10 |
11 |
12 | def read_txt(txt_path: Union[Path, str]) -> List[str]:
13 | with open(txt_path, "r", encoding="utf-8") as f:
14 | data = [v.rstrip("\n") for v in f]
15 | return data
16 |
17 |
18 | def get_readme():
19 | root_dir = Path(__file__).resolve().parent
20 | readme_path = str(root_dir / "docs" / "doc_whl_rapid_table.md")
21 | with open(readme_path, "r", encoding="utf-8") as f:
22 | readme = f.read()
23 | return readme
24 |
25 |
26 | MODULE_NAME = "rapid_table"
27 | obtainer = GetPyPiLatestVersion()
28 | try:
29 | latest_version = obtainer(MODULE_NAME)
30 | except Exception:
31 | latest_version = "0.0.0"
32 | VERSION_NUM = obtainer.version_add_one(latest_version)
33 |
34 | if len(sys.argv) > 2:
35 | match_str = " ".join(sys.argv[2:])
36 | matched_versions = obtainer.extract_version(match_str)
37 | if matched_versions:
38 | VERSION_NUM = matched_versions
39 | sys.argv = sys.argv[:2]
40 |
41 | setuptools.setup(
42 | name=MODULE_NAME,
43 | version=VERSION_NUM,
44 | platforms="Any",
45 | long_description=get_readme(),
46 | long_description_content_type="text/markdown",
47 | description="Table Recognition",
48 | author="SWHL",
49 | author_email="liekkaskono@163.com",
50 | url="https://github.com/RapidAI/RapidTable",
51 | license="Apache-2.0",
52 | include_package_data=True,
53 | install_requires=read_txt("requirements.txt"),
54 | packages=setuptools.find_packages(),
55 | package_data={"": ["*.onnx", "*.yaml"]},
56 | keywords=["ppstructure,table,rapidocr,rapid_table"],
57 | classifiers=[
58 | "Programming Language :: Python :: 3.6",
59 | "Programming Language :: Python :: 3.7",
60 | "Programming Language :: Python :: 3.8",
61 | "Programming Language :: Python :: 3.9",
62 | "Programming Language :: Python :: 3.10",
63 | "Programming Language :: Python :: 3.11",
64 | "Programming Language :: Python :: 3.12",
65 | "Programming Language :: Python :: 3.13",
66 | ],
67 | python_requires=">=3.6,<4",
68 | entry_points={"console_scripts": [f"{MODULE_NAME}={MODULE_NAME}.main:main"]},
69 | extras_require={"torch": ["torch", "torchvision", "tokenizers"]},
70 | )
71 |
--------------------------------------------------------------------------------
/tests/test_files/table.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RapidAI/RapidTable/b0ba540a2a5ac06a61bcd6506795246a5046a291/tests/test_files/table.jpg
--------------------------------------------------------------------------------
/tests/test_files/table_without_txt.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RapidAI/RapidTable/b0ba540a2a5ac06a61bcd6506795246a5046a291/tests/test_files/table_without_txt.jpg
--------------------------------------------------------------------------------
/tests/test_main.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | # @Author: SWHL
3 | # @Contact: liekkaskono@163.com
4 | import shlex
5 | import sys
6 | from ast import literal_eval
7 | from pathlib import Path
8 |
9 | import pytest
10 | from rapidocr import RapidOCR
11 |
12 | cur_dir = Path(__file__).resolve().parent
13 | root_dir = cur_dir.parent
14 |
15 | sys.path.append(str(root_dir))
16 | from rapid_table import EngineType, ModelType, RapidTable, RapidTableInput
17 | from rapid_table.main import main
18 |
19 | ocr_engine = RapidOCR()
20 | table_engine = RapidTable()
21 |
22 | test_file_dir = cur_dir / "test_files"
23 | img_path = str(test_file_dir / "table.jpg")
24 | img_url = "https://raw.githubusercontent.com/RapidAI/RapidTable/refs/heads/main/tests/test_files/table.jpg"
25 |
26 |
27 | def test_only_table():
28 | img_path = test_file_dir / "table_without_txt.jpg"
29 | table_engine = RapidTable(RapidTableInput(use_ocr=False))
30 | results = table_engine(img_path)
31 |
32 | assert len(results.pred_htmls) == 0
33 | assert results.cell_bboxes[0].shape == (16, 8)
34 |
35 |
36 | def test_without_txt_table():
37 | img_path = test_file_dir / "table_without_txt.jpg"
38 | results = table_engine(img_path)
39 |
40 | assert results.pred_htmls[0] is None
41 | assert results.cell_bboxes[0].shape == (16, 8)
42 |
43 |
44 | @pytest.mark.parametrize(
45 | "command, expected_output",
46 | [
47 | (f"{img_path} --model_type slanet_plus", 1274),
48 | (f"{img_url} --model_type slanet_plus", 1274),
49 | ],
50 | )
51 | def test_main_cli(capsys, command, expected_output):
52 | main(shlex.split(command))
53 | output = capsys.readouterr().out.rstrip()
54 | assert len(literal_eval(output)[0]) == expected_output
55 |
56 |
57 | @pytest.mark.parametrize(
58 | "model_type,engine_type",
59 | [
60 | (ModelType.SLANETPLUS, EngineType.ONNXRUNTIME),
61 | (ModelType.UNITABLE, EngineType.TORCH),
62 | ],
63 | )
64 | def test_ocr_input(model_type, engine_type):
65 | ori_ocr_res = ocr_engine(img_path)
66 | ocr_results = [ori_ocr_res.boxes, ori_ocr_res.txts, ori_ocr_res.scores]
67 |
68 | input_args = RapidTableInput(model_type=model_type, engine_type=engine_type)
69 | table_engine = RapidTable(input_args)
70 | table_results = table_engine(img_path, ocr_results=[ocr_results])
71 | assert table_results.pred_htmls[0].count("") == 16
72 |
73 |
74 | @pytest.mark.parametrize(
75 | "model_type,engine_type",
76 | [
77 | (ModelType.SLANETPLUS, EngineType.ONNXRUNTIME),
78 | (ModelType.UNITABLE, EngineType.TORCH),
79 | ],
80 | )
81 | def test_input_ocr_none(model_type, engine_type):
82 | input_args = RapidTableInput(model_type=model_type, engine_type=engine_type)
83 | table_engine = RapidTable(input_args)
84 | table_results = table_engine(img_path)
85 | assert table_results.pred_htmls[0].count("
") == 16
86 | assert len(table_results.cell_bboxes) == len(table_results.logic_points)
87 |
--------------------------------------------------------------------------------
/tests/test_table_matcher.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | import sys
3 | import warnings
4 |
5 | import pytest
6 |
7 |
8 | @pytest.mark.skipif(
9 | sys.version_info.major == 3 and sys.version_info.minor < 12,
10 | reason="仅在python>=3.12时测试",
11 | )
12 | def test_regex_syntax_warning():
13 | """测试捕获正则表达式中无效转义序列产生的 SyntaxWarning"""
14 |
15 | with warnings.catch_warnings(record=True) as w:
16 | warnings.simplefilter("always")
17 |
18 | # 使用 compile() 来编译包含无效转义序列的代码,这会触发 SyntaxWarning
19 | code_with_invalid_escape = """
20 | import re
21 | thead_part = ' | rowspan="2">'
22 | isolate_pattern = (
23 | ' | rowspan="(\d)+" colspan="(\d)+">|'
24 | ' | colspan="(\d)+" rowspan="(\d)+">|'
25 | ' | rowspan="(\d)+">|'
26 | ' | colspan="(\d)+">'
27 | )
28 | re.finditer(isolate_pattern, thead_part)
29 | """
30 |
31 | # 编译代码时会产生 SyntaxWarning
32 | compile(code_with_invalid_escape, "", "exec")
33 |
34 | # 检查是否捕获到 SyntaxWarning
35 | syntax_warnings = [
36 | warn for warn in w if issubclass(warn.category, SyntaxWarning)
37 | ]
38 | assert (
39 | len(syntax_warnings) > 0
40 | ), f"未捕获到 SyntaxWarning: {[str(warn.message) for warn in w]}"
41 |
42 | # 应该捕获到无效转义序列的警告
43 | for warning in syntax_warnings:
44 | assert "invalid escape sequence" in str(warning.message)
45 |
46 |
47 | @pytest.mark.skipif(
48 | sys.version_info.major == 3 and sys.version_info.minor < 12,
49 | reason="仅在python>=3.12时测试",
50 | )
51 | def test_correct_regex_pattern():
52 | with warnings.catch_warnings(record=True) as w:
53 | warnings.simplefilter("always")
54 |
55 | # 这不会触发 SyntaxWarning
56 | code_with_invalid_escape = """
57 | import re
58 | thead_part = ' | rowspan="2">'
59 | isolate_pattern_raw = (
60 | r' | rowspan="(\d)+" colspan="(\d)+">|'
61 | r' | colspan="(\d)+" rowspan="(\d)+">|'
62 | r' | rowspan="(\d)+">|'
63 | r' | colspan="(\d)+">'
64 | )
65 | re.finditer(isolate_pattern_raw, thead_part)
66 | """
67 | compile(code_with_invalid_escape, "", "exec")
68 |
69 | # 检查是否捕获到 SyntaxWarning
70 | syntax_warnings = [
71 | warn for warn in w if issubclass(warn.category, SyntaxWarning)
72 | ]
73 | assert (
74 | len(syntax_warnings) == 0
75 | ), f"正常写法捕获到 SyntaxWarning: {[str(warn.message) for warn in w]}"
76 |
--------------------------------------------------------------------------------