├── tests ├── __init__.py ├── data │ └── test-european.jpg └── test_ocr.py ├── multiocr ├── pipelines │ ├── aws_textract │ │ ├── __init__.py │ │ └── engine.py │ ├── doctr_ocr │ │ ├── __init__.py │ │ ├── doctr │ │ │ ├── io │ │ │ │ ├── __init__.py │ │ │ │ └── image │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── base.py │ │ │ │ │ ├── tensorflow.py │ │ │ │ │ └── pytorch.py │ │ │ ├── version.py │ │ │ ├── models │ │ │ │ ├── factory │ │ │ │ │ └── __init__.py │ │ │ │ ├── obj_detection │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── faster_rcnn │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── pytorch.py │ │ │ │ ├── artefacts │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── face.py │ │ │ │ │ └── barcode.py │ │ │ │ ├── modules │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── transformer │ │ │ │ │ │ └── __init__.py │ │ │ │ │ └── vision_transformer │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── pytorch.py │ │ │ │ │ │ └── tensorflow.py │ │ │ │ ├── detection │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── _utils │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── tensorflow.py │ │ │ │ │ │ └── pytorch.py │ │ │ │ │ ├── predictor │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── tensorflow.py │ │ │ │ │ │ └── pytorch.py │ │ │ │ │ ├── linknet │ │ │ │ │ │ └── __init__.py │ │ │ │ │ ├── differentiable_binarization │ │ │ │ │ │ └── __init__.py │ │ │ │ │ ├── zoo.py │ │ │ │ │ └── core.py │ │ │ │ ├── recognition │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── predictor │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── tensorflow.py │ │ │ │ │ │ ├── pytorch.py │ │ │ │ │ │ └── _utils.py │ │ │ │ │ ├── crnn │ │ │ │ │ │ └── __init__.py │ │ │ │ │ ├── sar │ │ │ │ │ │ └── __init__.py │ │ │ │ │ ├── master │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── base.py │ │ │ │ │ ├── parseq │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── base.py │ │ │ │ │ ├── vitstr │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── base.py │ │ │ │ │ ├── core.py │ │ │ │ │ ├── zoo.py │ │ │ │ │ └── utils.py │ │ │ │ ├── classification │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── vgg │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── pytorch.py │ │ │ │ │ │ └── tensorflow.py │ │ │ │ │ ├── mobilenet │ │ │ │ │ │ └── __init__.py │ │ │ │ │ ├── resnet │ │ │ │ │ │ └── __init__.py │ │ │ │ │ ├── vit │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── tensorflow.py │ │ │ │ │ │ └── pytorch.py │ │ │ │ │ ├── magc_resnet │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── pytorch.py │ │ │ │ │ │ └── tensorflow.py │ │ │ │ │ ├── predictor │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── tensorflow.py │ │ │ │ │ │ └── pytorch.py │ │ │ │ │ └── zoo.py │ │ │ │ ├── __init__.py │ │ │ │ ├── predictor │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── tensorflow.py │ │ │ │ │ ├── base.py │ │ │ │ │ └── pytorch.py │ │ │ │ ├── utils │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── pytorch.py │ │ │ │ │ └── tensorflow.py │ │ │ │ ├── preprocessor │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── tensorflow.py │ │ │ │ │ └── pytorch.py │ │ │ │ ├── core.py │ │ │ │ └── zoo.py │ │ │ ├── transforms │ │ │ │ ├── __init__.py │ │ │ │ ├── functional │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── pytorch.py │ │ │ │ └── modules │ │ │ │ │ └── __init__.py │ │ │ ├── datasets │ │ │ │ ├── __init__.py │ │ │ │ └── vocabs.py │ │ │ ├── utils │ │ │ │ ├── __init__.py │ │ │ │ ├── common_types.py │ │ │ │ ├── fonts.py │ │ │ │ ├── multithreading.py │ │ │ │ ├── repr.py │ │ │ │ └── data.py │ │ │ ├── __init__.py │ │ │ └── file_utils.py │ │ ├── readme.md │ │ └── engine.py │ ├── paddle_ocr │ │ ├── __init__.py │ │ └── engine.py │ ├── tesseract │ │ ├── __init__.py │ │ └── engine.py │ ├── __init__.py │ └── easy_ocr │ │ └── engine.py ├── __init__.py ├── base_class.py ├── utils.py └── main.py ├── Dockerfile ├── pyproject.toml ├── LICENSE ├── .gitignore └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /multiocr/pipelines/aws_textract/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /multiocr/pipelines/paddle_ocr/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /multiocr/pipelines/tesseract/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/io/__init__.py: -------------------------------------------------------------------------------- 1 | from .image import * -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.7.0a0' 2 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/factory/__init__.py: -------------------------------------------------------------------------------- 1 | from .hub import * 2 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * 2 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.7 2 | 3 | RUN apt-get update 4 | RUN apt-get install -y git 5 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/obj_detection/__init__.py: -------------------------------------------------------------------------------- 1 | from .faster_rcnn import * 2 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .vocabs import * 2 | from .utils import * -------------------------------------------------------------------------------- /tests/data/test-european.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s-aravindh/multiocr/HEAD/tests/data/test-european.jpg -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/artefacts/__init__.py: -------------------------------------------------------------------------------- 1 | from .barcode import * 2 | from .face import * 3 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import * 2 | from .vision_transformer import * 3 | -------------------------------------------------------------------------------- /multiocr/__init__.py: -------------------------------------------------------------------------------- 1 | from multiocr.main import OcrEngine 2 | from multiocr.pipelines.doctr_ocr import * 3 | from multiocr.pipelines.doctr_ocr.doctr import * -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/detection/__init__.py: -------------------------------------------------------------------------------- 1 | from .differentiable_binarization import * 2 | from .linknet import * 3 | from .zoo import * 4 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .common_types import * 2 | from .data import * 3 | from .geometry import * 4 | from .metrics import * 5 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/recognition/__init__.py: -------------------------------------------------------------------------------- 1 | from .crnn import * 2 | from .master import * 3 | from .sar import * 4 | from .vitstr import * 5 | from .parseq import * 6 | from .zoo import * 7 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/__init__.py: -------------------------------------------------------------------------------- 1 | from . import io, models, datasets, transforms, utils 2 | from .file_utils import is_tf_available, is_torch_available 3 | from .version import __version__ # noqa: F401 4 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/classification/__init__.py: -------------------------------------------------------------------------------- 1 | from .mobilenet import * 2 | from .resnet import * 3 | from .vgg import * 4 | from .magc_resnet import * 5 | from .vit import * 6 | from .zoo import * 7 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import artefacts 2 | from .classification import * 3 | from .detection import * 4 | from .recognition import * 5 | from .zoo import * 6 | from .factory import * 7 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/detection/_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from multiocr.pipelines.doctr_ocr.doctr.file_utils import is_tf_available 2 | 3 | if is_tf_available(): 4 | from .tensorflow import * 5 | else: 6 | from .pytorch import * 7 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/obj_detection/faster_rcnn/__init__.py: -------------------------------------------------------------------------------- 1 | from multiocr.pipelines.doctr_ocr.doctr.file_utils import is_tf_available, is_torch_available 2 | 3 | if not is_tf_available() and is_torch_available(): 4 | from .pytorch import * 5 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/classification/vgg/__init__.py: -------------------------------------------------------------------------------- 1 | from .....doctr.file_utils import is_tf_available, is_torch_available 2 | 3 | if is_tf_available(): 4 | from .tensorflow import * 5 | elif is_torch_available(): 6 | from .pytorch import * 7 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/predictor/__init__.py: -------------------------------------------------------------------------------- 1 | from multiocr.pipelines.doctr_ocr.doctr.file_utils import is_tf_available 2 | 3 | if is_tf_available(): 4 | from .tensorflow import * 5 | else: 6 | from .pytorch import * # type: ignore[assignment] 7 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/classification/mobilenet/__init__.py: -------------------------------------------------------------------------------- 1 | from .....doctr.file_utils import is_tf_available, is_torch_available 2 | 3 | if is_tf_available(): 4 | from .tensorflow import * 5 | elif is_torch_available(): 6 | from .pytorch import * 7 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/io/image/__init__.py: -------------------------------------------------------------------------------- 1 | from ....doctr.file_utils import is_tf_available, is_torch_available 2 | 3 | from .base import * 4 | 5 | if is_tf_available(): 6 | from .tensorflow import * 7 | elif is_torch_available(): 8 | from .pytorch import * 9 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/detection/predictor/__init__.py: -------------------------------------------------------------------------------- 1 | from multiocr.pipelines.doctr_ocr.doctr.file_utils import is_tf_available 2 | 3 | if is_tf_available(): 4 | from .tensorflow import * 5 | else: 6 | from .pytorch import * # type: ignore[assignment] 7 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/recognition/predictor/__init__.py: -------------------------------------------------------------------------------- 1 | from multiocr.pipelines.doctr_ocr.doctr.file_utils import is_tf_available 2 | 3 | if is_tf_available(): 4 | from .tensorflow import * 5 | else: 6 | from .pytorch import * # type: ignore[assignment] 7 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from ....doctr.file_utils import is_tf_available, is_torch_available 2 | 3 | if is_tf_available(): 4 | from .tensorflow import * 5 | elif is_torch_available(): 6 | from .pytorch import * # type: ignore[assignment] 7 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/readme.md: -------------------------------------------------------------------------------- 1 | # References & Credits 2 | all code for **DocTR OCR** is sourced from [python-doctr package from mindee](https://github.com/mindee/doctr) 3 | 4 | Thanks to [@felixdittrich92](https://github.com/felixdittrich92) for his suggestion to integratr DocTR -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/transforms/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from multiocr.pipelines.doctr_ocr.doctr.file_utils import is_tf_available, is_torch_available 2 | 3 | if is_tf_available(): 4 | from .tensorflow import * 5 | elif is_torch_available(): 6 | from .pytorch import * 7 | -------------------------------------------------------------------------------- /multiocr/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from multiocr.pipelines.aws_textract.engine import * 2 | from multiocr.pipelines.tesseract.engine import * 3 | from multiocr.pipelines.paddle_ocr.engine import * 4 | from multiocr.pipelines.easy_ocr.engine import * 5 | from multiocr.pipelines.doctr_ocr.engine import * -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/classification/resnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .....doctr.file_utils import is_tf_available, is_torch_available 2 | 3 | if is_tf_available(): 4 | from .tensorflow import * 5 | elif is_torch_available(): 6 | from .pytorch import * # type: ignore[assignment] 7 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/preprocessor/__init__.py: -------------------------------------------------------------------------------- 1 | from multiocr.pipelines.doctr_ocr.doctr.file_utils import is_tf_available, is_torch_available 2 | 3 | if is_tf_available(): 4 | from .tensorflow import * 5 | elif is_torch_available(): 6 | from .pytorch import * # type: ignore[assignment] 7 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/recognition/crnn/__init__.py: -------------------------------------------------------------------------------- 1 | from multiocr.pipelines.doctr_ocr.doctr.file_utils import is_tf_available, is_torch_available 2 | 3 | if is_tf_available(): 4 | from .tensorflow import * 5 | elif is_torch_available(): 6 | from .pytorch import * # type: ignore[assignment] 7 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/recognition/sar/__init__.py: -------------------------------------------------------------------------------- 1 | from multiocr.pipelines.doctr_ocr.doctr.file_utils import is_tf_available, is_torch_available 2 | 3 | if is_tf_available(): 4 | from .tensorflow import * 5 | elif is_torch_available(): 6 | from .pytorch import * # type: ignore[assignment] 7 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/classification/vit/__init__.py: -------------------------------------------------------------------------------- 1 | from multiocr.pipelines.doctr_ocr.doctr.file_utils import is_tf_available, is_torch_available 2 | 3 | if is_tf_available(): 4 | from .tensorflow import * 5 | elif is_torch_available(): 6 | from .pytorch import * # type: ignore[assignment] 7 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/detection/linknet/__init__.py: -------------------------------------------------------------------------------- 1 | from multiocr.pipelines.doctr_ocr.doctr.file_utils import is_tf_available, is_torch_available 2 | 3 | if is_tf_available(): 4 | from .tensorflow import * 5 | elif is_torch_available(): 6 | from .pytorch import * # type: ignore[assignment] 7 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/modules/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from multiocr.pipelines.doctr_ocr.doctr.file_utils import is_tf_available, is_torch_available 2 | 3 | if is_tf_available(): 4 | from .tensorflow import * 5 | elif is_torch_available(): 6 | from .pytorch import * # type: ignore[assignment] 7 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/recognition/master/__init__.py: -------------------------------------------------------------------------------- 1 | from multiocr.pipelines.doctr_ocr.doctr.file_utils import is_tf_available, is_torch_available 2 | 3 | if is_tf_available(): 4 | from .tensorflow import * 5 | elif is_torch_available(): 6 | from .pytorch import * # type: ignore[assignment] 7 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/recognition/parseq/__init__.py: -------------------------------------------------------------------------------- 1 | from multiocr.pipelines.doctr_ocr.doctr.file_utils import is_tf_available, is_torch_available 2 | 3 | if is_tf_available(): 4 | from .tensorflow import * 5 | elif is_torch_available(): 6 | from .pytorch import * # type: ignore[assignment] 7 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/recognition/vitstr/__init__.py: -------------------------------------------------------------------------------- 1 | from multiocr.pipelines.doctr_ocr.doctr.file_utils import is_tf_available, is_torch_available 2 | 3 | if is_tf_available(): 4 | from .tensorflow import * 5 | elif is_torch_available(): 6 | from .pytorch import * # type: ignore[assignment] 7 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/classification/magc_resnet/__init__.py: -------------------------------------------------------------------------------- 1 | from multiocr.pipelines.doctr_ocr.doctr.file_utils import is_tf_available, is_torch_available 2 | 3 | if is_tf_available(): 4 | from .tensorflow import * 5 | elif is_torch_available(): 6 | from .pytorch import * # type: ignore[assignment] 7 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/classification/predictor/__init__.py: -------------------------------------------------------------------------------- 1 | from multiocr.pipelines.doctr_ocr.doctr.file_utils import is_tf_available, is_torch_available 2 | 3 | if is_tf_available(): 4 | from .tensorflow import * 5 | elif is_torch_available(): 6 | from .pytorch import * # type: ignore[assignment] 7 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/modules/vision_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from multiocr.pipelines.doctr_ocr.doctr.file_utils import is_tf_available, is_torch_available 2 | 3 | if is_tf_available(): 4 | from .tensorflow import * 5 | elif is_torch_available(): 6 | from .pytorch import * # type: ignore[assignment] 7 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/detection/differentiable_binarization/__init__.py: -------------------------------------------------------------------------------- 1 | from multiocr.pipelines.doctr_ocr.doctr.file_utils import is_tf_available, is_torch_available 2 | 3 | if is_tf_available(): 4 | from .tensorflow import * 5 | elif is_torch_available(): 6 | from .pytorch import * # type: ignore[assignment] 7 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/transforms/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from multiocr.pipelines.doctr_ocr.doctr.file_utils import is_tf_available, is_torch_available 2 | 3 | from .base import * 4 | 5 | if is_tf_available(): 6 | from .tensorflow import * 7 | elif is_torch_available(): 8 | from .pytorch import * # type: ignore[assignment] 9 | -------------------------------------------------------------------------------- /multiocr/base_class.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class OCR(ABC): 5 | 6 | @abstractmethod 7 | def text_extraction(self): 8 | pass 9 | 10 | @abstractmethod 11 | def text_extraction_to_json(self): 12 | pass 13 | 14 | @abstractmethod 15 | def text_extraction_to_df(self): 16 | pass 17 | 18 | @abstractmethod 19 | def extract_plain_text(self): 20 | pass -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "multiocr" 3 | version = "0.1.4" 4 | description = "" 5 | authors = ["Aravindh <32878238+s-aravindh@users.noreply.github.com>"] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.8" 10 | pytesseract = "^0.3.10" 11 | paddlepaddle = "^2.4.2" 12 | paddleocr = "^2.6.1.3" 13 | easyocr = "^1.7.0" 14 | unidecode = "^1.3.6" 15 | langdetect = "^1.0.9" 16 | 17 | 18 | [build-system] 19 | requires = ["poetry-core"] 20 | build-backend = "poetry.core.masonry.api" 21 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/core.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | 7 | from typing import Any, Dict, Optional 8 | 9 | from multiocr.pipelines.doctr_ocr.doctr.utils.repr import NestedObject 10 | 11 | __all__ = ["BaseModel"] 12 | 13 | 14 | class BaseModel(NestedObject): 15 | """Implements abstract DetectionModel class""" 16 | 17 | def __init__(self, cfg: Optional[Dict[str, Any]] = None) -> None: 18 | super().__init__() 19 | self.cfg = cfg 20 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/utils/common_types.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from pathlib import Path 7 | from typing import List, Tuple, Union 8 | 9 | __all__ = ["Point2D", "BoundingBox", "Polygon4P", "Polygon", "Bbox"] 10 | 11 | 12 | Point2D = Tuple[float, float] 13 | BoundingBox = Tuple[Point2D, Point2D] 14 | Polygon4P = Tuple[Point2D, Point2D, Point2D, Point2D] 15 | Polygon = List[Point2D] 16 | AbstractPath = Union[str, Path] 17 | AbstractFile = Union[AbstractPath, bytes] 18 | Bbox = Tuple[float, float, float, float] 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Aravindh 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/detection/_utils/tensorflow.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import tensorflow as tf 7 | 8 | __all__ = ["erode", "dilate"] 9 | 10 | 11 | def erode(x: tf.Tensor, kernel_size: int) -> tf.Tensor: 12 | """Performs erosion on a given tensor 13 | 14 | Args: 15 | x: boolean tensor of shape (N, H, W, C) 16 | kernel_size: the size of the kernel to use for erosion 17 | Returns: 18 | the eroded tensor 19 | """ 20 | 21 | return 1 - tf.nn.max_pool2d(1 - x, kernel_size, strides=1, padding="SAME") 22 | 23 | 24 | def dilate(x: tf.Tensor, kernel_size: int) -> tf.Tensor: 25 | """Performs dilation on a given tensor 26 | 27 | Args: 28 | x: boolean tensor of shape (N, H, W, C) 29 | kernel_size: the size of the kernel to use for dilation 30 | Returns: 31 | the dilated tensor 32 | """ 33 | 34 | return tf.nn.max_pool2d(x, kernel_size, strides=1, padding="SAME") 35 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/detection/_utils/pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from torch import Tensor 7 | from torch.nn.functional import max_pool2d 8 | 9 | __all__ = ["erode", "dilate"] 10 | 11 | 12 | def erode(x: Tensor, kernel_size: int) -> Tensor: 13 | """Performs erosion on a given tensor 14 | 15 | Args: 16 | x: boolean tensor of shape (N, C, H, W) 17 | kernel_size: the size of the kernel to use for erosion 18 | Returns: 19 | the eroded tensor 20 | """ 21 | _pad = (kernel_size - 1) // 2 22 | 23 | return 1 - max_pool2d(1 - x, kernel_size, stride=1, padding=_pad) 24 | 25 | 26 | def dilate(x: Tensor, kernel_size: int) -> Tensor: 27 | """Performs dilation on a given tensor 28 | 29 | Args: 30 | x: boolean tensor of shape (N, C, H, W) 31 | kernel_size: the size of the kernel to use for dilation 32 | Returns: 33 | the dilated tensor 34 | """ 35 | _pad = (kernel_size - 1) // 2 36 | 37 | return max_pool2d(x, kernel_size, stride=1, padding=_pad) 38 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/utils/fonts.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import logging 7 | import platform 8 | from typing import Optional 9 | 10 | from PIL import ImageFont 11 | 12 | __all__ = ["get_font"] 13 | 14 | 15 | def get_font(font_family: Optional[str] = None, font_size: int = 13) -> ImageFont.ImageFont: 16 | """Resolves a compatible ImageFont for the system 17 | 18 | Args: 19 | font_family: the font family to use 20 | font_size: the size of the font upon rendering 21 | 22 | Returns: 23 | the Pillow font 24 | """ 25 | 26 | # Font selection 27 | if font_family is None: 28 | try: 29 | font = ImageFont.truetype("FreeMono.ttf" if platform.system() == "Linux" else "Arial.ttf", font_size) 30 | except OSError: 31 | font = ImageFont.load_default() 32 | logging.warning( 33 | "unable to load recommended font family. Loading default PIL font," 34 | "font size issues may be expected." 35 | "To prevent this, it is recommended to specify the value of 'font_family'." 36 | ) 37 | else: 38 | font = ImageFont.truetype(font_family, font_size) 39 | 40 | return font 41 | -------------------------------------------------------------------------------- /multiocr/utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageDraw 2 | # import fitz 3 | 4 | def draw_bounding_boxes(image_file, text_dict): 5 | # Open the image file 6 | image = Image.open(image_file).convert("RGBA") 7 | # Initialize the drawing context 8 | draw = ImageDraw.Draw(image) 9 | 10 | # Draw a green bounding box around each word 11 | for v in text_dict: 12 | left = v["coordinates"]["xmin"] 13 | top = v["coordinates"]["ymin"] 14 | right = v["coordinates"]["xmax"] 15 | bottom = v["coordinates"]["ymax"] 16 | draw.rectangle((left, top, right, bottom), outline='green', width=2) 17 | 18 | # Return the image with bounding boxes drawn over the words 19 | return image 20 | 21 | 22 | def is_digital_page(pdf_path, page_num): 23 | # Open the PDF file and select the specified page 24 | with fitz.open(pdf_path) as doc: 25 | page = doc[page_num] 26 | 27 | # Check if the page is a scanned image 28 | if page.is_image: 29 | return False 30 | 31 | # Check if the page contains any text 32 | text = page.get_text() 33 | if text.strip(): 34 | # Check if the text is really text or binary contents 35 | if text.isprintable(): 36 | return True 37 | else: 38 | return False 39 | 40 | # If the page is not a scanned image and does not contain text, it's likely a digital page 41 | return True 42 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/recognition/vitstr/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import List, Tuple 7 | 8 | import numpy as np 9 | 10 | from ....datasets import encode_sequences 11 | from ..core import RecognitionPostProcessor 12 | 13 | 14 | class _ViTSTR: 15 | vocab: str 16 | max_length: int 17 | 18 | def build_target( 19 | self, 20 | gts: List[str], 21 | ) -> Tuple[np.ndarray, List[int]]: 22 | """Encode a list of gts sequences into a np array and gives the corresponding* 23 | sequence lengths. 24 | 25 | Args: 26 | gts: list of ground-truth labels 27 | 28 | Returns: 29 | A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch) 30 | """ 31 | encoded = encode_sequences( 32 | sequences=gts, 33 | vocab=self.vocab, 34 | target_size=self.max_length, 35 | eos=len(self.vocab), 36 | sos=len(self.vocab) + 1, 37 | ) 38 | seq_len = [len(word) for word in gts] 39 | return encoded, seq_len 40 | 41 | 42 | class _ViTSTRPostProcessor(RecognitionPostProcessor): 43 | """Abstract class to postprocess the raw output of the model 44 | 45 | Args: 46 | vocab: string containing the ordered sequence of supported characters 47 | """ 48 | 49 | def __init__( 50 | self, 51 | vocab: str, 52 | ) -> None: 53 | super().__init__(vocab) 54 | self._embedding = list(vocab) + ["", ""] 55 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/recognition/parseq/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import List, Tuple 7 | 8 | import numpy as np 9 | 10 | from ....datasets import encode_sequences 11 | from ..core import RecognitionPostProcessor 12 | 13 | 14 | class _PARSeq: 15 | vocab: str 16 | max_length: int 17 | 18 | def build_target( 19 | self, 20 | gts: List[str], 21 | ) -> Tuple[np.ndarray, List[int]]: 22 | """Encode a list of gts sequences into a np array and gives the corresponding* 23 | sequence lengths. 24 | 25 | Args: 26 | gts: list of ground-truth labels 27 | 28 | Returns: 29 | A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch) 30 | """ 31 | encoded = encode_sequences( 32 | sequences=gts, 33 | vocab=self.vocab, 34 | target_size=self.max_length, 35 | eos=len(self.vocab), 36 | sos=len(self.vocab) + 1, 37 | pad=len(self.vocab) + 2, 38 | ) 39 | seq_len = [len(word) for word in gts] 40 | return encoded, seq_len 41 | 42 | 43 | class _PARSeqPostProcessor(RecognitionPostProcessor): 44 | """Abstract class to postprocess the raw output of the model 45 | 46 | Args: 47 | vocab: string containing the ordered sequence of supported characters 48 | """ 49 | 50 | def __init__( 51 | self, 52 | vocab: str, 53 | ) -> None: 54 | super().__init__(vocab) 55 | self._embedding = list(vocab) + ["", "", ""] 56 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/recognition/master/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import List, Tuple 7 | 8 | import numpy as np 9 | 10 | from ....datasets import encode_sequences 11 | from ..core import RecognitionPostProcessor 12 | 13 | 14 | class _MASTER: 15 | vocab: str 16 | max_length: int 17 | 18 | def build_target( 19 | self, 20 | gts: List[str], 21 | ) -> Tuple[np.ndarray, List[int]]: 22 | """Encode a list of gts sequences into a np array and gives the corresponding* 23 | sequence lengths. 24 | 25 | Args: 26 | gts: list of ground-truth labels 27 | 28 | Returns: 29 | A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch) 30 | """ 31 | encoded = encode_sequences( 32 | sequences=gts, 33 | vocab=self.vocab, 34 | target_size=self.max_length, 35 | eos=len(self.vocab), 36 | sos=len(self.vocab) + 1, 37 | pad=len(self.vocab) + 2, 38 | ) 39 | seq_len = [len(word) for word in gts] 40 | return encoded, seq_len 41 | 42 | 43 | class _MASTERPostProcessor(RecognitionPostProcessor): 44 | """Abstract class to postprocess the raw output of the model 45 | 46 | Args: 47 | vocab: string containing the ordered sequence of supported characters 48 | """ 49 | 50 | def __init__( 51 | self, 52 | vocab: str, 53 | ) -> None: 54 | super().__init__(vocab) 55 | self._embedding = list(vocab) + [""] + [""] + [""] 56 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/detection/predictor/tensorflow.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import Any, Dict, List, Union 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | from tensorflow import keras 11 | 12 | from multiocr.pipelines.doctr_ocr.doctr.models.preprocessor import PreProcessor 13 | from multiocr.pipelines.doctr_ocr.doctr.utils.repr import NestedObject 14 | 15 | __all__ = ["DetectionPredictor"] 16 | 17 | 18 | class DetectionPredictor(NestedObject): 19 | """Implements an object able to localize text elements in a document 20 | 21 | Args: 22 | pre_processor: transform inputs for easier batched model inference 23 | model: core detection architecture 24 | """ 25 | 26 | _children_names: List[str] = ["pre_processor", "model"] 27 | 28 | def __init__( 29 | self, 30 | pre_processor: PreProcessor, 31 | model: keras.Model, 32 | ) -> None: 33 | self.pre_processor = pre_processor 34 | self.model = model 35 | 36 | def __call__( 37 | self, 38 | pages: List[Union[np.ndarray, tf.Tensor]], 39 | **kwargs: Any, 40 | ) -> List[Dict[str, np.ndarray]]: 41 | # Dimension check 42 | if any(page.ndim != 3 for page in pages): 43 | raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.") 44 | 45 | processed_batches = self.pre_processor(pages) 46 | predicted_batches = [ 47 | self.model(batch, return_preds=True, training=False, **kwargs)["preds"] for batch in processed_batches 48 | ] 49 | return [pred for batch in predicted_batches for pred in batch] 50 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/recognition/core.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import List, Tuple 7 | 8 | import numpy as np 9 | 10 | from multiocr.pipelines.doctr_ocr.doctr.datasets import encode_sequences 11 | from multiocr.pipelines.doctr_ocr.doctr.utils.repr import NestedObject 12 | 13 | __all__ = ["RecognitionPostProcessor", "RecognitionModel"] 14 | 15 | 16 | class RecognitionModel(NestedObject): 17 | """Implements abstract RecognitionModel class""" 18 | 19 | vocab: str 20 | max_length: int 21 | 22 | def build_target( 23 | self, 24 | gts: List[str], 25 | ) -> Tuple[np.ndarray, List[int]]: 26 | """Encode a list of gts sequences into a np array and gives the corresponding* 27 | sequence lengths. 28 | 29 | Args: 30 | gts: list of ground-truth labels 31 | 32 | Returns: 33 | A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch) 34 | """ 35 | encoded = encode_sequences(sequences=gts, vocab=self.vocab, target_size=self.max_length, eos=len(self.vocab)) 36 | seq_len = [len(word) for word in gts] 37 | return encoded, seq_len 38 | 39 | 40 | class RecognitionPostProcessor(NestedObject): 41 | """Abstract class to postprocess the raw output of the model 42 | 43 | Args: 44 | vocab: string containing the ordered sequence of supported characters 45 | """ 46 | 47 | def __init__( 48 | self, 49 | vocab: str, 50 | ) -> None: 51 | self.vocab = vocab 52 | self._embedding = list(self.vocab) + [""] 53 | 54 | def extra_repr(self) -> str: 55 | return f"vocab_size={len(self.vocab)}" 56 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/datasets/vocabs.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import string 7 | from typing import Dict 8 | 9 | __all__ = ["VOCABS"] 10 | 11 | 12 | VOCABS: Dict[str, str] = { 13 | "digits": string.digits, 14 | "ascii_letters": string.ascii_letters, 15 | "punctuation": string.punctuation, 16 | "currency": "£€¥¢฿", 17 | "ancient_greek": "αβγδεζηθικλμνξοπρστυφχψωΑΒΓΔΕΖΗΘΙΚΛΜΝΞΟΠΡΣΤΥΦΧΨΩ", 18 | "arabic_letters": "ءآأؤإئابةتثجحخدذرزسشصضطظعغـفقكلمنهوىي", 19 | "persian_letters": "پچڢڤگ", 20 | "hindi_digits": "٠١٢٣٤٥٦٧٨٩", 21 | "arabic_diacritics": "ًٌٍَُِّْ", 22 | "arabic_punctuation": "؟؛«»—", 23 | } 24 | 25 | VOCABS["latin"] = VOCABS["digits"] + VOCABS["ascii_letters"] + VOCABS["punctuation"] 26 | VOCABS["english"] = VOCABS["latin"] + "°" + VOCABS["currency"] 27 | VOCABS["legacy_french"] = VOCABS["latin"] + "°" + "àâéèêëîïôùûçÀÂÉÈËÎÏÔÙÛÇ" + VOCABS["currency"] 28 | VOCABS["french"] = VOCABS["english"] + "àâéèêëîïôùûüçÀÂÉÈÊËÎÏÔÙÛÜÇ" 29 | VOCABS["portuguese"] = VOCABS["english"] + "áàâãéêíïóôõúüçÁÀÂÃÉÊÍÏÓÔÕÚÜÇ" 30 | VOCABS["spanish"] = VOCABS["english"] + "áéíóúüñÁÉÍÓÚÜÑ" + "¡¿" 31 | VOCABS["german"] = VOCABS["english"] + "äöüßÄÖÜẞ" 32 | VOCABS["arabic"] = ( 33 | VOCABS["digits"] 34 | + VOCABS["hindi_digits"] 35 | + VOCABS["arabic_letters"] 36 | + VOCABS["persian_letters"] 37 | + VOCABS["arabic_diacritics"] 38 | + VOCABS["arabic_punctuation"] 39 | + VOCABS["punctuation"] 40 | ) 41 | VOCABS["czech"] = VOCABS["english"] + "áčďéěíňóřšťúůýžÁČĎÉĚÍŇÓŘŠŤÚŮÝŽ" 42 | VOCABS["vietnamese"] = ( 43 | VOCABS["english"] 44 | + "áàảạãăắằẳẵặâấầẩẫậéèẻẽẹêếềểễệóòỏõọôốồổộỗơớờởợỡúùủũụưứừửữựiíìỉĩịýỳỷỹỵ" 45 | + "ÁÀẢẠÃĂẮẰẲẴẶÂẤẦẨẪẬÉÈẺẼẸÊẾỀỂỄỆÓÒỎÕỌÔỐỒỔỘỖƠỚỜỞỢỠÚÙỦŨỤƯỨỪỬỮỰIÍÌỈĨỊÝỲỶỸỴ" 46 | ) 47 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/io/image/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from pathlib import Path 7 | from typing import Optional, Tuple 8 | 9 | import cv2 10 | import numpy as np 11 | 12 | from ....doctr.utils.common_types import AbstractFile 13 | 14 | __all__ = ["read_img_as_numpy"] 15 | 16 | 17 | def read_img_as_numpy( 18 | file: AbstractFile, 19 | output_size: Optional[Tuple[int, int]] = None, 20 | rgb_output: bool = True, 21 | ) -> np.ndarray: 22 | """Read an image file into numpy format 23 | 24 | >>> from doctr.documents import read_img 25 | >>> page = read_img("path/to/your/doc.jpg") 26 | 27 | Args: 28 | file: the path to the image file 29 | output_size: the expected output size of each page in format H x W 30 | rgb_output: whether the output ndarray channel order should be RGB instead of BGR. 31 | 32 | Returns: 33 | the page decoded as numpy ndarray of shape H x W x 3 34 | """ 35 | 36 | if isinstance(file, (str, Path)): 37 | if not Path(file).is_file(): 38 | raise FileNotFoundError(f"unable to access {file}") 39 | img = cv2.imread(str(file), cv2.IMREAD_COLOR) 40 | elif isinstance(file, bytes): 41 | _file: np.ndarray = np.frombuffer(file, np.uint8) 42 | img = cv2.imdecode(_file, cv2.IMREAD_COLOR) 43 | else: 44 | raise TypeError("unsupported object type for argument 'file'") 45 | 46 | # Validity check 47 | if img is None: 48 | raise ValueError("unable to read file.") 49 | # Resizing 50 | if isinstance(output_size, tuple): 51 | img = cv2.resize(img, output_size[::-1], interpolation=cv2.INTER_LINEAR) 52 | # Switch the channel order 53 | if rgb_output: 54 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 55 | return img 56 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/detection/predictor/pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import Any, List, Union 7 | 8 | import numpy as np 9 | import torch 10 | from torch import nn 11 | 12 | from multiocr.pipelines.doctr_ocr.doctr.models.preprocessor import PreProcessor 13 | from multiocr.pipelines.doctr_ocr.doctr.models.utils import set_device_and_dtype 14 | 15 | __all__ = ["DetectionPredictor"] 16 | 17 | 18 | class DetectionPredictor(nn.Module): 19 | """Implements an object able to localize text elements in a document 20 | 21 | Args: 22 | pre_processor: transform inputs for easier batched model inference 23 | model: core detection architecture 24 | """ 25 | 26 | def __init__( 27 | self, 28 | pre_processor: PreProcessor, 29 | model: nn.Module, 30 | ) -> None: 31 | super().__init__() 32 | self.pre_processor = pre_processor 33 | self.model = model.eval() 34 | 35 | @torch.no_grad() 36 | def forward( 37 | self, 38 | pages: List[Union[np.ndarray, torch.Tensor]], 39 | **kwargs: Any, 40 | ) -> List[np.ndarray]: 41 | # Dimension check 42 | if any(page.ndim != 3 for page in pages): 43 | raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.") 44 | 45 | processed_batches = self.pre_processor(pages) 46 | _params = next(self.model.parameters()) 47 | self.model, processed_batches = set_device_and_dtype( 48 | self.model, processed_batches, _params.device, _params.dtype 49 | ) 50 | predicted_batches = [self.model(batch, return_preds=True, **kwargs)["preds"] for batch in processed_batches] 51 | return [pred for batch in predicted_batches for pred in batch] 52 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/classification/predictor/tensorflow.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import List, Union 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | from tensorflow import keras 11 | 12 | from multiocr.pipelines.doctr_ocr.doctr.models.preprocessor import PreProcessor 13 | from multiocr.pipelines.doctr_ocr.doctr.utils.repr import NestedObject 14 | 15 | __all__ = ["CropOrientationPredictor"] 16 | 17 | 18 | class CropOrientationPredictor(NestedObject): 19 | """Implements an object able to detect the reading direction of a text box. 20 | 4 possible orientations: 0, 90, 180, 270 degrees counter clockwise. 21 | 22 | Args: 23 | pre_processor: transform inputs for easier batched model inference 24 | model: core classification architecture (backbone + classification head) 25 | """ 26 | 27 | _children_names: List[str] = ["pre_processor", "model"] 28 | 29 | def __init__( 30 | self, 31 | pre_processor: PreProcessor, 32 | model: keras.Model, 33 | ) -> None: 34 | self.pre_processor = pre_processor 35 | self.model = model 36 | 37 | def __call__( 38 | self, 39 | crops: List[Union[np.ndarray, tf.Tensor]], 40 | ) -> List[int]: 41 | # Dimension check 42 | if any(crop.ndim != 3 for crop in crops): 43 | raise ValueError("incorrect input shape: all crops are expected to be multi-channel 2D images.") 44 | 45 | processed_batches = self.pre_processor(crops) 46 | predicted_batches = [self.model(batch, training=False) for batch in processed_batches] 47 | 48 | # Postprocess predictions 49 | predicted_batches = [out_batch.numpy().argmax(1) for out_batch in predicted_batches] 50 | 51 | return [int(pred) for batch in predicted_batches for pred in batch] 52 | -------------------------------------------------------------------------------- /tests/test_ocr.py: -------------------------------------------------------------------------------- 1 | 2 | from multiocr import OcrEngine 3 | import os 4 | import pandas as pd 5 | import json 6 | 7 | image_file = "/Users/aravindh/Documents/GitHub/multiocr/tests/data/test-european.jpg" 8 | 9 | 10 | all_ocr = [ 11 | # test with custom config for each ocr 12 | { 13 | "ocr_name":"tesseract", 14 | "config":{ 15 | "lang": "eng", 16 | "config" : "--psm 6" 17 | } 18 | }, 19 | { 20 | "ocr_name":"paddle_ocr", 21 | "config":{ 22 | "lang":"en" 23 | } 24 | }, 25 | { 26 | "ocr_name":"aws_textract", 27 | "config":{ 28 | "region_name":os.getenv("region_name"), 29 | "aws_access_key_id":os.getenv("aws_access_key_id"), 30 | "aws_secret_access_key":os.getenv("aws_secret_access_key") 31 | } 32 | }, 33 | { 34 | "ocr_name":"easy_ocr", 35 | "config":{ 36 | "lang_list": ["en"] 37 | } 38 | }, 39 | # test with no configs 40 | { 41 | "ocr_name":"tesseract" 42 | }, 43 | { 44 | "ocr_name":"paddle_ocr" 45 | }, 46 | { 47 | "ocr_name":"aws_textract" 48 | }, 49 | { 50 | "ocr_name":"easy_ocr" 51 | } 52 | ] 53 | 54 | for ocr in all_ocr: 55 | ocr_name = None 56 | ocr_config = None 57 | ocr_name = ocr["ocr_name"] 58 | if ocr.get("config"): 59 | ocr_config = ocr["config"] 60 | engine = OcrEngine(ocr_name, ocr_config) 61 | 62 | text_dict = engine.text_extraction(image_file) 63 | assert type(text_dict) == list 64 | assert len(text_dict)>1 65 | 66 | json_str = engine.text_extraction_to_json(text_dict) 67 | assert type(json_str) == str 68 | assert type(json.loads(json_str)) == list 69 | 70 | df = engine.text_extraction_to_df(text_dict) 71 | assert type(df) == pd.DataFrame 72 | assert len(df)>1 73 | 74 | plain_text = engine.extract_plain_text(text_dict) 75 | assert type(plain_text) == str 76 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/artefacts/face.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import List, Tuple 7 | 8 | import cv2 9 | import numpy as np 10 | 11 | from ....doctr.utils.repr import NestedObject 12 | 13 | __all__ = ["FaceDetector"] 14 | 15 | 16 | class FaceDetector(NestedObject): 17 | 18 | """Implements a face detector to detect profile pictures on resumes, IDS, driving licenses, passports... 19 | Based on open CV CascadeClassifier (haarcascades) 20 | 21 | Args: 22 | n_faces: maximal number of faces to detect on a single image, default = 1 23 | """ 24 | 25 | def __init__( 26 | self, 27 | n_faces: int = 1, 28 | ) -> None: 29 | self.n_faces = n_faces 30 | # Instantiate classifier 31 | self.detector = cv2.CascadeClassifier( 32 | cv2.data.haarcascades + "haarcascade_frontalface_default.xml" # type: ignore[attr-defined] 33 | ) 34 | 35 | def extra_repr(self) -> str: 36 | return f"n_faces={self.n_faces}" 37 | 38 | def __call__( 39 | self, 40 | img: np.ndarray, 41 | ) -> List[Tuple[float, float, float, float]]: 42 | """Detect n_faces on the img 43 | 44 | Args: 45 | img: image to detect faces on 46 | 47 | Returns: 48 | A list of size n_faces, each face is a tuple of relative xmin, ymin, xmax, ymax 49 | """ 50 | height, width = img.shape[:2] 51 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 52 | 53 | faces = self.detector.detectMultiScale(gray, 1.5, 3) 54 | # If faces are detected, keep only the biggest ones 55 | rel_faces = [] 56 | if len(faces) > 0: 57 | x, y, w, h = sorted(faces, key=lambda x: x[2] + x[3])[-min(self.n_faces, len(faces))] 58 | xmin, ymin, xmax, ymax = x / width, y / height, (x + w) / width, (y + h) / height 59 | rel_faces.append((xmin, ymin, xmax, ymax)) 60 | 61 | return rel_faces 62 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/utils/multithreading.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | 7 | import multiprocessing as mp 8 | import os 9 | from multiprocessing.pool import ThreadPool 10 | from typing import Any, Callable, Iterable, Iterator, Optional 11 | 12 | from multiocr.pipelines.doctr_ocr.doctr.file_utils import ENV_VARS_TRUE_VALUES 13 | 14 | __all__ = ["multithread_exec"] 15 | 16 | 17 | def multithread_exec(func: Callable[[Any], Any], seq: Iterable[Any], threads: Optional[int] = None) -> Iterator[Any]: 18 | """Execute a given function in parallel for each element of a given sequence 19 | 20 | >>> from doctr.utils.multithreading import multithread_exec 21 | >>> entries = [1, 4, 8] 22 | >>> results = multithread_exec(lambda x: x ** 2, entries) 23 | 24 | Args: 25 | func: function to be executed on each element of the iterable 26 | seq: iterable 27 | threads: number of workers to be used for multiprocessing 28 | 29 | Returns: 30 | iterator of the function's results using the iterable as inputs 31 | 32 | Notes: 33 | This function uses ThreadPool from multiprocessing package, which uses `/dev/shm` directory for shared memory. 34 | If you do not have write permissions for this directory (if you run `doctr` on AWS Lambda for instance), 35 | you might want to disable multiprocessing. To achieve that, set 'DOCTR_MULTIPROCESSING_DISABLE' to 'TRUE'. 36 | """ 37 | 38 | threads = threads if isinstance(threads, int) else min(16, mp.cpu_count()) 39 | # Single-thread 40 | if threads < 2 or os.environ.get("DOCTR_MULTIPROCESSING_DISABLE", "").upper() in ENV_VARS_TRUE_VALUES: 41 | results = map(func, seq) 42 | # Multi-threading 43 | else: 44 | with ThreadPool(threads) as tp: 45 | # ThreadPool's map function returns a list, but seq could be of a different type 46 | # That's why wrapping result in map to return iterator 47 | results = map(lambda x: x, tp.map(func, seq)) 48 | return results 49 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/classification/predictor/pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import List, Union 7 | 8 | import numpy as np 9 | import torch 10 | from torch import nn 11 | 12 | from multiocr.pipelines.doctr_ocr.doctr.models.preprocessor import PreProcessor 13 | from multiocr.pipelines.doctr_ocr.doctr.models.utils import set_device_and_dtype 14 | 15 | __all__ = ["CropOrientationPredictor"] 16 | 17 | 18 | class CropOrientationPredictor(nn.Module): 19 | """Implements an object able to detect the reading direction of a text box. 20 | 4 possible orientations: 0, 90, 180, 270 degrees counter clockwise. 21 | 22 | Args: 23 | pre_processor: transform inputs for easier batched model inference 24 | model: core classification architecture (backbone + classification head) 25 | """ 26 | 27 | def __init__( 28 | self, 29 | pre_processor: PreProcessor, 30 | model: nn.Module, 31 | ) -> None: 32 | super().__init__() 33 | self.pre_processor = pre_processor 34 | self.model = model.eval() 35 | 36 | @torch.no_grad() 37 | def forward( 38 | self, 39 | crops: List[Union[np.ndarray, torch.Tensor]], 40 | ) -> List[int]: 41 | # Dimension check 42 | if any(crop.ndim != 3 for crop in crops): 43 | raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.") 44 | 45 | processed_batches = self.pre_processor(crops) 46 | _params = next(self.model.parameters()) 47 | self.model, processed_batches = set_device_and_dtype( 48 | self.model, processed_batches, _params.device, _params.dtype 49 | ) 50 | predicted_batches = [self.model(batch) for batch in processed_batches] 51 | 52 | # Postprocess predictions 53 | predicted_batches = [out_batch.argmax(dim=1).cpu().detach().numpy() for out_batch in predicted_batches] 54 | 55 | return [int(pred) for batch in predicted_batches for pred in batch] 56 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/utils/repr.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | # Adapted from https://github.com/pytorch/torch/blob/master/torch/nn/modules/module.py 7 | 8 | from typing import List 9 | 10 | __all__ = ["NestedObject"] 11 | 12 | 13 | def _addindent(s_, num_spaces): 14 | s = s_.split("\n") 15 | # don't do anything for single-line stuff 16 | if len(s) == 1: 17 | return s_ 18 | first = s.pop(0) 19 | s = [(num_spaces * " ") + line for line in s] 20 | s = "\n".join(s) 21 | s = first + "\n" + s 22 | return s 23 | 24 | 25 | class NestedObject: 26 | _children_names: List[str] 27 | 28 | def extra_repr(self) -> str: 29 | return "" 30 | 31 | def __repr__(self): 32 | # We treat the extra repr like the sub-object, one item per line 33 | extra_lines = [] 34 | extra_repr = self.extra_repr() 35 | # empty string will be split into list [''] 36 | if extra_repr: 37 | extra_lines = extra_repr.split("\n") 38 | child_lines = [] 39 | if hasattr(self, "_children_names"): 40 | for key in self._children_names: 41 | child = getattr(self, key) 42 | if isinstance(child, list) and len(child) > 0: 43 | child_str = ",\n".join([repr(subchild) for subchild in child]) 44 | if len(child) > 1: 45 | child_str = _addindent(f"\n{child_str},", 2) + "\n" 46 | child_str = f"[{child_str}]" 47 | else: 48 | child_str = repr(child) 49 | child_str = _addindent(child_str, 2) 50 | child_lines.append("(" + key + "): " + child_str) 51 | lines = extra_lines + child_lines 52 | 53 | main_str = self.__class__.__name__ + "(" 54 | if lines: 55 | # simple one-liner info, which most builtin Modules will use 56 | if len(extra_lines) == 1 and not child_lines: 57 | main_str += extra_lines[0] 58 | else: 59 | main_str += "\n " + "\n ".join(lines) + "\n" 60 | 61 | main_str += ")" 62 | return main_str 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | .vscode 131 | poetry.lock 132 | .DS_Store 133 | -------------------------------------------------------------------------------- /multiocr/main.py: -------------------------------------------------------------------------------- 1 | from multiocr.pipelines import AwsTextractOcr 2 | from multiocr.pipelines import TesseractOcr 3 | from multiocr.pipelines import PaddleOcr 4 | from multiocr.pipelines import EasyOcr 5 | from multiocr.pipelines import DoctrOCR 6 | from multiocr.base_class import OCR 7 | from typing import Union 8 | 9 | ENGINE_DICT = { 10 | "aws_textract":AwsTextractOcr, 11 | "tesseract":TesseractOcr, 12 | "paddle_ocr":PaddleOcr, 13 | "easy_ocr": EasyOcr, 14 | "doctr_ocr":DoctrOCR 15 | } 16 | 17 | avail_ocr_backends = list(ENGINE_DICT.keys()) 18 | 19 | class OcrEngineSelectionError(Exception): 20 | def __init__(self, msg:str) -> None: 21 | super().__init__() 22 | pass 23 | 24 | 25 | class OcrEngine(OCR): 26 | def __init__(self, engine:str, config:Union[dict, None]=None) -> None: 27 | self.config = config 28 | self.engine = ENGINE_DICT[engine](self.config) if engine in avail_ocr_backends else None 29 | if self.engine is None: 30 | raise OcrEngineSelectionError(f"only these ocr backends are available : {avail_ocr_backends}") 31 | 32 | def text_extraction(self, image_file): 33 | text_dict = self.engine.text_extraction(image_file) 34 | return text_dict 35 | 36 | def text_extraction_to_json(self, text_dict): 37 | json_dict = self.engine.text_extraction_to_json(text_dict) 38 | return json_dict 39 | 40 | def text_extraction_to_df(self, text_dict): 41 | df = self.engine.text_extraction_to_df(text_dict) 42 | return df 43 | 44 | def extract_plain_text(self, text_dict): 45 | plain_text = self.engine.extract_plain_text(text_dict) 46 | return plain_text 47 | 48 | if __name__ == "__main__": 49 | import os 50 | image_file = "/Users/aravindh/Documents/GitHub/multiocr/tests/data/test-european.jpg" 51 | paddle_config = { 52 | "lang":"en" 53 | } 54 | tess_config = { 55 | "lang": "eng", 56 | "config" : "--psm 6" 57 | } 58 | aws_textract_config = { 59 | "region_name":os.getenv("region_name"), 60 | "aws_access_key_id":os.getenv("aws_access_key_id"), 61 | "aws_secret_access_key":os.getenv("aws_secret_access_key") 62 | } 63 | 64 | easy_ocr_config = { 65 | "lang_list": ["en"] 66 | } 67 | engine = OcrEngine("paddle_ocr", paddle_config) 68 | text_dict = engine.text_extraction(image_file) 69 | json = engine.text_extraction_to_json(text_dict) 70 | df = engine.text_extraction_to_df(text_dict) 71 | plain_text = engine.extract_plain_text(text_dict) 72 | print() 73 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/classification/zoo.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import Any, List 7 | 8 | from multiocr.pipelines.doctr_ocr.doctr.file_utils import is_tf_available 9 | 10 | from .. import classification 11 | from ..preprocessor import PreProcessor 12 | from .predictor import CropOrientationPredictor 13 | 14 | __all__ = ["crop_orientation_predictor"] 15 | 16 | ARCHS: List[str] = [ 17 | "magc_resnet31", 18 | "mobilenet_v3_small", 19 | "mobilenet_v3_small_r", 20 | "mobilenet_v3_large", 21 | "mobilenet_v3_large_r", 22 | "resnet18", 23 | "resnet31", 24 | "resnet34", 25 | "resnet50", 26 | "resnet34_wide", 27 | "vgg16_bn_r", 28 | "vit_s", 29 | "vit_b", 30 | ] 31 | ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_orientation"] 32 | 33 | 34 | def _crop_orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> CropOrientationPredictor: 35 | if arch not in ORIENTATION_ARCHS: 36 | raise ValueError(f"unknown architecture '{arch}'") 37 | 38 | # Load directly classifier from backbone 39 | _model = classification.__dict__[arch](pretrained=pretrained) 40 | kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) 41 | kwargs["std"] = kwargs.get("std", _model.cfg["std"]) 42 | kwargs["batch_size"] = kwargs.get("batch_size", 64) 43 | input_shape = _model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:] 44 | predictor = CropOrientationPredictor( 45 | PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model 46 | ) 47 | return predictor 48 | 49 | 50 | def crop_orientation_predictor( 51 | arch: str = "mobilenet_v3_small_orientation", pretrained: bool = False, **kwargs: Any 52 | ) -> CropOrientationPredictor: 53 | """Orientation classification architecture. 54 | 55 | >>> import numpy as np 56 | >>> from doctr.models import crop_orientation_predictor 57 | >>> model = crop_orientation_predictor(arch='classif_mobilenet_v3_small', pretrained=True) 58 | >>> input_crop = (255 * np.random.rand(600, 800, 3)).astype(np.uint8) 59 | >>> out = model([input_crop]) 60 | 61 | Args: 62 | arch: name of the architecture to use (e.g. 'mobilenet_v3_small') 63 | pretrained: If True, returns a model pre-trained on our recognition crops dataset 64 | 65 | Returns: 66 | CropOrientationPredictor 67 | """ 68 | 69 | return _crop_orientation_predictor(arch, pretrained, **kwargs) 70 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/recognition/zoo.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import Any, List 7 | 8 | from multiocr.pipelines.doctr_ocr.doctr.file_utils import is_tf_available 9 | from multiocr.pipelines.doctr_ocr.doctr.models.preprocessor import PreProcessor 10 | 11 | from .. import recognition 12 | from .predictor import RecognitionPredictor 13 | 14 | __all__ = ["recognition_predictor"] 15 | 16 | 17 | ARCHS: List[str] = [ 18 | "crnn_vgg16_bn", 19 | "crnn_mobilenet_v3_small", 20 | "crnn_mobilenet_v3_large", 21 | "sar_resnet31", 22 | "master", 23 | "vitstr_small", 24 | "vitstr_base", 25 | "parseq", 26 | ] 27 | 28 | 29 | def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredictor: 30 | if isinstance(arch, str): 31 | if arch not in ARCHS: 32 | raise ValueError(f"unknown architecture '{arch}'") 33 | 34 | _model = recognition.__dict__[arch]( 35 | pretrained=pretrained, pretrained_backbone=kwargs.get("pretrained_backbone", True) 36 | ) 37 | else: 38 | if not isinstance( 39 | arch, (recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq) 40 | ): 41 | raise ValueError(f"unknown architecture: {type(arch)}") 42 | _model = arch 43 | 44 | kwargs.pop("pretrained_backbone", None) 45 | 46 | kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) 47 | kwargs["std"] = kwargs.get("std", _model.cfg["std"]) 48 | kwargs["batch_size"] = kwargs.get("batch_size", 32) 49 | input_shape = _model.cfg["input_shape"][:2] if is_tf_available() else _model.cfg["input_shape"][-2:] 50 | predictor = RecognitionPredictor(PreProcessor(input_shape, preserve_aspect_ratio=True, **kwargs), _model) 51 | 52 | return predictor 53 | 54 | 55 | def recognition_predictor(arch: Any = "crnn_vgg16_bn", pretrained: bool = False, **kwargs: Any) -> RecognitionPredictor: 56 | """Text recognition architecture. 57 | 58 | Example:: 59 | >>> import numpy as np 60 | >>> from doctr.models import recognition_predictor 61 | >>> model = recognition_predictor(pretrained=True) 62 | >>> input_page = (255 * np.random.rand(32, 128, 3)).astype(np.uint8) 63 | >>> out = model([input_page]) 64 | 65 | Args: 66 | arch: name of the architecture or model itself to use (e.g. 'crnn_vgg16_bn') 67 | pretrained: If True, returns a model pre-trained on our text recognition dataset 68 | 69 | Returns: 70 | Recognition predictor 71 | """ 72 | 73 | return _predictor(arch, pretrained, **kwargs) 74 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/recognition/predictor/tensorflow.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import Any, List, Tuple, Union 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | 11 | from multiocr.pipelines.doctr_ocr.doctr.models.preprocessor import PreProcessor 12 | from multiocr.pipelines.doctr_ocr.doctr.utils.repr import NestedObject 13 | 14 | from ..core import RecognitionModel 15 | from ._utils import remap_preds, split_crops 16 | 17 | __all__ = ["RecognitionPredictor"] 18 | 19 | 20 | class RecognitionPredictor(NestedObject): 21 | """Implements an object able to identify character sequences in images 22 | 23 | Args: 24 | pre_processor: transform inputs for easier batched model inference 25 | model: core detection architecture 26 | split_wide_crops: wether to use crop splitting for high aspect ratio crops 27 | """ 28 | 29 | _children_names: List[str] = ["pre_processor", "model"] 30 | 31 | def __init__( 32 | self, 33 | pre_processor: PreProcessor, 34 | model: RecognitionModel, 35 | split_wide_crops: bool = True, 36 | ) -> None: 37 | super().__init__() 38 | self.pre_processor = pre_processor 39 | self.model = model 40 | self.split_wide_crops = split_wide_crops 41 | self.critical_ar = 8 # Critical aspect ratio 42 | self.dil_factor = 1.4 # Dilation factor to overlap the crops 43 | self.target_ar = 6 # Target aspect ratio 44 | 45 | def __call__( 46 | self, 47 | crops: List[Union[np.ndarray, tf.Tensor]], 48 | **kwargs: Any, 49 | ) -> List[Tuple[str, float]]: 50 | if len(crops) == 0: 51 | return [] 52 | # Dimension check 53 | if any(crop.ndim != 3 for crop in crops): 54 | raise ValueError("incorrect input shape: all crops are expected to be multi-channel 2D images.") 55 | 56 | # Split crops that are too wide 57 | remapped = False 58 | if self.split_wide_crops: 59 | new_crops, crop_map, remapped = split_crops(crops, self.critical_ar, self.target_ar, self.dil_factor) 60 | if remapped: 61 | crops = new_crops 62 | 63 | # Resize & batch them 64 | processed_batches = self.pre_processor(crops) 65 | 66 | # Forward it 67 | raw = [ 68 | self.model(batch, return_preds=True, training=False, **kwargs)["preds"] # type: ignore[operator] 69 | for batch in processed_batches 70 | ] 71 | 72 | # Process outputs 73 | out = [charseq for batch in raw for charseq in batch] 74 | 75 | # Remap crops 76 | if self.split_wide_crops and remapped: 77 | out = remap_preds(out, crop_map, self.dil_factor) 78 | 79 | return out 80 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/artefacts/barcode.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import List, Tuple 7 | 8 | import cv2 9 | import numpy as np 10 | 11 | __all__ = ["BarCodeDetector"] 12 | 13 | 14 | class BarCodeDetector: 15 | 16 | """Implements a Bar-code detector. 17 | For now, only horizontal (or with a small angle) bar-codes are supported 18 | 19 | Args: 20 | min_size: minimum relative size of a barcode on the page 21 | canny_minval: lower bound for canny hysteresis 22 | canny_maxval: upper-bound for canny hysteresis 23 | """ 24 | 25 | def __init__(self, min_size: float = 1 / 6, canny_minval: int = 50, canny_maxval: int = 150) -> None: 26 | self.min_size = min_size 27 | self.canny_minval = canny_minval 28 | self.canny_maxval = canny_maxval 29 | 30 | def __call__( 31 | self, 32 | img: np.ndarray, 33 | ) -> List[Tuple[float, float, float, float]]: 34 | """Detect Barcodes on the image 35 | Args: 36 | img: np image 37 | 38 | Returns: 39 | A list of tuples: [(xmin, ymin, xmax, ymax), ...] containing barcodes rel. coordinates 40 | """ 41 | # get image size and define parameters 42 | height, width = img.shape[:2] 43 | k = (1 + int(width / 512)) * 10 # spatial extension of kernels, 512 -> 20, 1024 -> 30, ... 44 | min_w = int(width * self.min_size) # minimal size of a possible barcode 45 | 46 | # Detect edges 47 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 48 | edges = cv2.Canny(gray, self.canny_minval, self.canny_maxval, apertureSize=3) 49 | 50 | # Horizontal dilation to aggregate bars of the potential barcode 51 | # without aggregating text lines of the page vertically 52 | edges = cv2.dilate(edges, np.ones((1, k), np.uint8)) 53 | 54 | # Instantiate a barcode-shaped kernel and erode to keep only vertical-bar structures 55 | bar_code_kernel: np.ndarray = np.zeros((k, 3), np.uint8) 56 | bar_code_kernel[..., [0, 2]] = 1 57 | edges = cv2.erode(edges, bar_code_kernel, iterations=1) 58 | 59 | # Opening to remove noise 60 | edges = cv2.morphologyEx(edges, cv2.MORPH_OPEN, np.ones((k, k), np.uint8)) 61 | 62 | # Dilation to retrieve vertical length (lost at the first dilation) 63 | edges = cv2.dilate(edges, np.ones((k, 1), np.uint8)) 64 | 65 | # Find contours, and keep the widest as barcodes 66 | contours, _ = cv2.findContours(edges, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 67 | barcodes = [] 68 | for contour in contours: 69 | x, y, w, h = cv2.boundingRect(contour) 70 | if w >= min_w: 71 | barcodes.append((x / width, y / height, (x + w) / width, (y + h) / height)) 72 | 73 | return barcodes 74 | -------------------------------------------------------------------------------- /multiocr/pipelines/easy_ocr/engine.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | from easyocr import Reader 4 | from typing import Union 5 | 6 | class EasyOcr: 7 | def __init__(self, config: Union[dict, None]=None): 8 | self.config = config 9 | if not self.config: 10 | self.config = { 11 | "lang_list": ["en"] 12 | } 13 | self.ocr = Reader(**self.config) 14 | 15 | def text_extraction(self, image_file): 16 | try: 17 | result = self.ocr.readtext(image_file) 18 | self.raw_ocr = result 19 | except Exception as e: 20 | raise Exception(f"Error detecting text in image: {e}") 21 | 22 | text_dict = [] 23 | for detection in result: 24 | text = detection[1] 25 | confidence = detection[2] 26 | xmin = int(min([w[0] for w in detection[0]])) 27 | ymin = int(min([w[1] for w in detection[0]])) 28 | xmax = int(max([w[0] for w in detection[0]])) 29 | ymax = int(max([w[1] for w in detection[0]])) 30 | word_dict = { 31 | "text": text, 32 | "confidence": confidence, 33 | "coordinates": { 34 | "xmin": xmin, 35 | "ymin": ymin, 36 | "xmax": xmax, 37 | "ymax": ymax 38 | } 39 | } 40 | text_dict.append(word_dict) 41 | 42 | return text_dict 43 | 44 | def text_extraction_to_json(self, text_dict): 45 | try: 46 | return json.dumps(text_dict) 47 | except Exception as e: 48 | raise Exception(f"Error converting text extraction to JSON: {e}") 49 | 50 | def text_extraction_to_df(self, text_dict): 51 | rows = [] 52 | for v in text_dict: 53 | rows.append([v['text'], v['confidence'], v['coordinates']['xmin'], v['coordinates']['ymin'], 54 | v['coordinates']['xmax'], v['coordinates']['ymax']]) 55 | df = pd.DataFrame(rows, columns=['text', 'confidence', 'xmin', 'ymin', 'xmax', 'ymax']) 56 | 57 | try: 58 | return df 59 | except Exception as e: 60 | raise Exception(f"Error converting text extraction to dataframe: {e}") 61 | 62 | def extract_plain_text(self, text_dict): 63 | plain_text = '' 64 | for v in text_dict: 65 | plain_text += v['text'] + ' ' 66 | 67 | try: 68 | return plain_text 69 | except Exception as e: 70 | raise Exception(f"Error converting text extraction to plain text: {e}") 71 | 72 | if __name__ == "__main__": 73 | config = { 74 | "lang_list": ["en"] 75 | } 76 | image_file = "/Users/aravindh/Documents/GitHub/multiocr/tests/data/test-european.jpg" 77 | ocr = EasyOcr(config) 78 | data = ocr.text_extraction(image_file) 79 | json_data = ocr.text_extraction_to_json(data) 80 | plain_text_data = ocr.extract_plain_text(data) 81 | pd_df = ocr.text_extraction_to_df(data) 82 | print() 83 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/obj_detection/faster_rcnn/pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import Any, Dict 7 | 8 | from torchvision.models.detection import FasterRCNN, FasterRCNN_MobileNet_V3_Large_FPN_Weights, faster_rcnn 9 | 10 | from ...utils import load_pretrained_params 11 | 12 | __all__ = ["fasterrcnn_mobilenet_v3_large_fpn"] 13 | 14 | 15 | default_cfgs: Dict[str, Dict[str, Any]] = { 16 | "fasterrcnn_mobilenet_v3_large_fpn": { 17 | "input_shape": (3, 1024, 1024), 18 | "mean": (0.485, 0.456, 0.406), 19 | "std": (0.229, 0.224, 0.225), 20 | "classes": ["background", "qr_code", "bar_code", "logo", "photo"], 21 | "url": "https://doctr-static.mindee.com/models?id=v0.4.1/fasterrcnn_mobilenet_v3_large_fpn-d5b2490d.pt&src=0", 22 | }, 23 | } 24 | 25 | 26 | def _fasterrcnn(arch: str, pretrained: bool, **kwargs: Any) -> FasterRCNN: 27 | _kwargs = { 28 | "image_mean": default_cfgs[arch]["mean"], 29 | "image_std": default_cfgs[arch]["std"], 30 | "box_detections_per_img": 150, 31 | "box_score_thresh": 0.5, 32 | "box_positive_fraction": 0.35, 33 | "box_nms_thresh": 0.2, 34 | "rpn_nms_thresh": 0.2, 35 | "num_classes": len(default_cfgs[arch]["classes"]), 36 | } 37 | 38 | # Build the model 39 | _kwargs.update(kwargs) 40 | model = faster_rcnn.__dict__[arch](weights=None, weights_backbone=None, **_kwargs) 41 | model.cfg = default_cfgs[arch] 42 | 43 | if pretrained: 44 | # Load pretrained parameters 45 | load_pretrained_params(model, default_cfgs[arch]["url"]) 46 | else: 47 | # Filter keys 48 | state_dict = { 49 | k: v 50 | for k, v in faster_rcnn.__dict__[arch](weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.DEFAULT) 51 | .state_dict() 52 | .items() 53 | if not k.startswith("roi_heads.") 54 | } 55 | 56 | # Load state dict 57 | model.load_state_dict(state_dict, strict=False) 58 | 59 | return model 60 | 61 | 62 | def fasterrcnn_mobilenet_v3_large_fpn(pretrained: bool = False, **kwargs: Any) -> FasterRCNN: 63 | """Faster-RCNN architecture with a MobileNet V3 backbone as described in `"Faster R-CNN: Towards Real-Time 64 | Object Detection with Region Proposal Networks" `_. 65 | 66 | >>> import torch 67 | >>> from doctr.models.obj_detection import fasterrcnn_mobilenet_v3_large_fpn 68 | >>> model = fasterrcnn_mobilenet_v3_large_fpn(pretrained=True) 69 | >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32) 70 | >>> out = model(input_tensor) 71 | 72 | Args: 73 | pretrained (bool): If True, returns a model pre-trained on our object detection dataset 74 | 75 | Returns: 76 | object detection architecture 77 | """ 78 | 79 | return _fasterrcnn("fasterrcnn_mobilenet_v3_large_fpn", pretrained, **kwargs) 80 | -------------------------------------------------------------------------------- /multiocr/pipelines/paddle_ocr/engine.py: -------------------------------------------------------------------------------- 1 | from multiocr.base_class import OCR 2 | import json 3 | import pandas as pd 4 | import paddleocr 5 | from PIL import Image 6 | from typing import Union 7 | 8 | class PaddleOcr: 9 | def __init__(self, config:Union[dict, None]=None): 10 | self.config = config 11 | if not self.config: 12 | self.config = { 13 | "lang":"en" 14 | } 15 | self.ocr = paddleocr.PaddleOCR(**self.config) 16 | 17 | def text_extraction(self, image_file): 18 | try: 19 | text = self.ocr.ocr(image_file) 20 | self.raw_ocr = text 21 | except Exception as e: 22 | raise Exception(f"Error detecting text in image: {e}") 23 | 24 | text_dict = [] 25 | for line in text: 26 | for word in line: 27 | xmin = min([w[0] for w in word[0]]) 28 | ymin = min([w[1] for w in word[0]]) 29 | xmax = max([w[0] for w in word[0]]) 30 | ymax = max([w[1] for w in word[0]]) 31 | word_dict = { 32 | "text": word[1][0], 33 | "confidence": word[1][1], 34 | "coordinates":{ 35 | "xmin":xmin, 36 | "ymin":ymin, 37 | "xmax":xmax, 38 | "ymax":ymax 39 | } 40 | } 41 | text_dict.append(word_dict) 42 | 43 | return text_dict 44 | 45 | def text_extraction_to_json(self, text_dict): 46 | try: 47 | return json.dumps(text_dict) 48 | except Exception as e: 49 | raise Exception(f"Error converting text extraction to JSON: {e}") 50 | 51 | def text_extraction_to_df(self, text_dict): 52 | rows = [] 53 | 54 | for v in text_dict: 55 | rows.append([v['text'], v['confidence'], v['coordinates']['xmin'], v['coordinates']['ymin'], 56 | v['coordinates']['xmax'], v['coordinates']['ymax']]) 57 | 58 | df = pd.DataFrame(rows, columns=['text', 'confidence', 'xmin', 'ymin', 'xmax', 'ymax']) 59 | 60 | try: 61 | return df 62 | except Exception as e: 63 | raise Exception(f"Error converting text extraction to dataframe: {e}") 64 | 65 | def extract_plain_text(self, text_dict): 66 | plain_text = '' 67 | 68 | for v in text_dict: 69 | plain_text += v['text'] + ' ' 70 | 71 | try: 72 | return plain_text 73 | except Exception as e: 74 | raise Exception(f"Error converting text extraction to plain text: {e}") 75 | 76 | if __name__ == "__main__": 77 | config = { 78 | "lang":"en" 79 | } 80 | image_file = "/Users/aravindh/Documents/GitHub/multiocr/tests/data/test-european.jpg" 81 | ocr = PaddleOcr(config) 82 | data = ocr.text_extraction(image_file) 83 | json_data = ocr.text_extraction_to_json(data) 84 | plain_text_data = ocr.extract_plain_text(data) 85 | pd_df = ocr.text_extraction_to_df(data) 86 | print() -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/recognition/predictor/pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import Any, List, Sequence, Tuple, Union 7 | 8 | import numpy as np 9 | import torch 10 | from torch import nn 11 | 12 | from multiocr.pipelines.doctr_ocr.doctr.models.preprocessor import PreProcessor 13 | from multiocr.pipelines.doctr_ocr.doctr.models.utils import set_device_and_dtype 14 | 15 | from ._utils import remap_preds, split_crops 16 | 17 | __all__ = ["RecognitionPredictor"] 18 | 19 | 20 | class RecognitionPredictor(nn.Module): 21 | """Implements an object able to identify character sequences in images 22 | 23 | Args: 24 | pre_processor: transform inputs for easier batched model inference 25 | model: core detection architecture 26 | split_wide_crops: wether to use crop splitting for high aspect ratio crops 27 | """ 28 | 29 | def __init__( 30 | self, 31 | pre_processor: PreProcessor, 32 | model: nn.Module, 33 | split_wide_crops: bool = True, 34 | ) -> None: 35 | super().__init__() 36 | self.pre_processor = pre_processor 37 | self.model = model.eval() 38 | self.split_wide_crops = split_wide_crops 39 | self.critical_ar = 8 # Critical aspect ratio 40 | self.dil_factor = 1.4 # Dilation factor to overlap the crops 41 | self.target_ar = 6 # Target aspect ratio 42 | 43 | @torch.no_grad() 44 | def forward( 45 | self, 46 | crops: Sequence[Union[np.ndarray, torch.Tensor]], 47 | **kwargs: Any, 48 | ) -> List[Tuple[str, float]]: 49 | if len(crops) == 0: 50 | return [] 51 | # Dimension check 52 | if any(crop.ndim != 3 for crop in crops): 53 | raise ValueError("incorrect input shape: all crops are expected to be multi-channel 2D images.") 54 | 55 | # Split crops that are too wide 56 | remapped = False 57 | if self.split_wide_crops: 58 | new_crops, crop_map, remapped = split_crops( 59 | crops, # type: ignore[arg-type] 60 | self.critical_ar, 61 | self.target_ar, 62 | self.dil_factor, 63 | isinstance(crops[0], np.ndarray), 64 | ) 65 | if remapped: 66 | crops = new_crops 67 | 68 | # Resize & batch them 69 | processed_batches = self.pre_processor(crops) 70 | 71 | # Forward it 72 | _params = next(self.model.parameters()) 73 | self.model, processed_batches = set_device_and_dtype( 74 | self.model, processed_batches, _params.device, _params.dtype 75 | ) 76 | raw = [self.model(batch, return_preds=True, **kwargs)["preds"] for batch in processed_batches] 77 | 78 | # Process outputs 79 | out = [charseq for batch in raw for charseq in batch] 80 | 81 | # Remap crops 82 | if self.split_wide_crops and remapped: 83 | out = remap_preds(out, crop_map, self.dil_factor) 84 | 85 | return out 86 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/file_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | # Adapted from https://github.com/huggingface/transformers/blob/master/src/transformers/file_utils.py 7 | 8 | import importlib.util 9 | import logging 10 | import os 11 | import sys 12 | 13 | CLASS_NAME: str = "words" 14 | 15 | 16 | if sys.version_info < (3, 8): # pragma: no cover 17 | import importlib_metadata 18 | else: 19 | import importlib.metadata as importlib_metadata 20 | 21 | 22 | __all__ = ["is_tf_available", "is_torch_available", "CLASS_NAME"] 23 | 24 | ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} 25 | ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) 26 | 27 | USE_TF = os.environ.get("USE_TF", "AUTO").upper() 28 | USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() 29 | 30 | 31 | if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: 32 | _torch_available = importlib.util.find_spec("torch") is not None 33 | if _torch_available: 34 | try: 35 | _torch_version = importlib_metadata.version("torch") 36 | logging.info(f"PyTorch version {_torch_version} available.") 37 | except importlib_metadata.PackageNotFoundError: # pragma: no cover 38 | _torch_available = False 39 | else: # pragma: no cover 40 | logging.info("Disabling PyTorch because USE_TF is set") 41 | _torch_available = False 42 | 43 | 44 | if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: 45 | _tf_available = importlib.util.find_spec("tensorflow") is not None 46 | if _tf_available: 47 | candidates = ( 48 | "tensorflow", 49 | "tensorflow-cpu", 50 | "tensorflow-gpu", 51 | "tf-nightly", 52 | "tf-nightly-cpu", 53 | "tf-nightly-gpu", 54 | "intel-tensorflow", 55 | "tensorflow-rocm", 56 | "tensorflow-macos", 57 | ) 58 | _tf_version = None 59 | # For the metadata, we have to look for both tensorflow and tensorflow-cpu 60 | for pkg in candidates: 61 | try: 62 | _tf_version = importlib_metadata.version(pkg) 63 | break 64 | except importlib_metadata.PackageNotFoundError: 65 | pass 66 | _tf_available = _tf_version is not None 67 | if _tf_available: 68 | if int(_tf_version.split(".")[0]) < 2: # type: ignore[union-attr] # pragma: no cover 69 | logging.info(f"TensorFlow found but with version {_tf_version}. DocTR requires version 2 minimum.") 70 | _tf_available = False 71 | else: 72 | logging.info(f"TensorFlow version {_tf_version} available.") 73 | else: # pragma: no cover 74 | logging.info("Disabling Tensorflow because USE_TORCH is set") 75 | _tf_available = False 76 | 77 | 78 | if not _torch_available and not _tf_available: # pragma: no cover 79 | raise ModuleNotFoundError( 80 | "DocTR requires either TensorFlow or PyTorch to be installed. Please ensure one of them" 81 | " is installed and that either USE_TF or USE_TORCH is enabled." 82 | ) 83 | 84 | 85 | def is_torch_available(): 86 | return _torch_available 87 | 88 | 89 | def is_tf_available(): 90 | return _tf_available 91 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/classification/vgg/pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from copy import deepcopy 7 | from typing import Any, Dict, List, Optional 8 | 9 | from torch import nn 10 | from torchvision.models import vgg as tv_vgg 11 | 12 | from .....doctr.datasets import VOCABS 13 | 14 | from ...utils import load_pretrained_params 15 | 16 | __all__ = ["vgg16_bn_r"] 17 | 18 | 19 | default_cfgs: Dict[str, Dict[str, Any]] = { 20 | "vgg16_bn_r": { 21 | "mean": (0.694, 0.695, 0.693), 22 | "std": (0.299, 0.296, 0.301), 23 | "input_shape": (3, 32, 32), 24 | "classes": list(VOCABS["french"]), 25 | "url": "https://doctr-static.mindee.com/models?id=v0.4.1/vgg16_bn_r-d108c19c.pt&src=0", 26 | }, 27 | } 28 | 29 | 30 | def _vgg( 31 | arch: str, 32 | pretrained: bool, 33 | tv_arch: str, 34 | num_rect_pools: int = 3, 35 | ignore_keys: Optional[List[str]] = None, 36 | **kwargs: Any, 37 | ) -> tv_vgg.VGG: 38 | kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"])) 39 | kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"]) 40 | 41 | _cfg = deepcopy(default_cfgs[arch]) 42 | _cfg["num_classes"] = kwargs["num_classes"] 43 | _cfg["classes"] = kwargs["classes"] 44 | kwargs.pop("classes") 45 | 46 | # Build the model 47 | model = tv_vgg.__dict__[tv_arch](**kwargs) 48 | # List the MaxPool2d 49 | pool_idcs = [idx for idx, m in enumerate(model.features) if isinstance(m, nn.MaxPool2d)] 50 | # Replace their kernel with rectangular ones 51 | for idx in pool_idcs[-num_rect_pools:]: 52 | model.features[idx] = nn.MaxPool2d((2, 1)) 53 | # Patch average pool & classification head 54 | model.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 55 | model.classifier = nn.Linear(512, kwargs["num_classes"]) 56 | # Load pretrained parameters 57 | if pretrained: 58 | # The number of classes is not the same as the number of classes in the pretrained model => 59 | # remove the last layer weights 60 | _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None 61 | load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys) 62 | 63 | model.cfg = _cfg 64 | 65 | return model 66 | 67 | 68 | def vgg16_bn_r(pretrained: bool = False, **kwargs: Any) -> tv_vgg.VGG: 69 | """VGG-16 architecture as described in `"Very Deep Convolutional Networks for Large-Scale Image Recognition" 70 | `_, modified by adding batch normalization, rectangular pooling and a simpler 71 | classification head. 72 | 73 | >>> import torch 74 | >>> from doctr.models import vgg16_bn_r 75 | >>> model = vgg16_bn_r(pretrained=False) 76 | >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32) 77 | >>> out = model(input_tensor) 78 | 79 | Args: 80 | pretrained (bool): If True, returns a model pre-trained on ImageNet 81 | 82 | Returns: 83 | VGG feature extractor 84 | """ 85 | 86 | return _vgg( 87 | "vgg16_bn_r", 88 | pretrained, 89 | "vgg16_bn", 90 | 3, 91 | ignore_keys=["classifier.weight", "classifier.bias"], 92 | **kwargs, 93 | ) 94 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/engine.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | from PIL import Image 4 | from typing import Union 5 | from .doctr.models import ocr_predictor 6 | import numpy as np 7 | 8 | class DoctrOCR: 9 | def __init__(self, config: Union[dict, None] = None): 10 | self.config = config 11 | if not self.config: 12 | self.config = {} 13 | self.model = ocr_predictor(pretrained=True, **self.config) 14 | 15 | def text_extraction(self, image_file): 16 | try: 17 | if isinstance(image_file, str): 18 | image = np.array(Image.open(image_file).convert("RGB")) 19 | elif isinstance(image_file, Image): 20 | image = np.array(image_file.convert("RGB")) 21 | 22 | result = self.model([image]) 23 | self.raw_ocr = result 24 | 25 | text_dict = [] 26 | for page in result.pages: 27 | h,w = page.dimensions 28 | for block in page.blocks: 29 | for line in block.lines: 30 | for word in line.words: 31 | text = word.value.strip() 32 | if text: 33 | confidence = word.confidence 34 | box = word.geometry 35 | text_dict.append({ 36 | 'text': text, 37 | 'confidence': confidence, 38 | 'coordinates': { 39 | 'xmin': box[0][0]*w, 40 | 'ymin': box[0][1]*h, 41 | 'xmax': box[1][0]*w, 42 | 'ymax': box[1][1]*h 43 | } 44 | }) 45 | return text_dict 46 | except Exception as e: 47 | raise Exception(f"Error detecting text in image: {e}") 48 | 49 | def text_extraction_to_json(self, text_dict): 50 | try: 51 | return json.dumps(text_dict) 52 | except Exception as e: 53 | raise Exception(f"Error converting text extraction to JSON: {e}") 54 | 55 | def text_extraction_to_df(self, text_dict): 56 | try: 57 | rows = [] 58 | 59 | for v in text_dict: 60 | rows.append([v['text'], v['confidence'], v['coordinates']['xmin'], v['coordinates']['ymin'], 61 | v['coordinates']['xmax'], v['coordinates']['ymax']]) 62 | 63 | df = pd.DataFrame(rows, columns=['text', 'confidence', 'xmin', 'ymin', 'xmax', 'ymax']) 64 | 65 | return df 66 | except Exception as e: 67 | raise Exception(f"Error converting text extraction to dataframe: {e}") 68 | 69 | def extract_plain_text(self, text_dict): 70 | try: 71 | plain_text = '' 72 | 73 | for v in text_dict: 74 | plain_text += v['text'] + ' ' 75 | 76 | return plain_text 77 | except Exception as e: 78 | raise Exception(f"Error converting text extraction to plain text: {e}") 79 | 80 | if __name__ == "__main__": 81 | image_file = "/Users/aravindh/Documents/GitHub/multiocr/tests/data/test-european.jpg" 82 | engine = DoctrOCR() 83 | text_dict = engine.text_extraction(image_file) 84 | json_op = engine.text_extraction_to_json(text_dict) 85 | df = engine.text_extraction_to_df(text_dict) 86 | plain_text = engine.extract_plain_text(text_dict) 87 | print() 88 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/io/image/tensorflow.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import Tuple 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | from PIL import Image 11 | from tensorflow.keras.utils import img_to_array 12 | 13 | from multiocr.pipelines.doctr_ocr.doctr.utils.common_types import AbstractPath 14 | 15 | __all__ = ["tensor_from_pil", "read_img_as_tensor", "decode_img_as_tensor", "tensor_from_numpy", "get_img_shape"] 16 | 17 | 18 | def tensor_from_pil(pil_img: Image, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor: 19 | """Convert a PIL Image to a TensorFlow tensor 20 | 21 | Args: 22 | pil_img: a PIL image 23 | dtype: the output tensor data type 24 | 25 | Returns: 26 | decoded image as tensor 27 | """ 28 | 29 | npy_img = img_to_array(pil_img) 30 | 31 | return tensor_from_numpy(npy_img, dtype) 32 | 33 | 34 | def read_img_as_tensor(img_path: AbstractPath, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor: 35 | """Read an image file as a TensorFlow tensor 36 | 37 | Args: 38 | img_path: location of the image file 39 | dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255. 40 | 41 | Returns: 42 | decoded image as a tensor 43 | """ 44 | 45 | if dtype not in (tf.uint8, tf.float16, tf.float32): 46 | raise ValueError("insupported value for dtype") 47 | 48 | img = tf.io.read_file(img_path) 49 | img = tf.image.decode_jpeg(img, channels=3) 50 | 51 | if dtype != tf.uint8: 52 | img = tf.image.convert_image_dtype(img, dtype=dtype) 53 | img = tf.clip_by_value(img, 0, 1) 54 | 55 | return img 56 | 57 | 58 | def decode_img_as_tensor(img_content: bytes, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor: 59 | """Read a byte stream as a TensorFlow tensor 60 | 61 | Args: 62 | img_content: bytes of a decoded image 63 | dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255. 64 | 65 | Returns: 66 | decoded image as a tensor 67 | """ 68 | 69 | if dtype not in (tf.uint8, tf.float16, tf.float32): 70 | raise ValueError("insupported value for dtype") 71 | 72 | img = tf.io.decode_image(img_content, channels=3) 73 | 74 | if dtype != tf.uint8: 75 | img = tf.image.convert_image_dtype(img, dtype=dtype) 76 | img = tf.clip_by_value(img, 0, 1) 77 | 78 | return img 79 | 80 | 81 | def tensor_from_numpy(npy_img: np.ndarray, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor: 82 | """Read an image file as a TensorFlow tensor 83 | 84 | Args: 85 | img: image encoded as a numpy array of shape (H, W, C) in np.uint8 86 | dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255. 87 | 88 | Returns: 89 | same image as a tensor of shape (H, W, C) 90 | """ 91 | 92 | if dtype not in (tf.uint8, tf.float16, tf.float32): 93 | raise ValueError("insupported value for dtype") 94 | 95 | if dtype == tf.uint8: 96 | img = tf.convert_to_tensor(npy_img, dtype=dtype) 97 | else: 98 | img = tf.image.convert_image_dtype(npy_img, dtype=dtype) 99 | img = tf.clip_by_value(img, 0, 1) 100 | 101 | return img 102 | 103 | 104 | def get_img_shape(img: tf.Tensor) -> Tuple[int, int]: 105 | return img.shape[:2] 106 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multiocr 2 | This package intends to give a common interface for multiple ocr backends 3 | 4 | # Installation 5 | ``` 6 | pip install multiocr 7 | ``` 8 | # Supported OCR Backends 9 | 10 | - Tesseract 11 | - PaddleOCR 12 | - Aws Textract 13 | - EasyOCR 14 | - Doctr-Ocr 15 | 16 | the output for all ocr backend will be simillar 17 | 18 | # Code Example 19 | **Tesseract** 20 | ```python 21 | from multiocr import OcrEngine 22 | 23 | config = { 24 | "lang": "eng", 25 | "config" : "--psm 6" 26 | } 27 | image_file = "path/to/image.jpg" 28 | engine = OcrEngine("tesseract", config) 29 | text_dict = engine.text_extraction(image_file) 30 | json = engine.text_extraction_to_json(text_dict) 31 | df = engine.text_extraction_to_df(text_dict) 32 | plain_text = engine.extract_plain_text(text_dict) 33 | ``` 34 | **PaddleOCR** 35 | ```python 36 | from multiocr import OcrEngine 37 | 38 | config = { 39 | "lang":"en" 40 | } 41 | image_file = "path/to/image.jpg" 42 | engine = OcrEngine("paddle_ocr", config) 43 | text_dict = engine.text_extraction(image_file) 44 | json = engine.text_extraction_to_json(text_dict) 45 | df = engine.text_extraction_to_df(text_dict) 46 | plain_text = engine.extract_plain_text(text_dict) 47 | ``` 48 | **Aws Textract** 49 | ```python 50 | from multiocr import OcrEngine 51 | 52 | config = { 53 | "region_name":os.getenv("region_name"), 54 | "aws_access_key_id":os.getenv("aws_access_key_id"), 55 | "aws_secret_access_key":os.getenv("aws_secret_access_key") 56 | } 57 | image_file = "path/to/image.jpg" 58 | 59 | engine = OcrEngine("aws_textract", config) 60 | text_dict = engine.text_extraction(image_file) 61 | json = engine.text_extraction_to_json(text_dict) 62 | df = engine.text_extraction_to_df(text_dict) 63 | plain_text = engine.extract_plain_text(text_dict) 64 | ``` 65 | 66 | **EasyOCR** 67 | ```python 68 | from multiocr import OcrEngine 69 | 70 | config = { 71 | "lang_list": ["en"] 72 | } 73 | image_file = "path/to/image.jpg" 74 | engine = OcrEngine("easy_ocr", config) 75 | text_dict = engine.text_extraction(image_file) 76 | json = engine.text_extraction_to_json(text_dict) 77 | df = engine.text_extraction_to_df(text_dict) 78 | plain_text = engine.extract_plain_text(text_dict) 79 | ``` 80 | 81 | **TrOCR** 82 | ```python 83 | from multiocr import OcrEngine 84 | 85 | image_file = "path/to/image.jpg" 86 | engine = OcrEngine("doctr_ocr") 87 | text_dict = engine.text_extraction(image_file) 88 | json = engine.text_extraction_to_json(text_dict) 89 | df = engine.text_extraction_to_df(text_dict) 90 | plain_text = engine.extract_plain_text(text_dict) 91 | 92 | ``` 93 | 94 | if you want to access the output of each individual ocr engine in their own raw format, we can fetch it this way 95 | 96 | ``` 97 | raw_ocr_output = engine.engine.raw_ocr 98 | ``` 99 | 100 | **config** is the each ocr's input parameters and it should be python dictionary. if not given, it'll default to each respective libraries default parameters 101 | 102 | the input parameters for each ocr differs, and you can look at its respective repo for all allowable parameters 103 | 104 | # Reference & Acknowlegements 105 | 106 | - [Pytesseract](https://github.com/madmaze/pytesseract) 107 | - [Tesseract](https://github.com/tesseract-ocr/tesseract) 108 | - [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR) 109 | - [AWS Textract](https://docs.aws.amazon.com/textract/latest/dg/what-is.html) 110 | - [EasyOCR](https://www.jaided.ai/easyocr/) 111 | - [Doctr-Ocr](https://github.com/mindee/doctr) 112 | 113 | # WIP - OCR Backends 114 | - [ ] MMOCR 115 | - [ ] Google Vision 116 | - [ ] Azure OCR 117 | - [ ] DocTR 118 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/io/image/pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from io import BytesIO 7 | from typing import Tuple 8 | 9 | import numpy as np 10 | import torch 11 | from PIL import Image 12 | from torchvision.transforms.functional import to_tensor 13 | 14 | from ....doctr.utils.common_types import AbstractPath 15 | 16 | __all__ = ["tensor_from_pil", "read_img_as_tensor", "decode_img_as_tensor", "tensor_from_numpy", "get_img_shape"] 17 | 18 | 19 | def tensor_from_pil(pil_img: Image, dtype: torch.dtype = torch.float32) -> torch.Tensor: 20 | """Convert a PIL Image to a PyTorch tensor 21 | 22 | Args: 23 | pil_img: a PIL image 24 | dtype: the output tensor data type 25 | 26 | Returns: 27 | decoded image as tensor 28 | """ 29 | 30 | if dtype == torch.float32: 31 | img = to_tensor(pil_img) 32 | else: 33 | img = tensor_from_numpy(np.array(pil_img, np.uint8, copy=True), dtype) 34 | 35 | return img 36 | 37 | 38 | def read_img_as_tensor(img_path: AbstractPath, dtype: torch.dtype = torch.float32) -> torch.Tensor: 39 | """Read an image file as a PyTorch tensor 40 | 41 | Args: 42 | img_path: location of the image file 43 | dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255. 44 | 45 | Returns: 46 | decoded image as a tensor 47 | """ 48 | 49 | if dtype not in (torch.uint8, torch.float16, torch.float32): 50 | raise ValueError("insupported value for dtype") 51 | 52 | pil_img = Image.open(img_path, mode="r").convert("RGB") 53 | 54 | return tensor_from_pil(pil_img, dtype) 55 | 56 | 57 | def decode_img_as_tensor(img_content: bytes, dtype: torch.dtype = torch.float32) -> torch.Tensor: 58 | """Read a byte stream as a PyTorch tensor 59 | 60 | Args: 61 | img_content: bytes of a decoded image 62 | dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255. 63 | 64 | Returns: 65 | decoded image as a tensor 66 | """ 67 | 68 | if dtype not in (torch.uint8, torch.float16, torch.float32): 69 | raise ValueError("insupported value for dtype") 70 | 71 | pil_img = Image.open(BytesIO(img_content), mode="r").convert("RGB") 72 | 73 | return tensor_from_pil(pil_img, dtype) 74 | 75 | 76 | def tensor_from_numpy(npy_img: np.ndarray, dtype: torch.dtype = torch.float32) -> torch.Tensor: 77 | """Read an image file as a PyTorch tensor 78 | 79 | Args: 80 | img: image encoded as a numpy array of shape (H, W, C) in np.uint8 81 | dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255. 82 | 83 | Returns: 84 | same image as a tensor of shape (C, H, W) 85 | """ 86 | 87 | if dtype not in (torch.uint8, torch.float16, torch.float32): 88 | raise ValueError("insupported value for dtype") 89 | 90 | if dtype == torch.float32: 91 | img = to_tensor(npy_img) 92 | else: 93 | img = torch.from_numpy(npy_img) 94 | # put it from HWC to CHW format 95 | img = img.permute((2, 0, 1)).contiguous() 96 | if dtype == torch.float16: 97 | # Switch to FP16 98 | img = img.to(dtype=torch.float16).div(255) 99 | 100 | return img 101 | 102 | 103 | def get_img_shape(img: torch.Tensor) -> Tuple[int, int]: 104 | return img.shape[-2:] # type: ignore[return-value] 105 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/detection/zoo.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import Any, List 7 | 8 | from multiocr.pipelines.doctr_ocr.doctr.file_utils import is_tf_available, is_torch_available 9 | 10 | from .. import detection 11 | from ..preprocessor import PreProcessor 12 | from .predictor import DetectionPredictor 13 | 14 | __all__ = ["detection_predictor"] 15 | 16 | ARCHS: List[str] 17 | ROT_ARCHS: List[str] 18 | 19 | 20 | if is_tf_available(): 21 | ARCHS = ["db_resnet50", "db_mobilenet_v3_large", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"] 22 | ROT_ARCHS = ["linknet_resnet18_rotation"] 23 | elif is_torch_available(): 24 | ARCHS = [ 25 | "db_resnet34", 26 | "db_resnet50", 27 | "db_mobilenet_v3_large", 28 | "linknet_resnet18", 29 | "linknet_resnet34", 30 | "linknet_resnet50", 31 | ] 32 | ROT_ARCHS = ["db_resnet50_rotation"] 33 | 34 | 35 | def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True, **kwargs: Any) -> DetectionPredictor: 36 | if isinstance(arch, str): 37 | if arch not in ARCHS + ROT_ARCHS: 38 | raise ValueError(f"unknown architecture '{arch}'") 39 | 40 | if arch not in ROT_ARCHS and not assume_straight_pages: 41 | raise AssertionError( 42 | "You are trying to use a model trained on straight pages while not assuming" 43 | " your pages are straight. If you have only straight documents, don't pass" 44 | " assume_straight_pages=False, otherwise you should use one of these archs:" 45 | f"{ROT_ARCHS}" 46 | ) 47 | 48 | _model = detection.__dict__[arch]( 49 | pretrained=pretrained, 50 | pretrained_backbone=kwargs.get("pretrained_backbone", True), 51 | assume_straight_pages=assume_straight_pages, 52 | ) 53 | else: 54 | if not isinstance(arch, (detection.DBNet, detection.LinkNet)): 55 | raise ValueError(f"unknown architecture: {type(arch)}") 56 | 57 | _model = arch 58 | _model.assume_straight_pages = assume_straight_pages 59 | 60 | kwargs.pop("pretrained_backbone", None) 61 | 62 | kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) 63 | kwargs["std"] = kwargs.get("std", _model.cfg["std"]) 64 | kwargs["batch_size"] = kwargs.get("batch_size", 1) 65 | predictor = DetectionPredictor( 66 | PreProcessor(_model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:], **kwargs), 67 | _model, 68 | ) 69 | return predictor 70 | 71 | 72 | def detection_predictor( 73 | arch: Any = "db_resnet50", 74 | pretrained: bool = False, 75 | assume_straight_pages: bool = True, 76 | **kwargs: Any, 77 | ) -> DetectionPredictor: 78 | """Text detection architecture. 79 | 80 | >>> import numpy as np 81 | >>> from doctr.models import detection_predictor 82 | >>> model = detection_predictor(arch='db_resnet50', pretrained=True) 83 | >>> input_page = (255 * np.random.rand(600, 800, 3)).astype(np.uint8) 84 | >>> out = model([input_page]) 85 | 86 | Args: 87 | arch: name of the architecture or model itself to use (e.g. 'db_resnet50') 88 | pretrained: If True, returns a model pre-trained on our text detection dataset 89 | assume_straight_pages: If True, fit straight boxes to the page 90 | 91 | Returns: 92 | Detection predictor 93 | """ 94 | 95 | return _predictor(arch, pretrained, assume_straight_pages, **kwargs) 96 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/recognition/predictor/_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import List, Tuple, Union 7 | 8 | import numpy as np 9 | 10 | from ..utils import merge_multi_strings 11 | 12 | __all__ = ["split_crops", "remap_preds"] 13 | 14 | 15 | def split_crops( 16 | crops: List[np.ndarray], 17 | max_ratio: float, 18 | target_ratio: int, 19 | dilation: float, 20 | channels_last: bool = True, 21 | ) -> Tuple[List[np.ndarray], List[Union[int, Tuple[int, int]]], bool]: 22 | """Chunk crops horizontally to match a given aspect ratio 23 | 24 | Args: 25 | crops: list of numpy array of shape (H, W, 3) if channels_last or (3, H, W) otherwise 26 | max_ratio: the maximum aspect ratio that won't trigger the chunk 27 | target_ratio: when crops are chunked, they will be chunked to match this aspect ratio 28 | dilation: the width dilation of final chunks (to provide some overlaps) 29 | channels_last: whether the numpy array has dimensions in channels last order 30 | 31 | Returns: 32 | a tuple with the new crops, their mapping, and a boolean specifying whether any remap is required 33 | """ 34 | 35 | _remap_required = False 36 | crop_map: List[Union[int, Tuple[int, int]]] = [] 37 | new_crops: List[np.ndarray] = [] 38 | for crop in crops: 39 | h, w = crop.shape[:2] if channels_last else crop.shape[-2:] 40 | aspect_ratio = w / h 41 | if aspect_ratio > max_ratio: 42 | # Determine the number of crops, reference aspect ratio = 4 = 128 / 32 43 | num_subcrops = int(aspect_ratio // target_ratio) 44 | # Find the new widths, additional dilation factor to overlap crops 45 | width = dilation * w / num_subcrops 46 | centers = [(w / num_subcrops) * (1 / 2 + idx) for idx in range(num_subcrops)] 47 | # Get the crops 48 | if channels_last: 49 | _crops = [ 50 | crop[:, max(0, int(round(center - width / 2))) : min(w - 1, int(round(center + width / 2))), :] 51 | for center in centers 52 | ] 53 | else: 54 | _crops = [ 55 | crop[:, :, max(0, int(round(center - width / 2))) : min(w - 1, int(round(center + width / 2)))] 56 | for center in centers 57 | ] 58 | # Avoid sending zero-sized crops 59 | _crops = [crop for crop in _crops if all(s > 0 for s in crop.shape)] 60 | # Record the slice of crops 61 | crop_map.append((len(new_crops), len(new_crops) + len(_crops))) 62 | new_crops.extend(_crops) 63 | # At least one crop will require merging 64 | _remap_required = True 65 | else: 66 | crop_map.append(len(new_crops)) 67 | new_crops.append(crop) 68 | 69 | return new_crops, crop_map, _remap_required 70 | 71 | 72 | def remap_preds( 73 | preds: List[Tuple[str, float]], crop_map: List[Union[int, Tuple[int, int]]], dilation: float 74 | ) -> List[Tuple[str, float]]: 75 | remapped_out = [] 76 | for _idx in crop_map: 77 | # Crop hasn't been split 78 | if isinstance(_idx, int): 79 | remapped_out.append(preds[_idx]) 80 | else: 81 | # unzip 82 | vals, probs = zip(*preds[_idx[0] : _idx[1]]) 83 | # Merge the string values 84 | remapped_out.append((merge_multi_strings(vals, dilation), min(probs))) # type: ignore[arg-type] 85 | return remapped_out 86 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/detection/core.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import List 7 | 8 | import cv2 9 | import numpy as np 10 | 11 | from multiocr.pipelines.doctr_ocr.doctr.utils.repr import NestedObject 12 | 13 | __all__ = ["DetectionPostProcessor"] 14 | 15 | 16 | class DetectionPostProcessor(NestedObject): 17 | """Abstract class to postprocess the raw output of the model 18 | 19 | Args: 20 | box_thresh (float): minimal objectness score to consider a box 21 | bin_thresh (float): threshold to apply to segmentation raw heatmap 22 | assume straight_pages (bool): if True, fit straight boxes only 23 | """ 24 | 25 | def __init__(self, box_thresh: float = 0.5, bin_thresh: float = 0.5, assume_straight_pages: bool = True) -> None: 26 | self.box_thresh = box_thresh 27 | self.bin_thresh = bin_thresh 28 | self.assume_straight_pages = assume_straight_pages 29 | self._opening_kernel: np.ndarray = np.ones((3, 3), dtype=np.uint8) 30 | 31 | def extra_repr(self) -> str: 32 | return f"bin_thresh={self.bin_thresh}, box_thresh={self.box_thresh}" 33 | 34 | @staticmethod 35 | def box_score(pred: np.ndarray, points: np.ndarray, assume_straight_pages: bool = True) -> float: 36 | """Compute the confidence score for a polygon : mean of the p values on the polygon 37 | 38 | Args: 39 | pred (np.ndarray): p map returned by the model 40 | 41 | Returns: 42 | polygon objectness 43 | """ 44 | h, w = pred.shape[:2] 45 | 46 | if assume_straight_pages: 47 | xmin = np.clip(np.floor(points[:, 0].min()).astype(np.int32), 0, w - 1) 48 | xmax = np.clip(np.ceil(points[:, 0].max()).astype(np.int32), 0, w - 1) 49 | ymin = np.clip(np.floor(points[:, 1].min()).astype(np.int32), 0, h - 1) 50 | ymax = np.clip(np.ceil(points[:, 1].max()).astype(np.int32), 0, h - 1) 51 | return pred[ymin : ymax + 1, xmin : xmax + 1].mean() 52 | 53 | else: 54 | mask: np.ndarray = np.zeros((h, w), np.int32) 55 | cv2.fillPoly(mask, [points.astype(np.int32)], 1.0) 56 | product = pred * mask 57 | return np.sum(product) / np.count_nonzero(product) 58 | 59 | def bitmap_to_boxes( 60 | self, 61 | pred: np.ndarray, 62 | bitmap: np.ndarray, 63 | ) -> np.ndarray: 64 | raise NotImplementedError 65 | 66 | def __call__( 67 | self, 68 | proba_map, 69 | ) -> List[List[np.ndarray]]: 70 | """Performs postprocessing for a list of model outputs 71 | 72 | Args: 73 | proba_map: probability map of shape (N, H, W, C) 74 | 75 | Returns: 76 | list of N class predictions (for each input sample), where each class predictions is a list of C tensors 77 | of shape (*, 5) or (*, 6) 78 | """ 79 | 80 | if proba_map.ndim != 4: 81 | raise AssertionError(f"arg `proba_map` is expected to be 4-dimensional, got {proba_map.ndim}.") 82 | 83 | # Erosion + dilation on the binary map 84 | bin_map = [ 85 | [ 86 | cv2.morphologyEx(bmap[..., idx], cv2.MORPH_OPEN, self._opening_kernel) 87 | for idx in range(proba_map.shape[-1]) 88 | ] 89 | for bmap in (proba_map >= self.bin_thresh).astype(np.uint8) 90 | ] 91 | 92 | return [ 93 | [self.bitmap_to_boxes(pmaps[..., idx], bmaps[idx]) for idx in range(proba_map.shape[-1])] 94 | for pmaps, bmaps in zip(proba_map, bin_map) 95 | ] 96 | -------------------------------------------------------------------------------- /multiocr/pipelines/tesseract/engine.py: -------------------------------------------------------------------------------- 1 | from multiocr.base_class import OCR 2 | import json 3 | import pytesseract 4 | import pandas as pd 5 | from PIL import Image 6 | from typing import Union 7 | 8 | class TesseractOcr(OCR): 9 | """ 10 | The TextractOcr class takes an image file path as input. It has four methods: 11 | 12 | text_extraction(): This method extracts text from the image using Tesseract OCR and returns the text as a dictionary with the block IDs as keys and the text, confidence score, and bounding box coordinates as values. 13 | text_extraction_to_json(text_dict): This method takes the dictionary output from text_extraction() as input and saves it to a JSON file. 14 | text_extraction_to_df(text_dict): This method takes the dictionary output from text_extraction() as input and saves it to a Pandas DataFrame with columns for the text, confidence score, and bounding box coordinates. 15 | extract_plain_text(text_dict): This method takes the dictionary output from text_extraction() as input and saves the plain text to a text file. 16 | """ 17 | def __init__(self, config:Union[dict, None]=None): 18 | self.config = config 19 | if not self.config: 20 | self.config = { 21 | "lang": "eng" 22 | } 23 | self.config.pop("output_type", None) 24 | 25 | def text_extraction(self, image_file): 26 | try: 27 | text = pytesseract.image_to_data( Image.open(image_file), output_type='dict', **self.config) 28 | self.raw_ocr = text 29 | except Exception as e: 30 | raise Exception(f"Error detecting text in image: {e}") 31 | 32 | text_dict = [] 33 | 34 | for i in range(len(text['text'])): 35 | if text['conf'][i] > -1: 36 | text_dict.append({'text': text['text'][i], 'confidence': text['conf'][i], 37 | 'coordinates': {'xmin': text['left'][i], 'ymin': text['top'][i], 38 | 'xmax': text['width'][i]+text['left'][i], 'ymax': text['height'][i]+text['top'][i]}}) 39 | 40 | return text_dict 41 | 42 | def text_extraction_to_json(self, text_dict): 43 | try: 44 | return json.dumps(text_dict) 45 | except Exception as e: 46 | raise Exception(f"Error converting text extraction to JSON: {e}") 47 | 48 | def text_extraction_to_df(self, text_dict): 49 | rows = [] 50 | 51 | for v in text_dict: 52 | rows.append([v['text'], v['confidence'], v['coordinates']['xmin'], v['coordinates']['ymin'], 53 | v['coordinates']['xmax'], v['coordinates']['ymax']]) 54 | 55 | df = pd.DataFrame(rows, columns=['text', 'confidence', 'xmin', 'ymin', 'xmax', 'ymax']) 56 | 57 | try: 58 | return df 59 | except Exception as e: 60 | raise Exception(f"Error converting text extraction to dataframe: {e}") 61 | 62 | def extract_plain_text(self, text_dict): 63 | plain_text = '' 64 | 65 | for v in text_dict: 66 | plain_text += v['text'] + ' ' 67 | 68 | try: 69 | return plain_text 70 | except Exception as e: 71 | raise Exception(f"Error converting text extraction to plain text: {e}") 72 | 73 | if __name__ == "__main__": 74 | image_file = "/Users/aravindh/Documents/GitHub/multiocr/tests/data/test-european.jpg" 75 | config = { 76 | "lang": "eng", 77 | "config" : "--psm 6" 78 | } 79 | engine = TesseractOcr(config) 80 | text_dict = engine.text_extraction(image_file) 81 | json_op = engine.text_extraction_to_json(text_dict) 82 | df = engine.text_extraction_to_df(text_dict) 83 | plain_text = engine.extract_plain_text(text_dict) 84 | print() -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/recognition/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import List 7 | 8 | from rapidfuzz.distance import Levenshtein 9 | 10 | __all__ = ["merge_strings", "merge_multi_strings"] 11 | 12 | 13 | def merge_strings(a: str, b: str, dil_factor: float) -> str: 14 | """Merges 2 character sequences in the best way to maximize the alignment of their overlapping characters. 15 | 16 | Args: 17 | a: first char seq, suffix should be similar to b's prefix. 18 | b: second char seq, prefix should be similar to a's suffix. 19 | dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is 20 | only used when the mother sequence is splitted on a character repetition 21 | 22 | Returns: 23 | A merged character sequence. 24 | 25 | Example:: 26 | >>> from doctr.model.recognition.utils import merge_sequences 27 | >>> merge_sequences('abcd', 'cdefgh', 1.4) 28 | 'abcdefgh' 29 | >>> merge_sequences('abcdi', 'cdefgh', 1.4) 30 | 'abcdefgh' 31 | """ 32 | seq_len = min(len(a), len(b)) 33 | if seq_len == 0: # One sequence is empty, return the other 34 | return b if len(a) == 0 else a 35 | 36 | # Initialize merging index and corresponding score (mean Levenstein) 37 | min_score, index = 1.0, 0 # No overlap, just concatenate 38 | 39 | scores = [Levenshtein.distance(a[-i:], b[:i], processor=None) / i for i in range(1, seq_len + 1)] 40 | 41 | # Edge case (split in the middle of char repetitions): if it starts with 2 or more 0 42 | if len(scores) > 1 and (scores[0], scores[1]) == (0, 0): 43 | # Compute n_overlap (number of overlapping chars, geometrically determined) 44 | n_overlap = round(len(b) * (dil_factor - 1) / dil_factor) 45 | # Find the number of consecutive zeros in the scores list 46 | # Impossible to have a zero after a non-zero score in that case 47 | n_zeros = sum(val == 0 for val in scores) 48 | # Index is bounded by the geometrical overlap to avoid collapsing repetitions 49 | min_score, index = 0, min(n_zeros, n_overlap) 50 | 51 | else: # Common case: choose the min score index 52 | for i, score in enumerate(scores): 53 | if score < min_score: 54 | min_score, index = score, i + 1 # Add one because first index is an overlap of 1 char 55 | 56 | # Merge with correct overlap 57 | if index == 0: 58 | return a + b 59 | return a[:-1] + b[index - 1 :] 60 | 61 | 62 | def merge_multi_strings(seq_list: List[str], dil_factor: float) -> str: 63 | """Recursively merges consecutive string sequences with overlapping characters. 64 | 65 | Args: 66 | seq_list: list of sequences to merge. Sequences need to be ordered from left to right. 67 | dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is 68 | only used when the mother sequence is splitted on a character repetition 69 | 70 | Returns: 71 | A merged character sequence 72 | 73 | Example:: 74 | >>> from doctr.model.recognition.utils import merge_multi_sequences 75 | >>> merge_multi_sequences(['abc', 'bcdef', 'difghi', 'aijkl'], 1.4) 76 | 'abcdefghijkl' 77 | """ 78 | 79 | def _recursive_merge(a: str, seq_list: List[str], dil_factor: float) -> str: 80 | # Recursive version of compute_overlap 81 | if len(seq_list) == 1: 82 | return merge_strings(a, seq_list[0], dil_factor) 83 | return _recursive_merge(merge_strings(a, seq_list[0], dil_factor), seq_list[1:], dil_factor) 84 | 85 | return _recursive_merge("", seq_list, dil_factor) 86 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/modules/vision_transformer/pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import math 7 | from typing import Tuple 8 | 9 | import torch 10 | from torch import nn 11 | 12 | __all__ = ["PatchEmbedding"] 13 | 14 | 15 | class PatchEmbedding(nn.Module): 16 | """Compute 2D patch embeddings with cls token and positional encoding""" 17 | 18 | def __init__(self, input_shape: Tuple[int, int, int], embed_dim: int, patch_size: Tuple[int, int]) -> None: 19 | super().__init__() 20 | channels, height, width = input_shape 21 | self.patch_size = patch_size 22 | self.interpolate = True if patch_size[0] == patch_size[1] else False 23 | self.grid_size = tuple([s // p for s, p in zip((height, width), self.patch_size)]) 24 | self.num_patches = self.grid_size[0] * self.grid_size[1] 25 | 26 | self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) 27 | self.positions = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim)) 28 | self.projection = nn.Conv2d(channels, embed_dim, kernel_size=self.patch_size, stride=self.patch_size) 29 | 30 | def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: 31 | """ 32 | 100 % borrowed from: 33 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/modeling_vit.py 34 | 35 | This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher 36 | resolution images. 37 | 38 | Source: 39 | https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py 40 | """ 41 | 42 | num_patches = embeddings.shape[1] - 1 43 | num_positions = self.positions.shape[1] - 1 44 | if num_patches == num_positions and height == width: 45 | return self.positions 46 | class_pos_embed = self.positions[:, 0] 47 | patch_pos_embed = self.positions[:, 1:] 48 | dim = embeddings.shape[-1] 49 | h0 = float(height // self.patch_size[0]) 50 | w0 = float(width // self.patch_size[1]) 51 | # we add a small number to avoid floating point error in the interpolation 52 | # see discussion at https://github.com/facebookresearch/dino/issues/8 53 | h0, w0 = h0 + 0.1, w0 + 0.1 54 | patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) 55 | patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) 56 | patch_pos_embed = nn.functional.interpolate( 57 | patch_pos_embed, 58 | scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), 59 | mode="bilinear", 60 | align_corners=False, 61 | recompute_scale_factor=True, 62 | ) 63 | assert int(h0) == patch_pos_embed.shape[-2], "height of interpolated patch embedding doesn't match" 64 | assert int(w0) == patch_pos_embed.shape[-1], "width of interpolated patch embedding doesn't match" 65 | 66 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 67 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 68 | 69 | def forward(self, x: torch.Tensor) -> torch.Tensor: 70 | B, C, H, W = x.shape 71 | assert H % self.patch_size[0] == 0, "Image height must be divisible by patch height" 72 | assert W % self.patch_size[1] == 0, "Image width must be divisible by patch width" 73 | 74 | # patchify image 75 | patches = self.projection(x).flatten(2).transpose(1, 2) 76 | 77 | cls_tokens = self.cls_token.expand(B, -1, -1) # (batch_size, 1, d_model) 78 | # concate cls_tokens to patches 79 | embeddings = torch.cat([cls_tokens, patches], dim=1) # (batch_size, num_patches + 1, d_model) 80 | # add positions to embeddings 81 | if self.interpolate: 82 | embeddings += self.interpolate_pos_encoding(embeddings, H, W) 83 | else: 84 | embeddings += self.positions 85 | 86 | return embeddings # (batch_size, num_patches + 1, d_model) 87 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/classification/vgg/tensorflow.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from copy import deepcopy 7 | from typing import Any, Dict, List, Optional, Tuple 8 | 9 | from tensorflow.keras import layers 10 | from tensorflow.keras.models import Sequential 11 | 12 | from multiocr.pipelines.doctr_ocr.doctr.datasets import VOCABS 13 | 14 | from ...utils import conv_sequence, load_pretrained_params 15 | 16 | __all__ = ["VGG", "vgg16_bn_r"] 17 | 18 | 19 | default_cfgs: Dict[str, Dict[str, Any]] = { 20 | "vgg16_bn_r": { 21 | "mean": (0.5, 0.5, 0.5), 22 | "std": (1.0, 1.0, 1.0), 23 | "input_shape": (32, 32, 3), 24 | "classes": list(VOCABS["french"]), 25 | "url": "https://doctr-static.mindee.com/models?id=v0.4.1/vgg16_bn_r-c5836cea.zip&src=0", 26 | }, 27 | } 28 | 29 | 30 | class VGG(Sequential): 31 | """Implements the VGG architecture from `"Very Deep Convolutional Networks for Large-Scale Image Recognition" 32 | `_. 33 | 34 | Args: 35 | num_blocks: number of convolutional block in each stage 36 | planes: number of output channels in each stage 37 | rect_pools: whether pooling square kernels should be replace with rectangular ones 38 | include_top: whether the classifier head should be instantiated 39 | num_classes: number of output classes 40 | input_shape: shapes of the input tensor 41 | """ 42 | 43 | def __init__( 44 | self, 45 | num_blocks: List[int], 46 | planes: List[int], 47 | rect_pools: List[bool], 48 | include_top: bool = False, 49 | num_classes: int = 1000, 50 | input_shape: Optional[Tuple[int, int, int]] = None, 51 | cfg: Optional[Dict[str, Any]] = None, 52 | ) -> None: 53 | _layers = [] 54 | # Specify input_shape only for the first layer 55 | kwargs = {"input_shape": input_shape} 56 | for nb_blocks, out_chan, rect_pool in zip(num_blocks, planes, rect_pools): 57 | for _ in range(nb_blocks): 58 | _layers.extend(conv_sequence(out_chan, "relu", True, kernel_size=3, **kwargs)) # type: ignore[arg-type] 59 | kwargs = {} 60 | _layers.append(layers.MaxPooling2D((2, 1 if rect_pool else 2))) 61 | 62 | if include_top: 63 | _layers.extend([layers.GlobalAveragePooling2D(), layers.Dense(num_classes)]) 64 | super().__init__(_layers) 65 | self.cfg = cfg 66 | 67 | 68 | def _vgg( 69 | arch: str, pretrained: bool, num_blocks: List[int], planes: List[int], rect_pools: List[bool], **kwargs: Any 70 | ) -> VGG: 71 | kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"])) 72 | kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"]) 73 | kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"]) 74 | 75 | _cfg = deepcopy(default_cfgs[arch]) 76 | _cfg["num_classes"] = kwargs["num_classes"] 77 | _cfg["classes"] = kwargs["classes"] 78 | _cfg["input_shape"] = kwargs["input_shape"] 79 | kwargs.pop("classes") 80 | 81 | # Build the model 82 | model = VGG(num_blocks, planes, rect_pools, cfg=_cfg, **kwargs) 83 | # Load pretrained parameters 84 | if pretrained: 85 | load_pretrained_params(model, default_cfgs[arch]["url"]) 86 | 87 | return model 88 | 89 | 90 | def vgg16_bn_r(pretrained: bool = False, **kwargs: Any) -> VGG: 91 | """VGG-16 architecture as described in `"Very Deep Convolutional Networks for Large-Scale Image Recognition" 92 | `_, modified by adding batch normalization, rectangular pooling and a simpler 93 | classification head. 94 | 95 | >>> import tensorflow as tf 96 | >>> from doctr.models import vgg16_bn_r 97 | >>> model = vgg16_bn_r(pretrained=False) 98 | >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32) 99 | >>> out = model(input_tensor) 100 | 101 | Args: 102 | pretrained (bool): If True, returns a model pre-trained on ImageNet 103 | 104 | Returns: 105 | VGG feature extractor 106 | """ 107 | 108 | return _vgg( 109 | "vgg16_bn_r", pretrained, [2, 2, 3, 3, 3], [64, 128, 256, 512, 512], [False, False, True, True, True], **kwargs 110 | ) 111 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/modules/vision_transformer/tensorflow.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import math 7 | from typing import Any, Tuple 8 | 9 | import tensorflow as tf 10 | from tensorflow.keras import layers 11 | 12 | from multiocr.pipelines.doctr_ocr.doctr.utils.repr import NestedObject 13 | 14 | __all__ = ["PatchEmbedding"] 15 | 16 | 17 | class PatchEmbedding(layers.Layer, NestedObject): 18 | """Compute 2D patch embeddings with cls token and positional encoding""" 19 | 20 | def __init__(self, input_shape: Tuple[int, int, int], embed_dim: int, patch_size: Tuple[int, int]) -> None: 21 | super().__init__() 22 | height, width, _ = input_shape 23 | self.patch_size = patch_size 24 | self.interpolate = True if patch_size[0] == patch_size[1] else False 25 | self.grid_size = tuple([s // p for s, p in zip((height, width), self.patch_size)]) 26 | self.num_patches = self.grid_size[0] * self.grid_size[1] 27 | 28 | self.cls_token = self.add_weight(shape=(1, 1, embed_dim), initializer="zeros", trainable=True, name="cls_token") 29 | self.positions = self.add_weight( 30 | shape=(1, self.num_patches + 1, embed_dim), 31 | initializer="zeros", 32 | trainable=True, 33 | name="positions", 34 | ) 35 | self.projection = layers.Conv2D( 36 | filters=embed_dim, 37 | kernel_size=self.patch_size, 38 | strides=self.patch_size, 39 | padding="valid", 40 | data_format="channels_last", 41 | use_bias=True, 42 | kernel_initializer="glorot_uniform", 43 | bias_initializer="zeros", 44 | name="projection", 45 | ) 46 | 47 | def interpolate_pos_encoding(self, embeddings: tf.Tensor, height: int, width: int) -> tf.Tensor: 48 | """ 49 | 100 % borrowed from: 50 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/modeling_tf_vit.py 51 | 52 | This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher 53 | resolution images. 54 | 55 | Source: 56 | https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py 57 | """ 58 | 59 | seq_len, dim = embeddings.shape[1:] 60 | num_patches = seq_len - 1 61 | 62 | num_positions = self.positions.shape[1] - 1 63 | 64 | if num_patches == num_positions and height == width: 65 | return self.positions 66 | class_pos_embed = self.positions[:, :1] 67 | patch_pos_embed = self.positions[:, 1:] 68 | h0 = height // self.patch_size[0] 69 | w0 = width // self.patch_size[1] 70 | patch_pos_embed = tf.image.resize( 71 | images=tf.reshape( 72 | patch_pos_embed, shape=(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) 73 | ), 74 | size=(h0, w0), 75 | method="bilinear", 76 | ) 77 | 78 | shape = patch_pos_embed.shape 79 | assert h0 == shape[-3], "height of interpolated patch embedding doesn't match" 80 | assert w0 == shape[-2], "width of interpolated patch embedding doesn't match" 81 | 82 | patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim)) 83 | return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1) 84 | 85 | def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor: 86 | B, H, W, C = x.shape 87 | assert H % self.patch_size[0] == 0, "Image height must be divisible by patch height" 88 | assert W % self.patch_size[1] == 0, "Image width must be divisible by patch width" 89 | # patchify image 90 | patches = self.projection(x, **kwargs) # (batch_size, num_patches, d_model) 91 | patches = tf.reshape(patches, (B, self.num_patches, -1)) # (batch_size, num_patches, d_model) 92 | 93 | cls_tokens = tf.repeat(self.cls_token, B, axis=0) # (batch_size, 1, d_model) 94 | # concate cls_tokens to patches 95 | embeddings = tf.concat([cls_tokens, patches], axis=1) # (batch_size, num_patches + 1, d_model) 96 | # add positions to embeddings 97 | if self.interpolate: 98 | embeddings += self.interpolate_pos_encoding(embeddings, H, W) 99 | else: 100 | embeddings += self.positions 101 | 102 | return embeddings # (batch_size, num_patches + 1, d_model) 103 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/utils/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | # Adapted from https://github.com/pytorch/vision/blob/master/torchvision/datasets/utils.py 7 | 8 | import hashlib 9 | import logging 10 | import os 11 | import re 12 | import urllib 13 | import urllib.error 14 | import urllib.request 15 | from pathlib import Path 16 | from typing import Optional, Union 17 | 18 | from tqdm.auto import tqdm 19 | 20 | __all__ = ["download_from_url"] 21 | 22 | 23 | # matches bfd8deac from resnet18-bfd8deac.ckpt 24 | HASH_REGEX = re.compile(r"-([a-f0-9]*)\.") 25 | USER_AGENT = "mindee/doctr" 26 | 27 | 28 | def _urlretrieve(url: str, filename: Union[Path, str], chunk_size: int = 1024) -> None: 29 | with open(filename, "wb") as fh: 30 | with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response: 31 | with tqdm(total=response.length) as pbar: 32 | for chunk in iter(lambda: response.read(chunk_size), ""): 33 | if not chunk: 34 | break 35 | pbar.update(chunk_size) 36 | fh.write(chunk) 37 | 38 | 39 | def _check_integrity(file_path: Union[str, Path], hash_prefix: str) -> bool: 40 | with open(file_path, "rb") as f: 41 | sha_hash = hashlib.sha256(f.read()).hexdigest() 42 | 43 | return sha_hash[: len(hash_prefix)] == hash_prefix 44 | 45 | 46 | def download_from_url( 47 | url: str, 48 | file_name: Optional[str] = None, 49 | hash_prefix: Optional[str] = None, 50 | cache_dir: Optional[str] = None, 51 | cache_subdir: Optional[str] = None, 52 | ) -> Path: 53 | """Download a file using its URL 54 | 55 | >>> from doctr.models import download_from_url 56 | >>> download_from_url("https://yoursource.com/yourcheckpoint-yourhash.zip") 57 | 58 | Args: 59 | url: the URL of the file to download 60 | file_name: optional name of the file once downloaded 61 | hash_prefix: optional expected SHA256 hash of the file 62 | cache_dir: cache directory 63 | cache_subdir: subfolder to use in the cache 64 | 65 | Returns: 66 | the location of the downloaded file 67 | 68 | Note: 69 | You can change cache directory location by using `DOCTR_CACHE_DIR` environment variable. 70 | """ 71 | 72 | if not isinstance(file_name, str): 73 | file_name = url.rpartition("/")[-1].split("&")[0] 74 | 75 | cache_dir = ( 76 | str(os.environ.get("DOCTR_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "doctr"))) 77 | if cache_dir is None 78 | else cache_dir 79 | ) 80 | 81 | # Check hash in file name 82 | if hash_prefix is None: 83 | r = HASH_REGEX.search(file_name) 84 | hash_prefix = r.group(1) if r else None 85 | 86 | folder_path = Path(cache_dir) if cache_subdir is None else Path(cache_dir, cache_subdir) 87 | file_path = folder_path.joinpath(file_name) 88 | # Check file existence 89 | if file_path.is_file() and (hash_prefix is None or _check_integrity(file_path, hash_prefix)): 90 | logging.info(f"Using downloaded & verified file: {file_path}") 91 | return file_path 92 | 93 | try: 94 | # Create folder hierarchy 95 | folder_path.mkdir(parents=True, exist_ok=True) 96 | except OSError: 97 | error_message = f"Failed creating cache direcotry at {folder_path}" 98 | if os.environ.get("DOCTR_CACHE_DIR", ""): 99 | error_message += " using path from 'DOCTR_CACHE_DIR' environment variable." 100 | else: 101 | error_message += ( 102 | ". You can change default cache directory using 'DOCTR_CACHE_DIR' environment variable if needed." 103 | ) 104 | logging.error(error_message) 105 | raise 106 | # Download the file 107 | try: 108 | print(f"Downloading {url} to {file_path}") 109 | _urlretrieve(url, file_path) 110 | except (urllib.error.URLError, IOError) as e: 111 | if url[:5] == "https": 112 | url = url.replace("https:", "http:") 113 | print("Failed download. Trying https -> http instead." f" Downloading {url} to {file_path}") 114 | _urlretrieve(url, file_path) 115 | else: 116 | raise e 117 | 118 | # Remove corrupted files 119 | if isinstance(hash_prefix, str) and not _check_integrity(file_path, hash_prefix): 120 | # Remove file 121 | os.remove(file_path) 122 | raise ValueError(f"corrupted download, the hash of {url} does not match its expected value") 123 | 124 | return file_path 125 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/zoo.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import Any 7 | 8 | from .detection.zoo import detection_predictor 9 | from .predictor import OCRPredictor 10 | from .recognition.zoo import recognition_predictor 11 | 12 | __all__ = ["ocr_predictor"] 13 | 14 | 15 | def _predictor( 16 | det_arch: Any, 17 | reco_arch: Any, 18 | pretrained: bool, 19 | pretrained_backbone: bool = True, 20 | assume_straight_pages: bool = True, 21 | preserve_aspect_ratio: bool = True, 22 | symmetric_pad: bool = True, 23 | det_bs: int = 2, 24 | reco_bs: int = 128, 25 | detect_orientation: bool = False, 26 | detect_language: bool = False, 27 | **kwargs, 28 | ) -> OCRPredictor: 29 | # Detection 30 | det_predictor = detection_predictor( 31 | det_arch, 32 | pretrained=pretrained, 33 | pretrained_backbone=pretrained_backbone, 34 | batch_size=det_bs, 35 | assume_straight_pages=assume_straight_pages, 36 | preserve_aspect_ratio=preserve_aspect_ratio, 37 | symmetric_pad=symmetric_pad, 38 | ) 39 | 40 | # Recognition 41 | reco_predictor = recognition_predictor( 42 | reco_arch, 43 | pretrained=pretrained, 44 | pretrained_backbone=pretrained_backbone, 45 | batch_size=reco_bs, 46 | ) 47 | 48 | return OCRPredictor( 49 | det_predictor, 50 | reco_predictor, 51 | assume_straight_pages=assume_straight_pages, 52 | preserve_aspect_ratio=preserve_aspect_ratio, 53 | symmetric_pad=symmetric_pad, 54 | detect_orientation=detect_orientation, 55 | detect_language=detect_language, 56 | **kwargs, 57 | ) 58 | 59 | 60 | def ocr_predictor( 61 | det_arch: Any = "db_resnet50", 62 | reco_arch: Any = "crnn_vgg16_bn", 63 | pretrained: bool = False, 64 | pretrained_backbone: bool = True, 65 | assume_straight_pages: bool = True, 66 | preserve_aspect_ratio: bool = True, 67 | symmetric_pad: bool = True, 68 | export_as_straight_boxes: bool = False, 69 | detect_orientation: bool = False, 70 | detect_language: bool = False, 71 | **kwargs: Any, 72 | ) -> OCRPredictor: 73 | """End-to-end OCR architecture using one model for localization, and another for text recognition. 74 | 75 | >>> import numpy as np 76 | >>> from doctr.models import ocr_predictor 77 | >>> model = ocr_predictor('db_resnet50', 'crnn_vgg16_bn', pretrained=True) 78 | >>> input_page = (255 * np.random.rand(600, 800, 3)).astype(np.uint8) 79 | >>> out = model([input_page]) 80 | 81 | Args: 82 | det_arch: name of the detection architecture or the model itself to use 83 | (e.g. 'db_resnet50', 'db_mobilenet_v3_large') 84 | reco_arch: name of the recognition architecture or the model itself to use 85 | (e.g. 'crnn_vgg16_bn', 'sar_resnet31') 86 | pretrained: If True, returns a model pre-trained on our OCR dataset 87 | pretrained_backbone: If True, returns a model with a pretrained backbone 88 | assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages 89 | without rotated textual elements. 90 | preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before 91 | running the detection model on it. 92 | symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right. 93 | export_as_straight_boxes: when assume_straight_pages is set to False, export final predictions 94 | (potentially rotated) as straight bounding boxes. 95 | detect_orientation: if True, the estimated general page orientation will be added to the predictions for each 96 | page. Doing so will slightly deteriorate the overall latency. 97 | detect_language: if True, the language prediction will be added to the predictions for each 98 | page. Doing so will slightly deteriorate the overall latency. 99 | kwargs: keyword args of `OCRPredictor` 100 | 101 | Returns: 102 | OCR predictor 103 | """ 104 | 105 | return _predictor( 106 | det_arch, 107 | reco_arch, 108 | pretrained, 109 | pretrained_backbone=pretrained_backbone, 110 | assume_straight_pages=assume_straight_pages, 111 | preserve_aspect_ratio=preserve_aspect_ratio, 112 | symmetric_pad=symmetric_pad, 113 | export_as_straight_boxes=export_as_straight_boxes, 114 | detect_orientation=detect_orientation, 115 | detect_language=detect_language, 116 | **kwargs, 117 | ) -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/preprocessor/tensorflow.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import math 7 | from typing import Any, List, Tuple, Union 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | 12 | from multiocr.pipelines.doctr_ocr.doctr.transforms import Normalize, Resize 13 | from multiocr.pipelines.doctr_ocr.doctr.utils.multithreading import multithread_exec 14 | from multiocr.pipelines.doctr_ocr.doctr.utils.repr import NestedObject 15 | 16 | __all__ = ["PreProcessor"] 17 | 18 | 19 | class PreProcessor(NestedObject): 20 | """Implements an abstract preprocessor object which performs casting, resizing, batching and normalization. 21 | 22 | Args: 23 | output_size: expected size of each page in format (H, W) 24 | batch_size: the size of page batches 25 | mean: mean value of the training distribution by channel 26 | std: standard deviation of the training distribution by channel 27 | """ 28 | 29 | _children_names: List[str] = ["resize", "normalize"] 30 | 31 | def __init__( 32 | self, 33 | output_size: Tuple[int, int], 34 | batch_size: int, 35 | mean: Tuple[float, float, float] = (0.5, 0.5, 0.5), 36 | std: Tuple[float, float, float] = (1.0, 1.0, 1.0), 37 | fp16: bool = False, 38 | **kwargs: Any, 39 | ) -> None: 40 | self.batch_size = batch_size 41 | self.resize = Resize(output_size, **kwargs) 42 | # Perform the division by 255 at the same time 43 | self.normalize = Normalize(mean, std) 44 | 45 | def batch_inputs(self, samples: List[tf.Tensor]) -> List[tf.Tensor]: 46 | """Gather samples into batches for inference purposes 47 | 48 | Args: 49 | samples: list of samples (tf.Tensor) 50 | 51 | Returns: 52 | list of batched samples 53 | """ 54 | 55 | num_batches = int(math.ceil(len(samples) / self.batch_size)) 56 | batches = [ 57 | tf.stack(samples[idx * self.batch_size : min((idx + 1) * self.batch_size, len(samples))], axis=0) 58 | for idx in range(int(num_batches)) 59 | ] 60 | 61 | return batches 62 | 63 | def sample_transforms(self, x: Union[np.ndarray, tf.Tensor]) -> tf.Tensor: 64 | if x.ndim != 3: 65 | raise AssertionError("expected list of 3D Tensors") 66 | if isinstance(x, np.ndarray): 67 | if x.dtype not in (np.uint8, np.float32): 68 | raise TypeError("unsupported data type for numpy.ndarray") 69 | x = tf.convert_to_tensor(x) 70 | elif x.dtype not in (tf.uint8, tf.float16, tf.float32): 71 | raise TypeError("unsupported data type for torch.Tensor") 72 | # Data type & 255 division 73 | if x.dtype == tf.uint8: 74 | x = tf.image.convert_image_dtype(x, dtype=tf.float32) 75 | # Resizing 76 | x = self.resize(x) 77 | 78 | return x 79 | 80 | def __call__(self, x: Union[tf.Tensor, np.ndarray, List[Union[tf.Tensor, np.ndarray]]]) -> List[tf.Tensor]: 81 | """Prepare document data for model forwarding 82 | 83 | Args: 84 | x: list of images (np.array) or tensors (already resized and batched) 85 | Returns: 86 | list of page batches 87 | """ 88 | 89 | # Input type check 90 | if isinstance(x, (np.ndarray, tf.Tensor)): 91 | if x.ndim != 4: 92 | raise AssertionError("expected 4D Tensor") 93 | if isinstance(x, np.ndarray): 94 | if x.dtype not in (np.uint8, np.float32): 95 | raise TypeError("unsupported data type for numpy.ndarray") 96 | x = tf.convert_to_tensor(x) 97 | elif x.dtype not in (tf.uint8, tf.float16, tf.float32): 98 | raise TypeError("unsupported data type for torch.Tensor") 99 | 100 | # Data type & 255 division 101 | if x.dtype == tf.uint8: 102 | x = tf.image.convert_image_dtype(x, dtype=tf.float32) 103 | # Resizing 104 | if (x.shape[1], x.shape[2]) != self.resize.output_size: 105 | x = tf.image.resize(x, self.resize.output_size, method=self.resize.method) 106 | 107 | batches = [x] 108 | 109 | elif isinstance(x, list) and all(isinstance(sample, (np.ndarray, tf.Tensor)) for sample in x): 110 | # Sample transform (to tensor, resize) 111 | samples = list(multithread_exec(self.sample_transforms, x)) 112 | # Batching 113 | batches = self.batch_inputs(samples) 114 | else: 115 | raise TypeError(f"invalid input type: {type(x)}") 116 | 117 | # Batch transforms (normalize) 118 | batches = list(multithread_exec(self.normalize, batches)) 119 | 120 | return batches 121 | -------------------------------------------------------------------------------- /multiocr/pipelines/aws_textract/engine.py: -------------------------------------------------------------------------------- 1 | from multiocr.base_class import OCR 2 | import boto3 3 | import pandas as pd 4 | import json 5 | from typing import Union 6 | from PIL import Image 7 | import os 8 | 9 | class AwsTextractOcr(OCR): 10 | """ 11 | 12 | The TextractOcr class takes an image file path as input and an optional AWS region. It has four methods: 13 | 14 | text_extraction(): This method extracts text from the image using AWS Textract and returns the text as a dictionary with the block IDs as keys and the text, confidence score, and bounding box coordinates as values. 15 | text_extraction_to_json(text_dict): This method takes the dictionary output from text_extraction() as input and saves it to a JSON file. 16 | """ 17 | 18 | def __init__(self, config: Union[dict, None]=None): 19 | 20 | self.config = config 21 | if not self.config: 22 | self.config = { 23 | "region_name":os.getenv("region_name"), 24 | "aws_access_key_id":os.getenv("aws_access_key_id"), 25 | "aws_secret_access_key":os.getenv("aws_secret_access_key") 26 | } 27 | self.client = boto3.client('textract', **self.config) 28 | 29 | def text_extraction(self, image_file): 30 | try: 31 | img = Image.open(image_file) 32 | with open(image_file, 'rb') as f: 33 | image_bytes = f.read() 34 | except Exception as e: 35 | raise Exception(f"Error reading image file: {e}") 36 | 37 | try: 38 | response = self.client.detect_document_text(Document={'Bytes': image_bytes}) 39 | self.raw_ocr = response 40 | # with open("./aws_response.json","r") as f: 41 | # response = json. loads(f.read()) 42 | except Exception as e: 43 | raise Exception(f"Error detecting text in image: {e}") 44 | 45 | text_dict = [] 46 | 47 | for block in response['Blocks']: 48 | if block['BlockType'] == 'LINE': 49 | continue 50 | # text_dict[block['Id']] = {'text': block['Text'], 'confidence': block['Confidence'], 51 | # 'coordinates': block['Geometry']['BoundingBox']} 52 | elif block['BlockType'] == 'WORD': 53 | if block['Id'] not in text_dict: 54 | word = {'text': block['Text'], 'confidence': block['Confidence'], 55 | 'coordinates': block['Geometry']['BoundingBox']} 56 | w = word["coordinates"]["Width"]*img.width 57 | h = word["coordinates"]["Height"]*img.height 58 | x = word["coordinates"]["Left"]*img.width 59 | y = word["coordinates"]["Top"]*img.height 60 | word_dict = { 61 | "text": word["text"], 62 | "confidence": word["confidence"], 63 | "coordinates": { 64 | "xmin": x, 65 | "ymin": y, 66 | "xmax": x+w, 67 | "ymax": y+h 68 | } 69 | } 70 | text_dict.append(word_dict) 71 | return text_dict 72 | 73 | def text_extraction_to_json(self, text_dict): 74 | try: 75 | return json.dumps(text_dict) 76 | except Exception as e: 77 | raise Exception(f"Error converting text extraction to JSON: {e}") 78 | 79 | def text_extraction_to_df(self, text_dict): 80 | rows = [] 81 | 82 | for v in text_dict: 83 | rows.append([v['text'], v['confidence'], v['coordinates']['xmin'], v['coordinates']['ymin'], 84 | v['coordinates']['xmax'], v['coordinates']['ymax']]) 85 | 86 | df = pd.DataFrame(rows, columns=['text', 'confidence', 'xmin', 'ymin', 'xmax', 'ymax']) 87 | 88 | try: 89 | return df 90 | except Exception as e: 91 | raise Exception(f"Error converting text extraction to dataframe: {e}") 92 | 93 | def extract_plain_text(self, text_dict): 94 | plain_text = '' 95 | 96 | for v in text_dict: 97 | plain_text += v['text'] + ' ' 98 | 99 | try: 100 | return plain_text 101 | except Exception as e: 102 | raise Exception(f"Error converting text extraction to plain text: {e}") 103 | 104 | if __name__ == "__main__": 105 | import os 106 | config = { 107 | "region_name":os.getenv("region_name"), 108 | "aws_access_key_id":os.getenv("aws_access_key_id"), 109 | "aws_secret_access_key":os.getenv("aws_secret_access_key") 110 | } 111 | image_file = "/Users/aravindh/Documents/GitHub/multiocr/tests/data/test-european.jpg" 112 | ocr = AwsTextractOcr(config) 113 | data = ocr.text_extraction(image_file) 114 | jsn_dt = ocr.text_extraction_to_json(data) 115 | pln_txt = ocr.extract_plain_text(data) 116 | df = ocr.text_extraction_to_df(data) 117 | print() -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/transforms/functional/pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from copy import deepcopy 7 | from typing import Tuple 8 | 9 | import numpy as np 10 | import torch 11 | from torchvision.transforms import functional as F 12 | 13 | from multiocr.pipelines.doctr_ocr.doctr.utils.geometry import rotate_abs_geoms 14 | 15 | from .base import create_shadow_mask, crop_boxes 16 | 17 | __all__ = ["invert_colors", "rotate_sample", "crop_detection", "random_shadow"] 18 | 19 | 20 | def invert_colors(img: torch.Tensor, min_val: float = 0.6) -> torch.Tensor: 21 | out = F.rgb_to_grayscale(img, num_output_channels=3) 22 | # Random RGB shift 23 | shift_shape = [img.shape[0], 3, 1, 1] if img.ndim == 4 else [3, 1, 1] 24 | rgb_shift = min_val + (1 - min_val) * torch.rand(shift_shape) 25 | # Inverse the color 26 | if out.dtype == torch.uint8: 27 | out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8) 28 | else: 29 | out = out * rgb_shift.to(dtype=out.dtype) 30 | # Inverse the color 31 | out = 255 - out if out.dtype == torch.uint8 else 1 - out 32 | return out 33 | 34 | 35 | def rotate_sample( 36 | img: torch.Tensor, 37 | geoms: np.ndarray, 38 | angle: float, 39 | expand: bool = False, 40 | ) -> Tuple[torch.Tensor, np.ndarray]: 41 | """Rotate image around the center, interpolation=NEAREST, pad with 0 (black) 42 | 43 | Args: 44 | img: image to rotate 45 | geoms: array of geometries of shape (N, 4) or (N, 4, 2) 46 | angle: angle in degrees. +: counter-clockwise, -: clockwise 47 | expand: whether the image should be padded before the rotation 48 | 49 | Returns: 50 | A tuple of rotated img (tensor), rotated geometries of shape (N, 4, 2) 51 | """ 52 | rotated_img = F.rotate(img, angle=angle, fill=0, expand=expand) # Interpolation NEAREST by default 53 | rotated_img = rotated_img[:3] # when expand=True, it expands to RGBA channels 54 | # Get absolute coords 55 | _geoms = deepcopy(geoms) 56 | if _geoms.shape[1:] == (4,): 57 | if np.max(_geoms) <= 1: 58 | _geoms[:, [0, 2]] *= img.shape[-1] 59 | _geoms[:, [1, 3]] *= img.shape[-2] 60 | elif _geoms.shape[1:] == (4, 2): 61 | if np.max(_geoms) <= 1: 62 | _geoms[..., 0] *= img.shape[-1] 63 | _geoms[..., 1] *= img.shape[-2] 64 | else: 65 | raise AssertionError("invalid format for arg `geoms`") 66 | 67 | # Rotate the boxes: xmin, ymin, xmax, ymax or polygons --> (4, 2) polygon 68 | rotated_geoms: np.ndarray = rotate_abs_geoms( 69 | _geoms, 70 | angle, 71 | img.shape[1:], # type: ignore[arg-type] 72 | expand, 73 | ).astype(np.float32) 74 | 75 | # Always return relative boxes to avoid label confusions when resizing is performed aferwards 76 | rotated_geoms[..., 0] = rotated_geoms[..., 0] / rotated_img.shape[2] 77 | rotated_geoms[..., 1] = rotated_geoms[..., 1] / rotated_img.shape[1] 78 | 79 | return rotated_img, np.clip(rotated_geoms, 0, 1) 80 | 81 | 82 | def crop_detection( 83 | img: torch.Tensor, boxes: np.ndarray, crop_box: Tuple[float, float, float, float] 84 | ) -> Tuple[torch.Tensor, np.ndarray]: 85 | """Crop and image and associated bboxes 86 | 87 | Args: 88 | img: image to crop 89 | boxes: array of boxes to clip, absolute (int) or relative (float) 90 | crop_box: box (xmin, ymin, xmax, ymax) to crop the image. Relative coords. 91 | 92 | Returns: 93 | A tuple of cropped image, cropped boxes, where the image is not resized. 94 | """ 95 | if any(val < 0 or val > 1 for val in crop_box): 96 | raise AssertionError("coordinates of arg `crop_box` should be relative") 97 | h, w = img.shape[-2:] 98 | xmin, ymin = int(round(crop_box[0] * (w - 1))), int(round(crop_box[1] * (h - 1))) 99 | xmax, ymax = int(round(crop_box[2] * (w - 1))), int(round(crop_box[3] * (h - 1))) 100 | cropped_img = F.crop(img, ymin, xmin, ymax - ymin, xmax - xmin) 101 | # Crop the box 102 | boxes = crop_boxes(boxes, crop_box if boxes.max() <= 1 else (xmin, ymin, xmax, ymax)) 103 | 104 | return cropped_img, boxes 105 | 106 | 107 | def random_shadow(img: torch.Tensor, opacity_range: Tuple[float, float], **kwargs) -> torch.Tensor: 108 | """Crop and image and associated bboxes 109 | 110 | Args: 111 | img: image to modify 112 | opacity_range: the minimum and maximum desired opacity of the shadow 113 | 114 | Returns: 115 | shaded image 116 | """ 117 | 118 | shadow_mask = create_shadow_mask(img.shape[1:], **kwargs) # type: ignore[arg-type] 119 | 120 | opacity = np.random.uniform(*opacity_range) 121 | shadow_tensor = 1 - torch.from_numpy(shadow_mask[None, ...]) 122 | 123 | # Add some blur to make it believable 124 | k = 7 + 2 * int(4 * np.random.rand(1)) 125 | sigma = np.random.uniform(0.5, 5.0) 126 | shadow_tensor = F.gaussian_blur(shadow_tensor, k, sigma=[sigma, sigma]) 127 | 128 | return opacity * shadow_tensor * img + (1 - opacity) * img 129 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/preprocessor/pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import math 7 | from typing import Any, List, Tuple, Union 8 | 9 | import numpy as np 10 | import torch 11 | from torch import nn 12 | from torchvision.transforms import functional as F 13 | from torchvision.transforms import transforms as T 14 | 15 | from multiocr.pipelines.doctr_ocr.doctr.transforms import Resize 16 | from multiocr.pipelines.doctr_ocr.doctr.utils.multithreading import multithread_exec 17 | 18 | __all__ = ["PreProcessor"] 19 | 20 | 21 | class PreProcessor(nn.Module): 22 | """Implements an abstract preprocessor object which performs casting, resizing, batching and normalization. 23 | 24 | Args: 25 | output_size: expected size of each page in format (H, W) 26 | batch_size: the size of page batches 27 | mean: mean value of the training distribution by channel 28 | std: standard deviation of the training distribution by channel 29 | """ 30 | 31 | def __init__( 32 | self, 33 | output_size: Tuple[int, int], 34 | batch_size: int, 35 | mean: Tuple[float, float, float] = (0.5, 0.5, 0.5), 36 | std: Tuple[float, float, float] = (1.0, 1.0, 1.0), 37 | fp16: bool = False, 38 | **kwargs: Any, 39 | ) -> None: 40 | super().__init__() 41 | self.batch_size = batch_size 42 | self.resize: T.Resize = Resize(output_size, **kwargs) 43 | # Perform the division by 255 at the same time 44 | self.normalize = T.Normalize(mean, std) 45 | 46 | def batch_inputs(self, samples: List[torch.Tensor]) -> List[torch.Tensor]: 47 | """Gather samples into batches for inference purposes 48 | 49 | Args: 50 | samples: list of samples of shape (C, H, W) 51 | 52 | Returns: 53 | list of batched samples (*, C, H, W) 54 | """ 55 | 56 | num_batches = int(math.ceil(len(samples) / self.batch_size)) 57 | batches = [ 58 | torch.stack(samples[idx * self.batch_size : min((idx + 1) * self.batch_size, len(samples))], dim=0) 59 | for idx in range(int(num_batches)) 60 | ] 61 | 62 | return batches 63 | 64 | def sample_transforms(self, x: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: 65 | if x.ndim != 3: 66 | raise AssertionError("expected list of 3D Tensors") 67 | if isinstance(x, np.ndarray): 68 | if x.dtype not in (np.uint8, np.float32): 69 | raise TypeError("unsupported data type for numpy.ndarray") 70 | x = torch.from_numpy(x.copy()).permute(2, 0, 1) 71 | elif x.dtype not in (torch.uint8, torch.float16, torch.float32): 72 | raise TypeError("unsupported data type for torch.Tensor") 73 | # Resizing 74 | x = self.resize(x) 75 | # Data type 76 | if x.dtype == torch.uint8: 77 | x = x.to(dtype=torch.float32).div(255).clip(0, 1) # type: ignore[union-attr] 78 | else: 79 | x = x.to(dtype=torch.float32) # type: ignore[union-attr] 80 | 81 | return x 82 | 83 | def __call__(self, x: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]) -> List[torch.Tensor]: 84 | """Prepare document data for model forwarding 85 | 86 | Args: 87 | x: list of images (np.array) or tensors (already resized and batched) 88 | Returns: 89 | list of page batches 90 | """ 91 | 92 | # Input type check 93 | if isinstance(x, (np.ndarray, torch.Tensor)): 94 | if x.ndim != 4: 95 | raise AssertionError("expected 4D Tensor") 96 | if isinstance(x, np.ndarray): 97 | if x.dtype not in (np.uint8, np.float32): 98 | raise TypeError("unsupported data type for numpy.ndarray") 99 | x = torch.from_numpy(x.copy()).permute(0, 3, 1, 2) 100 | elif x.dtype not in (torch.uint8, torch.float16, torch.float32): 101 | raise TypeError("unsupported data type for torch.Tensor") 102 | # Resizing 103 | if x.shape[-2] != self.resize.size[0] or x.shape[-1] != self.resize.size[1]: 104 | x = F.resize(x, self.resize.size, interpolation=self.resize.interpolation) 105 | # Data type 106 | if x.dtype == torch.uint8: # type: ignore[union-attr] 107 | x = x.to(dtype=torch.float32).div(255).clip(0, 1) # type: ignore[union-attr] 108 | else: 109 | x = x.to(dtype=torch.float32) # type: ignore[union-attr] 110 | batches = [x] 111 | 112 | elif isinstance(x, list) and all(isinstance(sample, (np.ndarray, torch.Tensor)) for sample in x): 113 | # Sample transform (to tensor, resize) 114 | samples = list(multithread_exec(self.sample_transforms, x)) 115 | # Batching 116 | batches = self.batch_inputs(samples) 117 | else: 118 | raise TypeError(f"invalid input type: {type(x)}") 119 | 120 | # Batch transforms (normalize) 121 | batches = list(multithread_exec(self.normalize, batches)) 122 | 123 | return batches 124 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/utils/pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import logging 7 | from typing import Any, List, Optional, Tuple, Union 8 | 9 | import torch 10 | from torch import nn 11 | 12 | from ....doctr.utils.data import download_from_url 13 | 14 | __all__ = ["load_pretrained_params", "conv_sequence_pt", "set_device_and_dtype", "export_model_to_onnx", "_copy_tensor"] 15 | 16 | 17 | def _copy_tensor(x: torch.Tensor) -> torch.Tensor: 18 | return x.clone().detach() 19 | 20 | 21 | def load_pretrained_params( 22 | model: nn.Module, 23 | url: Optional[str] = None, 24 | hash_prefix: Optional[str] = None, 25 | overwrite: bool = False, 26 | ignore_keys: Optional[List[str]] = None, 27 | **kwargs: Any, 28 | ) -> None: 29 | """Load a set of parameters onto a model 30 | 31 | >>> from doctr.models import load_pretrained_params 32 | >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip") 33 | 34 | Args: 35 | model: the PyTorch model to be loaded 36 | url: URL of the zipped set of parameters 37 | hash_prefix: first characters of SHA256 expected hash 38 | overwrite: should the zip extraction be enforced if the archive has already been extracted 39 | ignore_keys: list of weights to be ignored from the state_dict 40 | """ 41 | 42 | if url is None: 43 | logging.warning("Invalid model URL, using default initialization.") 44 | else: 45 | archive_path = download_from_url(url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs) 46 | 47 | # Read state_dict 48 | state_dict = torch.load(archive_path, map_location="cpu") 49 | 50 | # Remove weights from the state_dict 51 | if ignore_keys is not None and len(ignore_keys) > 0: 52 | for key in ignore_keys: 53 | state_dict.pop(key) 54 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) 55 | if set(missing_keys) != set(ignore_keys) or len(unexpected_keys) > 0: 56 | raise ValueError("unable to load state_dict, due to non-matching keys.") 57 | else: 58 | # Load weights 59 | model.load_state_dict(state_dict) 60 | 61 | 62 | def conv_sequence_pt( 63 | in_channels: int, 64 | out_channels: int, 65 | relu: bool = False, 66 | bn: bool = False, 67 | **kwargs: Any, 68 | ) -> List[nn.Module]: 69 | """Builds a convolutional-based layer sequence 70 | 71 | >>> from torch.nn import Sequential 72 | >>> from doctr.models import conv_sequence 73 | >>> module = Sequential(conv_sequence(3, 32, True, True, kernel_size=3)) 74 | 75 | Args: 76 | out_channels: number of output channels 77 | relu: whether ReLU should be used 78 | bn: should a batch normalization layer be added 79 | 80 | Returns: 81 | list of layers 82 | """ 83 | # No bias before Batch norm 84 | kwargs["bias"] = kwargs.get("bias", not bn) 85 | # Add activation directly to the conv if there is no BN 86 | conv_seq: List[nn.Module] = [nn.Conv2d(in_channels, out_channels, **kwargs)] 87 | 88 | if bn: 89 | conv_seq.append(nn.BatchNorm2d(out_channels)) 90 | 91 | if relu: 92 | conv_seq.append(nn.ReLU(inplace=True)) 93 | 94 | return conv_seq 95 | 96 | 97 | def set_device_and_dtype( 98 | model: Any, batches: List[torch.Tensor], device: Union[str, torch.device], dtype: torch.dtype 99 | ) -> Tuple[Any, List[torch.Tensor]]: 100 | """Set the device and dtype of a model and its batches 101 | 102 | >>> import torch 103 | >>> from torch import nn 104 | >>> from doctr.models.utils import set_device_and_dtype 105 | >>> model = nn.Sequential(nn.Linear(8, 8), nn.ReLU(), nn.Linear(8, 4)) 106 | >>> batches = [torch.rand(8) for _ in range(2)] 107 | >>> model, batches = set_device_and_dtype(model, batches, device="cuda", dtype=torch.float16) 108 | 109 | Args: 110 | model: the model to be set 111 | batches: the batches to be set 112 | device: the device to be used 113 | dtype: the dtype to be used 114 | 115 | Returns: 116 | the model and batches set 117 | """ 118 | 119 | return model.to(device=device, dtype=dtype), [batch.to(device=device, dtype=dtype) for batch in batches] 120 | 121 | 122 | def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.Tensor, **kwargs: Any) -> str: 123 | """Export model to ONNX format. 124 | 125 | >>> import torch 126 | >>> from doctr.models.classification import resnet18 127 | >>> from doctr.models.utils import export_model_to_onnx 128 | >>> model = resnet18(pretrained=True) 129 | >>> export_model_to_onnx(model, "my_model", dummy_input=torch.randn(1, 3, 32, 32)) 130 | 131 | Args: 132 | model: the PyTorch model to be exported 133 | model_name: the name for the exported model 134 | dummy_input: the dummy input to the model 135 | kwargs: additional arguments to be passed to torch.onnx.export 136 | 137 | Returns: 138 | the path to the exported model 139 | """ 140 | torch.onnx.export( 141 | model, 142 | dummy_input, 143 | f"{model_name}.onnx", 144 | input_names=["input"], 145 | output_names=["logits"], 146 | dynamic_axes={"input": {0: "batch_size"}, "logits": {0: "batch_size"}}, 147 | export_params=True, 148 | verbose=False, 149 | **kwargs, 150 | ) 151 | logging.info(f"Model exported to {model_name}.onnx") 152 | return f"{model_name}.onnx" 153 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/classification/magc_resnet/pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | 7 | import math 8 | from copy import deepcopy 9 | from functools import partial 10 | from typing import Any, Dict, List, Optional, Tuple 11 | 12 | import torch 13 | from torch import nn 14 | 15 | from multiocr.pipelines.doctr_ocr.doctr.datasets import VOCABS 16 | 17 | from ...utils.pytorch import load_pretrained_params 18 | from ..resnet.pytorch import ResNet 19 | 20 | __all__ = ["magc_resnet31"] 21 | 22 | 23 | default_cfgs: Dict[str, Dict[str, Any]] = { 24 | "magc_resnet31": { 25 | "mean": (0.694, 0.695, 0.693), 26 | "std": (0.299, 0.296, 0.301), 27 | "input_shape": (3, 32, 32), 28 | "classes": list(VOCABS["french"]), 29 | "url": "https://doctr-static.mindee.com/models?id=v0.4.1/magc_resnet31-857391d8.pt&src=0", 30 | }, 31 | } 32 | 33 | 34 | class MAGC(nn.Module): 35 | """Implements the Multi-Aspect Global Context Attention, as described in 36 | `_. 37 | 38 | Args: 39 | inplanes: input channels 40 | headers: number of headers to split channels 41 | attn_scale: if True, re-scale attention to counteract the variance distibutions 42 | ratio: bottleneck ratio 43 | **kwargs 44 | """ 45 | 46 | def __init__( 47 | self, 48 | inplanes: int, 49 | headers: int = 8, 50 | attn_scale: bool = False, 51 | ratio: float = 0.0625, # bottleneck ratio of 1/16 as described in paper 52 | cfg: Optional[Dict[str, Any]] = None, 53 | ) -> None: 54 | super().__init__() 55 | 56 | self.headers = headers 57 | self.inplanes = inplanes 58 | self.attn_scale = attn_scale 59 | self.planes = int(inplanes * ratio) 60 | 61 | self.single_header_inplanes = int(inplanes / headers) 62 | 63 | self.conv_mask = nn.Conv2d(self.single_header_inplanes, 1, kernel_size=1) 64 | self.softmax = nn.Softmax(dim=1) 65 | 66 | self.transform = nn.Sequential( 67 | nn.Conv2d(self.inplanes, self.planes, kernel_size=1), 68 | nn.LayerNorm([self.planes, 1, 1]), 69 | nn.ReLU(inplace=True), 70 | nn.Conv2d(self.planes, self.inplanes, kernel_size=1), 71 | ) 72 | 73 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 74 | batch, _, height, width = inputs.size() 75 | # (N * headers, C / headers, H , W) 76 | x = inputs.view(batch * self.headers, self.single_header_inplanes, height, width) 77 | shortcut = x 78 | # (N * headers, C / headers, H * W) 79 | shortcut = shortcut.view(batch * self.headers, self.single_header_inplanes, height * width) 80 | 81 | # (N * headers, 1, H, W) 82 | context_mask = self.conv_mask(x) 83 | # (N * headers, H * W) 84 | context_mask = context_mask.view(batch * self.headers, -1) 85 | 86 | # scale variance 87 | if self.attn_scale and self.headers > 1: 88 | context_mask = context_mask / math.sqrt(self.single_header_inplanes) 89 | 90 | # (N * headers, H * W) 91 | context_mask = self.softmax(context_mask) 92 | 93 | # (N * headers, C / headers) 94 | context = (shortcut * context_mask.unsqueeze(1)).sum(-1) 95 | 96 | # (N, C, 1, 1) 97 | context = context.view(batch, self.headers * self.single_header_inplanes, 1, 1) 98 | 99 | # Transform: B, C, 1, 1 -> B, C, 1, 1 100 | transformed = self.transform(context) 101 | return inputs + transformed 102 | 103 | 104 | def _magc_resnet( 105 | arch: str, 106 | pretrained: bool, 107 | num_blocks: List[int], 108 | output_channels: List[int], 109 | stage_stride: List[int], 110 | stage_conv: List[bool], 111 | stage_pooling: List[Optional[Tuple[int, int]]], 112 | ignore_keys: Optional[List[str]] = None, 113 | **kwargs: Any, 114 | ) -> ResNet: 115 | kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"])) 116 | kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"]) 117 | 118 | _cfg = deepcopy(default_cfgs[arch]) 119 | _cfg["num_classes"] = kwargs["num_classes"] 120 | _cfg["classes"] = kwargs["classes"] 121 | kwargs.pop("classes") 122 | 123 | # Build the model 124 | model = ResNet( 125 | num_blocks, 126 | output_channels, 127 | stage_stride, 128 | stage_conv, 129 | stage_pooling, 130 | attn_module=partial(MAGC, headers=8, attn_scale=True), 131 | cfg=_cfg, 132 | **kwargs, 133 | ) 134 | # Load pretrained parameters 135 | if pretrained: 136 | # The number of classes is not the same as the number of classes in the pretrained model => 137 | # remove the last layer weights 138 | _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None 139 | load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys) 140 | 141 | return model 142 | 143 | 144 | def magc_resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet: 145 | """Resnet31 architecture with Multi-Aspect Global Context Attention as described in 146 | `"MASTER: Multi-Aspect Non-local Network for Scene Text Recognition", 147 | `_. 148 | 149 | >>> import torch 150 | >>> from doctr.models import magc_resnet31 151 | >>> model = magc_resnet31(pretrained=False) 152 | >>> input_tensor = torch.rand((1, 3, 224, 224), dtype=tf.float32) 153 | >>> out = model(input_tensor) 154 | 155 | Args: 156 | pretrained: boolean, True if model is pretrained 157 | 158 | Returns: 159 | A feature extractor model 160 | """ 161 | 162 | return _magc_resnet( 163 | "magc_resnet31", 164 | pretrained, 165 | [1, 2, 5, 3], 166 | [256, 256, 512, 512], 167 | [1, 1, 1, 1], 168 | [True] * 4, 169 | [(2, 2), (2, 1), None, None], 170 | origin_stem=False, 171 | stem_channels=128, 172 | ignore_keys=["13.weight", "13.bias"], 173 | **kwargs, 174 | ) 175 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/predictor/tensorflow.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import Any, List, Union 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | 11 | from multiocr.pipelines.doctr_ocr.doctr.io.elements import Document 12 | from multiocr.pipelines.doctr_ocr.doctr.models._utils import estimate_orientation, get_language 13 | from multiocr.pipelines.doctr_ocr.doctr.models.detection.predictor import DetectionPredictor 14 | from multiocr.pipelines.doctr_ocr.doctr.models.recognition.predictor import RecognitionPredictor 15 | from multiocr.pipelines.doctr_ocr.doctr.utils.geometry import rotate_boxes, rotate_image 16 | from multiocr.pipelines.doctr_ocr.doctr.utils.repr import NestedObject 17 | 18 | from .base import _OCRPredictor 19 | 20 | __all__ = ["OCRPredictor"] 21 | 22 | 23 | class OCRPredictor(NestedObject, _OCRPredictor): 24 | """Implements an object able to localize and identify text elements in a set of documents 25 | 26 | Args: 27 | det_predictor: detection module 28 | reco_predictor: recognition module 29 | assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages 30 | without rotated textual elements. 31 | straighten_pages: if True, estimates the page general orientation based on the median line orientation. 32 | Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped 33 | accordingly. Doing so will improve performances for documents with page-uniform rotations. 34 | detect_orientation: if True, the estimated general page orientation will be added to the predictions for each 35 | page. Doing so will slightly deteriorate the overall latency. 36 | detect_language: if True, the language prediction will be added to the predictions for each 37 | page. Doing so will slightly deteriorate the overall latency. 38 | kwargs: keyword args of `DocumentBuilder` 39 | """ 40 | 41 | _children_names = ["det_predictor", "reco_predictor", "doc_builder"] 42 | 43 | def __init__( 44 | self, 45 | det_predictor: DetectionPredictor, 46 | reco_predictor: RecognitionPredictor, 47 | assume_straight_pages: bool = True, 48 | straighten_pages: bool = False, 49 | preserve_aspect_ratio: bool = True, 50 | symmetric_pad: bool = True, 51 | detect_orientation: bool = False, 52 | detect_language: bool = False, 53 | **kwargs: Any, 54 | ) -> None: 55 | self.det_predictor = det_predictor 56 | self.reco_predictor = reco_predictor 57 | _OCRPredictor.__init__( 58 | self, assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs 59 | ) 60 | self.detect_orientation = detect_orientation 61 | self.detect_language = detect_language 62 | 63 | def __call__( 64 | self, 65 | pages: List[Union[np.ndarray, tf.Tensor]], 66 | **kwargs: Any, 67 | ) -> Document: 68 | # Dimension check 69 | if any(page.ndim != 3 for page in pages): 70 | raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.") 71 | 72 | origin_page_shapes = [page.shape[:2] for page in pages] 73 | 74 | # Detect document rotation and rotate pages 75 | if self.detect_orientation: 76 | origin_page_orientations = [estimate_orientation(page) for page in pages] 77 | orientations = [ 78 | {"value": orientation_page, "confidence": 1.0} for orientation_page in origin_page_orientations 79 | ] 80 | else: 81 | orientations = None 82 | if self.straighten_pages: 83 | origin_page_orientations = ( 84 | origin_page_orientations if self.detect_orientation else [estimate_orientation(page) for page in pages] 85 | ) 86 | pages = [rotate_image(page, -angle, expand=True) for page, angle in zip(pages, origin_page_orientations)] 87 | 88 | # Localize text elements 89 | loc_preds_dict = self.det_predictor(pages, **kwargs) 90 | assert all( 91 | len(loc_pred) == 1 for loc_pred in loc_preds_dict 92 | ), "Detection Model in ocr_predictor should output only one class" 93 | 94 | loc_preds: List[np.ndarray] = [list(loc_pred.values())[0] for loc_pred in loc_preds_dict] 95 | 96 | # Rectify crops if aspect ratio 97 | loc_preds = self._remove_padding(pages, loc_preds) 98 | 99 | # Crop images 100 | crops, loc_preds = self._prepare_crops( 101 | pages, loc_preds, channels_last=True, assume_straight_pages=self.assume_straight_pages 102 | ) 103 | # Rectify crop orientation 104 | if not self.assume_straight_pages: 105 | crops, loc_preds = self._rectify_crops(crops, loc_preds) 106 | 107 | # Identify character sequences 108 | word_preds = self.reco_predictor([crop for page_crops in crops for crop in page_crops], **kwargs) 109 | 110 | boxes, text_preds = self._process_predictions(loc_preds, word_preds) 111 | 112 | if self.detect_language: 113 | languages = [get_language(" ".join([item[0] for item in text_pred])) for text_pred in text_preds] 114 | languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages] 115 | else: 116 | languages_dict = None 117 | # Rotate back pages and boxes while keeping original image size 118 | if self.straighten_pages: 119 | boxes = [ 120 | rotate_boxes( 121 | page_boxes, 122 | angle, 123 | orig_shape=page.shape[:2] if isinstance(page, np.ndarray) else page.shape[-2:], 124 | target_shape=mask, # type: ignore[arg-type] 125 | ) 126 | for page_boxes, page, angle, mask in zip(boxes, pages, origin_page_orientations, origin_page_shapes) 127 | ] 128 | 129 | out = self.doc_builder( 130 | boxes, 131 | text_preds, 132 | origin_page_shapes, # type: ignore[arg-type] 133 | orientations, 134 | languages_dict, 135 | ) 136 | return out 137 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/utils/tensorflow.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import logging 7 | import os 8 | from typing import Any, Callable, List, Optional, Tuple, Union 9 | from zipfile import ZipFile 10 | 11 | import tensorflow as tf 12 | import tf2onnx 13 | from tensorflow.keras import Model, layers 14 | 15 | from multiocr.pipelines.doctr_ocr.doctr.utils.data import download_from_url 16 | 17 | logging.getLogger("tensorflow").setLevel(logging.DEBUG) 18 | 19 | 20 | __all__ = ["load_pretrained_params", "conv_sequence", "IntermediateLayerGetter", "export_model_to_onnx", "_copy_tensor"] 21 | 22 | 23 | def _copy_tensor(x: tf.Tensor) -> tf.Tensor: 24 | return tf.identity(x) 25 | 26 | 27 | def load_pretrained_params( 28 | model: Model, 29 | url: Optional[str] = None, 30 | hash_prefix: Optional[str] = None, 31 | overwrite: bool = False, 32 | internal_name: str = "weights", 33 | **kwargs: Any, 34 | ) -> None: 35 | """Load a set of parameters onto a model 36 | 37 | >>> from doctr.models import load_pretrained_params 38 | >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip") 39 | 40 | Args: 41 | model: the keras model to be loaded 42 | url: URL of the zipped set of parameters 43 | hash_prefix: first characters of SHA256 expected hash 44 | overwrite: should the zip extraction be enforced if the archive has already been extracted 45 | internal_name: name of the ckpt files 46 | """ 47 | 48 | if url is None: 49 | logging.warning("Invalid model URL, using default initialization.") 50 | else: 51 | archive_path = download_from_url(url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs) 52 | 53 | # Unzip the archive 54 | params_path = archive_path.parent.joinpath(archive_path.stem) 55 | if not params_path.is_dir() or overwrite: 56 | with ZipFile(archive_path, "r") as f: 57 | f.extractall(path=params_path) 58 | 59 | # Load weights 60 | model.load_weights(f"{params_path}{os.sep}{internal_name}") 61 | 62 | 63 | def conv_sequence( 64 | out_channels: int, 65 | activation: Optional[Union[str, Callable]] = None, 66 | bn: bool = False, 67 | padding: str = "same", 68 | kernel_initializer: str = "he_normal", 69 | **kwargs: Any, 70 | ) -> List[layers.Layer]: 71 | """Builds a convolutional-based layer sequence 72 | 73 | >>> from tensorflow.keras import Sequential 74 | >>> from doctr.models import conv_sequence 75 | >>> module = Sequential(conv_sequence(32, 'relu', True, kernel_size=3, input_shape=[224, 224, 3])) 76 | 77 | Args: 78 | out_channels: number of output channels 79 | activation: activation to be used (default: no activation) 80 | bn: should a batch normalization layer be added 81 | padding: padding scheme 82 | kernel_initializer: kernel initializer 83 | 84 | Returns: 85 | list of layers 86 | """ 87 | # No bias before Batch norm 88 | kwargs["use_bias"] = kwargs.get("use_bias", not bn) 89 | # Add activation directly to the conv if there is no BN 90 | kwargs["activation"] = activation if not bn else None 91 | conv_seq = [layers.Conv2D(out_channels, padding=padding, kernel_initializer=kernel_initializer, **kwargs)] 92 | 93 | if bn: 94 | conv_seq.append(layers.BatchNormalization()) 95 | 96 | if (isinstance(activation, str) or callable(activation)) and bn: 97 | # Activation function can either be a string or a function ('relu' or tf.nn.relu) 98 | conv_seq.append(layers.Activation(activation)) 99 | 100 | return conv_seq 101 | 102 | 103 | class IntermediateLayerGetter(Model): 104 | """Implements an intermediate layer getter 105 | 106 | >>> from tensorflow.keras.applications import ResNet50 107 | >>> from doctr.models import IntermediateLayerGetter 108 | >>> target_layers = ["conv2_block3_out", "conv3_block4_out", "conv4_block6_out", "conv5_block3_out"] 109 | >>> feat_extractor = IntermediateLayerGetter(ResNet50(include_top=False, pooling=False), target_layers) 110 | 111 | Args: 112 | model: the model to extract feature maps from 113 | layer_names: the list of layers to retrieve the feature map from 114 | """ 115 | 116 | def __init__(self, model: Model, layer_names: List[str]) -> None: 117 | intermediate_fmaps = [model.get_layer(layer_name).get_output_at(0) for layer_name in layer_names] 118 | super().__init__(model.input, outputs=intermediate_fmaps) 119 | 120 | def __repr__(self) -> str: 121 | return f"{self.__class__.__name__}()" 122 | 123 | 124 | def export_model_to_onnx( 125 | model: Model, model_name: str, dummy_input: List[tf.TensorSpec], **kwargs: Any 126 | ) -> Tuple[str, List[str]]: 127 | """Export model to ONNX format. 128 | 129 | >>> import tensorflow as tf 130 | >>> from doctr.models.classification import resnet18 131 | >>> from doctr.models.utils import export_classification_model_to_onnx 132 | >>> model = resnet18(pretrained=True, include_top=True) 133 | >>> export_model_to_onnx(model, "my_model", 134 | >>> dummy_input=[tf.TensorSpec([None, 32, 32, 3], tf.float32, name="input")]) 135 | 136 | Args: 137 | model: the keras model to be exported 138 | model_name: the name for the exported model 139 | dummy_input: the dummy input to the model 140 | kwargs: additional arguments to be passed to tf2onnx 141 | 142 | Returns: 143 | the path to the exported model and a list with the output layer names 144 | """ 145 | large_model = kwargs.get("large_model", False) 146 | model_proto, _ = tf2onnx.convert.from_keras( 147 | model, 148 | input_signature=dummy_input, 149 | output_path=f"{model_name}.zip" if large_model else f"{model_name}.onnx", 150 | **kwargs, 151 | ) 152 | # Get the output layer names 153 | output = [n.name for n in model_proto.graph.output] 154 | 155 | # models which are too large (weights > 2GB while converting to ONNX) needs to be handled 156 | # about an external tensor storage where the graph and weights are seperatly stored in a archive 157 | if large_model: 158 | logging.info(f"Model exported to {model_name}.zip") 159 | return f"{model_name}.zip", output 160 | 161 | logging.info(f"Model exported to {model_name}.zip") 162 | return f"{model_name}.onnx", output 163 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/classification/vit/tensorflow.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from copy import deepcopy 7 | from typing import Any, Dict, Optional, Tuple 8 | 9 | import tensorflow as tf 10 | from tensorflow.keras import Sequential, layers 11 | 12 | from multiocr.pipelines.doctr_ocr.doctr.datasets import VOCABS 13 | from multiocr.pipelines.doctr_ocr.doctr.models.modules.transformer import EncoderBlock 14 | from multiocr.pipelines.doctr_ocr.doctr.models.modules.vision_transformer.tensorflow import PatchEmbedding 15 | from multiocr.pipelines.doctr_ocr.doctr.utils.repr import NestedObject 16 | 17 | from ...utils import load_pretrained_params 18 | 19 | __all__ = ["vit_s", "vit_b"] 20 | 21 | 22 | default_cfgs: Dict[str, Dict[str, Any]] = { 23 | "vit_s": { 24 | "mean": (0.694, 0.695, 0.693), 25 | "std": (0.299, 0.296, 0.301), 26 | "input_shape": (3, 32, 32), 27 | "classes": list(VOCABS["french"]), 28 | "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vit_s-6300fcc9.zip&src=0", 29 | }, 30 | "vit_b": { 31 | "mean": (0.694, 0.695, 0.693), 32 | "std": (0.299, 0.296, 0.301), 33 | "input_shape": (32, 32, 3), 34 | "classes": list(VOCABS["french"]), 35 | "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vit_b-57158446.zip&src=0", 36 | }, 37 | } 38 | 39 | 40 | class ClassifierHead(layers.Layer, NestedObject): 41 | """Classifier head for Vision Transformer 42 | 43 | Args: 44 | num_classes: number of output classes 45 | """ 46 | 47 | def __init__(self, num_classes: int) -> None: 48 | super().__init__() 49 | 50 | self.head = layers.Dense(num_classes, kernel_initializer="he_normal", name="dense") 51 | 52 | def call(self, x: tf.Tensor) -> tf.Tensor: 53 | # (batch_size, num_classes) cls token 54 | return self.head(x[:, 0]) 55 | 56 | 57 | class VisionTransformer(Sequential): 58 | """VisionTransformer architecture as described in 59 | `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", 60 | `_. 61 | 62 | Args: 63 | d_model: dimension of the transformer layers 64 | num_layers: number of transformer layers 65 | num_heads: number of attention heads 66 | ffd_ratio: multiplier for the hidden dimension of the feedforward layer 67 | patch_size: size of the patches 68 | input_shape: size of the input image 69 | dropout: dropout rate 70 | num_classes: number of output classes 71 | include_top: whether the classifier head should be instantiated 72 | """ 73 | 74 | def __init__( 75 | self, 76 | d_model: int, 77 | num_layers: int, 78 | num_heads: int, 79 | ffd_ratio: int, 80 | patch_size: Tuple[int, int] = (4, 4), 81 | input_shape: Tuple[int, int, int] = (32, 32, 3), 82 | dropout: float = 0.0, 83 | num_classes: int = 1000, 84 | include_top: bool = True, 85 | cfg: Optional[Dict[str, Any]] = None, 86 | ) -> None: 87 | _layers = [ 88 | PatchEmbedding(input_shape, d_model, patch_size), 89 | EncoderBlock( 90 | num_layers, 91 | num_heads, 92 | d_model, 93 | d_model * ffd_ratio, 94 | dropout, 95 | activation_fct=layers.Activation("gelu"), 96 | ), 97 | ] 98 | if include_top: 99 | _layers.append(ClassifierHead(num_classes)) 100 | 101 | super().__init__(_layers) 102 | self.cfg = cfg 103 | 104 | 105 | def _vit( 106 | arch: str, 107 | pretrained: bool, 108 | **kwargs: Any, 109 | ) -> VisionTransformer: 110 | kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"])) 111 | kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"]) 112 | kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"]) 113 | 114 | _cfg = deepcopy(default_cfgs[arch]) 115 | _cfg["num_classes"] = kwargs["num_classes"] 116 | _cfg["input_shape"] = kwargs["input_shape"] 117 | _cfg["classes"] = kwargs["classes"] 118 | kwargs.pop("classes") 119 | 120 | # Build the model 121 | model = VisionTransformer(cfg=_cfg, **kwargs) 122 | # Load pretrained parameters 123 | if pretrained: 124 | load_pretrained_params(model, default_cfgs[arch]["url"]) 125 | 126 | return model 127 | 128 | 129 | def vit_s(pretrained: bool = False, **kwargs: Any) -> VisionTransformer: 130 | """VisionTransformer-S architecture 131 | `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", 132 | `_. Patches: (H, W) -> (H/8, W/8) 133 | 134 | NOTE: unofficial config used in ViTSTR and ParSeq 135 | 136 | >>> import tensorflow as tf 137 | >>> from doctr.models import vit_s 138 | >>> model = vit_s(pretrained=False) 139 | >>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32) 140 | >>> out = model(input_tensor) 141 | 142 | Args: 143 | pretrained: boolean, True if model is pretrained 144 | 145 | Returns: 146 | A feature extractor model 147 | """ 148 | 149 | return _vit( 150 | "vit_s", 151 | pretrained, 152 | d_model=384, 153 | num_layers=12, 154 | num_heads=6, 155 | ffd_ratio=4, 156 | **kwargs, 157 | ) 158 | 159 | 160 | def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer: 161 | """VisionTransformer-B architecture as described in 162 | `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", 163 | `_. Patches: (H, W) -> (H/8, W/8) 164 | 165 | >>> import tensorflow as tf 166 | >>> from doctr.models import vit_b 167 | >>> model = vit_b(pretrained=False) 168 | >>> input_tensor = tf.random.uniform(shape=[1, 32, 32, 3], maxval=1, dtype=tf.float32) 169 | >>> out = model(input_tensor) 170 | 171 | Args: 172 | pretrained: boolean, True if model is pretrained 173 | 174 | Returns: 175 | A feature extractor model 176 | """ 177 | 178 | return _vit( 179 | "vit_b", 180 | pretrained, 181 | d_model=768, 182 | num_layers=12, 183 | num_heads=12, 184 | ffd_ratio=4, 185 | **kwargs, 186 | ) 187 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/classification/vit/pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from copy import deepcopy 7 | from typing import Any, Dict, List, Optional, Tuple 8 | 9 | import torch 10 | from torch import nn 11 | 12 | from multiocr.pipelines.doctr_ocr.doctr.datasets import VOCABS 13 | from multiocr.pipelines.doctr_ocr.doctr.models.modules.transformer import EncoderBlock 14 | from multiocr.pipelines.doctr_ocr.doctr.models.modules.vision_transformer.pytorch import PatchEmbedding 15 | 16 | from ...utils.pytorch import load_pretrained_params 17 | 18 | __all__ = ["vit_s", "vit_b"] 19 | 20 | 21 | default_cfgs: Dict[str, Dict[str, Any]] = { 22 | "vit_s": { 23 | "mean": (0.694, 0.695, 0.693), 24 | "std": (0.299, 0.296, 0.301), 25 | "input_shape": (3, 32, 32), 26 | "classes": list(VOCABS["french"]), 27 | "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vit_s-5d05442d.pt&src=0", 28 | }, 29 | "vit_b": { 30 | "mean": (0.694, 0.695, 0.693), 31 | "std": (0.299, 0.296, 0.301), 32 | "input_shape": (3, 32, 32), 33 | "classes": list(VOCABS["french"]), 34 | "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vit_b-0fbef167.pt&src=0", 35 | }, 36 | } 37 | 38 | 39 | class ClassifierHead(nn.Module): 40 | """Classifier head for Vision Transformer 41 | 42 | Args: 43 | in_channels: number of input channels 44 | num_classes: number of output classes 45 | """ 46 | 47 | def __init__( 48 | self, 49 | in_channels: int, 50 | num_classes: int, 51 | ) -> None: 52 | super().__init__() 53 | 54 | self.head = nn.Linear(in_channels, num_classes) 55 | 56 | def forward(self, x: torch.Tensor) -> torch.Tensor: 57 | # (batch_size, num_classes) cls token 58 | return self.head(x[:, 0]) 59 | 60 | 61 | class VisionTransformer(nn.Sequential): 62 | """VisionTransformer architecture as described in 63 | `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", 64 | `_. 65 | 66 | Args: 67 | d_model: dimension of the transformer layers 68 | num_layers: number of transformer layers 69 | num_heads: number of attention heads 70 | ffd_ratio: multiplier for the hidden dimension of the feedforward layer 71 | patch_size: size of the patches 72 | input_shape: size of the input image 73 | dropout: dropout rate 74 | num_classes: number of output classes 75 | include_top: whether the classifier head should be instantiated 76 | """ 77 | 78 | def __init__( 79 | self, 80 | d_model: int, 81 | num_layers: int, 82 | num_heads: int, 83 | ffd_ratio: int, 84 | patch_size: Tuple[int, int] = (4, 4), 85 | input_shape: Tuple[int, int, int] = (3, 32, 32), 86 | dropout: float = 0.0, 87 | num_classes: int = 1000, 88 | include_top: bool = True, 89 | cfg: Optional[Dict[str, Any]] = None, 90 | ) -> None: 91 | _layers: List[nn.Module] = [ 92 | PatchEmbedding(input_shape, d_model, patch_size), 93 | EncoderBlock(num_layers, num_heads, d_model, d_model * ffd_ratio, dropout, nn.GELU()), 94 | ] 95 | if include_top: 96 | _layers.append(ClassifierHead(d_model, num_classes)) 97 | 98 | super().__init__(*_layers) 99 | self.cfg = cfg 100 | 101 | 102 | def _vit( 103 | arch: str, 104 | pretrained: bool, 105 | ignore_keys: Optional[List[str]] = None, 106 | **kwargs: Any, 107 | ) -> VisionTransformer: 108 | kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"])) 109 | kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"]) 110 | kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"]) 111 | 112 | _cfg = deepcopy(default_cfgs[arch]) 113 | _cfg["num_classes"] = kwargs["num_classes"] 114 | _cfg["input_shape"] = kwargs["input_shape"] 115 | _cfg["classes"] = kwargs["classes"] 116 | kwargs.pop("classes") 117 | 118 | # Build the model 119 | model = VisionTransformer(cfg=_cfg, **kwargs) 120 | # Load pretrained parameters 121 | if pretrained: 122 | # The number of classes is not the same as the number of classes in the pretrained model => 123 | # remove the last layer weights 124 | _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None 125 | load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys) 126 | 127 | return model 128 | 129 | 130 | def vit_s(pretrained: bool = False, **kwargs: Any) -> VisionTransformer: 131 | """VisionTransformer-S architecture 132 | `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", 133 | `_. Patches: (H, W) -> (H/8, W/8) 134 | 135 | NOTE: unofficial config used in ViTSTR and ParSeq 136 | 137 | >>> import torch 138 | >>> from doctr.models import vit_s 139 | >>> model = vit_s(pretrained=False) 140 | >>> input_tensor = torch.rand((1, 3, 32, 32), dtype=tf.float32) 141 | >>> out = model(input_tensor) 142 | 143 | Args: 144 | pretrained: boolean, True if model is pretrained 145 | 146 | Returns: 147 | A feature extractor model 148 | """ 149 | 150 | return _vit( 151 | "vit_s", 152 | pretrained, 153 | d_model=384, 154 | num_layers=12, 155 | num_heads=6, 156 | ffd_ratio=4, 157 | ignore_keys=["2.head.weight", "2.head.bias"], 158 | **kwargs, 159 | ) 160 | 161 | 162 | def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer: 163 | """VisionTransformer-B architecture as described in 164 | `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", 165 | `_. Patches: (H, W) -> (H/8, W/8) 166 | 167 | >>> import torch 168 | >>> from doctr.models import vit_b 169 | >>> model = vit_b(pretrained=False) 170 | >>> input_tensor = torch.rand((1, 3, 32, 32), dtype=tf.float32) 171 | >>> out = model(input_tensor) 172 | 173 | Args: 174 | pretrained: boolean, True if model is pretrained 175 | 176 | Returns: 177 | A feature extractor model 178 | """ 179 | 180 | return _vit( 181 | "vit_b", 182 | pretrained, 183 | d_model=768, 184 | num_layers=12, 185 | num_heads=12, 186 | ffd_ratio=4, 187 | ignore_keys=["2.head.weight", "2.head.bias"], 188 | **kwargs, 189 | ) 190 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/predictor/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import Any, List, Optional, Tuple 7 | 8 | import numpy as np 9 | 10 | from multiocr.pipelines.doctr_ocr.doctr.models.builder import DocumentBuilder 11 | from multiocr.pipelines.doctr_ocr.doctr.utils.geometry import extract_crops, extract_rcrops 12 | 13 | from .._utils import rectify_crops, rectify_loc_preds 14 | from ..classification import crop_orientation_predictor 15 | from ..classification.predictor import CropOrientationPredictor 16 | 17 | __all__ = ["_OCRPredictor"] 18 | 19 | 20 | class _OCRPredictor: 21 | """Implements an object able to localize and identify text elements in a set of documents 22 | 23 | Args: 24 | assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages 25 | without rotated textual elements. 26 | straighten_pages: if True, estimates the page general orientation based on the median line orientation. 27 | Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped 28 | accordingly. Doing so will improve performances for documents with page-uniform rotations. 29 | preserve_aspect_ratio: if True, resize preserving the aspect ratio (with padding) 30 | symmetric_pad: if True and preserve_aspect_ratio is True, pas the image symmetrically. 31 | kwargs: keyword args of `DocumentBuilder` 32 | """ 33 | 34 | crop_orientation_predictor: Optional[CropOrientationPredictor] 35 | 36 | def __init__( 37 | self, 38 | assume_straight_pages: bool = True, 39 | straighten_pages: bool = False, 40 | preserve_aspect_ratio: bool = True, 41 | symmetric_pad: bool = True, 42 | **kwargs: Any, 43 | ) -> None: 44 | self.assume_straight_pages = assume_straight_pages 45 | self.straighten_pages = straighten_pages 46 | self.crop_orientation_predictor = None if assume_straight_pages else crop_orientation_predictor(pretrained=True) 47 | self.doc_builder = DocumentBuilder(**kwargs) 48 | self.preserve_aspect_ratio = preserve_aspect_ratio 49 | self.symmetric_pad = symmetric_pad 50 | 51 | @staticmethod 52 | def _generate_crops( 53 | pages: List[np.ndarray], 54 | loc_preds: List[np.ndarray], 55 | channels_last: bool, 56 | assume_straight_pages: bool = False, 57 | ) -> List[List[np.ndarray]]: 58 | extraction_fn = extract_crops if assume_straight_pages else extract_rcrops 59 | 60 | crops = [ 61 | extraction_fn(page, _boxes[:, :4], channels_last=channels_last) # type: ignore[operator] 62 | for page, _boxes in zip(pages, loc_preds) 63 | ] 64 | return crops 65 | 66 | @staticmethod 67 | def _prepare_crops( 68 | pages: List[np.ndarray], 69 | loc_preds: List[np.ndarray], 70 | channels_last: bool, 71 | assume_straight_pages: bool = False, 72 | ) -> Tuple[List[List[np.ndarray]], List[np.ndarray]]: 73 | crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages) 74 | 75 | # Avoid sending zero-sized crops 76 | is_kept = [[all(s > 0 for s in crop.shape) for crop in page_crops] for page_crops in crops] 77 | crops = [ 78 | [crop for crop, _kept in zip(page_crops, page_kept) if _kept] 79 | for page_crops, page_kept in zip(crops, is_kept) 80 | ] 81 | loc_preds = [_boxes[_kept] for _boxes, _kept in zip(loc_preds, is_kept)] 82 | 83 | return crops, loc_preds 84 | 85 | def _rectify_crops( 86 | self, 87 | crops: List[List[np.ndarray]], 88 | loc_preds: List[np.ndarray], 89 | ) -> Tuple[List[List[np.ndarray]], List[np.ndarray]]: 90 | # Work at a page level 91 | orientations = [self.crop_orientation_predictor(page_crops) for page_crops in crops] # type: ignore[misc] 92 | rect_crops = [rectify_crops(page_crops, orientation) for page_crops, orientation in zip(crops, orientations)] 93 | rect_loc_preds = [ 94 | rectify_loc_preds(page_loc_preds, orientation) if len(page_loc_preds) > 0 else page_loc_preds 95 | for page_loc_preds, orientation in zip(loc_preds, orientations) 96 | ] 97 | return rect_crops, rect_loc_preds # type: ignore[return-value] 98 | 99 | def _remove_padding( 100 | self, 101 | pages: List[np.ndarray], 102 | loc_preds: List[np.ndarray], 103 | ) -> List[np.ndarray]: 104 | if self.preserve_aspect_ratio: 105 | # Rectify loc_preds to remove padding 106 | rectified_preds = [] 107 | for page, loc_pred in zip(pages, loc_preds): 108 | h, w = page.shape[0], page.shape[1] 109 | if h > w: 110 | # y unchanged, dilate x coord 111 | if self.symmetric_pad: 112 | if self.assume_straight_pages: 113 | loc_pred[:, [0, 2]] = np.clip((loc_pred[:, [0, 2]] - 0.5) * h / w + 0.5, 0, 1) 114 | else: 115 | loc_pred[:, :, 0] = np.clip((loc_pred[:, :, 0] - 0.5) * h / w + 0.5, 0, 1) 116 | else: 117 | if self.assume_straight_pages: 118 | loc_pred[:, [0, 2]] *= h / w 119 | else: 120 | loc_pred[:, :, 0] *= h / w 121 | elif w > h: 122 | # x unchanged, dilate y coord 123 | if self.symmetric_pad: 124 | if self.assume_straight_pages: 125 | loc_pred[:, [1, 3]] = np.clip((loc_pred[:, [1, 3]] - 0.5) * w / h + 0.5, 0, 1) 126 | else: 127 | loc_pred[:, :, 1] = np.clip((loc_pred[:, :, 1] - 0.5) * w / h + 0.5, 0, 1) 128 | else: 129 | if self.assume_straight_pages: 130 | loc_pred[:, [1, 3]] *= w / h 131 | else: 132 | loc_pred[:, :, 1] *= w / h 133 | rectified_preds.append(loc_pred) 134 | return rectified_preds 135 | return loc_preds 136 | 137 | @staticmethod 138 | def _process_predictions( 139 | loc_preds: List[np.ndarray], 140 | word_preds: List[Tuple[str, float]], 141 | ) -> Tuple[List[np.ndarray], List[List[Tuple[str, float]]]]: 142 | text_preds = [] 143 | if len(loc_preds) > 0: 144 | # Text 145 | _idx = 0 146 | for page_boxes in loc_preds: 147 | text_preds.append(word_preds[_idx : _idx + page_boxes.shape[0]]) 148 | _idx += page_boxes.shape[0] 149 | 150 | return loc_preds, text_preds 151 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/predictor/pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | from typing import Any, List, Union 7 | 8 | import numpy as np 9 | import torch 10 | from torch import nn 11 | 12 | from multiocr.pipelines.doctr_ocr.doctr.io.elements import Document 13 | from multiocr.pipelines.doctr_ocr.doctr.models._utils import estimate_orientation, get_language 14 | from multiocr.pipelines.doctr_ocr.doctr.models.detection.predictor import DetectionPredictor 15 | from multiocr.pipelines.doctr_ocr.doctr.models.recognition.predictor import RecognitionPredictor 16 | from multiocr.pipelines.doctr_ocr.doctr.utils.geometry import rotate_boxes, rotate_image 17 | 18 | from .base import _OCRPredictor 19 | 20 | __all__ = ["OCRPredictor"] 21 | 22 | 23 | class OCRPredictor(nn.Module, _OCRPredictor): 24 | """Implements an object able to localize and identify text elements in a set of documents 25 | 26 | Args: 27 | det_predictor: detection module 28 | reco_predictor: recognition module 29 | assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages 30 | without rotated textual elements. 31 | straighten_pages: if True, estimates the page general orientation based on the median line orientation. 32 | Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped 33 | accordingly. Doing so will improve performances for documents with page-uniform rotations. 34 | detect_orientation: if True, the estimated general page orientation will be added to the predictions for each 35 | page. Doing so will slightly deteriorate the overall latency. 36 | detect_language: if True, the language prediction will be added to the predictions for each 37 | page. Doing so will slightly deteriorate the overall latency. 38 | kwargs: keyword args of `DocumentBuilder` 39 | """ 40 | 41 | def __init__( 42 | self, 43 | det_predictor: DetectionPredictor, 44 | reco_predictor: RecognitionPredictor, 45 | assume_straight_pages: bool = True, 46 | straighten_pages: bool = False, 47 | preserve_aspect_ratio: bool = True, 48 | symmetric_pad: bool = True, 49 | detect_orientation: bool = False, 50 | detect_language: bool = False, 51 | **kwargs: Any, 52 | ) -> None: 53 | nn.Module.__init__(self) 54 | self.det_predictor = det_predictor.eval() # type: ignore[attr-defined] 55 | self.reco_predictor = reco_predictor.eval() # type: ignore[attr-defined] 56 | _OCRPredictor.__init__( 57 | self, assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs 58 | ) 59 | self.detect_orientation = detect_orientation 60 | self.detect_language = detect_language 61 | 62 | @torch.no_grad() 63 | def forward( 64 | self, 65 | pages: List[Union[np.ndarray, torch.Tensor]], 66 | **kwargs: Any, 67 | ) -> Document: 68 | # Dimension check 69 | if any(page.ndim != 3 for page in pages): 70 | raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.") 71 | 72 | origin_page_shapes = [page.shape[:2] if isinstance(page, np.ndarray) else page.shape[-2:] for page in pages] 73 | 74 | # Detect document rotation and rotate pages 75 | if self.detect_orientation: 76 | origin_page_orientations = [estimate_orientation(page) for page in pages] # type: ignore[arg-type] 77 | orientations = [ 78 | {"value": orientation_page, "confidence": 1.0} for orientation_page in origin_page_orientations 79 | ] 80 | else: 81 | orientations = None 82 | if self.straighten_pages: 83 | origin_page_orientations = ( 84 | origin_page_orientations 85 | if self.detect_orientation 86 | else [estimate_orientation(page) for page in pages] # type: ignore[arg-type] 87 | ) 88 | pages = [ 89 | rotate_image(page, -angle, expand=True) # type: ignore[arg-type] 90 | for page, angle in zip(pages, origin_page_orientations) 91 | ] 92 | 93 | # Localize text elements 94 | loc_preds = self.det_predictor(pages, **kwargs) 95 | assert all( 96 | len(loc_pred) == 1 for loc_pred in loc_preds 97 | ), "Detection Model in ocr_predictor should output only one class" 98 | 99 | loc_preds = [list(loc_pred.values())[0] for loc_pred in loc_preds] 100 | # Check whether crop mode should be switched to channels first 101 | channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray) 102 | 103 | # Rectify crops if aspect ratio 104 | loc_preds = self._remove_padding(pages, loc_preds) # type: ignore[arg-type] 105 | 106 | # Crop images 107 | crops, loc_preds = self._prepare_crops( 108 | pages, # type: ignore[arg-type] 109 | loc_preds, 110 | channels_last=channels_last, 111 | assume_straight_pages=self.assume_straight_pages, 112 | ) 113 | # Rectify crop orientation 114 | if not self.assume_straight_pages: 115 | crops, loc_preds = self._rectify_crops(crops, loc_preds) 116 | # Identify character sequences 117 | word_preds = self.reco_predictor([crop for page_crops in crops for crop in page_crops], **kwargs) 118 | 119 | boxes, text_preds = self._process_predictions(loc_preds, word_preds) 120 | 121 | if self.detect_language: 122 | languages = [get_language(" ".join([item[0] for item in text_pred])) for text_pred in text_preds] 123 | languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages] 124 | else: 125 | languages_dict = None 126 | # Rotate back pages and boxes while keeping original image size 127 | if self.straighten_pages: 128 | boxes = [ 129 | rotate_boxes( 130 | page_boxes, 131 | angle, 132 | orig_shape=page.shape[:2] 133 | if isinstance(page, np.ndarray) 134 | else page.shape[1:], # type: ignore[arg-type] 135 | target_shape=mask, # type: ignore[arg-type] 136 | ) 137 | for page_boxes, page, angle, mask in zip(boxes, pages, origin_page_orientations, origin_page_shapes) 138 | ] 139 | 140 | out = self.doc_builder( 141 | boxes, 142 | text_preds, 143 | [page.shape[:2] if channels_last else page.shape[-2:] for page in pages], # type: ignore[misc] 144 | orientations, 145 | languages_dict, 146 | ) 147 | return out 148 | -------------------------------------------------------------------------------- /multiocr/pipelines/doctr_ocr/doctr/models/classification/magc_resnet/tensorflow.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2023, Mindee. 2 | 3 | # This program is licensed under the Apache License 2.0. 4 | # See LICENSE or go to for full license details. 5 | 6 | import math 7 | from copy import deepcopy 8 | from functools import partial 9 | from typing import Any, Dict, List, Optional, Tuple 10 | 11 | import tensorflow as tf 12 | from tensorflow.keras import layers 13 | from tensorflow.keras.models import Sequential 14 | 15 | from multiocr.pipelines.doctr_ocr.doctr.datasets import VOCABS 16 | 17 | from ...utils import load_pretrained_params 18 | from ..resnet.tensorflow import ResNet 19 | 20 | __all__ = ["magc_resnet31"] 21 | 22 | 23 | default_cfgs: Dict[str, Dict[str, Any]] = { 24 | "magc_resnet31": { 25 | "mean": (0.694, 0.695, 0.693), 26 | "std": (0.299, 0.296, 0.301), 27 | "input_shape": (32, 32, 3), 28 | "classes": list(VOCABS["french"]), 29 | "url": "https://doctr-static.mindee.com/models?id=v0.6.0/magc_resnet31-addbb705.zip&src=0", 30 | }, 31 | } 32 | 33 | 34 | class MAGC(layers.Layer): 35 | """Implements the Multi-Aspect Global Context Attention, as described in 36 | `_. 37 | 38 | Args: 39 | inplanes: input channels 40 | headers: number of headers to split channels 41 | attn_scale: if True, re-scale attention to counteract the variance distibutions 42 | ratio: bottleneck ratio 43 | **kwargs 44 | """ 45 | 46 | def __init__( 47 | self, 48 | inplanes: int, 49 | headers: int = 8, 50 | attn_scale: bool = False, 51 | ratio: float = 0.0625, # bottleneck ratio of 1/16 as described in paper 52 | **kwargs, 53 | ) -> None: 54 | super().__init__(**kwargs) 55 | 56 | self.headers = headers # h 57 | self.inplanes = inplanes # C 58 | self.attn_scale = attn_scale 59 | self.planes = int(inplanes * ratio) 60 | 61 | self.single_header_inplanes = int(inplanes / headers) # C / h 62 | 63 | self.conv_mask = layers.Conv2D(filters=1, kernel_size=1, kernel_initializer=tf.initializers.he_normal()) 64 | 65 | self.transform = Sequential( 66 | [ 67 | layers.Conv2D(filters=self.planes, kernel_size=1, kernel_initializer=tf.initializers.he_normal()), 68 | layers.LayerNormalization([1, 2, 3]), 69 | layers.ReLU(), 70 | layers.Conv2D(filters=self.inplanes, kernel_size=1, kernel_initializer=tf.initializers.he_normal()), 71 | ], 72 | name="transform", 73 | ) 74 | 75 | def context_modeling(self, inputs: tf.Tensor) -> tf.Tensor: 76 | b, h, w, c = (tf.shape(inputs)[i] for i in range(4)) 77 | 78 | # B, H, W, C -->> B*h, H, W, C/h 79 | x = tf.reshape(inputs, shape=(b, h, w, self.headers, self.single_header_inplanes)) 80 | x = tf.transpose(x, perm=(0, 3, 1, 2, 4)) 81 | x = tf.reshape(x, shape=(b * self.headers, h, w, self.single_header_inplanes)) 82 | 83 | # Compute shorcut 84 | shortcut = x 85 | # B*h, 1, H*W, C/h 86 | shortcut = tf.reshape(shortcut, shape=(b * self.headers, 1, h * w, self.single_header_inplanes)) 87 | # B*h, 1, C/h, H*W 88 | shortcut = tf.transpose(shortcut, perm=[0, 1, 3, 2]) 89 | 90 | # Compute context mask 91 | # B*h, H, W, 1 92 | context_mask = self.conv_mask(x) 93 | # B*h, 1, H*W, 1 94 | context_mask = tf.reshape(context_mask, shape=(b * self.headers, 1, h * w, 1)) 95 | # scale variance 96 | if self.attn_scale and self.headers > 1: 97 | context_mask = context_mask / math.sqrt(self.single_header_inplanes) 98 | # B*h, 1, H*W, 1 99 | context_mask = tf.keras.activations.softmax(context_mask, axis=2) 100 | 101 | # Compute context 102 | # B*h, 1, C/h, 1 103 | context = tf.matmul(shortcut, context_mask) 104 | context = tf.reshape(context, shape=(b, 1, c, 1)) 105 | # B, 1, 1, C 106 | context = tf.transpose(context, perm=(0, 1, 3, 2)) 107 | # Set shape to resolve shape when calling this module in the Sequential MAGCResnet 108 | batch, chan = inputs.get_shape().as_list()[0], inputs.get_shape().as_list()[-1] 109 | context.set_shape([batch, 1, 1, chan]) 110 | return context 111 | 112 | def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor: 113 | # Context modeling: B, H, W, C -> B, 1, 1, C 114 | context = self.context_modeling(inputs) 115 | # Transform: B, 1, 1, C -> B, 1, 1, C 116 | transformed = self.transform(context) 117 | return inputs + transformed 118 | 119 | 120 | def _magc_resnet( 121 | arch: str, 122 | pretrained: bool, 123 | num_blocks: List[int], 124 | output_channels: List[int], 125 | stage_downsample: List[bool], 126 | stage_conv: List[bool], 127 | stage_pooling: List[Optional[Tuple[int, int]]], 128 | origin_stem: bool = True, 129 | **kwargs: Any, 130 | ) -> ResNet: 131 | kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"])) 132 | kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"]) 133 | kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"]) 134 | 135 | _cfg = deepcopy(default_cfgs[arch]) 136 | _cfg["num_classes"] = kwargs["num_classes"] 137 | _cfg["classes"] = kwargs["classes"] 138 | _cfg["input_shape"] = kwargs["input_shape"] 139 | kwargs.pop("classes") 140 | 141 | # Build the model 142 | model = ResNet( 143 | num_blocks, 144 | output_channels, 145 | stage_downsample, 146 | stage_conv, 147 | stage_pooling, 148 | origin_stem, 149 | attn_module=partial(MAGC, headers=8, attn_scale=True), 150 | cfg=_cfg, 151 | **kwargs, 152 | ) 153 | # Load pretrained parameters 154 | if pretrained: 155 | load_pretrained_params(model, default_cfgs[arch]["url"]) 156 | 157 | return model 158 | 159 | 160 | def magc_resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet: 161 | """Resnet31 architecture with Multi-Aspect Global Context Attention as described in 162 | `"MASTER: Multi-Aspect Non-local Network for Scene Text Recognition", 163 | `_. 164 | 165 | >>> import tensorflow as tf 166 | >>> from doctr.models import magc_resnet31 167 | >>> model = magc_resnet31(pretrained=False) 168 | >>> input_tensor = tf.random.uniform(shape=[1, 224, 224, 3], maxval=1, dtype=tf.float32) 169 | >>> out = model(input_tensor) 170 | 171 | Args: 172 | pretrained: boolean, True if model is pretrained 173 | 174 | Returns: 175 | A feature extractor model 176 | """ 177 | 178 | return _magc_resnet( 179 | "magc_resnet31", 180 | pretrained, 181 | [1, 2, 5, 3], 182 | [256, 256, 512, 512], 183 | [False] * 4, 184 | [True] * 4, 185 | [(2, 2), (2, 1), None, None], 186 | False, 187 | stem_channels=128, 188 | **kwargs, 189 | ) 190 | --------------------------------------------------------------------------------