├── .flake8 ├── .github └── workflows │ └── pre-commit.yml ├── .gitignore ├── .isort.cfg ├── .pre-commit-config.yaml ├── LICENSE ├── Makefile ├── README.md ├── docquery_example.ipynb ├── pyproject.toml ├── setup.py ├── src └── docquery │ ├── __init__.py │ ├── cmd │ ├── __init__.py │ ├── __main__.py │ └── scan.py │ ├── config.py │ ├── document.py │ ├── ext │ ├── __init__.py │ ├── functools.py │ ├── itertools.py │ ├── model.py │ ├── pipeline_document_classification.py │ ├── pipeline_document_question_answering.py │ └── qa_helpers.py │ ├── find_leaf_nodes.js │ ├── ocr_reader.py │ ├── transformers_patch.py │ ├── version.py │ └── web.py └── tests ├── test_classification_end_to_end.py ├── test_end_to_end.py ├── test_ocr_reader.py └── test_web_driver.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 119 3 | ignore = E402, E203, E501, W503 4 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | # https://pre-commit.com 2 | # This GitHub Action assumes that the repo contains a valid .pre-commit-config.yaml file. 3 | # Using pre-commit.ci is even better that using GitHub Actions for pre-commit. 4 | name: pre-commit 5 | on: 6 | pull_request: 7 | branches: [main] 8 | push: 9 | branches: [main] 10 | workflow_dispatch: 11 | jobs: 12 | pre-commit: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v3 16 | - uses: actions/setup-python@v4 17 | with: 18 | python-version: 3.x 19 | - run: pip install pre-commit 20 | - run: pre-commit --version 21 | - run: pre-commit install 22 | - run: pre-commit run --all-files 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | venv 2 | *.swp 3 | *.swo 4 | *.pyc 5 | .DS_Store 6 | __pycache__ 7 | dist 8 | src/docquery.egg-info 9 | docs 10 | .vscode/settings.json 11 | build 12 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | line_length=119 3 | multi_line_output=3 4 | use_parentheses=true 5 | lines_after_imports=2 6 | include_trailing_comma=True 7 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: "https://github.com/pre-commit/pre-commit-hooks" 3 | rev: v4.3.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | - repo: "https://github.com/psf/black" 9 | rev: 22.6.0 10 | hooks: 11 | - id: black 12 | files: ./ 13 | - repo: "https://github.com/PyCQA/isort" 14 | rev: 5.10.1 15 | hooks: 16 | - id: isort 17 | args: 18 | - --settings-path 19 | - .isort.cfg 20 | files: ./ 21 | - repo: https://github.com/codespell-project/codespell 22 | rev: v2.2.1 23 | hooks: 24 | - id: codespell 25 | 26 | - repo: https://github.com/pre-commit/mirrors-prettier 27 | rev: v2.7.1 28 | hooks: 29 | - id: prettier 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 impira 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | all: build 2 | 3 | VERSION=$(shell python -c 'from src.docquery.version import VERSION; print(VERSION)') 4 | 5 | .PHONY: build publish clean 6 | build: 7 | python3 -m build 8 | 9 | publish: build 10 | python3 -m twine upload dist/docquery-${VERSION}* 11 | 12 | clean: 13 | rm -rf dist/* 14 | 15 | 16 | VENV_INITIALIZED := venv/.initialized 17 | 18 | ${VENV_INITIALIZED}: 19 | rm -rf venv && python3 -m venv venv 20 | @touch ${VENV_INITIALIZED} 21 | 22 | VENV_PYTHON_PACKAGES := venv/.python_packages 23 | 24 | ${VENV_PYTHON_PACKAGES}: ${VENV_INITIALIZED} setup.py 25 | bash -c 'source venv/bin/activate && python -m pip install --upgrade pip setuptools' 26 | bash -c 'source venv/bin/activate && python -m pip install -e .[dev]' 27 | @touch $@ 28 | 29 | VENV_PRE_COMMIT := venv/.pre_commit 30 | 31 | ${VENV_PRE_COMMIT}: ${VENV_PYTHON_PACKAGES} 32 | bash -c 'source venv/bin/activate && pre-commit install' 33 | @touch $@ 34 | 35 | .PHONY: develop fixup test 36 | develop: ${VENV_PRE_COMMIT} 37 | @echo 'Run "source venv/bin/activate" to enter development mode' 38 | 39 | fixup: 40 | pre-commit run --all-files 41 | 42 | test: 43 | python -m pytest -s -v ./tests/ 44 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | NOTE: DocQuery is not actively maintained anymore. We still welcome contributions and discussions among the community! 4 | 5 | # DocQuery: Document Query Engine Powered by Large Language Models 6 | 7 | [![Demo](https://img.shields.io/badge/Demo-Gradio-brightgreen)](https://huggingface.co/spaces/impira/docquery) 8 | [![Demo](https://img.shields.io/badge/Demo-Colab-orange)](https://github.com/impira/docquery/blob/main/docquery_example.ipynb) 9 | [![PyPI](https://img.shields.io/pypi/v/docquery?color=green&label=pip%20install%20docquery)](https://pypi.org/project/docquery/) 10 | [![Discord](https://img.shields.io/discord/1015684761471160402?label=Chat)](https://discord.gg/HucNfTtx7V) 11 | [![Downloads](https://static.pepy.tech/personalized-badge/docquery?period=total&units=international_system&left_color=grey&right_color=green&left_text=Downloads)](https://pepy.tech/project/docquery) 12 | 13 |
14 | 15 | DocQuery is a library and command-line tool that makes it easy to analyze semi-structured and unstructured documents (PDFs, scanned 16 | images, etc.) using large language models (LLMs). You simply point DocQuery at one or more documents and specify a 17 | question you want to ask. DocQuery is created by the team at [Impira](https://impira.com?utm_source=github&utm_medium=referral&utm_campaign=docquery). 18 | 19 | ## Quickstart (CLI) 20 | 21 | To install `docquery`, you can simply run `pip install docquery`. This will install the command line tool as well as the library. 22 | If you want to run OCR on images, then you must also install the [tesseract](https://github.com/tesseract-ocr/tesseract) library: 23 | 24 | - Mac OS X (using [Homebrew](https://brew.sh/)): 25 | 26 | ```sh 27 | brew install tesseract 28 | ``` 29 | 30 | - Ubuntu: 31 | 32 | ```sh 33 | apt install tesseract-ocr 34 | ``` 35 | 36 | `docquery` scan allows you to ask one or more questions to a single document or directory of files. For example, you can 37 | find the invoice number with: 38 | 39 | ```bash 40 | docquery scan "What is the invoice number?" https://templates.invoicehome.com/invoice-template-us-neat-750px.png 41 | ``` 42 | 43 | If you have a folder of documents on your machine, you can run something like 44 | 45 | ```bash 46 | docquery scan "What is the effective date?" /path/to/contracts/folder 47 | ``` 48 | 49 | to determine the effective date of every document in the folder. 50 | 51 | ## Quickstart (Library) 52 | 53 | DocQuery can also be used as a library. It contains two basic abstractions: (1) a `DocumentQuestionAnswering` pipeline 54 | that makes it simple to ask questions of documents and (2) a `Document` abstraction that can parse various types of documents 55 | to feed into the pipeline. 56 | 57 | ```python 58 | >>> from docquery import document, pipeline 59 | >>> p = pipeline('document-question-answering') 60 | >>> doc = document.load_document("/path/to/document.pdf") 61 | >>> for q in ["What is the invoice number?", "What is the invoice total?"]: 62 | ... print(q, p(question=q, **doc.context)) 63 | ``` 64 | 65 | ## Use cases 66 | 67 | DocQuery excels at a number of use cases involving structured, semi-structured, or unstructured documents. You can ask questions about 68 | invoices, contracts, forms, emails, letters, receipts, and many more. You can also classify documents. We will continue evolving the model, 69 | offer more modeling options, and expanding the set of supported documents. We welcome feedback, requests, and of course contributions to 70 | help achieve this vision. 71 | 72 | ## How it works 73 | 74 | Under the hood, docquery uses a pre-trained zero-shot language model, based on [LayoutLM](https://arxiv.org/abs/1912.13318), that has been 75 | fine-tuned for a question-answering task. The model is trained using a combination of [SQuAD2.0](https://rajpurkar.github.io/SQuAD-explorer/) 76 | and [DocVQA](https://rrc.cvc.uab.es/?ch=17) which make it particularly well suited for complex visual question answering tasks on 77 | a wide variety of documents. The underlying model is also published on HuggingFace as [impira/layoutlm-document-qa](https://huggingface.co/impira/layoutlm-document-qa) 78 | which you can access directly. 79 | 80 | ## Limitations 81 | 82 | DocQuery is intended to have a small install footprint and be simple to work with. As a result, it has some limitations: 83 | 84 | - Models must be pre-trained. Although DocQuery uses a zero-shot model that can adapt based on the question you provide, it does not learn from your data. 85 | - Support for images and PDFs. Currently DocQuery supports images and PDFs, with or without embedded text. It does not support word documents, emails, spreadsheets, etc. 86 | - Scalar text outputs. DocQuery only produces text outputs (answers). It does not support richer scalar types (i.e. it treats numbers and dates as strings) or tables. 87 | 88 | ## Advanced features 89 | 90 | ### Using Donut 🍩 91 | 92 | If you'd like to test `docquery` with [Donut](https://arxiv.org/abs/2111.15664), you must install the required extras: 93 | 94 | ```bash 95 | pip install docquery[donut] 96 | ``` 97 | 98 | You can then run 99 | 100 | ```bash 101 | docquery scan "What is the effective date?" /path/to/contracts/folder --checkpoint 'naver-clova-ix/donut-base-finetuned-docvqa' 102 | ``` 103 | 104 | ### Classifying documents 105 | 106 | To classify documents, you simply add the `--classify` argument to `scan`. You can specify any [image classification](https://huggingface.co/models?pipeline_tag=image-classification&sort=downloads) 107 | model on Hugging Face's hub. By default, the classification pipeline uses [Donut](https://huggingface.co/spaces/nielsr/donut-rvlcdip) (which requires 108 | the installation instructions above): 109 | 110 | ```bash 111 | 112 | # Classify documents 113 | docquery scan --classify /path/to/contracts/folder --checkpoint 'naver-clova-ix/donut-base-finetuned-docvqa' 114 | 115 | # Classify documents and ask a question too 116 | docquery scan --classify "What is the effective date?" /path/to/contracts/folder --checkpoint 'naver-clova-ix/donut-base-finetuned-docvqa' 117 | ``` 118 | 119 | ### Scraping webpages 120 | 121 | DocQuery can read files through HTTP/HTTPs out of the box. However, if you want to read HTML documents, you can do that too by installing the 122 | `[web]` extension. The extension uses the [webdriver-manager](https://pypi.org/project/webdriver-manager/) library which can install a Chrome 123 | driver on your system automatically, but you'll need to make sure Chrome is installed globally. 124 | 125 | ``` 126 | # Find the top post on hacker news 127 | docquery scan "What is the #1 post's title?" https://news.ycombinator.com 128 | ``` 129 | 130 | ## Where to go from here 131 | 132 | DocQuery is a swiss army knife tool for working with documents and experiencing the power of modern machine learning. You can use it 133 | just about anywhere, including behind a firewall on sensitive data, and test it with a wide variety of documents. Our hope is that 134 | DocQuery enables many creative use cases for document understanding by making it simple and easy to ask questions from your documents. 135 | 136 | When you run DocQuery for the first time, it will download some files (e.g. the models and some library code from HuggingFace). However, 137 | nothing leaves your computer -- the OCR is done locally, models run locally, etc. This comes with the benefit of security and privacy; 138 | however, it comes at the cost of runtime performance and some accuracy. 139 | 140 | If you find yourself wondering how to achieve higher accuracy, work with more file types, teach the model with your own data, have 141 | a human-in-the-loop workflow, or query the data you're extracting, then do not fear -- you are running into the challenges that 142 | every organization does while putting document AI into production. The [Impira](https://www.impira.com/) platform is designed to 143 | solve these problems in an easy and intuitive way. Impira comes with a QA model that is additionally trained on proprietary datasets 144 | and can achieve 95%+ accuracy out-of-the-box for most use cases. It also has an intuitive UI that enables subject matter experts to label 145 | and improve the models, as well as an API that makes integration a breeze. Please [sign up for the product](https://www.impira.com/signup) or 146 | [reach out to us](info@impira.com) for more details. 147 | 148 | ## Status 149 | 150 | DocQuery is a new project. Although the underlying models are running in production, we've just recently released our code in open source 151 | and are actively working with the OSS community to upstream some of the changes we've made (e.g. [the model](https://github.com/huggingface/transformers/pull/18407) 152 | and [pipeline](https://github.com/huggingface/transformers/pull/18414)). DocQuery is rapidly changing, and we are likely to make breaking 153 | API changes. If you would like to run it in production, then we suggest pinning a version or commit hash. Either way, please get in touch 154 | with us at [oss@impira.com](mailto:oss@impira.com) with any questions or feedback. 155 | 156 | ## Acknowledgements 157 | 158 | DocQuery would not be possible without the contributions of many open source projects: 159 | 160 | - [pdfplumber](https://github.com/jsvine/pdfplumber) / [pdfminer.six](https://github.com/pdfminer/pdfminer.six) 161 | - [Pillow](https://pillow.readthedocs.io/en/stable/) 162 | - [pytorch](https://pytorch.org/) 163 | - [tesseract](https://github.com/tesseract-ocr/tesseract) / [pytesseract](https://pypi.org/project/pytesseract/) 164 | - [transformers](https://github.com/impira/transformers) 165 | 166 | and many others! 167 | 168 | ## License 169 | 170 | This project is licensed under the [MIT license](LICENSE). 171 | 172 | It contains code that is copied and adapted from transformers (), 173 | which is [Apache 2.0 licensed](http://www.apache.org/licenses/LICENSE-2.0). Files containing this code have 174 | been marked as such in their comments. 175 | -------------------------------------------------------------------------------- /docquery_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "docquery_example.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "authorship_tag": "ABX9TyNxYAeRjZTeKNeMu6iJvYRw", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "language_info": { 17 | "name": "python" 18 | } 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "view-in-github", 25 | "colab_type": "text" 26 | }, 27 | "source": [ 28 | "\"Open" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": { 35 | "id": "yS9UNjHnAAS9" 36 | }, 37 | "outputs": [], 38 | "source": [ 39 | "!git clone https://github.com/impira/docquery.git\n", 40 | "!sudo apt install tesseract-ocr\n", 41 | "!sudo apt-get install poppler-utils\n", 42 | "!cd docquery && pip install .[all]" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "source": [ 48 | "!docquery scan \"who authored this paper?\" https://arxiv.org/pdf/2101.07597.pdf" 49 | ], 50 | "metadata": { 51 | "id": "bKRRY5u2DV52" 52 | }, 53 | "execution_count": null, 54 | "outputs": [] 55 | } 56 | ] 57 | } 58 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 119 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import setuptools 4 | 5 | dir_name = os.path.abspath(os.path.dirname(__file__)) 6 | 7 | version_contents = {} 8 | with open(os.path.join(dir_name, "src", "docquery", "version.py"), encoding="utf-8") as f: 9 | exec(f.read(), version_contents) 10 | 11 | with open(os.path.join(dir_name, "README.md"), "r", encoding="utf-8") as f: 12 | long_description = f.read() 13 | 14 | install_requires = [ 15 | "torch >= 1.0", 16 | "pdf2image", 17 | "pdfplumber", 18 | "Pillow", 19 | "pydantic", 20 | "pytesseract", # TODO: Test what happens if the host machine does not have tesseract installed 21 | "requests", 22 | "easyocr", 23 | "transformers >= 4.23", 24 | ] 25 | extras_require = { 26 | "dev": [ 27 | "black", 28 | "build", 29 | "flake8", 30 | "flake8-isort", 31 | "isort==5.10.1", 32 | "pre-commit", 33 | "pytest", 34 | "twine", 35 | ], 36 | "donut": [ 37 | "sentencepiece", 38 | "protobuf<=3.20.1", 39 | ], 40 | "web": [ 41 | "selenium", 42 | "webdriver-manager", 43 | ], 44 | "cli": [], 45 | } 46 | extras_require["all"] = sorted({package for packages in extras_require.values() for package in packages}) 47 | 48 | setuptools.setup( 49 | name="docquery", 50 | version=version_contents["VERSION"], 51 | author="Impira Engineering", 52 | author_email="engineering@impira.com", 53 | description="DocQuery: An easy way to extract information from documents", 54 | long_description=long_description, 55 | long_description_content_type="text/markdown", 56 | url="https://github.com/impira/docquery", 57 | project_urls={ 58 | "Bug Tracker": "https://github.com/impira/docquery/issues", 59 | }, 60 | classifiers=[ 61 | "Programming Language :: Python :: 3", 62 | "License :: OSI Approved :: MIT License", 63 | "Operating System :: OS Independent", 64 | ], 65 | package_dir={"": "src"}, 66 | package_data={"": ["find_leaf_nodes.js"]}, 67 | packages=setuptools.find_packages(where="src"), 68 | python_requires=">=3.7.0", 69 | entry_points={ 70 | "console_scripts": ["docquery = docquery.cmd.__main__:main"], 71 | }, 72 | install_requires=install_requires, 73 | extras_require=extras_require, 74 | ) 75 | -------------------------------------------------------------------------------- /src/docquery/__init__.py: -------------------------------------------------------------------------------- 1 | # pipeline wraps transformers pipeline with extensions in DocQuery 2 | # we're simply re-exporting it here. 3 | import sys 4 | 5 | from transformers.utils import _LazyModule 6 | 7 | from .version import VERSION 8 | 9 | 10 | _import_structure = { 11 | "transformers_patch": ["pipeline"], 12 | } 13 | 14 | sys.modules[__name__] = _LazyModule( 15 | __name__, 16 | globals()["__file__"], 17 | _import_structure, 18 | module_spec=__spec__, 19 | extra_objects={"__version__": VERSION}, 20 | ) 21 | -------------------------------------------------------------------------------- /src/docquery/cmd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/impira/docquery/3744f08a22609c0df5a72f463911b47689eaa819/src/docquery/cmd/__init__.py -------------------------------------------------------------------------------- /src/docquery/cmd/__main__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 5 | 6 | import argparse 7 | import logging 8 | import sys 9 | import textwrap 10 | 11 | import transformers 12 | 13 | from ..transformers_patch import PIPELINE_DEFAULTS 14 | 15 | 16 | _module_not_found_error = None 17 | try: 18 | from . import scan 19 | except ModuleNotFoundError as e: 20 | _module_not_found_error = e 21 | 22 | if _module_not_found_error is not None: 23 | raise ModuleNotFoundError( 24 | textwrap.dedent( 25 | f"""\ 26 | At least one dependency not found: {str(_module_not_found_error)!r} 27 | It is possible that docquery was installed without the CLI dependencies. Run: 28 | 29 | pip install 'docquery[cli]' 30 | 31 | to install impira with the CLI dependencies.""" 32 | ) 33 | ) 34 | 35 | 36 | def main(args=None): 37 | """The main routine.""" 38 | if args is None: 39 | args = sys.argv[1:] 40 | 41 | parent_parser = argparse.ArgumentParser(add_help=False) 42 | parent_parser.add_argument("--verbose", "-v", default=False, action="store_true") 43 | parent_parser.add_argument( 44 | "--checkpoint", 45 | default=None, 46 | help=f"A custom model checkpoint to use (other than {PIPELINE_DEFAULTS['document-question-answering']})", 47 | ) 48 | 49 | parser = argparse.ArgumentParser(description="docquery is a cli tool to work with documents.") 50 | subparsers = parser.add_subparsers(help="sub-command help", dest="subcommand", required=True) 51 | 52 | for module in [scan]: 53 | module.build_parser(subparsers, parent_parser) 54 | 55 | args = parser.parse_args(args=args) 56 | level = logging.DEBUG if args.verbose else logging.INFO 57 | if not args.verbose: 58 | transformers.logging.set_verbosity_error() 59 | logging.basicConfig(format="%(asctime)s %(levelname)s: %(message)s", level=level) 60 | 61 | return args.func(args) 62 | 63 | 64 | if __name__ == "__main__": 65 | sys.exit(main()) 66 | -------------------------------------------------------------------------------- /src/docquery/cmd/scan.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | 4 | from .. import pipeline 5 | from ..config import get_logger 6 | from ..document import UnsupportedDocument, load_document 7 | from ..ocr_reader import OCR_MAPPING 8 | from ..transformers_patch import PIPELINE_DEFAULTS 9 | 10 | 11 | log = get_logger("scan") 12 | 13 | 14 | def build_parser(subparsers, parent_parser): 15 | parser = subparsers.add_parser( 16 | "scan", 17 | help="Scan a directory and ask one or more questions of the documents in it.", 18 | parents=[parent_parser], 19 | ) 20 | 21 | parser.add_argument( 22 | "questions", default=[], nargs="*", type=str, help="One or more questions to ask of the documents" 23 | ) 24 | 25 | parser.add_argument("path", type=str, help="The file or directory to scan") 26 | 27 | parser.add_argument( 28 | "--ocr", choices=list(OCR_MAPPING.keys()), default=None, help="The OCR engine you would like to use" 29 | ) 30 | parser.add_argument( 31 | "--ignore-embedded-text", 32 | dest="use_embedded_text", 33 | action="store_false", 34 | help="Do not try and extract embedded text from document types that might provide it (e.g. PDFs)", 35 | ) 36 | parser.add_argument( 37 | "--classify", 38 | default=False, 39 | action="store_true", 40 | help="Classify documents while scanning them", 41 | ) 42 | parser.add_argument( 43 | "--classify-checkpoint", 44 | default=None, 45 | help=f"A custom model checkpoint to use (other than {PIPELINE_DEFAULTS['document-classification']})", 46 | ) 47 | 48 | parser.set_defaults(func=main) 49 | return parser 50 | 51 | 52 | def main(args): 53 | paths = [] 54 | if pathlib.Path(args.path).is_dir(): 55 | for root, dirs, files in os.walk(args.path): 56 | for fname in files: 57 | if (pathlib.Path(root) / fname).is_dir(): 58 | continue 59 | paths.append(pathlib.Path(root) / fname) 60 | else: 61 | paths.append(args.path) 62 | 63 | docs = [] 64 | for p in paths: 65 | try: 66 | log.info(f"Loading {p}") 67 | docs.append((p, load_document(str(p), ocr_reader=args.ocr, use_embedded_text=args.use_embedded_text))) 68 | except UnsupportedDocument as e: 69 | log.warning(f"Cannot load {p}: {e}. Skipping...") 70 | 71 | log.info(f"Done loading {len(docs)} file(s).") 72 | if not docs: 73 | return 74 | 75 | log.info("Loading pipelines.") 76 | 77 | nlp = pipeline("document-question-answering", model=args.checkpoint) 78 | if args.classify: 79 | classify = pipeline("document-classification", model=args.classify_checkpoint) 80 | 81 | log.info("Ready to start evaluating!") 82 | 83 | max_fname_len = max(len(str(p)) for (p, _) in docs) 84 | max_question_len = max(len(q) for q in args.questions) if len(args.questions) > 0 else 0 85 | for i, (p, d) in enumerate(docs): 86 | if i > 0 and len(args.questions) > 1: 87 | print("") 88 | 89 | if args.classify: 90 | cls = classify(**d.context)[0] 91 | print(f"{str(p):<{max_fname_len}} Document Type: {cls['label']}") 92 | 93 | for q in args.questions: 94 | try: 95 | response = nlp(question=q, **d.context) 96 | if isinstance(response, list): 97 | response = response[0] if len(response) > 0 else None 98 | except Exception: 99 | log.error(f"Failed while processing {str(p)} on question: '{q}'") 100 | raise 101 | 102 | answer = response["answer"] if response is not None else "NULL" 103 | print(f"{str(p):<{max_fname_len}} {q:<{max_question_len}}: {answer}") 104 | -------------------------------------------------------------------------------- /src/docquery/config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from pydantic import validate_arguments 4 | 5 | 6 | @validate_arguments 7 | def get_logger(prefix: str): 8 | log = logging.getLogger(prefix) 9 | log.propagate = False 10 | ch = logging.StreamHandler() 11 | formatter = logging.Formatter("%(asctime)s %(levelname)s: %(message)s") 12 | ch.setFormatter(formatter) 13 | log.addHandler(ch) 14 | return log 15 | -------------------------------------------------------------------------------- /src/docquery/document.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import mimetypes 3 | import os 4 | from io import BytesIO 5 | from typing import Any, Dict, List, Optional, Tuple, Union 6 | 7 | import requests 8 | from PIL import Image, UnidentifiedImageError 9 | from pydantic import validate_arguments 10 | 11 | from .ext.functools import cached_property 12 | from .ocr_reader import NoOCRReaderFound, OCRReader, get_ocr_reader 13 | from .web import get_webdriver 14 | 15 | 16 | class UnsupportedDocument(Exception): 17 | def __init__(self, e): 18 | self.e = e 19 | 20 | def __str__(self): 21 | return f"unsupported file type: {self.e}" 22 | 23 | 24 | PDF_2_IMAGE = False 25 | PDF_PLUMBER = False 26 | 27 | try: 28 | import pdf2image 29 | 30 | PDF_2_IMAGE = True 31 | except ImportError: 32 | pass 33 | 34 | try: 35 | import pdfplumber 36 | 37 | PDF_PLUMBER = True 38 | except ImportError: 39 | pass 40 | 41 | 42 | def use_pdf2_image(): 43 | if not PDF_2_IMAGE: 44 | raise UnsupportedDocument("Unable to import pdf2image (OCR will be unavailable for pdfs)") 45 | 46 | 47 | def use_pdf_plumber(): 48 | if not PDF_PLUMBER: 49 | raise UnsupportedDocument("Unable to import pdfplumber (pdfs will be unavailable)") 50 | 51 | 52 | class Document(metaclass=abc.ABCMeta): 53 | @property 54 | @abc.abstractmethod 55 | def context(self) -> Tuple[(str, List[int])]: 56 | raise NotImplementedError 57 | 58 | @property 59 | @abc.abstractmethod 60 | def preview(self) -> "Image": 61 | raise NotImplementedError 62 | 63 | @staticmethod 64 | def _generate_document_output( 65 | images: List["Image.Image"], 66 | words_by_page: List[List[str]], 67 | boxes_by_page: List[List[List[int]]], 68 | dimensions_by_page: List[Tuple[int, int]], 69 | ) -> Dict[str, List[Tuple["Image.Image", List[Any]]]]: 70 | 71 | # pages_dimensions (width, height) 72 | assert len(images) == len(dimensions_by_page) 73 | assert len(images) == len(words_by_page) 74 | assert len(images) == len(boxes_by_page) 75 | processed_pages = [] 76 | for image, words, boxes, dimensions in zip(images, words_by_page, boxes_by_page, dimensions_by_page): 77 | width, height = dimensions 78 | 79 | """ 80 | box is [x1,y1,x2,y2] where x1,y1 are the top left corner of box and x2,y2 is the bottom right corner 81 | This function scales the distance between boxes to be on a fixed scale 82 | It is derived from the preprocessing code for LayoutLM 83 | """ 84 | normalized_boxes = [ 85 | [ 86 | max(min(c, 1000), 0) 87 | for c in [ 88 | int(1000 * (box[0] / width)), 89 | int(1000 * (box[1] / height)), 90 | int(1000 * (box[2] / width)), 91 | int(1000 * (box[3] / height)), 92 | ] 93 | ] 94 | for box in boxes 95 | ] 96 | assert len(words) == len(normalized_boxes), "Not as many words as there are bounding boxes" 97 | word_boxes = [x for x in zip(words, normalized_boxes)] 98 | processed_pages.append((image, word_boxes)) 99 | 100 | return {"image": processed_pages} 101 | 102 | 103 | class PDFDocument(Document): 104 | def __init__(self, b, ocr_reader, use_embedded_text, **kwargs): 105 | self.b = b 106 | self.ocr_reader = ocr_reader 107 | self.use_embedded_text = use_embedded_text 108 | 109 | super().__init__(**kwargs) 110 | 111 | @cached_property 112 | def context(self) -> Dict[str, List[Tuple["Image.Image", List[Any]]]]: 113 | pdf = self._pdf 114 | if pdf is None: 115 | return {} 116 | 117 | images = self._images 118 | 119 | if len(images) != len(pdf.pages): 120 | raise ValueError( 121 | f"Mismatch: pdfplumber() thinks there are {len(pdf.pages)} pages and" 122 | f" pdf2image thinks there are {len(images)}" 123 | ) 124 | 125 | words_by_page = [] 126 | boxes_by_page = [] 127 | dimensions_by_page = [] 128 | for i, page in enumerate(pdf.pages): 129 | extracted_words = page.extract_words() if self.use_embedded_text else [] 130 | 131 | if len(extracted_words) == 0: 132 | words, boxes = self.ocr_reader.apply_ocr(images[i]) 133 | words_by_page.append(words) 134 | boxes_by_page.append(boxes) 135 | dimensions_by_page.append((images[i].width, images[i].height)) 136 | 137 | else: 138 | words = [w["text"] for w in extracted_words] 139 | boxes = [[w["x0"], w["top"], w["x1"], w["bottom"]] for w in extracted_words] 140 | words_by_page.append(words) 141 | boxes_by_page.append(boxes) 142 | dimensions_by_page.append((page.width, page.height)) 143 | 144 | return self._generate_document_output(images, words_by_page, boxes_by_page, dimensions_by_page) 145 | 146 | @cached_property 147 | def preview(self) -> "Image": 148 | return self._images 149 | 150 | @cached_property 151 | def _images(self): 152 | # First, try to extract text directly 153 | # TODO: This library requires poppler, which is not present everywhere. 154 | # We should look into alternatives. We could also gracefully handle this 155 | # and simply fall back to _only_ extracted text 156 | return [x.convert("RGB") for x in pdf2image.convert_from_bytes(self.b)] 157 | 158 | @cached_property 159 | def _pdf(self): 160 | use_pdf_plumber() 161 | pdf = pdfplumber.open(BytesIO(self.b)) 162 | if len(pdf.pages) == 0: 163 | return None 164 | return pdf 165 | 166 | 167 | class ImageDocument(Document): 168 | def __init__(self, b, ocr_reader, **kwargs): 169 | self.b = b 170 | self.ocr_reader = ocr_reader 171 | 172 | super().__init__(**kwargs) 173 | 174 | @cached_property 175 | def preview(self) -> "Image": 176 | return [self.b.convert("RGB")] 177 | 178 | @cached_property 179 | def context(self) -> Dict[str, List[Tuple["Image.Image", List[Any]]]]: 180 | words, boxes = self.ocr_reader.apply_ocr(self.b) 181 | return self._generate_document_output([self.b], [words], [boxes], [(self.b.width, self.b.height)]) 182 | 183 | 184 | class WebDocument(Document): 185 | def __init__(self, url, **kwargs): 186 | if not (url.startswith("http://") or url.startswith("https://")): 187 | url = "file://" + url 188 | self.url = url 189 | 190 | # TODO: This is a singleton, which is not thread-safe. We may want to relax this 191 | # behavior to allow the user to pass in their own driver (which could either be a 192 | # singleton or a custom instance). 193 | self.driver = get_webdriver() 194 | 195 | super().__init__(**kwargs) 196 | 197 | def ensure_loaded(self): 198 | self.driver.get(self.url) 199 | 200 | @cached_property 201 | def page_screenshots(self): 202 | self.ensure_loaded() 203 | return self.driver.scroll_and_screenshot() 204 | 205 | @cached_property 206 | def preview(self) -> "Image": 207 | return [img.convert("RGB") for img in self.page_screenshots[1]] 208 | 209 | @cached_property 210 | def context(self) -> Dict[str, List[Tuple["Image.Image", List[Any]]]]: 211 | self.ensure_loaded() 212 | word_boxes = self.driver.find_word_boxes() 213 | 214 | tops, _ = self.page_screenshots 215 | 216 | n_pages = len(tops) 217 | page = 0 218 | offset = 0 219 | 220 | words = [[] for _ in range(n_pages)] 221 | boxes = [[] for _ in range(n_pages)] 222 | for word_box in word_boxes["word_boxes"]: 223 | box = word_box["box"] 224 | 225 | if page < len(tops) - 1 and box["top"] >= tops[page + 1]: 226 | page += 1 227 | offset = tops[page] 228 | 229 | words[page].append(word_box["text"]) 230 | boxes[page].append((box["left"], box["top"] - offset, box["right"], box["bottom"] - offset)) 231 | 232 | return self._generate_document_output( 233 | self.preview, words, boxes, [(word_boxes["vw"], word_boxes["vh"])] * n_pages 234 | ) 235 | 236 | 237 | @validate_arguments 238 | def load_document(fpath: str, ocr_reader: Optional[Union[str, OCRReader]] = None, use_embedded_text=True): 239 | base_path = os.path.basename(fpath).split("?")[0].strip() 240 | doc_type = mimetypes.guess_type(base_path)[0] 241 | if fpath.startswith("http://") or fpath.startswith("https://"): 242 | resp = requests.get(fpath, allow_redirects=True, stream=True) 243 | if not resp.ok: 244 | raise UnsupportedDocument(f"Failed to download: {resp.content}") 245 | 246 | if "Content-Type" in resp.headers: 247 | doc_type = resp.headers["Content-Type"].split(";")[0].strip() 248 | 249 | b = resp.raw 250 | else: 251 | b = open(fpath, "rb") 252 | 253 | if not ocr_reader or isinstance(ocr_reader, str): 254 | ocr_reader = get_ocr_reader(ocr_reader) 255 | elif not isinstance(ocr_reader, OCRReader): 256 | raise NoOCRReaderFound(f"{ocr_reader} is not a supported OCRReader class") 257 | 258 | if doc_type == "application/pdf": 259 | return PDFDocument(b.read(), ocr_reader=ocr_reader, use_embedded_text=use_embedded_text) 260 | elif doc_type == "text/html": 261 | return WebDocument(fpath) 262 | else: 263 | try: 264 | img = Image.open(b) 265 | except UnidentifiedImageError as e: 266 | raise UnsupportedDocument(e) 267 | return ImageDocument(img, ocr_reader=ocr_reader) 268 | -------------------------------------------------------------------------------- /src/docquery/ext/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/impira/docquery/3744f08a22609c0df5a72f463911b47689eaa819/src/docquery/ext/__init__.py -------------------------------------------------------------------------------- /src/docquery/ext/functools.py: -------------------------------------------------------------------------------- 1 | try: 2 | from functools import cached_property as cached_property 3 | except ImportError: 4 | # for python 3.7 support fall back to just property 5 | cached_property = property 6 | -------------------------------------------------------------------------------- /src/docquery/ext/itertools.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | 4 | def unique_everseen(iterable, key=None): 5 | """ 6 | List unique elements, preserving order. Remember all elements ever seen [1]_. 7 | 8 | Examples 9 | -------- 10 | >>> list(unique_everseen("AAAABBBCCDAABBB")) 11 | ["A", "B", "C", "D"] 12 | >>> list(unique_everseen("ABBCcAD", str.lower)) 13 | ["A", "B", "C", "D"] 14 | 15 | References 16 | ---------- 17 | .. [1] https://docs.python.org/3/library/itertools.html 18 | """ 19 | seen = set() 20 | seen_add = seen.add 21 | if key is None: 22 | for element in itertools.filterfalse(seen.__contains__, iterable): 23 | seen_add(element) 24 | yield element 25 | else: 26 | for element in iterable: 27 | k = key(element) 28 | if k not in seen: 29 | seen_add(k) 30 | yield element 31 | -------------------------------------------------------------------------------- /src/docquery/ext/model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Tuple, Union 3 | 4 | import torch 5 | from torch import nn 6 | from transformers import LayoutLMModel, LayoutLMPreTrainedModel 7 | from transformers.modeling_outputs import QuestionAnsweringModelOutput as QuestionAnsweringModelOutputBase 8 | 9 | 10 | @dataclass 11 | class QuestionAnsweringModelOutput(QuestionAnsweringModelOutputBase): 12 | token_logits: Optional[torch.FloatTensor] = None 13 | 14 | 15 | # There are three additional config parameters that this model supports, which are not part of the 16 | # LayoutLMForQuestionAnswering in mainline transformers. These config parameters control the additional 17 | # token classifier head. 18 | # 19 | # token_classification (`bool, *optional*, defaults to False): 20 | # Whether to include an additional token classification head in question answering 21 | # token_classifier_reduction (`str`, *optional*, defaults to "mean") 22 | # Specifies the reduction to apply to the output of the cross entropy loss for the token classifier head during 23 | # training. Options are: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the weighted 24 | # mean of the output is taken, 'sum': the output will be summed. 25 | # token_classifier_constant (`float`, *optional*, defaults to 1.0) 26 | # Coefficient for the token classifier head's contribution to the total loss. A larger value means that the model 27 | # will prioritize learning the token classifier head during training. 28 | class LayoutLMForQuestionAnswering(LayoutLMPreTrainedModel): 29 | def __init__(self, config, has_visual_segment_embedding=True): 30 | super().__init__(config) 31 | self.num_labels = config.num_labels 32 | 33 | self.layoutlm = LayoutLMModel(config) 34 | self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) 35 | 36 | # NOTE: We have to use getattr() here because we do not patch the LayoutLMConfig 37 | # class to have these extra attributes, so existing LayoutLM models may not have 38 | # them in their configuration. 39 | self.token_classifier_head = None 40 | if getattr(self.config, "token_classification", False): 41 | self.token_classifier_head = nn.Linear(config.hidden_size, 2) 42 | 43 | # Initialize weights and apply final processing 44 | self.post_init() 45 | 46 | def get_input_embeddings(self): 47 | return self.layoutlm.embeddings.word_embeddings 48 | 49 | def forward( 50 | self, 51 | input_ids: Optional[torch.LongTensor] = None, 52 | bbox: Optional[torch.LongTensor] = None, 53 | attention_mask: Optional[torch.FloatTensor] = None, 54 | token_type_ids: Optional[torch.LongTensor] = None, 55 | position_ids: Optional[torch.LongTensor] = None, 56 | head_mask: Optional[torch.FloatTensor] = None, 57 | inputs_embeds: Optional[torch.FloatTensor] = None, 58 | start_positions: Optional[torch.LongTensor] = None, 59 | end_positions: Optional[torch.LongTensor] = None, 60 | token_labels: Optional[torch.LongTensor] = None, 61 | output_attentions: Optional[bool] = None, 62 | output_hidden_states: Optional[bool] = None, 63 | return_dict: Optional[bool] = None, 64 | ) -> Union[Tuple, QuestionAnsweringModelOutput]: 65 | r""" 66 | start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 67 | Labels for position (index) of the start of the labelled span for computing the token classification loss. 68 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence 69 | are not taken into account for computing the loss. 70 | end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 71 | Labels for position (index) of the end of the labelled span for computing the token classification loss. 72 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence 73 | are not taken into account for computing the loss. 74 | 75 | Returns: 76 | 77 | Example: 78 | 79 | In this example below, we give the LayoutLMv2 model an image (of texts) and ask it a question. It will give us 80 | a prediction of what it thinks the answer is (the span of the answer within the texts parsed from the image). 81 | 82 | ```python 83 | >>> from transformers import AutoTokenizer, LayoutLMForQuestionAnswering 84 | >>> from datasets import load_dataset 85 | >>> import torch 86 | 87 | >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased", add_prefix_space=True) 88 | >>> model = LayoutLMForQuestionAnswering.from_pretrained("microsoft/layoutlm-base-uncased") 89 | 90 | >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train") 91 | >>> example = dataset[0] 92 | >>> question = "what's his name?" 93 | >>> words = example["tokens"] 94 | >>> boxes = example["bboxes"] 95 | 96 | >>> encoding = tokenizer( 97 | ... question.split(), words, is_split_into_words=True, return_token_type_ids=True, return_tensors="pt" 98 | ... ) 99 | >>> bbox = [] 100 | >>> for i, s, w in zip(encoding.input_ids[0], encoding.sequence_ids(0), encoding.word_ids(0)): 101 | ... if s == 1: 102 | ... bbox.append(boxes[w]) 103 | ... elif i == tokenizer.sep_token_id: 104 | ... bbox.append([1000] * 4) 105 | ... else: 106 | ... bbox.append([0] * 4) 107 | >>> encoding["bbox"] = torch.tensor([bbox]) 108 | 109 | >>> outputs = model(**encoding) 110 | >>> loss = outputs.loss 111 | >>> start_scores = outputs.start_logits 112 | >>> end_scores = outputs.end_logits 113 | ``` 114 | """ 115 | 116 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 117 | 118 | outputs = self.layoutlm( 119 | input_ids=input_ids, 120 | bbox=bbox, 121 | attention_mask=attention_mask, 122 | token_type_ids=token_type_ids, 123 | position_ids=position_ids, 124 | head_mask=head_mask, 125 | inputs_embeds=inputs_embeds, 126 | output_attentions=output_attentions, 127 | output_hidden_states=output_hidden_states, 128 | return_dict=return_dict, 129 | ) 130 | 131 | if input_ids is not None: 132 | input_shape = input_ids.size() 133 | else: 134 | input_shape = inputs_embeds.size()[:-1] 135 | 136 | seq_length = input_shape[1] 137 | # only take the text part of the output representations 138 | sequence_output = outputs[0][:, :seq_length] 139 | 140 | logits = self.qa_outputs(sequence_output) 141 | start_logits, end_logits = logits.split(1, dim=-1) 142 | start_logits = start_logits.squeeze(-1).contiguous() 143 | end_logits = end_logits.squeeze(-1).contiguous() 144 | 145 | total_loss = None 146 | if start_positions is not None and end_positions is not None: 147 | # If we are on multi-GPU, split add a dimension 148 | if len(start_positions.size()) > 1: 149 | start_positions = start_positions.squeeze(-1) 150 | if len(end_positions.size()) > 1: 151 | end_positions = end_positions.squeeze(-1) 152 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 153 | ignored_index = start_logits.size(1) 154 | start_positions = start_positions.clamp(0, ignored_index) 155 | end_positions = end_positions.clamp(0, ignored_index) 156 | 157 | loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) 158 | start_loss = loss_fct(start_logits, start_positions) 159 | end_loss = loss_fct(end_logits, end_positions) 160 | total_loss = (start_loss + end_loss) / 2 161 | 162 | token_logits = None 163 | if getattr(self.config, "token_classification", False): 164 | token_logits = self.token_classifier_head(sequence_output) 165 | 166 | if token_labels is not None: 167 | # Loss fn expects logits to be of shape (batch_size, num_labels, 512), but model 168 | # outputs (batch_size, 512, num_labels), so we need to move the dimensions around 169 | # Ref: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html 170 | token_logits_reshaped = torch.movedim(token_logits, source=token_logits.ndim - 1, destination=1) 171 | token_loss = nn.CrossEntropyLoss(reduction=self.config.token_classifier_reduction)( 172 | token_logits_reshaped, token_labels 173 | ) 174 | 175 | total_loss += self.config.token_classifier_constant * token_loss 176 | 177 | if not return_dict: 178 | output = (start_logits, end_logits) 179 | if self.token_classification: 180 | output = output + (token_logits,) 181 | 182 | output = output + outputs[2:] 183 | 184 | if total_loss is not None: 185 | output = (total_loss,) + output 186 | 187 | return output 188 | 189 | return QuestionAnsweringModelOutput( 190 | loss=total_loss, 191 | start_logits=start_logits, 192 | end_logits=end_logits, 193 | token_logits=token_logits, 194 | hidden_states=outputs.hidden_states, 195 | attentions=outputs.attentions, 196 | ) 197 | -------------------------------------------------------------------------------- /src/docquery/ext/pipeline_document_classification.py: -------------------------------------------------------------------------------- 1 | # This file is copied from transformers: 2 | # https://github.com/huggingface/transformers/blob/bb6f6d53386bf2340eead6a8f9320ce61add3e96/src/transformers/pipelines/image_classification.py 3 | # And has been modified to support Donut 4 | import re 5 | from typing import List, Optional, Tuple, Union 6 | 7 | import torch 8 | from transformers.pipelines.base import PIPELINE_INIT_ARGS, ChunkPipeline 9 | from transformers.pipelines.text_classification import ClassificationFunction, sigmoid, softmax 10 | from transformers.utils import ExplicitEnum, add_end_docstrings, logging 11 | 12 | from .pipeline_document_question_answering import ImageOrName, apply_tesseract 13 | from .qa_helpers import TESSERACT_LOADED, VISION_LOADED, load_image 14 | 15 | 16 | logger = logging.get_logger(__name__) 17 | 18 | 19 | class ModelType(ExplicitEnum): 20 | Standard = "standard" 21 | VisionEncoderDecoder = "vision_encoder_decoder" 22 | 23 | 24 | def donut_token2json(tokenizer, tokens, is_inner_value=False): 25 | """ 26 | Convert a (generated) token sequence into an ordered JSON format. 27 | """ 28 | output = dict() 29 | 30 | while tokens: 31 | start_token = re.search(r"", tokens, re.IGNORECASE) 32 | if start_token is None: 33 | break 34 | key = start_token.group(1) 35 | end_token = re.search(rf"", tokens, re.IGNORECASE) 36 | start_token = start_token.group() 37 | if end_token is None: 38 | tokens = tokens.replace(start_token, "") 39 | else: 40 | end_token = end_token.group() 41 | start_token_escaped = re.escape(start_token) 42 | end_token_escaped = re.escape(end_token) 43 | content = re.search(f"{start_token_escaped}(.*?){end_token_escaped}", tokens, re.IGNORECASE) 44 | if content is not None: 45 | content = content.group(1).strip() 46 | if r""): 55 | leaf = leaf.strip() 56 | if leaf in tokenizer.get_added_vocab() and leaf[0] == "<" and leaf[-2:] == "/>": 57 | leaf = leaf[1:-2] # for categorical special tokens 58 | output[key].append(leaf) 59 | if len(output[key]) == 1: 60 | output[key] = output[key][0] 61 | 62 | tokens = tokens[tokens.find(end_token) + len(end_token) :].strip() 63 | if tokens[:6] == r"": # non-leaf nodes 64 | return [output] + donut_token2json(tokenizer, tokens[6:], is_inner_value=True) 65 | 66 | if len(output): 67 | return [output] if is_inner_value else output 68 | else: 69 | return [] if is_inner_value else {"text_sequence": tokens} 70 | 71 | 72 | @add_end_docstrings(PIPELINE_INIT_ARGS) 73 | class DocumentClassificationPipeline(ChunkPipeline): 74 | """ 75 | Document classification pipeline using any `AutoModelForDocumentClassification`. This pipeline predicts the class of a 76 | document. 77 | 78 | This document classification pipeline can currently be loaded from [`pipeline`] using the following task identifier: 79 | `"document-classification"`. 80 | 81 | See the list of available models on 82 | [huggingface.co/models](https://huggingface.co/models?filter=document-classification). 83 | """ 84 | 85 | def __init__(self, *args, **kwargs): 86 | super().__init__(*args, **kwargs) 87 | 88 | if self.model.config.__class__.__name__ == "VisionEncoderDecoderConfig": 89 | self.model_type = ModelType.VisionEncoderDecoder 90 | else: 91 | self.model_type = ModelType.Standard 92 | 93 | def _sanitize_parameters( 94 | self, 95 | doc_stride=None, 96 | lang: Optional[str] = None, 97 | tesseract_config: Optional[str] = None, 98 | max_num_spans: Optional[int] = None, 99 | max_seq_len=None, 100 | function_to_apply=None, 101 | top_k=None, 102 | ): 103 | preprocess_params, postprocess_params = {}, {} 104 | if doc_stride is not None: 105 | preprocess_params["doc_stride"] = doc_stride 106 | if max_seq_len is not None: 107 | preprocess_params["max_seq_len"] = max_seq_len 108 | if lang is not None: 109 | preprocess_params["lang"] = lang 110 | if tesseract_config is not None: 111 | preprocess_params["tesseract_config"] = tesseract_config 112 | if max_num_spans is not None: 113 | preprocess_params["max_num_spans"] = max_num_spans 114 | 115 | if isinstance(function_to_apply, str): 116 | function_to_apply = ClassificationFunction[function_to_apply.upper()] 117 | 118 | if function_to_apply is not None: 119 | postprocess_params["function_to_apply"] = function_to_apply 120 | 121 | if top_k is not None: 122 | if top_k < 1: 123 | raise ValueError(f"top_k parameter should be >= 1 (got {top_k})") 124 | postprocess_params["top_k"] = top_k 125 | 126 | return preprocess_params, {}, postprocess_params 127 | 128 | def __call__(self, image: Union[ImageOrName, List[ImageOrName], List[Tuple]], **kwargs): 129 | """ 130 | Assign labels to the document(s) passed as inputs. 131 | 132 | # TODO 133 | """ 134 | if isinstance(image, list): 135 | normalized_images = (i if isinstance(i, (tuple, list)) else (i, None) for i in image) 136 | else: 137 | normalized_images = [(image, None)] 138 | 139 | return super().__call__({"pages": normalized_images}, **kwargs) 140 | 141 | def preprocess( 142 | self, 143 | input, 144 | doc_stride=None, 145 | max_seq_len=None, 146 | word_boxes: Tuple[str, List[float]] = None, 147 | lang=None, 148 | tesseract_config="", 149 | max_num_spans=1, 150 | ): 151 | # NOTE: This code mirrors the code in question answering and will be implemented in a follow up PR 152 | # to support documents with enough tokens that overflow the model's window 153 | if max_seq_len is None: 154 | # TODO: LayoutLM's stride is 512 by default. Is it ok to use that as the min 155 | # instead of 384 (which the QA model uses)? 156 | max_seq_len = min(self.tokenizer.model_max_length, 512) 157 | 158 | if doc_stride is None: 159 | doc_stride = min(max_seq_len // 2, 256) 160 | 161 | total_num_spans = 0 162 | 163 | for page_idx, (image, word_boxes) in enumerate(input["pages"]): 164 | image_features = {} 165 | if image is not None: 166 | if not VISION_LOADED: 167 | raise ValueError( 168 | "If you provide an image, then the pipeline will run process it with PIL (Pillow), but" 169 | " PIL is not available. Install it with pip install Pillow." 170 | ) 171 | image = load_image(image) 172 | if self.feature_extractor is not None: 173 | image_features.update(self.feature_extractor(images=image, return_tensors=self.framework)) 174 | 175 | words, boxes = None, None 176 | if self.model_type != ModelType.VisionEncoderDecoder: 177 | if word_boxes is not None: 178 | words = [x[0] for x in word_boxes] 179 | boxes = [x[1] for x in word_boxes] 180 | elif "words" in image_features and "boxes" in image_features: 181 | words = image_features.pop("words")[0] 182 | boxes = image_features.pop("boxes")[0] 183 | elif image is not None: 184 | if not TESSERACT_LOADED: 185 | raise ValueError( 186 | "If you provide an image without word_boxes, then the pipeline will run OCR using" 187 | " Tesseract, but pytesseract is not available. Install it with pip install pytesseract." 188 | ) 189 | if TESSERACT_LOADED: 190 | words, boxes = apply_tesseract(image, lang=lang, tesseract_config=tesseract_config) 191 | else: 192 | raise ValueError( 193 | "You must provide an image or word_boxes. If you provide an image, the pipeline will" 194 | " automatically run OCR to derive words and boxes" 195 | ) 196 | 197 | if self.tokenizer.padding_side != "right": 198 | raise ValueError( 199 | "Document classification only supports tokenizers whose padding side is 'right', not" 200 | f" {self.tokenizer.padding_side}" 201 | ) 202 | 203 | if self.model_type == ModelType.VisionEncoderDecoder: 204 | encoding = { 205 | "inputs": image_features["pixel_values"], 206 | "max_length": self.model.decoder.config.max_position_embeddings, 207 | "decoder_input_ids": self.tokenizer( 208 | "", 209 | add_special_tokens=False, 210 | return_tensors=self.framework, 211 | ).input_ids, 212 | "return_dict_in_generate": True, 213 | } 214 | yield { 215 | **encoding, 216 | "page": None, 217 | } 218 | else: 219 | encoding = self.tokenizer( 220 | text=words, 221 | max_length=max_seq_len, 222 | stride=doc_stride, 223 | return_token_type_ids=True, 224 | is_split_into_words=True, 225 | truncation=True, 226 | return_overflowing_tokens=True, 227 | ) 228 | 229 | num_spans = len(encoding["input_ids"]) 230 | 231 | for span_idx in range(num_spans): 232 | if self.framework == "pt": 233 | span_encoding = {k: torch.tensor(v[span_idx : span_idx + 1]) for (k, v) in encoding.items()} 234 | span_encoding.update( 235 | {k: v for (k, v) in image_features.items()} 236 | ) # TODO: Verify cardinality is correct 237 | else: 238 | raise ValueError("Unsupported: Tensorflow preprocessing for DocumentClassification") 239 | 240 | # For each span, place a bounding box [0,0,0,0] for question and CLS tokens, [1000,1000,1000,1000] 241 | # for SEP tokens, and the word's bounding box for words in the original document. 242 | bbox = [] 243 | for i, s, w in zip( 244 | encoding.input_ids[span_idx], 245 | encoding.sequence_ids(span_idx), 246 | encoding.word_ids(span_idx), 247 | ): 248 | if s == 0: 249 | bbox.append(boxes[w]) 250 | elif i == self.tokenizer.sep_token_id: 251 | bbox.append([1000] * 4) 252 | else: 253 | bbox.append([0] * 4) 254 | 255 | span_encoding["bbox"] = torch.tensor(bbox).unsqueeze(0) 256 | 257 | yield { 258 | **span_encoding, 259 | "page": page_idx, 260 | } 261 | 262 | total_num_spans += 1 263 | if total_num_spans >= max_num_spans: 264 | break 265 | 266 | def _forward(self, model_inputs): 267 | page = model_inputs.pop("page", None) 268 | 269 | if "overflow_to_sample_mapping" in model_inputs: 270 | model_inputs.pop("overflow_to_sample_mapping") 271 | 272 | if self.model_type == ModelType.VisionEncoderDecoder: 273 | model_outputs = self.model.generate(**model_inputs) 274 | else: 275 | model_outputs = self.model(**model_inputs) 276 | 277 | model_outputs["page"] = page 278 | model_outputs["attention_mask"] = model_inputs.get("attention_mask", None) 279 | return model_outputs 280 | 281 | def postprocess(self, model_outputs, function_to_apply=None, top_k=1, **kwargs): 282 | if function_to_apply is None: 283 | if self.model.config.num_labels == 1: 284 | function_to_apply = ClassificationFunction.SIGMOID 285 | elif self.model.config.num_labels > 1: 286 | function_to_apply = ClassificationFunction.SOFTMAX 287 | elif hasattr(self.model.config, "function_to_apply") and function_to_apply is None: 288 | function_to_apply = self.model.config.function_to_apply 289 | else: 290 | function_to_apply = ClassificationFunction.NONE 291 | 292 | if self.model_type == ModelType.VisionEncoderDecoder: 293 | answers = self.postprocess_encoder_decoder(model_outputs, top_k=top_k, **kwargs) 294 | else: 295 | answers = self.postprocess_standard( 296 | model_outputs, function_to_apply=function_to_apply, top_k=top_k, **kwargs 297 | ) 298 | 299 | answers = sorted(answers, key=lambda x: x.get("score", 0), reverse=True)[:top_k] 300 | return answers 301 | 302 | def postprocess_encoder_decoder(self, model_outputs, **kwargs): 303 | classes = set() 304 | for model_output in model_outputs: 305 | for sequence in self.tokenizer.batch_decode(model_output.sequences): 306 | sequence = sequence.replace(self.tokenizer.eos_token, "").replace(self.tokenizer.pad_token, "") 307 | sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token 308 | classes.add(donut_token2json(self.tokenizer, sequence)["class"]) 309 | 310 | # Return the first top_k unique classes we see 311 | return [{"label": v} for v in classes] 312 | 313 | def postprocess_standard(self, model_outputs, function_to_apply, **kwargs): 314 | # Average the score across pages 315 | sum_scores = {k: 0 for k in self.model.config.id2label.values()} 316 | for model_output in model_outputs: 317 | outputs = model_output["logits"][0] 318 | outputs = outputs.numpy() 319 | 320 | if function_to_apply == ClassificationFunction.SIGMOID: 321 | scores = sigmoid(outputs) 322 | elif function_to_apply == ClassificationFunction.SOFTMAX: 323 | scores = softmax(outputs) 324 | elif function_to_apply == ClassificationFunction.NONE: 325 | scores = outputs 326 | else: 327 | raise ValueError(f"Unrecognized `function_to_apply` argument: {function_to_apply}") 328 | 329 | for i, score in enumerate(scores): 330 | sum_scores[self.model.config.id2label[i]] += score.item() 331 | 332 | return [{"label": label, "score": score / len(model_outputs)} for (label, score) in sum_scores.items()] 333 | -------------------------------------------------------------------------------- /src/docquery/ext/pipeline_document_question_answering.py: -------------------------------------------------------------------------------- 1 | # NOTE: This code is currently under review for inclusion in the main 2 | # huggingface/transformers repository: 3 | # https://github.com/huggingface/transformers/pull/18414 4 | import re 5 | from typing import List, Optional, Tuple, Union 6 | 7 | import numpy as np 8 | from transformers.pipelines.base import PIPELINE_INIT_ARGS, ChunkPipeline 9 | from transformers.utils import ( 10 | ExplicitEnum, 11 | add_end_docstrings, 12 | is_pytesseract_available, 13 | is_torch_available, 14 | is_vision_available, 15 | logging, 16 | ) 17 | 18 | from .itertools import unique_everseen 19 | from .qa_helpers import TESSERACT_LOADED, VISION_LOADED, Image, load_image, pytesseract, select_starts_ends 20 | 21 | 22 | if is_torch_available(): 23 | import torch 24 | 25 | # We do not perform the check in this version of the pipeline code 26 | # from transformers.models.auto.modeling_auto import MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING 27 | 28 | logger = logging.get_logger(__name__) 29 | 30 | 31 | # normalize_bbox() and apply_tesseract() are derived from apply_tesseract in models/layoutlmv3/feature_extraction_layoutlmv3.py. 32 | # However, because the pipeline may evolve from what layoutlmv3 currently does, it's copied (vs. imported) to avoid creating an 33 | # unnecessary dependency. 34 | def normalize_box(box, width, height): 35 | return [ 36 | int(1000 * (box[0] / width)), 37 | int(1000 * (box[1] / height)), 38 | int(1000 * (box[2] / width)), 39 | int(1000 * (box[3] / height)), 40 | ] 41 | 42 | 43 | def apply_tesseract(image: "Image.Image", lang: Optional[str], tesseract_config: Optional[str]): 44 | """Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes.""" 45 | # apply OCR 46 | data = pytesseract.image_to_data(image, lang=lang, output_type="dict", config=tesseract_config) 47 | words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"] 48 | 49 | # filter empty words and corresponding coordinates 50 | irrelevant_indices = [idx for idx, word in enumerate(words) if not word.strip()] 51 | words = [word for idx, word in enumerate(words) if idx not in irrelevant_indices] 52 | left = [coord for idx, coord in enumerate(left) if idx not in irrelevant_indices] 53 | top = [coord for idx, coord in enumerate(top) if idx not in irrelevant_indices] 54 | width = [coord for idx, coord in enumerate(width) if idx not in irrelevant_indices] 55 | height = [coord for idx, coord in enumerate(height) if idx not in irrelevant_indices] 56 | 57 | # turn coordinates into (left, top, left+width, top+height) format 58 | actual_boxes = [] 59 | for x, y, w, h in zip(left, top, width, height): 60 | actual_box = [x, y, x + w, y + h] 61 | actual_boxes.append(actual_box) 62 | 63 | image_width, image_height = image.size 64 | 65 | # finally, normalize the bounding boxes 66 | normalized_boxes = [] 67 | for box in actual_boxes: 68 | normalized_boxes.append(normalize_box(box, image_width, image_height)) 69 | 70 | if len(words) != len(normalized_boxes): 71 | raise ValueError("Not as many words as there are bounding boxes") 72 | 73 | return words, normalized_boxes 74 | 75 | 76 | class ModelType(ExplicitEnum): 77 | LayoutLM = "layoutlm" 78 | LayoutLMv2andv3 = "layoutlmv2andv3" 79 | VisionEncoderDecoder = "vision_encoder_decoder" 80 | 81 | 82 | ImageOrName = Union["Image.Image", str] 83 | DEFAULT_MAX_ANSWER_LENGTH = 15 84 | 85 | 86 | @add_end_docstrings(PIPELINE_INIT_ARGS) 87 | class DocumentQuestionAnsweringPipeline(ChunkPipeline): 88 | # TODO: Update task_summary docs to include an example with document QA and then update the first sentence 89 | """ 90 | Document Question Answering pipeline using any `AutoModelForDocumentQuestionAnswering`. The inputs/outputs are 91 | similar to the (extractive) question answering pipeline; however, the pipeline takes an image (and optional OCR'd 92 | words/boxes) as input instead of text context. 93 | 94 | This document question answering pipeline can currently be loaded from [`pipeline`] using the following task 95 | identifier: `"document-question-answering"`. 96 | 97 | The models that this pipeline can use are models that have been fine-tuned on a document question answering task. 98 | See the up-to-date list of available models on 99 | [huggingface.co/models](https://huggingface.co/models?filter=document-question-answering). 100 | """ 101 | 102 | def __init__(self, *args, **kwargs): 103 | super().__init__(*args, **kwargs) 104 | if self.model.config.__class__.__name__ == "VisionEncoderDecoderConfig": 105 | self.model_type = ModelType.VisionEncoderDecoder 106 | if self.model.config.encoder.model_type != "donut-swin": 107 | raise ValueError("Currently, the only supported VisionEncoderDecoder model is Donut") 108 | else: 109 | # self.check_model_type(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING) 110 | if self.model.config.__class__.__name__ == "LayoutLMConfig": 111 | self.model_type = ModelType.LayoutLM 112 | else: 113 | self.model_type = ModelType.LayoutLMv2andv3 114 | 115 | def _sanitize_parameters( 116 | self, 117 | padding=None, 118 | doc_stride=None, 119 | max_question_len=None, 120 | lang: Optional[str] = None, 121 | tesseract_config: Optional[str] = None, 122 | max_answer_len=None, 123 | max_seq_len=None, 124 | top_k=None, 125 | handle_impossible_answer=None, 126 | **kwargs, 127 | ): 128 | preprocess_params, postprocess_params = {}, {} 129 | if padding is not None: 130 | preprocess_params["padding"] = padding 131 | if doc_stride is not None: 132 | preprocess_params["doc_stride"] = doc_stride 133 | if max_question_len is not None: 134 | preprocess_params["max_question_len"] = max_question_len 135 | if max_seq_len is not None: 136 | preprocess_params["max_seq_len"] = max_seq_len 137 | if lang is not None: 138 | preprocess_params["lang"] = lang 139 | if tesseract_config is not None: 140 | preprocess_params["tesseract_config"] = tesseract_config 141 | 142 | if top_k is not None: 143 | if top_k < 1: 144 | raise ValueError(f"top_k parameter should be >= 1 (got {top_k})") 145 | postprocess_params["top_k"] = top_k 146 | if max_answer_len is not None: 147 | if max_answer_len < 1: 148 | raise ValueError(f"max_answer_len parameter should be >= 1 (got {max_answer_len}") 149 | postprocess_params["max_answer_len"] = max_answer_len 150 | if handle_impossible_answer is not None: 151 | postprocess_params["handle_impossible_answer"] = handle_impossible_answer 152 | 153 | return preprocess_params, {}, postprocess_params 154 | 155 | def __call__( 156 | self, 157 | image: Union[ImageOrName, List[ImageOrName], List[Tuple]], 158 | question: Optional[str] = None, 159 | **kwargs, 160 | ): 161 | """ 162 | Answer the question(s) given as inputs by using the document(s). A document is defined as an image and an 163 | optional list of (word, box) tuples which represent the text in the document. If the `word_boxes` are not 164 | provided, it will use the Tesseract OCR engine (if available) to extract the words and boxes automatically for 165 | LayoutLM-like models which require them as input. For Donut, no OCR is run. 166 | 167 | You can invoke the pipeline several ways: 168 | 169 | - `pipeline(image=image, question=question)` 170 | - `pipeline(image=image, question=question, word_boxes=word_boxes)` 171 | - `pipeline([{"image": image, "question": question}])` 172 | - `pipeline([{"image": image, "question": question, "word_boxes": word_boxes}])` 173 | 174 | Args: 175 | image (`str` or `PIL.Image`): 176 | The pipeline handles three types of images: 177 | 178 | - A string containing a http link pointing to an image 179 | - A string containing a local path to an image 180 | - An image loaded in PIL directly 181 | 182 | The pipeline accepts either a single image or a batch of images. If given a single image, it can be 183 | broadcasted to multiple questions. 184 | question (`str`): 185 | A question to ask of the document. 186 | word_boxes (`List[str, Tuple[float, float, float, float]]`, *optional*): 187 | A list of words and bounding boxes (normalized 0->1000). If you provide this optional input, then the 188 | pipeline will use these words and boxes instead of running OCR on the image to derive them for models 189 | that need them (e.g. LayoutLM). This allows you to reuse OCR'd results across many invocations of the 190 | pipeline without having to re-run it each time. 191 | top_k (`int`, *optional*, defaults to 1): 192 | The number of answers to return (will be chosen by order of likelihood). Note that we return less than 193 | top_k answers if there are not enough options available within the context. 194 | doc_stride (`int`, *optional*, defaults to 128): 195 | If the words in the document are too long to fit with the question for the model, it will be split in 196 | several chunks with some overlap. This argument controls the size of that overlap. 197 | max_answer_len (`int`, *optional*, defaults to 15): 198 | The maximum length of predicted answers (e.g., only answers with a shorter length are considered). 199 | max_seq_len (`int`, *optional*, defaults to 384): 200 | The maximum length of the total sentence (context + question) in tokens of each chunk passed to the 201 | model. The context will be split in several chunks (using `doc_stride` as overlap) if needed. 202 | max_question_len (`int`, *optional*, defaults to 64): 203 | The maximum length of the question after tokenization. It will be truncated if needed. 204 | handle_impossible_answer (`bool`, *optional*, defaults to `False`): 205 | Whether or not we accept impossible as an answer. 206 | lang (`str`, *optional*): 207 | Language to use while running OCR. Defaults to english. 208 | tesseract_config (`str`, *optional*): 209 | Additional flags to pass to tesseract while running OCR. 210 | 211 | Return: 212 | A `dict` or a list of `dict`: Each result comes as a dictionary with the following keys: 213 | 214 | - **score** (`float`) -- The probability associated to the answer. 215 | - **start** (`int`) -- The start word index of the answer (in the OCR'd version of the input or provided 216 | `word_boxes`). 217 | - **end** (`int`) -- The end word index of the answer (in the OCR'd version of the input or provided 218 | `word_boxes`). 219 | - **answer** (`str`) -- The answer to the question. 220 | - **words** (`list[int]`) -- The index of each word/box pair that is in the answer 221 | - **page** (`int`) -- The page of the answer 222 | """ 223 | if question is None: 224 | question = image["question"] 225 | image = image["image"] 226 | 227 | if isinstance(image, list): 228 | normalized_images = (i if isinstance(i, (tuple, list)) else (i, None) for i in image) 229 | else: 230 | normalized_images = [(image, None)] 231 | 232 | return super().__call__({"question": question, "pages": normalized_images}, **kwargs) 233 | 234 | def preprocess( 235 | self, 236 | input, 237 | padding="do_not_pad", 238 | doc_stride=None, 239 | max_question_len=64, 240 | max_seq_len=None, 241 | word_boxes: Tuple[str, List[float]] = None, 242 | lang=None, 243 | tesseract_config="", 244 | ): 245 | # NOTE: This code mirrors the code in question answering and will be implemented in a follow up PR 246 | # to support documents with enough tokens that overflow the model's window 247 | if max_seq_len is None: 248 | max_seq_len = min(self.tokenizer.model_max_length, 512) 249 | 250 | if doc_stride is None: 251 | doc_stride = min(max_seq_len // 2, 256) 252 | 253 | for page_idx, (image, word_boxes) in enumerate(input["pages"]): 254 | image_features = {} 255 | if image is not None: 256 | image = load_image(image) 257 | if self.feature_extractor is not None: 258 | image_features.update(self.feature_extractor(images=image, return_tensors=self.framework)) 259 | elif self.model_type == ModelType.VisionEncoderDecoder: 260 | raise ValueError( 261 | "If you are using a VisionEncoderDecoderModel, you must provide a feature extractor" 262 | ) 263 | 264 | words, boxes = None, None 265 | if not self.model_type == ModelType.VisionEncoderDecoder: 266 | if word_boxes is not None: 267 | words = [x[0] for x in word_boxes] 268 | boxes = [x[1] for x in word_boxes] 269 | elif "words" in image_features and "boxes" in image_features: 270 | words = image_features.pop("words")[0] 271 | boxes = image_features.pop("boxes")[0] 272 | elif image is not None: 273 | if not TESSERACT_LOADED: 274 | raise ValueError( 275 | "If you provide an image without word_boxes, then the pipeline will run OCR using" 276 | " Tesseract, but pytesseract is not available. Install it with pip install pytesseract." 277 | ) 278 | if TESSERACT_LOADED: 279 | words, boxes = apply_tesseract(image, lang=lang, tesseract_config=tesseract_config) 280 | else: 281 | raise ValueError( 282 | "You must provide an image or word_boxes. If you provide an image, the pipeline will" 283 | " automatically run OCR to derive words and boxes" 284 | ) 285 | 286 | if self.tokenizer.padding_side != "right": 287 | raise ValueError( 288 | "Document question answering only supports tokenizers whose padding side is 'right', not" 289 | f" {self.tokenizer.padding_side}" 290 | ) 291 | 292 | if self.model_type == ModelType.VisionEncoderDecoder: 293 | task_prompt = f'{input["question"]}' 294 | # Adapted from https://huggingface.co/spaces/nielsr/donut-docvqa/blob/main/app.py 295 | encoding = { 296 | "inputs": image_features["pixel_values"], 297 | "decoder_input_ids": self.tokenizer( 298 | task_prompt, add_special_tokens=False, return_tensors=self.framework 299 | ).input_ids, 300 | "return_dict_in_generate": True, 301 | } 302 | 303 | yield { 304 | **encoding, 305 | "p_mask": None, 306 | "word_ids": None, 307 | "words": None, 308 | "page": None, 309 | "output_attentions": True, 310 | } 311 | else: 312 | tokenizer_kwargs = {} 313 | if self.model_type == ModelType.LayoutLM: 314 | tokenizer_kwargs["text"] = input["question"].split() 315 | tokenizer_kwargs["text_pair"] = words 316 | tokenizer_kwargs["is_split_into_words"] = True 317 | else: 318 | tokenizer_kwargs["text"] = [input["question"]] 319 | tokenizer_kwargs["text_pair"] = [words] 320 | tokenizer_kwargs["boxes"] = [boxes] 321 | 322 | encoding = self.tokenizer( 323 | padding=padding, 324 | max_length=max_seq_len, 325 | stride=doc_stride, 326 | truncation="only_second", 327 | return_overflowing_tokens=True, 328 | **tokenizer_kwargs, 329 | ) 330 | 331 | if "pixel_values" in image_features: 332 | encoding["image"] = image_features.pop("pixel_values") 333 | 334 | num_spans = len(encoding["input_ids"]) 335 | 336 | # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer) 337 | # We put 0 on the tokens from the context and 1 everywhere else (question and special tokens) 338 | # This logic mirrors the logic in the question_answering pipeline 339 | p_mask = [[tok != 1 for tok in encoding.sequence_ids(span_id)] for span_id in range(num_spans)] 340 | 341 | for span_idx in range(num_spans): 342 | if self.framework == "pt": 343 | span_encoding = {k: torch.tensor(v[span_idx : span_idx + 1]) for (k, v) in encoding.items()} 344 | span_encoding.update( 345 | {k: v for (k, v) in image_features.items()} 346 | ) # TODO: Verify cardinality is correct 347 | else: 348 | raise ValueError("Unsupported: Tensorflow preprocessing for DocumentQuestionAnsweringPipeline") 349 | 350 | input_ids_span_idx = encoding["input_ids"][span_idx] 351 | # keep the cls_token unmasked (some models use it to indicate unanswerable questions) 352 | if self.tokenizer.cls_token_id is not None: 353 | cls_indices = np.nonzero(np.array(input_ids_span_idx) == self.tokenizer.cls_token_id)[0] 354 | for cls_index in cls_indices: 355 | p_mask[span_idx][cls_index] = 0 356 | 357 | # For each span, place a bounding box [0,0,0,0] for question and CLS tokens, [1000,1000,1000,1000] 358 | # for SEP tokens, and the word's bounding box for words in the original document. 359 | if "boxes" not in tokenizer_kwargs: 360 | bbox = [] 361 | 362 | for input_id, sequence_id, word_id in zip( 363 | encoding.input_ids[span_idx], 364 | encoding.sequence_ids(span_idx), 365 | encoding.word_ids(span_idx), 366 | ): 367 | if sequence_id == 1: 368 | bbox.append(boxes[word_id]) 369 | elif input_id == self.tokenizer.sep_token_id: 370 | bbox.append([1000] * 4) 371 | else: 372 | bbox.append([0] * 4) 373 | 374 | if self.framework == "pt": 375 | span_encoding["bbox"] = torch.tensor(bbox).unsqueeze(0) 376 | elif self.framework == "tf": 377 | raise ValueError( 378 | "Unsupported: Tensorflow preprocessing for DocumentQuestionAnsweringPipeline" 379 | ) 380 | 381 | yield { 382 | **span_encoding, 383 | "p_mask": p_mask[span_idx], 384 | "word_ids": encoding.word_ids(span_idx), 385 | "words": words, 386 | "page": page_idx, 387 | } 388 | 389 | def _forward(self, model_inputs): 390 | p_mask = model_inputs.pop("p_mask", None) 391 | word_ids = model_inputs.pop("word_ids", None) 392 | words = model_inputs.pop("words", None) 393 | page = model_inputs.pop("page", None) 394 | 395 | if "overflow_to_sample_mapping" in model_inputs: 396 | model_inputs.pop("overflow_to_sample_mapping") 397 | 398 | if self.model_type == ModelType.VisionEncoderDecoder: 399 | model_outputs = self.model.generate(**model_inputs) 400 | else: 401 | model_outputs = self.model(**model_inputs) 402 | 403 | model_outputs["p_mask"] = p_mask 404 | model_outputs["word_ids"] = word_ids 405 | model_outputs["words"] = words 406 | model_outputs["page"] = page 407 | model_outputs["attention_mask"] = model_inputs.get("attention_mask", None) 408 | return model_outputs 409 | 410 | def postprocess(self, model_outputs, top_k=1, **kwargs): 411 | if self.model_type == ModelType.VisionEncoderDecoder: 412 | answers = [self.postprocess_encoder_decoder_single(o) for o in model_outputs] 413 | else: 414 | answers = self.postprocess_extractive_qa(model_outputs, top_k=top_k, **kwargs) 415 | 416 | answers = sorted(answers, key=lambda x: x.get("score", 0), reverse=True)[:top_k] 417 | return answers 418 | 419 | def postprocess_encoder_decoder_single(self, model_outputs, **kwargs): 420 | sequence = self.tokenizer.batch_decode(model_outputs.sequences)[0] 421 | 422 | # TODO: A lot of this logic is specific to Donut and should probably be handled in the tokenizer 423 | # (see https://github.com/huggingface/transformers/pull/18414/files#r961747408 for more context). 424 | sequence = sequence.replace(self.tokenizer.eos_token, "").replace(self.tokenizer.pad_token, "") 425 | sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token 426 | ret = { 427 | "answer": None, 428 | } 429 | 430 | answer = re.search(r"(.*)", sequence) 431 | if answer is not None: 432 | ret["answer"] = answer.group(1).strip() 433 | return ret 434 | 435 | def postprocess_extractive_qa( 436 | self, model_outputs, top_k=1, handle_impossible_answer=False, max_answer_len=None, **kwargs 437 | ): 438 | min_null_score = 1000000 # large and positive 439 | answers = [] 440 | 441 | if max_answer_len is None: 442 | if hasattr(self.model.config, "token_classification") and self.model.config.token_classification: 443 | # If this model has token classification, then use a much longer max answer length and 444 | # let the classifier remove things 445 | max_answer_len = self.tokenizer.model_max_length 446 | else: 447 | max_answer_len = DEFAULT_MAX_ANSWER_LENGTH 448 | 449 | for output in model_outputs: 450 | words = output["words"] 451 | 452 | starts, ends, scores, min_null_score = select_starts_ends( 453 | output["start_logits"], 454 | output["end_logits"], 455 | output["p_mask"], 456 | output["attention_mask"].numpy() if output.get("attention_mask", None) is not None else None, 457 | min_null_score, 458 | top_k, 459 | handle_impossible_answer, 460 | max_answer_len, 461 | ) 462 | word_ids = output["word_ids"] 463 | for start, end, score in zip(starts, ends, scores): 464 | if "token_logits" in output: 465 | predicted_token_classes = ( 466 | output["token_logits"][ 467 | 0, 468 | start : end + 1, 469 | ] 470 | .argmax(axis=1) 471 | .cpu() 472 | .numpy() 473 | ) 474 | assert np.setdiff1d(predicted_token_classes, [0, 1]).shape == (0,) 475 | token_indices = np.flatnonzero(predicted_token_classes) + start 476 | else: 477 | token_indices = range(start, end + 1) 478 | 479 | answer_word_ids = list(unique_everseen([word_ids[i] for i in token_indices])) 480 | if len(answer_word_ids) > 0 and answer_word_ids[0] is not None and answer_word_ids[-1] is not None: 481 | answers.append( 482 | { 483 | "score": float(score), 484 | "answer": " ".join(words[i] for i in answer_word_ids), 485 | "word_ids": answer_word_ids, 486 | "page": output["page"], 487 | } 488 | ) 489 | 490 | if handle_impossible_answer: 491 | answers.append({"score": min_null_score, "answer": "", "start": 0, "end": 0}) 492 | 493 | return answers 494 | -------------------------------------------------------------------------------- /src/docquery/ext/qa_helpers.py: -------------------------------------------------------------------------------- 1 | # NOTE: This code is currently under review for inclusion in the main 2 | # huggingface/transformers repository: 3 | # https://github.com/huggingface/transformers/pull/18414 4 | 5 | import warnings 6 | from collections.abc import Iterable 7 | from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union 8 | 9 | import numpy as np 10 | from transformers.utils import is_pytesseract_available, is_vision_available 11 | 12 | 13 | VISION_LOADED = False 14 | if is_vision_available(): 15 | from PIL import Image 16 | from transformers.image_utils import load_image 17 | 18 | VISION_LOADED = True 19 | else: 20 | Image = None 21 | load_image = None 22 | 23 | 24 | TESSERACT_LOADED = False 25 | if is_pytesseract_available(): 26 | import pytesseract 27 | 28 | TESSERACT_LOADED = True 29 | else: 30 | pytesseract = None 31 | 32 | 33 | def decode_spans( 34 | start: np.ndarray, end: np.ndarray, topk: int, max_answer_len: int, undesired_tokens: np.ndarray 35 | ) -> Tuple: 36 | """ 37 | Take the output of any `ModelForQuestionAnswering` and will generate probabilities for each span to be the actual 38 | answer. 39 | 40 | In addition, it filters out some unwanted/impossible cases like answer len being greater than max_answer_len or 41 | answer end position being before the starting position. The method supports output the k-best answer through the 42 | topk argument. 43 | 44 | Args: 45 | start (`np.ndarray`): Individual start probabilities for each token. 46 | end (`np.ndarray`): Individual end probabilities for each token. 47 | topk (`int`): Indicates how many possible answer span(s) to extract from the model output. 48 | max_answer_len (`int`): Maximum size of the answer to extract from the model's output. 49 | undesired_tokens (`np.ndarray`): Mask determining tokens that can be part of the answer 50 | """ 51 | # Ensure we have batch axis 52 | if start.ndim == 1: 53 | start = start[None] 54 | 55 | if end.ndim == 1: 56 | end = end[None] 57 | 58 | # Compute the score of each tuple(start, end) to be the real answer 59 | outer = np.matmul(np.expand_dims(start, -1), np.expand_dims(end, 1)) 60 | 61 | # Remove candidate with end < start and end - start > max_answer_len 62 | candidates = np.tril(np.triu(outer), max_answer_len - 1) 63 | 64 | # Inspired by Chen & al. (https://github.com/facebookresearch/DrQA) 65 | scores_flat = candidates.flatten() 66 | if topk == 1: 67 | idx_sort = [np.argmax(scores_flat)] 68 | elif len(scores_flat) < topk: 69 | idx_sort = np.argsort(-scores_flat) 70 | else: 71 | idx = np.argpartition(-scores_flat, topk)[0:topk] 72 | idx_sort = idx[np.argsort(-scores_flat[idx])] 73 | 74 | starts, ends = np.unravel_index(idx_sort, candidates.shape)[1:] 75 | desired_spans = np.isin(starts, undesired_tokens.nonzero()) & np.isin(ends, undesired_tokens.nonzero()) 76 | starts = starts[desired_spans] 77 | ends = ends[desired_spans] 78 | scores = candidates[0, starts, ends] 79 | 80 | return starts, ends, scores 81 | 82 | 83 | def select_starts_ends( 84 | start, 85 | end, 86 | p_mask, 87 | attention_mask, 88 | min_null_score=1000000, 89 | top_k=1, 90 | handle_impossible_answer=False, 91 | max_answer_len=15, 92 | ): 93 | """ 94 | Takes the raw output of any `ModelForQuestionAnswering` and first normalizes its outputs and then uses 95 | `decode_spans()` to generate probabilities for each span to be the actual answer. 96 | 97 | Args: 98 | start (`np.ndarray`): Individual start probabilities for each token. 99 | end (`np.ndarray`): Individual end probabilities for each token. 100 | p_mask (`np.ndarray`): A mask with 1 for values that cannot be in the answer 101 | attention_mask (`np.ndarray`): The attention mask generated by the tokenizer 102 | min_null_score(`float`): The minimum null (empty) answer score seen so far. 103 | topk (`int`): Indicates how many possible answer span(s) to extract from the model output. 104 | handle_impossible_answer(`bool`): Whether to allow null (empty) answers 105 | max_answer_len (`int`): Maximum size of the answer to extract from the model's output. 106 | """ 107 | # Ensure padded tokens & question tokens cannot belong to the set of candidate answers. 108 | undesired_tokens = np.abs(np.array(p_mask) - 1) 109 | 110 | if attention_mask is not None: 111 | undesired_tokens = undesired_tokens & attention_mask 112 | 113 | # Generate mask 114 | undesired_tokens_mask = undesired_tokens == 0.0 115 | 116 | # Make sure non-context indexes in the tensor cannot contribute to the softmax 117 | start = np.where(undesired_tokens_mask, -10000.0, start) 118 | end = np.where(undesired_tokens_mask, -10000.0, end) 119 | 120 | # Normalize logits and spans to retrieve the answer 121 | start = np.exp(start - start.max(axis=-1, keepdims=True)) 122 | start = start / start.sum() 123 | 124 | end = np.exp(end - end.max(axis=-1, keepdims=True)) 125 | end = end / end.sum() 126 | 127 | if handle_impossible_answer: 128 | min_null_score = min(min_null_score, (start[0, 0] * end[0, 0]).item()) 129 | 130 | # Mask CLS 131 | start[0, 0] = end[0, 0] = 0.0 132 | 133 | starts, ends, scores = decode_spans(start, end, top_k, max_answer_len, undesired_tokens) 134 | return starts, ends, scores, min_null_score 135 | -------------------------------------------------------------------------------- /src/docquery/find_leaf_nodes.js: -------------------------------------------------------------------------------- 1 | // This is adapted from code generated by GPT-3, using the following prompt: 2 | // Please write the code for a javascript program that can traverse the DOM and return each text node and its bounding box. If the text has multiple bounding boxes, it should combine them into one. The output should be Javascript code that can run in Chrome. 3 | 4 | function computeViewport() { 5 | return { 6 | vw: Math.max( 7 | document.documentElement.clientWidth || 0, 8 | window.innerWidth || 0 9 | ), 10 | vh: Math.max( 11 | document.documentElement.clientHeight || 0, 12 | window.innerHeight || 0 13 | ), 14 | 15 | // https://stackoverflow.com/questions/1145850/how-to-get-height-of-entire-document-with-javascript 16 | dh: Math.max( 17 | document.body.scrollHeight, 18 | document.body.offsetHeight, 19 | document.documentElement.clientHeight, 20 | document.documentElement.scrollHeight, 21 | document.documentElement.offsetHeight 22 | ), 23 | }; 24 | } 25 | 26 | function findLeafNodes(node) { 27 | var textNodes = []; 28 | var walk = document.createTreeWalker( 29 | document.body, 30 | NodeFilter.SHOW_TEXT, 31 | null, 32 | false 33 | ); 34 | 35 | while (walk.nextNode()) { 36 | var node = walk.currentNode; 37 | var range = document.createRange(); 38 | range.selectNodeContents(node); 39 | var rects = Array.from(range.getClientRects()); 40 | if (rects.length > 0) { 41 | var box = rects.reduce(function (previousValue, currentValue) { 42 | return { 43 | top: Math.min(previousValue.top, currentValue.top), 44 | right: Math.max(previousValue.right, currentValue.right), 45 | bottom: Math.max(previousValue.bottom, currentValue.bottom), 46 | left: Math.min(previousValue.left, currentValue.left), 47 | }; 48 | }); 49 | 50 | textNodes.push({ 51 | text: node.textContent, 52 | box, 53 | }); 54 | } 55 | } 56 | 57 | return textNodes; 58 | } 59 | 60 | function computeBoundingBoxes(node) { 61 | const word_boxes = findLeafNodes(node); 62 | return { 63 | ...computeViewport(), 64 | word_boxes, 65 | }; 66 | } 67 | -------------------------------------------------------------------------------- /src/docquery/ocr_reader.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import logging 3 | from typing import Any, List, Optional, Tuple 4 | 5 | import numpy as np 6 | from pydantic.fields import ModelField 7 | 8 | 9 | class NoOCRReaderFound(Exception): 10 | def __init__(self, e): 11 | self.e = e 12 | 13 | def __str__(self): 14 | return f"Could not load OCR Reader: {self.e}" 15 | 16 | 17 | OCR_AVAILABLE = { 18 | "tesseract": False, 19 | "easyocr": False, 20 | "dummy": True, 21 | } 22 | 23 | try: 24 | import pytesseract # noqa 25 | 26 | pytesseract.get_tesseract_version() 27 | OCR_AVAILABLE["tesseract"] = True 28 | except ImportError: 29 | pass 30 | except pytesseract.TesseractNotFoundError as e: 31 | logging.warning("Unable to find tesseract: %s." % (e)) 32 | pass 33 | 34 | try: 35 | import easyocr # noqa 36 | 37 | OCR_AVAILABLE["easyocr"] = True 38 | except ImportError: 39 | pass 40 | 41 | 42 | class SingletonMeta(type): 43 | _instances = {} 44 | 45 | def __call__(cls, *args, **kwargs): 46 | if cls not in cls._instances: 47 | instance = super().__call__(*args, **kwargs) 48 | cls._instances[cls] = instance 49 | return cls._instances[cls] 50 | 51 | 52 | class OCRReader(metaclass=SingletonMeta): 53 | def __init__(self): 54 | # TODO: add device here 55 | self._check_if_available() 56 | 57 | @classmethod 58 | def __get_validators__(cls): 59 | yield cls.validate 60 | 61 | @classmethod 62 | def validate(cls, v, field: ModelField): 63 | if not isinstance(v, cls): 64 | raise TypeError("Invalid value") 65 | return v 66 | 67 | @abc.abstractmethod 68 | def apply_ocr(self, image: "Image.Image") -> Tuple[List[Any], List[List[int]]]: 69 | raise NotImplementedError 70 | 71 | @staticmethod 72 | @abc.abstractmethod 73 | def _check_if_available(): 74 | raise NotImplementedError 75 | 76 | 77 | class TesseractReader(OCRReader): 78 | def __init__(self): 79 | super().__init__() 80 | 81 | def apply_ocr(self, image: "Image.Image") -> Tuple[List[str], List[List[int]]]: 82 | """ 83 | Applies Tesseract on a document image, and returns recognized words + normalized bounding boxes. 84 | This was derived from LayoutLM preprocessing code in Huggingface's Transformers library. 85 | """ 86 | data = pytesseract.image_to_data(image, output_type="dict") 87 | words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"] 88 | 89 | # filter empty words and corresponding coordinates 90 | irrelevant_indices = set(idx for idx, word in enumerate(words) if not word.strip()) 91 | words = [word for idx, word in enumerate(words) if idx not in irrelevant_indices] 92 | left = [coord for idx, coord in enumerate(left) if idx not in irrelevant_indices] 93 | top = [coord for idx, coord in enumerate(top) if idx not in irrelevant_indices] 94 | width = [coord for idx, coord in enumerate(width) if idx not in irrelevant_indices] 95 | height = [coord for idx, coord in enumerate(height) if idx not in irrelevant_indices] 96 | 97 | # turn coordinates into (left, top, left+width, top+height) format 98 | actual_boxes = [[x, y, x + w, y + h] for x, y, w, h in zip(left, top, width, height)] 99 | 100 | return words, actual_boxes 101 | 102 | @staticmethod 103 | def _check_if_available(): 104 | if not OCR_AVAILABLE["tesseract"]: 105 | raise NoOCRReaderFound( 106 | "Unable to use pytesseract (OCR will be unavailable). Install tesseract to process images with OCR." 107 | ) 108 | 109 | 110 | class EasyOCRReader(OCRReader): 111 | def __init__(self): 112 | super().__init__() 113 | self.reader = None 114 | 115 | def apply_ocr(self, image: "Image.Image") -> Tuple[List[str], List[List[int]]]: 116 | """Applies Easy OCR on a document image, and returns recognized words + normalized bounding boxes.""" 117 | if not self.reader: 118 | # TODO: expose language currently setting to english 119 | self.reader = easyocr.Reader(["en"]) # TODO: device here example: gpu=self.device > -1) 120 | 121 | # apply OCR 122 | data = self.reader.readtext(np.array(image)) 123 | boxes, words, acc = list(map(list, zip(*data))) 124 | 125 | # filter empty words and corresponding coordinates 126 | irrelevant_indices = set(idx for idx, word in enumerate(words) if not word.strip()) 127 | words = [word for idx, word in enumerate(words) if idx not in irrelevant_indices] 128 | boxes = [coords for idx, coords in enumerate(boxes) if idx not in irrelevant_indices] 129 | 130 | # turn coordinates into (left, top, left+width, top+height) format 131 | actual_boxes = [tl + br for tl, tr, br, bl in boxes] 132 | 133 | return words, actual_boxes 134 | 135 | @staticmethod 136 | def _check_if_available(): 137 | if not OCR_AVAILABLE["easyocr"]: 138 | raise NoOCRReaderFound( 139 | "Unable to use easyocr (OCR will be unavailable). Install easyocr to process images with OCR." 140 | ) 141 | 142 | 143 | class DummyOCRReader(OCRReader): 144 | def __init__(self): 145 | super().__init__() 146 | self.reader = None 147 | 148 | def apply_ocr(self, image: "Image.Image") -> Tuple[(List[str], List[List[int]])]: 149 | raise NoOCRReaderFound("Unable to find any OCR reader and OCR extraction was requested") 150 | 151 | @staticmethod 152 | def _check_if_available(): 153 | logging.warning("Falling back to a dummy OCR reader since none were found.") 154 | 155 | 156 | OCR_MAPPING = { 157 | "tesseract": TesseractReader, 158 | "easyocr": EasyOCRReader, 159 | "dummy": DummyOCRReader, 160 | } 161 | 162 | 163 | def get_ocr_reader(ocr_reader_name: Optional[str] = None): 164 | if not ocr_reader_name: 165 | for name, reader in OCR_MAPPING.items(): 166 | if OCR_AVAILABLE[name]: 167 | return reader() 168 | 169 | if ocr_reader_name in OCR_MAPPING.keys(): 170 | if OCR_AVAILABLE[ocr_reader_name]: 171 | return OCR_MAPPING[ocr_reader_name]() 172 | else: 173 | raise NoOCRReaderFound(f"Failed to load: {ocr_reader_name} Please make sure its installed correctly.") 174 | else: 175 | raise NoOCRReaderFound( 176 | f"Failed to find: {ocr_reader_name} in the available ocr libraries. The choices are: {list(OCR_MAPPING.keys())}" 177 | ) 178 | -------------------------------------------------------------------------------- /src/docquery/transformers_patch.py: -------------------------------------------------------------------------------- 1 | # This file contains extensions to transformers that have not yet been upstreamed. Importantly, since docquery 2 | # is designed to be easy to install via PyPI, we must extend anything that is not part of an official release, 3 | # since libraries on pypi are not permitted to install specific git commits. 4 | 5 | from collections import OrderedDict 6 | from typing import Optional, Union 7 | 8 | import torch 9 | from transformers import AutoConfig, AutoModel, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast 10 | from transformers import pipeline as transformers_pipeline 11 | from transformers.models.auto.auto_factory import _BaseAutoModelClass, _LazyAutoMapping 12 | from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES 13 | from transformers.pipelines import PIPELINE_REGISTRY 14 | 15 | from .ext.model import LayoutLMForQuestionAnswering 16 | from .ext.pipeline_document_classification import DocumentClassificationPipeline 17 | from .ext.pipeline_document_question_answering import DocumentQuestionAnsweringPipeline 18 | 19 | 20 | PIPELINE_DEFAULTS = { 21 | "document-question-answering": "impira/layoutlm-document-qa", 22 | "document-classification": "impira/layoutlm-document-classifier", 23 | } 24 | 25 | # These revisions are pinned so that the "default" experience in DocQuery is both fast (does not 26 | # need to check network for updates) and versioned (we can be sure that the model changes 27 | # result in new versions of DocQuery). This may eventually change. 28 | DEFAULT_REVISIONS = { 29 | "impira/layoutlm-document-qa": "ff904df", 30 | "impira/layoutlm-invoices": "783b0c2", 31 | "naver-clova-ix/donut-base-finetuned-rvlcdip": "5998e9f", 32 | # XXX add impira-document-classifier 33 | } 34 | 35 | MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( 36 | [ 37 | ("layoutlm", "LayoutLMForQuestionAnswering"), 38 | ("donut-swin", "DonutSwinModel"), 39 | ] 40 | ) 41 | 42 | MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( 43 | CONFIG_MAPPING_NAMES, 44 | MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES, 45 | ) 46 | 47 | 48 | MODEL_FOR_DOCUMENT_CLASSIFICATION_MAPPING_NAMES = OrderedDict( 49 | [ 50 | ("layoutlm", "LayoutLMForSequenceClassification"), 51 | ] 52 | ) 53 | 54 | MODEL_FOR_DOCUMENT_CLASSIFICATION_MAPPING = _LazyAutoMapping( 55 | CONFIG_MAPPING_NAMES, 56 | MODEL_FOR_DOCUMENT_CLASSIFICATION_MAPPING_NAMES, 57 | ) 58 | 59 | 60 | class AutoModelForDocumentQuestionAnswering(_BaseAutoModelClass): 61 | _model_mapping = MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING 62 | 63 | 64 | class AutoModelForDocumentClassification(_BaseAutoModelClass): 65 | _model_mapping = MODEL_FOR_DOCUMENT_CLASSIFICATION_MAPPING 66 | 67 | 68 | PIPELINE_REGISTRY.register_pipeline( 69 | "document-question-answering", 70 | pipeline_class=DocumentQuestionAnsweringPipeline, 71 | pt_model=AutoModelForDocumentQuestionAnswering, 72 | ) 73 | 74 | PIPELINE_REGISTRY.register_pipeline( 75 | "document-classification", 76 | pipeline_class=DocumentClassificationPipeline, 77 | pt_model=AutoModelForDocumentClassification, 78 | ) 79 | 80 | 81 | def pipeline( 82 | task: str = None, 83 | model: Optional = None, 84 | tokenizer: Optional[Union[str, PreTrainedTokenizer, PreTrainedTokenizerFast]] = None, 85 | revision: Optional[str] = None, 86 | device: Optional[Union[int, str, "torch.device"]] = None, 87 | **pipeline_kwargs 88 | ): 89 | 90 | if model is None and task is not None: 91 | model = PIPELINE_DEFAULTS.get(task) 92 | 93 | if revision is None and model is not None: 94 | revision = DEFAULT_REVISIONS.get(model) 95 | 96 | # We need to explicitly check for the impira/layoutlm-document-qa model because of challenges with 97 | # registering an existing model "flavor" (layoutlm) within transformers after the fact. There may 98 | # be a clever way to get around this. Either way, we should be able to remove it once 99 | # https://github.com/huggingface/transformers/commit/5c4c869014f5839d04c1fd28133045df0c91fd84 100 | # is officially released. 101 | config = AutoConfig.from_pretrained(model, revision=revision, **{**pipeline_kwargs}) 102 | 103 | if tokenizer is None: 104 | tokenizer = AutoTokenizer.from_pretrained( 105 | model, 106 | revision=revision, 107 | config=config, 108 | **pipeline_kwargs, 109 | ) 110 | 111 | if any(a == "LayoutLMForQuestionAnswering" for a in config.architectures): 112 | model = LayoutLMForQuestionAnswering.from_pretrained( 113 | model, config=config, revision=revision, **{**pipeline_kwargs} 114 | ) 115 | 116 | if config.model_type == "vision-encoder-decoder": 117 | # This _should_ be a feature of transformers -- deriving the feature_extractor automatically -- 118 | # but is not at the time of writing, so we do it explicitly. 119 | pipeline_kwargs["feature_extractor"] = model 120 | 121 | if device is None: 122 | # This trick merely simplifies the device argument, so that cuda is used by default if it's 123 | # available, which at the time of writing is not a feature of transformers 124 | device = 0 if torch.cuda.is_available() else -1 125 | 126 | return transformers_pipeline( 127 | task, 128 | revision=revision, 129 | model=model, 130 | tokenizer=tokenizer, 131 | device=device, 132 | **pipeline_kwargs, 133 | ) 134 | -------------------------------------------------------------------------------- /src/docquery/version.py: -------------------------------------------------------------------------------- 1 | VERSION = "0.0.7" 2 | -------------------------------------------------------------------------------- /src/docquery/web.py: -------------------------------------------------------------------------------- 1 | import os 2 | from io import BytesIO 3 | from pathlib import Path 4 | 5 | from PIL import Image 6 | 7 | from .config import get_logger 8 | from .ext.functools import cached_property 9 | 10 | 11 | log = get_logger("web") 12 | 13 | try: 14 | from selenium import webdriver 15 | from selenium.common import exceptions 16 | from selenium.webdriver.chrome.options import Options 17 | from webdriver_manager.chrome import ChromeDriverManager 18 | from webdriver_manager.core.utils import ChromeType 19 | 20 | WEB_AVAILABLE = True 21 | except ImportError: 22 | WEB_AVAILABLE = False 23 | 24 | 25 | FIND_LEAF_NODES_JS = None 26 | WEB_DRIVER = None 27 | dir_path = Path(os.path.dirname(os.path.realpath(__file__))) 28 | 29 | 30 | class WebDriver: 31 | def __init__(self): 32 | if not WEB_AVAILABLE: 33 | raise ValueError( 34 | "Web imports are unavailable. You must install the [web] extra and chrome or" " chromium system-wide." 35 | ) 36 | 37 | self._reinit_driver() 38 | 39 | def _reinit_driver(self): 40 | options = Options() 41 | options.headless = True 42 | options.add_argument("--window-size=1920,1200") 43 | if os.geteuid() == 0: 44 | options.add_argument("--no-sandbox") 45 | 46 | self.driver = webdriver.Chrome( 47 | options=options, executable_path=ChromeDriverManager(chrome_type=ChromeType.CHROMIUM).install() 48 | ) 49 | 50 | def get(self, page, retry=True): 51 | try: 52 | self.driver.get(page) 53 | except exceptions.InvalidSessionIdException: 54 | if retry: 55 | # Forgive an invalid session once and try again 56 | self._reinit_driver() 57 | return self.get(page, retry=False) 58 | else: 59 | raise 60 | 61 | def get_html(self, html): 62 | # https://stackoverflow.com/questions/22538457/put-a-string-with-html-javascript-into-selenium-webdriver 63 | self.get("data:text/html;charset=utf-8," + html) 64 | 65 | def find_word_boxes(self): 66 | # Assumes the driver has been pointed at the right website already 67 | return self.driver.execute_script( 68 | self.lib_js 69 | + """ 70 | return computeBoundingBoxes(document.body); 71 | """ 72 | ) 73 | 74 | # TODO: Handle horizontal scrolling 75 | def scroll_and_screenshot(self): 76 | tops = [] 77 | images = [] 78 | dims = self.driver.execute_script( 79 | self.lib_js 80 | + """ 81 | return computeViewport() 82 | """ 83 | ) 84 | 85 | view_height = dims["vh"] 86 | doc_height = dims["dh"] 87 | 88 | try: 89 | self.driver.execute_script("window.scroll(0, 0)") 90 | curr = self.driver.execute_script("return window.scrollY") 91 | 92 | while True: 93 | tops.append(curr) 94 | images.append(Image.open(BytesIO(self.driver.get_screenshot_as_png()))) 95 | if curr + view_height < doc_height: 96 | self.driver.execute_script(f"window.scroll(0, {curr+view_height})") 97 | 98 | curr = self.driver.execute_script("return window.scrollY") 99 | if curr <= tops[-1]: 100 | break 101 | finally: 102 | # Reset scroll to the top of the page 103 | self.driver.execute_script("window.scroll(0, 0)") 104 | 105 | if len(tops) >= 2: 106 | _, second_last_height = images[-2].size 107 | if tops[-1] - tops[-2] < second_last_height: 108 | # This occurs when the last screenshot should be "clipped". Adjust the last "top" 109 | # to correspond to the right view_height and clip the screenshot accordingly 110 | delta = second_last_height - (tops[-1] - tops[-2]) 111 | tops[-1] += delta 112 | 113 | last_img = images[-1] 114 | last_width, last_height = last_img.size 115 | images[-1] = last_img.crop((0, delta, last_width, last_height)) 116 | 117 | return tops, images 118 | 119 | @cached_property 120 | def lib_js(self): 121 | with open(dir_path / "find_leaf_nodes.js", "r") as f: 122 | return f.read() 123 | 124 | 125 | def get_webdriver(): 126 | global WEB_DRIVER 127 | if WEB_DRIVER is None: 128 | WEB_DRIVER = WebDriver() 129 | return WEB_DRIVER 130 | -------------------------------------------------------------------------------- /tests/test_classification_end_to_end.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import pytest 4 | from pydantic import BaseModel 5 | from transformers.testing_utils import nested_simplify 6 | 7 | from docquery import pipeline 8 | from docquery.document import load_document 9 | from docquery.ocr_reader import TesseractReader 10 | 11 | 12 | CHECKPOINTS = { 13 | "Donut": "naver-clova-ix/donut-base-finetuned-rvlcdip", 14 | } 15 | 16 | 17 | class Example(BaseModel): 18 | name: str 19 | path: str 20 | classes: Dict[str, List[str]] 21 | 22 | 23 | # Use the examples from the DocQuery space (this also solves for hosting) 24 | EXAMPLES = [ 25 | Example( 26 | name="contract", 27 | path="https://huggingface.co/spaces/impira/docquery/resolve/2f6c96314dc84dfda62d40de9da55f2f5165d403/contract.jpeg", 28 | classes={"Donut": ["scientific_report"]}, 29 | ), 30 | Example( 31 | name="invoice", 32 | path="https://huggingface.co/spaces/impira/docquery/resolve/2f6c96314dc84dfda62d40de9da55f2f5165d403/invoice.png", 33 | classes={"Donut": ["invoice"]}, 34 | ), 35 | Example( 36 | name="statement", 37 | path="https://huggingface.co/spaces/impira/docquery/resolve/2f6c96314dc84dfda62d40de9da55f2f5165d403/statement.png", 38 | classes={"Donut": ["budget"]}, 39 | ), 40 | ] 41 | 42 | 43 | @pytest.mark.parametrize("example", EXAMPLES) 44 | @pytest.mark.parametrize("model", CHECKPOINTS.keys()) 45 | def test_impira_dataset(example, model): 46 | document = load_document(example.path) 47 | pipe = pipeline("document-classification", model=CHECKPOINTS[model]) 48 | resp = pipe(top_k=1, **document.context) 49 | assert resp == [{"label": x} for x in example.classes[model]] 50 | -------------------------------------------------------------------------------- /tests/test_end_to_end.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import pytest 4 | from pydantic import BaseModel 5 | from transformers.testing_utils import nested_simplify 6 | 7 | from docquery import pipeline 8 | from docquery.document import load_document 9 | from docquery.ocr_reader import TesseractReader 10 | 11 | 12 | CHECKPOINTS = { 13 | "LayoutLMv1": "impira/layoutlm-document-qa", 14 | "LayoutLMv1-Invoices": "impira/layoutlm-invoices", 15 | "Donut": "naver-clova-ix/donut-base-finetuned-docvqa", 16 | } 17 | 18 | 19 | class QAPair(BaseModel): 20 | question: str 21 | answers: Dict[str, List[Dict]] 22 | 23 | 24 | class Example(BaseModel): 25 | name: str 26 | path: str 27 | qa_pairs: List[QAPair] 28 | 29 | 30 | # Use the examples from the DocQuery space (this also solves for hosting) 31 | EXAMPLES = [ 32 | Example( 33 | name="contract", 34 | path="https://huggingface.co/spaces/impira/docquery/resolve/2f6c96314dc84dfda62d40de9da55f2f5165d403/contract.jpeg", 35 | qa_pairs=[ 36 | { 37 | "question": "What is the purchase amount?", 38 | "answers": { 39 | "LayoutLMv1": [{"score": 0.9999, "answer": "$1,000,000,000", "word_ids": [97], "page": 0}], 40 | "LayoutLMv1-Invoices": [ 41 | {"score": 0.9997, "answer": "$1,000,000,000", "word_ids": [97], "page": 0} 42 | ], 43 | "Donut": [{"answer": "$1,0000,000,00"}], 44 | }, 45 | } 46 | ], 47 | ), 48 | Example( 49 | name="invoice", 50 | path="https://huggingface.co/spaces/impira/docquery/resolve/2f6c96314dc84dfda62d40de9da55f2f5165d403/invoice.png", 51 | qa_pairs=[ 52 | { 53 | "question": "What is the invoice number?", 54 | "answers": { 55 | "LayoutLMv1": [{"score": 0.9997, "answer": "us-001", "word_ids": [15], "page": 0}], 56 | "LayoutLMv1-Invoices": [{"score": 0.9999, "answer": "us-001", "word_ids": [15], "page": 0}], 57 | "Donut": [{"answer": "us-001"}], 58 | }, 59 | } 60 | ], 61 | ), 62 | Example( 63 | name="statement", 64 | path="https://huggingface.co/spaces/impira/docquery/resolve/2f6c96314dc84dfda62d40de9da55f2f5165d403/statement.pdf", 65 | qa_pairs=[ 66 | { 67 | "question": "What are net sales for 2020?", 68 | "answers": { 69 | "LayoutLMv1": [{"score": 0.9429, "answer": "$ 3,750\n", "word_ids": [15, 16], "page": 0}], 70 | # (The answer with `use_embedded_text=False` relies entirely on Tesseract, and it is incorrect because it 71 | # misses 3,750 altogether.) 72 | "LayoutLMv1__use_embedded_text=False": [ 73 | {"score": 0.3078, "answer": "$ 3,980", "word_ids": [11, 12], "page": 0} 74 | ], 75 | "LayoutLMv1-Invoices": [{"score": 0.9956, "answer": "$ 3,750\n", "word_ids": [15, 16], "page": 0}], 76 | "Donut": [{"answer": "$ 3,750"}], 77 | }, 78 | } 79 | ], 80 | ), 81 | Example( 82 | name="readme", 83 | path="https://github.com/impira/docquery/blob/ef73fa7e8069773ace03efae2254f3a510a814ef/README.md", 84 | qa_pairs=[ 85 | { 86 | "question": "What are the use cases for DocQuery?", 87 | "answers": { 88 | # These examples demonstrate the fact that the "word_boxes" are way too coarse in the web document implementation 89 | "LayoutLMv1": [ 90 | { 91 | "score": 0.9921, 92 | "answer": "DocQuery is a swiss army knife tool for working with documents and experiencing the power of modern machine learning. You can use it\njust about anywhere, including behind a firewall on sensitive data, and test it with a wide variety of documents. Our hope is that\nDocQuery enables many creative use cases for document understanding by making it simple and easy to ask questions from your documents.", 93 | "word_ids": [37], 94 | "page": 2, 95 | } 96 | ], 97 | "LayoutLMv1-Invoices": [ 98 | { 99 | "score": 0.9939, 100 | "answer": "DocQuery is a library and command-line tool that makes it easy to analyze semi-structured and unstructured documents (PDFs, scanned\nimages, etc.) using large language models (LLMs). You simply point DocQuery at one or more documents and specify a\nquestion you want to ask. DocQuery is created by the team at ", 101 | "word_ids": [98], 102 | "page": 0, 103 | } 104 | ], 105 | "Donut": [{"answer": "engine Powered by large language"}], 106 | }, 107 | } 108 | ], 109 | ), 110 | ] 111 | 112 | 113 | @pytest.mark.parametrize("example", EXAMPLES) 114 | @pytest.mark.parametrize("model", CHECKPOINTS.keys()) 115 | def test_impira_dataset(example, model): 116 | document = load_document(example.path) 117 | pipe = pipeline("document-question-answering", model=CHECKPOINTS[model]) 118 | for qa in example.qa_pairs: 119 | resp = pipe(question=qa.question, **document.context, top_k=1) 120 | assert nested_simplify(resp, decimals=4) == qa.answers[model] 121 | 122 | 123 | def test_run_with_choosen_OCR_str(): 124 | example = EXAMPLES[0] 125 | document = load_document(example.path, "tesseract") 126 | pipe = pipeline("document-question-answering", model=CHECKPOINTS["LayoutLMv1"]) 127 | for qa in example.qa_pairs: 128 | resp = pipe(question=qa.question, **document.context, top_k=1) 129 | assert nested_simplify(resp, decimals=4) == qa.answers["LayoutLMv1"] 130 | 131 | 132 | def test_run_with_choosen_OCR_instance(): 133 | example = EXAMPLES[0] 134 | reader = TesseractReader() 135 | document = load_document(example.path, reader) 136 | pipe = pipeline("document-question-answering", model=CHECKPOINTS["LayoutLMv1"]) 137 | for qa in example.qa_pairs: 138 | resp = pipe(question=qa.question, **document.context, top_k=1) 139 | assert nested_simplify(resp, decimals=4) == qa.answers["LayoutLMv1"] 140 | 141 | 142 | def test_run_with_ignore_embedded_text(): 143 | example = EXAMPLES[2] 144 | document = load_document(example.path, use_embedded_text=False) 145 | pipe = pipeline("document-question-answering", model=CHECKPOINTS["LayoutLMv1"]) 146 | for qa in example.qa_pairs: 147 | resp = pipe(question=qa.question, **document.context, top_k=1) 148 | assert nested_simplify(resp, decimals=4) == qa.answers["LayoutLMv1__use_embedded_text=False"] 149 | -------------------------------------------------------------------------------- /tests/test_ocr_reader.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch 2 | 3 | import pytest 4 | 5 | from docquery.ocr_reader import DummyOCRReader, EasyOCRReader, NoOCRReaderFound, TesseractReader, get_ocr_reader 6 | 7 | 8 | READER_PERMUTATIONS = [ 9 | {"name": "tesseract", "reader_class": TesseractReader}, 10 | {"name": "easyocr", "reader_class": EasyOCRReader}, 11 | {"name": "dummy", "reader_class": DummyOCRReader}, 12 | ] 13 | 14 | 15 | @pytest.mark.parametrize("reader_permutations", READER_PERMUTATIONS) 16 | @patch( 17 | "docquery.ocr_reader.OCR_AVAILABLE", 18 | { 19 | "tesseract": True, 20 | "easyocr": True, 21 | "dummy": True, 22 | }, 23 | ) 24 | def test_get_ocr_reader(reader_permutations): 25 | reader = get_ocr_reader(reader_permutations["name"]) 26 | assert isinstance(reader, reader_permutations["reader_class"]) 27 | 28 | 29 | @patch( 30 | "docquery.ocr_reader.OCR_AVAILABLE", 31 | { 32 | "tesseract": True, 33 | "easyocr": True, 34 | "dummy": True, 35 | }, 36 | ) 37 | def test_wrong_string_ocr_reader(): 38 | with pytest.raises(Exception) as e: 39 | reader = get_ocr_reader("FAKE_OCR") 40 | assert ( 41 | "Failed to find: FAKE_OCR in the available ocr libraries. The choices are: ['tesseract', 'easyocr', 'dummy']" 42 | in str(e.value) 43 | ) 44 | assert e.type == NoOCRReaderFound 45 | 46 | 47 | @patch( 48 | "docquery.ocr_reader.OCR_AVAILABLE", 49 | { 50 | "tesseract": False, 51 | "easyocr": True, 52 | "dummy": True, 53 | }, 54 | ) 55 | def test_choosing_unavailable_ocr_reader(): 56 | with pytest.raises(Exception) as e: 57 | reader = get_ocr_reader("tesseract") 58 | assert f"Failed to load: tesseract Please make sure its installed correctly." in str(e.value) 59 | assert e.type == NoOCRReaderFound 60 | 61 | 62 | @patch( 63 | "docquery.ocr_reader.OCR_AVAILABLE", 64 | { 65 | "tesseract": False, 66 | "easyocr": True, 67 | "dummy": True, 68 | }, 69 | ) 70 | def test_assert_fallback(): 71 | reader = get_ocr_reader() 72 | assert isinstance(reader, EasyOCRReader) 73 | 74 | 75 | @patch( 76 | "docquery.ocr_reader.OCR_AVAILABLE", 77 | { 78 | "tesseract": False, 79 | "easyocr": False, 80 | "dummy": True, 81 | }, 82 | ) 83 | def test_assert_fallback_to_dummy(): 84 | reader = get_ocr_reader() 85 | assert isinstance(reader, DummyOCRReader) 86 | 87 | 88 | @patch( 89 | "docquery.ocr_reader.OCR_AVAILABLE", 90 | { 91 | "tesseract": False, 92 | "easyocr": False, 93 | "dummy": False, 94 | }, 95 | ) 96 | def test_fail_to_load_if_called_directly_when_ocr_unavailable(): 97 | EasyOCRReader._instances = {} 98 | with pytest.raises(Exception) as e: 99 | reader = EasyOCRReader() 100 | assert "Unable to use easyocr (OCR will be unavailable). Install easyocr to process images with OCR." in str( 101 | e.value 102 | ) 103 | assert e.type == NoOCRReaderFound 104 | 105 | 106 | def test_ocr_reader_are_singletons(): 107 | reader_a = DummyOCRReader() 108 | reader_b = DummyOCRReader() 109 | reader_c = DummyOCRReader() 110 | assert reader_a is reader_b 111 | assert reader_a is reader_c 112 | -------------------------------------------------------------------------------- /tests/test_web_driver.py: -------------------------------------------------------------------------------- 1 | from docquery.web import get_webdriver 2 | 3 | 4 | def test_singleton(): 5 | d1 = get_webdriver() 6 | d2 = get_webdriver() 7 | assert d1 is d2, "Both webdrivers should map to the same instance" 8 | 9 | 10 | def test_readme_file(): 11 | driver = get_webdriver() 12 | driver.get("https://github.com/impira/docquery/blob/ef73fa7e8069773ace03efae2254f3a510a814ef/README.md") 13 | word_boxes = driver.find_word_boxes() 14 | 15 | # This sanity checks the logic that merges word boxes 16 | assert len(word_boxes["word_boxes"]) > 20, "Expect multiple word boxes" 17 | 18 | # Make sure the last screenshot is shorter than the previous ones 19 | _, screenshots = driver.scroll_and_screenshot() 20 | assert len(screenshots) > 1, "Expect multiple pages" 21 | assert ( 22 | screenshots[0].size[1] - screenshots[-1].size[1] > 10 23 | ), "Expect the last page to be shorter than the first several" 24 | --------------------------------------------------------------------------------