├── .github ├── dependabot.yml └── workflows │ ├── ci.yml │ └── create_issue.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CHANGELOG.md ├── Dockerfile ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── examples └── ocr │ ├── engine.py │ ├── output │ └── .gitignore │ ├── requirements.txt │ └── validate_ocr_performance.py ├── img └── unstructured_logo.png ├── logger_config.yaml ├── pyproject.toml ├── pytest.ini ├── requirements ├── base.in ├── base.txt ├── constraints.in ├── dev.in ├── dev.txt ├── test.in └── test.txt ├── sample-docs ├── 2023-Jan-economic-outlook.pdf ├── IRS-form-1987.pdf ├── RGBA_image.png ├── Silent-Giant.pdf ├── design-thinking.pdf ├── easy_table.jpg ├── embedded-images.pdf ├── empty-document.pdf ├── example_table.jpg ├── ilpa-example-1.jpg ├── layout-parser-paper-fast.jpg ├── layout-parser-paper-fast.pdf ├── layout-parser-paper.pdf ├── loremipsum-flat.pdf ├── loremipsum.jpg ├── loremipsum.pdf ├── loremipsum.png ├── loremipsum.tiff ├── loremipsum_multipage.pdf ├── non-embedded.pdf ├── password.pdf ├── patent-1p.pdf ├── patent.pdf ├── pdf2image-memory-error-test-400p.pdf ├── recalibrating-risk-report.pdf ├── receipt-sample.jpg ├── table-multi-row-column-cells.png └── test-image.jpg ├── scripts ├── docker-build.sh ├── shellcheck.sh ├── test-unstructured-ingest-helper.sh └── version-sync.sh ├── setup.cfg ├── setup.py ├── test_unstructured_inference ├── conftest.py ├── inference │ ├── test_layout.py │ └── test_layout_element.py ├── models │ ├── test_detectron2onnx.py │ ├── test_eval.py │ ├── test_model.py │ ├── test_tables.py │ └── test_yolox.py ├── test_config.py ├── test_elements.py ├── test_logger.py ├── test_math.py ├── test_utils.py └── test_visualization.py └── unstructured_inference ├── __init__.py ├── __version__.py ├── config.py ├── constants.py ├── inference ├── __init__.py ├── elements.py ├── layout.py └── layoutelement.py ├── logger.py ├── math.py ├── models ├── __init__.py ├── base.py ├── detectron2onnx.py ├── eval.py ├── table_postprocess.py ├── tables.py ├── unstructuredmodel.py └── yolox.py ├── utils.py └── visualize.py /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "pip" 4 | directory: "/requirements" 5 | schedule: 6 | interval: "monthly" 7 | 8 | - package-ecosystem: "github-actions" 9 | # NOTE(robinson) - Workflow files stored in the 10 | # default location of `.github/workflows` 11 | directory: "/" 12 | schedule: 13 | interval: "monthly" 14 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [ main, robinson/initial-repo-setup ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | setup: 11 | strategy: 12 | matrix: 13 | python-version: ["3.10","3.11", "3.12"] 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v4 17 | - uses: actions/cache@v4 18 | id: virtualenv-cache 19 | with: 20 | path: | 21 | .venv 22 | key: ${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('requirements/*.txt') }} 23 | lookup-only: true 24 | - name: Set up Python ${{ matrix.python-version }} 25 | uses: actions/setup-python@v5 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | - name: Install Poppler 29 | run: | 30 | sudo apt-get update 31 | sudo apt-get -y install poppler-utils 32 | - name: Setup virtual environment (no cache hit) 33 | if: steps.virtualenv-cache.outputs.cache-hit != 'true' 34 | run: | 35 | python${{ matrix.python-version }} -m venv .venv 36 | source .venv/bin/activate 37 | make install-ci 38 | 39 | lint: 40 | strategy: 41 | matrix: 42 | python-version: ["3.10","3.11", "3.12"] 43 | runs-on: ubuntu-latest 44 | needs: setup 45 | steps: 46 | - uses: actions/checkout@v4 47 | - uses: actions/cache/restore@v4 48 | id: virtualenv-cache 49 | with: 50 | path: .venv 51 | key: ${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('requirements/*.txt') }} 52 | # NOTE(robinson) - This is a fallback in case the lint job does not find the cache. 53 | # We can take this out when we implement the fix in CORE-99 54 | - name: Setup virtual environment (no cache hit) 55 | if: steps.virtualenv-cache.outputs.cache-hit != 'true' 56 | run: | 57 | python${{ matrix.python-version }} -m venv .venv 58 | - name: Lint 59 | run: | 60 | source .venv/bin/activate 61 | make install-ci 62 | make check 63 | 64 | shellcheck: 65 | runs-on: ubuntu-latest 66 | steps: 67 | - uses: actions/checkout@v4 68 | - name: ShellCheck 69 | uses: ludeeus/action-shellcheck@master 70 | 71 | test: 72 | strategy: 73 | matrix: 74 | python-version: ["3.10","3.11", "3.12"] 75 | runs-on: ubuntu-latest 76 | needs: [setup, lint] 77 | steps: 78 | - uses: actions/checkout@v4 79 | - uses: actions/cache/restore@v4 80 | id: virtualenv-cache 81 | with: 82 | path: | 83 | .venv 84 | key: ${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('requirements/*.txt') }} 85 | # NOTE(robinson) - This is a fallback in case the lint job does not find the cache. 86 | # We can take this out when we implement the fix in CORE-99 87 | - name: Setup virtual environment (no cache hit) 88 | if: steps.virtualenv-cache.outputs.cache-hit != 'true' 89 | run: | 90 | python${{ matrix.python-version }} -m venv .venv 91 | - name: Install Poppler 92 | run: | 93 | sudo apt-get update 94 | sudo apt-get -y install poppler-utils tesseract-ocr 95 | - name: Configure AWS credentials 96 | uses: aws-actions/configure-aws-credentials@v4 97 | with: 98 | aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} 99 | aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} 100 | aws-region: us-east-2 101 | - name: Test 102 | env: 103 | UNSTRUCTURED_HF_TOKEN: ${{ secrets.HF_TOKEN }} 104 | run: | 105 | source .venv/bin/activate 106 | make install-ci 107 | aws s3 cp s3://utic-dev-models/ci_test_model/test_ci_model.onnx test_unstructured_inference/models/ 108 | CI=true make test 109 | make check-coverage 110 | 111 | # NOTE(robinson) - disabling ingest tests for now, as of 5/22/2024 they seem to have been 112 | # broken for the past six months 113 | # test_ingest: 114 | # strategy: 115 | # matrix: 116 | # python-version: ["3.9","3.10"] 117 | # runs-on: ubuntu-latest 118 | # env: 119 | # NLTK_DATA: ${{ github.workspace }}/nltk_data 120 | # needs: lint 121 | # steps: 122 | # - name: Checkout unstructured repo for integration testing 123 | # uses: actions/checkout@v4 124 | # with: 125 | # repository: 'Unstructured-IO/unstructured' 126 | # - name: Checkout this repo 127 | # uses: actions/checkout@v4 128 | # with: 129 | # path: inference 130 | # - name: Set up Python ${{ matrix.python-version }} 131 | # uses: actions/setup-python@v4 132 | # with: 133 | # python-version: ${{ matrix.python-version }} 134 | # - name: Test 135 | # env: 136 | # GH_READ_ONLY_ACCESS_TOKEN: ${{ secrets.GH_READ_ONLY_ACCESS_TOKEN }} 137 | # SLACK_TOKEN: ${{ secrets.SLACK_TOKEN }} 138 | # DISCORD_TOKEN: ${{ secrets.DISCORD_TOKEN }} 139 | # run: | 140 | # python${{ matrix.python-version }} -m venv .venv 141 | # source .venv/bin/activate 142 | # [ ! -d "$NLTK_DATA" ] && mkdir "$NLTK_DATA" 143 | # make install-ci 144 | # pip install -e inference/ 145 | # sudo apt-get update 146 | # sudo apt-get install -y libmagic-dev poppler-utils libreoffice pandoc 147 | # sudo add-apt-repository -y ppa:alex-p/tesseract-ocr5 148 | # sudo apt-get install -y tesseract-ocr 149 | # sudo apt-get install -y tesseract-ocr-kor 150 | # sudo apt-get install -y diffstat 151 | # tesseract --version 152 | # make install-all-ingest 153 | # # only run ingest tests that check expected output diffs. 154 | # bash inference/scripts/test-unstructured-ingest-helper.sh 155 | 156 | changelog: 157 | runs-on: ubuntu-latest 158 | steps: 159 | - uses: actions/checkout@v4 160 | - if: github.ref != 'refs/heads/main' 161 | uses: dorny/paths-filter@v2 162 | id: changes 163 | with: 164 | filters: | 165 | src: 166 | - 'unstructured_inference/**' 167 | 168 | - if: steps.changes.outputs.src == 'true' && github.ref != 'refs/heads/main' 169 | uses: dangoslen/changelog-enforcer@v3 -------------------------------------------------------------------------------- /.github/workflows/create_issue.yml: -------------------------------------------------------------------------------- 1 | name: create_jira_issue 2 | 3 | on: 4 | issues: 5 | types: 6 | - opened 7 | 8 | jobs: 9 | create: 10 | runs-on: ubuntu-latest 11 | name: Create JIRA Issue 12 | steps: 13 | 14 | - name: Login to Jira 15 | uses: atlassian/gajira-login@v3 16 | env: 17 | JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} 18 | JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }} 19 | JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} 20 | 21 | - name: Create Jira issue 22 | uses: atlassian/gajira-create@v3 23 | with: 24 | project: CORE 25 | issuetype: Task 26 | summary: ${{ github.event.issue.title }} 27 | description: | 28 | Created from github issue: ${{ github.event.issue.html_url }} 29 | ---- 30 | ${{ github.event.issue.body }} 31 | fields: '{ "labels": ["github-issue"] }' 32 | 33 | - name: Log created issue 34 | run: echo "Issue ${{ steps.create.outputs.issue }} was created" 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | nbs/ 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # Pycharm 122 | .idea/ 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | # Model artifacts 136 | .models/* 137 | !.models/.gitkeep 138 | 139 | # Mac stuff 140 | .DS_Store 141 | 142 | # VSCode 143 | .vscode/ 144 | 145 | sample-docs/*_images 146 | examples/**/output 147 | figures 148 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: "v4.3.0" 4 | hooks: 5 | - id: check-added-large-files 6 | - id: check-toml 7 | - id: check-yaml 8 | - id: check-json 9 | - id: check-xml 10 | - id: end-of-file-fixer 11 | exclude: \.json$ 12 | include: \.py$ 13 | - id: trailing-whitespace 14 | - id: mixed-line-ending 15 | 16 | - repo: https://github.com/psf/black 17 | rev: 22.10.0 18 | hooks: 19 | - id: black 20 | args: ["--line-length=100"] 21 | language_version: python3 22 | 23 | - repo: https://github.com/charliermarsh/ruff-pre-commit 24 | rev: "v0.0.230" 25 | hooks: 26 | - id: ruff 27 | args: 28 | [ 29 | "--fix", 30 | "--select=I,UP015,UP032,UP034,UP018,COM,C4,PT,SIM,PLR0402", 31 | "--ignore=PT011,PT012,SIM117", 32 | ] 33 | 34 | - repo: https://github.com/pycqa/flake8 35 | rev: 4.0.1 36 | hooks: 37 | - id: flake8 38 | language_version: python3 39 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax=docker/dockerfile:experimental 2 | FROM quay.io/unstructured-io/base-images:rocky8.7-3 as base 3 | 4 | ARG PIP_VERSION 5 | 6 | # Set up environment 7 | ENV HOME /home/ 8 | WORKDIR ${HOME} 9 | RUN mkdir ${HOME}/.ssh && chmod go-rwx ${HOME}/.ssh \ 10 | && ssh-keyscan -t rsa github.com >> /home/.ssh/known_hosts 11 | ENV PYTHONPATH="${PYTHONPATH}:${HOME}" 12 | ENV PATH="/home/usr/.local/bin:${PATH}" 13 | 14 | FROM base as deps 15 | # Copy and install Unstructured 16 | COPY requirements requirements 17 | 18 | RUN python3.8 -m pip install pip==${PIP_VERSION} && \ 19 | dnf -y groupinstall "Development Tools" && \ 20 | pip install --no-cache -r requirements/base.txt && \ 21 | pip install --no-cache -r requirements/test.txt && \ 22 | pip install --no-cache -r requirements/dev.txt && \ 23 | dnf -y groupremove "Development Tools" && \ 24 | dnf clean all 25 | 26 | FROM deps as code 27 | ARG PACKAGE_NAME=unstructured_inference 28 | COPY unstructured_inference unstructured_inference 29 | 30 | #CMD ["pytest -m \"not slow\" test_${PACKAGE_NAME} --cov=${PACKAGE_NAME} --cov-report term-missing"] 31 | CMD ["/bin/bash"] 32 | #CMD ["bash -c pytest test_unstructured_inference"] 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements/base.in 2 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PACKAGE_NAME := unstructured_inference 2 | PIP_VERSION := 23.2.1 3 | CURRENT_DIR := $(shell pwd) 4 | 5 | 6 | .PHONY: help 7 | help: Makefile 8 | @sed -n 's/^\(## \)\([a-zA-Z]\)/\2/p' $< 9 | 10 | 11 | ########### 12 | # Install # 13 | ########### 14 | 15 | ## install-base: installs core requirements needed for text processing bricks 16 | .PHONY: install-base 17 | install-base: install-base-pip-packages 18 | python3 -m pip install -r requirements/base.txt 19 | 20 | ## install: installs all test, dev, and experimental requirements 21 | .PHONY: install 22 | install: install-base-pip-packages install-dev 23 | 24 | .PHONY: install-ci 25 | install-ci: install-base-pip-packages install-test 26 | 27 | .PHONY: install-base-pip-packages 28 | install-base-pip-packages: 29 | python3 -m pip install pip==${PIP_VERSION} 30 | 31 | .PHONY: install-test 32 | install-test: install-base 33 | python3 -m pip install -r requirements/test.txt 34 | 35 | .PHONY: install-dev 36 | install-dev: install-test 37 | python3 -m pip install -r requirements/dev.txt 38 | 39 | ## pip-compile: compiles all base/dev/test requirements 40 | .PHONY: pip-compile 41 | pip-compile: 42 | pip-compile --upgrade requirements/base.in 43 | pip-compile --upgrade requirements/test.in 44 | pip-compile --upgrade requirements/dev.in 45 | 46 | ################# 47 | # Test and Lint # 48 | ################# 49 | 50 | export CI ?= false 51 | 52 | ## test: runs all unittests 53 | .PHONY: test 54 | test: 55 | PYTHONPATH=. CI=$(CI) pytest -m "not slow" test_${PACKAGE_NAME} --cov=${PACKAGE_NAME} --cov-report term-missing 56 | 57 | .PHONY: test-slow 58 | test-slow: 59 | PYTHONPATH=. CI=$(CI) pytest test_${PACKAGE_NAME} --cov=${PACKAGE_NAME} --cov-report term-missing 60 | 61 | ## check: runs linters (includes tests) 62 | .PHONY: check 63 | check: check-src check-tests check-version 64 | 65 | ## check-src: runs linters (source only, no tests) 66 | .PHONY: check-src 67 | check-src: 68 | ruff check ${PACKAGE_NAME} --line-length 100 --select C4,COM,E,F,I,PLR0402,PT,SIM,UP015,UP018,UP032,UP034 --ignore COM812,PT011,PT012,SIM117 69 | python -m black --line-length 100 ${PACKAGE_NAME} --check 70 | python -m flake8 ${PACKAGE_NAME} 71 | python -m mypy ${PACKAGE_NAME} --ignore-missing-imports 72 | 73 | .PHONY: check-tests 74 | check-tests: 75 | python -m black --line-length 100 test_${PACKAGE_NAME} --check 76 | python -m flake8 test_${PACKAGE_NAME} 77 | 78 | ## check-scripts: run shellcheck 79 | .PHONY: check-scripts 80 | check-scripts: 81 | # Fail if any of these files have warnings 82 | scripts/shellcheck.sh 83 | 84 | ## check-version: run check to ensure version in CHANGELOG.md matches version in package 85 | .PHONY: check-version 86 | check-version: 87 | # Fail if syncing version would produce changes 88 | scripts/version-sync.sh -c \ 89 | -s CHANGELOG.md \ 90 | -f unstructured_inference/__version__.py semver 91 | 92 | ## tidy: run black 93 | .PHONY: tidy 94 | tidy: 95 | ruff check ${PACKAGE_NAME} --fix --line-length 100 --select C4,COM,E,F,I,PLR0402,PT,SIM,UP015,UP018,UP032,UP034 --ignore COM812,PT011,PT012,SIM117 96 | black --line-length 100 ${PACKAGE_NAME} 97 | black --line-length 100 test_${PACKAGE_NAME} 98 | 99 | ## version-sync: update __version__.py with most recent version from CHANGELOG.md 100 | .PHONY: version-sync 101 | version-sync: 102 | scripts/version-sync.sh \ 103 | -s CHANGELOG.md \ 104 | -f unstructured_inference/__version__.py semver 105 | 106 | .PHONY: check-coverage 107 | check-coverage: 108 | python -m coverage report --fail-under=95 109 | 110 | ########## 111 | # Docker # 112 | ########## 113 | 114 | # Docker targets are provided for convenience only and are not required in a standard development environment 115 | 116 | DOCKER_IMAGE ?= unstructured-inference:dev 117 | 118 | .PHONY: docker-build 119 | docker-build: 120 | PIP_VERSION=${PIP_VERSION} DOCKER_IMAGE_NAME=${DOCKER_IMAGE} ./scripts/docker-build.sh 121 | 122 | .PHONY: docker-test 123 | docker-test: docker-build 124 | docker run --rm \ 125 | -v ${CURRENT_DIR}/test_unstructured_inference:/home/test_unstructured_inference \ 126 | -v ${CURRENT_DIR}/sample-docs:/home/sample-docs \ 127 | $(DOCKER_IMAGE) \ 128 | bash -c "pytest $(if $(TEST_NAME),-k $(TEST_NAME),) test_unstructured_inference" 129 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 6 | 7 |

8 | 9 |

10 |

Open-Source Pre-Processing Tools for Unstructured Data

11 |

