├── .flake8
├── .github
└── workflows
│ └── pre-commit.yml
├── .gitignore
├── .isort.cfg
├── .pre-commit-config.yaml
├── LICENSE
├── Makefile
├── README.md
├── docquery_example.ipynb
├── pyproject.toml
├── setup.py
├── src
└── docquery
│ ├── __init__.py
│ ├── cmd
│ ├── __init__.py
│ ├── __main__.py
│ └── scan.py
│ ├── config.py
│ ├── document.py
│ ├── ext
│ ├── __init__.py
│ ├── functools.py
│ ├── itertools.py
│ ├── model.py
│ ├── pipeline_document_classification.py
│ ├── pipeline_document_question_answering.py
│ └── qa_helpers.py
│ ├── find_leaf_nodes.js
│ ├── ocr_reader.py
│ ├── transformers_patch.py
│ ├── version.py
│ └── web.py
└── tests
├── test_classification_end_to_end.py
├── test_end_to_end.py
├── test_ocr_reader.py
└── test_web_driver.py
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 119
3 | ignore = E402, E203, E501, W503
4 |
--------------------------------------------------------------------------------
/.github/workflows/pre-commit.yml:
--------------------------------------------------------------------------------
1 | # https://pre-commit.com
2 | # This GitHub Action assumes that the repo contains a valid .pre-commit-config.yaml file.
3 | # Using pre-commit.ci is even better that using GitHub Actions for pre-commit.
4 | name: pre-commit
5 | on:
6 | pull_request:
7 | branches: [main]
8 | push:
9 | branches: [main]
10 | workflow_dispatch:
11 | jobs:
12 | pre-commit:
13 | runs-on: ubuntu-latest
14 | steps:
15 | - uses: actions/checkout@v3
16 | - uses: actions/setup-python@v4
17 | with:
18 | python-version: 3.x
19 | - run: pip install pre-commit
20 | - run: pre-commit --version
21 | - run: pre-commit install
22 | - run: pre-commit run --all-files
23 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | venv
2 | *.swp
3 | *.swo
4 | *.pyc
5 | .DS_Store
6 | __pycache__
7 | dist
8 | src/docquery.egg-info
9 | docs
10 | .vscode/settings.json
11 | build
12 |
--------------------------------------------------------------------------------
/.isort.cfg:
--------------------------------------------------------------------------------
1 | [settings]
2 | line_length=119
3 | multi_line_output=3
4 | use_parentheses=true
5 | lines_after_imports=2
6 | include_trailing_comma=True
7 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: "https://github.com/pre-commit/pre-commit-hooks"
3 | rev: v4.3.0
4 | hooks:
5 | - id: check-yaml
6 | - id: end-of-file-fixer
7 | - id: trailing-whitespace
8 | - repo: "https://github.com/psf/black"
9 | rev: 22.6.0
10 | hooks:
11 | - id: black
12 | files: ./
13 | - repo: "https://github.com/PyCQA/isort"
14 | rev: 5.10.1
15 | hooks:
16 | - id: isort
17 | args:
18 | - --settings-path
19 | - .isort.cfg
20 | files: ./
21 | - repo: https://github.com/codespell-project/codespell
22 | rev: v2.2.1
23 | hooks:
24 | - id: codespell
25 |
26 | - repo: https://github.com/pre-commit/mirrors-prettier
27 | rev: v2.7.1
28 | hooks:
29 | - id: prettier
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 impira
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | all: build
2 |
3 | VERSION=$(shell python -c 'from src.docquery.version import VERSION; print(VERSION)')
4 |
5 | .PHONY: build publish clean
6 | build:
7 | python3 -m build
8 |
9 | publish: build
10 | python3 -m twine upload dist/docquery-${VERSION}*
11 |
12 | clean:
13 | rm -rf dist/*
14 |
15 |
16 | VENV_INITIALIZED := venv/.initialized
17 |
18 | ${VENV_INITIALIZED}:
19 | rm -rf venv && python3 -m venv venv
20 | @touch ${VENV_INITIALIZED}
21 |
22 | VENV_PYTHON_PACKAGES := venv/.python_packages
23 |
24 | ${VENV_PYTHON_PACKAGES}: ${VENV_INITIALIZED} setup.py
25 | bash -c 'source venv/bin/activate && python -m pip install --upgrade pip setuptools'
26 | bash -c 'source venv/bin/activate && python -m pip install -e .[dev]'
27 | @touch $@
28 |
29 | VENV_PRE_COMMIT := venv/.pre_commit
30 |
31 | ${VENV_PRE_COMMIT}: ${VENV_PYTHON_PACKAGES}
32 | bash -c 'source venv/bin/activate && pre-commit install'
33 | @touch $@
34 |
35 | .PHONY: develop fixup test
36 | develop: ${VENV_PRE_COMMIT}
37 | @echo 'Run "source venv/bin/activate" to enter development mode'
38 |
39 | fixup:
40 | pre-commit run --all-files
41 |
42 | test:
43 | python -m pytest -s -v ./tests/
44 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | NOTE: DocQuery is not actively maintained anymore. We still welcome contributions and discussions among the community!
4 |
5 | # DocQuery: Document Query Engine Powered by Large Language Models
6 |
7 | [](https://huggingface.co/spaces/impira/docquery)
8 | [](https://github.com/impira/docquery/blob/main/docquery_example.ipynb)
9 | [](https://pypi.org/project/docquery/)
10 | [](https://discord.gg/HucNfTtx7V)
11 | [](https://pepy.tech/project/docquery)
12 |
13 |
14 |
15 | DocQuery is a library and command-line tool that makes it easy to analyze semi-structured and unstructured documents (PDFs, scanned
16 | images, etc.) using large language models (LLMs). You simply point DocQuery at one or more documents and specify a
17 | question you want to ask. DocQuery is created by the team at [Impira](https://impira.com?utm_source=github&utm_medium=referral&utm_campaign=docquery).
18 |
19 | ## Quickstart (CLI)
20 |
21 | To install `docquery`, you can simply run `pip install docquery`. This will install the command line tool as well as the library.
22 | If you want to run OCR on images, then you must also install the [tesseract](https://github.com/tesseract-ocr/tesseract) library:
23 |
24 | - Mac OS X (using [Homebrew](https://brew.sh/)):
25 |
26 | ```sh
27 | brew install tesseract
28 | ```
29 |
30 | - Ubuntu:
31 |
32 | ```sh
33 | apt install tesseract-ocr
34 | ```
35 |
36 | `docquery` scan allows you to ask one or more questions to a single document or directory of files. For example, you can
37 | find the invoice number with:
38 |
39 | ```bash
40 | docquery scan "What is the invoice number?" https://templates.invoicehome.com/invoice-template-us-neat-750px.png
41 | ```
42 |
43 | If you have a folder of documents on your machine, you can run something like
44 |
45 | ```bash
46 | docquery scan "What is the effective date?" /path/to/contracts/folder
47 | ```
48 |
49 | to determine the effective date of every document in the folder.
50 |
51 | ## Quickstart (Library)
52 |
53 | DocQuery can also be used as a library. It contains two basic abstractions: (1) a `DocumentQuestionAnswering` pipeline
54 | that makes it simple to ask questions of documents and (2) a `Document` abstraction that can parse various types of documents
55 | to feed into the pipeline.
56 |
57 | ```python
58 | >>> from docquery import document, pipeline
59 | >>> p = pipeline('document-question-answering')
60 | >>> doc = document.load_document("/path/to/document.pdf")
61 | >>> for q in ["What is the invoice number?", "What is the invoice total?"]:
62 | ... print(q, p(question=q, **doc.context))
63 | ```
64 |
65 | ## Use cases
66 |
67 | DocQuery excels at a number of use cases involving structured, semi-structured, or unstructured documents. You can ask questions about
68 | invoices, contracts, forms, emails, letters, receipts, and many more. You can also classify documents. We will continue evolving the model,
69 | offer more modeling options, and expanding the set of supported documents. We welcome feedback, requests, and of course contributions to
70 | help achieve this vision.
71 |
72 | ## How it works
73 |
74 | Under the hood, docquery uses a pre-trained zero-shot language model, based on [LayoutLM](https://arxiv.org/abs/1912.13318), that has been
75 | fine-tuned for a question-answering task. The model is trained using a combination of [SQuAD2.0](https://rajpurkar.github.io/SQuAD-explorer/)
76 | and [DocVQA](https://rrc.cvc.uab.es/?ch=17) which make it particularly well suited for complex visual question answering tasks on
77 | a wide variety of documents. The underlying model is also published on HuggingFace as [impira/layoutlm-document-qa](https://huggingface.co/impira/layoutlm-document-qa)
78 | which you can access directly.
79 |
80 | ## Limitations
81 |
82 | DocQuery is intended to have a small install footprint and be simple to work with. As a result, it has some limitations:
83 |
84 | - Models must be pre-trained. Although DocQuery uses a zero-shot model that can adapt based on the question you provide, it does not learn from your data.
85 | - Support for images and PDFs. Currently DocQuery supports images and PDFs, with or without embedded text. It does not support word documents, emails, spreadsheets, etc.
86 | - Scalar text outputs. DocQuery only produces text outputs (answers). It does not support richer scalar types (i.e. it treats numbers and dates as strings) or tables.
87 |
88 | ## Advanced features
89 |
90 | ### Using Donut 🍩
91 |
92 | If you'd like to test `docquery` with [Donut](https://arxiv.org/abs/2111.15664), you must install the required extras:
93 |
94 | ```bash
95 | pip install docquery[donut]
96 | ```
97 |
98 | You can then run
99 |
100 | ```bash
101 | docquery scan "What is the effective date?" /path/to/contracts/folder --checkpoint 'naver-clova-ix/donut-base-finetuned-docvqa'
102 | ```
103 |
104 | ### Classifying documents
105 |
106 | To classify documents, you simply add the `--classify` argument to `scan`. You can specify any [image classification](https://huggingface.co/models?pipeline_tag=image-classification&sort=downloads)
107 | model on Hugging Face's hub. By default, the classification pipeline uses [Donut](https://huggingface.co/spaces/nielsr/donut-rvlcdip) (which requires
108 | the installation instructions above):
109 |
110 | ```bash
111 |
112 | # Classify documents
113 | docquery scan --classify /path/to/contracts/folder --checkpoint 'naver-clova-ix/donut-base-finetuned-docvqa'
114 |
115 | # Classify documents and ask a question too
116 | docquery scan --classify "What is the effective date?" /path/to/contracts/folder --checkpoint 'naver-clova-ix/donut-base-finetuned-docvqa'
117 | ```
118 |
119 | ### Scraping webpages
120 |
121 | DocQuery can read files through HTTP/HTTPs out of the box. However, if you want to read HTML documents, you can do that too by installing the
122 | `[web]` extension. The extension uses the [webdriver-manager](https://pypi.org/project/webdriver-manager/) library which can install a Chrome
123 | driver on your system automatically, but you'll need to make sure Chrome is installed globally.
124 |
125 | ```
126 | # Find the top post on hacker news
127 | docquery scan "What is the #1 post's title?" https://news.ycombinator.com
128 | ```
129 |
130 | ## Where to go from here
131 |
132 | DocQuery is a swiss army knife tool for working with documents and experiencing the power of modern machine learning. You can use it
133 | just about anywhere, including behind a firewall on sensitive data, and test it with a wide variety of documents. Our hope is that
134 | DocQuery enables many creative use cases for document understanding by making it simple and easy to ask questions from your documents.
135 |
136 | When you run DocQuery for the first time, it will download some files (e.g. the models and some library code from HuggingFace). However,
137 | nothing leaves your computer -- the OCR is done locally, models run locally, etc. This comes with the benefit of security and privacy;
138 | however, it comes at the cost of runtime performance and some accuracy.
139 |
140 | If you find yourself wondering how to achieve higher accuracy, work with more file types, teach the model with your own data, have
141 | a human-in-the-loop workflow, or query the data you're extracting, then do not fear -- you are running into the challenges that
142 | every organization does while putting document AI into production. The [Impira](https://www.impira.com/) platform is designed to
143 | solve these problems in an easy and intuitive way. Impira comes with a QA model that is additionally trained on proprietary datasets
144 | and can achieve 95%+ accuracy out-of-the-box for most use cases. It also has an intuitive UI that enables subject matter experts to label
145 | and improve the models, as well as an API that makes integration a breeze. Please [sign up for the product](https://www.impira.com/signup) or
146 | [reach out to us](info@impira.com) for more details.
147 |
148 | ## Status
149 |
150 | DocQuery is a new project. Although the underlying models are running in production, we've just recently released our code in open source
151 | and are actively working with the OSS community to upstream some of the changes we've made (e.g. [the model](https://github.com/huggingface/transformers/pull/18407)
152 | and [pipeline](https://github.com/huggingface/transformers/pull/18414)). DocQuery is rapidly changing, and we are likely to make breaking
153 | API changes. If you would like to run it in production, then we suggest pinning a version or commit hash. Either way, please get in touch
154 | with us at [oss@impira.com](mailto:oss@impira.com) with any questions or feedback.
155 |
156 | ## Acknowledgements
157 |
158 | DocQuery would not be possible without the contributions of many open source projects:
159 |
160 | - [pdfplumber](https://github.com/jsvine/pdfplumber) / [pdfminer.six](https://github.com/pdfminer/pdfminer.six)
161 | - [Pillow](https://pillow.readthedocs.io/en/stable/)
162 | - [pytorch](https://pytorch.org/)
163 | - [tesseract](https://github.com/tesseract-ocr/tesseract) / [pytesseract](https://pypi.org/project/pytesseract/)
164 | - [transformers](https://github.com/impira/transformers)
165 |
166 | and many others!
167 |
168 | ## License
169 |
170 | This project is licensed under the [MIT license](LICENSE).
171 |
172 | It contains code that is copied and adapted from transformers (),
173 | which is [Apache 2.0 licensed](http://www.apache.org/licenses/LICENSE-2.0). Files containing this code have
174 | been marked as such in their comments.
175 |
--------------------------------------------------------------------------------
/docquery_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "docquery_example.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "authorship_tag": "ABX9TyNxYAeRjZTeKNeMu6iJvYRw",
10 | "include_colab_link": true
11 | },
12 | "kernelspec": {
13 | "name": "python3",
14 | "display_name": "Python 3"
15 | },
16 | "language_info": {
17 | "name": "python"
18 | }
19 | },
20 | "cells": [
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {
24 | "id": "view-in-github",
25 | "colab_type": "text"
26 | },
27 | "source": [
28 | "
"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": null,
34 | "metadata": {
35 | "id": "yS9UNjHnAAS9"
36 | },
37 | "outputs": [],
38 | "source": [
39 | "!git clone https://github.com/impira/docquery.git\n",
40 | "!sudo apt install tesseract-ocr\n",
41 | "!sudo apt-get install poppler-utils\n",
42 | "!cd docquery && pip install .[all]"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "source": [
48 | "!docquery scan \"who authored this paper?\" https://arxiv.org/pdf/2101.07597.pdf"
49 | ],
50 | "metadata": {
51 | "id": "bKRRY5u2DV52"
52 | },
53 | "execution_count": null,
54 | "outputs": []
55 | }
56 | ]
57 | }
58 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 119
3 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import setuptools
4 |
5 | dir_name = os.path.abspath(os.path.dirname(__file__))
6 |
7 | version_contents = {}
8 | with open(os.path.join(dir_name, "src", "docquery", "version.py"), encoding="utf-8") as f:
9 | exec(f.read(), version_contents)
10 |
11 | with open(os.path.join(dir_name, "README.md"), "r", encoding="utf-8") as f:
12 | long_description = f.read()
13 |
14 | install_requires = [
15 | "torch >= 1.0",
16 | "pdf2image",
17 | "pdfplumber",
18 | "Pillow",
19 | "pydantic",
20 | "pytesseract", # TODO: Test what happens if the host machine does not have tesseract installed
21 | "requests",
22 | "easyocr",
23 | "transformers >= 4.23",
24 | ]
25 | extras_require = {
26 | "dev": [
27 | "black",
28 | "build",
29 | "flake8",
30 | "flake8-isort",
31 | "isort==5.10.1",
32 | "pre-commit",
33 | "pytest",
34 | "twine",
35 | ],
36 | "donut": [
37 | "sentencepiece",
38 | "protobuf<=3.20.1",
39 | ],
40 | "web": [
41 | "selenium",
42 | "webdriver-manager",
43 | ],
44 | "cli": [],
45 | }
46 | extras_require["all"] = sorted({package for packages in extras_require.values() for package in packages})
47 |
48 | setuptools.setup(
49 | name="docquery",
50 | version=version_contents["VERSION"],
51 | author="Impira Engineering",
52 | author_email="engineering@impira.com",
53 | description="DocQuery: An easy way to extract information from documents",
54 | long_description=long_description,
55 | long_description_content_type="text/markdown",
56 | url="https://github.com/impira/docquery",
57 | project_urls={
58 | "Bug Tracker": "https://github.com/impira/docquery/issues",
59 | },
60 | classifiers=[
61 | "Programming Language :: Python :: 3",
62 | "License :: OSI Approved :: MIT License",
63 | "Operating System :: OS Independent",
64 | ],
65 | package_dir={"": "src"},
66 | package_data={"": ["find_leaf_nodes.js"]},
67 | packages=setuptools.find_packages(where="src"),
68 | python_requires=">=3.7.0",
69 | entry_points={
70 | "console_scripts": ["docquery = docquery.cmd.__main__:main"],
71 | },
72 | install_requires=install_requires,
73 | extras_require=extras_require,
74 | )
75 |
--------------------------------------------------------------------------------
/src/docquery/__init__.py:
--------------------------------------------------------------------------------
1 | # pipeline wraps transformers pipeline with extensions in DocQuery
2 | # we're simply re-exporting it here.
3 | import sys
4 |
5 | from transformers.utils import _LazyModule
6 |
7 | from .version import VERSION
8 |
9 |
10 | _import_structure = {
11 | "transformers_patch": ["pipeline"],
12 | }
13 |
14 | sys.modules[__name__] = _LazyModule(
15 | __name__,
16 | globals()["__file__"],
17 | _import_structure,
18 | module_spec=__spec__,
19 | extra_objects={"__version__": VERSION},
20 | )
21 |
--------------------------------------------------------------------------------
/src/docquery/cmd/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/impira/docquery/3744f08a22609c0df5a72f463911b47689eaa819/src/docquery/cmd/__init__.py
--------------------------------------------------------------------------------
/src/docquery/cmd/__main__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
5 |
6 | import argparse
7 | import logging
8 | import sys
9 | import textwrap
10 |
11 | import transformers
12 |
13 | from ..transformers_patch import PIPELINE_DEFAULTS
14 |
15 |
16 | _module_not_found_error = None
17 | try:
18 | from . import scan
19 | except ModuleNotFoundError as e:
20 | _module_not_found_error = e
21 |
22 | if _module_not_found_error is not None:
23 | raise ModuleNotFoundError(
24 | textwrap.dedent(
25 | f"""\
26 | At least one dependency not found: {str(_module_not_found_error)!r}
27 | It is possible that docquery was installed without the CLI dependencies. Run:
28 |
29 | pip install 'docquery[cli]'
30 |
31 | to install impira with the CLI dependencies."""
32 | )
33 | )
34 |
35 |
36 | def main(args=None):
37 | """The main routine."""
38 | if args is None:
39 | args = sys.argv[1:]
40 |
41 | parent_parser = argparse.ArgumentParser(add_help=False)
42 | parent_parser.add_argument("--verbose", "-v", default=False, action="store_true")
43 | parent_parser.add_argument(
44 | "--checkpoint",
45 | default=None,
46 | help=f"A custom model checkpoint to use (other than {PIPELINE_DEFAULTS['document-question-answering']})",
47 | )
48 |
49 | parser = argparse.ArgumentParser(description="docquery is a cli tool to work with documents.")
50 | subparsers = parser.add_subparsers(help="sub-command help", dest="subcommand", required=True)
51 |
52 | for module in [scan]:
53 | module.build_parser(subparsers, parent_parser)
54 |
55 | args = parser.parse_args(args=args)
56 | level = logging.DEBUG if args.verbose else logging.INFO
57 | if not args.verbose:
58 | transformers.logging.set_verbosity_error()
59 | logging.basicConfig(format="%(asctime)s %(levelname)s: %(message)s", level=level)
60 |
61 | return args.func(args)
62 |
63 |
64 | if __name__ == "__main__":
65 | sys.exit(main())
66 |
--------------------------------------------------------------------------------
/src/docquery/cmd/scan.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 |
4 | from .. import pipeline
5 | from ..config import get_logger
6 | from ..document import UnsupportedDocument, load_document
7 | from ..ocr_reader import OCR_MAPPING
8 | from ..transformers_patch import PIPELINE_DEFAULTS
9 |
10 |
11 | log = get_logger("scan")
12 |
13 |
14 | def build_parser(subparsers, parent_parser):
15 | parser = subparsers.add_parser(
16 | "scan",
17 | help="Scan a directory and ask one or more questions of the documents in it.",
18 | parents=[parent_parser],
19 | )
20 |
21 | parser.add_argument(
22 | "questions", default=[], nargs="*", type=str, help="One or more questions to ask of the documents"
23 | )
24 |
25 | parser.add_argument("path", type=str, help="The file or directory to scan")
26 |
27 | parser.add_argument(
28 | "--ocr", choices=list(OCR_MAPPING.keys()), default=None, help="The OCR engine you would like to use"
29 | )
30 | parser.add_argument(
31 | "--ignore-embedded-text",
32 | dest="use_embedded_text",
33 | action="store_false",
34 | help="Do not try and extract embedded text from document types that might provide it (e.g. PDFs)",
35 | )
36 | parser.add_argument(
37 | "--classify",
38 | default=False,
39 | action="store_true",
40 | help="Classify documents while scanning them",
41 | )
42 | parser.add_argument(
43 | "--classify-checkpoint",
44 | default=None,
45 | help=f"A custom model checkpoint to use (other than {PIPELINE_DEFAULTS['document-classification']})",
46 | )
47 |
48 | parser.set_defaults(func=main)
49 | return parser
50 |
51 |
52 | def main(args):
53 | paths = []
54 | if pathlib.Path(args.path).is_dir():
55 | for root, dirs, files in os.walk(args.path):
56 | for fname in files:
57 | if (pathlib.Path(root) / fname).is_dir():
58 | continue
59 | paths.append(pathlib.Path(root) / fname)
60 | else:
61 | paths.append(args.path)
62 |
63 | docs = []
64 | for p in paths:
65 | try:
66 | log.info(f"Loading {p}")
67 | docs.append((p, load_document(str(p), ocr_reader=args.ocr, use_embedded_text=args.use_embedded_text)))
68 | except UnsupportedDocument as e:
69 | log.warning(f"Cannot load {p}: {e}. Skipping...")
70 |
71 | log.info(f"Done loading {len(docs)} file(s).")
72 | if not docs:
73 | return
74 |
75 | log.info("Loading pipelines.")
76 |
77 | nlp = pipeline("document-question-answering", model=args.checkpoint)
78 | if args.classify:
79 | classify = pipeline("document-classification", model=args.classify_checkpoint)
80 |
81 | log.info("Ready to start evaluating!")
82 |
83 | max_fname_len = max(len(str(p)) for (p, _) in docs)
84 | max_question_len = max(len(q) for q in args.questions) if len(args.questions) > 0 else 0
85 | for i, (p, d) in enumerate(docs):
86 | if i > 0 and len(args.questions) > 1:
87 | print("")
88 |
89 | if args.classify:
90 | cls = classify(**d.context)[0]
91 | print(f"{str(p):<{max_fname_len}} Document Type: {cls['label']}")
92 |
93 | for q in args.questions:
94 | try:
95 | response = nlp(question=q, **d.context)
96 | if isinstance(response, list):
97 | response = response[0] if len(response) > 0 else None
98 | except Exception:
99 | log.error(f"Failed while processing {str(p)} on question: '{q}'")
100 | raise
101 |
102 | answer = response["answer"] if response is not None else "NULL"
103 | print(f"{str(p):<{max_fname_len}} {q:<{max_question_len}}: {answer}")
104 |
--------------------------------------------------------------------------------
/src/docquery/config.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from pydantic import validate_arguments
4 |
5 |
6 | @validate_arguments
7 | def get_logger(prefix: str):
8 | log = logging.getLogger(prefix)
9 | log.propagate = False
10 | ch = logging.StreamHandler()
11 | formatter = logging.Formatter("%(asctime)s %(levelname)s: %(message)s")
12 | ch.setFormatter(formatter)
13 | log.addHandler(ch)
14 | return log
15 |
--------------------------------------------------------------------------------
/src/docquery/document.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import mimetypes
3 | import os
4 | from io import BytesIO
5 | from typing import Any, Dict, List, Optional, Tuple, Union
6 |
7 | import requests
8 | from PIL import Image, UnidentifiedImageError
9 | from pydantic import validate_arguments
10 |
11 | from .ext.functools import cached_property
12 | from .ocr_reader import NoOCRReaderFound, OCRReader, get_ocr_reader
13 | from .web import get_webdriver
14 |
15 |
16 | class UnsupportedDocument(Exception):
17 | def __init__(self, e):
18 | self.e = e
19 |
20 | def __str__(self):
21 | return f"unsupported file type: {self.e}"
22 |
23 |
24 | PDF_2_IMAGE = False
25 | PDF_PLUMBER = False
26 |
27 | try:
28 | import pdf2image
29 |
30 | PDF_2_IMAGE = True
31 | except ImportError:
32 | pass
33 |
34 | try:
35 | import pdfplumber
36 |
37 | PDF_PLUMBER = True
38 | except ImportError:
39 | pass
40 |
41 |
42 | def use_pdf2_image():
43 | if not PDF_2_IMAGE:
44 | raise UnsupportedDocument("Unable to import pdf2image (OCR will be unavailable for pdfs)")
45 |
46 |
47 | def use_pdf_plumber():
48 | if not PDF_PLUMBER:
49 | raise UnsupportedDocument("Unable to import pdfplumber (pdfs will be unavailable)")
50 |
51 |
52 | class Document(metaclass=abc.ABCMeta):
53 | @property
54 | @abc.abstractmethod
55 | def context(self) -> Tuple[(str, List[int])]:
56 | raise NotImplementedError
57 |
58 | @property
59 | @abc.abstractmethod
60 | def preview(self) -> "Image":
61 | raise NotImplementedError
62 |
63 | @staticmethod
64 | def _generate_document_output(
65 | images: List["Image.Image"],
66 | words_by_page: List[List[str]],
67 | boxes_by_page: List[List[List[int]]],
68 | dimensions_by_page: List[Tuple[int, int]],
69 | ) -> Dict[str, List[Tuple["Image.Image", List[Any]]]]:
70 |
71 | # pages_dimensions (width, height)
72 | assert len(images) == len(dimensions_by_page)
73 | assert len(images) == len(words_by_page)
74 | assert len(images) == len(boxes_by_page)
75 | processed_pages = []
76 | for image, words, boxes, dimensions in zip(images, words_by_page, boxes_by_page, dimensions_by_page):
77 | width, height = dimensions
78 |
79 | """
80 | box is [x1,y1,x2,y2] where x1,y1 are the top left corner of box and x2,y2 is the bottom right corner
81 | This function scales the distance between boxes to be on a fixed scale
82 | It is derived from the preprocessing code for LayoutLM
83 | """
84 | normalized_boxes = [
85 | [
86 | max(min(c, 1000), 0)
87 | for c in [
88 | int(1000 * (box[0] / width)),
89 | int(1000 * (box[1] / height)),
90 | int(1000 * (box[2] / width)),
91 | int(1000 * (box[3] / height)),
92 | ]
93 | ]
94 | for box in boxes
95 | ]
96 | assert len(words) == len(normalized_boxes), "Not as many words as there are bounding boxes"
97 | word_boxes = [x for x in zip(words, normalized_boxes)]
98 | processed_pages.append((image, word_boxes))
99 |
100 | return {"image": processed_pages}
101 |
102 |
103 | class PDFDocument(Document):
104 | def __init__(self, b, ocr_reader, use_embedded_text, **kwargs):
105 | self.b = b
106 | self.ocr_reader = ocr_reader
107 | self.use_embedded_text = use_embedded_text
108 |
109 | super().__init__(**kwargs)
110 |
111 | @cached_property
112 | def context(self) -> Dict[str, List[Tuple["Image.Image", List[Any]]]]:
113 | pdf = self._pdf
114 | if pdf is None:
115 | return {}
116 |
117 | images = self._images
118 |
119 | if len(images) != len(pdf.pages):
120 | raise ValueError(
121 | f"Mismatch: pdfplumber() thinks there are {len(pdf.pages)} pages and"
122 | f" pdf2image thinks there are {len(images)}"
123 | )
124 |
125 | words_by_page = []
126 | boxes_by_page = []
127 | dimensions_by_page = []
128 | for i, page in enumerate(pdf.pages):
129 | extracted_words = page.extract_words() if self.use_embedded_text else []
130 |
131 | if len(extracted_words) == 0:
132 | words, boxes = self.ocr_reader.apply_ocr(images[i])
133 | words_by_page.append(words)
134 | boxes_by_page.append(boxes)
135 | dimensions_by_page.append((images[i].width, images[i].height))
136 |
137 | else:
138 | words = [w["text"] for w in extracted_words]
139 | boxes = [[w["x0"], w["top"], w["x1"], w["bottom"]] for w in extracted_words]
140 | words_by_page.append(words)
141 | boxes_by_page.append(boxes)
142 | dimensions_by_page.append((page.width, page.height))
143 |
144 | return self._generate_document_output(images, words_by_page, boxes_by_page, dimensions_by_page)
145 |
146 | @cached_property
147 | def preview(self) -> "Image":
148 | return self._images
149 |
150 | @cached_property
151 | def _images(self):
152 | # First, try to extract text directly
153 | # TODO: This library requires poppler, which is not present everywhere.
154 | # We should look into alternatives. We could also gracefully handle this
155 | # and simply fall back to _only_ extracted text
156 | return [x.convert("RGB") for x in pdf2image.convert_from_bytes(self.b)]
157 |
158 | @cached_property
159 | def _pdf(self):
160 | use_pdf_plumber()
161 | pdf = pdfplumber.open(BytesIO(self.b))
162 | if len(pdf.pages) == 0:
163 | return None
164 | return pdf
165 |
166 |
167 | class ImageDocument(Document):
168 | def __init__(self, b, ocr_reader, **kwargs):
169 | self.b = b
170 | self.ocr_reader = ocr_reader
171 |
172 | super().__init__(**kwargs)
173 |
174 | @cached_property
175 | def preview(self) -> "Image":
176 | return [self.b.convert("RGB")]
177 |
178 | @cached_property
179 | def context(self) -> Dict[str, List[Tuple["Image.Image", List[Any]]]]:
180 | words, boxes = self.ocr_reader.apply_ocr(self.b)
181 | return self._generate_document_output([self.b], [words], [boxes], [(self.b.width, self.b.height)])
182 |
183 |
184 | class WebDocument(Document):
185 | def __init__(self, url, **kwargs):
186 | if not (url.startswith("http://") or url.startswith("https://")):
187 | url = "file://" + url
188 | self.url = url
189 |
190 | # TODO: This is a singleton, which is not thread-safe. We may want to relax this
191 | # behavior to allow the user to pass in their own driver (which could either be a
192 | # singleton or a custom instance).
193 | self.driver = get_webdriver()
194 |
195 | super().__init__(**kwargs)
196 |
197 | def ensure_loaded(self):
198 | self.driver.get(self.url)
199 |
200 | @cached_property
201 | def page_screenshots(self):
202 | self.ensure_loaded()
203 | return self.driver.scroll_and_screenshot()
204 |
205 | @cached_property
206 | def preview(self) -> "Image":
207 | return [img.convert("RGB") for img in self.page_screenshots[1]]
208 |
209 | @cached_property
210 | def context(self) -> Dict[str, List[Tuple["Image.Image", List[Any]]]]:
211 | self.ensure_loaded()
212 | word_boxes = self.driver.find_word_boxes()
213 |
214 | tops, _ = self.page_screenshots
215 |
216 | n_pages = len(tops)
217 | page = 0
218 | offset = 0
219 |
220 | words = [[] for _ in range(n_pages)]
221 | boxes = [[] for _ in range(n_pages)]
222 | for word_box in word_boxes["word_boxes"]:
223 | box = word_box["box"]
224 |
225 | if page < len(tops) - 1 and box["top"] >= tops[page + 1]:
226 | page += 1
227 | offset = tops[page]
228 |
229 | words[page].append(word_box["text"])
230 | boxes[page].append((box["left"], box["top"] - offset, box["right"], box["bottom"] - offset))
231 |
232 | return self._generate_document_output(
233 | self.preview, words, boxes, [(word_boxes["vw"], word_boxes["vh"])] * n_pages
234 | )
235 |
236 |
237 | @validate_arguments
238 | def load_document(fpath: str, ocr_reader: Optional[Union[str, OCRReader]] = None, use_embedded_text=True):
239 | base_path = os.path.basename(fpath).split("?")[0].strip()
240 | doc_type = mimetypes.guess_type(base_path)[0]
241 | if fpath.startswith("http://") or fpath.startswith("https://"):
242 | resp = requests.get(fpath, allow_redirects=True, stream=True)
243 | if not resp.ok:
244 | raise UnsupportedDocument(f"Failed to download: {resp.content}")
245 |
246 | if "Content-Type" in resp.headers:
247 | doc_type = resp.headers["Content-Type"].split(";")[0].strip()
248 |
249 | b = resp.raw
250 | else:
251 | b = open(fpath, "rb")
252 |
253 | if not ocr_reader or isinstance(ocr_reader, str):
254 | ocr_reader = get_ocr_reader(ocr_reader)
255 | elif not isinstance(ocr_reader, OCRReader):
256 | raise NoOCRReaderFound(f"{ocr_reader} is not a supported OCRReader class")
257 |
258 | if doc_type == "application/pdf":
259 | return PDFDocument(b.read(), ocr_reader=ocr_reader, use_embedded_text=use_embedded_text)
260 | elif doc_type == "text/html":
261 | return WebDocument(fpath)
262 | else:
263 | try:
264 | img = Image.open(b)
265 | except UnidentifiedImageError as e:
266 | raise UnsupportedDocument(e)
267 | return ImageDocument(img, ocr_reader=ocr_reader)
268 |
--------------------------------------------------------------------------------
/src/docquery/ext/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/impira/docquery/3744f08a22609c0df5a72f463911b47689eaa819/src/docquery/ext/__init__.py
--------------------------------------------------------------------------------
/src/docquery/ext/functools.py:
--------------------------------------------------------------------------------
1 | try:
2 | from functools import cached_property as cached_property
3 | except ImportError:
4 | # for python 3.7 support fall back to just property
5 | cached_property = property
6 |
--------------------------------------------------------------------------------
/src/docquery/ext/itertools.py:
--------------------------------------------------------------------------------
1 | import itertools
2 |
3 |
4 | def unique_everseen(iterable, key=None):
5 | """
6 | List unique elements, preserving order. Remember all elements ever seen [1]_.
7 |
8 | Examples
9 | --------
10 | >>> list(unique_everseen("AAAABBBCCDAABBB"))
11 | ["A", "B", "C", "D"]
12 | >>> list(unique_everseen("ABBCcAD", str.lower))
13 | ["A", "B", "C", "D"]
14 |
15 | References
16 | ----------
17 | .. [1] https://docs.python.org/3/library/itertools.html
18 | """
19 | seen = set()
20 | seen_add = seen.add
21 | if key is None:
22 | for element in itertools.filterfalse(seen.__contains__, iterable):
23 | seen_add(element)
24 | yield element
25 | else:
26 | for element in iterable:
27 | k = key(element)
28 | if k not in seen:
29 | seen_add(k)
30 | yield element
31 |
--------------------------------------------------------------------------------
/src/docquery/ext/model.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional, Tuple, Union
3 |
4 | import torch
5 | from torch import nn
6 | from transformers import LayoutLMModel, LayoutLMPreTrainedModel
7 | from transformers.modeling_outputs import QuestionAnsweringModelOutput as QuestionAnsweringModelOutputBase
8 |
9 |
10 | @dataclass
11 | class QuestionAnsweringModelOutput(QuestionAnsweringModelOutputBase):
12 | token_logits: Optional[torch.FloatTensor] = None
13 |
14 |
15 | # There are three additional config parameters that this model supports, which are not part of the
16 | # LayoutLMForQuestionAnswering in mainline transformers. These config parameters control the additional
17 | # token classifier head.
18 | #
19 | # token_classification (`bool, *optional*, defaults to False):
20 | # Whether to include an additional token classification head in question answering
21 | # token_classifier_reduction (`str`, *optional*, defaults to "mean")
22 | # Specifies the reduction to apply to the output of the cross entropy loss for the token classifier head during
23 | # training. Options are: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the weighted
24 | # mean of the output is taken, 'sum': the output will be summed.
25 | # token_classifier_constant (`float`, *optional*, defaults to 1.0)
26 | # Coefficient for the token classifier head's contribution to the total loss. A larger value means that the model
27 | # will prioritize learning the token classifier head during training.
28 | class LayoutLMForQuestionAnswering(LayoutLMPreTrainedModel):
29 | def __init__(self, config, has_visual_segment_embedding=True):
30 | super().__init__(config)
31 | self.num_labels = config.num_labels
32 |
33 | self.layoutlm = LayoutLMModel(config)
34 | self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
35 |
36 | # NOTE: We have to use getattr() here because we do not patch the LayoutLMConfig
37 | # class to have these extra attributes, so existing LayoutLM models may not have
38 | # them in their configuration.
39 | self.token_classifier_head = None
40 | if getattr(self.config, "token_classification", False):
41 | self.token_classifier_head = nn.Linear(config.hidden_size, 2)
42 |
43 | # Initialize weights and apply final processing
44 | self.post_init()
45 |
46 | def get_input_embeddings(self):
47 | return self.layoutlm.embeddings.word_embeddings
48 |
49 | def forward(
50 | self,
51 | input_ids: Optional[torch.LongTensor] = None,
52 | bbox: Optional[torch.LongTensor] = None,
53 | attention_mask: Optional[torch.FloatTensor] = None,
54 | token_type_ids: Optional[torch.LongTensor] = None,
55 | position_ids: Optional[torch.LongTensor] = None,
56 | head_mask: Optional[torch.FloatTensor] = None,
57 | inputs_embeds: Optional[torch.FloatTensor] = None,
58 | start_positions: Optional[torch.LongTensor] = None,
59 | end_positions: Optional[torch.LongTensor] = None,
60 | token_labels: Optional[torch.LongTensor] = None,
61 | output_attentions: Optional[bool] = None,
62 | output_hidden_states: Optional[bool] = None,
63 | return_dict: Optional[bool] = None,
64 | ) -> Union[Tuple, QuestionAnsweringModelOutput]:
65 | r"""
66 | start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
67 | Labels for position (index) of the start of the labelled span for computing the token classification loss.
68 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
69 | are not taken into account for computing the loss.
70 | end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
71 | Labels for position (index) of the end of the labelled span for computing the token classification loss.
72 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
73 | are not taken into account for computing the loss.
74 |
75 | Returns:
76 |
77 | Example:
78 |
79 | In this example below, we give the LayoutLMv2 model an image (of texts) and ask it a question. It will give us
80 | a prediction of what it thinks the answer is (the span of the answer within the texts parsed from the image).
81 |
82 | ```python
83 | >>> from transformers import AutoTokenizer, LayoutLMForQuestionAnswering
84 | >>> from datasets import load_dataset
85 | >>> import torch
86 |
87 | >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased", add_prefix_space=True)
88 | >>> model = LayoutLMForQuestionAnswering.from_pretrained("microsoft/layoutlm-base-uncased")
89 |
90 | >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
91 | >>> example = dataset[0]
92 | >>> question = "what's his name?"
93 | >>> words = example["tokens"]
94 | >>> boxes = example["bboxes"]
95 |
96 | >>> encoding = tokenizer(
97 | ... question.split(), words, is_split_into_words=True, return_token_type_ids=True, return_tensors="pt"
98 | ... )
99 | >>> bbox = []
100 | >>> for i, s, w in zip(encoding.input_ids[0], encoding.sequence_ids(0), encoding.word_ids(0)):
101 | ... if s == 1:
102 | ... bbox.append(boxes[w])
103 | ... elif i == tokenizer.sep_token_id:
104 | ... bbox.append([1000] * 4)
105 | ... else:
106 | ... bbox.append([0] * 4)
107 | >>> encoding["bbox"] = torch.tensor([bbox])
108 |
109 | >>> outputs = model(**encoding)
110 | >>> loss = outputs.loss
111 | >>> start_scores = outputs.start_logits
112 | >>> end_scores = outputs.end_logits
113 | ```
114 | """
115 |
116 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
117 |
118 | outputs = self.layoutlm(
119 | input_ids=input_ids,
120 | bbox=bbox,
121 | attention_mask=attention_mask,
122 | token_type_ids=token_type_ids,
123 | position_ids=position_ids,
124 | head_mask=head_mask,
125 | inputs_embeds=inputs_embeds,
126 | output_attentions=output_attentions,
127 | output_hidden_states=output_hidden_states,
128 | return_dict=return_dict,
129 | )
130 |
131 | if input_ids is not None:
132 | input_shape = input_ids.size()
133 | else:
134 | input_shape = inputs_embeds.size()[:-1]
135 |
136 | seq_length = input_shape[1]
137 | # only take the text part of the output representations
138 | sequence_output = outputs[0][:, :seq_length]
139 |
140 | logits = self.qa_outputs(sequence_output)
141 | start_logits, end_logits = logits.split(1, dim=-1)
142 | start_logits = start_logits.squeeze(-1).contiguous()
143 | end_logits = end_logits.squeeze(-1).contiguous()
144 |
145 | total_loss = None
146 | if start_positions is not None and end_positions is not None:
147 | # If we are on multi-GPU, split add a dimension
148 | if len(start_positions.size()) > 1:
149 | start_positions = start_positions.squeeze(-1)
150 | if len(end_positions.size()) > 1:
151 | end_positions = end_positions.squeeze(-1)
152 | # sometimes the start/end positions are outside our model inputs, we ignore these terms
153 | ignored_index = start_logits.size(1)
154 | start_positions = start_positions.clamp(0, ignored_index)
155 | end_positions = end_positions.clamp(0, ignored_index)
156 |
157 | loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
158 | start_loss = loss_fct(start_logits, start_positions)
159 | end_loss = loss_fct(end_logits, end_positions)
160 | total_loss = (start_loss + end_loss) / 2
161 |
162 | token_logits = None
163 | if getattr(self.config, "token_classification", False):
164 | token_logits = self.token_classifier_head(sequence_output)
165 |
166 | if token_labels is not None:
167 | # Loss fn expects logits to be of shape (batch_size, num_labels, 512), but model
168 | # outputs (batch_size, 512, num_labels), so we need to move the dimensions around
169 | # Ref: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
170 | token_logits_reshaped = torch.movedim(token_logits, source=token_logits.ndim - 1, destination=1)
171 | token_loss = nn.CrossEntropyLoss(reduction=self.config.token_classifier_reduction)(
172 | token_logits_reshaped, token_labels
173 | )
174 |
175 | total_loss += self.config.token_classifier_constant * token_loss
176 |
177 | if not return_dict:
178 | output = (start_logits, end_logits)
179 | if self.token_classification:
180 | output = output + (token_logits,)
181 |
182 | output = output + outputs[2:]
183 |
184 | if total_loss is not None:
185 | output = (total_loss,) + output
186 |
187 | return output
188 |
189 | return QuestionAnsweringModelOutput(
190 | loss=total_loss,
191 | start_logits=start_logits,
192 | end_logits=end_logits,
193 | token_logits=token_logits,
194 | hidden_states=outputs.hidden_states,
195 | attentions=outputs.attentions,
196 | )
197 |
--------------------------------------------------------------------------------
/src/docquery/ext/pipeline_document_classification.py:
--------------------------------------------------------------------------------
1 | # This file is copied from transformers:
2 | # https://github.com/huggingface/transformers/blob/bb6f6d53386bf2340eead6a8f9320ce61add3e96/src/transformers/pipelines/image_classification.py
3 | # And has been modified to support Donut
4 | import re
5 | from typing import List, Optional, Tuple, Union
6 |
7 | import torch
8 | from transformers.pipelines.base import PIPELINE_INIT_ARGS, ChunkPipeline
9 | from transformers.pipelines.text_classification import ClassificationFunction, sigmoid, softmax
10 | from transformers.utils import ExplicitEnum, add_end_docstrings, logging
11 |
12 | from .pipeline_document_question_answering import ImageOrName, apply_tesseract
13 | from .qa_helpers import TESSERACT_LOADED, VISION_LOADED, load_image
14 |
15 |
16 | logger = logging.get_logger(__name__)
17 |
18 |
19 | class ModelType(ExplicitEnum):
20 | Standard = "standard"
21 | VisionEncoderDecoder = "vision_encoder_decoder"
22 |
23 |
24 | def donut_token2json(tokenizer, tokens, is_inner_value=False):
25 | """
26 | Convert a (generated) token sequence into an ordered JSON format.
27 | """
28 | output = dict()
29 |
30 | while tokens:
31 | start_token = re.search(r"", tokens, re.IGNORECASE)
32 | if start_token is None:
33 | break
34 | key = start_token.group(1)
35 | end_token = re.search(rf"", tokens, re.IGNORECASE)
36 | start_token = start_token.group()
37 | if end_token is None:
38 | tokens = tokens.replace(start_token, "")
39 | else:
40 | end_token = end_token.group()
41 | start_token_escaped = re.escape(start_token)
42 | end_token_escaped = re.escape(end_token)
43 | content = re.search(f"{start_token_escaped}(.*?){end_token_escaped}", tokens, re.IGNORECASE)
44 | if content is not None:
45 | content = content.group(1).strip()
46 | if r""):
55 | leaf = leaf.strip()
56 | if leaf in tokenizer.get_added_vocab() and leaf[0] == "<" and leaf[-2:] == "/>":
57 | leaf = leaf[1:-2] # for categorical special tokens
58 | output[key].append(leaf)
59 | if len(output[key]) == 1:
60 | output[key] = output[key][0]
61 |
62 | tokens = tokens[tokens.find(end_token) + len(end_token) :].strip()
63 | if tokens[:6] == r"": # non-leaf nodes
64 | return [output] + donut_token2json(tokenizer, tokens[6:], is_inner_value=True)
65 |
66 | if len(output):
67 | return [output] if is_inner_value else output
68 | else:
69 | return [] if is_inner_value else {"text_sequence": tokens}
70 |
71 |
72 | @add_end_docstrings(PIPELINE_INIT_ARGS)
73 | class DocumentClassificationPipeline(ChunkPipeline):
74 | """
75 | Document classification pipeline using any `AutoModelForDocumentClassification`. This pipeline predicts the class of a
76 | document.
77 |
78 | This document classification pipeline can currently be loaded from [`pipeline`] using the following task identifier:
79 | `"document-classification"`.
80 |
81 | See the list of available models on
82 | [huggingface.co/models](https://huggingface.co/models?filter=document-classification).
83 | """
84 |
85 | def __init__(self, *args, **kwargs):
86 | super().__init__(*args, **kwargs)
87 |
88 | if self.model.config.__class__.__name__ == "VisionEncoderDecoderConfig":
89 | self.model_type = ModelType.VisionEncoderDecoder
90 | else:
91 | self.model_type = ModelType.Standard
92 |
93 | def _sanitize_parameters(
94 | self,
95 | doc_stride=None,
96 | lang: Optional[str] = None,
97 | tesseract_config: Optional[str] = None,
98 | max_num_spans: Optional[int] = None,
99 | max_seq_len=None,
100 | function_to_apply=None,
101 | top_k=None,
102 | ):
103 | preprocess_params, postprocess_params = {}, {}
104 | if doc_stride is not None:
105 | preprocess_params["doc_stride"] = doc_stride
106 | if max_seq_len is not None:
107 | preprocess_params["max_seq_len"] = max_seq_len
108 | if lang is not None:
109 | preprocess_params["lang"] = lang
110 | if tesseract_config is not None:
111 | preprocess_params["tesseract_config"] = tesseract_config
112 | if max_num_spans is not None:
113 | preprocess_params["max_num_spans"] = max_num_spans
114 |
115 | if isinstance(function_to_apply, str):
116 | function_to_apply = ClassificationFunction[function_to_apply.upper()]
117 |
118 | if function_to_apply is not None:
119 | postprocess_params["function_to_apply"] = function_to_apply
120 |
121 | if top_k is not None:
122 | if top_k < 1:
123 | raise ValueError(f"top_k parameter should be >= 1 (got {top_k})")
124 | postprocess_params["top_k"] = top_k
125 |
126 | return preprocess_params, {}, postprocess_params
127 |
128 | def __call__(self, image: Union[ImageOrName, List[ImageOrName], List[Tuple]], **kwargs):
129 | """
130 | Assign labels to the document(s) passed as inputs.
131 |
132 | # TODO
133 | """
134 | if isinstance(image, list):
135 | normalized_images = (i if isinstance(i, (tuple, list)) else (i, None) for i in image)
136 | else:
137 | normalized_images = [(image, None)]
138 |
139 | return super().__call__({"pages": normalized_images}, **kwargs)
140 |
141 | def preprocess(
142 | self,
143 | input,
144 | doc_stride=None,
145 | max_seq_len=None,
146 | word_boxes: Tuple[str, List[float]] = None,
147 | lang=None,
148 | tesseract_config="",
149 | max_num_spans=1,
150 | ):
151 | # NOTE: This code mirrors the code in question answering and will be implemented in a follow up PR
152 | # to support documents with enough tokens that overflow the model's window
153 | if max_seq_len is None:
154 | # TODO: LayoutLM's stride is 512 by default. Is it ok to use that as the min
155 | # instead of 384 (which the QA model uses)?
156 | max_seq_len = min(self.tokenizer.model_max_length, 512)
157 |
158 | if doc_stride is None:
159 | doc_stride = min(max_seq_len // 2, 256)
160 |
161 | total_num_spans = 0
162 |
163 | for page_idx, (image, word_boxes) in enumerate(input["pages"]):
164 | image_features = {}
165 | if image is not None:
166 | if not VISION_LOADED:
167 | raise ValueError(
168 | "If you provide an image, then the pipeline will run process it with PIL (Pillow), but"
169 | " PIL is not available. Install it with pip install Pillow."
170 | )
171 | image = load_image(image)
172 | if self.feature_extractor is not None:
173 | image_features.update(self.feature_extractor(images=image, return_tensors=self.framework))
174 |
175 | words, boxes = None, None
176 | if self.model_type != ModelType.VisionEncoderDecoder:
177 | if word_boxes is not None:
178 | words = [x[0] for x in word_boxes]
179 | boxes = [x[1] for x in word_boxes]
180 | elif "words" in image_features and "boxes" in image_features:
181 | words = image_features.pop("words")[0]
182 | boxes = image_features.pop("boxes")[0]
183 | elif image is not None:
184 | if not TESSERACT_LOADED:
185 | raise ValueError(
186 | "If you provide an image without word_boxes, then the pipeline will run OCR using"
187 | " Tesseract, but pytesseract is not available. Install it with pip install pytesseract."
188 | )
189 | if TESSERACT_LOADED:
190 | words, boxes = apply_tesseract(image, lang=lang, tesseract_config=tesseract_config)
191 | else:
192 | raise ValueError(
193 | "You must provide an image or word_boxes. If you provide an image, the pipeline will"
194 | " automatically run OCR to derive words and boxes"
195 | )
196 |
197 | if self.tokenizer.padding_side != "right":
198 | raise ValueError(
199 | "Document classification only supports tokenizers whose padding side is 'right', not"
200 | f" {self.tokenizer.padding_side}"
201 | )
202 |
203 | if self.model_type == ModelType.VisionEncoderDecoder:
204 | encoding = {
205 | "inputs": image_features["pixel_values"],
206 | "max_length": self.model.decoder.config.max_position_embeddings,
207 | "decoder_input_ids": self.tokenizer(
208 | "",
209 | add_special_tokens=False,
210 | return_tensors=self.framework,
211 | ).input_ids,
212 | "return_dict_in_generate": True,
213 | }
214 | yield {
215 | **encoding,
216 | "page": None,
217 | }
218 | else:
219 | encoding = self.tokenizer(
220 | text=words,
221 | max_length=max_seq_len,
222 | stride=doc_stride,
223 | return_token_type_ids=True,
224 | is_split_into_words=True,
225 | truncation=True,
226 | return_overflowing_tokens=True,
227 | )
228 |
229 | num_spans = len(encoding["input_ids"])
230 |
231 | for span_idx in range(num_spans):
232 | if self.framework == "pt":
233 | span_encoding = {k: torch.tensor(v[span_idx : span_idx + 1]) for (k, v) in encoding.items()}
234 | span_encoding.update(
235 | {k: v for (k, v) in image_features.items()}
236 | ) # TODO: Verify cardinality is correct
237 | else:
238 | raise ValueError("Unsupported: Tensorflow preprocessing for DocumentClassification")
239 |
240 | # For each span, place a bounding box [0,0,0,0] for question and CLS tokens, [1000,1000,1000,1000]
241 | # for SEP tokens, and the word's bounding box for words in the original document.
242 | bbox = []
243 | for i, s, w in zip(
244 | encoding.input_ids[span_idx],
245 | encoding.sequence_ids(span_idx),
246 | encoding.word_ids(span_idx),
247 | ):
248 | if s == 0:
249 | bbox.append(boxes[w])
250 | elif i == self.tokenizer.sep_token_id:
251 | bbox.append([1000] * 4)
252 | else:
253 | bbox.append([0] * 4)
254 |
255 | span_encoding["bbox"] = torch.tensor(bbox).unsqueeze(0)
256 |
257 | yield {
258 | **span_encoding,
259 | "page": page_idx,
260 | }
261 |
262 | total_num_spans += 1
263 | if total_num_spans >= max_num_spans:
264 | break
265 |
266 | def _forward(self, model_inputs):
267 | page = model_inputs.pop("page", None)
268 |
269 | if "overflow_to_sample_mapping" in model_inputs:
270 | model_inputs.pop("overflow_to_sample_mapping")
271 |
272 | if self.model_type == ModelType.VisionEncoderDecoder:
273 | model_outputs = self.model.generate(**model_inputs)
274 | else:
275 | model_outputs = self.model(**model_inputs)
276 |
277 | model_outputs["page"] = page
278 | model_outputs["attention_mask"] = model_inputs.get("attention_mask", None)
279 | return model_outputs
280 |
281 | def postprocess(self, model_outputs, function_to_apply=None, top_k=1, **kwargs):
282 | if function_to_apply is None:
283 | if self.model.config.num_labels == 1:
284 | function_to_apply = ClassificationFunction.SIGMOID
285 | elif self.model.config.num_labels > 1:
286 | function_to_apply = ClassificationFunction.SOFTMAX
287 | elif hasattr(self.model.config, "function_to_apply") and function_to_apply is None:
288 | function_to_apply = self.model.config.function_to_apply
289 | else:
290 | function_to_apply = ClassificationFunction.NONE
291 |
292 | if self.model_type == ModelType.VisionEncoderDecoder:
293 | answers = self.postprocess_encoder_decoder(model_outputs, top_k=top_k, **kwargs)
294 | else:
295 | answers = self.postprocess_standard(
296 | model_outputs, function_to_apply=function_to_apply, top_k=top_k, **kwargs
297 | )
298 |
299 | answers = sorted(answers, key=lambda x: x.get("score", 0), reverse=True)[:top_k]
300 | return answers
301 |
302 | def postprocess_encoder_decoder(self, model_outputs, **kwargs):
303 | classes = set()
304 | for model_output in model_outputs:
305 | for sequence in self.tokenizer.batch_decode(model_output.sequences):
306 | sequence = sequence.replace(self.tokenizer.eos_token, "").replace(self.tokenizer.pad_token, "")
307 | sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
308 | classes.add(donut_token2json(self.tokenizer, sequence)["class"])
309 |
310 | # Return the first top_k unique classes we see
311 | return [{"label": v} for v in classes]
312 |
313 | def postprocess_standard(self, model_outputs, function_to_apply, **kwargs):
314 | # Average the score across pages
315 | sum_scores = {k: 0 for k in self.model.config.id2label.values()}
316 | for model_output in model_outputs:
317 | outputs = model_output["logits"][0]
318 | outputs = outputs.numpy()
319 |
320 | if function_to_apply == ClassificationFunction.SIGMOID:
321 | scores = sigmoid(outputs)
322 | elif function_to_apply == ClassificationFunction.SOFTMAX:
323 | scores = softmax(outputs)
324 | elif function_to_apply == ClassificationFunction.NONE:
325 | scores = outputs
326 | else:
327 | raise ValueError(f"Unrecognized `function_to_apply` argument: {function_to_apply}")
328 |
329 | for i, score in enumerate(scores):
330 | sum_scores[self.model.config.id2label[i]] += score.item()
331 |
332 | return [{"label": label, "score": score / len(model_outputs)} for (label, score) in sum_scores.items()]
333 |
--------------------------------------------------------------------------------
/src/docquery/ext/pipeline_document_question_answering.py:
--------------------------------------------------------------------------------
1 | # NOTE: This code is currently under review for inclusion in the main
2 | # huggingface/transformers repository:
3 | # https://github.com/huggingface/transformers/pull/18414
4 | import re
5 | from typing import List, Optional, Tuple, Union
6 |
7 | import numpy as np
8 | from transformers.pipelines.base import PIPELINE_INIT_ARGS, ChunkPipeline
9 | from transformers.utils import (
10 | ExplicitEnum,
11 | add_end_docstrings,
12 | is_pytesseract_available,
13 | is_torch_available,
14 | is_vision_available,
15 | logging,
16 | )
17 |
18 | from .itertools import unique_everseen
19 | from .qa_helpers import TESSERACT_LOADED, VISION_LOADED, Image, load_image, pytesseract, select_starts_ends
20 |
21 |
22 | if is_torch_available():
23 | import torch
24 |
25 | # We do not perform the check in this version of the pipeline code
26 | # from transformers.models.auto.modeling_auto import MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING
27 |
28 | logger = logging.get_logger(__name__)
29 |
30 |
31 | # normalize_bbox() and apply_tesseract() are derived from apply_tesseract in models/layoutlmv3/feature_extraction_layoutlmv3.py.
32 | # However, because the pipeline may evolve from what layoutlmv3 currently does, it's copied (vs. imported) to avoid creating an
33 | # unnecessary dependency.
34 | def normalize_box(box, width, height):
35 | return [
36 | int(1000 * (box[0] / width)),
37 | int(1000 * (box[1] / height)),
38 | int(1000 * (box[2] / width)),
39 | int(1000 * (box[3] / height)),
40 | ]
41 |
42 |
43 | def apply_tesseract(image: "Image.Image", lang: Optional[str], tesseract_config: Optional[str]):
44 | """Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes."""
45 | # apply OCR
46 | data = pytesseract.image_to_data(image, lang=lang, output_type="dict", config=tesseract_config)
47 | words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"]
48 |
49 | # filter empty words and corresponding coordinates
50 | irrelevant_indices = [idx for idx, word in enumerate(words) if not word.strip()]
51 | words = [word for idx, word in enumerate(words) if idx not in irrelevant_indices]
52 | left = [coord for idx, coord in enumerate(left) if idx not in irrelevant_indices]
53 | top = [coord for idx, coord in enumerate(top) if idx not in irrelevant_indices]
54 | width = [coord for idx, coord in enumerate(width) if idx not in irrelevant_indices]
55 | height = [coord for idx, coord in enumerate(height) if idx not in irrelevant_indices]
56 |
57 | # turn coordinates into (left, top, left+width, top+height) format
58 | actual_boxes = []
59 | for x, y, w, h in zip(left, top, width, height):
60 | actual_box = [x, y, x + w, y + h]
61 | actual_boxes.append(actual_box)
62 |
63 | image_width, image_height = image.size
64 |
65 | # finally, normalize the bounding boxes
66 | normalized_boxes = []
67 | for box in actual_boxes:
68 | normalized_boxes.append(normalize_box(box, image_width, image_height))
69 |
70 | if len(words) != len(normalized_boxes):
71 | raise ValueError("Not as many words as there are bounding boxes")
72 |
73 | return words, normalized_boxes
74 |
75 |
76 | class ModelType(ExplicitEnum):
77 | LayoutLM = "layoutlm"
78 | LayoutLMv2andv3 = "layoutlmv2andv3"
79 | VisionEncoderDecoder = "vision_encoder_decoder"
80 |
81 |
82 | ImageOrName = Union["Image.Image", str]
83 | DEFAULT_MAX_ANSWER_LENGTH = 15
84 |
85 |
86 | @add_end_docstrings(PIPELINE_INIT_ARGS)
87 | class DocumentQuestionAnsweringPipeline(ChunkPipeline):
88 | # TODO: Update task_summary docs to include an example with document QA and then update the first sentence
89 | """
90 | Document Question Answering pipeline using any `AutoModelForDocumentQuestionAnswering`. The inputs/outputs are
91 | similar to the (extractive) question answering pipeline; however, the pipeline takes an image (and optional OCR'd
92 | words/boxes) as input instead of text context.
93 |
94 | This document question answering pipeline can currently be loaded from [`pipeline`] using the following task
95 | identifier: `"document-question-answering"`.
96 |
97 | The models that this pipeline can use are models that have been fine-tuned on a document question answering task.
98 | See the up-to-date list of available models on
99 | [huggingface.co/models](https://huggingface.co/models?filter=document-question-answering).
100 | """
101 |
102 | def __init__(self, *args, **kwargs):
103 | super().__init__(*args, **kwargs)
104 | if self.model.config.__class__.__name__ == "VisionEncoderDecoderConfig":
105 | self.model_type = ModelType.VisionEncoderDecoder
106 | if self.model.config.encoder.model_type != "donut-swin":
107 | raise ValueError("Currently, the only supported VisionEncoderDecoder model is Donut")
108 | else:
109 | # self.check_model_type(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING)
110 | if self.model.config.__class__.__name__ == "LayoutLMConfig":
111 | self.model_type = ModelType.LayoutLM
112 | else:
113 | self.model_type = ModelType.LayoutLMv2andv3
114 |
115 | def _sanitize_parameters(
116 | self,
117 | padding=None,
118 | doc_stride=None,
119 | max_question_len=None,
120 | lang: Optional[str] = None,
121 | tesseract_config: Optional[str] = None,
122 | max_answer_len=None,
123 | max_seq_len=None,
124 | top_k=None,
125 | handle_impossible_answer=None,
126 | **kwargs,
127 | ):
128 | preprocess_params, postprocess_params = {}, {}
129 | if padding is not None:
130 | preprocess_params["padding"] = padding
131 | if doc_stride is not None:
132 | preprocess_params["doc_stride"] = doc_stride
133 | if max_question_len is not None:
134 | preprocess_params["max_question_len"] = max_question_len
135 | if max_seq_len is not None:
136 | preprocess_params["max_seq_len"] = max_seq_len
137 | if lang is not None:
138 | preprocess_params["lang"] = lang
139 | if tesseract_config is not None:
140 | preprocess_params["tesseract_config"] = tesseract_config
141 |
142 | if top_k is not None:
143 | if top_k < 1:
144 | raise ValueError(f"top_k parameter should be >= 1 (got {top_k})")
145 | postprocess_params["top_k"] = top_k
146 | if max_answer_len is not None:
147 | if max_answer_len < 1:
148 | raise ValueError(f"max_answer_len parameter should be >= 1 (got {max_answer_len}")
149 | postprocess_params["max_answer_len"] = max_answer_len
150 | if handle_impossible_answer is not None:
151 | postprocess_params["handle_impossible_answer"] = handle_impossible_answer
152 |
153 | return preprocess_params, {}, postprocess_params
154 |
155 | def __call__(
156 | self,
157 | image: Union[ImageOrName, List[ImageOrName], List[Tuple]],
158 | question: Optional[str] = None,
159 | **kwargs,
160 | ):
161 | """
162 | Answer the question(s) given as inputs by using the document(s). A document is defined as an image and an
163 | optional list of (word, box) tuples which represent the text in the document. If the `word_boxes` are not
164 | provided, it will use the Tesseract OCR engine (if available) to extract the words and boxes automatically for
165 | LayoutLM-like models which require them as input. For Donut, no OCR is run.
166 |
167 | You can invoke the pipeline several ways:
168 |
169 | - `pipeline(image=image, question=question)`
170 | - `pipeline(image=image, question=question, word_boxes=word_boxes)`
171 | - `pipeline([{"image": image, "question": question}])`
172 | - `pipeline([{"image": image, "question": question, "word_boxes": word_boxes}])`
173 |
174 | Args:
175 | image (`str` or `PIL.Image`):
176 | The pipeline handles three types of images:
177 |
178 | - A string containing a http link pointing to an image
179 | - A string containing a local path to an image
180 | - An image loaded in PIL directly
181 |
182 | The pipeline accepts either a single image or a batch of images. If given a single image, it can be
183 | broadcasted to multiple questions.
184 | question (`str`):
185 | A question to ask of the document.
186 | word_boxes (`List[str, Tuple[float, float, float, float]]`, *optional*):
187 | A list of words and bounding boxes (normalized 0->1000). If you provide this optional input, then the
188 | pipeline will use these words and boxes instead of running OCR on the image to derive them for models
189 | that need them (e.g. LayoutLM). This allows you to reuse OCR'd results across many invocations of the
190 | pipeline without having to re-run it each time.
191 | top_k (`int`, *optional*, defaults to 1):
192 | The number of answers to return (will be chosen by order of likelihood). Note that we return less than
193 | top_k answers if there are not enough options available within the context.
194 | doc_stride (`int`, *optional*, defaults to 128):
195 | If the words in the document are too long to fit with the question for the model, it will be split in
196 | several chunks with some overlap. This argument controls the size of that overlap.
197 | max_answer_len (`int`, *optional*, defaults to 15):
198 | The maximum length of predicted answers (e.g., only answers with a shorter length are considered).
199 | max_seq_len (`int`, *optional*, defaults to 384):
200 | The maximum length of the total sentence (context + question) in tokens of each chunk passed to the
201 | model. The context will be split in several chunks (using `doc_stride` as overlap) if needed.
202 | max_question_len (`int`, *optional*, defaults to 64):
203 | The maximum length of the question after tokenization. It will be truncated if needed.
204 | handle_impossible_answer (`bool`, *optional*, defaults to `False`):
205 | Whether or not we accept impossible as an answer.
206 | lang (`str`, *optional*):
207 | Language to use while running OCR. Defaults to english.
208 | tesseract_config (`str`, *optional*):
209 | Additional flags to pass to tesseract while running OCR.
210 |
211 | Return:
212 | A `dict` or a list of `dict`: Each result comes as a dictionary with the following keys:
213 |
214 | - **score** (`float`) -- The probability associated to the answer.
215 | - **start** (`int`) -- The start word index of the answer (in the OCR'd version of the input or provided
216 | `word_boxes`).
217 | - **end** (`int`) -- The end word index of the answer (in the OCR'd version of the input or provided
218 | `word_boxes`).
219 | - **answer** (`str`) -- The answer to the question.
220 | - **words** (`list[int]`) -- The index of each word/box pair that is in the answer
221 | - **page** (`int`) -- The page of the answer
222 | """
223 | if question is None:
224 | question = image["question"]
225 | image = image["image"]
226 |
227 | if isinstance(image, list):
228 | normalized_images = (i if isinstance(i, (tuple, list)) else (i, None) for i in image)
229 | else:
230 | normalized_images = [(image, None)]
231 |
232 | return super().__call__({"question": question, "pages": normalized_images}, **kwargs)
233 |
234 | def preprocess(
235 | self,
236 | input,
237 | padding="do_not_pad",
238 | doc_stride=None,
239 | max_question_len=64,
240 | max_seq_len=None,
241 | word_boxes: Tuple[str, List[float]] = None,
242 | lang=None,
243 | tesseract_config="",
244 | ):
245 | # NOTE: This code mirrors the code in question answering and will be implemented in a follow up PR
246 | # to support documents with enough tokens that overflow the model's window
247 | if max_seq_len is None:
248 | max_seq_len = min(self.tokenizer.model_max_length, 512)
249 |
250 | if doc_stride is None:
251 | doc_stride = min(max_seq_len // 2, 256)
252 |
253 | for page_idx, (image, word_boxes) in enumerate(input["pages"]):
254 | image_features = {}
255 | if image is not None:
256 | image = load_image(image)
257 | if self.feature_extractor is not None:
258 | image_features.update(self.feature_extractor(images=image, return_tensors=self.framework))
259 | elif self.model_type == ModelType.VisionEncoderDecoder:
260 | raise ValueError(
261 | "If you are using a VisionEncoderDecoderModel, you must provide a feature extractor"
262 | )
263 |
264 | words, boxes = None, None
265 | if not self.model_type == ModelType.VisionEncoderDecoder:
266 | if word_boxes is not None:
267 | words = [x[0] for x in word_boxes]
268 | boxes = [x[1] for x in word_boxes]
269 | elif "words" in image_features and "boxes" in image_features:
270 | words = image_features.pop("words")[0]
271 | boxes = image_features.pop("boxes")[0]
272 | elif image is not None:
273 | if not TESSERACT_LOADED:
274 | raise ValueError(
275 | "If you provide an image without word_boxes, then the pipeline will run OCR using"
276 | " Tesseract, but pytesseract is not available. Install it with pip install pytesseract."
277 | )
278 | if TESSERACT_LOADED:
279 | words, boxes = apply_tesseract(image, lang=lang, tesseract_config=tesseract_config)
280 | else:
281 | raise ValueError(
282 | "You must provide an image or word_boxes. If you provide an image, the pipeline will"
283 | " automatically run OCR to derive words and boxes"
284 | )
285 |
286 | if self.tokenizer.padding_side != "right":
287 | raise ValueError(
288 | "Document question answering only supports tokenizers whose padding side is 'right', not"
289 | f" {self.tokenizer.padding_side}"
290 | )
291 |
292 | if self.model_type == ModelType.VisionEncoderDecoder:
293 | task_prompt = f'{input["question"]}'
294 | # Adapted from https://huggingface.co/spaces/nielsr/donut-docvqa/blob/main/app.py
295 | encoding = {
296 | "inputs": image_features["pixel_values"],
297 | "decoder_input_ids": self.tokenizer(
298 | task_prompt, add_special_tokens=False, return_tensors=self.framework
299 | ).input_ids,
300 | "return_dict_in_generate": True,
301 | }
302 |
303 | yield {
304 | **encoding,
305 | "p_mask": None,
306 | "word_ids": None,
307 | "words": None,
308 | "page": None,
309 | "output_attentions": True,
310 | }
311 | else:
312 | tokenizer_kwargs = {}
313 | if self.model_type == ModelType.LayoutLM:
314 | tokenizer_kwargs["text"] = input["question"].split()
315 | tokenizer_kwargs["text_pair"] = words
316 | tokenizer_kwargs["is_split_into_words"] = True
317 | else:
318 | tokenizer_kwargs["text"] = [input["question"]]
319 | tokenizer_kwargs["text_pair"] = [words]
320 | tokenizer_kwargs["boxes"] = [boxes]
321 |
322 | encoding = self.tokenizer(
323 | padding=padding,
324 | max_length=max_seq_len,
325 | stride=doc_stride,
326 | truncation="only_second",
327 | return_overflowing_tokens=True,
328 | **tokenizer_kwargs,
329 | )
330 |
331 | if "pixel_values" in image_features:
332 | encoding["image"] = image_features.pop("pixel_values")
333 |
334 | num_spans = len(encoding["input_ids"])
335 |
336 | # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
337 | # We put 0 on the tokens from the context and 1 everywhere else (question and special tokens)
338 | # This logic mirrors the logic in the question_answering pipeline
339 | p_mask = [[tok != 1 for tok in encoding.sequence_ids(span_id)] for span_id in range(num_spans)]
340 |
341 | for span_idx in range(num_spans):
342 | if self.framework == "pt":
343 | span_encoding = {k: torch.tensor(v[span_idx : span_idx + 1]) for (k, v) in encoding.items()}
344 | span_encoding.update(
345 | {k: v for (k, v) in image_features.items()}
346 | ) # TODO: Verify cardinality is correct
347 | else:
348 | raise ValueError("Unsupported: Tensorflow preprocessing for DocumentQuestionAnsweringPipeline")
349 |
350 | input_ids_span_idx = encoding["input_ids"][span_idx]
351 | # keep the cls_token unmasked (some models use it to indicate unanswerable questions)
352 | if self.tokenizer.cls_token_id is not None:
353 | cls_indices = np.nonzero(np.array(input_ids_span_idx) == self.tokenizer.cls_token_id)[0]
354 | for cls_index in cls_indices:
355 | p_mask[span_idx][cls_index] = 0
356 |
357 | # For each span, place a bounding box [0,0,0,0] for question and CLS tokens, [1000,1000,1000,1000]
358 | # for SEP tokens, and the word's bounding box for words in the original document.
359 | if "boxes" not in tokenizer_kwargs:
360 | bbox = []
361 |
362 | for input_id, sequence_id, word_id in zip(
363 | encoding.input_ids[span_idx],
364 | encoding.sequence_ids(span_idx),
365 | encoding.word_ids(span_idx),
366 | ):
367 | if sequence_id == 1:
368 | bbox.append(boxes[word_id])
369 | elif input_id == self.tokenizer.sep_token_id:
370 | bbox.append([1000] * 4)
371 | else:
372 | bbox.append([0] * 4)
373 |
374 | if self.framework == "pt":
375 | span_encoding["bbox"] = torch.tensor(bbox).unsqueeze(0)
376 | elif self.framework == "tf":
377 | raise ValueError(
378 | "Unsupported: Tensorflow preprocessing for DocumentQuestionAnsweringPipeline"
379 | )
380 |
381 | yield {
382 | **span_encoding,
383 | "p_mask": p_mask[span_idx],
384 | "word_ids": encoding.word_ids(span_idx),
385 | "words": words,
386 | "page": page_idx,
387 | }
388 |
389 | def _forward(self, model_inputs):
390 | p_mask = model_inputs.pop("p_mask", None)
391 | word_ids = model_inputs.pop("word_ids", None)
392 | words = model_inputs.pop("words", None)
393 | page = model_inputs.pop("page", None)
394 |
395 | if "overflow_to_sample_mapping" in model_inputs:
396 | model_inputs.pop("overflow_to_sample_mapping")
397 |
398 | if self.model_type == ModelType.VisionEncoderDecoder:
399 | model_outputs = self.model.generate(**model_inputs)
400 | else:
401 | model_outputs = self.model(**model_inputs)
402 |
403 | model_outputs["p_mask"] = p_mask
404 | model_outputs["word_ids"] = word_ids
405 | model_outputs["words"] = words
406 | model_outputs["page"] = page
407 | model_outputs["attention_mask"] = model_inputs.get("attention_mask", None)
408 | return model_outputs
409 |
410 | def postprocess(self, model_outputs, top_k=1, **kwargs):
411 | if self.model_type == ModelType.VisionEncoderDecoder:
412 | answers = [self.postprocess_encoder_decoder_single(o) for o in model_outputs]
413 | else:
414 | answers = self.postprocess_extractive_qa(model_outputs, top_k=top_k, **kwargs)
415 |
416 | answers = sorted(answers, key=lambda x: x.get("score", 0), reverse=True)[:top_k]
417 | return answers
418 |
419 | def postprocess_encoder_decoder_single(self, model_outputs, **kwargs):
420 | sequence = self.tokenizer.batch_decode(model_outputs.sequences)[0]
421 |
422 | # TODO: A lot of this logic is specific to Donut and should probably be handled in the tokenizer
423 | # (see https://github.com/huggingface/transformers/pull/18414/files#r961747408 for more context).
424 | sequence = sequence.replace(self.tokenizer.eos_token, "").replace(self.tokenizer.pad_token, "")
425 | sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
426 | ret = {
427 | "answer": None,
428 | }
429 |
430 | answer = re.search(r"(.*)", sequence)
431 | if answer is not None:
432 | ret["answer"] = answer.group(1).strip()
433 | return ret
434 |
435 | def postprocess_extractive_qa(
436 | self, model_outputs, top_k=1, handle_impossible_answer=False, max_answer_len=None, **kwargs
437 | ):
438 | min_null_score = 1000000 # large and positive
439 | answers = []
440 |
441 | if max_answer_len is None:
442 | if hasattr(self.model.config, "token_classification") and self.model.config.token_classification:
443 | # If this model has token classification, then use a much longer max answer length and
444 | # let the classifier remove things
445 | max_answer_len = self.tokenizer.model_max_length
446 | else:
447 | max_answer_len = DEFAULT_MAX_ANSWER_LENGTH
448 |
449 | for output in model_outputs:
450 | words = output["words"]
451 |
452 | starts, ends, scores, min_null_score = select_starts_ends(
453 | output["start_logits"],
454 | output["end_logits"],
455 | output["p_mask"],
456 | output["attention_mask"].numpy() if output.get("attention_mask", None) is not None else None,
457 | min_null_score,
458 | top_k,
459 | handle_impossible_answer,
460 | max_answer_len,
461 | )
462 | word_ids = output["word_ids"]
463 | for start, end, score in zip(starts, ends, scores):
464 | if "token_logits" in output:
465 | predicted_token_classes = (
466 | output["token_logits"][
467 | 0,
468 | start : end + 1,
469 | ]
470 | .argmax(axis=1)
471 | .cpu()
472 | .numpy()
473 | )
474 | assert np.setdiff1d(predicted_token_classes, [0, 1]).shape == (0,)
475 | token_indices = np.flatnonzero(predicted_token_classes) + start
476 | else:
477 | token_indices = range(start, end + 1)
478 |
479 | answer_word_ids = list(unique_everseen([word_ids[i] for i in token_indices]))
480 | if len(answer_word_ids) > 0 and answer_word_ids[0] is not None and answer_word_ids[-1] is not None:
481 | answers.append(
482 | {
483 | "score": float(score),
484 | "answer": " ".join(words[i] for i in answer_word_ids),
485 | "word_ids": answer_word_ids,
486 | "page": output["page"],
487 | }
488 | )
489 |
490 | if handle_impossible_answer:
491 | answers.append({"score": min_null_score, "answer": "", "start": 0, "end": 0})
492 |
493 | return answers
494 |
--------------------------------------------------------------------------------
/src/docquery/ext/qa_helpers.py:
--------------------------------------------------------------------------------
1 | # NOTE: This code is currently under review for inclusion in the main
2 | # huggingface/transformers repository:
3 | # https://github.com/huggingface/transformers/pull/18414
4 |
5 | import warnings
6 | from collections.abc import Iterable
7 | from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8 |
9 | import numpy as np
10 | from transformers.utils import is_pytesseract_available, is_vision_available
11 |
12 |
13 | VISION_LOADED = False
14 | if is_vision_available():
15 | from PIL import Image
16 | from transformers.image_utils import load_image
17 |
18 | VISION_LOADED = True
19 | else:
20 | Image = None
21 | load_image = None
22 |
23 |
24 | TESSERACT_LOADED = False
25 | if is_pytesseract_available():
26 | import pytesseract
27 |
28 | TESSERACT_LOADED = True
29 | else:
30 | pytesseract = None
31 |
32 |
33 | def decode_spans(
34 | start: np.ndarray, end: np.ndarray, topk: int, max_answer_len: int, undesired_tokens: np.ndarray
35 | ) -> Tuple:
36 | """
37 | Take the output of any `ModelForQuestionAnswering` and will generate probabilities for each span to be the actual
38 | answer.
39 |
40 | In addition, it filters out some unwanted/impossible cases like answer len being greater than max_answer_len or
41 | answer end position being before the starting position. The method supports output the k-best answer through the
42 | topk argument.
43 |
44 | Args:
45 | start (`np.ndarray`): Individual start probabilities for each token.
46 | end (`np.ndarray`): Individual end probabilities for each token.
47 | topk (`int`): Indicates how many possible answer span(s) to extract from the model output.
48 | max_answer_len (`int`): Maximum size of the answer to extract from the model's output.
49 | undesired_tokens (`np.ndarray`): Mask determining tokens that can be part of the answer
50 | """
51 | # Ensure we have batch axis
52 | if start.ndim == 1:
53 | start = start[None]
54 |
55 | if end.ndim == 1:
56 | end = end[None]
57 |
58 | # Compute the score of each tuple(start, end) to be the real answer
59 | outer = np.matmul(np.expand_dims(start, -1), np.expand_dims(end, 1))
60 |
61 | # Remove candidate with end < start and end - start > max_answer_len
62 | candidates = np.tril(np.triu(outer), max_answer_len - 1)
63 |
64 | # Inspired by Chen & al. (https://github.com/facebookresearch/DrQA)
65 | scores_flat = candidates.flatten()
66 | if topk == 1:
67 | idx_sort = [np.argmax(scores_flat)]
68 | elif len(scores_flat) < topk:
69 | idx_sort = np.argsort(-scores_flat)
70 | else:
71 | idx = np.argpartition(-scores_flat, topk)[0:topk]
72 | idx_sort = idx[np.argsort(-scores_flat[idx])]
73 |
74 | starts, ends = np.unravel_index(idx_sort, candidates.shape)[1:]
75 | desired_spans = np.isin(starts, undesired_tokens.nonzero()) & np.isin(ends, undesired_tokens.nonzero())
76 | starts = starts[desired_spans]
77 | ends = ends[desired_spans]
78 | scores = candidates[0, starts, ends]
79 |
80 | return starts, ends, scores
81 |
82 |
83 | def select_starts_ends(
84 | start,
85 | end,
86 | p_mask,
87 | attention_mask,
88 | min_null_score=1000000,
89 | top_k=1,
90 | handle_impossible_answer=False,
91 | max_answer_len=15,
92 | ):
93 | """
94 | Takes the raw output of any `ModelForQuestionAnswering` and first normalizes its outputs and then uses
95 | `decode_spans()` to generate probabilities for each span to be the actual answer.
96 |
97 | Args:
98 | start (`np.ndarray`): Individual start probabilities for each token.
99 | end (`np.ndarray`): Individual end probabilities for each token.
100 | p_mask (`np.ndarray`): A mask with 1 for values that cannot be in the answer
101 | attention_mask (`np.ndarray`): The attention mask generated by the tokenizer
102 | min_null_score(`float`): The minimum null (empty) answer score seen so far.
103 | topk (`int`): Indicates how many possible answer span(s) to extract from the model output.
104 | handle_impossible_answer(`bool`): Whether to allow null (empty) answers
105 | max_answer_len (`int`): Maximum size of the answer to extract from the model's output.
106 | """
107 | # Ensure padded tokens & question tokens cannot belong to the set of candidate answers.
108 | undesired_tokens = np.abs(np.array(p_mask) - 1)
109 |
110 | if attention_mask is not None:
111 | undesired_tokens = undesired_tokens & attention_mask
112 |
113 | # Generate mask
114 | undesired_tokens_mask = undesired_tokens == 0.0
115 |
116 | # Make sure non-context indexes in the tensor cannot contribute to the softmax
117 | start = np.where(undesired_tokens_mask, -10000.0, start)
118 | end = np.where(undesired_tokens_mask, -10000.0, end)
119 |
120 | # Normalize logits and spans to retrieve the answer
121 | start = np.exp(start - start.max(axis=-1, keepdims=True))
122 | start = start / start.sum()
123 |
124 | end = np.exp(end - end.max(axis=-1, keepdims=True))
125 | end = end / end.sum()
126 |
127 | if handle_impossible_answer:
128 | min_null_score = min(min_null_score, (start[0, 0] * end[0, 0]).item())
129 |
130 | # Mask CLS
131 | start[0, 0] = end[0, 0] = 0.0
132 |
133 | starts, ends, scores = decode_spans(start, end, top_k, max_answer_len, undesired_tokens)
134 | return starts, ends, scores, min_null_score
135 |
--------------------------------------------------------------------------------
/src/docquery/find_leaf_nodes.js:
--------------------------------------------------------------------------------
1 | // This is adapted from code generated by GPT-3, using the following prompt:
2 | // Please write the code for a javascript program that can traverse the DOM and return each text node and its bounding box. If the text has multiple bounding boxes, it should combine them into one. The output should be Javascript code that can run in Chrome.
3 |
4 | function computeViewport() {
5 | return {
6 | vw: Math.max(
7 | document.documentElement.clientWidth || 0,
8 | window.innerWidth || 0
9 | ),
10 | vh: Math.max(
11 | document.documentElement.clientHeight || 0,
12 | window.innerHeight || 0
13 | ),
14 |
15 | // https://stackoverflow.com/questions/1145850/how-to-get-height-of-entire-document-with-javascript
16 | dh: Math.max(
17 | document.body.scrollHeight,
18 | document.body.offsetHeight,
19 | document.documentElement.clientHeight,
20 | document.documentElement.scrollHeight,
21 | document.documentElement.offsetHeight
22 | ),
23 | };
24 | }
25 |
26 | function findLeafNodes(node) {
27 | var textNodes = [];
28 | var walk = document.createTreeWalker(
29 | document.body,
30 | NodeFilter.SHOW_TEXT,
31 | null,
32 | false
33 | );
34 |
35 | while (walk.nextNode()) {
36 | var node = walk.currentNode;
37 | var range = document.createRange();
38 | range.selectNodeContents(node);
39 | var rects = Array.from(range.getClientRects());
40 | if (rects.length > 0) {
41 | var box = rects.reduce(function (previousValue, currentValue) {
42 | return {
43 | top: Math.min(previousValue.top, currentValue.top),
44 | right: Math.max(previousValue.right, currentValue.right),
45 | bottom: Math.max(previousValue.bottom, currentValue.bottom),
46 | left: Math.min(previousValue.left, currentValue.left),
47 | };
48 | });
49 |
50 | textNodes.push({
51 | text: node.textContent,
52 | box,
53 | });
54 | }
55 | }
56 |
57 | return textNodes;
58 | }
59 |
60 | function computeBoundingBoxes(node) {
61 | const word_boxes = findLeafNodes(node);
62 | return {
63 | ...computeViewport(),
64 | word_boxes,
65 | };
66 | }
67 |
--------------------------------------------------------------------------------
/src/docquery/ocr_reader.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import logging
3 | from typing import Any, List, Optional, Tuple
4 |
5 | import numpy as np
6 | from pydantic.fields import ModelField
7 |
8 |
9 | class NoOCRReaderFound(Exception):
10 | def __init__(self, e):
11 | self.e = e
12 |
13 | def __str__(self):
14 | return f"Could not load OCR Reader: {self.e}"
15 |
16 |
17 | OCR_AVAILABLE = {
18 | "tesseract": False,
19 | "easyocr": False,
20 | "dummy": True,
21 | }
22 |
23 | try:
24 | import pytesseract # noqa
25 |
26 | pytesseract.get_tesseract_version()
27 | OCR_AVAILABLE["tesseract"] = True
28 | except ImportError:
29 | pass
30 | except pytesseract.TesseractNotFoundError as e:
31 | logging.warning("Unable to find tesseract: %s." % (e))
32 | pass
33 |
34 | try:
35 | import easyocr # noqa
36 |
37 | OCR_AVAILABLE["easyocr"] = True
38 | except ImportError:
39 | pass
40 |
41 |
42 | class SingletonMeta(type):
43 | _instances = {}
44 |
45 | def __call__(cls, *args, **kwargs):
46 | if cls not in cls._instances:
47 | instance = super().__call__(*args, **kwargs)
48 | cls._instances[cls] = instance
49 | return cls._instances[cls]
50 |
51 |
52 | class OCRReader(metaclass=SingletonMeta):
53 | def __init__(self):
54 | # TODO: add device here
55 | self._check_if_available()
56 |
57 | @classmethod
58 | def __get_validators__(cls):
59 | yield cls.validate
60 |
61 | @classmethod
62 | def validate(cls, v, field: ModelField):
63 | if not isinstance(v, cls):
64 | raise TypeError("Invalid value")
65 | return v
66 |
67 | @abc.abstractmethod
68 | def apply_ocr(self, image: "Image.Image") -> Tuple[List[Any], List[List[int]]]:
69 | raise NotImplementedError
70 |
71 | @staticmethod
72 | @abc.abstractmethod
73 | def _check_if_available():
74 | raise NotImplementedError
75 |
76 |
77 | class TesseractReader(OCRReader):
78 | def __init__(self):
79 | super().__init__()
80 |
81 | def apply_ocr(self, image: "Image.Image") -> Tuple[List[str], List[List[int]]]:
82 | """
83 | Applies Tesseract on a document image, and returns recognized words + normalized bounding boxes.
84 | This was derived from LayoutLM preprocessing code in Huggingface's Transformers library.
85 | """
86 | data = pytesseract.image_to_data(image, output_type="dict")
87 | words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"]
88 |
89 | # filter empty words and corresponding coordinates
90 | irrelevant_indices = set(idx for idx, word in enumerate(words) if not word.strip())
91 | words = [word for idx, word in enumerate(words) if idx not in irrelevant_indices]
92 | left = [coord for idx, coord in enumerate(left) if idx not in irrelevant_indices]
93 | top = [coord for idx, coord in enumerate(top) if idx not in irrelevant_indices]
94 | width = [coord for idx, coord in enumerate(width) if idx not in irrelevant_indices]
95 | height = [coord for idx, coord in enumerate(height) if idx not in irrelevant_indices]
96 |
97 | # turn coordinates into (left, top, left+width, top+height) format
98 | actual_boxes = [[x, y, x + w, y + h] for x, y, w, h in zip(left, top, width, height)]
99 |
100 | return words, actual_boxes
101 |
102 | @staticmethod
103 | def _check_if_available():
104 | if not OCR_AVAILABLE["tesseract"]:
105 | raise NoOCRReaderFound(
106 | "Unable to use pytesseract (OCR will be unavailable). Install tesseract to process images with OCR."
107 | )
108 |
109 |
110 | class EasyOCRReader(OCRReader):
111 | def __init__(self):
112 | super().__init__()
113 | self.reader = None
114 |
115 | def apply_ocr(self, image: "Image.Image") -> Tuple[List[str], List[List[int]]]:
116 | """Applies Easy OCR on a document image, and returns recognized words + normalized bounding boxes."""
117 | if not self.reader:
118 | # TODO: expose language currently setting to english
119 | self.reader = easyocr.Reader(["en"]) # TODO: device here example: gpu=self.device > -1)
120 |
121 | # apply OCR
122 | data = self.reader.readtext(np.array(image))
123 | boxes, words, acc = list(map(list, zip(*data)))
124 |
125 | # filter empty words and corresponding coordinates
126 | irrelevant_indices = set(idx for idx, word in enumerate(words) if not word.strip())
127 | words = [word for idx, word in enumerate(words) if idx not in irrelevant_indices]
128 | boxes = [coords for idx, coords in enumerate(boxes) if idx not in irrelevant_indices]
129 |
130 | # turn coordinates into (left, top, left+width, top+height) format
131 | actual_boxes = [tl + br for tl, tr, br, bl in boxes]
132 |
133 | return words, actual_boxes
134 |
135 | @staticmethod
136 | def _check_if_available():
137 | if not OCR_AVAILABLE["easyocr"]:
138 | raise NoOCRReaderFound(
139 | "Unable to use easyocr (OCR will be unavailable). Install easyocr to process images with OCR."
140 | )
141 |
142 |
143 | class DummyOCRReader(OCRReader):
144 | def __init__(self):
145 | super().__init__()
146 | self.reader = None
147 |
148 | def apply_ocr(self, image: "Image.Image") -> Tuple[(List[str], List[List[int]])]:
149 | raise NoOCRReaderFound("Unable to find any OCR reader and OCR extraction was requested")
150 |
151 | @staticmethod
152 | def _check_if_available():
153 | logging.warning("Falling back to a dummy OCR reader since none were found.")
154 |
155 |
156 | OCR_MAPPING = {
157 | "tesseract": TesseractReader,
158 | "easyocr": EasyOCRReader,
159 | "dummy": DummyOCRReader,
160 | }
161 |
162 |
163 | def get_ocr_reader(ocr_reader_name: Optional[str] = None):
164 | if not ocr_reader_name:
165 | for name, reader in OCR_MAPPING.items():
166 | if OCR_AVAILABLE[name]:
167 | return reader()
168 |
169 | if ocr_reader_name in OCR_MAPPING.keys():
170 | if OCR_AVAILABLE[ocr_reader_name]:
171 | return OCR_MAPPING[ocr_reader_name]()
172 | else:
173 | raise NoOCRReaderFound(f"Failed to load: {ocr_reader_name} Please make sure its installed correctly.")
174 | else:
175 | raise NoOCRReaderFound(
176 | f"Failed to find: {ocr_reader_name} in the available ocr libraries. The choices are: {list(OCR_MAPPING.keys())}"
177 | )
178 |
--------------------------------------------------------------------------------
/src/docquery/transformers_patch.py:
--------------------------------------------------------------------------------
1 | # This file contains extensions to transformers that have not yet been upstreamed. Importantly, since docquery
2 | # is designed to be easy to install via PyPI, we must extend anything that is not part of an official release,
3 | # since libraries on pypi are not permitted to install specific git commits.
4 |
5 | from collections import OrderedDict
6 | from typing import Optional, Union
7 |
8 | import torch
9 | from transformers import AutoConfig, AutoModel, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
10 | from transformers import pipeline as transformers_pipeline
11 | from transformers.models.auto.auto_factory import _BaseAutoModelClass, _LazyAutoMapping
12 | from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES
13 | from transformers.pipelines import PIPELINE_REGISTRY
14 |
15 | from .ext.model import LayoutLMForQuestionAnswering
16 | from .ext.pipeline_document_classification import DocumentClassificationPipeline
17 | from .ext.pipeline_document_question_answering import DocumentQuestionAnsweringPipeline
18 |
19 |
20 | PIPELINE_DEFAULTS = {
21 | "document-question-answering": "impira/layoutlm-document-qa",
22 | "document-classification": "impira/layoutlm-document-classifier",
23 | }
24 |
25 | # These revisions are pinned so that the "default" experience in DocQuery is both fast (does not
26 | # need to check network for updates) and versioned (we can be sure that the model changes
27 | # result in new versions of DocQuery). This may eventually change.
28 | DEFAULT_REVISIONS = {
29 | "impira/layoutlm-document-qa": "ff904df",
30 | "impira/layoutlm-invoices": "783b0c2",
31 | "naver-clova-ix/donut-base-finetuned-rvlcdip": "5998e9f",
32 | # XXX add impira-document-classifier
33 | }
34 |
35 | MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
36 | [
37 | ("layoutlm", "LayoutLMForQuestionAnswering"),
38 | ("donut-swin", "DonutSwinModel"),
39 | ]
40 | )
41 |
42 | MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
43 | CONFIG_MAPPING_NAMES,
44 | MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES,
45 | )
46 |
47 |
48 | MODEL_FOR_DOCUMENT_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
49 | [
50 | ("layoutlm", "LayoutLMForSequenceClassification"),
51 | ]
52 | )
53 |
54 | MODEL_FOR_DOCUMENT_CLASSIFICATION_MAPPING = _LazyAutoMapping(
55 | CONFIG_MAPPING_NAMES,
56 | MODEL_FOR_DOCUMENT_CLASSIFICATION_MAPPING_NAMES,
57 | )
58 |
59 |
60 | class AutoModelForDocumentQuestionAnswering(_BaseAutoModelClass):
61 | _model_mapping = MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING
62 |
63 |
64 | class AutoModelForDocumentClassification(_BaseAutoModelClass):
65 | _model_mapping = MODEL_FOR_DOCUMENT_CLASSIFICATION_MAPPING
66 |
67 |
68 | PIPELINE_REGISTRY.register_pipeline(
69 | "document-question-answering",
70 | pipeline_class=DocumentQuestionAnsweringPipeline,
71 | pt_model=AutoModelForDocumentQuestionAnswering,
72 | )
73 |
74 | PIPELINE_REGISTRY.register_pipeline(
75 | "document-classification",
76 | pipeline_class=DocumentClassificationPipeline,
77 | pt_model=AutoModelForDocumentClassification,
78 | )
79 |
80 |
81 | def pipeline(
82 | task: str = None,
83 | model: Optional = None,
84 | tokenizer: Optional[Union[str, PreTrainedTokenizer, PreTrainedTokenizerFast]] = None,
85 | revision: Optional[str] = None,
86 | device: Optional[Union[int, str, "torch.device"]] = None,
87 | **pipeline_kwargs
88 | ):
89 |
90 | if model is None and task is not None:
91 | model = PIPELINE_DEFAULTS.get(task)
92 |
93 | if revision is None and model is not None:
94 | revision = DEFAULT_REVISIONS.get(model)
95 |
96 | # We need to explicitly check for the impira/layoutlm-document-qa model because of challenges with
97 | # registering an existing model "flavor" (layoutlm) within transformers after the fact. There may
98 | # be a clever way to get around this. Either way, we should be able to remove it once
99 | # https://github.com/huggingface/transformers/commit/5c4c869014f5839d04c1fd28133045df0c91fd84
100 | # is officially released.
101 | config = AutoConfig.from_pretrained(model, revision=revision, **{**pipeline_kwargs})
102 |
103 | if tokenizer is None:
104 | tokenizer = AutoTokenizer.from_pretrained(
105 | model,
106 | revision=revision,
107 | config=config,
108 | **pipeline_kwargs,
109 | )
110 |
111 | if any(a == "LayoutLMForQuestionAnswering" for a in config.architectures):
112 | model = LayoutLMForQuestionAnswering.from_pretrained(
113 | model, config=config, revision=revision, **{**pipeline_kwargs}
114 | )
115 |
116 | if config.model_type == "vision-encoder-decoder":
117 | # This _should_ be a feature of transformers -- deriving the feature_extractor automatically --
118 | # but is not at the time of writing, so we do it explicitly.
119 | pipeline_kwargs["feature_extractor"] = model
120 |
121 | if device is None:
122 | # This trick merely simplifies the device argument, so that cuda is used by default if it's
123 | # available, which at the time of writing is not a feature of transformers
124 | device = 0 if torch.cuda.is_available() else -1
125 |
126 | return transformers_pipeline(
127 | task,
128 | revision=revision,
129 | model=model,
130 | tokenizer=tokenizer,
131 | device=device,
132 | **pipeline_kwargs,
133 | )
134 |
--------------------------------------------------------------------------------
/src/docquery/version.py:
--------------------------------------------------------------------------------
1 | VERSION = "0.0.7"
2 |
--------------------------------------------------------------------------------
/src/docquery/web.py:
--------------------------------------------------------------------------------
1 | import os
2 | from io import BytesIO
3 | from pathlib import Path
4 |
5 | from PIL import Image
6 |
7 | from .config import get_logger
8 | from .ext.functools import cached_property
9 |
10 |
11 | log = get_logger("web")
12 |
13 | try:
14 | from selenium import webdriver
15 | from selenium.common import exceptions
16 | from selenium.webdriver.chrome.options import Options
17 | from webdriver_manager.chrome import ChromeDriverManager
18 | from webdriver_manager.core.utils import ChromeType
19 |
20 | WEB_AVAILABLE = True
21 | except ImportError:
22 | WEB_AVAILABLE = False
23 |
24 |
25 | FIND_LEAF_NODES_JS = None
26 | WEB_DRIVER = None
27 | dir_path = Path(os.path.dirname(os.path.realpath(__file__)))
28 |
29 |
30 | class WebDriver:
31 | def __init__(self):
32 | if not WEB_AVAILABLE:
33 | raise ValueError(
34 | "Web imports are unavailable. You must install the [web] extra and chrome or" " chromium system-wide."
35 | )
36 |
37 | self._reinit_driver()
38 |
39 | def _reinit_driver(self):
40 | options = Options()
41 | options.headless = True
42 | options.add_argument("--window-size=1920,1200")
43 | if os.geteuid() == 0:
44 | options.add_argument("--no-sandbox")
45 |
46 | self.driver = webdriver.Chrome(
47 | options=options, executable_path=ChromeDriverManager(chrome_type=ChromeType.CHROMIUM).install()
48 | )
49 |
50 | def get(self, page, retry=True):
51 | try:
52 | self.driver.get(page)
53 | except exceptions.InvalidSessionIdException:
54 | if retry:
55 | # Forgive an invalid session once and try again
56 | self._reinit_driver()
57 | return self.get(page, retry=False)
58 | else:
59 | raise
60 |
61 | def get_html(self, html):
62 | # https://stackoverflow.com/questions/22538457/put-a-string-with-html-javascript-into-selenium-webdriver
63 | self.get("data:text/html;charset=utf-8," + html)
64 |
65 | def find_word_boxes(self):
66 | # Assumes the driver has been pointed at the right website already
67 | return self.driver.execute_script(
68 | self.lib_js
69 | + """
70 | return computeBoundingBoxes(document.body);
71 | """
72 | )
73 |
74 | # TODO: Handle horizontal scrolling
75 | def scroll_and_screenshot(self):
76 | tops = []
77 | images = []
78 | dims = self.driver.execute_script(
79 | self.lib_js
80 | + """
81 | return computeViewport()
82 | """
83 | )
84 |
85 | view_height = dims["vh"]
86 | doc_height = dims["dh"]
87 |
88 | try:
89 | self.driver.execute_script("window.scroll(0, 0)")
90 | curr = self.driver.execute_script("return window.scrollY")
91 |
92 | while True:
93 | tops.append(curr)
94 | images.append(Image.open(BytesIO(self.driver.get_screenshot_as_png())))
95 | if curr + view_height < doc_height:
96 | self.driver.execute_script(f"window.scroll(0, {curr+view_height})")
97 |
98 | curr = self.driver.execute_script("return window.scrollY")
99 | if curr <= tops[-1]:
100 | break
101 | finally:
102 | # Reset scroll to the top of the page
103 | self.driver.execute_script("window.scroll(0, 0)")
104 |
105 | if len(tops) >= 2:
106 | _, second_last_height = images[-2].size
107 | if tops[-1] - tops[-2] < second_last_height:
108 | # This occurs when the last screenshot should be "clipped". Adjust the last "top"
109 | # to correspond to the right view_height and clip the screenshot accordingly
110 | delta = second_last_height - (tops[-1] - tops[-2])
111 | tops[-1] += delta
112 |
113 | last_img = images[-1]
114 | last_width, last_height = last_img.size
115 | images[-1] = last_img.crop((0, delta, last_width, last_height))
116 |
117 | return tops, images
118 |
119 | @cached_property
120 | def lib_js(self):
121 | with open(dir_path / "find_leaf_nodes.js", "r") as f:
122 | return f.read()
123 |
124 |
125 | def get_webdriver():
126 | global WEB_DRIVER
127 | if WEB_DRIVER is None:
128 | WEB_DRIVER = WebDriver()
129 | return WEB_DRIVER
130 |
--------------------------------------------------------------------------------
/tests/test_classification_end_to_end.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List
2 |
3 | import pytest
4 | from pydantic import BaseModel
5 | from transformers.testing_utils import nested_simplify
6 |
7 | from docquery import pipeline
8 | from docquery.document import load_document
9 | from docquery.ocr_reader import TesseractReader
10 |
11 |
12 | CHECKPOINTS = {
13 | "Donut": "naver-clova-ix/donut-base-finetuned-rvlcdip",
14 | }
15 |
16 |
17 | class Example(BaseModel):
18 | name: str
19 | path: str
20 | classes: Dict[str, List[str]]
21 |
22 |
23 | # Use the examples from the DocQuery space (this also solves for hosting)
24 | EXAMPLES = [
25 | Example(
26 | name="contract",
27 | path="https://huggingface.co/spaces/impira/docquery/resolve/2f6c96314dc84dfda62d40de9da55f2f5165d403/contract.jpeg",
28 | classes={"Donut": ["scientific_report"]},
29 | ),
30 | Example(
31 | name="invoice",
32 | path="https://huggingface.co/spaces/impira/docquery/resolve/2f6c96314dc84dfda62d40de9da55f2f5165d403/invoice.png",
33 | classes={"Donut": ["invoice"]},
34 | ),
35 | Example(
36 | name="statement",
37 | path="https://huggingface.co/spaces/impira/docquery/resolve/2f6c96314dc84dfda62d40de9da55f2f5165d403/statement.png",
38 | classes={"Donut": ["budget"]},
39 | ),
40 | ]
41 |
42 |
43 | @pytest.mark.parametrize("example", EXAMPLES)
44 | @pytest.mark.parametrize("model", CHECKPOINTS.keys())
45 | def test_impira_dataset(example, model):
46 | document = load_document(example.path)
47 | pipe = pipeline("document-classification", model=CHECKPOINTS[model])
48 | resp = pipe(top_k=1, **document.context)
49 | assert resp == [{"label": x} for x in example.classes[model]]
50 |
--------------------------------------------------------------------------------
/tests/test_end_to_end.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List
2 |
3 | import pytest
4 | from pydantic import BaseModel
5 | from transformers.testing_utils import nested_simplify
6 |
7 | from docquery import pipeline
8 | from docquery.document import load_document
9 | from docquery.ocr_reader import TesseractReader
10 |
11 |
12 | CHECKPOINTS = {
13 | "LayoutLMv1": "impira/layoutlm-document-qa",
14 | "LayoutLMv1-Invoices": "impira/layoutlm-invoices",
15 | "Donut": "naver-clova-ix/donut-base-finetuned-docvqa",
16 | }
17 |
18 |
19 | class QAPair(BaseModel):
20 | question: str
21 | answers: Dict[str, List[Dict]]
22 |
23 |
24 | class Example(BaseModel):
25 | name: str
26 | path: str
27 | qa_pairs: List[QAPair]
28 |
29 |
30 | # Use the examples from the DocQuery space (this also solves for hosting)
31 | EXAMPLES = [
32 | Example(
33 | name="contract",
34 | path="https://huggingface.co/spaces/impira/docquery/resolve/2f6c96314dc84dfda62d40de9da55f2f5165d403/contract.jpeg",
35 | qa_pairs=[
36 | {
37 | "question": "What is the purchase amount?",
38 | "answers": {
39 | "LayoutLMv1": [{"score": 0.9999, "answer": "$1,000,000,000", "word_ids": [97], "page": 0}],
40 | "LayoutLMv1-Invoices": [
41 | {"score": 0.9997, "answer": "$1,000,000,000", "word_ids": [97], "page": 0}
42 | ],
43 | "Donut": [{"answer": "$1,0000,000,00"}],
44 | },
45 | }
46 | ],
47 | ),
48 | Example(
49 | name="invoice",
50 | path="https://huggingface.co/spaces/impira/docquery/resolve/2f6c96314dc84dfda62d40de9da55f2f5165d403/invoice.png",
51 | qa_pairs=[
52 | {
53 | "question": "What is the invoice number?",
54 | "answers": {
55 | "LayoutLMv1": [{"score": 0.9997, "answer": "us-001", "word_ids": [15], "page": 0}],
56 | "LayoutLMv1-Invoices": [{"score": 0.9999, "answer": "us-001", "word_ids": [15], "page": 0}],
57 | "Donut": [{"answer": "us-001"}],
58 | },
59 | }
60 | ],
61 | ),
62 | Example(
63 | name="statement",
64 | path="https://huggingface.co/spaces/impira/docquery/resolve/2f6c96314dc84dfda62d40de9da55f2f5165d403/statement.pdf",
65 | qa_pairs=[
66 | {
67 | "question": "What are net sales for 2020?",
68 | "answers": {
69 | "LayoutLMv1": [{"score": 0.9429, "answer": "$ 3,750\n", "word_ids": [15, 16], "page": 0}],
70 | # (The answer with `use_embedded_text=False` relies entirely on Tesseract, and it is incorrect because it
71 | # misses 3,750 altogether.)
72 | "LayoutLMv1__use_embedded_text=False": [
73 | {"score": 0.3078, "answer": "$ 3,980", "word_ids": [11, 12], "page": 0}
74 | ],
75 | "LayoutLMv1-Invoices": [{"score": 0.9956, "answer": "$ 3,750\n", "word_ids": [15, 16], "page": 0}],
76 | "Donut": [{"answer": "$ 3,750"}],
77 | },
78 | }
79 | ],
80 | ),
81 | Example(
82 | name="readme",
83 | path="https://github.com/impira/docquery/blob/ef73fa7e8069773ace03efae2254f3a510a814ef/README.md",
84 | qa_pairs=[
85 | {
86 | "question": "What are the use cases for DocQuery?",
87 | "answers": {
88 | # These examples demonstrate the fact that the "word_boxes" are way too coarse in the web document implementation
89 | "LayoutLMv1": [
90 | {
91 | "score": 0.9921,
92 | "answer": "DocQuery is a swiss army knife tool for working with documents and experiencing the power of modern machine learning. You can use it\njust about anywhere, including behind a firewall on sensitive data, and test it with a wide variety of documents. Our hope is that\nDocQuery enables many creative use cases for document understanding by making it simple and easy to ask questions from your documents.",
93 | "word_ids": [37],
94 | "page": 2,
95 | }
96 | ],
97 | "LayoutLMv1-Invoices": [
98 | {
99 | "score": 0.9939,
100 | "answer": "DocQuery is a library and command-line tool that makes it easy to analyze semi-structured and unstructured documents (PDFs, scanned\nimages, etc.) using large language models (LLMs). You simply point DocQuery at one or more documents and specify a\nquestion you want to ask. DocQuery is created by the team at ",
101 | "word_ids": [98],
102 | "page": 0,
103 | }
104 | ],
105 | "Donut": [{"answer": "engine Powered by large language"}],
106 | },
107 | }
108 | ],
109 | ),
110 | ]
111 |
112 |
113 | @pytest.mark.parametrize("example", EXAMPLES)
114 | @pytest.mark.parametrize("model", CHECKPOINTS.keys())
115 | def test_impira_dataset(example, model):
116 | document = load_document(example.path)
117 | pipe = pipeline("document-question-answering", model=CHECKPOINTS[model])
118 | for qa in example.qa_pairs:
119 | resp = pipe(question=qa.question, **document.context, top_k=1)
120 | assert nested_simplify(resp, decimals=4) == qa.answers[model]
121 |
122 |
123 | def test_run_with_choosen_OCR_str():
124 | example = EXAMPLES[0]
125 | document = load_document(example.path, "tesseract")
126 | pipe = pipeline("document-question-answering", model=CHECKPOINTS["LayoutLMv1"])
127 | for qa in example.qa_pairs:
128 | resp = pipe(question=qa.question, **document.context, top_k=1)
129 | assert nested_simplify(resp, decimals=4) == qa.answers["LayoutLMv1"]
130 |
131 |
132 | def test_run_with_choosen_OCR_instance():
133 | example = EXAMPLES[0]
134 | reader = TesseractReader()
135 | document = load_document(example.path, reader)
136 | pipe = pipeline("document-question-answering", model=CHECKPOINTS["LayoutLMv1"])
137 | for qa in example.qa_pairs:
138 | resp = pipe(question=qa.question, **document.context, top_k=1)
139 | assert nested_simplify(resp, decimals=4) == qa.answers["LayoutLMv1"]
140 |
141 |
142 | def test_run_with_ignore_embedded_text():
143 | example = EXAMPLES[2]
144 | document = load_document(example.path, use_embedded_text=False)
145 | pipe = pipeline("document-question-answering", model=CHECKPOINTS["LayoutLMv1"])
146 | for qa in example.qa_pairs:
147 | resp = pipe(question=qa.question, **document.context, top_k=1)
148 | assert nested_simplify(resp, decimals=4) == qa.answers["LayoutLMv1__use_embedded_text=False"]
149 |
--------------------------------------------------------------------------------
/tests/test_ocr_reader.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import patch
2 |
3 | import pytest
4 |
5 | from docquery.ocr_reader import DummyOCRReader, EasyOCRReader, NoOCRReaderFound, TesseractReader, get_ocr_reader
6 |
7 |
8 | READER_PERMUTATIONS = [
9 | {"name": "tesseract", "reader_class": TesseractReader},
10 | {"name": "easyocr", "reader_class": EasyOCRReader},
11 | {"name": "dummy", "reader_class": DummyOCRReader},
12 | ]
13 |
14 |
15 | @pytest.mark.parametrize("reader_permutations", READER_PERMUTATIONS)
16 | @patch(
17 | "docquery.ocr_reader.OCR_AVAILABLE",
18 | {
19 | "tesseract": True,
20 | "easyocr": True,
21 | "dummy": True,
22 | },
23 | )
24 | def test_get_ocr_reader(reader_permutations):
25 | reader = get_ocr_reader(reader_permutations["name"])
26 | assert isinstance(reader, reader_permutations["reader_class"])
27 |
28 |
29 | @patch(
30 | "docquery.ocr_reader.OCR_AVAILABLE",
31 | {
32 | "tesseract": True,
33 | "easyocr": True,
34 | "dummy": True,
35 | },
36 | )
37 | def test_wrong_string_ocr_reader():
38 | with pytest.raises(Exception) as e:
39 | reader = get_ocr_reader("FAKE_OCR")
40 | assert (
41 | "Failed to find: FAKE_OCR in the available ocr libraries. The choices are: ['tesseract', 'easyocr', 'dummy']"
42 | in str(e.value)
43 | )
44 | assert e.type == NoOCRReaderFound
45 |
46 |
47 | @patch(
48 | "docquery.ocr_reader.OCR_AVAILABLE",
49 | {
50 | "tesseract": False,
51 | "easyocr": True,
52 | "dummy": True,
53 | },
54 | )
55 | def test_choosing_unavailable_ocr_reader():
56 | with pytest.raises(Exception) as e:
57 | reader = get_ocr_reader("tesseract")
58 | assert f"Failed to load: tesseract Please make sure its installed correctly." in str(e.value)
59 | assert e.type == NoOCRReaderFound
60 |
61 |
62 | @patch(
63 | "docquery.ocr_reader.OCR_AVAILABLE",
64 | {
65 | "tesseract": False,
66 | "easyocr": True,
67 | "dummy": True,
68 | },
69 | )
70 | def test_assert_fallback():
71 | reader = get_ocr_reader()
72 | assert isinstance(reader, EasyOCRReader)
73 |
74 |
75 | @patch(
76 | "docquery.ocr_reader.OCR_AVAILABLE",
77 | {
78 | "tesseract": False,
79 | "easyocr": False,
80 | "dummy": True,
81 | },
82 | )
83 | def test_assert_fallback_to_dummy():
84 | reader = get_ocr_reader()
85 | assert isinstance(reader, DummyOCRReader)
86 |
87 |
88 | @patch(
89 | "docquery.ocr_reader.OCR_AVAILABLE",
90 | {
91 | "tesseract": False,
92 | "easyocr": False,
93 | "dummy": False,
94 | },
95 | )
96 | def test_fail_to_load_if_called_directly_when_ocr_unavailable():
97 | EasyOCRReader._instances = {}
98 | with pytest.raises(Exception) as e:
99 | reader = EasyOCRReader()
100 | assert "Unable to use easyocr (OCR will be unavailable). Install easyocr to process images with OCR." in str(
101 | e.value
102 | )
103 | assert e.type == NoOCRReaderFound
104 |
105 |
106 | def test_ocr_reader_are_singletons():
107 | reader_a = DummyOCRReader()
108 | reader_b = DummyOCRReader()
109 | reader_c = DummyOCRReader()
110 | assert reader_a is reader_b
111 | assert reader_a is reader_c
112 |
--------------------------------------------------------------------------------
/tests/test_web_driver.py:
--------------------------------------------------------------------------------
1 | from docquery.web import get_webdriver
2 |
3 |
4 | def test_singleton():
5 | d1 = get_webdriver()
6 | d2 = get_webdriver()
7 | assert d1 is d2, "Both webdrivers should map to the same instance"
8 |
9 |
10 | def test_readme_file():
11 | driver = get_webdriver()
12 | driver.get("https://github.com/impira/docquery/blob/ef73fa7e8069773ace03efae2254f3a510a814ef/README.md")
13 | word_boxes = driver.find_word_boxes()
14 |
15 | # This sanity checks the logic that merges word boxes
16 | assert len(word_boxes["word_boxes"]) > 20, "Expect multiple word boxes"
17 |
18 | # Make sure the last screenshot is shorter than the previous ones
19 | _, screenshots = driver.scroll_and_screenshot()
20 | assert len(screenshots) > 1, "Expect multiple pages"
21 | assert (
22 | screenshots[0].size[1] - screenshots[-1].size[1] > 10
23 | ), "Expect the last page to be shorter than the first several"
24 |
--------------------------------------------------------------------------------