├── .dockerignore ├── .github ├── FUNDING.yml ├── dependabot.yml └── workflows │ ├── push_docker_image.yml │ └── test.yml ├── .gitignore ├── .well-known └── funding-manifest-urls ├── Dockerfile ├── Dockerfile.ollama ├── LICENSE ├── Makefile ├── README.md ├── dev-requirements.txt ├── docker-compose-gpu.yml ├── docker-compose.yml ├── fine_tuning_lightgbm_models.ipynb ├── images ├── vgtexample1.png ├── vgtexample2.png ├── vgtexample3.png └── vgtexample4.png ├── justfile ├── pyproject.toml ├── requirements.txt ├── src ├── adapters │ ├── __init__.py │ ├── infrastructure │ │ ├── __init__.py │ │ ├── format_conversion_service_adapter.py │ │ ├── format_converters │ │ │ ├── __init__.py │ │ │ ├── convert_formula_to_latex.py │ │ │ └── convert_table_to_html.py │ │ ├── html_conversion_service_adapter.py │ │ ├── markdown_conversion_service_adapter.py │ │ ├── markup_conversion │ │ │ ├── ExtractedImage.py │ │ │ ├── Link.py │ │ │ ├── OutputFormat.py │ │ │ ├── __init__.py │ │ │ └── pdf_to_markup_service_adapter.py │ │ ├── ocr │ │ │ ├── __init__.py │ │ │ └── languages.py │ │ ├── ocr_service_adapter.py │ │ ├── pdf_analysis_service_adapter.py │ │ ├── text_extraction_adapter.py │ │ ├── toc │ │ │ ├── MergeTwoSegmentsTitles.py │ │ │ ├── PdfSegmentation.py │ │ │ ├── TOCExtractor.py │ │ │ ├── TitleFeatures.py │ │ │ ├── __init__.py │ │ │ ├── data │ │ │ │ ├── TOCItem.py │ │ │ │ └── __init__.py │ │ │ ├── extract_table_of_contents.py │ │ │ └── methods │ │ │ │ ├── __init__.py │ │ │ │ └── two_models_v3_segments_context_2 │ │ │ │ ├── Modes.py │ │ │ │ └── __init__.py │ │ ├── toc_service_adapter.py │ │ ├── translation │ │ │ ├── decode_html_content.py │ │ │ ├── decode_markdown_content.py │ │ │ ├── download_translation_model.py │ │ │ ├── encode_html_content.py │ │ │ ├── encode_markdown_content.py │ │ │ ├── ollama_container_manager.py │ │ │ └── translate_markup_document.py │ │ └── visualization_service_adapter.py │ ├── ml │ │ ├── __init__.py │ │ ├── fast_trainer │ │ │ ├── Paragraph.py │ │ │ ├── ParagraphExtractorTrainer.py │ │ │ ├── __init__.py │ │ │ └── model_configuration.py │ │ ├── fast_trainer_adapter.py │ │ ├── pdf_tokens_type_trainer │ │ │ ├── ModelConfiguration.py │ │ │ ├── PdfTrainer.py │ │ │ ├── TokenFeatures.py │ │ │ ├── TokenTypeTrainer.py │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ ├── download_models.py │ │ │ ├── get_paths.py │ │ │ └── tests │ │ │ │ ├── __init__.py │ │ │ │ └── test_trainer.py │ │ ├── vgt │ │ │ ├── __init__.py │ │ │ ├── bros │ │ │ │ ├── __init__.py │ │ │ │ ├── configuration_bros.py │ │ │ │ ├── modeling_bros.py │ │ │ │ ├── tokenization_bros.py │ │ │ │ └── tokenization_bros_fast.py │ │ │ ├── create_word_grid.py │ │ │ ├── ditod │ │ │ │ ├── FeatureMerge.py │ │ │ │ ├── VGT.py │ │ │ │ ├── VGTTrainer.py │ │ │ │ ├── VGTbackbone.py │ │ │ │ ├── VGTbeit.py │ │ │ │ ├── VGTcheckpointer.py │ │ │ │ ├── Wordnn_embedding.py │ │ │ │ ├── __init__.py │ │ │ │ ├── config.py │ │ │ │ ├── dataset_mapper.py │ │ │ │ ├── tokenization_bros.py │ │ │ │ └── utils.py │ │ │ ├── get_json_annotations.py │ │ │ ├── get_model_configuration.py │ │ │ ├── get_most_probable_pdf_segments.py │ │ │ ├── get_reading_orders.py │ │ │ └── model_configuration │ │ │ │ ├── Base-RCNN-FPN.yaml │ │ │ │ ├── doclaynet_VGT_cascade_PTM.yaml │ │ │ │ └── doclaynet_configuration.pickle │ │ └── vgt_model_adapter.py │ ├── storage │ │ ├── __init__.py │ │ └── file_system_repository.py │ └── web │ │ ├── __init__.py │ │ └── fastapi_controllers.py ├── app.py ├── catch_exceptions.py ├── configuration.py ├── domain │ ├── PdfImages.py │ ├── PdfSegment.py │ ├── Prediction.py │ └── SegmentBox.py ├── download_models.py ├── drivers │ ├── __init__.py │ └── web │ │ ├── __init__.py │ │ ├── dependency_injection.py │ │ └── fastapi_app.py ├── ports │ ├── __init__.py │ ├── repositories │ │ ├── __init__.py │ │ └── file_repository.py │ └── services │ │ ├── __init__.py │ │ ├── format_conversion_service.py │ │ ├── html_conversion_service.py │ │ ├── markdown_conversion_service.py │ │ ├── ml_model_service.py │ │ ├── ocr_service.py │ │ ├── pdf_analysis_service.py │ │ ├── text_extraction_service.py │ │ ├── toc_service.py │ │ └── visualization_service.py ├── tests │ ├── __init__.py │ └── test_end_to_end.py └── use_cases │ ├── __init__.py │ ├── html_conversion │ ├── __init__.py │ └── convert_to_html_use_case.py │ ├── markdown_conversion │ ├── __init__.py │ └── convert_to_markdown_use_case.py │ ├── ocr │ ├── __init__.py │ └── process_ocr_use_case.py │ ├── pdf_analysis │ ├── __init__.py │ └── analyze_pdf_use_case.py │ ├── text_extraction │ ├── __init__.py │ └── extract_text_use_case.py │ ├── toc_extraction │ ├── __init__.py │ └── extract_toc_use_case.py │ └── visualization │ ├── __init__.py │ └── create_visualization_use_case.py ├── start.sh └── test_pdfs ├── blank.pdf ├── chinese.pdf ├── error.pdf ├── formula.pdf ├── image.pdf ├── korean.pdf ├── not_a_pdf.pdf ├── ocr-sample-already-ocred.pdf ├── ocr-sample-english.pdf ├── ocr-sample-french.pdf ├── ocr_pdf.pdf ├── regular.pdf ├── some_empty_pages.pdf ├── table.pdf ├── test.pdf └── toc-test.pdf /.dockerignore: -------------------------------------------------------------------------------- 1 | /venv/ 2 | /.venv/ 3 | .git 4 | /detectron2/ 5 | /images/ -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | custom: ["https://huridocs.org/donate/"] 2 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "pip" 4 | directory: "/" 5 | schedule: 6 | interval: "daily" 7 | open-pull-requests-limit: 5 8 | labels: 9 | - "dependencies" 10 | - package-ecosystem: "github-actions" 11 | directory: "/" 12 | schedule: 13 | interval: "daily" 14 | - package-ecosystem: "docker" 15 | directory: "/" 16 | schedule: 17 | interval: "daily" 18 | -------------------------------------------------------------------------------- /.github/workflows/push_docker_image.yml: -------------------------------------------------------------------------------- 1 | name: Create and publish Docker image 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' 7 | 8 | env: 9 | REGISTRY: ghcr.io 10 | IMAGE_NAME: huridocs/pdf-document-layout-analysis 11 | 12 | jobs: 13 | build-and-push-image: 14 | runs-on: ubuntu-latest 15 | permissions: 16 | contents: read 17 | packages: write 18 | steps: 19 | - name: Checkout repository 20 | uses: actions/checkout@v4 21 | 22 | - name: Install dependencies 23 | run: sudo apt-get install -y just 24 | 25 | - name: Log in to the Container registry 26 | uses: docker/login-action@v3 27 | with: 28 | registry: ${{ env.REGISTRY }} 29 | username: ${{ github.actor }} 30 | password: ${{ secrets.GITHUB_TOKEN }} 31 | 32 | - name: Extract metadata (tags, labels) for Docker 33 | id: meta 34 | uses: docker/metadata-action@v5 35 | with: 36 | images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} 37 | tags: | 38 | type=ref,event=branch 39 | type=ref,event=pr 40 | type=semver,pattern={{version}} 41 | type=semver,pattern={{major}}.{{minor}} 42 | 43 | - name: Create folder models 44 | run: mkdir -p models 45 | 46 | - name: Build and push 47 | uses: docker/build-push-action@v6 48 | with: 49 | context: . 50 | file: Dockerfile 51 | push: ${{ github.event_name != 'pull_request' }} 52 | tags: ${{ steps.meta.outputs.tags }} 53 | labels: ${{ steps.meta.outputs.labels }} 54 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Test 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v4 19 | - name: Set up Python 3.11 20 | uses: actions/setup-python@v5 21 | with: 22 | python-version: '3.11' 23 | 24 | - name: Install dependencies 25 | run: sudo apt-get update; sudo apt-get install -y pdftohtml qpdf just 26 | 27 | - name: Free up space 28 | run: just free_up_space 29 | 30 | - name: Install venv 31 | run: just install_venv 32 | 33 | - name: Lint with black 34 | run: just check_format 35 | 36 | - name: Start service 37 | run: just start_detached 38 | 39 | - name: Check API ready 40 | uses: emilioschepis/wait-for-endpoint@v1.0.3 41 | with: 42 | url: http://localhost:5060 43 | method: GET 44 | expected-status: 200 45 | timeout: 120000 46 | interval: 500 47 | 48 | - name: Test with unittest 49 | run: just test 50 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | /models/ 162 | /word_grids/ 163 | /jsons/ 164 | /model_output/ 165 | /pdf_outputs/ 166 | /detectron2/ 167 | /ocr/ 168 | -------------------------------------------------------------------------------- /.well-known/funding-manifest-urls: -------------------------------------------------------------------------------- 1 | https://huridocs.org/funding/funding.json 2 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:2.4.0-cuda11.8-cudnn9-runtime 2 | COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/ 3 | 4 | RUN apt-get update 5 | RUN apt-get install --fix-missing -y -q --no-install-recommends libgomp1 ffmpeg libsm6 pdftohtml libxext6 git ninja-build g++ qpdf pandoc curl 6 | 7 | 8 | RUN apt-get install -y ocrmypdf 9 | RUN apt-get install -y tesseract-ocr-fra 10 | RUN apt-get install -y tesseract-ocr-spa 11 | RUN apt-get install -y tesseract-ocr-deu 12 | RUN apt-get install -y tesseract-ocr-ara 13 | RUN apt-get install -y tesseract-ocr-mya 14 | RUN apt-get install -y tesseract-ocr-hin 15 | RUN apt-get install -y tesseract-ocr-tam 16 | RUN apt-get install -y tesseract-ocr-tha 17 | RUN apt-get install -y tesseract-ocr-chi-sim 18 | RUN apt-get install -y tesseract-ocr-tur 19 | RUN apt-get install -y tesseract-ocr-ukr 20 | RUN apt-get install -y tesseract-ocr-ell 21 | RUN apt-get install -y tesseract-ocr-rus 22 | RUN apt-get install -y tesseract-ocr-kor 23 | RUN apt-get install -y tesseract-ocr-kor-vert 24 | 25 | 26 | RUN mkdir -p /app/src 27 | RUN mkdir -p /app/models 28 | 29 | RUN addgroup --system python && adduser --system --group python 30 | RUN chown -R python:python /app 31 | USER python 32 | 33 | ENV VIRTUAL_ENV=/app/.venv 34 | RUN python -m venv $VIRTUAL_ENV 35 | ENV PATH="$VIRTUAL_ENV/bin:$PATH" 36 | 37 | COPY requirements.txt requirements.txt 38 | RUN uv pip install --upgrade pip 39 | RUN uv pip install -r requirements.txt 40 | 41 | WORKDIR /app 42 | 43 | RUN cd src; git clone https://github.com/facebookresearch/detectron2; 44 | RUN cd src/detectron2; git checkout 70f454304e1a38378200459dd2dbca0f0f4a5ab4; python setup.py build develop 45 | RUN uv pip install pycocotools==2.0.8 46 | 47 | COPY ./start.sh ./start.sh 48 | COPY ./src/. ./src 49 | COPY ./models/. ./models/ 50 | RUN python src/download_models.py 51 | 52 | ENV PYTHONPATH "${PYTHONPATH}:/app/src" 53 | ENV TRANSFORMERS_VERBOSITY=error 54 | ENV TRANSFORMERS_NO_ADVISORY_WARNINGS=1 55 | -------------------------------------------------------------------------------- /Dockerfile.ollama: -------------------------------------------------------------------------------- 1 | FROM ollama/ollama:latest 2 | 3 | RUN apt-get update && apt-get install -y curl && rm -rf /var/lib/apt/lists/* 4 | 5 | ENV OLLAMA_HOST=0.0.0.0:11434 6 | 7 | EXPOSE 11434 8 | 9 | HEALTHCHECK --interval=30s --timeout=10s --start-period=30s --retries=3 \ 10 | CMD curl -f http://localhost:11434/api/tags || exit 1 11 | -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements.txt 2 | pytest==8.2.2 3 | black==24.4.2 4 | pip-upgrader==1.4.15 -------------------------------------------------------------------------------- /docker-compose-gpu.yml: -------------------------------------------------------------------------------- 1 | services: 2 | ollama-gpu: 3 | extends: 4 | file: docker-compose.yml 5 | service: ollama 6 | container_name: ollama-service-gpu 7 | deploy: 8 | resources: 9 | reservations: 10 | devices: 11 | - driver: nvidia 12 | count: 1 13 | capabilities: [ gpu ] 14 | environment: 15 | - NVIDIA_VISIBLE_DEVICES=all 16 | 17 | pdf-document-layout-analysis-gpu: 18 | container_name: pdf-document-layout-analysis-gpu 19 | entrypoint: [ "gunicorn", "-k", "uvicorn.workers.UvicornWorker", "--chdir", "./src", "app:app", "--bind", "0.0.0.0:5060", "--timeout", "10000"] 20 | init: true 21 | restart: unless-stopped 22 | build: 23 | context: . 24 | dockerfile: Dockerfile 25 | ports: 26 | - "5060:5060" 27 | deploy: 28 | resources: 29 | reservations: 30 | devices: 31 | - driver: nvidia 32 | count: 1 33 | capabilities: [ gpu ] 34 | environment: 35 | - RESTART_IF_NO_GPU=$RESTART_IF_NO_GPU 36 | - OLLAMA_HOST=http://localhost:11434 37 | 38 | pdf-document-layout-analysis-gpu-translation: 39 | container_name: pdf-document-layout-analysis-gpu-translation 40 | entrypoint: [ "gunicorn", "-k", "uvicorn.workers.UvicornWorker", "--chdir", "./src", "app:app", "--bind", "0.0.0.0:5060", "--timeout", "10000"] 41 | init: true 42 | restart: unless-stopped 43 | build: 44 | context: . 45 | dockerfile: Dockerfile 46 | ports: 47 | - "5060:5060" 48 | depends_on: 49 | ollama-gpu: 50 | condition: service_healthy 51 | deploy: 52 | resources: 53 | reservations: 54 | devices: 55 | - driver: nvidia 56 | count: 1 57 | capabilities: [ gpu ] 58 | environment: 59 | - RESTART_IF_NO_GPU=$RESTART_IF_NO_GPU 60 | - OLLAMA_HOST=http://ollama-gpu:11434 61 | networks: 62 | - pdf-analysis-network 63 | 64 | networks: 65 | pdf-analysis-network: 66 | driver: bridge -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | ollama: 3 | container_name: ollama-service 4 | build: 5 | context: . 6 | dockerfile: Dockerfile.ollama 7 | restart: unless-stopped 8 | ports: 9 | - "11434:11434" 10 | healthcheck: 11 | test: ["CMD", "curl", "-f", "http://localhost:11434/api/tags"] 12 | interval: 30s 13 | timeout: 10s 14 | retries: 3 15 | start_period: 30s 16 | networks: 17 | - pdf-analysis-network 18 | 19 | pdf-document-layout-analysis: 20 | container_name: pdf-document-layout-analysis 21 | entrypoint: [ "gunicorn", "-k", "uvicorn.workers.UvicornWorker", "--chdir", "./src", "app:app", "--bind", "0.0.0.0:5060", "--timeout", "10000"] 22 | init: true 23 | restart: unless-stopped 24 | build: 25 | context: . 26 | dockerfile: Dockerfile 27 | ports: 28 | - "5060:5060" 29 | environment: 30 | - OLLAMA_HOST=http://localhost:11434 31 | 32 | pdf-document-layout-analysis-translation: 33 | container_name: pdf-document-layout-analysis-translation 34 | entrypoint: [ "gunicorn", "-k", "uvicorn.workers.UvicornWorker", "--chdir", "./src", "app:app", "--bind", "0.0.0.0:5060", "--timeout", "10000"] 35 | init: true 36 | restart: unless-stopped 37 | build: 38 | context: . 39 | dockerfile: Dockerfile 40 | ports: 41 | - "5060:5060" 42 | depends_on: 43 | ollama: 44 | condition: service_healthy 45 | environment: 46 | - OLLAMA_HOST=http://ollama:11434 47 | networks: 48 | - pdf-analysis-network 49 | 50 | networks: 51 | pdf-analysis-network: 52 | driver: bridge 53 | -------------------------------------------------------------------------------- /images/vgtexample1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/images/vgtexample1.png -------------------------------------------------------------------------------- /images/vgtexample2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/images/vgtexample2.png -------------------------------------------------------------------------------- /images/vgtexample3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/images/vgtexample3.png -------------------------------------------------------------------------------- /images/vgtexample4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/images/vgtexample4.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "pdf-document-layout-analysis" 3 | version = "2025.03.18.03" 4 | description = "This tool is for PDF document layout analysis" 5 | license = { file = "LICENSE" } 6 | authors = [{ name = "HURIDOCS" }] 7 | requires-python = ">= 3.10" 8 | dependencies = [ 9 | "fastapi==0.111.1", 10 | "python-multipart==0.0.9", 11 | "uvicorn==0.30.3", 12 | "gunicorn==22.0.0", 13 | "requests==2.32.3", 14 | "torch==2.4.0", 15 | "torchvision==0.19.0", 16 | "timm==1.0.8", 17 | "Pillow==10.4.0", 18 | "pdf-annotate==0.12.0", 19 | "scipy==1.14.0", 20 | "opencv-python==4.10.0.84", 21 | "Shapely==2.0.5", 22 | "transformers==4.40.2", 23 | "huggingface_hub==0.23.5", 24 | "pdf2image==1.17.0", 25 | "lxml==5.2.2", 26 | "lightgbm==4.5.0", 27 | "setuptools==75.4.0", 28 | "roman==4.2", 29 | "hydra-core==1.3.2", 30 | "pypandoc==1.13", 31 | "rapid-latex-ocr==0.0.9", 32 | "struct_eqtable @ git+https://github.com/UniModal4Reasoning/StructEqTable-Deploy.git@fd06078bfa9364849eb39330c075dd63cbed73ff" 33 | ] 34 | 35 | [project.urls] 36 | HURIDOCS = "https://huridocs.org" 37 | GitHub = "https://github.com/huridocs/pdf-document-layout-analysis" 38 | HuggingFace = "https://huggingface.co/HURIDOCS/pdf-document-layout-analysis" 39 | DockerHub = "https://hub.docker.com/r/huridocs/pdf-document-layout-analysis" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi==0.111.1 2 | pydantic==2.11.0 3 | python-multipart==0.0.9 4 | uvicorn==0.30.3 5 | gunicorn==22.0.0 6 | requests==2.32.3 7 | torch==2.4.0 8 | torchvision==0.19.0 9 | Pillow==10.4.0 10 | pdf-annotate==0.12.0 11 | scipy==1.14.0 12 | opencv-python==4.10.0.84 13 | Shapely==2.0.5 14 | transformers==4.40.2 15 | huggingface_hub==0.23.5 16 | pdf2image==1.17.0 17 | lightgbm==4.5.0 18 | setuptools==75.4.0 19 | roman==4.2 20 | hydra-core==1.3.2 21 | pypandoc==1.13 22 | rapid-table==2.0.3 23 | rapidocr==3.2.0 24 | pix2tex==0.1.4 25 | latex2mathml==3.78.0 26 | PyMuPDF==1.25.5 27 | ollama==0.6.0 28 | git+https://github.com/huridocs/pdf-features.git@2025.10.1.1 -------------------------------------------------------------------------------- /src/adapters/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/__init__.py -------------------------------------------------------------------------------- /src/adapters/infrastructure/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/infrastructure/__init__.py -------------------------------------------------------------------------------- /src/adapters/infrastructure/format_conversion_service_adapter.py: -------------------------------------------------------------------------------- 1 | from domain.PdfImages import PdfImages 2 | from domain.PdfSegment import PdfSegment 3 | from ports.services.format_conversion_service import FormatConversionService 4 | from adapters.infrastructure.format_converters.convert_table_to_html import extract_table_format 5 | from adapters.infrastructure.format_converters.convert_formula_to_latex import extract_formula_format 6 | 7 | 8 | class FormatConversionServiceAdapter(FormatConversionService): 9 | def convert_table_to_html(self, pdf_images: PdfImages, segments: list[PdfSegment]) -> None: 10 | extract_table_format(pdf_images, segments) 11 | 12 | def convert_formula_to_latex(self, pdf_images: PdfImages, segments: list[PdfSegment]) -> None: 13 | extract_formula_format(pdf_images, segments) 14 | -------------------------------------------------------------------------------- /src/adapters/infrastructure/format_converters/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/infrastructure/format_converters/__init__.py -------------------------------------------------------------------------------- /src/adapters/infrastructure/format_converters/convert_formula_to_latex.py: -------------------------------------------------------------------------------- 1 | from PIL.Image import Image 2 | from pix2tex.cli import LatexOCR 3 | from domain.PdfImages import PdfImages 4 | from domain.PdfSegment import PdfSegment 5 | from pdf_token_type_labels import TokenType 6 | import latex2mathml.converter 7 | 8 | 9 | def has_arabic(text: str) -> bool: 10 | return any("\u0600" <= char <= "\u06FF" or "\u0750" <= char <= "\u077F" for char in text) 11 | 12 | 13 | def is_valid_latex(formula: str) -> bool: 14 | try: 15 | latex2mathml.converter.convert(formula) 16 | return True 17 | except Exception: 18 | return False 19 | 20 | 21 | def extract_formula_format(pdf_images: PdfImages, predicted_segments: list[PdfSegment]): 22 | formula_segments = [segment for segment in predicted_segments if segment.segment_type == TokenType.FORMULA] 23 | if not formula_segments: 24 | return 25 | 26 | model = LatexOCR() 27 | model.args.temperature = 1e-8 28 | 29 | for formula_segment in formula_segments: 30 | if has_arabic(formula_segment.text_content): 31 | continue 32 | page_image: Image = pdf_images.pdf_images[formula_segment.page_number - 1] 33 | left, top = formula_segment.bounding_box.left, formula_segment.bounding_box.top 34 | right, bottom = formula_segment.bounding_box.right, formula_segment.bounding_box.bottom 35 | left = int(left * pdf_images.dpi / 72) 36 | top = int(top * pdf_images.dpi / 72) 37 | right = int(right * pdf_images.dpi / 72) 38 | bottom = int(bottom * pdf_images.dpi / 72) 39 | formula_image = page_image.crop((left, top, right, bottom)) 40 | formula_result = model(formula_image) 41 | if not is_valid_latex(formula_result): 42 | continue 43 | formula_segment.text_content = f"$${formula_result}$$" 44 | -------------------------------------------------------------------------------- /src/adapters/infrastructure/format_converters/convert_table_to_html.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from domain.PdfImages import PdfImages 3 | from domain.PdfSegment import PdfSegment 4 | from pdf_token_type_labels import TokenType 5 | from rapidocr import RapidOCR 6 | from rapid_table import ModelType, RapidTable, RapidTableInput 7 | 8 | 9 | def extract_table_format(pdf_images: PdfImages, predicted_segments: list[PdfSegment]): 10 | table_segments = [segment for segment in predicted_segments if segment.segment_type == TokenType.TABLE] 11 | if not table_segments: 12 | return 13 | 14 | input_args = RapidTableInput(model_type=ModelType["SLANETPLUS"]) 15 | 16 | ocr_engine = RapidOCR() 17 | table_engine = RapidTable(input_args) 18 | 19 | for table_segment in table_segments: 20 | page_image: Image = pdf_images.pdf_images[table_segment.page_number - 1] 21 | left, top = table_segment.bounding_box.left, table_segment.bounding_box.top 22 | right, bottom = table_segment.bounding_box.right, table_segment.bounding_box.bottom 23 | left = int(left * pdf_images.dpi / 72) 24 | top = int(top * pdf_images.dpi / 72) 25 | right = int(right * pdf_images.dpi / 72) 26 | bottom = int(bottom * pdf_images.dpi / 72) 27 | table_image = page_image.crop((left, top, right, bottom)) 28 | ori_ocr_res = ocr_engine(table_image) 29 | if not ori_ocr_res.txts: 30 | continue 31 | ocr_results = [ori_ocr_res.boxes, ori_ocr_res.txts, ori_ocr_res.scores] 32 | table_result = table_engine(table_image, ocr_results=ocr_results) 33 | table_segment.text_content = table_result.pred_html 34 | -------------------------------------------------------------------------------- /src/adapters/infrastructure/html_conversion_service_adapter.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | from starlette.responses import Response 3 | 4 | from domain.SegmentBox import SegmentBox 5 | from ports.services.html_conversion_service import HtmlConversionService 6 | from adapters.infrastructure.markup_conversion.pdf_to_markup_service_adapter import PdfToMarkupServiceAdapter 7 | from adapters.infrastructure.markup_conversion.OutputFormat import OutputFormat 8 | 9 | 10 | class HtmlConversionServiceAdapter(HtmlConversionService, PdfToMarkupServiceAdapter): 11 | 12 | def __init__(self): 13 | PdfToMarkupServiceAdapter.__init__(self, OutputFormat.HTML) 14 | 15 | def convert_to_html( 16 | self, 17 | pdf_content: bytes, 18 | segments: list[SegmentBox], 19 | extract_toc: bool = False, 20 | dpi: int = 120, 21 | output_file: Optional[str] = None, 22 | target_languages: Optional[list[str]] = None, 23 | translation_model: str = "gpt-oss", 24 | ) -> Union[str, Response]: 25 | return self.convert_to_format( 26 | pdf_content, segments, extract_toc, dpi, output_file, target_languages, translation_model 27 | ) 28 | -------------------------------------------------------------------------------- /src/adapters/infrastructure/markdown_conversion_service_adapter.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | from starlette.responses import Response 3 | 4 | from domain.SegmentBox import SegmentBox 5 | from ports.services.markdown_conversion_service import MarkdownConversionService 6 | from adapters.infrastructure.markup_conversion.pdf_to_markup_service_adapter import PdfToMarkupServiceAdapter 7 | from adapters.infrastructure.markup_conversion.OutputFormat import OutputFormat 8 | 9 | 10 | class MarkdownConversionServiceAdapter(MarkdownConversionService, PdfToMarkupServiceAdapter): 11 | 12 | def __init__(self): 13 | PdfToMarkupServiceAdapter.__init__(self, OutputFormat.MARKDOWN) 14 | 15 | def convert_to_markdown( 16 | self, 17 | pdf_content: bytes, 18 | segments: list[SegmentBox], 19 | extract_toc: bool = False, 20 | dpi: int = 120, 21 | output_file: Optional[str] = None, 22 | target_languages: Optional[list[str]] = None, 23 | translation_model: str = "gpt-oss", 24 | ) -> Union[str, Response]: 25 | return self.convert_to_format( 26 | pdf_content, segments, extract_toc, dpi, output_file, target_languages, translation_model 27 | ) 28 | -------------------------------------------------------------------------------- /src/adapters/infrastructure/markup_conversion/ExtractedImage.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class ExtractedImage(BaseModel): 5 | image_data: bytes 6 | filename: str 7 | -------------------------------------------------------------------------------- /src/adapters/infrastructure/markup_conversion/Link.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | from domain.SegmentBox import SegmentBox 3 | 4 | 5 | class Link(BaseModel): 6 | source_segment: SegmentBox 7 | destination_segment: SegmentBox 8 | text: str 9 | -------------------------------------------------------------------------------- /src/adapters/infrastructure/markup_conversion/OutputFormat.py: -------------------------------------------------------------------------------- 1 | from enum import StrEnum 2 | 3 | 4 | class OutputFormat(StrEnum): 5 | HTML = "html" 6 | MARKDOWN = "markdown" 7 | -------------------------------------------------------------------------------- /src/adapters/infrastructure/markup_conversion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/infrastructure/markup_conversion/__init__.py -------------------------------------------------------------------------------- /src/adapters/infrastructure/ocr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/infrastructure/ocr/__init__.py -------------------------------------------------------------------------------- /src/adapters/infrastructure/ocr/languages.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | iso_to_tesseract = { 4 | "af": "afr", # Afrikaans 5 | "all": "all", # Allar 6 | "am": "amh", # Amharic 7 | "ar": "ara", # Arabic 8 | "as": "asm", # Assamese 9 | "az": "aze", # Azerbaijani 10 | "aze-cyrl": "aze-cyrl", # Azerbaijani (Cyrillic) 11 | "be": "bel", # Belarusian 12 | "bn": "ben", # Bangla 13 | "bo": "bod", # Tibetan 14 | "bs": "bos", # Bosnian 15 | "br": "bre", # Breton 16 | "bg": "bul", # Bulgarian 17 | "ca": "cat", # Catalan 18 | "ceb": "ceb", # Cebuano 19 | "cs": "ces", # Czech 20 | "zh-Hans": "chi_sim", # Chinese (Simplified) 21 | "chi-sim-vert": "chi-sim-vert", # Chinese (Simplified) vertical 22 | "zh-Hant": "chi_tra", # Chinese (Traditional) 23 | "chi-tra-vert": "chi-tra-vert", # Chinese (Traditional) vertical 24 | "chr": "chr", # Cherokee 25 | "co": "cos", # Corsican 26 | "cy": "cym", # Welsh 27 | "da": "dan", # Danish 28 | "de": "deu", # German 29 | "dv": "div", # Divehi 30 | "dz": "dzo", # Dzongkha 31 | "el": "ell", # Greek 32 | "en": "eng", # English 33 | "enm": "enm", # Middle English 34 | "eo": "epo", # Esperanto 35 | "et": "est", # Estonian 36 | "eu": "eus", # Basque 37 | "fo": "fao", # Faroese 38 | "fa": "fas", # Persian 39 | "fil": "fil", # Filipino 40 | "fi": "fin", # Finnish 41 | "fr": "fra", # French 42 | "frk": "frk", # Frankish 43 | "frm": "frm", # Middle French 44 | "fy": "fry", # Western Frisian 45 | "gd": "gla", # Scottish Gaelic 46 | "ga": "gle", # Irish 47 | "gl": "glg", # Galician 48 | "grc": "grc", # Ancient Greek 49 | "gu": "guj", # Gujarati 50 | "ht": "hat", # Haitian Creole 51 | "he": "heb", # Hebrew 52 | "hi": "hin", # Hindi 53 | "hr": "hrv", # Croatian 54 | "hu": "hun", # Hungarian 55 | "hy": "hye", # Armenian 56 | "iu": "iku", # Inuktitut 57 | "id": "ind", # Indonesian 58 | "is": "isl", # Icelandic 59 | "it": "ita", # Italian 60 | "ita-old": "ita-old", # Old Italian 61 | "jv": "jav", # Javanese 62 | "ja": "jpn", # Japanese 63 | "jpn-vert": "jpn-vert", # Japanese vertical 64 | "kn": "kan", # Kannada 65 | "ka": "kat", # Georgian 66 | "kat-old": "kat-old", # Old Georgian 67 | "kk": "kaz", # Kazakh 68 | "km": "khm", # Khmer 69 | "ky": "kir", # Kyrgyz 70 | "kmr": "kmr", # Northern Kurdish 71 | "ko": "kor", # Korean 72 | "kor-vert": "kor_vert", # Korean vertical 73 | "lo": "lao", # Lao 74 | "la": "lat", # Latin 75 | "lv": "lav", # Latvian 76 | "lt": "lit", # Lithuanian 77 | "lb": "ltz", # Luxembourgish 78 | "ml": "mal", # Malayalam 79 | "mr": "mar", # Marathi 80 | "mk": "mkd", # Macedonian 81 | "mt": "mlt", # Maltese 82 | "mn": "mon", # Mongolian 83 | "mi": "mri", # Māori 84 | "ms": "msa", # Malay 85 | "my": "mya", # Burmese 86 | "ne": "nep", # Nepali 87 | "nl": "nld", # Dutch 88 | "no": "nor", # Norwegian 89 | "oc": "oci", # Occitan 90 | "or": "ori", # Odia 91 | "osd": "osd", # Unknown language [osd] 92 | "pa": "pan", # Punjabi 93 | "pl": "pol", # Polish 94 | "pt": "por", # Portuguese 95 | "ps": "pus", # Pashto 96 | "qu": "que", # Quechua 97 | "ro": "ron", # Romanian 98 | "ru": "rus", # Russian 99 | "sa": "san", # Sanskrit 100 | "script-arab": "script-arab", # Arabic script 101 | "script-armn": "script-armn", # Armenian script 102 | "script-beng": "script-beng", # Bengali script 103 | "script-cans": "script-cans", # Canadian Aboriginal script 104 | "script-cher": "script-cher", # Cherokee script 105 | "script-cyrl": "script-cyrl", # Cyrillic script 106 | "script-deva": "script-deva", # Devanagari script 107 | "script-ethi": "script-ethi", # Ethiopic script 108 | "script-frak": "script-frak", # Frankish script 109 | "script-geor": "script-geor", # Georgian script 110 | "script-grek": "script-grek", # Greek script 111 | "script-gujr": "script-gujr", # Gujarati script 112 | "script-guru": "script-guru", # Gurmukhi script 113 | "script-hang": "script-hang", # Hangul script 114 | "script-hang-vert": "script-hang-vert", # Hangul script vertical 115 | "script-hans": "script-hans", 116 | "script-hans-vert": "script-hans-vert", 117 | "script-hant": "script-hant", 118 | "script-hant-vert": "script-hant-vert", 119 | "script-hebr": "script-hebr", # Hebrew script 120 | "script-jpan": "script-jpan", # Japanese script 121 | "script-jpan-vert": "script-jpan-vert", # Japanese script vertical 122 | "script-khmr": "script-khmr", # Khmer script 123 | "script-knda": "script-knda", # Kannada script 124 | "script-laoo": "script-laoo", # Lao script 125 | "script-latn": "script-latn", 126 | "script-mlym": "script-mlym", # Malayalam script 127 | "script-mymr": "script-mymr", # Myanmar script 128 | "script-orya": "script-orya", # Odia script 129 | "script-sinh": "script-sinh", # Sinhala script 130 | "script-syrc": "script-syrc", # Syriac script 131 | "script-taml": "script-taml", # Tamil script 132 | "script-telu": "script-telu", # Telugu script 133 | "script-thaa": "script-thaa", # Thaana script 134 | "script-thai": "script-thai", # Thai script 135 | "script-tibt": "script-tibt", # Tibetan script 136 | "script-viet": "script-viet", # Vietnamese script 137 | "si": "sin", # Sinhala 138 | "sk": "slk", # Slovak 139 | "sl": "slv", # Slovenian 140 | "sd": "snd", # Sindhi 141 | "es": "spa", # Spanish 142 | "spa-old": "spa-old", # Old Spanish 143 | "sq": "sqi", # Albanian 144 | "sr": "srp", # Serbian 145 | "srp-latn": "srp-latn", # Serbian (Latin) 146 | "su": "sun", # Sundanese 147 | "sw": "swa", # Swahili 148 | "sv": "swe", # Swedish 149 | "syr": "syr", # Syriac 150 | "ta": "tam", # Tamil 151 | "tt": "tat", # Tatar 152 | "te": "tel", # Telugu 153 | "tg": "tgk", # Tajik 154 | "th": "tha", # Thai 155 | "ti": "tir", # Tigrinya 156 | "to": "ton", # Tongan 157 | "tr": "tur", # Turkish 158 | "ug": "uig", # Uyghur 159 | "uk": "ukr", # Ukrainian 160 | "ur": "urd", # Urdu 161 | "uz": "uzb", # Uzbek 162 | "uzb-cyrl": "uzb-cyrl", # Uzbek (Cyrillic) 163 | "vi": "vie", # Vietnamese 164 | "yi": "yid", # Yiddish 165 | "yo": "yor", # Yoruba 166 | } 167 | 168 | 169 | def supported_languages(): 170 | cmd = "tesseract --list-langs | grep -v osd | awk '{if(NR>1)print}'" 171 | sp = subprocess.Popen(["/bin/bash", "-c", cmd], stdout=subprocess.PIPE) 172 | tesseract_langs = [line.strip().decode("utf-8") for line in sp.stdout.readlines()] 173 | inverted_iso_dict = {v: k for k, v in iso_to_tesseract.items()} 174 | return list({tesseract_key: inverted_iso_dict[tesseract_key] for tesseract_key in tesseract_langs}.values()) 175 | -------------------------------------------------------------------------------- /src/adapters/infrastructure/ocr_service_adapter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import subprocess 4 | from pathlib import Path 5 | from ports.services.ocr_service import OCRService 6 | from configuration import OCR_SOURCE, OCR_OUTPUT, OCR_FAILED 7 | from adapters.infrastructure.ocr.languages import iso_to_tesseract, supported_languages 8 | 9 | 10 | class OCRServiceAdapter(OCRService): 11 | def process_pdf_ocr(self, filename: str, namespace: str, language: str = "en") -> Path: 12 | source_pdf_filepath, processed_pdf_filepath, failed_pdf_filepath = self._get_paths(namespace, filename) 13 | os.makedirs(processed_pdf_filepath.parent, exist_ok=True) 14 | 15 | result = subprocess.run( 16 | [ 17 | "ocrmypdf", 18 | "-l", 19 | iso_to_tesseract[language], 20 | source_pdf_filepath, 21 | processed_pdf_filepath, 22 | "--force-ocr", 23 | ] 24 | ) 25 | 26 | if result.returncode == 0: 27 | return processed_pdf_filepath 28 | 29 | os.makedirs(failed_pdf_filepath.parent, exist_ok=True) 30 | shutil.move(source_pdf_filepath, failed_pdf_filepath) 31 | return False 32 | 33 | def get_supported_languages(self) -> list[str]: 34 | return supported_languages() 35 | 36 | def _get_paths(self, namespace: str, pdf_file_name: str) -> tuple[Path, Path, Path]: 37 | file_name = "".join(pdf_file_name.split(".")[:-1]) if "." in pdf_file_name else pdf_file_name 38 | source_pdf_filepath = Path(OCR_SOURCE, namespace, pdf_file_name) 39 | processed_pdf_filepath = Path(OCR_OUTPUT, namespace, f"{file_name}.pdf") 40 | failed_pdf_filepath = Path(OCR_FAILED, namespace, pdf_file_name) 41 | return source_pdf_filepath, processed_pdf_filepath, failed_pdf_filepath 42 | -------------------------------------------------------------------------------- /src/adapters/infrastructure/pdf_analysis_service_adapter.py: -------------------------------------------------------------------------------- 1 | from typing import AnyStr 2 | from domain.PdfImages import PdfImages 3 | from domain.SegmentBox import SegmentBox 4 | from ports.services.pdf_analysis_service import PDFAnalysisService 5 | from ports.services.ml_model_service import MLModelService 6 | from ports.services.format_conversion_service import FormatConversionService 7 | from ports.repositories.file_repository import FileRepository 8 | from configuration import service_logger 9 | 10 | 11 | class PDFAnalysisServiceAdapter(PDFAnalysisService): 12 | def __init__( 13 | self, 14 | vgt_model_service: MLModelService, 15 | fast_model_service: MLModelService, 16 | format_conversion_service: FormatConversionService, 17 | file_repository: FileRepository, 18 | ): 19 | self.vgt_model_service = vgt_model_service 20 | self.fast_model_service = fast_model_service 21 | self.format_conversion_service = format_conversion_service 22 | self.file_repository = file_repository 23 | 24 | def analyze_pdf_layout( 25 | self, pdf_content: AnyStr, xml_filename: str = "", parse_tables_and_math: bool = False, keep_pdf: bool = False 26 | ) -> list[dict]: 27 | pdf_path = self.file_repository.save_pdf(pdf_content) 28 | service_logger.info("Creating PDF images") 29 | 30 | pdf_images_list: list[PdfImages] = [PdfImages.from_pdf_path(pdf_path, "", xml_filename)] 31 | 32 | predicted_segments = self.vgt_model_service.predict_document_layout(pdf_images_list) 33 | 34 | if parse_tables_and_math: 35 | pdf_images_200_dpi = PdfImages.from_pdf_path(pdf_path, "", xml_filename, dpi=200) 36 | self.format_conversion_service.convert_formula_to_latex(pdf_images_200_dpi, predicted_segments) 37 | self.format_conversion_service.convert_table_to_html(pdf_images_200_dpi, predicted_segments) 38 | 39 | if not keep_pdf: 40 | self.file_repository.delete_file(pdf_path) 41 | 42 | return [ 43 | SegmentBox.from_pdf_segment(pdf_segment, pdf_images_list[0].pdf_features.pages).to_dict() 44 | for pdf_segment in predicted_segments 45 | ] 46 | 47 | def analyze_pdf_layout_fast( 48 | self, pdf_content: AnyStr, xml_filename: str = "", parse_tables_and_math: bool = False, keep_pdf: bool = False 49 | ) -> list[dict]: 50 | pdf_path = self.file_repository.save_pdf(pdf_content) 51 | service_logger.info("Creating PDF images for fast analysis") 52 | 53 | pdf_images_list: list[PdfImages] = [PdfImages.from_pdf_path(pdf_path, "", xml_filename)] 54 | 55 | predicted_segments = self.fast_model_service.predict_layout_fast(pdf_images_list) 56 | 57 | if parse_tables_and_math: 58 | pdf_images_200_dpi = PdfImages.from_pdf_path(pdf_path, "", xml_filename, dpi=200) 59 | self.format_conversion_service.convert_formula_to_latex(pdf_images_200_dpi, predicted_segments) 60 | self.format_conversion_service.convert_table_to_html(pdf_images_list[0], predicted_segments) 61 | 62 | if not keep_pdf: 63 | self.file_repository.delete_file(pdf_path) 64 | 65 | return [ 66 | SegmentBox.from_pdf_segment(pdf_segment, pdf_images_list[0].pdf_features.pages).to_dict() 67 | for pdf_segment in predicted_segments 68 | ] 69 | -------------------------------------------------------------------------------- /src/adapters/infrastructure/text_extraction_adapter.py: -------------------------------------------------------------------------------- 1 | from pdf_token_type_labels import TokenType 2 | from ports.services.text_extraction_service import TextExtractionService 3 | from configuration import service_logger 4 | 5 | 6 | class TextExtractionAdapter(TextExtractionService): 7 | def extract_text_by_types(self, segment_boxes: list[dict], token_types: list[TokenType]) -> dict: 8 | service_logger.info(f"Extracted types: {[t.name for t in token_types]}") 9 | text = "\n".join( 10 | [ 11 | segment_box["text"] 12 | for segment_box in segment_boxes 13 | if TokenType.from_text(segment_box["type"].replace(" ", "_")) in token_types 14 | ] 15 | ) 16 | return text 17 | 18 | def extract_all_text(self, segment_boxes: list[dict]) -> dict: 19 | all_types = [t for t in TokenType] 20 | return self.extract_text_by_types(segment_boxes, all_types) 21 | -------------------------------------------------------------------------------- /src/adapters/infrastructure/toc/MergeTwoSegmentsTitles.py: -------------------------------------------------------------------------------- 1 | from adapters.infrastructure.toc.TitleFeatures import TitleFeatures 2 | from adapters.infrastructure.toc.PdfSegmentation import PdfSegmentation 3 | 4 | 5 | class MergeTwoSegmentsTitles: 6 | def __init__(self, pdf_segmentation: PdfSegmentation): 7 | self.title_features_list: list[TitleFeatures] = TitleFeatures.from_pdf_segmentation(pdf_segmentation) 8 | self.titles_merged: list[TitleFeatures] = list() 9 | self.merge() 10 | 11 | def merge(self): 12 | index = 0 13 | while index < len(self.title_features_list): 14 | if index == len(self.title_features_list) - 1: 15 | self.titles_merged.append(self.title_features_list[index]) 16 | break 17 | 18 | if not self.should_merge(self.title_features_list[index], self.title_features_list[index + 1]): 19 | self.titles_merged.append(self.title_features_list[index]) 20 | index += 1 21 | continue 22 | 23 | self.title_features_list[index + 1] = self.title_features_list[index + 1].append(self.title_features_list[index]) 24 | index += 1 25 | 26 | @staticmethod 27 | def should_merge(title: TitleFeatures, other_title: TitleFeatures): 28 | same_page = other_title.pdf_segment.page_number == title.pdf_segment.page_number 29 | 30 | if not same_page: 31 | return False 32 | 33 | if abs(other_title.top - title.bottom) > 15: 34 | return False 35 | 36 | if abs(other_title.left - title.right) > 15 or abs(other_title.right - title.left) > 15: 37 | return False 38 | 39 | if title.first_characters_type in [1, 2, 3] and other_title.first_characters_type in [1, 2, 3]: 40 | return False 41 | 42 | if title.bullet_points_type and other_title.bullet_points_type: 43 | return False 44 | 45 | if title.get_features_to_merge() != other_title.get_features_to_merge(): 46 | return False 47 | 48 | return True 49 | -------------------------------------------------------------------------------- /src/adapters/infrastructure/toc/PdfSegmentation.py: -------------------------------------------------------------------------------- 1 | from domain.PdfSegment import PdfSegment 2 | from pdf_features import PdfFeatures 3 | from pdf_features import PdfToken 4 | 5 | 6 | class PdfSegmentation: 7 | def __init__(self, pdf_features: PdfFeatures, pdf_segments: list[PdfSegment]): 8 | self.pdf_features: PdfFeatures = pdf_features 9 | self.pdf_segments: list[PdfSegment] = pdf_segments 10 | self.tokens_by_segments: dict[PdfSegment, list[PdfToken]] = self.find_tokens_by_segments() 11 | 12 | @staticmethod 13 | def find_segment_for_token(token: PdfToken, segments: list[PdfSegment], tokens_by_segments): 14 | best_score: float = 0 15 | most_probable_segment: PdfSegment | None = None 16 | for segment in segments: 17 | intersection_percentage = token.bounding_box.get_intersection_percentage(segment.bounding_box) 18 | if intersection_percentage > best_score: 19 | best_score = intersection_percentage 20 | most_probable_segment = segment 21 | if best_score >= 99: 22 | break 23 | if most_probable_segment: 24 | tokens_by_segments.setdefault(most_probable_segment, list()).append(token) 25 | 26 | def find_tokens_by_segments(self): 27 | tokens_by_segments: dict[PdfSegment, list[PdfToken]] = {} 28 | for page in self.pdf_features.pages: 29 | page_segments = [segment for segment in self.pdf_segments if segment.page_number == page.page_number] 30 | for token in page.tokens: 31 | self.find_segment_for_token(token, page_segments, tokens_by_segments) 32 | return tokens_by_segments 33 | -------------------------------------------------------------------------------- /src/adapters/infrastructure/toc/TOCExtractor.py: -------------------------------------------------------------------------------- 1 | from adapters.infrastructure.toc.MergeTwoSegmentsTitles import MergeTwoSegmentsTitles 2 | from adapters.infrastructure.toc.TitleFeatures import TitleFeatures 3 | from adapters.infrastructure.toc.data.TOCItem import TOCItem 4 | from adapters.infrastructure.toc.PdfSegmentation import PdfSegmentation 5 | 6 | 7 | class TOCExtractor: 8 | def __init__(self, pdf_segmentation: PdfSegmentation): 9 | self.pdf_segmentation = pdf_segmentation 10 | self.titles_features_sorted = MergeTwoSegmentsTitles(self.pdf_segmentation).titles_merged 11 | self.toc: list[TOCItem] = list() 12 | self.set_toc() 13 | 14 | def set_toc(self): 15 | for index, title_features in enumerate(self.titles_features_sorted): 16 | indentation = self.get_indentation(index, title_features) 17 | self.toc.append(title_features.to_toc_item(indentation)) 18 | 19 | def __str__(self): 20 | return "\n".join([f'{" " * x.indentation} * {x.label}' for x in self.toc]) 21 | 22 | def get_indentation(self, title_index: int, title_features: TitleFeatures): 23 | if title_index == 0: 24 | return 0 25 | 26 | for index in reversed(range(title_index)): 27 | if self.toc[index].point_closed: 28 | continue 29 | 30 | if self.same_indentation(self.titles_features_sorted[index], title_features): 31 | self.close_toc_items(self.toc[index].indentation) 32 | return self.toc[index].indentation 33 | 34 | return self.toc[title_index - 1].indentation + 1 35 | 36 | def close_toc_items(self, indentation): 37 | for toc in self.toc: 38 | if toc.indentation > indentation: 39 | toc.point_closed = True 40 | 41 | @staticmethod 42 | def same_indentation(previous_title_features: TitleFeatures, title_features: TitleFeatures): 43 | if previous_title_features.first_characters in title_features.get_possible_previous_point(): 44 | return True 45 | 46 | if previous_title_features.get_features_toc() == title_features.get_features_toc(): 47 | return True 48 | 49 | return False 50 | 51 | def to_dict(self): 52 | toc: list[dict[str, any]] = list() 53 | 54 | for toc_item in self.toc: 55 | toc_element_dict = dict() 56 | toc_element_dict["indentation"] = toc_item.indentation 57 | toc_element_dict["label"] = toc_item.label 58 | rectangle = dict() 59 | rectangle["left"] = int(toc_item.selection_rectangle.left) 60 | rectangle["top"] = int(toc_item.selection_rectangle.top) 61 | rectangle["width"] = int(toc_item.selection_rectangle.width) 62 | rectangle["height"] = int(toc_item.selection_rectangle.height) 63 | rectangle["page"] = str(toc_item.selection_rectangle.page_number) 64 | toc_element_dict["bounding_box"] = rectangle 65 | toc.append(toc_element_dict) 66 | 67 | return toc 68 | -------------------------------------------------------------------------------- /src/adapters/infrastructure/toc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/infrastructure/toc/__init__.py -------------------------------------------------------------------------------- /src/adapters/infrastructure/toc/data/TOCItem.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | from domain.SegmentBox import SegmentBox 4 | 5 | 6 | class TOCItem(BaseModel): 7 | indentation: int 8 | label: str = "" 9 | selection_rectangle: SegmentBox 10 | point_closed: bool = False 11 | -------------------------------------------------------------------------------- /src/adapters/infrastructure/toc/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/infrastructure/toc/data/__init__.py -------------------------------------------------------------------------------- /src/adapters/infrastructure/toc/extract_table_of_contents.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | import uuid 3 | from os.path import join 4 | from pathlib import Path 5 | from typing import AnyStr 6 | from domain.PdfSegment import PdfSegment 7 | from pdf_features import PdfFeatures 8 | from pdf_features import Rectangle 9 | from pdf_token_type_labels import TokenType 10 | from adapters.infrastructure.toc.TOCExtractor import TOCExtractor 11 | from configuration import service_logger 12 | from adapters.infrastructure.toc.PdfSegmentation import PdfSegmentation 13 | 14 | TITLE_TYPES = {TokenType.TITLE, TokenType.SECTION_HEADER} 15 | SKIP_TYPES = {TokenType.TITLE, TokenType.SECTION_HEADER, TokenType.PAGE_HEADER, TokenType.PICTURE} 16 | 17 | 18 | def get_file_path(file_name, extension): 19 | return join(tempfile.gettempdir(), file_name + "." + extension) 20 | 21 | 22 | def pdf_content_to_pdf_path(file_content): 23 | file_id = str(uuid.uuid1()) 24 | 25 | pdf_path = Path(get_file_path(file_id, "pdf")) 26 | pdf_path.write_bytes(file_content) 27 | 28 | return pdf_path 29 | 30 | 31 | def skip_name_of_the_document(pdf_segments: list[PdfSegment], title_segments: list[PdfSegment]): 32 | segments_to_remove = [] 33 | last_segment = None 34 | for segment in pdf_segments: 35 | if segment.segment_type not in SKIP_TYPES: 36 | break 37 | if segment.segment_type == TokenType.PAGE_HEADER or segment.segment_type == TokenType.PICTURE: 38 | continue 39 | if not last_segment: 40 | last_segment = segment 41 | else: 42 | if segment.bounding_box.right < last_segment.bounding_box.left + last_segment.bounding_box.width * 0.66: 43 | break 44 | last_segment = segment 45 | if segment.segment_type in TITLE_TYPES: 46 | segments_to_remove.append(segment) 47 | for segment in segments_to_remove: 48 | title_segments.remove(segment) 49 | 50 | 51 | def get_pdf_segments_from_segment_boxes(pdf_features: PdfFeatures, segment_boxes: list[dict]) -> list[PdfSegment]: 52 | pdf_segments: list[PdfSegment] = [] 53 | for segment_box in segment_boxes: 54 | left, top, width, height = segment_box["left"], segment_box["top"], segment_box["width"], segment_box["height"] 55 | bounding_box = Rectangle.from_width_height(left, top, width, height) 56 | segment_type = TokenType.from_value(segment_box["type"]) 57 | pdf_name = pdf_features.file_name 58 | segment = PdfSegment(segment_box["page_number"], bounding_box, segment_box["text"], segment_type, pdf_name) 59 | pdf_segments.append(segment) 60 | return pdf_segments 61 | 62 | 63 | def extract_table_of_contents(file: AnyStr, segment_boxes: list[dict], skip_document_name=False): 64 | service_logger.info("Getting TOC") 65 | pdf_path = pdf_content_to_pdf_path(file) 66 | pdf_features: PdfFeatures = PdfFeatures.from_pdf_path(pdf_path) 67 | pdf_segments: list[PdfSegment] = get_pdf_segments_from_segment_boxes(pdf_features, segment_boxes) 68 | title_segments = [segment for segment in pdf_segments if segment.segment_type in TITLE_TYPES] 69 | if skip_document_name: 70 | skip_name_of_the_document(pdf_segments, title_segments) 71 | pdf_segmentation: PdfSegmentation = PdfSegmentation(pdf_features, title_segments) 72 | toc_instance: TOCExtractor = TOCExtractor(pdf_segmentation) 73 | return toc_instance.to_dict() 74 | -------------------------------------------------------------------------------- /src/adapters/infrastructure/toc/methods/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/infrastructure/toc/methods/__init__.py -------------------------------------------------------------------------------- /src/adapters/infrastructure/toc/methods/two_models_v3_segments_context_2/Modes.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import hashlib 3 | from statistics import mode 4 | 5 | from pdf_features import PdfFeatures 6 | 7 | 8 | @dataclasses.dataclass 9 | class Modes: 10 | lines_space_mode: float 11 | left_space_mode: float 12 | right_space_mode: float 13 | font_size_mode: float 14 | font_family_name_mode: str 15 | font_family_mode: int 16 | font_family_mode_normalized: float 17 | pdf_features: PdfFeatures 18 | 19 | def __init__(self, pdf_features: PdfFeatures): 20 | self.pdf_features = pdf_features 21 | self.set_modes() 22 | 23 | def set_modes(self): 24 | line_spaces, right_spaces, left_spaces = [0], [0], [0] 25 | for page, token in self.pdf_features.loop_tokens(): 26 | right_spaces.append(self.pdf_features.pages[0].page_width - token.bounding_box.right) 27 | left_spaces.append(token.bounding_box.left) 28 | line_spaces.append(token.bounding_box.bottom) 29 | 30 | self.lines_space_mode = mode(line_spaces) 31 | self.left_space_mode = mode(left_spaces) 32 | self.right_space_mode = mode(right_spaces) 33 | 34 | font_sizes = [token.font.font_size for page, token in self.pdf_features.loop_tokens() if token.font] 35 | self.font_size_mode = mode(font_sizes) if font_sizes else 0 36 | font_ids = [token.font.font_id for page, token in self.pdf_features.loop_tokens() if token.font] 37 | self.font_family_name_mode = mode(font_ids) if font_ids else "" 38 | self.font_family_mode = abs( 39 | int( 40 | str(hashlib.sha256(self.font_family_name_mode.encode("utf-8")).hexdigest())[:8], 41 | 16, 42 | ) 43 | ) 44 | self.font_family_mode_normalized = float(f"{str(self.font_family_mode)[0]}.{str(self.font_family_mode)[1:]}") 45 | -------------------------------------------------------------------------------- /src/adapters/infrastructure/toc/methods/two_models_v3_segments_context_2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/infrastructure/toc/methods/two_models_v3_segments_context_2/__init__.py -------------------------------------------------------------------------------- /src/adapters/infrastructure/toc_service_adapter.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | import uuid 3 | from os.path import join 4 | from pathlib import Path 5 | from typing import AnyStr 6 | from domain.PdfSegment import PdfSegment 7 | from pdf_features import PdfFeatures, Rectangle 8 | from pdf_token_type_labels import TokenType 9 | from ports.services.toc_service import TOCService 10 | from configuration import service_logger 11 | from adapters.infrastructure.toc.TOCExtractor import TOCExtractor 12 | from adapters.infrastructure.toc.PdfSegmentation import PdfSegmentation 13 | 14 | TITLE_TYPES = {TokenType.TITLE, TokenType.SECTION_HEADER} 15 | SKIP_TYPES = {TokenType.TITLE, TokenType.SECTION_HEADER, TokenType.PAGE_HEADER, TokenType.PICTURE} 16 | 17 | 18 | class TOCServiceAdapter(TOCService): 19 | 20 | def extract_table_of_contents( 21 | self, pdf_content: AnyStr, segment_boxes: list[dict], skip_document_name=False 22 | ) -> list[dict]: 23 | service_logger.info("Getting TOC") 24 | pdf_path = self._pdf_content_to_pdf_path(pdf_content) 25 | pdf_features: PdfFeatures = PdfFeatures.from_pdf_path(pdf_path) 26 | pdf_segments: list[PdfSegment] = self._get_pdf_segments_from_segment_boxes(pdf_features, segment_boxes) 27 | title_segments = [segment for segment in pdf_segments if segment.segment_type in TITLE_TYPES] 28 | if skip_document_name: 29 | self._skip_name_of_the_document(pdf_segments, title_segments) 30 | pdf_segmentation: PdfSegmentation = PdfSegmentation(pdf_features, title_segments) 31 | toc_instance: TOCExtractor = TOCExtractor(pdf_segmentation) 32 | return toc_instance.to_dict() 33 | 34 | def format_toc_for_uwazi(self, toc_items: list[dict]) -> list[dict]: 35 | toc_compatible = [] 36 | for toc_item in toc_items: 37 | toc_compatible.append(toc_item.copy()) 38 | toc_compatible[-1]["bounding_box"]["left"] = int(toc_item["bounding_box"]["left"] / 0.75) 39 | toc_compatible[-1]["bounding_box"]["top"] = int(toc_item["bounding_box"]["top"] / 0.75) 40 | toc_compatible[-1]["bounding_box"]["width"] = int(toc_item["bounding_box"]["width"] / 0.75) 41 | toc_compatible[-1]["bounding_box"]["height"] = int(toc_item["bounding_box"]["height"] / 0.75) 42 | toc_compatible[-1]["selectionRectangles"] = [toc_compatible[-1]["bounding_box"]] 43 | del toc_compatible[-1]["bounding_box"] 44 | return toc_compatible 45 | 46 | def _get_file_path(self, file_name: str, extension: str) -> str: 47 | return join(tempfile.gettempdir(), file_name + "." + extension) 48 | 49 | def _pdf_content_to_pdf_path(self, file_content: AnyStr) -> Path: 50 | file_id = str(uuid.uuid1()) 51 | pdf_path = Path(self._get_file_path(file_id, "pdf")) 52 | pdf_path.write_bytes(file_content) 53 | return pdf_path 54 | 55 | def _skip_name_of_the_document(self, pdf_segments: list[PdfSegment], title_segments: list[PdfSegment]) -> None: 56 | segments_to_remove = [] 57 | last_segment = None 58 | for segment in pdf_segments: 59 | if segment.segment_type not in SKIP_TYPES: 60 | break 61 | if segment.segment_type == TokenType.PAGE_HEADER or segment.segment_type == TokenType.PICTURE: 62 | continue 63 | if not last_segment: 64 | last_segment = segment 65 | else: 66 | if segment.bounding_box.right < last_segment.bounding_box.left + last_segment.bounding_box.width * 0.66: 67 | break 68 | last_segment = segment 69 | if segment.segment_type in TITLE_TYPES: 70 | segments_to_remove.append(segment) 71 | for segment in segments_to_remove: 72 | title_segments.remove(segment) 73 | 74 | def _get_pdf_segments_from_segment_boxes(self, pdf_features: PdfFeatures, segment_boxes: list[dict]) -> list[PdfSegment]: 75 | pdf_segments: list[PdfSegment] = [] 76 | for segment_box in segment_boxes: 77 | left, top, width, height = segment_box["left"], segment_box["top"], segment_box["width"], segment_box["height"] 78 | bounding_box = Rectangle.from_width_height(left, top, width, height) 79 | segment_type = TokenType.from_value(segment_box["type"]) 80 | pdf_name = pdf_features.file_name 81 | segment = PdfSegment(segment_box["page_number"], bounding_box, segment_box["text"], segment_type, pdf_name) 82 | pdf_segments.append(segment) 83 | return pdf_segments 84 | -------------------------------------------------------------------------------- /src/adapters/infrastructure/translation/decode_html_content.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def decode_html(text, link_map, doc_ref_map): 5 | # 1. Decode bold+italic first 6 | def bold_italic_decoder(match): 7 | return f"{match.group(2)}" 8 | 9 | text = re.sub(r"\[BI(\d+)\](.*?)\[BI\1\]", bold_italic_decoder, text) 10 | 11 | # 2. Decode bold 12 | def bold_decoder(match): 13 | return f"{match.group(2)}" 14 | 15 | text = re.sub(r"\[B(\d+)\](.*?)\[B\1\]", bold_decoder, text) 16 | 17 | # 3. Decode italic 18 | def italic_decoder(match): 19 | return f"{match.group(2)}" 20 | 21 | text = re.sub(r"\[IT(\d+)\](.*?)\[IT\1\]", italic_decoder, text) 22 | 23 | # 4. Decode links 24 | def link_decoder(match): 25 | idx = int(match.group(1)) 26 | if idx < len(link_map): 27 | label, url = link_map[idx] 28 | return f'{match.group(2)}' 29 | else: 30 | # Return original text if index is out of range 31 | return match.group(0) 32 | 33 | text = re.sub(r"\[LINK(\d+)\](.*?)\[LINK\1\]", link_decoder, text) 34 | 35 | # 5. Decode doc refs (same as markdown since they're custom) 36 | def doc_ref_decoder(match): 37 | idx = int(match.group(1)) 38 | if idx < len(doc_ref_map): 39 | return doc_ref_map[idx] 40 | else: 41 | # Return original text if index is out of range 42 | return match.group(0) 43 | 44 | text = re.sub(r"\[DOCREF(\d+)\]", doc_ref_decoder, text) 45 | text = text.replace("] (#page", "](#page") 46 | text = " ".join(text.split()) 47 | 48 | return text 49 | -------------------------------------------------------------------------------- /src/adapters/infrastructure/translation/decode_markdown_content.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def decode_markdown(text, link_map, doc_ref_map): 5 | # 1. Decode bold+italic first 6 | def bold_italic_decoder(match): 7 | return f"**_{match.group(2)}_**" 8 | 9 | text = re.sub(r"\[BI(\d+)\](.*?)\[BI\1\]", bold_italic_decoder, text) 10 | 11 | # 2. Decode bold 12 | def bold_decoder(match): 13 | return f"**{match.group(2)}**" 14 | 15 | text = re.sub(r"\[B(\d+)\](.*?)\[B\1\]", bold_decoder, text) 16 | 17 | # 3. Decode italic 18 | def italic_decoder(match): 19 | return f"_{match.group(2)}_" 20 | 21 | text = re.sub(r"\[IT(\d+)\](.*?)\[IT\1\]", italic_decoder, text) 22 | 23 | # 4. Decode links 24 | def link_decoder(match): 25 | idx = int(match.group(1)) 26 | if idx < len(link_map): 27 | label, url = link_map[idx] 28 | return f"[{match.group(2)}]({url})" 29 | else: 30 | return match.group(0) 31 | 32 | text = re.sub(r"\[LINK(\d+)\](.*?)\[LINK\1\]", link_decoder, text) 33 | 34 | # 5. Decode doc refs 35 | def doc_ref_decoder(match): 36 | idx = int(match.group(1)) 37 | if idx < len(doc_ref_map): 38 | return doc_ref_map[idx] 39 | else: 40 | return match.group(0) 41 | 42 | text = re.sub(r"\[DOCREF(\d+)\]", doc_ref_decoder, text) 43 | text = text.replace("] (#page", "](#page") 44 | text = " ".join(text.split()) 45 | 46 | return text 47 | -------------------------------------------------------------------------------- /src/adapters/infrastructure/translation/download_translation_model.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import time 3 | from configuration import service_logger 4 | 5 | OLLAMA_NOT_RUNNING_MSG = "could not connect to ollama server" 6 | 7 | 8 | def is_ollama_running(): 9 | try: 10 | result = subprocess.run(["ollama", "ls"], capture_output=True, text=True) 11 | msg = result.stderr.lower() + result.stdout.lower() 12 | return OLLAMA_NOT_RUNNING_MSG not in msg and result.returncode == 0 13 | except FileNotFoundError: 14 | service_logger.error("Ollama is not installed or not in PATH.") 15 | return False 16 | 17 | 18 | def start_ollama(): 19 | try: 20 | subprocess.Popen(["ollama", "serve"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) 21 | service_logger.info("Starting Ollama server...") 22 | time.sleep(5) 23 | return True 24 | except Exception as e: 25 | service_logger.error(f"Failed to start Ollama server: {e}") 26 | return False 27 | 28 | 29 | def model_name_variants(name): 30 | base = name.split(":")[0] 31 | return {base, f"{base}:latest", name} 32 | 33 | 34 | def ensure_ollama_model(model_name): 35 | if not is_ollama_running(): 36 | service_logger.info("Ollama server is not running. Attempting to start it...") 37 | if not start_ollama(): 38 | service_logger.error("Could not start Ollama server. Exiting.") 39 | return False 40 | for _ in range(5): 41 | if is_ollama_running(): 42 | break 43 | time.sleep(2) 44 | else: 45 | service_logger.error("Ollama server did not start in time.") 46 | return False 47 | 48 | try: 49 | result = subprocess.run(["ollama", "ls"], capture_output=True, text=True, check=True) 50 | except subprocess.CalledProcessError as e: 51 | service_logger.error(f"Error running 'ollama ls': {e}") 52 | return False 53 | 54 | model_lines = [line.split()[0] for line in result.stdout.splitlines() if line and not line.startswith("NAME")] 55 | available_models = set(model_lines) 56 | variants = model_name_variants(model_name) 57 | 58 | if available_models & variants: 59 | service_logger.info(f"Model '{model_name}' already exists in Ollama.") 60 | return True 61 | 62 | service_logger.info(f"Model '{model_name}' not found. Pulling...") 63 | try: 64 | subprocess.run(["ollama", "pull", model_name], check=True) 65 | service_logger.info(f"Model '{model_name}' pulled successfully.") 66 | return True 67 | except subprocess.CalledProcessError as e: 68 | service_logger.error(f"Failed to pull model '{model_name}': {e}") 69 | return False 70 | -------------------------------------------------------------------------------- /src/adapters/infrastructure/translation/encode_html_content.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def encode_html(text): 5 | text = text.replace(" ", " ") 6 | 7 | link_map = [] 8 | doc_ref_map = [] 9 | bold_map = [] 10 | italic_map = [] 11 | bold_italic_map = [] 12 | 13 | # Helper to encode doc refs 14 | def doc_ref_replacer(match): 15 | idx = len(doc_ref_map) 16 | doc_ref_map.append(match.group(0)) 17 | return f"[DOCREF{idx}]" 18 | 19 | # Helper to encode bold+italic 20 | def bold_italic_replacer(match): 21 | text_content = match.group(1) 22 | idx = len(bold_italic_map) 23 | bold_italic_map.append(text_content) 24 | return f"[BI{idx}]{text_content}[BI{idx}]" 25 | 26 | # Helper to encode bold 27 | def bold_replacer(match): 28 | text_content = match.group(1) 29 | idx = len(bold_map) 30 | bold_map.append(text_content) 31 | return f"[B{idx}]{text_content}[B{idx}]" 32 | 33 | # 1. Encode document references FIRST - Updated patterns to match your actual format 34 | # Handle patterns like [\[9,](#page-5-9), [10,](#page-5-10), [11\]](#page-5-11), [\[12\]](#page-5-12), [\[13\]](#page-5-13) 35 | text = re.sub(r"(\[\\?\[?\*?\d+[,\.]?\\?\]?\]\(#page-\d+-\d+\))", doc_ref_replacer, text) 36 | 37 | # Also handle the original patterns from markdown version 38 | text = re.sub(r"(\[(?:\\?\d+[,\.]? ?)+\]\(#page-\d+-\d+\))", doc_ref_replacer, text) 39 | text = re.sub(r"(\[\d+[,\.]?\]\(#page-\d+-\d+\))", doc_ref_replacer, text) 40 | text = re.sub(r"(\[\*\d+[,\.]?\]\(#page-\d+-\d+\))", doc_ref_replacer, text) 41 | 42 | # 2. Encode links BEFORE formatting - handle complex nested structures 43 | def find_and_replace_links(text): 44 | offset = 0 # Track how much the text has shifted due to replacements 45 | 46 | # Find all ]*>' 48 | matches = list(re.finditer(pattern, text, re.IGNORECASE)) 49 | 50 | # Process from left to right 51 | for match in matches: 52 | start = match.start() + offset 53 | href = match.group(1) 54 | 55 | # Find the matching closing tag 56 | tag_count = 1 57 | pos = match.end() + offset 58 | content_start = pos 59 | 60 | while pos < len(text) and tag_count > 0: 61 | # Look for opening tags 62 | next_open = text.find(" tags 68 | next_close = text.find("", pos) 69 | next_close_case = text.find("", pos) 70 | if next_close_case != -1 and (next_close == -1 or next_close_case < next_close): 71 | next_close = next_close_case 72 | 73 | if next_close == -1: 74 | break 75 | 76 | if next_open != -1 and next_open < next_close: 77 | # Found nested opening tag 78 | tag_count += 1 79 | pos = next_open + 3 80 | else: 81 | # Found closing tag 82 | tag_count -= 1 83 | if tag_count == 0: 84 | # This is our matching closing tag 85 | content = text[content_start:next_close] 86 | end = next_close + 4 # +4 for 87 | 88 | # Replace the entire link 89 | idx = len(link_map) 90 | link_map.append((content, href)) 91 | replacement = f"[LINK{idx}]{content}[LINK{idx}]" 92 | 93 | original_length = end - start 94 | new_length = len(replacement) 95 | 96 | text = text[:start] + replacement + text[end:] 97 | offset += new_length - original_length 98 | break 99 | else: 100 | pos = next_close + 4 101 | 102 | return text 103 | 104 | text = find_and_replace_links(text) 105 | 106 | # 3. Encode bold+italic combinations BEFORE individual bold/italic 107 | # Handle text and text - only simple cases without nested links 108 | text = re.sub(r"([^<]+)", bold_italic_replacer, text, flags=re.IGNORECASE) 109 | text = re.sub(r"([^<]+)", bold_italic_replacer, text, flags=re.IGNORECASE) 110 | 111 | # 4. Encode bold (text) - only simple cases without nested tags 112 | text = re.sub(r"([^<]+)", bold_replacer, text, flags=re.IGNORECASE) 113 | 114 | # 5. Encode italic (text) - handle cases that might contain encoded links 115 | def italic_with_links_replacer(match): 116 | content = match.group(1) 117 | idx = len(italic_map) 118 | italic_map.append(content) 119 | return f"[IT{idx}]{content}[IT{idx}]" 120 | 121 | # Handle italic tags that might contain encoded links 122 | text = re.sub( 123 | r"([^<]*(?:\[LINK\d+\][^\[]*\[LINK\d+\][^<]*)*)", italic_with_links_replacer, text, flags=re.IGNORECASE 124 | ) 125 | # Handle simple italic tags 126 | text = re.sub(r"([^<]+)", italic_with_links_replacer, text, flags=re.IGNORECASE) 127 | 128 | return text, link_map, doc_ref_map 129 | -------------------------------------------------------------------------------- /src/adapters/infrastructure/translation/encode_markdown_content.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def encode_markdown(text): 5 | text = text.replace("_ _", " ") 6 | 7 | link_map = [] 8 | doc_ref_map = [] 9 | bold_map = [] 10 | italic_map = [] 11 | bold_italic_map = [] 12 | 13 | # Helper to encode links with nested brackets in label 14 | def link_replacer(match): 15 | label = match.group(1)[1:-1] # Remove outer [] 16 | url = match.group(3) 17 | idx = len(link_map) 18 | link_map.append((label, url)) 19 | return f"[LINK{idx}]{label}[LINK{idx}]" 20 | 21 | # Helper to encode doc refs - sequential numbering 22 | def doc_ref_replacer(match): 23 | idx = len(doc_ref_map) 24 | doc_ref_map.append(match.group(0)) 25 | return f"[DOCREF{idx}]" 26 | 27 | # Helper to encode bold+italic 28 | def bold_italic_replacer(match): 29 | text_content = match.group(1) 30 | idx = len(bold_italic_map) 31 | bold_italic_map.append(text_content) 32 | return f"[BI{idx}]{text_content}[BI{idx}]" 33 | 34 | # Helper to encode bold 35 | def bold_replacer(match): 36 | text_content = match.group(1) 37 | idx = len(bold_map) 38 | bold_map.append(text_content) 39 | return f"[B{idx}]{text_content}[B{idx}]" 40 | 41 | # Helper to encode italic 42 | def italic_replacer(match): 43 | text_content = match.group(1) 44 | idx = len(italic_map) 45 | italic_map.append(text_content) 46 | return f"[IT{idx}]{text_content}[IT{idx}]" 47 | 48 | # 1. Encode ALL document references in ONE PASS for sequential numbering 49 | doc_ref_patterns = [ 50 | r"\[\\?\[\d+,\]\(#page-\d+-\d+\)", # [\[9,](#page-5-9) 51 | r"\[\\?\[\d+\\?\]\]\(#page-\d+-\d+\)", # [\[12\]](#page-5-12) 52 | r"\[\*\d+\]\(#page-\d+-\d+\)", # [*1221](#page-3-7) 53 | r"\[\d+,\]\(#page-\d+-\d+\)", # [10,](#page-5-10) 54 | r"\[\d+\\?\]\]\(#page-\d+-\d+\)", # [11\]](#page-5-11) 55 | r"\[\d+\]\(#page-\d+-\d+\)", # [1221](#page-3-7) 56 | ] 57 | 58 | # Combine all patterns with alternation (|) 59 | combined_pattern = "(" + "|".join(doc_ref_patterns) + ")" 60 | text = re.sub(combined_pattern, doc_ref_replacer, text) 61 | 62 | # 2. Encode links BEFORE formatting to avoid matching underscores in URLs 63 | link_pattern = re.compile(r"(\[((?:[^\[\]]+|\[[^\[\]]*\])*)\])\((https?://[^\)]+)\)") 64 | while True: 65 | new_text = link_pattern.sub(lambda m: link_replacer(m), text) 66 | if new_text == text: 67 | break 68 | text = new_text 69 | 70 | # 3. Encode bold+italic BEFORE individual bold/italic 71 | text = re.sub(r"\*\*\_([^\*_]+)\_\*\*", bold_italic_replacer, text) 72 | text = re.sub(r"_\*([^\*_]+)\*_", bold_italic_replacer, text) 73 | 74 | # 4. Encode bold 75 | text = re.sub(r"\*\*([^\*]+)\*\*", bold_replacer, text) 76 | 77 | # 5. Encode italic 78 | text = re.sub(r"\_([^\_]+)\_", italic_replacer, text) 79 | 80 | return text, link_map, doc_ref_map 81 | -------------------------------------------------------------------------------- /src/adapters/infrastructure/translation/ollama_container_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import requests 4 | import json 5 | from typing import Optional, Any 6 | from configuration import service_logger 7 | 8 | 9 | class OllamaContainerManager: 10 | 11 | def __init__(self, ollama_host: str = None): 12 | self.ollama_host = ollama_host or os.getenv("OLLAMA_HOST", "http://ollama:11434") 13 | self.api_base_url = f"{self.ollama_host}/api" 14 | self.timeout = 600 15 | self.max_retries = 5 16 | 17 | def is_ollama_available(self) -> bool: 18 | try: 19 | response = requests.get(f"{self.api_base_url}/tags", timeout=10) 20 | return response.status_code == 200 21 | except Exception as e: 22 | service_logger.debug(f"Ollama availability check failed: {e}") 23 | return False 24 | 25 | def ensure_model_available(self, model_name: str) -> bool: 26 | try: 27 | if self._is_model_available(model_name): 28 | service_logger.info(f"\033[92mModel '{model_name}' is available\033[0m") 29 | return True 30 | 31 | service_logger.info(f"\033[93mModel '{model_name}' not found. Downloading...\033[0m") 32 | return self._download_model(model_name) 33 | 34 | except Exception as e: 35 | service_logger.error(f"Error ensuring model availability: {e}") 36 | return False 37 | 38 | def _is_model_available(self, model_name: str) -> bool: 39 | try: 40 | response = requests.get(f"{self.api_base_url}/tags", timeout=10) 41 | if response.status_code != 200: 42 | return False 43 | 44 | models_data = response.json() 45 | available_models = [model["name"] for model in models_data.get("models", [])] 46 | 47 | model_variants = {model_name, f"{model_name}:latest", model_name.split(":")[0]} 48 | return any(variant in available_models for variant in model_variants) 49 | 50 | except Exception as e: 51 | service_logger.error(f"Error checking model availability: {e}") 52 | return False 53 | 54 | def _download_model(self, model_name: str) -> bool: 55 | try: 56 | response = requests.post(f"{self.api_base_url}/pull", json={"name": model_name}, stream=True) 57 | 58 | if response.status_code != 200: 59 | service_logger.error(f"Failed to start model download: {response.text}") 60 | return False 61 | 62 | for idx, line in enumerate(response.iter_lines()): 63 | if line: 64 | try: 65 | data = json.loads(line) 66 | if "status" in data and idx % 100 == 0: 67 | service_logger.info(f"Model download: {data['status']}") 68 | if data.get("status") == "success": 69 | service_logger.info(f"Model '{model_name}' downloaded successfully") 70 | return True 71 | except json.JSONDecodeError: 72 | continue 73 | 74 | return True 75 | 76 | except Exception as e: 77 | service_logger.error(f"Error downloading model '{model_name}': {e}") 78 | return False 79 | 80 | def chat_with_timeout( 81 | self, model: str, messages: list[dict], source_markup: str, timeout: Optional[int] = None 82 | ) -> dict[str, Any] | str: 83 | timeout = timeout or self.timeout 84 | 85 | for attempt in range(self.max_retries + 1): 86 | try: 87 | if attempt > 0: 88 | service_logger.info(f"Retrying chat request (attempt {attempt + 1}/{self.max_retries + 1})") 89 | time.sleep(10) 90 | 91 | return self._make_chat_request( 92 | model, 93 | messages, 94 | timeout, 95 | ) 96 | 97 | except requests.exceptions.Timeout: 98 | service_logger.warning(f"Chat request timed out after {timeout} seconds (attempt {attempt})") 99 | if attempt < self.max_retries: 100 | continue 101 | else: 102 | service_logger.error(f"Chat request failed after {self.max_retries} attempts due to timeout") 103 | return source_markup 104 | 105 | except Exception as e: 106 | service_logger.error(f"Chat request failed (attempt {attempt}): {e}") 107 | if attempt < self.max_retries: 108 | continue 109 | else: 110 | service_logger.error(f"Chat request failed after {self.max_retries} attempts") 111 | return source_markup 112 | 113 | return source_markup 114 | 115 | def _make_chat_request(self, model: str, messages: list, timeout: int) -> dict[str, Any]: 116 | payload = {"model": model, "messages": messages, "stream": False} 117 | 118 | response = requests.post(f"{self.api_base_url}/chat", json=payload, timeout=timeout) 119 | 120 | if response.status_code != 200: 121 | raise Exception(f"Chat request failed with status {response.status_code}: {response.text}") 122 | 123 | return response.json() 124 | 125 | def ensure_service_ready(self, model_name: str) -> bool: 126 | try: 127 | if not self.is_ollama_available(): 128 | service_logger.error("Ollama service is not available. Make sure the Ollama container is running.") 129 | return False 130 | 131 | return self.ensure_model_available(model_name) 132 | 133 | except Exception as e: 134 | service_logger.error(f"Error ensuring service readiness: {e}") 135 | return False 136 | -------------------------------------------------------------------------------- /src/adapters/infrastructure/visualization_service_adapter.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from os import makedirs 3 | from os.path import join 4 | from pdf_annotate import PdfAnnotator, Location, Appearance 5 | from starlette.responses import FileResponse 6 | from ports.services.visualization_service import VisualizationService 7 | from configuration import ROOT_PATH 8 | 9 | DOCLAYNET_COLOR_BY_TYPE = { 10 | "Caption": "#FFC300", 11 | "Footnote": "#581845", 12 | "Formula": "#FF5733", 13 | "List item": "#008B8B", 14 | "Page footer": "#FF5733", 15 | "Page header": "#581845", 16 | "Picture": "#C70039", 17 | "Section header": "#C70039", 18 | "Table": "#FF8C00", 19 | "Text": "#606060", 20 | "Title": "#EED400", 21 | } 22 | 23 | 24 | class VisualizationServiceAdapter(VisualizationService): 25 | def create_pdf_visualization(self, pdf_path: Path, segment_boxes: list[dict]) -> Path: 26 | pdf_outputs_path = join(ROOT_PATH, "pdf_outputs") 27 | makedirs(pdf_outputs_path, exist_ok=True) 28 | annotator = PdfAnnotator(str(pdf_path)) 29 | segment_index = 0 30 | current_page = 1 31 | 32 | for segment_box in segment_boxes: 33 | if int(segment_box["page_number"]) != current_page: 34 | segment_index = 0 35 | current_page += 1 36 | page_height = int(segment_box["page_height"]) 37 | self._add_prediction_annotation(annotator, segment_box, segment_index, page_height) 38 | segment_index += 1 39 | 40 | annotator.write(str(pdf_path)) 41 | return pdf_path 42 | 43 | def get_visualization_response(self, pdf_path: Path) -> FileResponse: 44 | return FileResponse(path=pdf_path, media_type="application/pdf", filename=pdf_path.name) 45 | 46 | def _hex_color_to_rgb(self, color: str) -> tuple: 47 | r, g, b = int(color[1:3], 16), int(color[3:5], 16), int(color[5:7], 16) 48 | alpha = 1 49 | return r / 255, g / 255, b / 255, alpha 50 | 51 | def _add_prediction_annotation( 52 | self, annotator: PdfAnnotator, segment_box: dict, segment_index: int, page_height: int 53 | ) -> None: 54 | predicted_type = segment_box["type"] 55 | color = DOCLAYNET_COLOR_BY_TYPE[predicted_type] 56 | left, top, right, bottom = ( 57 | segment_box["left"], 58 | page_height - segment_box["top"], 59 | segment_box["left"] + segment_box["width"], 60 | page_height - (segment_box["top"] + segment_box["height"]), 61 | ) 62 | text_box_size = len(predicted_type) * 8 + 8 63 | 64 | annotator.add_annotation( 65 | "square", 66 | Location(x1=left, y1=bottom, x2=right, y2=top, page=int(segment_box["page_number"]) - 1), 67 | Appearance(stroke_color=self._hex_color_to_rgb(color)), 68 | ) 69 | 70 | annotator.add_annotation( 71 | "square", 72 | Location(x1=left, y1=top, x2=left + text_box_size, y2=top + 10, page=int(segment_box["page_number"]) - 1), 73 | Appearance(fill=self._hex_color_to_rgb(color)), 74 | ) 75 | 76 | content = predicted_type.capitalize() + f" [{str(segment_index+1)}]" 77 | annotator.add_annotation( 78 | "text", 79 | Location(x1=left, y1=top, x2=left + text_box_size, y2=top + 10, page=int(segment_box["page_number"]) - 1), 80 | Appearance(content=content, font_size=8, fill=(1, 1, 1), stroke_width=3), 81 | ) 82 | -------------------------------------------------------------------------------- /src/adapters/ml/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/ml/__init__.py -------------------------------------------------------------------------------- /src/adapters/ml/fast_trainer/Paragraph.py: -------------------------------------------------------------------------------- 1 | from pdf_features import PdfToken 2 | 3 | 4 | class Paragraph: 5 | def __init__(self, tokens: list[PdfToken], pdf_name: str = ""): 6 | self.tokens = tokens 7 | self.pdf_name = pdf_name 8 | 9 | def add_token(self, token: PdfToken): 10 | self.tokens.append(token) 11 | -------------------------------------------------------------------------------- /src/adapters/ml/fast_trainer/ParagraphExtractorTrainer.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from adapters.ml.fast_trainer.Paragraph import Paragraph 4 | from domain.PdfSegment import PdfSegment 5 | from pdf_features import PdfToken 6 | from pdf_token_type_labels import TokenType 7 | from adapters.ml.pdf_tokens_type_trainer.TokenFeatures import TokenFeatures 8 | from adapters.ml.pdf_tokens_type_trainer.TokenTypeTrainer import TokenTypeTrainer 9 | 10 | 11 | class ParagraphExtractorTrainer(TokenTypeTrainer): 12 | def get_context_features(self, token_features: TokenFeatures, page_tokens: list[PdfToken], token_index: int): 13 | token_row_features = list() 14 | first_token_from_context = token_index - self.model_configuration.context_size 15 | for i in range(self.model_configuration.context_size * 2): 16 | first_token = page_tokens[first_token_from_context + i] 17 | second_token = page_tokens[first_token_from_context + i + 1] 18 | features = token_features.get_features(first_token, second_token, page_tokens) 19 | features += self.get_paragraph_extraction_features(first_token, second_token) 20 | token_row_features.extend(features) 21 | 22 | return token_row_features 23 | 24 | @staticmethod 25 | def get_paragraph_extraction_features(first_token: PdfToken, second_token: PdfToken) -> list[int]: 26 | one_hot_token_type_1 = [1 if token_type == first_token.token_type else 0 for token_type in TokenType] 27 | one_hot_token_type_2 = [1 if token_type == second_token.token_type else 0 for token_type in TokenType] 28 | return one_hot_token_type_1 + one_hot_token_type_2 29 | 30 | def loop_token_next_token(self): 31 | for pdf_features in self.pdfs_features: 32 | for page in pdf_features.pages: 33 | if not page.tokens: 34 | continue 35 | if len(page.tokens) == 1: 36 | yield page, page.tokens[0], page.tokens[0] 37 | for token, next_token in zip(page.tokens, page.tokens[1:]): 38 | yield page, token, next_token 39 | 40 | def get_pdf_segments(self, paragraph_extractor_model_path: str | Path) -> list[PdfSegment]: 41 | paragraphs = self.get_paragraphs(paragraph_extractor_model_path) 42 | pdf_segments = [PdfSegment.from_pdf_tokens(paragraph.tokens, paragraph.pdf_name) for paragraph in paragraphs] 43 | 44 | return pdf_segments 45 | 46 | def get_paragraphs(self, paragraph_extractor_model_path) -> list[Paragraph]: 47 | self.predict(paragraph_extractor_model_path) 48 | paragraphs: list[Paragraph] = [] 49 | last_page = None 50 | for page, token, next_token in self.loop_token_next_token(): 51 | if last_page != page: 52 | last_page = page 53 | paragraphs.append(Paragraph([token], page.pdf_name)) 54 | if token == next_token: 55 | continue 56 | if token.prediction: 57 | paragraphs[-1].add_token(next_token) 58 | continue 59 | paragraphs.append(Paragraph([next_token], page.pdf_name)) 60 | 61 | return paragraphs 62 | -------------------------------------------------------------------------------- /src/adapters/ml/fast_trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/ml/fast_trainer/__init__.py -------------------------------------------------------------------------------- /src/adapters/ml/fast_trainer/model_configuration.py: -------------------------------------------------------------------------------- 1 | from adapters.ml.pdf_tokens_type_trainer.ModelConfiguration import ModelConfiguration 2 | 3 | config_json = { 4 | "boosting_type": "gbdt", 5 | "verbose": -1, 6 | "learning_rate": 0.1, 7 | "num_class": 2, 8 | "context_size": 1, 9 | "num_boost_round": 400, 10 | "num_leaves": 191, 11 | "bagging_fraction": 0.9166599392739231, 12 | "bagging_freq": 7, 13 | "feature_fraction": 0.3116707710163228, 14 | "lambda_l1": 0.0006901861637621734, 15 | "lambda_l2": 1.1886914989632197e-05, 16 | "min_data_in_leaf": 50, 17 | "feature_pre_filter": True, 18 | "seed": 22, 19 | "deterministic": True, 20 | } 21 | 22 | MODEL_CONFIGURATION = ModelConfiguration(**config_json) 23 | 24 | if __name__ == "__main__": 25 | print(MODEL_CONFIGURATION) 26 | -------------------------------------------------------------------------------- /src/adapters/ml/fast_trainer_adapter.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | from domain.PdfImages import PdfImages 3 | from domain.PdfSegment import PdfSegment 4 | from ports.services.ml_model_service import MLModelService 5 | from adapters.ml.fast_trainer.ParagraphExtractorTrainer import ParagraphExtractorTrainer 6 | from adapters.ml.fast_trainer.model_configuration import MODEL_CONFIGURATION as PARAGRAPH_EXTRACTION_CONFIGURATION 7 | from adapters.ml.pdf_tokens_type_trainer.TokenTypeTrainer import TokenTypeTrainer 8 | from adapters.ml.pdf_tokens_type_trainer.ModelConfiguration import ModelConfiguration 9 | from configuration import ROOT_PATH, service_logger 10 | 11 | 12 | class FastTrainerAdapter(MLModelService): 13 | def predict_document_layout(self, pdf_images: list[PdfImages]) -> list[PdfSegment]: 14 | return self.predict_layout_fast(pdf_images) 15 | 16 | def predict_layout_fast(self, pdf_images: list[PdfImages]) -> list[PdfSegment]: 17 | service_logger.info("Creating Paragraph Tokens [fast]") 18 | 19 | pdf_images_obj = pdf_images[0] 20 | 21 | token_type_trainer = TokenTypeTrainer([pdf_images_obj.pdf_features], ModelConfiguration()) 22 | token_type_trainer.set_token_types(join(ROOT_PATH, "models", "token_type_lightgbm.model")) 23 | 24 | trainer = ParagraphExtractorTrainer( 25 | pdfs_features=[pdf_images_obj.pdf_features], model_configuration=PARAGRAPH_EXTRACTION_CONFIGURATION 26 | ) 27 | segments = trainer.get_pdf_segments(join(ROOT_PATH, "models", "paragraph_extraction_lightgbm.model")) 28 | 29 | pdf_images_obj.remove_images() 30 | 31 | return segments 32 | -------------------------------------------------------------------------------- /src/adapters/ml/pdf_tokens_type_trainer/ModelConfiguration.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, asdict 2 | 3 | from pdf_token_type_labels import TokenType 4 | 5 | 6 | @dataclass 7 | class ModelConfiguration: 8 | context_size: int = 4 9 | num_boost_round: int = 700 10 | num_leaves: int = 127 11 | bagging_fraction: float = 0.6810645192499981 12 | lambda_l1: float = 1.1533558410486358e-08 13 | lambda_l2: float = 4.91211684620458 14 | feature_fraction: float = 0.7087268965467017 15 | bagging_freq: int = 10 16 | min_data_in_leaf: int = 47 17 | feature_pre_filter: bool = False 18 | boosting_type: str = "gbdt" 19 | objective: str = "multiclass" 20 | metric: str = "multi_logloss" 21 | learning_rate: float = 0.1 22 | seed: int = 22 23 | num_class: int = len(TokenType) 24 | verbose: int = -1 25 | deterministic: bool = True 26 | resume_training: bool = False 27 | early_stopping_rounds: int = None 28 | 29 | def dict(self): 30 | return asdict(self) 31 | -------------------------------------------------------------------------------- /src/adapters/ml/pdf_tokens_type_trainer/PdfTrainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import exists, join 3 | from pathlib import Path 4 | 5 | import lightgbm as lgb 6 | import numpy as np 7 | 8 | from pdf_features import PdfFeatures, PdfTokenStyle 9 | from pdf_features import PdfFont 10 | from pdf_features import PdfToken 11 | from pdf_features import Rectangle 12 | from pdf_token_type_labels import TokenType 13 | from adapters.ml.pdf_tokens_type_trainer.ModelConfiguration import ModelConfiguration 14 | from adapters.ml.pdf_tokens_type_trainer.download_models import pdf_tokens_type_model 15 | 16 | 17 | class PdfTrainer: 18 | def __init__(self, pdfs_features: list[PdfFeatures], model_configuration: ModelConfiguration = None): 19 | self.pdfs_features = pdfs_features 20 | self.model_configuration = model_configuration if model_configuration else ModelConfiguration() 21 | 22 | def get_model_input(self) -> np.ndarray: 23 | pass 24 | 25 | @staticmethod 26 | def features_rows_to_x(features_rows): 27 | if not features_rows: 28 | return np.zeros((0, 0)) 29 | 30 | x = np.zeros(((len(features_rows)), len(features_rows[0]))) 31 | for i, v in enumerate(features_rows): 32 | x[i] = v 33 | return x 34 | 35 | def train(self, model_path: str | Path, labels: list[int]): 36 | print("Getting model input") 37 | x_train = self.get_model_input() 38 | 39 | if not x_train.any(): 40 | print("No data for training") 41 | return 42 | 43 | lgb_train = lgb.Dataset(x_train, labels) 44 | lgb_eval = lgb.Dataset(x_train, labels, reference=lgb_train) 45 | print("Training") 46 | 47 | if self.model_configuration.resume_training and exists(model_path): 48 | model = lgb.Booster(model_file=model_path) 49 | gbm = model.refit(x_train, labels) 50 | else: 51 | gbm = lgb.train(params=self.model_configuration.dict(), train_set=lgb_train, valid_sets=[lgb_eval]) 52 | 53 | print("Saving") 54 | gbm.save_model(model_path, num_iteration=gbm.best_iteration) 55 | 56 | def loop_tokens(self): 57 | for pdf_features in self.pdfs_features: 58 | for page, token in pdf_features.loop_tokens(): 59 | yield token 60 | 61 | @staticmethod 62 | def get_padding_token(segment_number: int, page_number: int): 63 | return PdfToken( 64 | page_number=page_number, 65 | id="pad_token", 66 | content="", 67 | font=PdfFont(font_id="pad_font_id", font_size=0, bold=False, italics=False, color="black"), 68 | reading_order_no=segment_number, 69 | bounding_box=Rectangle.from_coordinates(0, 0, 0, 0), 70 | token_type=TokenType.TEXT, 71 | token_style=PdfTokenStyle( 72 | font=PdfFont(font_id="pad_font_id", font_size=0, bold=False, italics=False, color="black") 73 | ), 74 | ) 75 | 76 | def predict(self, model_path: str | Path = None): 77 | model_path = model_path if model_path else pdf_tokens_type_model 78 | x = self.get_model_input() 79 | 80 | if not x.any(): 81 | return self.pdfs_features 82 | 83 | lightgbm_model = lgb.Booster(model_file=model_path) 84 | return lightgbm_model.predict(x) 85 | 86 | def save_training_data(self, save_folder_path: str | Path, labels: list[int]): 87 | os.makedirs(save_folder_path, exist_ok=True) 88 | 89 | x = self.get_model_input() 90 | 91 | np.save(join(str(save_folder_path), "x.npy"), x) 92 | np.save(join(str(save_folder_path), "y.npy"), labels) 93 | -------------------------------------------------------------------------------- /src/adapters/ml/pdf_tokens_type_trainer/TokenFeatures.py: -------------------------------------------------------------------------------- 1 | import string 2 | import unicodedata 3 | 4 | from pdf_features import PdfFeatures 5 | from pdf_features import PdfToken 6 | from adapters.ml.pdf_tokens_type_trainer.config import CHARACTER_TYPE 7 | 8 | 9 | class TokenFeatures: 10 | def __init__(self, pdfs_features: PdfFeatures): 11 | self.pdfs_features = pdfs_features 12 | 13 | def get_features(self, token_1: PdfToken, token_2: PdfToken, page_tokens: list[PdfToken]): 14 | same_font = True if token_1.font.font_id == token_2.font.font_id else False 15 | 16 | return ( 17 | [ 18 | same_font, 19 | self.pdfs_features.pdf_modes.font_size_mode / 100, 20 | len(token_1.content), 21 | len(token_2.content), 22 | token_1.content.count(" "), 23 | token_2.content.count(" "), 24 | sum(character in string.punctuation for character in token_1.content), 25 | sum(character in string.punctuation for character in token_2.content), 26 | ] 27 | + self.get_position_features(token_1, token_2, page_tokens) 28 | + self.get_unicode_categories(token_1) 29 | + self.get_unicode_categories(token_2) 30 | ) 31 | 32 | def get_position_features(self, token_1: PdfToken, token_2: PdfToken, page_tokens): 33 | left_1 = token_1.bounding_box.left 34 | right_1 = token_1.bounding_box.right 35 | height_1 = token_1.bounding_box.height 36 | width_1 = token_1.bounding_box.width 37 | 38 | left_2 = token_2.bounding_box.left 39 | right_2 = token_2.bounding_box.right 40 | height_2 = token_2.bounding_box.height 41 | width_2 = token_2.bounding_box.width 42 | 43 | right_gap_1, left_gap_2 = ( 44 | token_1.pdf_token_context.left_of_token_on_the_right - right_1, 45 | left_2 - token_2.pdf_token_context.right_of_token_on_the_left, 46 | ) 47 | 48 | absolute_right_1 = max(right_1, token_1.pdf_token_context.right_of_token_on_the_right) 49 | absolute_right_2 = max(right_2, token_2.pdf_token_context.right_of_token_on_the_right) 50 | 51 | absolute_left_1 = min(left_1, token_1.pdf_token_context.left_of_token_on_the_left) 52 | absolute_left_2 = min(left_2, token_2.pdf_token_context.left_of_token_on_the_left) 53 | 54 | right_distance, left_distance, height_difference = left_2 - left_1 - width_1, left_1 - left_2, height_1 - height_2 55 | 56 | top_distance = token_2.bounding_box.top - token_1.bounding_box.top - height_1 57 | top_distance_gaps = self.get_top_distance_gap(token_1, token_2, page_tokens) 58 | 59 | start_lines_differences = absolute_left_1 - absolute_left_2 60 | end_lines_difference = abs(absolute_right_1 - absolute_right_2) 61 | 62 | return [ 63 | absolute_right_1, 64 | token_1.bounding_box.top, 65 | right_1, 66 | width_1, 67 | height_1, 68 | token_2.bounding_box.top, 69 | right_2, 70 | width_2, 71 | height_2, 72 | right_distance, 73 | left_distance, 74 | right_gap_1, 75 | left_gap_2, 76 | height_difference, 77 | top_distance, 78 | top_distance - self.pdfs_features.pdf_modes.lines_space_mode, 79 | top_distance_gaps, 80 | top_distance - height_1, 81 | end_lines_difference, 82 | start_lines_differences, 83 | self.pdfs_features.pdf_modes.lines_space_mode - top_distance_gaps, 84 | self.pdfs_features.pdf_modes.right_space_mode - absolute_right_1, 85 | ] 86 | 87 | @staticmethod 88 | def get_top_distance_gap(token_1: PdfToken, token_2: PdfToken, page_tokens): 89 | top_distance = token_2.bounding_box.top - token_1.bounding_box.top - token_1.bounding_box.height 90 | tokens_in_the_middle = [ 91 | token 92 | for token in page_tokens 93 | if token_1.bounding_box.bottom <= token.bounding_box.top < token_2.bounding_box.top 94 | ] 95 | 96 | gap_middle_bottom = 0 97 | gap_middle_top = 0 98 | 99 | if tokens_in_the_middle: 100 | tokens_in_the_middle_top = min([token.bounding_box.top for token in tokens_in_the_middle]) 101 | tokens_in_the_middle_bottom = max([token.bounding_box.bottom for token in tokens_in_the_middle]) 102 | gap_middle_top = tokens_in_the_middle_top - token_1.bounding_box.top - token_1.bounding_box.height 103 | gap_middle_bottom = token_2.bounding_box.top - tokens_in_the_middle_bottom 104 | 105 | top_distance_gaps = top_distance - (gap_middle_bottom - gap_middle_top) 106 | return top_distance_gaps 107 | 108 | @staticmethod 109 | def get_unicode_categories(token: PdfToken): 110 | if token.id == "pad_token": 111 | return [-1] * len(CHARACTER_TYPE) * 4 112 | 113 | categories = [unicodedata.category(letter) for letter in token.content[:2] + token.content[-2:]] 114 | categories += ["no_category"] * (4 - len(categories)) 115 | 116 | categories_one_hot_encoding = list() 117 | 118 | for category in categories: 119 | categories_one_hot_encoding.extend([0] * len(CHARACTER_TYPE)) 120 | if category not in CHARACTER_TYPE: 121 | continue 122 | 123 | category_index = len(categories_one_hot_encoding) - len(CHARACTER_TYPE) + CHARACTER_TYPE.index(category) 124 | categories_one_hot_encoding[category_index] = 1 125 | 126 | return categories_one_hot_encoding 127 | -------------------------------------------------------------------------------- /src/adapters/ml/pdf_tokens_type_trainer/TokenTypeTrainer.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | from pdf_features import PdfToken 7 | from pdf_token_type_labels import TokenType 8 | from adapters.ml.pdf_tokens_type_trainer.PdfTrainer import PdfTrainer 9 | from adapters.ml.pdf_tokens_type_trainer.TokenFeatures import TokenFeatures 10 | 11 | 12 | class TokenTypeTrainer(PdfTrainer): 13 | def get_model_input(self) -> np.ndarray: 14 | features_rows = [] 15 | 16 | contex_size = self.model_configuration.context_size 17 | for token_features, page in self.loop_token_features(): 18 | page_tokens = [ 19 | self.get_padding_token(segment_number=i - 999999, page_number=page.page_number) for i in range(contex_size) 20 | ] 21 | page_tokens += page.tokens 22 | page_tokens += [ 23 | self.get_padding_token(segment_number=999999 + i, page_number=page.page_number) for i in range(contex_size) 24 | ] 25 | 26 | tokens_indexes = range(contex_size, len(page_tokens) - contex_size) 27 | page_features = [self.get_context_features(token_features, page_tokens, i) for i in tokens_indexes] 28 | features_rows.extend(page_features) 29 | 30 | return self.features_rows_to_x(features_rows) 31 | 32 | def loop_token_features(self): 33 | for pdf_features in tqdm(self.pdfs_features): 34 | token_features = TokenFeatures(pdf_features) 35 | 36 | for page in pdf_features.pages: 37 | if not page.tokens: 38 | continue 39 | 40 | yield token_features, page 41 | 42 | def get_context_features(self, token_features: TokenFeatures, page_tokens: list[PdfToken], token_index: int): 43 | token_row_features = [] 44 | first_token_from_context = token_index - self.model_configuration.context_size 45 | for i in range(self.model_configuration.context_size * 2): 46 | first_token = page_tokens[first_token_from_context + i] 47 | second_token = page_tokens[first_token_from_context + i + 1] 48 | token_row_features.extend(token_features.get_features(first_token, second_token, page_tokens)) 49 | 50 | return token_row_features 51 | 52 | def predict(self, model_path: str | Path = None): 53 | predictions = super().predict(model_path) 54 | predictions_assigned = 0 55 | for token_features, page in self.loop_token_features(): 56 | for token, prediction in zip( 57 | page.tokens, predictions[predictions_assigned : predictions_assigned + len(page.tokens)] 58 | ): 59 | token.prediction = int(np.argmax(prediction)) 60 | 61 | predictions_assigned += len(page.tokens) 62 | 63 | def set_token_types(self, model_path: str | Path = None): 64 | self.predict(model_path) 65 | for token in self.loop_tokens(): 66 | token.token_type = TokenType.from_index(token.prediction) 67 | -------------------------------------------------------------------------------- /src/adapters/ml/pdf_tokens_type_trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/ml/pdf_tokens_type_trainer/__init__.py -------------------------------------------------------------------------------- /src/adapters/ml/pdf_tokens_type_trainer/config.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | from pathlib import Path 3 | 4 | ROOT_PATH = Path(__file__).parent.parent.parent.parent.parent.absolute() 5 | PDF_LABELED_DATA_ROOT_PATH = Path(join(ROOT_PATH.parent.absolute(), "pdf-labeled-data")) 6 | TOKEN_TYPE_LABEL_PATH = Path(join(PDF_LABELED_DATA_ROOT_PATH, "labeled_data", "token_type")) 7 | 8 | TRAINED_MODEL_PATH = join(ROOT_PATH, "model", "pdf_tokens_type.model") 9 | TOKEN_TYPE_RELATIVE_PATH = join("labeled_data", "token_type") 10 | MISTAKES_RELATIVE_PATH = join("labeled_data", "task_mistakes") 11 | 12 | XML_NAME = "etree.xml" 13 | LABELS_FILE_NAME = "labels.json" 14 | STATUS_FILE_NAME = "status.txt" 15 | 16 | CHARACTER_TYPE = [ 17 | "Lt", 18 | "Lo", 19 | "Sk", 20 | "Lm", 21 | "Sm", 22 | "Cf", 23 | "Nl", 24 | "Pe", 25 | "Po", 26 | "Pd", 27 | "Me", 28 | "Sc", 29 | "Ll", 30 | "Pf", 31 | "Mc", 32 | "Lu", 33 | "Zs", 34 | "Cn", 35 | "Cc", 36 | "No", 37 | "Co", 38 | "Ps", 39 | "Nd", 40 | "Mn", 41 | "Pi", 42 | "So", 43 | "Pc", 44 | ] 45 | 46 | if __name__ == "__main__": 47 | print(ROOT_PATH) 48 | print(PDF_LABELED_DATA_ROOT_PATH) 49 | -------------------------------------------------------------------------------- /src/adapters/ml/pdf_tokens_type_trainer/download_models.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import hf_hub_download 2 | 3 | pdf_tokens_type_model = hf_hub_download( 4 | repo_id="HURIDOCS/pdf-segmentation", 5 | filename="pdf_tokens_type.model", 6 | revision="c71f833500707201db9f3649a6d2010d3ce9d4c9", 7 | ) 8 | 9 | token_type_finding_config_path = hf_hub_download( 10 | repo_id="HURIDOCS/pdf-segmentation", 11 | filename="tag_type_finding_model_config.txt", 12 | revision="7d98776dd34acb2fe3a06495c82e64b9c84bdc16", 13 | ) 14 | -------------------------------------------------------------------------------- /src/adapters/ml/pdf_tokens_type_trainer/get_paths.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | from pathlib import Path 3 | 4 | 5 | def get_xml_path(pdf_labeled_data_project_path: str): 6 | return Path(join(pdf_labeled_data_project_path, "pdfs")) 7 | -------------------------------------------------------------------------------- /src/adapters/ml/pdf_tokens_type_trainer/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/ml/pdf_tokens_type_trainer/tests/__init__.py -------------------------------------------------------------------------------- /src/adapters/ml/pdf_tokens_type_trainer/tests/test_trainer.py: -------------------------------------------------------------------------------- 1 | from os.path import join, exists 2 | from unittest import TestCase 3 | 4 | from pdf_token_type_labels import TokenType 5 | from adapters.ml.pdf_tokens_type_trainer.TokenTypeTrainer import TokenTypeTrainer 6 | 7 | from pdf_features import PdfFeatures 8 | 9 | from configuration import ROOT_PATH 10 | 11 | 12 | class TestTrainer(TestCase): 13 | def test_train_blank_pdf(self): 14 | pdf_features = PdfFeatures.from_pdf_path(join(ROOT_PATH, "test_pdfs", "blank.pdf")) 15 | model_path = join(ROOT_PATH, "model", "blank.model") 16 | trainer = TokenTypeTrainer([pdf_features]) 17 | trainer.train(model_path, []) 18 | self.assertFalse(exists(model_path)) 19 | 20 | def test_predict_blank_pdf(self): 21 | pdf_features = PdfFeatures.from_pdf_path(join(ROOT_PATH, "test_pdfs", "blank.pdf")) 22 | trainer = TokenTypeTrainer([pdf_features]) 23 | trainer.set_token_types() 24 | self.assertEqual([], pdf_features.pages[0].tokens) 25 | 26 | def test_predict(self): 27 | pdf_features = PdfFeatures.from_pdf_path(join(ROOT_PATH, "test_pdfs", "test.pdf")) 28 | trainer = TokenTypeTrainer([pdf_features]) 29 | trainer.set_token_types() 30 | tokens = pdf_features.pages[0].tokens 31 | self.assertEqual(TokenType.TITLE, tokens[0].token_type) 32 | self.assertEqual("Document Big Centered Title", tokens[0].content) 33 | self.assertEqual(TokenType.TEXT, tokens[1].token_type) 34 | self.assertEqual("List Title", tokens[10].content) 35 | -------------------------------------------------------------------------------- /src/adapters/ml/vgt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/ml/vgt/__init__.py -------------------------------------------------------------------------------- /src/adapters/ml/vgt/bros/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # Copyright 2021 NAVER CLOVA Team. All rights reserved. 6 | # Copyright 2020 The HuggingFace Team. All rights reserved. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | from typing import TYPE_CHECKING 21 | 22 | from transformers.file_utils import ( 23 | _LazyModule, 24 | is_tokenizers_available, 25 | is_torch_available, 26 | ) 27 | 28 | _import_structure = { 29 | "configuration_bros": ["BROS_PRETRAINED_CONFIG_ARCHIVE_MAP", "BrosConfig"], 30 | "tokenization_bros": ["BrosTokenizer"], 31 | } 32 | 33 | if is_tokenizers_available(): 34 | _import_structure["tokenization_bros_fast"] = ["BrosTokenizerFast"] 35 | 36 | if is_torch_available(): 37 | _import_structure["modeling_bros"] = [ 38 | "BROS_PRETRAINED_MODEL_ARCHIVE_LIST", 39 | "BrosForMaskedLM", 40 | "BrosForPreTraining", 41 | "BrosForSequenceClassification", 42 | "BrosForTokenClassification", 43 | "BrosModel", 44 | "BrosLMHeadModel", 45 | "BrosPreTrainedModel", 46 | ] 47 | 48 | if TYPE_CHECKING: 49 | from .configuration_bros import BROS_PRETRAINED_CONFIG_ARCHIVE_MAP, BrosConfig 50 | from .tokenization_bros import BrosTokenizer 51 | 52 | if is_tokenizers_available(): 53 | from .tokenization_bros_fast import BrosTokenizerFast 54 | 55 | if is_torch_available(): 56 | from .modeling_bros import ( 57 | BROS_PRETRAINED_MODEL_ARCHIVE_LIST, 58 | BrosForMaskedLM, 59 | BrosForPreTraining, 60 | BrosForSequenceClassification, 61 | BrosForTokenClassification, 62 | BrosLMHeadModel, 63 | BrosModel, 64 | BrosPreTrainedModel, 65 | ) 66 | 67 | else: 68 | import sys 69 | 70 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) 71 | -------------------------------------------------------------------------------- /src/adapters/ml/vgt/bros/tokenization_bros.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022-present NAVER Corp, The Microsoft Research Asia LayoutLM Team Authors and the HuggingFace Inc. team.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for BROS.""" 16 | 17 | 18 | import collections 19 | 20 | from transformers.models.bert.tokenization_bert import BertTokenizer 21 | from transformers.utils import logging 22 | 23 | logger = logging.get_logger(__name__) 24 | 25 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 26 | 27 | PRETRAINED_VOCAB_FILES_MAP = { 28 | "vocab_file": { 29 | "naver-clova-ocr/bros-base-uncased": "https://huggingface.co/naver-clova-ocr/bros-base-uncased/resolve/main/vocab.txt", 30 | "naver-clova-ocr/bros-large-uncased": "https://huggingface.co/naver-clova-ocr/bros-large-uncased/resolve/main/vocab.txt", 31 | } 32 | } 33 | 34 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 35 | "naver-clova-ocr/bros-base-uncased": 512, 36 | "naver-clova-ocr/bros-large-uncased": 512, 37 | } 38 | 39 | PRETRAINED_INIT_CONFIGURATION = { 40 | "naver-clova-ocr/bros-base-uncased": {"do_lower_case": True}, 41 | "naver-clova-ocr/bros-large-uncased": {"do_lower_case": True}, 42 | } 43 | 44 | 45 | def load_vocab(vocab_file): 46 | """Loads a vocabulary file into a dictionary.""" 47 | vocab = collections.OrderedDict() 48 | with open(vocab_file, "r", encoding="utf-8") as reader: 49 | tokens = reader.readlines() 50 | for index, token in enumerate(tokens): 51 | token = token.rstrip("\n") 52 | vocab[token] = index 53 | return vocab 54 | 55 | 56 | def whitespace_tokenize(text): 57 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 58 | text = text.strip() 59 | if not text: 60 | return [] 61 | tokens = text.split() 62 | return tokens 63 | 64 | 65 | class BrosTokenizer(BertTokenizer): 66 | r""" 67 | Construct a BERT tokenizer. Based on WordPiece. 68 | 69 | This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods. 70 | Users should refer to this superclass for more information regarding those methods. 71 | 72 | Args: 73 | vocab_file (:obj:`str`): 74 | File containing the vocabulary. 75 | do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): 76 | Whether or not to lowercase the input when tokenizing. 77 | do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`): 78 | Whether or not to do basic tokenization before WordPiece. 79 | never_split (:obj:`Iterable`, `optional`): 80 | Collection of tokens which will never be split during tokenization. Only has an effect when 81 | :obj:`do_basic_tokenize=True` 82 | unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`): 83 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this 84 | token instead. 85 | sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`): 86 | The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for 87 | sequence classification or for a text and a question for question answering. It is also used as the last 88 | token of a sequence built with special tokens. 89 | pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`): 90 | The token used for padding, for example when batching sequences of different lengths. 91 | cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`): 92 | The classifier token which is used when doing sequence classification (classification of the whole sequence 93 | instead of per-token classification). It is the first token of the sequence when built with special tokens. 94 | mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`): 95 | The token used for masking values. This is the token used when training this model with masked language 96 | modeling. This is the token which the model will try to predict. 97 | tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): 98 | Whether or not to tokenize Chinese characters. 99 | 100 | This should likely be deactivated for Japanese (see this `issue 101 | `__). 102 | strip_accents: (:obj:`bool`, `optional`): 103 | Whether or not to strip all accents. If this option is not specified, then it will be determined by the 104 | value for :obj:`lowercase` (as in the original BERT). 105 | """ 106 | 107 | vocab_files_names = VOCAB_FILES_NAMES 108 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 109 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 110 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 111 | -------------------------------------------------------------------------------- /src/adapters/ml/vgt/bros/tokenization_bros_fast.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022-present NAVER Corp, The Microsoft Research Asia LayoutLM Team Authors and the HuggingFace Inc. team.. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Fast Tokenization classes for BROS.""" 16 | 17 | from transformers.models.bert.tokenization_bert_fast import BertTokenizerFast 18 | from transformers.utils import logging 19 | 20 | from .tokenization_bros import BrosTokenizer 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"} 25 | 26 | PRETRAINED_VOCAB_FILES_MAP = { 27 | "vocab_file": { 28 | "naver-clova-ocr/bros-base-uncased": "https://huggingface.co/naver-clova-ocr/bros-base-uncased/resolve/main/vocab.txt", 29 | "naver-clova-ocr/bros-large-uncased": "https://huggingface.co/naver-clova-ocr/bros-large-uncased/resolve/main/vocab.txt", 30 | }, 31 | "tokenizer_file": { 32 | "naver-clova-ocr/bros-base-uncased": "https://huggingface.co/naver-clova-ocr/bros-base-uncased/resolve/main/tokenizer.json", 33 | "naver-clova-ocr/bros-large-uncased": "https://huggingface.co/naver-clova-ocr/bros-large-uncased/resolve/main/tokenizer.json", 34 | }, 35 | } 36 | 37 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 38 | "naver-clova-ocr/bros-base-uncased": 512, 39 | "naver-clova-ocr/bros-large-uncased": 512, 40 | } 41 | 42 | PRETRAINED_INIT_CONFIGURATION = { 43 | "naver-clova-ocr/bros-base-uncased": {"do_lower_case": True}, 44 | "naver-clova-ocr/bros-large-uncased": {"do_lower_case": True}, 45 | } 46 | 47 | 48 | class BrosTokenizerFast(BertTokenizerFast): 49 | r""" 50 | Construct a "fast" BERT tokenizer (backed by HuggingFace's `tokenizers` library). Based on WordPiece. 51 | 52 | This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the main 53 | methods. Users should refer to this superclass for more information regarding those methods. 54 | 55 | Args: 56 | vocab_file (:obj:`str`): 57 | File containing the vocabulary. 58 | do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): 59 | Whether or not to lowercase the input when tokenizing. 60 | unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`): 61 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this 62 | token instead. 63 | sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`): 64 | The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for 65 | sequence classification or for a text and a question for question answering. It is also used as the last 66 | token of a sequence built with special tokens. 67 | pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`): 68 | The token used for padding, for example when batching sequences of different lengths. 69 | cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`): 70 | The classifier token which is used when doing sequence classification (classification of the whole sequence 71 | instead of per-token classification). It is the first token of the sequence when built with special tokens. 72 | mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`): 73 | The token used for masking values. This is the token used when training this model with masked language 74 | modeling. This is the token which the model will try to predict. 75 | clean_text (:obj:`bool`, `optional`, defaults to :obj:`True`): 76 | Whether or not to clean the text before tokenization by removing any control characters and replacing all 77 | whitespaces by the classic one. 78 | tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): 79 | Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see `this 80 | issue `__). 81 | strip_accents: (:obj:`bool`, `optional`): 82 | Whether or not to strip all accents. If this option is not specified, then it will be determined by the 83 | value for :obj:`lowercase` (as in the original BERT). 84 | wordpieces_prefix: (:obj:`str`, `optional`, defaults to :obj:`"##"`): 85 | The prefix for subwords. 86 | """ 87 | 88 | vocab_files_names = VOCAB_FILES_NAMES 89 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 90 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 91 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 92 | slow_tokenizer_class = BrosTokenizer 93 | -------------------------------------------------------------------------------- /src/adapters/ml/vgt/create_word_grid.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import shutil 3 | 4 | import numpy as np 5 | from os import makedirs 6 | from os.path import join, exists 7 | from pdf_features import PdfToken 8 | from pdf_features import Rectangle 9 | from pdf_features import PdfFeatures 10 | 11 | from adapters.ml.vgt.bros.tokenization_bros import BrosTokenizer 12 | from configuration import WORD_GRIDS_PATH 13 | 14 | tokenizer = BrosTokenizer.from_pretrained("naver-clova-ocr/bros-base-uncased") 15 | 16 | 17 | def rectangle_to_bbox(rectangle: Rectangle): 18 | return [rectangle.left, rectangle.top, rectangle.width, rectangle.height] 19 | 20 | 21 | def get_words_positions(text: str, rectangle: Rectangle): 22 | text = text.strip() 23 | text_len = len(text) 24 | 25 | width_per_letter = rectangle.width / text_len 26 | 27 | words_bboxes = [Rectangle.from_coordinates(rectangle.left, rectangle.top, rectangle.left + 5, rectangle.bottom)] 28 | words_bboxes[-1].width = 0 29 | words_bboxes[-1].right = words_bboxes[-1].left 30 | 31 | for letter in text: 32 | if letter == " ": 33 | left = words_bboxes[-1].right + width_per_letter 34 | words_bboxes.append(Rectangle.from_coordinates(left, words_bboxes[-1].top, left + 5, words_bboxes[-1].bottom)) 35 | words_bboxes[-1].width = 0 36 | words_bboxes[-1].right = words_bboxes[-1].left 37 | else: 38 | words_bboxes[-1].right = words_bboxes[-1].right + width_per_letter 39 | words_bboxes[-1].width = words_bboxes[-1].width + width_per_letter 40 | 41 | words = text.split() 42 | return words, words_bboxes 43 | 44 | 45 | def get_subwords_positions(word: str, rectangle: Rectangle): 46 | width_per_letter = rectangle.width / len(word) 47 | word_tokens = [x.replace("#", "") for x in tokenizer.tokenize(word)] 48 | 49 | if not word_tokens: 50 | return [], [] 51 | 52 | ids = [x[-2] for x in tokenizer(word_tokens)["input_ids"]] 53 | 54 | right = rectangle.left + len(word_tokens[0]) * width_per_letter 55 | bboxes = [Rectangle.from_coordinates(rectangle.left, rectangle.top, right, rectangle.bottom)] 56 | 57 | for subword in word_tokens[1:]: 58 | right = bboxes[-1].right + len(subword) * width_per_letter 59 | bboxes.append(Rectangle.from_coordinates(bboxes[-1].right, rectangle.top, right, rectangle.bottom)) 60 | 61 | return ids, bboxes 62 | 63 | 64 | def get_grid_words_dict(tokens: list[PdfToken]): 65 | texts, bbox_texts_list, inputs_ids, bbox_subword_list = [], [], [], [] 66 | for token in tokens: 67 | words, words_bboxes = get_words_positions(token.content, token.bounding_box) 68 | texts += words 69 | bbox_texts_list += [rectangle_to_bbox(r) for r in words_bboxes] 70 | for word, word_box in zip(words, words_bboxes): 71 | ids, subwords_bboxes = get_subwords_positions(word, word_box) 72 | inputs_ids += ids 73 | bbox_subword_list += [rectangle_to_bbox(r) for r in subwords_bboxes] 74 | 75 | return { 76 | "input_ids": np.array(inputs_ids), 77 | "bbox_subword_list": np.array(bbox_subword_list), 78 | "texts": texts, 79 | "bbox_texts_list": np.array(bbox_texts_list), 80 | } 81 | 82 | 83 | def create_word_grid(pdf_features_list: list[PdfFeatures]): 84 | makedirs(WORD_GRIDS_PATH, exist_ok=True) 85 | 86 | for pdf_features in pdf_features_list: 87 | for page in pdf_features.pages: 88 | image_id = f"{pdf_features.file_name}_{page.page_number - 1}" 89 | if exists(join(WORD_GRIDS_PATH, image_id + ".pkl")): 90 | continue 91 | grid_words_dict = get_grid_words_dict(page.tokens) 92 | with open(join(WORD_GRIDS_PATH, f"{image_id}.pkl"), mode="wb") as file: 93 | pickle.dump(grid_words_dict, file) 94 | 95 | 96 | def remove_word_grids(): 97 | shutil.rmtree(WORD_GRIDS_PATH, ignore_errors=True) 98 | -------------------------------------------------------------------------------- /src/adapters/ml/vgt/ditod/FeatureMerge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class FeatureMerge(nn.Module): 6 | """Multimodal feature fusion used in VSR.""" 7 | 8 | def __init__( 9 | self, 10 | feature_names, 11 | visual_dim, 12 | semantic_dim, 13 | merge_type="Sum", 14 | dropout_ratio=0.1, 15 | with_extra_fc=True, 16 | shortcut=False, 17 | ): 18 | """Multimodal feature merge used in VSR. 19 | Args: 20 | visual_dim (list): the dim of visual features, e.g. [256] 21 | semantic_dim (list): the dim of semantic features, e.g. [256] 22 | merge_type (str): fusion type, e.g. 'Sum', 'Concat', 'Weighted' 23 | dropout_ratio (float): dropout ratio of fusion features 24 | with_extra_fc (bool): whether add extra fc layers for adaptation 25 | shortcut (bool): whether add shortcut connection 26 | """ 27 | super().__init__() 28 | 29 | # merge param 30 | self.feature_names = feature_names 31 | self.merge_type = merge_type 32 | self.visual_dim = visual_dim 33 | self.textual_dim = semantic_dim 34 | self.with_extra_fc = with_extra_fc 35 | self.shortcut = shortcut 36 | self.relu = nn.ReLU(inplace=True) 37 | 38 | if self.merge_type == "Sum": 39 | assert len(self.visual_dim) == len(self.textual_dim) 40 | elif self.merge_type == "Concat": 41 | assert len(self.visual_dim) == len(self.textual_dim) 42 | # self.concat_proj = nn.ModuleList() 43 | 44 | self.vis_proj = nn.ModuleList() 45 | self.text_proj = nn.ModuleList() 46 | self.alpha_proj = nn.ModuleList() 47 | 48 | for idx in range(len(self.visual_dim)): 49 | # self.concat_proj.append(nn.Conv2d(self.visual_dim[idx] + self.textual_dim[idx], self.visual_dim[idx], kernel_size = (1,1), stride=1)) 50 | if self.with_extra_fc: 51 | self.vis_proj.append(nn.Linear(self.visual_dim[idx], self.visual_dim[idx])) 52 | self.text_proj.append(nn.Linear(self.textual_dim[idx], self.textual_dim[idx])) 53 | self.alpha_proj.append(nn.Linear(self.visual_dim[idx] + self.textual_dim[idx], self.visual_dim[idx])) 54 | 55 | elif self.merge_type == "Weighted": 56 | assert len(self.visual_dim) == len(self.textual_dim) 57 | self.total_num = len(self.visual_dim) 58 | 59 | # vis projection 60 | self.vis_proj = nn.ModuleList() 61 | self.vis_proj_relu = nn.ModuleList() 62 | 63 | # text projection 64 | self.text_proj = nn.ModuleList() 65 | self.text_proj_relu = nn.ModuleList() 66 | 67 | self.alpha_proj = nn.ModuleList() 68 | for idx in range(self.total_num): 69 | if self.with_extra_fc: 70 | self.vis_proj.append(nn.Linear(self.visual_dim[idx], self.visual_dim[idx])) 71 | self.text_proj.append(nn.Linear(self.textual_dim[idx], self.textual_dim[idx])) 72 | self.alpha_proj.append(nn.Linear(self.visual_dim[idx] + self.textual_dim[idx], self.visual_dim[idx])) 73 | 74 | else: 75 | raise "Unknown merge type {}".format(self.merge_type) 76 | 77 | self.dropout = nn.Dropout(dropout_ratio) 78 | 79 | # visual context 80 | # self.visual_ap = nn.AdaptiveAvgPool2d((1, 1)) 81 | 82 | def forward(self, visual_feat=None, textual_feat=None): 83 | """Forward computation 84 | Args: 85 | visual_feat (list(Tensor)): visual feature maps, in shape of [L x C x H x W] x B 86 | textual_feat (Tensor): textual feature maps, in shape of B x L x C 87 | Returns: 88 | Tensor: fused feature maps, in shape of [B x L x C] 89 | """ 90 | assert len(visual_feat) == len(textual_feat) 91 | 92 | # feature merge 93 | merged_feat = {} 94 | if self.merge_type == "Sum": 95 | for name in self.feature_names: 96 | merged_feat[name] = visual_feat[name] + textual_feat[name] 97 | elif self.merge_type == "Concat": 98 | for idx, name in enumerate(self.feature_names): 99 | # merged_feat[name] = self.concat_proj[idx](torch.cat((visual_feat[name],textual_feat[name]),1)) 100 | per_vis = visual_feat[name].permute(0, 2, 3, 1) 101 | per_text = textual_feat[name].permute(0, 2, 3, 1) 102 | if self.with_extra_fc: 103 | per_vis = self.relu(self.vis_proj[idx](per_vis)) 104 | per_text = self.relu(self.text_proj[idx](per_text)) 105 | x_sentence = self.alpha_proj[idx](torch.cat((per_vis, per_text), -1)) 106 | x_sentence = x_sentence.permute(0, 3, 1, 2).contiguous() 107 | merged_feat[name] = x_sentence 108 | else: 109 | assert self.total_num == len(visual_feat) or self.total_num == 1 110 | # for per_vis, per_text in zip(visual_feat, textual_feat): 111 | for idx, name in enumerate(self.feature_names): 112 | per_vis = visual_feat[name].permute(0, 2, 3, 1) 113 | per_text = textual_feat[name].permute(0, 2, 3, 1) 114 | if self.with_extra_fc: 115 | per_vis = self.relu(self.vis_proj[idx](per_vis)) 116 | per_text = self.relu(self.text_proj[idx](per_text)) 117 | 118 | alpha = torch.sigmoid(self.alpha_proj[idx](torch.cat((per_vis, per_text), -1))) 119 | if self.shortcut: 120 | # shortcut 121 | x_sentence = per_vis + alpha * per_text 122 | else: 123 | # selection 124 | x_sentence = alpha * per_vis + (1 - alpha) * per_text 125 | 126 | x_sentence = x_sentence.permute(0, 3, 1, 2).contiguous() 127 | merged_feat[name] = x_sentence 128 | 129 | return merged_feat 130 | -------------------------------------------------------------------------------- /src/adapters/ml/vgt/ditod/Wordnn_embedding.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from .tokenization_bros import BrosTokenizer 5 | 6 | 7 | def _init_weights(m): 8 | if isinstance(m, nn.Linear): 9 | # we use xavier_uniform following official JAX ViT: 10 | torch.nn.init.xavier_uniform_(m.weight) 11 | if isinstance(m, nn.Linear) and m.bias is not None: 12 | nn.init.constant_(m.bias, 0) 13 | elif isinstance(m, nn.LayerNorm): 14 | nn.init.constant_(m.bias, 0) 15 | nn.init.constant_(m.weight, 1.0) 16 | 17 | 18 | class WordnnEmbedding(nn.Module): 19 | """Generate chargrid embedding feature map.""" 20 | 21 | def __init__( 22 | self, 23 | vocab_size=30552, 24 | hidden_size=768, 25 | embedding_dim=64, 26 | bros_embedding_path="/bros-base-uncased/", 27 | use_pretrain_weight=True, 28 | use_UNK_text=False, 29 | ): 30 | """ 31 | Args: 32 | vocab_size (int): size of vocabulary. 33 | embedding_dim (int): dim of input features 34 | """ 35 | super().__init__() 36 | 37 | self.embedding = nn.Embedding(vocab_size, hidden_size) 38 | self.embedding_proj = nn.Linear(hidden_size, embedding_dim, bias=False) 39 | # self.tokenizer = BrosTokenizer.from_pretrained(bros_embedding_path) 40 | self.use_pretrain_weight = use_pretrain_weight 41 | self.use_UNK_text = use_UNK_text 42 | 43 | self.init_weights(bros_embedding_path) 44 | self.apply(_init_weights) 45 | 46 | def init_weights(self, bros_embedding_path): 47 | if self.use_pretrain_weight: 48 | state_dict = torch.load(bros_embedding_path + "pytorch_model.bin", map_location="cpu") 49 | if "bert" in bros_embedding_path: 50 | word_embs = state_dict["bert.embeddings.word_embeddings.weight"] 51 | elif "bros" in bros_embedding_path: 52 | word_embs = state_dict["embeddings.word_embeddings.weight"] 53 | elif "layoutlm" in bros_embedding_path: 54 | word_embs = state_dict["layoutlm.embeddings.word_embeddings.weight"] 55 | else: 56 | print("Wrong bros_embedding_path!") 57 | self.embedding = nn.Embedding.from_pretrained(word_embs) 58 | print("use_pretrain_weight: load model from:", bros_embedding_path) 59 | 60 | def forward(self, img, batched_inputs, stride=1): 61 | """Forward computation 62 | Args: 63 | img (Tensor): in shape of [B x 3 x H x W] 64 | batched_inputs (list[dict]): 65 | Returns: 66 | Tensor: in shape of [B x N x L x D], where D is the embedding_dim. 67 | """ 68 | device = img.device 69 | batch_b, _, batch_h, batch_w = img.size() 70 | 71 | chargrid_map = torch.zeros((batch_b, batch_h // stride, batch_w // stride), dtype=torch.int64).to(device) 72 | 73 | for iter_b in range(batch_b): 74 | per_input_ids = batched_inputs[iter_b]["input_ids"] 75 | per_input_bbox = batched_inputs[iter_b]["bbox"] 76 | 77 | short_length_w = min(len(per_input_ids), len(per_input_bbox)) 78 | 79 | if short_length_w > 0: 80 | for word_idx in range(short_length_w): 81 | per_id = per_input_ids[word_idx] 82 | 83 | bbox = per_input_bbox[word_idx] / stride 84 | w_start, h_start, w_end, h_end = bbox.round().astype(int).tolist() 85 | 86 | if self.use_UNK_text: 87 | chargrid_map[iter_b, h_start:h_end, w_start:w_end] = 100 88 | else: 89 | chargrid_map[iter_b, h_start:h_end, w_start:w_end] = per_id 90 | 91 | chargrid_map = self.embedding(chargrid_map) 92 | chargrid_map = self.embedding_proj(chargrid_map) 93 | 94 | return chargrid_map.permute(0, 3, 1, 2).contiguous() 95 | -------------------------------------------------------------------------------- /src/adapters/ml/vgt/ditod/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- 2 | # MPViT: Multi-Path Vision Transformer for Dense Prediction 3 | # Copyright (c) 2022 Electronics and Telecommunications Research Institute (ETRI). 4 | # All Rights Reserved. 5 | # Written by Youngwan Lee 6 | # This source code is licensed(Dual License(GPL3.0 & Commercial)) under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # -------------------------------------------------------------------------------- 9 | 10 | from .config import add_vit_config 11 | from .VGTbackbone import build_VGT_fpn_backbone 12 | from .dataset_mapper import DetrDatasetMapper 13 | from .VGTTrainer import VGTTrainer 14 | from .VGT import VGT 15 | 16 | from .utils import eval_and_show, load_gt_from_json, pub_load_gt_from_json 17 | -------------------------------------------------------------------------------- /src/adapters/ml/vgt/ditod/config.py: -------------------------------------------------------------------------------- 1 | from detectron2.config import CfgNode as CN 2 | 3 | 4 | def add_vit_config(cfg): 5 | """ 6 | Add config for VIT. 7 | """ 8 | _C = cfg 9 | 10 | _C.MODEL.VIT = CN() 11 | 12 | # CoaT model name. 13 | _C.MODEL.VIT.NAME = "" 14 | 15 | # Output features from CoaT backbone. 16 | _C.MODEL.VIT.OUT_FEATURES = ["layer3", "layer5", "layer7", "layer11"] 17 | 18 | _C.MODEL.VIT.IMG_SIZE = [224, 224] 19 | 20 | _C.MODEL.VIT.POS_TYPE = "shared_rel" 21 | 22 | _C.MODEL.VIT.MERGE_TYPE = "Sum" 23 | 24 | _C.MODEL.VIT.DROP_PATH = 0.0 25 | 26 | _C.MODEL.VIT.MODEL_KWARGS = "{}" 27 | 28 | _C.SOLVER.OPTIMIZER = "ADAMW" 29 | 30 | _C.SOLVER.BACKBONE_MULTIPLIER = 1.0 31 | 32 | _C.AUG = CN() 33 | 34 | _C.AUG.DETR = False 35 | 36 | _C.MODEL.WORDGRID = CN() 37 | 38 | _C.MODEL.WORDGRID.VOCAB_SIZE = 30552 39 | 40 | _C.MODEL.WORDGRID.EMBEDDING_DIM = 64 41 | 42 | _C.MODEL.WORDGRID.MODEL_PATH = "" 43 | 44 | _C.MODEL.WORDGRID.HIDDEN_SIZE = 768 45 | 46 | _C.MODEL.WORDGRID.USE_PRETRAIN_WEIGHT = True 47 | 48 | _C.MODEL.WORDGRID.USE_UNK_TEXT = False 49 | -------------------------------------------------------------------------------- /src/adapters/ml/vgt/ditod/tokenization_bros.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | 18 | import collections 19 | 20 | from transformers.models.bert.tokenization_bert import BertTokenizer 21 | from transformers.utils import logging 22 | 23 | logger = logging.get_logger(__name__) 24 | 25 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 26 | 27 | PRETRAINED_VOCAB_FILES_MAP = { 28 | "vocab_file": { 29 | "naver-clova-ocr/bros-base-uncased": "https://huggingface.co/naver-clova-ocr/bros-base-uncased/resolve/main/vocab.txt", 30 | "naver-clova-ocr/bros-large-uncased": "https://huggingface.co/naver-clova-ocr/bros-large-uncased/resolve/main/vocab.txt", 31 | } 32 | } 33 | 34 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 35 | "naver-clova-ocr/bros-base-uncased": 512, 36 | "naver-clova-ocr/bros-large-uncased": 512, 37 | } 38 | 39 | PRETRAINED_INIT_CONFIGURATION = { 40 | "naver-clova-ocr/bros-base-uncased": {"do_lower_case": True}, 41 | "naver-clova-ocr/bros-large-uncased": {"do_lower_case": True}, 42 | } 43 | 44 | 45 | def load_vocab(vocab_file): 46 | """Loads a vocabulary file into a dictionary.""" 47 | vocab = collections.OrderedDict() 48 | with open(vocab_file, "r", encoding="utf-8") as reader: 49 | tokens = reader.readlines() 50 | for index, token in enumerate(tokens): 51 | token = token.rstrip("\n") 52 | vocab[token] = index 53 | return vocab 54 | 55 | 56 | def convert_to_unicode(text): 57 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 58 | if six.PY3: 59 | if isinstance(text, str): 60 | return text 61 | elif isinstance(text, bytes): 62 | return text.decode("utf-8", "ignore") 63 | else: 64 | raise ValueError("Unsupported string type: %s" % (type(text))) 65 | elif six.PY2: 66 | if isinstance(text, str): 67 | return text.decode("utf-8", "ignore") 68 | elif isinstance(text, unicode): 69 | return text 70 | else: 71 | raise ValueError("Unsupported string type: %s" % (type(text))) 72 | else: 73 | raise ValueError("Not running on Python2 or Python 3?") 74 | 75 | 76 | def whitespace_tokenize(text): 77 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 78 | text = text.strip() 79 | if not text: 80 | return [] 81 | tokens = text.split() 82 | return tokens 83 | 84 | 85 | class BrosTokenizer(BertTokenizer): 86 | r""" 87 | Construct a BERT tokenizer. Based on WordPiece. 88 | 89 | This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods. 90 | Users should refer to this superclass for more information regarding those methods. 91 | 92 | Args: 93 | vocab_file (:obj:`str`): 94 | File containing the vocabulary. 95 | do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): 96 | Whether or not to lowercase the input when tokenizing. 97 | do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`): 98 | Whether or not to do basic tokenization before WordPiece. 99 | never_split (:obj:`Iterable`, `optional`): 100 | Collection of tokens which will never be split during tokenization. Only has an effect when 101 | :obj:`do_basic_tokenize=True` 102 | unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`): 103 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this 104 | token instead. 105 | sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`): 106 | The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for 107 | sequence classification or for a text and a question for question answering. It is also used as the last 108 | token of a sequence built with special tokens. 109 | pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`): 110 | The token used for padding, for example when batching sequences of different lengths. 111 | cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`): 112 | The classifier token which is used when doing sequence classification (classification of the whole sequence 113 | instead of per-token classification). It is the first token of the sequence when built with special tokens. 114 | mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`): 115 | The token used for masking values. This is the token used when training this model with masked language 116 | modeling. This is the token which the model will try to predict. 117 | tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): 118 | Whether or not to tokenize Chinese characters. 119 | 120 | This should likely be deactivated for Japanese (see this `issue 121 | `__). 122 | strip_accents: (:obj:`bool`, `optional`): 123 | Whether or not to strip all accents. If this option is not specified, then it will be determined by the 124 | value for :obj:`lowercase` (as in the original BERT). 125 | """ 126 | 127 | vocab_files_names = VOCAB_FILES_NAMES 128 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 129 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION 130 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 131 | -------------------------------------------------------------------------------- /src/adapters/ml/vgt/get_json_annotations.py: -------------------------------------------------------------------------------- 1 | import json 2 | from os import makedirs 3 | from pdf_features import PdfToken 4 | from domain.PdfImages import PdfImages 5 | from configuration import DOCLAYNET_TYPE_BY_ID 6 | from configuration import JSONS_ROOT_PATH, JSON_TEST_FILE_PATH 7 | 8 | 9 | def save_annotations_json(annotations: list, width_height: list, images: list): 10 | images_dict = [ 11 | { 12 | "id": i, 13 | "file_name": image_id + ".jpg", 14 | "width": width_height[images.index(image_id)][0], 15 | "height": width_height[images.index(image_id)][1], 16 | } 17 | for i, image_id in enumerate(images) 18 | ] 19 | 20 | categories_dict = [{"id": key, "name": value} for key, value in DOCLAYNET_TYPE_BY_ID.items()] 21 | 22 | info_dict = { 23 | "description": "PDF Document Layout Analysis Dataset", 24 | "url": "", 25 | "version": "1.0", 26 | "year": 2025, 27 | "contributor": "", 28 | "date_created": "2025-01-01", 29 | } 30 | 31 | coco_dict = {"info": info_dict, "images": images_dict, "categories": categories_dict, "annotations": annotations} 32 | 33 | JSON_TEST_FILE_PATH.write_text(json.dumps(coco_dict)) 34 | 35 | 36 | def get_annotation(index: int, image_id: str, token: PdfToken): 37 | return { 38 | "area": 1, 39 | "iscrowd": 0, 40 | "score": 1, 41 | "image_id": image_id, 42 | "bbox": [token.bounding_box.left, token.bounding_box.top, token.bounding_box.width, token.bounding_box.height], 43 | "category_id": token.token_type.get_index(), 44 | "id": index, 45 | } 46 | 47 | 48 | def get_annotations_for_document(annotations, images, index, pdf_images, width_height): 49 | for page_index, page in enumerate(pdf_images.pdf_features.pages): 50 | image_id = f"{pdf_images.pdf_features.file_name}_{page.page_number - 1}" 51 | images.append(image_id) 52 | width_height.append((pdf_images.pdf_images[page_index].width, pdf_images.pdf_images[page_index].height)) 53 | 54 | for token in page.tokens: 55 | annotations.append(get_annotation(index, image_id, token)) 56 | index += 1 57 | 58 | 59 | def get_annotations(pdf_images_list: list[PdfImages]): 60 | makedirs(JSONS_ROOT_PATH, exist_ok=True) 61 | 62 | annotations = list() 63 | images = list() 64 | width_height = list() 65 | index = 0 66 | 67 | for pdf_images in pdf_images_list: 68 | get_annotations_for_document(annotations, images, index, pdf_images, width_height) 69 | index += sum([len(page.tokens) for page in pdf_images.pdf_features.pages]) 70 | 71 | save_annotations_json(annotations, width_height, images) 72 | -------------------------------------------------------------------------------- /src/adapters/ml/vgt/get_model_configuration.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from os.path import join 3 | from detectron2.config import get_cfg 4 | from detectron2.engine import default_setup, default_argument_parser 5 | from configuration import service_logger, SRC_PATH, ROOT_PATH 6 | from adapters.ml.vgt.ditod import add_vit_config 7 | 8 | 9 | def is_gpu_available(): 10 | total_free_memory_in_system: float = 0.0 11 | if torch.cuda.is_available(): 12 | for i in range(torch.cuda.device_count()): 13 | total_memory = torch.cuda.get_device_properties(i).total_memory / 1024**2 14 | allocated_memory = torch.cuda.memory_allocated(i) / 1024**2 15 | cached_memory = torch.cuda.memory_reserved(i) / 1024**2 16 | service_logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}") 17 | service_logger.info(f" Total Memory: {total_memory} MB") 18 | service_logger.info(f" Allocated Memory: {allocated_memory} MB") 19 | service_logger.info(f" Cached Memory: {cached_memory} MB") 20 | total_free_memory_in_system += total_memory - allocated_memory - cached_memory 21 | if total_free_memory_in_system < 3000: 22 | service_logger.info(f"Total free GPU memory is {total_free_memory_in_system} < 3000 MB. Switching to CPU.") 23 | service_logger.info("The process is probably going to be 15 times slower.") 24 | else: 25 | service_logger.info("No CUDA-compatible GPU detected. Switching to CPU.") 26 | return total_free_memory_in_system > 3000 27 | 28 | 29 | def get_model_configuration(): 30 | parser = default_argument_parser() 31 | args, unknown = parser.parse_known_args() 32 | args.config_file = join(SRC_PATH, "adapters", "ml", "vgt", "model_configuration", "doclaynet_VGT_cascade_PTM.yaml") 33 | args.eval_only = True 34 | args.num_gpus = 1 35 | args.opts = [ 36 | "MODEL.WEIGHTS", 37 | join(ROOT_PATH, "models", "doclaynet_VGT_model.pth"), 38 | "OUTPUT_DIR", 39 | join(ROOT_PATH, "model_output_doclaynet"), 40 | ] 41 | args.debug = False 42 | 43 | configuration = get_cfg() 44 | add_vit_config(configuration) 45 | configuration.merge_from_file(args.config_file) 46 | configuration.merge_from_list(args.opts) 47 | configuration.MODEL.DEVICE = "cuda" if is_gpu_available() else "cpu" 48 | configuration.freeze() 49 | default_setup(configuration, args) 50 | 51 | return configuration 52 | -------------------------------------------------------------------------------- /src/adapters/ml/vgt/get_most_probable_pdf_segments.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | from os.path import join 4 | from pathlib import Path 5 | from statistics import mode 6 | 7 | from domain.PdfSegment import PdfSegment 8 | from pdf_features import PdfFeatures 9 | from pdf_features import PdfToken 10 | from pdf_features import Rectangle 11 | from pdf_token_type_labels import TokenType 12 | from domain.PdfImages import PdfImages 13 | from configuration import ROOT_PATH, DOCLAYNET_TYPE_BY_ID 14 | from domain.Prediction import Prediction 15 | 16 | 17 | def get_prediction_from_annotation(annotation, images_names, vgt_predictions_dict): 18 | pdf_name = images_names[annotation["image_id"]][:-4] 19 | category_id = annotation["category_id"] 20 | bounding_box = Rectangle.from_width_height( 21 | left=int(annotation["bbox"][0]), 22 | top=int(annotation["bbox"][1]), 23 | width=int(annotation["bbox"][2]), 24 | height=int(annotation["bbox"][3]), 25 | ) 26 | 27 | prediction = Prediction( 28 | bounding_box=bounding_box, category_id=category_id, score=round(float(annotation["score"]) * 100, 2) 29 | ) 30 | vgt_predictions_dict.setdefault(pdf_name, list()).append(prediction) 31 | 32 | 33 | def get_vgt_predictions(model_name: str) -> dict[str, list[Prediction]]: 34 | output_dir: str = f"model_output_{model_name}" 35 | model_output_json_path = join(str(ROOT_PATH), output_dir, "inference", "coco_instances_results.json") 36 | annotations = json.loads(Path(model_output_json_path).read_text()) 37 | 38 | test_json_path = join(str(ROOT_PATH), "jsons", "test.json") 39 | coco_truth = json.loads(Path(test_json_path).read_text()) 40 | 41 | images_names = {value["id"]: value["file_name"] for value in coco_truth["images"]} 42 | 43 | vgt_predictions_dict = dict() 44 | for annotation in annotations: 45 | get_prediction_from_annotation(annotation, images_names, vgt_predictions_dict) 46 | 47 | return vgt_predictions_dict 48 | 49 | 50 | def find_best_prediction_for_token(page_pdf_name, token, vgt_predictions_dict, most_probable_tokens_by_predictions): 51 | best_score: float = 0 52 | most_probable_prediction: Prediction | None = None 53 | for prediction in vgt_predictions_dict[page_pdf_name]: 54 | if prediction.score > best_score and prediction.bounding_box.get_intersection_percentage(token.bounding_box): 55 | best_score = prediction.score 56 | most_probable_prediction = prediction 57 | if best_score >= 99: 58 | break 59 | if most_probable_prediction: 60 | most_probable_tokens_by_predictions.setdefault(most_probable_prediction, list()).append(token) 61 | else: 62 | dummy_prediction = Prediction(bounding_box=token.bounding_box, category_id=10, score=0.0) 63 | most_probable_tokens_by_predictions.setdefault(dummy_prediction, list()).append(token) 64 | 65 | 66 | def get_merged_prediction_type(to_merge: list[Prediction]): 67 | table_exists = any([p.category_id == 9 for p in to_merge]) 68 | if not table_exists: 69 | return mode([p.category_id for p in sorted(to_merge, key=lambda x: -x.score)]) 70 | return 9 71 | 72 | 73 | def merge_colliding_predictions(predictions: list[Prediction]): 74 | predictions = [p for p in predictions if not p.score < 20] 75 | while True: 76 | new_predictions, merged = [], False 77 | while predictions: 78 | p1 = predictions.pop(0) 79 | to_merge = [p for p in predictions if p1.bounding_box.get_intersection_percentage(p.bounding_box) > 0] 80 | for prediction in to_merge: 81 | predictions.remove(prediction) 82 | if to_merge: 83 | to_merge.append(p1) 84 | p1.bounding_box = Rectangle.merge_rectangles([prediction.bounding_box for prediction in to_merge]) 85 | p1.category_id = get_merged_prediction_type(to_merge) 86 | merged = True 87 | new_predictions.append(p1) 88 | if not merged: 89 | return new_predictions 90 | predictions = new_predictions 91 | 92 | 93 | def get_pdf_segments_for_page(page, pdf_name, page_pdf_name, vgt_predictions_dict): 94 | most_probable_pdf_segments_for_page: list[PdfSegment] = [] 95 | most_probable_tokens_by_predictions: dict[Prediction, list[PdfToken]] = {} 96 | vgt_predictions_dict[page_pdf_name] = merge_colliding_predictions(vgt_predictions_dict[page_pdf_name]) 97 | 98 | for token in page.tokens: 99 | find_best_prediction_for_token(page_pdf_name, token, vgt_predictions_dict, most_probable_tokens_by_predictions) 100 | 101 | for prediction, tokens in most_probable_tokens_by_predictions.items(): 102 | new_segment = PdfSegment.from_pdf_tokens(tokens, pdf_name) 103 | new_segment.bounding_box = prediction.bounding_box 104 | new_segment.segment_type = TokenType.from_text(DOCLAYNET_TYPE_BY_ID[prediction.category_id]) 105 | most_probable_pdf_segments_for_page.append(new_segment) 106 | 107 | no_token_predictions = [ 108 | prediction 109 | for prediction in vgt_predictions_dict[page_pdf_name] 110 | if prediction not in most_probable_tokens_by_predictions 111 | ] 112 | 113 | for prediction in no_token_predictions: 114 | segment_type = TokenType.from_text(DOCLAYNET_TYPE_BY_ID[prediction.category_id]) 115 | page_number = page.page_number 116 | new_segment = PdfSegment(page_number, prediction.bounding_box, "", segment_type, pdf_name) 117 | most_probable_pdf_segments_for_page.append(new_segment) 118 | 119 | return most_probable_pdf_segments_for_page 120 | 121 | 122 | def prediction_exists_for_page(page_pdf_name, vgt_predictions_dict): 123 | return page_pdf_name in vgt_predictions_dict 124 | 125 | 126 | def get_most_probable_pdf_segments(model_name: str, pdf_images_list: list[PdfImages], save_output: bool = False): 127 | most_probable_pdf_segments: list[PdfSegment] = [] 128 | vgt_predictions_dict = get_vgt_predictions(model_name) 129 | pdf_features_list: list[PdfFeatures] = [pdf_images.pdf_features for pdf_images in pdf_images_list] 130 | for pdf_features in pdf_features_list: 131 | for page in pdf_features.pages: 132 | page_pdf_name = pdf_features.file_name + "_" + str(page.page_number - 1) 133 | if not prediction_exists_for_page(page_pdf_name, vgt_predictions_dict): 134 | continue 135 | page_segments = get_pdf_segments_for_page(page, pdf_features.file_name, page_pdf_name, vgt_predictions_dict) 136 | most_probable_pdf_segments.extend(page_segments) 137 | if save_output: 138 | save_path = join(ROOT_PATH, f"model_output_{model_name}", "predicted_segments.pickle") 139 | with open(save_path, mode="wb") as file: 140 | pickle.dump(most_probable_pdf_segments, file) 141 | return most_probable_pdf_segments 142 | -------------------------------------------------------------------------------- /src/adapters/ml/vgt/get_reading_orders.py: -------------------------------------------------------------------------------- 1 | from domain.PdfSegment import PdfSegment 2 | from pdf_features import PdfPage 3 | from pdf_features import PdfToken 4 | from pdf_token_type_labels import TokenType 5 | 6 | from domain.PdfImages import PdfImages 7 | 8 | 9 | def find_segment_for_token(token: PdfToken, segments: list[PdfSegment], tokens_by_segments): 10 | best_score: float = 0 11 | most_probable_segment: PdfSegment | None = None 12 | for segment in segments: 13 | intersection_percentage = token.bounding_box.get_intersection_percentage(segment.bounding_box) 14 | if intersection_percentage > best_score: 15 | best_score = intersection_percentage 16 | most_probable_segment = segment 17 | if best_score >= 99: 18 | break 19 | if most_probable_segment: 20 | tokens_by_segments.setdefault(most_probable_segment, list()).append(token) 21 | 22 | 23 | def get_average_reading_order_for_segment(page: PdfPage, tokens_for_segment: list[PdfToken]): 24 | reading_order_sum: int = sum(page.tokens.index(token) for token in tokens_for_segment) 25 | return reading_order_sum / len(tokens_for_segment) 26 | 27 | 28 | def get_distance_between_segments(segment1: PdfSegment, segment2: PdfSegment): 29 | center_1_x = (segment1.bounding_box.left + segment1.bounding_box.right) / 2 30 | center_1_y = (segment1.bounding_box.top + segment1.bounding_box.bottom) / 2 31 | center_2_x = (segment2.bounding_box.left + segment2.bounding_box.right) / 2 32 | center_2_y = (segment2.bounding_box.top + segment2.bounding_box.bottom) / 2 33 | return ((center_1_x - center_2_x) ** 2 + (center_1_y - center_2_y) ** 2) ** 0.5 34 | 35 | 36 | def add_no_token_segments(segments, no_token_segments): 37 | if segments: 38 | for no_token_segment in no_token_segments: 39 | closest_segment = sorted(segments, key=lambda seg: get_distance_between_segments(no_token_segment, seg))[0] 40 | closest_index = segments.index(closest_segment) 41 | if closest_segment.bounding_box.top < no_token_segment.bounding_box.top: 42 | segments.insert(closest_index + 1, no_token_segment) 43 | else: 44 | segments.insert(closest_index, no_token_segment) 45 | else: 46 | for segment in sorted(no_token_segments, key=lambda r: (r.bounding_box.left, r.bounding_box.top)): 47 | segments.append(segment) 48 | 49 | 50 | def filter_and_sort_segments(page, tokens_by_segments, types): 51 | filtered_segments = [seg for seg in tokens_by_segments.keys() if seg.segment_type in types] 52 | order = {seg: get_average_reading_order_for_segment(page, tokens_by_segments[seg]) for seg in filtered_segments} 53 | return sorted(filtered_segments, key=lambda seg: order[seg]) 54 | 55 | 56 | def get_ordered_segments_for_page(segments_for_page: list[PdfSegment], page: PdfPage): 57 | tokens_by_segments: dict[PdfSegment, list[PdfToken]] = {} 58 | for token in page.tokens: 59 | find_segment_for_token(token, segments_for_page, tokens_by_segments) 60 | 61 | page_number_segment: None | PdfSegment = None 62 | if tokens_by_segments: 63 | last_segment = max(tokens_by_segments.keys(), key=lambda seg: seg.bounding_box.top) 64 | if last_segment.text_content and len(last_segment.text_content) < 5: 65 | page_number_segment = last_segment 66 | del tokens_by_segments[last_segment] 67 | 68 | header_segments: list[PdfSegment] = filter_and_sort_segments(page, tokens_by_segments, {TokenType.PAGE_HEADER}) 69 | paragraph_types = {t for t in TokenType if t.name not in {"PAGE_HEADER", "PAGE_FOOTER", "FOOTNOTE"}} 70 | paragraph_segments = filter_and_sort_segments(page, tokens_by_segments, paragraph_types) 71 | footer_segments = filter_and_sort_segments(page, tokens_by_segments, {TokenType.PAGE_FOOTER, TokenType.FOOTNOTE}) 72 | if page_number_segment: 73 | footer_segments.append(page_number_segment) 74 | ordered_segments = header_segments + paragraph_segments + footer_segments 75 | no_token_segments = [segment for segment in segments_for_page if segment not in ordered_segments] 76 | add_no_token_segments(ordered_segments, no_token_segments) 77 | return ordered_segments 78 | 79 | 80 | def get_reading_orders(pdf_images_list: list[PdfImages], predicted_segments: list[PdfSegment]): 81 | ordered_segments: list[PdfSegment] = [] 82 | for pdf_images in pdf_images_list: 83 | pdf_name = pdf_images.pdf_features.file_name 84 | segments_for_file = [segment for segment in predicted_segments if segment.pdf_name == pdf_name] 85 | for page in pdf_images.pdf_features.pages: 86 | segments_for_page = [segment for segment in segments_for_file if segment.page_number == page.page_number] 87 | ordered_segments.extend(get_ordered_segments_for_page(segments_for_page, page)) 88 | return ordered_segments 89 | -------------------------------------------------------------------------------- /src/adapters/ml/vgt/model_configuration/Base-RCNN-FPN.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | MASK_ON: True 3 | META_ARCHITECTURE: "GeneralizedRCNN" 4 | PIXEL_MEAN: [123.675, 116.280, 103.530] 5 | PIXEL_STD: [58.395, 57.120, 57.375] 6 | BACKBONE: 7 | NAME: "build_vit_fpn_backbone" 8 | VIT: 9 | OUT_FEATURES: ["layer3", "layer5", "layer7", "layer11"] 10 | DROP_PATH: 0.1 11 | IMG_SIZE: [224,224] 12 | POS_TYPE: "abs" 13 | FPN: 14 | IN_FEATURES: ["layer3", "layer5", "layer7", "layer11"] 15 | ANCHOR_GENERATOR: 16 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map 17 | ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) 18 | RPN: 19 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] 20 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 21 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 22 | # Detectron1 uses 2000 proposals per-batch, 23 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) 24 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. 25 | POST_NMS_TOPK_TRAIN: 1000 26 | POST_NMS_TOPK_TEST: 1000 27 | ROI_HEADS: 28 | NAME: "StandardROIHeads" 29 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 30 | NUM_CLASSES: 5 31 | ROI_BOX_HEAD: 32 | NAME: "FastRCNNConvFCHead" 33 | NUM_FC: 2 34 | POOLER_RESOLUTION: 7 35 | ROI_MASK_HEAD: 36 | NAME: "MaskRCNNConvUpsampleHead" 37 | NUM_CONV: 4 38 | POOLER_RESOLUTION: 14 39 | DATASETS: 40 | TRAIN: ("docbank_train",) 41 | TEST: ("docbank_val",) 42 | SOLVER: 43 | LR_SCHEDULER_NAME: "WarmupCosineLR" 44 | AMP: 45 | ENABLED: True 46 | OPTIMIZER: "ADAMW" 47 | BACKBONE_MULTIPLIER: 1.0 48 | CLIP_GRADIENTS: 49 | ENABLED: True 50 | CLIP_TYPE: "full_model" 51 | CLIP_VALUE: 1.0 52 | NORM_TYPE: 2.0 53 | WARMUP_FACTOR: 0.01 54 | BASE_LR: 0.0004 55 | WEIGHT_DECAY: 0.05 56 | IMS_PER_BATCH: 32 57 | INPUT: 58 | CROP: 59 | ENABLED: True 60 | TYPE: "absolute_range" 61 | SIZE: (384, 600) 62 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800) 63 | FORMAT: "RGB" 64 | DATALOADER: 65 | NUM_WORKERS: 6 66 | FILTER_EMPTY_ANNOTATIONS: False 67 | VERSION: 2 68 | AUG: 69 | DETR: True 70 | SEED: 42 -------------------------------------------------------------------------------- /src/adapters/ml/vgt/model_configuration/doclaynet_VGT_cascade_PTM.yaml: -------------------------------------------------------------------------------- 1 | DATASETS: 2 | TEST: ("predict_data",) 3 | TRAIN: ("train_data",) 4 | MODEL: 5 | BACKBONE: 6 | NAME: build_VGT_fpn_backbone 7 | MASK_ON: false 8 | META_ARCHITECTURE: VGT 9 | PIXEL_MEAN: 10 | - 127.5 11 | - 127.5 12 | - 127.5 13 | PIXEL_STD: 14 | - 127.5 15 | - 127.5 16 | - 127.5 17 | ROI_BOX_HEAD: 18 | CLS_AGNOSTIC_BBOX_REG: true 19 | ROI_HEADS: 20 | NAME: CascadeROIHeads 21 | NUM_CLASSES: 11 22 | RPN: 23 | POST_NMS_TOPK_TRAIN: 2000 24 | VIT: 25 | MERGE_TYPE: Sum 26 | NAME: VGT_dit_base_patch16 27 | WEIGHTS: https://layoutlm.blob.core.windows.net/dit/dit-pts/dit-base-224-p16-500k-62d53a.pth 28 | WORDGRID: 29 | EMBEDDING_DIM: 64 30 | MODEL_PATH: ../models/layoutlm-base-uncased/ 31 | USE_PRETRAIN_WEIGHT: true 32 | VOCAB_SIZE: 30552 33 | SOLVER: 34 | BASE_LR: 0.0002 35 | IMS_PER_BATCH: 12 36 | MAX_ITER: 10000 37 | STEPS: (6000, 8000) 38 | WARMUP_ITERS: 100 39 | TEST: 40 | EVAL_PERIOD: 2000 41 | _BASE_: ./Base-RCNN-FPN.yaml 42 | -------------------------------------------------------------------------------- /src/adapters/ml/vgt/model_configuration/doclaynet_configuration.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/ml/vgt/model_configuration/doclaynet_configuration.pickle -------------------------------------------------------------------------------- /src/adapters/ml/vgt_model_adapter.py: -------------------------------------------------------------------------------- 1 | from domain.PdfImages import PdfImages 2 | from domain.PdfSegment import PdfSegment 3 | from ports.services.ml_model_service import MLModelService 4 | from adapters.ml.vgt.ditod import VGTTrainer 5 | from adapters.ml.vgt.get_model_configuration import get_model_configuration 6 | from adapters.ml.vgt.get_most_probable_pdf_segments import get_most_probable_pdf_segments 7 | from adapters.ml.vgt.get_reading_orders import get_reading_orders 8 | from adapters.ml.vgt.get_json_annotations import get_annotations 9 | from adapters.ml.vgt.create_word_grid import create_word_grid, remove_word_grids 10 | from detectron2.checkpoint import DetectionCheckpointer 11 | from detectron2.data.datasets import register_coco_instances 12 | from detectron2.data import DatasetCatalog 13 | from configuration import JSON_TEST_FILE_PATH, IMAGES_ROOT_PATH 14 | 15 | configuration = get_model_configuration() 16 | model = VGTTrainer.build_model(configuration) 17 | DetectionCheckpointer(model, save_dir=configuration.OUTPUT_DIR).resume_or_load(configuration.MODEL.WEIGHTS, resume=True) 18 | 19 | 20 | class VGTModelAdapter(MLModelService): 21 | 22 | def _register_data(self) -> None: 23 | try: 24 | DatasetCatalog.remove("predict_data") 25 | except KeyError: 26 | pass 27 | 28 | register_coco_instances("predict_data", {}, JSON_TEST_FILE_PATH, IMAGES_ROOT_PATH) 29 | 30 | def predict_document_layout(self, pdf_images: list[PdfImages]) -> list[PdfSegment]: 31 | create_word_grid([pdf_images_obj.pdf_features for pdf_images_obj in pdf_images]) 32 | get_annotations(pdf_images) 33 | 34 | self._register_data() 35 | VGTTrainer.test(configuration, model) 36 | 37 | predicted_segments = get_most_probable_pdf_segments("doclaynet", pdf_images, False) 38 | 39 | PdfImages.remove_images() 40 | remove_word_grids() 41 | 42 | return get_reading_orders(pdf_images, predicted_segments) 43 | 44 | def predict_layout_fast(self, pdf_images: list[PdfImages]) -> list[PdfSegment]: 45 | raise NotImplementedError("Fast prediction should be handled by FastTrainerAdapter") 46 | -------------------------------------------------------------------------------- /src/adapters/storage/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/storage/__init__.py -------------------------------------------------------------------------------- /src/adapters/storage/file_system_repository.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | import uuid 3 | from pathlib import Path 4 | from typing import AnyStr 5 | from ports.repositories.file_repository import FileRepository 6 | from configuration import XMLS_PATH 7 | 8 | 9 | class FileSystemRepository(FileRepository): 10 | def save_pdf(self, content: AnyStr, filename: str = "") -> Path: 11 | if not filename: 12 | filename = str(uuid.uuid1()) 13 | 14 | pdf_path = Path(tempfile.gettempdir(), f"{filename}.pdf") 15 | pdf_path.write_bytes(content) 16 | return pdf_path 17 | 18 | def save_xml(self, content: str, filename: str) -> Path: 19 | if not filename.endswith(".xml"): 20 | filename = f"{filename}.xml" 21 | 22 | xml_path = Path(XMLS_PATH, filename) 23 | xml_path.parent.mkdir(parents=True, exist_ok=True) 24 | xml_path.write_text(content) 25 | return xml_path 26 | 27 | def get_xml(self, filename: str) -> str: 28 | if not filename.endswith(".xml"): 29 | filename = f"{filename}.xml" 30 | 31 | xml_path = Path(XMLS_PATH, filename) 32 | if not xml_path.exists(): 33 | raise FileNotFoundError(f"XML file {filename} not found") 34 | 35 | return xml_path.read_text() 36 | 37 | def delete_file(self, filepath: Path) -> None: 38 | filepath.unlink(missing_ok=True) 39 | 40 | def cleanup_temp_files(self) -> None: 41 | pass 42 | 43 | def save_pdf_to_directory(self, content: AnyStr, filename: str, directory: Path, namespace: str = "") -> Path: 44 | if namespace: 45 | target_path = Path(directory, namespace, filename) 46 | else: 47 | target_path = Path(directory, filename) 48 | 49 | target_path.parent.mkdir(parents=True, exist_ok=True) 50 | target_path.write_bytes(content) 51 | return target_path 52 | 53 | def save_markdown(self, content: str, filepath: Path) -> Path: 54 | filepath.parent.mkdir(parents=True, exist_ok=True) 55 | filepath.write_text(content, encoding="utf-8") 56 | return filepath 57 | -------------------------------------------------------------------------------- /src/adapters/web/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/web/__init__.py -------------------------------------------------------------------------------- /src/adapters/web/fastapi_controllers.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import subprocess 3 | from fastapi import UploadFile, File, Form 4 | from typing import Optional, Union 5 | from starlette.responses import Response 6 | from starlette.concurrency import run_in_threadpool 7 | from use_cases.pdf_analysis.analyze_pdf_use_case import AnalyzePDFUseCase 8 | from use_cases.text_extraction.extract_text_use_case import ExtractTextUseCase 9 | from use_cases.toc_extraction.extract_toc_use_case import ExtractTOCUseCase 10 | from use_cases.visualization.create_visualization_use_case import CreateVisualizationUseCase 11 | from use_cases.ocr.process_ocr_use_case import ProcessOCRUseCase 12 | from use_cases.markdown_conversion.convert_to_markdown_use_case import ConvertToMarkdownUseCase 13 | from use_cases.html_conversion.convert_to_html_use_case import ConvertToHtmlUseCase 14 | from adapters.storage.file_system_repository import FileSystemRepository 15 | 16 | 17 | class FastAPIControllers: 18 | def __init__( 19 | self, 20 | analyze_pdf_use_case: AnalyzePDFUseCase, 21 | extract_text_use_case: ExtractTextUseCase, 22 | extract_toc_use_case: ExtractTOCUseCase, 23 | create_visualization_use_case: CreateVisualizationUseCase, 24 | process_ocr_use_case: ProcessOCRUseCase, 25 | convert_to_markdown_use_case: ConvertToMarkdownUseCase, 26 | convert_to_html_use_case: ConvertToHtmlUseCase, 27 | file_repository: FileSystemRepository, 28 | ): 29 | self.analyze_pdf_use_case = analyze_pdf_use_case 30 | self.extract_text_use_case = extract_text_use_case 31 | self.extract_toc_use_case = extract_toc_use_case 32 | self.create_visualization_use_case = create_visualization_use_case 33 | self.process_ocr_use_case = process_ocr_use_case 34 | self.convert_to_markdown_use_case = convert_to_markdown_use_case 35 | self.convert_to_html_use_case = convert_to_html_use_case 36 | self.file_repository = file_repository 37 | 38 | async def root(self): 39 | import torch 40 | 41 | return sys.version + " Using GPU: " + str(torch.cuda.is_available()) 42 | 43 | async def info(self): 44 | return { 45 | "sys": sys.version, 46 | "tesseract_version": subprocess.run("tesseract --version", shell=True, text=True, capture_output=True).stdout, 47 | "ocrmypdf_version": subprocess.run("ocrmypdf --version", shell=True, text=True, capture_output=True).stdout, 48 | "supported_languages": self.process_ocr_use_case.get_supported_languages(), 49 | } 50 | 51 | async def error(self): 52 | raise FileNotFoundError("This is a test error from the error endpoint") 53 | 54 | async def analyze_pdf( 55 | self, file: UploadFile = File(...), fast: bool = Form(False), parse_tables_and_math: bool = Form(False) 56 | ): 57 | return await run_in_threadpool( 58 | self.analyze_pdf_use_case.execute, file.file.read(), "", parse_tables_and_math, fast, False 59 | ) 60 | 61 | async def analyze_and_save_xml( 62 | self, file: UploadFile = File(...), xml_file_name: str | None = None, fast: bool = Form(False) 63 | ): 64 | if not xml_file_name.endswith(".xml"): 65 | xml_file_name = f"{xml_file_name}.xml" 66 | return await run_in_threadpool(self.analyze_pdf_use_case.execute_and_save_xml, file.file.read(), xml_file_name, fast) 67 | 68 | async def get_xml_by_name(self, xml_file_name: str): 69 | if not xml_file_name.endswith(".xml"): 70 | xml_file_name = f"{xml_file_name}.xml" 71 | return await run_in_threadpool(self.file_repository.get_xml, xml_file_name) 72 | 73 | async def get_toc_endpoint(self, file: UploadFile = File(...), fast: bool = Form(False)): 74 | return await run_in_threadpool(self.extract_toc_use_case.execute, file, fast) 75 | 76 | async def toc_legacy_uwazi_compatible(self, file: UploadFile = File(...)): 77 | return await run_in_threadpool(self.extract_toc_use_case.execute_uwazi_compatible, file) 78 | 79 | async def get_text_endpoint(self, file: UploadFile = File(...), fast: bool = Form(False), types: str = Form("all")): 80 | return await run_in_threadpool(self.extract_text_use_case.execute, file, fast, types) 81 | 82 | async def get_visualization_endpoint(self, file: UploadFile = File(...), fast: bool = Form(False)): 83 | return await run_in_threadpool(self.create_visualization_use_case.execute, file, fast) 84 | 85 | async def ocr_pdf_sync(self, file: UploadFile = File(...), language: str = Form("en")): 86 | return await run_in_threadpool(self.process_ocr_use_case.execute, file, language) 87 | 88 | async def convert_to_markdown_endpoint( 89 | self, 90 | file: UploadFile = File(...), 91 | fast: bool = Form(False), 92 | extract_toc: bool = Form(False), 93 | dpi: int = Form(120), 94 | output_file: Optional[str] = Form(None), 95 | target_languages: Optional[str] = Form(None), 96 | translation_model: str = Form("gpt-oss"), 97 | ) -> Union[str, Response]: 98 | target_languages_list = None 99 | if target_languages: 100 | target_languages_list = [lang.strip() for lang in target_languages.split(",") if lang.strip()] 101 | 102 | return await run_in_threadpool( 103 | self.convert_to_markdown_use_case.execute, 104 | file.file.read(), 105 | fast, 106 | extract_toc, 107 | dpi, 108 | output_file, 109 | target_languages_list, 110 | translation_model, 111 | ) 112 | 113 | async def convert_to_html_endpoint( 114 | self, 115 | file: UploadFile = File(...), 116 | fast: bool = Form(False), 117 | extract_toc: bool = Form(False), 118 | dpi: int = Form(120), 119 | output_file: Optional[str] = Form(None), 120 | target_languages: Optional[str] = Form(None), 121 | translation_model: str = Form("gpt-oss"), 122 | ) -> Union[str, Response]: 123 | target_languages_list = None 124 | if target_languages: 125 | target_languages_list = [lang.strip() for lang in target_languages.split(",") if lang.strip()] 126 | 127 | return await run_in_threadpool( 128 | self.convert_to_html_use_case.execute, 129 | file.file.read(), 130 | fast, 131 | extract_toc, 132 | dpi, 133 | output_file, 134 | target_languages_list, 135 | translation_model, 136 | ) 137 | -------------------------------------------------------------------------------- /src/app.py: -------------------------------------------------------------------------------- 1 | from configuration import RESTART_IF_NO_GPU 2 | from drivers.web.fastapi_app import create_app 3 | from drivers.web.dependency_injection import setup_dependencies 4 | import torch 5 | 6 | if RESTART_IF_NO_GPU: 7 | if not torch.cuda.is_available(): 8 | raise RuntimeError("No GPU available. Restarting the service is required.") 9 | 10 | controllers = setup_dependencies() 11 | 12 | app = create_app(controllers) 13 | -------------------------------------------------------------------------------- /src/catch_exceptions.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from fastapi import HTTPException 3 | 4 | from configuration import service_logger 5 | 6 | 7 | def catch_exceptions(func): 8 | @wraps(func) 9 | async def wrapper(*args, **kwargs): 10 | try: 11 | service_logger.info(f"Calling endpoint: {func.__name__}") 12 | if kwargs and "file" in kwargs: 13 | service_logger.info(f"Processing file: {kwargs['file'].filename}") 14 | if kwargs and "xml_file_name" in kwargs: 15 | service_logger.info(f"Asking for file: {kwargs['xml_file_name']}") 16 | return await func(*args, **kwargs) 17 | except FileNotFoundError: 18 | raise HTTPException(status_code=404, detail="No xml file") 19 | except Exception: 20 | service_logger.error("Error see traceback", exc_info=1) 21 | raise HTTPException(status_code=422, detail="Error see traceback") 22 | 23 | return wrapper 24 | -------------------------------------------------------------------------------- /src/configuration.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from pathlib import Path 4 | 5 | 6 | SRC_PATH = Path(__file__).parent.absolute() 7 | ROOT_PATH = Path(__file__).parent.parent.absolute() 8 | 9 | handlers = [logging.StreamHandler()] 10 | logging.root.handlers = [] 11 | logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=handlers) 12 | service_logger = logging.getLogger(__name__) 13 | 14 | RESTART_IF_NO_GPU = os.environ.get("RESTART_IF_NO_GPU", "false").lower().strip() == "true" 15 | IMAGES_ROOT_PATH = Path(ROOT_PATH, "images") 16 | WORD_GRIDS_PATH = Path(ROOT_PATH, "word_grids") 17 | JSONS_ROOT_PATH = Path(ROOT_PATH, "jsons") 18 | OCR_SOURCE = Path(ROOT_PATH, "ocr", "source") 19 | OCR_OUTPUT = Path(ROOT_PATH, "ocr", "output") 20 | OCR_FAILED = Path(ROOT_PATH, "ocr", "failed") 21 | JSON_TEST_FILE_PATH = Path(JSONS_ROOT_PATH, "test.json") 22 | MODELS_PATH = Path(ROOT_PATH, "models") 23 | XMLS_PATH = Path(ROOT_PATH, "xmls") 24 | 25 | DOCLAYNET_TYPE_BY_ID = { 26 | 1: "Caption", 27 | 2: "Footnote", 28 | 3: "Formula", 29 | 4: "List_Item", 30 | 5: "Page_Footer", 31 | 6: "Page_Header", 32 | 7: "Picture", 33 | 8: "Section_Header", 34 | 9: "Table", 35 | 10: "Text", 36 | 11: "Title", 37 | } 38 | -------------------------------------------------------------------------------- /src/domain/PdfImages.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import cv2 5 | import numpy as np 6 | from os import makedirs 7 | from os.path import join 8 | from pathlib import Path 9 | from PIL import Image 10 | from pdf2image import convert_from_path 11 | from pdf_features import PdfFeatures 12 | 13 | from src.configuration import IMAGES_ROOT_PATH, XMLS_PATH 14 | 15 | 16 | class PdfImages: 17 | def __init__(self, pdf_features: PdfFeatures, pdf_images: list[Image], dpi: int = 72): 18 | self.pdf_features: PdfFeatures = pdf_features 19 | self.pdf_images: list[Image] = pdf_images 20 | self.dpi: int = dpi 21 | self.save_images() 22 | 23 | def show_images(self, next_image_delay: int = 2): 24 | for image_index, image in enumerate(self.pdf_images): 25 | image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) 26 | cv2.imshow(f"Page: {image_index + 1}", image_np) 27 | cv2.waitKey(next_image_delay * 1000) 28 | cv2.destroyAllWindows() 29 | 30 | def save_images(self): 31 | makedirs(IMAGES_ROOT_PATH, exist_ok=True) 32 | for image_index, image in enumerate(self.pdf_images): 33 | image_name = f"{self.pdf_features.file_name}_{image_index}.jpg" 34 | image.save(join(IMAGES_ROOT_PATH, image_name)) 35 | 36 | @staticmethod 37 | def remove_images(): 38 | shutil.rmtree(IMAGES_ROOT_PATH) 39 | 40 | @staticmethod 41 | def from_pdf_path(pdf_path: str | Path, pdf_name: str = "", xml_file_name: str = "", dpi: int = 72): 42 | xml_path = None if not xml_file_name else Path(XMLS_PATH, xml_file_name) 43 | 44 | if xml_path and not xml_path.parent.exists(): 45 | os.makedirs(xml_path.parent, exist_ok=True) 46 | 47 | pdf_features: PdfFeatures = PdfFeatures.from_pdf_path(pdf_path, xml_path) 48 | 49 | if pdf_name: 50 | pdf_features.file_name = pdf_name 51 | else: 52 | pdf_name = Path(pdf_path).parent.name if Path(pdf_path).name == "document.pdf" else Path(pdf_path).stem 53 | pdf_features.file_name = pdf_name 54 | pdf_images = convert_from_path(pdf_path, dpi=dpi) 55 | return PdfImages(pdf_features, pdf_images, dpi) 56 | -------------------------------------------------------------------------------- /src/domain/PdfSegment.py: -------------------------------------------------------------------------------- 1 | from statistics import mode 2 | from pdf_features import PdfToken 3 | from pdf_features import Rectangle 4 | from pdf_token_type_labels import TokenType 5 | 6 | 7 | class PdfSegment: 8 | def __init__( 9 | self, page_number: int, bounding_box: Rectangle, text_content: str, segment_type: TokenType, pdf_name: str = "" 10 | ): 11 | self.page_number = page_number 12 | self.bounding_box = bounding_box 13 | self.text_content = text_content 14 | self.segment_type = segment_type 15 | self.pdf_name = pdf_name 16 | 17 | @staticmethod 18 | def from_pdf_tokens(pdf_tokens: list[PdfToken], pdf_name: str = ""): 19 | text: str = " ".join([pdf_token.content for pdf_token in pdf_tokens]) 20 | bounding_boxes = [pdf_token.bounding_box for pdf_token in pdf_tokens] 21 | segment_type = mode([token.token_type for token in pdf_tokens]) 22 | return PdfSegment( 23 | pdf_tokens[0].page_number, Rectangle.merge_rectangles(bounding_boxes), text, segment_type, pdf_name 24 | ) 25 | -------------------------------------------------------------------------------- /src/domain/Prediction.py: -------------------------------------------------------------------------------- 1 | from pdf_features import Rectangle 2 | 3 | 4 | class Prediction: 5 | def __init__(self, bounding_box: Rectangle, category_id: int, score: float): 6 | self.bounding_box: Rectangle = bounding_box 7 | self.category_id: int = category_id 8 | self.score: float = score 9 | -------------------------------------------------------------------------------- /src/domain/SegmentBox.py: -------------------------------------------------------------------------------- 1 | from domain.PdfSegment import PdfSegment 2 | from pdf_features import PdfPage 3 | from pdf_token_type_labels import TokenType 4 | from pydantic import BaseModel 5 | 6 | 7 | class SegmentBox(BaseModel): 8 | left: float 9 | top: float 10 | width: float 11 | height: float 12 | page_number: int 13 | page_width: int 14 | page_height: int 15 | text: str = "" 16 | type: TokenType = TokenType.TEXT 17 | id: str = "" 18 | 19 | def __hash__(self): 20 | return hash( 21 | ( 22 | self.left, 23 | self.top, 24 | self.width, 25 | self.height, 26 | self.page_number, 27 | self.page_width, 28 | self.page_height, 29 | self.text, 30 | self.type, 31 | self.id, 32 | ) 33 | ) 34 | 35 | def to_dict(self): 36 | return { 37 | "left": self.left, 38 | "top": self.top, 39 | "width": self.width, 40 | "height": self.height, 41 | "page_number": self.page_number, 42 | "page_width": self.page_width, 43 | "page_height": self.page_height, 44 | "text": self.text, 45 | "type": self.type.value, 46 | } 47 | 48 | @staticmethod 49 | def from_pdf_segment(pdf_segment: PdfSegment, pdf_pages: list[PdfPage]): 50 | return SegmentBox( 51 | left=pdf_segment.bounding_box.left, 52 | top=pdf_segment.bounding_box.top, 53 | width=pdf_segment.bounding_box.width, 54 | height=pdf_segment.bounding_box.height, 55 | page_number=pdf_segment.page_number, 56 | page_width=pdf_pages[pdf_segment.page_number - 1].page_width, 57 | page_height=pdf_pages[pdf_segment.page_number - 1].page_height, 58 | text=pdf_segment.text_content, 59 | type=pdf_segment.segment_type, 60 | ) 61 | 62 | 63 | if __name__ == "__main__": 64 | a = TokenType.TEXT 65 | print(a.value) 66 | -------------------------------------------------------------------------------- /src/download_models.py: -------------------------------------------------------------------------------- 1 | import math 2 | from os import makedirs 3 | from os.path import join, exists 4 | from pathlib import Path 5 | from urllib.request import urlretrieve 6 | from huggingface_hub import snapshot_download, hf_hub_download 7 | 8 | from configuration import service_logger, MODELS_PATH 9 | 10 | 11 | def download_progress(count, block_size, total_size): 12 | total_counts = total_size // block_size 13 | show_counts_percentages = total_counts // 5 14 | percent = count * block_size * 100 / total_size 15 | if count % show_counts_percentages == 0: 16 | service_logger.info(f"Downloaded {math.ceil(percent)}%") 17 | 18 | 19 | def download_vgt_model(model_name: str): 20 | service_logger.info(f"Downloading {model_name} model") 21 | model_path = join(MODELS_PATH, f"{model_name}_VGT_model.pth") 22 | if exists(model_path): 23 | return 24 | download_link = f"https://github.com/AlibabaResearch/AdvancedLiterateMachinery/releases/download/v1.3.0-VGT-release/{model_name}_VGT_model.pth" 25 | urlretrieve(download_link, model_path, reporthook=download_progress) 26 | 27 | 28 | def download_embedding_model(): 29 | model_path = join(MODELS_PATH, "layoutlm-base-uncased") 30 | if exists(model_path): 31 | return 32 | makedirs(model_path, exist_ok=True) 33 | service_logger.info("Embedding model is being downloaded") 34 | snapshot_download(repo_id="microsoft/layoutlm-base-uncased", local_dir=model_path, local_dir_use_symlinks=False) 35 | 36 | 37 | def download_from_hf_hub(path: Path): 38 | if path.exists(): 39 | return 40 | 41 | file_name = path.name 42 | makedirs(path.parent, exist_ok=True) 43 | repo_id = "HURIDOCS/pdf-document-layout-analysis" 44 | hf_hub_download(repo_id=repo_id, filename=file_name, local_dir=path.parent, local_dir_use_symlinks=False) 45 | 46 | 47 | def download_lightgbm_models(): 48 | download_from_hf_hub(Path(MODELS_PATH, "token_type_lightgbm.model")) 49 | download_from_hf_hub(Path(MODELS_PATH, "paragraph_extraction_lightgbm.model")) 50 | download_from_hf_hub(Path(MODELS_PATH, "config.json")) 51 | 52 | 53 | def download_models(model_name: str): 54 | makedirs(MODELS_PATH, exist_ok=True) 55 | if model_name == "fast": 56 | download_lightgbm_models() 57 | return 58 | download_vgt_model(model_name) 59 | download_embedding_model() 60 | 61 | 62 | if __name__ == "__main__": 63 | download_models("doclaynet") 64 | download_models("fast") 65 | -------------------------------------------------------------------------------- /src/drivers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/drivers/__init__.py -------------------------------------------------------------------------------- /src/drivers/web/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/drivers/web/__init__.py -------------------------------------------------------------------------------- /src/drivers/web/dependency_injection.py: -------------------------------------------------------------------------------- 1 | from adapters.storage.file_system_repository import FileSystemRepository 2 | from adapters.ml.vgt_model_adapter import VGTModelAdapter 3 | from adapters.ml.fast_trainer_adapter import FastTrainerAdapter 4 | from adapters.infrastructure.pdf_analysis_service_adapter import PDFAnalysisServiceAdapter 5 | from adapters.infrastructure.text_extraction_adapter import TextExtractionAdapter 6 | from adapters.infrastructure.toc_service_adapter import TOCServiceAdapter 7 | from adapters.infrastructure.visualization_service_adapter import VisualizationServiceAdapter 8 | from adapters.infrastructure.ocr_service_adapter import OCRServiceAdapter 9 | from adapters.infrastructure.format_conversion_service_adapter import FormatConversionServiceAdapter 10 | from adapters.infrastructure.markdown_conversion_service_adapter import MarkdownConversionServiceAdapter 11 | from adapters.infrastructure.html_conversion_service_adapter import HtmlConversionServiceAdapter 12 | from adapters.web.fastapi_controllers import FastAPIControllers 13 | from use_cases.pdf_analysis.analyze_pdf_use_case import AnalyzePDFUseCase 14 | from use_cases.text_extraction.extract_text_use_case import ExtractTextUseCase 15 | from use_cases.toc_extraction.extract_toc_use_case import ExtractTOCUseCase 16 | from use_cases.visualization.create_visualization_use_case import CreateVisualizationUseCase 17 | from use_cases.ocr.process_ocr_use_case import ProcessOCRUseCase 18 | from use_cases.markdown_conversion.convert_to_markdown_use_case import ConvertToMarkdownUseCase 19 | from use_cases.html_conversion.convert_to_html_use_case import ConvertToHtmlUseCase 20 | 21 | 22 | def setup_dependencies(): 23 | file_repository = FileSystemRepository() 24 | 25 | vgt_model_service = VGTModelAdapter() 26 | fast_model_service = FastTrainerAdapter() 27 | 28 | format_conversion_service = FormatConversionServiceAdapter() 29 | markdown_conversion_service = MarkdownConversionServiceAdapter() 30 | html_conversion_service = HtmlConversionServiceAdapter() 31 | text_extraction_service = TextExtractionAdapter() 32 | toc_service = TOCServiceAdapter() 33 | visualization_service = VisualizationServiceAdapter() 34 | ocr_service = OCRServiceAdapter() 35 | 36 | pdf_analysis_service = PDFAnalysisServiceAdapter( 37 | vgt_model_service=vgt_model_service, 38 | fast_model_service=fast_model_service, 39 | format_conversion_service=format_conversion_service, 40 | file_repository=file_repository, 41 | ) 42 | 43 | analyze_pdf_use_case = AnalyzePDFUseCase(pdf_analysis_service=pdf_analysis_service, ml_model_service=vgt_model_service) 44 | 45 | extract_text_use_case = ExtractTextUseCase( 46 | pdf_analysis_service=pdf_analysis_service, text_extraction_service=text_extraction_service 47 | ) 48 | 49 | extract_toc_use_case = ExtractTOCUseCase(pdf_analysis_service=pdf_analysis_service, toc_service=toc_service) 50 | 51 | create_visualization_use_case = CreateVisualizationUseCase( 52 | pdf_analysis_service=pdf_analysis_service, visualization_service=visualization_service 53 | ) 54 | 55 | process_ocr_use_case = ProcessOCRUseCase(ocr_service=ocr_service, file_repository=file_repository) 56 | 57 | convert_to_markdown_use_case = ConvertToMarkdownUseCase( 58 | pdf_analysis_service=pdf_analysis_service, markdown_conversion_service=markdown_conversion_service 59 | ) 60 | 61 | convert_to_html_use_case = ConvertToHtmlUseCase( 62 | pdf_analysis_service=pdf_analysis_service, html_conversion_service=html_conversion_service 63 | ) 64 | 65 | controllers = FastAPIControllers( 66 | analyze_pdf_use_case=analyze_pdf_use_case, 67 | extract_text_use_case=extract_text_use_case, 68 | extract_toc_use_case=extract_toc_use_case, 69 | create_visualization_use_case=create_visualization_use_case, 70 | process_ocr_use_case=process_ocr_use_case, 71 | convert_to_markdown_use_case=convert_to_markdown_use_case, 72 | convert_to_html_use_case=convert_to_html_use_case, 73 | file_repository=file_repository, 74 | ) 75 | 76 | return controllers 77 | -------------------------------------------------------------------------------- /src/drivers/web/fastapi_app.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from fastapi import FastAPI 3 | from fastapi.responses import PlainTextResponse 4 | from adapters.web.fastapi_controllers import FastAPIControllers 5 | from catch_exceptions import catch_exceptions 6 | from configuration import service_logger 7 | 8 | 9 | def create_app(controllers: FastAPIControllers) -> FastAPI: 10 | service_logger.info(f"Is PyTorch using GPU: {torch.cuda.is_available()}") 11 | 12 | app = FastAPI() 13 | 14 | app.get("/")(controllers.root) 15 | app.get("/info")(controllers.info) 16 | app.get("/error")(controllers.error) 17 | 18 | app.post("/")(catch_exceptions(controllers.analyze_pdf)) 19 | app.post("/save_xml/{xml_file_name}")(catch_exceptions(controllers.analyze_and_save_xml)) 20 | app.get("/get_xml/{xml_file_name}", response_class=PlainTextResponse)(catch_exceptions(controllers.get_xml_by_name)) 21 | 22 | app.post("/toc")(catch_exceptions(controllers.get_toc_endpoint)) 23 | app.post("/toc_legacy_uwazi_compatible")(catch_exceptions(controllers.toc_legacy_uwazi_compatible)) 24 | 25 | app.post("/text")(catch_exceptions(controllers.get_text_endpoint)) 26 | app.post("/visualize")(catch_exceptions(controllers.get_visualization_endpoint)) 27 | app.post("/markdown", response_model=None)(catch_exceptions(controllers.convert_to_markdown_endpoint)) 28 | app.post("/html", response_model=None)(catch_exceptions(controllers.convert_to_html_endpoint)) 29 | app.post("/ocr")(catch_exceptions(controllers.ocr_pdf_sync)) 30 | 31 | return app 32 | -------------------------------------------------------------------------------- /src/ports/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/ports/__init__.py -------------------------------------------------------------------------------- /src/ports/repositories/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/ports/repositories/__init__.py -------------------------------------------------------------------------------- /src/ports/repositories/file_repository.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from pathlib import Path 3 | from typing import AnyStr 4 | 5 | 6 | class FileRepository(ABC): 7 | @abstractmethod 8 | def save_pdf(self, content: AnyStr, filename: str = "") -> Path: 9 | pass 10 | 11 | @abstractmethod 12 | def save_xml(self, content: str, filename: str) -> Path: 13 | pass 14 | 15 | @abstractmethod 16 | def get_xml(self, filename: str) -> str: 17 | pass 18 | 19 | @abstractmethod 20 | def delete_file(self, filepath: Path) -> None: 21 | pass 22 | 23 | @abstractmethod 24 | def cleanup_temp_files(self) -> None: 25 | pass 26 | 27 | @abstractmethod 28 | def save_pdf_to_directory(self, content: AnyStr, filename: str, directory: Path, namespace: str = "") -> Path: 29 | pass 30 | 31 | @abstractmethod 32 | def save_markdown(self, content: str, filepath: Path) -> Path: 33 | pass 34 | -------------------------------------------------------------------------------- /src/ports/services/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/ports/services/__init__.py -------------------------------------------------------------------------------- /src/ports/services/format_conversion_service.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from domain.PdfImages import PdfImages 3 | from domain.PdfSegment import PdfSegment 4 | 5 | 6 | class FormatConversionService(ABC): 7 | 8 | @abstractmethod 9 | def convert_table_to_html(self, pdf_images: PdfImages, segments: list[PdfSegment]) -> None: 10 | pass 11 | 12 | @abstractmethod 13 | def convert_formula_to_latex(self, pdf_images: PdfImages, segments: list[PdfSegment]) -> None: 14 | pass 15 | -------------------------------------------------------------------------------- /src/ports/services/html_conversion_service.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Optional, Union 3 | from starlette.responses import Response 4 | from domain.SegmentBox import SegmentBox 5 | 6 | 7 | class HtmlConversionService(ABC): 8 | 9 | @abstractmethod 10 | def convert_to_html( 11 | self, 12 | pdf_content: bytes, 13 | segments: list[SegmentBox], 14 | extract_toc: bool = False, 15 | dpi: int = 120, 16 | output_file: Optional[str] = None, 17 | target_languages: Optional[list[str]] = None, 18 | translation_model: str = "gpt-oss", 19 | ) -> Union[str, Response]: 20 | pass 21 | -------------------------------------------------------------------------------- /src/ports/services/markdown_conversion_service.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Optional, Union 3 | from starlette.responses import Response 4 | from domain.SegmentBox import SegmentBox 5 | 6 | 7 | class MarkdownConversionService(ABC): 8 | 9 | @abstractmethod 10 | def convert_to_markdown( 11 | self, 12 | pdf_content: bytes, 13 | segments: list[SegmentBox], 14 | extract_toc: bool = False, 15 | dpi: int = 120, 16 | output_file: Optional[str] = None, 17 | target_languages: Optional[list[str]] = None, 18 | translation_model: str = "gpt-oss", 19 | ) -> Union[str, Response]: 20 | pass 21 | -------------------------------------------------------------------------------- /src/ports/services/ml_model_service.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from domain.PdfImages import PdfImages 3 | from domain.PdfSegment import PdfSegment 4 | 5 | 6 | class MLModelService(ABC): 7 | @abstractmethod 8 | def predict_document_layout(self, pdf_images: list[PdfImages]) -> list[PdfSegment]: 9 | pass 10 | 11 | @abstractmethod 12 | def predict_layout_fast(self, pdf_images: list[PdfImages]) -> list[PdfSegment]: 13 | pass 14 | -------------------------------------------------------------------------------- /src/ports/services/ocr_service.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from pathlib import Path 3 | 4 | 5 | class OCRService(ABC): 6 | @abstractmethod 7 | def process_pdf_ocr(self, filename: str, namespace: str, language: str = "en") -> Path: 8 | pass 9 | 10 | @abstractmethod 11 | def get_supported_languages(self) -> list[str]: 12 | pass 13 | -------------------------------------------------------------------------------- /src/ports/services/pdf_analysis_service.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import AnyStr 3 | 4 | 5 | class PDFAnalysisService(ABC): 6 | @abstractmethod 7 | def analyze_pdf_layout( 8 | self, pdf_content: AnyStr, xml_filename: str = "", parse_tables_and_math: bool = False, keep_pdf: bool = False 9 | ) -> list[dict]: 10 | pass 11 | 12 | @abstractmethod 13 | def analyze_pdf_layout_fast( 14 | self, pdf_content: AnyStr, xml_filename: str = "", parse_tables_and_math: bool = False, keep_pdf: bool = False 15 | ) -> list[dict]: 16 | pass 17 | -------------------------------------------------------------------------------- /src/ports/services/text_extraction_service.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from pdf_token_type_labels import TokenType 3 | 4 | 5 | class TextExtractionService(ABC): 6 | @abstractmethod 7 | def extract_text_by_types(self, segment_boxes: list[dict], token_types: list[TokenType]) -> dict: 8 | pass 9 | 10 | @abstractmethod 11 | def extract_all_text(self, segment_boxes: list[dict]) -> dict: 12 | pass 13 | -------------------------------------------------------------------------------- /src/ports/services/toc_service.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import AnyStr 3 | 4 | 5 | class TOCService(ABC): 6 | @abstractmethod 7 | def extract_table_of_contents(self, pdf_content: AnyStr, segment_boxes: list[dict]) -> list[dict]: 8 | pass 9 | 10 | @abstractmethod 11 | def format_toc_for_uwazi(self, toc_items: list[dict]) -> list[dict]: 12 | pass 13 | -------------------------------------------------------------------------------- /src/ports/services/visualization_service.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from pathlib import Path 3 | from starlette.responses import FileResponse 4 | 5 | 6 | class VisualizationService(ABC): 7 | @abstractmethod 8 | def create_pdf_visualization(self, pdf_path: Path, segment_boxes: list[dict]) -> Path: 9 | pass 10 | 11 | @abstractmethod 12 | def get_visualization_response(self, pdf_path: Path) -> FileResponse: 13 | pass 14 | -------------------------------------------------------------------------------- /src/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/tests/__init__.py -------------------------------------------------------------------------------- /src/use_cases/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/use_cases/__init__.py -------------------------------------------------------------------------------- /src/use_cases/html_conversion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/use_cases/html_conversion/__init__.py -------------------------------------------------------------------------------- /src/use_cases/html_conversion/convert_to_html_use_case.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | from starlette.responses import Response 3 | from ports.services.html_conversion_service import HtmlConversionService 4 | from ports.services.pdf_analysis_service import PDFAnalysisService 5 | from domain.SegmentBox import SegmentBox 6 | 7 | 8 | class ConvertToHtmlUseCase: 9 | def __init__( 10 | self, 11 | pdf_analysis_service: PDFAnalysisService, 12 | html_conversion_service: HtmlConversionService, 13 | ): 14 | self.pdf_analysis_service = pdf_analysis_service 15 | self.html_conversion_service = html_conversion_service 16 | 17 | def execute( 18 | self, 19 | pdf_content: bytes, 20 | use_fast_mode: bool = False, 21 | extract_toc: bool = False, 22 | dpi: int = 120, 23 | output_file: Optional[str] = None, 24 | target_languages: Optional[list[str]] = None, 25 | translation_model: str = "gpt-oss", 26 | ) -> Union[str, Response]: 27 | if use_fast_mode: 28 | analysis_result = self.pdf_analysis_service.analyze_pdf_layout_fast(pdf_content, "", True, False) 29 | else: 30 | analysis_result = self.pdf_analysis_service.analyze_pdf_layout(pdf_content, "", True, False) 31 | 32 | segments: list[SegmentBox] = [] 33 | for item in analysis_result: 34 | if isinstance(item, dict): 35 | segment = SegmentBox( 36 | left=item.get("left", 0), 37 | top=item.get("top", 0), 38 | width=item.get("width", 0), 39 | height=item.get("height", 0), 40 | page_number=item.get("page_number", 1), 41 | page_width=item.get("page_width", 0), 42 | page_height=item.get("page_height", 0), 43 | text=item.get("text", ""), 44 | type=item.get("type", "TEXT"), 45 | ) 46 | segments.append(segment) 47 | elif isinstance(item, SegmentBox): 48 | segments.append(item) 49 | 50 | return self.html_conversion_service.convert_to_html( 51 | pdf_content, segments, extract_toc, dpi, output_file, target_languages, translation_model 52 | ) 53 | -------------------------------------------------------------------------------- /src/use_cases/markdown_conversion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/use_cases/markdown_conversion/__init__.py -------------------------------------------------------------------------------- /src/use_cases/markdown_conversion/convert_to_markdown_use_case.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | from starlette.responses import Response 3 | from ports.services.markdown_conversion_service import MarkdownConversionService 4 | from ports.services.pdf_analysis_service import PDFAnalysisService 5 | from domain.SegmentBox import SegmentBox 6 | 7 | 8 | class ConvertToMarkdownUseCase: 9 | def __init__( 10 | self, 11 | pdf_analysis_service: PDFAnalysisService, 12 | markdown_conversion_service: MarkdownConversionService, 13 | ): 14 | self.pdf_analysis_service = pdf_analysis_service 15 | self.markdown_conversion_service = markdown_conversion_service 16 | 17 | def execute( 18 | self, 19 | pdf_content: bytes, 20 | use_fast_mode: bool = False, 21 | extract_toc: bool = False, 22 | dpi: int = 120, 23 | output_file: Optional[str] = None, 24 | target_languages: Optional[list[str]] = None, 25 | translation_model: str = "gpt-oss", 26 | ) -> Union[str, Response]: 27 | if use_fast_mode: 28 | analysis_result = self.pdf_analysis_service.analyze_pdf_layout_fast(pdf_content, "", True, False) 29 | else: 30 | analysis_result = self.pdf_analysis_service.analyze_pdf_layout(pdf_content, "", True, False) 31 | 32 | segments: list[SegmentBox] = [] 33 | for item in analysis_result: 34 | if isinstance(item, dict): 35 | segment = SegmentBox( 36 | left=item.get("left", 0), 37 | top=item.get("top", 0), 38 | width=item.get("width", 0), 39 | height=item.get("height", 0), 40 | page_number=item.get("page_number", 1), 41 | page_width=item.get("page_width", 0), 42 | page_height=item.get("page_height", 0), 43 | text=item.get("text", ""), 44 | type=item.get("type", "TEXT"), 45 | ) 46 | segments.append(segment) 47 | elif isinstance(item, SegmentBox): 48 | segments.append(item) 49 | 50 | return self.markdown_conversion_service.convert_to_markdown( 51 | pdf_content, segments, extract_toc, dpi, output_file, target_languages, translation_model 52 | ) 53 | -------------------------------------------------------------------------------- /src/use_cases/ocr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/use_cases/ocr/__init__.py -------------------------------------------------------------------------------- /src/use_cases/ocr/process_ocr_use_case.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from fastapi import UploadFile 3 | from starlette.responses import FileResponse 4 | from ports.services.ocr_service import OCRService 5 | from ports.repositories.file_repository import FileRepository 6 | from configuration import OCR_SOURCE 7 | 8 | 9 | class ProcessOCRUseCase: 10 | def __init__(self, ocr_service: OCRService, file_repository: FileRepository): 11 | self.ocr_service = ocr_service 12 | self.file_repository = file_repository 13 | 14 | def execute(self, file: UploadFile, language: str = "en") -> FileResponse: 15 | namespace = "sync_pdfs" 16 | 17 | self.file_repository.save_pdf_to_directory( 18 | content=file.file.read(), filename=file.filename, directory=Path(OCR_SOURCE), namespace=namespace 19 | ) 20 | 21 | processed_pdf_filepath = self.ocr_service.process_pdf_ocr(file.filename, namespace, language) 22 | 23 | return FileResponse(path=processed_pdf_filepath, media_type="application/pdf") 24 | 25 | def get_supported_languages(self) -> list: 26 | return self.ocr_service.get_supported_languages() 27 | -------------------------------------------------------------------------------- /src/use_cases/pdf_analysis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/use_cases/pdf_analysis/__init__.py -------------------------------------------------------------------------------- /src/use_cases/pdf_analysis/analyze_pdf_use_case.py: -------------------------------------------------------------------------------- 1 | from typing import AnyStr 2 | from ports.services.pdf_analysis_service import PDFAnalysisService 3 | from ports.services.ml_model_service import MLModelService 4 | 5 | 6 | class AnalyzePDFUseCase: 7 | def __init__( 8 | self, 9 | pdf_analysis_service: PDFAnalysisService, 10 | ml_model_service: MLModelService, 11 | ): 12 | self.pdf_analysis_service = pdf_analysis_service 13 | self.ml_model_service = ml_model_service 14 | 15 | def execute( 16 | self, 17 | pdf_content: AnyStr, 18 | xml_filename: str = "", 19 | parse_tables_and_math: bool = False, 20 | use_fast_mode: bool = False, 21 | keep_pdf: bool = False, 22 | ) -> list[dict]: 23 | if use_fast_mode: 24 | return self.pdf_analysis_service.analyze_pdf_layout_fast( 25 | pdf_content, xml_filename, parse_tables_and_math, keep_pdf 26 | ) 27 | else: 28 | return self.pdf_analysis_service.analyze_pdf_layout(pdf_content, xml_filename, parse_tables_and_math, keep_pdf) 29 | 30 | def execute_and_save_xml(self, pdf_content: AnyStr, xml_filename: str, use_fast_mode: bool = False) -> list[dict]: 31 | result = self.execute(pdf_content, xml_filename, False, use_fast_mode, keep_pdf=False) 32 | return result 33 | -------------------------------------------------------------------------------- /src/use_cases/text_extraction/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/use_cases/text_extraction/__init__.py -------------------------------------------------------------------------------- /src/use_cases/text_extraction/extract_text_use_case.py: -------------------------------------------------------------------------------- 1 | from fastapi import UploadFile 2 | from pdf_token_type_labels import TokenType 3 | from ports.services.pdf_analysis_service import PDFAnalysisService 4 | from ports.services.text_extraction_service import TextExtractionService 5 | 6 | 7 | class ExtractTextUseCase: 8 | def __init__(self, pdf_analysis_service: PDFAnalysisService, text_extraction_service: TextExtractionService): 9 | self.pdf_analysis_service = pdf_analysis_service 10 | self.text_extraction_service = text_extraction_service 11 | 12 | def execute(self, file: UploadFile, use_fast_mode: bool = False, types: str = "all") -> dict: 13 | file_content = file.file.read() 14 | 15 | if types == "all": 16 | token_types: list[TokenType] = [t for t in TokenType] 17 | else: 18 | token_types = list(set([TokenType.from_text(t.strip().replace(" ", "_")) for t in types.split(",")])) 19 | 20 | if use_fast_mode: 21 | segment_boxes = self.pdf_analysis_service.analyze_pdf_layout_fast(file_content) 22 | else: 23 | segment_boxes = self.pdf_analysis_service.analyze_pdf_layout(file_content, "") 24 | 25 | return self.text_extraction_service.extract_text_by_types(segment_boxes, token_types) 26 | -------------------------------------------------------------------------------- /src/use_cases/toc_extraction/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/use_cases/toc_extraction/__init__.py -------------------------------------------------------------------------------- /src/use_cases/toc_extraction/extract_toc_use_case.py: -------------------------------------------------------------------------------- 1 | from fastapi import UploadFile 2 | from ports.services.pdf_analysis_service import PDFAnalysisService 3 | from ports.services.toc_service import TOCService 4 | 5 | 6 | class ExtractTOCUseCase: 7 | def __init__(self, pdf_analysis_service: PDFAnalysisService, toc_service: TOCService): 8 | self.pdf_analysis_service = pdf_analysis_service 9 | self.toc_service = toc_service 10 | 11 | def execute(self, file: UploadFile, use_fast_mode: bool = False) -> list[dict]: 12 | file_content = file.file.read() 13 | 14 | if use_fast_mode: 15 | segment_boxes = self.pdf_analysis_service.analyze_pdf_layout_fast(file_content) 16 | else: 17 | segment_boxes = self.pdf_analysis_service.analyze_pdf_layout(file_content, "") 18 | 19 | return self.toc_service.extract_table_of_contents(file_content, segment_boxes) 20 | 21 | def execute_uwazi_compatible(self, file: UploadFile) -> list[dict]: 22 | toc_items = self.execute(file, use_fast_mode=True) 23 | return self.toc_service.format_toc_for_uwazi(toc_items) 24 | -------------------------------------------------------------------------------- /src/use_cases/visualization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/use_cases/visualization/__init__.py -------------------------------------------------------------------------------- /src/use_cases/visualization/create_visualization_use_case.py: -------------------------------------------------------------------------------- 1 | from fastapi import UploadFile 2 | from starlette.responses import FileResponse 3 | from ports.services.pdf_analysis_service import PDFAnalysisService 4 | from ports.services.visualization_service import VisualizationService 5 | from glob import glob 6 | from os.path import getctime, join 7 | from tempfile import gettempdir 8 | from pathlib import Path 9 | 10 | 11 | class CreateVisualizationUseCase: 12 | def __init__(self, pdf_analysis_service: PDFAnalysisService, visualization_service: VisualizationService): 13 | self.pdf_analysis_service = pdf_analysis_service 14 | self.visualization_service = visualization_service 15 | 16 | def execute(self, file: UploadFile, use_fast_mode: bool = False) -> FileResponse: 17 | file_content = file.file.read() 18 | 19 | if use_fast_mode: 20 | segment_boxes = self.pdf_analysis_service.analyze_pdf_layout_fast(file_content, "", "", True) 21 | else: 22 | segment_boxes = self.pdf_analysis_service.analyze_pdf_layout(file_content, "", "", True) 23 | 24 | pdf_path = Path(max(glob(join(gettempdir(), "*.pdf")), key=getctime)) 25 | visualization_path = self.visualization_service.create_pdf_visualization(pdf_path, segment_boxes) 26 | 27 | return self.visualization_service.get_visualization_response(visualization_path) 28 | -------------------------------------------------------------------------------- /start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | gunicorn -k uvicorn.workers.UvicornWorker --chdir ./src app:app --bind 0.0.0.0:5060 --timeout 10000 -------------------------------------------------------------------------------- /test_pdfs/blank.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/test_pdfs/blank.pdf -------------------------------------------------------------------------------- /test_pdfs/chinese.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/test_pdfs/chinese.pdf -------------------------------------------------------------------------------- /test_pdfs/error.pdf: -------------------------------------------------------------------------------- 1 | error -------------------------------------------------------------------------------- /test_pdfs/formula.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/test_pdfs/formula.pdf -------------------------------------------------------------------------------- /test_pdfs/image.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/test_pdfs/image.pdf -------------------------------------------------------------------------------- /test_pdfs/korean.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/test_pdfs/korean.pdf -------------------------------------------------------------------------------- /test_pdfs/not_a_pdf.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/test_pdfs/not_a_pdf.pdf -------------------------------------------------------------------------------- /test_pdfs/ocr-sample-already-ocred.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/test_pdfs/ocr-sample-already-ocred.pdf -------------------------------------------------------------------------------- /test_pdfs/ocr-sample-english.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/test_pdfs/ocr-sample-english.pdf -------------------------------------------------------------------------------- /test_pdfs/ocr-sample-french.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/test_pdfs/ocr-sample-french.pdf -------------------------------------------------------------------------------- /test_pdfs/ocr_pdf.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/test_pdfs/ocr_pdf.pdf -------------------------------------------------------------------------------- /test_pdfs/regular.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/test_pdfs/regular.pdf -------------------------------------------------------------------------------- /test_pdfs/some_empty_pages.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/test_pdfs/some_empty_pages.pdf -------------------------------------------------------------------------------- /test_pdfs/table.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/test_pdfs/table.pdf -------------------------------------------------------------------------------- /test_pdfs/test.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/test_pdfs/test.pdf -------------------------------------------------------------------------------- /test_pdfs/toc-test.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/test_pdfs/toc-test.pdf --------------------------------------------------------------------------------