├── .editorconfig ├── .github └── workflows │ ├── close-inactive-issues.yaml │ ├── main.yaml │ └── secret-scanning.yaml ├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── Makefile ├── README.md ├── config ├── arg_plate_example.yaml └── latin_plate_example.yaml ├── docs ├── architecture.md ├── contributing.md ├── index.md ├── installation.md ├── reference.md └── usage.md ├── fast_plate_ocr ├── __init__.py ├── cli │ ├── __init__.py │ ├── cli.py │ ├── onnx_converter.py │ ├── train.py │ ├── utils.py │ ├── valid.py │ ├── visualize_augmentation.py │ └── visualize_predictions.py ├── common │ ├── __init__.py │ └── utils.py ├── inference │ ├── __init__.py │ ├── config.py │ ├── hub.py │ ├── onnx_inference.py │ ├── process.py │ └── utils.py ├── py.typed └── train │ ├── __init__.py │ ├── data │ ├── __init__.py │ ├── augmentation.py │ └── dataset.py │ ├── model │ ├── __init__.py │ ├── config.py │ ├── custom.py │ ├── layer_blocks.py │ └── models.py │ └── utilities │ ├── __init__.py │ ├── backend_utils.py │ └── utils.py ├── mkdocs.yml ├── poetry.lock ├── poetry.toml ├── pyproject.toml └── test ├── __init__.py ├── assets ├── __init__.py ├── test_plate_1.png └── test_plate_2.png ├── conftest.py └── fast_lp_ocr ├── __init__.py ├── inference ├── __init__.py ├── test_hub.py ├── test_onnx_inference.py └── test_process.py └── train ├── __init__.py ├── test_config.py ├── test_custom.py ├── test_models.py └── test_utils.py /.editorconfig: -------------------------------------------------------------------------------- 1 | # EditorConfig: https://EditorConfig.org 2 | root = true 3 | 4 | [*] 5 | charset = utf-8 6 | end_of_line = lf 7 | indent_style = space 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | 11 | [*.py] 12 | indent_size = 4 13 | max_line_length = 100 14 | -------------------------------------------------------------------------------- /.github/workflows/close-inactive-issues.yaml: -------------------------------------------------------------------------------- 1 | name: Close inactive issues 2 | 3 | on: 4 | schedule: 5 | - cron: "30 1 * * *" # Runs daily at 1:30 AM UTC 6 | 7 | jobs: 8 | close-issues: 9 | runs-on: ubuntu-latest 10 | permissions: 11 | issues: write 12 | pull-requests: write 13 | steps: 14 | - uses: actions/stale@v5 15 | with: 16 | days-before-issue-stale: 90 # The number of days old an issue can be before marking it stale 17 | days-before-issue-close: 14 # The number of days to wait to close an issue after it being marked stale 18 | stale-issue-label: "stale" 19 | stale-issue-message: "This issue is stale because it has been open for 90 days with no activity." 20 | close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale." 21 | days-before-pr-stale: -1 # Disables stale behavior for PRs 22 | days-before-pr-close: -1 # Disables closing behavior for PRs 23 | repo-token: ${{ secrets.GITHUB_TOKEN }} 24 | -------------------------------------------------------------------------------- /.github/workflows/main.yaml: -------------------------------------------------------------------------------- 1 | name: Test and Deploy 2 | on: 3 | push: 4 | branches: 5 | - master 6 | tags: 7 | - 'v*' 8 | pull_request: 9 | branches: 10 | - master 11 | jobs: 12 | test: 13 | name: Test 14 | strategy: 15 | fail-fast: false 16 | matrix: 17 | python-version: [ '3.10', '3.11', '3.12' ] 18 | os: [ ubuntu-latest ] 19 | runs-on: ${{ matrix.os }} 20 | steps: 21 | - uses: actions/checkout@v3 22 | 23 | - name: Install poetry 24 | run: pipx install poetry 25 | 26 | - uses: actions/setup-python@v4 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | cache: 'poetry' 30 | 31 | - name: Install dependencies 32 | run: poetry install --all-extras 33 | 34 | - name: Check format 35 | run: make check_format 36 | 37 | - name: Run linters 38 | run: make lint 39 | 40 | - name: Run tests 41 | run: make test 42 | 43 | publish-to-pypi: 44 | name: Build and Publish to PyPI 45 | needs: 46 | - test 47 | if: "startsWith(github.ref, 'refs/tags/v')" 48 | runs-on: ubuntu-latest 49 | environment: 50 | name: pypi 51 | url: https://pypi.org/p/fast-plate-ocr 52 | permissions: 53 | id-token: write 54 | steps: 55 | - uses: actions/checkout@v3 56 | 57 | - name: Install poetry 58 | run: pipx install poetry 59 | 60 | - name: Setup Python 61 | uses: actions/setup-python@v3 62 | with: 63 | python-version: '3.10' 64 | 65 | - name: Build a binary wheel 66 | run: poetry build 67 | 68 | - name: Publish distribution 📦 to PyPI 69 | uses: pypa/gh-action-pypi-publish@release/v1 70 | 71 | github-release: 72 | name: Create GitHub release 73 | needs: 74 | - publish-to-pypi 75 | runs-on: ubuntu-latest 76 | 77 | permissions: 78 | contents: write 79 | 80 | steps: 81 | - uses: actions/checkout@v3 82 | 83 | - name: Check package version matches tag 84 | id: check-version 85 | uses: samuelcolvin/check-python-version@v4.1 86 | with: 87 | version_file_path: 'pyproject.toml' 88 | 89 | - name: Create GitHub Release 90 | env: 91 | GITHUB_TOKEN: ${{ github.token }} 92 | tag: ${{ github.ref_name }} 93 | run: | 94 | gh release create "$tag" \ 95 | --repo="$GITHUB_REPOSITORY" \ 96 | --title="${GITHUB_REPOSITORY#*/} ${tag#v}" \ 97 | --generate-notes 98 | 99 | update_docs: 100 | name: Update documentation 101 | needs: 102 | - github-release 103 | runs-on: ubuntu-latest 104 | 105 | steps: 106 | - uses: actions/checkout@v3 107 | with: 108 | fetch-depth: 0 109 | 110 | - name: Install poetry 111 | run: pipx install poetry 112 | 113 | - uses: actions/setup-python@v4 114 | with: 115 | python-version: '3.10' 116 | cache: 'poetry' 117 | 118 | - name: Configure Git user 119 | run: | 120 | git config --local user.email "github-actions[bot]@users.noreply.github.com" 121 | git config --local user.name "github-actions[bot]" 122 | 123 | - name: Retrieve version 124 | id: check-version 125 | uses: samuelcolvin/check-python-version@v4.1 126 | with: 127 | version_file_path: 'pyproject.toml' 128 | skip_env_check: true 129 | 130 | - name: Deploy the docs 131 | run: | 132 | poetry run mike deploy \ 133 | --update-aliases \ 134 | --push \ 135 | --branch docs-site \ 136 | ${{ steps.check-version.outputs.VERSION_MAJOR_MINOR }} latest 137 | -------------------------------------------------------------------------------- /.github/workflows/secret-scanning.yaml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: 4 | - master 5 | pull_request: 6 | branches: 7 | - master 8 | 9 | name: Secret Leaks 10 | jobs: 11 | trufflehog: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: Checkout code 15 | uses: actions/checkout@v4 16 | with: 17 | fetch-depth: 0 18 | - name: Secret Scanning 19 | uses: trufflesecurity/trufflehog@main 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # pyenv 156 | .python-version 157 | 158 | # CUDA DNN 159 | cudnn64_7.dll 160 | 161 | # Train folder 162 | train_val_set/ 163 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), 6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 7 | 8 | ## [0.3.0] - 2024-12-08 9 | 10 | ### Added 11 | 12 | - New Global model using MobileViTV2 trained with data from +65 countries, with 85k+ plates 🚀 . 13 | 14 | [0.2.0]: https://github.com/ankandrew/fast-plate-ocr/compare/v0.2.0...v0.3.0 15 | 16 | ## [0.2.0] - 2024-10-14 17 | 18 | ### Added 19 | 20 | - New European model using MobileViTV2 - trained on +40 countries 🚀 . 21 | - Added more logging to train script. 22 | 23 | [0.2.0]: https://github.com/ankandrew/fast-plate-ocr/compare/v0.1.6...v0.2.0 24 | 25 | ## [0.1.6] - 2024-05-09 26 | 27 | ### Added 28 | 29 | - Add new Argentinian model trained with more (synthetic) data. 30 | - Add option to visualize only predictions which have low char prob. 31 | - Add onnxsim for simplifying ONNX model when exporting. 32 | 33 | [0.1.6]: https://github.com/ankandrew/fast-plate-ocr/compare/v0.1.5...v0.1.6 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 ankandrew 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Directories 2 | SRC_PATHS := fast_plate_ocr/ test/ 3 | 4 | # Tasks 5 | .PHONY: help 6 | help: 7 | @echo "Available targets:" 8 | @echo " help : Show this help message" 9 | @echo " format : Format code using Ruff format" 10 | @echo " check_format : Check code formatting with Ruff format" 11 | @echo " ruff : Run Ruff linter" 12 | @echo " pylint : Run Pylint linter" 13 | @echo " mypy : Run MyPy static type checker" 14 | @echo " lint : Run linters (Ruff, Pylint and Mypy)" 15 | @echo " test : Run tests using pytest" 16 | @echo " checks : Check format, lint, and test" 17 | @echo " clean : Clean up caches and build artifacts" 18 | 19 | .PHONY: format 20 | format: 21 | @echo "==> Sorting imports..." 22 | @# Currently, the Ruff formatter does not sort imports, see https://docs.astral.sh/ruff/formatter/#sorting-imports 23 | @poetry run ruff check --select I --fix $(SRC_PATHS) 24 | @echo "=====> Formatting code..." 25 | @poetry run ruff format $(SRC_PATHS) 26 | 27 | .PHONY: check_format 28 | check_format: 29 | @echo "=====> Checking format..." 30 | @poetry run ruff format --check --diff $(SRC_PATHS) 31 | @echo "=====> Checking imports are sorted..." 32 | @poetry run ruff check --select I --exit-non-zero-on-fix $(SRC_PATHS) 33 | 34 | .PHONY: ruff 35 | ruff: 36 | @echo "=====> Running Ruff..." 37 | @poetry run ruff check $(SRC_PATHS) 38 | 39 | .PHONY: pylint 40 | pylint: 41 | @echo "=====> Running Pylint..." 42 | @poetry run pylint $(SRC_PATHS) 43 | 44 | .PHONY: mypy 45 | mypy: 46 | @echo "=====> Running Mypy..." 47 | @poetry run mypy $(SRC_PATHS) 48 | 49 | .PHONY: lint 50 | lint: ruff pylint mypy 51 | 52 | .PHONY: test 53 | test: 54 | @echo "=====> Running tests..." 55 | @poetry run pytest test/ 56 | 57 | .PHONY: clean 58 | clean: 59 | @echo "=====> Cleaning caches..." 60 | @poetry run ruff clean 61 | @rm -rf .cache .pytest_cache .mypy_cache build dist *.egg-info 62 | 63 | checks: format lint test 64 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Fast & Lightweight License Plate OCR 2 | 3 | [![Actions status](https://github.com/ankandrew/fast-plate-ocr/actions/workflows/main.yaml/badge.svg)](https://github.com/ankandrew/fast-plate-ocr/actions) 4 | [![Keras 3](https://img.shields.io/badge/Keras-3-red?logo=keras&logoColor=red&labelColor=white)](https://keras.io/keras_3/) 5 | [![image](https://img.shields.io/pypi/v/fast-plate-ocr.svg)](https://pypi.python.org/pypi/fast-plate-ocr) 6 | [![image](https://img.shields.io/pypi/pyversions/fast-plate-ocr.svg)](https://pypi.python.org/pypi/fast-plate-ocr) 7 | [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) 8 | [![Pylint](https://img.shields.io/badge/linting-pylint-yellowgreen)](https://github.com/pylint-dev/pylint) 9 | [![Checked with mypy](http://www.mypy-lang.org/static/mypy_badge.svg)](http://mypy-lang.org/) 10 | [![ONNX Model](https://img.shields.io/badge/model-ONNX-blue?logo=onnx&logoColor=white)](https://onnx.ai/) 11 | [![Hugging Face Spaces](https://img.shields.io/badge/🤗%20Hugging%20Face-Spaces-orange)](https://huggingface.co/spaces/ankandrew/fast-alpr) 12 | [![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg)](https://ankandrew.github.io/fast-plate-ocr/) 13 | [![image](https://img.shields.io/pypi/l/fast-plate-ocr.svg)](https://pypi.python.org/pypi/fast-plate-ocr) 14 | 15 | ![Intro](https://raw.githubusercontent.com/ankandrew/fast-plate-ocr/4a7dd34c9803caada0dc50a33b59487b63dd4754/extra/demo.gif) 16 | 17 | --- 18 | 19 | ### Introduction 20 | 21 | **Lightweight** and **fast** OCR models for license plate text recognition. You can train models from scratch or use 22 | the trained models for inference. 23 | 24 | The idea is to use this after a plate object detector, since the OCR expects the cropped plates. 25 | 26 | ### Features 27 | 28 | - **Keras 3 Backend Support**: Compatible with **[TensorFlow](https://www.tensorflow.org/)**, **[JAX](https://github.com/google/jax)**, and **[PyTorch](https://pytorch.org/)** backends 🧠 29 | - **Augmentation Variety**: Diverse **augmentations** via **[Albumentations](https://albumentations.ai/)** library 🖼️ 30 | - **Efficient Execution**: **Lightweight** models that are cheap to run 💰 31 | - **ONNX Runtime Inference**: **Fast** and **optimized** inference with **[ONNX runtime](https://onnxruntime.ai/)** ⚡ 32 | - **User-Friendly CLI**: Simplified **CLI** for **training** and **validating** OCR models 🛠️ 33 | - **Model HUB**: Access to a collection of **pre-trained models** ready for inference 🌟 34 | 35 | ### Available Models 36 | 37 | | Model Name | Time b=1
(ms)[1] | Throughput
(plates/second)[1] | Accuracy[2] | Dataset | 38 | |:----------------------------------------:|:--------------------------------:|:----------------------------------------------:|:----------------------:|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| 39 | | `argentinian-plates-cnn-model` | 2.1 | 476 | 94.05% | Non-synthetic, plates up to 2020. Dataset [arg_plate_dataset.zip](https://github.com/ankandrew/fast-plate-ocr/releases/download/arg-plates/arg_plate_dataset.zip). | 40 | | `argentinian-plates-cnn-synth-model` | 2.1 | 476 | 94.19% | Plates up to 2020 + synthetic plates. Dataset [arg_plate_dataset_plus_synth.zip](https://github.com/ankandrew/fast-plate-ocr/releases/download/arg-plates/arg_plate_dataset_plus_synth.zip). | 41 | | `european-plates-mobile-vit-v2-model` | 2.9 | 344 | 92.5%[3] | European plates (from +40 countries, trained on 40k+ plates). | 42 | | 🆕🔥 `global-plates-mobile-vit-v2-model` | 2.9 | 344 | 93.3%[4] | Worldwide plates (from +65 countries, trained on 85k+ plates). | 43 | 44 | > [!TIP] 45 | > Try `fast-plate-ocr` pre-trained models in [Hugging Spaces](https://huggingface.co/spaces/ankandrew/fast-alpr). 46 | 47 |
48 | Notes 49 | 50 | _[1] Inference on Mac M1 chip using CPUExecutionProvider. Utilizing CoreMLExecutionProvider accelerates speed by 5x in the CNN models._ 51 | 52 | _[2] Accuracy is what we refer to as plate_acc. See [metrics section](#model-metrics)._ 53 | 54 | _[3] For detailed accuracy for each country see [results](https://github.com/ankandrew/fast-plate-ocr/releases/download/arg-plates/european_mobile_vit_v2_ocr_results.json) and the corresponding [val split](https://github.com/ankandrew/fast-plate-ocr/releases/download/arg-plates/european_mobile_vit_v2_ocr_val.zip) used._ 55 | 56 | _[4] For detailed accuracy for each country see [results](https://github.com/ankandrew/fast-plate-ocr/releases/download/arg-plates/global_mobile_vit_v2_ocr_results.json)._ 57 | 58 |
59 | 60 |
61 | Reproduce results 62 | 63 | * Calculate Inference Time: 64 | 65 | ```shell 66 | pip install fast_plate_ocr 67 | ``` 68 | 69 | ```python 70 | from fast_plate_ocr import ONNXPlateRecognizer 71 | 72 | m = ONNXPlateRecognizer("argentinian-plates-cnn-model") 73 | m.benchmark() 74 | ``` 75 | * Calculate Model accuracy: 76 | 77 | ```shell 78 | pip install fast-plate-ocr[train] 79 | curl -LO https://github.com/ankandrew/fast-plate-ocr/releases/download/arg-plates/arg_cnn_ocr_config.yaml 80 | curl -LO https://github.com/ankandrew/fast-plate-ocr/releases/download/arg-plates/arg_cnn_ocr.keras 81 | curl -LO https://github.com/ankandrew/fast-plate-ocr/releases/download/arg-plates/arg_plate_benchmark.zip 82 | unzip arg_plate_benchmark.zip 83 | fast_plate_ocr valid \ 84 | -m arg_cnn_ocr.keras \ 85 | --config-file arg_cnn_ocr_config.yaml \ 86 | --annotations benchmark/annotations.csv 87 | ``` 88 | 89 |
90 | 91 | ### Inference 92 | 93 | For inference, install: 94 | 95 | ```shell 96 | pip install fast_plate_ocr 97 | ``` 98 | 99 | #### Usage 100 | 101 | To predict from disk image: 102 | 103 | ```python 104 | from fast_plate_ocr import ONNXPlateRecognizer 105 | 106 | m = ONNXPlateRecognizer('argentinian-plates-cnn-model') 107 | print(m.run('test_plate.png')) 108 | ``` 109 | 110 |
111 | run demo 112 | 113 | ![Run demo](https://github.com/ankandrew/fast-plate-ocr/blob/ac3d110c58f62b79072e3a7af15720bb52a45e4e/extra/inference_demo.gif?raw=true) 114 | 115 |
116 | 117 | To run model benchmark: 118 | 119 | ```python 120 | from fast_plate_ocr import ONNXPlateRecognizer 121 | 122 | m = ONNXPlateRecognizer('argentinian-plates-cnn-model') 123 | m.benchmark() 124 | ``` 125 | 126 |
127 | benchmark demo 128 | 129 | ![Benchmark demo](https://github.com/ankandrew/fast-plate-ocr/blob/ac3d110c58f62b79072e3a7af15720bb52a45e4e/extra/benchmark_demo.gif?raw=true) 130 | 131 |
132 | 133 | Make sure to check out the [docs](https://ankandrew.github.io/fast-plate-ocr) for more information. 134 | 135 | ### CLI 136 | 137 | CLI 138 | 139 | To train or use the CLI tool, you'll need to install: 140 | 141 | ```shell 142 | pip install fast_plate_ocr[train] 143 | ``` 144 | 145 | > [!IMPORTANT] 146 | > Make sure you have installed a supported backend for Keras. 147 | 148 | #### Train Model 149 | 150 | To train the model you will need: 151 | 152 | 1. A configuration used for the OCR model. Depending on your use case, you might have more plate slots or different set 153 | of characters. Take a look at the config for Argentinian license plate as an example: 154 | ```yaml 155 | # Config example for Argentinian License Plates 156 | # The old license plates contain 6 slots/characters (i.e. JUH697) 157 | # and new 'Mercosur' contain 7 slots/characters (i.e. AB123CD) 158 | 159 | # Max number of plate slots supported. This represents the number of model classification heads. 160 | max_plate_slots: 7 161 | # All the possible character set for the model output. 162 | alphabet: '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_' 163 | # Padding character for plates which length is smaller than MAX_PLATE_SLOTS. It should still be present in the alphabet. 164 | pad_char: '_' 165 | # Image height which is fed to the model. 166 | img_height: 70 167 | # Image width which is fed to the model. 168 | img_width: 140 169 | ``` 170 | 2. A labeled dataset, 171 | see [arg_plate_dataset.zip](https://github.com/ankandrew/fast-plate-ocr/releases/download/arg-plates/arg_plate_dataset.zip) 172 | for the expected data format. 173 | 3. Run train script: 174 | ```shell 175 | # You can set the backend to either TensorFlow, JAX or PyTorch 176 | # (just make sure it is installed) 177 | KERAS_BACKEND=tensorflow fast_plate_ocr train \ 178 | --annotations path_to_the_train.csv \ 179 | --val-annotations path_to_the_val.csv \ 180 | --config-file config.yaml \ 181 | --batch-size 128 \ 182 | --epochs 750 \ 183 | --dense \ 184 | --early-stopping-patience 100 \ 185 | --reduce-lr-patience 50 186 | ``` 187 | 188 | You will probably want to change the augmentation pipeline to apply to your dataset. 189 | 190 | In order to do this define an Albumentations pipeline: 191 | 192 | ```python 193 | import albumentations as A 194 | 195 | transform_pipeline = A.Compose( 196 | [ 197 | # ... 198 | A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=1), 199 | A.MotionBlur(blur_limit=(3, 5), p=0.1), 200 | A.CoarseDropout(max_holes=10, max_height=4, max_width=4, p=0.3), 201 | # ... and any other augmentation ... 202 | ] 203 | ) 204 | 205 | # Export to a file (this resultant YAML can be used by the train script) 206 | A.save(transform_pipeline, "./transform_pipeline.yaml", data_format="yaml") 207 | ``` 208 | 209 | And then you can train using the custom transformation pipeline with the `--augmentation-path` option. 210 | 211 | #### Visualize Augmentation 212 | 213 | It's useful to visualize the augmentation pipeline before training the model. This helps us to identify 214 | if we should apply more heavy augmentation or less, as it can hurt the model. 215 | 216 | You might want to see the augmented image next to the original, to see how much it changed: 217 | 218 | ```shell 219 | fast_plate_ocr visualize-augmentation \ 220 | --img-dir benchmark/imgs \ 221 | --columns 2 \ 222 | --show-original \ 223 | --augmentation-path '/transform_pipeline.yaml' 224 | ``` 225 | 226 | You will see something like: 227 | 228 | ![Augmented Images](https://github.com/ankandrew/fast-plate-ocr/blob/ac3d110c58f62b79072e3a7af15720bb52a45e4e/extra/image_augmentation.gif?raw=true) 229 | 230 | #### Validate Model 231 | 232 | After finishing training you can validate the model on a labeled test dataset. 233 | 234 | Example: 235 | 236 | ```shell 237 | fast_plate_ocr valid \ 238 | --model arg_cnn_ocr.keras \ 239 | --config-file arg_plate_example.yaml \ 240 | --annotations benchmark/annotations.csv 241 | ``` 242 | 243 | #### Visualize Predictions 244 | 245 | Once you finish training your model, you can view the model predictions on raw data with: 246 | 247 | ```shell 248 | fast_plate_ocr visualize-predictions \ 249 | --model arg_cnn_ocr.keras \ 250 | --img-dir benchmark/imgs \ 251 | --config-file arg_cnn_ocr_config.yaml 252 | ``` 253 | 254 | You will see something like: 255 | 256 | ![Visualize Predictions](https://github.com/ankandrew/fast-plate-ocr/blob/ac3d110c58f62b79072e3a7af15720bb52a45e4e/extra/visualize_predictions.gif?raw=true) 257 | 258 | #### Export as ONNX 259 | 260 | Exporting the Keras model to ONNX format might be beneficial to speed-up inference time. 261 | 262 | ```shell 263 | fast_plate_ocr export-onnx \ 264 | --model arg_cnn_ocr.keras \ 265 | --output-path arg_cnn_ocr.onnx \ 266 | --opset 18 \ 267 | --config-file arg_cnn_ocr_config.yaml 268 | ``` 269 | 270 | ### Keras Backend 271 | 272 | To train the model, you can install the ML Framework you like the most. **Keras 3** has 273 | support for **TensorFlow**, **JAX** and **PyTorch** backends. 274 | 275 | To change the Keras backend you can either: 276 | 277 | 1. Export `KERAS_BACKEND` environment variable, i.e. to use JAX for training: 278 | ```shell 279 | KERAS_BACKEND=jax fast_plate_ocr train --config-file ... 280 | ``` 281 | 2. Edit your local config file at `~/.keras/keras.json`. 282 | 283 | _Note: You will probably need to install your desired framework for training._ 284 | 285 | ### Model Architecture 286 | 287 | The current model architecture is quite simple but effective. 288 | See [cnn_ocr_model](https://github.com/ankandrew/cnn-ocr-lp/blob/e59b738bad86d269c82101dfe7a3bef49b3a77c7/fast_plate_ocr/train/model/models.py#L23-L23) 289 | for implementation details. 290 | 291 | The model output consists of several heads. Each head represents the prediction of a character of the 292 | plate. If the plate consists of 7 characters at most (`max_plate_slots=7`), then the model would have 7 heads. 293 | 294 | Example of Argentinian plates: 295 | 296 | ![Model head](https://raw.githubusercontent.com/ankandrew/fast-plate-ocr/4a7dd34c9803caada0dc50a33b59487b63dd4754/extra/FCN.png) 297 | 298 | Each head will output a probability distribution over the `vocabulary` specified during training. So the output 299 | prediction for a single plate will be of shape `(max_plate_slots, vocabulary_size)`. 300 | 301 | ### Model Metrics 302 | 303 | During training, you will see the following metrics 304 | 305 | * **plate_acc**: Compute the number of **license plates** that were **fully classified**. For a single plate, if the 306 | ground truth is `ABC123` and the prediction is also `ABC123`, it would score 1. However, if the prediction was 307 | `ABD123`, it would score 0, as **not all characters** were correctly classified. 308 | 309 | * **cat_acc**: Calculate the accuracy of **individual characters** within the license plates that were 310 | **correctly classified**. For example, if the correct label is `ABC123` and the prediction is `ABC133`, it would yield 311 | a precision of 83.3% (5 out of 6 characters correctly classified), rather than 0% as in plate_acc, because it's not 312 | completely classified correctly. 313 | 314 | * **top_3_k**: Calculate how frequently the true character is included in the **top-3 predictions** 315 | (the three predictions with the highest probability). 316 | 317 | ### Contributing 318 | 319 | Contributions to the repo are greatly appreciated. Whether it's bug fixes, feature enhancements, or new models, 320 | your contributions are warmly welcomed. 321 | 322 | To start contributing or to begin development, you can follow these steps: 323 | 324 | 1. Clone repo 325 | ```shell 326 | git clone https://github.com/ankandrew/fast-plate-ocr.git 327 | ``` 328 | 2. Install all dependencies using [Poetry](https://python-poetry.org/docs/#installation): 329 | ```shell 330 | poetry install --all-extras 331 | ``` 332 | 3. To ensure your changes pass linting and tests before submitting a PR: 333 | ```shell 334 | make checks 335 | ``` 336 | 337 | If you want to train a model and share it, we'll add it to the HUB 🚀 338 | 339 | If you look to contribute to the repo, some cool things are in the backlog: 340 | 341 | - [ ] Implement [STN](https://arxiv.org/abs/1506.02025) using Keras 3 (With `keras.ops`) 342 | - [ ] Implement [SVTRv2](https://arxiv.org/abs/2411.15858). 343 | - [ ] Implement CTC loss function, so we can choose that or CE loss. 344 | - [ ] Extra head for country recognition, making it configurable. 345 | -------------------------------------------------------------------------------- /config/arg_plate_example.yaml: -------------------------------------------------------------------------------- 1 | # Config example for Argentinian License Plates 2 | # The old license plates contain 6 slots/characters (i.e. JUH697) 3 | # and new 'Mercosur' contain 7 slots/characters (i.e. AB123CD) 4 | 5 | # Max number of plate slots supported. This represents the number of model classification heads. 6 | max_plate_slots: 7 7 | # All the possible character set for the model output. 8 | alphabet: '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_' 9 | # Padding character for plates which length is smaller than MAX_PLATE_SLOTS. It should still be present in the alphabet. 10 | pad_char: '_' 11 | # Image height which is fed to the model. 12 | img_height: 70 13 | # Image width which is fed to the model. 14 | img_width: 140 15 | -------------------------------------------------------------------------------- /config/latin_plate_example.yaml: -------------------------------------------------------------------------------- 1 | # Config example for Latin plates from 70 countries 2 | 3 | # Max number of plate slots supported. This represents the number of model classification heads. 4 | max_plate_slots: 9 5 | # All the possible character set for the model output. 6 | alphabet: '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_' 7 | # Padding character for plates which length is smaller than MAX_PLATE_SLOTS. It should still be present in the alphabet. 8 | pad_char: '_' 9 | # Image height which is fed to the model. 10 | img_height: 70 11 | # Image width which is fed to the model. 12 | img_width: 140 13 | -------------------------------------------------------------------------------- /docs/architecture.md: -------------------------------------------------------------------------------- 1 | ### ConvNet (CNN) model 2 | 3 | The current model architecture is quite simple but effective. It just consists of a few CNN layers with several output 4 | heads. 5 | See [cnn_ocr_model](https://github.com/ankandrew/cnn-ocr-lp/blob/e59b738bad86d269c82101dfe7a3bef49b3a77c7/fast_plate_ocr/train/model/models.py#L23-L23) 6 | for implementation details. 7 | 8 | The model output consists of several heads. Each head represents the prediction of a character of the 9 | plate. If the plate consists of 7 characters at most (`max_plate_slots=7`), then the model would have 7 heads. 10 | 11 | Example of Argentinian plates: 12 | 13 | ![Model head](https://raw.githubusercontent.com/ankandrew/fast-plate-ocr/4a7dd34c9803caada0dc50a33b59487b63dd4754/extra/FCN.png) 14 | 15 | Each head will output a probability distribution over the `vocabulary` specified during training. So the output 16 | prediction for a single plate will be of shape `(max_plate_slots, vocabulary_size)`. 17 | 18 | ### Model Metrics 19 | 20 | During training, you will see the following metrics 21 | 22 | * **plate_acc**: Compute the number of **license plates** that were **fully classified**. For a single plate, if the 23 | ground truth is `ABC123` and the prediction is also `ABC123`, it would score 1. However, if the prediction was 24 | `ABD123`, it would score 0, as **not all characters** were correctly classified. 25 | 26 | * **cat_acc**: Calculate the accuracy of **individual characters** within the license plates that were 27 | **correctly classified**. For example, if the correct label is `ABC123` and the prediction is `ABC133`, it would yield 28 | a precision of 83.3% (5 out of 6 characters correctly classified), rather than 0% as in plate_acc, because it's not 29 | completely classified correctly. 30 | 31 | * **top_3_k**: Calculate how frequently the true character is included in the **top-3 predictions** 32 | (the three predictions with the highest probability). 33 | -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | Contributions are greatly appreciated. Whether it's bug fixes, feature enhancements, or new models, 2 | your contributions are warmly welcomed. 3 | 4 | To start contributing or to begin development, you can follow these steps: 5 | 6 | 1. Clone repo 7 | ```shell 8 | git clone https://github.com/ankandrew/fast-plate-ocr.git 9 | ``` 10 | 2. Install all dependencies using [Poetry](https://python-poetry.org/docs/#installation): 11 | ```shell 12 | poetry install --all-extras 13 | ``` 14 | 3. To ensure your changes pass linting and tests before submitting a PR: 15 | ```shell 16 | make checks 17 | ``` 18 | 19 | ???+ tip 20 | If you want to train a model and share it, we'll add it to the HUB 🚀 21 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Fast & Lightweight License Plate OCR 2 | 3 | ![Intro](https://raw.githubusercontent.com/ankandrew/fast-plate-ocr/4a7dd34c9803caada0dc50a33b59487b63dd4754/extra/demo.gif) 4 | 5 | **FastPlateOCR** is a **lightweight** and **fast** OCR framework for **license plate text recognition**. You can train 6 | models from scratch or use the trained models for inference. 7 | 8 | The idea is to use this after a plate object detector, since the OCR expects the cropped plates. 9 | 10 | ### Features 11 | 12 | - **Keras 3 Backend Support**: Compatible with **TensorFlow**, **JAX**, and **PyTorch** backends 🧠 13 | - **Augmentation Variety**: Diverse augmentations via **Albumentations** library 🖼️ 14 | - **Efficient Execution**: **Lightweight** models that are cheap to run 💰 15 | - **ONNX Runtime Inference**: **Fast** and **optimized** inference with ONNX runtime ⚡ 16 | - **User-Friendly CLI**: Simplified **CLI** for **training** and **validating** OCR models 🛠️ 17 | - **Model HUB**: Access to a collection of pre-trained models ready for inference 🌟 18 | 19 | ### Model Zoo 20 | 21 | We currently have the following available models: 22 | 23 | | Model Name | Time b=1
(ms)[1] | Throughput
(plates/second)[1] | Accuracy[2] | Dataset | 24 | |:----------------------------------------:|:--------------------------------:|:----------------------------------------------:|:----------------------:|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| 25 | | `argentinian-plates-cnn-model` | 2.1 | 476 | 94.05% | Non-synthetic, plates up to 2020. Dataset [arg_plate_dataset.zip](https://github.com/ankandrew/fast-plate-ocr/releases/download/arg-plates/arg_plate_dataset.zip). | 26 | | `argentinian-plates-cnn-synth-model` | 2.1 | 476 | 94.19% | Plates up to 2020 + synthetic plates. Dataset [arg_plate_dataset_plus_synth.zip](https://github.com/ankandrew/fast-plate-ocr/releases/download/arg-plates/arg_plate_dataset_plus_synth.zip). | 27 | | `european-plates-mobile-vit-v2-model` | 2.9 | 344 | 92.5%[3] | European plates (from +40 countries, trained on 40k+ plates). | 28 | | 🆕🔥 `global-plates-mobile-vit-v2-model` | 2.9 | 344 | 93.3%[4] | Worldwide plates (from +65 countries, trained on 85k+ plates). | 29 | 30 | _[1] Inference on Mac M1 chip using CPUExecutionProvider. Utilizing CoreMLExecutionProvider accelerates speed 31 | by 5x in the CNN models._ 32 | 33 | _[2] Accuracy is what we refer as plate_acc. See [metrics section](#model-metrics)._ 34 | 35 | _[3] For detailed accuracy for each country see [results](https://github.com/ankandrew/fast-plate-ocr/releases/download/arg-plates/european_mobile_vit_v2_ocr_results.json) and 36 | the corresponding [val split](https://github.com/ankandrew/fast-plate-ocr/releases/download/arg-plates/european_mobile_vit_v2_ocr_val.zip) used._ 37 | 38 | _[4] For detailed accuracy for each country see [results](https://github.com/ankandrew/fast-plate-ocr/releases/download/arg-plates/global_mobile_vit_v2_ocr_results.json)._ 39 | 40 |
41 | Reproduce results. 42 | 43 | Calculate Inference Time: 44 | 45 | ```shell 46 | pip install fast_plate_ocr 47 | ``` 48 | 49 | ```python 50 | from fast_plate_ocr import ONNXPlateRecognizer 51 | 52 | m = ONNXPlateRecognizer("argentinian-plates-cnn-model") 53 | m.benchmark() 54 | ``` 55 | 56 | Calculate Model accuracy: 57 | 58 | ```shell 59 | pip install fast-plate-ocr[train] 60 | curl -LO https://github.com/ankandrew/fast-plate-ocr/releases/download/arg-plates/arg_cnn_ocr_config.yaml 61 | curl -LO https://github.com/ankandrew/fast-plate-ocr/releases/download/arg-plates/arg_cnn_ocr.keras 62 | curl -LO https://github.com/ankandrew/fast-plate-ocr/releases/download/arg-plates/arg_plate_benchmark.zip 63 | unzip arg_plate_benchmark.zip 64 | fast_plate_ocr valid \ 65 | -m arg_cnn_ocr.keras \ 66 | --config-file arg_cnn_ocr_config.yaml \ 67 | --annotations benchmark/annotations.csv 68 | ``` 69 | 70 |
71 | -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | ### Inference 2 | 3 | For **inference**, install: 4 | 5 | ```shell 6 | pip install fast_plate_ocr 7 | ``` 8 | 9 | ### Train 10 | 11 | To **train** or use the **CLI tool**, you'll need to install: 12 | 13 | ```shell 14 | pip install fast_plate_ocr[train] 15 | ``` 16 | 17 | ???+ info 18 | You will probably need to **install** your desired framework for training. FastPlateOCR doesn't 19 | enforce you to use any specific framework. See [Keras backend](usage.md#keras-backend) section. 20 | -------------------------------------------------------------------------------- /docs/reference.md: -------------------------------------------------------------------------------- 1 | ::: fast_plate_ocr.inference.onnx_inference 2 | -------------------------------------------------------------------------------- /docs/usage.md: -------------------------------------------------------------------------------- 1 | ### API 2 | 3 | To predict from disk image: 4 | 5 | ```python 6 | from fast_plate_ocr import ONNXPlateRecognizer 7 | 8 | m = ONNXPlateRecognizer('argentinian-plates-cnn-model') 9 | print(m.run('test_plate.png')) 10 | ``` 11 | 12 |
13 | Demo 14 | 15 |
16 | Inference Demo 17 |
18 | 19 |
20 | 21 | To run model benchmark: 22 | 23 | ```python 24 | from fast_plate_ocr import ONNXPlateRecognizer 25 | 26 | m = ONNXPlateRecognizer('argentinian-plates-cnn-model') 27 | m.benchmark() 28 | ``` 29 | 30 |
31 | Demo 32 | 33 |
34 | Benchmark Demo 35 |
36 | 37 |
38 | 39 | For a full list of options see [Reference](reference.md). 40 | 41 | ### CLI 42 | 43 | CLI 44 | 45 | To train or use the CLI tool, you'll need to install: 46 | 47 | ```shell 48 | pip install fast_plate_ocr[train] 49 | ``` 50 | 51 | #### Train Model 52 | 53 | To train the model you will need: 54 | 55 | 1. A configuration used for the OCR model. Depending on your use case, you might have more plate slots or different set 56 | of characters. Take a look at the config for Argentinian license plate as an example: 57 | ```yaml 58 | # Config example for Argentinian License Plates 59 | # The old license plates contain 6 slots/characters (i.e. JUH697) 60 | # and new 'Mercosur' contain 7 slots/characters (i.e. AB123CD) 61 | 62 | # Max number of plate slots supported. This represents the number of model classification heads. 63 | max_plate_slots: 7 64 | # All the possible character set for the model output. 65 | alphabet: '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_' 66 | # Padding character for plates which length is smaller than MAX_PLATE_SLOTS. It should still be present in the alphabet. 67 | pad_char: '_' 68 | # Image height which is fed to the model. 69 | img_height: 70 70 | # Image width which is fed to the model. 71 | img_width: 140 72 | ``` 73 | 2. A labeled dataset, 74 | see [arg_plate_dataset.zip](https://github.com/ankandrew/fast-plate-ocr/releases/download/arg-plates/arg_plate_dataset.zip) 75 | for the expected data format. 76 | 3. Run train script: 77 | ```shell 78 | # You can set the backend to either TensorFlow, JAX or PyTorch 79 | # (just make sure it is installed) 80 | KERAS_BACKEND=tensorflow fast_plate_ocr train \ 81 | --annotations path_to_the_train.csv \ 82 | --val-annotations path_to_the_val.csv \ 83 | --config-file config.yaml \ 84 | --batch-size 128 \ 85 | --epochs 750 \ 86 | --dense \ 87 | --early-stopping-patience 100 \ 88 | --reduce-lr-patience 50 89 | ``` 90 | 91 | You will probably want to change the augmentation pipeline to apply to your dataset. 92 | 93 | In order to do this define an Albumentations pipeline: 94 | 95 | ```python 96 | import albumentations as A 97 | 98 | transform_pipeline = A.Compose( 99 | [ 100 | # ... 101 | A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=1), 102 | A.MotionBlur(blur_limit=(3, 5), p=0.1), 103 | A.CoarseDropout(max_holes=10, max_height=4, max_width=4, p=0.3), 104 | # ... and any other augmentation ... 105 | ] 106 | ) 107 | 108 | # Export to a file (this resultant YAML can be used by the train script) 109 | A.save(transform_pipeline, "./transform_pipeline.yaml", data_format="yaml") 110 | ``` 111 | 112 | And then you can train using the custom transformation pipeline with the `--augmentation-path` option. 113 | 114 | #### Visualize Augmentation 115 | 116 | It's useful to visualize the augmentation pipeline before training the model. This helps us to identify 117 | if we should apply more heavy augmentation or less, as it can hurt the model. 118 | 119 | You might want to see the augmented image next to the original, to see how much it changed: 120 | 121 | ```shell 122 | fast_plate_ocr visualize-augmentation \ 123 | --img-dir benchmark/imgs \ 124 | --columns 2 \ 125 | --show-original \ 126 | --augmentation-path '/transform_pipeline.yaml' 127 | ``` 128 | 129 | You will see something like: 130 | 131 | ![Augmented Images](https://github.com/ankandrew/fast-plate-ocr/blob/ac3d110c58f62b79072e3a7af15720bb52a45e4e/extra/image_augmentation.gif?raw=true) 132 | 133 | #### Validate Model 134 | 135 | After finishing training you can validate the model on a labeled test dataset. 136 | 137 | Example: 138 | 139 | ```shell 140 | fast_plate_ocr valid \ 141 | --model arg_cnn_ocr.keras \ 142 | --config-file arg_plate_example.yaml \ 143 | --annotations benchmark/annotations.csv 144 | ``` 145 | 146 | #### Visualize Predictions 147 | 148 | Once you finish training your model, you can view the model predictions on raw data with: 149 | 150 | ```shell 151 | fast_plate_ocr visualize-predictions \ 152 | --model arg_cnn_ocr.keras \ 153 | --img-dir benchmark/imgs \ 154 | --config-file arg_cnn_ocr_config.yaml 155 | ``` 156 | 157 | You will see something like: 158 | 159 | ![Visualize Predictions](https://github.com/ankandrew/fast-plate-ocr/blob/ac3d110c58f62b79072e3a7af15720bb52a45e4e/extra/visualize_predictions.gif?raw=true) 160 | 161 | #### Export as ONNX 162 | 163 | Exporting the Keras model to ONNX format might be beneficial to speed-up inference time. 164 | 165 | ```shell 166 | fast_plate_ocr export-onnx \ 167 | --model arg_cnn_ocr.keras \ 168 | --output-path arg_cnn_ocr.onnx \ 169 | --opset 18 \ 170 | --config-file arg_cnn_ocr_config.yaml 171 | ``` 172 | 173 | ### Keras Backend 174 | 175 | To train the model, you can install the ML Framework you like the most. **Keras 3** has 176 | support for **TensorFlow**, **JAX** and **PyTorch** backends. 177 | 178 | To change the Keras backend you can either: 179 | 180 | 1. Export `KERAS_BACKEND` environment variable, i.e. to use JAX for training: 181 | ```shell 182 | KERAS_BACKEND=jax fast_plate_ocr train --config-file ... 183 | ``` 184 | 2. Edit your local config file at `~/.keras/keras.json`. 185 | 186 | ???+ tip 187 | **Usually training with JAX and TensorFlow is faster.** 188 | 189 | _Note: You will probably need to install your desired framework for training._ 190 | -------------------------------------------------------------------------------- /fast_plate_ocr/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Fast Plate OCR package. 3 | """ 4 | 5 | from fast_plate_ocr.inference.onnx_inference import ONNXPlateRecognizer 6 | 7 | __all__ = ["ONNXPlateRecognizer"] 8 | -------------------------------------------------------------------------------- /fast_plate_ocr/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ankandrew/fast-plate-ocr/49954297c2226ad6c488f6fbfb93169b579b815d/fast_plate_ocr/cli/__init__.py -------------------------------------------------------------------------------- /fast_plate_ocr/cli/cli.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main CLI used when training a FastPlateOCR model. 3 | """ 4 | 5 | try: 6 | import click 7 | 8 | from fast_plate_ocr.cli.onnx_converter import export_onnx 9 | from fast_plate_ocr.cli.train import train 10 | from fast_plate_ocr.cli.valid import valid 11 | from fast_plate_ocr.cli.visualize_augmentation import visualize_augmentation 12 | from fast_plate_ocr.cli.visualize_predictions import visualize_predictions 13 | except ImportError as e: 14 | raise ImportError("Make sure to 'pip install fast-plate-ocr[train]' to run this!") from e 15 | 16 | 17 | @click.group(context_settings={"max_content_width": 120}) 18 | def main_cli(): 19 | """FastPlateOCR CLI.""" 20 | 21 | 22 | main_cli.add_command(visualize_predictions) 23 | main_cli.add_command(visualize_augmentation) 24 | main_cli.add_command(valid) 25 | main_cli.add_command(train) 26 | main_cli.add_command(export_onnx) 27 | -------------------------------------------------------------------------------- /fast_plate_ocr/cli/onnx_converter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for converting Keras models to ONNX format. 3 | """ 4 | 5 | import logging 6 | import pathlib 7 | import shutil 8 | from tempfile import NamedTemporaryFile 9 | 10 | import click 11 | import numpy as np 12 | import onnx 13 | import onnxruntime as rt 14 | import onnxsim 15 | import tensorflow as tf 16 | import tf2onnx 17 | from tf2onnx import constants as tf2onnx_constants 18 | 19 | from fast_plate_ocr.common.utils import log_time_taken 20 | from fast_plate_ocr.train.model.config import load_config_from_yaml 21 | from fast_plate_ocr.train.utilities.utils import load_keras_model 22 | 23 | logging.basicConfig( 24 | level=logging.INFO, format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S" 25 | ) 26 | 27 | 28 | # pylint: disable=too-many-arguments,too-many-locals 29 | 30 | 31 | @click.command(context_settings={"max_content_width": 120}) 32 | @click.option( 33 | "-m", 34 | "--model", 35 | "model_path", 36 | required=True, 37 | type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=pathlib.Path), 38 | help="Path to the saved .keras model.", 39 | ) 40 | @click.option( 41 | "--output-path", 42 | required=True, 43 | type=str, 44 | help="Output name for ONNX model.", 45 | ) 46 | @click.option( 47 | "--simplify/--no-simplify", 48 | default=False, 49 | show_default=True, 50 | help="Simplify ONNX model using onnxsim.", 51 | ) 52 | @click.option( 53 | "--config-file", 54 | required=True, 55 | type=click.Path(exists=True, file_okay=True, path_type=pathlib.Path), 56 | help="Path pointing to the model license plate OCR config.", 57 | ) 58 | @click.option( 59 | "--opset", 60 | default=16, 61 | type=click.IntRange(max=max(tf2onnx_constants.OPSET_TO_IR_VERSION)), 62 | show_default=True, 63 | help="Opset version for ONNX.", 64 | ) 65 | def export_onnx( 66 | model_path: pathlib.Path, 67 | output_path: str, 68 | simplify: bool, 69 | config_file: pathlib.Path, 70 | opset: int, 71 | ) -> None: 72 | """ 73 | Export Keras models to ONNX format. 74 | """ 75 | config = load_config_from_yaml(config_file) 76 | model = load_keras_model( 77 | model_path, 78 | vocab_size=config.vocabulary_size, 79 | max_plate_slots=config.max_plate_slots, 80 | ) 81 | spec = (tf.TensorSpec((None, config.img_height, config.img_width, 1), tf.uint8, name="input"),) 82 | # Convert from Keras to ONNX using tf2onnx library 83 | with NamedTemporaryFile(suffix=".onnx") as tmp: 84 | tmp_onnx = tmp.name 85 | model_proto, _ = tf2onnx.convert.from_keras( 86 | model, 87 | input_signature=spec, 88 | opset=opset, 89 | output_path=tmp_onnx, 90 | ) 91 | if simplify: 92 | logging.info("Simplifying ONNX model ...") 93 | model_simp, check = onnxsim.simplify(onnx.load(tmp_onnx)) 94 | assert check, "Simplified ONNX model could not be validated!" 95 | onnx.save(model_simp, output_path) 96 | else: 97 | shutil.copy(tmp_onnx, output_path) 98 | output_names = [n.name for n in model_proto.graph.output] 99 | x = np.random.randint(0, 256, size=(1, config.img_height, config.img_width, 1), dtype=np.uint8) 100 | # Run dummy inference and log time taken 101 | m = rt.InferenceSession(output_path) 102 | with log_time_taken("ONNX inference took:"): 103 | onnx_pred = m.run(output_names, {"input": x}) 104 | # Check if ONNX and keras have the same results 105 | if not np.allclose(model.predict(x, verbose=0), onnx_pred[0], rtol=1e-5, atol=1e-5): 106 | logging.warning("ONNX model output was not close to Keras model for the given tolerance!") 107 | logging.info("Model converted to ONNX! Saved at %s", output_path) 108 | 109 | 110 | if __name__ == "__main__": 111 | export_onnx() 112 | -------------------------------------------------------------------------------- /fast_plate_ocr/cli/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for training the License Plate OCR models. 3 | """ 4 | 5 | import pathlib 6 | import shutil 7 | from datetime import datetime 8 | from typing import Literal 9 | 10 | import albumentations as A 11 | import click 12 | from keras.src.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, TensorBoard 13 | from keras.src.optimizers import Adam 14 | from torch.utils.data import DataLoader 15 | 16 | from fast_plate_ocr.cli.utils import print_params, print_train_details 17 | from fast_plate_ocr.train.data.augmentation import TRAIN_AUGMENTATION 18 | from fast_plate_ocr.train.data.dataset import LicensePlateDataset 19 | from fast_plate_ocr.train.model.config import load_config_from_yaml 20 | from fast_plate_ocr.train.model.custom import ( 21 | cat_acc_metric, 22 | cce_loss, 23 | plate_acc_metric, 24 | top_3_k_metric, 25 | ) 26 | from fast_plate_ocr.train.model.models import cnn_ocr_model 27 | 28 | # ruff: noqa: PLR0913 29 | # pylint: disable=too-many-arguments,too-many-locals 30 | 31 | 32 | @click.command(context_settings={"max_content_width": 120}) 33 | @click.option( 34 | "--dense/--no-dense", 35 | default=True, 36 | show_default=True, 37 | help="Whether to use Fully Connected layers in model head or not.", 38 | ) 39 | @click.option( 40 | "--config-file", 41 | required=True, 42 | type=click.Path(exists=True, file_okay=True, path_type=pathlib.Path), 43 | help="Path pointing to the model license plate OCR config.", 44 | ) 45 | @click.option( 46 | "--annotations", 47 | required=True, 48 | type=click.Path(exists=True, file_okay=True, path_type=pathlib.Path), 49 | help="Path pointing to the train annotations CSV file.", 50 | ) 51 | @click.option( 52 | "--val-annotations", 53 | required=True, 54 | type=click.Path(exists=True, file_okay=True, path_type=pathlib.Path), 55 | help="Path pointing to the train validation CSV file.", 56 | ) 57 | @click.option( 58 | "--augmentation-path", 59 | type=click.Path(exists=True, file_okay=True, path_type=pathlib.Path), 60 | help="YAML file pointing to the augmentation pipeline saved with Albumentations.save(...)", 61 | ) 62 | @click.option( 63 | "--lr", 64 | default=1e-3, 65 | show_default=True, 66 | type=float, 67 | help="Initial learning rate to use.", 68 | ) 69 | @click.option( 70 | "--label-smoothing", 71 | default=0.05, 72 | show_default=True, 73 | type=float, 74 | help="Amount of label smoothing to apply.", 75 | ) 76 | @click.option( 77 | "--batch-size", 78 | default=128, 79 | show_default=True, 80 | type=int, 81 | help="Batch size for training.", 82 | ) 83 | @click.option( 84 | "--num-workers", 85 | default=0, 86 | show_default=True, 87 | type=int, 88 | help="How many subprocesses to load data, used in the torch DataLoader.", 89 | ) 90 | @click.option( 91 | "--output-dir", 92 | default="./trained_models", 93 | type=click.Path(dir_okay=True, path_type=pathlib.Path), 94 | help="Output directory where model will be saved.", 95 | ) 96 | @click.option( 97 | "--epochs", 98 | default=500, 99 | show_default=True, 100 | type=int, 101 | help="Number of training epochs.", 102 | ) 103 | @click.option( 104 | "--tensorboard", 105 | "-t", 106 | is_flag=True, 107 | help="Whether to use TensorBoard visualization tool.", 108 | ) 109 | @click.option( 110 | "--tensorboard-dir", 111 | "-l", 112 | default="tensorboard_logs", 113 | show_default=True, 114 | type=click.Path(path_type=pathlib.Path), 115 | help="The path of the directory where to save the TensorBoard log files.", 116 | ) 117 | @click.option( 118 | "--early-stopping-patience", 119 | default=100, 120 | show_default=True, 121 | type=int, 122 | help="Stop training when 'val_plate_acc' doesn't improve for X epochs.", 123 | ) 124 | @click.option( 125 | "--reduce-lr-patience", 126 | default=60, 127 | show_default=True, 128 | type=int, 129 | help="Patience to reduce the learning rate if 'val_plate_acc' doesn't improve within X epochs.", 130 | ) 131 | @click.option( 132 | "--reduce-lr-factor", 133 | default=0.85, 134 | show_default=True, 135 | type=float, 136 | help="Reduce the learning rate by this factor when 'val_plate_acc' doesn't improve.", 137 | ) 138 | @click.option( 139 | "--activation", 140 | default="relu", 141 | show_default=True, 142 | type=str, 143 | help="Activation function to use.", 144 | ) 145 | @click.option( 146 | "--pool-layer", 147 | default="max", 148 | show_default=True, 149 | type=click.Choice(["max", "avg"]), 150 | help="Choose the pooling layer to use.", 151 | ) 152 | @click.option( 153 | "--weights-path", 154 | type=click.Path(exists=True, file_okay=True, path_type=pathlib.Path), 155 | help="Path to the pretrained model weights file.", 156 | ) 157 | @print_params(table_title="CLI Training Parameters", c1_title="Parameter", c2_title="Details") 158 | def train( 159 | dense: bool, 160 | config_file: pathlib.Path, 161 | annotations: pathlib.Path, 162 | val_annotations: pathlib.Path, 163 | augmentation_path: pathlib.Path | None, 164 | lr: float, 165 | label_smoothing: float, 166 | batch_size: int, 167 | num_workers: int, 168 | output_dir: pathlib.Path, 169 | epochs: int, 170 | tensorboard: bool, 171 | tensorboard_dir: pathlib.Path, 172 | early_stopping_patience: int, 173 | reduce_lr_patience: int, 174 | reduce_lr_factor: float, 175 | activation: str, 176 | pool_layer: Literal["max", "avg"], 177 | weights_path: pathlib.Path | None, 178 | ) -> None: 179 | """ 180 | Train the License Plate OCR model. 181 | """ 182 | train_augmentation = ( 183 | A.load(augmentation_path, data_format="yaml") if augmentation_path else TRAIN_AUGMENTATION 184 | ) 185 | config = load_config_from_yaml(config_file) 186 | print_train_details(train_augmentation, config.model_dump()) 187 | train_torch_dataset = LicensePlateDataset( 188 | annotations_file=annotations, 189 | transform=train_augmentation, 190 | config=config, 191 | ) 192 | train_dataloader = DataLoader( 193 | train_torch_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True 194 | ) 195 | 196 | if val_annotations: 197 | val_torch_dataset = LicensePlateDataset( 198 | annotations_file=val_annotations, 199 | config=config, 200 | ) 201 | val_dataloader = DataLoader( 202 | val_torch_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False 203 | ) 204 | else: 205 | val_dataloader = None 206 | 207 | # Train 208 | model = cnn_ocr_model( 209 | h=config.img_height, 210 | w=config.img_width, 211 | dense=dense, 212 | max_plate_slots=config.max_plate_slots, 213 | vocabulary_size=config.vocabulary_size, 214 | activation=activation, 215 | pool_layer=pool_layer, 216 | ) 217 | 218 | if weights_path: 219 | model.load_weights(weights_path) 220 | 221 | model.compile( 222 | loss=cce_loss(vocabulary_size=config.vocabulary_size, label_smoothing=label_smoothing), 223 | optimizer=Adam(lr), 224 | metrics=[ 225 | cat_acc_metric( 226 | max_plate_slots=config.max_plate_slots, vocabulary_size=config.vocabulary_size 227 | ), 228 | plate_acc_metric( 229 | max_plate_slots=config.max_plate_slots, vocabulary_size=config.vocabulary_size 230 | ), 231 | top_3_k_metric(vocabulary_size=config.vocabulary_size), 232 | ], 233 | ) 234 | 235 | output_dir /= datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 236 | output_dir.mkdir(parents=True, exist_ok=True) 237 | model_file_path = output_dir / "cnn_ocr-epoch_{epoch:02d}-acc_{val_plate_acc:.3f}.keras" 238 | 239 | # Save params and config used for training 240 | shutil.copy(config_file, output_dir / "config.yaml") 241 | A.save(train_augmentation, output_dir / "train_augmentation.yaml", "yaml") 242 | 243 | callbacks = [ 244 | # Reduce the learning rate by 0.5x if 'val_plate_acc' doesn't improve within X epochs 245 | ReduceLROnPlateau( 246 | "val_plate_acc", 247 | patience=reduce_lr_patience, 248 | factor=reduce_lr_factor, 249 | min_lr=1e-6, 250 | verbose=1, 251 | ), 252 | # Stop training when 'val_plate_acc' doesn't improve for X epochs 253 | EarlyStopping( 254 | monitor="val_plate_acc", 255 | patience=early_stopping_patience, 256 | mode="max", 257 | restore_best_weights=False, 258 | verbose=1, 259 | ), 260 | # We don't use EarlyStopping restore_best_weights=True because it won't restore the best 261 | # weights when it didn't manage to EarlyStop but finished all epochs 262 | ModelCheckpoint( 263 | model_file_path, 264 | monitor="val_plate_acc", 265 | mode="max", 266 | save_best_only=True, 267 | verbose=1, 268 | ), 269 | ] 270 | 271 | if tensorboard: 272 | run_dir = tensorboard_dir / datetime.now().strftime("run_%Y%m%d_%H%M%S") 273 | run_dir.mkdir(parents=True, exist_ok=True) 274 | callbacks.append(TensorBoard(log_dir=run_dir)) 275 | 276 | model.fit(train_dataloader, epochs=epochs, validation_data=val_dataloader, callbacks=callbacks) 277 | 278 | 279 | if __name__ == "__main__": 280 | train() 281 | -------------------------------------------------------------------------------- /fast_plate_ocr/cli/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils used for the CLI scripts. 3 | """ 4 | 5 | import inspect 6 | import pathlib 7 | from collections.abc import Callable 8 | from functools import wraps 9 | from typing import Any 10 | 11 | import albumentations as A 12 | from rich import box 13 | from rich.console import Console 14 | from rich.pretty import Pretty 15 | from rich.table import Table 16 | 17 | 18 | def print_variables_as_table( 19 | c1_title: str, c2_title: str, title: str = "Variables Table", **kwargs: Any 20 | ) -> None: 21 | """ 22 | Prints variables in a formatted table using the rich library. 23 | 24 | Args: 25 | c1_title (str): Title of the first column. 26 | c2_title (str): Title of the second column. 27 | title (str): Title of the table. 28 | **kwargs (Any): Variable names and values to be printed. 29 | """ 30 | console = Console() 31 | console.print("\n") 32 | table = Table(title=title, show_header=True, header_style="bold blue", box=box.ROUNDED) 33 | table.add_column(c1_title, min_width=20, justify="left", style="bold") 34 | table.add_column(c2_title, min_width=60, justify="left", style="bold") 35 | 36 | for key, value in kwargs.items(): 37 | if isinstance(value, pathlib.Path): 38 | value = str(value) # noqa: PLW2901 39 | table.add_row(f"[bold]{key}[/bold]", Pretty(value)) 40 | 41 | console.print(table) 42 | 43 | 44 | def print_params( 45 | table_title: str = "Parameters Table", c1_title: str = "Variable", c2_title: str = "Value" 46 | ) -> Callable: 47 | """ 48 | A decorator that prints the parameters of a function in a formatted table 49 | using the rich library. 50 | 51 | Args: 52 | c1_title (str, optional): Title of the first column. Defaults to "Variable". 53 | c2_title (str, optional): Title of the second column. Defaults to "Value". 54 | table_title (str, optional): Title of the table. Defaults to "Parameters Table". 55 | 56 | Returns: 57 | Callable: The wrapped function with parameter printing functionality. 58 | """ 59 | 60 | def decorator(func: Callable) -> Callable: 61 | @wraps(func) 62 | def wrapper(*args: Any, **kwargs: Any) -> Any: 63 | func_signature = inspect.signature(func) 64 | bound_arguments = func_signature.bind(*args, **kwargs) 65 | bound_arguments.apply_defaults() 66 | params = dict(bound_arguments.arguments.items()) 67 | print_variables_as_table(c1_title, c2_title, table_title, **params) 68 | return func(*args, **kwargs) 69 | 70 | return wrapper 71 | 72 | return decorator 73 | 74 | 75 | def print_train_details(augmentation: A.Compose, config: dict[str, Any]) -> None: 76 | console = Console() 77 | console.print("\n") 78 | console.print("[bold blue]Augmentation Pipeline:[/bold blue]") 79 | console.print(Pretty(augmentation)) 80 | console.print("\n") 81 | console.print("[bold blue]Configuration:[/bold blue]") 82 | console.print(Pretty(config)) 83 | console.print("\n") 84 | -------------------------------------------------------------------------------- /fast_plate_ocr/cli/valid.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for validating trained OCR models. 3 | """ 4 | 5 | import pathlib 6 | 7 | import click 8 | from torch.utils.data import DataLoader 9 | 10 | from fast_plate_ocr.train.data.dataset import LicensePlateDataset 11 | 12 | # Custom metris / losses 13 | from fast_plate_ocr.train.model.config import load_config_from_yaml 14 | from fast_plate_ocr.train.utilities import utils 15 | 16 | 17 | @click.command(context_settings={"max_content_width": 120}) 18 | @click.option( 19 | "-m", 20 | "--model", 21 | "model_path", 22 | required=True, 23 | type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=pathlib.Path), 24 | help="Path to the saved .keras model.", 25 | ) 26 | @click.option( 27 | "--config-file", 28 | required=True, 29 | type=click.Path(exists=True, file_okay=True, path_type=pathlib.Path), 30 | help="Path pointing to the model license plate OCR config.", 31 | ) 32 | @click.option( 33 | "-a", 34 | "--annotations", 35 | required=True, 36 | type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=pathlib.Path), 37 | help="Annotations file used for validation.", 38 | ) 39 | @click.option( 40 | "-b", 41 | "--batch-size", 42 | default=1, 43 | show_default=True, 44 | type=int, 45 | help="Batch size.", 46 | ) 47 | def valid( 48 | model_path: pathlib.Path, 49 | config_file: pathlib.Path, 50 | annotations: pathlib.Path, 51 | batch_size: int, 52 | ) -> None: 53 | """ 54 | Validate the trained OCR model on a labeled set. 55 | """ 56 | config = load_config_from_yaml(config_file) 57 | model = utils.load_keras_model( 58 | model_path, vocab_size=config.vocabulary_size, max_plate_slots=config.max_plate_slots 59 | ) 60 | val_torch_dataset = LicensePlateDataset(annotations_file=annotations, config=config) 61 | val_dataloader = DataLoader(val_torch_dataset, batch_size=batch_size, shuffle=False) 62 | model.evaluate(val_dataloader) 63 | 64 | 65 | if __name__ == "__main__": 66 | valid() 67 | -------------------------------------------------------------------------------- /fast_plate_ocr/cli/visualize_augmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to visualize the augmented plates used during training. 3 | """ 4 | 5 | import pathlib 6 | import random 7 | from math import ceil 8 | 9 | import albumentations as A 10 | import click 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import numpy.typing as npt 14 | 15 | from fast_plate_ocr.train.data.augmentation import TRAIN_AUGMENTATION 16 | from fast_plate_ocr.train.utilities import utils 17 | 18 | 19 | def _set_seed(seed: int | None) -> None: 20 | """Set random seed for reproducing augmentations.""" 21 | if seed: 22 | random.seed(seed) 23 | np.random.seed(seed) 24 | 25 | 26 | def load_images( 27 | img_dir: pathlib.Path, 28 | num_images: int, 29 | shuffle: bool, 30 | height: int, 31 | width: int, 32 | augmentation: A.Compose, 33 | ) -> tuple[list[npt.NDArray[np.uint8]], list[npt.NDArray[np.uint8]]]: 34 | images = utils.load_images_from_folder( 35 | img_dir, height=height, width=width, shuffle=shuffle, limit=num_images 36 | ) 37 | augmented_images = [augmentation(image=i)["image"] for i in images] 38 | return images, augmented_images 39 | 40 | 41 | def display_images( 42 | images: list[npt.NDArray[np.uint8]], 43 | augmented_images: list[npt.NDArray[np.uint8]], 44 | columns: int, 45 | rows: int, 46 | show_original: bool, 47 | ) -> None: 48 | num_images = len(images) 49 | total_plots = rows * columns 50 | num_pages = ceil(num_images / total_plots) 51 | for page in range(num_pages): 52 | _, axs = plt.subplots(rows, columns, figsize=(8, 8)) 53 | axs = axs.flatten() 54 | for i, ax in enumerate(axs): 55 | idx = page * total_plots + i 56 | if idx < num_images: 57 | if show_original: 58 | img_to_show = np.concatenate((images[idx], augmented_images[idx]), axis=1) 59 | else: 60 | img_to_show = augmented_images[idx] 61 | ax.imshow(img_to_show, cmap="gray") 62 | ax.axis("off") 63 | else: 64 | ax.axis("off") 65 | plt.tight_layout() 66 | plt.show() 67 | 68 | 69 | # ruff: noqa: PLR0913 70 | # pylint: disable=too-many-arguments,too-many-locals 71 | 72 | 73 | @click.command(context_settings={"max_content_width": 120}) 74 | @click.option( 75 | "--img-dir", 76 | "-d", 77 | required=True, 78 | type=click.Path(exists=True, dir_okay=True, path_type=pathlib.Path), 79 | help="Path to the images that will be augmented and visualized.", 80 | ) 81 | @click.option( 82 | "--num-images", 83 | "-n", 84 | type=int, 85 | default=1_000, 86 | show_default=True, 87 | help="Maximum number of images to visualize.", 88 | ) 89 | @click.option( 90 | "--augmentation-path", 91 | type=click.Path(exists=True, file_okay=True, path_type=pathlib.Path), 92 | help="YAML file pointing to the augmentation pipeline saved with Albumentations.save(...)", 93 | ) 94 | @click.option( 95 | "--shuffle", 96 | "-s", 97 | is_flag=True, 98 | default=False, 99 | help="Whether to shuffle the images before plotting them.", 100 | ) 101 | @click.option( 102 | "--columns", 103 | "-c", 104 | type=int, 105 | default=3, 106 | show_default=True, 107 | help="Number of columns in the grid layout for displaying images.", 108 | ) 109 | @click.option( 110 | "--rows", 111 | "-r", 112 | type=int, 113 | default=4, 114 | show_default=True, 115 | help="Number of rows in the grid layout for displaying images.", 116 | ) 117 | @click.option( 118 | "--height", 119 | "-h", 120 | type=int, 121 | default=70, 122 | show_default=True, 123 | help="Height to which the images will be resize.", 124 | ) 125 | @click.option( 126 | "--width", 127 | "-w", 128 | type=int, 129 | default=140, 130 | show_default=True, 131 | help="Width to which the images will be resize.", 132 | ) 133 | @click.option( 134 | "--show-original", 135 | "-o", 136 | is_flag=True, 137 | help="Show the original image along with the augmented one.", 138 | ) 139 | @click.option( 140 | "--seed", 141 | type=int, 142 | help="Seed for reproducing augmentations.", 143 | ) 144 | def visualize_augmentation( 145 | img_dir: pathlib.Path, 146 | num_images: int, 147 | augmentation_path: pathlib.Path | None, 148 | shuffle: bool, 149 | columns: int, 150 | rows: int, 151 | height: int, 152 | width: int, 153 | seed: int | None, 154 | show_original: bool, 155 | ) -> None: 156 | """ 157 | Visualize augmentation pipeline applied to raw images. 158 | """ 159 | _set_seed(seed) 160 | aug = A.load(augmentation_path, data_format="yaml") if augmentation_path else TRAIN_AUGMENTATION 161 | images, augmented_images = load_images(img_dir, num_images, shuffle, height, width, aug) 162 | display_images(images, augmented_images, columns, rows, show_original) 163 | 164 | 165 | if __name__ == "__main__": 166 | # pylint: disable = no-value-for-parameter 167 | visualize_augmentation() 168 | -------------------------------------------------------------------------------- /fast_plate_ocr/cli/visualize_predictions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for displaying an image with the OCR model predictions. 3 | """ 4 | 5 | import logging 6 | import pathlib 7 | 8 | import click 9 | import cv2 10 | import keras 11 | import numpy as np 12 | 13 | from fast_plate_ocr.train.model.config import load_config_from_yaml 14 | from fast_plate_ocr.train.utilities import utils 15 | from fast_plate_ocr.train.utilities.utils import postprocess_model_output 16 | 17 | logging.basicConfig( 18 | level=logging.INFO, format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S" 19 | ) 20 | 21 | 22 | @click.command(context_settings={"max_content_width": 120}) 23 | @click.option( 24 | "-m", 25 | "--model", 26 | "model_path", 27 | required=True, 28 | type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=pathlib.Path), 29 | help="Path to the saved .keras model.", 30 | ) 31 | @click.option( 32 | "--config-file", 33 | required=True, 34 | type=click.Path(exists=True, file_okay=True, path_type=pathlib.Path), 35 | help="Path pointing to the model license plate OCR config.", 36 | ) 37 | @click.option( 38 | "-d", 39 | "--img-dir", 40 | required=True, 41 | type=click.Path(exists=True, dir_okay=True, file_okay=False, path_type=pathlib.Path), 42 | help="Directory containing the images to make predictions from.", 43 | ) 44 | @click.option( 45 | "-l", 46 | "--low-conf-thresh", 47 | type=float, 48 | default=0.35, 49 | show_default=True, 50 | help="Threshold for displaying low confidence characters.", 51 | ) 52 | @click.option( 53 | "-f", 54 | "--filter-conf", 55 | type=float, 56 | help="Display plates that any of the plate characters are below this number.", 57 | ) 58 | def visualize_predictions( 59 | model_path: pathlib.Path, 60 | config_file: pathlib.Path, 61 | img_dir: pathlib.Path, 62 | low_conf_thresh: float, 63 | filter_conf: float | None, 64 | ): 65 | """ 66 | Visualize OCR model predictions on unlabeled data. 67 | """ 68 | config = load_config_from_yaml(config_file) 69 | model = utils.load_keras_model( 70 | model_path, vocab_size=config.vocabulary_size, max_plate_slots=config.max_plate_slots 71 | ) 72 | images = utils.load_images_from_folder( 73 | img_dir, width=config.img_width, height=config.img_height 74 | ) 75 | for image in images: 76 | x = np.expand_dims(image, 0) 77 | prediction = model(x, training=False) 78 | prediction = keras.ops.stop_gradient(prediction).numpy() 79 | plate, probs = postprocess_model_output( 80 | prediction=prediction, 81 | alphabet=config.alphabet, 82 | max_plate_slots=config.max_plate_slots, 83 | vocab_size=config.vocabulary_size, 84 | ) 85 | if not filter_conf or (filter_conf and np.any(probs < filter_conf)): 86 | utils.display_predictions( 87 | image=image, plate=plate, probs=probs, low_conf_thresh=low_conf_thresh 88 | ) 89 | cv2.destroyAllWindows() 90 | 91 | 92 | if __name__ == "__main__": 93 | visualize_predictions() 94 | -------------------------------------------------------------------------------- /fast_plate_ocr/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ankandrew/fast-plate-ocr/49954297c2226ad6c488f6fbfb93169b579b815d/fast_plate_ocr/common/__init__.py -------------------------------------------------------------------------------- /fast_plate_ocr/common/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common utilities used across the package. 3 | """ 4 | 5 | import logging 6 | import time 7 | from collections.abc import Callable, Iterator 8 | from contextlib import contextmanager 9 | 10 | 11 | @contextmanager 12 | def log_time_taken(process_name: str) -> Iterator[None]: 13 | """ 14 | A concise context manager to time code snippets and log the result. 15 | 16 | Usage: 17 | with log_time_taken("process_name"): 18 | # Code snippet to be timed 19 | 20 | :param process_name: Name of the process being timed. 21 | """ 22 | time_start: float = time.perf_counter() 23 | try: 24 | yield 25 | finally: 26 | time_end: float = time.perf_counter() 27 | time_elapsed: float = time_end - time_start 28 | logger = logging.getLogger(__name__) 29 | logger.info("Computation time of '%s' = %.3fms", process_name, 1_000 * time_elapsed) 30 | 31 | 32 | @contextmanager 33 | def measure_time() -> Iterator[Callable[[], float]]: 34 | """ 35 | A context manager for measuring execution time (in milliseconds) within its code block. 36 | 37 | usage: 38 | with code_timer() as timer: 39 | # Code snippet to be timed 40 | print(f"Code took: {timer()} seconds") 41 | """ 42 | start_time = end_time = time.perf_counter() 43 | yield lambda: (end_time - start_time) * 1_000 44 | end_time = time.perf_counter() 45 | -------------------------------------------------------------------------------- /fast_plate_ocr/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ankandrew/fast-plate-ocr/49954297c2226ad6c488f6fbfb93169b579b815d/fast_plate_ocr/inference/__init__.py -------------------------------------------------------------------------------- /fast_plate_ocr/inference/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model config reading/parsing for doing inference. 3 | """ 4 | 5 | from os import PathLike 6 | from typing import TypedDict 7 | 8 | import yaml 9 | 10 | # pylint: disable=duplicate-code 11 | 12 | 13 | class PlateOCRConfig(TypedDict): 14 | """ 15 | Plate OCR Config used for inference. 16 | 17 | This has the same attributes as the one used in the training Pydantic BaseModel. We use this to 18 | avoid having Pydantic as a required dependency of the minimal package install. 19 | """ 20 | 21 | max_plate_slots: int 22 | """ 23 | Max number of plate slots supported. This represents the number of model classification heads. 24 | """ 25 | alphabet: str 26 | """ 27 | All the possible character set for the model output. 28 | """ 29 | pad_char: str 30 | """ 31 | Padding character for plates which length is smaller than MAX_PLATE_SLOTS. 32 | """ 33 | img_height: int 34 | """ 35 | Image height which is fed to the model. 36 | """ 37 | img_width: int 38 | """ 39 | Image width which is fed to the model. 40 | """ 41 | 42 | 43 | def load_config_from_yaml(yaml_file_path: str | PathLike[str]) -> PlateOCRConfig: 44 | """ 45 | Read and parse a yaml containing the Plate OCR config. 46 | 47 | Note: This is currently not using Pydantic for parsing/validating to avoid adding it a python 48 | dependency as part of the minimal installation. 49 | """ 50 | with open(yaml_file_path, encoding="utf-8") as f_in: 51 | config: PlateOCRConfig = yaml.safe_load(f_in) 52 | return config 53 | -------------------------------------------------------------------------------- /fast_plate_ocr/inference/hub.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities function used for doing inference with the OCR models. 3 | """ 4 | 5 | import logging 6 | import pathlib 7 | import shutil 8 | import urllib.request 9 | from http import HTTPStatus 10 | from typing import Literal 11 | 12 | from tqdm.asyncio import tqdm 13 | 14 | from fast_plate_ocr.inference.utils import safe_write 15 | 16 | BASE_URL: str = "https://github.com/ankandrew/cnn-ocr-lp/releases/download" 17 | OcrModel = Literal[ 18 | "argentinian-plates-cnn-model", 19 | "argentinian-plates-cnn-synth-model", 20 | "european-plates-mobile-vit-v2-model", 21 | "global-plates-mobile-vit-v2-model", 22 | ] 23 | 24 | 25 | AVAILABLE_ONNX_MODELS: dict[OcrModel, tuple[str, str]] = { 26 | "argentinian-plates-cnn-model": ( 27 | f"{BASE_URL}/arg-plates/arg_cnn_ocr.onnx", 28 | f"{BASE_URL}/arg-plates/arg_cnn_ocr_config.yaml", 29 | ), 30 | "argentinian-plates-cnn-synth-model": ( 31 | f"{BASE_URL}/arg-plates/arg_cnn_ocr_synth.onnx", 32 | f"{BASE_URL}/arg-plates/arg_cnn_ocr_config.yaml", 33 | ), 34 | "european-plates-mobile-vit-v2-model": ( 35 | f"{BASE_URL}/arg-plates/european_mobile_vit_v2_ocr.onnx", 36 | f"{BASE_URL}/arg-plates/european_mobile_vit_v2_ocr_config.yaml", 37 | ), 38 | "global-plates-mobile-vit-v2-model": ( 39 | f"{BASE_URL}/arg-plates/global_mobile_vit_v2_ocr.onnx", 40 | f"{BASE_URL}/arg-plates/global_mobile_vit_v2_ocr_config.yaml", 41 | ), 42 | } 43 | """Available ONNX models for doing inference.""" 44 | 45 | MODEL_CACHE_DIR: pathlib.Path = pathlib.Path.home() / ".cache" / "fast-plate-ocr" 46 | """Default location where models will be stored.""" 47 | 48 | 49 | def _download_with_progress(url: str, filename: pathlib.Path) -> None: 50 | """ 51 | Download utility function with progress bar. 52 | 53 | :param url: URL of the model to download. 54 | :param filename: Where to save the OCR model. 55 | """ 56 | with urllib.request.urlopen(url) as response, safe_write(filename, mode="wb") as out_file: 57 | if response.getcode() != HTTPStatus.OK: 58 | raise ValueError(f"Failed to download file from {url}. Status code: {response.status}") 59 | 60 | file_size = int(response.headers.get("Content-Length", 0)) 61 | desc = f"Downloading {filename.name}" 62 | 63 | with tqdm.wrapattr(out_file, "write", total=file_size, desc=desc) as f_out: 64 | shutil.copyfileobj(response, f_out) 65 | 66 | 67 | def download_model( 68 | model_name: OcrModel, 69 | save_directory: pathlib.Path | None = None, 70 | force_download: bool = False, 71 | ) -> tuple[pathlib.Path, pathlib.Path]: 72 | """ 73 | Download an OCR model and the config to a given directory. 74 | 75 | :param model_name: Which model to download. 76 | :param save_directory: Directory to save the OCR model. It should point to a folder. If not 77 | supplied, this will point to '~/.cache/' 78 | :param force_download: Force and download the model if it already exists in `save_directory`. 79 | :return: A tuple consisting of (model_downloaded_path, config_downloaded_path). 80 | """ 81 | if model_name not in AVAILABLE_ONNX_MODELS: 82 | available_models = ", ".join(AVAILABLE_ONNX_MODELS.keys()) 83 | raise ValueError(f"Unknown model {model_name}. Use one of [{available_models}]") 84 | 85 | if save_directory is None: 86 | save_directory = MODEL_CACHE_DIR / model_name 87 | elif save_directory.is_file(): 88 | raise ValueError(f"Expected a directory, but got {save_directory}") 89 | 90 | save_directory.mkdir(parents=True, exist_ok=True) 91 | 92 | model_url, config_url = AVAILABLE_ONNX_MODELS[model_name] 93 | model_filename = save_directory / model_url.split("/")[-1] 94 | config_filename = save_directory / config_url.split("/")[-1] 95 | 96 | if not force_download and model_filename.is_file() and config_filename.is_file(): 97 | logging.info( 98 | "Skipping download of '%s' model, already exists at %s", 99 | model_name, 100 | save_directory, 101 | ) 102 | return model_filename, config_filename 103 | 104 | # Download the model if not present or if we want to force the download 105 | if force_download or not model_filename.is_file(): 106 | logging.info("Downloading model to %s", model_filename) 107 | _download_with_progress(url=model_url, filename=model_filename) 108 | 109 | # Same for the config 110 | if force_download or not config_filename.is_file(): 111 | logging.info("Downloading config to %s", config_filename) 112 | _download_with_progress(url=config_url, filename=config_filename) 113 | 114 | return model_filename, config_filename 115 | -------------------------------------------------------------------------------- /fast_plate_ocr/inference/onnx_inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | ONNX inference module. 3 | """ 4 | 5 | import logging 6 | import os 7 | import pathlib 8 | from collections.abc import Sequence 9 | from typing import Literal 10 | 11 | import numpy as np 12 | import numpy.typing as npt 13 | import onnxruntime as ort 14 | from rich.console import Console 15 | from rich.panel import Panel 16 | from rich.table import Table 17 | from rich.text import Text 18 | 19 | from fast_plate_ocr.common.utils import measure_time 20 | from fast_plate_ocr.inference import hub 21 | from fast_plate_ocr.inference.config import load_config_from_yaml 22 | from fast_plate_ocr.inference.hub import OcrModel 23 | from fast_plate_ocr.inference.process import postprocess_output, preprocess_image, read_plate_image 24 | 25 | 26 | def _load_image_from_source( 27 | source: str | list[str] | npt.NDArray | list[npt.NDArray], 28 | ) -> npt.NDArray | list[npt.NDArray]: 29 | """ 30 | Loads an image from a given source. 31 | 32 | :param source: Path to the input image file, list of paths, or numpy array representing one or 33 | multiple images. 34 | :return: Numpy array representing the input image(s) or a list of numpy arrays. 35 | """ 36 | if isinstance(source, str): 37 | # Shape returned (H, W) 38 | return read_plate_image(source) 39 | 40 | if isinstance(source, list): 41 | # Are image paths 42 | if all(isinstance(s, str) for s in source): 43 | # List returned with array item of shape (H, W) 44 | return [read_plate_image(i) for i in source] # type: ignore[arg-type] 45 | # Are list of numpy arrays 46 | if all(isinstance(a, np.ndarray) for a in source): 47 | # List returned with array item of shape (H, W) 48 | return source # type: ignore[return-value] 49 | raise ValueError("Expected source to be a list of `str` or `np.ndarray`!") 50 | 51 | if isinstance(source, np.ndarray): 52 | # Squeeze grayscale channel dimension if supplied 53 | source = source.squeeze() 54 | if source.ndim != 2: 55 | raise ValueError("Expected source array to be of shape (H, W) or (H, W, 1).") 56 | # Shape returned (H, W) 57 | return source 58 | 59 | raise ValueError("Unsupported input type. Only file path or numpy array is supported.") 60 | 61 | 62 | class ONNXPlateRecognizer: 63 | """ 64 | ONNX inference class for performing license plates OCR. 65 | """ 66 | 67 | def __init__( 68 | self, 69 | hub_ocr_model: OcrModel | None = None, 70 | device: Literal["cuda", "cpu", "auto"] = "auto", 71 | providers: Sequence[str | tuple[str, dict]] | None = None, 72 | sess_options: ort.SessionOptions | None = None, 73 | model_path: str | os.PathLike[str] | None = None, 74 | config_path: str | os.PathLike[str] | None = None, 75 | force_download: bool = False, 76 | ) -> None: 77 | """ 78 | Initializes the ONNXPlateRecognizer with the specified OCR model and inference device. 79 | 80 | The current OCR models available from the HUB are: 81 | 82 | - `argentinian-plates-cnn-model`: OCR for Argentinian license plates. Uses fully conv 83 | architecture. 84 | - `argentinian-plates-cnn-synth-model`: OCR for Argentinian license plates trained with 85 | synthetic and real data. Uses fully conv architecture. 86 | - `european-plates-mobile-vit-v2-model`: OCR for European license plates. Uses MobileVIT-2 87 | for the backbone. 88 | - `global-plates-mobile-vit-v2-model`: OCR for Global license plates (+65 countries). 89 | Uses MobileVIT-2 for the backbone. 90 | 91 | Args: 92 | hub_ocr_model: Name of the OCR model to use from the HUB. 93 | device: Device type for inference. Should be one of ('cpu', 'cuda', 'auto'). If 94 | 'auto' mode, the device will be deduced from 95 | `onnxruntime.get_available_providers()`. 96 | providers: Optional sequence of providers in order of decreasing precedence. If not 97 | specified, all available providers are used based on the device argument. 98 | sess_options: Advanced session options for ONNX Runtime. 99 | model_path: Path to ONNX model file to use (In case you want to use a custom one). 100 | config_path: Path to config file to use (In case you want to use a custom one). 101 | force_download: Force and download the model, even if it already exists. 102 | Returns: 103 | None. 104 | """ 105 | self.logger = logging.getLogger(__name__) 106 | 107 | if providers is not None: 108 | self.providers = providers 109 | self.logger.info("Using custom providers: %s", providers) 110 | else: 111 | if device == "cuda": 112 | self.providers = ["CUDAExecutionProvider"] 113 | elif device == "cpu": 114 | self.providers = ["CPUExecutionProvider"] 115 | elif device == "auto": 116 | self.providers = ort.get_available_providers() 117 | else: 118 | raise ValueError( 119 | f"Device should be one of ('cpu', 'cuda', 'auto'). Got '{device}'." 120 | ) 121 | 122 | self.logger.info("Using device '%s' with providers: %s", device, self.providers) 123 | 124 | if model_path and config_path: 125 | model_path = pathlib.Path(model_path) 126 | config_path = pathlib.Path(config_path) 127 | if not model_path.exists() or not config_path.exists(): 128 | raise FileNotFoundError("Missing model/config file!") 129 | self.model_name = model_path.stem 130 | elif hub_ocr_model: 131 | self.model_name = hub_ocr_model 132 | model_path, config_path = hub.download_model( 133 | model_name=hub_ocr_model, force_download=force_download 134 | ) 135 | else: 136 | raise ValueError( 137 | "Either provide a model from the HUB or a custom model_path and config_path" 138 | ) 139 | 140 | self.config = load_config_from_yaml(config_path) 141 | self.model = ort.InferenceSession( 142 | model_path, providers=self.providers, sess_options=sess_options 143 | ) 144 | self.logger.info("Using ONNX Runtime with %s.", self.providers) 145 | 146 | def benchmark(self, n_iter: int = 10_000, include_processing: bool = False) -> None: 147 | """ 148 | Benchmark time taken to run the OCR model. This reports the average inference time and the 149 | throughput in plates per second. 150 | 151 | Args: 152 | n_iter: The number of iterations to run the benchmark. This determines how many times 153 | the inference will be executed to compute the average performance metrics. 154 | include_processing: Indicates whether the benchmark should include preprocessing and 155 | postprocessing times in the measurement. 156 | """ 157 | cum_time = 0.0 158 | x = np.random.randint( 159 | 0, 256, size=(1, self.config["img_height"], self.config["img_width"], 1), dtype=np.uint8 160 | ) 161 | for _ in range(n_iter): 162 | with measure_time() as time_taken: 163 | if include_processing: 164 | self.run(x) 165 | else: 166 | self.model.run(None, {"input": x}) 167 | cum_time += time_taken() 168 | 169 | avg_time = (cum_time / n_iter) if n_iter > 0 else 0.0 170 | avg_pps = (1_000 / avg_time) if n_iter > 0 else 0.0 171 | 172 | console = Console() 173 | model_info = Panel( 174 | Text(f"Model: {self.model_name}\nProviders: {self.providers}", style="bold green"), 175 | title="Model Information", 176 | border_style="bright_blue", 177 | expand=False, 178 | ) 179 | console.print(model_info) 180 | table = Table(title=f"Benchmark for '{self.model_name}' Model", border_style="bright_blue") 181 | table.add_column("Metric", justify="center", style="cyan", no_wrap=True) 182 | table.add_column("Value", justify="center", style="magenta") 183 | table.add_row("Number of Iterations", str(n_iter)) 184 | table.add_row("Average Time (ms)", f"{avg_time:.4f}") 185 | table.add_row("Plates Per Second (PPS)", f"{avg_pps:.4f}") 186 | console.print(table) 187 | 188 | def run( 189 | self, 190 | source: str | list[str] | npt.NDArray | list[npt.NDArray], 191 | return_confidence: bool = False, 192 | ) -> tuple[list[str], npt.NDArray] | list[str]: 193 | """ 194 | Performs OCR to recognize license plate characters from an image or a list of images. 195 | 196 | Args: 197 | source: The path(s) to the image(s), a numpy array representing an image or a list 198 | of NumPy arrays. If a numpy array is provided, it is expected to already be in 199 | grayscale format, with shape `(H, W) `or `(H, W, 1)`. A list of numpy arrays with 200 | different image sizes may also be provided. 201 | return_confidence: Whether to return confidence scores along with plate predictions. 202 | 203 | Returns: 204 | A list of plates for each input image. If `return_confidence` is True, a numpy 205 | array is returned with the shape `(N, plate_slots)`, where N is the batch size and 206 | each plate slot is the confidence for the recognized license plate character. 207 | """ 208 | x = _load_image_from_source(source) 209 | # Preprocess 210 | x = preprocess_image(x, self.config["img_height"], self.config["img_width"]) 211 | # Run model 212 | y: list[npt.NDArray] = self.model.run(None, {"input": x}) 213 | # Postprocess model output 214 | return postprocess_output( 215 | y[0], 216 | self.config["max_plate_slots"], 217 | self.config["alphabet"], 218 | return_confidence=return_confidence, 219 | ) 220 | -------------------------------------------------------------------------------- /fast_plate_ocr/inference/process.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for processing model input/output. 3 | """ 4 | 5 | import os 6 | 7 | import cv2 8 | import numpy as np 9 | import numpy.typing as npt 10 | 11 | 12 | def read_plate_image(image_path: str) -> npt.NDArray: 13 | """ 14 | Read image from disk as a grayscale image. 15 | 16 | :param image_path: The path to the license plate image. 17 | :return: The image as a NumPy array. 18 | """ 19 | if not os.path.exists(image_path): 20 | raise ValueError(f"{image_path} file doesn't exist!") 21 | return cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) 22 | 23 | 24 | def preprocess_image( 25 | image: npt.NDArray | list[npt.NDArray], img_height: int, img_width: int 26 | ) -> npt.NDArray: 27 | """ 28 | Preprocess the image(s), so they're ready to be fed to the model. 29 | 30 | Note: We don't normalize the pixel values between [0, 1] here, because that the model first 31 | layer does that. 32 | 33 | :param image: The image(s) contained in a NumPy array. 34 | :param img_height: The desired height of the resized image. 35 | :param img_width: The desired width of the resized image. 36 | :return: A numpy array with shape (N, H, W, 1). 37 | """ 38 | # Add batch dimension: (H, W) -> (1, H, W) 39 | if isinstance(image, np.ndarray): 40 | image = np.expand_dims(image, axis=0) 41 | 42 | imgs = np.array( 43 | [ 44 | cv2.resize(im.squeeze(), (img_width, img_height), interpolation=cv2.INTER_LINEAR) 45 | for im in image 46 | ] 47 | ) 48 | # Add channel dimension 49 | imgs = np.expand_dims(imgs, axis=-1) 50 | return imgs 51 | 52 | 53 | def postprocess_output( 54 | model_output: npt.NDArray, 55 | max_plate_slots: int, 56 | model_alphabet: str, 57 | return_confidence: bool = False, 58 | ) -> tuple[list[str], npt.NDArray] | list[str]: 59 | """ 60 | Post-processes model output and return license plate string, and optionally the probabilities. 61 | 62 | :param model_output: Output from the model containing predictions. 63 | :param max_plate_slots: Maximum number of characters in a license plate. 64 | :param model_alphabet: Alphabet used by the model for character encoding. 65 | :param return_confidence: Flag to indicate whether to return confidence scores along with plate 66 | predictions. 67 | :return: Decoded license plate characters as a list, optionally with confidence scores. The 68 | confidence scores have shape (N, max_plate_slots) where N is the batch size. 69 | """ 70 | predictions = model_output.reshape((-1, max_plate_slots, len(model_alphabet))) 71 | prediction_indices = np.argmax(predictions, axis=-1) 72 | alphabet_array = np.array(list(model_alphabet)) 73 | plate_chars = alphabet_array[prediction_indices] 74 | plates: list[str] = np.apply_along_axis("".join, 1, plate_chars).tolist() 75 | if return_confidence: 76 | probs = np.max(predictions, axis=-1) 77 | return plates, probs 78 | return plates 79 | -------------------------------------------------------------------------------- /fast_plate_ocr/inference/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities used around the inference package. 3 | """ 4 | 5 | import os 6 | from collections.abc import Iterator 7 | from contextlib import contextmanager 8 | from pathlib import Path 9 | from typing import IO, Any 10 | 11 | 12 | @contextmanager 13 | def safe_write( 14 | file: str | os.PathLike[str], 15 | mode: str = "wb", 16 | encoding: str | None = None, 17 | **kwargs: Any, 18 | ) -> Iterator[IO]: 19 | """ 20 | Context manager for safe file writing. 21 | 22 | Opens the specified file for writing and yields a file object. 23 | If an exception occurs during writing, the file is removed before raising the exception. 24 | """ 25 | try: 26 | with open(file, mode, encoding=encoding, **kwargs) as f: 27 | yield f 28 | except Exception as e: 29 | Path(file).unlink(missing_ok=True) 30 | raise e 31 | -------------------------------------------------------------------------------- /fast_plate_ocr/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ankandrew/fast-plate-ocr/49954297c2226ad6c488f6fbfb93169b579b815d/fast_plate_ocr/py.typed -------------------------------------------------------------------------------- /fast_plate_ocr/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ankandrew/fast-plate-ocr/49954297c2226ad6c488f6fbfb93169b579b815d/fast_plate_ocr/train/__init__.py -------------------------------------------------------------------------------- /fast_plate_ocr/train/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ankandrew/fast-plate-ocr/49954297c2226ad6c488f6fbfb93169b579b815d/fast_plate_ocr/train/data/__init__.py -------------------------------------------------------------------------------- /fast_plate_ocr/train/data/augmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Augmentations used for training the OCR model. 3 | """ 4 | 5 | import albumentations as A 6 | import cv2 7 | 8 | BORDER_COLOR_BLACK: tuple[int, int, int] = (0, 0, 0) 9 | 10 | TRAIN_AUGMENTATION = A.Compose( 11 | [ 12 | A.ShiftScaleRotate( 13 | shift_limit=0.06, 14 | scale_limit=0.1, 15 | rotate_limit=9, 16 | border_mode=cv2.BORDER_CONSTANT, 17 | value=BORDER_COLOR_BLACK, 18 | p=1, 19 | ), 20 | A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=1), 21 | A.MotionBlur(blur_limit=(3, 5), p=0.1), 22 | A.OneOf( 23 | [ 24 | A.CoarseDropout(max_holes=10, max_height=4, max_width=4, p=0.3), 25 | A.PixelDropout(dropout_prob=0.01, p=0.2), 26 | ], 27 | p=0.7, 28 | ), 29 | ] 30 | ) 31 | """Training augmentations recipe.""" 32 | -------------------------------------------------------------------------------- /fast_plate_ocr/train/data/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset module. 3 | """ 4 | 5 | import os 6 | from os import PathLike 7 | 8 | import albumentations as A 9 | import numpy.typing as npt 10 | import pandas as pd 11 | from torch.utils.data import Dataset 12 | 13 | from fast_plate_ocr.train.model.config import PlateOCRConfig 14 | from fast_plate_ocr.train.utilities import utils 15 | 16 | 17 | class LicensePlateDataset(Dataset): 18 | def __init__( 19 | self, 20 | annotations_file: str | PathLike[str], 21 | config: PlateOCRConfig, 22 | transform: A.Compose | None = None, 23 | ) -> None: 24 | annotations = pd.read_csv(annotations_file) 25 | annotations["image_path"] = ( 26 | os.path.dirname(os.path.realpath(annotations_file)) + os.sep + annotations["image_path"] 27 | ) 28 | assert ( 29 | annotations["plate_text"].str.len() <= config.max_plate_slots 30 | ).all(), "Plates are longer than max_plate_slots specified param. Change the parameter." 31 | self.annotations = annotations.to_numpy() 32 | self.config = config 33 | self.transform = transform 34 | 35 | def __len__(self) -> int: 36 | return self.annotations.shape[0] 37 | 38 | def __getitem__(self, idx) -> tuple[npt.NDArray, npt.NDArray]: 39 | image_path, plate_text = self.annotations[idx] 40 | x = utils.read_plate_image( 41 | image_path=image_path, 42 | img_height=self.config.img_height, 43 | img_width=self.config.img_width, 44 | ) 45 | y = utils.target_transform( 46 | plate_text=plate_text, 47 | max_plate_slots=self.config.max_plate_slots, 48 | alphabet=self.config.alphabet, 49 | pad_char=self.config.pad_char, 50 | ) 51 | if self.transform: 52 | x = self.transform(image=x)["image"] 53 | return x, y 54 | -------------------------------------------------------------------------------- /fast_plate_ocr/train/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ankandrew/fast-plate-ocr/49954297c2226ad6c488f6fbfb93169b579b815d/fast_plate_ocr/train/model/__init__.py -------------------------------------------------------------------------------- /fast_plate_ocr/train/model/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Config values used throughout the code. 3 | """ 4 | 5 | from os import PathLike 6 | 7 | import yaml 8 | from pydantic import BaseModel, computed_field, model_validator 9 | 10 | 11 | class PlateOCRConfig(BaseModel, extra="forbid", frozen=True): 12 | """ 13 | Model License Plate OCR config. 14 | """ 15 | 16 | max_plate_slots: int 17 | """ 18 | Max number of plate slots supported. This represents the number of model classification heads. 19 | """ 20 | 21 | alphabet: str 22 | """ 23 | All the possible character set for the model output. 24 | """ 25 | pad_char: str 26 | """ 27 | Padding character for plates which length is smaller than MAX_PLATE_SLOTS. 28 | """ 29 | img_height: int 30 | """ 31 | Image height which is fed to the model. 32 | """ 33 | img_width: int 34 | """ 35 | Image width which is fed to the model. 36 | """ 37 | 38 | @computed_field # type: ignore[misc] 39 | @property 40 | def vocabulary_size(self) -> int: 41 | return len(self.alphabet) 42 | 43 | @model_validator(mode="after") 44 | def check_pad_in_alphabet(self) -> "PlateOCRConfig": 45 | if self.pad_char not in self.alphabet: 46 | raise ValueError("Pad character must be present in model alphabet.") 47 | return self 48 | 49 | 50 | def load_config_from_yaml(yaml_file_path: str | PathLike[str]) -> PlateOCRConfig: 51 | """Read and parse a yaml containing the Plate OCR config.""" 52 | with open(yaml_file_path, encoding="utf-8") as f_in: 53 | yaml_content = yaml.safe_load(f_in) 54 | config = PlateOCRConfig(**yaml_content) 55 | return config 56 | -------------------------------------------------------------------------------- /fast_plate_ocr/train/model/custom.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom metrics and loss functions. 3 | """ 4 | 5 | from keras import losses, metrics, ops 6 | 7 | 8 | def cat_acc_metric(max_plate_slots: int, vocabulary_size: int): 9 | """ 10 | Categorical accuracy metric. 11 | """ 12 | 13 | def cat_acc(y_true, y_pred): 14 | """ 15 | This is simply the CategoricalAccuracy for multi-class label problems. Example if the 16 | correct label is ABC123 and ABC133 is predicted, it will not give a precision of 0% like 17 | plate_acc (not completely classified correctly), but 83.3% (5/6). 18 | """ 19 | y_true = ops.reshape(y_true, newshape=(-1, max_plate_slots, vocabulary_size)) 20 | y_pred = ops.reshape(y_pred, newshape=(-1, max_plate_slots, vocabulary_size)) 21 | return ops.mean(metrics.categorical_accuracy(y_true, y_pred)) 22 | 23 | return cat_acc 24 | 25 | 26 | def plate_acc_metric(max_plate_slots: int, vocabulary_size: int): 27 | """ 28 | Plate accuracy metric. 29 | """ 30 | 31 | def plate_acc(y_true, y_pred): 32 | """ 33 | Compute how many plates were correctly classified. For a single plate, if ground truth is 34 | 'ABC 123', and the prediction is 'ABC 123', then this would give a score of 1. If the 35 | prediction was ABD 123, it would score 0. 36 | """ 37 | y_true = ops.reshape(y_true, newshape=(-1, max_plate_slots, vocabulary_size)) 38 | y_pred = ops.reshape(y_pred, newshape=(-1, max_plate_slots, vocabulary_size)) 39 | et = ops.equal(ops.argmax(y_true, axis=-1), ops.argmax(y_pred, axis=-1)) 40 | return ops.mean(ops.cast(ops.all(et, axis=-1, keepdims=False), dtype="float32")) 41 | 42 | return plate_acc 43 | 44 | 45 | def top_3_k_metric(vocabulary_size: int): 46 | """ 47 | Top 3 K categorical accuracy metric. 48 | """ 49 | 50 | def top_3_k(y_true, y_pred): 51 | """ 52 | Calculates how often the true character is found in the 3 predictions with the highest 53 | probability. 54 | """ 55 | # Reshape into 2-d 56 | y_true = ops.reshape(y_true, newshape=(-1, vocabulary_size)) 57 | y_pred = ops.reshape(y_pred, newshape=(-1, vocabulary_size)) 58 | return ops.mean(metrics.top_k_categorical_accuracy(y_true, y_pred, k=3)) 59 | 60 | return top_3_k 61 | 62 | 63 | # Custom loss 64 | def cce_loss(vocabulary_size: int, label_smoothing: float = 0.2): 65 | """ 66 | Categorical cross-entropy loss. 67 | """ 68 | 69 | def cce(y_true, y_pred): 70 | """ 71 | Computes the categorical cross-entropy loss. 72 | """ 73 | y_true = ops.reshape(y_true, newshape=(-1, vocabulary_size)) 74 | y_pred = ops.reshape(y_pred, newshape=(-1, vocabulary_size)) 75 | return ops.mean( 76 | losses.categorical_crossentropy( 77 | y_true, y_pred, from_logits=False, label_smoothing=label_smoothing 78 | ) 79 | ) 80 | 81 | return cce 82 | -------------------------------------------------------------------------------- /fast_plate_ocr/train/model/layer_blocks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Layer blocks used in the OCR model. 3 | """ 4 | 5 | from keras import regularizers 6 | from keras.src.layers import ( 7 | Activation, 8 | AveragePooling2D, 9 | BatchNormalization, 10 | Conv2D, 11 | MaxPooling2D, 12 | SeparableConv2D, 13 | ) 14 | 15 | 16 | def block_no_bn(i, k=3, n_c=64, s=1, padding="same", activation: str = "relu"): 17 | x1 = Conv2D( 18 | kernel_size=k, 19 | filters=n_c, 20 | strides=s, 21 | padding=padding, 22 | kernel_regularizer=regularizers.l2(0.01), 23 | use_bias=False, 24 | )(i) 25 | x2 = Activation(activation)(x1) 26 | return x2, x1 27 | 28 | 29 | def block_no_activation(i, k=3, n_c=64, s=1, padding="same"): 30 | x = Conv2D( 31 | kernel_size=k, 32 | filters=n_c, 33 | strides=s, 34 | padding=padding, 35 | kernel_regularizer=regularizers.l2(0.01), 36 | use_bias=False, 37 | )(i) 38 | x = BatchNormalization()(x) 39 | return x 40 | 41 | 42 | def block_bn(i, k=3, n_c=64, s=1, padding="same", activation: str = "relu"): 43 | x1 = Conv2D( 44 | kernel_size=k, 45 | filters=n_c, 46 | strides=s, 47 | padding=padding, 48 | kernel_regularizer=regularizers.l2(0.01), 49 | use_bias=False, 50 | )(i) 51 | x2 = BatchNormalization()(x1) 52 | x2 = Activation(activation)(x2) 53 | return x2, x1 54 | 55 | 56 | def block_bn_no_l2(i, k=3, n_c=64, s=1, padding="same", activation: str = "relu"): 57 | x1 = Conv2D(kernel_size=k, filters=n_c, strides=s, padding=padding, use_bias=False)(i) 58 | x2 = BatchNormalization()(x1) 59 | x2 = Activation(activation)(x2) 60 | return x2, x1 61 | 62 | 63 | def block_bn_sep_conv_l2( 64 | i, k=3, n_c=64, s=1, padding="same", depth_multiplier=1, activation: str = "relu" 65 | ): 66 | l2_kernel_reg = regularizers.l2(0.01) 67 | x1 = SeparableConv2D( 68 | kernel_size=k, 69 | filters=n_c, 70 | depth_multiplier=depth_multiplier, 71 | strides=s, 72 | padding=padding, 73 | use_bias=False, 74 | depthwise_regularizer=l2_kernel_reg, 75 | pointwise_regularizer=l2_kernel_reg, 76 | )(i) 77 | x2 = BatchNormalization()(x1) 78 | x2 = Activation(activation)(x2) 79 | return x2, x1 80 | 81 | 82 | def block_bn_relu6(i, k=3, n_c=64, s=1, padding="same", activation: str = "relu6"): 83 | x1 = Conv2D( 84 | kernel_size=k, 85 | filters=n_c, 86 | strides=s, 87 | padding=padding, 88 | kernel_regularizer=regularizers.l2(0.01), 89 | use_bias=False, 90 | )(i) 91 | x2 = BatchNormalization()(x1) 92 | x2 = Activation(activation)(x2) 93 | return x2, x1 94 | 95 | 96 | def block_bn_relu6_no_l2(i, k=3, n_c=64, s=1, padding="same", activation: str = "relu6"): 97 | x1 = Conv2D(kernel_size=k, filters=n_c, strides=s, padding=padding, use_bias=False)(i) 98 | x2 = BatchNormalization()(x1) 99 | x2 = Activation(activation)(x2) 100 | return x2, x1 101 | 102 | 103 | def block_average_conv_down(x, n_c, padding="same", activation: str = "relu"): 104 | x = AveragePooling2D(pool_size=2, strides=1, padding="valid")(x) 105 | x = Conv2D( 106 | filters=n_c, 107 | kernel_size=3, 108 | strides=2, 109 | padding=padding, 110 | kernel_regularizer=regularizers.l2(0.01), 111 | use_bias=False, 112 | )(x) 113 | x = BatchNormalization()(x) 114 | x = Activation(activation)(x) 115 | return x 116 | 117 | 118 | def block_max_conv_down(x, n_c, padding="same", activation: str = "relu"): 119 | x = MaxPooling2D(pool_size=2, strides=1, padding="valid")(x) 120 | x = Conv2D( 121 | filters=n_c, 122 | kernel_size=3, 123 | strides=2, 124 | padding=padding, 125 | kernel_regularizer=regularizers.l2(0.01), 126 | use_bias=False, 127 | )(x) 128 | x = BatchNormalization()(x) 129 | x = Activation(activation)(x) 130 | return x 131 | -------------------------------------------------------------------------------- /fast_plate_ocr/train/model/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model definitions for the FastLP OCR. 3 | """ 4 | 5 | from typing import Literal 6 | 7 | from keras.src.activations import softmax 8 | from keras.src.layers import ( 9 | Activation, 10 | Concatenate, 11 | Dense, 12 | Dropout, 13 | GlobalAveragePooling2D, 14 | Input, 15 | Rescaling, 16 | Reshape, 17 | Softmax, 18 | ) 19 | from keras.src.models import Model 20 | 21 | from fast_plate_ocr.train.model.layer_blocks import ( 22 | block_average_conv_down, 23 | block_bn, 24 | block_max_conv_down, 25 | block_no_activation, 26 | ) 27 | 28 | 29 | def cnn_ocr_model( 30 | h: int, 31 | w: int, 32 | max_plate_slots: int, 33 | vocabulary_size: int, 34 | dense: bool = True, 35 | activation: str = "relu", 36 | pool_layer: Literal["avg", "max"] = "max", 37 | ) -> Model: 38 | """ 39 | OCR model implemented with just CNN layers (v2). 40 | """ 41 | input_tensor = Input((h, w, 1)) 42 | x = Rescaling(1.0 / 255)(input_tensor) 43 | # Pooling-Conv layer 44 | if pool_layer == "avg": 45 | block_pool_conv = block_average_conv_down 46 | elif pool_layer == "max": 47 | block_pool_conv = block_max_conv_down 48 | # Backbone 49 | x = block_pool_conv(x, n_c=32, padding="same", activation=activation) 50 | x, _ = block_bn(x, k=3, n_c=64, s=1, padding="same", activation=activation) 51 | x, _ = block_bn(x, k=1, n_c=64, s=1, padding="same", activation=activation) 52 | x = block_pool_conv(x, n_c=64, padding="same", activation=activation) 53 | x, _ = block_bn(x, k=3, n_c=128, s=1, padding="same", activation=activation) 54 | x, _ = block_bn(x, k=1, n_c=128, s=1, padding="same", activation=activation) 55 | x = block_pool_conv(x, n_c=128, padding="same", activation=activation) 56 | x, _ = block_bn(x, k=3, n_c=128, s=1, padding="same", activation=activation) 57 | x, _ = block_bn(x, k=1, n_c=256, s=1, padding="same", activation=activation) 58 | x = block_pool_conv(x, n_c=256, padding="same", activation=activation) 59 | x, _ = block_bn(x, k=1, n_c=512, s=1, padding="same", activation=activation) 60 | x, _ = block_bn(x, k=1, n_c=1024, s=1, padding="same", activation=activation) 61 | x = ( 62 | head(x, max_plate_slots, vocabulary_size) 63 | if dense 64 | else head_no_fc(x, max_plate_slots, vocabulary_size) 65 | ) 66 | return Model(inputs=input_tensor, outputs=x) 67 | 68 | 69 | def head(x, max_plate_slots: int, vocabulary_size: int): 70 | """ 71 | Model's head with Fully Connected (FC) layers. 72 | """ 73 | x = GlobalAveragePooling2D()(x) 74 | # dropout for more robust learning 75 | x = Dropout(0.5)(x) 76 | dense_outputs = [ 77 | Activation(softmax)(Dense(units=vocabulary_size)(x)) for _ in range(max_plate_slots) 78 | ] 79 | # concat all the dense outputs 80 | x = Concatenate()(dense_outputs) 81 | return x 82 | 83 | 84 | def head_no_fc(x, max_plate_slots: int, vocabulary_size: int): 85 | """ 86 | Model's head without Fully Connected (FC) layers. 87 | """ 88 | x = block_no_activation(x, k=1, n_c=max_plate_slots * vocabulary_size, s=1, padding="same") 89 | x = GlobalAveragePooling2D()(x) 90 | x = Reshape((max_plate_slots, vocabulary_size, 1))(x) 91 | x = Softmax(axis=-2)(x) 92 | return x 93 | -------------------------------------------------------------------------------- /fast_plate_ocr/train/utilities/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ankandrew/fast-plate-ocr/49954297c2226ad6c488f6fbfb93169b579b815d/fast_plate_ocr/train/utilities/__init__.py -------------------------------------------------------------------------------- /fast_plate_ocr/train/utilities/backend_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils for Keras supported backends. 3 | """ 4 | 5 | import os 6 | from typing import Literal, TypeAlias 7 | 8 | Framework: TypeAlias = Literal["jax", "tensorflow", "torch"] 9 | """Supported backend frameworks for Keras.""" 10 | 11 | 12 | def set_jax_backend() -> None: 13 | """Set Keras backend to jax.""" 14 | set_keras_backend("jax") 15 | 16 | 17 | def set_tensorflow_backend() -> None: 18 | """Set Keras backend to tensorflow.""" 19 | set_keras_backend("tensorflow") 20 | 21 | 22 | def set_pytorch_backend() -> None: 23 | """Set Keras backend to pytorch.""" 24 | set_keras_backend("torch") 25 | 26 | 27 | def set_keras_backend(framework: Framework) -> None: 28 | """Set the Keras backend to a given framework.""" 29 | os.environ["KERAS_BACKEND"] = framework 30 | 31 | 32 | def reload_keras_backend(framework: Framework) -> None: 33 | """Reload the Keras backend with a given framework.""" 34 | # pylint: disable=import-outside-toplevel 35 | import keras 36 | 37 | keras.config.set_backend(framework) 38 | -------------------------------------------------------------------------------- /fast_plate_ocr/train/utilities/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions module 3 | """ 4 | 5 | import logging 6 | import pathlib 7 | import random 8 | 9 | import cv2 10 | import keras 11 | import numpy as np 12 | import numpy.typing as npt 13 | 14 | from fast_plate_ocr.train.model.custom import ( 15 | cat_acc_metric, 16 | cce_loss, 17 | plate_acc_metric, 18 | top_3_k_metric, 19 | ) 20 | 21 | 22 | def one_hot_plate(plate: str, alphabet: str) -> list[list[int]]: 23 | return [[0 if char != letter else 1 for char in alphabet] for letter in plate] 24 | 25 | 26 | def target_transform( 27 | plate_text: str, 28 | max_plate_slots: int, 29 | alphabet: str, 30 | pad_char: str, 31 | ) -> npt.NDArray[np.uint8]: 32 | # Pad the plates which length is smaller than 'max_plate_slots' 33 | plate_text = plate_text.ljust(max_plate_slots, pad_char) 34 | # Generate numpy arrays with one-hot encoding of plates 35 | encoded_plate = np.array(one_hot_plate(plate_text, alphabet=alphabet), dtype=np.uint8) 36 | return encoded_plate 37 | 38 | 39 | def read_plate_image(image_path: str, img_height: int, img_width: int) -> npt.NDArray: 40 | """ 41 | Read and resize a license plate image. 42 | 43 | :param image_path: The path to the license plate image. 44 | :param img_height: The desired height of the resized image. 45 | :param img_width: The desired width of the resized image. 46 | :return: The resized license plate image as a NumPy array. 47 | """ 48 | img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) 49 | img = cv2.resize(img, (img_width, img_height), interpolation=cv2.INTER_LINEAR) 50 | img = np.expand_dims(img, -1) 51 | return img 52 | 53 | 54 | def load_keras_model( 55 | model_path: str | pathlib.Path, 56 | vocab_size: int, 57 | max_plate_slots: int, 58 | ) -> keras.Model: 59 | """ 60 | Utility helper function to load the keras OCR model. 61 | """ 62 | custom_objects = { 63 | "cce": cce_loss(vocabulary_size=vocab_size), 64 | "cat_acc": cat_acc_metric(max_plate_slots=max_plate_slots, vocabulary_size=vocab_size), 65 | "plate_acc": plate_acc_metric(max_plate_slots=max_plate_slots, vocabulary_size=vocab_size), 66 | "top_3_k": top_3_k_metric(vocabulary_size=vocab_size), 67 | } 68 | model = keras.models.load_model(model_path, custom_objects=custom_objects) 69 | return model 70 | 71 | 72 | IMG_EXTENSIONS: set[str] = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff", ".webp"} 73 | """Valid image extensions for the scope of this script.""" 74 | 75 | 76 | def load_images_from_folder( 77 | img_dir: pathlib.Path, 78 | width: int, 79 | height: int, 80 | shuffle: bool = False, 81 | limit: int | None = None, 82 | ) -> list[npt.NDArray]: 83 | """ 84 | Return all images read from a directory. This uses the same read function used during training. 85 | """ 86 | image_paths = sorted( 87 | str(f.resolve()) for f in img_dir.iterdir() if f.is_file() and f.suffix in IMG_EXTENSIONS 88 | ) 89 | if limit: 90 | image_paths = image_paths[:limit] 91 | if shuffle: 92 | random.shuffle(image_paths) 93 | images = [read_plate_image(i, img_height=height, img_width=width) for i in image_paths] 94 | return images 95 | 96 | 97 | def postprocess_model_output( 98 | prediction: npt.NDArray, 99 | alphabet: str, 100 | max_plate_slots: int, 101 | vocab_size: int, 102 | ) -> tuple[str, npt.NDArray]: 103 | """ 104 | Return plate text and confidence scores from raw model output. 105 | """ 106 | prediction = prediction.reshape((max_plate_slots, vocab_size)) 107 | probs = np.max(prediction, axis=-1) 108 | prediction = np.argmax(prediction, axis=-1) 109 | plate = "".join([alphabet[x] for x in prediction]) 110 | return plate, probs 111 | 112 | 113 | def low_confidence_positions(probs, thresh=0.3) -> npt.NDArray: 114 | """Returns indices of elements in `probs` less than `thresh`, indicating low confidence.""" 115 | return np.where(np.array(probs) < thresh)[0] 116 | 117 | 118 | def display_predictions( 119 | image: npt.NDArray, 120 | plate: str, 121 | probs: npt.NDArray, 122 | low_conf_thresh: float, 123 | ) -> None: 124 | """ 125 | Display plate and corresponding prediction. 126 | """ 127 | plate_str = "".join(plate) 128 | logging.info("Plate: %s", plate_str) 129 | logging.info("Confidence: %s", probs) 130 | image_to_show = cv2.resize(image, None, fx=3, fy=3, interpolation=cv2.INTER_LINEAR) 131 | # Converting to BGR for color text 132 | image_to_show = cv2.cvtColor(image_to_show, cv2.COLOR_GRAY2RGB) 133 | # Average probabilities 134 | avg_prob = np.mean(probs) * 100 135 | cv2.putText( 136 | image_to_show, 137 | f"{plate_str} {avg_prob:.{2}f}%", 138 | org=(5, 30), 139 | fontFace=cv2.FONT_HERSHEY_SIMPLEX, 140 | fontScale=1, 141 | color=(0, 0, 0), 142 | lineType=1, 143 | thickness=6, 144 | ) 145 | cv2.putText( 146 | image_to_show, 147 | f"{plate_str} {avg_prob:.{2}f}%", 148 | org=(5, 30), 149 | fontFace=cv2.FONT_HERSHEY_SIMPLEX, 150 | fontScale=1, 151 | color=(255, 255, 255), 152 | lineType=1, 153 | thickness=2, 154 | ) 155 | # Display character with low confidence 156 | low_conf_chars = "Low conf. on: " + " ".join( 157 | [plate[i] for i in low_confidence_positions(probs, thresh=low_conf_thresh)] 158 | ) 159 | cv2.putText( 160 | image_to_show, 161 | low_conf_chars, 162 | org=(5, 200), 163 | fontFace=cv2.FONT_HERSHEY_SIMPLEX, 164 | fontScale=0.7, 165 | color=(0, 0, 220), 166 | lineType=1, 167 | thickness=2, 168 | ) 169 | cv2.imshow("plates", image_to_show) 170 | if cv2.waitKey(0) & 0xFF == ord("q"): 171 | return 172 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: FastPlateOCR 2 | site_author: ankandrew 3 | site_description: Fast & Lightweight License Plate OCR 4 | repo_url: https://github.com/ankandrew/fast-plate-ocr 5 | theme: 6 | name: material 7 | features: 8 | - navigation.instant 9 | - navigation.instant.progress 10 | - search.suggest 11 | - search.highlight 12 | - content.code.copy 13 | palette: 14 | - scheme: default 15 | toggle: 16 | icon: material/lightbulb-outline 17 | name: Switch to dark mode 18 | - scheme: slate 19 | toggle: 20 | icon: material/lightbulb 21 | name: Switch to light mode 22 | nav: 23 | - Introduction: index.md 24 | - Installation: installation.md 25 | - Usage: usage.md 26 | - Architecture: architecture.md 27 | - Contributing: contributing.md 28 | - Reference: reference.md 29 | plugins: 30 | - search 31 | - mike: 32 | alias_type: symlink 33 | canonical_version: latest 34 | - mkdocstrings: 35 | handlers: 36 | python: 37 | paths: [ . ] 38 | options: 39 | members_order: source 40 | separate_signature: true 41 | filters: [ "!^_" ] 42 | docstring_options: 43 | ignore_init_summary: true 44 | merge_init_into_class: true 45 | show_signature_annotations: true 46 | signature_crossrefs: true 47 | extra: 48 | version: 49 | provider: mike 50 | generator: false 51 | markdown_extensions: 52 | - admonition 53 | - pymdownx.highlight: 54 | anchor_linenums: true 55 | line_spans: __span 56 | pygments_lang_class: true 57 | - pymdownx.inlinehilite 58 | - pymdownx.snippets 59 | - pymdownx.details 60 | - pymdownx.superfences 61 | - toc: 62 | permalink: true 63 | title: Page contents 64 | -------------------------------------------------------------------------------- /poetry.toml: -------------------------------------------------------------------------------- 1 | [virtualenvs] 2 | in-project = true 3 | prefer-active-python = true 4 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "fast-plate-ocr" 3 | version = "0.3.0" 4 | description = "Fast & Lightweight OCR for vehicle license plates." 5 | authors = ["ankandrew <61120139+ankandrew@users.noreply.github.com>"] 6 | readme = "README.md" 7 | repository = "https://github.com/ankandrew/fast-plate-ocr/" 8 | documentation = "https://ankandrew.github.io/fast-plate-ocr" 9 | keywords = ["plate-recognition", "license-plate-recognition", "license-plate-ocr"] 10 | license = "MIT" 11 | classifiers = [ 12 | "Typing :: Typed", 13 | "Intended Audience :: Developers", 14 | "Intended Audience :: Education", 15 | "Intended Audience :: Science/Research", 16 | "Operating System :: OS Independent", 17 | "Topic :: Software Development", 18 | "Topic :: Scientific/Engineering", 19 | "Topic :: Software Development :: Libraries", 20 | "Topic :: Software Development :: Build Tools", 21 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 22 | "Topic :: Software Development :: Libraries :: Python Modules", 23 | ] 24 | 25 | [tool.poetry.dependencies] 26 | python = "^3.10" 27 | # Packages required for doing inference 28 | numpy = ">=1.20" 29 | opencv-python = "*" 30 | pyyaml = ">=5.1" 31 | tqdm = "*" 32 | rich = "*" 33 | 34 | # Install onnxruntime-gpu only on systems other than macOS or Raspberry Pi 35 | onnxruntime-gpu = { version = ">=1.19.2", markers = "sys_platform != 'darwin' and platform_machine != 'armv7l' and platform_machine != 'aarch64' and (platform_system == 'Linux' or platform_system == 'Windows')" } 36 | # Fallback to onnxruntime for macOS, Raspberry Pi, and other unsupported platforms 37 | onnxruntime = { version = ">=1.19.2", markers = "sys_platform == 'darwin' or platform_machine == 'armv7l' or platform_machine == 'aarch64'" } 38 | 39 | # Training packages are optional 40 | albumentations = { version = "*", optional = true } 41 | click = { version = "*", optional = true } 42 | keras = { version = ">=3.1.1", optional = true } 43 | matplotlib = { version = "*", optional = true } 44 | pandas = { version = "*", optional = true } 45 | pydantic = { version = "^2.0.0", optional = true } 46 | tensorboard = { version = "*", optional = true } 47 | tensorflow = { version = "*", optional = true } 48 | tf2onnx = { version = "*", optional = true } 49 | torch = { version = "*", optional = true } 50 | 51 | # Optional packages for creating the docs 52 | mkdocs-material = { version = "*", optional = true } 53 | mkdocstrings = {version = "*", extras = ["python"], optional = true} 54 | mike = { version = "*", optional = true } 55 | onnxsim = { version = ">0.4.10", optional = true } 56 | 57 | [tool.poetry.extras] 58 | train = [ 59 | "albumentations", 60 | "click", 61 | "keras", 62 | "matplotlib", 63 | "pandas", 64 | "pydantic", 65 | "tensorboard", 66 | "tensorflow", 67 | "tf2onnx", 68 | "torch", 69 | "onnxsim", 70 | ] 71 | docs = ["mkdocs-material", "mkdocstrings", "mike"] 72 | 73 | [tool.poetry.group.test.dependencies] 74 | pytest = "*" 75 | 76 | [tool.poetry.group.dev.dependencies] 77 | mypy = "*" 78 | ruff = "*" 79 | pandas-stubs = "^2.2.0.240218" 80 | pylint = "*" 81 | types-pyyaml = "^6.0.12.20240311" 82 | 83 | [tool.poetry.scripts] 84 | fast_plate_ocr = "fast_plate_ocr.cli.cli:main_cli" 85 | 86 | [tool.ruff] 87 | line-length = 100 88 | target-version = "py310" 89 | 90 | [tool.ruff.lint] 91 | select = [ 92 | # pycodestyle 93 | "E", 94 | "W", 95 | # Pyflakes 96 | "F", 97 | # pep8-naming 98 | "N", 99 | # pyupgrade 100 | "UP", 101 | # flake8-bugbear 102 | "B", 103 | # flake8-simplify 104 | "SIM", 105 | # flake8-unused-arguments 106 | "ARG", 107 | # Pylint 108 | "PL", 109 | # Perflint 110 | "PERF", 111 | # Ruff-specific rules 112 | "RUF", 113 | # pandas-vet 114 | "PD", 115 | ] 116 | ignore = ["N812", "PLR2004", "PD011"] 117 | fixable = ["ALL"] 118 | unfixable = [] 119 | 120 | [tool.ruff.lint.pylint] 121 | max-args = 8 122 | 123 | [tool.ruff.format] 124 | line-ending = "lf" 125 | 126 | [tool.mypy] 127 | disable_error_code = "import-untyped" 128 | [[tool.mypy.overrides]] 129 | module = ["albumentations"] 130 | ignore_missing_imports = true 131 | 132 | [tool.pylint.typecheck] 133 | generated-members = ["cv2.*"] 134 | signature-mutators = [ 135 | "click.decorators.option", 136 | "click.decorators.argument", 137 | "click.decorators.version_option", 138 | "click.decorators.help_option", 139 | "click.decorators.pass_context", 140 | "click.decorators.confirmation_option" 141 | ] 142 | 143 | [tool.pylint.format] 144 | max-line-length = 100 145 | 146 | [tool.pylint."messages control"] 147 | disable = ["missing-class-docstring", "missing-function-docstring", "wrong-import-order"] 148 | 149 | [tool.pylint.design] 150 | max-args = 8 151 | min-public-methods = 1 152 | 153 | [tool.pylint.basic] 154 | no-docstring-rgx = "^__|^test_" 155 | 156 | [build-system] 157 | requires = ["poetry-core"] 158 | build-backend = "poetry.core.masonry.api" 159 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test package. 3 | """ 4 | 5 | from pathlib import Path 6 | 7 | PROJECT_ROOT_DIR = Path(__file__).resolve().parent.parent 8 | -------------------------------------------------------------------------------- /test/assets/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Assets package used in tests. 3 | """ 4 | 5 | import pathlib 6 | 7 | ASSETS_DIR = pathlib.Path(__file__).resolve().parent 8 | """Path pointing to test/assets directory""" 9 | -------------------------------------------------------------------------------- /test/assets/test_plate_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ankandrew/fast-plate-ocr/49954297c2226ad6c488f6fbfb93169b579b815d/test/assets/test_plate_1.png -------------------------------------------------------------------------------- /test/assets/test_plate_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ankandrew/fast-plate-ocr/49954297c2226ad6c488f6fbfb93169b579b815d/test/assets/test_plate_2.png -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | """ 2 | The `conftest.py` file serves as a means of providing fixtures for an entire directory. Fixtures 3 | defined in a `conftest.py` can be used by any test in that package without needing to import them 4 | (pytest will automatically discover them). 5 | """ 6 | 7 | import pytest 8 | 9 | 10 | @pytest.fixture(scope="function") 11 | def temp_directory(tmpdir): 12 | """ 13 | Example fixture to create a temporary directory for testing. 14 | """ 15 | temp_dir = tmpdir.mkdir("temp") 16 | yield temp_dir 17 | temp_dir.remove() 18 | -------------------------------------------------------------------------------- /test/fast_lp_ocr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ankandrew/fast-plate-ocr/49954297c2226ad6c488f6fbfb93169b579b815d/test/fast_lp_ocr/__init__.py -------------------------------------------------------------------------------- /test/fast_lp_ocr/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ankandrew/fast-plate-ocr/49954297c2226ad6c488f6fbfb93169b579b815d/test/fast_lp_ocr/inference/__init__.py -------------------------------------------------------------------------------- /test/fast_lp_ocr/inference/test_hub.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for ONNX hub module. 3 | """ 4 | 5 | from http import HTTPStatus 6 | 7 | import pytest 8 | import requests 9 | 10 | from fast_plate_ocr.inference.hub import AVAILABLE_ONNX_MODELS 11 | 12 | 13 | @pytest.mark.parametrize("model_name", AVAILABLE_ONNX_MODELS.keys()) 14 | def test_model_and_config_urls(model_name): 15 | """ 16 | Test to check if the model and config URLs for AVAILABLE_ONNX_MODELS are valid. 17 | """ 18 | model_url, config_url = AVAILABLE_ONNX_MODELS[model_name] 19 | 20 | for url in [model_url, config_url]: 21 | response = requests.head(url, timeout=5, allow_redirects=True) 22 | assert ( 23 | response.status_code == HTTPStatus.OK 24 | ), f"URL {url} is not accessible, got {response.status_code}" 25 | -------------------------------------------------------------------------------- /test/fast_lp_ocr/inference/test_onnx_inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for ONNX inference module. 3 | """ 4 | 5 | from collections.abc import Iterator 6 | 7 | import cv2 8 | import numpy.typing as npt 9 | import pytest 10 | 11 | from fast_plate_ocr import ONNXPlateRecognizer 12 | from test.assets import ASSETS_DIR 13 | 14 | 15 | @pytest.fixture(scope="module", name="onnx_model") 16 | def onnx_model_fixture() -> Iterator[ONNXPlateRecognizer]: 17 | yield ONNXPlateRecognizer("argentinian-plates-cnn-model", device="cpu") 18 | 19 | 20 | @pytest.mark.parametrize( 21 | "input_image, expected_result", 22 | [ 23 | # Single image path 24 | (str(ASSETS_DIR / "test_plate_1.png"), 1), 25 | # Multiple Image paths 26 | ( 27 | [str(ASSETS_DIR / "test_plate_1.png"), str(ASSETS_DIR / "test_plate_2.png")], 28 | 2, 29 | ), 30 | # NumPy array with single image 31 | (cv2.imread(str(ASSETS_DIR / "test_plate_1.png"), cv2.IMREAD_GRAYSCALE), 1), 32 | # NumPy array with batch images 33 | ( 34 | [ 35 | cv2.imread(str(ASSETS_DIR / "test_plate_1.png"), cv2.IMREAD_GRAYSCALE), 36 | cv2.imread(str(ASSETS_DIR / "test_plate_2.png"), cv2.IMREAD_GRAYSCALE), 37 | ], 38 | 2, 39 | ), 40 | ], 41 | ) 42 | def test_result_from_different_image_sources( 43 | input_image: str | list[str] | npt.NDArray, 44 | expected_result: int, 45 | onnx_model: ONNXPlateRecognizer, 46 | ) -> None: 47 | actual_result = len(onnx_model.run(input_image)) 48 | assert actual_result == expected_result 49 | -------------------------------------------------------------------------------- /test/fast_lp_ocr/inference/test_process.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for inference process module. 3 | """ 4 | 5 | import numpy as np 6 | import numpy.typing as npt 7 | import pytest 8 | 9 | from fast_plate_ocr.inference.process import postprocess_output 10 | 11 | 12 | @pytest.mark.parametrize( 13 | "model_output, max_plate_slots, model_alphabet, expected_plates", 14 | [ 15 | ( 16 | np.array( 17 | [ 18 | [[0.5, 0.4, 0.1], [0.2, 0.6, 0.2], [0.1, 0.4, 0.5]], 19 | [[0.1, 0.1, 0.8], [0.2, 0.2, 0.6], [0.1, 0.4, 0.5]], 20 | ], 21 | dtype=np.float32, 22 | ), 23 | 3, 24 | "ABC", 25 | ["ABC", "CCC"], 26 | ), 27 | ( 28 | np.array( 29 | [[[0.1, 0.4, 0.5], [0.6, 0.2, 0.2], [0.1, 0.5, 0.4]]], 30 | dtype=np.float32, 31 | ), 32 | 3, 33 | "ABC", 34 | ["CAB"], 35 | ), 36 | ], 37 | ) 38 | def test_postprocess_output( 39 | model_output: npt.NDArray, 40 | max_plate_slots: int, 41 | model_alphabet: str, 42 | expected_plates: list[str], 43 | ) -> None: 44 | actual_plate = postprocess_output(model_output, max_plate_slots, model_alphabet) 45 | assert actual_plate == expected_plates 46 | -------------------------------------------------------------------------------- /test/fast_lp_ocr/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ankandrew/fast-plate-ocr/49954297c2226ad6c488f6fbfb93169b579b815d/test/fast_lp_ocr/train/__init__.py -------------------------------------------------------------------------------- /test/fast_lp_ocr/train/test_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for config module 3 | """ 4 | 5 | from pathlib import Path 6 | 7 | import pytest 8 | 9 | from fast_plate_ocr.train.model.config import PlateOCRConfig, load_config_from_yaml 10 | from test import PROJECT_ROOT_DIR 11 | 12 | 13 | @pytest.mark.parametrize( 14 | "file_path", 15 | [f for f in PROJECT_ROOT_DIR.joinpath("config").iterdir() if f.suffix in (".yaml", ".yml")], 16 | ) 17 | def test_yaml_configs(file_path: Path) -> None: 18 | load_config_from_yaml(file_path) 19 | 20 | 21 | @pytest.mark.parametrize( 22 | "raw_config", 23 | [ 24 | { 25 | "max_plate_slots": 7, 26 | # Pad char not in alphabet, should raise exception 27 | "alphabet": "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ", 28 | "pad_char": "_", 29 | "img_height": 70, 30 | "img_width": 140, 31 | } 32 | ], 33 | ) 34 | def test_invalid_config_raises(raw_config: dict) -> None: 35 | with pytest.raises(ValueError): 36 | PlateOCRConfig(**raw_config) 37 | -------------------------------------------------------------------------------- /test/fast_lp_ocr/train/test_custom.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test the custom metric/losses module. 3 | """ 4 | 5 | # ruff: noqa: E402 6 | # pylint: disable=wrong-import-position,wrong-import-order,ungrouped-imports 7 | # fmt: off 8 | from fast_plate_ocr.train.utilities.backend_utils import set_pytorch_backend 9 | 10 | set_pytorch_backend() 11 | # fmt: on 12 | 13 | import pytest 14 | import torch 15 | 16 | from fast_plate_ocr.train.model.custom import cat_acc_metric, plate_acc_metric 17 | 18 | 19 | @pytest.mark.parametrize( 20 | "y_true, y_pred, expected_accuracy", 21 | [ 22 | (torch.tensor([[[1, 0]] * 6]), torch.tensor([[[0.9, 0.1]] * 6]), 1.0), 23 | ], 24 | ) 25 | def test_cat_acc(y_true: torch.Tensor, y_pred: torch.Tensor, expected_accuracy: float) -> None: 26 | actual_accuracy = cat_acc_metric(2, 1)(y_true, y_pred) 27 | assert actual_accuracy == expected_accuracy 28 | 29 | 30 | @pytest.mark.parametrize( 31 | "y_true, y_pred, expected_accuracy", 32 | [ 33 | ( 34 | torch.tensor( 35 | [ 36 | [ 37 | [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 38 | [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 39 | [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 40 | [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 41 | [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 42 | [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 43 | ], 44 | [ 45 | [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 46 | [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 47 | [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 48 | [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 49 | [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 50 | [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 51 | ], 52 | ] 53 | ), 54 | torch.tensor( 55 | [ 56 | [ 57 | [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 58 | [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 59 | [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 60 | [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 61 | [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 62 | [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 63 | ], 64 | [ 65 | [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 66 | [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 67 | [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 68 | [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 69 | [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 70 | [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 71 | ], 72 | ] 73 | ), 74 | # First batch slice plate was recognized completely correct but second one wasn't 75 | # So 50% of plates were recognized correctly 76 | 0.5, 77 | ), 78 | ], 79 | ) 80 | def test_plate_accuracy( 81 | y_true: torch.Tensor, y_pred: torch.Tensor, expected_accuracy: float 82 | ) -> None: 83 | actual_accuracy = plate_acc_metric(y_true.shape[1], y_true.shape[2])(y_true, y_pred).item() 84 | assert actual_accuracy == expected_accuracy 85 | -------------------------------------------------------------------------------- /test/fast_lp_ocr/train/test_models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test OCR models module. 3 | """ 4 | 5 | import pytest 6 | from keras import Input 7 | 8 | from fast_plate_ocr.train.model import models 9 | 10 | 11 | @pytest.mark.parametrize( 12 | "max_plates_slots, vocabulary_size, expected_hidden_units", [(7, 37, 7 * 37)] 13 | ) 14 | def test_head(max_plates_slots: int, vocabulary_size: int, expected_hidden_units: int) -> None: 15 | x = Input((70, 140, 1)) 16 | out_tensor = models.head(x, max_plates_slots, vocabulary_size) 17 | actual_hidden_units = out_tensor.shape[-1] 18 | assert actual_hidden_units == expected_hidden_units 19 | 20 | 21 | @pytest.mark.parametrize( 22 | "max_plates_slots, vocabulary_size, expected_hidden_units", [(7, 37, 7 * 37)] 23 | ) 24 | def test_head_no_fc( 25 | max_plates_slots: int, vocabulary_size: int, expected_hidden_units: int 26 | ) -> None: 27 | x = Input((70, 140, 1)) 28 | out_tensor = models.head_no_fc(x, max_plates_slots, vocabulary_size) 29 | actual_hidden_units = out_tensor.shape[1] * out_tensor.shape[2] 30 | assert actual_hidden_units == expected_hidden_units 31 | -------------------------------------------------------------------------------- /test/fast_lp_ocr/train/test_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test utils module. 3 | """ 4 | 5 | import pytest 6 | 7 | from fast_plate_ocr.train.utilities import utils 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "plate, alphabet, expected_result", 12 | [ 13 | ( 14 | "AB12", 15 | "ABCD123456", 16 | [ 17 | [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], 18 | [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], 19 | [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], 20 | [0, 0, 0, 0, 0, 1, 0, 0, 0, 0], 21 | ], 22 | ), 23 | ( 24 | "ABC123", 25 | "0123456789ABCDEFGHIJ", 26 | [ 27 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], 28 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], 29 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], 30 | [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 31 | [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 32 | [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 33 | ], 34 | ), 35 | ], 36 | ) 37 | def test_one_hot_plate(plate: str, alphabet: str, expected_result: list[list[int]]) -> None: 38 | actual_result = utils.one_hot_plate(plate, alphabet) 39 | assert actual_result == expected_result 40 | --------------------------------------------------------------------------------