├── .bumpversion.cfg
├── .editorconfig
├── .github
├── ISSUE_TEMPLATE.md
└── workflows
│ ├── dev.yml
│ ├── docs.yml
│ ├── preview.yml
│ └── release.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CHANGELOG.md
├── CNAME
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── benchmarks
├── classification.ipynb
└── ocr.ipynb
├── docprompt
├── __init__.py
├── _decorators.py
├── _exec
│ ├── __init__.py
│ ├── ghostscript.py
│ └── tesseract.py
├── _pdfium.py
├── contrib
│ ├── __init__.py
│ └── parser_bot
│ │ └── __init__.py
├── provenance
│ ├── __init__.py
│ ├── search.py
│ ├── source.py
│ └── util.py
├── rasterize.py
├── schema
│ ├── __init__.py
│ ├── document.py
│ ├── layout.py
│ └── pipeline
│ │ ├── __init__.py
│ │ ├── metadata.py
│ │ ├── node
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── collection.py
│ │ ├── document.py
│ │ ├── image.py
│ │ ├── page.py
│ │ └── typing.py
│ │ └── rasterizer.py
├── storage.py
├── tasks
│ ├── __init__.py
│ ├── base.py
│ ├── capabilities.py
│ ├── classification
│ │ ├── __init__.py
│ │ ├── anthropic.py
│ │ └── base.py
│ ├── credentials.py
│ ├── factory.py
│ ├── markerize
│ │ ├── __init__.py
│ │ ├── anthropic.py
│ │ └── base.py
│ ├── message.py
│ ├── ocr
│ │ ├── __init__.py
│ │ ├── amazon.py
│ │ ├── base.py
│ │ ├── gcp.py
│ │ ├── result.py
│ │ └── tesseract.py
│ ├── parser.py
│ ├── result.py
│ ├── table_extraction
│ │ ├── __init__.py
│ │ ├── anthropic.py
│ │ ├── base.py
│ │ └── schema.py
│ └── util.py
└── utils
│ ├── __init__.py
│ ├── compressor.py
│ ├── date_extraction.py
│ ├── inference.py
│ ├── layout.py
│ ├── masking
│ ├── __init__.py
│ └── image.py
│ ├── splitter.py
│ └── util.py
├── docs
├── CNAME
├── assets
│ └── static
│ │ └── img
│ │ ├── logo.png
│ │ ├── logo.svg
│ │ └── old-logo.png
├── blog
│ └── index.md
├── community
│ ├── contributing.md
│ └── versioning.md
├── concepts
│ ├── nodes.md
│ ├── primatives.md
│ ├── provenance.md
│ └── providers.md
├── enterprise.md
├── gen_ref_pages.py
├── guide
│ ├── classify
│ │ ├── binary.md
│ │ ├── multi.md
│ │ └── single.md
│ ├── ocr
│ │ ├── advanced_search.md
│ │ ├── advanced_workflows.md
│ │ ├── basic_usage.md
│ │ └── provider_config.md
│ └── table_extraction
│ │ └── extract_tables.md
└── index.md
├── makefile
├── mkdocs.yml
├── pdm.lock
├── pyproject.toml
├── setup.cfg
└── tests
├── __init__.py
├── _run_tests_with_coverage.py
├── conftest.py
├── fixtures.py
├── fixtures
├── 1.pdf
├── 1_ocr.json
└── 2.pdf
├── schema
├── __init__.py
├── pipeline
│ ├── __init__.py
│ ├── test_imagenode.py
│ ├── test_layoutaware.py
│ ├── test_metadata.py
│ └── test_rasterizer.py
├── test_document.py
└── test_layout_models.py
├── tasks
├── __init__.py
├── classification
│ ├── __init__.py
│ ├── test_anthropic.py
│ └── test_base.py
├── markerize
│ ├── __init__.py
│ ├── test_anthropic.py
│ └── test_base.py
├── table_extraction
│ ├── __init__.py
│ ├── test_anthropic.py
│ └── test_base.py
├── test_credentials.py
├── test_factory.py
├── test_result.py
└── test_task_provider.py
├── test_date_extraction.py
├── test_decorators.py
├── test_search.py
├── test_storage.py
├── test_threadpool.py
└── util.py
/.bumpversion.cfg:
--------------------------------------------------------------------------------
1 | [bumpversion]
2 | current_version = 0.1.0
3 | commit = True
4 | tag = True
5 |
6 | [bumpversion:file:pyproject.toml]
7 | search = version = "{current_version}"
8 | replace = version = "{new_version}"
9 |
10 | [bumpversion:file:docprompt/__init__.py]
11 | search = __version__ = '{current_version}'
12 | replace = __version__ = '{new_version}'
13 |
--------------------------------------------------------------------------------
/.editorconfig:
--------------------------------------------------------------------------------
1 | # http://editorconfig.org
2 |
3 | root = true
4 |
5 | [*]
6 | indent_style = space
7 | indent_size = 4
8 | trim_trailing_whitespace = true
9 | insert_final_newline = true
10 | charset = utf-8
11 | end_of_line = lf
12 |
13 | [*.bat]
14 | indent_style = tab
15 | end_of_line = crlf
16 |
17 | [LICENSE]
18 | insert_final_newline = false
19 |
20 | [Makefile]
21 | indent_style = tab
22 |
23 | [*.{yml, yaml}]
24 | indent_size = 2
25 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | * Docprompt version:
2 | * Python version:
3 | * Operating System:
4 |
5 | ### Description
6 |
7 | Describe what you were trying to get done.
8 | Tell us what happened, what went wrong, and what you expected to happen.
9 |
10 | ### What I Did
11 |
12 | ```
13 | Paste the command(s) you ran and the output.
14 | If there was a crash, please include the traceback here.
15 | ```
16 |
--------------------------------------------------------------------------------
/.github/workflows/dev.yml:
--------------------------------------------------------------------------------
1 | # This is a basic workflow to help you get started with Actions
2 |
3 | name: dev workflow
4 |
5 | # Controls when the action will run.
6 | on:
7 | # Triggers the workflow on push or pull request events but only for the master branch
8 | push:
9 | branches: [ main ]
10 | pull_request:
11 | branches: [ main ]
12 |
13 | # Allows you to run this workflow manually from the Actions tab
14 | workflow_dispatch:
15 |
16 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel
17 | jobs:
18 | # This workflow contains a single job called "test"
19 | test:
20 | # The type of runner that the job will run on
21 | strategy:
22 | matrix:
23 | python: ['3.8', '3.9', '3.10', '3.11', '3.12']
24 | runs-on: ubuntu-latest
25 |
26 | # Steps represent a sequence of tasks that will be executed as part of the job
27 | steps:
28 | # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
29 | - uses: actions/checkout@v4
30 | - uses: actions/setup-python@v5
31 | with:
32 | python-version: ${{ matrix.python }}
33 | cache: 'pip'
34 |
35 | - name: Install dependencies
36 | run: |
37 | sudo apt update && sudo apt install -y ghostscript
38 | python -m pip install --upgrade pip
39 | pip install -U pdm
40 | pdm install -G:all
41 | - name: Lint with Ruff
42 | run: |
43 | pdm run ruff check --output-format=github .
44 | - name: Test with pytest
45 | run: |
46 | pdm run pytest -sxv --cov --cov-report=xml .
47 | - name: Upload coverage reports to Codecov
48 | uses: codecov/codecov-action@v4.0.1
49 | with:
50 | token: ${{ secrets.CODECOV_TOKEN }}
51 | slug: Page-Leaf/Docprompt
52 | files: ./coverage.xml
53 |
--------------------------------------------------------------------------------
/.github/workflows/docs.yml:
--------------------------------------------------------------------------------
1 | name: Deploy MkDocs to GitHub Pages
2 |
3 | on:
4 | push:
5 | branches:
6 | - main # Set this to your default branch
7 |
8 | jobs:
9 | deploy:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - uses: actions/checkout@v3
13 |
14 | - name: Set up Python
15 | uses: actions/setup-python@v4
16 | with:
17 | python-version: '3.x'
18 |
19 | - name: Install PDM
20 | run: pip install pdm
21 |
22 | - name: Install dependencies
23 | run: pdm install -d
24 |
25 | - name: Deploy MkDocs
26 | run: pdm run mkdocs gh-deploy --force
27 |
--------------------------------------------------------------------------------
/.github/workflows/preview.yml:
--------------------------------------------------------------------------------
1 | # This is a basic workflow to help you get started with Actions
2 |
3 | name: stage & preview workflow
4 |
5 | # Controls when the action will run.
6 | on:
7 | # Triggers the workflow on push or pull request events but only for the master branch
8 | push:
9 | branches: [ main ]
10 |
11 | # Allows you to run this workflow manually from the Actions tab
12 | workflow_dispatch:
13 |
14 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel
15 | jobs:
16 | publish_dev_build:
17 | runs-on: ubuntu-latest
18 |
19 | strategy:
20 | matrix:
21 | python-versions: [3.9]
22 |
23 | permissions:
24 | id-token: write
25 | contents: read
26 | steps:
27 | - uses: actions/checkout@v4
28 | - uses: actions/setup-python@v5
29 | with:
30 | python-version: ${{ matrix.python-versions }}
31 | cache: 'pip'
32 |
33 | - name: Install dependencies
34 | run: |
35 | sudo apt update && sudo apt install -y ghostscript
36 | python -m pip install --upgrade pip
37 | pip install pdm
38 | pdm install -G:all
39 |
40 | - name: Run Pytest
41 | run:
42 | pdm run pytest -sxv
43 |
44 | - name: Build wheels and source tarball
45 | run: |
46 | pdm build
47 |
48 | - name: Publish package distributions to TestPyPI
49 | uses: pypa/gh-action-pypi-publish@release/v1
50 | with:
51 | repository-url: https://test.pypi.org/legacy/
52 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: release & publish workflow
2 |
3 | on:
4 | push:
5 | tags:
6 | - 'v*'
7 | workflow_dispatch:
8 |
9 | jobs:
10 | release:
11 | name: Create Release
12 | runs-on: ubuntu-latest
13 |
14 | strategy:
15 | matrix:
16 | python-versions: [3.11]
17 |
18 | permissions:
19 | contents: write # This allows creating releases
20 |
21 | steps:
22 | - name: Get version from tag
23 | id: tag_name
24 | run: |
25 | echo ::set-output name=current_version::${GITHUB_REF#refs/tags/v}
26 | shell: bash
27 |
28 | - uses: actions/checkout@v4
29 | - uses: actions/setup-python@v5
30 | with:
31 | python-version: ${{ matrix.python-versions }}
32 |
33 | - name: Install dependencies
34 | run: |
35 | python -m pip install --upgrade pip
36 | pip install pdm
37 | pdm install -G:all
38 |
39 | - name: Build wheels and source tarball
40 | run: >-
41 | pdm build
42 |
43 | - name: show temporary files
44 | run: >-
45 | ls -l
46 |
47 | - name: create github release
48 | id: create_release
49 | uses: softprops/action-gh-release@v2
50 | env:
51 | GITHUB_TOKEN: ${{ secrets.GH_TOKEN }}
52 | with:
53 | files: dist/*.whl
54 | draft: false
55 | prerelease: false
56 |
57 | - name: publish to PyPI
58 | uses: pypa/gh-action-pypi-publish@release/v1
59 | with:
60 | user: __token__
61 | password: ${{ secrets.PYPI_API_TOKEN }}
62 | skip_existing: true
63 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # macOS
2 | .DS_Store
3 | .AppleDouble
4 | .LSOverride
5 |
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 |
14 |
15 | # Distribution / packaging
16 | .Python
17 | env/
18 | build/
19 | develop-eggs/
20 | dist/
21 | downloads/
22 | eggs/
23 | .eggs/
24 | lib/
25 | lib64/
26 | parts/
27 | sdist/
28 | var/
29 | wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # pyenv
81 | .python-version
82 |
83 | # celery beat schedule file
84 | celerybeat-schedule
85 |
86 | # SageMath parsed files
87 | *.sage.py
88 |
89 | # dotenv
90 | .env
91 |
92 | # virtualenv
93 | .venv
94 | venv/
95 | ENV/
96 |
97 | # Spyder project settings
98 | .spyderproject
99 | .spyproject
100 |
101 | # Rope project settings
102 | .ropeproject
103 |
104 | # mkdocs documentation
105 | /site
106 |
107 | # mypy
108 | .mypy_cache/
109 |
110 | # IDE settings
111 | .vscode/
112 | .idea/
113 |
114 | # mkdocs build dir
115 | site/
116 |
117 | .creds/
118 |
119 | data/
120 |
121 | .pdm-python
122 | .aider*
123 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/Lucas-C/pre-commit-hooks
3 | rev: v1.1.9
4 | hooks:
5 | - id: forbid-crlf
6 | - id: remove-crlf
7 | - repo: https://github.com/pre-commit/pre-commit-hooks
8 | rev: v3.4.0
9 | hooks:
10 | - id: trailing-whitespace
11 | - id: end-of-file-fixer
12 | - id: check-merge-conflict
13 | - id: check-yaml
14 | args: [ --unsafe ]
15 | - repo: https://github.com/astral-sh/ruff-pre-commit
16 | # Ruff version.
17 | rev: v0.3.2
18 | hooks:
19 | # Run the linter.
20 | - id: ruff
21 | args: [ --fix ]
22 | # Run the formatter.
23 | - id: ruff-format
24 | - repo: https://github.com/pappasam/toml-sort
25 | rev: v0.23.1
26 | hooks:
27 | - id: toml-sort
28 | - id: toml-sort-fix
29 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | ## 0.1.0 (2023-10-18)
4 |
5 | * First release on PyPI.
6 |
--------------------------------------------------------------------------------
/CNAME:
--------------------------------------------------------------------------------
1 | docs.docprompt.io
2 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing
2 |
3 | Contributions are welcome, and they are greatly appreciated! Every little bit
4 | helps, and credit will always be given.
5 |
6 | You can contribute in many ways:
7 |
8 | ## Types of Contributions
9 |
10 | ### Report Bugs
11 |
12 | Report bugs at https://github.com/psu3d0/docprompt/issues.
13 |
14 | If you are reporting a bug, please include:
15 |
16 | * Your operating system name and version.
17 | * Any details about your local setup that might be helpful in troubleshooting.
18 | * Detailed steps to reproduce the bug.
19 |
20 | ### Fix Bugs
21 |
22 | Look through the GitHub issues for bugs. Anything tagged with "bug" and "help
23 | wanted" is open to whoever wants to implement it.
24 |
25 | ### Implement Features
26 |
27 | Look through the GitHub issues for features. Anything tagged with "enhancement"
28 | and "help wanted" is open to whoever wants to implement it.
29 |
30 | ### Write Documentation
31 |
32 | Docprompt could always use more documentation, whether as part of the
33 | official Docprompt docs, in docstrings, or even on the web in blog posts,
34 | articles, and such.
35 |
36 | ### Submit Feedback
37 |
38 | The best way to send feedback is to file an issue at https://github.com/psu3d0/docprompt/issues.
39 |
40 | If you are proposing a feature:
41 |
42 | * Explain in detail how it would work.
43 | * Keep the scope as narrow as possible, to make it easier to implement.
44 | * Remember that this is a volunteer-driven project, and that contributions
45 | are welcome :)
46 |
47 | ## Get Started!
48 |
49 | Ready to contribute? Here's how to set up `docprompt` for local development.
50 |
51 | 1. Fork the `docprompt` repo on GitHub.
52 | 2. Clone your fork locally
53 |
54 | ```
55 | $ git clone git@github.com:your_name_here/docprompt.git
56 | ```
57 |
58 | 3. Ensure [pdm](https://pdm-project.org/en/latest/) is installed.
59 | 4. Install dependencies and start your virtualenv:
60 |
61 | ```
62 | $ pdm install -d
63 | ```
64 |
65 | 5. Create a branch for local development:
66 |
67 | ```
68 | $ git checkout -b name-of-your-bugfix-or-feature
69 | ```
70 |
71 | Now you can make your changes locally.
72 |
73 | 6. When you're done making changes, check that your changes pass the
74 | tests, including testing other Python versions, with tox:
75 |
76 | ```
77 | $ pdm run tox
78 | ```
79 |
80 | 7. Commit your changes and push your branch to GitHub:
81 |
82 | ```
83 | $ git add .
84 | $ git commit -m "Your detailed description of your changes."
85 | $ git push origin name-of-your-bugfix-or-feature
86 | ```
87 |
88 | 8. Submit a pull request through the GitHub website.
89 |
90 | ## Pull Request Guidelines
91 |
92 | Before you submit a pull request, check that it meets these guidelines:
93 |
94 | 1. The pull request should include tests.
95 | 2. If the pull request adds functionality, the docs should be updated. Put
96 | your new functionality into a function with a docstring, and add the
97 | feature to the list in README.md.
98 | 3. The pull request should work for Python 3.6, 3.7, 3.8 and 3.9. Check
99 | https://github.com/psu3d0/docprompt/actions
100 | and make sure that the tests pass for all supported Python versions.
101 |
102 | ## Tips
103 |
104 | ```
105 | $ pdm run pytest tests/test_docprompt.py
106 | ```
107 |
108 | To run a subset of tests.
109 |
110 |
111 | ## Deploying
112 |
113 | A reminder for the maintainers on how to deploy.
114 | Make sure all your changes are committed (including an entry in CHANGELOG.md).
115 | Then run:
116 |
117 | ```
118 | $ pdm run bump2version patch # possible: major / minor / patch
119 | $ git push
120 | $ git push --tags
121 | ```
122 |
123 | GitHub Actions will then deploy to PyPI if tests pass.
124 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache Software License 2.0
2 |
3 | Copyright (c) 2023, Frankie Colson
4 |
5 | Licensed under the Apache License, Version 2.0 (the "License");
6 | you may not use this file except in compliance with the License.
7 | You may obtain a copy of the License at
8 |
9 | http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | Unless required by applicable law or agreed to in writing, software
12 | distributed under the License is distributed on an "AS IS" BASIS,
13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | See the License for the specific language governing permissions and
15 | limitations under the License.
16 |
--------------------------------------------------------------------------------
/docprompt/__init__.py:
--------------------------------------------------------------------------------
1 | """Top-level package for Docprompt."""
2 |
3 | __author__ = """Frankie Colson"""
4 | __email__ = "frank@pageleaf.io"
5 | __version__ = "0.8.7"
6 |
7 | from docprompt.rasterize import ProviderResizeRatios
8 | from docprompt.schema.document import Document, PdfDocument # noqa
9 | from docprompt.schema.layout import NormBBox, TextBlock # noqa
10 | from docprompt.schema.pipeline import DocumentCollection, DocumentNode, PageNode # noqa
11 | from docprompt.tasks.ocr.result import OcrPageResult # noqa
12 | from docprompt.utils import ( # noqa
13 | hash_from_bytes,
14 | load_document,
15 | load_document_node,
16 | load_documents,
17 | load_pdf_document,
18 | load_pdf_documents,
19 | )
20 |
21 | # PdfDocument.model_rebuild()
22 | DocumentNode.model_rebuild()
23 |
24 |
25 | __all__ = [
26 | "Document",
27 | "PdfDocument",
28 | "DocumentCollection",
29 | "DocumentNode",
30 | "NormBBox",
31 | "PageNode",
32 | "TextBlock",
33 | "load_document",
34 | "load_documents",
35 | "hash_from_bytes",
36 | "ProviderResizeRatios",
37 | "load_pdf_document",
38 | ]
39 |
--------------------------------------------------------------------------------
/docprompt/_decorators.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import sys
3 | from functools import partial, update_wrapper, wraps
4 | from typing import Callable, Optional, Set, Tuple, Type
5 |
6 | if sys.version_info >= (3, 9):
7 | to_thread = asyncio.to_thread
8 | else:
9 |
10 | def to_thread(func, /, *args, **kwargs):
11 | @wraps(func)
12 | async def wrapper():
13 | try:
14 | loop = asyncio.get_running_loop()
15 | except RuntimeError:
16 | # If there's no running event loop, create a new one
17 | loop = asyncio.new_event_loop()
18 | asyncio.set_event_loop(loop)
19 | pfunc = partial(func, *args, **kwargs)
20 | return await loop.run_in_executor(None, pfunc)
21 |
22 | return wrapper()
23 |
24 |
25 | def get_closest_attr(cls: Type, attr_name: str) -> Tuple[Type, Optional[Callable], int]:
26 | closest_cls = cls
27 | attr = getattr(cls.__dict__, attr_name, None)
28 | depth = 0
29 |
30 | if attr and hasattr(attr, "_original"):
31 | attr = None
32 | elif attr:
33 | return (cls, attr, 0)
34 |
35 | for idx, base in enumerate(cls.__mro__, start=1):
36 | if not attr and attr_name in base.__dict__:
37 | if not hasattr(base.__dict__[attr_name], "_original"):
38 | closest_cls = base
39 | attr = base.__dict__[attr_name]
40 | depth = idx
41 |
42 | if attr:
43 | break
44 |
45 | return (closest_cls, attr, depth)
46 |
47 |
48 | def validate_method(cls, name: str, method: Callable, expected_async: bool):
49 | if method is None:
50 | return None
51 | is_async = asyncio.iscoroutinefunction(method)
52 | if is_async != expected_async:
53 | return f"Method '{name}' in {cls.__name__} should be {'async' if expected_async else 'sync'}, but it's {'async' if is_async else 'sync'}"
54 |
55 | return None
56 |
57 |
58 | def apply_dual_methods_to_cls(cls: Type, method_group: Tuple[str, str]):
59 | errors = []
60 |
61 | sync_name, async_name = method_group
62 |
63 | sync_trace = get_closest_attr(cls, sync_name)
64 | async_trace = get_closest_attr(cls, async_name)
65 |
66 | sync_cls, sync_method, sync_depth = sync_trace
67 | async_cls, async_method, async_depth = async_trace
68 |
69 | if sync_method:
70 | sync_error = validate_method(cls, sync_name, sync_method, False)
71 | if sync_error:
72 | errors.append(sync_error)
73 |
74 | if async_method:
75 | async_error = validate_method(cls, async_name, async_method, True)
76 | if async_error:
77 | errors.append(async_error)
78 |
79 | if (
80 | sync_method is None
81 | and async_method is None
82 | and not getattr(getattr(cls, "Meta", None), "abstract", False)
83 | ):
84 | return [
85 | f"{cls.__name__} must implement at least one of these methods: {sync_name}, {async_name}"
86 | ]
87 |
88 | if sync_cls is cls and async_cls is cls and sync_method and async_method:
89 | return errors # Both methods are already in the same class
90 |
91 | if async_cls is cls and async_method:
92 |
93 | def sync_wrapper(*args, **kwargs):
94 | return asyncio.run(async_method(*args, **kwargs))
95 |
96 | update_wrapper(sync_wrapper, async_method)
97 |
98 | sync_wrapper._original = async_method
99 |
100 | setattr(cls, sync_name, sync_wrapper)
101 | elif sync_cls is cls and sync_method:
102 |
103 | async def async_wrapper(*args, **kwargs):
104 | if hasattr(sync_method, "__func__"):
105 | return await to_thread(sync_method, *args, **kwargs)
106 | return await to_thread(sync_method, *args, **kwargs)
107 |
108 | update_wrapper(async_wrapper, sync_method)
109 |
110 | async_wrapper._original = sync_method
111 |
112 | setattr(cls, async_name, async_wrapper)
113 | else:
114 | if async_depth < sync_depth:
115 |
116 | def sync_wrapper(*args, **kwargs):
117 | return asyncio.run(async_method(*args, **kwargs))
118 |
119 | update_wrapper(sync_wrapper, async_method)
120 |
121 | sync_wrapper._original = async_method
122 |
123 | setattr(cls, sync_name, sync_wrapper)
124 | else:
125 |
126 | async def async_wrapper(*args, **kwargs):
127 | return await to_thread(sync_method, *args, **kwargs)
128 |
129 | update_wrapper(async_wrapper, sync_method)
130 |
131 | async_wrapper._original = sync_method
132 |
133 | setattr(cls, async_name, async_wrapper)
134 |
135 | return errors
136 |
137 |
138 | def get_flexible_method_configs(cls: Type) -> Set[Tuple[str, str]]:
139 | all = set()
140 | for base in cls.__mro__:
141 | all.update(getattr(base, "__flexible_methods__", set()))
142 |
143 | return all
144 |
145 |
146 | def flexible_methods(*method_groups: Tuple[str, str]):
147 | def decorator(cls: Type):
148 | if not hasattr(cls, "__flexible_methods__"):
149 | setattr(cls, "__flexible_methods__", set())
150 |
151 | for base in cls.__bases__:
152 | if hasattr(base, "__flexible_methods__"):
153 | cls.__flexible_methods__.update(base.__flexible_methods__)
154 |
155 | cls.__flexible_methods__.update(method_groups)
156 |
157 | def apply_flexible_methods(cls: Type):
158 | errors = []
159 |
160 | for group in get_flexible_method_configs(cls):
161 | if len(group) != 2:
162 | errors.append(
163 | f"Invalid method group {group}. Each group must be a tuple of exactly two method names."
164 | )
165 | continue
166 |
167 | errors.extend(apply_dual_methods_to_cls(cls, group))
168 |
169 | if errors:
170 | raise TypeError("\n".join(errors))
171 |
172 | apply_flexible_methods(cls)
173 |
174 | original_init_subclass = cls.__init_subclass__
175 |
176 | @classmethod
177 | def new_init_subclass(cls, **kwargs):
178 | original_init_subclass(**kwargs)
179 | apply_flexible_methods(cls)
180 |
181 | cls.__init_subclass__ = new_init_subclass
182 |
183 | return cls
184 |
185 | return decorator
186 |
--------------------------------------------------------------------------------
/docprompt/_exec/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/docprompt/_exec/__init__.py
--------------------------------------------------------------------------------
/docprompt/_exec/ghostscript.py:
--------------------------------------------------------------------------------
1 | from os import PathLike
2 | from pathlib import Path
3 | from subprocess import PIPE, CompletedProcess, run
4 | from typing import Literal, Union
5 |
6 | GS = "gs"
7 |
8 |
9 | class GhostscriptError(Exception):
10 | def __init__(self, message: str, process: CompletedProcess) -> None:
11 | self.process = process
12 | super().__init__(message)
13 |
14 |
15 | def compress_pdf(
16 | fp: Union[PathLike, str], # Ghostscript insists on a file instead of bytes
17 | output_path: str,
18 | *,
19 | compression: Literal["jpeg", "lossless"] = "jpeg",
20 | ):
21 | compression_args = []
22 | if compression == "jpeg":
23 | compression_args = [
24 | "-dAutoFilterColorImages=false",
25 | "-dColorImageFilter=/DCTEncode",
26 | "-dAutoFilterGrayImages=false",
27 | "-dGrayImageFilter=/DCTEncode",
28 | ]
29 | elif compression == "lossless":
30 | compression_args = [
31 | "-dAutoFilterColorImages=false",
32 | "-dColorImageFilter=/FlateEncode",
33 | "-dAutoFilterGrayImages=false",
34 | "-dGrayImageFilter=/FlateEncode",
35 | ]
36 | else:
37 | compression_args = [
38 | "-dAutoFilterColorImages=true",
39 | "-dAutoFilterGrayImages=true",
40 | ]
41 |
42 | args_gs = (
43 | [
44 | GS,
45 | "-q",
46 | "-dBATCH",
47 | "-dNOPAUSE",
48 | "-dSAFER",
49 | "-dCompatibilityLevel=1.5",
50 | "-sDEVICE=pdfwrite",
51 | "-dAutoRotatePages=/None",
52 | "-sColorConversionStrategy=LeaveColorUnchanged",
53 | ]
54 | + compression_args
55 | + [
56 | "-dJPEGQ=95",
57 | "-dPDFA=2",
58 | "-dPDFACompatibilityPolicy=1",
59 | "-sOutputFile=" + output_path,
60 | str(fp),
61 | ]
62 | )
63 |
64 | result = run(args_gs, stdout=PIPE, stderr=PIPE, check=False)
65 |
66 | if result.returncode != 0:
67 | raise GhostscriptError("Ghostscript failed to compress the document", result)
68 |
69 | return result
70 |
71 |
72 | def compress_pdf_to_bytes(
73 | fp: Union[PathLike, str], *, compression: Literal["jpeg", "lossless"] = "jpeg"
74 | ) -> bytes:
75 | result = compress_pdf(fp, output_path="%stdout", compression=compression)
76 |
77 | return result.stdout
78 |
79 |
80 | def compress_pdf_to_path(
81 | fp: Union[PathLike, str],
82 | output_path: PathLike,
83 | *,
84 | compression: Literal["jpeg", "lossless"] = "jpeg",
85 | ) -> Path:
86 | compress_pdf(fp, output_path=str(output_path), compression=compression)
87 |
88 | return Path(output_path)
89 |
--------------------------------------------------------------------------------
/docprompt/_exec/tesseract.py:
--------------------------------------------------------------------------------
1 | import io
2 | import re
3 | import xml.etree.ElementTree as ET
4 | from os import PathLike
5 | from subprocess import PIPE, CompletedProcess, run
6 | from typing import List, TypedDict, Union
7 |
8 | from PIL import Image
9 |
10 | TESSERACT = "tesseract"
11 |
12 |
13 | def check_tesseract_installed() -> bool:
14 | result = run(
15 | [TESSERACT, "--version"], stdout=PIPE, stderr=PIPE, check=False, text=True
16 | )
17 | return result.returncode == 0
18 |
19 |
20 | class TesseractError(Exception):
21 | def __init__(self, message: str, process: CompletedProcess) -> None:
22 | self.process = process
23 | super().__init__(message)
24 |
25 |
26 | class BoundingBox(TypedDict):
27 | x: int
28 | y: int
29 | width: int
30 | height: int
31 |
32 |
33 | class Word(TypedDict):
34 | id: str
35 | bbox: BoundingBox
36 | text: str
37 | line_id: str
38 | block_id: str
39 |
40 |
41 | class Line(TypedDict):
42 | id: str
43 | bbox: BoundingBox
44 | text: str
45 | words: List[str]
46 | block_id: str
47 |
48 |
49 | class Block(TypedDict):
50 | id: str
51 | bbox: BoundingBox
52 | text: str
53 |
54 |
55 | class OCRResult(TypedDict):
56 | blocks: List[Block]
57 | lines: List[Line]
58 | words: List[Word]
59 | language: str
60 |
61 |
62 | def process_image(
63 | fp: Union[PathLike, str],
64 | *,
65 | lang: str = "eng",
66 | config: List[str] = None,
67 | ) -> str:
68 | args_tesseract = [
69 | TESSERACT,
70 | str(fp),
71 | "stdout",
72 | "-l",
73 | lang,
74 | ]
75 |
76 | if config is None:
77 | config = []
78 |
79 | # Add HOCR output format to get word bounding boxes
80 | config.extend(["-c", "tessedit_create_hocr=1"])
81 |
82 | args_tesseract.extend(config)
83 |
84 | result = run(args_tesseract, stdout=PIPE, stderr=PIPE, check=False, text=True)
85 |
86 | if result.returncode != 0:
87 | raise TesseractError(
88 | f"Tesseract failed to process the image, {result.stderr}", result
89 | )
90 |
91 | return result.stdout
92 |
93 |
94 | def get_bbox(element) -> BoundingBox:
95 | title = element.get("title")
96 | bbox = [int(x) for x in title.split(";")[0].split(" ")[1:]]
97 | return BoundingBox(
98 | x=bbox[0], y=bbox[1], width=bbox[2] - bbox[0], height=bbox[3] - bbox[1]
99 | )
100 |
101 |
102 | def clean_text(text: str) -> str:
103 | # Remove extra whitespace
104 | text = re.sub(r" \n ", " ", text)
105 | text = re.sub(r"\n+", "\n", text)
106 | text = re.sub(r" +", " ", text).strip()
107 | return text.strip()
108 |
109 |
110 | def process_image_to_dict(
111 | fp: Union[PathLike, str],
112 | *,
113 | lang: str = "eng",
114 | config: List[str] = None,
115 | ) -> OCRResult:
116 | hocr_content = process_image(fp, lang=lang, config=config)
117 |
118 | # Use StringIO to create a file-like object from the string
119 | hocr_file = io.StringIO(hocr_content)
120 | root = ET.parse(hocr_file).getroot()
121 |
122 | # Get image dimensions
123 | image = Image.open(fp)
124 | img_width, img_height = image.size
125 | image.close()
126 |
127 | blocks: List[Block] = []
128 | lines: List[Line] = []
129 | words: List[Word] = []
130 |
131 | block_id = 0
132 | line_id = 0
133 | word_id = 0
134 |
135 | def normalize_bbox(bbox: BoundingBox) -> BoundingBox:
136 | return BoundingBox(
137 | x=bbox["x"] / img_width,
138 | y=bbox["y"] / img_height,
139 | width=bbox["width"] / img_width,
140 | height=bbox["height"] / img_height,
141 | )
142 |
143 | for block in root.findall(".//*[@class='ocr_carea']"):
144 | block_bbox = normalize_bbox(get_bbox(block))
145 | block_text = clean_text(" ".join(block.itertext()))
146 | blocks.append(Block(id=f"block_{block_id}", bbox=block_bbox, text=block_text))
147 |
148 | for line in block.findall(".//*[@class='ocr_line']"):
149 | line_bbox = normalize_bbox(get_bbox(line))
150 | line_words: List[str] = []
151 | line_text: List[str] = []
152 |
153 | for word in line.findall(".//*[@class='ocrx_word']"):
154 | word_bbox = normalize_bbox(get_bbox(word))
155 | word_text = clean_text(word.text) if word.text else ""
156 | words.append(
157 | Word(
158 | id=f"word_{word_id}",
159 | bbox=word_bbox,
160 | text=word_text,
161 | line_id=f"line_{line_id}",
162 | block_id=f"block_{block_id}",
163 | )
164 | )
165 | line_words.append(f"word_{word_id}")
166 | line_text.append(word_text)
167 | word_id += 1
168 |
169 | lines.append(
170 | Line(
171 | id=f"line_{line_id}",
172 | bbox=line_bbox,
173 | text=" ".join(line_text),
174 | words=line_words,
175 | block_id=f"block_{block_id}",
176 | )
177 | )
178 | line_id += 1
179 |
180 | block_id += 1
181 |
182 | return OCRResult(
183 | blocks=blocks,
184 | lines=lines,
185 | words=words,
186 | language=lang,
187 | )
188 |
--------------------------------------------------------------------------------
/docprompt/contrib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/docprompt/contrib/__init__.py
--------------------------------------------------------------------------------
/docprompt/contrib/parser_bot/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/docprompt/contrib/parser_bot/__init__.py
--------------------------------------------------------------------------------
/docprompt/provenance/__init__.py:
--------------------------------------------------------------------------------
1 | from .source import PageTextLocation, ProvenanceSource
2 |
3 | __all__ = ["ProvenanceSource", "PageTextLocation"]
4 |
--------------------------------------------------------------------------------
/docprompt/provenance/source.py:
--------------------------------------------------------------------------------
1 | from typing import List, Literal, Optional
2 |
3 | from pydantic import BaseModel, Field, PositiveInt, computed_field
4 |
5 | from docprompt.schema.layout import TextBlock
6 |
7 |
8 | class PageTextLocation(BaseModel):
9 | """
10 | Specifies the location of a piece of text in a page
11 | """
12 |
13 | source_blocks: List[TextBlock] = Field(
14 | description="The source text blocks", repr=False
15 | )
16 | text: str # Sometimes the source text is less than the textblock's text.
17 | score: float
18 | granularity: Literal["word", "line", "block"] = "block"
19 |
20 | merged_source_block: Optional[TextBlock] = Field(default=None)
21 |
22 |
23 | class ProvenanceSource(BaseModel):
24 | """
25 | Bundled with some data, specifies exactly where a piece of verbatim text came from
26 | in a document.
27 | """
28 |
29 | document_name: str
30 | page_number: PositiveInt
31 | text_location: Optional[PageTextLocation] = None
32 |
33 | @computed_field # type: ignore
34 | @property
35 | def source_block(self) -> Optional[TextBlock]:
36 | if self.text_location:
37 | if self.text_location.merged_source_block:
38 | return self.text_location.merged_source_block
39 | if self.text_location.source_blocks:
40 | return self.text_location.source_blocks[0]
41 |
42 | return None
43 |
44 | @property
45 | def text(self) -> str:
46 | if self.text_location:
47 | return "\n".join([block.text for block in self.text_location.source_blocks])
48 |
49 | return ""
50 |
--------------------------------------------------------------------------------
/docprompt/provenance/util.py:
--------------------------------------------------------------------------------
1 | import re
2 | from collections import defaultdict
3 | from typing import Any, Iterable, List, Optional
4 |
5 | from rapidfuzz import fuzz
6 | from rapidfuzz.utils import default_process
7 |
8 | from docprompt.schema.layout import NormBBox, TextBlock
9 |
10 | try:
11 | import tantivy
12 | except ImportError:
13 | raise ImportError("Could not import tantivy. Install with `docprompt[search]`")
14 |
15 | try:
16 | import networkx
17 | except ImportError:
18 | raise ImportError("Could not import networkx. Install with `docprompt[search]`")
19 |
20 |
21 | _prefix_regexs = [
22 | re.compile(r"^\d+\.\s+"),
23 | re.compile(r"^\d+\.\d+\s+"),
24 | re.compile(r"^\*+\s+"),
25 | re.compile(r"^-+\s+"),
26 | ]
27 |
28 |
29 | def preprocess_query_text(text: str) -> str:
30 | """
31 | Improve matching ability by applying some preprocessing to the query text.
32 | """
33 | for regex in _prefix_regexs:
34 | text = regex.sub("", text)
35 |
36 | text = text.strip()
37 |
38 | text = text.replace('"', "")
39 |
40 | return text
41 |
42 |
43 | def word_tokenize(text: str) -> List[str]:
44 | """
45 | Tokenize a string into words.
46 | """
47 | return re.split(r"\s+", text)
48 |
49 |
50 | def create_tantivy_document_wise_block_index():
51 | schema_builder = tantivy.SchemaBuilder()
52 |
53 | schema_builder.add_integer_field(
54 | "page_number", stored=True, indexed=True, fast=True
55 | )
56 | schema_builder.add_text_field("block_type", stored=True)
57 | schema_builder.add_integer_field("block_page_idx", stored=True)
58 | schema_builder.add_text_field("content", stored=True)
59 |
60 | schema = schema_builder.build()
61 |
62 | index = tantivy.Index(schema=schema)
63 |
64 | return index
65 |
66 |
67 | def construct_valid_rtree_tuple(bbox: NormBBox):
68 | # For some reason sometimes the bounding box is invalid (top > bottom, x0 > x1
69 | # This function is to ensure that the bounding box is valid for the rtree index
70 |
71 | true_top = min(bbox.top, bbox.bottom)
72 | true_bottom = max(bbox.top, bbox.bottom)
73 |
74 | true_x0 = min(bbox.x0, bbox.x1)
75 | true_x1 = max(bbox.x0, bbox.x1)
76 |
77 | return (true_x0, true_top, true_x1, true_bottom)
78 |
79 |
80 | def insert_generator(bboxes: List[NormBBox], data: Optional[Iterable[Any]] = None):
81 | """
82 | Make an iterator that yields tuples of (id, bbox, data) for insertion into an RTree index
83 | which improves performance massively.
84 | """
85 | data = data or [None] * len(bboxes)
86 |
87 | for idx, (bbox, data_item) in enumerate(zip(bboxes, data)):
88 | yield (idx, construct_valid_rtree_tuple(bbox), data_item)
89 |
90 |
91 | def refine_block_to_word_level(
92 | source_block: TextBlock,
93 | intersecting_word_level_blocks: List[TextBlock],
94 | query: str,
95 | ):
96 | """
97 | Create a new text block by merging the intersecting word level blocks that
98 | match the query.
99 |
100 | """
101 | intersecting_word_level_blocks.sort(
102 | key=lambda x: (x.bounding_box.top, x.bounding_box.x0)
103 | )
104 |
105 | tokenized_query = word_tokenize(query)
106 |
107 | if len(tokenized_query) == 1:
108 | fuzzified = default_process(tokenized_query[0])
109 | for word_level_block in intersecting_word_level_blocks:
110 | if fuzz.ratio(fuzzified, default_process(word_level_block.text)) > 87.5:
111 | return word_level_block, [word_level_block]
112 | else:
113 | fuzzified_word_level_texts = [
114 | default_process(word_level_block.text)
115 | for word_level_block in intersecting_word_level_blocks
116 | ]
117 |
118 | # Populate the block mapping
119 | token_block_mapping = defaultdict(set)
120 |
121 | first_word = tokenized_query[0]
122 | last_word = tokenized_query[-1]
123 |
124 | for token in tokenized_query:
125 | fuzzified_token = default_process(token)
126 | for i, word_level_block in enumerate(intersecting_word_level_blocks):
127 | if fuzz.ratio(fuzzified_token, fuzzified_word_level_texts[i]) > 87.5:
128 | token_block_mapping[token].add(i)
129 |
130 | graph = networkx.DiGraph()
131 | prev = tokenized_query[0]
132 |
133 | for i in token_block_mapping[prev]:
134 | graph.add_node(i)
135 |
136 | for token in tokenized_query[1:]:
137 | for prev_block in token_block_mapping[prev]:
138 | for block in sorted(token_block_mapping[token]):
139 | if block > prev_block:
140 | weight = (
141 | (block - prev_block) ** 2
142 | ) # Square the distance to penalize large jumps, which encourages reading order
143 | graph.add_edge(prev_block, block, weight=weight)
144 |
145 | prev = token
146 |
147 | # Get every combination of first and last word
148 | first_word_blocks = token_block_mapping[first_word]
149 | last_word_blocks = token_block_mapping[last_word]
150 |
151 | combinations = sorted(
152 | [(x, y) for x in first_word_blocks for y in last_word_blocks if x < y],
153 | key=lambda x: abs(x[1] - x[0]),
154 | )
155 |
156 | for start, end in combinations:
157 | try:
158 | path = networkx.shortest_path(graph, start, end, weight="weight")
159 | except networkx.NetworkXNoPath:
160 | continue
161 | except Exception:
162 | continue
163 |
164 | matching_blocks = [intersecting_word_level_blocks[i] for i in path]
165 |
166 | merged_bbox = NormBBox.combine(
167 | *[word_level_block.bounding_box for word_level_block in matching_blocks]
168 | )
169 |
170 | merged_text = ""
171 |
172 | for word_level_block in matching_blocks:
173 | merged_text += word_level_block.text
174 | if not word_level_block.text.endswith(" "):
175 | merged_text += " " # Ensure there is a space between words
176 |
177 | return (
178 | TextBlock(
179 | text=merged_text,
180 | type="block",
181 | bounding_box=merged_bbox,
182 | metadata=source_block.metadata,
183 | ),
184 | matching_blocks,
185 | )
186 |
--------------------------------------------------------------------------------
/docprompt/schema/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/docprompt/schema/__init__.py
--------------------------------------------------------------------------------
/docprompt/schema/pipeline/__init__.py:
--------------------------------------------------------------------------------
1 | from .metadata import BaseMetadata
2 | from .node import DocumentCollection, DocumentNode, ImageNode, PageNode
3 |
4 | __all__ = [
5 | "DocumentCollection",
6 | "DocumentNode",
7 | "PageNode",
8 | "ImageNode",
9 | "BaseMetadata",
10 | ]
11 |
--------------------------------------------------------------------------------
/docprompt/schema/pipeline/node/__init__.py:
--------------------------------------------------------------------------------
1 | from .collection import DocumentCollection
2 | from .document import DocumentNode
3 | from .image import ImageNode
4 | from .page import PageNode
5 | from .typing import DocumentNodeMetadata, PageNodeMetadata
6 |
7 | __all__ = [
8 | "DocumentNode",
9 | "PageNode",
10 | "ImageNode",
11 | "DocumentNodeMetadata",
12 | "PageNodeMetadata",
13 | "DocumentCollection",
14 | ]
15 |
--------------------------------------------------------------------------------
/docprompt/schema/pipeline/node/base.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 |
3 |
4 | class BaseNode(BaseModel):
5 | """The base node class is utilized for defining a basic yet flexible interface"""
6 |
--------------------------------------------------------------------------------
/docprompt/schema/pipeline/node/collection.py:
--------------------------------------------------------------------------------
1 | from typing import TYPE_CHECKING, Generic, List
2 |
3 | from pydantic import BaseModel, Field
4 |
5 | from .typing import DocumentCollectionMetadata, DocumentNodeMetadata, PageNodeMetadata
6 |
7 | if TYPE_CHECKING:
8 | from .document import DocumentNode
9 |
10 |
11 | class DocumentCollection(
12 | BaseModel,
13 | Generic[DocumentCollectionMetadata, DocumentNodeMetadata, PageNodeMetadata],
14 | ):
15 | """
16 | Represents a collection of documents with some common metadata
17 | """
18 |
19 | document_nodes: List["DocumentNode[DocumentNodeMetadata, PageNodeMetadata]"]
20 | metadata: DocumentCollectionMetadata = Field(..., default_factory=dict)
21 |
--------------------------------------------------------------------------------
/docprompt/schema/pipeline/node/image.py:
--------------------------------------------------------------------------------
1 | import base64
2 | from typing import Generic
3 |
4 | from pydantic import Field, field_serializer, field_validator
5 |
6 | from docprompt.schema.pipeline.metadata import BaseMetadata
7 |
8 | from .base import BaseNode
9 | from .typing import ImageNodeMetadata
10 |
11 |
12 | class ImageNode(BaseNode, Generic[ImageNodeMetadata]):
13 | """
14 | Represents a single image of any kind
15 | """
16 |
17 | image: bytes
18 |
19 | metadata: ImageNodeMetadata = Field(
20 | description="Application-specific metadata for the image",
21 | default_factory=BaseMetadata,
22 | )
23 |
24 | @field_serializer("image")
25 | def serialize_image(self, value):
26 | return base64.b64encode(value).decode("utf-8")
27 |
28 | @field_validator("image")
29 | @classmethod
30 | def validate_image(cls, value):
31 | if isinstance(value, bytes):
32 | return value
33 |
34 | return base64.b64decode(value)
35 |
36 | @property
37 | def pil_image(self):
38 | from io import BytesIO
39 |
40 | from PIL import Image
41 |
42 | return Image.open(BytesIO(self.image))
43 |
44 | @property
45 | def cv2_image(self):
46 | try:
47 | import cv2
48 | except ImportError:
49 | raise ImportError("OpenCV is required to use this property")
50 |
51 | try:
52 | import numpy as np
53 | except ImportError:
54 | raise ImportError("Numpy is required to use this property")
55 |
56 | return cv2.imdecode(np.frombuffer(self.image, np.uint8), cv2.IMREAD_COLOR)
57 |
--------------------------------------------------------------------------------
/docprompt/schema/pipeline/node/page.py:
--------------------------------------------------------------------------------
1 | from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Union
2 |
3 | from pydantic import Field, PositiveInt
4 |
5 | from docprompt.schema.pipeline.metadata import BaseMetadata
6 | from docprompt.schema.pipeline.rasterizer import PageRasterizer
7 |
8 | from .base import BaseNode
9 | from .typing import PageNodeMetadata
10 |
11 | if TYPE_CHECKING:
12 | from docprompt.tasks.ocr.result import OcrPageResult
13 |
14 | from .document import DocumentNode
15 |
16 |
17 | class SimplePageNodeMetadata(BaseMetadata):
18 | """
19 | A simple metadata class for a page node
20 | """
21 |
22 | ocr_results: Optional["OcrPageResult"] = Field(
23 | None, description="The OCR results for the page"
24 | )
25 |
26 |
27 | class PageNode(BaseNode, Generic[PageNodeMetadata]):
28 | """
29 | Represents a single page in a document, with some metadata
30 | """
31 |
32 | document: "DocumentNode" = Field(exclude=True, repr=False)
33 | page_number: PositiveInt = Field(description="The page number")
34 | metadata: Union[PageNodeMetadata, SimplePageNodeMetadata] = Field(
35 | description="Application-specific metadata for the page",
36 | default_factory=SimplePageNodeMetadata,
37 | )
38 | extra: Dict[str, Any] = Field(
39 | description="Extra data that can be stored on the page node",
40 | default_factory=dict,
41 | )
42 |
43 | @property
44 | def rasterizer(self):
45 | return PageRasterizer(self)
46 |
47 | @property
48 | def ocr_results(self) -> Optional["OcrPageResult"]:
49 | from docprompt.tasks.ocr.result import OcrPageResult
50 |
51 | return self.metadata.find_by_type(OcrPageResult)
52 |
53 | @ocr_results.setter
54 | def ocr_results(self, value):
55 | if not hasattr(self.metadata, "ocr_results"):
56 | raise AttributeError(
57 | "Page metadata does not have an `ocr_results` attribute"
58 | )
59 |
60 | self.metadata.ocr_results = value
61 |
62 | def search(
63 | self, query: str, refine_to_words: bool = True, require_exact_match: bool = True
64 | ):
65 | return self.document.locator.search(
66 | query,
67 | page_number=self.page_number,
68 | refine_to_word=refine_to_words,
69 | require_exact_match=require_exact_match,
70 | )
71 |
72 | def get_layout_aware_text(self, **kwargs) -> str:
73 | if not self.ocr_results:
74 | raise ValueError("Calculate OCR results before calling layout_aware_text")
75 |
76 | from docprompt.utils.layout import build_layout_aware_page_representation
77 |
78 | word_blocks = self.ocr_results.word_level_blocks
79 |
80 | line_blocks = self.ocr_results.line_level_blocks
81 |
82 | if not len(line_blocks):
83 | line_blocks = None
84 |
85 | return build_layout_aware_page_representation(
86 | word_blocks, line_blocks=line_blocks, **kwargs
87 | )
88 |
89 | @property
90 | def layout_aware_text(self):
91 | return self.get_layout_aware_text()
92 |
--------------------------------------------------------------------------------
/docprompt/schema/pipeline/node/typing.py:
--------------------------------------------------------------------------------
1 | from typing import TypeVar
2 |
3 | from ..metadata import BaseMetadata
4 |
5 | ImageNodeMetadata = TypeVar("ImageNodeMetadata", bound=BaseMetadata)
6 | PageNodeMetadata = TypeVar("PageNodeMetadata", bound=BaseMetadata)
7 | DocumentNodeMetadata = TypeVar("DocumentNodeMetadata", bound=BaseMetadata)
8 | DocumentCollectionMetadata = TypeVar("DocumentCollectionMetadata", bound=BaseMetadata)
9 |
--------------------------------------------------------------------------------
/docprompt/tasks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/docprompt/tasks/__init__.py
--------------------------------------------------------------------------------
/docprompt/tasks/capabilities.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 |
4 | class PageLevelCapabilities(str, Enum):
5 | """
6 | Represents a capability that a provider can fulfill
7 | """
8 |
9 | PAGE_RASTERIZATION = "page-rasterization"
10 | PAGE_LAYOUT_OCR = "page-layout-ocr"
11 | PAGE_TEXT_OCR = "page-text-ocr"
12 | PAGE_CLASSIFICATION = "page-classification"
13 | PAGE_MARKERIZATION = "page-markerization"
14 | PAGE_SEGMENTATION = "page-segmentation"
15 | PAGE_VQA = "page-vqa"
16 | PAGE_TABLE_IDENTIFICATION = "page-table-identification"
17 | PAGE_TABLE_EXTRACTION = "page-table-extraction"
18 |
19 |
20 | class DocumentLevelCapabilities(str, Enum):
21 | DOCUMENT_VQA = "multi-page-document-vqa"
22 |
--------------------------------------------------------------------------------
/docprompt/tasks/classification/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/docprompt/tasks/classification/__init__.py
--------------------------------------------------------------------------------
/docprompt/tasks/classification/anthropic.py:
--------------------------------------------------------------------------------
1 | """The anthropic implementation of page level classification."""
2 |
3 | import re
4 | from typing import Iterable, List
5 |
6 | from pydantic import Field
7 |
8 | from docprompt.tasks.message import OpenAIComplexContent, OpenAIImageURL, OpenAIMessage
9 | from docprompt.utils import inference
10 |
11 | from .base import (
12 | BaseClassificationProvider,
13 | BasePageClassificationOutputParser,
14 | ClassificationConfig,
15 | ClassificationOutput,
16 | )
17 |
18 |
19 | def get_classification_system_prompt(input: ClassificationConfig) -> str:
20 | prompt_parts = [
21 | "You are a classification expert. You are given a single page to perform a classification task on.\n"
22 | ]
23 |
24 | if input.instructions:
25 | prompt_parts.append(f"Task Instructions:\n{input.instructions}\n\n")
26 |
27 | if input.type == "binary":
28 | prompt_parts.append(
29 | 'You must classify the page with a binary label:\n"YES"/"NO"\n'
30 | )
31 | else:
32 | classification_task = (
33 | "all labels that apply"
34 | if input.type == "multi_label"
35 | else "one of the following"
36 | )
37 | prompt_parts.append(f"Classify the page as {classification_task}:\n")
38 | for label in input.formatted_labels:
39 | prompt_parts.append(f"- {label}\n")
40 | prompt_parts.append(
41 | "\nThese are the only label values you may use when providing your classifications!\n"
42 | )
43 |
44 | prompt_parts.append(
45 | "\nIt is crucial that your response is accurate and provides a valid answer using "
46 | )
47 | if input.type == "multi_label":
48 | prompt_parts.append("the labels ")
49 | else:
50 | prompt_parts.append("one of the labels ")
51 | prompt_parts.append(
52 | "above. There are consequences for providing INVALID or INACCURATE labels.\n\n"
53 | )
54 |
55 | prompt_parts.append(
56 | "Answer in the following format:\n\nReasoning: { your reasoning and analysis }\n"
57 | )
58 |
59 | if input.type == "binary":
60 | prompt_parts.append('Answer: { "YES" or "NO" }\n')
61 | elif input.type == "single_label":
62 | prompt_parts.append('Answer: { "label-value" }\n')
63 | else:
64 | prompt_parts.append('Answer: { "label-value", "label-value", ... }\n')
65 |
66 | if input.confidence:
67 | prompt_parts.append("Confidence: { low, medium, high }\n")
68 |
69 | prompt_parts.append(
70 | "\nYou MUST ONLY use the labels provided and described above. Do not use ANY additional labels.\n"
71 | )
72 |
73 | return "".join(prompt_parts).strip()
74 |
75 |
76 | class AnthropicPageClassificationOutputParser(BasePageClassificationOutputParser):
77 | """The output parser for the page classification system."""
78 |
79 | def parse(self, text: str) -> ClassificationOutput:
80 | """Parse the results of the classification task."""
81 | pattern = re.compile(r"Answer:\s*(?:['\"`]?)(.+?)(?:['\"`]?)\s*$", re.MULTILINE)
82 | match = pattern.search(text)
83 |
84 | result = self.resolve_match(match)
85 |
86 | if self.confidence:
87 | conf_pattern = re.compile(r"Confidence: (.+)")
88 | conf_match = conf_pattern.search(text)
89 | conf_result = self.resolve_confidence(conf_match)
90 |
91 | return ClassificationOutput(
92 | type=self.type,
93 | labels=result,
94 | score=conf_result,
95 | provider_name=self.name,
96 | )
97 |
98 | return ClassificationOutput(
99 | type=self.type, labels=result, provider_name=self.name
100 | )
101 |
102 |
103 | def _prepare_messages(
104 | document_images: Iterable[bytes],
105 | config: ClassificationConfig,
106 | ):
107 | messages = []
108 |
109 | for image_bytes in document_images:
110 | messages.append(
111 | [
112 | OpenAIMessage(
113 | role="user",
114 | content=[
115 | OpenAIComplexContent(
116 | type="image_url",
117 | image_url=OpenAIImageURL(url=image_bytes),
118 | ),
119 | OpenAIComplexContent(
120 | type="text",
121 | text=get_classification_system_prompt(config),
122 | ),
123 | ],
124 | ),
125 | ]
126 | )
127 |
128 | return messages
129 |
130 |
131 | class AnthropicClassificationProvider(BaseClassificationProvider):
132 | """The Anthropic implementation of unscored page classification."""
133 |
134 | name = "anthropic"
135 |
136 | anthropic_model_name: str = Field("claude-3-haiku-20240307")
137 |
138 | async def _ainvoke(
139 | self, input: Iterable[bytes], config: ClassificationConfig = None, **kwargs
140 | ) -> List[ClassificationOutput]:
141 | messages = _prepare_messages(input, config)
142 |
143 | parser = AnthropicPageClassificationOutputParser.from_task_input(
144 | config, provider_name=self.name
145 | )
146 |
147 | model_name = kwargs.pop("model_name", self.anthropic_model_name)
148 | completions = await inference.run_batch_inference_anthropic(
149 | model_name, messages, **kwargs
150 | )
151 | return [parser.parse(res) for res in completions]
152 |
--------------------------------------------------------------------------------
/docprompt/tasks/credentials.py:
--------------------------------------------------------------------------------
1 | """The credentials module defines a simple model schema for storing credentials."""
2 |
3 | import os
4 | from typing import Dict, Mapping, Optional
5 |
6 | from pydantic import BaseModel, Field, HttpUrl, SecretStr, model_validator
7 | from typing_extensions import Self
8 |
9 |
10 | class BaseCredentials(BaseModel):
11 | """The base credentials model."""
12 |
13 | @property
14 | def kwargs(self) -> Dict[str, str]:
15 | """Return the credentials as a dictionary with secrets exposed."""
16 | data = self.model_dump(exclude_none=True)
17 | for key, value in data.items():
18 | if isinstance(value, SecretStr):
19 | data[key] = value.get_secret_value()
20 | return data
21 |
22 |
23 | class APIKeyCredential(BaseCredentials):
24 | """The API key credential model."""
25 |
26 | api_key: SecretStr
27 |
28 | def __init__(self, environ_path: Optional[str] = None, **data):
29 | api_key = data.get("api_key", None)
30 | if api_key is None and environ_path:
31 | api_key = os.environ.get(environ_path, None)
32 | data["api_key"] = api_key
33 |
34 | super().__init__(**data)
35 |
36 |
37 | class GenericOpenAICredentials(APIKeyCredential):
38 | """Credentials that are common for OpenAI API requests."""
39 |
40 | base_url: Optional[HttpUrl] = Field(None)
41 | timeout: Optional[int] = Field(None)
42 | max_retries: Optional[int] = Field(None)
43 |
44 | default_headers: Optional[Mapping[str, str]] = Field(None)
45 | default_query_params: Optional[Mapping[str, object]] = Field(None)
46 |
47 |
48 | class AWSCredentials(BaseCredentials):
49 | """The AWS credentials model."""
50 |
51 | aws_access_key_id: Optional[SecretStr] = Field(None)
52 | aws_secret_access_key: Optional[SecretStr] = Field(None)
53 | aws_session_token: Optional[SecretStr] = Field(None)
54 | aws_region: Optional[str] = Field(None)
55 |
56 | def __init__(self, **data):
57 | aws_access_key_id = data.get(
58 | "aws_access_key_id", os.environ.get("AWS_ACCESS_KEY_ID", None)
59 | )
60 | aws_secret_access_key = data.get(
61 | "aws_secret_access_key", os.environ.get("AWS_SECRET_ACCESS_KEY", None)
62 | )
63 | aws_session_token = data.get(
64 | "aws_session_token", os.environ.get("AWS_SESSION_TOKEN", None)
65 | )
66 | aws_region = data.get("aws_region", os.environ.get("AWS_DEFAULT_REGION", None))
67 |
68 | super().__init__(
69 | aws_access_key_id=aws_access_key_id,
70 | aws_secret_access_key=aws_secret_access_key,
71 | aws_session_token=aws_session_token,
72 | aws_region=aws_region,
73 | )
74 |
75 | @model_validator(mode="after")
76 | def _validate_aws_credentials(self) -> Self:
77 | """Ensure the provided AWS credentials are valid."""
78 |
79 | key_pair_is_set = self.aws_access_key_id and self.aws_secret_access_key
80 |
81 | if not key_pair_is_set and not self.aws_session_token:
82 | raise ValueError(
83 | "You must provide either an AWS session token or an access key and secret key."
84 | )
85 |
86 | if key_pair_is_set and not self.aws_region:
87 | raise ValueError(
88 | "You must provide an AWS region when using an access key and secret key."
89 | )
90 |
91 | if key_pair_is_set and self.aws_session_token:
92 | raise ValueError(
93 | "You cannot provide both an AWS session token and an access key and secret key."
94 | )
95 |
96 | return self
97 |
98 |
99 | class GCPServiceFileCredentials(BaseCredentials):
100 | """The GCP service account credentials model."""
101 |
102 | service_account_info: Optional[Dict[str, str]] = Field(None)
103 | service_account_file: Optional[str] = Field(None)
104 |
105 | def __init__(self, **data):
106 | service_account_info = data.get("service_account_info", None)
107 | service_account_file = data.get(
108 | "service_account_file", os.environ.get("GCP_SERVICE_ACCOUNT_FILE", None)
109 | )
110 |
111 | super().__init__(
112 | service_account_info=service_account_info,
113 | service_account_file=service_account_file,
114 | )
115 |
116 | @model_validator(mode="after")
117 | def _validate_gcp_credentials(self) -> Self:
118 | """Ensure the provided GCP credentials are valid."""
119 | if self.service_account_info is None and self.service_account_file is None:
120 | raise ValueError(
121 | "You must provide either service_account_info or service_account_file. You may set the `GCP_SERVICE_ACCOUNT_FILE` environment variable to the path of the service account file."
122 | )
123 | if (
124 | self.service_account_info is not None
125 | and self.service_account_file is not None
126 | ):
127 | raise ValueError(
128 | "You must provide either service_account_info or service_account_file, not both"
129 | )
130 | return self
131 |
--------------------------------------------------------------------------------
/docprompt/tasks/markerize/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/docprompt/tasks/markerize/__init__.py
--------------------------------------------------------------------------------
/docprompt/tasks/markerize/anthropic.py:
--------------------------------------------------------------------------------
1 | from typing import Iterable, List, Optional
2 |
3 | from bs4 import BeautifulSoup
4 | from pydantic import Field
5 |
6 | from docprompt.tasks.message import OpenAIComplexContent, OpenAIImageURL, OpenAIMessage
7 | from docprompt.utils import inference
8 |
9 | from .base import BaseMarkerizeProvider, MarkerizeResult
10 |
11 | _HUMAN_MESSAGE_PROMPT = """
12 | Convert the image into markdown, preserving the overall layout and style of the page. \
13 | Use the appropriate headings for different sections. Preserve bolded and italicized text. \
14 | Include ALL the text on the page.
15 |
16 | You ALWAYS respond by wrapping the markdown in tags.
17 | """.strip()
18 |
19 |
20 | def ensure_single_root(xml_data: str) -> str:
21 | """Ensure the XML data has a single root element."""
22 | if not xml_data.strip().startswith("") and not xml_data.strip().endswith(
23 | ""
24 | ):
25 | return f"{xml_data}"
26 | return xml_data
27 |
28 |
29 | def _parse_result(raw_markdown: str) -> Optional[str]:
30 | raw_markdown = ensure_single_root(raw_markdown)
31 | soup = BeautifulSoup(raw_markdown, "html.parser")
32 |
33 | md = soup.find("md")
34 |
35 | return md.text.strip() if md else "" # TODO Fix bad extractions
36 |
37 |
38 | def _prepare_messages(
39 | document_images: Iterable[bytes],
40 | start: Optional[int] = None,
41 | stop: Optional[int] = None,
42 | ):
43 | messages = []
44 |
45 | for image_bytes in document_images:
46 | messages.append(
47 | [
48 | OpenAIMessage(
49 | role="user",
50 | content=[
51 | OpenAIComplexContent(
52 | type="image_url",
53 | image_url=OpenAIImageURL(url=image_bytes),
54 | ),
55 | OpenAIComplexContent(type="text", text=_HUMAN_MESSAGE_PROMPT),
56 | ],
57 | ),
58 | ]
59 | )
60 |
61 | return messages
62 |
63 |
64 | class AnthropicMarkerizeProvider(BaseMarkerizeProvider):
65 | name = "anthropic"
66 |
67 | anthropic_model_name: str = Field("claude-3-haiku-20240307")
68 |
69 | async def _ainvoke(
70 | self, input: Iterable[bytes], config: Optional[None] = None, **kwargs
71 | ) -> List[MarkerizeResult]:
72 | messages = _prepare_messages(input)
73 |
74 | model_name = kwargs.pop("model_name", self.anthropic_model_name)
75 | completions = await inference.run_batch_inference_anthropic(
76 | model_name, messages, **kwargs
77 | )
78 |
79 | return [
80 | MarkerizeResult(raw_markdown=_parse_result(x), provider_name=self.name)
81 | for x in completions
82 | ]
83 |
--------------------------------------------------------------------------------
/docprompt/tasks/markerize/base.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from docprompt.schema.pipeline.node.document import DocumentNode
4 | from docprompt.tasks.base import AbstractPageTaskProvider, BasePageResult
5 |
6 | from ..capabilities import PageLevelCapabilities
7 |
8 |
9 | class MarkerizeResult(BasePageResult):
10 | task_name = "markerize"
11 | raw_markdown: str
12 |
13 |
14 | class BaseMarkerizeProvider(AbstractPageTaskProvider[bytes, None, MarkerizeResult]):
15 | capabilities = [PageLevelCapabilities.PAGE_MARKERIZATION]
16 |
17 | class Meta:
18 | abstract = True
19 |
20 | def process_document_node(
21 | self,
22 | document_node: "DocumentNode",
23 | task_config: Optional[None] = None,
24 | start: Optional[int] = None,
25 | stop: Optional[int] = None,
26 | contribute_to_document: bool = True,
27 | **kwargs,
28 | ):
29 | raster_bytes = []
30 | for page_number in range(start or 1, (stop or len(document_node)) + 1):
31 | image_bytes = document_node.page_nodes[
32 | page_number - 1
33 | ].rasterizer.rasterize("default")
34 | raster_bytes.append(image_bytes)
35 |
36 | # TODO: This is a somewhat dangerous way of requiring these kwargs to be drilled
37 | # through, potentially a decorator solution to be had here
38 | kwargs = {**self._default_invoke_kwargs, **kwargs}
39 | results = self._invoke(raster_bytes, config=task_config, **kwargs)
40 |
41 | return {
42 | i: res
43 | for i, res in zip(
44 | range(start or 1, (stop or len(document_node)) + 1), results
45 | )
46 | }
47 |
--------------------------------------------------------------------------------
/docprompt/tasks/message.py:
--------------------------------------------------------------------------------
1 | """
2 | The core primatives for any language model interfacing. Docprompt uses these for the prompt garden, but
3 | supports free conversion to and from these types from other libaries.
4 | """
5 |
6 | from typing import List, Literal, Optional, Union
7 |
8 | from pydantic import BaseModel, model_validator
9 |
10 |
11 | def _ensure_png_base64_prefix(base64_string: str):
12 | prefix = "data:image/png;base64,"
13 | if base64_string.startswith(prefix):
14 | return base64_string
15 | else:
16 | return prefix + base64_string
17 |
18 |
19 | def _strip_png_base64_prefix(base64_string: str):
20 | prefix = "data:image/png;base64,"
21 | if base64_string.startswith(prefix):
22 | return base64_string[len(prefix) :]
23 | else:
24 | return base64_string
25 |
26 |
27 | class OpenAIImageURL(BaseModel):
28 | url: str
29 |
30 |
31 | class OpenAIComplexContent(BaseModel):
32 | type: Literal["text", "image_url"]
33 | text: Optional[str] = None
34 | image_url: Optional[OpenAIImageURL] = None
35 |
36 | @model_validator(mode="after")
37 | def validate_content(cls, v):
38 | if v.type == "text" and v.text is None:
39 | raise ValueError("Text content must be provided when type is 'text'")
40 | if v.type == "image_url" and v.image_url is None:
41 | raise ValueError(
42 | "Image URL content must be provided when type is 'image_url'"
43 | )
44 |
45 | if v.text is not None and v.image_url is not None:
46 | raise ValueError("Only one of text or image_url can be provided")
47 |
48 | return v
49 |
50 | def to_anthropic_message(self):
51 | if self.type == "text":
52 | return {"type": "text", "text": self.text}
53 | elif self.type == "image_url":
54 | return {
55 | "type": "image",
56 | "source": {
57 | "data": _strip_png_base64_prefix(self.image_url.url),
58 | "media_type": "image/png",
59 | "type": "base64",
60 | },
61 | }
62 | else:
63 | raise ValueError(f"Invalid content type: {self.type}")
64 |
65 |
66 | class OpenAIMessage(BaseModel):
67 | role: Literal["system", "user", "assistant"]
68 | content: Union[str, List[OpenAIComplexContent]]
69 |
70 | def to_langchain_message(self):
71 | try:
72 | from langchain.schema import AIMessage, HumanMessage, SystemMessage
73 | except ImportError:
74 | raise ImportError(
75 | "Could not import langchain.schema. Install with `docprompt[langchain]`"
76 | )
77 |
78 | role_mapping = {
79 | "system": SystemMessage,
80 | "user": HumanMessage,
81 | "assistant": AIMessage,
82 | }
83 |
84 | dumped = self.model_dump(mode="json", exclude_unset=True, exclude_none=True)
85 |
86 | return role_mapping[self.role](content=dumped["content"])
87 |
88 | def to_openai(self):
89 | return self.model_dump(mode="json", exclude_unset=True, exclude_none=True)
90 |
91 | def to_llamaindex_chat_message(self):
92 | try:
93 | from llama_index.core.base.llms.types import ChatMessage, MessageRole
94 | except ImportError:
95 | raise ImportError(
96 | "Could not import llama_index.core. Install with `docprompt[llamaindex]`"
97 | )
98 |
99 | role_mapping = {
100 | "system": MessageRole.SYSTEM,
101 | "user": MessageRole.USER,
102 | "assistant": MessageRole.ASSISTANT,
103 | }
104 |
105 | dumped = self.model_dump(mode="json", exclude_unset=True, exclude_none=True)
106 |
107 | return ChatMessage.from_str(
108 | content=dumped["content"], role=role_mapping[self.role]
109 | )
110 |
111 | @classmethod
112 | def from_image_uri(cls, image_uri: str) -> "OpenAIMessage":
113 | """Create an image message from a URI.
114 |
115 | Args:
116 | role: The role of the message.
117 | image_uri: The URI of the image.
118 | """
119 | image_url = OpenAIImageURL(url=image_uri)
120 | content = OpenAIComplexContent(type="image_url", image_url=image_url)
121 | message = cls(role="user", content=[content])
122 | return message
123 |
--------------------------------------------------------------------------------
/docprompt/tasks/ocr/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/docprompt/tasks/ocr/__init__.py
--------------------------------------------------------------------------------
/docprompt/tasks/ocr/base.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from typing import TYPE_CHECKING, Dict, Optional, Union
3 |
4 | from docprompt.schema.document import PdfDocument
5 | from docprompt.schema.pipeline import DocumentNode
6 | from docprompt.tasks.base import AbstractPageTaskProvider
7 | from docprompt.tasks.ocr.result import OcrPageResult
8 |
9 | if TYPE_CHECKING:
10 | from docprompt.schema.pipeline import DocumentNode
11 |
12 |
13 | ImageBytes = bytes
14 |
15 |
16 | class BaseOCRProvider(
17 | AbstractPageTaskProvider[Union[PdfDocument, ImageBytes], None, OcrPageResult]
18 | ):
19 | def _populate_ocr_results(
20 | self,
21 | document_node: "DocumentNode",
22 | results: Dict[int, OcrPageResult],
23 | add_images_to_raster_cache: bool = False,
24 | raster_cache_key: str = "default",
25 | ) -> None:
26 | for page_number, result in results.items():
27 | result.contribute_to_document_node(
28 | document_node,
29 | page_number=page_number,
30 | add_images_to_raster_cache=add_images_to_raster_cache,
31 | raster_cache_key=raster_cache_key,
32 | )
33 |
34 | @abstractmethod
35 | def process_document_node(
36 | self,
37 | document_node: "DocumentNode",
38 | task_config: Optional[None] = None,
39 | start: Optional[int] = None,
40 | stop: Optional[int] = None,
41 | contribute_to_document: bool = True,
42 | **kwargs,
43 | ) -> Dict[int, OcrPageResult]: ...
44 |
--------------------------------------------------------------------------------
/docprompt/tasks/ocr/result.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 | from typing import Any, Dict, List, Optional
3 |
4 | from pydantic import Field
5 |
6 | from docprompt.schema.layout import TextBlock
7 | from docprompt.schema.pipeline.node.document import DocumentNode
8 | from docprompt.tasks.base import BasePageResult
9 |
10 |
11 | class OcrPageResult(BasePageResult):
12 | page_text: str = Field(description="The text for the entire page in reading order")
13 |
14 | word_level_blocks: List[TextBlock] = Field(
15 | default_factory=list,
16 | description="The provider-sourced words for the page",
17 | repr=False,
18 | )
19 | line_level_blocks: List[TextBlock] = Field(
20 | default_factory=list,
21 | description="The provider-sourced lines for the page",
22 | repr=False,
23 | )
24 | block_level_blocks: List[TextBlock] = Field(
25 | default_factory=list,
26 | description="The provider-sourced blocks for the page",
27 | repr=False,
28 | )
29 |
30 | raster_image: Optional[bytes] = Field(
31 | default=None,
32 | description="The rasterized image of the page used in OCR",
33 | repr=False,
34 | )
35 |
36 | extra: Optional[Dict[str, Any]] = Field(default_factory=dict)
37 |
38 | task_name = "ocr"
39 |
40 | @property
41 | def pil_image(self):
42 | if not self.raster_image:
43 | return None
44 | from PIL import Image
45 |
46 | return Image.open(BytesIO(self.raster_image))
47 |
48 | @property
49 | def words(self):
50 | return self.word_level_blocks
51 |
52 | @property
53 | def lines(self):
54 | return self.line_level_blocks
55 |
56 | @property
57 | def blocks(self):
58 | return self.block_level_blocks
59 |
60 | def contribute_to_document_node(
61 | self,
62 | document_node: DocumentNode,
63 | page_number: Optional[int] = None,
64 | add_images_to_raster_cache: bool = False,
65 | raster_cache_key: str = "default",
66 | **kwargs,
67 | ) -> None:
68 | if not page_number:
69 | raise ValueError("Page number must be provided for page level results")
70 |
71 | page_node = document_node.page_nodes[page_number - 1]
72 | if hasattr(page_node.metadata, "ocr_results"):
73 | page_node.metadata.ocr_results = self.model_copy(
74 | update={"raster_image": None}
75 | )
76 | else:
77 | super().contribute_to_document_node(document_node, page_number=page_number)
78 |
79 | if self.raster_image is not None and add_images_to_raster_cache:
80 | document_node.rasterizer.cache.set_image_for_page(
81 | key=raster_cache_key,
82 | page_number=page_number,
83 | image_bytes=self.raster_image,
84 | )
85 |
--------------------------------------------------------------------------------
/docprompt/tasks/parser.py:
--------------------------------------------------------------------------------
1 | """The base output parser that seeks to mimic the langhain implementation."""
2 |
3 | from abc import abstractmethod
4 | from typing import TypeVar
5 |
6 | from pydantic import BaseModel
7 | from typing_extensions import Generic
8 |
9 | TTaskInput = TypeVar("TTaskInput", bound=BaseModel)
10 | TTaskOutput = TypeVar("TTaskOutput", bound=BaseModel)
11 |
12 |
13 | class BaseOutputParser(BaseModel, Generic[TTaskInput, TTaskOutput]):
14 | """The output parser for the page classification system."""
15 |
16 | @abstractmethod
17 | def from_task_input(
18 | cls, task_input: TTaskInput
19 | ) -> "BaseOutputParser[TTaskInput, TTaskOutput]":
20 | """Create an output parser from the task input."""
21 |
22 | @abstractmethod
23 | def parse(self, text: str) -> TTaskOutput:
24 | """Parse the results of the classification task."""
25 |
--------------------------------------------------------------------------------
/docprompt/tasks/result.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from collections.abc import MutableMapping
3 | from datetime import datetime
4 | from typing import TYPE_CHECKING, ClassVar, Dict, Generic, Optional, TypeVar
5 |
6 | from pydantic import BaseModel, Field
7 |
8 | if TYPE_CHECKING:
9 | from docprompt.schema.pipeline import DocumentNode
10 |
11 |
12 | class BaseResult(BaseModel):
13 | provider_name: str = Field(
14 | description="The name of the provider which produced the result"
15 | )
16 | when: datetime = Field(
17 | default_factory=datetime.now, description="The time the result was produced"
18 | )
19 |
20 | task_name: ClassVar[str]
21 |
22 | @property
23 | def task_key(self):
24 | return f"{self.provider_name}_{self.task_name}"
25 |
26 | @abstractmethod
27 | def contribute_to_document_node(
28 | self, document_node: "DocumentNode", **kwargs
29 | ) -> None:
30 | """
31 | Contribute this task result to the document node or a specific page node.
32 |
33 | :param document_node: The DocumentNode to contribute to
34 | :param page_number: If provided, contribute to a specific page. If None, contribute to the document.
35 | """
36 |
37 |
38 | class BaseDocumentResult(BaseResult):
39 | def contribute_to_document_node(
40 | self, document_node: "DocumentNode", **kwargs
41 | ) -> None:
42 | document_node.metadata.task_results[self.task_key] = self
43 |
44 |
45 | class BasePageResult(BaseResult):
46 | def contribute_to_document_node(
47 | self, document_node: "DocumentNode", page_number: Optional[int] = None, **kwargs
48 | ) -> None:
49 | assert (
50 | page_number is not None
51 | ), "Page number must be provided for page level results"
52 | assert (
53 | 0 < page_number <= len(document_node)
54 | ), "Page number must be less than or equal to the number of pages in the document"
55 |
56 | page_node = document_node.page_nodes[page_number - 1]
57 | page_node.metadata.task_results[self.task_key] = self.model_copy()
58 |
59 |
60 | TTaskInput = TypeVar("TTaskInput") # What invoke requires
61 | TTaskConfig = TypeVar("TTaskConfig") # Task specific config like classification labels
62 | PageTaskResult = TypeVar("PageTaskResult", bound=BasePageResult)
63 | DocumentTaskResult = TypeVar("DocumentTaskResult", bound=BaseDocumentResult)
64 | PageOrDocumentTaskResult = TypeVar("PageOrDocumentTaskResult", bound=BaseResult)
65 |
66 |
67 | class ResultContainer(BaseModel, MutableMapping, Generic[PageOrDocumentTaskResult]):
68 | results: Dict[str, PageOrDocumentTaskResult] = Field(
69 | description="The results of the task", default_factory=dict
70 | )
71 |
72 | @property
73 | def result(self):
74 | return next(iter(self.results.values()), None)
75 |
76 | def __setitem__(self, key, value):
77 | if key in self.results:
78 | raise ValueError(f"Result with key {key} already exists")
79 |
80 | self.results[key] = value
81 |
82 | def __delitem__(self, key):
83 | del self.results[key]
84 |
85 | def __getitem__(self, key):
86 | return self.results[key]
87 |
88 | def __iter__(self):
89 | return iter(self.results)
90 |
91 | def __len__(self):
92 | return len(self.results)
93 |
94 | def __contains__(self, item):
95 | return item in self.results
96 |
97 | def __bool__(self):
98 | return bool(self.results)
99 |
--------------------------------------------------------------------------------
/docprompt/tasks/table_extraction/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/docprompt/tasks/table_extraction/__init__.py
--------------------------------------------------------------------------------
/docprompt/tasks/table_extraction/anthropic.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import Iterable, List, Optional, Union
3 |
4 | from bs4 import BeautifulSoup, Tag
5 | from pydantic import Field
6 |
7 | from docprompt.tasks.message import OpenAIComplexContent, OpenAIImageURL, OpenAIMessage
8 | from docprompt.utils import inference
9 |
10 | from .base import BaseTableExtractionProvider
11 | from .schema import (
12 | ExtractedTable,
13 | TableCell,
14 | TableExtractionPageResult,
15 | TableHeader,
16 | TableRow,
17 | )
18 |
19 | SYSTEM_PROMPT = """
20 | You are given an image. Identify and extract all tables from the document.
21 |
22 | For each table, respond in the following format:
23 |
24 |
25 | (value)
26 |
27 |
28 | ...
29 |
30 |
31 |
32 | (value)
33 | ...
34 |
35 | ...
36 |
37 | """.strip()
38 |
39 |
40 | def _title_from_tree(tree: Tag) -> Union[str, None]:
41 | title = tree.find("title")
42 | if title is not None:
43 | return title.text
44 | return None
45 |
46 |
47 | def _headers_from_tree(tree: Tag) -> List[TableHeader]:
48 | headers = tree.find("headers")
49 | if headers is not None:
50 | return [
51 | TableHeader(text=header.text or "") for header in headers.find_all("header")
52 | ]
53 | return []
54 |
55 |
56 | def _rows_from_tree(tree: Tag) -> List[TableRow]:
57 | rows = tree.find("rows")
58 | if rows is not None:
59 | return [
60 | TableRow(
61 | cells=[
62 | TableCell(text=cell.text or "") for cell in row.find_all("column")
63 | ]
64 | )
65 | for row in rows.find_all("row")
66 | ]
67 | return []
68 |
69 |
70 | def _find_start_indices(s: str, sub: str) -> List[int]:
71 | return [m.start() for m in re.finditer(sub, s)]
72 |
73 |
74 | def _find_end_indices(s: str, sub: str) -> List[int]:
75 | return [m.end() for m in re.finditer(sub, s)]
76 |
77 |
78 | def parse_response(response: str, **kwargs) -> TableExtractionPageResult:
79 | table_start_indices = _find_start_indices(response, "")
80 | table_end_indices = _find_end_indices(response, "
")
81 |
82 | tables: List[ExtractedTable] = []
83 | provider_name = kwargs.pop("provider_name", "anthropic")
84 |
85 | for table_start, table_end in zip(table_start_indices, table_end_indices):
86 | table_str = response[table_start:table_end]
87 |
88 | soup = BeautifulSoup(table_str, "html.parser")
89 |
90 | table_element = soup.find("table")
91 |
92 | title = _title_from_tree(table_element)
93 | headers = _headers_from_tree(table_element)
94 | rows = _rows_from_tree(table_element)
95 |
96 | tables.append(ExtractedTable(title=title, headers=headers, rows=rows))
97 |
98 | result = TableExtractionPageResult(tables=tables, provider_name=provider_name)
99 | return result
100 |
101 |
102 | def _prepare_messages(
103 | document_images: Iterable[bytes],
104 | start: Optional[int] = None,
105 | stop: Optional[int] = None,
106 | ):
107 | messages = []
108 |
109 | for image_bytes in document_images:
110 | messages.append(
111 | [
112 | OpenAIMessage(
113 | role="user",
114 | content=[
115 | OpenAIComplexContent(
116 | type="image_url",
117 | image_url=OpenAIImageURL(url=image_bytes),
118 | ),
119 | OpenAIComplexContent(type="text", text=SYSTEM_PROMPT),
120 | ],
121 | ),
122 | ]
123 | )
124 |
125 | return messages
126 |
127 |
128 | class AnthropicTableExtractionProvider(BaseTableExtractionProvider):
129 | name = "anthropic"
130 |
131 | anthropic_model_name: str = Field("claude-3-haiku-20240307")
132 |
133 | async def _ainvoke(
134 | self, input: Iterable[bytes], config: Optional[None] = None, **kwargs
135 | ) -> List[TableExtractionPageResult]:
136 | messages = _prepare_messages(input)
137 |
138 | model_name = kwargs.pop("model_name", self.anthropic_model_name)
139 | completions = await inference.run_batch_inference_anthropic(
140 | model_name, messages, **kwargs
141 | )
142 |
143 | return [parse_response(x, provider_name=self.name) for x in completions]
144 |
--------------------------------------------------------------------------------
/docprompt/tasks/table_extraction/base.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from docprompt import DocumentNode
4 | from docprompt.tasks.base import AbstractPageTaskProvider
5 | from docprompt.tasks.capabilities import PageLevelCapabilities
6 |
7 | from .schema import TableExtractionPageResult
8 |
9 |
10 | class BaseTableExtractionProvider(
11 | AbstractPageTaskProvider[bytes, None, TableExtractionPageResult]
12 | ):
13 | capabilities = [
14 | PageLevelCapabilities.PAGE_TABLE_EXTRACTION,
15 | PageLevelCapabilities.PAGE_TABLE_IDENTIFICATION,
16 | ]
17 |
18 | class Meta:
19 | abstract = True
20 |
21 | def process_document_node(
22 | self,
23 | document_node: DocumentNode,
24 | task_config: Optional[None] = None,
25 | start: Optional[int] = None,
26 | stop: Optional[int] = None,
27 | contribute_to_document: bool = True,
28 | **kwargs,
29 | ):
30 | raster_bytes = []
31 | for page_number in range(start or 1, (stop or len(document_node)) + 1):
32 | image_bytes = document_node.page_nodes[
33 | page_number - 1
34 | ].rasterizer.rasterize("default")
35 | raster_bytes.append(image_bytes)
36 |
37 | # This will be a list of extracted tables??
38 | results = self._invoke(raster_bytes, config=task_config, **kwargs)
39 |
40 | return {
41 | i: res
42 | for i, res in zip(
43 | range(start or 1, (stop or len(document_node)) + 1), results
44 | )
45 | }
46 |
--------------------------------------------------------------------------------
/docprompt/tasks/table_extraction/schema.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | from pydantic import BaseModel, Field
4 |
5 | from docprompt.schema.layout import NormBBox
6 | from docprompt.tasks.base import BasePageResult
7 |
8 |
9 | class TableHeader(BaseModel):
10 | text: str
11 | bbox: Optional[NormBBox] = None
12 |
13 |
14 | class TableCell(BaseModel):
15 | text: str
16 | bbox: Optional[NormBBox] = None
17 |
18 |
19 | class TableRow(BaseModel):
20 | cells: List[TableCell] = Field(default_factory=list)
21 | bbox: Optional[NormBBox] = None
22 |
23 |
24 | class ExtractedTable(BaseModel):
25 | title: Optional[str] = None
26 | bbox: Optional[NormBBox] = None
27 |
28 | headers: List[TableHeader] = Field(default_factory=list)
29 | rows: List[TableRow] = Field(default_factory=list)
30 |
31 | def to_markdown_string(self) -> str:
32 | markdown = ""
33 |
34 | # Add title if present
35 | if self.title:
36 | markdown += f"# {self.title}\n\n"
37 |
38 | # Create header row
39 | header_row = "|" + "|".join(header.text for header in self.headers) + "|\n"
40 | markdown += header_row
41 |
42 | # Create separator row
43 | separator_row = "|" + "|".join("---" for _ in self.headers) + "|\n"
44 | markdown += separator_row
45 |
46 | # Create data rows
47 | for row in self.rows:
48 | data_row = "|" + "|".join(cell.text for cell in row.cells) + "|\n"
49 | markdown += data_row
50 |
51 | return markdown.strip()
52 |
53 |
54 | class TableExtractionPageResult(BasePageResult):
55 | tables: List[ExtractedTable] = Field(default_factory=list)
56 |
57 | task_name = "table_extraction"
58 |
--------------------------------------------------------------------------------
/docprompt/tasks/util.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 | from contextvars import ContextVar
3 | from typing import Any, Dict, Iterator
4 |
5 | _init_context_var = ContextVar("_init_context_var", default=None)
6 |
7 |
8 | @contextmanager
9 | def init_context(value: Dict[str, Any]) -> Iterator[None]:
10 | token = _init_context_var.set(value)
11 | try:
12 | yield
13 | finally:
14 | _init_context_var.reset(token)
15 |
--------------------------------------------------------------------------------
/docprompt/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .date_extraction import extract_dates_from_text
2 | from .util import (
3 | get_page_count,
4 | hash_from_bytes,
5 | is_pdf,
6 | load_document,
7 | load_document_node,
8 | load_documents,
9 | load_pdf_document,
10 | load_pdf_documents,
11 | )
12 |
13 | __all__ = [
14 | "get_page_count",
15 | "is_pdf",
16 | "load_pdf_document",
17 | "load_pdf_documents",
18 | "load_document",
19 | "load_documents",
20 | "hash_from_bytes",
21 | "extract_dates_from_text",
22 | "load_document_node",
23 | ]
24 |
--------------------------------------------------------------------------------
/docprompt/utils/compressor.py:
--------------------------------------------------------------------------------
1 | import tempfile
2 | from typing import Literal
3 |
4 | from docprompt._exec.ghostscript import compress_pdf_to_bytes
5 |
6 |
7 | def compress_pdf_bytes(
8 | file_bytes: bytes, *, compression: Literal["jpeg", "lossless"] = "jpeg"
9 | ) -> bytes:
10 | with tempfile.NamedTemporaryFile(suffix=".pdf") as temp_file:
11 | temp_file.write(file_bytes)
12 | temp_file.flush()
13 |
14 | return compress_pdf_to_bytes(temp_file.name, compression=compression)
15 |
--------------------------------------------------------------------------------
/docprompt/utils/date_extraction.py:
--------------------------------------------------------------------------------
1 | import re
2 | from datetime import date, datetime
3 | from typing import List, Tuple
4 |
5 | DateFormatsType = List[Tuple[re.Pattern, str]]
6 |
7 | default_date_formats = [
8 | # Pre-compile regex patterns for efficiency
9 | # YYYY-MM-DD
10 | (
11 | re.compile(r"\b((19|20)\d\d[-](0?[1-9]|1[012])[-](0?[1-9]|[12][0-9]|3[01]))\b"),
12 | "%Y-%m-%d",
13 | ),
14 | # YY-MM-DD
15 | (
16 | re.compile(r"\b((\d\d)[-](0?[1-9]|1[012])[-](0?[1-9]|[12][0-9]|3[01]))\b"),
17 | "%y-%m-%d",
18 | ),
19 | # MM-DD-YYYY
20 | (
21 | re.compile(r"\b((0?[1-9]|1[012])[-](0?[1-9]|[12][0-9]|3[01])[-](19|20)\d\d)\b"),
22 | "%m-%d-%Y",
23 | ),
24 | # MM-DD-YY
25 | (
26 | re.compile(r"\b((0?[1-9]|1[012])[-](0?[1-9]|[12][0-9]|3[01])[-](\d\d))\b"),
27 | "%m-%d-%y",
28 | ),
29 | # DD-MM-YYYY
30 | (
31 | re.compile(r"\b((0?[1-9]|[12][0-9]|3[01])[-](0?[1-9]|1[012])[-](19|20)\d\d)\b"),
32 | "%d-%m-%Y",
33 | ),
34 | # DD-MM-YY
35 | (
36 | re.compile(r"\b((0?[1-9]|[12][0-9]|3[01])[-](0?[1-9]|1[012])[-](\d\d))\b"),
37 | "%d-%m-%y",
38 | ),
39 | # YYYY/MM/DD
40 | (
41 | re.compile(r"\b((19|20)\d\d[/](0?[1-9]|1[012])[/](0?[1-9]|[12][0-9]|3[01]))\b"),
42 | "%Y/%m/%d",
43 | ),
44 | # YY/MM/DD
45 | (
46 | re.compile(r"\b((\d\d)[/](0?[1-9]|1[012])[/](0?[1-9]|[12][0-9]|3[01]))\b"),
47 | "%y/%m/%d",
48 | ),
49 | # MM/DD/YYYY
50 | (
51 | re.compile(r"\b((0?[1-9]|1[012])[/](0?[1-9]|[12][0-9]|3[01])[/](19|20)\d\d)\b"),
52 | "%m/%d/%Y",
53 | ),
54 | # MM/DD/YY
55 | (
56 | re.compile(r"\b((0?[1-9]|1[012])[/](0?[1-9]|[12][0-9]|3[01])[/](\d\d))\b"),
57 | "%m/%d/%y",
58 | ),
59 | # DD/MM/YYYY
60 | (
61 | re.compile(r"\b((0?[1-9]|[12][0-9]|3[01])[/](0?[1-9]|1[012])[/](19|20)\d\d)\b"),
62 | "%d/%m/%Y",
63 | ),
64 | # DD/MM/YY
65 | (
66 | re.compile(r"\b((0?[1-9]|[12][0-9]|3[01])[/](0?[1-9]|1[012])[/](\d\d))\b"),
67 | "%d/%m/%y",
68 | ),
69 | # YYYY.MM.DD
70 | (
71 | re.compile(r"\b((19|20)\d\d[.](0?[1-9]|1[012])[.](0?[1-9]|[12][0-9]|3[01]))\b"),
72 | "%Y.%m.%d",
73 | ),
74 | # YY.MM.DD
75 | (
76 | re.compile(r"\b((\d\d)[.](0?[1-9]|1[012])[.](0?[1-9]|[12][0-9]|3[01]))\b"),
77 | "%y.%m.%d",
78 | ),
79 | # MM.DD.YYYY
80 | (
81 | re.compile(r"\b((0?[1-9]|1[012])[.](0?[1-9]|[12][0-9]|3[01])[.](19|20)\d\d)\b"),
82 | "%m.%d.%Y",
83 | ),
84 | # MM.DD.YY
85 | (
86 | re.compile(r"\b((0?[1-9]|1[012])[.](0?[1-9]|[12][0-9]|3[01])[.](\d\d))\b"),
87 | "%m.%d.%y",
88 | ),
89 | # DD.MM.YYYY
90 | (
91 | re.compile(r"\b((0?[1-9]|[12][0-9]|3[01])[.](0?[1-9]|1[012])[.](19|20)\d\d)\b"),
92 | "%d.%m.%Y",
93 | ),
94 | # DD.MM.YY
95 | (
96 | re.compile(r"\b((0?[1-9]|[12][0-9]|3[01])[.](0?[1-9]|1[012])[.](\d\d))\b"),
97 | "%d.%m.%y",
98 | ),
99 | # MMMM DDth, YYYY - November 4th, 2023
100 | (
101 | re.compile(
102 | r"\b((January|February|March|April|May|June|July|August|September|October|November|December)\s{1,6}\d{1,2}(st|nd|rd|th)\s{0,2},\s{1,6}\d{4})\b"
103 | ),
104 | "%B %d, %Y",
105 | ),
106 | # MMMM DD, YYYY - November 4, 2023
107 | (
108 | re.compile(
109 | r"\b((January|February|March|April|May|June|July|August|September|October|November|December)\s{1,6}\d{1,2}\s{0,2},\s{1,6}\d{4})\b"
110 | ),
111 | "%B %d, %Y",
112 | ),
113 | # MMM DDth, YYYY - Nov 4th, 2023
114 | (
115 | re.compile(
116 | r"\b((Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\s{1,6}\d{1,2}(st|nd|rd|th)\s{0,2},\s{1,6}\d{4})\b"
117 | ),
118 | "%b %d, %Y",
119 | ),
120 | # MMM DD, YYYY - Nov 4, 2023
121 | (
122 | re.compile(
123 | r"\b((Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\s{1,6}\d{1,2}\s{0,2},\s{1,6}\d{4})\b"
124 | ),
125 | "%b %d, %Y",
126 | ),
127 | ]
128 |
129 |
130 | def extract_dates_from_text(
131 | input_string: str, *, date_formats: DateFormatsType = default_date_formats
132 | ) -> List[Tuple[date, str]]:
133 | """
134 | Extract dates from a string using a set of predefined regex patterns.
135 |
136 | Returns a list of tuples, where the first element is the date object and the second is the full date string.
137 | """
138 | extracted_dates = []
139 |
140 | for regex, date_format in date_formats:
141 | matches = regex.findall(input_string)
142 |
143 | for match_obj in matches:
144 | # Extract the full date from the match
145 | full_date = match_obj[0] # First group captures the entire date
146 |
147 | if "%d" in date_format:
148 | parse_date = re.sub(r"(st|nd|rd|th)", "", full_date)
149 | else:
150 | parse_date = full_date
151 |
152 | parse_date = re.sub(r"\s+", " ", parse_date).strip()
153 | parse_date = re.sub(
154 | r"\s{1,},", ",", parse_date
155 | ).strip() # Commas shouldnt have spaces before them
156 |
157 | # Convert to datetime object
158 | try:
159 | date_obj = datetime.strptime(parse_date, date_format)
160 | except ValueError as e:
161 | print(f"Error parsing date '{full_date}': {e}")
162 | continue
163 |
164 | extracted_dates.append((date_obj.date(), full_date))
165 |
166 | return extracted_dates
167 |
--------------------------------------------------------------------------------
/docprompt/utils/inference.py:
--------------------------------------------------------------------------------
1 | """A utility file for running inference with various LLM providers."""
2 |
3 | import asyncio
4 | import os
5 | from typing import List
6 |
7 | from tenacity import (
8 | retry,
9 | retry_if_exception_type,
10 | stop_after_attempt,
11 | wait_random_exponential,
12 | )
13 | from tqdm.asyncio import tqdm
14 |
15 | from docprompt.tasks.message import OpenAIComplexContent, OpenAIMessage
16 |
17 |
18 | def get_anthropic_retry_decorator():
19 | import anthropic
20 |
21 | return retry(
22 | wait=wait_random_exponential(multiplier=0.5, max=60),
23 | stop=stop_after_attempt(14),
24 | retry=retry_if_exception_type(anthropic.RateLimitError)
25 | | retry_if_exception_type(anthropic.InternalServerError)
26 | | retry_if_exception_type(anthropic.APITimeoutError),
27 | reraise=True,
28 | )
29 |
30 |
31 | def get_openai_retry_decorator():
32 | import openai
33 |
34 | return retry(
35 | wait=wait_random_exponential(multiplier=0.5, max=60),
36 | stop=stop_after_attempt(14),
37 | retry=retry_if_exception_type(openai.RateLimitError)
38 | | retry_if_exception_type(openai.InternalServerError)
39 | | retry_if_exception_type(openai.APITimeoutError),
40 | reraise=True,
41 | )
42 |
43 |
44 | async def run_inference_anthropic(
45 | model_name: str, messages: List[OpenAIMessage], **kwargs
46 | ) -> str:
47 | """Run inference using an Anthropic model asynchronously."""
48 | from anthropic import AsyncAnthropic
49 |
50 | api_key = kwargs.pop("api_key", os.environ.get("ANTHROPIC_API_KEY"))
51 | base_url = kwargs.pop("base_url", os.environ.get("ANTHROPIC_BASE_URL"))
52 | client = AsyncAnthropic(api_key=api_key, base_url=base_url)
53 |
54 | system = None
55 | if messages and messages[0].role == "system":
56 | system = messages[0].content
57 | messages = messages[1:]
58 |
59 | processed_messages = []
60 | for msg in messages:
61 | if isinstance(msg.content, list):
62 | processed_content = []
63 | for content in msg.content:
64 | if isinstance(content, OpenAIComplexContent):
65 | content = content.to_anthropic_message()
66 | processed_content.append(content)
67 | else:
68 | pass
69 | # raise ValueError(f"Invalid content type: {type(content)} Expected OpenAIComplexContent")
70 |
71 | dumped = msg.model_dump()
72 | dumped["content"] = processed_content
73 | processed_messages.append(dumped)
74 | else:
75 | processed_messages.append(msg)
76 |
77 | client_kwargs = {
78 | "model": model_name,
79 | "max_tokens": 2048,
80 | "messages": processed_messages,
81 | **kwargs,
82 | }
83 |
84 | if system:
85 | client_kwargs["system"] = system
86 |
87 | response = await client.messages.create(**client_kwargs)
88 |
89 | content = response.content[0].text
90 |
91 | return content
92 |
93 |
94 | async def run_batch_inference_anthropic(
95 | model_name: str, messages: List[List[OpenAIMessage]], **kwargs
96 | ) -> List[str]:
97 | """Run batch inference using an Anthropic model asynchronously."""
98 | retry_decorator = get_anthropic_retry_decorator()
99 |
100 | @retry_decorator
101 | async def process_message_set(msg_set, index: int):
102 | return await run_inference_anthropic(model_name, msg_set, **kwargs), index
103 |
104 | tasks = [process_message_set(msg_set, i) for i, msg_set in enumerate(messages)]
105 |
106 | # TODO: Need cleaner implementation to ensure message ordering is perserved
107 | responses: List[str] = []
108 | for f in tqdm(asyncio.as_completed(tasks), desc="Processing messages"):
109 | response, index = await f
110 | responses.append((response, index))
111 |
112 | # Sort and extract the responses
113 | responses.sort(key=lambda x: x[1])
114 | responses = [r[0] for r in responses]
115 |
116 | return responses
117 |
--------------------------------------------------------------------------------
/docprompt/utils/masking/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/docprompt/utils/masking/__init__.py
--------------------------------------------------------------------------------
/docprompt/utils/masking/image.py:
--------------------------------------------------------------------------------
1 | from typing import Literal
2 |
3 | from PIL import Image
4 |
5 | from docprompt.schema.layout import NormBBox
6 |
7 | ImageMaskModes = Literal["color", "average", "alpha"]
8 |
9 |
10 | def mask_image_from_bounding_boxes(
11 | image: Image.Image,
12 | *bounding_boxes: NormBBox,
13 | mask_color: str = "#000000",
14 | ):
15 | """
16 | Create a copy of the image with the positions of the bounding boxes masked.
17 | """
18 |
19 | width, height = image.size
20 |
21 | mask = Image.new("RGBA", (width, height), (0, 0, 0, 0))
22 |
23 | for bbox in bounding_boxes:
24 | mask.paste(
25 | Image.new("RGBA", (bbox.width, bbox.height), mask_color),
26 | (int(bbox.x0 * width), int(bbox.top * height)),
27 | )
28 |
29 | return Image.alpha_composite(image, mask)
30 |
--------------------------------------------------------------------------------
/docprompt/utils/util.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | from concurrent.futures import ThreadPoolExecutor, as_completed
3 | from io import BytesIO
4 | from os import PathLike
5 | from pathlib import Path
6 | from typing import TYPE_CHECKING, List, Optional, Union
7 | from urllib.parse import unquote, urlparse
8 |
9 | import filetype
10 | import fsspec
11 |
12 | from docprompt._pdfium import get_pdfium_document
13 | from docprompt.schema.document import PdfDocument
14 |
15 | if TYPE_CHECKING:
16 | from docprompt.schema.pipeline.node.document import DocumentNode
17 |
18 |
19 | def is_pdf(fd: Union[Path, PathLike, bytes]) -> bool:
20 | """
21 | Determines if a file is a PDF
22 | """
23 | if isinstance(fd, (bytes, str)):
24 | mime = filetype.guess_mime(fd)
25 | else:
26 | with open(fd, "rb") as f:
27 | # We only need the first 1024 bytes to determine if it's a PDF
28 | mime = filetype.guess_mime(f.read(1024))
29 |
30 | return mime == "application/pdf"
31 |
32 |
33 | def get_page_count(fd: Union[Path, PathLike, bytes]) -> int:
34 | """
35 | Determines the number of pages in a PDF
36 | """
37 | if not isinstance(fd, bytes):
38 | with open(fd, "rb") as f:
39 | fd = f.read()
40 |
41 | with get_pdfium_document(fd) as pdf:
42 | return len(pdf)
43 |
44 |
45 | def name_from_path(path: Union[Path, PathLike]) -> str:
46 | if not isinstance(path, Path):
47 | path = Path(path)
48 |
49 | file_name = path.name
50 |
51 | parsed = urlparse(file_name)
52 |
53 | return unquote(parsed.path)
54 |
55 |
56 | def read_pdf_bytes_from_path(path: Union[Path, PathLike], **kwargs) -> bytes:
57 | with fsspec.open(urlpath=str(path), mode="rb", **kwargs) as f:
58 | return f.read()
59 |
60 |
61 | def determine_pdf_name_from_bytes(file_bytes: bytes) -> str:
62 | """
63 | Attempts to determine the name of a PDF by exaimining metadata
64 | """
65 | with get_pdfium_document(file_bytes) as pdf:
66 | metadata_dict = pdf.get_metadata_dict(skip_empty=True)
67 |
68 | name = None
69 |
70 | if metadata_dict:
71 | name = (
72 | metadata_dict.get("Title")
73 | or metadata_dict.get("Subject")
74 | or metadata_dict.get("Author")
75 | )
76 |
77 | if name:
78 | return f"{name.strip()}.pdf"
79 |
80 | return f"document-{hash_from_bytes(file_bytes)}.pdf"
81 |
82 |
83 | def load_pdf_document(
84 | fp: Union[Path, PathLike, bytes],
85 | *,
86 | file_name: Optional[str] = None,
87 | password: Optional[str] = None,
88 | ) -> PdfDocument:
89 | """
90 | Loads a document from a file path
91 | """
92 | if isinstance(fp, bytes):
93 | file_bytes = fp
94 | file_name = file_name or determine_pdf_name_from_bytes(file_bytes)
95 | file_path = None
96 | else:
97 | file_name = name_from_path(fp) if file_name is None else file_name
98 | file_path = str(fp)
99 | file_bytes = read_pdf_bytes_from_path(fp)
100 |
101 | if not is_pdf(file_bytes):
102 | raise ValueError("File is not a PDF")
103 |
104 | return PdfDocument(
105 | name=unquote(file_name),
106 | file_path=file_path,
107 | file_bytes=file_bytes,
108 | password=password,
109 | )
110 |
111 |
112 | def load_pdf_documents(
113 | fps: List[Union[Path, PathLike, bytes]],
114 | *,
115 | max_threads: int = 12,
116 | passwords: Optional[List[str]] = None,
117 | ):
118 | """
119 | Loads multiple documents from file paths, using a thread pool
120 | """
121 | futures = []
122 |
123 | thread_count = min(max_threads, len(fps))
124 |
125 | with ThreadPoolExecutor(max_workers=thread_count) as executor:
126 | for fp in fps:
127 | futures.append(executor.submit(load_document, fp))
128 |
129 | results = []
130 |
131 | for future in as_completed(futures):
132 | results.append(future.result())
133 |
134 | return results
135 |
136 |
137 | def load_document_node(
138 | fp: Union[Path, PathLike, bytes],
139 | *,
140 | file_name: Optional[str] = None,
141 | password: Optional[str] = None,
142 | ) -> "DocumentNode":
143 | from docprompt.schema.pipeline.node.document import DocumentNode
144 |
145 | document = load_pdf_document(fp, file_name=file_name, password=password)
146 |
147 | return DocumentNode.from_document(document)
148 |
149 |
150 | load_document = load_pdf_document
151 | load_documents = load_pdf_documents
152 |
153 |
154 | def hash_from_bytes(
155 | byte_data: bytes, hash_func=hashlib.md5, threshold=1024 * 1024 * 128
156 | ) -> str:
157 | """
158 | Gets a hash from bytes. If the bytes are larger than the threshold, the hash is computed in chunks
159 | to avoid memory issues. The default hash function is MD5 with a threshold of 128MB which is optimal
160 | for most machines and use cases.
161 | """
162 | if len(byte_data) < 1024 * 1024 * 10: # 10MB
163 | return hashlib.md5(byte_data).hexdigest()
164 |
165 | hash = hash_func()
166 |
167 | if len(byte_data) > threshold:
168 | stream = BytesIO(byte_data)
169 | b = bytearray(128 * 1024)
170 | mv = memoryview(b)
171 |
172 | while n := stream.readinto(mv):
173 | hash.update(mv[:n])
174 | else:
175 | hash.update(byte_data)
176 |
177 | return hash.hexdigest()
178 |
--------------------------------------------------------------------------------
/docs/CNAME:
--------------------------------------------------------------------------------
1 | docs.docprompt.io
2 |
--------------------------------------------------------------------------------
/docs/assets/static/img/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/docs/assets/static/img/logo.png
--------------------------------------------------------------------------------
/docs/assets/static/img/old-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/docs/assets/static/img/old-logo.png
--------------------------------------------------------------------------------
/docs/blog/index.md:
--------------------------------------------------------------------------------
1 | # Blog
2 |
--------------------------------------------------------------------------------
/docs/community/contributing.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/docs/community/contributing.md
--------------------------------------------------------------------------------
/docs/community/versioning.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/docs/community/versioning.md
--------------------------------------------------------------------------------
/docs/concepts/nodes.md:
--------------------------------------------------------------------------------
1 | # Nodes in Docprompt
2 |
3 | ## Overview
4 |
5 | In Docprompt, nodes are fundamental structures used to represent and manage documents and their pages. They provide a way to store state and metadata associated with documents and individual pages, enabling advanced document analysis and processing capabilities.
6 |
7 | ## Key Concepts
8 |
9 | ### DocumentNode
10 |
11 | A `DocumentNode` represents a single document within the Docprompt system. It serves as a container for document-level metadata and provides access to individual pages through `PageNode` instances.
12 |
13 | ```python
14 | class DocumentNode(BaseModel, Generic[DocumentNodeMetadata, PageNodeMetadata]):
15 | document: Document
16 | page_nodes: List[PageNode[PageNodeMetadata]]
17 | metadata: Optional[DocumentNodeMetadata]
18 | ```
19 |
20 | Key features:
21 | - Stores a reference to the underlying `Document` object
22 | - Maintains a list of `PageNode` instances representing individual pages
23 | - Allows for custom document-level metadata
24 | - Provides access to a `DocumentProvenanceLocator` for efficient text search within the document
25 |
26 | ### PageNode
27 |
28 | A `PageNode` represents a single page within a document. It stores page-specific information and provides access to various analysis results, such as OCR data.
29 |
30 | ```python
31 | class PageNode(BaseModel, Generic[PageNodeMetadata]):
32 | document: "DocumentNode"
33 | page_number: PositiveInt
34 | metadata: Optional[PageNodeMetadata]
35 | extra: Dict[str, Any]
36 | ocr_results: ResultContainer[OcrPageResult]
37 | ```
38 |
39 | Key features:
40 | - References the parent `DocumentNode`
41 | - Stores the page number
42 | - Allows for custom page-level metadata
43 | - Provides a flexible `extra` field for additional data storage
44 | - Stores OCR results in a `ResultContainer`
45 |
46 | ## Usage
47 |
48 | ### Creating a DocumentNode
49 |
50 | You can create a `DocumentNode` from a `Document` instance:
51 |
52 | ```python
53 | from docprompt import load_document, DocumentNode
54 |
55 | document = load_document("path/to/my.pdf")
56 | document_node = DocumentNode.from_document(document)
57 | ```
58 |
59 | ### Working with OCR Results
60 |
61 | After processing a document with an OCR provider, you can access the results through the `DocumentNode` and `PageNode` structures:
62 |
63 | ```python
64 | from docprompt.tasks.ocr.gcp import GoogleOcrProvider
65 |
66 | provider = GoogleOcrProvider.from_service_account_file(
67 | project_id=my_project_id,
68 | processor_id=my_processor_id,
69 | service_account_file=path_to_service_file
70 | )
71 |
72 | provider.process_document_node(document_node)
73 |
74 | # Access OCR results for a specific page
75 | ocr_result = document_node.page_nodes[0].ocr_results
76 | ```
77 |
78 | ### Using DocumentProvenanceLocator
79 |
80 | The `DocumentProvenanceLocator` is a powerful tool for searching text within a document:
81 |
82 | ```python
83 | # Search for text across the entire document
84 | results = document_node.locator.search("John Doe")
85 |
86 | # Search for text on a specific page
87 | page_results = document_node.locator.search("Jane Doe", page_number=4)
88 | ```
89 |
90 | ## Benefits of Using Nodes
91 |
92 | 1. **Separation of Concerns**: Nodes allow you to separate the core PDF functionality (handled by the `Document` class) from additional metadata and analysis results.
93 |
94 | 2. **Flexible Metadata**: Both `DocumentNode` and `PageNode` support generic metadata types, allowing you to add custom, type-safe metadata to your documents and pages.
95 |
96 | 3. **Result Caching**: Nodes provide a convenient way to cache and access results from various analysis tasks, such as OCR.
97 |
98 | 4. **Efficient Text Search**: The `DocumentProvenanceLocator` enables fast text search capabilities, leveraging OCR results for improved performance.
99 |
100 | 5. **Extensibility**: The node structure allows for easy integration of new analysis tools and result types in the future.
101 |
102 | By using the node structure in Docprompt, you can build powerful document analysis workflows that combine the core PDF functionality with advanced processing and search capabilities.
103 |
--------------------------------------------------------------------------------
/docs/concepts/primatives.md:
--------------------------------------------------------------------------------
1 | # Docprompt Primitives
2 |
3 | Docprompt uses several primitive objects that are fundamental to its operation. These primitives are used throughout the library and are essential for understanding how Docprompt processes and represents documents.
4 |
5 | ## PdfDocument
6 |
7 | The `PdfDocument` class is a core primitive in Docprompt, representing a PDF document with various utilities for manipulation and analysis.
8 |
9 | ```python
10 | class PdfDocument(BaseModel):
11 | name: str
12 | file_bytes: bytes
13 | file_path: Optional[str] = None
14 | ```
15 |
16 | ### Key Features
17 |
18 | 1. **Document Properties**
19 | - `name`: The name of the document
20 | - `file_bytes`: The raw bytes of the PDF file
21 | - `file_path`: Optional path to the PDF file on disk
22 | - `page_count`: The number of pages in the document (computed field)
23 | - `document_hash`: A unique hash of the document (computed field)
24 |
25 | 2. **Utility Methods**
26 | - `from_path(file_path)`: Create a PdfDocument from a file path
27 | - `from_bytes(file_bytes, name)`: Create a PdfDocument from bytes
28 | - `get_page_render_size(page_number, dpi)`: Get the render size of a specific page
29 | - `to_compressed_bytes()`: Compress the PDF using Ghostscript
30 | - `rasterize_page(page_number, ...)`: Rasterize a specific page with various options
31 | - `rasterize_pdf(...)`: Rasterize the entire PDF
32 | - `split(start, stop)`: Split the PDF into a new document
33 | - `as_tempfile()`: Create a temporary file from the PDF
34 | - `write_to_path(path)`: Write the PDF to a specific path
35 |
36 | ### Usage Example
37 |
38 | ```python
39 | from docprompt import PdfDocument
40 |
41 | # Load a PDF
42 | pdf = PdfDocument.from_path("path/to/document.pdf")
43 |
44 | # Get document properties
45 | print(f"Document name: {pdf.name}")
46 | print(f"Page count: {pdf.page_count}")
47 |
48 | # Rasterize a page
49 | page_image = pdf.rasterize_page(1, dpi=300)
50 |
51 | # Split the document
52 | new_pdf = pdf.split(start=5, stop=10)
53 | ```
54 |
55 | ## Layout Primitives
56 |
57 | Docprompt uses several layout primitives to represent the structure and content of documents.
58 |
59 | ### NormBBox
60 |
61 | `NormBBox` represents a normalized bounding box with values between 0 and 1.
62 |
63 | ```python
64 | class NormBBox(BaseModel):
65 | x0: BoundedFloat
66 | top: BoundedFloat
67 | x1: BoundedFloat
68 | bottom: BoundedFloat
69 | ```
70 |
71 | Key features:
72 | - Intersection operations (`__and__`)
73 | - Union operations (`__add__`)
74 | - Intersection over Union (IoU) calculation
75 | - Area and centroid properties
76 |
77 | ### TextBlock
78 |
79 | `TextBlock` represents a block of text within a document, including its bounding box and metadata.
80 |
81 | ```python
82 | class TextBlock(BaseModel):
83 | text: str
84 | type: SegmentLevels
85 | source: TextblockSource
86 | bounding_box: NormBBox
87 | bounding_poly: Optional[BoundingPoly]
88 | text_spans: Optional[List[TextSpan]]
89 | metadata: Optional[TextBlockMetadata]
90 | ```
91 |
92 | ### Point and BoundingPoly
93 |
94 | `Point` and `BoundingPoly` are used to represent more complex shapes within a document.
95 |
96 | ```python
97 | class Point(BaseModel):
98 | x: BoundedFloat
99 | y: BoundedFloat
100 |
101 | class BoundingPoly(BaseModel):
102 | normalized_vertices: List[Point]
103 | ```
104 |
105 | ### TextSpan
106 |
107 | `TextSpan` represents a span of text within a document or page.
108 |
109 | ```python
110 | class TextSpan(BaseModel):
111 | start_index: int
112 | end_index: int
113 | level: Literal["page", "document"]
114 | ```
115 |
116 | ### Usage Example
117 |
118 | ```python
119 | from docprompt.schema.layout import NormBBox, TextBlock, TextBlockMetadata
120 |
121 | # Create a bounding box
122 | bbox = NormBBox(x0=0.1, top=0.1, x1=0.9, bottom=0.2)
123 |
124 | # Create a text block
125 | text_block = TextBlock(
126 | text="Example text",
127 | type="block",
128 | source="ocr",
129 | bounding_box=bbox,
130 | metadata=TextBlockMetadata(confidence=0.95)
131 | )
132 |
133 | # Use the text block
134 | print(f"Text: {text_block.text}")
135 | print(f"Bounding box: {text_block.bounding_box}")
136 | print(f"Confidence: {text_block.confidence}")
137 | ```
138 |
139 | These primitives form the foundation of Docprompt's document processing capabilities, allowing for precise representation and manipulation of document content and structure.
140 |
--------------------------------------------------------------------------------
/docs/concepts/provenance.md:
--------------------------------------------------------------------------------
1 | # Provenance in Docprompt
2 |
3 | ## Overview
4 |
5 | Provenance in Docprompt refers to the ability to trace and locate specific pieces of text within a document. The `DocumentProvenanceLocator` class is a powerful tool that enables efficient text search, spatial queries, and fine-grained text location within documents that have been processed with OCR.
6 |
7 | ## Key Concepts
8 |
9 | ### DocumentProvenanceLocator
10 |
11 | The `DocumentProvenanceLocator` is a class that provides advanced search capabilities for documents in Docprompt. It combines full-text search with spatial indexing to offer fast and accurate text location services.
12 |
13 | ```python
14 | @dataclass
15 | class DocumentProvenanceLocator:
16 | document_name: str
17 | search_index: "tantivy.Index"
18 | block_mapping: Dict[int, OcrPageResult]
19 | geo_index: DocumentProvenanceGeoMap
20 | ```
21 |
22 | Key features:
23 | - Full-text search using the Tantivy search engine
24 | - Spatial indexing using R-tree for efficient bounding box queries
25 | - Support for different granularity levels (word, line, block)
26 | - Ability to refine search results to word-level precision
27 |
28 | ## Main Functionalities
29 |
30 | ### 1. Text Search
31 |
32 | The `search` method allows you to find specific text within a document:
33 |
34 | ```python
35 | def search(
36 | self,
37 | query: str,
38 | page_number: Optional[int] = None,
39 | *,
40 | refine_to_word: bool = True,
41 | require_exact_match: bool = True
42 | ) -> List[ProvenanceSource]:
43 | # ... implementation ...
44 | ```
45 |
46 | This method returns a list of `ProvenanceSource` objects, which contain detailed information about where the text was found, including page number, bounding box, and the surrounding context.
47 |
48 | ### 2. Spatial Queries
49 |
50 | The `DocumentProvenanceLocator` supports spatial queries to find text blocks based on their location on the page:
51 |
52 | ```python
53 | def get_k_nearest_blocks(
54 | self,
55 | bbox: NormBBox,
56 | page_number: int,
57 | k: int,
58 | granularity: BlockGranularity = "block"
59 | ) -> List[TextBlock]:
60 | # ... implementation ...
61 |
62 | def get_overlapping_blocks(
63 | self,
64 | bbox: NormBBox,
65 | page_number: int,
66 | granularity: BlockGranularity = "block"
67 | ) -> List[TextBlock]:
68 | # ... implementation ...
69 | ```
70 |
71 | These methods allow you to find text blocks that are near or overlapping with a given bounding box on a specific page.
72 |
73 | ## Usage
74 |
75 | ### Recommended Usage: Through DocumentNode
76 |
77 | The recommended way to use the `DocumentProvenanceLocator` is through the `DocumentNode` class. The `DocumentNode` provides two methods for working with the locator:
78 |
79 | 1. `locator` property: Lazily creates and returns the `DocumentProvenanceLocator`.
80 | 2. `refresh_locator()` method: Explicitly refreshes the locator for the document node.
81 |
82 | Here's how to use these methods:
83 |
84 | ```python
85 | from docprompt import load_document, DocumentNode
86 | from docprompt.tasks.ocr.gcp import GoogleOcrProvider
87 |
88 | # Load and process the document
89 | document = load_document("path/to/my.pdf")
90 | document_node = DocumentNode.from_document(document)
91 |
92 | # Process the document with OCR
93 | provider = GoogleOcrProvider.from_service_account_file(...)
94 | provider.process_document_node(document_node)
95 |
96 | # Access the locator (creates it if it doesn't exist)
97 | locator = document_node.locator
98 |
99 | # Perform a search
100 | results = locator.search("Docprompt")
101 |
102 | # If you need to refresh the locator (e.g., after updating OCR results)
103 | document_node.refresh_locator()
104 | ```
105 |
106 | Note: Attempting to access the locator before OCR results are available will raise a `ValueError`.
107 |
108 | ### Alternative: Standalone Usage
109 |
110 | While the recommended approach is to use the locator through `DocumentNode`, you can also create and use a `DocumentProvenanceLocator` independently:
111 |
112 | ```python
113 | from docprompt.provenance.search import DocumentProvenanceLocator
114 |
115 | # Assuming you have a processed DocumentNode
116 | locator = DocumentProvenanceLocator.from_document_node(document_node)
117 |
118 | # Now you can use the locator directly
119 | results = locator.search("Docprompt")
120 | ```
121 |
122 | ### Searching for Text
123 |
124 | To search for text within the document:
125 |
126 | ```python
127 | results = locator.search("Docprompt")
128 | for result in results:
129 | print(f"Found on page {result.page_number}, bbox: {result.text_location.merged_source_block.bounding_box}")
130 | ```
131 |
132 | ### Performing Spatial Queries
133 |
134 | To find text blocks near a specific location:
135 |
136 | ```python
137 | bbox = NormBBox(x0=0.1, y0=0.1, x1=0.2, y1=0.2)
138 | nearby_blocks = locator.get_k_nearest_blocks(bbox, page_number=1, k=5)
139 | ```
140 |
141 | ## Benefits of Using Provenance
142 |
143 | 1. **Accurate Text Location**: Quickly find the exact location of text within a document, including page number and bounding box.
144 | 2. **Efficient Searching**: Combine full-text search with spatial indexing for fast and accurate results.
145 | 3. **Flexible Granularity**: Search and retrieve results at different levels of granularity (word, line, block).
146 | 4. **Integration with OCR**: Seamlessly works with OCR results to provide comprehensive document analysis capabilities.
147 | 5. **Support for Complex Queries**: Perform spatial queries to find text based on location within pages.
148 | 6. **Easy Access**: Conveniently access the locator through the `DocumentNode` class, ensuring it's always available when needed.
149 |
150 | By leveraging the provenance functionality in Docprompt, you can build sophisticated document analysis workflows that require precise text location and contextual information retrieval.
151 |
--------------------------------------------------------------------------------
/docs/concepts/providers.md:
--------------------------------------------------------------------------------
1 | # Providers in Docprompt
2 |
3 | ## Overview
4 |
5 | Providers in Docprompt are abstract interfaces that define how to add data to document nodes. They encapsulate various tasks such as OCR, classification, and more. The provider system is designed to be extensible, allowing users to create custom providers to add new functionality to Docprompt.
6 |
7 | ## Key Concepts
8 |
9 | ### AbstractTaskProvider
10 |
11 | The `AbstractTaskProvider` is the base class for all providers in Docprompt. It defines the interface that all task providers must implement.
12 |
13 | ```python
14 | class AbstractTaskProvider(Generic[PageTaskResult]):
15 | name: str
16 | capabilities: List[str]
17 |
18 | def process_document_pages(
19 | self,
20 | document: Document,
21 | start: Optional[int] = None,
22 | stop: Optional[int] = None,
23 | **kwargs,
24 | ) -> Dict[int, PageTaskResult]:
25 | raise NotImplementedError
26 |
27 | def contribute_to_document_node(
28 | self,
29 | document_node: "DocumentNode",
30 | results: Dict[int, PageTaskResult],
31 | ) -> None:
32 | pass
33 |
34 | def process_document_node(
35 | self,
36 | document_node: "DocumentNode",
37 | start: Optional[int] = None,
38 | stop: Optional[int] = None,
39 | contribute_to_document: bool = True,
40 | **kwargs,
41 | ) -> Dict[int, PageTaskResult]:
42 | # ... implementation ...
43 | ```
44 |
45 | Key features:
46 | - Generic type `PageTaskResult` allows for type-safe results
47 | - `capabilities` list defines what the provider can do
48 | - `process_document_pages` method processes pages of a document
49 | - `contribute_to_document_node` method adds results to a `DocumentNode`
50 | - `process_document_node` method combines processing and contributing results
51 |
52 | ### CAPABILITIES
53 |
54 | The `CAPABILITIES` enum defines the various capabilities that a provider can have:
55 |
56 | ```python
57 | class CAPABILITIES(Enum):
58 | PAGE_RASTERIZATION = "page-rasterization"
59 | PAGE_LAYOUT_OCR = "page-layout-ocr"
60 | PAGE_TEXT_OCR = "page-text-ocr"
61 | PAGE_CLASSIFICATION = "page-classification"
62 | PAGE_SEGMENTATION = "page-segmentation"
63 | PAGE_VQA = "page-vqa"
64 | PAGE_TABLE_IDENTIFICATION = "page-table-identification"
65 | PAGE_TABLE_EXTRACTION = "page-table-extraction"
66 | ```
67 |
68 | ### ResultContainer
69 |
70 | The `ResultContainer` is a generic class that holds the results of a task:
71 |
72 | ```python
73 | class ResultContainer(BaseModel, Generic[PageOrDocumentTaskResult]):
74 | results: Dict[str, PageOrDocumentTaskResult] = Field(
75 | description="The results of the task, keyed by provider", default_factory=dict
76 | )
77 |
78 | @property
79 | def result(self):
80 | return next(iter(self.results.values()), None)
81 | ```
82 |
83 | ## Creating Custom Providers
84 |
85 | To extend Docprompt's functionality, you can create custom providers. Here's an shortened example of a builtin OCR provider from GCP:
86 |
87 | ```python
88 | from docprompt.tasks.base import AbstractTaskProvider, CAPABILITIES
89 | from docprompt.schema.layout import TextBlock
90 | from pydantic import Field
91 |
92 | class OcrPageResult(BasePageResult):
93 | page_text: str = Field(description="The text for the entire page in reading order")
94 | word_level_blocks: List[TextBlock] = Field(default_factory=list)
95 | line_level_blocks: List[TextBlock] = Field(default_factory=list)
96 | block_level_blocks: List[TextBlock] = Field(default_factory=list)
97 | raster_image: Optional[bytes] = Field(default=None)
98 |
99 | class GoogleOcrProvider(AbstractTaskProvider[OcrPageResult]):
100 | name = "Google Document AI"
101 | capabilities = [
102 | CAPABILITIES.PAGE_TEXT_OCR.value,
103 | CAPABILITIES.PAGE_LAYOUT_OCR.value,
104 | CAPABILITIES.PAGE_RASTERIZATION.value,
105 | ]
106 |
107 | def process_document_pages(
108 | self,
109 | document: Document,
110 | start: Optional[int] = None,
111 | stop: Optional[int] = None,
112 | **kwargs,
113 | ) -> Dict[int, OcrPageResult]:
114 | # Implement OCR logic here
115 | pass
116 |
117 | def contribute_to_document_node(
118 | self,
119 | document_node: "DocumentNode",
120 | results: Dict[int, OcrPageResult],
121 | ) -> None:
122 | # Add OCR results to document node
123 | pass
124 | ```
125 |
126 | ## Usage
127 |
128 | Here's how you can use a provider in your Docprompt workflow:
129 |
130 | ```python
131 | from docprompt import load_document, DocumentNode
132 | from docprompt.providers.ocr import GoogleOcrProvider
133 |
134 | # Load a document
135 | document = load_document("path/to/my.pdf")
136 | document_node = DocumentNode.from_document(document)
137 |
138 | # Create and use the OCR provider
139 | ocr_provider = GoogleOcrProvider(...)
140 | ocr_results = ocr_provider.process_document_node(document_node)
141 |
142 | # Access OCR results
143 | for page_number, result in ocr_results.items():
144 | print(f"Page {page_number} text: {result.page_text[:100]}...")
145 | ```
146 |
147 | ## Benefits of Using Providers
148 |
149 | 1. **Extensibility**: Easily add new functionality to Docprompt by creating custom providers.
150 | 2. **Modularity**: Each provider encapsulates a specific task, making the codebase more organized and maintainable.
151 | 3. **Type Safety**: Generic types ensure that providers produce and consume the correct types of results.
152 | 4. **Standardized Interface**: All providers follow the same interface, making it easy to switch between different implementations.
153 | 5. **Capability-based Design**: Providers declare their capabilities, allowing for dynamic feature discovery and usage.
154 |
155 | By leveraging the provider system in Docprompt, you can create flexible and powerful document processing pipelines that can be easily extended and customized to meet your specific needs.
156 |
--------------------------------------------------------------------------------
/docs/enterprise.md:
--------------------------------------------------------------------------------
1 | # Enterprise
2 | For companies looking to unlock data, build custom language models, or for general professional support
3 |
4 | [Talk to founders](https://calendly.com/pageleaf/meet-with-the-founders)
5 |
6 | This covers:
7 |
8 | * ✅ **Assistance with PDF-optimized prompt engineering for Document AI tasks**
9 | * ✅ **Feature Prioritization**
10 | * ✅ **Custom Integrations**
11 | - ✅ **Professional Support - Dedicated discord + slack**
12 |
13 |
14 | ### What topics does Professional support cover?
15 |
16 | The expertise we've developed during our time building medical processing pipelines has equipped us with the tools and knowhow needed to perform highly accurate information extraction in document-heavy domains. We offer consulting services that leverage this expertise to assist with prompt engineering, deployments, and general ML-ops in the Document AI space.
17 |
--------------------------------------------------------------------------------
/docs/gen_ref_pages.py:
--------------------------------------------------------------------------------
1 | """Generate the code reference pages and navigation."""
2 |
3 | import logging
4 | from pathlib import Path
5 |
6 | import mkdocs_gen_files
7 |
8 | # Set up logging
9 | logging.basicConfig(level=logging.DEBUG)
10 | logger = logging.getLogger(__name__)
11 |
12 | nav = mkdocs_gen_files.Nav()
13 |
14 | root_path = Path("docprompt")
15 | logger.debug(f"Searching for Python files in: {root_path.absolute()}")
16 |
17 | python_files = list(root_path.rglob("*.py"))
18 | logger.debug(f"Found {len(python_files)} Python files")
19 |
20 | if not python_files:
21 | logger.warning(
22 | "No Python files found. Ensure the 'docprompt' directory exists and contains .py files."
23 | )
24 | nav["No modules found"] = "no_modules.md"
25 | with mkdocs_gen_files.open("reference/no_modules.md", "w") as fd:
26 | fd.write(
27 | "# No Modules Found\n\nNo Python modules were found in the 'docprompt' directory."
28 | )
29 | else:
30 | for path in sorted(python_files):
31 | module_path = path.relative_to(root_path).with_suffix("")
32 | doc_path = path.relative_to(root_path).with_suffix(".md")
33 | full_doc_path = Path("reference", doc_path)
34 |
35 | parts = tuple(module_path.parts)
36 |
37 | if parts[-1] == "__init__":
38 | parts = parts[:-1]
39 | doc_path = doc_path.with_name("index.md")
40 | full_doc_path = full_doc_path.with_name("index.md")
41 | elif parts[-1] == "__main__":
42 | continue
43 |
44 | # Handle empty parts
45 | if not parts:
46 | logger.warning(f"Empty parts for file: {path}. Skipping this file.")
47 | continue
48 |
49 | nav[parts] = doc_path.as_posix()
50 |
51 | with mkdocs_gen_files.open(full_doc_path, "w") as fd:
52 | ident = ".".join(parts)
53 | fd.write(f"::: docprompt.{ident}")
54 |
55 | mkdocs_gen_files.set_edit_path(full_doc_path, path)
56 |
57 | logger.debug("Navigation structure:")
58 | logger.debug(nav.build_literate_nav())
59 |
60 | with mkdocs_gen_files.open("reference/SUMMARY.md", "w") as nav_file:
61 | nav_file.writelines(nav.build_literate_nav())
62 |
--------------------------------------------------------------------------------
/docs/guide/classify/binary.md:
--------------------------------------------------------------------------------
1 | # Binary Classification with DocPrompt
2 |
3 | Binary classification is a fundamental task in document analysis, where you categorize pages into one of two classes. This guide will walk you through performing binary classification using DocPrompt with the Anthropic provider.
4 |
5 | ## Setting Up
6 |
7 | First, let's import the necessary modules and set up our environment:
8 |
9 | ```python
10 | from docprompt import load_document, DocumentNode
11 | from docprompt.tasks.factory import AnthropicTaskProviderFactory
12 | from docprompt.tasks.classification import ClassificationInput, ClassificationTypes
13 |
14 | # Initialize the Anthropic factory
15 | # Make sure you have set the ANTHROPIC_API_KEY environment variable
16 | factory = AnthropicTaskProviderFactory()
17 |
18 | # Create the classification provider
19 | classification_provider = factory.get_page_classification_provider()
20 | ```
21 |
22 | ## Preparing the Document
23 |
24 | Load your document and create a DocumentNode:
25 |
26 | ```python
27 | document = load_document("path/to/your/document.pdf")
28 | document_node = DocumentNode.from_document(document)
29 | ```
30 |
31 | ## Configuring the Classification Task
32 |
33 | For binary classification, we need to create a `ClassificationInput` object. This object acts as a prompt for the model, guiding its classification decision. Here's how to set it up:
34 |
35 | ```python
36 | classification_input = ClassificationInput(
37 | type=ClassificationTypes.BINARY,
38 | instructions="Determine if the page contains information about financial transactions.",
39 | confidence=True # Optional: request confidence scores
40 | )
41 | ```
42 |
43 | Let's break down the `ClassificationInput`:
44 |
45 | - `type`: Set to `ClassificationTypes.BINARY` for binary classification.
46 | - `instructions`: This is crucial for binary classification. Provide clear, specific instructions for what the model should look for.
47 | - `confidence`: Set to `True` if you want confidence scores with the results.
48 |
49 | Note: For binary classification, you don't need to specify labels. The model will automatically use "YES" and "NO" as labels.
50 |
51 | ## Performing Classification
52 |
53 | Now, let's run the classification task:
54 |
55 | ```python
56 | results = classification_provider.process_document_node(
57 | document_node,
58 | classification_input
59 | )
60 | ```
61 |
62 | ## Interpreting Results
63 |
64 | The `results` dictionary contains the classification output for each page. Let's examine the results:
65 |
66 | ```python
67 | for page_number, result in results.items():
68 | label = result.labels
69 | confidence = result.score
70 | print(f"Page {page_number}:")
71 | print(f"\tClassification: {label} ({confidence})")
72 | print('---')
73 | ```
74 |
75 | This will print the classification (YES or NO) for each page, along with the confidence level if requested.
76 |
77 | ## Tips for Effective Binary Classification
78 |
79 | 1. **Clear Instructions**: The `instructions` field in `ClassificationInput` is crucial. Be specific about what constitutes a "YES" classification.
80 |
81 | 2. **Consider Page Context**: Remember that the model analyzes each page independently. If your classification requires context from multiple pages, you may need to adjust your approach.
82 |
83 | 3. **Confidence Scores**: Use the `confidence` option to get an idea of how certain the model is about its classifications. This can be helpful for identifying pages that might need human review.
84 |
85 | 4. **Iterative Refinement**: If you're not getting the desired results, try refining your instructions. You might need to be more specific or provide examples of what constitutes a positive classification.
86 |
87 | ## Conclusion
88 |
89 | Binary classification with DocPrompt allows you to quickly categorize pages in your documents. By leveraging the Anthropic provider and carefully crafting your `ClassificationInput`, you can achieve accurate and efficient document analysis.
90 |
91 | Remember that the quality of your results heavily depends on the clarity and specificity of your instructions. Experiment with different phrasings to find what works best for your specific use case.
92 |
--------------------------------------------------------------------------------
/docs/guide/classify/multi.md:
--------------------------------------------------------------------------------
1 | # Multi-Label Classification with DocPrompt: Academic Research Papers
2 |
3 | In this guide, we'll use DocPrompt to classify academic research papers in a PDF document into multiple relevant fields. We'll use multi-label classification, allowing each page (paper) to be assigned to one or more categories based on its content.
4 |
5 | ## Setting Up
6 |
7 | First, let's import the necessary modules and set up our environment:
8 |
9 | ```python
10 | from docprompt import load_document, DocumentNode
11 | from docprompt.tasks.factory import AnthropicTaskProviderFactory
12 | from docprompt.tasks.classification import ClassificationInput, ClassificationTypes
13 |
14 | # Initialize the Anthropic factory
15 | # Ensure you have set the ANTHROPIC_API_KEY environment variable
16 | factory = AnthropicTaskProviderFactory()
17 |
18 | # Create the classification provider
19 | classification_provider = factory.get_page_classification_provider()
20 | ```
21 |
22 | ## Preparing the Document
23 |
24 | Load your collection of research papers and create a DocumentNode:
25 |
26 | ```python
27 | document = load_document("path/to/your/research_papers.pdf")
28 | document_node = DocumentNode.from_document(document)
29 | ```
30 |
31 | ## Configuring the Classification Task
32 |
33 | For multi-label classification, we'll create a `ClassificationInput` object that specifies our labels and provides instructions for the model:
34 |
35 | ```python
36 | classification_input = ClassificationInput(
37 | type=ClassificationTypes.MULTI_LABEL,
38 | instructions=(
39 | "Classify the research paper on this page into one or more relevant fields "
40 | "based on its title, abstract, methodology, and key findings. A paper may "
41 | "belong to multiple categories if it spans multiple disciplines."
42 | ),
43 | labels=[
44 | "Machine Learning",
45 | "Natural Language Processing",
46 | "Computer Vision",
47 | "Robotics",
48 | "Data Mining",
49 | "Cybersecurity",
50 | "Bioinformatics",
51 | "Quantum Computing"
52 | ],
53 | descriptions=[
54 | "Algorithms and statistical models that computer systems use to perform tasks without explicit instructions",
55 | "Processing and analyzing natural language data",
56 | "Enabling computers to derive meaningful information from digital images, videos and other visual inputs",
57 | "Design, construction, operation, and use of robots",
58 | "Process of discovering patterns in large data sets",
59 | "Protection of computer systems from theft or damage to their hardware, software, or electronic data",
60 | "Application of computational techniques to analyze biological data",
61 | "Computation based on the principles of quantum theory"
62 | ],
63 | confidence=True # Request confidence scores
64 | )
65 | ```
66 |
67 | Let's break down the `ClassificationInput`:
68 |
69 | - `type`: Set to `ClassificationTypes.MULTI_LABEL` for multi-label classification.
70 | - `labels`: List of possible categories for our research papers.
71 | - `instructions`: Clear directions for the model on how to classify the papers, emphasizing that multiple labels can be assigned.
72 | - `descriptions`: Provide additional context for each label to improve classification accuracy.
73 | - `confidence`: Set to `True` to get confidence scores with the results.
74 |
75 | ## Performing Classification
76 |
77 | Now, let's run the classification task on our collection of research papers:
78 |
79 | ```python
80 | results = classification_provider.process_document_node(
81 | document_node,
82 | classification_input
83 | )
84 | ```
85 |
86 | ## Interpreting Results
87 |
88 | Let's examine the classification results for each research paper:
89 |
90 | ```python
91 | for page_number, result in results.items():
92 | categories = result.labels
93 | confidence = result.score
94 | print(f"Research Paper on Page {page_number}:")
95 | print(f"\tCategories: {', '.join(categories)} ({confidence})")
96 | print('---')
97 | ```
98 |
99 | This will print the assigned categories for each research paper, along with the confidence level.
100 |
101 | ## Tips for Effective Multi-Label Classification
102 |
103 | 1. **Comprehensive Label Set**: Ensure your label set covers the main topics in your domain but isn't so large that it becomes unwieldy.
104 |
105 | 2. **Clear Instructions**: Emphasize in your instructions that multiple labels can and should be assigned when appropriate.
106 |
107 | 3. **Use Descriptions**: The `descriptions` field helps the model understand the nuances of each category, which is especially important for interdisciplinary papers.
108 |
109 | 4. **Consider Confidence Scores**: In multi-label classification, confidence scores can indicate how strongly a paper fits into each assigned category.
110 |
111 | 5. **Analyze Label Co-occurrences**: Look for patterns in which labels frequently appear together to gain insights into interdisciplinary trends.
112 |
113 | 6. **Handle Outliers**: If a paper doesn't fit well into any category, consider adding a catch-all category like "Other" or "Interdisciplinary" in future iterations.
114 |
115 | ## Advanced Usage: Increasing the Power
116 |
117 | For more control over the classification process, you can specify a beefier model from Anthropic to up the reasoning power. This can be done when setting up the task provider OR at inference time. Allowing for easy, fine-grained control of over provider defaults and runtime overrides.
118 |
119 | ```python
120 | classification_provider = factory.get_page_classification_provider(
121 | model_name="claude-3-5-sonnet-20240620" # setup the task provider with sonnet-3.5
122 | )
123 |
124 | results = classification_provider.process_document_node(
125 | document_node,
126 | classification_input,
127 | model_name="claude-3-5-sonnet-20240620" # or you can declare model name at inference time as well
128 | )
129 | ```
130 |
131 | ## Conclusion
132 |
133 | Multi-label classification with DocPrompt provides a powerful way to categorize complex documents like research papers that often span multiple disciplines. By carefully crafting your `ClassificationInput` with clear labels, instructions, and descriptions, you can achieve nuanced and informative document analysis.
134 |
135 | Remember that the quality of your results depends on the clarity of your instructions, the comprehensiveness of your label set, and the appropriateness of your descriptions. Experiment with different configurations to find what works best for your specific use case.
136 |
137 | This approach can be adapted to other multi-label classification tasks, such as categorizing news articles by multiple topics, classifying products by multiple features, or tagging images with multiple attributes.
138 |
--------------------------------------------------------------------------------
/docs/guide/classify/single.md:
--------------------------------------------------------------------------------
1 | # Single-Label Classification with DocPrompt: Recipe Categories
2 |
3 | In this guide, we'll use DocPrompt to classify recipes in a PDF document into distinct meal categories. We'll use single-label classification, meaning each page (recipe) will be assigned to one category: Breakfast, Lunch, Dinner, or Dessert.
4 |
5 | ## Setting Up
6 |
7 | First, let's import the necessary modules and set up our environment:
8 |
9 | ```python
10 | from docprompt import load_document, DocumentNode
11 | from docprompt.tasks.factory import AnthropicTaskProviderFactory
12 | from docprompt.tasks.classification import ClassificationInput, ClassificationTypes
13 |
14 | # Initialize the Anthropic factory
15 | # Ensure you have set the ANTHROPIC_API_KEY environment variable
16 | factory = AnthropicTaskProviderFactory()
17 |
18 | # Create the classification provider
19 | classification_provider = factory.get_page_classification_provider()
20 | ```
21 |
22 | ## Preparing the Document
23 |
24 | Load your recipe book PDF and create a DocumentNode:
25 |
26 | ```python
27 | document = load_document("path/to/your/recipe_book.pdf")
28 | document_node = DocumentNode.from_document(document)
29 | ```
30 |
31 | ## Configuring the Classification Task
32 |
33 | For single-label classification, we'll create a `ClassificationInput` object that specifies our labels and provides instructions for the model:
34 |
35 | ```python
36 | classification_input = ClassificationInput(
37 | type=ClassificationTypes.SINGLE_LABEL,
38 | labels=["Breakfast", "Lunch", "Dinner", "Dessert"],
39 | instructions="Classify the recipe on this page into one of the given meal categories based on its ingredients, cooking methods, and typical serving time.",
40 | descriptions=[
41 | "Morning meals, often including eggs, cereals, or pastries",
42 | "Midday meals, typically lighter fare like sandwiches or salads",
43 | "Evening meals, often the most substantial meal of the day",
44 | "Sweet treats typically served after a meal or as a snack"
45 | ],
46 | confidence=True # Request confidence scores
47 | )
48 | ```
49 |
50 | Let's break down the `ClassificationInput`:
51 |
52 | - `type`: Set to `ClassificationTypes.SINGLE_LABEL` for single-label classification.
53 | - `labels`: List of possible categories for our recipes.
54 | - `instructions`: Clear directions for the model on how to classify the recipes.
55 | - `descriptions`: (Optional) Provide additional context for each label to improve classification accuracy.
56 | - _Note that the `descriptions` array must be the same length as the `labels` array._
57 | - `confidence`: Set to `True` to get confidence scores with the results.
58 |
59 | ## Performing Classification
60 |
61 | Now, let's run the classification task on our recipe book:
62 |
63 | ```python
64 | results = classification_provider.process_document_node(
65 | document_node,
66 | classification_input
67 | )
68 | ```
69 |
70 | ## Interpreting Results
71 |
72 | Let's examine the classification results for each recipe:
73 |
74 | ```python
75 | for page_number, result in results.items():
76 | category = result.labels
77 | confidence = result.score
78 |
79 | print(f"Recipe on Page {page_number}:")
80 | print(f"\tCategory: {category} ({confidence})")
81 | print('---')
82 | ```
83 |
84 | This will print the assigned category (Breakfast, Lunch, Dinner, or Dessert) for each recipe, along with the confidence level.
85 |
86 | ## Tips for Effective Single-Label Classification
87 |
88 | 1. **Comprehensive Labels**: Ensure your label set covers all possible categories without overlap.
89 |
90 | 2. **Clear Instructions**: Provide specific criteria for each category in your instructions. For recipes, mention considering ingredients, cooking methods, and typical serving times.
91 |
92 | 3. **Use Descriptions**: The `descriptions` field can help the model understand nuances between categories, especially for edge cases like brunch recipes.
93 |
94 | 4. **Consider Confidence Scores**: Low confidence scores might indicate recipes that don't clearly fit into one category, such as versatile dishes that could be served at multiple meals.
95 |
96 | 5. **Handling Edge Cases**: If you encounter many low-confidence classifications, you might need to refine your categories or instructions. For example, you might add an "Anytime" category for versatile recipes.
97 |
98 | ## Advanced Usage: Customizing the Model
99 |
100 | If you need to experiment with different LLM models, based on the complexity of your task, you may control the model_name parameter to the classification provider:
101 |
102 | ```python
103 | haiku_classification_provider = factory.get_page_classification_provider(
104 | model_name="claude-3-haiku-20240307"
105 | )
106 |
107 | sonnet_classification_provider = factory.get_page_classification_provider(
108 | model_name="claude-3-5-sonnet-20240620"
109 | )
110 | ```
111 |
112 | ## Conclusion
113 |
114 | Single-label classification with DocPrompt provides a powerful way to categorize pages in your documents, such as recipes in a cookbook. By carefully crafting your `ClassificationInput` with clear labels, instructions, and descriptions, you can achieve accurate and efficient document analysis.
115 |
116 | Remember that the quality of your results depends on the clarity of your instructions and the appropriateness of your label set. Experiment with different phrasings and label combinations to find what works best for your specific use case.
117 |
118 | This approach can be easily adapted to other single-label classification tasks, such as categorizing scientific papers by field, sorting legal documents by type, or classifying news articles by topic.
119 |
--------------------------------------------------------------------------------
/docs/guide/ocr/advanced_search.md:
--------------------------------------------------------------------------------
1 | # Lightning-Fast Document Search 🔥🚀
2 |
3 | Ever wished you could search through OCR-processed documents at the speed of light? Look no further! DocPrompt's Provenance Locator, powered by Rust, offers blazingly fast text search capabilities that will revolutionize your document processing workflows.
4 |
5 | ## The Power of Rust-Powered Search
6 |
7 | DocPrompt's `DocumentProvenanceLocator` is not your average search tool. Implemented in Rust and leveraging the power of `tantivy` and `rtree`, it provides:
8 |
9 | - ⚡ Lightning-fast full-text search
10 | - 🎯 Precise text location within documents
11 | - 🧠 Smart granularity refinement (word, line, block)
12 | - 🗺️ Spatial querying capabilities
13 |
14 | Let's dive into how you can harness this power!
15 |
16 | ## Setting Up the Locator
17 |
18 | First, let's create a `DocumentProvenanceLocator` from a processed `DocumentNode`:
19 |
20 | ```python
21 | from docprompt import load_document, DocumentNode
22 | from docprompt.tasks.factory import GCPTaskProviderFactory
23 |
24 | # Load and process the document
25 | document = load_document("path/to/your/document.pdf")
26 | document_node = DocumentNode.from_document(document)
27 |
28 | # Process with OCR (assuming you've set up the GCP factory)
29 | gcp_factory = GCPTaskProviderFactory(service_account_file="path/to/credentials.json")
30 | ocr_provider = gcp_factory.get_page_ocr_provider(
31 | project_id="your-project-id",
32 | processor_id="your-processor-id"
33 | )
34 | ocr_results = ocr_provider.process_document_node(document_node)
35 |
36 | # Create the locator
37 | locator = document_node.locator
38 | ```
39 |
40 | ## Searching at the Speed of Rust 🔥
41 |
42 | Now that we have our locator, let's see it in action:
43 |
44 | ```python
45 | # Perform a simple search
46 | results = locator.search("DocPrompt")
47 |
48 | for result in results:
49 | print(f"Found on page {result.page_number}")
50 | print(f"Text: {result.text}")
51 | print(f"Bounding box: {result.text_location.merged_source_block.bounding_box}")
52 | print("---")
53 | ```
54 |
55 | This search operation happens in milliseconds, even for large documents, thanks to the Rust-powered backend!
56 |
57 | ## Advanced Search Capabilities
58 |
59 | ### Refining to Word Level
60 |
61 | DocPrompt can automatically refine search results to the word level:
62 |
63 | ```python
64 | refined_results = locator.search("DocPrompt", refine_to_word=True)
65 | ```
66 |
67 | This gives you pinpoint accuracy in locating text within your document.
68 |
69 | ### Page-Specific Search
70 |
71 | Need to search on a specific page? No problem:
72 |
73 | ```python
74 | page_5_results = locator.search("DocPrompt", page_number=5)
75 | ```
76 |
77 | ### Best Match Search
78 |
79 | Find the best matches based on different criteria:
80 |
81 | ```python
82 | best_short_matches = locator.search_n_best("DocPrompt", n=3, mode="shortest_text")
83 |
84 | best_long_matches = locator.search_n_best("DocPrompt", n=3, mode="longest_text")
85 |
86 | best_overall_matches = locator.search_n_best("DocPrompt", n=3, mode="highest_score")
87 | ```
88 |
89 | ## Spatial Queries: Beyond Text Search 🗺️
90 |
91 | DocPrompt's locator isn't just fast—it's spatially aware! You can perform queries based on document layout:
92 |
93 | ```python
94 | from docprompt.schema.layout import NormBBox
95 |
96 | # Get blocks near a specific area on page 1
97 | bbox = NormBBox(x0=0.1, top=0.1, x1=0.2, bottom=0.2)
98 | nearby_blocks = locator.get_k_nearest_blocks(bbox, page_number=1, k=5)
99 |
100 | # Get overlapping blocks
101 | overlapping_blocks = locator.get_overlapping_blocks(bbox, page_number=1)
102 | ```
103 |
104 | This spatial awareness opens up possibilities for advanced document analysis and data extraction!
105 |
106 | ## Conclusion: Search at the Speed of Thought 🧠💨
107 |
108 | DocPrompt's `DocumentProvenanceLocator` brings unprecedented speed and precision to document search and analysis. By leveraging the power of Rust, it offers:
109 |
110 | 1. Lightning-fast full-text search
111 | 2. Precise text location within documents
112 | 3. Advanced spatial querying capabilities
113 | 4. Scalability for large documents and datasets
114 |
115 | Whether you're building a document analysis pipeline, a search system, or any text-based application, DocPrompt's Provenance Locator offers the speed and accuracy you need to stay ahead of the game.
116 |
--------------------------------------------------------------------------------
/docs/guide/ocr/advanced_workflows.md:
--------------------------------------------------------------------------------
1 | # Advanced Document Analysis:
2 |
3 |
4 | ### Detecting Potential Conflicts of Interest in 10-K Reports
5 | In this guide, we'll demonstrate how to use DocPrompt's powerful OCR and search capabilities to analyze 10-K reports for potential conflicts of interest. We'll search for mentions of company names and executive names, then identify instances where they appear in close proximity within the document.
6 |
7 | ## Setup
8 |
9 | First, let's set up our environment and process the document:
10 |
11 | ```python
12 | from docprompt import load_document, DocumentNode
13 | from docprompt.tasks.factory import GCPTaskProviderFactory
14 | from docprompt.schema.layout import NormBBox
15 | from itertools import product
16 |
17 | # Load and process the document
18 | document = load_document("path/to/10k_report.pdf")
19 | document_node = DocumentNode.from_document(document)
20 |
21 | # Perform OCR
22 | gcp_factory = GCPTaskProviderFactory(service_account_file="path/to/credentials.json")
23 | ocr_provider = gcp_factory.get_page_ocr_provider(project_id="your-project-id", processor_id="your-processor-id")
24 | ocr_provider.process_document_node(document_node)
25 |
26 | # Create the locator
27 | locator = document_node.locator
28 |
29 | # Define entities to search for
30 | company_names = ["SubsidiaryA", "PartnerB", "CompetitorC"]
31 | executive_names = ["John Doe", "Jane Smith", "Alice Johnson"]
32 | ```
33 |
34 | ## Searching for Entities
35 |
36 | Now, let's use DocPrompt's fast search capabilities to find all mentions of companies and executives. By leveraging the speed of the rust powered locator, along with python's builtin comprehension, we can execute our set of queries over the several-hundred page document in a matter of miliseconds.
37 |
38 | ```python
39 | company_results = {
40 | company: locator.search(company)
41 | for company in company_names
42 | }
43 |
44 | executive_results = {
45 | executive: locator.search(executive)
46 | for executive in executive_names
47 | }
48 | ```
49 |
50 | ## Detecting Proximity
51 |
52 | Next, we'll check for instances where company names and executive names appear in close proximity:
53 |
54 | ```python
55 | def check_proximity(bbox1, bbox2, threshold=0.1):
56 | left_collision = abs(bbox1.x0 - bbox2.x0) < threshold
57 | top_collision = abs(bbox1.top - bbox2.top) < threshold
58 |
59 | return left_collision and top_collision
60 |
61 | potential_conflicts = []
62 |
63 | for company, exec_name in product(company_names, executive_names):
64 | c_result = company_results[company]
65 | e_result = exexecutive_results[exec_name]
66 |
67 | # Check if the two results appear on the same page
68 | if c_result.page_number == e_result.page_number:
69 | c_bbox = c_result.text_location.merged_source_block.bounding_box
70 | e_bbox = e_result.text_location.merged_source_block.bounding_box
71 |
72 | # If they do, check if the bounding boxes break our threshold
73 | if check_proximity(c_bbox, e_bbox):
74 | potential_conflicts.append({
75 | 'company': company,
76 | 'executive': exec_name,
77 | 'page': c_result.page_number,
78 | 'company_bbox': c_bbox,
79 | 'exec_bbox': e_bbox
80 | })
81 | ```
82 |
83 | ## Analyzing Results
84 |
85 | Finally, let's analyze and display our results:
86 |
87 | ```python
88 | print(f"Found {len(potential_conflicts)} potential conflicts of interest:")
89 |
90 | for conflict in potential_conflicts:
91 | print(f"\nPotential conflict on page {conflict['page']}:")
92 | print(f" Company: {conflict['company']}")
93 | print(f" Executive: {conflict['executive']}")
94 |
95 | # Get surrounding context
96 | context_bbox = NormBBox(
97 | x0=min(conflict['company_bbox'].x0, conflict['exec_bbox'].x0) - 0.05,
98 | top=min(conflict['company_bbox'].top, conflict['exec_bbox'].top) - 0.05,
99 | x1=max(conflict['company_bbox'].x1, conflict['exec_bbox'].x1) + 0.05,
100 | bottom=max(conflict['company_bbox'].bottom, conflict['exec_bbox'].bottom) + 0.05
101 | )
102 |
103 | context_blocks = locator.get_overlapping_blocks(context_bbox, conflict['page'])
104 |
105 | print(" Context:")
106 | for block in context_blocks:
107 | print(f" {block.text}")
108 | ```
109 |
110 | This refined approach demonstrates several key features of DocPrompt:
111 |
112 | 1. **Fast and Accurate Search**: We use the `DocumentProvenanceLocator` to quickly find all mentions of companies and executives across the entire document.
113 |
114 | 2. **Spatial Analysis**: By leveraging the bounding box information, we can determine when two entities are mentioned in close proximity on the page.
115 |
116 | 3. **Contextual Information**: We use spatial queries to extract the surrounding text, providing context for each potential conflict of interest.
117 |
118 | 4. **Scalability**: This approach can easily handle multiple companies and executives, making it suitable for analyzing large, complex documents.
119 |
120 | By combining these capabilities, DocPrompt enables efficient and thorough analysis of 10-K reports, helping to identify potential conflicts of interest that might otherwise be overlooked in manual review processes.
121 |
--------------------------------------------------------------------------------
/docs/guide/ocr/basic_usage.md:
--------------------------------------------------------------------------------
1 | # Basic OCR Usage with Docprompt
2 |
3 | This guide will walk you through the basics of performing Optical Character Recognition (OCR) using Docprompt. You'll learn how to set up the OCR provider, process a document, and access the results.
4 |
5 | ## Prerequisites
6 |
7 | Before you begin, ensure you have:
8 |
9 | 1. Installed Docprompt with OCR support: `pip install "docprompt[google]"`
10 | 2. A Google Cloud Platform account with Document AI API enabled
11 | 3. A GCP service account key file
12 |
13 | ## Setting Up the OCR Provider
14 |
15 | First, let's set up the Google OCR provider:
16 |
17 | ```python
18 | from docprompt.tasks.factory import GCPTaskProviderFactory
19 |
20 | # Initialize the GCP Task Provider Factory
21 | gcp_factory = GCPTaskProviderFactory(
22 | service_account_file="path/to/your/service_account_key.json"
23 | )
24 |
25 | # Create the OCR provider
26 | ocr_provider = gcp_factory.get_page_ocr_provider(
27 | project_id="your-gcp-project-id",
28 | processor_id="your-document-ai-processor-id"
29 | )
30 | ```
31 |
32 | ## Loading and Processing a Document
33 |
34 | Now, let's load a document and process it using OCR:
35 |
36 | ```python
37 | from docprompt import load_document, DocumentNode
38 |
39 | # Load the document
40 | document = load_document("path/to/your/document.pdf")
41 | document_node = DocumentNode.from_document(document)
42 |
43 | # Process the document
44 | ocr_results = ocr_provider.process_document_node(document_node)
45 | ```
46 |
47 | ## Accessing OCR Results
48 |
49 | After processing, you can access the OCR results in various ways:
50 |
51 | ### 1. Page-level Text
52 |
53 | To get the full text of a specific page:
54 |
55 | ```python
56 | page_number = 1 # Pages are 1-indexed
57 | page_text = ocr_results[page_number].page_text
58 | print(f"Text on page {page_number}:\n{page_text[:500]}...") # Print first 500 characters
59 | ```
60 |
61 | ### 2. Words, Lines, and Blocks
62 |
63 | Docprompt provides access to words, lines, and blocks (paragraphs) extracted from the document:
64 |
65 | ```python
66 | # Get the first page's result
67 | first_page_result = ocr_results[1]
68 |
69 | # Print the first 5 words on the page
70 | print("First 5 words:")
71 | for word in first_page_result.word_level_blocks[:5]:
72 | print(f"Word: {word.text}, Confidence: {word.metadata.confidence}")
73 |
74 | # Print the first line on the page
75 | print("\nFirst line:")
76 | if first_page_result.line_level_blocks:
77 | first_line = first_page_result.line_level_blocks[0]
78 | print(f"Line: {first_line.text}")
79 |
80 | # Print the first block (paragraph) on the page
81 | print("\nFirst block:")
82 | if first_page_result.block_level_blocks:
83 | first_block = first_page_result.block_level_blocks[0]
84 | print(f"Block: {first_block.text[:100]}...") # Print first 100 characters
85 | ```
86 |
87 | ### 3. Bounding Boxes
88 |
89 | Each word, line, and block comes with bounding box information:
90 |
91 | ```python
92 | # Get bounding box for the first word
93 | if first_page_result.word_level_blocks:
94 | first_word = first_page_result.word_level_blocks[0]
95 | bbox = first_word.bounding_box
96 | print(f"\nBounding box for '{first_word.text}':")
97 | print(f"Top-left: ({bbox.x0}, {bbox.top})")
98 | print(f"Bottom-right: ({bbox.x1}, {bbox.bottom})")
99 | ```
100 |
101 | ## Conclusion
102 |
103 | You've now learned the basics of performing OCR with Docprompt. This includes setting up the OCR provider, processing a document, and accessing the results at different levels of granularity.
104 |
105 | For more advanced usage, including customizing OCR settings, using other providers, and leveraging the powerful search capabilities, check out our other guides:
106 |
107 | - Customizing OCR Providers and Settings
108 | - Advanced Text Search with Provenance Locators
109 | - Building OCR-based Workflows
110 |
--------------------------------------------------------------------------------
/docs/guide/ocr/provider_config.md:
--------------------------------------------------------------------------------
1 | # Customizing OCR Providers
2 |
3 | Docprompt uses a factory pattern to manage credentials and create task providers efficiently. This guide will demonstrate how to configure and customize OCR providers, focusing on Amazon Textract and Google Cloud Platform (GCP) as examples.
4 |
5 | ## Understanding the Factory Pattern
6 |
7 | Docprompt uses task provider factories to manage credentials and create providers for various tasks. This approach allows for:
8 |
9 | 1. Centralized credential management
10 | 2. Easy creation of multiple task providers from a single backend
11 | 3. Separation of provider-specific and task-specific configurations
12 |
13 | Here's a simplified example of how the factory pattern works:
14 |
15 | ```python
16 | from docprompt.tasks.factory import GCPTaskProviderFactory, AmazonTaskProviderFactory
17 |
18 | # Create a GCP factory
19 | gcp_factory = GCPTaskProviderFactory(
20 | service_account_file="path/to/service_account.json"
21 | )
22 |
23 | # Create an Amazon factory
24 | amazon_factory = AmazonTaskProviderFactory(
25 | aws_access_key_id="YOUR_ACCESS_KEY",
26 | aws_secret_access_key="YOUR_SECRET_KEY",
27 | region_name="us-west-2"
28 | )
29 | ```
30 |
31 | ## Creating OCR Providers
32 |
33 | Once you have a factory, you can create OCR providers with task-specific configurations:
34 |
35 | ```python
36 | # Create a GCP OCR provider
37 | gcp_ocr_provider = gcp_factory.get_page_ocr_provider(
38 | project_id="YOUR_PROJECT_ID",
39 | processor_id="YOUR_PROCESSOR_ID",
40 | max_workers=4,
41 | return_images=True
42 | )
43 |
44 | # Create an Amazon Textract provider
45 | amazon_ocr_provider = amazon_factory.get_page_ocr_provider(
46 | max_workers=4,
47 | exclude_bounding_poly=True
48 | )
49 | ```
50 |
51 | ## Understanding Provider Configuration
52 |
53 | When configuring OCR providers, you'll encounter two types of parameters:
54 |
55 | 1. **Docprompt generic parameters**: These are common across different providers and control Docprompt's behavior.
56 | - `max_workers`: Controls concurrency for processing large documents
57 | - `exclude_bounding_poly`: Reduces memory usage by excluding detailed polygon data
58 |
59 | 2. **Provider-specific parameters**: These are unique to each backend and control provider-specific features. For example, if using GCP as an OCR provider, you must specify `project_id`, `processor_id` and you may optionally set `return_image_quality_scores`.
60 |
61 | ## Provider-Specific Features and Limitations
62 |
63 | ### Google Cloud Platform (GCP)
64 | - Offers advanced layout analysis and image quality scoring
65 | - Supports returning rasterized images of processed pages
66 | - Requires GCP-specific project and processor IDs
67 |
68 | Example configuration:
69 | ```python
70 | gcp_ocr_provider = gcp_factory.get_page_ocr_provider(
71 | project_id="YOUR_PROJECT_ID",
72 | processor_id="YOUR_PROCESSOR_ID",
73 | max_workers=4, # Docprompt generic
74 | return_images=True, # GCP-specific
75 | return_image_quality_scores=True # GCP-specific
76 | )
77 | ```
78 |
79 | ### Amazon Textract
80 | - Focuses on text extraction and layout analysis
81 | - Provides confidence scores for extracted text
82 | - Does not support returning rasterized images
83 |
84 | Example configuration:
85 | ```python
86 | amazon_ocr_provider = amazon_factory.get_page_ocr_provider(
87 | max_workers=4, # Docprompt generic
88 | exclude_bounding_poly=True # Docprompt generic
89 | )
90 | ```
91 |
92 | ## Best Practices
93 |
94 | 1. **Use factories for credential management**: This centralizes authentication and makes it easier to switch between providers.
95 |
96 | 2. **Consult provider documentation**: Always refer to the latest documentation from AWS or GCP for the most up-to-date information on their OCR services.
97 |
98 | 3. **Check Docprompt API reference**: Review Docprompt's API documentation for each provider to understand available configurations.
99 |
100 | 4. **Optimize for your use case**: Configure providers based on your specific needs, balancing performance and feature requirements.
101 |
102 | ## Conclusion
103 |
104 | Understanding the factory pattern and the distinction between Docprompt generic and provider-specific parameters is key to effectively configuring OCR providers in Docprompt. While this guide provides an overview using Amazon Textract and GCP as examples, the principles apply to other providers as well. Always consult the specific provider's documentation and Docprompt's API reference for the most current and detailed information.
105 |
--------------------------------------------------------------------------------
/docs/guide/table_extraction/extract_tables.md:
--------------------------------------------------------------------------------
1 | # Table Extraction with DocPrompt: Invoice Parsing
2 |
3 | DocPrompt can be used to extract tables from documents with high accuracy using visual large language models, such as GPT-4 Vision or Anthropic's Claude 3. In this guide, we'll demonstrate how to extract tables from invoices using DocPrompt.
4 |
5 | ## Setting Up
6 |
7 | First, let's import the necessary modules and set up our environment:
8 |
9 | ```python
10 | from docprompt import load_document_node, DocumentNode
11 | from docprompt.tasks.factory import AnthropicTaskProviderFactory
12 | from docprompt.tasks.table_extraction import TableExtractionInput
13 |
14 | # Initialize the Anthropic factory
15 | # Ensure you have set the ANTHROPIC_API_KEY environment variable
16 | factory = AnthropicTaskProviderFactory()
17 |
18 | # Create the table extraction provider
19 | table_extraction_provider = factory.get_page_table_extraction_provider()
20 | ```
21 |
22 | ## Preparing the Document
23 |
24 | Load a DocumentNode from a path
25 |
26 | ```python
27 | document_node = load_document_node("path/to/your/invoice.pdf")
28 | ```
29 |
30 | ## Performing Table Extraction
31 |
32 | Now, let's run the table extraction task on our invoice:
33 |
34 | ```python
35 | results = table_extraction_provider.process_document_node(document_node) # Sync
36 |
37 | async_results = await table_extraction_provider.aprocess_document_node(document_node)
38 | ```
39 |
40 | Alternatively, we can do table extraction async as well
41 |
42 | ## Interpreting Results
43 |
44 | Let's examine the extracted tables from a pretend invoice:
45 |
46 | ```python
47 | for page_number, result in results.items():
48 | print(f"Tables extracted from Page {page_number}:")
49 | for i, table in enumerate(result.tables, 1):
50 | print(f"\nTable {i}:")
51 | print(f"Title: {table.title}")
52 | print("Headers:")
53 | print(", ".join(header.text for header in table.headers))
54 | print("Rows:")
55 | for row in table.rows:
56 | print(", ".join(cell.text for cell in row.cells))
57 | print('---')
58 | ```
59 |
60 | This will print the extracted tables, including headers and rows, for each page of the invoice.
61 |
62 | ## Increasing Accuracy
63 |
64 | In Anthropic's case, the default is `"claude-3-haiku-20240307"`. This performs with high accuracy, and is over 5x cheaper than table extraction using Azure Document Intelligence.
65 |
66 | In use-cases where accuracy is paramount however, it may be worthwhile to set the provider to a more powerful model.
67 |
68 | ```python
69 | table_extraction_provider = factory.get_page_table_extraction_provider(
70 | model_name="claude-3-5-sonnet-20240620" # setup the task provider with Sonnet 35
71 | )
72 |
73 | results = table_extraction_provider.process_document_node(
74 | document_node,
75 | table_extraction_input,
76 | model_name="claude-3-5-sonnet-20240620" # or declare model name at inference time
77 | )
78 | ```
79 |
80 | As Large Language Models steadily get cheaper and more capable, your inference costs will drop inevitably. The beauty of progress!
81 |
82 |
83 | ## Resolving Bounding Boxes
84 |
85 | **Coming Soon**
86 |
87 | In some scenarios, you may want the exact bounding boxes of the various rows, columns, and cells. If you've processed OCR results through Docprompt, this is possible by specifying an additional argument in `process_document_node`
88 |
89 | ```python
90 | results = table_extraction_provider.process_document_node(
91 | document_node,
92 | table_extraction_input,
93 | model_name="claude-3-5-sonnet-20240620", # or declare model name at inference time
94 | resolve_bounding_boxes=True
95 | )
96 | ```
97 |
98 | If you've collected and stored OCR results on the DocumentNode, this will use word-level bounding boxes coupled with the Docprompt search engine to determine the bounding boxes of the resulting tables, where possible.
99 |
100 | ## Conclusion
101 |
102 | Table extraction with DocPrompt provides a powerful way to automatically parse structured data from any documents containing tabular information in just a few lines of code.
103 |
104 | The quality of your results depends on the model and the complexity of the table layouts. Experiment with different configurations and post-processing steps to find what works best for your specific use case.
105 |
106 | When combining with other tasks such as classification, layout analysis and markerization, you can build powerful document processing pipelines in just a few steps.
107 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | # Docprompt - Getting Started
2 |
3 | ## Supercharged Document Analysis
4 |
5 | * Common utilities for interacting with PDFs
6 | * PDF loading and serialization
7 | * PDF byte compression using Ghostscript :ghost:
8 | * Fast rasterization :fire: :rocket:
9 | * Page splitting, re-export with PDFium
10 | * Support for most OCR providers with batched inference
11 | * Google :white_check_mark:
12 | * Azure Document Intelligence :red_circle:
13 | * Amazon Textract :red_circle:
14 | * Tesseract :red_circle:
15 |
16 |
17 |
18 | ### Installation
19 |
20 | Base installation
21 |
22 | ```bash
23 | pip install docprompt
24 | ```
25 |
26 | With an OCR provider
27 |
28 | ```bash
29 | pip install "docprompt[google]
30 | ```
31 |
32 | ## Usage
33 |
34 |
35 | ### Simple Operations
36 | ```python
37 | from docprompt import load_document
38 |
39 | # Load a document
40 | document = load_document("path/to/my.pdf")
41 |
42 | # Rasterize a single page using Ghostscript
43 | page_number = 5
44 | rastered = document.rasterize_page(page_number, dpi=120)
45 |
46 | # Split a pdf based on a page range
47 | document_2 = document.split(start=125, stop=130)
48 | ```
49 |
50 | ### Performing OCR
51 | ```python
52 | from docprompt import load_document, DocumentNode
53 | from docprompt.tasks.ocr.gcp import GoogleOcrProvider
54 |
55 | provider = GoogleOcrProvider.from_service_account_file(
56 | project_id=my_project_id,
57 | processor_id=my_processor_id,
58 | service_account_file=path_to_service_file
59 | )
60 |
61 | document = load_document("path/to/my.pdf")
62 |
63 | # A container holds derived data for a document, like OCR or classification results
64 | document_node = DocumentNode.from_document(document)
65 |
66 | provider.process_document_node(document_node) # Caches results on the document_node
67 |
68 | document_node[0].ocr_result # Access OCR results
69 | ```
70 |
71 |
72 | ### Document Search
73 |
74 | When a large language model returns a result, we might want to highlight that result for our users. However, language models return results as **text**, while what we need to show our users requires a page number and a bounding box.
75 |
76 | After extracting text from a PDF, we can support this pattern using `DocumentProvenanceLocator`, which lives on a `DocumentNode`
77 |
78 | ```
79 | from docprompt import load_document, DocumentNode
80 | from docprompt.tasks.ocr.gcp import GoogleOcrProvider
81 |
82 | provider = GoogleOcrProvider.from_service_account_file(
83 | project_id=my_project_id,
84 | processor_id=my_processor_id,
85 | service_account_file=path_to_service_file
86 | )
87 |
88 | document = load_document("path/to/my.pdf")
89 |
90 | # A container holds derived data for a document, like OCR or classification results
91 | document_node = DocumentNode.from_document(document)
92 |
93 | provider.process_document_node(document_node) # Caches results on the document_node
94 |
95 | # With OCR results available, we can now instantiate a locator and search through documents.
96 |
97 | document_node.locator.search("John Doe") # This will return a list of all terms across the document that contain "John Doe"
98 | document_node.locator.search("Jane Doe", page_number=4) # Just return results a list of matching results from page 4
99 | ```
100 |
101 | This functionality uses a combination of `rtree` and the Rust library `tantivy`, allowing you to perform thousands of searches in **seconds** :fire: :rocket:
102 |
--------------------------------------------------------------------------------
/makefile:
--------------------------------------------------------------------------------
1 | sources = docprompt
2 |
3 | .PHONY: test format lint unittest coverage pre-commit clean
4 | test: format lint unittest
5 |
6 | format:
7 | isort $(sources) tests
8 | black $(sources) tests
9 |
10 | lint:
11 | flake8 $(sources) tests
12 | mypy $(sources) tests
13 |
14 | unittest:
15 | pytest
16 |
17 | coverage:
18 | pytest --cov=$(sources) --cov-branch --cov-report=term-missing tests
19 |
20 | pre-commit:
21 | pre-commit run --all-files
22 |
23 | clean:
24 | rm -rf .mypy_cache .pytest_cache
25 | rm -rf *.egg-info
26 | rm -rf .tox dist site
27 | rm -rf coverage.xml .coverage
28 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: Docprompt
2 | site_url: https://docs.docprompt.io
3 | repo_url: https://github.com/docprompt/Docprompt
4 | theme:
5 | name: material
6 | logo: assets/static/img/logo.png
7 | favicon: assets/static/img/logo.png
8 | features:
9 | - navigation.instant
10 | - navigation.instant.prefetch
11 | plugins:
12 | - search
13 | - blog
14 | - gen-files:
15 | scripts:
16 | - docs/gen_ref_pages.py
17 | - literate-nav:
18 | nav_file: SUMMARY.md
19 | - mkdocstrings:
20 | handlers:
21 | python:
22 | paths: [docprompt]
23 | options:
24 | docstring_style: google
25 | show_source: true
26 | show_submodules: true
27 |
28 | nav:
29 | - Getting Started: index.md
30 | - How-to Guides:
31 | - Perform OCR:
32 | - Basic OCR Usage: guide/ocr/basic_usage.md
33 | - Customizing OCR Providers: guide/ocr/provider_config.md
34 | - Lightning-Fast Doc Search: guide/ocr/advanced_search.md
35 | - OCR-based Workflows: guide/ocr/advanced_workflows.md
36 | - Classify Pages:
37 | - Binary Classification: guide/classify/binary.md
38 | - Single-Label Classification: guide/classify/single.md
39 | - Multi-Label Classification: guide/classify/multi.md
40 | - Extract Tables: guide/table_extraction/extract_tables.md
41 | - Concepts:
42 | - Primatives: concepts/primatives.md
43 | - Nodes: concepts/nodes.md
44 | - Providers: concepts/providers.md
45 | - Provenance: concepts/provenance.md
46 | - Cloud: enterprise.md
47 | - Blog:
48 | - blog/index.md
49 | - API Reference:
50 | - Docprompt SDK: reference/
51 | - Enterpise API: enterprise/
52 | - Community:
53 | - Contributing: community/contributing.md
54 | - Versioning: community/versioning.md
55 |
56 | markdown_extensions:
57 | - pymdownx.highlight:
58 | anchor_linenums: true
59 | - pymdownx.superfences
60 | - pymdownx.emoji:
61 | emoji_index: !!python/name:materialx.emoji.twemoji
62 | emoji_generator: !!python/name:materialx.emoji.to_svg
63 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["pdm-backend"]
3 | build-backend = "pdm.backend"
4 |
5 | [project]
6 | name = "docprompt"
7 | version = "0.8.7"
8 | description = "Documents and large language models."
9 | authors = [
10 | {name = "Frank Colson", email = "frank@pageleaf.io"}
11 | ]
12 | dependencies = [
13 | "pillow>=9.0.1",
14 | "tqdm>=4.50.2",
15 | "fsspec>=2022.11.0",
16 | "pydantic>=2.1.0",
17 | "tenacity>=7.0.0",
18 | "pypdfium2<5.0.0,>=4.28.0",
19 | "filetype>=1.2.0",
20 | "beautifulsoup4>=4.12.3",
21 | "pypdf>=5.0.0"
22 | ]
23 | requires-python = "<3.13,>=3.8.6"
24 | readme = "README.md"
25 | license = {text = "Apache-2.0"}
26 | classifiers = ["Development Status :: 2 - Pre-Alpha", "Intended Audience :: Developers", "License :: OSI Approved :: Apache Software License", "Natural Language :: English", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12"]
27 |
28 | [project.optional-dependencies]
29 | google = ["google-cloud-documentai>=2.20.0"]
30 | azure = ["azure-ai-formrecognizer>=3.3.0"]
31 | search = [
32 | "tantivy<1.0.0,>=0.21.0",
33 | "rtree<3.0.0,>=1.2.0",
34 | "networkx<4.0.0,>=2.5.0",
35 | "rapidfuzz>=3.0.0"
36 | ]
37 | anthropic = [
38 | "anthropic>=0.26.0"
39 | ]
40 | openai = [
41 | "openai>=1.0.1"
42 | ]
43 | aws = [
44 | "aioboto3>=13.1.0",
45 | "boto3>=1.18.0"
46 | ]
47 |
48 | [project.scripts]
49 | docprompt = "docprompt.cli:main"
50 |
51 | [project.urls]
52 | homepage = "https://github.com/Docprompt/docprompt"
53 |
54 | [tool.black]
55 | line-length = 120
56 | skip-string-normalization = true
57 | target-version = ['py39', 'py310', 'py311']
58 | include = '\.pyi?$'
59 | exclude = '''
60 | /(
61 | \.eggs
62 | | \.git
63 | | \.hg
64 | | \.mypy_cache
65 | | \.tox
66 | | \.venv`
67 | | _build
68 | | buck-out
69 | | build
70 | | dist
71 | )/
72 | '''
73 |
74 | [tool.flake8]
75 | ignore = [
76 | "E501"
77 | ]
78 |
79 | [tool.isort]
80 | multi_line_output = 3
81 | include_trailing_comma = true
82 | force_grid_wrap = 0
83 | use_parentheses = true
84 | ensure_newline_before_comments = true
85 | line_length = 120
86 | skip_gitignore = true
87 |
88 | [tool.pdm]
89 | distribution = true
90 |
91 | [tool.pdm.build]
92 | includes = ["docprompt", "tests"]
93 |
94 | [tool.pdm.dev-dependencies]
95 | test = [
96 | "isort<6.0.0,>=5.12.0",
97 | "flake8<7.0.0,>=6.1.0",
98 | "flake8-docstrings<2.0.0,>=1.7.0",
99 | "mypy<2.0.0,>=1.6.1",
100 | "pytest<8.0.0,>=7.4.2",
101 | "pytest-cov<5.0.0,>=4.1.0",
102 | "ruff<1.0.0,>=0.3.3",
103 | "pytest-asyncio>=0.23.7"
104 | ]
105 | dev = [
106 | "tox<4.0.0,>=3.20.1",
107 | "virtualenv<25.0.0,>=20.2.2",
108 | "pip<21.0.0,>=20.3.1",
109 | "twine<4.0.0,>=3.3.0",
110 | "pre-commit>=3.5.0",
111 | "toml<1.0.0,>=0.10.2",
112 | "bump2version<2.0.0,>=1.0.1",
113 | "ipython>=7.12.0",
114 | "python-dotenv>=1.0.1",
115 | "ipykernel>=6.29.4",
116 | "pytest-asyncio>=0.23.7"
117 | ]
118 | docs = [
119 | "mkdocs>=1.6.0",
120 | "mkdocs-material>=9.5.27",
121 | "mkdocstrings[python]>=0.25.1",
122 | "mkdocs-blog-plugin>=0.25",
123 | "mkdocs-gen-files>=0.5.0",
124 | "mkdocs-literate-nav>=0.6.1"
125 | ]
126 |
127 | [tool.pdm.scripts]
128 | docs = "mkdocs serve"
129 | lint = "pre-commit run --all-files"
130 | cov = {shell = "python tests/_run_tests_with_coverage.py {args}"}
131 |
132 | [tool.ruff]
133 | target-version = "py38"
134 |
135 | [tool.ruff.lint]
136 | # Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
137 | # Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or
138 | # McCabe complexity (`C901`) by default.
139 | select = ["E4", "E7", "E9", "F"]
140 | extend-select = ["I"]
141 | ignore = ["D212"]
142 | # Allow fix for all enabled rules (when `--fix`) is provided.
143 | fixable = ["ALL"]
144 | unfixable = []
145 | # Allow unused variables when underscore-prefixed.
146 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
147 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 120
3 | max-complexity = 18
4 | ignore = E203, E266, W503
5 | docstring-convention = google
6 | per-file-ignores = __init__.py:F401
7 | exclude = .git,
8 | __pycache__,
9 | setup.py,
10 | build,
11 | dist,
12 | docs,
13 | releases,
14 | .venv,
15 | .tox,
16 | .mypy_cache,
17 | .pytest_cache,
18 | .vscode,
19 | .github,
20 | # By default test codes will be linted.
21 | # tests
22 |
23 | [mypy]
24 | ignore_missing_imports = True
25 |
26 | [coverage:run]
27 | # uncomment the following to omit files during running
28 | #omit =
29 | [coverage:report]
30 | exclude_lines =
31 | pragma: no cover
32 | def __repr__
33 | if self.debug:
34 | if settings.DEBUG
35 | raise AssertionError
36 | raise NotImplementedError
37 | if 0:
38 | if __name__ == .__main__.:
39 | def main
40 |
41 | [tox:tox]
42 | isolated_build = true
43 | envlist = py39, py310, py311, py312 format, lint, build
44 |
45 | [gh-actions]
46 | python =
47 | 3.11: py311, format, lint, build
48 | 3.10: py310
49 | 3.9: py39
50 |
51 | [testenv]
52 | allowlist_externals = pytest
53 | extras =
54 | test
55 | passenv = *
56 | setenv =
57 | PYTHONPATH = {toxinidir}
58 | PYTHONWARNINGS = ignore
59 | commands =
60 | pytest --cov=docprompt --cov-branch --cov-report=xml --cov-report=term-missing tests
61 |
62 | [testenv:format]
63 | allowlist_externals =
64 | isort
65 | black
66 | extras =
67 | test
68 | commands =
69 | isort docprompt
70 | black docprompt tests
71 |
72 | [testenv:lint]
73 | allowlist_externals =
74 | flake8
75 | mypy
76 | extras =
77 | test
78 | commands =
79 | flake8 docprompt tests
80 | mypy docprompt tests
81 |
82 | [testenv:build]
83 | allowlist_externals =
84 | pdm
85 | mkdocs
86 | twine
87 | extras =
88 | doc
89 | dev
90 | commands =
91 | pdm build
92 | mkdocs build
93 | twine check dist/*
94 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | """Unit test package for docprompt."""
2 |
--------------------------------------------------------------------------------
/tests/_run_tests_with_coverage.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import sys
3 |
4 |
5 | def main():
6 | args = sys.argv[1:]
7 | module = "docprompt" # Default module
8 | pytest_args = []
9 |
10 | # Parse arguments
11 | for arg in args:
12 | if arg.startswith("--mod="):
13 | module = arg.split("=")[1]
14 | else:
15 | pytest_args.append(arg)
16 |
17 | # Construct the pytest command
18 | command = [
19 | "pytest",
20 | f"--cov={module}",
21 | "--cov-report=term-missing",
22 | ] + pytest_args
23 |
24 | # Run the command
25 | subprocess.run(command)
26 |
27 |
28 | if __name__ == "__main__":
29 | main()
30 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | def pytest_configure(config):
2 | config.addinivalue_line(
3 | "filterwarnings", "ignore:The NumPy module was reloaded:UserWarning"
4 | )
5 |
--------------------------------------------------------------------------------
/tests/fixtures.py:
--------------------------------------------------------------------------------
1 | from .util import PdfFixture
2 |
3 | PDF_FIXTURES = [
4 | PdfFixture(
5 | name="1.pdf",
6 | page_count=6,
7 | file_hash="121ffed4336e6129e97ee3c4cb747864",
8 | ocr_name="1_ocr.json",
9 | ),
10 | PdfFixture(
11 | name="2.pdf",
12 | page_count=23,
13 | file_hash="bd2fa4f101b305e4001acf9137ce78cf",
14 | ocr_name="1_ocr.json",
15 | ),
16 | ]
17 |
--------------------------------------------------------------------------------
/tests/fixtures/1.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/tests/fixtures/1.pdf
--------------------------------------------------------------------------------
/tests/fixtures/2.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/tests/fixtures/2.pdf
--------------------------------------------------------------------------------
/tests/schema/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/tests/schema/__init__.py
--------------------------------------------------------------------------------
/tests/schema/pipeline/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/tests/schema/pipeline/__init__.py
--------------------------------------------------------------------------------
/tests/schema/pipeline/test_imagenode.py:
--------------------------------------------------------------------------------
1 | from docprompt.schema.pipeline import ImageNode
2 |
3 |
4 | def test_imagenode():
5 | ImageNode(image=b"test", metadata={})
6 |
--------------------------------------------------------------------------------
/tests/schema/pipeline/test_layoutaware.py:
--------------------------------------------------------------------------------
1 | from tests.fixtures import PDF_FIXTURES
2 |
3 |
4 | def test_direct__page_node_layout_aware_text():
5 | # Create a sample PageNode with some TextBlocks
6 | fixture = PDF_FIXTURES[0]
7 |
8 | document = fixture.get_document_node()
9 |
10 | page = document.page_nodes[0]
11 |
12 | assert page.ocr_results, "The OCR results should be populated"
13 |
14 | layout_text__property = page.layout_aware_text
15 |
16 | layout_len = 4786
17 |
18 | assert (
19 | len(layout_text__property) == layout_len
20 | ), f"The layout-aware text should be {layout_len} characters long"
21 |
22 | layout_text__direct = page.get_layout_aware_text()
23 |
24 | assert (
25 | layout_text__property == layout_text__direct
26 | ), "The layout-aware text should be the same"
27 |
--------------------------------------------------------------------------------
/tests/schema/test_layout_models.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from docprompt import NormBBox
4 | from docprompt.schema.layout import BoundingPoly, Point
5 |
6 |
7 | def test_normbbox_utilities():
8 | bbox = NormBBox(x0=0, top=0, x1=1, bottom=1)
9 |
10 | # Test simple properties
11 | assert bbox[0] == 0
12 | assert bbox[1] == 0
13 | assert bbox[2] == 1
14 | assert bbox[3] == 1
15 |
16 | assert bbox.width == 1
17 | assert bbox.height == 1
18 | assert bbox.area == 1
19 | assert bbox.centroid == (0.5, 0.5)
20 | assert bbox.y_center == 0.5
21 | assert bbox.x_center == 0.5
22 |
23 | # Test equality
24 | assert bbox == NormBBox(x0=0, top=0, x1=1, bottom=1)
25 |
26 | # Test out of bounds
27 |
28 | with pytest.raises(ValueError):
29 | NormBBox(x0=0, top=0, x1=1, bottom=2)
30 |
31 | # Add two bboxes
32 |
33 | bbox_2 = NormBBox(x0=0.5, top=0.5, x1=1, bottom=1.0)
34 | combined_bbox = bbox + bbox_2
35 | assert combined_bbox == NormBBox(x0=0, top=0, x1=1.0, bottom=1.0)
36 |
37 | # Add two bboxes via combine
38 |
39 | combined_bbox = NormBBox.combine(bbox, bbox_2)
40 | assert combined_bbox == NormBBox(x0=0, top=0, x1=1.0, bottom=1.0)
41 |
42 | # Test from bounding poly
43 | bounding_poly = BoundingPoly(
44 | normalized_vertices=[
45 | Point(x=0, y=0),
46 | Point(x=1, y=0),
47 | Point(x=1, y=1),
48 | Point(x=0, y=1),
49 | ]
50 | )
51 |
52 | bbox = NormBBox.from_bounding_poly(bounding_poly)
53 |
54 | assert bbox == NormBBox(x0=0, top=0, x1=1, bottom=1)
55 |
56 | # Test contains
57 |
58 | small_bbox = NormBBox(x0=0.25, top=0.25, x1=0.75, bottom=0.75)
59 | big_bbox = NormBBox(x0=0, top=0, x1=1, bottom=1)
60 |
61 | assert small_bbox in big_bbox
62 | assert big_bbox not in small_bbox
63 |
64 | # Test Overlap
65 |
66 | assert small_bbox.x_overlap(big_bbox) == 0.5
67 | assert small_bbox.y_overlap(big_bbox) == 0.5
68 |
69 | # Should be commutative
70 | assert big_bbox.x_overlap(small_bbox) == 0.5
71 | assert big_bbox.y_overlap(small_bbox) == 0.5
72 |
--------------------------------------------------------------------------------
/tests/tasks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/tests/tasks/__init__.py
--------------------------------------------------------------------------------
/tests/tasks/classification/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Test the functionality of the classification task provider.
3 | """
4 |
--------------------------------------------------------------------------------
/tests/tasks/classification/test_anthropic.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import AsyncMock, patch
2 |
3 | import pytest
4 |
5 | from docprompt.tasks.classification.anthropic import (
6 | AnthropicClassificationProvider,
7 | AnthropicPageClassificationOutputParser,
8 | _prepare_messages,
9 | )
10 | from docprompt.tasks.classification.base import (
11 | ClassificationConfig,
12 | ClassificationOutput,
13 | ClassificationTypes,
14 | ConfidenceLevel,
15 | )
16 |
17 |
18 | class TestAnthropicPageClassificationOutputParser:
19 | @pytest.fixture
20 | def parser(self):
21 | return AnthropicPageClassificationOutputParser(
22 | name="anthropic",
23 | type=ClassificationTypes.SINGLE_LABEL,
24 | labels=["A", "B", "C"],
25 | confidence=True,
26 | )
27 |
28 | def test_parse_single_label(self, parser):
29 | text = "Reasoning: This is a test.\nAnswer: B\nConfidence: high"
30 | result = parser.parse(text)
31 |
32 | assert isinstance(result, ClassificationOutput)
33 | assert result.type == ClassificationTypes.SINGLE_LABEL
34 | assert result.labels == "B"
35 | assert result.score == ConfidenceLevel.HIGH
36 | assert result.provider_name == "anthropic"
37 |
38 | def test_parse_multi_label(self):
39 | parser = AnthropicPageClassificationOutputParser(
40 | name="anthropic",
41 | type=ClassificationTypes.MULTI_LABEL,
42 | labels=["X", "Y", "Z"],
43 | confidence=True,
44 | )
45 | text = (
46 | "Reasoning: This is a multi-label test.\nAnswer: X, Z\nConfidence: medium"
47 | )
48 | result = parser.parse(text)
49 |
50 | assert isinstance(result, ClassificationOutput)
51 | assert result.type == ClassificationTypes.MULTI_LABEL
52 | assert result.labels == ["X", "Z"]
53 | assert result.score == ConfidenceLevel.MEDIUM
54 | assert result.provider_name == "anthropic"
55 |
56 | def test_parse_binary(self):
57 | parser = AnthropicPageClassificationOutputParser(
58 | name="anthropic",
59 | type=ClassificationTypes.BINARY,
60 | labels=["YES", "NO"],
61 | confidence=False,
62 | )
63 | text = "Reasoning: This is a binary test.\nAnswer: YES"
64 | result = parser.parse(text)
65 |
66 | assert isinstance(result, ClassificationOutput)
67 | assert result.type == ClassificationTypes.BINARY
68 | assert result.labels == "YES"
69 | assert result.score is None
70 | assert result.provider_name == "anthropic"
71 |
72 | def test_parse_invalid_answer(self, parser):
73 | text = "Reasoning: This is an invalid test.\nAnswer: D\nConfidence: low"
74 | with pytest.raises(ValueError, match="Invalid label: D"):
75 | parser.parse(text)
76 |
77 |
78 | class TestAnthropicClassificationProvider:
79 | @pytest.fixture
80 | def provider(self):
81 | return AnthropicClassificationProvider()
82 |
83 | @pytest.fixture
84 | def mock_config(self):
85 | return ClassificationConfig(
86 | type=ClassificationTypes.SINGLE_LABEL,
87 | labels=["A", "B", "C"],
88 | confidence=True,
89 | )
90 |
91 | @pytest.mark.asyncio()
92 | async def test_ainvoke(self, provider, mock_config):
93 | mock_input = [b"image1", b"image2"]
94 | mock_completions = [
95 | "Reasoning: Test 1\nAnswer: A\nConfidence: high",
96 | "Reasoning: Test 2\nAnswer: B\nConfidence: medium",
97 | ]
98 |
99 | with patch(
100 | "docprompt.tasks.classification.anthropic._prepare_messages"
101 | ) as mock_prepare:
102 | mock_prepare.return_value = "mock_messages"
103 |
104 | with patch(
105 | "docprompt.utils.inference.run_batch_inference_anthropic",
106 | new_callable=AsyncMock,
107 | ) as mock_inference:
108 | mock_inference.return_value = mock_completions
109 |
110 | test_kwargs = {
111 | "test": "test"
112 | } # Test that kwargs are passed through to inference
113 | results = await provider._ainvoke(
114 | mock_input, mock_config, **test_kwargs
115 | )
116 |
117 | assert len(results) == 2
118 | assert all(isinstance(result, ClassificationOutput) for result in results)
119 | assert results[0].labels == "A"
120 | assert results[0].score == ConfidenceLevel.HIGH
121 | assert results[1].labels == "B"
122 | assert results[1].score == ConfidenceLevel.MEDIUM
123 |
124 | mock_prepare.assert_called_once_with(mock_input, mock_config)
125 | mock_inference.assert_called_once_with(
126 | "claude-3-haiku-20240307", "mock_messages", **test_kwargs
127 | )
128 |
129 | @pytest.mark.asyncio()
130 | async def test_ainvoke_with_error(self, provider, mock_config):
131 | mock_input = [b"image1"]
132 | mock_completions = ["Reasoning: Error test\nAnswer: Invalid\nConfidence: low"]
133 |
134 | with patch(
135 | "docprompt.tasks.classification.anthropic._prepare_messages"
136 | ) as mock_prepare:
137 | mock_prepare.return_value = "mock_messages"
138 |
139 | with patch(
140 | "docprompt.utils.inference.run_batch_inference_anthropic",
141 | new_callable=AsyncMock,
142 | ) as mock_inference:
143 | mock_inference.return_value = mock_completions
144 |
145 | with pytest.raises(ValueError, match="Invalid label: Invalid"):
146 | await provider._ainvoke(mock_input, mock_config)
147 |
148 | def test_prepare_messages(self, mock_config):
149 | imgs = [b"image1", b"image2"]
150 | config = mock_config
151 |
152 | result = _prepare_messages(imgs, config)
153 |
154 | assert len(result) == 2
155 | for msg_group in result:
156 | assert len(msg_group) == 1
157 | msg = msg_group[0]
158 | assert msg.role == "user"
159 | assert len(msg.content) == 2
160 |
161 | types = set([content.type for content in msg.content])
162 | assert types == set(["image_url", "text"])
163 |
--------------------------------------------------------------------------------
/tests/tasks/markerize/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/tests/tasks/markerize/__init__.py
--------------------------------------------------------------------------------
/tests/tasks/markerize/test_anthropic.py:
--------------------------------------------------------------------------------
1 | """
2 | Test the Anthropic implementation of the markerize task.
3 | """
4 |
5 | from unittest.mock import patch
6 |
7 | import pytest
8 |
9 | from docprompt.tasks.markerize.anthropic import (
10 | AnthropicMarkerizeProvider,
11 | _parse_result,
12 | _prepare_messages,
13 | )
14 | from docprompt.tasks.markerize.base import MarkerizeResult
15 | from docprompt.tasks.message import OpenAIComplexContent, OpenAIImageURL, OpenAIMessage
16 |
17 |
18 | @pytest.fixture
19 | def mock_image_bytes():
20 | return b"mock_image_bytes"
21 |
22 |
23 | class TestAnthropicMarkerizeProvider:
24 | @pytest.fixture
25 | def provider(self):
26 | return AnthropicMarkerizeProvider()
27 |
28 | def test_provider_name(self, provider):
29 | assert provider.name == "anthropic"
30 |
31 | @pytest.mark.asyncio
32 | async def test_ainvoke(self, provider, mock_image_bytes):
33 | mock_completions = ["# Test Markdown", "## Another Test"]
34 |
35 | with patch(
36 | "docprompt.tasks.markerize.anthropic._prepare_messages"
37 | ) as mock_prepare:
38 | with patch(
39 | "docprompt.utils.inference.run_batch_inference_anthropic"
40 | ) as mock_inference:
41 | mock_prepare.return_value = "mock_messages"
42 | mock_inference.return_value = mock_completions
43 |
44 | test_kwargs = {
45 | "test": "test"
46 | } # Test that kwargs are passed through to inference
47 | result = await provider._ainvoke(
48 | [mock_image_bytes, mock_image_bytes], **test_kwargs
49 | )
50 |
51 | assert len(result) == 2
52 | assert all(isinstance(r, MarkerizeResult) for r in result)
53 | assert result[0].raw_markdown == "# Test Markdown"
54 | assert result[1].raw_markdown == "## Another Test"
55 | assert all(r.provider_name == "anthropic" for r in result)
56 |
57 | mock_prepare.assert_called_once_with(
58 | [mock_image_bytes, mock_image_bytes]
59 | )
60 | mock_inference.assert_called_once_with(
61 | "claude-3-haiku-20240307", "mock_messages", **test_kwargs
62 | )
63 |
64 |
65 | def test_prepare_messages(mock_image_bytes):
66 | messages = _prepare_messages([mock_image_bytes])
67 |
68 | assert len(messages) == 1
69 | assert len(messages[0]) == 1
70 | assert isinstance(messages[0][0], OpenAIMessage)
71 | assert messages[0][0].role == "user"
72 | assert len(messages[0][0].content) == 2
73 | assert isinstance(messages[0][0].content[0], OpenAIComplexContent)
74 | assert messages[0][0].content[0].type == "image_url"
75 | assert isinstance(messages[0][0].content[0].image_url, OpenAIImageURL)
76 | assert messages[0][0].content[0].image_url.url == mock_image_bytes.decode("utf-8")
77 | assert isinstance(messages[0][0].content[1], OpenAIComplexContent)
78 | assert messages[0][0].content[1].type == "text"
79 | assert "Convert the image into markdown" in messages[0][0].content[1].text
80 |
81 |
82 | @pytest.mark.parametrize(
83 | "raw_markdown,expected",
84 | [
85 | ("# Test Markdown", "# Test Markdown"),
86 | ("## Another Test", "## Another Test"),
87 | ("Invalid markdown", ""),
88 | (" Trimmed ", "Trimmed"),
89 | ],
90 | )
91 | def test_parse_result(raw_markdown, expected):
92 | assert _parse_result(raw_markdown) == expected
93 |
--------------------------------------------------------------------------------
/tests/tasks/markerize/test_base.py:
--------------------------------------------------------------------------------
1 | """
2 | Test the basic functionality of the atomic components for the markerize task.
3 | """
4 |
5 | from unittest.mock import MagicMock, patch
6 |
7 | import pytest
8 |
9 | from docprompt.schema.pipeline.node.document import DocumentNode
10 | from docprompt.schema.pipeline.node.page import PageNode
11 | from docprompt.tasks.markerize.base import BaseMarkerizeProvider, MarkerizeResult
12 |
13 |
14 | class TestBaseMarkerizeProvider:
15 | """
16 | Test the base markerize provider to ensure that the task is properly handled.
17 | """
18 |
19 | @pytest.fixture
20 | def mock_document_node(self):
21 | mock_node = MagicMock(spec=DocumentNode)
22 | mock_node.page_nodes = [MagicMock(spec=PageNode) for _ in range(5)]
23 | for pnode in mock_node.page_nodes:
24 | pnode.rasterizer.rasterize.return_value = b"image"
25 | mock_node.__len__.return_value = len(mock_node.page_nodes)
26 | return mock_node
27 |
28 | @pytest.mark.parametrize(
29 | "start,stop,expected_keys,expected_results",
30 | [
31 | (2, 4, [2, 3, 4], {2: "RESULT-0", 3: "RESULT-1", 4: "RESULT-2"}),
32 | (3, None, [3, 4, 5], {3: "RESULT-0", 4: "RESULT-1", 5: "RESULT-2"}),
33 | (None, 2, [1, 2], {1: "RESULT-0", 2: "RESULT-1"}),
34 | (
35 | None,
36 | None,
37 | [1, 2, 3, 4, 5],
38 | {
39 | 1: "RESULT-0",
40 | 2: "RESULT-1",
41 | 3: "RESULT-2",
42 | 4: "RESULT-3",
43 | 5: "RESULT-4",
44 | },
45 | ),
46 | ],
47 | )
48 | def test_process_document_node_with_start_stop(
49 | self, mock_document_node, start, stop, expected_keys, expected_results
50 | ):
51 | class TestProvider(BaseMarkerizeProvider):
52 | name = "test"
53 |
54 | def _invoke(self, input, config, **kwargs):
55 | return [
56 | MarkerizeResult(raw_markdown=f"RESULT-{i}", provider_name="test")
57 | for i in range(len(input))
58 | ]
59 |
60 | provider = TestProvider()
61 | result = provider.process_document_node(
62 | mock_document_node, start=start, stop=stop
63 | )
64 |
65 | assert list(result.keys()) == expected_keys
66 | assert all(isinstance(v, MarkerizeResult) for v in result.values())
67 | assert {k: v.raw_markdown for k, v in result.items()} == expected_results
68 |
69 | with patch.object(provider, "_invoke") as mock_invoke:
70 | provider.process_document_node(mock_document_node, start=start, stop=stop)
71 | mock_invoke.assert_called_once()
72 | expected_invoke_length = len(expected_keys)
73 | assert len(mock_invoke.call_args[0][0]) == expected_invoke_length
74 |
75 | def test_process_document_node_rasterization(self, mock_document_node):
76 | class TestProvider(BaseMarkerizeProvider):
77 | name = "test"
78 |
79 | def _invoke(self, input, config, **kwargs):
80 | return [
81 | MarkerizeResult(raw_markdown=f"RESULT-{i}", provider_name="test")
82 | for i in range(len(input))
83 | ]
84 |
85 | provider = TestProvider()
86 | provider.process_document_node(mock_document_node)
87 |
88 | for page_node in mock_document_node.page_nodes:
89 | page_node.rasterizer.rasterize.assert_called_once_with("default")
90 |
--------------------------------------------------------------------------------
/tests/tasks/table_extraction/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/docprompt/Docprompt/4c489796743f78cf25c257e29c1794394b08f3c0/tests/tasks/table_extraction/__init__.py
--------------------------------------------------------------------------------
/tests/tasks/table_extraction/test_anthropic.py:
--------------------------------------------------------------------------------
1 | """
2 | Test the Anthropic implementation of the table extraction task.
3 | """
4 |
5 | from unittest.mock import patch
6 |
7 | import pytest
8 | from bs4 import BeautifulSoup
9 |
10 | from docprompt.tasks.message import OpenAIComplexContent, OpenAIImageURL, OpenAIMessage
11 | from docprompt.tasks.table_extraction.anthropic import (
12 | AnthropicTableExtractionProvider,
13 | _headers_from_tree,
14 | _prepare_messages,
15 | _rows_from_tree,
16 | _title_from_tree,
17 | parse_response,
18 | )
19 | from docprompt.tasks.table_extraction.schema import (
20 | TableCell,
21 | TableExtractionPageResult,
22 | TableHeader,
23 | TableRow,
24 | )
25 |
26 |
27 | @pytest.fixture
28 | def mock_image_bytes():
29 | return b"mock_image_bytes"
30 |
31 |
32 | class TestAnthropicTableExtractionProvider:
33 | @pytest.fixture
34 | def provider(self):
35 | return AnthropicTableExtractionProvider()
36 |
37 | def test_provider_name(self, provider):
38 | assert provider.name == "anthropic"
39 |
40 | @pytest.mark.asyncio
41 | async def test_ainvoke(self, provider, mock_image_bytes):
42 | mock_completions = [
43 | "",
44 | "",
45 | ]
46 |
47 | with patch(
48 | "docprompt.tasks.table_extraction.anthropic._prepare_messages"
49 | ) as mock_prepare:
50 | with patch(
51 | "docprompt.utils.inference.run_batch_inference_anthropic"
52 | ) as mock_inference:
53 | mock_prepare.return_value = "mock_messages"
54 | mock_inference.return_value = mock_completions
55 |
56 | result = await provider._ainvoke([mock_image_bytes, mock_image_bytes])
57 |
58 | assert len(result) == 2
59 | assert all(isinstance(r, TableExtractionPageResult) for r in result)
60 | assert result[0].tables[0].title == "Test Table"
61 | assert result[1].tables[0].title is None
62 | assert all(r.provider_name == "anthropic" for r in result)
63 |
64 | mock_prepare.assert_called_once_with(
65 | [mock_image_bytes, mock_image_bytes]
66 | )
67 | mock_inference.assert_called_once_with(
68 | "claude-3-haiku-20240307", "mock_messages"
69 | )
70 |
71 |
72 | def test_prepare_messages(mock_image_bytes):
73 | messages = _prepare_messages([mock_image_bytes])
74 |
75 | assert len(messages) == 1
76 | assert len(messages[0]) == 1
77 | assert isinstance(messages[0][0], OpenAIMessage)
78 | assert messages[0][0].role == "user"
79 | assert len(messages[0][0].content) == 2
80 | assert isinstance(messages[0][0].content[0], OpenAIComplexContent)
81 | assert messages[0][0].content[0].type == "image_url"
82 | assert isinstance(messages[0][0].content[0].image_url, OpenAIImageURL)
83 | assert messages[0][0].content[0].image_url.url == mock_image_bytes.decode()
84 | assert isinstance(messages[0][0].content[1], OpenAIComplexContent)
85 | assert messages[0][0].content[1].type == "text"
86 | assert (
87 | "Identify and extract all tables from the document"
88 | in messages[0][0].content[1].text
89 | )
90 |
91 |
92 | def test_parse_response():
93 | response = """
94 |
95 | Test Table
96 |
97 |
98 |
99 |
100 |
101 |
102 | Data1
103 | Data2
104 |
105 |
106 |
107 | """
108 | result = parse_response(response)
109 |
110 | assert isinstance(result, TableExtractionPageResult)
111 | assert len(result.tables) == 1
112 | assert result.tables[0].title == "Test Table"
113 | assert len(result.tables[0].headers) == 2
114 | assert result.tables[0].headers[0].text == "Col1"
115 | assert len(result.tables[0].rows) == 1
116 | assert result.tables[0].rows[0].cells[0].text == "Data1"
117 |
118 |
119 | def test_title_from_tree():
120 | soup = BeautifulSoup("")
121 | assert _title_from_tree(soup.table) == "Test Title"
122 |
123 | soup = BeautifulSoup("")
124 | assert _title_from_tree(soup.table) is None
125 |
126 |
127 | def test_headers_from_tree():
128 | soup = BeautifulSoup(
129 | "",
130 | )
131 | headers = _headers_from_tree(soup.table)
132 | assert len(headers) == 2
133 | assert all(isinstance(h, TableHeader) for h in headers)
134 | assert headers[0].text == "Col1"
135 |
136 | soup = BeautifulSoup("")
137 | assert _headers_from_tree(soup.table) == []
138 |
139 |
140 | def test_rows_from_tree():
141 | soup = BeautifulSoup(
142 | "",
143 | )
144 | rows = _rows_from_tree(soup.table)
145 | assert len(rows) == 1
146 | assert isinstance(rows[0], TableRow)
147 | assert len(rows[0].cells) == 2
148 | assert all(isinstance(c, TableCell) for c in rows[0].cells)
149 | assert rows[0].cells[0].text == "Data1"
150 |
151 | soup = BeautifulSoup("")
152 | assert _rows_from_tree(soup.table) == []
153 |
154 |
155 | @pytest.mark.parametrize(
156 | "input_str,sub_str,expected",
157 | [
158 | ("abcghi", "", [3, 24]),
159 | ("notables", "", []),
160 | ("", "", [0, 7, 14]),
161 | ],
162 | )
163 | def test_find_start_indices(input_str, sub_str, expected):
164 | from docprompt.tasks.table_extraction.anthropic import _find_start_indices
165 |
166 | assert _find_start_indices(input_str, sub_str) == expected
167 |
168 |
169 | @pytest.mark.parametrize(
170 | "input_str,sub_str,expected",
171 | [
172 | ("abc
def
ghi", "
", [11, 22]),
173 | ("notables", "
", []),
174 | ("
", "", [8, 16, 24]),
175 | ],
176 | )
177 | def test_find_end_indices(input_str, sub_str, expected):
178 | from docprompt.tasks.table_extraction.anthropic import _find_end_indices
179 |
180 | assert _find_end_indices(input_str, sub_str) == expected
181 |
--------------------------------------------------------------------------------
/tests/tasks/test_factory.py:
--------------------------------------------------------------------------------
1 | """
2 | Unit tests for the factory module to ensure task providers are instantiated correctly.
3 | """
4 |
5 | import pytest
6 |
7 | from docprompt.tasks.factory import (
8 | AmazonTaskProviderFactory,
9 | AnthropicTaskProviderFactory,
10 | GCPTaskProviderFactory,
11 | )
12 |
13 |
14 | class TestAnthropicTaskProviderFactory:
15 | @pytest.fixture
16 | def api_key(self):
17 | return "test-api-key"
18 |
19 | @pytest.fixture
20 | def environ_key(self, monkeypatch, api_key):
21 | monkeypatch.setenv("ANTHROPIC_API_KEY", api_key)
22 |
23 | yield api_key
24 |
25 | monkeypatch.delenv("ANTHROPIC_API_KEY")
26 |
27 | def test_validate_provider_w_explicit_key(self, api_key):
28 | api_key = api_key
29 | factory = AnthropicTaskProviderFactory(api_key=api_key)
30 | assert factory._credentials.kwargs == {"api_key": api_key}
31 |
32 | def test_validate_provider_w_environ_key(self, environ_key):
33 | factory = AnthropicTaskProviderFactory()
34 | assert factory._credentials.kwargs == {"api_key": environ_key}
35 |
36 | def test_get_page_classification_provider(self, environ_key):
37 | factory = AnthropicTaskProviderFactory()
38 | provider = factory.get_page_classification_provider()
39 |
40 | assert provider._default_invoke_kwargs == {"api_key": environ_key}
41 | assert provider.name == "anthropic"
42 |
43 | def test_get_page_table_extraction_provider(self, environ_key):
44 | factory = AnthropicTaskProviderFactory()
45 | provider = factory.get_page_table_extraction_provider()
46 |
47 | assert provider._default_invoke_kwargs == {"api_key": environ_key}
48 | assert provider.name == "anthropic"
49 |
50 | def test_get_page_markerization_provider(self, environ_key):
51 | factory = AnthropicTaskProviderFactory()
52 | provider = factory.get_page_markerization_provider()
53 |
54 | assert provider._default_invoke_kwargs == {"api_key": environ_key}
55 | assert provider.name == "anthropic"
56 |
57 |
58 | class TestGoogleTaskProviderFactory:
59 | @pytest.fixture
60 | def sa_file(self):
61 | return "/path/to/file"
62 |
63 | @pytest.fixture
64 | def environ_file(self, monkeypatch, sa_file):
65 | monkeypatch.setenv("GCP_SERVICE_ACCOUNT_FILE", sa_file)
66 |
67 | yield sa_file
68 |
69 | monkeypatch.delenv("GCP_SERVICE_ACCOUNT_FILE")
70 |
71 | def test_validate_provider_w_explicit_sa_file(self, sa_file):
72 | factory = GCPTaskProviderFactory(service_account_file=sa_file)
73 | assert factory._credentials.kwargs == {"service_account_file": sa_file}
74 |
75 | def test_validate_provider_w_environ_sa_file(self, environ_file):
76 | factory = GCPTaskProviderFactory()
77 | assert factory._credentials.kwargs == {"service_account_file": environ_file}
78 |
79 | def test_get_page_ocr_provider(self, environ_file):
80 | factory = GCPTaskProviderFactory()
81 |
82 | project_id = "project-id"
83 | processor_id = "processor-id"
84 |
85 | provider = factory.get_page_ocr_provider(project_id, processor_id)
86 |
87 | assert provider._default_invoke_kwargs == {"service_account_file": environ_file}
88 | assert provider.name == "gcp_documentai"
89 |
90 |
91 | class TestAmazonTaskProviderFactory:
92 | @pytest.fixture
93 | def aws_creds(self):
94 | return {
95 | "aws_access_key_id": "test_access_key",
96 | "aws_secret_access_key": "test_secret_key",
97 | "aws_region": "us-west-2",
98 | }
99 |
100 | @pytest.fixture
101 | def environ_creds(self, monkeypatch, aws_creds):
102 | monkeypatch.setenv("AWS_ACCESS_KEY_ID", aws_creds["aws_access_key_id"])
103 | monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", aws_creds["aws_secret_access_key"])
104 | monkeypatch.setenv("AWS_DEFAULT_REGION", aws_creds["aws_region"])
105 |
106 | yield aws_creds
107 |
108 | monkeypatch.delenv("AWS_ACCESS_KEY_ID")
109 | monkeypatch.delenv("AWS_SECRET_ACCESS_KEY")
110 | monkeypatch.delenv("AWS_DEFAULT_REGION")
111 |
112 | def test_validate_provider_w_explicict_creds(self, aws_creds):
113 | factory = AmazonTaskProviderFactory(**aws_creds)
114 | assert factory._credentials.kwargs == aws_creds
115 |
116 | def test_validate_provider_w_environ_creds(self, environ_creds):
117 | factory = AmazonTaskProviderFactory()
118 | assert factory._credentials.kwargs == environ_creds
119 |
120 | def test_get_page_ocr_provider(self, environ_creds):
121 | factory = AmazonTaskProviderFactory()
122 | provider = factory.get_page_ocr_provider()
123 |
124 | assert provider._default_invoke_kwargs == environ_creds
125 | assert provider.name == "aws_textract"
126 |
--------------------------------------------------------------------------------
/tests/tasks/test_result.py:
--------------------------------------------------------------------------------
1 | """
2 | Test the Page Level and Document Level task results operate as expected.
3 | """
4 |
5 | from unittest.mock import MagicMock
6 |
7 | from docprompt import DocumentNode
8 | from docprompt.schema.pipeline import BaseMetadata
9 | from docprompt.tasks.result import BaseDocumentResult, BasePageResult, BaseResult
10 |
11 |
12 | def test_task_key():
13 | class TestResult(BaseResult):
14 | task_name = "test"
15 |
16 | def contribute_to_document_node(self, document_node, page_number=None):
17 | pass
18 |
19 | result = TestResult(provider_name="test")
20 | assert result.task_key == "test_test"
21 |
22 |
23 | def test_base_document_result_contribution():
24 | class TestDocumentResult(BaseDocumentResult):
25 | task_name = "test"
26 |
27 | result = TestDocumentResult(
28 | provider_name="test", document_name="test", file_hash="test"
29 | )
30 |
31 | mock_meta = MagicMock(spec=BaseMetadata)
32 | mock_meta.task_results = {}
33 | mock_node = MagicMock(spec=DocumentNode)
34 | mock_node.metadata = mock_meta
35 |
36 | result.contribute_to_document_node(mock_node)
37 |
38 | assert mock_meta.task_results["test_test"] == result
39 |
40 |
41 | def test_base_page_result_contribution():
42 | class TestPageResult(BasePageResult):
43 | task_name = "test"
44 |
45 | result = TestPageResult(provider_name="test")
46 |
47 | num_pages = 3
48 | mock_meta = MagicMock(spec=BaseMetadata)
49 | mock_meta.task_results = {}
50 | mock_node = MagicMock(spec=DocumentNode)
51 | mock_node.page_nodes = [MagicMock() for _ in range(num_pages)]
52 |
53 | # Test contributing to a specific page
54 | mock_node.page_nodes[0].metadata = mock_meta
55 |
56 | mock_node.__len__.return_value = num_pages
57 |
58 | result.contribute_to_document_node(mock_node, page_number=1)
59 |
60 | assert mock_node.page_nodes[0].metadata.task_results["test_test"] == result
61 |
--------------------------------------------------------------------------------
/tests/tasks/test_task_provider.py:
--------------------------------------------------------------------------------
1 | """The test suite for the base task provider seeks to ensure that all of the
2 | builtin functionality of the BaseTaskProvider is proeprly implemented.
3 | """
4 |
5 | import pytest
6 |
7 | from docprompt.tasks.base import AbstractTaskProvider
8 |
9 |
10 | class TestAbstractTaskProviderBaseFunctionliaty:
11 | """
12 | Test that the BaseTaskProvider interface provides the correct expected basic
13 | functionality to be inherited by all subclasses.
14 |
15 | This includes:
16 | - the model validaton asserts `name` and `capabilities` are required
17 | - the intialization of the model properly sets invoke kwargs
18 | - the `ainvoke` method calling the `_ainvoke` method
19 | - the `invoke` method calling the `_invoke` method
20 | """
21 |
22 | def test_model_validator_raises_error_on_missing_name(self):
23 | class BadTaskProvider(AbstractTaskProvider):
24 | capabilities = []
25 |
26 | with pytest.raises(ValueError):
27 | BadTaskProvider.validate_class_vars({})
28 |
29 | def test_model_validator_raises_error_on_missing_capabilities(self):
30 | class BadTaskProvider(AbstractTaskProvider):
31 | name = "BadTaskProvider"
32 |
33 | with pytest.raises(ValueError):
34 | BadTaskProvider.validate_class_vars({})
35 |
36 | def test_model_validator_raises_error_on_empty_capabilities(self):
37 | class BadTaskProvider(AbstractTaskProvider):
38 | name = "BadTaskProvider"
39 | capabilities = []
40 |
41 | with pytest.raises(ValueError):
42 | BadTaskProvider.validate_class_vars({})
43 |
44 | def test_init_no_invoke_kwargs(self):
45 | class TestTaskProvider(AbstractTaskProvider):
46 | name = "TestTaskProvider"
47 | capabilities = ["test"]
48 |
49 | provider = TestTaskProvider()
50 |
51 | assert provider._default_invoke_kwargs == {}
52 |
53 | def test_init_with_invoke_kwargs(self):
54 | class TestTaskProvider(AbstractTaskProvider):
55 | name = "TestTaskProvider"
56 | capabilities = ["test"]
57 |
58 | kwargs = {"test": "test"}
59 | provider = TestTaskProvider(invoke_kwargs=kwargs)
60 |
61 | assert provider._default_invoke_kwargs == kwargs
62 |
63 | def test_init_with_fields_and_invoke_kwargs(self):
64 | class TestTaskProvider(AbstractTaskProvider):
65 | name = "TestTaskProvider"
66 | capabilities = ["test"]
67 |
68 | foo: str
69 |
70 | kwargs = {"test": "test"}
71 | provider = TestTaskProvider(foo="bar", invoke_kwargs=kwargs)
72 |
73 | assert provider._default_invoke_kwargs == kwargs
74 | assert provider.foo == "bar"
75 |
76 | @pytest.mark.asyncio
77 | async def test_ainvoke_calls__ainvoke(self):
78 | class TestTaskProvider(AbstractTaskProvider):
79 | name = "TestTaskProvider"
80 | capabilities = ["test"]
81 |
82 | async def _ainvoke(self, input, config=None, **kwargs):
83 | return input
84 |
85 | provider = TestTaskProvider()
86 |
87 | assert await provider.ainvoke([1, 2, 3]) == [1, 2, 3]
88 |
89 | def test_invoke_calls__invoke(self):
90 | class TestTaskProvider(AbstractTaskProvider):
91 | name = "TestTaskProvider"
92 | capabilities = ["test"]
93 |
94 | def _invoke(self, input, config=None, **kwargs):
95 | return input
96 |
97 | provider = TestTaskProvider()
98 |
99 | assert provider.invoke([1, 2, 3]) == [1, 2, 3]
100 |
--------------------------------------------------------------------------------
/tests/test_date_extraction.py:
--------------------------------------------------------------------------------
1 | from datetime import date
2 |
3 | from docprompt.utils.date_extraction import extract_dates_from_text
4 |
5 | STRING_A = """
6 | There was a meeting on 2021-01-01 and another on 2021-01-02.
7 |
8 | Meanwhile on September 1, 2021, there was a third meeting.
9 |
10 | The final meeting was on 4/5/2021.
11 | """
12 |
13 |
14 | def test_date_extraction():
15 | dates = extract_dates_from_text(STRING_A)
16 |
17 | assert len(dates) == 5
18 |
19 | dates.sort(key=lambda x: x[0])
20 |
21 | assert dates[0][0] == date(2021, 1, 1)
22 | assert dates[0][1] == "2021-01-01"
23 |
24 | assert dates[1][0] == date(2021, 1, 2)
25 | assert dates[1][1] == "2021-01-02"
26 |
27 | assert dates[2][0] == date(2021, 4, 5)
28 | assert dates[2][1] == "4/5/2021"
29 |
30 | assert dates[3][0] == date(2021, 5, 4)
31 | assert dates[3][1] == "4/5/2021"
32 |
33 | assert dates[4][0] == date(2021, 9, 1)
34 | assert dates[4][1] == "September 1, 2021"
35 |
--------------------------------------------------------------------------------
/tests/test_search.py:
--------------------------------------------------------------------------------
1 | import pickle
2 |
3 | from pytest import raises
4 |
5 | from docprompt import DocumentNode, load_document
6 |
7 | from .fixtures import PDF_FIXTURES
8 |
9 |
10 | def test_search():
11 | document = load_document(PDF_FIXTURES[0].get_full_path())
12 | document_node = DocumentNode.from_document(document)
13 |
14 | ocr_results = PDF_FIXTURES[0].get_ocr_results()
15 |
16 | with raises(ValueError):
17 | document_node.refresh_locator()
18 |
19 | for page_num, ocr_result in ocr_results.items():
20 | ocr_result.contribute_to_document_node(document_node, page_number=page_num)
21 |
22 | print(document_node[0].metadata.task_results)
23 |
24 | assert document_node._locator is None # Ensure the locator is not set
25 |
26 | # Need to make sure an ocr_key is set to avoid ValueError
27 | locator = document_node.locator
28 |
29 | result = locator.search("word that doesn't exist")
30 |
31 | assert len(result) == 0
32 |
33 | result_all_pages = locator.search("and")
34 |
35 | assert len(result_all_pages) == 50
36 |
37 | result_page_1 = locator.search("rooted", page_number=1)
38 |
39 | assert len(result_page_1) == 1
40 |
41 | result_multiple_words = locator.search("MMAX2 system", page_number=1)
42 |
43 | assert len(result_multiple_words) == 1
44 |
45 | sources = result_multiple_words[0].text_location.source_blocks
46 |
47 | assert len(sources) == 2
48 |
49 | result_multiple_words = locator.search(
50 | "MMAX2 system", page_number=1, refine_to_word=False
51 | )
52 |
53 | assert len(result_multiple_words) == 1
54 |
55 | sources = result_multiple_words[0].text_location.source_blocks
56 |
57 | assert len(sources) == 1
58 |
59 | n_best = locator.search_n_best("and", n=3)
60 |
61 | assert len(n_best) == 3
62 |
63 | raw_search = locator.search_raw('content:"rooted"')
64 |
65 | assert len(raw_search) == 1
66 |
67 |
68 | def test_pickling__removes_locator_document_basis():
69 | document = load_document(PDF_FIXTURES[0].get_full_path())
70 | document_node = DocumentNode.from_document(document)
71 |
72 | ocr_results = PDF_FIXTURES[0].get_ocr_results()
73 |
74 | for page_num, ocr_result in ocr_results.items():
75 | ocr_result.contribute_to_document_node(document_node, page_number=page_num)
76 |
77 | result_page_1 = document_node.locator.search("rooted", page_number=1)
78 |
79 | assert len(result_page_1) == 1
80 |
81 | dumped = pickle.dumps(document_node)
82 |
83 | loaded = pickle.loads(dumped)
84 |
85 | assert loaded._locator is None
86 |
87 |
88 | def test_pickling__removes_locator_page_basis():
89 | document = load_document(PDF_FIXTURES[0].get_full_path())
90 | document_node = DocumentNode.from_document(document)
91 |
92 | ocr_results = PDF_FIXTURES[0].get_ocr_results()
93 |
94 | for page_num, ocr_result in ocr_results.items():
95 | ocr_result.contribute_to_document_node(document_node, page_number=page_num)
96 |
97 | page = document_node.page_nodes[0]
98 |
99 | result_page_1 = page.search("rooted")
100 |
101 | assert len(result_page_1) == 1
102 |
103 | dumped = pickle.dumps(page)
104 |
105 | loaded = pickle.loads(dumped)
106 |
107 | assert loaded.document._locator is None
108 |
--------------------------------------------------------------------------------
/tests/test_threadpool.py:
--------------------------------------------------------------------------------
1 | from concurrent.futures import ThreadPoolExecutor, as_completed
2 | from pathlib import Path
3 |
4 | import pytest
5 |
6 | from docprompt import load_document
7 | from docprompt.utils.splitter import pdf_split_iter_with_max_bytes
8 |
9 |
10 | def do_split(document):
11 | return list(
12 | pdf_split_iter_with_max_bytes(document.file_bytes, 15, 1024 * 1024 * 15)
13 | )
14 |
15 |
16 | @pytest.mark.skip(reason="Fixures are missing for this test")
17 | def test_document_split_in_threadpool__does_not_hang():
18 | source_dir = Path(__file__).parent.parent / "data" / "threadpool_test"
19 |
20 | documents = [load_document(file) for file in source_dir.iterdir()]
21 |
22 | futures = []
23 |
24 | with ThreadPoolExecutor() as executor:
25 | for document in documents:
26 | future = executor.submit(do_split, document)
27 |
28 | futures.append(future)
29 |
30 | for future in as_completed(futures):
31 | assert future.result() is not None
32 |
--------------------------------------------------------------------------------
/tests/util.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Dict, Optional
3 |
4 | from pydantic import BaseModel, Field, PositiveInt, TypeAdapter
5 |
6 | from docprompt.tasks.ocr.result import OcrPageResult
7 |
8 | FIXTURE_PATH = Path(__file__).parent / "fixtures"
9 |
10 |
11 | OCR_ADAPTER = TypeAdapter(Dict[int, OcrPageResult])
12 |
13 |
14 | class PdfFixture(BaseModel):
15 | name: str = Field(description="The name of the fixture")
16 | page_count: PositiveInt = Field(description="The number of pages in the fixture")
17 | file_hash: str = Field(description="The expected hash of the fixture")
18 | ocr_name: Optional[str] = Field(
19 | description="The path to the OCR results for the fixture", default=None
20 | )
21 |
22 | def get_full_path(self):
23 | return FIXTURE_PATH / self.name
24 |
25 | def get_bytes(self):
26 | return self.get_full_path().read_bytes()
27 |
28 | def get_ocr_results(self):
29 | if not self.ocr_name:
30 | return None
31 |
32 | ocr_path = FIXTURE_PATH / self.ocr_name
33 |
34 | return OCR_ADAPTER.validate_json(ocr_path.read_text())
35 |
36 | def get_document_node(self):
37 | from docprompt import load_document_node
38 |
39 | document = load_document_node(self.get_bytes())
40 |
41 | ocr_results = self.get_ocr_results()
42 |
43 | if ocr_results:
44 | for page_number, ocr_result in self.get_ocr_results().items():
45 | page = document.page_nodes[page_number - 1]
46 | page.metadata.ocr_results = ocr_result
47 |
48 | return document
49 |
--------------------------------------------------------------------------------