├── .github ├── ISSUE_TEMPLATE │ └── bug-report.md └── workflows │ ├── ci.yml │ ├── package_testing.yml │ └── publish_pypi.yml ├── .gitignore ├── LICENSE ├── README.md ├── craft_text_detector ├── __init__.py ├── craft_utils.py ├── file_utils.py ├── image_utils.py ├── models │ ├── __init__.py │ ├── basenet │ │ ├── __init__.py │ │ └── vgg16_bn.py │ ├── craftnet.py │ └── refinenet.py ├── predict.py └── torch_utils.py ├── figures ├── craft_example.gif ├── idcard.png └── idcard2.jpg ├── requirements.txt ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── test_craft.py └── test_helpers.py /.github/ISSUE_TEMPLATE/bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "Bug Report" 3 | about: Submit a bug report to help us improve craft-text-detector 4 | 5 | --- 6 | 7 | ## 🐛 Bug 8 | 9 | 10 | 11 | ## To Reproduce 12 | 13 | Steps to reproduce the behavior: 14 | 15 | 1. 16 | 1. 17 | 1. 18 | 19 | 20 | 21 | ## Expected behavior 22 | 23 | 24 | 25 | ## Environment 26 | 27 | - craft-text-detector version (e.g., 0.3.1): 28 | - Python version (e.g., 3.6/3.7): 29 | - OS (e.g., Linux/Windows/MacOS): 30 | - How you installed craft-text-detector (`conda`, `pip`, source): 31 | - Any other relevant information: 32 | 33 | ## Additional context 34 | 35 | 36 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | build: 11 | 12 | runs-on: ubuntu-latest 13 | strategy: 14 | matrix: 15 | operating-system: [ubuntu-latest, windows-latest, macos-latest] 16 | python-version: [3.7, 3.8, 3.9] 17 | fail-fast: false 18 | 19 | steps: 20 | - name: Checkout 21 | uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Restore Ubuntu cache 27 | uses: actions/cache@v1 28 | if: matrix.operating-system == 'ubuntu-latest' 29 | with: 30 | path: ~/.cache/pip 31 | key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}} 32 | restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}- 33 | - name: Restore MacOS cache 34 | uses: actions/cache@v1 35 | if: matrix.operating-system == 'macos-latest' 36 | with: 37 | path: ~/Library/Caches/pip 38 | key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}} 39 | restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}- 40 | - name: Restore Windows cache 41 | uses: actions/cache@v1 42 | if: matrix.operating-system == 'windows-latest' 43 | with: 44 | path: ~\AppData\Local\pip\Cache 45 | key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}} 46 | restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}- 47 | - name: Update pip 48 | run: python -m pip install --upgrade pip 49 | - name: Install PyTorch on Linux and Windows 50 | if: > 51 | matrix.operating-system == 'ubuntu-latest' || 52 | matrix.operating-system == 'windows-latest' 53 | run: > 54 | pip install torch==1.8.1+cpu torchvision==0.9.1+cpu 55 | -f https://download.pytorch.org/whl/torch_stable.html 56 | - name: Install PyTorch on MacOS 57 | if: matrix.operating-system == 'macos-latest' 58 | run: pip install torch==1.8.1 torchvision==0.9.1 59 | - name: Install dependencies 60 | run: | 61 | python -m pip install --upgrade pip 62 | pip install -r requirements.txt 63 | - name: Lint with flake8 64 | run: | 65 | pip install flake8 66 | # stop the build if there are Python syntax errors or undefined names 67 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 68 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 69 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 70 | - name: Test with unittest 71 | run: | 72 | python -m unittest 73 | -------------------------------------------------------------------------------- /.github/workflows/package_testing.yml: -------------------------------------------------------------------------------- 1 | name: Package Testing 2 | 3 | on: 4 | schedule: 5 | - cron: '0 0 * * *' # Runs at 00:00 UTC every day 6 | 7 | jobs: 8 | build: 9 | runs-on: ubuntu-latest 10 | 11 | strategy: 12 | matrix: 13 | operating-system: [ubuntu-latest, windows-latest, macos-latest] 14 | python-version: [3.7, 3.8, 3.9] 15 | fail-fast: false 16 | 17 | steps: 18 | - name: Checkout 19 | uses: actions/checkout@v2 20 | - name: Set up Python 21 | uses: actions/setup-python@v2 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | - name: Restore Ubuntu cache 25 | uses: actions/cache@v1 26 | if: matrix.operating-system == 'ubuntu-latest' 27 | with: 28 | path: ~/.cache/pip 29 | key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}} 30 | restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}- 31 | - name: Restore MacOS cache 32 | uses: actions/cache@v1 33 | if: matrix.operating-system == 'macos-latest' 34 | with: 35 | path: ~/Library/Caches/pip 36 | key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}} 37 | restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}- 38 | - name: Restore Windows cache 39 | uses: actions/cache@v1 40 | if: matrix.operating-system == 'windows-latest' 41 | with: 42 | path: ~\AppData\Local\pip\Cache 43 | key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}} 44 | restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}- 45 | - name: Update pip 46 | run: python -m pip install --upgrade pip 47 | - name: Install PyTorch on Linux and Windows 48 | if: > 49 | matrix.operating-system == 'ubuntu-latest' || 50 | matrix.operating-system == 'windows-latest' 51 | run: > 52 | pip install torch==1.8.1+cpu torchvision==0.9.1+cpu 53 | -f https://download.pytorch.org/whl/torch_stable.html 54 | - name: Install PyTorch on MacOS 55 | if: matrix.operating-system == 'macos-latest' 56 | run: pip install torch==1.8.1 torchvision==0.9.1 57 | - name: Install latest craft-text-detector package 58 | run: > 59 | pip install --upgrade --force-reinstall craft-text-detector 60 | - name: Test with unittest 61 | run: | 62 | python -m unittest -------------------------------------------------------------------------------- /.github/workflows/publish_pypi.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | release: 5 | types: [published, edited] 6 | 7 | jobs: 8 | deploy: 9 | 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - uses: actions/checkout@v2 14 | - name: Set up Python 15 | uses: actions/setup-python@v2 16 | with: 17 | python-version: '3.x' 18 | - name: Install dependencies 19 | run: | 20 | python -m pip install --upgrade pip 21 | pip install setuptools wheel twine 22 | - name: Build and publish 23 | env: 24 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 25 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 26 | run: | 27 | python setup.py sdist bdist_wheel 28 | twine upload dist/* 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.swp 3 | *.pkl 4 | *.pth 5 | result* 6 | weights* 7 | .vscode 8 | .pypirc 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | pip-wheel-metadata/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | cover/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | .pybuilder/ 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | # For a library or package, you might want to ignore these files since the code is 97 | # intended to run in multiple environments; otherwise, check them in: 98 | # .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 108 | __pypackages__/ 109 | 110 | # Celery stuff 111 | celerybeat-schedule 112 | celerybeat.pid 113 | 114 | # SageMath parsed files 115 | *.sage.py 116 | 117 | # Environments 118 | .env 119 | .venv 120 | env/ 121 | venv/ 122 | ENV/ 123 | env.bak/ 124 | venv.bak/ 125 | 126 | # Spyder project settings 127 | .spyderproject 128 | .spyproject 129 | 130 | # Rope project settings 131 | .ropeproject 132 | 133 | # mkdocs documentation 134 | /site 135 | 136 | # mypy 137 | .mypy_cache/ 138 | .dmypy.json 139 | dmypy.json 140 | 141 | # Pyre type checker 142 | .pyre/ 143 | 144 | # pytype static type analyzer 145 | .pytype/ 146 | 147 | # Cython debug symbols 148 | cython_debug/ 149 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019 NAVER Corp., 2020 Fatih C Akyon 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CRAFT: Character-Region Awareness For Text detection 2 | 3 |

4 | downloads 5 | downloads 6 | fcakyon twitter 7 |
8 |
Build status 9 | PyPI version 10 | License: MIT 11 |

