├── .dockerignore
├── .github
├── FUNDING.yml
├── dependabot.yml
└── workflows
│ ├── push_docker_image.yml
│ └── test.yml
├── .gitignore
├── .well-known
└── funding-manifest-urls
├── Dockerfile
├── Dockerfile.ollama
├── LICENSE
├── Makefile
├── README.md
├── dev-requirements.txt
├── docker-compose-gpu.yml
├── docker-compose.yml
├── fine_tuning_lightgbm_models.ipynb
├── images
├── vgtexample1.png
├── vgtexample2.png
├── vgtexample3.png
└── vgtexample4.png
├── justfile
├── pyproject.toml
├── requirements.txt
├── src
├── adapters
│ ├── __init__.py
│ ├── infrastructure
│ │ ├── __init__.py
│ │ ├── format_conversion_service_adapter.py
│ │ ├── format_converters
│ │ │ ├── __init__.py
│ │ │ ├── convert_formula_to_latex.py
│ │ │ └── convert_table_to_html.py
│ │ ├── html_conversion_service_adapter.py
│ │ ├── markdown_conversion_service_adapter.py
│ │ ├── markup_conversion
│ │ │ ├── ExtractedImage.py
│ │ │ ├── Link.py
│ │ │ ├── OutputFormat.py
│ │ │ ├── __init__.py
│ │ │ └── pdf_to_markup_service_adapter.py
│ │ ├── ocr
│ │ │ ├── __init__.py
│ │ │ └── languages.py
│ │ ├── ocr_service_adapter.py
│ │ ├── pdf_analysis_service_adapter.py
│ │ ├── text_extraction_adapter.py
│ │ ├── toc
│ │ │ ├── MergeTwoSegmentsTitles.py
│ │ │ ├── PdfSegmentation.py
│ │ │ ├── TOCExtractor.py
│ │ │ ├── TitleFeatures.py
│ │ │ ├── __init__.py
│ │ │ ├── data
│ │ │ │ ├── TOCItem.py
│ │ │ │ └── __init__.py
│ │ │ ├── extract_table_of_contents.py
│ │ │ └── methods
│ │ │ │ ├── __init__.py
│ │ │ │ └── two_models_v3_segments_context_2
│ │ │ │ ├── Modes.py
│ │ │ │ └── __init__.py
│ │ ├── toc_service_adapter.py
│ │ ├── translation
│ │ │ ├── decode_html_content.py
│ │ │ ├── decode_markdown_content.py
│ │ │ ├── download_translation_model.py
│ │ │ ├── encode_html_content.py
│ │ │ ├── encode_markdown_content.py
│ │ │ ├── ollama_container_manager.py
│ │ │ └── translate_markup_document.py
│ │ └── visualization_service_adapter.py
│ ├── ml
│ │ ├── __init__.py
│ │ ├── fast_trainer
│ │ │ ├── Paragraph.py
│ │ │ ├── ParagraphExtractorTrainer.py
│ │ │ ├── __init__.py
│ │ │ └── model_configuration.py
│ │ ├── fast_trainer_adapter.py
│ │ ├── pdf_tokens_type_trainer
│ │ │ ├── ModelConfiguration.py
│ │ │ ├── PdfTrainer.py
│ │ │ ├── TokenFeatures.py
│ │ │ ├── TokenTypeTrainer.py
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ ├── download_models.py
│ │ │ ├── get_paths.py
│ │ │ └── tests
│ │ │ │ ├── __init__.py
│ │ │ │ └── test_trainer.py
│ │ ├── vgt
│ │ │ ├── __init__.py
│ │ │ ├── bros
│ │ │ │ ├── __init__.py
│ │ │ │ ├── configuration_bros.py
│ │ │ │ ├── modeling_bros.py
│ │ │ │ ├── tokenization_bros.py
│ │ │ │ └── tokenization_bros_fast.py
│ │ │ ├── create_word_grid.py
│ │ │ ├── ditod
│ │ │ │ ├── FeatureMerge.py
│ │ │ │ ├── VGT.py
│ │ │ │ ├── VGTTrainer.py
│ │ │ │ ├── VGTbackbone.py
│ │ │ │ ├── VGTbeit.py
│ │ │ │ ├── VGTcheckpointer.py
│ │ │ │ ├── Wordnn_embedding.py
│ │ │ │ ├── __init__.py
│ │ │ │ ├── config.py
│ │ │ │ ├── dataset_mapper.py
│ │ │ │ ├── tokenization_bros.py
│ │ │ │ └── utils.py
│ │ │ ├── get_json_annotations.py
│ │ │ ├── get_model_configuration.py
│ │ │ ├── get_most_probable_pdf_segments.py
│ │ │ ├── get_reading_orders.py
│ │ │ └── model_configuration
│ │ │ │ ├── Base-RCNN-FPN.yaml
│ │ │ │ ├── doclaynet_VGT_cascade_PTM.yaml
│ │ │ │ └── doclaynet_configuration.pickle
│ │ └── vgt_model_adapter.py
│ ├── storage
│ │ ├── __init__.py
│ │ └── file_system_repository.py
│ └── web
│ │ ├── __init__.py
│ │ └── fastapi_controllers.py
├── app.py
├── catch_exceptions.py
├── configuration.py
├── domain
│ ├── PdfImages.py
│ ├── PdfSegment.py
│ ├── Prediction.py
│ └── SegmentBox.py
├── download_models.py
├── drivers
│ ├── __init__.py
│ └── web
│ │ ├── __init__.py
│ │ ├── dependency_injection.py
│ │ └── fastapi_app.py
├── ports
│ ├── __init__.py
│ ├── repositories
│ │ ├── __init__.py
│ │ └── file_repository.py
│ └── services
│ │ ├── __init__.py
│ │ ├── format_conversion_service.py
│ │ ├── html_conversion_service.py
│ │ ├── markdown_conversion_service.py
│ │ ├── ml_model_service.py
│ │ ├── ocr_service.py
│ │ ├── pdf_analysis_service.py
│ │ ├── text_extraction_service.py
│ │ ├── toc_service.py
│ │ └── visualization_service.py
├── tests
│ ├── __init__.py
│ └── test_end_to_end.py
└── use_cases
│ ├── __init__.py
│ ├── html_conversion
│ ├── __init__.py
│ └── convert_to_html_use_case.py
│ ├── markdown_conversion
│ ├── __init__.py
│ └── convert_to_markdown_use_case.py
│ ├── ocr
│ ├── __init__.py
│ └── process_ocr_use_case.py
│ ├── pdf_analysis
│ ├── __init__.py
│ └── analyze_pdf_use_case.py
│ ├── text_extraction
│ ├── __init__.py
│ └── extract_text_use_case.py
│ ├── toc_extraction
│ ├── __init__.py
│ └── extract_toc_use_case.py
│ └── visualization
│ ├── __init__.py
│ └── create_visualization_use_case.py
├── start.sh
└── test_pdfs
├── blank.pdf
├── chinese.pdf
├── error.pdf
├── formula.pdf
├── image.pdf
├── korean.pdf
├── not_a_pdf.pdf
├── ocr-sample-already-ocred.pdf
├── ocr-sample-english.pdf
├── ocr-sample-french.pdf
├── ocr_pdf.pdf
├── regular.pdf
├── some_empty_pages.pdf
├── table.pdf
├── test.pdf
└── toc-test.pdf
/.dockerignore:
--------------------------------------------------------------------------------
1 | /venv/
2 | /.venv/
3 | .git
4 | /detectron2/
5 | /images/
--------------------------------------------------------------------------------
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | custom: ["https://huridocs.org/donate/"]
2 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: "pip"
4 | directory: "/"
5 | schedule:
6 | interval: "daily"
7 | open-pull-requests-limit: 5
8 | labels:
9 | - "dependencies"
10 | - package-ecosystem: "github-actions"
11 | directory: "/"
12 | schedule:
13 | interval: "daily"
14 | - package-ecosystem: "docker"
15 | directory: "/"
16 | schedule:
17 | interval: "daily"
18 |
--------------------------------------------------------------------------------
/.github/workflows/push_docker_image.yml:
--------------------------------------------------------------------------------
1 | name: Create and publish Docker image
2 |
3 | on:
4 | push:
5 | tags:
6 | - 'v*'
7 |
8 | env:
9 | REGISTRY: ghcr.io
10 | IMAGE_NAME: huridocs/pdf-document-layout-analysis
11 |
12 | jobs:
13 | build-and-push-image:
14 | runs-on: ubuntu-latest
15 | permissions:
16 | contents: read
17 | packages: write
18 | steps:
19 | - name: Checkout repository
20 | uses: actions/checkout@v4
21 |
22 | - name: Install dependencies
23 | run: sudo apt-get install -y just
24 |
25 | - name: Log in to the Container registry
26 | uses: docker/login-action@v3
27 | with:
28 | registry: ${{ env.REGISTRY }}
29 | username: ${{ github.actor }}
30 | password: ${{ secrets.GITHUB_TOKEN }}
31 |
32 | - name: Extract metadata (tags, labels) for Docker
33 | id: meta
34 | uses: docker/metadata-action@v5
35 | with:
36 | images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
37 | tags: |
38 | type=ref,event=branch
39 | type=ref,event=pr
40 | type=semver,pattern={{version}}
41 | type=semver,pattern={{major}}.{{minor}}
42 |
43 | - name: Create folder models
44 | run: mkdir -p models
45 |
46 | - name: Build and push
47 | uses: docker/build-push-action@v6
48 | with:
49 | context: .
50 | file: Dockerfile
51 | push: ${{ github.event_name != 'pull_request' }}
52 | tags: ${{ steps.meta.outputs.tags }}
53 | labels: ${{ steps.meta.outputs.labels }}
54 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: Test
5 |
6 | on:
7 | push:
8 | branches: [ main ]
9 | pull_request:
10 | branches: [ main ]
11 |
12 | jobs:
13 | build:
14 |
15 | runs-on: ubuntu-latest
16 |
17 | steps:
18 | - uses: actions/checkout@v4
19 | - name: Set up Python 3.11
20 | uses: actions/setup-python@v5
21 | with:
22 | python-version: '3.11'
23 |
24 | - name: Install dependencies
25 | run: sudo apt-get update; sudo apt-get install -y pdftohtml qpdf just
26 |
27 | - name: Free up space
28 | run: just free_up_space
29 |
30 | - name: Install venv
31 | run: just install_venv
32 |
33 | - name: Lint with black
34 | run: just check_format
35 |
36 | - name: Start service
37 | run: just start_detached
38 |
39 | - name: Check API ready
40 | uses: emilioschepis/wait-for-endpoint@v1.0.3
41 | with:
42 | url: http://localhost:5060
43 | method: GET
44 | expected-status: 200
45 | timeout: 120000
46 | interval: 500
47 |
48 | - name: Test with unittest
49 | run: just test
50 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | .idea/
161 | /models/
162 | /word_grids/
163 | /jsons/
164 | /model_output/
165 | /pdf_outputs/
166 | /detectron2/
167 | /ocr/
168 |
--------------------------------------------------------------------------------
/.well-known/funding-manifest-urls:
--------------------------------------------------------------------------------
1 | https://huridocs.org/funding/funding.json
2 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM pytorch/pytorch:2.4.0-cuda11.8-cudnn9-runtime
2 | COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
3 |
4 | RUN apt-get update
5 | RUN apt-get install --fix-missing -y -q --no-install-recommends libgomp1 ffmpeg libsm6 pdftohtml libxext6 git ninja-build g++ qpdf pandoc curl
6 |
7 |
8 | RUN apt-get install -y ocrmypdf
9 | RUN apt-get install -y tesseract-ocr-fra
10 | RUN apt-get install -y tesseract-ocr-spa
11 | RUN apt-get install -y tesseract-ocr-deu
12 | RUN apt-get install -y tesseract-ocr-ara
13 | RUN apt-get install -y tesseract-ocr-mya
14 | RUN apt-get install -y tesseract-ocr-hin
15 | RUN apt-get install -y tesseract-ocr-tam
16 | RUN apt-get install -y tesseract-ocr-tha
17 | RUN apt-get install -y tesseract-ocr-chi-sim
18 | RUN apt-get install -y tesseract-ocr-tur
19 | RUN apt-get install -y tesseract-ocr-ukr
20 | RUN apt-get install -y tesseract-ocr-ell
21 | RUN apt-get install -y tesseract-ocr-rus
22 | RUN apt-get install -y tesseract-ocr-kor
23 | RUN apt-get install -y tesseract-ocr-kor-vert
24 |
25 |
26 | RUN mkdir -p /app/src
27 | RUN mkdir -p /app/models
28 |
29 | RUN addgroup --system python && adduser --system --group python
30 | RUN chown -R python:python /app
31 | USER python
32 |
33 | ENV VIRTUAL_ENV=/app/.venv
34 | RUN python -m venv $VIRTUAL_ENV
35 | ENV PATH="$VIRTUAL_ENV/bin:$PATH"
36 |
37 | COPY requirements.txt requirements.txt
38 | RUN uv pip install --upgrade pip
39 | RUN uv pip install -r requirements.txt
40 |
41 | WORKDIR /app
42 |
43 | RUN cd src; git clone https://github.com/facebookresearch/detectron2;
44 | RUN cd src/detectron2; git checkout 70f454304e1a38378200459dd2dbca0f0f4a5ab4; python setup.py build develop
45 | RUN uv pip install pycocotools==2.0.8
46 |
47 | COPY ./start.sh ./start.sh
48 | COPY ./src/. ./src
49 | COPY ./models/. ./models/
50 | RUN python src/download_models.py
51 |
52 | ENV PYTHONPATH "${PYTHONPATH}:/app/src"
53 | ENV TRANSFORMERS_VERBOSITY=error
54 | ENV TRANSFORMERS_NO_ADVISORY_WARNINGS=1
55 |
--------------------------------------------------------------------------------
/Dockerfile.ollama:
--------------------------------------------------------------------------------
1 | FROM ollama/ollama:latest
2 |
3 | RUN apt-get update && apt-get install -y curl && rm -rf /var/lib/apt/lists/*
4 |
5 | ENV OLLAMA_HOST=0.0.0.0:11434
6 |
7 | EXPOSE 11434
8 |
9 | HEALTHCHECK --interval=30s --timeout=10s --start-period=30s --retries=3 \
10 | CMD curl -f http://localhost:11434/api/tags || exit 1
11 |
--------------------------------------------------------------------------------
/dev-requirements.txt:
--------------------------------------------------------------------------------
1 | -r requirements.txt
2 | pytest==8.2.2
3 | black==24.4.2
4 | pip-upgrader==1.4.15
--------------------------------------------------------------------------------
/docker-compose-gpu.yml:
--------------------------------------------------------------------------------
1 | services:
2 | ollama-gpu:
3 | extends:
4 | file: docker-compose.yml
5 | service: ollama
6 | container_name: ollama-service-gpu
7 | deploy:
8 | resources:
9 | reservations:
10 | devices:
11 | - driver: nvidia
12 | count: 1
13 | capabilities: [ gpu ]
14 | environment:
15 | - NVIDIA_VISIBLE_DEVICES=all
16 |
17 | pdf-document-layout-analysis-gpu:
18 | container_name: pdf-document-layout-analysis-gpu
19 | entrypoint: [ "gunicorn", "-k", "uvicorn.workers.UvicornWorker", "--chdir", "./src", "app:app", "--bind", "0.0.0.0:5060", "--timeout", "10000"]
20 | init: true
21 | restart: unless-stopped
22 | build:
23 | context: .
24 | dockerfile: Dockerfile
25 | ports:
26 | - "5060:5060"
27 | deploy:
28 | resources:
29 | reservations:
30 | devices:
31 | - driver: nvidia
32 | count: 1
33 | capabilities: [ gpu ]
34 | environment:
35 | - RESTART_IF_NO_GPU=$RESTART_IF_NO_GPU
36 | - OLLAMA_HOST=http://localhost:11434
37 |
38 | pdf-document-layout-analysis-gpu-translation:
39 | container_name: pdf-document-layout-analysis-gpu-translation
40 | entrypoint: [ "gunicorn", "-k", "uvicorn.workers.UvicornWorker", "--chdir", "./src", "app:app", "--bind", "0.0.0.0:5060", "--timeout", "10000"]
41 | init: true
42 | restart: unless-stopped
43 | build:
44 | context: .
45 | dockerfile: Dockerfile
46 | ports:
47 | - "5060:5060"
48 | depends_on:
49 | ollama-gpu:
50 | condition: service_healthy
51 | deploy:
52 | resources:
53 | reservations:
54 | devices:
55 | - driver: nvidia
56 | count: 1
57 | capabilities: [ gpu ]
58 | environment:
59 | - RESTART_IF_NO_GPU=$RESTART_IF_NO_GPU
60 | - OLLAMA_HOST=http://ollama-gpu:11434
61 | networks:
62 | - pdf-analysis-network
63 |
64 | networks:
65 | pdf-analysis-network:
66 | driver: bridge
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | services:
2 | ollama:
3 | container_name: ollama-service
4 | build:
5 | context: .
6 | dockerfile: Dockerfile.ollama
7 | restart: unless-stopped
8 | ports:
9 | - "11434:11434"
10 | healthcheck:
11 | test: ["CMD", "curl", "-f", "http://localhost:11434/api/tags"]
12 | interval: 30s
13 | timeout: 10s
14 | retries: 3
15 | start_period: 30s
16 | networks:
17 | - pdf-analysis-network
18 |
19 | pdf-document-layout-analysis:
20 | container_name: pdf-document-layout-analysis
21 | entrypoint: [ "gunicorn", "-k", "uvicorn.workers.UvicornWorker", "--chdir", "./src", "app:app", "--bind", "0.0.0.0:5060", "--timeout", "10000"]
22 | init: true
23 | restart: unless-stopped
24 | build:
25 | context: .
26 | dockerfile: Dockerfile
27 | ports:
28 | - "5060:5060"
29 | environment:
30 | - OLLAMA_HOST=http://localhost:11434
31 |
32 | pdf-document-layout-analysis-translation:
33 | container_name: pdf-document-layout-analysis-translation
34 | entrypoint: [ "gunicorn", "-k", "uvicorn.workers.UvicornWorker", "--chdir", "./src", "app:app", "--bind", "0.0.0.0:5060", "--timeout", "10000"]
35 | init: true
36 | restart: unless-stopped
37 | build:
38 | context: .
39 | dockerfile: Dockerfile
40 | ports:
41 | - "5060:5060"
42 | depends_on:
43 | ollama:
44 | condition: service_healthy
45 | environment:
46 | - OLLAMA_HOST=http://ollama:11434
47 | networks:
48 | - pdf-analysis-network
49 |
50 | networks:
51 | pdf-analysis-network:
52 | driver: bridge
53 |
--------------------------------------------------------------------------------
/images/vgtexample1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/images/vgtexample1.png
--------------------------------------------------------------------------------
/images/vgtexample2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/images/vgtexample2.png
--------------------------------------------------------------------------------
/images/vgtexample3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/images/vgtexample3.png
--------------------------------------------------------------------------------
/images/vgtexample4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/images/vgtexample4.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "pdf-document-layout-analysis"
3 | version = "2025.03.18.03"
4 | description = "This tool is for PDF document layout analysis"
5 | license = { file = "LICENSE" }
6 | authors = [{ name = "HURIDOCS" }]
7 | requires-python = ">= 3.10"
8 | dependencies = [
9 | "fastapi==0.111.1",
10 | "python-multipart==0.0.9",
11 | "uvicorn==0.30.3",
12 | "gunicorn==22.0.0",
13 | "requests==2.32.3",
14 | "torch==2.4.0",
15 | "torchvision==0.19.0",
16 | "timm==1.0.8",
17 | "Pillow==10.4.0",
18 | "pdf-annotate==0.12.0",
19 | "scipy==1.14.0",
20 | "opencv-python==4.10.0.84",
21 | "Shapely==2.0.5",
22 | "transformers==4.40.2",
23 | "huggingface_hub==0.23.5",
24 | "pdf2image==1.17.0",
25 | "lxml==5.2.2",
26 | "lightgbm==4.5.0",
27 | "setuptools==75.4.0",
28 | "roman==4.2",
29 | "hydra-core==1.3.2",
30 | "pypandoc==1.13",
31 | "rapid-latex-ocr==0.0.9",
32 | "struct_eqtable @ git+https://github.com/UniModal4Reasoning/StructEqTable-Deploy.git@fd06078bfa9364849eb39330c075dd63cbed73ff"
33 | ]
34 |
35 | [project.urls]
36 | HURIDOCS = "https://huridocs.org"
37 | GitHub = "https://github.com/huridocs/pdf-document-layout-analysis"
38 | HuggingFace = "https://huggingface.co/HURIDOCS/pdf-document-layout-analysis"
39 | DockerHub = "https://hub.docker.com/r/huridocs/pdf-document-layout-analysis"
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | fastapi==0.111.1
2 | pydantic==2.11.0
3 | python-multipart==0.0.9
4 | uvicorn==0.30.3
5 | gunicorn==22.0.0
6 | requests==2.32.3
7 | torch==2.4.0
8 | torchvision==0.19.0
9 | Pillow==10.4.0
10 | pdf-annotate==0.12.0
11 | scipy==1.14.0
12 | opencv-python==4.10.0.84
13 | Shapely==2.0.5
14 | transformers==4.40.2
15 | huggingface_hub==0.23.5
16 | pdf2image==1.17.0
17 | lightgbm==4.5.0
18 | setuptools==75.4.0
19 | roman==4.2
20 | hydra-core==1.3.2
21 | pypandoc==1.13
22 | rapid-table==2.0.3
23 | rapidocr==3.2.0
24 | pix2tex==0.1.4
25 | latex2mathml==3.78.0
26 | PyMuPDF==1.25.5
27 | ollama==0.6.0
28 | git+https://github.com/huridocs/pdf-features.git@2025.10.1.1
--------------------------------------------------------------------------------
/src/adapters/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/__init__.py
--------------------------------------------------------------------------------
/src/adapters/infrastructure/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/infrastructure/__init__.py
--------------------------------------------------------------------------------
/src/adapters/infrastructure/format_conversion_service_adapter.py:
--------------------------------------------------------------------------------
1 | from domain.PdfImages import PdfImages
2 | from domain.PdfSegment import PdfSegment
3 | from ports.services.format_conversion_service import FormatConversionService
4 | from adapters.infrastructure.format_converters.convert_table_to_html import extract_table_format
5 | from adapters.infrastructure.format_converters.convert_formula_to_latex import extract_formula_format
6 |
7 |
8 | class FormatConversionServiceAdapter(FormatConversionService):
9 | def convert_table_to_html(self, pdf_images: PdfImages, segments: list[PdfSegment]) -> None:
10 | extract_table_format(pdf_images, segments)
11 |
12 | def convert_formula_to_latex(self, pdf_images: PdfImages, segments: list[PdfSegment]) -> None:
13 | extract_formula_format(pdf_images, segments)
14 |
--------------------------------------------------------------------------------
/src/adapters/infrastructure/format_converters/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/infrastructure/format_converters/__init__.py
--------------------------------------------------------------------------------
/src/adapters/infrastructure/format_converters/convert_formula_to_latex.py:
--------------------------------------------------------------------------------
1 | from PIL.Image import Image
2 | from pix2tex.cli import LatexOCR
3 | from domain.PdfImages import PdfImages
4 | from domain.PdfSegment import PdfSegment
5 | from pdf_token_type_labels import TokenType
6 | import latex2mathml.converter
7 |
8 |
9 | def has_arabic(text: str) -> bool:
10 | return any("\u0600" <= char <= "\u06FF" or "\u0750" <= char <= "\u077F" for char in text)
11 |
12 |
13 | def is_valid_latex(formula: str) -> bool:
14 | try:
15 | latex2mathml.converter.convert(formula)
16 | return True
17 | except Exception:
18 | return False
19 |
20 |
21 | def extract_formula_format(pdf_images: PdfImages, predicted_segments: list[PdfSegment]):
22 | formula_segments = [segment for segment in predicted_segments if segment.segment_type == TokenType.FORMULA]
23 | if not formula_segments:
24 | return
25 |
26 | model = LatexOCR()
27 | model.args.temperature = 1e-8
28 |
29 | for formula_segment in formula_segments:
30 | if has_arabic(formula_segment.text_content):
31 | continue
32 | page_image: Image = pdf_images.pdf_images[formula_segment.page_number - 1]
33 | left, top = formula_segment.bounding_box.left, formula_segment.bounding_box.top
34 | right, bottom = formula_segment.bounding_box.right, formula_segment.bounding_box.bottom
35 | left = int(left * pdf_images.dpi / 72)
36 | top = int(top * pdf_images.dpi / 72)
37 | right = int(right * pdf_images.dpi / 72)
38 | bottom = int(bottom * pdf_images.dpi / 72)
39 | formula_image = page_image.crop((left, top, right, bottom))
40 | formula_result = model(formula_image)
41 | if not is_valid_latex(formula_result):
42 | continue
43 | formula_segment.text_content = f"$${formula_result}$$"
44 |
--------------------------------------------------------------------------------
/src/adapters/infrastructure/format_converters/convert_table_to_html.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from domain.PdfImages import PdfImages
3 | from domain.PdfSegment import PdfSegment
4 | from pdf_token_type_labels import TokenType
5 | from rapidocr import RapidOCR
6 | from rapid_table import ModelType, RapidTable, RapidTableInput
7 |
8 |
9 | def extract_table_format(pdf_images: PdfImages, predicted_segments: list[PdfSegment]):
10 | table_segments = [segment for segment in predicted_segments if segment.segment_type == TokenType.TABLE]
11 | if not table_segments:
12 | return
13 |
14 | input_args = RapidTableInput(model_type=ModelType["SLANETPLUS"])
15 |
16 | ocr_engine = RapidOCR()
17 | table_engine = RapidTable(input_args)
18 |
19 | for table_segment in table_segments:
20 | page_image: Image = pdf_images.pdf_images[table_segment.page_number - 1]
21 | left, top = table_segment.bounding_box.left, table_segment.bounding_box.top
22 | right, bottom = table_segment.bounding_box.right, table_segment.bounding_box.bottom
23 | left = int(left * pdf_images.dpi / 72)
24 | top = int(top * pdf_images.dpi / 72)
25 | right = int(right * pdf_images.dpi / 72)
26 | bottom = int(bottom * pdf_images.dpi / 72)
27 | table_image = page_image.crop((left, top, right, bottom))
28 | ori_ocr_res = ocr_engine(table_image)
29 | if not ori_ocr_res.txts:
30 | continue
31 | ocr_results = [ori_ocr_res.boxes, ori_ocr_res.txts, ori_ocr_res.scores]
32 | table_result = table_engine(table_image, ocr_results=ocr_results)
33 | table_segment.text_content = table_result.pred_html
34 |
--------------------------------------------------------------------------------
/src/adapters/infrastructure/html_conversion_service_adapter.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union
2 | from starlette.responses import Response
3 |
4 | from domain.SegmentBox import SegmentBox
5 | from ports.services.html_conversion_service import HtmlConversionService
6 | from adapters.infrastructure.markup_conversion.pdf_to_markup_service_adapter import PdfToMarkupServiceAdapter
7 | from adapters.infrastructure.markup_conversion.OutputFormat import OutputFormat
8 |
9 |
10 | class HtmlConversionServiceAdapter(HtmlConversionService, PdfToMarkupServiceAdapter):
11 |
12 | def __init__(self):
13 | PdfToMarkupServiceAdapter.__init__(self, OutputFormat.HTML)
14 |
15 | def convert_to_html(
16 | self,
17 | pdf_content: bytes,
18 | segments: list[SegmentBox],
19 | extract_toc: bool = False,
20 | dpi: int = 120,
21 | output_file: Optional[str] = None,
22 | target_languages: Optional[list[str]] = None,
23 | translation_model: str = "gpt-oss",
24 | ) -> Union[str, Response]:
25 | return self.convert_to_format(
26 | pdf_content, segments, extract_toc, dpi, output_file, target_languages, translation_model
27 | )
28 |
--------------------------------------------------------------------------------
/src/adapters/infrastructure/markdown_conversion_service_adapter.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union
2 | from starlette.responses import Response
3 |
4 | from domain.SegmentBox import SegmentBox
5 | from ports.services.markdown_conversion_service import MarkdownConversionService
6 | from adapters.infrastructure.markup_conversion.pdf_to_markup_service_adapter import PdfToMarkupServiceAdapter
7 | from adapters.infrastructure.markup_conversion.OutputFormat import OutputFormat
8 |
9 |
10 | class MarkdownConversionServiceAdapter(MarkdownConversionService, PdfToMarkupServiceAdapter):
11 |
12 | def __init__(self):
13 | PdfToMarkupServiceAdapter.__init__(self, OutputFormat.MARKDOWN)
14 |
15 | def convert_to_markdown(
16 | self,
17 | pdf_content: bytes,
18 | segments: list[SegmentBox],
19 | extract_toc: bool = False,
20 | dpi: int = 120,
21 | output_file: Optional[str] = None,
22 | target_languages: Optional[list[str]] = None,
23 | translation_model: str = "gpt-oss",
24 | ) -> Union[str, Response]:
25 | return self.convert_to_format(
26 | pdf_content, segments, extract_toc, dpi, output_file, target_languages, translation_model
27 | )
28 |
--------------------------------------------------------------------------------
/src/adapters/infrastructure/markup_conversion/ExtractedImage.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 |
3 |
4 | class ExtractedImage(BaseModel):
5 | image_data: bytes
6 | filename: str
7 |
--------------------------------------------------------------------------------
/src/adapters/infrastructure/markup_conversion/Link.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 | from domain.SegmentBox import SegmentBox
3 |
4 |
5 | class Link(BaseModel):
6 | source_segment: SegmentBox
7 | destination_segment: SegmentBox
8 | text: str
9 |
--------------------------------------------------------------------------------
/src/adapters/infrastructure/markup_conversion/OutputFormat.py:
--------------------------------------------------------------------------------
1 | from enum import StrEnum
2 |
3 |
4 | class OutputFormat(StrEnum):
5 | HTML = "html"
6 | MARKDOWN = "markdown"
7 |
--------------------------------------------------------------------------------
/src/adapters/infrastructure/markup_conversion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/infrastructure/markup_conversion/__init__.py
--------------------------------------------------------------------------------
/src/adapters/infrastructure/ocr/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/infrastructure/ocr/__init__.py
--------------------------------------------------------------------------------
/src/adapters/infrastructure/ocr/languages.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 |
3 | iso_to_tesseract = {
4 | "af": "afr", # Afrikaans
5 | "all": "all", # Allar
6 | "am": "amh", # Amharic
7 | "ar": "ara", # Arabic
8 | "as": "asm", # Assamese
9 | "az": "aze", # Azerbaijani
10 | "aze-cyrl": "aze-cyrl", # Azerbaijani (Cyrillic)
11 | "be": "bel", # Belarusian
12 | "bn": "ben", # Bangla
13 | "bo": "bod", # Tibetan
14 | "bs": "bos", # Bosnian
15 | "br": "bre", # Breton
16 | "bg": "bul", # Bulgarian
17 | "ca": "cat", # Catalan
18 | "ceb": "ceb", # Cebuano
19 | "cs": "ces", # Czech
20 | "zh-Hans": "chi_sim", # Chinese (Simplified)
21 | "chi-sim-vert": "chi-sim-vert", # Chinese (Simplified) vertical
22 | "zh-Hant": "chi_tra", # Chinese (Traditional)
23 | "chi-tra-vert": "chi-tra-vert", # Chinese (Traditional) vertical
24 | "chr": "chr", # Cherokee
25 | "co": "cos", # Corsican
26 | "cy": "cym", # Welsh
27 | "da": "dan", # Danish
28 | "de": "deu", # German
29 | "dv": "div", # Divehi
30 | "dz": "dzo", # Dzongkha
31 | "el": "ell", # Greek
32 | "en": "eng", # English
33 | "enm": "enm", # Middle English
34 | "eo": "epo", # Esperanto
35 | "et": "est", # Estonian
36 | "eu": "eus", # Basque
37 | "fo": "fao", # Faroese
38 | "fa": "fas", # Persian
39 | "fil": "fil", # Filipino
40 | "fi": "fin", # Finnish
41 | "fr": "fra", # French
42 | "frk": "frk", # Frankish
43 | "frm": "frm", # Middle French
44 | "fy": "fry", # Western Frisian
45 | "gd": "gla", # Scottish Gaelic
46 | "ga": "gle", # Irish
47 | "gl": "glg", # Galician
48 | "grc": "grc", # Ancient Greek
49 | "gu": "guj", # Gujarati
50 | "ht": "hat", # Haitian Creole
51 | "he": "heb", # Hebrew
52 | "hi": "hin", # Hindi
53 | "hr": "hrv", # Croatian
54 | "hu": "hun", # Hungarian
55 | "hy": "hye", # Armenian
56 | "iu": "iku", # Inuktitut
57 | "id": "ind", # Indonesian
58 | "is": "isl", # Icelandic
59 | "it": "ita", # Italian
60 | "ita-old": "ita-old", # Old Italian
61 | "jv": "jav", # Javanese
62 | "ja": "jpn", # Japanese
63 | "jpn-vert": "jpn-vert", # Japanese vertical
64 | "kn": "kan", # Kannada
65 | "ka": "kat", # Georgian
66 | "kat-old": "kat-old", # Old Georgian
67 | "kk": "kaz", # Kazakh
68 | "km": "khm", # Khmer
69 | "ky": "kir", # Kyrgyz
70 | "kmr": "kmr", # Northern Kurdish
71 | "ko": "kor", # Korean
72 | "kor-vert": "kor_vert", # Korean vertical
73 | "lo": "lao", # Lao
74 | "la": "lat", # Latin
75 | "lv": "lav", # Latvian
76 | "lt": "lit", # Lithuanian
77 | "lb": "ltz", # Luxembourgish
78 | "ml": "mal", # Malayalam
79 | "mr": "mar", # Marathi
80 | "mk": "mkd", # Macedonian
81 | "mt": "mlt", # Maltese
82 | "mn": "mon", # Mongolian
83 | "mi": "mri", # Māori
84 | "ms": "msa", # Malay
85 | "my": "mya", # Burmese
86 | "ne": "nep", # Nepali
87 | "nl": "nld", # Dutch
88 | "no": "nor", # Norwegian
89 | "oc": "oci", # Occitan
90 | "or": "ori", # Odia
91 | "osd": "osd", # Unknown language [osd]
92 | "pa": "pan", # Punjabi
93 | "pl": "pol", # Polish
94 | "pt": "por", # Portuguese
95 | "ps": "pus", # Pashto
96 | "qu": "que", # Quechua
97 | "ro": "ron", # Romanian
98 | "ru": "rus", # Russian
99 | "sa": "san", # Sanskrit
100 | "script-arab": "script-arab", # Arabic script
101 | "script-armn": "script-armn", # Armenian script
102 | "script-beng": "script-beng", # Bengali script
103 | "script-cans": "script-cans", # Canadian Aboriginal script
104 | "script-cher": "script-cher", # Cherokee script
105 | "script-cyrl": "script-cyrl", # Cyrillic script
106 | "script-deva": "script-deva", # Devanagari script
107 | "script-ethi": "script-ethi", # Ethiopic script
108 | "script-frak": "script-frak", # Frankish script
109 | "script-geor": "script-geor", # Georgian script
110 | "script-grek": "script-grek", # Greek script
111 | "script-gujr": "script-gujr", # Gujarati script
112 | "script-guru": "script-guru", # Gurmukhi script
113 | "script-hang": "script-hang", # Hangul script
114 | "script-hang-vert": "script-hang-vert", # Hangul script vertical
115 | "script-hans": "script-hans",
116 | "script-hans-vert": "script-hans-vert",
117 | "script-hant": "script-hant",
118 | "script-hant-vert": "script-hant-vert",
119 | "script-hebr": "script-hebr", # Hebrew script
120 | "script-jpan": "script-jpan", # Japanese script
121 | "script-jpan-vert": "script-jpan-vert", # Japanese script vertical
122 | "script-khmr": "script-khmr", # Khmer script
123 | "script-knda": "script-knda", # Kannada script
124 | "script-laoo": "script-laoo", # Lao script
125 | "script-latn": "script-latn",
126 | "script-mlym": "script-mlym", # Malayalam script
127 | "script-mymr": "script-mymr", # Myanmar script
128 | "script-orya": "script-orya", # Odia script
129 | "script-sinh": "script-sinh", # Sinhala script
130 | "script-syrc": "script-syrc", # Syriac script
131 | "script-taml": "script-taml", # Tamil script
132 | "script-telu": "script-telu", # Telugu script
133 | "script-thaa": "script-thaa", # Thaana script
134 | "script-thai": "script-thai", # Thai script
135 | "script-tibt": "script-tibt", # Tibetan script
136 | "script-viet": "script-viet", # Vietnamese script
137 | "si": "sin", # Sinhala
138 | "sk": "slk", # Slovak
139 | "sl": "slv", # Slovenian
140 | "sd": "snd", # Sindhi
141 | "es": "spa", # Spanish
142 | "spa-old": "spa-old", # Old Spanish
143 | "sq": "sqi", # Albanian
144 | "sr": "srp", # Serbian
145 | "srp-latn": "srp-latn", # Serbian (Latin)
146 | "su": "sun", # Sundanese
147 | "sw": "swa", # Swahili
148 | "sv": "swe", # Swedish
149 | "syr": "syr", # Syriac
150 | "ta": "tam", # Tamil
151 | "tt": "tat", # Tatar
152 | "te": "tel", # Telugu
153 | "tg": "tgk", # Tajik
154 | "th": "tha", # Thai
155 | "ti": "tir", # Tigrinya
156 | "to": "ton", # Tongan
157 | "tr": "tur", # Turkish
158 | "ug": "uig", # Uyghur
159 | "uk": "ukr", # Ukrainian
160 | "ur": "urd", # Urdu
161 | "uz": "uzb", # Uzbek
162 | "uzb-cyrl": "uzb-cyrl", # Uzbek (Cyrillic)
163 | "vi": "vie", # Vietnamese
164 | "yi": "yid", # Yiddish
165 | "yo": "yor", # Yoruba
166 | }
167 |
168 |
169 | def supported_languages():
170 | cmd = "tesseract --list-langs | grep -v osd | awk '{if(NR>1)print}'"
171 | sp = subprocess.Popen(["/bin/bash", "-c", cmd], stdout=subprocess.PIPE)
172 | tesseract_langs = [line.strip().decode("utf-8") for line in sp.stdout.readlines()]
173 | inverted_iso_dict = {v: k for k, v in iso_to_tesseract.items()}
174 | return list({tesseract_key: inverted_iso_dict[tesseract_key] for tesseract_key in tesseract_langs}.values())
175 |
--------------------------------------------------------------------------------
/src/adapters/infrastructure/ocr_service_adapter.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import subprocess
4 | from pathlib import Path
5 | from ports.services.ocr_service import OCRService
6 | from configuration import OCR_SOURCE, OCR_OUTPUT, OCR_FAILED
7 | from adapters.infrastructure.ocr.languages import iso_to_tesseract, supported_languages
8 |
9 |
10 | class OCRServiceAdapter(OCRService):
11 | def process_pdf_ocr(self, filename: str, namespace: str, language: str = "en") -> Path:
12 | source_pdf_filepath, processed_pdf_filepath, failed_pdf_filepath = self._get_paths(namespace, filename)
13 | os.makedirs(processed_pdf_filepath.parent, exist_ok=True)
14 |
15 | result = subprocess.run(
16 | [
17 | "ocrmypdf",
18 | "-l",
19 | iso_to_tesseract[language],
20 | source_pdf_filepath,
21 | processed_pdf_filepath,
22 | "--force-ocr",
23 | ]
24 | )
25 |
26 | if result.returncode == 0:
27 | return processed_pdf_filepath
28 |
29 | os.makedirs(failed_pdf_filepath.parent, exist_ok=True)
30 | shutil.move(source_pdf_filepath, failed_pdf_filepath)
31 | return False
32 |
33 | def get_supported_languages(self) -> list[str]:
34 | return supported_languages()
35 |
36 | def _get_paths(self, namespace: str, pdf_file_name: str) -> tuple[Path, Path, Path]:
37 | file_name = "".join(pdf_file_name.split(".")[:-1]) if "." in pdf_file_name else pdf_file_name
38 | source_pdf_filepath = Path(OCR_SOURCE, namespace, pdf_file_name)
39 | processed_pdf_filepath = Path(OCR_OUTPUT, namespace, f"{file_name}.pdf")
40 | failed_pdf_filepath = Path(OCR_FAILED, namespace, pdf_file_name)
41 | return source_pdf_filepath, processed_pdf_filepath, failed_pdf_filepath
42 |
--------------------------------------------------------------------------------
/src/adapters/infrastructure/pdf_analysis_service_adapter.py:
--------------------------------------------------------------------------------
1 | from typing import AnyStr
2 | from domain.PdfImages import PdfImages
3 | from domain.SegmentBox import SegmentBox
4 | from ports.services.pdf_analysis_service import PDFAnalysisService
5 | from ports.services.ml_model_service import MLModelService
6 | from ports.services.format_conversion_service import FormatConversionService
7 | from ports.repositories.file_repository import FileRepository
8 | from configuration import service_logger
9 |
10 |
11 | class PDFAnalysisServiceAdapter(PDFAnalysisService):
12 | def __init__(
13 | self,
14 | vgt_model_service: MLModelService,
15 | fast_model_service: MLModelService,
16 | format_conversion_service: FormatConversionService,
17 | file_repository: FileRepository,
18 | ):
19 | self.vgt_model_service = vgt_model_service
20 | self.fast_model_service = fast_model_service
21 | self.format_conversion_service = format_conversion_service
22 | self.file_repository = file_repository
23 |
24 | def analyze_pdf_layout(
25 | self, pdf_content: AnyStr, xml_filename: str = "", parse_tables_and_math: bool = False, keep_pdf: bool = False
26 | ) -> list[dict]:
27 | pdf_path = self.file_repository.save_pdf(pdf_content)
28 | service_logger.info("Creating PDF images")
29 |
30 | pdf_images_list: list[PdfImages] = [PdfImages.from_pdf_path(pdf_path, "", xml_filename)]
31 |
32 | predicted_segments = self.vgt_model_service.predict_document_layout(pdf_images_list)
33 |
34 | if parse_tables_and_math:
35 | pdf_images_200_dpi = PdfImages.from_pdf_path(pdf_path, "", xml_filename, dpi=200)
36 | self.format_conversion_service.convert_formula_to_latex(pdf_images_200_dpi, predicted_segments)
37 | self.format_conversion_service.convert_table_to_html(pdf_images_200_dpi, predicted_segments)
38 |
39 | if not keep_pdf:
40 | self.file_repository.delete_file(pdf_path)
41 |
42 | return [
43 | SegmentBox.from_pdf_segment(pdf_segment, pdf_images_list[0].pdf_features.pages).to_dict()
44 | for pdf_segment in predicted_segments
45 | ]
46 |
47 | def analyze_pdf_layout_fast(
48 | self, pdf_content: AnyStr, xml_filename: str = "", parse_tables_and_math: bool = False, keep_pdf: bool = False
49 | ) -> list[dict]:
50 | pdf_path = self.file_repository.save_pdf(pdf_content)
51 | service_logger.info("Creating PDF images for fast analysis")
52 |
53 | pdf_images_list: list[PdfImages] = [PdfImages.from_pdf_path(pdf_path, "", xml_filename)]
54 |
55 | predicted_segments = self.fast_model_service.predict_layout_fast(pdf_images_list)
56 |
57 | if parse_tables_and_math:
58 | pdf_images_200_dpi = PdfImages.from_pdf_path(pdf_path, "", xml_filename, dpi=200)
59 | self.format_conversion_service.convert_formula_to_latex(pdf_images_200_dpi, predicted_segments)
60 | self.format_conversion_service.convert_table_to_html(pdf_images_list[0], predicted_segments)
61 |
62 | if not keep_pdf:
63 | self.file_repository.delete_file(pdf_path)
64 |
65 | return [
66 | SegmentBox.from_pdf_segment(pdf_segment, pdf_images_list[0].pdf_features.pages).to_dict()
67 | for pdf_segment in predicted_segments
68 | ]
69 |
--------------------------------------------------------------------------------
/src/adapters/infrastructure/text_extraction_adapter.py:
--------------------------------------------------------------------------------
1 | from pdf_token_type_labels import TokenType
2 | from ports.services.text_extraction_service import TextExtractionService
3 | from configuration import service_logger
4 |
5 |
6 | class TextExtractionAdapter(TextExtractionService):
7 | def extract_text_by_types(self, segment_boxes: list[dict], token_types: list[TokenType]) -> dict:
8 | service_logger.info(f"Extracted types: {[t.name for t in token_types]}")
9 | text = "\n".join(
10 | [
11 | segment_box["text"]
12 | for segment_box in segment_boxes
13 | if TokenType.from_text(segment_box["type"].replace(" ", "_")) in token_types
14 | ]
15 | )
16 | return text
17 |
18 | def extract_all_text(self, segment_boxes: list[dict]) -> dict:
19 | all_types = [t for t in TokenType]
20 | return self.extract_text_by_types(segment_boxes, all_types)
21 |
--------------------------------------------------------------------------------
/src/adapters/infrastructure/toc/MergeTwoSegmentsTitles.py:
--------------------------------------------------------------------------------
1 | from adapters.infrastructure.toc.TitleFeatures import TitleFeatures
2 | from adapters.infrastructure.toc.PdfSegmentation import PdfSegmentation
3 |
4 |
5 | class MergeTwoSegmentsTitles:
6 | def __init__(self, pdf_segmentation: PdfSegmentation):
7 | self.title_features_list: list[TitleFeatures] = TitleFeatures.from_pdf_segmentation(pdf_segmentation)
8 | self.titles_merged: list[TitleFeatures] = list()
9 | self.merge()
10 |
11 | def merge(self):
12 | index = 0
13 | while index < len(self.title_features_list):
14 | if index == len(self.title_features_list) - 1:
15 | self.titles_merged.append(self.title_features_list[index])
16 | break
17 |
18 | if not self.should_merge(self.title_features_list[index], self.title_features_list[index + 1]):
19 | self.titles_merged.append(self.title_features_list[index])
20 | index += 1
21 | continue
22 |
23 | self.title_features_list[index + 1] = self.title_features_list[index + 1].append(self.title_features_list[index])
24 | index += 1
25 |
26 | @staticmethod
27 | def should_merge(title: TitleFeatures, other_title: TitleFeatures):
28 | same_page = other_title.pdf_segment.page_number == title.pdf_segment.page_number
29 |
30 | if not same_page:
31 | return False
32 |
33 | if abs(other_title.top - title.bottom) > 15:
34 | return False
35 |
36 | if abs(other_title.left - title.right) > 15 or abs(other_title.right - title.left) > 15:
37 | return False
38 |
39 | if title.first_characters_type in [1, 2, 3] and other_title.first_characters_type in [1, 2, 3]:
40 | return False
41 |
42 | if title.bullet_points_type and other_title.bullet_points_type:
43 | return False
44 |
45 | if title.get_features_to_merge() != other_title.get_features_to_merge():
46 | return False
47 |
48 | return True
49 |
--------------------------------------------------------------------------------
/src/adapters/infrastructure/toc/PdfSegmentation.py:
--------------------------------------------------------------------------------
1 | from domain.PdfSegment import PdfSegment
2 | from pdf_features import PdfFeatures
3 | from pdf_features import PdfToken
4 |
5 |
6 | class PdfSegmentation:
7 | def __init__(self, pdf_features: PdfFeatures, pdf_segments: list[PdfSegment]):
8 | self.pdf_features: PdfFeatures = pdf_features
9 | self.pdf_segments: list[PdfSegment] = pdf_segments
10 | self.tokens_by_segments: dict[PdfSegment, list[PdfToken]] = self.find_tokens_by_segments()
11 |
12 | @staticmethod
13 | def find_segment_for_token(token: PdfToken, segments: list[PdfSegment], tokens_by_segments):
14 | best_score: float = 0
15 | most_probable_segment: PdfSegment | None = None
16 | for segment in segments:
17 | intersection_percentage = token.bounding_box.get_intersection_percentage(segment.bounding_box)
18 | if intersection_percentage > best_score:
19 | best_score = intersection_percentage
20 | most_probable_segment = segment
21 | if best_score >= 99:
22 | break
23 | if most_probable_segment:
24 | tokens_by_segments.setdefault(most_probable_segment, list()).append(token)
25 |
26 | def find_tokens_by_segments(self):
27 | tokens_by_segments: dict[PdfSegment, list[PdfToken]] = {}
28 | for page in self.pdf_features.pages:
29 | page_segments = [segment for segment in self.pdf_segments if segment.page_number == page.page_number]
30 | for token in page.tokens:
31 | self.find_segment_for_token(token, page_segments, tokens_by_segments)
32 | return tokens_by_segments
33 |
--------------------------------------------------------------------------------
/src/adapters/infrastructure/toc/TOCExtractor.py:
--------------------------------------------------------------------------------
1 | from adapters.infrastructure.toc.MergeTwoSegmentsTitles import MergeTwoSegmentsTitles
2 | from adapters.infrastructure.toc.TitleFeatures import TitleFeatures
3 | from adapters.infrastructure.toc.data.TOCItem import TOCItem
4 | from adapters.infrastructure.toc.PdfSegmentation import PdfSegmentation
5 |
6 |
7 | class TOCExtractor:
8 | def __init__(self, pdf_segmentation: PdfSegmentation):
9 | self.pdf_segmentation = pdf_segmentation
10 | self.titles_features_sorted = MergeTwoSegmentsTitles(self.pdf_segmentation).titles_merged
11 | self.toc: list[TOCItem] = list()
12 | self.set_toc()
13 |
14 | def set_toc(self):
15 | for index, title_features in enumerate(self.titles_features_sorted):
16 | indentation = self.get_indentation(index, title_features)
17 | self.toc.append(title_features.to_toc_item(indentation))
18 |
19 | def __str__(self):
20 | return "\n".join([f'{" " * x.indentation} * {x.label}' for x in self.toc])
21 |
22 | def get_indentation(self, title_index: int, title_features: TitleFeatures):
23 | if title_index == 0:
24 | return 0
25 |
26 | for index in reversed(range(title_index)):
27 | if self.toc[index].point_closed:
28 | continue
29 |
30 | if self.same_indentation(self.titles_features_sorted[index], title_features):
31 | self.close_toc_items(self.toc[index].indentation)
32 | return self.toc[index].indentation
33 |
34 | return self.toc[title_index - 1].indentation + 1
35 |
36 | def close_toc_items(self, indentation):
37 | for toc in self.toc:
38 | if toc.indentation > indentation:
39 | toc.point_closed = True
40 |
41 | @staticmethod
42 | def same_indentation(previous_title_features: TitleFeatures, title_features: TitleFeatures):
43 | if previous_title_features.first_characters in title_features.get_possible_previous_point():
44 | return True
45 |
46 | if previous_title_features.get_features_toc() == title_features.get_features_toc():
47 | return True
48 |
49 | return False
50 |
51 | def to_dict(self):
52 | toc: list[dict[str, any]] = list()
53 |
54 | for toc_item in self.toc:
55 | toc_element_dict = dict()
56 | toc_element_dict["indentation"] = toc_item.indentation
57 | toc_element_dict["label"] = toc_item.label
58 | rectangle = dict()
59 | rectangle["left"] = int(toc_item.selection_rectangle.left)
60 | rectangle["top"] = int(toc_item.selection_rectangle.top)
61 | rectangle["width"] = int(toc_item.selection_rectangle.width)
62 | rectangle["height"] = int(toc_item.selection_rectangle.height)
63 | rectangle["page"] = str(toc_item.selection_rectangle.page_number)
64 | toc_element_dict["bounding_box"] = rectangle
65 | toc.append(toc_element_dict)
66 |
67 | return toc
68 |
--------------------------------------------------------------------------------
/src/adapters/infrastructure/toc/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/infrastructure/toc/__init__.py
--------------------------------------------------------------------------------
/src/adapters/infrastructure/toc/data/TOCItem.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 |
3 | from domain.SegmentBox import SegmentBox
4 |
5 |
6 | class TOCItem(BaseModel):
7 | indentation: int
8 | label: str = ""
9 | selection_rectangle: SegmentBox
10 | point_closed: bool = False
11 |
--------------------------------------------------------------------------------
/src/adapters/infrastructure/toc/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/infrastructure/toc/data/__init__.py
--------------------------------------------------------------------------------
/src/adapters/infrastructure/toc/extract_table_of_contents.py:
--------------------------------------------------------------------------------
1 | import tempfile
2 | import uuid
3 | from os.path import join
4 | from pathlib import Path
5 | from typing import AnyStr
6 | from domain.PdfSegment import PdfSegment
7 | from pdf_features import PdfFeatures
8 | from pdf_features import Rectangle
9 | from pdf_token_type_labels import TokenType
10 | from adapters.infrastructure.toc.TOCExtractor import TOCExtractor
11 | from configuration import service_logger
12 | from adapters.infrastructure.toc.PdfSegmentation import PdfSegmentation
13 |
14 | TITLE_TYPES = {TokenType.TITLE, TokenType.SECTION_HEADER}
15 | SKIP_TYPES = {TokenType.TITLE, TokenType.SECTION_HEADER, TokenType.PAGE_HEADER, TokenType.PICTURE}
16 |
17 |
18 | def get_file_path(file_name, extension):
19 | return join(tempfile.gettempdir(), file_name + "." + extension)
20 |
21 |
22 | def pdf_content_to_pdf_path(file_content):
23 | file_id = str(uuid.uuid1())
24 |
25 | pdf_path = Path(get_file_path(file_id, "pdf"))
26 | pdf_path.write_bytes(file_content)
27 |
28 | return pdf_path
29 |
30 |
31 | def skip_name_of_the_document(pdf_segments: list[PdfSegment], title_segments: list[PdfSegment]):
32 | segments_to_remove = []
33 | last_segment = None
34 | for segment in pdf_segments:
35 | if segment.segment_type not in SKIP_TYPES:
36 | break
37 | if segment.segment_type == TokenType.PAGE_HEADER or segment.segment_type == TokenType.PICTURE:
38 | continue
39 | if not last_segment:
40 | last_segment = segment
41 | else:
42 | if segment.bounding_box.right < last_segment.bounding_box.left + last_segment.bounding_box.width * 0.66:
43 | break
44 | last_segment = segment
45 | if segment.segment_type in TITLE_TYPES:
46 | segments_to_remove.append(segment)
47 | for segment in segments_to_remove:
48 | title_segments.remove(segment)
49 |
50 |
51 | def get_pdf_segments_from_segment_boxes(pdf_features: PdfFeatures, segment_boxes: list[dict]) -> list[PdfSegment]:
52 | pdf_segments: list[PdfSegment] = []
53 | for segment_box in segment_boxes:
54 | left, top, width, height = segment_box["left"], segment_box["top"], segment_box["width"], segment_box["height"]
55 | bounding_box = Rectangle.from_width_height(left, top, width, height)
56 | segment_type = TokenType.from_value(segment_box["type"])
57 | pdf_name = pdf_features.file_name
58 | segment = PdfSegment(segment_box["page_number"], bounding_box, segment_box["text"], segment_type, pdf_name)
59 | pdf_segments.append(segment)
60 | return pdf_segments
61 |
62 |
63 | def extract_table_of_contents(file: AnyStr, segment_boxes: list[dict], skip_document_name=False):
64 | service_logger.info("Getting TOC")
65 | pdf_path = pdf_content_to_pdf_path(file)
66 | pdf_features: PdfFeatures = PdfFeatures.from_pdf_path(pdf_path)
67 | pdf_segments: list[PdfSegment] = get_pdf_segments_from_segment_boxes(pdf_features, segment_boxes)
68 | title_segments = [segment for segment in pdf_segments if segment.segment_type in TITLE_TYPES]
69 | if skip_document_name:
70 | skip_name_of_the_document(pdf_segments, title_segments)
71 | pdf_segmentation: PdfSegmentation = PdfSegmentation(pdf_features, title_segments)
72 | toc_instance: TOCExtractor = TOCExtractor(pdf_segmentation)
73 | return toc_instance.to_dict()
74 |
--------------------------------------------------------------------------------
/src/adapters/infrastructure/toc/methods/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/infrastructure/toc/methods/__init__.py
--------------------------------------------------------------------------------
/src/adapters/infrastructure/toc/methods/two_models_v3_segments_context_2/Modes.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import hashlib
3 | from statistics import mode
4 |
5 | from pdf_features import PdfFeatures
6 |
7 |
8 | @dataclasses.dataclass
9 | class Modes:
10 | lines_space_mode: float
11 | left_space_mode: float
12 | right_space_mode: float
13 | font_size_mode: float
14 | font_family_name_mode: str
15 | font_family_mode: int
16 | font_family_mode_normalized: float
17 | pdf_features: PdfFeatures
18 |
19 | def __init__(self, pdf_features: PdfFeatures):
20 | self.pdf_features = pdf_features
21 | self.set_modes()
22 |
23 | def set_modes(self):
24 | line_spaces, right_spaces, left_spaces = [0], [0], [0]
25 | for page, token in self.pdf_features.loop_tokens():
26 | right_spaces.append(self.pdf_features.pages[0].page_width - token.bounding_box.right)
27 | left_spaces.append(token.bounding_box.left)
28 | line_spaces.append(token.bounding_box.bottom)
29 |
30 | self.lines_space_mode = mode(line_spaces)
31 | self.left_space_mode = mode(left_spaces)
32 | self.right_space_mode = mode(right_spaces)
33 |
34 | font_sizes = [token.font.font_size for page, token in self.pdf_features.loop_tokens() if token.font]
35 | self.font_size_mode = mode(font_sizes) if font_sizes else 0
36 | font_ids = [token.font.font_id for page, token in self.pdf_features.loop_tokens() if token.font]
37 | self.font_family_name_mode = mode(font_ids) if font_ids else ""
38 | self.font_family_mode = abs(
39 | int(
40 | str(hashlib.sha256(self.font_family_name_mode.encode("utf-8")).hexdigest())[:8],
41 | 16,
42 | )
43 | )
44 | self.font_family_mode_normalized = float(f"{str(self.font_family_mode)[0]}.{str(self.font_family_mode)[1:]}")
45 |
--------------------------------------------------------------------------------
/src/adapters/infrastructure/toc/methods/two_models_v3_segments_context_2/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/infrastructure/toc/methods/two_models_v3_segments_context_2/__init__.py
--------------------------------------------------------------------------------
/src/adapters/infrastructure/toc_service_adapter.py:
--------------------------------------------------------------------------------
1 | import tempfile
2 | import uuid
3 | from os.path import join
4 | from pathlib import Path
5 | from typing import AnyStr
6 | from domain.PdfSegment import PdfSegment
7 | from pdf_features import PdfFeatures, Rectangle
8 | from pdf_token_type_labels import TokenType
9 | from ports.services.toc_service import TOCService
10 | from configuration import service_logger
11 | from adapters.infrastructure.toc.TOCExtractor import TOCExtractor
12 | from adapters.infrastructure.toc.PdfSegmentation import PdfSegmentation
13 |
14 | TITLE_TYPES = {TokenType.TITLE, TokenType.SECTION_HEADER}
15 | SKIP_TYPES = {TokenType.TITLE, TokenType.SECTION_HEADER, TokenType.PAGE_HEADER, TokenType.PICTURE}
16 |
17 |
18 | class TOCServiceAdapter(TOCService):
19 |
20 | def extract_table_of_contents(
21 | self, pdf_content: AnyStr, segment_boxes: list[dict], skip_document_name=False
22 | ) -> list[dict]:
23 | service_logger.info("Getting TOC")
24 | pdf_path = self._pdf_content_to_pdf_path(pdf_content)
25 | pdf_features: PdfFeatures = PdfFeatures.from_pdf_path(pdf_path)
26 | pdf_segments: list[PdfSegment] = self._get_pdf_segments_from_segment_boxes(pdf_features, segment_boxes)
27 | title_segments = [segment for segment in pdf_segments if segment.segment_type in TITLE_TYPES]
28 | if skip_document_name:
29 | self._skip_name_of_the_document(pdf_segments, title_segments)
30 | pdf_segmentation: PdfSegmentation = PdfSegmentation(pdf_features, title_segments)
31 | toc_instance: TOCExtractor = TOCExtractor(pdf_segmentation)
32 | return toc_instance.to_dict()
33 |
34 | def format_toc_for_uwazi(self, toc_items: list[dict]) -> list[dict]:
35 | toc_compatible = []
36 | for toc_item in toc_items:
37 | toc_compatible.append(toc_item.copy())
38 | toc_compatible[-1]["bounding_box"]["left"] = int(toc_item["bounding_box"]["left"] / 0.75)
39 | toc_compatible[-1]["bounding_box"]["top"] = int(toc_item["bounding_box"]["top"] / 0.75)
40 | toc_compatible[-1]["bounding_box"]["width"] = int(toc_item["bounding_box"]["width"] / 0.75)
41 | toc_compatible[-1]["bounding_box"]["height"] = int(toc_item["bounding_box"]["height"] / 0.75)
42 | toc_compatible[-1]["selectionRectangles"] = [toc_compatible[-1]["bounding_box"]]
43 | del toc_compatible[-1]["bounding_box"]
44 | return toc_compatible
45 |
46 | def _get_file_path(self, file_name: str, extension: str) -> str:
47 | return join(tempfile.gettempdir(), file_name + "." + extension)
48 |
49 | def _pdf_content_to_pdf_path(self, file_content: AnyStr) -> Path:
50 | file_id = str(uuid.uuid1())
51 | pdf_path = Path(self._get_file_path(file_id, "pdf"))
52 | pdf_path.write_bytes(file_content)
53 | return pdf_path
54 |
55 | def _skip_name_of_the_document(self, pdf_segments: list[PdfSegment], title_segments: list[PdfSegment]) -> None:
56 | segments_to_remove = []
57 | last_segment = None
58 | for segment in pdf_segments:
59 | if segment.segment_type not in SKIP_TYPES:
60 | break
61 | if segment.segment_type == TokenType.PAGE_HEADER or segment.segment_type == TokenType.PICTURE:
62 | continue
63 | if not last_segment:
64 | last_segment = segment
65 | else:
66 | if segment.bounding_box.right < last_segment.bounding_box.left + last_segment.bounding_box.width * 0.66:
67 | break
68 | last_segment = segment
69 | if segment.segment_type in TITLE_TYPES:
70 | segments_to_remove.append(segment)
71 | for segment in segments_to_remove:
72 | title_segments.remove(segment)
73 |
74 | def _get_pdf_segments_from_segment_boxes(self, pdf_features: PdfFeatures, segment_boxes: list[dict]) -> list[PdfSegment]:
75 | pdf_segments: list[PdfSegment] = []
76 | for segment_box in segment_boxes:
77 | left, top, width, height = segment_box["left"], segment_box["top"], segment_box["width"], segment_box["height"]
78 | bounding_box = Rectangle.from_width_height(left, top, width, height)
79 | segment_type = TokenType.from_value(segment_box["type"])
80 | pdf_name = pdf_features.file_name
81 | segment = PdfSegment(segment_box["page_number"], bounding_box, segment_box["text"], segment_type, pdf_name)
82 | pdf_segments.append(segment)
83 | return pdf_segments
84 |
--------------------------------------------------------------------------------
/src/adapters/infrastructure/translation/decode_html_content.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 |
4 | def decode_html(text, link_map, doc_ref_map):
5 | # 1. Decode bold+italic first
6 | def bold_italic_decoder(match):
7 | return f"{match.group(2)}"
8 |
9 | text = re.sub(r"\[BI(\d+)\](.*?)\[BI\1\]", bold_italic_decoder, text)
10 |
11 | # 2. Decode bold
12 | def bold_decoder(match):
13 | return f"{match.group(2)}"
14 |
15 | text = re.sub(r"\[B(\d+)\](.*?)\[B\1\]", bold_decoder, text)
16 |
17 | # 3. Decode italic
18 | def italic_decoder(match):
19 | return f"{match.group(2)}"
20 |
21 | text = re.sub(r"\[IT(\d+)\](.*?)\[IT\1\]", italic_decoder, text)
22 |
23 | # 4. Decode links
24 | def link_decoder(match):
25 | idx = int(match.group(1))
26 | if idx < len(link_map):
27 | label, url = link_map[idx]
28 | return f'{match.group(2)}'
29 | else:
30 | # Return original text if index is out of range
31 | return match.group(0)
32 |
33 | text = re.sub(r"\[LINK(\d+)\](.*?)\[LINK\1\]", link_decoder, text)
34 |
35 | # 5. Decode doc refs (same as markdown since they're custom)
36 | def doc_ref_decoder(match):
37 | idx = int(match.group(1))
38 | if idx < len(doc_ref_map):
39 | return doc_ref_map[idx]
40 | else:
41 | # Return original text if index is out of range
42 | return match.group(0)
43 |
44 | text = re.sub(r"\[DOCREF(\d+)\]", doc_ref_decoder, text)
45 | text = text.replace("] (#page", "](#page")
46 | text = " ".join(text.split())
47 |
48 | return text
49 |
--------------------------------------------------------------------------------
/src/adapters/infrastructure/translation/decode_markdown_content.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 |
4 | def decode_markdown(text, link_map, doc_ref_map):
5 | # 1. Decode bold+italic first
6 | def bold_italic_decoder(match):
7 | return f"**_{match.group(2)}_**"
8 |
9 | text = re.sub(r"\[BI(\d+)\](.*?)\[BI\1\]", bold_italic_decoder, text)
10 |
11 | # 2. Decode bold
12 | def bold_decoder(match):
13 | return f"**{match.group(2)}**"
14 |
15 | text = re.sub(r"\[B(\d+)\](.*?)\[B\1\]", bold_decoder, text)
16 |
17 | # 3. Decode italic
18 | def italic_decoder(match):
19 | return f"_{match.group(2)}_"
20 |
21 | text = re.sub(r"\[IT(\d+)\](.*?)\[IT\1\]", italic_decoder, text)
22 |
23 | # 4. Decode links
24 | def link_decoder(match):
25 | idx = int(match.group(1))
26 | if idx < len(link_map):
27 | label, url = link_map[idx]
28 | return f"[{match.group(2)}]({url})"
29 | else:
30 | return match.group(0)
31 |
32 | text = re.sub(r"\[LINK(\d+)\](.*?)\[LINK\1\]", link_decoder, text)
33 |
34 | # 5. Decode doc refs
35 | def doc_ref_decoder(match):
36 | idx = int(match.group(1))
37 | if idx < len(doc_ref_map):
38 | return doc_ref_map[idx]
39 | else:
40 | return match.group(0)
41 |
42 | text = re.sub(r"\[DOCREF(\d+)\]", doc_ref_decoder, text)
43 | text = text.replace("] (#page", "](#page")
44 | text = " ".join(text.split())
45 |
46 | return text
47 |
--------------------------------------------------------------------------------
/src/adapters/infrastructure/translation/download_translation_model.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import time
3 | from configuration import service_logger
4 |
5 | OLLAMA_NOT_RUNNING_MSG = "could not connect to ollama server"
6 |
7 |
8 | def is_ollama_running():
9 | try:
10 | result = subprocess.run(["ollama", "ls"], capture_output=True, text=True)
11 | msg = result.stderr.lower() + result.stdout.lower()
12 | return OLLAMA_NOT_RUNNING_MSG not in msg and result.returncode == 0
13 | except FileNotFoundError:
14 | service_logger.error("Ollama is not installed or not in PATH.")
15 | return False
16 |
17 |
18 | def start_ollama():
19 | try:
20 | subprocess.Popen(["ollama", "serve"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
21 | service_logger.info("Starting Ollama server...")
22 | time.sleep(5)
23 | return True
24 | except Exception as e:
25 | service_logger.error(f"Failed to start Ollama server: {e}")
26 | return False
27 |
28 |
29 | def model_name_variants(name):
30 | base = name.split(":")[0]
31 | return {base, f"{base}:latest", name}
32 |
33 |
34 | def ensure_ollama_model(model_name):
35 | if not is_ollama_running():
36 | service_logger.info("Ollama server is not running. Attempting to start it...")
37 | if not start_ollama():
38 | service_logger.error("Could not start Ollama server. Exiting.")
39 | return False
40 | for _ in range(5):
41 | if is_ollama_running():
42 | break
43 | time.sleep(2)
44 | else:
45 | service_logger.error("Ollama server did not start in time.")
46 | return False
47 |
48 | try:
49 | result = subprocess.run(["ollama", "ls"], capture_output=True, text=True, check=True)
50 | except subprocess.CalledProcessError as e:
51 | service_logger.error(f"Error running 'ollama ls': {e}")
52 | return False
53 |
54 | model_lines = [line.split()[0] for line in result.stdout.splitlines() if line and not line.startswith("NAME")]
55 | available_models = set(model_lines)
56 | variants = model_name_variants(model_name)
57 |
58 | if available_models & variants:
59 | service_logger.info(f"Model '{model_name}' already exists in Ollama.")
60 | return True
61 |
62 | service_logger.info(f"Model '{model_name}' not found. Pulling...")
63 | try:
64 | subprocess.run(["ollama", "pull", model_name], check=True)
65 | service_logger.info(f"Model '{model_name}' pulled successfully.")
66 | return True
67 | except subprocess.CalledProcessError as e:
68 | service_logger.error(f"Failed to pull model '{model_name}': {e}")
69 | return False
70 |
--------------------------------------------------------------------------------
/src/adapters/infrastructure/translation/encode_html_content.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 |
4 | def encode_html(text):
5 | text = text.replace(" ", " ")
6 |
7 | link_map = []
8 | doc_ref_map = []
9 | bold_map = []
10 | italic_map = []
11 | bold_italic_map = []
12 |
13 | # Helper to encode doc refs
14 | def doc_ref_replacer(match):
15 | idx = len(doc_ref_map)
16 | doc_ref_map.append(match.group(0))
17 | return f"[DOCREF{idx}]"
18 |
19 | # Helper to encode bold+italic
20 | def bold_italic_replacer(match):
21 | text_content = match.group(1)
22 | idx = len(bold_italic_map)
23 | bold_italic_map.append(text_content)
24 | return f"[BI{idx}]{text_content}[BI{idx}]"
25 |
26 | # Helper to encode bold
27 | def bold_replacer(match):
28 | text_content = match.group(1)
29 | idx = len(bold_map)
30 | bold_map.append(text_content)
31 | return f"[B{idx}]{text_content}[B{idx}]"
32 |
33 | # 1. Encode document references FIRST - Updated patterns to match your actual format
34 | # Handle patterns like [\[9,](#page-5-9), [10,](#page-5-10), [11\]](#page-5-11), [\[12\]](#page-5-12), [\[13\]](#page-5-13)
35 | text = re.sub(r"(\[\\?\[?\*?\d+[,\.]?\\?\]?\]\(#page-\d+-\d+\))", doc_ref_replacer, text)
36 |
37 | # Also handle the original patterns from markdown version
38 | text = re.sub(r"(\[(?:\\?\d+[,\.]? ?)+\]\(#page-\d+-\d+\))", doc_ref_replacer, text)
39 | text = re.sub(r"(\[\d+[,\.]?\]\(#page-\d+-\d+\))", doc_ref_replacer, text)
40 | text = re.sub(r"(\[\*\d+[,\.]?\]\(#page-\d+-\d+\))", doc_ref_replacer, text)
41 |
42 | # 2. Encode links BEFORE formatting - handle complex nested structures
43 | def find_and_replace_links(text):
44 | offset = 0 # Track how much the text has shifted due to replacements
45 |
46 | # Find all ]*>'
48 | matches = list(re.finditer(pattern, text, re.IGNORECASE))
49 |
50 | # Process from left to right
51 | for match in matches:
52 | start = match.start() + offset
53 | href = match.group(1)
54 |
55 | # Find the matching closing tag
56 | tag_count = 1
57 | pos = match.end() + offset
58 | content_start = pos
59 |
60 | while pos < len(text) and tag_count > 0:
61 | # Look for opening tags
62 | next_open = text.find(" tags
68 | next_close = text.find("", pos)
69 | next_close_case = text.find("", pos)
70 | if next_close_case != -1 and (next_close == -1 or next_close_case < next_close):
71 | next_close = next_close_case
72 |
73 | if next_close == -1:
74 | break
75 |
76 | if next_open != -1 and next_open < next_close:
77 | # Found nested opening tag
78 | tag_count += 1
79 | pos = next_open + 3
80 | else:
81 | # Found closing tag
82 | tag_count -= 1
83 | if tag_count == 0:
84 | # This is our matching closing tag
85 | content = text[content_start:next_close]
86 | end = next_close + 4 # +4 for
87 |
88 | # Replace the entire link
89 | idx = len(link_map)
90 | link_map.append((content, href))
91 | replacement = f"[LINK{idx}]{content}[LINK{idx}]"
92 |
93 | original_length = end - start
94 | new_length = len(replacement)
95 |
96 | text = text[:start] + replacement + text[end:]
97 | offset += new_length - original_length
98 | break
99 | else:
100 | pos = next_close + 4
101 |
102 | return text
103 |
104 | text = find_and_replace_links(text)
105 |
106 | # 3. Encode bold+italic combinations BEFORE individual bold/italic
107 | # Handle text and text - only simple cases without nested links
108 | text = re.sub(r"([^<]+)", bold_italic_replacer, text, flags=re.IGNORECASE)
109 | text = re.sub(r"([^<]+)", bold_italic_replacer, text, flags=re.IGNORECASE)
110 |
111 | # 4. Encode bold (text) - only simple cases without nested tags
112 | text = re.sub(r"([^<]+)", bold_replacer, text, flags=re.IGNORECASE)
113 |
114 | # 5. Encode italic (text) - handle cases that might contain encoded links
115 | def italic_with_links_replacer(match):
116 | content = match.group(1)
117 | idx = len(italic_map)
118 | italic_map.append(content)
119 | return f"[IT{idx}]{content}[IT{idx}]"
120 |
121 | # Handle italic tags that might contain encoded links
122 | text = re.sub(
123 | r"([^<]*(?:\[LINK\d+\][^\[]*\[LINK\d+\][^<]*)*)", italic_with_links_replacer, text, flags=re.IGNORECASE
124 | )
125 | # Handle simple italic tags
126 | text = re.sub(r"([^<]+)", italic_with_links_replacer, text, flags=re.IGNORECASE)
127 |
128 | return text, link_map, doc_ref_map
129 |
--------------------------------------------------------------------------------
/src/adapters/infrastructure/translation/encode_markdown_content.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 |
4 | def encode_markdown(text):
5 | text = text.replace("_ _", " ")
6 |
7 | link_map = []
8 | doc_ref_map = []
9 | bold_map = []
10 | italic_map = []
11 | bold_italic_map = []
12 |
13 | # Helper to encode links with nested brackets in label
14 | def link_replacer(match):
15 | label = match.group(1)[1:-1] # Remove outer []
16 | url = match.group(3)
17 | idx = len(link_map)
18 | link_map.append((label, url))
19 | return f"[LINK{idx}]{label}[LINK{idx}]"
20 |
21 | # Helper to encode doc refs - sequential numbering
22 | def doc_ref_replacer(match):
23 | idx = len(doc_ref_map)
24 | doc_ref_map.append(match.group(0))
25 | return f"[DOCREF{idx}]"
26 |
27 | # Helper to encode bold+italic
28 | def bold_italic_replacer(match):
29 | text_content = match.group(1)
30 | idx = len(bold_italic_map)
31 | bold_italic_map.append(text_content)
32 | return f"[BI{idx}]{text_content}[BI{idx}]"
33 |
34 | # Helper to encode bold
35 | def bold_replacer(match):
36 | text_content = match.group(1)
37 | idx = len(bold_map)
38 | bold_map.append(text_content)
39 | return f"[B{idx}]{text_content}[B{idx}]"
40 |
41 | # Helper to encode italic
42 | def italic_replacer(match):
43 | text_content = match.group(1)
44 | idx = len(italic_map)
45 | italic_map.append(text_content)
46 | return f"[IT{idx}]{text_content}[IT{idx}]"
47 |
48 | # 1. Encode ALL document references in ONE PASS for sequential numbering
49 | doc_ref_patterns = [
50 | r"\[\\?\[\d+,\]\(#page-\d+-\d+\)", # [\[9,](#page-5-9)
51 | r"\[\\?\[\d+\\?\]\]\(#page-\d+-\d+\)", # [\[12\]](#page-5-12)
52 | r"\[\*\d+\]\(#page-\d+-\d+\)", # [*1221](#page-3-7)
53 | r"\[\d+,\]\(#page-\d+-\d+\)", # [10,](#page-5-10)
54 | r"\[\d+\\?\]\]\(#page-\d+-\d+\)", # [11\]](#page-5-11)
55 | r"\[\d+\]\(#page-\d+-\d+\)", # [1221](#page-3-7)
56 | ]
57 |
58 | # Combine all patterns with alternation (|)
59 | combined_pattern = "(" + "|".join(doc_ref_patterns) + ")"
60 | text = re.sub(combined_pattern, doc_ref_replacer, text)
61 |
62 | # 2. Encode links BEFORE formatting to avoid matching underscores in URLs
63 | link_pattern = re.compile(r"(\[((?:[^\[\]]+|\[[^\[\]]*\])*)\])\((https?://[^\)]+)\)")
64 | while True:
65 | new_text = link_pattern.sub(lambda m: link_replacer(m), text)
66 | if new_text == text:
67 | break
68 | text = new_text
69 |
70 | # 3. Encode bold+italic BEFORE individual bold/italic
71 | text = re.sub(r"\*\*\_([^\*_]+)\_\*\*", bold_italic_replacer, text)
72 | text = re.sub(r"_\*([^\*_]+)\*_", bold_italic_replacer, text)
73 |
74 | # 4. Encode bold
75 | text = re.sub(r"\*\*([^\*]+)\*\*", bold_replacer, text)
76 |
77 | # 5. Encode italic
78 | text = re.sub(r"\_([^\_]+)\_", italic_replacer, text)
79 |
80 | return text, link_map, doc_ref_map
81 |
--------------------------------------------------------------------------------
/src/adapters/infrastructure/translation/ollama_container_manager.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import requests
4 | import json
5 | from typing import Optional, Any
6 | from configuration import service_logger
7 |
8 |
9 | class OllamaContainerManager:
10 |
11 | def __init__(self, ollama_host: str = None):
12 | self.ollama_host = ollama_host or os.getenv("OLLAMA_HOST", "http://ollama:11434")
13 | self.api_base_url = f"{self.ollama_host}/api"
14 | self.timeout = 600
15 | self.max_retries = 5
16 |
17 | def is_ollama_available(self) -> bool:
18 | try:
19 | response = requests.get(f"{self.api_base_url}/tags", timeout=10)
20 | return response.status_code == 200
21 | except Exception as e:
22 | service_logger.debug(f"Ollama availability check failed: {e}")
23 | return False
24 |
25 | def ensure_model_available(self, model_name: str) -> bool:
26 | try:
27 | if self._is_model_available(model_name):
28 | service_logger.info(f"\033[92mModel '{model_name}' is available\033[0m")
29 | return True
30 |
31 | service_logger.info(f"\033[93mModel '{model_name}' not found. Downloading...\033[0m")
32 | return self._download_model(model_name)
33 |
34 | except Exception as e:
35 | service_logger.error(f"Error ensuring model availability: {e}")
36 | return False
37 |
38 | def _is_model_available(self, model_name: str) -> bool:
39 | try:
40 | response = requests.get(f"{self.api_base_url}/tags", timeout=10)
41 | if response.status_code != 200:
42 | return False
43 |
44 | models_data = response.json()
45 | available_models = [model["name"] for model in models_data.get("models", [])]
46 |
47 | model_variants = {model_name, f"{model_name}:latest", model_name.split(":")[0]}
48 | return any(variant in available_models for variant in model_variants)
49 |
50 | except Exception as e:
51 | service_logger.error(f"Error checking model availability: {e}")
52 | return False
53 |
54 | def _download_model(self, model_name: str) -> bool:
55 | try:
56 | response = requests.post(f"{self.api_base_url}/pull", json={"name": model_name}, stream=True)
57 |
58 | if response.status_code != 200:
59 | service_logger.error(f"Failed to start model download: {response.text}")
60 | return False
61 |
62 | for idx, line in enumerate(response.iter_lines()):
63 | if line:
64 | try:
65 | data = json.loads(line)
66 | if "status" in data and idx % 100 == 0:
67 | service_logger.info(f"Model download: {data['status']}")
68 | if data.get("status") == "success":
69 | service_logger.info(f"Model '{model_name}' downloaded successfully")
70 | return True
71 | except json.JSONDecodeError:
72 | continue
73 |
74 | return True
75 |
76 | except Exception as e:
77 | service_logger.error(f"Error downloading model '{model_name}': {e}")
78 | return False
79 |
80 | def chat_with_timeout(
81 | self, model: str, messages: list[dict], source_markup: str, timeout: Optional[int] = None
82 | ) -> dict[str, Any] | str:
83 | timeout = timeout or self.timeout
84 |
85 | for attempt in range(self.max_retries + 1):
86 | try:
87 | if attempt > 0:
88 | service_logger.info(f"Retrying chat request (attempt {attempt + 1}/{self.max_retries + 1})")
89 | time.sleep(10)
90 |
91 | return self._make_chat_request(
92 | model,
93 | messages,
94 | timeout,
95 | )
96 |
97 | except requests.exceptions.Timeout:
98 | service_logger.warning(f"Chat request timed out after {timeout} seconds (attempt {attempt})")
99 | if attempt < self.max_retries:
100 | continue
101 | else:
102 | service_logger.error(f"Chat request failed after {self.max_retries} attempts due to timeout")
103 | return source_markup
104 |
105 | except Exception as e:
106 | service_logger.error(f"Chat request failed (attempt {attempt}): {e}")
107 | if attempt < self.max_retries:
108 | continue
109 | else:
110 | service_logger.error(f"Chat request failed after {self.max_retries} attempts")
111 | return source_markup
112 |
113 | return source_markup
114 |
115 | def _make_chat_request(self, model: str, messages: list, timeout: int) -> dict[str, Any]:
116 | payload = {"model": model, "messages": messages, "stream": False}
117 |
118 | response = requests.post(f"{self.api_base_url}/chat", json=payload, timeout=timeout)
119 |
120 | if response.status_code != 200:
121 | raise Exception(f"Chat request failed with status {response.status_code}: {response.text}")
122 |
123 | return response.json()
124 |
125 | def ensure_service_ready(self, model_name: str) -> bool:
126 | try:
127 | if not self.is_ollama_available():
128 | service_logger.error("Ollama service is not available. Make sure the Ollama container is running.")
129 | return False
130 |
131 | return self.ensure_model_available(model_name)
132 |
133 | except Exception as e:
134 | service_logger.error(f"Error ensuring service readiness: {e}")
135 | return False
136 |
--------------------------------------------------------------------------------
/src/adapters/infrastructure/visualization_service_adapter.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from os import makedirs
3 | from os.path import join
4 | from pdf_annotate import PdfAnnotator, Location, Appearance
5 | from starlette.responses import FileResponse
6 | from ports.services.visualization_service import VisualizationService
7 | from configuration import ROOT_PATH
8 |
9 | DOCLAYNET_COLOR_BY_TYPE = {
10 | "Caption": "#FFC300",
11 | "Footnote": "#581845",
12 | "Formula": "#FF5733",
13 | "List item": "#008B8B",
14 | "Page footer": "#FF5733",
15 | "Page header": "#581845",
16 | "Picture": "#C70039",
17 | "Section header": "#C70039",
18 | "Table": "#FF8C00",
19 | "Text": "#606060",
20 | "Title": "#EED400",
21 | }
22 |
23 |
24 | class VisualizationServiceAdapter(VisualizationService):
25 | def create_pdf_visualization(self, pdf_path: Path, segment_boxes: list[dict]) -> Path:
26 | pdf_outputs_path = join(ROOT_PATH, "pdf_outputs")
27 | makedirs(pdf_outputs_path, exist_ok=True)
28 | annotator = PdfAnnotator(str(pdf_path))
29 | segment_index = 0
30 | current_page = 1
31 |
32 | for segment_box in segment_boxes:
33 | if int(segment_box["page_number"]) != current_page:
34 | segment_index = 0
35 | current_page += 1
36 | page_height = int(segment_box["page_height"])
37 | self._add_prediction_annotation(annotator, segment_box, segment_index, page_height)
38 | segment_index += 1
39 |
40 | annotator.write(str(pdf_path))
41 | return pdf_path
42 |
43 | def get_visualization_response(self, pdf_path: Path) -> FileResponse:
44 | return FileResponse(path=pdf_path, media_type="application/pdf", filename=pdf_path.name)
45 |
46 | def _hex_color_to_rgb(self, color: str) -> tuple:
47 | r, g, b = int(color[1:3], 16), int(color[3:5], 16), int(color[5:7], 16)
48 | alpha = 1
49 | return r / 255, g / 255, b / 255, alpha
50 |
51 | def _add_prediction_annotation(
52 | self, annotator: PdfAnnotator, segment_box: dict, segment_index: int, page_height: int
53 | ) -> None:
54 | predicted_type = segment_box["type"]
55 | color = DOCLAYNET_COLOR_BY_TYPE[predicted_type]
56 | left, top, right, bottom = (
57 | segment_box["left"],
58 | page_height - segment_box["top"],
59 | segment_box["left"] + segment_box["width"],
60 | page_height - (segment_box["top"] + segment_box["height"]),
61 | )
62 | text_box_size = len(predicted_type) * 8 + 8
63 |
64 | annotator.add_annotation(
65 | "square",
66 | Location(x1=left, y1=bottom, x2=right, y2=top, page=int(segment_box["page_number"]) - 1),
67 | Appearance(stroke_color=self._hex_color_to_rgb(color)),
68 | )
69 |
70 | annotator.add_annotation(
71 | "square",
72 | Location(x1=left, y1=top, x2=left + text_box_size, y2=top + 10, page=int(segment_box["page_number"]) - 1),
73 | Appearance(fill=self._hex_color_to_rgb(color)),
74 | )
75 |
76 | content = predicted_type.capitalize() + f" [{str(segment_index+1)}]"
77 | annotator.add_annotation(
78 | "text",
79 | Location(x1=left, y1=top, x2=left + text_box_size, y2=top + 10, page=int(segment_box["page_number"]) - 1),
80 | Appearance(content=content, font_size=8, fill=(1, 1, 1), stroke_width=3),
81 | )
82 |
--------------------------------------------------------------------------------
/src/adapters/ml/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/ml/__init__.py
--------------------------------------------------------------------------------
/src/adapters/ml/fast_trainer/Paragraph.py:
--------------------------------------------------------------------------------
1 | from pdf_features import PdfToken
2 |
3 |
4 | class Paragraph:
5 | def __init__(self, tokens: list[PdfToken], pdf_name: str = ""):
6 | self.tokens = tokens
7 | self.pdf_name = pdf_name
8 |
9 | def add_token(self, token: PdfToken):
10 | self.tokens.append(token)
11 |
--------------------------------------------------------------------------------
/src/adapters/ml/fast_trainer/ParagraphExtractorTrainer.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from adapters.ml.fast_trainer.Paragraph import Paragraph
4 | from domain.PdfSegment import PdfSegment
5 | from pdf_features import PdfToken
6 | from pdf_token_type_labels import TokenType
7 | from adapters.ml.pdf_tokens_type_trainer.TokenFeatures import TokenFeatures
8 | from adapters.ml.pdf_tokens_type_trainer.TokenTypeTrainer import TokenTypeTrainer
9 |
10 |
11 | class ParagraphExtractorTrainer(TokenTypeTrainer):
12 | def get_context_features(self, token_features: TokenFeatures, page_tokens: list[PdfToken], token_index: int):
13 | token_row_features = list()
14 | first_token_from_context = token_index - self.model_configuration.context_size
15 | for i in range(self.model_configuration.context_size * 2):
16 | first_token = page_tokens[first_token_from_context + i]
17 | second_token = page_tokens[first_token_from_context + i + 1]
18 | features = token_features.get_features(first_token, second_token, page_tokens)
19 | features += self.get_paragraph_extraction_features(first_token, second_token)
20 | token_row_features.extend(features)
21 |
22 | return token_row_features
23 |
24 | @staticmethod
25 | def get_paragraph_extraction_features(first_token: PdfToken, second_token: PdfToken) -> list[int]:
26 | one_hot_token_type_1 = [1 if token_type == first_token.token_type else 0 for token_type in TokenType]
27 | one_hot_token_type_2 = [1 if token_type == second_token.token_type else 0 for token_type in TokenType]
28 | return one_hot_token_type_1 + one_hot_token_type_2
29 |
30 | def loop_token_next_token(self):
31 | for pdf_features in self.pdfs_features:
32 | for page in pdf_features.pages:
33 | if not page.tokens:
34 | continue
35 | if len(page.tokens) == 1:
36 | yield page, page.tokens[0], page.tokens[0]
37 | for token, next_token in zip(page.tokens, page.tokens[1:]):
38 | yield page, token, next_token
39 |
40 | def get_pdf_segments(self, paragraph_extractor_model_path: str | Path) -> list[PdfSegment]:
41 | paragraphs = self.get_paragraphs(paragraph_extractor_model_path)
42 | pdf_segments = [PdfSegment.from_pdf_tokens(paragraph.tokens, paragraph.pdf_name) for paragraph in paragraphs]
43 |
44 | return pdf_segments
45 |
46 | def get_paragraphs(self, paragraph_extractor_model_path) -> list[Paragraph]:
47 | self.predict(paragraph_extractor_model_path)
48 | paragraphs: list[Paragraph] = []
49 | last_page = None
50 | for page, token, next_token in self.loop_token_next_token():
51 | if last_page != page:
52 | last_page = page
53 | paragraphs.append(Paragraph([token], page.pdf_name))
54 | if token == next_token:
55 | continue
56 | if token.prediction:
57 | paragraphs[-1].add_token(next_token)
58 | continue
59 | paragraphs.append(Paragraph([next_token], page.pdf_name))
60 |
61 | return paragraphs
62 |
--------------------------------------------------------------------------------
/src/adapters/ml/fast_trainer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/ml/fast_trainer/__init__.py
--------------------------------------------------------------------------------
/src/adapters/ml/fast_trainer/model_configuration.py:
--------------------------------------------------------------------------------
1 | from adapters.ml.pdf_tokens_type_trainer.ModelConfiguration import ModelConfiguration
2 |
3 | config_json = {
4 | "boosting_type": "gbdt",
5 | "verbose": -1,
6 | "learning_rate": 0.1,
7 | "num_class": 2,
8 | "context_size": 1,
9 | "num_boost_round": 400,
10 | "num_leaves": 191,
11 | "bagging_fraction": 0.9166599392739231,
12 | "bagging_freq": 7,
13 | "feature_fraction": 0.3116707710163228,
14 | "lambda_l1": 0.0006901861637621734,
15 | "lambda_l2": 1.1886914989632197e-05,
16 | "min_data_in_leaf": 50,
17 | "feature_pre_filter": True,
18 | "seed": 22,
19 | "deterministic": True,
20 | }
21 |
22 | MODEL_CONFIGURATION = ModelConfiguration(**config_json)
23 |
24 | if __name__ == "__main__":
25 | print(MODEL_CONFIGURATION)
26 |
--------------------------------------------------------------------------------
/src/adapters/ml/fast_trainer_adapter.py:
--------------------------------------------------------------------------------
1 | from os.path import join
2 | from domain.PdfImages import PdfImages
3 | from domain.PdfSegment import PdfSegment
4 | from ports.services.ml_model_service import MLModelService
5 | from adapters.ml.fast_trainer.ParagraphExtractorTrainer import ParagraphExtractorTrainer
6 | from adapters.ml.fast_trainer.model_configuration import MODEL_CONFIGURATION as PARAGRAPH_EXTRACTION_CONFIGURATION
7 | from adapters.ml.pdf_tokens_type_trainer.TokenTypeTrainer import TokenTypeTrainer
8 | from adapters.ml.pdf_tokens_type_trainer.ModelConfiguration import ModelConfiguration
9 | from configuration import ROOT_PATH, service_logger
10 |
11 |
12 | class FastTrainerAdapter(MLModelService):
13 | def predict_document_layout(self, pdf_images: list[PdfImages]) -> list[PdfSegment]:
14 | return self.predict_layout_fast(pdf_images)
15 |
16 | def predict_layout_fast(self, pdf_images: list[PdfImages]) -> list[PdfSegment]:
17 | service_logger.info("Creating Paragraph Tokens [fast]")
18 |
19 | pdf_images_obj = pdf_images[0]
20 |
21 | token_type_trainer = TokenTypeTrainer([pdf_images_obj.pdf_features], ModelConfiguration())
22 | token_type_trainer.set_token_types(join(ROOT_PATH, "models", "token_type_lightgbm.model"))
23 |
24 | trainer = ParagraphExtractorTrainer(
25 | pdfs_features=[pdf_images_obj.pdf_features], model_configuration=PARAGRAPH_EXTRACTION_CONFIGURATION
26 | )
27 | segments = trainer.get_pdf_segments(join(ROOT_PATH, "models", "paragraph_extraction_lightgbm.model"))
28 |
29 | pdf_images_obj.remove_images()
30 |
31 | return segments
32 |
--------------------------------------------------------------------------------
/src/adapters/ml/pdf_tokens_type_trainer/ModelConfiguration.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, asdict
2 |
3 | from pdf_token_type_labels import TokenType
4 |
5 |
6 | @dataclass
7 | class ModelConfiguration:
8 | context_size: int = 4
9 | num_boost_round: int = 700
10 | num_leaves: int = 127
11 | bagging_fraction: float = 0.6810645192499981
12 | lambda_l1: float = 1.1533558410486358e-08
13 | lambda_l2: float = 4.91211684620458
14 | feature_fraction: float = 0.7087268965467017
15 | bagging_freq: int = 10
16 | min_data_in_leaf: int = 47
17 | feature_pre_filter: bool = False
18 | boosting_type: str = "gbdt"
19 | objective: str = "multiclass"
20 | metric: str = "multi_logloss"
21 | learning_rate: float = 0.1
22 | seed: int = 22
23 | num_class: int = len(TokenType)
24 | verbose: int = -1
25 | deterministic: bool = True
26 | resume_training: bool = False
27 | early_stopping_rounds: int = None
28 |
29 | def dict(self):
30 | return asdict(self)
31 |
--------------------------------------------------------------------------------
/src/adapters/ml/pdf_tokens_type_trainer/PdfTrainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os.path import exists, join
3 | from pathlib import Path
4 |
5 | import lightgbm as lgb
6 | import numpy as np
7 |
8 | from pdf_features import PdfFeatures, PdfTokenStyle
9 | from pdf_features import PdfFont
10 | from pdf_features import PdfToken
11 | from pdf_features import Rectangle
12 | from pdf_token_type_labels import TokenType
13 | from adapters.ml.pdf_tokens_type_trainer.ModelConfiguration import ModelConfiguration
14 | from adapters.ml.pdf_tokens_type_trainer.download_models import pdf_tokens_type_model
15 |
16 |
17 | class PdfTrainer:
18 | def __init__(self, pdfs_features: list[PdfFeatures], model_configuration: ModelConfiguration = None):
19 | self.pdfs_features = pdfs_features
20 | self.model_configuration = model_configuration if model_configuration else ModelConfiguration()
21 |
22 | def get_model_input(self) -> np.ndarray:
23 | pass
24 |
25 | @staticmethod
26 | def features_rows_to_x(features_rows):
27 | if not features_rows:
28 | return np.zeros((0, 0))
29 |
30 | x = np.zeros(((len(features_rows)), len(features_rows[0])))
31 | for i, v in enumerate(features_rows):
32 | x[i] = v
33 | return x
34 |
35 | def train(self, model_path: str | Path, labels: list[int]):
36 | print("Getting model input")
37 | x_train = self.get_model_input()
38 |
39 | if not x_train.any():
40 | print("No data for training")
41 | return
42 |
43 | lgb_train = lgb.Dataset(x_train, labels)
44 | lgb_eval = lgb.Dataset(x_train, labels, reference=lgb_train)
45 | print("Training")
46 |
47 | if self.model_configuration.resume_training and exists(model_path):
48 | model = lgb.Booster(model_file=model_path)
49 | gbm = model.refit(x_train, labels)
50 | else:
51 | gbm = lgb.train(params=self.model_configuration.dict(), train_set=lgb_train, valid_sets=[lgb_eval])
52 |
53 | print("Saving")
54 | gbm.save_model(model_path, num_iteration=gbm.best_iteration)
55 |
56 | def loop_tokens(self):
57 | for pdf_features in self.pdfs_features:
58 | for page, token in pdf_features.loop_tokens():
59 | yield token
60 |
61 | @staticmethod
62 | def get_padding_token(segment_number: int, page_number: int):
63 | return PdfToken(
64 | page_number=page_number,
65 | id="pad_token",
66 | content="",
67 | font=PdfFont(font_id="pad_font_id", font_size=0, bold=False, italics=False, color="black"),
68 | reading_order_no=segment_number,
69 | bounding_box=Rectangle.from_coordinates(0, 0, 0, 0),
70 | token_type=TokenType.TEXT,
71 | token_style=PdfTokenStyle(
72 | font=PdfFont(font_id="pad_font_id", font_size=0, bold=False, italics=False, color="black")
73 | ),
74 | )
75 |
76 | def predict(self, model_path: str | Path = None):
77 | model_path = model_path if model_path else pdf_tokens_type_model
78 | x = self.get_model_input()
79 |
80 | if not x.any():
81 | return self.pdfs_features
82 |
83 | lightgbm_model = lgb.Booster(model_file=model_path)
84 | return lightgbm_model.predict(x)
85 |
86 | def save_training_data(self, save_folder_path: str | Path, labels: list[int]):
87 | os.makedirs(save_folder_path, exist_ok=True)
88 |
89 | x = self.get_model_input()
90 |
91 | np.save(join(str(save_folder_path), "x.npy"), x)
92 | np.save(join(str(save_folder_path), "y.npy"), labels)
93 |
--------------------------------------------------------------------------------
/src/adapters/ml/pdf_tokens_type_trainer/TokenFeatures.py:
--------------------------------------------------------------------------------
1 | import string
2 | import unicodedata
3 |
4 | from pdf_features import PdfFeatures
5 | from pdf_features import PdfToken
6 | from adapters.ml.pdf_tokens_type_trainer.config import CHARACTER_TYPE
7 |
8 |
9 | class TokenFeatures:
10 | def __init__(self, pdfs_features: PdfFeatures):
11 | self.pdfs_features = pdfs_features
12 |
13 | def get_features(self, token_1: PdfToken, token_2: PdfToken, page_tokens: list[PdfToken]):
14 | same_font = True if token_1.font.font_id == token_2.font.font_id else False
15 |
16 | return (
17 | [
18 | same_font,
19 | self.pdfs_features.pdf_modes.font_size_mode / 100,
20 | len(token_1.content),
21 | len(token_2.content),
22 | token_1.content.count(" "),
23 | token_2.content.count(" "),
24 | sum(character in string.punctuation for character in token_1.content),
25 | sum(character in string.punctuation for character in token_2.content),
26 | ]
27 | + self.get_position_features(token_1, token_2, page_tokens)
28 | + self.get_unicode_categories(token_1)
29 | + self.get_unicode_categories(token_2)
30 | )
31 |
32 | def get_position_features(self, token_1: PdfToken, token_2: PdfToken, page_tokens):
33 | left_1 = token_1.bounding_box.left
34 | right_1 = token_1.bounding_box.right
35 | height_1 = token_1.bounding_box.height
36 | width_1 = token_1.bounding_box.width
37 |
38 | left_2 = token_2.bounding_box.left
39 | right_2 = token_2.bounding_box.right
40 | height_2 = token_2.bounding_box.height
41 | width_2 = token_2.bounding_box.width
42 |
43 | right_gap_1, left_gap_2 = (
44 | token_1.pdf_token_context.left_of_token_on_the_right - right_1,
45 | left_2 - token_2.pdf_token_context.right_of_token_on_the_left,
46 | )
47 |
48 | absolute_right_1 = max(right_1, token_1.pdf_token_context.right_of_token_on_the_right)
49 | absolute_right_2 = max(right_2, token_2.pdf_token_context.right_of_token_on_the_right)
50 |
51 | absolute_left_1 = min(left_1, token_1.pdf_token_context.left_of_token_on_the_left)
52 | absolute_left_2 = min(left_2, token_2.pdf_token_context.left_of_token_on_the_left)
53 |
54 | right_distance, left_distance, height_difference = left_2 - left_1 - width_1, left_1 - left_2, height_1 - height_2
55 |
56 | top_distance = token_2.bounding_box.top - token_1.bounding_box.top - height_1
57 | top_distance_gaps = self.get_top_distance_gap(token_1, token_2, page_tokens)
58 |
59 | start_lines_differences = absolute_left_1 - absolute_left_2
60 | end_lines_difference = abs(absolute_right_1 - absolute_right_2)
61 |
62 | return [
63 | absolute_right_1,
64 | token_1.bounding_box.top,
65 | right_1,
66 | width_1,
67 | height_1,
68 | token_2.bounding_box.top,
69 | right_2,
70 | width_2,
71 | height_2,
72 | right_distance,
73 | left_distance,
74 | right_gap_1,
75 | left_gap_2,
76 | height_difference,
77 | top_distance,
78 | top_distance - self.pdfs_features.pdf_modes.lines_space_mode,
79 | top_distance_gaps,
80 | top_distance - height_1,
81 | end_lines_difference,
82 | start_lines_differences,
83 | self.pdfs_features.pdf_modes.lines_space_mode - top_distance_gaps,
84 | self.pdfs_features.pdf_modes.right_space_mode - absolute_right_1,
85 | ]
86 |
87 | @staticmethod
88 | def get_top_distance_gap(token_1: PdfToken, token_2: PdfToken, page_tokens):
89 | top_distance = token_2.bounding_box.top - token_1.bounding_box.top - token_1.bounding_box.height
90 | tokens_in_the_middle = [
91 | token
92 | for token in page_tokens
93 | if token_1.bounding_box.bottom <= token.bounding_box.top < token_2.bounding_box.top
94 | ]
95 |
96 | gap_middle_bottom = 0
97 | gap_middle_top = 0
98 |
99 | if tokens_in_the_middle:
100 | tokens_in_the_middle_top = min([token.bounding_box.top for token in tokens_in_the_middle])
101 | tokens_in_the_middle_bottom = max([token.bounding_box.bottom for token in tokens_in_the_middle])
102 | gap_middle_top = tokens_in_the_middle_top - token_1.bounding_box.top - token_1.bounding_box.height
103 | gap_middle_bottom = token_2.bounding_box.top - tokens_in_the_middle_bottom
104 |
105 | top_distance_gaps = top_distance - (gap_middle_bottom - gap_middle_top)
106 | return top_distance_gaps
107 |
108 | @staticmethod
109 | def get_unicode_categories(token: PdfToken):
110 | if token.id == "pad_token":
111 | return [-1] * len(CHARACTER_TYPE) * 4
112 |
113 | categories = [unicodedata.category(letter) for letter in token.content[:2] + token.content[-2:]]
114 | categories += ["no_category"] * (4 - len(categories))
115 |
116 | categories_one_hot_encoding = list()
117 |
118 | for category in categories:
119 | categories_one_hot_encoding.extend([0] * len(CHARACTER_TYPE))
120 | if category not in CHARACTER_TYPE:
121 | continue
122 |
123 | category_index = len(categories_one_hot_encoding) - len(CHARACTER_TYPE) + CHARACTER_TYPE.index(category)
124 | categories_one_hot_encoding[category_index] = 1
125 |
126 | return categories_one_hot_encoding
127 |
--------------------------------------------------------------------------------
/src/adapters/ml/pdf_tokens_type_trainer/TokenTypeTrainer.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import numpy as np
4 | from tqdm import tqdm
5 |
6 | from pdf_features import PdfToken
7 | from pdf_token_type_labels import TokenType
8 | from adapters.ml.pdf_tokens_type_trainer.PdfTrainer import PdfTrainer
9 | from adapters.ml.pdf_tokens_type_trainer.TokenFeatures import TokenFeatures
10 |
11 |
12 | class TokenTypeTrainer(PdfTrainer):
13 | def get_model_input(self) -> np.ndarray:
14 | features_rows = []
15 |
16 | contex_size = self.model_configuration.context_size
17 | for token_features, page in self.loop_token_features():
18 | page_tokens = [
19 | self.get_padding_token(segment_number=i - 999999, page_number=page.page_number) for i in range(contex_size)
20 | ]
21 | page_tokens += page.tokens
22 | page_tokens += [
23 | self.get_padding_token(segment_number=999999 + i, page_number=page.page_number) for i in range(contex_size)
24 | ]
25 |
26 | tokens_indexes = range(contex_size, len(page_tokens) - contex_size)
27 | page_features = [self.get_context_features(token_features, page_tokens, i) for i in tokens_indexes]
28 | features_rows.extend(page_features)
29 |
30 | return self.features_rows_to_x(features_rows)
31 |
32 | def loop_token_features(self):
33 | for pdf_features in tqdm(self.pdfs_features):
34 | token_features = TokenFeatures(pdf_features)
35 |
36 | for page in pdf_features.pages:
37 | if not page.tokens:
38 | continue
39 |
40 | yield token_features, page
41 |
42 | def get_context_features(self, token_features: TokenFeatures, page_tokens: list[PdfToken], token_index: int):
43 | token_row_features = []
44 | first_token_from_context = token_index - self.model_configuration.context_size
45 | for i in range(self.model_configuration.context_size * 2):
46 | first_token = page_tokens[first_token_from_context + i]
47 | second_token = page_tokens[first_token_from_context + i + 1]
48 | token_row_features.extend(token_features.get_features(first_token, second_token, page_tokens))
49 |
50 | return token_row_features
51 |
52 | def predict(self, model_path: str | Path = None):
53 | predictions = super().predict(model_path)
54 | predictions_assigned = 0
55 | for token_features, page in self.loop_token_features():
56 | for token, prediction in zip(
57 | page.tokens, predictions[predictions_assigned : predictions_assigned + len(page.tokens)]
58 | ):
59 | token.prediction = int(np.argmax(prediction))
60 |
61 | predictions_assigned += len(page.tokens)
62 |
63 | def set_token_types(self, model_path: str | Path = None):
64 | self.predict(model_path)
65 | for token in self.loop_tokens():
66 | token.token_type = TokenType.from_index(token.prediction)
67 |
--------------------------------------------------------------------------------
/src/adapters/ml/pdf_tokens_type_trainer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/ml/pdf_tokens_type_trainer/__init__.py
--------------------------------------------------------------------------------
/src/adapters/ml/pdf_tokens_type_trainer/config.py:
--------------------------------------------------------------------------------
1 | from os.path import join
2 | from pathlib import Path
3 |
4 | ROOT_PATH = Path(__file__).parent.parent.parent.parent.parent.absolute()
5 | PDF_LABELED_DATA_ROOT_PATH = Path(join(ROOT_PATH.parent.absolute(), "pdf-labeled-data"))
6 | TOKEN_TYPE_LABEL_PATH = Path(join(PDF_LABELED_DATA_ROOT_PATH, "labeled_data", "token_type"))
7 |
8 | TRAINED_MODEL_PATH = join(ROOT_PATH, "model", "pdf_tokens_type.model")
9 | TOKEN_TYPE_RELATIVE_PATH = join("labeled_data", "token_type")
10 | MISTAKES_RELATIVE_PATH = join("labeled_data", "task_mistakes")
11 |
12 | XML_NAME = "etree.xml"
13 | LABELS_FILE_NAME = "labels.json"
14 | STATUS_FILE_NAME = "status.txt"
15 |
16 | CHARACTER_TYPE = [
17 | "Lt",
18 | "Lo",
19 | "Sk",
20 | "Lm",
21 | "Sm",
22 | "Cf",
23 | "Nl",
24 | "Pe",
25 | "Po",
26 | "Pd",
27 | "Me",
28 | "Sc",
29 | "Ll",
30 | "Pf",
31 | "Mc",
32 | "Lu",
33 | "Zs",
34 | "Cn",
35 | "Cc",
36 | "No",
37 | "Co",
38 | "Ps",
39 | "Nd",
40 | "Mn",
41 | "Pi",
42 | "So",
43 | "Pc",
44 | ]
45 |
46 | if __name__ == "__main__":
47 | print(ROOT_PATH)
48 | print(PDF_LABELED_DATA_ROOT_PATH)
49 |
--------------------------------------------------------------------------------
/src/adapters/ml/pdf_tokens_type_trainer/download_models.py:
--------------------------------------------------------------------------------
1 | from huggingface_hub import hf_hub_download
2 |
3 | pdf_tokens_type_model = hf_hub_download(
4 | repo_id="HURIDOCS/pdf-segmentation",
5 | filename="pdf_tokens_type.model",
6 | revision="c71f833500707201db9f3649a6d2010d3ce9d4c9",
7 | )
8 |
9 | token_type_finding_config_path = hf_hub_download(
10 | repo_id="HURIDOCS/pdf-segmentation",
11 | filename="tag_type_finding_model_config.txt",
12 | revision="7d98776dd34acb2fe3a06495c82e64b9c84bdc16",
13 | )
14 |
--------------------------------------------------------------------------------
/src/adapters/ml/pdf_tokens_type_trainer/get_paths.py:
--------------------------------------------------------------------------------
1 | from os.path import join
2 | from pathlib import Path
3 |
4 |
5 | def get_xml_path(pdf_labeled_data_project_path: str):
6 | return Path(join(pdf_labeled_data_project_path, "pdfs"))
7 |
--------------------------------------------------------------------------------
/src/adapters/ml/pdf_tokens_type_trainer/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/ml/pdf_tokens_type_trainer/tests/__init__.py
--------------------------------------------------------------------------------
/src/adapters/ml/pdf_tokens_type_trainer/tests/test_trainer.py:
--------------------------------------------------------------------------------
1 | from os.path import join, exists
2 | from unittest import TestCase
3 |
4 | from pdf_token_type_labels import TokenType
5 | from adapters.ml.pdf_tokens_type_trainer.TokenTypeTrainer import TokenTypeTrainer
6 |
7 | from pdf_features import PdfFeatures
8 |
9 | from configuration import ROOT_PATH
10 |
11 |
12 | class TestTrainer(TestCase):
13 | def test_train_blank_pdf(self):
14 | pdf_features = PdfFeatures.from_pdf_path(join(ROOT_PATH, "test_pdfs", "blank.pdf"))
15 | model_path = join(ROOT_PATH, "model", "blank.model")
16 | trainer = TokenTypeTrainer([pdf_features])
17 | trainer.train(model_path, [])
18 | self.assertFalse(exists(model_path))
19 |
20 | def test_predict_blank_pdf(self):
21 | pdf_features = PdfFeatures.from_pdf_path(join(ROOT_PATH, "test_pdfs", "blank.pdf"))
22 | trainer = TokenTypeTrainer([pdf_features])
23 | trainer.set_token_types()
24 | self.assertEqual([], pdf_features.pages[0].tokens)
25 |
26 | def test_predict(self):
27 | pdf_features = PdfFeatures.from_pdf_path(join(ROOT_PATH, "test_pdfs", "test.pdf"))
28 | trainer = TokenTypeTrainer([pdf_features])
29 | trainer.set_token_types()
30 | tokens = pdf_features.pages[0].tokens
31 | self.assertEqual(TokenType.TITLE, tokens[0].token_type)
32 | self.assertEqual("Document Big Centered Title", tokens[0].content)
33 | self.assertEqual(TokenType.TEXT, tokens[1].token_type)
34 | self.assertEqual("List Title", tokens[10].content)
35 |
--------------------------------------------------------------------------------
/src/adapters/ml/vgt/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/ml/vgt/__init__.py
--------------------------------------------------------------------------------
/src/adapters/ml/vgt/bros/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this
3 | # module, but to preserve other warnings. So, don't check this module at all.
4 |
5 | # Copyright 2021 NAVER CLOVA Team. All rights reserved.
6 | # Copyright 2020 The HuggingFace Team. All rights reserved.
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 |
20 | from typing import TYPE_CHECKING
21 |
22 | from transformers.file_utils import (
23 | _LazyModule,
24 | is_tokenizers_available,
25 | is_torch_available,
26 | )
27 |
28 | _import_structure = {
29 | "configuration_bros": ["BROS_PRETRAINED_CONFIG_ARCHIVE_MAP", "BrosConfig"],
30 | "tokenization_bros": ["BrosTokenizer"],
31 | }
32 |
33 | if is_tokenizers_available():
34 | _import_structure["tokenization_bros_fast"] = ["BrosTokenizerFast"]
35 |
36 | if is_torch_available():
37 | _import_structure["modeling_bros"] = [
38 | "BROS_PRETRAINED_MODEL_ARCHIVE_LIST",
39 | "BrosForMaskedLM",
40 | "BrosForPreTraining",
41 | "BrosForSequenceClassification",
42 | "BrosForTokenClassification",
43 | "BrosModel",
44 | "BrosLMHeadModel",
45 | "BrosPreTrainedModel",
46 | ]
47 |
48 | if TYPE_CHECKING:
49 | from .configuration_bros import BROS_PRETRAINED_CONFIG_ARCHIVE_MAP, BrosConfig
50 | from .tokenization_bros import BrosTokenizer
51 |
52 | if is_tokenizers_available():
53 | from .tokenization_bros_fast import BrosTokenizerFast
54 |
55 | if is_torch_available():
56 | from .modeling_bros import (
57 | BROS_PRETRAINED_MODEL_ARCHIVE_LIST,
58 | BrosForMaskedLM,
59 | BrosForPreTraining,
60 | BrosForSequenceClassification,
61 | BrosForTokenClassification,
62 | BrosLMHeadModel,
63 | BrosModel,
64 | BrosPreTrainedModel,
65 | )
66 |
67 | else:
68 | import sys
69 |
70 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
71 |
--------------------------------------------------------------------------------
/src/adapters/ml/vgt/bros/tokenization_bros.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022-present NAVER Corp, The Microsoft Research Asia LayoutLM Team Authors and the HuggingFace Inc. team..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes for BROS."""
16 |
17 |
18 | import collections
19 |
20 | from transformers.models.bert.tokenization_bert import BertTokenizer
21 | from transformers.utils import logging
22 |
23 | logger = logging.get_logger(__name__)
24 |
25 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
26 |
27 | PRETRAINED_VOCAB_FILES_MAP = {
28 | "vocab_file": {
29 | "naver-clova-ocr/bros-base-uncased": "https://huggingface.co/naver-clova-ocr/bros-base-uncased/resolve/main/vocab.txt",
30 | "naver-clova-ocr/bros-large-uncased": "https://huggingface.co/naver-clova-ocr/bros-large-uncased/resolve/main/vocab.txt",
31 | }
32 | }
33 |
34 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
35 | "naver-clova-ocr/bros-base-uncased": 512,
36 | "naver-clova-ocr/bros-large-uncased": 512,
37 | }
38 |
39 | PRETRAINED_INIT_CONFIGURATION = {
40 | "naver-clova-ocr/bros-base-uncased": {"do_lower_case": True},
41 | "naver-clova-ocr/bros-large-uncased": {"do_lower_case": True},
42 | }
43 |
44 |
45 | def load_vocab(vocab_file):
46 | """Loads a vocabulary file into a dictionary."""
47 | vocab = collections.OrderedDict()
48 | with open(vocab_file, "r", encoding="utf-8") as reader:
49 | tokens = reader.readlines()
50 | for index, token in enumerate(tokens):
51 | token = token.rstrip("\n")
52 | vocab[token] = index
53 | return vocab
54 |
55 |
56 | def whitespace_tokenize(text):
57 | """Runs basic whitespace cleaning and splitting on a piece of text."""
58 | text = text.strip()
59 | if not text:
60 | return []
61 | tokens = text.split()
62 | return tokens
63 |
64 |
65 | class BrosTokenizer(BertTokenizer):
66 | r"""
67 | Construct a BERT tokenizer. Based on WordPiece.
68 |
69 | This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
70 | Users should refer to this superclass for more information regarding those methods.
71 |
72 | Args:
73 | vocab_file (:obj:`str`):
74 | File containing the vocabulary.
75 | do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
76 | Whether or not to lowercase the input when tokenizing.
77 | do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
78 | Whether or not to do basic tokenization before WordPiece.
79 | never_split (:obj:`Iterable`, `optional`):
80 | Collection of tokens which will never be split during tokenization. Only has an effect when
81 | :obj:`do_basic_tokenize=True`
82 | unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`):
83 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
84 | token instead.
85 | sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`):
86 | The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
87 | sequence classification or for a text and a question for question answering. It is also used as the last
88 | token of a sequence built with special tokens.
89 | pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`):
90 | The token used for padding, for example when batching sequences of different lengths.
91 | cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`):
92 | The classifier token which is used when doing sequence classification (classification of the whole sequence
93 | instead of per-token classification). It is the first token of the sequence when built with special tokens.
94 | mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`):
95 | The token used for masking values. This is the token used when training this model with masked language
96 | modeling. This is the token which the model will try to predict.
97 | tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
98 | Whether or not to tokenize Chinese characters.
99 |
100 | This should likely be deactivated for Japanese (see this `issue
101 | `__).
102 | strip_accents: (:obj:`bool`, `optional`):
103 | Whether or not to strip all accents. If this option is not specified, then it will be determined by the
104 | value for :obj:`lowercase` (as in the original BERT).
105 | """
106 |
107 | vocab_files_names = VOCAB_FILES_NAMES
108 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
109 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
110 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
111 |
--------------------------------------------------------------------------------
/src/adapters/ml/vgt/bros/tokenization_bros_fast.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022-present NAVER Corp, The Microsoft Research Asia LayoutLM Team Authors and the HuggingFace Inc. team..
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Fast Tokenization classes for BROS."""
16 |
17 | from transformers.models.bert.tokenization_bert_fast import BertTokenizerFast
18 | from transformers.utils import logging
19 |
20 | from .tokenization_bros import BrosTokenizer
21 |
22 | logger = logging.get_logger(__name__)
23 |
24 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"}
25 |
26 | PRETRAINED_VOCAB_FILES_MAP = {
27 | "vocab_file": {
28 | "naver-clova-ocr/bros-base-uncased": "https://huggingface.co/naver-clova-ocr/bros-base-uncased/resolve/main/vocab.txt",
29 | "naver-clova-ocr/bros-large-uncased": "https://huggingface.co/naver-clova-ocr/bros-large-uncased/resolve/main/vocab.txt",
30 | },
31 | "tokenizer_file": {
32 | "naver-clova-ocr/bros-base-uncased": "https://huggingface.co/naver-clova-ocr/bros-base-uncased/resolve/main/tokenizer.json",
33 | "naver-clova-ocr/bros-large-uncased": "https://huggingface.co/naver-clova-ocr/bros-large-uncased/resolve/main/tokenizer.json",
34 | },
35 | }
36 |
37 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
38 | "naver-clova-ocr/bros-base-uncased": 512,
39 | "naver-clova-ocr/bros-large-uncased": 512,
40 | }
41 |
42 | PRETRAINED_INIT_CONFIGURATION = {
43 | "naver-clova-ocr/bros-base-uncased": {"do_lower_case": True},
44 | "naver-clova-ocr/bros-large-uncased": {"do_lower_case": True},
45 | }
46 |
47 |
48 | class BrosTokenizerFast(BertTokenizerFast):
49 | r"""
50 | Construct a "fast" BERT tokenizer (backed by HuggingFace's `tokenizers` library). Based on WordPiece.
51 |
52 | This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the main
53 | methods. Users should refer to this superclass for more information regarding those methods.
54 |
55 | Args:
56 | vocab_file (:obj:`str`):
57 | File containing the vocabulary.
58 | do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
59 | Whether or not to lowercase the input when tokenizing.
60 | unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`):
61 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
62 | token instead.
63 | sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`):
64 | The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
65 | sequence classification or for a text and a question for question answering. It is also used as the last
66 | token of a sequence built with special tokens.
67 | pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`):
68 | The token used for padding, for example when batching sequences of different lengths.
69 | cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`):
70 | The classifier token which is used when doing sequence classification (classification of the whole sequence
71 | instead of per-token classification). It is the first token of the sequence when built with special tokens.
72 | mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`):
73 | The token used for masking values. This is the token used when training this model with masked language
74 | modeling. This is the token which the model will try to predict.
75 | clean_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
76 | Whether or not to clean the text before tokenization by removing any control characters and replacing all
77 | whitespaces by the classic one.
78 | tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
79 | Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see `this
80 | issue `__).
81 | strip_accents: (:obj:`bool`, `optional`):
82 | Whether or not to strip all accents. If this option is not specified, then it will be determined by the
83 | value for :obj:`lowercase` (as in the original BERT).
84 | wordpieces_prefix: (:obj:`str`, `optional`, defaults to :obj:`"##"`):
85 | The prefix for subwords.
86 | """
87 |
88 | vocab_files_names = VOCAB_FILES_NAMES
89 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
90 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
91 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
92 | slow_tokenizer_class = BrosTokenizer
93 |
--------------------------------------------------------------------------------
/src/adapters/ml/vgt/create_word_grid.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import shutil
3 |
4 | import numpy as np
5 | from os import makedirs
6 | from os.path import join, exists
7 | from pdf_features import PdfToken
8 | from pdf_features import Rectangle
9 | from pdf_features import PdfFeatures
10 |
11 | from adapters.ml.vgt.bros.tokenization_bros import BrosTokenizer
12 | from configuration import WORD_GRIDS_PATH
13 |
14 | tokenizer = BrosTokenizer.from_pretrained("naver-clova-ocr/bros-base-uncased")
15 |
16 |
17 | def rectangle_to_bbox(rectangle: Rectangle):
18 | return [rectangle.left, rectangle.top, rectangle.width, rectangle.height]
19 |
20 |
21 | def get_words_positions(text: str, rectangle: Rectangle):
22 | text = text.strip()
23 | text_len = len(text)
24 |
25 | width_per_letter = rectangle.width / text_len
26 |
27 | words_bboxes = [Rectangle.from_coordinates(rectangle.left, rectangle.top, rectangle.left + 5, rectangle.bottom)]
28 | words_bboxes[-1].width = 0
29 | words_bboxes[-1].right = words_bboxes[-1].left
30 |
31 | for letter in text:
32 | if letter == " ":
33 | left = words_bboxes[-1].right + width_per_letter
34 | words_bboxes.append(Rectangle.from_coordinates(left, words_bboxes[-1].top, left + 5, words_bboxes[-1].bottom))
35 | words_bboxes[-1].width = 0
36 | words_bboxes[-1].right = words_bboxes[-1].left
37 | else:
38 | words_bboxes[-1].right = words_bboxes[-1].right + width_per_letter
39 | words_bboxes[-1].width = words_bboxes[-1].width + width_per_letter
40 |
41 | words = text.split()
42 | return words, words_bboxes
43 |
44 |
45 | def get_subwords_positions(word: str, rectangle: Rectangle):
46 | width_per_letter = rectangle.width / len(word)
47 | word_tokens = [x.replace("#", "") for x in tokenizer.tokenize(word)]
48 |
49 | if not word_tokens:
50 | return [], []
51 |
52 | ids = [x[-2] for x in tokenizer(word_tokens)["input_ids"]]
53 |
54 | right = rectangle.left + len(word_tokens[0]) * width_per_letter
55 | bboxes = [Rectangle.from_coordinates(rectangle.left, rectangle.top, right, rectangle.bottom)]
56 |
57 | for subword in word_tokens[1:]:
58 | right = bboxes[-1].right + len(subword) * width_per_letter
59 | bboxes.append(Rectangle.from_coordinates(bboxes[-1].right, rectangle.top, right, rectangle.bottom))
60 |
61 | return ids, bboxes
62 |
63 |
64 | def get_grid_words_dict(tokens: list[PdfToken]):
65 | texts, bbox_texts_list, inputs_ids, bbox_subword_list = [], [], [], []
66 | for token in tokens:
67 | words, words_bboxes = get_words_positions(token.content, token.bounding_box)
68 | texts += words
69 | bbox_texts_list += [rectangle_to_bbox(r) for r in words_bboxes]
70 | for word, word_box in zip(words, words_bboxes):
71 | ids, subwords_bboxes = get_subwords_positions(word, word_box)
72 | inputs_ids += ids
73 | bbox_subword_list += [rectangle_to_bbox(r) for r in subwords_bboxes]
74 |
75 | return {
76 | "input_ids": np.array(inputs_ids),
77 | "bbox_subword_list": np.array(bbox_subword_list),
78 | "texts": texts,
79 | "bbox_texts_list": np.array(bbox_texts_list),
80 | }
81 |
82 |
83 | def create_word_grid(pdf_features_list: list[PdfFeatures]):
84 | makedirs(WORD_GRIDS_PATH, exist_ok=True)
85 |
86 | for pdf_features in pdf_features_list:
87 | for page in pdf_features.pages:
88 | image_id = f"{pdf_features.file_name}_{page.page_number - 1}"
89 | if exists(join(WORD_GRIDS_PATH, image_id + ".pkl")):
90 | continue
91 | grid_words_dict = get_grid_words_dict(page.tokens)
92 | with open(join(WORD_GRIDS_PATH, f"{image_id}.pkl"), mode="wb") as file:
93 | pickle.dump(grid_words_dict, file)
94 |
95 |
96 | def remove_word_grids():
97 | shutil.rmtree(WORD_GRIDS_PATH, ignore_errors=True)
98 |
--------------------------------------------------------------------------------
/src/adapters/ml/vgt/ditod/FeatureMerge.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class FeatureMerge(nn.Module):
6 | """Multimodal feature fusion used in VSR."""
7 |
8 | def __init__(
9 | self,
10 | feature_names,
11 | visual_dim,
12 | semantic_dim,
13 | merge_type="Sum",
14 | dropout_ratio=0.1,
15 | with_extra_fc=True,
16 | shortcut=False,
17 | ):
18 | """Multimodal feature merge used in VSR.
19 | Args:
20 | visual_dim (list): the dim of visual features, e.g. [256]
21 | semantic_dim (list): the dim of semantic features, e.g. [256]
22 | merge_type (str): fusion type, e.g. 'Sum', 'Concat', 'Weighted'
23 | dropout_ratio (float): dropout ratio of fusion features
24 | with_extra_fc (bool): whether add extra fc layers for adaptation
25 | shortcut (bool): whether add shortcut connection
26 | """
27 | super().__init__()
28 |
29 | # merge param
30 | self.feature_names = feature_names
31 | self.merge_type = merge_type
32 | self.visual_dim = visual_dim
33 | self.textual_dim = semantic_dim
34 | self.with_extra_fc = with_extra_fc
35 | self.shortcut = shortcut
36 | self.relu = nn.ReLU(inplace=True)
37 |
38 | if self.merge_type == "Sum":
39 | assert len(self.visual_dim) == len(self.textual_dim)
40 | elif self.merge_type == "Concat":
41 | assert len(self.visual_dim) == len(self.textual_dim)
42 | # self.concat_proj = nn.ModuleList()
43 |
44 | self.vis_proj = nn.ModuleList()
45 | self.text_proj = nn.ModuleList()
46 | self.alpha_proj = nn.ModuleList()
47 |
48 | for idx in range(len(self.visual_dim)):
49 | # self.concat_proj.append(nn.Conv2d(self.visual_dim[idx] + self.textual_dim[idx], self.visual_dim[idx], kernel_size = (1,1), stride=1))
50 | if self.with_extra_fc:
51 | self.vis_proj.append(nn.Linear(self.visual_dim[idx], self.visual_dim[idx]))
52 | self.text_proj.append(nn.Linear(self.textual_dim[idx], self.textual_dim[idx]))
53 | self.alpha_proj.append(nn.Linear(self.visual_dim[idx] + self.textual_dim[idx], self.visual_dim[idx]))
54 |
55 | elif self.merge_type == "Weighted":
56 | assert len(self.visual_dim) == len(self.textual_dim)
57 | self.total_num = len(self.visual_dim)
58 |
59 | # vis projection
60 | self.vis_proj = nn.ModuleList()
61 | self.vis_proj_relu = nn.ModuleList()
62 |
63 | # text projection
64 | self.text_proj = nn.ModuleList()
65 | self.text_proj_relu = nn.ModuleList()
66 |
67 | self.alpha_proj = nn.ModuleList()
68 | for idx in range(self.total_num):
69 | if self.with_extra_fc:
70 | self.vis_proj.append(nn.Linear(self.visual_dim[idx], self.visual_dim[idx]))
71 | self.text_proj.append(nn.Linear(self.textual_dim[idx], self.textual_dim[idx]))
72 | self.alpha_proj.append(nn.Linear(self.visual_dim[idx] + self.textual_dim[idx], self.visual_dim[idx]))
73 |
74 | else:
75 | raise "Unknown merge type {}".format(self.merge_type)
76 |
77 | self.dropout = nn.Dropout(dropout_ratio)
78 |
79 | # visual context
80 | # self.visual_ap = nn.AdaptiveAvgPool2d((1, 1))
81 |
82 | def forward(self, visual_feat=None, textual_feat=None):
83 | """Forward computation
84 | Args:
85 | visual_feat (list(Tensor)): visual feature maps, in shape of [L x C x H x W] x B
86 | textual_feat (Tensor): textual feature maps, in shape of B x L x C
87 | Returns:
88 | Tensor: fused feature maps, in shape of [B x L x C]
89 | """
90 | assert len(visual_feat) == len(textual_feat)
91 |
92 | # feature merge
93 | merged_feat = {}
94 | if self.merge_type == "Sum":
95 | for name in self.feature_names:
96 | merged_feat[name] = visual_feat[name] + textual_feat[name]
97 | elif self.merge_type == "Concat":
98 | for idx, name in enumerate(self.feature_names):
99 | # merged_feat[name] = self.concat_proj[idx](torch.cat((visual_feat[name],textual_feat[name]),1))
100 | per_vis = visual_feat[name].permute(0, 2, 3, 1)
101 | per_text = textual_feat[name].permute(0, 2, 3, 1)
102 | if self.with_extra_fc:
103 | per_vis = self.relu(self.vis_proj[idx](per_vis))
104 | per_text = self.relu(self.text_proj[idx](per_text))
105 | x_sentence = self.alpha_proj[idx](torch.cat((per_vis, per_text), -1))
106 | x_sentence = x_sentence.permute(0, 3, 1, 2).contiguous()
107 | merged_feat[name] = x_sentence
108 | else:
109 | assert self.total_num == len(visual_feat) or self.total_num == 1
110 | # for per_vis, per_text in zip(visual_feat, textual_feat):
111 | for idx, name in enumerate(self.feature_names):
112 | per_vis = visual_feat[name].permute(0, 2, 3, 1)
113 | per_text = textual_feat[name].permute(0, 2, 3, 1)
114 | if self.with_extra_fc:
115 | per_vis = self.relu(self.vis_proj[idx](per_vis))
116 | per_text = self.relu(self.text_proj[idx](per_text))
117 |
118 | alpha = torch.sigmoid(self.alpha_proj[idx](torch.cat((per_vis, per_text), -1)))
119 | if self.shortcut:
120 | # shortcut
121 | x_sentence = per_vis + alpha * per_text
122 | else:
123 | # selection
124 | x_sentence = alpha * per_vis + (1 - alpha) * per_text
125 |
126 | x_sentence = x_sentence.permute(0, 3, 1, 2).contiguous()
127 | merged_feat[name] = x_sentence
128 |
129 | return merged_feat
130 |
--------------------------------------------------------------------------------
/src/adapters/ml/vgt/ditod/Wordnn_embedding.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | from .tokenization_bros import BrosTokenizer
5 |
6 |
7 | def _init_weights(m):
8 | if isinstance(m, nn.Linear):
9 | # we use xavier_uniform following official JAX ViT:
10 | torch.nn.init.xavier_uniform_(m.weight)
11 | if isinstance(m, nn.Linear) and m.bias is not None:
12 | nn.init.constant_(m.bias, 0)
13 | elif isinstance(m, nn.LayerNorm):
14 | nn.init.constant_(m.bias, 0)
15 | nn.init.constant_(m.weight, 1.0)
16 |
17 |
18 | class WordnnEmbedding(nn.Module):
19 | """Generate chargrid embedding feature map."""
20 |
21 | def __init__(
22 | self,
23 | vocab_size=30552,
24 | hidden_size=768,
25 | embedding_dim=64,
26 | bros_embedding_path="/bros-base-uncased/",
27 | use_pretrain_weight=True,
28 | use_UNK_text=False,
29 | ):
30 | """
31 | Args:
32 | vocab_size (int): size of vocabulary.
33 | embedding_dim (int): dim of input features
34 | """
35 | super().__init__()
36 |
37 | self.embedding = nn.Embedding(vocab_size, hidden_size)
38 | self.embedding_proj = nn.Linear(hidden_size, embedding_dim, bias=False)
39 | # self.tokenizer = BrosTokenizer.from_pretrained(bros_embedding_path)
40 | self.use_pretrain_weight = use_pretrain_weight
41 | self.use_UNK_text = use_UNK_text
42 |
43 | self.init_weights(bros_embedding_path)
44 | self.apply(_init_weights)
45 |
46 | def init_weights(self, bros_embedding_path):
47 | if self.use_pretrain_weight:
48 | state_dict = torch.load(bros_embedding_path + "pytorch_model.bin", map_location="cpu")
49 | if "bert" in bros_embedding_path:
50 | word_embs = state_dict["bert.embeddings.word_embeddings.weight"]
51 | elif "bros" in bros_embedding_path:
52 | word_embs = state_dict["embeddings.word_embeddings.weight"]
53 | elif "layoutlm" in bros_embedding_path:
54 | word_embs = state_dict["layoutlm.embeddings.word_embeddings.weight"]
55 | else:
56 | print("Wrong bros_embedding_path!")
57 | self.embedding = nn.Embedding.from_pretrained(word_embs)
58 | print("use_pretrain_weight: load model from:", bros_embedding_path)
59 |
60 | def forward(self, img, batched_inputs, stride=1):
61 | """Forward computation
62 | Args:
63 | img (Tensor): in shape of [B x 3 x H x W]
64 | batched_inputs (list[dict]):
65 | Returns:
66 | Tensor: in shape of [B x N x L x D], where D is the embedding_dim.
67 | """
68 | device = img.device
69 | batch_b, _, batch_h, batch_w = img.size()
70 |
71 | chargrid_map = torch.zeros((batch_b, batch_h // stride, batch_w // stride), dtype=torch.int64).to(device)
72 |
73 | for iter_b in range(batch_b):
74 | per_input_ids = batched_inputs[iter_b]["input_ids"]
75 | per_input_bbox = batched_inputs[iter_b]["bbox"]
76 |
77 | short_length_w = min(len(per_input_ids), len(per_input_bbox))
78 |
79 | if short_length_w > 0:
80 | for word_idx in range(short_length_w):
81 | per_id = per_input_ids[word_idx]
82 |
83 | bbox = per_input_bbox[word_idx] / stride
84 | w_start, h_start, w_end, h_end = bbox.round().astype(int).tolist()
85 |
86 | if self.use_UNK_text:
87 | chargrid_map[iter_b, h_start:h_end, w_start:w_end] = 100
88 | else:
89 | chargrid_map[iter_b, h_start:h_end, w_start:w_end] = per_id
90 |
91 | chargrid_map = self.embedding(chargrid_map)
92 | chargrid_map = self.embedding_proj(chargrid_map)
93 |
94 | return chargrid_map.permute(0, 3, 1, 2).contiguous()
95 |
--------------------------------------------------------------------------------
/src/adapters/ml/vgt/ditod/__init__.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------------------------------
2 | # MPViT: Multi-Path Vision Transformer for Dense Prediction
3 | # Copyright (c) 2022 Electronics and Telecommunications Research Institute (ETRI).
4 | # All Rights Reserved.
5 | # Written by Youngwan Lee
6 | # This source code is licensed(Dual License(GPL3.0 & Commercial)) under the license found in the
7 | # LICENSE file in the root directory of this source tree.
8 | # --------------------------------------------------------------------------------
9 |
10 | from .config import add_vit_config
11 | from .VGTbackbone import build_VGT_fpn_backbone
12 | from .dataset_mapper import DetrDatasetMapper
13 | from .VGTTrainer import VGTTrainer
14 | from .VGT import VGT
15 |
16 | from .utils import eval_and_show, load_gt_from_json, pub_load_gt_from_json
17 |
--------------------------------------------------------------------------------
/src/adapters/ml/vgt/ditod/config.py:
--------------------------------------------------------------------------------
1 | from detectron2.config import CfgNode as CN
2 |
3 |
4 | def add_vit_config(cfg):
5 | """
6 | Add config for VIT.
7 | """
8 | _C = cfg
9 |
10 | _C.MODEL.VIT = CN()
11 |
12 | # CoaT model name.
13 | _C.MODEL.VIT.NAME = ""
14 |
15 | # Output features from CoaT backbone.
16 | _C.MODEL.VIT.OUT_FEATURES = ["layer3", "layer5", "layer7", "layer11"]
17 |
18 | _C.MODEL.VIT.IMG_SIZE = [224, 224]
19 |
20 | _C.MODEL.VIT.POS_TYPE = "shared_rel"
21 |
22 | _C.MODEL.VIT.MERGE_TYPE = "Sum"
23 |
24 | _C.MODEL.VIT.DROP_PATH = 0.0
25 |
26 | _C.MODEL.VIT.MODEL_KWARGS = "{}"
27 |
28 | _C.SOLVER.OPTIMIZER = "ADAMW"
29 |
30 | _C.SOLVER.BACKBONE_MULTIPLIER = 1.0
31 |
32 | _C.AUG = CN()
33 |
34 | _C.AUG.DETR = False
35 |
36 | _C.MODEL.WORDGRID = CN()
37 |
38 | _C.MODEL.WORDGRID.VOCAB_SIZE = 30552
39 |
40 | _C.MODEL.WORDGRID.EMBEDDING_DIM = 64
41 |
42 | _C.MODEL.WORDGRID.MODEL_PATH = ""
43 |
44 | _C.MODEL.WORDGRID.HIDDEN_SIZE = 768
45 |
46 | _C.MODEL.WORDGRID.USE_PRETRAIN_WEIGHT = True
47 |
48 | _C.MODEL.WORDGRID.USE_UNK_TEXT = False
49 |
--------------------------------------------------------------------------------
/src/adapters/ml/vgt/ditod/tokenization_bros.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 |
18 | import collections
19 |
20 | from transformers.models.bert.tokenization_bert import BertTokenizer
21 | from transformers.utils import logging
22 |
23 | logger = logging.get_logger(__name__)
24 |
25 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
26 |
27 | PRETRAINED_VOCAB_FILES_MAP = {
28 | "vocab_file": {
29 | "naver-clova-ocr/bros-base-uncased": "https://huggingface.co/naver-clova-ocr/bros-base-uncased/resolve/main/vocab.txt",
30 | "naver-clova-ocr/bros-large-uncased": "https://huggingface.co/naver-clova-ocr/bros-large-uncased/resolve/main/vocab.txt",
31 | }
32 | }
33 |
34 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
35 | "naver-clova-ocr/bros-base-uncased": 512,
36 | "naver-clova-ocr/bros-large-uncased": 512,
37 | }
38 |
39 | PRETRAINED_INIT_CONFIGURATION = {
40 | "naver-clova-ocr/bros-base-uncased": {"do_lower_case": True},
41 | "naver-clova-ocr/bros-large-uncased": {"do_lower_case": True},
42 | }
43 |
44 |
45 | def load_vocab(vocab_file):
46 | """Loads a vocabulary file into a dictionary."""
47 | vocab = collections.OrderedDict()
48 | with open(vocab_file, "r", encoding="utf-8") as reader:
49 | tokens = reader.readlines()
50 | for index, token in enumerate(tokens):
51 | token = token.rstrip("\n")
52 | vocab[token] = index
53 | return vocab
54 |
55 |
56 | def convert_to_unicode(text):
57 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
58 | if six.PY3:
59 | if isinstance(text, str):
60 | return text
61 | elif isinstance(text, bytes):
62 | return text.decode("utf-8", "ignore")
63 | else:
64 | raise ValueError("Unsupported string type: %s" % (type(text)))
65 | elif six.PY2:
66 | if isinstance(text, str):
67 | return text.decode("utf-8", "ignore")
68 | elif isinstance(text, unicode):
69 | return text
70 | else:
71 | raise ValueError("Unsupported string type: %s" % (type(text)))
72 | else:
73 | raise ValueError("Not running on Python2 or Python 3?")
74 |
75 |
76 | def whitespace_tokenize(text):
77 | """Runs basic whitespace cleaning and splitting on a piece of text."""
78 | text = text.strip()
79 | if not text:
80 | return []
81 | tokens = text.split()
82 | return tokens
83 |
84 |
85 | class BrosTokenizer(BertTokenizer):
86 | r"""
87 | Construct a BERT tokenizer. Based on WordPiece.
88 |
89 | This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
90 | Users should refer to this superclass for more information regarding those methods.
91 |
92 | Args:
93 | vocab_file (:obj:`str`):
94 | File containing the vocabulary.
95 | do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
96 | Whether or not to lowercase the input when tokenizing.
97 | do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
98 | Whether or not to do basic tokenization before WordPiece.
99 | never_split (:obj:`Iterable`, `optional`):
100 | Collection of tokens which will never be split during tokenization. Only has an effect when
101 | :obj:`do_basic_tokenize=True`
102 | unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`):
103 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
104 | token instead.
105 | sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`):
106 | The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
107 | sequence classification or for a text and a question for question answering. It is also used as the last
108 | token of a sequence built with special tokens.
109 | pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`):
110 | The token used for padding, for example when batching sequences of different lengths.
111 | cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`):
112 | The classifier token which is used when doing sequence classification (classification of the whole sequence
113 | instead of per-token classification). It is the first token of the sequence when built with special tokens.
114 | mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`):
115 | The token used for masking values. This is the token used when training this model with masked language
116 | modeling. This is the token which the model will try to predict.
117 | tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
118 | Whether or not to tokenize Chinese characters.
119 |
120 | This should likely be deactivated for Japanese (see this `issue
121 | `__).
122 | strip_accents: (:obj:`bool`, `optional`):
123 | Whether or not to strip all accents. If this option is not specified, then it will be determined by the
124 | value for :obj:`lowercase` (as in the original BERT).
125 | """
126 |
127 | vocab_files_names = VOCAB_FILES_NAMES
128 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
129 | pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
130 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
131 |
--------------------------------------------------------------------------------
/src/adapters/ml/vgt/get_json_annotations.py:
--------------------------------------------------------------------------------
1 | import json
2 | from os import makedirs
3 | from pdf_features import PdfToken
4 | from domain.PdfImages import PdfImages
5 | from configuration import DOCLAYNET_TYPE_BY_ID
6 | from configuration import JSONS_ROOT_PATH, JSON_TEST_FILE_PATH
7 |
8 |
9 | def save_annotations_json(annotations: list, width_height: list, images: list):
10 | images_dict = [
11 | {
12 | "id": i,
13 | "file_name": image_id + ".jpg",
14 | "width": width_height[images.index(image_id)][0],
15 | "height": width_height[images.index(image_id)][1],
16 | }
17 | for i, image_id in enumerate(images)
18 | ]
19 |
20 | categories_dict = [{"id": key, "name": value} for key, value in DOCLAYNET_TYPE_BY_ID.items()]
21 |
22 | info_dict = {
23 | "description": "PDF Document Layout Analysis Dataset",
24 | "url": "",
25 | "version": "1.0",
26 | "year": 2025,
27 | "contributor": "",
28 | "date_created": "2025-01-01",
29 | }
30 |
31 | coco_dict = {"info": info_dict, "images": images_dict, "categories": categories_dict, "annotations": annotations}
32 |
33 | JSON_TEST_FILE_PATH.write_text(json.dumps(coco_dict))
34 |
35 |
36 | def get_annotation(index: int, image_id: str, token: PdfToken):
37 | return {
38 | "area": 1,
39 | "iscrowd": 0,
40 | "score": 1,
41 | "image_id": image_id,
42 | "bbox": [token.bounding_box.left, token.bounding_box.top, token.bounding_box.width, token.bounding_box.height],
43 | "category_id": token.token_type.get_index(),
44 | "id": index,
45 | }
46 |
47 |
48 | def get_annotations_for_document(annotations, images, index, pdf_images, width_height):
49 | for page_index, page in enumerate(pdf_images.pdf_features.pages):
50 | image_id = f"{pdf_images.pdf_features.file_name}_{page.page_number - 1}"
51 | images.append(image_id)
52 | width_height.append((pdf_images.pdf_images[page_index].width, pdf_images.pdf_images[page_index].height))
53 |
54 | for token in page.tokens:
55 | annotations.append(get_annotation(index, image_id, token))
56 | index += 1
57 |
58 |
59 | def get_annotations(pdf_images_list: list[PdfImages]):
60 | makedirs(JSONS_ROOT_PATH, exist_ok=True)
61 |
62 | annotations = list()
63 | images = list()
64 | width_height = list()
65 | index = 0
66 |
67 | for pdf_images in pdf_images_list:
68 | get_annotations_for_document(annotations, images, index, pdf_images, width_height)
69 | index += sum([len(page.tokens) for page in pdf_images.pdf_features.pages])
70 |
71 | save_annotations_json(annotations, width_height, images)
72 |
--------------------------------------------------------------------------------
/src/adapters/ml/vgt/get_model_configuration.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from os.path import join
3 | from detectron2.config import get_cfg
4 | from detectron2.engine import default_setup, default_argument_parser
5 | from configuration import service_logger, SRC_PATH, ROOT_PATH
6 | from adapters.ml.vgt.ditod import add_vit_config
7 |
8 |
9 | def is_gpu_available():
10 | total_free_memory_in_system: float = 0.0
11 | if torch.cuda.is_available():
12 | for i in range(torch.cuda.device_count()):
13 | total_memory = torch.cuda.get_device_properties(i).total_memory / 1024**2
14 | allocated_memory = torch.cuda.memory_allocated(i) / 1024**2
15 | cached_memory = torch.cuda.memory_reserved(i) / 1024**2
16 | service_logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
17 | service_logger.info(f" Total Memory: {total_memory} MB")
18 | service_logger.info(f" Allocated Memory: {allocated_memory} MB")
19 | service_logger.info(f" Cached Memory: {cached_memory} MB")
20 | total_free_memory_in_system += total_memory - allocated_memory - cached_memory
21 | if total_free_memory_in_system < 3000:
22 | service_logger.info(f"Total free GPU memory is {total_free_memory_in_system} < 3000 MB. Switching to CPU.")
23 | service_logger.info("The process is probably going to be 15 times slower.")
24 | else:
25 | service_logger.info("No CUDA-compatible GPU detected. Switching to CPU.")
26 | return total_free_memory_in_system > 3000
27 |
28 |
29 | def get_model_configuration():
30 | parser = default_argument_parser()
31 | args, unknown = parser.parse_known_args()
32 | args.config_file = join(SRC_PATH, "adapters", "ml", "vgt", "model_configuration", "doclaynet_VGT_cascade_PTM.yaml")
33 | args.eval_only = True
34 | args.num_gpus = 1
35 | args.opts = [
36 | "MODEL.WEIGHTS",
37 | join(ROOT_PATH, "models", "doclaynet_VGT_model.pth"),
38 | "OUTPUT_DIR",
39 | join(ROOT_PATH, "model_output_doclaynet"),
40 | ]
41 | args.debug = False
42 |
43 | configuration = get_cfg()
44 | add_vit_config(configuration)
45 | configuration.merge_from_file(args.config_file)
46 | configuration.merge_from_list(args.opts)
47 | configuration.MODEL.DEVICE = "cuda" if is_gpu_available() else "cpu"
48 | configuration.freeze()
49 | default_setup(configuration, args)
50 |
51 | return configuration
52 |
--------------------------------------------------------------------------------
/src/adapters/ml/vgt/get_most_probable_pdf_segments.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pickle
3 | from os.path import join
4 | from pathlib import Path
5 | from statistics import mode
6 |
7 | from domain.PdfSegment import PdfSegment
8 | from pdf_features import PdfFeatures
9 | from pdf_features import PdfToken
10 | from pdf_features import Rectangle
11 | from pdf_token_type_labels import TokenType
12 | from domain.PdfImages import PdfImages
13 | from configuration import ROOT_PATH, DOCLAYNET_TYPE_BY_ID
14 | from domain.Prediction import Prediction
15 |
16 |
17 | def get_prediction_from_annotation(annotation, images_names, vgt_predictions_dict):
18 | pdf_name = images_names[annotation["image_id"]][:-4]
19 | category_id = annotation["category_id"]
20 | bounding_box = Rectangle.from_width_height(
21 | left=int(annotation["bbox"][0]),
22 | top=int(annotation["bbox"][1]),
23 | width=int(annotation["bbox"][2]),
24 | height=int(annotation["bbox"][3]),
25 | )
26 |
27 | prediction = Prediction(
28 | bounding_box=bounding_box, category_id=category_id, score=round(float(annotation["score"]) * 100, 2)
29 | )
30 | vgt_predictions_dict.setdefault(pdf_name, list()).append(prediction)
31 |
32 |
33 | def get_vgt_predictions(model_name: str) -> dict[str, list[Prediction]]:
34 | output_dir: str = f"model_output_{model_name}"
35 | model_output_json_path = join(str(ROOT_PATH), output_dir, "inference", "coco_instances_results.json")
36 | annotations = json.loads(Path(model_output_json_path).read_text())
37 |
38 | test_json_path = join(str(ROOT_PATH), "jsons", "test.json")
39 | coco_truth = json.loads(Path(test_json_path).read_text())
40 |
41 | images_names = {value["id"]: value["file_name"] for value in coco_truth["images"]}
42 |
43 | vgt_predictions_dict = dict()
44 | for annotation in annotations:
45 | get_prediction_from_annotation(annotation, images_names, vgt_predictions_dict)
46 |
47 | return vgt_predictions_dict
48 |
49 |
50 | def find_best_prediction_for_token(page_pdf_name, token, vgt_predictions_dict, most_probable_tokens_by_predictions):
51 | best_score: float = 0
52 | most_probable_prediction: Prediction | None = None
53 | for prediction in vgt_predictions_dict[page_pdf_name]:
54 | if prediction.score > best_score and prediction.bounding_box.get_intersection_percentage(token.bounding_box):
55 | best_score = prediction.score
56 | most_probable_prediction = prediction
57 | if best_score >= 99:
58 | break
59 | if most_probable_prediction:
60 | most_probable_tokens_by_predictions.setdefault(most_probable_prediction, list()).append(token)
61 | else:
62 | dummy_prediction = Prediction(bounding_box=token.bounding_box, category_id=10, score=0.0)
63 | most_probable_tokens_by_predictions.setdefault(dummy_prediction, list()).append(token)
64 |
65 |
66 | def get_merged_prediction_type(to_merge: list[Prediction]):
67 | table_exists = any([p.category_id == 9 for p in to_merge])
68 | if not table_exists:
69 | return mode([p.category_id for p in sorted(to_merge, key=lambda x: -x.score)])
70 | return 9
71 |
72 |
73 | def merge_colliding_predictions(predictions: list[Prediction]):
74 | predictions = [p for p in predictions if not p.score < 20]
75 | while True:
76 | new_predictions, merged = [], False
77 | while predictions:
78 | p1 = predictions.pop(0)
79 | to_merge = [p for p in predictions if p1.bounding_box.get_intersection_percentage(p.bounding_box) > 0]
80 | for prediction in to_merge:
81 | predictions.remove(prediction)
82 | if to_merge:
83 | to_merge.append(p1)
84 | p1.bounding_box = Rectangle.merge_rectangles([prediction.bounding_box for prediction in to_merge])
85 | p1.category_id = get_merged_prediction_type(to_merge)
86 | merged = True
87 | new_predictions.append(p1)
88 | if not merged:
89 | return new_predictions
90 | predictions = new_predictions
91 |
92 |
93 | def get_pdf_segments_for_page(page, pdf_name, page_pdf_name, vgt_predictions_dict):
94 | most_probable_pdf_segments_for_page: list[PdfSegment] = []
95 | most_probable_tokens_by_predictions: dict[Prediction, list[PdfToken]] = {}
96 | vgt_predictions_dict[page_pdf_name] = merge_colliding_predictions(vgt_predictions_dict[page_pdf_name])
97 |
98 | for token in page.tokens:
99 | find_best_prediction_for_token(page_pdf_name, token, vgt_predictions_dict, most_probable_tokens_by_predictions)
100 |
101 | for prediction, tokens in most_probable_tokens_by_predictions.items():
102 | new_segment = PdfSegment.from_pdf_tokens(tokens, pdf_name)
103 | new_segment.bounding_box = prediction.bounding_box
104 | new_segment.segment_type = TokenType.from_text(DOCLAYNET_TYPE_BY_ID[prediction.category_id])
105 | most_probable_pdf_segments_for_page.append(new_segment)
106 |
107 | no_token_predictions = [
108 | prediction
109 | for prediction in vgt_predictions_dict[page_pdf_name]
110 | if prediction not in most_probable_tokens_by_predictions
111 | ]
112 |
113 | for prediction in no_token_predictions:
114 | segment_type = TokenType.from_text(DOCLAYNET_TYPE_BY_ID[prediction.category_id])
115 | page_number = page.page_number
116 | new_segment = PdfSegment(page_number, prediction.bounding_box, "", segment_type, pdf_name)
117 | most_probable_pdf_segments_for_page.append(new_segment)
118 |
119 | return most_probable_pdf_segments_for_page
120 |
121 |
122 | def prediction_exists_for_page(page_pdf_name, vgt_predictions_dict):
123 | return page_pdf_name in vgt_predictions_dict
124 |
125 |
126 | def get_most_probable_pdf_segments(model_name: str, pdf_images_list: list[PdfImages], save_output: bool = False):
127 | most_probable_pdf_segments: list[PdfSegment] = []
128 | vgt_predictions_dict = get_vgt_predictions(model_name)
129 | pdf_features_list: list[PdfFeatures] = [pdf_images.pdf_features for pdf_images in pdf_images_list]
130 | for pdf_features in pdf_features_list:
131 | for page in pdf_features.pages:
132 | page_pdf_name = pdf_features.file_name + "_" + str(page.page_number - 1)
133 | if not prediction_exists_for_page(page_pdf_name, vgt_predictions_dict):
134 | continue
135 | page_segments = get_pdf_segments_for_page(page, pdf_features.file_name, page_pdf_name, vgt_predictions_dict)
136 | most_probable_pdf_segments.extend(page_segments)
137 | if save_output:
138 | save_path = join(ROOT_PATH, f"model_output_{model_name}", "predicted_segments.pickle")
139 | with open(save_path, mode="wb") as file:
140 | pickle.dump(most_probable_pdf_segments, file)
141 | return most_probable_pdf_segments
142 |
--------------------------------------------------------------------------------
/src/adapters/ml/vgt/get_reading_orders.py:
--------------------------------------------------------------------------------
1 | from domain.PdfSegment import PdfSegment
2 | from pdf_features import PdfPage
3 | from pdf_features import PdfToken
4 | from pdf_token_type_labels import TokenType
5 |
6 | from domain.PdfImages import PdfImages
7 |
8 |
9 | def find_segment_for_token(token: PdfToken, segments: list[PdfSegment], tokens_by_segments):
10 | best_score: float = 0
11 | most_probable_segment: PdfSegment | None = None
12 | for segment in segments:
13 | intersection_percentage = token.bounding_box.get_intersection_percentage(segment.bounding_box)
14 | if intersection_percentage > best_score:
15 | best_score = intersection_percentage
16 | most_probable_segment = segment
17 | if best_score >= 99:
18 | break
19 | if most_probable_segment:
20 | tokens_by_segments.setdefault(most_probable_segment, list()).append(token)
21 |
22 |
23 | def get_average_reading_order_for_segment(page: PdfPage, tokens_for_segment: list[PdfToken]):
24 | reading_order_sum: int = sum(page.tokens.index(token) for token in tokens_for_segment)
25 | return reading_order_sum / len(tokens_for_segment)
26 |
27 |
28 | def get_distance_between_segments(segment1: PdfSegment, segment2: PdfSegment):
29 | center_1_x = (segment1.bounding_box.left + segment1.bounding_box.right) / 2
30 | center_1_y = (segment1.bounding_box.top + segment1.bounding_box.bottom) / 2
31 | center_2_x = (segment2.bounding_box.left + segment2.bounding_box.right) / 2
32 | center_2_y = (segment2.bounding_box.top + segment2.bounding_box.bottom) / 2
33 | return ((center_1_x - center_2_x) ** 2 + (center_1_y - center_2_y) ** 2) ** 0.5
34 |
35 |
36 | def add_no_token_segments(segments, no_token_segments):
37 | if segments:
38 | for no_token_segment in no_token_segments:
39 | closest_segment = sorted(segments, key=lambda seg: get_distance_between_segments(no_token_segment, seg))[0]
40 | closest_index = segments.index(closest_segment)
41 | if closest_segment.bounding_box.top < no_token_segment.bounding_box.top:
42 | segments.insert(closest_index + 1, no_token_segment)
43 | else:
44 | segments.insert(closest_index, no_token_segment)
45 | else:
46 | for segment in sorted(no_token_segments, key=lambda r: (r.bounding_box.left, r.bounding_box.top)):
47 | segments.append(segment)
48 |
49 |
50 | def filter_and_sort_segments(page, tokens_by_segments, types):
51 | filtered_segments = [seg for seg in tokens_by_segments.keys() if seg.segment_type in types]
52 | order = {seg: get_average_reading_order_for_segment(page, tokens_by_segments[seg]) for seg in filtered_segments}
53 | return sorted(filtered_segments, key=lambda seg: order[seg])
54 |
55 |
56 | def get_ordered_segments_for_page(segments_for_page: list[PdfSegment], page: PdfPage):
57 | tokens_by_segments: dict[PdfSegment, list[PdfToken]] = {}
58 | for token in page.tokens:
59 | find_segment_for_token(token, segments_for_page, tokens_by_segments)
60 |
61 | page_number_segment: None | PdfSegment = None
62 | if tokens_by_segments:
63 | last_segment = max(tokens_by_segments.keys(), key=lambda seg: seg.bounding_box.top)
64 | if last_segment.text_content and len(last_segment.text_content) < 5:
65 | page_number_segment = last_segment
66 | del tokens_by_segments[last_segment]
67 |
68 | header_segments: list[PdfSegment] = filter_and_sort_segments(page, tokens_by_segments, {TokenType.PAGE_HEADER})
69 | paragraph_types = {t for t in TokenType if t.name not in {"PAGE_HEADER", "PAGE_FOOTER", "FOOTNOTE"}}
70 | paragraph_segments = filter_and_sort_segments(page, tokens_by_segments, paragraph_types)
71 | footer_segments = filter_and_sort_segments(page, tokens_by_segments, {TokenType.PAGE_FOOTER, TokenType.FOOTNOTE})
72 | if page_number_segment:
73 | footer_segments.append(page_number_segment)
74 | ordered_segments = header_segments + paragraph_segments + footer_segments
75 | no_token_segments = [segment for segment in segments_for_page if segment not in ordered_segments]
76 | add_no_token_segments(ordered_segments, no_token_segments)
77 | return ordered_segments
78 |
79 |
80 | def get_reading_orders(pdf_images_list: list[PdfImages], predicted_segments: list[PdfSegment]):
81 | ordered_segments: list[PdfSegment] = []
82 | for pdf_images in pdf_images_list:
83 | pdf_name = pdf_images.pdf_features.file_name
84 | segments_for_file = [segment for segment in predicted_segments if segment.pdf_name == pdf_name]
85 | for page in pdf_images.pdf_features.pages:
86 | segments_for_page = [segment for segment in segments_for_file if segment.page_number == page.page_number]
87 | ordered_segments.extend(get_ordered_segments_for_page(segments_for_page, page))
88 | return ordered_segments
89 |
--------------------------------------------------------------------------------
/src/adapters/ml/vgt/model_configuration/Base-RCNN-FPN.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | MASK_ON: True
3 | META_ARCHITECTURE: "GeneralizedRCNN"
4 | PIXEL_MEAN: [123.675, 116.280, 103.530]
5 | PIXEL_STD: [58.395, 57.120, 57.375]
6 | BACKBONE:
7 | NAME: "build_vit_fpn_backbone"
8 | VIT:
9 | OUT_FEATURES: ["layer3", "layer5", "layer7", "layer11"]
10 | DROP_PATH: 0.1
11 | IMG_SIZE: [224,224]
12 | POS_TYPE: "abs"
13 | FPN:
14 | IN_FEATURES: ["layer3", "layer5", "layer7", "layer11"]
15 | ANCHOR_GENERATOR:
16 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map
17 | ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps)
18 | RPN:
19 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"]
20 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level
21 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level
22 | # Detectron1 uses 2000 proposals per-batch,
23 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue)
24 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2.
25 | POST_NMS_TOPK_TRAIN: 1000
26 | POST_NMS_TOPK_TEST: 1000
27 | ROI_HEADS:
28 | NAME: "StandardROIHeads"
29 | IN_FEATURES: ["p2", "p3", "p4", "p5"]
30 | NUM_CLASSES: 5
31 | ROI_BOX_HEAD:
32 | NAME: "FastRCNNConvFCHead"
33 | NUM_FC: 2
34 | POOLER_RESOLUTION: 7
35 | ROI_MASK_HEAD:
36 | NAME: "MaskRCNNConvUpsampleHead"
37 | NUM_CONV: 4
38 | POOLER_RESOLUTION: 14
39 | DATASETS:
40 | TRAIN: ("docbank_train",)
41 | TEST: ("docbank_val",)
42 | SOLVER:
43 | LR_SCHEDULER_NAME: "WarmupCosineLR"
44 | AMP:
45 | ENABLED: True
46 | OPTIMIZER: "ADAMW"
47 | BACKBONE_MULTIPLIER: 1.0
48 | CLIP_GRADIENTS:
49 | ENABLED: True
50 | CLIP_TYPE: "full_model"
51 | CLIP_VALUE: 1.0
52 | NORM_TYPE: 2.0
53 | WARMUP_FACTOR: 0.01
54 | BASE_LR: 0.0004
55 | WEIGHT_DECAY: 0.05
56 | IMS_PER_BATCH: 32
57 | INPUT:
58 | CROP:
59 | ENABLED: True
60 | TYPE: "absolute_range"
61 | SIZE: (384, 600)
62 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
63 | FORMAT: "RGB"
64 | DATALOADER:
65 | NUM_WORKERS: 6
66 | FILTER_EMPTY_ANNOTATIONS: False
67 | VERSION: 2
68 | AUG:
69 | DETR: True
70 | SEED: 42
--------------------------------------------------------------------------------
/src/adapters/ml/vgt/model_configuration/doclaynet_VGT_cascade_PTM.yaml:
--------------------------------------------------------------------------------
1 | DATASETS:
2 | TEST: ("predict_data",)
3 | TRAIN: ("train_data",)
4 | MODEL:
5 | BACKBONE:
6 | NAME: build_VGT_fpn_backbone
7 | MASK_ON: false
8 | META_ARCHITECTURE: VGT
9 | PIXEL_MEAN:
10 | - 127.5
11 | - 127.5
12 | - 127.5
13 | PIXEL_STD:
14 | - 127.5
15 | - 127.5
16 | - 127.5
17 | ROI_BOX_HEAD:
18 | CLS_AGNOSTIC_BBOX_REG: true
19 | ROI_HEADS:
20 | NAME: CascadeROIHeads
21 | NUM_CLASSES: 11
22 | RPN:
23 | POST_NMS_TOPK_TRAIN: 2000
24 | VIT:
25 | MERGE_TYPE: Sum
26 | NAME: VGT_dit_base_patch16
27 | WEIGHTS: https://layoutlm.blob.core.windows.net/dit/dit-pts/dit-base-224-p16-500k-62d53a.pth
28 | WORDGRID:
29 | EMBEDDING_DIM: 64
30 | MODEL_PATH: ../models/layoutlm-base-uncased/
31 | USE_PRETRAIN_WEIGHT: true
32 | VOCAB_SIZE: 30552
33 | SOLVER:
34 | BASE_LR: 0.0002
35 | IMS_PER_BATCH: 12
36 | MAX_ITER: 10000
37 | STEPS: (6000, 8000)
38 | WARMUP_ITERS: 100
39 | TEST:
40 | EVAL_PERIOD: 2000
41 | _BASE_: ./Base-RCNN-FPN.yaml
42 |
--------------------------------------------------------------------------------
/src/adapters/ml/vgt/model_configuration/doclaynet_configuration.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/ml/vgt/model_configuration/doclaynet_configuration.pickle
--------------------------------------------------------------------------------
/src/adapters/ml/vgt_model_adapter.py:
--------------------------------------------------------------------------------
1 | from domain.PdfImages import PdfImages
2 | from domain.PdfSegment import PdfSegment
3 | from ports.services.ml_model_service import MLModelService
4 | from adapters.ml.vgt.ditod import VGTTrainer
5 | from adapters.ml.vgt.get_model_configuration import get_model_configuration
6 | from adapters.ml.vgt.get_most_probable_pdf_segments import get_most_probable_pdf_segments
7 | from adapters.ml.vgt.get_reading_orders import get_reading_orders
8 | from adapters.ml.vgt.get_json_annotations import get_annotations
9 | from adapters.ml.vgt.create_word_grid import create_word_grid, remove_word_grids
10 | from detectron2.checkpoint import DetectionCheckpointer
11 | from detectron2.data.datasets import register_coco_instances
12 | from detectron2.data import DatasetCatalog
13 | from configuration import JSON_TEST_FILE_PATH, IMAGES_ROOT_PATH
14 |
15 | configuration = get_model_configuration()
16 | model = VGTTrainer.build_model(configuration)
17 | DetectionCheckpointer(model, save_dir=configuration.OUTPUT_DIR).resume_or_load(configuration.MODEL.WEIGHTS, resume=True)
18 |
19 |
20 | class VGTModelAdapter(MLModelService):
21 |
22 | def _register_data(self) -> None:
23 | try:
24 | DatasetCatalog.remove("predict_data")
25 | except KeyError:
26 | pass
27 |
28 | register_coco_instances("predict_data", {}, JSON_TEST_FILE_PATH, IMAGES_ROOT_PATH)
29 |
30 | def predict_document_layout(self, pdf_images: list[PdfImages]) -> list[PdfSegment]:
31 | create_word_grid([pdf_images_obj.pdf_features for pdf_images_obj in pdf_images])
32 | get_annotations(pdf_images)
33 |
34 | self._register_data()
35 | VGTTrainer.test(configuration, model)
36 |
37 | predicted_segments = get_most_probable_pdf_segments("doclaynet", pdf_images, False)
38 |
39 | PdfImages.remove_images()
40 | remove_word_grids()
41 |
42 | return get_reading_orders(pdf_images, predicted_segments)
43 |
44 | def predict_layout_fast(self, pdf_images: list[PdfImages]) -> list[PdfSegment]:
45 | raise NotImplementedError("Fast prediction should be handled by FastTrainerAdapter")
46 |
--------------------------------------------------------------------------------
/src/adapters/storage/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/storage/__init__.py
--------------------------------------------------------------------------------
/src/adapters/storage/file_system_repository.py:
--------------------------------------------------------------------------------
1 | import tempfile
2 | import uuid
3 | from pathlib import Path
4 | from typing import AnyStr
5 | from ports.repositories.file_repository import FileRepository
6 | from configuration import XMLS_PATH
7 |
8 |
9 | class FileSystemRepository(FileRepository):
10 | def save_pdf(self, content: AnyStr, filename: str = "") -> Path:
11 | if not filename:
12 | filename = str(uuid.uuid1())
13 |
14 | pdf_path = Path(tempfile.gettempdir(), f"{filename}.pdf")
15 | pdf_path.write_bytes(content)
16 | return pdf_path
17 |
18 | def save_xml(self, content: str, filename: str) -> Path:
19 | if not filename.endswith(".xml"):
20 | filename = f"{filename}.xml"
21 |
22 | xml_path = Path(XMLS_PATH, filename)
23 | xml_path.parent.mkdir(parents=True, exist_ok=True)
24 | xml_path.write_text(content)
25 | return xml_path
26 |
27 | def get_xml(self, filename: str) -> str:
28 | if not filename.endswith(".xml"):
29 | filename = f"{filename}.xml"
30 |
31 | xml_path = Path(XMLS_PATH, filename)
32 | if not xml_path.exists():
33 | raise FileNotFoundError(f"XML file {filename} not found")
34 |
35 | return xml_path.read_text()
36 |
37 | def delete_file(self, filepath: Path) -> None:
38 | filepath.unlink(missing_ok=True)
39 |
40 | def cleanup_temp_files(self) -> None:
41 | pass
42 |
43 | def save_pdf_to_directory(self, content: AnyStr, filename: str, directory: Path, namespace: str = "") -> Path:
44 | if namespace:
45 | target_path = Path(directory, namespace, filename)
46 | else:
47 | target_path = Path(directory, filename)
48 |
49 | target_path.parent.mkdir(parents=True, exist_ok=True)
50 | target_path.write_bytes(content)
51 | return target_path
52 |
53 | def save_markdown(self, content: str, filepath: Path) -> Path:
54 | filepath.parent.mkdir(parents=True, exist_ok=True)
55 | filepath.write_text(content, encoding="utf-8")
56 | return filepath
57 |
--------------------------------------------------------------------------------
/src/adapters/web/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/adapters/web/__init__.py
--------------------------------------------------------------------------------
/src/adapters/web/fastapi_controllers.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import subprocess
3 | from fastapi import UploadFile, File, Form
4 | from typing import Optional, Union
5 | from starlette.responses import Response
6 | from starlette.concurrency import run_in_threadpool
7 | from use_cases.pdf_analysis.analyze_pdf_use_case import AnalyzePDFUseCase
8 | from use_cases.text_extraction.extract_text_use_case import ExtractTextUseCase
9 | from use_cases.toc_extraction.extract_toc_use_case import ExtractTOCUseCase
10 | from use_cases.visualization.create_visualization_use_case import CreateVisualizationUseCase
11 | from use_cases.ocr.process_ocr_use_case import ProcessOCRUseCase
12 | from use_cases.markdown_conversion.convert_to_markdown_use_case import ConvertToMarkdownUseCase
13 | from use_cases.html_conversion.convert_to_html_use_case import ConvertToHtmlUseCase
14 | from adapters.storage.file_system_repository import FileSystemRepository
15 |
16 |
17 | class FastAPIControllers:
18 | def __init__(
19 | self,
20 | analyze_pdf_use_case: AnalyzePDFUseCase,
21 | extract_text_use_case: ExtractTextUseCase,
22 | extract_toc_use_case: ExtractTOCUseCase,
23 | create_visualization_use_case: CreateVisualizationUseCase,
24 | process_ocr_use_case: ProcessOCRUseCase,
25 | convert_to_markdown_use_case: ConvertToMarkdownUseCase,
26 | convert_to_html_use_case: ConvertToHtmlUseCase,
27 | file_repository: FileSystemRepository,
28 | ):
29 | self.analyze_pdf_use_case = analyze_pdf_use_case
30 | self.extract_text_use_case = extract_text_use_case
31 | self.extract_toc_use_case = extract_toc_use_case
32 | self.create_visualization_use_case = create_visualization_use_case
33 | self.process_ocr_use_case = process_ocr_use_case
34 | self.convert_to_markdown_use_case = convert_to_markdown_use_case
35 | self.convert_to_html_use_case = convert_to_html_use_case
36 | self.file_repository = file_repository
37 |
38 | async def root(self):
39 | import torch
40 |
41 | return sys.version + " Using GPU: " + str(torch.cuda.is_available())
42 |
43 | async def info(self):
44 | return {
45 | "sys": sys.version,
46 | "tesseract_version": subprocess.run("tesseract --version", shell=True, text=True, capture_output=True).stdout,
47 | "ocrmypdf_version": subprocess.run("ocrmypdf --version", shell=True, text=True, capture_output=True).stdout,
48 | "supported_languages": self.process_ocr_use_case.get_supported_languages(),
49 | }
50 |
51 | async def error(self):
52 | raise FileNotFoundError("This is a test error from the error endpoint")
53 |
54 | async def analyze_pdf(
55 | self, file: UploadFile = File(...), fast: bool = Form(False), parse_tables_and_math: bool = Form(False)
56 | ):
57 | return await run_in_threadpool(
58 | self.analyze_pdf_use_case.execute, file.file.read(), "", parse_tables_and_math, fast, False
59 | )
60 |
61 | async def analyze_and_save_xml(
62 | self, file: UploadFile = File(...), xml_file_name: str | None = None, fast: bool = Form(False)
63 | ):
64 | if not xml_file_name.endswith(".xml"):
65 | xml_file_name = f"{xml_file_name}.xml"
66 | return await run_in_threadpool(self.analyze_pdf_use_case.execute_and_save_xml, file.file.read(), xml_file_name, fast)
67 |
68 | async def get_xml_by_name(self, xml_file_name: str):
69 | if not xml_file_name.endswith(".xml"):
70 | xml_file_name = f"{xml_file_name}.xml"
71 | return await run_in_threadpool(self.file_repository.get_xml, xml_file_name)
72 |
73 | async def get_toc_endpoint(self, file: UploadFile = File(...), fast: bool = Form(False)):
74 | return await run_in_threadpool(self.extract_toc_use_case.execute, file, fast)
75 |
76 | async def toc_legacy_uwazi_compatible(self, file: UploadFile = File(...)):
77 | return await run_in_threadpool(self.extract_toc_use_case.execute_uwazi_compatible, file)
78 |
79 | async def get_text_endpoint(self, file: UploadFile = File(...), fast: bool = Form(False), types: str = Form("all")):
80 | return await run_in_threadpool(self.extract_text_use_case.execute, file, fast, types)
81 |
82 | async def get_visualization_endpoint(self, file: UploadFile = File(...), fast: bool = Form(False)):
83 | return await run_in_threadpool(self.create_visualization_use_case.execute, file, fast)
84 |
85 | async def ocr_pdf_sync(self, file: UploadFile = File(...), language: str = Form("en")):
86 | return await run_in_threadpool(self.process_ocr_use_case.execute, file, language)
87 |
88 | async def convert_to_markdown_endpoint(
89 | self,
90 | file: UploadFile = File(...),
91 | fast: bool = Form(False),
92 | extract_toc: bool = Form(False),
93 | dpi: int = Form(120),
94 | output_file: Optional[str] = Form(None),
95 | target_languages: Optional[str] = Form(None),
96 | translation_model: str = Form("gpt-oss"),
97 | ) -> Union[str, Response]:
98 | target_languages_list = None
99 | if target_languages:
100 | target_languages_list = [lang.strip() for lang in target_languages.split(",") if lang.strip()]
101 |
102 | return await run_in_threadpool(
103 | self.convert_to_markdown_use_case.execute,
104 | file.file.read(),
105 | fast,
106 | extract_toc,
107 | dpi,
108 | output_file,
109 | target_languages_list,
110 | translation_model,
111 | )
112 |
113 | async def convert_to_html_endpoint(
114 | self,
115 | file: UploadFile = File(...),
116 | fast: bool = Form(False),
117 | extract_toc: bool = Form(False),
118 | dpi: int = Form(120),
119 | output_file: Optional[str] = Form(None),
120 | target_languages: Optional[str] = Form(None),
121 | translation_model: str = Form("gpt-oss"),
122 | ) -> Union[str, Response]:
123 | target_languages_list = None
124 | if target_languages:
125 | target_languages_list = [lang.strip() for lang in target_languages.split(",") if lang.strip()]
126 |
127 | return await run_in_threadpool(
128 | self.convert_to_html_use_case.execute,
129 | file.file.read(),
130 | fast,
131 | extract_toc,
132 | dpi,
133 | output_file,
134 | target_languages_list,
135 | translation_model,
136 | )
137 |
--------------------------------------------------------------------------------
/src/app.py:
--------------------------------------------------------------------------------
1 | from configuration import RESTART_IF_NO_GPU
2 | from drivers.web.fastapi_app import create_app
3 | from drivers.web.dependency_injection import setup_dependencies
4 | import torch
5 |
6 | if RESTART_IF_NO_GPU:
7 | if not torch.cuda.is_available():
8 | raise RuntimeError("No GPU available. Restarting the service is required.")
9 |
10 | controllers = setup_dependencies()
11 |
12 | app = create_app(controllers)
13 |
--------------------------------------------------------------------------------
/src/catch_exceptions.py:
--------------------------------------------------------------------------------
1 | from functools import wraps
2 | from fastapi import HTTPException
3 |
4 | from configuration import service_logger
5 |
6 |
7 | def catch_exceptions(func):
8 | @wraps(func)
9 | async def wrapper(*args, **kwargs):
10 | try:
11 | service_logger.info(f"Calling endpoint: {func.__name__}")
12 | if kwargs and "file" in kwargs:
13 | service_logger.info(f"Processing file: {kwargs['file'].filename}")
14 | if kwargs and "xml_file_name" in kwargs:
15 | service_logger.info(f"Asking for file: {kwargs['xml_file_name']}")
16 | return await func(*args, **kwargs)
17 | except FileNotFoundError:
18 | raise HTTPException(status_code=404, detail="No xml file")
19 | except Exception:
20 | service_logger.error("Error see traceback", exc_info=1)
21 | raise HTTPException(status_code=422, detail="Error see traceback")
22 |
23 | return wrapper
24 |
--------------------------------------------------------------------------------
/src/configuration.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from pathlib import Path
4 |
5 |
6 | SRC_PATH = Path(__file__).parent.absolute()
7 | ROOT_PATH = Path(__file__).parent.parent.absolute()
8 |
9 | handlers = [logging.StreamHandler()]
10 | logging.root.handlers = []
11 | logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=handlers)
12 | service_logger = logging.getLogger(__name__)
13 |
14 | RESTART_IF_NO_GPU = os.environ.get("RESTART_IF_NO_GPU", "false").lower().strip() == "true"
15 | IMAGES_ROOT_PATH = Path(ROOT_PATH, "images")
16 | WORD_GRIDS_PATH = Path(ROOT_PATH, "word_grids")
17 | JSONS_ROOT_PATH = Path(ROOT_PATH, "jsons")
18 | OCR_SOURCE = Path(ROOT_PATH, "ocr", "source")
19 | OCR_OUTPUT = Path(ROOT_PATH, "ocr", "output")
20 | OCR_FAILED = Path(ROOT_PATH, "ocr", "failed")
21 | JSON_TEST_FILE_PATH = Path(JSONS_ROOT_PATH, "test.json")
22 | MODELS_PATH = Path(ROOT_PATH, "models")
23 | XMLS_PATH = Path(ROOT_PATH, "xmls")
24 |
25 | DOCLAYNET_TYPE_BY_ID = {
26 | 1: "Caption",
27 | 2: "Footnote",
28 | 3: "Formula",
29 | 4: "List_Item",
30 | 5: "Page_Footer",
31 | 6: "Page_Header",
32 | 7: "Picture",
33 | 8: "Section_Header",
34 | 9: "Table",
35 | 10: "Text",
36 | 11: "Title",
37 | }
38 |
--------------------------------------------------------------------------------
/src/domain/PdfImages.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | import cv2
5 | import numpy as np
6 | from os import makedirs
7 | from os.path import join
8 | from pathlib import Path
9 | from PIL import Image
10 | from pdf2image import convert_from_path
11 | from pdf_features import PdfFeatures
12 |
13 | from src.configuration import IMAGES_ROOT_PATH, XMLS_PATH
14 |
15 |
16 | class PdfImages:
17 | def __init__(self, pdf_features: PdfFeatures, pdf_images: list[Image], dpi: int = 72):
18 | self.pdf_features: PdfFeatures = pdf_features
19 | self.pdf_images: list[Image] = pdf_images
20 | self.dpi: int = dpi
21 | self.save_images()
22 |
23 | def show_images(self, next_image_delay: int = 2):
24 | for image_index, image in enumerate(self.pdf_images):
25 | image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
26 | cv2.imshow(f"Page: {image_index + 1}", image_np)
27 | cv2.waitKey(next_image_delay * 1000)
28 | cv2.destroyAllWindows()
29 |
30 | def save_images(self):
31 | makedirs(IMAGES_ROOT_PATH, exist_ok=True)
32 | for image_index, image in enumerate(self.pdf_images):
33 | image_name = f"{self.pdf_features.file_name}_{image_index}.jpg"
34 | image.save(join(IMAGES_ROOT_PATH, image_name))
35 |
36 | @staticmethod
37 | def remove_images():
38 | shutil.rmtree(IMAGES_ROOT_PATH)
39 |
40 | @staticmethod
41 | def from_pdf_path(pdf_path: str | Path, pdf_name: str = "", xml_file_name: str = "", dpi: int = 72):
42 | xml_path = None if not xml_file_name else Path(XMLS_PATH, xml_file_name)
43 |
44 | if xml_path and not xml_path.parent.exists():
45 | os.makedirs(xml_path.parent, exist_ok=True)
46 |
47 | pdf_features: PdfFeatures = PdfFeatures.from_pdf_path(pdf_path, xml_path)
48 |
49 | if pdf_name:
50 | pdf_features.file_name = pdf_name
51 | else:
52 | pdf_name = Path(pdf_path).parent.name if Path(pdf_path).name == "document.pdf" else Path(pdf_path).stem
53 | pdf_features.file_name = pdf_name
54 | pdf_images = convert_from_path(pdf_path, dpi=dpi)
55 | return PdfImages(pdf_features, pdf_images, dpi)
56 |
--------------------------------------------------------------------------------
/src/domain/PdfSegment.py:
--------------------------------------------------------------------------------
1 | from statistics import mode
2 | from pdf_features import PdfToken
3 | from pdf_features import Rectangle
4 | from pdf_token_type_labels import TokenType
5 |
6 |
7 | class PdfSegment:
8 | def __init__(
9 | self, page_number: int, bounding_box: Rectangle, text_content: str, segment_type: TokenType, pdf_name: str = ""
10 | ):
11 | self.page_number = page_number
12 | self.bounding_box = bounding_box
13 | self.text_content = text_content
14 | self.segment_type = segment_type
15 | self.pdf_name = pdf_name
16 |
17 | @staticmethod
18 | def from_pdf_tokens(pdf_tokens: list[PdfToken], pdf_name: str = ""):
19 | text: str = " ".join([pdf_token.content for pdf_token in pdf_tokens])
20 | bounding_boxes = [pdf_token.bounding_box for pdf_token in pdf_tokens]
21 | segment_type = mode([token.token_type for token in pdf_tokens])
22 | return PdfSegment(
23 | pdf_tokens[0].page_number, Rectangle.merge_rectangles(bounding_boxes), text, segment_type, pdf_name
24 | )
25 |
--------------------------------------------------------------------------------
/src/domain/Prediction.py:
--------------------------------------------------------------------------------
1 | from pdf_features import Rectangle
2 |
3 |
4 | class Prediction:
5 | def __init__(self, bounding_box: Rectangle, category_id: int, score: float):
6 | self.bounding_box: Rectangle = bounding_box
7 | self.category_id: int = category_id
8 | self.score: float = score
9 |
--------------------------------------------------------------------------------
/src/domain/SegmentBox.py:
--------------------------------------------------------------------------------
1 | from domain.PdfSegment import PdfSegment
2 | from pdf_features import PdfPage
3 | from pdf_token_type_labels import TokenType
4 | from pydantic import BaseModel
5 |
6 |
7 | class SegmentBox(BaseModel):
8 | left: float
9 | top: float
10 | width: float
11 | height: float
12 | page_number: int
13 | page_width: int
14 | page_height: int
15 | text: str = ""
16 | type: TokenType = TokenType.TEXT
17 | id: str = ""
18 |
19 | def __hash__(self):
20 | return hash(
21 | (
22 | self.left,
23 | self.top,
24 | self.width,
25 | self.height,
26 | self.page_number,
27 | self.page_width,
28 | self.page_height,
29 | self.text,
30 | self.type,
31 | self.id,
32 | )
33 | )
34 |
35 | def to_dict(self):
36 | return {
37 | "left": self.left,
38 | "top": self.top,
39 | "width": self.width,
40 | "height": self.height,
41 | "page_number": self.page_number,
42 | "page_width": self.page_width,
43 | "page_height": self.page_height,
44 | "text": self.text,
45 | "type": self.type.value,
46 | }
47 |
48 | @staticmethod
49 | def from_pdf_segment(pdf_segment: PdfSegment, pdf_pages: list[PdfPage]):
50 | return SegmentBox(
51 | left=pdf_segment.bounding_box.left,
52 | top=pdf_segment.bounding_box.top,
53 | width=pdf_segment.bounding_box.width,
54 | height=pdf_segment.bounding_box.height,
55 | page_number=pdf_segment.page_number,
56 | page_width=pdf_pages[pdf_segment.page_number - 1].page_width,
57 | page_height=pdf_pages[pdf_segment.page_number - 1].page_height,
58 | text=pdf_segment.text_content,
59 | type=pdf_segment.segment_type,
60 | )
61 |
62 |
63 | if __name__ == "__main__":
64 | a = TokenType.TEXT
65 | print(a.value)
66 |
--------------------------------------------------------------------------------
/src/download_models.py:
--------------------------------------------------------------------------------
1 | import math
2 | from os import makedirs
3 | from os.path import join, exists
4 | from pathlib import Path
5 | from urllib.request import urlretrieve
6 | from huggingface_hub import snapshot_download, hf_hub_download
7 |
8 | from configuration import service_logger, MODELS_PATH
9 |
10 |
11 | def download_progress(count, block_size, total_size):
12 | total_counts = total_size // block_size
13 | show_counts_percentages = total_counts // 5
14 | percent = count * block_size * 100 / total_size
15 | if count % show_counts_percentages == 0:
16 | service_logger.info(f"Downloaded {math.ceil(percent)}%")
17 |
18 |
19 | def download_vgt_model(model_name: str):
20 | service_logger.info(f"Downloading {model_name} model")
21 | model_path = join(MODELS_PATH, f"{model_name}_VGT_model.pth")
22 | if exists(model_path):
23 | return
24 | download_link = f"https://github.com/AlibabaResearch/AdvancedLiterateMachinery/releases/download/v1.3.0-VGT-release/{model_name}_VGT_model.pth"
25 | urlretrieve(download_link, model_path, reporthook=download_progress)
26 |
27 |
28 | def download_embedding_model():
29 | model_path = join(MODELS_PATH, "layoutlm-base-uncased")
30 | if exists(model_path):
31 | return
32 | makedirs(model_path, exist_ok=True)
33 | service_logger.info("Embedding model is being downloaded")
34 | snapshot_download(repo_id="microsoft/layoutlm-base-uncased", local_dir=model_path, local_dir_use_symlinks=False)
35 |
36 |
37 | def download_from_hf_hub(path: Path):
38 | if path.exists():
39 | return
40 |
41 | file_name = path.name
42 | makedirs(path.parent, exist_ok=True)
43 | repo_id = "HURIDOCS/pdf-document-layout-analysis"
44 | hf_hub_download(repo_id=repo_id, filename=file_name, local_dir=path.parent, local_dir_use_symlinks=False)
45 |
46 |
47 | def download_lightgbm_models():
48 | download_from_hf_hub(Path(MODELS_PATH, "token_type_lightgbm.model"))
49 | download_from_hf_hub(Path(MODELS_PATH, "paragraph_extraction_lightgbm.model"))
50 | download_from_hf_hub(Path(MODELS_PATH, "config.json"))
51 |
52 |
53 | def download_models(model_name: str):
54 | makedirs(MODELS_PATH, exist_ok=True)
55 | if model_name == "fast":
56 | download_lightgbm_models()
57 | return
58 | download_vgt_model(model_name)
59 | download_embedding_model()
60 |
61 |
62 | if __name__ == "__main__":
63 | download_models("doclaynet")
64 | download_models("fast")
65 |
--------------------------------------------------------------------------------
/src/drivers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/drivers/__init__.py
--------------------------------------------------------------------------------
/src/drivers/web/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/drivers/web/__init__.py
--------------------------------------------------------------------------------
/src/drivers/web/dependency_injection.py:
--------------------------------------------------------------------------------
1 | from adapters.storage.file_system_repository import FileSystemRepository
2 | from adapters.ml.vgt_model_adapter import VGTModelAdapter
3 | from adapters.ml.fast_trainer_adapter import FastTrainerAdapter
4 | from adapters.infrastructure.pdf_analysis_service_adapter import PDFAnalysisServiceAdapter
5 | from adapters.infrastructure.text_extraction_adapter import TextExtractionAdapter
6 | from adapters.infrastructure.toc_service_adapter import TOCServiceAdapter
7 | from adapters.infrastructure.visualization_service_adapter import VisualizationServiceAdapter
8 | from adapters.infrastructure.ocr_service_adapter import OCRServiceAdapter
9 | from adapters.infrastructure.format_conversion_service_adapter import FormatConversionServiceAdapter
10 | from adapters.infrastructure.markdown_conversion_service_adapter import MarkdownConversionServiceAdapter
11 | from adapters.infrastructure.html_conversion_service_adapter import HtmlConversionServiceAdapter
12 | from adapters.web.fastapi_controllers import FastAPIControllers
13 | from use_cases.pdf_analysis.analyze_pdf_use_case import AnalyzePDFUseCase
14 | from use_cases.text_extraction.extract_text_use_case import ExtractTextUseCase
15 | from use_cases.toc_extraction.extract_toc_use_case import ExtractTOCUseCase
16 | from use_cases.visualization.create_visualization_use_case import CreateVisualizationUseCase
17 | from use_cases.ocr.process_ocr_use_case import ProcessOCRUseCase
18 | from use_cases.markdown_conversion.convert_to_markdown_use_case import ConvertToMarkdownUseCase
19 | from use_cases.html_conversion.convert_to_html_use_case import ConvertToHtmlUseCase
20 |
21 |
22 | def setup_dependencies():
23 | file_repository = FileSystemRepository()
24 |
25 | vgt_model_service = VGTModelAdapter()
26 | fast_model_service = FastTrainerAdapter()
27 |
28 | format_conversion_service = FormatConversionServiceAdapter()
29 | markdown_conversion_service = MarkdownConversionServiceAdapter()
30 | html_conversion_service = HtmlConversionServiceAdapter()
31 | text_extraction_service = TextExtractionAdapter()
32 | toc_service = TOCServiceAdapter()
33 | visualization_service = VisualizationServiceAdapter()
34 | ocr_service = OCRServiceAdapter()
35 |
36 | pdf_analysis_service = PDFAnalysisServiceAdapter(
37 | vgt_model_service=vgt_model_service,
38 | fast_model_service=fast_model_service,
39 | format_conversion_service=format_conversion_service,
40 | file_repository=file_repository,
41 | )
42 |
43 | analyze_pdf_use_case = AnalyzePDFUseCase(pdf_analysis_service=pdf_analysis_service, ml_model_service=vgt_model_service)
44 |
45 | extract_text_use_case = ExtractTextUseCase(
46 | pdf_analysis_service=pdf_analysis_service, text_extraction_service=text_extraction_service
47 | )
48 |
49 | extract_toc_use_case = ExtractTOCUseCase(pdf_analysis_service=pdf_analysis_service, toc_service=toc_service)
50 |
51 | create_visualization_use_case = CreateVisualizationUseCase(
52 | pdf_analysis_service=pdf_analysis_service, visualization_service=visualization_service
53 | )
54 |
55 | process_ocr_use_case = ProcessOCRUseCase(ocr_service=ocr_service, file_repository=file_repository)
56 |
57 | convert_to_markdown_use_case = ConvertToMarkdownUseCase(
58 | pdf_analysis_service=pdf_analysis_service, markdown_conversion_service=markdown_conversion_service
59 | )
60 |
61 | convert_to_html_use_case = ConvertToHtmlUseCase(
62 | pdf_analysis_service=pdf_analysis_service, html_conversion_service=html_conversion_service
63 | )
64 |
65 | controllers = FastAPIControllers(
66 | analyze_pdf_use_case=analyze_pdf_use_case,
67 | extract_text_use_case=extract_text_use_case,
68 | extract_toc_use_case=extract_toc_use_case,
69 | create_visualization_use_case=create_visualization_use_case,
70 | process_ocr_use_case=process_ocr_use_case,
71 | convert_to_markdown_use_case=convert_to_markdown_use_case,
72 | convert_to_html_use_case=convert_to_html_use_case,
73 | file_repository=file_repository,
74 | )
75 |
76 | return controllers
77 |
--------------------------------------------------------------------------------
/src/drivers/web/fastapi_app.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from fastapi import FastAPI
3 | from fastapi.responses import PlainTextResponse
4 | from adapters.web.fastapi_controllers import FastAPIControllers
5 | from catch_exceptions import catch_exceptions
6 | from configuration import service_logger
7 |
8 |
9 | def create_app(controllers: FastAPIControllers) -> FastAPI:
10 | service_logger.info(f"Is PyTorch using GPU: {torch.cuda.is_available()}")
11 |
12 | app = FastAPI()
13 |
14 | app.get("/")(controllers.root)
15 | app.get("/info")(controllers.info)
16 | app.get("/error")(controllers.error)
17 |
18 | app.post("/")(catch_exceptions(controllers.analyze_pdf))
19 | app.post("/save_xml/{xml_file_name}")(catch_exceptions(controllers.analyze_and_save_xml))
20 | app.get("/get_xml/{xml_file_name}", response_class=PlainTextResponse)(catch_exceptions(controllers.get_xml_by_name))
21 |
22 | app.post("/toc")(catch_exceptions(controllers.get_toc_endpoint))
23 | app.post("/toc_legacy_uwazi_compatible")(catch_exceptions(controllers.toc_legacy_uwazi_compatible))
24 |
25 | app.post("/text")(catch_exceptions(controllers.get_text_endpoint))
26 | app.post("/visualize")(catch_exceptions(controllers.get_visualization_endpoint))
27 | app.post("/markdown", response_model=None)(catch_exceptions(controllers.convert_to_markdown_endpoint))
28 | app.post("/html", response_model=None)(catch_exceptions(controllers.convert_to_html_endpoint))
29 | app.post("/ocr")(catch_exceptions(controllers.ocr_pdf_sync))
30 |
31 | return app
32 |
--------------------------------------------------------------------------------
/src/ports/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/ports/__init__.py
--------------------------------------------------------------------------------
/src/ports/repositories/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/ports/repositories/__init__.py
--------------------------------------------------------------------------------
/src/ports/repositories/file_repository.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from pathlib import Path
3 | from typing import AnyStr
4 |
5 |
6 | class FileRepository(ABC):
7 | @abstractmethod
8 | def save_pdf(self, content: AnyStr, filename: str = "") -> Path:
9 | pass
10 |
11 | @abstractmethod
12 | def save_xml(self, content: str, filename: str) -> Path:
13 | pass
14 |
15 | @abstractmethod
16 | def get_xml(self, filename: str) -> str:
17 | pass
18 |
19 | @abstractmethod
20 | def delete_file(self, filepath: Path) -> None:
21 | pass
22 |
23 | @abstractmethod
24 | def cleanup_temp_files(self) -> None:
25 | pass
26 |
27 | @abstractmethod
28 | def save_pdf_to_directory(self, content: AnyStr, filename: str, directory: Path, namespace: str = "") -> Path:
29 | pass
30 |
31 | @abstractmethod
32 | def save_markdown(self, content: str, filepath: Path) -> Path:
33 | pass
34 |
--------------------------------------------------------------------------------
/src/ports/services/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/ports/services/__init__.py
--------------------------------------------------------------------------------
/src/ports/services/format_conversion_service.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from domain.PdfImages import PdfImages
3 | from domain.PdfSegment import PdfSegment
4 |
5 |
6 | class FormatConversionService(ABC):
7 |
8 | @abstractmethod
9 | def convert_table_to_html(self, pdf_images: PdfImages, segments: list[PdfSegment]) -> None:
10 | pass
11 |
12 | @abstractmethod
13 | def convert_formula_to_latex(self, pdf_images: PdfImages, segments: list[PdfSegment]) -> None:
14 | pass
15 |
--------------------------------------------------------------------------------
/src/ports/services/html_conversion_service.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Optional, Union
3 | from starlette.responses import Response
4 | from domain.SegmentBox import SegmentBox
5 |
6 |
7 | class HtmlConversionService(ABC):
8 |
9 | @abstractmethod
10 | def convert_to_html(
11 | self,
12 | pdf_content: bytes,
13 | segments: list[SegmentBox],
14 | extract_toc: bool = False,
15 | dpi: int = 120,
16 | output_file: Optional[str] = None,
17 | target_languages: Optional[list[str]] = None,
18 | translation_model: str = "gpt-oss",
19 | ) -> Union[str, Response]:
20 | pass
21 |
--------------------------------------------------------------------------------
/src/ports/services/markdown_conversion_service.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Optional, Union
3 | from starlette.responses import Response
4 | from domain.SegmentBox import SegmentBox
5 |
6 |
7 | class MarkdownConversionService(ABC):
8 |
9 | @abstractmethod
10 | def convert_to_markdown(
11 | self,
12 | pdf_content: bytes,
13 | segments: list[SegmentBox],
14 | extract_toc: bool = False,
15 | dpi: int = 120,
16 | output_file: Optional[str] = None,
17 | target_languages: Optional[list[str]] = None,
18 | translation_model: str = "gpt-oss",
19 | ) -> Union[str, Response]:
20 | pass
21 |
--------------------------------------------------------------------------------
/src/ports/services/ml_model_service.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from domain.PdfImages import PdfImages
3 | from domain.PdfSegment import PdfSegment
4 |
5 |
6 | class MLModelService(ABC):
7 | @abstractmethod
8 | def predict_document_layout(self, pdf_images: list[PdfImages]) -> list[PdfSegment]:
9 | pass
10 |
11 | @abstractmethod
12 | def predict_layout_fast(self, pdf_images: list[PdfImages]) -> list[PdfSegment]:
13 | pass
14 |
--------------------------------------------------------------------------------
/src/ports/services/ocr_service.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from pathlib import Path
3 |
4 |
5 | class OCRService(ABC):
6 | @abstractmethod
7 | def process_pdf_ocr(self, filename: str, namespace: str, language: str = "en") -> Path:
8 | pass
9 |
10 | @abstractmethod
11 | def get_supported_languages(self) -> list[str]:
12 | pass
13 |
--------------------------------------------------------------------------------
/src/ports/services/pdf_analysis_service.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import AnyStr
3 |
4 |
5 | class PDFAnalysisService(ABC):
6 | @abstractmethod
7 | def analyze_pdf_layout(
8 | self, pdf_content: AnyStr, xml_filename: str = "", parse_tables_and_math: bool = False, keep_pdf: bool = False
9 | ) -> list[dict]:
10 | pass
11 |
12 | @abstractmethod
13 | def analyze_pdf_layout_fast(
14 | self, pdf_content: AnyStr, xml_filename: str = "", parse_tables_and_math: bool = False, keep_pdf: bool = False
15 | ) -> list[dict]:
16 | pass
17 |
--------------------------------------------------------------------------------
/src/ports/services/text_extraction_service.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from pdf_token_type_labels import TokenType
3 |
4 |
5 | class TextExtractionService(ABC):
6 | @abstractmethod
7 | def extract_text_by_types(self, segment_boxes: list[dict], token_types: list[TokenType]) -> dict:
8 | pass
9 |
10 | @abstractmethod
11 | def extract_all_text(self, segment_boxes: list[dict]) -> dict:
12 | pass
13 |
--------------------------------------------------------------------------------
/src/ports/services/toc_service.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import AnyStr
3 |
4 |
5 | class TOCService(ABC):
6 | @abstractmethod
7 | def extract_table_of_contents(self, pdf_content: AnyStr, segment_boxes: list[dict]) -> list[dict]:
8 | pass
9 |
10 | @abstractmethod
11 | def format_toc_for_uwazi(self, toc_items: list[dict]) -> list[dict]:
12 | pass
13 |
--------------------------------------------------------------------------------
/src/ports/services/visualization_service.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from pathlib import Path
3 | from starlette.responses import FileResponse
4 |
5 |
6 | class VisualizationService(ABC):
7 | @abstractmethod
8 | def create_pdf_visualization(self, pdf_path: Path, segment_boxes: list[dict]) -> Path:
9 | pass
10 |
11 | @abstractmethod
12 | def get_visualization_response(self, pdf_path: Path) -> FileResponse:
13 | pass
14 |
--------------------------------------------------------------------------------
/src/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/tests/__init__.py
--------------------------------------------------------------------------------
/src/use_cases/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/use_cases/__init__.py
--------------------------------------------------------------------------------
/src/use_cases/html_conversion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/use_cases/html_conversion/__init__.py
--------------------------------------------------------------------------------
/src/use_cases/html_conversion/convert_to_html_use_case.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union
2 | from starlette.responses import Response
3 | from ports.services.html_conversion_service import HtmlConversionService
4 | from ports.services.pdf_analysis_service import PDFAnalysisService
5 | from domain.SegmentBox import SegmentBox
6 |
7 |
8 | class ConvertToHtmlUseCase:
9 | def __init__(
10 | self,
11 | pdf_analysis_service: PDFAnalysisService,
12 | html_conversion_service: HtmlConversionService,
13 | ):
14 | self.pdf_analysis_service = pdf_analysis_service
15 | self.html_conversion_service = html_conversion_service
16 |
17 | def execute(
18 | self,
19 | pdf_content: bytes,
20 | use_fast_mode: bool = False,
21 | extract_toc: bool = False,
22 | dpi: int = 120,
23 | output_file: Optional[str] = None,
24 | target_languages: Optional[list[str]] = None,
25 | translation_model: str = "gpt-oss",
26 | ) -> Union[str, Response]:
27 | if use_fast_mode:
28 | analysis_result = self.pdf_analysis_service.analyze_pdf_layout_fast(pdf_content, "", True, False)
29 | else:
30 | analysis_result = self.pdf_analysis_service.analyze_pdf_layout(pdf_content, "", True, False)
31 |
32 | segments: list[SegmentBox] = []
33 | for item in analysis_result:
34 | if isinstance(item, dict):
35 | segment = SegmentBox(
36 | left=item.get("left", 0),
37 | top=item.get("top", 0),
38 | width=item.get("width", 0),
39 | height=item.get("height", 0),
40 | page_number=item.get("page_number", 1),
41 | page_width=item.get("page_width", 0),
42 | page_height=item.get("page_height", 0),
43 | text=item.get("text", ""),
44 | type=item.get("type", "TEXT"),
45 | )
46 | segments.append(segment)
47 | elif isinstance(item, SegmentBox):
48 | segments.append(item)
49 |
50 | return self.html_conversion_service.convert_to_html(
51 | pdf_content, segments, extract_toc, dpi, output_file, target_languages, translation_model
52 | )
53 |
--------------------------------------------------------------------------------
/src/use_cases/markdown_conversion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/use_cases/markdown_conversion/__init__.py
--------------------------------------------------------------------------------
/src/use_cases/markdown_conversion/convert_to_markdown_use_case.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union
2 | from starlette.responses import Response
3 | from ports.services.markdown_conversion_service import MarkdownConversionService
4 | from ports.services.pdf_analysis_service import PDFAnalysisService
5 | from domain.SegmentBox import SegmentBox
6 |
7 |
8 | class ConvertToMarkdownUseCase:
9 | def __init__(
10 | self,
11 | pdf_analysis_service: PDFAnalysisService,
12 | markdown_conversion_service: MarkdownConversionService,
13 | ):
14 | self.pdf_analysis_service = pdf_analysis_service
15 | self.markdown_conversion_service = markdown_conversion_service
16 |
17 | def execute(
18 | self,
19 | pdf_content: bytes,
20 | use_fast_mode: bool = False,
21 | extract_toc: bool = False,
22 | dpi: int = 120,
23 | output_file: Optional[str] = None,
24 | target_languages: Optional[list[str]] = None,
25 | translation_model: str = "gpt-oss",
26 | ) -> Union[str, Response]:
27 | if use_fast_mode:
28 | analysis_result = self.pdf_analysis_service.analyze_pdf_layout_fast(pdf_content, "", True, False)
29 | else:
30 | analysis_result = self.pdf_analysis_service.analyze_pdf_layout(pdf_content, "", True, False)
31 |
32 | segments: list[SegmentBox] = []
33 | for item in analysis_result:
34 | if isinstance(item, dict):
35 | segment = SegmentBox(
36 | left=item.get("left", 0),
37 | top=item.get("top", 0),
38 | width=item.get("width", 0),
39 | height=item.get("height", 0),
40 | page_number=item.get("page_number", 1),
41 | page_width=item.get("page_width", 0),
42 | page_height=item.get("page_height", 0),
43 | text=item.get("text", ""),
44 | type=item.get("type", "TEXT"),
45 | )
46 | segments.append(segment)
47 | elif isinstance(item, SegmentBox):
48 | segments.append(item)
49 |
50 | return self.markdown_conversion_service.convert_to_markdown(
51 | pdf_content, segments, extract_toc, dpi, output_file, target_languages, translation_model
52 | )
53 |
--------------------------------------------------------------------------------
/src/use_cases/ocr/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/use_cases/ocr/__init__.py
--------------------------------------------------------------------------------
/src/use_cases/ocr/process_ocr_use_case.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from fastapi import UploadFile
3 | from starlette.responses import FileResponse
4 | from ports.services.ocr_service import OCRService
5 | from ports.repositories.file_repository import FileRepository
6 | from configuration import OCR_SOURCE
7 |
8 |
9 | class ProcessOCRUseCase:
10 | def __init__(self, ocr_service: OCRService, file_repository: FileRepository):
11 | self.ocr_service = ocr_service
12 | self.file_repository = file_repository
13 |
14 | def execute(self, file: UploadFile, language: str = "en") -> FileResponse:
15 | namespace = "sync_pdfs"
16 |
17 | self.file_repository.save_pdf_to_directory(
18 | content=file.file.read(), filename=file.filename, directory=Path(OCR_SOURCE), namespace=namespace
19 | )
20 |
21 | processed_pdf_filepath = self.ocr_service.process_pdf_ocr(file.filename, namespace, language)
22 |
23 | return FileResponse(path=processed_pdf_filepath, media_type="application/pdf")
24 |
25 | def get_supported_languages(self) -> list:
26 | return self.ocr_service.get_supported_languages()
27 |
--------------------------------------------------------------------------------
/src/use_cases/pdf_analysis/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/use_cases/pdf_analysis/__init__.py
--------------------------------------------------------------------------------
/src/use_cases/pdf_analysis/analyze_pdf_use_case.py:
--------------------------------------------------------------------------------
1 | from typing import AnyStr
2 | from ports.services.pdf_analysis_service import PDFAnalysisService
3 | from ports.services.ml_model_service import MLModelService
4 |
5 |
6 | class AnalyzePDFUseCase:
7 | def __init__(
8 | self,
9 | pdf_analysis_service: PDFAnalysisService,
10 | ml_model_service: MLModelService,
11 | ):
12 | self.pdf_analysis_service = pdf_analysis_service
13 | self.ml_model_service = ml_model_service
14 |
15 | def execute(
16 | self,
17 | pdf_content: AnyStr,
18 | xml_filename: str = "",
19 | parse_tables_and_math: bool = False,
20 | use_fast_mode: bool = False,
21 | keep_pdf: bool = False,
22 | ) -> list[dict]:
23 | if use_fast_mode:
24 | return self.pdf_analysis_service.analyze_pdf_layout_fast(
25 | pdf_content, xml_filename, parse_tables_and_math, keep_pdf
26 | )
27 | else:
28 | return self.pdf_analysis_service.analyze_pdf_layout(pdf_content, xml_filename, parse_tables_and_math, keep_pdf)
29 |
30 | def execute_and_save_xml(self, pdf_content: AnyStr, xml_filename: str, use_fast_mode: bool = False) -> list[dict]:
31 | result = self.execute(pdf_content, xml_filename, False, use_fast_mode, keep_pdf=False)
32 | return result
33 |
--------------------------------------------------------------------------------
/src/use_cases/text_extraction/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/use_cases/text_extraction/__init__.py
--------------------------------------------------------------------------------
/src/use_cases/text_extraction/extract_text_use_case.py:
--------------------------------------------------------------------------------
1 | from fastapi import UploadFile
2 | from pdf_token_type_labels import TokenType
3 | from ports.services.pdf_analysis_service import PDFAnalysisService
4 | from ports.services.text_extraction_service import TextExtractionService
5 |
6 |
7 | class ExtractTextUseCase:
8 | def __init__(self, pdf_analysis_service: PDFAnalysisService, text_extraction_service: TextExtractionService):
9 | self.pdf_analysis_service = pdf_analysis_service
10 | self.text_extraction_service = text_extraction_service
11 |
12 | def execute(self, file: UploadFile, use_fast_mode: bool = False, types: str = "all") -> dict:
13 | file_content = file.file.read()
14 |
15 | if types == "all":
16 | token_types: list[TokenType] = [t for t in TokenType]
17 | else:
18 | token_types = list(set([TokenType.from_text(t.strip().replace(" ", "_")) for t in types.split(",")]))
19 |
20 | if use_fast_mode:
21 | segment_boxes = self.pdf_analysis_service.analyze_pdf_layout_fast(file_content)
22 | else:
23 | segment_boxes = self.pdf_analysis_service.analyze_pdf_layout(file_content, "")
24 |
25 | return self.text_extraction_service.extract_text_by_types(segment_boxes, token_types)
26 |
--------------------------------------------------------------------------------
/src/use_cases/toc_extraction/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/use_cases/toc_extraction/__init__.py
--------------------------------------------------------------------------------
/src/use_cases/toc_extraction/extract_toc_use_case.py:
--------------------------------------------------------------------------------
1 | from fastapi import UploadFile
2 | from ports.services.pdf_analysis_service import PDFAnalysisService
3 | from ports.services.toc_service import TOCService
4 |
5 |
6 | class ExtractTOCUseCase:
7 | def __init__(self, pdf_analysis_service: PDFAnalysisService, toc_service: TOCService):
8 | self.pdf_analysis_service = pdf_analysis_service
9 | self.toc_service = toc_service
10 |
11 | def execute(self, file: UploadFile, use_fast_mode: bool = False) -> list[dict]:
12 | file_content = file.file.read()
13 |
14 | if use_fast_mode:
15 | segment_boxes = self.pdf_analysis_service.analyze_pdf_layout_fast(file_content)
16 | else:
17 | segment_boxes = self.pdf_analysis_service.analyze_pdf_layout(file_content, "")
18 |
19 | return self.toc_service.extract_table_of_contents(file_content, segment_boxes)
20 |
21 | def execute_uwazi_compatible(self, file: UploadFile) -> list[dict]:
22 | toc_items = self.execute(file, use_fast_mode=True)
23 | return self.toc_service.format_toc_for_uwazi(toc_items)
24 |
--------------------------------------------------------------------------------
/src/use_cases/visualization/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/src/use_cases/visualization/__init__.py
--------------------------------------------------------------------------------
/src/use_cases/visualization/create_visualization_use_case.py:
--------------------------------------------------------------------------------
1 | from fastapi import UploadFile
2 | from starlette.responses import FileResponse
3 | from ports.services.pdf_analysis_service import PDFAnalysisService
4 | from ports.services.visualization_service import VisualizationService
5 | from glob import glob
6 | from os.path import getctime, join
7 | from tempfile import gettempdir
8 | from pathlib import Path
9 |
10 |
11 | class CreateVisualizationUseCase:
12 | def __init__(self, pdf_analysis_service: PDFAnalysisService, visualization_service: VisualizationService):
13 | self.pdf_analysis_service = pdf_analysis_service
14 | self.visualization_service = visualization_service
15 |
16 | def execute(self, file: UploadFile, use_fast_mode: bool = False) -> FileResponse:
17 | file_content = file.file.read()
18 |
19 | if use_fast_mode:
20 | segment_boxes = self.pdf_analysis_service.analyze_pdf_layout_fast(file_content, "", "", True)
21 | else:
22 | segment_boxes = self.pdf_analysis_service.analyze_pdf_layout(file_content, "", "", True)
23 |
24 | pdf_path = Path(max(glob(join(gettempdir(), "*.pdf")), key=getctime))
25 | visualization_path = self.visualization_service.create_pdf_visualization(pdf_path, segment_boxes)
26 |
27 | return self.visualization_service.get_visualization_response(visualization_path)
28 |
--------------------------------------------------------------------------------
/start.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | gunicorn -k uvicorn.workers.UvicornWorker --chdir ./src app:app --bind 0.0.0.0:5060 --timeout 10000
--------------------------------------------------------------------------------
/test_pdfs/blank.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/test_pdfs/blank.pdf
--------------------------------------------------------------------------------
/test_pdfs/chinese.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/test_pdfs/chinese.pdf
--------------------------------------------------------------------------------
/test_pdfs/error.pdf:
--------------------------------------------------------------------------------
1 | error
--------------------------------------------------------------------------------
/test_pdfs/formula.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/test_pdfs/formula.pdf
--------------------------------------------------------------------------------
/test_pdfs/image.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/test_pdfs/image.pdf
--------------------------------------------------------------------------------
/test_pdfs/korean.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/test_pdfs/korean.pdf
--------------------------------------------------------------------------------
/test_pdfs/not_a_pdf.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/test_pdfs/not_a_pdf.pdf
--------------------------------------------------------------------------------
/test_pdfs/ocr-sample-already-ocred.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/test_pdfs/ocr-sample-already-ocred.pdf
--------------------------------------------------------------------------------
/test_pdfs/ocr-sample-english.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/test_pdfs/ocr-sample-english.pdf
--------------------------------------------------------------------------------
/test_pdfs/ocr-sample-french.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/test_pdfs/ocr-sample-french.pdf
--------------------------------------------------------------------------------
/test_pdfs/ocr_pdf.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/test_pdfs/ocr_pdf.pdf
--------------------------------------------------------------------------------
/test_pdfs/regular.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/test_pdfs/regular.pdf
--------------------------------------------------------------------------------
/test_pdfs/some_empty_pages.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/test_pdfs/some_empty_pages.pdf
--------------------------------------------------------------------------------
/test_pdfs/table.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/test_pdfs/table.pdf
--------------------------------------------------------------------------------
/test_pdfs/test.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/test_pdfs/test.pdf
--------------------------------------------------------------------------------
/test_pdfs/toc-test.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huridocs/pdf-document-layout-analysis/2caad114f12386af06603b6c3af2018c5a59702b/test_pdfs/toc-test.pdf
--------------------------------------------------------------------------------