├── .github
├── ISSUE_TEMPLATE
│ └── bug-report.md
└── workflows
│ ├── ci.yml
│ ├── package_testing.yml
│ └── publish_pypi.yml
├── .gitignore
├── LICENSE
├── README.md
├── craft_text_detector
├── __init__.py
├── craft_utils.py
├── file_utils.py
├── image_utils.py
├── models
│ ├── __init__.py
│ ├── basenet
│ │ ├── __init__.py
│ │ └── vgg16_bn.py
│ ├── craftnet.py
│ └── refinenet.py
├── predict.py
└── torch_utils.py
├── figures
├── craft_example.gif
├── idcard.png
└── idcard2.jpg
├── requirements.txt
├── setup.cfg
├── setup.py
└── tests
├── __init__.py
├── test_craft.py
└── test_helpers.py
/.github/ISSUE_TEMPLATE/bug-report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: "Bug Report"
3 | about: Submit a bug report to help us improve craft-text-detector
4 |
5 | ---
6 |
7 | ## 🐛 Bug
8 |
9 |
10 |
11 | ## To Reproduce
12 |
13 | Steps to reproduce the behavior:
14 |
15 | 1.
16 | 1.
17 | 1.
18 |
19 |
20 |
21 | ## Expected behavior
22 |
23 |
24 |
25 | ## Environment
26 |
27 | - craft-text-detector version (e.g., 0.3.1):
28 | - Python version (e.g., 3.6/3.7):
29 | - OS (e.g., Linux/Windows/MacOS):
30 | - How you installed craft-text-detector (`conda`, `pip`, source):
31 | - Any other relevant information:
32 |
33 | ## Additional context
34 |
35 |
36 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | push:
5 | branches: [ master ]
6 | pull_request:
7 | branches: [ master ]
8 |
9 | jobs:
10 | build:
11 |
12 | runs-on: ubuntu-latest
13 | strategy:
14 | matrix:
15 | operating-system: [ubuntu-latest, windows-latest, macos-latest]
16 | python-version: [3.7, 3.8, 3.9]
17 | fail-fast: false
18 |
19 | steps:
20 | - name: Checkout
21 | uses: actions/checkout@v2
22 | - name: Set up Python
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: ${{ matrix.python-version }}
26 | - name: Restore Ubuntu cache
27 | uses: actions/cache@v1
28 | if: matrix.operating-system == 'ubuntu-latest'
29 | with:
30 | path: ~/.cache/pip
31 | key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}}
32 | restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}-
33 | - name: Restore MacOS cache
34 | uses: actions/cache@v1
35 | if: matrix.operating-system == 'macos-latest'
36 | with:
37 | path: ~/Library/Caches/pip
38 | key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}}
39 | restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}-
40 | - name: Restore Windows cache
41 | uses: actions/cache@v1
42 | if: matrix.operating-system == 'windows-latest'
43 | with:
44 | path: ~\AppData\Local\pip\Cache
45 | key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}}
46 | restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}-
47 | - name: Update pip
48 | run: python -m pip install --upgrade pip
49 | - name: Install PyTorch on Linux and Windows
50 | if: >
51 | matrix.operating-system == 'ubuntu-latest' ||
52 | matrix.operating-system == 'windows-latest'
53 | run: >
54 | pip install torch==1.8.1+cpu torchvision==0.9.1+cpu
55 | -f https://download.pytorch.org/whl/torch_stable.html
56 | - name: Install PyTorch on MacOS
57 | if: matrix.operating-system == 'macos-latest'
58 | run: pip install torch==1.8.1 torchvision==0.9.1
59 | - name: Install dependencies
60 | run: |
61 | python -m pip install --upgrade pip
62 | pip install -r requirements.txt
63 | - name: Lint with flake8
64 | run: |
65 | pip install flake8
66 | # stop the build if there are Python syntax errors or undefined names
67 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
68 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
69 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
70 | - name: Test with unittest
71 | run: |
72 | python -m unittest
73 |
--------------------------------------------------------------------------------
/.github/workflows/package_testing.yml:
--------------------------------------------------------------------------------
1 | name: Package Testing
2 |
3 | on:
4 | schedule:
5 | - cron: '0 0 * * *' # Runs at 00:00 UTC every day
6 |
7 | jobs:
8 | build:
9 | runs-on: ubuntu-latest
10 |
11 | strategy:
12 | matrix:
13 | operating-system: [ubuntu-latest, windows-latest, macos-latest]
14 | python-version: [3.7, 3.8, 3.9]
15 | fail-fast: false
16 |
17 | steps:
18 | - name: Checkout
19 | uses: actions/checkout@v2
20 | - name: Set up Python
21 | uses: actions/setup-python@v2
22 | with:
23 | python-version: ${{ matrix.python-version }}
24 | - name: Restore Ubuntu cache
25 | uses: actions/cache@v1
26 | if: matrix.operating-system == 'ubuntu-latest'
27 | with:
28 | path: ~/.cache/pip
29 | key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}}
30 | restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}-
31 | - name: Restore MacOS cache
32 | uses: actions/cache@v1
33 | if: matrix.operating-system == 'macos-latest'
34 | with:
35 | path: ~/Library/Caches/pip
36 | key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}}
37 | restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}-
38 | - name: Restore Windows cache
39 | uses: actions/cache@v1
40 | if: matrix.operating-system == 'windows-latest'
41 | with:
42 | path: ~\AppData\Local\pip\Cache
43 | key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}}
44 | restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}-
45 | - name: Update pip
46 | run: python -m pip install --upgrade pip
47 | - name: Install PyTorch on Linux and Windows
48 | if: >
49 | matrix.operating-system == 'ubuntu-latest' ||
50 | matrix.operating-system == 'windows-latest'
51 | run: >
52 | pip install torch==1.8.1+cpu torchvision==0.9.1+cpu
53 | -f https://download.pytorch.org/whl/torch_stable.html
54 | - name: Install PyTorch on MacOS
55 | if: matrix.operating-system == 'macos-latest'
56 | run: pip install torch==1.8.1 torchvision==0.9.1
57 | - name: Install latest craft-text-detector package
58 | run: >
59 | pip install --upgrade --force-reinstall craft-text-detector
60 | - name: Test with unittest
61 | run: |
62 | python -m unittest
--------------------------------------------------------------------------------
/.github/workflows/publish_pypi.yml:
--------------------------------------------------------------------------------
1 | name: Upload Python Package
2 |
3 | on:
4 | release:
5 | types: [published, edited]
6 |
7 | jobs:
8 | deploy:
9 |
10 | runs-on: ubuntu-latest
11 |
12 | steps:
13 | - uses: actions/checkout@v2
14 | - name: Set up Python
15 | uses: actions/setup-python@v2
16 | with:
17 | python-version: '3.x'
18 | - name: Install dependencies
19 | run: |
20 | python -m pip install --upgrade pip
21 | pip install setuptools wheel twine
22 | - name: Build and publish
23 | env:
24 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
25 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
26 | run: |
27 | python setup.py sdist bdist_wheel
28 | twine upload dist/*
29 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.swp
3 | *.pkl
4 | *.pth
5 | result*
6 | weights*
7 | .vscode
8 | .pypirc
9 |
10 | # Byte-compiled / optimized / DLL files
11 | __pycache__/
12 | *.py[cod]
13 | *$py.class
14 |
15 | # C extensions
16 | *.so
17 |
18 | # Distribution / packaging
19 | .Python
20 | build/
21 | develop-eggs/
22 | dist/
23 | downloads/
24 | eggs/
25 | .eggs/
26 | lib/
27 | lib64/
28 | parts/
29 | sdist/
30 | var/
31 | wheels/
32 | pip-wheel-metadata/
33 | share/python-wheels/
34 | *.egg-info/
35 | .installed.cfg
36 | *.egg
37 | MANIFEST
38 |
39 | # PyInstaller
40 | # Usually these files are written by a python script from a template
41 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
42 | *.manifest
43 | *.spec
44 |
45 | # Installer logs
46 | pip-log.txt
47 | pip-delete-this-directory.txt
48 |
49 | # Unit test / coverage reports
50 | htmlcov/
51 | .tox/
52 | .nox/
53 | .coverage
54 | .coverage.*
55 | .cache
56 | nosetests.xml
57 | coverage.xml
58 | *.cover
59 | *.py,cover
60 | .hypothesis/
61 | .pytest_cache/
62 | cover/
63 |
64 | # Translations
65 | *.mo
66 | *.pot
67 |
68 | # Django stuff:
69 | *.log
70 | local_settings.py
71 | db.sqlite3
72 | db.sqlite3-journal
73 |
74 | # Flask stuff:
75 | instance/
76 | .webassets-cache
77 |
78 | # Scrapy stuff:
79 | .scrapy
80 |
81 | # Sphinx documentation
82 | docs/_build/
83 |
84 | # PyBuilder
85 | .pybuilder/
86 | target/
87 |
88 | # Jupyter Notebook
89 | .ipynb_checkpoints
90 |
91 | # IPython
92 | profile_default/
93 | ipython_config.py
94 |
95 | # pyenv
96 | # For a library or package, you might want to ignore these files since the code is
97 | # intended to run in multiple environments; otherwise, check them in:
98 | # .python-version
99 |
100 | # pipenv
101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
104 | # install all needed dependencies.
105 | #Pipfile.lock
106 |
107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
108 | __pypackages__/
109 |
110 | # Celery stuff
111 | celerybeat-schedule
112 | celerybeat.pid
113 |
114 | # SageMath parsed files
115 | *.sage.py
116 |
117 | # Environments
118 | .env
119 | .venv
120 | env/
121 | venv/
122 | ENV/
123 | env.bak/
124 | venv.bak/
125 |
126 | # Spyder project settings
127 | .spyderproject
128 | .spyproject
129 |
130 | # Rope project settings
131 | .ropeproject
132 |
133 | # mkdocs documentation
134 | /site
135 |
136 | # mypy
137 | .mypy_cache/
138 | .dmypy.json
139 | dmypy.json
140 |
141 | # Pyre type checker
142 | .pyre/
143 |
144 | # pytype static type analyzer
145 | .pytype/
146 |
147 | # Cython debug symbols
148 | cython_debug/
149 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2019 NAVER Corp., 2020 Fatih C Akyon
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CRAFT: Character-Region Awareness For Text detection
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 | Packaged, Pytorch-based, easy to use, cross-platform version of the CRAFT text detector | [Paper](https://arxiv.org/abs/1904.01941) |
14 |
15 | ## Overview
16 |
17 | PyTorch implementation for CRAFT text detector that effectively detect text area by exploring each character region and affinity between characters. The bounding box of texts are obtained by simply finding minimum bounding rectangles on binary map after thresholding character region and affinity scores.
18 |
19 |
20 |
21 | ## Getting started
22 |
23 | ### Installation
24 |
25 | - Install using pip:
26 |
27 | ```console
28 | pip install craft-text-detector
29 | ```
30 |
31 | ### Basic Usage
32 |
33 | ```python
34 | # import Craft class
35 | from craft_text_detector import Craft
36 |
37 | # set image path and export folder directory
38 | image = 'figures/idcard.png' # can be filepath, PIL image or numpy array
39 | output_dir = 'outputs/'
40 |
41 | # create a craft instance
42 | craft = Craft(output_dir=output_dir, crop_type="poly", cuda=False)
43 |
44 | # apply craft text detection and export detected regions to output directory
45 | prediction_result = craft.detect_text(image)
46 |
47 | # unload models from ram/gpu
48 | craft.unload_craftnet_model()
49 | craft.unload_refinenet_model()
50 | ```
51 |
52 | ### Advanced Usage
53 |
54 | ```python
55 | # import craft functions
56 | from craft_text_detector import (
57 | read_image,
58 | load_craftnet_model,
59 | load_refinenet_model,
60 | get_prediction,
61 | export_detected_regions,
62 | export_extra_results,
63 | empty_cuda_cache
64 | )
65 |
66 | # set image path and export folder directory
67 | image = 'figures/idcard.png' # can be filepath, PIL image or numpy array
68 | output_dir = 'outputs/'
69 |
70 | # read image
71 | image = read_image(image)
72 |
73 | # load models
74 | refine_net = load_refinenet_model(cuda=True)
75 | craft_net = load_craftnet_model(cuda=True)
76 |
77 | # perform prediction
78 | prediction_result = get_prediction(
79 | image=image,
80 | craft_net=craft_net,
81 | refine_net=refine_net,
82 | text_threshold=0.7,
83 | link_threshold=0.4,
84 | low_text=0.4,
85 | cuda=True,
86 | long_size=1280
87 | )
88 |
89 | # export detected text regions
90 | exported_file_paths = export_detected_regions(
91 | image=image,
92 | regions=prediction_result["boxes"],
93 | output_dir=output_dir,
94 | rectify=True
95 | )
96 |
97 | # export heatmap, detection points, box visualization
98 | export_extra_results(
99 | image=image,
100 | regions=prediction_result["boxes"],
101 | heatmaps=prediction_result["heatmaps"],
102 | output_dir=output_dir
103 | )
104 |
105 | # unload models from gpu
106 | empty_cuda_cache()
107 | ```
108 |
--------------------------------------------------------------------------------
/craft_text_detector/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import os
4 | from typing import Optional
5 |
6 | import craft_text_detector.craft_utils as craft_utils
7 | import craft_text_detector.file_utils as file_utils
8 | import craft_text_detector.image_utils as image_utils
9 | import craft_text_detector.predict as predict
10 | import craft_text_detector.torch_utils as torch_utils
11 |
12 | __version__ = "0.4.3"
13 |
14 |
15 | __all__ = [
16 | "read_image",
17 | "load_craftnet_model",
18 | "load_refinenet_model",
19 | "get_prediction",
20 | "export_detected_regions",
21 | "export_extra_results",
22 | "empty_cuda_cache",
23 | "Craft",
24 | ]
25 |
26 | read_image = image_utils.read_image
27 | load_craftnet_model = craft_utils.load_craftnet_model
28 | load_refinenet_model = craft_utils.load_refinenet_model
29 | get_prediction = predict.get_prediction
30 | export_detected_regions = file_utils.export_detected_regions
31 | export_extra_results = file_utils.export_extra_results
32 | empty_cuda_cache = torch_utils.empty_cuda_cache
33 |
34 |
35 | class Craft:
36 | def __init__(
37 | self,
38 | output_dir=None,
39 | rectify=True,
40 | export_extra=True,
41 | text_threshold=0.7,
42 | link_threshold=0.4,
43 | low_text=0.4,
44 | cuda=False,
45 | long_size=1280,
46 | refiner=True,
47 | crop_type="poly",
48 | weight_path_craft_net: Optional[str] = None,
49 | weight_path_refine_net: Optional[str] = None,
50 | ):
51 | """
52 | Arguments:
53 | output_dir: path to the results to be exported
54 | rectify: rectify detected polygon by affine transform
55 | export_extra: export heatmap, detection points, box visualization
56 | text_threshold: text confidence threshold
57 | link_threshold: link confidence threshold
58 | low_text: text low-bound score
59 | cuda: Use cuda for inference
60 | long_size: desired longest image size for inference
61 | refiner: enable link refiner
62 | crop_type: crop regions by detected boxes or polys ("poly" or "box")
63 | """
64 | self.craft_net = None
65 | self.refine_net = None
66 | self.output_dir = output_dir
67 | self.rectify = rectify
68 | self.export_extra = export_extra
69 | self.text_threshold = text_threshold
70 | self.link_threshold = link_threshold
71 | self.low_text = low_text
72 | self.cuda = cuda
73 | self.long_size = long_size
74 | self.refiner = refiner
75 | self.crop_type = crop_type
76 |
77 | # load craftnet
78 | self.load_craftnet_model(weight_path_craft_net)
79 | # load refinernet if required
80 | if refiner:
81 | self.load_refinenet_model(weight_path_refine_net)
82 |
83 | def load_craftnet_model(self, weight_path: Optional[str] = None):
84 | """
85 | Loads craftnet model
86 | """
87 | self.craft_net = load_craftnet_model(self.cuda, weight_path=weight_path)
88 |
89 | def load_refinenet_model(self, weight_path: Optional[str] = None):
90 | """
91 | Loads refinenet model
92 | """
93 | self.refine_net = load_refinenet_model(self.cuda, weight_path=weight_path)
94 |
95 | def unload_craftnet_model(self):
96 | """
97 | Unloads craftnet model
98 | """
99 | self.craft_net = None
100 | empty_cuda_cache()
101 |
102 | def unload_refinenet_model(self):
103 | """
104 | Unloads refinenet model
105 | """
106 | self.refine_net = None
107 | empty_cuda_cache()
108 |
109 | def detect_text(self, image, image_path=None):
110 | """
111 | Arguments:
112 | image: path to the image to be processed or numpy array or PIL image
113 |
114 | Output:
115 | {
116 | "masks": lists of predicted masks 2d as bool array,
117 | "boxes": list of coords of points of predicted boxes,
118 | "boxes_as_ratios": list of coords of points of predicted boxes as ratios of image size,
119 | "polys_as_ratios": list of coords of points of predicted polys as ratios of image size,
120 | "heatmaps": visualization of the detected characters/links,
121 | "text_crop_paths": list of paths of the exported text boxes/polys,
122 | "times": elapsed times of the sub modules, in seconds
123 | }
124 | """
125 |
126 | if image_path is not None:
127 | print("Argument 'image_path' is deprecated, use 'image' instead.")
128 | image = image_path
129 |
130 | # perform prediction
131 | prediction_result = get_prediction(
132 | image=image,
133 | craft_net=self.craft_net,
134 | refine_net=self.refine_net,
135 | text_threshold=self.text_threshold,
136 | link_threshold=self.link_threshold,
137 | low_text=self.low_text,
138 | cuda=self.cuda,
139 | long_size=self.long_size,
140 | )
141 |
142 | # arange regions
143 | if self.crop_type == "box":
144 | regions = prediction_result["boxes"]
145 | elif self.crop_type == "poly":
146 | regions = prediction_result["polys"]
147 | else:
148 | raise TypeError("crop_type can be only 'polys' or 'boxes'")
149 |
150 | # export if output_dir is given
151 | prediction_result["text_crop_paths"] = []
152 | if self.output_dir is not None:
153 | # export detected text regions
154 | if type(image) == str:
155 | file_name, file_ext = os.path.splitext(os.path.basename(image))
156 | else:
157 | file_name = "image"
158 | exported_file_paths = export_detected_regions(
159 | image=image,
160 | regions=regions,
161 | file_name=file_name,
162 | output_dir=self.output_dir,
163 | rectify=self.rectify,
164 | )
165 | prediction_result["text_crop_paths"] = exported_file_paths
166 |
167 | # export heatmap, detection points, box visualization
168 | if self.export_extra:
169 | export_extra_results(
170 | image=image,
171 | regions=regions,
172 | heatmaps=prediction_result["heatmaps"],
173 | file_name=file_name,
174 | output_dir=self.output_dir,
175 | )
176 |
177 | # return prediction results
178 | return prediction_result
179 |
--------------------------------------------------------------------------------
/craft_text_detector/craft_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | from collections import OrderedDict
4 | from pathlib import Path
5 | from typing import Optional, Union
6 |
7 | import cv2
8 | import numpy as np
9 |
10 | import craft_text_detector.file_utils as file_utils
11 | import craft_text_detector.torch_utils as torch_utils
12 |
13 | CRAFT_GDRIVE_URL = "https://drive.google.com/uc?id=1bupFXqT-VU6Jjeul13XP7yx2Sg5IHr4J"
14 | REFINENET_GDRIVE_URL = (
15 | "https://drive.google.com/uc?id=1xcE9qpJXp4ofINwXWVhhQIh9S8Z7cuGj"
16 | )
17 |
18 |
19 | # unwarp corodinates
20 | def warpCoord(Minv, pt):
21 | out = np.matmul(Minv, (pt[0], pt[1], 1))
22 | return np.array([out[0] / out[2], out[1] / out[2]])
23 |
24 |
25 | def copyStateDict(state_dict):
26 | if list(state_dict.keys())[0].startswith("module"):
27 | start_idx = 1
28 | else:
29 | start_idx = 0
30 | new_state_dict = OrderedDict()
31 | for k, v in state_dict.items():
32 | name = ".".join(k.split(".")[start_idx:])
33 | new_state_dict[name] = v
34 | return new_state_dict
35 |
36 |
37 | def load_craftnet_model(
38 | cuda: bool = False,
39 | weight_path: Optional[Union[str, Path]] = None
40 | ):
41 | # get craft net path
42 | if weight_path is None:
43 | home_path = str(Path.home())
44 | weight_path = Path(
45 | home_path,
46 | ".craft_text_detector",
47 | "weights",
48 | "craft_mlt_25k.pth"
49 | )
50 | weight_path = Path(weight_path).resolve()
51 | weight_path.parent.mkdir(exist_ok=True, parents=True)
52 | weight_path = str(weight_path)
53 |
54 | # load craft net
55 | from craft_text_detector.models.craftnet import CraftNet
56 |
57 | craft_net = CraftNet() # initialize
58 |
59 | # check if weights are already downloaded, if not download
60 | url = CRAFT_GDRIVE_URL
61 | if not os.path.isfile(weight_path):
62 | print("Craft text detector weight will be downloaded to {}".format(weight_path))
63 |
64 | file_utils.download(url=url, save_path=weight_path)
65 |
66 | # arange device
67 | if cuda:
68 | craft_net.load_state_dict(copyStateDict(torch_utils.load(weight_path)))
69 |
70 | craft_net = craft_net.cuda()
71 | craft_net = torch_utils.DataParallel(craft_net)
72 | torch_utils.cudnn_benchmark = False
73 | else:
74 | craft_net.load_state_dict(
75 | copyStateDict(torch_utils.load(weight_path, map_location="cpu"))
76 | )
77 | craft_net.eval()
78 | return craft_net
79 |
80 |
81 | def load_refinenet_model(
82 | cuda: bool = False,
83 | weight_path: Optional[Union[str, Path]] = None
84 | ):
85 | # get refine net path
86 | if weight_path is None:
87 | home_path = Path.home()
88 | weight_path = Path(
89 | home_path,
90 | ".craft_text_detector",
91 | "weights",
92 | "craft_refiner_CTW1500.pth"
93 | )
94 | weight_path = Path(weight_path).resolve()
95 | weight_path.parent.mkdir(exist_ok=True, parents=True)
96 | weight_path = str(weight_path)
97 |
98 | # load refine net
99 | from craft_text_detector.models.refinenet import RefineNet
100 |
101 | refine_net = RefineNet() # initialize
102 |
103 | # check if weights are already downloaded, if not download
104 | url = REFINENET_GDRIVE_URL
105 | if not os.path.isfile(weight_path):
106 | print("Craft text refiner weight will be downloaded to {}".format(weight_path))
107 |
108 | file_utils.download(url=url, save_path=weight_path)
109 |
110 | # arange device
111 | if cuda:
112 | refine_net.load_state_dict(copyStateDict(torch_utils.load(weight_path)))
113 |
114 | refine_net = refine_net.cuda()
115 | refine_net = torch_utils.DataParallel(refine_net)
116 | torch_utils.cudnn_benchmark = False
117 | else:
118 | refine_net.load_state_dict(
119 | copyStateDict(torch_utils.load(weight_path, map_location="cpu"))
120 | )
121 | refine_net.eval()
122 | return refine_net
123 |
124 |
125 | def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text):
126 | # prepare data
127 | linkmap = linkmap.copy()
128 | textmap = textmap.copy()
129 | img_h, img_w = textmap.shape
130 |
131 | """ labeling method """
132 | ret, text_score = cv2.threshold(textmap, low_text, 1, 0)
133 | ret, link_score = cv2.threshold(linkmap, link_threshold, 1, 0)
134 |
135 | text_score_comb = np.clip(text_score + link_score, 0, 1)
136 | nLabels, labels, stats, centroids = cv2.connectedComponentsWithStats(
137 | text_score_comb.astype(np.uint8), connectivity=4
138 | )
139 |
140 | det = []
141 | mapper = []
142 | for k in range(1, nLabels):
143 | # size filtering
144 | size = stats[k, cv2.CC_STAT_AREA]
145 | if size < 10:
146 | continue
147 |
148 | # thresholding
149 | if np.max(textmap[labels == k]) < text_threshold:
150 | continue
151 |
152 | # make segmentation map
153 | segmap = np.zeros(textmap.shape, dtype=np.uint8)
154 | segmap[labels == k] = 255
155 |
156 | # remove link area
157 | segmap[np.logical_and(link_score == 1, text_score == 0)] = 0
158 |
159 | x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP]
160 | w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT]
161 | niter = int(math.sqrt(size * min(w, h) / (w * h)) * 2)
162 | sx, ex, sy, ey = (x - niter, x + w + niter + 1, y - niter, y + h + niter + 1)
163 | # boundary check
164 | if sx < 0:
165 | sx = 0
166 | if sy < 0:
167 | sy = 0
168 | if ex >= img_w:
169 | ex = img_w
170 | if ey >= img_h:
171 | ey = img_h
172 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1 + niter, 1 + niter))
173 | segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel)
174 |
175 | # make box
176 | np_temp = np.roll(np.array(np.where(segmap != 0)), 1, axis=0)
177 | np_contours = np_temp.transpose().reshape(-1, 2)
178 | rectangle = cv2.minAreaRect(np_contours)
179 | box = cv2.boxPoints(rectangle)
180 |
181 | # boundary check due to minAreaRect may have out of range values
182 | # (see https://docs.opencv.org/3.4/d3/dc0/group__imgproc__shape.html#ga3d476a3417130ae5154aea421ca7ead9)
183 | for p in box:
184 | if p[0] < 0:
185 | p[0] = 0
186 | if p[1] < 0:
187 | p[1] = 0
188 | if p[0] >= img_w:
189 | p[0] = img_w
190 | if p[1] >= img_h:
191 | p[1] = img_h
192 |
193 | # align diamond-shape
194 | w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2])
195 | box_ratio = max(w, h) / (min(w, h) + 1e-5)
196 | if abs(1 - box_ratio) <= 0.1:
197 | l, r = min(np_contours[:, 0]), max(np_contours[:, 0])
198 | t, b = min(np_contours[:, 1]), max(np_contours[:, 1])
199 | box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32)
200 |
201 | # make clock-wise order
202 | startidx = box.sum(axis=1).argmin()
203 | box = np.roll(box, 4 - startidx, 0)
204 | box = np.array(box)
205 |
206 | det.append(box)
207 | mapper.append(k)
208 |
209 | return det, labels, mapper
210 |
211 |
212 | def getPoly_core(boxes, labels, mapper, linkmap):
213 | # configs
214 | num_cp = 5
215 | max_len_ratio = 0.7
216 | expand_ratio = 1.45
217 | max_r = 2.0
218 | step_r = 0.2
219 |
220 | polys = []
221 | for k, box in enumerate(boxes):
222 | # size filter for small instance
223 | w, h = (
224 | int(np.linalg.norm(box[0] - box[1]) + 1),
225 | int(np.linalg.norm(box[1] - box[2]) + 1),
226 | )
227 | if w < 10 or h < 10:
228 | polys.append(None)
229 | continue
230 |
231 | # warp image
232 | tar = np.float32([[0, 0], [w, 0], [w, h], [0, h]])
233 | M = cv2.getPerspectiveTransform(box, tar)
234 | word_label = cv2.warpPerspective(labels, M, (w, h), flags=cv2.INTER_NEAREST)
235 | try:
236 | Minv = np.linalg.inv(M)
237 | except:
238 | polys.append(None)
239 | continue
240 |
241 | # binarization for selected label
242 | cur_label = mapper[k]
243 | word_label[word_label != cur_label] = 0
244 | word_label[word_label > 0] = 1
245 |
246 | """ Polygon generation """
247 | # find top/bottom contours
248 | cp = []
249 | max_len = -1
250 | for i in range(w):
251 | region = np.where(word_label[:, i] != 0)[0]
252 | if len(region) < 2:
253 | continue
254 | cp.append((i, region[0], region[-1]))
255 | length = region[-1] - region[0] + 1
256 | if length > max_len:
257 | max_len = length
258 |
259 | # pass if max_len is similar to h
260 | if h * max_len_ratio < max_len:
261 | polys.append(None)
262 | continue
263 |
264 | # get pivot points with fixed length
265 | tot_seg = num_cp * 2 + 1
266 | seg_w = w / tot_seg # segment width
267 | pp = [None] * num_cp # init pivot points
268 | cp_section = [[0, 0]] * tot_seg
269 | seg_height = [0] * num_cp
270 | seg_num = 0
271 | num_sec = 0
272 | prev_h = -1
273 | for i in range(0, len(cp)):
274 | (x, sy, ey) = cp[i]
275 | if (seg_num + 1) * seg_w <= x and seg_num <= tot_seg:
276 | # average previous segment
277 | if num_sec == 0:
278 | break
279 | cp_section[seg_num] = [
280 | cp_section[seg_num][0] / num_sec,
281 | cp_section[seg_num][1] / num_sec,
282 | ]
283 | num_sec = 0
284 |
285 | # reset variables
286 | seg_num += 1
287 | prev_h = -1
288 |
289 | # accumulate center points
290 | cy = (sy + ey) * 0.5
291 | cur_h = ey - sy + 1
292 | cp_section[seg_num] = [
293 | cp_section[seg_num][0] + x,
294 | cp_section[seg_num][1] + cy,
295 | ]
296 | num_sec += 1
297 |
298 | if seg_num % 2 == 0:
299 | continue # No polygon area
300 |
301 | if prev_h < cur_h:
302 | pp[int((seg_num - 1) / 2)] = (x, cy)
303 | seg_height[int((seg_num - 1) / 2)] = cur_h
304 | prev_h = cur_h
305 |
306 | # processing last segment
307 | if num_sec != 0:
308 | cp_section[-1] = [cp_section[-1][0] / num_sec, cp_section[-1][1] / num_sec]
309 |
310 | # pass if num of pivots is not sufficient or segment widh
311 | # is smaller than character height
312 | if None in pp or seg_w < np.max(seg_height) * 0.25:
313 | polys.append(None)
314 | continue
315 |
316 | # calc median maximum of pivot points
317 | half_char_h = np.median(seg_height) * expand_ratio / 2
318 |
319 | # calc gradiant and apply to make horizontal pivots
320 | new_pp = []
321 | for i, (x, cy) in enumerate(pp):
322 | dx = cp_section[i * 2 + 2][0] - cp_section[i * 2][0]
323 | dy = cp_section[i * 2 + 2][1] - cp_section[i * 2][1]
324 | if dx == 0: # gradient if zero
325 | new_pp.append([x, cy - half_char_h, x, cy + half_char_h])
326 | continue
327 | rad = -math.atan2(dy, dx)
328 | c, s = half_char_h * math.cos(rad), half_char_h * math.sin(rad)
329 | new_pp.append([x - s, cy - c, x + s, cy + c])
330 |
331 | # get edge points to cover character heatmaps
332 | isSppFound, isEppFound = False, False
333 | grad_s = (pp[1][1] - pp[0][1]) / (pp[1][0] - pp[0][0]) + (
334 | pp[2][1] - pp[1][1]
335 | ) / (pp[2][0] - pp[1][0])
336 | grad_e = (pp[-2][1] - pp[-1][1]) / (pp[-2][0] - pp[-1][0]) + (
337 | pp[-3][1] - pp[-2][1]
338 | ) / (pp[-3][0] - pp[-2][0])
339 | for r in np.arange(0.5, max_r, step_r):
340 | dx = 2 * half_char_h * r
341 | if not isSppFound:
342 | line_img = np.zeros(word_label.shape, dtype=np.uint8)
343 | dy = grad_s * dx
344 | p = np.array(new_pp[0]) - np.array([dx, dy, dx, dy])
345 | cv2.line(
346 | line_img,
347 | (int(p[0]), int(p[1])),
348 | (int(p[2]), int(p[3])),
349 | 1,
350 | thickness=1,
351 | )
352 | if (
353 | np.sum(np.logical_and(word_label, line_img)) == 0
354 | or r + 2 * step_r >= max_r
355 | ):
356 | spp = p
357 | isSppFound = True
358 | if not isEppFound:
359 | line_img = np.zeros(word_label.shape, dtype=np.uint8)
360 | dy = grad_e * dx
361 | p = np.array(new_pp[-1]) + np.array([dx, dy, dx, dy])
362 | cv2.line(
363 | line_img,
364 | (int(p[0]), int(p[1])),
365 | (int(p[2]), int(p[3])),
366 | 1,
367 | thickness=1,
368 | )
369 | if (
370 | np.sum(np.logical_and(word_label, line_img)) == 0
371 | or r + 2 * step_r >= max_r
372 | ):
373 | epp = p
374 | isEppFound = True
375 | if isSppFound and isEppFound:
376 | break
377 |
378 | # pass if boundary of polygon is not found
379 | if not (isSppFound and isEppFound):
380 | polys.append(None)
381 | continue
382 |
383 | # make final polygon
384 | poly = []
385 | poly.append(warpCoord(Minv, (spp[0], spp[1])))
386 | for p in new_pp:
387 | poly.append(warpCoord(Minv, (p[0], p[1])))
388 | poly.append(warpCoord(Minv, (epp[0], epp[1])))
389 | poly.append(warpCoord(Minv, (epp[2], epp[3])))
390 | for p in reversed(new_pp):
391 | poly.append(warpCoord(Minv, (p[2], p[3])))
392 | poly.append(warpCoord(Minv, (spp[2], spp[3])))
393 |
394 | # add to final result
395 | polys.append(np.array(poly))
396 |
397 | return polys
398 |
399 |
400 | def getDetBoxes(textmap, linkmap, text_threshold, link_threshold, low_text, poly=False):
401 | boxes, labels, mapper = getDetBoxes_core(
402 | textmap, linkmap, text_threshold, link_threshold, low_text
403 | )
404 |
405 | if poly:
406 | polys = getPoly_core(boxes, labels, mapper, linkmap)
407 | else:
408 | polys = [None] * len(boxes)
409 |
410 | return boxes, polys
411 |
412 |
413 | def adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net=2):
414 | if len(polys) > 0:
415 | polys = np.array(polys)
416 | for k in range(len(polys)):
417 | if polys[k] is not None:
418 | polys[k] *= (ratio_w * ratio_net, ratio_h * ratio_net)
419 | return polys
420 |
--------------------------------------------------------------------------------
/craft_text_detector/file_utils.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import os
3 |
4 | import cv2
5 | import gdown
6 | import numpy as np
7 |
8 | from craft_text_detector.image_utils import read_image
9 |
10 |
11 | def download(url: str, save_path: str):
12 | """
13 | Downloads file from gdrive, shows progress.
14 | Example inputs:
15 | url: 'ftp://smartengines.com/midv-500/dataset/01_alb_id.zip'
16 | save_path: 'data/file.zip'
17 | """
18 |
19 | # create save_dir if not present
20 | create_dir(os.path.dirname(save_path))
21 | # download file
22 | gdown.download(url, save_path, quiet=False)
23 |
24 |
25 | def create_dir(_dir):
26 | """
27 | Creates given directory if it is not present.
28 | """
29 | if not os.path.exists(_dir):
30 | os.makedirs(_dir)
31 |
32 |
33 | def get_files(img_dir):
34 | imgs, masks, xmls = list_files(img_dir)
35 | return imgs, masks, xmls
36 |
37 |
38 | def list_files(in_path):
39 | img_files = []
40 | mask_files = []
41 | gt_files = []
42 | for (dirpath, dirnames, filenames) in os.walk(in_path):
43 | for file in filenames:
44 | filename, ext = os.path.splitext(file)
45 | ext = str.lower(ext)
46 | if (
47 | ext == ".jpg"
48 | or ext == ".jpeg"
49 | or ext == ".gif"
50 | or ext == ".png"
51 | or ext == ".pgm"
52 | ):
53 | img_files.append(os.path.join(dirpath, file))
54 | elif ext == ".bmp":
55 | mask_files.append(os.path.join(dirpath, file))
56 | elif ext == ".xml" or ext == ".gt" or ext == ".txt":
57 | gt_files.append(os.path.join(dirpath, file))
58 | elif ext == ".zip":
59 | continue
60 | # img_files.sort()
61 | # mask_files.sort()
62 | # gt_files.sort()
63 | return img_files, mask_files, gt_files
64 |
65 |
66 | def rectify_poly(img, poly):
67 | # Use Affine transform
68 | n = int(len(poly) / 2) - 1
69 | width = 0
70 | height = 0
71 | for k in range(n):
72 | box = np.float32([poly[k], poly[k + 1], poly[-k - 2], poly[-k - 1]])
73 | width += int(
74 | (np.linalg.norm(box[0] - box[1]) + np.linalg.norm(box[2] - box[3])) / 2
75 | )
76 | height += np.linalg.norm(box[1] - box[2])
77 | width = int(width)
78 | height = int(height / n)
79 |
80 | output_img = np.zeros((height, width, 3), dtype=np.uint8)
81 | width_step = 0
82 | for k in range(n):
83 | box = np.float32([poly[k], poly[k + 1], poly[-k - 2], poly[-k - 1]])
84 | w = int((np.linalg.norm(box[0] - box[1]) + np.linalg.norm(box[2] - box[3])) / 2)
85 |
86 | # Top triangle
87 | pts1 = box[:3]
88 | pts2 = np.float32(
89 | [[width_step, 0], [width_step + w - 1, 0], [width_step + w - 1, height - 1]]
90 | )
91 | M = cv2.getAffineTransform(pts1, pts2)
92 | warped_img = cv2.warpAffine(
93 | img, M, (width, height), borderMode=cv2.BORDER_REPLICATE
94 | )
95 | warped_mask = np.zeros((height, width, 3), dtype=np.uint8)
96 | warped_mask = cv2.fillConvexPoly(warped_mask, np.int32(pts2), (1, 1, 1))
97 | output_img[warped_mask == 1] = warped_img[warped_mask == 1]
98 |
99 | # Bottom triangle
100 | pts1 = np.vstack((box[0], box[2:]))
101 | pts2 = np.float32(
102 | [
103 | [width_step, 0],
104 | [width_step + w - 1, height - 1],
105 | [width_step, height - 1],
106 | ]
107 | )
108 | M = cv2.getAffineTransform(pts1, pts2)
109 | warped_img = cv2.warpAffine(
110 | img, M, (width, height), borderMode=cv2.BORDER_REPLICATE
111 | )
112 | warped_mask = np.zeros((height, width, 3), dtype=np.uint8)
113 | warped_mask = cv2.fillConvexPoly(warped_mask, np.int32(pts2), (1, 1, 1))
114 | cv2.line(
115 | warped_mask, (width_step, 0), (width_step + w - 1, height - 1), (0, 0, 0), 1
116 | )
117 | output_img[warped_mask == 1] = warped_img[warped_mask == 1]
118 |
119 | width_step += w
120 | return output_img
121 |
122 |
123 | def crop_poly(image, poly):
124 | # points should have 1*x*2 shape
125 | if len(poly.shape) == 2:
126 | poly = np.array([np.array(poly).astype(np.int32)])
127 |
128 | # create mask with shape of image
129 | mask = np.zeros(image.shape[0:2], dtype=np.uint8)
130 |
131 | # method 1 smooth region
132 | cv2.drawContours(mask, [poly], -1, (255, 255, 255), -1, cv2.LINE_AA)
133 | # method 2 not so smooth region
134 | # cv2.fillPoly(mask, points, (255))
135 |
136 | # crop around poly
137 | res = cv2.bitwise_and(image, image, mask=mask)
138 | rect = cv2.boundingRect(poly) # returns (x,y,w,h) of the rect
139 | cropped = res[rect[1] : rect[1] + rect[3], rect[0] : rect[0] + rect[2]]
140 |
141 | return cropped
142 |
143 |
144 | def export_detected_region(image, poly, file_path, rectify=True):
145 | """
146 | Arguments:
147 | image: full image
148 | points: bbox or poly points
149 | file_path: path to be exported
150 | rectify: rectify detected polygon by affine transform
151 | """
152 | if rectify:
153 | # rectify poly region
154 | result_rgb = rectify_poly(image, poly)
155 | else:
156 | result_rgb = crop_poly(image, poly)
157 |
158 | # export corpped region
159 | result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR)
160 | cv2.imwrite(file_path, result_bgr)
161 |
162 |
163 | def export_detected_regions(
164 | image,
165 | regions,
166 | file_name: str = "image",
167 | output_dir: str = "output/",
168 | rectify: bool = False,
169 | ):
170 | """
171 | Arguments:
172 | image: path to the image to be processed or numpy array or PIL image
173 | regions: list of bboxes or polys
174 | file_name (str): export image file name
175 | output_dir: folder to be exported
176 | rectify: rectify detected polygon by affine transform
177 | """
178 |
179 | # read/convert image
180 | image = read_image(image)
181 |
182 | # deepcopy image so that original is not altered
183 | image = copy.deepcopy(image)
184 |
185 | # create crops dir
186 | crops_dir = os.path.join(output_dir, file_name + "_crops")
187 | create_dir(crops_dir)
188 |
189 | # init exported file paths
190 | exported_file_paths = []
191 |
192 | # export regions
193 | for ind, region in enumerate(regions):
194 | # get export path
195 | file_path = os.path.join(crops_dir, "crop_" + str(ind) + ".png")
196 | # export region
197 | export_detected_region(image, poly=region, file_path=file_path, rectify=rectify)
198 | # note exported file path
199 | exported_file_paths.append(file_path)
200 |
201 | return exported_file_paths
202 |
203 |
204 | def export_extra_results(
205 | image,
206 | regions,
207 | heatmaps,
208 | file_name: str = "image",
209 | output_dir="output/",
210 | verticals=None,
211 | texts=None,
212 | ):
213 | """save text detection result one by one
214 | Args:
215 | image: path to the image to be processed or numpy array or PIL image
216 | file_name (str): export image file name
217 | boxes (array): array of result file
218 | Shape: [num_detections, 4] for BB output / [num_detections, 4]
219 | for QUAD output
220 | Return:
221 | None
222 | """
223 | # read/convert image
224 | image = read_image(image)
225 |
226 | # result directory
227 | res_file = os.path.join(output_dir, file_name + "_text_detection.txt")
228 | res_img_file = os.path.join(output_dir, file_name + "_text_detection.png")
229 | text_heatmap_file = os.path.join(output_dir, file_name + "_text_score_heatmap.png")
230 | link_heatmap_file = os.path.join(output_dir, file_name + "_link_score_heatmap.png")
231 |
232 | # create output dir
233 | create_dir(output_dir)
234 |
235 | # export heatmaps
236 | cv2.imwrite(text_heatmap_file, heatmaps["text_score_heatmap"])
237 | cv2.imwrite(link_heatmap_file, heatmaps["link_score_heatmap"])
238 |
239 | with open(res_file, "w") as f:
240 | for i, region in enumerate(regions):
241 | region = np.array(region).astype(np.int32).reshape((-1))
242 | strResult = ",".join([str(r) for r in region]) + "\r\n"
243 | f.write(strResult)
244 |
245 | region = region.reshape(-1, 2)
246 | cv2.polylines(
247 | image,
248 | [region.reshape((-1, 1, 2))],
249 | True,
250 | color=(0, 0, 255),
251 | thickness=2,
252 | )
253 |
254 | if texts is not None:
255 | font = cv2.FONT_HERSHEY_SIMPLEX
256 | font_scale = 0.5
257 | cv2.putText(
258 | image,
259 | "{}".format(texts[i]),
260 | (region[0][0] + 1, region[0][1] + 1),
261 | font,
262 | font_scale,
263 | (0, 0, 0),
264 | thickness=1,
265 | )
266 | cv2.putText(
267 | image,
268 | "{}".format(texts[i]),
269 | tuple(region[0]),
270 | font,
271 | font_scale,
272 | (0, 255, 255),
273 | thickness=1,
274 | )
275 |
276 | # Save result image
277 | cv2.imwrite(res_img_file, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
278 |
--------------------------------------------------------------------------------
/craft_text_detector/image_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2019-present NAVER Corp.
3 | MIT License
4 | """
5 |
6 | import cv2
7 | import numpy as np
8 |
9 |
10 | def read_image(image):
11 | if type(image) == str:
12 | img = cv2.imread(image)
13 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
14 |
15 | elif type(image) == bytes:
16 | nparr = np.frombuffer(image, np.uint8)
17 | img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
18 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
19 |
20 | elif type(image) == np.ndarray:
21 | if len(image.shape) == 2: # grayscale
22 | img = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
23 | elif len(image.shape) == 3 and image.shape[2] == 3:
24 | img = image
25 | elif len(image.shape) == 3 and image.shape[2] == 4: # RGBAscale
26 | img = image[:, :, :3]
27 |
28 | return img
29 |
30 |
31 | def normalizeMeanVariance(
32 | in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)
33 | ):
34 | # should be RGB order
35 | img = in_img.copy().astype(np.float32)
36 |
37 | img -= np.array(
38 | [mean[0] * 255.0, mean[1] * 255.0, mean[2] * 255.0], dtype=np.float32
39 | )
40 | img /= np.array(
41 | [variance[0] * 255.0, variance[1] * 255.0, variance[2] * 255.0],
42 | dtype=np.float32,
43 | )
44 | return img
45 |
46 |
47 | def denormalizeMeanVariance(
48 | in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)
49 | ):
50 | # should be RGB order
51 | img = in_img.copy()
52 | img *= variance
53 | img += mean
54 | img *= 255.0
55 | img = np.clip(img, 0, 255).astype(np.uint8)
56 | return img
57 |
58 |
59 | def resize_aspect_ratio(img, long_size, interpolation):
60 | height, width, channel = img.shape
61 |
62 | # set target image size
63 | target_size = long_size
64 |
65 | ratio = target_size / max(height, width)
66 |
67 | target_h, target_w = int(height * ratio), int(width * ratio)
68 | proc = cv2.resize(img, (target_w, target_h), interpolation=interpolation)
69 |
70 | # make canvas and paste image
71 | target_h32, target_w32 = target_h, target_w
72 | if target_h % 32 != 0:
73 | target_h32 = target_h + (32 - target_h % 32)
74 | if target_w % 32 != 0:
75 | target_w32 = target_w + (32 - target_w % 32)
76 | resized = np.zeros((target_h32, target_w32, channel), dtype=np.float32)
77 | resized[0:target_h, 0:target_w, :] = proc
78 | target_h, target_w = target_h32, target_w32
79 |
80 | size_heatmap = (int(target_w / 2), int(target_h / 2))
81 |
82 | return resized, ratio, size_heatmap
83 |
84 |
85 | def cvt2HeatmapImg(img):
86 | img = (np.clip(img, 0, 1) * 255).astype(np.uint8)
87 | img = cv2.applyColorMap(img, cv2.COLORMAP_JET)
88 | return img
89 |
--------------------------------------------------------------------------------
/craft_text_detector/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fcakyon/craft-text-detector/6b10d3e0d178679e35796cf4426abb083b354239/craft_text_detector/models/__init__.py
--------------------------------------------------------------------------------
/craft_text_detector/models/basenet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fcakyon/craft-text-detector/6b10d3e0d178679e35796cf4426abb083b354239/craft_text_detector/models/basenet/__init__.py
--------------------------------------------------------------------------------
/craft_text_detector/models/basenet/vgg16_bn.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.init as init
6 | from torchvision import models
7 | from torchvision.models.vgg import model_urls
8 |
9 |
10 | def init_weights(modules):
11 | for m in modules:
12 | if isinstance(m, nn.Conv2d):
13 | init.xavier_uniform_(m.weight.data)
14 | if m.bias is not None:
15 | m.bias.data.zero_()
16 | elif isinstance(m, nn.BatchNorm2d):
17 | m.weight.data.fill_(1)
18 | m.bias.data.zero_()
19 | elif isinstance(m, nn.Linear):
20 | m.weight.data.normal_(0, 0.01)
21 | m.bias.data.zero_()
22 |
23 |
24 | class vgg16_bn(torch.nn.Module):
25 | def __init__(self, pretrained=True, freeze=True):
26 | super(vgg16_bn, self).__init__()
27 | model_urls["vgg16_bn"] = model_urls["vgg16_bn"].replace("https://", "http://")
28 | vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features
29 | self.slice1 = torch.nn.Sequential()
30 | self.slice2 = torch.nn.Sequential()
31 | self.slice3 = torch.nn.Sequential()
32 | self.slice4 = torch.nn.Sequential()
33 | self.slice5 = torch.nn.Sequential()
34 | for x in range(12): # conv2_2
35 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
36 | for x in range(12, 19): # conv3_3
37 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
38 | for x in range(19, 29): # conv4_3
39 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
40 | for x in range(29, 39): # conv5_3
41 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
42 |
43 | # fc6, fc7 without atrous conv
44 | self.slice5 = torch.nn.Sequential(
45 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
46 | nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
47 | nn.Conv2d(1024, 1024, kernel_size=1),
48 | )
49 |
50 | if not pretrained:
51 | init_weights(self.slice1.modules())
52 | init_weights(self.slice2.modules())
53 | init_weights(self.slice3.modules())
54 | init_weights(self.slice4.modules())
55 |
56 | init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7
57 |
58 | if freeze:
59 | for param in self.slice1.parameters(): # only first conv
60 | param.requires_grad = False
61 |
62 | def forward(self, X):
63 | h = self.slice1(X)
64 | h_relu2_2 = h
65 | h = self.slice2(h)
66 | h_relu3_2 = h
67 | h = self.slice3(h)
68 | h_relu4_3 = h
69 | h = self.slice4(h)
70 | h_relu5_3 = h
71 | h = self.slice5(h)
72 | h_fc7 = h
73 | vgg_outputs = namedtuple(
74 | "VggOutputs", ["fc7", "relu5_3", "relu4_3", "relu3_2", "relu2_2"]
75 | )
76 | out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2)
77 | return out
78 |
--------------------------------------------------------------------------------
/craft_text_detector/models/craftnet.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2019-present NAVER Corp.
3 | MIT License
4 | """
5 |
6 | # -*- coding: utf-8 -*-
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | from craft_text_detector.models.basenet.vgg16_bn import vgg16_bn, init_weights
12 |
13 |
14 | class double_conv(nn.Module):
15 | def __init__(self, in_ch, mid_ch, out_ch):
16 | super(double_conv, self).__init__()
17 | self.conv = nn.Sequential(
18 | nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1),
19 | nn.BatchNorm2d(mid_ch),
20 | nn.ReLU(inplace=True),
21 | nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1),
22 | nn.BatchNorm2d(out_ch),
23 | nn.ReLU(inplace=True),
24 | )
25 |
26 | def forward(self, x):
27 | x = self.conv(x)
28 | return x
29 |
30 |
31 | class CraftNet(nn.Module):
32 | def __init__(self, pretrained=False, freeze=False):
33 | super(CraftNet, self).__init__()
34 |
35 | """ Base network """
36 | self.basenet = vgg16_bn(pretrained, freeze)
37 |
38 | """ U network """
39 | self.upconv1 = double_conv(1024, 512, 256)
40 | self.upconv2 = double_conv(512, 256, 128)
41 | self.upconv3 = double_conv(256, 128, 64)
42 | self.upconv4 = double_conv(128, 64, 32)
43 |
44 | num_class = 2
45 | self.conv_cls = nn.Sequential(
46 | nn.Conv2d(32, 32, kernel_size=3, padding=1),
47 | nn.ReLU(inplace=True),
48 | nn.Conv2d(32, 32, kernel_size=3, padding=1),
49 | nn.ReLU(inplace=True),
50 | nn.Conv2d(32, 16, kernel_size=3, padding=1),
51 | nn.ReLU(inplace=True),
52 | nn.Conv2d(16, 16, kernel_size=1),
53 | nn.ReLU(inplace=True),
54 | nn.Conv2d(16, num_class, kernel_size=1),
55 | )
56 |
57 | init_weights(self.upconv1.modules())
58 | init_weights(self.upconv2.modules())
59 | init_weights(self.upconv3.modules())
60 | init_weights(self.upconv4.modules())
61 | init_weights(self.conv_cls.modules())
62 |
63 | def forward(self, x):
64 | """ Base network """
65 | sources = self.basenet(x)
66 |
67 | """ U network """
68 | y = torch.cat([sources[0], sources[1]], dim=1)
69 | y = self.upconv1(y)
70 |
71 | y = F.interpolate(
72 | y, size=sources[2].size()[2:], mode="bilinear", align_corners=False
73 | )
74 | y = torch.cat([y, sources[2]], dim=1)
75 | y = self.upconv2(y)
76 |
77 | y = F.interpolate(
78 | y, size=sources[3].size()[2:], mode="bilinear", align_corners=False
79 | )
80 | y = torch.cat([y, sources[3]], dim=1)
81 | y = self.upconv3(y)
82 |
83 | y = F.interpolate(
84 | y, size=sources[4].size()[2:], mode="bilinear", align_corners=False
85 | )
86 | y = torch.cat([y, sources[4]], dim=1)
87 | feature = self.upconv4(y)
88 |
89 | y = self.conv_cls(feature)
90 |
91 | return y.permute(0, 2, 3, 1), feature
92 |
93 |
94 | if __name__ == "__main__":
95 | model = CraftNet(pretrained=True).cuda()
96 | output, _ = model(torch.randn(1, 3, 768, 768).cuda())
97 | print(output.shape)
98 |
--------------------------------------------------------------------------------
/craft_text_detector/models/refinenet.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2019-present NAVER Corp.
3 | MIT License
4 | """
5 |
6 | # -*- coding: utf-8 -*-
7 | import torch
8 | import torch.nn as nn
9 | from craft_text_detector.models.basenet.vgg16_bn import init_weights
10 |
11 |
12 | class RefineNet(nn.Module):
13 | def __init__(self):
14 | super(RefineNet, self).__init__()
15 |
16 | self.last_conv = nn.Sequential(
17 | nn.Conv2d(34, 64, kernel_size=3, padding=1),
18 | nn.BatchNorm2d(64),
19 | nn.ReLU(inplace=True),
20 | nn.Conv2d(64, 64, kernel_size=3, padding=1),
21 | nn.BatchNorm2d(64),
22 | nn.ReLU(inplace=True),
23 | nn.Conv2d(64, 64, kernel_size=3, padding=1),
24 | nn.BatchNorm2d(64),
25 | nn.ReLU(inplace=True),
26 | )
27 |
28 | self.aspp1 = nn.Sequential(
29 | nn.Conv2d(64, 128, kernel_size=3, dilation=6, padding=6),
30 | nn.BatchNorm2d(128),
31 | nn.ReLU(inplace=True),
32 | nn.Conv2d(128, 128, kernel_size=1),
33 | nn.BatchNorm2d(128),
34 | nn.ReLU(inplace=True),
35 | nn.Conv2d(128, 1, kernel_size=1),
36 | )
37 |
38 | self.aspp2 = nn.Sequential(
39 | nn.Conv2d(64, 128, kernel_size=3, dilation=12, padding=12),
40 | nn.BatchNorm2d(128),
41 | nn.ReLU(inplace=True),
42 | nn.Conv2d(128, 128, kernel_size=1),
43 | nn.BatchNorm2d(128),
44 | nn.ReLU(inplace=True),
45 | nn.Conv2d(128, 1, kernel_size=1),
46 | )
47 |
48 | self.aspp3 = nn.Sequential(
49 | nn.Conv2d(64, 128, kernel_size=3, dilation=18, padding=18),
50 | nn.BatchNorm2d(128),
51 | nn.ReLU(inplace=True),
52 | nn.Conv2d(128, 128, kernel_size=1),
53 | nn.BatchNorm2d(128),
54 | nn.ReLU(inplace=True),
55 | nn.Conv2d(128, 1, kernel_size=1),
56 | )
57 |
58 | self.aspp4 = nn.Sequential(
59 | nn.Conv2d(64, 128, kernel_size=3, dilation=24, padding=24),
60 | nn.BatchNorm2d(128),
61 | nn.ReLU(inplace=True),
62 | nn.Conv2d(128, 128, kernel_size=1),
63 | nn.BatchNorm2d(128),
64 | nn.ReLU(inplace=True),
65 | nn.Conv2d(128, 1, kernel_size=1),
66 | )
67 |
68 | init_weights(self.last_conv.modules())
69 | init_weights(self.aspp1.modules())
70 | init_weights(self.aspp2.modules())
71 | init_weights(self.aspp3.modules())
72 | init_weights(self.aspp4.modules())
73 |
74 | def forward(self, y, upconv4):
75 | refine = torch.cat([y.permute(0, 3, 1, 2), upconv4], dim=1)
76 | refine = self.last_conv(refine)
77 |
78 | aspp1 = self.aspp1(refine)
79 | aspp2 = self.aspp2(refine)
80 | aspp3 = self.aspp3(refine)
81 | aspp4 = self.aspp4(refine)
82 |
83 | # out = torch.add([aspp1, aspp2, aspp3, aspp4], dim=1)
84 | out = aspp1 + aspp2 + aspp3 + aspp4
85 | return out.permute(0, 2, 3, 1) # , refine.permute(0,2,3,1)
86 |
--------------------------------------------------------------------------------
/craft_text_detector/predict.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 |
4 | import cv2
5 | import numpy as np
6 |
7 | import craft_text_detector.craft_utils as craft_utils
8 | import craft_text_detector.image_utils as image_utils
9 | import craft_text_detector.torch_utils as torch_utils
10 |
11 |
12 | def get_prediction(
13 | image,
14 | craft_net,
15 | refine_net=None,
16 | text_threshold: float = 0.7,
17 | link_threshold: float = 0.4,
18 | low_text: float = 0.4,
19 | cuda: bool = False,
20 | long_size: int = 1280,
21 | poly: bool = True,
22 | ):
23 | """
24 | Arguments:
25 | image: path to the image to be processed or numpy array or PIL image
26 | output_dir: path to the results to be exported
27 | craft_net: craft net model
28 | refine_net: refine net model
29 | text_threshold: text confidence threshold
30 | link_threshold: link confidence threshold
31 | low_text: text low-bound score
32 | cuda: Use cuda for inference
33 | canvas_size: image size for inference
34 | long_size: desired longest image size for inference
35 | poly: enable polygon type
36 | Output:
37 | {"masks": lists of predicted masks 2d as bool array,
38 | "boxes": list of coords of points of predicted boxes,
39 | "boxes_as_ratios": list of coords of points of predicted boxes as ratios of image size,
40 | "polys_as_ratios": list of coords of points of predicted polys as ratios of image size,
41 | "heatmaps": visualizations of the detected characters/links,
42 | "times": elapsed times of the sub modules, in seconds}
43 | """
44 | t0 = time.time()
45 |
46 | # read/convert image
47 | image = image_utils.read_image(image)
48 |
49 | # resize
50 | img_resized, target_ratio, size_heatmap = image_utils.resize_aspect_ratio(
51 | image, long_size, interpolation=cv2.INTER_LINEAR
52 | )
53 | ratio_h = ratio_w = 1 / target_ratio
54 | resize_time = time.time() - t0
55 | t0 = time.time()
56 |
57 | # preprocessing
58 | x = image_utils.normalizeMeanVariance(img_resized)
59 | x = torch_utils.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w]
60 | x = torch_utils.Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w]
61 | if cuda:
62 | x = x.cuda()
63 | preprocessing_time = time.time() - t0
64 | t0 = time.time()
65 |
66 | # forward pass
67 | with torch_utils.no_grad():
68 | y, feature = craft_net(x)
69 | craftnet_time = time.time() - t0
70 | t0 = time.time()
71 |
72 | # make score and link map
73 | score_text = y[0, :, :, 0].cpu().data.numpy()
74 | score_link = y[0, :, :, 1].cpu().data.numpy()
75 |
76 | # refine link
77 | if refine_net is not None:
78 | with torch_utils.no_grad():
79 | y_refiner = refine_net(y, feature)
80 | score_link = y_refiner[0, :, :, 0].cpu().data.numpy()
81 | refinenet_time = time.time() - t0
82 | t0 = time.time()
83 |
84 | # Post-processing
85 | boxes, polys = craft_utils.getDetBoxes(
86 | score_text, score_link, text_threshold, link_threshold, low_text, poly
87 | )
88 |
89 | # coordinate adjustment
90 | boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
91 | polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
92 | for k in range(len(polys)):
93 | if polys[k] is None:
94 | polys[k] = boxes[k]
95 |
96 | # get image size
97 | img_height = image.shape[0]
98 | img_width = image.shape[1]
99 |
100 | # calculate box coords as ratios to image size
101 | boxes_as_ratio = []
102 | for box in boxes:
103 | boxes_as_ratio.append(box / [img_width, img_height])
104 | boxes_as_ratio = np.array(boxes_as_ratio)
105 |
106 | # calculate poly coords as ratios to image size
107 | polys_as_ratio = []
108 | for poly in polys:
109 | polys_as_ratio.append(poly / [img_width, img_height])
110 | polys_as_ratio = np.array(polys_as_ratio)
111 |
112 | text_score_heatmap = image_utils.cvt2HeatmapImg(score_text)
113 | link_score_heatmap = image_utils.cvt2HeatmapImg(score_link)
114 |
115 | postprocess_time = time.time() - t0
116 |
117 | times = {
118 | "resize_time": resize_time,
119 | "preprocessing_time": preprocessing_time,
120 | "craftnet_time": craftnet_time,
121 | "refinenet_time": refinenet_time,
122 | "postprocess_time": postprocess_time,
123 | }
124 |
125 | return {
126 | "boxes": boxes,
127 | "boxes_as_ratios": boxes_as_ratio,
128 | "polys": polys,
129 | "polys_as_ratios": polys_as_ratio,
130 | "heatmaps": {
131 | "text_score_heatmap": text_score_heatmap,
132 | "link_score_heatmap": link_score_heatmap,
133 | },
134 | "times": times,
135 | }
136 |
--------------------------------------------------------------------------------
/craft_text_detector/torch_utils.py:
--------------------------------------------------------------------------------
1 | from torch import from_numpy, load, no_grad
2 | from torch.autograd import Variable
3 | from torch.backends.cudnn import benchmark as cudnn_benchmark
4 | from torch.cuda import empty_cache as empty_cuda_cache
5 | from torch.nn import DataParallel
6 |
--------------------------------------------------------------------------------
/figures/craft_example.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fcakyon/craft-text-detector/6b10d3e0d178679e35796cf4426abb083b354239/figures/craft_example.gif
--------------------------------------------------------------------------------
/figures/idcard.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fcakyon/craft-text-detector/6b10d3e0d178679e35796cf4426abb083b354239/figures/idcard.png
--------------------------------------------------------------------------------
/figures/idcard2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fcakyon/craft-text-detector/6b10d3e0d178679e35796cf4426abb083b354239/figures/idcard2.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.6.0
2 | torchvision>=0.7.0
3 | opencv-python>=3.4.8.29,<4.5.4.62
4 | scipy>=1.3.2
5 | gdown>=3.10.1
6 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 120
3 | exclude =.git,__pycache__,docs/source/conf.py,build,dist
4 | ignore = I101,I201,F401,F403,S001,D100,D101,D102,D103,D104,D105,D106,D107,D200,D205,D400,W504,D202,E203,W503,B006
5 | inline-quotes = "
6 |
7 | [mypy]
8 | ignore_missing_imports = True
9 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import io
2 | import os
3 | import re
4 |
5 | import setuptools
6 |
7 |
8 | def get_long_description():
9 | base_dir = os.path.abspath(os.path.dirname(__file__))
10 | with io.open(os.path.join(base_dir, "README.md"), encoding="utf-8") as f:
11 | return f.read()
12 |
13 |
14 | def get_requirements():
15 | with open("requirements.txt") as f:
16 | return f.read().splitlines()
17 |
18 |
19 | def get_version():
20 | current_dir = os.path.abspath(os.path.dirname(__file__))
21 | version_file = os.path.join(current_dir, "craft_text_detector", "__init__.py")
22 | with io.open(version_file, encoding="utf-8") as f:
23 | return re.search(r'^__version__ = [\'"]([^\'"]*)[\'"]', f.read(), re.M).group(1)
24 |
25 |
26 | setuptools.setup(
27 | name="craft-text-detector",
28 | version=get_version(),
29 | author="Fatih Cagatay Akyon",
30 | license="MIT",
31 | description="Fast and accurate text detection library built on CRAFT implementation",
32 | long_description=get_long_description(),
33 | long_description_content_type="text/markdown",
34 | url="https://github.com/fcakyon/craft_text_detector",
35 | packages=setuptools.find_packages(exclude=["tests"]),
36 | install_requires=get_requirements(),
37 | python_requires=">=3.7",
38 | classifiers=[
39 | "Development Status :: 5 - Production/Stable",
40 | "License :: OSI Approved :: MIT License",
41 | "Operating System :: OS Independent",
42 | "Intended Audience :: Developers",
43 | "Intended Audience :: Science/Research",
44 | "Programming Language :: Python :: 3",
45 | "Programming Language :: Python :: 3.6",
46 | "Programming Language :: Python :: 3.7",
47 | "Programming Language :: Python :: 3.8",
48 | "Topic :: Software Development :: Libraries",
49 | "Topic :: Software Development :: Libraries :: Python Modules",
50 | "Topic :: Education",
51 | "Topic :: Scientific/Engineering",
52 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
53 | "Topic :: Scientific/Engineering :: Image Recognition",
54 | ],
55 | keywords="machine-learning, deep-learning, ml, pytorch, text, text-detection, craft",
56 | )
57 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fcakyon/craft-text-detector/6b10d3e0d178679e35796cf4426abb083b354239/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_craft.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from craft_text_detector import Craft
4 |
5 |
6 | class TestCraftTextDetector(unittest.TestCase):
7 | image_path = "figures/idcard.png"
8 |
9 | def test_init(self):
10 | craft = Craft(
11 | output_dir=None,
12 | rectify=True,
13 | export_extra=False,
14 | text_threshold=0.7,
15 | link_threshold=0.4,
16 | low_text=0.4,
17 | cuda=False,
18 | long_size=720,
19 | refiner=False,
20 | crop_type="poly",
21 | )
22 | self.assertTrue(craft)
23 |
24 | def test_load_craftnet_model(self):
25 | # init craft
26 | craft = Craft(
27 | output_dir=None,
28 | rectify=True,
29 | export_extra=False,
30 | text_threshold=0.7,
31 | link_threshold=0.4,
32 | low_text=0.4,
33 | cuda=False,
34 | long_size=720,
35 | refiner=False,
36 | crop_type="poly",
37 | )
38 | # remove craftnet model
39 | craft.craft_net = None
40 | # load craftnet model
41 | craft.load_craftnet_model()
42 | self.assertTrue(craft.craft_net)
43 |
44 | def test_load_refinenet_model(self):
45 | # init craft
46 | craft = Craft(
47 | output_dir=None,
48 | rectify=True,
49 | export_extra=False,
50 | text_threshold=0.7,
51 | link_threshold=0.4,
52 | low_text=0.4,
53 | cuda=False,
54 | long_size=720,
55 | refiner=False,
56 | crop_type="poly",
57 | )
58 | # remove refinenet model
59 | craft.refine_net = None
60 | # load refinenet model
61 | craft.load_refinenet_model()
62 | self.assertTrue(craft.refine_net)
63 |
64 | def test_detect_text(self):
65 | # init craft
66 | craft = Craft(
67 | output_dir=None,
68 | rectify=True,
69 | export_extra=False,
70 | text_threshold=0.7,
71 | link_threshold=0.4,
72 | low_text=0.4,
73 | cuda=False,
74 | long_size=720,
75 | refiner=False,
76 | crop_type="poly",
77 | )
78 | # detect text
79 | prediction_result = craft.detect_text(image=self.image_path)
80 |
81 | self.assertEqual(len(prediction_result["boxes"]), 52)
82 | self.assertEqual(len(prediction_result["boxes"][0]), 4)
83 | self.assertEqual(len(prediction_result["boxes"][0][0]), 2)
84 | self.assertEqual(int(prediction_result["boxes"][0][0][0]), 115)
85 |
86 | # init craft
87 | craft = Craft(
88 | output_dir=None,
89 | rectify=True,
90 | export_extra=False,
91 | text_threshold=0.7,
92 | link_threshold=0.4,
93 | low_text=0.4,
94 | cuda=False,
95 | long_size=720,
96 | refiner=True,
97 | crop_type="poly",
98 | )
99 | # detect text
100 | prediction_result = craft.detect_text(image=self.image_path)
101 |
102 | self.assertEqual(len(prediction_result["boxes"]), 19)
103 | self.assertEqual(len(prediction_result["boxes"][0]), 4)
104 | self.assertEqual(len(prediction_result["boxes"][0][0]), 2)
105 | self.assertEqual(int(prediction_result["boxes"][0][2][0]), 661)
106 |
107 | # init craft
108 | craft = Craft(
109 | output_dir=None,
110 | rectify=False,
111 | export_extra=False,
112 | text_threshold=0.7,
113 | link_threshold=0.4,
114 | low_text=0.4,
115 | cuda=False,
116 | long_size=720,
117 | refiner=False,
118 | crop_type="box",
119 | )
120 | # detect text
121 | prediction_result = craft.detect_text(image=self.image_path)
122 |
123 | self.assertEqual(len(prediction_result["boxes"]), 52)
124 | self.assertEqual(len(prediction_result["boxes"][0]), 4)
125 | self.assertEqual(len(prediction_result["boxes"][0][0]), 2)
126 | self.assertEqual(int(prediction_result["boxes"][0][2][0]), 244)
127 |
128 | # init craft
129 | craft = Craft(
130 | output_dir=None,
131 | rectify=False,
132 | export_extra=False,
133 | text_threshold=0.7,
134 | link_threshold=0.4,
135 | low_text=0.4,
136 | cuda=False,
137 | long_size=720,
138 | refiner=True,
139 | crop_type="box",
140 | )
141 | # detect text
142 | prediction_result = craft.detect_text(image=self.image_path)
143 |
144 | self.assertEqual(len(prediction_result["boxes"]), 19)
145 | self.assertEqual(len(prediction_result["boxes"][0]), 4)
146 | self.assertEqual(len(prediction_result["boxes"][0][0]), 2)
147 | self.assertEqual(int(prediction_result["boxes"][0][2][0]), 661)
148 |
149 |
150 | if __name__ == "__main__":
151 | unittest.main()
152 |
--------------------------------------------------------------------------------
/tests/test_helpers.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from pathlib import Path
3 | from tempfile import TemporaryDirectory
4 |
5 | from craft_text_detector import (
6 | export_detected_regions,
7 | export_extra_results,
8 | get_prediction,
9 | load_craftnet_model,
10 | load_refinenet_model,
11 | read_image,
12 | )
13 |
14 |
15 | class TestCraftTextDetectorHelpers(unittest.TestCase):
16 | image_path = "figures/idcard.png"
17 |
18 | def test_load_craftnet_model(self):
19 | craft_net = load_craftnet_model(cuda=False)
20 | self.assertTrue(craft_net)
21 |
22 | with TemporaryDirectory() as dir_name:
23 | weight_path = Path(dir_name, "weights.pth")
24 | self.assertFalse(weight_path.is_file())
25 | load_craftnet_model(cuda=False, weight_path=weight_path)
26 | self.assertTrue(weight_path.is_file())
27 |
28 | def test_load_refinenet_model(self):
29 | refine_net = load_refinenet_model(cuda=False)
30 | self.assertTrue(refine_net)
31 |
32 | with TemporaryDirectory() as dir_name:
33 | weight_path = Path(dir_name, "weights.pth")
34 | self.assertFalse(weight_path.is_file())
35 | load_refinenet_model(cuda=False, weight_path=weight_path)
36 | self.assertTrue(weight_path.is_file())
37 |
38 | def test_read_image(self):
39 | image = read_image(self.image_path)
40 | self.assertTrue(image.shape, (500, 786, 3))
41 |
42 | def test_get_prediction(self):
43 | # load image
44 | image = read_image(self.image_path)
45 |
46 | # load models
47 | craft_net = load_craftnet_model()
48 | refine_net = None
49 |
50 | # perform prediction
51 | text_threshold = 0.9
52 | link_threshold = 0.2
53 | low_text = 0.2
54 | cuda = False
55 | prediction_result = get_prediction(
56 | image=image,
57 | craft_net=craft_net,
58 | refine_net=refine_net,
59 | text_threshold=text_threshold,
60 | link_threshold=link_threshold,
61 | low_text=low_text,
62 | cuda=cuda,
63 | long_size=720,
64 | )
65 |
66 | self.assertEqual(len(prediction_result["boxes"]), 35)
67 | self.assertEqual(len(prediction_result["boxes"][0]), 4)
68 | self.assertEqual(len(prediction_result["boxes"][0][0]), 2)
69 | self.assertEqual(int(prediction_result["boxes"][0][0][0]), 111)
70 | self.assertEqual(len(prediction_result["polys"]), 35)
71 | self.assertEqual(
72 | prediction_result["heatmaps"]["text_score_heatmap"].shape, (240, 368, 3)
73 | )
74 |
75 | def test_get_prediction_without_read_image(self):
76 | # set image filepath
77 | image = self.image_path
78 |
79 | # load models
80 | craft_net = load_craftnet_model()
81 | refine_net = None
82 |
83 | # perform prediction
84 | text_threshold = 0.9
85 | link_threshold = 0.2
86 | low_text = 0.2
87 | cuda = False
88 | prediction_result = get_prediction(
89 | image=image,
90 | craft_net=craft_net,
91 | refine_net=refine_net,
92 | text_threshold=text_threshold,
93 | link_threshold=link_threshold,
94 | low_text=low_text,
95 | cuda=cuda,
96 | long_size=720,
97 | )
98 |
99 | self.assertEqual(len(prediction_result["boxes"]), 35)
100 | self.assertEqual(len(prediction_result["boxes"][0]), 4)
101 | self.assertEqual(len(prediction_result["boxes"][0][0]), 2)
102 | self.assertEqual(int(prediction_result["boxes"][0][0][0]), 111)
103 | self.assertEqual(len(prediction_result["polys"]), 35)
104 | self.assertEqual(
105 | prediction_result["heatmaps"]["text_score_heatmap"].shape, (240, 368, 3)
106 | )
107 |
108 |
109 | if __name__ == "__main__":
110 | unittest.main()
111 |
--------------------------------------------------------------------------------