├── .github └── workflows │ └── test.yml ├── .gitignore ├── CITATION.cff ├── LICENSE ├── README.rst ├── docs ├── Makefile ├── _static │ ├── blla_heatmap.jpg │ ├── blla_output.jpg │ ├── bw.png │ ├── custom.css │ ├── kraken.png │ ├── kraken_recognition.svg │ ├── kraken_segmentation.svg │ ├── kraken_segmodel.svg │ ├── kraken_torchseqrecognizer.svg │ ├── kraken_workflow.svg │ ├── normal-reproduction-low-resolution.jpg │ └── pat.png ├── _templates │ ├── sidebarintro.html │ └── versions.html ├── advanced.rst ├── alto.xml ├── api.rst ├── api_docs.rst ├── conf.py ├── gpu.rst ├── index.rst ├── ketos.rst ├── models.rst ├── pagexml.xml ├── redirect.html ├── training.rst └── vgsl.rst ├── environment.yml ├── environment_cuda.yml ├── kraken ├── __init__.py ├── align.py ├── binarization.py ├── blla.mlmodel ├── blla.py ├── containers.py ├── contrib │ ├── add_neural_ro.py │ ├── baselineset_overlay.py │ ├── extract_lines.py │ ├── forced_alignment_overlay.py │ ├── generate_scripts.py │ ├── heatmap_overlay.py │ ├── new_classes.json │ ├── print_word_spreader.py │ ├── recognition_boxes.py │ ├── repolygonize.py │ ├── segmentation_overlay.py │ ├── set_seg_options.py │ └── test_per_file.py ├── iso15924.json ├── ketos │ ├── __init__.py │ ├── dataset.py │ ├── linegen.py │ ├── pretrain.py │ ├── recognition.py │ ├── repo.py │ ├── ro.py │ ├── segmentation.py │ ├── transcription.py │ └── util.py ├── kraken.py ├── lib │ ├── __init__.py │ ├── arrow_dataset.py │ ├── codec.py │ ├── ctc_decoder.py │ ├── dataset │ │ ├── __init__.py │ │ ├── recognition.py │ │ ├── ro.py │ │ ├── scripts.json │ │ ├── segmentation.py │ │ └── utils.py │ ├── default_specs.py │ ├── exceptions.py │ ├── functional_im_transforms.py │ ├── layers.py │ ├── lineest.py │ ├── log.py │ ├── models.py │ ├── morph.py │ ├── pretrain │ │ ├── __init__.py │ │ ├── layers.py │ │ ├── model.py │ │ └── util.py │ ├── progress.py │ ├── register.py │ ├── ro │ │ ├── __init__.py │ │ ├── layers.py │ │ ├── model.py │ │ └── util.py │ ├── segmentation.py │ ├── sl.py │ ├── train.py │ ├── util.py │ ├── vgsl.py │ └── xml.py ├── linegen.py ├── pageseg.py ├── repo.py ├── rpred.py ├── serialization.py ├── templates │ ├── abbyyxml │ ├── alto │ ├── hocr │ ├── layout.html │ ├── main.js │ ├── pagexml │ ├── report │ └── style.css └── transcribe.py ├── pyproject.toml ├── pytest.ini ├── setup.cfg ├── setup.py ├── singularity └── kraken.def └── tests ├── resources ├── 000236.gt.txt ├── 000236.png ├── 170025120000003,0074-lite.xml ├── 170025120000003,0074.jpg ├── 170025120000003,0074.xml ├── FineReader10-schema-v1.xml ├── ONB_ibn_19110701_010.tif_line_1548924556947_449.png ├── alto-4-3.xsd ├── bl_records.json ├── bsb00084914_00007.xml ├── bw.png ├── cPAS-2000.xml ├── input.jpg ├── input.tif ├── merge_tests │ ├── 0006.gt.txt │ ├── 0006.jpg │ ├── 0007.gt.txt │ ├── 0007.jpg │ ├── 0008.gt.txt │ ├── 0008.jpg │ ├── 0014.jpg │ ├── 0014.xml │ ├── 0021.gt.txt │ ├── 0021.jpg │ ├── base.arrow │ ├── merge_codec_nfd.mlmodel │ └── merger.arrow ├── model_small.mlmodel ├── overfit.mlmodel ├── overfit_newpoly.mlmodel ├── pagecontent.xsd ├── records.json ├── segmentation.json └── xlink.xsd ├── test_align.py ├── test_arrow_dataset.py ├── test_binarization.py ├── test_cli.py ├── test_codec.py ├── test_dataset.py ├── test_layers.py ├── test_lineest.py ├── test_merging.py ├── test_models.py ├── test_newpolygons.py ├── test_pageseg.py ├── test_readingorder.py ├── test_repo.py ├── test_rpred.py ├── test_serialization.py ├── test_train.py ├── test_transcribe.py ├── test_vgsl.py └── test_xml.py /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Lint, test, build, and publish 2 | 3 | on: 4 | push: 5 | 6 | 7 | jobs: 8 | lint_and_test: 9 | name: Runs the linter and tests 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | python-version: [3.9, '3.10', '3.11', '3.12'] 14 | 15 | steps: 16 | - uses: actions/checkout@v4 17 | - name: Set up Python ${{ matrix.python-version }} 18 | uses: actions/setup-python@v5 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | - name: Install dependencies and kraken 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install .[test] flake8 25 | - name: Lint with flake8 26 | run: | 27 | # stop the build if there are Python syntax errors or undefined names 28 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 29 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 30 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 31 | - name: Run tests, except training tests 32 | run: | 33 | pytest -k 'not test_train' 34 | 35 | build-n-publish-pypi: 36 | name: Build and publish Python 🐍 distributions 📦 to PyPI and TestPyPI 37 | needs: lint_and_test 38 | runs-on: ubuntu-latest 39 | if: startsWith(github.ref, 'refs/tags/') 40 | 41 | steps: 42 | - uses: actions/checkout@v4 43 | with: 44 | fetch-depth: 0 45 | - name: Set up Python 3.11 46 | uses: actions/setup-python@v5 47 | with: 48 | python-version: 3.11 49 | - name: Build a binary wheel and a source tarball 50 | run: | 51 | python -m pip install build --user 52 | python -m build --sdist --wheel --outdir dist/ . 53 | - name: Publish a Python distribution to PyPI 54 | uses: pypa/gh-action-pypi-publish@release/v1 55 | - name: Upload PyPI artifacts to GH storage 56 | uses: actions/upload-artifact@v3 57 | with: 58 | name: pypi_packages 59 | path: dist/* 60 | 61 | autodraft-gh-release: 62 | name: Create github release 63 | needs: build-n-publish-pypi 64 | runs-on: ubuntu-latest 65 | 66 | steps: 67 | - uses: actions/download-artifact@v4 68 | with: 69 | name: pypi_packages 70 | path: pypi 71 | - uses: "marvinpinto/action-automatic-releases@latest" 72 | with: 73 | repo_token: "${{ secrets.GITHUB_TOKEN }}" 74 | prerelease: false 75 | draft: true 76 | files: | 77 | pypi/* 78 | 79 | publish-gh-pages: 80 | name: Update kraken.re github pages 81 | needs: lint_and_test 82 | runs-on: ubuntu-latest 83 | if: | 84 | github.ref == 'refs/heads/main' || 85 | startsWith(github.ref, 'refs/tags/') 86 | 87 | steps: 88 | - uses: actions/checkout@v4 89 | with: 90 | fetch-depth: 0 91 | - name: Set up Python 3.11 92 | uses: actions/setup-python@v5 93 | with: 94 | python-version: 3.11 95 | - name: Install sphinx-multiversion 96 | run: python -m pip install sphinx-multiversion sphinx-autoapi 97 | - name: Create docs 98 | run: sphinx-multiversion docs build/html 99 | - name: Create redirect 100 | run: cp docs/redirect.html build/html/index.html 101 | - name: Push gh-pages 102 | uses: crazy-max/ghaction-github-pages@v4 103 | with: 104 | target_branch: gh-pages 105 | build_dir: build/html 106 | fqdn: kraken.re 107 | jekyll: false 108 | env: 109 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 110 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | [._]*.s[a-w][a-z] 2 | [._]s[a-w][a-z] 3 | *.un~ 4 | Session.vim 5 | .netrwhist 6 | *~ 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | env/ 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *,cover 51 | 52 | # Sphinx documentation 53 | docs/_build/ 54 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Kiessling" 5 | given-names: "Benjamin" 6 | orcid: "https://orcid.org/0000-0001-9543-7827" 7 | title: "The Kraken OCR system" 8 | version: 4.1.2 9 | date-released: 2022-04-12 10 | url: "https://kraken.re" 11 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | Description 2 | =========== 3 | 4 | .. image:: https://github.com/mittagessen/kraken/actions/workflows/test.yml/badge.svg 5 | :target: https://github.com/mittagessen/kraken/actions/workflows/test.yml 6 | 7 | kraken is a turn-key OCR system optimized for historical and non-Latin script 8 | material. 9 | 10 | kraken's main features are: 11 | 12 | - Fully trainable layout analysis, reading order, and character recognition 13 | - `Right-to-Left `_, `BiDi 14 | `_, and Top-to-Bottom 15 | script support 16 | - `ALTO `_, PageXML, abbyyXML, and hOCR 17 | output 18 | - Word bounding boxes and character cuts 19 | - Multi-script recognition support 20 | - `Public repository `_ of model files 21 | - Variable recognition network architecture 22 | 23 | Installation 24 | ============ 25 | 26 | kraken only runs on **Linux or Mac OS X**. Windows is not supported. 27 | 28 | The latest stable releases can be installed from `PyPi `_: 29 | 30 | :: 31 | 32 | $ pip install kraken 33 | 34 | If you want direct PDF and multi-image TIFF/JPEG2000 support it is necessary to 35 | install the `pdf` extras package for PyPi: 36 | 37 | :: 38 | 39 | $ pip install kraken[pdf] 40 | 41 | or install `pyvips` manually with pip: 42 | 43 | :: 44 | 45 | $ pip install pyvips 46 | 47 | Conda environment files are provided for the seamless installation of the main 48 | branch as well: 49 | 50 | :: 51 | 52 | $ git clone https://github.com/mittagessen/kraken.git 53 | $ cd kraken 54 | $ conda env create -f environment.yml 55 | 56 | or: 57 | 58 | :: 59 | 60 | $ git clone https://github.com/mittagessen/kraken.git 61 | $ cd kraken 62 | $ conda env create -f environment_cuda.yml 63 | 64 | for CUDA acceleration with the appropriate hardware. 65 | 66 | Finally you'll have to scrounge up a model to do the actual recognition of 67 | characters. To download the default model for printed French text and place it 68 | in the kraken directory for the current user: 69 | 70 | :: 71 | 72 | $ kraken get 10.5281/zenodo.10592716 73 | 74 | A list of libre models available in the central repository can be retrieved by 75 | running: 76 | 77 | :: 78 | 79 | $ kraken list 80 | 81 | Quickstart 82 | ========== 83 | 84 | Recognizing text on an image using the default parameters including the 85 | prerequisite steps of binarization and page segmentation: 86 | 87 | :: 88 | 89 | $ kraken -i image.tif image.txt binarize segment ocr 90 | 91 | To binarize a single image using the nlbin algorithm: 92 | 93 | :: 94 | 95 | $ kraken -i image.tif bw.png binarize 96 | 97 | To segment an image (binarized or not) with the new baseline segmenter: 98 | 99 | :: 100 | 101 | $ kraken -i image.tif lines.json segment -bl 102 | 103 | 104 | To segment and OCR an image using the default model(s): 105 | 106 | :: 107 | 108 | $ kraken -i image.tif image.txt segment -bl ocr -m catmus-print-fondue-large.mlmodel 109 | 110 | All subcommands and options are documented. Use the ``help`` option to get more 111 | information. 112 | 113 | Documentation 114 | ============= 115 | 116 | Have a look at the `docs `_. 117 | 118 | Related Software 119 | ================ 120 | 121 | These days kraken is quite closely linked to the `eScriptorium 122 | `_ project developed in the same eScripta research 123 | group. eScriptorium provides a user-friendly interface for annotating data, 124 | training models, and inference (but also much more). There is a `gitter channel 125 | `_ that is mostly intended for 126 | coordinating technical development but is also a spot to find people with 127 | experience on applying kraken on a wide variety of material. 128 | 129 | Funding 130 | ======= 131 | 132 | kraken is developed at the `École Pratique des Hautes Études `_, `Université PSL `_. 133 | 134 | .. container:: twocol 135 | 136 | .. container:: 137 | 138 | .. image:: https://raw.githubusercontent.com/mittagessen/kraken/main/docs/_static/normal-reproduction-low-resolution.jpg 139 | :width: 100 140 | :alt: Co-financed by the European Union 141 | 142 | .. container:: 143 | 144 | This project was funded in part by the European Union. (ERC, MiDRASH, 145 | project number 101071829). 146 | 147 | .. container:: twocol 148 | 149 | .. container:: 150 | 151 | .. image:: https://raw.githubusercontent.com/mittagessen/kraken/main/docs/_static/normal-reproduction-low-resolution.jpg 152 | :width: 100 153 | :alt: Co-financed by the European Union 154 | 155 | .. container:: 156 | 157 | This project was partially funded through the RESILIENCE project, funded from 158 | the European Union’s Horizon 2020 Framework Programme for Research and 159 | Innovation. 160 | 161 | 162 | .. container:: twocol 163 | 164 | .. container:: 165 | 166 | .. image:: https://projet.biblissima.fr/sites/default/files/2021-11/biblissima-baseline-sombre-ia.png 167 | :width: 400 168 | :alt: Received funding from the Programme d’investissements d’Avenir 169 | 170 | .. container:: 171 | 172 | Ce travail a bénéficié d’une aide de l’État gérée par l’Agence Nationale de la 173 | Recherche au titre du Programme d’Investissements d’Avenir portant la référence 174 | ANR-21-ESRE-0005 (Biblissima+). 175 | 176 | 177 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/blla_heatmap.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mittagessen/kraken/14778a4ef2d061d688efcf979cb038d34e4b87e0/docs/_static/blla_heatmap.jpg -------------------------------------------------------------------------------- /docs/_static/blla_output.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mittagessen/kraken/14778a4ef2d061d688efcf979cb038d34e4b87e0/docs/_static/blla_output.jpg -------------------------------------------------------------------------------- /docs/_static/bw.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mittagessen/kraken/14778a4ef2d061d688efcf979cb038d34e4b87e0/docs/_static/bw.png -------------------------------------------------------------------------------- /docs/_static/custom.css: -------------------------------------------------------------------------------- 1 | pre { 2 | white-space: pre-wrap; 3 | } 4 | svg { 5 | width: 100%; 6 | } 7 | .highlight .err { 8 | border: inherit; 9 | box-sizing: inherit; 10 | } 11 | 12 | div.leftside { 13 | width: 110px; 14 | padding: 0px 3px 0px 0px; 15 | float: left; 16 | } 17 | 18 | div.rightside { 19 | margin-left: 125px; 20 | } 21 | 22 | dl.py { 23 | margin-top: 25px; 24 | } 25 | -------------------------------------------------------------------------------- /docs/_static/kraken.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mittagessen/kraken/14778a4ef2d061d688efcf979cb038d34e4b87e0/docs/_static/kraken.png -------------------------------------------------------------------------------- /docs/_static/normal-reproduction-low-resolution.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mittagessen/kraken/14778a4ef2d061d688efcf979cb038d34e4b87e0/docs/_static/normal-reproduction-low-resolution.jpg -------------------------------------------------------------------------------- /docs/_static/pat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mittagessen/kraken/14778a4ef2d061d688efcf979cb038d34e4b87e0/docs/_static/pat.png -------------------------------------------------------------------------------- /docs/_templates/sidebarintro.html: -------------------------------------------------------------------------------- 1 |

Useful Links

2 | 8 | -------------------------------------------------------------------------------- /docs/_templates/versions.html: -------------------------------------------------------------------------------- 1 | {% if versions %} 2 |

{{ _('Versions') }}

3 |
    4 | {%- for item in versions %} 5 |
  • {{ item.name }}
  • 6 | {%- endfor %} 7 |