12 | 13 | Packaged, Pytorch-based, easy to use, cross-platform version of the CRAFT text detector | [Paper](https://arxiv.org/abs/1904.01941) | 14 | 15 | ## Overview 16 | 17 | PyTorch implementation for CRAFT text detector that effectively detect text area by exploring each character region and affinity between characters. The bounding box of texts are obtained by simply finding minimum bounding rectangles on binary map after thresholding character region and affinity scores. 18 | 19 | teaser 20 | 21 | ## Getting started 22 | 23 | ### Installation 24 | 25 | - Install using pip: 26 | 27 | ```console 28 | pip install craft-text-detector 29 | ``` 30 | 31 | ### Basic Usage 32 | 33 | ```python 34 | # import Craft class 35 | from craft_text_detector import Craft 36 | 37 | # set image path and export folder directory 38 | image = 'figures/idcard.png' # can be filepath, PIL image or numpy array 39 | output_dir = 'outputs/' 40 | 41 | # create a craft instance 42 | craft = Craft(output_dir=output_dir, crop_type="poly", cuda=False) 43 | 44 | # apply craft text detection and export detected regions to output directory 45 | prediction_result = craft.detect_text(image) 46 | 47 | # unload models from ram/gpu 48 | craft.unload_craftnet_model() 49 | craft.unload_refinenet_model() 50 | ``` 51 | 52 | ### Advanced Usage 53 | 54 | ```python 55 | # import craft functions 56 | from craft_text_detector import ( 57 | read_image, 58 | load_craftnet_model, 59 | load_refinenet_model, 60 | get_prediction, 61 | export_detected_regions, 62 | export_extra_results, 63 | empty_cuda_cache 64 | ) 65 | 66 | # set image path and export folder directory 67 | image = 'figures/idcard.png' # can be filepath, PIL image or numpy array 68 | output_dir = 'outputs/' 69 | 70 | # read image 71 | image = read_image(image) 72 | 73 | # load models 74 | refine_net = load_refinenet_model(cuda=True) 75 | craft_net = load_craftnet_model(cuda=True) 76 | 77 | # perform prediction 78 | prediction_result = get_prediction( 79 | image=image, 80 | craft_net=craft_net, 81 | refine_net=refine_net, 82 | text_threshold=0.7, 83 | link_threshold=0.4, 84 | low_text=0.4, 85 | cuda=True, 86 | long_size=1280 87 | ) 88 | 89 | # export detected text regions 90 | exported_file_paths = export_detected_regions( 91 | image=image, 92 | regions=prediction_result["boxes"], 93 | output_dir=output_dir, 94 | rectify=True 95 | ) 96 | 97 | # export heatmap, detection points, box visualization 98 | export_extra_results( 99 | image=image, 100 | regions=prediction_result["boxes"], 101 | heatmaps=prediction_result["heatmaps"], 102 | output_dir=output_dir 103 | ) 104 | 105 | # unload models from gpu 106 | empty_cuda_cache() 107 | ``` 108 | -------------------------------------------------------------------------------- /craft_text_detector/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | from typing import Optional 5 | 6 | import craft_text_detector.craft_utils as craft_utils 7 | import craft_text_detector.file_utils as file_utils 8 | import craft_text_detector.image_utils as image_utils 9 | import craft_text_detector.predict as predict 10 | import craft_text_detector.torch_utils as torch_utils 11 | 12 | __version__ = "0.4.3" 13 | 14 | 15 | __all__ = [ 16 | "read_image", 17 | "load_craftnet_model", 18 | "load_refinenet_model", 19 | "get_prediction", 20 | "export_detected_regions", 21 | "export_extra_results", 22 | "empty_cuda_cache", 23 | "Craft", 24 | ] 25 | 26 | read_image = image_utils.read_image 27 | load_craftnet_model = craft_utils.load_craftnet_model 28 | load_refinenet_model = craft_utils.load_refinenet_model 29 | get_prediction = predict.get_prediction 30 | export_detected_regions = file_utils.export_detected_regions 31 | export_extra_results = file_utils.export_extra_results 32 | empty_cuda_cache = torch_utils.empty_cuda_cache 33 | 34 | 35 | class Craft: 36 | def __init__( 37 | self, 38 | output_dir=None, 39 | rectify=True, 40 | export_extra=True, 41 | text_threshold=0.7, 42 | link_threshold=0.4, 43 | low_text=0.4, 44 | cuda=False, 45 | long_size=1280, 46 | refiner=True, 47 | crop_type="poly", 48 | weight_path_craft_net: Optional[str] = None, 49 | weight_path_refine_net: Optional[str] = None, 50 | ): 51 | """ 52 | Arguments: 53 | output_dir: path to the results to be exported 54 | rectify: rectify detected polygon by affine transform 55 | export_extra: export heatmap, detection points, box visualization 56 | text_threshold: text confidence threshold 57 | link_threshold: link confidence threshold 58 | low_text: text low-bound score 59 | cuda: Use cuda for inference 60 | long_size: desired longest image size for inference 61 | refiner: enable link refiner 62 | crop_type: crop regions by detected boxes or polys ("poly" or "box") 63 | """ 64 | self.craft_net = None 65 | self.refine_net = None 66 | self.output_dir = output_dir 67 | self.rectify = rectify 68 | self.export_extra = export_extra 69 | self.text_threshold = text_threshold 70 | self.link_threshold = link_threshold 71 | self.low_text = low_text 72 | self.cuda = cuda 73 | self.long_size = long_size 74 | self.refiner = refiner 75 | self.crop_type = crop_type 76 | 77 | # load craftnet 78 | self.load_craftnet_model(weight_path_craft_net) 79 | # load refinernet if required 80 | if refiner: 81 | self.load_refinenet_model(weight_path_refine_net) 82 | 83 | def load_craftnet_model(self, weight_path: Optional[str] = None): 84 | """ 85 | Loads craftnet model 86 | """ 87 | self.craft_net = load_craftnet_model(self.cuda, weight_path=weight_path) 88 | 89 | def load_refinenet_model(self, weight_path: Optional[str] = None): 90 | """ 91 | Loads refinenet model 92 | """ 93 | self.refine_net = load_refinenet_model(self.cuda, weight_path=weight_path) 94 | 95 | def unload_craftnet_model(self): 96 | """ 97 | Unloads craftnet model 98 | """ 99 | self.craft_net = None 100 | empty_cuda_cache() 101 | 102 | def unload_refinenet_model(self): 103 | """ 104 | Unloads refinenet model 105 | """ 106 | self.refine_net = None 107 | empty_cuda_cache() 108 | 109 | def detect_text(self, image, image_path=None): 110 | """ 111 | Arguments: 112 | image: path to the image to be processed or numpy array or PIL image 113 | 114 | Output: 115 | { 116 | "masks": lists of predicted masks 2d as bool array, 117 | "boxes": list of coords of points of predicted boxes, 118 | "boxes_as_ratios": list of coords of points of predicted boxes as ratios of image size, 119 | "polys_as_ratios": list of coords of points of predicted polys as ratios of image size, 120 | "heatmaps": visualization of the detected characters/links, 121 | "text_crop_paths": list of paths of the exported text boxes/polys, 122 | "times": elapsed times of the sub modules, in seconds 123 | } 124 | """ 125 | 126 | if image_path is not None: 127 | print("Argument 'image_path' is deprecated, use 'image' instead.") 128 | image = image_path 129 | 130 | # perform prediction 131 | prediction_result = get_prediction( 132 | image=image, 133 | craft_net=self.craft_net, 134 | refine_net=self.refine_net, 135 | text_threshold=self.text_threshold, 136 | link_threshold=self.link_threshold, 137 | low_text=self.low_text, 138 | cuda=self.cuda, 139 | long_size=self.long_size, 140 | ) 141 | 142 | # arange regions 143 | if self.crop_type == "box": 144 | regions = prediction_result["boxes"] 145 | elif self.crop_type == "poly": 146 | regions = prediction_result["polys"] 147 | else: 148 | raise TypeError("crop_type can be only 'polys' or 'boxes'") 149 | 150 | # export if output_dir is given 151 | prediction_result["text_crop_paths"] = [] 152 | if self.output_dir is not None: 153 | # export detected text regions 154 | if type(image) == str: 155 | file_name, file_ext = os.path.splitext(os.path.basename(image)) 156 | else: 157 | file_name = "image" 158 | exported_file_paths = export_detected_regions( 159 | image=image, 160 | regions=regions, 161 | file_name=file_name, 162 | output_dir=self.output_dir, 163 | rectify=self.rectify, 164 | ) 165 | prediction_result["text_crop_paths"] = exported_file_paths 166 | 167 | # export heatmap, detection points, box visualization 168 | if self.export_extra: 169 | export_extra_results( 170 | image=image, 171 | regions=regions, 172 | heatmaps=prediction_result["heatmaps"], 173 | file_name=file_name, 174 | output_dir=self.output_dir, 175 | ) 176 | 177 | # return prediction results 178 | return prediction_result 179 | -------------------------------------------------------------------------------- /craft_text_detector/craft_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from collections import OrderedDict 4 | from pathlib import Path 5 | from typing import Optional, Union 6 | 7 | import cv2 8 | import numpy as np 9 | 10 | import craft_text_detector.file_utils as file_utils 11 | import craft_text_detector.torch_utils as torch_utils 12 | 13 | CRAFT_GDRIVE_URL = "https://drive.google.com/uc?id=1bupFXqT-VU6Jjeul13XP7yx2Sg5IHr4J" 14 | REFINENET_GDRIVE_URL = ( 15 | "https://drive.google.com/uc?id=1xcE9qpJXp4ofINwXWVhhQIh9S8Z7cuGj" 16 | ) 17 | 18 | 19 | # unwarp corodinates 20 | def warpCoord(Minv, pt): 21 | out = np.matmul(Minv, (pt[0], pt[1], 1)) 22 | return np.array([out[0] / out[2], out[1] / out[2]]) 23 | 24 | 25 | def copyStateDict(state_dict): 26 | if list(state_dict.keys())[0].startswith("module"): 27 | start_idx = 1 28 | else: 29 | start_idx = 0 30 | new_state_dict = OrderedDict() 31 | for k, v in state_dict.items(): 32 | name = ".".join(k.split(".")[start_idx:]) 33 | new_state_dict[name] = v 34 | return new_state_dict 35 | 36 | 37 | def load_craftnet_model( 38 | cuda: bool = False, 39 | weight_path: Optional[Union[str, Path]] = None 40 | ): 41 | # get craft net path 42 | if weight_path is None: 43 | home_path = str(Path.home()) 44 | weight_path = Path( 45 | home_path, 46 | ".craft_text_detector", 47 | "weights", 48 | "craft_mlt_25k.pth" 49 | ) 50 | weight_path = Path(weight_path).resolve() 51 | weight_path.parent.mkdir(exist_ok=True, parents=True) 52 | weight_path = str(weight_path) 53 | 54 | # load craft net 55 | from craft_text_detector.models.craftnet import CraftNet 56 | 57 | craft_net = CraftNet() # initialize 58 | 59 | # check if weights are already downloaded, if not download 60 | url = CRAFT_GDRIVE_URL 61 | if not os.path.isfile(weight_path): 62 | print("Craft text detector weight will be downloaded to {}".format(weight_path)) 63 | 64 | file_utils.download(url=url, save_path=weight_path) 65 | 66 | # arange device 67 | if cuda: 68 | craft_net.load_state_dict(copyStateDict(torch_utils.load(weight_path))) 69 | 70 | craft_net = craft_net.cuda() 71 | craft_net = torch_utils.DataParallel(craft_net) 72 | torch_utils.cudnn_benchmark = False 73 | else: 74 | craft_net.load_state_dict( 75 | copyStateDict(torch_utils.load(weight_path, map_location="cpu")) 76 | ) 77 | craft_net.eval() 78 | return craft_net 79 | 80 | 81 | def load_refinenet_model( 82 | cuda: bool = False, 83 | weight_path: Optional[Union[str, Path]] = None 84 | ): 85 | # get refine net path 86 | if weight_path is None: 87 | home_path = Path.home() 88 | weight_path = Path( 89 | home_path, 90 | ".craft_text_detector", 91 | "weights", 92 | "craft_refiner_CTW1500.pth" 93 | ) 94 | weight_path = Path(weight_path).resolve() 95 | weight_path.parent.mkdir(exist_ok=True, parents=True) 96 | weight_path = str(weight_path) 97 | 98 | # load refine net 99 | from craft_text_detector.models.refinenet import RefineNet 100 | 101 | refine_net = RefineNet() # initialize 102 | 103 | # check if weights are already downloaded, if not download 104 | url = REFINENET_GDRIVE_URL 105 | if not os.path.isfile(weight_path): 106 | print("Craft text refiner weight will be downloaded to {}".format(weight_path)) 107 | 108 | file_utils.download(url=url, save_path=weight_path) 109 | 110 | # arange device 111 | if cuda: 112 | refine_net.load_state_dict(copyStateDict(torch_utils.load(weight_path))) 113 | 114 | refine_net = refine_net.cuda() 115 | refine_net = torch_utils.DataParallel(refine_net) 116 | torch_utils.cudnn_benchmark = False 117 | else: 118 | refine_net.load_state_dict( 119 | copyStateDict(torch_utils.load(weight_path, map_location="cpu")) 120 | ) 121 | refine_net.eval() 122 | return refine_net 123 | 124 | 125 | def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text): 126 | # prepare data 127 | linkmap = linkmap.copy() 128 | textmap = textmap.copy() 129 | img_h, img_w = textmap.shape 130 | 131 | """ labeling method """ 132 | ret, text_score = cv2.threshold(textmap, low_text, 1, 0) 133 | ret, link_score = cv2.threshold(linkmap, link_threshold, 1, 0) 134 | 135 | text_score_comb = np.clip(text_score + link_score, 0, 1) 136 | nLabels, labels, stats, centroids = cv2.connectedComponentsWithStats( 137 | text_score_comb.astype(np.uint8), connectivity=4 138 | ) 139 | 140 | det = [] 141 | mapper = [] 142 | for k in range(1, nLabels): 143 | # size filtering 144 | size = stats[k, cv2.CC_STAT_AREA] 145 | if size < 10: 146 | continue 147 | 148 | # thresholding 149 | if np.max(textmap[labels == k]) < text_threshold: 150 | continue 151 | 152 | # make segmentation map 153 | segmap = np.zeros(textmap.shape, dtype=np.uint8) 154 | segmap[labels == k] = 255 155 | 156 | # remove link area 157 | segmap[np.logical_and(link_score == 1, text_score == 0)] = 0 158 | 159 | x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP] 160 | w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT] 161 | niter = int(math.sqrt(size * min(w, h) / (w * h)) * 2) 162 | sx, ex, sy, ey = (x - niter, x + w + niter + 1, y - niter, y + h + niter + 1) 163 | # boundary check 164 | if sx < 0: 165 | sx = 0 166 | if sy < 0: 167 | sy = 0 168 | if ex >= img_w: 169 | ex = img_w 170 | if ey >= img_h: 171 | ey = img_h 172 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1 + niter, 1 + niter)) 173 | segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel) 174 | 175 | # make box 176 | np_temp = np.roll(np.array(np.where(segmap != 0)), 1, axis=0) 177 | np_contours = np_temp.transpose().reshape(-1, 2) 178 | rectangle = cv2.minAreaRect(np_contours) 179 | box = cv2.boxPoints(rectangle) 180 | 181 | # boundary check due to minAreaRect may have out of range values 182 | # (see https://docs.opencv.org/3.4/d3/dc0/group__imgproc__shape.html#ga3d476a3417130ae5154aea421ca7ead9) 183 | for p in box: 184 | if p[0] < 0: 185 | p[0] = 0 186 | if p[1] < 0: 187 | p[1] = 0 188 | if p[0] >= img_w: 189 | p[0] = img_w 190 | if p[1] >= img_h: 191 | p[1] = img_h 192 | 193 | # align diamond-shape 194 | w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2]) 195 | box_ratio = max(w, h) / (min(w, h) + 1e-5) 196 | if abs(1 - box_ratio) <= 0.1: 197 | l, r = min(np_contours[:, 0]), max(np_contours[:, 0]) 198 | t, b = min(np_contours[:, 1]), max(np_contours[:, 1]) 199 | box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32) 200 | 201 | # make clock-wise order 202 | startidx = box.sum(axis=1).argmin() 203 | box = np.roll(box, 4 - startidx, 0) 204 | box = np.array(box) 205 | 206 | det.append(box) 207 | mapper.append(k) 208 | 209 | return det, labels, mapper 210 | 211 | 212 | def getPoly_core(boxes, labels, mapper, linkmap): 213 | # configs 214 | num_cp = 5 215 | max_len_ratio = 0.7 216 | expand_ratio = 1.45 217 | max_r = 2.0 218 | step_r = 0.2 219 | 220 | polys = [] 221 | for k, box in enumerate(boxes): 222 | # size filter for small instance 223 | w, h = ( 224 | int(np.linalg.norm(box[0] - box[1]) + 1), 225 | int(np.linalg.norm(box[1] - box[2]) + 1), 226 | ) 227 | if w < 10 or h < 10: 228 | polys.append(None) 229 | continue 230 | 231 | # warp image 232 | tar = np.float32([[0, 0], [w, 0], [w, h], [0, h]]) 233 | M = cv2.getPerspectiveTransform(box, tar) 234 | word_label = cv2.warpPerspective(labels, M, (w, h), flags=cv2.INTER_NEAREST) 235 | try: 236 | Minv = np.linalg.inv(M) 237 | except: 238 | polys.append(None) 239 | continue 240 | 241 | # binarization for selected label 242 | cur_label = mapper[k] 243 | word_label[word_label != cur_label] = 0 244 | word_label[word_label > 0] = 1 245 | 246 | """ Polygon generation """ 247 | # find top/bottom contours 248 | cp = [] 249 | max_len = -1 250 | for i in range(w): 251 | region = np.where(word_label[:, i] != 0)[0] 252 | if len(region) < 2: 253 | continue 254 | cp.append((i, region[0], region[-1])) 255 | length = region[-1] - region[0] + 1 256 | if length > max_len: 257 | max_len = length 258 | 259 | # pass if max_len is similar to h 260 | if h * max_len_ratio < max_len: 261 | polys.append(None) 262 | continue 263 | 264 | # get pivot points with fixed length 265 | tot_seg = num_cp * 2 + 1 266 | seg_w = w / tot_seg # segment width 267 | pp = [None] * num_cp # init pivot points 268 | cp_section = [[0, 0]] * tot_seg 269 | seg_height = [0] * num_cp 270 | seg_num = 0 271 | num_sec = 0 272 | prev_h = -1 273 | for i in range(0, len(cp)): 274 | (x, sy, ey) = cp[i] 275 | if (seg_num + 1) * seg_w <= x and seg_num <= tot_seg: 276 | # average previous segment 277 | if num_sec == 0: 278 | break 279 | cp_section[seg_num] = [ 280 | cp_section[seg_num][0] / num_sec, 281 | cp_section[seg_num][1] / num_sec, 282 | ] 283 | num_sec = 0 284 | 285 | # reset variables 286 | seg_num += 1 287 | prev_h = -1 288 | 289 | # accumulate center points 290 | cy = (sy + ey) * 0.5 291 | cur_h = ey - sy + 1 292 | cp_section[seg_num] = [ 293 | cp_section[seg_num][0] + x, 294 | cp_section[seg_num][1] + cy, 295 | ] 296 | num_sec += 1 297 | 298 | if seg_num % 2 == 0: 299 | continue # No polygon area 300 | 301 | if prev_h < cur_h: 302 | pp[int((seg_num - 1) / 2)] = (x, cy) 303 | seg_height[int((seg_num - 1) / 2)] = cur_h 304 | prev_h = cur_h 305 | 306 | # processing last segment 307 | if num_sec != 0: 308 | cp_section[-1] = [cp_section[-1][0] / num_sec, cp_section[-1][1] / num_sec] 309 | 310 | # pass if num of pivots is not sufficient or segment widh 311 | # is smaller than character height 312 | if None in pp or seg_w < np.max(seg_height) * 0.25: 313 | polys.append(None) 314 | continue 315 | 316 | # calc median maximum of pivot points 317 | half_char_h = np.median(seg_height) * expand_ratio / 2 318 | 319 | # calc gradiant and apply to make horizontal pivots 320 | new_pp = [] 321 | for i, (x, cy) in enumerate(pp): 322 | dx = cp_section[i * 2 + 2][0] - cp_section[i * 2][0] 323 | dy = cp_section[i * 2 + 2][1] - cp_section[i * 2][1] 324 | if dx == 0: # gradient if zero 325 | new_pp.append([x, cy - half_char_h, x, cy + half_char_h]) 326 | continue 327 | rad = -math.atan2(dy, dx) 328 | c, s = half_char_h * math.cos(rad), half_char_h * math.sin(rad) 329 | new_pp.append([x - s, cy - c, x + s, cy + c]) 330 | 331 | # get edge points to cover character heatmaps 332 | isSppFound, isEppFound = False, False 333 | grad_s = (pp[1][1] - pp[0][1]) / (pp[1][0] - pp[0][0]) + ( 334 | pp[2][1] - pp[1][1] 335 | ) / (pp[2][0] - pp[1][0]) 336 | grad_e = (pp[-2][1] - pp[-1][1]) / (pp[-2][0] - pp[-1][0]) + ( 337 | pp[-3][1] - pp[-2][1] 338 | ) / (pp[-3][0] - pp[-2][0]) 339 | for r in np.arange(0.5, max_r, step_r): 340 | dx = 2 * half_char_h * r 341 | if not isSppFound: 342 | line_img = np.zeros(word_label.shape, dtype=np.uint8) 343 | dy = grad_s * dx 344 | p = np.array(new_pp[0]) - np.array([dx, dy, dx, dy]) 345 | cv2.line( 346 | line_img, 347 | (int(p[0]), int(p[1])), 348 | (int(p[2]), int(p[3])), 349 | 1, 350 | thickness=1, 351 | ) 352 | if ( 353 | np.sum(np.logical_and(word_label, line_img)) == 0 354 | or r + 2 * step_r >= max_r 355 | ): 356 | spp = p 357 | isSppFound = True 358 | if not isEppFound: 359 | line_img = np.zeros(word_label.shape, dtype=np.uint8) 360 | dy = grad_e * dx 361 | p = np.array(new_pp[-1]) + np.array([dx, dy, dx, dy]) 362 | cv2.line( 363 | line_img, 364 | (int(p[0]), int(p[1])), 365 | (int(p[2]), int(p[3])), 366 | 1, 367 | thickness=1, 368 | ) 369 | if ( 370 | np.sum(np.logical_and(word_label, line_img)) == 0 371 | or r + 2 * step_r >= max_r 372 | ): 373 | epp = p 374 | isEppFound = True 375 | if isSppFound and isEppFound: 376 | break 377 | 378 | # pass if boundary of polygon is not found 379 | if not (isSppFound and isEppFound): 380 | polys.append(None) 381 | continue 382 | 383 | # make final polygon 384 | poly = [] 385 | poly.append(warpCoord(Minv, (spp[0], spp[1]))) 386 | for p in new_pp: 387 | poly.append(warpCoord(Minv, (p[0], p[1]))) 388 | poly.append(warpCoord(Minv, (epp[0], epp[1]))) 389 | poly.append(warpCoord(Minv, (epp[2], epp[3]))) 390 | for p in reversed(new_pp): 391 | poly.append(warpCoord(Minv, (p[2], p[3]))) 392 | poly.append(warpCoord(Minv, (spp[2], spp[3]))) 393 | 394 | # add to final result 395 | polys.append(np.array(poly)) 396 | 397 | return polys 398 | 399 | 400 | def getDetBoxes(textmap, linkmap, text_threshold, link_threshold, low_text, poly=False): 401 | boxes, labels, mapper = getDetBoxes_core( 402 | textmap, linkmap, text_threshold, link_threshold, low_text 403 | ) 404 | 405 | if poly: 406 | polys = getPoly_core(boxes, labels, mapper, linkmap) 407 | else: 408 | polys = [None] * len(boxes) 409 | 410 | return boxes, polys 411 | 412 | 413 | def adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net=2): 414 | if len(polys) > 0: 415 | polys = np.array(polys) 416 | for k in range(len(polys)): 417 | if polys[k] is not None: 418 | polys[k] *= (ratio_w * ratio_net, ratio_h * ratio_net) 419 | return polys 420 | -------------------------------------------------------------------------------- /craft_text_detector/file_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | import cv2 5 | import gdown 6 | import numpy as np 7 | 8 | from craft_text_detector.image_utils import read_image 9 | 10 | 11 | def download(url: str, save_path: str): 12 | """ 13 | Downloads file from gdrive, shows progress. 14 | Example inputs: 15 | url: 'ftp://smartengines.com/midv-500/dataset/01_alb_id.zip' 16 | save_path: 'data/file.zip' 17 | """ 18 | 19 | # create save_dir if not present 20 | create_dir(os.path.dirname(save_path)) 21 | # download file 22 | gdown.download(url, save_path, quiet=False) 23 | 24 | 25 | def create_dir(_dir): 26 | """ 27 | Creates given directory if it is not present. 28 | """ 29 | if not os.path.exists(_dir): 30 | os.makedirs(_dir) 31 | 32 | 33 | def get_files(img_dir): 34 | imgs, masks, xmls = list_files(img_dir) 35 | return imgs, masks, xmls 36 | 37 | 38 | def list_files(in_path): 39 | img_files = [] 40 | mask_files = [] 41 | gt_files = [] 42 | for (dirpath, dirnames, filenames) in os.walk(in_path): 43 | for file in filenames: 44 | filename, ext = os.path.splitext(file) 45 | ext = str.lower(ext) 46 | if ( 47 | ext == ".jpg" 48 | or ext == ".jpeg" 49 | or ext == ".gif" 50 | or ext == ".png" 51 | or ext == ".pgm" 52 | ): 53 | img_files.append(os.path.join(dirpath, file)) 54 | elif ext == ".bmp": 55 | mask_files.append(os.path.join(dirpath, file)) 56 | elif ext == ".xml" or ext == ".gt" or ext == ".txt": 57 | gt_files.append(os.path.join(dirpath, file)) 58 | elif ext == ".zip": 59 | continue 60 | # img_files.sort() 61 | # mask_files.sort() 62 | # gt_files.sort() 63 | return img_files, mask_files, gt_files 64 | 65 | 66 | def rectify_poly(img, poly): 67 | # Use Affine transform 68 | n = int(len(poly) / 2) - 1 69 | width = 0 70 | height = 0 71 | for k in range(n): 72 | box = np.float32([poly[k], poly[k + 1], poly[-k - 2], poly[-k - 1]]) 73 | width += int( 74 | (np.linalg.norm(box[0] - box[1]) + np.linalg.norm(box[2] - box[3])) / 2 75 | ) 76 | height += np.linalg.norm(box[1] - box[2]) 77 | width = int(width) 78 | height = int(height / n) 79 | 80 | output_img = np.zeros((height, width, 3), dtype=np.uint8) 81 | width_step = 0 82 | for k in range(n): 83 | box = np.float32([poly[k], poly[k + 1], poly[-k - 2], poly[-k - 1]]) 84 | w = int((np.linalg.norm(box[0] - box[1]) + np.linalg.norm(box[2] - box[3])) / 2) 85 | 86 | # Top triangle 87 | pts1 = box[:3] 88 | pts2 = np.float32( 89 | [[width_step, 0], [width_step + w - 1, 0], [width_step + w - 1, height - 1]] 90 | ) 91 | M = cv2.getAffineTransform(pts1, pts2) 92 | warped_img = cv2.warpAffine( 93 | img, M, (width, height), borderMode=cv2.BORDER_REPLICATE 94 | ) 95 | warped_mask = np.zeros((height, width, 3), dtype=np.uint8) 96 | warped_mask = cv2.fillConvexPoly(warped_mask, np.int32(pts2), (1, 1, 1)) 97 | output_img[warped_mask == 1] = warped_img[warped_mask == 1] 98 | 99 | # Bottom triangle 100 | pts1 = np.vstack((box[0], box[2:])) 101 | pts2 = np.float32( 102 | [ 103 | [width_step, 0], 104 | [width_step + w - 1, height - 1], 105 | [width_step, height - 1], 106 | ] 107 | ) 108 | M = cv2.getAffineTransform(pts1, pts2) 109 | warped_img = cv2.warpAffine( 110 | img, M, (width, height), borderMode=cv2.BORDER_REPLICATE 111 | ) 112 | warped_mask = np.zeros((height, width, 3), dtype=np.uint8) 113 | warped_mask = cv2.fillConvexPoly(warped_mask, np.int32(pts2), (1, 1, 1)) 114 | cv2.line( 115 | warped_mask, (width_step, 0), (width_step + w - 1, height - 1), (0, 0, 0), 1 116 | ) 117 | output_img[warped_mask == 1] = warped_img[warped_mask == 1] 118 | 119 | width_step += w 120 | return output_img 121 | 122 | 123 | def crop_poly(image, poly): 124 | # points should have 1*x*2 shape 125 | if len(poly.shape) == 2: 126 | poly = np.array([np.array(poly).astype(np.int32)]) 127 | 128 | # create mask with shape of image 129 | mask = np.zeros(image.shape[0:2], dtype=np.uint8) 130 | 131 | # method 1 smooth region 132 | cv2.drawContours(mask, [poly], -1, (255, 255, 255), -1, cv2.LINE_AA) 133 | # method 2 not so smooth region 134 | # cv2.fillPoly(mask, points, (255)) 135 | 136 | # crop around poly 137 | res = cv2.bitwise_and(image, image, mask=mask) 138 | rect = cv2.boundingRect(poly) # returns (x,y,w,h) of the rect 139 | cropped = res[rect[1] : rect[1] + rect[3], rect[0] : rect[0] + rect[2]] 140 | 141 | return cropped 142 | 143 | 144 | def export_detected_region(image, poly, file_path, rectify=True): 145 | """ 146 | Arguments: 147 | image: full image 148 | points: bbox or poly points 149 | file_path: path to be exported 150 | rectify: rectify detected polygon by affine transform 151 | """ 152 | if rectify: 153 | # rectify poly region 154 | result_rgb = rectify_poly(image, poly) 155 | else: 156 | result_rgb = crop_poly(image, poly) 157 | 158 | # export corpped region 159 | result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR) 160 | cv2.imwrite(file_path, result_bgr) 161 | 162 | 163 | def export_detected_regions( 164 | image, 165 | regions, 166 | file_name: str = "image", 167 | output_dir: str = "output/", 168 | rectify: bool = False, 169 | ): 170 | """ 171 | Arguments: 172 | image: path to the image to be processed or numpy array or PIL image 173 | regions: list of bboxes or polys 174 | file_name (str): export image file name 175 | output_dir: folder to be exported 176 | rectify: rectify detected polygon by affine transform 177 | """ 178 | 179 | # read/convert image 180 | image = read_image(image) 181 | 182 | # deepcopy image so that original is not altered 183 | image = copy.deepcopy(image) 184 | 185 | # create crops dir 186 | crops_dir = os.path.join(output_dir, file_name + "_crops") 187 | create_dir(crops_dir) 188 | 189 | # init exported file paths 190 | exported_file_paths = [] 191 | 192 | # export regions 193 | for ind, region in enumerate(regions): 194 | # get export path 195 | file_path = os.path.join(crops_dir, "crop_" + str(ind) + ".png") 196 | # export region 197 | export_detected_region(image, poly=region, file_path=file_path, rectify=rectify) 198 | # note exported file path 199 | exported_file_paths.append(file_path) 200 | 201 | return exported_file_paths 202 | 203 | 204 | def export_extra_results( 205 | image, 206 | regions, 207 | heatmaps, 208 | file_name: str = "image", 209 | output_dir="output/", 210 | verticals=None, 211 | texts=None, 212 | ): 213 | """save text detection result one by one 214 | Args: 215 | image: path to the image to be processed or numpy array or PIL image 216 | file_name (str): export image file name 217 | boxes (array): array of result file 218 | Shape: [num_detections, 4] for BB output / [num_detections, 4] 219 | for QUAD output 220 | Return: 221 | None 222 | """ 223 | # read/convert image 224 | image = read_image(image) 225 | 226 | # result directory 227 | res_file = os.path.join(output_dir, file_name + "_text_detection.txt") 228 | res_img_file = os.path.join(output_dir, file_name + "_text_detection.png") 229 | text_heatmap_file = os.path.join(output_dir, file_name + "_text_score_heatmap.png") 230 | link_heatmap_file = os.path.join(output_dir, file_name + "_link_score_heatmap.png") 231 | 232 | # create output dir 233 | create_dir(output_dir) 234 | 235 | # export heatmaps 236 | cv2.imwrite(text_heatmap_file, heatmaps["text_score_heatmap"]) 237 | cv2.imwrite(link_heatmap_file, heatmaps["link_score_heatmap"]) 238 | 239 | with open(res_file, "w") as f: 240 | for i, region in enumerate(regions): 241 | region = np.array(region).astype(np.int32).reshape((-1)) 242 | strResult = ",".join([str(r) for r in region]) + "\r\n" 243 | f.write(strResult) 244 | 245 | region = region.reshape(-1, 2) 246 | cv2.polylines( 247 | image, 248 | [region.reshape((-1, 1, 2))], 249 | True, 250 | color=(0, 0, 255), 251 | thickness=2, 252 | ) 253 | 254 | if texts is not None: 255 | font = cv2.FONT_HERSHEY_SIMPLEX 256 | font_scale = 0.5 257 | cv2.putText( 258 | image, 259 | "{}".format(texts[i]), 260 | (region[0][0] + 1, region[0][1] + 1), 261 | font, 262 | font_scale, 263 | (0, 0, 0), 264 | thickness=1, 265 | ) 266 | cv2.putText( 267 | image, 268 | "{}".format(texts[i]), 269 | tuple(region[0]), 270 | font, 271 | font_scale, 272 | (0, 255, 255), 273 | thickness=1, 274 | ) 275 | 276 | # Save result image 277 | cv2.imwrite(res_img_file, cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) 278 | -------------------------------------------------------------------------------- /craft_text_detector/image_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | MIT License 4 | """ 5 | 6 | import cv2 7 | import numpy as np 8 | 9 | 10 | def read_image(image): 11 | if type(image) == str: 12 | img = cv2.imread(image) 13 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 14 | 15 | elif type(image) == bytes: 16 | nparr = np.frombuffer(image, np.uint8) 17 | img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) 18 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 19 | 20 | elif type(image) == np.ndarray: 21 | if len(image.shape) == 2: # grayscale 22 | img = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) 23 | elif len(image.shape) == 3 and image.shape[2] == 3: 24 | img = image 25 | elif len(image.shape) == 3 and image.shape[2] == 4: # RGBAscale 26 | img = image[:, :, :3] 27 | 28 | return img 29 | 30 | 31 | def normalizeMeanVariance( 32 | in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225) 33 | ): 34 | # should be RGB order 35 | img = in_img.copy().astype(np.float32) 36 | 37 | img -= np.array( 38 | [mean[0] * 255.0, mean[1] * 255.0, mean[2] * 255.0], dtype=np.float32 39 | ) 40 | img /= np.array( 41 | [variance[0] * 255.0, variance[1] * 255.0, variance[2] * 255.0], 42 | dtype=np.float32, 43 | ) 44 | return img 45 | 46 | 47 | def denormalizeMeanVariance( 48 | in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225) 49 | ): 50 | # should be RGB order 51 | img = in_img.copy() 52 | img *= variance 53 | img += mean 54 | img *= 255.0 55 | img = np.clip(img, 0, 255).astype(np.uint8) 56 | return img 57 | 58 | 59 | def resize_aspect_ratio(img, long_size, interpolation): 60 | height, width, channel = img.shape 61 | 62 | # set target image size 63 | target_size = long_size 64 | 65 | ratio = target_size / max(height, width) 66 | 67 | target_h, target_w = int(height * ratio), int(width * ratio) 68 | proc = cv2.resize(img, (target_w, target_h), interpolation=interpolation) 69 | 70 | # make canvas and paste image 71 | target_h32, target_w32 = target_h, target_w 72 | if target_h % 32 != 0: 73 | target_h32 = target_h + (32 - target_h % 32) 74 | if target_w % 32 != 0: 75 | target_w32 = target_w + (32 - target_w % 32) 76 | resized = np.zeros((target_h32, target_w32, channel), dtype=np.float32) 77 | resized[0:target_h, 0:target_w, :] = proc 78 | target_h, target_w = target_h32, target_w32 79 | 80 | size_heatmap = (int(target_w / 2), int(target_h / 2)) 81 | 82 | return resized, ratio, size_heatmap 83 | 84 | 85 | def cvt2HeatmapImg(img): 86 | img = (np.clip(img, 0, 1) * 255).astype(np.uint8) 87 | img = cv2.applyColorMap(img, cv2.COLORMAP_JET) 88 | return img 89 | -------------------------------------------------------------------------------- /craft_text_detector/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fcakyon/craft-text-detector/6b10d3e0d178679e35796cf4426abb083b354239/craft_text_detector/models/__init__.py -------------------------------------------------------------------------------- /craft_text_detector/models/basenet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fcakyon/craft-text-detector/6b10d3e0d178679e35796cf4426abb083b354239/craft_text_detector/models/basenet/__init__.py -------------------------------------------------------------------------------- /craft_text_detector/models/basenet/vgg16_bn.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | from torchvision import models 7 | from torchvision.models.vgg import model_urls 8 | 9 | 10 | def init_weights(modules): 11 | for m in modules: 12 | if isinstance(m, nn.Conv2d): 13 | init.xavier_uniform_(m.weight.data) 14 | if m.bias is not None: 15 | m.bias.data.zero_() 16 | elif isinstance(m, nn.BatchNorm2d): 17 | m.weight.data.fill_(1) 18 | m.bias.data.zero_() 19 | elif isinstance(m, nn.Linear): 20 | m.weight.data.normal_(0, 0.01) 21 | m.bias.data.zero_() 22 | 23 | 24 | class vgg16_bn(torch.nn.Module): 25 | def __init__(self, pretrained=True, freeze=True): 26 | super(vgg16_bn, self).__init__() 27 | model_urls["vgg16_bn"] = model_urls["vgg16_bn"].replace("https://", "http://") 28 | vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features 29 | self.slice1 = torch.nn.Sequential() 30 | self.slice2 = torch.nn.Sequential() 31 | self.slice3 = torch.nn.Sequential() 32 | self.slice4 = torch.nn.Sequential() 33 | self.slice5 = torch.nn.Sequential() 34 | for x in range(12): # conv2_2 35 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 36 | for x in range(12, 19): # conv3_3 37 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 38 | for x in range(19, 29): # conv4_3 39 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 40 | for x in range(29, 39): # conv5_3 41 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 42 | 43 | # fc6, fc7 without atrous conv 44 | self.slice5 = torch.nn.Sequential( 45 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1), 46 | nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6), 47 | nn.Conv2d(1024, 1024, kernel_size=1), 48 | ) 49 | 50 | if not pretrained: 51 | init_weights(self.slice1.modules()) 52 | init_weights(self.slice2.modules()) 53 | init_weights(self.slice3.modules()) 54 | init_weights(self.slice4.modules()) 55 | 56 | init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7 57 | 58 | if freeze: 59 | for param in self.slice1.parameters(): # only first conv 60 | param.requires_grad = False 61 | 62 | def forward(self, X): 63 | h = self.slice1(X) 64 | h_relu2_2 = h 65 | h = self.slice2(h) 66 | h_relu3_2 = h 67 | h = self.slice3(h) 68 | h_relu4_3 = h 69 | h = self.slice4(h) 70 | h_relu5_3 = h 71 | h = self.slice5(h) 72 | h_fc7 = h 73 | vgg_outputs = namedtuple( 74 | "VggOutputs", ["fc7", "relu5_3", "relu4_3", "relu3_2", "relu2_2"] 75 | ) 76 | out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2) 77 | return out 78 | -------------------------------------------------------------------------------- /craft_text_detector/models/craftnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | MIT License 4 | """ 5 | 6 | # -*- coding: utf-8 -*- 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from craft_text_detector.models.basenet.vgg16_bn import vgg16_bn, init_weights 12 | 13 | 14 | class double_conv(nn.Module): 15 | def __init__(self, in_ch, mid_ch, out_ch): 16 | super(double_conv, self).__init__() 17 | self.conv = nn.Sequential( 18 | nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1), 19 | nn.BatchNorm2d(mid_ch), 20 | nn.ReLU(inplace=True), 21 | nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1), 22 | nn.BatchNorm2d(out_ch), 23 | nn.ReLU(inplace=True), 24 | ) 25 | 26 | def forward(self, x): 27 | x = self.conv(x) 28 | return x 29 | 30 | 31 | class CraftNet(nn.Module): 32 | def __init__(self, pretrained=False, freeze=False): 33 | super(CraftNet, self).__init__() 34 | 35 | """ Base network """ 36 | self.basenet = vgg16_bn(pretrained, freeze) 37 | 38 | """ U network """ 39 | self.upconv1 = double_conv(1024, 512, 256) 40 | self.upconv2 = double_conv(512, 256, 128) 41 | self.upconv3 = double_conv(256, 128, 64) 42 | self.upconv4 = double_conv(128, 64, 32) 43 | 44 | num_class = 2 45 | self.conv_cls = nn.Sequential( 46 | nn.Conv2d(32, 32, kernel_size=3, padding=1), 47 | nn.ReLU(inplace=True), 48 | nn.Conv2d(32, 32, kernel_size=3, padding=1), 49 | nn.ReLU(inplace=True), 50 | nn.Conv2d(32, 16, kernel_size=3, padding=1), 51 | nn.ReLU(inplace=True), 52 | nn.Conv2d(16, 16, kernel_size=1), 53 | nn.ReLU(inplace=True), 54 | nn.Conv2d(16, num_class, kernel_size=1), 55 | ) 56 | 57 | init_weights(self.upconv1.modules()) 58 | init_weights(self.upconv2.modules()) 59 | init_weights(self.upconv3.modules()) 60 | init_weights(self.upconv4.modules()) 61 | init_weights(self.conv_cls.modules()) 62 | 63 | def forward(self, x): 64 | """ Base network """ 65 | sources = self.basenet(x) 66 | 67 | """ U network """ 68 | y = torch.cat([sources[0], sources[1]], dim=1) 69 | y = self.upconv1(y) 70 | 71 | y = F.interpolate( 72 | y, size=sources[2].size()[2:], mode="bilinear", align_corners=False 73 | ) 74 | y = torch.cat([y, sources[2]], dim=1) 75 | y = self.upconv2(y) 76 | 77 | y = F.interpolate( 78 | y, size=sources[3].size()[2:], mode="bilinear", align_corners=False 79 | ) 80 | y = torch.cat([y, sources[3]], dim=1) 81 | y = self.upconv3(y) 82 | 83 | y = F.interpolate( 84 | y, size=sources[4].size()[2:], mode="bilinear", align_corners=False 85 | ) 86 | y = torch.cat([y, sources[4]], dim=1) 87 | feature = self.upconv4(y) 88 | 89 | y = self.conv_cls(feature) 90 | 91 | return y.permute(0, 2, 3, 1), feature 92 | 93 | 94 | if __name__ == "__main__": 95 | model = CraftNet(pretrained=True).cuda() 96 | output, _ = model(torch.randn(1, 3, 768, 768).cuda()) 97 | print(output.shape) 98 | -------------------------------------------------------------------------------- /craft_text_detector/models/refinenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | MIT License 4 | """ 5 | 6 | # -*- coding: utf-8 -*- 7 | import torch 8 | import torch.nn as nn 9 | from craft_text_detector.models.basenet.vgg16_bn import init_weights 10 | 11 | 12 | class RefineNet(nn.Module): 13 | def __init__(self): 14 | super(RefineNet, self).__init__() 15 | 16 | self.last_conv = nn.Sequential( 17 | nn.Conv2d(34, 64, kernel_size=3, padding=1), 18 | nn.BatchNorm2d(64), 19 | nn.ReLU(inplace=True), 20 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 21 | nn.BatchNorm2d(64), 22 | nn.ReLU(inplace=True), 23 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 24 | nn.BatchNorm2d(64), 25 | nn.ReLU(inplace=True), 26 | ) 27 | 28 | self.aspp1 = nn.Sequential( 29 | nn.Conv2d(64, 128, kernel_size=3, dilation=6, padding=6), 30 | nn.BatchNorm2d(128), 31 | nn.ReLU(inplace=True), 32 | nn.Conv2d(128, 128, kernel_size=1), 33 | nn.BatchNorm2d(128), 34 | nn.ReLU(inplace=True), 35 | nn.Conv2d(128, 1, kernel_size=1), 36 | ) 37 | 38 | self.aspp2 = nn.Sequential( 39 | nn.Conv2d(64, 128, kernel_size=3, dilation=12, padding=12), 40 | nn.BatchNorm2d(128), 41 | nn.ReLU(inplace=True), 42 | nn.Conv2d(128, 128, kernel_size=1), 43 | nn.BatchNorm2d(128), 44 | nn.ReLU(inplace=True), 45 | nn.Conv2d(128, 1, kernel_size=1), 46 | ) 47 | 48 | self.aspp3 = nn.Sequential( 49 | nn.Conv2d(64, 128, kernel_size=3, dilation=18, padding=18), 50 | nn.BatchNorm2d(128), 51 | nn.ReLU(inplace=True), 52 | nn.Conv2d(128, 128, kernel_size=1), 53 | nn.BatchNorm2d(128), 54 | nn.ReLU(inplace=True), 55 | nn.Conv2d(128, 1, kernel_size=1), 56 | ) 57 | 58 | self.aspp4 = nn.Sequential( 59 | nn.Conv2d(64, 128, kernel_size=3, dilation=24, padding=24), 60 | nn.BatchNorm2d(128), 61 | nn.ReLU(inplace=True), 62 | nn.Conv2d(128, 128, kernel_size=1), 63 | nn.BatchNorm2d(128), 64 | nn.ReLU(inplace=True), 65 | nn.Conv2d(128, 1, kernel_size=1), 66 | ) 67 | 68 | init_weights(self.last_conv.modules()) 69 | init_weights(self.aspp1.modules()) 70 | init_weights(self.aspp2.modules()) 71 | init_weights(self.aspp3.modules()) 72 | init_weights(self.aspp4.modules()) 73 | 74 | def forward(self, y, upconv4): 75 | refine = torch.cat([y.permute(0, 3, 1, 2), upconv4], dim=1) 76 | refine = self.last_conv(refine) 77 | 78 | aspp1 = self.aspp1(refine) 79 | aspp2 = self.aspp2(refine) 80 | aspp3 = self.aspp3(refine) 81 | aspp4 = self.aspp4(refine) 82 | 83 | # out = torch.add([aspp1, aspp2, aspp3, aspp4], dim=1) 84 | out = aspp1 + aspp2 + aspp3 + aspp4 85 | return out.permute(0, 2, 3, 1) # , refine.permute(0,2,3,1) 86 | -------------------------------------------------------------------------------- /craft_text_detector/predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | import craft_text_detector.craft_utils as craft_utils 8 | import craft_text_detector.image_utils as image_utils 9 | import craft_text_detector.torch_utils as torch_utils 10 | 11 | 12 | def get_prediction( 13 | image, 14 | craft_net, 15 | refine_net=None, 16 | text_threshold: float = 0.7, 17 | link_threshold: float = 0.4, 18 | low_text: float = 0.4, 19 | cuda: bool = False, 20 | long_size: int = 1280, 21 | poly: bool = True, 22 | ): 23 | """ 24 | Arguments: 25 | image: path to the image to be processed or numpy array or PIL image 26 | output_dir: path to the results to be exported 27 | craft_net: craft net model 28 | refine_net: refine net model 29 | text_threshold: text confidence threshold 30 | link_threshold: link confidence threshold 31 | low_text: text low-bound score 32 | cuda: Use cuda for inference 33 | canvas_size: image size for inference 34 | long_size: desired longest image size for inference 35 | poly: enable polygon type 36 | Output: 37 | {"masks": lists of predicted masks 2d as bool array, 38 | "boxes": list of coords of points of predicted boxes, 39 | "boxes_as_ratios": list of coords of points of predicted boxes as ratios of image size, 40 | "polys_as_ratios": list of coords of points of predicted polys as ratios of image size, 41 | "heatmaps": visualizations of the detected characters/links, 42 | "times": elapsed times of the sub modules, in seconds} 43 | """ 44 | t0 = time.time() 45 | 46 | # read/convert image 47 | image = image_utils.read_image(image) 48 | 49 | # resize 50 | img_resized, target_ratio, size_heatmap = image_utils.resize_aspect_ratio( 51 | image, long_size, interpolation=cv2.INTER_LINEAR 52 | ) 53 | ratio_h = ratio_w = 1 / target_ratio 54 | resize_time = time.time() - t0 55 | t0 = time.time() 56 | 57 | # preprocessing 58 | x = image_utils.normalizeMeanVariance(img_resized) 59 | x = torch_utils.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] 60 | x = torch_utils.Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w] 61 | if cuda: 62 | x = x.cuda() 63 | preprocessing_time = time.time() - t0 64 | t0 = time.time() 65 | 66 | # forward pass 67 | with torch_utils.no_grad(): 68 | y, feature = craft_net(x) 69 | craftnet_time = time.time() - t0 70 | t0 = time.time() 71 | 72 | # make score and link map 73 | score_text = y[0, :, :, 0].cpu().data.numpy() 74 | score_link = y[0, :, :, 1].cpu().data.numpy() 75 | 76 | # refine link 77 | if refine_net is not None: 78 | with torch_utils.no_grad(): 79 | y_refiner = refine_net(y, feature) 80 | score_link = y_refiner[0, :, :, 0].cpu().data.numpy() 81 | refinenet_time = time.time() - t0 82 | t0 = time.time() 83 | 84 | # Post-processing 85 | boxes, polys = craft_utils.getDetBoxes( 86 | score_text, score_link, text_threshold, link_threshold, low_text, poly 87 | ) 88 | 89 | # coordinate adjustment 90 | boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) 91 | polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h) 92 | for k in range(len(polys)): 93 | if polys[k] is None: 94 | polys[k] = boxes[k] 95 | 96 | # get image size 97 | img_height = image.shape[0] 98 | img_width = image.shape[1] 99 | 100 | # calculate box coords as ratios to image size 101 | boxes_as_ratio = [] 102 | for box in boxes: 103 | boxes_as_ratio.append(box / [img_width, img_height]) 104 | boxes_as_ratio = np.array(boxes_as_ratio) 105 | 106 | # calculate poly coords as ratios to image size 107 | polys_as_ratio = [] 108 | for poly in polys: 109 | polys_as_ratio.append(poly / [img_width, img_height]) 110 | polys_as_ratio = np.array(polys_as_ratio) 111 | 112 | text_score_heatmap = image_utils.cvt2HeatmapImg(score_text) 113 | link_score_heatmap = image_utils.cvt2HeatmapImg(score_link) 114 | 115 | postprocess_time = time.time() - t0 116 | 117 | times = { 118 | "resize_time": resize_time, 119 | "preprocessing_time": preprocessing_time, 120 | "craftnet_time": craftnet_time, 121 | "refinenet_time": refinenet_time, 122 | "postprocess_time": postprocess_time, 123 | } 124 | 125 | return { 126 | "boxes": boxes, 127 | "boxes_as_ratios": boxes_as_ratio, 128 | "polys": polys, 129 | "polys_as_ratios": polys_as_ratio, 130 | "heatmaps": { 131 | "text_score_heatmap": text_score_heatmap, 132 | "link_score_heatmap": link_score_heatmap, 133 | }, 134 | "times": times, 135 | } 136 | -------------------------------------------------------------------------------- /craft_text_detector/torch_utils.py: -------------------------------------------------------------------------------- 1 | from torch import from_numpy, load, no_grad 2 | from torch.autograd import Variable 3 | from torch.backends.cudnn import benchmark as cudnn_benchmark 4 | from torch.cuda import empty_cache as empty_cuda_cache 5 | from torch.nn import DataParallel 6 | -------------------------------------------------------------------------------- /figures/craft_example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fcakyon/craft-text-detector/6b10d3e0d178679e35796cf4426abb083b354239/figures/craft_example.gif -------------------------------------------------------------------------------- /figures/idcard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fcakyon/craft-text-detector/6b10d3e0d178679e35796cf4426abb083b354239/figures/idcard.png -------------------------------------------------------------------------------- /figures/idcard2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fcakyon/craft-text-detector/6b10d3e0d178679e35796cf4426abb083b354239/figures/idcard2.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.6.0 2 | torchvision>=0.7.0 3 | opencv-python>=3.4.8.29,<4.5.4.62 4 | scipy>=1.3.2 5 | gdown>=3.10.1 6 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | exclude =.git,__pycache__,docs/source/conf.py,build,dist 4 | ignore = I101,I201,F401,F403,S001,D100,D101,D102,D103,D104,D105,D106,D107,D200,D205,D400,W504,D202,E203,W503,B006 5 | inline-quotes = " 6 | 7 | [mypy] 8 | ignore_missing_imports = True 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import re 4 | 5 | import setuptools 6 | 7 | 8 | def get_long_description(): 9 | base_dir = os.path.abspath(os.path.dirname(__file__)) 10 | with io.open(os.path.join(base_dir, "README.md"), encoding="utf-8") as f: 11 | return f.read() 12 | 13 | 14 | def get_requirements(): 15 | with open("requirements.txt") as f: 16 | return f.read().splitlines() 17 | 18 | 19 | def get_version(): 20 | current_dir = os.path.abspath(os.path.dirname(__file__)) 21 | version_file = os.path.join(current_dir, "craft_text_detector", "__init__.py") 22 | with io.open(version_file, encoding="utf-8") as f: 23 | return re.search(r'^__version__ = [\'"]([^\'"]*)[\'"]', f.read(), re.M).group(1) 24 | 25 | 26 | setuptools.setup( 27 | name="craft-text-detector", 28 | version=get_version(), 29 | author="Fatih Cagatay Akyon", 30 | license="MIT", 31 | description="Fast and accurate text detection library built on CRAFT implementation", 32 | long_description=get_long_description(), 33 | long_description_content_type="text/markdown", 34 | url="https://github.com/fcakyon/craft_text_detector", 35 | packages=setuptools.find_packages(exclude=["tests"]), 36 | install_requires=get_requirements(), 37 | python_requires=">=3.7", 38 | classifiers=[ 39 | "Development Status :: 5 - Production/Stable", 40 | "License :: OSI Approved :: MIT License", 41 | "Operating System :: OS Independent", 42 | "Intended Audience :: Developers", 43 | "Intended Audience :: Science/Research", 44 | "Programming Language :: Python :: 3", 45 | "Programming Language :: Python :: 3.6", 46 | "Programming Language :: Python :: 3.7", 47 | "Programming Language :: Python :: 3.8", 48 | "Topic :: Software Development :: Libraries", 49 | "Topic :: Software Development :: Libraries :: Python Modules", 50 | "Topic :: Education", 51 | "Topic :: Scientific/Engineering", 52 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 53 | "Topic :: Scientific/Engineering :: Image Recognition", 54 | ], 55 | keywords="machine-learning, deep-learning, ml, pytorch, text, text-detection, craft", 56 | ) 57 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fcakyon/craft-text-detector/6b10d3e0d178679e35796cf4426abb083b354239/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_craft.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from craft_text_detector import Craft 4 | 5 | 6 | class TestCraftTextDetector(unittest.TestCase): 7 | image_path = "figures/idcard.png" 8 | 9 | def test_init(self): 10 | craft = Craft( 11 | output_dir=None, 12 | rectify=True, 13 | export_extra=False, 14 | text_threshold=0.7, 15 | link_threshold=0.4, 16 | low_text=0.4, 17 | cuda=False, 18 | long_size=720, 19 | refiner=False, 20 | crop_type="poly", 21 | ) 22 | self.assertTrue(craft) 23 | 24 | def test_load_craftnet_model(self): 25 | # init craft 26 | craft = Craft( 27 | output_dir=None, 28 | rectify=True, 29 | export_extra=False, 30 | text_threshold=0.7, 31 | link_threshold=0.4, 32 | low_text=0.4, 33 | cuda=False, 34 | long_size=720, 35 | refiner=False, 36 | crop_type="poly", 37 | ) 38 | # remove craftnet model 39 | craft.craft_net = None 40 | # load craftnet model 41 | craft.load_craftnet_model() 42 | self.assertTrue(craft.craft_net) 43 | 44 | def test_load_refinenet_model(self): 45 | # init craft 46 | craft = Craft( 47 | output_dir=None, 48 | rectify=True, 49 | export_extra=False, 50 | text_threshold=0.7, 51 | link_threshold=0.4, 52 | low_text=0.4, 53 | cuda=False, 54 | long_size=720, 55 | refiner=False, 56 | crop_type="poly", 57 | ) 58 | # remove refinenet model 59 | craft.refine_net = None 60 | # load refinenet model 61 | craft.load_refinenet_model() 62 | self.assertTrue(craft.refine_net) 63 | 64 | def test_detect_text(self): 65 | # init craft 66 | craft = Craft( 67 | output_dir=None, 68 | rectify=True, 69 | export_extra=False, 70 | text_threshold=0.7, 71 | link_threshold=0.4, 72 | low_text=0.4, 73 | cuda=False, 74 | long_size=720, 75 | refiner=False, 76 | crop_type="poly", 77 | ) 78 | # detect text 79 | prediction_result = craft.detect_text(image=self.image_path) 80 | 81 | self.assertEqual(len(prediction_result["boxes"]), 52) 82 | self.assertEqual(len(prediction_result["boxes"][0]), 4) 83 | self.assertEqual(len(prediction_result["boxes"][0][0]), 2) 84 | self.assertEqual(int(prediction_result["boxes"][0][0][0]), 115) 85 | 86 | # init craft 87 | craft = Craft( 88 | output_dir=None, 89 | rectify=True, 90 | export_extra=False, 91 | text_threshold=0.7, 92 | link_threshold=0.4, 93 | low_text=0.4, 94 | cuda=False, 95 | long_size=720, 96 | refiner=True, 97 | crop_type="poly", 98 | ) 99 | # detect text 100 | prediction_result = craft.detect_text(image=self.image_path) 101 | 102 | self.assertEqual(len(prediction_result["boxes"]), 19) 103 | self.assertEqual(len(prediction_result["boxes"][0]), 4) 104 | self.assertEqual(len(prediction_result["boxes"][0][0]), 2) 105 | self.assertEqual(int(prediction_result["boxes"][0][2][0]), 661) 106 | 107 | # init craft 108 | craft = Craft( 109 | output_dir=None, 110 | rectify=False, 111 | export_extra=False, 112 | text_threshold=0.7, 113 | link_threshold=0.4, 114 | low_text=0.4, 115 | cuda=False, 116 | long_size=720, 117 | refiner=False, 118 | crop_type="box", 119 | ) 120 | # detect text 121 | prediction_result = craft.detect_text(image=self.image_path) 122 | 123 | self.assertEqual(len(prediction_result["boxes"]), 52) 124 | self.assertEqual(len(prediction_result["boxes"][0]), 4) 125 | self.assertEqual(len(prediction_result["boxes"][0][0]), 2) 126 | self.assertEqual(int(prediction_result["boxes"][0][2][0]), 244) 127 | 128 | # init craft 129 | craft = Craft( 130 | output_dir=None, 131 | rectify=False, 132 | export_extra=False, 133 | text_threshold=0.7, 134 | link_threshold=0.4, 135 | low_text=0.4, 136 | cuda=False, 137 | long_size=720, 138 | refiner=True, 139 | crop_type="box", 140 | ) 141 | # detect text 142 | prediction_result = craft.detect_text(image=self.image_path) 143 | 144 | self.assertEqual(len(prediction_result["boxes"]), 19) 145 | self.assertEqual(len(prediction_result["boxes"][0]), 4) 146 | self.assertEqual(len(prediction_result["boxes"][0][0]), 2) 147 | self.assertEqual(int(prediction_result["boxes"][0][2][0]), 661) 148 | 149 | 150 | if __name__ == "__main__": 151 | unittest.main() 152 | -------------------------------------------------------------------------------- /tests/test_helpers.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from pathlib import Path 3 | from tempfile import TemporaryDirectory 4 | 5 | from craft_text_detector import ( 6 | export_detected_regions, 7 | export_extra_results, 8 | get_prediction, 9 | load_craftnet_model, 10 | load_refinenet_model, 11 | read_image, 12 | ) 13 | 14 | 15 | class TestCraftTextDetectorHelpers(unittest.TestCase): 16 | image_path = "figures/idcard.png" 17 | 18 | def test_load_craftnet_model(self): 19 | craft_net = load_craftnet_model(cuda=False) 20 | self.assertTrue(craft_net) 21 | 22 | with TemporaryDirectory() as dir_name: 23 | weight_path = Path(dir_name, "weights.pth") 24 | self.assertFalse(weight_path.is_file()) 25 | load_craftnet_model(cuda=False, weight_path=weight_path) 26 | self.assertTrue(weight_path.is_file()) 27 | 28 | def test_load_refinenet_model(self): 29 | refine_net = load_refinenet_model(cuda=False) 30 | self.assertTrue(refine_net) 31 | 32 | with TemporaryDirectory() as dir_name: 33 | weight_path = Path(dir_name, "weights.pth") 34 | self.assertFalse(weight_path.is_file()) 35 | load_refinenet_model(cuda=False, weight_path=weight_path) 36 | self.assertTrue(weight_path.is_file()) 37 | 38 | def test_read_image(self): 39 | image = read_image(self.image_path) 40 | self.assertTrue(image.shape, (500, 786, 3)) 41 | 42 | def test_get_prediction(self): 43 | # load image 44 | image = read_image(self.image_path) 45 | 46 | # load models 47 | craft_net = load_craftnet_model() 48 | refine_net = None 49 | 50 | # perform prediction 51 | text_threshold = 0.9 52 | link_threshold = 0.2 53 | low_text = 0.2 54 | cuda = False 55 | prediction_result = get_prediction( 56 | image=image, 57 | craft_net=craft_net, 58 | refine_net=refine_net, 59 | text_threshold=text_threshold, 60 | link_threshold=link_threshold, 61 | low_text=low_text, 62 | cuda=cuda, 63 | long_size=720, 64 | ) 65 | 66 | self.assertEqual(len(prediction_result["boxes"]), 35) 67 | self.assertEqual(len(prediction_result["boxes"][0]), 4) 68 | self.assertEqual(len(prediction_result["boxes"][0][0]), 2) 69 | self.assertEqual(int(prediction_result["boxes"][0][0][0]), 111) 70 | self.assertEqual(len(prediction_result["polys"]), 35) 71 | self.assertEqual( 72 | prediction_result["heatmaps"]["text_score_heatmap"].shape, (240, 368, 3) 73 | ) 74 | 75 | def test_get_prediction_without_read_image(self): 76 | # set image filepath 77 | image = self.image_path 78 | 79 | # load models 80 | craft_net = load_craftnet_model() 81 | refine_net = None 82 | 83 | # perform prediction 84 | text_threshold = 0.9 85 | link_threshold = 0.2 86 | low_text = 0.2 87 | cuda = False 88 | prediction_result = get_prediction( 89 | image=image, 90 | craft_net=craft_net, 91 | refine_net=refine_net, 92 | text_threshold=text_threshold, 93 | link_threshold=link_threshold, 94 | low_text=low_text, 95 | cuda=cuda, 96 | long_size=720, 97 | ) 98 | 99 | self.assertEqual(len(prediction_result["boxes"]), 35) 100 | self.assertEqual(len(prediction_result["boxes"][0]), 4) 101 | self.assertEqual(len(prediction_result["boxes"][0][0]), 2) 102 | self.assertEqual(int(prediction_result["boxes"][0][0][0]), 111) 103 | self.assertEqual(len(prediction_result["polys"]), 35) 104 | self.assertEqual( 105 | prediction_result["heatmaps"]["text_score_heatmap"].shape, (240, 368, 3) 106 | ) 107 | 108 | 109 | if __name__ == "__main__": 110 | unittest.main() 111 | --------------------------------------------------------------------------------