12 | 13 | The `unstructured-inference` repo contains hosted model inference code for layout parsing models. 14 | These models are invoked via API as part of the partitioning bricks in the `unstructured` package. 15 | 16 | ## Installation 17 | 18 | ### Package 19 | 20 | Run `pip install unstructured-inference`. 21 | 22 | ### Detectron2 23 | 24 | [Detectron2](https://github.com/facebookresearch/detectron2) is required for using models from the [layoutparser model zoo](#using-models-from-the-layoutparser-model-zoo) 25 | but is not automatically installed with this package. 26 | For MacOS and Linux, build from source with: 27 | ```shell 28 | pip install 'git+https://github.com/facebookresearch/detectron2.git@57bdb21249d5418c130d54e2ebdc94dda7a4c01a' 29 | ``` 30 | Other install options can be found in the 31 | [Detectron2 installation guide](https://detectron2.readthedocs.io/en/latest/tutorials/install.html). 32 | 33 | Windows is not officially supported by Detectron2, but some users are able to install it anyway. 34 | See discussion [here](https://layout-parser.github.io/tutorials/installation#for-windows-users) for 35 | tips on installing Detectron2 on Windows. 36 | 37 | ### Repository 38 | 39 | To install the repository for development, clone the repo and run `make install` to install dependencies. 40 | Run `make help` for a full list of install options. 41 | 42 | ## Getting Started 43 | 44 | To get started with the layout parsing model, use the following commands: 45 | 46 | ```python 47 | from unstructured_inference.inference.layout import DocumentLayout 48 | 49 | layout = DocumentLayout.from_file("sample-docs/loremipsum.pdf") 50 | 51 | print(layout.pages[0].elements) 52 | ``` 53 | 54 | Once the model has detected the layout and OCR'd the document, the text extracted from the first 55 | page of the sample document will be displayed. 56 | You can convert a given element to a `dict` by running the `.to_dict()` method. 57 | 58 | ## Models 59 | 60 | The inference pipeline operates by finding text elements in a document page using a detection model, then extracting the contents of the elements using direct extraction (if available), OCR, and optionally table inference models. 61 | 62 | We offer several detection models including [Detectron2](https://github.com/facebookresearch/detectron2) and [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX). 63 | 64 | ### Using a non-default model 65 | 66 | When doing inference, an alternate model can be used by passing the model object to the ingestion method via the `model` parameter. The `get_model` function can be used to construct one of our out-of-the-box models from a keyword, e.g.: 67 | ```python 68 | from unstructured_inference.models.base import get_model 69 | from unstructured_inference.inference.layout import DocumentLayout 70 | 71 | model = get_model("yolox") 72 | layout = DocumentLayout.from_file("sample-docs/layout-parser-paper.pdf", detection_model=model) 73 | ``` 74 | 75 | ### Using your own model 76 | 77 | Any detection model can be used for in the `unstructured_inference` pipeline by wrapping the model in the `UnstructuredObjectDetectionModel` class. To integrate with the `DocumentLayout` class, a subclass of `UnstructuredObjectDetectionModel` must have a `predict` method that accepts a `PIL.Image.Image` and returns a list of `LayoutElement`s, and an `initialize` method, which loads the model and prepares it for inference. 78 | 79 | ## Security Policy 80 | 81 | See our [security policy](https://github.com/Unstructured-IO/unstructured-inference/security/policy) for 82 | information on how to report security vulnerabilities. 83 | 84 | ## Learn more 85 | 86 | | Section | Description | 87 | |-|-| 88 | | [Unstructured Community Github](https://github.com/Unstructured-IO/community) | Information about Unstructured.io community projects | 89 | | [Unstructured Github](https://github.com/Unstructured-IO) | Unstructured.io open source repositories | 90 | | [Company Website](https://unstructured.io) | Unstructured.io product and company info | 91 | -------------------------------------------------------------------------------- /examples/ocr/engine.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | from typing import List, cast 5 | 6 | import cv2 7 | import numpy as np 8 | import pytesseract 9 | from pytesseract import Output 10 | 11 | from unstructured_inference.inference import layout 12 | from unstructured_inference.inference.elements import Rectangle, TextRegion 13 | 14 | 15 | def remove_non_printable(s): 16 | dst_str = re.sub(r'[^\x20-\x7E]', ' ', s) 17 | return ' '.join(dst_str.split()) 18 | 19 | 20 | def run_ocr_with_layout_detection( 21 | images, 22 | detection_model=None, 23 | element_extraction_model=None, 24 | mode="individual_blocks", 25 | output_dir="", 26 | drawable=True, 27 | printable=True, 28 | ): 29 | total_text_extraction_infer_time = 0 30 | total_extracted_text = {} 31 | for i, image in enumerate(images): 32 | page_num = i + 1 33 | page_num_str = f"page{page_num}" 34 | 35 | page = layout.PageLayout( 36 | number=i+1, 37 | image=image, 38 | layout=None, 39 | detection_model=detection_model, 40 | element_extraction_model=element_extraction_model, 41 | ) 42 | 43 | inferred_layout: List[TextRegion] = cast(List[TextRegion], page.detection_model(page.image)) 44 | 45 | cv_img = np.array(image) 46 | 47 | if mode == "individual_blocks": 48 | # OCR'ing individual blocks (current approach) 49 | text_extraction_start_time = time.time() 50 | 51 | elements = page.get_elements_from_layout(inferred_layout) 52 | 53 | text_extraction_infer_time = time.time() - text_extraction_start_time 54 | 55 | total_text_extraction_infer_time += text_extraction_infer_time 56 | 57 | page_text = "" 58 | for el in elements: 59 | page_text += el.text 60 | filtered_page_text = remove_non_printable(page_text) 61 | total_extracted_text[page_num_str] = filtered_page_text 62 | elif mode == "entire_page": 63 | # OCR'ing entire page (new approach to implement) 64 | text_extraction_start_time = time.time() 65 | 66 | ocr_data = pytesseract.image_to_data(image, lang='eng', output_type=Output.DICT) 67 | boxes = ocr_data['level'] 68 | extracted_text_list = [] 69 | for k in range(len(boxes)): 70 | (x, y, w, h) = ocr_data['left'][k], ocr_data['top'][k], ocr_data['width'][k], ocr_data['height'][k] 71 | extracted_text = ocr_data['text'][k] 72 | if not extracted_text: 73 | continue 74 | 75 | extracted_region = Rectangle(x1=x, y1=y, x2=x+w, y2=y+h) 76 | 77 | extracted_is_subregion_of_inferred = False 78 | for inferred_region in inferred_layout: 79 | extracted_is_subregion_of_inferred = extracted_region.is_almost_subregion_of( 80 | inferred_region.pad(12), 81 | subregion_threshold=0.75, 82 | ) 83 | if extracted_is_subregion_of_inferred: 84 | break 85 | 86 | if extracted_is_subregion_of_inferred: 87 | extracted_text_list.append(extracted_text) 88 | 89 | if drawable: 90 | if extracted_is_subregion_of_inferred: 91 | cv2.rectangle(cv_img, (x, y), (x + w, y + h), (0, 255, 0), 2, None) 92 | else: 93 | cv2.rectangle(cv_img, (x, y), (x + w, y + h), (255, 0, 0), 2, None) 94 | 95 | text_extraction_infer_time = time.time() - text_extraction_start_time 96 | total_text_extraction_infer_time += text_extraction_infer_time 97 | 98 | page_text = " ".join(extracted_text_list) 99 | filtered_page_text = remove_non_printable(page_text) 100 | total_extracted_text[page_num_str] = filtered_page_text 101 | else: 102 | raise ValueError("Invalid mode") 103 | 104 | if drawable: 105 | for el in inferred_layout: 106 | pt1 = [int(el.x1), int(el.y1)] 107 | pt2 = [int(el.x2), int(el.y2)] 108 | cv2.rectangle( 109 | img=cv_img, 110 | pt1=pt1, pt2=pt2, 111 | color=(0, 0, 255), 112 | thickness=4, 113 | lineType=None, 114 | ) 115 | 116 | f_path = os.path.join(output_dir, f"ocr_{mode}_{page_num_str}.jpg") 117 | cv2.imwrite(f_path, cv_img) 118 | 119 | if printable: 120 | print(f"page: {i + 1} - n_layout_elements: {len(inferred_layout)} - " 121 | f"text_extraction_infer_time: {text_extraction_infer_time}") 122 | 123 | return total_text_extraction_infer_time, total_extracted_text 124 | 125 | 126 | def run_ocr( 127 | images, 128 | printable=True, 129 | ): 130 | total_text_extraction_infer_time = 0 131 | total_text = "" 132 | for i, image in enumerate(images): 133 | text_extraction_start_time = time.time() 134 | 135 | page_text = pytesseract.image_to_string(image) 136 | 137 | text_extraction_infer_time = time.time() - text_extraction_start_time 138 | 139 | if printable: 140 | print(f"page: {i + 1} - text_extraction_infer_time: {text_extraction_infer_time}") 141 | 142 | total_text_extraction_infer_time += text_extraction_infer_time 143 | total_text += page_text 144 | 145 | return total_text_extraction_infer_time, total_text 146 | -------------------------------------------------------------------------------- /examples/ocr/output/.gitignore: -------------------------------------------------------------------------------- 1 | * -------------------------------------------------------------------------------- /examples/ocr/requirements.txt: -------------------------------------------------------------------------------- 1 | unstructured[local-inference] 2 | nltk -------------------------------------------------------------------------------- /examples/ocr/validate_ocr_performance.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | from datetime import datetime 5 | from difflib import SequenceMatcher 6 | 7 | import nltk 8 | import pdf2image 9 | 10 | from unstructured_inference.inference.layout import ( 11 | DocumentLayout, 12 | create_image_output_dir, 13 | process_file_with_model, 14 | ) 15 | 16 | # Download the required resources (run this once) 17 | nltk.download('punkt') 18 | 19 | 20 | def validate_performance( 21 | f_name, 22 | validation_mode, 23 | is_image_file=False, 24 | ): 25 | print(f">>> Start performance comparison - filename: {f_name} - validation_mode: {validation_mode}" 26 | f" - is_image_file: {is_image_file}") 27 | 28 | now_dt = datetime.utcnow() 29 | now_str = now_dt.strftime("%Y_%m_%d-%H_%M_%S") 30 | 31 | f_path = os.path.join(example_docs_dir, f_name) 32 | 33 | image_f_paths = [] 34 | if validation_mode == "pdf": 35 | pdf_info = pdf2image.pdfinfo_from_path(f_path) 36 | n_pages = pdf_info["Pages"] 37 | elif validation_mode == "image": 38 | if is_image_file: 39 | image_f_paths.append(f_path) 40 | else: 41 | image_output_dir = create_image_output_dir(f_path) 42 | images = pdf2image.convert_from_path(f_path, output_folder=image_output_dir) 43 | image_f_paths = [image.filename for image in images] 44 | n_pages = len(image_f_paths) 45 | else: 46 | n_pages = 0 47 | 48 | processing_result = {} 49 | for ocr_mode in ["individual_blocks", "entire_page"]: 50 | start_time = time.time() 51 | 52 | if validation_mode == "pdf": 53 | layout = process_file_with_model( 54 | f_path, 55 | model_name=None, 56 | ocr_mode=ocr_mode, 57 | ) 58 | elif validation_mode == "image": 59 | pages = [] 60 | for image_f_path in image_f_paths: 61 | _layout = process_file_with_model( 62 | image_f_path, 63 | model_name=None, 64 | ocr_mode=ocr_mode, 65 | is_image=True, 66 | ) 67 | pages += _layout.pages 68 | for i, page in enumerate(pages): 69 | page.number = i + 1 70 | layout = DocumentLayout.from_pages(pages) 71 | else: 72 | layout = None 73 | 74 | infer_time = time.time() - start_time 75 | 76 | if layout is None: 77 | print("Layout is None") 78 | return 79 | 80 | full_text = str(layout) 81 | page_text = {} 82 | for page in layout.pages: 83 | page_text[page.number] = str(page) 84 | 85 | processing_result[ocr_mode] = { 86 | "infer_time": infer_time, 87 | "full_text": full_text, 88 | "page_text": page_text, 89 | } 90 | 91 | individual_mode_page_text = processing_result["individual_blocks"]["page_text"] 92 | entire_mode_page_text = processing_result["individual_blocks"]["page_text"] 93 | individual_mode_full_text = processing_result["individual_blocks"]["full_text"] 94 | entire_mode_full_text = processing_result["entire_page"]["full_text"] 95 | 96 | compare_result = compare_processed_text(individual_mode_full_text, entire_mode_full_text) 97 | 98 | report = { 99 | "validation_mode": validation_mode, 100 | "file_info": { 101 | "filename": f_name, 102 | "n_pages": n_pages, 103 | }, 104 | "processing_time": { 105 | "individual_blocks": processing_result["individual_blocks"]["infer_time"], 106 | "entire_page": processing_result["entire_page"]["infer_time"], 107 | }, 108 | "text_similarity": compare_result, 109 | "extracted_text": { 110 | "individual_blocks": { 111 | "page_text": individual_mode_page_text, 112 | "full_text": individual_mode_full_text, 113 | }, 114 | "entire_page": { 115 | "page_text": entire_mode_page_text, 116 | "full_text": entire_mode_full_text, 117 | }, 118 | }, 119 | } 120 | 121 | write_report(report, now_str, validation_mode) 122 | 123 | print("<<< End performance comparison", f_name) 124 | 125 | 126 | def compare_processed_text(individual_mode_full_text, entire_mode_full_text, delimiter=" "): 127 | # Calculate similarity ratio 128 | similarity_ratio = SequenceMatcher(None, individual_mode_full_text, entire_mode_full_text).ratio() 129 | 130 | print(f"similarity_ratio: {similarity_ratio}") 131 | 132 | # Tokenize the text into words 133 | word_list_individual = nltk.word_tokenize(individual_mode_full_text) 134 | n_word_list_individual = len(word_list_individual) 135 | print("n_word_list_in_text_individual:", n_word_list_individual) 136 | word_sets_individual = set(word_list_individual) 137 | n_word_sets_individual = len(word_sets_individual) 138 | print(f"n_word_sets_in_text_individual: {n_word_sets_individual}") 139 | # print("word_sets_merged:", word_sets_merged) 140 | 141 | word_list_entire = nltk.word_tokenize(entire_mode_full_text) 142 | n_word_list_entire = len(word_list_entire) 143 | print("n_word_list_individual:", n_word_list_entire) 144 | word_sets_entire = set(word_list_entire) 145 | n_word_sets_entire = len(word_sets_entire) 146 | print(f"n_word_sets_individual: {n_word_sets_entire}") 147 | # print("word_sets_individual:", word_sets_individual) 148 | 149 | # Find unique elements using difference 150 | print("diff_elements:") 151 | unique_words_individual = word_sets_individual - word_sets_entire 152 | unique_words_entire = word_sets_entire - word_sets_individual 153 | print(f"unique_words_in_text_individual: {unique_words_individual}\n") 154 | print(f"unique_words_in_text_entire: {unique_words_entire}") 155 | 156 | return { 157 | "similarity_ratio": similarity_ratio, 158 | "individual_blocks": { 159 | "n_word_list": n_word_list_individual, 160 | "n_word_sets": n_word_sets_individual, 161 | "unique_words": delimiter.join(list(unique_words_individual)), 162 | }, 163 | "entire_page": { 164 | "n_word_list": n_word_list_entire, 165 | "n_word_sets": n_word_sets_entire, 166 | "unique_words": delimiter.join(list(unique_words_entire)), 167 | }, 168 | } 169 | 170 | 171 | def write_report(report, now_str, validation_mode): 172 | report_f_name = f"validate-ocr-{validation_mode}-{now_str}.json" 173 | report_f_path = os.path.join(output_dir, report_f_name) 174 | with open(report_f_path, "w", encoding="utf-8-sig") as f: 175 | json.dump(report, f, indent=4) 176 | 177 | 178 | def run(): 179 | test_files = [ 180 | {"name": "layout-parser-paper-fast.pdf", "mode": "image", "is_image_file": False}, 181 | {"name": "loremipsum_multipage.pdf", "mode": "image", "is_image_file": False}, 182 | {"name": "2023-Jan-economic-outlook.pdf", "mode": "image", "is_image_file": False}, 183 | {"name": "recalibrating-risk-report.pdf", "mode": "image", "is_image_file": False}, 184 | {"name": "Silent-Giant.pdf", "mode": "image", "is_image_file": False}, 185 | ] 186 | 187 | for test_file in test_files: 188 | f_name = test_file["name"] 189 | validation_mode = test_file["mode"] 190 | is_image_file = test_file["is_image_file"] 191 | 192 | validate_performance(f_name, validation_mode, is_image_file) 193 | 194 | 195 | if __name__ == '__main__': 196 | cur_dir = os.getcwd() 197 | base_dir = os.path.join(cur_dir, os.pardir, os.pardir) 198 | example_docs_dir = os.path.join(base_dir, "sample-docs") 199 | 200 | # folder path to save temporary outputs 201 | output_dir = os.path.join(cur_dir, "output") 202 | os.makedirs(output_dir, exist_ok=True) 203 | 204 | run() 205 | -------------------------------------------------------------------------------- /img/unstructured_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/img/unstructured_logo.png -------------------------------------------------------------------------------- /logger_config.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | disable_existing_loggers: False 3 | formatters: 4 | default_format: 5 | "()": uvicorn.logging.DefaultFormatter 6 | format: '%(asctime)s %(name)s %(levelname)s %(message)s' 7 | access: 8 | "()": uvicorn.logging.AccessFormatter 9 | format: '%(asctime)s %(client_addr)s %(request_line)s - %(status_code)s' 10 | handlers: 11 | access_handler: 12 | formatter: access 13 | class: logging.StreamHandler 14 | stream: ext://sys.stderr 15 | standard_handler: 16 | formatter: default_format 17 | class: logging.StreamHandler 18 | stream: ext://sys.stderr 19 | loggers: 20 | uvicorn.error: 21 | level: INFO 22 | handlers: 23 | - standard_handler 24 | propagate: no 25 | # disable logging for uvicorn.error by not having a handler 26 | uvicorn.access: 27 | level: INFO 28 | handlers: 29 | - access_handler 30 | propagate: no 31 | # disable logging for uvicorn.access by not having a handler 32 | unstructured: 33 | level: INFO 34 | handlers: 35 | - standard_handler 36 | propagate: no 37 | unstructured_inference: 38 | level: DEBUG 39 | handlers: 40 | - standard_handler 41 | propagate: no 42 | 43 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 100 3 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | slow: marks tests as slow (deselect with '-m "not long"') 4 | -------------------------------------------------------------------------------- /requirements/base.in: -------------------------------------------------------------------------------- 1 | -c constraints.in 2 | python-multipart 3 | huggingface-hub 4 | numpy 5 | opencv-python!=4.7.0.68 6 | onnx 7 | onnxruntime>=1.18.0 8 | matplotlib 9 | torch 10 | timm 11 | # NOTE(alan): Pinned because this is when the most recent module we import appeared 12 | transformers>=4.25.1 13 | accelerate 14 | rapidfuzz 15 | pandas 16 | scipy 17 | pypdfium2 18 | pdfminer-six 19 | -------------------------------------------------------------------------------- /requirements/base.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with Python 3.12 3 | # by the following command: 4 | # 5 | # pip-compile requirements/base.in 6 | # 7 | accelerate==1.7.0 8 | # via -r requirements/base.in 9 | certifi==2025.4.26 10 | # via requests 11 | cffi==1.17.1 12 | # via cryptography 13 | charset-normalizer==3.4.2 14 | # via 15 | # pdfminer-six 16 | # requests 17 | coloredlogs==15.0.1 18 | # via onnxruntime 19 | contourpy==1.3.2 20 | # via matplotlib 21 | cryptography==44.0.3 22 | # via pdfminer-six 23 | cycler==0.12.1 24 | # via matplotlib 25 | filelock==3.18.0 26 | # via 27 | # huggingface-hub 28 | # torch 29 | # transformers 30 | flatbuffers==25.2.10 31 | # via onnxruntime 32 | fonttools==4.58.0 33 | # via matplotlib 34 | fsspec==2025.3.2 35 | # via 36 | # huggingface-hub 37 | # torch 38 | huggingface-hub==0.31.2 39 | # via 40 | # -r requirements/base.in 41 | # accelerate 42 | # timm 43 | # tokenizers 44 | # transformers 45 | humanfriendly==10.0 46 | # via coloredlogs 47 | idna==3.10 48 | # via requests 49 | jinja2==3.1.6 50 | # via torch 51 | kiwisolver==1.4.8 52 | # via matplotlib 53 | markupsafe==3.0.2 54 | # via jinja2 55 | matplotlib==3.10.3 56 | # via -r requirements/base.in 57 | mpmath==1.3.0 58 | # via sympy 59 | networkx==3.4.2 60 | # via torch 61 | numpy==2.2.5 62 | # via 63 | # -r requirements/base.in 64 | # accelerate 65 | # contourpy 66 | # matplotlib 67 | # onnx 68 | # onnxruntime 69 | # opencv-python 70 | # pandas 71 | # scipy 72 | # torchvision 73 | # transformers 74 | onnx==1.18.0 75 | # via -r requirements/base.in 76 | onnxruntime==1.22.0 77 | # via -r requirements/base.in 78 | opencv-python==4.11.0.86 79 | # via -r requirements/base.in 80 | packaging==25.0 81 | # via 82 | # accelerate 83 | # huggingface-hub 84 | # matplotlib 85 | # onnxruntime 86 | # transformers 87 | pandas==2.2.3 88 | # via -r requirements/base.in 89 | pdfminer-six==20250506 90 | # via -r requirements/base.in 91 | pillow==11.2.1 92 | # via 93 | # matplotlib 94 | # torchvision 95 | protobuf==6.31.0 96 | # via 97 | # onnx 98 | # onnxruntime 99 | psutil==7.0.0 100 | # via accelerate 101 | pycparser==2.22 102 | # via cffi 103 | pyparsing==3.2.3 104 | # via matplotlib 105 | pypdfium2==4.30.1 106 | # via -r requirements/base.in 107 | python-dateutil==2.9.0.post0 108 | # via 109 | # matplotlib 110 | # pandas 111 | python-multipart==0.0.20 112 | # via -r requirements/base.in 113 | pytz==2025.2 114 | # via pandas 115 | pyyaml==6.0.2 116 | # via 117 | # accelerate 118 | # huggingface-hub 119 | # timm 120 | # transformers 121 | rapidfuzz==3.13.0 122 | # via -r requirements/base.in 123 | regex==2024.11.6 124 | # via transformers 125 | requests==2.32.3 126 | # via 127 | # huggingface-hub 128 | # transformers 129 | safetensors==0.5.3 130 | # via 131 | # accelerate 132 | # timm 133 | # transformers 134 | scipy==1.15.3 135 | # via -r requirements/base.in 136 | six==1.17.0 137 | # via python-dateutil 138 | sympy==1.14.0 139 | # via 140 | # onnxruntime 141 | # torch 142 | timm==1.0.15 143 | # via -r requirements/base.in 144 | tokenizers==0.21.1 145 | # via transformers 146 | torch==2.7.0 147 | # via 148 | # -r requirements/base.in 149 | # accelerate 150 | # timm 151 | # torchvision 152 | torchvision==0.22.0 153 | # via timm 154 | tqdm==4.67.1 155 | # via 156 | # huggingface-hub 157 | # transformers 158 | transformers==4.51.3 159 | # via -r requirements/base.in 160 | typing-extensions==4.13.2 161 | # via 162 | # huggingface-hub 163 | # onnx 164 | # torch 165 | tzdata==2025.2 166 | # via pandas 167 | urllib3==2.4.0 168 | # via requests 169 | 170 | # The following packages are considered to be unsafe in a requirements file: 171 | # setuptools 172 | -------------------------------------------------------------------------------- /requirements/constraints.in: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/requirements/constraints.in -------------------------------------------------------------------------------- /requirements/dev.in: -------------------------------------------------------------------------------- 1 | -c constraints.in 2 | -c base.txt 3 | -c test.txt 4 | jupyter 5 | ipython 6 | pip-tools 7 | matplotlib -------------------------------------------------------------------------------- /requirements/dev.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with Python 3.12 3 | # by the following command: 4 | # 5 | # pip-compile requirements/dev.in 6 | # 7 | anyio==4.9.0 8 | # via 9 | # -c requirements/test.txt 10 | # httpx 11 | # jupyter-server 12 | appnope==0.1.4 13 | # via ipykernel 14 | argon2-cffi==23.1.0 15 | # via jupyter-server 16 | argon2-cffi-bindings==21.2.0 17 | # via argon2-cffi 18 | arrow==1.3.0 19 | # via isoduration 20 | asttokens==3.0.0 21 | # via stack-data 22 | async-lru==2.0.5 23 | # via jupyterlab 24 | attrs==25.3.0 25 | # via 26 | # jsonschema 27 | # referencing 28 | babel==2.17.0 29 | # via jupyterlab-server 30 | beautifulsoup4==4.13.4 31 | # via nbconvert 32 | bleach[css]==6.2.0 33 | # via nbconvert 34 | build==1.2.2.post1 35 | # via pip-tools 36 | certifi==2025.4.26 37 | # via 38 | # -c requirements/base.txt 39 | # -c requirements/test.txt 40 | # httpcore 41 | # httpx 42 | # requests 43 | cffi==1.17.1 44 | # via 45 | # -c requirements/base.txt 46 | # argon2-cffi-bindings 47 | charset-normalizer==3.4.2 48 | # via 49 | # -c requirements/base.txt 50 | # -c requirements/test.txt 51 | # requests 52 | click==8.2.0 53 | # via 54 | # -c requirements/test.txt 55 | # pip-tools 56 | comm==0.2.2 57 | # via 58 | # ipykernel 59 | # ipywidgets 60 | contourpy==1.3.2 61 | # via 62 | # -c requirements/base.txt 63 | # matplotlib 64 | cycler==0.12.1 65 | # via 66 | # -c requirements/base.txt 67 | # matplotlib 68 | debugpy==1.8.14 69 | # via ipykernel 70 | decorator==5.2.1 71 | # via ipython 72 | defusedxml==0.7.1 73 | # via nbconvert 74 | executing==2.2.0 75 | # via stack-data 76 | fastjsonschema==2.21.1 77 | # via nbformat 78 | fonttools==4.58.0 79 | # via 80 | # -c requirements/base.txt 81 | # matplotlib 82 | fqdn==1.5.1 83 | # via jsonschema 84 | h11==0.16.0 85 | # via 86 | # -c requirements/test.txt 87 | # httpcore 88 | httpcore==1.0.9 89 | # via 90 | # -c requirements/test.txt 91 | # httpx 92 | httpx==0.28.1 93 | # via 94 | # -c requirements/test.txt 95 | # jupyterlab 96 | idna==3.10 97 | # via 98 | # -c requirements/base.txt 99 | # -c requirements/test.txt 100 | # anyio 101 | # httpx 102 | # jsonschema 103 | # requests 104 | ipykernel==6.29.5 105 | # via 106 | # jupyter 107 | # jupyter-console 108 | # jupyterlab 109 | ipython==9.2.0 110 | # via 111 | # -r requirements/dev.in 112 | # ipykernel 113 | # ipywidgets 114 | # jupyter-console 115 | ipython-pygments-lexers==1.1.1 116 | # via ipython 117 | ipywidgets==8.1.7 118 | # via jupyter 119 | isoduration==20.11.0 120 | # via jsonschema 121 | jedi==0.19.2 122 | # via ipython 123 | jinja2==3.1.6 124 | # via 125 | # -c requirements/base.txt 126 | # jupyter-server 127 | # jupyterlab 128 | # jupyterlab-server 129 | # nbconvert 130 | json5==0.12.0 131 | # via jupyterlab-server 132 | jsonpointer==3.0.0 133 | # via jsonschema 134 | jsonschema[format-nongpl]==4.23.0 135 | # via 136 | # jupyter-events 137 | # jupyterlab-server 138 | # nbformat 139 | jsonschema-specifications==2025.4.1 140 | # via jsonschema 141 | jupyter==1.1.1 142 | # via -r requirements/dev.in 143 | jupyter-client==8.6.3 144 | # via 145 | # ipykernel 146 | # jupyter-console 147 | # jupyter-server 148 | # nbclient 149 | jupyter-console==6.6.3 150 | # via jupyter 151 | jupyter-core==5.7.2 152 | # via 153 | # ipykernel 154 | # jupyter-client 155 | # jupyter-console 156 | # jupyter-server 157 | # jupyterlab 158 | # nbclient 159 | # nbconvert 160 | # nbformat 161 | jupyter-events==0.12.0 162 | # via jupyter-server 163 | jupyter-lsp==2.2.5 164 | # via jupyterlab 165 | jupyter-server==2.16.0 166 | # via 167 | # jupyter-lsp 168 | # jupyterlab 169 | # jupyterlab-server 170 | # notebook 171 | # notebook-shim 172 | jupyter-server-terminals==0.5.3 173 | # via jupyter-server 174 | jupyterlab==4.4.2 175 | # via 176 | # jupyter 177 | # notebook 178 | jupyterlab-pygments==0.3.0 179 | # via nbconvert 180 | jupyterlab-server==2.27.3 181 | # via 182 | # jupyterlab 183 | # notebook 184 | jupyterlab-widgets==3.0.15 185 | # via ipywidgets 186 | kiwisolver==1.4.8 187 | # via 188 | # -c requirements/base.txt 189 | # matplotlib 190 | markupsafe==3.0.2 191 | # via 192 | # -c requirements/base.txt 193 | # jinja2 194 | # nbconvert 195 | matplotlib==3.10.3 196 | # via 197 | # -c requirements/base.txt 198 | # -r requirements/dev.in 199 | matplotlib-inline==0.1.7 200 | # via 201 | # ipykernel 202 | # ipython 203 | mistune==3.1.3 204 | # via nbconvert 205 | nbclient==0.10.2 206 | # via nbconvert 207 | nbconvert==7.16.6 208 | # via 209 | # jupyter 210 | # jupyter-server 211 | nbformat==5.10.4 212 | # via 213 | # jupyter-server 214 | # nbclient 215 | # nbconvert 216 | nest-asyncio==1.6.0 217 | # via ipykernel 218 | notebook==7.4.2 219 | # via jupyter 220 | notebook-shim==0.2.4 221 | # via 222 | # jupyterlab 223 | # notebook 224 | numpy==2.2.5 225 | # via 226 | # -c requirements/base.txt 227 | # contourpy 228 | # matplotlib 229 | overrides==7.7.0 230 | # via jupyter-server 231 | packaging==25.0 232 | # via 233 | # -c requirements/base.txt 234 | # -c requirements/test.txt 235 | # build 236 | # ipykernel 237 | # jupyter-events 238 | # jupyter-server 239 | # jupyterlab 240 | # jupyterlab-server 241 | # matplotlib 242 | # nbconvert 243 | pandocfilters==1.5.1 244 | # via nbconvert 245 | parso==0.8.4 246 | # via jedi 247 | pexpect==4.9.0 248 | # via ipython 249 | pillow==11.2.1 250 | # via 251 | # -c requirements/base.txt 252 | # -c requirements/test.txt 253 | # matplotlib 254 | pip-tools==7.4.1 255 | # via -r requirements/dev.in 256 | platformdirs==4.3.8 257 | # via 258 | # -c requirements/test.txt 259 | # jupyter-core 260 | prometheus-client==0.21.1 261 | # via jupyter-server 262 | prompt-toolkit==3.0.51 263 | # via 264 | # ipython 265 | # jupyter-console 266 | psutil==7.0.0 267 | # via 268 | # -c requirements/base.txt 269 | # ipykernel 270 | ptyprocess==0.7.0 271 | # via 272 | # pexpect 273 | # terminado 274 | pure-eval==0.2.3 275 | # via stack-data 276 | pycparser==2.22 277 | # via 278 | # -c requirements/base.txt 279 | # cffi 280 | pygments==2.19.1 281 | # via 282 | # ipython 283 | # ipython-pygments-lexers 284 | # jupyter-console 285 | # nbconvert 286 | pyparsing==3.2.3 287 | # via 288 | # -c requirements/base.txt 289 | # matplotlib 290 | pyproject-hooks==1.2.0 291 | # via 292 | # build 293 | # pip-tools 294 | python-dateutil==2.9.0.post0 295 | # via 296 | # -c requirements/base.txt 297 | # arrow 298 | # jupyter-client 299 | # matplotlib 300 | python-json-logger==3.3.0 301 | # via jupyter-events 302 | pyyaml==6.0.2 303 | # via 304 | # -c requirements/base.txt 305 | # -c requirements/test.txt 306 | # jupyter-events 307 | pyzmq==26.4.0 308 | # via 309 | # ipykernel 310 | # jupyter-client 311 | # jupyter-console 312 | # jupyter-server 313 | referencing==0.36.2 314 | # via 315 | # jsonschema 316 | # jsonschema-specifications 317 | # jupyter-events 318 | requests==2.32.3 319 | # via 320 | # -c requirements/base.txt 321 | # -c requirements/test.txt 322 | # jupyterlab-server 323 | rfc3339-validator==0.1.4 324 | # via 325 | # jsonschema 326 | # jupyter-events 327 | rfc3986-validator==0.1.1 328 | # via 329 | # jsonschema 330 | # jupyter-events 331 | rpds-py==0.25.0 332 | # via 333 | # jsonschema 334 | # referencing 335 | send2trash==1.8.3 336 | # via jupyter-server 337 | six==1.17.0 338 | # via 339 | # -c requirements/base.txt 340 | # python-dateutil 341 | # rfc3339-validator 342 | sniffio==1.3.1 343 | # via 344 | # -c requirements/test.txt 345 | # anyio 346 | soupsieve==2.7 347 | # via beautifulsoup4 348 | stack-data==0.6.3 349 | # via ipython 350 | terminado==0.18.1 351 | # via 352 | # jupyter-server 353 | # jupyter-server-terminals 354 | tinycss2==1.4.0 355 | # via bleach 356 | tornado==6.5 357 | # via 358 | # ipykernel 359 | # jupyter-client 360 | # jupyter-server 361 | # jupyterlab 362 | # notebook 363 | # terminado 364 | traitlets==5.14.3 365 | # via 366 | # comm 367 | # ipykernel 368 | # ipython 369 | # ipywidgets 370 | # jupyter-client 371 | # jupyter-console 372 | # jupyter-core 373 | # jupyter-events 374 | # jupyter-server 375 | # jupyterlab 376 | # matplotlib-inline 377 | # nbclient 378 | # nbconvert 379 | # nbformat 380 | types-python-dateutil==2.9.0.20241206 381 | # via arrow 382 | typing-extensions==4.13.2 383 | # via 384 | # -c requirements/base.txt 385 | # -c requirements/test.txt 386 | # anyio 387 | # beautifulsoup4 388 | # referencing 389 | uri-template==1.3.0 390 | # via jsonschema 391 | urllib3==2.4.0 392 | # via 393 | # -c requirements/base.txt 394 | # -c requirements/test.txt 395 | # requests 396 | wcwidth==0.2.13 397 | # via prompt-toolkit 398 | webcolors==24.11.1 399 | # via jsonschema 400 | webencodings==0.5.1 401 | # via 402 | # bleach 403 | # tinycss2 404 | websocket-client==1.8.0 405 | # via jupyter-server 406 | wheel==0.45.1 407 | # via pip-tools 408 | widgetsnbextension==4.0.14 409 | # via ipywidgets 410 | 411 | # The following packages are considered to be unsafe in a requirements file: 412 | # pip 413 | # setuptools 414 | -------------------------------------------------------------------------------- /requirements/test.in: -------------------------------------------------------------------------------- 1 | -c constraints.in 2 | -c base.txt 3 | black>=22.3.0 4 | coverage 5 | # NOTE(mrobinson) - Pinning click due to a unicode issue in black 6 | # can remove after black drops support for Python 3.6 7 | # ref: https://github.com/psf/black/issues/2964 8 | click>=8.1 9 | # NOTE(alan) - Added to cover the fact that is isn't specified in 10 | # starlette even though it's required for TestClient 11 | httpx 12 | flake8 13 | flake8-docstrings 14 | mypy 15 | pytest-cov 16 | pytest-mock 17 | pdf2image>=1.16.2 18 | huggingface_hub>=0.11.1 19 | ruff 20 | types-pyyaml 21 | -------------------------------------------------------------------------------- /requirements/test.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with Python 3.12 3 | # by the following command: 4 | # 5 | # pip-compile requirements/test.in 6 | # 7 | anyio==4.9.0 8 | # via httpx 9 | black==25.1.0 10 | # via -r requirements/test.in 11 | certifi==2025.4.26 12 | # via 13 | # -c requirements/base.txt 14 | # httpcore 15 | # httpx 16 | # requests 17 | charset-normalizer==3.4.2 18 | # via 19 | # -c requirements/base.txt 20 | # requests 21 | click==8.2.0 22 | # via 23 | # -r requirements/test.in 24 | # black 25 | coverage[toml]==7.8.0 26 | # via 27 | # -r requirements/test.in 28 | # pytest-cov 29 | filelock==3.18.0 30 | # via 31 | # -c requirements/base.txt 32 | # huggingface-hub 33 | flake8==7.2.0 34 | # via 35 | # -r requirements/test.in 36 | # flake8-docstrings 37 | flake8-docstrings==1.7.0 38 | # via -r requirements/test.in 39 | fsspec==2025.3.2 40 | # via 41 | # -c requirements/base.txt 42 | # huggingface-hub 43 | h11==0.16.0 44 | # via httpcore 45 | httpcore==1.0.9 46 | # via httpx 47 | httpx==0.28.1 48 | # via -r requirements/test.in 49 | huggingface-hub==0.31.2 50 | # via 51 | # -c requirements/base.txt 52 | # -r requirements/test.in 53 | idna==3.10 54 | # via 55 | # -c requirements/base.txt 56 | # anyio 57 | # httpx 58 | # requests 59 | iniconfig==2.1.0 60 | # via pytest 61 | mccabe==0.7.0 62 | # via flake8 63 | mypy==1.15.0 64 | # via -r requirements/test.in 65 | mypy-extensions==1.1.0 66 | # via 67 | # black 68 | # mypy 69 | packaging==25.0 70 | # via 71 | # -c requirements/base.txt 72 | # black 73 | # huggingface-hub 74 | # pytest 75 | pathspec==0.12.1 76 | # via black 77 | pdf2image==1.17.0 78 | # via -r requirements/test.in 79 | pillow==11.2.1 80 | # via 81 | # -c requirements/base.txt 82 | # pdf2image 83 | platformdirs==4.3.8 84 | # via black 85 | pluggy==1.6.0 86 | # via pytest 87 | pycodestyle==2.13.0 88 | # via flake8 89 | pydocstyle==6.3.0 90 | # via flake8-docstrings 91 | pyflakes==3.3.2 92 | # via flake8 93 | pytest==8.3.5 94 | # via 95 | # pytest-cov 96 | # pytest-mock 97 | pytest-cov==6.1.1 98 | # via -r requirements/test.in 99 | pytest-mock==3.14.0 100 | # via -r requirements/test.in 101 | pyyaml==6.0.2 102 | # via 103 | # -c requirements/base.txt 104 | # huggingface-hub 105 | requests==2.32.3 106 | # via 107 | # -c requirements/base.txt 108 | # huggingface-hub 109 | ruff==0.11.10 110 | # via -r requirements/test.in 111 | sniffio==1.3.1 112 | # via anyio 113 | snowballstemmer==3.0.1 114 | # via pydocstyle 115 | tqdm==4.67.1 116 | # via 117 | # -c requirements/base.txt 118 | # huggingface-hub 119 | types-pyyaml==6.0.12.20250402 120 | # via -r requirements/test.in 121 | typing-extensions==4.13.2 122 | # via 123 | # -c requirements/base.txt 124 | # anyio 125 | # huggingface-hub 126 | # mypy 127 | urllib3==2.4.0 128 | # via 129 | # -c requirements/base.txt 130 | # requests 131 | -------------------------------------------------------------------------------- /sample-docs/2023-Jan-economic-outlook.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/sample-docs/2023-Jan-economic-outlook.pdf -------------------------------------------------------------------------------- /sample-docs/IRS-form-1987.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/sample-docs/IRS-form-1987.pdf -------------------------------------------------------------------------------- /sample-docs/RGBA_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/sample-docs/RGBA_image.png -------------------------------------------------------------------------------- /sample-docs/Silent-Giant.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/sample-docs/Silent-Giant.pdf -------------------------------------------------------------------------------- /sample-docs/design-thinking.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/sample-docs/design-thinking.pdf -------------------------------------------------------------------------------- /sample-docs/easy_table.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/sample-docs/easy_table.jpg -------------------------------------------------------------------------------- /sample-docs/embedded-images.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/sample-docs/embedded-images.pdf -------------------------------------------------------------------------------- /sample-docs/empty-document.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/sample-docs/empty-document.pdf -------------------------------------------------------------------------------- /sample-docs/example_table.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/sample-docs/example_table.jpg -------------------------------------------------------------------------------- /sample-docs/ilpa-example-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/sample-docs/ilpa-example-1.jpg -------------------------------------------------------------------------------- /sample-docs/layout-parser-paper-fast.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/sample-docs/layout-parser-paper-fast.jpg -------------------------------------------------------------------------------- /sample-docs/layout-parser-paper-fast.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/sample-docs/layout-parser-paper-fast.pdf -------------------------------------------------------------------------------- /sample-docs/layout-parser-paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/sample-docs/layout-parser-paper.pdf -------------------------------------------------------------------------------- /sample-docs/loremipsum-flat.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/sample-docs/loremipsum-flat.pdf -------------------------------------------------------------------------------- /sample-docs/loremipsum.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/sample-docs/loremipsum.jpg -------------------------------------------------------------------------------- /sample-docs/loremipsum.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/sample-docs/loremipsum.pdf -------------------------------------------------------------------------------- /sample-docs/loremipsum.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/sample-docs/loremipsum.png -------------------------------------------------------------------------------- /sample-docs/loremipsum.tiff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/sample-docs/loremipsum.tiff -------------------------------------------------------------------------------- /sample-docs/loremipsum_multipage.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/sample-docs/loremipsum_multipage.pdf -------------------------------------------------------------------------------- /sample-docs/non-embedded.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/sample-docs/non-embedded.pdf -------------------------------------------------------------------------------- /sample-docs/password.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/sample-docs/password.pdf -------------------------------------------------------------------------------- /sample-docs/patent-1p.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/sample-docs/patent-1p.pdf -------------------------------------------------------------------------------- /sample-docs/patent.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/sample-docs/patent.pdf -------------------------------------------------------------------------------- /sample-docs/pdf2image-memory-error-test-400p.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/sample-docs/pdf2image-memory-error-test-400p.pdf -------------------------------------------------------------------------------- /sample-docs/recalibrating-risk-report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/sample-docs/recalibrating-risk-report.pdf -------------------------------------------------------------------------------- /sample-docs/receipt-sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/sample-docs/receipt-sample.jpg -------------------------------------------------------------------------------- /sample-docs/table-multi-row-column-cells.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/sample-docs/table-multi-row-column-cells.png -------------------------------------------------------------------------------- /sample-docs/test-image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/sample-docs/test-image.jpg -------------------------------------------------------------------------------- /scripts/docker-build.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -euo pipefail 4 | PIP_VERSION="${PIP_VERSION:-23.1.2}" 5 | DOCKER_IMAGE="unstructured-inference:dev" 6 | 7 | DOCKER_BUILD_CMD=(docker buildx build --load -f Dockerfile \ 8 | --build-arg PIP_VERSION="$PIP_VERSION" \ 9 | --build-arg BUILDKIT_INLINE_CACHE=1 \ 10 | --progress plain \ 11 | -t "$DOCKER_IMAGE" .) 12 | 13 | DOCKER_BUILDKIT=1 "${DOCKER_BUILD_CMD[@]}" -------------------------------------------------------------------------------- /scripts/shellcheck.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | find scripts -name "*.sh" -exec shellcheck {} + 4 | 5 | -------------------------------------------------------------------------------- /scripts/test-unstructured-ingest-helper.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # This is intended to be run from an unstructured checkout, not in this repo 4 | # The goal here is to see what changes the current branch would introduce to unstructured 5 | # fixtures 6 | 7 | INGEST_COMMANDS=( 8 | test_unstructured_ingest/src/azure.sh 9 | test_unstructured_ingest/src/biomed-api.sh 10 | test_unstructured_ingest/src/biomed-path.sh 11 | test_unstructured_ingest/src/box.sh 12 | test_unstructured_ingest/src/dropbox.sh 13 | test_unstructured_ingest/src/gcs.sh 14 | test_unstructured_ingest/src/onedrive.sh 15 | test_unstructured_ingest/src/s3.sh 16 | ) 17 | 18 | EXIT_STATUSES=() 19 | 20 | # Run each command and capture its exit status 21 | for INGEST_COMMAND in "${INGEST_COMMANDS[@]}"; do 22 | $INGEST_COMMAND 23 | EXIT_STATUSES+=($?) 24 | done 25 | 26 | # Check for failures 27 | for STATUS in "${EXIT_STATUSES[@]}"; do 28 | if [[ $STATUS -ne 0 ]]; then 29 | echo "At least one ingest command failed! Scroll up to see which" 30 | exit 1 31 | fi 32 | done 33 | 34 | echo "No diff's resulted from any ingest commands" 35 | -------------------------------------------------------------------------------- /scripts/version-sync.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | function usage { 3 | echo "Usage: $(basename "$0") [-c] -f FILE_TO_CHANGE REPLACEMENT_FORMAT [-f FILE_TO_CHANGE REPLACEMENT_FORMAT ...]" 2>&1 4 | echo 'Synchronize files to latest version in source file' 5 | echo ' -s Specifies source file for version (default is CHANGELOG.md)' 6 | echo ' -f Specifies a file to change and the format for searching and replacing versions' 7 | echo ' FILE_TO_CHANGE is the file to be updated/checked for updates' 8 | echo ' REPLACEMENT_FORMAT is one of (semver, release, api-release)' 9 | echo ' semver indicates to look for a full semver version and replace with the latest full version' 10 | echo ' release indicates to look for a release semver version (x.x.x) and replace with the latest release version' 11 | echo ' api-release indicates to look for a release semver version in the context of an api route and replace with the latest release version' 12 | echo ' -c Compare versions and output proposed changes without changing anything.' 13 | } 14 | 15 | function getopts-extra () { 16 | declare -i i=1 17 | # if the next argument is not an option, then append it to array OPTARG 18 | while [[ ${OPTIND} -le $# && ${!OPTIND:0:1} != '-' ]]; do 19 | OPTARG[i]=${!OPTIND} 20 | ((i += 1)) 21 | ((OPTIND += 1)) 22 | done 23 | } 24 | 25 | # Parse input options 26 | declare CHECK=0 27 | declare SOURCE_FILE="CHANGELOG.md" 28 | declare -a FILES_TO_CHECK=() 29 | declare -a REPLACEMENT_FORMATS=() 30 | declare args 31 | declare OPTIND OPTARG opt 32 | while getopts ":hcs:f:" opt; do 33 | case $opt in 34 | h) 35 | usage 36 | exit 0 37 | ;; 38 | c) 39 | CHECK=1 40 | ;; 41 | s) 42 | SOURCE_FILE="$OPTARG" 43 | ;; 44 | f) 45 | getopts-extra "$@" 46 | args=( "${OPTARG[@]}" ) 47 | # validate length of args, should be 2 48 | if [ ${#args[@]} -eq 2 ]; then 49 | FILES_TO_CHECK+=( "${args[0]}" ) 50 | REPLACEMENT_FORMATS+=( "${args[1]}" ) 51 | else 52 | echo "Exactly 2 arguments must follow -f option." >&2 53 | exit 1 54 | fi 55 | ;; 56 | \?) 57 | echo "Invalid option: -$OPTARG." >&2 58 | usage 59 | exit 1 60 | ;; 61 | esac 62 | done 63 | 64 | # Parse REPLACEMENT_FORMATS 65 | RE_SEMVER_FULL="(0|[1-9][0-9]*)\.(0|[1-9][0-9]*)\.(0|[1-9][0-9]*)(-((0|[1-9][0-9]*|[0-9]*[a-zA-Z-][0-9a-zA-Z-]*)(\.(0|[1-9][0-9]*|[0-9]*[a-zA-Z-][0-9a-zA-Z-]*))*))?(\+([0-9a-zA-Z-]+(\.[0-9a-zA-Z-]+)*))?" 66 | RE_RELEASE="(0|[1-9][0-9]*)\.(0|[1-9][0-9]*)\.(0|[1-9][0-9]*)" 67 | RE_API_RELEASE="v(0|[1-9][0-9]*)\.(0|[1-9][0-9]*)\.(0|[1-9][0-9]*)" 68 | # Pull out semver appearing earliest in SOURCE_FILE. 69 | LAST_VERSION=$(grep -o -m 1 -E "${RE_SEMVER_FULL}" "$SOURCE_FILE") 70 | LAST_RELEASE=$(grep -o -m 1 -E "${RE_RELEASE}($|[^-+])" "$SOURCE_FILE" | grep -o -m 1 -E "${RE_RELEASE}") 71 | LAST_API_RELEASE="v$(grep -o -m 1 -E "${RE_RELEASE}($|[^-+])$" "$SOURCE_FILE" | grep -o -m 1 -E "${RE_RELEASE}")" 72 | declare -a RE_SEMVERS=() 73 | declare -a UPDATED_VERSIONS=() 74 | for i in "${!REPLACEMENT_FORMATS[@]}"; do 75 | REPLACEMENT_FORMAT=${REPLACEMENT_FORMATS[$i]} 76 | case $REPLACEMENT_FORMAT in 77 | semver) 78 | RE_SEMVERS+=( "$RE_SEMVER_FULL" ) 79 | UPDATED_VERSIONS+=( "$LAST_VERSION" ) 80 | ;; 81 | release) 82 | RE_SEMVERS+=( "$RE_RELEASE" ) 83 | UPDATED_VERSIONS+=( "$LAST_RELEASE" ) 84 | ;; 85 | api-release) 86 | RE_SEMVERS+=( "$RE_API_RELEASE" ) 87 | UPDATED_VERSIONS+=( "$LAST_API_RELEASE" ) 88 | ;; 89 | *) 90 | echo "Invalid replacement format: \"${REPLACEMENT_FORMAT}\". Use semver, release, or api-release" >&2 91 | exit 1 92 | ;; 93 | esac 94 | done 95 | 96 | if [ -z "$LAST_VERSION" ]; 97 | then 98 | # No match to semver regex in SOURCE_FILE, so no version to go from. 99 | printf "Error: Unable to find latest version from %s.\n" "$SOURCE_FILE" 100 | exit 1 101 | fi 102 | 103 | # Search files in FILES_TO_CHECK and change (or get diffs) 104 | declare FAILED_CHECK=0 105 | 106 | for i in "${!FILES_TO_CHECK[@]}"; do 107 | FILE_TO_CHANGE=${FILES_TO_CHECK[$i]} 108 | RE_SEMVER=${RE_SEMVERS[$i]} 109 | UPDATED_VERSION=${UPDATED_VERSIONS[$i]} 110 | FILE_VERSION=$(grep -o -m 1 -E "${RE_SEMVER}" "$FILE_TO_CHANGE") 111 | if [ -z "$FILE_VERSION" ]; 112 | then 113 | # No match to semver regex in VERSIONFILE, so nothing to replace 114 | printf "Error: No semver version found in file %s.\n" "$FILE_TO_CHANGE" 115 | exit 1 116 | else 117 | # Replace semver in VERSIONFILE with semver obtained from SOURCE_FILE 118 | TMPFILE=$(mktemp /tmp/new_version.XXXXXX) 119 | # Check sed version, exit if version < 4.3 120 | if ! sed --version > /dev/null 2>&1; then 121 | CURRENT_VERSION=1.archaic 122 | else 123 | CURRENT_VERSION=$(sed --version | head -n1 | cut -d" " -f4) 124 | fi 125 | REQUIRED_VERSION="4.3" 126 | if [ "$(printf '%s\n' "$REQUIRED_VERSION" "$CURRENT_VERSION" | sort -V | head -n1)" != "$REQUIRED_VERSION" ]; then 127 | echo "sed version must be >= ${REQUIRED_VERSION}" && exit 1 128 | fi 129 | sed -E -r "s/$RE_SEMVER/$UPDATED_VERSION/" "$FILE_TO_CHANGE" > "$TMPFILE" 130 | if [ $CHECK == 1 ]; 131 | then 132 | DIFF=$(diff "$FILE_TO_CHANGE" "$TMPFILE" ) 133 | if [ -z "$DIFF" ]; 134 | then 135 | printf "version sync would make no changes to %s.\n" "$FILE_TO_CHANGE" 136 | rm "$TMPFILE" 137 | else 138 | FAILED_CHECK=1 139 | printf "version sync would make the following changes to %s:\n%s\n" "$FILE_TO_CHANGE" "$DIFF" 140 | rm "$TMPFILE" 141 | fi 142 | else 143 | cp "$TMPFILE" "$FILE_TO_CHANGE" 144 | rm "$TMPFILE" 145 | fi 146 | fi 147 | done 148 | 149 | # Exit with code determined by whether changes were needed in a check. 150 | if [ ${FAILED_CHECK} -ne 0 ]; then 151 | exit 1 152 | else 153 | exit 0 154 | fi 155 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | license_files = LICENSE.md 3 | 4 | [flake8] 5 | max-line-length = 100 6 | extend-ignore = D100, D101, D104, D105, D107, D2, D4 7 | per-file-ignores = 8 | test_*/**: D 9 | 10 | [tool:pytest] 11 | filterwarnings = 12 | ignore::DeprecationWarning 13 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | setup.py 3 | 4 | unstructured_inference - Tools to utilize trained models 5 | 6 | Copyright 2022 Unstructured Technologies, Inc. 7 | 8 | Licensed under the Apache License, Version 2.0 (the "License"); 9 | you may not use this file except in compliance with the License. 10 | You may obtain a copy of the License at 11 | 12 | http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | Unless required by applicable law or agreed to in writing, software 15 | distributed under the License is distributed on an "AS IS" BASIS, 16 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | See the License for the specific language governing permissions and 18 | limitations under the License. 19 | """ 20 | from typing import List, Optional, Union 21 | 22 | from setuptools import find_packages, setup 23 | 24 | from unstructured_inference.__version__ import __version__ 25 | 26 | 27 | def load_requirements(file_list: Optional[Union[str, List[str]]] = None): 28 | """Loads the requirements from a .in file or list of .in files.""" 29 | if file_list is None: 30 | file_list = ["requirements/base.in"] 31 | if isinstance(file_list, str): 32 | file_list = [file_list] 33 | requirements: List[str] = [] 34 | for file in file_list: 35 | with open(file, encoding="utf-8") as f: 36 | requirements.extend(f.readlines()) 37 | requirements = [ 38 | req for req in requirements if not req.startswith("#") and not req.startswith("-") 39 | ] 40 | return requirements 41 | 42 | 43 | def load_text_from_file(filename: str): 44 | """Retrieves text from a file.""" 45 | with open(filename, encoding="utf-8") as fp: 46 | description = fp.read() 47 | return description 48 | 49 | 50 | setup( 51 | name="unstructured_inference", 52 | description="A library for performing inference using trained models.", 53 | long_description=load_text_from_file("README.md"), 54 | long_description_content_type="text/markdown", 55 | keywords="NLP PDF HTML CV XML parsing preprocessing", 56 | url="https://github.com/Unstructured-IO/unstructured-inference", 57 | python_requires=">=3.7.0", 58 | classifiers=[ 59 | "Development Status :: 4 - Beta", 60 | "Intended Audience :: Developers", 61 | "Intended Audience :: Education", 62 | "Intended Audience :: Science/Research", 63 | "License :: OSI Approved :: Apache Software License", 64 | "Operating System :: OS Independent", 65 | "Programming Language :: Python :: 3", 66 | "Programming Language :: Python :: 3.8", 67 | "Programming Language :: Python :: 3.9", 68 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 69 | ], 70 | author="Unstructured Technologies", 71 | author_email="devops@unstructuredai.io", 72 | license="Apache-2.0", 73 | packages=find_packages(), 74 | version=__version__, 75 | entry_points={}, 76 | install_requires=load_requirements(), 77 | ) 78 | -------------------------------------------------------------------------------- /test_unstructured_inference/conftest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from PIL import Image 4 | 5 | from unstructured_inference.inference.elements import ( 6 | EmbeddedTextRegion, 7 | Rectangle, 8 | TextRegion, 9 | ) 10 | from unstructured_inference.inference.layoutelement import LayoutElement 11 | 12 | 13 | @pytest.fixture() 14 | def mock_pil_image(): 15 | return Image.new("RGB", (50, 50)) 16 | 17 | 18 | @pytest.fixture() 19 | def mock_numpy_image(): 20 | return np.zeros((50, 50, 3), np.uint8) 21 | 22 | 23 | @pytest.fixture() 24 | def mock_rectangle(): 25 | return Rectangle(100, 100, 300, 300) 26 | 27 | 28 | @pytest.fixture() 29 | def mock_text_region(): 30 | return TextRegion.from_coords(100, 100, 300, 300, text="Sample text") 31 | 32 | 33 | @pytest.fixture() 34 | def mock_layout_element(): 35 | return LayoutElement.from_coords( 36 | 100, 37 | 100, 38 | 300, 39 | 300, 40 | text="Sample text", 41 | source=None, 42 | type="Text", 43 | ) 44 | 45 | 46 | @pytest.fixture() 47 | def mock_embedded_text_regions(): 48 | return [ 49 | EmbeddedTextRegion.from_coords( 50 | x1=453.00277777777774, 51 | y1=317.319341111111, 52 | x2=711.5338541666665, 53 | y2=358.28571222222206, 54 | text="LayoutParser:", 55 | ), 56 | EmbeddedTextRegion.from_coords( 57 | x1=726.4778125, 58 | y1=317.319341111111, 59 | x2=760.3308594444444, 60 | y2=357.1698966666667, 61 | text="A", 62 | ), 63 | EmbeddedTextRegion.from_coords( 64 | x1=775.2748177777777, 65 | y1=317.319341111111, 66 | x2=917.3579885555555, 67 | y2=357.1698966666667, 68 | text="Unified", 69 | ), 70 | EmbeddedTextRegion.from_coords( 71 | x1=932.3019468888888, 72 | y1=317.319341111111, 73 | x2=1071.8426522222221, 74 | y2=357.1698966666667, 75 | text="Toolkit", 76 | ), 77 | EmbeddedTextRegion.from_coords( 78 | x1=1086.7866105555556, 79 | y1=317.319341111111, 80 | x2=1141.2105142777777, 81 | y2=357.1698966666667, 82 | text="for", 83 | ), 84 | EmbeddedTextRegion.from_coords( 85 | x1=1156.154472611111, 86 | y1=317.319341111111, 87 | x2=1256.334784222222, 88 | y2=357.1698966666667, 89 | text="Deep", 90 | ), 91 | EmbeddedTextRegion.from_coords( 92 | x1=437.83888888888885, 93 | y1=367.13322999999986, 94 | x2=610.0171992222222, 95 | y2=406.9837855555556, 96 | text="Learning", 97 | ), 98 | EmbeddedTextRegion.from_coords( 99 | x1=624.9611575555555, 100 | y1=367.13322999999986, 101 | x2=741.6754646666665, 102 | y2=406.9837855555556, 103 | text="Based", 104 | ), 105 | EmbeddedTextRegion.from_coords( 106 | x1=756.619423, 107 | y1=367.13322999999986, 108 | x2=958.3867708333332, 109 | y2=406.9837855555556, 110 | text="Document", 111 | ), 112 | EmbeddedTextRegion.from_coords( 113 | x1=973.3307291666665, 114 | y1=367.13322999999986, 115 | x2=1092.0535042777776, 116 | y2=406.9837855555556, 117 | text="Image", 118 | ), 119 | ] 120 | 121 | 122 | # TODO(alan): Make a better test layout 123 | @pytest.fixture() 124 | def mock_layout(mock_embedded_text_regions): 125 | return [ 126 | LayoutElement(text=r.text, type="UncategorizedText", bbox=r.bbox) 127 | for r in mock_embedded_text_regions 128 | ] 129 | 130 | 131 | @pytest.fixture() 132 | def example_table_cells(): 133 | cells = [ 134 | {"cell text": "Disability Category", "row_nums": [0, 1], "column_nums": [0]}, 135 | {"cell text": "Participants", "row_nums": [0, 1], "column_nums": [1]}, 136 | {"cell text": "Ballots Completed", "row_nums": [0, 1], "column_nums": [2]}, 137 | {"cell text": "Ballots Incomplete/Terminated", "row_nums": [0, 1], "column_nums": [3]}, 138 | {"cell text": "Results", "row_nums": [0], "column_nums": [4, 5]}, 139 | {"cell text": "Accuracy", "row_nums": [1], "column_nums": [4]}, 140 | {"cell text": "Time to complete", "row_nums": [1], "column_nums": [5]}, 141 | {"cell text": "Blind", "row_nums": [2], "column_nums": [0]}, 142 | {"cell text": "Low Vision", "row_nums": [3], "column_nums": [0]}, 143 | {"cell text": "Dexterity", "row_nums": [4], "column_nums": [0]}, 144 | {"cell text": "Mobility", "row_nums": [5], "column_nums": [0]}, 145 | {"cell text": "5", "row_nums": [2], "column_nums": [1]}, 146 | {"cell text": "5", "row_nums": [3], "column_nums": [1]}, 147 | {"cell text": "5", "row_nums": [4], "column_nums": [1]}, 148 | {"cell text": "3", "row_nums": [5], "column_nums": [1]}, 149 | {"cell text": "1", "row_nums": [2], "column_nums": [2]}, 150 | {"cell text": "2", "row_nums": [3], "column_nums": [2]}, 151 | {"cell text": "4", "row_nums": [4], "column_nums": [2]}, 152 | {"cell text": "3", "row_nums": [5], "column_nums": [2]}, 153 | {"cell text": "4", "row_nums": [2], "column_nums": [3]}, 154 | {"cell text": "3", "row_nums": [3], "column_nums": [3]}, 155 | {"cell text": "1", "row_nums": [4], "column_nums": [3]}, 156 | {"cell text": "0", "row_nums": [5], "column_nums": [3]}, 157 | {"cell text": "34.5%, n=1", "row_nums": [2], "column_nums": [4]}, 158 | {"cell text": "98.3% n=2 (97.7%, n=3)", "row_nums": [3], "column_nums": [4]}, 159 | {"cell text": "98.3%, n=4", "row_nums": [4], "column_nums": [4]}, 160 | {"cell text": "95.4%, n=3", "row_nums": [5], "column_nums": [4]}, 161 | {"cell text": "1199 sec, n=1", "row_nums": [2], "column_nums": [5]}, 162 | {"cell text": "1716 sec, n=3 (1934 sec, n=2)", "row_nums": [3], "column_nums": [5]}, 163 | {"cell text": "1672.1 sec, n=4", "row_nums": [4], "column_nums": [5]}, 164 | {"cell text": "1416 sec, n=3", "row_nums": [5], "column_nums": [5]}, 165 | ] 166 | for i in range(len(cells)): 167 | cells[i]["column header"] = False 168 | return [cells] 169 | -------------------------------------------------------------------------------- /test_unstructured_inference/inference/test_layout_element.py: -------------------------------------------------------------------------------- 1 | from unstructured_inference.inference.layoutelement import LayoutElement, TextRegion 2 | 3 | 4 | def test_layout_element_do_dict(mock_layout_element): 5 | expected = { 6 | "coordinates": ((100, 100), (100, 300), (300, 300), (300, 100)), 7 | "text": "Sample text", 8 | "type": "Text", 9 | "prob": None, 10 | "source": None, 11 | } 12 | 13 | assert mock_layout_element.to_dict() == expected 14 | 15 | 16 | def test_layout_element_from_region(mock_rectangle): 17 | expected = LayoutElement.from_coords(100, 100, 300, 300) 18 | region = TextRegion(bbox=mock_rectangle) 19 | 20 | assert LayoutElement.from_region(region) == expected 21 | -------------------------------------------------------------------------------- /test_unstructured_inference/models/test_detectron2onnx.py: -------------------------------------------------------------------------------- 1 | import os 2 | from unittest.mock import patch 3 | 4 | import pytest 5 | from PIL import Image 6 | 7 | import unstructured_inference.models.base as models 8 | import unstructured_inference.models.detectron2onnx as detectron2 9 | 10 | 11 | class MockDetectron2ONNXLayoutModel: 12 | def __init__(self, *args, **kwargs): 13 | self.args = args 14 | self.kwargs = kwargs 15 | 16 | def run(self, *args): 17 | return ([(1, 2, 3, 4)], [0], [(4, 5)], [0.818]) 18 | 19 | def get_inputs(self): 20 | class input_thing: 21 | name = "Bernard" 22 | 23 | return [input_thing()] 24 | 25 | 26 | def test_load_default_model(monkeypatch): 27 | monkeypatch.setattr(models, "models", {}) 28 | with patch.object( 29 | detectron2.onnxruntime, 30 | "InferenceSession", 31 | new=MockDetectron2ONNXLayoutModel, 32 | ): 33 | model = models.get_model("detectron2_mask_rcnn") 34 | 35 | assert isinstance(model.model, MockDetectron2ONNXLayoutModel) 36 | 37 | 38 | @pytest.mark.parametrize(("model_path", "label_map"), [("asdf", "diufs"), ("dfaw", "hfhfhfh")]) 39 | def test_load_model(model_path, label_map): 40 | with patch.object(detectron2.onnxruntime, "InferenceSession", return_value=True): 41 | model = detectron2.UnstructuredDetectronONNXModel() 42 | model.initialize(model_path=model_path, label_map=label_map) 43 | args, _ = detectron2.onnxruntime.InferenceSession.call_args 44 | assert args == (model_path,) 45 | assert label_map == model.label_map 46 | 47 | 48 | def test_unstructured_detectron_model(): 49 | model = detectron2.UnstructuredDetectronONNXModel() 50 | model.model = 1 51 | with patch.object(detectron2.UnstructuredDetectronONNXModel, "predict", return_value=[]): 52 | result = model(None) 53 | assert isinstance(result, list) 54 | assert len(result) == 0 55 | 56 | 57 | def test_inference(): 58 | with patch.object( 59 | detectron2.onnxruntime, 60 | "InferenceSession", 61 | return_value=MockDetectron2ONNXLayoutModel(), 62 | ): 63 | model = detectron2.UnstructuredDetectronONNXModel() 64 | model.initialize(model_path="test_path", label_map={0: "test_class"}) 65 | assert isinstance(model.model, MockDetectron2ONNXLayoutModel) 66 | with open(os.path.join("sample-docs", "receipt-sample.jpg"), mode="rb") as fp: 67 | image = Image.open(fp) 68 | image.load() 69 | elements = model(image) 70 | assert len(elements) == 1 71 | element = elements[0] 72 | (x1, y1), _, (x2, y2), _ = element.bbox.coordinates 73 | assert hasattr( 74 | element, 75 | "prob", 76 | ) # NOTE(pravin) New Assertion to Make Sure element has probabilities 77 | assert isinstance( 78 | element.prob, 79 | float, 80 | ) # NOTE(pravin) New Assertion to Make Sure Populated Probability is Float 81 | # NOTE(alan): The bbox coordinates get resized, so check their relative proportions 82 | assert x2 / x1 == pytest.approx(3.0) # x1 == 1, x2 == 3 before scaling 83 | assert y2 / y1 == pytest.approx(2.0) # y1 == 2, y2 == 4 before scaling 84 | assert element.type == "test_class" 85 | -------------------------------------------------------------------------------- /test_unstructured_inference/models/test_eval.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from unstructured_inference.inference.layoutelement import table_cells_to_dataframe 4 | from unstructured_inference.models.eval import compare_contents_as_df, default_tokenizer 5 | 6 | 7 | @pytest.fixture() 8 | def actual_cells(): 9 | return [ 10 | { 11 | "column_nums": [0], 12 | "row_nums": [0, 1], 13 | "column header": True, 14 | "cell text": "Disability Category", 15 | }, 16 | { 17 | "column_nums": [1], 18 | "row_nums": [0, 1], 19 | "column header": True, 20 | "cell text": "Participants", 21 | }, 22 | { 23 | "column_nums": [2], 24 | "row_nums": [0, 1], 25 | "column header": True, 26 | "cell text": "Ballots Completed", 27 | }, 28 | { 29 | "column_nums": [3], 30 | "row_nums": [0, 1], 31 | "column header": True, 32 | "cell text": "Ballots Incomplete/Terminated", 33 | }, 34 | {"column_nums": [4, 5], "row_nums": [0], "column header": True, "cell text": "Results"}, 35 | {"column_nums": [4], "row_nums": [1], "column header": False, "cell text": "Accuracy"}, 36 | { 37 | "column_nums": [5], 38 | "row_nums": [1], 39 | "column header": False, 40 | "cell text": "Time to complete", 41 | }, 42 | {"column_nums": [0], "row_nums": [2], "column header": False, "cell text": "Blind"}, 43 | {"column_nums": [0], "row_nums": [3], "column header": False, "cell text": "Low Vision"}, 44 | {"column_nums": [0], "row_nums": [4], "column header": False, "cell text": "Dexterity"}, 45 | {"column_nums": [0], "row_nums": [5], "column header": False, "cell text": "Mobility"}, 46 | {"column_nums": [1], "row_nums": [2], "column header": False, "cell text": "5"}, 47 | {"column_nums": [1], "row_nums": [3], "column header": False, "cell text": "5"}, 48 | {"column_nums": [1], "row_nums": [4], "column header": False, "cell text": "5"}, 49 | {"column_nums": [1], "row_nums": [5], "column header": False, "cell text": "3"}, 50 | {"column_nums": [2], "row_nums": [2], "column header": False, "cell text": "1"}, 51 | {"column_nums": [2], "row_nums": [3], "column header": False, "cell text": "2"}, 52 | {"column_nums": [2], "row_nums": [4], "column header": False, "cell text": "4"}, 53 | {"column_nums": [2], "row_nums": [5], "column header": False, "cell text": "3"}, 54 | {"column_nums": [3], "row_nums": [2], "column header": False, "cell text": "4"}, 55 | {"column_nums": [3], "row_nums": [3], "column header": False, "cell text": "3"}, 56 | {"column_nums": [3], "row_nums": [4], "column header": False, "cell text": "1"}, 57 | {"column_nums": [3], "row_nums": [5], "column header": False, "cell text": "0"}, 58 | {"column_nums": [4], "row_nums": [2], "column header": False, "cell text": "34.5%, n=1"}, 59 | { 60 | "column_nums": [4], 61 | "row_nums": [3], 62 | "column header": False, 63 | "cell text": "98.3% n=2 (97.7%, n=3)", 64 | }, 65 | {"column_nums": [4], "row_nums": [4], "column header": False, "cell text": "98.3%, n=4"}, 66 | {"column_nums": [4], "row_nums": [5], "column header": False, "cell text": "95.4%, n=3"}, 67 | {"column_nums": [5], "row_nums": [2], "column header": False, "cell text": "1199 sec, n=1"}, 68 | { 69 | "column_nums": [5], 70 | "row_nums": [3], 71 | "column header": False, 72 | "cell text": "1716 sec, n=3 (1934 sec, n=2)", 73 | }, 74 | { 75 | "column_nums": [5], 76 | "row_nums": [4], 77 | "column header": False, 78 | "cell text": "1672.1 sec, n=4", 79 | }, 80 | {"column_nums": [5], "row_nums": [5], "column header": False, "cell text": "1416 sec, n=3"}, 81 | ] 82 | 83 | 84 | @pytest.fixture() 85 | def pred_cells(): 86 | return [ 87 | {"column_nums": [0], "row_nums": [2], "column header": False, "cell text": "Blind"}, 88 | {"column_nums": [0], "row_nums": [3], "column header": False, "cell text": "Low Vision"}, 89 | {"column_nums": [0], "row_nums": [4], "column header": False, "cell text": "Dexterity"}, 90 | {"column_nums": [0], "row_nums": [5], "column header": False, "cell text": "Mobility"}, 91 | {"column_nums": [1], "row_nums": [2], "column header": False, "cell text": "5"}, 92 | {"column_nums": [1], "row_nums": [3], "column header": False, "cell text": "5"}, 93 | {"column_nums": [1], "row_nums": [4], "column header": False, "cell text": "5"}, 94 | {"column_nums": [1], "row_nums": [5], "column header": False, "cell text": "3"}, 95 | {"column_nums": [2], "row_nums": [2], "column header": False, "cell text": "1"}, 96 | {"column_nums": [2], "row_nums": [3], "column header": False, "cell text": "2"}, 97 | {"column_nums": [2], "row_nums": [4], "column header": False, "cell text": "4"}, 98 | {"column_nums": [2], "row_nums": [5], "column header": False, "cell text": "3"}, 99 | {"column_nums": [3], "row_nums": [2], "column header": False, "cell text": "4"}, 100 | {"column_nums": [3], "row_nums": [3], "column header": False, "cell text": "3"}, 101 | {"column_nums": [3], "row_nums": [4], "column header": False, "cell text": "1"}, 102 | {"column_nums": [3], "row_nums": [5], "column header": False, "cell text": "0"}, 103 | {"column_nums": [4], "row_nums": [1], "column header": False, "cell text": "Accuracy"}, 104 | {"column_nums": [4], "row_nums": [2], "column header": False, "cell text": "34.5%, n=1"}, 105 | { 106 | "column_nums": [4], 107 | "row_nums": [3], 108 | "column header": False, 109 | "cell text": "98.3% n=2 (97.7%, n=3)", 110 | }, 111 | {"column_nums": [4], "row_nums": [4], "column header": False, "cell text": "98.3%, n=4"}, 112 | {"column_nums": [4], "row_nums": [5], "column header": False, "cell text": "95.4%, n=3"}, 113 | { 114 | "column_nums": [5], 115 | "row_nums": [1], 116 | "column header": False, 117 | "cell text": "Time to complete", 118 | }, 119 | {"column_nums": [5], "row_nums": [2], "column header": False, "cell text": "1199 sec, n=1"}, 120 | { 121 | "column_nums": [5], 122 | "row_nums": [3], 123 | "column header": False, 124 | "cell text": "1716 sec, n=3 | (1934 sec, n=2)", 125 | }, 126 | { 127 | "column_nums": [5], 128 | "row_nums": [4], 129 | "column header": False, 130 | "cell text": "1672.1 sec, n=4", 131 | }, 132 | {"column_nums": [5], "row_nums": [5], "column header": False, "cell text": "1416 sec, n=3"}, 133 | { 134 | "column_nums": [0], 135 | "row_nums": [0, 1], 136 | "column header": True, 137 | "cell text": "soa etealeiliay Category", 138 | }, 139 | {"column_nums": [4, 5], "row_nums": [0], "column header": True, "cell text": "Results"}, 140 | { 141 | "column_nums": [1], 142 | "row_nums": [0, 1], 143 | "column header": True, 144 | "cell text": "Participants P", 145 | }, 146 | { 147 | "column_nums": [2], 148 | "row_nums": [0, 1], 149 | "column header": True, 150 | "cell text": "pallets Completed", 151 | }, 152 | { 153 | "column_nums": [3], 154 | "row_nums": [0, 1], 155 | "column header": True, 156 | "cell text": "Ballot: incom lete/ Ne Terminated", 157 | }, 158 | ] 159 | 160 | 161 | @pytest.fixture() 162 | def actual_df(actual_cells): 163 | return table_cells_to_dataframe(actual_cells).fillna("") 164 | 165 | 166 | @pytest.fixture() 167 | def pred_df(pred_cells): 168 | return table_cells_to_dataframe(pred_cells).fillna("") 169 | 170 | 171 | @pytest.mark.parametrize( 172 | ("eval_func", "processor"), 173 | [ 174 | ("token_ratio", default_tokenizer), 175 | ("token_ratio", None), 176 | ("partial_token_ratio", default_tokenizer), 177 | ("ratio", None), 178 | ("ratio", default_tokenizer), 179 | ("partial_ratio", default_tokenizer), 180 | ], 181 | ) 182 | def test_compare_content_as_df(actual_df, pred_df, eval_func, processor): 183 | results = compare_contents_as_df(actual_df, pred_df, eval_func=eval_func, processor=processor) 184 | assert 0 < results.get(f"by_col_{eval_func}") < 100 185 | 186 | 187 | def test_compare_content_as_df_with_invalid_input(actual_df, pred_df): 188 | with pytest.raises(ValueError, match="eval_func must be one of"): 189 | compare_contents_as_df(actual_df, pred_df, eval_func="foo") 190 | -------------------------------------------------------------------------------- /test_unstructured_inference/models/test_model.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any 3 | from unittest import mock 4 | 5 | import numpy as np 6 | import pytest 7 | 8 | import unstructured_inference.models.base as models 9 | from unstructured_inference.inference.layoutelement import LayoutElement, LayoutElements 10 | from unstructured_inference.models.unstructuredmodel import ( 11 | ModelNotInitializedError, 12 | UnstructuredObjectDetectionModel, 13 | ) 14 | 15 | 16 | class MockModel(UnstructuredObjectDetectionModel): 17 | call_count = 0 18 | 19 | def __init__(self): 20 | self.initializer = mock.MagicMock() 21 | super().__init__() 22 | 23 | def initialize(self, *args, **kwargs): 24 | return self.initializer(self, *args, **kwargs) 25 | 26 | def predict(self, x: Any) -> Any: 27 | return LayoutElements(element_coords=np.array([])) 28 | 29 | 30 | MOCK_MODEL_TYPES = { 31 | "foo": { 32 | "input_shape": (640, 640), 33 | }, 34 | } 35 | 36 | 37 | def test_get_model(monkeypatch): 38 | monkeypatch.setattr(models, "models", {}) 39 | with mock.patch.dict(models.model_class_map, {"yolox": MockModel}): 40 | assert isinstance(models.get_model("yolox"), MockModel) 41 | 42 | 43 | def test_register_new_model(): 44 | assert "foo" not in models.model_class_map 45 | assert "foo" not in models.model_config_map 46 | models.register_new_model(MOCK_MODEL_TYPES, MockModel) 47 | assert "foo" in models.model_class_map 48 | assert "foo" in models.model_config_map 49 | model = models.get_model("foo") 50 | assert len(model.initializer.mock_calls) == 1 51 | assert model.initializer.mock_calls[0][-1] == MOCK_MODEL_TYPES["foo"] 52 | assert isinstance(model, MockModel) 53 | # unregister the new model by reset to default 54 | models.model_class_map, models.model_config_map = models.get_default_model_mappings() 55 | assert "foo" not in models.model_class_map 56 | assert "foo" not in models.model_config_map 57 | 58 | 59 | def test_raises_invalid_model(): 60 | with pytest.raises(models.UnknownModelException): 61 | models.get_model("fake_model") 62 | 63 | 64 | def test_raises_uninitialized(): 65 | with pytest.raises(ModelNotInitializedError): 66 | models.UnstructuredDetectronONNXModel().predict(None) 67 | 68 | 69 | def test_model_initializes_once(): 70 | from unstructured_inference.inference import layout 71 | 72 | with mock.patch.dict(models.model_class_map, {"yolox": MockModel}), mock.patch.object( 73 | models, 74 | "models", 75 | {}, 76 | ): 77 | doc = layout.DocumentLayout.from_file("sample-docs/loremipsum.pdf") 78 | doc.pages[0].detection_model.initializer.assert_called_once() 79 | 80 | 81 | def test_deduplicate_detected_elements(): 82 | import numpy as np 83 | 84 | from unstructured_inference.inference.elements import intersections 85 | from unstructured_inference.inference.layout import DocumentLayout 86 | from unstructured_inference.models.base import get_model 87 | 88 | model = get_model("yolox_quantized") 89 | # model.confidence_threshold=0.5 90 | file = "sample-docs/example_table.jpg" 91 | doc = DocumentLayout.from_image_file( 92 | file, 93 | model, 94 | ) 95 | known_elements = [e.bbox for e in doc.pages[0].elements if e.type != "UncategorizedText"] 96 | # Compute intersection matrix 97 | intersections_mtx = intersections(*known_elements) 98 | # Get rid off diagonal (cause an element will always intersect itself) 99 | np.fill_diagonal(intersections_mtx, False) 100 | # Now all the elements should be False, because any intersection remains 101 | assert not intersections_mtx.any() 102 | 103 | 104 | def test_enhance_regions(): 105 | from unstructured_inference.inference.elements import Rectangle 106 | from unstructured_inference.models.base import get_model 107 | 108 | elements = [ 109 | LayoutElement(bbox=Rectangle(0, 0, 1, 1)), 110 | LayoutElement(bbox=Rectangle(0.01, 0.01, 1.01, 1.01)), 111 | LayoutElement(bbox=Rectangle(0.02, 0.02, 1.02, 1.02)), 112 | LayoutElement(bbox=Rectangle(0.03, 0.03, 1.03, 1.03)), 113 | LayoutElement(bbox=Rectangle(0.04, 0.04, 1.04, 1.04)), 114 | LayoutElement(bbox=Rectangle(0.05, 0.05, 1.05, 1.05)), 115 | LayoutElement(bbox=Rectangle(0.06, 0.06, 1.06, 1.06)), 116 | LayoutElement(bbox=Rectangle(0.07, 0.07, 1.07, 1.07)), 117 | LayoutElement(bbox=Rectangle(0.08, 0.08, 1.08, 1.08)), 118 | LayoutElement(bbox=Rectangle(0.09, 0.09, 1.09, 1.09)), 119 | LayoutElement(bbox=Rectangle(0.10, 0.10, 1.10, 1.10)), 120 | ] 121 | model = get_model("yolox_tiny") 122 | elements = model.enhance_regions(elements, 0.5) 123 | assert len(elements) == 1 124 | assert ( 125 | elements[0].bbox.x1, 126 | elements[0].bbox.y1, 127 | elements[0].bbox.x2, 128 | elements[0].bbox.x2, 129 | ) == ( 130 | 0, 131 | 0, 132 | 1.10, 133 | 1.10, 134 | ) 135 | 136 | 137 | def test_clean_type(): 138 | from unstructured_inference.inference.layout import LayoutElement 139 | from unstructured_inference.models.base import get_model 140 | 141 | elements = [ 142 | LayoutElement.from_coords( 143 | 0.6, 144 | 0.6, 145 | 0.65, 146 | 0.65, 147 | type="Table", 148 | ), # One little table nested inside all the others 149 | LayoutElement.from_coords(0.5, 0.5, 0.7, 0.7, type="Table"), # One nested table 150 | LayoutElement.from_coords(0, 0, 1, 1, type="Table"), # Big table 151 | LayoutElement.from_coords(0.01, 0.01, 1.01, 1.01), 152 | LayoutElement.from_coords(0.02, 0.02, 1.02, 1.02), 153 | LayoutElement.from_coords(0.03, 0.03, 1.03, 1.03), 154 | LayoutElement.from_coords(0.04, 0.04, 1.04, 1.04), 155 | LayoutElement.from_coords(0.05, 0.05, 1.05, 1.05), 156 | ] 157 | model = get_model("yolox_tiny") 158 | elements = model.clean_type(elements, type_to_clean="Table") 159 | assert len(elements) == 1 160 | assert ( 161 | elements[0].bbox.x1, 162 | elements[0].bbox.y1, 163 | elements[0].bbox.x2, 164 | elements[0].bbox.x2, 165 | ) == (0, 0, 1, 1) 166 | 167 | 168 | def test_env_variables_override_default_model(monkeypatch): 169 | # When an environment variable specifies a different default model and we call get_model with no 170 | # args, we should get back the model the env var calls for 171 | monkeypatch.setattr(models, "models", {}) 172 | with mock.patch.dict( 173 | models.os.environ, 174 | {"UNSTRUCTURED_DEFAULT_MODEL_NAME": "yolox"}, 175 | ), mock.patch.dict(models.model_class_map, {"yolox": MockModel}): 176 | model = models.get_model() 177 | assert isinstance(model, MockModel) 178 | 179 | 180 | def test_env_variables_override_initialization_params(monkeypatch): 181 | # When initialization params are specified in an environment variable, and we call get_model, we 182 | # should see that the model was initialized with those params 183 | monkeypatch.setattr(models, "models", {}) 184 | fake_label_map = {"1": "label1", "2": "label2"} 185 | with mock.patch.dict( 186 | models.os.environ, 187 | {"UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH": "fake_json.json"}, 188 | ), mock.patch.object(models, "DEFAULT_MODEL", "fake"), mock.patch.dict( 189 | models.model_class_map, 190 | {"fake": mock.MagicMock()}, 191 | ), mock.patch( 192 | "builtins.open", 193 | mock.mock_open( 194 | read_data='{"model_path": "fakepath", "label_map": ' + json.dumps(fake_label_map) + "}", 195 | ), 196 | ): 197 | model = models.get_model() 198 | model.initialize.assert_called_once_with( 199 | model_path="fakepath", 200 | label_map={1: "label1", 2: "label2"}, 201 | ) 202 | -------------------------------------------------------------------------------- /test_unstructured_inference/models/test_yolox.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | from unstructured_inference.inference.layout import process_file_with_model 6 | 7 | 8 | @pytest.mark.slow() 9 | def test_layout_yolox_local_parsing_image(): 10 | filename = os.path.join("sample-docs", "test-image.jpg") 11 | # NOTE(benjamin) keep_output = True create a file for each image in 12 | # localstorage for visualization of the result 13 | document_layout = process_file_with_model(filename, model_name="yolox", is_image=True) 14 | # NOTE(benjamin) The example image should result in one page result 15 | assert len(document_layout.pages) == 1 16 | # NOTE(benjamin) The example sent to the test contains 13 detections 17 | types_known = ["Text", "Section-header", "Page-header"] 18 | elements = document_layout.pages[0].elements_array 19 | known_regions = [ 20 | e for e in elements.element_class_ids if elements.element_class_id_map[e] in types_known 21 | ] 22 | assert len(known_regions) == 13 23 | # NOTE(pravin) New Assertion to Make Sure LayoutElement has probabilities 24 | assert hasattr(elements, "element_probs") 25 | assert isinstance( 26 | elements.element_probs[0], 27 | float, 28 | ) # NOTE(pravin) New Assertion to Make Sure Populated Probability is Float 29 | 30 | 31 | @pytest.mark.slow() 32 | def test_layout_yolox_local_parsing_pdf(): 33 | filename = os.path.join("sample-docs", "loremipsum.pdf") 34 | document_layout = process_file_with_model(filename, model_name="yolox") 35 | assert len(document_layout.pages) == 1 36 | # NOTE(benjamin) The example sent to the test contains 5 text detections 37 | text_elements = [e for e in document_layout.pages[0].elements if e.type == "Text"] 38 | assert len(text_elements) == 5 39 | assert hasattr( 40 | document_layout.pages[0].elements[0], 41 | "prob", 42 | ) # NOTE(pravin) New Assertion to Make Sure LayoutElement has probabilities 43 | assert isinstance( 44 | document_layout.pages[0].elements[0].prob, 45 | float, 46 | ) # NOTE(pravin) New Assertion to Make Sure Populated Probability is Float 47 | 48 | 49 | @pytest.mark.slow() 50 | def test_layout_yolox_local_parsing_empty_pdf(): 51 | filename = os.path.join("sample-docs", "empty-document.pdf") 52 | document_layout = process_file_with_model(filename, model_name="yolox") 53 | assert len(document_layout.pages) == 1 54 | # NOTE(benjamin) The example sent to the test contains 0 detections 55 | assert len(document_layout.pages[0].elements) == 0 56 | 57 | 58 | ######################## 59 | # ONLY SHORT TESTS BELOW 60 | ######################## 61 | 62 | 63 | def test_layout_yolox_local_parsing_image_soft(): 64 | filename = os.path.join("sample-docs", "example_table.jpg") 65 | # NOTE(benjamin) keep_output = True create a file for each image in 66 | # localstorage for visualization of the result 67 | document_layout = process_file_with_model(filename, model_name="yolox_quantized", is_image=True) 68 | # NOTE(benjamin) The example image should result in one page result 69 | assert len(document_layout.pages) == 1 70 | # NOTE(benjamin) Soft version of the test, run make test-long in order to run with full model 71 | assert len(document_layout.pages[0].elements) > 0 72 | assert hasattr( 73 | document_layout.pages[0].elements[0], 74 | "prob", 75 | ) # NOTE(pravin) New Assertion to Make Sure LayoutElement has probabilities 76 | assert isinstance( 77 | document_layout.pages[0].elements[0].prob, 78 | float, 79 | ) # NOTE(pravin) New Assertion to Make Sure Populated Probability is Float 80 | 81 | 82 | def test_layout_yolox_local_parsing_pdf_soft(): 83 | filename = os.path.join("sample-docs", "loremipsum.pdf") 84 | document_layout = process_file_with_model(filename, model_name="yolox_tiny") 85 | assert len(document_layout.pages) == 1 86 | # NOTE(benjamin) Soft version of the test, run make test-long in order to run with full model 87 | assert len(document_layout.pages[0].elements) > 0 88 | assert hasattr( 89 | document_layout.pages[0].elements[0], 90 | "prob", 91 | ) # NOTE(pravin) New Assertion to Make Sure LayoutElement has probabilities 92 | 93 | 94 | def test_layout_yolox_local_parsing_empty_pdf_soft(): 95 | filename = os.path.join("sample-docs", "empty-document.pdf") 96 | document_layout = process_file_with_model(filename, model_name="yolox_tiny") 97 | assert len(document_layout.pages) == 1 98 | # NOTE(benjamin) The example sent to the test contains 0 detections 99 | text_elements_page_1 = [el for el in document_layout.pages[0].elements if el.type != "Image"] 100 | assert len(text_elements_page_1) == 0 101 | -------------------------------------------------------------------------------- /test_unstructured_inference/test_config.py: -------------------------------------------------------------------------------- 1 | def test_default_config(): 2 | from unstructured_inference.config import inference_config 3 | 4 | assert inference_config.TT_TABLE_CONF == 0.5 5 | 6 | 7 | def test_env_override(monkeypatch): 8 | monkeypatch.setenv("TT_TABLE_CONF", 1) 9 | from unstructured_inference.config import inference_config 10 | 11 | assert inference_config.TT_TABLE_CONF == 1 12 | -------------------------------------------------------------------------------- /test_unstructured_inference/test_logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pytest 4 | 5 | from unstructured_inference import logger 6 | 7 | 8 | @pytest.mark.parametrize("level", range(50)) 9 | def test_translate_log_level(level): 10 | level_name = logging.getLevelName(level) 11 | if level_name in ["WARNING", "INFO", "DEBUG", "NOTSET", "WARN"]: 12 | expected = 4 13 | elif level_name in ["ERROR", "CRITICAL"]: 14 | expected = 3 15 | else: 16 | expected = 0 17 | assert logger.translate_log_level(level) == expected 18 | -------------------------------------------------------------------------------- /test_unstructured_inference/test_math.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from unstructured_inference.math import FLOAT_EPSILON, safe_division 5 | 6 | 7 | @pytest.mark.parametrize( 8 | ("a", "b", "expected"), 9 | [(0, 0, 0), (0, 1, 0), (1, 0, np.round(1 / FLOAT_EPSILON, 1)), (2, 3, 0.7)], 10 | ) 11 | def test_safe_division(a, b, expected): 12 | assert np.round(safe_division(a, b), 1) == expected 13 | -------------------------------------------------------------------------------- /test_unstructured_inference/test_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from unstructured_inference.inference.layout import DocumentLayout 5 | from unstructured_inference.utils import ( 6 | LazyDict, 7 | LazyEvaluateInfo, 8 | pad_image_with_background_color, 9 | strip_tags, 10 | ) 11 | 12 | 13 | # Mocking the DocumentLayout and Page classes 14 | class MockPageLayout: 15 | def annotate(self, annotation_data): 16 | return "mock_image" 17 | 18 | 19 | class MockDocumentLayout(DocumentLayout): 20 | @property 21 | def pages(self): 22 | return [MockPageLayout(), MockPageLayout()] 23 | 24 | 25 | def test_dict_same(): 26 | d = {"a": 1, "b": 2, "c": 3} 27 | ld = LazyDict(**d) 28 | assert all(kd == kld for kd, kld in zip(d, ld)) 29 | assert all(d[k] == ld[k] for k in d) 30 | assert len(ld) == len(d) 31 | 32 | 33 | def test_lazy_evaluate(): 34 | called = 0 35 | 36 | def func(x): 37 | nonlocal called 38 | called += 1 39 | return x 40 | 41 | lei = LazyEvaluateInfo(func, 3) 42 | assert called == 0 43 | ld = LazyDict(a=lei) 44 | assert called == 0 45 | assert ld["a"] == 3 46 | assert called == 1 47 | 48 | 49 | @pytest.mark.parametrize(("cache", "expected"), [(True, 1), (False, 2)]) 50 | def test_caches(cache, expected): 51 | called = 0 52 | 53 | def func(x): 54 | nonlocal called 55 | called += 1 56 | return x 57 | 58 | lei = LazyEvaluateInfo(func, 3) 59 | assert called == 0 60 | ld = LazyDict(cache=cache, a=lei) 61 | assert called == 0 62 | assert ld["a"] == 3 63 | assert ld["a"] == 3 64 | assert called == expected 65 | 66 | 67 | def test_pad_image_with_background_color(mock_pil_image): 68 | pad = 10 69 | height, width = mock_pil_image.size 70 | padded = pad_image_with_background_color(mock_pil_image, pad, "black") 71 | assert padded.size == (height + 2 * pad, width + 2 * pad) 72 | np.testing.assert_array_almost_equal( 73 | np.array(padded.crop((pad, pad, width + pad, height + pad))), 74 | np.array(mock_pil_image), 75 | ) 76 | assert padded.getpixel((1, 1)) == (0, 0, 0) 77 | 78 | 79 | def test_pad_image_with_invalid_input(mock_pil_image): 80 | with pytest.raises(ValueError, match="Can not pad an image with negative space!"): 81 | pad_image_with_background_color(mock_pil_image, -1) 82 | 83 | 84 | @pytest.mark.parametrize( 85 | ("html", "text"), 86 | [ 87 | ("Table
", "Table"), 88 | # test escaped character 89 | ("y<x, x>z
", "yz"), 90 | # test tag with parameters 91 | ("Table", "Table"), 92 | ], 93 | ) 94 | def test_strip_tags(html, text): 95 | assert strip_tags(html) == text 96 | -------------------------------------------------------------------------------- /test_unstructured_inference/test_visualization.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock, patch 2 | 3 | import numpy as np 4 | import pytest 5 | from PIL import Image 6 | 7 | from unstructured_inference.inference.elements import TextRegion 8 | from unstructured_inference.visualize import draw_bbox, show_plot 9 | 10 | 11 | def test_draw_bbox(): 12 | test_image_arr = np.ones((100, 100, 3), dtype="uint8") 13 | image = Image.fromarray(test_image_arr) 14 | x1, y1, x2, y2 = (1, 10, 7, 11) 15 | rect = TextRegion.from_coords(x1, y1, x2, y2) 16 | annotated_image = draw_bbox(image=image, element=rect, details=False) 17 | annotated_array = np.array(annotated_image) 18 | # Make sure the pixels on the edge of the box are red 19 | for i, expected in zip(range(3), [255, 0, 0]): 20 | assert all(annotated_array[y1, x1:x2, i] == expected) 21 | assert all(annotated_array[y2, x1:x2, i] == expected) 22 | assert all(annotated_array[y1:y2, x1, i] == expected) 23 | assert all(annotated_array[y1:y2, x2, i] == expected) 24 | # Make sure almost all the pixels are not changed 25 | assert ((annotated_array[:, :, 0] == 1).mean()) > 0.995 26 | assert ((annotated_array[:, :, 1] == 1).mean()) > 0.995 27 | assert ((annotated_array[:, :, 2] == 1).mean()) > 0.995 28 | 29 | 30 | def test_show_plot_with_pil_image(mock_pil_image): 31 | mock_fig = MagicMock() 32 | mock_ax = MagicMock() 33 | 34 | with patch( 35 | "matplotlib.pyplot.subplots", 36 | return_value=(mock_fig, mock_ax), 37 | ) as mock_subplots, patch("matplotlib.pyplot.show") as mock_show, patch.object( 38 | mock_ax, 39 | "imshow", 40 | ) as mock_imshow: 41 | show_plot(mock_pil_image, desired_width=100) 42 | 43 | mock_subplots.assert_called() 44 | mock_imshow.assert_called_with(mock_pil_image) 45 | mock_show.assert_called() 46 | 47 | 48 | def test_show_plot_with_numpy_image(mock_numpy_image): 49 | mock_fig = MagicMock() 50 | mock_ax = MagicMock() 51 | 52 | with patch( 53 | "matplotlib.pyplot.subplots", 54 | return_value=(mock_fig, mock_ax), 55 | ) as mock_subplots, patch("matplotlib.pyplot.show") as mock_show, patch.object( 56 | mock_ax, 57 | "imshow", 58 | ) as mock_imshow: 59 | show_plot(mock_numpy_image) 60 | 61 | mock_subplots.assert_called() 62 | mock_imshow.assert_called_with(mock_numpy_image) 63 | mock_show.assert_called() 64 | 65 | 66 | def test_show_plot_with_unsupported_image_type(): 67 | with pytest.raises(ValueError) as exec_info: 68 | show_plot("unsupported_image_type") 69 | 70 | assert "Unsupported Image Type" in str(exec_info.value) 71 | -------------------------------------------------------------------------------- /unstructured_inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/unstructured_inference/__init__.py -------------------------------------------------------------------------------- /unstructured_inference/__version__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.6" # pragma: no cover 2 | -------------------------------------------------------------------------------- /unstructured_inference/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains variables that can permitted to be tweaked by the system environment. For 3 | example, model parameters that changes the output of an inference call. Constants do NOT belong in 4 | this module. Constants are values that are usually names for common options (e.g., color names) or 5 | settings that should not be altered without making a code change (e.g., definition of 1Gb of memory 6 | in bytes). Constants should go into `./constants.py` 7 | """ 8 | 9 | import os 10 | from dataclasses import dataclass 11 | 12 | 13 | @dataclass 14 | class InferenceConfig: 15 | """class for configuring inference parameters""" 16 | 17 | def _get_string(self, var: str, default_value: str = "") -> str: 18 | """attempt to get the value of var from the os environment; if not present return the 19 | default_value""" 20 | return os.environ.get(var, default_value) 21 | 22 | def _get_int(self, var: str, default_value: int) -> int: 23 | if value := self._get_string(var): 24 | return int(value) 25 | return default_value 26 | 27 | def _get_float(self, var: str, default_value: float) -> float: 28 | if value := self._get_string(var): 29 | return float(value) 30 | return default_value 31 | 32 | @property 33 | def TABLE_IMAGE_BACKGROUND_PAD(self) -> int: 34 | """number of pixels to pad around an table image with a white background color 35 | 36 | The padding adds NO image data around an identified table bounding box; it simply adds white 37 | background around the image 38 | """ 39 | return self._get_int("TABLE_IMAGE_BACKGROUND_PAD", 20) 40 | 41 | @property 42 | def TT_TABLE_CONF(self) -> float: 43 | """confidence threshold for table identified by table transformer""" 44 | return self._get_float("TT_TABLE_CONF", 0.5) 45 | 46 | @property 47 | def TABLE_COLUMN_CONF(self) -> float: 48 | """confidence threshold for column identified by table transformer""" 49 | return self._get_float("TABLE_COLUMN_CONF", 0.5) 50 | 51 | @property 52 | def TABLE_ROW_CONF(self) -> float: 53 | """confidence threshold for column identified by table transformer""" 54 | return self._get_float("TABLE_ROW_CONF", 0.5) 55 | 56 | @property 57 | def TABLE_COLUMN_HEADER_CONF(self) -> float: 58 | """confidence threshold for column header identified by table transformer""" 59 | return self._get_float("TABLE_COLUMN_HEADER_CONF", 0.5) 60 | 61 | @property 62 | def TABLE_PROJECTED_ROW_HEADER_CONF(self) -> float: 63 | """confidence threshold for projected row header identified by table transformer""" 64 | return self._get_float("TABLE_PROJECTED_ROW_HEADER_CONF", 0.5) 65 | 66 | @property 67 | def TABLE_SPANNING_CELL_CONF(self) -> float: 68 | """confidence threshold for table spanning cells identified by table transformer""" 69 | return self._get_float("TABLE_SPANNING_CELL_CONF", 0.5) 70 | 71 | @property 72 | def TABLE_IOB_THRESHOLD(self) -> float: 73 | """minimum intersection over box area ratio for a box to be considered part of a larger box 74 | it intersects""" 75 | return self._get_float("TABLE_IOB_THRESHOLD", 0.5) 76 | 77 | @property 78 | def LAYOUT_SAME_REGION_THRESHOLD(self) -> float: 79 | """threshold for two layouts' bounding boxes to be considered as the same region 80 | 81 | When the intersection area over union area of the two is larger than this threshold the two 82 | boxes are considered the same region 83 | """ 84 | return self._get_float("LAYOUT_SAME_REGION_THRESHOLD", 0.75) 85 | 86 | @property 87 | def LAYOUT_SUBREGION_THRESHOLD(self) -> float: 88 | """threshold for one bounding box to be considered as a sub-region of another bounding box 89 | 90 | When the intersection region area divided by self area is larger than this threshold self is 91 | considered a subregion of the other 92 | """ 93 | return self._get_float("LAYOUT_SUBREGION_THRESHOLD", 0.75) 94 | 95 | @property 96 | def ELEMENTS_H_PADDING_COEF(self) -> float: 97 | """When extending the boundaries of a PDF object for the purpose of determining which other 98 | elements should be considered in the same text region, we use a relative distance based on 99 | some fraction of the block height (typically character height). This is the fraction used 100 | for the horizontal extension applied to the left and right sides. 101 | """ 102 | return self._get_float("ELEMENTS_H_PADDING_COEF", 0.4) 103 | 104 | @property 105 | def ELEMENTS_V_PADDING_COEF(self) -> float: 106 | """Same as ELEMENTS_H_PADDING_COEF but the vertical extension.""" 107 | return self._get_float("ELEMENTS_V_PADDING_COEF", 0.3) 108 | 109 | @property 110 | def IMG_PROCESSOR_LONGEST_EDGE(self) -> int: 111 | """configuration for DetrImageProcessor to scale images""" 112 | return self._get_int("IMG_PROCESSOR_LONGEST_EDGE", 1333) 113 | 114 | @property 115 | def IMG_PROCESSOR_SHORTEST_EDGE(self) -> int: 116 | """configuration for DetrImageProcessor to scale images""" 117 | return self._get_int("IMG_PROCESSOR_SHORTEST_EDGE", 800) 118 | 119 | 120 | inference_config = InferenceConfig() 121 | -------------------------------------------------------------------------------- /unstructured_inference/constants.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class Source(Enum): 5 | YOLOX = "yolox" 6 | DETECTRON2_ONNX = "detectron2_onnx" 7 | DETECTRON2_LP = "detectron2_lp" 8 | MERGED = "merged" 9 | 10 | 11 | class ElementType: 12 | PARAGRAPH = "Paragraph" 13 | IMAGE = "Image" 14 | PARAGRAPH_IN_IMAGE = "ParagraphInImage" 15 | FIGURE = "Figure" 16 | PICTURE = "Picture" 17 | TABLE = "Table" 18 | PARAGRAPH_IN_TABLE = "ParagraphInTable" 19 | LIST = "List" 20 | FORM = "Form" 21 | PARAGRAPH_IN_FORM = "ParagraphInForm" 22 | CHECK_BOX_CHECKED = "CheckBoxChecked" 23 | CHECK_BOX_UNCHECKED = "CheckBoxUnchecked" 24 | RADIO_BUTTON_CHECKED = "RadioButtonChecked" 25 | RADIO_BUTTON_UNCHECKED = "RadioButtonUnchecked" 26 | LIST_ITEM = "List-item" 27 | FORMULA = "Formula" 28 | CAPTION = "Caption" 29 | PAGE_HEADER = "Page-header" 30 | SECTION_HEADER = "Section-header" 31 | PAGE_FOOTER = "Page-footer" 32 | FOOTNOTE = "Footnote" 33 | TITLE = "Title" 34 | TEXT = "Text" 35 | UNCATEGORIZED_TEXT = "UncategorizedText" 36 | PAGE_BREAK = "PageBreak" 37 | CODE_SNIPPET = "CodeSnippet" 38 | PAGE_NUMBER = "PageNumber" 39 | OTHER = "Other" 40 | 41 | 42 | FULL_PAGE_REGION_THRESHOLD = 0.99 43 | 44 | # this field is defined by pytesseract/unstructured.pytesseract 45 | TESSERACT_TEXT_HEIGHT = "height" 46 | -------------------------------------------------------------------------------- /unstructured_inference/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/unstructured_inference/inference/__init__.py -------------------------------------------------------------------------------- /unstructured_inference/inference/elements.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from copy import deepcopy 4 | from dataclasses import dataclass, field 5 | from functools import cached_property 6 | from typing import Optional, Union 7 | 8 | import numpy as np 9 | 10 | from unstructured_inference.constants import Source 11 | from unstructured_inference.math import safe_division 12 | 13 | 14 | @dataclass 15 | class Rectangle: 16 | x1: Union[int, float] 17 | y1: Union[int, float] 18 | x2: Union[int, float] 19 | y2: Union[int, float] 20 | 21 | def pad(self, padding: Union[int, float]): 22 | """Increases (or decreases, if padding is negative) the size of the rectangle by extending 23 | the boundary outward (resp. inward).""" 24 | out_object = self.hpad(padding).vpad(padding) 25 | return out_object 26 | 27 | def hpad(self, padding: Union[int, float]): 28 | """Increases (or decreases, if padding is negative) the size of the rectangle by extending 29 | the left and right sides of the boundary outward (resp. inward).""" 30 | out_object = deepcopy(self) 31 | out_object.x1 -= padding 32 | out_object.x2 += padding 33 | return out_object 34 | 35 | def vpad(self, padding: Union[int, float]): 36 | """Increases (or decreases, if padding is negative) the size of the rectangle by extending 37 | the top and bottom of the boundary outward (resp. inward).""" 38 | out_object = deepcopy(self) 39 | out_object.y1 -= padding 40 | out_object.y2 += padding 41 | return out_object 42 | 43 | @property 44 | def width(self) -> Union[int, float]: 45 | """Width of rectangle""" 46 | return self.x2 - self.x1 47 | 48 | @property 49 | def height(self) -> Union[int, float]: 50 | """Height of rectangle""" 51 | return self.y2 - self.y1 52 | 53 | @property 54 | def x_midpoint(self) -> Union[int, float]: 55 | """Finds the horizontal midpoint of the object.""" 56 | return (self.x2 + self.x1) / 2 57 | 58 | @property 59 | def y_midpoint(self) -> Union[int, float]: 60 | """Finds the vertical midpoint of the object.""" 61 | return (self.y2 + self.y1) / 2 62 | 63 | def is_disjoint(self, other: Rectangle) -> bool: 64 | """Checks whether this rectangle is disjoint from another rectangle.""" 65 | return not self.intersects(other) 66 | 67 | def intersects(self, other: Rectangle) -> bool: 68 | """Checks whether this rectangle intersects another rectangle.""" 69 | if self._has_none() or other._has_none(): 70 | return False 71 | return intersections(self, other)[0, 1] 72 | 73 | def is_in(self, other: Rectangle, error_margin: Optional[Union[int, float]] = None) -> bool: 74 | """Checks whether this rectangle is contained within another rectangle.""" 75 | padded_other = other.pad(error_margin) if error_margin is not None else other 76 | return all( 77 | [ 78 | (self.x1 >= padded_other.x1), 79 | (self.x2 <= padded_other.x2), 80 | (self.y1 >= padded_other.y1), 81 | (self.y2 <= padded_other.y2), 82 | ], 83 | ) 84 | 85 | def _has_none(self) -> bool: 86 | """return false when one of the coord is nan""" 87 | return any((self.x1 is None, self.x2 is None, self.y1 is None, self.y2 is None)) 88 | 89 | @property 90 | def coordinates(self): 91 | """Gets coordinates of the rectangle""" 92 | return ((self.x1, self.y1), (self.x1, self.y2), (self.x2, self.y2), (self.x2, self.y1)) 93 | 94 | def intersection(self, other: Rectangle) -> Optional[Rectangle]: 95 | """Gives the rectangle that is the intersection of two rectangles, or None if the 96 | rectangles are disjoint.""" 97 | if self._has_none() or other._has_none(): 98 | return None 99 | x1 = max(self.x1, other.x1) 100 | x2 = min(self.x2, other.x2) 101 | y1 = max(self.y1, other.y1) 102 | y2 = min(self.y2, other.y2) 103 | if x1 > x2 or y1 > y2: 104 | return None 105 | return Rectangle(x1, y1, x2, y2) 106 | 107 | @property 108 | def area(self) -> float: 109 | """Gives the area of the rectangle.""" 110 | return self.width * self.height 111 | 112 | def intersection_over_union(self, other: Rectangle) -> float: 113 | """Gives the intersection-over-union of two rectangles. This tends to be a good metric of 114 | how similar the regions are. Returns 0 for disjoint rectangles, 1 for two identical 115 | rectangles -- area of intersection / area of union.""" 116 | intersection = self.intersection(other) 117 | intersection_area = 0.0 if intersection is None else intersection.area 118 | union_area = self.area + other.area - intersection_area 119 | return safe_division(intersection_area, union_area) 120 | 121 | def intersection_over_minimum(self, other: Rectangle) -> float: 122 | """Gives the area-of-intersection over the minimum of the areas of the rectangles. Useful 123 | for identifying when one rectangle is almost-a-subset of the other. Returns 0 for disjoint 124 | rectangles, 1 when either is a subset of the other.""" 125 | intersection = self.intersection(other) 126 | intersection_area = 0.0 if intersection is None else intersection.area 127 | min_area = min(self.area, other.area) 128 | return safe_division(intersection_area, min_area) 129 | 130 | def is_almost_subregion_of(self, other: Rectangle, subregion_threshold: float = 0.75) -> bool: 131 | """Returns whether this region is almost a subregion of other. This is determined by 132 | comparing the intersection area over self area to some threshold, and checking whether self 133 | is the smaller rectangle.""" 134 | intersection = self.intersection(other) 135 | intersection_area = 0.0 if intersection is None else intersection.area 136 | return (subregion_threshold < safe_division(intersection_area, self.area)) and ( 137 | self.area <= other.area 138 | ) 139 | 140 | 141 | def minimal_containing_region(*regions: Rectangle) -> Rectangle: 142 | """Returns the smallest rectangular region that contains all regions passed""" 143 | x1 = min(region.x1 for region in regions) 144 | y1 = min(region.y1 for region in regions) 145 | x2 = max(region.x2 for region in regions) 146 | y2 = max(region.y2 for region in regions) 147 | 148 | return Rectangle(x1, y1, x2, y2) 149 | 150 | 151 | def intersections(*rects: Rectangle): 152 | """Returns a square boolean matrix of intersections of an arbitrary number of rectangles, i.e. 153 | the ijth entry of the matrix is True if and only if the ith Rectangle and jth Rectangle 154 | intersect.""" 155 | # NOTE(alan): Rewrite using line scan 156 | coords = np.array([[r.x1, r.y1, r.x2, r.y2] for r in rects]) 157 | return coords_intersections(coords) 158 | 159 | 160 | def coords_intersections(coords: np.ndarray) -> np.ndarray: 161 | """Returns a square boolean matrix of intersections of given stack of coords, i.e. 162 | the ijth entry of the matrix is True if and only if the ith coords and jth coords 163 | intersect.""" 164 | x1s, y1s, x2s, y2s = coords[:, 0], coords[:, 1], coords[:, 2], coords[:, 3] 165 | 166 | # Use broadcasting to get comparison matrices. 167 | # For Rectangles r1 and r2, any of the following conditions makes the rectangles disjoint: 168 | # r1.x1 > r2.x2 169 | # r1.y1 > r2.y2 170 | # r2.x1 > r1.x2 171 | # r2.y1 > r1.y2 172 | # Then we take the complement (~) of the disjointness matrix to get the intersection matrix. 173 | intersections = ~( 174 | (x1s[None] > x2s[..., None]) 175 | | (y1s[None] > y2s[..., None]) 176 | | (x1s[None] > x2s[..., None]).T 177 | | (y1s[None] > y2s[..., None]).T 178 | ) 179 | 180 | return intersections 181 | 182 | 183 | @dataclass 184 | class TextRegion: 185 | bbox: Rectangle 186 | text: Optional[str] = None 187 | source: Optional[Source] = None 188 | 189 | def __str__(self) -> str: 190 | return str(self.text) 191 | 192 | @classmethod 193 | def from_coords( 194 | cls, 195 | x1: Union[int, float], 196 | y1: Union[int, float], 197 | x2: Union[int, float], 198 | y2: Union[int, float], 199 | text: Optional[str] = None, 200 | source: Optional[Source] = None, 201 | **kwargs, 202 | ) -> TextRegion: 203 | """Constructs a region from coordinates.""" 204 | bbox = Rectangle(x1, y1, x2, y2) 205 | 206 | return cls(text=text, source=source, bbox=bbox, **kwargs) 207 | 208 | 209 | @dataclass 210 | class TextRegions: 211 | element_coords: np.ndarray 212 | texts: np.ndarray = field(default_factory=lambda: np.array([])) 213 | sources: np.ndarray = field(default_factory=lambda: np.array([])) 214 | source: Source | None = None 215 | 216 | def __post_init__(self): 217 | if self.texts.size == 0 and self.element_coords.size > 0: 218 | self.texts = np.array([None] * self.element_coords.shape[0]) 219 | 220 | # for backward compatibility; also allow to use one value to set sources for all regions 221 | if self.sources.size == 0 and self.element_coords.size > 0: 222 | self.sources = np.array([self.source] * self.element_coords.shape[0]) 223 | elif self.source is None and self.sources.size: 224 | self.source = self.sources[0] 225 | 226 | # we convert to float so data type is more consistent (e.g., None will be np.nan) 227 | self.element_coords = self.element_coords.astype(float) 228 | 229 | def __getitem__(self, indices) -> TextRegions: 230 | return self.slice(indices) 231 | 232 | def slice(self, indices) -> TextRegions: 233 | """slice text regions based on indices""" 234 | return TextRegions( 235 | element_coords=self.element_coords[indices], 236 | texts=self.texts[indices], 237 | sources=self.sources[indices], 238 | ) 239 | 240 | def iter_elements(self): 241 | """iter text regions as one TextRegion per iteration; this returns a generator and has less 242 | memory impact than the as_list method""" 243 | for (x1, y1, x2, y2), text, source in zip( 244 | self.element_coords, 245 | self.texts, 246 | self.sources, 247 | ): 248 | yield TextRegion.from_coords(x1, y1, x2, y2, text, source) 249 | 250 | def as_list(self): 251 | """return a list of LayoutElement for backward compatibility""" 252 | return list(self.iter_elements()) 253 | 254 | @classmethod 255 | def from_list(cls, regions: list): 256 | """create TextRegions from a list of TextRegion objects; the objects must have the same 257 | source""" 258 | coords, texts, sources = [], [], [] 259 | for region in regions: 260 | coords.append((region.bbox.x1, region.bbox.y1, region.bbox.x2, region.bbox.y2)) 261 | texts.append(region.text) 262 | sources.append(region.source) 263 | return cls( 264 | element_coords=np.array(coords), 265 | texts=np.array(texts), 266 | sources=np.array(sources), 267 | ) 268 | 269 | def __len__(self): 270 | return self.element_coords.shape[0] 271 | 272 | @property 273 | def x1(self): 274 | """left coordinate""" 275 | return self.element_coords[:, 0] 276 | 277 | @property 278 | def y1(self): 279 | """top coordinate""" 280 | return self.element_coords[:, 1] 281 | 282 | @property 283 | def x2(self): 284 | """right coordinate""" 285 | return self.element_coords[:, 2] 286 | 287 | @property 288 | def y2(self): 289 | """bottom coordinate""" 290 | return self.element_coords[:, 3] 291 | 292 | @cached_property 293 | def areas(self) -> np.ndarray: 294 | """areas of each region; only compute it when it is needed""" 295 | return (self.x2 - self.x1) * (self.y2 - self.y1) 296 | 297 | 298 | class EmbeddedTextRegion(TextRegion): 299 | pass 300 | 301 | 302 | class ImageTextRegion(TextRegion): 303 | pass 304 | 305 | 306 | def region_bounding_boxes_are_almost_the_same( 307 | region1: Rectangle, 308 | region2: Rectangle, 309 | same_region_threshold: float = 0.75, 310 | ) -> bool: 311 | """Returns whether bounding boxes are almost the same. This is determined by checking if the 312 | intersection over union is above some threshold.""" 313 | return region1.intersection_over_union(region2) > same_region_threshold 314 | 315 | 316 | def grow_region_to_match_region(region_to_grow: Rectangle, region_to_match: Rectangle): 317 | """Grows a region to the minimum size necessary to contain both regions.""" 318 | (new_x1, new_y1), _, (new_x2, new_y2), _ = minimal_containing_region( 319 | region_to_grow, 320 | region_to_match, 321 | ).coordinates 322 | region_to_grow.x1, region_to_grow.y1, region_to_grow.x2, region_to_grow.y2 = ( 323 | new_x1, 324 | new_y1, 325 | new_x2, 326 | new_y2, 327 | ) 328 | -------------------------------------------------------------------------------- /unstructured_inference/inference/layout.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | import tempfile 5 | from functools import cached_property 6 | from pathlib import PurePath 7 | from typing import Any, BinaryIO, Collection, List, Optional, Union, cast 8 | 9 | import numpy as np 10 | import pdf2image 11 | from PIL import Image, ImageSequence 12 | 13 | from unstructured_inference.inference.elements import ( 14 | TextRegion, 15 | ) 16 | from unstructured_inference.inference.layoutelement import LayoutElement, LayoutElements 17 | from unstructured_inference.logger import logger 18 | from unstructured_inference.models.base import get_model 19 | from unstructured_inference.models.unstructuredmodel import ( 20 | UnstructuredElementExtractionModel, 21 | UnstructuredObjectDetectionModel, 22 | ) 23 | from unstructured_inference.visualize import draw_bbox 24 | 25 | 26 | class DocumentLayout: 27 | """Class for handling documents that are saved as .pdf files. For .pdf files, a 28 | document image analysis (DIA) model detects the layout of the page prior to extracting 29 | element.""" 30 | 31 | def __init__(self, pages=None): 32 | self._pages = pages 33 | 34 | def __str__(self) -> str: 35 | return "\n\n".join([str(page) for page in self.pages]) 36 | 37 | @property 38 | def pages(self) -> List[PageLayout]: 39 | """Gets all elements from pages in sequential order.""" 40 | return self._pages 41 | 42 | @classmethod 43 | def from_pages(cls, pages: List[PageLayout]) -> DocumentLayout: 44 | """Generates a new instance of the class from a list of `PageLayouts`s""" 45 | doc_layout = cls() 46 | doc_layout._pages = pages 47 | return doc_layout 48 | 49 | @classmethod 50 | def from_file( 51 | cls, 52 | filename: str, 53 | fixed_layouts: Optional[List[Optional[List[TextRegion]]]] = None, 54 | pdf_image_dpi: int = 200, 55 | password: Optional[str] = None, 56 | **kwargs, 57 | ) -> DocumentLayout: 58 | """Creates a DocumentLayout from a pdf file.""" 59 | logger.info(f"Reading PDF for file: {filename} ...") 60 | 61 | with tempfile.TemporaryDirectory() as temp_dir: 62 | _image_paths = convert_pdf_to_image( 63 | filename, 64 | pdf_image_dpi, 65 | output_folder=temp_dir, 66 | path_only=True, 67 | password=password, 68 | ) 69 | image_paths = cast(List[str], _image_paths) 70 | number_of_pages = len(image_paths) 71 | pages: List[PageLayout] = [] 72 | if fixed_layouts is None: 73 | fixed_layouts = [None for _ in range(0, number_of_pages)] 74 | for i, (image_path, fixed_layout) in enumerate(zip(image_paths, fixed_layouts)): 75 | # NOTE(robinson) - In the future, maybe we detect the page number and default 76 | # to the index if it is not detected 77 | with Image.open(image_path) as image: 78 | page = PageLayout.from_image( 79 | image, 80 | number=i + 1, 81 | document_filename=filename, 82 | fixed_layout=fixed_layout, 83 | **kwargs, 84 | ) 85 | pages.append(page) 86 | return cls.from_pages(pages) 87 | 88 | @classmethod 89 | def from_image_file( 90 | cls, 91 | filename: str, 92 | detection_model: Optional[UnstructuredObjectDetectionModel] = None, 93 | element_extraction_model: Optional[UnstructuredElementExtractionModel] = None, 94 | fixed_layout: Optional[List[TextRegion]] = None, 95 | **kwargs, 96 | ) -> DocumentLayout: 97 | """Creates a DocumentLayout from an image file.""" 98 | logger.info(f"Reading image file: {filename} ...") 99 | try: 100 | image = Image.open(filename) 101 | format = image.format 102 | images: list[Image.Image] = [] 103 | for i, im in enumerate(ImageSequence.Iterator(image)): 104 | im = im.convert("RGB") 105 | im.format = format 106 | images.append(im) 107 | except Exception as e: 108 | if os.path.isdir(filename) or os.path.isfile(filename): 109 | raise e 110 | else: 111 | raise FileNotFoundError(f'File "{filename}" not found!') from e 112 | pages = [] 113 | for i, image in enumerate(images): # type: ignore 114 | page = PageLayout.from_image( 115 | image, 116 | image_path=filename, 117 | number=i, 118 | detection_model=detection_model, 119 | element_extraction_model=element_extraction_model, 120 | fixed_layout=fixed_layout, 121 | **kwargs, 122 | ) 123 | pages.append(page) 124 | return cls.from_pages(pages) 125 | 126 | 127 | class PageLayout: 128 | """Class for an individual PDF page.""" 129 | 130 | def __init__( 131 | self, 132 | number: int, 133 | image: Image.Image, 134 | image_metadata: Optional[dict] = None, 135 | image_path: Optional[Union[str, PurePath]] = None, # TODO: Deprecate 136 | document_filename: Optional[Union[str, PurePath]] = None, 137 | detection_model: Optional[UnstructuredObjectDetectionModel] = None, 138 | element_extraction_model: Optional[UnstructuredElementExtractionModel] = None, 139 | password: Optional[str] = None, 140 | ): 141 | if detection_model is not None and element_extraction_model is not None: 142 | raise ValueError("Only one of detection_model and extraction_model should be passed.") 143 | self.image: Optional[Image.Image] = image 144 | if image_metadata is None: 145 | image_metadata = {} 146 | self.image_metadata = image_metadata 147 | self.image_path = image_path 148 | self.image_array: Union[np.ndarray[Any, Any], None] = None 149 | self.document_filename = document_filename 150 | self.number = number 151 | self.detection_model = detection_model 152 | self.element_extraction_model = element_extraction_model 153 | self.elements_array: LayoutElements | None = None 154 | self.password = password 155 | # NOTE(alan): Dropped LocationlessLayoutElement that was created for chipper - chipper has 156 | # locations now and if we need to support LayoutElements without bounding boxes we can make 157 | # the bbox property optional 158 | 159 | def __str__(self) -> str: 160 | return "\n\n".join([str(element) for element in self.elements]) 161 | 162 | @cached_property 163 | def elements(self) -> Collection[LayoutElement]: 164 | """return a list of layout elements from the array data structure; intended for backward 165 | compatibility""" 166 | if self.elements_array is None: 167 | return [] 168 | return self.elements_array.as_list() 169 | 170 | def get_elements_using_image_extraction( 171 | self, 172 | inplace=True, 173 | ) -> Optional[list[LayoutElement]]: 174 | """Uses end-to-end text element extraction model to extract the elements on the page.""" 175 | if self.element_extraction_model is None: 176 | raise ValueError( 177 | "Cannot get elements using image extraction, no image extraction model defined", 178 | ) 179 | assert self.image is not None 180 | elements = self.element_extraction_model(self.image) 181 | if inplace: 182 | self.elements = elements 183 | return None 184 | return elements 185 | 186 | def get_elements_with_detection_model( 187 | self, 188 | inplace: bool = True, 189 | ) -> Optional[LayoutElements]: 190 | """Uses specified model to detect the elements on the page.""" 191 | if self.detection_model is None: 192 | model = get_model() 193 | if isinstance(model, UnstructuredObjectDetectionModel): 194 | self.detection_model = model 195 | else: 196 | raise NotImplementedError("Default model should be a detection model") 197 | 198 | # NOTE(mrobinson) - We'll want make this model inference step some kind of 199 | # remote call in the future. 200 | assert self.image is not None 201 | inferred_layout: LayoutElements = self.detection_model(self.image) 202 | inferred_layout = self.detection_model.deduplicate_detected_elements( 203 | inferred_layout, 204 | ) 205 | 206 | if inplace: 207 | self.elements_array = inferred_layout 208 | return None 209 | 210 | return inferred_layout 211 | 212 | def _get_image_array(self) -> Union[np.ndarray[Any, Any], None]: 213 | """Converts the raw image into a numpy array.""" 214 | if self.image_array is None: 215 | if self.image: 216 | self.image_array = np.array(self.image) 217 | else: 218 | image = Image.open(self.image_path) # type: ignore 219 | self.image_array = np.array(image) 220 | return self.image_array 221 | 222 | def annotate( 223 | self, 224 | colors: Optional[Union[List[str], str]] = None, 225 | image_dpi: int = 200, 226 | annotation_data: Optional[dict[str, dict]] = None, 227 | add_details: bool = False, 228 | sources: Optional[List[str]] = None, 229 | ) -> Image.Image: 230 | """Annotates the elements on the page image. 231 | if add_details is True, and the elements contain type and source attributes, then 232 | the type and source will be added to the image. 233 | sources is a list of sources to annotate. If sources is ["all"], then all sources will be 234 | annotated. Current sources allowed are "yolox","detectron2_onnx" and "detectron2_lp" """ 235 | if colors is None: 236 | colors = ["red" for _ in self.elements] 237 | if isinstance(colors, str): 238 | colors = [colors] 239 | # If there aren't enough colors, just cycle through the colors a few times 240 | if len(colors) < len(self.elements): 241 | n_copies = (len(self.elements) // len(colors)) + 1 242 | colors = colors * n_copies 243 | 244 | # Hotload image if it hasn't been loaded yet 245 | if self.image: 246 | img = self.image.copy() 247 | elif self.image_path: 248 | img = Image.open(self.image_path) 249 | else: 250 | img = self._get_image(self.document_filename, self.number, image_dpi) 251 | 252 | if annotation_data is None: 253 | for el, color in zip(self.elements, colors): 254 | if sources is None or el.source in sources: 255 | img = draw_bbox(img, el, color=color, details=add_details) 256 | else: 257 | for attribute, style in annotation_data.items(): 258 | if hasattr(self, attribute) and getattr(self, attribute): 259 | color = style["color"] 260 | width = style["width"] 261 | for region in getattr(self, attribute): 262 | required_source = getattr(region, "source", None) 263 | if (sources is None) or (required_source in sources): 264 | img = draw_bbox( 265 | img, 266 | region, 267 | color=color, 268 | width=width, 269 | details=add_details, 270 | ) 271 | 272 | return img 273 | 274 | def _get_image(self, filename, page_number, pdf_image_dpi: int = 200) -> Image.Image: 275 | """Hotloads a page image from a pdf file.""" 276 | 277 | with tempfile.TemporaryDirectory() as temp_dir: 278 | _image_paths = pdf2image.convert_from_path( 279 | filename, 280 | dpi=pdf_image_dpi, 281 | output_folder=temp_dir, 282 | paths_only=True, 283 | ) 284 | image_paths = cast(List[str], _image_paths) 285 | if page_number > len(image_paths): 286 | raise ValueError( 287 | f"Page number {page_number} is greater than the number of pages in the PDF.", 288 | ) 289 | 290 | with Image.open(image_paths[page_number - 1]) as image: 291 | return image.copy() 292 | 293 | @classmethod 294 | def from_image( 295 | cls, 296 | image: Image.Image, 297 | image_path: Optional[Union[str, PurePath]] = None, 298 | document_filename: Optional[Union[str, PurePath]] = None, 299 | number: int = 1, 300 | detection_model: Optional[UnstructuredObjectDetectionModel] = None, 301 | element_extraction_model: Optional[UnstructuredElementExtractionModel] = None, 302 | fixed_layout: Optional[List[TextRegion]] = None, 303 | ): 304 | """Creates a PageLayout from an already-loaded PIL Image.""" 305 | 306 | page = cls( 307 | number=number, 308 | image=image, 309 | detection_model=detection_model, 310 | element_extraction_model=element_extraction_model, 311 | ) 312 | # FIXME (yao): refactor the other methods so they all return elements like the third route 313 | if page.element_extraction_model is not None: 314 | page.get_elements_using_image_extraction() 315 | elif fixed_layout is None: 316 | page.get_elements_with_detection_model() 317 | else: 318 | page.elements = [] 319 | 320 | page.image_metadata = { 321 | "format": page.image.format if page.image else None, 322 | "width": page.image.width if page.image else None, 323 | "height": page.image.height if page.image else None, 324 | } 325 | page.image_path = os.path.abspath(image_path) if image_path else None 326 | page.document_filename = os.path.abspath(document_filename) if document_filename else None 327 | 328 | # Clear the image to save memory 329 | page.image = None 330 | 331 | return page 332 | 333 | 334 | def process_data_with_model( 335 | data: BinaryIO, 336 | model_name: Optional[str], 337 | password: Optional[str] = None, 338 | **kwargs: Any, 339 | ) -> DocumentLayout: 340 | """Process PDF as file-like object `data` into a `DocumentLayout`. 341 | 342 | Uses the model identified by `model_name`. 343 | """ 344 | with tempfile.TemporaryDirectory() as tmp_dir_path: 345 | file_path = os.path.join(tmp_dir_path, "document.pdf") 346 | with open(file_path, "wb") as f: 347 | f.write(data.read()) 348 | f.flush() 349 | layout = process_file_with_model( 350 | file_path, 351 | model_name, 352 | password=password, 353 | **kwargs, 354 | ) 355 | 356 | return layout 357 | 358 | 359 | def process_file_with_model( 360 | filename: str, 361 | model_name: Optional[str], 362 | is_image: bool = False, 363 | fixed_layouts: Optional[List[Optional[List[TextRegion]]]] = None, 364 | pdf_image_dpi: int = 200, 365 | password: Optional[str] = None, 366 | **kwargs: Any, 367 | ) -> DocumentLayout: 368 | """Processes pdf file with name filename into a DocumentLayout by using a model identified by 369 | model_name.""" 370 | 371 | model = get_model(model_name, **kwargs) 372 | if isinstance(model, UnstructuredObjectDetectionModel): 373 | detection_model = model 374 | element_extraction_model = None 375 | elif isinstance(model, UnstructuredElementExtractionModel): 376 | detection_model = None 377 | element_extraction_model = model 378 | else: 379 | raise ValueError(f"Unsupported model type: {type(model)}") 380 | layout = ( 381 | DocumentLayout.from_image_file( 382 | filename, 383 | detection_model=detection_model, 384 | element_extraction_model=element_extraction_model, 385 | **kwargs, 386 | ) 387 | if is_image 388 | else DocumentLayout.from_file( 389 | filename, 390 | detection_model=detection_model, 391 | element_extraction_model=element_extraction_model, 392 | fixed_layouts=fixed_layouts, 393 | pdf_image_dpi=pdf_image_dpi, 394 | password=password, 395 | **kwargs, 396 | ) 397 | ) 398 | return layout 399 | 400 | 401 | def convert_pdf_to_image( 402 | filename: str, 403 | dpi: int = 200, 404 | output_folder: Optional[Union[str, PurePath]] = None, 405 | path_only: bool = False, 406 | password: Optional[str] = None, 407 | ) -> Union[List[Image.Image], List[str]]: 408 | """Get the image renderings of the pdf pages using pdf2image""" 409 | 410 | if path_only and not output_folder: 411 | raise ValueError("output_folder must be specified if path_only is true") 412 | 413 | if output_folder is not None: 414 | images = pdf2image.convert_from_path( 415 | filename, 416 | dpi=dpi, 417 | output_folder=output_folder, 418 | paths_only=path_only, 419 | userpw=password or "", 420 | ) 421 | else: 422 | images = pdf2image.convert_from_path( 423 | filename, 424 | dpi=dpi, 425 | paths_only=path_only, 426 | userpw=password or "", 427 | ) 428 | 429 | return images 430 | -------------------------------------------------------------------------------- /unstructured_inference/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def translate_log_level(level: int) -> int: 5 | """Translate Python debugg level to ONNX runtime error level 6 | since blank pages error are shown at level 3 that should be the 7 | exception, and 4 the normal behavior""" 8 | level_name = logging.getLevelName(level) 9 | onnx_level = 0 10 | if level_name in ["NOTSET", "DEBUG", "INFO", "WARNING"]: 11 | onnx_level = 4 12 | elif level_name in ["ERROR", "CRITICAL"]: 13 | onnx_level = 3 14 | 15 | return onnx_level 16 | 17 | 18 | logger = logging.getLogger("unstructured_inference") 19 | 20 | logger_onnx = logging.getLogger("unstructured_inference_onnxruntime") 21 | logger_onnx.setLevel(translate_log_level(logger.getEffectiveLevel())) 22 | -------------------------------------------------------------------------------- /unstructured_inference/math.py: -------------------------------------------------------------------------------- 1 | """a lightweight module that provides helpers to common math operations""" 2 | 3 | import numpy as np 4 | 5 | FLOAT_EPSILON = np.finfo(float).eps 6 | 7 | 8 | def safe_division(a, b) -> float: 9 | """a safer division to avoid division by zero when b == 0 10 | 11 | returns a/b or a/FLOAT_EPSILON (should be around 2.2E-16) when b == 0 12 | 13 | Parameters: 14 | - a (int/float): a in a/b 15 | - b (int/float): b in a/b 16 | 17 | Returns: 18 | float: a/b or a/FLOAT_EPSILON (should be around 2.2E-16) when b == 0 19 | """ 20 | return a / max(b, FLOAT_EPSILON) 21 | -------------------------------------------------------------------------------- /unstructured_inference/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Unstructured-IO/unstructured-inference/a0df407c3f7143fa66a9ad8bb40a6ee06907ce5a/unstructured_inference/models/__init__.py -------------------------------------------------------------------------------- /unstructured_inference/models/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | import os 5 | import threading 6 | from typing import Dict, Optional, Tuple, Type 7 | 8 | from unstructured_inference.models.detectron2onnx import ( 9 | MODEL_TYPES as DETECTRON2_ONNX_MODEL_TYPES, 10 | ) 11 | from unstructured_inference.models.detectron2onnx import UnstructuredDetectronONNXModel 12 | from unstructured_inference.models.unstructuredmodel import UnstructuredModel 13 | from unstructured_inference.models.yolox import MODEL_TYPES as YOLOX_MODEL_TYPES 14 | from unstructured_inference.models.yolox import UnstructuredYoloXModel 15 | from unstructured_inference.utils import LazyDict 16 | 17 | DEFAULT_MODEL = "yolox" 18 | 19 | 20 | class Models(object): 21 | _instance = None 22 | _lock = threading.Lock() 23 | 24 | def __new__(cls): 25 | """return an instance if one already exists otherwise create an instance""" 26 | if cls._instance is None: 27 | with cls._lock: 28 | if cls._instance is None: 29 | cls._instance = super(Models, cls).__new__(cls) 30 | cls.models: Dict[str, UnstructuredModel] = {} 31 | return cls._instance 32 | 33 | def __contains__(self, key): 34 | return key in self.models 35 | 36 | def __getitem__(self, key: str): 37 | return self.models.__getitem__(key) 38 | 39 | def __setitem__(self, key: str, value: UnstructuredModel): 40 | self.models[key] = value 41 | 42 | 43 | models: Models = Models() 44 | 45 | 46 | def get_default_model_mappings() -> Tuple[ 47 | Dict[str, Type[UnstructuredModel]], 48 | Dict[str, dict | LazyDict], 49 | ]: 50 | """default model mappings for models that are in `unstructured_inference` repo""" 51 | return { 52 | **dict.fromkeys(DETECTRON2_ONNX_MODEL_TYPES, UnstructuredDetectronONNXModel), 53 | **dict.fromkeys(YOLOX_MODEL_TYPES, UnstructuredYoloXModel), 54 | }, {**DETECTRON2_ONNX_MODEL_TYPES, **YOLOX_MODEL_TYPES} 55 | 56 | 57 | model_class_map, model_config_map = get_default_model_mappings() 58 | 59 | 60 | def register_new_model(model_config: dict, model_class: UnstructuredModel): 61 | """Register this model in model_config_map and model_class_map. 62 | 63 | Those maps are updated with the with the new model class information. 64 | """ 65 | model_config_map.update(model_config) 66 | model_class_map.update(dict.fromkeys(model_config, model_class)) 67 | 68 | 69 | def get_model(model_name: Optional[str] = None) -> UnstructuredModel: 70 | """Gets the model object by model name.""" 71 | # TODO(alan): These cases are similar enough that we can probably do them all together with 72 | # importlib 73 | 74 | if model_name is None: 75 | default_name_from_env = os.environ.get("UNSTRUCTURED_DEFAULT_MODEL_NAME") 76 | model_name = default_name_from_env if default_name_from_env is not None else DEFAULT_MODEL 77 | 78 | if model_name in models: 79 | return models[model_name] 80 | 81 | initialize_param_json = os.environ.get("UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH") 82 | if initialize_param_json is not None: 83 | with open(initialize_param_json) as fp: 84 | initialize_params = json.load(fp) 85 | label_map_int_keys = { 86 | int(key): value for key, value in initialize_params["label_map"].items() 87 | } 88 | initialize_params["label_map"] = label_map_int_keys 89 | else: 90 | if model_name in model_config_map: 91 | initialize_params = model_config_map[model_name] 92 | else: 93 | raise UnknownModelException(f"Unknown model type: {model_name}") 94 | 95 | model: UnstructuredModel = model_class_map[model_name]() 96 | 97 | model.initialize(**initialize_params) 98 | models[model_name] = model 99 | return model 100 | 101 | 102 | class UnknownModelException(Exception): 103 | """A model was requested with an unrecognized identifier.""" 104 | 105 | pass 106 | -------------------------------------------------------------------------------- /unstructured_inference/models/detectron2onnx.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, Final, List, Optional, Union, cast 3 | 4 | import cv2 5 | import numpy as np 6 | import onnxruntime 7 | from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE 8 | from onnxruntime.capi import _pybind_state as C 9 | from onnxruntime.quantization import QuantType, quantize_dynamic 10 | from PIL import Image 11 | 12 | from unstructured_inference.constants import Source 13 | from unstructured_inference.inference.layoutelement import LayoutElement 14 | from unstructured_inference.logger import logger, logger_onnx 15 | from unstructured_inference.models.unstructuredmodel import ( 16 | UnstructuredObjectDetectionModel, 17 | ) 18 | from unstructured_inference.utils import ( 19 | LazyDict, 20 | LazyEvaluateInfo, 21 | download_if_needed_and_get_local_path, 22 | ) 23 | 24 | onnxruntime.set_default_logger_severity(logger_onnx.getEffectiveLevel()) 25 | 26 | DEFAULT_LABEL_MAP: Final[Dict[int, str]] = { 27 | 0: "Text", 28 | 1: "Title", 29 | 2: "List", 30 | 3: "Table", 31 | 4: "Figure", 32 | } 33 | 34 | 35 | # NOTE(alan): Entries are implemented as LazyDicts so that models aren't downloaded until they are 36 | # needed. 37 | MODEL_TYPES: Dict[str, Union[LazyDict, dict]] = { 38 | "detectron2_onnx": LazyDict( 39 | model_path=LazyEvaluateInfo( 40 | download_if_needed_and_get_local_path, 41 | "unstructuredio/detectron2_faster_rcnn_R_50_FPN_3x", 42 | "model.onnx", 43 | ), 44 | label_map=DEFAULT_LABEL_MAP, 45 | confidence_threshold=0.8, 46 | ), 47 | "detectron2_quantized": { 48 | "model_path": os.path.join( 49 | HUGGINGFACE_HUB_CACHE, 50 | "detectron2_quantized", 51 | "detectrin2_quantized.onnx", 52 | ), 53 | "label_map": DEFAULT_LABEL_MAP, 54 | "confidence_threshold": 0.8, 55 | }, 56 | "detectron2_mask_rcnn": LazyDict( 57 | model_path=LazyEvaluateInfo( 58 | download_if_needed_and_get_local_path, 59 | "unstructuredio/detectron2_mask_rcnn_X_101_32x8d_FPN_3x", 60 | "model.onnx", 61 | ), 62 | label_map=DEFAULT_LABEL_MAP, 63 | confidence_threshold=0.8, 64 | ), 65 | } 66 | 67 | 68 | class UnstructuredDetectronONNXModel(UnstructuredObjectDetectionModel): 69 | """Unstructured model wrapper for detectron2 ONNX model.""" 70 | 71 | # The model was trained and exported with this shape 72 | required_w = 800 73 | required_h = 1035 74 | 75 | def predict(self, image: Image.Image) -> List[LayoutElement]: 76 | """Makes a prediction using detectron2 model.""" 77 | super().predict(image) 78 | 79 | prepared_input = self.preprocess(image) 80 | try: 81 | result = self.model.run(None, prepared_input) 82 | bboxes = result[0] 83 | labels = result[1] 84 | # Previous model detectron2_onnx stored confidence scores at index 2, 85 | # bigger model stores it at index 3 86 | confidence_scores = result[2] if "R_50" in self.model_path else result[3] 87 | except onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: 88 | logger_onnx.debug( 89 | "Ignoring runtime error from onnx (likely due to encountering blank page).", 90 | ) 91 | return [] 92 | input_w, input_h = image.size 93 | regions = self.postprocess(bboxes, labels, confidence_scores, input_w, input_h) 94 | 95 | return regions 96 | 97 | def initialize( 98 | self, 99 | model_path: str, 100 | label_map: Dict[int, str], 101 | confidence_threshold: Optional[float] = None, 102 | ): 103 | """Loads the detectron2 model using the specified parameters""" 104 | if not os.path.exists(model_path) and "detectron2_quantized" in model_path: 105 | logger.info("Quantized model don't currently exists, quantizing now...") 106 | os.mkdir("".join(os.path.split(model_path)[:-1])) 107 | source_path = MODEL_TYPES["detectron2_onnx"]["model_path"] 108 | quantize_dynamic(source_path, model_path, weight_type=QuantType.QUInt8) 109 | 110 | available_providers = C.get_available_providers() 111 | ordered_providers = [ 112 | "TensorrtExecutionProvider", 113 | "CUDAExecutionProvider", 114 | "CPUExecutionProvider", 115 | ] 116 | providers = [provider for provider in ordered_providers if provider in available_providers] 117 | 118 | self.model = onnxruntime.InferenceSession( 119 | model_path, 120 | providers=providers, 121 | ) 122 | self.model_path = model_path 123 | self.label_map = label_map 124 | if confidence_threshold is None: 125 | confidence_threshold = 0.5 126 | self.confidence_threshold = confidence_threshold 127 | 128 | def preprocess(self, image: Image.Image) -> Dict[str, np.ndarray]: 129 | """Process input image into required format for ingestion into the Detectron2 ONNX binary. 130 | This involves resizing to a fixed shape and converting to a specific numpy format. 131 | """ 132 | # TODO (benjamin): check other shapes for inference 133 | img = np.array(image) 134 | # TODO (benjamin): We should use models.get_model() but currenly returns Detectron model 135 | session = self.model 136 | # onnx input expected 137 | # [3,1035,800] 138 | img = cv2.resize( 139 | img, 140 | (self.required_w, self.required_h), 141 | interpolation=cv2.INTER_LINEAR, 142 | ).astype(np.float32) 143 | img = img.transpose(2, 0, 1) 144 | ort_inputs = {session.get_inputs()[0].name: img} 145 | return ort_inputs 146 | 147 | def postprocess( 148 | self, 149 | bboxes: np.ndarray, 150 | labels: np.ndarray, 151 | confidence_scores: np.ndarray, 152 | input_w: float, 153 | input_h: float, 154 | ) -> List[LayoutElement]: 155 | """Process output into Unstructured class. Bounding box coordinates are converted to 156 | original image resolution.""" 157 | regions = [] 158 | width_conversion = input_w / self.required_w 159 | height_conversion = input_h / self.required_h 160 | for (x1, y1, x2, y2), label, conf in zip(bboxes, labels, confidence_scores): 161 | detected_class = self.label_map[int(label)] 162 | if conf >= self.confidence_threshold: 163 | region = LayoutElement.from_coords( 164 | x1 * width_conversion, 165 | y1 * height_conversion, 166 | x2 * width_conversion, 167 | y2 * height_conversion, 168 | text=None, 169 | type=detected_class, 170 | prob=conf, 171 | source=Source.DETECTRON2_ONNX, 172 | ) 173 | 174 | regions.append(region) 175 | 176 | regions.sort(key=lambda element: element.bbox.y1) 177 | return cast(List[LayoutElement], regions) 178 | -------------------------------------------------------------------------------- /unstructured_inference/models/eval.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Callable, Dict, List, Optional 3 | 4 | import pandas as pd 5 | from rapidfuzz import fuzz 6 | 7 | EVAL_FUNCTIONS = { 8 | "token_ratio": fuzz.token_ratio, 9 | "ratio": fuzz.ratio, 10 | "partial_token_ratio": fuzz.partial_token_ratio, 11 | "partial_ratio": fuzz.partial_ratio, 12 | } 13 | 14 | 15 | def _join_df_content(df, tab_token="\t", row_break_token="\n") -> str: 16 | """joining dataframe's table content as one long string""" 17 | return row_break_token.join([tab_token.join(row) for row in df.values]) 18 | 19 | 20 | def default_tokenizer(text: str) -> List[str]: 21 | """a simple tokenizer that splits text by white space""" 22 | return text.split() 23 | 24 | 25 | def compare_contents_as_df( 26 | actual_df: pd.DataFrame, 27 | pred_df: pd.DataFrame, 28 | eval_func: str = "token_ratio", 29 | processor: Optional[Callable] = None, 30 | tab_token: str = "\t", 31 | row_break_token: str = "\n", 32 | ) -> Dict[str, float]: 33 | r"""ravel the table as string then use text distance to compare the prediction against true 34 | table 35 | 36 | Parameters 37 | ---------- 38 | actual_df: pd.DataFrame 39 | actual table as pandas dataframe 40 | 41 | pred_df: pd.DataFrame 42 | predicted table as pandas dataframe 43 | 44 | eval_func: str, default tp "token_ratio" 45 | the eval_func should be one of "token_ratio", "ratio", "partial_token_ratio", 46 | "partial_ratio". Those are functions provided by rapidfuzz to evaluate text distances 47 | using either tokens or characters. In general token is better than characters for evaluating 48 | tables. 49 | 50 | processor: Callable, default to None 51 | processor to tokenize the text; by default None means no processing (using characters). For 52 | tokens eval functions we recommend using the `default_tokenizer` or some other functions to 53 | break down the text into words 54 | 55 | tab_token: str, default to "\t" 56 | the string to join cells together 57 | 58 | row_break_token: str, default to "\n" 59 | the string to join rows together 60 | 61 | Returns 62 | ------- 63 | Dict[str, int] 64 | mapping of by column and by row scores to the scores as float numbers 65 | """ 66 | func = EVAL_FUNCTIONS.get(eval_func) 67 | if func is None: 68 | raise ValueError( 69 | 'eval_func must be one of "token_ratio", "ratio", "partial_token_ratio", ' 70 | f'"partial_ratio" but got {eval_func}', 71 | ) 72 | join_func = partial(_join_df_content, tab_token=tab_token, row_break_token=row_break_token) 73 | return { 74 | f"by_col_{eval_func}": func( 75 | join_func(actual_df), 76 | join_func(pred_df), 77 | processor=processor, 78 | ), 79 | f"by_row_{eval_func}": func( 80 | join_func(actual_df.T), 81 | join_func(pred_df.T), 82 | processor=processor, 83 | ), 84 | } 85 | -------------------------------------------------------------------------------- /unstructured_inference/models/unstructuredmodel.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import Any, List, cast 5 | 6 | import numpy as np 7 | from PIL.Image import Image 8 | 9 | from unstructured_inference.constants import ElementType 10 | from unstructured_inference.inference.elements import ( 11 | grow_region_to_match_region, 12 | intersections, 13 | ) 14 | from unstructured_inference.inference.layoutelement import ( 15 | LayoutElement, 16 | LayoutElements, 17 | clean_layoutelements, 18 | partition_groups_from_regions, 19 | separate, 20 | ) 21 | 22 | 23 | class UnstructuredModel(ABC): 24 | """Wrapper class for the various models used by unstructured.""" 25 | 26 | def __init__(self): 27 | """model should support inference of some sort, either by calling or by some method. 28 | UnstructuredModel doesn't provide any training interface, it's assumed the model is 29 | already trained. 30 | """ 31 | self.model = None 32 | 33 | @abstractmethod 34 | def predict(self, x: Any) -> Any: 35 | """Do inference using the wrapped model.""" 36 | if self.model is None: 37 | raise ModelNotInitializedError( 38 | "Model has not been initialized. Please call the initialize method with the " 39 | "appropriate arguments for loading the model.", 40 | ) 41 | pass # pragma: no cover 42 | 43 | def __call__(self, x: Any) -> Any: 44 | """Inference using function call interface.""" 45 | return self.predict(x) 46 | 47 | @abstractmethod 48 | def initialize(self, *args, **kwargs): 49 | """Load the model for inference.""" 50 | pass # pragma: no cover 51 | 52 | 53 | class UnstructuredObjectDetectionModel(UnstructuredModel): 54 | """Wrapper class for object detection models used by unstructured.""" 55 | 56 | @abstractmethod 57 | def predict(self, x: Image) -> LayoutElements | list[LayoutElement]: 58 | """Do inference using the wrapped model.""" 59 | super().predict(x) 60 | return [] 61 | 62 | def __call__(self, x: Image) -> LayoutElements: 63 | """Inference using function call interface.""" 64 | return super().__call__(x) 65 | 66 | @staticmethod 67 | def enhance_regions( 68 | elements: List[LayoutElement], 69 | iom_to_merge: float = 0.3, 70 | ) -> List[LayoutElement]: 71 | """This function traverses all the elements and either deletes nested elements, 72 | or merges or splits them depending on the iom score for both regions""" 73 | rects = [el.bbox for el in elements] 74 | intersections_mtx = intersections(*rects) 75 | 76 | for i, row in enumerate(intersections_mtx): 77 | first = elements[i] 78 | if first: 79 | # We get only the elements which intersected 80 | indices_to_check = np.where(row)[0] 81 | # Delete the first element, since it will always intersect with itself 82 | indices_to_check = indices_to_check[indices_to_check != i] 83 | if len(indices_to_check) == 0: 84 | continue 85 | if len(indices_to_check) > 1: # sort by iom 86 | iom_to_check = [ 87 | (j, first.bbox.intersection_over_minimum(elements[j].bbox)) 88 | for j in indices_to_check 89 | if elements[j] is not None 90 | ] 91 | iom_to_check.sort( 92 | key=lambda x: x[1], 93 | reverse=True, 94 | ) # sort elements by iom, so we first check the greatest 95 | indices_to_check = [x[0] for x in iom_to_check if x[0] != i] # type:ignore 96 | for j in indices_to_check: 97 | if elements[j] is None or elements[i] is None: 98 | continue 99 | second = elements[j] 100 | intersection = first.bbox.intersection( 101 | second.bbox, 102 | ) # we know it does, but need the region 103 | first_inside_second = first.bbox.is_in(second.bbox) 104 | second_inside_first = second.bbox.is_in(first.bbox) 105 | 106 | if first_inside_second and not second_inside_first: 107 | elements[i] = None # type:ignore 108 | elif second_inside_first and not first_inside_second: 109 | # delete second element 110 | elements[j] = None # type:ignore 111 | elif intersection: 112 | iom = first.bbox.intersection_over_minimum(second.bbox) 113 | if iom < iom_to_merge: # small 114 | separate(first.bbox, second.bbox) 115 | # The rectangle could become too small, which is a 116 | # good size to delete? 117 | else: # big 118 | # merge 119 | if first.bbox.area > second.bbox.area: 120 | grow_region_to_match_region(first.bbox, second.bbox) 121 | elements[j] = None # type:ignore 122 | else: 123 | grow_region_to_match_region(second.bbox, first.bbox) 124 | elements[i] = None # type:ignore 125 | 126 | elements = [e for e in elements if e is not None] 127 | return elements 128 | 129 | @staticmethod 130 | def clean_type( 131 | elements: list[LayoutElement], 132 | type_to_clean=ElementType.TABLE, 133 | ) -> List[LayoutElement]: 134 | """After this function, the list of elements will not contain any element inside 135 | of the type specified""" 136 | target_elements = [e for e in elements if e.type == type_to_clean] 137 | other_elements = [e for e in elements if e.type != type_to_clean] 138 | if len(target_elements) == 0 or len(other_elements) == 0: 139 | return elements 140 | 141 | # Sort elements from biggest to smallest 142 | target_elements.sort(key=lambda e: e.bbox.area, reverse=True) 143 | other_elements.sort(key=lambda e: e.bbox.area, reverse=True) 144 | 145 | # First check if targets contains each other 146 | for element in target_elements: # Just handles containment or little overlap 147 | contains = [ 148 | e 149 | for e in target_elements 150 | if e.bbox.is_almost_subregion_of(element.bbox) and e != element 151 | ] 152 | for contained in contains: 153 | target_elements.remove(contained) 154 | # Then check if remaining elements intersect with targets 155 | other_elements = filter( 156 | lambda e: not any( 157 | e.bbox.is_almost_subregion_of(target.bbox) for target in target_elements 158 | ), 159 | other_elements, 160 | ) # type:ignore 161 | 162 | final_elements = list(other_elements) 163 | final_elements.extend(target_elements) 164 | # Note(benjamin): could use bisect.insort, 165 | # but need to add < operator for 166 | # LayoutElement in python <3.10 167 | final_elements.sort(key=lambda e: e.bbox.y1) 168 | return final_elements 169 | 170 | def deduplicate_detected_elements( 171 | self, 172 | elements: LayoutElements, 173 | min_text_size: int = 15, 174 | ) -> LayoutElements: 175 | """Deletes overlapping elements in a list of elements.""" 176 | 177 | if len(elements) <= 1: 178 | return elements 179 | 180 | cleaned_elements = [] 181 | # TODO: Delete nested elements with low or None probability 182 | # TODO: Keep most confident 183 | # TODO: Better to grow horizontally than vertically? 184 | groups = cast(list[LayoutElements], partition_groups_from_regions(elements)) 185 | for group in groups: 186 | cleaned_elements.append(clean_layoutelements(group)) 187 | return LayoutElements.concatenate(cleaned_elements) 188 | 189 | 190 | class UnstructuredElementExtractionModel(UnstructuredModel): 191 | """Wrapper class for object extraction models used by unstructured.""" 192 | 193 | @abstractmethod 194 | def predict(self, x: Image) -> List[LayoutElement]: 195 | """Do inference using the wrapped model.""" 196 | super().predict(x) 197 | return [] # pragma: no cover 198 | 199 | def __call__(self, x: Image) -> List[LayoutElement]: 200 | """Inference using function call interface.""" 201 | return super().__call__(x) 202 | 203 | 204 | class ModelNotInitializedError(Exception): 205 | pass 206 | -------------------------------------------------------------------------------- /unstructured_inference/models/yolox.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Megvii, Inc. and its affiliates. 2 | # Unstructured modified the original source code found at: 3 | # https://github.com/Megvii-BaseDetection/YOLOX/blob/237e943ac64aa32eb32f875faa93ebb18512d41d/yolox/data/data_augment.py 4 | # https://github.com/Megvii-BaseDetection/YOLOX/blob/ac379df3c97d1835ebd319afad0c031c36d03f36/yolox/utils/demo_utils.py 5 | 6 | import cv2 7 | import numpy as np 8 | import onnxruntime 9 | from onnxruntime.capi import _pybind_state as C 10 | from PIL import Image as PILImage 11 | 12 | from unstructured_inference.constants import ElementType, Source 13 | from unstructured_inference.inference.layoutelement import LayoutElements 14 | from unstructured_inference.models.unstructuredmodel import ( 15 | UnstructuredObjectDetectionModel, 16 | ) 17 | from unstructured_inference.utils import ( 18 | LazyDict, 19 | LazyEvaluateInfo, 20 | download_if_needed_and_get_local_path, 21 | ) 22 | 23 | YOLOX_LABEL_MAP = { 24 | 0: ElementType.CAPTION, 25 | 1: ElementType.FOOTNOTE, 26 | 2: ElementType.FORMULA, 27 | 3: ElementType.LIST_ITEM, 28 | 4: ElementType.PAGE_FOOTER, 29 | 5: ElementType.PAGE_HEADER, 30 | 6: ElementType.PICTURE, 31 | 7: ElementType.SECTION_HEADER, 32 | 8: ElementType.TABLE, 33 | 9: ElementType.TEXT, 34 | 10: ElementType.TITLE, 35 | } 36 | 37 | MODEL_TYPES = { 38 | "yolox": LazyDict( 39 | model_path=LazyEvaluateInfo( 40 | download_if_needed_and_get_local_path, 41 | "unstructuredio/yolo_x_layout", 42 | "yolox_l0.05.onnx", 43 | ), 44 | label_map=YOLOX_LABEL_MAP, 45 | ), 46 | "yolox_tiny": LazyDict( 47 | model_path=LazyEvaluateInfo( 48 | download_if_needed_and_get_local_path, 49 | "unstructuredio/yolo_x_layout", 50 | "yolox_tiny.onnx", 51 | ), 52 | label_map=YOLOX_LABEL_MAP, 53 | ), 54 | "yolox_quantized": LazyDict( 55 | model_path=LazyEvaluateInfo( 56 | download_if_needed_and_get_local_path, 57 | "unstructuredio/yolo_x_layout", 58 | "yolox_l0.05_quantized.onnx", 59 | ), 60 | label_map=YOLOX_LABEL_MAP, 61 | ), 62 | } 63 | 64 | 65 | class UnstructuredYoloXModel(UnstructuredObjectDetectionModel): 66 | def predict(self, x: PILImage.Image): 67 | """Predict using YoloX model.""" 68 | super().predict(x) 69 | return self.image_processing(x) 70 | 71 | def initialize(self, model_path: str, label_map: dict): 72 | """Start inference session for YoloX model.""" 73 | self.model_path = model_path 74 | 75 | available_providers = C.get_available_providers() 76 | ordered_providers = [ 77 | "TensorrtExecutionProvider", 78 | "CUDAExecutionProvider", 79 | "CPUExecutionProvider", 80 | ] 81 | providers = [provider for provider in ordered_providers if provider in available_providers] 82 | 83 | self.model = onnxruntime.InferenceSession( 84 | model_path, 85 | providers=providers, 86 | ) 87 | 88 | self.layout_classes = label_map 89 | 90 | def image_processing( 91 | self, 92 | image: PILImage.Image, 93 | ) -> LayoutElements: 94 | """Method runing YoloX for layout detection, returns a PageLayout 95 | parameters 96 | ---------- 97 | page 98 | Path for image file with the image to process 99 | origin_img 100 | If specified, an Image object for process with YoloX model 101 | page_number 102 | Number asigned to the PageLayout returned 103 | output_directory 104 | Boolean indicating if result will be stored 105 | """ 106 | # The model was trained and exported with this shape 107 | # TODO (benjamin): check other shapes for inference 108 | input_shape = (1024, 768) 109 | origin_img = np.array(image) 110 | img, ratio = preprocess(origin_img, input_shape) 111 | session = self.model 112 | 113 | ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]} 114 | output = session.run(None, ort_inputs) 115 | # TODO(benjamin): check for p6 116 | predictions = demo_postprocess(output[0], input_shape, p6=False)[0] 117 | 118 | boxes = predictions[:, :4] 119 | scores = predictions[:, 4:5] * predictions[:, 5:] 120 | 121 | boxes_xyxy = np.ones_like(boxes) 122 | boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2.0 123 | boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2.0 124 | boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2.0 125 | boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2.0 126 | boxes_xyxy /= ratio 127 | 128 | # Note (Benjamin): Distinct models (quantized and original) requires distincts 129 | # levels of thresholds 130 | if "quantized" in self.model_path: 131 | dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.0, score_thr=0.07) 132 | else: 133 | dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.1, score_thr=0.25) 134 | 135 | order = np.argsort(dets[:, 1]) 136 | sorted_dets = dets[order] 137 | 138 | return LayoutElements( 139 | element_coords=sorted_dets[:, :4].astype(float), 140 | element_probs=sorted_dets[:, 4].astype(float), 141 | element_class_ids=sorted_dets[:, 5].astype(int), 142 | element_class_id_map=self.layout_classes, 143 | sources=np.array([Source.YOLOX] * sorted_dets.shape[0]), 144 | ) 145 | 146 | 147 | # Note: preprocess function was named preproc on original source 148 | 149 | 150 | def preprocess(img, input_size, swap=(2, 0, 1)): 151 | """Preprocess image data before YoloX inference.""" 152 | if len(img.shape) == 3: 153 | padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 154 | else: 155 | padded_img = np.ones(input_size, dtype=np.uint8) * 114 156 | 157 | r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) 158 | resized_img = cv2.resize( 159 | img, 160 | (int(img.shape[1] * r), int(img.shape[0] * r)), 161 | interpolation=cv2.INTER_LINEAR, 162 | ).astype(np.uint8) 163 | padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img 164 | 165 | padded_img = padded_img.transpose(swap) 166 | padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) 167 | return padded_img, r 168 | 169 | 170 | def demo_postprocess(outputs, img_size, p6=False): 171 | """Postprocessing for YoloX model.""" 172 | grids = [] 173 | expanded_strides = [] 174 | 175 | strides = [8, 16, 32] if not p6 else [8, 16, 32, 64] 176 | 177 | hsizes = [img_size[0] // stride for stride in strides] 178 | wsizes = [img_size[1] // stride for stride in strides] 179 | 180 | for hsize, wsize, stride in zip(hsizes, wsizes, strides): 181 | xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) 182 | grid = np.stack((xv, yv), 2).reshape(1, -1, 2) 183 | grids.append(grid) 184 | shape = grid.shape[:2] 185 | expanded_strides.append(np.full((*shape, 1), stride)) 186 | 187 | grids = np.concatenate(grids, 1) 188 | expanded_strides = np.concatenate(expanded_strides, 1) 189 | outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides 190 | outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides 191 | 192 | return outputs 193 | 194 | 195 | def multiclass_nms(boxes, scores, nms_thr, score_thr, class_agnostic=True): 196 | """Multiclass NMS implemented in Numpy""" 197 | # TODO(benjamin): check for non-class agnostic 198 | # if class_agnostic: 199 | nms_method = multiclass_nms_class_agnostic 200 | # else: 201 | # nms_method = multiclass_nms_class_aware 202 | return nms_method(boxes, scores, nms_thr, score_thr) 203 | 204 | 205 | def multiclass_nms_class_agnostic(boxes, scores, nms_thr, score_thr): 206 | """Multiclass NMS implemented in Numpy. Class-agnostic version.""" 207 | cls_inds = scores.argmax(1) 208 | cls_scores = scores[np.arange(len(cls_inds)), cls_inds] 209 | 210 | valid_score_mask = cls_scores > score_thr 211 | valid_scores = cls_scores[valid_score_mask] 212 | valid_boxes = boxes[valid_score_mask] 213 | valid_cls_inds = cls_inds[valid_score_mask] 214 | keep = nms(valid_boxes, valid_scores, nms_thr) 215 | dets = np.concatenate( 216 | [valid_boxes[keep], valid_scores[keep, None], valid_cls_inds[keep, None]], 217 | 1, 218 | ) 219 | return dets 220 | 221 | 222 | def nms(boxes, scores, nms_thr): 223 | """Single class NMS implemented in Numpy.""" 224 | x1 = boxes[:, 0] 225 | y1 = boxes[:, 1] 226 | x2 = boxes[:, 2] 227 | y2 = boxes[:, 3] 228 | 229 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 230 | order = scores.argsort()[::-1] 231 | 232 | keep = [] 233 | while order.size > 0: 234 | i = order[0] 235 | keep.append(i) 236 | xx1 = np.maximum(x1[i], x1[order[1:]]) 237 | yy1 = np.maximum(y1[i], y1[order[1:]]) 238 | xx2 = np.minimum(x2[i], x2[order[1:]]) 239 | yy2 = np.minimum(y2[i], y2[order[1:]]) 240 | 241 | w = np.maximum(0.0, xx2 - xx1 + 1) 242 | h = np.maximum(0.0, yy2 - yy1 + 1) 243 | inter = w * h 244 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 245 | 246 | inds = np.where(ovr <= nms_thr)[0] 247 | order = order[inds + 1] 248 | 249 | return keep 250 | -------------------------------------------------------------------------------- /unstructured_inference/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections.abc import Mapping 3 | from html.parser import HTMLParser 4 | from io import StringIO 5 | from typing import Any, Callable, Hashable, Iterable, Iterator, Union 6 | 7 | from huggingface_hub import hf_hub_download 8 | from PIL import Image 9 | 10 | from unstructured_inference.inference.layoutelement import LayoutElement 11 | 12 | 13 | class LazyEvaluateInfo: 14 | """Class that stores the information needed to lazily evaluate a function with given arguments. 15 | The object stores the information needed for evaluation as a function and its arguments. 16 | """ 17 | 18 | def __init__(self, evaluate: Callable, *args, **kwargs): 19 | self.evaluate = evaluate 20 | self.info = (args, kwargs) 21 | 22 | 23 | class LazyDict(Mapping): 24 | """Class that wraps a dict and only evaluates keys of the dict when the key is accessed. Keys 25 | that should be evaluated lazily should use LazyEvaluateInfo objects as values. By default when 26 | a value is computed from a LazyEvaluateInfo object, it is converted to the raw value in the 27 | internal dict, so subsequent accessing of the key will produce the same value. Set cache=False 28 | to avoid storing the raw value. 29 | """ 30 | 31 | def __init__(self, *args, cache=True, **kwargs): 32 | self.cache = cache 33 | self._raw_dict = dict(*args, **kwargs) 34 | 35 | def __getitem__(self, key: Hashable) -> Union[LazyEvaluateInfo, Any]: 36 | value = self._raw_dict.__getitem__(key) 37 | if isinstance(value, LazyEvaluateInfo): 38 | evaluate = value.evaluate 39 | args, kwargs = value.info 40 | value = evaluate(*args, **kwargs) 41 | if self.cache: 42 | self._raw_dict[key] = value 43 | return value 44 | 45 | def __iter__(self) -> Iterator: 46 | return iter(self._raw_dict) 47 | 48 | def __len__(self) -> int: 49 | return len(self._raw_dict) 50 | 51 | 52 | def tag(elements: Iterable[LayoutElement]): 53 | """Asign an numeric id to the elements in the list. 54 | Useful for debugging""" 55 | colors = ["red", "blue", "green", "magenta", "brown"] 56 | for i, e in enumerate(elements): 57 | e.text = f"-{i}-:{e.text}" 58 | # currently not a property 59 | e.id = i # type:ignore 60 | e.color = colors[i % len(colors)] # type:ignore 61 | 62 | 63 | def pad_image_with_background_color( 64 | image: Image.Image, 65 | pad: int = 10, 66 | background_color: str = "white", 67 | ) -> Image.Image: 68 | """pads an input image with the same background color around it by pad on all 4 sides 69 | 70 | The original image is kept intact and a new image is returned with padding added. 71 | """ 72 | width, height = image.size 73 | if pad < 0: 74 | raise ValueError( 75 | "Can not pad an image with negative space! Please use a positive value for `pad`.", 76 | ) 77 | new = Image.new(image.mode, (width + pad * 2, height + pad * 2), background_color) 78 | new.paste(image, (pad, pad)) 79 | return new 80 | 81 | 82 | class MLStripper(HTMLParser): 83 | """simple markup language stripper that helps to strip tags from string""" 84 | 85 | def __init__(self): 86 | super().__init__() 87 | self.reset() 88 | self.strict = True 89 | self.convert_charrefs = True 90 | self.text = StringIO() 91 | 92 | def handle_data(self, d): 93 | """process input data""" 94 | self.text.write(d) 95 | 96 | def get_data(self): 97 | """performs stripping by get the value of text""" 98 | return self.text.getvalue() 99 | 100 | 101 | def strip_tags(html: str) -> str: 102 | """stripping html tags from input string and return string without tags""" 103 | s = MLStripper() 104 | s.feed(html) 105 | return s.get_data() 106 | 107 | 108 | def download_if_needed_and_get_local_path(path_or_repo: str, filename: str, **kwargs) -> str: 109 | """Returns path to local file if it exists, otherwise treats it as a huggingface repo and 110 | attempts to download.""" 111 | full_path = os.path.join(path_or_repo, filename) 112 | if os.path.exists(full_path): 113 | return full_path 114 | else: 115 | return hf_hub_download(path_or_repo, filename, **kwargs) 116 | -------------------------------------------------------------------------------- /unstructured_inference/visualize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Megvii Inc. All rights reserved. 2 | # Unstructured modified the original source code found at 3 | # https://github.com/Megvii-BaseDetection/YOLOX/blob/ac379df3c97d1835ebd319afad0c031c36d03f36/yolox/utils/visualize.py 4 | import typing 5 | from typing import Optional, Union 6 | 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | from PIL import ImageFont 10 | from PIL.Image import Image 11 | from PIL.ImageDraw import ImageDraw 12 | 13 | from unstructured_inference.inference.elements import TextRegion 14 | 15 | 16 | @typing.no_type_check 17 | def draw_bbox( 18 | image: Image, 19 | element: TextRegion, 20 | color: str = "red", 21 | width=1, 22 | details: bool = False, 23 | ) -> Image: 24 | """Draws bounding box in image""" 25 | try: 26 | img = image.copy() 27 | draw = ImageDraw(img) 28 | topleft, _, bottomright, _ = element.bbox.coordinates 29 | c = getattr(element, "color", color) 30 | if details: 31 | source = getattr(element, "source", "Unknown") 32 | type = getattr(element, "type", "") 33 | kbd = ImageFont.truetype("Keyboard.ttf", 20) 34 | draw.text(topleft, text=f"{type} {source}", fill=c, font=kbd) 35 | draw.rectangle((topleft, bottomright), outline=c, width=width) 36 | except OSError: 37 | print("Failed to find font file. Skipping details.") 38 | img = draw_bbox(image, element, color, width) 39 | except Exception as e: 40 | print(f"Failed to draw bounding box: {e}") 41 | return img 42 | 43 | 44 | def show_plot( 45 | image: Union[Image, np.ndarray], 46 | desired_width: Optional[int] = None, 47 | ): 48 | """ 49 | Display an image using matplotlib with an optional desired width while maintaining the aspect 50 | ratio. 51 | 52 | Parameters: 53 | - image (Union[Image, np.ndarray]): An image in PIL Image format or a numpy ndarray format. 54 | - desired_width (Optional[int]): Desired width for the display size of the image. 55 | If provided, the height is calculated based on the original aspect ratio. 56 | If not provided, the image will be displayed with its original dimensions. 57 | 58 | Raises: 59 | - ValueError: If the provided image type is neither PIL Image nor numpy ndarray. 60 | 61 | Returns: 62 | - None: The function displays the image using matplotlib but does not return any value. 63 | """ 64 | if isinstance(image, Image): 65 | image_width, image_height = image.size 66 | elif isinstance(image, np.ndarray): 67 | image_height, image_width, _ = image.shape 68 | else: 69 | raise ValueError("Unsupported Image Type") 70 | 71 | if desired_width: 72 | # Calculate the desired height based on the original aspect ratio 73 | aspect_ratio = image_width / image_height 74 | desired_height = desired_width / aspect_ratio 75 | 76 | # Create a figure with the desired size and aspect ratio 77 | fig, ax = plt.subplots(figsize=(desired_width, desired_height)) 78 | else: 79 | # Create figure and axes 80 | fig, ax = plt.subplots() 81 | # Display the image 82 | ax.imshow(image) 83 | plt.show() 84 | 85 | 86 | _COLORS = np.array( 87 | [ 88 | [0.000, 0.447, 0.741], 89 | [0.850, 0.325, 0.098], 90 | [0.929, 0.694, 0.125], 91 | [0.494, 0.184, 0.556], 92 | [0.466, 0.674, 0.188], 93 | [0.301, 0.745, 0.933], 94 | [0.635, 0.078, 0.184], 95 | [0.300, 0.300, 0.300], 96 | [0.600, 0.600, 0.600], 97 | [1.000, 0.000, 0.000], 98 | [1.000, 0.500, 0.000], 99 | [0.749, 0.749, 0.000], 100 | [0.000, 1.000, 0.000], 101 | [0.000, 0.000, 1.000], 102 | [0.667, 0.000, 1.000], 103 | [0.333, 0.333, 0.000], 104 | [0.333, 0.667, 0.000], 105 | [0.333, 1.000, 0.000], 106 | [0.667, 0.333, 0.000], 107 | [0.667, 0.667, 0.000], 108 | [0.667, 1.000, 0.000], 109 | [1.000, 0.333, 0.000], 110 | [1.000, 0.667, 0.000], 111 | [1.000, 1.000, 0.000], 112 | [0.000, 0.333, 0.500], 113 | [0.000, 0.667, 0.500], 114 | [0.000, 1.000, 0.500], 115 | [0.333, 0.000, 0.500], 116 | [0.333, 0.333, 0.500], 117 | [0.333, 0.667, 0.500], 118 | [0.333, 1.000, 0.500], 119 | [0.667, 0.000, 0.500], 120 | [0.667, 0.333, 0.500], 121 | [0.667, 0.667, 0.500], 122 | [0.667, 1.000, 0.500], 123 | [1.000, 0.000, 0.500], 124 | [1.000, 0.333, 0.500], 125 | [1.000, 0.667, 0.500], 126 | [1.000, 1.000, 0.500], 127 | [0.000, 0.333, 1.000], 128 | [0.000, 0.667, 1.000], 129 | [0.000, 1.000, 1.000], 130 | [0.333, 0.000, 1.000], 131 | [0.333, 0.333, 1.000], 132 | [0.333, 0.667, 1.000], 133 | [0.333, 1.000, 1.000], 134 | [0.667, 0.000, 1.000], 135 | [0.667, 0.333, 1.000], 136 | [0.667, 0.667, 1.000], 137 | [0.667, 1.000, 1.000], 138 | [1.000, 0.000, 1.000], 139 | [1.000, 0.333, 1.000], 140 | [1.000, 0.667, 1.000], 141 | [0.333, 0.000, 0.000], 142 | [0.500, 0.000, 0.000], 143 | [0.667, 0.000, 0.000], 144 | [0.833, 0.000, 0.000], 145 | [1.000, 0.000, 0.000], 146 | [0.000, 0.167, 0.000], 147 | [0.000, 0.333, 0.000], 148 | [0.000, 0.500, 0.000], 149 | [0.000, 0.667, 0.000], 150 | [0.000, 0.833, 0.000], 151 | [0.000, 1.000, 0.000], 152 | [0.000, 0.000, 0.167], 153 | [0.000, 0.000, 0.333], 154 | [0.000, 0.000, 0.500], 155 | [0.000, 0.000, 0.667], 156 | [0.000, 0.000, 0.833], 157 | [0.000, 0.000, 1.000], 158 | [0.000, 0.000, 0.000], 159 | [0.143, 0.143, 0.143], 160 | [0.286, 0.286, 0.286], 161 | [0.429, 0.429, 0.429], 162 | [0.571, 0.571, 0.571], 163 | [0.714, 0.714, 0.714], 164 | [0.857, 0.857, 0.857], 165 | [0.000, 0.447, 0.741], 166 | [0.314, 0.717, 0.741], 167 | [0.50, 0.5, 0], 168 | ], 169 | ).astype(np.float32) 170 | --------------------------------------------------------------------------------