8 | {% endif %} 9 | 10 | -------------------------------------------------------------------------------- /docs/alto.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | 7 | filename.jpg 8 | 9 | .... 10 | 11 | 12 | 13 | 14 | 20 | 21 | 27 | 29 | ... 30 | 31 | 32 | 33 | 34 | 35 | 36 | ... 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | 9 | project = 'kraken' 10 | copyright = '2015-2024, Benjamin Kiessling' 11 | author = 'Benjamin Kiessling' 12 | 13 | from subprocess import Popen, PIPE 14 | pipe = Popen('git describe --tags --always main', stdout=PIPE, shell=True) 15 | release = pipe.stdout.read().decode('utf-8') 16 | 17 | extensions = [ 18 | 'sphinx.ext.autodoc', 19 | 'sphinx.ext.autodoc.typehints', 20 | 'autoapi.extension', 21 | 'sphinx.ext.napoleon', 22 | 'sphinx.ext.githubpages', 23 | 'sphinx_multiversion', 24 | ] 25 | 26 | templates_path = ['_templates'] 27 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 28 | 29 | autodoc_typehints = 'description' 30 | 31 | autoapi_type = 'python' 32 | autoapi_dirs = ['../kraken'] 33 | 34 | autoapi_options = ['members', 35 | 'undoc-members', 36 | #'private-members', 37 | #'special-members', 38 | 'show-inheritance', 39 | 'show-module-summary', 40 | #'imported-members', 41 | ] 42 | autoapi_generate_api_docs = False 43 | 44 | source_suffix = '.rst' 45 | 46 | master_doc = 'index' 47 | 48 | language = 'en' 49 | 50 | pygments_style = 'sphinx' 51 | todo_include_todos = False 52 | 53 | html_theme = 'alabaster' 54 | html_theme_options = { 55 | 'github_user': 'mittagessen', 56 | 'github_repo': 'kraken', 57 | } 58 | html_logo = '_static/kraken.png' 59 | 60 | html_static_path = ['_static'] 61 | html_css_files = [ 62 | 'custom.css', 63 | ] 64 | 65 | html_sidebars = { 66 | 'index': ['sidebarintro.html', 'navigation.html', 'searchbox.html', 'versions.html'], 67 | '**': ['localtoc.html', 'relations.html', 'searchbox.html', 'versions.html'] 68 | } 69 | 70 | html_baseurl = 'kraken.re' 71 | htmlhelp_basename = 'krakendoc' 72 | 73 | smv_branch_whitelist = r'main' 74 | smv_tag_whitelist = r'^[2-9]\.\d+(\.0)?$' 75 | -------------------------------------------------------------------------------- /docs/gpu.rst: -------------------------------------------------------------------------------- 1 | .. _gpu: 2 | 3 | GPU Acceleration 4 | ================ 5 | 6 | The latest version of kraken uses a new pytorch backend which enables GPU 7 | acceleration both for training and recognition. Apart from a compatible Nvidia 8 | GPU, CUDA and cuDNN have to be installed so pytorch can run computation on it. 9 | 10 | 11 | -------------------------------------------------------------------------------- /docs/models.rst: -------------------------------------------------------------------------------- 1 | .. _models: 2 | 3 | Models 4 | ====== 5 | 6 | There are currently three kinds of models containing the recurrent neural 7 | networks doing all the character recognition supported by kraken: ``pronn`` 8 | files serializing old pickled ``pyrnn`` models as protobuf, clstm's native 9 | serialization, and versatile `Core ML 10 | `_ models. 11 | 12 | CoreML 13 | ------ 14 | 15 | Core ML allows arbitrary network architectures in a compact serialization with 16 | metadata. This is the default format in pytorch-based kraken. 17 | 18 | Segmentation Models 19 | ------------------- 20 | 21 | Recognition Models 22 | ------------------ 23 | 24 | 25 | -------------------------------------------------------------------------------- /docs/pagexml.xml: -------------------------------------------------------------------------------- 1 | 2 | ... 3 | 4 | 6 | 7 | 8 | 9 | text text text 10 | 11 | ... 12 | 13 | .... 14 | 15 | 16 | 17 | ... 18 | .... 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /docs/redirect.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Redirecting to main branch 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: kraken 2 | channels: 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - python>=3.9 7 | - python-bidi~=0.6.0 8 | - lxml 9 | - regex 10 | - requests 11 | - click>=8.1 12 | - numpy~=2.0.0 13 | - pillow~=9.2.0 14 | - scipy~=1.13.0 15 | - jinja2~=3.0 16 | - conda-forge::torchvision-cpu>=0.5.0 17 | - conda-forge::pytorch-cpu~=2.4.0 18 | - jsonschema 19 | - scikit-learn~=1.2.1 20 | - scikit-image~=0.24.0 21 | - shapely>=2.0.6,~=2.0.6 22 | - pyvips 23 | - imagemagick>=7.1.0 24 | - pyarrow 25 | - importlib-resources>=1.3.0 26 | - conda-forge::lightning~=2.4.0 27 | - conda-forge::torchmetrics>=1.1.0 28 | - conda-forge::threadpoolctl~=3.5.0 29 | - pip 30 | - albumentations 31 | - rich 32 | - setuptools>=36.6.0,<70.0.0 33 | - pip: 34 | - coremltools~=8.1 35 | - htrmopo 36 | - platformdirs 37 | - file:. 38 | -------------------------------------------------------------------------------- /environment_cuda.yml: -------------------------------------------------------------------------------- 1 | name: kraken 2 | channels: 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - python>=3.9 7 | - python-bidi~=0.6.0 8 | - lxml 9 | - regex 10 | - requests 11 | - click>=8.1 12 | - numpy~=1.23 13 | - pillow>=9.2.0 14 | - scipy~=1.13.0 15 | - jinja2~=3.0 16 | - conda-forge::torchvision>=0.5.0 17 | - conda-forge::pytorch~=2.4.0 18 | - cudatoolkit>=9.2 19 | - jsonschema 20 | - scikit-learn~=1.2.1 21 | - scikit-image~=0.24.0 22 | - shapely>=2.0.6,~=2.0.6 23 | - pyvips 24 | - imagemagick>=7.1.0 25 | - pyarrow 26 | - importlib-resources>=1.3.0 27 | - conda-forge::lightning~=2.4.0 28 | - conda-forge::torchmetrics>=1.1.0 29 | - conda-forge::threadpoolctl~=3.5.0 30 | - pip 31 | - albumentations 32 | - rich 33 | - setuptools>=36.6.0,<70.0.0 34 | - pip: 35 | - coremltools~=8.1 36 | - htrmopo 37 | - platformdirs 38 | - file:. 39 | -------------------------------------------------------------------------------- /kraken/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | entry point for kraken functionality 3 | """ 4 | -------------------------------------------------------------------------------- /kraken/align.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2021 Teklia 3 | # Copyright 2021 Benjamin Kiessling 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """ 18 | align 19 | ~~~~~ 20 | 21 | A character alignment module using a network output lattice and ground truth to 22 | accuractely determine grapheme locations in input data. 23 | """ 24 | import dataclasses 25 | import logging 26 | from dataclasses import dataclass 27 | from typing import TYPE_CHECKING, Literal, Optional 28 | 29 | import torch 30 | from bidi.algorithm import get_display 31 | from PIL import Image 32 | 33 | from kraken import rpred 34 | from kraken.containers import BaselineOCRRecord, Segmentation 35 | 36 | if TYPE_CHECKING: 37 | from kraken.lib.models import TorchSeqRecognizer 38 | 39 | logger = logging.getLogger('kraken') 40 | 41 | 42 | def forced_align(doc: Segmentation, model: 'TorchSeqRecognizer', base_dir: Optional[Literal['L', 'R']] = None) -> Segmentation: 43 | """ 44 | Performs a forced character alignment of text with recognition model 45 | output activations. 46 | 47 | Argument: 48 | doc: Parsed document. 49 | model: Recognition model to use for alignment. 50 | 51 | Returns: 52 | A Segmentation object where the record's contain the aligned text. 53 | """ 54 | im = Image.open(doc.imagename) 55 | predictor = rpred.rpred(model, im, doc) 56 | 57 | records = [] 58 | 59 | # enable training mode in last layer to get log_softmax output 60 | model.nn.nn[-1].training = True 61 | 62 | for idx, line in enumerate(doc.lines): 63 | # convert text to display order 64 | do_text = get_display(line.text, base_dir=base_dir) 65 | # encode into labels, ignoring unencodable sequences 66 | labels = model.codec.encode(do_text).long() 67 | next(predictor) 68 | if model.outputs.shape[2] < 2*len(labels): 69 | logger.warning(f'Could not align line {idx}. Output sequence length {model.outputs.shape[2]} < ' 70 | f'{2*len(labels)} (length of "{line.text}" after encoding).') 71 | records.append(BaselineOCRRecord('', [], [], line)) 72 | continue 73 | emission = torch.tensor(model.outputs).squeeze().T 74 | trellis = get_trellis(emission, labels) 75 | path = backtrack(trellis, emission, labels) 76 | path = merge_repeats(path, do_text) 77 | pred = [] 78 | pos = [] 79 | conf = [] 80 | for seg in path: 81 | pred.append(seg.label) 82 | pos.append((predictor._scale_val(seg.start, 0, predictor.box.size[0]), 83 | predictor._scale_val(seg.end, 0, predictor.box.size[0]))) 84 | conf.append(seg.score) 85 | records.append(BaselineOCRRecord(pred, pos, conf, line, display_order=True)) 86 | return dataclasses.replace(doc, lines=records) 87 | 88 | 89 | """ 90 | Copied from the forced alignment with Wav2Vec2 tutorial of pytorch available 91 | at: 92 | https://github.com/pytorch/audio/blob/main/examples/tutorials/forced_alignment_tutorial.py 93 | """ 94 | 95 | 96 | @dataclass 97 | class Point: 98 | token_index: int 99 | time_index: int 100 | score: float 101 | 102 | 103 | # Merge the labels 104 | @dataclass 105 | class Segment: 106 | label: str 107 | start: int 108 | end: int 109 | score: float 110 | 111 | def __repr__(self): 112 | return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})" 113 | 114 | @property 115 | def length(self): 116 | return self.end - self.start 117 | 118 | 119 | def get_trellis(emission, tokens): 120 | # width x labels in log domain 121 | num_frame = emission.size(0) 122 | num_tokens = len(tokens) 123 | 124 | # Trellis has extra dimensions for both time axis and tokens. 125 | # The extra dim for tokens represents (start-of-sentence) 126 | # The extra dim for time axis is for simplification of the code. 127 | trellis = torch.empty((num_frame + 1, num_tokens + 1)) 128 | trellis[0, 0] = 0 129 | trellis[1:, 0] = torch.cumsum(emission[:, 0], 0) 130 | trellis[0, -num_tokens:] = -float("inf") 131 | trellis[-num_tokens:, 0] = float("inf") 132 | 133 | for t in range(num_frame): 134 | trellis[t + 1, 1:] = torch.maximum( 135 | # Score for staying at the same token 136 | trellis[t, 1:] + emission[t, 0], 137 | # Score for changing to the next token 138 | trellis[t, :-1] + emission[t, tokens], 139 | ) 140 | return trellis 141 | 142 | 143 | def backtrack(trellis, emission, tokens): 144 | # Note: 145 | # j and t are indices for trellis, which has extra dimensions 146 | # for time and tokens at the beginning. 147 | # When referring to time frame index `T` in trellis, 148 | # the corresponding index in emission is `T-1`. 149 | # Similarly, when referring to token index `J` in trellis, 150 | # the corresponding index in transcript is `J-1`. 151 | j = trellis.size(1) - 1 152 | t_start = torch.argmax(trellis[:, j]).item() 153 | 154 | path = [] 155 | for t in range(t_start, 0, -1): 156 | # 1. Figure out if the current position was stay or change 157 | # Note (again): 158 | # `emission[J-1]` is the emission at time frame `J` of trellis dimension. 159 | # Score for token staying the same from time frame J-1 to T. 160 | stayed = trellis[t - 1, j] + emission[t - 1, 0] 161 | # Score for token changing from C-1 at T-1 to J at T. 162 | changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]] 163 | 164 | # 2. Store the path with frame-wise probability. 165 | prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item() 166 | # Return token index and time index in non-trellis coordinate. 167 | path.append(Point(j - 1, t - 1, prob)) 168 | 169 | # 3. Update the token 170 | if changed > stayed: 171 | j -= 1 172 | if j == 0: 173 | break 174 | else: 175 | raise ValueError("Failed to align") 176 | return path[::-1] 177 | 178 | 179 | def merge_repeats(path, ground_truth): 180 | i1, i2 = 0, 0 181 | segments = [] 182 | while i1 < len(path): 183 | while i2 < len(path) and path[i1].token_index == path[i2].token_index: 184 | i2 += 1 185 | score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1) 186 | segments.append( 187 | Segment( 188 | ground_truth[path[i1].token_index], 189 | path[i1].time_index, 190 | path[i2 - 1].time_index + 1, 191 | score, 192 | ) 193 | ) 194 | i1 = i2 195 | return segments 196 | -------------------------------------------------------------------------------- /kraken/binarization.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2015 Benjamin Kiessling 3 | # 2014 Thomas M. Breuel 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 14 | # or implied. See the License for the specific language governing 15 | # permissions and limitations under the License. 16 | """ 17 | kraken.binarization 18 | ~~~~~~~~~~~~~~~~~~~ 19 | 20 | An adaptive binarization algorithm. 21 | """ 22 | import logging 23 | import warnings 24 | from typing import TYPE_CHECKING 25 | 26 | import numpy as np 27 | from scipy.ndimage import (affine_transform, binary_dilation, gaussian_filter, 28 | percentile_filter) 29 | from scipy.ndimage import zoom as _zoom 30 | 31 | from kraken.lib.exceptions import KrakenInputException 32 | from kraken.lib.util import array2pil, get_im_str, is_bitonal, pil2array 33 | 34 | if TYPE_CHECKING: 35 | from PIL import Image 36 | 37 | 38 | __all__ = ['nlbin'] 39 | 40 | logger = logging.getLogger(__name__) 41 | 42 | 43 | def nlbin(im: 'Image.Image', 44 | threshold: float = 0.5, 45 | zoom: float = 0.5, 46 | escale: float = 1.0, 47 | border: float = 0.1, 48 | perc: int = 80, 49 | range: int = 20, 50 | low: int = 5, 51 | high: int = 90) -> 'Image.Image': 52 | """ 53 | Performs binarization using non-linear processing. 54 | 55 | Args: 56 | im: Input image 57 | threshold: 58 | zoom: Zoom for background page estimation 59 | escale: Scale for estimating a mask over the text region 60 | border: Ignore this much of the border 61 | perc: Percentage for filters 62 | range: Range for filters 63 | low: Percentile for black estimation 64 | high: Percentile for white estimation 65 | 66 | Returns: 67 | PIL.Image.Image containing the binarized image 68 | 69 | Raises: 70 | KrakenInputException: When trying to binarize an empty image. 71 | """ 72 | im_str = get_im_str(im) 73 | logger.info(f'Binarizing {im_str}') 74 | if is_bitonal(im): 75 | logger.info(f'Skipping binarization because {im_str} is bitonal.') 76 | return im 77 | # convert to grayscale first 78 | logger.debug(f'Converting {im_str} to grayscale') 79 | im = im.convert('L') 80 | raw = pil2array(im) 81 | logger.debug('Scaling and normalizing') 82 | # rescale image to between -1 or 0 and 1 83 | raw = raw/float(np.iinfo(raw.dtype).max) 84 | # perform image normalization 85 | if np.amax(raw) == np.amin(raw): 86 | logger.warning(f'Trying to binarize empty image {im_str}') 87 | raise KrakenInputException('Image is empty') 88 | image = raw-np.amin(raw) 89 | image /= np.amax(image) 90 | 91 | logger.debug('Interpolation and percentile filtering') 92 | with warnings.catch_warnings(): 93 | warnings.simplefilter('ignore', UserWarning) 94 | m = _zoom(image, zoom) 95 | m = percentile_filter(m, perc, size=(range, 2)) 96 | m = percentile_filter(m, perc, size=(2, range)) 97 | mh, mw = m.shape 98 | oh, ow = image.shape 99 | scale = np.diag([mh * 1.0/oh, mw * 1.0/ow]) 100 | m = affine_transform(m, scale, output_shape=image.shape) 101 | w, h = np.minimum(np.array(image.shape), np.array(m.shape)) 102 | flat = np.clip(image[:w, :h]-m[:w, :h]+1, 0, 1) 103 | 104 | # estimate low and high thresholds 105 | d0, d1 = flat.shape 106 | o0, o1 = int(border*d0), int(border*d1) 107 | est = flat[o0:d0-o0, o1:d1-o1] 108 | logger.debug('Threshold estimates {}'.format(est)) 109 | # by default, we use only regions that contain 110 | # significant variance; this makes the percentile 111 | # based low and high estimates more reliable 112 | logger.debug('Refine estimates') 113 | v = est-gaussian_filter(est, escale*20.0) 114 | v = gaussian_filter(v**2, escale*20.0)**0.5 115 | v = (v > 0.3*np.amax(v)) 116 | v = binary_dilation(v, structure=np.ones((int(escale * 50), 1))) 117 | v = binary_dilation(v, structure=np.ones((1, int(escale * 50)))) 118 | est = est[v] 119 | lo = np.percentile(est.ravel(), low) 120 | hi = np.percentile(est.ravel(), high) 121 | flat -= lo 122 | flat /= (hi-lo) 123 | flat = np.clip(flat, 0, 1) 124 | logger.debug(f'Thresholding at {threshold}') 125 | bin = np.array(255*(flat > threshold), 'B') 126 | return array2pil(bin) 127 | -------------------------------------------------------------------------------- /kraken/blla.mlmodel: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mittagessen/kraken/14778a4ef2d061d688efcf979cb038d34e4b87e0/kraken/blla.mlmodel -------------------------------------------------------------------------------- /kraken/contrib/add_neural_ro.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Computes an additional reading order from a neural RO model and adds it to an 4 | ALTO document. 5 | """ 6 | import click 7 | 8 | @click.command() 9 | @click.option('-f', '--format-type', type=click.Choice(['alto']), default='alto', 10 | help='Sets the input document format. In ALTO and PageXML mode all ' 11 | 'data is extracted from xml files containing both baselines, polygons, and a ' 12 | 'link to source images.') 13 | @click.option('-i', '--model', default=None, show_default=True, type=click.Path(exists=True), 14 | help='Baseline detection model to use. Overrides format type and expects image files as input.') 15 | @click.argument('files', nargs=-1) 16 | def cli(format_type, model, files): 17 | """ 18 | A script adding new neural reading orders to the input documents. 19 | """ 20 | if len(files) == 0: 21 | ctx = click.get_current_context() 22 | click.echo(ctx.get_help()) 23 | ctx.exit() 24 | 25 | import uuid 26 | 27 | from kraken import blla 28 | from kraken.lib import segmentation, vgsl, xml 29 | 30 | from lxml import etree 31 | from dataclasses import asdict 32 | 33 | try: 34 | net = vgsl.TorchVGSLModel.load_model(model) 35 | ro_class_mapping = net.user_metadata['ro_class_mapping'] 36 | ro_net = net.aux_layers['ro_model'] 37 | except: 38 | from kraken.lib.ro import ROModel 39 | net = ROModel.load_from_checkpoint(model) 40 | ro_class_mapping = net.class_mapping 41 | ro_model = net.ro_net 42 | 43 | for doc in files: 44 | click.echo(f'Processing {doc} ', nl=False) 45 | doc = xml.XMLPage(doc) 46 | if doc.filetype != 'alto': 47 | click.echo(f'Not an ALTO file. Skipping.') 48 | continue 49 | seg = doc.to_container() 50 | lines = list(map(asdict, seg.lines)) 51 | _order = segmentation.neural_reading_order(lines=lines, 52 | regions=seg.regions, 53 | model=ro_model, 54 | im_size=doc.image_size[::-1], 55 | class_mapping=ro_class_mapping) 56 | # reorder 57 | lines = [lines[idx] for idx in _order] 58 | # add ReadingOrder block to ALTO 59 | tree = etree.parse(doc.filename) 60 | alto = tree.getroot() 61 | if alto.find('./{*}ReadingOrder'): 62 | click.echo(f'Addition to files with explicit reading order not yet supported. Skipping.') 63 | continue 64 | ro = etree.Element('ReadingOrder') 65 | og = etree.SubElement(ro, 'OrderedGroup') 66 | og.set('ID', f'_{uuid.uuid4()}') 67 | for line in lines: 68 | el = etree.SubElement(og, 'ElementRef') 69 | el.set('ID', f'_{uuid.uuid4()}') 70 | el.set('REF', f'{line["id"]}') 71 | tree.find('.//{*}Layout').addprevious(ro) 72 | with open(doc.filename.with_suffix('.ro.xml'), 'wb') as fo: 73 | fo.write(etree.tostring(tree, encoding='UTF-8', xml_declaration=True, pretty_print=True)) 74 | click.secho('\u2713', fg='green') 75 | 76 | 77 | if __name__ == '__main__': 78 | cli() 79 | -------------------------------------------------------------------------------- /kraken/contrib/baselineset_overlay.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | """ 3 | Produces semi-transparent neural segmenter output overlays 4 | """ 5 | import click 6 | 7 | 8 | @click.command() 9 | @click.argument('files', nargs=-1) 10 | def cli(files): 11 | 12 | from os.path import splitext 13 | 14 | import torch 15 | import torchvision.transforms as tf 16 | from PIL import Image 17 | 18 | from kraken.lib import dataset 19 | 20 | batch, channels, height, width = 1, 3, 1200, 0 21 | transforms = dataset.ImageInputTransforms(batch, height, width, channels, 0, valid_norm=False) 22 | 23 | torch.set_num_threads(1) 24 | 25 | ds = dataset.BaselineSet(files, im_transforms=transforms, mode='xml') 26 | 27 | for idx, batch in enumerate(ds): 28 | img = ds.imgs[idx] 29 | print(img) 30 | im = Image.open(img) 31 | res_tf = tf.Compose(transforms.transforms[:2]) 32 | scal_im = res_tf(im) 33 | o = batch['target'].numpy() 34 | heat = Image.fromarray((o[ds.class_mapping['baselines']['default']]*255).astype('uint8')) 35 | heat.save(splitext(img)[0] + '.heat.png') 36 | overlay = Image.new('RGBA', scal_im.size, (0, 130, 200, 255)) 37 | bl = Image.composite(overlay, scal_im.convert('RGBA'), heat) 38 | heat = Image.fromarray((o[ds.class_mapping['aux']['_start_separator']]*255).astype('uint8')) 39 | overlay = Image.new('RGBA', scal_im.size, (230, 25, 75, 255)) 40 | bl = Image.composite(overlay, bl, heat) 41 | heat = Image.fromarray((o[ds.class_mapping['aux']['_end_separator']]*255).astype('uint8')) 42 | overlay = Image.new('RGBA', scal_im.size, (60, 180, 75, 255)) 43 | Image.composite(overlay, bl, heat).save(splitext(img)[0] + '.overlay.png') 44 | del o 45 | del im 46 | 47 | 48 | if __name__ == '__main__': 49 | cli() 50 | -------------------------------------------------------------------------------- /kraken/contrib/extract_lines.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | import click 3 | 4 | 5 | @click.command() 6 | @click.option('-f', '--format-type', type=click.Choice(['xml', 'alto', 'page', 'binary']), default='xml', 7 | help='Sets the input document format. In ALTO and PageXML mode all ' 8 | 'data is extracted from xml files containing both baselines, polygons, and a ' 9 | 'link to source images.') 10 | @click.option('-i', '--model', default=None, show_default=True, type=click.Path(exists=True), 11 | help='Baseline detection model to use. Overrides format type and expects image files as input.') 12 | @click.option('--legacy-polygons', is_flag=True, help='Use the legacy polygon extractor.') 13 | @click.argument('files', nargs=-1) 14 | def cli(format_type, model, legacy_polygons, files): 15 | """ 16 | A small script extracting rectified line polygons as defined in either ALTO or 17 | PageXML files or run a model to do the same. 18 | """ 19 | if len(files) == 0: 20 | ctx = click.get_current_context() 21 | click.echo(ctx.get_help()) 22 | ctx.exit() 23 | 24 | import io 25 | import json 26 | from os.path import splitext 27 | 28 | import pyarrow as pa 29 | from PIL import Image 30 | 31 | from kraken import blla 32 | from kraken.lib import segmentation, vgsl, xml 33 | 34 | if model is None: 35 | for doc in files: 36 | click.echo(f'Processing {doc} ', nl=False) 37 | if format_type != 'binary': 38 | data = xml.XMLPage(doc, format_type) 39 | if len(data.lines) > 0: 40 | bounds = data.to_container() 41 | for idx, (im, box) in enumerate(segmentation.extract_polygons(Image.open(bounds.imagename), bounds, legacy=legacy_polygons)): 42 | click.echo('.', nl=False) 43 | im.save('{}.{}.jpg'.format(splitext(bounds.imagename)[0], idx)) 44 | with open('{}.{}.gt.txt'.format(splitext(bounds.imagename)[0], idx), 'w') as fp: 45 | fp.write(box.text) 46 | else: 47 | with pa.memory_map(doc, 'rb') as source: 48 | ds_table = pa.ipc.open_file(source).read_all() 49 | raw_metadata = ds_table.schema.metadata 50 | if not raw_metadata or b'lines' not in raw_metadata: 51 | raise ValueError(f'{doc} does not contain a valid metadata record.') 52 | metadata = json.loads(raw_metadata[b'lines']) 53 | for idx in range(metadata['counts']['all']): 54 | sample = ds_table.column('lines')[idx].as_py() 55 | im = Image.open(io.BytesIO(sample['im'])) 56 | im.save('{}.{}.jpg'.format(splitext(doc)[0], idx)) 57 | with open('{}.{}.gt.txt'.format(splitext(doc)[0], idx), 'w') as fp: 58 | fp.write(sample['text']) 59 | click.echo() 60 | else: 61 | net = vgsl.TorchVGSLModel.load_model(model) 62 | for doc in files: 63 | click.echo(f'Processing {doc} ', nl=False) 64 | full_im = Image.open(doc) 65 | bounds = blla.segment(full_im, model=net) 66 | for idx, (im, box) in enumerate(segmentation.extract_polygons(full_im, bounds, legacy=legacy_polygons)): 67 | click.echo('.', nl=False) 68 | im.save('{}.{}.jpg'.format(splitext(doc)[0], idx)) 69 | click.echo() 70 | 71 | 72 | if __name__ == '__main__': 73 | cli() 74 | -------------------------------------------------------------------------------- /kraken/contrib/forced_alignment_overlay.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Draws a transparent overlay of the forced alignment output over the input 4 | image. 5 | """ 6 | import os 7 | import re 8 | import unicodedata 9 | from itertools import cycle 10 | from unicodedata import normalize 11 | 12 | import click 13 | from lxml import etree 14 | 15 | cmap = cycle([(230, 25, 75, 127), 16 | (60, 180, 75, 127), 17 | (255, 225, 25, 127), 18 | (0, 130, 200, 127), 19 | (245, 130, 48, 127), 20 | (145, 30, 180, 127), 21 | (70, 240, 240, 127)]) 22 | 23 | 24 | def slugify(value): 25 | """ 26 | Normalizes string, converts to lowercase, removes non-alpha characters, 27 | and converts spaces to hyphens. 28 | """ 29 | value = unicodedata.normalize('NFKD', value) 30 | value = re.sub(r'[^\w\s-]', '', value).strip().lower() 31 | value = re.sub(r'[-\s]+', '-', value) 32 | return value 33 | 34 | 35 | def _repl_alto(fname, cuts): 36 | with open(fname, 'rb') as fp: 37 | doc = etree.parse(fp) 38 | lines = doc.findall('.//{*}TextLine') 39 | char_idx = 0 40 | for line, line_cuts in zip(lines, cuts.lines): 41 | idx = 0 42 | for el in line: 43 | if el.tag.endswith('Shape'): 44 | continue 45 | elif el.tag.endswith('SP'): 46 | idx += 1 47 | elif el.tag.endswith('String'): 48 | str_len = len(el.get('CONTENT')) 49 | # clear out all 50 | for chld in el: 51 | if chld.tag.endswith('Glyph'): 52 | el.remove(chld) 53 | for char in zip(line_cuts.prediction[idx:str_len], 54 | line_cuts.cuts[idx:str_len], 55 | line_cuts.confidences[idx:str_len]): 56 | glyph = etree.SubElement(el, 'Glyph') 57 | glyph.set('ID', f'char_{char_idx}') 58 | char_idx += 1 59 | glyph.set('CONTENT', char[0]) 60 | glyph.set('GC', str(char[2])) 61 | pol = etree.SubElement(etree.SubElement(glyph, 'Shape'), 'Polygon') 62 | pol.set('POINTS', ' '.join([str(coord) for pt in char[1] for coord in pt])) 63 | idx += str_len 64 | with open(f'{os.path.basename(fname)}_algn.xml', 'wb') as fp: 65 | doc.write(fp, encoding='UTF-8', xml_declaration=True) 66 | 67 | 68 | def _repl_page(fname, cuts): 69 | with open(fname, 'rb') as fp: 70 | doc = etree.parse(fp) 71 | lines = doc.findall('.//{*}TextLine') 72 | for line, line_cuts in zip(lines, cuts.lines): 73 | glyphs = line.findall('../{*}Glyph/{*}Coords') 74 | for glyph, cut in zip(glyphs, line_cuts): 75 | glyph.attrib['points'] = ' '.join([','.join([str(x) for x in pt]) for pt in cut]) 76 | with open(f'{os.path.basename(fname)}_algn.xml', 'wb') as fp: 77 | doc.write(fp, encoding='UTF-8', xml_declaration=True) 78 | 79 | 80 | @click.command() 81 | @click.option('-f', '--format-type', type=click.Choice(['alto', 'page']), default='page', 82 | help='Sets the input document format. In ALTO and PageXML mode all ' 83 | 'data is extracted from xml files containing both baselines, polygons, and a ' 84 | 'link to source images.') 85 | @click.option('-i', '--model', default=None, show_default=True, type=click.Path(exists=True), 86 | help='Transcription model to use.') 87 | @click.option('-u', '--normalization', show_default=True, type=click.Choice(['NFD', 'NFKD', 'NFC', 'NFKC']), 88 | default=None, 89 | help='Ground truth normalization') 90 | @click.option('-o', '--output', type=click.Choice(['xml', 'overlay']), 91 | show_default=True, default='overlay', help='Output mode. Either page or ' 92 | 'alto for xml output, overlay for image overlays.') 93 | @click.argument('files', nargs=-1) 94 | def cli(format_type, model, normalization, output, files): 95 | """ 96 | A script producing overlays of lines and regions from either ALTO or 97 | PageXML files or run a model to do the same. 98 | """ 99 | if len(files) == 0: 100 | ctx = click.get_current_context() 101 | click.echo(ctx.get_help()) 102 | ctx.exit() 103 | 104 | from PIL import Image, ImageDraw 105 | 106 | from kraken import align 107 | from kraken.lib import models 108 | from kraken.lib.xml import XMLPage 109 | 110 | if format_type == 'alto': 111 | repl_fn = _repl_alto 112 | else: 113 | repl_fn = _repl_page 114 | 115 | click.echo(f'Loading model {model}') 116 | net = models.load_any(model) 117 | 118 | for doc in files: 119 | click.echo(f'Processing {doc} ', nl=False) 120 | data = XMLPage(doc) 121 | im = Image.open(data.imagename).convert('RGBA') 122 | result = align.forced_align(data.to_container(), net) 123 | if normalization: 124 | for line in data._lines: 125 | line["text"] = normalize(normalization, line["text"]) 126 | im = Image.open(data.imagename).convert('RGBA') 127 | result = align.forced_align(data.to_container(), net) 128 | if output == 'overlay': 129 | tmp = Image.new('RGBA', im.size, (0, 0, 0, 0)) 130 | draw = ImageDraw.Draw(tmp) 131 | for record in result.lines: 132 | for pol in record.cuts: 133 | c = next(cmap) 134 | draw.polygon([tuple(x) for x in pol], fill=c, outline=c[:3]) 135 | base_image = Image.alpha_composite(im, tmp) 136 | base_image.save(f'high_{os.path.basename(doc)}_algn.png') 137 | else: 138 | repl_fn(doc, result) 139 | click.secho('\u2713', fg='green') 140 | 141 | 142 | if __name__ == '__main__': 143 | cli() 144 | -------------------------------------------------------------------------------- /kraken/contrib/generate_scripts.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Script fetching the latest unicode Scripts.txt and dumping it as json. 4 | """ 5 | import json 6 | from urllib import request 7 | 8 | import regex 9 | 10 | uri = 'http://www.unicode.org/Public/UNIDATA/Scripts.txt' 11 | 12 | re = regex.compile(r'^(?P[0-9A-F]{4,6})(..(?P[0-9A-F]{4,6}))?\s+; (?P[A-Za-z]+)') 13 | 14 | with open('scripts.json', 'w') as fp, request.urlopen(uri) as req: 15 | d = [] 16 | for line in req: 17 | line = line.decode('utf-8') 18 | if line.startswith('#') or line.strip() == '': 19 | continue 20 | m = re.match(line) 21 | if m: 22 | print(line) 23 | start = int(m.group('start'), base=16) 24 | end = start 25 | if m.group('end'): 26 | end = int(m.group('end'), base=16) 27 | name = m.group('name') 28 | if len(d) > 0 and d[-1][2] == name and (start - 1 == d[-1][1] or start - 1 == d[-1][0]): 29 | print('merging {} and ({}, {}, {})'.format(d[-1], start, end, name)) 30 | d[-1] = (d[-1][0], end, name) 31 | else: 32 | d.append((start, end if end != start else None, name)) 33 | json.dump(d, fp) 34 | -------------------------------------------------------------------------------- /kraken/contrib/heatmap_overlay.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | """ 3 | Produces semi-transparent neural segmenter output overlays 4 | """ 5 | import click 6 | 7 | 8 | @click.command() 9 | @click.option('-i', '--model', default=None, show_default=True, type=click.Path(exists=True), 10 | help='Baseline detection model to use.') 11 | @click.argument('files', nargs=-1) 12 | def cli(model, files): 13 | """ 14 | Applies a BLLA baseline segmentation model and outputs the raw heatmaps of 15 | the first baseline class. 16 | """ 17 | from os.path import splitext 18 | 19 | import torch 20 | import torch.nn.functional as F 21 | import torchvision.transforms as tf 22 | from PIL import Image 23 | 24 | from kraken.lib import dataset, vgsl 25 | 26 | model = vgsl.TorchVGSLModel.load_model(model) 27 | model.eval() 28 | batch, channels, height, width = model.input 29 | 30 | transforms = dataset.ImageInputTransforms(batch, height, width, channels, 0, valid_norm=False) 31 | 32 | torch.set_num_threads(1) 33 | 34 | for img in files: 35 | print(img) 36 | im = Image.open(img) 37 | xs = transforms(im) 38 | 39 | with torch.no_grad(): 40 | o, _ = model.nn(xs.unsqueeze(0)) 41 | o = F.interpolate(o, size=xs.shape[1:]) 42 | o = o.squeeze().numpy() 43 | 44 | scal_im = tf.ToPILImage()(1-xs) 45 | heat = Image.fromarray((o[2]*255).astype('uint8')) 46 | heat.save(splitext(img)[0] + '.heat.png') 47 | overlay = Image.new('RGBA', scal_im.size, (0, 130, 200, 255)) 48 | bl = Image.composite(overlay, scal_im.convert('RGBA'), heat) 49 | heat = Image.fromarray((o[1]*255).astype('uint8')) 50 | overlay = Image.new('RGBA', scal_im.size, (230, 25, 75, 255)) 51 | bl = Image.composite(overlay, bl, heat) 52 | heat = Image.fromarray((o[0]*255).astype('uint8')) 53 | overlay = Image.new('RGBA', scal_im.size, (60, 180, 75, 255)) 54 | Image.composite(overlay, bl, heat).save(splitext(img)[0] + '.overlay.png') 55 | del o 56 | del im 57 | 58 | 59 | if __name__ == '__main__': 60 | cli() 61 | -------------------------------------------------------------------------------- /kraken/contrib/new_classes.json: -------------------------------------------------------------------------------- 1 | {"baselines": {"defaultLine": 2}, "regions": {"foo": 3}} 2 | -------------------------------------------------------------------------------- /kraken/contrib/recognition_boxes.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Draws transparent character bounding boxes over images giving a legacy 4 | segmenter model. 5 | """ 6 | 7 | import os 8 | import sys 9 | from itertools import cycle 10 | 11 | from PIL import Image, ImageDraw 12 | 13 | from kraken.binarization import nlbin 14 | from kraken.lib import models 15 | from kraken.pageseg import segment 16 | from kraken.rpred import rpred 17 | 18 | cmap = cycle([(230, 25, 75, 127), 19 | (60, 180, 75, 127), 20 | (255, 225, 25, 127), 21 | (0, 130, 200, 127), 22 | (245, 130, 48, 127), 23 | (145, 30, 180, 127), 24 | (70, 240, 240, 127)]) 25 | 26 | net = models.load_any(sys.argv[1]) 27 | 28 | for fname in sys.argv[2:]: 29 | im = Image.open(fname) 30 | print(fname) 31 | im = nlbin(im) 32 | res = segment(im, maxcolseps=0) 33 | pred = rpred(net, im, res) 34 | im = im.convert('RGBA') 35 | tmp = Image.new('RGBA', im.size, (0, 0, 0, 0)) 36 | draw = ImageDraw.Draw(tmp) 37 | for line in pred: 38 | for box in line.cuts: 39 | draw.rectangle(box, fill=next(cmap)) 40 | im = Image.alpha_composite(im, tmp) 41 | im.save('high_{}'.format(os.path.basename(fname))) 42 | -------------------------------------------------------------------------------- /kraken/contrib/repolygonize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Reads in a bunch of ALTO documents and repolygonizes the lines contained with 4 | the kraken polygonizer. 5 | """ 6 | import click 7 | 8 | 9 | @click.command() 10 | @click.option('-f', '--format-type', type=click.Choice(['alto', 'page']), default='alto', 11 | help='Sets the input document format. In ALTO and PageXML mode all ' 12 | 'data is extracted from xml files containing both baselines, polygons, and a ' 13 | 'link to source images.') 14 | @click.option('-tl', '--topline', 'topline', show_default=True, flag_value='topline', 15 | help='Switch for the baseline location in the scripts. ' 16 | 'Set to topline if the data is annotated with a hanging baseline, as is ' 17 | 'common with Hebrew, Bengali, Devanagari, etc. Set to ' 18 | ' centerline for scripts annotated with a central line.') 19 | @click.option('-cl', '--centerline', 'topline', flag_value='centerline') 20 | @click.option('-bl', '--baseline', 'topline', flag_value='baseline', default='baseline') 21 | @click.option('--scale', show_default=True, type=click.INT, default=1800, help='A integer height ' 22 | 'containing optional scale factors of the input. Default 1800.') 23 | @click.argument('files', nargs=-1) 24 | def cli(format_type, topline, scale, files): 25 | """ 26 | A small script repolygonizing line boundaries in ALTO or PageXML files. 27 | """ 28 | if len(files) == 0: 29 | ctx = click.get_current_context() 30 | click.echo(ctx.get_help()) 31 | ctx.exit() 32 | 33 | from itertools import groupby 34 | from os.path import splitext 35 | 36 | from lxml import etree 37 | from PIL import Image 38 | 39 | from kraken.lib import xml 40 | from kraken.lib.segmentation import calculate_polygonal_environment 41 | 42 | def _repl_alto(fname, polygons): 43 | with open(fname, 'rb') as fp: 44 | doc = etree.parse(fp) 45 | lines = doc.findall('.//{*}TextLine') 46 | idx = 0 47 | for line in lines: 48 | if line.get('BASELINE') is None: 49 | continue 50 | pol = line.find('./{*}Shape/{*}Polygon') 51 | if pol is not None: 52 | if polygons[idx] is not None: 53 | pol.attrib['POINTS'] = ' '.join([str(coord) for pt in polygons[idx] for coord in pt]) 54 | else: 55 | pol.attrib['POINTS'] = '' 56 | idx += 1 57 | with open(splitext(fname)[0] + '_rewrite.xml', 'wb') as fp: 58 | doc.write(fp, encoding='UTF-8', xml_declaration=True) 59 | 60 | def _parse_page_coords(coords): 61 | points = [x for x in coords.split(' ')] 62 | points = [int(c) for point in points for c in point.split(',')] 63 | pts = zip(points[::2], points[1::2]) 64 | return [k for k, g in groupby(pts)] 65 | 66 | def _repl_page(fname, polygons): 67 | with open(fname, 'rb') as fp: 68 | doc = etree.parse(fp) 69 | lines = doc.findall('.//{*}TextLine') 70 | idx = 0 71 | for line in lines: 72 | base = line.find('./{*}Baseline') 73 | if base is not None and not base.get('points').isspace() and len(base.get('points')): 74 | try: 75 | _parse_page_coords(base.get('points')) 76 | except Exception: 77 | continue 78 | else: 79 | continue 80 | pol = line.find('./{*}Coords') 81 | if pol is not None: 82 | if polygons[idx] is not None: 83 | pol.attrib['points'] = ' '.join([','.join([str(x) for x in pt]) for pt in polygons[idx]]) 84 | else: 85 | pol.attrib['points'] = '' 86 | idx += 1 87 | with open(splitext(fname)[0] + '_rewrite.xml', 'wb') as fp: 88 | doc.write(fp, encoding='UTF-8', xml_declaration=True) 89 | 90 | if format_type == 'page': 91 | repl_fn = _repl_page 92 | else: 93 | repl_fn = _repl_alto 94 | 95 | topline = {'topline': True, 96 | 'baseline': False, 97 | 'centerline': None}[topline] 98 | 99 | for doc in files: 100 | click.echo(f'Processing {doc} ') 101 | seg = xml.XMLPage(doc).to_container() 102 | im = Image.open(seg.imagename).convert('L') 103 | baselines = [] 104 | for x in seg.lines: 105 | bl = x.baseline if x.baseline is not None else [0, 0] 106 | baselines.append(bl) 107 | o = calculate_polygonal_environment(im, baselines, scale=(scale,0), topline=topline) 108 | repl_fn(doc, o) 109 | 110 | 111 | if __name__ == '__main__': 112 | cli() 113 | -------------------------------------------------------------------------------- /kraken/contrib/segmentation_overlay.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Draws a transparent overlay of baseline segmenter output over a list of image 4 | files. 5 | """ 6 | import dataclasses 7 | import os 8 | import re 9 | import unicodedata 10 | from collections import defaultdict 11 | from itertools import cycle 12 | 13 | import click 14 | 15 | cmap = cycle([(230, 25, 75, 127), 16 | (60, 180, 75, 127)]) 17 | 18 | bmap = (0, 130, 200, 255) 19 | 20 | 21 | def slugify(value): 22 | """ 23 | Normalizes string, converts to lowercase, removes non-alpha characters, 24 | and converts spaces to hyphens. 25 | """ 26 | value = unicodedata.normalize('NFKD', value) 27 | value = re.sub(r'[^\w\s-]', '', value).strip().lower() 28 | value = re.sub(r'[-\s]+', '-', value) 29 | return value 30 | 31 | 32 | @click.command() 33 | @click.option('-i', '--model', default=None, show_default=True, type=click.Path(exists=True), 34 | help='Baseline detection model to use. Overrides format type and expects image files as input.') 35 | @click.option('-d', '--text-direction', default='horizontal-lr', 36 | show_default=True, 37 | type=click.Choice(['horizontal-lr', 'horizontal-rl', 38 | 'vertical-lr', 'vertical-rl']), 39 | help='Sets principal text direction') 40 | @click.option('--repolygonize/--no-repolygonize', show_default=True, 41 | default=False, help='Repolygonizes line data in ALTO/PageXML ' 42 | 'files. This ensures that the trained model is compatible with the ' 43 | 'segmenter in kraken even if the original image files either do ' 44 | 'not contain anything but transcriptions and baseline information ' 45 | 'or the polygon data was created using a different method. Will ' 46 | 'be ignored in `path` mode. Note, that this option will be slow ' 47 | 'and will not scale input images to the same size as the segmenter ' 48 | 'does.') 49 | @click.option('-tl', '--topline', 'topline', show_default=True, flag_value='topline', 50 | help='Switch for the baseline location in the scripts. ') 51 | @click.option('-cl', '--centerline', 'topline', flag_value='centerline') 52 | @click.option('-bl', '--baseline', 'topline', flag_value='baseline', default='baseline') 53 | @click.option('--height-scale', default=1800, show_default=True, 54 | help='Maximum height of input image in height dimension') 55 | @click.argument('files', nargs=-1) 56 | def cli(model, text_direction, repolygonize, topline, height_scale, files): 57 | """ 58 | A script producing overlays of lines and regions from either ALTO or 59 | PageXML files or run a model to do the same. 60 | """ 61 | if len(files) == 0: 62 | ctx = click.get_current_context() 63 | click.echo(ctx.get_help()) 64 | ctx.exit() 65 | 66 | from PIL import Image, ImageDraw 67 | 68 | from kraken import blla 69 | from kraken.lib import segmentation, vgsl, xml 70 | 71 | loc = {'topline': True, 72 | 'baseline': False, 73 | 'centerline': None} 74 | 75 | topline = loc[topline] 76 | 77 | if model is None: 78 | for doc in files: 79 | click.echo(f'Processing {doc} ', nl=False) 80 | data = xml.XMLPage(doc).to_container() 81 | if repolygonize: 82 | im = Image.open(data.imagename).convert('L') 83 | lines = data.lines 84 | polygons = segmentation.calculate_polygonal_environment(im, 85 | [x.baseline for x in lines], 86 | scale=(height_scale, 0), 87 | topline=topline) 88 | data.lines = [dataclasses.replace(orig, boundary=polygon) for orig, polygon in zip(lines, polygons)] 89 | # reorder lines by type 90 | lines = defaultdict(list) 91 | for line in data.lines: 92 | lines[line.tags['type']].append(line) 93 | im = Image.open(data.imagename).convert('RGBA') 94 | for t, ls in lines.items(): 95 | tmp = Image.new('RGBA', im.size, (0, 0, 0, 0)) 96 | draw = ImageDraw.Draw(tmp) 97 | for idx, line in enumerate(ls): 98 | c = next(cmap) 99 | if line.boundary: 100 | draw.polygon([tuple(x) for x in line.boundary], fill=c, outline=c[:3]) 101 | if line.baseline: 102 | draw.line([tuple(x) for x in line.baseline], fill=bmap, width=2, joint='curve') 103 | draw.text(line.baseline[0], str(idx), fill=(0, 0, 0, 255)) 104 | base_image = Image.alpha_composite(im, tmp) 105 | base_image.save(f'high_{os.path.basename(doc)}_lines_{slugify(t)}.png') 106 | for t, regs in data.regions.items(): 107 | tmp = Image.new('RGBA', im.size, (0, 0, 0, 0)) 108 | draw = ImageDraw.Draw(tmp) 109 | for reg in regs: 110 | c = next(cmap) 111 | try: 112 | draw.polygon(reg.boundary, fill=c, outline=c[:3]) 113 | except Exception: 114 | pass 115 | base_image = Image.alpha_composite(im, tmp) 116 | base_image.save(f'high_{os.path.basename(doc)}_regions_{slugify(t)}.png') 117 | click.secho('\u2713', fg='green') 118 | else: 119 | net = vgsl.TorchVGSLModel.load_model(model) 120 | for doc in files: 121 | click.echo(f'Processing {doc} ', nl=False) 122 | im = Image.open(doc) 123 | res = blla.segment(im, model=net, text_direction=text_direction) 124 | # reorder lines by type 125 | lines = defaultdict(list) 126 | for line in res.lines: 127 | lines[line.tags['type']].append(line) 128 | im = im.convert('RGBA') 129 | for t, ls in lines.items(): 130 | tmp = Image.new('RGBA', im.size, (0, 0, 0, 0)) 131 | draw = ImageDraw.Draw(tmp) 132 | for idx, line in enumerate(ls): 133 | c = next(cmap) 134 | draw.polygon([tuple(x) for x in line.boundary], fill=c, outline=c[:3]) 135 | draw.line([tuple(x) for x in line.baseline], fill=bmap, width=2, joint='curve') 136 | draw.text(line.baseline[0], str(idx), fill=(0, 0, 0, 255)) 137 | base_image = Image.alpha_composite(im, tmp) 138 | base_image.save(f'high_{os.path.basename(doc)}_lines_{slugify(t)}.png') 139 | for t, regs in res.regions.items(): 140 | tmp = Image.new('RGBA', im.size, (0, 0, 0, 0)) 141 | draw = ImageDraw.Draw(tmp) 142 | for reg in regs: 143 | c = next(cmap) 144 | try: 145 | draw.polygon([tuple(x) for x in reg.boundary], fill=c, outline=c[:3]) 146 | except Exception: 147 | pass 148 | 149 | base_image = Image.alpha_composite(im, tmp) 150 | base_image.save(f'high_{os.path.basename(doc)}_regions_{slugify(t)}.png') 151 | click.secho('\u2713', fg='green') 152 | 153 | 154 | if __name__ == '__main__': 155 | cli() 156 | -------------------------------------------------------------------------------- /kraken/contrib/set_seg_options.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | A script setting the metadata of segmentation models. 4 | """ 5 | import shutil 6 | 7 | import click 8 | 9 | 10 | @click.command() 11 | @click.option('-b', '--bounding-region', multiple=True, help='Sets region identifiers which bound line bounding polygons') 12 | @click.option('--topline', 'topline', 13 | help='Sets model metadata baseline location to either `--topline`, `--centerline`, or `--baseline`', 14 | flag_value='topline', 15 | show_default=True) 16 | @click.option('--centerline', 'topline', flag_value='centerline') 17 | @click.option('--baseline', 'topline', flag_value='topline') 18 | @click.option('--pad', show_default=True, type=(int, int), default=(0, 0), 19 | help='Padding (left/right, top/bottom) around the page image') 20 | @click.option('--output-identifiers', type=click.Path(exists=True), help='Path ' 21 | 'to a json file containing a dict updating the string identifiers ' 22 | 'of line/region classes.') 23 | @click.argument('model', nargs=1, type=click.Path(exists=True)) 24 | def cli(bounding_region, topline, pad, output_identifiers, model): 25 | """ 26 | A script setting the metadata of segmentation models. 27 | """ 28 | import json 29 | from kraken.lib import vgsl 30 | 31 | net = vgsl.TorchVGSLModel.load_model(model) 32 | if net.model_type != 'segmentation': 33 | print('Model is not a segmentation model.') 34 | return 35 | 36 | print('detectable line types:') 37 | for k, v in net.user_metadata['class_mapping']['baselines'].items(): 38 | print(f' {k}\t{v}') 39 | print('detectable region types:') 40 | for k, v in net.user_metadata['class_mapping']['regions'].items(): 41 | print(f' {k}\t{v}') 42 | 43 | if output_identifiers: 44 | with open(output_identifiers, 'r') as fp: 45 | new_cls_map = json.load(fp) 46 | print('-> Updating class maps') 47 | if 'baselines' in new_cls_map: 48 | print('new baseline identifiers:') 49 | old_cls = {v: k for k,v in net.user_metadata['class_mapping']['baselines'].items()} 50 | new_cls = {v: k for k,v in new_cls_map['baselines'].items()} 51 | old_cls.update(new_cls) 52 | net.user_metadata['class_mapping']['baselines'] = {v: k for k, v in old_cls.items()} 53 | for k, v in net.user_metadata['class_mapping']['baselines'].items(): 54 | print(f' {k}\t{v}') 55 | if 'regions' in new_cls_map: 56 | print('new region identifiers:') 57 | old_cls = {v: k for k,v in net.user_metadata['class_mapping']['regions'].items()} 58 | new_cls = {v: k for k,v in new_cls_map['regions'].items()} 59 | old_cls.update(new_cls) 60 | net.user_metadata['class_mapping']['regions'] = {v: k for k, v in old_cls.items()} 61 | for k, v in net.user_metadata['class_mapping']['regions'].items(): 62 | print(f' {k}\t{v}') 63 | 64 | print(f'existing bounding regions: {net.user_metadata["bounding_regions"]}') 65 | 66 | if bounding_region: 67 | br = set(net.user_metadata["bounding_regions"]) 68 | br_new = set(bounding_region) 69 | 70 | print(f'-> removing: {br.difference(br_new)}') 71 | print(f'-> adding: {br_new.difference(br)}') 72 | net.user_metadata["bounding_regions"] = bounding_region 73 | 74 | loc = {'topline': True, 75 | 'baseline': False, 76 | 'centerline': None} 77 | 78 | rloc = {True: 'topline', 79 | False: 'baseline', 80 | None: 'centerline'} 81 | line_loc = rloc[net.user_metadata.get('topline', False)] 82 | 83 | print(f'Model is {line_loc}') 84 | print(f'-> Setting to {topline}') 85 | net.user_metadata['topline'] = loc[topline] 86 | 87 | print(f"Model has padding {net.user_metadata['hyper_params']['padding'] if 'padding' in net.user_metadata['hyper_params'] else (0, 0)}") 88 | print(f'-> Setting to {pad}') 89 | net.user_metadata['hyper_params']['padding'] = pad 90 | 91 | shutil.copy(model, f'{model}.bak') 92 | net.save_model(model) 93 | 94 | 95 | if __name__ == '__main__': 96 | cli() 97 | -------------------------------------------------------------------------------- /kraken/iso15924.json: -------------------------------------------------------------------------------- 1 | {"520": "Tang", "20": "Xsux", "30": "Xpeo", "550": "Blis", "40": "Ugar", "50": "Egyp", "570": "Brai", "60": "Egyh", "437": "Loma", "70": "Egyd", "80": "Hluw", "90": "Maya", "95": "Sgnw", "610": "Inds", "100": "Mero", "101": "Merc", "105": "Sarb", "106": "Narb", "620": "Roro", "115": "Phnx", "116": "Lydi", "120": "Tfng", "123": "Samr", "124": "Armi", "125": "Hebr", "126": "Palm", "127": "Hatr", "130": "Prti", "131": "Phli", "132": "Phlp", "133": "Phlv", "134": "Avst", "135": "Syrc", "136": "Syrn", "137": "Syrj", "138": "Syre", "139": "Mani", "140": "Mand", "145": "Mong", "159": "Nbat", "160": "Arab", "161": "Aran", "165": "Nkoo", "166": "Adlm", "170": "Thaa", "175": "Orkh", "176": "Hung", "200": "Grek", "201": "Cari", "202": "Lyci", "204": "Copt", "206": "Goth", "210": "Ital", "211": "Runr", "212": "Ogam", "215": "Latn", "216": "Latg", "217": "Latf", "218": "Moon", "219": "Osge", "220": "Cyrl", "221": "Cyrs", "225": "Glag", "226": "Elba", "227": "Perm", "230": "Armn", "239": "Aghb", "240": "Geor", "241": "Geok", "755": "Dupl", "250": "Dsrt", "259": "Bass", "260": "Osma", "261": "Olck", "262": "Wara", "263": "Pauc", "264": "Mroo", "265": "Medf", "280": "Visp", "281": "Shaw", "282": "Plrd", "284": "Jamo", "285": "Bopo", "286": "Hang", "287": "Kore", "288": "Kits", "290": "Teng", "291": "Cirt", "292": "Sara", "293": "Piqd", "300": "Brah", "302": "Sidd", "305": "Khar", "310": "Guru", "312": "Gong", "313": "Gonm", "314": "Mahj", "315": "Deva", "316": "Sylo", "317": "Kthi", "318": "Sind", "319": "Shrd", "320": "Gujr", "321": "Takr", "322": "Khoj", "323": "Mult", "324": "Modi", "325": "Beng", "326": "Tirh", "327": "Orya", "328": "Dogr", "329": "Soyo", "330": "Tibt", "331": "Phag", "332": "Marc", "333": "Newa", "334": "Bhks", "335": "Lepc", "336": "Limb", "337": "Mtei", "338": "Ahom", "339": "Zanb", "340": "Telu", "343": "Gran", "344": "Saur", "345": "Knda", "346": "Taml", "347": "Mlym", "348": "Sinh", "349": "Cakm", "350": "Mymr", "351": "Lana", "352": "Thai", "353": "Tale", "354": "Talu", "355": "Khmr", "356": "Laoo", "357": "Kali", "358": "Cham", "359": "Tavt", "360": "Bali", "361": "Java", "362": "Sund", "363": "Rjng", "364": "Leke", "365": "Batk", "366": "Maka", "367": "Bugi", "370": "Tglg", "371": "Hano", "372": "Buhd", "373": "Tagb", "900": "Qaaa", "398": "Sora", "399": "Lisu", "400": "Lina", "401": "Linb", "403": "Cprt", "410": "Hira", "411": "Kana", "412": "Hrkt", "413": "Jpan", "420": "Nkgb", "430": "Ethi", "435": "Bamu", "436": "Kpel", "949": "Qabx", "438": "Mend", "439": "Afak", "440": "Cans", "445": "Cher", "450": "Hmng", "460": "Yiii", "470": "Vaii", "480": "Wole", "993": "Zsye", "994": "Zinh", "995": "Zmth", "996": "Zsym", "997": "Zxxx", "998": "Zyyy", "999": "Zzzz", "499": "Nshu", "500": "Hani", "501": "Hans", "502": "Hant", "503": "Hanb", "505": "Kitl", "510": "Jurc"} 2 | -------------------------------------------------------------------------------- /kraken/ketos/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2015 Benjamin Kiessling 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 13 | # or implied. See the License for the specific language governing 14 | # permissions and limitations under the License. 15 | """ 16 | kraken.ketos 17 | ~~~~~~~~~~~~~ 18 | 19 | Command line drivers for training functionality. 20 | """ 21 | 22 | import logging 23 | 24 | import click 25 | from PIL import Image 26 | from rich.traceback import install 27 | 28 | from kraken.lib import log 29 | 30 | from .dataset import compile 31 | from .linegen import line_generator 32 | from .pretrain import pretrain 33 | from .recognition import test, train 34 | from .repo import publish 35 | from .ro import roadd, rotrain 36 | from .segmentation import segtest, segtrain 37 | from .transcription import extract, transcription 38 | 39 | logging.captureWarnings(True) 40 | logger = logging.getLogger('kraken') 41 | # disable annoying lightning worker seeding log messages 42 | logging.getLogger("lightning.fabric.utilities.seed").setLevel(logging.ERROR) 43 | # install rich traceback handler 44 | install(suppress=[click]) 45 | 46 | # raise default max image size to 20k * 20k pixels 47 | Image.MAX_IMAGE_PIXELS = 20000 ** 2 48 | 49 | 50 | @click.group() 51 | @click.version_option() 52 | @click.pass_context 53 | @click.option('-v', '--verbose', default=0, count=True) 54 | @click.option('-s', '--seed', default=None, type=click.INT, 55 | help='Seed for numpy\'s and torch\'s RNG. Set to a fixed value to ' 56 | 'ensure reproducible random splits of data') 57 | @click.option('-r', '--deterministic/--no-deterministic', default=False, 58 | help="Enables deterministic training. If no seed is given and enabled the seed will be set to 42.") 59 | def cli(ctx, verbose, seed, deterministic): 60 | ctx.meta['deterministic'] = False if not deterministic else 'warn' 61 | if seed: 62 | from lightning.pytorch import seed_everything 63 | seed_everything(seed, workers=True) 64 | elif deterministic: 65 | from lightning.pytorch import seed_everything 66 | seed_everything(42, workers=True) 67 | 68 | ctx.meta['verbose'] = verbose 69 | log.set_logger(logger, level=30 - min(10 * verbose, 20)) 70 | 71 | 72 | cli.add_command(compile) 73 | cli.add_command(pretrain) 74 | cli.add_command(train) 75 | cli.add_command(test) 76 | cli.add_command(segtrain) 77 | cli.add_command(segtest) 78 | cli.add_command(publish) 79 | cli.add_command(rotrain) 80 | cli.add_command(roadd) 81 | 82 | # deprecated commands 83 | cli.add_command(line_generator) 84 | cli.add_command(extract) 85 | cli.add_command(transcription) 86 | 87 | if __name__ == '__main__': 88 | cli() 89 | -------------------------------------------------------------------------------- /kraken/ketos/dataset.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2022 Benjamin Kiessling 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 13 | # or implied. See the License for the specific language governing 14 | # permissions and limitations under the License. 15 | """ 16 | kraken.ketos.dataset 17 | ~~~~~~~~~~~~~~~~~~~~ 18 | 19 | Command line driver for dataset compilation 20 | """ 21 | import click 22 | 23 | from .util import _validate_manifests 24 | 25 | 26 | @click.command('compile') 27 | @click.pass_context 28 | @click.option('-o', '--output', show_default=True, type=click.Path(), default='dataset.arrow', help='Output dataset file') 29 | @click.option('--workers', show_default=True, default=1, help='Number of parallel workers for text line extraction.') 30 | @click.option('-f', '--format-type', type=click.Choice(['path', 'xml', 'alto', 'page']), default='xml', show_default=True, 31 | help='Sets the training data format. In ALTO and PageXML mode all ' 32 | 'data is extracted from xml files containing both baselines and a ' 33 | 'link to source images. In `path` mode arguments are image files ' 34 | 'sharing a prefix up to the last extension with JSON `.path` files ' 35 | 'containing the baseline information.') 36 | @click.option('-F', '--files', show_default=True, default=None, multiple=True, 37 | callback=_validate_manifests, type=click.File(mode='r', lazy=True), 38 | help='File(s) with additional paths to training data.') 39 | @click.option('--random-split', type=float, nargs=3, default=None, show_default=True, 40 | help='Creates a fixed random split of the input data with the ' 41 | 'proportions (train, validation, test). Overrides the save split option.') 42 | @click.option('--force-type', type=click.Choice(['bbox', 'baseline']), default=None, show_default=True, 43 | help='Forces the dataset type to a specific value. Can be used to ' 44 | '"convert" a line strip-type collection to a baseline-style ' 45 | 'dataset, e.g. to disable centerline normalization.') 46 | @click.option('--save-splits/--ignore-splits', show_default=True, default=True, 47 | help='Whether to serialize explicit splits contained in XML ' 48 | 'files. Is ignored in `path` mode.') 49 | @click.option('--skip-empty-lines/--keep-empty-lines', show_default=True, default=True, 50 | help='Whether to keep or skip empty text lines. Text-less ' 51 | 'datasets are useful for unsupervised pretraining but ' 52 | 'loading datasets with many empty lines for recognition ' 53 | 'training is inefficient.') 54 | @click.option('--recordbatch-size', show_default=True, default=100, 55 | help='Minimum number of records per RecordBatch written to the ' 56 | 'output file. Larger batches require more transient memory ' 57 | 'but slightly improve reading performance.') 58 | @click.option('--legacy-polygons', show_default=True, default=False, is_flag=True, 59 | help='Use the old polygon extractor.') 60 | @click.argument('ground_truth', nargs=-1, type=click.Path(exists=True, dir_okay=False)) 61 | def compile(ctx, output, workers, format_type, files, random_split, force_type, 62 | save_splits, skip_empty_lines, recordbatch_size, ground_truth, legacy_polygons): 63 | """ 64 | Precompiles a binary dataset from a collection of XML files. 65 | """ 66 | from kraken.lib.progress import KrakenProgressBar 67 | 68 | from .util import message 69 | 70 | ground_truth = list(ground_truth) 71 | 72 | if files: 73 | ground_truth.extend(files) 74 | 75 | if not ground_truth: 76 | raise click.UsageError('No training data was provided to the compile command. Use the `ground_truth` argument.') 77 | 78 | from kraken.lib import arrow_dataset 79 | 80 | force_type = {'bbox': 'kraken_recognition_bbox', 81 | 'baseline': 'kraken_recognition_baseline', 82 | None: None}[force_type] 83 | 84 | with KrakenProgressBar() as progress: 85 | extract_task = progress.add_task('Extracting lines', total=0, start=False, visible=True if not ctx.meta['verbose'] else False) 86 | 87 | def _update_bar(advance, total): 88 | if not progress.tasks[0].started: 89 | progress.start_task(extract_task) 90 | progress.update(extract_task, total=total, advance=advance) 91 | 92 | arrow_dataset.build_binary_dataset(ground_truth, 93 | output, 94 | format_type, 95 | workers, 96 | save_splits, 97 | random_split, 98 | force_type, 99 | recordbatch_size, 100 | skip_empty_lines, 101 | _update_bar, 102 | legacy_polygons=legacy_polygons) 103 | 104 | message(f'Output file written to {output}') 105 | -------------------------------------------------------------------------------- /kraken/ketos/util.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2022 Benjamin Kiessling 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 13 | # or implied. See the License for the specific language governing 14 | # permissions and limitations under the License. 15 | """ 16 | kraken.ketos.util 17 | ~~~~~~~~~~~~~~~~~~~~ 18 | 19 | Command line driver helpers 20 | """ 21 | import glob 22 | import logging 23 | import os 24 | from typing import List, Optional, Tuple 25 | 26 | import click 27 | 28 | logging.captureWarnings(True) 29 | logger = logging.getLogger('kraken') 30 | 31 | 32 | def _validate_manifests(ctx, param, value): 33 | images = [] 34 | for manifest in value: 35 | try: 36 | for entry in manifest.readlines(): 37 | im_p = entry.rstrip('\r\n') 38 | if os.path.isfile(im_p): 39 | images.append(im_p) 40 | else: 41 | logger.warning('Invalid entry "{}" in {}'.format(im_p, manifest.name)) 42 | except UnicodeDecodeError: 43 | raise click.BadOptionUsage(param, 44 | f'File {manifest.name} is not a text file. Please ' 45 | 'ensure that the argument to `-t`/`-e` is a manifest ' 46 | 'file containing paths to training data (one per ' 47 | 'line).', 48 | ctx=ctx) 49 | return images 50 | 51 | 52 | def _expand_gt(ctx, param, value): 53 | images = [] 54 | for expression in value: 55 | images.extend([x for x in glob.iglob(expression, recursive=True) if os.path.isfile(x)]) 56 | return images 57 | 58 | 59 | def message(msg, **styles): 60 | if logger.getEffectiveLevel() >= 30: 61 | click.secho(msg, **styles) 62 | 63 | 64 | def to_ptl_device(device: str) -> Tuple[str, Optional[List[int]]]: 65 | if device.strip() == 'auto': 66 | return 'auto', 'auto' 67 | devices = device.split(',') 68 | if devices[0] in ['cpu', 'mps']: 69 | return devices[0], 'auto' 70 | elif any([devices[0].startswith(x) for x in ['tpu', 'cuda', 'hpu', 'ipu']]): 71 | devices = [device.split(':') for device in devices] 72 | devices = [(x[0].strip(), x[1].strip()) for x in devices] 73 | if len(set(x[0] for x in devices)) > 1: 74 | raise Exception('Can only use a single type of device at a time.') 75 | dev, _ = devices[0] 76 | if dev == 'cuda': 77 | dev = 'gpu' 78 | return dev, [int(x[1]) for x in devices] 79 | raise Exception(f'Invalid device {device} specified') 80 | -------------------------------------------------------------------------------- /kraken/lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mittagessen/kraken/14778a4ef2d061d688efcf979cb038d34e4b87e0/kraken/lib/__init__.py -------------------------------------------------------------------------------- /kraken/lib/ctc_decoder.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2017 Benjamin Kiessling 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 13 | # or implied. See the License for the specific language governing 14 | # permissions and limitations under the License. 15 | """ 16 | Decoders for softmax outputs of CTC trained networks. 17 | 18 | Decoders extract label sequences out of the raw output matrix of the line 19 | recognition network. There are multiple different approaches implemented here, 20 | from a simple greedy decoder, to the legacy ocropy thresholding decoder, and a 21 | more complex beam search decoder. 22 | 23 | Extracted label sequences are converted into the code point domain using kraken.lib.codec.PytorchCodec. 24 | """ 25 | 26 | import collections 27 | from itertools import groupby 28 | from typing import List, Tuple 29 | 30 | import numpy as np 31 | from scipy.ndimage import measurements 32 | from scipy.special import logsumexp 33 | 34 | __all__ = ['beam_decoder', 'greedy_decoder', 'blank_threshold_decoder'] 35 | 36 | 37 | def beam_decoder(outputs: np.ndarray, beam_size: int = 3) -> List[Tuple[int, int, int, float]]: 38 | """ 39 | Translates back the network output to a label sequence using 40 | same-prefix-merge beam search decoding as described in [0]. 41 | 42 | [0] Hannun, Awni Y., et al. "First-pass large vocabulary continuous speech 43 | recognition using bi-directional recurrent DNNs." arXiv preprint 44 | arXiv:1408.2873 (2014). 45 | 46 | Args: 47 | output: (C, W) shaped softmax output tensor 48 | beam_size: Size of the beam 49 | 50 | Returns: 51 | A list with tuples (class, start, end, prob). max is the maximum value 52 | of the softmax layer in the region. 53 | """ 54 | c, w = outputs.shape 55 | probs = np.log(outputs) 56 | beam = [(tuple(), (0.0, float('-inf')))] # type: List[Tuple[Tuple, Tuple[float, float]]] 57 | 58 | # loop over each time step 59 | for t in range(w): 60 | next_beam = collections.defaultdict(lambda: 2*(float('-inf'),)) # type: dict 61 | # p_b -> prob for prefix ending in blank 62 | # p_nb -> prob for prefix not ending in blank 63 | for prefix, (p_b, p_nb) in beam: 64 | # only update ending-in-blank-prefix probability for blank 65 | n_p_b, n_p_nb = next_beam[prefix] 66 | n_p_b = logsumexp((n_p_b, p_b + probs[0, t], p_nb + probs[0, t])) 67 | next_beam[prefix] = (n_p_b, n_p_nb) 68 | # loop over non-blank classes 69 | for s in range(1, c): 70 | # only update the not-ending-in-blank-prefix probability for prefix+s 71 | l_end = prefix[-1][0] if prefix else None 72 | n_prefix = prefix + ((s, t, t),) 73 | n_p_b, n_p_nb = next_beam[n_prefix] 74 | if s == l_end: 75 | # substitute the previous non-blank-ending-prefix 76 | # probability for repeated labels 77 | n_p_nb = logsumexp((n_p_nb, p_b + probs[s, t])) 78 | else: 79 | n_p_nb = logsumexp((n_p_nb, p_b + probs[s, t], p_nb + probs[s, t])) 80 | 81 | next_beam[n_prefix] = (n_p_b, n_p_nb) 82 | 83 | # If s is repeated at the end we also update the unchanged 84 | # prefix. This is the merging case. 85 | if s == l_end: 86 | n_p_b, n_p_nb = next_beam[prefix] 87 | n_p_nb = logsumexp((n_p_nb, p_nb + probs[s, t])) 88 | # rewrite both new and old prefix positions 89 | next_beam[prefix[:-1] + ((prefix[-1][0], prefix[-1][1], t),)] = (n_p_b, n_p_nb) 90 | next_beam[n_prefix[:-1] + ((n_prefix[-1][0], n_prefix[-1][1], t),)] = next_beam.pop(n_prefix) 91 | 92 | # Sort and trim the beam before moving on to the 93 | # next time-step. 94 | beam = sorted(next_beam.items(), 95 | key=lambda x: logsumexp(x[1]), 96 | reverse=True) 97 | beam = beam[:beam_size] 98 | return [(c, start, end, max(outputs[c, start:end+1])) for (c, start, end) in beam[0][0]] 99 | 100 | 101 | def greedy_decoder(outputs: np.ndarray) -> List[Tuple[int, int, int, float]]: 102 | """ 103 | Translates back the network output to a label sequence using greedy/best 104 | path decoding as described in [0]. 105 | 106 | [0] Graves, Alex, et al. "Connectionist temporal classification: labelling 107 | unsegmented sequence data with recurrent neural networks." Proceedings of 108 | the 23rd international conference on Machine learning. ACM, 2006. 109 | 110 | Args: 111 | output: (C, W) shaped softmax output tensor 112 | 113 | Returns: 114 | A list with tuples (class, start, end, max). max is the maximum value 115 | of the softmax layer in the region. 116 | """ 117 | labels = np.argmax(outputs, 0) 118 | seq_len = outputs.shape[1] 119 | mask = np.eye(outputs.shape[0], dtype='bool')[labels].T 120 | classes = [] 121 | for label, group in groupby(zip(np.arange(seq_len), labels, outputs[mask]), key=lambda x: x[1]): 122 | lgroup = list(group) 123 | if label != 0: 124 | classes.append((label, lgroup[0][0], lgroup[-1][0], max(x[2] for x in lgroup))) 125 | return classes 126 | 127 | 128 | def blank_threshold_decoder(outputs: np.ndarray, threshold: float = 0.5) -> List[Tuple[int, int, int, float]]: 129 | """ 130 | Translates back the network output to a label sequence as the original 131 | ocropy/clstm. 132 | 133 | Thresholds on class 0, then assigns the maximum (non-zero) class to each 134 | region. 135 | 136 | Args: 137 | output: (C, W) shaped softmax output tensor 138 | threshold: Threshold for 0 class when determining possible label 139 | locations. 140 | 141 | Returns: 142 | A list with tuples (class, start, end, max). max is the maximum value 143 | of the softmax layer in the region. 144 | """ 145 | outputs = outputs.T 146 | labels, n = measurements.label(outputs[:, 0] < threshold) 147 | mask = np.tile(labels.reshape(-1, 1), (1, outputs.shape[1])) 148 | maxima = measurements.maximum_position(outputs, mask, np.arange(1, np.amax(mask)+1)) 149 | p = 0 150 | start = None 151 | x = [] 152 | for idx, val in enumerate(labels): 153 | if val != 0 and start is None: 154 | start = idx 155 | p += 1 156 | if val == 0 and start is not None: 157 | if maxima[p-1][1] == 0: 158 | start = None 159 | else: 160 | x.append((maxima[p-1][1], start, idx, outputs[maxima[p-1]])) 161 | start = None 162 | # append last non-zero region to list of no zero region occurs after it 163 | if start: 164 | x.append((maxima[p-1][1], start, len(outputs), outputs[maxima[p-1]])) 165 | return [y for y in x if x[0] != 0] 166 | -------------------------------------------------------------------------------- /kraken/lib/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2022 Benjamin Kiessling 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 13 | # or implied. See the License for the specific language governing 14 | # permissions and limitations under the License. 15 | """ 16 | Top-level module containing datasets for recognition and segmentation training. 17 | """ 18 | from .recognition import (ArrowIPCRecognitionDataset, # NOQA 19 | GroundTruthDataset, PolygonGTDataset) 20 | from .ro import PageWiseROSet, PairWiseROSet # NOQA 21 | from .segmentation import BaselineSet # NOQA 22 | from .utils import (ImageInputTransforms, collate_sequences, # NOQA 23 | compute_confusions, global_align) 24 | -------------------------------------------------------------------------------- /kraken/lib/default_specs.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2020 Benjamin Kiessling 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 13 | # or implied. See the License for the specific language governing 14 | # permissions and limitations under the License. 15 | """ 16 | Default VGSL specs and hyperparameters 17 | """ 18 | 19 | SEGMENTATION_SPEC = '[1,1800,0,3 Cr7,7,64,2,2 Gn32 Cr3,3,128,2,2 Gn32 Cr3,3,128 Gn32 Cr3,3,256 Gn32 Cr3,3,256 Gn32 Lbx32 Lby32 Cr1,1,32 Gn32 Lby32 Lbx32]' # NOQA 20 | RECOGNITION_SPEC = '[1,120,0,1 Cr3,13,32 Do0.1,2 Mp2,2 Cr3,13,32 Do0.1,2 Mp2,2 Cr3,9,64 Do0.1,2 Mp2,2 Cr3,9,64 Do0.1,2 S1(1x0)1,3 Lbx200 Do0.1,2 Lbx200 Do0.1,2 Lbx200 Do]' # NOQA 21 | 22 | READING_ORDER_HYPER_PARAMS = {'lrate': 0.001, 23 | 'freq': 1.0, 24 | 'batch_size': 15000, 25 | 'min_epochs': 500, 26 | 'epochs': 3000, 27 | 'lag': 300, 28 | 'min_delta': None, 29 | 'quit': 'early', 30 | 'optimizer': 'Adam', 31 | 'momentum': 0.9, 32 | 'weight_decay': 0.01, 33 | 'schedule': 'cosine', 34 | 'completed_epochs': 0, 35 | # lr scheduler params 36 | # step/exp decay 37 | 'step_size': 10, 38 | 'gamma': 0.1, 39 | # reduce on plateau 40 | 'rop_factor': 0.1, 41 | 'rop_patience': 5, 42 | # cosine 43 | 'cos_t_max': 100, 44 | 'cos_min_lr': 0.001, 45 | 'warmup': 0, 46 | } 47 | 48 | RECOGNITION_PRETRAIN_HYPER_PARAMS = {'pad': 16, 49 | 'freq': 1.0, 50 | 'batch_size': 64, 51 | 'quit': 'early', 52 | 'epochs': -1, 53 | 'min_epochs': 100, 54 | 'lag': 5, 55 | 'min_delta': None, 56 | 'optimizer': 'Adam', 57 | 'lrate': 1e-6, 58 | 'momentum': 0.9, 59 | 'weight_decay': 0.01, 60 | 'schedule': 'cosine', 61 | 'completed_epochs': 0, 62 | 'augment': False, 63 | # lr scheduler params 64 | # step/exp decay 65 | 'step_size': 10, 66 | 'gamma': 0.1, 67 | # reduce on plateau 68 | 'rop_factor': 0.1, 69 | 'rop_patience': 5, 70 | # cosine 71 | 'cos_t_max': 100, 72 | 'cos_min_lr': 1e-7, 73 | # masking parameters 74 | 'mask_width': 4, 75 | 'mask_prob': 0.5, 76 | 'num_negatives': 100, 77 | 'logit_temp': 0.1, 78 | 'warmup': 32000, 79 | } 80 | 81 | RECOGNITION_HYPER_PARAMS = {'pad': 16, 82 | 'freq': 1.0, 83 | 'batch_size': 1, 84 | 'quit': 'early', 85 | 'epochs': -1, 86 | 'min_epochs': 0, 87 | 'lag': 10, 88 | 'min_delta': None, 89 | 'optimizer': 'Adam', 90 | 'lrate': 1e-3, 91 | 'momentum': 0.9, 92 | 'weight_decay': 0.0, 93 | 'schedule': 'constant', 94 | 'normalization': None, 95 | 'normalize_whitespace': True, 96 | 'completed_epochs': 0, 97 | 'augment': False, 98 | # lr scheduler params 99 | # step/exp decay 100 | 'step_size': 10, 101 | 'gamma': 0.1, 102 | # reduce on plateau 103 | 'rop_factor': 0.1, 104 | 'rop_patience': 5, 105 | # cosine 106 | 'cos_t_max': 50, 107 | 'cos_min_lr': 1e-4, 108 | 'warmup': 0, 109 | 'freeze_backbone': 0, 110 | } 111 | 112 | SEGMENTATION_HYPER_PARAMS = {'line_width': 8, 113 | 'padding': (0, 0), 114 | 'freq': 1.0, 115 | 'quit': 'fixed', 116 | 'epochs': 50, 117 | 'min_epochs': 0, 118 | 'lag': 10, 119 | 'min_delta': None, 120 | 'optimizer': 'Adam', 121 | 'lrate': 2e-4, 122 | 'momentum': 0.9, 123 | 'weight_decay': 1e-5, 124 | 'schedule': 'constant', 125 | 'completed_epochs': 0, 126 | 'augment': False, 127 | # lr scheduler params 128 | # step/exp decay 129 | 'step_size': 10, 130 | 'gamma': 0.1, 131 | # reduce on plateau 132 | 'rop_factor': 0.1, 133 | 'rop_patience': 5, 134 | # cosine 135 | 'cos_t_max': 50, 136 | 'cos_min_lr': 2e-5, 137 | 'warmup': 0, 138 | } 139 | -------------------------------------------------------------------------------- /kraken/lib/exceptions.py: -------------------------------------------------------------------------------- 1 | """ 2 | kraken.lib.exceptions 3 | ~~~~~~~~~~~~~~~~~~~~~ 4 | 5 | All custom exceptions raised by kraken's modules and packages. Packages should 6 | always define their exceptions here. 7 | """ 8 | 9 | 10 | class KrakenCodecException(Exception): 11 | 12 | def __init__(self, message=None): 13 | Exception.__init__(self, message) 14 | 15 | 16 | class KrakenStopTrainingException(Exception): 17 | 18 | def __init__(self, message=None): 19 | Exception.__init__(self, message) 20 | 21 | 22 | class KrakenEncodeException(Exception): 23 | 24 | def __init__(self, message=None): 25 | Exception.__init__(self, message) 26 | 27 | 28 | class KrakenRecordException(Exception): 29 | 30 | def __init__(self, message=None): 31 | Exception.__init__(self, message) 32 | 33 | 34 | class KrakenInvalidModelException(Exception): 35 | 36 | def __init__(self, message=None): 37 | Exception.__init__(self, message) 38 | 39 | 40 | class KrakenInputException(Exception): 41 | 42 | def __init__(self, message=None): 43 | Exception.__init__(self, message) 44 | 45 | 46 | class KrakenRepoException(Exception): 47 | 48 | def __init__(self, message=None): 49 | Exception.__init__(self, message) 50 | 51 | 52 | class KrakenCairoSurfaceException(Exception): 53 | """ 54 | Raised when the Cairo surface couldn't be created. 55 | 56 | Attributes: 57 | message (str): Error message 58 | width (int): Width of the surface 59 | height (int): Height of the surface 60 | """ 61 | def __init__(self, message: str, width: int, height: int) -> None: 62 | self.message = message 63 | self.width = width 64 | self.height = height 65 | 66 | def __repr__(self) -> str: 67 | return repr(self.message) 68 | -------------------------------------------------------------------------------- /kraken/lib/functional_im_transforms.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2015 Benjamin Kiessling 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 13 | # or implied. See the License for the specific language governing 14 | # permissions and limitations under the License. 15 | """ 16 | Named functions for all the transforms that were lambdas in the past to 17 | facilitate pickling. 18 | """ 19 | import unicodedata 20 | from pathlib import Path 21 | from typing import (TYPE_CHECKING, Any, Callable, Literal, Optional, Tuple, 22 | Union) 23 | 24 | import bidi.algorithm as bd 25 | import regex 26 | import torch 27 | from PIL.Image import Resampling 28 | 29 | from kraken.lib.lineest import CenterNormalizer, dewarp 30 | 31 | if TYPE_CHECKING: 32 | from os import PathLike 33 | 34 | from PIL import Image 35 | 36 | 37 | def pil_to_mode(im: 'Image.Image', mode: str) -> 'Image.Image': 38 | return im.convert(mode) 39 | 40 | 41 | def pil_to_bin(im: 'Image.Image') -> 'Image.Image': 42 | from kraken.binarization import nlbin 43 | return nlbin(im) 44 | 45 | 46 | def dummy(x: Any) -> Any: 47 | return x 48 | 49 | 50 | def pil_dewarp(im: 'Image.Image', lnorm: CenterNormalizer) -> 'Image.Image': 51 | return dewarp(lnorm, im) 52 | 53 | 54 | def pil_fixed_resize(im: 'Image.Image', scale: Tuple[int, int]) -> 'Image.Image': 55 | return _fixed_resize(im, scale, Resampling.LANCZOS) 56 | 57 | 58 | def tensor_invert(im: torch.Tensor) -> torch.Tensor: 59 | return im.max() - im 60 | 61 | 62 | def tensor_permute(im: torch.Tensor, perm: Tuple[int, ...]) -> torch.Tensor: 63 | return im.permute(*perm) 64 | 65 | 66 | def _fixed_resize(img: 'Image.Image', size: Tuple[int, int], interpolation: int = Resampling.LANCZOS): 67 | """ 68 | Doesn't do the annoying runtime scale dimension switching the default 69 | pytorch transform does. 70 | 71 | Args: 72 | img: image to resize 73 | size: Tuple (height, width) 74 | """ 75 | w, h = img.size 76 | oh, ow = size 77 | if oh == 0: 78 | oh = int(h * ow/w) 79 | elif ow == 0: 80 | ow = int(w * oh/h) 81 | img = img.resize((ow, oh), interpolation) 82 | return img 83 | 84 | 85 | def text_normalize(text: str, normalization: Literal['NFD', 'NFC', 'NFKD', 'NFKC']) -> str: 86 | return unicodedata.normalize(normalization, text) 87 | 88 | 89 | def text_whitespace_normalize(text: str) -> str: 90 | return regex.sub(r'\s', ' ', text).strip() 91 | 92 | 93 | def text_reorder(text: str, base_dir: Optional[Literal['L', 'R']] = None) -> str: 94 | return bd.get_display(text, base_dir=base_dir) 95 | 96 | 97 | def default_split(x: Union['PathLike', str]) -> str: 98 | x = Path(x) 99 | while x.suffixes: 100 | x = x.with_suffix('') 101 | return str(x) 102 | 103 | 104 | def suffix_split(x: Union['PathLike', str], split: Callable[[Union['PathLike', str]], str], suffix: str) -> str: 105 | return split(x) + suffix 106 | -------------------------------------------------------------------------------- /kraken/lib/lineest.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import TYPE_CHECKING 3 | 4 | import numpy as np 5 | from scipy.ndimage import affine_transform, gaussian_filter, uniform_filter 6 | 7 | if TYPE_CHECKING: 8 | from PIL import Image 9 | 10 | __all__ = ['CenterNormalizer', 'dewarp'] 11 | 12 | 13 | def scale_to_h(img, target_height, order=1, dtype=np.dtype('f'), cval=0): 14 | h, w = img.shape 15 | scale = target_height*1.0/h 16 | target_width = int(scale*w) 17 | with warnings.catch_warnings(): 18 | warnings.simplefilter('ignore', UserWarning) 19 | output = affine_transform(1.0*img, np.ones(2)/scale, order=order, 20 | output_shape=(target_height, target_width), 21 | mode='constant', cval=cval) 22 | output = np.array(output, dtype=dtype) 23 | return output 24 | 25 | 26 | class CenterNormalizer(object): 27 | def __init__(self, target_height=48, params=(4, 1.0, 0.3)): 28 | self.target_height = target_height 29 | self.range, self.smoothness, self.extra = params 30 | 31 | def setHeight(self, target_height): 32 | self.target_height = target_height 33 | 34 | def measure(self, line): 35 | h, w = line.shape 36 | # XXX: this filter is awfully slow 37 | smoothed = gaussian_filter(line, (h*0.5, h*self.smoothness), 38 | mode='constant') 39 | smoothed += 0.001*uniform_filter(smoothed, (h*0.5, w), mode='constant') 40 | self.shape = (h, w) 41 | a = np.argmax(smoothed, axis=0) 42 | a = gaussian_filter(a, h*self.extra) 43 | self.center = np.array(a, 'i') 44 | deltas = np.abs(np.arange(h)[:, np.newaxis]-self.center[np.newaxis, :]) 45 | self.mad = np.mean(deltas[line != 0]) 46 | self.r = int(1+self.range*self.mad) 47 | 48 | def dewarp(self, img, cval=0, dtype=np.dtype('f')): 49 | if img.shape != self.shape: 50 | raise Exception('Measured and dewarp image shapes different') 51 | h, w = img.shape 52 | padded = np.vstack([cval*np.ones((h, w)), img, cval*np.ones((h, w))]) 53 | center = self.center+h 54 | dewarped = [padded[center[i]-self.r:center[i]+self.r, i] for i in 55 | range(w)] 56 | dewarped = np.array(dewarped, dtype=dtype).T 57 | return dewarped 58 | 59 | def normalize(self, img, order=1, dtype=np.dtype('f'), cval=0): 60 | dewarped = self.dewarp(img, cval=cval, dtype=dtype) 61 | if dewarped.shape[0] == 0: 62 | dewarped = img 63 | scaled = scale_to_h(dewarped, self.target_height, order=order, 64 | dtype=dtype, cval=cval) 65 | return scaled 66 | 67 | 68 | def dewarp(normalizer: CenterNormalizer, im: 'Image.Image') -> 'Image.Image': 69 | """ 70 | Dewarps an image of a line using a kraken.lib.lineest.CenterNormalizer 71 | instance. 72 | 73 | Args: 74 | normalizer: A line normalizer instance 75 | im: Image to dewarp 76 | 77 | Returns: 78 | PIL.Image.Image containing the dewarped image. 79 | """ 80 | from kraken.lib.util import array2pil, pil2array 81 | 82 | line = pil2array(im) 83 | temp = np.amax(line)-line 84 | temp = temp*1.0/np.amax(temp) 85 | normalizer.measure(temp) 86 | line = normalizer.normalize(line, cval=np.amax(line)) 87 | return array2pil(line) 88 | -------------------------------------------------------------------------------- /kraken/lib/log.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2018 Benjamin Kiessling 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 13 | # or implied. See the License for the specific language governing 14 | # permissions and limitations under the License. 15 | """ 16 | kraken.lib.log 17 | ~~~~~~~~~~~~~~~~~ 18 | 19 | Handlers and formatters for logging. 20 | """ 21 | import logging 22 | 23 | from rich.logging import RichHandler 24 | 25 | 26 | def set_logger(logger=None, level=logging.ERROR): 27 | logger.addHandler(RichHandler(rich_tracebacks=True)) 28 | logger.setLevel(level) 29 | -------------------------------------------------------------------------------- /kraken/lib/morph.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various add-ons to the SciPy morphology package 3 | """ 4 | import numpy as np 5 | from scipy.ndimage import distance_transform_edt, filters 6 | from scipy.ndimage import find_objects as _find_objects 7 | from scipy.ndimage import label as _label 8 | 9 | 10 | def label(image: np.ndarray, **kw) -> np.ndarray: 11 | """ 12 | Redefine the scipy.ndimage.measurements.label function to work with a wider 13 | range of data types. The default function is inconsistent about the data 14 | types it accepts on different platforms. 15 | """ 16 | try: 17 | return _label(image, **kw) 18 | except Exception: 19 | pass 20 | types = ["int32", "uint32", "int64", "uint64", "int16", "uint16"] 21 | for t in types: 22 | try: 23 | return _label(np.array(image, dtype=t), **kw) 24 | except Exception: 25 | pass 26 | # let it raise the same exception as before 27 | return _label(image, **kw) 28 | 29 | 30 | def find_objects(image: np.ndarray, **kw) -> np.ndarray: 31 | """ 32 | Redefine the scipy.ndimage.measurements.find_objects function to work with 33 | a wider range of data types. The default function is inconsistent about 34 | the data types it accepts on different platforms. 35 | """ 36 | try: 37 | return _find_objects(image, **kw) 38 | except Exception: 39 | pass 40 | types = ["int32", "uint32", "int64", "uint64", "int16", "uint16"] 41 | for t in types: 42 | try: 43 | return _find_objects(np.array(image, dtype=t), **kw) 44 | except Exception: 45 | pass 46 | # let it raise the same exception as before 47 | return _find_objects(image, **kw) 48 | 49 | 50 | def r_dilation(image, size, origin=0): 51 | """Dilation with rectangular structuring element using maximum_filter""" 52 | return filters.maximum_filter(image, size, origin=origin) 53 | 54 | 55 | def r_erosion(image, size, origin=0): 56 | """Erosion with rectangular structuring element using maximum_filter""" 57 | return filters.minimum_filter(image, size, origin=origin) 58 | 59 | 60 | def rb_dilation(image, size, origin=0): 61 | """Binary dilation using linear filters.""" 62 | output = np.zeros(image.shape, 'f') 63 | filters.uniform_filter(image, size, output=output, origin=origin, 64 | mode='constant', cval=0) 65 | return np.array(output > 0, 'i') 66 | 67 | 68 | def rb_erosion(image, size, origin=0): 69 | """Binary erosion using linear filters.""" 70 | output = np.zeros(image.shape, 'f') 71 | filters.uniform_filter(image, size, output=output, origin=origin, 72 | mode='constant', cval=1) 73 | return np.array(output == 1, 'i') 74 | 75 | 76 | def rb_opening(image, size, origin=0): 77 | """Binary opening using linear filters.""" 78 | image = rb_erosion(image, size, origin=origin) 79 | return rb_dilation(image, size, origin=origin) 80 | 81 | 82 | def spread_labels(labels, maxdist=9999999): 83 | """Spread the given labels to the background""" 84 | distances, features = distance_transform_edt(labels == 0, 85 | return_distances=1, 86 | return_indices=1) 87 | indexes = features[0] * labels.shape[1] + features[1] 88 | spread = labels.ravel()[indexes.ravel()].reshape(*labels.shape) 89 | spread *= (distances < maxdist) 90 | return spread 91 | 92 | 93 | def correspondences(labels1, labels2): 94 | """Given two labeled images, compute an array giving the correspondences 95 | between labels in the two images.""" 96 | q = 100000 97 | combo = labels1 * q + labels2 98 | result = np.unique(combo) 99 | result = np.array([result // q, result % q]) 100 | return result 101 | 102 | 103 | def propagate_labels(image, labels, conflict=0): 104 | """Given an image and a set of labels, apply the labels 105 | to all the regions in the image that overlap a label. 106 | Assign the value `conflict` to any labels that have a conflict.""" 107 | rlabels, _ = label(image) 108 | cors = correspondences(rlabels, labels) 109 | outputs = np.zeros(np.amax(rlabels) + 1, 'i') 110 | oops = -(1 << 30) 111 | for o, i in cors.T: 112 | if outputs[o] != 0: 113 | outputs[o] = oops 114 | else: 115 | outputs[o] = i 116 | outputs[outputs == oops] = conflict 117 | outputs[0] = 0 118 | return outputs[rlabels] 119 | 120 | 121 | def select_regions(binary, f, min=0, nbest=100000): 122 | """Given a scoring function f over slice tuples (as returned by 123 | find_objects), keeps at most nbest regions whose scores is higher 124 | than min.""" 125 | labels, n = label(binary) 126 | objects = find_objects(labels) 127 | scores = [f(o) for o in objects] 128 | best = np.argsort(scores) 129 | keep = np.zeros(len(objects) + 1, 'i') 130 | if nbest > 0: 131 | for i in best[-nbest:]: 132 | if scores[i] <= min: 133 | continue 134 | keep[i+1] = 1 135 | return keep[labels] 136 | -------------------------------------------------------------------------------- /kraken/lib/pretrain/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2022 Benjamin Kiessling 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 13 | # or implied. See the License for the specific language governing 14 | # permissions and limitations under the License. 15 | """ 16 | Tools for unsupervised pretraining of recognition models. 17 | """ 18 | 19 | from .model import PretrainDataModule, RecognitionPretrainModel # NOQA 20 | -------------------------------------------------------------------------------- /kraken/lib/pretrain/layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Layers for VGSL models 3 | """ 4 | from typing import TYPE_CHECKING, Optional, Tuple 5 | 6 | import torch 7 | from torch.nn import Embedding, Linear, Module 8 | 9 | from kraken.lib.pretrain.util import compute_mask_indices, sample_negatives 10 | 11 | if TYPE_CHECKING: 12 | from kraken.lib.vgsl import VGSLBlock 13 | 14 | # all tensors are ordered NCHW, the "feature" dimension is C, so the output of 15 | # an LSTM will be put into C same as the filters of a CNN. 16 | 17 | __all__ = ['Wav2Vec2Mask'] 18 | 19 | 20 | class Wav2Vec2Mask(Module): 21 | """ 22 | A layer for Wav2Vec2-style masking. Needs to be placed just before 23 | recurrent/contextual layers. 24 | """ 25 | 26 | def __init__(self, 27 | context_encoder_input_dim: int, 28 | final_dim: int, 29 | mask_width: int, 30 | mask_prob: float, 31 | num_negatives: int) -> None: 32 | """ 33 | 34 | Args: 35 | context_encoder_input_dim: size of the `C` input dimension 36 | final_dim: size of the decoder `C` output dimension just before the 37 | final linear projection. 38 | mask_width: width of the non-overlapping masked areas. 39 | mask_prob: probability of masking at each time step 40 | num_negatives: number of negative samples with width mask_width * 41 | num_masks 42 | 43 | Shape: 44 | - Inputs: :math:`(N, C, H, W)` where `N` batches, `C` channels, `H` 45 | height, and `W` width. 46 | - Outputs output :math:`(N, C, H, W)` 47 | """ 48 | super().__init__() 49 | 50 | self.context_encoder_input_dim = context_encoder_input_dim 51 | self.final_dim = final_dim 52 | self.mask_width = mask_width 53 | self.mask_prob = mask_prob 54 | self.num_negatives = num_negatives 55 | 56 | # mask embedding replacing the masked out areas 57 | self.mask_emb = Embedding(1, context_encoder_input_dim) 58 | self.project_q = Linear(context_encoder_input_dim, final_dim) 59 | 60 | def forward(self, inputs: torch.Tensor, seq_len: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 61 | N, C, H, W = inputs.shape 62 | if H != 1: 63 | raise Exception(f'Height has to be 1, not {H} for Wav2Vec2 masking layer.') 64 | 65 | # NCHW -> NWC 66 | inputs = inputs.transpose(1, 3).reshape(-1, W, C) 67 | mask_indices = compute_mask_indices((N, W), self.mask_prob, self.mask_width) 68 | mask_indices = torch.from_numpy(mask_indices).to(inputs.device) 69 | 70 | unmasked_features = inputs.clone() 71 | # mask out 72 | inputs[mask_indices] = self.mask_emb.weight 73 | # project into same dimensionality as final recurrent layer 74 | unmasked_features = self.project_q(unmasked_features) 75 | unmasked_samples = unmasked_features[mask_indices].view(unmasked_features.size(0), -1, unmasked_features.size(-1)) 76 | 77 | # negative samples 78 | negative_samples = sample_negatives(unmasked_samples, unmasked_samples.size(1), self.num_negatives) 79 | 80 | # NWC -> NCHW 81 | inputs = inputs.permute(0, 2, 1).unsqueeze(2) 82 | return {'output': inputs, 83 | 'unmasked_samples': unmasked_samples, 84 | 'negative_samples': negative_samples, 85 | 'seq_len': seq_len, 86 | 'mask': mask_indices} 87 | 88 | def get_shape(self, input: Tuple[int, int, int, int]) -> Tuple[int, int, int, int]: 89 | """ 90 | Calculates the output shape from input 4D tuple NCHW. 91 | """ 92 | return input 93 | 94 | def get_spec(self, name) -> "VGSLBlock": 95 | """ 96 | Generates a VGSL spec block from the layer instance. 97 | """ 98 | return f'[1,{self.final_dim},0,{self.context_encoder_input_dim} W{{{name}}}{self.final_dim},{self.mask_width},{self.mask_prob},{self.num_negatives}]' 99 | 100 | def deserialize(self, name: str, spec) -> None: 101 | """ 102 | Sets the weights of an initialized module from a CoreML protobuf spec. 103 | """ 104 | # extract embedding parameters 105 | emb = [x for x in spec.neuralNetwork.layers if x.name == '{}_wave2vec2_emb'.format(name)][0].embedding 106 | weights = torch.Tensor(emb.weights.floatValue).resize_as_(self.mask_emb.weight.data) 107 | self.mask_emb.weight = torch.nn.Parameter(weights) 108 | # extract linear projection parameters 109 | lin = [x for x in spec.neuralNetwork.layers if x.name == '{}_wave2vec2_lin'.format(name)][0].innerProduct 110 | weights = torch.Tensor(lin.weights.floatValue).resize_as_(self.project_q.weight.data) 111 | bias = torch.Tensor(lin.bias.floatValue) 112 | self.project_q.weight = torch.nn.Parameter(weights) 113 | self.project_q.bias = torch.nn.Parameter(bias) 114 | 115 | def serialize(self, name: str, input: str, builder): 116 | """ 117 | Serializes the module using a NeuralNetworkBuilder. 118 | """ 119 | wave2vec2_emb_name = f'{name}_wave2vec2_emb' 120 | builder.add_embedding(wave2vec2_emb_name, self.mask_emb.weight.data.numpy(), 121 | None, 122 | self.context_encoder_input_dim, self.mask_width, 123 | has_bias=False, input_name=input, output_name=wave2vec2_emb_name) 124 | wave2vec2_lin_name = f'{name}_wave2vec2_lin' 125 | builder.add_inner_product(wave2vec2_lin_name, self.project_q.weight.data.numpy(), 126 | self.project_q.bias.data.numpy(), 127 | self.context_encoder_input_dim, self.final_dim, 128 | has_bias=True, input_name=input, output_name=wave2vec2_lin_name) 129 | return name 130 | -------------------------------------------------------------------------------- /kraken/lib/pretrain/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import random 7 | from typing import Sequence, Tuple, Union 8 | 9 | import numpy as np 10 | import torch 11 | 12 | 13 | def positive_integers_with_sum(n, total): 14 | ls = [0] 15 | rv = [] 16 | while len(ls) < n: 17 | c = random.randint(0, total) 18 | ls.append(c) 19 | ls = sorted(ls) 20 | ls.append(total) 21 | for i in range(1, len(ls)): 22 | rv.append(ls[i] - ls[i-1]) 23 | return rv 24 | 25 | 26 | def compute_masks(mask_prob: int, 27 | mask_width: int, 28 | num_neg_samples: int, 29 | seq_lens: Union[torch.Tensor, Sequence[int]]): 30 | """ 31 | Samples num_mask non-overlapping random masks of length mask_width in 32 | sequence of length seq_len. 33 | 34 | Args: 35 | mask_prob: Probability of each individual token being chosen as start 36 | of a masked sequence. Overall number of masks num_masks is 37 | mask_prob * sum(seq_lens) / mask_width. 38 | mask_width: width of each mask 39 | num_neg_samples: Number of samples from unmasked sequence parts (gets 40 | multiplied by num_mask) 41 | seq_lens: sequence lengths 42 | 43 | Returns: 44 | An index array containing 1 for masked bits, 2 for negative samples, 45 | the number of masks, and the actual number of negative samples. 46 | """ 47 | mask_samples = np.zeros(sum(seq_lens)) 48 | num_masks = int(mask_prob * sum(seq_lens.numpy()) // mask_width) 49 | num_neg_samples = num_masks * num_neg_samples 50 | num_masks += num_neg_samples 51 | 52 | indices = [x+mask_width for x in positive_integers_with_sum(num_masks, sum(seq_lens)-num_masks*mask_width)] 53 | start = 0 54 | mask_slices = [] 55 | for i in indices: 56 | i_start = random.randint(start, i+start-mask_width) 57 | mask_slices.append(slice(i_start, i_start+mask_width)) 58 | start += i 59 | 60 | neg_idx = random.sample(range(len(mask_slices)), num_neg_samples) 61 | neg_slices = [mask_slices.pop(idx) for idx in sorted(neg_idx, reverse=True)] 62 | 63 | mask_samples[np.r_[tuple(mask_slices)]] = 1 64 | mask_samples[np.r_[tuple(neg_slices)]] = 2 65 | 66 | return mask_samples, num_masks - num_neg_samples, num_neg_samples 67 | 68 | 69 | def buffered_arange(max): 70 | if not hasattr(buffered_arange, "buf"): 71 | buffered_arange.buf = torch.LongTensor() 72 | if max > buffered_arange.buf.numel(): 73 | buffered_arange.buf.resize_(max) 74 | torch.arange(max, out=buffered_arange.buf) 75 | return buffered_arange.buf[:max] 76 | 77 | 78 | def sample_negatives(y, num_samples, num_neg_samples: int): 79 | B, W, C = y.shape 80 | y = y.view(-1, C) # BTC => (BxT)C 81 | 82 | with torch.no_grad(): 83 | tszs = (buffered_arange(num_samples).unsqueeze(-1).expand(-1, num_neg_samples).flatten()) 84 | 85 | neg_idxs = torch.randint(low=0, high=W - 1, size=(B, num_neg_samples * num_samples)) 86 | neg_idxs[neg_idxs >= tszs] += 1 87 | 88 | for i in range(1, B): 89 | neg_idxs[i] += i * W 90 | 91 | negs = y[neg_idxs.view(-1)] 92 | negs = negs.view(B, num_samples, num_neg_samples, C).permute(2, 0, 1, 3) # to NxBxTxC 93 | return negs 94 | 95 | 96 | def compute_mask_indices(shape: Tuple[int, int], mask_prob: float, mask_length: int = 4, mask_min_space: int = 2) -> np.ndarray: 97 | """ 98 | Computes random mask spans for a given shape 99 | 100 | Args: 101 | shape: the the shape for which to compute masks. 102 | should be of size 2 where first element is batch size and 2nd is timesteps 103 | mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by 104 | number of timesteps divided by length of mask span to mask approximately this percentage of all elements. 105 | however due to overlaps, the actual number will be smaller. 106 | """ 107 | 108 | bsz, all_sz = shape 109 | mask = np.full((bsz, all_sz), False) 110 | 111 | all_num_mask = int(mask_prob * all_sz / float(mask_length) + np.random.rand()) 112 | 113 | mask_idcs = [] 114 | for i in range(bsz): 115 | # import ipdb; ipdb.set_trace() 116 | sz = all_sz 117 | num_mask = all_num_mask 118 | 119 | lengths = np.full(num_mask, mask_length) 120 | 121 | if sum(lengths) == 0: 122 | lengths[0] = min(mask_length, sz - 1) 123 | 124 | mask_idc = [] 125 | 126 | def arrange(s, e, length, keep_length): 127 | span_start = np.random.randint(s, e - length) 128 | mask_idc.extend(span_start + i for i in range(length)) 129 | 130 | new_parts = [] 131 | if span_start - s - mask_min_space >= keep_length: 132 | new_parts.append((s, span_start - mask_min_space + 1)) 133 | if e - span_start - keep_length - mask_min_space > keep_length: 134 | new_parts.append((span_start + length + mask_min_space, e)) 135 | return new_parts 136 | 137 | parts = [(0, sz)] 138 | min_length = min(lengths) 139 | for length in sorted(lengths, reverse=True): 140 | lens = np.fromiter( 141 | (e - s if e - s >= length + mask_min_space else 0 for s, e in parts), 142 | int, 143 | ) 144 | l_sum = np.sum(lens) 145 | if l_sum == 0: 146 | break 147 | probs = lens / np.sum(lens) 148 | c = np.random.choice(len(parts), p=probs) 149 | s, e = parts.pop(c) 150 | parts.extend(arrange(s, e, length, min_length)) 151 | mask_idc = np.asarray(mask_idc) 152 | 153 | mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) 154 | 155 | # make sure all masks are the same length in the batch by removing masks 156 | # if they are greater than the min length mask 157 | min_len = min([len(m) for m in mask_idcs]) 158 | 159 | for i, mask_idc in enumerate(mask_idcs): 160 | if len(mask_idc) > min_len: 161 | mask_idc = np.random.choice(mask_idc, min_len, replace=False) 162 | assert len(mask_idc) == min_len 163 | mask[i, mask_idc] = True 164 | 165 | return mask 166 | -------------------------------------------------------------------------------- /kraken/lib/progress.py: -------------------------------------------------------------------------------- 1 | # Copyright Benjamin Kiessling 2 | # Copyright The PyTorch Lightning team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | Handlers for rich-based progress bars. 17 | """ 18 | from dataclasses import dataclass 19 | from typing import TYPE_CHECKING, Union 20 | 21 | from lightning.pytorch.callbacks.progress.rich_progress import ( 22 | CustomProgress, MetricsTextColumn, RichProgressBar) 23 | from rich import get_console, reconfigure 24 | from rich.default_styles import DEFAULT_STYLES 25 | from rich.progress import (BarColumn, DownloadColumn, Progress, ProgressColumn, 26 | TextColumn, TimeElapsedColumn, TimeRemainingColumn) 27 | from rich.text import Text 28 | 29 | if TYPE_CHECKING: 30 | from rich.console import RenderableType 31 | from rich.style import Style 32 | 33 | __all__ = ['KrakenProgressBar', 'KrakenDownloadProgressBar', 'KrakenTrainProgressBar'] 34 | 35 | 36 | class BatchesProcessedColumn(ProgressColumn): 37 | def __init__(self): 38 | super().__init__() 39 | 40 | def render(self, task) -> 'RenderableType': 41 | total = task.total if task.total != float("inf") else "--" 42 | return Text(f"{int(task.completed)}/{total}", style='magenta') 43 | 44 | 45 | class EarlyStoppingColumn(ProgressColumn): 46 | """ 47 | A column containing text. 48 | """ 49 | 50 | def __init__(self, trainer): 51 | self._trainer = trainer 52 | super().__init__() 53 | 54 | def render(self, task) -> Text: 55 | 56 | text = f'early_stopping: ' \ 57 | f'{self._trainer.early_stopping_callback.wait_count}/{self._trainer.early_stopping_callback.patience} ' \ 58 | f'{self._trainer.early_stopping_callback.best_score:.5f}' 59 | return Text(text, justify="left") 60 | 61 | 62 | class KrakenProgressBar(Progress): 63 | """ 64 | Adaptation of the default rich progress bar to fit with kraken/ketos output. 65 | """ 66 | def __init__(self, *args, **kwargs): 67 | columns = [TextColumn("[progress.description]{task.description}"), 68 | BarColumn(), 69 | TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), 70 | BatchesProcessedColumn(), 71 | TimeRemainingColumn(), 72 | TimeElapsedColumn()] 73 | kwargs['refresh_per_second'] = 1 74 | super().__init__(*columns, *args, **kwargs) 75 | 76 | 77 | class KrakenDownloadProgressBar(Progress): 78 | """ 79 | Adaptation of the default rich progress bar to fit with kraken/ketos download output. 80 | """ 81 | def __init__(self, *args, **kwargs): 82 | columns = [TextColumn("[progress.description]{task.description}"), 83 | BarColumn(), 84 | TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), 85 | DownloadColumn(), 86 | TimeRemainingColumn(), 87 | TimeElapsedColumn()] 88 | kwargs['refresh_per_second'] = 1 89 | super().__init__(*columns, *args, **kwargs) 90 | 91 | 92 | class KrakenTrainProgressBar(RichProgressBar): 93 | """ 94 | Adaptation of the default ptl rich progress bar to fit with kraken (segtrain, train) output. 95 | 96 | Args: 97 | refresh_rate: Determines at which rate (in number of batches) the progress bars get updated. 98 | Set it to ``0`` to disable the display. 99 | leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False 100 | console_kwargs: Args for constructing a `Console` 101 | """ 102 | def __init__(self, 103 | *args, 104 | **kwargs): 105 | super().__init__(*args, **kwargs, theme=RichProgressBarTheme()) 106 | 107 | def _init_progress(self, trainer): 108 | if self.is_enabled and (self.progress is None or self._progress_stopped): 109 | self._reset_progress_bar_ids() 110 | reconfigure(**self._console_kwargs) 111 | self._console = get_console() 112 | self._console.clear_live() 113 | self._metric_component = MetricsTextColumn(trainer, 114 | self.theme.metrics, 115 | self.theme.metrics_text_delimiter, 116 | self.theme.metrics_format) 117 | columns = self.configure_columns(trainer) 118 | columns.append(self._metric_component) 119 | 120 | if trainer.early_stopping_callback: 121 | self._early_stopping_component = EarlyStoppingColumn(trainer) 122 | columns.append(self._early_stopping_component) 123 | 124 | self.progress = CustomProgress( 125 | *columns, 126 | auto_refresh=False, 127 | disable=self.is_disabled, 128 | console=self._console, 129 | ) 130 | self.progress.start() 131 | # progress has started 132 | self._progress_stopped = False 133 | 134 | def _get_train_description(self, current_epoch: int) -> str: 135 | return f"stage {current_epoch}/" \ 136 | f"{self.trainer.max_epochs-1 if self.trainer.max_epochs != -1 else '∞'}" 137 | 138 | 139 | @dataclass 140 | class RichProgressBarTheme: 141 | """Styles to associate to different base components. 142 | 143 | Args: 144 | description: Style for the progress bar description. For eg., Epoch x, Testing, etc. 145 | progress_bar: Style for the bar in progress. 146 | progress_bar_finished: Style for the finished progress bar. 147 | progress_bar_pulse: Style for the progress bar when `IterableDataset` is being processed. 148 | batch_progress: Style for the progress tracker (i.e 10/50 batches completed). 149 | time: Style for the processed time and estimate time remaining. 150 | processing_speed: Style for the speed of the batches being processed. 151 | metrics: Style for the metrics 152 | 153 | https://rich.readthedocs.io/en/stable/style.html 154 | """ 155 | 156 | description: Union[str, 'Style'] = DEFAULT_STYLES['progress.description'] 157 | progress_bar: Union[str, 'Style'] = DEFAULT_STYLES['bar.complete'] 158 | progress_bar_finished: Union[str, 'Style'] = DEFAULT_STYLES['bar.finished'] 159 | progress_bar_pulse: Union[str, 'Style'] = DEFAULT_STYLES['bar.pulse'] 160 | batch_progress: Union[str, 'Style'] = DEFAULT_STYLES['progress.description'] 161 | time: Union[str, 'Style'] = DEFAULT_STYLES['progress.elapsed'] 162 | processing_speed: Union[str, 'Style'] = DEFAULT_STYLES['progress.data.speed'] 163 | metrics: Union[str, 'Style'] = DEFAULT_STYLES['progress.description'] 164 | metrics_text_delimiter: str = ' ' 165 | metrics_format: str = '.3f' 166 | -------------------------------------------------------------------------------- /kraken/lib/register.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2025 Benjamin Kiessling 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 13 | # or implied. See the License for the specific language governing 14 | # permissions and limitations under the License. 15 | """ 16 | Register for hyperparameter values 17 | """ 18 | 19 | OPTIMIZERS = ['Adam', 'AdamW', 'SGD', 'RMSprop'] 20 | SCHEDULERS = ['cosine', 'constant', 'exponential', 'step', '1cycle', 'reduceonplateau'] 21 | STOPPERS = ['early', 'fixed'] 22 | PRECISIONS = ['transformer-engine', 'transformer-engine-float16', '16-true', '16-mixed', 'bf16-true', 'bf16-mixed', '32-true', '64-true'] 23 | -------------------------------------------------------------------------------- /kraken/lib/ro/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2023 Benjamin Kiessling 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 13 | # or implied. See the License for the specific language governing 14 | # permissions and limitations under the License. 15 | """ 16 | Tools for trainable reading order. 17 | """ 18 | 19 | from .model import ROModel # NOQA 20 | -------------------------------------------------------------------------------- /kraken/lib/ro/layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Layers for VGSL models 3 | """ 4 | from typing import TYPE_CHECKING, Tuple 5 | 6 | import torch 7 | from torch import nn 8 | 9 | if TYPE_CHECKING: 10 | from kraken.lib.vgsl import VGSLBlock 11 | 12 | # all tensors are ordered NCHW, the "feature" dimension is C, so the output of 13 | # an LSTM will be put into C same as the filters of a CNN. 14 | 15 | __all__ = ['MLP'] 16 | 17 | 18 | class MLP(nn.Module): 19 | """ 20 | A simple 2 layer MLP for reading order determination. 21 | """ 22 | def __init__(self, feature_size: int, hidden_size: int): 23 | super(MLP, self).__init__() 24 | self.fc1 = nn.Linear(feature_size, hidden_size) 25 | self.relu = nn.ReLU() 26 | self.fc2 = nn.Linear(hidden_size, 1) 27 | self.feature_size = feature_size 28 | self.hidden_size = hidden_size 29 | self.class_mapping = None 30 | 31 | def forward(self, x): 32 | x = self.fc1(x) 33 | x = self.relu(x) 34 | return self.fc2(x) 35 | 36 | def get_shape(self, input: Tuple[int, int, int, int]) -> Tuple[int, int, int, int]: 37 | """ 38 | Calculates the output shape from input 4D tuple NCHW. 39 | """ 40 | self.output_shape = input 41 | return input 42 | 43 | def get_spec(self, name) -> 'VGSLBlock': 44 | """ 45 | Generates a VGSL spec block from the layer instance. 46 | """ 47 | return f'[1,0,0,1 RO{{{name}}}{self.feature_size},{self.hidden_size}]' 48 | 49 | def deserialize(self, name: str, spec) -> None: 50 | """ 51 | Sets the weights of an initialized module from a CoreML protobuf spec. 52 | """ 53 | # extract 1st linear projection parameters 54 | lin = [x for x in spec.neuralNetwork.layers if x.name == '{}_mlp_lin_0'.format(name)][0].innerProduct 55 | weights = torch.Tensor(lin.weights.floatValue).resize_as_(self.fc1.weight.data) 56 | bias = torch.Tensor(lin.bias.floatValue) 57 | self.fc1.weight = torch.nn.Parameter(weights) 58 | self.fc1.bias = torch.nn.Parameter(bias) 59 | # extract 2nd linear projection parameters 60 | lin = [x for x in spec.neuralNetwork.layers if x.name == '{}_mlp_lin_1'.format(name)][0].innerProduct 61 | weights = torch.Tensor(lin.weights.floatValue).resize_as_(self.fc2.weight.data) 62 | bias = torch.Tensor(lin.bias.floatValue) 63 | self.fc2.weight = torch.nn.Parameter(weights) 64 | self.fc2.bias = torch.nn.Parameter(bias) 65 | 66 | def serialize(self, name: str, input: str, builder): 67 | """ 68 | Serializes the module using a NeuralNetworkBuilder. 69 | """ 70 | builder.add_inner_product(f'{name}_mlp_lin_0', self.fc1.weight.data.numpy(), 71 | self.fc1.bias.data.numpy(), 72 | self.feature_size, self.hidden_size, 73 | has_bias=True, input_name=input, output_name=f'{name}_mlp_lin_0') 74 | builder.add_activation(f'{name}_mlp_lin_0_relu', 'RELU', f'{name}_mlp_lin_0', f'{name}_mlp_lin_0_relu') 75 | builder.add_inner_product(f'{name}_mlp_lin_1', self.fc2.weight.data.numpy(), 76 | self.fc2.bias.data.numpy(), 77 | self.hidden_size, 1, 78 | has_bias=True, input_name=f'{name}_mlp_lin_0_relu', output_name=f'{name}_mlp_lin_1') 79 | return name 80 | -------------------------------------------------------------------------------- /kraken/lib/ro/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import random 7 | from typing import Sequence, Union 8 | 9 | import numpy as np 10 | import torch 11 | 12 | 13 | def positive_integers_with_sum(n, total): 14 | ls = [0] 15 | rv = [] 16 | while len(ls) < n: 17 | c = random.randint(0, total) 18 | ls.append(c) 19 | ls = sorted(ls) 20 | ls.append(total) 21 | for i in range(1, len(ls)): 22 | rv.append(ls[i] - ls[i-1]) 23 | return rv 24 | 25 | 26 | def compute_masks(mask_prob: int, 27 | mask_width: int, 28 | num_neg_samples: int, 29 | seq_lens: Union[torch.Tensor, Sequence[int]]): 30 | """ 31 | Samples num_mask non-overlapping random masks of length mask_width in 32 | sequence of length seq_len. 33 | 34 | Args: 35 | mask_prob: Probability of each individual token being chosen as start 36 | of a masked sequence. Overall number of masks num_masks is 37 | mask_prob * sum(seq_lens) / mask_width. 38 | mask_width: width of each mask 39 | num_neg_samples: Number of samples from unmasked sequence parts (gets 40 | multiplied by num_mask) 41 | seq_lens: sequence lengths 42 | 43 | Returns: 44 | An index array containing 1 for masked bits, 2 for negative samples, 45 | the number of masks, and the actual number of negative samples. 46 | """ 47 | mask_samples = np.zeros(sum(seq_lens)) 48 | num_masks = int(mask_prob * sum(seq_lens.numpy()) // mask_width) 49 | num_neg_samples = num_masks * num_neg_samples 50 | num_masks += num_neg_samples 51 | 52 | indices = [x+mask_width for x in positive_integers_with_sum(num_masks, sum(seq_lens)-num_masks*mask_width)] 53 | start = 0 54 | mask_slices = [] 55 | for i in indices: 56 | i_start = random.randint(start, i+start-mask_width) 57 | mask_slices.append(slice(i_start, i_start+mask_width)) 58 | start += i 59 | 60 | neg_idx = random.sample(range(len(mask_slices)), num_neg_samples) 61 | neg_slices = [mask_slices.pop(idx) for idx in sorted(neg_idx, reverse=True)] 62 | 63 | mask_samples[np.r_[tuple(mask_slices)]] = 1 64 | mask_samples[np.r_[tuple(neg_slices)]] = 2 65 | 66 | return mask_samples, num_masks - num_neg_samples, num_neg_samples 67 | -------------------------------------------------------------------------------- /kraken/lib/sl.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def dim0(s): 5 | """Dimension of the slice list for dimension 0.""" 6 | return s[0].stop-s[0].start 7 | 8 | 9 | def dim1(s): 10 | """Dimension of the slice list for dimension 1.""" 11 | return s[1].stop-s[1].start 12 | 13 | 14 | def area(a): 15 | """Return the area of the slice list (ignores anything past a[:2].""" 16 | return np.prod([max(x.stop-x.start, 0) for x in a[:2]]) 17 | 18 | 19 | def width(s): 20 | return s[1].stop-s[1].start 21 | 22 | 23 | def height(s): 24 | return s[0].stop-s[0].start 25 | 26 | 27 | def aspect(a): 28 | return height(a)*1.0/width(a) 29 | 30 | 31 | def xcenter(s): 32 | return np.mean([s[1].stop, s[1].start]) 33 | 34 | 35 | def ycenter(s): 36 | return np.mean([s[0].stop, s[0].start]) 37 | 38 | 39 | def center(s): 40 | return (ycenter(s), xcenter(s)) 41 | -------------------------------------------------------------------------------- /kraken/lib/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ocropus's magic PIL-numpy array conversion routines. They express slightly 3 | different behavior from PIL.Image.toarray(). 4 | """ 5 | import unicodedata 6 | import uuid 7 | from typing import TYPE_CHECKING, Callable, Literal, Optional, Union 8 | 9 | import numpy as np 10 | import torch 11 | from PIL import Image 12 | 13 | from kraken.containers import BBoxLine 14 | from kraken.lib import functional_im_transforms as F_t 15 | from kraken.lib.exceptions import KrakenInputException 16 | 17 | if TYPE_CHECKING: 18 | from os import PathLike 19 | 20 | __all__ = ['pil2array', 'array2pil', 'is_bitonal', 'make_printable', 'get_im_str', 'parse_gt_path'] 21 | 22 | 23 | def pil2array(im: Image.Image, alpha: int = 0) -> np.ndarray: 24 | if im.mode == '1': 25 | return np.array(im.convert('L')) 26 | return np.array(im) 27 | 28 | 29 | def array2pil(a: np.ndarray) -> Image.Image: 30 | if a.dtype == np.dtype("B"): 31 | if a.ndim == 2: 32 | return Image.frombytes("L", (a.shape[1], a.shape[0]), 33 | a.tobytes()) 34 | elif a.ndim == 3: 35 | return Image.frombytes("RGB", (a.shape[1], a.shape[0]), 36 | a.tobytes()) 37 | else: 38 | raise Exception("bad image rank") 39 | elif a.dtype == np.dtype('float32'): 40 | return Image.frombytes("F", (a.shape[1], a.shape[0]), a.tobytes()) 41 | else: 42 | raise Exception("unknown image type") 43 | 44 | 45 | def is_bitonal(im: Union[Image.Image, torch.Tensor]) -> bool: 46 | """ 47 | Tests a PIL image or torch tensor for bitonality. 48 | 49 | Args: 50 | im: Image to test 51 | 52 | Returns: 53 | True if the image contains only two different color values. False 54 | otherwise. 55 | """ 56 | if isinstance(im, Image.Image): 57 | return im.getcolors(2) is not None and len(im.getcolors(2)) == 2 58 | elif isinstance(im, torch.Tensor): 59 | return len(im.unique()) == 2 60 | 61 | 62 | def get_im_str(im: Image.Image) -> str: 63 | return im.filename if hasattr(im, 'filename') else str(im) 64 | 65 | 66 | def is_printable(char: str) -> bool: 67 | """ 68 | Determines if a chode point is printable/visible when printed. 69 | 70 | Args: 71 | char (str): Input code point. 72 | 73 | Returns: 74 | True if printable, False otherwise. 75 | """ 76 | letters = ('LC', 'Ll', 'Lm', 'Lo', 'Lt', 'Lu') 77 | numbers = ('Nd', 'Nl', 'No') 78 | punctuation = ('Pc', 'Pd', 'Pe', 'Pf', 'Pi', 'Po', 'Ps') 79 | symbol = ('Sc', 'Sk', 'Sm', 'So') 80 | printable = letters + numbers + punctuation + symbol 81 | 82 | return unicodedata.category(char) in printable 83 | 84 | 85 | def make_printable(char: str) -> str: 86 | """ 87 | Takes a Unicode code point and return a printable representation of it. 88 | 89 | Args: 90 | char (str): Input code point 91 | 92 | Returns: 93 | Either the original code point, the name of the code point if it is a 94 | combining mark, whitespace etc., or the hex code if it is a control 95 | symbol. 96 | """ 97 | if not char or is_printable(char): 98 | return char 99 | elif unicodedata.category(char) in ('Cc', 'Cs', 'Co'): 100 | return '0x{:x}'.format(ord(char)) 101 | else: 102 | try: 103 | return unicodedata.name(char) 104 | except ValueError: 105 | return '0x{:x}'.format(ord(char)) 106 | 107 | 108 | def parse_gt_path(path: Union[str, 'PathLike'], 109 | suffix: str = '.gt.txt', 110 | split: Callable[[Union['PathLike', str]], str] = F_t.default_split, 111 | skip_empty_lines: bool = True, 112 | base_dir: Optional[Literal['L', 'R']] = None, 113 | text_direction: Literal['horizontal-lr', 'horizontal-rl', 'vertical-lr', 'vertical-rl'] = 'horizontal-lr') -> BBoxLine: 114 | """ 115 | Returns a BBoxLine from a image/text file pair. 116 | 117 | Args: 118 | path: Path to image file 119 | suffix: Suffix of the corresponding ground truth text file to image 120 | file in `path`. 121 | split: Suffix stripping function. 122 | skip_empty_lines: Whether to raise an exception if ground truth is 123 | empty or text file is missing. 124 | base_dir: Unicode BiDi algorithm base direction 125 | text_direction: Orientation of the line box. 126 | """ 127 | try: 128 | with Image.open(path) as im: 129 | w, h = im.size 130 | except Exception as e: 131 | raise KrakenInputException(e) 132 | 133 | gt = '' 134 | try: 135 | with open(F_t.suffix_split(path, split=split, suffix=suffix), 'r', encoding='utf-8') as fp: 136 | gt = fp.read().strip('\n\r') 137 | except FileNotFoundError: 138 | if not skip_empty_lines: 139 | raise KrakenInputException(f'No text file found for ground truth line {path}.') 140 | 141 | if not gt and skip_empty_lines: 142 | raise KrakenInputException(f'No text for ground truth line {path}.') 143 | 144 | return BBoxLine(id=f'_{uuid.uuid4()}', 145 | bbox=((0, 0), (w, 0), (w, h), (0, h)), 146 | text=gt, 147 | base_dir=base_dir, 148 | imagename=path, 149 | text_direction=text_direction) 150 | -------------------------------------------------------------------------------- /kraken/repo.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2025 Benjamin Kiessling 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 13 | # or implied. See the License for the specific language governing 14 | # permissions and limitations under the License. 15 | """ 16 | kraken.repo 17 | ~~~~~~~~~~~ 18 | 19 | Wrappers around the htrmopo reference implementation implementing 20 | kraken-specific filtering for repository querying operations. 21 | """ 22 | from collections import defaultdict 23 | from collections.abc import Callable 24 | from typing import Any, Dict, Optional, TypeVar, Literal 25 | 26 | 27 | from htrmopo import get_description as mopo_get_description 28 | from htrmopo import get_listing as mopo_get_listing 29 | from htrmopo.record import v0RepositoryRecord, v1RepositoryRecord 30 | 31 | 32 | _v0_or_v1_Record = TypeVar('_v0_or_v1_Record', v0RepositoryRecord, v1RepositoryRecord) 33 | 34 | 35 | def get_description(model_id: str, 36 | callback: Callable[..., Any] = lambda: None, 37 | version: Optional[Literal['v0', 'v1']] = None, 38 | filter_fn: Optional[Callable[[_v0_or_v1_Record], bool]] = lambda x: True) -> _v0_or_v1_Record: 39 | """ 40 | Filters the output of htrmopo.get_description with a custom function. 41 | 42 | Args: 43 | model_id: model DOI 44 | callback: Progress callback 45 | version: 46 | filter_fn: Function called to filter the retrieved record. 47 | """ 48 | desc = mopo_get_description(model_id, callback, version) 49 | if not filter_fn(desc): 50 | raise ValueError(f'Record {model_id} exists but is not a valid kraken record') 51 | return desc 52 | 53 | 54 | def get_listing(callback: Callable[[int, int], Any] = lambda total, advance: None, 55 | from_date: Optional[str] = None, 56 | filter_fn: Optional[Callable[[_v0_or_v1_Record], bool]] = lambda x: True) -> Dict[str, Dict[str, _v0_or_v1_Record]]: 57 | """ 58 | Returns a filtered representation of the model repository grouped by 59 | concept DOI. 60 | 61 | Args: 62 | callback: Progress callback 63 | from_data: 64 | filter_fn: Function called for each record object 65 | 66 | Returns: 67 | A dictionary mapping group DOIs to one record object per deposit. The 68 | record of the highest available schema version is retained. 69 | """ 70 | kwargs = {} 71 | if from_date is not None: 72 | kwargs['from'] = from_date 73 | repository = mopo_get_listing(callback, **kwargs) 74 | # aggregate models under their concept DOI 75 | concepts = defaultdict(list) 76 | for item in repository.values(): 77 | # filter records here 78 | item = {k: v for k, v in item.items() if filter_fn(v)} 79 | # both got the same DOI information 80 | record = item.get('v1', item.get('v0', None)) 81 | if record is not None: 82 | concepts[record.concept_doi].append(record) 83 | 84 | for k, v in concepts.items(): 85 | concepts[k] = sorted(v, key=lambda x: x.publication_date, reverse=True) 86 | 87 | return concepts 88 | -------------------------------------------------------------------------------- /kraken/templates/abbyyxml: -------------------------------------------------------------------------------- 1 | {%+ macro render_line(page, line) +%} 2 | 3 | {% for segment in line.recognition %} 4 | {% for char in segment.recognition %} 5 | {% if loop.first %} 6 | {{ char.text }} 7 | {% else %} 8 | {{ char.text }} 9 | {% endif %} 10 | {% endfor %} 11 | {% endfor %} 12 | 13 | 14 | {%+ endmacro %} 15 | 16 | 17 | 18 | {% for entity in page.entities %} 19 | {% if entity.type == "region" %} 20 | 21 | 22 | 23 | {%- for line in entity.lines -%} 24 | {{ render_line(page, line) }} 25 | {%- endfor -%} 26 | 27 | 28 | 29 | {% else %} 30 | 31 | 32 | 33 | {{ render_line(page, entity) }} 34 | 35 | 36 | 37 | {% endif %} 38 | {% endfor %} 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /kraken/templates/hocr: -------------------------------------------------------------------------------- 1 | {% macro render_line(line) -%} 2 | 3 | {% for segment in line.recognition %} 4 | {{ segment.text }} 5 | {% endfor -%} 6 | 7 |
8 | {%- endmacro -%} 9 | 10 | 11 | 12 | 13 | 14 | 15 | {% if page.scripts %} 16 | 17 | {% endif %} 18 | 19 | 20 |
21 | {% for entity in page.entities -%} 22 | {% if entity.type == "region" -%} 23 |
24 | {% for line in entity.lines -%} 25 | {{ render_line(line) }} 26 | {% endfor %} 27 |
28 | {% else -%} 29 | {{ render_line(entity) }} 30 | {% endif -%} 31 | {% endfor -%} 32 |
33 | 34 | 35 | -------------------------------------------------------------------------------- /kraken/templates/layout.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 11 | 14 | 15 | 16 |
17 | 25 | 26 |
27 | {% for page in pages %} 28 |
29 |
30 | photo 31 | {% for line in page.lines %} 32 | 33 | {% endfor %} 34 |
35 |
    36 | {% for line in page.lines %} 37 |
  • 38 | {% if line.text %} 39 | {{ line.text }} 40 | {% endif %} 41 |
  • 42 | {% endfor %} 43 |
44 |
45 | {% endfor %} 46 |
47 |
48 | 49 | 50 | -------------------------------------------------------------------------------- /kraken/templates/main.js: -------------------------------------------------------------------------------- 1 | const $ = (s) => document.querySelector(s) 2 | const $$ = (s) => document.querySelectorAll(s) 3 | const activeClass = 'active' 4 | const hoverClass = 'hovered' 5 | 6 | function activate(...els) { els.forEach(el => el.classList.add(activeClass)) } 7 | function deactivate(...els) { els.forEach(el => el.classList.remove(activeClass)) } 8 | function hoverate(...els) { els.forEach(el => el.classList.add(hoverClass)) } 9 | function dehoverate(...els) { els.forEach(el => el.classList.remove(hoverClass)) } 10 | 11 | document.addEventListener('DOMContentLoaded', function() { 12 | const uuid = $('meta[name=uuid]').getAttribute('content') 13 | const inputFields = $$('li[contenteditable=true]') 14 | 15 | const getLocalStorageId = (lineId) => `${uuid}__${lineId}` 16 | 17 | if (localStorage != null) { 18 | inputFields.forEach(function(li) { 19 | li.textContent = localStorage.getItem(getLocalStorageId(li.id)) || '' 20 | }); 21 | } 22 | 23 | // focus text fields when lines/words are clicked + mouseover 24 | $$('a.rect').forEach(function(a) { 25 | const field = document.getElementById(a.getAttribute('alt')) 26 | 27 | a.addEventListener('click', function(e) { 28 | e.preventDefault(); 29 | activate(a, field) 30 | field.focus() 31 | }) 32 | 33 | a.addEventListener('mouseover', () => hoverate(a, field)) 34 | a.addEventListener('mouseout', () => dehoverate(a, field)) 35 | }) 36 | 37 | // create mouseover effect on text fields 38 | inputFields.forEach(function(field) { 39 | const a = $(`a.rect[alt="${field.id}"]`) 40 | 41 | field.addEventListener('mouseover', () => hoverate(a, field)) 42 | field.addEventListener('mouseout', () => dehoverate(a, field)) 43 | field.addEventListener('focus', () => activate(a, field)) 44 | field.addEventListener('blur', () => deactivate(a, field)) 45 | 46 | field.addEventListener('keydown', function(e) { 47 | if (e.which != 13) return 48 | e.preventDefault() 49 | 50 | field.classList.add('corrected') 51 | field.nextElementSibling.focus() 52 | }) 53 | 54 | field.addEventListener('keyup', function () { 55 | localStorage.setItem(getLocalStorageId(field.id), field.textContent) 56 | }) 57 | }) 58 | 59 | 60 | // serializing the DOM to a file 61 | const button = $('button.download > a') 62 | button.addEventListener('click', function(e) { 63 | const path = window.location.pathname 64 | button.setAttribute('href', 'data:text/html,' + encodeURIComponent(document.documentElement.outerHTML)) 65 | button.setAttribute('download', path.substr(path.lastIndexOf('/') + 1)) 66 | }) 67 | }) 68 | -------------------------------------------------------------------------------- /kraken/templates/pagexml: -------------------------------------------------------------------------------- 1 | {%+ macro render_line(line) +%} 2 | 3 | {% if line.boundary %} 4 | 5 | {% endif %} 6 | {% if line.baseline %} 7 | 8 | {% endif %} 9 | {% if line.text is string %} 10 | {{ line.text|e }} 11 | {% else %} 12 | {% for segment in line.recognition %} 13 | 14 | {% if segment.boundary %} 15 | 16 | {% else %} 17 | 18 | {% endif %} 19 | {% for char in segment.recognition %} 20 | 21 | 22 | {{ char.text|e }} 23 | 24 | {% endfor %} 25 | {{ segment.text|e }} 26 | 27 | {% endfor %} 28 | {%+ if line.confidences|length %}{% for segment in line.recognition %}{{ segment.text|e }}{% endfor %}{% endif +%} 29 | {% endif %} 30 | 31 | {%+ endmacro %} 32 | 33 | 34 | 35 | kraken {{ metadata.version }} 36 | {{ page.date }} 37 | {{ page.date }} 38 | 39 | 40 | {% for entity in page.entities %} 41 | {% if entity.type == "region" %} 42 | {% if loop.previtem and loop.previtem.type == 'line' %} 43 | 44 | {% endif %} 45 | 46 | {% if entity.boundary %}{% endif %} 47 | {%- for line in entity.lines -%} 48 | {{ render_line(line) }} 49 | {%- endfor %} 50 | 51 | {% else %} 52 | {% if not loop.previtem or loop.previtem.type != 'line' %} 53 | 54 | 55 | {% endif %} 56 | {{ render_line(entity) }} 57 | {% if loop.last %} 58 | 59 | {% endif %} 60 | {% endif %} 61 | {% endfor %} 62 | 63 | 64 | -------------------------------------------------------------------------------- /kraken/templates/report: -------------------------------------------------------------------------------- 1 | === report {{ report.name }} === 2 | 3 | {{ report.chars }} Characters 4 | {{ report.errors }} Errors 5 | {{ '%0.2f'| format(report.character_accuracy) }}% Character Accuracy 6 | {{ '%0.2f'| format(report.character_CI_accucary) }}% Character Accuracy (Case-insensitive) 7 | {{ '%0.2f'| format(report.word_accuracy) }}% Word Accuracy 8 | 9 | {{ report.insertions }} Insertions 10 | {{ report.deletions }} Deletions 11 | {{ report.substitutions }} Substitutions 12 | 13 | Count Missed %Right 14 | {% for script in report.scripts %} 15 | {{ script.count }} {{ script.errors }} {{'%0.2f'| format(script.accuracy) }}% {{ script.script }} 16 | {% endfor %} 17 | 18 | Errors Correct-Generated 19 | {% for count in report.counts %} 20 | {{ count.errors }} {{ '{ ' }}{{ count.correct }}{{ ' }' }} - {{ '{ ' }}{{ count.generated }}{{ ' }' }} 21 | {% endfor %} 22 | -------------------------------------------------------------------------------- /kraken/templates/style.css: -------------------------------------------------------------------------------- 1 | body { 2 | background: #f3f3f3; 3 | height: 100vh; 4 | margin: 0; 5 | } 6 | 7 | div, ul, li { 8 | box-sizing: border-box; 9 | } 10 | 11 | #pages { 12 | height: 100vh; 13 | margin-left: 15%; 14 | overflow: hidden; 15 | } 16 | 17 | section.page { 18 | display: grid; 19 | grid-template-columns: 1fr 1fr; 20 | } 21 | 22 | /* Left column: facsimile */ 23 | .facsimile { 24 | align-self: center; 25 | justify-self: center; 26 | position: relative; 27 | margin: 1em; 28 | } 29 | 30 | .facsimile > a { 31 | position: absolute; 32 | z-index: 2; 33 | } 34 | 35 | .facsimile > img { 36 | width: 100%; 37 | } 38 | 39 | /* Line highlights on facsimile */ 40 | a.rect:hover, a.hovered { 41 | border: 2px solid rgba(255, 0, 0, .33); 42 | } 43 | 44 | a.active, li[contenteditable=true].active { 45 | border: 2px solid red; 46 | } 47 | 48 | /* Right column: list of input fields */ 49 | section.page > ul { 50 | counter-reset: mycounter; 51 | height: 100vh; 52 | list-style-type: none; 53 | margin: 0; 54 | overflow: auto; 55 | padding: 1em 1em 10vh 3em; 56 | position: relative; 57 | } 58 | 59 | li[contenteditable=true] { 60 | border: 2px dashed #CCC; 61 | height: 1.8em; 62 | margin: 0 0 1em 0; 63 | outline: none; 64 | padding: .2em; 65 | width: 100%; 66 | } 67 | 68 | li[contenteditable=true]:before { 69 | color: #BBB; 70 | content: counter(mycounter); 71 | counter-increment: mycounter; 72 | font-family: monospace; 73 | font-size: 1.2em; 74 | left: 0; 75 | position: absolute; 76 | text-align: right; 77 | width: 28px; 78 | } 79 | 80 | li[contenteditable=true].active:before { 81 | color: #222; 82 | } 83 | 84 | li[contenteditable=true]:hover, 85 | li[contenteditable=true].hovered { 86 | border: 2px solid rgba(255, 0, 0, .33); 87 | } 88 | 89 | li[contenteditable=true].corrected { 90 | background-color: #a6e6a6; 91 | } 92 | 93 | /* Left aside navigation menu */ 94 | nav { 95 | background: #444; 96 | font-family: sans-serif; 97 | position: fixed; 98 | left: 0; 99 | top: 0; 100 | bottom: 0; 101 | width: 15%; 102 | } 103 | 104 | nav li { 105 | display : inline-block; 106 | } 107 | 108 | nav a { 109 | color: white; 110 | text-decoration: none; 111 | } 112 | 113 | nav a:hover { 114 | text-decoration: underline; 115 | } 116 | 117 | button.download { 118 | background: white; 119 | border-radius: 0.3em; 120 | border: 3px solid #CCC; 121 | bottom: 50px; 122 | cursor: pointer; 123 | font-size: 1.1em; 124 | margin-left: 2.5%; 125 | position: fixed; 126 | width: 10%; 127 | } 128 | 129 | button.download > a { 130 | color: #222; 131 | } 132 | 133 | -------------------------------------------------------------------------------- /kraken/transcribe.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 2015 Benjamin Kiessling 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 13 | # or implied. See the License for the specific language governing 14 | # permissions and limitations under the License. 15 | """ 16 | Utility functions for ground truth transcription. 17 | """ 18 | import base64 19 | import logging 20 | import uuid 21 | from io import BytesIO 22 | from typing import Any, Dict, List 23 | 24 | from jinja2 import Environment, PackageLoader 25 | 26 | from kraken.lib.util import get_im_str 27 | 28 | logger = logging.getLogger() 29 | 30 | 31 | class TranscriptionInterface(object): 32 | 33 | def __init__(self, font=None, font_style=None): 34 | logging.info('Initializing transcription object.') 35 | logger.debug('Initializing jinja environment.') 36 | env = Environment(loader=PackageLoader('kraken', 'templates'), autoescape=True) 37 | logger.debug('Loading transcription template.') 38 | self.tmpl = env.get_template('layout.html') 39 | self.pages: List[Dict[Any, Any]] = [] 40 | self.font = {'font': font, 'style': font_style} 41 | self.text_direction = 'horizontal-tb' 42 | self.page_idx = 1 43 | self.line_idx = 1 44 | self.seg_idx = 1 45 | 46 | def add_page(self, im, segmentation = None): 47 | """ 48 | Adds an image to the transcription interface, optionally filling in 49 | information from a list of ocr_record objects. 50 | 51 | Args: 52 | im: Input image 53 | segmentation: Output of the segment method. 54 | """ 55 | im_str = get_im_str(im) 56 | logger.info(f'Adding page {im_str} with {len(segmentation.lines)} lines') 57 | page = {} 58 | fd = BytesIO() 59 | im.save(fd, format='png', optimize=True) 60 | page['index'] = self.page_idx 61 | self.page_idx += 1 62 | logger.debug('Base64 encoding image') 63 | page['img'] = 'data:image/png;base64,' + base64.b64encode(fd.getvalue()).decode('ascii') 64 | page['lines'] = [] 65 | logger.debug('Adding segmentation.') 66 | self.text_direction = segmentation.text_direction 67 | for line in segmentation.lines: 68 | bbox = line.bbox 69 | page['lines'].append({'index': self.line_idx, 70 | 'left': 100*int(bbox[0]) / im.size[0], 71 | 'top': 100*int(bbox[1]) / im.size[1], 72 | 'width': 100*(bbox[2] - bbox[0])/im.size[0], 73 | 'height': 100*(int(bbox[3]) - int(bbox[1]))/im.size[1], 74 | 'bbox': '{}, {}, {}, {}'.format(int(bbox[0]), 75 | int(bbox[1]), 76 | int(bbox[2]), 77 | int(bbox[3]))}) 78 | if line.text: 79 | page['lines'][-1]['text'] = line.prediction 80 | self.line_idx += 1 81 | self.pages.append(page) 82 | 83 | def write(self, fd): 84 | """ 85 | Writes the HTML file to a file descriptor. 86 | 87 | Args: 88 | fd (File): File descriptor (mode='rb') to write to. 89 | """ 90 | logger.info('Rendering and writing transcription.') 91 | fd.write(self.tmpl.render(uuid=f'_{uuid.uuid4()}', pages=self.pages, 92 | font=self.font, 93 | text_direction=self.text_direction).encode('utf-8')) 94 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | 3 | [build-system] 4 | requires = ["pbr>=5.7.0", "setuptools>=36.6.0,<70.0.0", "wheel"] 5 | build-backend = "pbr.build" 6 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | testpaths = 3 | tests 4 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = kraken 3 | author = Benjamin Kiessling 4 | author_email = mittagessen@l.unchti.me 5 | summary = OCR/HTR engine for all the languages 6 | home_page = https://kraken.re 7 | long_description = file: README.rst 8 | long_description_content_type = text/x-rst; charset=UTF-8 9 | license = Apache 10 | classifier = 11 | Development Status :: 5 - Production/Stable 12 | Environment :: Console 13 | Environment :: GPU 14 | Intended Audience :: Science/Research 15 | License :: OSI Approved :: Apache Software License 16 | Operating System :: POSIX 17 | Programming Language :: Python :: 3.9 18 | Programming Language :: Python :: 3.10 19 | Programming Language :: Python :: 3.11 20 | Programming Language :: Python :: 3.12 21 | Topic :: Scientific/Engineering :: Image Recognition 22 | Topic :: Scientific/Engineering :: Artificial Intelligence 23 | 24 | keywords = 25 | ocr 26 | htr 27 | 28 | [files] 29 | packages = kraken 30 | 31 | [entry_points] 32 | console_scripts = 33 | kraken = kraken.kraken:cli 34 | ketos = kraken.ketos:cli 35 | 36 | [flake8] 37 | max_line_length = 160 38 | exclude = tests/* 39 | 40 | [options] 41 | python_requires = >=3.9,<3.13 42 | install_requires = 43 | jsonschema 44 | lxml 45 | requests 46 | click>=8.1 47 | numpy~=2.0.0 48 | Pillow>=9.2.0 49 | regex 50 | scipy~=1.13.0 51 | protobuf>=3.0.0 52 | coremltools~=8.1 53 | jinja2~=3.0 54 | python-bidi~=0.6.0 55 | torchvision>=0.5.0 56 | torch~=2.4.0 57 | scikit-learn~=1.5.0 58 | scikit-image~=0.24.0 59 | shapely>=2.0.6,~=2.0.6 60 | pyarrow 61 | htrmopo>=0.3,~=0.3 62 | lightning~=2.4.0 63 | torchmetrics>=1.1.0 64 | threadpoolctl~=3.5.0 65 | platformdirs 66 | rich 67 | 68 | [options.extras_require] 69 | test = hocr-spec; pytest 70 | pdf = pyvips 71 | augment = albumentations 72 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from setuptools import setup 3 | 4 | setup( 5 | include_package_data=True, 6 | setup_requires=['pbr'], 7 | pbr=True, 8 | ) 9 | -------------------------------------------------------------------------------- /singularity/kraken.def: -------------------------------------------------------------------------------- 1 | BootStrap: library 2 | From: debian 3 | 4 | %post 5 | apt update 6 | apt -y --allow-unauthenticated full-upgrade 7 | apt -y --allow-unauthenticated install --no-install-recommends git python3 python3-pip 8 | apt clean 9 | pip install "kraken[pdf,augment]" 10 | pip install "numpy<1.24" 11 | pip cache purge 12 | mkdir /sps /pbs 13 | -------------------------------------------------------------------------------- /tests/resources/000236.gt.txt: -------------------------------------------------------------------------------- 1 | ܠܐ ܡܕܡ ܢܗܦܘܟ ܗܘܐ ܠܗ ܠܠܐ ܡܕܡ. ܝܬܝܪ ܕܐܣܗܕ -------------------------------------------------------------------------------- /tests/resources/000236.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mittagessen/kraken/14778a4ef2d061d688efcf979cb038d34e4b87e0/tests/resources/000236.png -------------------------------------------------------------------------------- /tests/resources/170025120000003,0074-lite.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | TRP 5 | 2016-06-16T16:57:15.027+02:00 6 | 2018-07-04T17:25:44.389+02:00 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | $pag:39 24 | 25 | 26 | 27 | $pag:39 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | y salvedades alli expressadas; Y fue acetada; 44 | 45 | 46 | 47 | 48 | 49 | 50 | y assi mismo el dho$.dicho $ofi:Patron $ant:Miguel $ant:Carreras, 51 | 52 | 53 | 54 | 55 | 56 | 57 | y el $ofi:Rndo$.Reverendo $ant:Miguel $ant:Carreras $ofi:pbro$.presbítero residente en la 58 | 59 | 60 | 61 | $-nor su hijo, De todos sus bienes, con los pactos 62 | y salvedades alli expressadas; Y fue acetada; 63 | y assi mismo el dho$.dicho $ofi:Patron $ant:Miguel $ant:Carreras, 64 | y el $ofi:Rndo$.Reverendo $ant:Miguel $ant:Carreras $ofi:pbro$.presbítero residente en la 65 | Parq.$^l$.Parroquial Igla$.Iglesia de dha$.dicha villa de $top:Canet Padre é, hijo 66 | hizieron donacion a la dha$.dicha $ant:Anna $ant:Maria su 67 | hija y hermana resp.$^e$.respectivamente por todos sus drôs$.derechos de le:$- 68 | $-gitima Paterna, Materna y otros de ducientas$.doscientas 69 | libras de moneda Bar$.barcelonesa; arca y vestidos corres:$- 70 | $-pondientes, con promesa de pagar en esta 71 | forma, ésto es arcas, ropas y joyas el dia de las 72 | Bodas; cien libras del dia de la fecha, á, medio 73 | año y las restantes cien libras del dho$.dicho dia de la 74 | fecha á tres años prox.$^s$.proximos venturos bajo obli:$- 75 | $-gacion de todos sus bienes; cuya dha$.dicha donacion 76 | fue echa con el pacto revercional acos:$- 77 | $-tumbrado; Y fue azetada por la dha$.dicha $ant:Anna 78 | $ant:Maria por quien fue echa la diffinición cor:$- 79 | $-respondiente de dhos$.dichos sus dros$.derechos a favor del dho$.dicho 80 | su Padre y hermano resp.$^e$.respectivamente y salvose el de 81 | futura sucession: Y en su consequencia hizo 82 | la correspondiente constitucion dotal al 83 | dho$.dicho $ant:Joseph $ant:Vancells su venidero esposo; y este 84 | acetandola prometió en su caso restituir 85 | bajo obligacion de todos sus bienes. 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /tests/resources/170025120000003,0074.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mittagessen/kraken/14778a4ef2d061d688efcf979cb038d34e4b87e0/tests/resources/170025120000003,0074.jpg -------------------------------------------------------------------------------- /tests/resources/ONB_ibn_19110701_010.tif_line_1548924556947_449.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mittagessen/kraken/14778a4ef2d061d688efcf979cb038d34e4b87e0/tests/resources/ONB_ibn_19110701_010.tif_line_1548924556947_449.png -------------------------------------------------------------------------------- /tests/resources/bw.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mittagessen/kraken/14778a4ef2d061d688efcf979cb038d34e4b87e0/tests/resources/bw.png -------------------------------------------------------------------------------- /tests/resources/input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mittagessen/kraken/14778a4ef2d061d688efcf979cb038d34e4b87e0/tests/resources/input.jpg -------------------------------------------------------------------------------- /tests/resources/input.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mittagessen/kraken/14778a4ef2d061d688efcf979cb038d34e4b87e0/tests/resources/input.tif -------------------------------------------------------------------------------- /tests/resources/merge_tests/0006.gt.txt: -------------------------------------------------------------------------------- 1 | Ud; lib; 2 | -------------------------------------------------------------------------------- /tests/resources/merge_tests/0006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mittagessen/kraken/14778a4ef2d061d688efcf979cb038d34e4b87e0/tests/resources/merge_tests/0006.jpg -------------------------------------------------------------------------------- /tests/resources/merge_tests/0007.gt.txt: -------------------------------------------------------------------------------- 1 | ex~ Sĩ 2 | -------------------------------------------------------------------------------- /tests/resources/merge_tests/0007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mittagessen/kraken/14778a4ef2d061d688efcf979cb038d34e4b87e0/tests/resources/merge_tests/0007.jpg -------------------------------------------------------------------------------- /tests/resources/merge_tests/0008.gt.txt: -------------------------------------------------------------------------------- 1 | /Tngt -------------------------------------------------------------------------------- /tests/resources/merge_tests/0008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mittagessen/kraken/14778a4ef2d061d688efcf979cb038d34e4b87e0/tests/resources/merge_tests/0008.jpg -------------------------------------------------------------------------------- /tests/resources/merge_tests/0014.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mittagessen/kraken/14778a4ef2d061d688efcf979cb038d34e4b87e0/tests/resources/merge_tests/0014.jpg -------------------------------------------------------------------------------- /tests/resources/merge_tests/0014.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | pixel 7 | 8 | 0014.jpg 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 23 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 40 | 41 | 46 | 47 | 48 | 49 | 50 | 57 | 58 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /tests/resources/merge_tests/0021.gt.txt: -------------------------------------------------------------------------------- 1 | Ah hodi; Ly9ẽ -------------------------------------------------------------------------------- /tests/resources/merge_tests/0021.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mittagessen/kraken/14778a4ef2d061d688efcf979cb038d34e4b87e0/tests/resources/merge_tests/0021.jpg -------------------------------------------------------------------------------- /tests/resources/merge_tests/base.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mittagessen/kraken/14778a4ef2d061d688efcf979cb038d34e4b87e0/tests/resources/merge_tests/base.arrow -------------------------------------------------------------------------------- /tests/resources/merge_tests/merge_codec_nfd.mlmodel: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mittagessen/kraken/14778a4ef2d061d688efcf979cb038d34e4b87e0/tests/resources/merge_tests/merge_codec_nfd.mlmodel -------------------------------------------------------------------------------- /tests/resources/merge_tests/merger.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mittagessen/kraken/14778a4ef2d061d688efcf979cb038d34e4b87e0/tests/resources/merge_tests/merger.arrow -------------------------------------------------------------------------------- /tests/resources/model_small.mlmodel: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mittagessen/kraken/14778a4ef2d061d688efcf979cb038d34e4b87e0/tests/resources/model_small.mlmodel -------------------------------------------------------------------------------- /tests/resources/overfit.mlmodel: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mittagessen/kraken/14778a4ef2d061d688efcf979cb038d34e4b87e0/tests/resources/overfit.mlmodel -------------------------------------------------------------------------------- /tests/resources/overfit_newpoly.mlmodel: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mittagessen/kraken/14778a4ef2d061d688efcf979cb038d34e4b87e0/tests/resources/overfit_newpoly.mlmodel -------------------------------------------------------------------------------- /tests/resources/segmentation.json: -------------------------------------------------------------------------------- 1 | {"boxes": [[0, 29, 518, 56], [25, 54, 122, 82], [9, 74, 95, 119], [103, 75, 146, 131], [7, 138, 136, 231], [10, 228, 122, 348], [13, 230, 65, 285], [74, 304, 121, 354], [12, 353, 143, 405], [15, 450, 109, 521], [17, 511, 147, 574], [108, 544, 151, 597], [30, 591, 143, 694], [21, 696, 149, 838], [13, 832, 155, 900], [3, 880, 93, 970], [20, 989, 60, 1036], [13, 1096, 67, 1152], [87, 1502, 126, 1558], [7, 1866, 132, 1949], [21, 1978, 93, 2051], [26, 2048, 120, 2091], [518, 297, 580, 337], [654, 293, 1088, 332], [514, 353, 1294, 398], [519, 407, 1294, 447], [515, 453, 1292, 499], [518, 505, 1290, 546], [517, 553, 1292, 594], [514, 603, 1292, 647], [518, 652, 1293, 693], [519, 700, 1296, 742], [518, 750, 1296, 797], [518, 799, 1292, 841], [514, 848, 1296, 897], [515, 895, 885, 944], [517, 943, 1294, 990], [514, 995, 1351, 1043], [513, 1043, 1294, 1094], [513, 1094, 1293, 1141], [512, 1143, 1294, 1192], [512, 1192, 1293, 1240], [513, 1241, 1294, 1284], [517, 1290, 1292, 1331], [515, 1340, 1291, 1383], [514, 1388, 1295, 1438], [517, 1436, 1292, 1487], [516, 1483, 1291, 1539], [1078, 1546, 1283, 1584], [530, 1581, 1291, 1636], [514, 1639, 1291, 1689], [512, 1680, 859, 1716], [1389, 24, 1453, 45]], "text_direction": "horizontal-lr", "script_detection": false} -------------------------------------------------------------------------------- /tests/resources/xlink.xsd: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | -------------------------------------------------------------------------------- /tests/test_align.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import unittest 4 | from pathlib import Path 5 | 6 | from kraken.align import forced_align 7 | from kraken.lib import xml 8 | 9 | thisfile = Path(__file__).resolve().parent 10 | resources = thisfile / 'resources' 11 | 12 | class TestKrakenAlign(unittest.TestCase): 13 | """ 14 | Tests for the forced alignment module 15 | """ 16 | def setUp(self): 17 | self.doc = resources / '170025120000003,0074.xml' 18 | self.bls = xml.XMLPage(self.doc).to_container() 19 | 20 | def test_forced_align_simple(self): 21 | """ 22 | Simple alignment test. 23 | """ 24 | pass 25 | -------------------------------------------------------------------------------- /tests/test_arrow_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import json 4 | import unittest 5 | import tempfile 6 | import pyarrow as pa 7 | 8 | from pathlib import Path 9 | from pytest import raises, fixture 10 | 11 | import kraken 12 | from kraken.lib import xml 13 | from kraken.lib.arrow_dataset import build_binary_dataset 14 | 15 | thisfile = Path(__file__).resolve().parent 16 | resources = thisfile / 'resources' 17 | 18 | def _validate_ds(self, path, num_lines, num_empty_lines, ds_type): 19 | with pa.memory_map(path, 'rb') as source: 20 | ds_table = pa.ipc.open_file(source).read_all() 21 | raw_metadata = ds_table.schema.metadata 22 | if not raw_metadata or b'lines' not in raw_metadata: 23 | raise ValueError(f'{file} does not contain a valid metadata record.') 24 | metadata = json.loads(raw_metadata[b'lines']) 25 | self.assertEqual(metadata['type'], 26 | ds_type, 27 | f'Unexpected dataset type (expected: {ds_type}, found: {metadata["type"]}') 28 | self.assertEqual(metadata['counts']['all'], 29 | num_lines, 30 | 'Unexpected number of lines in dataset metadata ' 31 | f'(expected: {num_lines}, found: {metadata["counts"]["all"]}') 32 | self.assertEqual(len(ds_table), 33 | num_lines, 34 | 'Unexpected number of rows in arrow table ' 35 | f'(expected: {num_lines}, found: {metadata["counts"]["all"]}') 36 | 37 | real_empty_lines = len([line for line in ds_table.column('lines') if not str(line[0])]) 38 | self.assertEqual(real_empty_lines, 39 | num_empty_lines, 40 | 'Unexpected number of empty lines in dataset ' 41 | f'(expected: {num_empty_lines}, found: {real_empty_lines}') 42 | 43 | 44 | class TestKrakenArrowCompilation(unittest.TestCase): 45 | """ 46 | Tests for binary datasets 47 | """ 48 | def setUp(self): 49 | self.xml = resources / '170025120000003,0074-lite.xml' 50 | self.seg = xml.XMLPage(self.xml).to_container() 51 | self.box_lines = [resources / '000236.png'] 52 | 53 | def test_build_path_dataset(self): 54 | with tempfile.NamedTemporaryFile() as tmp_file: 55 | build_binary_dataset(files=4*self.box_lines, 56 | output_file=tmp_file.name, 57 | format_type='path') 58 | _validate_ds(self, tmp_file.name, 4, 0, 'kraken_recognition_bbox') 59 | 60 | def test_build_xml_dataset(self): 61 | with tempfile.NamedTemporaryFile() as tmp_file: 62 | build_binary_dataset(files=[self.xml], 63 | output_file=tmp_file.name, 64 | format_type='xml') 65 | _validate_ds(self, tmp_file.name, 4, 0, 'kraken_recognition_baseline') 66 | 67 | def test_build_seg_dataset(self): 68 | with tempfile.NamedTemporaryFile() as tmp_file: 69 | build_binary_dataset(files=[self.seg], 70 | output_file=tmp_file.name, 71 | format_type=None) 72 | _validate_ds(self, tmp_file.name, 4, 0, 'kraken_recognition_baseline') 73 | 74 | def test_forced_type_dataset(self): 75 | with tempfile.NamedTemporaryFile() as tmp_file: 76 | build_binary_dataset(files=4*self.box_lines, 77 | output_file=tmp_file.name, 78 | format_type='path', 79 | force_type='kraken_recognition_baseline') 80 | _validate_ds(self, tmp_file.name, 4, 0, 'kraken_recognition_baseline') 81 | 82 | def test_build_empty_dataset(self): 83 | """ 84 | Test that empty lines are retained in compiled dataset. 85 | """ 86 | with tempfile.NamedTemporaryFile() as tmp_file: 87 | build_binary_dataset(files=[self.xml], 88 | output_file=tmp_file.name, 89 | format_type='xml', 90 | skip_empty_lines=False) 91 | _validate_ds(self, tmp_file.name, 5, 1, 'kraken_recognition_baseline') 92 | 93 | @fixture(autouse=True) 94 | def caplog_fixture(self, caplog): 95 | # make pytest caplog fixture available 96 | self.caplog = caplog 97 | 98 | def test_build_image_error(self): 99 | """ 100 | Test that image load errors are handled. 101 | """ 102 | # change resource path so it will not resolve 103 | bad_box_lines = [path.with_name(f"bogus_{path.stem}") for path in self.box_lines] 104 | with tempfile.NamedTemporaryFile() as tmp_file: 105 | build_binary_dataset(files=bad_box_lines, 106 | output_file=tmp_file.name, 107 | format_type='xml') 108 | # expect zero resulting lines due to image load error 109 | _validate_ds(self, tmp_file.name, 0, 0, 'kraken_recognition_baseline') 110 | # expect one warning log message; should include the file image filename 111 | assert len(self.caplog.records) == 1 112 | log_record = self.caplog.records[0] 113 | assert log_record.levelname == "WARNING" 114 | assert f"Invalid input file {bad_box_lines[0]}" in log_record.message 115 | -------------------------------------------------------------------------------- /tests/test_binarization.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import unittest 3 | from pathlib import Path 4 | 5 | from PIL import Image 6 | from pytest import raises 7 | 8 | from kraken.binarization import nlbin 9 | from kraken.lib.exceptions import KrakenInputException 10 | 11 | thisfile = Path(__file__).resolve().parent 12 | resources = thisfile / 'resources' 13 | 14 | class TestBinarization(unittest.TestCase): 15 | 16 | """ 17 | Tests of the nlbin function for binarization of images 18 | """ 19 | def test_not_binarize_empty(self): 20 | """ 21 | Test that mode '1' images aren't binarized again. 22 | """ 23 | with raises(KrakenInputException): 24 | with Image.new('1', (1000,1000)) as im: 25 | nlbin(im) 26 | 27 | def test_not_binarize_bw(self): 28 | """ 29 | Test that mode '1' images aren't binarized again. 30 | """ 31 | with Image.open(resources / 'bw.png') as im: 32 | self.assertEqual(im, nlbin(im)) 33 | 34 | def test_binarize_no_bw(self): 35 | """ 36 | Tests binarization of image formats without a 1bpp mode (JPG). 37 | """ 38 | with Image.open(resources / 'input.jpg') as im: 39 | res = nlbin(im) 40 | # calculate histogram and check if only pixels of value 0/255 exist 41 | self.assertEqual(254, res.histogram().count(0), msg='Output not ' 42 | 'binarized') 43 | 44 | def test_binarize_tif(self): 45 | """ 46 | Tests binarization of RGB TIFF images. 47 | """ 48 | with Image.open(resources /'input.tif') as im: 49 | res = nlbin(im) 50 | # calculate histogram and check if only pixels of value 0/255 exist 51 | self.assertEqual(254, res.histogram().count(0), msg='Output not ' 52 | 'binarized') 53 | 54 | def test_binarize_grayscale(self): 55 | """ 56 | Test binarization of mode 'L' images. 57 | """ 58 | with Image.open(resources / 'input.tif') as im: 59 | res = nlbin(im.convert('L')) 60 | # calculate histogram and check if only pixels of value 0/255 exist 61 | self.assertEqual(254, res.histogram().count(0), msg='Output not ' 62 | 'binarized') 63 | -------------------------------------------------------------------------------- /tests/test_cli.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import tempfile 4 | import unittest 5 | from pathlib import Path 6 | 7 | import click 8 | import numpy as np 9 | from click.testing import CliRunner 10 | from PIL import Image 11 | from pytest import raises 12 | 13 | from kraken.kraken import cli 14 | 15 | thisfile = Path(__file__).resolve().parent 16 | resources = thisfile / 'resources' 17 | 18 | class TestCLI(unittest.TestCase): 19 | """ 20 | Testing the kraken CLI 21 | """ 22 | 23 | def setUp(self): 24 | self.temp = tempfile.NamedTemporaryFile(delete=False) 25 | self.runner = CliRunner() 26 | self.color_img = resources / 'input.tif' 27 | self.bw_img = resources / 'bw.png' 28 | 29 | def tearDown(self): 30 | self.temp.close() 31 | os.unlink(self.temp.name) 32 | 33 | def test_binarize_color(self): 34 | """ 35 | Tests binarization of color images. 36 | """ 37 | with tempfile.NamedTemporaryFile() as fp: 38 | result = self.runner.invoke(cli, ['-i', self.color_img, fp.name, 'binarize']) 39 | self.assertEqual(result.exit_code, 0) 40 | self.assertEqual(tuple(map(lambda x: x[1], Image.open(fp).getcolors())), (0, 255)) 41 | 42 | def test_binarize_bw(self): 43 | """ 44 | Tests binarization of b/w images. 45 | """ 46 | with tempfile.NamedTemporaryFile() as fp: 47 | result = self.runner.invoke(cli, ['-i', self.bw_img, fp.name, 'binarize']) 48 | self.assertEqual(result.exit_code, 0) 49 | bw = np.array(Image.open(self.bw_img)) 50 | new = np.array(Image.open(fp.name)) 51 | self.assertTrue(np.all(bw == new)) 52 | 53 | def test_segment_color(self): 54 | """ 55 | Tests that segmentation is aborted when given color image. 56 | """ 57 | with tempfile.NamedTemporaryFile() as fp: 58 | result = self.runner.invoke(cli, ['-r', '-i', self.color_img, fp.name, 'segment']) 59 | self.assertEqual(result.exit_code, 1) 60 | -------------------------------------------------------------------------------- /tests/test_lineest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import unittest 3 | from pathlib import Path 4 | 5 | from PIL import Image 6 | from pytest import raises 7 | 8 | from kraken.lib import lineest 9 | 10 | thisfile = Path(__file__).resolve().parent 11 | resources = thisfile / 'resources' 12 | 13 | class TestLineest(unittest.TestCase): 14 | 15 | """ 16 | Testing centerline estimator 17 | """ 18 | 19 | def setUp(self): 20 | self.lnorm = lineest.CenterNormalizer() 21 | 22 | def test_dewarp_bw(self): 23 | """ 24 | Test dewarping of a single line in B/W 25 | """ 26 | with Image.open(resources / '000236.png') as im: 27 | o = lineest.dewarp(self.lnorm, im.convert('1')) 28 | self.assertEqual(self.lnorm.target_height, o.size[1]) 29 | 30 | def test_dewarp_gray(self): 31 | """ 32 | Test dewarping of a single line in grayscale 33 | """ 34 | with Image.open(resources /'000236.png') as im: 35 | o = lineest.dewarp(self.lnorm, im.convert('L')) 36 | self.assertEqual(self.lnorm.target_height, o.size[1]) 37 | 38 | def test_dewarp_fail_color(self): 39 | """ 40 | Test dewarping of a color line fails 41 | """ 42 | with raises(ValueError): 43 | with Image.open(resources /'000236.png') as im: 44 | lineest.dewarp(self.lnorm, im.convert('RGB')) 45 | 46 | def test_dewarp_bw_undewarpable(self): 47 | """ 48 | Test dewarping of an undewarpable line. 49 | """ 50 | with Image.open(resources /'ONB_ibn_19110701_010.tif_line_1548924556947_449.png') as im: 51 | o = lineest.dewarp(self.lnorm, im) 52 | self.assertEqual(self.lnorm.target_height, o.size[1]) 53 | 54 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import pickle 4 | import tempfile 5 | import unittest 6 | from pathlib import Path 7 | 8 | from pytest import raises 9 | 10 | from kraken.lib import models 11 | from kraken.lib.exceptions import KrakenInvalidModelException 12 | 13 | thisfile = Path(__file__).resolve().parent 14 | resources = thisfile / 'resources' 15 | 16 | class TestModels(unittest.TestCase): 17 | """ 18 | Testing model loading routines 19 | """ 20 | 21 | def setUp(self): 22 | self.temp = tempfile.NamedTemporaryFile(delete=False) 23 | 24 | def tearDown(self): 25 | self.temp.close() 26 | os.unlink(self.temp.name) 27 | 28 | def test_load_invalid(self): 29 | """ 30 | Tests correct handling of invalid files. 31 | """ 32 | with raises(KrakenInvalidModelException): 33 | models.load_any(self.temp.name) 34 | -------------------------------------------------------------------------------- /tests/test_pageseg.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import unittest 3 | from pathlib import Path 4 | 5 | from PIL import Image 6 | from pytest import raises 7 | 8 | from kraken.lib.exceptions import KrakenInputException 9 | from kraken.pageseg import segment 10 | 11 | thisfile = Path(__file__).resolve().parent 12 | resources = thisfile / 'resources' 13 | 14 | class TestPageSeg(unittest.TestCase): 15 | 16 | """ 17 | Tests of the page segmentation functionality 18 | """ 19 | def test_segment_color(self): 20 | """ 21 | Test correct handling of color input. 22 | """ 23 | with raises(KrakenInputException): 24 | with Image.open(resources / 'input.jpg') as im: 25 | segment(im) 26 | 27 | def test_segment_bw(self): 28 | """ 29 | Tests segmentation of bi-level input. 30 | """ 31 | with Image.open(resources / 'bw.png') as im: 32 | seg = segment(im) 33 | self.assertEqual(seg.type, 'bbox') 34 | # test if line count is roughly correct 35 | self.assertAlmostEqual(len(seg.lines), 30, msg='Segmentation differs ' 36 | 'wildly from true line count', delta=5) 37 | # check if lines do not extend beyond image 38 | for line in seg.lines: 39 | box = line.bbox 40 | self.assertLess(0, box[0], msg='Line x0 < 0') 41 | self.assertLess(0, box[1], msg='Line y0 < 0') 42 | self.assertGreater(im.size[0], box[2], msg='Line x1 > {}'.format(im.size[0])) 43 | self.assertGreater(im.size[1], box[3], msg='Line y1 > {}'.format(im.size[1])) 44 | -------------------------------------------------------------------------------- /tests/test_repo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import shutil 3 | import tempfile 4 | import unittest 5 | from pathlib import Path 6 | 7 | from kraken import repo 8 | 9 | thisfile = Path(__file__).resolve().parent 10 | resources = thisfile / 'resources' 11 | 12 | class TestRepo(unittest.TestCase): 13 | """ 14 | Testing our wrappers around HTRMoPo 15 | """ 16 | 17 | def setUp(self): 18 | self.temp_model = tempfile.TemporaryDirectory() 19 | self.temp_path = Path(self.temp_model.name) 20 | 21 | def tearDown(self): 22 | shutil.rmtree(self.temp_model.name) 23 | 24 | def test_listing(self): 25 | """ 26 | Tests fetching the model list. 27 | """ 28 | records = repo.get_listing() 29 | self.assertGreater(len(records), 15) 30 | 31 | def test_get_description(self): 32 | """ 33 | Tests fetching the description of a model. 34 | """ 35 | record = repo.get_description('10.5281/zenodo.8425684') 36 | self.assertEqual(record.doi, '10.5281/zenodo.8425684') 37 | 38 | def test_prev_record_version_get_description(self): 39 | """ 40 | Tests fetching the description of a model that has a superseding newer version. 41 | """ 42 | record = repo.get_description('10.5281/zenodo.6657809') 43 | self.assertEqual(record.doi, '10.5281/zenodo.6657809') 44 | -------------------------------------------------------------------------------- /tests/test_transcribe.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import os 4 | import unittest 5 | from io import BytesIO 6 | from pathlib import Path 7 | 8 | from lxml import etree 9 | from PIL import Image 10 | 11 | from kraken import containers 12 | from kraken.transcribe import TranscriptionInterface 13 | 14 | thisfile = Path(__file__).resolve().parent 15 | resources = thisfile / 'resources' 16 | 17 | class TestTranscriptionInterface(unittest.TestCase): 18 | 19 | """ 20 | Test of the transcription interface generation 21 | """ 22 | def setUp(self): 23 | with open(resources /'records.json', 'r') as fp: 24 | self.box_records = [containers.BBoxOCRRecord(**x) for x in json.load(fp)] 25 | 26 | self.box_segmentation = containers.Segmentation(type='bbox', 27 | imagename='foo.png', 28 | text_direction='horizontal-lr', 29 | lines=self.box_records, 30 | script_detection=True, 31 | regions={}) 32 | 33 | self.im = Image.open(resources / 'input.jpg') 34 | 35 | def test_transcription_generation(self): 36 | """ 37 | Tests creation of transcription interfaces with segmentation. 38 | """ 39 | tr = TranscriptionInterface() 40 | tr.add_page(im = self.im, segmentation=self.box_segmentation) 41 | fp = BytesIO() 42 | tr.write(fp) 43 | # this will not throw an exception ever so we need a better validator 44 | etree.HTML(fp.getvalue()) 45 | -------------------------------------------------------------------------------- /tests/test_vgsl.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import tempfile 4 | import unittest 5 | 6 | import torch 7 | from pytest import raises 8 | 9 | from kraken.lib import layers, vgsl 10 | 11 | 12 | class TestVGSL(unittest.TestCase): 13 | """ 14 | Testing VGSL module 15 | """ 16 | def test_helper_train(self): 17 | """ 18 | Tests train/eval mode helper methods 19 | """ 20 | rnn = vgsl.TorchVGSLModel('[1,1,0,48 Lbx10 Do O1c57]') 21 | rnn.train() 22 | self.assertTrue(torch.is_grad_enabled()) 23 | self.assertTrue(rnn.nn.training) 24 | rnn.eval() 25 | self.assertFalse(torch.is_grad_enabled()) 26 | self.assertFalse(rnn.nn.training) 27 | 28 | @unittest.skip('works randomly on ci') 29 | def test_helper_threads(self): 30 | """ 31 | Test openmp threads helper method. 32 | """ 33 | rnn = vgsl.TorchVGSLModel('[1,1,0,48 Lbx10 Do O1c57]') 34 | rnn.set_num_threads(4) 35 | self.assertEqual(torch.get_num_threads(), 4) 36 | 37 | def test_save_model(self): 38 | """ 39 | Test model serialization. 40 | """ 41 | rnn = vgsl.TorchVGSLModel('[1,1,0,48 Lbx10 Do O1c57]') 42 | with tempfile.TemporaryDirectory() as dir: 43 | rnn.save_model(dir + '/foo.mlmodel') 44 | self.assertTrue(os.path.exists(dir + '/foo.mlmodel')) 45 | 46 | def test_append(self): 47 | """ 48 | Test appending one VGSL spec to another. 49 | """ 50 | rnn = vgsl.TorchVGSLModel('[1,1,0,48 Lbx10 Do O1c57]') 51 | rnn.append(1, '[Cr1,1,2 Gn2 Cr3,3,4]') 52 | self.assertEqual(rnn.spec, '[1,1,0,48 Lbx{L_0}10 Cr{C_1}1,1,2 Gn{Gn_2}2 Cr{C_3}3,3,4]') 53 | 54 | def test_resize(self): 55 | """ 56 | Tests resizing of output layers. 57 | """ 58 | rnn = vgsl.TorchVGSLModel('[1,1,0,48 Lbx10 Do O1c57]') 59 | rnn.resize_output(80) 60 | self.assertEqual(rnn.nn[-1].lin.out_features, 80) 61 | 62 | def test_del_resize(self): 63 | """ 64 | Tests resizing of output layers with entry deletion. 65 | """ 66 | rnn = vgsl.TorchVGSLModel('[1,1,0,48 Lbx10 Do O1c57]') 67 | rnn.resize_output(80, [2, 4, 5, 6, 7, 12, 25]) 68 | self.assertEqual(rnn.nn[-1].lin.out_features, 80) 69 | 70 | def test_nested_serial_model(self): 71 | """ 72 | Test the creation of a nested serial model. 73 | """ 74 | net = vgsl.TorchVGSLModel('[1,48,0,1 Cr4,2,1,4,2 ([Cr4,2,1,1,1 Do Cr3,3,2,1,1] [Cr4,2,1,1,1 Cr3,3,2,1,1 Do]) S1(1x0)1,3 Lbx2 Do0.5 Lbx2]') 75 | self.assertIsInstance(net.nn[1], layers.MultiParamParallel) 76 | for x in net.nn[1].children(): 77 | self.assertIsInstance(x, layers.MultiParamSequential) 78 | self.assertEqual(len(x), 3) 79 | 80 | def test_parallel_model_inequal(self): 81 | """ 82 | Test proper raising of ValueError when parallel layers do not have the same output shape. 83 | """ 84 | with raises(ValueError): 85 | net = vgsl.TorchVGSLModel('[1,48,0,1 Cr4,2,1,4,2 [Cr4,2,1,1,1 (Cr4,2,1,4,2 Cr3,3,2,1,1) S1(1x0)1,3 Lbx2 Do0.5] Lbx2]') 86 | 87 | def test_complex_serialization(self): 88 | """ 89 | Test proper serialization and deserialization of a complex model. 90 | """ 91 | net = vgsl.TorchVGSLModel('[1,48,0,1 Cr4,2,1,4,2 ([Cr4,2,1,1,1 Do Cr3,3,2,1,1] [Cr4,2,1,1,1 Cr3,3,2,1,1 Do]) S1(1x0)1,3 Lbx2 Do0.5 Lbx2]') 92 | -------------------------------------------------------------------------------- /tests/test_xml.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import tempfile 4 | import unittest 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | from pytest import raises 9 | 10 | from kraken.lib import xml 11 | 12 | thisfile = Path(__file__).resolve().parent 13 | resources = thisfile / 'resources' 14 | 15 | class TestXMLParser(unittest.TestCase): 16 | """ 17 | Tests XML (ALTO/PAGE) parsing 18 | """ 19 | def setUp(self): 20 | self.page_doc = resources / 'cPAS-2000.xml' 21 | self.alto_doc = resources / 'bsb00084914_00007.xml' 22 | 23 | def test_page_parsing(self): 24 | """ 25 | Test parsing of PAGE XML files with reading order. 26 | """ 27 | doc = xml.XMLPage(self.page_doc, filetype='page') 28 | self.assertEqual(len(doc.get_sorted_lines()), 97) 29 | self.assertEqual(len([item for x in doc.regions.values() for item in x]), 4) 30 | 31 | def test_alto_parsing(self): 32 | """ 33 | Test parsing of ALTO XML files with reading order. 34 | """ 35 | doc = xml.XMLPage(self.alto_doc, filetype='alto') 36 | 37 | def test_auto_parsing(self): 38 | """ 39 | Test parsing of PAGE and ALTO XML files with auto-format determination. 40 | """ 41 | doc = xml.XMLPage(self.page_doc, filetype='xml') 42 | self.assertEqual(doc.filetype, 'page') 43 | doc = xml.XMLPage(self.alto_doc, filetype='xml') 44 | self.assertEqual(doc.filetype, 'alto') 45 | 46 | def test_failure_page_alto_parsing(self): 47 | """ 48 | Test that parsing ALTO files with PAGE as format fails. 49 | """ 50 | with raises(ValueError): 51 | xml.XMLPage(self.alto_doc, filetype='page') 52 | 53 | def test_failure_alto_page_parsing(self): 54 | """ 55 | Test that parsing PAGE files with ALTO as format fails. 56 | """ 57 | with raises(ValueError): 58 | xml.XMLPage(self.page_doc, filetype='alto') 59 | 60 | --------------------------------------------------------------------------------