├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yml │ ├── config.yml │ └── feature_request.yml ├── dependabot.yml ├── release.yml ├── verify_pr_labels.py └── workflows │ ├── builds.yml │ ├── clear_caches.yml │ ├── demo.yml │ ├── doc-status.yml │ ├── docker.yml │ ├── docs.yml │ ├── main.yml │ ├── pr-labels.yml │ ├── public_docker_images.yml │ ├── publish.yml │ ├── pull_requests.yml │ ├── references.yml │ ├── scripts.yml │ └── style.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── api ├── .gitignore ├── Dockerfile ├── Makefile ├── README.md ├── app │ ├── config.py │ ├── main.py │ ├── routes │ │ ├── detection.py │ │ ├── kie.py │ │ ├── ocr.py │ │ └── recognition.py │ ├── schemas.py │ ├── utils.py │ └── vision.py ├── docker-compose.yml ├── pyproject.toml └── tests │ ├── conftest.py │ ├── routes │ ├── test_detection.py │ ├── test_kie.py │ ├── test_ocr.py │ └── test_recognition.py │ └── utils │ ├── test_utils.py │ └── test_vision.py ├── demo ├── README.md ├── app.py ├── backend │ └── pytorch.py ├── packages.txt └── pt-requirements.txt ├── docs ├── Makefile ├── README.md ├── build.sh ├── images │ ├── Logo_doctr.gif │ ├── demo_illustration_mini.png │ ├── demo_update.png │ ├── doctr-need-help.png │ ├── doctr_demo_app.png │ ├── doctr_example_script.gif │ ├── ocr.png │ └── synthesized_sample.png └── source │ ├── _static │ ├── css │ │ └── mindee.css │ ├── images │ │ ├── Logo-docTR-white.png │ │ └── favicon.ico │ └── js │ │ └── custom.js │ ├── changelog.rst │ ├── community │ ├── resources.rst │ └── tools.rst │ ├── conf.py │ ├── contributing │ ├── code_of_conduct.md │ └── contributing.md │ ├── getting_started │ └── installing.rst │ ├── index.rst │ ├── modules │ ├── contrib.rst │ ├── datasets.rst │ ├── io.rst │ ├── models.rst │ ├── transforms.rst │ └── utils.rst │ ├── notebooks.rst │ └── using_doctr │ ├── custom_models_training.rst │ ├── running_on_aws.rst │ ├── sharing_models.rst │ ├── using_contrib_modules.rst │ ├── using_datasets.rst │ ├── using_model_export.rst │ └── using_models.rst ├── doctr ├── __init__.py ├── contrib │ ├── __init__.py │ ├── artefacts.py │ └── base.py ├── datasets │ ├── __init__.py │ ├── coco_text.py │ ├── cord.py │ ├── datasets │ │ ├── __init__.py │ │ ├── base.py │ │ └── pytorch.py │ ├── detection.py │ ├── doc_artefacts.py │ ├── funsd.py │ ├── generator │ │ ├── __init__.py │ │ ├── base.py │ │ └── pytorch.py │ ├── ic03.py │ ├── ic13.py │ ├── iiit5k.py │ ├── iiithws.py │ ├── imgur5k.py │ ├── mjsynth.py │ ├── ocr.py │ ├── orientation.py │ ├── recognition.py │ ├── sroie.py │ ├── svhn.py │ ├── svt.py │ ├── synthtext.py │ ├── utils.py │ ├── vocabs.py │ └── wildreceipt.py ├── file_utils.py ├── io │ ├── __init__.py │ ├── elements.py │ ├── html.py │ ├── image │ │ ├── __init__.py │ │ ├── base.py │ │ └── pytorch.py │ ├── pdf.py │ └── reader.py ├── models │ ├── __init__.py │ ├── _utils.py │ ├── builder.py │ ├── classification │ │ ├── __init__.py │ │ ├── magc_resnet │ │ │ ├── __init__.py │ │ │ └── pytorch.py │ │ ├── mobilenet │ │ │ ├── __init__.py │ │ │ └── pytorch.py │ │ ├── predictor │ │ │ ├── __init__.py │ │ │ └── pytorch.py │ │ ├── resnet │ │ │ ├── __init__.py │ │ │ └── pytorch.py │ │ ├── textnet │ │ │ ├── __init__.py │ │ │ └── pytorch.py │ │ ├── vgg │ │ │ ├── __init__.py │ │ │ └── pytorch.py │ │ ├── vip │ │ │ ├── __init__.py │ │ │ ├── layers │ │ │ │ ├── __init__.py │ │ │ │ └── pytorch.py │ │ │ └── pytorch.py │ │ ├── vit │ │ │ ├── __init__.py │ │ │ └── pytorch.py │ │ └── zoo.py │ ├── core.py │ ├── detection │ │ ├── __init__.py │ │ ├── _utils │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ └── pytorch.py │ │ ├── core.py │ │ ├── differentiable_binarization │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ └── pytorch.py │ │ ├── fast │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ └── pytorch.py │ │ ├── linknet │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ └── pytorch.py │ │ ├── predictor │ │ │ ├── __init__.py │ │ │ └── pytorch.py │ │ └── zoo.py │ ├── factory │ │ ├── __init__.py │ │ └── hub.py │ ├── kie_predictor │ │ ├── __init__.py │ │ ├── base.py │ │ └── pytorch.py │ ├── modules │ │ ├── __init__.py │ │ ├── layers │ │ │ ├── __init__.py │ │ │ └── pytorch.py │ │ ├── transformer │ │ │ ├── __init__.py │ │ │ └── pytorch.py │ │ └── vision_transformer │ │ │ ├── __init__.py │ │ │ └── pytorch.py │ ├── predictor │ │ ├── __init__.py │ │ ├── base.py │ │ └── pytorch.py │ ├── preprocessor │ │ ├── __init__.py │ │ └── pytorch.py │ ├── recognition │ │ ├── __init__.py │ │ ├── core.py │ │ ├── crnn │ │ │ ├── __init__.py │ │ │ └── pytorch.py │ │ ├── master │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ └── pytorch.py │ │ ├── parseq │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ └── pytorch.py │ │ ├── predictor │ │ │ ├── __init__.py │ │ │ ├── _utils.py │ │ │ └── pytorch.py │ │ ├── sar │ │ │ ├── __init__.py │ │ │ └── pytorch.py │ │ ├── utils.py │ │ ├── viptr │ │ │ ├── __init__.py │ │ │ └── pytorch.py │ │ ├── vitstr │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ └── pytorch.py │ │ └── zoo.py │ ├── utils │ │ ├── __init__.py │ │ └── pytorch.py │ └── zoo.py ├── py.typed ├── transforms │ ├── __init__.py │ ├── functional │ │ ├── __init__.py │ │ ├── base.py │ │ └── pytorch.py │ └── modules │ │ ├── __init__.py │ │ ├── base.py │ │ └── pytorch.py └── utils │ ├── __init__.py │ ├── common_types.py │ ├── data.py │ ├── fonts.py │ ├── geometry.py │ ├── metrics.py │ ├── multithreading.py │ ├── reconstitution.py │ ├── repr.py │ └── visualization.py ├── notebooks └── README.rst ├── pyproject.toml ├── references ├── classification │ ├── README.md │ ├── latency.py │ ├── train_character.py │ ├── train_orientation.py │ └── utils.py ├── detection │ ├── README.md │ ├── evaluate.py │ ├── latency.py │ ├── train.py │ └── utils.py ├── recognition │ ├── README.md │ ├── evaluate.py │ ├── latency.py │ ├── train.py │ └── utils.py └── requirements.txt ├── scripts ├── analyze.py ├── collect_env.py ├── detect_text.py ├── evaluate.py └── evaluate_kie.py ├── setup.py └── tests ├── common ├── test_contrib.py ├── test_core.py ├── test_datasets.py ├── test_datasets_utils.py ├── test_datasets_vocabs.py ├── test_headers.py ├── test_io.py ├── test_io_elements.py ├── test_models.py ├── test_models_builder.py ├── test_models_detection.py ├── test_models_detection_utils.py ├── test_models_recognition_predictor.py ├── test_models_recognition_utils.py ├── test_transforms.py ├── test_utils_data.py ├── test_utils_fonts.py ├── test_utils_geometry.py ├── test_utils_metrics.py ├── test_utils_multithreading.py ├── test_utils_reconstitution.py └── test_utils_visualization.py ├── conftest.py └── pytorch ├── test_datasets_pt.py ├── test_io_image_pt.py ├── test_models_classification_pt.py ├── test_models_detection_pt.py ├── test_models_factory.py ├── test_models_preprocessor_pt.py ├── test_models_recognition_pt.py ├── test_models_utils_pt.py ├── test_models_zoo_pt.py └── test_transforms_pt.py /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: 🐛 Bug report 2 | description: Create a report to help us improve the library 3 | labels: 'type: bug' 4 | 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: > 9 | #### Before reporting a bug, please check that the issue hasn't already been addressed in [the existing and past issues](https://github.com/mindee/doctr/issues?q=is%3Aissue). 10 | - type: textarea 11 | attributes: 12 | label: Bug description 13 | description: | 14 | A clear and concise description of what the bug is. 15 | 16 | Please explain the result you observed and the behavior you were expecting. 17 | placeholder: | 18 | A clear and concise description of what the bug is. 19 | validations: 20 | required: true 21 | 22 | - type: textarea 23 | attributes: 24 | label: Code snippet to reproduce the bug 25 | description: | 26 | Sample code to reproduce the problem. 27 | 28 | Please wrap your code snippet with ```` ```triple quotes blocks``` ```` for readability. 29 | placeholder: | 30 | ```python 31 | Sample code to reproduce the problem 32 | ``` 33 | validations: 34 | required: true 35 | - type: textarea 36 | attributes: 37 | label: Error traceback 38 | description: | 39 | The error message you received running the code snippet, with the full traceback. 40 | 41 | Please wrap your error message with ```` ```triple quotes blocks``` ```` for readability. 42 | placeholder: | 43 | ``` 44 | The error message you got, with the full traceback. 45 | ``` 46 | validations: 47 | required: true 48 | - type: textarea 49 | attributes: 50 | label: Environment 51 | description: | 52 | Please run the following command and paste the output below. 53 | ```sh 54 | wget https://raw.githubusercontent.com/mindee/doctr/main/scripts/collect_env.py 55 | # For security purposes, please check the contents of collect_env.py before running it. 56 | python collect_env.py 57 | ``` 58 | validations: 59 | required: true 60 | - type: markdown 61 | attributes: 62 | value: > 63 | Thanks for helping us improve the library! 64 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: true 2 | contact_links: 3 | - name: Usage questions 4 | url: https://github.com/mindee/doctr/discussions 5 | about: Ask questions and discuss with other docTR community members 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: 🚀 Feature request 2 | description: Submit a proposal/request for a new feature for docTR 3 | labels: 'type: enhancement' 4 | 5 | body: 6 | - type: textarea 7 | attributes: 8 | label: 🚀 The feature 9 | description: > 10 | A clear and concise description of the feature proposal 11 | validations: 12 | required: true 13 | - type: textarea 14 | attributes: 15 | label: Motivation, pitch 16 | description: > 17 | Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too. 18 | validations: 19 | required: true 20 | - type: textarea 21 | attributes: 22 | label: Alternatives 23 | description: > 24 | A description of any alternative solutions or features you've considered, if any. 25 | - type: textarea 26 | attributes: 27 | label: Additional context 28 | description: > 29 | Add any other context or screenshots about the feature request. 30 | - type: markdown 31 | attributes: 32 | value: > 33 | Thanks for contributing 🎉 34 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "pip" 4 | directory: "/" 5 | open-pull-requests-limit: 10 6 | target-branch: "main" 7 | labels: ["topic: build"] 8 | schedule: 9 | interval: weekly 10 | day: sunday 11 | reviewers: 12 | - "charlesmindee" 13 | - "felixdittrich92" 14 | - "odulcy-mindee" 15 | - package-ecosystem: "github-actions" 16 | directory: "/" 17 | open-pull-requests-limit: 10 18 | target-branch: "main" 19 | labels: ["topic: ci"] 20 | schedule: 21 | interval: weekly 22 | day: sunday 23 | reviewers: 24 | - "charlesmindee" 25 | - "felixdittrich92" 26 | - "odulcy-mindee" 27 | groups: 28 | github-actions: 29 | patterns: 30 | - "*" 31 | -------------------------------------------------------------------------------- /.github/release.yml: -------------------------------------------------------------------------------- 1 | changelog: 2 | exclude: 3 | labels: 4 | - ignore-for-release 5 | categories: 6 | - title: Breaking Changes 🛠 7 | labels: 8 | - "type: breaking change" 9 | # NEW FEATURES 10 | - title: New Features 11 | labels: 12 | - "type: new feature" 13 | # BUG FIXES 14 | - title: Bug Fixes 15 | labels: 16 | - "type: bug" 17 | # IMPROVEMENTS 18 | - title: Improvements 19 | labels: 20 | - "type: enhancement" 21 | # MISC 22 | - title: Miscellaneous 23 | labels: 24 | - "type: misc" 25 | -------------------------------------------------------------------------------- /.github/verify_pr_labels.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | """Borrowed & adapted from https://github.com/pytorch/vision/blob/main/.github/process_commit.py 7 | This script finds the merger responsible for labeling a PR by a commit SHA. It is used by the workflow in 8 | '.github/workflows/pr-labels.yml'. If there exists no PR associated with the commit or the PR is properly labeled, 9 | this script is a no-op. 10 | Note: we ping the merger only, not the reviewers, as the reviewers can sometimes be external to torchvision 11 | with no labeling responsibility, so we don't want to bother them. 12 | """ 13 | 14 | from typing import Any 15 | 16 | import requests 17 | 18 | # For a PR to be properly labeled it should have one primary label and one secondary label 19 | 20 | # Should specify the type of change 21 | PRIMARY_LABELS = { 22 | "type: new feature", 23 | "type: bug", 24 | "type: enhancement", 25 | "type: misc", 26 | } 27 | 28 | # Should specify what has been modified 29 | SECONDARY_LABELS = { 30 | "topic: documentation", 31 | "module: datasets", 32 | "module: io", 33 | "module: models", 34 | "module: transforms", 35 | "module: utils", 36 | "ext: api", 37 | "ext: demo", 38 | "ext: docs", 39 | "ext: notebooks", 40 | "ext: references", 41 | "ext: scripts", 42 | "ext: tests", 43 | "topic: build", 44 | "topic: ci", 45 | "topic: docker", 46 | } 47 | 48 | GH_ORG = "mindee" 49 | GH_REPO = "doctr" 50 | 51 | 52 | def query_repo(cmd: str, *, accept) -> Any: 53 | response = requests.get(f"https://api.github.com/repos/{GH_ORG}/{GH_REPO}/{cmd}", headers=dict(Accept=accept)) 54 | return response.json() 55 | 56 | 57 | def get_pr_merger_and_labels(pr_number: int) -> tuple[str, set[str]]: 58 | # See https://docs.github.com/en/rest/reference/pulls#get-a-pull-request 59 | data = query_repo(f"pulls/{pr_number}", accept="application/vnd.github.v3+json") 60 | merger = data.get("merged_by", {}).get("login") 61 | labels = {label["name"] for label in data["labels"]} 62 | return merger, labels 63 | 64 | 65 | def main(args): 66 | merger, labels = get_pr_merger_and_labels(args.pr) 67 | is_properly_labeled = bool(PRIMARY_LABELS.intersection(labels) and SECONDARY_LABELS.intersection(labels)) 68 | if isinstance(merger, str) and not is_properly_labeled: 69 | print(f"@{merger}") 70 | 71 | 72 | def parse_args(): 73 | import argparse 74 | 75 | parser = argparse.ArgumentParser( 76 | description="PR label checker", formatter_class=argparse.ArgumentDefaultsHelpFormatter 77 | ) 78 | 79 | parser.add_argument("pr", type=int, help="PR number") 80 | args = parser.parse_args() 81 | 82 | return args 83 | 84 | 85 | if __name__ == "__main__": 86 | args = parse_args() 87 | main(args) 88 | -------------------------------------------------------------------------------- /.github/workflows/builds.yml: -------------------------------------------------------------------------------- 1 | name: builds 2 | 3 | on: 4 | push: 5 | branches: main 6 | pull_request: 7 | branches: main 8 | 9 | jobs: 10 | build: 11 | runs-on: ${{ matrix.os }} 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | os: [ubuntu-latest, macos-latest, windows-latest] 16 | python: ["3.10", "3.11"] 17 | steps: 18 | - uses: actions/checkout@v4 19 | - if: matrix.os == 'macos-latest' 20 | name: Install MacOS prerequisites 21 | run: brew install cairo pango gdk-pixbuf libffi 22 | - name: Set up Python 23 | uses: actions/setup-python@v5 24 | with: 25 | # MacOS issue ref.: https://github.com/actions/setup-python/issues/855 & https://github.com/actions/setup-python/issues/865 26 | python-version: ${{ matrix.os == 'macos-latest' && matrix.python == '3.10' && '3.11' || matrix.python }} 27 | architecture: x64 28 | - name: Cache python modules 29 | uses: actions/cache@v4 30 | with: 31 | path: ~/.cache/pip 32 | key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }} 33 | - name: Install package 34 | run: | 35 | python -m pip install --upgrade pip 36 | if [ "${{ runner.os }}" = "Windows" ]; then 37 | pip install -e .[viz] --upgrade 38 | else 39 | pip install -e .[viz,html] --upgrade 40 | fi 41 | shell: bash # Ensures shell is consistent across OSes 42 | - name: Import package 43 | run: python -c "import doctr; print(doctr.__version__)" 44 | -------------------------------------------------------------------------------- /.github/workflows/clear_caches.yml: -------------------------------------------------------------------------------- 1 | name: Clear GitHub runner caches 2 | 3 | on: 4 | workflow_dispatch: 5 | schedule: 6 | - cron: '0 0 * * *' # Runs once a day 7 | 8 | jobs: 9 | clear: 10 | name: Clear caches 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: MyAlbum/purge-cache@v2 14 | with: 15 | max-age: 172800 # Caches older than 2 days are deleted 16 | -------------------------------------------------------------------------------- /.github/workflows/demo.yml: -------------------------------------------------------------------------------- 1 | name: demo 2 | 3 | on: 4 | # Run 'test-demo' on every pull request to the main branch 5 | pull_request: 6 | branches: [main] 7 | 8 | # Run 'test-demo' on every push to the main branch or both jobs when a new version tag is pushed 9 | push: 10 | branches: 11 | - main 12 | tags: 13 | - 'v*' 14 | 15 | # Run 'sync-to-hub' on a scheduled cron job 16 | schedule: 17 | - cron: '0 2 10 * *' # At 02:00 on day-of-month 10 (every month) 18 | 19 | # Allow manual triggering of the workflow 20 | workflow_dispatch: 21 | 22 | jobs: 23 | test-demo: 24 | runs-on: ${{ matrix.os }} 25 | strategy: 26 | fail-fast: false 27 | matrix: 28 | os: [ubuntu-latest] 29 | python: ["3.10"] 30 | steps: 31 | - if: matrix.os == 'macos-latest' 32 | name: Install MacOS prerequisites 33 | run: brew install cairo pango gdk-pixbuf libffi 34 | - uses: actions/checkout@v4 35 | - name: Set up Python 36 | uses: actions/setup-python@v5 37 | with: 38 | python-version: ${{ matrix.python }} 39 | architecture: x64 40 | - name: Cache python modules 41 | uses: actions/cache@v4 42 | with: 43 | path: ~/.cache/pip 44 | key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('demo/pt-requirements.txt') }} 45 | - name: Install dependencies 46 | run: | 47 | python -m pip install --upgrade pip 48 | pip install -e .[viz,html] --upgrade 49 | pip install -r demo/pt-requirements.txt 50 | - name: Run demo 51 | run: | 52 | streamlit --version 53 | screen -dm streamlit run demo/app.py 54 | sleep 10 55 | curl http://localhost:8501/docs 56 | 57 | # This job only runs when a new version tag is pushed or during the cron job or when manually triggered 58 | sync-to-hub: 59 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' 60 | needs: test-demo 61 | runs-on: ${{ matrix.os }} 62 | strategy: 63 | fail-fast: false 64 | matrix: 65 | os: [ubuntu-latest] 66 | python: ["3.10"] 67 | steps: 68 | - uses: actions/checkout@v4 69 | with: 70 | fetch-depth: 0 71 | - name: Set up Python 72 | uses: actions/setup-python@v5 73 | with: 74 | python-version: ${{ matrix.python }} 75 | architecture: x64 76 | - name: Install huggingface_hub 77 | run: pip install huggingface-hub 78 | - name: Upload folder to Hugging Face 79 | # Only keep the requirements.txt file for the demo (PyTorch) 80 | run: | 81 | mv demo/pt-requirements.txt demo/requirements.txt 82 | 83 | python -c " 84 | from huggingface_hub import HfApi 85 | api = HfApi(token='${{ secrets.HF_TOKEN }}') 86 | repo_id = 'mindee/doctr' 87 | api.upload_folder(repo_id=repo_id, repo_type='space', folder_path='demo/') 88 | api.restart_space(repo_id=repo_id, factory_reboot=True) 89 | " 90 | -------------------------------------------------------------------------------- /.github/workflows/doc-status.yml: -------------------------------------------------------------------------------- 1 | name: doc-status 2 | on: 3 | page_build 4 | 5 | jobs: 6 | see-page-build-payload: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - name: Set up Python 10 | uses: actions/setup-python@v5 11 | with: 12 | python-version: "3.10" 13 | architecture: x64 14 | - name: check status 15 | run: | 16 | import os 17 | status, errormsg = os.getenv('STATUS'), os.getenv('ERROR') 18 | if status != 'built': raise AssertionError(f"There was an error building the page on GitHub pages.\n\nStatus: {status}\n\nError messsage: {errormsg}") 19 | shell: python 20 | env: 21 | STATUS: ${{ github.event.build.status }} 22 | ERROR: ${{ github.event.build.error.message }} 23 | -------------------------------------------------------------------------------- /.github/workflows/docker.yml: -------------------------------------------------------------------------------- 1 | name: docker 2 | 3 | on: 4 | push: 5 | branches: main 6 | pull_request: 7 | branches: main 8 | 9 | jobs: 10 | docker-package: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | - name: Build docker image 15 | run: docker build -t doctr-py3.10-slim --build-arg SYSTEM=cpu . 16 | - name: Run docker container 17 | run: docker run doctr-py3.10-slim python3 -c 'import doctr' 18 | 19 | pytest-api: 20 | runs-on: ${{ matrix.os }} 21 | strategy: 22 | matrix: 23 | os: [ubuntu-latest] 24 | python: ["3.10"] 25 | steps: 26 | - uses: actions/checkout@v4 27 | - uses: actions/setup-python@v5 28 | with: 29 | python-version: ${{ matrix.python }} 30 | architecture: x64 31 | - name: Build & run docker 32 | run: cd api && make lock && make run 33 | - name: Ping server 34 | run: wget --spider --tries=12 http://localhost:8080/docs 35 | - name: Run docker test 36 | run: cd api && make test 37 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: docs 2 | on: 3 | push: 4 | branches: main 5 | 6 | jobs: 7 | docs-deploy: 8 | runs-on: ${{ matrix.os }} 9 | strategy: 10 | matrix: 11 | os: [ubuntu-latest] 12 | python: ["3.10"] 13 | steps: 14 | - uses: actions/checkout@v4 15 | with: 16 | persist-credentials: false 17 | - name: Set up Python 18 | uses: actions/setup-python@v5 19 | with: 20 | python-version: ${{ matrix.python }} 21 | architecture: x64 22 | - name: Cache python modules 23 | uses: actions/cache@v4 24 | with: 25 | path: ~/.cache/pip 26 | key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }}-docs 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | pip install -e .[viz,html] --upgrade 31 | pip install -e .[docs] 32 | 33 | - name: Build documentation 34 | run: cd docs && bash build.sh 35 | 36 | - name: Documentation sanity check 37 | run: test -e docs/build/index.html || exit 38 | 39 | - name: Install SSH Client 🔑 40 | uses: webfactory/ssh-agent@v0.9.1 41 | with: 42 | ssh-private-key: ${{ secrets.SSH_DEPLOY_KEY }} 43 | 44 | - name: Deploy to Github Pages 45 | uses: JamesIves/github-pages-deploy-action@v4.7.3 46 | with: 47 | BRANCH: gh-pages 48 | FOLDER: 'docs/build' 49 | COMMIT_MESSAGE: '[skip ci] Documentation updates' 50 | CLEAN: true 51 | SSH: true 52 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | push: 5 | branches: main 6 | pull_request: 7 | branches: main 8 | 9 | jobs: 10 | pytest-common: 11 | runs-on: ${{ matrix.os }} 12 | strategy: 13 | matrix: 14 | os: [ubuntu-latest] 15 | python: ["3.10"] 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Set up Python 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: ${{ matrix.python }} 22 | architecture: x64 23 | - name: Cache python modules 24 | uses: actions/cache@v4 25 | with: 26 | path: ~/.cache/pip 27 | key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }}-tests 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install -e .[viz,html] --upgrade 32 | pip install -e .[testing] 33 | - name: Run unittests 34 | run: | 35 | coverage run -m pytest tests/common/ -rs 36 | coverage xml -o coverage-common.xml 37 | - uses: actions/upload-artifact@v4 38 | with: 39 | name: coverage-common 40 | path: ./coverage-common.xml 41 | if-no-files-found: error 42 | 43 | 44 | pytest-torch: 45 | runs-on: ${{ matrix.os }} 46 | strategy: 47 | matrix: 48 | os: [ubuntu-latest] 49 | python: ["3.10"] 50 | steps: 51 | - uses: actions/checkout@v4 52 | - name: Set up Python 53 | uses: actions/setup-python@v5 54 | with: 55 | python-version: ${{ matrix.python }} 56 | architecture: x64 57 | - name: Cache python modules 58 | uses: actions/cache@v4 59 | with: 60 | path: ~/.cache/pip 61 | key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }}-tests 62 | - name: Install dependencies 63 | run: | 64 | python -m pip install --upgrade pip 65 | pip install -e .[viz,html] --upgrade 66 | pip install -e .[testing] 67 | 68 | - name: Run unittests 69 | run: | 70 | coverage run -m pytest tests/pytorch/ -rs 71 | coverage xml -o coverage-pt.xml 72 | 73 | - uses: actions/upload-artifact@v4 74 | with: 75 | name: coverage-pytorch 76 | path: ./coverage-pt.xml 77 | if-no-files-found: error 78 | 79 | codecov-upload: 80 | runs-on: ubuntu-latest 81 | needs: [ pytest-common, pytest-torch ] 82 | steps: 83 | - uses: actions/checkout@v4 84 | - uses: actions/download-artifact@v4 85 | - name: Upload coverage to Codecov 86 | uses: codecov/codecov-action@v5 87 | with: 88 | flags: unittests 89 | fail_ci_if_error: true 90 | token: ${{ secrets.CODECOV_TOKEN }} 91 | -------------------------------------------------------------------------------- /.github/workflows/pr-labels.yml: -------------------------------------------------------------------------------- 1 | name: pr-labels 2 | 3 | on: 4 | pull_request: 5 | branches: main 6 | types: closed 7 | 8 | jobs: 9 | is-properly-labeled: 10 | if: github.event.pull_request.merged == true 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout repository 14 | uses: actions/checkout@v4 15 | - name: Set up python 16 | uses: actions/setup-python@v5 17 | - name: Install requests 18 | run: pip install requests 19 | - name: Process commit and find merger responsible for labeling 20 | id: commit 21 | run: echo "::set-output name=merger::$(python .github/verify_pr_labels.py ${{ github.event.pull_request.number }})" 22 | - name: 'Comment PR' 23 | uses: actions/github-script@v7.0.1 24 | if: ${{ steps.commit.outputs.merger != '' }} 25 | with: 26 | github-token: ${{ secrets.GITHUB_TOKEN }} 27 | script: | 28 | const { issue: { number: issue_number }, repo: { owner, repo } } = context; 29 | github.rest.issues.createComment({ issue_number, owner, repo, body: 'Hey ${{ steps.commit.outputs.merger }} 👋\nYou merged this PR, but it is not correctly labeled. The list of valid labels is available at https://github.com/mindee/doctr/blob/main/.github/verify_pr_labels.py' }); 30 | -------------------------------------------------------------------------------- /.github/workflows/public_docker_images.yml: -------------------------------------------------------------------------------- 1 | # https://docs.github.com/en/actions/publishing-packages/publishing-docker-images#publishing-images-to-github-packages 2 | # 3 | name: Docker image on ghcr.io 4 | 5 | on: 6 | push: 7 | tags: 8 | - 'v*' 9 | pull_request: 10 | branches: main 11 | schedule: 12 | - cron: '0 2 1 */3 *' # At 02:00 on the 1st day of every 3rd month 13 | 14 | env: 15 | REGISTRY: ghcr.io 16 | 17 | jobs: 18 | build-and-push-image: 19 | runs-on: ubuntu-latest 20 | 21 | strategy: 22 | fail-fast: false 23 | matrix: 24 | # Must match version at https://www.python.org/ftp/python/ 25 | python: ["3.10.13", "3.11.8", "3.12.7"] 26 | # NOTE: Since docTR 1.0.0 torch doesn't exist as a seperate install option it's only to keep the naming convention 27 | framework: ["torch", "torch,viz,html,contrib"] 28 | 29 | # Sets the permissions granted to the `GITHUB_TOKEN` for the actions in this job. 30 | permissions: 31 | contents: read 32 | packages: write 33 | 34 | steps: 35 | - name: Checkout repository 36 | uses: actions/checkout@v4 37 | 38 | - name: Log in to the Container registry 39 | uses: docker/login-action@v3 40 | with: 41 | registry: ${{ env.REGISTRY }} 42 | username: ${{ github.actor }} 43 | password: ${{ secrets.GITHUB_TOKEN }} 44 | 45 | - name: Sanitize docker tag 46 | run: | 47 | PREFIX_DOCKER_TAG="${{ matrix.framework }}-py${{ matrix.python }}-" 48 | PREFIX_DOCKER_TAG=$(echo ${PREFIX_DOCKER_TAG}|sed 's/,/-/g') 49 | echo PREFIX_DOCKER_TAG=${PREFIX_DOCKER_TAG} >> $GITHUB_ENV 50 | echo $PREFIX_DOCKER_TAG 51 | 52 | - name: Extract metadata (tags, labels) for Docker 53 | id: meta 54 | uses: docker/metadata-action@v5 55 | with: 56 | images: ${{ env.REGISTRY }}/${{ github.repository }} 57 | tags: | 58 | # used only on schedule event 59 | type=schedule,pattern={{date 'YYYY-MM'}},prefix=${{ env.PREFIX_DOCKER_TAG }} 60 | # used only if a tag following semver is published 61 | type=semver,pattern={{raw}},prefix=${{ env.PREFIX_DOCKER_TAG }} 62 | 63 | - name: Build Docker image 64 | id: build 65 | uses: docker/build-push-action@v6 66 | with: 67 | context: . 68 | build-args: | 69 | FRAMEWORK=${{ matrix.framework }} 70 | PYTHON_VERSION=${{ matrix.python }} 71 | DOCTR_REPO=${{ github.repository }} 72 | DOCTR_VERSION=${{ github.sha }} 73 | push: false # push only if `import doctr` works 74 | tags: ${{ steps.meta.outputs.tags }} 75 | 76 | - name: Check if `import doctr` works 77 | run: docker run ${{ steps.build.outputs.imageid }} python3 -c 'import doctr' 78 | 79 | - name: Push Docker image 80 | # Push only if the CI is not triggered by "PR on main" 81 | if: ${{ (github.ref == 'refs/heads/main' && github.event_name != 'pull_request') || (startsWith(github.ref, 'refs/tags') && github.event_name == 'push') }} 82 | uses: docker/build-push-action@v6 83 | with: 84 | context: . 85 | build-args: | 86 | FRAMEWORK=${{ matrix.framework }} 87 | PYTHON_VERSION=${{ matrix.python }} 88 | DOCTR_REPO=${{ github.repository }} 89 | DOCTR_VERSION=${{ github.sha }} 90 | push: true 91 | tags: ${{ steps.meta.outputs.tags }} 92 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: publish 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | pypi: 9 | if: "!github.event.release.prerelease" 10 | strategy: 11 | fail-fast: false 12 | matrix: 13 | os: [ubuntu-latest] 14 | python: ["3.10"] 15 | runs-on: ${{ matrix.os }} 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Set up Python 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: ${{ matrix.python }} 22 | architecture: x64 23 | - name: Cache python modules 24 | uses: actions/cache@v4 25 | with: 26 | path: ~/.cache/pip 27 | key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }} 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install setuptools wheel twine --upgrade 32 | - name: Get release tag 33 | id: release_tag 34 | run: echo "VERSION=${GITHUB_REF/refs\/tags\//}" >> $GITHUB_ENV 35 | - name: Build and publish 36 | env: 37 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 38 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 39 | VERSION: ${{ env.VERSION }} 40 | run: | 41 | BUILD_VERSION=$VERSION python setup.py sdist bdist_wheel 42 | twine check dist/* 43 | twine upload dist/* 44 | 45 | pypi-check: 46 | needs: pypi 47 | if: "!github.event.release.prerelease" 48 | strategy: 49 | fail-fast: false 50 | matrix: 51 | os: [ubuntu-latest] 52 | python: ["3.10"] 53 | runs-on: ${{ matrix.os }} 54 | steps: 55 | - uses: actions/checkout@v4 56 | - name: Set up Python 57 | uses: actions/setup-python@v5 58 | with: 59 | python-version: ${{ matrix.python }} 60 | architecture: x64 61 | - name: Install package 62 | run: | 63 | python -m pip install --upgrade pip 64 | pip install python-doctr 65 | python -c "from importlib.metadata import version; print(version('python-doctr'))" 66 | -------------------------------------------------------------------------------- /.github/workflows/pull_requests.yml: -------------------------------------------------------------------------------- 1 | name: pull_requests 2 | 3 | on: 4 | pull_request: 5 | branches: main 6 | 7 | jobs: 8 | docs-build: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - name: Set up Python 13 | uses: actions/setup-python@v5 14 | with: 15 | python-version: "3.10" 16 | architecture: x64 17 | - name: Cache python modules 18 | uses: actions/cache@v4 19 | with: 20 | path: ~/.cache/pip 21 | key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }}-docs 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install -e .[viz,html] --upgrade 26 | pip install -e .[docs] 27 | 28 | - name: Build documentation 29 | run: cd docs && bash build.sh 30 | 31 | - name: Documentation sanity check 32 | run: test -e docs/build/index.html || exit 33 | -------------------------------------------------------------------------------- /.github/workflows/style.yml: -------------------------------------------------------------------------------- 1 | name: style 2 | 3 | on: 4 | push: 5 | branches: main 6 | pull_request: 7 | branches: main 8 | 9 | jobs: 10 | ruff: 11 | runs-on: ${{ matrix.os }} 12 | strategy: 13 | matrix: 14 | os: [ubuntu-latest] 15 | python: ["3.10"] 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Set up Python 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: ${{ matrix.python }} 22 | architecture: x64 23 | - name: Run ruff 24 | run: | 25 | pip install ruff --upgrade 26 | ruff --version 27 | ruff check --diff . 28 | 29 | mypy: 30 | runs-on: ${{ matrix.os }} 31 | strategy: 32 | matrix: 33 | os: [ubuntu-latest] 34 | python: ["3.10"] 35 | steps: 36 | - uses: actions/checkout@v4 37 | - name: Set up Python 38 | uses: actions/setup-python@v5 39 | with: 40 | python-version: ${{ matrix.python }} 41 | architecture: x64 42 | - name: Cache python modules 43 | uses: actions/cache@v4 44 | with: 45 | path: ~/.cache/pip 46 | key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }}-style 47 | - name: Install dependencies 48 | run: | 49 | python -m pip install --upgrade pip 50 | pip install -e .[dev] --upgrade 51 | pip install mypy --upgrade 52 | - name: Run mypy 53 | run: | 54 | mypy --version 55 | mypy 56 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Temp files 132 | doctr/version.py 133 | logs/ 134 | wandb/ 135 | .idea/ 136 | 137 | # Checkpoints 138 | *.pt 139 | *.pb 140 | *.index 141 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: check-ast 6 | - id: check-yaml 7 | exclude: .conda 8 | - id: check-toml 9 | - id: check-json 10 | - id: check-added-large-files 11 | exclude: docs/images/ 12 | - id: end-of-file-fixer 13 | - id: trailing-whitespace 14 | - id: debug-statements 15 | - id: check-merge-conflict 16 | - id: no-commit-to-branch 17 | args: ['--branch', 'main'] 18 | - repo: https://github.com/astral-sh/ruff-pre-commit 19 | rev: v0.12.0 20 | hooks: 21 | - id: ruff 22 | args: [ --fix ] 23 | - id: ruff-format 24 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.2.0-base-ubuntu22.04 2 | 3 | ENV DEBIAN_FRONTEND=noninteractive 4 | ENV LANG=C.UTF-8 5 | ENV PYTHONUNBUFFERED=1 6 | ENV PYTHONDONTWRITEBYTECODE=1 7 | 8 | 9 | RUN apt-get update && apt-get install -y --no-install-recommends \ 10 | # - Other packages 11 | build-essential \ 12 | pkg-config \ 13 | curl \ 14 | wget \ 15 | software-properties-common \ 16 | unzip \ 17 | git \ 18 | # - Packages to build Python 19 | tar make gcc zlib1g-dev libffi-dev libssl-dev liblzma-dev libbz2-dev libsqlite3-dev \ 20 | # - Packages for docTR 21 | libgl1-mesa-dev libsm6 libxext6 libxrender-dev libpangocairo-1.0-0 \ 22 | && apt-get clean \ 23 | && rm -rf /var/lib/apt/lists/* 24 | 25 | # Install Python 26 | ARG PYTHON_VERSION=3.10.13 27 | 28 | RUN wget http://www.python.org/ftp/python/$PYTHON_VERSION/Python-$PYTHON_VERSION.tgz && \ 29 | tar -zxf Python-$PYTHON_VERSION.tgz && \ 30 | cd Python-$PYTHON_VERSION && \ 31 | mkdir /opt/python/ && \ 32 | ./configure --prefix=/opt/python && \ 33 | make && \ 34 | make install && \ 35 | cd .. && \ 36 | rm Python-$PYTHON_VERSION.tgz && \ 37 | rm -r Python-$PYTHON_VERSION 38 | 39 | ENV PATH=/opt/python/bin:$PATH 40 | 41 | # Install docTR 42 | ARG FRAMEWORK=torch 43 | ARG DOCTR_REPO='mindee/doctr' 44 | ARG DOCTR_VERSION=main 45 | RUN pip3 install -U pip setuptools wheel && \ 46 | pip3 install "python-doctr[$FRAMEWORK]@git+https://github.com/$DOCTR_REPO.git@$DOCTR_VERSION" 47 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: quality style test test-common test-tf test-torch docs-single-version docs 2 | # this target runs checks on all files 3 | quality: 4 | ruff check . 5 | mypy doctr/ 6 | 7 | # this target runs checks on all files and potentially modifies some of them 8 | style: 9 | ruff format . 10 | ruff check --fix . 11 | 12 | # Run tests for the library 13 | test: 14 | coverage run -m pytest tests/common/ -rs 15 | coverage run -m pytest tests/pytorch/ -rs 16 | 17 | test-common: 18 | coverage run -m pytest tests/common/ -rs 19 | 20 | test-torch: 21 | coverage run -m pytest tests/pytorch/ -rs 22 | 23 | # Check that docs can build 24 | docs-single-version: 25 | sphinx-build docs/source docs/_build -a 26 | 27 | # Check that docs can build 28 | docs: 29 | cd docs && bash build.sh 30 | -------------------------------------------------------------------------------- /api/.gitignore: -------------------------------------------------------------------------------- 1 | poetry.lock 2 | requirements* 3 | -------------------------------------------------------------------------------- /api/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tiangolo/uvicorn-gunicorn-fastapi:python3.10-slim 2 | 3 | WORKDIR /app 4 | 5 | # set environment variables 6 | ENV PYTHONDONTWRITEBYTECODE 1 7 | ENV PYTHONUNBUFFERED 1 8 | ENV PYTHONPATH "${PYTHONPATH}:/app" 9 | 10 | RUN apt-get update \ 11 | && apt-get install --no-install-recommends git ffmpeg libsm6 libxext6 make -y \ 12 | && apt-get autoremove -y \ 13 | && rm -rf /var/lib/apt/lists/* 14 | 15 | COPY pyproject.toml /app/pyproject.toml 16 | COPY Makefile /app/Makefile 17 | 18 | RUN pip install --upgrade pip setuptools wheel \ 19 | && make lock \ 20 | && pip install -r /app/requirements.txt \ 21 | && pip cache purge \ 22 | && rm -rf /root/.cache/pip 23 | 24 | # copy project 25 | COPY app /app/app 26 | -------------------------------------------------------------------------------- /api/Makefile: -------------------------------------------------------------------------------- 1 | # api setup is borrowed from https://github.com/frgfm/Holocron/blob/main/api 2 | 3 | .PHONY: lock run stop test 4 | # Pin the dependencies 5 | lock: 6 | pip install poetry>=1.0 poetry-plugin-export 7 | poetry lock 8 | poetry export -f requirements.txt --without-hashes --output requirements.txt 9 | poetry export -f requirements.txt --without-hashes --with dev --output requirements-dev.txt 10 | 11 | # Run the docker 12 | run: 13 | docker compose up -d --build 14 | 15 | # Run the docker 16 | stop: 17 | docker compose down 18 | 19 | # Run tests for the library 20 | test: 21 | docker compose up -d --build 22 | docker cp requirements-dev.txt api_web:/app/requirements-dev.txt 23 | docker compose exec -T web pip install -r requirements-dev.txt 24 | docker cp tests api_web:/app/tests 25 | docker compose exec -T web pytest tests/ -vv 26 | docker compose down 27 | -------------------------------------------------------------------------------- /api/app/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import os 7 | 8 | import doctr 9 | 10 | PROJECT_NAME: str = "docTR API template" 11 | PROJECT_DESCRIPTION: str = "Template API for Optical Character Recognition" 12 | VERSION: str = doctr.__version__ 13 | DEBUG: bool = os.environ.get("DEBUG", "") != "False" 14 | -------------------------------------------------------------------------------- /api/app/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import time 7 | 8 | from fastapi import FastAPI, Request 9 | from fastapi.openapi.utils import get_openapi 10 | 11 | from app import config as cfg 12 | from app.routes import detection, kie, ocr, recognition 13 | 14 | app = FastAPI(title=cfg.PROJECT_NAME, description=cfg.PROJECT_DESCRIPTION, debug=cfg.DEBUG, version=cfg.VERSION) 15 | 16 | 17 | # Routing 18 | app.include_router(recognition.router, prefix="/recognition", tags=["recognition"]) 19 | app.include_router(detection.router, prefix="/detection", tags=["detection"]) 20 | app.include_router(ocr.router, prefix="/ocr", tags=["ocr"]) 21 | app.include_router(kie.router, prefix="/kie", tags=["kie"]) 22 | 23 | 24 | # Middleware 25 | @app.middleware("http") 26 | async def add_process_time_header(request: Request, call_next): 27 | start_time = time.time() 28 | response = await call_next(request) 29 | process_time = time.time() - start_time 30 | response.headers["X-Process-Time"] = str(process_time) 31 | return response 32 | 33 | 34 | # Docs 35 | def custom_openapi(): 36 | if app.openapi_schema: 37 | return app.openapi_schema 38 | openapi_schema = get_openapi( 39 | title=cfg.PROJECT_NAME, 40 | version=cfg.VERSION, 41 | description=cfg.PROJECT_DESCRIPTION, 42 | routes=app.routes, 43 | ) 44 | app.openapi_schema = openapi_schema 45 | return app.openapi_schema 46 | 47 | 48 | app.openapi = custom_openapi 49 | -------------------------------------------------------------------------------- /api/app/routes/detection.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | 7 | from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status 8 | 9 | from app.schemas import DetectionIn, DetectionOut 10 | from app.utils import get_documents, resolve_geometry 11 | from app.vision import init_predictor 12 | from doctr.file_utils import CLASS_NAME 13 | 14 | router = APIRouter() 15 | 16 | 17 | @router.post("/", response_model=list[DetectionOut], status_code=status.HTTP_200_OK, summary="Perform text detection") 18 | async def text_detection(request: DetectionIn = Depends(), files: list[UploadFile] = [File(...)]): 19 | """Runs docTR text detection model to analyze the input image""" 20 | try: 21 | predictor = init_predictor(request) 22 | content, filenames = await get_documents(files) 23 | except ValueError as e: 24 | raise HTTPException(status_code=400, detail=str(e)) 25 | 26 | return [ 27 | DetectionOut( 28 | name=filename, 29 | geometries=[ 30 | geom[:-1].tolist() if geom.shape == (5,) else resolve_geometry(geom[:4].tolist()) 31 | for geom in doc[CLASS_NAME] 32 | ], 33 | ) 34 | for doc, filename in zip(predictor(content), filenames) 35 | ] 36 | -------------------------------------------------------------------------------- /api/app/routes/kie.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | 7 | from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status 8 | 9 | from app.schemas import KIEElement, KIEIn, KIEOut 10 | from app.utils import get_documents, resolve_geometry 11 | from app.vision import init_predictor 12 | 13 | router = APIRouter() 14 | 15 | 16 | @router.post("/", response_model=list[KIEOut], status_code=status.HTTP_200_OK, summary="Perform KIE") 17 | async def perform_kie(request: KIEIn = Depends(), files: list[UploadFile] = [File(...)]): 18 | """Runs docTR KIE model to analyze the input image""" 19 | try: 20 | predictor = init_predictor(request) 21 | content, filenames = await get_documents(files) 22 | except ValueError as e: 23 | raise HTTPException(status_code=400, detail=str(e)) 24 | 25 | out = predictor(content) 26 | 27 | results = [ 28 | KIEOut( 29 | name=filenames[i], 30 | orientation=page.orientation, 31 | language=page.language, 32 | dimensions=page.dimensions, 33 | predictions=[ 34 | KIEElement( 35 | class_name=class_name, 36 | items=[ 37 | dict( 38 | value=prediction.value, 39 | geometry=resolve_geometry(prediction.geometry), 40 | objectness_score=round(prediction.objectness_score, 2), 41 | confidence=round(prediction.confidence, 2), 42 | crop_orientation=prediction.crop_orientation, 43 | ) 44 | for prediction in page.predictions[class_name] 45 | ], 46 | ) 47 | for class_name in page.predictions.keys() 48 | ], 49 | ) 50 | for i, page in enumerate(out.pages) 51 | ] 52 | 53 | return results 54 | -------------------------------------------------------------------------------- /api/app/routes/ocr.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | 7 | from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status 8 | 9 | from app.schemas import OCRBlock, OCRIn, OCRLine, OCROut, OCRPage, OCRWord 10 | from app.utils import get_documents, resolve_geometry 11 | from app.vision import init_predictor 12 | 13 | router = APIRouter() 14 | 15 | 16 | @router.post("/", response_model=list[OCROut], status_code=status.HTTP_200_OK, summary="Perform OCR") 17 | async def perform_ocr(request: OCRIn = Depends(), files: list[UploadFile] = [File(...)]): 18 | """Runs docTR OCR model to analyze the input image""" 19 | try: 20 | # generator object to list 21 | content, filenames = await get_documents(files) 22 | predictor = init_predictor(request) 23 | except ValueError as e: 24 | raise HTTPException(status_code=400, detail=str(e)) 25 | 26 | out = predictor(content) 27 | 28 | results = [ 29 | OCROut( 30 | name=filenames[i], 31 | orientation=page.orientation, 32 | language=page.language, 33 | dimensions=page.dimensions, 34 | items=[ 35 | OCRPage( 36 | blocks=[ 37 | OCRBlock( 38 | geometry=resolve_geometry(block.geometry), 39 | objectness_score=round(block.objectness_score, 2), 40 | lines=[ 41 | OCRLine( 42 | geometry=resolve_geometry(line.geometry), 43 | objectness_score=round(line.objectness_score, 2), 44 | words=[ 45 | OCRWord( 46 | value=word.value, 47 | geometry=resolve_geometry(word.geometry), 48 | objectness_score=round(word.objectness_score, 2), 49 | confidence=round(word.confidence, 2), 50 | crop_orientation=word.crop_orientation, 51 | ) 52 | for word in line.words 53 | ], 54 | ) 55 | for line in block.lines 56 | ], 57 | ) 58 | for block in page.blocks 59 | ] 60 | ) 61 | ], 62 | ) 63 | for i, page in enumerate(out.pages) 64 | ] 65 | 66 | return results 67 | -------------------------------------------------------------------------------- /api/app/routes/recognition.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | 7 | from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status 8 | 9 | from app.schemas import RecognitionIn, RecognitionOut 10 | from app.utils import get_documents 11 | from app.vision import init_predictor 12 | 13 | router = APIRouter() 14 | 15 | 16 | @router.post( 17 | "/", response_model=list[RecognitionOut], status_code=status.HTTP_200_OK, summary="Perform text recognition" 18 | ) 19 | async def text_recognition(request: RecognitionIn = Depends(), files: list[UploadFile] = [File(...)]): 20 | """Runs docTR text recognition model to analyze the input image""" 21 | try: 22 | predictor = init_predictor(request) 23 | content, filenames = await get_documents(files) 24 | except ValueError as e: 25 | raise HTTPException(status_code=400, detail=str(e)) 26 | return [ 27 | RecognitionOut(name=filename, value=res[0], confidence=round(res[1], 2)) 28 | for res, filename in zip(predictor(content), filenames) 29 | ] 30 | -------------------------------------------------------------------------------- /api/app/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | 7 | from typing import Any 8 | 9 | import numpy as np 10 | from fastapi import UploadFile 11 | 12 | from doctr.io import DocumentFile 13 | 14 | 15 | def resolve_geometry( 16 | geom: Any, 17 | ) -> tuple[float, float, float, float] | tuple[float, float, float, float, float, float, float, float]: 18 | if len(geom) == 4: 19 | return (*geom[0], *geom[1], *geom[2], *geom[3]) 20 | return (*geom[0], *geom[1]) 21 | 22 | 23 | async def get_documents(files: list[UploadFile]) -> tuple[list[np.ndarray], list[str]]: # pragma: no cover 24 | """Convert a list of UploadFile objects to lists of numpy arrays and their corresponding filenames 25 | 26 | Args: 27 | files: list of UploadFile objects 28 | 29 | Returns: 30 | tuple[list[np.ndarray], list[str]]: list of numpy arrays and their corresponding filenames 31 | 32 | """ 33 | filenames = [] 34 | docs = [] 35 | for file in files: 36 | mime_type = file.content_type 37 | if mime_type in ["image/jpeg", "image/png"]: 38 | docs.extend(DocumentFile.from_images([await file.read()])) 39 | filenames.append(file.filename or "") 40 | elif mime_type == "application/pdf": 41 | pdf_content = DocumentFile.from_pdf(await file.read()) 42 | docs.extend(pdf_content) 43 | filenames.extend([file.filename] * len(pdf_content) or [""] * len(pdf_content)) 44 | else: 45 | raise ValueError(f"Unsupported file format: {mime_type} for file {file.filename}") 46 | 47 | return docs, filenames 48 | -------------------------------------------------------------------------------- /api/app/vision.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | 7 | from collections.abc import Callable 8 | 9 | import torch 10 | 11 | from doctr.models import kie_predictor, ocr_predictor 12 | 13 | from .schemas import DetectionIn, KIEIn, OCRIn, RecognitionIn 14 | 15 | 16 | def _move_to_device(predictor: Callable) -> Callable: 17 | """Move the predictor to the desired device 18 | 19 | Args: 20 | predictor: the predictor to move 21 | 22 | Returns: 23 | Callable: the predictor moved to the desired device 24 | """ 25 | return predictor.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) 26 | 27 | 28 | def init_predictor(request: KIEIn | OCRIn | RecognitionIn | DetectionIn) -> Callable: 29 | """Initialize the predictor based on the request 30 | 31 | Args: 32 | request: input request 33 | 34 | Returns: 35 | Callable: the predictor 36 | """ 37 | params = request.model_dump() 38 | bin_thresh = params.pop("bin_thresh", None) 39 | box_thresh = params.pop("box_thresh", None) 40 | if isinstance(request, (OCRIn, RecognitionIn, DetectionIn)): 41 | predictor = ocr_predictor(pretrained=True, **params) 42 | predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh 43 | predictor.det_predictor.model.postprocessor.box_thresh = box_thresh 44 | if isinstance(request, DetectionIn): 45 | return _move_to_device(predictor.det_predictor) 46 | elif isinstance(request, RecognitionIn): 47 | return _move_to_device(predictor.reco_predictor) 48 | return _move_to_device(predictor) 49 | elif isinstance(request, KIEIn): 50 | predictor = kie_predictor(pretrained=True, **params) 51 | predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh 52 | predictor.det_predictor.model.postprocessor.box_thresh = box_thresh 53 | return _move_to_device(predictor) 54 | -------------------------------------------------------------------------------- /api/docker-compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | web: 3 | container_name: api_web 4 | build: 5 | context: . 6 | dockerfile: Dockerfile 7 | command: uvicorn app.main:app --reload --workers 1 --host 0.0.0.0 --port 8080 8 | ports: 9 | - 8080:8080 10 | -------------------------------------------------------------------------------- /api/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["poetry>=1.0"] 3 | build-backend = "poetry.masonry.api" 4 | 5 | [tool.poetry] 6 | name = "doctr-api" 7 | version = "1.0.1a0" 8 | description = "Backend template for your OCR API with docTR" 9 | authors = ["Mindee "] 10 | license = "Apache-2.0" 11 | 12 | [tool.poetry.dependencies] 13 | python = ">=3.10,<3.13" 14 | python-doctr = {git = "https://github.com/mindee/doctr.git", branch = "main" } 15 | # Fastapi: minimum version required to avoid pydantic error 16 | # cf. https://github.com/tiangolo/fastapi/issues/4168 17 | fastapi = ">=0.73.0" 18 | uvicorn = ">=0.11.1" 19 | python-multipart = ">=0.0.5" 20 | 21 | [tool.poetry.dev-dependencies] 22 | pytest = ">=5.3.2" 23 | pytest-asyncio = ">=0.14.0" 24 | httpx = ">=0.23.0" 25 | requests = ">=2.20.0" 26 | -------------------------------------------------------------------------------- /api/tests/routes/test_detection.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | 5 | def common_test(json_response, expected_response): 6 | assert isinstance(json_response, list) and len(json_response) == 2 7 | first_pred = json_response[0] # it's enough to test for the first file because the same image is used twice 8 | 9 | assert isinstance(first_pred["name"], str) 10 | np.testing.assert_allclose(first_pred["geometries"], expected_response["geometries"], rtol=1e-2) 11 | 12 | 13 | @pytest.mark.asyncio 14 | async def test_text_detection_box(test_app_asyncio, mock_detection_image, mock_detection_response): 15 | headers = { 16 | "accept": "application/json", 17 | } 18 | params = {"det_arch": "db_resnet50"} 19 | files = [ 20 | ("files", ("test.jpg", mock_detection_image, "image/jpeg")), 21 | ("files", ("test2.jpg", mock_detection_image, "image/jpeg")), 22 | ] 23 | response = await test_app_asyncio.post("/detection", params=params, files=files, headers=headers) 24 | assert response.status_code == 200 25 | json_response = response.json() 26 | 27 | expected_box_response = mock_detection_response["box"] 28 | common_test(json_response, expected_box_response) 29 | 30 | 31 | @pytest.mark.asyncio 32 | async def test_text_detection_poly(test_app_asyncio, mock_detection_image, mock_detection_response): 33 | headers = { 34 | "accept": "application/json", 35 | } 36 | params = {"det_arch": "db_resnet50", "assume_straight_pages": False} 37 | files = [ 38 | ("files", ("test.jpg", mock_detection_image, "image/jpeg")), 39 | ("files", ("test2.jpg", mock_detection_image, "image/jpeg")), 40 | ] 41 | response = await test_app_asyncio.post("/detection", params=params, files=files, headers=headers) 42 | assert response.status_code == 200 43 | json_response = response.json() 44 | 45 | expected_poly_response = mock_detection_response["poly"] 46 | common_test(json_response, expected_poly_response) 47 | 48 | 49 | @pytest.mark.asyncio 50 | async def test_text_detection_invalid_file(test_app_asyncio, mock_txt_file): 51 | headers = { 52 | "accept": "application/json", 53 | } 54 | files = [ 55 | ("files", ("test.txt", mock_txt_file)), 56 | ] 57 | response = await test_app_asyncio.post("/detection", files=files, headers=headers) 58 | assert response.status_code == 400 59 | -------------------------------------------------------------------------------- /api/tests/routes/test_kie.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | 5 | def common_test(json_response, expected_response): 6 | first_pred = json_response[0] # it's enough to test for the first file because the same image is used twice 7 | assert isinstance(first_pred["name"], str) 8 | assert ( 9 | isinstance(first_pred["dimensions"], (tuple, list)) 10 | and len(first_pred["dimensions"]) == 2 11 | and all(isinstance(dim, int) for dim in first_pred["dimensions"]) 12 | ) 13 | assert isinstance(first_pred["predictions"], list) 14 | assert isinstance(expected_response["predictions"], list) 15 | 16 | for pred, expected_pred in zip(first_pred["predictions"], expected_response["predictions"]): 17 | assert pred["class_name"] == expected_pred["class_name"] 18 | assert isinstance(pred["items"], list) 19 | assert isinstance(expected_pred["items"], list) 20 | 21 | for pred_item, expected_pred_item in zip(pred["items"], expected_pred["items"]): 22 | assert isinstance(pred_item["value"], str) and pred_item["value"] == expected_pred_item["value"] 23 | assert isinstance(pred_item["confidence"], (int, float)) 24 | np.testing.assert_allclose(pred_item["geometry"], expected_pred_item["geometry"], rtol=1e-2) 25 | assert isinstance(pred_item["objectness_score"], (int, float)) 26 | assert isinstance(pred_item["crop_orientation"], dict) 27 | assert isinstance(pred_item["crop_orientation"]["value"], int) and isinstance( 28 | pred_item["crop_orientation"]["confidence"], (float, int, type(None)) 29 | ) 30 | 31 | 32 | @pytest.mark.asyncio 33 | async def test_kie_box(test_app_asyncio, mock_detection_image, mock_kie_response): 34 | headers = { 35 | "accept": "application/json", 36 | } 37 | params = {"det_arch": "db_resnet50", "reco_arch": "crnn_vgg16_bn"} 38 | files = [ 39 | ("files", ("test.jpg", mock_detection_image, "image/jpeg")), 40 | ("files", ("test2.jpg", mock_detection_image, "image/jpeg")), 41 | ] 42 | response = await test_app_asyncio.post("/kie", params=params, files=files, headers=headers) 43 | assert response.status_code == 200 44 | json_response = response.json() 45 | 46 | expected_box_response = mock_kie_response["box"] 47 | assert isinstance(json_response, list) and len(json_response) == 2 48 | common_test(json_response, expected_box_response) 49 | 50 | 51 | @pytest.mark.asyncio 52 | async def test_kie_poly(test_app_asyncio, mock_detection_image, mock_kie_response): 53 | headers = { 54 | "accept": "application/json", 55 | } 56 | params = {"det_arch": "db_resnet50", "reco_arch": "crnn_vgg16_bn", "assume_straight_pages": False} 57 | files = [ 58 | ("files", ("test.jpg", mock_detection_image, "image/jpeg")), 59 | ("files", ("test2.jpg", mock_detection_image, "image/jpeg")), 60 | ] 61 | response = await test_app_asyncio.post("/kie", params=params, files=files, headers=headers) 62 | assert response.status_code == 200 63 | json_response = response.json() 64 | 65 | expected_poly_response = mock_kie_response["poly"] 66 | assert isinstance(json_response, list) and len(json_response) == 2 67 | common_test(json_response, expected_poly_response) 68 | 69 | 70 | @pytest.mark.asyncio 71 | async def test_kie_invalid_file(test_app_asyncio, mock_txt_file): 72 | headers = { 73 | "accept": "application/json", 74 | } 75 | files = [ 76 | ("files", ("test.txt", mock_txt_file)), 77 | ] 78 | response = await test_app_asyncio.post("/kie", files=files, headers=headers) 79 | assert response.status_code == 400 80 | -------------------------------------------------------------------------------- /api/tests/routes/test_recognition.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | @pytest.mark.asyncio 5 | async def test_text_recognition(test_app_asyncio, mock_recognition_image, mock_txt_file): 6 | headers = { 7 | "accept": "application/json", 8 | } 9 | params = {"reco_arch": "crnn_vgg16_bn"} 10 | files = [ 11 | ("files", ("test.jpg", mock_recognition_image, "image/jpeg")), 12 | ("files", ("test2.jpg", mock_recognition_image, "image/jpeg")), 13 | ] 14 | response = await test_app_asyncio.post("/recognition", params=params, files=files, headers=headers) 15 | assert response.status_code == 200 16 | json_response = response.json() 17 | assert isinstance(json_response, list) and len(json_response) == 2 18 | for item in json_response: 19 | assert isinstance(item["name"], str) 20 | assert isinstance(item["value"], str) and item["value"] == "invite" 21 | assert isinstance(item["confidence"], (int, float)) and item["confidence"] >= 0.8 22 | 23 | headers = { 24 | "accept": "application/json", 25 | } 26 | files = [ 27 | ("files", ("test.txt", mock_txt_file)), 28 | ] 29 | response = await test_app_asyncio.post("/recognition", files=files, headers=headers) 30 | assert response.status_code == 400 31 | -------------------------------------------------------------------------------- /api/tests/utils/test_utils.py: -------------------------------------------------------------------------------- 1 | from app.utils import resolve_geometry 2 | 3 | 4 | def test_resolve_geometry(): 5 | dummy_box = [(0.0, 0.0), (1.0, 0.0)] 6 | dummy_polygon = [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)] 7 | 8 | assert resolve_geometry(dummy_box) == (0.0, 0.0, 1.0, 0.0) 9 | assert resolve_geometry(dummy_polygon) == (0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0) 10 | -------------------------------------------------------------------------------- /api/tests/utils/test_vision.py: -------------------------------------------------------------------------------- 1 | from app.schemas import DetectionIn, KIEIn, OCRIn, RecognitionIn 2 | from app.vision import init_predictor 3 | from doctr.models.detection.predictor import DetectionPredictor 4 | from doctr.models.kie_predictor import KIEPredictor 5 | from doctr.models.predictor import OCRPredictor 6 | from doctr.models.recognition.predictor import RecognitionPredictor 7 | 8 | 9 | def test_vision(): 10 | assert isinstance(init_predictor(OCRIn()), OCRPredictor) 11 | assert isinstance(init_predictor(DetectionIn()), DetectionPredictor) 12 | assert isinstance(init_predictor(RecognitionIn()), RecognitionPredictor) 13 | assert isinstance(init_predictor(KIEIn()), KIEPredictor) 14 | -------------------------------------------------------------------------------- /demo/README.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: docTR 3 | emoji: 📑 4 | colorFrom: purple 5 | colorTo: pink 6 | sdk: streamlit 7 | sdk_version: 1.39.0 8 | app_file: app.py 9 | pinned: false 10 | license: apache-2.0 11 | --- 12 | 13 | ## Configuration 14 | 15 | `title`: _string_ 16 | Display title for the Space 17 | 18 | `emoji`: _string_ 19 | Space emoji (emoji-only character allowed) 20 | 21 | `colorFrom`: _string_ 22 | Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray) 23 | 24 | `colorTo`: _string_ 25 | Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray) 26 | 27 | `sdk`: _string_ 28 | Can be either `gradio` or `streamlit` 29 | 30 | `sdk_version` : _string_ 31 | Only applicable for `streamlit` SDK. 32 | See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions. 33 | 34 | `app_file`: _string_ 35 | Path to your main application file (which contains either `gradio` or `streamlit` Python code). 36 | Path is relative to the root of the repository. 37 | 38 | `pinned`: _boolean_ 39 | Whether the Space stays on top of your list. 40 | 41 | ## Run the demo locally 42 | 43 | ```bash 44 | cd demo 45 | pip install -r pt-requirements.txt 46 | streamlit run app.py 47 | ``` 48 | -------------------------------------------------------------------------------- /demo/backend/pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from doctr.models import ocr_predictor 10 | from doctr.models.predictor import OCRPredictor 11 | 12 | DET_ARCHS = [ 13 | "fast_base", 14 | "fast_small", 15 | "fast_tiny", 16 | "db_resnet50", 17 | "db_resnet34", 18 | "db_mobilenet_v3_large", 19 | "linknet_resnet18", 20 | "linknet_resnet34", 21 | "linknet_resnet50", 22 | ] 23 | RECO_ARCHS = [ 24 | "crnn_vgg16_bn", 25 | "crnn_mobilenet_v3_small", 26 | "crnn_mobilenet_v3_large", 27 | "master", 28 | "sar_resnet31", 29 | "vitstr_small", 30 | "vitstr_base", 31 | "parseq", 32 | "viptr_tiny", 33 | ] 34 | 35 | 36 | def load_predictor( 37 | det_arch: str, 38 | reco_arch: str, 39 | assume_straight_pages: bool, 40 | straighten_pages: bool, 41 | export_as_straight_boxes: bool, 42 | disable_page_orientation: bool, 43 | disable_crop_orientation: bool, 44 | bin_thresh: float, 45 | box_thresh: float, 46 | device: torch.device, 47 | ) -> OCRPredictor: 48 | """Load a predictor from doctr.models 49 | 50 | Args: 51 | det_arch: detection architecture 52 | reco_arch: recognition architecture 53 | assume_straight_pages: whether to assume straight pages or not 54 | straighten_pages: whether to straighten rotated pages or not 55 | export_as_straight_boxes: whether to export boxes as straight or not 56 | disable_page_orientation: whether to disable page orientation or not 57 | disable_crop_orientation: whether to disable crop orientation or not 58 | bin_thresh: binarization threshold for the segmentation map 59 | box_thresh: minimal objectness score to consider a box 60 | device: torch.device, the device to load the predictor on 61 | 62 | Returns: 63 | instance of OCRPredictor 64 | """ 65 | predictor = ocr_predictor( 66 | det_arch, 67 | reco_arch, 68 | pretrained=True, 69 | assume_straight_pages=assume_straight_pages, 70 | straighten_pages=straighten_pages, 71 | export_as_straight_boxes=export_as_straight_boxes, 72 | detect_orientation=not assume_straight_pages, 73 | disable_page_orientation=disable_page_orientation, 74 | disable_crop_orientation=disable_crop_orientation, 75 | ).to(device) 76 | predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh 77 | predictor.det_predictor.model.postprocessor.box_thresh = box_thresh 78 | return predictor 79 | 80 | 81 | def forward_image(predictor: OCRPredictor, image: np.ndarray, device: torch.device) -> np.ndarray: 82 | """Forward an image through the predictor 83 | 84 | Args: 85 | predictor: instance of OCRPredictor 86 | image: image to process 87 | device: torch.device, the device to process the image on 88 | 89 | Returns: 90 | segmentation map 91 | """ 92 | with torch.no_grad(): 93 | processed_batches = predictor.det_predictor.pre_processor([image]) 94 | out = predictor.det_predictor.model(processed_batches[0].to(device), return_model_output=True) 95 | seg_map = out["out_map"].to("cpu").numpy() 96 | 97 | return seg_map 98 | -------------------------------------------------------------------------------- /demo/packages.txt: -------------------------------------------------------------------------------- 1 | python3-opencv 2 | fonts-freefont-ttf 3 | -------------------------------------------------------------------------------- /demo/pt-requirements.txt: -------------------------------------------------------------------------------- 1 | -e git+https://github.com/mindee/doctr.git#egg=python-doctr[viz] 2 | streamlit>=1.0.0 3 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Contribute to Documentation 2 | 3 | Please have a look at our [contribution guide](../CONTRIBUTING.md) to see how to install 4 | the development environment and how to generate the documentation. 5 | 6 | To install only the `docs` environment, you can do: 7 | 8 | ```bash 9 | # Make sure you are at the root of the repository before executing these commands 10 | python -m pip install --upgrade pip 11 | pip install -e .[viz,html] 12 | pip install -e .[docs] 13 | ``` 14 | -------------------------------------------------------------------------------- /docs/build.sh: -------------------------------------------------------------------------------- 1 | function deploy_doc(){ 2 | if [ ! -z "$1" ] 3 | then 4 | git checkout $1 5 | fi 6 | COMMIT=$(git rev-parse --short HEAD) 7 | echo "Creating doc at commit" $COMMIT "and pushing to folder $2" 8 | pip install -U .. 9 | if [ ! -z "$2" ] 10 | then 11 | if [ "$2" == "latest" ]; then 12 | echo "Pushing main" 13 | sphinx-build source _build -a && mkdir build && mkdir build/$2 && cp -a _build/* build/$2/ 14 | elif [ -d build/$2 ]; then 15 | echo "Directory" $2 "already exists" 16 | else 17 | echo "Pushing version" $2 18 | cp -r _static source/ && cp _conf.py source/conf.py 19 | sphinx-build source _build -a 20 | mkdir build/$2 && cp -a _build/* build/$2/ && git checkout source/ && git clean -f source/ 21 | fi 22 | else 23 | echo "Pushing stable" 24 | cp -r _static source/ && cp _conf.py source/conf.py 25 | sphinx-build source build -a && git checkout source/ && git clean -f source/ 26 | fi 27 | } 28 | 29 | # You can find the commit for each tag on https://github.com/mindee/doctr/tags 30 | if [ -d build ]; then rm -Rf build; fi 31 | cp -r source/_static . 32 | cp source/conf.py _conf.py 33 | git fetch --all --tags --unshallow 34 | deploy_doc "" latest 35 | deploy_doc "1c9ce92" v0.11.0 36 | deploy_doc "97d4006" v0.12.0 37 | deploy_doc "7dabbe1" # v1.0.0 Latest stable release 38 | rm -rf _build _static _conf.py 39 | -------------------------------------------------------------------------------- /docs/images/Logo_doctr.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mindee/doctr/849d19566b03cd37c8a831ae02673a1c2f265399/docs/images/Logo_doctr.gif -------------------------------------------------------------------------------- /docs/images/demo_illustration_mini.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mindee/doctr/849d19566b03cd37c8a831ae02673a1c2f265399/docs/images/demo_illustration_mini.png -------------------------------------------------------------------------------- /docs/images/demo_update.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mindee/doctr/849d19566b03cd37c8a831ae02673a1c2f265399/docs/images/demo_update.png -------------------------------------------------------------------------------- /docs/images/doctr-need-help.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mindee/doctr/849d19566b03cd37c8a831ae02673a1c2f265399/docs/images/doctr-need-help.png -------------------------------------------------------------------------------- /docs/images/doctr_demo_app.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mindee/doctr/849d19566b03cd37c8a831ae02673a1c2f265399/docs/images/doctr_demo_app.png -------------------------------------------------------------------------------- /docs/images/doctr_example_script.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mindee/doctr/849d19566b03cd37c8a831ae02673a1c2f265399/docs/images/doctr_example_script.gif -------------------------------------------------------------------------------- /docs/images/ocr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mindee/doctr/849d19566b03cd37c8a831ae02673a1c2f265399/docs/images/ocr.png -------------------------------------------------------------------------------- /docs/images/synthesized_sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mindee/doctr/849d19566b03cd37c8a831ae02673a1c2f265399/docs/images/synthesized_sample.png -------------------------------------------------------------------------------- /docs/source/_static/css/mindee.css: -------------------------------------------------------------------------------- 1 | /* Version control */ 2 | 3 | .version-button { 4 | color: white; 5 | border: none; 6 | padding: 5px; 7 | font-size: 15px; 8 | cursor: pointer; 9 | } 10 | 11 | .version-button:hover, .version-button:focus { 12 | background-color: #5eb2e6; 13 | } 14 | 15 | .version-dropdown { 16 | display: none; 17 | min-width: 160px; 18 | overflow: auto; 19 | font-size: 15px; 20 | } 21 | 22 | .version-dropdown a { 23 | color: white; 24 | padding: 3px 4px; 25 | text-decoration: none; 26 | display: block; 27 | } 28 | 29 | .version-dropdown a:hover { 30 | background-color: #5eb2e6; 31 | } 32 | 33 | .version-show { 34 | display: block; 35 | } 36 | 37 | h1 { 38 | font-family: "Helvetica Neue", Arial, sans-serif; 39 | /* style fix for headline that it fits into one line */ 40 | font-size: 240%; 41 | } 42 | 43 | h1, h2, h3, h4, h5, .caption-text { 44 | font-family: "Helvetica Neue", Arial, sans-serif; 45 | } 46 | 47 | /* Github button */ 48 | 49 | .github-repo { 50 | display: flex; 51 | justify-content: center; 52 | } 53 | -------------------------------------------------------------------------------- /docs/source/_static/images/Logo-docTR-white.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mindee/doctr/849d19566b03cd37c8a831ae02673a1c2f265399/docs/source/_static/images/Logo-docTR-white.png -------------------------------------------------------------------------------- /docs/source/_static/images/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mindee/doctr/849d19566b03cd37c8a831ae02673a1c2f265399/docs/source/_static/images/favicon.ico -------------------------------------------------------------------------------- /docs/source/changelog.rst: -------------------------------------------------------------------------------- 1 | Changelog 2 | ========= 3 | 4 | v1.0.0 (2025-07-09) 5 | ------------------- 6 | Release note: `v1.0.0 `_ 7 | 8 | v0.12.0 (2025-06-20) 9 | -------------------- 10 | Release note: `v0.12.0 `_ 11 | 12 | v0.11.0 (2025-01-30) 13 | -------------------- 14 | Release note: `v0.11.0 `_ 15 | 16 | v0.10.0 (2024-10-21) 17 | -------------------- 18 | Release note: `v0.10.0 `_ 19 | 20 | v0.9.0 (2024-08-08) 21 | ------------------- 22 | Release note: `v0.9.0 `_ 23 | 24 | v0.8.1 (2024-03-04) 25 | ------------------- 26 | Release note: `v0.8.1 `_ 27 | 28 | v0.8.0 (2024-02-28) 29 | ------------------- 30 | Release note: `v0.8.0 `_ 31 | 32 | v0.7.0 (2023-09-09) 33 | ------------------- 34 | Release note: `v0.7.0 `_ 35 | 36 | v0.6.0 (2022-09-29) 37 | ------------------- 38 | Release note: `v0.6.0 `_ 39 | 40 | v0.5.1 (2022-03-22) 41 | ------------------- 42 | Release note: `v0.5.1 `_ 43 | 44 | v0.5.0 (2021-12-31) 45 | ------------------- 46 | Release note: `v0.5.0 `_ 47 | 48 | v0.4.1 (2021-11-22) 49 | ------------------- 50 | Release note: `v0.4.1 `_ 51 | 52 | v0.4.0 (2021-10-01) 53 | ------------------- 54 | Release note: `v0.4.0 `_ 55 | 56 | v0.3.1 (2021-08-27) 57 | ------------------- 58 | Release note: `v0.3.1 `_ 59 | 60 | v0.3.0 (2021-07-02) 61 | ------------------- 62 | Release note: `v0.3.0 `_ 63 | 64 | v0.2.1 (2021-05-28) 65 | ------------------- 66 | Release note: `v0.2.1 `_ 67 | 68 | v0.2.0 (2021-05-11) 69 | ------------------- 70 | Release note: `v0.2.0 `_ 71 | 72 | v0.1.1 (2021-03-18) 73 | ------------------- 74 | Release note: `v0.1.1 `_ 75 | 76 | v0.1.0 (2021-03-05) 77 | ------------------- 78 | Release note: `v0.1.0 `_ 79 | -------------------------------------------------------------------------------- /docs/source/community/resources.rst: -------------------------------------------------------------------------------- 1 | Community Resources 2 | =================== 3 | 4 | This section contains some cool resources created by the docTR community. 5 | 6 | 7 | * |:book:| Fine-tuning OCR works really well: the Statistical Abstracts of the United States: 8 | `Article `_ created by: `Christian Moscardi `_. 9 | 10 | * |:video_camera:| Mindee docTR - Probably the Best Open-Source OCR: 11 | `Video `_ created by: `Andrej Baranovskij `_. 12 | 13 | * |:book:| Extract Text from images in Python with docTR: 14 | `Article `_ created by: `Netraj Patil`. 15 | 16 | * |:book:| How to Detect Text in Images with OCR (Roboflow integration): 17 | `Article `_ created by: `James Gallagher`. 18 | 19 | * |:book:| Our search for the best OCR tool in 2023, and what we found: 20 | `Article `_ created by: `Sanjin Ibrahimovic`. 21 | 22 | * |:book:| Real-time information extraction from documents with docTR: 23 | `Article `_ created by: `Yugesh Verma`. 24 | -------------------------------------------------------------------------------- /docs/source/contributing/code_of_conduct.md: -------------------------------------------------------------------------------- 1 | ../../../CODE_OF_CONDUCT.md -------------------------------------------------------------------------------- /docs/source/contributing/contributing.md: -------------------------------------------------------------------------------- 1 | ../../../CONTRIBUTING.md -------------------------------------------------------------------------------- /docs/source/getting_started/installing.rst: -------------------------------------------------------------------------------- 1 | 2 | ************ 3 | Installation 4 | ************ 5 | 6 | This library requires `Python `_ 3.10 or higher. 7 | 8 | 9 | Via Python Package 10 | ================== 11 | 12 | Install the last stable release of the package using `pip `_: 13 | 14 | .. code:: bash 15 | 16 | pip install python-doctr 17 | 18 | 19 | We strive towards reducing framework-specific dependencies to a minimum, but some necessary features are developed by third-parties for specific frameworks. To avoid missing some dependencies for a specific framework, you can install specific builds as follows: 20 | 21 | .. code:: bash 22 | 23 | pip install python-doctr 24 | # or with preinstalled packages for visualization & html & contrib module support 25 | pip install "python-doctr[viz,html,contrib]" 26 | 27 | 28 | Via Git 29 | ======= 30 | 31 | Install the library in developer mode: 32 | 33 | 34 | .. code:: bash 35 | 36 | git clone https://github.com/mindee/doctr.git 37 | pip install -e doctr/. 38 | -------------------------------------------------------------------------------- /docs/source/modules/contrib.rst: -------------------------------------------------------------------------------- 1 | doctr.contrib 2 | ============= 3 | 4 | .. currentmodule:: doctr.contrib 5 | 6 | This module contains all the available contribution modules for docTR. 7 | 8 | 9 | Supported contribution modules 10 | ------------------------------ 11 | Here are all the available contribution modules: 12 | 13 | .. autoclass:: ArtefactDetector 14 | -------------------------------------------------------------------------------- /docs/source/modules/io.rst: -------------------------------------------------------------------------------- 1 | doctr.io 2 | ======== 3 | 4 | 5 | .. currentmodule:: doctr.io 6 | 7 | The io module enables users to easily access content from documents and export analysis 8 | results to structured formats. 9 | 10 | .. _document_structure: 11 | 12 | Document structure 13 | ------------------ 14 | 15 | Structural organization of the documents. 16 | 17 | Word 18 | ^^^^ 19 | A Word is an uninterrupted sequence of characters. 20 | 21 | .. autoclass:: Word 22 | 23 | Line 24 | ^^^^ 25 | A Line is a collection of Words aligned spatially and meant to be read together (on a two-column page, on the same horizontal, we will consider that there are two Lines). 26 | 27 | .. autoclass:: Line 28 | 29 | Artefact 30 | ^^^^^^^^ 31 | 32 | An Artefact is a non-textual element (e.g. QR code, picture, chart, signature, logo, etc.). 33 | 34 | .. autoclass:: Artefact 35 | 36 | Block 37 | ^^^^^ 38 | A Block is a collection of Lines (e.g. an address written on several lines) and Artefacts (e.g. a graph with its title underneath). 39 | 40 | .. autoclass:: Block 41 | 42 | Page 43 | ^^^^ 44 | 45 | A Page is a collection of Blocks that were on the same physical page. 46 | 47 | .. autoclass:: Page 48 | 49 | .. automethod:: show 50 | 51 | 52 | Document 53 | ^^^^^^^^ 54 | 55 | A Document is a collection of Pages. 56 | 57 | .. autoclass:: Document 58 | 59 | .. automethod:: show 60 | 61 | 62 | File reading 63 | ------------ 64 | 65 | High-performance file reading and conversion to processable structured data. 66 | 67 | .. autofunction:: read_pdf 68 | 69 | .. autofunction:: read_img_as_numpy 70 | 71 | .. autofunction:: read_img_as_tensor 72 | 73 | .. autofunction:: decode_img_as_tensor 74 | 75 | .. autofunction:: read_html 76 | 77 | 78 | .. autoclass:: DocumentFile 79 | 80 | .. automethod:: from_pdf 81 | 82 | .. automethod:: from_url 83 | 84 | .. automethod:: from_images 85 | -------------------------------------------------------------------------------- /docs/source/modules/models.rst: -------------------------------------------------------------------------------- 1 | doctr.models 2 | ============ 3 | 4 | .. currentmodule:: doctr.models 5 | 6 | 7 | doctr.models.classification 8 | --------------------------- 9 | 10 | .. autofunction:: doctr.models.classification.vgg16_bn_r 11 | 12 | .. autofunction:: doctr.models.classification.resnet18 13 | 14 | .. autofunction:: doctr.models.classification.resnet34 15 | 16 | .. autofunction:: doctr.models.classification.resnet50 17 | 18 | .. autofunction:: doctr.models.classification.resnet31 19 | 20 | .. autofunction:: doctr.models.classification.mobilenet_v3_small 21 | 22 | .. autofunction:: doctr.models.classification.mobilenet_v3_large 23 | 24 | .. autofunction:: doctr.models.classification.mobilenet_v3_small_r 25 | 26 | .. autofunction:: doctr.models.classification.mobilenet_v3_large_r 27 | 28 | .. autofunction:: doctr.models.classification.mobilenet_v3_small_crop_orientation 29 | 30 | .. autofunction:: doctr.models.classification.mobilenet_v3_small_page_orientation 31 | 32 | .. autofunction:: doctr.models.classification.magc_resnet31 33 | 34 | .. autofunction:: doctr.models.classification.vit_s 35 | 36 | .. autofunction:: doctr.models.classification.vit_b 37 | 38 | .. autofunction:: doctr.models.classification.textnet_tiny 39 | 40 | .. autofunction:: doctr.models.classification.textnet_small 41 | 42 | .. autofunction:: doctr.models.classification.textnet_base 43 | 44 | .. autofunction:: doctr.models.classification.vip_tiny 45 | 46 | .. autofunction:: doctr.models.classification.vip_base 47 | 48 | .. autofunction:: doctr.models.classification.crop_orientation_predictor 49 | 50 | .. autofunction:: doctr.models.classification.page_orientation_predictor 51 | 52 | 53 | doctr.models.detection 54 | ---------------------- 55 | 56 | .. autofunction:: doctr.models.detection.linknet_resnet18 57 | 58 | .. autofunction:: doctr.models.detection.linknet_resnet34 59 | 60 | .. autofunction:: doctr.models.detection.linknet_resnet50 61 | 62 | .. autofunction:: doctr.models.detection.db_resnet50 63 | 64 | .. autofunction:: doctr.models.detection.db_mobilenet_v3_large 65 | 66 | .. autofunction:: doctr.models.detection.fast_tiny 67 | 68 | .. autofunction:: doctr.models.detection.fast_small 69 | 70 | .. autofunction:: doctr.models.detection.fast_base 71 | 72 | .. autofunction:: doctr.models.detection.detection_predictor 73 | 74 | 75 | doctr.models.recognition 76 | ------------------------ 77 | 78 | .. autofunction:: doctr.models.recognition.crnn_vgg16_bn 79 | 80 | .. autofunction:: doctr.models.recognition.crnn_mobilenet_v3_small 81 | 82 | .. autofunction:: doctr.models.recognition.crnn_mobilenet_v3_large 83 | 84 | .. autofunction:: doctr.models.recognition.sar_resnet31 85 | 86 | .. autofunction:: doctr.models.recognition.master 87 | 88 | .. autofunction:: doctr.models.recognition.vitstr_small 89 | 90 | .. autofunction:: doctr.models.recognition.vitstr_base 91 | 92 | .. autofunction:: doctr.models.recognition.parseq 93 | 94 | .. autofunction:: doctr.models.recognition.viptr_tiny 95 | 96 | .. autofunction:: doctr.models.recognition.recognition_predictor 97 | 98 | 99 | doctr.models.zoo 100 | ---------------- 101 | 102 | .. autofunction:: doctr.models.ocr_predictor 103 | 104 | .. autofunction:: doctr.models.kie_predictor 105 | 106 | 107 | doctr.models.factory 108 | -------------------- 109 | 110 | .. autofunction:: doctr.models.factory.login_to_hub 111 | 112 | .. autofunction:: doctr.models.factory.from_hub 113 | 114 | .. autofunction:: doctr.models.factory.push_to_hf_hub 115 | -------------------------------------------------------------------------------- /docs/source/modules/transforms.rst: -------------------------------------------------------------------------------- 1 | doctr.transforms 2 | ================ 3 | 4 | .. currentmodule:: doctr.transforms 5 | 6 | Data transformations are part of both training and inference procedure. Drawing inspiration from the design of `torchvision `_, we express transformations as composable modules. 7 | 8 | 9 | Supported transformations 10 | ------------------------- 11 | Here are all transformations that are available through docTR: 12 | 13 | .. currentmodule:: doctr.transforms.modules 14 | 15 | .. autoclass:: Resize 16 | .. autoclass:: GaussianNoise 17 | .. autoclass:: ChannelShuffle 18 | .. autoclass:: RandomHorizontalFlip 19 | .. autoclass:: RandomShadow 20 | .. autoclass:: RandomResize 21 | 22 | 23 | Composing transformations 24 | --------------------------------------------- 25 | It is common to require several transformations to be performed consecutively. 26 | 27 | .. autoclass:: SampleCompose 28 | .. autoclass:: ImageTransform 29 | .. autoclass:: ColorInversion 30 | .. autoclass:: OneOf 31 | .. autoclass:: RandomApply 32 | .. autoclass:: RandomRotate 33 | .. autoclass:: RandomCrop 34 | -------------------------------------------------------------------------------- /docs/source/modules/utils.rst: -------------------------------------------------------------------------------- 1 | doctr.utils 2 | =========== 3 | 4 | This module regroups non-core features that are complementary to the rest of the package. 5 | 6 | .. currentmodule:: doctr.utils 7 | 8 | 9 | Visualization 10 | ------------- 11 | Easy-to-use functions to make sense of your model's predictions. 12 | 13 | .. currentmodule:: doctr.utils.visualization 14 | 15 | .. autofunction:: visualize_page 16 | 17 | Reconstitution 18 | --------------- 19 | 20 | .. currentmodule:: doctr.utils.reconstitution 21 | 22 | .. autofunction:: synthesize_page 23 | 24 | 25 | .. _metrics: 26 | 27 | Task evaluation 28 | --------------- 29 | Implementations of task-specific metrics to easily assess your model performances. 30 | 31 | .. currentmodule:: doctr.utils.metrics 32 | 33 | .. autoclass:: TextMatch 34 | 35 | .. automethod:: update 36 | .. automethod:: summary 37 | 38 | .. autoclass:: LocalizationConfusion 39 | 40 | .. automethod:: update 41 | .. automethod:: summary 42 | 43 | .. autoclass:: OCRMetric 44 | 45 | .. automethod:: update 46 | .. automethod:: summary 47 | 48 | .. autoclass:: DetectionMetric 49 | 50 | .. automethod:: update 51 | .. automethod:: summary 52 | -------------------------------------------------------------------------------- /docs/source/notebooks.rst: -------------------------------------------------------------------------------- 1 | ../../notebooks/README.rst -------------------------------------------------------------------------------- /docs/source/using_doctr/running_on_aws.rst: -------------------------------------------------------------------------------- 1 | AWS Lambda 2 | ========== 3 | 4 | The security policy of `AWS Lambda `_ restricts writing outside the ``/tmp`` directory. 5 | 6 | To make docTR work on Lambda, you need to perform the following two steps: 7 | 8 | 1. Disable the usage of the ``multiprocessing`` package by setting the ``DOCTR_MULTIPROCESSING_DISABLE`` environment variable to ``TRUE``. This step is necessary because the package uses the ``/dev/shm`` directory for shared memory. 9 | 10 | 2. Change the caching directory used by docTR for models. By default, it is set to ``~/.cache/doctr``, which is outside the ``/tmp`` directory on AWS Lambda. You can modify this by setting the ``DOCTR_CACHE_DIR`` environment variable. 11 | -------------------------------------------------------------------------------- /docs/source/using_doctr/using_contrib_modules.rst: -------------------------------------------------------------------------------- 1 | Integrate contributions into your pipeline 2 | ========================================== 3 | 4 | The `contrib` module provides a collection of additional features which could be relevant for your document analysis pipeline. 5 | The following sections will give you an overview of the available modules and features. 6 | 7 | .. currentmodule:: doctr.contrib 8 | 9 | 10 | Available contribution modules 11 | ------------------------------ 12 | 13 | **NOTE:** To use the contrib module, you need to install the `onnxruntime` package. You can install it using the following command: 14 | 15 | .. code:: bash 16 | 17 | pip install python-doctr[contrib] 18 | # Or 19 | pip install onnxruntime # pip install onnxruntime-gpu 20 | 21 | Here are all contribution modules that are available through docTR: 22 | 23 | ArtefactDetection 24 | ^^^^^^^^^^^^^^^^^ 25 | 26 | The ArtefactDetection module provides a set of functions to detect artefacts in the document images, such as logos, QR codes, bar codes, etc. 27 | It is based on the YOLOv8 architecture, which is a state-of-the-art object detection model. 28 | 29 | .. code:: python3 30 | 31 | from doctr.io import DocumentFile 32 | from doctr.contrib.artefacts import ArtefactDetection 33 | 34 | # Load the document 35 | doc = DocumentFile.from_images(["path/to/your/image"]) 36 | detector = ArtefactDetection(batch_size=2, conf_threshold=0.5, iou_threshold=0.5) 37 | artefacts = detector(doc) 38 | 39 | # Visualize the detected artefacts 40 | detector.show() 41 | 42 | You can also use your custom trained YOLOv8 model to detect artefacts or anything else you need. 43 | Reference: `YOLOv8 `_ 44 | 45 | **NOTE:** The YOLOv8 model (no Oriented Bounding Box (OBB) inference supported yet) needs to be provided as onnx exported model with a dynamic batch size. 46 | 47 | .. code:: python3 48 | 49 | from doctr.contrib import ArtefactDetection 50 | 51 | detector = ArtefactDetection(model_path="path/to/your/model.onnx", labels=["table", "figure"]) 52 | -------------------------------------------------------------------------------- /doctr/__init__.py: -------------------------------------------------------------------------------- 1 | from . import io, models, datasets, contrib, transforms, utils 2 | from .version import __version__ # noqa: F401 3 | -------------------------------------------------------------------------------- /doctr/contrib/__init__.py: -------------------------------------------------------------------------------- 1 | from .artefacts import ArtefactDetector 2 | -------------------------------------------------------------------------------- /doctr/contrib/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import Any 7 | 8 | import numpy as np 9 | 10 | from doctr.file_utils import requires_package 11 | from doctr.utils.data import download_from_url 12 | 13 | 14 | class _BasePredictor: 15 | """ 16 | Base class for all predictors 17 | 18 | Args: 19 | batch_size: the batch size to use 20 | url: the url to use to download a model if needed 21 | model_path: the path to the model to use 22 | **kwargs: additional arguments to be passed to `download_from_url` 23 | """ 24 | 25 | def __init__(self, batch_size: int, url: str | None = None, model_path: str | None = None, **kwargs) -> None: 26 | self.batch_size = batch_size 27 | self.session = self._init_model(url, model_path, **kwargs) 28 | 29 | self._inputs: list[np.ndarray] = [] 30 | self._results: list[Any] = [] 31 | 32 | def _init_model(self, url: str | None = None, model_path: str | None = None, **kwargs: Any) -> Any: 33 | """ 34 | Download the model from the given url if needed 35 | 36 | Args: 37 | url: the url to use 38 | model_path: the path to the model to use 39 | **kwargs: additional arguments to be passed to `download_from_url` 40 | 41 | Returns: 42 | Any: the ONNX loaded model 43 | """ 44 | requires_package("onnxruntime", "`.contrib` module requires `onnxruntime` to be installed.") 45 | import onnxruntime as ort 46 | 47 | if not url and not model_path: 48 | raise ValueError("You must provide either a url or a model_path") 49 | onnx_model_path = model_path if model_path else str(download_from_url(url, cache_subdir="models", **kwargs)) # type: ignore[arg-type] 50 | return ort.InferenceSession(onnx_model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) 51 | 52 | def preprocess(self, img: np.ndarray) -> np.ndarray: 53 | """ 54 | Preprocess the input image 55 | 56 | Args: 57 | img: the input image to preprocess 58 | 59 | Returns: 60 | np.ndarray: the preprocessed image 61 | """ 62 | raise NotImplementedError 63 | 64 | def postprocess(self, output: list[np.ndarray], input_images: list[list[np.ndarray]]) -> Any: 65 | """ 66 | Postprocess the model output 67 | 68 | Args: 69 | output: the model output to postprocess 70 | input_images: the input images used to generate the output 71 | 72 | Returns: 73 | Any: the postprocessed output 74 | """ 75 | raise NotImplementedError 76 | 77 | def __call__(self, inputs: list[np.ndarray]) -> Any: 78 | """ 79 | Call the model on the given inputs 80 | 81 | Args: 82 | inputs: the inputs to use 83 | 84 | Returns: 85 | Any: the postprocessed output 86 | """ 87 | self._inputs = inputs 88 | model_inputs = self.session.get_inputs() 89 | 90 | batched_inputs = [inputs[i : i + self.batch_size] for i in range(0, len(inputs), self.batch_size)] 91 | processed_batches = [ 92 | np.array([self.preprocess(img) for img in batch], dtype=np.float32) for batch in batched_inputs 93 | ] 94 | 95 | outputs = [self.session.run(None, {model_inputs[0].name: batch}) for batch in processed_batches] 96 | return self.postprocess(outputs, batched_inputs) 97 | -------------------------------------------------------------------------------- /doctr/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .generator import * 2 | from .coco_text import * 3 | from .cord import * 4 | from .detection import * 5 | from .doc_artefacts import * 6 | from .funsd import * 7 | from .ic03 import * 8 | from .ic13 import * 9 | from .iiit5k import * 10 | from .iiithws import * 11 | from .imgur5k import * 12 | from .mjsynth import * 13 | from .ocr import * 14 | from .recognition import * 15 | from .orientation import * 16 | from .sroie import * 17 | from .svhn import * 18 | from .svt import * 19 | from .synthtext import * 20 | from .utils import * 21 | from .vocabs import * 22 | from .wildreceipt import * 23 | -------------------------------------------------------------------------------- /doctr/datasets/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import * 2 | -------------------------------------------------------------------------------- /doctr/datasets/datasets/pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import os 7 | from copy import deepcopy 8 | from typing import Any 9 | 10 | import numpy as np 11 | import torch 12 | 13 | from doctr.io import read_img_as_tensor, tensor_from_numpy 14 | 15 | from .base import _AbstractDataset, _VisionDataset 16 | 17 | __all__ = ["AbstractDataset", "VisionDataset"] 18 | 19 | 20 | class AbstractDataset(_AbstractDataset): 21 | """Abstract class for all datasets""" 22 | 23 | def _read_sample(self, index: int) -> tuple[torch.Tensor, Any]: 24 | img_name, target = self.data[index] 25 | 26 | # Check target 27 | if isinstance(target, dict): 28 | assert "boxes" in target, "Target should contain 'boxes' key" 29 | assert "labels" in target, "Target should contain 'labels' key" 30 | elif isinstance(target, tuple): 31 | assert len(target) == 2 32 | assert isinstance(target[0], str) or isinstance(target[0], np.ndarray), ( 33 | "first element of the tuple should be a string or a numpy array" 34 | ) 35 | assert isinstance(target[1], list), "second element of the tuple should be a list" 36 | else: 37 | assert isinstance(target, str) or isinstance(target, np.ndarray), ( 38 | "Target should be a string or a numpy array" 39 | ) 40 | 41 | # Read image 42 | img = ( 43 | tensor_from_numpy(img_name, dtype=torch.float32) 44 | if isinstance(img_name, np.ndarray) 45 | else read_img_as_tensor(os.path.join(self.root, img_name), dtype=torch.float32) 46 | ) 47 | 48 | return img, deepcopy(target) 49 | 50 | @staticmethod 51 | def collate_fn(samples: list[tuple[torch.Tensor, Any]]) -> tuple[torch.Tensor, list[Any]]: 52 | images, targets = zip(*samples) 53 | images = torch.stack(images, dim=0) # type: ignore[assignment] 54 | 55 | return images, list(targets) # type: ignore[return-value] 56 | 57 | 58 | class VisionDataset(AbstractDataset, _VisionDataset): # noqa: D101 59 | pass 60 | -------------------------------------------------------------------------------- /doctr/datasets/doc_artefacts.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import json 7 | import os 8 | from typing import Any 9 | 10 | import numpy as np 11 | 12 | from .datasets import VisionDataset 13 | 14 | __all__ = ["DocArtefacts"] 15 | 16 | 17 | class DocArtefacts(VisionDataset): 18 | """Object detection dataset for non-textual elements in documents. 19 | The dataset includes a variety of synthetic document pages with non-textual elements. 20 | 21 | .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/artefacts-grid.png&src=0 22 | :align: center 23 | 24 | >>> from doctr.datasets import DocArtefacts 25 | >>> train_set = DocArtefacts(train=True, download=True) 26 | >>> img, target = train_set[0] 27 | 28 | Args: 29 | train: whether the subset should be the training one 30 | use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) 31 | **kwargs: keyword arguments from `VisionDataset`. 32 | """ 33 | 34 | URL = "https://doctr-static.mindee.com/models?id=v0.4.0/artefact_detection-13fab8ce.zip&src=0" 35 | SHA256 = "13fab8ced7f84583d9dccd0c634f046c3417e62a11fe1dea6efbbaba5052471b" 36 | CLASSES = ["background", "qr_code", "bar_code", "logo", "photo"] 37 | 38 | def __init__( 39 | self, 40 | train: bool = True, 41 | use_polygons: bool = False, 42 | **kwargs: Any, 43 | ) -> None: 44 | super().__init__(self.URL, None, self.SHA256, True, **kwargs) 45 | self.train = train 46 | 47 | # Update root 48 | self.root = os.path.join(self.root, "train" if train else "val") 49 | # List images 50 | tmp_root = os.path.join(self.root, "images") 51 | with open(os.path.join(self.root, "labels.json"), "rb") as f: 52 | labels = json.load(f) 53 | self.data: list[tuple[str, dict[str, Any]]] = [] 54 | img_list = os.listdir(tmp_root) 55 | if len(labels) != len(img_list): 56 | raise AssertionError("the number of images and labels do not match") 57 | np_dtype = np.float32 58 | for img_name, label in labels.items(): 59 | # File existence check 60 | if not os.path.exists(os.path.join(tmp_root, img_name)): 61 | raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_name)}") 62 | 63 | # xmin, ymin, xmax, ymax 64 | boxes: np.ndarray = np.asarray([obj["geometry"] for obj in label], dtype=np_dtype) 65 | classes: np.ndarray = np.asarray([self.CLASSES.index(obj["label"]) for obj in label], dtype=np.int64) 66 | if use_polygons: 67 | # (x, y) coordinates of top left, top right, bottom right, bottom left corners 68 | boxes = np.stack( 69 | [ 70 | np.stack([boxes[:, 0], boxes[:, 1]], axis=-1), 71 | np.stack([boxes[:, 2], boxes[:, 1]], axis=-1), 72 | np.stack([boxes[:, 2], boxes[:, 3]], axis=-1), 73 | np.stack([boxes[:, 0], boxes[:, 3]], axis=-1), 74 | ], 75 | axis=1, 76 | ) 77 | self.data.append((img_name, dict(boxes=boxes, labels=classes))) 78 | self.root = tmp_root 79 | 80 | def extra_repr(self) -> str: 81 | return f"train={self.train}" 82 | -------------------------------------------------------------------------------- /doctr/datasets/generator/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import * 2 | -------------------------------------------------------------------------------- /doctr/datasets/generator/pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from torch.utils.data._utils.collate import default_collate 7 | 8 | from .base import _CharacterGenerator, _WordGenerator 9 | 10 | __all__ = ["CharacterGenerator", "WordGenerator"] 11 | 12 | 13 | class CharacterGenerator(_CharacterGenerator): 14 | """Implements a character image generation dataset 15 | 16 | >>> from doctr.datasets import CharacterGenerator 17 | >>> ds = CharacterGenerator(vocab='abdef', num_samples=100) 18 | >>> img, target = ds[0] 19 | 20 | Args: 21 | vocab: vocabulary to take the character from 22 | num_samples: number of samples that will be generated iterating over the dataset 23 | cache_samples: whether generated images should be cached firsthand 24 | font_family: font to use to generate the text images 25 | img_transforms: composable transformations that will be applied to each image 26 | sample_transforms: composable transformations that will be applied to both the image and the target 27 | """ 28 | 29 | def __init__(self, *args, **kwargs) -> None: 30 | super().__init__(*args, **kwargs) 31 | setattr(self, "collate_fn", default_collate) 32 | 33 | 34 | class WordGenerator(_WordGenerator): 35 | """Implements a character image generation dataset 36 | 37 | >>> from doctr.datasets import WordGenerator 38 | >>> ds = WordGenerator(vocab='abdef', min_chars=1, max_chars=32, num_samples=100) 39 | >>> img, target = ds[0] 40 | 41 | Args: 42 | vocab: vocabulary to take the character from 43 | min_chars: minimum number of characters in a word 44 | max_chars: maximum number of characters in a word 45 | num_samples: number of samples that will be generated iterating over the dataset 46 | cache_samples: whether generated images should be cached firsthand 47 | font_family: font to use to generate the text images 48 | img_transforms: composable transformations that will be applied to each image 49 | sample_transforms: composable transformations that will be applied to both the image and the target 50 | """ 51 | 52 | pass 53 | -------------------------------------------------------------------------------- /doctr/datasets/iiithws.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import os 7 | from random import sample 8 | from typing import Any 9 | 10 | from tqdm import tqdm 11 | 12 | from .datasets import AbstractDataset 13 | 14 | __all__ = ["IIITHWS"] 15 | 16 | 17 | class IIITHWS(AbstractDataset): 18 | """IIITHWS dataset from `"Generating Synthetic Data for Text Recognition" 19 | `_ | `"repository" `_ | 20 | `"website" `_. 21 | 22 | >>> # NOTE: This is a pure recognition dataset without bounding box labels. 23 | >>> # NOTE: You need to download the dataset. 24 | >>> from doctr.datasets import IIITHWS 25 | >>> train_set = IIITHWS(img_folder="/path/to/iiit-hws/Images_90K_Normalized", 26 | >>> label_path="/path/to/IIIT-HWS-90K.txt", 27 | >>> train=True) 28 | >>> img, target = train_set[0] 29 | >>> test_set = IIITHWS(img_folder="/path/to/iiit-hws/Images_90K_Normalized", 30 | >>> label_path="/path/to/IIIT-HWS-90K.txt") 31 | >>> train=False) 32 | >>> img, target = test_set[0] 33 | 34 | Args: 35 | img_folder: folder with all the images of the dataset 36 | label_path: path to the file with the labels 37 | train: whether the subset should be the training one 38 | **kwargs: keyword arguments from `AbstractDataset`. 39 | """ 40 | 41 | def __init__( 42 | self, 43 | img_folder: str, 44 | label_path: str, 45 | train: bool = True, 46 | **kwargs: Any, 47 | ) -> None: 48 | super().__init__(img_folder, **kwargs) 49 | 50 | # File existence check 51 | if not os.path.exists(label_path) or not os.path.exists(img_folder): 52 | raise FileNotFoundError(f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}") 53 | 54 | self.data: list[tuple[str, str]] = [] 55 | self.train = train 56 | 57 | with open(label_path) as f: 58 | annotations = f.readlines() 59 | 60 | # Shuffle the dataset otherwise the test set will contain the same labels n times 61 | annotations = sample(annotations, len(annotations)) 62 | train_samples = int(len(annotations) * 0.9) 63 | set_slice = slice(train_samples) if self.train else slice(train_samples, None) 64 | 65 | for annotation in tqdm( 66 | iterable=annotations[set_slice], desc="Preparing and Loading IIITHWS", total=len(annotations[set_slice]) 67 | ): 68 | img_path, label = annotation.split()[0:2] 69 | img_path = os.path.join(img_folder, img_path) 70 | 71 | self.data.append((img_path, label)) 72 | 73 | def extra_repr(self) -> str: 74 | return f"train={self.train}" 75 | -------------------------------------------------------------------------------- /doctr/datasets/ocr.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import json 7 | import os 8 | from pathlib import Path 9 | from typing import Any 10 | 11 | import numpy as np 12 | 13 | from .datasets import AbstractDataset 14 | 15 | __all__ = ["OCRDataset"] 16 | 17 | 18 | class OCRDataset(AbstractDataset): 19 | """Implements an OCR dataset 20 | 21 | >>> from doctr.datasets import OCRDataset 22 | >>> train_set = OCRDataset(img_folder="/path/to/images", 23 | >>> label_file="/path/to/labels.json") 24 | >>> img, target = train_set[0] 25 | 26 | Args: 27 | img_folder: local path to image folder (all jpg at the root) 28 | label_file: local path to the label file 29 | use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) 30 | **kwargs: keyword arguments from `AbstractDataset`. 31 | """ 32 | 33 | def __init__( 34 | self, 35 | img_folder: str, 36 | label_file: str, 37 | use_polygons: bool = False, 38 | **kwargs: Any, 39 | ) -> None: 40 | super().__init__(img_folder, **kwargs) 41 | 42 | # List images 43 | self.data: list[tuple[Path, dict[str, Any]]] = [] 44 | np_dtype = np.float32 45 | with open(label_file, "rb") as f: 46 | data = json.load(f) 47 | 48 | for img_name, annotations in data.items(): 49 | # Get image path 50 | img_name = Path(img_name) 51 | # File existence check 52 | if not os.path.exists(os.path.join(self.root, img_name)): 53 | raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}") 54 | 55 | # handle empty images 56 | if len(annotations["typed_words"]) == 0: 57 | self.data.append((img_name, dict(boxes=np.zeros((0, 4), dtype=np_dtype), labels=[]))) 58 | continue 59 | # Unpack the straight boxes (xmin, ymin, xmax, ymax) 60 | geoms = [list(map(float, obj["geometry"][:4])) for obj in annotations["typed_words"]] 61 | if use_polygons: 62 | # (x, y) coordinates of top left, top right, bottom right, bottom left corners 63 | geoms = [ 64 | [geom[:2], [geom[2], geom[1]], geom[2:], [geom[0], geom[3]]] # type: ignore[list-item] 65 | for geom in geoms 66 | ] 67 | 68 | text_targets = [obj["value"] for obj in annotations["typed_words"]] 69 | 70 | self.data.append((img_name, dict(boxes=np.asarray(geoms, dtype=np_dtype), labels=text_targets))) 71 | -------------------------------------------------------------------------------- /doctr/datasets/orientation.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import os 7 | from typing import Any 8 | 9 | import numpy as np 10 | 11 | from .datasets import AbstractDataset 12 | 13 | __all__ = ["OrientationDataset"] 14 | 15 | 16 | class OrientationDataset(AbstractDataset): 17 | """Implements a basic image dataset where targets are filled with zeros. 18 | 19 | >>> from doctr.datasets import OrientationDataset 20 | >>> train_set = OrientationDataset(img_folder="/path/to/images") 21 | >>> img, target = train_set[0] 22 | 23 | Args: 24 | img_folder: folder with all the images of the dataset 25 | **kwargs: keyword arguments from `AbstractDataset`. 26 | """ 27 | 28 | def __init__( 29 | self, 30 | img_folder: str, 31 | **kwargs: Any, 32 | ) -> None: 33 | super().__init__( 34 | img_folder, 35 | **kwargs, 36 | ) 37 | 38 | # initialize dataset with 0 degree rotation targets 39 | self.data: list[tuple[str, np.ndarray]] = [(img_name, np.array([0])) for img_name in os.listdir(self.root)] 40 | -------------------------------------------------------------------------------- /doctr/datasets/recognition.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import json 7 | import os 8 | from pathlib import Path 9 | from typing import Any 10 | 11 | from .datasets import AbstractDataset 12 | 13 | __all__ = ["RecognitionDataset"] 14 | 15 | 16 | class RecognitionDataset(AbstractDataset): 17 | """Dataset implementation for text recognition tasks 18 | 19 | >>> from doctr.datasets import RecognitionDataset 20 | >>> train_set = RecognitionDataset(img_folder="/path/to/images", 21 | >>> labels_path="/path/to/labels.json") 22 | >>> img, target = train_set[0] 23 | 24 | Args: 25 | img_folder: path to the images folder 26 | labels_path: path to the json file containing all labels (character sequences) 27 | **kwargs: keyword arguments from `AbstractDataset`. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | img_folder: str, 33 | labels_path: str, 34 | **kwargs: Any, 35 | ) -> None: 36 | super().__init__(img_folder, **kwargs) 37 | 38 | self.data: list[tuple[str, str]] = [] 39 | with open(labels_path, encoding="utf-8") as f: 40 | labels = json.load(f) 41 | 42 | for img_name, label in labels.items(): 43 | if not os.path.exists(os.path.join(self.root, img_name)): 44 | raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}") 45 | 46 | self.data.append((img_name, label)) 47 | 48 | def merge_dataset(self, ds: AbstractDataset) -> None: 49 | # Update data with new root for self 50 | self.data = [(str(Path(self.root).joinpath(img_path)), label) for img_path, label in self.data] 51 | # Define new root 52 | self.root = Path("/") 53 | # Merge with ds data 54 | for img_path, label in ds.data: 55 | self.data.append((str(Path(ds.root).joinpath(img_path)), label)) 56 | -------------------------------------------------------------------------------- /doctr/file_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import importlib.metadata 7 | import logging 8 | 9 | __all__ = ["requires_package", "CLASS_NAME"] 10 | 11 | CLASS_NAME: str = "words" 12 | ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} 13 | 14 | 15 | def requires_package(name: str, extra_message: str | None = None) -> None: # pragma: no cover 16 | """ 17 | package requirement helper 18 | 19 | Args: 20 | name: name of the package 21 | extra_message: additional message to display if the package is not found 22 | """ 23 | try: 24 | _pkg_version = importlib.metadata.version(name) 25 | logging.info(f"{name} version {_pkg_version} available.") 26 | except importlib.metadata.PackageNotFoundError: 27 | raise ImportError( 28 | f"\n\n{extra_message if extra_message is not None else ''} " 29 | f"\nPlease install it with the following command: pip install {name}\n" 30 | ) 31 | -------------------------------------------------------------------------------- /doctr/io/__init__.py: -------------------------------------------------------------------------------- 1 | from .elements import * 2 | from .html import * 3 | from .image import * 4 | from .pdf import * 5 | from .reader import * 6 | -------------------------------------------------------------------------------- /doctr/io/html.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import Any 7 | 8 | __all__ = ["read_html"] 9 | 10 | 11 | def read_html(url: str, **kwargs: Any) -> bytes: 12 | """Read a PDF file and convert it into an image in numpy format 13 | 14 | >>> from doctr.io import read_html 15 | >>> doc = read_html("https://www.yoursite.com") 16 | 17 | Args: 18 | url: URL of the target web page 19 | **kwargs: keyword arguments from `weasyprint.HTML` 20 | 21 | Returns: 22 | decoded PDF file as a bytes stream 23 | """ 24 | from weasyprint import HTML 25 | 26 | return HTML(url, **kwargs).write_pdf() 27 | -------------------------------------------------------------------------------- /doctr/io/image/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .pytorch import * 3 | -------------------------------------------------------------------------------- /doctr/io/image/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from pathlib import Path 7 | 8 | import cv2 9 | import numpy as np 10 | 11 | from doctr.utils.common_types import AbstractFile 12 | 13 | __all__ = ["read_img_as_numpy"] 14 | 15 | 16 | def read_img_as_numpy( 17 | file: AbstractFile, 18 | output_size: tuple[int, int] | None = None, 19 | rgb_output: bool = True, 20 | ) -> np.ndarray: 21 | """Read an image file into numpy format 22 | 23 | >>> from doctr.io import read_img_as_numpy 24 | >>> page = read_img_as_numpy("path/to/your/doc.jpg") 25 | 26 | Args: 27 | file: the path to the image file 28 | output_size: the expected output size of each page in format H x W 29 | rgb_output: whether the output ndarray channel order should be RGB instead of BGR. 30 | 31 | Returns: 32 | the page decoded as numpy ndarray of shape H x W x 3 33 | """ 34 | if isinstance(file, (str, Path)): 35 | if not Path(file).is_file(): 36 | raise FileNotFoundError(f"unable to access {file}") 37 | img = cv2.imread(str(file), cv2.IMREAD_COLOR) 38 | elif isinstance(file, bytes): 39 | _file: np.ndarray = np.frombuffer(file, np.uint8) 40 | img = cv2.imdecode(_file, cv2.IMREAD_COLOR) 41 | else: 42 | raise TypeError("unsupported object type for argument 'file'") 43 | 44 | # Validity check 45 | if img is None: 46 | raise ValueError("unable to read file.") 47 | # Resizing 48 | if isinstance(output_size, tuple): 49 | img = cv2.resize(img, output_size[::-1], interpolation=cv2.INTER_LINEAR) 50 | # Switch the channel order 51 | if rgb_output: 52 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 53 | return img 54 | -------------------------------------------------------------------------------- /doctr/io/image/pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from io import BytesIO 7 | 8 | import numpy as np 9 | import torch 10 | from PIL import Image 11 | from torchvision.transforms.functional import to_tensor 12 | 13 | from doctr.utils.common_types import AbstractPath 14 | 15 | __all__ = ["tensor_from_pil", "read_img_as_tensor", "decode_img_as_tensor", "tensor_from_numpy", "get_img_shape"] 16 | 17 | 18 | def tensor_from_pil(pil_img: Image.Image, dtype: torch.dtype = torch.float32) -> torch.Tensor: 19 | """Convert a PIL Image to a PyTorch tensor 20 | 21 | Args: 22 | pil_img: a PIL image 23 | dtype: the output tensor data type 24 | 25 | Returns: 26 | decoded image as tensor 27 | """ 28 | if dtype == torch.float32: 29 | img = to_tensor(pil_img) 30 | else: 31 | img = tensor_from_numpy(np.array(pil_img, np.uint8, copy=True), dtype) 32 | 33 | return img 34 | 35 | 36 | def read_img_as_tensor(img_path: AbstractPath, dtype: torch.dtype = torch.float32) -> torch.Tensor: 37 | """Read an image file as a PyTorch tensor 38 | 39 | Args: 40 | img_path: location of the image file 41 | dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255. 42 | 43 | Returns: 44 | decoded image as a tensor 45 | """ 46 | if dtype not in (torch.uint8, torch.float16, torch.float32): 47 | raise ValueError("insupported value for dtype") 48 | 49 | with Image.open(img_path, mode="r") as pil_img: 50 | return tensor_from_pil(pil_img.convert("RGB"), dtype) 51 | 52 | 53 | def decode_img_as_tensor(img_content: bytes, dtype: torch.dtype = torch.float32) -> torch.Tensor: 54 | """Read a byte stream as a PyTorch tensor 55 | 56 | Args: 57 | img_content: bytes of a decoded image 58 | dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255. 59 | 60 | Returns: 61 | decoded image as a tensor 62 | """ 63 | if dtype not in (torch.uint8, torch.float16, torch.float32): 64 | raise ValueError("insupported value for dtype") 65 | 66 | with Image.open(BytesIO(img_content), mode="r") as pil_img: 67 | return tensor_from_pil(pil_img.convert("RGB"), dtype) 68 | 69 | 70 | def tensor_from_numpy(npy_img: np.ndarray, dtype: torch.dtype = torch.float32) -> torch.Tensor: 71 | """Read an image file as a PyTorch tensor 72 | 73 | Args: 74 | npy_img: image encoded as a numpy array of shape (H, W, C) in np.uint8 75 | dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255. 76 | 77 | Returns: 78 | same image as a tensor of shape (C, H, W) 79 | """ 80 | if dtype not in (torch.uint8, torch.float16, torch.float32): 81 | raise ValueError("insupported value for dtype") 82 | 83 | if dtype == torch.float32: 84 | img = to_tensor(npy_img) 85 | else: 86 | img = torch.from_numpy(npy_img) 87 | # put it from HWC to CHW format 88 | img = img.permute((2, 0, 1)).contiguous() 89 | if dtype == torch.float16: 90 | # Switch to FP16 91 | img = img.to(dtype=torch.float16).div(255) 92 | 93 | return img 94 | 95 | 96 | def get_img_shape(img: torch.Tensor) -> tuple[int, int]: 97 | """Get the shape of an image""" 98 | return img.shape[-2:] # type: ignore[return-value] 99 | -------------------------------------------------------------------------------- /doctr/io/pdf.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import Any 7 | 8 | import numpy as np 9 | import pypdfium2 as pdfium 10 | 11 | from doctr.utils.common_types import AbstractFile 12 | 13 | __all__ = ["read_pdf"] 14 | 15 | 16 | def read_pdf( 17 | file: AbstractFile, 18 | scale: int = 2, 19 | rgb_mode: bool = True, 20 | password: str | None = None, 21 | **kwargs: Any, 22 | ) -> list[np.ndarray]: 23 | """Read a PDF file and convert it into an image in numpy format 24 | 25 | >>> from doctr.io import read_pdf 26 | >>> doc = read_pdf("path/to/your/doc.pdf") 27 | 28 | Args: 29 | file: the path to the PDF file 30 | scale: rendering scale (1 corresponds to 72dpi) 31 | rgb_mode: if True, the output will be RGB, otherwise BGR 32 | password: a password to unlock the document, if encrypted 33 | **kwargs: additional parameters to :meth:`pypdfium2.PdfPage.render` 34 | 35 | Returns: 36 | the list of pages decoded as numpy ndarray of shape H x W x C 37 | """ 38 | # Rasterise pages to numpy ndarrays with pypdfium2 39 | pdf = pdfium.PdfDocument(file, password=password) 40 | try: 41 | return [page.render(scale=scale, rev_byteorder=rgb_mode, **kwargs).to_numpy() for page in pdf] 42 | finally: 43 | pdf.close() 44 | -------------------------------------------------------------------------------- /doctr/io/reader.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from collections.abc import Sequence 7 | from pathlib import Path 8 | 9 | import numpy as np 10 | 11 | from doctr.file_utils import requires_package 12 | from doctr.utils.common_types import AbstractFile 13 | 14 | from .html import read_html 15 | from .image import read_img_as_numpy 16 | from .pdf import read_pdf 17 | 18 | __all__ = ["DocumentFile"] 19 | 20 | 21 | class DocumentFile: 22 | """Read a document from multiple extensions""" 23 | 24 | @classmethod 25 | def from_pdf(cls, file: AbstractFile, **kwargs) -> list[np.ndarray]: 26 | """Read a PDF file 27 | 28 | >>> from doctr.io import DocumentFile 29 | >>> doc = DocumentFile.from_pdf("path/to/your/doc.pdf") 30 | 31 | Args: 32 | file: the path to the PDF file or a binary stream 33 | **kwargs: additional parameters to :meth:`pypdfium2.PdfPage.render` 34 | 35 | Returns: 36 | the list of pages decoded as numpy ndarray of shape H x W x 3 37 | """ 38 | return read_pdf(file, **kwargs) 39 | 40 | @classmethod 41 | def from_url(cls, url: str, **kwargs) -> list[np.ndarray]: 42 | """Interpret a web page as a PDF document 43 | 44 | >>> from doctr.io import DocumentFile 45 | >>> doc = DocumentFile.from_url("https://www.yoursite.com") 46 | 47 | Args: 48 | url: the URL of the target web page 49 | **kwargs: additional parameters to :meth:`pypdfium2.PdfPage.render` 50 | 51 | Returns: 52 | the list of pages decoded as numpy ndarray of shape H x W x 3 53 | """ 54 | requires_package( 55 | "weasyprint", 56 | "`.from_url` requires weasyprint installed.\n" 57 | + "Installation instructions: https://doc.courtbouillon.org/weasyprint/stable/first_steps.html#installation", 58 | ) 59 | pdf_stream = read_html(url) 60 | return cls.from_pdf(pdf_stream, **kwargs) 61 | 62 | @classmethod 63 | def from_images(cls, files: Sequence[AbstractFile] | AbstractFile, **kwargs) -> list[np.ndarray]: 64 | """Read an image file (or a collection of image files) and convert it into an image in numpy format 65 | 66 | >>> from doctr.io import DocumentFile 67 | >>> pages = DocumentFile.from_images(["path/to/your/page1.png", "path/to/your/page2.png"]) 68 | 69 | Args: 70 | files: the path to the image file or a binary stream, or a collection of those 71 | **kwargs: additional parameters to :meth:`doctr.io.image.read_img_as_numpy` 72 | 73 | Returns: 74 | the list of pages decoded as numpy ndarray of shape H x W x 3 75 | """ 76 | if isinstance(files, (str, Path, bytes)): 77 | files = [files] 78 | 79 | return [read_img_as_numpy(file, **kwargs) for file in files] 80 | -------------------------------------------------------------------------------- /doctr/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .classification import * 2 | from .detection import * 3 | from .recognition import * 4 | from .zoo import * 5 | from .factory import * 6 | -------------------------------------------------------------------------------- /doctr/models/classification/__init__.py: -------------------------------------------------------------------------------- 1 | from .mobilenet import * 2 | from .resnet import * 3 | from .vgg import * 4 | from .magc_resnet import * 5 | from .vit import * 6 | from .textnet import * 7 | from .vip import * 8 | from .zoo import * 9 | -------------------------------------------------------------------------------- /doctr/models/classification/magc_resnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import * 2 | -------------------------------------------------------------------------------- /doctr/models/classification/mobilenet/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import * 2 | -------------------------------------------------------------------------------- /doctr/models/classification/predictor/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import * 2 | -------------------------------------------------------------------------------- /doctr/models/classification/predictor/pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from doctr.models.preprocessor import PreProcessor 12 | from doctr.models.utils import set_device_and_dtype 13 | 14 | __all__ = ["OrientationPredictor"] 15 | 16 | 17 | class OrientationPredictor(nn.Module): 18 | """Implements an object able to detect the reading direction of a text box or a page. 19 | 4 possible orientations: 0, 90, 180, 270 (-90) degrees counter clockwise. 20 | 21 | Args: 22 | pre_processor: transform inputs for easier batched model inference 23 | model: core classification architecture (backbone + classification head) 24 | """ 25 | 26 | def __init__( 27 | self, 28 | pre_processor: PreProcessor | None, 29 | model: nn.Module | None, 30 | ) -> None: 31 | super().__init__() 32 | self.pre_processor = pre_processor if isinstance(pre_processor, PreProcessor) else None 33 | self.model = model.eval() if isinstance(model, nn.Module) else None 34 | 35 | @torch.inference_mode() 36 | def forward( 37 | self, 38 | inputs: list[np.ndarray], 39 | ) -> list[list[int] | list[float]]: 40 | # Dimension check 41 | if any(input.ndim != 3 for input in inputs): 42 | raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.") 43 | 44 | if self.model is None or self.pre_processor is None: 45 | # predictor is disabled 46 | return [[0] * len(inputs), [0] * len(inputs), [1.0] * len(inputs)] 47 | 48 | processed_batches = self.pre_processor(inputs) 49 | _params = next(self.model.parameters()) 50 | self.model, processed_batches = set_device_and_dtype( 51 | self.model, processed_batches, _params.device, _params.dtype 52 | ) 53 | predicted_batches = [self.model(batch) for batch in processed_batches] 54 | # confidence 55 | probs = [ 56 | torch.max(torch.softmax(batch, dim=1), dim=1).values.cpu().detach().numpy() for batch in predicted_batches 57 | ] 58 | # Postprocess predictions 59 | predicted_batches = [out_batch.argmax(dim=1).cpu().detach().numpy() for out_batch in predicted_batches] 60 | 61 | class_idxs = [int(pred) for batch in predicted_batches for pred in batch] 62 | classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs] # type: ignore 63 | confs = [round(float(p), 2) for prob in probs for p in prob] 64 | 65 | return [class_idxs, classes, confs] 66 | -------------------------------------------------------------------------------- /doctr/models/classification/resnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import * 2 | -------------------------------------------------------------------------------- /doctr/models/classification/textnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import * 2 | -------------------------------------------------------------------------------- /doctr/models/classification/vgg/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import * 2 | -------------------------------------------------------------------------------- /doctr/models/classification/vip/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import * 2 | -------------------------------------------------------------------------------- /doctr/models/classification/vip/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import * 2 | -------------------------------------------------------------------------------- /doctr/models/classification/vit/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import * 2 | -------------------------------------------------------------------------------- /doctr/models/core.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | 7 | from typing import Any 8 | 9 | from doctr.utils.repr import NestedObject 10 | 11 | __all__ = ["BaseModel"] 12 | 13 | 14 | class BaseModel(NestedObject): 15 | """Implements abstract DetectionModel class""" 16 | 17 | def __init__(self, cfg: dict[str, Any] | None = None) -> None: 18 | super().__init__() 19 | self.cfg = cfg 20 | -------------------------------------------------------------------------------- /doctr/models/detection/__init__.py: -------------------------------------------------------------------------------- 1 | from .differentiable_binarization import * 2 | from .linknet import * 3 | from .fast import * 4 | from .zoo import * 5 | -------------------------------------------------------------------------------- /doctr/models/detection/_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .pytorch import * 3 | -------------------------------------------------------------------------------- /doctr/models/detection/_utils/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | 7 | import numpy as np 8 | 9 | __all__ = ["_remove_padding"] 10 | 11 | 12 | def _remove_padding( 13 | pages: list[np.ndarray], 14 | loc_preds: list[dict[str, np.ndarray]], 15 | preserve_aspect_ratio: bool, 16 | symmetric_pad: bool, 17 | assume_straight_pages: bool, 18 | ) -> list[dict[str, np.ndarray]]: 19 | """Remove padding from the localization predictions 20 | 21 | Args: 22 | pages: list of pages 23 | loc_preds: list of localization predictions 24 | preserve_aspect_ratio: whether the aspect ratio was preserved during padding 25 | symmetric_pad: whether the padding was symmetric 26 | assume_straight_pages: whether the pages are assumed to be straight 27 | 28 | Returns: 29 | list of unpaded localization predictions 30 | """ 31 | if preserve_aspect_ratio: 32 | # Rectify loc_preds to remove padding 33 | rectified_preds = [] 34 | for page, dict_loc_preds in zip(pages, loc_preds): 35 | for k, loc_pred in dict_loc_preds.items(): 36 | h, w = page.shape[0], page.shape[1] 37 | if h > w: 38 | # y unchanged, dilate x coord 39 | if symmetric_pad: 40 | if assume_straight_pages: 41 | loc_pred[:, [0, 2]] = (loc_pred[:, [0, 2]] - 0.5) * h / w + 0.5 42 | else: 43 | loc_pred[:, :, 0] = (loc_pred[:, :, 0] - 0.5) * h / w + 0.5 44 | else: 45 | if assume_straight_pages: 46 | loc_pred[:, [0, 2]] *= h / w 47 | else: 48 | loc_pred[:, :, 0] *= h / w 49 | elif w > h: 50 | # x unchanged, dilate y coord 51 | if symmetric_pad: 52 | if assume_straight_pages: 53 | loc_pred[:, [1, 3]] = (loc_pred[:, [1, 3]] - 0.5) * w / h + 0.5 54 | else: 55 | loc_pred[:, :, 1] = (loc_pred[:, :, 1] - 0.5) * w / h + 0.5 56 | else: 57 | if assume_straight_pages: 58 | loc_pred[:, [1, 3]] *= w / h 59 | else: 60 | loc_pred[:, :, 1] *= w / h 61 | rectified_preds.append({k: np.clip(loc_pred, 0, 1)}) 62 | return rectified_preds 63 | return loc_preds 64 | -------------------------------------------------------------------------------- /doctr/models/detection/_utils/pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from torch import Tensor 7 | from torch.nn.functional import max_pool2d 8 | 9 | __all__ = ["erode", "dilate"] 10 | 11 | 12 | def erode(x: Tensor, kernel_size: int) -> Tensor: 13 | """Performs erosion on a given tensor 14 | 15 | Args: 16 | x: boolean tensor of shape (N, C, H, W) 17 | kernel_size: the size of the kernel to use for erosion 18 | 19 | Returns: 20 | the eroded tensor 21 | """ 22 | _pad = (kernel_size - 1) // 2 23 | 24 | return 1 - max_pool2d(1 - x, kernel_size, stride=1, padding=_pad) 25 | 26 | 27 | def dilate(x: Tensor, kernel_size: int) -> Tensor: 28 | """Performs dilation on a given tensor 29 | 30 | Args: 31 | x: boolean tensor of shape (N, C, H, W) 32 | kernel_size: the size of the kernel to use for dilation 33 | 34 | Returns: 35 | the dilated tensor 36 | """ 37 | _pad = (kernel_size - 1) // 2 38 | 39 | return max_pool2d(x, kernel_size, stride=1, padding=_pad) 40 | -------------------------------------------------------------------------------- /doctr/models/detection/differentiable_binarization/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import * 2 | -------------------------------------------------------------------------------- /doctr/models/detection/fast/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import * 2 | -------------------------------------------------------------------------------- /doctr/models/detection/linknet/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import * 2 | -------------------------------------------------------------------------------- /doctr/models/detection/predictor/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import * 2 | -------------------------------------------------------------------------------- /doctr/models/detection/predictor/pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import Any 7 | 8 | import numpy as np 9 | import torch 10 | from torch import nn 11 | 12 | from doctr.models.detection._utils import _remove_padding 13 | from doctr.models.preprocessor import PreProcessor 14 | from doctr.models.utils import set_device_and_dtype 15 | 16 | __all__ = ["DetectionPredictor"] 17 | 18 | 19 | class DetectionPredictor(nn.Module): 20 | """Implements an object able to localize text elements in a document 21 | 22 | Args: 23 | pre_processor: transform inputs for easier batched model inference 24 | model: core detection architecture 25 | """ 26 | 27 | def __init__( 28 | self, 29 | pre_processor: PreProcessor, 30 | model: nn.Module, 31 | ) -> None: 32 | super().__init__() 33 | self.pre_processor = pre_processor 34 | self.model = model.eval() 35 | 36 | @torch.inference_mode() 37 | def forward( 38 | self, 39 | pages: list[np.ndarray], 40 | return_maps: bool = False, 41 | **kwargs: Any, 42 | ) -> list[dict[str, np.ndarray]] | tuple[list[dict[str, np.ndarray]], list[np.ndarray]]: 43 | # Extract parameters from the preprocessor 44 | preserve_aspect_ratio = self.pre_processor.resize.preserve_aspect_ratio 45 | symmetric_pad = self.pre_processor.resize.symmetric_pad 46 | assume_straight_pages = self.model.assume_straight_pages 47 | 48 | # Dimension check 49 | if any(page.ndim != 3 for page in pages): 50 | raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.") 51 | 52 | processed_batches = self.pre_processor(pages) 53 | _params = next(self.model.parameters()) 54 | self.model, processed_batches = set_device_and_dtype( 55 | self.model, processed_batches, _params.device, _params.dtype 56 | ) 57 | predicted_batches = [ 58 | self.model(batch, return_preds=True, return_model_output=True, **kwargs) for batch in processed_batches 59 | ] 60 | # Remove padding from loc predictions 61 | preds = _remove_padding( 62 | pages, 63 | [pred for batch in predicted_batches for pred in batch["preds"]], 64 | preserve_aspect_ratio=preserve_aspect_ratio, 65 | symmetric_pad=symmetric_pad, 66 | assume_straight_pages=assume_straight_pages, # type: ignore[arg-type] 67 | ) 68 | 69 | if return_maps: 70 | seg_maps = [ 71 | pred.permute(1, 2, 0).detach().cpu().numpy() for batch in predicted_batches for pred in batch["out_map"] 72 | ] 73 | return preds, seg_maps 74 | return preds 75 | -------------------------------------------------------------------------------- /doctr/models/factory/__init__.py: -------------------------------------------------------------------------------- 1 | from .hub import * 2 | -------------------------------------------------------------------------------- /doctr/models/kie_predictor/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import * 2 | -------------------------------------------------------------------------------- /doctr/models/kie_predictor/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import Any 7 | 8 | from doctr.models.builder import KIEDocumentBuilder 9 | 10 | from ..classification.predictor import OrientationPredictor 11 | from ..predictor.base import _OCRPredictor 12 | 13 | __all__ = ["_KIEPredictor"] 14 | 15 | 16 | class _KIEPredictor(_OCRPredictor): 17 | """Implements an object able to localize and identify text elements in a set of documents 18 | 19 | Args: 20 | assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages 21 | without rotated textual elements. 22 | straighten_pages: if True, estimates the page general orientation based on the median line orientation. 23 | Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped 24 | accordingly. Doing so will improve performances for documents with page-uniform rotations. 25 | preserve_aspect_ratio: if True, resize preserving the aspect ratio (with padding) 26 | symmetric_pad: if True and preserve_aspect_ratio is True, pas the image symmetrically. 27 | detect_orientation: if True, the estimated general page orientation will be added to the predictions for each 28 | page. Doing so will slightly deteriorate the overall latency. 29 | kwargs: keyword args of `DocumentBuilder` 30 | """ 31 | 32 | crop_orientation_predictor: OrientationPredictor | None 33 | page_orientation_predictor: OrientationPredictor | None 34 | 35 | def __init__( 36 | self, 37 | assume_straight_pages: bool = True, 38 | straighten_pages: bool = False, 39 | preserve_aspect_ratio: bool = True, 40 | symmetric_pad: bool = True, 41 | detect_orientation: bool = False, 42 | **kwargs: Any, 43 | ) -> None: 44 | super().__init__( 45 | assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, detect_orientation, **kwargs 46 | ) 47 | 48 | # Remove the following arguments from kwargs after initialization of the parent class 49 | kwargs.pop("disable_page_orientation", None) 50 | kwargs.pop("disable_crop_orientation", None) 51 | 52 | self.doc_builder: KIEDocumentBuilder = KIEDocumentBuilder(**kwargs) 53 | -------------------------------------------------------------------------------- /doctr/models/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .layers import * 2 | from .transformer import * 3 | from .vision_transformer import * 4 | -------------------------------------------------------------------------------- /doctr/models/modules/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import * 2 | -------------------------------------------------------------------------------- /doctr/models/modules/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import * 2 | -------------------------------------------------------------------------------- /doctr/models/modules/vision_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import * 2 | -------------------------------------------------------------------------------- /doctr/models/predictor/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import * 2 | -------------------------------------------------------------------------------- /doctr/models/preprocessor/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import * 2 | -------------------------------------------------------------------------------- /doctr/models/recognition/__init__.py: -------------------------------------------------------------------------------- 1 | from .crnn import * 2 | from .master import * 3 | from .sar import * 4 | from .vitstr import * 5 | from .parseq import * 6 | from .viptr import * 7 | from .zoo import * 8 | -------------------------------------------------------------------------------- /doctr/models/recognition/core.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | 7 | import numpy as np 8 | 9 | from doctr.datasets import encode_sequences 10 | from doctr.utils.repr import NestedObject 11 | 12 | __all__ = ["RecognitionPostProcessor", "RecognitionModel"] 13 | 14 | 15 | class RecognitionModel(NestedObject): 16 | """Implements abstract RecognitionModel class""" 17 | 18 | vocab: str 19 | max_length: int 20 | 21 | def build_target( 22 | self, 23 | gts: list[str], 24 | ) -> tuple[np.ndarray, list[int]]: 25 | """Encode a list of gts sequences into a np array and gives the corresponding* 26 | sequence lengths. 27 | 28 | Args: 29 | gts: list of ground-truth labels 30 | 31 | Returns: 32 | A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch) 33 | """ 34 | encoded = encode_sequences(sequences=gts, vocab=self.vocab, target_size=self.max_length, eos=len(self.vocab)) 35 | seq_len = [len(word) for word in gts] 36 | return encoded, seq_len 37 | 38 | 39 | class RecognitionPostProcessor(NestedObject): 40 | """Abstract class to postprocess the raw output of the model 41 | 42 | Args: 43 | vocab: string containing the ordered sequence of supported characters 44 | """ 45 | 46 | def __init__( 47 | self, 48 | vocab: str, 49 | ) -> None: 50 | self.vocab = vocab 51 | self._embedding = list(self.vocab) + [""] 52 | 53 | def extra_repr(self) -> str: 54 | return f"vocab_size={len(self.vocab)}" 55 | -------------------------------------------------------------------------------- /doctr/models/recognition/crnn/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import * 2 | -------------------------------------------------------------------------------- /doctr/models/recognition/master/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import * 2 | -------------------------------------------------------------------------------- /doctr/models/recognition/master/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | 7 | import numpy as np 8 | 9 | from ....datasets import encode_sequences 10 | from ..core import RecognitionPostProcessor 11 | 12 | 13 | class _MASTER: 14 | vocab: str 15 | max_length: int 16 | 17 | def build_target( 18 | self, 19 | gts: list[str], 20 | ) -> tuple[np.ndarray, list[int]]: 21 | """Encode a list of gts sequences into a np array and gives the corresponding* 22 | sequence lengths. 23 | 24 | Args: 25 | gts: list of ground-truth labels 26 | 27 | Returns: 28 | A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch) 29 | """ 30 | encoded = encode_sequences( 31 | sequences=gts, 32 | vocab=self.vocab, 33 | target_size=self.max_length, 34 | eos=len(self.vocab), 35 | sos=len(self.vocab) + 1, 36 | pad=len(self.vocab) + 2, 37 | ) 38 | seq_len = [len(word) for word in gts] 39 | return encoded, seq_len 40 | 41 | 42 | class _MASTERPostProcessor(RecognitionPostProcessor): 43 | """Abstract class to postprocess the raw output of the model 44 | 45 | Args: 46 | vocab: string containing the ordered sequence of supported characters 47 | """ 48 | 49 | def __init__( 50 | self, 51 | vocab: str, 52 | ) -> None: 53 | super().__init__(vocab) 54 | self._embedding = list(vocab) + [""] + [""] + [""] 55 | -------------------------------------------------------------------------------- /doctr/models/recognition/parseq/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import * 2 | -------------------------------------------------------------------------------- /doctr/models/recognition/parseq/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | 7 | import numpy as np 8 | 9 | from ....datasets import encode_sequences 10 | from ..core import RecognitionPostProcessor 11 | 12 | 13 | class _PARSeq: 14 | vocab: str 15 | max_length: int 16 | 17 | def build_target( 18 | self, 19 | gts: list[str], 20 | ) -> tuple[np.ndarray, list[int]]: 21 | """Encode a list of gts sequences into a np array and gives the corresponding* 22 | sequence lengths. 23 | 24 | Args: 25 | gts: list of ground-truth labels 26 | 27 | Returns: 28 | A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch) 29 | """ 30 | encoded = encode_sequences( 31 | sequences=gts, 32 | vocab=self.vocab, 33 | target_size=self.max_length, 34 | eos=len(self.vocab), 35 | sos=len(self.vocab) + 1, 36 | pad=len(self.vocab) + 2, 37 | ) 38 | seq_len = [len(word) for word in gts] 39 | return encoded, seq_len 40 | 41 | 42 | class _PARSeqPostProcessor(RecognitionPostProcessor): 43 | """Abstract class to postprocess the raw output of the model 44 | 45 | Args: 46 | vocab: string containing the ordered sequence of supported characters 47 | """ 48 | 49 | def __init__( 50 | self, 51 | vocab: str, 52 | ) -> None: 53 | super().__init__(vocab) 54 | self._embedding = list(vocab) + ["", "", ""] 55 | -------------------------------------------------------------------------------- /doctr/models/recognition/predictor/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import * 2 | -------------------------------------------------------------------------------- /doctr/models/recognition/predictor/pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from collections.abc import Sequence 7 | from typing import Any 8 | 9 | import numpy as np 10 | import torch 11 | from torch import nn 12 | 13 | from doctr.models.preprocessor import PreProcessor 14 | from doctr.models.utils import set_device_and_dtype 15 | 16 | from ._utils import remap_preds, split_crops 17 | 18 | __all__ = ["RecognitionPredictor"] 19 | 20 | 21 | class RecognitionPredictor(nn.Module): 22 | """Implements an object able to identify character sequences in images 23 | 24 | Args: 25 | pre_processor: transform inputs for easier batched model inference 26 | model: core detection architecture 27 | split_wide_crops: wether to use crop splitting for high aspect ratio crops 28 | """ 29 | 30 | def __init__( 31 | self, 32 | pre_processor: PreProcessor, 33 | model: nn.Module, 34 | split_wide_crops: bool = True, 35 | ) -> None: 36 | super().__init__() 37 | self.pre_processor = pre_processor 38 | self.model = model.eval() 39 | self.split_wide_crops = split_wide_crops 40 | self.critical_ar = 8 # Critical aspect ratio 41 | self.overlap_ratio = 0.5 # Ratio of overlap between neighboring crops 42 | self.target_ar = 6 # Target aspect ratio 43 | 44 | @torch.inference_mode() 45 | def forward( 46 | self, 47 | crops: Sequence[np.ndarray], 48 | **kwargs: Any, 49 | ) -> list[tuple[str, float]]: 50 | if len(crops) == 0: 51 | return [] 52 | # Dimension check 53 | if any(crop.ndim != 3 for crop in crops): 54 | raise ValueError("incorrect input shape: all crops are expected to be multi-channel 2D images.") 55 | 56 | # Split crops that are too wide 57 | remapped = False 58 | if self.split_wide_crops: 59 | new_crops, crop_map, remapped = split_crops( 60 | crops, # type: ignore[arg-type] 61 | self.critical_ar, 62 | self.target_ar, 63 | self.overlap_ratio, 64 | ) 65 | if remapped: 66 | crops = new_crops 67 | 68 | # Resize & batch them 69 | processed_batches = self.pre_processor(crops) # type: ignore[arg-type] 70 | 71 | # Forward it 72 | _params = next(self.model.parameters()) 73 | self.model, processed_batches = set_device_and_dtype( 74 | self.model, processed_batches, _params.device, _params.dtype 75 | ) 76 | raw = [self.model(batch, return_preds=True, **kwargs)["preds"] for batch in processed_batches] 77 | 78 | # Process outputs 79 | out = [charseq for batch in raw for charseq in batch] 80 | 81 | # Remap crops 82 | if self.split_wide_crops and remapped: 83 | out = remap_preds(out, crop_map, self.overlap_ratio) 84 | 85 | return out 86 | -------------------------------------------------------------------------------- /doctr/models/recognition/sar/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import * 2 | -------------------------------------------------------------------------------- /doctr/models/recognition/viptr/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import * 2 | -------------------------------------------------------------------------------- /doctr/models/recognition/vitstr/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import * 2 | -------------------------------------------------------------------------------- /doctr/models/recognition/vitstr/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | 7 | import numpy as np 8 | 9 | from ....datasets import encode_sequences 10 | from ..core import RecognitionPostProcessor 11 | 12 | 13 | class _ViTSTR: 14 | vocab: str 15 | max_length: int 16 | 17 | def build_target( 18 | self, 19 | gts: list[str], 20 | ) -> tuple[np.ndarray, list[int]]: 21 | """Encode a list of gts sequences into a np array and gives the corresponding* 22 | sequence lengths. 23 | 24 | Args: 25 | gts: list of ground-truth labels 26 | 27 | Returns: 28 | A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch) 29 | """ 30 | encoded = encode_sequences( 31 | sequences=gts, 32 | vocab=self.vocab, 33 | target_size=self.max_length, 34 | eos=len(self.vocab), 35 | sos=len(self.vocab) + 1, 36 | ) 37 | seq_len = [len(word) for word in gts] 38 | return encoded, seq_len 39 | 40 | 41 | class _ViTSTRPostProcessor(RecognitionPostProcessor): 42 | """Abstract class to postprocess the raw output of the model 43 | 44 | Args: 45 | vocab: string containing the ordered sequence of supported characters 46 | """ 47 | 48 | def __init__( 49 | self, 50 | vocab: str, 51 | ) -> None: 52 | super().__init__(vocab) 53 | self._embedding = list(vocab) + ["", ""] 54 | -------------------------------------------------------------------------------- /doctr/models/recognition/zoo.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import Any 7 | 8 | from doctr.models.preprocessor import PreProcessor 9 | from doctr.models.utils import _CompiledModule 10 | 11 | from .. import recognition 12 | from .predictor import RecognitionPredictor 13 | 14 | __all__ = ["recognition_predictor"] 15 | 16 | 17 | ARCHS: list[str] = [ 18 | "crnn_vgg16_bn", 19 | "crnn_mobilenet_v3_small", 20 | "crnn_mobilenet_v3_large", 21 | "sar_resnet31", 22 | "master", 23 | "vitstr_small", 24 | "vitstr_base", 25 | "parseq", 26 | "viptr_tiny", 27 | ] 28 | 29 | 30 | def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredictor: 31 | if isinstance(arch, str): 32 | if arch not in ARCHS: 33 | raise ValueError(f"unknown architecture '{arch}'") 34 | 35 | _model = recognition.__dict__[arch]( 36 | pretrained=pretrained, pretrained_backbone=kwargs.get("pretrained_backbone", True) 37 | ) 38 | else: 39 | # Adding the type for torch compiled models to the allowed architectures 40 | allowed_archs = [ 41 | recognition.CRNN, 42 | recognition.SAR, 43 | recognition.MASTER, 44 | recognition.ViTSTR, 45 | recognition.PARSeq, 46 | recognition.VIPTR, 47 | _CompiledModule, 48 | ] 49 | 50 | if not isinstance(arch, tuple(allowed_archs)): 51 | raise ValueError(f"unknown architecture: {type(arch)}") 52 | _model = arch 53 | 54 | kwargs.pop("pretrained_backbone", None) 55 | 56 | kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) 57 | kwargs["std"] = kwargs.get("std", _model.cfg["std"]) 58 | kwargs["batch_size"] = kwargs.get("batch_size", 128) 59 | input_shape = _model.cfg["input_shape"][-2:] 60 | predictor = RecognitionPredictor(PreProcessor(input_shape, preserve_aspect_ratio=True, **kwargs), _model) 61 | 62 | return predictor 63 | 64 | 65 | def recognition_predictor( 66 | arch: Any = "crnn_vgg16_bn", 67 | pretrained: bool = False, 68 | symmetric_pad: bool = False, 69 | batch_size: int = 128, 70 | **kwargs: Any, 71 | ) -> RecognitionPredictor: 72 | """Text recognition architecture. 73 | 74 | Example:: 75 | >>> import numpy as np 76 | >>> from doctr.models import recognition_predictor 77 | >>> model = recognition_predictor(pretrained=True) 78 | >>> input_page = (255 * np.random.rand(32, 128, 3)).astype(np.uint8) 79 | >>> out = model([input_page]) 80 | 81 | Args: 82 | arch: name of the architecture or model itself to use (e.g. 'crnn_vgg16_bn') 83 | pretrained: If True, returns a model pre-trained on our text recognition dataset 84 | symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right 85 | batch_size: number of samples the model processes in parallel 86 | **kwargs: optional parameters to be passed to the architecture 87 | 88 | Returns: 89 | Recognition predictor 90 | """ 91 | return _predictor(arch=arch, pretrained=pretrained, symmetric_pad=symmetric_pad, batch_size=batch_size, **kwargs) 92 | -------------------------------------------------------------------------------- /doctr/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import * 2 | -------------------------------------------------------------------------------- /doctr/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mindee/doctr/849d19566b03cd37c8a831ae02673a1c2f265399/doctr/py.typed -------------------------------------------------------------------------------- /doctr/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * 2 | -------------------------------------------------------------------------------- /doctr/transforms/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import * 2 | -------------------------------------------------------------------------------- /doctr/transforms/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .pytorch import * 3 | -------------------------------------------------------------------------------- /doctr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .common_types import * 2 | from .data import * 3 | from .geometry import * 4 | from .metrics import * 5 | -------------------------------------------------------------------------------- /doctr/utils/common_types.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from pathlib import Path 7 | 8 | __all__ = ["Point2D", "BoundingBox", "Polygon4P", "Polygon", "Bbox"] 9 | 10 | 11 | Point2D = tuple[float, float] 12 | BoundingBox = tuple[Point2D, Point2D] 13 | Polygon4P = tuple[Point2D, Point2D, Point2D, Point2D] 14 | Polygon = list[Point2D] 15 | AbstractPath = str | Path 16 | AbstractFile = AbstractPath | bytes 17 | Bbox = tuple[float, float, float, float] 18 | -------------------------------------------------------------------------------- /doctr/utils/fonts.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import logging 7 | import platform 8 | 9 | from PIL import ImageFont 10 | 11 | __all__ = ["get_font"] 12 | 13 | 14 | def get_font(font_family: str | None = None, font_size: int = 13) -> ImageFont.FreeTypeFont | ImageFont.ImageFont: 15 | """Resolves a compatible ImageFont for the system 16 | 17 | Args: 18 | font_family: the font family to use 19 | font_size: the size of the font upon rendering 20 | 21 | Returns: 22 | the Pillow font 23 | """ 24 | # Font selection 25 | if font_family is None: 26 | try: 27 | font = ImageFont.truetype("FreeMono.ttf" if platform.system() == "Linux" else "Arial.ttf", font_size) 28 | except OSError: # pragma: no cover 29 | font = ImageFont.load_default() # type: ignore[assignment] 30 | logging.warning( 31 | "unable to load recommended font family. Loading default PIL font," 32 | "font size issues may be expected." 33 | "To prevent this, it is recommended to specify the value of 'font_family'." 34 | ) 35 | else: # pragma: no cover 36 | font = ImageFont.truetype(font_family, font_size) 37 | 38 | return font 39 | -------------------------------------------------------------------------------- /doctr/utils/multithreading.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | 7 | import multiprocessing as mp 8 | import os 9 | from collections.abc import Callable, Iterable, Iterator 10 | from multiprocessing.pool import ThreadPool 11 | from typing import Any 12 | 13 | from doctr.file_utils import ENV_VARS_TRUE_VALUES 14 | 15 | __all__ = ["multithread_exec"] 16 | 17 | 18 | def multithread_exec(func: Callable[[Any], Any], seq: Iterable[Any], threads: int | None = None) -> Iterator[Any]: 19 | """Execute a given function in parallel for each element of a given sequence 20 | 21 | >>> from doctr.utils.multithreading import multithread_exec 22 | >>> entries = [1, 4, 8] 23 | >>> results = multithread_exec(lambda x: x ** 2, entries) 24 | 25 | Args: 26 | func: function to be executed on each element of the iterable 27 | seq: iterable 28 | threads: number of workers to be used for multiprocessing 29 | 30 | Returns: 31 | iterator of the function's results using the iterable as inputs 32 | 33 | Notes: 34 | This function uses ThreadPool from multiprocessing package, which uses `/dev/shm` directory for shared memory. 35 | If you do not have write permissions for this directory (if you run `doctr` on AWS Lambda for instance), 36 | you might want to disable multiprocessing. To achieve that, set 'DOCTR_MULTIPROCESSING_DISABLE' to 'TRUE'. 37 | """ 38 | threads = threads if isinstance(threads, int) else min(16, mp.cpu_count()) 39 | # Single-thread 40 | if threads < 2 or os.environ.get("DOCTR_MULTIPROCESSING_DISABLE", "").upper() in ENV_VARS_TRUE_VALUES: 41 | results = map(func, seq) 42 | # Multi-threading 43 | else: 44 | with ThreadPool(threads) as tp: 45 | # ThreadPool's map function returns a list, but seq could be of a different type 46 | # That's why wrapping result in map to return iterator 47 | results = map(lambda x: x, tp.map(func, seq)) # noqa: C417 48 | return results 49 | -------------------------------------------------------------------------------- /doctr/utils/repr.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | # Adapted from https://github.com/pytorch/torch/blob/master/torch/nn/modules/module.py 7 | 8 | 9 | __all__ = ["NestedObject"] 10 | 11 | 12 | def _addindent(s_, num_spaces): 13 | s = s_.split("\n") 14 | # don't do anything for single-line stuff 15 | if len(s) == 1: 16 | return s_ 17 | first = s.pop(0) 18 | s = [(num_spaces * " ") + line for line in s] 19 | s = "\n".join(s) 20 | s = first + "\n" + s 21 | return s 22 | 23 | 24 | class NestedObject: 25 | """Base class for all nested objects in doctr""" 26 | 27 | _children_names: list[str] 28 | 29 | def extra_repr(self) -> str: 30 | return "" 31 | 32 | def __repr__(self): 33 | # We treat the extra repr like the sub-object, one item per line 34 | extra_lines = [] 35 | extra_repr = self.extra_repr() 36 | # empty string will be split into list [''] 37 | if extra_repr: 38 | extra_lines = extra_repr.split("\n") 39 | child_lines = [] 40 | if hasattr(self, "_children_names"): 41 | for key in self._children_names: 42 | child = getattr(self, key) 43 | if isinstance(child, list) and len(child) > 0: 44 | child_str = ",\n".join([repr(subchild) for subchild in child]) 45 | if len(child) > 1: 46 | child_str = _addindent(f"\n{child_str},", 2) + "\n" 47 | child_str = f"[{child_str}]" 48 | else: 49 | child_str = repr(child) 50 | child_str = _addindent(child_str, 2) 51 | child_lines.append("(" + key + "): " + child_str) 52 | lines = extra_lines + child_lines 53 | 54 | main_str = self.__class__.__name__ + "(" 55 | if lines: 56 | # simple one-liner info, which most builtin Modules will use 57 | if len(extra_lines) == 1 and not child_lines: 58 | main_str += extra_lines[0] 59 | else: 60 | main_str += "\n " + "\n ".join(lines) + "\n" 61 | 62 | main_str += ")" 63 | return main_str 64 | -------------------------------------------------------------------------------- /references/classification/README.md: -------------------------------------------------------------------------------- 1 | # Character classification 2 | 3 | The sample training scripts was made to train a character classification model or a orientation classifier with docTR. 4 | 5 | ## Setup 6 | 7 | First, you need to install `doctr` (with pip, for instance) 8 | 9 | ```shell 10 | pip install -e . --upgrade 11 | pip install -r references/requirements.txt 12 | ``` 13 | 14 | ## Usage character classification 15 | 16 | You can start your training in PyTorch: 17 | 18 | ```shell 19 | python references/classification/train_character.py mobilenet_v3_large --epochs 5 --device 0 20 | ``` 21 | 22 | ## Usage orientation classification 23 | 24 | You can start your training in PyTorch: 25 | 26 | ```shell 27 | python references/classification/train_orientation.py resnet18 --type page --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5 28 | ``` 29 | 30 | The type can be either `page` for document images or `crop` for word crops. 31 | 32 | ## Data format 33 | 34 | You need to provide both `train_path` and `val_path` arguments to start training. 35 | Each path must lead to a folder where the images are stored. For example: 36 | 37 | ```shell 38 | images 39 | ├── sample_img_01.png 40 | ├── sample_img_02.png 41 | ├── sample_img_03.png 42 | └── ... 43 | ``` 44 | 45 | ## Slack Logging with tqdm 46 | 47 | To enable Slack logging using `tqdm`, you need to set the following environment variables: 48 | 49 | - `TQDM_SLACK_TOKEN`: the Slack Bot Token 50 | - `TQDM_SLACK_CHANNEL`: you can retrieve it using `Right Click on Channel > Copy > Copy link`. You should get something like `https://xxxxxx.slack.com/archives/yyyyyyyy`. Keep only the `yyyyyyyy` part. 51 | 52 | You can follow this page on [how to create a Slack App](https://api.slack.com/quickstart). 53 | 54 | ## Advanced options 55 | 56 | Feel free to inspect the multiple script option to customize your training to your own needs! 57 | 58 | Character classification: 59 | 60 | ```shell 61 | python references/classification/train_character.py --help 62 | ``` 63 | 64 | Orientation classification: 65 | 66 | ```shell 67 | python references/classification/train_orientation.py --help 68 | ``` 69 | -------------------------------------------------------------------------------- /references/classification/latency.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | """Image classification latency benchmark""" 7 | 8 | import argparse 9 | import time 10 | 11 | import numpy as np 12 | import torch 13 | 14 | from doctr.models import classification 15 | 16 | 17 | @torch.inference_mode() 18 | def main(args): 19 | device = torch.device("cuda:0" if args.gpu else "cpu") 20 | 21 | # Pretrained imagenet model 22 | model = ( 23 | classification.__dict__[args.arch]( 24 | pretrained=args.pretrained, 25 | ) 26 | .eval() 27 | .to(device=device) 28 | ) 29 | 30 | # Input 31 | img_tensor = torch.rand((args.batch_size, 3, args.size, args.size)).to(device=device) 32 | 33 | # Warmup 34 | for _ in range(10): 35 | _ = model(img_tensor) 36 | 37 | timings = [] 38 | 39 | # Evaluation runs 40 | for _ in range(args.it): 41 | start_ts = time.perf_counter() 42 | _ = model(img_tensor) 43 | timings.append(time.perf_counter() - start_ts) 44 | 45 | _timings = np.array(timings) 46 | print(f"{args.arch} ({args.it} runs on ({args.size}, {args.size}) inputs in batches of {args.batch_size})") 47 | print(f"mean {1000 * _timings.mean():.2f}ms, std {1000 * _timings.std():.2f}ms") 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser( 52 | description="docTR latency benchmark for image classification (PyTorch)", 53 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 54 | ) 55 | parser.add_argument("arch", type=str, help="Architecture to use") 56 | parser.add_argument("--size", type=int, default=32, help="The image input size") 57 | parser.add_argument("--batch-size", "-b", type=int, default=64, help="The batch_size") 58 | parser.add_argument("--gpu", dest="gpu", help="Should the benchmark be performed on GPU", action="store_true") 59 | parser.add_argument("--it", type=int, default=100, help="Number of iterations to run") 60 | parser.add_argument( 61 | "--pretrained", dest="pretrained", help="Use pre-trained models from the modelzoo", action="store_true" 62 | ) 63 | args = parser.parse_args() 64 | 65 | main(args) 66 | -------------------------------------------------------------------------------- /references/classification/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import math 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | 11 | 12 | def plot_samples(images, targets): 13 | # Unnormalize image 14 | num_samples = min(len(images), 12) 15 | num_cols = min(len(images), 8) 16 | num_rows = int(math.ceil(num_samples / num_cols)) 17 | _, axes = plt.subplots(num_rows, num_cols, figsize=(20, 5)) 18 | for idx in range(num_samples): 19 | img = (255 * images[idx].numpy()).round().clip(0, 255).astype(np.uint8) 20 | if img.shape[0] == 3 and img.shape[2] != 3: 21 | img = img.transpose(1, 2, 0) 22 | 23 | row_idx = idx // num_cols 24 | col_idx = idx % num_cols 25 | 26 | ax = axes[row_idx] if num_rows > 1 else axes 27 | ax = ax[col_idx] if num_cols > 1 else ax 28 | 29 | ax.imshow(img) 30 | ax.set_title(targets[idx]) 31 | # Disable axis 32 | for ax in axes.ravel(): 33 | ax.axis("off") 34 | plt.show() 35 | 36 | 37 | def plot_recorder(lr_recorder, loss_recorder, beta: float = 0.95, **kwargs) -> None: 38 | """Display the results of the LR grid search. 39 | Adapted from https://github.com/frgfm/Holocron/blob/master/holocron/trainer/core.py 40 | 41 | Args: 42 | lr_recorder: list of LR values 43 | loss_recorder: list of loss values 44 | beta (float, optional): smoothing factor 45 | **kwargs: keyword arguments from `matplotlib.pyplot.show` 46 | """ 47 | if len(lr_recorder) != len(loss_recorder) or len(lr_recorder) == 0: 48 | raise AssertionError("Both `lr_recorder` and `loss_recorder` should have the same length") 49 | 50 | # Exp moving average of loss 51 | smoothed_losses = [] 52 | avg_loss = 0.0 53 | for idx, loss in enumerate(loss_recorder): 54 | avg_loss = beta * avg_loss + (1 - beta) * loss 55 | smoothed_losses.append(avg_loss / (1 - beta ** (idx + 1))) 56 | 57 | # Properly rescale Y-axis 58 | data_slice = slice( 59 | min(len(loss_recorder) // 10, 10), 60 | -min(len(loss_recorder) // 20, 5) if len(loss_recorder) >= 20 else len(loss_recorder), 61 | ) 62 | vals = np.array(smoothed_losses[data_slice]) 63 | min_idx = vals.argmin() 64 | max_val = vals.max() if min_idx is None else vals[: min_idx + 1].max() # type: ignore[misc] 65 | delta = max_val - vals[min_idx] 66 | 67 | plt.plot(lr_recorder[data_slice], smoothed_losses[data_slice]) 68 | plt.xscale("log") 69 | plt.xlabel("Learning Rate") 70 | plt.ylabel("Training loss") 71 | plt.ylim(vals[min_idx] - 0.1 * delta, max_val + 0.2 * delta) 72 | plt.grid(True, linestyle="--", axis="x") 73 | plt.show(**kwargs) 74 | 75 | 76 | class EarlyStopper: 77 | def __init__(self, patience: int = 5, min_delta: float = 0.01): 78 | self.patience = patience 79 | self.min_delta = min_delta 80 | self.counter = 0 81 | self.min_validation_loss = float("inf") 82 | 83 | def early_stop(self, validation_loss: float) -> bool: 84 | if validation_loss < self.min_validation_loss: 85 | self.min_validation_loss = validation_loss 86 | self.counter = 0 87 | elif validation_loss > (self.min_validation_loss + self.min_delta): 88 | self.counter += 1 89 | if self.counter >= self.patience: 90 | return True 91 | return False 92 | -------------------------------------------------------------------------------- /references/detection/latency.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | """Text detection latency benchmark""" 7 | 8 | import argparse 9 | import time 10 | 11 | import numpy as np 12 | import torch 13 | 14 | from doctr.models import detection 15 | 16 | 17 | @torch.inference_mode() 18 | def main(args): 19 | device = torch.device("cuda:0" if args.gpu else "cpu") 20 | 21 | # Pretrained imagenet model 22 | model = ( 23 | detection.__dict__[args.arch](pretrained=args.pretrained, pretrained_backbone=False).eval().to(device=device) 24 | ) 25 | 26 | # Input 27 | img_tensor = torch.rand((1, 3, args.size, args.size)).to(device=device) 28 | 29 | # Warmup 30 | for _ in range(10): 31 | _ = model(img_tensor) 32 | 33 | timings = [] 34 | 35 | # Evaluation runs 36 | for _ in range(args.it): 37 | start_ts = time.perf_counter() 38 | _ = model(img_tensor) 39 | timings.append(time.perf_counter() - start_ts) 40 | 41 | _timings = np.array(timings) 42 | print(f"{args.arch} ({args.it} runs on ({args.size}, {args.size}) inputs)") 43 | print(f"mean {1000 * _timings.mean():.2f}ms, std {1000 * _timings.std():.2f}ms") 44 | 45 | 46 | if __name__ == "__main__": 47 | parser = argparse.ArgumentParser( 48 | description="docTR latency benchmark for text detection (PyTorch)", 49 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 50 | ) 51 | parser.add_argument("arch", type=str, help="Architecture to use") 52 | parser.add_argument("--size", type=int, default=1024, help="The image input size") 53 | parser.add_argument("--gpu", dest="gpu", help="Should the benchmark be performed on GPU", action="store_true") 54 | parser.add_argument("--it", type=int, default=100, help="Number of iterations to run") 55 | parser.add_argument( 56 | "--pretrained", dest="pretrained", help="Use pre-trained models from the modelzoo", action="store_true" 57 | ) 58 | args = parser.parse_args() 59 | 60 | main(args) 61 | -------------------------------------------------------------------------------- /references/recognition/latency.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | """Text recognition latency benchmark""" 7 | 8 | import argparse 9 | import time 10 | 11 | import numpy as np 12 | import torch 13 | 14 | from doctr.models import recognition 15 | 16 | 17 | @torch.inference_mode() 18 | def main(args): 19 | device = torch.device("cuda:0" if args.gpu else "cpu") 20 | 21 | # Pretrained imagenet model 22 | model = ( 23 | recognition.__dict__[args.arch]( 24 | pretrained=args.pretrained, 25 | pretrained_backbone=False, 26 | ) 27 | .eval() 28 | .to(device=device) 29 | ) 30 | 31 | # Input 32 | img_tensor = torch.rand((args.batch_size, 3, args.size, 4 * args.size)).to(device=device) 33 | 34 | # Warmup 35 | for _ in range(10): 36 | _ = model(img_tensor) 37 | 38 | timings = [] 39 | 40 | # Evaluation runs 41 | for _ in range(args.it): 42 | start_ts = time.perf_counter() 43 | _ = model(img_tensor) 44 | timings.append(time.perf_counter() - start_ts) 45 | 46 | _timings = np.array(timings) 47 | print(f"{args.arch} ({args.it} runs on ({args.size}, {4 * args.size}) inputs in batches of {args.batch_size})") 48 | print(f"mean {1000 * _timings.mean():.2f}ms, std {1000 * _timings.std():.2f}ms") 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser( 53 | description="docTR latency benchmark for text recognition (PyTorch)", 54 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 55 | ) 56 | parser.add_argument("arch", type=str, help="Architecture to use") 57 | parser.add_argument("--batch-size", "-b", type=int, default=64, help="The batch_size") 58 | parser.add_argument("--size", type=int, default=32, help="The image input size") 59 | parser.add_argument("--gpu", dest="gpu", help="Should the benchmark be performed on GPU", action="store_true") 60 | parser.add_argument("--it", type=int, default=100, help="Number of iterations to run") 61 | parser.add_argument( 62 | "--pretrained", dest="pretrained", help="Use pre-trained models from the modelzoo", action="store_true" 63 | ) 64 | args = parser.parse_args() 65 | 66 | main(args) 67 | -------------------------------------------------------------------------------- /references/recognition/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import math 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | 11 | 12 | def plot_samples(images, targets): 13 | # Unnormalize image 14 | num_samples = min(len(images), 12) 15 | num_cols = min(len(images), 4) 16 | num_rows = int(math.ceil(num_samples / num_cols)) 17 | _, axes = plt.subplots(num_rows, num_cols, figsize=(20, 5)) 18 | for idx in range(num_samples): 19 | img = (255 * images[idx].numpy()).round().clip(0, 255).astype(np.uint8) 20 | if img.shape[0] == 3 and img.shape[2] != 3: 21 | img = img.transpose(1, 2, 0) 22 | 23 | row_idx = idx // num_cols 24 | col_idx = idx % num_cols 25 | ax = axes[row_idx] if num_rows > 1 else axes 26 | ax = ax[col_idx] if num_cols > 1 else ax 27 | 28 | ax.imshow(img) 29 | ax.set_title(targets[idx]) 30 | # Disable axis 31 | for ax in axes.ravel(): 32 | ax.axis("off") 33 | 34 | plt.show() 35 | 36 | 37 | def plot_recorder(lr_recorder, loss_recorder, beta: float = 0.95, **kwargs) -> None: 38 | """Display the results of the LR grid search. 39 | Adapted from https://github.com/frgfm/Holocron/blob/master/holocron/trainer/core.py. 40 | 41 | Args: 42 | lr_recorder: list of LR values 43 | loss_recorder: list of loss values 44 | beta (float, optional): smoothing factor 45 | **kwargs: keyword arguments from `matplotlib.pyplot.show`. 46 | """ 47 | if len(lr_recorder) != len(loss_recorder) or len(lr_recorder) == 0: 48 | raise AssertionError("Both `lr_recorder` and `loss_recorder` should have the same length") 49 | 50 | # Exp moving average of loss 51 | smoothed_losses = [] 52 | avg_loss = 0.0 53 | for idx, loss in enumerate(loss_recorder): 54 | avg_loss = beta * avg_loss + (1 - beta) * loss 55 | smoothed_losses.append(avg_loss / (1 - beta ** (idx + 1))) 56 | 57 | # Properly rescale Y-axis 58 | data_slice = slice( 59 | min(len(loss_recorder) // 10, 10), 60 | -min(len(loss_recorder) // 20, 5) if len(loss_recorder) >= 20 else len(loss_recorder), 61 | ) 62 | vals = np.array(smoothed_losses[data_slice]) 63 | min_idx = vals.argmin() 64 | max_val = vals.max() if min_idx is None else vals[: min_idx + 1].max() # type: ignore[misc] 65 | delta = max_val - vals[min_idx] 66 | 67 | plt.plot(lr_recorder[data_slice], smoothed_losses[data_slice]) 68 | plt.xscale("log") 69 | plt.xlabel("Learning Rate") 70 | plt.ylabel("Training loss") 71 | plt.ylim(vals[min_idx] - 0.1 * delta, max_val + 0.2 * delta) 72 | plt.grid(True, linestyle="--", axis="x") 73 | plt.show(**kwargs) 74 | 75 | 76 | class EarlyStopper: 77 | def __init__(self, patience: int = 5, min_delta: float = 0.01): 78 | self.patience = patience 79 | self.min_delta = min_delta 80 | self.counter = 0 81 | self.min_validation_loss = float("inf") 82 | 83 | def early_stop(self, validation_loss: float) -> bool: 84 | if validation_loss < self.min_validation_loss: 85 | self.min_validation_loss = validation_loss 86 | self.counter = 0 87 | elif validation_loss > (self.min_validation_loss + self.min_delta): 88 | self.counter += 1 89 | if self.counter >= self.patience: 90 | return True 91 | return False 92 | -------------------------------------------------------------------------------- /references/requirements.txt: -------------------------------------------------------------------------------- 1 | -e . 2 | tqdm 3 | slack-sdk 4 | wandb>=0.10.31 5 | clearml>=1.11.1 6 | matplotlib>=3.1.0 7 | -------------------------------------------------------------------------------- /scripts/analyze.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | 7 | from doctr.io import DocumentFile 8 | from doctr.models import ocr_predictor 9 | 10 | 11 | def main(args): 12 | model = ocr_predictor(args.detection, args.recognition, pretrained=True) 13 | 14 | if args.path.lower().endswith(".pdf"): 15 | doc = DocumentFile.from_pdf(args.path) 16 | else: 17 | doc = DocumentFile.from_images(args.path) 18 | 19 | out = model(doc) 20 | 21 | for page in out.pages: 22 | page.show(block=not args.noblock, interactive=not args.static) 23 | 24 | 25 | def parse_args(): 26 | import argparse 27 | 28 | parser = argparse.ArgumentParser( 29 | description="DocTR end-to-end analysis", formatter_class=argparse.ArgumentDefaultsHelpFormatter 30 | ) 31 | 32 | parser.add_argument("path", type=str, help="Path to the input document (PDF or image)") 33 | parser.add_argument("--detection", type=str, default="fast_base", help="Text detection model to use for analysis") 34 | parser.add_argument( 35 | "--recognition", type=str, default="crnn_vgg16_bn", help="Text recognition model to use for analysis" 36 | ) 37 | parser.add_argument( 38 | "--noblock", dest="noblock", help="Disables blocking visualization. Used only for CI.", action="store_true" 39 | ) 40 | parser.add_argument("--static", dest="static", help="Switches to static visualization", action="store_true") 41 | args = parser.parse_args() 42 | 43 | return args 44 | 45 | 46 | if __name__ == "__main__": 47 | parsed_args = parse_args() 48 | main(parsed_args) 49 | -------------------------------------------------------------------------------- /scripts/detect_text.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import argparse 7 | import json 8 | import os 9 | from pathlib import Path 10 | 11 | from tqdm import tqdm 12 | 13 | from doctr.io import DocumentFile 14 | from doctr.models import detection, ocr_predictor 15 | 16 | IMAGE_FILE_EXTENSIONS = [".jpeg", ".jpg", ".png", ".tif", ".tiff", ".bmp"] 17 | OTHER_EXTENSIONS = [".pdf"] 18 | 19 | 20 | def _process_file(model, file_path: Path, out_format: str) -> None: 21 | if out_format not in ["txt", "json", "xml"]: 22 | raise ValueError(f"Unsupported output format: {out_format}") 23 | 24 | if os.path.splitext(file_path)[1] in IMAGE_FILE_EXTENSIONS: 25 | doc = DocumentFile.from_images([file_path]) 26 | elif os.path.splitext(file_path)[1] in OTHER_EXTENSIONS: 27 | doc = DocumentFile.from_pdf(file_path) 28 | else: 29 | print(f"Skip unsupported file type: {file_path}") 30 | 31 | out = model(doc) 32 | 33 | if out_format == "json": 34 | output = json.dumps(out.export(), indent=2) 35 | elif out_format == "txt": 36 | output = out.render() 37 | elif out_format == "xml": 38 | output = out.export_as_xml() 39 | 40 | path = Path("output").joinpath(file_path.stem + "." + out_format) 41 | if out_format == "xml": 42 | for i, (xml_bytes, xml_tree) in enumerate(output): 43 | path = Path("output").joinpath(file_path.stem + f"_{i}." + out_format) 44 | xml_tree.write(path, encoding="utf-8", xml_declaration=True) 45 | else: 46 | with open(path, "w") as f: 47 | f.write(output) 48 | 49 | 50 | def main(args): 51 | detection_model = detection.__dict__[args.detection]( 52 | pretrained=True, 53 | bin_thresh=args.bin_thresh, 54 | box_thresh=args.box_thresh, 55 | ) 56 | model = ocr_predictor(detection_model, args.recognition, pretrained=True) 57 | path = Path(args.path) 58 | 59 | os.makedirs(name="output", exist_ok=True) 60 | 61 | if path.is_dir(): 62 | to_process = [ 63 | f for f in path.iterdir() if str(f).lower().endswith(tuple(IMAGE_FILE_EXTENSIONS + OTHER_EXTENSIONS)) 64 | ] 65 | for file_path in tqdm(to_process): 66 | _process_file(model, file_path, args.format) 67 | else: 68 | _process_file(model, path, args.format) 69 | 70 | 71 | def parse_args(): 72 | parser = argparse.ArgumentParser( 73 | description="DocTR text detection", 74 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 75 | ) 76 | parser.add_argument("path", type=str, help="Path to process: PDF, image, directory") 77 | parser.add_argument("--detection", type=str, default="fast_base", help="Text detection model to use for analysis") 78 | parser.add_argument("--bin-thresh", type=float, default=0.3, help="Binarization threshold for the detection model.") 79 | parser.add_argument("--box-thresh", type=float, default=0.1, help="Threshold for the detection boxes.") 80 | parser.add_argument( 81 | "--recognition", type=str, default="crnn_vgg16_bn", help="Text recognition model to use for analysis" 82 | ) 83 | parser.add_argument("-f", "--format", choices=["txt", "json", "xml"], default="txt", help="Output format") 84 | return parser.parse_args() 85 | 86 | 87 | if __name__ == "__main__": 88 | parsed_args = parse_args() 89 | main(parsed_args) 90 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2025, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import os 7 | from pathlib import Path 8 | 9 | from setuptools import setup 10 | 11 | PKG_NAME = "python-doctr" 12 | VERSION = os.getenv("BUILD_VERSION", "1.0.1a0") 13 | 14 | 15 | if __name__ == "__main__": 16 | print(f"Building wheel {PKG_NAME}-{VERSION}") 17 | 18 | # Dynamically set the __version__ attribute 19 | cwd = Path(__file__).parent.absolute() 20 | with open(cwd.joinpath("doctr", "version.py"), "w", encoding="utf-8") as f: 21 | f.write(f"__version__ = '{VERSION}'\n") 22 | 23 | setup(name=PKG_NAME, version=VERSION) 24 | -------------------------------------------------------------------------------- /tests/common/test_contrib.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from doctr.contrib import artefacts 5 | from doctr.contrib.base import _BasePredictor 6 | from doctr.io import DocumentFile 7 | 8 | 9 | def test_base_predictor(): 10 | # check that we need to provide either a url or a model_path 11 | with pytest.raises(ValueError): 12 | _ = _BasePredictor(batch_size=2) 13 | 14 | predictor = _BasePredictor(batch_size=2, url=artefacts.default_cfgs["yolov8_artefact"]["url"]) 15 | # check that we need to implement preprocess and postprocess 16 | with pytest.raises(NotImplementedError): 17 | predictor.preprocess(np.zeros((10, 10, 3))) 18 | with pytest.raises(NotImplementedError): 19 | predictor.postprocess([np.zeros((10, 10, 3))], [[np.zeros((10, 10, 3))]]) 20 | 21 | 22 | def test_artefact_detector(mock_artefact_image_stream): 23 | doc = DocumentFile.from_images([mock_artefact_image_stream]) 24 | detector = artefacts.ArtefactDetector(batch_size=2, conf_threshold=0.5, iou_threshold=0.5) 25 | results = detector(doc) 26 | assert isinstance(results, list) and len(results) == 1 and isinstance(results[0], list) 27 | assert all(isinstance(artefact, dict) for artefact in results[0]) 28 | # check result keys 29 | assert all(key in results[0][0] for key in ["label", "confidence", "box"]) 30 | assert all(len(artefact["box"]) == 4 for artefact in results[0]) 31 | assert all(isinstance(coord, int) for box in results[0] for coord in box["box"]) 32 | assert all(isinstance(artefact["confidence"], float) for artefact in results[0]) 33 | assert all(isinstance(artefact["label"], str) for artefact in results[0]) 34 | # check results for the mock image are 9 artefacts 35 | assert len(results[0]) == 9 36 | # test visualization non-blocking for tests 37 | detector.show(block=False) 38 | -------------------------------------------------------------------------------- /tests/common/test_core.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import doctr 4 | from doctr.file_utils import requires_package 5 | 6 | 7 | def test_version(): 8 | assert len(doctr.__version__.split(".")) == 3 9 | 10 | 11 | def test_requires_package(): 12 | requires_package("numpy") # available 13 | with pytest.raises(ImportError): # not available 14 | requires_package("non_existent_package") 15 | -------------------------------------------------------------------------------- /tests/common/test_datasets.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from doctr import datasets 7 | 8 | 9 | def test_visiondataset(): 10 | url = "https://github.com/mindee/doctr/releases/download/v0.6.0/mnist.zip" 11 | with pytest.raises(ValueError): 12 | datasets.datasets.VisionDataset(url, download=False) 13 | 14 | dataset = datasets.datasets.VisionDataset(url, download=True, extract_archive=True) 15 | assert len(dataset) == 0 16 | assert repr(dataset) == "VisionDataset()" 17 | 18 | 19 | def test_abstractdataset(mock_image_path): 20 | with pytest.raises(ValueError): 21 | datasets.datasets.AbstractDataset("my/fantasy/folder") 22 | 23 | # Check transforms 24 | path = Path(mock_image_path) 25 | ds = datasets.datasets.AbstractDataset(path.parent) 26 | # Check target format 27 | with pytest.raises(AssertionError): 28 | ds.data = [(path.name, 0)] 29 | img, target = ds[0] 30 | with pytest.raises(AssertionError): 31 | ds.data = [(path.name, dict(boxes=np.array([[0, 0, 1, 1]])))] 32 | img, target = ds[0] 33 | with pytest.raises(AssertionError): 34 | ds.data = [(ds.data[0][0], {"label": "A"})] 35 | img, target = ds[0] 36 | 37 | # Patch some data 38 | ds.data = [(path.name, np.array([0]))] 39 | 40 | # Fetch the img 41 | img, target = ds[0] 42 | assert isinstance(target, np.ndarray) and target == np.array([0]) 43 | 44 | # Check img_transforms 45 | ds.img_transforms = lambda x: 1 - x 46 | img2, target2 = ds[0] 47 | assert np.all(img2.numpy() == 1 - img.numpy()) 48 | assert target == target2 49 | 50 | # Check sample_transforms 51 | ds.img_transforms = None 52 | ds.sample_transforms = lambda x, y: (x, y + 1) 53 | img3, target3 = ds[0] 54 | assert np.all(img3.numpy() == img.numpy()) and (target3 == (target + 1)) 55 | 56 | # Check inplace modifications 57 | ds.data = [(ds.data[0][0], "A")] 58 | 59 | def inplace_transfo(x, target): 60 | target += "B" 61 | return x, target 62 | 63 | ds.sample_transforms = inplace_transfo 64 | _, t = ds[0] 65 | _, t = ds[0] 66 | assert t == "AB" 67 | -------------------------------------------------------------------------------- /tests/common/test_datasets_vocabs.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | 3 | from doctr.datasets import VOCABS 4 | 5 | 6 | def test_vocabs_duplicates(): 7 | for key, vocab in VOCABS.items(): 8 | assert isinstance(vocab, str) 9 | 10 | duplicates = [char for char, count in Counter(vocab).items() if count > 1] 11 | assert not duplicates, f"Duplicate characters in {key} vocab: {duplicates}" 12 | -------------------------------------------------------------------------------- /tests/common/test_headers.py: -------------------------------------------------------------------------------- 1 | """Test for python files copyright headers.""" 2 | 3 | from datetime import datetime 4 | from pathlib import Path 5 | 6 | 7 | def test_copyright_header(): 8 | copyright_header = "".join([ 9 | f"# Copyright (C) {2021}-{datetime.now().year}, Mindee.\n\n", 10 | "# This program is licensed under the Apache License 2.0.\n", 11 | "# See LICENSE or go to for full license details.\n", 12 | ]) 13 | excluded_files = ["__init__.py", "version.py"] 14 | invalid_files = [] 15 | locations = [".github", "api/app", "demo", "docs", "doctr", "references", "scripts"] 16 | 17 | for location in locations: 18 | for source_path in Path(__file__).parent.parent.parent.joinpath(location).rglob("*.py"): 19 | if source_path.name not in excluded_files: 20 | source_path_content = source_path.read_text() 21 | if copyright_header not in source_path_content: 22 | invalid_files.append(source_path) 23 | assert len(invalid_files) == 0, f"Invalid copyright header in the following files: {invalid_files}" 24 | -------------------------------------------------------------------------------- /tests/common/test_io.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import pytest 6 | import requests 7 | 8 | from doctr import io 9 | 10 | 11 | def _check_doc_content(doc_tensors, num_pages): 12 | # 1 doc of 8 pages 13 | assert len(doc_tensors) == num_pages 14 | assert all(isinstance(page, np.ndarray) for page in doc_tensors) 15 | assert all(page.dtype == np.uint8 for page in doc_tensors) 16 | 17 | 18 | def test_read_pdf(mock_pdf): 19 | doc = io.read_pdf(mock_pdf) 20 | _check_doc_content(doc, 2) 21 | 22 | # Test with Path 23 | doc = io.read_pdf(Path(mock_pdf)) 24 | _check_doc_content(doc, 2) 25 | 26 | with open(mock_pdf, "rb") as f: 27 | doc = io.read_pdf(f.read()) 28 | _check_doc_content(doc, 2) 29 | 30 | # Wrong input type 31 | with pytest.raises(TypeError): 32 | _ = io.read_pdf(123) 33 | 34 | # Wrong path 35 | with pytest.raises(FileNotFoundError): 36 | _ = io.read_pdf("my_imaginary_file.pdf") 37 | 38 | 39 | def test_read_img_as_numpy(tmpdir_factory, mock_pdf): 40 | # Wrong input type 41 | with pytest.raises(TypeError): 42 | _ = io.read_img_as_numpy(123) 43 | 44 | # Non-existing file 45 | with pytest.raises(FileNotFoundError): 46 | io.read_img_as_numpy("my_imaginary_file.jpg") 47 | 48 | # Invalid image 49 | with pytest.raises(ValueError): 50 | io.read_img_as_numpy(str(mock_pdf)) 51 | 52 | # From path 53 | url = "https://doctr-static.mindee.com/models?id=v0.2.1/Grace_Hopper.jpg&src=0" 54 | file = BytesIO(requests.get(url).content) 55 | tmp_path = str(tmpdir_factory.mktemp("data").join("mock_img_file.jpg")) 56 | with open(tmp_path, "wb") as f: 57 | f.write(file.getbuffer()) 58 | 59 | # Path & stream 60 | with open(tmp_path, "rb") as f: 61 | page_stream = io.read_img_as_numpy(f.read()) 62 | 63 | for page in (io.read_img_as_numpy(tmp_path), page_stream): 64 | # Data type 65 | assert isinstance(page, np.ndarray) 66 | assert page.dtype == np.uint8 67 | # Shape 68 | assert page.shape == (606, 517, 3) 69 | 70 | # RGB 71 | bgr_page = io.read_img_as_numpy(tmp_path, rgb_output=False) 72 | assert np.all(page == bgr_page[..., ::-1]) 73 | 74 | # Resize 75 | target_size = (200, 150) 76 | resized_page = io.read_img_as_numpy(tmp_path, target_size) 77 | assert resized_page.shape[:2] == target_size 78 | 79 | 80 | def test_read_html(): 81 | url = "https://www.google.com" 82 | pdf_stream = io.read_html(url) 83 | assert isinstance(pdf_stream, bytes) 84 | 85 | 86 | def test_document_file(mock_pdf, mock_image_stream): 87 | pages = io.DocumentFile.from_images(mock_image_stream) 88 | _check_doc_content(pages, 1) 89 | 90 | assert isinstance(io.DocumentFile.from_pdf(mock_pdf), list) 91 | assert isinstance(io.DocumentFile.from_url("https://www.google.com"), list) 92 | 93 | 94 | def test_pdf(mock_pdf): 95 | pages = io.DocumentFile.from_pdf(mock_pdf) 96 | 97 | # As images 98 | num_pages = 2 99 | _check_doc_content(pages, num_pages) 100 | -------------------------------------------------------------------------------- /tests/common/test_models.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | import cv2 4 | import numpy as np 5 | import pytest 6 | import requests 7 | 8 | from doctr.io import reader 9 | from doctr.models._utils import estimate_orientation, get_language, invert_data_structure 10 | from doctr.utils import geometry 11 | 12 | 13 | @pytest.fixture(scope="function") 14 | def mock_image(tmpdir_factory): 15 | url = "https://doctr-static.mindee.com/models?id=v0.2.1/bitmap30.png&src=0" 16 | file = BytesIO(requests.get(url).content) 17 | tmp_path = str(tmpdir_factory.mktemp("data").join("mock_bitmap.jpg")) 18 | with open(tmp_path, "wb") as f: 19 | f.write(file.getbuffer()) 20 | image = reader.read_img_as_numpy(tmp_path) 21 | return image 22 | 23 | 24 | @pytest.fixture(scope="function") 25 | def mock_bitmap(mock_image): 26 | bitmap = np.squeeze(cv2.cvtColor(mock_image, cv2.COLOR_BGR2GRAY) / 255.0) 27 | bitmap = np.expand_dims(bitmap, axis=-1) 28 | return bitmap 29 | 30 | 31 | def test_estimate_orientation(mock_image, mock_bitmap, mock_tilted_payslip): 32 | assert estimate_orientation(mock_image * 0) == 0 33 | 34 | # test binarized image 35 | angle = estimate_orientation(mock_bitmap) 36 | assert abs(angle) - 30 < 1.0 37 | 38 | angle = estimate_orientation(mock_bitmap * 255) 39 | assert abs(angle) - 30.0 < 1.0 40 | 41 | angle = estimate_orientation(mock_image) 42 | assert abs(angle) - 30.0 < 1.0 43 | 44 | rotated = geometry.rotate_image(mock_image, angle) 45 | angle_rotated = estimate_orientation(rotated) 46 | assert abs(angle_rotated) == 0 47 | 48 | mock_tilted_payslip = reader.read_img_as_numpy(mock_tilted_payslip) 49 | assert estimate_orientation(mock_tilted_payslip) == -30 50 | 51 | rotated = geometry.rotate_image(mock_tilted_payslip, -30, expand=True) 52 | angle_rotated = estimate_orientation(rotated) 53 | assert abs(angle_rotated) < 1.0 54 | 55 | with pytest.raises(AssertionError): 56 | estimate_orientation(np.ones((10, 10, 10))) 57 | 58 | # test with general_page_orientation 59 | assert estimate_orientation(mock_bitmap, (90, 0.9)) in range(140, 160) 60 | 61 | rotated = geometry.rotate_image(mock_tilted_payslip, -30) 62 | assert estimate_orientation(rotated, (0, 0.9)) in range(-10, 10) 63 | 64 | assert estimate_orientation(mock_image, (0, 0.9)) - 30 < 1.0 65 | 66 | 67 | def test_get_lang(): 68 | sentence = "This is a test sentence." 69 | expected_lang = "en" 70 | threshold_prob = 0.99 71 | 72 | lang = get_language(sentence) 73 | 74 | assert lang[0] == expected_lang 75 | assert lang[1] > threshold_prob 76 | 77 | lang = get_language("a") 78 | assert lang[0] == "unknown" 79 | assert lang[1] == 0.0 80 | 81 | 82 | def test_convert_list_dict(): 83 | dic = {"k1": [[0], [0], [0]], "k2": [[1], [1], [1]]} 84 | tar_dict = [{"k1": [0], "k2": [1]}, {"k1": [0], "k2": [1]}, {"k1": [0], "k2": [1]}] 85 | 86 | converted_dic = invert_data_structure(dic) 87 | converted_list = invert_data_structure(tar_dict) 88 | 89 | assert converted_dic == tar_dict 90 | assert converted_list == dic 91 | -------------------------------------------------------------------------------- /tests/common/test_models_detection_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from doctr.models.detection._utils import _remove_padding 5 | 6 | 7 | @pytest.mark.parametrize("pages", [[np.zeros((1000, 1000))], [np.zeros((1000, 2000))], [np.zeros((2000, 1000))]]) 8 | @pytest.mark.parametrize("preserve_aspect_ratio", [True, False]) 9 | @pytest.mark.parametrize("symmetric_pad", [True, False]) 10 | @pytest.mark.parametrize("assume_straight_pages", [True, False]) 11 | def test_remove_padding(pages, preserve_aspect_ratio, symmetric_pad, assume_straight_pages): 12 | h, w = pages[0].shape 13 | # straight pages test cases 14 | if assume_straight_pages: 15 | loc_preds = [{"words": np.array([[0.7, 0.1, 0.7, 0.2]])}] 16 | if h == w or not preserve_aspect_ratio: 17 | expected = loc_preds 18 | else: 19 | if symmetric_pad: 20 | if h > w: 21 | expected = [{"words": np.array([[0.9, 0.1, 0.9, 0.2]])}] 22 | else: 23 | expected = [{"words": np.array([[0.7, 0.0, 0.7, 0.0]])}] 24 | else: 25 | if h > w: 26 | expected = [{"words": np.array([[1.0, 0.1, 1.0, 0.2]])}] 27 | else: 28 | expected = [{"words": np.array([[0.7, 0.2, 0.7, 0.4]])}] 29 | # non-straight pages test cases 30 | else: 31 | loc_preds = [{"words": np.array([[[0.9, 0.1], [0.9, 0.2], [0.8, 0.2], [0.8, 0.2]]])}] 32 | if h == w or not preserve_aspect_ratio: 33 | expected = loc_preds 34 | else: 35 | if symmetric_pad: 36 | if h > w: 37 | expected = [{"words": np.array([[[1.0, 0.1], [1.0, 0.2], [1.0, 0.2], [1.0, 0.2]]])}] 38 | else: 39 | expected = [{"words": np.array([[[0.9, 0.0], [0.9, 0.0], [0.8, 0.0], [0.8, 0.0]]])}] 40 | else: 41 | if h > w: 42 | expected = [{"words": np.array([[[1.0, 0.1], [1.0, 0.2], [1.0, 0.2], [1.0, 0.2]]])}] 43 | else: 44 | expected = [{"words": np.array([[[0.9, 0.2], [0.9, 0.4], [0.8, 0.4], [0.8, 0.4]]])}] 45 | 46 | result = _remove_padding(pages, loc_preds, preserve_aspect_ratio, symmetric_pad, assume_straight_pages) 47 | for res, exp in zip(result, expected): 48 | assert np.allclose(res["words"], exp["words"]) 49 | -------------------------------------------------------------------------------- /tests/common/test_models_recognition_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from doctr.models.recognition.utils import merge_multi_strings, merge_strings 4 | 5 | 6 | @pytest.mark.parametrize( 7 | "a, b, overlap_ratio, merged", 8 | [ 9 | # Last character of first string and first of last string will be cropped when merging - indicated by X 10 | ("abcX", "Xdef", 0.5, "abcdef"), 11 | ("abcdX", "Xdef", 0.75, "abcdef"), 12 | ("abcdeX", "Xdef", 0.9, "abcdef"), 13 | ("abcdefX", "Xdef", 0.9, "abcdef"), 14 | # Long repetition - four of seven characters in the second string are in the estimated overlap 15 | # X-chars will be cropped during merge, because they might be cut off during splitting of corresponding image 16 | ("abccccX", "Xcccccc", 4 / 7, "abcccccccc"), 17 | ("abc", "", 0.5, "abc"), 18 | ("", "abc", 0.5, "abc"), 19 | ("a", "b", 0.5, "ab"), 20 | # No overlap of input strings after crop 21 | ("abcdX", "Xefghi", 0.33, "abcdefghi"), 22 | # No overlap of input strings after crop with shorter inputs 23 | ("bcdX", "Xefgh", 0.4, "bcdefgh"), 24 | # No overlap of input strings after crop with even shorter inputs 25 | ("cdX", "Xefg", 0.5, "cdefg"), 26 | # Full overlap of input strings 27 | ("abcdX", "Xbcde", 1.0, "abcde"), 28 | # One repetition within inputs 29 | ("ababX", "Xabde", 0.8, "ababde"), 30 | # Multiple repetitions within inputs 31 | ("ababX", "Xabab", 0.8, "ababab"), 32 | # Multiple repetitions within inputs with shorter input strings 33 | ("abaX", "Xbab", 1.0, "abab"), 34 | # Longer multiple repetitions within inputs with half overlap 35 | ("cabababX", "Xabababc", 0.5, "cabababababc"), 36 | # Longer multiple repetitions within inputs with full overlap 37 | ("ababaX", "Xbabab", 1.0, "ababab"), 38 | # One different letter in overlap 39 | ("one_differon", "ferent_letter", 0.5, "one_differont_letter"), 40 | # First string empty after crop 41 | ("-", "test", 0.9, "-test"), 42 | # Second string empty after crop 43 | ("test", "-", 0.9, "test-"), 44 | ], 45 | ) 46 | def test_merge_strings(a, b, overlap_ratio, merged): 47 | assert merged == merge_strings(a, b, overlap_ratio) 48 | 49 | 50 | @pytest.mark.parametrize( 51 | "seq_list, overlap_ratio, last_overlap_ratio, merged", 52 | [ 53 | # One character at each conjunction point will be cropped when merging - indicated by X 54 | (["abcX", "Xdef"], 0.5, 0.5, "abcdef"), 55 | (["abcdX", "XdefX", "XefghX", "Xijk"], 0.5, 0.5, "abcdefghijk"), 56 | (["abcdX", "XdefX", "XefghiX", "Xaijk"], 0.5, 0.8, "abcdefghijk"), 57 | (["aaaa", "aaab", "aabc"], 0.8, 0.3, "aaaabc"), 58 | # Handle empty input 59 | ([], 0.5, 0.4, ""), 60 | ], 61 | ) 62 | def test_merge_multi_strings(seq_list, overlap_ratio, last_overlap_ratio, merged): 63 | assert merged == merge_multi_strings(seq_list, overlap_ratio, last_overlap_ratio) 64 | -------------------------------------------------------------------------------- /tests/common/test_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from doctr.transforms import modules as T 5 | from doctr.transforms.functional.base import expand_line 6 | 7 | 8 | def test_imagetransform(): 9 | transfo = T.ImageTransform(lambda x: 1 - x) 10 | assert transfo(0, 1) == (1, 1) 11 | 12 | 13 | def test_samplecompose(): 14 | transfos = [lambda x, y: (1 - x, y), lambda x, y: (x, 2 * y)] 15 | transfo = T.SampleCompose(transfos) 16 | assert transfo(0, 1) == (1, 2) 17 | 18 | 19 | def test_oneof(): 20 | transfos = [lambda x: 1 - x, lambda x: x + 10] 21 | transfo = T.OneOf(transfos) 22 | out = transfo(1) 23 | assert out == 0 or out == 11 24 | # test with target 25 | transfos = [lambda x, y: (1 - x, y), lambda x, y: (x + 10, y)] 26 | transfo = T.OneOf(transfos) 27 | out = transfo(1, np.array([2])) 28 | assert out == (0, 2) or out == (11, 2) and isinstance(out[1], np.ndarray) 29 | 30 | 31 | def test_randomapply(): 32 | transfo = T.RandomApply(lambda x: 1 - x) 33 | out = transfo(1) 34 | assert out == 0 or out == 1 35 | transfo = T.RandomApply(lambda x, y: (1 - x, 2 * y)) 36 | out = transfo(1, np.array([2])) 37 | assert out == (0, 4) or out == (1, 2) and isinstance(out[1], np.ndarray) 38 | assert repr(transfo).endswith(", p=0.5)") 39 | 40 | 41 | @pytest.mark.parametrize( 42 | "line", 43 | [ 44 | # Horizontal 45 | np.array([[63, 1], [42, 1]]).astype(np.int32), 46 | # Vertical 47 | np.array([[1, 63], [1, 42]]).astype(np.int32), 48 | # Normal 49 | np.array([[1, 63], [12, 42]]).astype(np.int32), 50 | ], 51 | ) 52 | def test_expand_line(line): 53 | out = expand_line(line, (100, 100)) 54 | assert isinstance(out, tuple) 55 | assert all(isinstance(val, (float, int, np.int32, np.float64)) and val >= 0 for val in out) 56 | -------------------------------------------------------------------------------- /tests/common/test_utils_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import PosixPath 3 | from unittest.mock import patch 4 | 5 | import pytest 6 | 7 | from doctr.utils.data import download_from_url 8 | 9 | 10 | @patch("doctr.utils.data._urlretrieve") 11 | @patch("pathlib.Path.mkdir") 12 | @patch.dict(os.environ, {"HOME": "/"}, clear=True) 13 | def test_download_from_url(mkdir_mock, urlretrieve_mock): 14 | download_from_url("test_url") 15 | urlretrieve_mock.assert_called_with("test_url", PosixPath("/.cache/doctr/test_url")) 16 | 17 | 18 | @patch.dict(os.environ, {"DOCTR_CACHE_DIR": "/test"}, clear=True) 19 | @patch("doctr.utils.data._urlretrieve") 20 | @patch("pathlib.Path.mkdir") 21 | def test_download_from_url_customizing_cache_dir(mkdir_mock, urlretrieve_mock): 22 | download_from_url("test_url") 23 | urlretrieve_mock.assert_called_with("test_url", PosixPath("/test/test_url")) 24 | 25 | 26 | @patch.dict(os.environ, {"HOME": "/"}, clear=True) 27 | @patch("pathlib.Path.mkdir", side_effect=OSError) 28 | @patch("logging.error") 29 | def test_download_from_url_error_creating_directory(logging_mock, mkdir_mock): 30 | with pytest.raises(OSError): 31 | download_from_url("test_url") 32 | logging_mock.assert_called_with( 33 | "Failed creating cache directory at /.cache/doctr." 34 | " You can change default cache directory using 'DOCTR_CACHE_DIR' environment variable if needed." 35 | ) 36 | 37 | 38 | @patch.dict(os.environ, {"HOME": "/", "DOCTR_CACHE_DIR": "/test"}, clear=True) 39 | @patch("pathlib.Path.mkdir", side_effect=OSError) 40 | @patch("logging.error") 41 | def test_download_from_url_error_creating_directory_with_env_var(logging_mock, mkdir_mock): 42 | with pytest.raises(OSError): 43 | download_from_url("test_url") 44 | logging_mock.assert_called_with( 45 | "Failed creating cache directory at /test using path from 'DOCTR_CACHE_DIR' environment variable." 46 | ) 47 | -------------------------------------------------------------------------------- /tests/common/test_utils_fonts.py: -------------------------------------------------------------------------------- 1 | from PIL.ImageFont import FreeTypeFont, ImageFont 2 | 3 | from doctr.utils.fonts import get_font 4 | 5 | 6 | def test_get_font(): 7 | # Attempts to load recommended OS font 8 | font = get_font() 9 | 10 | assert isinstance(font, (ImageFont, FreeTypeFont)) 11 | -------------------------------------------------------------------------------- /tests/common/test_utils_multithreading.py: -------------------------------------------------------------------------------- 1 | import os 2 | from multiprocessing.pool import ThreadPool 3 | from unittest.mock import patch 4 | 5 | import pytest 6 | 7 | from doctr.utils.multithreading import multithread_exec 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "input_seq, func, output_seq", 12 | [ 13 | [[1, 2, 3], lambda x: 2 * x, [2, 4, 6]], 14 | [[1, 2, 3], lambda x: x**2, [1, 4, 9]], 15 | [ 16 | ["this is", "show me", "I know"], 17 | lambda x: x + " the way", 18 | ["this is the way", "show me the way", "I know the way"], 19 | ], 20 | ], 21 | ) 22 | def test_multithread_exec(input_seq, func, output_seq): 23 | assert list(multithread_exec(func, input_seq)) == output_seq 24 | assert list(multithread_exec(func, input_seq, 0)) == output_seq 25 | 26 | 27 | @patch.dict(os.environ, {"DOCTR_MULTIPROCESSING_DISABLE": "TRUE"}, clear=True) 28 | def test_multithread_exec_multiprocessing_disable(): 29 | with patch.object(ThreadPool, "map") as mock_tp_map: 30 | multithread_exec(lambda x: x, [1, 2]) 31 | assert not mock_tp_map.called 32 | -------------------------------------------------------------------------------- /tests/common/test_utils_reconstitution.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from test_io_elements import _mock_kie_pages, _mock_pages 3 | 4 | from doctr.utils import reconstitution 5 | 6 | 7 | def test_synthesize_page(): 8 | pages = _mock_pages() 9 | # Test without probability rendering 10 | render_no_proba = reconstitution.synthesize_page(pages[0].export(), draw_proba=False) 11 | assert isinstance(render_no_proba, np.ndarray) 12 | assert render_no_proba.shape == (*pages[0].dimensions, 3) 13 | 14 | # Test with probability rendering 15 | render_with_proba = reconstitution.synthesize_page(pages[0].export(), draw_proba=True) 16 | assert isinstance(render_with_proba, np.ndarray) 17 | assert render_with_proba.shape == (*pages[0].dimensions, 3) 18 | 19 | # Test with only one line 20 | pages_one_line = pages[0].export() 21 | pages_one_line["blocks"][0]["lines"] = [pages_one_line["blocks"][0]["lines"][0]] 22 | render_one_line = reconstitution.synthesize_page(pages_one_line, draw_proba=True) 23 | assert isinstance(render_one_line, np.ndarray) 24 | assert render_one_line.shape == (*pages[0].dimensions, 3) 25 | 26 | # Test with polygons 27 | pages_poly = pages[0].export() 28 | pages_poly["blocks"][0]["lines"][0]["geometry"] = [(0, 0), (0, 1), (1, 1), (1, 0)] 29 | render_poly = reconstitution.synthesize_page(pages_poly, draw_proba=True) 30 | assert isinstance(render_poly, np.ndarray) 31 | assert render_poly.shape == (*pages[0].dimensions, 3) 32 | 33 | 34 | def test_synthesize_kie_page(): 35 | pages = _mock_kie_pages() 36 | # Test without probability rendering 37 | render_no_proba = reconstitution.synthesize_kie_page(pages[0].export(), draw_proba=False) 38 | assert isinstance(render_no_proba, np.ndarray) 39 | assert render_no_proba.shape == (*pages[0].dimensions, 3) 40 | 41 | # Test with probability rendering 42 | render_with_proba = reconstitution.synthesize_kie_page(pages[0].export(), draw_proba=True) 43 | assert isinstance(render_with_proba, np.ndarray) 44 | assert render_with_proba.shape == (*pages[0].dimensions, 3) 45 | -------------------------------------------------------------------------------- /tests/common/test_utils_visualization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from test_io_elements import _mock_pages 4 | 5 | from doctr.utils import visualization 6 | 7 | 8 | def test_visualize_page(): 9 | pages = _mock_pages() 10 | image = np.ones((300, 200, 3)) 11 | visualization.visualize_page(pages[0].export(), image, words_only=False) 12 | visualization.visualize_page(pages[0].export(), image, words_only=True, interactive=False) 13 | # geometry checks 14 | with pytest.raises(ValueError): 15 | visualization.create_obj_patch([1, 2], (100, 100)) 16 | 17 | with pytest.raises(ValueError): 18 | visualization.create_obj_patch((1, 2), (100, 100)) 19 | 20 | with pytest.raises(ValueError): 21 | visualization.create_obj_patch((1, 2, 3, 4, 5), (100, 100)) 22 | 23 | 24 | def test_draw_boxes(): 25 | image = np.ones((256, 256, 3), dtype=np.float32) 26 | boxes = [ 27 | [0.1, 0.1, 0.2, 0.2], 28 | [0.15, 0.15, 0.19, 0.2], # to suppress 29 | [0.5, 0.5, 0.6, 0.55], 30 | [0.55, 0.5, 0.7, 0.55], # to suppress 31 | ] 32 | visualization.draw_boxes(boxes=np.array(boxes), image=image, block=False) 33 | -------------------------------------------------------------------------------- /tests/pytorch/test_io_image_pt.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | 5 | from doctr.io import decode_img_as_tensor, read_img_as_tensor, tensor_from_numpy 6 | 7 | 8 | def test_read_img_as_tensor(mock_image_path): 9 | img = read_img_as_tensor(mock_image_path) 10 | 11 | assert isinstance(img, torch.Tensor) 12 | assert img.dtype == torch.float32 13 | assert img.shape == (3, 900, 1200) 14 | 15 | img = read_img_as_tensor(mock_image_path, dtype=torch.float16) 16 | assert img.dtype == torch.float16 17 | img = read_img_as_tensor(mock_image_path, dtype=torch.uint8) 18 | assert img.dtype == torch.uint8 19 | 20 | with pytest.raises(ValueError): 21 | read_img_as_tensor(mock_image_path, dtype=torch.float64) 22 | 23 | 24 | def test_decode_img_as_tensor(mock_image_stream): 25 | img = decode_img_as_tensor(mock_image_stream) 26 | 27 | assert isinstance(img, torch.Tensor) 28 | assert img.dtype == torch.float32 29 | assert img.shape == (3, 900, 1200) 30 | 31 | img = decode_img_as_tensor(mock_image_stream, dtype=torch.float16) 32 | assert img.dtype == torch.float16 33 | img = decode_img_as_tensor(mock_image_stream, dtype=torch.uint8) 34 | assert img.dtype == torch.uint8 35 | 36 | with pytest.raises(ValueError): 37 | decode_img_as_tensor(mock_image_stream, dtype=torch.float64) 38 | 39 | 40 | def test_tensor_from_numpy(mock_image_stream): 41 | with pytest.raises(ValueError): 42 | tensor_from_numpy(np.zeros((256, 256, 3)), torch.int64) 43 | 44 | out = tensor_from_numpy(np.zeros((256, 256, 3), dtype=np.uint8)) 45 | 46 | assert isinstance(out, torch.Tensor) 47 | assert out.dtype == torch.float32 48 | assert out.shape == (3, 256, 256) 49 | 50 | out = tensor_from_numpy(np.zeros((256, 256, 3), dtype=np.uint8), dtype=torch.float16) 51 | assert out.dtype == torch.float16 52 | out = tensor_from_numpy(np.zeros((256, 256, 3), dtype=np.uint8), dtype=torch.uint8) 53 | assert out.dtype == torch.uint8 54 | -------------------------------------------------------------------------------- /tests/pytorch/test_models_preprocessor_pt.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | 5 | from doctr.models.preprocessor import PreProcessor 6 | 7 | 8 | @pytest.mark.parametrize( 9 | "batch_size, output_size, input_tensor, expected_batches, expected_value", 10 | [ 11 | [2, (128, 128), np.full((3, 256, 128, 3), 255, dtype=np.uint8), 1, 0.5], # numpy uint8 12 | [2, (128, 128), np.ones((3, 256, 128, 3), dtype=np.float32), 1, 0.5], # numpy fp32 13 | [2, (128, 128), np.ones((3, 256, 128, 3), dtype=np.float16), 1, 0.5], # numpy fp16 14 | [2, (128, 128), [np.full((256, 128, 3), 255, dtype=np.uint8)] * 3, 2, 0.5], # list of numpy uint8 15 | [2, (128, 128), [np.ones((256, 128, 3), dtype=np.float32)] * 3, 2, 0.5], # list of numpy fp32 16 | [2, (128, 128), [np.ones((256, 128, 3), dtype=np.float16)] * 3, 2, 0.5], # list of numpy fp16 17 | ], 18 | ) 19 | def test_preprocessor(batch_size, output_size, input_tensor, expected_batches, expected_value): 20 | processor = PreProcessor(output_size, batch_size) 21 | 22 | # Invalid input type 23 | with pytest.raises(TypeError): 24 | processor(42) 25 | # 4D check 26 | with pytest.raises(AssertionError): 27 | processor(np.full((256, 128, 3), 255, dtype=np.uint8)) 28 | with pytest.raises(TypeError): 29 | processor(np.full((1, 256, 128, 3), 255, dtype=np.int32)) 30 | # 3D check 31 | with pytest.raises(AssertionError): 32 | processor([np.full((3, 256, 128, 3), 255, dtype=np.uint8)]) 33 | with pytest.raises(TypeError): 34 | processor([np.full((256, 128, 3), 255, dtype=np.int32)]) 35 | 36 | with torch.no_grad(): 37 | out = processor(input_tensor) 38 | assert isinstance(out, list) and len(out) == expected_batches 39 | assert all(isinstance(b, torch.Tensor) for b in out) 40 | assert all(b.dtype == torch.float32 for b in out) 41 | assert all(b.shape[-2:] == output_size for b in out) 42 | assert all(torch.all(b == expected_value) for b in out) 43 | assert len(repr(processor).split("\n")) == 4 44 | -------------------------------------------------------------------------------- /tests/pytorch/test_models_utils_pt.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | import torch 5 | from torch import nn 6 | 7 | from doctr.models.utils import ( 8 | _bf16_to_float32, 9 | _copy_tensor, 10 | conv_sequence_pt, 11 | load_pretrained_params, 12 | set_device_and_dtype, 13 | ) 14 | 15 | 16 | def test_copy_tensor(): 17 | x = torch.rand(8) 18 | m = _copy_tensor(x) 19 | assert m.device == x.device and m.dtype == x.dtype and m.shape == x.shape and torch.allclose(m, x) 20 | 21 | 22 | def test_bf16_to_float32(): 23 | x = torch.randn([2, 2], dtype=torch.bfloat16) 24 | converted_x = _bf16_to_float32(x) 25 | assert x.dtype == torch.bfloat16 and converted_x.dtype == torch.float32 and torch.equal(converted_x, x.float()) 26 | 27 | 28 | def test_load_pretrained_params(tmpdir_factory): 29 | model = nn.Sequential(nn.Linear(8, 8), nn.ReLU(), nn.Linear(8, 4)) 30 | # Retrieve this URL 31 | url = "https://github.com/mindee/doctr/releases/download/v0.2.1/tmp_checkpoint-6f0ce0e6.pt" 32 | # Temp cache dir 33 | cache_dir = tmpdir_factory.mktemp("cache") 34 | # Pass an incorrect hash 35 | with pytest.raises(ValueError): 36 | load_pretrained_params(model, url, "mywronghash", cache_dir=str(cache_dir)) 37 | # Let it resolve the hash from the file name 38 | load_pretrained_params(model, url, cache_dir=str(cache_dir)) 39 | # Check that the file was downloaded & the archive extracted 40 | assert os.path.exists(cache_dir.join("models").join(url.rpartition("/")[-1].split("&")[0])) 41 | # Default initialization 42 | load_pretrained_params(model, None) 43 | # Check ignore keys 44 | load_pretrained_params(model, url, cache_dir=str(cache_dir), ignore_keys=["2.weight"]) 45 | # non matching keys 46 | model = nn.Sequential(nn.Linear(8, 8), nn.ReLU(), nn.Linear(8, 4), nn.ReLU(), nn.Linear(4, 1)) 47 | with pytest.raises(ValueError): 48 | load_pretrained_params(model, url, cache_dir=str(cache_dir), ignore_keys=["2.weight"]) 49 | 50 | 51 | def test_conv_sequence(): 52 | assert len(conv_sequence_pt(3, 8, kernel_size=3)) == 1 53 | assert len(conv_sequence_pt(3, 8, True, kernel_size=3)) == 2 54 | assert len(conv_sequence_pt(3, 8, False, True, kernel_size=3)) == 2 55 | assert len(conv_sequence_pt(3, 8, True, True, kernel_size=3)) == 3 56 | 57 | 58 | def test_set_device_and_dtype(): 59 | model = nn.Sequential(nn.Linear(8, 8), nn.ReLU(), nn.Linear(8, 4)) 60 | batches = [torch.rand(8) for _ in range(2)] 61 | model, batches = set_device_and_dtype(model, batches, device="cpu", dtype=torch.float32) 62 | assert model[0].weight.device == torch.device("cpu") 63 | assert model[0].weight.dtype == torch.float32 64 | assert batches[0].device == torch.device("cpu") 65 | assert batches[0].dtype == torch.float32 66 | model, batches = set_device_and_dtype(model, batches, device="cpu", dtype=torch.float16) 67 | assert model[0].weight.dtype == torch.float16 68 | assert batches[0].dtype == torch.float16 69 | --------------------------------------------------------------------------------