├── tests ├── __init__.py ├── test_data │ ├── samples │ │ ├── empty_iocr.png │ │ ├── PHM.2013.page_30.png │ │ ├── page_with_list.png │ │ ├── page_with_table.png │ │ ├── ADS.2007.page_123.png │ │ └── empty_iocr.png.json │ ├── code_formula │ │ ├── images │ │ │ ├── code.png │ │ │ └── formula.png │ │ └── gt │ │ │ ├── formula.txt │ │ │ └── code.txt │ └── figure_classifier │ │ └── images │ │ ├── map.jpg │ │ └── bar_chart.jpg ├── test_listitem_marker_model.py ├── test_document_figure_classifier.py ├── test_common.py ├── test_layout_predictor.py ├── test_code_formula_predictor.py └── test_reading_order.py ├── docling_ibm_models ├── py.typed ├── __init__.py ├── layoutmodel │ ├── __init__.py │ ├── labels.py │ └── layout_predictor.py ├── tableformer │ ├── __init__.py │ ├── utils │ │ ├── __init__.py │ │ ├── mem_monitor.py │ │ ├── app_profiler.py │ │ └── utils.py │ ├── models │ │ ├── __init__.py │ │ ├── common │ │ │ ├── __init__.py │ │ │ └── base_model.py │ │ └── table04_rs │ │ │ ├── __init__.py │ │ │ ├── encoder04_rs.py │ │ │ ├── bbox_decoder_rs.py │ │ │ ├── transformer_rs.py │ │ │ └── tablemodel04_rs.py │ ├── data_management │ │ ├── __init__.py │ │ ├── transforms.py │ │ └── functional.py │ ├── settings.py │ └── common.py ├── reading_order │ └── __init__.py ├── code_formula_model │ ├── __init__.py │ ├── models │ │ ├── __init__.py │ │ ├── sam_opt_image_processor.py │ │ └── sam_opt.py │ └── code_formula_predictor.py ├── list_item_normalizer │ └── __init__.py └── document_figure_classifier_model │ ├── __init__.py │ └── document_figure_classifier_predictor.py ├── .github ├── dco.yml ├── workflows │ ├── ci.yml │ ├── pypi.yml │ ├── discord-release.yml │ ├── cd.yml │ ├── checks.yml │ └── dco-advisor.yml ├── mergify.yml ├── codecov.yml ├── PULL_REQUEST_TEMPLATE.md └── scripts │ └── release.sh ├── docs ├── tbm04.png └── tablemodel_overview_color.png ├── MAINTAINERS.md ├── .pre-commit-config.yaml ├── LICENSE ├── .gitignore ├── CONTRIBUTING.md ├── demo ├── demo_code_formula_predictor.py ├── demo_document_figure_classifier_predictor.py └── demo_layout_predictor.py ├── pyproject.toml ├── README.md └── CODE_OF_CONDUCT.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docling_ibm_models/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docling_ibm_models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docling_ibm_models/layoutmodel/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docling_ibm_models/tableformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docling_ibm_models/reading_order/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docling_ibm_models/tableformer/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docling_ibm_models/code_formula_model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docling_ibm_models/list_item_normalizer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docling_ibm_models/tableformer/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docling_ibm_models/code_formula_model/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docling_ibm_models/tableformer/data_management/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docling_ibm_models/tableformer/models/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docling_ibm_models/document_figure_classifier_model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docling_ibm_models/tableformer/models/table04_rs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.github/dco.yml: -------------------------------------------------------------------------------- 1 | allowRemediationCommits: 2 | individual: true 3 | -------------------------------------------------------------------------------- /docs/tbm04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docling-project/docling-ibm-models/HEAD/docs/tbm04.png -------------------------------------------------------------------------------- /docs/tablemodel_overview_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docling-project/docling-ibm-models/HEAD/docs/tablemodel_overview_color.png -------------------------------------------------------------------------------- /tests/test_data/samples/empty_iocr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docling-project/docling-ibm-models/HEAD/tests/test_data/samples/empty_iocr.png -------------------------------------------------------------------------------- /tests/test_data/code_formula/images/code.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docling-project/docling-ibm-models/HEAD/tests/test_data/code_formula/images/code.png -------------------------------------------------------------------------------- /tests/test_data/samples/PHM.2013.page_30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docling-project/docling-ibm-models/HEAD/tests/test_data/samples/PHM.2013.page_30.png -------------------------------------------------------------------------------- /tests/test_data/samples/page_with_list.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docling-project/docling-ibm-models/HEAD/tests/test_data/samples/page_with_list.png -------------------------------------------------------------------------------- /tests/test_data/samples/page_with_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docling-project/docling-ibm-models/HEAD/tests/test_data/samples/page_with_table.png -------------------------------------------------------------------------------- /tests/test_data/samples/ADS.2007.page_123.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docling-project/docling-ibm-models/HEAD/tests/test_data/samples/ADS.2007.page_123.png -------------------------------------------------------------------------------- /tests/test_data/code_formula/images/formula.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docling-project/docling-ibm-models/HEAD/tests/test_data/code_formula/images/formula.png -------------------------------------------------------------------------------- /tests/test_data/figure_classifier/images/map.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docling-project/docling-ibm-models/HEAD/tests/test_data/figure_classifier/images/map.jpg -------------------------------------------------------------------------------- /tests/test_data/figure_classifier/images/bar_chart.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/docling-project/docling-ibm-models/HEAD/tests/test_data/figure_classifier/images/bar_chart.jpg -------------------------------------------------------------------------------- /tests/test_data/code_formula/gt/formula.txt: -------------------------------------------------------------------------------- 1 | E _ { n l } ( t ) = \frac { 1 } { 8 } \int _ { \varepsilon } ^ { 1 } d \rho ( K _ { 1 } \Psi ) ( K _ { 2 } \Psi ) - \frac { 1 } { 8 } \int _ { \varepsilon } ^ { 1 } d \rho ( K _ { 3 } \Psi ) ( K _ { 4 } \Psi ), -------------------------------------------------------------------------------- /tests/test_data/code_formula/gt/code.txt: -------------------------------------------------------------------------------- 1 | <_C++_> #include 2 | using namespace std; 3 | 4 | int main(){ 5 | 6 | int n; 7 | 8 | while(cin>>n, n){ 9 | int cnt=0; 10 | 11 | n=1000-n; 12 | cnt+=n/500; 13 | n%=500; 14 | cnt+=n/100; 15 | n%=100; 16 | cnt+=n/50; 17 | n%=50; 18 | cnt+=n/10; 19 | n%=10; 20 | cnt+=n/5; 21 | n%=5; 22 | 23 | cout<= 2" 17 | -------------------------------------------------------------------------------- /.github/codecov.yml: -------------------------------------------------------------------------------- 1 | codecov: 2 | # https://docs.codecov.io/docs/comparing-commits 3 | allow_coverage_offsets: true 4 | coverage: 5 | status: 6 | project: 7 | default: 8 | informational: true 9 | target: auto # auto compares coverage to the previous base commit 10 | if_ci_failed: success 11 | flags: 12 | - docling-ibm-models 13 | comment: 14 | layout: "reach, diff, flags, files" 15 | behavior: default 16 | require_changes: false # if true: only post the comment if coverage changes 17 | branches: # branch names that can post comment 18 | - "main" 19 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 8 | 9 | 13 | 14 | **Checklist:** 15 | 16 | - [ ] Documentation has been updated, if necessary. 17 | - [ ] Examples have been added, if necessary. 18 | - [ ] Tests have been added, if necessary. 19 | 20 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | fail_fast: true 2 | repos: 3 | - repo: local 4 | hooks: 5 | - id: black 6 | name: Black 7 | entry: uv run --no-sync black docling_ibm_models 8 | pass_filenames: false 9 | language: system 10 | files: '\.py$' 11 | - id: isort 12 | name: isort 13 | entry: uv run --no-sync isort docling_ibm_models 14 | pass_filenames: false 15 | language: system 16 | files: '\.py$' 17 | - id: system 18 | name: MyPy 19 | entry: uv run --no-sync mypy docling_ibm_models 20 | pass_filenames: false 21 | language: system 22 | files: '\.py$' 23 | - repo: https://github.com/astral-sh/uv-pre-commit 24 | rev: 0.7.8 25 | hooks: 26 | - id: uv-lock 27 | -------------------------------------------------------------------------------- /.github/workflows/pypi.yml: -------------------------------------------------------------------------------- 1 | name: "Build and publish package" 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | env: 8 | UV_FROZEN: "1" 9 | 10 | permissions: 11 | contents: read 12 | 13 | jobs: 14 | build-and-publish: 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: ['3.12'] 19 | environment: 20 | name: pypi 21 | url: https://pypi.org/p/docling-ibm-models 22 | permissions: 23 | id-token: write # IMPORTANT: mandatory for trusted publishing 24 | steps: 25 | - uses: actions/checkout@v4 26 | - name: Install uv and set the python version 27 | uses: astral-sh/setup-uv@v5 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | enable-cache: true 31 | - name: Install dependencies 32 | run: uv sync 33 | - name: Build package 34 | run: uv build 35 | - name: Publish distribution 📦 to PyPI 36 | uses: pypa/gh-action-pypi-publish@release/v1 37 | with: 38 | attestations: true 39 | -------------------------------------------------------------------------------- /docling_ibm_models/code_formula_model/models/sam_opt_image_processor.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright IBM Corp. 2024 - 2024 3 | # SPDX-License-Identifier: MIT 4 | # 5 | from PIL import Image 6 | from torchvision.transforms import functional as F 7 | from transformers import AutoImageProcessor 8 | from transformers.image_processing_utils import ImageProcessingMixin 9 | 10 | 11 | class SamOptImageProcessor(ImageProcessingMixin): 12 | 13 | def __init__(self, size=(1024, 1024), mean=None, std=None, **kwargs): 14 | super().__init__(**kwargs) 15 | self.size = size 16 | self.mean = mean 17 | self.std = std 18 | 19 | def __call__(self, image): 20 | if not isinstance(image, Image.Image): 21 | raise ValueError("Input must be a PIL Image") 22 | 23 | image = F.resize(image, self.size) 24 | image = F.to_tensor(image) 25 | 26 | image = F.normalize(image, mean=self.mean, std=self.std) 27 | 28 | return image 29 | 30 | 31 | AutoImageProcessor.register( 32 | config_class="SamOptImageProcessor", 33 | slow_image_processor_class=SamOptImageProcessor, 34 | ) 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 International Business Machines 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # Tmp files and directories 9 | stderr.* 10 | stdout.* 11 | *.tar 12 | test.sh 13 | OutputDecoder 14 | jobs.txt 15 | _std*.* 16 | tests/tmp/* 17 | runs/* 18 | *.onnx 19 | .DS_Store 20 | viz/ 21 | 22 | # VSCode 23 | .vscode 24 | 25 | # VIM 26 | *.swp 27 | *.swo 28 | *.bak 29 | 30 | # Environments 31 | .env 32 | .venv 33 | _venv/ 34 | env/ 35 | venv/ 36 | ENV/ 37 | env.bak/ 38 | venv.bak/ 39 | venv 40 | 41 | # Distribution / packaging 42 | .Python 43 | build/ 44 | develop-eggs/ 45 | dist/ 46 | downloads/ 47 | eggs/ 48 | .eggs/ 49 | lib64/ 50 | parts/ 51 | sdist/ 52 | var/ 53 | wheels/ 54 | *.egg-info/ 55 | .installed.cfg 56 | *.egg 57 | MANIFEST 58 | 59 | # checkpoint file for testing 60 | tests/test_data/model_artifacts/*.check 61 | tests/test_data/model_artifacts/*.json 62 | tests/test_data/model_artifacts/*.pt 63 | 64 | # test results 65 | tests/test_data/viz/ 66 | 67 | # Unit test / coverage reports 68 | htmlcov/ 69 | .tox/ 70 | .nox/ 71 | .coverage 72 | .coverage.* 73 | .cache 74 | nosetests.xml 75 | coverage.xml 76 | *.cover 77 | *.py,cover 78 | .hypothesis/ 79 | .pytest_cache/ 80 | cover/ 81 | -------------------------------------------------------------------------------- /.github/scripts/release.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e # trigger failure on error - do not remove! 4 | set -x # display command on output 5 | 6 | if [ -z "${TARGET_VERSION}" ]; then 7 | >&2 echo "No TARGET_VERSION specified" 8 | exit 1 9 | fi 10 | CHGLOG_FILE="${CHGLOG_FILE:-CHANGELOG.md}" 11 | 12 | # update package version 13 | uvx --from=toml-cli toml set --toml-path=pyproject.toml project.version "${TARGET_VERSION}" 14 | UV_FROZEN=0 uv lock --upgrade-package docling-ibm-models 15 | 16 | # collect release notes 17 | REL_NOTES=$(mktemp) 18 | uv run --no-sync semantic-release changelog --unreleased >> "${REL_NOTES}" 19 | 20 | # update changelog 21 | TMP_CHGLOG=$(mktemp) 22 | TARGET_TAG_NAME="v${TARGET_VERSION}" 23 | RELEASE_URL="$(gh repo view --json url -q ".url")/releases/tag/${TARGET_TAG_NAME}" 24 | printf "## [${TARGET_TAG_NAME}](${RELEASE_URL}) - $(date -Idate)\n\n" >> "${TMP_CHGLOG}" 25 | cat "${REL_NOTES}" >> "${TMP_CHGLOG}" 26 | if [ -f "${CHGLOG_FILE}" ]; then 27 | printf "\n" | cat - "${CHGLOG_FILE}" >> "${TMP_CHGLOG}" 28 | fi 29 | mv "${TMP_CHGLOG}" "${CHGLOG_FILE}" 30 | 31 | # push changes 32 | git config --global user.name 'github-actions[bot]' 33 | git config --global user.email 'github-actions[bot]@users.noreply.github.com' 34 | git add pyproject.toml uv.lock "${CHGLOG_FILE}" 35 | COMMIT_MSG="chore: bump version to ${TARGET_VERSION} [skip ci]" 36 | git commit -m "${COMMIT_MSG}" 37 | git push origin main 38 | 39 | # create GitHub release (incl. Git tag) 40 | gh release create "${TARGET_TAG_NAME}" -F "${REL_NOTES}" 41 | -------------------------------------------------------------------------------- /.github/workflows/discord-release.yml: -------------------------------------------------------------------------------- 1 | # .github/workflows/discord-release.yml 2 | name: Notify Discord on Release 3 | 4 | on: 5 | release: 6 | types: [published] 7 | 8 | jobs: 9 | discord: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Send release info to Discord 13 | env: 14 | DISCORD_WEBHOOK: ${{ secrets.RELEASES_DISCORD_WEBHOOK }} 15 | run: | 16 | REPO_NAME=${{ github.repository }} 17 | RELEASE_TAG=${{ github.event.release.tag_name }} 18 | RELEASE_NAME="${{ github.event.release.name }}" 19 | RELEASE_URL=${{ github.event.release.html_url }} 20 | 21 | # Capture the body safely (handles backticks, $, ", etc.) 22 | RELEASE_BODY=$(cat <<'EOF' 23 | ${{ github.event.release.body }} 24 | EOF 25 | ) 26 | 27 | # Fallback if release name is empty 28 | if [ -z "$RELEASE_NAME" ]; then 29 | RELEASE_NAME=$RELEASE_TAG 30 | fi 31 | 32 | PAYLOAD=$(jq -n \ 33 | --arg title "🚀 New Release: $RELEASE_NAME" \ 34 | --arg url "$RELEASE_URL" \ 35 | --arg desc "$RELEASE_BODY" \ 36 | --arg author_name "$REPO_NAME" \ 37 | --arg author_icon "https://github.com/docling-project.png" \ 38 | '{embeds: [{title: $title, url: $url, description: $desc, color: 5814783, author: {name: $author_name, icon_url: $author_icon}}]}') 39 | 40 | curl -H "Content-Type: application/json" \ 41 | -d "$PAYLOAD" \ 42 | "$DISCORD_WEBHOOK" 43 | -------------------------------------------------------------------------------- /tests/test_data/samples/empty_iocr.png.json: -------------------------------------------------------------------------------- 1 | { 2 | "doc_source_type": {}, 3 | "font_dist_info": {}, 4 | "info": { 5 | "histogram": { 6 | "mean-char-height": {}, 7 | "mean-char-width": {}, 8 | "number-of-chars": {} 9 | }, 10 | "styles": [] 11 | }, 12 | "title": "", 13 | "metadata": { 14 | "numPages": 1 15 | }, 16 | "pages": [ 17 | { 18 | "blocks": [], 19 | "cells": [], 20 | "height": 1612, 21 | "width": 1237, 22 | "dimensions": { 23 | "bbox": [ 24 | 0.0, 25 | 0.0, 26 | 1237, 27 | 1612 28 | ], 29 | "height": 1612, 30 | "origin": "TopLeft", 31 | "width": 1237 32 | }, 33 | "fonts": [], 34 | "links": [], 35 | "rotation": 0.0, 36 | "rectangles": [], 37 | "textPositions": [], 38 | "text_lines": [], 39 | "tokens": [], 40 | "localized_image_locations": [], 41 | "scanned_elements": [], 42 | "paths": [], 43 | "pageNumber": 1, 44 | "page_image": {}, 45 | "lang": [ 46 | "en", 47 | "pt", 48 | "fr", 49 | "it", 50 | "es", 51 | "fi" 52 | ] 53 | } 54 | ], 55 | "settings": {}, 56 | "passedHeadersFooters": { 57 | "headerFooters": { 58 | "1": { 59 | "headerHeight": 0, 60 | "footerHeight": 0 61 | } 62 | }, 63 | "headerFound": false, 64 | "footerFound": false 65 | }, 66 | "styles": [] 67 | } -------------------------------------------------------------------------------- /docling_ibm_models/layoutmodel/labels.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | 4 | class LayoutLabels: 5 | r"""Single point of reference for the layout labels""" 6 | 7 | def __init__(self) -> None: 8 | r""" """ 9 | # Canonical classes originating in DLNv2 10 | self._canonical: Dict[int, str] = { 11 | # DLNv1 and DLNv2 12 | 0: "Caption", 13 | 1: "Footnote", 14 | 2: "Formula", 15 | 3: "List-item", 16 | 4: "Page-footer", 17 | 5: "Page-header", 18 | 6: "Picture", 19 | 7: "Section-header", 20 | 8: "Table", 21 | 9: "Text", 22 | 10: "Title", 23 | # DLNv2 only 24 | 11: "Document Index", 25 | 12: "Code", 26 | 13: "Checkbox-Selected", 27 | 14: "Checkbox-Unselected", 28 | 15: "Form", 29 | 16: "Key-Value Region", 30 | } 31 | self._inverse_canonical: Dict[str, int] = { 32 | label: class_id for class_id, label in self._canonical.items() 33 | } 34 | 35 | # Shifted canonical classes with background in 0 36 | self._shifted_canonical: Dict[int, str] = {0: "Background"} 37 | for k, v in self._canonical.items(): 38 | self._shifted_canonical[k + 1] = v 39 | self._inverse_shifted_canonical: Dict[str, int] = { 40 | label: class_id for class_id, label in self._shifted_canonical.items() 41 | } 42 | 43 | def canonical_categories(self) -> Dict[int, str]: 44 | return self._canonical 45 | 46 | def canonical_to_int(self) -> Dict[str, int]: 47 | return self._inverse_canonical 48 | 49 | def shifted_canonical_categories(self) -> Dict[int, str]: 50 | return self._shifted_canonical 51 | 52 | def shifted_canonical_to_int(self) -> Dict[str, int]: 53 | return self._inverse_shifted_canonical 54 | -------------------------------------------------------------------------------- /.github/workflows/cd.yml: -------------------------------------------------------------------------------- 1 | name: "Run CD" 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | env: 7 | UV_FROZEN: "1" 8 | 9 | jobs: 10 | code-checks: 11 | uses: ./.github/workflows/checks.yml 12 | with: 13 | push_coverage: false 14 | pre-release-check: 15 | runs-on: ubuntu-latest 16 | outputs: 17 | TARGET_TAG_V: ${{ steps.version_check.outputs.TRGT_VERSION }} 18 | steps: 19 | - uses: actions/checkout@v4 20 | with: 21 | fetch-depth: 0 # for fetching tags, required for semantic-release 22 | - name: Install uv and set the python version 23 | uses: astral-sh/setup-uv@v5 24 | with: 25 | enable-cache: true 26 | - name: Install dependencies 27 | run: uv sync --only-dev 28 | - name: Check version of potential release 29 | id: version_check 30 | run: | 31 | TRGT_VERSION=$(uv run --no-sync semantic-release print-version) 32 | echo "TRGT_VERSION=${TRGT_VERSION}" >> "$GITHUB_OUTPUT" 33 | echo "${TRGT_VERSION}" 34 | - name: Check notes of potential release 35 | run: uv run --no-sync semantic-release changelog --unreleased 36 | release: 37 | needs: [code-checks, pre-release-check] 38 | if: needs.pre-release-check.outputs.TARGET_TAG_V != '' 39 | environment: auto-release 40 | runs-on: ubuntu-latest 41 | concurrency: release 42 | steps: 43 | - uses: actions/create-github-app-token@v1 44 | id: app-token 45 | with: 46 | app-id: ${{ vars.CI_APP_ID }} 47 | private-key: ${{ secrets.CI_PRIVATE_KEY }} 48 | - uses: actions/checkout@v4 49 | with: 50 | token: ${{ steps.app-token.outputs.token }} 51 | fetch-depth: 0 # for fetching tags, required for semantic-release 52 | - name: Install uv and set the python version 53 | uses: astral-sh/setup-uv@v5 54 | with: 55 | enable-cache: true 56 | - name: Install dependencies 57 | run: uv sync --only-dev 58 | - name: Run release script 59 | env: 60 | GH_TOKEN: ${{ steps.app-token.outputs.token }} 61 | TARGET_VERSION: ${{ needs.pre-release-check.outputs.TARGET_TAG_V }} 62 | CHGLOG_FILE: CHANGELOG.md 63 | run: ./.github/scripts/release.sh 64 | shell: bash 65 | -------------------------------------------------------------------------------- /docling_ibm_models/tableformer/models/table04_rs/encoder04_rs.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright IBM Corp. 2024 - 2024 3 | # SPDX-License-Identifier: MIT 4 | # 5 | import logging 6 | 7 | import torch.nn as nn 8 | import torchvision 9 | 10 | import docling_ibm_models.tableformer.settings as s 11 | 12 | LOG_LEVEL = logging.INFO 13 | # LOG_LEVEL = logging.DEBUG 14 | 15 | 16 | class Encoder04(nn.Module): 17 | """ 18 | Encoder based on resnet-18 19 | """ 20 | 21 | def __init__(self, enc_image_size, enc_dim=512): 22 | r""" 23 | Parameters 24 | ---------- 25 | enc_image_size : int 26 | Assuming that the encoded image is a square, this is the length of the image side 27 | """ 28 | 29 | super(Encoder04, self).__init__() 30 | self.enc_image_size = enc_image_size 31 | self._encoder_dim = enc_dim 32 | 33 | resnet = torchvision.models.resnet18() 34 | modules = list(resnet.children())[:-3] 35 | 36 | self._resnet = nn.Sequential(*modules) 37 | self._adaptive_pool = nn.AdaptiveAvgPool2d( 38 | (self.enc_image_size, self.enc_image_size) 39 | ) 40 | 41 | def _log(self): 42 | # Setup a custom logger 43 | return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL) 44 | 45 | def get_encoder_dim(self): 46 | return self._encoder_dim 47 | 48 | def forward(self, images): 49 | """ 50 | Forward propagation 51 | The encoder_dim 512 is decided by the structure of the image network (modified resnet-19) 52 | 53 | Parameters 54 | ---------- 55 | images : tensor (batch_size, image_channels, resized_image, resized_image) 56 | images input 57 | 58 | Returns 59 | ------- 60 | tensor : (batch_size, enc_image_size, enc_image_size, 256) 61 | encoded images 62 | """ 63 | out = self._resnet(images) # (batch_size, 256, 28, 28) 64 | self._log().debug("forward: resnet out: {}".format(out.size())) 65 | out = self._adaptive_pool(out) 66 | out = out.permute( 67 | 0, 2, 3, 1 68 | ) # (batch_size, enc_image_size, enc_image_size, 256) 69 | 70 | self._log().debug("enc forward: final out: {}".format(out.size())) 71 | 72 | return out 73 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## Contributing In General 2 | Our project welcomes external contributions. If you have an itch, please feel 3 | free to scratch it. 4 | 5 | For more details on the contributing guidelines head to the Docling Project [community repository](https://github.com/docling-project/community). 6 | 7 | ## Developing 8 | 9 | ### Usage of uv 10 | 11 | We use [uv](https://docs.astral.sh/uv/) as package and project manager. 12 | 13 | #### Installation 14 | 15 | To install `uv`, check the documentation on [Installing uv](https://docs.astral.sh/uv/getting-started/installation/). 16 | 17 | #### Create an environment and sync it 18 | 19 | You can use the `uv sync` to create a project virtual environment (if it does not already exist) and sync 20 | the project's dependencies with the environment. 21 | 22 | ```bash 23 | uv sync 24 | ``` 25 | 26 | #### Use a specific Python version (optional) 27 | 28 | If you need to work with a specific version of Python, you can create a new virtual environment for that version 29 | and run the sync command: 30 | 31 | ```bash 32 | uv venv --python 3.12 33 | uv sync 34 | ``` 35 | 36 | More detailed options are described on the [Using Python environments](https://docs.astral.sh/uv/pip/environments/) documentation. 37 | 38 | #### Add a new dependency 39 | 40 | Simply use the `uv add` command. The `pyproject.toml` and `uv.lock` files will be updated. 41 | 42 | ```bash 43 | uv add [OPTIONS] > 44 | ``` 45 | 46 | ### Code sytle guidelines 47 | 48 | We use the following tools to enforce code style: 49 | 50 | - isort, to sort imports 51 | - Black, to format code 52 | - [MyPy](https://mypy.readthedocs.io), as static type checker 53 | 54 | A set of styling checks, as well as regression tests, are defined and managed through the [pre-commit](https://pre-commit.com/) framework. To ensure that those scripts run automatically before a commit is finalized, install `pre-commit` on your local repository: 55 | 56 | ```bash 57 | uv run pre-commit install 58 | ``` 59 | 60 | To run the checks on-demand, type: 61 | 62 | ```bash 63 | uv run pre-commit run --all-files 64 | ``` 65 | 66 | Note: Checks like `Black` and `isort` will _fail_ if they modify files. This is because `pre-commit` doesn't like to see files modified by their hooks. In these cases, `git add` the modified files and `git commit` again. 67 | 68 | -------------------------------------------------------------------------------- /tests/test_listitem_marker_model.py: -------------------------------------------------------------------------------- 1 | from docling_core.types.doc.document import DoclingDocument, ListItem, ProvenanceItem 2 | from docling_core.types.doc.base import BoundingBox, CoordOrigin 3 | 4 | from docling_core.types.doc.labels import DocItemLabel 5 | 6 | from docling_ibm_models.list_item_normalizer.list_marker_processor import ListItemMarkerProcessor 7 | 8 | # Example usage and testing 9 | def test_listitem_marker_model(): 10 | """Example of how to use the ListItemMarkerProcessor.""" 11 | 12 | # Create a sample document 13 | doc = DoclingDocument(name="Sample Document") 14 | 15 | doc.add_text( 16 | label=DocItemLabel.TEXT, 17 | text="• Second item with bullet and content", # Marker and content together 18 | prov=ProvenanceItem( 19 | page_no=0, 20 | bbox=BoundingBox(l=0, t=15, r=200, b=25, coord_origin=CoordOrigin.TOPLEFT), 21 | charspan=(0, 37) 22 | ) 23 | ) 24 | 25 | doc.add_list_item( 26 | text="• Third item with bullet and content", # Marker and content together 27 | prov=ProvenanceItem( 28 | page_no=0, 29 | bbox=BoundingBox(l=0, t=15, r=200, b=25, coord_origin=CoordOrigin.TOPLEFT), 30 | charspan=(0, 37) 31 | ) 32 | ) 33 | 34 | # Add some sample text items that should be converted to list items 35 | doc.add_text( 36 | label=DocItemLabel.TEXT, 37 | text="1.", # Marker only 38 | prov=ProvenanceItem( 39 | page_no=0, 40 | bbox=BoundingBox(l=0, t=0, r=10, b=10, coord_origin=CoordOrigin.TOPLEFT), 41 | charspan=(0, 2) 42 | ) 43 | ) 44 | 45 | doc.add_text( 46 | label=DocItemLabel.TEXT, 47 | text="First item content", # Content only 48 | prov=ProvenanceItem( 49 | page_no=0, 50 | bbox=BoundingBox(l=15, t=0, r=100, b=10, coord_origin=CoordOrigin.TOPLEFT), 51 | charspan=(0, 18) 52 | ) 53 | ) 54 | 55 | # Process the document 56 | processor = ListItemMarkerProcessor() 57 | processed_doc = processor.process_document(doc, merge_items=True) 58 | 59 | # print(" ---------- document: \n", processed_doc.export_to_markdown(), "\n ---------- \n") 60 | 61 | assert len(processed_doc.texts)==3, "len(processed_doc.texts)==3" 62 | 63 | assert processed_doc.texts[0].text=="• Second item with bullet and content" 64 | 65 | assert isinstance(processed_doc.texts[1], ListItem) 66 | assert processed_doc.texts[1].text=="Third item with bullet and content" 67 | assert processed_doc.texts[1].marker=="•" 68 | assert not processed_doc.texts[1].enumerated 69 | 70 | assert isinstance(processed_doc.texts[2], ListItem) 71 | assert processed_doc.texts[2].label==DocItemLabel.LIST_ITEM 72 | assert processed_doc.texts[2].text=="First item content" 73 | assert processed_doc.texts[2].marker=="1." 74 | assert processed_doc.texts[2].enumerated 75 | -------------------------------------------------------------------------------- /tests/test_document_figure_classifier.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright IBM Corp. 2024 - 2024 3 | # SPDX-License-Identifier: MIT 4 | # 5 | import os 6 | import numpy as np 7 | import pytest 8 | from PIL import Image 9 | 10 | from docling_ibm_models.document_figure_classifier_model.document_figure_classifier_predictor import ( 11 | DocumentFigureClassifierPredictor, 12 | ) 13 | 14 | from huggingface_hub import snapshot_download 15 | 16 | 17 | @pytest.fixture(scope="module") 18 | def init() -> dict: 19 | r""" 20 | Initialize the testing environment 21 | """ 22 | init = { 23 | "num_threads": 1, 24 | "test_imgs": [ 25 | { 26 | "label": "bar_chart", 27 | "image_path": "tests/test_data/figure_classifier/images/bar_chart.jpg", 28 | }, 29 | { 30 | "label": "map", 31 | "image_path": "tests/test_data/figure_classifier/images/map.jpg", 32 | }, 33 | ], 34 | "info": { 35 | "device": "auto", 36 | }, 37 | } 38 | 39 | # Download models from HF 40 | init["artifact_path"] = snapshot_download( 41 | repo_id="ds4sd/DocumentFigureClassifier", revision="v1.0.0" 42 | ) 43 | 44 | return init 45 | 46 | 47 | def test_figure_classifier(init: dict): 48 | r""" 49 | Unit test for the CodeFormulaPredictor 50 | """ 51 | device = "cpu" 52 | num_threads = 2 53 | 54 | # Initialize LayoutPredictor 55 | figure_classifier = DocumentFigureClassifierPredictor( 56 | init["artifact_path"], device=device, num_threads=num_threads 57 | ) 58 | 59 | # Check info 60 | info = figure_classifier.info() 61 | assert info["device"] == device, "Wronly set device" 62 | assert info["num_threads"] == num_threads, "Wronly set number of threads" 63 | 64 | # Unsupported input image 65 | is_exception = False 66 | try: 67 | for _ in figure_classifier.predict(["wrong"]): 68 | pass 69 | except TypeError: 70 | is_exception = True 71 | assert is_exception 72 | 73 | # Predict on test images, not batched 74 | for d in init["test_imgs"]: 75 | label = d["label"] 76 | img_path = d["image_path"] 77 | 78 | with Image.open(img_path) as img: 79 | 80 | output = figure_classifier.predict([img]) 81 | predicted_class = output[0][0][0] 82 | 83 | assert predicted_class == label 84 | 85 | # Load images as numpy arrays 86 | np_arr = np.asarray(img) 87 | output = figure_classifier.predict([np_arr]) 88 | predicted_class = output[0][0][0] 89 | 90 | assert predicted_class == label 91 | 92 | # Predict on test images, batched 93 | labels = [d['label'] for d in init["test_imgs"]] 94 | images = [Image.open(d["image_path"]) for d in init["test_imgs"]] 95 | 96 | outputs = figure_classifier.predict(images) 97 | outputs = [output[0][0] for output in outputs] 98 | assert outputs == labels 99 | -------------------------------------------------------------------------------- /tests/test_common.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright IBM Corp. 2024 - 2024 3 | # SPDX-License-Identifier: MIT 4 | # 5 | import json 6 | import tempfile 7 | 8 | import docling_ibm_models.tableformer.common as c 9 | 10 | 11 | test_config_a = { 12 | "base_dir": "./tests/test_data/", 13 | "curr_dir": "./tests/test_data/test_common/", 14 | "data_top_dir": "./tests/test_data/", 15 | "dataset": { 16 | "name": ["PhysRevB"], 17 | "limit": 10, 18 | "split": {"test": 0.2, "train": 0.5, "evaluate": 0.3}, 19 | }, 20 | "features": { 21 | "name": "Data2Features03b", 22 | "parameters": { 23 | "normalize_features": True, 24 | "normalize_features_method": "Z-Score", 25 | }, 26 | }, 27 | } 28 | 29 | 30 | test_config_b = {"preparation": {"max_tag_len": 300}, "model": {"seq_len": 30}} 31 | 32 | test_config_c = {"preparation": {"max_tag_len": 300}, "model": {"seq_len": 302}} 33 | 34 | test_config_d = {"preparation": {"max_tag_len": 300}, "model": {"seq_len": 303}} 35 | 36 | 37 | def test_safe_get_parameters(): 38 | val = c.safe_get_parameter(None, None, 10) 39 | assert val == 10, "Failed with null objects" 40 | 41 | index_path = ["features", "parameters", "normalize_features_method"] 42 | val = c.safe_get_parameter(test_config_a, index_path, None) 43 | assert val == "Z-Score", "Cannot find existing parameter" 44 | 45 | index_path = ["features", "parameters", "wrong"] 46 | val = c.safe_get_parameter(test_config_a, index_path, "hello") 47 | assert val == "hello", "Default value should be here" 48 | 49 | index_path = ["features", "wrong", "normalize_features_method"] 50 | val = c.safe_get_parameter(test_config_a, index_path, 10) 51 | assert val == 10, "Default value should be here" 52 | 53 | index_path = ["model", "parameters", "normalize_features_method"] 54 | val = c.safe_get_parameter(test_config_a, index_path, "hello") 55 | assert val == "hello", "Default value should be here" 56 | 57 | # Test exception throwing 58 | exRaised = False 59 | try: 60 | index_path = ["missing"] 61 | val = c.safe_get_parameter(test_config_a, index_path, required=True) 62 | except ValueError: 63 | exRaised = True 64 | assert exRaised, "Exception should had been raised here" 65 | 66 | 67 | def test_config_validation(): 68 | configs = [test_config_b, test_config_c, test_config_d] 69 | 70 | for i, config in enumerate(configs): 71 | try: 72 | val = c.validate_config(config) 73 | if i == 0 or i == 1: 74 | assert val, "Valid configuration didn't pass the validation test" 75 | except AssertionError: 76 | assert i == 2, "Configuration validation error" 77 | 78 | def test_read_config(): 79 | r""" 80 | Testing the read_config() function 81 | """ 82 | with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as fp: 83 | # Write a tmp file 84 | json.dump(test_config_a, fp) 85 | fp.close() 86 | 87 | # Read the tmp file and extract the config 88 | config = c.read_config(fp.name) 89 | assert isinstance(config, dict) 90 | -------------------------------------------------------------------------------- /docling_ibm_models/tableformer/data_management/transforms.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright IBM Corp. 2024 - 2024 3 | # SPDX-License-Identifier: MIT 4 | # 5 | from __future__ import division 6 | 7 | import collections 8 | import numbers 9 | import random 10 | 11 | import torch 12 | 13 | from docling_ibm_models.tableformer.data_management import functional as F 14 | 15 | 16 | class Normalize(object): 17 | """Normalize a tensor image with mean and standard deviation. 18 | Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform 19 | will normalize each channel of the input ``torch.*Tensor`` i.e. 20 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` 21 | Args: 22 | mean (sequence): Sequence of means for each channel. 23 | std (sequence): Sequence of standard deviations for each channel. 24 | """ 25 | 26 | def __init__(self, mean, std): 27 | self.mean = mean 28 | self.std = std 29 | 30 | def __call__(self, tensor, target=None): 31 | """ 32 | Args: 33 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 34 | Returns: 35 | Tensor: Normalized Tensor image. 36 | """ 37 | return F.normalize(tensor, self.mean, self.std), target 38 | 39 | def __repr__(self): 40 | return self.__class__.__name__ + "(mean={0}, std={1})".format( 41 | self.mean, self.std 42 | ) 43 | 44 | 45 | class Resize(object): 46 | """Resize the input PIL Image to the given size. 47 | Args: 48 | size (sequence or int): Desired output size. If size is a sequence like 49 | (h, w), output size will be matched to this. If size is an int, 50 | smaller edge of the image will be matched to this number. 51 | i.e, if height > width, then image will be rescaled to 52 | (size * height / width, size) 53 | interpolation (int, optional): Desired interpolation. Default is 54 | ``BILINEAR`` 55 | """ 56 | 57 | def __init__(self, size, interpolation="BILINEAR"): 58 | self.size = size 59 | self.interpolation = interpolation 60 | 61 | def __call__(self, img, target=None): 62 | """ 63 | Args: 64 | img (np.ndarray): Image to be scaled. 65 | Returns: 66 | np.ndarray: Rescaled image. 67 | """ 68 | # Resize bboxes (in pixels) 69 | x_scale = 0 70 | y_scale = 0 71 | 72 | if img.shape[1] > 0: 73 | x_scale = self.size[0] / img.shape[1] 74 | if img.shape[0] > 0: 75 | y_scale = self.size[1] / img.shape[0] 76 | 77 | # loop over bboxes 78 | if target is not None: 79 | if target["boxes"] is not None: 80 | target_ = target.copy() 81 | target_["boxes"][:, 0] = x_scale * target_["boxes"][:, 0] 82 | target_["boxes"][:, 1] = y_scale * target_["boxes"][:, 1] 83 | target_["boxes"][:, 2] = x_scale * target_["boxes"][:, 2] 84 | target_["boxes"][:, 3] = y_scale * target_["boxes"][:, 3] 85 | return F.resize(img, self.size, self.interpolation), target 86 | 87 | def __repr__(self): 88 | interpolate_str = self.interpolation 89 | return self.__class__.__name__ + "(size={0}, interpolation={1})".format( 90 | self.size, interpolate_str 91 | ) 92 | -------------------------------------------------------------------------------- /docling_ibm_models/tableformer/settings.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright IBM Corp. 2024 - 2024 3 | # SPDX-License-Identifier: MIT 4 | # 5 | import logging 6 | import sys 7 | 8 | 9 | def get_custom_logger(logger_name, level, stream=sys.stdout): 10 | r""" 11 | Create a custom logger with a standard formatting 12 | 13 | Inputs: 14 | - logger_name: Name of the logger. You can get the class name as self.__class__.__name__ 15 | - level: logging level (e.g. logging.INFO, logging.DEBUG, etc.) 16 | - stream: One of sys.stdout or sys.stderr 17 | 18 | Outputs: 19 | logger 20 | """ 21 | logger = logging.getLogger(logger_name) 22 | logger.setLevel(level) 23 | 24 | # Set the handler 25 | if not logger.hasHandlers(): 26 | handler = logging.StreamHandler(stream) 27 | formatter = logging.Formatter( 28 | "%(asctime)s %(name)-12s %(levelname)-8s %(message)s" 29 | ) 30 | handler.setFormatter(formatter) 31 | logger.addHandler(handler) 32 | 33 | return logger 34 | 35 | 36 | ################################################################################### 37 | # System constants 38 | # 39 | 40 | r""" 41 | This is a "generic" logger available to all scripts. 42 | It is encouraged that each class has it's own custom logger with the name of the class. 43 | You can use the "get_custom_logger" function to build a custom logger with a standard format. 44 | """ 45 | LOGGER = get_custom_logger("docling-pm", logging.INFO) 46 | 47 | # Supported dataset types 48 | supported_datasets = ["TF_prepared"] # TF prepared dataset 49 | 50 | # Split names 51 | TRAIN_SPLIT = "train" 52 | VAL_SPLIT = "val" 53 | TEST_SPLIT = "test" 54 | 55 | # Prepared data parts and filename templates 56 | PREPARED_DATA_PARTS = { 57 | # Array with the bboxes (x1y1x2y2) for all cells of the images across all splits. 58 | # The bboxes are indexed with the filename. 59 | # Notices: 60 | # - The bboxes are NOT transformed. 61 | # - If the image filenames are the same across splits, there will be one one entry in the file 62 | "BBOXES": "BBOXES.json", 63 | # Image filenames used for train and val 64 | "IMAGES": "IMAGES.json", 65 | # Mean, std, variance as arrays of 3 (for each color) 66 | "STATISTICS": "STATISTICS_.json", # PRECOMPUTED 67 | # Bboxes of the cells in the form [1, x1, x2, y1, y2] or [0, 0, 0, 0, 0] in case of no box. 68 | "TRAIN_CELLBBOXES": "TRAIN_CELLBBOXES_.json", # NOT USED. 69 | # Array with arrays of the length + 2 of the original cells per image. 70 | "TRAIN_CELLLENS": "TRAIN_CELLLENS_.json", 71 | # Indices of the cells between and at the end. 72 | "TRAIN_CELLS": "TRAIN_CELLS_.json", 73 | # Array with the length + 2 of the original tags per image. 74 | "TRAIN_TAGLENS": "TRAIN_TAGLENS_.json", 75 | # Indices of the tags between and at the end. 76 | "TRAIN_TAGS": "TRAIN_TAGS_.json", 77 | # Ground truth for the evaluation dataset per eval image. 78 | "VAL": "VAL.json", 79 | # Vocabulary: Indices of the word_map_cells and word_map_tags 80 | "WORDMAP": "WORDMAP_.json", # PRECOMPUTED 81 | } 82 | 83 | # Purposes 84 | TRAIN_PURPOSE = "train" 85 | VAL_PURPOSE = "val" 86 | TEST_PURPOSE = "test" 87 | PREDICT_PURPOSE = "predict" 88 | 89 | # The DDP world size when we train in CPU with DDP enabled 90 | DDP_CPU_WORLD_SIZE = 2 91 | -------------------------------------------------------------------------------- /.github/workflows/checks.yml: -------------------------------------------------------------------------------- 1 | on: 2 | workflow_call: 3 | inputs: 4 | push_coverage: 5 | type: boolean 6 | description: "If true, the coverage results are pushed to codecov.io." 7 | default: true 8 | secrets: 9 | CODECOV_TOKEN: 10 | required: false 11 | 12 | env: 13 | HF_HUB_DOWNLOAD_TIMEOUT: "60" 14 | HF_HUB_ETAG_TIMEOUT: "60" 15 | UV_FROZEN: "1" 16 | 17 | jobs: 18 | run-checks: 19 | runs-on: ubuntu-latest 20 | strategy: 21 | matrix: 22 | python-version: ['3.9', '3.10', '3.11', '3.12', '3.13', '3.14'] 23 | steps: 24 | - uses: actions/checkout@v4 25 | - name: Cache Hugging Face models 26 | uses: actions/cache@v4 27 | with: 28 | path: ~/.cache/huggingface 29 | key: huggingface-cache-py${{ matrix.python-version }} 30 | - name: Install uv and set the python version 31 | uses: astral-sh/setup-uv@v5 32 | with: 33 | python-version: ${{ matrix.python-version }} 34 | enable-cache: true 35 | - name: pre-commit cache key 36 | run: echo "PY=$(python -VV | sha256sum | cut -d' ' -f1)" >> "$GITHUB_ENV" 37 | - uses: actions/cache@v4 38 | with: 39 | path: ~/.cache/pre-commit 40 | key: pre-commit|${{ env.PY }}|${{ hashFiles('.pre-commit-config.yaml') }} 41 | - name: Install dependencies 42 | run: uv sync --frozen 43 | - name: Check code quality and consistency 44 | run: pre-commit run --all-files 45 | - name: Run tests 46 | run: | 47 | uv run --no-sync pytest -v --cov=docling_ibm_models --cov-report=xml tests 48 | - name: Upload coverage to Codecov 49 | if: inputs.push_coverage 50 | uses: codecov/codecov-action@v5 51 | with: 52 | token: ${{ secrets.CODECOV_TOKEN }} 53 | files: ./coverage.xml 54 | 55 | build-package: 56 | runs-on: ubuntu-latest 57 | strategy: 58 | matrix: 59 | python-version: ['3.12'] 60 | steps: 61 | - uses: actions/checkout@v4 62 | - name: Install uv and set the python version 63 | uses: astral-sh/setup-uv@v5 64 | with: 65 | python-version: ${{ matrix.python-version }} 66 | enable-cache: true 67 | - name: Install dependencies 68 | run: uv sync 69 | - name: Build package 70 | run: uv build 71 | - name: Check content of wheel 72 | run: unzip -l dist/*.whl 73 | - name: Store the distribution packages 74 | uses: actions/upload-artifact@v4 75 | with: 76 | name: python-package-distributions 77 | path: dist/ 78 | 79 | test-package: 80 | needs: 81 | - build-package 82 | runs-on: ubuntu-latest 83 | strategy: 84 | matrix: 85 | python-version: ['3.12'] 86 | steps: 87 | - name: Download all the dists 88 | uses: actions/download-artifact@v4 89 | with: 90 | name: python-package-distributions 91 | path: dist/ 92 | - name: Install uv and set the python version 93 | uses: astral-sh/setup-uv@v5 94 | with: 95 | python-version: ${{ matrix.python-version }} 96 | enable-cache: true 97 | - name: Install package 98 | run: uv pip install dist/*.whl 99 | - name: Test a simple import 100 | run: python -c 'from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor' 101 | -------------------------------------------------------------------------------- /docling_ibm_models/tableformer/data_management/functional.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright IBM Corp. 2024 - 2024 3 | # SPDX-License-Identifier: MIT 4 | # 5 | import numbers 6 | from collections.abc import Iterable, Sequence 7 | 8 | import cv2 9 | import numpy as np 10 | import torch 11 | from torchvision.transforms import functional 12 | 13 | cv2.setNumThreads(0) 14 | cv2.ocl.setUseOpenCL(False) 15 | 16 | INTER_MODE = { 17 | "NEAREST": cv2.INTER_NEAREST, 18 | "BILINEAR": cv2.INTER_LINEAR, 19 | "BICUBIC": cv2.INTER_CUBIC, 20 | } 21 | 22 | PAD_MOD = { 23 | "constant": cv2.BORDER_CONSTANT, 24 | "edge": cv2.BORDER_REPLICATE, 25 | "reflect": cv2.BORDER_DEFAULT, 26 | "symmetric": cv2.BORDER_REFLECT, 27 | } 28 | 29 | 30 | def _is_tensor_image(img): 31 | return torch.is_tensor(img) and img.ndimension() == 3 32 | 33 | 34 | def _is_numpy_image(img): 35 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) 36 | 37 | 38 | def normalize(tensor, mean, std): 39 | """Normalize a tensor image with mean and standard deviation. 40 | See ``Normalize`` for more details. 41 | Args: 42 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 43 | mean (sequence): Sequence of means for each channel. 44 | std (sequence): Sequence of standard deviations for each channely. 45 | Returns: 46 | Tensor: Normalized Tensor image. 47 | """ 48 | if _is_tensor_image(tensor): 49 | for t, m, s in zip(tensor, mean, std, strict=False): 50 | t.sub_(m).div_(s) 51 | return tensor 52 | elif _is_numpy_image(tensor): 53 | return (tensor.astype(np.float32) - 255.0 * np.array(mean)) / np.array(std) 54 | else: 55 | raise RuntimeError("Undefined type") 56 | 57 | 58 | def resize(img, size, interpolation="BILINEAR"): 59 | """Resize the input CV Image to the given size. 60 | Args: 61 | img (np.ndarray): Image to be resized. 62 | size (tuple or int): Desired output size. If size is a sequence like 63 | (h, w), the output size will be matched to this. If size is an int, 64 | the smaller edge of the image will be matched to this number maintaing 65 | the aspect ratio. i.e, if height > width, then image will be rescaled to 66 | (size * height / width, size) 67 | interpolation (str, optional): Desired interpolation. Default is ``BILINEAR`` 68 | Returns: 69 | cv Image: Resized image. 70 | """ 71 | if not _is_numpy_image(img): 72 | raise TypeError("img should be CV Image. Got {}".format(type(img))) 73 | if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)): 74 | raise TypeError("Got inappropriate size arg: {}".format(size)) 75 | 76 | # TODO(Nikos): Try to remove the opencv dependency 77 | if isinstance(size, int): 78 | h, w, c = img.shape 79 | if (w <= h and w == size) or (h <= w and h == size): 80 | return img 81 | if w < h: 82 | ow = size 83 | oh = int(size * h / w) 84 | return cv2.resize( 85 | img, dsize=(ow, oh), interpolation=INTER_MODE[interpolation] 86 | ) 87 | else: 88 | oh = size 89 | ow = int(size * w / h) 90 | return cv2.resize( 91 | img, dsize=(ow, oh), interpolation=INTER_MODE[interpolation] 92 | ) 93 | else: 94 | oh, ow = size 95 | return cv2.resize( 96 | img, dsize=(int(ow), int(oh)), interpolation=INTER_MODE[interpolation] 97 | ) 98 | -------------------------------------------------------------------------------- /tests/test_layout_predictor.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright IBM Corp. 2024 - 2024 3 | # SPDX-License-Identifier: MIT 4 | # 5 | import os 6 | import json 7 | 8 | import numpy as np 9 | import pytest 10 | from huggingface_hub import snapshot_download 11 | from PIL import Image 12 | 13 | from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor 14 | 15 | 16 | @pytest.fixture(scope="module") 17 | def init() -> dict: 18 | r""" 19 | Initialize the testing environment 20 | """ 21 | # This config is missing the keys: "artifact_path", "info1.torch_file", "info2.torch_file" 22 | init = { 23 | "num_threads": 1, 24 | "test_imgs": [ 25 | "tests/test_data/samples/ADS.2007.page_123.png", 26 | ], 27 | "info1": { 28 | "device": "cpu", 29 | "image_size": 640, 30 | "threshold": 0.6, 31 | }, 32 | "pred_bboxes": 9, 33 | } 34 | 35 | # Download models from HF 36 | artifact_path = snapshot_download(repo_id="ds4sd/docling-layout-old") 37 | 38 | # Add the missing config keys 39 | init["artifact_path"] = artifact_path 40 | 41 | return init 42 | 43 | 44 | def test_layoutpredictor(init: dict): 45 | r""" 46 | Unit test for the LayoutPredictor 47 | """ 48 | device = "cpu" 49 | num_threads = 2 50 | 51 | # Initialize LayoutPredictor 52 | lpredictor = LayoutPredictor( 53 | init["artifact_path"], device=device, num_threads=num_threads 54 | ) 55 | 56 | # Check info 57 | info = lpredictor.info() 58 | assert info["device"] == device, "Wronly set device" 59 | assert info["num_threads"] == num_threads, "Wronly set number of threads" 60 | 61 | # Unsupported input image 62 | is_exception = False 63 | try: 64 | for pred in lpredictor.predict("wrong"): 65 | pass 66 | except TypeError: 67 | is_exception = True 68 | assert is_exception 69 | 70 | # Predict on the test image 71 | for img_fn in init["test_imgs"]: 72 | 73 | true_layout_fn = img_fn+".json" 74 | with Image.open(img_fn) as img: 75 | 76 | w, h = img.size 77 | 78 | # Load images as PIL objects 79 | for i, pred in enumerate(lpredictor.predict(img)): 80 | print("PIL pred: {}".format(pred)) 81 | assert pred["l"] >= 0 and pred["l"] <= w 82 | assert pred["t"] >= 0 and pred["t"] <= h 83 | assert pred["r"] >= 0 and pred["r"] <= w 84 | assert pred["b"] >= 0 and pred["b"] <= h 85 | 86 | assert i + 1 == init["pred_bboxes"] 87 | 88 | if os.path.exists(true_layout_fn): 89 | with open(true_layout_fn, "r") as fr: 90 | true_layout = json.load(fr) 91 | 92 | """ 93 | # FIXME: write a simple test to check all objects are found 94 | else: 95 | with open(true_layout_fn, "w") as fw: 96 | fw.write(json.dumps(pred_layout, indent=4)) 97 | """ 98 | 99 | # Load images as numpy arrays 100 | np_arr = np.asarray(img) 101 | for i, pred in enumerate(lpredictor.predict(np_arr)): 102 | print("numpy pred: {}".format(pred)) 103 | assert pred["l"] >= 0 and pred["l"] <= w 104 | assert pred["t"] >= 0 and pred["t"] <= h 105 | assert pred["r"] >= 0 and pred["r"] <= w 106 | assert pred["b"] >= 0 and pred["b"] <= h 107 | assert i + 1 == init["pred_bboxes"] 108 | -------------------------------------------------------------------------------- /demo/demo_code_formula_predictor.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright IBM Corp. 2024 - 2024 3 | # SPDX-License-Identifier: MIT 4 | # 5 | import argparse 6 | import logging 7 | import os 8 | import sys 9 | import time 10 | from pathlib import Path 11 | 12 | from huggingface_hub import snapshot_download 13 | from PIL import Image 14 | 15 | from docling_ibm_models.code_formula_model.code_formula_predictor import CodeFormulaPredictor 16 | 17 | 18 | def demo( 19 | logger: logging.Logger, 20 | artifact_path: str, 21 | device: str, 22 | num_threads: int, 23 | image_dir: str, 24 | viz_dir: str, 25 | ): 26 | r""" 27 | Apply LayoutPredictor on the input image directory 28 | 29 | If you want to load from PDF: 30 | pdf_image = pyvips.Image.new_from_file("test_data/ADS.2007.page_123.pdf", page=0) 31 | """ 32 | # Create the layout predictor 33 | code_formula_predictor = CodeFormulaPredictor(artifact_path, device=device, num_threads=num_threads) 34 | 35 | image_dir = Path(image_dir) 36 | images = [] 37 | image_names = os.listdir(image_dir) 38 | image_names.sort() 39 | for image_name in image_names: 40 | image = Image.open(image_dir / image_name) 41 | images.append(image) 42 | 43 | t0 = time.perf_counter() 44 | outputs = code_formula_predictor.predict(images, ['code', 'formula'], temperature=0) 45 | total_ms = 1000 * (time.perf_counter() - t0) 46 | avg_ms = (total_ms / len(image_names)) if len(image_names) > 0 else 0 47 | logger.info( 48 | "For {} images(ms): [total|avg] = [{:.1f}|{:.1f}]".format( 49 | len(image_names), total_ms, avg_ms 50 | ) 51 | ) 52 | 53 | for i, output in enumerate(outputs): 54 | logger.info(f"\nOutput {i}:\n{output}\n\n") 55 | 56 | 57 | def main(args): 58 | num_threads = int(args.num_threads) if args.num_threads is not None else None 59 | device = args.device.lower() 60 | image_dir = args.image_dir 61 | viz_dir = args.viz_dir 62 | 63 | # Initialize logger 64 | logging.basicConfig(level=logging.DEBUG) 65 | logger = logging.getLogger("CodeFormulaPredictor") 66 | logger.setLevel(logging.DEBUG) 67 | if not logger.hasHandlers(): 68 | handler = logging.StreamHandler(sys.stdout) 69 | formatter = logging.Formatter( 70 | "%(asctime)s %(name)-12s %(levelname)-8s %(message)s" 71 | ) 72 | handler.setFormatter(formatter) 73 | logger.addHandler(handler) 74 | 75 | # Ensure the viz dir 76 | Path(viz_dir).mkdir(parents=True, exist_ok=True) 77 | 78 | # Download models from HF 79 | download_path = snapshot_download(repo_id="ds4sd/CodeFormula", revision="v1.0.0") 80 | 81 | # Test the Code+Equation model 82 | demo(logger, download_path, device, num_threads, image_dir, viz_dir) 83 | 84 | 85 | if __name__ == "__main__": 86 | r""" 87 | python -m demo.demo_code_formula_predictor -i 88 | """ 89 | parser = argparse.ArgumentParser(description="Test the CodeFormulaPredictor") 90 | parser.add_argument( 91 | "-d", "--device", required=False, default="cpu", help="One of [cpu, cuda, mps]" 92 | ) 93 | parser.add_argument( 94 | "-n", "--num_threads", required=False, default=4, help="Number of threads" 95 | ) 96 | parser.add_argument( 97 | "-i", 98 | "--image_dir", 99 | required=True, 100 | help="PNG images input directory", 101 | ) 102 | parser.add_argument( 103 | "-v", 104 | "--viz_dir", 105 | required=False, 106 | default="viz/", 107 | help="Directory to save prediction visualizations", 108 | ) 109 | 110 | args = parser.parse_args() 111 | main(args) 112 | -------------------------------------------------------------------------------- /docling_ibm_models/tableformer/common.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright IBM Corp. 2024 - 2024 3 | # SPDX-License-Identifier: MIT 4 | # 5 | import argparse 6 | import json 7 | import logging 8 | import os 9 | 10 | import torch 11 | 12 | import docling_ibm_models.tableformer.settings as s 13 | from docling_ibm_models.tableformer.models.common.base_model import BaseModel 14 | 15 | LOG_LEVEL = logging.DEBUG 16 | logger = s.get_custom_logger("common", LOG_LEVEL) 17 | 18 | 19 | def validate_config(config): 20 | r""" 21 | Validate the provided configuration file. 22 | A ValueError exception will be thrown in case the config file is invalid 23 | 24 | Parameters 25 | ---------- 26 | config : dictionary 27 | Configuration for the tablemodel 28 | 29 | Returns 30 | ------- 31 | bool : True on success 32 | """ 33 | if "model" not in config: 34 | return True 35 | if "preparation" not in config: 36 | return True 37 | assert ( 38 | "max_tag_len" in config["preparation"] 39 | ), "Config error: 'preparation.max_tag_len' parameter is missing" 40 | if "seq_len" in config["model"]: 41 | assert ( 42 | config["model"]["seq_len"] > 0 43 | ), "Config error: 'model.seq_len' should be positive" 44 | assert config["model"]["seq_len"] <= ( 45 | config["preparation"]["max_tag_len"] + 2 46 | ), "Config error: 'model.seq_len' should be up to 'preparation.max_tag_len' + 2" 47 | 48 | return True 49 | 50 | 51 | def read_config(config_filename): 52 | with open(config_filename, "r") as fd: 53 | config = json.load(fd) 54 | 55 | # Validate the config file 56 | validate_config(config) 57 | 58 | return config 59 | 60 | 61 | def safe_get_parameter(input_dict, index_path, default=None, required=False): 62 | r""" 63 | Safe get parameter from a nested dictionary. 64 | 65 | Provide a nested dictionary (dictionary of dictionaries) and a list of indices: 66 | - If the whole index path exists the value pointed by it is returned 67 | - Otherwise the default value is returned. 68 | 69 | Input: 70 | input_dict: Data structure of nested dictionaries. 71 | index_path: List with the indices path to follow inside the input_dict. 72 | default: Default value to return if the indices path is broken. 73 | required: If true a ValueError exception will be raised in case the parameter does not exist 74 | Output: 75 | The value pointed by the index path or "default". 76 | """ 77 | if input_dict is None or index_path is None: 78 | return default 79 | 80 | d = input_dict 81 | for i in index_path[:-1]: 82 | if i not in d: 83 | if required: 84 | raise ValueError("Missing parameter: {}".format(i)) 85 | return default 86 | d = d[i] 87 | 88 | last_index = index_path[-1] 89 | if last_index not in d: 90 | if required: 91 | raise ValueError("Missing parameter: {}".format(last_index)) 92 | return default 93 | 94 | return d[last_index] 95 | 96 | 97 | def get_prepared_data_filename(prepared_data_part, dataset_name): 98 | r""" 99 | Build the full filename of the prepared data part 100 | 101 | Parameters 102 | ---------- 103 | prepared_data_part : string 104 | Part of the prepared data 105 | dataset_name : string 106 | Name of the dataset 107 | 108 | Returns 109 | ------- 110 | string 111 | The full filename for the prepared file 112 | """ 113 | template = s.PREPARED_DATA_PARTS[prepared_data_part] 114 | if "" in template: 115 | template = template.replace("", dataset_name) 116 | return template 117 | -------------------------------------------------------------------------------- /demo/demo_document_figure_classifier_predictor.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright IBM Corp. 2024 - 2024 3 | # SPDX-License-Identifier: MIT 4 | # 5 | import argparse 6 | import logging 7 | import os 8 | import sys 9 | import time 10 | from pathlib import Path 11 | 12 | from huggingface_hub import snapshot_download 13 | from PIL import Image 14 | 15 | from docling_ibm_models.document_figure_classifier_model.document_figure_classifier_predictor import DocumentFigureClassifierPredictor 16 | 17 | 18 | def demo( 19 | logger: logging.Logger, 20 | artifact_path: str, 21 | device: str, 22 | num_threads: int, 23 | image_dir: str, 24 | viz_dir: str, 25 | ): 26 | r""" 27 | Apply DocumentFigureClassifierPredictor on the input image directory 28 | """ 29 | # Create the layout predictor 30 | document_figure_classifier_predictor = DocumentFigureClassifierPredictor(artifact_path, device=device, num_threads=num_threads) 31 | 32 | image_dir = Path(image_dir) 33 | images = [] 34 | image_names = os.listdir(image_dir) 35 | image_names.sort() 36 | for image_name in image_names: 37 | image = Image.open(image_dir / image_name) 38 | images.append(image) 39 | 40 | t0 = time.perf_counter() 41 | outputs = document_figure_classifier_predictor.predict(images) 42 | total_ms = 1000 * (time.perf_counter() - t0) 43 | avg_ms = (total_ms / len(image_names)) if len(image_names) > 0 else 0 44 | logger.info( 45 | "For {} images(ms): [total|avg] = [{:.1f}|{:.1f}]".format( 46 | len(image_names), total_ms, avg_ms 47 | ) 48 | ) 49 | 50 | for i, output in enumerate(outputs): 51 | image_name = image_names[i] 52 | logger.info(f"Predictions for: '{image_name}':") 53 | for pred in output: 54 | logger.info(f" Class '{pred[0]}' has probability {pred[1]}") 55 | 56 | 57 | def main(args): 58 | num_threads = int(args.num_threads) if args.num_threads is not None else None 59 | device = args.device.lower() 60 | image_dir = args.image_dir 61 | viz_dir = args.viz_dir 62 | 63 | # Initialize logger 64 | logging.basicConfig(level=logging.DEBUG) 65 | logger = logging.getLogger("DocumentFigureClassifierPredictor") 66 | logger.setLevel(logging.DEBUG) 67 | if not logger.hasHandlers(): 68 | handler = logging.StreamHandler(sys.stdout) 69 | formatter = logging.Formatter( 70 | "%(asctime)s %(name)-12s %(levelname)-8s %(message)s" 71 | ) 72 | handler.setFormatter(formatter) 73 | logger.addHandler(handler) 74 | 75 | # Ensure the viz dir 76 | Path(viz_dir).mkdir(parents=True, exist_ok=True) 77 | 78 | # Download models from HF 79 | download_path = snapshot_download(repo_id="ds4sd/DocumentFigureClassifier", revision="v1.0.0") 80 | 81 | # Test the figure classifier model 82 | demo(logger, download_path, device, num_threads, image_dir, viz_dir) 83 | 84 | 85 | if __name__ == "__main__": 86 | r""" 87 | python -m demo.demo_document_figure_classifier_predictor -i 88 | """ 89 | parser = argparse.ArgumentParser(description="Test the DocumentFigureClassifierPredictor") 90 | parser.add_argument( 91 | "-d", "--device", required=False, default="cpu", help="One of [cpu, cuda, mps]" 92 | ) 93 | parser.add_argument( 94 | "-n", "--num_threads", required=False, default=4, help="Number of threads" 95 | ) 96 | parser.add_argument( 97 | "-i", 98 | "--image_dir", 99 | required=True, 100 | help="PNG images input directory", 101 | ) 102 | parser.add_argument( 103 | "-v", 104 | "--viz_dir", 105 | required=False, 106 | default="viz/", 107 | help="Directory to save prediction visualizations", 108 | ) 109 | 110 | args = parser.parse_args() 111 | main(args) -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "docling-ibm-models" 3 | version = "3.10.3" # DO NOT EDIT, updated automatically 4 | description = "This package contains the AI models used by the Docling PDF conversion package" 5 | license = "MIT" 6 | keywords = ["docling", "convert", "document", "pdf", "layout model", "segmentation", "table structure", "table former"] 7 | readme = "README.md" 8 | authors = [ 9 | { name = "Nikos Livathinos", email = "nli@zurich.ibm.com" }, 10 | { name = "Maxim Lysak", email = "mly@zurich.ibm.com" }, 11 | { name = "Ahmed Nassar", email = "ahn@zurich.ibm.com" }, 12 | { name = "Christoph Auer", email = "cau@zurich.ibm.com" }, 13 | { name = "Michele Dolfi", email = "dol@zurich.ibm.com" }, 14 | { name = "Peter Staar", email = "taa@zurich.ibm.com" }, 15 | ] 16 | classifiers = [ 17 | "Operating System :: MacOS :: MacOS X", 18 | "Operating System :: POSIX :: Linux", 19 | "Operating System :: Microsoft :: Windows", 20 | "Development Status :: 5 - Production/Stable", 21 | "Intended Audience :: Developers", 22 | "Intended Audience :: Science/Research", 23 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 24 | "Programming Language :: Python :: 3", 25 | "Programming Language :: Python :: 3.9", 26 | "Programming Language :: Python :: 3.10", 27 | "Programming Language :: Python :: 3.11", 28 | "Programming Language :: Python :: 3.12", 29 | "Programming Language :: Python :: 3.13", 30 | "Programming Language :: Python :: 3.14", 31 | ] 32 | requires-python = '>=3.9,<4.0' 33 | dependencies = [ 34 | 'torch (>=2.2.2,<3.0.0)', 35 | 'torchvision (>=0,<1)', 36 | 'jsonlines (>=3.1.0,<5.0.0)', 37 | 'Pillow (>=10.0.0,<13.0.0)', 38 | 'tqdm (>=4.64.0,<5.0.0)', 39 | 'huggingface_hub (>=0.23,<1)', 40 | 'safetensors[torch] (>=0.4.3,<1)', 41 | 'pydantic (>=2.0.0,<3.0.0)', 42 | 'docling-core (>=2.19.0,<3.0.0)', 43 | 'transformers (>=4.42.0,<5.0.0)', 44 | 'numpy (>=1.24.4,<3.0.0)', 45 | "rtree>=1.0.0", 46 | 'accelerate (>=1.2.1,<2.0.0)', 47 | ] 48 | 49 | [project.optional-dependencies] 50 | opencv-python-headless = ['opencv-python-headless (>=4.6.0.66,<5.0.0.0)'] 51 | opencv-python = ['opencv-python (>=4.6.0.66,<5.0.0.0)'] 52 | 53 | [project.urls] 54 | homepage = "https://github.com/docling-project/docling-ibm-models" 55 | repository = "https://github.com/docling-project/docling-ibm-models" 56 | issues = "https://github.com/docling-project/docling-ibm-models/issues" 57 | changelog = "https://github.com/docling-project/docling-ibm-models/blob/main/CHANGELOG.md" 58 | 59 | [dependency-groups] 60 | dev = [ 61 | "opencv-python-headless (>=4.6.0.66,<5.0.0.0)", 62 | "pre-commit~=3.7", 63 | "mypy~=1.10", 64 | "black~=24.4", 65 | "isort~=5.10", 66 | "autoflake~=2.0", 67 | "flake8~=7.1", 68 | "flake8-docstrings~=1.6", 69 | "types-setuptools~=70.3", 70 | "pandas-stubs~=2.1", 71 | "types-requests~=2.31", 72 | "coverage~=7.6", 73 | "pytest~=8.3", 74 | "pytest-cov>=6.1.1", 75 | "pytest-dependency~=0.6", 76 | "pytest-xdist~=3.3", 77 | "python-semantic-release~=7.32", 78 | 'datasets~=3.2; python_version < "3.14"', 79 | ] 80 | 81 | [tool.uv] 82 | package = true 83 | conflicts = [ 84 | [ 85 | { extra = "opencv-python-headless" }, 86 | { extra = "opencv-python" }, 87 | ] 88 | ] 89 | 90 | [tool.setuptools.packages.find] 91 | include = ["docling_ibm_models*"] 92 | 93 | [tool.black] 94 | line-length = 88 95 | target-version = ["py39"] 96 | include = '\.pyi?$' 97 | 98 | [tool.isort] 99 | profile = "black" 100 | line_length = 88 101 | py_version = 39 102 | 103 | [tool.semantic_release] 104 | # for default values check: 105 | # https://github.com/python-semantic-release/python-semantic-release/blob/v7.32.2/semantic_release/defaults.cfg 106 | 107 | version_source = "tag_only" 108 | branch = "main" 109 | 110 | # configure types which should trigger minor and patch version bumps respectively 111 | # (note that they must be a subset of the configured allowed types): 112 | parser_angular_allowed_types = "build,chore,ci,docs,feat,fix,perf,style,refactor,test" 113 | parser_angular_minor_types = "feat" 114 | parser_angular_patch_types = "fix,perf" 115 | 116 | 117 | [tool.mypy] 118 | pretty = true 119 | no_implicit_optional = true 120 | python_version = "3.10" 121 | 122 | [[tool.mypy.overrides]] 123 | module = ["torchvision.*", "transformers.*"] 124 | ignore_missing_imports = true 125 | -------------------------------------------------------------------------------- /tests/test_code_formula_predictor.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright IBM Corp. 2024 - 2024 3 | # SPDX-License-Identifier: MIT 4 | # 5 | import os 6 | import numpy as np 7 | import pytest 8 | from PIL import Image 9 | 10 | from docling_ibm_models.code_formula_model.code_formula_predictor import CodeFormulaPredictor 11 | 12 | from huggingface_hub import snapshot_download 13 | 14 | @pytest.fixture(scope="module") 15 | def init() -> dict: 16 | r""" 17 | Initialize the testing environment 18 | """ 19 | init = { 20 | "num_threads": 1, 21 | "test_imgs": [ 22 | { 23 | "label": "code", 24 | "image_path": "tests/test_data/code_formula/images/code.png", 25 | "gt_path": "tests/test_data/code_formula/gt/code.txt", 26 | }, 27 | { 28 | "label": "formula", 29 | "image_path": "tests/test_data/code_formula/images/formula.png", 30 | "gt_path": "tests/test_data/code_formula/gt/formula.txt", 31 | }, 32 | ], 33 | "info": { 34 | "device": "auto", 35 | "temperature": 0, 36 | }, 37 | } 38 | 39 | # Download models from HF 40 | artifact_path = snapshot_download(repo_id="ds4sd/CodeFormula", revision="v1.0.1") 41 | 42 | init["artifact_path"] = artifact_path 43 | 44 | return init 45 | 46 | 47 | def test_code_formula_predictor(init: dict): 48 | r""" 49 | Unit test for the CodeFormulaPredictor 50 | """ 51 | device = "cpu" 52 | num_threads = 2 53 | 54 | # Initialize LayoutPredictor 55 | code_formula_predictor = CodeFormulaPredictor( 56 | init["artifact_path"], device=device, num_threads=num_threads 57 | ) 58 | 59 | # Check info 60 | info = code_formula_predictor.info() 61 | assert info["device"] == device, "Wronly set device" 62 | assert info["num_threads"] == num_threads, "Wronly set number of threads" 63 | 64 | # Unsupported input image 65 | is_exception = False 66 | try: 67 | for _ in code_formula_predictor.predict(["wrong"], ['label']): 68 | pass 69 | except TypeError: 70 | is_exception = True 71 | assert is_exception 72 | 73 | # wrong type for temperature 74 | is_exception = False 75 | try: 76 | dummy_image = Image.new(mode="RGB", size=(100, 100), color=(255, 255, 255)) 77 | for _ in code_formula_predictor.predict([dummy_image], ['label'], "0.1"): 78 | pass 79 | except Exception: 80 | is_exception = True 81 | assert is_exception 82 | 83 | # wrong value for temperature 84 | is_exception = False 85 | try: 86 | dummy_image = Image.new(mode="RGB", size=(100, 100), color=(255, 255, 255)) 87 | for _ in code_formula_predictor.predict([dummy_image], ['label'], -0.1): 88 | pass 89 | except Exception: 90 | is_exception = True 91 | assert is_exception 92 | 93 | # wrong value for temperature 94 | is_exception = False 95 | try: 96 | dummy_image = Image.new(mode="RGB", size=(100, 100), color=(255, 255, 255)) 97 | for _ in code_formula_predictor.predict([dummy_image], ["label"], None): 98 | pass 99 | except Exception: 100 | is_exception = True 101 | assert is_exception 102 | 103 | # mistmatched number of images and labels 104 | is_exception = False 105 | try: 106 | dummy_image = Image.new(mode="RGB", size=(100, 100), color=(255, 255, 255)) 107 | for _ in code_formula_predictor.predict([dummy_image], ['label', 'label']): 108 | pass 109 | except Exception: 110 | is_exception = True 111 | assert is_exception 112 | 113 | # Predict on test images, not batched 114 | temperature = init['info']['temperature'] 115 | for d in init["test_imgs"]: 116 | label = d['label'] 117 | img_path = d['image_path'] 118 | gt_path = d['gt_path'] 119 | 120 | with Image.open(img_path) as img, open(gt_path, 'r') as gt_fp: 121 | gt = gt_fp.read() 122 | 123 | output = code_formula_predictor.predict([img], [label], temperature) 124 | output = output[0] 125 | 126 | assert output == gt 127 | 128 | # Load images as numpy arrays 129 | np_arr = np.asarray(img) 130 | output = code_formula_predictor.predict([np_arr], [label], temperature) 131 | output = output[0] 132 | 133 | assert output == gt 134 | 135 | # Predict on test images, batched 136 | labels = [d['label'] for d in init["test_imgs"]] 137 | images = [Image.open(d['image_path']) for d in init["test_imgs"]] 138 | gts = [open(d['gt_path'], 'r').read() for d in init["test_imgs"]] 139 | 140 | outputs = code_formula_predictor.predict(images, labels, temperature) 141 | assert outputs == gts 142 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![PyPI version](https://img.shields.io/pypi/v/docling-ibm-models)](https://pypi.org/project/docling-ibm-models/) 2 | [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/docling-ibm-models)](https://pypi.org/project/docling-ibm-models/) 3 | [![uv](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/uv/main/assets/badge/v0.json)](https://github.com/astral-sh/uv) 4 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 5 | [![Imports: isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/) 6 | [![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://github.com/pre-commit/pre-commit) 7 | [![Models on Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue)](https://huggingface.co/ds4sd/docling-models/) 8 | [![License MIT](https://img.shields.io/github/license/ds4sd/deepsearch-toolkit)](https://opensource.org/licenses/MIT) 9 | 10 | # Docling IBM models 11 | 12 | AI modules to support the Docling PDF document conversion project. 13 | 14 | - TableFormer is an AI module that recognizes the structure of a table and the bounding boxes of the table content. 15 | - Layout model is an AI model that provides among other things ability to detect tables on the page. This package contains inference code for Layout model. 16 | 17 | 18 | ## Install 19 | 20 | The package provides two variants which allow to seemlessly switch between `opencv-python` and `opencv-python-headless`. 21 | 22 | ```sh 23 | # Option 1: with opencv-python-headless 24 | pip install "docling-ibm-models[opencv-python-headless]" 25 | 26 | # Option 2: with opencv-python 27 | pip install "docling-ibm-models[opencv-python]" 28 | ``` 29 | 30 | ## Pipeline Overview 31 | ![Architecture](docs/tablemodel_overview_color.png) 32 | 33 | ## Datasets 34 | Below we list datasets used with their description, source, and ***"TableFormer Format"***. The TableFormer Format is our processed version of the version of the original format to work with the dataloader out of the box, and to augment the dataset when necassary to add missing groundtruth (bounding boxes for empty cells). 35 | 36 | 37 | | Name | Description | URL | 38 | | ------------- |:-------------:|----| 39 | | PubTabNet | PubTabNet contains heterogeneous tables in both image and HTML format, 516k+ tables in the PubMed Central Open Access Subset | [PubTabNet](https://developer.ibm.com/exchanges/data/all/pubtabnet/) | 40 | | FinTabNet| A dataset for Financial Report Tables with corresponding ground truth location and structure. 112k+ tables included.| [FinTabNet](https://developer.ibm.com/exchanges/data/all/fintabnet/) | 41 | | TableBank| TableBank is a new image-based table detection and recognition dataset built with novel weak supervision from Word and Latex documents on the internet, contains 417K high-quality labeled tables. | [TableBank](https://github.com/doc-analysis/TableBank) | 42 | 43 | ## Models 44 | 45 | ### TableModel04: 46 | ![TableModel04](docs/tbm04.png) 47 | **TableModel04rs (OTSL)** is our SOTA method that using transformers in order to predict table structure and bounding box. 48 | 49 | 50 | ## Configuration file 51 | 52 | Example configuration can be found inside test `tests/test_tf_predictor.py` 53 | These are the main sections of the configuration file: 54 | 55 | - `dataset`: The directory for prepared data and the parameters used during the data loading. 56 | - `model`: The type, name and hyperparameters of the model. Also the directory to save/load the 57 | trained checkpoint files. 58 | - `train`: Parameters for the training of the model. 59 | - `predict`: Parameters for the evaluation of the model. 60 | - `dataset_wordmap`: Very important part that contains token maps. 61 | 62 | 63 | ## Model weights 64 | 65 | You can download the model weights and config files from the links: 66 | 67 | - [TableFormer Checkpoint](https://huggingface.co/ds4sd/docling-models/tree/main/model_artifacts/tableformer) 68 | - [beehive_v0.0.5](https://huggingface.co/ds4sd/docling-models/tree/main/model_artifacts/layout/beehive_v0.0.5) 69 | 70 | 71 | ## Inference Tests 72 | 73 | You can run the inference tests for the models with: 74 | 75 | ``` 76 | python -m pytest tests/ 77 | ``` 78 | 79 | This will also generate prediction and matching visualizations that can be found here: 80 | `tests\test_data\viz\` 81 | 82 | Visualization outlines: 83 | - `Light Pink`: border of recognized table 84 | - `Grey`: OCR cells 85 | - `Green`: prediction bboxes 86 | - `Red`: OCR cells matched with prediction 87 | - `Blue`: Post processed, match 88 | - `Bold Blue`: column header 89 | - `Bold Magenta`: row header 90 | - `Bold Brown`: section row (if table have one) 91 | 92 | 93 | ## Demo 94 | 95 | A demo application allows to apply the `LayoutPredictor` on a directory `` that contains 96 | `png` images and visualize the predictions inside another directory ``. 97 | 98 | First download the model weights (see above), then run: 99 | ``` 100 | python -m demo.demo_layout_predictor -i -v 101 | ``` 102 | 103 | e.g. 104 | ``` 105 | python -m demo.demo_layout_predictor -i tests/test_data/samples -v viz/ 106 | ``` 107 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement using 63 | [deepsearch-core@zurich.ibm.com](mailto:deepsearch-core@zurich.ibm.com). 64 | 65 | All complaints will be reviewed and investigated promptly and fairly. 66 | 67 | All community leaders are obligated to respect the privacy and security of the 68 | reporter of any incident. 69 | 70 | ## Enforcement Guidelines 71 | 72 | Community leaders will follow these Community Impact Guidelines in determining 73 | the consequences for any action they deem in violation of this Code of Conduct: 74 | 75 | ### 1. Correction 76 | 77 | **Community Impact**: Use of inappropriate language or other behavior deemed 78 | unprofessional or unwelcome in the community. 79 | 80 | **Consequence**: A private, written warning from community leaders, providing 81 | clarity around the nature of the violation and an explanation of why the 82 | behavior was inappropriate. A public apology may be requested. 83 | 84 | ### 2. Warning 85 | 86 | **Community Impact**: A violation through a single incident or series 87 | of actions. 88 | 89 | **Consequence**: A warning with consequences for continued behavior. No 90 | interaction with the people involved, including unsolicited interaction with 91 | those enforcing the Code of Conduct, for a specified period of time. This 92 | includes avoiding interactions in community spaces as well as external channels 93 | like social media. Violating these terms may lead to a temporary or 94 | permanent ban. 95 | 96 | ### 3. Temporary Ban 97 | 98 | **Community Impact**: A serious violation of community standards, including 99 | sustained inappropriate behavior. 100 | 101 | **Consequence**: A temporary ban from any sort of interaction or public 102 | communication with the community for a specified period of time. No public or 103 | private interaction with the people involved, including unsolicited interaction 104 | with those enforcing the Code of Conduct, is allowed during this period. 105 | Violating these terms may lead to a permanent ban. 106 | 107 | ### 4. Permanent Ban 108 | 109 | **Community Impact**: Demonstrating a pattern of violation of community 110 | standards, including sustained inappropriate behavior, harassment of an 111 | individual, or aggression toward or disparagement of classes of individuals. 112 | 113 | **Consequence**: A permanent ban from any sort of public interaction within 114 | the community. 115 | 116 | ## Attribution 117 | 118 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 119 | version 2.0, available at 120 | [https://www.contributor-covenant.org/version/2/0/code_of_conduct.html](https://www.contributor-covenant.org/version/2/0/code_of_conduct.html). 121 | 122 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 123 | enforcement ladder](https://github.com/mozilla/diversity). 124 | 125 | Homepage: [https://www.contributor-covenant.org](https://www.contributor-covenant.org) 126 | 127 | For answers to common questions about this code of conduct, see the FAQ at 128 | [https://www.contributor-covenant.org/faq](https://www.contributor-covenant.org/faq). Translations are available at 129 | [https://www.contributor-covenant.org/translations](https://www.contributor-covenant.org/translations). -------------------------------------------------------------------------------- /demo/demo_layout_predictor.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright IBM Corp. 2024 - 2024 3 | # SPDX-License-Identifier: MIT 4 | # 5 | import argparse 6 | import logging 7 | import os 8 | import sys 9 | import time 10 | from pathlib import Path 11 | from typing import Any, Dict, List 12 | import numpy as np 13 | import torch 14 | from huggingface_hub import snapshot_download 15 | from PIL import Image, ImageDraw, ImageFont 16 | 17 | from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor 18 | 19 | 20 | def save_predictions( 21 | prefix: str, viz_dir: str, img_fn: Path, img, predictions: List[Dict[str, Any]] 22 | ): 23 | img_path = Path(img_fn) 24 | 25 | image = img.copy() 26 | draw = ImageDraw.Draw(image) 27 | 28 | predictions_filename = f"{prefix}_{img_path.stem}.txt" 29 | predictions_fn = os.path.join(viz_dir, predictions_filename) 30 | with open(predictions_fn, "w") as fd: 31 | for pred in predictions: 32 | bbox = [ 33 | round(pred["l"], 2), 34 | round(pred["t"], 2), 35 | round(pred["r"], 2), 36 | round(pred["b"], 2), 37 | ] 38 | label = pred["label"] 39 | confidence = round(pred["confidence"], 3) 40 | 41 | # Save the predictions in txt file 42 | pred_txt = f"{prefix} {str(img_fn)}: {label} - {bbox} - {confidence}\n" 43 | fd.write(pred_txt) 44 | 45 | # Draw the bbox and label 46 | draw.rectangle(bbox, outline="orange") 47 | txt = f"{label}: {confidence}" 48 | draw.text( 49 | (bbox[0], bbox[1]), text=txt, font=ImageFont.load_default(), fill="blue" 50 | ) 51 | 52 | draw_filename = f"{prefix}_{img_path.name}" 53 | draw_fn = os.path.join(viz_dir, draw_filename) 54 | image.save(draw_fn) 55 | 56 | 57 | def demo( 58 | logger: logging.Logger, 59 | artifact_path: str, 60 | device: str, 61 | num_threads: int, 62 | img_dir: str, 63 | viz_dir: str, 64 | threshold: float, 65 | ): 66 | r""" 67 | Apply LayoutPredictor on the input image directory 68 | 69 | If you want to load from PDF: 70 | pdf_image = pyvips.Image.new_from_file("test_data/ADS.2007.page_123.pdf", page=0) 71 | """ 72 | # Create the layout predictor 73 | predictor = LayoutPredictor(artifact_path, device=device, num_threads=num_threads, base_threshold=threshold) 74 | 75 | # Predict all test png images 76 | t0 = time.perf_counter() 77 | img_counter = 0 78 | for img_fn in Path(img_dir).rglob("*.png"): 79 | img_counter += 1 80 | logger.info("Predicting '%s'...", img_fn) 81 | 82 | with Image.open(img_fn) as image: 83 | # Predict layout 84 | img_t0 = time.perf_counter() 85 | preds: List[Dict[str, Any]] = list(predictor.predict(image)) 86 | img_ms = 1000 * (time.perf_counter() - img_t0) 87 | logger.debug("Prediction(ms): {:.2f}".format(img_ms)) 88 | 89 | # Save predictions 90 | logger.info("Saving prediction visualization in: '%s'", viz_dir) 91 | save_predictions("ST", viz_dir, img_fn, image, preds) 92 | total_ms = 1000 * (time.perf_counter() - t0) 93 | avg_ms = (total_ms / img_counter) if img_counter > 0 else 0 94 | logger.info( 95 | "For {} images(ms): [total|avg] = [{:.1f}|{:.1f}]".format( 96 | img_counter, total_ms, avg_ms 97 | ) 98 | ) 99 | 100 | 101 | def main(args): 102 | r""" """ 103 | num_threads = int(args.num_threads) if args.num_threads is not None else 4 104 | device = args.device.lower() 105 | img_dir = args.img_dir 106 | viz_dir = args.viz_dir 107 | hugging_face_repo = args.hugging_face_repo 108 | threshold = float(args.threshold) 109 | 110 | # Initialize logger 111 | logging.basicConfig(level=logging.DEBUG) 112 | logger = logging.getLogger("LayoutPredictor") 113 | logger.setLevel(logging.DEBUG) 114 | if not logger.hasHandlers(): 115 | handler = logging.StreamHandler(sys.stdout) 116 | formatter = logging.Formatter( 117 | "%(asctime)s %(name)-12s %(levelname)-8s %(message)s" 118 | ) 119 | handler.setFormatter(formatter) 120 | logger.addHandler(handler) 121 | 122 | # Ensure the viz dir 123 | Path(viz_dir).mkdir(parents=True, exist_ok=True) 124 | 125 | # Download models from HF 126 | download_path = snapshot_download(repo_id=hugging_face_repo) 127 | 128 | # Test the LayoutPredictor 129 | demo(logger, download_path, device, num_threads, img_dir, viz_dir, threshold) 130 | 131 | 132 | if __name__ == "__main__": 133 | r""" 134 | python -m demo.demo_layout_predictor -i 135 | """ 136 | parser = argparse.ArgumentParser(description="Test the LayoutPredictor") 137 | 138 | supported_hf_repos = [ 139 | "ds4sd/docling-layout-old", 140 | "ds4sd/docling-layout-heron", 141 | "ds4sd/docling-layout-heron-101", 142 | "ds4sd/docling-layout-egret-medium", 143 | "ds4sd/docling-layout-egret-large", 144 | "ds4sd/docling-layout-egret-xlarge", 145 | ] 146 | parser.add_argument( 147 | "-r", 148 | "--hugging-face-repo", 149 | required=False, 150 | default="ds4sd/docling-layout-old", 151 | help=f"The hugging face repo id: [{', '.join(supported_hf_repos)}]", 152 | ) 153 | parser.add_argument( 154 | "-t", "--threshold", required=False, default=0.3, help="Threshold for the LayoutPredictor" 155 | ) 156 | parser.add_argument( 157 | "-d", "--device", required=False, default="cpu", help="One of [cpu, cuda, mps]" 158 | ) 159 | parser.add_argument( 160 | "-n", "--num_threads", required=False, default=4, help="Number of threads" 161 | ) 162 | parser.add_argument( 163 | "-i", 164 | "--img_dir", 165 | required=True, 166 | help="PNG images input directory", 167 | ) 168 | parser.add_argument( 169 | "-v", 170 | "--viz_dir", 171 | required=False, 172 | default="viz/", 173 | help="Directory to save prediction visualizations", 174 | ) 175 | 176 | args = parser.parse_args() 177 | main(args) 178 | -------------------------------------------------------------------------------- /docling_ibm_models/tableformer/models/table04_rs/bbox_decoder_rs.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright IBM Corp. 2024 - 2024 3 | # SPDX-License-Identifier: MIT 4 | # 5 | import logging 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | import docling_ibm_models.tableformer.settings as s 11 | import docling_ibm_models.tableformer.utils.utils as u 12 | 13 | # from scipy.optimize import linear_sum_assignment 14 | 15 | LOG_LEVEL = logging.INFO 16 | 17 | 18 | class CellAttention(nn.Module): 19 | """ 20 | Attention Network. 21 | """ 22 | 23 | def __init__(self, encoder_dim, tag_decoder_dim, language_dim, attention_dim): 24 | """ 25 | :param encoder_dim: feature size of encoded images 26 | :param tag_decoder_dim: size of tag decoder's RNN 27 | :param language_dim: size of language model's RNN 28 | :param attention_dim: size of the attention network 29 | """ 30 | super(CellAttention, self).__init__() 31 | # linear layer to transform encoded image 32 | self._encoder_att = nn.Linear(encoder_dim, attention_dim) 33 | # linear layer to transform tag decoder output 34 | self._tag_decoder_att = nn.Linear(tag_decoder_dim, attention_dim) 35 | # linear layer to transform language models output 36 | self._language_att = nn.Linear(language_dim, attention_dim) 37 | # linear layer to calculate values to be softmax-ed 38 | self._full_att = nn.Linear(attention_dim, 1) 39 | self._relu = nn.ReLU() 40 | self._softmax = nn.Softmax(dim=1) # softmax layer to calculate weights 41 | 42 | def _log(self): 43 | # Setup a custom logger 44 | return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL) 45 | 46 | def forward(self, encoder_out, decoder_hidden, language_out): 47 | """ 48 | Forward propagation. 49 | :param encoder_out: encoded images, a tensor of dimension (1, num_pixels, encoder_dim) 50 | :param decoder_hidden: tag decoder output, a tensor of dimension [(num_cells, 51 | tag_decoder_dim)] 52 | :param language_out: language model output, a tensor of dimension (num_cells, 53 | language_dim) 54 | :return: attention weighted encoding, weights 55 | """ 56 | att1 = self._encoder_att(encoder_out) # (1, num_pixels, attention_dim) 57 | att2 = self._tag_decoder_att(decoder_hidden) # (num_cells, tag_decoder_dim) 58 | att3 = self._language_att(language_out) # (num_cells, attention_dim) 59 | att = self._full_att( 60 | self._relu(att1 + att2.unsqueeze(1) + att3.unsqueeze(1)) 61 | ).squeeze(2) 62 | alpha = self._softmax(att) # (num_cells, num_pixels) 63 | # (num_cells, encoder_dim) 64 | attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1) 65 | return attention_weighted_encoding, alpha 66 | 67 | 68 | class BBoxDecoder(nn.Module): 69 | """ 70 | CellDecoder generates cell content 71 | """ 72 | 73 | def __init__( 74 | self, 75 | device, 76 | attention_dim, 77 | embed_dim, 78 | tag_decoder_dim, 79 | decoder_dim, 80 | num_classes, 81 | encoder_dim=512, 82 | dropout=0.5, 83 | cnn_layer_stride=1, 84 | ): 85 | """ 86 | :param attention_dim: size of attention network 87 | :param embed_dim: embedding size 88 | :param tag_decoder_dim: size of tag decoder's RNN 89 | :param decoder_dim: size of decoder's RNN 90 | :param vocab_size: size of vocabulary 91 | :param encoder_dim: feature size of encoded images 92 | :param dropout: dropout 93 | :param mini_batch_size: batch size of cells to reduce GPU memory usage 94 | """ 95 | super(BBoxDecoder, self).__init__() 96 | self._device = device 97 | self._encoder_dim = encoder_dim 98 | self._attention_dim = attention_dim 99 | self._embed_dim = embed_dim 100 | self._decoder_dim = decoder_dim 101 | self._dropout = dropout 102 | self._num_classes = num_classes 103 | 104 | if cnn_layer_stride is not None: 105 | self._input_filter = u.resnet_block(stride=cnn_layer_stride) 106 | # attention network 107 | self._attention = CellAttention( 108 | encoder_dim, tag_decoder_dim, decoder_dim, attention_dim 109 | ) 110 | # decoder LSTMCell 111 | self._init_h = nn.Linear(encoder_dim, decoder_dim) 112 | 113 | # linear layer to create a sigmoid-activated gate 114 | self._f_beta = nn.Linear(decoder_dim, encoder_dim) 115 | self._sigmoid = nn.Sigmoid() 116 | self._dropout = nn.Dropout(p=self._dropout) 117 | self._class_embed = nn.Linear(512, self._num_classes + 1) 118 | self._bbox_embed = u.MLP(512, 256, 4, 3) 119 | 120 | def _init_hidden_state(self, encoder_out, batch_size): 121 | mean_encoder_out = encoder_out.mean(dim=1) 122 | h = self._init_h(mean_encoder_out).expand(batch_size, -1) 123 | return h 124 | 125 | def _log(self): 126 | # Setup a custom logger 127 | return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL) 128 | 129 | def inference(self, encoder_out, tag_H): 130 | """ 131 | Inference on test images with beam search 132 | """ 133 | if hasattr(self, "_input_filter"): 134 | encoder_out = self._input_filter(encoder_out.permute(0, 3, 1, 2)).permute( 135 | 0, 2, 3, 1 136 | ) 137 | 138 | encoder_dim = encoder_out.size(3) 139 | 140 | # Flatten encoding (1, num_pixels, encoder_dim) 141 | encoder_out = encoder_out.view(1, -1, encoder_dim) 142 | 143 | num_cells = len(tag_H) 144 | predictions_bboxes = [] 145 | predictions_classes = [] 146 | 147 | for c_id in range(num_cells): 148 | # Start decoding 149 | h = self._init_hidden_state(encoder_out, 1) 150 | cell_tag_H = tag_H[c_id] 151 | awe, _ = self._attention(encoder_out, cell_tag_H, h) 152 | gate = self._sigmoid(self._f_beta(h)) 153 | awe = gate * awe 154 | h = awe * h 155 | 156 | predictions_bboxes.append(self._bbox_embed(h).sigmoid()) 157 | predictions_classes.append(self._class_embed(h)) 158 | if len(predictions_bboxes) > 0: 159 | predictions_bboxes = torch.stack([x[0] for x in predictions_bboxes]) 160 | else: 161 | predictions_bboxes = torch.empty(0) 162 | 163 | if len(predictions_classes) > 0: 164 | predictions_classes = torch.stack([x[0] for x in predictions_classes]) 165 | else: 166 | predictions_classes = torch.empty(0) 167 | 168 | return predictions_classes, predictions_bboxes 169 | -------------------------------------------------------------------------------- /docling_ibm_models/document_figure_classifier_model/document_figure_classifier_predictor.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright IBM Corp. 2024 - 2024 3 | # SPDX-License-Identifier: MIT 4 | # 5 | import logging 6 | import threading 7 | from typing import List, Tuple, Union 8 | 9 | import numpy as np 10 | import torch 11 | import torchvision.transforms as transforms 12 | from PIL import Image 13 | from transformers import AutoConfig, AutoModelForImageClassification 14 | 15 | _log = logging.getLogger(__name__) 16 | 17 | # Global lock for model initialization to prevent threading issues 18 | _model_init_lock = threading.Lock() 19 | 20 | 21 | class DocumentFigureClassifierPredictor: 22 | r""" 23 | Model for classifying document figures. 24 | 25 | Classifies figures as 1 out of 16 possible classes. 26 | 27 | The classes are: 28 | 1. "bar_chart" 29 | 2. "bar_code" 30 | 3. "chemistry_markush_structure" 31 | 4. "chemistry_molecular_structure" 32 | 5. "flow_chart" 33 | 6. "icon" 34 | 7. "line_chart" 35 | 8. "logo" 36 | 9. "map" 37 | 10. "other" 38 | 11. "pie_chart" 39 | 12. "qr_code" 40 | 13. "remote_sensing" 41 | 14. "screenshot" 42 | 15. "signature" 43 | 16. "stamp" 44 | 45 | Attributes 46 | ---------- 47 | _device : str 48 | The device on which the model is loaded (e.g., 'cpu' or 'cuda'). 49 | _num_threads : int 50 | Number of threads used for inference when running on CPU. 51 | _model : EfficientNetForImageClassification 52 | Pretrained EfficientNetb0 model. 53 | _image_processor : EfficientNetImageProcessor 54 | Processor for normalizing and preparing input images. 55 | _classes: List[str]: 56 | The classes used by the model. 57 | 58 | Methods 59 | ------- 60 | __init__(artifacts_path, device, num_threads) 61 | Initializes the DocumentFigureClassifierPredictor with the specified parameters. 62 | info() -> dict: 63 | Retrieves configuration details of the DocumentFigureClassifierPredictor instance. 64 | predict(images) -> List[List[float]] 65 | The confidence scores for the classification of each image. 66 | """ 67 | 68 | def __init__( 69 | self, 70 | artifacts_path: str, 71 | device: str = "cpu", 72 | num_threads: int = 4, 73 | ): 74 | r""" 75 | Initializes the DocumentFigureClassifierPredictor. 76 | 77 | Parameters 78 | ---------- 79 | artifacts_path : str 80 | Path to the directory containing the pretrained model files. 81 | device : str, optional 82 | Device to run the inference on ('cpu' or 'cuda'), by default "cpu". 83 | num_threads : int, optional 84 | Number of threads for CPU inference, by default 4. 85 | """ 86 | self._device = device 87 | self._num_threads = num_threads 88 | 89 | if device == "cpu": 90 | torch.set_num_threads(self._num_threads) 91 | 92 | with _model_init_lock: 93 | self._model = AutoModelForImageClassification.from_pretrained( 94 | artifacts_path, device_map=device 95 | ) 96 | self._model.eval() 97 | 98 | self._image_processor = transforms.Compose( 99 | [ 100 | transforms.Resize((224, 224)), 101 | transforms.ToTensor(), 102 | transforms.Normalize( 103 | mean=[0.485, 0.456, 0.406], 104 | std=[0.47853944, 0.4732864, 0.47434163], 105 | ), 106 | ] 107 | ) 108 | 109 | config = AutoConfig.from_pretrained(artifacts_path) 110 | 111 | self._classes = list(config.id2label.values()) 112 | self._classes.sort() 113 | 114 | _log.debug("CodeFormulaModel settings: {}".format(self.info())) 115 | 116 | def info(self) -> dict: 117 | """ 118 | Retrieves configuration details of the DocumentFigureClassifierPredictor instance. 119 | 120 | Returns 121 | ------- 122 | dict 123 | A dictionary containing configuration details such as the device, 124 | the number of threads used and the classe sused by the model. 125 | """ 126 | info = { 127 | "device": self._device, 128 | "num_threads": self._num_threads, 129 | "classes": self._classes, 130 | } 131 | return info 132 | 133 | def predict( 134 | self, images: List[Union[Image.Image, np.ndarray]] 135 | ) -> List[List[Tuple[str, float]]]: 136 | r""" 137 | Performs inference on a batch of figures. 138 | 139 | Parameters 140 | ---------- 141 | images : List[Union[Image.Image, np.ndarray]] 142 | A list of input images for inference. Each image can either be a 143 | PIL.Image.Image object or a NumPy array representing an image. 144 | 145 | Returns 146 | ------- 147 | List[List[Tuple[str, float]]] 148 | A list of predictions for each input image. Each prediction is a list of 149 | tuples representing the predicted class and confidence score: 150 | - str: The predicted class name for the image. 151 | - float: The confidence score associated with the predicted class, 152 | ranging from 0 to 1. 153 | 154 | The predictions for each image are sorted in descending order of confidence. 155 | """ 156 | rgb_images = [] 157 | for image in images: 158 | if isinstance(image, Image.Image): 159 | rgb_images.append(image.convert("RGB")) 160 | elif isinstance(image, np.ndarray): 161 | rgb_images.append(Image.fromarray(image).convert("RGB")) 162 | else: 163 | raise TypeError( 164 | "Supported input formats are PIL.Image.Image or numpy.ndarray." 165 | ) 166 | 167 | # (batch_size, 3, 224, 224) 168 | processed_images = [self._image_processor(image) for image in rgb_images] 169 | torch_images = torch.stack(processed_images).to(self._device) 170 | 171 | with torch.no_grad(): 172 | logits = self._model(torch_images).logits # (batch_size, num_classes) 173 | probs_batch = logits.softmax(dim=1) # (batch_size, num_classes) 174 | probs_batch = probs_batch.cpu().numpy().tolist() 175 | 176 | predictions_batch = [] 177 | for probs_image in probs_batch: 178 | preds = [(self._classes[i], prob) for i, prob in enumerate(probs_image)] 179 | preds.sort(key=lambda t: t[1], reverse=True) 180 | predictions_batch.append(preds) 181 | 182 | return predictions_batch 183 | -------------------------------------------------------------------------------- /docling_ibm_models/tableformer/utils/mem_monitor.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright IBM Corp. 2024 - 2024 3 | # SPDX-License-Identifier: MIT 4 | # 5 | import os 6 | import platform 7 | import re 8 | from typing import Dict, Union 9 | 10 | 11 | class MemMonitor: 12 | r""" 13 | Memory monitor for Linux 14 | 15 | It supports 2 approaches for extracting memory information: 16 | - linux-native: It parse the `/proc` pseudo-files. It is available only for Linux 17 | - psutil: Use the `psutil` library 18 | 19 | ## Linux-Native approach 20 | 21 | The linux-native approach implements 2 methods to extract the memory fields: 22 | 23 | 1. The `get_memory()` method: 24 | 25 | - It is very fast 26 | - It parses the `/proc//statm` pseudo-file 27 | - It Contains the following fields: 28 | size (1) total program size 29 | (same as VmSize in /proc/[pid]/status) 30 | resident (2) resident set size 31 | (same as VmRSS in /proc/[pid]/status) 32 | shared (3) number of resident shared pages (i.e., backed by a file) 33 | (same as RssFile+RssShmem in /proc/[pid]/status) 34 | text (4) text (code) 35 | lib (5) library (unused since Linux 2.6; always 0) 36 | data (6) data + stack 37 | dt (7) dirty pages (unused since Linux 2.6; always 0) 38 | 39 | 40 | 2. The `get_memory_full()` method: 41 | 42 | - It is slower to parse but contains more detailed information 43 | - It uses regex to parse the `/proc//status` pseudo-file 44 | - It contains the following fields: 45 | VmPeak: Peak virtual memory size. 46 | VmSize: Virtual memory size. 47 | VmLck: Locked memory size (see mlock(2)). 48 | VmPin: Pinned memory size (since Linux 3.2). These are pages that can't be moved because 49 | something needs to directly access physical memory. 50 | VmHWM: Peak resident set size ("high water mark"). 51 | VmRSS: Resident set size. Note that the value here is the sum of RssAnon, RssFile, and 52 | RssShmem. 53 | RssAnon: Size of resident anonymous memory. (since Linux 4.5). 54 | RssFile: Size of resident file mappings. (since Linux 4.5). 55 | RssShmem: Size of resident shared memory (includes System V shared memory, mappings from 56 | tmpfs(5), and shared anonymous mappings). (since Linux 4.5). 57 | VmData, VmStk, VmExe: Size of data, stack, and text segments. 58 | VmLib: Shared library code size. 59 | VmPTE: Page table entries size (since Linux 2.6.10). 60 | VmPMD: Size of second-level page tables (added in Linux 4.0; removed in Linux 4.15). 61 | VmSwap: Swapped-out virtual memory size by anonymous private pages; shmem swap usage is 62 | not included (since Linux 2.6.34). 63 | 64 | 65 | ## The psutil library 66 | 67 | - Apparently the psutil library parses the `/proc//statm` 68 | - The memory_info() function returns the fields: rss, vms, shared, text, lib, data, dirty 69 | 70 | 71 | ## Field mappings 72 | 73 | These are the fields returned by psutil memory_info() and their mapping in the /proc files: 74 | (I put ? when I am not 100% about the mapping) 75 | 76 | | psutil | /proc/$$/status | /proc/$$/statm | 77 | |---------|--------------------|----------------| 78 | | rss | VmRSS | resident | 79 | | vms | VmSize | size | 80 | | shared | RssFile + RssShmem | shared | 81 | | text | VmExe ? | text | 82 | | lib | RssShmem ? | lib | 83 | | data | VmData + VmStk | data | 84 | | dirty | VmSwap ? | dt | 85 | 86 | """ 87 | 88 | def __init__(self, enable=True): 89 | self._enable = enable 90 | self._pid = os.getpid() 91 | 92 | # Create regex for each memory field of the /proc/status pseudo-file 93 | self._status_fields = [ 94 | "VmPeak", 95 | "VmSize", 96 | "VmLck", 97 | "VmPin", 98 | "VmHWM", 99 | "VmRSS", 100 | "RssAnon", 101 | "RssFile", 102 | "RssShmem", 103 | "VmData", 104 | "VmStk", 105 | "VmExe", 106 | "VmLib", 107 | "VmPTE", 108 | "VmPMD", 109 | "VmSwap", 110 | ] 111 | self._status_regex = {} 112 | for mem_field in self._status_fields: 113 | regex_str = r"({}:)(\s+)(\d*)(.*)".format(mem_field) 114 | self._status_regex[mem_field] = re.compile(regex_str) 115 | 116 | def get_memory_full(self) -> Union[Dict, int]: 117 | r""" 118 | - Parse /proc/status to get all memory info. 119 | - The method returns a dict with the fields self._status_fields 120 | - This method is SLOW. Unless you need the full memory info, better to use `get_memory` 121 | 122 | The returned values are in kB 123 | """ 124 | if not self._enable: 125 | return -2 126 | if platform.system() != "Linux": 127 | return -1 128 | pid_fn = "/proc/{}/status".format(self._pid) 129 | 130 | # Dict to collect all memory fields 131 | memory = {} 132 | with open(pid_fn, "r") as fn: 133 | for ll in fn: 134 | for mem_field in self._status_fields: 135 | regex = self._status_regex[mem_field] 136 | m = regex.match(ll) 137 | if m is not None: 138 | memory[mem_field] = int(m.group(3)) 139 | if len(memory) == len(self._status_fields): 140 | break 141 | 142 | return memory 143 | 144 | def get_memory(self) -> Union[Dict, int]: 145 | r""" 146 | - Parse /proc/statm to get the most important memory fields 147 | - This is a fast implementation. 148 | - The method returns a dict with the fields: 149 | "size", "resident", "shared", "text", "lib", "data", "dt" 150 | - Check the documentation at the top for a mapping across the various fields 151 | 152 | The returned values are in kB 153 | """ 154 | if not self._enable: 155 | return -2 156 | if platform.system() != "Linux": 157 | return -1 158 | pid_fn = "/proc/{}/statm".format(self._pid) 159 | 160 | # Dict to collect all memory fields 161 | memory = {} 162 | with open(pid_fn, "r") as fn: 163 | ll = fn.read() 164 | # The values are in pages 165 | # Each page is 4096 bytes (4kB) 166 | data = [int(x) << 2 for x in ll.split(" ")] 167 | memory = { 168 | "size": data[0], 169 | "resident": data[1], 170 | "shared": data[2], 171 | "text": data[3], 172 | "lib": data[4], 173 | "data": data[5], 174 | "dt": data[6], 175 | } 176 | return memory 177 | -------------------------------------------------------------------------------- /docling_ibm_models/tableformer/models/table04_rs/transformer_rs.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright IBM Corp. 2024 - 2024 3 | # SPDX-License-Identifier: MIT 4 | # 5 | 6 | 7 | import logging 8 | import math 9 | from typing import Optional 10 | 11 | import torch 12 | from torch import Tensor, nn 13 | 14 | import docling_ibm_models.tableformer.utils.utils as u 15 | 16 | LOG_LEVEL = logging.INFO 17 | # LOG_LEVEL = logging.DEBUG 18 | 19 | 20 | class PositionalEncoding(nn.Module): 21 | def __init__(self, d_model, dropout=0.1, max_len=1024): 22 | super(PositionalEncoding, self).__init__() 23 | self.dropout = nn.Dropout(p=dropout) 24 | 25 | pe = torch.zeros(max_len, d_model) 26 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 27 | div_term = torch.exp( 28 | torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) 29 | ) 30 | pe[:, 0::2] = torch.sin(position * div_term) 31 | pe[:, 1::2] = torch.cos(position * div_term) 32 | pe = pe.unsqueeze(0).transpose(0, 1) 33 | self.register_buffer("pe", pe) 34 | 35 | def forward(self, x): 36 | x = x + self.pe[: x.size(0), :] 37 | return self.dropout(x) 38 | 39 | 40 | class TMTransformerDecoder(nn.TransformerDecoder): 41 | def forward( # type: ignore 42 | self, 43 | tgt: Tensor, 44 | memory: Optional[Tensor] = None, 45 | cache: Optional[Tensor] = None, 46 | memory_mask: Optional[Tensor] = None, 47 | tgt_key_padding_mask: Optional[Tensor] = None, 48 | memory_key_padding_mask: Optional[Tensor] = None, 49 | ) -> Tensor: 50 | """ 51 | Args: 52 | tgt (Tensor): encoded tags. (tags_len,bsz,hidden_dim) 53 | memory (Tensor): encoded image (enc_image_size,bsz,hidden_dim) 54 | cache (Optional[Tensor]): None during training, only used during inference. 55 | Returns: 56 | output (Tensor): (tags_len,bsz,hidden_dim) 57 | """ 58 | 59 | output = tgt 60 | 61 | # cache 62 | tag_cache = [] 63 | for i, mod in enumerate(self.layers): 64 | output = mod(output, memory) 65 | tag_cache.append(output) 66 | if cache is not None: 67 | output = torch.cat([cache[i], output], dim=0) 68 | 69 | if cache is not None: 70 | out_cache = torch.cat([cache, torch.stack(tag_cache, dim=0)], dim=1) 71 | else: 72 | out_cache = torch.stack(tag_cache, dim=0) 73 | 74 | return output, out_cache # type: ignore 75 | 76 | 77 | class TMTransformerDecoderLayer(nn.TransformerDecoderLayer): 78 | def forward( # type: ignore 79 | self, 80 | tgt: Tensor, 81 | memory: Optional[Tensor] = None, 82 | memory_mask: Optional[Tensor] = None, 83 | tgt_key_padding_mask: Optional[Tensor] = None, 84 | memory_key_padding_mask: Optional[Tensor] = None, 85 | ) -> Tensor: 86 | """ 87 | Args: 88 | same as TMTransformerDecoder 89 | Returns: 90 | Tensor: 91 | During training (seq_len,bsz,hidden_dim) 92 | If eval mode: embedding of last tag: (1,bsz,hidden_dim) 93 | """ 94 | 95 | # From PyTorch but modified to only use the last tag 96 | tgt_last_tok = tgt[-1:, :, :] 97 | 98 | tmp_tgt = self.self_attn( 99 | tgt_last_tok, 100 | tgt, 101 | tgt, 102 | attn_mask=None, # None, because we only care about the last tag 103 | key_padding_mask=tgt_key_padding_mask, 104 | need_weights=False, # Optimization: Don't compute attention weights 105 | )[0] 106 | tgt_last_tok = tgt_last_tok + self.dropout1(tmp_tgt) 107 | tgt_last_tok = self.norm1(tgt_last_tok) 108 | 109 | if memory is not None: 110 | tmp_tgt = self.multihead_attn( 111 | tgt_last_tok, 112 | memory, 113 | memory, 114 | attn_mask=memory_mask, 115 | key_padding_mask=memory_key_padding_mask, 116 | need_weights=False, # Optimization: Don't compute attention weights 117 | )[0] 118 | tgt_last_tok = tgt_last_tok + self.dropout2(tmp_tgt) 119 | tgt_last_tok = self.norm2(tgt_last_tok) 120 | 121 | tmp_tgt = self.linear2( 122 | self.dropout(self.activation(self.linear1(tgt_last_tok))) 123 | ) 124 | tgt_last_tok = tgt_last_tok + self.dropout3(tmp_tgt) 125 | tgt_last_tok = self.norm3(tgt_last_tok) 126 | return tgt_last_tok 127 | 128 | 129 | class Tag_Transformer(nn.Module): 130 | """ 131 | "Attention Is All You Need" - https://arxiv.org/abs/1706.03762 132 | """ 133 | 134 | def __init__( 135 | self, 136 | device, 137 | vocab_size, 138 | td_encode, 139 | embed_dim, 140 | encoder_layers, 141 | decoder_layers, 142 | enc_image_size, 143 | dropout=0.1, 144 | n_heads=4, 145 | dim_ff=1024, 146 | ): 147 | 148 | super(Tag_Transformer, self).__init__() 149 | 150 | self._device = device 151 | self._n_heads = n_heads 152 | self._embedding = nn.Embedding(vocab_size, embed_dim) 153 | self._positional_encoding = PositionalEncoding(embed_dim) 154 | self._td_encode = td_encode 155 | 156 | encoder_layer = nn.TransformerEncoderLayer( 157 | d_model=embed_dim, nhead=n_heads, dim_feedforward=dim_ff 158 | ) 159 | self._encoder = nn.TransformerEncoder( 160 | encoder_layer, num_layers=encoder_layers, enable_nested_tensor=False 161 | ) 162 | 163 | self._decoder = TMTransformerDecoder( 164 | TMTransformerDecoderLayer( 165 | d_model=embed_dim, 166 | nhead=n_heads, 167 | dim_feedforward=dim_ff, 168 | ), 169 | num_layers=decoder_layers, 170 | ) 171 | 172 | self._decoder_dim = embed_dim 173 | self._enc_image_size = enc_image_size 174 | self._input_filter = u.resnet_block(stride=1) 175 | self._fc = nn.Linear(embed_dim, vocab_size) 176 | 177 | def inference(self, enc_inputs, tags, tag_lens, num_cells): 178 | # CNN backbone image encoding 179 | enc_inputs = self._input_filter(enc_inputs.permute(0, 3, 1, 2)).permute( 180 | 0, 2, 3, 1 181 | ) 182 | 183 | batch_size = enc_inputs.size(0) 184 | encoder_dim = enc_inputs.size(-1) 185 | 186 | enc_inputs = enc_inputs.view(batch_size, -1, encoder_dim).to(self._device) 187 | 188 | enc_inputs = enc_inputs.permute(1, 0, 2) 189 | positions = enc_inputs.shape[0] 190 | # Transformer Encoder Encoded Image mask need to check if its useful 191 | encoder_mask = torch.zeros( 192 | (batch_size * self._n_heads, positions, positions), device=self._device 193 | ) == torch.ones( 194 | (batch_size * self._n_heads, positions, positions), device=self._device 195 | ) 196 | 197 | # Transformer Encoder 198 | encoder_out = self._encoder(enc_inputs, mask=encoder_mask) 199 | 200 | decode_lengths = (tag_lens - 1).tolist() 201 | 202 | tgt = self._positional_encoding(self._embedding(tags).permute(1, 0, 2)) 203 | 204 | decoded = self._decoder(tgt, memory=encoder_out) 205 | decoded = decoded.permute(1, 0, 2) 206 | predictions = self._fc(decoded) 207 | return predictions, decode_lengths 208 | -------------------------------------------------------------------------------- /.github/workflows/dco-advisor.yml: -------------------------------------------------------------------------------- 1 | name: DCO Advisor Bot 2 | 3 | on: 4 | pull_request_target: 5 | types: [opened, reopened, synchronize] 6 | 7 | permissions: 8 | pull-requests: write 9 | issues: write 10 | 11 | jobs: 12 | dco_advisor: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Handle DCO check result 16 | uses: actions/github-script@v7 17 | with: 18 | github-token: ${{ secrets.GITHUB_TOKEN }} 19 | script: | 20 | const pr = context.payload.pull_request || context.payload.check_run?.pull_requests?.[0]; 21 | if (!pr) return; 22 | 23 | const prNumber = pr.number; 24 | const baseRef = pr.base.ref; 25 | const headSha = 26 | context.payload.check_run?.head_sha || 27 | pr.head?.sha; 28 | const username = pr.user.login; 29 | 30 | console.log("HEAD SHA:", headSha); 31 | 32 | const sleep = ms => new Promise(resolve => setTimeout(resolve, ms)); 33 | 34 | // Poll until DCO check has a conclusion (max 6 attempts, 30s) 35 | let dcoCheck = null; 36 | for (let attempt = 0; attempt < 6; attempt++) { 37 | const { data: checks } = await github.rest.checks.listForRef({ 38 | owner: context.repo.owner, 39 | repo: context.repo.repo, 40 | ref: headSha 41 | }); 42 | 43 | 44 | console.log("All check runs:"); 45 | checks.check_runs.forEach(run => { 46 | console.log(`- ${run.name} (${run.status}/${run.conclusion}) @ ${run.head_sha}`); 47 | }); 48 | 49 | dcoCheck = checks.check_runs.find(run => 50 | run.name.toLowerCase().includes("dco") && 51 | !run.name.toLowerCase().includes("dco_advisor") && 52 | run.head_sha === headSha 53 | ); 54 | 55 | 56 | if (dcoCheck?.conclusion) break; 57 | console.log(`Waiting for DCO check... (${attempt + 1})`); 58 | await sleep(5000); // wait 5 seconds 59 | } 60 | 61 | if (!dcoCheck || !dcoCheck.conclusion) { 62 | console.log("DCO check did not complete in time."); 63 | return; 64 | } 65 | 66 | const isFailure = ["failure", "action_required"].includes(dcoCheck.conclusion); 67 | console.log(`DCO check conclusion for ${headSha}: ${dcoCheck.conclusion} (treated as ${isFailure ? "failure" : "success"})`); 68 | 69 | // Parse DCO output for commit SHAs and author 70 | let badCommits = []; 71 | let authorName = ""; 72 | let authorEmail = ""; 73 | let moreInfo = `More info: [DCO check report](${dcoCheck?.html_url})`; 74 | 75 | if (isFailure) { 76 | const { data: commits } = await github.rest.pulls.listCommits({ 77 | owner: context.repo.owner, 78 | repo: context.repo.repo, 79 | pull_number: prNumber, 80 | }); 81 | 82 | for (const commit of commits) { 83 | const commitMessage = commit.commit.message; 84 | const signoffMatch = commitMessage.match(/^Signed-off-by:\s+.+<.+>$/m); 85 | if (!signoffMatch) { 86 | console.log(`Bad commit found ${commit.sha}`) 87 | badCommits.push({ 88 | sha: commit.sha, 89 | authorName: commit.commit.author.name, 90 | authorEmail: commit.commit.author.email, 91 | }); 92 | } 93 | } 94 | } 95 | 96 | // If multiple authors are present, you could adapt the message accordingly 97 | // For now, we'll just use the first one 98 | if (badCommits.length > 0) { 99 | authorName = badCommits[0].authorName; 100 | authorEmail = badCommits[0].authorEmail; 101 | } 102 | 103 | // Generate remediation commit message if needed 104 | let remediationSnippet = ""; 105 | if (badCommits.length && authorEmail) { 106 | remediationSnippet = `git commit --allow-empty -s -m "DCO Remediation Commit for ${authorName} <${authorEmail}>\n\n` + 107 | badCommits.map(c => `I, ${c.authorName} <${c.authorEmail}>, hereby add my Signed-off-by to this commit: ${c.sha}`).join('\n') + 108 | `"`; 109 | } else { 110 | remediationSnippet = "# Unable to auto-generate remediation message. Please check the DCO check details."; 111 | } 112 | 113 | // Build comment 114 | const commentHeader = ''; 115 | let body = ""; 116 | 117 | if (isFailure) { 118 | body = [ 119 | commentHeader, 120 | '❌ **DCO Check Failed**', 121 | '', 122 | `Hi @${username}, your pull request has failed the Developer Certificate of Origin (DCO) check.`, 123 | '', 124 | 'This repository supports **remediation commits**, so you can fix this without rewriting history — but you must follow the required message format.', 125 | '', 126 | '---', 127 | '', 128 | '### 🛠 Quick Fix: Add a remediation commit', 129 | 'Run this command:', 130 | '', 131 | '```bash', 132 | remediationSnippet, 133 | 'git push', 134 | '```', 135 | '', 136 | '---', 137 | '', 138 | '
', 139 | '🔧 Advanced: Sign off each commit directly', 140 | '', 141 | '**For the latest commit:**', 142 | '```bash', 143 | 'git commit --amend --signoff', 144 | 'git push --force-with-lease', 145 | '```', 146 | '', 147 | '**For multiple commits:**', 148 | '```bash', 149 | `git rebase --signoff origin/${baseRef}`, 150 | 'git push --force-with-lease', 151 | '```', 152 | '', 153 | '
', 154 | '', 155 | moreInfo 156 | ].join('\n'); 157 | } else { 158 | body = [ 159 | commentHeader, 160 | '✅ **DCO Check Passed**', 161 | '', 162 | `Thanks @${username}, all your commits are properly signed off. 🎉` 163 | ].join('\n'); 164 | } 165 | 166 | // Get existing comments on the PR 167 | const { data: comments } = await github.rest.issues.listComments({ 168 | owner: context.repo.owner, 169 | repo: context.repo.repo, 170 | issue_number: prNumber 171 | }); 172 | 173 | // Look for a previous bot comment 174 | const existingComment = comments.find(c => 175 | c.body.includes("") 176 | ); 177 | 178 | if (existingComment) { 179 | await github.rest.issues.updateComment({ 180 | owner: context.repo.owner, 181 | repo: context.repo.repo, 182 | comment_id: existingComment.id, 183 | body: body 184 | }); 185 | } else { 186 | await github.rest.issues.createComment({ 187 | owner: context.repo.owner, 188 | repo: context.repo.repo, 189 | issue_number: prNumber, 190 | body: body 191 | }); 192 | } 193 | -------------------------------------------------------------------------------- /docling_ibm_models/tableformer/utils/app_profiler.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright IBM Corp. 2024 - 2024 3 | # SPDX-License-Identifier: MIT 4 | # 5 | import time 6 | from collections import deque 7 | from statistics import mean, median 8 | 9 | from docling_ibm_models.tableformer.utils.mem_monitor import MemMonitor 10 | 11 | 12 | class SingletonClass(type): 13 | r""" 14 | Generic singleton metaclass 15 | """ 16 | 17 | def __init__(self, name, bases, dic): 18 | self._instance = None 19 | super().__init__(name, bases, dic) 20 | 21 | def __call__(cls, *args, **kwargs): 22 | # Create a singleton if needed 23 | if cls._instance is None: 24 | singleton = cls.__new__(cls) 25 | singleton.__init__(*args, **kwargs) 26 | cls._instance = singleton 27 | return cls._instance 28 | 29 | 30 | class Profiler: 31 | r""" 32 | Application specific profiler 33 | Decompose the application into "sections". Each section is a label. 34 | The total time a section consumes is split into "intervals" 35 | Use the `begin`, `end` methods to mark the begining and end of an interval for 36 | a certain section 37 | """ 38 | 39 | def __init__(self): 40 | self._section_dts = {} # section name -> sum(section intervals) 41 | self._section_calls = {} # section name -> number of invocations 42 | self._section_kB = {} # section name -> max kB of used heap (resident set size) 43 | 44 | # section name -> beginning of the last interval 45 | self._last_begin = {} 46 | 47 | self._mem_monitor = MemMonitor() 48 | 49 | def begin(self, section_name, enable=True): 50 | r""" 51 | Mark the beginning of an interval 52 | 53 | Parameters 54 | ---------- 55 | section_name : string 56 | Name of the section 57 | enable : bool 58 | The actual interval entry takes place only if enable is true 59 | 60 | Return 61 | ------ 62 | True if the interval has actuall begun 63 | """ 64 | if not enable: 65 | return False 66 | self._last_begin[section_name] = time.time() 67 | return True 68 | 69 | def end(self, section_name, enable=True): 70 | r""" 71 | Mark the end of an interval for a certain section 72 | 73 | Parameters 74 | ---------- 75 | section_name : string 76 | Name of the section 77 | enable : bool 78 | The actual interval entry takes place only if enable is true 79 | 80 | Return 81 | ------ 82 | True if the section name is valid and an interval for this section has already begun 83 | False otherwise 84 | """ 85 | if not enable: 86 | return False 87 | if section_name not in self._last_begin: 88 | return False 89 | 90 | # Get memory 91 | kB = self._mem_monitor.get_memory() 92 | if isinstance(kB, dict): 93 | kB = kB["resident"] 94 | 95 | dt = time.time() - self._last_begin[section_name] 96 | if section_name not in self._section_dts: 97 | self._section_dts[section_name] = dt 98 | self._section_calls[section_name] = 1 99 | self._section_kB[section_name] = kB 100 | else: 101 | self._section_dts[section_name] += dt 102 | self._section_calls[section_name] += 1 103 | self._section_kB[section_name] = max(kB, self._section_kB[section_name]) 104 | 105 | return True 106 | 107 | def get_data(self, section_names=None): 108 | r""" 109 | Return a dict with profiling data for the specified sections. 110 | 111 | Parameter 112 | --------- 113 | section_names : list of string 114 | List with the section names to get their accumulative dt 115 | If it is None, all sections are returned 116 | 117 | Return 118 | ------ 119 | dict of dicts 120 | Outer key: section name 121 | Inner keys: "dt": Accumulative time for that section, "cells": Number of calls 122 | """ 123 | # Filter the section names to apply 124 | filtered_names = list( 125 | filter(lambda x: x in section_names, self._section_dts.keys()) 126 | if section_names is not None 127 | else self._section_dts.keys() 128 | ) 129 | data = {} 130 | for section_name in filtered_names: 131 | data[section_name] = { 132 | "dt": self._section_dts[section_name], 133 | "calls": self._section_calls[section_name], 134 | "kB": self._section_kB[section_name], 135 | } 136 | return data 137 | 138 | 139 | class AppProfiler(Profiler, metaclass=SingletonClass): 140 | r""" 141 | AppProfiler is a singleton of the Profiler for application wide usage 142 | """ 143 | 144 | def __init__(self): 145 | super(AppProfiler, self).__init__() 146 | 147 | 148 | class AggProfiler(metaclass=SingletonClass): 149 | r""" 150 | Generic wrapper of Profiler that enables aggregation of profiling statistics around Cycles 151 | 152 | - When a new cycle begins a new Profiler is created to keep the profiling data per section 153 | - Keep the last n cycles in a sliding window manner 154 | - At every time we can get profiling data about the last cycle and statistics over the last n 155 | cycles 156 | """ 157 | 158 | def __init__(self, window_size=20): 159 | self._window_size = window_size 160 | # deque with up to the last "window_size" Profilers. The newest at index 0 161 | self._cycles = deque() 162 | 163 | def start_agg(self, enable=True): 164 | r""" 165 | Returns 166 | ------- 167 | 0: not enabled 168 | 1: a new scope has started 169 | """ 170 | if not enable: 171 | return 0 172 | 173 | # Add a new profiler 174 | self._cycles.appendleft(Profiler()) 175 | # In case the deque has grown too much, remove the oldest Profiler 176 | if len(self._cycles) > self._window_size: 177 | self._cycles.pop() 178 | return 1 179 | 180 | def begin(self, section_name, enable=True): 181 | if not enable: 182 | return False 183 | if len(self._cycles) == 0: 184 | print("AggProfiler begin | Start Aggregator not initialized.") 185 | return False 186 | profiler = self._cycles[0] 187 | return profiler.begin(section_name) 188 | 189 | def end(self, section_name, enable=True): 190 | if not enable: 191 | return False 192 | if len(self._cycles) == 0: 193 | print("AggProfiler end | Start Aggregator not initialized.") 194 | return False 195 | profiler = self._cycles[0] 196 | return profiler.end(section_name) 197 | 198 | def get_data(self): 199 | r""" 200 | Get profiling data for: 201 | - The last cycle 202 | - Aggragated statistics (avg, median) per section and per metric across all cycles 203 | - The dt numbers for the mean/median is the average time for each section ACROSS the cycle 204 | - There is NO need to compute average by yourself. 205 | 206 | Returns 207 | ------- 208 | dict with the structure: 209 | - window: int with the size of the time sliding window 210 | - last: dict with the metrics for the last cycle (as provided by the Profiler) 211 | - mean: dict with the mean metrics per section across the cycle 212 | - section_name 213 | - metric_name: mean of the metric values 214 | - median: dict with the median metrics per section across the cycle 215 | - section_name 216 | - metric_name: median of the metric values 217 | """ 218 | last_data = self._cycles[0].get_data() 219 | data = { 220 | "window": len(self._cycles), 221 | "last": last_data, 222 | "mean": {}, 223 | "median": {}, 224 | } 225 | 226 | # Section -> metric -> [values] 227 | section_metric_values = {} 228 | 229 | # Collect the metrics 230 | for i, p in enumerate(self._cycles): 231 | p_data = p.get_data() 232 | for section_name, m_dict in p_data.items(): 233 | for m_name, m_val in m_dict.items(): 234 | if section_name not in section_metric_values: 235 | section_metric_values[section_name] = {} 236 | s_metrics = section_metric_values[section_name] 237 | if m_name not in s_metrics: 238 | s_metrics[m_name] = [] 239 | s_metrics[m_name].append(m_val) 240 | 241 | # Aggregate the metrics 242 | for section_name, m_dict in section_metric_values.items(): 243 | for m_name, m_values in m_dict.items(): 244 | if section_name not in data["mean"]: 245 | data["mean"][section_name] = {} 246 | if section_name not in data["median"]: 247 | data["median"][section_name] = {} 248 | 249 | mean_v = mean(m_values) 250 | median_v = median(m_values) 251 | data["mean"][section_name][m_name] = mean_v 252 | data["median"][section_name][m_name] = median_v 253 | 254 | return data 255 | -------------------------------------------------------------------------------- /docling_ibm_models/code_formula_model/models/sam_opt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # This file is part of the Vary project, originally located at: 4 | # https://github.com/Ucas-HaoranWei/Vary-toy/blob/main/Vary-master/vary/model/vary_opt.py 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | 19 | from typing import List, Optional, Tuple, Union 20 | 21 | import torch 22 | import torch.nn as nn 23 | from transformers import ( 24 | AutoConfig, 25 | AutoModelForCausalLM, 26 | OPTConfig, 27 | OPTForCausalLM, 28 | OPTModel, 29 | ) 30 | from transformers.modeling_outputs import ( 31 | BaseModelOutputWithPast, 32 | CausalLMOutputWithPast, 33 | ) 34 | 35 | from docling_ibm_models.code_formula_model.models.sam import build_sam_vit_b 36 | 37 | 38 | class SamOptConfig(OPTConfig): 39 | model_type = "sam_opt" 40 | 41 | def __init__( 42 | self, 43 | sam_image_size=1024, 44 | sam_mm_projector_in=1024, 45 | sam_mm_projector_out=768, 46 | **kwargs, 47 | ): 48 | super().__init__(**kwargs) 49 | self.sam_image_size = sam_image_size 50 | self.sam_mm_projector_in = sam_mm_projector_in 51 | self.sam_mm_projector_out = sam_mm_projector_out 52 | 53 | 54 | class SamOPTModel(OPTModel): 55 | config_class = SamOptConfig # type: ignore 56 | 57 | def __init__(self, config: OPTConfig): 58 | super(SamOPTModel, self).__init__(config) 59 | self.vision_tower = build_sam_vit_b(image_size=config.sam_image_size) 60 | 61 | self.mm_projector = nn.Linear( 62 | config.sam_mm_projector_in, config.sam_mm_projector_out 63 | ) 64 | 65 | def embed_tokens(self, x): 66 | return self.get_input_embeddings()(x) 67 | 68 | def forward( 69 | self, 70 | input_ids: torch.LongTensor, 71 | attention_mask: Optional[torch.Tensor] = None, 72 | past_key_values: Optional[List[torch.FloatTensor]] = None, 73 | inputs_embeds: Optional[torch.FloatTensor] = None, 74 | use_cache: Optional[bool] = None, 75 | output_attentions: Optional[bool] = None, 76 | output_hidden_states: Optional[bool] = None, 77 | images: Optional[torch.FloatTensor] = None, 78 | return_dict: Optional[bool] = None, 79 | ) -> Union[Tuple, BaseModelOutputWithPast]: 80 | 81 | if inputs_embeds is None: 82 | inputs_embeds = self.embed_tokens(input_ids) 83 | 84 | vision_tower = getattr(self, "vision_tower", None) 85 | im_start_token = getattr(self.config, "im_start_token", -1) # type: ignore 86 | 87 | if input_ids.shape[1] != 1 or self.training: # type: ignore 88 | with torch.set_grad_enabled(self.training): # type: ignore 89 | assert vision_tower is not None 90 | image_features = vision_tower(images) 91 | image_features = image_features.flatten(2).permute(0, 2, 1) 92 | image_features = self.mm_projector(image_features) 93 | 94 | new_input_embeds = [] 95 | for cur_input_ids, cur_input_embeds, cur_image_features in zip( 96 | input_ids, inputs_embeds, image_features 97 | ): 98 | image_start_token_position = int( 99 | torch.where(cur_input_ids == im_start_token)[0].item() 100 | ) # cast to int for mypy 101 | 102 | cur_image_features = cur_image_features.to( 103 | device=cur_input_embeds.device 104 | ) 105 | num_patches = cur_image_features.shape[0] 106 | cur_input_embeds = torch.cat( 107 | ( 108 | cur_input_embeds[: image_start_token_position + 1], 109 | cur_image_features, 110 | cur_input_embeds[ 111 | image_start_token_position + num_patches + 1 : 112 | ], 113 | ), 114 | dim=0, 115 | ) 116 | 117 | new_input_embeds.append(cur_input_embeds) 118 | 119 | inputs_embeds = torch.stack(new_input_embeds, dim=0) # type: ignore 120 | 121 | return super(SamOPTModel, self).forward( # type: ignore 122 | input_ids=None, 123 | attention_mask=attention_mask, 124 | past_key_values=past_key_values, 125 | inputs_embeds=inputs_embeds, 126 | use_cache=use_cache, 127 | output_attentions=output_attentions, 128 | output_hidden_states=output_hidden_states, 129 | return_dict=return_dict, 130 | ) 131 | 132 | 133 | class SamOPTForCausalLM(OPTForCausalLM): 134 | config_class = SamOptConfig # type: ignore 135 | 136 | def __init__(self, config): 137 | super(OPTForCausalLM, self).__init__(config) 138 | self.model = SamOPTModel(config) 139 | 140 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 141 | 142 | self.post_init() 143 | 144 | def get_model(self): 145 | return self.model 146 | 147 | def forward( 148 | self, 149 | input_ids: Optional[torch.LongTensor] = None, 150 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 151 | attention_mask: Optional[torch.FloatTensor] = None, 152 | token_type_ids: Optional[torch.LongTensor] = None, 153 | position_ids: Optional[torch.LongTensor] = None, 154 | head_mask: Optional[torch.FloatTensor] = None, 155 | inputs_embeds: Optional[torch.FloatTensor] = None, 156 | encoder_hidden_states: Optional[torch.Tensor] = None, 157 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 158 | labels: Optional[torch.LongTensor] = None, 159 | use_cache: Optional[bool] = None, 160 | output_attentions: Optional[bool] = None, 161 | output_hidden_states: Optional[bool] = None, 162 | images: Optional[torch.FloatTensor] = None, 163 | return_dict: Optional[bool] = None, 164 | ) -> Union[Tuple, CausalLMOutputWithPast]: 165 | output_attentions = ( 166 | output_attentions 167 | if output_attentions is not None 168 | else self.config.output_attentions # type: ignore 169 | ) 170 | output_hidden_states = ( 171 | output_hidden_states 172 | if output_hidden_states is not None 173 | else self.config.output_hidden_states # type: ignore 174 | ) 175 | 176 | outputs = self.model( 177 | input_ids=input_ids, 178 | past_key_values=past_key_values, 179 | attention_mask=attention_mask, 180 | inputs_embeds=inputs_embeds, 181 | use_cache=use_cache, 182 | output_attentions=output_attentions, 183 | output_hidden_states=output_hidden_states, 184 | images=images, 185 | return_dict=return_dict, 186 | ) 187 | 188 | hidden_states = outputs[0] 189 | logits = self.lm_head(hidden_states).contiguous() 190 | 191 | return CausalLMOutputWithPast( 192 | loss=None, 193 | logits=logits, 194 | past_key_values=outputs.past_key_values, 195 | hidden_states=outputs.hidden_states, 196 | attentions=outputs.attentions, 197 | ) 198 | 199 | def prepare_inputs_for_generation( 200 | self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs 201 | ): 202 | token_type_ids = kwargs.get("token_type_ids", None) 203 | if past_key_values: 204 | input_ids = input_ids[:, -1].unsqueeze(-1) 205 | if token_type_ids is not None: 206 | token_type_ids = token_type_ids[:, -1].unsqueeze(-1) 207 | 208 | attention_mask = kwargs.get("attention_mask", None) 209 | position_ids = kwargs.get("position_ids", None) 210 | 211 | if attention_mask is not None and position_ids is None: 212 | position_ids = attention_mask.long().cumsum(-1) - 1 213 | position_ids.masked_fill_(attention_mask == 0, 1) 214 | if past_key_values: 215 | position_ids = position_ids[:, -1].unsqueeze(-1) 216 | else: 217 | position_ids = None 218 | 219 | if inputs_embeds is not None and past_key_values is None: 220 | model_inputs = {"inputs_embeds": inputs_embeds} 221 | else: 222 | model_inputs = {"input_ids": input_ids} 223 | 224 | model_inputs.update( 225 | { 226 | "past_key_values": past_key_values, 227 | "use_cache": kwargs.get("use_cache"), 228 | "position_ids": position_ids, 229 | "attention_mask": attention_mask, 230 | "token_type_ids": token_type_ids, 231 | "images": kwargs.get("images", None), 232 | } 233 | ) 234 | return model_inputs 235 | 236 | 237 | AutoConfig.register("sam_opt", SamOptConfig) 238 | AutoModelForCausalLM.register(SamOptConfig, SamOPTForCausalLM) 239 | -------------------------------------------------------------------------------- /docling_ibm_models/layoutmodel/layout_predictor.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright IBM Corp. 2024 - 2024 3 | # SPDX-License-Identifier: MIT 4 | # 5 | import logging 6 | import os 7 | import threading 8 | from collections.abc import Iterable 9 | from typing import Dict, List, Set, Union 10 | 11 | import numpy as np 12 | import torch 13 | from PIL import Image 14 | from torch import Tensor 15 | from transformers import AutoModelForObjectDetection, RTDetrImageProcessor 16 | 17 | from docling_ibm_models.layoutmodel.labels import LayoutLabels 18 | 19 | _log = logging.getLogger(__name__) 20 | 21 | # Global lock for model initialization to prevent threading issues 22 | _model_init_lock = threading.Lock() 23 | 24 | 25 | class LayoutPredictor: 26 | """ 27 | Document layout prediction using safe tensors 28 | """ 29 | 30 | def __init__( 31 | self, 32 | artifact_path: str, 33 | device: str = "cpu", 34 | num_threads: int = 4, 35 | base_threshold: float = 0.3, 36 | blacklist_classes: Set[str] = set(), 37 | ): 38 | """ 39 | Provide the artifact path that contains the LayoutModel file 40 | 41 | Parameters 42 | ---------- 43 | artifact_path: Path for the model torch file. 44 | device: (Optional) device to run the inference. 45 | num_threads: (Optional) Number of threads to run the inference if device = 'cpu' 46 | 47 | Raises 48 | ------ 49 | FileNotFoundError when the model's torch file is missing 50 | """ 51 | # Blacklisted classes 52 | self._black_classes = blacklist_classes # set(["Form", "Key-Value Region"]) 53 | 54 | # Canonical classes 55 | self._labels = LayoutLabels() 56 | 57 | # Set basic params 58 | self._threshold = base_threshold # Score threshold 59 | 60 | # Set number of threads for CPU 61 | self._device = torch.device(device) 62 | self._num_threads = num_threads 63 | if device == "cpu": 64 | torch.set_num_threads(self._num_threads) 65 | 66 | # Load model file and configurations 67 | self._processor_config = os.path.join(artifact_path, "preprocessor_config.json") 68 | self._model_config = os.path.join(artifact_path, "config.json") 69 | self._st_fn = os.path.join(artifact_path, "model.safetensors") 70 | if not os.path.isfile(self._st_fn): 71 | raise FileNotFoundError("Missing safe tensors file: {}".format(self._st_fn)) 72 | if not os.path.isfile(self._processor_config): 73 | raise FileNotFoundError( 74 | f"Missing processor config file: {self._processor_config}" 75 | ) 76 | if not os.path.isfile(self._model_config): 77 | raise FileNotFoundError(f"Missing model config file: {self._model_config}") 78 | 79 | # Load model and move to device 80 | self._image_processor = RTDetrImageProcessor.from_json_file( 81 | self._processor_config 82 | ) 83 | 84 | # Use lock to prevent threading issues during model initialization 85 | with _model_init_lock: 86 | self._model = AutoModelForObjectDetection.from_pretrained( 87 | artifact_path, config=self._model_config, device_map=self._device 88 | ) 89 | self._model.eval() 90 | 91 | # Set classes map 92 | self._model_name = type(self._model).__name__ 93 | if self._model_name == "RTDetrForObjectDetection": 94 | self._classes_map = self._labels.shifted_canonical_categories() 95 | self._label_offset = 1 96 | else: 97 | self._classes_map = self._labels.canonical_categories() 98 | self._label_offset = 0 99 | 100 | _log.debug("LayoutPredictor settings: {}".format(self.info())) 101 | 102 | def info(self) -> dict: 103 | """ 104 | Get information about the configuration of LayoutPredictor 105 | """ 106 | info = { 107 | "model_name": self._model_name, 108 | "safe_tensors_file": self._st_fn, 109 | "device": self._device.type, 110 | "num_threads": self._num_threads, 111 | "image_size": self._image_processor.size, 112 | "threshold": self._threshold, 113 | } 114 | return info 115 | 116 | @torch.inference_mode() 117 | def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]: 118 | """ 119 | Predict bounding boxes for a given image. 120 | The origin (0, 0) is the top-left corner and the predicted bbox coords are provided as: 121 | [left, top, right, bottom] 122 | 123 | Parameter 124 | --------- 125 | origin_img: Image to be predicted as a PIL Image object or numpy array. 126 | 127 | Yield 128 | ----- 129 | Bounding box as a dict with the keys: "label", "confidence", "l", "t", "r", "b" 130 | 131 | Raises 132 | ------ 133 | TypeError when the input image is not supported 134 | """ 135 | # Convert image format 136 | if isinstance(orig_img, Image.Image): 137 | page_img = orig_img.convert("RGB") 138 | elif isinstance(orig_img, np.ndarray): 139 | page_img = Image.fromarray(orig_img).convert("RGB") 140 | else: 141 | raise TypeError("Not supported input image format") 142 | 143 | target_sizes = torch.tensor([page_img.size[::-1]]) 144 | inputs = self._image_processor(images=[page_img], return_tensors="pt").to( 145 | self._device 146 | ) 147 | outputs = self._model(**inputs) 148 | results: List[Dict[str, Tensor]] = ( 149 | self._image_processor.post_process_object_detection( 150 | outputs, 151 | target_sizes=target_sizes, 152 | threshold=self._threshold, 153 | ) 154 | ) 155 | 156 | w, h = page_img.size 157 | result = results[0] 158 | for score, label_id, box in zip( 159 | result["scores"], result["labels"], result["boxes"] 160 | ): 161 | score = float(score.item()) 162 | 163 | label_id = int(label_id.item()) + self._label_offset 164 | label_str = self._classes_map[label_id] 165 | 166 | # Filter out blacklisted classes 167 | if label_str in self._black_classes: 168 | continue 169 | 170 | bbox_float = [float(b.item()) for b in box] 171 | l = min(w, max(0, bbox_float[0])) 172 | t = min(h, max(0, bbox_float[1])) 173 | r = min(w, max(0, bbox_float[2])) 174 | b = min(h, max(0, bbox_float[3])) 175 | yield { 176 | "l": l, 177 | "t": t, 178 | "r": r, 179 | "b": b, 180 | "label": label_str, 181 | "confidence": score, 182 | } 183 | 184 | @torch.inference_mode() 185 | def predict_batch( 186 | self, images: List[Union[Image.Image, np.ndarray]] 187 | ) -> List[List[dict]]: 188 | """ 189 | Batch prediction for multiple images - more efficient than calling predict() multiple times. 190 | 191 | Parameters 192 | ---------- 193 | images : List[Union[Image.Image, np.ndarray]] 194 | List of images to process in a single batch 195 | 196 | Returns 197 | ------- 198 | List[List[dict]] 199 | List of prediction lists, one per input image. Each prediction dict contains: 200 | "label", "confidence", "l", "t", "r", "b" 201 | """ 202 | if not images: 203 | return [] 204 | 205 | # Convert all images to RGB PIL format 206 | pil_images = [] 207 | for img in images: 208 | if isinstance(img, Image.Image): 209 | pil_images.append(img.convert("RGB")) 210 | elif isinstance(img, np.ndarray): 211 | pil_images.append(Image.fromarray(img).convert("RGB")) 212 | else: 213 | raise TypeError("Not supported input image format") 214 | 215 | # Get target sizes for all images 216 | target_sizes = torch.tensor([img.size[::-1] for img in pil_images]) 217 | 218 | # Process all images in a single batch 219 | inputs = self._image_processor(images=pil_images, return_tensors="pt").to( 220 | self._device 221 | ) 222 | outputs = self._model(**inputs) 223 | 224 | # Post-process all results at once 225 | results_list: List[Dict[str, Tensor]] = ( 226 | self._image_processor.post_process_object_detection( 227 | outputs, 228 | target_sizes=target_sizes, 229 | threshold=self._threshold, 230 | ) 231 | ) 232 | 233 | # Convert results to standard format for each image 234 | all_predictions = [] 235 | 236 | for img, results in zip(pil_images, results_list): 237 | w, h = img.size 238 | predictions = [] 239 | 240 | for score, label_id, box in zip( 241 | results["scores"], results["labels"], results["boxes"] 242 | ): 243 | score = float(score.item()) 244 | label_id = int(label_id.item()) + self._label_offset 245 | label_str = self._classes_map[label_id] 246 | 247 | # Filter out blacklisted classes 248 | if label_str in self._black_classes: 249 | continue 250 | 251 | bbox_float = [float(b.item()) for b in box] 252 | l = min(w, max(0, bbox_float[0])) 253 | t = min(h, max(0, bbox_float[1])) 254 | r = min(w, max(0, bbox_float[2])) 255 | b = min(h, max(0, bbox_float[3])) 256 | 257 | predictions.append( 258 | { 259 | "l": l, 260 | "t": t, 261 | "r": r, 262 | "b": b, 263 | "label": label_str, 264 | "confidence": score, 265 | } 266 | ) 267 | 268 | all_predictions.append(predictions) 269 | 270 | return all_predictions 271 | -------------------------------------------------------------------------------- /docling_ibm_models/code_formula_model/code_formula_predictor.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright IBM Corp. 2024 - 2024 3 | # SPDX-License-Identifier: MIT 4 | # 5 | import logging 6 | import threading 7 | from typing import List, Optional, Union 8 | 9 | import numpy as np 10 | import torch 11 | from PIL import Image 12 | from transformers import AutoTokenizer, StoppingCriteria, StoppingCriteriaList 13 | 14 | from docling_ibm_models.code_formula_model.models.sam_opt import SamOPTForCausalLM 15 | from docling_ibm_models.code_formula_model.models.sam_opt_image_processor import ( 16 | SamOptImageProcessor, 17 | ) 18 | 19 | _log = logging.getLogger(__name__) 20 | 21 | # Global lock for model initialization to prevent threading issues 22 | _model_init_lock = threading.Lock() 23 | 24 | 25 | class StopOnString(StoppingCriteria): 26 | def __init__(self, tokenizer, stop_string): 27 | self.stop_token_ids = tokenizer.encode(stop_string, add_special_tokens=False) 28 | 29 | def __call__(self, input_ids, scores, **kwargs): 30 | for sequence in input_ids: 31 | sequence_list = sequence.tolist() 32 | for i in range(len(sequence_list) - len(self.stop_token_ids) + 1): 33 | if ( 34 | sequence_list[i : i + len(self.stop_token_ids)] 35 | == self.stop_token_ids 36 | ): 37 | return True 38 | return False 39 | 40 | 41 | class CodeFormulaPredictor: 42 | """ 43 | Code and Formula Predictor using a multi-modal vision-language model. 44 | 45 | This class enables the prediction of code or LaTeX representations 46 | from input images of code snippets or mathematical formulas. 47 | 48 | Attributes 49 | ---------- 50 | _device : str 51 | The device on which the model is loaded (e.g., 'cpu' or 'cuda'). 52 | _num_threads : int 53 | Number of threads used for inference when running on CPU. 54 | _tokenizer : transformers.PreTrainedTokenizer 55 | Tokenizer for processing textual inputs to the model. 56 | _model : transformers.PreTrainedModel 57 | Pretrained multi-modal vision-language model. 58 | _image_processor : transformers.ImageProcessor 59 | Processor for normalizing and preparing input images. 60 | _temperature : float 61 | Sampling temperature for generation; controls randomness in predictions. 62 | """ 63 | 64 | def __init__( 65 | self, 66 | artifacts_path: str, 67 | device: str = "cpu", 68 | num_threads: int = 4, 69 | ): 70 | """ 71 | Initializes the CodeFormulaPredictor with the specified model artifacts. 72 | 73 | Parameters 74 | ---------- 75 | artifacts_path : str 76 | Path to the directory containing the pretrained model files. 77 | device : str, optional 78 | Device to run the inference on ('cpu' or 'cuda'), by default "cpu". 79 | num_threads : int, optional 80 | Number of threads for CPU inference, by default 4. 81 | """ 82 | self._device = device 83 | self._num_threads = num_threads 84 | if device == "cpu": 85 | torch.set_num_threads(self._num_threads) 86 | 87 | # Use lock to prevent threading issues during model initialization 88 | with _model_init_lock: 89 | self._tokenizer = AutoTokenizer.from_pretrained( 90 | artifacts_path, use_fast=True, padding_side="left" 91 | ) 92 | self._model = SamOPTForCausalLM.from_pretrained( 93 | artifacts_path, device_map=self._device 94 | ) 95 | self._model.eval() 96 | 97 | self._image_processor = SamOptImageProcessor.from_pretrained(artifacts_path) 98 | 99 | _log.debug("CodeFormulaModel settings: {}".format(self.info())) 100 | 101 | def info(self) -> dict: 102 | """ 103 | Retrieves configuration details of the CodeFormulaPredictor instance. 104 | 105 | Returns 106 | ------- 107 | dict 108 | A dictionary containing configuration details such as the device and 109 | the number of threads used. 110 | """ 111 | info = { 112 | "device": self._device, 113 | "num_threads": self._num_threads, 114 | } 115 | return info 116 | 117 | def _get_prompt(self, label: str) -> str: 118 | """ 119 | Constructs the prompt for the model based on the input label. 120 | 121 | Parameters 122 | ---------- 123 | label : str 124 | The type of input, either 'code' or 'formula'. 125 | 126 | Returns 127 | ------- 128 | str 129 | The constructed prompt including necessary tokens and query. 130 | 131 | Raises 132 | ------ 133 | NotImplementedError 134 | If the label is not 'code' or 'formula'. 135 | """ 136 | if label == "code": 137 | query = "" 138 | elif label == "formula": 139 | query = "" 140 | else: 141 | raise NotImplementedError("Label must be either code or formula") 142 | 143 | prompt = ( 144 | "A chat between a curious user and an artificial intelligence" 145 | " assistant. The assistant gives helpful, detailed, and polite answers to" 146 | " the user's questions. USER: " 147 | ) 148 | prompt += ( 149 | "" + "" * 256 + "" + "\n" + " ASSISTANT:" + "\n" + query 150 | ) 151 | 152 | return prompt 153 | 154 | def _strip(self, text: str): 155 | """ 156 | Removes any occurrences of the substrings in remove_list from the end of text. 157 | 158 | Parameters 159 | ---------- 160 | text : str 161 | The original string. 162 | 163 | Returns 164 | ------- 165 | str 166 | The trimmed string. 167 | """ 168 | remove_list = [r"\quad", r"\\", r"\,", " c c c c", " l l l l l"] 169 | changed = True 170 | while changed: 171 | changed = False 172 | for substr in remove_list: 173 | if text.endswith(substr): 174 | text = text[: -len(substr)] 175 | changed = True 176 | 177 | return text.strip() 178 | 179 | @torch.inference_mode() 180 | def predict( 181 | self, 182 | images: List[Union[Image.Image, np.ndarray]], 183 | labels: List[str], 184 | temperature: Optional[float] = 0.0, 185 | ) -> List[str]: 186 | """ 187 | Predicts the textual representation of input images (code or LaTeX). 188 | 189 | Parameters 190 | ---------- 191 | images : List[Union[Image.Image, np.ndarray]] 192 | List of images to be processed, provided as PIL Image objects or numpy arrays. 193 | labels : List[str] 194 | List of labels indicating the type of each image ('code' or 'formula'). 195 | temperature : Optional[float] 196 | Sampling temperature for generation, by default set to 0.0. 197 | 198 | Returns 199 | ------- 200 | List[str] 201 | List of predicted textual outputs for each input image in the given input 202 | order. 203 | 204 | Raises 205 | ------ 206 | TypeError 207 | If any of the input images is not of a supported type (PIL Image or numpy array). 208 | Excpetion 209 | In case the temperature is an invalid number. 210 | """ 211 | if ( 212 | temperature is None 213 | or not (isinstance(temperature, float) or isinstance(temperature, int)) 214 | or temperature < 0 215 | ): 216 | raise Exception("Temperature must be a number greater or equal to 0.") 217 | 218 | do_sample = True 219 | if temperature == 0: 220 | do_sample = False 221 | temperature = None 222 | 223 | if len(labels) != len(images): 224 | raise Exception( 225 | "The number of images must be the same as the number of labels." 226 | ) 227 | 228 | images_tmp = [] 229 | for image in images: 230 | if isinstance(image, Image.Image): 231 | image = image.convert("RGB") 232 | elif isinstance(image, np.ndarray): 233 | image = Image.fromarray(image).convert("RGB") 234 | else: 235 | raise TypeError("Not supported input image format") 236 | images_tmp.append(image) 237 | 238 | images_tensor = torch.stack( 239 | [self._image_processor(img) for img in images_tmp] 240 | ).to(self._device) 241 | 242 | prompts = [self._get_prompt(label) for label in labels] 243 | 244 | tokenized = self._tokenizer(prompts, padding=True, return_tensors="pt") 245 | tokenized = {k: v.to(self._device) for k, v in tokenized.items()} 246 | 247 | prompt_ids = tokenized["input_ids"] 248 | attention_mask = tokenized["attention_mask"] 249 | 250 | stopping_criteria = StoppingCriteriaList( 251 | [ 252 | StopOnString(self._tokenizer, r" \quad \quad \quad \quad"), 253 | StopOnString(self._tokenizer, r" \\ \\ \\ \\"), 254 | StopOnString(self._tokenizer, r" \, \, \, \,"), 255 | StopOnString(self._tokenizer, r" c c c c c c c c c c c c c c c c"), 256 | StopOnString(self._tokenizer, r" l l l l l l l l l l l l l l l l l"), 257 | ] 258 | ) 259 | 260 | if self._device == "cpu": 261 | output_ids_list = self._model.generate( 262 | input_ids=prompt_ids, 263 | attention_mask=attention_mask, 264 | images=images_tensor, 265 | do_sample=do_sample, 266 | temperature=temperature, 267 | max_new_tokens=4096 - prompt_ids.shape[1], 268 | use_cache=True, 269 | no_repeat_ngram_size=200, 270 | stopping_criteria=stopping_criteria, 271 | ) 272 | else: 273 | with torch.autocast(device_type=self._device, dtype=torch.bfloat16): 274 | output_ids_list = self._model.generate( 275 | prompt_ids, 276 | images=images_tensor, 277 | do_sample=do_sample, 278 | temperature=temperature, 279 | max_new_tokens=4096 - prompt_ids.shape[1], 280 | use_cache=True, 281 | no_repeat_ngram_size=200, 282 | stopping_criteria=stopping_criteria, 283 | ) 284 | 285 | outputs = self._tokenizer.batch_decode( 286 | output_ids_list[:, prompt_ids.shape[1] :], skip_special_tokens=True 287 | ) 288 | outputs = [self._strip(output) for output in outputs] 289 | 290 | return outputs 291 | -------------------------------------------------------------------------------- /docling_ibm_models/tableformer/models/common/base_model.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright IBM Corp. 2024 - 2024 3 | # SPDX-License-Identifier: MIT 4 | # 5 | import glob 6 | import logging 7 | import os 8 | import time 9 | from abc import ABC, abstractmethod 10 | from pathlib import Path 11 | 12 | import torch 13 | 14 | import docling_ibm_models.tableformer.settings as s 15 | 16 | LOG_LEVEL = logging.INFO 17 | # LOG_LEVEL = logging.DEBUG 18 | 19 | 20 | class BaseModel(ABC): 21 | r""" 22 | BaseModel provides some common functionality for all models: 23 | - Saves checkpoint files for each epoch 24 | - Loads the model from the best available checkpoint 25 | - Save repository branch and commit 26 | """ 27 | 28 | def __init__(self, config, init_data, device): 29 | r""" 30 | Inputs: 31 | config: The configuration file 32 | init_data: Dictionary with initialization data. This dictionary can be used to pass any 33 | kind of initialization data for the models 34 | device: The device used to move the tensors of the model 35 | """ 36 | super(BaseModel, self).__init__() 37 | 38 | # Set config and device 39 | self._config = config 40 | self._init_data = init_data 41 | 42 | self._device = device 43 | 44 | self._save_dir = config["model"]["save_dir"] 45 | self._load_checkpoint = None 46 | if "load_checkpoint" in config["model"]: 47 | self._load_checkpoint = config["model"]["load_checkpoint"] 48 | 49 | self._branch_name = "dev/next" 50 | self._commit_sha = "1" 51 | 52 | # Keep a dictionary with the starting times per epoch. 53 | # NOTICE: Epochs start from 0 54 | self._epoch_start_ts = {0: time.time()} 55 | 56 | def _log(self): 57 | # Setup a custom logger 58 | return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL) 59 | 60 | @abstractmethod 61 | def predict(self, img, max_steps, beam_size, return_attention): 62 | pass 63 | 64 | def count_parameters(self): 65 | r"""Counts the number of trainable parameters of this model 66 | 67 | Output: 68 | num_parameters: number of trainable parameters 69 | """ 70 | num_parameters = sum(p.numel() for p in self.parameters() if p.requires_grad) 71 | 72 | return num_parameters 73 | 74 | def get_code_version(self): 75 | r"""Gets the source control version of this model code 76 | 77 | Returns 78 | ------- 79 | branch_name : str 80 | The name of the Git branch of this model code 81 | commit_sha : str 82 | The unique identifier of the Git commit of this model code 83 | """ 84 | 85 | return self._branch_name, self._commit_sha 86 | 87 | def get_save_directory(self): 88 | r""" 89 | Return the save directory 90 | """ 91 | return self._save_dir 92 | 93 | def is_saved(self): 94 | r""" 95 | This method returns True if both conditions are met: 96 | 1. There is a checkpoint file for the model. 97 | 2. The checkpoint file corresponds to the last training epoch set in the configuration file. 98 | """ 99 | # Get the saved_model 100 | saved_model, _ = self._load_best_checkpoint() 101 | 102 | if saved_model is None: 103 | return False 104 | 105 | epochs = self._config["train"]["epochs"] 106 | self._log().debug( 107 | "Best epoch in saved model: {}; Number of epochs in config: {}".format( 108 | saved_model["epoch"], epochs 109 | ) 110 | ) 111 | if epochs == saved_model["epoch"] + 1: 112 | return True 113 | 114 | return False 115 | 116 | def save(self, epoch=None, optimizers=None, losses=None, model_parameters=None): 117 | r""" 118 | Save the model data to the disk as a pickle file. 119 | 120 | Parameters 121 | ---------- 122 | epoch: Training epoch 123 | optimizers: Dictionary with the optimizers. The key specifies what the optimizer is 124 | used for. The 'state_dict' of each optimizer will be saved in the 125 | checkpoint file. 126 | losses: Dictionary with the losses. The key specifies what the loss is used for. Each 127 | value is a list 128 | model_parameters: Dictionary with model specific parameters that we need to save in the 129 | checkpoint file. 130 | Returns 131 | ------- 132 | True if success, False otherwise 133 | """ 134 | # Get the checkpoint_filename 135 | c_filename = self._build_checkpoint_filename(epoch) 136 | self._log().debug("Trying to save checkpoint file: {}".format(c_filename)) 137 | 138 | # Prepare a dictionary with all data we want to save 139 | optimizers_state_dict = None 140 | if optimizers is not None: 141 | optimizers_state_dict = {k: v.state_dict() for k, v in optimizers.items()} 142 | 143 | model_data = { 144 | "model_state_dict": self.state_dict(), 145 | "epoch": epoch, 146 | "optimizers": optimizers_state_dict, 147 | "losses": losses, 148 | "model_parameters": model_parameters, 149 | } 150 | 151 | # Add the processing time per epoch 152 | now = time.time() 153 | self._epoch_start_ts[epoch + 1] = now 154 | if epoch in self._epoch_start_ts: 155 | dt = now - self._epoch_start_ts[epoch] 156 | model_data["epoch_start_ts"] = self._epoch_start_ts[epoch] 157 | model_data["epoch_dt"] = dt 158 | 159 | # Create the save directory 160 | Path(self._save_dir).mkdir(parents=True, exist_ok=True) 161 | 162 | # Save the model 163 | torch.save(model_data, c_filename) 164 | 165 | # Return true if file is present, otherwise false 166 | if not os.path.isfile(c_filename): 167 | self._log().error("Cannot find the file to save: " + c_filename) 168 | return False 169 | 170 | # store code branch name and commit 171 | version_file = os.path.join(self._save_dir, "_version") 172 | with open(version_file, "w") as text_file: 173 | print("Model is using code [commit:branch]", file=text_file) 174 | print("{}:{}".format(self._commit_sha, self._branch_name), file=text_file) 175 | 176 | return True 177 | 178 | def load(self, optimizers=None): 179 | r""" 180 | Load the model data from the disk. 181 | The method will iterate over all *.check files and try to load the one from the highest 182 | epoch. 183 | 184 | Input: 185 | -optimizers: Dictionary with optimizers. If it is not null the keys will be used to 186 | associate the corresponding state_dicts from the checkpoint file and update 187 | the internal states of the provided optimizers. 188 | 189 | Output: 190 | - Success: True/ False 191 | - epoch: Loaded epoch or -1 if there are no checkpoint files 192 | - optimizers: Dictionary with loaded optimizers or empty dictionary of there is no 193 | checkpoint file 194 | - losses: Dictionary with loaded losses or empty dictionary of there is no checkpoint 195 | file 196 | - model_parameters: Dictionary with the model parameters or empty dictionary if there 197 | are no checkpoint files 198 | """ 199 | # Get the saved_model 200 | saved_model, _ = self._load_best_checkpoint() 201 | 202 | # Restore the model 203 | if saved_model is None: 204 | self._log().debug("No saved model checkpoint found") 205 | return False, -1, optimizers, {}, {} 206 | 207 | self._log().debug("Loading model from checkpoint file") 208 | self.load_state_dict(saved_model["model_state_dict"]) 209 | 210 | epoch = 0 211 | if "epoch" in saved_model: 212 | epoch = saved_model["epoch"] 213 | losses = {} 214 | if "losses" in saved_model: 215 | losses = saved_model["losses"] 216 | model_parameters = saved_model["model_parameters"] 217 | 218 | if optimizers is not None: 219 | for key, optimizer_state_dict in saved_model["optimizers"].items(): 220 | optimizers[key].load_state_dict(optimizer_state_dict) 221 | 222 | # Reset the start_ts of the next epoch 223 | self._epoch_start_ts[epoch + 1] = time.time() 224 | 225 | return True, epoch, optimizers, losses, model_parameters 226 | 227 | def _load_best_checkpoint(self): 228 | r""" 229 | If a "load_checkpoint" file has been provided, load this one. 230 | Otherwise use the "save_dir" and load the one with the most advanced epoch 231 | 232 | Returns 233 | ------- 234 | saved_model : dictionary 235 | Checkpoint file contents generated by torch.load, or None 236 | checkpoint_file : string 237 | Filename of the loaded checkpoint, or None 238 | """ 239 | checkpoint_files = [] 240 | # If a "load_checkpoint" file is provided, try to load it 241 | if self._load_checkpoint is not None: 242 | if not os.path.exists(self._load_checkpoint): 243 | self._log().error( 244 | "Cannot load the checkpoint: {}".format(self._load_checkpoint) 245 | ) 246 | return None, None 247 | checkpoint_files.append(self._load_checkpoint) 248 | else: 249 | # Iterate over all check files from the directory by reverse alphabetical order 250 | # This will get the biggest epoch first 251 | checkpoint_files = glob.glob(os.path.join(self._save_dir, "*.check")) 252 | checkpoint_files.sort(reverse=True) 253 | 254 | for checkpoint_file in checkpoint_files: 255 | try: 256 | # Try to load the file 257 | self._log().info( 258 | "Loading model checkpoint file: {}".format(checkpoint_file) 259 | ) 260 | saved_model = torch.load( 261 | checkpoint_file, map_location=self._device, weights_only=False 262 | ) 263 | return saved_model, checkpoint_file 264 | except RuntimeError: 265 | self._log().error("Cannot load file: {}".format(checkpoint_file)) 266 | 267 | return None, None 268 | 269 | def _build_checkpoint_filename(self, epoch): 270 | r""" 271 | Construct the full path for the filename of this checkpoint 272 | """ 273 | dataset_name = self._config["dataset"]["name"] 274 | model_type = self._config["model"]["type"] 275 | model_name = self._config["model"]["name"] 276 | filename = "{}_{}_{}_{:03}.check".format( 277 | model_type, model_name, dataset_name, epoch 278 | ) 279 | c_filename = os.path.join(self._save_dir, filename) 280 | 281 | return c_filename 282 | -------------------------------------------------------------------------------- /tests/test_reading_order.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import logging 4 | 5 | import sys 6 | 7 | import pytest 8 | 9 | from typing import List, Dict 10 | import random 11 | 12 | from docling_ibm_models.reading_order.reading_order_rb import PageElement, ReadingOrderPredictor 13 | 14 | from docling_core.types.doc.document import DoclingDocument, DocItem, TextItem, ContentLayer 15 | 16 | # Configure logging 17 | logging.basicConfig( 18 | level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" 19 | ) 20 | 21 | def rank_array(arr): 22 | """Compute ranks, resolving ties by averaging.""" 23 | sorted_indices = sorted(range(len(arr)), key=lambda i: arr[i]) # Sort indices 24 | ranks = [0] * len(arr) # Initialize ranks 25 | 26 | i = 0 27 | while i < len(arr): 28 | start = i 29 | while i + 1 < len(arr) and arr[sorted_indices[i]] == arr[sorted_indices[i + 1]]: 30 | i += 1 # Handle ties 31 | avg_rank = sum(range(start + 1, i + 2)) / (i - start + 1) # Average rank for ties 32 | for j in range(start, i + 1): 33 | ranks[sorted_indices[j]] = avg_rank 34 | i += 1 35 | return ranks 36 | 37 | def spearman_rank_correlation(arr1, arr2): 38 | assert len(arr1) == len(arr2), "Arrays must have the same length" 39 | 40 | # Compute ranks 41 | rank1 = rank_array(arr1) 42 | rank2 = rank_array(arr2) 43 | 44 | # Compute rank differences and apply formula 45 | d = [rank1[i] - rank2[i] for i in range(len(arr1))] 46 | d_squared_sum = sum(d_i ** 2 for d_i in d) 47 | 48 | n = len(arr1) 49 | if n > 1: 50 | rho = 1 - (6 * d_squared_sum) / (n * (n**2 - 1)) 51 | else: 52 | rho = 0 53 | return rho 54 | 55 | def test_readingorder(): 56 | if sys.version_info >= (3, 14): 57 | pytest.skip("Pyarrow is not yet available for Python 3.14, hence we cannot load the dataset.") 58 | 59 | from datasets import load_dataset 60 | 61 | 62 | ro_scores, caption_scores, footnote_scores = [], [], [] 63 | 64 | # Init the reading-order model 65 | romodel = ReadingOrderPredictor() 66 | 67 | ds = load_dataset("ds4sd/docling-dpbench") 68 | for row in ds["test"]: 69 | true_doc = DoclingDocument.model_validate_json(row["GroundTruthDocument"]) 70 | 71 | true_elements: List[PageElement] = [] 72 | pred_elements: List[PageElement] = [] 73 | 74 | to_ref: Dict[int, str] = {} 75 | from_ref: Dict[str, int] = {} 76 | 77 | for item, level in true_doc.iterate_items(included_content_layers={ContentLayer.BODY, ContentLayer.FURNITURE}): 78 | if isinstance(item, DocItem): 79 | for prov in item.prov: 80 | 81 | page_height = true_doc.pages[prov.page_no].size.height 82 | bbox = prov.bbox.to_bottom_left_origin(page_height=page_height) 83 | 84 | text = "" 85 | if isinstance(item, TextItem): 86 | text = item.text 87 | 88 | true_elements.append( 89 | PageElement( 90 | cid=len(true_elements), 91 | ref=item.get_ref(), 92 | text = text, 93 | page_no=prov.page_no, 94 | page_size = true_doc.pages[prov.page_no].size, 95 | label=item.label, 96 | l = bbox.l, 97 | r = bbox.r, 98 | b = bbox.b, 99 | t = bbox.t, 100 | coord_origin = bbox.coord_origin 101 | ) 102 | ) 103 | 104 | to_ref[true_elements[-1].cid] = item.get_ref().cref 105 | from_ref[item.get_ref().cref] = true_elements[-1].cid 106 | 107 | rand_elements = copy.deepcopy(true_elements) 108 | random.shuffle(rand_elements) 109 | 110 | """ 111 | print(f"reading {os.path.basename(filename)}") 112 | for true_elem, rand_elem in zip(true_elements, rand_elements): 113 | print("true: ", str(true_elem), ", rand: ", str(rand_elem)) 114 | """ 115 | 116 | pred_elements = romodel.predict_reading_order(page_elements=rand_elements) 117 | #pred_elements = romodel.predict_page(page_elements=rand_elements) 118 | 119 | assert len(pred_elements)==len(true_elements), f"{len(pred_elements)}!={len(true_elements)}" 120 | 121 | true_cids, pred_cids = [], [] 122 | for true_elem, pred_elem, rand_elem in zip(true_elements, 123 | pred_elements, 124 | rand_elements): 125 | true_cids.append(true_elem.cid) 126 | pred_cids.append(pred_elem.cid) 127 | 128 | score = spearman_rank_correlation(true_cids, pred_cids) 129 | ro_scores.append(score) 130 | 131 | filename = row["document_id"] 132 | 133 | if score == 0: 134 | continue 135 | # Identify special cases ... 136 | if filename in ["doc_906d54a21ef3c7bfac03f4bb613b0c79ef32fdf81b362450c79e98a96f88708a_page_000001.png", # 0.720588 137 | "doc_2cd17a32ee330a239e19c915738df0c27e8ec3635a60a7e16e2a0cf3868d4af3_page_000001.png", # 0.64920 138 | "doc_bcb3dafc35b5e7476fd1b9cd6eccf5eeef936cd5b13ad846a4943f1e7797f4e9_page_000001.png", # 0.65 139 | "doc_a0edae1fa147c7bb78ebc493743a68ba4372b5ead31f2a2b146c35119462379e_page_000001.png", # 0.82857 140 | "doc_94ba5468fcb6277721947697048846dc0d0551296be3b45f5918ab857d21dcc7_page_000001.png", # 0.857142 141 | # "doc_cbb4a13ffd01d9f777fdb939451d6a21cea1b869ee50d79581451e3601df9ec8_page_000001.png", 142 | 143 | "doc_e2b604a3fb1541b82b6af8caca05682dff0c7735e0a3a4fa7c6a68246fb60e57_page_000001.png", # 0.657142 144 | "doc_827d21de372a2c26237ee1db526460851ae71c1867761776583535f532432e32_page_000001.png", # 0.8922077 145 | "doc_b862cd0d6f06c06ee5ab7729ed4e8ce58e6964eb0f1ab98b3865b57a4808216f_page_000001.png"]: # 0.642857 146 | # print(f"{os.path.basename(filename)}: {score}") 147 | assert score>=0.60, f"reading-order score={score}>0.60" 148 | else: 149 | assert score>=0.90, f"reading-order score={score}>0.90 for {filename}" 150 | 151 | 152 | true_to_captions: Dict[int, List[int]] = {} 153 | true_to_footnotes: Dict[int, List[int]] = {} 154 | 155 | total_caption_links = 0 156 | total_footnote_links = 0 157 | 158 | for table in true_doc.tables: 159 | table_cid = from_ref[table.get_ref().cref] 160 | 161 | true_to_captions[table_cid] = [] 162 | for caption in table.captions: 163 | caption_cid = from_ref[caption.get_ref().cref] 164 | true_to_captions[table_cid].append(caption_cid) 165 | 166 | total_caption_links += 1 167 | 168 | true_to_footnotes[table_cid] = [] 169 | for footnote in table.footnotes: 170 | footnote_cid = from_ref[footnote.get_ref().cref] 171 | true_to_footnotes[table_cid].append(footnote_cid) 172 | 173 | total_footnote_links += 1 174 | 175 | for picture in true_doc.pictures: 176 | picture_cid = from_ref[picture.get_ref().cref] 177 | 178 | true_to_captions[picture_cid] = [] 179 | for caption in picture.captions: 180 | caption_cid = from_ref[caption.get_ref().cref] 181 | true_to_captions[picture_cid].append(caption_cid) 182 | 183 | total_caption_links += 1 184 | 185 | true_to_footnotes[picture_cid] = [] 186 | for footnote in picture.footnotes: 187 | footnote_cid = from_ref[footnote.get_ref().cref] 188 | true_to_footnotes[picture_cid].append(footnote_cid) 189 | 190 | total_footnote_links += 1 191 | 192 | if total_caption_links>0: 193 | #print(" *********** ") 194 | 195 | pred_to_captions = romodel.predict_to_captions(sorted_elements=pred_elements) 196 | 197 | """ 198 | for key,val in pred_to_captions.items(): 199 | print(f"pred {key}: {val}") 200 | """ 201 | 202 | score, total = 0.0, 0.0 203 | for key,val in true_to_captions.items(): 204 | # print(f"true {key}: {val}") 205 | 206 | total += 1.0 207 | if key in pred_to_captions and pred_to_captions[key]==val: 208 | score += 1.0 209 | 210 | # print(f"to_captions: {score/total}") 211 | caption_scores.append(score/total) 212 | 213 | if total_footnote_links>0: 214 | # print(" *********** ") 215 | 216 | pred_to_footnotes = romodel.predict_to_footnotes(sorted_elements=pred_elements) 217 | 218 | """ 219 | for key,val in pred_to_footnotes.items(): 220 | print(f"pred {key}: {val}") 221 | """ 222 | 223 | score, total = 0.0, 0.0 224 | for key,val in true_to_footnotes.items(): 225 | # print(f"true {key}: {val}") 226 | 227 | total += 1.0 228 | if key in pred_to_footnotes and pred_to_footnotes[key]==val: 229 | score += 1.0 230 | 231 | # print(f"to_footnotes: {score/total}") 232 | footnote_scores.append(score/total) 233 | 234 | pred_merges = romodel.predict_merges(sorted_elements=pred_elements) 235 | # print("merges: ", pred_merges) 236 | 237 | 238 | mean_ro_score = sum(ro_scores)/len(ro_scores) 239 | mean_cp_score = sum(caption_scores)/len(caption_scores) 240 | mean_ft_score = sum(footnote_scores)/len(footnote_scores) 241 | 242 | assert mean_ro_score>0.95, "mean_ro_score>0.95" 243 | assert mean_cp_score>0.85, "mean_cp_score>0.85" 244 | assert mean_ft_score>0.90, "mean_ft_score>0.90" 245 | 246 | print("\n score(reading): ", mean_ro_score) 247 | print(" score(caption): ", mean_cp_score) 248 | print("score(footnotes): ", mean_ft_score) 249 | 250 | 251 | """ 252 | def test_readingorder_multipage(): 253 | 254 | filename = Path("") 255 | 256 | # Init the reading-order model 257 | romodel = ReadingOrderPredictor() 258 | 259 | true_elements: List[PageElement] = [] 260 | pred_elements: List[PageElement] = [] 261 | 262 | with open(filename, "r") as fr: 263 | data = json.load(fr) 264 | true_elements = [PageElement.model_validate(item) for item in data] 265 | 266 | pred_elements = romodel.predict_reading_order(page_elements=true_elements) 267 | for true_elem, pred_elem in zip(true_elements, pred_elements): 268 | print("true: ", str(true_elem), ", pred: ", str(pred_elem)) 269 | """ -------------------------------------------------------------------------------- /docling_ibm_models/tableformer/utils/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright IBM Corp. 2024 - 2024 3 | # SPDX-License-Identifier: MIT 4 | # 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from PIL import Image 10 | from torchvision.models.resnet import BasicBlock, conv1x1 11 | from torchvision.ops.boxes import box_area 12 | 13 | 14 | def remove_padding(seq): 15 | r""" 16 | Remove the trailing zeros from the provided input 17 | 18 | Parameters 19 | ---------- 20 | list: List of integers 21 | Predicted sequence 22 | 23 | Returns 24 | ------- 25 | list: List of integers 26 | The part of the input before the zero padding 27 | 28 | """ 29 | pad_len = 0 30 | for x in reversed(seq): 31 | if x != 0: 32 | break 33 | pad_len += 1 34 | if pad_len == 0: 35 | return seq, 0 36 | 37 | un_padded = seq[:-pad_len] 38 | return un_padded, pad_len 39 | 40 | 41 | def probabilities_to_predictions(probabilities): 42 | r""" 43 | Convert probabilities to predictions 44 | 45 | Parameters 46 | ---------- 47 | probabilities : Tensor[batch_size, vocab_size, seq_len] 48 | All log probabilities coming out at the last stage of the decoder 49 | 50 | Returns 51 | ------- 52 | predictions : tensor [batch_size, output_sequence_length] 53 | The prediceted trags 54 | 55 | """ 56 | # max_idx: [batch_size, seq_len] 57 | max_idx = torch.argmax(probabilities, dim=1) 58 | return max_idx 59 | 60 | 61 | def print_target_predict(target, predictions, filenames=None, batch_idx=0): 62 | r""" 63 | For the Tags, print the target and predicted tensors for the specified batch index 64 | 65 | We expect to have the batch size as the first dimension. 66 | Only the specified batch is extractred and the remaining dimenions are flattened. 67 | The results are printed as 2 lists with the target on top and the predictions below underlined 68 | 69 | Parameters 70 | --------- 71 | target : tensor [batch_size, output_sequence_length] 72 | The ground truth tags 73 | 74 | predictions : tensor [batch_size, output_sequence_length] 75 | The prediceted trags 76 | 77 | filenames : list of string 78 | The actual filename that provides the data 79 | 80 | batch_idx : int 81 | Which index in the batch dimension will be printed 82 | """ 83 | target_flat = target[batch_idx].flatten() 84 | predictions_flat = predictions[batch_idx].flatten() 85 | target_label = "target" 86 | predict_label = "predict" 87 | if filenames is not None: 88 | target_label = filenames[batch_idx] 89 | label_len = max(len(target_label), len(predict_label)) 90 | print("{}: {}".format(target_label.ljust(label_len, " "), target_flat.tolist())) 91 | print( 92 | "{}: {}".format(predict_label.ljust(label_len, " "), predictions_flat.tolist()) 93 | ) 94 | 95 | 96 | def load_image(full_fn): 97 | r""" 98 | Load an image from the disk as a numpy array 99 | 100 | Parameters 101 | ---------- 102 | full_fn : string 103 | The full path filename of the image 104 | 105 | Results 106 | ------- 107 | img : numpy array: (channels, width, height) 108 | The loaded image as a numpy array 109 | """ 110 | with Image.open(full_fn) as f: 111 | img = np.asarray(f) # (width, height, channels) 112 | img = img.transpose(2, 0, 1) # (channels, width, height) 113 | return img 114 | 115 | 116 | def resnet_block(stride=1): 117 | layers = [] 118 | downsample = nn.Sequential( 119 | conv1x1(256, 512, stride), 120 | nn.BatchNorm2d(512), 121 | ) 122 | layers.append(BasicBlock(256, 512, stride, downsample)) 123 | layers.append(BasicBlock(512, 512, 1)) 124 | return nn.Sequential(*layers) 125 | 126 | 127 | def repackage_hidden(h): 128 | r""" 129 | Wraps hidden states in new Tensors, to detach them from their history. 130 | """ 131 | if isinstance(h, torch.Tensor): 132 | return h.detach() 133 | else: 134 | return tuple(repackage_hidden(v) for v in h) 135 | 136 | 137 | def accuracy(scores, targets, k): 138 | """ 139 | Computes top-k accuracy, from predicted and true labels. 140 | 141 | :param scores: scores from the model 142 | :param targets: true labels 143 | :param k: k in top-k accuracy 144 | :return: top-k accuracy 145 | """ 146 | 147 | batch_size = targets.size(0) 148 | _, ind = scores.topk(k, 1, True, True) 149 | correct = ind.eq(targets.view(-1, 1).expand_as(ind)) 150 | correct_total = correct.view(-1).float().sum() # 0D tensor 151 | return correct_total.item() * (100.0 / batch_size) 152 | 153 | 154 | def clip_gradient(optimizer, grad_clip): 155 | """ 156 | Clips gradients computed during backpropagation to avoid explosion of gradients. 157 | 158 | :param optimizer: optimizer with the gradients to be clipped 159 | :param grad_clip: clip value 160 | """ 161 | for group in optimizer.param_groups: 162 | for param in group["params"]: 163 | if param.grad is not None: 164 | param.grad.data.clamp_(-grad_clip, grad_clip) 165 | 166 | 167 | class AverageMeter(object): 168 | """ 169 | Keeps track of most recent, average, sum, and count of a metric. 170 | """ 171 | 172 | def __init__(self): 173 | self.reset() 174 | 175 | def reset(self): 176 | self.val = 0 177 | self.avg = 0 178 | self.sum = 0 179 | self.count = 0 180 | 181 | def update(self, val, n=1): 182 | self.val = val 183 | self.sum += val * n 184 | self.count += n 185 | self.avg = self.sum / self.count 186 | 187 | 188 | @torch.no_grad() 189 | def bip_accuracy(output, target, topk=(1,)): 190 | """Computes the precision@k for the specified values of k""" 191 | if target.numel() == 0: 192 | return [torch.zeros([], device=output.device)] 193 | maxk = max(topk) 194 | batch_size = target.size(0) 195 | 196 | _, pred = output.topk(maxk, 1, True, True) 197 | pred = pred.t() 198 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 199 | 200 | res = [] 201 | for k in topk: 202 | correct_k = correct[:k].view(-1).float().sum(0) 203 | res.append(correct_k.mul_(100.0 / batch_size)) 204 | return res 205 | 206 | 207 | def box_cxcywh_to_xyxy(x): 208 | x_c, y_c, w, h = x.unbind(-1) 209 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] 210 | return torch.stack(b, dim=-1) 211 | 212 | 213 | def box_xyxy_to_cxcywh(x): 214 | x0, y0, x1, y1 = x.unbind(-1) 215 | b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] 216 | return torch.stack(b, dim=-1) 217 | 218 | 219 | # modified from torchvision to also return the union 220 | def box_iou(boxes1, boxes2): 221 | area1 = box_area(boxes1) 222 | area2 = box_area(boxes2) 223 | 224 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 225 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 226 | 227 | wh = (rb - lt).clamp(min=0) # [N,M,2] 228 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 229 | 230 | union = area1[:, None] + area2 - inter 231 | 232 | iou = inter / union 233 | return iou, union 234 | 235 | 236 | def generalized_box_iou(boxes1, boxes2): 237 | """ 238 | Generalized IoU from https://giou.stanford.edu/ 239 | 240 | The boxes should be in [x0, y0, x1, y1] format 241 | 242 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 243 | and M = len(boxes2) 244 | """ 245 | # degenerate boxes gives inf / nan results 246 | # so do an early check 247 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 248 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 249 | iou, union = box_iou(boxes1, boxes2) 250 | 251 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 252 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 253 | 254 | wh = (rb - lt).clamp(min=0) # [N,M,2] 255 | area = wh[:, :, 0] * wh[:, :, 1] 256 | 257 | return iou - (area - union) / area 258 | 259 | 260 | class MLP(nn.Module): 261 | """Very simple multi-layer perceptron (also called FFN)""" 262 | 263 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers): 264 | super().__init__() 265 | self.num_layers = num_layers 266 | h = [hidden_dim] * (num_layers - 1) 267 | self.layers = nn.ModuleList( 268 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 269 | ) 270 | 271 | def forward(self, x): 272 | for i, layer in enumerate(self.layers): 273 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 274 | return x 275 | 276 | 277 | def generate_square_subsequent_mask(sz: int, device: str = "cpu") -> torch.Tensor: 278 | """Generate the attention mask for causal decoding""" 279 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) 280 | mask = ( 281 | mask.float() 282 | .masked_fill(mask == 0, float("-inf")) 283 | .masked_fill(mask == 1, float(0.0)) 284 | ).to(device=device) 285 | return mask 286 | 287 | 288 | class EarlyStopping: 289 | """Early stops the training if validation loss doesn't improve after a given patience. 290 | Source from: https://github.com/Bjarten/early-stopping-pytorch 291 | """ 292 | 293 | def __init__(self, patience=2, verbose=False, delta=0, trace_func=print): 294 | """ 295 | Args: 296 | patience (int): How long to wait after last time validation loss improved. 297 | Default: 7 298 | verbose (bool): If True, prints a message for each validation loss improvement. 299 | Default: False 300 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 301 | Default: 0 302 | path (str): Path for the checkpoint to be saved to. 303 | Default: 'checkpoint.pt' 304 | trace_func (function): trace print function. 305 | Default: print 306 | """ 307 | self._patience = patience 308 | self._verbose = verbose 309 | self._counter = 0 310 | self._best_score = None 311 | self._early_stop = False 312 | self._val_loss_min = np.Inf 313 | self._delta = delta 314 | self._trace_func = trace_func 315 | 316 | def __call__(self, val_loss): 317 | score = -val_loss 318 | save_checkpoint = True 319 | if self._best_score is None: 320 | self._best_score = score 321 | save_checkpoint = True 322 | if self._verbose: 323 | verb = f"Validation loss decreased ({self._val_loss_min:.6f} --> {val_loss:.6f})." 324 | self._trace_func(verb) 325 | self._val_loss_min = val_loss 326 | elif score < self._best_score + self._delta: 327 | self._counter += 1 328 | self._trace_func( 329 | f"EarlyStopping counter: {self._counter} out of {self._patience}" 330 | ) 331 | if self._counter >= self._patience: 332 | self._early_stop = True 333 | save_checkpoint = False 334 | else: 335 | self._best_score = score 336 | save_checkpoint = True 337 | self._counter = 0 338 | if self._verbose: 339 | verb = f"Validation loss decreased ({self._val_loss_min:.6f} --> {val_loss:.6f})." 340 | self._trace_func(verb) 341 | self._val_loss_min = val_loss 342 | return save_checkpoint 343 | 344 | 345 | def print_dict(m: dict): 346 | r""" 347 | Print dict elements in separate lines sorted by keys 348 | """ 349 | if len(m) == 0: 350 | return 351 | 352 | # Check if the key is a stringified integer 353 | first_key = next(iter(m)) 354 | is_numeric = isinstance(first_key, str) and first_key.isnumeric() 355 | if is_numeric: 356 | keys = sorted([int(k) for k in m.keys()]) 357 | else: 358 | keys = sorted([k for k in m.keys()]) 359 | 360 | for k in keys: 361 | if is_numeric: 362 | v = m[str(k)] 363 | else: 364 | v = m[k] 365 | print("{}: {}".format(k, v)) 366 | 367 | 368 | def print_list(lst: list): 369 | r""" 370 | Print list elements in separate lines 371 | """ 372 | for i, elm in enumerate(lst): 373 | if isinstance(elm, list): 374 | print("{}: ({}) - {}".format(i, len(elm), elm)) 375 | else: 376 | print("{}: {}".format(i, elm)) 377 | -------------------------------------------------------------------------------- /docling_ibm_models/tableformer/models/table04_rs/tablemodel04_rs.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright IBM Corp. 2024 - 2024 3 | # SPDX-License-Identifier: MIT 4 | # 5 | import logging 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | import docling_ibm_models.tableformer.settings as s 11 | from docling_ibm_models.tableformer.models.common.base_model import BaseModel 12 | from docling_ibm_models.tableformer.models.table04_rs.bbox_decoder_rs import BBoxDecoder 13 | from docling_ibm_models.tableformer.models.table04_rs.encoder04_rs import Encoder04 14 | from docling_ibm_models.tableformer.models.table04_rs.transformer_rs import ( 15 | Tag_Transformer, 16 | ) 17 | from docling_ibm_models.tableformer.utils.app_profiler import AggProfiler 18 | 19 | LOG_LEVEL = logging.WARN 20 | # LOG_LEVEL = logging.INFO 21 | # LOG_LEVEL = logging.DEBUG 22 | 23 | 24 | class TableModel04_rs(BaseModel, nn.Module): 25 | r""" 26 | TableNet04Model encoder, dual-decoder model with OTSL+ support 27 | """ 28 | 29 | def __init__(self, config, init_data, device): 30 | super(TableModel04_rs, self).__init__(config, init_data, device) 31 | 32 | self._prof = config["predict"].get("profiling", False) 33 | self._device = device 34 | # Extract the word_map from the init_data 35 | word_map = init_data["word_map"] 36 | 37 | # Encoder 38 | self._enc_image_size = config["model"]["enc_image_size"] 39 | self._encoder_dim = config["model"]["hidden_dim"] 40 | self._encoder = Encoder04(self._enc_image_size, self._encoder_dim).to(device) 41 | 42 | tag_vocab_size = len(word_map["word_map_tag"]) 43 | 44 | td_encode = [] 45 | for t in ["ecel", "fcel", "ched", "rhed", "srow"]: 46 | if t in word_map["word_map_tag"]: 47 | td_encode.append(word_map["word_map_tag"][t]) 48 | self._log().debug("td_encode length: {}".format(len(td_encode))) 49 | self._log().debug("td_encode: {}".format(td_encode)) 50 | 51 | self._tag_attention_dim = config["model"]["tag_attention_dim"] 52 | self._tag_embed_dim = config["model"]["tag_embed_dim"] 53 | self._tag_decoder_dim = config["model"]["tag_decoder_dim"] 54 | self._decoder_dim = config["model"]["hidden_dim"] 55 | self._dropout = config["model"]["dropout"] 56 | 57 | self._bbox = config["train"]["bbox"] 58 | self._bbox_attention_dim = config["model"]["bbox_attention_dim"] 59 | self._bbox_embed_dim = config["model"]["bbox_embed_dim"] 60 | self._bbox_decoder_dim = config["model"]["hidden_dim"] 61 | 62 | self._enc_layers = config["model"]["enc_layers"] 63 | self._dec_layers = config["model"]["dec_layers"] 64 | self._n_heads = config["model"]["nheads"] 65 | 66 | self._num_classes = config["model"]["bbox_classes"] 67 | self._enc_image_size = config["model"]["enc_image_size"] 68 | 69 | self._max_pred_len = config["predict"]["max_steps"] 70 | 71 | self._tag_transformer = Tag_Transformer( 72 | device, 73 | tag_vocab_size, 74 | td_encode, 75 | self._decoder_dim, 76 | self._enc_layers, 77 | self._dec_layers, 78 | self._enc_image_size, 79 | n_heads=self._n_heads, 80 | ).to(device) 81 | 82 | self._bbox_decoder = BBoxDecoder( 83 | device, 84 | self._bbox_attention_dim, 85 | self._bbox_embed_dim, 86 | self._tag_decoder_dim, 87 | self._bbox_decoder_dim, 88 | self._num_classes, 89 | self._encoder_dim, 90 | self._dropout, 91 | ).to(device) 92 | 93 | def _log(self): 94 | # Setup a custom logger 95 | return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL) 96 | 97 | def mergebboxes(self, bbox1, bbox2): 98 | new_w = (bbox2[0] + bbox2[2] / 2) - (bbox1[0] - bbox1[2] / 2) 99 | new_h = (bbox2[1] + bbox2[3] / 2) - (bbox1[1] - bbox1[3] / 2) 100 | 101 | new_left = bbox1[0] - bbox1[2] / 2 102 | new_top = min((bbox2[1] - bbox2[3] / 2), (bbox1[1] - bbox1[3] / 2)) 103 | 104 | new_cx = new_left + new_w / 2 105 | new_cy = new_top + new_h / 2 106 | 107 | bboxm = torch.tensor([new_cx, new_cy, new_w, new_h]) 108 | return bboxm 109 | 110 | def predict(self, imgs, max_steps, k, return_attention=False): 111 | r""" 112 | Inference. 113 | The input image must be preprocessed and transformed. 114 | 115 | Parameters 116 | ---------- 117 | img : tensor FloatTensor - torch.Size([1, 3, 448, 448]) 118 | Input image for the inference 119 | 120 | Returns 121 | ------- 122 | seq : list 123 | Predictions for the tags as indices over the word_map 124 | outputs_class : tensor(x, 3) 125 | Classes of predicted bboxes. x is the number of bboxes. There are 3 bbox classes 126 | 127 | outputs_coord : tensor(x, 4) 128 | Coords of predicted bboxes. x is the number of bboxes. Each bbox is in [cxcywh] format 129 | """ 130 | AggProfiler().begin("predict_total", self._prof) 131 | 132 | # Invoke encoder 133 | self._tag_transformer.eval() 134 | enc_out = self._encoder(imgs) 135 | AggProfiler().end("model_encoder", self._prof) 136 | 137 | word_map = self._init_data["word_map"]["word_map_tag"] 138 | n_heads = self._tag_transformer._n_heads 139 | # [1, 28, 28, 512] 140 | encoder_out = self._tag_transformer._input_filter( 141 | enc_out.permute(0, 3, 1, 2) 142 | ).permute(0, 2, 3, 1) 143 | 144 | batch_size = encoder_out.size(0) 145 | encoder_dim = encoder_out.size(-1) 146 | enc_inputs = encoder_out.view(batch_size, -1, encoder_dim).to(self._device) 147 | enc_inputs = enc_inputs.permute(1, 0, 2) 148 | positions = enc_inputs.shape[0] 149 | 150 | encoder_mask = torch.zeros( 151 | (batch_size * n_heads, positions, positions), device=self._device 152 | ) == torch.ones( 153 | (batch_size * n_heads, positions, positions), device=self._device 154 | ) 155 | 156 | # Invoking tag transformer encoder before the loop to save time 157 | AggProfiler().begin("model_tag_transformer_encoder", self._prof) 158 | encoder_out = self._tag_transformer._encoder(enc_inputs, mask=encoder_mask) 159 | AggProfiler().end("model_tag_transformer_encoder", self._prof) 160 | 161 | decoded_tags = ( 162 | torch.LongTensor([word_map[""]]).to(self._device).unsqueeze(1) 163 | ) 164 | output_tags = [] 165 | cache = None 166 | tag_H_buf = [] 167 | 168 | skip_next_tag = True 169 | prev_tag_ucel = False 170 | line_num = 0 171 | 172 | # Populate bboxes_to_merge, indexes of first lcel, and last cell in a span 173 | first_lcel = True 174 | bboxes_to_merge = {} 175 | cur_bbox_ind = -1 176 | bbox_ind = 0 177 | 178 | # i = 0 179 | while len(output_tags) < self._max_pred_len: 180 | decoded_embedding = self._tag_transformer._embedding(decoded_tags) 181 | decoded_embedding = self._tag_transformer._positional_encoding( 182 | decoded_embedding 183 | ) 184 | AggProfiler().begin("model_tag_transformer_decoder", self._prof) 185 | decoded, cache = self._tag_transformer._decoder( 186 | decoded_embedding, 187 | encoder_out, 188 | cache, 189 | memory_key_padding_mask=encoder_mask, 190 | ) 191 | AggProfiler().end("model_tag_transformer_decoder", self._prof) 192 | # Grab last feature to produce token 193 | AggProfiler().begin("model_tag_transformer_fc", self._prof) 194 | logits = self._tag_transformer._fc(decoded[-1, :, :]) # 1, vocab_size 195 | AggProfiler().end("model_tag_transformer_fc", self._prof) 196 | new_tag = logits.argmax(1).item() 197 | 198 | # STRUCTURE ERROR CORRECTION 199 | # Correction for first line xcel... 200 | if line_num == 0: 201 | if new_tag == word_map["xcel"]: 202 | new_tag = word_map["lcel"] 203 | 204 | # Correction for ucel, lcel sequence... 205 | if prev_tag_ucel: 206 | if new_tag == word_map["lcel"]: 207 | new_tag = word_map["fcel"] 208 | 209 | # End of generation 210 | if new_tag == word_map[""]: 211 | output_tags.append(new_tag) 212 | decoded_tags = torch.cat( 213 | [ 214 | decoded_tags, 215 | torch.LongTensor([new_tag]).unsqueeze(1).to(self._device), 216 | ], 217 | dim=0, 218 | ) # current_output_len, 1 219 | break 220 | output_tags.append(new_tag) 221 | 222 | # BBOX PREDICTION 223 | 224 | # MAKE SURE TO SYNC NUMBER OF CELLS WITH NUMBER OF BBOXes 225 | if not skip_next_tag: 226 | if new_tag in [ 227 | word_map["fcel"], 228 | word_map["ecel"], 229 | word_map["ched"], 230 | word_map["rhed"], 231 | word_map["srow"], 232 | word_map["nl"], 233 | word_map["ucel"], 234 | ]: 235 | # GENERATE BBOX HERE TOO (All other cases)... 236 | tag_H_buf.append(decoded[-1, :, :]) 237 | if first_lcel is not True: 238 | # Mark end index for horizontal cell bbox merge 239 | bboxes_to_merge[cur_bbox_ind] = bbox_ind 240 | bbox_ind += 1 241 | 242 | # Treat horisontal span bboxes... 243 | if new_tag != word_map["lcel"]: 244 | first_lcel = True 245 | else: 246 | if first_lcel: 247 | # GENERATE BBOX HERE (Beginning of horisontal span)... 248 | tag_H_buf.append(decoded[-1, :, :]) 249 | first_lcel = False 250 | # Mark start index for cell bbox merge 251 | cur_bbox_ind = bbox_ind 252 | bboxes_to_merge[cur_bbox_ind] = -1 253 | bbox_ind += 1 254 | 255 | if new_tag in [word_map["nl"], word_map["ucel"], word_map["xcel"]]: 256 | skip_next_tag = True 257 | else: 258 | skip_next_tag = False 259 | 260 | # Register ucel in sequence... 261 | if new_tag == word_map["ucel"]: 262 | prev_tag_ucel = True 263 | else: 264 | prev_tag_ucel = False 265 | 266 | decoded_tags = torch.cat( 267 | [ 268 | decoded_tags, 269 | torch.LongTensor([new_tag]).unsqueeze(1).to(self._device), 270 | ], 271 | dim=0, 272 | ) # current_output_len, 1 273 | seq = decoded_tags.squeeze().tolist() 274 | 275 | if self._bbox: 276 | AggProfiler().begin("model_bbox_decoder", self._prof) 277 | outputs_class, outputs_coord = self._bbox_decoder.inference( 278 | enc_out, tag_H_buf 279 | ) 280 | AggProfiler().end("model_bbox_decoder", self._prof) 281 | else: 282 | outputs_class, outputs_coord = None, None 283 | 284 | outputs_class.to(self._device) 285 | outputs_coord.to(self._device) 286 | 287 | ######################################################################################## 288 | # Merge First and Last predicted BBOX for each span, according to bboxes_to_merge 289 | ######################################################################################## 290 | 291 | outputs_class1 = [] 292 | outputs_coord1 = [] 293 | boxes_to_skip = [] 294 | 295 | for box_ind in range(len(outputs_coord)): 296 | box1 = outputs_coord[box_ind].to(self._device) 297 | cls1 = outputs_class[box_ind].to(self._device) 298 | if box_ind in bboxes_to_merge: 299 | box2 = outputs_coord[bboxes_to_merge[box_ind]].to(self._device) 300 | boxes_to_skip.append(bboxes_to_merge[box_ind]) 301 | boxm = self.mergebboxes(box1, box2).to(self._device) 302 | outputs_coord1.append(boxm) 303 | outputs_class1.append(cls1) 304 | else: 305 | if box_ind not in boxes_to_skip: 306 | outputs_coord1.append(box1) 307 | outputs_class1.append(cls1) 308 | 309 | if len(outputs_coord1) > 0: 310 | outputs_coord1 = torch.stack(outputs_coord1) 311 | else: 312 | outputs_coord1 = torch.empty(0) 313 | if len(outputs_class1) > 0: 314 | outputs_class1 = torch.stack(outputs_class1) 315 | else: 316 | outputs_class1 = torch.empty(0) 317 | 318 | outputs_class = outputs_class1 319 | outputs_coord = outputs_coord1 320 | 321 | # Do the rest of the steps... 322 | AggProfiler().end("predict_total", self._prof) 323 | num_tab_cells = seq.count(4) + seq.count(5) 324 | num_rows = seq.count(9) 325 | self._log().info( 326 | "OTSL predicted table cells#: {}; rows#: {}".format(num_tab_cells, num_rows) 327 | ) 328 | return seq, outputs_class, outputs_coord 329 | --------------------------------------------------------------------------------