├── .bumpversion.cfg ├── .editorconfig ├── .github ├── ISSUE_TEMPLATE.md └── workflows │ ├── dev.yml │ ├── docs.yml │ ├── preview.yml │ └── release.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CHANGELOG.md ├── CNAME ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── benchmarks ├── classification.ipynb └── ocr.ipynb ├── docprompt ├── __init__.py ├── _decorators.py ├── _exec │ ├── __init__.py │ ├── ghostscript.py │ └── tesseract.py ├── _pdfium.py ├── contrib │ ├── __init__.py │ └── parser_bot │ │ └── __init__.py ├── provenance │ ├── __init__.py │ ├── search.py │ ├── source.py │ └── util.py ├── rasterize.py ├── schema │ ├── __init__.py │ ├── document.py │ ├── layout.py │ └── pipeline │ │ ├── __init__.py │ │ ├── metadata.py │ │ ├── node │ │ ├── __init__.py │ │ ├── base.py │ │ ├── collection.py │ │ ├── document.py │ │ ├── image.py │ │ ├── page.py │ │ └── typing.py │ │ └── rasterizer.py ├── storage.py ├── tasks │ ├── __init__.py │ ├── base.py │ ├── capabilities.py │ ├── classification │ │ ├── __init__.py │ │ ├── anthropic.py │ │ └── base.py │ ├── credentials.py │ ├── factory.py │ ├── markerize │ │ ├── __init__.py │ │ ├── anthropic.py │ │ └── base.py │ ├── message.py │ ├── ocr │ │ ├── __init__.py │ │ ├── amazon.py │ │ ├── base.py │ │ ├── gcp.py │ │ ├── result.py │ │ └── tesseract.py │ ├── parser.py │ ├── result.py │ ├── table_extraction │ │ ├── __init__.py │ │ ├── anthropic.py │ │ ├── base.py │ │ └── schema.py │ └── util.py └── utils │ ├── __init__.py │ ├── compressor.py │ ├── date_extraction.py │ ├── inference.py │ ├── layout.py │ ├── masking │ ├── __init__.py │ └── image.py │ ├── splitter.py │ └── util.py ├── docs ├── CNAME ├── assets │ └── static │ │ └── img │ │ ├── logo.png │ │ ├── logo.svg │ │ └── old-logo.png ├── blog │ └── index.md ├── community │ ├── contributing.md │ └── versioning.md ├── concepts │ ├── nodes.md │ ├── primatives.md │ ├── provenance.md │ └── providers.md ├── enterprise.md ├── gen_ref_pages.py ├── guide │ ├── classify │ │ ├── binary.md │ │ ├── multi.md │ │ └── single.md │ ├── ocr │ │ ├── advanced_search.md │ │ ├── advanced_workflows.md │ │ ├── basic_usage.md │ │ └── provider_config.md │ └── table_extraction │ │ └── extract_tables.md └── index.md ├── makefile ├── mkdocs.yml ├── pdm.lock ├── pyproject.toml ├── setup.cfg └── tests ├── __init__.py ├── _run_tests_with_coverage.py ├── conftest.py ├── fixtures.py ├── fixtures ├── 1.pdf ├── 1_ocr.json └── 2.pdf ├── schema ├── __init__.py ├── pipeline │ ├── __init__.py │ ├── test_imagenode.py │ ├── test_layoutaware.py │ ├── test_metadata.py │ └── test_rasterizer.py ├── test_document.py └── test_layout_models.py ├── tasks ├── __init__.py ├── classification │ ├── __init__.py │ ├── test_anthropic.py │ └── test_base.py ├── markerize │ ├── __init__.py │ ├── test_anthropic.py │ └── test_base.py ├── table_extraction │ ├── __init__.py │ ├── test_anthropic.py │ └── test_base.py ├── test_credentials.py ├── test_factory.py ├── test_result.py └── test_task_provider.py ├── test_date_extraction.py ├── test_decorators.py ├── test_search.py ├── test_storage.py ├── test_threadpool.py └── util.py /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.1.0 3 | commit = True 4 | tag = True 5 | 6 | [bumpversion:file:pyproject.toml] 7 | search = version = "{current_version}" 8 | replace = version = "{new_version}" 9 | 10 | [bumpversion:file:docprompt/__init__.py] 11 | search = __version__ = '{current_version}' 12 | replace = __version__ = '{new_version}' 13 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | 13 | [*.bat] 14 | indent_style = tab 15 | end_of_line = crlf 16 | 17 | [LICENSE] 18 | insert_final_newline = false 19 | 20 | [Makefile] 21 | indent_style = tab 22 | 23 | [*.{yml, yaml}] 24 | indent_size = 2 25 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | * Docprompt version: 2 | * Python version: 3 | * Operating System: 4 | 5 | ### Description 6 | 7 | Describe what you were trying to get done. 8 | Tell us what happened, what went wrong, and what you expected to happen. 9 | 10 | ### What I Did 11 | 12 | ``` 13 | Paste the command(s) you ran and the output. 14 | If there was a crash, please include the traceback here. 15 | ``` 16 | -------------------------------------------------------------------------------- /.github/workflows/dev.yml: -------------------------------------------------------------------------------- 1 | # This is a basic workflow to help you get started with Actions 2 | 3 | name: dev workflow 4 | 5 | # Controls when the action will run. 6 | on: 7 | # Triggers the workflow on push or pull request events but only for the master branch 8 | push: 9 | branches: [ main ] 10 | pull_request: 11 | branches: [ main ] 12 | 13 | # Allows you to run this workflow manually from the Actions tab 14 | workflow_dispatch: 15 | 16 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel 17 | jobs: 18 | # This workflow contains a single job called "test" 19 | test: 20 | # The type of runner that the job will run on 21 | strategy: 22 | matrix: 23 | python: ['3.8', '3.9', '3.10', '3.11', '3.12'] 24 | runs-on: ubuntu-latest 25 | 26 | # Steps represent a sequence of tasks that will be executed as part of the job 27 | steps: 28 | # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it 29 | - uses: actions/checkout@v4 30 | - uses: actions/setup-python@v5 31 | with: 32 | python-version: ${{ matrix.python }} 33 | cache: 'pip' 34 | 35 | - name: Install dependencies 36 | run: | 37 | sudo apt update && sudo apt install -y ghostscript 38 | python -m pip install --upgrade pip 39 | pip install -U pdm 40 | pdm install -G:all 41 | - name: Lint with Ruff 42 | run: | 43 | pdm run ruff check --output-format=github . 44 | - name: Test with pytest 45 | run: | 46 | pdm run pytest -sxv --cov --cov-report=xml . 47 | - name: Upload coverage reports to Codecov 48 | uses: codecov/codecov-action@v4.0.1 49 | with: 50 | token: ${{ secrets.CODECOV_TOKEN }} 51 | slug: Page-Leaf/Docprompt 52 | files: ./coverage.xml 53 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Deploy MkDocs to GitHub Pages 2 | 3 | on: 4 | push: 5 | branches: 6 | - main # Set this to your default branch 7 | 8 | jobs: 9 | deploy: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | 14 | - name: Set up Python 15 | uses: actions/setup-python@v4 16 | with: 17 | python-version: '3.x' 18 | 19 | - name: Install PDM 20 | run: pip install pdm 21 | 22 | - name: Install dependencies 23 | run: pdm install -d 24 | 25 | - name: Deploy MkDocs 26 | run: pdm run mkdocs gh-deploy --force 27 | -------------------------------------------------------------------------------- /.github/workflows/preview.yml: -------------------------------------------------------------------------------- 1 | # This is a basic workflow to help you get started with Actions 2 | 3 | name: stage & preview workflow 4 | 5 | # Controls when the action will run. 6 | on: 7 | # Triggers the workflow on push or pull request events but only for the master branch 8 | push: 9 | branches: [ main ] 10 | 11 | # Allows you to run this workflow manually from the Actions tab 12 | workflow_dispatch: 13 | 14 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel 15 | jobs: 16 | publish_dev_build: 17 | runs-on: ubuntu-latest 18 | 19 | strategy: 20 | matrix: 21 | python-versions: [3.9] 22 | 23 | permissions: 24 | id-token: write 25 | contents: read 26 | steps: 27 | - uses: actions/checkout@v4 28 | - uses: actions/setup-python@v5 29 | with: 30 | python-version: ${{ matrix.python-versions }} 31 | cache: 'pip' 32 | 33 | - name: Install dependencies 34 | run: | 35 | sudo apt update && sudo apt install -y ghostscript 36 | python -m pip install --upgrade pip 37 | pip install pdm 38 | pdm install -G:all 39 | 40 | - name: Run Pytest 41 | run: 42 | pdm run pytest -sxv 43 | 44 | - name: Build wheels and source tarball 45 | run: | 46 | pdm build 47 | 48 | - name: Publish package distributions to TestPyPI 49 | uses: pypa/gh-action-pypi-publish@release/v1 50 | with: 51 | repository-url: https://test.pypi.org/legacy/ 52 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: release & publish workflow 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' 7 | workflow_dispatch: 8 | 9 | jobs: 10 | release: 11 | name: Create Release 12 | runs-on: ubuntu-latest 13 | 14 | strategy: 15 | matrix: 16 | python-versions: [3.11] 17 | 18 | permissions: 19 | contents: write # This allows creating releases 20 | 21 | steps: 22 | - name: Get version from tag 23 | id: tag_name 24 | run: | 25 | echo ::set-output name=current_version::${GITHUB_REF#refs/tags/v} 26 | shell: bash 27 | 28 | - uses: actions/checkout@v4 29 | - uses: actions/setup-python@v5 30 | with: 31 | python-version: ${{ matrix.python-versions }} 32 | 33 | - name: Install dependencies 34 | run: | 35 | python -m pip install --upgrade pip 36 | pip install pdm 37 | pdm install -G:all 38 | 39 | - name: Build wheels and source tarball 40 | run: >- 41 | pdm build 42 | 43 | - name: show temporary files 44 | run: >- 45 | ls -l 46 | 47 | - name: create github release 48 | id: create_release 49 | uses: softprops/action-gh-release@v2 50 | env: 51 | GITHUB_TOKEN: ${{ secrets.GH_TOKEN }} 52 | with: 53 | files: dist/*.whl 54 | draft: false 55 | prerelease: false 56 | 57 | - name: publish to PyPI 58 | uses: pypa/gh-action-pypi-publish@release/v1 59 | with: 60 | user: __token__ 61 | password: ${{ secrets.PYPI_API_TOKEN }} 62 | skip_existing: true 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # macOS 2 | .DS_Store 3 | .AppleDouble 4 | .LSOverride 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | 15 | # Distribution / packaging 16 | .Python 17 | env/ 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # dotenv 90 | .env 91 | 92 | # virtualenv 93 | .venv 94 | venv/ 95 | ENV/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | 110 | # IDE settings 111 | .vscode/ 112 | .idea/ 113 | 114 | # mkdocs build dir 115 | site/ 116 | 117 | .creds/ 118 | 119 | data/ 120 | 121 | .pdm-python 122 | .aider* 123 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/Lucas-C/pre-commit-hooks 3 | rev: v1.1.9 4 | hooks: 5 | - id: forbid-crlf 6 | - id: remove-crlf 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v3.4.0 9 | hooks: 10 | - id: trailing-whitespace 11 | - id: end-of-file-fixer 12 | - id: check-merge-conflict 13 | - id: check-yaml 14 | args: [ --unsafe ] 15 | - repo: https://github.com/astral-sh/ruff-pre-commit 16 | # Ruff version. 17 | rev: v0.3.2 18 | hooks: 19 | # Run the linter. 20 | - id: ruff 21 | args: [ --fix ] 22 | # Run the formatter. 23 | - id: ruff-format 24 | - repo: https://github.com/pappasam/toml-sort 25 | rev: v0.23.1 26 | hooks: 27 | - id: toml-sort 28 | - id: toml-sort-fix 29 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## 0.1.0 (2023-10-18) 4 | 5 | * First release on PyPI. 6 | -------------------------------------------------------------------------------- /CNAME: -------------------------------------------------------------------------------- 1 | docs.docprompt.io 2 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Contributions are welcome, and they are greatly appreciated! Every little bit 4 | helps, and credit will always be given. 5 | 6 | You can contribute in many ways: 7 | 8 | ## Types of Contributions 9 | 10 | ### Report Bugs 11 | 12 | Report bugs at https://github.com/psu3d0/docprompt/issues. 13 | 14 | If you are reporting a bug, please include: 15 | 16 | * Your operating system name and version. 17 | * Any details about your local setup that might be helpful in troubleshooting. 18 | * Detailed steps to reproduce the bug. 19 | 20 | ### Fix Bugs 21 | 22 | Look through the GitHub issues for bugs. Anything tagged with "bug" and "help 23 | wanted" is open to whoever wants to implement it. 24 | 25 | ### Implement Features 26 | 27 | Look through the GitHub issues for features. Anything tagged with "enhancement" 28 | and "help wanted" is open to whoever wants to implement it. 29 | 30 | ### Write Documentation 31 | 32 | Docprompt could always use more documentation, whether as part of the 33 | official Docprompt docs, in docstrings, or even on the web in blog posts, 34 | articles, and such. 35 | 36 | ### Submit Feedback 37 | 38 | The best way to send feedback is to file an issue at https://github.com/psu3d0/docprompt/issues. 39 | 40 | If you are proposing a feature: 41 | 42 | * Explain in detail how it would work. 43 | * Keep the scope as narrow as possible, to make it easier to implement. 44 | * Remember that this is a volunteer-driven project, and that contributions 45 | are welcome :) 46 | 47 | ## Get Started! 48 | 49 | Ready to contribute? Here's how to set up `docprompt` for local development. 50 | 51 | 1. Fork the `docprompt` repo on GitHub. 52 | 2. Clone your fork locally 53 | 54 | ``` 55 | $ git clone git@github.com:your_name_here/docprompt.git 56 | ``` 57 | 58 | 3. Ensure [pdm](https://pdm-project.org/en/latest/) is installed. 59 | 4. Install dependencies and start your virtualenv: 60 | 61 | ``` 62 | $ pdm install -d 63 | ``` 64 | 65 | 5. Create a branch for local development: 66 | 67 | ``` 68 | $ git checkout -b name-of-your-bugfix-or-feature 69 | ``` 70 | 71 | Now you can make your changes locally. 72 | 73 | 6. When you're done making changes, check that your changes pass the 74 | tests, including testing other Python versions, with tox: 75 | 76 | ``` 77 | $ pdm run tox 78 | ``` 79 | 80 | 7. Commit your changes and push your branch to GitHub: 81 | 82 | ``` 83 | $ git add . 84 | $ git commit -m "Your detailed description of your changes." 85 | $ git push origin name-of-your-bugfix-or-feature 86 | ``` 87 | 88 | 8. Submit a pull request through the GitHub website. 89 | 90 | ## Pull Request Guidelines 91 | 92 | Before you submit a pull request, check that it meets these guidelines: 93 | 94 | 1. The pull request should include tests. 95 | 2. If the pull request adds functionality, the docs should be updated. Put 96 | your new functionality into a function with a docstring, and add the 97 | feature to the list in README.md. 98 | 3. The pull request should work for Python 3.6, 3.7, 3.8 and 3.9. Check 99 | https://github.com/psu3d0/docprompt/actions 100 | and make sure that the tests pass for all supported Python versions. 101 | 102 | ## Tips 103 | 104 | ``` 105 | $ pdm run pytest tests/test_docprompt.py 106 | ``` 107 | 108 | To run a subset of tests. 109 | 110 | 111 | ## Deploying 112 | 113 | A reminder for the maintainers on how to deploy. 114 | Make sure all your changes are committed (including an entry in CHANGELOG.md). 115 | Then run: 116 | 117 | ``` 118 | $ pdm run bump2version patch # possible: major / minor / patch 119 | $ git push 120 | $ git push --tags 121 | ``` 122 | 123 | GitHub Actions will then deploy to PyPI if tests pass. 124 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache Software License 2.0 2 | 3 | Copyright (c) 2023, Frankie Colson 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | -------------------------------------------------------------------------------- /docprompt/__init__.py: -------------------------------------------------------------------------------- 1 | """Top-level package for Docprompt.""" 2 | 3 | __author__ = """Frankie Colson""" 4 | __email__ = "frank@pageleaf.io" 5 | __version__ = "0.8.7" 6 | 7 | from docprompt.rasterize import ProviderResizeRatios 8 | from docprompt.schema.document import Document, PdfDocument # noqa 9 | from docprompt.schema.layout import NormBBox, TextBlock # noqa 10 | from docprompt.schema.pipeline import DocumentCollection, DocumentNode, PageNode # noqa 11 | from docprompt.tasks.ocr.result import OcrPageResult # noqa 12 | from docprompt.utils import ( # noqa 13 | hash_from_bytes, 14 | load_document, 15 | load_document_node, 16 | load_documents, 17 | load_pdf_document, 18 | load_pdf_documents, 19 | ) 20 | 21 | # PdfDocument.model_rebuild() 22 | DocumentNode.model_rebuild() 23 | 24 | 25 | __all__ = [ 26 | "Document", 27 | "PdfDocument", 28 | "DocumentCollection", 29 | "DocumentNode", 30 | "NormBBox", 31 | "PageNode", 32 | "TextBlock", 33 | "load_document", 34 | "load_documents", 35 | "hash_from_bytes", 36 | "ProviderResizeRatios", 37 | "load_pdf_document", 38 | ] 39 | -------------------------------------------------------------------------------- /docprompt/_decorators.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import sys 3 | from functools import partial, update_wrapper, wraps 4 | from typing import Callable, Optional, Set, Tuple, Type 5 | 6 | if sys.version_info >= (3, 9): 7 | to_thread = asyncio.to_thread 8 | else: 9 | 10 | def to_thread(func, /, *args, **kwargs): 11 | @wraps(func) 12 | async def wrapper(): 13 | try: 14 | loop = asyncio.get_running_loop() 15 | except RuntimeError: 16 | # If there's no running event loop, create a new one 17 | loop = asyncio.new_event_loop() 18 | asyncio.set_event_loop(loop) 19 | pfunc = partial(func, *args, **kwargs) 20 | return await loop.run_in_executor(None, pfunc) 21 | 22 | return wrapper() 23 | 24 | 25 | def get_closest_attr(cls: Type, attr_name: str) -> Tuple[Type, Optional[Callable], int]: 26 | closest_cls = cls 27 | attr = getattr(cls.__dict__, attr_name, None) 28 | depth = 0 29 | 30 | if attr and hasattr(attr, "_original"): 31 | attr = None 32 | elif attr: 33 | return (cls, attr, 0) 34 | 35 | for idx, base in enumerate(cls.__mro__, start=1): 36 | if not attr and attr_name in base.__dict__: 37 | if not hasattr(base.__dict__[attr_name], "_original"): 38 | closest_cls = base 39 | attr = base.__dict__[attr_name] 40 | depth = idx 41 | 42 | if attr: 43 | break 44 | 45 | return (closest_cls, attr, depth) 46 | 47 | 48 | def validate_method(cls, name: str, method: Callable, expected_async: bool): 49 | if method is None: 50 | return None 51 | is_async = asyncio.iscoroutinefunction(method) 52 | if is_async != expected_async: 53 | return f"Method '{name}' in {cls.__name__} should be {'async' if expected_async else 'sync'}, but it's {'async' if is_async else 'sync'}" 54 | 55 | return None 56 | 57 | 58 | def apply_dual_methods_to_cls(cls: Type, method_group: Tuple[str, str]): 59 | errors = [] 60 | 61 | sync_name, async_name = method_group 62 | 63 | sync_trace = get_closest_attr(cls, sync_name) 64 | async_trace = get_closest_attr(cls, async_name) 65 | 66 | sync_cls, sync_method, sync_depth = sync_trace 67 | async_cls, async_method, async_depth = async_trace 68 | 69 | if sync_method: 70 | sync_error = validate_method(cls, sync_name, sync_method, False) 71 | if sync_error: 72 | errors.append(sync_error) 73 | 74 | if async_method: 75 | async_error = validate_method(cls, async_name, async_method, True) 76 | if async_error: 77 | errors.append(async_error) 78 | 79 | if ( 80 | sync_method is None 81 | and async_method is None 82 | and not getattr(getattr(cls, "Meta", None), "abstract", False) 83 | ): 84 | return [ 85 | f"{cls.__name__} must implement at least one of these methods: {sync_name}, {async_name}" 86 | ] 87 | 88 | if sync_cls is cls and async_cls is cls and sync_method and async_method: 89 | return errors # Both methods are already in the same class 90 | 91 | if async_cls is cls and async_method: 92 | 93 | def sync_wrapper(*args, **kwargs): 94 | return asyncio.run(async_method(*args, **kwargs)) 95 | 96 | update_wrapper(sync_wrapper, async_method) 97 | 98 | sync_wrapper._original = async_method 99 | 100 | setattr(cls, sync_name, sync_wrapper) 101 | elif sync_cls is cls and sync_method: 102 | 103 | async def async_wrapper(*args, **kwargs): 104 | if hasattr(sync_method, "__func__"): 105 | return await to_thread(sync_method, *args, **kwargs) 106 | return await to_thread(sync_method, *args, **kwargs) 107 | 108 | update_wrapper(async_wrapper, sync_method) 109 | 110 | async_wrapper._original = sync_method 111 | 112 | setattr(cls, async_name, async_wrapper) 113 | else: 114 | if async_depth < sync_depth: 115 | 116 | def sync_wrapper(*args, **kwargs): 117 | return asyncio.run(async_method(*args, **kwargs)) 118 | 119 | update_wrapper(sync_wrapper, async_method) 120 | 121 | sync_wrapper._original = async_method 122 | 123 | setattr(cls, sync_name, sync_wrapper) 124 | else: 125 | 126 | async def async_wrapper(*args, **kwargs): 127 | return await to_thread(sync_method, *args, **kwargs) 128 | 129 | update_wrapper(async_wrapper, sync_method) 130 | 131 | async_wrapper._original = sync_method 132 | 133 | setattr(cls, async_name, async_wrapper) 134 | 135 | return errors 136 | 137 | 138 | def get_flexible_method_configs(cls: Type) -> Set[Tuple[str, str]]: 139 | all = set() 140 | for base in cls.__mro__: 141 | all.update(getattr(base, "__flexible_methods__", set())) 142 | 143 | return all 144 | 145 | 146 | def flexible_methods(*method_groups: Tuple[str, str]): 147 | def decorator(cls: Type): 148 | if not hasattr(cls, "__flexible_methods__"): 149 | setattr(cls, "__flexible_methods__", set()) 150 | 151 | for base in cls.__bases__: 152 | if hasattr(base, "__flexible_methods__"): 153 | cls.__flexible_methods__.update(base.__flexible_methods__) 154 | 155 | cls.__flexible_methods__.update(method_groups) 156 | 157 | def apply_flexible_methods(cls: Type): 158 | errors = [] 159 | 160 | for group in get_flexible_method_configs(cls): 161 | if len(group) != 2: 162 | errors.append( 163 | f"Invalid method group {group}. Each group must be a tuple of exactly two method names." 164 | ) 165 | continue 166 | 167 | errors.extend(apply_dual_methods_to_cls(cls, group)) 168 | 169 | if errors: 170 | raise TypeError("\n".join(errors)) 171 | 172 | apply_flexible_methods(cls) 173 | 174 | original_init_subclass = cls.__init_subclass__ 175 | 176 | @classmethod 177 | def new_init_subclass(cls, **kwargs): 178 | original_init_subclass(**kwargs) 179 | apply_flexible_methods(cls) 180 | 181 | cls.__init_subclass__ = new_init_subclass 182 | 183 | return cls 184 | 185 | return decorator 186 | -------------------------------------------------------------------------------- /docprompt/_exec/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/docprompt/_exec/__init__.py -------------------------------------------------------------------------------- /docprompt/_exec/ghostscript.py: -------------------------------------------------------------------------------- 1 | from os import PathLike 2 | from pathlib import Path 3 | from subprocess import PIPE, CompletedProcess, run 4 | from typing import Literal, Union 5 | 6 | GS = "gs" 7 | 8 | 9 | class GhostscriptError(Exception): 10 | def __init__(self, message: str, process: CompletedProcess) -> None: 11 | self.process = process 12 | super().__init__(message) 13 | 14 | 15 | def compress_pdf( 16 | fp: Union[PathLike, str], # Ghostscript insists on a file instead of bytes 17 | output_path: str, 18 | *, 19 | compression: Literal["jpeg", "lossless"] = "jpeg", 20 | ): 21 | compression_args = [] 22 | if compression == "jpeg": 23 | compression_args = [ 24 | "-dAutoFilterColorImages=false", 25 | "-dColorImageFilter=/DCTEncode", 26 | "-dAutoFilterGrayImages=false", 27 | "-dGrayImageFilter=/DCTEncode", 28 | ] 29 | elif compression == "lossless": 30 | compression_args = [ 31 | "-dAutoFilterColorImages=false", 32 | "-dColorImageFilter=/FlateEncode", 33 | "-dAutoFilterGrayImages=false", 34 | "-dGrayImageFilter=/FlateEncode", 35 | ] 36 | else: 37 | compression_args = [ 38 | "-dAutoFilterColorImages=true", 39 | "-dAutoFilterGrayImages=true", 40 | ] 41 | 42 | args_gs = ( 43 | [ 44 | GS, 45 | "-q", 46 | "-dBATCH", 47 | "-dNOPAUSE", 48 | "-dSAFER", 49 | "-dCompatibilityLevel=1.5", 50 | "-sDEVICE=pdfwrite", 51 | "-dAutoRotatePages=/None", 52 | "-sColorConversionStrategy=LeaveColorUnchanged", 53 | ] 54 | + compression_args 55 | + [ 56 | "-dJPEGQ=95", 57 | "-dPDFA=2", 58 | "-dPDFACompatibilityPolicy=1", 59 | "-sOutputFile=" + output_path, 60 | str(fp), 61 | ] 62 | ) 63 | 64 | result = run(args_gs, stdout=PIPE, stderr=PIPE, check=False) 65 | 66 | if result.returncode != 0: 67 | raise GhostscriptError("Ghostscript failed to compress the document", result) 68 | 69 | return result 70 | 71 | 72 | def compress_pdf_to_bytes( 73 | fp: Union[PathLike, str], *, compression: Literal["jpeg", "lossless"] = "jpeg" 74 | ) -> bytes: 75 | result = compress_pdf(fp, output_path="%stdout", compression=compression) 76 | 77 | return result.stdout 78 | 79 | 80 | def compress_pdf_to_path( 81 | fp: Union[PathLike, str], 82 | output_path: PathLike, 83 | *, 84 | compression: Literal["jpeg", "lossless"] = "jpeg", 85 | ) -> Path: 86 | compress_pdf(fp, output_path=str(output_path), compression=compression) 87 | 88 | return Path(output_path) 89 | -------------------------------------------------------------------------------- /docprompt/_exec/tesseract.py: -------------------------------------------------------------------------------- 1 | import io 2 | import re 3 | import xml.etree.ElementTree as ET 4 | from os import PathLike 5 | from subprocess import PIPE, CompletedProcess, run 6 | from typing import List, TypedDict, Union 7 | 8 | from PIL import Image 9 | 10 | TESSERACT = "tesseract" 11 | 12 | 13 | def check_tesseract_installed() -> bool: 14 | result = run( 15 | [TESSERACT, "--version"], stdout=PIPE, stderr=PIPE, check=False, text=True 16 | ) 17 | return result.returncode == 0 18 | 19 | 20 | class TesseractError(Exception): 21 | def __init__(self, message: str, process: CompletedProcess) -> None: 22 | self.process = process 23 | super().__init__(message) 24 | 25 | 26 | class BoundingBox(TypedDict): 27 | x: int 28 | y: int 29 | width: int 30 | height: int 31 | 32 | 33 | class Word(TypedDict): 34 | id: str 35 | bbox: BoundingBox 36 | text: str 37 | line_id: str 38 | block_id: str 39 | 40 | 41 | class Line(TypedDict): 42 | id: str 43 | bbox: BoundingBox 44 | text: str 45 | words: List[str] 46 | block_id: str 47 | 48 | 49 | class Block(TypedDict): 50 | id: str 51 | bbox: BoundingBox 52 | text: str 53 | 54 | 55 | class OCRResult(TypedDict): 56 | blocks: List[Block] 57 | lines: List[Line] 58 | words: List[Word] 59 | language: str 60 | 61 | 62 | def process_image( 63 | fp: Union[PathLike, str], 64 | *, 65 | lang: str = "eng", 66 | config: List[str] = None, 67 | ) -> str: 68 | args_tesseract = [ 69 | TESSERACT, 70 | str(fp), 71 | "stdout", 72 | "-l", 73 | lang, 74 | ] 75 | 76 | if config is None: 77 | config = [] 78 | 79 | # Add HOCR output format to get word bounding boxes 80 | config.extend(["-c", "tessedit_create_hocr=1"]) 81 | 82 | args_tesseract.extend(config) 83 | 84 | result = run(args_tesseract, stdout=PIPE, stderr=PIPE, check=False, text=True) 85 | 86 | if result.returncode != 0: 87 | raise TesseractError( 88 | f"Tesseract failed to process the image, {result.stderr}", result 89 | ) 90 | 91 | return result.stdout 92 | 93 | 94 | def get_bbox(element) -> BoundingBox: 95 | title = element.get("title") 96 | bbox = [int(x) for x in title.split(";")[0].split(" ")[1:]] 97 | return BoundingBox( 98 | x=bbox[0], y=bbox[1], width=bbox[2] - bbox[0], height=bbox[3] - bbox[1] 99 | ) 100 | 101 | 102 | def clean_text(text: str) -> str: 103 | # Remove extra whitespace 104 | text = re.sub(r" \n ", " ", text) 105 | text = re.sub(r"\n+", "\n", text) 106 | text = re.sub(r" +", " ", text).strip() 107 | return text.strip() 108 | 109 | 110 | def process_image_to_dict( 111 | fp: Union[PathLike, str], 112 | *, 113 | lang: str = "eng", 114 | config: List[str] = None, 115 | ) -> OCRResult: 116 | hocr_content = process_image(fp, lang=lang, config=config) 117 | 118 | # Use StringIO to create a file-like object from the string 119 | hocr_file = io.StringIO(hocr_content) 120 | root = ET.parse(hocr_file).getroot() 121 | 122 | # Get image dimensions 123 | image = Image.open(fp) 124 | img_width, img_height = image.size 125 | image.close() 126 | 127 | blocks: List[Block] = [] 128 | lines: List[Line] = [] 129 | words: List[Word] = [] 130 | 131 | block_id = 0 132 | line_id = 0 133 | word_id = 0 134 | 135 | def normalize_bbox(bbox: BoundingBox) -> BoundingBox: 136 | return BoundingBox( 137 | x=bbox["x"] / img_width, 138 | y=bbox["y"] / img_height, 139 | width=bbox["width"] / img_width, 140 | height=bbox["height"] / img_height, 141 | ) 142 | 143 | for block in root.findall(".//*[@class='ocr_carea']"): 144 | block_bbox = normalize_bbox(get_bbox(block)) 145 | block_text = clean_text(" ".join(block.itertext())) 146 | blocks.append(Block(id=f"block_{block_id}", bbox=block_bbox, text=block_text)) 147 | 148 | for line in block.findall(".//*[@class='ocr_line']"): 149 | line_bbox = normalize_bbox(get_bbox(line)) 150 | line_words: List[str] = [] 151 | line_text: List[str] = [] 152 | 153 | for word in line.findall(".//*[@class='ocrx_word']"): 154 | word_bbox = normalize_bbox(get_bbox(word)) 155 | word_text = clean_text(word.text) if word.text else "" 156 | words.append( 157 | Word( 158 | id=f"word_{word_id}", 159 | bbox=word_bbox, 160 | text=word_text, 161 | line_id=f"line_{line_id}", 162 | block_id=f"block_{block_id}", 163 | ) 164 | ) 165 | line_words.append(f"word_{word_id}") 166 | line_text.append(word_text) 167 | word_id += 1 168 | 169 | lines.append( 170 | Line( 171 | id=f"line_{line_id}", 172 | bbox=line_bbox, 173 | text=" ".join(line_text), 174 | words=line_words, 175 | block_id=f"block_{block_id}", 176 | ) 177 | ) 178 | line_id += 1 179 | 180 | block_id += 1 181 | 182 | return OCRResult( 183 | blocks=blocks, 184 | lines=lines, 185 | words=words, 186 | language=lang, 187 | ) 188 | -------------------------------------------------------------------------------- /docprompt/contrib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/docprompt/contrib/__init__.py -------------------------------------------------------------------------------- /docprompt/contrib/parser_bot/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/docprompt/contrib/parser_bot/__init__.py -------------------------------------------------------------------------------- /docprompt/provenance/__init__.py: -------------------------------------------------------------------------------- 1 | from .source import PageTextLocation, ProvenanceSource 2 | 3 | __all__ = ["ProvenanceSource", "PageTextLocation"] 4 | -------------------------------------------------------------------------------- /docprompt/provenance/source.py: -------------------------------------------------------------------------------- 1 | from typing import List, Literal, Optional 2 | 3 | from pydantic import BaseModel, Field, PositiveInt, computed_field 4 | 5 | from docprompt.schema.layout import TextBlock 6 | 7 | 8 | class PageTextLocation(BaseModel): 9 | """ 10 | Specifies the location of a piece of text in a page 11 | """ 12 | 13 | source_blocks: List[TextBlock] = Field( 14 | description="The source text blocks", repr=False 15 | ) 16 | text: str # Sometimes the source text is less than the textblock's text. 17 | score: float 18 | granularity: Literal["word", "line", "block"] = "block" 19 | 20 | merged_source_block: Optional[TextBlock] = Field(default=None) 21 | 22 | 23 | class ProvenanceSource(BaseModel): 24 | """ 25 | Bundled with some data, specifies exactly where a piece of verbatim text came from 26 | in a document. 27 | """ 28 | 29 | document_name: str 30 | page_number: PositiveInt 31 | text_location: Optional[PageTextLocation] = None 32 | 33 | @computed_field # type: ignore 34 | @property 35 | def source_block(self) -> Optional[TextBlock]: 36 | if self.text_location: 37 | if self.text_location.merged_source_block: 38 | return self.text_location.merged_source_block 39 | if self.text_location.source_blocks: 40 | return self.text_location.source_blocks[0] 41 | 42 | return None 43 | 44 | @property 45 | def text(self) -> str: 46 | if self.text_location: 47 | return "\n".join([block.text for block in self.text_location.source_blocks]) 48 | 49 | return "" 50 | -------------------------------------------------------------------------------- /docprompt/provenance/util.py: -------------------------------------------------------------------------------- 1 | import re 2 | from collections import defaultdict 3 | from typing import Any, Iterable, List, Optional 4 | 5 | from rapidfuzz import fuzz 6 | from rapidfuzz.utils import default_process 7 | 8 | from docprompt.schema.layout import NormBBox, TextBlock 9 | 10 | try: 11 | import tantivy 12 | except ImportError: 13 | raise ImportError("Could not import tantivy. Install with `docprompt[search]`") 14 | 15 | try: 16 | import networkx 17 | except ImportError: 18 | raise ImportError("Could not import networkx. Install with `docprompt[search]`") 19 | 20 | 21 | _prefix_regexs = [ 22 | re.compile(r"^\d+\.\s+"), 23 | re.compile(r"^\d+\.\d+\s+"), 24 | re.compile(r"^\*+\s+"), 25 | re.compile(r"^-+\s+"), 26 | ] 27 | 28 | 29 | def preprocess_query_text(text: str) -> str: 30 | """ 31 | Improve matching ability by applying some preprocessing to the query text. 32 | """ 33 | for regex in _prefix_regexs: 34 | text = regex.sub("", text) 35 | 36 | text = text.strip() 37 | 38 | text = text.replace('"', "") 39 | 40 | return text 41 | 42 | 43 | def word_tokenize(text: str) -> List[str]: 44 | """ 45 | Tokenize a string into words. 46 | """ 47 | return re.split(r"\s+", text) 48 | 49 | 50 | def create_tantivy_document_wise_block_index(): 51 | schema_builder = tantivy.SchemaBuilder() 52 | 53 | schema_builder.add_integer_field( 54 | "page_number", stored=True, indexed=True, fast=True 55 | ) 56 | schema_builder.add_text_field("block_type", stored=True) 57 | schema_builder.add_integer_field("block_page_idx", stored=True) 58 | schema_builder.add_text_field("content", stored=True) 59 | 60 | schema = schema_builder.build() 61 | 62 | index = tantivy.Index(schema=schema) 63 | 64 | return index 65 | 66 | 67 | def construct_valid_rtree_tuple(bbox: NormBBox): 68 | # For some reason sometimes the bounding box is invalid (top > bottom, x0 > x1 69 | # This function is to ensure that the bounding box is valid for the rtree index 70 | 71 | true_top = min(bbox.top, bbox.bottom) 72 | true_bottom = max(bbox.top, bbox.bottom) 73 | 74 | true_x0 = min(bbox.x0, bbox.x1) 75 | true_x1 = max(bbox.x0, bbox.x1) 76 | 77 | return (true_x0, true_top, true_x1, true_bottom) 78 | 79 | 80 | def insert_generator(bboxes: List[NormBBox], data: Optional[Iterable[Any]] = None): 81 | """ 82 | Make an iterator that yields tuples of (id, bbox, data) for insertion into an RTree index 83 | which improves performance massively. 84 | """ 85 | data = data or [None] * len(bboxes) 86 | 87 | for idx, (bbox, data_item) in enumerate(zip(bboxes, data)): 88 | yield (idx, construct_valid_rtree_tuple(bbox), data_item) 89 | 90 | 91 | def refine_block_to_word_level( 92 | source_block: TextBlock, 93 | intersecting_word_level_blocks: List[TextBlock], 94 | query: str, 95 | ): 96 | """ 97 | Create a new text block by merging the intersecting word level blocks that 98 | match the query. 99 | 100 | """ 101 | intersecting_word_level_blocks.sort( 102 | key=lambda x: (x.bounding_box.top, x.bounding_box.x0) 103 | ) 104 | 105 | tokenized_query = word_tokenize(query) 106 | 107 | if len(tokenized_query) == 1: 108 | fuzzified = default_process(tokenized_query[0]) 109 | for word_level_block in intersecting_word_level_blocks: 110 | if fuzz.ratio(fuzzified, default_process(word_level_block.text)) > 87.5: 111 | return word_level_block, [word_level_block] 112 | else: 113 | fuzzified_word_level_texts = [ 114 | default_process(word_level_block.text) 115 | for word_level_block in intersecting_word_level_blocks 116 | ] 117 | 118 | # Populate the block mapping 119 | token_block_mapping = defaultdict(set) 120 | 121 | first_word = tokenized_query[0] 122 | last_word = tokenized_query[-1] 123 | 124 | for token in tokenized_query: 125 | fuzzified_token = default_process(token) 126 | for i, word_level_block in enumerate(intersecting_word_level_blocks): 127 | if fuzz.ratio(fuzzified_token, fuzzified_word_level_texts[i]) > 87.5: 128 | token_block_mapping[token].add(i) 129 | 130 | graph = networkx.DiGraph() 131 | prev = tokenized_query[0] 132 | 133 | for i in token_block_mapping[prev]: 134 | graph.add_node(i) 135 | 136 | for token in tokenized_query[1:]: 137 | for prev_block in token_block_mapping[prev]: 138 | for block in sorted(token_block_mapping[token]): 139 | if block > prev_block: 140 | weight = ( 141 | (block - prev_block) ** 2 142 | ) # Square the distance to penalize large jumps, which encourages reading order 143 | graph.add_edge(prev_block, block, weight=weight) 144 | 145 | prev = token 146 | 147 | # Get every combination of first and last word 148 | first_word_blocks = token_block_mapping[first_word] 149 | last_word_blocks = token_block_mapping[last_word] 150 | 151 | combinations = sorted( 152 | [(x, y) for x in first_word_blocks for y in last_word_blocks if x < y], 153 | key=lambda x: abs(x[1] - x[0]), 154 | ) 155 | 156 | for start, end in combinations: 157 | try: 158 | path = networkx.shortest_path(graph, start, end, weight="weight") 159 | except networkx.NetworkXNoPath: 160 | continue 161 | except Exception: 162 | continue 163 | 164 | matching_blocks = [intersecting_word_level_blocks[i] for i in path] 165 | 166 | merged_bbox = NormBBox.combine( 167 | *[word_level_block.bounding_box for word_level_block in matching_blocks] 168 | ) 169 | 170 | merged_text = "" 171 | 172 | for word_level_block in matching_blocks: 173 | merged_text += word_level_block.text 174 | if not word_level_block.text.endswith(" "): 175 | merged_text += " " # Ensure there is a space between words 176 | 177 | return ( 178 | TextBlock( 179 | text=merged_text, 180 | type="block", 181 | bounding_box=merged_bbox, 182 | metadata=source_block.metadata, 183 | ), 184 | matching_blocks, 185 | ) 186 | -------------------------------------------------------------------------------- /docprompt/schema/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/docprompt/schema/__init__.py -------------------------------------------------------------------------------- /docprompt/schema/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | from .metadata import BaseMetadata 2 | from .node import DocumentCollection, DocumentNode, ImageNode, PageNode 3 | 4 | __all__ = [ 5 | "DocumentCollection", 6 | "DocumentNode", 7 | "PageNode", 8 | "ImageNode", 9 | "BaseMetadata", 10 | ] 11 | -------------------------------------------------------------------------------- /docprompt/schema/pipeline/node/__init__.py: -------------------------------------------------------------------------------- 1 | from .collection import DocumentCollection 2 | from .document import DocumentNode 3 | from .image import ImageNode 4 | from .page import PageNode 5 | from .typing import DocumentNodeMetadata, PageNodeMetadata 6 | 7 | __all__ = [ 8 | "DocumentNode", 9 | "PageNode", 10 | "ImageNode", 11 | "DocumentNodeMetadata", 12 | "PageNodeMetadata", 13 | "DocumentCollection", 14 | ] 15 | -------------------------------------------------------------------------------- /docprompt/schema/pipeline/node/base.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class BaseNode(BaseModel): 5 | """The base node class is utilized for defining a basic yet flexible interface""" 6 | -------------------------------------------------------------------------------- /docprompt/schema/pipeline/node/collection.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Generic, List 2 | 3 | from pydantic import BaseModel, Field 4 | 5 | from .typing import DocumentCollectionMetadata, DocumentNodeMetadata, PageNodeMetadata 6 | 7 | if TYPE_CHECKING: 8 | from .document import DocumentNode 9 | 10 | 11 | class DocumentCollection( 12 | BaseModel, 13 | Generic[DocumentCollectionMetadata, DocumentNodeMetadata, PageNodeMetadata], 14 | ): 15 | """ 16 | Represents a collection of documents with some common metadata 17 | """ 18 | 19 | document_nodes: List["DocumentNode[DocumentNodeMetadata, PageNodeMetadata]"] 20 | metadata: DocumentCollectionMetadata = Field(..., default_factory=dict) 21 | -------------------------------------------------------------------------------- /docprompt/schema/pipeline/node/image.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from typing import Generic 3 | 4 | from pydantic import Field, field_serializer, field_validator 5 | 6 | from docprompt.schema.pipeline.metadata import BaseMetadata 7 | 8 | from .base import BaseNode 9 | from .typing import ImageNodeMetadata 10 | 11 | 12 | class ImageNode(BaseNode, Generic[ImageNodeMetadata]): 13 | """ 14 | Represents a single image of any kind 15 | """ 16 | 17 | image: bytes 18 | 19 | metadata: ImageNodeMetadata = Field( 20 | description="Application-specific metadata for the image", 21 | default_factory=BaseMetadata, 22 | ) 23 | 24 | @field_serializer("image") 25 | def serialize_image(self, value): 26 | return base64.b64encode(value).decode("utf-8") 27 | 28 | @field_validator("image") 29 | @classmethod 30 | def validate_image(cls, value): 31 | if isinstance(value, bytes): 32 | return value 33 | 34 | return base64.b64decode(value) 35 | 36 | @property 37 | def pil_image(self): 38 | from io import BytesIO 39 | 40 | from PIL import Image 41 | 42 | return Image.open(BytesIO(self.image)) 43 | 44 | @property 45 | def cv2_image(self): 46 | try: 47 | import cv2 48 | except ImportError: 49 | raise ImportError("OpenCV is required to use this property") 50 | 51 | try: 52 | import numpy as np 53 | except ImportError: 54 | raise ImportError("Numpy is required to use this property") 55 | 56 | return cv2.imdecode(np.frombuffer(self.image, np.uint8), cv2.IMREAD_COLOR) 57 | -------------------------------------------------------------------------------- /docprompt/schema/pipeline/node/page.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Union 2 | 3 | from pydantic import Field, PositiveInt 4 | 5 | from docprompt.schema.pipeline.metadata import BaseMetadata 6 | from docprompt.schema.pipeline.rasterizer import PageRasterizer 7 | 8 | from .base import BaseNode 9 | from .typing import PageNodeMetadata 10 | 11 | if TYPE_CHECKING: 12 | from docprompt.tasks.ocr.result import OcrPageResult 13 | 14 | from .document import DocumentNode 15 | 16 | 17 | class SimplePageNodeMetadata(BaseMetadata): 18 | """ 19 | A simple metadata class for a page node 20 | """ 21 | 22 | ocr_results: Optional["OcrPageResult"] = Field( 23 | None, description="The OCR results for the page" 24 | ) 25 | 26 | 27 | class PageNode(BaseNode, Generic[PageNodeMetadata]): 28 | """ 29 | Represents a single page in a document, with some metadata 30 | """ 31 | 32 | document: "DocumentNode" = Field(exclude=True, repr=False) 33 | page_number: PositiveInt = Field(description="The page number") 34 | metadata: Union[PageNodeMetadata, SimplePageNodeMetadata] = Field( 35 | description="Application-specific metadata for the page", 36 | default_factory=SimplePageNodeMetadata, 37 | ) 38 | extra: Dict[str, Any] = Field( 39 | description="Extra data that can be stored on the page node", 40 | default_factory=dict, 41 | ) 42 | 43 | @property 44 | def rasterizer(self): 45 | return PageRasterizer(self) 46 | 47 | @property 48 | def ocr_results(self) -> Optional["OcrPageResult"]: 49 | from docprompt.tasks.ocr.result import OcrPageResult 50 | 51 | return self.metadata.find_by_type(OcrPageResult) 52 | 53 | @ocr_results.setter 54 | def ocr_results(self, value): 55 | if not hasattr(self.metadata, "ocr_results"): 56 | raise AttributeError( 57 | "Page metadata does not have an `ocr_results` attribute" 58 | ) 59 | 60 | self.metadata.ocr_results = value 61 | 62 | def search( 63 | self, query: str, refine_to_words: bool = True, require_exact_match: bool = True 64 | ): 65 | return self.document.locator.search( 66 | query, 67 | page_number=self.page_number, 68 | refine_to_word=refine_to_words, 69 | require_exact_match=require_exact_match, 70 | ) 71 | 72 | def get_layout_aware_text(self, **kwargs) -> str: 73 | if not self.ocr_results: 74 | raise ValueError("Calculate OCR results before calling layout_aware_text") 75 | 76 | from docprompt.utils.layout import build_layout_aware_page_representation 77 | 78 | word_blocks = self.ocr_results.word_level_blocks 79 | 80 | line_blocks = self.ocr_results.line_level_blocks 81 | 82 | if not len(line_blocks): 83 | line_blocks = None 84 | 85 | return build_layout_aware_page_representation( 86 | word_blocks, line_blocks=line_blocks, **kwargs 87 | ) 88 | 89 | @property 90 | def layout_aware_text(self): 91 | return self.get_layout_aware_text() 92 | -------------------------------------------------------------------------------- /docprompt/schema/pipeline/node/typing.py: -------------------------------------------------------------------------------- 1 | from typing import TypeVar 2 | 3 | from ..metadata import BaseMetadata 4 | 5 | ImageNodeMetadata = TypeVar("ImageNodeMetadata", bound=BaseMetadata) 6 | PageNodeMetadata = TypeVar("PageNodeMetadata", bound=BaseMetadata) 7 | DocumentNodeMetadata = TypeVar("DocumentNodeMetadata", bound=BaseMetadata) 8 | DocumentCollectionMetadata = TypeVar("DocumentCollectionMetadata", bound=BaseMetadata) 9 | -------------------------------------------------------------------------------- /docprompt/tasks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/docprompt/tasks/__init__.py -------------------------------------------------------------------------------- /docprompt/tasks/capabilities.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class PageLevelCapabilities(str, Enum): 5 | """ 6 | Represents a capability that a provider can fulfill 7 | """ 8 | 9 | PAGE_RASTERIZATION = "page-rasterization" 10 | PAGE_LAYOUT_OCR = "page-layout-ocr" 11 | PAGE_TEXT_OCR = "page-text-ocr" 12 | PAGE_CLASSIFICATION = "page-classification" 13 | PAGE_MARKERIZATION = "page-markerization" 14 | PAGE_SEGMENTATION = "page-segmentation" 15 | PAGE_VQA = "page-vqa" 16 | PAGE_TABLE_IDENTIFICATION = "page-table-identification" 17 | PAGE_TABLE_EXTRACTION = "page-table-extraction" 18 | 19 | 20 | class DocumentLevelCapabilities(str, Enum): 21 | DOCUMENT_VQA = "multi-page-document-vqa" 22 | -------------------------------------------------------------------------------- /docprompt/tasks/classification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/docprompt/tasks/classification/__init__.py -------------------------------------------------------------------------------- /docprompt/tasks/classification/anthropic.py: -------------------------------------------------------------------------------- 1 | """The anthropic implementation of page level classification.""" 2 | 3 | import re 4 | from typing import Iterable, List 5 | 6 | from pydantic import Field 7 | 8 | from docprompt.tasks.message import OpenAIComplexContent, OpenAIImageURL, OpenAIMessage 9 | from docprompt.utils import inference 10 | 11 | from .base import ( 12 | BaseClassificationProvider, 13 | BasePageClassificationOutputParser, 14 | ClassificationConfig, 15 | ClassificationOutput, 16 | ) 17 | 18 | 19 | def get_classification_system_prompt(input: ClassificationConfig) -> str: 20 | prompt_parts = [ 21 | "You are a classification expert. You are given a single page to perform a classification task on.\n" 22 | ] 23 | 24 | if input.instructions: 25 | prompt_parts.append(f"Task Instructions:\n{input.instructions}\n\n") 26 | 27 | if input.type == "binary": 28 | prompt_parts.append( 29 | 'You must classify the page with a binary label:\n"YES"/"NO"\n' 30 | ) 31 | else: 32 | classification_task = ( 33 | "all labels that apply" 34 | if input.type == "multi_label" 35 | else "one of the following" 36 | ) 37 | prompt_parts.append(f"Classify the page as {classification_task}:\n") 38 | for label in input.formatted_labels: 39 | prompt_parts.append(f"- {label}\n") 40 | prompt_parts.append( 41 | "\nThese are the only label values you may use when providing your classifications!\n" 42 | ) 43 | 44 | prompt_parts.append( 45 | "\nIt is crucial that your response is accurate and provides a valid answer using " 46 | ) 47 | if input.type == "multi_label": 48 | prompt_parts.append("the labels ") 49 | else: 50 | prompt_parts.append("one of the labels ") 51 | prompt_parts.append( 52 | "above. There are consequences for providing INVALID or INACCURATE labels.\n\n" 53 | ) 54 | 55 | prompt_parts.append( 56 | "Answer in the following format:\n\nReasoning: { your reasoning and analysis }\n" 57 | ) 58 | 59 | if input.type == "binary": 60 | prompt_parts.append('Answer: { "YES" or "NO" }\n') 61 | elif input.type == "single_label": 62 | prompt_parts.append('Answer: { "label-value" }\n') 63 | else: 64 | prompt_parts.append('Answer: { "label-value", "label-value", ... }\n') 65 | 66 | if input.confidence: 67 | prompt_parts.append("Confidence: { low, medium, high }\n") 68 | 69 | prompt_parts.append( 70 | "\nYou MUST ONLY use the labels provided and described above. Do not use ANY additional labels.\n" 71 | ) 72 | 73 | return "".join(prompt_parts).strip() 74 | 75 | 76 | class AnthropicPageClassificationOutputParser(BasePageClassificationOutputParser): 77 | """The output parser for the page classification system.""" 78 | 79 | def parse(self, text: str) -> ClassificationOutput: 80 | """Parse the results of the classification task.""" 81 | pattern = re.compile(r"Answer:\s*(?:['\"`]?)(.+?)(?:['\"`]?)\s*$", re.MULTILINE) 82 | match = pattern.search(text) 83 | 84 | result = self.resolve_match(match) 85 | 86 | if self.confidence: 87 | conf_pattern = re.compile(r"Confidence: (.+)") 88 | conf_match = conf_pattern.search(text) 89 | conf_result = self.resolve_confidence(conf_match) 90 | 91 | return ClassificationOutput( 92 | type=self.type, 93 | labels=result, 94 | score=conf_result, 95 | provider_name=self.name, 96 | ) 97 | 98 | return ClassificationOutput( 99 | type=self.type, labels=result, provider_name=self.name 100 | ) 101 | 102 | 103 | def _prepare_messages( 104 | document_images: Iterable[bytes], 105 | config: ClassificationConfig, 106 | ): 107 | messages = [] 108 | 109 | for image_bytes in document_images: 110 | messages.append( 111 | [ 112 | OpenAIMessage( 113 | role="user", 114 | content=[ 115 | OpenAIComplexContent( 116 | type="image_url", 117 | image_url=OpenAIImageURL(url=image_bytes), 118 | ), 119 | OpenAIComplexContent( 120 | type="text", 121 | text=get_classification_system_prompt(config), 122 | ), 123 | ], 124 | ), 125 | ] 126 | ) 127 | 128 | return messages 129 | 130 | 131 | class AnthropicClassificationProvider(BaseClassificationProvider): 132 | """The Anthropic implementation of unscored page classification.""" 133 | 134 | name = "anthropic" 135 | 136 | anthropic_model_name: str = Field("claude-3-haiku-20240307") 137 | 138 | async def _ainvoke( 139 | self, input: Iterable[bytes], config: ClassificationConfig = None, **kwargs 140 | ) -> List[ClassificationOutput]: 141 | messages = _prepare_messages(input, config) 142 | 143 | parser = AnthropicPageClassificationOutputParser.from_task_input( 144 | config, provider_name=self.name 145 | ) 146 | 147 | model_name = kwargs.pop("model_name", self.anthropic_model_name) 148 | completions = await inference.run_batch_inference_anthropic( 149 | model_name, messages, **kwargs 150 | ) 151 | return [parser.parse(res) for res in completions] 152 | -------------------------------------------------------------------------------- /docprompt/tasks/credentials.py: -------------------------------------------------------------------------------- 1 | """The credentials module defines a simple model schema for storing credentials.""" 2 | 3 | import os 4 | from typing import Dict, Mapping, Optional 5 | 6 | from pydantic import BaseModel, Field, HttpUrl, SecretStr, model_validator 7 | from typing_extensions import Self 8 | 9 | 10 | class BaseCredentials(BaseModel): 11 | """The base credentials model.""" 12 | 13 | @property 14 | def kwargs(self) -> Dict[str, str]: 15 | """Return the credentials as a dictionary with secrets exposed.""" 16 | data = self.model_dump(exclude_none=True) 17 | for key, value in data.items(): 18 | if isinstance(value, SecretStr): 19 | data[key] = value.get_secret_value() 20 | return data 21 | 22 | 23 | class APIKeyCredential(BaseCredentials): 24 | """The API key credential model.""" 25 | 26 | api_key: SecretStr 27 | 28 | def __init__(self, environ_path: Optional[str] = None, **data): 29 | api_key = data.get("api_key", None) 30 | if api_key is None and environ_path: 31 | api_key = os.environ.get(environ_path, None) 32 | data["api_key"] = api_key 33 | 34 | super().__init__(**data) 35 | 36 | 37 | class GenericOpenAICredentials(APIKeyCredential): 38 | """Credentials that are common for OpenAI API requests.""" 39 | 40 | base_url: Optional[HttpUrl] = Field(None) 41 | timeout: Optional[int] = Field(None) 42 | max_retries: Optional[int] = Field(None) 43 | 44 | default_headers: Optional[Mapping[str, str]] = Field(None) 45 | default_query_params: Optional[Mapping[str, object]] = Field(None) 46 | 47 | 48 | class AWSCredentials(BaseCredentials): 49 | """The AWS credentials model.""" 50 | 51 | aws_access_key_id: Optional[SecretStr] = Field(None) 52 | aws_secret_access_key: Optional[SecretStr] = Field(None) 53 | aws_session_token: Optional[SecretStr] = Field(None) 54 | aws_region: Optional[str] = Field(None) 55 | 56 | def __init__(self, **data): 57 | aws_access_key_id = data.get( 58 | "aws_access_key_id", os.environ.get("AWS_ACCESS_KEY_ID", None) 59 | ) 60 | aws_secret_access_key = data.get( 61 | "aws_secret_access_key", os.environ.get("AWS_SECRET_ACCESS_KEY", None) 62 | ) 63 | aws_session_token = data.get( 64 | "aws_session_token", os.environ.get("AWS_SESSION_TOKEN", None) 65 | ) 66 | aws_region = data.get("aws_region", os.environ.get("AWS_DEFAULT_REGION", None)) 67 | 68 | super().__init__( 69 | aws_access_key_id=aws_access_key_id, 70 | aws_secret_access_key=aws_secret_access_key, 71 | aws_session_token=aws_session_token, 72 | aws_region=aws_region, 73 | ) 74 | 75 | @model_validator(mode="after") 76 | def _validate_aws_credentials(self) -> Self: 77 | """Ensure the provided AWS credentials are valid.""" 78 | 79 | key_pair_is_set = self.aws_access_key_id and self.aws_secret_access_key 80 | 81 | if not key_pair_is_set and not self.aws_session_token: 82 | raise ValueError( 83 | "You must provide either an AWS session token or an access key and secret key." 84 | ) 85 | 86 | if key_pair_is_set and not self.aws_region: 87 | raise ValueError( 88 | "You must provide an AWS region when using an access key and secret key." 89 | ) 90 | 91 | if key_pair_is_set and self.aws_session_token: 92 | raise ValueError( 93 | "You cannot provide both an AWS session token and an access key and secret key." 94 | ) 95 | 96 | return self 97 | 98 | 99 | class GCPServiceFileCredentials(BaseCredentials): 100 | """The GCP service account credentials model.""" 101 | 102 | service_account_info: Optional[Dict[str, str]] = Field(None) 103 | service_account_file: Optional[str] = Field(None) 104 | 105 | def __init__(self, **data): 106 | service_account_info = data.get("service_account_info", None) 107 | service_account_file = data.get( 108 | "service_account_file", os.environ.get("GCP_SERVICE_ACCOUNT_FILE", None) 109 | ) 110 | 111 | super().__init__( 112 | service_account_info=service_account_info, 113 | service_account_file=service_account_file, 114 | ) 115 | 116 | @model_validator(mode="after") 117 | def _validate_gcp_credentials(self) -> Self: 118 | """Ensure the provided GCP credentials are valid.""" 119 | if self.service_account_info is None and self.service_account_file is None: 120 | raise ValueError( 121 | "You must provide either service_account_info or service_account_file. You may set the `GCP_SERVICE_ACCOUNT_FILE` environment variable to the path of the service account file." 122 | ) 123 | if ( 124 | self.service_account_info is not None 125 | and self.service_account_file is not None 126 | ): 127 | raise ValueError( 128 | "You must provide either service_account_info or service_account_file, not both" 129 | ) 130 | return self 131 | -------------------------------------------------------------------------------- /docprompt/tasks/markerize/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/docprompt/tasks/markerize/__init__.py -------------------------------------------------------------------------------- /docprompt/tasks/markerize/anthropic.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, List, Optional 2 | 3 | from bs4 import BeautifulSoup 4 | from pydantic import Field 5 | 6 | from docprompt.tasks.message import OpenAIComplexContent, OpenAIImageURL, OpenAIMessage 7 | from docprompt.utils import inference 8 | 9 | from .base import BaseMarkerizeProvider, MarkerizeResult 10 | 11 | _HUMAN_MESSAGE_PROMPT = """ 12 | Convert the image into markdown, preserving the overall layout and style of the page. \ 13 | Use the appropriate headings for different sections. Preserve bolded and italicized text. \ 14 | Include ALL the text on the page. 15 | 16 | You ALWAYS respond by wrapping the markdown in tags. 17 | """.strip() 18 | 19 | 20 | def ensure_single_root(xml_data: str) -> str: 21 | """Ensure the XML data has a single root element.""" 22 | if not xml_data.strip().startswith("") and not xml_data.strip().endswith( 23 | "" 24 | ): 25 | return f"{xml_data}" 26 | return xml_data 27 | 28 | 29 | def _parse_result(raw_markdown: str) -> Optional[str]: 30 | raw_markdown = ensure_single_root(raw_markdown) 31 | soup = BeautifulSoup(raw_markdown, "html.parser") 32 | 33 | md = soup.find("md") 34 | 35 | return md.text.strip() if md else "" # TODO Fix bad extractions 36 | 37 | 38 | def _prepare_messages( 39 | document_images: Iterable[bytes], 40 | start: Optional[int] = None, 41 | stop: Optional[int] = None, 42 | ): 43 | messages = [] 44 | 45 | for image_bytes in document_images: 46 | messages.append( 47 | [ 48 | OpenAIMessage( 49 | role="user", 50 | content=[ 51 | OpenAIComplexContent( 52 | type="image_url", 53 | image_url=OpenAIImageURL(url=image_bytes), 54 | ), 55 | OpenAIComplexContent(type="text", text=_HUMAN_MESSAGE_PROMPT), 56 | ], 57 | ), 58 | ] 59 | ) 60 | 61 | return messages 62 | 63 | 64 | class AnthropicMarkerizeProvider(BaseMarkerizeProvider): 65 | name = "anthropic" 66 | 67 | anthropic_model_name: str = Field("claude-3-haiku-20240307") 68 | 69 | async def _ainvoke( 70 | self, input: Iterable[bytes], config: Optional[None] = None, **kwargs 71 | ) -> List[MarkerizeResult]: 72 | messages = _prepare_messages(input) 73 | 74 | model_name = kwargs.pop("model_name", self.anthropic_model_name) 75 | completions = await inference.run_batch_inference_anthropic( 76 | model_name, messages, **kwargs 77 | ) 78 | 79 | return [ 80 | MarkerizeResult(raw_markdown=_parse_result(x), provider_name=self.name) 81 | for x in completions 82 | ] 83 | -------------------------------------------------------------------------------- /docprompt/tasks/markerize/base.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from docprompt.schema.pipeline.node.document import DocumentNode 4 | from docprompt.tasks.base import AbstractPageTaskProvider, BasePageResult 5 | 6 | from ..capabilities import PageLevelCapabilities 7 | 8 | 9 | class MarkerizeResult(BasePageResult): 10 | task_name = "markerize" 11 | raw_markdown: str 12 | 13 | 14 | class BaseMarkerizeProvider(AbstractPageTaskProvider[bytes, None, MarkerizeResult]): 15 | capabilities = [PageLevelCapabilities.PAGE_MARKERIZATION] 16 | 17 | class Meta: 18 | abstract = True 19 | 20 | def process_document_node( 21 | self, 22 | document_node: "DocumentNode", 23 | task_config: Optional[None] = None, 24 | start: Optional[int] = None, 25 | stop: Optional[int] = None, 26 | contribute_to_document: bool = True, 27 | **kwargs, 28 | ): 29 | raster_bytes = [] 30 | for page_number in range(start or 1, (stop or len(document_node)) + 1): 31 | image_bytes = document_node.page_nodes[ 32 | page_number - 1 33 | ].rasterizer.rasterize("default") 34 | raster_bytes.append(image_bytes) 35 | 36 | # TODO: This is a somewhat dangerous way of requiring these kwargs to be drilled 37 | # through, potentially a decorator solution to be had here 38 | kwargs = {**self._default_invoke_kwargs, **kwargs} 39 | results = self._invoke(raster_bytes, config=task_config, **kwargs) 40 | 41 | return { 42 | i: res 43 | for i, res in zip( 44 | range(start or 1, (stop or len(document_node)) + 1), results 45 | ) 46 | } 47 | -------------------------------------------------------------------------------- /docprompt/tasks/message.py: -------------------------------------------------------------------------------- 1 | """ 2 | The core primatives for any language model interfacing. Docprompt uses these for the prompt garden, but 3 | supports free conversion to and from these types from other libaries. 4 | """ 5 | 6 | from typing import List, Literal, Optional, Union 7 | 8 | from pydantic import BaseModel, model_validator 9 | 10 | 11 | def _ensure_png_base64_prefix(base64_string: str): 12 | prefix = "data:image/png;base64," 13 | if base64_string.startswith(prefix): 14 | return base64_string 15 | else: 16 | return prefix + base64_string 17 | 18 | 19 | def _strip_png_base64_prefix(base64_string: str): 20 | prefix = "data:image/png;base64," 21 | if base64_string.startswith(prefix): 22 | return base64_string[len(prefix) :] 23 | else: 24 | return base64_string 25 | 26 | 27 | class OpenAIImageURL(BaseModel): 28 | url: str 29 | 30 | 31 | class OpenAIComplexContent(BaseModel): 32 | type: Literal["text", "image_url"] 33 | text: Optional[str] = None 34 | image_url: Optional[OpenAIImageURL] = None 35 | 36 | @model_validator(mode="after") 37 | def validate_content(cls, v): 38 | if v.type == "text" and v.text is None: 39 | raise ValueError("Text content must be provided when type is 'text'") 40 | if v.type == "image_url" and v.image_url is None: 41 | raise ValueError( 42 | "Image URL content must be provided when type is 'image_url'" 43 | ) 44 | 45 | if v.text is not None and v.image_url is not None: 46 | raise ValueError("Only one of text or image_url can be provided") 47 | 48 | return v 49 | 50 | def to_anthropic_message(self): 51 | if self.type == "text": 52 | return {"type": "text", "text": self.text} 53 | elif self.type == "image_url": 54 | return { 55 | "type": "image", 56 | "source": { 57 | "data": _strip_png_base64_prefix(self.image_url.url), 58 | "media_type": "image/png", 59 | "type": "base64", 60 | }, 61 | } 62 | else: 63 | raise ValueError(f"Invalid content type: {self.type}") 64 | 65 | 66 | class OpenAIMessage(BaseModel): 67 | role: Literal["system", "user", "assistant"] 68 | content: Union[str, List[OpenAIComplexContent]] 69 | 70 | def to_langchain_message(self): 71 | try: 72 | from langchain.schema import AIMessage, HumanMessage, SystemMessage 73 | except ImportError: 74 | raise ImportError( 75 | "Could not import langchain.schema. Install with `docprompt[langchain]`" 76 | ) 77 | 78 | role_mapping = { 79 | "system": SystemMessage, 80 | "user": HumanMessage, 81 | "assistant": AIMessage, 82 | } 83 | 84 | dumped = self.model_dump(mode="json", exclude_unset=True, exclude_none=True) 85 | 86 | return role_mapping[self.role](content=dumped["content"]) 87 | 88 | def to_openai(self): 89 | return self.model_dump(mode="json", exclude_unset=True, exclude_none=True) 90 | 91 | def to_llamaindex_chat_message(self): 92 | try: 93 | from llama_index.core.base.llms.types import ChatMessage, MessageRole 94 | except ImportError: 95 | raise ImportError( 96 | "Could not import llama_index.core. Install with `docprompt[llamaindex]`" 97 | ) 98 | 99 | role_mapping = { 100 | "system": MessageRole.SYSTEM, 101 | "user": MessageRole.USER, 102 | "assistant": MessageRole.ASSISTANT, 103 | } 104 | 105 | dumped = self.model_dump(mode="json", exclude_unset=True, exclude_none=True) 106 | 107 | return ChatMessage.from_str( 108 | content=dumped["content"], role=role_mapping[self.role] 109 | ) 110 | 111 | @classmethod 112 | def from_image_uri(cls, image_uri: str) -> "OpenAIMessage": 113 | """Create an image message from a URI. 114 | 115 | Args: 116 | role: The role of the message. 117 | image_uri: The URI of the image. 118 | """ 119 | image_url = OpenAIImageURL(url=image_uri) 120 | content = OpenAIComplexContent(type="image_url", image_url=image_url) 121 | message = cls(role="user", content=[content]) 122 | return message 123 | -------------------------------------------------------------------------------- /docprompt/tasks/ocr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/docprompt/tasks/ocr/__init__.py -------------------------------------------------------------------------------- /docprompt/tasks/ocr/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import TYPE_CHECKING, Dict, Optional, Union 3 | 4 | from docprompt.schema.document import PdfDocument 5 | from docprompt.schema.pipeline import DocumentNode 6 | from docprompt.tasks.base import AbstractPageTaskProvider 7 | from docprompt.tasks.ocr.result import OcrPageResult 8 | 9 | if TYPE_CHECKING: 10 | from docprompt.schema.pipeline import DocumentNode 11 | 12 | 13 | ImageBytes = bytes 14 | 15 | 16 | class BaseOCRProvider( 17 | AbstractPageTaskProvider[Union[PdfDocument, ImageBytes], None, OcrPageResult] 18 | ): 19 | def _populate_ocr_results( 20 | self, 21 | document_node: "DocumentNode", 22 | results: Dict[int, OcrPageResult], 23 | add_images_to_raster_cache: bool = False, 24 | raster_cache_key: str = "default", 25 | ) -> None: 26 | for page_number, result in results.items(): 27 | result.contribute_to_document_node( 28 | document_node, 29 | page_number=page_number, 30 | add_images_to_raster_cache=add_images_to_raster_cache, 31 | raster_cache_key=raster_cache_key, 32 | ) 33 | 34 | @abstractmethod 35 | def process_document_node( 36 | self, 37 | document_node: "DocumentNode", 38 | task_config: Optional[None] = None, 39 | start: Optional[int] = None, 40 | stop: Optional[int] = None, 41 | contribute_to_document: bool = True, 42 | **kwargs, 43 | ) -> Dict[int, OcrPageResult]: ... 44 | -------------------------------------------------------------------------------- /docprompt/tasks/ocr/result.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | from typing import Any, Dict, List, Optional 3 | 4 | from pydantic import Field 5 | 6 | from docprompt.schema.layout import TextBlock 7 | from docprompt.schema.pipeline.node.document import DocumentNode 8 | from docprompt.tasks.base import BasePageResult 9 | 10 | 11 | class OcrPageResult(BasePageResult): 12 | page_text: str = Field(description="The text for the entire page in reading order") 13 | 14 | word_level_blocks: List[TextBlock] = Field( 15 | default_factory=list, 16 | description="The provider-sourced words for the page", 17 | repr=False, 18 | ) 19 | line_level_blocks: List[TextBlock] = Field( 20 | default_factory=list, 21 | description="The provider-sourced lines for the page", 22 | repr=False, 23 | ) 24 | block_level_blocks: List[TextBlock] = Field( 25 | default_factory=list, 26 | description="The provider-sourced blocks for the page", 27 | repr=False, 28 | ) 29 | 30 | raster_image: Optional[bytes] = Field( 31 | default=None, 32 | description="The rasterized image of the page used in OCR", 33 | repr=False, 34 | ) 35 | 36 | extra: Optional[Dict[str, Any]] = Field(default_factory=dict) 37 | 38 | task_name = "ocr" 39 | 40 | @property 41 | def pil_image(self): 42 | if not self.raster_image: 43 | return None 44 | from PIL import Image 45 | 46 | return Image.open(BytesIO(self.raster_image)) 47 | 48 | @property 49 | def words(self): 50 | return self.word_level_blocks 51 | 52 | @property 53 | def lines(self): 54 | return self.line_level_blocks 55 | 56 | @property 57 | def blocks(self): 58 | return self.block_level_blocks 59 | 60 | def contribute_to_document_node( 61 | self, 62 | document_node: DocumentNode, 63 | page_number: Optional[int] = None, 64 | add_images_to_raster_cache: bool = False, 65 | raster_cache_key: str = "default", 66 | **kwargs, 67 | ) -> None: 68 | if not page_number: 69 | raise ValueError("Page number must be provided for page level results") 70 | 71 | page_node = document_node.page_nodes[page_number - 1] 72 | if hasattr(page_node.metadata, "ocr_results"): 73 | page_node.metadata.ocr_results = self.model_copy( 74 | update={"raster_image": None} 75 | ) 76 | else: 77 | super().contribute_to_document_node(document_node, page_number=page_number) 78 | 79 | if self.raster_image is not None and add_images_to_raster_cache: 80 | document_node.rasterizer.cache.set_image_for_page( 81 | key=raster_cache_key, 82 | page_number=page_number, 83 | image_bytes=self.raster_image, 84 | ) 85 | -------------------------------------------------------------------------------- /docprompt/tasks/parser.py: -------------------------------------------------------------------------------- 1 | """The base output parser that seeks to mimic the langhain implementation.""" 2 | 3 | from abc import abstractmethod 4 | from typing import TypeVar 5 | 6 | from pydantic import BaseModel 7 | from typing_extensions import Generic 8 | 9 | TTaskInput = TypeVar("TTaskInput", bound=BaseModel) 10 | TTaskOutput = TypeVar("TTaskOutput", bound=BaseModel) 11 | 12 | 13 | class BaseOutputParser(BaseModel, Generic[TTaskInput, TTaskOutput]): 14 | """The output parser for the page classification system.""" 15 | 16 | @abstractmethod 17 | def from_task_input( 18 | cls, task_input: TTaskInput 19 | ) -> "BaseOutputParser[TTaskInput, TTaskOutput]": 20 | """Create an output parser from the task input.""" 21 | 22 | @abstractmethod 23 | def parse(self, text: str) -> TTaskOutput: 24 | """Parse the results of the classification task.""" 25 | -------------------------------------------------------------------------------- /docprompt/tasks/result.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from collections.abc import MutableMapping 3 | from datetime import datetime 4 | from typing import TYPE_CHECKING, ClassVar, Dict, Generic, Optional, TypeVar 5 | 6 | from pydantic import BaseModel, Field 7 | 8 | if TYPE_CHECKING: 9 | from docprompt.schema.pipeline import DocumentNode 10 | 11 | 12 | class BaseResult(BaseModel): 13 | provider_name: str = Field( 14 | description="The name of the provider which produced the result" 15 | ) 16 | when: datetime = Field( 17 | default_factory=datetime.now, description="The time the result was produced" 18 | ) 19 | 20 | task_name: ClassVar[str] 21 | 22 | @property 23 | def task_key(self): 24 | return f"{self.provider_name}_{self.task_name}" 25 | 26 | @abstractmethod 27 | def contribute_to_document_node( 28 | self, document_node: "DocumentNode", **kwargs 29 | ) -> None: 30 | """ 31 | Contribute this task result to the document node or a specific page node. 32 | 33 | :param document_node: The DocumentNode to contribute to 34 | :param page_number: If provided, contribute to a specific page. If None, contribute to the document. 35 | """ 36 | 37 | 38 | class BaseDocumentResult(BaseResult): 39 | def contribute_to_document_node( 40 | self, document_node: "DocumentNode", **kwargs 41 | ) -> None: 42 | document_node.metadata.task_results[self.task_key] = self 43 | 44 | 45 | class BasePageResult(BaseResult): 46 | def contribute_to_document_node( 47 | self, document_node: "DocumentNode", page_number: Optional[int] = None, **kwargs 48 | ) -> None: 49 | assert ( 50 | page_number is not None 51 | ), "Page number must be provided for page level results" 52 | assert ( 53 | 0 < page_number <= len(document_node) 54 | ), "Page number must be less than or equal to the number of pages in the document" 55 | 56 | page_node = document_node.page_nodes[page_number - 1] 57 | page_node.metadata.task_results[self.task_key] = self.model_copy() 58 | 59 | 60 | TTaskInput = TypeVar("TTaskInput") # What invoke requires 61 | TTaskConfig = TypeVar("TTaskConfig") # Task specific config like classification labels 62 | PageTaskResult = TypeVar("PageTaskResult", bound=BasePageResult) 63 | DocumentTaskResult = TypeVar("DocumentTaskResult", bound=BaseDocumentResult) 64 | PageOrDocumentTaskResult = TypeVar("PageOrDocumentTaskResult", bound=BaseResult) 65 | 66 | 67 | class ResultContainer(BaseModel, MutableMapping, Generic[PageOrDocumentTaskResult]): 68 | results: Dict[str, PageOrDocumentTaskResult] = Field( 69 | description="The results of the task", default_factory=dict 70 | ) 71 | 72 | @property 73 | def result(self): 74 | return next(iter(self.results.values()), None) 75 | 76 | def __setitem__(self, key, value): 77 | if key in self.results: 78 | raise ValueError(f"Result with key {key} already exists") 79 | 80 | self.results[key] = value 81 | 82 | def __delitem__(self, key): 83 | del self.results[key] 84 | 85 | def __getitem__(self, key): 86 | return self.results[key] 87 | 88 | def __iter__(self): 89 | return iter(self.results) 90 | 91 | def __len__(self): 92 | return len(self.results) 93 | 94 | def __contains__(self, item): 95 | return item in self.results 96 | 97 | def __bool__(self): 98 | return bool(self.results) 99 | -------------------------------------------------------------------------------- /docprompt/tasks/table_extraction/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/docprompt/tasks/table_extraction/__init__.py -------------------------------------------------------------------------------- /docprompt/tasks/table_extraction/anthropic.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Iterable, List, Optional, Union 3 | 4 | from bs4 import BeautifulSoup, Tag 5 | from pydantic import Field 6 | 7 | from docprompt.tasks.message import OpenAIComplexContent, OpenAIImageURL, OpenAIMessage 8 | from docprompt.utils import inference 9 | 10 | from .base import BaseTableExtractionProvider 11 | from .schema import ( 12 | ExtractedTable, 13 | TableCell, 14 | TableExtractionPageResult, 15 | TableHeader, 16 | TableRow, 17 | ) 18 | 19 | SYSTEM_PROMPT = """ 20 | You are given an image. Identify and extract all tables from the document. 21 | 22 | For each table, respond in the following format: 23 | 24 | 25 | (value) 26 | 27 |
(value)
28 | ... 29 |
30 | 31 | 32 | (value) 33 | ... 34 | 35 | ... 36 |
37 | """.strip() 38 | 39 | 40 | def _title_from_tree(tree: Tag) -> Union[str, None]: 41 | title = tree.find("title") 42 | if title is not None: 43 | return title.text 44 | return None 45 | 46 | 47 | def _headers_from_tree(tree: Tag) -> List[TableHeader]: 48 | headers = tree.find("headers") 49 | if headers is not None: 50 | return [ 51 | TableHeader(text=header.text or "") for header in headers.find_all("header") 52 | ] 53 | return [] 54 | 55 | 56 | def _rows_from_tree(tree: Tag) -> List[TableRow]: 57 | rows = tree.find("rows") 58 | if rows is not None: 59 | return [ 60 | TableRow( 61 | cells=[ 62 | TableCell(text=cell.text or "") for cell in row.find_all("column") 63 | ] 64 | ) 65 | for row in rows.find_all("row") 66 | ] 67 | return [] 68 | 69 | 70 | def _find_start_indices(s: str, sub: str) -> List[int]: 71 | return [m.start() for m in re.finditer(sub, s)] 72 | 73 | 74 | def _find_end_indices(s: str, sub: str) -> List[int]: 75 | return [m.end() for m in re.finditer(sub, s)] 76 | 77 | 78 | def parse_response(response: str, **kwargs) -> TableExtractionPageResult: 79 | table_start_indices = _find_start_indices(response, "") 80 | table_end_indices = _find_end_indices(response, "
") 81 | 82 | tables: List[ExtractedTable] = [] 83 | provider_name = kwargs.pop("provider_name", "anthropic") 84 | 85 | for table_start, table_end in zip(table_start_indices, table_end_indices): 86 | table_str = response[table_start:table_end] 87 | 88 | soup = BeautifulSoup(table_str, "html.parser") 89 | 90 | table_element = soup.find("table") 91 | 92 | title = _title_from_tree(table_element) 93 | headers = _headers_from_tree(table_element) 94 | rows = _rows_from_tree(table_element) 95 | 96 | tables.append(ExtractedTable(title=title, headers=headers, rows=rows)) 97 | 98 | result = TableExtractionPageResult(tables=tables, provider_name=provider_name) 99 | return result 100 | 101 | 102 | def _prepare_messages( 103 | document_images: Iterable[bytes], 104 | start: Optional[int] = None, 105 | stop: Optional[int] = None, 106 | ): 107 | messages = [] 108 | 109 | for image_bytes in document_images: 110 | messages.append( 111 | [ 112 | OpenAIMessage( 113 | role="user", 114 | content=[ 115 | OpenAIComplexContent( 116 | type="image_url", 117 | image_url=OpenAIImageURL(url=image_bytes), 118 | ), 119 | OpenAIComplexContent(type="text", text=SYSTEM_PROMPT), 120 | ], 121 | ), 122 | ] 123 | ) 124 | 125 | return messages 126 | 127 | 128 | class AnthropicTableExtractionProvider(BaseTableExtractionProvider): 129 | name = "anthropic" 130 | 131 | anthropic_model_name: str = Field("claude-3-haiku-20240307") 132 | 133 | async def _ainvoke( 134 | self, input: Iterable[bytes], config: Optional[None] = None, **kwargs 135 | ) -> List[TableExtractionPageResult]: 136 | messages = _prepare_messages(input) 137 | 138 | model_name = kwargs.pop("model_name", self.anthropic_model_name) 139 | completions = await inference.run_batch_inference_anthropic( 140 | model_name, messages, **kwargs 141 | ) 142 | 143 | return [parse_response(x, provider_name=self.name) for x in completions] 144 | -------------------------------------------------------------------------------- /docprompt/tasks/table_extraction/base.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from docprompt import DocumentNode 4 | from docprompt.tasks.base import AbstractPageTaskProvider 5 | from docprompt.tasks.capabilities import PageLevelCapabilities 6 | 7 | from .schema import TableExtractionPageResult 8 | 9 | 10 | class BaseTableExtractionProvider( 11 | AbstractPageTaskProvider[bytes, None, TableExtractionPageResult] 12 | ): 13 | capabilities = [ 14 | PageLevelCapabilities.PAGE_TABLE_EXTRACTION, 15 | PageLevelCapabilities.PAGE_TABLE_IDENTIFICATION, 16 | ] 17 | 18 | class Meta: 19 | abstract = True 20 | 21 | def process_document_node( 22 | self, 23 | document_node: DocumentNode, 24 | task_config: Optional[None] = None, 25 | start: Optional[int] = None, 26 | stop: Optional[int] = None, 27 | contribute_to_document: bool = True, 28 | **kwargs, 29 | ): 30 | raster_bytes = [] 31 | for page_number in range(start or 1, (stop or len(document_node)) + 1): 32 | image_bytes = document_node.page_nodes[ 33 | page_number - 1 34 | ].rasterizer.rasterize("default") 35 | raster_bytes.append(image_bytes) 36 | 37 | # This will be a list of extracted tables?? 38 | results = self._invoke(raster_bytes, config=task_config, **kwargs) 39 | 40 | return { 41 | i: res 42 | for i, res in zip( 43 | range(start or 1, (stop or len(document_node)) + 1), results 44 | ) 45 | } 46 | -------------------------------------------------------------------------------- /docprompt/tasks/table_extraction/schema.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | from pydantic import BaseModel, Field 4 | 5 | from docprompt.schema.layout import NormBBox 6 | from docprompt.tasks.base import BasePageResult 7 | 8 | 9 | class TableHeader(BaseModel): 10 | text: str 11 | bbox: Optional[NormBBox] = None 12 | 13 | 14 | class TableCell(BaseModel): 15 | text: str 16 | bbox: Optional[NormBBox] = None 17 | 18 | 19 | class TableRow(BaseModel): 20 | cells: List[TableCell] = Field(default_factory=list) 21 | bbox: Optional[NormBBox] = None 22 | 23 | 24 | class ExtractedTable(BaseModel): 25 | title: Optional[str] = None 26 | bbox: Optional[NormBBox] = None 27 | 28 | headers: List[TableHeader] = Field(default_factory=list) 29 | rows: List[TableRow] = Field(default_factory=list) 30 | 31 | def to_markdown_string(self) -> str: 32 | markdown = "" 33 | 34 | # Add title if present 35 | if self.title: 36 | markdown += f"# {self.title}\n\n" 37 | 38 | # Create header row 39 | header_row = "|" + "|".join(header.text for header in self.headers) + "|\n" 40 | markdown += header_row 41 | 42 | # Create separator row 43 | separator_row = "|" + "|".join("---" for _ in self.headers) + "|\n" 44 | markdown += separator_row 45 | 46 | # Create data rows 47 | for row in self.rows: 48 | data_row = "|" + "|".join(cell.text for cell in row.cells) + "|\n" 49 | markdown += data_row 50 | 51 | return markdown.strip() 52 | 53 | 54 | class TableExtractionPageResult(BasePageResult): 55 | tables: List[ExtractedTable] = Field(default_factory=list) 56 | 57 | task_name = "table_extraction" 58 | -------------------------------------------------------------------------------- /docprompt/tasks/util.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from contextvars import ContextVar 3 | from typing import Any, Dict, Iterator 4 | 5 | _init_context_var = ContextVar("_init_context_var", default=None) 6 | 7 | 8 | @contextmanager 9 | def init_context(value: Dict[str, Any]) -> Iterator[None]: 10 | token = _init_context_var.set(value) 11 | try: 12 | yield 13 | finally: 14 | _init_context_var.reset(token) 15 | -------------------------------------------------------------------------------- /docprompt/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .date_extraction import extract_dates_from_text 2 | from .util import ( 3 | get_page_count, 4 | hash_from_bytes, 5 | is_pdf, 6 | load_document, 7 | load_document_node, 8 | load_documents, 9 | load_pdf_document, 10 | load_pdf_documents, 11 | ) 12 | 13 | __all__ = [ 14 | "get_page_count", 15 | "is_pdf", 16 | "load_pdf_document", 17 | "load_pdf_documents", 18 | "load_document", 19 | "load_documents", 20 | "hash_from_bytes", 21 | "extract_dates_from_text", 22 | "load_document_node", 23 | ] 24 | -------------------------------------------------------------------------------- /docprompt/utils/compressor.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | from typing import Literal 3 | 4 | from docprompt._exec.ghostscript import compress_pdf_to_bytes 5 | 6 | 7 | def compress_pdf_bytes( 8 | file_bytes: bytes, *, compression: Literal["jpeg", "lossless"] = "jpeg" 9 | ) -> bytes: 10 | with tempfile.NamedTemporaryFile(suffix=".pdf") as temp_file: 11 | temp_file.write(file_bytes) 12 | temp_file.flush() 13 | 14 | return compress_pdf_to_bytes(temp_file.name, compression=compression) 15 | -------------------------------------------------------------------------------- /docprompt/utils/date_extraction.py: -------------------------------------------------------------------------------- 1 | import re 2 | from datetime import date, datetime 3 | from typing import List, Tuple 4 | 5 | DateFormatsType = List[Tuple[re.Pattern, str]] 6 | 7 | default_date_formats = [ 8 | # Pre-compile regex patterns for efficiency 9 | # YYYY-MM-DD 10 | ( 11 | re.compile(r"\b((19|20)\d\d[-](0?[1-9]|1[012])[-](0?[1-9]|[12][0-9]|3[01]))\b"), 12 | "%Y-%m-%d", 13 | ), 14 | # YY-MM-DD 15 | ( 16 | re.compile(r"\b((\d\d)[-](0?[1-9]|1[012])[-](0?[1-9]|[12][0-9]|3[01]))\b"), 17 | "%y-%m-%d", 18 | ), 19 | # MM-DD-YYYY 20 | ( 21 | re.compile(r"\b((0?[1-9]|1[012])[-](0?[1-9]|[12][0-9]|3[01])[-](19|20)\d\d)\b"), 22 | "%m-%d-%Y", 23 | ), 24 | # MM-DD-YY 25 | ( 26 | re.compile(r"\b((0?[1-9]|1[012])[-](0?[1-9]|[12][0-9]|3[01])[-](\d\d))\b"), 27 | "%m-%d-%y", 28 | ), 29 | # DD-MM-YYYY 30 | ( 31 | re.compile(r"\b((0?[1-9]|[12][0-9]|3[01])[-](0?[1-9]|1[012])[-](19|20)\d\d)\b"), 32 | "%d-%m-%Y", 33 | ), 34 | # DD-MM-YY 35 | ( 36 | re.compile(r"\b((0?[1-9]|[12][0-9]|3[01])[-](0?[1-9]|1[012])[-](\d\d))\b"), 37 | "%d-%m-%y", 38 | ), 39 | # YYYY/MM/DD 40 | ( 41 | re.compile(r"\b((19|20)\d\d[/](0?[1-9]|1[012])[/](0?[1-9]|[12][0-9]|3[01]))\b"), 42 | "%Y/%m/%d", 43 | ), 44 | # YY/MM/DD 45 | ( 46 | re.compile(r"\b((\d\d)[/](0?[1-9]|1[012])[/](0?[1-9]|[12][0-9]|3[01]))\b"), 47 | "%y/%m/%d", 48 | ), 49 | # MM/DD/YYYY 50 | ( 51 | re.compile(r"\b((0?[1-9]|1[012])[/](0?[1-9]|[12][0-9]|3[01])[/](19|20)\d\d)\b"), 52 | "%m/%d/%Y", 53 | ), 54 | # MM/DD/YY 55 | ( 56 | re.compile(r"\b((0?[1-9]|1[012])[/](0?[1-9]|[12][0-9]|3[01])[/](\d\d))\b"), 57 | "%m/%d/%y", 58 | ), 59 | # DD/MM/YYYY 60 | ( 61 | re.compile(r"\b((0?[1-9]|[12][0-9]|3[01])[/](0?[1-9]|1[012])[/](19|20)\d\d)\b"), 62 | "%d/%m/%Y", 63 | ), 64 | # DD/MM/YY 65 | ( 66 | re.compile(r"\b((0?[1-9]|[12][0-9]|3[01])[/](0?[1-9]|1[012])[/](\d\d))\b"), 67 | "%d/%m/%y", 68 | ), 69 | # YYYY.MM.DD 70 | ( 71 | re.compile(r"\b((19|20)\d\d[.](0?[1-9]|1[012])[.](0?[1-9]|[12][0-9]|3[01]))\b"), 72 | "%Y.%m.%d", 73 | ), 74 | # YY.MM.DD 75 | ( 76 | re.compile(r"\b((\d\d)[.](0?[1-9]|1[012])[.](0?[1-9]|[12][0-9]|3[01]))\b"), 77 | "%y.%m.%d", 78 | ), 79 | # MM.DD.YYYY 80 | ( 81 | re.compile(r"\b((0?[1-9]|1[012])[.](0?[1-9]|[12][0-9]|3[01])[.](19|20)\d\d)\b"), 82 | "%m.%d.%Y", 83 | ), 84 | # MM.DD.YY 85 | ( 86 | re.compile(r"\b((0?[1-9]|1[012])[.](0?[1-9]|[12][0-9]|3[01])[.](\d\d))\b"), 87 | "%m.%d.%y", 88 | ), 89 | # DD.MM.YYYY 90 | ( 91 | re.compile(r"\b((0?[1-9]|[12][0-9]|3[01])[.](0?[1-9]|1[012])[.](19|20)\d\d)\b"), 92 | "%d.%m.%Y", 93 | ), 94 | # DD.MM.YY 95 | ( 96 | re.compile(r"\b((0?[1-9]|[12][0-9]|3[01])[.](0?[1-9]|1[012])[.](\d\d))\b"), 97 | "%d.%m.%y", 98 | ), 99 | # MMMM DDth, YYYY - November 4th, 2023 100 | ( 101 | re.compile( 102 | r"\b((January|February|March|April|May|June|July|August|September|October|November|December)\s{1,6}\d{1,2}(st|nd|rd|th)\s{0,2},\s{1,6}\d{4})\b" 103 | ), 104 | "%B %d, %Y", 105 | ), 106 | # MMMM DD, YYYY - November 4, 2023 107 | ( 108 | re.compile( 109 | r"\b((January|February|March|April|May|June|July|August|September|October|November|December)\s{1,6}\d{1,2}\s{0,2},\s{1,6}\d{4})\b" 110 | ), 111 | "%B %d, %Y", 112 | ), 113 | # MMM DDth, YYYY - Nov 4th, 2023 114 | ( 115 | re.compile( 116 | r"\b((Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\s{1,6}\d{1,2}(st|nd|rd|th)\s{0,2},\s{1,6}\d{4})\b" 117 | ), 118 | "%b %d, %Y", 119 | ), 120 | # MMM DD, YYYY - Nov 4, 2023 121 | ( 122 | re.compile( 123 | r"\b((Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\s{1,6}\d{1,2}\s{0,2},\s{1,6}\d{4})\b" 124 | ), 125 | "%b %d, %Y", 126 | ), 127 | ] 128 | 129 | 130 | def extract_dates_from_text( 131 | input_string: str, *, date_formats: DateFormatsType = default_date_formats 132 | ) -> List[Tuple[date, str]]: 133 | """ 134 | Extract dates from a string using a set of predefined regex patterns. 135 | 136 | Returns a list of tuples, where the first element is the date object and the second is the full date string. 137 | """ 138 | extracted_dates = [] 139 | 140 | for regex, date_format in date_formats: 141 | matches = regex.findall(input_string) 142 | 143 | for match_obj in matches: 144 | # Extract the full date from the match 145 | full_date = match_obj[0] # First group captures the entire date 146 | 147 | if "%d" in date_format: 148 | parse_date = re.sub(r"(st|nd|rd|th)", "", full_date) 149 | else: 150 | parse_date = full_date 151 | 152 | parse_date = re.sub(r"\s+", " ", parse_date).strip() 153 | parse_date = re.sub( 154 | r"\s{1,},", ",", parse_date 155 | ).strip() # Commas shouldnt have spaces before them 156 | 157 | # Convert to datetime object 158 | try: 159 | date_obj = datetime.strptime(parse_date, date_format) 160 | except ValueError as e: 161 | print(f"Error parsing date '{full_date}': {e}") 162 | continue 163 | 164 | extracted_dates.append((date_obj.date(), full_date)) 165 | 166 | return extracted_dates 167 | -------------------------------------------------------------------------------- /docprompt/utils/inference.py: -------------------------------------------------------------------------------- 1 | """A utility file for running inference with various LLM providers.""" 2 | 3 | import asyncio 4 | import os 5 | from typing import List 6 | 7 | from tenacity import ( 8 | retry, 9 | retry_if_exception_type, 10 | stop_after_attempt, 11 | wait_random_exponential, 12 | ) 13 | from tqdm.asyncio import tqdm 14 | 15 | from docprompt.tasks.message import OpenAIComplexContent, OpenAIMessage 16 | 17 | 18 | def get_anthropic_retry_decorator(): 19 | import anthropic 20 | 21 | return retry( 22 | wait=wait_random_exponential(multiplier=0.5, max=60), 23 | stop=stop_after_attempt(14), 24 | retry=retry_if_exception_type(anthropic.RateLimitError) 25 | | retry_if_exception_type(anthropic.InternalServerError) 26 | | retry_if_exception_type(anthropic.APITimeoutError), 27 | reraise=True, 28 | ) 29 | 30 | 31 | def get_openai_retry_decorator(): 32 | import openai 33 | 34 | return retry( 35 | wait=wait_random_exponential(multiplier=0.5, max=60), 36 | stop=stop_after_attempt(14), 37 | retry=retry_if_exception_type(openai.RateLimitError) 38 | | retry_if_exception_type(openai.InternalServerError) 39 | | retry_if_exception_type(openai.APITimeoutError), 40 | reraise=True, 41 | ) 42 | 43 | 44 | async def run_inference_anthropic( 45 | model_name: str, messages: List[OpenAIMessage], **kwargs 46 | ) -> str: 47 | """Run inference using an Anthropic model asynchronously.""" 48 | from anthropic import AsyncAnthropic 49 | 50 | api_key = kwargs.pop("api_key", os.environ.get("ANTHROPIC_API_KEY")) 51 | base_url = kwargs.pop("base_url", os.environ.get("ANTHROPIC_BASE_URL")) 52 | client = AsyncAnthropic(api_key=api_key, base_url=base_url) 53 | 54 | system = None 55 | if messages and messages[0].role == "system": 56 | system = messages[0].content 57 | messages = messages[1:] 58 | 59 | processed_messages = [] 60 | for msg in messages: 61 | if isinstance(msg.content, list): 62 | processed_content = [] 63 | for content in msg.content: 64 | if isinstance(content, OpenAIComplexContent): 65 | content = content.to_anthropic_message() 66 | processed_content.append(content) 67 | else: 68 | pass 69 | # raise ValueError(f"Invalid content type: {type(content)} Expected OpenAIComplexContent") 70 | 71 | dumped = msg.model_dump() 72 | dumped["content"] = processed_content 73 | processed_messages.append(dumped) 74 | else: 75 | processed_messages.append(msg) 76 | 77 | client_kwargs = { 78 | "model": model_name, 79 | "max_tokens": 2048, 80 | "messages": processed_messages, 81 | **kwargs, 82 | } 83 | 84 | if system: 85 | client_kwargs["system"] = system 86 | 87 | response = await client.messages.create(**client_kwargs) 88 | 89 | content = response.content[0].text 90 | 91 | return content 92 | 93 | 94 | async def run_batch_inference_anthropic( 95 | model_name: str, messages: List[List[OpenAIMessage]], **kwargs 96 | ) -> List[str]: 97 | """Run batch inference using an Anthropic model asynchronously.""" 98 | retry_decorator = get_anthropic_retry_decorator() 99 | 100 | @retry_decorator 101 | async def process_message_set(msg_set, index: int): 102 | return await run_inference_anthropic(model_name, msg_set, **kwargs), index 103 | 104 | tasks = [process_message_set(msg_set, i) for i, msg_set in enumerate(messages)] 105 | 106 | # TODO: Need cleaner implementation to ensure message ordering is perserved 107 | responses: List[str] = [] 108 | for f in tqdm(asyncio.as_completed(tasks), desc="Processing messages"): 109 | response, index = await f 110 | responses.append((response, index)) 111 | 112 | # Sort and extract the responses 113 | responses.sort(key=lambda x: x[1]) 114 | responses = [r[0] for r in responses] 115 | 116 | return responses 117 | -------------------------------------------------------------------------------- /docprompt/utils/masking/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/docprompt/utils/masking/__init__.py -------------------------------------------------------------------------------- /docprompt/utils/masking/image.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | from PIL import Image 4 | 5 | from docprompt.schema.layout import NormBBox 6 | 7 | ImageMaskModes = Literal["color", "average", "alpha"] 8 | 9 | 10 | def mask_image_from_bounding_boxes( 11 | image: Image.Image, 12 | *bounding_boxes: NormBBox, 13 | mask_color: str = "#000000", 14 | ): 15 | """ 16 | Create a copy of the image with the positions of the bounding boxes masked. 17 | """ 18 | 19 | width, height = image.size 20 | 21 | mask = Image.new("RGBA", (width, height), (0, 0, 0, 0)) 22 | 23 | for bbox in bounding_boxes: 24 | mask.paste( 25 | Image.new("RGBA", (bbox.width, bbox.height), mask_color), 26 | (int(bbox.x0 * width), int(bbox.top * height)), 27 | ) 28 | 29 | return Image.alpha_composite(image, mask) 30 | -------------------------------------------------------------------------------- /docprompt/utils/util.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | from concurrent.futures import ThreadPoolExecutor, as_completed 3 | from io import BytesIO 4 | from os import PathLike 5 | from pathlib import Path 6 | from typing import TYPE_CHECKING, List, Optional, Union 7 | from urllib.parse import unquote, urlparse 8 | 9 | import filetype 10 | import fsspec 11 | 12 | from docprompt._pdfium import get_pdfium_document 13 | from docprompt.schema.document import PdfDocument 14 | 15 | if TYPE_CHECKING: 16 | from docprompt.schema.pipeline.node.document import DocumentNode 17 | 18 | 19 | def is_pdf(fd: Union[Path, PathLike, bytes]) -> bool: 20 | """ 21 | Determines if a file is a PDF 22 | """ 23 | if isinstance(fd, (bytes, str)): 24 | mime = filetype.guess_mime(fd) 25 | else: 26 | with open(fd, "rb") as f: 27 | # We only need the first 1024 bytes to determine if it's a PDF 28 | mime = filetype.guess_mime(f.read(1024)) 29 | 30 | return mime == "application/pdf" 31 | 32 | 33 | def get_page_count(fd: Union[Path, PathLike, bytes]) -> int: 34 | """ 35 | Determines the number of pages in a PDF 36 | """ 37 | if not isinstance(fd, bytes): 38 | with open(fd, "rb") as f: 39 | fd = f.read() 40 | 41 | with get_pdfium_document(fd) as pdf: 42 | return len(pdf) 43 | 44 | 45 | def name_from_path(path: Union[Path, PathLike]) -> str: 46 | if not isinstance(path, Path): 47 | path = Path(path) 48 | 49 | file_name = path.name 50 | 51 | parsed = urlparse(file_name) 52 | 53 | return unquote(parsed.path) 54 | 55 | 56 | def read_pdf_bytes_from_path(path: Union[Path, PathLike], **kwargs) -> bytes: 57 | with fsspec.open(urlpath=str(path), mode="rb", **kwargs) as f: 58 | return f.read() 59 | 60 | 61 | def determine_pdf_name_from_bytes(file_bytes: bytes) -> str: 62 | """ 63 | Attempts to determine the name of a PDF by exaimining metadata 64 | """ 65 | with get_pdfium_document(file_bytes) as pdf: 66 | metadata_dict = pdf.get_metadata_dict(skip_empty=True) 67 | 68 | name = None 69 | 70 | if metadata_dict: 71 | name = ( 72 | metadata_dict.get("Title") 73 | or metadata_dict.get("Subject") 74 | or metadata_dict.get("Author") 75 | ) 76 | 77 | if name: 78 | return f"{name.strip()}.pdf" 79 | 80 | return f"document-{hash_from_bytes(file_bytes)}.pdf" 81 | 82 | 83 | def load_pdf_document( 84 | fp: Union[Path, PathLike, bytes], 85 | *, 86 | file_name: Optional[str] = None, 87 | password: Optional[str] = None, 88 | ) -> PdfDocument: 89 | """ 90 | Loads a document from a file path 91 | """ 92 | if isinstance(fp, bytes): 93 | file_bytes = fp 94 | file_name = file_name or determine_pdf_name_from_bytes(file_bytes) 95 | file_path = None 96 | else: 97 | file_name = name_from_path(fp) if file_name is None else file_name 98 | file_path = str(fp) 99 | file_bytes = read_pdf_bytes_from_path(fp) 100 | 101 | if not is_pdf(file_bytes): 102 | raise ValueError("File is not a PDF") 103 | 104 | return PdfDocument( 105 | name=unquote(file_name), 106 | file_path=file_path, 107 | file_bytes=file_bytes, 108 | password=password, 109 | ) 110 | 111 | 112 | def load_pdf_documents( 113 | fps: List[Union[Path, PathLike, bytes]], 114 | *, 115 | max_threads: int = 12, 116 | passwords: Optional[List[str]] = None, 117 | ): 118 | """ 119 | Loads multiple documents from file paths, using a thread pool 120 | """ 121 | futures = [] 122 | 123 | thread_count = min(max_threads, len(fps)) 124 | 125 | with ThreadPoolExecutor(max_workers=thread_count) as executor: 126 | for fp in fps: 127 | futures.append(executor.submit(load_document, fp)) 128 | 129 | results = [] 130 | 131 | for future in as_completed(futures): 132 | results.append(future.result()) 133 | 134 | return results 135 | 136 | 137 | def load_document_node( 138 | fp: Union[Path, PathLike, bytes], 139 | *, 140 | file_name: Optional[str] = None, 141 | password: Optional[str] = None, 142 | ) -> "DocumentNode": 143 | from docprompt.schema.pipeline.node.document import DocumentNode 144 | 145 | document = load_pdf_document(fp, file_name=file_name, password=password) 146 | 147 | return DocumentNode.from_document(document) 148 | 149 | 150 | load_document = load_pdf_document 151 | load_documents = load_pdf_documents 152 | 153 | 154 | def hash_from_bytes( 155 | byte_data: bytes, hash_func=hashlib.md5, threshold=1024 * 1024 * 128 156 | ) -> str: 157 | """ 158 | Gets a hash from bytes. If the bytes are larger than the threshold, the hash is computed in chunks 159 | to avoid memory issues. The default hash function is MD5 with a threshold of 128MB which is optimal 160 | for most machines and use cases. 161 | """ 162 | if len(byte_data) < 1024 * 1024 * 10: # 10MB 163 | return hashlib.md5(byte_data).hexdigest() 164 | 165 | hash = hash_func() 166 | 167 | if len(byte_data) > threshold: 168 | stream = BytesIO(byte_data) 169 | b = bytearray(128 * 1024) 170 | mv = memoryview(b) 171 | 172 | while n := stream.readinto(mv): 173 | hash.update(mv[:n]) 174 | else: 175 | hash.update(byte_data) 176 | 177 | return hash.hexdigest() 178 | -------------------------------------------------------------------------------- /docs/CNAME: -------------------------------------------------------------------------------- 1 | docs.docprompt.io 2 | -------------------------------------------------------------------------------- /docs/assets/static/img/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/docs/assets/static/img/logo.png -------------------------------------------------------------------------------- /docs/assets/static/img/old-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/docs/assets/static/img/old-logo.png -------------------------------------------------------------------------------- /docs/blog/index.md: -------------------------------------------------------------------------------- 1 | # Blog 2 | -------------------------------------------------------------------------------- /docs/community/contributing.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/docs/community/contributing.md -------------------------------------------------------------------------------- /docs/community/versioning.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/docs/community/versioning.md -------------------------------------------------------------------------------- /docs/concepts/nodes.md: -------------------------------------------------------------------------------- 1 | # Nodes in Docprompt 2 | 3 | ## Overview 4 | 5 | In Docprompt, nodes are fundamental structures used to represent and manage documents and their pages. They provide a way to store state and metadata associated with documents and individual pages, enabling advanced document analysis and processing capabilities. 6 | 7 | ## Key Concepts 8 | 9 | ### DocumentNode 10 | 11 | A `DocumentNode` represents a single document within the Docprompt system. It serves as a container for document-level metadata and provides access to individual pages through `PageNode` instances. 12 | 13 | ```python 14 | class DocumentNode(BaseModel, Generic[DocumentNodeMetadata, PageNodeMetadata]): 15 | document: Document 16 | page_nodes: List[PageNode[PageNodeMetadata]] 17 | metadata: Optional[DocumentNodeMetadata] 18 | ``` 19 | 20 | Key features: 21 | - Stores a reference to the underlying `Document` object 22 | - Maintains a list of `PageNode` instances representing individual pages 23 | - Allows for custom document-level metadata 24 | - Provides access to a `DocumentProvenanceLocator` for efficient text search within the document 25 | 26 | ### PageNode 27 | 28 | A `PageNode` represents a single page within a document. It stores page-specific information and provides access to various analysis results, such as OCR data. 29 | 30 | ```python 31 | class PageNode(BaseModel, Generic[PageNodeMetadata]): 32 | document: "DocumentNode" 33 | page_number: PositiveInt 34 | metadata: Optional[PageNodeMetadata] 35 | extra: Dict[str, Any] 36 | ocr_results: ResultContainer[OcrPageResult] 37 | ``` 38 | 39 | Key features: 40 | - References the parent `DocumentNode` 41 | - Stores the page number 42 | - Allows for custom page-level metadata 43 | - Provides a flexible `extra` field for additional data storage 44 | - Stores OCR results in a `ResultContainer` 45 | 46 | ## Usage 47 | 48 | ### Creating a DocumentNode 49 | 50 | You can create a `DocumentNode` from a `Document` instance: 51 | 52 | ```python 53 | from docprompt import load_document, DocumentNode 54 | 55 | document = load_document("path/to/my.pdf") 56 | document_node = DocumentNode.from_document(document) 57 | ``` 58 | 59 | ### Working with OCR Results 60 | 61 | After processing a document with an OCR provider, you can access the results through the `DocumentNode` and `PageNode` structures: 62 | 63 | ```python 64 | from docprompt.tasks.ocr.gcp import GoogleOcrProvider 65 | 66 | provider = GoogleOcrProvider.from_service_account_file( 67 | project_id=my_project_id, 68 | processor_id=my_processor_id, 69 | service_account_file=path_to_service_file 70 | ) 71 | 72 | provider.process_document_node(document_node) 73 | 74 | # Access OCR results for a specific page 75 | ocr_result = document_node.page_nodes[0].ocr_results 76 | ``` 77 | 78 | ### Using DocumentProvenanceLocator 79 | 80 | The `DocumentProvenanceLocator` is a powerful tool for searching text within a document: 81 | 82 | ```python 83 | # Search for text across the entire document 84 | results = document_node.locator.search("John Doe") 85 | 86 | # Search for text on a specific page 87 | page_results = document_node.locator.search("Jane Doe", page_number=4) 88 | ``` 89 | 90 | ## Benefits of Using Nodes 91 | 92 | 1. **Separation of Concerns**: Nodes allow you to separate the core PDF functionality (handled by the `Document` class) from additional metadata and analysis results. 93 | 94 | 2. **Flexible Metadata**: Both `DocumentNode` and `PageNode` support generic metadata types, allowing you to add custom, type-safe metadata to your documents and pages. 95 | 96 | 3. **Result Caching**: Nodes provide a convenient way to cache and access results from various analysis tasks, such as OCR. 97 | 98 | 4. **Efficient Text Search**: The `DocumentProvenanceLocator` enables fast text search capabilities, leveraging OCR results for improved performance. 99 | 100 | 5. **Extensibility**: The node structure allows for easy integration of new analysis tools and result types in the future. 101 | 102 | By using the node structure in Docprompt, you can build powerful document analysis workflows that combine the core PDF functionality with advanced processing and search capabilities. 103 | -------------------------------------------------------------------------------- /docs/concepts/primatives.md: -------------------------------------------------------------------------------- 1 | # Docprompt Primitives 2 | 3 | Docprompt uses several primitive objects that are fundamental to its operation. These primitives are used throughout the library and are essential for understanding how Docprompt processes and represents documents. 4 | 5 | ## PdfDocument 6 | 7 | The `PdfDocument` class is a core primitive in Docprompt, representing a PDF document with various utilities for manipulation and analysis. 8 | 9 | ```python 10 | class PdfDocument(BaseModel): 11 | name: str 12 | file_bytes: bytes 13 | file_path: Optional[str] = None 14 | ``` 15 | 16 | ### Key Features 17 | 18 | 1. **Document Properties** 19 | - `name`: The name of the document 20 | - `file_bytes`: The raw bytes of the PDF file 21 | - `file_path`: Optional path to the PDF file on disk 22 | - `page_count`: The number of pages in the document (computed field) 23 | - `document_hash`: A unique hash of the document (computed field) 24 | 25 | 2. **Utility Methods** 26 | - `from_path(file_path)`: Create a PdfDocument from a file path 27 | - `from_bytes(file_bytes, name)`: Create a PdfDocument from bytes 28 | - `get_page_render_size(page_number, dpi)`: Get the render size of a specific page 29 | - `to_compressed_bytes()`: Compress the PDF using Ghostscript 30 | - `rasterize_page(page_number, ...)`: Rasterize a specific page with various options 31 | - `rasterize_pdf(...)`: Rasterize the entire PDF 32 | - `split(start, stop)`: Split the PDF into a new document 33 | - `as_tempfile()`: Create a temporary file from the PDF 34 | - `write_to_path(path)`: Write the PDF to a specific path 35 | 36 | ### Usage Example 37 | 38 | ```python 39 | from docprompt import PdfDocument 40 | 41 | # Load a PDF 42 | pdf = PdfDocument.from_path("path/to/document.pdf") 43 | 44 | # Get document properties 45 | print(f"Document name: {pdf.name}") 46 | print(f"Page count: {pdf.page_count}") 47 | 48 | # Rasterize a page 49 | page_image = pdf.rasterize_page(1, dpi=300) 50 | 51 | # Split the document 52 | new_pdf = pdf.split(start=5, stop=10) 53 | ``` 54 | 55 | ## Layout Primitives 56 | 57 | Docprompt uses several layout primitives to represent the structure and content of documents. 58 | 59 | ### NormBBox 60 | 61 | `NormBBox` represents a normalized bounding box with values between 0 and 1. 62 | 63 | ```python 64 | class NormBBox(BaseModel): 65 | x0: BoundedFloat 66 | top: BoundedFloat 67 | x1: BoundedFloat 68 | bottom: BoundedFloat 69 | ``` 70 | 71 | Key features: 72 | - Intersection operations (`__and__`) 73 | - Union operations (`__add__`) 74 | - Intersection over Union (IoU) calculation 75 | - Area and centroid properties 76 | 77 | ### TextBlock 78 | 79 | `TextBlock` represents a block of text within a document, including its bounding box and metadata. 80 | 81 | ```python 82 | class TextBlock(BaseModel): 83 | text: str 84 | type: SegmentLevels 85 | source: TextblockSource 86 | bounding_box: NormBBox 87 | bounding_poly: Optional[BoundingPoly] 88 | text_spans: Optional[List[TextSpan]] 89 | metadata: Optional[TextBlockMetadata] 90 | ``` 91 | 92 | ### Point and BoundingPoly 93 | 94 | `Point` and `BoundingPoly` are used to represent more complex shapes within a document. 95 | 96 | ```python 97 | class Point(BaseModel): 98 | x: BoundedFloat 99 | y: BoundedFloat 100 | 101 | class BoundingPoly(BaseModel): 102 | normalized_vertices: List[Point] 103 | ``` 104 | 105 | ### TextSpan 106 | 107 | `TextSpan` represents a span of text within a document or page. 108 | 109 | ```python 110 | class TextSpan(BaseModel): 111 | start_index: int 112 | end_index: int 113 | level: Literal["page", "document"] 114 | ``` 115 | 116 | ### Usage Example 117 | 118 | ```python 119 | from docprompt.schema.layout import NormBBox, TextBlock, TextBlockMetadata 120 | 121 | # Create a bounding box 122 | bbox = NormBBox(x0=0.1, top=0.1, x1=0.9, bottom=0.2) 123 | 124 | # Create a text block 125 | text_block = TextBlock( 126 | text="Example text", 127 | type="block", 128 | source="ocr", 129 | bounding_box=bbox, 130 | metadata=TextBlockMetadata(confidence=0.95) 131 | ) 132 | 133 | # Use the text block 134 | print(f"Text: {text_block.text}") 135 | print(f"Bounding box: {text_block.bounding_box}") 136 | print(f"Confidence: {text_block.confidence}") 137 | ``` 138 | 139 | These primitives form the foundation of Docprompt's document processing capabilities, allowing for precise representation and manipulation of document content and structure. 140 | -------------------------------------------------------------------------------- /docs/concepts/provenance.md: -------------------------------------------------------------------------------- 1 | # Provenance in Docprompt 2 | 3 | ## Overview 4 | 5 | Provenance in Docprompt refers to the ability to trace and locate specific pieces of text within a document. The `DocumentProvenanceLocator` class is a powerful tool that enables efficient text search, spatial queries, and fine-grained text location within documents that have been processed with OCR. 6 | 7 | ## Key Concepts 8 | 9 | ### DocumentProvenanceLocator 10 | 11 | The `DocumentProvenanceLocator` is a class that provides advanced search capabilities for documents in Docprompt. It combines full-text search with spatial indexing to offer fast and accurate text location services. 12 | 13 | ```python 14 | @dataclass 15 | class DocumentProvenanceLocator: 16 | document_name: str 17 | search_index: "tantivy.Index" 18 | block_mapping: Dict[int, OcrPageResult] 19 | geo_index: DocumentProvenanceGeoMap 20 | ``` 21 | 22 | Key features: 23 | - Full-text search using the Tantivy search engine 24 | - Spatial indexing using R-tree for efficient bounding box queries 25 | - Support for different granularity levels (word, line, block) 26 | - Ability to refine search results to word-level precision 27 | 28 | ## Main Functionalities 29 | 30 | ### 1. Text Search 31 | 32 | The `search` method allows you to find specific text within a document: 33 | 34 | ```python 35 | def search( 36 | self, 37 | query: str, 38 | page_number: Optional[int] = None, 39 | *, 40 | refine_to_word: bool = True, 41 | require_exact_match: bool = True 42 | ) -> List[ProvenanceSource]: 43 | # ... implementation ... 44 | ``` 45 | 46 | This method returns a list of `ProvenanceSource` objects, which contain detailed information about where the text was found, including page number, bounding box, and the surrounding context. 47 | 48 | ### 2. Spatial Queries 49 | 50 | The `DocumentProvenanceLocator` supports spatial queries to find text blocks based on their location on the page: 51 | 52 | ```python 53 | def get_k_nearest_blocks( 54 | self, 55 | bbox: NormBBox, 56 | page_number: int, 57 | k: int, 58 | granularity: BlockGranularity = "block" 59 | ) -> List[TextBlock]: 60 | # ... implementation ... 61 | 62 | def get_overlapping_blocks( 63 | self, 64 | bbox: NormBBox, 65 | page_number: int, 66 | granularity: BlockGranularity = "block" 67 | ) -> List[TextBlock]: 68 | # ... implementation ... 69 | ``` 70 | 71 | These methods allow you to find text blocks that are near or overlapping with a given bounding box on a specific page. 72 | 73 | ## Usage 74 | 75 | ### Recommended Usage: Through DocumentNode 76 | 77 | The recommended way to use the `DocumentProvenanceLocator` is through the `DocumentNode` class. The `DocumentNode` provides two methods for working with the locator: 78 | 79 | 1. `locator` property: Lazily creates and returns the `DocumentProvenanceLocator`. 80 | 2. `refresh_locator()` method: Explicitly refreshes the locator for the document node. 81 | 82 | Here's how to use these methods: 83 | 84 | ```python 85 | from docprompt import load_document, DocumentNode 86 | from docprompt.tasks.ocr.gcp import GoogleOcrProvider 87 | 88 | # Load and process the document 89 | document = load_document("path/to/my.pdf") 90 | document_node = DocumentNode.from_document(document) 91 | 92 | # Process the document with OCR 93 | provider = GoogleOcrProvider.from_service_account_file(...) 94 | provider.process_document_node(document_node) 95 | 96 | # Access the locator (creates it if it doesn't exist) 97 | locator = document_node.locator 98 | 99 | # Perform a search 100 | results = locator.search("Docprompt") 101 | 102 | # If you need to refresh the locator (e.g., after updating OCR results) 103 | document_node.refresh_locator() 104 | ``` 105 | 106 | Note: Attempting to access the locator before OCR results are available will raise a `ValueError`. 107 | 108 | ### Alternative: Standalone Usage 109 | 110 | While the recommended approach is to use the locator through `DocumentNode`, you can also create and use a `DocumentProvenanceLocator` independently: 111 | 112 | ```python 113 | from docprompt.provenance.search import DocumentProvenanceLocator 114 | 115 | # Assuming you have a processed DocumentNode 116 | locator = DocumentProvenanceLocator.from_document_node(document_node) 117 | 118 | # Now you can use the locator directly 119 | results = locator.search("Docprompt") 120 | ``` 121 | 122 | ### Searching for Text 123 | 124 | To search for text within the document: 125 | 126 | ```python 127 | results = locator.search("Docprompt") 128 | for result in results: 129 | print(f"Found on page {result.page_number}, bbox: {result.text_location.merged_source_block.bounding_box}") 130 | ``` 131 | 132 | ### Performing Spatial Queries 133 | 134 | To find text blocks near a specific location: 135 | 136 | ```python 137 | bbox = NormBBox(x0=0.1, y0=0.1, x1=0.2, y1=0.2) 138 | nearby_blocks = locator.get_k_nearest_blocks(bbox, page_number=1, k=5) 139 | ``` 140 | 141 | ## Benefits of Using Provenance 142 | 143 | 1. **Accurate Text Location**: Quickly find the exact location of text within a document, including page number and bounding box. 144 | 2. **Efficient Searching**: Combine full-text search with spatial indexing for fast and accurate results. 145 | 3. **Flexible Granularity**: Search and retrieve results at different levels of granularity (word, line, block). 146 | 4. **Integration with OCR**: Seamlessly works with OCR results to provide comprehensive document analysis capabilities. 147 | 5. **Support for Complex Queries**: Perform spatial queries to find text based on location within pages. 148 | 6. **Easy Access**: Conveniently access the locator through the `DocumentNode` class, ensuring it's always available when needed. 149 | 150 | By leveraging the provenance functionality in Docprompt, you can build sophisticated document analysis workflows that require precise text location and contextual information retrieval. 151 | -------------------------------------------------------------------------------- /docs/concepts/providers.md: -------------------------------------------------------------------------------- 1 | # Providers in Docprompt 2 | 3 | ## Overview 4 | 5 | Providers in Docprompt are abstract interfaces that define how to add data to document nodes. They encapsulate various tasks such as OCR, classification, and more. The provider system is designed to be extensible, allowing users to create custom providers to add new functionality to Docprompt. 6 | 7 | ## Key Concepts 8 | 9 | ### AbstractTaskProvider 10 | 11 | The `AbstractTaskProvider` is the base class for all providers in Docprompt. It defines the interface that all task providers must implement. 12 | 13 | ```python 14 | class AbstractTaskProvider(Generic[PageTaskResult]): 15 | name: str 16 | capabilities: List[str] 17 | 18 | def process_document_pages( 19 | self, 20 | document: Document, 21 | start: Optional[int] = None, 22 | stop: Optional[int] = None, 23 | **kwargs, 24 | ) -> Dict[int, PageTaskResult]: 25 | raise NotImplementedError 26 | 27 | def contribute_to_document_node( 28 | self, 29 | document_node: "DocumentNode", 30 | results: Dict[int, PageTaskResult], 31 | ) -> None: 32 | pass 33 | 34 | def process_document_node( 35 | self, 36 | document_node: "DocumentNode", 37 | start: Optional[int] = None, 38 | stop: Optional[int] = None, 39 | contribute_to_document: bool = True, 40 | **kwargs, 41 | ) -> Dict[int, PageTaskResult]: 42 | # ... implementation ... 43 | ``` 44 | 45 | Key features: 46 | - Generic type `PageTaskResult` allows for type-safe results 47 | - `capabilities` list defines what the provider can do 48 | - `process_document_pages` method processes pages of a document 49 | - `contribute_to_document_node` method adds results to a `DocumentNode` 50 | - `process_document_node` method combines processing and contributing results 51 | 52 | ### CAPABILITIES 53 | 54 | The `CAPABILITIES` enum defines the various capabilities that a provider can have: 55 | 56 | ```python 57 | class CAPABILITIES(Enum): 58 | PAGE_RASTERIZATION = "page-rasterization" 59 | PAGE_LAYOUT_OCR = "page-layout-ocr" 60 | PAGE_TEXT_OCR = "page-text-ocr" 61 | PAGE_CLASSIFICATION = "page-classification" 62 | PAGE_SEGMENTATION = "page-segmentation" 63 | PAGE_VQA = "page-vqa" 64 | PAGE_TABLE_IDENTIFICATION = "page-table-identification" 65 | PAGE_TABLE_EXTRACTION = "page-table-extraction" 66 | ``` 67 | 68 | ### ResultContainer 69 | 70 | The `ResultContainer` is a generic class that holds the results of a task: 71 | 72 | ```python 73 | class ResultContainer(BaseModel, Generic[PageOrDocumentTaskResult]): 74 | results: Dict[str, PageOrDocumentTaskResult] = Field( 75 | description="The results of the task, keyed by provider", default_factory=dict 76 | ) 77 | 78 | @property 79 | def result(self): 80 | return next(iter(self.results.values()), None) 81 | ``` 82 | 83 | ## Creating Custom Providers 84 | 85 | To extend Docprompt's functionality, you can create custom providers. Here's an shortened example of a builtin OCR provider from GCP: 86 | 87 | ```python 88 | from docprompt.tasks.base import AbstractTaskProvider, CAPABILITIES 89 | from docprompt.schema.layout import TextBlock 90 | from pydantic import Field 91 | 92 | class OcrPageResult(BasePageResult): 93 | page_text: str = Field(description="The text for the entire page in reading order") 94 | word_level_blocks: List[TextBlock] = Field(default_factory=list) 95 | line_level_blocks: List[TextBlock] = Field(default_factory=list) 96 | block_level_blocks: List[TextBlock] = Field(default_factory=list) 97 | raster_image: Optional[bytes] = Field(default=None) 98 | 99 | class GoogleOcrProvider(AbstractTaskProvider[OcrPageResult]): 100 | name = "Google Document AI" 101 | capabilities = [ 102 | CAPABILITIES.PAGE_TEXT_OCR.value, 103 | CAPABILITIES.PAGE_LAYOUT_OCR.value, 104 | CAPABILITIES.PAGE_RASTERIZATION.value, 105 | ] 106 | 107 | def process_document_pages( 108 | self, 109 | document: Document, 110 | start: Optional[int] = None, 111 | stop: Optional[int] = None, 112 | **kwargs, 113 | ) -> Dict[int, OcrPageResult]: 114 | # Implement OCR logic here 115 | pass 116 | 117 | def contribute_to_document_node( 118 | self, 119 | document_node: "DocumentNode", 120 | results: Dict[int, OcrPageResult], 121 | ) -> None: 122 | # Add OCR results to document node 123 | pass 124 | ``` 125 | 126 | ## Usage 127 | 128 | Here's how you can use a provider in your Docprompt workflow: 129 | 130 | ```python 131 | from docprompt import load_document, DocumentNode 132 | from docprompt.providers.ocr import GoogleOcrProvider 133 | 134 | # Load a document 135 | document = load_document("path/to/my.pdf") 136 | document_node = DocumentNode.from_document(document) 137 | 138 | # Create and use the OCR provider 139 | ocr_provider = GoogleOcrProvider(...) 140 | ocr_results = ocr_provider.process_document_node(document_node) 141 | 142 | # Access OCR results 143 | for page_number, result in ocr_results.items(): 144 | print(f"Page {page_number} text: {result.page_text[:100]}...") 145 | ``` 146 | 147 | ## Benefits of Using Providers 148 | 149 | 1. **Extensibility**: Easily add new functionality to Docprompt by creating custom providers. 150 | 2. **Modularity**: Each provider encapsulates a specific task, making the codebase more organized and maintainable. 151 | 3. **Type Safety**: Generic types ensure that providers produce and consume the correct types of results. 152 | 4. **Standardized Interface**: All providers follow the same interface, making it easy to switch between different implementations. 153 | 5. **Capability-based Design**: Providers declare their capabilities, allowing for dynamic feature discovery and usage. 154 | 155 | By leveraging the provider system in Docprompt, you can create flexible and powerful document processing pipelines that can be easily extended and customized to meet your specific needs. 156 | -------------------------------------------------------------------------------- /docs/enterprise.md: -------------------------------------------------------------------------------- 1 | # Enterprise 2 | For companies looking to unlock data, build custom language models, or for general professional support 3 | 4 | [Talk to founders](https://calendly.com/pageleaf/meet-with-the-founders) 5 | 6 | This covers: 7 | 8 | * ✅ **Assistance with PDF-optimized prompt engineering for Document AI tasks** 9 | * ✅ **Feature Prioritization** 10 | * ✅ **Custom Integrations** 11 | - ✅ **Professional Support - Dedicated discord + slack** 12 | 13 | 14 | ### What topics does Professional support cover? 15 | 16 | The expertise we've developed during our time building medical processing pipelines has equipped us with the tools and knowhow needed to perform highly accurate information extraction in document-heavy domains. We offer consulting services that leverage this expertise to assist with prompt engineering, deployments, and general ML-ops in the Document AI space. 17 | -------------------------------------------------------------------------------- /docs/gen_ref_pages.py: -------------------------------------------------------------------------------- 1 | """Generate the code reference pages and navigation.""" 2 | 3 | import logging 4 | from pathlib import Path 5 | 6 | import mkdocs_gen_files 7 | 8 | # Set up logging 9 | logging.basicConfig(level=logging.DEBUG) 10 | logger = logging.getLogger(__name__) 11 | 12 | nav = mkdocs_gen_files.Nav() 13 | 14 | root_path = Path("docprompt") 15 | logger.debug(f"Searching for Python files in: {root_path.absolute()}") 16 | 17 | python_files = list(root_path.rglob("*.py")) 18 | logger.debug(f"Found {len(python_files)} Python files") 19 | 20 | if not python_files: 21 | logger.warning( 22 | "No Python files found. Ensure the 'docprompt' directory exists and contains .py files." 23 | ) 24 | nav["No modules found"] = "no_modules.md" 25 | with mkdocs_gen_files.open("reference/no_modules.md", "w") as fd: 26 | fd.write( 27 | "# No Modules Found\n\nNo Python modules were found in the 'docprompt' directory." 28 | ) 29 | else: 30 | for path in sorted(python_files): 31 | module_path = path.relative_to(root_path).with_suffix("") 32 | doc_path = path.relative_to(root_path).with_suffix(".md") 33 | full_doc_path = Path("reference", doc_path) 34 | 35 | parts = tuple(module_path.parts) 36 | 37 | if parts[-1] == "__init__": 38 | parts = parts[:-1] 39 | doc_path = doc_path.with_name("index.md") 40 | full_doc_path = full_doc_path.with_name("index.md") 41 | elif parts[-1] == "__main__": 42 | continue 43 | 44 | # Handle empty parts 45 | if not parts: 46 | logger.warning(f"Empty parts for file: {path}. Skipping this file.") 47 | continue 48 | 49 | nav[parts] = doc_path.as_posix() 50 | 51 | with mkdocs_gen_files.open(full_doc_path, "w") as fd: 52 | ident = ".".join(parts) 53 | fd.write(f"::: docprompt.{ident}") 54 | 55 | mkdocs_gen_files.set_edit_path(full_doc_path, path) 56 | 57 | logger.debug("Navigation structure:") 58 | logger.debug(nav.build_literate_nav()) 59 | 60 | with mkdocs_gen_files.open("reference/SUMMARY.md", "w") as nav_file: 61 | nav_file.writelines(nav.build_literate_nav()) 62 | -------------------------------------------------------------------------------- /docs/guide/classify/binary.md: -------------------------------------------------------------------------------- 1 | # Binary Classification with DocPrompt 2 | 3 | Binary classification is a fundamental task in document analysis, where you categorize pages into one of two classes. This guide will walk you through performing binary classification using DocPrompt with the Anthropic provider. 4 | 5 | ## Setting Up 6 | 7 | First, let's import the necessary modules and set up our environment: 8 | 9 | ```python 10 | from docprompt import load_document, DocumentNode 11 | from docprompt.tasks.factory import AnthropicTaskProviderFactory 12 | from docprompt.tasks.classification import ClassificationInput, ClassificationTypes 13 | 14 | # Initialize the Anthropic factory 15 | # Make sure you have set the ANTHROPIC_API_KEY environment variable 16 | factory = AnthropicTaskProviderFactory() 17 | 18 | # Create the classification provider 19 | classification_provider = factory.get_page_classification_provider() 20 | ``` 21 | 22 | ## Preparing the Document 23 | 24 | Load your document and create a DocumentNode: 25 | 26 | ```python 27 | document = load_document("path/to/your/document.pdf") 28 | document_node = DocumentNode.from_document(document) 29 | ``` 30 | 31 | ## Configuring the Classification Task 32 | 33 | For binary classification, we need to create a `ClassificationInput` object. This object acts as a prompt for the model, guiding its classification decision. Here's how to set it up: 34 | 35 | ```python 36 | classification_input = ClassificationInput( 37 | type=ClassificationTypes.BINARY, 38 | instructions="Determine if the page contains information about financial transactions.", 39 | confidence=True # Optional: request confidence scores 40 | ) 41 | ``` 42 | 43 | Let's break down the `ClassificationInput`: 44 | 45 | - `type`: Set to `ClassificationTypes.BINARY` for binary classification. 46 | - `instructions`: This is crucial for binary classification. Provide clear, specific instructions for what the model should look for. 47 | - `confidence`: Set to `True` if you want confidence scores with the results. 48 | 49 | Note: For binary classification, you don't need to specify labels. The model will automatically use "YES" and "NO" as labels. 50 | 51 | ## Performing Classification 52 | 53 | Now, let's run the classification task: 54 | 55 | ```python 56 | results = classification_provider.process_document_node( 57 | document_node, 58 | classification_input 59 | ) 60 | ``` 61 | 62 | ## Interpreting Results 63 | 64 | The `results` dictionary contains the classification output for each page. Let's examine the results: 65 | 66 | ```python 67 | for page_number, result in results.items(): 68 | label = result.labels 69 | confidence = result.score 70 | print(f"Page {page_number}:") 71 | print(f"\tClassification: {label} ({confidence})") 72 | print('---') 73 | ``` 74 | 75 | This will print the classification (YES or NO) for each page, along with the confidence level if requested. 76 | 77 | ## Tips for Effective Binary Classification 78 | 79 | 1. **Clear Instructions**: The `instructions` field in `ClassificationInput` is crucial. Be specific about what constitutes a "YES" classification. 80 | 81 | 2. **Consider Page Context**: Remember that the model analyzes each page independently. If your classification requires context from multiple pages, you may need to adjust your approach. 82 | 83 | 3. **Confidence Scores**: Use the `confidence` option to get an idea of how certain the model is about its classifications. This can be helpful for identifying pages that might need human review. 84 | 85 | 4. **Iterative Refinement**: If you're not getting the desired results, try refining your instructions. You might need to be more specific or provide examples of what constitutes a positive classification. 86 | 87 | ## Conclusion 88 | 89 | Binary classification with DocPrompt allows you to quickly categorize pages in your documents. By leveraging the Anthropic provider and carefully crafting your `ClassificationInput`, you can achieve accurate and efficient document analysis. 90 | 91 | Remember that the quality of your results heavily depends on the clarity and specificity of your instructions. Experiment with different phrasings to find what works best for your specific use case. 92 | -------------------------------------------------------------------------------- /docs/guide/classify/multi.md: -------------------------------------------------------------------------------- 1 | # Multi-Label Classification with DocPrompt: Academic Research Papers 2 | 3 | In this guide, we'll use DocPrompt to classify academic research papers in a PDF document into multiple relevant fields. We'll use multi-label classification, allowing each page (paper) to be assigned to one or more categories based on its content. 4 | 5 | ## Setting Up 6 | 7 | First, let's import the necessary modules and set up our environment: 8 | 9 | ```python 10 | from docprompt import load_document, DocumentNode 11 | from docprompt.tasks.factory import AnthropicTaskProviderFactory 12 | from docprompt.tasks.classification import ClassificationInput, ClassificationTypes 13 | 14 | # Initialize the Anthropic factory 15 | # Ensure you have set the ANTHROPIC_API_KEY environment variable 16 | factory = AnthropicTaskProviderFactory() 17 | 18 | # Create the classification provider 19 | classification_provider = factory.get_page_classification_provider() 20 | ``` 21 | 22 | ## Preparing the Document 23 | 24 | Load your collection of research papers and create a DocumentNode: 25 | 26 | ```python 27 | document = load_document("path/to/your/research_papers.pdf") 28 | document_node = DocumentNode.from_document(document) 29 | ``` 30 | 31 | ## Configuring the Classification Task 32 | 33 | For multi-label classification, we'll create a `ClassificationInput` object that specifies our labels and provides instructions for the model: 34 | 35 | ```python 36 | classification_input = ClassificationInput( 37 | type=ClassificationTypes.MULTI_LABEL, 38 | instructions=( 39 | "Classify the research paper on this page into one or more relevant fields " 40 | "based on its title, abstract, methodology, and key findings. A paper may " 41 | "belong to multiple categories if it spans multiple disciplines." 42 | ), 43 | labels=[ 44 | "Machine Learning", 45 | "Natural Language Processing", 46 | "Computer Vision", 47 | "Robotics", 48 | "Data Mining", 49 | "Cybersecurity", 50 | "Bioinformatics", 51 | "Quantum Computing" 52 | ], 53 | descriptions=[ 54 | "Algorithms and statistical models that computer systems use to perform tasks without explicit instructions", 55 | "Processing and analyzing natural language data", 56 | "Enabling computers to derive meaningful information from digital images, videos and other visual inputs", 57 | "Design, construction, operation, and use of robots", 58 | "Process of discovering patterns in large data sets", 59 | "Protection of computer systems from theft or damage to their hardware, software, or electronic data", 60 | "Application of computational techniques to analyze biological data", 61 | "Computation based on the principles of quantum theory" 62 | ], 63 | confidence=True # Request confidence scores 64 | ) 65 | ``` 66 | 67 | Let's break down the `ClassificationInput`: 68 | 69 | - `type`: Set to `ClassificationTypes.MULTI_LABEL` for multi-label classification. 70 | - `labels`: List of possible categories for our research papers. 71 | - `instructions`: Clear directions for the model on how to classify the papers, emphasizing that multiple labels can be assigned. 72 | - `descriptions`: Provide additional context for each label to improve classification accuracy. 73 | - `confidence`: Set to `True` to get confidence scores with the results. 74 | 75 | ## Performing Classification 76 | 77 | Now, let's run the classification task on our collection of research papers: 78 | 79 | ```python 80 | results = classification_provider.process_document_node( 81 | document_node, 82 | classification_input 83 | ) 84 | ``` 85 | 86 | ## Interpreting Results 87 | 88 | Let's examine the classification results for each research paper: 89 | 90 | ```python 91 | for page_number, result in results.items(): 92 | categories = result.labels 93 | confidence = result.score 94 | print(f"Research Paper on Page {page_number}:") 95 | print(f"\tCategories: {', '.join(categories)} ({confidence})") 96 | print('---') 97 | ``` 98 | 99 | This will print the assigned categories for each research paper, along with the confidence level. 100 | 101 | ## Tips for Effective Multi-Label Classification 102 | 103 | 1. **Comprehensive Label Set**: Ensure your label set covers the main topics in your domain but isn't so large that it becomes unwieldy. 104 | 105 | 2. **Clear Instructions**: Emphasize in your instructions that multiple labels can and should be assigned when appropriate. 106 | 107 | 3. **Use Descriptions**: The `descriptions` field helps the model understand the nuances of each category, which is especially important for interdisciplinary papers. 108 | 109 | 4. **Consider Confidence Scores**: In multi-label classification, confidence scores can indicate how strongly a paper fits into each assigned category. 110 | 111 | 5. **Analyze Label Co-occurrences**: Look for patterns in which labels frequently appear together to gain insights into interdisciplinary trends. 112 | 113 | 6. **Handle Outliers**: If a paper doesn't fit well into any category, consider adding a catch-all category like "Other" or "Interdisciplinary" in future iterations. 114 | 115 | ## Advanced Usage: Increasing the Power 116 | 117 | For more control over the classification process, you can specify a beefier model from Anthropic to up the reasoning power. This can be done when setting up the task provider OR at inference time. Allowing for easy, fine-grained control of over provider defaults and runtime overrides. 118 | 119 | ```python 120 | classification_provider = factory.get_page_classification_provider( 121 | model_name="claude-3-5-sonnet-20240620" # setup the task provider with sonnet-3.5 122 | ) 123 | 124 | results = classification_provider.process_document_node( 125 | document_node, 126 | classification_input, 127 | model_name="claude-3-5-sonnet-20240620" # or you can declare model name at inference time as well 128 | ) 129 | ``` 130 | 131 | ## Conclusion 132 | 133 | Multi-label classification with DocPrompt provides a powerful way to categorize complex documents like research papers that often span multiple disciplines. By carefully crafting your `ClassificationInput` with clear labels, instructions, and descriptions, you can achieve nuanced and informative document analysis. 134 | 135 | Remember that the quality of your results depends on the clarity of your instructions, the comprehensiveness of your label set, and the appropriateness of your descriptions. Experiment with different configurations to find what works best for your specific use case. 136 | 137 | This approach can be adapted to other multi-label classification tasks, such as categorizing news articles by multiple topics, classifying products by multiple features, or tagging images with multiple attributes. 138 | -------------------------------------------------------------------------------- /docs/guide/classify/single.md: -------------------------------------------------------------------------------- 1 | # Single-Label Classification with DocPrompt: Recipe Categories 2 | 3 | In this guide, we'll use DocPrompt to classify recipes in a PDF document into distinct meal categories. We'll use single-label classification, meaning each page (recipe) will be assigned to one category: Breakfast, Lunch, Dinner, or Dessert. 4 | 5 | ## Setting Up 6 | 7 | First, let's import the necessary modules and set up our environment: 8 | 9 | ```python 10 | from docprompt import load_document, DocumentNode 11 | from docprompt.tasks.factory import AnthropicTaskProviderFactory 12 | from docprompt.tasks.classification import ClassificationInput, ClassificationTypes 13 | 14 | # Initialize the Anthropic factory 15 | # Ensure you have set the ANTHROPIC_API_KEY environment variable 16 | factory = AnthropicTaskProviderFactory() 17 | 18 | # Create the classification provider 19 | classification_provider = factory.get_page_classification_provider() 20 | ``` 21 | 22 | ## Preparing the Document 23 | 24 | Load your recipe book PDF and create a DocumentNode: 25 | 26 | ```python 27 | document = load_document("path/to/your/recipe_book.pdf") 28 | document_node = DocumentNode.from_document(document) 29 | ``` 30 | 31 | ## Configuring the Classification Task 32 | 33 | For single-label classification, we'll create a `ClassificationInput` object that specifies our labels and provides instructions for the model: 34 | 35 | ```python 36 | classification_input = ClassificationInput( 37 | type=ClassificationTypes.SINGLE_LABEL, 38 | labels=["Breakfast", "Lunch", "Dinner", "Dessert"], 39 | instructions="Classify the recipe on this page into one of the given meal categories based on its ingredients, cooking methods, and typical serving time.", 40 | descriptions=[ 41 | "Morning meals, often including eggs, cereals, or pastries", 42 | "Midday meals, typically lighter fare like sandwiches or salads", 43 | "Evening meals, often the most substantial meal of the day", 44 | "Sweet treats typically served after a meal or as a snack" 45 | ], 46 | confidence=True # Request confidence scores 47 | ) 48 | ``` 49 | 50 | Let's break down the `ClassificationInput`: 51 | 52 | - `type`: Set to `ClassificationTypes.SINGLE_LABEL` for single-label classification. 53 | - `labels`: List of possible categories for our recipes. 54 | - `instructions`: Clear directions for the model on how to classify the recipes. 55 | - `descriptions`: (Optional) Provide additional context for each label to improve classification accuracy. 56 | - _Note that the `descriptions` array must be the same length as the `labels` array._ 57 | - `confidence`: Set to `True` to get confidence scores with the results. 58 | 59 | ## Performing Classification 60 | 61 | Now, let's run the classification task on our recipe book: 62 | 63 | ```python 64 | results = classification_provider.process_document_node( 65 | document_node, 66 | classification_input 67 | ) 68 | ``` 69 | 70 | ## Interpreting Results 71 | 72 | Let's examine the classification results for each recipe: 73 | 74 | ```python 75 | for page_number, result in results.items(): 76 | category = result.labels 77 | confidence = result.score 78 | 79 | print(f"Recipe on Page {page_number}:") 80 | print(f"\tCategory: {category} ({confidence})") 81 | print('---') 82 | ``` 83 | 84 | This will print the assigned category (Breakfast, Lunch, Dinner, or Dessert) for each recipe, along with the confidence level. 85 | 86 | ## Tips for Effective Single-Label Classification 87 | 88 | 1. **Comprehensive Labels**: Ensure your label set covers all possible categories without overlap. 89 | 90 | 2. **Clear Instructions**: Provide specific criteria for each category in your instructions. For recipes, mention considering ingredients, cooking methods, and typical serving times. 91 | 92 | 3. **Use Descriptions**: The `descriptions` field can help the model understand nuances between categories, especially for edge cases like brunch recipes. 93 | 94 | 4. **Consider Confidence Scores**: Low confidence scores might indicate recipes that don't clearly fit into one category, such as versatile dishes that could be served at multiple meals. 95 | 96 | 5. **Handling Edge Cases**: If you encounter many low-confidence classifications, you might need to refine your categories or instructions. For example, you might add an "Anytime" category for versatile recipes. 97 | 98 | ## Advanced Usage: Customizing the Model 99 | 100 | If you need to experiment with different LLM models, based on the complexity of your task, you may control the model_name parameter to the classification provider: 101 | 102 | ```python 103 | haiku_classification_provider = factory.get_page_classification_provider( 104 | model_name="claude-3-haiku-20240307" 105 | ) 106 | 107 | sonnet_classification_provider = factory.get_page_classification_provider( 108 | model_name="claude-3-5-sonnet-20240620" 109 | ) 110 | ``` 111 | 112 | ## Conclusion 113 | 114 | Single-label classification with DocPrompt provides a powerful way to categorize pages in your documents, such as recipes in a cookbook. By carefully crafting your `ClassificationInput` with clear labels, instructions, and descriptions, you can achieve accurate and efficient document analysis. 115 | 116 | Remember that the quality of your results depends on the clarity of your instructions and the appropriateness of your label set. Experiment with different phrasings and label combinations to find what works best for your specific use case. 117 | 118 | This approach can be easily adapted to other single-label classification tasks, such as categorizing scientific papers by field, sorting legal documents by type, or classifying news articles by topic. 119 | -------------------------------------------------------------------------------- /docs/guide/ocr/advanced_search.md: -------------------------------------------------------------------------------- 1 | # Lightning-Fast Document Search 🔥🚀 2 | 3 | Ever wished you could search through OCR-processed documents at the speed of light? Look no further! DocPrompt's Provenance Locator, powered by Rust, offers blazingly fast text search capabilities that will revolutionize your document processing workflows. 4 | 5 | ## The Power of Rust-Powered Search 6 | 7 | DocPrompt's `DocumentProvenanceLocator` is not your average search tool. Implemented in Rust and leveraging the power of `tantivy` and `rtree`, it provides: 8 | 9 | - ⚡ Lightning-fast full-text search 10 | - 🎯 Precise text location within documents 11 | - 🧠 Smart granularity refinement (word, line, block) 12 | - 🗺️ Spatial querying capabilities 13 | 14 | Let's dive into how you can harness this power! 15 | 16 | ## Setting Up the Locator 17 | 18 | First, let's create a `DocumentProvenanceLocator` from a processed `DocumentNode`: 19 | 20 | ```python 21 | from docprompt import load_document, DocumentNode 22 | from docprompt.tasks.factory import GCPTaskProviderFactory 23 | 24 | # Load and process the document 25 | document = load_document("path/to/your/document.pdf") 26 | document_node = DocumentNode.from_document(document) 27 | 28 | # Process with OCR (assuming you've set up the GCP factory) 29 | gcp_factory = GCPTaskProviderFactory(service_account_file="path/to/credentials.json") 30 | ocr_provider = gcp_factory.get_page_ocr_provider( 31 | project_id="your-project-id", 32 | processor_id="your-processor-id" 33 | ) 34 | ocr_results = ocr_provider.process_document_node(document_node) 35 | 36 | # Create the locator 37 | locator = document_node.locator 38 | ``` 39 | 40 | ## Searching at the Speed of Rust 🔥 41 | 42 | Now that we have our locator, let's see it in action: 43 | 44 | ```python 45 | # Perform a simple search 46 | results = locator.search("DocPrompt") 47 | 48 | for result in results: 49 | print(f"Found on page {result.page_number}") 50 | print(f"Text: {result.text}") 51 | print(f"Bounding box: {result.text_location.merged_source_block.bounding_box}") 52 | print("---") 53 | ``` 54 | 55 | This search operation happens in milliseconds, even for large documents, thanks to the Rust-powered backend! 56 | 57 | ## Advanced Search Capabilities 58 | 59 | ### Refining to Word Level 60 | 61 | DocPrompt can automatically refine search results to the word level: 62 | 63 | ```python 64 | refined_results = locator.search("DocPrompt", refine_to_word=True) 65 | ``` 66 | 67 | This gives you pinpoint accuracy in locating text within your document. 68 | 69 | ### Page-Specific Search 70 | 71 | Need to search on a specific page? No problem: 72 | 73 | ```python 74 | page_5_results = locator.search("DocPrompt", page_number=5) 75 | ``` 76 | 77 | ### Best Match Search 78 | 79 | Find the best matches based on different criteria: 80 | 81 | ```python 82 | best_short_matches = locator.search_n_best("DocPrompt", n=3, mode="shortest_text") 83 | 84 | best_long_matches = locator.search_n_best("DocPrompt", n=3, mode="longest_text") 85 | 86 | best_overall_matches = locator.search_n_best("DocPrompt", n=3, mode="highest_score") 87 | ``` 88 | 89 | ## Spatial Queries: Beyond Text Search 🗺️ 90 | 91 | DocPrompt's locator isn't just fast—it's spatially aware! You can perform queries based on document layout: 92 | 93 | ```python 94 | from docprompt.schema.layout import NormBBox 95 | 96 | # Get blocks near a specific area on page 1 97 | bbox = NormBBox(x0=0.1, top=0.1, x1=0.2, bottom=0.2) 98 | nearby_blocks = locator.get_k_nearest_blocks(bbox, page_number=1, k=5) 99 | 100 | # Get overlapping blocks 101 | overlapping_blocks = locator.get_overlapping_blocks(bbox, page_number=1) 102 | ``` 103 | 104 | This spatial awareness opens up possibilities for advanced document analysis and data extraction! 105 | 106 | ## Conclusion: Search at the Speed of Thought 🧠💨 107 | 108 | DocPrompt's `DocumentProvenanceLocator` brings unprecedented speed and precision to document search and analysis. By leveraging the power of Rust, it offers: 109 | 110 | 1. Lightning-fast full-text search 111 | 2. Precise text location within documents 112 | 3. Advanced spatial querying capabilities 113 | 4. Scalability for large documents and datasets 114 | 115 | Whether you're building a document analysis pipeline, a search system, or any text-based application, DocPrompt's Provenance Locator offers the speed and accuracy you need to stay ahead of the game. 116 | -------------------------------------------------------------------------------- /docs/guide/ocr/advanced_workflows.md: -------------------------------------------------------------------------------- 1 | # Advanced Document Analysis: 2 | 3 | 4 | ### Detecting Potential Conflicts of Interest in 10-K Reports 5 | In this guide, we'll demonstrate how to use DocPrompt's powerful OCR and search capabilities to analyze 10-K reports for potential conflicts of interest. We'll search for mentions of company names and executive names, then identify instances where they appear in close proximity within the document. 6 | 7 | ## Setup 8 | 9 | First, let's set up our environment and process the document: 10 | 11 | ```python 12 | from docprompt import load_document, DocumentNode 13 | from docprompt.tasks.factory import GCPTaskProviderFactory 14 | from docprompt.schema.layout import NormBBox 15 | from itertools import product 16 | 17 | # Load and process the document 18 | document = load_document("path/to/10k_report.pdf") 19 | document_node = DocumentNode.from_document(document) 20 | 21 | # Perform OCR 22 | gcp_factory = GCPTaskProviderFactory(service_account_file="path/to/credentials.json") 23 | ocr_provider = gcp_factory.get_page_ocr_provider(project_id="your-project-id", processor_id="your-processor-id") 24 | ocr_provider.process_document_node(document_node) 25 | 26 | # Create the locator 27 | locator = document_node.locator 28 | 29 | # Define entities to search for 30 | company_names = ["SubsidiaryA", "PartnerB", "CompetitorC"] 31 | executive_names = ["John Doe", "Jane Smith", "Alice Johnson"] 32 | ``` 33 | 34 | ## Searching for Entities 35 | 36 | Now, let's use DocPrompt's fast search capabilities to find all mentions of companies and executives. By leveraging the speed of the rust powered locator, along with python's builtin comprehension, we can execute our set of queries over the several-hundred page document in a matter of miliseconds. 37 | 38 | ```python 39 | company_results = { 40 | company: locator.search(company) 41 | for company in company_names 42 | } 43 | 44 | executive_results = { 45 | executive: locator.search(executive) 46 | for executive in executive_names 47 | } 48 | ``` 49 | 50 | ## Detecting Proximity 51 | 52 | Next, we'll check for instances where company names and executive names appear in close proximity: 53 | 54 | ```python 55 | def check_proximity(bbox1, bbox2, threshold=0.1): 56 | left_collision = abs(bbox1.x0 - bbox2.x0) < threshold 57 | top_collision = abs(bbox1.top - bbox2.top) < threshold 58 | 59 | return left_collision and top_collision 60 | 61 | potential_conflicts = [] 62 | 63 | for company, exec_name in product(company_names, executive_names): 64 | c_result = company_results[company] 65 | e_result = exexecutive_results[exec_name] 66 | 67 | # Check if the two results appear on the same page 68 | if c_result.page_number == e_result.page_number: 69 | c_bbox = c_result.text_location.merged_source_block.bounding_box 70 | e_bbox = e_result.text_location.merged_source_block.bounding_box 71 | 72 | # If they do, check if the bounding boxes break our threshold 73 | if check_proximity(c_bbox, e_bbox): 74 | potential_conflicts.append({ 75 | 'company': company, 76 | 'executive': exec_name, 77 | 'page': c_result.page_number, 78 | 'company_bbox': c_bbox, 79 | 'exec_bbox': e_bbox 80 | }) 81 | ``` 82 | 83 | ## Analyzing Results 84 | 85 | Finally, let's analyze and display our results: 86 | 87 | ```python 88 | print(f"Found {len(potential_conflicts)} potential conflicts of interest:") 89 | 90 | for conflict in potential_conflicts: 91 | print(f"\nPotential conflict on page {conflict['page']}:") 92 | print(f" Company: {conflict['company']}") 93 | print(f" Executive: {conflict['executive']}") 94 | 95 | # Get surrounding context 96 | context_bbox = NormBBox( 97 | x0=min(conflict['company_bbox'].x0, conflict['exec_bbox'].x0) - 0.05, 98 | top=min(conflict['company_bbox'].top, conflict['exec_bbox'].top) - 0.05, 99 | x1=max(conflict['company_bbox'].x1, conflict['exec_bbox'].x1) + 0.05, 100 | bottom=max(conflict['company_bbox'].bottom, conflict['exec_bbox'].bottom) + 0.05 101 | ) 102 | 103 | context_blocks = locator.get_overlapping_blocks(context_bbox, conflict['page']) 104 | 105 | print(" Context:") 106 | for block in context_blocks: 107 | print(f" {block.text}") 108 | ``` 109 | 110 | This refined approach demonstrates several key features of DocPrompt: 111 | 112 | 1. **Fast and Accurate Search**: We use the `DocumentProvenanceLocator` to quickly find all mentions of companies and executives across the entire document. 113 | 114 | 2. **Spatial Analysis**: By leveraging the bounding box information, we can determine when two entities are mentioned in close proximity on the page. 115 | 116 | 3. **Contextual Information**: We use spatial queries to extract the surrounding text, providing context for each potential conflict of interest. 117 | 118 | 4. **Scalability**: This approach can easily handle multiple companies and executives, making it suitable for analyzing large, complex documents. 119 | 120 | By combining these capabilities, DocPrompt enables efficient and thorough analysis of 10-K reports, helping to identify potential conflicts of interest that might otherwise be overlooked in manual review processes. 121 | -------------------------------------------------------------------------------- /docs/guide/ocr/basic_usage.md: -------------------------------------------------------------------------------- 1 | # Basic OCR Usage with Docprompt 2 | 3 | This guide will walk you through the basics of performing Optical Character Recognition (OCR) using Docprompt. You'll learn how to set up the OCR provider, process a document, and access the results. 4 | 5 | ## Prerequisites 6 | 7 | Before you begin, ensure you have: 8 | 9 | 1. Installed Docprompt with OCR support: `pip install "docprompt[google]"` 10 | 2. A Google Cloud Platform account with Document AI API enabled 11 | 3. A GCP service account key file 12 | 13 | ## Setting Up the OCR Provider 14 | 15 | First, let's set up the Google OCR provider: 16 | 17 | ```python 18 | from docprompt.tasks.factory import GCPTaskProviderFactory 19 | 20 | # Initialize the GCP Task Provider Factory 21 | gcp_factory = GCPTaskProviderFactory( 22 | service_account_file="path/to/your/service_account_key.json" 23 | ) 24 | 25 | # Create the OCR provider 26 | ocr_provider = gcp_factory.get_page_ocr_provider( 27 | project_id="your-gcp-project-id", 28 | processor_id="your-document-ai-processor-id" 29 | ) 30 | ``` 31 | 32 | ## Loading and Processing a Document 33 | 34 | Now, let's load a document and process it using OCR: 35 | 36 | ```python 37 | from docprompt import load_document, DocumentNode 38 | 39 | # Load the document 40 | document = load_document("path/to/your/document.pdf") 41 | document_node = DocumentNode.from_document(document) 42 | 43 | # Process the document 44 | ocr_results = ocr_provider.process_document_node(document_node) 45 | ``` 46 | 47 | ## Accessing OCR Results 48 | 49 | After processing, you can access the OCR results in various ways: 50 | 51 | ### 1. Page-level Text 52 | 53 | To get the full text of a specific page: 54 | 55 | ```python 56 | page_number = 1 # Pages are 1-indexed 57 | page_text = ocr_results[page_number].page_text 58 | print(f"Text on page {page_number}:\n{page_text[:500]}...") # Print first 500 characters 59 | ``` 60 | 61 | ### 2. Words, Lines, and Blocks 62 | 63 | Docprompt provides access to words, lines, and blocks (paragraphs) extracted from the document: 64 | 65 | ```python 66 | # Get the first page's result 67 | first_page_result = ocr_results[1] 68 | 69 | # Print the first 5 words on the page 70 | print("First 5 words:") 71 | for word in first_page_result.word_level_blocks[:5]: 72 | print(f"Word: {word.text}, Confidence: {word.metadata.confidence}") 73 | 74 | # Print the first line on the page 75 | print("\nFirst line:") 76 | if first_page_result.line_level_blocks: 77 | first_line = first_page_result.line_level_blocks[0] 78 | print(f"Line: {first_line.text}") 79 | 80 | # Print the first block (paragraph) on the page 81 | print("\nFirst block:") 82 | if first_page_result.block_level_blocks: 83 | first_block = first_page_result.block_level_blocks[0] 84 | print(f"Block: {first_block.text[:100]}...") # Print first 100 characters 85 | ``` 86 | 87 | ### 3. Bounding Boxes 88 | 89 | Each word, line, and block comes with bounding box information: 90 | 91 | ```python 92 | # Get bounding box for the first word 93 | if first_page_result.word_level_blocks: 94 | first_word = first_page_result.word_level_blocks[0] 95 | bbox = first_word.bounding_box 96 | print(f"\nBounding box for '{first_word.text}':") 97 | print(f"Top-left: ({bbox.x0}, {bbox.top})") 98 | print(f"Bottom-right: ({bbox.x1}, {bbox.bottom})") 99 | ``` 100 | 101 | ## Conclusion 102 | 103 | You've now learned the basics of performing OCR with Docprompt. This includes setting up the OCR provider, processing a document, and accessing the results at different levels of granularity. 104 | 105 | For more advanced usage, including customizing OCR settings, using other providers, and leveraging the powerful search capabilities, check out our other guides: 106 | 107 | - Customizing OCR Providers and Settings 108 | - Advanced Text Search with Provenance Locators 109 | - Building OCR-based Workflows 110 | -------------------------------------------------------------------------------- /docs/guide/ocr/provider_config.md: -------------------------------------------------------------------------------- 1 | # Customizing OCR Providers 2 | 3 | Docprompt uses a factory pattern to manage credentials and create task providers efficiently. This guide will demonstrate how to configure and customize OCR providers, focusing on Amazon Textract and Google Cloud Platform (GCP) as examples. 4 | 5 | ## Understanding the Factory Pattern 6 | 7 | Docprompt uses task provider factories to manage credentials and create providers for various tasks. This approach allows for: 8 | 9 | 1. Centralized credential management 10 | 2. Easy creation of multiple task providers from a single backend 11 | 3. Separation of provider-specific and task-specific configurations 12 | 13 | Here's a simplified example of how the factory pattern works: 14 | 15 | ```python 16 | from docprompt.tasks.factory import GCPTaskProviderFactory, AmazonTaskProviderFactory 17 | 18 | # Create a GCP factory 19 | gcp_factory = GCPTaskProviderFactory( 20 | service_account_file="path/to/service_account.json" 21 | ) 22 | 23 | # Create an Amazon factory 24 | amazon_factory = AmazonTaskProviderFactory( 25 | aws_access_key_id="YOUR_ACCESS_KEY", 26 | aws_secret_access_key="YOUR_SECRET_KEY", 27 | region_name="us-west-2" 28 | ) 29 | ``` 30 | 31 | ## Creating OCR Providers 32 | 33 | Once you have a factory, you can create OCR providers with task-specific configurations: 34 | 35 | ```python 36 | # Create a GCP OCR provider 37 | gcp_ocr_provider = gcp_factory.get_page_ocr_provider( 38 | project_id="YOUR_PROJECT_ID", 39 | processor_id="YOUR_PROCESSOR_ID", 40 | max_workers=4, 41 | return_images=True 42 | ) 43 | 44 | # Create an Amazon Textract provider 45 | amazon_ocr_provider = amazon_factory.get_page_ocr_provider( 46 | max_workers=4, 47 | exclude_bounding_poly=True 48 | ) 49 | ``` 50 | 51 | ## Understanding Provider Configuration 52 | 53 | When configuring OCR providers, you'll encounter two types of parameters: 54 | 55 | 1. **Docprompt generic parameters**: These are common across different providers and control Docprompt's behavior. 56 | - `max_workers`: Controls concurrency for processing large documents 57 | - `exclude_bounding_poly`: Reduces memory usage by excluding detailed polygon data 58 | 59 | 2. **Provider-specific parameters**: These are unique to each backend and control provider-specific features. For example, if using GCP as an OCR provider, you must specify `project_id`, `processor_id` and you may optionally set `return_image_quality_scores`. 60 | 61 | ## Provider-Specific Features and Limitations 62 | 63 | ### Google Cloud Platform (GCP) 64 | - Offers advanced layout analysis and image quality scoring 65 | - Supports returning rasterized images of processed pages 66 | - Requires GCP-specific project and processor IDs 67 | 68 | Example configuration: 69 | ```python 70 | gcp_ocr_provider = gcp_factory.get_page_ocr_provider( 71 | project_id="YOUR_PROJECT_ID", 72 | processor_id="YOUR_PROCESSOR_ID", 73 | max_workers=4, # Docprompt generic 74 | return_images=True, # GCP-specific 75 | return_image_quality_scores=True # GCP-specific 76 | ) 77 | ``` 78 | 79 | ### Amazon Textract 80 | - Focuses on text extraction and layout analysis 81 | - Provides confidence scores for extracted text 82 | - Does not support returning rasterized images 83 | 84 | Example configuration: 85 | ```python 86 | amazon_ocr_provider = amazon_factory.get_page_ocr_provider( 87 | max_workers=4, # Docprompt generic 88 | exclude_bounding_poly=True # Docprompt generic 89 | ) 90 | ``` 91 | 92 | ## Best Practices 93 | 94 | 1. **Use factories for credential management**: This centralizes authentication and makes it easier to switch between providers. 95 | 96 | 2. **Consult provider documentation**: Always refer to the latest documentation from AWS or GCP for the most up-to-date information on their OCR services. 97 | 98 | 3. **Check Docprompt API reference**: Review Docprompt's API documentation for each provider to understand available configurations. 99 | 100 | 4. **Optimize for your use case**: Configure providers based on your specific needs, balancing performance and feature requirements. 101 | 102 | ## Conclusion 103 | 104 | Understanding the factory pattern and the distinction between Docprompt generic and provider-specific parameters is key to effectively configuring OCR providers in Docprompt. While this guide provides an overview using Amazon Textract and GCP as examples, the principles apply to other providers as well. Always consult the specific provider's documentation and Docprompt's API reference for the most current and detailed information. 105 | -------------------------------------------------------------------------------- /docs/guide/table_extraction/extract_tables.md: -------------------------------------------------------------------------------- 1 | # Table Extraction with DocPrompt: Invoice Parsing 2 | 3 | DocPrompt can be used to extract tables from documents with high accuracy using visual large language models, such as GPT-4 Vision or Anthropic's Claude 3. In this guide, we'll demonstrate how to extract tables from invoices using DocPrompt. 4 | 5 | ## Setting Up 6 | 7 | First, let's import the necessary modules and set up our environment: 8 | 9 | ```python 10 | from docprompt import load_document_node, DocumentNode 11 | from docprompt.tasks.factory import AnthropicTaskProviderFactory 12 | from docprompt.tasks.table_extraction import TableExtractionInput 13 | 14 | # Initialize the Anthropic factory 15 | # Ensure you have set the ANTHROPIC_API_KEY environment variable 16 | factory = AnthropicTaskProviderFactory() 17 | 18 | # Create the table extraction provider 19 | table_extraction_provider = factory.get_page_table_extraction_provider() 20 | ``` 21 | 22 | ## Preparing the Document 23 | 24 | Load a DocumentNode from a path 25 | 26 | ```python 27 | document_node = load_document_node("path/to/your/invoice.pdf") 28 | ``` 29 | 30 | ## Performing Table Extraction 31 | 32 | Now, let's run the table extraction task on our invoice: 33 | 34 | ```python 35 | results = table_extraction_provider.process_document_node(document_node) # Sync 36 | 37 | async_results = await table_extraction_provider.aprocess_document_node(document_node) 38 | ``` 39 | 40 | Alternatively, we can do table extraction async as well 41 | 42 | ## Interpreting Results 43 | 44 | Let's examine the extracted tables from a pretend invoice: 45 | 46 | ```python 47 | for page_number, result in results.items(): 48 | print(f"Tables extracted from Page {page_number}:") 49 | for i, table in enumerate(result.tables, 1): 50 | print(f"\nTable {i}:") 51 | print(f"Title: {table.title}") 52 | print("Headers:") 53 | print(", ".join(header.text for header in table.headers)) 54 | print("Rows:") 55 | for row in table.rows: 56 | print(", ".join(cell.text for cell in row.cells)) 57 | print('---') 58 | ``` 59 | 60 | This will print the extracted tables, including headers and rows, for each page of the invoice. 61 | 62 | ## Increasing Accuracy 63 | 64 | In Anthropic's case, the default is `"claude-3-haiku-20240307"`. This performs with high accuracy, and is over 5x cheaper than table extraction using Azure Document Intelligence. 65 | 66 | In use-cases where accuracy is paramount however, it may be worthwhile to set the provider to a more powerful model. 67 | 68 | ```python 69 | table_extraction_provider = factory.get_page_table_extraction_provider( 70 | model_name="claude-3-5-sonnet-20240620" # setup the task provider with Sonnet 35 71 | ) 72 | 73 | results = table_extraction_provider.process_document_node( 74 | document_node, 75 | table_extraction_input, 76 | model_name="claude-3-5-sonnet-20240620" # or declare model name at inference time 77 | ) 78 | ``` 79 | 80 | As Large Language Models steadily get cheaper and more capable, your inference costs will drop inevitably. The beauty of progress! 81 | 82 | 83 | ## Resolving Bounding Boxes 84 | 85 | **Coming Soon** 86 | 87 | In some scenarios, you may want the exact bounding boxes of the various rows, columns, and cells. If you've processed OCR results through Docprompt, this is possible by specifying an additional argument in `process_document_node` 88 | 89 | ```python 90 | results = table_extraction_provider.process_document_node( 91 | document_node, 92 | table_extraction_input, 93 | model_name="claude-3-5-sonnet-20240620", # or declare model name at inference time 94 | resolve_bounding_boxes=True 95 | ) 96 | ``` 97 | 98 | If you've collected and stored OCR results on the DocumentNode, this will use word-level bounding boxes coupled with the Docprompt search engine to determine the bounding boxes of the resulting tables, where possible. 99 | 100 | ## Conclusion 101 | 102 | Table extraction with DocPrompt provides a powerful way to automatically parse structured data from any documents containing tabular information in just a few lines of code. 103 | 104 | The quality of your results depends on the model and the complexity of the table layouts. Experiment with different configurations and post-processing steps to find what works best for your specific use case. 105 | 106 | When combining with other tasks such as classification, layout analysis and markerization, you can build powerful document processing pipelines in just a few steps. 107 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Docprompt - Getting Started 2 | 3 | ## Supercharged Document Analysis 4 | 5 | * Common utilities for interacting with PDFs 6 | * PDF loading and serialization 7 | * PDF byte compression using Ghostscript :ghost: 8 | * Fast rasterization :fire: :rocket: 9 | * Page splitting, re-export with PDFium 10 | * Support for most OCR providers with batched inference 11 | * Google :white_check_mark: 12 | * Azure Document Intelligence :red_circle: 13 | * Amazon Textract :red_circle: 14 | * Tesseract :red_circle: 15 | 16 | 17 | 18 | ### Installation 19 | 20 | Base installation 21 | 22 | ```bash 23 | pip install docprompt 24 | ``` 25 | 26 | With an OCR provider 27 | 28 | ```bash 29 | pip install "docprompt[google] 30 | ``` 31 | 32 | ## Usage 33 | 34 | 35 | ### Simple Operations 36 | ```python 37 | from docprompt import load_document 38 | 39 | # Load a document 40 | document = load_document("path/to/my.pdf") 41 | 42 | # Rasterize a single page using Ghostscript 43 | page_number = 5 44 | rastered = document.rasterize_page(page_number, dpi=120) 45 | 46 | # Split a pdf based on a page range 47 | document_2 = document.split(start=125, stop=130) 48 | ``` 49 | 50 | ### Performing OCR 51 | ```python 52 | from docprompt import load_document, DocumentNode 53 | from docprompt.tasks.ocr.gcp import GoogleOcrProvider 54 | 55 | provider = GoogleOcrProvider.from_service_account_file( 56 | project_id=my_project_id, 57 | processor_id=my_processor_id, 58 | service_account_file=path_to_service_file 59 | ) 60 | 61 | document = load_document("path/to/my.pdf") 62 | 63 | # A container holds derived data for a document, like OCR or classification results 64 | document_node = DocumentNode.from_document(document) 65 | 66 | provider.process_document_node(document_node) # Caches results on the document_node 67 | 68 | document_node[0].ocr_result # Access OCR results 69 | ``` 70 | 71 | 72 | ### Document Search 73 | 74 | When a large language model returns a result, we might want to highlight that result for our users. However, language models return results as **text**, while what we need to show our users requires a page number and a bounding box. 75 | 76 | After extracting text from a PDF, we can support this pattern using `DocumentProvenanceLocator`, which lives on a `DocumentNode` 77 | 78 | ``` 79 | from docprompt import load_document, DocumentNode 80 | from docprompt.tasks.ocr.gcp import GoogleOcrProvider 81 | 82 | provider = GoogleOcrProvider.from_service_account_file( 83 | project_id=my_project_id, 84 | processor_id=my_processor_id, 85 | service_account_file=path_to_service_file 86 | ) 87 | 88 | document = load_document("path/to/my.pdf") 89 | 90 | # A container holds derived data for a document, like OCR or classification results 91 | document_node = DocumentNode.from_document(document) 92 | 93 | provider.process_document_node(document_node) # Caches results on the document_node 94 | 95 | # With OCR results available, we can now instantiate a locator and search through documents. 96 | 97 | document_node.locator.search("John Doe") # This will return a list of all terms across the document that contain "John Doe" 98 | document_node.locator.search("Jane Doe", page_number=4) # Just return results a list of matching results from page 4 99 | ``` 100 | 101 | This functionality uses a combination of `rtree` and the Rust library `tantivy`, allowing you to perform thousands of searches in **seconds** :fire: :rocket: 102 | -------------------------------------------------------------------------------- /makefile: -------------------------------------------------------------------------------- 1 | sources = docprompt 2 | 3 | .PHONY: test format lint unittest coverage pre-commit clean 4 | test: format lint unittest 5 | 6 | format: 7 | isort $(sources) tests 8 | black $(sources) tests 9 | 10 | lint: 11 | flake8 $(sources) tests 12 | mypy $(sources) tests 13 | 14 | unittest: 15 | pytest 16 | 17 | coverage: 18 | pytest --cov=$(sources) --cov-branch --cov-report=term-missing tests 19 | 20 | pre-commit: 21 | pre-commit run --all-files 22 | 23 | clean: 24 | rm -rf .mypy_cache .pytest_cache 25 | rm -rf *.egg-info 26 | rm -rf .tox dist site 27 | rm -rf coverage.xml .coverage 28 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Docprompt 2 | site_url: https://docs.docprompt.io 3 | repo_url: https://github.com/docprompt/Docprompt 4 | theme: 5 | name: material 6 | logo: assets/static/img/logo.png 7 | favicon: assets/static/img/logo.png 8 | features: 9 | - navigation.instant 10 | - navigation.instant.prefetch 11 | plugins: 12 | - search 13 | - blog 14 | - gen-files: 15 | scripts: 16 | - docs/gen_ref_pages.py 17 | - literate-nav: 18 | nav_file: SUMMARY.md 19 | - mkdocstrings: 20 | handlers: 21 | python: 22 | paths: [docprompt] 23 | options: 24 | docstring_style: google 25 | show_source: true 26 | show_submodules: true 27 | 28 | nav: 29 | - Getting Started: index.md 30 | - How-to Guides: 31 | - Perform OCR: 32 | - Basic OCR Usage: guide/ocr/basic_usage.md 33 | - Customizing OCR Providers: guide/ocr/provider_config.md 34 | - Lightning-Fast Doc Search: guide/ocr/advanced_search.md 35 | - OCR-based Workflows: guide/ocr/advanced_workflows.md 36 | - Classify Pages: 37 | - Binary Classification: guide/classify/binary.md 38 | - Single-Label Classification: guide/classify/single.md 39 | - Multi-Label Classification: guide/classify/multi.md 40 | - Extract Tables: guide/table_extraction/extract_tables.md 41 | - Concepts: 42 | - Primatives: concepts/primatives.md 43 | - Nodes: concepts/nodes.md 44 | - Providers: concepts/providers.md 45 | - Provenance: concepts/provenance.md 46 | - Cloud: enterprise.md 47 | - Blog: 48 | - blog/index.md 49 | - API Reference: 50 | - Docprompt SDK: reference/ 51 | - Enterpise API: enterprise/ 52 | - Community: 53 | - Contributing: community/contributing.md 54 | - Versioning: community/versioning.md 55 | 56 | markdown_extensions: 57 | - pymdownx.highlight: 58 | anchor_linenums: true 59 | - pymdownx.superfences 60 | - pymdownx.emoji: 61 | emoji_index: !!python/name:materialx.emoji.twemoji 62 | emoji_generator: !!python/name:materialx.emoji.to_svg 63 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["pdm-backend"] 3 | build-backend = "pdm.backend" 4 | 5 | [project] 6 | name = "docprompt" 7 | version = "0.8.7" 8 | description = "Documents and large language models." 9 | authors = [ 10 | {name = "Frank Colson", email = "frank@pageleaf.io"} 11 | ] 12 | dependencies = [ 13 | "pillow>=9.0.1", 14 | "tqdm>=4.50.2", 15 | "fsspec>=2022.11.0", 16 | "pydantic>=2.1.0", 17 | "tenacity>=7.0.0", 18 | "pypdfium2<5.0.0,>=4.28.0", 19 | "filetype>=1.2.0", 20 | "beautifulsoup4>=4.12.3", 21 | "pypdf>=5.0.0" 22 | ] 23 | requires-python = "<3.13,>=3.8.6" 24 | readme = "README.md" 25 | license = {text = "Apache-2.0"} 26 | classifiers = ["Development Status :: 2 - Pre-Alpha", "Intended Audience :: Developers", "License :: OSI Approved :: Apache Software License", "Natural Language :: English", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12"] 27 | 28 | [project.optional-dependencies] 29 | google = ["google-cloud-documentai>=2.20.0"] 30 | azure = ["azure-ai-formrecognizer>=3.3.0"] 31 | search = [ 32 | "tantivy<1.0.0,>=0.21.0", 33 | "rtree<3.0.0,>=1.2.0", 34 | "networkx<4.0.0,>=2.5.0", 35 | "rapidfuzz>=3.0.0" 36 | ] 37 | anthropic = [ 38 | "anthropic>=0.26.0" 39 | ] 40 | openai = [ 41 | "openai>=1.0.1" 42 | ] 43 | aws = [ 44 | "aioboto3>=13.1.0", 45 | "boto3>=1.18.0" 46 | ] 47 | 48 | [project.scripts] 49 | docprompt = "docprompt.cli:main" 50 | 51 | [project.urls] 52 | homepage = "https://github.com/Docprompt/docprompt" 53 | 54 | [tool.black] 55 | line-length = 120 56 | skip-string-normalization = true 57 | target-version = ['py39', 'py310', 'py311'] 58 | include = '\.pyi?$' 59 | exclude = ''' 60 | /( 61 | \.eggs 62 | | \.git 63 | | \.hg 64 | | \.mypy_cache 65 | | \.tox 66 | | \.venv` 67 | | _build 68 | | buck-out 69 | | build 70 | | dist 71 | )/ 72 | ''' 73 | 74 | [tool.flake8] 75 | ignore = [ 76 | "E501" 77 | ] 78 | 79 | [tool.isort] 80 | multi_line_output = 3 81 | include_trailing_comma = true 82 | force_grid_wrap = 0 83 | use_parentheses = true 84 | ensure_newline_before_comments = true 85 | line_length = 120 86 | skip_gitignore = true 87 | 88 | [tool.pdm] 89 | distribution = true 90 | 91 | [tool.pdm.build] 92 | includes = ["docprompt", "tests"] 93 | 94 | [tool.pdm.dev-dependencies] 95 | test = [ 96 | "isort<6.0.0,>=5.12.0", 97 | "flake8<7.0.0,>=6.1.0", 98 | "flake8-docstrings<2.0.0,>=1.7.0", 99 | "mypy<2.0.0,>=1.6.1", 100 | "pytest<8.0.0,>=7.4.2", 101 | "pytest-cov<5.0.0,>=4.1.0", 102 | "ruff<1.0.0,>=0.3.3", 103 | "pytest-asyncio>=0.23.7" 104 | ] 105 | dev = [ 106 | "tox<4.0.0,>=3.20.1", 107 | "virtualenv<25.0.0,>=20.2.2", 108 | "pip<21.0.0,>=20.3.1", 109 | "twine<4.0.0,>=3.3.0", 110 | "pre-commit>=3.5.0", 111 | "toml<1.0.0,>=0.10.2", 112 | "bump2version<2.0.0,>=1.0.1", 113 | "ipython>=7.12.0", 114 | "python-dotenv>=1.0.1", 115 | "ipykernel>=6.29.4", 116 | "pytest-asyncio>=0.23.7" 117 | ] 118 | docs = [ 119 | "mkdocs>=1.6.0", 120 | "mkdocs-material>=9.5.27", 121 | "mkdocstrings[python]>=0.25.1", 122 | "mkdocs-blog-plugin>=0.25", 123 | "mkdocs-gen-files>=0.5.0", 124 | "mkdocs-literate-nav>=0.6.1" 125 | ] 126 | 127 | [tool.pdm.scripts] 128 | docs = "mkdocs serve" 129 | lint = "pre-commit run --all-files" 130 | cov = {shell = "python tests/_run_tests_with_coverage.py {args}"} 131 | 132 | [tool.ruff] 133 | target-version = "py38" 134 | 135 | [tool.ruff.lint] 136 | # Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. 137 | # Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or 138 | # McCabe complexity (`C901`) by default. 139 | select = ["E4", "E7", "E9", "F"] 140 | extend-select = ["I"] 141 | ignore = ["D212"] 142 | # Allow fix for all enabled rules (when `--fix`) is provided. 143 | fixable = ["ALL"] 144 | unfixable = [] 145 | # Allow unused variables when underscore-prefixed. 146 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 147 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | max-complexity = 18 4 | ignore = E203, E266, W503 5 | docstring-convention = google 6 | per-file-ignores = __init__.py:F401 7 | exclude = .git, 8 | __pycache__, 9 | setup.py, 10 | build, 11 | dist, 12 | docs, 13 | releases, 14 | .venv, 15 | .tox, 16 | .mypy_cache, 17 | .pytest_cache, 18 | .vscode, 19 | .github, 20 | # By default test codes will be linted. 21 | # tests 22 | 23 | [mypy] 24 | ignore_missing_imports = True 25 | 26 | [coverage:run] 27 | # uncomment the following to omit files during running 28 | #omit = 29 | [coverage:report] 30 | exclude_lines = 31 | pragma: no cover 32 | def __repr__ 33 | if self.debug: 34 | if settings.DEBUG 35 | raise AssertionError 36 | raise NotImplementedError 37 | if 0: 38 | if __name__ == .__main__.: 39 | def main 40 | 41 | [tox:tox] 42 | isolated_build = true 43 | envlist = py39, py310, py311, py312 format, lint, build 44 | 45 | [gh-actions] 46 | python = 47 | 3.11: py311, format, lint, build 48 | 3.10: py310 49 | 3.9: py39 50 | 51 | [testenv] 52 | allowlist_externals = pytest 53 | extras = 54 | test 55 | passenv = * 56 | setenv = 57 | PYTHONPATH = {toxinidir} 58 | PYTHONWARNINGS = ignore 59 | commands = 60 | pytest --cov=docprompt --cov-branch --cov-report=xml --cov-report=term-missing tests 61 | 62 | [testenv:format] 63 | allowlist_externals = 64 | isort 65 | black 66 | extras = 67 | test 68 | commands = 69 | isort docprompt 70 | black docprompt tests 71 | 72 | [testenv:lint] 73 | allowlist_externals = 74 | flake8 75 | mypy 76 | extras = 77 | test 78 | commands = 79 | flake8 docprompt tests 80 | mypy docprompt tests 81 | 82 | [testenv:build] 83 | allowlist_externals = 84 | pdm 85 | mkdocs 86 | twine 87 | extras = 88 | doc 89 | dev 90 | commands = 91 | pdm build 92 | mkdocs build 93 | twine check dist/* 94 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Unit test package for docprompt.""" 2 | -------------------------------------------------------------------------------- /tests/_run_tests_with_coverage.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import sys 3 | 4 | 5 | def main(): 6 | args = sys.argv[1:] 7 | module = "docprompt" # Default module 8 | pytest_args = [] 9 | 10 | # Parse arguments 11 | for arg in args: 12 | if arg.startswith("--mod="): 13 | module = arg.split("=")[1] 14 | else: 15 | pytest_args.append(arg) 16 | 17 | # Construct the pytest command 18 | command = [ 19 | "pytest", 20 | f"--cov={module}", 21 | "--cov-report=term-missing", 22 | ] + pytest_args 23 | 24 | # Run the command 25 | subprocess.run(command) 26 | 27 | 28 | if __name__ == "__main__": 29 | main() 30 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | def pytest_configure(config): 2 | config.addinivalue_line( 3 | "filterwarnings", "ignore:The NumPy module was reloaded:UserWarning" 4 | ) 5 | -------------------------------------------------------------------------------- /tests/fixtures.py: -------------------------------------------------------------------------------- 1 | from .util import PdfFixture 2 | 3 | PDF_FIXTURES = [ 4 | PdfFixture( 5 | name="1.pdf", 6 | page_count=6, 7 | file_hash="121ffed4336e6129e97ee3c4cb747864", 8 | ocr_name="1_ocr.json", 9 | ), 10 | PdfFixture( 11 | name="2.pdf", 12 | page_count=23, 13 | file_hash="bd2fa4f101b305e4001acf9137ce78cf", 14 | ocr_name="1_ocr.json", 15 | ), 16 | ] 17 | -------------------------------------------------------------------------------- /tests/fixtures/1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/tests/fixtures/1.pdf -------------------------------------------------------------------------------- /tests/fixtures/2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/tests/fixtures/2.pdf -------------------------------------------------------------------------------- /tests/schema/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/tests/schema/__init__.py -------------------------------------------------------------------------------- /tests/schema/pipeline/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/tests/schema/pipeline/__init__.py -------------------------------------------------------------------------------- /tests/schema/pipeline/test_imagenode.py: -------------------------------------------------------------------------------- 1 | from docprompt.schema.pipeline import ImageNode 2 | 3 | 4 | def test_imagenode(): 5 | ImageNode(image=b"test", metadata={}) 6 | -------------------------------------------------------------------------------- /tests/schema/pipeline/test_layoutaware.py: -------------------------------------------------------------------------------- 1 | from tests.fixtures import PDF_FIXTURES 2 | 3 | 4 | def test_direct__page_node_layout_aware_text(): 5 | # Create a sample PageNode with some TextBlocks 6 | fixture = PDF_FIXTURES[0] 7 | 8 | document = fixture.get_document_node() 9 | 10 | page = document.page_nodes[0] 11 | 12 | assert page.ocr_results, "The OCR results should be populated" 13 | 14 | layout_text__property = page.layout_aware_text 15 | 16 | layout_len = 4786 17 | 18 | assert ( 19 | len(layout_text__property) == layout_len 20 | ), f"The layout-aware text should be {layout_len} characters long" 21 | 22 | layout_text__direct = page.get_layout_aware_text() 23 | 24 | assert ( 25 | layout_text__property == layout_text__direct 26 | ), "The layout-aware text should be the same" 27 | -------------------------------------------------------------------------------- /tests/schema/test_layout_models.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from docprompt import NormBBox 4 | from docprompt.schema.layout import BoundingPoly, Point 5 | 6 | 7 | def test_normbbox_utilities(): 8 | bbox = NormBBox(x0=0, top=0, x1=1, bottom=1) 9 | 10 | # Test simple properties 11 | assert bbox[0] == 0 12 | assert bbox[1] == 0 13 | assert bbox[2] == 1 14 | assert bbox[3] == 1 15 | 16 | assert bbox.width == 1 17 | assert bbox.height == 1 18 | assert bbox.area == 1 19 | assert bbox.centroid == (0.5, 0.5) 20 | assert bbox.y_center == 0.5 21 | assert bbox.x_center == 0.5 22 | 23 | # Test equality 24 | assert bbox == NormBBox(x0=0, top=0, x1=1, bottom=1) 25 | 26 | # Test out of bounds 27 | 28 | with pytest.raises(ValueError): 29 | NormBBox(x0=0, top=0, x1=1, bottom=2) 30 | 31 | # Add two bboxes 32 | 33 | bbox_2 = NormBBox(x0=0.5, top=0.5, x1=1, bottom=1.0) 34 | combined_bbox = bbox + bbox_2 35 | assert combined_bbox == NormBBox(x0=0, top=0, x1=1.0, bottom=1.0) 36 | 37 | # Add two bboxes via combine 38 | 39 | combined_bbox = NormBBox.combine(bbox, bbox_2) 40 | assert combined_bbox == NormBBox(x0=0, top=0, x1=1.0, bottom=1.0) 41 | 42 | # Test from bounding poly 43 | bounding_poly = BoundingPoly( 44 | normalized_vertices=[ 45 | Point(x=0, y=0), 46 | Point(x=1, y=0), 47 | Point(x=1, y=1), 48 | Point(x=0, y=1), 49 | ] 50 | ) 51 | 52 | bbox = NormBBox.from_bounding_poly(bounding_poly) 53 | 54 | assert bbox == NormBBox(x0=0, top=0, x1=1, bottom=1) 55 | 56 | # Test contains 57 | 58 | small_bbox = NormBBox(x0=0.25, top=0.25, x1=0.75, bottom=0.75) 59 | big_bbox = NormBBox(x0=0, top=0, x1=1, bottom=1) 60 | 61 | assert small_bbox in big_bbox 62 | assert big_bbox not in small_bbox 63 | 64 | # Test Overlap 65 | 66 | assert small_bbox.x_overlap(big_bbox) == 0.5 67 | assert small_bbox.y_overlap(big_bbox) == 0.5 68 | 69 | # Should be commutative 70 | assert big_bbox.x_overlap(small_bbox) == 0.5 71 | assert big_bbox.y_overlap(small_bbox) == 0.5 72 | -------------------------------------------------------------------------------- /tests/tasks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/tests/tasks/__init__.py -------------------------------------------------------------------------------- /tests/tasks/classification/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test the functionality of the classification task provider. 3 | """ 4 | -------------------------------------------------------------------------------- /tests/tasks/classification/test_anthropic.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import AsyncMock, patch 2 | 3 | import pytest 4 | 5 | from docprompt.tasks.classification.anthropic import ( 6 | AnthropicClassificationProvider, 7 | AnthropicPageClassificationOutputParser, 8 | _prepare_messages, 9 | ) 10 | from docprompt.tasks.classification.base import ( 11 | ClassificationConfig, 12 | ClassificationOutput, 13 | ClassificationTypes, 14 | ConfidenceLevel, 15 | ) 16 | 17 | 18 | class TestAnthropicPageClassificationOutputParser: 19 | @pytest.fixture 20 | def parser(self): 21 | return AnthropicPageClassificationOutputParser( 22 | name="anthropic", 23 | type=ClassificationTypes.SINGLE_LABEL, 24 | labels=["A", "B", "C"], 25 | confidence=True, 26 | ) 27 | 28 | def test_parse_single_label(self, parser): 29 | text = "Reasoning: This is a test.\nAnswer: B\nConfidence: high" 30 | result = parser.parse(text) 31 | 32 | assert isinstance(result, ClassificationOutput) 33 | assert result.type == ClassificationTypes.SINGLE_LABEL 34 | assert result.labels == "B" 35 | assert result.score == ConfidenceLevel.HIGH 36 | assert result.provider_name == "anthropic" 37 | 38 | def test_parse_multi_label(self): 39 | parser = AnthropicPageClassificationOutputParser( 40 | name="anthropic", 41 | type=ClassificationTypes.MULTI_LABEL, 42 | labels=["X", "Y", "Z"], 43 | confidence=True, 44 | ) 45 | text = ( 46 | "Reasoning: This is a multi-label test.\nAnswer: X, Z\nConfidence: medium" 47 | ) 48 | result = parser.parse(text) 49 | 50 | assert isinstance(result, ClassificationOutput) 51 | assert result.type == ClassificationTypes.MULTI_LABEL 52 | assert result.labels == ["X", "Z"] 53 | assert result.score == ConfidenceLevel.MEDIUM 54 | assert result.provider_name == "anthropic" 55 | 56 | def test_parse_binary(self): 57 | parser = AnthropicPageClassificationOutputParser( 58 | name="anthropic", 59 | type=ClassificationTypes.BINARY, 60 | labels=["YES", "NO"], 61 | confidence=False, 62 | ) 63 | text = "Reasoning: This is a binary test.\nAnswer: YES" 64 | result = parser.parse(text) 65 | 66 | assert isinstance(result, ClassificationOutput) 67 | assert result.type == ClassificationTypes.BINARY 68 | assert result.labels == "YES" 69 | assert result.score is None 70 | assert result.provider_name == "anthropic" 71 | 72 | def test_parse_invalid_answer(self, parser): 73 | text = "Reasoning: This is an invalid test.\nAnswer: D\nConfidence: low" 74 | with pytest.raises(ValueError, match="Invalid label: D"): 75 | parser.parse(text) 76 | 77 | 78 | class TestAnthropicClassificationProvider: 79 | @pytest.fixture 80 | def provider(self): 81 | return AnthropicClassificationProvider() 82 | 83 | @pytest.fixture 84 | def mock_config(self): 85 | return ClassificationConfig( 86 | type=ClassificationTypes.SINGLE_LABEL, 87 | labels=["A", "B", "C"], 88 | confidence=True, 89 | ) 90 | 91 | @pytest.mark.asyncio() 92 | async def test_ainvoke(self, provider, mock_config): 93 | mock_input = [b"image1", b"image2"] 94 | mock_completions = [ 95 | "Reasoning: Test 1\nAnswer: A\nConfidence: high", 96 | "Reasoning: Test 2\nAnswer: B\nConfidence: medium", 97 | ] 98 | 99 | with patch( 100 | "docprompt.tasks.classification.anthropic._prepare_messages" 101 | ) as mock_prepare: 102 | mock_prepare.return_value = "mock_messages" 103 | 104 | with patch( 105 | "docprompt.utils.inference.run_batch_inference_anthropic", 106 | new_callable=AsyncMock, 107 | ) as mock_inference: 108 | mock_inference.return_value = mock_completions 109 | 110 | test_kwargs = { 111 | "test": "test" 112 | } # Test that kwargs are passed through to inference 113 | results = await provider._ainvoke( 114 | mock_input, mock_config, **test_kwargs 115 | ) 116 | 117 | assert len(results) == 2 118 | assert all(isinstance(result, ClassificationOutput) for result in results) 119 | assert results[0].labels == "A" 120 | assert results[0].score == ConfidenceLevel.HIGH 121 | assert results[1].labels == "B" 122 | assert results[1].score == ConfidenceLevel.MEDIUM 123 | 124 | mock_prepare.assert_called_once_with(mock_input, mock_config) 125 | mock_inference.assert_called_once_with( 126 | "claude-3-haiku-20240307", "mock_messages", **test_kwargs 127 | ) 128 | 129 | @pytest.mark.asyncio() 130 | async def test_ainvoke_with_error(self, provider, mock_config): 131 | mock_input = [b"image1"] 132 | mock_completions = ["Reasoning: Error test\nAnswer: Invalid\nConfidence: low"] 133 | 134 | with patch( 135 | "docprompt.tasks.classification.anthropic._prepare_messages" 136 | ) as mock_prepare: 137 | mock_prepare.return_value = "mock_messages" 138 | 139 | with patch( 140 | "docprompt.utils.inference.run_batch_inference_anthropic", 141 | new_callable=AsyncMock, 142 | ) as mock_inference: 143 | mock_inference.return_value = mock_completions 144 | 145 | with pytest.raises(ValueError, match="Invalid label: Invalid"): 146 | await provider._ainvoke(mock_input, mock_config) 147 | 148 | def test_prepare_messages(self, mock_config): 149 | imgs = [b"image1", b"image2"] 150 | config = mock_config 151 | 152 | result = _prepare_messages(imgs, config) 153 | 154 | assert len(result) == 2 155 | for msg_group in result: 156 | assert len(msg_group) == 1 157 | msg = msg_group[0] 158 | assert msg.role == "user" 159 | assert len(msg.content) == 2 160 | 161 | types = set([content.type for content in msg.content]) 162 | assert types == set(["image_url", "text"]) 163 | -------------------------------------------------------------------------------- /tests/tasks/markerize/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/tests/tasks/markerize/__init__.py -------------------------------------------------------------------------------- /tests/tasks/markerize/test_anthropic.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test the Anthropic implementation of the markerize task. 3 | """ 4 | 5 | from unittest.mock import patch 6 | 7 | import pytest 8 | 9 | from docprompt.tasks.markerize.anthropic import ( 10 | AnthropicMarkerizeProvider, 11 | _parse_result, 12 | _prepare_messages, 13 | ) 14 | from docprompt.tasks.markerize.base import MarkerizeResult 15 | from docprompt.tasks.message import OpenAIComplexContent, OpenAIImageURL, OpenAIMessage 16 | 17 | 18 | @pytest.fixture 19 | def mock_image_bytes(): 20 | return b"mock_image_bytes" 21 | 22 | 23 | class TestAnthropicMarkerizeProvider: 24 | @pytest.fixture 25 | def provider(self): 26 | return AnthropicMarkerizeProvider() 27 | 28 | def test_provider_name(self, provider): 29 | assert provider.name == "anthropic" 30 | 31 | @pytest.mark.asyncio 32 | async def test_ainvoke(self, provider, mock_image_bytes): 33 | mock_completions = ["# Test Markdown", "## Another Test"] 34 | 35 | with patch( 36 | "docprompt.tasks.markerize.anthropic._prepare_messages" 37 | ) as mock_prepare: 38 | with patch( 39 | "docprompt.utils.inference.run_batch_inference_anthropic" 40 | ) as mock_inference: 41 | mock_prepare.return_value = "mock_messages" 42 | mock_inference.return_value = mock_completions 43 | 44 | test_kwargs = { 45 | "test": "test" 46 | } # Test that kwargs are passed through to inference 47 | result = await provider._ainvoke( 48 | [mock_image_bytes, mock_image_bytes], **test_kwargs 49 | ) 50 | 51 | assert len(result) == 2 52 | assert all(isinstance(r, MarkerizeResult) for r in result) 53 | assert result[0].raw_markdown == "# Test Markdown" 54 | assert result[1].raw_markdown == "## Another Test" 55 | assert all(r.provider_name == "anthropic" for r in result) 56 | 57 | mock_prepare.assert_called_once_with( 58 | [mock_image_bytes, mock_image_bytes] 59 | ) 60 | mock_inference.assert_called_once_with( 61 | "claude-3-haiku-20240307", "mock_messages", **test_kwargs 62 | ) 63 | 64 | 65 | def test_prepare_messages(mock_image_bytes): 66 | messages = _prepare_messages([mock_image_bytes]) 67 | 68 | assert len(messages) == 1 69 | assert len(messages[0]) == 1 70 | assert isinstance(messages[0][0], OpenAIMessage) 71 | assert messages[0][0].role == "user" 72 | assert len(messages[0][0].content) == 2 73 | assert isinstance(messages[0][0].content[0], OpenAIComplexContent) 74 | assert messages[0][0].content[0].type == "image_url" 75 | assert isinstance(messages[0][0].content[0].image_url, OpenAIImageURL) 76 | assert messages[0][0].content[0].image_url.url == mock_image_bytes.decode("utf-8") 77 | assert isinstance(messages[0][0].content[1], OpenAIComplexContent) 78 | assert messages[0][0].content[1].type == "text" 79 | assert "Convert the image into markdown" in messages[0][0].content[1].text 80 | 81 | 82 | @pytest.mark.parametrize( 83 | "raw_markdown,expected", 84 | [ 85 | ("# Test Markdown", "# Test Markdown"), 86 | ("## Another Test", "## Another Test"), 87 | ("Invalid markdown", ""), 88 | (" Trimmed ", "Trimmed"), 89 | ], 90 | ) 91 | def test_parse_result(raw_markdown, expected): 92 | assert _parse_result(raw_markdown) == expected 93 | -------------------------------------------------------------------------------- /tests/tasks/markerize/test_base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test the basic functionality of the atomic components for the markerize task. 3 | """ 4 | 5 | from unittest.mock import MagicMock, patch 6 | 7 | import pytest 8 | 9 | from docprompt.schema.pipeline.node.document import DocumentNode 10 | from docprompt.schema.pipeline.node.page import PageNode 11 | from docprompt.tasks.markerize.base import BaseMarkerizeProvider, MarkerizeResult 12 | 13 | 14 | class TestBaseMarkerizeProvider: 15 | """ 16 | Test the base markerize provider to ensure that the task is properly handled. 17 | """ 18 | 19 | @pytest.fixture 20 | def mock_document_node(self): 21 | mock_node = MagicMock(spec=DocumentNode) 22 | mock_node.page_nodes = [MagicMock(spec=PageNode) for _ in range(5)] 23 | for pnode in mock_node.page_nodes: 24 | pnode.rasterizer.rasterize.return_value = b"image" 25 | mock_node.__len__.return_value = len(mock_node.page_nodes) 26 | return mock_node 27 | 28 | @pytest.mark.parametrize( 29 | "start,stop,expected_keys,expected_results", 30 | [ 31 | (2, 4, [2, 3, 4], {2: "RESULT-0", 3: "RESULT-1", 4: "RESULT-2"}), 32 | (3, None, [3, 4, 5], {3: "RESULT-0", 4: "RESULT-1", 5: "RESULT-2"}), 33 | (None, 2, [1, 2], {1: "RESULT-0", 2: "RESULT-1"}), 34 | ( 35 | None, 36 | None, 37 | [1, 2, 3, 4, 5], 38 | { 39 | 1: "RESULT-0", 40 | 2: "RESULT-1", 41 | 3: "RESULT-2", 42 | 4: "RESULT-3", 43 | 5: "RESULT-4", 44 | }, 45 | ), 46 | ], 47 | ) 48 | def test_process_document_node_with_start_stop( 49 | self, mock_document_node, start, stop, expected_keys, expected_results 50 | ): 51 | class TestProvider(BaseMarkerizeProvider): 52 | name = "test" 53 | 54 | def _invoke(self, input, config, **kwargs): 55 | return [ 56 | MarkerizeResult(raw_markdown=f"RESULT-{i}", provider_name="test") 57 | for i in range(len(input)) 58 | ] 59 | 60 | provider = TestProvider() 61 | result = provider.process_document_node( 62 | mock_document_node, start=start, stop=stop 63 | ) 64 | 65 | assert list(result.keys()) == expected_keys 66 | assert all(isinstance(v, MarkerizeResult) for v in result.values()) 67 | assert {k: v.raw_markdown for k, v in result.items()} == expected_results 68 | 69 | with patch.object(provider, "_invoke") as mock_invoke: 70 | provider.process_document_node(mock_document_node, start=start, stop=stop) 71 | mock_invoke.assert_called_once() 72 | expected_invoke_length = len(expected_keys) 73 | assert len(mock_invoke.call_args[0][0]) == expected_invoke_length 74 | 75 | def test_process_document_node_rasterization(self, mock_document_node): 76 | class TestProvider(BaseMarkerizeProvider): 77 | name = "test" 78 | 79 | def _invoke(self, input, config, **kwargs): 80 | return [ 81 | MarkerizeResult(raw_markdown=f"RESULT-{i}", provider_name="test") 82 | for i in range(len(input)) 83 | ] 84 | 85 | provider = TestProvider() 86 | provider.process_document_node(mock_document_node) 87 | 88 | for page_node in mock_document_node.page_nodes: 89 | page_node.rasterizer.rasterize.assert_called_once_with("default") 90 | -------------------------------------------------------------------------------- /tests/tasks/table_extraction/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/tests/tasks/table_extraction/__init__.py -------------------------------------------------------------------------------- /tests/tasks/table_extraction/test_anthropic.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test the Anthropic implementation of the table extraction task. 3 | """ 4 | 5 | from unittest.mock import patch 6 | 7 | import pytest 8 | from bs4 import BeautifulSoup 9 | 10 | from docprompt.tasks.message import OpenAIComplexContent, OpenAIImageURL, OpenAIMessage 11 | from docprompt.tasks.table_extraction.anthropic import ( 12 | AnthropicTableExtractionProvider, 13 | _headers_from_tree, 14 | _prepare_messages, 15 | _rows_from_tree, 16 | _title_from_tree, 17 | parse_response, 18 | ) 19 | from docprompt.tasks.table_extraction.schema import ( 20 | TableCell, 21 | TableExtractionPageResult, 22 | TableHeader, 23 | TableRow, 24 | ) 25 | 26 | 27 | @pytest.fixture 28 | def mock_image_bytes(): 29 | return b"mock_image_bytes" 30 | 31 | 32 | class TestAnthropicTableExtractionProvider: 33 | @pytest.fixture 34 | def provider(self): 35 | return AnthropicTableExtractionProvider() 36 | 37 | def test_provider_name(self, provider): 38 | assert provider.name == "anthropic" 39 | 40 | @pytest.mark.asyncio 41 | async def test_ainvoke(self, provider, mock_image_bytes): 42 | mock_completions = [ 43 | "Test Table
Col1
Data1
", 44 | "
Col2
Data2
", 45 | ] 46 | 47 | with patch( 48 | "docprompt.tasks.table_extraction.anthropic._prepare_messages" 49 | ) as mock_prepare: 50 | with patch( 51 | "docprompt.utils.inference.run_batch_inference_anthropic" 52 | ) as mock_inference: 53 | mock_prepare.return_value = "mock_messages" 54 | mock_inference.return_value = mock_completions 55 | 56 | result = await provider._ainvoke([mock_image_bytes, mock_image_bytes]) 57 | 58 | assert len(result) == 2 59 | assert all(isinstance(r, TableExtractionPageResult) for r in result) 60 | assert result[0].tables[0].title == "Test Table" 61 | assert result[1].tables[0].title is None 62 | assert all(r.provider_name == "anthropic" for r in result) 63 | 64 | mock_prepare.assert_called_once_with( 65 | [mock_image_bytes, mock_image_bytes] 66 | ) 67 | mock_inference.assert_called_once_with( 68 | "claude-3-haiku-20240307", "mock_messages" 69 | ) 70 | 71 | 72 | def test_prepare_messages(mock_image_bytes): 73 | messages = _prepare_messages([mock_image_bytes]) 74 | 75 | assert len(messages) == 1 76 | assert len(messages[0]) == 1 77 | assert isinstance(messages[0][0], OpenAIMessage) 78 | assert messages[0][0].role == "user" 79 | assert len(messages[0][0].content) == 2 80 | assert isinstance(messages[0][0].content[0], OpenAIComplexContent) 81 | assert messages[0][0].content[0].type == "image_url" 82 | assert isinstance(messages[0][0].content[0].image_url, OpenAIImageURL) 83 | assert messages[0][0].content[0].image_url.url == mock_image_bytes.decode() 84 | assert isinstance(messages[0][0].content[1], OpenAIComplexContent) 85 | assert messages[0][0].content[1].type == "text" 86 | assert ( 87 | "Identify and extract all tables from the document" 88 | in messages[0][0].content[1].text 89 | ) 90 | 91 | 92 | def test_parse_response(): 93 | response = """ 94 | 95 | Test Table 96 | 97 |
Col1
98 |
Col2
99 |
100 | 101 | 102 | Data1 103 | Data2 104 | 105 | 106 |
107 | """ 108 | result = parse_response(response) 109 | 110 | assert isinstance(result, TableExtractionPageResult) 111 | assert len(result.tables) == 1 112 | assert result.tables[0].title == "Test Table" 113 | assert len(result.tables[0].headers) == 2 114 | assert result.tables[0].headers[0].text == "Col1" 115 | assert len(result.tables[0].rows) == 1 116 | assert result.tables[0].rows[0].cells[0].text == "Data1" 117 | 118 | 119 | def test_title_from_tree(): 120 | soup = BeautifulSoup("Test Title
") 121 | assert _title_from_tree(soup.table) == "Test Title" 122 | 123 | soup = BeautifulSoup("
") 124 | assert _title_from_tree(soup.table) is None 125 | 126 | 127 | def test_headers_from_tree(): 128 | soup = BeautifulSoup( 129 | "
Col1
Col2
", 130 | ) 131 | headers = _headers_from_tree(soup.table) 132 | assert len(headers) == 2 133 | assert all(isinstance(h, TableHeader) for h in headers) 134 | assert headers[0].text == "Col1" 135 | 136 | soup = BeautifulSoup("
") 137 | assert _headers_from_tree(soup.table) == [] 138 | 139 | 140 | def test_rows_from_tree(): 141 | soup = BeautifulSoup( 142 | "Data1Data2
", 143 | ) 144 | rows = _rows_from_tree(soup.table) 145 | assert len(rows) == 1 146 | assert isinstance(rows[0], TableRow) 147 | assert len(rows[0].cells) == 2 148 | assert all(isinstance(c, TableCell) for c in rows[0].cells) 149 | assert rows[0].cells[0].text == "Data1" 150 | 151 | soup = BeautifulSoup("
") 152 | assert _rows_from_tree(soup.table) == [] 153 | 154 | 155 | @pytest.mark.parametrize( 156 | "input_str,sub_str,expected", 157 | [ 158 | ("abcdef
ghijkl
", "", [3, 24]), 159 | ("notables", "
", []), 160 | ("
", "
", [0, 7, 14]), 161 | ], 162 | ) 163 | def test_find_start_indices(input_str, sub_str, expected): 164 | from docprompt.tasks.table_extraction.anthropic import _find_start_indices 165 | 166 | assert _find_start_indices(input_str, sub_str) == expected 167 | 168 | 169 | @pytest.mark.parametrize( 170 | "input_str,sub_str,expected", 171 | [ 172 | ("abc
defghi", "", [11, 22]), 173 | ("notables", "", []), 174 | ("", "", [8, 16, 24]), 175 | ], 176 | ) 177 | def test_find_end_indices(input_str, sub_str, expected): 178 | from docprompt.tasks.table_extraction.anthropic import _find_end_indices 179 | 180 | assert _find_end_indices(input_str, sub_str) == expected 181 | -------------------------------------------------------------------------------- /tests/tasks/test_factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit tests for the factory module to ensure task providers are instantiated correctly. 3 | """ 4 | 5 | import pytest 6 | 7 | from docprompt.tasks.factory import ( 8 | AmazonTaskProviderFactory, 9 | AnthropicTaskProviderFactory, 10 | GCPTaskProviderFactory, 11 | ) 12 | 13 | 14 | class TestAnthropicTaskProviderFactory: 15 | @pytest.fixture 16 | def api_key(self): 17 | return "test-api-key" 18 | 19 | @pytest.fixture 20 | def environ_key(self, monkeypatch, api_key): 21 | monkeypatch.setenv("ANTHROPIC_API_KEY", api_key) 22 | 23 | yield api_key 24 | 25 | monkeypatch.delenv("ANTHROPIC_API_KEY") 26 | 27 | def test_validate_provider_w_explicit_key(self, api_key): 28 | api_key = api_key 29 | factory = AnthropicTaskProviderFactory(api_key=api_key) 30 | assert factory._credentials.kwargs == {"api_key": api_key} 31 | 32 | def test_validate_provider_w_environ_key(self, environ_key): 33 | factory = AnthropicTaskProviderFactory() 34 | assert factory._credentials.kwargs == {"api_key": environ_key} 35 | 36 | def test_get_page_classification_provider(self, environ_key): 37 | factory = AnthropicTaskProviderFactory() 38 | provider = factory.get_page_classification_provider() 39 | 40 | assert provider._default_invoke_kwargs == {"api_key": environ_key} 41 | assert provider.name == "anthropic" 42 | 43 | def test_get_page_table_extraction_provider(self, environ_key): 44 | factory = AnthropicTaskProviderFactory() 45 | provider = factory.get_page_table_extraction_provider() 46 | 47 | assert provider._default_invoke_kwargs == {"api_key": environ_key} 48 | assert provider.name == "anthropic" 49 | 50 | def test_get_page_markerization_provider(self, environ_key): 51 | factory = AnthropicTaskProviderFactory() 52 | provider = factory.get_page_markerization_provider() 53 | 54 | assert provider._default_invoke_kwargs == {"api_key": environ_key} 55 | assert provider.name == "anthropic" 56 | 57 | 58 | class TestGoogleTaskProviderFactory: 59 | @pytest.fixture 60 | def sa_file(self): 61 | return "/path/to/file" 62 | 63 | @pytest.fixture 64 | def environ_file(self, monkeypatch, sa_file): 65 | monkeypatch.setenv("GCP_SERVICE_ACCOUNT_FILE", sa_file) 66 | 67 | yield sa_file 68 | 69 | monkeypatch.delenv("GCP_SERVICE_ACCOUNT_FILE") 70 | 71 | def test_validate_provider_w_explicit_sa_file(self, sa_file): 72 | factory = GCPTaskProviderFactory(service_account_file=sa_file) 73 | assert factory._credentials.kwargs == {"service_account_file": sa_file} 74 | 75 | def test_validate_provider_w_environ_sa_file(self, environ_file): 76 | factory = GCPTaskProviderFactory() 77 | assert factory._credentials.kwargs == {"service_account_file": environ_file} 78 | 79 | def test_get_page_ocr_provider(self, environ_file): 80 | factory = GCPTaskProviderFactory() 81 | 82 | project_id = "project-id" 83 | processor_id = "processor-id" 84 | 85 | provider = factory.get_page_ocr_provider(project_id, processor_id) 86 | 87 | assert provider._default_invoke_kwargs == {"service_account_file": environ_file} 88 | assert provider.name == "gcp_documentai" 89 | 90 | 91 | class TestAmazonTaskProviderFactory: 92 | @pytest.fixture 93 | def aws_creds(self): 94 | return { 95 | "aws_access_key_id": "test_access_key", 96 | "aws_secret_access_key": "test_secret_key", 97 | "aws_region": "us-west-2", 98 | } 99 | 100 | @pytest.fixture 101 | def environ_creds(self, monkeypatch, aws_creds): 102 | monkeypatch.setenv("AWS_ACCESS_KEY_ID", aws_creds["aws_access_key_id"]) 103 | monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", aws_creds["aws_secret_access_key"]) 104 | monkeypatch.setenv("AWS_DEFAULT_REGION", aws_creds["aws_region"]) 105 | 106 | yield aws_creds 107 | 108 | monkeypatch.delenv("AWS_ACCESS_KEY_ID") 109 | monkeypatch.delenv("AWS_SECRET_ACCESS_KEY") 110 | monkeypatch.delenv("AWS_DEFAULT_REGION") 111 | 112 | def test_validate_provider_w_explicict_creds(self, aws_creds): 113 | factory = AmazonTaskProviderFactory(**aws_creds) 114 | assert factory._credentials.kwargs == aws_creds 115 | 116 | def test_validate_provider_w_environ_creds(self, environ_creds): 117 | factory = AmazonTaskProviderFactory() 118 | assert factory._credentials.kwargs == environ_creds 119 | 120 | def test_get_page_ocr_provider(self, environ_creds): 121 | factory = AmazonTaskProviderFactory() 122 | provider = factory.get_page_ocr_provider() 123 | 124 | assert provider._default_invoke_kwargs == environ_creds 125 | assert provider.name == "aws_textract" 126 | -------------------------------------------------------------------------------- /tests/tasks/test_result.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test the Page Level and Document Level task results operate as expected. 3 | """ 4 | 5 | from unittest.mock import MagicMock 6 | 7 | from docprompt import DocumentNode 8 | from docprompt.schema.pipeline import BaseMetadata 9 | from docprompt.tasks.result import BaseDocumentResult, BasePageResult, BaseResult 10 | 11 | 12 | def test_task_key(): 13 | class TestResult(BaseResult): 14 | task_name = "test" 15 | 16 | def contribute_to_document_node(self, document_node, page_number=None): 17 | pass 18 | 19 | result = TestResult(provider_name="test") 20 | assert result.task_key == "test_test" 21 | 22 | 23 | def test_base_document_result_contribution(): 24 | class TestDocumentResult(BaseDocumentResult): 25 | task_name = "test" 26 | 27 | result = TestDocumentResult( 28 | provider_name="test", document_name="test", file_hash="test" 29 | ) 30 | 31 | mock_meta = MagicMock(spec=BaseMetadata) 32 | mock_meta.task_results = {} 33 | mock_node = MagicMock(spec=DocumentNode) 34 | mock_node.metadata = mock_meta 35 | 36 | result.contribute_to_document_node(mock_node) 37 | 38 | assert mock_meta.task_results["test_test"] == result 39 | 40 | 41 | def test_base_page_result_contribution(): 42 | class TestPageResult(BasePageResult): 43 | task_name = "test" 44 | 45 | result = TestPageResult(provider_name="test") 46 | 47 | num_pages = 3 48 | mock_meta = MagicMock(spec=BaseMetadata) 49 | mock_meta.task_results = {} 50 | mock_node = MagicMock(spec=DocumentNode) 51 | mock_node.page_nodes = [MagicMock() for _ in range(num_pages)] 52 | 53 | # Test contributing to a specific page 54 | mock_node.page_nodes[0].metadata = mock_meta 55 | 56 | mock_node.__len__.return_value = num_pages 57 | 58 | result.contribute_to_document_node(mock_node, page_number=1) 59 | 60 | assert mock_node.page_nodes[0].metadata.task_results["test_test"] == result 61 | -------------------------------------------------------------------------------- /tests/tasks/test_task_provider.py: -------------------------------------------------------------------------------- 1 | """The test suite for the base task provider seeks to ensure that all of the 2 | builtin functionality of the BaseTaskProvider is proeprly implemented. 3 | """ 4 | 5 | import pytest 6 | 7 | from docprompt.tasks.base import AbstractTaskProvider 8 | 9 | 10 | class TestAbstractTaskProviderBaseFunctionliaty: 11 | """ 12 | Test that the BaseTaskProvider interface provides the correct expected basic 13 | functionality to be inherited by all subclasses. 14 | 15 | This includes: 16 | - the model validaton asserts `name` and `capabilities` are required 17 | - the intialization of the model properly sets invoke kwargs 18 | - the `ainvoke` method calling the `_ainvoke` method 19 | - the `invoke` method calling the `_invoke` method 20 | """ 21 | 22 | def test_model_validator_raises_error_on_missing_name(self): 23 | class BadTaskProvider(AbstractTaskProvider): 24 | capabilities = [] 25 | 26 | with pytest.raises(ValueError): 27 | BadTaskProvider.validate_class_vars({}) 28 | 29 | def test_model_validator_raises_error_on_missing_capabilities(self): 30 | class BadTaskProvider(AbstractTaskProvider): 31 | name = "BadTaskProvider" 32 | 33 | with pytest.raises(ValueError): 34 | BadTaskProvider.validate_class_vars({}) 35 | 36 | def test_model_validator_raises_error_on_empty_capabilities(self): 37 | class BadTaskProvider(AbstractTaskProvider): 38 | name = "BadTaskProvider" 39 | capabilities = [] 40 | 41 | with pytest.raises(ValueError): 42 | BadTaskProvider.validate_class_vars({}) 43 | 44 | def test_init_no_invoke_kwargs(self): 45 | class TestTaskProvider(AbstractTaskProvider): 46 | name = "TestTaskProvider" 47 | capabilities = ["test"] 48 | 49 | provider = TestTaskProvider() 50 | 51 | assert provider._default_invoke_kwargs == {} 52 | 53 | def test_init_with_invoke_kwargs(self): 54 | class TestTaskProvider(AbstractTaskProvider): 55 | name = "TestTaskProvider" 56 | capabilities = ["test"] 57 | 58 | kwargs = {"test": "test"} 59 | provider = TestTaskProvider(invoke_kwargs=kwargs) 60 | 61 | assert provider._default_invoke_kwargs == kwargs 62 | 63 | def test_init_with_fields_and_invoke_kwargs(self): 64 | class TestTaskProvider(AbstractTaskProvider): 65 | name = "TestTaskProvider" 66 | capabilities = ["test"] 67 | 68 | foo: str 69 | 70 | kwargs = {"test": "test"} 71 | provider = TestTaskProvider(foo="bar", invoke_kwargs=kwargs) 72 | 73 | assert provider._default_invoke_kwargs == kwargs 74 | assert provider.foo == "bar" 75 | 76 | @pytest.mark.asyncio 77 | async def test_ainvoke_calls__ainvoke(self): 78 | class TestTaskProvider(AbstractTaskProvider): 79 | name = "TestTaskProvider" 80 | capabilities = ["test"] 81 | 82 | async def _ainvoke(self, input, config=None, **kwargs): 83 | return input 84 | 85 | provider = TestTaskProvider() 86 | 87 | assert await provider.ainvoke([1, 2, 3]) == [1, 2, 3] 88 | 89 | def test_invoke_calls__invoke(self): 90 | class TestTaskProvider(AbstractTaskProvider): 91 | name = "TestTaskProvider" 92 | capabilities = ["test"] 93 | 94 | def _invoke(self, input, config=None, **kwargs): 95 | return input 96 | 97 | provider = TestTaskProvider() 98 | 99 | assert provider.invoke([1, 2, 3]) == [1, 2, 3] 100 | -------------------------------------------------------------------------------- /tests/test_date_extraction.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | 3 | from docprompt.utils.date_extraction import extract_dates_from_text 4 | 5 | STRING_A = """ 6 | There was a meeting on 2021-01-01 and another on 2021-01-02. 7 | 8 | Meanwhile on September 1, 2021, there was a third meeting. 9 | 10 | The final meeting was on 4/5/2021. 11 | """ 12 | 13 | 14 | def test_date_extraction(): 15 | dates = extract_dates_from_text(STRING_A) 16 | 17 | assert len(dates) == 5 18 | 19 | dates.sort(key=lambda x: x[0]) 20 | 21 | assert dates[0][0] == date(2021, 1, 1) 22 | assert dates[0][1] == "2021-01-01" 23 | 24 | assert dates[1][0] == date(2021, 1, 2) 25 | assert dates[1][1] == "2021-01-02" 26 | 27 | assert dates[2][0] == date(2021, 4, 5) 28 | assert dates[2][1] == "4/5/2021" 29 | 30 | assert dates[3][0] == date(2021, 5, 4) 31 | assert dates[3][1] == "4/5/2021" 32 | 33 | assert dates[4][0] == date(2021, 9, 1) 34 | assert dates[4][1] == "September 1, 2021" 35 | -------------------------------------------------------------------------------- /tests/test_search.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | from pytest import raises 4 | 5 | from docprompt import DocumentNode, load_document 6 | 7 | from .fixtures import PDF_FIXTURES 8 | 9 | 10 | def test_search(): 11 | document = load_document(PDF_FIXTURES[0].get_full_path()) 12 | document_node = DocumentNode.from_document(document) 13 | 14 | ocr_results = PDF_FIXTURES[0].get_ocr_results() 15 | 16 | with raises(ValueError): 17 | document_node.refresh_locator() 18 | 19 | for page_num, ocr_result in ocr_results.items(): 20 | ocr_result.contribute_to_document_node(document_node, page_number=page_num) 21 | 22 | print(document_node[0].metadata.task_results) 23 | 24 | assert document_node._locator is None # Ensure the locator is not set 25 | 26 | # Need to make sure an ocr_key is set to avoid ValueError 27 | locator = document_node.locator 28 | 29 | result = locator.search("word that doesn't exist") 30 | 31 | assert len(result) == 0 32 | 33 | result_all_pages = locator.search("and") 34 | 35 | assert len(result_all_pages) == 50 36 | 37 | result_page_1 = locator.search("rooted", page_number=1) 38 | 39 | assert len(result_page_1) == 1 40 | 41 | result_multiple_words = locator.search("MMAX2 system", page_number=1) 42 | 43 | assert len(result_multiple_words) == 1 44 | 45 | sources = result_multiple_words[0].text_location.source_blocks 46 | 47 | assert len(sources) == 2 48 | 49 | result_multiple_words = locator.search( 50 | "MMAX2 system", page_number=1, refine_to_word=False 51 | ) 52 | 53 | assert len(result_multiple_words) == 1 54 | 55 | sources = result_multiple_words[0].text_location.source_blocks 56 | 57 | assert len(sources) == 1 58 | 59 | n_best = locator.search_n_best("and", n=3) 60 | 61 | assert len(n_best) == 3 62 | 63 | raw_search = locator.search_raw('content:"rooted"') 64 | 65 | assert len(raw_search) == 1 66 | 67 | 68 | def test_pickling__removes_locator_document_basis(): 69 | document = load_document(PDF_FIXTURES[0].get_full_path()) 70 | document_node = DocumentNode.from_document(document) 71 | 72 | ocr_results = PDF_FIXTURES[0].get_ocr_results() 73 | 74 | for page_num, ocr_result in ocr_results.items(): 75 | ocr_result.contribute_to_document_node(document_node, page_number=page_num) 76 | 77 | result_page_1 = document_node.locator.search("rooted", page_number=1) 78 | 79 | assert len(result_page_1) == 1 80 | 81 | dumped = pickle.dumps(document_node) 82 | 83 | loaded = pickle.loads(dumped) 84 | 85 | assert loaded._locator is None 86 | 87 | 88 | def test_pickling__removes_locator_page_basis(): 89 | document = load_document(PDF_FIXTURES[0].get_full_path()) 90 | document_node = DocumentNode.from_document(document) 91 | 92 | ocr_results = PDF_FIXTURES[0].get_ocr_results() 93 | 94 | for page_num, ocr_result in ocr_results.items(): 95 | ocr_result.contribute_to_document_node(document_node, page_number=page_num) 96 | 97 | page = document_node.page_nodes[0] 98 | 99 | result_page_1 = page.search("rooted") 100 | 101 | assert len(result_page_1) == 1 102 | 103 | dumped = pickle.dumps(page) 104 | 105 | loaded = pickle.loads(dumped) 106 | 107 | assert loaded.document._locator is None 108 | -------------------------------------------------------------------------------- /tests/test_threadpool.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ThreadPoolExecutor, as_completed 2 | from pathlib import Path 3 | 4 | import pytest 5 | 6 | from docprompt import load_document 7 | from docprompt.utils.splitter import pdf_split_iter_with_max_bytes 8 | 9 | 10 | def do_split(document): 11 | return list( 12 | pdf_split_iter_with_max_bytes(document.file_bytes, 15, 1024 * 1024 * 15) 13 | ) 14 | 15 | 16 | @pytest.mark.skip(reason="Fixures are missing for this test") 17 | def test_document_split_in_threadpool__does_not_hang(): 18 | source_dir = Path(__file__).parent.parent / "data" / "threadpool_test" 19 | 20 | documents = [load_document(file) for file in source_dir.iterdir()] 21 | 22 | futures = [] 23 | 24 | with ThreadPoolExecutor() as executor: 25 | for document in documents: 26 | future = executor.submit(do_split, document) 27 | 28 | futures.append(future) 29 | 30 | for future in as_completed(futures): 31 | assert future.result() is not None 32 | -------------------------------------------------------------------------------- /tests/util.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Dict, Optional 3 | 4 | from pydantic import BaseModel, Field, PositiveInt, TypeAdapter 5 | 6 | from docprompt.tasks.ocr.result import OcrPageResult 7 | 8 | FIXTURE_PATH = Path(__file__).parent / "fixtures" 9 | 10 | 11 | OCR_ADAPTER = TypeAdapter(Dict[int, OcrPageResult]) 12 | 13 | 14 | class PdfFixture(BaseModel): 15 | name: str = Field(description="The name of the fixture") 16 | page_count: PositiveInt = Field(description="The number of pages in the fixture") 17 | file_hash: str = Field(description="The expected hash of the fixture") 18 | ocr_name: Optional[str] = Field( 19 | description="The path to the OCR results for the fixture", default=None 20 | ) 21 | 22 | def get_full_path(self): 23 | return FIXTURE_PATH / self.name 24 | 25 | def get_bytes(self): 26 | return self.get_full_path().read_bytes() 27 | 28 | def get_ocr_results(self): 29 | if not self.ocr_name: 30 | return None 31 | 32 | ocr_path = FIXTURE_PATH / self.ocr_name 33 | 34 | return OCR_ADAPTER.validate_json(ocr_path.read_text()) 35 | 36 | def get_document_node(self): 37 | from docprompt import load_document_node 38 | 39 | document = load_document_node(self.get_bytes()) 40 | 41 | ocr_results = self.get_ocr_results() 42 | 43 | if ocr_results: 44 | for page_number, ocr_result in self.get_ocr_results().items(): 45 | page = document.page_nodes[page_number - 1] 46 | page.metadata.ocr_results = ocr_result 47 | 48 | return document 49 | --------------------------------------------------------------------------------