├── .github └── workflows │ ├── benchmarks.yml │ ├── ci.yml │ ├── cla.yml │ ├── publish.yml │ └── scripts.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CITATION.cff ├── CLA.md ├── LICENSE ├── README.md ├── benchmark ├── detection.py ├── layout.py ├── ordering.py ├── recognition.py ├── table_recognition.py ├── texify.py └── utils │ ├── __init__.py │ ├── bbox.py │ ├── metrics.py │ ├── scoring.py │ ├── tatr.py │ ├── tesseract.py │ ├── textract.py │ └── verify_benchmark_scores.py ├── detect_layout.py ├── detect_text.py ├── ocr_app.py ├── ocr_latex.py ├── ocr_text.py ├── poetry.lock ├── pyproject.toml ├── pytest.ini ├── signatures └── version1 │ └── cla.json ├── static ├── fonts │ └── .gitignore └── images │ ├── arabic.jpg │ ├── arabic_layout.jpg │ ├── arabic_reading.jpg │ ├── arabic_text.jpg │ ├── benchmark_chart.png │ ├── benchmark_chart_small.png │ ├── benchmark_layout_chart.png │ ├── benchmark_rec_chart.png │ ├── benchmark_tablerec_acc.png │ ├── benchmark_tablerec_speed.png │ ├── chi_hind.jpg │ ├── chi_hind_layout.jpg │ ├── chi_hind_orig.jpg │ ├── chi_hind_reading.jpg │ ├── chi_hind_text.jpg │ ├── chinese.jpg │ ├── chinese_layout.jpg │ ├── chinese_reading.jpg │ ├── chinese_text.jpg │ ├── excerpt.png │ ├── excerpt_layout.png │ ├── excerpt_reading.jpg │ ├── excerpt_text.png │ ├── funsd.png │ ├── funsd_layout.jpg │ ├── funsd_reading.jpg │ ├── funsd_text.jpg │ ├── gcloud_full_langs.png │ ├── gcloud_rec_bench.png │ ├── hindi.jpg │ ├── hindi_layout.jpg │ ├── hindi_reading.jpg │ ├── hindi_text.jpg │ ├── japanese.jpg │ ├── japanese_layout.jpg │ ├── japanese_reading.jpg │ ├── japanese_tablerec.png │ ├── japanese_text.jpg │ ├── latex_ocr.png │ ├── nyt.jpg │ ├── nyt_layout.jpg │ ├── nyt_order.jpg │ ├── nyt_text.jpg │ ├── paper.jpg │ ├── paper_layout.jpg │ ├── paper_reading.jpg │ ├── paper_tablerec.png │ ├── paper_text.jpg │ ├── pres.png │ ├── pres_layout.jpg │ ├── pres_reading.jpg │ ├── pres_tablerec.png │ ├── pres_text.jpg │ ├── rec_acc_table.png │ ├── scanned.png │ ├── scanned_layout.jpg │ ├── scanned_reading.jpg │ ├── scanned_tablerec.png │ ├── scanned_tablerec2.png │ ├── scanned_text.jpg │ ├── surya_rec_perf.png │ ├── table_rec.png │ ├── textbook.jpg │ ├── textbook_layout.jpg │ ├── textbook_order.jpg │ └── textbook_text.jpg ├── surya ├── __init__.py ├── common │ ├── __init__.py │ ├── adetr │ │ └── decoder.py │ ├── donut │ │ ├── encoder.py │ │ └── processor.py │ ├── load.py │ ├── polygon.py │ ├── predictor.py │ ├── s3.py │ ├── surya │ │ ├── __init__.py │ │ ├── config.py │ │ ├── decoder │ │ │ ├── __init__.py │ │ │ └── config.py │ │ ├── embedder │ │ │ └── __init__.py │ │ ├── encoder │ │ │ ├── __init__.py │ │ │ └── config.py │ │ ├── flash_attn_utils.py │ │ ├── processor │ │ │ ├── __init__.py │ │ │ ├── schema.py │ │ │ └── tokenizer.py │ │ └── schema.py │ └── util.py ├── debug │ ├── draw.py │ ├── fonts.py │ ├── katex.js │ ├── render_html.py │ └── text.py ├── detection │ ├── __init__.py │ ├── affinity.py │ ├── heatmap.py │ ├── loader.py │ ├── model │ │ ├── __init__.py │ │ ├── config.py │ │ └── encoderdecoder.py │ ├── parallel.py │ ├── processor.py │ ├── schema.py │ └── util.py ├── input │ ├── load.py │ └── processing.py ├── layout │ ├── __init__.py │ ├── loader.py │ ├── model │ │ ├── __init__.py │ │ ├── config.py │ │ ├── decoder.py │ │ ├── encoder.py │ │ └── encoderdecoder.py │ ├── schema.py │ ├── slicer.py │ └── util.py ├── logging.py ├── models.py ├── ocr_error │ ├── __init__.py │ ├── loader.py │ ├── model │ │ ├── __init__.py │ │ ├── config.py │ │ └── encoder.py │ ├── schema.py │ └── tokenizer.py ├── recognition │ ├── __init__.py │ ├── cache.py │ ├── languages.py │ ├── loader.py │ ├── postprocessing.py │ ├── schema.py │ └── util.py ├── scripts │ ├── __init__.py │ ├── config.py │ ├── detect_layout.py │ ├── detect_text.py │ ├── hf_to_s3.py │ ├── ocr_latex.py │ ├── ocr_text.py │ ├── run_streamlit_app.py │ ├── run_texify_app.py │ ├── streamlit_app.py │ ├── table_recognition.py │ └── texify_app.py ├── settings.py └── table_rec │ ├── __init__.py │ ├── loader.py │ ├── model │ ├── __init__.py │ ├── config.py │ ├── decoder.py │ ├── encoder.py │ └── encoderdecoder.py │ ├── processor.py │ ├── schema.py │ └── shaper.py ├── table_recognition.py ├── tests ├── conftest.py ├── test_detection.py ├── test_latex_ocr.py ├── test_layout.py ├── test_ocr_errors.py ├── test_recognition.py └── test_table_rec.py └── texify_app.py /.github/workflows/benchmarks.yml: -------------------------------------------------------------------------------- 1 | name: Integration test 2 | 3 | on: [push] 4 | 5 | env: 6 | PYTHONIOENCODING: "utf-8" 7 | 8 | jobs: 9 | build: 10 | runs-on: ${{ matrix.os }} 11 | strategy: 12 | matrix: 13 | os: [ubuntu-latest, windows-latest] 14 | steps: 15 | - uses: actions/checkout@v3 16 | - name: Set up Python 3.11 17 | uses: actions/setup-python@v4 18 | with: 19 | python-version: 3.11 20 | - name: Install python dependencies 21 | run: | 22 | pip install poetry 23 | poetry install 24 | - name: Run detection benchmark test 25 | run: | 26 | poetry run python benchmark/detection.py --max_rows 2 27 | poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/det_bench/results.json --bench_type detection 28 | - name: Run recognition benchmark test 29 | run: | 30 | poetry run python benchmark/recognition.py --max_rows 2 31 | poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/rec_bench/results.json --bench_type recognition 32 | - name: Run layout benchmark test 33 | run: | 34 | poetry run python benchmark/layout.py --max_rows 5 35 | poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/layout_bench/results.json --bench_type layout 36 | - name: Run ordering benchmark 37 | run: | 38 | poetry run python benchmark/ordering.py --max_rows 5 39 | poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/order_bench/results.json --bench_type ordering 40 | - name: Run table recognition benchmark 41 | run: | 42 | poetry run python benchmark/table_recognition.py --max_rows 5 43 | poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/table_rec_bench/results.json --bench_type table_recognition 44 | - name: Run texify benchmark 45 | run: | 46 | poetry run python benchmark/texify.py --max_rows 5 47 | poetry run python benchmark/utils/verify_benchmark_scores.py results/benchmark/texify_bench/results.json --bench_type texify -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Unit tests 2 | 3 | on: [push] 4 | 5 | env: 6 | TORCH_DEVICE: "cpu" 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | - name: Set up Python 3.11 14 | uses: actions/setup-python@v4 15 | with: 16 | python-version: 3.11 17 | - name: Install python dependencies 18 | run: | 19 | pip install poetry 20 | poetry install 21 | - name: Run tests 22 | run: poetry run pytest -------------------------------------------------------------------------------- /.github/workflows/cla.yml: -------------------------------------------------------------------------------- 1 | name: "Surya CLA Assistant" 2 | on: 3 | issue_comment: 4 | types: [created] 5 | pull_request_target: 6 | types: [opened,closed,synchronize] 7 | 8 | # explicitly configure permissions, in case your GITHUB_TOKEN workflow permissions are set to read-only in repository settings 9 | permissions: 10 | actions: write 11 | contents: write 12 | pull-requests: write 13 | statuses: write 14 | 15 | jobs: 16 | CLAAssistant: 17 | runs-on: ubuntu-latest 18 | steps: 19 | - name: "Surya CLA Assistant" 20 | if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') || github.event_name == 'pull_request_target' 21 | uses: contributor-assistant/github-action@v2.3.0 22 | env: 23 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 24 | # the below token should have repo scope and must be manually added by you in the repository's secret 25 | # This token is required only if you have configured to store the signatures in a remote repository/organization 26 | PERSONAL_ACCESS_TOKEN: ${{ secrets.PERSONAL_ACCESS_TOKEN }} 27 | with: 28 | path-to-signatures: 'signatures/version1/cla.json' 29 | path-to-document: 'https://github.com/VikParuchuri/surya/blob/master/CLA.md' 30 | # branch should not be protected 31 | branch: 'master' 32 | allowlist: VikParuchuri -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | on: 3 | push: 4 | tags: 5 | - "v*.*.*" 6 | jobs: 7 | build: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v3 11 | - name: Set up Python 3.11 12 | uses: actions/setup-python@v4 13 | with: 14 | python-version: 3.11 15 | - name: Install python dependencies 16 | run: | 17 | pip install poetry 18 | poetry install 19 | - name: Build package 20 | run: | 21 | poetry build 22 | - name: Publish package 23 | env: 24 | PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} 25 | run: | 26 | poetry config pypi-token.pypi "$PYPI_TOKEN" 27 | poetry publish 28 | -------------------------------------------------------------------------------- /.github/workflows/scripts.yml: -------------------------------------------------------------------------------- 1 | name: Test CLI scripts 2 | 3 | on: [push] 4 | 5 | env: 6 | TORCH_DEVICE: "cpu" 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | - name: Set up Python 3.11 14 | uses: actions/setup-python@v4 15 | with: 16 | python-version: 3.11 17 | - name: Install python dependencies 18 | run: | 19 | pip install poetry 20 | poetry install 21 | - name: Download benchmark data 22 | run: | 23 | wget -O benchmark_data.zip "https://drive.google.com/uc?export=download&id=1NHrdYatR1rtqs2gPVfdvO0BAvocH8CJi" 24 | unzip -o benchmark_data.zip 25 | - name: Test detection 26 | run: poetry run surya_detect benchmark_data/pdfs/switch_trans.pdf --page_range 0 27 | - name: Test OCR 28 | env: 29 | RECOGNITION_MAX_TOKENS: 25 30 | run: poetry run surya_ocr benchmark_data/pdfs/switch_trans.pdf --page_range 0 31 | - name: Test layout 32 | run: poetry run surya_layout benchmark_data/pdfs/switch_trans.pdf --page_range 0 33 | - name: Test table 34 | run: poetry run surya_table benchmark_data/pdfs/switch_trans.pdf --page_range 0 35 | - name: Test texify 36 | env: 37 | TEXIFY_MAX_TOKENS: 25 38 | run: poetry run surya_latex_ocr benchmark_data/pdfs/switch_trans.pdf --page_range 0 39 | - name: Test detection folder 40 | run: poetry run surya_detect benchmark_data/pdfs --page_range 0 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | private.py 2 | .DS_Store 3 | local.env 4 | experiments 5 | test_data 6 | training 7 | wandb 8 | notebooks 9 | results 10 | data 11 | slices 12 | 13 | # Byte-compiled / optimized / DLL files 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | share/python-wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .nox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | *.py,cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | cover/ 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | local_settings.py 73 | db.sqlite3 74 | db.sqlite3-journal 75 | 76 | # Flask stuff: 77 | instance/ 78 | .webassets-cache 79 | 80 | # Scrapy stuff: 81 | .scrapy 82 | 83 | # Sphinx documentation 84 | docs/_build/ 85 | 86 | # PyBuilder 87 | .pybuilder/ 88 | target/ 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | 93 | # IPython 94 | profile_default/ 95 | ipython_config.py 96 | 97 | # pyenv 98 | # For a library or package, you might want to ignore these files since the code is 99 | # intended to run in multiple environments; otherwise, check them in: 100 | # .python-version 101 | 102 | # pipenv 103 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 104 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 105 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 106 | # install all needed dependencies. 107 | #Pipfile.lock 108 | 109 | # poetry 110 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 111 | # This is especially recommended for binary packages to ensure reproducibility, and is more 112 | # commonly ignored for libraries. 113 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 114 | #poetry.lock 115 | 116 | # pdm 117 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 118 | #pdm.lock 119 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 120 | # in version control. 121 | # https://pdm.fming.dev/#use-with-ide 122 | .pdm.toml 123 | 124 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 125 | __pypackages__/ 126 | 127 | # Celery stuff 128 | celerybeat-schedule 129 | celerybeat.pid 130 | 131 | # SageMath parsed files 132 | *.sage.py 133 | 134 | # Environments 135 | .env 136 | .venv 137 | env/ 138 | venv/ 139 | ENV/ 140 | env.bak/ 141 | venv.bak/ 142 | 143 | # Spyder project settings 144 | .spyderproject 145 | .spyproject 146 | 147 | # Rope project settings 148 | .ropeproject 149 | 150 | # mkdocs documentation 151 | /site 152 | 153 | # mypy 154 | .mypy_cache/ 155 | .dmypy.json 156 | dmypy.json 157 | 158 | # Pyre type checker 159 | .pyre/ 160 | 161 | # pytype static type analyzer 162 | .pytype/ 163 | 164 | # Cython debug symbols 165 | cython_debug/ 166 | 167 | # PyCharm 168 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 169 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 170 | # and can be added to the global gitignore or merged into this file. For a more nuclear 171 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 172 | .idea/ 173 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | # Ruff version. 4 | rev: v0.9.10 5 | hooks: 6 | # Run the linter. 7 | - id: ruff 8 | types_or: [ python, pyi ] 9 | args: [ --fix ] 10 | # Run the formatter. 11 | - id: ruff-format 12 | types_or: [ python, pyi ] -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it using the following metadata." 3 | title: "Surya: A lightweight framework for analyzing documents and PDFs at scale" 4 | authors: 5 | - family-names: Paruchuri 6 | given-names: Vikas 7 | - name: Datalab Team 8 | date-released: 2025-05-13 9 | url: https://github.com/VikParuchuri/surya 10 | version: 0.14.0 11 | repository-code: https://github.com/VikParuchuri/surya -------------------------------------------------------------------------------- /CLA.md: -------------------------------------------------------------------------------- 1 | Surya Contributor Agreement 2 | 3 | This Surya Contributor Agreement ("SCA") applies to any contribution that you make to any product or project managed by us (the "project"), and sets out the intellectual property rights you grant to us in the contributed materials. The term "us" shall mean Endless Labs, Inc. The term "you" shall mean the person or entity identified below. 4 | 5 | If you agree to be bound by these terms, sign by writing "I have read the CLA document and I hereby sign the CLA" in response to the CLA bot Github comment. Read this agreement carefully before signing. These terms and conditions constitute a binding legal agreement. 6 | 7 | 1. The term 'contribution' or 'contributed materials' means any source code, object code, patch, tool, sample, graphic, specification, manual, documentation, or any other material posted or submitted by you to the project. 8 | 2. With respect to any worldwide copyrights, or copyright applications and registrations, in your contribution: 9 | - you hereby assign to us joint ownership, and to the extent that such assignment is or becomes invalid, ineffective or unenforceable, you hereby grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge, royalty free, unrestricted license to exercise all rights under those copyrights. This includes, at our option, the right to sublicense these same rights to third parties through multiple levels of sublicensees or other licensing arrangements, including dual-license structures for commercial customers; 10 | - you agree that each of us can do all things in relation to your contribution as if each of us were the sole owners, and if one of us makes a derivative work of your contribution, the one who makes the derivative work (or has it made will be the sole owner of that derivative work; 11 | - you agree that you will not assert any moral rights in your contribution against us, our licensees or transferees; 12 | - you agree that we may register a copyright in your contribution and exercise all ownership rights associated with it; and 13 | - you agree that neither of us has any duty to consult with, obtain the consent of, pay or render an accounting to the other for any use or distribution of vour contribution. 14 | 3. With respect to any patents you own, or that you can license without payment to any third party, you hereby grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge, royalty-free license to: 15 | - make, have made, use, sell, offer to sell, import, and otherwise transfer your contribution in whole or in part, alone or in combination with or included in any product, work or materials arising out of the project to which your contribution was submitted, and 16 | - at our option, to sublicense these same rights to third parties through multiple levels of sublicensees or other licensing arrangements. 17 | If you or your affiliates institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the contribution or any project it was submitted to constitutes direct or contributory patent infringement, then any patent licenses granted to you under this agreement for that contribution shall terminate as of the date such litigation is filed. 18 | 4. Except as set out above, you keep all right, title, and interest in your contribution. The rights that you grant to us under these terms are effective on the date you first submitted a contribution to us, even if your submission took place before the date you sign these terms. Any contribution we make available under any license will also be made available under a suitable FSF (Free Software Foundation) or OSI (Open Source Initiative) approved license. 19 | 5. You covenant, represent, warrant and agree that: 20 | - each contribution that you submit is and shall be an original work of authorship and you can legally grant the rights set out in this SCA; 21 | - to the best of your knowledge, each contribution will not violate any third party's copyrights, trademarks, patents, or other intellectual property rights; and 22 | - each contribution shall be in compliance with U.S. export control laws and other applicable export and import laws. 23 | You agree to notify us if you become aware of any circumstance which would make any of the foregoing representations inaccurate in any respect. Endless Labs, Inc. may publicly disclose your participation in the project, including the fact that you have signed the SCA. 24 | 6. This SCA is governed by the laws of the State of California and applicable U.S. Federal law. Any choice of law rules will not apply. -------------------------------------------------------------------------------- /benchmark/detection.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import copy 4 | import json 5 | 6 | import click 7 | 8 | from benchmark.utils.bbox import get_pdf_lines 9 | from benchmark.utils.metrics import precision_recall 10 | from benchmark.utils.tesseract import tesseract_parallel 11 | from surya.input.processing import open_pdf, get_page_images, convert_if_not_rgb 12 | from surya.debug.draw import draw_polys_on_image 13 | from surya.common.util import rescale_bbox 14 | from surya.settings import settings 15 | from surya.detection import DetectionPredictor 16 | 17 | import os 18 | import time 19 | from tabulate import tabulate 20 | import datasets 21 | 22 | 23 | @click.command(help="Benchmark detection model.") 24 | @click.option("--pdf_path", type=str, help="Path to PDF to detect bboxes in.", default=None) 25 | @click.option("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark")) 26 | @click.option("--max_rows", type=int, help="Maximum number of pdf pages to OCR.", default=100) 27 | @click.option("--debug", is_flag=True, help="Enable debug mode.", default=False) 28 | @click.option("--tesseract", is_flag=True, help="Run tesseract as well.", default=False) 29 | def main(pdf_path: str, results_dir: str, max_rows: int, debug: bool, tesseract: bool): 30 | det_predictor = DetectionPredictor() 31 | 32 | if pdf_path is not None: 33 | pathname = pdf_path 34 | doc = open_pdf(pdf_path) 35 | page_count = len(doc) 36 | page_indices = list(range(page_count)) 37 | page_indices = page_indices[:max_rows] 38 | 39 | images = get_page_images(doc, page_indices) 40 | doc.close() 41 | 42 | image_sizes = [img.size for img in images] 43 | correct_boxes = get_pdf_lines(pdf_path, image_sizes) 44 | else: 45 | pathname = "det_bench" 46 | # These have already been shuffled randomly, so sampling from the start is fine 47 | dataset = datasets.load_dataset(settings.DETECTOR_BENCH_DATASET_NAME, split=f"train[:{max_rows}]") 48 | images = list(dataset["image"]) 49 | images = convert_if_not_rgb(images) 50 | correct_boxes = [] 51 | for i, boxes in enumerate(dataset["bboxes"]): 52 | img_size = images[i].size 53 | # 1000,1000 is bbox size for doclaynet 54 | correct_boxes.append([rescale_bbox(b, (1000, 1000), img_size) for b in boxes]) 55 | 56 | if settings.DETECTOR_STATIC_CACHE: 57 | # Run through one batch to compile the model 58 | det_predictor(images[:1]) 59 | 60 | start = time.time() 61 | predictions = det_predictor(images) 62 | surya_time = time.time() - start 63 | 64 | if tesseract: 65 | start = time.time() 66 | tess_predictions = tesseract_parallel(images) 67 | tess_time = time.time() - start 68 | else: 69 | tess_predictions = [None] * len(images) 70 | tess_time = None 71 | 72 | folder_name = os.path.basename(pathname).split(".")[0] 73 | result_path = os.path.join(results_dir, folder_name) 74 | os.makedirs(result_path, exist_ok=True) 75 | 76 | page_metrics = collections.OrderedDict() 77 | for idx, (tb, sb, cb) in enumerate(zip(tess_predictions, predictions, correct_boxes)): 78 | surya_boxes = [s.bbox for s in sb.bboxes] 79 | surya_polys = [s.polygon for s in sb.bboxes] 80 | 81 | surya_metrics = precision_recall(surya_boxes, cb) 82 | if tb is not None: 83 | tess_metrics = precision_recall(tb, cb) 84 | else: 85 | tess_metrics = None 86 | 87 | page_metrics[idx] = { 88 | "surya": surya_metrics, 89 | "tesseract": tess_metrics 90 | } 91 | 92 | if debug: 93 | bbox_image = draw_polys_on_image(surya_polys, copy.deepcopy(images[idx])) 94 | bbox_image.save(os.path.join(result_path, f"{idx}_bbox.png")) 95 | 96 | mean_metrics = {} 97 | metric_types = sorted(page_metrics[0]["surya"].keys()) 98 | models = ["surya"] 99 | if tesseract: 100 | models.append("tesseract") 101 | 102 | for k in models: 103 | for m in metric_types: 104 | metric = [] 105 | for page in page_metrics: 106 | metric.append(page_metrics[page][k][m]) 107 | if k not in mean_metrics: 108 | mean_metrics[k] = {} 109 | mean_metrics[k][m] = sum(metric) / len(metric) 110 | 111 | out_data = { 112 | "times": { 113 | "surya": surya_time, 114 | "tesseract": tess_time 115 | }, 116 | "metrics": mean_metrics, 117 | "page_metrics": page_metrics 118 | } 119 | 120 | with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: 121 | json.dump(out_data, f, indent=4) 122 | 123 | table_headers = ["Model", "Time (s)", "Time per page (s)"] + metric_types 124 | table_data = [ 125 | ["surya", surya_time, surya_time / len(images)] + [mean_metrics["surya"][m] for m in metric_types], 126 | ] 127 | if tesseract: 128 | table_data.append( 129 | ["tesseract", tess_time, tess_time / len(images)] + [mean_metrics["tesseract"][m] for m in metric_types] 130 | ) 131 | 132 | print(tabulate(table_data, headers=table_headers, tablefmt="github")) 133 | print("Precision and recall are over the mutual coverage of the detected boxes and the ground truth boxes at a .5 threshold. There is a precision penalty for multiple boxes overlapping reference lines.") 134 | print(f"Wrote results to {result_path}") 135 | 136 | 137 | if __name__ == "__main__": 138 | main() 139 | -------------------------------------------------------------------------------- /benchmark/layout.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import copy 3 | import json 4 | 5 | import click 6 | 7 | from benchmark.utils.metrics import precision_recall 8 | from surya.layout import LayoutPredictor 9 | from surya.input.processing import convert_if_not_rgb 10 | from surya.debug.draw import draw_bboxes_on_image 11 | from surya.settings import settings 12 | import os 13 | import time 14 | from tabulate import tabulate 15 | import datasets 16 | 17 | 18 | @click.command(help="Benchmark surya layout model.") 19 | @click.option("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark")) 20 | @click.option("--max_rows", type=int, help="Maximum number of images to run benchmark on.", default=100) 21 | @click.option("--debug", is_flag=True, help="Run in debug mode.", default=False) 22 | def main(results_dir: str, max_rows: int, debug: bool): 23 | layout_predictor = LayoutPredictor() 24 | 25 | pathname = "layout_bench" 26 | # These have already been shuffled randomly, so sampling from the start is fine 27 | dataset = datasets.load_dataset(settings.LAYOUT_BENCH_DATASET_NAME, split=f"train[:{max_rows}]") 28 | images = list(dataset["image"]) 29 | images = convert_if_not_rgb(images) 30 | 31 | if settings.LAYOUT_STATIC_CACHE: 32 | layout_predictor(images[:1]) 33 | 34 | start = time.time() 35 | layout_predictions = layout_predictor(images) 36 | surya_time = time.time() - start 37 | 38 | folder_name = os.path.basename(pathname).split(".")[0] 39 | result_path = os.path.join(results_dir, folder_name) 40 | os.makedirs(result_path, exist_ok=True) 41 | 42 | label_alignment = { # First is publaynet, second is surya 43 | "Image": [["Figure"], ["Picture", "Figure"]], 44 | "Table": [["Table"], ["Table", "Form", "TableOfContents"]], 45 | "Text": [["Text"], ["Text", "Formula", "Footnote", "Caption", "TextInlineMath", "Code", "Handwriting"]], 46 | "List": [["List"], ["ListItem"]], 47 | "Title": [["Title"], ["SectionHeader", "Title"]] 48 | } 49 | 50 | page_metrics = collections.OrderedDict() 51 | for idx, pred in enumerate(layout_predictions): 52 | row = dataset[idx] 53 | all_correct_bboxes = [] 54 | page_results = {} 55 | for label_name in label_alignment: 56 | correct_cats, surya_cats = label_alignment[label_name] 57 | correct_bboxes = [b for b, l in zip(row["bboxes"], row["labels"]) if l in correct_cats] 58 | all_correct_bboxes.extend(correct_bboxes) 59 | pred_bboxes = [b.bbox for b in pred.bboxes if b.label in surya_cats] 60 | 61 | metrics = precision_recall(pred_bboxes, correct_bboxes, penalize_double=False) 62 | weight = len(correct_bboxes) 63 | metrics["weight"] = weight 64 | page_results[label_name] = metrics 65 | 66 | page_metrics[idx] = page_results 67 | 68 | if debug: 69 | bbox_image = draw_bboxes_on_image(all_correct_bboxes, copy.deepcopy(images[idx])) 70 | bbox_image.save(os.path.join(result_path, f"{idx}_layout.png")) 71 | 72 | mean_metrics = collections.defaultdict(dict) 73 | layout_types = sorted(page_metrics[0].keys()) 74 | metric_types = sorted(page_metrics[0][layout_types[0]].keys()) 75 | metric_types.remove("weight") 76 | for l in layout_types: 77 | for m in metric_types: 78 | metric = [] 79 | total = 0 80 | for page in page_metrics: 81 | metric.append(page_metrics[page][l][m] * page_metrics[page][l]["weight"]) 82 | total += page_metrics[page][l]["weight"] 83 | 84 | value = sum(metric) 85 | if value > 0: 86 | value /= total 87 | mean_metrics[l][m] = value 88 | 89 | out_data = { 90 | "time": surya_time, 91 | "metrics": mean_metrics, 92 | "page_metrics": page_metrics 93 | } 94 | 95 | with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: 96 | json.dump(out_data, f, indent=4) 97 | 98 | table_headers = ["Layout Type", ] + metric_types 99 | table_data = [] 100 | for layout_type in layout_types: 101 | table_data.append([layout_type, ] + [f"{mean_metrics[layout_type][m]:.5f}" for m in metric_types]) 102 | 103 | print(tabulate(table_data, headers=table_headers, tablefmt="github")) 104 | print(f"Took {surya_time / len(images):.5f} seconds per image, and {surya_time:.5f} seconds total.") 105 | print("Precision and recall are over the mutual coverage of the detected boxes and the ground truth boxes at a .5 threshold.") 106 | print(f"Wrote results to {result_path}") 107 | 108 | 109 | if __name__ == "__main__": 110 | main() 111 | -------------------------------------------------------------------------------- /benchmark/ordering.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import json 3 | 4 | import click 5 | 6 | from surya.input.processing import convert_if_not_rgb 7 | from surya.layout import LayoutPredictor 8 | from surya.common.polygon import PolygonBox 9 | from surya.settings import settings 10 | from benchmark.utils.metrics import rank_accuracy 11 | import os 12 | import time 13 | import datasets 14 | 15 | 16 | @click.command(help="Benchmark surya layout for reading order.") 17 | @click.option("--results_dir", type=str, help="Path to JSON file with benchmark results.", default=os.path.join(settings.RESULT_DIR, "benchmark")) 18 | @click.option("--max_rows", type=int, help="Maximum number of images to run benchmark on.", default=None) 19 | def main(results_dir: str, max_rows: int): 20 | layout_predictor = LayoutPredictor() 21 | pathname = "order_bench" 22 | # These have already been shuffled randomly, so sampling from the start is fine 23 | split = "train" 24 | if max_rows is not None: 25 | split = f"train[:{max_rows}]" 26 | dataset = datasets.load_dataset(settings.ORDER_BENCH_DATASET_NAME, split=split) 27 | images = list(dataset["image"]) 28 | images = convert_if_not_rgb(images) 29 | 30 | start = time.time() 31 | layout_predictions = layout_predictor(images) 32 | surya_time = time.time() - start 33 | 34 | folder_name = os.path.basename(pathname).split(".")[0] 35 | result_path = os.path.join(results_dir, folder_name) 36 | os.makedirs(result_path, exist_ok=True) 37 | 38 | page_metrics = collections.OrderedDict() 39 | mean_accuracy = 0 40 | for idx, order_pred in enumerate(layout_predictions): 41 | row = dataset[idx] 42 | labels = row["labels"] 43 | bboxes = row["bboxes"] 44 | pred_positions = [] 45 | for label, bbox in zip(labels, bboxes): 46 | max_intersection = 0 47 | matching_idx = 0 48 | for pred_box in order_pred.bboxes: 49 | intersection = pred_box.intersection_pct(PolygonBox(polygon=bbox)) 50 | if intersection > max_intersection: 51 | max_intersection = intersection 52 | matching_idx = pred_box.position 53 | pred_positions.append(matching_idx) 54 | accuracy = rank_accuracy(pred_positions, labels) 55 | mean_accuracy += accuracy 56 | page_results = { 57 | "accuracy": accuracy, 58 | "box_count": len(labels) 59 | } 60 | 61 | page_metrics[idx] = page_results 62 | 63 | mean_accuracy /= len(layout_predictions) 64 | 65 | out_data = { 66 | "time": surya_time, 67 | "mean_accuracy": mean_accuracy, 68 | "page_metrics": page_metrics 69 | } 70 | 71 | with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: 72 | json.dump(out_data, f, indent=4) 73 | 74 | print(f"Mean accuracy is {mean_accuracy:.2f}.") 75 | print(f"Took {surya_time / len(images):.2f} seconds per image, and {surya_time:.1f} seconds total.") 76 | print("Mean accuracy is the % of correct ranking pairs.") 77 | print(f"Wrote results to {result_path}") 78 | 79 | 80 | if __name__ == "__main__": 81 | main() 82 | -------------------------------------------------------------------------------- /benchmark/texify.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import re 3 | import time 4 | from pathlib import Path 5 | from typing import List 6 | 7 | import click 8 | import datasets 9 | from tabulate import tabulate 10 | from bs4 import BeautifulSoup 11 | 12 | from surya.common.surya.schema import TaskNames 13 | from surya.settings import settings 14 | from surya.recognition import RecognitionPredictor, OCRResult 15 | import json 16 | from rapidfuzz.distance import Levenshtein 17 | 18 | 19 | def normalize_text(text): 20 | soup = BeautifulSoup(text, "html.parser") 21 | # Unwrap math tags 22 | for tag in soup.find_all(): 23 | if tag.name == "math": 24 | tag.unwrap() 25 | text = soup.get_text() 26 | text = re.sub(r"\n", " ", text) 27 | text = re.sub(r"\s+", " ", text) 28 | return text.strip() 29 | 30 | 31 | def score_text(predictions, references): 32 | lev_dist = [] 33 | for p, r in zip(predictions, references): 34 | p = normalize_text(p) 35 | r = normalize_text(r) 36 | lev_dist.append(Levenshtein.normalized_distance(p, r)) 37 | 38 | return sum(lev_dist) / len(lev_dist) 39 | 40 | 41 | def inference_texify(source_data, predictor: RecognitionPredictor): 42 | images = [sd["image"] for sd in source_data] 43 | tasks = [TaskNames.block_without_boxes] * len(images) 44 | bboxes = [[[0, 0, image.width, image.height]] for image in images] 45 | texify_predictions: List[OCRResult] = predictor(images, tasks, bboxes=bboxes) 46 | out_data = [ 47 | { 48 | "text": texify_predictions[i].text_lines[0].text, 49 | "equation": source_data[i]["equation"], 50 | } 51 | for i in range(len(texify_predictions)) 52 | ] 53 | 54 | return out_data 55 | 56 | 57 | @click.command(help="Benchmark the performance of texify.") 58 | @click.option( 59 | "--ds_name", 60 | type=str, 61 | help="Path to dataset file with source images/equations.", 62 | default=settings.TEXIFY_BENCHMARK_DATASET, 63 | ) 64 | @click.option( 65 | "--results_dir", 66 | type=str, 67 | help="Path to JSON file with benchmark results.", 68 | default=os.path.join(settings.RESULT_DIR, "benchmark"), 69 | ) 70 | @click.option( 71 | "--max_rows", type=int, help="Maximum number of images to benchmark.", default=None 72 | ) 73 | def main(ds_name: str, results_dir: str, max_rows: int): 74 | predictor = RecognitionPredictor() 75 | ds = datasets.load_dataset(ds_name, split="train") 76 | 77 | if max_rows: 78 | ds = ds.filter(lambda x, idx: idx < max_rows, with_indices=True) 79 | 80 | start = time.time() 81 | predictions = inference_texify(ds, predictor) 82 | time_taken = time.time() - start 83 | 84 | text = [p["text"] for p in predictions] 85 | references = [p["equation"] for p in predictions] 86 | scores = score_text(text, references) 87 | 88 | write_data = { 89 | "scores": scores, 90 | "text": [{"prediction": p, "reference": r} for p, r in zip(text, references)], 91 | } 92 | 93 | score_table = [["texify", write_data["scores"], time_taken]] 94 | score_headers = ["edit", "time taken (s)"] 95 | score_dirs = ["⬇", "⬇"] 96 | 97 | score_headers = [f"{h} {d}" for h, d in zip(score_headers, score_dirs)] 98 | table = tabulate(score_table, headers=["Method", *score_headers]) 99 | print() 100 | print(table) 101 | 102 | result_path = Path(results_dir) / "texify_bench" 103 | result_path.mkdir(parents=True, exist_ok=True) 104 | with open(result_path / "results.json", "w", encoding="utf-8") as f: 105 | json.dump(write_data, f, indent=4) 106 | 107 | 108 | if __name__ == "__main__": 109 | main() 110 | -------------------------------------------------------------------------------- /benchmark/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/benchmark/utils/__init__.py -------------------------------------------------------------------------------- /benchmark/utils/bbox.py: -------------------------------------------------------------------------------- 1 | import fitz as pymupdf 2 | from surya.common.util import rescale_bbox 3 | 4 | 5 | def get_pdf_lines(pdf_path, img_sizes): 6 | doc = pymupdf.open(pdf_path) 7 | page_lines = [] 8 | for idx, img_size in enumerate(img_sizes): 9 | page = doc[idx] 10 | blocks = page.get_text("dict", sort=True, flags=pymupdf.TEXTFLAGS_DICT & ~pymupdf.TEXT_PRESERVE_LIGATURES & ~pymupdf.TEXT_PRESERVE_IMAGES)["blocks"] 11 | 12 | line_boxes = [] 13 | for block_idx, block in enumerate(blocks): 14 | for l in block["lines"]: 15 | line_boxes.append(list(l["bbox"])) 16 | 17 | page_box = page.bound() 18 | pwidth, pheight = page_box[2] - page_box[0], page_box[3] - page_box[1] 19 | line_boxes = [rescale_bbox(bbox, (pwidth, pheight), img_size) for bbox in line_boxes] 20 | page_lines.append(line_boxes) 21 | 22 | return page_lines 23 | 24 | def merge_boxes(box1, box2): 25 | return (min(box1[0], box2[0]), min(box1[1], box2[1]), max(box1[2], box2[2]), max(box1[3], box2[3])) 26 | 27 | 28 | def join_lines(bboxes, max_gap=5): 29 | to_merge = {} 30 | for i, box1 in bboxes: 31 | for z, box2 in bboxes[i + 1:]: 32 | j = i + z + 1 33 | if box1 == box2: 34 | continue 35 | 36 | if box1[0] <= box2[0] and box1[2] >= box2[2]: 37 | if abs(box1[1] - box2[3]) <= max_gap: 38 | if i not in to_merge: 39 | to_merge[i] = [] 40 | to_merge[i].append(j) 41 | 42 | merged_boxes = set() 43 | merged = [] 44 | for i, box in bboxes: 45 | if i in merged_boxes: 46 | continue 47 | 48 | if i in to_merge: 49 | for j in to_merge[i]: 50 | box = merge_boxes(box, bboxes[j][1]) 51 | merged_boxes.add(j) 52 | 53 | merged.append(box) 54 | return merged 55 | -------------------------------------------------------------------------------- /benchmark/utils/scoring.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List 3 | 4 | from rapidfuzz import fuzz 5 | 6 | 7 | def overlap_score(pred_lines: List[str], reference_lines: List[str]): 8 | line_scores = [] 9 | line_weights = [] 10 | line_match = {} 11 | for i, pred_line in enumerate(pred_lines): 12 | max_score = 0 13 | line_weight = 1 14 | match = None 15 | for j, ref_line in enumerate(reference_lines): 16 | score = fuzz.ratio(pred_line, ref_line, score_cutoff=20) / 100 17 | if score > max_score: 18 | max_score = score 19 | line_weight = math.sqrt(len(ref_line)) 20 | match = j 21 | line_scores.append(max_score) 22 | line_weights.append(line_weight) 23 | line_match[i] = match 24 | line_scores = [line_scores[i] * line_weights[i] for i in range(len(line_scores))] 25 | 26 | return line_scores, line_weights, line_match 27 | 28 | 29 | def overlap_score_exact(pred_lines: List[str], reference_lines: List[str]): 30 | line_scores = [] 31 | line_weights = [] 32 | assert len(pred_lines) == len(reference_lines) 33 | 34 | for i, (pred_line, ref_line) in enumerate(zip(pred_lines, reference_lines)): 35 | score = fuzz.ratio(pred_line, ref_line, score_cutoff=20) / 100 36 | weight = math.sqrt(len(ref_line)) 37 | line_scores.append(score * weight) 38 | line_weights.append(weight) 39 | 40 | return line_scores, line_weights 41 | -------------------------------------------------------------------------------- /benchmark/utils/tatr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForObjectDetection 3 | from surya.settings import settings 4 | import numpy as np 5 | 6 | 7 | class MaxResize(object): 8 | def __init__(self, max_size=800): 9 | self.max_size = max_size 10 | 11 | def __call__(self, image): 12 | width, height = image.size 13 | current_max_size = max(width, height) 14 | scale = self.max_size / current_max_size 15 | resized_image = image.resize((int(round(scale * width)), int(round(scale * height)))) 16 | 17 | return resized_image 18 | 19 | 20 | def to_tensor(image): 21 | # Convert PIL Image to NumPy array 22 | np_image = np.array(image).astype(np.float32) 23 | 24 | # Rearrange dimensions to [C, H, W] format 25 | np_image = np_image.transpose((2, 0, 1)) 26 | 27 | # Normalize to [0.0, 1.0] 28 | np_image /= 255.0 29 | 30 | return torch.from_numpy(np_image) 31 | 32 | 33 | def normalize(tensor, mean, std): 34 | for t, m, s in zip(tensor, mean, std): 35 | t.sub_(m).div_(s) 36 | return tensor 37 | 38 | 39 | def structure_transform(image): 40 | image = MaxResize(1000)(image) 41 | tensor = to_tensor(image) 42 | normalized_tensor = normalize(tensor, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 43 | return normalized_tensor 44 | 45 | 46 | def box_cxcywh_to_xyxy(x): 47 | x_c, y_c, w, h = x.unbind(-1) 48 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] 49 | return torch.stack(b, dim=1) 50 | 51 | 52 | def rescale_bboxes(out_bbox, size): 53 | width, height = size 54 | boxes = box_cxcywh_to_xyxy(out_bbox) 55 | boxes = boxes * torch.tensor([width, height, width, height], dtype=torch.float32) 56 | return boxes 57 | 58 | 59 | def outputs_to_objects(outputs, img_sizes, id2label): 60 | m = outputs.logits.softmax(-1).max(-1) 61 | batch_labels = list(m.indices.detach().cpu().numpy()) 62 | batch_scores = list(m.values.detach().cpu().numpy()) 63 | batch_bboxes = outputs['pred_boxes'].detach().cpu() 64 | 65 | batch_objects = [] 66 | for i in range(len(img_sizes)): 67 | pred_bboxes = [elem.tolist() for elem in rescale_bboxes(batch_bboxes[i], img_sizes[i])] 68 | pred_scores = batch_scores[i] 69 | pred_labels = batch_labels[i] 70 | 71 | objects = [] 72 | for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes): 73 | class_label = id2label[int(label)] 74 | if not class_label == 'no object': 75 | objects.append({ 76 | 'label': class_label, 77 | 'score': float(score), 78 | 'bbox': [float(elem) for elem in bbox]} 79 | ) 80 | 81 | rows = [] 82 | cols = [] 83 | for cell in objects: 84 | if cell["label"] == "table column": 85 | cols.append(cell) 86 | 87 | if cell["label"] == "table row": 88 | rows.append(cell) 89 | batch_objects.append({ 90 | "rows": rows, 91 | "cols": cols 92 | }) 93 | 94 | return batch_objects 95 | 96 | 97 | def load_tatr(): 98 | return AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition-v1.1-all").to(settings.TORCH_DEVICE_MODEL) 99 | 100 | 101 | def batch_inference_tatr(model, images, batch_size): 102 | device = model.device 103 | rows_cols = [] 104 | for i in range(0, len(images), batch_size): 105 | batch_images = images[i:i + batch_size] 106 | pixel_values = torch.stack([structure_transform(img) for img in batch_images], dim=0).to(device) 107 | 108 | # forward pass 109 | with torch.no_grad(): 110 | outputs = model(pixel_values) 111 | 112 | id2label = model.config.id2label 113 | id2label[len(model.config.id2label)] = "no object" 114 | rows_cols.extend(outputs_to_objects(outputs, [img.size for img in batch_images], id2label)) 115 | return rows_cols -------------------------------------------------------------------------------- /benchmark/utils/tesseract.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | from surya.input.processing import slice_bboxes_from_image 7 | from surya.settings import settings 8 | import os 9 | from concurrent.futures import ProcessPoolExecutor 10 | from surya.recognition.languages import CODE_TO_LANGUAGE 11 | from surya.recognition import RecognitionPredictor 12 | from surya.detection import DetectionPredictor 13 | 14 | 15 | def surya_lang_to_tesseract(code: str) -> Optional[str]: 16 | lang_str = CODE_TO_LANGUAGE[code] 17 | try: 18 | tess_lang = TESS_LANGUAGE_TO_CODE[lang_str] 19 | except KeyError: 20 | return None 21 | return tess_lang 22 | 23 | 24 | def tesseract_ocr(img, bboxes, lang: str): 25 | import pytesseract 26 | line_imgs = slice_bboxes_from_image(img, bboxes) 27 | config = f'--tessdata-dir "{settings.TESSDATA_PREFIX}"' 28 | lines = [] 29 | for line_img in line_imgs: 30 | line = pytesseract.image_to_string(line_img, lang=lang, config=config) 31 | lines.append(line) 32 | return lines 33 | 34 | 35 | def tesseract_ocr_parallel(imgs, bboxes, langs: List[str], cpus=None): 36 | tess_parallel_cores = min(len(imgs), RecognitionPredictor.get_batch_size()) 37 | if not cpus: 38 | cpus = os.cpu_count() 39 | tess_parallel_cores = min(tess_parallel_cores, cpus) 40 | 41 | # Tesseract uses up to 4 processes per instance 42 | # Divide by 2 because tesseract doesn't seem to saturate all 4 cores with these small images 43 | tess_parallel = max(tess_parallel_cores // 2, 1) 44 | 45 | with ProcessPoolExecutor(max_workers=tess_parallel) as executor: 46 | tess_text = tqdm(executor.map(tesseract_ocr, imgs, bboxes, langs), total=len(imgs), desc="Running tesseract OCR") 47 | tess_text = list(tess_text) 48 | return tess_text 49 | 50 | 51 | def tesseract_bboxes(img): 52 | import pytesseract 53 | from pytesseract import Output 54 | arr_img = np.asarray(img, dtype=np.uint8) 55 | ocr = pytesseract.image_to_data(arr_img, output_type=Output.DICT) 56 | 57 | bboxes = [] 58 | n_boxes = len(ocr['level']) 59 | for i in range(n_boxes): 60 | # It is possible to merge by line here with line number, but it gives bad results. 61 | _, x, y, w, h = ocr['text'][i], ocr['left'][i], ocr['top'][i], ocr['width'][i], ocr['height'][i] 62 | bbox = (x, y, x + w, y + h) 63 | bboxes.append(bbox) 64 | 65 | return bboxes 66 | 67 | 68 | def tesseract_parallel(imgs): 69 | # Tesseract uses 4 threads per instance 70 | tess_parallel_cores = min(len(imgs), DetectionPredictor.get_batch_size()) 71 | cpus = os.cpu_count() 72 | tess_parallel_cores = min(tess_parallel_cores, cpus) 73 | 74 | # Tesseract uses 4 threads per instance 75 | tess_parallel = max(tess_parallel_cores // 4, 1) 76 | 77 | with ProcessPoolExecutor(max_workers=tess_parallel) as executor: 78 | tess_bboxes = tqdm(executor.map(tesseract_bboxes, imgs), total=len(imgs), desc="Running tesseract bbox detection") 79 | tess_bboxes = list(tess_bboxes) 80 | return tess_bboxes 81 | 82 | 83 | TESS_CODE_TO_LANGUAGE = { 84 | "afr": "Afrikaans", 85 | "amh": "Amharic", 86 | "ara": "Arabic", 87 | "asm": "Assamese", 88 | "aze": "Azerbaijani", 89 | "bel": "Belarusian", 90 | "ben": "Bengali", 91 | "bod": "Tibetan", 92 | "bos": "Bosnian", 93 | "bre": "Breton", 94 | "bul": "Bulgarian", 95 | "cat": "Catalan", 96 | "ceb": "Cebuano", 97 | "ces": "Czech", 98 | "chi_sim": "Chinese", 99 | "chr": "Cherokee", 100 | "cym": "Welsh", 101 | "dan": "Danish", 102 | "deu": "German", 103 | "dzo": "Dzongkha", 104 | "ell": "Greek", 105 | "eng": "English", 106 | "epo": "Esperanto", 107 | "est": "Estonian", 108 | "eus": "Basque", 109 | "fas": "Persian", 110 | "fin": "Finnish", 111 | "fra": "French", 112 | "fry": "Western Frisian", 113 | "guj": "Gujarati", 114 | "gla": "Scottish Gaelic", 115 | "gle": "Irish", 116 | "glg": "Galician", 117 | "heb": "Hebrew", 118 | "hin": "Hindi", 119 | "hrv": "Croatian", 120 | "hun": "Hungarian", 121 | "hye": "Armenian", 122 | "iku": "Inuktitut", 123 | "ind": "Indonesian", 124 | "isl": "Icelandic", 125 | "ita": "Italian", 126 | "jav": "Javanese", 127 | "jpn": "Japanese", 128 | "kan": "Kannada", 129 | "kat": "Georgian", 130 | "kaz": "Kazakh", 131 | "khm": "Khmer", 132 | "kir": "Kyrgyz", 133 | "kor": "Korean", 134 | "lao": "Lao", 135 | "lat": "Latin", 136 | "lav": "Latvian", 137 | "lit": "Lithuanian", 138 | "mal": "Malayalam", 139 | "mar": "Marathi", 140 | "mkd": "Macedonian", 141 | "mlt": "Maltese", 142 | "mon": "Mongolian", 143 | "msa": "Malay", 144 | "mya": "Burmese", 145 | "nep": "Nepali", 146 | "nld": "Dutch", 147 | "nor": "Norwegian", 148 | "ori": "Oriya", 149 | "pan": "Punjabi", 150 | "pol": "Polish", 151 | "por": "Portuguese", 152 | "pus": "Pashto", 153 | "ron": "Romanian", 154 | "rus": "Russian", 155 | "san": "Sanskrit", 156 | "sin": "Sinhala", 157 | "slk": "Slovak", 158 | "slv": "Slovenian", 159 | "snd": "Sindhi", 160 | "spa": "Spanish", 161 | "sqi": "Albanian", 162 | "srp": "Serbian", 163 | "swa": "Swahili", 164 | "swe": "Swedish", 165 | "syr": "Syriac", 166 | "tam": "Tamil", 167 | "tel": "Telugu", 168 | "tgk": "Tajik", 169 | "tha": "Thai", 170 | "tir": "Tigrinya", 171 | "tur": "Turkish", 172 | "uig": "Uyghur", 173 | "ukr": "Ukrainian", 174 | "urd": "Urdu", 175 | "uzb": "Uzbek", 176 | "vie": "Vietnamese", 177 | "yid": "Yiddish" 178 | } 179 | 180 | TESS_LANGUAGE_TO_CODE = {v:k for k,v in TESS_CODE_TO_LANGUAGE.items()} 181 | -------------------------------------------------------------------------------- /benchmark/utils/textract.py: -------------------------------------------------------------------------------- 1 | import os 2 | from concurrent.futures import ThreadPoolExecutor 3 | from tqdm import tqdm 4 | import traceback 5 | 6 | from surya.input.processing import slice_bboxes_from_image 7 | from surya.recognition import RecognitionPredictor 8 | 9 | def textract_ocr(extractor, img): 10 | try: 11 | document = extractor.detect_document_text(file_source=img) 12 | return [line.text for line in document.lines] 13 | except: 14 | traceback.print_exc() 15 | return [None] 16 | 17 | def textract_ocr_parallel(imgs, cpus=None): 18 | from textractor import Textractor # Optional dependency 19 | 20 | extractor = Textractor(profile_name='default') 21 | parallel_cores = min(len(imgs), RecognitionPredictor().get_batch_size()) 22 | if not cpus: 23 | cpus = os.cpu_count() 24 | parallel_cores = min(parallel_cores, cpus) 25 | 26 | with ThreadPoolExecutor(max_workers=parallel_cores) as executor: 27 | textract_text = tqdm(executor.map(textract_ocr, [extractor]*len(imgs), imgs), total=len(imgs), desc="Running textract OCR") 28 | textract_text = list(textract_text) 29 | return textract_text -------------------------------------------------------------------------------- /benchmark/utils/verify_benchmark_scores.py: -------------------------------------------------------------------------------- 1 | import json 2 | import click 3 | 4 | 5 | def verify_layout(data): 6 | scores = data["metrics"] 7 | for layout_type, metrics in scores.items(): 8 | if layout_type == "List": # Skip lists since none appear early on 9 | continue 10 | 11 | if metrics["precision"] <= 0.6 or metrics["recall"] <= 0.6: 12 | raise ValueError("Scores do not meet the required threshold") 13 | 14 | 15 | def verify_det(data): 16 | scores = data["metrics"]["surya"] 17 | if scores["precision"] <= 0.9 or scores["recall"] <= 0.9: 18 | raise ValueError("Scores do not meet the required threshold") 19 | 20 | 21 | def verify_rec(data): 22 | scores = data["surya"] 23 | if scores["avg_score"] <= 0.9: 24 | raise ValueError("Scores do not meet the required threshold") 25 | 26 | 27 | def verify_order(data): 28 | score = data["mean_accuracy"] 29 | if score < 0.75: 30 | raise ValueError("Scores do not meet the required threshold") 31 | 32 | 33 | def verify_table_rec(data): 34 | row_score = data["surya"]["mean_row_iou"] 35 | col_score = data["surya"]["mean_col_iou"] 36 | 37 | if row_score < 0.75 or col_score < 0.75: 38 | raise ValueError("Scores do not meet the required threshold") 39 | 40 | 41 | def verify_texify(data): 42 | edit_dist = data["scores"] 43 | if edit_dist > 0.2: 44 | raise ValueError("Scores do not meet the required threshold") 45 | 46 | 47 | @click.command(help="Verify benchmark scores") 48 | @click.argument("file_path", type=str) 49 | @click.option( 50 | "--bench_type", type=str, help="Type of benchmark to verify", default="detection" 51 | ) 52 | def main(file_path, bench_type): 53 | with open(file_path, "r") as file: 54 | data = json.load(file) 55 | 56 | if bench_type == "detection": 57 | verify_det(data) 58 | elif bench_type == "recognition": 59 | verify_rec(data) 60 | elif bench_type == "layout": 61 | verify_layout(data) 62 | elif bench_type == "ordering": 63 | verify_order(data) 64 | elif bench_type == "table_recognition": 65 | verify_table_rec(data) 66 | elif bench_type == "texify": 67 | verify_texify(data) 68 | else: 69 | raise ValueError("Invalid benchmark type") 70 | 71 | 72 | if __name__ == "__main__": 73 | main() 74 | -------------------------------------------------------------------------------- /detect_layout.py: -------------------------------------------------------------------------------- 1 | from surya.scripts.detect_layout import detect_layout_cli 2 | 3 | if __name__ == "__main__": 4 | detect_layout_cli() 5 | -------------------------------------------------------------------------------- /detect_text.py: -------------------------------------------------------------------------------- 1 | from surya.scripts.detect_text import detect_text_cli 2 | 3 | if __name__ == "__main__": 4 | detect_text_cli() 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /ocr_app.py: -------------------------------------------------------------------------------- 1 | from surya.scripts.run_streamlit_app import streamlit_app_cli 2 | 3 | if __name__ == "__main__": 4 | streamlit_app_cli() -------------------------------------------------------------------------------- /ocr_latex.py: -------------------------------------------------------------------------------- 1 | from surya.scripts.ocr_latex import ocr_latex_cli 2 | 3 | if __name__ == "__main__": 4 | ocr_latex_cli() 5 | -------------------------------------------------------------------------------- /ocr_text.py: -------------------------------------------------------------------------------- 1 | from surya.scripts.ocr_text import ocr_text_cli 2 | 3 | if __name__ == "__main__": 4 | ocr_text_cli() 5 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "surya-ocr" 3 | version = "0.14.5" 4 | description = "OCR, layout, reading order, and table recognition in 90+ languages" 5 | authors = ["Vik Paruchuri "] 6 | readme = "README.md" 7 | license = "GPL-3.0-or-later" 8 | repository = "https://github.com/VikParuchuri/surya" 9 | keywords = ["ocr", "pdf", "text detection", "text recognition", "tables"] 10 | packages = [ 11 | {include = "surya"} 12 | ] 13 | 14 | [tool.poetry.dependencies] 15 | python = "^3.10" 16 | transformers = "^4.51.2" 17 | torch = "^2.7.0" 18 | pydantic = "^2.5.3" 19 | pydantic-settings = "^2.1.0" 20 | python-dotenv = "^1.0.0" 21 | pillow = "^10.2.0" 22 | pypdfium2 = "=4.30.0" 23 | filetype = "^1.2.0" 24 | click = "^8.1.8" 25 | platformdirs = "^4.3.6" 26 | opencv-python-headless = "^4.11.0.86" 27 | einops = "^0.8.1" 28 | pre-commit = "^4.2.0" 29 | 30 | [tool.poetry.group.dev.dependencies] 31 | jupyter = "^1.0.0" 32 | pytesseract = "^0.3.10" 33 | pymupdf = "^1.23.8" 34 | datasets = "^2.16.1" 35 | rapidfuzz = "^3.6.1" 36 | streamlit = "^1.31.0" 37 | pytest = "^8.3.4" 38 | pdftext = "^0.5.1" 39 | tabulate = "^0.9.0" 40 | 41 | [tool.poetry.scripts] 42 | surya_detect = "surya.scripts.detect_text:detect_text_cli" 43 | surya_ocr = "surya.scripts.ocr_text:ocr_text_cli" 44 | surya_layout = "surya.scripts.detect_layout:detect_layout_cli" 45 | surya_gui = "surya.scripts.run_streamlit_app:streamlit_app_cli" 46 | surya_table = "surya.scripts.table_recognition:table_recognition_cli" 47 | surya_latex_ocr = "surya.scripts.ocr_latex:ocr_latex_cli" 48 | texify_gui = "surya.scripts.run_texify_app:texify_app_cli" 49 | 50 | [build-system] 51 | requires = ["poetry-core"] 52 | build-backend = "poetry.core.masonry.api" 53 | 54 | [[tool.poetry.source]] 55 | name = "libtpu-releases" 56 | url = "https://storage.googleapis.com/libtpu-releases/index.html" 57 | priority = "supplemental" 58 | 59 | [[tool.poetry.source]] 60 | name = "libtpu-wheels" 61 | url = "https://storage.googleapis.com/libtpu-wheels/index.html" 62 | priority = "supplemental" 63 | 64 | [tool.poetry.group.xla] 65 | optional = true 66 | 67 | [tool.poetry.group.xla.dependencies] 68 | torch-xla = {version = "^2.4.1", extras = ["tpu"]} 69 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | testpaths=tests 3 | pythonpath=. 4 | filterwarnings = 5 | ignore::UserWarning 6 | ignore::PendingDeprecationWarning 7 | ignore::DeprecationWarning -------------------------------------------------------------------------------- /signatures/version1/cla.json: -------------------------------------------------------------------------------- 1 | { 2 | "signedContributors": [ 3 | { 4 | "name": "rishiraj", 5 | "id": 44090649, 6 | "comment_id": 2170578748, 7 | "created_at": "2024-06-15T19:31:20Z", 8 | "repoId": 741297064, 9 | "pullRequestNo": 135 10 | }, 11 | { 12 | "name": "mmacvicar", 13 | "id": 59354, 14 | "comment_id": 2236493182, 15 | "created_at": "2024-07-18T13:17:43Z", 16 | "repoId": 741297064, 17 | "pullRequestNo": 152 18 | }, 19 | { 20 | "name": "jimexist", 21 | "id": 622789, 22 | "comment_id": 2255151376, 23 | "created_at": "2024-07-29T07:23:55Z", 24 | "repoId": 741297064, 25 | "pullRequestNo": 160 26 | }, 27 | { 28 | "name": "michaeldriscoll-avant", 29 | "id": 85255083, 30 | "comment_id": 2259143427, 31 | "created_at": "2024-07-30T20:21:33Z", 32 | "repoId": 741297064, 33 | "pullRequestNo": 161 34 | }, 35 | { 36 | "name": "EdoardoPona", 37 | "id": 29152472, 38 | "comment_id": 2271115922, 39 | "created_at": "2024-08-06T11:58:00Z", 40 | "repoId": 741297064, 41 | "pullRequestNo": 167 42 | }, 43 | { 44 | "name": "hidenori-endo", 45 | "id": 15546605, 46 | "comment_id": 2307217499, 47 | "created_at": "2024-08-23T14:31:17Z", 48 | "repoId": 741297064, 49 | "pullRequestNo": 182 50 | }, 51 | { 52 | "name": "dobosevych", 53 | "id": 12053536, 54 | "comment_id": 2430376828, 55 | "created_at": "2024-10-22T21:48:34Z", 56 | "repoId": 741297064, 57 | "pullRequestNo": 220 58 | }, 59 | { 60 | "name": "iammosespaulr", 61 | "id": 28682735, 62 | "comment_id": 2447941238, 63 | "created_at": "2024-10-30T17:55:23Z", 64 | "repoId": 741297064, 65 | "pullRequestNo": 235 66 | }, 67 | { 68 | "name": "ArthurMor4is", 69 | "id": 42987302, 70 | "comment_id": 2515315717, 71 | "created_at": "2024-12-03T18:37:45Z", 72 | "repoId": 741297064, 73 | "pullRequestNo": 255 74 | }, 75 | { 76 | "name": "tarun-menta", 77 | "id": 66506307, 78 | "comment_id": 2543457960, 79 | "created_at": "2024-12-15T05:43:33Z", 80 | "repoId": 741297064, 81 | "pullRequestNo": 261 82 | }, 83 | { 84 | "name": "jonaskahn", 85 | "id": 4338500, 86 | "comment_id": 2556622097, 87 | "created_at": "2024-12-20T09:36:20Z", 88 | "repoId": 741297064, 89 | "pullRequestNo": 269 90 | }, 91 | { 92 | "name": "kumsumit", 93 | "id": 95072784, 94 | "comment_id": 2574534622, 95 | "created_at": "2025-01-07T07:05:59Z", 96 | "repoId": 741297064, 97 | "pullRequestNo": 276 98 | }, 99 | { 100 | "name": "kevinhu", 101 | "id": 6051736, 102 | "comment_id": 2614135351, 103 | "created_at": "2025-01-25T23:34:12Z", 104 | "repoId": 741297064, 105 | "pullRequestNo": 291 106 | } 107 | ] 108 | } -------------------------------------------------------------------------------- /static/fonts/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /static/images/arabic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/arabic.jpg -------------------------------------------------------------------------------- /static/images/arabic_layout.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/arabic_layout.jpg -------------------------------------------------------------------------------- /static/images/arabic_reading.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/arabic_reading.jpg -------------------------------------------------------------------------------- /static/images/arabic_text.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/arabic_text.jpg -------------------------------------------------------------------------------- /static/images/benchmark_chart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/benchmark_chart.png -------------------------------------------------------------------------------- /static/images/benchmark_chart_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/benchmark_chart_small.png -------------------------------------------------------------------------------- /static/images/benchmark_layout_chart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/benchmark_layout_chart.png -------------------------------------------------------------------------------- /static/images/benchmark_rec_chart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/benchmark_rec_chart.png -------------------------------------------------------------------------------- /static/images/benchmark_tablerec_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/benchmark_tablerec_acc.png -------------------------------------------------------------------------------- /static/images/benchmark_tablerec_speed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/benchmark_tablerec_speed.png -------------------------------------------------------------------------------- /static/images/chi_hind.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/chi_hind.jpg -------------------------------------------------------------------------------- /static/images/chi_hind_layout.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/chi_hind_layout.jpg -------------------------------------------------------------------------------- /static/images/chi_hind_orig.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/chi_hind_orig.jpg -------------------------------------------------------------------------------- /static/images/chi_hind_reading.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/chi_hind_reading.jpg -------------------------------------------------------------------------------- /static/images/chi_hind_text.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/chi_hind_text.jpg -------------------------------------------------------------------------------- /static/images/chinese.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/chinese.jpg -------------------------------------------------------------------------------- /static/images/chinese_layout.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/chinese_layout.jpg -------------------------------------------------------------------------------- /static/images/chinese_reading.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/chinese_reading.jpg -------------------------------------------------------------------------------- /static/images/chinese_text.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/chinese_text.jpg -------------------------------------------------------------------------------- /static/images/excerpt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/excerpt.png -------------------------------------------------------------------------------- /static/images/excerpt_layout.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/excerpt_layout.png -------------------------------------------------------------------------------- /static/images/excerpt_reading.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/excerpt_reading.jpg -------------------------------------------------------------------------------- /static/images/excerpt_text.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/excerpt_text.png -------------------------------------------------------------------------------- /static/images/funsd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/funsd.png -------------------------------------------------------------------------------- /static/images/funsd_layout.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/funsd_layout.jpg -------------------------------------------------------------------------------- /static/images/funsd_reading.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/funsd_reading.jpg -------------------------------------------------------------------------------- /static/images/funsd_text.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/funsd_text.jpg -------------------------------------------------------------------------------- /static/images/gcloud_full_langs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/gcloud_full_langs.png -------------------------------------------------------------------------------- /static/images/gcloud_rec_bench.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/gcloud_rec_bench.png -------------------------------------------------------------------------------- /static/images/hindi.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/hindi.jpg -------------------------------------------------------------------------------- /static/images/hindi_layout.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/hindi_layout.jpg -------------------------------------------------------------------------------- /static/images/hindi_reading.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/hindi_reading.jpg -------------------------------------------------------------------------------- /static/images/hindi_text.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/hindi_text.jpg -------------------------------------------------------------------------------- /static/images/japanese.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/japanese.jpg -------------------------------------------------------------------------------- /static/images/japanese_layout.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/japanese_layout.jpg -------------------------------------------------------------------------------- /static/images/japanese_reading.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/japanese_reading.jpg -------------------------------------------------------------------------------- /static/images/japanese_tablerec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/japanese_tablerec.png -------------------------------------------------------------------------------- /static/images/japanese_text.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/japanese_text.jpg -------------------------------------------------------------------------------- /static/images/latex_ocr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/latex_ocr.png -------------------------------------------------------------------------------- /static/images/nyt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/nyt.jpg -------------------------------------------------------------------------------- /static/images/nyt_layout.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/nyt_layout.jpg -------------------------------------------------------------------------------- /static/images/nyt_order.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/nyt_order.jpg -------------------------------------------------------------------------------- /static/images/nyt_text.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/nyt_text.jpg -------------------------------------------------------------------------------- /static/images/paper.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/paper.jpg -------------------------------------------------------------------------------- /static/images/paper_layout.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/paper_layout.jpg -------------------------------------------------------------------------------- /static/images/paper_reading.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/paper_reading.jpg -------------------------------------------------------------------------------- /static/images/paper_tablerec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/paper_tablerec.png -------------------------------------------------------------------------------- /static/images/paper_text.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/paper_text.jpg -------------------------------------------------------------------------------- /static/images/pres.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/pres.png -------------------------------------------------------------------------------- /static/images/pres_layout.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/pres_layout.jpg -------------------------------------------------------------------------------- /static/images/pres_reading.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/pres_reading.jpg -------------------------------------------------------------------------------- /static/images/pres_tablerec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/pres_tablerec.png -------------------------------------------------------------------------------- /static/images/pres_text.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/pres_text.jpg -------------------------------------------------------------------------------- /static/images/rec_acc_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/rec_acc_table.png -------------------------------------------------------------------------------- /static/images/scanned.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/scanned.png -------------------------------------------------------------------------------- /static/images/scanned_layout.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/scanned_layout.jpg -------------------------------------------------------------------------------- /static/images/scanned_reading.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/scanned_reading.jpg -------------------------------------------------------------------------------- /static/images/scanned_tablerec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/scanned_tablerec.png -------------------------------------------------------------------------------- /static/images/scanned_tablerec2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/scanned_tablerec2.png -------------------------------------------------------------------------------- /static/images/scanned_text.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/scanned_text.jpg -------------------------------------------------------------------------------- /static/images/surya_rec_perf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/surya_rec_perf.png -------------------------------------------------------------------------------- /static/images/table_rec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/table_rec.png -------------------------------------------------------------------------------- /static/images/textbook.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/textbook.jpg -------------------------------------------------------------------------------- /static/images/textbook_layout.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/textbook_layout.jpg -------------------------------------------------------------------------------- /static/images/textbook_order.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/textbook_order.jpg -------------------------------------------------------------------------------- /static/images/textbook_text.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/static/images/textbook_text.jpg -------------------------------------------------------------------------------- /surya/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/surya/__init__.py -------------------------------------------------------------------------------- /surya/common/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /surya/common/load.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Any 2 | 3 | import torch 4 | 5 | from surya.settings import settings 6 | 7 | 8 | class ModelLoader: 9 | def __init__(self, checkpoint: Optional[str] = None): 10 | self.checkpoint = checkpoint 11 | 12 | def model( 13 | self, 14 | device: torch.device | str | None = settings.TORCH_DEVICE_MODEL, 15 | dtype: Optional[torch.dtype | str] = settings.MODEL_DTYPE, 16 | ) -> Any: 17 | raise NotImplementedError() 18 | 19 | def processor( 20 | self, 21 | device: torch.device | str | None = settings.TORCH_DEVICE_MODEL, 22 | dtype: Optional[torch.dtype | str] = settings.MODEL_DTYPE, 23 | ) -> Any: 24 | raise NotImplementedError() 25 | -------------------------------------------------------------------------------- /surya/common/predictor.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from surya.common.load import ModelLoader 6 | from surya.settings import settings 7 | 8 | 9 | class BasePredictor: 10 | model_loader_cls = ModelLoader 11 | batch_size: Optional[int] = None 12 | default_batch_sizes = { 13 | "cpu": 1, 14 | "mps": 1, 15 | "cuda": 1 16 | } 17 | disable_tqdm: bool = settings.DISABLE_TQDM 18 | torch_dtype = settings.MODEL_DTYPE 19 | 20 | def __init__(self, checkpoint: Optional[str] = None, device: torch.device | str | None = settings.TORCH_DEVICE_MODEL, dtype: Optional[torch.dtype | str] = None): 21 | if dtype is None: 22 | dtype = self.torch_dtype 23 | 24 | self.model = None 25 | self.processor = None 26 | loader = self.model_loader_cls(checkpoint) 27 | 28 | self.model = loader.model(device, dtype) 29 | self.processor = loader.processor() 30 | 31 | def to(self, device_dtype: torch.device | str | None = None): 32 | if self.model: 33 | self.model.to(device_dtype) 34 | else: 35 | raise ValueError("Model not loaded") 36 | 37 | def get_batch_size(self): 38 | batch_size = self.batch_size 39 | if batch_size is None: 40 | batch_size = self.default_batch_sizes["cpu"] 41 | if settings.TORCH_DEVICE_MODEL in self.default_batch_sizes: 42 | batch_size = self.default_batch_sizes[settings.TORCH_DEVICE_MODEL] 43 | return batch_size 44 | 45 | @staticmethod 46 | def pad_to_batch_size(tensor: torch.Tensor, batch_size: int): 47 | current_batch_size = tensor.shape[0] 48 | if current_batch_size >= batch_size: 49 | return tensor 50 | 51 | pad_size = batch_size - current_batch_size 52 | padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size) 53 | 54 | return F.pad(tensor, padding, mode='constant', value=0) 55 | 56 | def __call__(self, *args, **kwargs): 57 | raise NotImplementedError() -------------------------------------------------------------------------------- /surya/common/surya/config.py: -------------------------------------------------------------------------------- 1 | from transformers import PretrainedConfig 2 | 3 | from surya.common.s3 import S3DownloaderMixin 4 | from surya.common.surya.encoder.config import SuryaEncoderConfig 5 | from surya.common.surya.decoder.config import SuryaDecoderConfig 6 | 7 | 8 | class SuryaModelConfig(S3DownloaderMixin, PretrainedConfig): 9 | model_type = "surya-multimodal-foundation" 10 | is_composition = True 11 | 12 | def __init__( 13 | self, 14 | vocab_size=65536, 15 | bbox_size=1025, 16 | blank_bbox_token_id=1025, 17 | bos_token_id=0, 18 | eos_token_id=1, 19 | pad_token_id=2, 20 | image_token_id=3, 21 | special_token_count=4, 22 | max_sequence_length=1536, 23 | special_ocr_tokens=None, 24 | vision_encoder=None, 25 | decoder=None, 26 | tasks: dict | None = None, 27 | bbox_embed_size: int = 64, 28 | register_token_ids=(4, 5, 6, 7), 29 | unmask_image: bool = False, 30 | num_register_tokens: int = 4, 31 | image_embed_encoding_size: int = 1024, 32 | image_embed_encoding_multiplier: int = 256, 33 | **kwargs, 34 | ): 35 | super().__init__(**kwargs) 36 | self.is_encoder_decoder = False 37 | self.vocab_size = vocab_size 38 | self.bbox_size = bbox_size 39 | self.blank_bbox_token_id = blank_bbox_token_id 40 | self.image_token_id = image_token_id 41 | self.bos_token_id = bos_token_id 42 | self.eos_token_id = eos_token_id 43 | self.pad_token_id = pad_token_id 44 | self.special_ocr_tokens = special_ocr_tokens 45 | self.special_token_count = special_token_count # pad, bos, etc, tokens 46 | self.max_sequence_length = max_sequence_length 47 | self.tasks = tasks 48 | self.tie_word_embeddings = True 49 | self.bbox_embed_size = bbox_embed_size 50 | self.unmask_image = unmask_image 51 | self.num_register_tokens = num_register_tokens 52 | self.register_token_ids = register_token_ids 53 | self.image_embed_encoding_size = image_embed_encoding_size 54 | self.image_embed_encoding_multiplier = image_embed_encoding_multiplier 55 | 56 | if isinstance(vision_encoder, dict): 57 | vision_encoder = SuryaEncoderConfig(**vision_encoder) 58 | elif vision_encoder is None: 59 | vision_encoder = SuryaEncoderConfig() 60 | self.vision_encoder = vision_encoder 61 | 62 | if isinstance(decoder, dict): 63 | decoder = SuryaDecoderConfig(**decoder) 64 | elif decoder is None: 65 | decoder = SuryaDecoderConfig() 66 | self.decoder = decoder 67 | 68 | self.hidden_size = self.decoder.hidden_size 69 | 70 | self.patch_size = self.vision_encoder.spatial_patch_size 71 | self.merge_size = self.vision_encoder.spatial_merge_size 72 | -------------------------------------------------------------------------------- /surya/common/surya/decoder/config.py: -------------------------------------------------------------------------------- 1 | from transformers.configuration_utils import PretrainedConfig 2 | from transformers.modeling_rope_utils import rope_config_validation 3 | from transformers.utils import logging 4 | 5 | logger = logging.get_logger(__name__) 6 | 7 | 8 | class SuryaDecoderConfig(PretrainedConfig): 9 | model_type = "qwen2" 10 | keys_to_ignore_at_inference = ["past_key_values"] 11 | 12 | # Default tensor parallel plan for base model `Qwen2` 13 | base_model_tp_plan = { 14 | "layers.*.self_attn.q_proj": "colwise", 15 | "layers.*.self_attn.k_proj": "colwise", 16 | "layers.*.self_attn.v_proj": "colwise", 17 | "layers.*.self_attn.o_proj": "rowwise", 18 | "layers.*.mlp.gate_proj": "colwise", 19 | "layers.*.mlp.up_proj": "colwise", 20 | "layers.*.mlp.down_proj": "rowwise", 21 | } 22 | base_model_pp_plan = { 23 | "embed_tokens": (["input_ids"], ["inputs_embeds"]), 24 | "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), 25 | "norm": (["hidden_states"], ["hidden_states"]), 26 | } 27 | 28 | def __init__( 29 | self, 30 | vocab_size=151936, 31 | hidden_size=4096, 32 | intermediate_size=22016, 33 | num_hidden_layers=32, 34 | num_attention_heads=32, 35 | num_key_value_heads=32, 36 | hidden_act="silu", 37 | max_position_embeddings=32768, 38 | initializer_range=0.02, 39 | rms_norm_eps=1e-6, 40 | use_cache=True, 41 | tie_word_embeddings=False, 42 | rope_theta=10000.0, 43 | rope_scaling=None, 44 | use_sliding_window=False, 45 | sliding_window=4096, 46 | max_window_layers=28, 47 | attention_dropout=0.0, 48 | unmask_image: bool = True, 49 | **kwargs, 50 | ): 51 | self.vocab_size = vocab_size 52 | self.max_position_embeddings = max_position_embeddings 53 | self.hidden_size = hidden_size 54 | self.intermediate_size = intermediate_size 55 | self.num_hidden_layers = num_hidden_layers 56 | self.num_attention_heads = num_attention_heads 57 | self.use_sliding_window = False # Disable sliding window 58 | self.sliding_window = ( 59 | sliding_window # we check `use_sliding_window` in the modeling code 60 | ) 61 | self.max_window_layers = max_window_layers 62 | 63 | # for backward compatibility 64 | if num_key_value_heads is None: 65 | num_key_value_heads = num_attention_heads 66 | 67 | self.num_key_value_heads = num_key_value_heads 68 | self.hidden_act = hidden_act 69 | self.initializer_range = initializer_range 70 | self.rms_norm_eps = rms_norm_eps 71 | self.use_cache = use_cache 72 | self.rope_theta = rope_theta 73 | self.rope_scaling = rope_scaling 74 | self.attention_dropout = attention_dropout 75 | self.unmask_image = unmask_image 76 | # Validate the correctness of rotary position embeddings parameters 77 | # BC: if there is a 'type' field, move it to 'rope_type'. 78 | if self.rope_scaling is not None and "type" in self.rope_scaling: 79 | self.rope_scaling["rope_type"] = self.rope_scaling["type"] 80 | rope_config_validation(self) 81 | 82 | super().__init__( 83 | tie_word_embeddings=tie_word_embeddings, 84 | **kwargs, 85 | ) 86 | -------------------------------------------------------------------------------- /surya/common/surya/embedder/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SimpleTokenEmbedder(nn.Module): 6 | def __init__(self, config): 7 | super().__init__() 8 | self.token_embed = nn.Embedding(config.vocab_size, config.hidden_size) 9 | 10 | def embed( 11 | self, 12 | input_tokens: torch.Tensor, 13 | ) -> torch.Tensor: 14 | return self.token_embed(input_tokens) 15 | -------------------------------------------------------------------------------- /surya/common/surya/encoder/config.py: -------------------------------------------------------------------------------- 1 | from transformers.configuration_utils import PretrainedConfig 2 | from transformers.utils import logging 3 | 4 | logger = logging.get_logger(__name__) 5 | 6 | 7 | class SuryaEncoderConfig(PretrainedConfig): 8 | model_type = "qwen2_5_vl" 9 | base_config_key = "vision_config" 10 | 11 | attribute_map = { 12 | "num_attention_heads": "num_heads", 13 | "num_hidden_layers": "depth", 14 | } 15 | 16 | def __init__( 17 | self, 18 | depth=8, 19 | hidden_size=1280, 20 | hidden_act="silu", 21 | intermediate_size=3420, 22 | num_heads=16, 23 | in_channels=3, 24 | patch_size=14, 25 | spatial_merge_size=2, 26 | spatial_patch_size=14, 27 | temporal_patch_size=1, 28 | tokens_per_second=4, 29 | window_size=112, 30 | out_hidden_size=1280, 31 | fullatt_block_indexes=(3, 7), 32 | initializer_range=0.02, 33 | image_size=4096, 34 | **kwargs, 35 | ): 36 | super().__init__(**kwargs) 37 | 38 | self.depth = depth 39 | self.hidden_size = hidden_size 40 | self.hidden_act = hidden_act 41 | self.intermediate_size = intermediate_size 42 | self.num_heads = num_heads 43 | self.in_channels = in_channels 44 | self.patch_size = patch_size 45 | self.spatial_merge_size = spatial_merge_size 46 | self.temporal_patch_size = temporal_patch_size 47 | self.tokens_per_second = tokens_per_second 48 | self.window_size = window_size 49 | self.fullatt_block_indexes = fullatt_block_indexes 50 | self.out_hidden_size = out_hidden_size 51 | self.initializer_range = initializer_range 52 | self.spatial_patch_size = spatial_patch_size 53 | self.image_size = image_size 54 | -------------------------------------------------------------------------------- /surya/common/surya/processor/schema.py: -------------------------------------------------------------------------------- 1 | from typing import TypedDict, Literal, List, Tuple 2 | 3 | import torch 4 | from PIL import Image 5 | 6 | 7 | class TaskDict(TypedDict): 8 | datasets: List[str] 9 | img_size: Tuple[int, int] 10 | 11 | 12 | class TasksDict(TypedDict): 13 | ocr_with_boxes: TaskDict 14 | ocr_without_boxes: TaskDict 15 | block_without_boxes: TaskDict 16 | 17 | 18 | class ProcessorInput(TypedDict): 19 | type: Literal["image", "ocr", "text", "empty_output"] 20 | 21 | 22 | class ImageInput(ProcessorInput): 23 | type: Literal["image"] 24 | image: Image.Image 25 | rotated: bool 26 | 27 | 28 | class TextInput(ProcessorInput): 29 | type: Literal["text"] 30 | text: str 31 | math: bool 32 | 33 | 34 | class ProcessorOutput(TypedDict): 35 | input_ids: List[int] 36 | image_tiles: torch.Tensor | None 37 | grid_thw: torch.Tensor | None 38 | -------------------------------------------------------------------------------- /surya/common/surya/schema.py: -------------------------------------------------------------------------------- 1 | class TaskNames: 2 | block_without_boxes = "block_without_boxes" 3 | ocr_with_boxes = "ocr_with_boxes" 4 | ocr_without_boxes = "ocr_without_boxes" 5 | 6 | 7 | TASK_NAMES = [ 8 | TaskNames.block_without_boxes, 9 | TaskNames.ocr_with_boxes, 10 | TaskNames.ocr_without_boxes, 11 | ] 12 | -------------------------------------------------------------------------------- /surya/common/util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import List 3 | import torch 4 | 5 | from surya.common.polygon import PolygonBox 6 | from surya.settings import settings 7 | 8 | 9 | def clean_boxes(boxes: List[PolygonBox]) -> List[PolygonBox]: 10 | new_boxes = [] 11 | for box_obj in boxes: 12 | xs = [point[0] for point in box_obj.polygon] 13 | ys = [point[1] for point in box_obj.polygon] 14 | if max(xs) == min(xs) or max(ys) == min(ys): 15 | continue 16 | 17 | box = box_obj.bbox 18 | contained = False 19 | for other_box_obj in boxes: 20 | if other_box_obj.polygon == box_obj.polygon: 21 | continue 22 | 23 | other_box = other_box_obj.bbox 24 | if box == other_box: 25 | continue 26 | if ( 27 | box[0] >= other_box[0] 28 | and box[1] >= other_box[1] 29 | and box[2] <= other_box[2] 30 | and box[3] <= other_box[3] 31 | ): 32 | contained = True 33 | break 34 | if not contained: 35 | new_boxes.append(box_obj) 36 | return new_boxes 37 | 38 | 39 | def rescale_bbox(bbox, processor_size, image_size): 40 | page_width, page_height = processor_size 41 | 42 | img_width, img_height = image_size 43 | width_scaler = img_width / page_width 44 | height_scaler = img_height / page_height 45 | 46 | new_bbox = copy.deepcopy(bbox) 47 | new_bbox[0] = int(new_bbox[0] * width_scaler) 48 | new_bbox[1] = int(new_bbox[1] * height_scaler) 49 | new_bbox[2] = int(new_bbox[2] * width_scaler) 50 | new_bbox[3] = int(new_bbox[3] * height_scaler) 51 | return new_bbox 52 | 53 | 54 | def expand_bbox(bbox, expansion_factor=0.01): 55 | expansion_low = 1 - expansion_factor 56 | expansion_high = 1 + expansion_factor 57 | return [ 58 | bbox[0] * expansion_low, 59 | bbox[1] * expansion_low, 60 | bbox[2] * expansion_high, 61 | bbox[3] * expansion_high, 62 | ] 63 | 64 | 65 | def is_flash_attn_2_supported(device: str | torch.device) -> bool: 66 | if not torch.cuda.is_available(): 67 | return False 68 | 69 | if "cuda" not in str(device): 70 | return False 71 | 72 | # Check CUDA version >= 12.0 73 | cuda_version_str = torch.version.cuda 74 | if cuda_version_str is None: 75 | return False 76 | cuda_version = tuple(map(int, cuda_version_str.split("."))) 77 | if cuda_version < (12, 0): 78 | return False 79 | 80 | # Check GPU compute capability (Ampere, Ada, Hopper GPUs) 81 | major, minor = torch.cuda.get_device_capability() 82 | compute_capability = major + minor / 10 83 | if compute_capability < 8.0: 84 | return False 85 | 86 | return True 87 | 88 | 89 | if settings.TORCH_DEVICE_MODEL == "xla": 90 | import torch_xla.core.xla_model as xm 91 | else: 92 | xm = None 93 | 94 | 95 | def mark_step(): 96 | if xm is not None: 97 | xm.mark_step() 98 | -------------------------------------------------------------------------------- /surya/debug/draw.py: -------------------------------------------------------------------------------- 1 | import re 2 | from PIL import ImageDraw, ImageFont 3 | 4 | from surya.debug.fonts import get_font_path 5 | from surya.debug.text import get_text_size 6 | 7 | 8 | def draw_bboxes_on_image(bboxes, image, labels=None, label_font_size=10, color: str | list = 'red'): 9 | polys = [] 10 | for bb in bboxes: 11 | # Clockwise polygon 12 | poly = [ 13 | [bb[0], bb[1]], 14 | [bb[2], bb[1]], 15 | [bb[2], bb[3]], 16 | [bb[0], bb[3]] 17 | ] 18 | polys.append(poly) 19 | 20 | return draw_polys_on_image(polys, image, labels, label_font_size=label_font_size, color=color) 21 | 22 | 23 | def draw_polys_on_image(corners, image, labels=None, box_padding=-1, label_offset=1, label_font_size=10, color: str | list = 'red'): 24 | draw = ImageDraw.Draw(image) 25 | font_path = get_font_path() 26 | label_font = ImageFont.truetype(font_path, label_font_size) 27 | 28 | for i in range(len(corners)): 29 | poly = corners[i] 30 | poly = [(int(p[0]), int(p[1])) for p in poly] 31 | draw.polygon(poly, outline=color[i] if isinstance(color, list) else color, width=1) 32 | 33 | if labels is not None: 34 | label = labels[i] 35 | text_position = ( 36 | min([p[0] for p in poly]) + label_offset, 37 | min([p[1] for p in poly]) + label_offset 38 | ) 39 | text_size = get_text_size(label, label_font) 40 | box_position = ( 41 | text_position[0] - box_padding + label_offset, 42 | text_position[1] - box_padding + label_offset, 43 | text_position[0] + text_size[0] + box_padding + label_offset, 44 | text_position[1] + text_size[1] + box_padding + label_offset 45 | ) 46 | draw.rectangle(box_position, fill="white") 47 | draw.text( 48 | text_position, 49 | label, 50 | fill=color[i] if isinstance(color, list) else color, 51 | font=label_font 52 | ) 53 | 54 | return image 55 | -------------------------------------------------------------------------------- /surya/debug/fonts.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | import os 3 | import requests 4 | 5 | from surya.settings import settings 6 | 7 | 8 | def get_font_path(langs: Optional[List[str]] = None) -> str: 9 | font_path = settings.RECOGNITION_RENDER_FONTS["all"] 10 | if langs is not None: 11 | for k in settings.RECOGNITION_RENDER_FONTS: 12 | if k in langs and len(langs) == 1: 13 | font_path = settings.RECOGNITION_RENDER_FONTS[k] 14 | break 15 | 16 | if not os.path.exists(font_path): 17 | os.makedirs(os.path.dirname(font_path), exist_ok=True) 18 | font_dl_path = f"{settings.RECOGNITION_FONT_DL_BASE}/{os.path.basename(font_path)}" 19 | with requests.get(font_dl_path, stream=True) as r, open(font_path, 'wb') as f: 20 | r.raise_for_status() 21 | for chunk in r.iter_content(chunk_size=8192): 22 | f.write(chunk) 23 | 24 | return font_path -------------------------------------------------------------------------------- /surya/debug/katex.js: -------------------------------------------------------------------------------- 1 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /surya/debug/render_html.py: -------------------------------------------------------------------------------- 1 | import html as htmllib 2 | import os.path 3 | import re 4 | 5 | filepath = os.path.abspath(__file__) 6 | 7 | def render_text_as_html( 8 | bboxes: list[list[int]], 9 | texts: list[str], 10 | image_size: tuple[int, int], 11 | base_font_size: int = 16, 12 | scaler: int = 2 13 | ): 14 | katex_path = os.path.join(os.path.dirname(filepath), "katex.js") 15 | with open(katex_path, "r") as f: 16 | katex_script = f.read() 17 | 18 | html_content = [] 19 | image_size = tuple([int(s * scaler) for s in image_size]) 20 | width, height = image_size 21 | 22 | 23 | html_content.append(f""" 24 | 25 | 26 | 27 | 50 | {katex_script} 51 | 52 | 53 | """) 54 | 55 | for i, (bbox, text) in enumerate(zip(bboxes, texts)): 56 | bbox = bbox.copy() 57 | bbox = [int(bb * scaler) for bb in bbox] 58 | x1, y1, x2, y2 = bbox 59 | width = x2 - x1 60 | height = y2 - y1 61 | min_dim = min(width, height) 62 | 63 | # Scale font size based on box height 64 | font_size = min(int(min_dim * 0.75), base_font_size) 65 | 66 | # Create div with absolute positioning 67 | div_style = ( 68 | f"left: {x1}px; " 69 | f"top: {y1}px; " 70 | f"width: {width}px; " 71 | f"height: {height}px; " 72 | f"font-size: {font_size}px;" 73 | ) 74 | 75 | class_ = "text-box" 76 | if height > width * 2: 77 | class_ += " vertical-text" 78 | 79 | # Determine if content is HTML/MathML or plain text 80 | if "<" in text and ">" in text and re.search(r"<(html|math|div|sub|sup|i|u|mark|small|del|b|br|code)\b", text.lower()): 81 | # Content is already HTML/MathML, include as-is 82 | html_content.append(f'{text}') 83 | else: 84 | # Plain text, escape it 85 | escaped_text = htmllib.escape(text) 86 | html_content.append(f'{escaped_text}') 87 | 88 | html_content.append("") 89 | 90 | return "\n".join(html_content), image_size -------------------------------------------------------------------------------- /surya/debug/text.py: -------------------------------------------------------------------------------- 1 | import re 2 | from io import BytesIO 3 | from typing import List, Tuple 4 | from PIL import Image, ImageDraw, ImageFont 5 | 6 | from surya.debug.fonts import get_font_path 7 | from surya.debug.render_html import render_text_as_html 8 | 9 | try: 10 | from playwright.sync_api import sync_playwright 11 | 12 | has_playwright = True 13 | except ImportError: 14 | has_playwright = False 15 | 16 | 17 | def strip_html_tags(html_text): 18 | pattern = re.compile(r"<[\w/][^>]*>") 19 | text_only = pattern.sub("", html_text) 20 | 21 | return text_only 22 | 23 | 24 | def get_text_size(text, font): 25 | im = Image.new(mode="P", size=(0, 0)) 26 | draw = ImageDraw.Draw(im) 27 | _, _, width, height = draw.textbbox((0, 0), text=text, font=font) 28 | return width, height 29 | 30 | 31 | def render_text(draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size): 32 | font = ImageFont.truetype(font_path, box_font_size) 33 | text_width, text_height = get_text_size(text, font) 34 | while (text_width > bbox_width or text_height > bbox_height) and box_font_size > 6: 35 | box_font_size = box_font_size - 1 36 | font = ImageFont.truetype(font_path, box_font_size) 37 | text_width, text_height = get_text_size(text, font) 38 | 39 | # Calculate text position (centered in bbox) 40 | text_width, text_height = get_text_size(text, font) 41 | x = s_bbox[0] 42 | y = s_bbox[1] + (bbox_height - text_height) / 2 43 | 44 | draw.text((x, y), text, fill="black", font=font) 45 | 46 | 47 | def draw_text_with_playwright( 48 | bboxes, texts: List[str], image_size: Tuple[int, int] 49 | ) -> Image.Image: 50 | html_content, image_size = render_text_as_html(bboxes, texts, image_size) 51 | if not has_playwright: 52 | raise ImportError( 53 | "Playwright is not installed. Please install it using `pip install playwright`" 54 | ) 55 | 56 | with sync_playwright() as p: 57 | browser = p.chromium.launch(headless=True) 58 | page = browser.new_page( 59 | viewport={"width": image_size[0], "height": image_size[1]} 60 | ) 61 | page.set_content(html_content) 62 | page.wait_for_timeout(1000) 63 | body = page.query_selector("body") 64 | image = body.screenshot() 65 | browser.close() 66 | 67 | pil_img = Image.open(BytesIO(image)) 68 | return pil_img 69 | 70 | 71 | def draw_text_on_image( 72 | bboxes, 73 | texts, 74 | image_size: Tuple[int, int], 75 | font_path=None, 76 | max_font_size=60, 77 | res_upscale=2, 78 | ) -> Image.Image: 79 | if has_playwright: 80 | return draw_text_with_playwright(bboxes, texts, image_size) 81 | 82 | texts = [strip_html_tags(text) for text in texts] 83 | if font_path is None: 84 | font_path = get_font_path() 85 | new_image_size = (image_size[0] * res_upscale, image_size[1] * res_upscale) 86 | image = Image.new("RGB", new_image_size, color="white") 87 | draw = ImageDraw.Draw(image) 88 | 89 | for bbox, text in zip(bboxes, texts): 90 | s_bbox = [int(coord * res_upscale) for coord in bbox] 91 | bbox_width = s_bbox[2] - s_bbox[0] 92 | bbox_height = s_bbox[3] - s_bbox[1] 93 | 94 | # Shrink the text to fit in the bbox if needed 95 | box_font_size = max(6, min(int(0.75 * bbox_height), max_font_size)) 96 | render_text( 97 | draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size 98 | ) 99 | 100 | return image 101 | -------------------------------------------------------------------------------- /surya/detection/loader.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | 5 | from surya.common.load import ModelLoader 6 | from surya.detection.processor import SegformerImageProcessor 7 | 8 | from surya.detection.model.config import EfficientViTConfig 9 | from surya.detection.model.encoderdecoder import EfficientViTForSemanticSegmentation 10 | from surya.logging import get_logger 11 | from surya.settings import settings 12 | 13 | logger = get_logger() 14 | 15 | 16 | class DetectionModelLoader(ModelLoader): 17 | def __init__(self, checkpoint: Optional[str] = None): 18 | super().__init__(checkpoint) 19 | 20 | if self.checkpoint is None: 21 | self.checkpoint = settings.DETECTOR_MODEL_CHECKPOINT 22 | 23 | def model( 24 | self, 25 | device: Optional[torch.device | str] = None, 26 | dtype: Optional[torch.dtype | str] = None, 27 | ) -> EfficientViTForSemanticSegmentation: 28 | if device is None: 29 | device = settings.TORCH_DEVICE_MODEL 30 | if dtype is None: 31 | dtype = settings.MODEL_DTYPE 32 | 33 | config = EfficientViTConfig.from_pretrained(self.checkpoint) 34 | model = EfficientViTForSemanticSegmentation.from_pretrained( 35 | self.checkpoint, 36 | torch_dtype=dtype, 37 | config=config, 38 | ) 39 | model = model.to(device) 40 | model = model.eval() 41 | 42 | if settings.COMPILE_ALL or settings.COMPILE_DETECTOR: 43 | torch.set_float32_matmul_precision("high") 44 | torch._dynamo.config.cache_size_limit = 1 45 | torch._dynamo.config.suppress_errors = False 46 | 47 | logger.info( 48 | f"Compiling detection model {self.checkpoint} on device {device} with dtype {dtype}" 49 | ) 50 | compile_args = {"backend": "openxla"} if device == "xla" else {} 51 | model = torch.compile(model, **compile_args) 52 | 53 | logger.debug( 54 | f"Loaded detection model {self.checkpoint} from {EfficientViTForSemanticSegmentation.get_local_path(self.checkpoint)} onto device {device} with dtype {dtype}" 55 | ) 56 | return model 57 | 58 | def processor( 59 | self, 60 | device: Optional[torch.device | str] = None, 61 | dtype: Optional[torch.dtype | str] = None, 62 | ) -> SegformerImageProcessor: 63 | return SegformerImageProcessor.from_pretrained(self.checkpoint) 64 | -------------------------------------------------------------------------------- /surya/detection/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/surya/detection/model/__init__.py -------------------------------------------------------------------------------- /surya/detection/model/config.py: -------------------------------------------------------------------------------- 1 | from transformers import PretrainedConfig 2 | 3 | from surya.common.s3 import S3DownloaderMixin 4 | 5 | 6 | class EfficientViTConfig(S3DownloaderMixin, PretrainedConfig): 7 | r""" 8 | ```""" 9 | 10 | model_type = "efficientvit" 11 | 12 | def __init__( 13 | self, 14 | num_classes=2, 15 | num_channels=3, 16 | widths=(32, 64, 128, 256, 512), 17 | head_dim=32, 18 | num_stages=4, 19 | depths=(1, 1, 1, 6, 6), 20 | strides=(2, 2, 2, 2, 2), 21 | hidden_sizes=(32, 64, 160, 256), 22 | patch_size=(7, 7), 23 | hidden_dropout_prob=0.0, 24 | attention_probs_dropout_prob=0.0, 25 | classifier_dropout_prob=0.0, 26 | layer_norm_eps=1e-6, 27 | decoder_layer_hidden_size=128, 28 | decoder_hidden_size=512, 29 | semantic_loss_ignore_index=255, 30 | initializer_range=0.02, 31 | **kwargs, 32 | ): 33 | super().__init__(**kwargs) 34 | 35 | self.num_classes = num_classes 36 | self.widths = widths 37 | self.head_dim = head_dim 38 | 39 | self.num_channels = num_channels 40 | self.num_stages = num_stages 41 | self.depths = depths 42 | self.strides = strides 43 | self.hidden_sizes = hidden_sizes 44 | self.patch_size = patch_size 45 | self.hidden_dropout_prob = hidden_dropout_prob 46 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 47 | self.classifier_dropout_prob = classifier_dropout_prob 48 | self.layer_norm_eps = layer_norm_eps 49 | self.decoder_hidden_size = decoder_hidden_size 50 | self.decoder_layer_hidden_size = decoder_layer_hidden_size 51 | self.semantic_loss_ignore_index = semantic_loss_ignore_index 52 | 53 | self.initializer_range = initializer_range -------------------------------------------------------------------------------- /surya/detection/parallel.py: -------------------------------------------------------------------------------- 1 | class FakeFuture: 2 | def __init__(self, func, *args, **kwargs): 3 | self._result = func(*args, **kwargs) 4 | 5 | def result(self): 6 | return self._result 7 | 8 | class FakeExecutor: 9 | def __init__(self, **kwargs): 10 | pass 11 | 12 | def __enter__(self): 13 | return self 14 | 15 | def __exit__(self, *excinfo): 16 | pass 17 | 18 | def submit(self, fn, *args, **kwargs): 19 | return FakeFuture(fn, *args, **kwargs) 20 | -------------------------------------------------------------------------------- /surya/detection/schema.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Any 2 | 3 | from pydantic import BaseModel 4 | 5 | from surya.common.polygon import PolygonBox 6 | 7 | 8 | class ColumnLine(PolygonBox): 9 | vertical: bool 10 | horizontal: bool 11 | 12 | class TextDetectionResult(BaseModel): 13 | bboxes: List[PolygonBox] 14 | vertical_lines: List[ColumnLine] 15 | heatmap: Optional[Any] 16 | affinity_map: Optional[Any] 17 | image_bbox: List[float] 18 | -------------------------------------------------------------------------------- /surya/detection/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | from PIL import ImageOps 3 | 4 | from surya.settings import settings 5 | 6 | 7 | def get_total_splits(image_size, height): 8 | img_height = list(image_size)[1] 9 | max_height = settings.DETECTOR_IMAGE_CHUNK_HEIGHT 10 | if img_height > max_height: 11 | num_splits = math.ceil(img_height / height) 12 | return num_splits 13 | return 1 14 | 15 | 16 | def split_image(img, height): 17 | # This will not modify/return the original image - it will either crop, or copy the image 18 | img_height = list(img.size)[1] 19 | max_height = settings.DETECTOR_IMAGE_CHUNK_HEIGHT 20 | if img_height > max_height: 21 | num_splits = math.ceil(img_height / height) 22 | splits = [] 23 | split_heights = [] 24 | for i in range(num_splits): 25 | top = i * height 26 | bottom = (i + 1) * height 27 | if bottom > img_height: 28 | bottom = img_height 29 | cropped = img.crop((0, top, img.size[0], bottom)) 30 | chunk_height = bottom - top 31 | if chunk_height < height: 32 | cropped = ImageOps.pad(cropped, (img.size[0], height), color=255, centering=(0, 0)) 33 | splits.append(cropped) 34 | split_heights.append(chunk_height) 35 | return splits, split_heights 36 | return [img.copy()], [img_height] 37 | -------------------------------------------------------------------------------- /surya/input/load.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import PIL 3 | 4 | from surya.input.processing import open_pdf, get_page_images 5 | from surya.logging import get_logger 6 | from surya.settings import settings 7 | import os 8 | import filetype 9 | from PIL import Image 10 | import json 11 | 12 | logger = get_logger() 13 | 14 | 15 | def get_name_from_path(path): 16 | return os.path.basename(path).split(".")[0] 17 | 18 | 19 | def load_pdf(pdf_path, page_range: List[int] | None = None, dpi=settings.IMAGE_DPI): 20 | doc = open_pdf(pdf_path) 21 | last_page = len(doc) 22 | 23 | if page_range: 24 | assert all([0 <= page < last_page for page in page_range]), ( 25 | f"Invalid page range: {page_range}" 26 | ) 27 | else: 28 | page_range = list(range(last_page)) 29 | 30 | images = get_page_images(doc, page_range, dpi=dpi) 31 | doc.close() 32 | names = [get_name_from_path(pdf_path) for _ in page_range] 33 | return images, names 34 | 35 | 36 | def load_image(image_path): 37 | image = Image.open(image_path).convert("RGB") 38 | name = get_name_from_path(image_path) 39 | return [image], [name] 40 | 41 | 42 | def load_from_file( 43 | input_path, page_range: List[int] | None = None, dpi=settings.IMAGE_DPI 44 | ): 45 | input_type = filetype.guess(input_path) 46 | if input_type and input_type.extension == "pdf": 47 | return load_pdf(input_path, page_range, dpi=dpi) 48 | else: 49 | return load_image(input_path) 50 | 51 | 52 | def load_from_folder( 53 | folder_path, page_range: List[int] | None = None, dpi=settings.IMAGE_DPI 54 | ): 55 | image_paths = [ 56 | os.path.join(folder_path, image_name) 57 | for image_name in os.listdir(folder_path) 58 | if not image_name.startswith(".") 59 | ] 60 | image_paths = [ip for ip in image_paths if not os.path.isdir(ip)] 61 | 62 | images = [] 63 | names = [] 64 | for path in image_paths: 65 | extension = filetype.guess(path) 66 | if extension and extension.extension == "pdf": 67 | image, name = load_pdf(path, page_range, dpi=dpi) 68 | images.extend(image) 69 | names.extend(name) 70 | else: 71 | try: 72 | image, name = load_image(path) 73 | images.extend(image) 74 | names.extend(name) 75 | except PIL.UnidentifiedImageError: 76 | logger.warning(f"Could not load image {path}") 77 | continue 78 | return images, names 79 | 80 | 81 | def load_lang_file(lang_path, names): 82 | with open(lang_path, "r") as f: 83 | lang_dict = json.load(f) 84 | return [lang_dict[name].copy() for name in names] 85 | -------------------------------------------------------------------------------- /surya/input/processing.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import cv2 4 | import numpy as np 5 | import pypdfium2 6 | from PIL import Image 7 | 8 | from surya.logging import get_logger 9 | from surya.settings import settings 10 | 11 | logger = get_logger() 12 | 13 | 14 | def convert_if_not_rgb(images: List[Image.Image]) -> List[Image.Image]: 15 | new_images = [] 16 | for image in images: 17 | if image.mode != "RGB": 18 | image = image.convert("RGB") 19 | new_images.append(image) 20 | return new_images 21 | 22 | 23 | def open_pdf(pdf_filepath): 24 | return pypdfium2.PdfDocument(pdf_filepath) 25 | 26 | 27 | def get_page_images(doc, indices: List, dpi=settings.IMAGE_DPI): 28 | images = [ 29 | doc[i].render(scale=dpi / 72, draw_annots=False).to_pil() for i in indices 30 | ] 31 | images = [image.convert("RGB") for image in images] 32 | return images 33 | 34 | 35 | def slice_bboxes_from_image(image: np.ndarray, bboxes): 36 | lines = [] 37 | for bbox in bboxes: 38 | bbox = np.array(bbox, dtype=np.int32) 39 | bbox = np.clip(bbox, 0, None) # Ensure no negative indices 40 | # Ensure bbox is within the image bounds 41 | if bbox[3] <= bbox[1]: 42 | bbox[3] = bbox[1] + 1 43 | 44 | if bbox[2] <= bbox[0]: 45 | bbox[2] = bbox[0] + 1 46 | 47 | bbox[2] = min(bbox[2], image.shape[1]) 48 | bbox[3] = min(bbox[3], image.shape[0]) 49 | 50 | line = image[bbox[1] : bbox[3], bbox[0] : bbox[2]].copy() 51 | if line.size == 0: 52 | logger.warning(f"Warning: found an empty line with bbox {bbox}") 53 | lines.append(line) 54 | return lines 55 | 56 | 57 | def slice_polys_from_image(image: np.ndarray, polys): 58 | lines = [] 59 | for idx, poly in enumerate(polys): 60 | lines.append(slice_and_pad_poly(image, poly)) 61 | return lines 62 | 63 | 64 | def slice_and_pad_poly(image_array: np.array, coordinates): 65 | # Draw polygon onto mask 66 | coordinates = [(corner[0], corner[1]) for corner in coordinates] 67 | bbox = [ 68 | min([x[0] for x in coordinates]), 69 | min([x[1] for x in coordinates]), 70 | max([x[0] for x in coordinates]), 71 | max([x[1] for x in coordinates]), 72 | ] 73 | 74 | # We mask out anything not in the polygon 75 | cropped_polygon = image_array[bbox[1] : bbox[3], bbox[0] : bbox[2]].copy() 76 | height, width = cropped_polygon.shape[:2] 77 | 78 | coordinates = [(x - bbox[0], y - bbox[1]) for x, y in coordinates] 79 | 80 | # Validate the cropped area 81 | if any( 82 | [ 83 | bbox[3] <= bbox[1] or bbox[2] <= bbox[0], 84 | len(coordinates) < 3, 85 | height == 0, 86 | width == 0, 87 | ] 88 | ): 89 | return cropped_polygon 90 | 91 | # Pad the area outside the polygon with the pad value 92 | try: 93 | mask = np.zeros(cropped_polygon.shape[:2], dtype=np.uint8) 94 | cv2.fillPoly(mask, [np.int32(coordinates)], 1) 95 | mask = np.stack([mask] * 3, axis=-1) 96 | 97 | cropped_polygon[mask == 0] = settings.RECOGNITION_PAD_VALUE 98 | except cv2.error as e: 99 | logger.warning(f"Warning: issue while processing polygon: {e}") 100 | 101 | return cropped_polygon 102 | -------------------------------------------------------------------------------- /surya/layout/loader.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | 5 | from surya.common.donut.processor import SuryaEncoderImageProcessor 6 | from surya.common.load import ModelLoader 7 | from surya.layout.model.config import ( 8 | SuryaLayoutConfig, 9 | SuryaLayoutDecoderConfig, 10 | DonutSwinLayoutConfig, 11 | ) 12 | from surya.layout.model.encoderdecoder import SuryaLayoutModel 13 | from surya.logging import get_logger 14 | from surya.settings import settings 15 | 16 | logger = get_logger() 17 | 18 | 19 | class LayoutModelLoader(ModelLoader): 20 | def __init__(self, checkpoint: Optional[str] = None): 21 | super().__init__(checkpoint) 22 | 23 | if self.checkpoint is None: 24 | self.checkpoint = settings.LAYOUT_MODEL_CHECKPOINT 25 | 26 | def model( 27 | self, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE 28 | ) -> SuryaLayoutModel: 29 | if device is None: 30 | device = settings.TORCH_DEVICE_MODEL 31 | if dtype is None: 32 | dtype = settings.MODEL_DTYPE 33 | 34 | config = SuryaLayoutConfig.from_pretrained(self.checkpoint) 35 | decoder_config = config.decoder 36 | decoder = SuryaLayoutDecoderConfig(**decoder_config) 37 | config.decoder = decoder 38 | 39 | encoder_config = config.encoder 40 | encoder = DonutSwinLayoutConfig(**encoder_config) 41 | config.encoder = encoder 42 | 43 | model = SuryaLayoutModel.from_pretrained( 44 | self.checkpoint, config=config, torch_dtype=dtype 45 | ) 46 | model = model.to(device) 47 | model = model.eval() 48 | 49 | if settings.COMPILE_ALL or settings.COMPILE_LAYOUT: 50 | torch.set_float32_matmul_precision("high") 51 | torch._dynamo.config.cache_size_limit = 16 52 | torch._dynamo.config.suppress_errors = False 53 | 54 | logger.info( 55 | f"Compiling layout model {self.checkpoint} on device {device} with dtype {dtype}" 56 | ) 57 | compile_args = {"backend": "openxla"} if device == "xla" else {} 58 | model.encoder = torch.compile(model.encoder, **compile_args) 59 | model.decoder = torch.compile(model.decoder, **compile_args) 60 | 61 | logger.debug( 62 | f"Loaded layout model {self.checkpoint} from {SuryaLayoutModel.get_local_path(self.checkpoint)} onto device {device} with dtype {dtype}" 63 | ) 64 | return model 65 | 66 | def processor( 67 | self, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE 68 | ) -> SuryaEncoderImageProcessor: 69 | processor = SuryaEncoderImageProcessor(max_size=settings.LAYOUT_IMAGE_SIZE) 70 | return processor 71 | -------------------------------------------------------------------------------- /surya/layout/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/surya/layout/model/__init__.py -------------------------------------------------------------------------------- /surya/layout/model/decoder.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.utils.checkpoint 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | from surya.common.adetr.decoder import SuryaADETRDecoderModel, SuryaADETRDecoderPreTrainedModel 9 | from surya.layout.model.config import LayoutModelOutput 10 | from transformers.modeling_outputs import CausalLMOutput 11 | from surya.settings import settings 12 | 13 | 14 | class BboxEmbedding(nn.Module): 15 | def __init__(self, config): 16 | super().__init__() 17 | self.w_embed = nn.Embedding(config.vocab_size, config.hidden_size) 18 | self.h_embed = nn.Embedding(config.vocab_size, config.hidden_size) 19 | self.cx_embed = nn.Embedding(config.vocab_size, config.hidden_size) 20 | self.cy_embed = nn.Embedding(config.vocab_size, config.hidden_size) 21 | self.xskew_embed = nn.Embedding(config.vocab_size, config.hidden_size) 22 | self.yskew_embed = nn.Embedding(config.vocab_size, config.hidden_size) 23 | self.label_embed = nn.Embedding(config.label_count, config.hidden_size) 24 | 25 | self.x1_embed = nn.Embedding(config.vocab_size, config.hidden_size) 26 | self.y1_embed = nn.Embedding(config.vocab_size, config.hidden_size) 27 | self.x2_embed = nn.Embedding(config.vocab_size, config.hidden_size) 28 | self.y2_embed = nn.Embedding(config.vocab_size, config.hidden_size) 29 | self.x3_embed = nn.Embedding(config.vocab_size, config.hidden_size) 30 | self.y3_embed = nn.Embedding(config.vocab_size, config.hidden_size) 31 | self.x4_embed = nn.Embedding(config.vocab_size, config.hidden_size) 32 | self.y4_embed = nn.Embedding(config.vocab_size, config.hidden_size) 33 | 34 | self.config = config 35 | 36 | def forward(self, boxes: torch.LongTensor, input_box_counts: torch.LongTensor): 37 | cx, cy, w, h, xskew, yskew, label = boxes.to(torch.long).unbind(dim=-1) 38 | 39 | xskew_actual = ((xskew - self.config.bbox_size // 2) / 2).to(torch.long) 40 | yskew_actual = ((yskew - self.config.bbox_size // 2) / 2).to(torch.long) 41 | 42 | x1 = (cx - w // 2 - xskew_actual).clamp(0, self.config.bbox_size).to(torch.long) 43 | y1 = (cy - h // 2 - yskew_actual).clamp(0, self.config.bbox_size).to(torch.long) 44 | x2 = (cx + w // 2 - xskew_actual).clamp(0, self.config.bbox_size).to(torch.long) 45 | y2 = (cy + h // 2 + yskew_actual).clamp(0, self.config.bbox_size).to(torch.long) 46 | x3 = (cx + w // 2 + xskew_actual).clamp(0, self.config.bbox_size).to(torch.long) 47 | y3 = (cy + h // 2 + yskew_actual).clamp(0, self.config.bbox_size).to(torch.long) 48 | x4 = (cx - w // 2 + xskew_actual).clamp(0, self.config.bbox_size).to(torch.long) 49 | y4 = (cy - h // 2 - yskew_actual).clamp(0, self.config.bbox_size).to(torch.long) 50 | 51 | label_embeds = self.label_embed(label) 52 | size_embeds = self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy) 53 | skew_embeds = self.xskew_embed(xskew) + self.yskew_embed(yskew) 54 | corner_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x2_embed(x2) + self.y2_embed(y2) + self.x3_embed(x3) + self.y3_embed(y3) + self.x4_embed(x4) + self.y4_embed(y4) 55 | embedded = label_embeds + size_embeds + skew_embeds + corner_embeds 56 | 57 | return embedded 58 | 59 | 60 | class SuryaLayoutDecoder(SuryaADETRDecoderPreTrainedModel): 61 | _tied_weights_keys = None 62 | 63 | def __init__(self, config, **kwargs): 64 | super().__init__(config) 65 | embed_tokens = BboxEmbedding(config) 66 | self.model = SuryaADETRDecoderModel( 67 | config, 68 | embedder=embed_tokens, 69 | static_cache=settings.LAYOUT_STATIC_CACHE, 70 | max_boxes=settings.LAYOUT_MAX_BOXES 71 | ) 72 | self.vocab_size = config.vocab_size 73 | self.lm_head = nn.Linear(config.hidden_size, config.label_count, bias=False) 74 | self.bbox_head = nn.Linear(config.hidden_size, 6, bias=True) 75 | self.pre_output_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 76 | 77 | self.bbox_size = config.bbox_size 78 | self.label_count = config.label_count 79 | # Initialize weights and apply final processing 80 | self.post_init() 81 | 82 | def get_input_embeddings(self): 83 | return self.model.embed_tokens 84 | 85 | def set_input_embeddings(self, value): 86 | self.model.embed_tokens = value 87 | 88 | def set_decoder(self, decoder): 89 | self.model = decoder 90 | 91 | def get_decoder(self): 92 | return self.model 93 | 94 | # Ignore copy 95 | def forward( 96 | self, 97 | input_boxes: torch.LongTensor = None, 98 | input_boxes_counts: torch.LongTensor = None, 99 | cache_position: Optional[torch.LongTensor] = None, 100 | attention_mask: Optional[torch.Tensor] = None, 101 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 102 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 103 | use_cache: Optional[bool] = None, 104 | **kwargs 105 | ) -> Union[Tuple, CausalLMOutput]: 106 | outputs = self.model( 107 | input_ids=input_boxes, 108 | input_boxes_counts=input_boxes_counts, 109 | cache_position=cache_position, 110 | attention_mask=attention_mask, 111 | encoder_hidden_states=encoder_hidden_states, 112 | encoder_attention_mask=encoder_attention_mask, 113 | use_cache=use_cache, 114 | output_hidden_states=True, 115 | return_dict=True, 116 | ) 117 | 118 | hidden_states = self.pre_output_norm(outputs[0]) 119 | class_logits = self.lm_head(hidden_states) 120 | bbox_logits = F.sigmoid(self.bbox_head(hidden_states)) 121 | 122 | return LayoutModelOutput( 123 | bbox_logits=bbox_logits, 124 | class_logits=class_logits, 125 | hidden_states=outputs.hidden_states, 126 | ) -------------------------------------------------------------------------------- /surya/layout/model/encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | from surya.common.donut.encoder import DonutSwinPreTrainedModel, DonutSwinModelOutput, DonutSwinEmbeddings, DonutSwinEncoder 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class DonutSwinLayoutModel(DonutSwinPreTrainedModel): 8 | def __init__(self, config, add_pooling_layer=True, use_mask_token=False): 9 | super().__init__(config) 10 | self.config = config 11 | self.num_layers = len(config.depths) 12 | self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1)) 13 | 14 | self.embeddings = DonutSwinEmbeddings(config, use_mask_token=use_mask_token) 15 | self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid) 16 | 17 | self.position_embeddings = nn.Parameter(torch.zeros(1, config.encoder_length, config.hidden_size)) 18 | 19 | # Initialize weights and apply final processing 20 | self.post_init() 21 | 22 | def get_input_embeddings(self): 23 | return self.embeddings.patch_embeddings 24 | 25 | def _prune_heads(self, heads_to_prune): 26 | """ 27 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base 28 | class PreTrainedModel 29 | """ 30 | for layer, heads in heads_to_prune.items(): 31 | self.encoder.layer[layer].attention.prune_heads(heads) 32 | 33 | def forward( 34 | self, 35 | pixel_values: Optional[torch.FloatTensor] = None, 36 | bool_masked_pos: Optional[torch.BoolTensor] = None, 37 | head_mask: Optional[torch.FloatTensor] = None, 38 | output_attentions: Optional[bool] = None, 39 | output_hidden_states: Optional[bool] = None, 40 | interpolate_pos_encoding: bool = False, 41 | return_dict: Optional[bool] = None, 42 | ) -> Union[Tuple, DonutSwinModelOutput]: 43 | r""" 44 | bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): 45 | Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). 46 | """ 47 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 48 | output_hidden_states = ( 49 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 50 | ) 51 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 52 | 53 | if pixel_values is None: 54 | raise ValueError("You have to specify pixel_values") 55 | 56 | # Prepare head mask if needed 57 | # 1.0 in head_mask indicate we keep the head 58 | # attention_probs has shape bsz x n_heads x N x N 59 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] 60 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] 61 | head_mask = self.get_head_mask(head_mask, len(self.config.depths)) 62 | 63 | embedding_output, input_dimensions = self.embeddings( 64 | pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding 65 | ) 66 | 67 | encoder_outputs = self.encoder( 68 | embedding_output, 69 | input_dimensions, 70 | head_mask=head_mask, 71 | output_attentions=output_attentions, 72 | output_hidden_states=output_hidden_states, 73 | return_dict=return_dict, 74 | ) 75 | 76 | last_hidden_state = encoder_outputs[0] 77 | last_hidden_state += self.position_embeddings[:, :last_hidden_state.size(1), :] 78 | 79 | return DonutSwinModelOutput( 80 | last_hidden_state=last_hidden_state, 81 | ) -------------------------------------------------------------------------------- /surya/layout/model/encoderdecoder.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Union, Tuple 3 | 4 | import torch 5 | from transformers import PreTrainedModel, VisionEncoderDecoderConfig, PretrainedConfig 6 | from transformers.modeling_outputs import BaseModelOutput 7 | from surya.common.s3 import S3DownloaderMixin 8 | from surya.layout.model.encoder import DonutSwinLayoutModel 9 | from surya.layout.model.decoder import SuryaLayoutDecoder 10 | from transformers.utils import ModelOutput 11 | 12 | @dataclass 13 | class LayoutBboxOutput(ModelOutput): 14 | bbox_logits: torch.FloatTensor = None 15 | class_logits: torch.FloatTensor = None 16 | decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None 17 | encoder_last_hidden_state: Optional[torch.FloatTensor] = None 18 | 19 | 20 | class SuryaLayoutModel(S3DownloaderMixin, PreTrainedModel): 21 | config_class = VisionEncoderDecoderConfig 22 | base_model_prefix = "vision_encoder_decoder" 23 | main_input_name = "pixel_values" 24 | supports_gradient_checkpointing = True 25 | _supports_param_buffer_assignment = False 26 | 27 | def __init__( 28 | self, 29 | config: Optional[PretrainedConfig] = None, 30 | encoder: Optional[PreTrainedModel] = None, 31 | decoder: Optional[PreTrainedModel] = None, 32 | ): 33 | # initialize with config 34 | # make sure input & output embeddings is not tied 35 | config.tie_word_embeddings = False 36 | config.decoder.tie_word_embeddings = False 37 | super().__init__(config) 38 | 39 | if encoder is None: 40 | encoder = DonutSwinLayoutModel(config.encoder) 41 | 42 | if decoder is None: 43 | decoder = SuryaLayoutDecoder(config.decoder, attn_implementation=config._attn_implementation) 44 | 45 | self.encoder = encoder 46 | self.decoder = decoder 47 | 48 | # make sure that the individual model's config refers to the shared config 49 | # so that the updates to the config will be synced 50 | self.encoder.config = self.config.encoder 51 | self.decoder.config = self.config.decoder 52 | 53 | def get_encoder(self): 54 | return self.encoder 55 | 56 | def get_decoder(self): 57 | return self.decoder 58 | 59 | def get_output_embeddings(self): 60 | return self.decoder.get_output_embeddings() 61 | 62 | def set_output_embeddings(self, new_embeddings): 63 | return self.decoder.set_output_embeddings(new_embeddings) 64 | 65 | def forward( 66 | self, 67 | pixel_values: Optional[torch.FloatTensor] = None, 68 | decoder_input_boxes: torch.LongTensor = None, # Shape (batch_size, num_boxes, 7), first 6 values all coords scaled 0 - 1024, with 1025 as padding, last value is the label, 0 to 11 69 | decoder_cache_position: Optional[torch.LongTensor] = None, 70 | decoder_attention_mask: Optional[torch.BoolTensor] = None, 71 | decoder_input_boxes_counts: torch.LongTensor = None, # Shape (batch_size), number of boxes in each image 72 | encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, 73 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, 74 | labels: Optional[torch.LongTensor] = None, 75 | use_cache: Optional[bool] = None, 76 | output_attentions: Optional[bool] = None, 77 | output_hidden_states: Optional[bool] = None, 78 | return_dict: Optional[bool] = None, 79 | **kwargs, 80 | ) -> Union[Tuple[torch.FloatTensor], LayoutBboxOutput]: 81 | kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} 82 | 83 | kwargs_decoder = { 84 | argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") 85 | } 86 | 87 | if encoder_outputs is None: 88 | if pixel_values is None: 89 | raise ValueError("You have to specify pixel_values") 90 | 91 | encoder_outputs = self.encoder( 92 | pixel_values=pixel_values, 93 | **kwargs_encoder, 94 | ) 95 | elif isinstance(encoder_outputs, tuple): 96 | encoder_outputs = BaseModelOutput(*encoder_outputs) 97 | 98 | encoder_hidden_states = encoder_outputs[0] 99 | 100 | # We need a start token as the first token 101 | assert decoder_input_boxes[0][0][0] == self.config.decoder_start_token_id 102 | assert decoder_input_boxes[0][0].shape == (7,) 103 | 104 | decoder_outputs = self.decoder( 105 | input_boxes=decoder_input_boxes, 106 | input_boxes_counts=decoder_input_boxes_counts, 107 | cache_position=decoder_cache_position, 108 | attention_mask=decoder_attention_mask, 109 | encoder_hidden_states=encoder_hidden_states, 110 | encoder_attention_mask=None, 111 | use_cache=use_cache, 112 | **kwargs_decoder, 113 | ) 114 | 115 | return LayoutBboxOutput( 116 | bbox_logits=decoder_outputs.bbox_logits, 117 | class_logits=decoder_outputs.class_logits, 118 | decoder_hidden_states=decoder_outputs.hidden_states, 119 | encoder_last_hidden_state=encoder_outputs.last_hidden_state 120 | ) 121 | 122 | def _reorder_cache(self, past_key_values, beam_idx): 123 | # apply decoder cache reordering here 124 | return self.decoder._reorder_cache(past_key_values, beam_idx) -------------------------------------------------------------------------------- /surya/layout/schema.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict, List 2 | 3 | from pydantic import BaseModel 4 | 5 | from surya.common.polygon import PolygonBox 6 | 7 | 8 | class LayoutBox(PolygonBox): 9 | label: str 10 | position: int 11 | top_k: Optional[Dict[str, float]] = None 12 | 13 | 14 | class LayoutResult(BaseModel): 15 | bboxes: List[LayoutBox] 16 | image_bbox: List[float] 17 | sliced: bool = False # Whether the image was sliced and reconstructed 18 | -------------------------------------------------------------------------------- /surya/layout/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def prediction_to_polygon(pred, img_size, bbox_scaler, skew_scaler, skew_min=0.001): 5 | w_scale = img_size[0] / bbox_scaler 6 | h_scale = img_size[1] / bbox_scaler 7 | 8 | boxes = pred 9 | cx = boxes[0] 10 | cy = boxes[1] 11 | width = boxes[2] 12 | height = boxes[3] 13 | x1 = cx - width / 2 14 | y1 = cy - height / 2 15 | x2 = cx + width / 2 16 | y2 = cy + height / 2 17 | skew_x = torch.floor((boxes[4] - skew_scaler) / 2) 18 | skew_y = torch.floor((boxes[5] - skew_scaler) / 2) 19 | 20 | # Ensures we don't get slightly warped boxes 21 | # Note that the values are later scaled, so this is in 1/1024 space 22 | skew_x[torch.abs(skew_x) < skew_min] = 0 23 | skew_y[torch.abs(skew_y) < skew_min] = 0 24 | 25 | polygon = [ 26 | x1 - skew_x, 27 | y1 - skew_y, 28 | x2 - skew_x, 29 | y1 + skew_y, 30 | x2 + skew_x, 31 | y2 + skew_y, 32 | x1 + skew_x, 33 | y2 - skew_y, 34 | ] 35 | poly = [] 36 | for i in range(4): 37 | poly.append( 38 | [polygon[2 * i].item() * w_scale, polygon[2 * i + 1].item() * h_scale] 39 | ) 40 | return poly 41 | -------------------------------------------------------------------------------- /surya/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import warnings 3 | from surya.settings import settings 4 | 5 | 6 | def configure_logging(): 7 | # Setup surya logger 8 | logger = get_logger() 9 | 10 | if not logger.handlers: 11 | handler = logging.StreamHandler() 12 | formatter = logging.Formatter( 13 | "%(asctime)s [%(levelname)s] %(name)s: %(message)s" 14 | ) 15 | handler.setFormatter(formatter) 16 | logger.addHandler(handler) 17 | 18 | logger.setLevel(settings.LOGLEVEL) 19 | warnings.simplefilter(action="ignore", category=FutureWarning) 20 | 21 | 22 | def get_logger(): 23 | return logging.getLogger("surya") 24 | -------------------------------------------------------------------------------- /surya/models.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | 5 | from surya.common.predictor import BasePredictor 6 | from surya.detection import DetectionPredictor 7 | from surya.layout import LayoutPredictor 8 | from surya.logging import configure_logging 9 | from surya.ocr_error import OCRErrorPredictor 10 | from surya.recognition import RecognitionPredictor 11 | from surya.table_rec import TableRecPredictor 12 | 13 | configure_logging() 14 | 15 | 16 | def load_predictors( 17 | device: str | torch.device | None = None, dtype: torch.dtype | str | None = None 18 | ) -> Dict[str, BasePredictor]: 19 | return { 20 | "layout": LayoutPredictor(device=device, dtype=dtype), 21 | "ocr_error": OCRErrorPredictor(device=device, dtype=dtype), 22 | "recognition": RecognitionPredictor(device=device, dtype=dtype), 23 | "detection": DetectionPredictor(device=device, dtype=dtype), 24 | "table_rec": TableRecPredictor(device=device, dtype=dtype), 25 | } 26 | -------------------------------------------------------------------------------- /surya/ocr_error/__init__.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Optional 3 | 4 | from tqdm import tqdm 5 | 6 | from surya.common.predictor import BasePredictor 7 | from surya.ocr_error.loader import OCRErrorModelLoader 8 | from surya.ocr_error.model.config import ID2LABEL 9 | from surya.ocr_error.schema import OCRErrorDetectionResult 10 | from surya.settings import settings 11 | from surya.common.util import mark_step 12 | 13 | 14 | class OCRErrorPredictor(BasePredictor): 15 | model_loader_cls = OCRErrorModelLoader 16 | batch_size = settings.OCR_ERROR_BATCH_SIZE 17 | default_batch_sizes = {"cpu": 8, "mps": 8, "cuda": 64, "xla": 32} 18 | 19 | def __call__(self, texts: List[str], batch_size: Optional[int] = None): 20 | return self.batch_ocr_error_detection(texts, batch_size) 21 | 22 | def batch_ocr_error_detection( 23 | self, texts: List[str], batch_size: Optional[int] = None 24 | ): 25 | if batch_size is None: 26 | batch_size = self.get_batch_size() 27 | 28 | num_batches = math.ceil(len(texts) / batch_size) 29 | texts_processed = self.processor( 30 | texts, padding="longest", truncation=True, return_tensors="pt" 31 | ) 32 | predictions = [] 33 | for batch_idx in tqdm( 34 | range(num_batches), 35 | desc="Running OCR Error Detection", 36 | disable=self.disable_tqdm, 37 | ): 38 | start_idx, end_idx = batch_idx * batch_size, (batch_idx + 1) * batch_size 39 | batch_input_ids = texts_processed.input_ids[start_idx:end_idx].to( 40 | self.model.device 41 | ) 42 | batch_attention_mask = texts_processed.attention_mask[start_idx:end_idx].to( 43 | self.model.device 44 | ) 45 | 46 | # Pad to batch size 47 | current_batch_size = batch_input_ids.shape[0] 48 | if settings.OCR_ERROR_STATIC_CACHE: 49 | batch_input_ids = self.pad_to_batch_size(batch_input_ids, batch_size) 50 | batch_attention_mask = self.pad_to_batch_size( 51 | batch_attention_mask, batch_size 52 | ) 53 | 54 | with settings.INFERENCE_MODE(): 55 | pred = self.model(batch_input_ids, attention_mask=batch_attention_mask) 56 | 57 | logits = pred.logits.argmax(dim=1).cpu().tolist()[:current_batch_size] 58 | predictions.extend(logits) 59 | mark_step() 60 | 61 | return OCRErrorDetectionResult( 62 | texts=texts, labels=[ID2LABEL[p] for p in predictions] 63 | ) 64 | -------------------------------------------------------------------------------- /surya/ocr_error/loader.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | 5 | from surya.common.load import ModelLoader 6 | from surya.logging import get_logger 7 | from surya.ocr_error.model.config import DistilBertConfig 8 | from surya.ocr_error.model.encoder import DistilBertForSequenceClassification 9 | from surya.ocr_error.tokenizer import DistilBertTokenizer 10 | from surya.settings import settings 11 | 12 | logger = get_logger() 13 | 14 | 15 | class OCRErrorModelLoader(ModelLoader): 16 | def __init__(self, checkpoint: Optional[str] = None): 17 | super().__init__(checkpoint) 18 | 19 | if self.checkpoint is None: 20 | self.checkpoint = settings.OCR_ERROR_MODEL_CHECKPOINT 21 | 22 | def model( 23 | self, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE 24 | ) -> DistilBertForSequenceClassification: 25 | if device is None: 26 | device = settings.TORCH_DEVICE_MODEL 27 | if dtype is None: 28 | dtype = settings.MODEL_DTYPE 29 | 30 | config = DistilBertConfig.from_pretrained(self.checkpoint) 31 | model = ( 32 | DistilBertForSequenceClassification.from_pretrained( 33 | self.checkpoint, 34 | torch_dtype=dtype, 35 | config=config, 36 | ) 37 | .to(device) 38 | .eval() 39 | ) 40 | 41 | if settings.COMPILE_ALL or settings.COMPILE_OCR_ERROR: 42 | torch.set_float32_matmul_precision("high") 43 | torch._dynamo.config.cache_size_limit = 1 44 | torch._dynamo.config.suppress_errors = False 45 | 46 | logger.info( 47 | f"Compiling detection model {self.checkpoint} from {DistilBertForSequenceClassification.get_local_path(self.checkpoint)} onto device {device} with dtype {dtype}" 48 | ) 49 | compile_args = {"backend": "openxla"} if device == "xla" else {} 50 | model = torch.compile(model, **compile_args) 51 | 52 | return model 53 | 54 | def processor( 55 | self, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE 56 | ) -> DistilBertTokenizer: 57 | return DistilBertTokenizer.from_pretrained(self.checkpoint) 58 | -------------------------------------------------------------------------------- /surya/ocr_error/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/surya/ocr_error/model/__init__.py -------------------------------------------------------------------------------- /surya/ocr_error/model/config.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Mapping 3 | 4 | from transformers.configuration_utils import PretrainedConfig 5 | from transformers.onnx import OnnxConfig 6 | 7 | from surya.common.s3 import S3DownloaderMixin 8 | 9 | ID2LABEL = { 10 | 0: 'good', 11 | 1: 'bad' 12 | } 13 | 14 | class DistilBertConfig(S3DownloaderMixin, PretrainedConfig): 15 | model_type = "distilbert" 16 | attribute_map = { 17 | "hidden_size": "dim", 18 | "num_attention_heads": "n_heads", 19 | "num_hidden_layers": "n_layers", 20 | } 21 | 22 | def __init__( 23 | self, 24 | vocab_size=30522, 25 | max_position_embeddings=512, 26 | sinusoidal_pos_embds=False, 27 | n_layers=6, 28 | n_heads=12, 29 | dim=768, 30 | hidden_dim=4 * 768, 31 | dropout=0.1, 32 | attention_dropout=0.1, 33 | activation="gelu", 34 | initializer_range=0.02, 35 | qa_dropout=0.1, 36 | seq_classif_dropout=0.2, 37 | pad_token_id=0, 38 | **kwargs, 39 | ): 40 | self.vocab_size = vocab_size 41 | self.max_position_embeddings = max_position_embeddings 42 | self.sinusoidal_pos_embds = sinusoidal_pos_embds 43 | self.n_layers = n_layers 44 | self.n_heads = n_heads 45 | self.dim = dim 46 | self.hidden_dim = hidden_dim 47 | self.dropout = dropout 48 | self.attention_dropout = attention_dropout 49 | self.activation = activation 50 | self.initializer_range = initializer_range 51 | self.qa_dropout = qa_dropout 52 | self.seq_classif_dropout = seq_classif_dropout 53 | super().__init__(**kwargs, pad_token_id=pad_token_id) 54 | 55 | 56 | class DistilBertOnnxConfig(OnnxConfig): 57 | @property 58 | def inputs(self) -> Mapping[str, Mapping[int, str]]: 59 | if self.task == "multiple-choice": 60 | dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} 61 | else: 62 | dynamic_axis = {0: "batch", 1: "sequence"} 63 | return OrderedDict( 64 | [ 65 | ("input_ids", dynamic_axis), 66 | ("attention_mask", dynamic_axis), 67 | ] 68 | ) -------------------------------------------------------------------------------- /surya/ocr_error/schema.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from pydantic import BaseModel 4 | 5 | 6 | class OCRErrorDetectionResult(BaseModel): 7 | texts: List[str] 8 | labels: List[str] 9 | -------------------------------------------------------------------------------- /surya/recognition/languages.py: -------------------------------------------------------------------------------- 1 | CODE_TO_LANGUAGE = { 2 | "_math": "Math", 3 | "af": "Afrikaans", 4 | "am": "Amharic", 5 | "ar": "Arabic", 6 | "as": "Assamese", 7 | "az": "Azerbaijani", 8 | "be": "Belarusian", 9 | "bg": "Bulgarian", 10 | "bn": "Bengali", 11 | "br": "Breton", 12 | "bs": "Bosnian", 13 | "ca": "Catalan", 14 | "cs": "Czech", 15 | "cy": "Welsh", 16 | "da": "Danish", 17 | "de": "German", 18 | "el": "Greek", 19 | "en": "English", 20 | "eo": "Esperanto", 21 | "es": "Spanish", 22 | "et": "Estonian", 23 | "eu": "Basque", 24 | "fa": "Persian", 25 | "fi": "Finnish", 26 | "fr": "French", 27 | "fy": "Western Frisian", 28 | "ga": "Irish", 29 | "gd": "Scottish Gaelic", 30 | "gl": "Galician", 31 | "gu": "Gujarati", 32 | "ha": "Hausa", 33 | "he": "Hebrew", 34 | "hi": "Hindi", 35 | "hr": "Croatian", 36 | "hu": "Hungarian", 37 | "hy": "Armenian", 38 | "id": "Indonesian", 39 | "is": "Icelandic", 40 | "it": "Italian", 41 | "ja": "Japanese", 42 | "jv": "Javanese", 43 | "ka": "Georgian", 44 | "kk": "Kazakh", 45 | "km": "Khmer", 46 | "kn": "Kannada", 47 | "ko": "Korean", 48 | "ku": "Kurdish", 49 | "ky": "Kyrgyz", 50 | "la": "Latin", 51 | "lo": "Lao", 52 | "lt": "Lithuanian", 53 | "lv": "Latvian", 54 | "mg": "Malagasy", 55 | "mk": "Macedonian", 56 | "ml": "Malayalam", 57 | "mn": "Mongolian", 58 | "mr": "Marathi", 59 | "ms": "Malay", 60 | "my": "Burmese", 61 | "ne": "Nepali", 62 | "nl": "Dutch", 63 | "no": "Norwegian", 64 | "om": "Oromo", 65 | "or": "Oriya", 66 | "pa": "Punjabi", 67 | "pl": "Polish", 68 | "ps": "Pashto", 69 | "pt": "Portuguese", 70 | "ro": "Romanian", 71 | "ru": "Russian", 72 | "sa": "Sanskrit", 73 | "sd": "Sindhi", 74 | "si": "Sinhala", 75 | "sk": "Slovak", 76 | "sl": "Slovenian", 77 | "so": "Somali", 78 | "sq": "Albanian", 79 | "sr": "Serbian", 80 | "su": "Sundanese", 81 | "sv": "Swedish", 82 | "sw": "Swahili", 83 | "ta": "Tamil", 84 | "te": "Telugu", 85 | "th": "Thai", 86 | "tl": "Tagalog", 87 | "tr": "Turkish", 88 | "ug": "Uyghur", 89 | "uk": "Ukrainian", 90 | "ur": "Urdu", 91 | "uz": "Uzbek", 92 | "vi": "Vietnamese", 93 | "xh": "Xhosa", 94 | "yi": "Yiddish", 95 | "zh": "Chinese", 96 | } 97 | 98 | LANGUAGE_TO_CODE = {v: k for k, v in CODE_TO_LANGUAGE.items()} 99 | -------------------------------------------------------------------------------- /surya/recognition/loader.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from transformers.utils import is_flash_attn_2_available 5 | 6 | from surya.common.load import ModelLoader 7 | from surya.common.surya.config import SuryaModelConfig 8 | from surya.common.surya import SuryaModel 9 | from surya.common.surya.processor import SuryaOCRProcessor 10 | from surya.common.surya.processor.tokenizer import SuryaOCRTokenizer 11 | from surya.common.util import is_flash_attn_2_supported 12 | from surya.logging import get_logger 13 | from surya.settings import settings 14 | 15 | logger = get_logger() 16 | 17 | 18 | class RecognitionModelLoader(ModelLoader): 19 | def __init__(self, checkpoint: Optional[str] = None): 20 | super().__init__(checkpoint) 21 | 22 | if self.checkpoint is None: 23 | self.checkpoint = settings.RECOGNITION_MODEL_CHECKPOINT 24 | 25 | def model( 26 | self, 27 | device=settings.TORCH_DEVICE_MODEL, 28 | dtype=None, 29 | ) -> SuryaModel: 30 | if device is None: 31 | device = settings.TORCH_DEVICE_MODEL 32 | if dtype is None: 33 | # See https://github.com/pytorch/pytorch/issues/118122 - T4 (device version 7.5) will return true since it supports 34 | # emulated bf16, but falls back to very slow kernels, especially for SDPA 35 | if torch.cuda.is_bf16_supported(including_emulation=False): 36 | dtype = settings.MODEL_DTYPE_BFLOAT 37 | else: 38 | dtype = settings.MODEL_DTYPE 39 | 40 | torch.set_float32_matmul_precision("high") 41 | config = SuryaModelConfig.from_pretrained(self.checkpoint) 42 | 43 | if is_flash_attn_2_available() and is_flash_attn_2_supported(device): 44 | config.decoder._attn_implementation = "flash_attention_2" 45 | config.vision_encoder._attn_implementation = "flash_attention_2" 46 | else: 47 | config.decoder._attn_implementation = "sdpa" 48 | config.vision_encoder._attn_implementation = "sdpa" 49 | 50 | model = SuryaModel.from_pretrained( 51 | self.checkpoint, torch_dtype=dtype, config=config 52 | ).to(device) 53 | model = model.eval() 54 | 55 | logger.debug( 56 | f"Loaded recognition model {self.checkpoint} from {SuryaModel.get_local_path(self.checkpoint)} onto device {model.device} with dtype {dtype}, using decoder attention mechanism {model.config.decoder._attn_implementation}, encoder attention mechanism {model.config.vision_encoder._attn_implementation}." 57 | ) 58 | return model 59 | 60 | def processor( 61 | self, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE_BFLOAT 62 | ) -> SuryaOCRProcessor: 63 | config: SuryaModelConfig = SuryaModelConfig.from_pretrained(self.checkpoint) 64 | 65 | ocr_tokenizer = SuryaOCRTokenizer( 66 | special_tokens=config.special_ocr_tokens, model_checkpoint=self.checkpoint 67 | ) 68 | 69 | processor = SuryaOCRProcessor( 70 | ocr_tokenizer=ocr_tokenizer, 71 | blank_bbox_token_id=config.blank_bbox_token_id, 72 | num_register_tokens=config.num_register_tokens, 73 | sequence_length=None, 74 | patch_size=config.vision_encoder.patch_size, 75 | merge_size=config.vision_encoder.spatial_merge_size, 76 | model_device=device, 77 | ) 78 | config.eos_token_id = processor.eos_token_id 79 | config.pad_token_id = processor.pad_token_id 80 | config.bos_token_id = processor.bos_token_id 81 | 82 | return processor 83 | -------------------------------------------------------------------------------- /surya/recognition/postprocessing.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List, Dict 3 | 4 | from surya.recognition.schema import TextChar 5 | 6 | 7 | def truncate_repetitions(text: str, min_len=15): 8 | # From nougat, with some cleanup 9 | if len(text) < 2 * min_len: 10 | return text 11 | 12 | # try to find a length at which the tail is repeating 13 | max_rep_len = None 14 | for rep_len in range(min_len, int(len(text) / 2)): 15 | # check if there is a repetition at the end 16 | same = True 17 | for i in range(0, rep_len): 18 | if text[len(text) - rep_len - i - 1] != text[len(text) - i - 1]: 19 | same = False 20 | break 21 | 22 | if same: 23 | max_rep_len = rep_len 24 | 25 | if max_rep_len is None: 26 | return text 27 | 28 | lcs = text[-max_rep_len:] 29 | 30 | # remove all but the last repetition 31 | text_to_truncate = text 32 | while text_to_truncate.endswith(lcs): 33 | text_to_truncate = text_to_truncate[:-max_rep_len] 34 | 35 | return text[: len(text_to_truncate)] 36 | 37 | 38 | def extract_tags(proposed_tags: List[str]) -> List[str]: 39 | tags = [] 40 | for tag in proposed_tags: 41 | tag_match = re.match(tag_pattern, tag) 42 | if not tag_match: 43 | continue 44 | 45 | if not tag_match.group(1) == "/": 46 | continue 47 | 48 | tags.append(tag_match.group(2)) 49 | return tags 50 | 51 | 52 | tag_pattern = re.compile(r"<(/?)([a-z]+)([^>]*)>?", re.IGNORECASE) 53 | 54 | 55 | def cleanup_math(line: str): 56 | matches = re.finditer(r"(]*>)(.*?)", line, re.DOTALL) 57 | result = line 58 | 59 | for match in matches: 60 | opening_tag = match.group(1) # The opening tag with attributes 61 | full_match = match.group(0) # The entire content tag 62 | block_content = match.group(2) # Just the content inside the tags 63 | 64 | clean_block = re.sub(r"<[^>]+>", "", block_content) 65 | 66 | if not re.search(r"[\\\_]", clean_block): 67 | result = result.replace(full_match, clean_block) 68 | else: 69 | result = result.replace(full_match, f"{opening_tag}{clean_block}") 70 | 71 | return result 72 | 73 | 74 | def fix_unbalanced_tags( 75 | text_chars: List[TextChar], special_tokens: Dict[str, list] 76 | ) -> List[TextChar]: 77 | self_closing_tags = ["br"] 78 | 79 | open_tags = [] 80 | 81 | format_tags = extract_tags(special_tokens["formatting"]) + extract_tags( 82 | special_tokens["math_external"] 83 | ) 84 | 85 | for char in text_chars: 86 | if len(char.text) <= 1: 87 | continue 88 | 89 | tag_match = re.match(tag_pattern, char.text) 90 | if not tag_match: 91 | continue 92 | 93 | is_closing = tag_match.group(1) == "/" 94 | tag_name = tag_match.group(2).lower() 95 | 96 | if tag_name not in format_tags: 97 | continue 98 | 99 | if tag_name in self_closing_tags: 100 | continue 101 | 102 | # Self-closing tags 103 | if tag_match.group(3) and tag_match.group(3).strip().endswith("/"): 104 | continue 105 | 106 | if is_closing: 107 | if open_tags and open_tags[-1] == tag_name: 108 | open_tags.pop() 109 | else: 110 | open_tags.append(tag_name) 111 | 112 | for tag in open_tags: 113 | text_chars.append( 114 | TextChar( 115 | text=f"", 116 | confidence=0, 117 | polygon=[[0, 0], [1, 0], [1, 1], [0, 1]], 118 | bbox_valid=False, 119 | ) 120 | ) 121 | return text_chars 122 | -------------------------------------------------------------------------------- /surya/recognition/schema.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | from typing import Optional, List 4 | 5 | from pydantic import BaseModel, field_validator 6 | 7 | from surya.common.polygon import PolygonBox 8 | 9 | 10 | class BaseChar(PolygonBox): 11 | text: str 12 | confidence: Optional[float] = 0 13 | 14 | @field_validator("confidence", mode="before") 15 | @classmethod 16 | def validate_confidence(cls, v: float) -> float: 17 | if v is None: 18 | return 0 19 | elif math.isnan(v) or np.isnan(v): 20 | return 0 21 | return v 22 | 23 | 24 | class TextChar(BaseChar): 25 | bbox_valid: bool = True # This is false when the given bbox is not valid 26 | 27 | 28 | class TextWord(BaseChar): 29 | bbox_valid: bool = True 30 | 31 | 32 | class TextLine(BaseChar): 33 | chars: List[TextChar] # Individual characters in the line 34 | original_text_good: bool = False 35 | words: List[TextWord] | None = None 36 | 37 | 38 | class OCRResult(BaseModel): 39 | text_lines: List[TextLine] 40 | image_bbox: List[float] 41 | -------------------------------------------------------------------------------- /surya/recognition/util.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List, Tuple 3 | 4 | import numpy 5 | import torch 6 | 7 | from surya.common.polygon import PolygonBox 8 | from surya.recognition.schema import TextLine, TextWord, TextChar 9 | 10 | MATH_SYMBOLS = ["+", "-", "*", "=", "^", "_", "\\", "{", "}"] 11 | 12 | 13 | def unwrap_math(text: str) -> str: 14 | if len(text) > 50: 15 | return text 16 | 17 | # Detected as math, but does not contain LaTeX commands 18 | if ( 19 | re.match(r'^\s*\s*$', text, re.DOTALL) 20 | and text.count("", "", text) 25 | text = re.sub(r"", "", text) 26 | 27 | return text 28 | 29 | 30 | def detect_repeat_token(predicted_tokens: List[int], max_repeats: int = 40): 31 | if len(predicted_tokens) < max_repeats: 32 | return False 33 | 34 | # Detect repeats containing 1 or 2 tokens 35 | last_n = predicted_tokens[-max_repeats:] 36 | unique_tokens = len(set(last_n)) 37 | if unique_tokens > 5: 38 | return False 39 | 40 | return last_n[-unique_tokens:] == last_n[-unique_tokens * 2 : -unique_tokens] 41 | 42 | 43 | def sort_text_lines(lines: List[TextLine] | List[dict], tolerance=1.25): 44 | # Sorts in reading order. Not 100% accurate, this should only 45 | # be used as a starting point for more advanced sorting. 46 | vertical_groups = {} 47 | for line in lines: 48 | group_key = ( 49 | round( 50 | line.bbox[1] 51 | if isinstance(line, TextLine) 52 | else line["bbox"][1] / tolerance 53 | ) 54 | * tolerance 55 | ) 56 | if group_key not in vertical_groups: 57 | vertical_groups[group_key] = [] 58 | vertical_groups[group_key].append(line) 59 | 60 | # Sort each group horizontally and flatten the groups into a single list 61 | sorted_lines = [] 62 | for _, group in sorted(vertical_groups.items()): 63 | sorted_group = sorted( 64 | group, key=lambda x: x.bbox[0] if isinstance(x, TextLine) else x["bbox"][0] 65 | ) 66 | sorted_lines.extend(sorted_group) 67 | 68 | return sorted_lines 69 | 70 | 71 | def clean_close_polygons(bboxes: List[List[List[int]]], thresh: float = 0.1): 72 | if len(bboxes) < 2: 73 | return bboxes 74 | 75 | new_bboxes = [bboxes[0]] 76 | for i in range(1, len(bboxes)): 77 | close = True 78 | prev_bbox = bboxes[i - 1] 79 | bbox = bboxes[i] 80 | for j in range(4): 81 | if ( 82 | abs(bbox[j][0] - prev_bbox[j][0]) > thresh 83 | or abs(bbox[j][1] - prev_bbox[j][1]) > thresh 84 | ): 85 | close = False 86 | break 87 | 88 | if not close: 89 | new_bboxes.append(bboxes[i]) 90 | 91 | return new_bboxes 92 | 93 | 94 | def words_from_chars(chars: List[TextChar], line_box: PolygonBox): 95 | words = [] 96 | word = None 97 | for i, char in enumerate(chars): 98 | if not char.bbox_valid: 99 | if word: 100 | words.append(word) 101 | word = None 102 | continue 103 | 104 | if not word: 105 | word = TextWord(**char.model_dump()) 106 | 107 | # Fit bounds to line if first word 108 | if i == 0: 109 | word.merge_left(line_box) 110 | 111 | elif not char.text.strip(): 112 | if word: 113 | words.append(word) 114 | word = None 115 | else: 116 | # Merge bboxes 117 | word.merge(char) 118 | word.text = word.text + char.text 119 | 120 | if i == len(chars) - 1: 121 | word.merge_right(line_box) 122 | if word: 123 | words.append(word) 124 | 125 | return words 126 | 127 | 128 | def prediction_to_polygon_batch( 129 | pred: torch.Tensor, 130 | img_sizes: List[Tuple[int, int]], 131 | bbox_scaler, 132 | skew_scaler, 133 | skew_min=0.001, 134 | ): 135 | img_sizes = torch.from_numpy(numpy.array(img_sizes, dtype=numpy.float32)).to( 136 | pred.device 137 | ) 138 | w_scale = (img_sizes[:, 1] / bbox_scaler)[:, None, None] 139 | h_scale = (img_sizes[:, 0] / bbox_scaler)[:, None, None] 140 | 141 | cx = pred[:, :, 0] 142 | cy = pred[:, :, 1] 143 | width = pred[:, :, 2] 144 | height = pred[:, :, 3] 145 | 146 | x1 = cx - width / 2 147 | y1 = cy - height / 2 148 | x2 = cx + width / 2 149 | y2 = cy + height / 2 150 | 151 | skew_x = torch.floor((pred[:, :, 4] - skew_scaler) / 2) 152 | skew_y = torch.floor((pred[:, :, 5] - skew_scaler) / 2) 153 | 154 | skew_x[torch.abs(skew_x) < skew_min] = 0 155 | skew_y[torch.abs(skew_y) < skew_min] = 0 156 | 157 | polygons_flat = torch.stack( 158 | [ 159 | x1 - skew_x, 160 | y1 - skew_y, 161 | x2 - skew_x, 162 | y1 + skew_y, 163 | x2 + skew_x, 164 | y2 + skew_y, 165 | x1 + skew_x, 166 | y2 - skew_y, 167 | ], 168 | dim=2, 169 | ) 170 | 171 | batch_size, seq_len, _ = pred.shape 172 | polygons = polygons_flat.view(batch_size, seq_len, 4, 2) 173 | 174 | polygons[:, :, :, 0] *= w_scale 175 | polygons[:, :, :, 1] *= h_scale 176 | 177 | return polygons 178 | -------------------------------------------------------------------------------- /surya/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/surya/scripts/__init__.py -------------------------------------------------------------------------------- /surya/scripts/config.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import click 4 | import os 5 | from surya.input.load import load_from_folder, load_from_file 6 | from surya.settings import settings 7 | 8 | 9 | class CLILoader: 10 | def __init__(self, filepath: str, cli_options: dict, highres: bool = False): 11 | self.page_range = cli_options.get("page_range") 12 | if self.page_range: 13 | self.page_range = self.parse_range_str(self.page_range) 14 | self.filepath = filepath 15 | self.config = cli_options 16 | self.save_images = cli_options.get("images", False) 17 | self.debug = cli_options.get("debug", False) 18 | self.output_dir = cli_options.get("output_dir") 19 | 20 | self.load(highres) 21 | 22 | @staticmethod 23 | def common_options(fn): 24 | fn = click.argument("input_path", type=click.Path(exists=True), required=True)(fn) 25 | fn = click.option("--output_dir", type=click.Path(exists=False), required=False, default=os.path.join(settings.RESULT_DIR, "surya"), help="Directory to save output.")(fn) 26 | fn = click.option("--page_range", type=str, default=None, help="Page range to convert, specify comma separated page numbers or ranges. Example: 0,5-10,20")(fn) 27 | fn = click.option("--images", is_flag=True, help="Save images of detected bboxes.", default=False)(fn) 28 | fn = click.option('--debug', '-d', is_flag=True, help='Enable debug mode.', default=False)(fn) 29 | return fn 30 | 31 | def load(self, highres: bool = False): 32 | highres_images = None 33 | if os.path.isdir(self.filepath): 34 | images, names = load_from_folder(self.filepath, self.page_range) 35 | folder_name = os.path.basename(self.filepath) 36 | if highres: 37 | highres_images, _ = load_from_folder(self.filepath, self.page_range, settings.IMAGE_DPI_HIGHRES) 38 | else: 39 | images, names = load_from_file(self.filepath, self.page_range) 40 | folder_name = os.path.basename(self.filepath).split(".")[0] 41 | if highres: 42 | highres_images, _ = load_from_file(self.filepath, self.page_range, settings.IMAGE_DPI_HIGHRES) 43 | 44 | 45 | self.images = images 46 | self.highres_images = highres_images 47 | self.names = names 48 | 49 | self.result_path = os.path.abspath(os.path.join(self.output_dir, folder_name)) 50 | os.makedirs(self.result_path, exist_ok=True) 51 | 52 | @staticmethod 53 | def parse_range_str(range_str: str) -> List[int]: 54 | range_lst = range_str.split(",") 55 | page_lst = [] 56 | for i in range_lst: 57 | if "-" in i: 58 | start, end = i.split("-") 59 | page_lst += list(range(int(start), int(end) + 1)) 60 | else: 61 | page_lst.append(int(i)) 62 | page_lst = sorted(list(set(page_lst))) # Deduplicate page numbers and sort in order 63 | return page_lst -------------------------------------------------------------------------------- /surya/scripts/detect_layout.py: -------------------------------------------------------------------------------- 1 | import time 2 | import click 3 | import copy 4 | import json 5 | from collections import defaultdict 6 | 7 | from surya.layout import LayoutPredictor 8 | from surya.debug.draw import draw_polys_on_image 9 | from surya.logging import configure_logging, get_logger 10 | from surya.scripts.config import CLILoader 11 | import os 12 | 13 | configure_logging() 14 | logger = get_logger() 15 | 16 | 17 | @click.command(help="Detect layout of an input file or folder (PDFs or image).") 18 | @CLILoader.common_options 19 | def detect_layout_cli(input_path: str, **kwargs): 20 | loader = CLILoader(input_path, kwargs) 21 | 22 | layout_predictor = LayoutPredictor() 23 | 24 | start = time.time() 25 | layout_predictions = layout_predictor(loader.images) 26 | 27 | if loader.debug: 28 | logger.debug(f"Layout took {time.time() - start} seconds") 29 | 30 | if loader.save_images: 31 | for idx, (image, layout_pred, name) in enumerate( 32 | zip(loader.images, layout_predictions, loader.names) 33 | ): 34 | polygons = [p.polygon for p in layout_pred.bboxes] 35 | labels = [f"{p.label}-{p.position}" for p in layout_pred.bboxes] 36 | bbox_image = draw_polys_on_image( 37 | polygons, copy.deepcopy(image), labels=labels 38 | ) 39 | bbox_image.save( 40 | os.path.join(loader.result_path, f"{name}_{idx}_layout.png") 41 | ) 42 | 43 | predictions_by_page = defaultdict(list) 44 | for idx, (pred, name, image) in enumerate( 45 | zip(layout_predictions, loader.names, loader.images) 46 | ): 47 | out_pred = pred.model_dump() 48 | out_pred["page"] = len(predictions_by_page[name]) + 1 49 | predictions_by_page[name].append(out_pred) 50 | 51 | with open( 52 | os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8" 53 | ) as f: 54 | json.dump(predictions_by_page, f, ensure_ascii=False) 55 | 56 | logger.info(f"Wrote results to {loader.result_path}") 57 | -------------------------------------------------------------------------------- /surya/scripts/detect_text.py: -------------------------------------------------------------------------------- 1 | import click 2 | import copy 3 | import json 4 | import time 5 | from collections import defaultdict 6 | 7 | from surya.detection import DetectionPredictor 8 | from surya.debug.draw import draw_polys_on_image 9 | from surya.logging import configure_logging, get_logger 10 | from surya.scripts.config import CLILoader 11 | import os 12 | 13 | configure_logging() 14 | logger = get_logger() 15 | 16 | 17 | @click.command(help="Detect bboxes in an input file or folder (PDFs or image).") 18 | @CLILoader.common_options 19 | def detect_text_cli(input_path: str, **kwargs): 20 | loader = CLILoader(input_path, kwargs) 21 | 22 | det_predictor = DetectionPredictor() 23 | 24 | start = time.time() 25 | predictions = det_predictor(loader.images, include_maps=loader.debug) 26 | end = time.time() 27 | if loader.debug: 28 | logger.debug(f"Detection took {end - start} seconds") 29 | 30 | if loader.save_images: 31 | for idx, (image, pred, name) in enumerate( 32 | zip(loader.images, predictions, loader.names) 33 | ): 34 | polygons = [p.polygon for p in pred.bboxes] 35 | bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image)) 36 | bbox_image.save(os.path.join(loader.result_path, f"{name}_{idx}_bbox.png")) 37 | 38 | if loader.debug: 39 | heatmap = pred.heatmap 40 | heatmap.save(os.path.join(loader.result_path, f"{name}_{idx}_heat.png")) 41 | 42 | predictions_by_page = defaultdict(list) 43 | for idx, (pred, name, image) in enumerate( 44 | zip(predictions, loader.names, loader.images) 45 | ): 46 | out_pred = pred.model_dump(exclude=["heatmap", "affinity_map"]) 47 | out_pred["page"] = len(predictions_by_page[name]) + 1 48 | predictions_by_page[name].append(out_pred) 49 | 50 | with open( 51 | os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8" 52 | ) as f: 53 | json.dump(predictions_by_page, f, ensure_ascii=False) 54 | 55 | logger.info(f"Wrote results to {loader.result_path}") 56 | -------------------------------------------------------------------------------- /surya/scripts/hf_to_s3.py: -------------------------------------------------------------------------------- 1 | import json 2 | import shutil 3 | import datetime 4 | from pathlib import Path 5 | import boto3 6 | 7 | from huggingface_hub import snapshot_download 8 | 9 | import click 10 | from tqdm import tqdm 11 | 12 | S3_API_URL = "https://1afbe4656a6b40d982ab5e730a39f6b9.r2.cloudflarestorage.com" 13 | 14 | 15 | @click.command(help="Uploads the data from huggingface to an S3 bucket") 16 | @click.argument("hf_repo_id", type=str) 17 | @click.argument("s3_path", type=str) 18 | @click.option("--bucket_name", type=str, default="datalab") 19 | @click.option("--access_key_id", type=str, default="") 20 | @click.option("--access_key_secret", type=str, default="") 21 | @click.option("--suffix", type=str, default="") 22 | def main( 23 | hf_repo_id: str, 24 | s3_path: str, 25 | bucket_name: str, 26 | access_key_id: str, 27 | access_key_secret: str, 28 | suffix: str, 29 | ): 30 | curr_date = datetime.datetime.now().strftime("%Y_%m_%d") 31 | s3_path = f"{s3_path}/{curr_date}" 32 | if suffix: 33 | s3_path = f"{s3_path}_{suffix}" 34 | 35 | download_folder = snapshot_download(repo_id=hf_repo_id) 36 | download_folder = Path(download_folder) 37 | contained_files = list(download_folder.glob("*")) 38 | contained_files = [f.name for f in contained_files] # Just get the base name 39 | manifest_file = download_folder / "manifest.json" 40 | 41 | with open(manifest_file, "w") as f: 42 | json.dump({"files": contained_files}, f) 43 | 44 | # Upload the files to S3 45 | s3_client = boto3.client( 46 | service_name="s3", 47 | endpoint_url=S3_API_URL, 48 | aws_access_key_id=access_key_id, 49 | aws_secret_access_key=access_key_secret, 50 | region_name="auto", 51 | ) 52 | 53 | # Iterate through all files in the folder 54 | for file_path in tqdm( 55 | download_folder.glob("*"), desc="Uploading files", unit="file" 56 | ): 57 | s3_key = f"{s3_path}/{file_path.name}" 58 | 59 | try: 60 | s3_client.upload_file(str(file_path), bucket_name, s3_key) 61 | except Exception as e: 62 | print(f"Error uploading {file_path}: {str(e)}") 63 | 64 | shutil.rmtree(download_folder) 65 | 66 | print(f"Uploaded files to {s3_path}") 67 | 68 | 69 | if __name__ == "__main__": 70 | main() 71 | -------------------------------------------------------------------------------- /surya/scripts/ocr_latex.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import click 4 | import json 5 | import time 6 | from collections import defaultdict 7 | 8 | from surya.logging import configure_logging, get_logger 9 | from surya.scripts.config import CLILoader 10 | from surya.recognition import RecognitionPredictor 11 | from surya.common.surya.schema import TaskNames 12 | 13 | configure_logging() 14 | logger = get_logger() 15 | 16 | 17 | @click.command(help="OCR LaTeX equations.") 18 | @CLILoader.common_options 19 | def ocr_latex_cli(input_path: str, **kwargs): 20 | loader = CLILoader(input_path, kwargs, highres=True) 21 | 22 | texify_predictor = RecognitionPredictor() 23 | tasks = [TaskNames.block_without_boxes] * len(loader.images) 24 | bboxes = [[[0, 0, image.width, image.height]] for image in loader.images] 25 | 26 | start = time.time() 27 | predictions_by_image = texify_predictor( 28 | loader.images, 29 | tasks, 30 | bboxes=bboxes, 31 | ) 32 | 33 | latex_predictions = [p.text_lines[0].text for p in predictions_by_image] 34 | 35 | if loader.debug: 36 | logger.debug(f"OCR took {time.time() - start:.2f} seconds") 37 | max_chars = max([len(latex) for latex in latex_predictions]) 38 | logger.debug(f"Max chars: {max_chars}") 39 | 40 | out_preds = defaultdict(list) 41 | for name, pred, image in zip(loader.names, latex_predictions, loader.images): 42 | out_pred = { 43 | "equation": pred, 44 | "page": len(out_preds[name]) + 1, 45 | } 46 | out_preds[name].append(out_pred) 47 | 48 | with open( 49 | os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8" 50 | ) as f: 51 | json.dump(out_preds, f, ensure_ascii=False) 52 | 53 | logger.info(f"Wrote results to {loader.result_path}") 54 | -------------------------------------------------------------------------------- /surya/scripts/ocr_text.py: -------------------------------------------------------------------------------- 1 | import os 2 | import click 3 | import json 4 | import time 5 | from collections import defaultdict 6 | 7 | from surya.common.surya.schema import TaskNames 8 | from surya.detection import DetectionPredictor 9 | from surya.debug.text import draw_text_on_image 10 | from surya.logging import configure_logging, get_logger 11 | from surya.recognition import RecognitionPredictor 12 | from surya.scripts.config import CLILoader 13 | 14 | configure_logging() 15 | logger = get_logger() 16 | 17 | 18 | @click.command(help="OCR text.") 19 | @click.option("--task_name", type=str, default=TaskNames.ocr_with_boxes) 20 | @click.option( 21 | "--disable_math", is_flag=True, default=False, help="Do not recognize math in OCR." 22 | ) 23 | @CLILoader.common_options 24 | def ocr_text_cli(input_path: str, task_name: str, disable_math: bool, **kwargs): 25 | loader = CLILoader(input_path, kwargs, highres=True) 26 | task_names = [task_name] * len(loader.images) 27 | 28 | det_predictor = DetectionPredictor() 29 | rec_predictor = RecognitionPredictor() 30 | 31 | start = time.time() 32 | predictions_by_image = rec_predictor( 33 | loader.images, 34 | task_names=task_names, 35 | det_predictor=det_predictor, 36 | highres_images=loader.highres_images, 37 | math_mode=not disable_math, 38 | ) 39 | 40 | if loader.debug: 41 | logger.debug(f"OCR took {time.time() - start:.2f} seconds") 42 | max_chars = max( 43 | [len(line.text) for p in predictions_by_image for line in p.text_lines] 44 | ) 45 | logger.debug(f"Max chars: {max_chars}") 46 | 47 | if loader.save_images: 48 | for idx, (name, image, pred) in enumerate( 49 | zip(loader.names, loader.images, predictions_by_image) 50 | ): 51 | bboxes = [line.bbox for line in pred.text_lines] 52 | pred_text = [line.text for line in pred.text_lines] 53 | page_image = draw_text_on_image(bboxes, pred_text, image.size) 54 | page_image.save(os.path.join(loader.result_path, f"{name}_{idx}_text.png")) 55 | 56 | out_preds = defaultdict(list) 57 | for name, pred, image in zip(loader.names, predictions_by_image, loader.images): 58 | out_pred = pred.model_dump() 59 | out_pred["page"] = len(out_preds[name]) + 1 60 | out_preds[name].append(out_pred) 61 | 62 | with open( 63 | os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8" 64 | ) as f: 65 | json.dump(out_preds, f, ensure_ascii=False) 66 | 67 | logger.info(f"Wrote results to {loader.result_path}") 68 | -------------------------------------------------------------------------------- /surya/scripts/run_streamlit_app.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | 4 | 5 | def streamlit_app_cli(): 6 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 7 | ocr_app_path = os.path.join(cur_dir, "streamlit_app.py") 8 | cmd = ["streamlit", "run", ocr_app_path, "--server.fileWatcherType", "none", "--server.headless", "true"] 9 | subprocess.run(cmd, env={**os.environ, "IN_STREAMLIT": "true"}) -------------------------------------------------------------------------------- /surya/scripts/run_texify_app.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | 4 | 5 | def texify_app_cli(): 6 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 7 | ocr_app_path = os.path.join(cur_dir, "texify_app.py") 8 | cmd = ["streamlit", "run", ocr_app_path, "--server.fileWatcherType", "none", "--server.headless", "true"] 9 | subprocess.run(cmd, env={**os.environ, "IN_STREAMLIT": "true"}) -------------------------------------------------------------------------------- /surya/scripts/table_recognition.py: -------------------------------------------------------------------------------- 1 | import os 2 | import click 3 | import copy 4 | import json 5 | from collections import defaultdict 6 | 7 | from surya.logging import configure_logging, get_logger 8 | from surya.scripts.config import CLILoader 9 | from surya.layout import LayoutPredictor 10 | from surya.table_rec import TableRecPredictor 11 | from surya.debug.draw import draw_bboxes_on_image 12 | from surya.common.util import rescale_bbox, expand_bbox 13 | 14 | configure_logging() 15 | logger = get_logger() 16 | 17 | 18 | @click.command(help="Detect layout of an input file or folder (PDFs or image).") 19 | @CLILoader.common_options 20 | @click.option( 21 | "--skip_table_detection", 22 | is_flag=True, 23 | help="Tables are already cropped, so don't re-detect tables.", 24 | default=False, 25 | ) 26 | def table_recognition_cli(input_path: str, skip_table_detection: bool, **kwargs): 27 | loader = CLILoader(input_path, kwargs, highres=True) 28 | 29 | table_rec_predictor = TableRecPredictor() 30 | layout_predictor = LayoutPredictor() 31 | 32 | pnums = [] 33 | prev_name = None 34 | for i, name in enumerate(loader.names): 35 | if prev_name is None or prev_name != name: 36 | pnums.append(0) 37 | else: 38 | pnums.append(pnums[-1] + 1) 39 | 40 | prev_name = name 41 | 42 | layout_predictions = layout_predictor(loader.images) 43 | 44 | table_imgs = [] 45 | table_counts = [] 46 | 47 | for layout_pred, img, highres_img in zip( 48 | layout_predictions, loader.images, loader.highres_images 49 | ): 50 | # The table may already be cropped 51 | if skip_table_detection: 52 | table_imgs.append(highres_img) 53 | table_counts.append(1) 54 | else: 55 | # The bbox for the entire table 56 | bbox = [ 57 | line.bbox 58 | for line in layout_pred.bboxes 59 | if line.label in ["Table", "TableOfContents"] 60 | ] 61 | # Number of tables per page 62 | table_counts.append(len(bbox)) 63 | 64 | if len(bbox) == 0: 65 | continue 66 | 67 | page_table_imgs = [] 68 | highres_bbox = [] 69 | for bb in bbox: 70 | highres_bb = rescale_bbox(bb, img.size, highres_img.size) 71 | highres_bb = expand_bbox(highres_bb) 72 | page_table_imgs.append(highres_img.crop(highres_bb)) 73 | highres_bbox.append(highres_bb) 74 | 75 | table_imgs.extend(page_table_imgs) 76 | 77 | table_preds = table_rec_predictor(table_imgs) 78 | 79 | img_idx = 0 80 | prev_count = 0 81 | table_predictions = defaultdict(list) 82 | for i in range(sum(table_counts)): 83 | while i >= prev_count + table_counts[img_idx]: 84 | prev_count += table_counts[img_idx] 85 | img_idx += 1 86 | 87 | pred = table_preds[i] 88 | orig_name = loader.names[img_idx] 89 | pnum = pnums[img_idx] 90 | table_img = table_imgs[i] 91 | 92 | out_pred = pred.model_dump() 93 | out_pred["page"] = pnum + 1 94 | table_idx = i - prev_count 95 | out_pred["table_idx"] = table_idx 96 | table_predictions[orig_name].append(out_pred) 97 | 98 | if loader.save_images: 99 | rows = [line.bbox for line in pred.rows] 100 | cols = [line.bbox for line in pred.cols] 101 | row_labels = [f"Row {line.row_id}" for line in pred.rows] 102 | col_labels = [f"Col {line.col_id}" for line in pred.cols] 103 | cells = [line.bbox for line in pred.cells] 104 | 105 | rc_image = copy.deepcopy(table_img) 106 | rc_image = draw_bboxes_on_image( 107 | rows, rc_image, labels=row_labels, label_font_size=20, color="blue" 108 | ) 109 | rc_image = draw_bboxes_on_image( 110 | cols, rc_image, labels=col_labels, label_font_size=20, color="red" 111 | ) 112 | rc_image.save( 113 | os.path.join( 114 | loader.result_path, f"{name}_page{pnum + 1}_table{table_idx}_rc.png" 115 | ) 116 | ) 117 | 118 | cell_image = copy.deepcopy(table_img) 119 | cell_image = draw_bboxes_on_image(cells, cell_image, color="green") 120 | cell_image.save( 121 | os.path.join( 122 | loader.result_path, 123 | f"{name}_page{pnum + 1}_table{table_idx}_cells.png", 124 | ) 125 | ) 126 | 127 | with open( 128 | os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8" 129 | ) as f: 130 | json.dump(table_predictions, f, ensure_ascii=False) 131 | 132 | logger.info(f"Wrote results to {loader.result_path}") 133 | -------------------------------------------------------------------------------- /surya/scripts/texify_app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from typing import List 4 | 5 | from surya.recognition import RecognitionPredictor 6 | from surya.common.surya.schema import TaskNames 7 | 8 | os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = ( 9 | "1" # For some reason, transformers decided to use .isin for a simple op, which is not supported on MPS 10 | ) 11 | 12 | import io 13 | 14 | import pandas as pd 15 | import streamlit as st 16 | from streamlit_drawable_canvas import st_canvas 17 | import hashlib 18 | import pypdfium2 19 | 20 | from surya.settings import settings 21 | from PIL import Image 22 | 23 | MAX_WIDTH = 800 24 | MAX_HEIGHT = 1000 25 | 26 | 27 | def replace_fences(text): 28 | text = re.sub(r'(.*?)', r"$$\1$$", text) 29 | text = re.sub(r"(.*?)", r"$\1$", text) 30 | text = re.sub(r'(.*?)', r"$\1$", text) 31 | return text 32 | 33 | 34 | @st.cache_resource() 35 | def load_predictor(): 36 | return RecognitionPredictor() 37 | 38 | 39 | @st.cache_data() 40 | def inference(pil_image: Image.Image, bbox: List[float]): 41 | input_img = pil_image.crop(bbox) 42 | bbox = [0, 0, input_img.width, input_img.height] 43 | model_output = predictor( 44 | [input_img], [TaskNames.block_without_boxes], bboxes=[[bbox]] 45 | ) 46 | return model_output[0].text_lines[0].text 47 | 48 | 49 | def open_pdf(pdf_file): 50 | stream = io.BytesIO(pdf_file.getvalue()) 51 | return pypdfium2.PdfDocument(stream) 52 | 53 | 54 | @st.cache_data() 55 | def get_page_image(pdf_file, page_num, dpi=settings.IMAGE_DPI_HIGHRES): 56 | doc = open_pdf(pdf_file) 57 | renderer = doc.render( 58 | pypdfium2.PdfBitmap.to_pil, 59 | page_indices=[page_num - 1], 60 | scale=dpi / 72, 61 | ) 62 | png = list(renderer)[0] 63 | png_image = png.convert("RGB") 64 | doc.close() 65 | return png_image 66 | 67 | 68 | @st.cache_data() 69 | def page_counter(pdf_file): 70 | doc = open_pdf(pdf_file) 71 | doc_len = len(doc) 72 | doc.close() 73 | return doc_len 74 | 75 | 76 | def resize_image(pil_image): 77 | if pil_image is None: 78 | return 79 | pil_image.thumbnail((MAX_WIDTH, MAX_HEIGHT), Image.Resampling.LANCZOS) 80 | 81 | 82 | def get_canvas_hash(pil_image): 83 | return hashlib.md5(pil_image.tobytes()).hexdigest() 84 | 85 | 86 | st.set_page_config(layout="wide") 87 | 88 | top_message = """### LaTeX OCR 89 | 90 | After the model loads, upload an image or a pdf, then draw a box around the equation or text you want to OCR by clicking and dragging. Surya will convert it to Markdown with LaTeX math on the right. 91 | """ 92 | 93 | st.markdown(top_message) 94 | col1, col2 = st.columns([0.7, 0.3]) 95 | 96 | predictor = load_predictor() 97 | 98 | in_file = st.sidebar.file_uploader( 99 | "PDF file or image:", type=["pdf", "png", "jpg", "jpeg", "gif", "webp"] 100 | ) 101 | if in_file is None: 102 | st.stop() 103 | 104 | if in_file is None: 105 | st.stop() 106 | 107 | filetype = in_file.type 108 | page_count = None 109 | if "pdf" in filetype: 110 | page_count = page_counter(in_file) 111 | page_number = st.sidebar.number_input( 112 | f"Page number out of {page_count}:", min_value=1, value=1, max_value=page_count 113 | ) 114 | pil_image = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES) 115 | else: 116 | pil_image = Image.open(in_file).convert("RGB") 117 | page_number = None 118 | 119 | if pil_image is None: 120 | st.stop() 121 | 122 | pil_image.thumbnail((MAX_WIDTH, MAX_HEIGHT), Image.Resampling.LANCZOS) 123 | canvas_hash = get_canvas_hash(pil_image) 124 | 125 | with col1: 126 | # Create a canvas component 127 | canvas_result = st_canvas( 128 | fill_color="rgba(255, 165, 0, 0.1)", # Fixed fill color with some opacity 129 | stroke_width=1, 130 | stroke_color="#FFAA00", 131 | background_color="#FFF", 132 | background_image=pil_image, 133 | update_streamlit=True, 134 | height=pil_image.height, 135 | width=pil_image.width, 136 | drawing_mode="rect", 137 | point_display_radius=0, 138 | key=canvas_hash, 139 | ) 140 | 141 | if not canvas_result.json_data: 142 | st.stop() 143 | 144 | objects = pd.json_normalize( 145 | canvas_result.json_data["objects"] 146 | ) # need to convert obj to str because PyArrow 147 | bbox_list = None 148 | if objects.shape[0] > 0: 149 | boxes = objects[objects["type"] == "rect"][["left", "top", "width", "height"]] 150 | boxes["right"] = boxes["left"] + boxes["width"] 151 | boxes["bottom"] = boxes["top"] + boxes["height"] 152 | bbox_list = boxes[["left", "top", "right", "bottom"]].values.tolist() 153 | 154 | if bbox_list: 155 | with col2: 156 | texts = [inference(pil_image, bbox) for bbox in bbox_list] 157 | for idx, latex in enumerate(reversed(texts)): 158 | st.markdown(f"### {len(texts) - idx}") 159 | st.markdown(replace_fences(latex), unsafe_allow_html=True) 160 | st.code(latex) 161 | st.divider() 162 | 163 | with col2: 164 | tips = """ 165 | ### Usage tips 166 | - Texify is sensitive to how you draw the box around the text you want to OCR. If you get bad results, try selecting a slightly different box, or splitting the box into multiple. 167 | """ 168 | st.markdown(tips) 169 | -------------------------------------------------------------------------------- /surya/table_rec/loader.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | 5 | from surya.common.load import ModelLoader 6 | from surya.logging import get_logger 7 | from surya.settings import settings 8 | from surya.table_rec.model.config import ( 9 | SuryaTableRecConfig, 10 | SuryaTableRecDecoderConfig, 11 | DonutSwinTableRecConfig, 12 | ) 13 | from surya.table_rec.model.encoderdecoder import TableRecEncoderDecoderModel 14 | from surya.table_rec.processor import SuryaTableRecProcessor 15 | 16 | logger = get_logger() 17 | 18 | 19 | class TableRecModelLoader(ModelLoader): 20 | def __init__(self, checkpoint: Optional[str] = None): 21 | super().__init__(checkpoint) 22 | 23 | if self.checkpoint is None: 24 | self.checkpoint = settings.TABLE_REC_MODEL_CHECKPOINT 25 | 26 | def model( 27 | self, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE 28 | ) -> TableRecEncoderDecoderModel: 29 | if device is None: 30 | device = settings.TORCH_DEVICE_MODEL 31 | if dtype is None: 32 | dtype = settings.MODEL_DTYPE 33 | 34 | config = SuryaTableRecConfig.from_pretrained(self.checkpoint) 35 | decoder_config = config.decoder 36 | decoder = SuryaTableRecDecoderConfig(**decoder_config) 37 | config.decoder = decoder 38 | 39 | encoder_config = config.encoder 40 | encoder = DonutSwinTableRecConfig(**encoder_config) 41 | config.encoder = encoder 42 | 43 | model = TableRecEncoderDecoderModel.from_pretrained( 44 | self.checkpoint, config=config, torch_dtype=dtype 45 | ) 46 | 47 | model = model.to(device) 48 | model = model.eval() 49 | 50 | if settings.COMPILE_ALL or settings.COMPILE_TABLE_REC: 51 | torch.set_float32_matmul_precision("high") 52 | torch._dynamo.config.cache_size_limit = 16 53 | torch._dynamo.config.suppress_errors = False 54 | 55 | logger.info( 56 | f"Compiling table recognition model {self.checkpoint} on device {device} with dtype {dtype}" 57 | ) 58 | compile_args = {"backend": "openxla"} if device == "xla" else {} 59 | model.encoder = torch.compile(model.encoder, **compile_args) 60 | model.decoder = torch.compile(model.decoder, **compile_args) 61 | 62 | logger.debug( 63 | f"Loaded table recognition model {self.checkpoint} from {TableRecEncoderDecoderModel.get_local_path(self.checkpoint)} onto device {device} with dtype {dtype}" 64 | ) 65 | return model 66 | 67 | def processor( 68 | self, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE 69 | ) -> SuryaTableRecProcessor: 70 | processor = SuryaTableRecProcessor(self.checkpoint) 71 | 72 | processor.token_pad_id = 0 73 | processor.token_eos_id = 1 74 | processor.token_bos_id = 1 75 | processor.token_query_end_id = 4 76 | return processor 77 | -------------------------------------------------------------------------------- /surya/table_rec/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikParuchuri/surya/8023f3a50f41f5eaefcdf423dd379ceaff0f504a/surya/table_rec/model/__init__.py -------------------------------------------------------------------------------- /surya/table_rec/model/encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from surya.common.donut.encoder import DonutSwinPreTrainedModel, DonutSwinModelOutput, DonutSwinEmbeddings, DonutSwinEncoder 7 | 8 | 9 | class DonutSwinModel(DonutSwinPreTrainedModel): 10 | def __init__(self, config, add_pooling_layer=True, use_mask_token=False): 11 | super().__init__(config) 12 | self.config = config 13 | self.num_layers = len(config.depths) 14 | self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1)) 15 | 16 | self.embeddings = DonutSwinEmbeddings(config, use_mask_token=use_mask_token) 17 | self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid) 18 | 19 | self.position_embeddings = None 20 | if hasattr(config, "encoder_length"): 21 | self.position_embeddings = nn.Parameter(torch.zeros(1, config.encoder_length, config.hidden_size)) 22 | 23 | # Initialize weights and apply final processing 24 | self.post_init() 25 | 26 | def get_input_embeddings(self): 27 | return self.embeddings.patch_embeddings 28 | 29 | def _prune_heads(self, heads_to_prune): 30 | """ 31 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base 32 | class PreTrainedModel 33 | """ 34 | for layer, heads in heads_to_prune.items(): 35 | self.encoder.layer[layer].attention.prune_heads(heads) 36 | 37 | def forward( 38 | self, 39 | pixel_values: Optional[torch.FloatTensor] = None, 40 | bool_masked_pos: Optional[torch.BoolTensor] = None, 41 | head_mask: Optional[torch.FloatTensor] = None, 42 | output_attentions: Optional[bool] = None, 43 | output_hidden_states: Optional[bool] = None, 44 | interpolate_pos_encoding: bool = False, 45 | return_dict: Optional[bool] = None, 46 | ) -> Union[Tuple, DonutSwinModelOutput]: 47 | r""" 48 | bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): 49 | Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). 50 | """ 51 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 52 | output_hidden_states = ( 53 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 54 | ) 55 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 56 | 57 | if pixel_values is None: 58 | raise ValueError("You have to specify pixel_values") 59 | 60 | # Prepare head mask if needed 61 | # 1.0 in head_mask indicate we keep the head 62 | # attention_probs has shape bsz x n_heads x N x N 63 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] 64 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] 65 | head_mask = self.get_head_mask(head_mask, len(self.config.depths)) 66 | 67 | embedding_output, input_dimensions = self.embeddings( 68 | pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding 69 | ) 70 | 71 | encoder_outputs = self.encoder( 72 | embedding_output, 73 | input_dimensions, 74 | head_mask=head_mask, 75 | output_attentions=output_attentions, 76 | output_hidden_states=output_hidden_states, 77 | return_dict=return_dict, 78 | ) 79 | 80 | last_hidden_state = encoder_outputs[0] 81 | 82 | if self.position_embeddings is not None: 83 | last_hidden_state += self.position_embeddings[:, :last_hidden_state.size(1), :] 84 | 85 | return DonutSwinModelOutput( 86 | last_hidden_state=last_hidden_state, 87 | ) 88 | -------------------------------------------------------------------------------- /surya/table_rec/model/encoderdecoder.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Union, Tuple, Dict 3 | 4 | import torch 5 | from transformers import PreTrainedModel, VisionEncoderDecoderConfig, PretrainedConfig 6 | from surya.common.s3 import S3DownloaderMixin 7 | from surya.table_rec.model.decoder import SuryaTableRecDecoder 8 | from surya.table_rec.model.encoder import DonutSwinModel 9 | from transformers.utils import ModelOutput 10 | 11 | 12 | @dataclass 13 | class TableRecOutput(ModelOutput): 14 | box_property_logits: Dict[str, torch.FloatTensor] 15 | decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None 16 | 17 | 18 | class TableRecEncoderDecoderModel(S3DownloaderMixin, PreTrainedModel): 19 | config_class = VisionEncoderDecoderConfig 20 | base_model_prefix = "vision_encoder_decoder" 21 | main_input_name = "pixel_values" 22 | supports_gradient_checkpointing = True 23 | _supports_param_buffer_assignment = False 24 | 25 | def __init__( 26 | self, 27 | config: Optional[PretrainedConfig] = None, 28 | encoder: Optional[PreTrainedModel] = None, 29 | decoder: Optional[PreTrainedModel] = None, 30 | ): 31 | # initialize with config 32 | # make sure input & output embeddings is not tied 33 | config.tie_word_embeddings = False 34 | config.decoder.tie_word_embeddings = False 35 | super().__init__(config) 36 | 37 | if encoder is None: 38 | encoder = DonutSwinModel(config.encoder) 39 | 40 | if decoder is None: 41 | decoder = SuryaTableRecDecoder(config.decoder, attn_implementation=config._attn_implementation) 42 | 43 | self.encoder = encoder 44 | self.decoder = decoder 45 | 46 | # make sure that the individual model's config refers to the shared config 47 | # so that the updates to the config will be synced 48 | self.encoder.config = self.config.encoder 49 | self.decoder.config = self.config.decoder 50 | 51 | def get_encoder(self): 52 | return self.encoder 53 | 54 | def get_decoder(self): 55 | return self.decoder 56 | 57 | def get_output_embeddings(self): 58 | return self.decoder.get_output_embeddings() 59 | 60 | def set_output_embeddings(self, new_embeddings): 61 | return self.decoder.set_output_embeddings(new_embeddings) 62 | 63 | def forward( 64 | self, 65 | decoder_input_ids: torch.LongTensor = None, 66 | decoder_cache_position: Optional[torch.LongTensor] = None, 67 | decoder_attention_mask: Optional[torch.LongTensor] = None, 68 | encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, 69 | use_cache: Optional[bool] = None, 70 | return_dict: Optional[bool] = None, 71 | **kwargs, 72 | ) -> Union[Tuple[torch.FloatTensor], TableRecOutput]: 73 | kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} 74 | 75 | kwargs_decoder = { 76 | argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") 77 | } 78 | 79 | # Decode 80 | decoder_outputs = self.decoder( 81 | input_labels=decoder_input_ids, 82 | input_boxes_counts=None, 83 | cache_position=decoder_cache_position, 84 | attention_mask=decoder_attention_mask, 85 | encoder_hidden_states=encoder_outputs, 86 | encoder_attention_mask=None, 87 | use_cache=use_cache, 88 | **kwargs_decoder, 89 | ) 90 | 91 | return TableRecOutput( 92 | box_property_logits=decoder_outputs.box_property_logits, 93 | decoder_hidden_states=decoder_outputs.hidden_states, 94 | ) 95 | 96 | def resize_token_embeddings(self, *args, **kwargs): 97 | raise NotImplementedError( 98 | "Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the" 99 | " respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))" 100 | ) 101 | 102 | def _reorder_cache(self, past_key_values, beam_idx): 103 | # apply decoder cache reordering here 104 | return self.decoder._reorder_cache(past_key_values, beam_idx) -------------------------------------------------------------------------------- /surya/table_rec/processor.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import PIL 4 | import torch 5 | from transformers import ProcessorMixin 6 | 7 | from surya.common.s3 import S3DownloaderMixin 8 | from surya.common.donut.processor import SuryaEncoderImageProcessor 9 | from surya.table_rec.shaper import LabelShaper 10 | from surya.settings import settings 11 | from surya.table_rec.model.config import BOX_DIM, SPECIAL_TOKENS 12 | 13 | 14 | class SuryaTableRecProcessor(S3DownloaderMixin, ProcessorMixin): 15 | attributes = ["image_processor"] 16 | image_processor_class = "AutoImageProcessor" 17 | 18 | def __init__(self, checkpoint, **kwargs): 19 | image_processor = SuryaEncoderImageProcessor.from_pretrained(checkpoint) 20 | image_processor.do_align_long_axis = False 21 | image_processor.max_size = settings.TABLE_REC_IMAGE_SIZE 22 | self.image_processor = image_processor 23 | super().__init__(image_processor) 24 | 25 | self.box_size = (BOX_DIM, BOX_DIM) 26 | self.special_token_count = SPECIAL_TOKENS 27 | self.shaper = LabelShaper() 28 | 29 | def resize_polygon(self, polygon, orig_size, new_size): 30 | w_scaler = new_size[0] / orig_size[0] 31 | h_scaler = new_size[1] / orig_size[1] 32 | 33 | for corner in polygon: 34 | corner[0] = corner[0] * w_scaler 35 | corner[1] = corner[1] * h_scaler 36 | 37 | if corner[0] < 0: 38 | corner[0] = 0 39 | if corner[1] < 0: 40 | corner[1] = 0 41 | if corner[0] > new_size[0]: 42 | corner[0] = new_size[0] 43 | if corner[1] > new_size[1]: 44 | corner[1] = new_size[1] 45 | 46 | return polygon 47 | 48 | def __call__( 49 | self, 50 | images: List[PIL.Image.Image] | None, 51 | query_items: List[dict], 52 | columns: List[dict] | None = None, 53 | convert_images: bool = True, 54 | *args, 55 | **kwargs 56 | ): 57 | if convert_images: 58 | assert len(images) == len(query_items) 59 | assert len(images) > 0 60 | 61 | # Resize input query items 62 | for image, query_item in zip(images, query_items): 63 | query_item["polygon"] = self.resize_polygon(query_item["polygon"], image.size, self.box_size) 64 | 65 | query_items = self.shaper.convert_polygons_to_bboxes(query_items) 66 | query_labels = self.shaper.dict_to_labels(query_items) 67 | 68 | decoder_input_boxes = [] 69 | col_count = len(query_labels[0]) 70 | for label in query_labels: 71 | decoder_input_boxes.append([ 72 | [self.token_bos_id] * col_count, 73 | label, 74 | [self.token_query_end_id] * col_count 75 | ]) 76 | 77 | # Add columns to end of decoder input 78 | if columns: 79 | columns = self.shaper.convert_polygons_to_bboxes(columns) 80 | column_labels = self.shaper.dict_to_labels(columns) 81 | for decoder_box in decoder_input_boxes: 82 | decoder_box += column_labels 83 | 84 | input_boxes = torch.tensor(decoder_input_boxes, dtype=torch.long) 85 | input_boxes_mask = torch.ones_like(input_boxes, dtype=torch.long) 86 | 87 | inputs = { 88 | "input_ids": input_boxes, 89 | "attention_mask": input_boxes_mask 90 | } 91 | if convert_images: 92 | inputs["pixel_values"] = self.image_processor(images, *args, **kwargs)["pixel_values"] 93 | return inputs 94 | -------------------------------------------------------------------------------- /surya/table_rec/schema.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from pydantic import BaseModel 4 | 5 | from surya.common.polygon import PolygonBox 6 | 7 | 8 | class TableCell(PolygonBox): 9 | row_id: int 10 | colspan: int 11 | within_row_id: int 12 | cell_id: int 13 | is_header: bool 14 | rowspan: int | None = None 15 | merge_up: bool = False 16 | merge_down: bool = False 17 | col_id: int | None = None 18 | text_lines: List[dict] | None = None 19 | 20 | @property 21 | def label(self): 22 | return f'Cell {self.cell_id} {self.rowspan}/{self.colspan}' 23 | 24 | 25 | class TableRow(PolygonBox): 26 | row_id: int 27 | is_header: bool 28 | 29 | @property 30 | def label(self): 31 | return f'Row {self.row_id}' 32 | 33 | 34 | class TableCol(PolygonBox): 35 | col_id: int 36 | is_header: bool 37 | 38 | @property 39 | def label(self): 40 | return f'Column {self.col_id}' 41 | 42 | 43 | class TableResult(BaseModel): 44 | cells: List[TableCell] 45 | unmerged_cells: List[TableCell] 46 | rows: List[TableRow] 47 | cols: List[TableCol] 48 | image_bbox: List[float] 49 | -------------------------------------------------------------------------------- /surya/table_rec/shaper.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Dict 3 | import numpy as np 4 | 5 | from surya.table_rec.model.config import BOX_PROPERTIES, SPECIAL_TOKENS, BOX_DIM 6 | 7 | 8 | class LabelShaper: 9 | def __init__(self): 10 | self.property_keys = [k for (k, kcount, mode) in BOX_PROPERTIES] 11 | 12 | def dict_to_labels(self, label_components: List[dict]): 13 | if len(label_components) == 0: 14 | return [] 15 | 16 | out_list = [] 17 | for (k, kcount, mode) in BOX_PROPERTIES: 18 | for label_component in label_components: 19 | if k not in label_component: 20 | raise ValueError(f"Missing key {k} in label component {label_component}") 21 | 22 | if mode == "classification": 23 | assert isinstance(label_component[k], int) 24 | elif mode == "regression": 25 | assert (isinstance(label_component[k], (int, float)) and kcount == 1) or len(label_component[k]) == kcount 26 | else: 27 | raise ValueError(f"Invalid mode {k['mode']} for key {k}") 28 | 29 | for label_component in label_components: 30 | bbox = label_component["bbox"] 31 | for i in range(len(bbox)): 32 | if bbox[i] < 0: 33 | bbox[i] = 0 34 | if bbox[i] > BOX_DIM: 35 | bbox[i] = BOX_DIM 36 | 37 | vector = [] 38 | for (k, kcount, mode) in BOX_PROPERTIES: 39 | item = label_component[k] 40 | if isinstance(item, (list, tuple)): 41 | vector += list(item) 42 | elif isinstance(item, (float, int)): 43 | if mode == "classification": 44 | # Shift up for model 45 | item += SPECIAL_TOKENS 46 | vector.append(item) 47 | else: 48 | raise ValueError(f"Invalid item {item} for key {k}") 49 | 50 | out_list.append(vector) 51 | 52 | return out_list 53 | 54 | def component_idx(self, key): 55 | idx = 0 56 | for (k, kcount, mode) in BOX_PROPERTIES: 57 | if mode == "regression": 58 | incr = kcount 59 | elif mode == "classification": 60 | incr = 1 61 | else: 62 | raise ValueError(f"Invalid mode {mode} for key {k}") 63 | if k == key: 64 | return (idx, idx + incr) 65 | idx += incr 66 | raise ValueError(f"Key {key} not found in properties") 67 | 68 | def get_box_property(self, key, add_special_tokens=True): 69 | for (k, kcount, mode) in BOX_PROPERTIES: 70 | if k == key: 71 | # Add special token count 72 | if mode == "classification" and add_special_tokens: 73 | kcount += SPECIAL_TOKENS 74 | return (k, kcount, mode) 75 | raise ValueError(f"Key {key} not found in properties") 76 | 77 | def component_idx_dict(self): 78 | idx_dict = {} 79 | for (k, kcount, mode) in BOX_PROPERTIES: 80 | idx_dict[k] = self.component_idx(k) 81 | return idx_dict 82 | 83 | def convert_polygons_to_bboxes(self, label_components: List[Dict]): 84 | for i, label_component in enumerate(label_components): 85 | poly = label_component["polygon"] 86 | poly = np.clip(poly, 0, BOX_DIM) 87 | 88 | (x1, y1), (x2, y2), (x3, y3), (x4, y4) = poly 89 | cx = (x1 + x2 + x3 + x4) / 4 90 | cy = (y1 + y2 + y3 + y4) / 4 91 | width = (x2 + x3) / 2 - (x1 + x4) / 2 92 | height = (y3 + y4) / 2 - (y2 + y1) / 2 93 | bottom_avg_x = (x3 + x4) / 2 94 | top_avg_x = (x1 + x2) / 2 95 | right_avg_y = (y2 + y3) / 2 96 | left_avg_y = (y1 + y4) / 2 97 | 98 | x_skew = bottom_avg_x - top_avg_x 99 | y_skew = right_avg_y - left_avg_y 100 | x_skew += BOX_DIM // 2 # Shift up into positive space 101 | y_skew += BOX_DIM // 2 # Shift up into positive space 102 | new_poly = [ 103 | cx, 104 | cy, 105 | width, 106 | height, 107 | x_skew, 108 | y_skew 109 | ] 110 | label_component["bbox"] = new_poly 111 | 112 | return label_components 113 | 114 | def convert_bbox_to_polygon(self, box, skew_scaler=BOX_DIM // 2, skew_min=.001): 115 | cx = box[0] 116 | cy = box[1] 117 | width = box[2] 118 | height = box[3] 119 | x1 = cx - width / 2 120 | y1 = cy - height / 2 121 | x2 = cx + width / 2 122 | y2 = cy + height / 2 123 | skew_x = math.floor((box[4] - skew_scaler) / 2) 124 | skew_y = math.floor((box[5] - skew_scaler) / 2) 125 | 126 | # Ensures we don't get slightly warped boxes 127 | # Note that the values are later scaled, so this is in 1/1024 space 128 | if abs(skew_x) < skew_min: 129 | skew_x = 0 130 | 131 | if abs(skew_y) < skew_min: 132 | skew_y = 0 133 | 134 | polygon = [x1 - skew_x, y1 - skew_y, x2 - skew_x, y1 + skew_y, x2 + skew_x, y2 + skew_y, x1 + skew_x, 135 | y2 - skew_y] 136 | poly = [] 137 | for i in range(4): 138 | poly.append([ 139 | polygon[2 * i], 140 | polygon[2 * i + 1] 141 | ]) 142 | return poly 143 | 144 | 145 | 146 | -------------------------------------------------------------------------------- /table_recognition.py: -------------------------------------------------------------------------------- 1 | from surya.scripts.table_recognition import table_recognition_cli 2 | 3 | if __name__ == "__main__": 4 | table_recognition_cli() -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" 4 | 5 | import pytest 6 | from PIL import Image, ImageDraw 7 | 8 | from surya.detection import DetectionPredictor 9 | from surya.ocr_error import OCRErrorPredictor 10 | from surya.layout import LayoutPredictor 11 | from surya.recognition import RecognitionPredictor 12 | from surya.table_rec import TableRecPredictor 13 | 14 | 15 | @pytest.fixture(scope="session") 16 | def ocr_error_predictor() -> OCRErrorPredictor: 17 | ocr_error_predictor = OCRErrorPredictor() 18 | yield ocr_error_predictor 19 | del ocr_error_predictor 20 | 21 | 22 | @pytest.fixture(scope="session") 23 | def layout_predictor() -> LayoutPredictor: 24 | layout_predictor = LayoutPredictor() 25 | yield layout_predictor 26 | del layout_predictor 27 | 28 | 29 | @pytest.fixture(scope="session") 30 | def detection_predictor() -> DetectionPredictor: 31 | detection_predictor = DetectionPredictor() 32 | yield detection_predictor 33 | del detection_predictor 34 | 35 | 36 | @pytest.fixture(scope="session") 37 | def recognition_predictor() -> RecognitionPredictor: 38 | recognition_predictor = RecognitionPredictor() 39 | yield recognition_predictor 40 | del recognition_predictor 41 | 42 | 43 | @pytest.fixture(scope="session") 44 | def table_rec_predictor() -> TableRecPredictor: 45 | table_rec_predictor = TableRecPredictor() 46 | yield table_rec_predictor 47 | del table_rec_predictor 48 | 49 | 50 | @pytest.fixture() 51 | def test_image(): 52 | image = Image.new("RGB", (1024, 1024), "white") 53 | draw = ImageDraw.Draw(image) 54 | draw.text((10, 10), "Hello World", fill="black", font_size=72) 55 | draw.text( 56 | (10, 200), 57 | "This is a sentence of text.\nNow it is a paragraph.\nA three-line one.", 58 | fill="black", 59 | font_size=24, 60 | ) 61 | return image 62 | 63 | 64 | @pytest.fixture() 65 | def test_image_tall(): 66 | image = Image.new("RGB", (4096, 4096), "white") 67 | draw = ImageDraw.Draw(image) 68 | draw.text((10, 10), "Hello World", fill="black", font_size=72) 69 | draw.text( 70 | (4000, 4000), 71 | "This is a sentence of text.\n\nNow it is a paragraph.\n\nA three-line one.", 72 | fill="black", 73 | font_size=24, 74 | ) 75 | return image 76 | -------------------------------------------------------------------------------- /tests/test_detection.py: -------------------------------------------------------------------------------- 1 | def test_detection(detection_predictor, test_image): 2 | detection_results = detection_predictor([test_image]) 3 | 4 | assert len(detection_results) == 1 5 | assert detection_results[0].image_bbox == [0, 0, 1024, 1024] 6 | 7 | bboxes = detection_results[0].bboxes 8 | assert len(bboxes) == 4 9 | 10 | 11 | def test_detection_chunking(detection_predictor, test_image_tall): 12 | detection_results = detection_predictor([test_image_tall]) 13 | 14 | assert len(detection_results) == 1 15 | assert detection_results[0].image_bbox == [0, 0, 4096, 4096] 16 | 17 | bboxes = detection_results[0].bboxes 18 | assert len(bboxes) >= 3 # Sometimes merges into 3 19 | assert abs(4000 - bboxes[1].polygon[0][0]) < 50 -------------------------------------------------------------------------------- /tests/test_latex_ocr.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from PIL import Image, ImageDraw 4 | 5 | from surya.common.surya.schema import TaskNames 6 | from surya.recognition import OCRResult 7 | 8 | 9 | def test_latex_ocr(recognition_predictor): 10 | img = Image.new("RGB", (200, 100), color="white") 11 | draw = ImageDraw.Draw(img) 12 | draw.text((10, 10), "E = mc2", fill="black", font_size=48) 13 | 14 | results: List[OCRResult] = recognition_predictor( 15 | [img], [TaskNames.block_without_boxes], bboxes=[[[0, 0, 200, 100]]] 16 | ) 17 | text = results[0].text_lines[0].text 18 | assert len(results) == 1 19 | 20 | assert text.startswith("") 22 | -------------------------------------------------------------------------------- /tests/test_layout.py: -------------------------------------------------------------------------------- 1 | def test_layout_topk(layout_predictor, test_image): 2 | layout_results = layout_predictor([test_image]) 3 | 4 | assert len(layout_results) == 1 5 | assert layout_results[0].image_bbox == [0, 0, 1024, 1024] 6 | 7 | bboxes = layout_results[0].bboxes 8 | assert len(bboxes) == 2 9 | 10 | assert bboxes[0].label == "SectionHeader" 11 | assert len(bboxes[0].top_k) == 5 12 | 13 | assert bboxes[1].label == "Text" 14 | assert len(bboxes[1].top_k) == 5 15 | -------------------------------------------------------------------------------- /tests/test_ocr_errors.py: -------------------------------------------------------------------------------- 1 | def test_garbled_text(ocr_error_predictor): 2 | text = """" 3 | ; dh vksj ls mifLFkr vf/koDrk % Jh vfuy dqekj 4 | 2. vfHk;qDr dh vksj ls mifLFkr vf/koDrk % Jh iznhi d 5 | """.strip() 6 | results = ocr_error_predictor([text]) 7 | assert results.labels[0] == "bad" 8 | 9 | 10 | def test_good_text(ocr_error_predictor): 11 | text = """" 12 | There are professions more harmful than industrial design, but only a very few of them. 13 | """.strip() 14 | results = ocr_error_predictor([text]) 15 | assert results.labels[0] == "good" -------------------------------------------------------------------------------- /tests/test_recognition.py: -------------------------------------------------------------------------------- 1 | import time 2 | from PIL import ImageDraw, Image 3 | 4 | 5 | def test_recognition(recognition_predictor, detection_predictor, test_image): 6 | recognition_results = recognition_predictor([test_image], None, detection_predictor) 7 | 8 | assert len(recognition_results) == 1 9 | assert recognition_results[0].image_bbox == [0, 0, 1024, 1024] 10 | 11 | text_lines = recognition_results[0].text_lines 12 | assert len(text_lines) == 4 13 | assert "Hello World" in text_lines[0].text 14 | 15 | 16 | def test_recognition_input_text(recognition_predictor, detection_predictor, test_image): 17 | start = time.time() 18 | recognition_predictor([test_image], None, detection_predictor) 19 | end = time.time() - start 20 | 21 | input_text = "a" * 400 22 | start2 = time.time() 23 | recognition_results = recognition_predictor( 24 | [test_image], None, detection_predictor, input_text=[input_text] 25 | ) 26 | end2 = time.time() - start2 27 | 28 | assert max([end, end2]) / min([end, end2]) < 1.5, ( 29 | "Input text should be truncated and not change inference time" 30 | ) 31 | 32 | assert len(recognition_results) == 1 33 | assert recognition_results[0].image_bbox == [0, 0, 1024, 1024] 34 | 35 | text_lines = recognition_results[0].text_lines 36 | assert len(text_lines) == 4 37 | assert "Hello World" in text_lines[0].text 38 | 39 | 40 | def test_recognition_drop_repeats(recognition_predictor, detection_predictor): 41 | image = Image.new("RGB", (1024, 128), "white") 42 | draw = ImageDraw.Draw(image) 43 | text = "a" * 80 44 | draw.text((5, 5), text, fill="black", font_size=24) 45 | 46 | recognition_results = recognition_predictor( 47 | [image], None, bboxes=[[[0, 0, 1024, 128]]], drop_repeated_text=True 48 | ) 49 | assert len(recognition_results) == 1 50 | result = recognition_results[0].text_lines 51 | assert result[0].text == "" 52 | -------------------------------------------------------------------------------- /tests/test_table_rec.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageDraw 2 | 3 | def test_table_rec(table_rec_predictor): 4 | data = [ 5 | ["Name", "Age", "City"], 6 | ["Alice", 25, "New York"], 7 | ["Bob", 30, "Los Angeles"], 8 | ["Charlie", 35, "Chicago"], 9 | ] 10 | test_image = draw_table(data) 11 | 12 | results = table_rec_predictor([test_image]) 13 | assert len(results) == 1 14 | assert results[0].image_bbox == [0, 0, test_image.size[0], test_image.size[1]] 15 | 16 | cells = results[0].cells 17 | assert len(cells) == 12 18 | for row_id in range(4): 19 | for col_id in range(3): 20 | cell = [c for c in cells if c.row_id == row_id and c.col_id == col_id] 21 | assert len(cell) == 1, f"Missing cell at row {row_id}, col {col_id}" 22 | 23 | def draw_table(data, cell_width=100, cell_height=40): 24 | rows = len(data) 25 | cols = len(data[0]) 26 | width = cols * cell_width 27 | height = rows * cell_height 28 | 29 | image = Image.new('RGB', (width, height), 'white') 30 | draw = ImageDraw.Draw(image) 31 | 32 | for i in range(rows + 1): 33 | y = i * cell_height 34 | draw.line([(0, y), (width, y)], fill='black', width=1) 35 | 36 | for i in range(cols + 1): 37 | x = i * cell_width 38 | draw.line([(x, 0), (x, height)], fill='black', width=1) 39 | 40 | for i in range(rows): 41 | for j in range(cols): 42 | text = str(data[i][j]) 43 | text_bbox = draw.textbbox((0, 0), text) 44 | text_width = text_bbox[2] - text_bbox[0] 45 | text_height = text_bbox[3] - text_bbox[1] 46 | 47 | x = j * cell_width + (cell_width - text_width) // 2 48 | y = i * cell_height + (cell_height - text_height) // 2 49 | 50 | draw.text((x, y), text, fill='black') 51 | 52 | return image -------------------------------------------------------------------------------- /texify_app.py: -------------------------------------------------------------------------------- 1 | from surya.scripts.run_texify_app import texify_app_cli 2 | 3 | if __name__ == "__main__": 4 | texify_app_cli() --------------------------------------------------------------------------------