├── .github ├── ISSUE_TEMPLATE │ └── bug-report.md └── workflows │ ├── ci.yml │ ├── publish_conda.yml │ └── publish_pypi.yml ├── .gitignore ├── LICENSE ├── README.md ├── conda ├── conda_build_config.yaml └── meta.yaml ├── craft_text_detector ├── __init__.py ├── craft_utils.py ├── file_utils.py ├── imgproc.py ├── models │ ├── __init__.py │ ├── basenet │ │ ├── __init__.py │ │ └── vgg16_bn.py │ ├── craftnet.py │ └── refinenet.py └── predict.py ├── environment.yml ├── figures ├── craft_example.gif └── idcard.png ├── requirements.txt ├── setup.cfg ├── setup.py └── tests ├── __init__.py └── test_craft.py /.github/ISSUE_TEMPLATE/bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "Bug Report" 3 | about: Submit a bug report to help us improve craft-text-detector 4 | 5 | --- 6 | 7 | ## 🐛 Bug 8 | 9 | 10 | 11 | ## To Reproduce 12 | 13 | Steps to reproduce the behavior: 14 | 15 | 1. 16 | 1. 17 | 1. 18 | 19 | 20 | 21 | ## Expected behavior 22 | 23 | 24 | 25 | ## Environment 26 | 27 | - craft-text-detector version (e.g., 0.1.7): 28 | - Python version (e.g., 3.6/3.7): 29 | - OS (e.g., Linux/Windows/MacOS): 30 | - How you installed albumentations (`conda`, `pip`, source): 31 | - Any other relevant information: 32 | 33 | ## Additional context 34 | 35 | 36 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | build: 11 | 12 | runs-on: ubuntu-latest 13 | strategy: 14 | matrix: 15 | operating-system: [ubuntu-latest, windows-latest, macos-latest] 16 | python-version: [3.5, 3.6, 3.7, 3.8] 17 | fail-fast: false 18 | 19 | steps: 20 | - name: Checkout 21 | uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v1.1.1 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Restore Ubuntu cache 27 | uses: actions/cache@v1 28 | if: matrix.operating-system == 'ubuntu-latest' 29 | with: 30 | path: ~/.cache/pip 31 | key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}} 32 | restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}- 33 | - name: Restore MacOS cache 34 | uses: actions/cache@v1 35 | if: matrix.operating-system == 'macos-latest' 36 | with: 37 | path: ~/Library/Caches/pip 38 | key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}} 39 | restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}- 40 | - name: Restore Windows cache 41 | uses: actions/cache@v1 42 | if: matrix.operating-system == 'windows-latest' 43 | with: 44 | path: ~\AppData\Local\pip\Cache 45 | key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}} 46 | restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}- 47 | - name: Update pip 48 | run: python -m pip install --upgrade pip 49 | - name: Install PyTorch on Linux and Windows 50 | if: > 51 | matrix.operating-system == 'ubuntu-latest' || 52 | matrix.operating-system == 'windows-latest' 53 | run: > 54 | pip install torch==1.4.0+cpu torchvision==0.5.0+cpu 55 | -f https://download.pytorch.org/whl/torch_stable.html 56 | - name: Install PyTorch on MacOS 57 | if: matrix.operating-system == 'macos-latest' 58 | run: pip install torch==1.4.0 torchvision==0.5.0 59 | - name: Install dependencies 60 | run: | 61 | python -m pip install --upgrade pip 62 | pip install -r requirements.txt 63 | - name: Lint with flake8 64 | run: | 65 | pip install flake8 66 | # stop the build if there are Python syntax errors or undefined names 67 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 68 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 69 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 70 | - name: Test with unittest 71 | run: | 72 | python -m unittest 73 | -------------------------------------------------------------------------------- /.github/workflows/publish_conda.yml: -------------------------------------------------------------------------------- 1 | name: Publish Conda Package 2 | 3 | on: 4 | release: 5 | types: [published, edited] 6 | 7 | jobs: 8 | publish: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@master 12 | - name: publish-to-conda 13 | uses: fcakyon/conda-package-publish-action@master 14 | with: 15 | subdir: 'conda' 16 | anacondatoken: ${{ secrets.ANACONDA_TOKEN }} 17 | platforms: 'osx linux win' -------------------------------------------------------------------------------- /.github/workflows/publish_pypi.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | release: 5 | types: [published, edited] 6 | 7 | jobs: 8 | deploy: 9 | 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - uses: actions/checkout@v2 14 | - name: Set up Python 15 | uses: actions/setup-python@v1 16 | with: 17 | python-version: '3.x' 18 | - name: Install dependencies 19 | run: | 20 | python -m pip install --upgrade pip 21 | pip install setuptools wheel twine 22 | - name: Build and publish 23 | env: 24 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 25 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 26 | run: | 27 | python setup.py sdist bdist_wheel 28 | twine upload dist/* 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.swp 3 | *.pkl 4 | *.pth 5 | result* 6 | weights* 7 | .vscode 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | pip-wheel-metadata/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | cover/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | .pybuilder/ 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | # For a library or package, you might want to ignore these files since the code is 96 | # intended to run in multiple environments; otherwise, check them in: 97 | # .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 107 | __pypackages__/ 108 | 109 | # Celery stuff 110 | celerybeat-schedule 111 | celerybeat.pid 112 | 113 | # SageMath parsed files 114 | *.sage.py 115 | 116 | # Environments 117 | .env 118 | .venv 119 | env/ 120 | venv/ 121 | ENV/ 122 | env.bak/ 123 | venv.bak/ 124 | 125 | # Spyder project settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | 143 | # pytype static type analyzer 144 | .pytype/ 145 | 146 | # Cython debug symbols 147 | cython_debug/ 148 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019-present NAVER Corp. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repo is deprecated. Please refer to [new up-to-date repo](https://github.com/fcakyon/craft-text-detector). 2 | -------------------------------------------------------------------------------- /conda/conda_build_config.yaml: -------------------------------------------------------------------------------- 1 | python: 2 | - 3.5 3 | - 3.6 4 | - 3.7 5 | -------------------------------------------------------------------------------- /conda/meta.yaml: -------------------------------------------------------------------------------- 1 | {% set data = load_setup_py_data() %} 2 | 3 | package: 4 | name: craft-text-detector 5 | version: {{ data['version'] }} 6 | 7 | source: 8 | path: .. 9 | 10 | build: 11 | number: 0 12 | script: python setup.py install --single-version-externally-managed --record=record.txt 13 | 14 | requirements: 15 | build: 16 | - python 17 | - numpy>=1.11.1 18 | - scipy 19 | - opencv 20 | - pytorch>=0.4.1 21 | - torchvision>=0.2.1 22 | 23 | run: 24 | - python 25 | - numpy>=1.11.1 26 | - scipy 27 | - opencv 28 | - gdown>=3.10.1 29 | - pytorch>=0.4.1 30 | - torchvision>=0.2.1 31 | 32 | test: 33 | imports: 34 | - craft_text_detector 35 | 36 | about: 37 | home: {{ data['url'] }} 38 | license: {{ data['license'] }} 39 | summary: {{ data['description'] }} 40 | -------------------------------------------------------------------------------- /craft_text_detector/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | __version__ = "0.1.8" 4 | 5 | from craft_text_detector.imgproc import read_image 6 | 7 | from craft_text_detector.file_utils import (export_detected_regions, 8 | export_extra_results) 9 | 10 | from craft_text_detector.predict import (load_craftnet_model, 11 | load_refinenet_model, 12 | get_prediction) 13 | 14 | # load craft model 15 | craft_net = load_craftnet_model() 16 | 17 | 18 | # detect texts 19 | def detect_text(image_path, 20 | output_dir=None, 21 | rectify=True, 22 | export_extra=True, 23 | text_threshold=0.7, 24 | link_threshold=0.4, 25 | low_text=0.4, 26 | cuda=False, 27 | long_size=1280, 28 | show_time=False, 29 | refiner=True, 30 | crop_type="poly"): 31 | """ 32 | Arguments: 33 | image_path: path to the image to be processed 34 | output_dir: path to the results to be exported 35 | rectify: rectify detected polygon by affine transform 36 | export_extra: export heatmap, detection points, box visualization 37 | text_threshold: text confidence threshold 38 | link_threshold: link confidence threshold 39 | low_text: text low-bound score 40 | cuda: Use cuda for inference 41 | long_size: desired longest image size for inference 42 | show_time: show processing time 43 | refiner: enable link refiner 44 | crop_type: crop regions by detected boxes or polys ("poly" or "box") 45 | Output: 46 | {"masks": lists of predicted masks 2d as bool array, 47 | "boxes": list of coords of points of predicted boxes, 48 | "boxes_as_ratios": list of coords of points of predicted boxes as ratios of image size, 49 | "polys_as_ratios": list of coords of points of predicted polys as ratios of image size, 50 | "heatmaps": visualization of the detected characters/links, 51 | "text_crop_paths": list of paths of the exported text boxes/polys} 52 | """ 53 | # load image 54 | image = read_image(image_path) 55 | 56 | # load refiner if required 57 | if refiner: 58 | refine_net = load_refinenet_model() 59 | else: 60 | refine_net = None 61 | 62 | # perform prediction 63 | prediction_result = get_prediction(image=image, 64 | craft_net=craft_net, 65 | refine_net=refine_net, 66 | text_threshold=text_threshold, 67 | link_threshold=link_threshold, 68 | low_text=low_text, 69 | cuda=cuda, 70 | long_size=long_size, 71 | show_time=show_time) 72 | 73 | # arange regions 74 | if crop_type == "box": 75 | regions = prediction_result["boxes"] 76 | elif crop_type == "poly": 77 | regions = prediction_result["polys"] 78 | else: 79 | raise TypeError("crop_type can be only 'polys' or 'boxes'") 80 | 81 | # export if output_dir is given 82 | prediction_result["text_crop_paths"] = [] 83 | if output_dir is not None: 84 | # export detected text regions 85 | exported_file_paths = export_detected_regions(image_path=image_path, 86 | image=image, 87 | regions=regions, 88 | output_dir=output_dir, 89 | rectify=rectify) 90 | prediction_result["text_crop_paths"] = exported_file_paths 91 | 92 | # export heatmap, detection points, box visualization 93 | if export_extra: 94 | export_extra_results(image_path=image_path, 95 | image=image, 96 | regions=regions, 97 | heatmaps=prediction_result["heatmaps"], 98 | output_dir=output_dir) 99 | 100 | # return prediction results 101 | return prediction_result 102 | -------------------------------------------------------------------------------- /craft_text_detector/craft_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | MIT License 4 | """ 5 | 6 | # -*- coding: utf-8 -*- 7 | import numpy as np 8 | import cv2 9 | import math 10 | 11 | """ auxilary functions """ 12 | # unwarp corodinates 13 | 14 | 15 | def warpCoord(Minv, pt): 16 | out = np.matmul(Minv, (pt[0], pt[1], 1)) 17 | return np.array([out[0] / out[2], out[1] / out[2]]) 18 | 19 | 20 | """ end of auxilary functions """ 21 | 22 | 23 | def getDetBoxes_core(textmap, linkmap, 24 | text_threshold, link_threshold, low_text): 25 | # prepare data 26 | linkmap = linkmap.copy() 27 | textmap = textmap.copy() 28 | img_h, img_w = textmap.shape 29 | 30 | """ labeling method """ 31 | ret, text_score = cv2.threshold(textmap, low_text, 1, 0) 32 | ret, link_score = cv2.threshold(linkmap, link_threshold, 1, 0) 33 | 34 | text_score_comb = np.clip(text_score + link_score, 0, 1) 35 | nLabels, labels, stats, centroids = cv2.connectedComponentsWithStats( 36 | text_score_comb.astype(np.uint8), connectivity=4) 37 | 38 | det = [] 39 | mapper = [] 40 | for k in range(1, nLabels): 41 | # size filtering 42 | size = stats[k, cv2.CC_STAT_AREA] 43 | if size < 10: 44 | continue 45 | 46 | # thresholding 47 | if np.max(textmap[labels == k]) < text_threshold: 48 | continue 49 | 50 | # make segmentation map 51 | segmap = np.zeros(textmap.shape, dtype=np.uint8) 52 | segmap[labels == k] = 255 53 | 54 | # remove link area 55 | segmap[np.logical_and(link_score == 1, text_score == 0)] = 0 56 | 57 | x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP] 58 | w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT] 59 | niter = int(math.sqrt(size * min(w, h) / (w * h)) * 2) 60 | sx, ex, sy, ey = (x - niter, 61 | x + w + niter + 1, 62 | y - niter, 63 | y + h + niter + 1) 64 | # boundary check 65 | if sx < 0: 66 | sx = 0 67 | if sy < 0: 68 | sy = 0 69 | if ex >= img_w: 70 | ex = img_w 71 | if ey >= img_h: 72 | ey = img_h 73 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1 + niter, 74 | 1 + niter)) 75 | segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel) 76 | 77 | # make box 78 | np_temp = np.roll(np.array(np.where(segmap != 0)), 1, axis=0) 79 | np_contours = np_temp.transpose().reshape(-1, 2) 80 | rectangle = cv2.minAreaRect(np_contours) 81 | box = cv2.boxPoints(rectangle) 82 | 83 | # align diamond-shape 84 | w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2]) 85 | box_ratio = max(w, h) / (min(w, h) + 1e-5) 86 | if abs(1 - box_ratio) <= 0.1: 87 | l, r = min(np_contours[:, 0]), max(np_contours[:, 0]) 88 | t, b = min(np_contours[:, 1]), max(np_contours[:, 1]) 89 | box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32) 90 | 91 | # make clock-wise order 92 | startidx = box.sum(axis=1).argmin() 93 | box = np.roll(box, 4 - startidx, 0) 94 | box = np.array(box) 95 | 96 | det.append(box) 97 | mapper.append(k) 98 | 99 | return det, labels, mapper 100 | 101 | 102 | def getPoly_core(boxes, labels, mapper, linkmap): 103 | # configs 104 | num_cp = 5 105 | max_len_ratio = 0.7 106 | expand_ratio = 1.45 107 | max_r = 2.0 108 | step_r = 0.2 109 | 110 | polys = [] 111 | for k, box in enumerate(boxes): 112 | # size filter for small instance 113 | w, h = (int(np.linalg.norm(box[0] - box[1]) + 1), 114 | int(np.linalg.norm(box[1] - box[2]) + 1)) 115 | if w < 10 or h < 10: 116 | polys.append(None) 117 | continue 118 | 119 | # warp image 120 | tar = np.float32([[0, 0], [w, 0], [w, h], [0, h]]) 121 | M = cv2.getPerspectiveTransform(box, tar) 122 | word_label = cv2.warpPerspective(labels, 123 | M, 124 | (w, h), 125 | flags=cv2.INTER_NEAREST) 126 | try: 127 | Minv = np.linalg.inv(M) 128 | except: 129 | polys.append(None) 130 | continue 131 | 132 | # binarization for selected label 133 | cur_label = mapper[k] 134 | word_label[word_label != cur_label] = 0 135 | word_label[word_label > 0] = 1 136 | 137 | """ Polygon generation """ 138 | # find top/bottom contours 139 | cp = [] 140 | max_len = -1 141 | for i in range(w): 142 | region = np.where(word_label[:, i] != 0)[0] 143 | if len(region) < 2: 144 | continue 145 | cp.append((i, region[0], region[-1])) 146 | length = region[-1] - region[0] + 1 147 | if length > max_len: 148 | max_len = length 149 | 150 | # pass if max_len is similar to h 151 | if h * max_len_ratio < max_len: 152 | polys.append(None) 153 | continue 154 | 155 | # get pivot points with fixed length 156 | tot_seg = num_cp * 2 + 1 157 | seg_w = w / tot_seg # segment width 158 | pp = [None] * num_cp # init pivot points 159 | cp_section = [[0, 0]] * tot_seg 160 | seg_height = [0] * num_cp 161 | seg_num = 0 162 | num_sec = 0 163 | prev_h = -1 164 | for i in range(0, len(cp)): 165 | (x, sy, ey) = cp[i] 166 | if (seg_num + 1) * seg_w <= x and seg_num <= tot_seg: 167 | # average previous segment 168 | if num_sec == 0: 169 | break 170 | cp_section[seg_num] = [cp_section[seg_num][0] / num_sec, 171 | cp_section[seg_num][1] / num_sec] 172 | num_sec = 0 173 | 174 | # reset variables 175 | seg_num += 1 176 | prev_h = -1 177 | 178 | # accumulate center points 179 | cy = (sy + ey) * 0.5 180 | cur_h = ey - sy + 1 181 | cp_section[seg_num] = [cp_section[seg_num][0] + x, 182 | cp_section[seg_num][1] + cy] 183 | num_sec += 1 184 | 185 | if seg_num % 2 == 0: 186 | continue # No polygon area 187 | 188 | if prev_h < cur_h: 189 | pp[int((seg_num - 1) / 2)] = (x, cy) 190 | seg_height[int((seg_num - 1) / 2)] = cur_h 191 | prev_h = cur_h 192 | 193 | # processing last segment 194 | if num_sec != 0: 195 | cp_section[-1] = [cp_section[-1][0] / num_sec, 196 | cp_section[-1][1] / num_sec] 197 | 198 | # pass if num of pivots is not sufficient or segment widh 199 | # is smaller than character height 200 | if None in pp or seg_w < np.max(seg_height) * 0.25: 201 | polys.append(None) 202 | continue 203 | 204 | # calc median maximum of pivot points 205 | half_char_h = np.median(seg_height) * expand_ratio / 2 206 | 207 | # calc gradiant and apply to make horizontal pivots 208 | new_pp = [] 209 | for i, (x, cy) in enumerate(pp): 210 | dx = cp_section[i * 2 + 2][0] - cp_section[i * 2][0] 211 | dy = cp_section[i * 2 + 2][1] - cp_section[i * 2][1] 212 | if dx == 0: # gradient if zero 213 | new_pp.append([x, cy - half_char_h, x, cy + half_char_h]) 214 | continue 215 | rad = - math.atan2(dy, dx) 216 | c, s = half_char_h * math.cos(rad), half_char_h * math.sin(rad) 217 | new_pp.append([x - s, cy - c, x + s, cy + c]) 218 | 219 | # get edge points to cover character heatmaps 220 | isSppFound, isEppFound = False, False 221 | grad_s = (pp[1][1] - pp[0][1]) / (pp[1][0] - pp[0][0]) + (pp[2][1] - pp[1][1]) / (pp[2][0] - pp[1][0]) 222 | grad_e = (pp[-2][1] - pp[-1][1]) / (pp[-2][0] - pp[-1][0]) + (pp[-3][1] - pp[-2][1]) / (pp[-3][0] - pp[-2][0]) 223 | for r in np.arange(0.5, max_r, step_r): 224 | dx = 2 * half_char_h * r 225 | if not isSppFound: 226 | line_img = np.zeros(word_label.shape, dtype=np.uint8) 227 | dy = grad_s * dx 228 | p = np.array(new_pp[0]) - np.array([dx, dy, dx, dy]) 229 | cv2.line(line_img, 230 | (int(p[0]), int(p[1])), 231 | (int(p[2]), int(p[3])), 232 | 1, 233 | thickness=1) 234 | if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r: 235 | spp = p 236 | isSppFound = True 237 | if not isEppFound: 238 | line_img = np.zeros(word_label.shape, dtype=np.uint8) 239 | dy = grad_e * dx 240 | p = np.array(new_pp[-1]) + np.array([dx, dy, dx, dy]) 241 | cv2.line(line_img, 242 | (int(p[0]), int(p[1])), 243 | (int(p[2]), int(p[3])), 244 | 1, 245 | thickness=1) 246 | if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r: 247 | epp = p 248 | isEppFound = True 249 | if isSppFound and isEppFound: 250 | break 251 | 252 | # pass if boundary of polygon is not found 253 | if not (isSppFound and isEppFound): 254 | polys.append(None) 255 | continue 256 | 257 | # make final polygon 258 | poly = [] 259 | poly.append(warpCoord(Minv, (spp[0], spp[1]))) 260 | for p in new_pp: 261 | poly.append(warpCoord(Minv, (p[0], p[1]))) 262 | poly.append(warpCoord(Minv, (epp[0], epp[1]))) 263 | poly.append(warpCoord(Minv, (epp[2], epp[3]))) 264 | for p in reversed(new_pp): 265 | poly.append(warpCoord(Minv, (p[2], p[3]))) 266 | poly.append(warpCoord(Minv, (spp[2], spp[3]))) 267 | 268 | # add to final result 269 | polys.append(np.array(poly)) 270 | 271 | return polys 272 | 273 | 274 | def getDetBoxes(textmap, 275 | linkmap, 276 | text_threshold, 277 | link_threshold, 278 | low_text, 279 | poly=False): 280 | boxes, labels, mapper = getDetBoxes_core(textmap, 281 | linkmap, 282 | text_threshold, 283 | link_threshold, 284 | low_text) 285 | 286 | if poly: 287 | polys = getPoly_core(boxes, labels, mapper, linkmap) 288 | else: 289 | polys = [None] * len(boxes) 290 | 291 | return boxes, polys 292 | 293 | 294 | def adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net=2): 295 | if len(polys) > 0: 296 | polys = np.array(polys) 297 | for k in range(len(polys)): 298 | if polys[k] is not None: 299 | polys[k] *= (ratio_w * ratio_net, ratio_h * ratio_net) 300 | return polys 301 | -------------------------------------------------------------------------------- /craft_text_detector/file_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import cv2 4 | import copy 5 | import gdown 6 | import numpy as np 7 | 8 | 9 | def download(url: str, save_path: str): 10 | """ 11 | Downloads file from gdrive, shows progress. 12 | Example inputs: 13 | url: 'ftp://smartengines.com/midv-500/dataset/01_alb_id.zip' 14 | save_path: 'data/file.zip' 15 | """ 16 | 17 | # create save_dir if not present 18 | create_dir(os.path.dirname(save_path)) 19 | # download file 20 | gdown.download(url, save_path, quiet=False) 21 | 22 | 23 | def create_dir(_dir): 24 | """ 25 | Creates given directory if it is not present. 26 | """ 27 | if not os.path.exists(_dir): 28 | os.makedirs(_dir) 29 | 30 | 31 | def get_files(img_dir): 32 | imgs, masks, xmls = list_files(img_dir) 33 | return imgs, masks, xmls 34 | 35 | 36 | def list_files(in_path): 37 | img_files = [] 38 | mask_files = [] 39 | gt_files = [] 40 | for (dirpath, dirnames, filenames) in os.walk(in_path): 41 | for file in filenames: 42 | filename, ext = os.path.splitext(file) 43 | ext = str.lower(ext) 44 | if (ext == '.jpg' or ext == '.jpeg' or ext == '.gif' or 45 | ext == '.png' or ext == '.pgm'): 46 | img_files.append(os.path.join(dirpath, file)) 47 | elif ext == '.bmp': 48 | mask_files.append(os.path.join(dirpath, file)) 49 | elif ext == '.xml' or ext == '.gt' or ext == '.txt': 50 | gt_files.append(os.path.join(dirpath, file)) 51 | elif ext == '.zip': 52 | continue 53 | # img_files.sort() 54 | # mask_files.sort() 55 | # gt_files.sort() 56 | return img_files, mask_files, gt_files 57 | 58 | 59 | def rectify_poly(img, poly): 60 | # Use Affine transform 61 | n = int(len(poly) / 2) - 1 62 | width = 0 63 | height = 0 64 | for k in range(n): 65 | box = np.float32([poly[k], poly[k + 1], poly[-k - 2], poly[-k - 1]]) 66 | width += int((np.linalg.norm(box[0] - box[1]) + np.linalg.norm(box[2] - box[3])) / 2) 67 | height += np.linalg.norm(box[1] - box[2]) 68 | width = int(width) 69 | height = int(height / n) 70 | 71 | output_img = np.zeros((height, width, 3), dtype=np.uint8) 72 | width_step = 0 73 | for k in range(n): 74 | box = np.float32([poly[k], poly[k + 1], poly[-k - 2], poly[-k - 1]]) 75 | w = int((np.linalg.norm(box[0] - box[1]) + np.linalg.norm(box[2] - box[3])) / 2) 76 | 77 | # Top triangle 78 | pts1 = box[:3] 79 | pts2 = np.float32([[width_step, 0], [width_step + w - 1, 0], [width_step + w - 1, height - 1]]) 80 | M = cv2.getAffineTransform(pts1, pts2) 81 | warped_img = cv2.warpAffine(img, M, (width, height), borderMode=cv2.BORDER_REPLICATE) 82 | warped_mask = np.zeros((height, width, 3), dtype=np.uint8) 83 | warped_mask = cv2.fillConvexPoly(warped_mask, np.int32(pts2), (1, 1, 1)) 84 | output_img[warped_mask == 1] = warped_img[warped_mask == 1] 85 | 86 | # Bottom triangle 87 | pts1 = np.vstack((box[0], box[2:])) 88 | pts2 = np.float32([[width_step, 0], [width_step + w - 1, height - 1], [width_step, height - 1]]) 89 | M = cv2.getAffineTransform(pts1, pts2) 90 | warped_img = cv2.warpAffine(img, M, (width, height), borderMode=cv2.BORDER_REPLICATE) 91 | warped_mask = np.zeros((height, width, 3), dtype=np.uint8) 92 | warped_mask = cv2.fillConvexPoly(warped_mask, np.int32(pts2), (1, 1, 1)) 93 | cv2.line(warped_mask, (width_step, 0), (width_step + w - 1, height - 1), (0, 0, 0), 1) 94 | output_img[warped_mask == 1] = warped_img[warped_mask == 1] 95 | 96 | width_step += w 97 | return output_img 98 | 99 | 100 | def crop_poly(image, poly): 101 | # points should have 1*x*2 shape 102 | if len(poly.shape) == 2: 103 | poly = np.array([np.array(poly).astype(np.int32)]) 104 | 105 | # create mask with shape of image 106 | mask = np.zeros(image.shape[0:2], dtype=np.uint8) 107 | 108 | # method 1 smooth region 109 | cv2.drawContours(mask, [poly], -1, (255, 255, 255), -1, cv2.LINE_AA) 110 | # method 2 not so smooth region 111 | # cv2.fillPoly(mask, points, (255)) 112 | 113 | # crop around poly 114 | res = cv2.bitwise_and(image, image, mask=mask) 115 | rect = cv2.boundingRect(poly) # returns (x,y,w,h) of the rect 116 | cropped = res[rect[1]: rect[1] + rect[3], rect[0]: rect[0] + rect[2]] 117 | 118 | return cropped 119 | 120 | 121 | def export_detected_region(image, poly, file_path, rectify=True): 122 | """ 123 | Arguments: 124 | image: full image 125 | points: bbox or poly points 126 | file_path: path to be exported 127 | rectify: rectify detected polygon by affine transform 128 | """ 129 | if rectify: 130 | # rectify poly region 131 | result = rectify_poly(image, poly) 132 | else: 133 | result = crop_poly(image, poly) 134 | 135 | # export corpped region 136 | cv2.imwrite(file_path, result) 137 | 138 | 139 | def export_detected_regions(image_path, image, regions, 140 | output_dir: str = "output/", 141 | rectify: bool = False): 142 | """ 143 | Arguments: 144 | image_path: path to original image 145 | image: full/original image 146 | regions: list of bboxes or polys 147 | output_dir: folder to be exported 148 | rectify: rectify detected polygon by affine transform 149 | """ 150 | # deepcopy image so that original is not altered 151 | image = copy.deepcopy(image) 152 | 153 | # get file name 154 | file_name, file_ext = os.path.splitext(os.path.basename(image_path)) 155 | 156 | # create crops dir 157 | crops_dir = os.path.join(output_dir, file_name + "_crops") 158 | create_dir(crops_dir) 159 | 160 | # init exported file paths 161 | exported_file_paths = [] 162 | 163 | # export regions 164 | for ind, region in enumerate(regions): 165 | # get export path 166 | file_path = os.path.join(crops_dir, "crop_" + str(ind) + ".png") 167 | # export region 168 | export_detected_region(image, 169 | poly=region, 170 | file_path=file_path, 171 | rectify=rectify) 172 | # note exported file path 173 | exported_file_paths.append(file_path) 174 | 175 | return exported_file_paths 176 | 177 | 178 | def export_extra_results(image_path, 179 | image, 180 | regions, 181 | heatmaps, 182 | output_dir='output/', 183 | verticals=None, 184 | texts=None): 185 | """ save text detection result one by one 186 | Args: 187 | image_path (str): image file name 188 | image (array): raw image context 189 | boxes (array): array of result file 190 | Shape: [num_detections, 4] for BB output / [num_detections, 4] 191 | for QUAD output 192 | Return: 193 | None 194 | """ 195 | image = np.array(image) 196 | 197 | # make result file list 198 | filename, file_ext = os.path.splitext(os.path.basename(image_path)) 199 | 200 | # result directory 201 | res_file = os.path.join(output_dir, 202 | filename + '.txt') 203 | res_img_file = os.path.join(output_dir, 204 | filename + '.png') 205 | text_heatmap_file = os.path.join(output_dir, 206 | filename + '_text_score_heatmap.png') 207 | link_heatmap_file = os.path.join(output_dir, 208 | filename + '_link_score_heatmap.png') 209 | 210 | # create output dir 211 | create_dir(output_dir) 212 | 213 | # export heatmaps 214 | cv2.imwrite(text_heatmap_file, heatmaps["text_score_heatmap"]) 215 | cv2.imwrite(link_heatmap_file, heatmaps["link_score_heatmap"]) 216 | 217 | with open(res_file, 'w') as f: 218 | for i, region in enumerate(regions): 219 | region = np.array(region).astype(np.int32).reshape((-1)) 220 | strResult = ','.join([str(r) for r in region]) + '\r\n' 221 | f.write(strResult) 222 | 223 | region = region.reshape(-1, 2) 224 | cv2.polylines(image, 225 | [region.reshape((-1, 1, 2))], 226 | True, 227 | color=(0, 0, 255), 228 | thickness=2) 229 | 230 | if texts is not None: 231 | font = cv2.FONT_HERSHEY_SIMPLEX 232 | font_scale = 0.5 233 | cv2.putText(image, "{}".format(texts[i]), 234 | (region[0][0] + 1, region[0][1] + 1), 235 | font, 236 | font_scale, 237 | (0, 0, 0), 238 | thickness=1) 239 | cv2.putText(image, 240 | "{}".format(texts[i]), 241 | tuple(region[0]), 242 | font, 243 | font_scale, 244 | (0, 255, 255), 245 | thickness=1) 246 | 247 | # Save result image 248 | cv2.imwrite(res_img_file, image) 249 | -------------------------------------------------------------------------------- /craft_text_detector/imgproc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | MIT License 4 | """ 5 | 6 | # -*- coding: utf-8 -*- 7 | import numpy as np 8 | import cv2 9 | 10 | 11 | def read_image(img_file): 12 | img = cv2.imread(img_file) 13 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 14 | # following two cases are not explained in the original repo 15 | if img.shape[0] == 2: 16 | img = img[0] 17 | if img.shape[2] == 4: 18 | img = img[:, :, :3] 19 | 20 | return img 21 | 22 | 23 | def normalizeMeanVariance(in_img, 24 | mean=(0.485, 0.456, 0.406), 25 | variance=(0.229, 0.224, 0.225)): 26 | # should be RGB order 27 | img = in_img.copy().astype(np.float32) 28 | 29 | img -= np.array([mean[0] * 255.0, 30 | mean[1] * 255.0, 31 | mean[2] * 255.0], dtype=np.float32) 32 | img /= np.array([variance[0] * 255.0, 33 | variance[1] * 255.0, 34 | variance[2] * 255.0], dtype=np.float32) 35 | return img 36 | 37 | 38 | def denormalizeMeanVariance(in_img, 39 | mean=(0.485, 0.456, 0.406), 40 | variance=(0.229, 0.224, 0.225)): 41 | # should be RGB order 42 | img = in_img.copy() 43 | img *= variance 44 | img += mean 45 | img *= 255.0 46 | img = np.clip(img, 0, 255).astype(np.uint8) 47 | return img 48 | 49 | 50 | def resize_aspect_ratio(img, long_size, interpolation): 51 | height, width, channel = img.shape 52 | 53 | # set target image size 54 | target_size = long_size 55 | 56 | ratio = target_size / max(height, width) 57 | 58 | target_h, target_w = int(height * ratio), int(width * ratio) 59 | proc = cv2.resize(img, (target_w, target_h), interpolation=interpolation) 60 | 61 | # make canvas and paste image 62 | target_h32, target_w32 = target_h, target_w 63 | if target_h % 32 != 0: 64 | target_h32 = target_h + (32 - target_h % 32) 65 | if target_w % 32 != 0: 66 | target_w32 = target_w + (32 - target_w % 32) 67 | resized = np.zeros((target_h32, target_w32, channel), dtype=np.float32) 68 | resized[0:target_h, 0:target_w, :] = proc 69 | target_h, target_w = target_h32, target_w32 70 | 71 | size_heatmap = (int(target_w / 2), int(target_h / 2)) 72 | 73 | return resized, ratio, size_heatmap 74 | 75 | 76 | def cvt2HeatmapImg(img): 77 | img = (np.clip(img, 0, 1) * 255).astype(np.uint8) 78 | img = cv2.applyColorMap(img, cv2.COLORMAP_JET) 79 | return img 80 | -------------------------------------------------------------------------------- /craft_text_detector/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fcakyon/craft-text-detector-old-repo/9b8915b7c7e00e0e3edb67d8a89b894c34c89b7a/craft_text_detector/models/__init__.py -------------------------------------------------------------------------------- /craft_text_detector/models/basenet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fcakyon/craft-text-detector-old-repo/9b8915b7c7e00e0e3edb67d8a89b894c34c89b7a/craft_text_detector/models/basenet/__init__.py -------------------------------------------------------------------------------- /craft_text_detector/models/basenet/vgg16_bn.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | from torchvision import models 7 | from torchvision.models.vgg import model_urls 8 | 9 | def init_weights(modules): 10 | for m in modules: 11 | if isinstance(m, nn.Conv2d): 12 | init.xavier_uniform_(m.weight.data) 13 | if m.bias is not None: 14 | m.bias.data.zero_() 15 | elif isinstance(m, nn.BatchNorm2d): 16 | m.weight.data.fill_(1) 17 | m.bias.data.zero_() 18 | elif isinstance(m, nn.Linear): 19 | m.weight.data.normal_(0, 0.01) 20 | m.bias.data.zero_() 21 | 22 | class vgg16_bn(torch.nn.Module): 23 | def __init__(self, pretrained=True, freeze=True): 24 | super(vgg16_bn, self).__init__() 25 | model_urls['vgg16_bn'] = model_urls['vgg16_bn'].replace('https://', 'http://') 26 | vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features 27 | self.slice1 = torch.nn.Sequential() 28 | self.slice2 = torch.nn.Sequential() 29 | self.slice3 = torch.nn.Sequential() 30 | self.slice4 = torch.nn.Sequential() 31 | self.slice5 = torch.nn.Sequential() 32 | for x in range(12): # conv2_2 33 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 34 | for x in range(12, 19): # conv3_3 35 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 36 | for x in range(19, 29): # conv4_3 37 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 38 | for x in range(29, 39): # conv5_3 39 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 40 | 41 | # fc6, fc7 without atrous conv 42 | self.slice5 = torch.nn.Sequential( 43 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1), 44 | nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6), 45 | nn.Conv2d(1024, 1024, kernel_size=1) 46 | ) 47 | 48 | if not pretrained: 49 | init_weights(self.slice1.modules()) 50 | init_weights(self.slice2.modules()) 51 | init_weights(self.slice3.modules()) 52 | init_weights(self.slice4.modules()) 53 | 54 | init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7 55 | 56 | if freeze: 57 | for param in self.slice1.parameters(): # only first conv 58 | param.requires_grad= False 59 | 60 | def forward(self, X): 61 | h = self.slice1(X) 62 | h_relu2_2 = h 63 | h = self.slice2(h) 64 | h_relu3_2 = h 65 | h = self.slice3(h) 66 | h_relu4_3 = h 67 | h = self.slice4(h) 68 | h_relu5_3 = h 69 | h = self.slice5(h) 70 | h_fc7 = h 71 | vgg_outputs = namedtuple("VggOutputs", ['fc7', 'relu5_3', 'relu4_3', 'relu3_2', 'relu2_2']) 72 | out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2) 73 | return out 74 | -------------------------------------------------------------------------------- /craft_text_detector/models/craftnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | MIT License 4 | """ 5 | 6 | # -*- coding: utf-8 -*- 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from craft_text_detector.models.basenet.vgg16_bn import vgg16_bn, init_weights 12 | 13 | 14 | class double_conv(nn.Module): 15 | def __init__(self, in_ch, mid_ch, out_ch): 16 | super(double_conv, self).__init__() 17 | self.conv = nn.Sequential( 18 | nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1), 19 | nn.BatchNorm2d(mid_ch), 20 | nn.ReLU(inplace=True), 21 | nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1), 22 | nn.BatchNorm2d(out_ch), 23 | nn.ReLU(inplace=True) 24 | ) 25 | 26 | def forward(self, x): 27 | x = self.conv(x) 28 | return x 29 | 30 | 31 | class CRAFT(nn.Module): 32 | def __init__(self, pretrained=False, freeze=False): 33 | super(CRAFT, self).__init__() 34 | 35 | """ Base network """ 36 | self.basenet = vgg16_bn(pretrained, freeze) 37 | 38 | """ U network """ 39 | self.upconv1 = double_conv(1024, 512, 256) 40 | self.upconv2 = double_conv(512, 256, 128) 41 | self.upconv3 = double_conv(256, 128, 64) 42 | self.upconv4 = double_conv(128, 64, 32) 43 | 44 | num_class = 2 45 | self.conv_cls = nn.Sequential( 46 | nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True), 47 | nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True), 48 | nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True), 49 | nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True), 50 | nn.Conv2d(16, num_class, kernel_size=1), 51 | ) 52 | 53 | init_weights(self.upconv1.modules()) 54 | init_weights(self.upconv2.modules()) 55 | init_weights(self.upconv3.modules()) 56 | init_weights(self.upconv4.modules()) 57 | init_weights(self.conv_cls.modules()) 58 | 59 | def forward(self, x): 60 | """ Base network """ 61 | sources = self.basenet(x) 62 | 63 | """ U network """ 64 | y = torch.cat([sources[0], sources[1]], dim=1) 65 | y = self.upconv1(y) 66 | 67 | y = F.interpolate(y, size=sources[2].size()[2:], 68 | mode='bilinear', align_corners=False) 69 | y = torch.cat([y, sources[2]], dim=1) 70 | y = self.upconv2(y) 71 | 72 | y = F.interpolate(y, size=sources[3].size()[2:], 73 | mode='bilinear', align_corners=False) 74 | y = torch.cat([y, sources[3]], dim=1) 75 | y = self.upconv3(y) 76 | 77 | y = F.interpolate(y, size=sources[4].size()[2:], 78 | mode='bilinear', align_corners=False) 79 | y = torch.cat([y, sources[4]], dim=1) 80 | feature = self.upconv4(y) 81 | 82 | y = self.conv_cls(feature) 83 | 84 | return y.permute(0, 2, 3, 1), feature 85 | 86 | 87 | if __name__ == '__main__': 88 | model = CRAFT(pretrained=True).cuda() 89 | output, _ = model(torch.randn(1, 3, 768, 768).cuda()) 90 | print(output.shape) 91 | -------------------------------------------------------------------------------- /craft_text_detector/models/refinenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | MIT License 4 | """ 5 | 6 | # -*- coding: utf-8 -*- 7 | import torch 8 | import torch.nn as nn 9 | from craft_text_detector.models.basenet.vgg16_bn import init_weights 10 | 11 | 12 | class RefineNet(nn.Module): 13 | def __init__(self): 14 | super(RefineNet, self).__init__() 15 | 16 | self.last_conv = nn.Sequential( 17 | nn.Conv2d(34, 64, kernel_size=3, padding=1), 18 | nn.BatchNorm2d(64), 19 | nn.ReLU(inplace=True), 20 | 21 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 22 | nn.BatchNorm2d(64), 23 | nn.ReLU(inplace=True), 24 | 25 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 26 | nn.BatchNorm2d(64), 27 | nn.ReLU(inplace=True) 28 | ) 29 | 30 | self.aspp1 = nn.Sequential( 31 | nn.Conv2d(64, 128, kernel_size=3, dilation=6, padding=6), 32 | nn.BatchNorm2d(128), 33 | nn.ReLU(inplace=True), 34 | 35 | nn.Conv2d(128, 128, kernel_size=1), 36 | nn.BatchNorm2d(128), 37 | nn.ReLU(inplace=True), 38 | 39 | nn.Conv2d(128, 1, kernel_size=1) 40 | ) 41 | 42 | self.aspp2 = nn.Sequential( 43 | nn.Conv2d(64, 128, kernel_size=3, dilation=12, padding=12), 44 | nn.BatchNorm2d(128), 45 | nn.ReLU(inplace=True), 46 | 47 | nn.Conv2d(128, 128, kernel_size=1), 48 | nn.BatchNorm2d(128), 49 | nn.ReLU(inplace=True), 50 | 51 | nn.Conv2d(128, 1, kernel_size=1) 52 | ) 53 | 54 | self.aspp3 = nn.Sequential( 55 | nn.Conv2d(64, 128, kernel_size=3, dilation=18, padding=18), 56 | nn.BatchNorm2d(128), 57 | nn.ReLU(inplace=True), 58 | 59 | nn.Conv2d(128, 128, kernel_size=1), 60 | nn.BatchNorm2d(128), 61 | nn.ReLU(inplace=True), 62 | 63 | nn.Conv2d(128, 1, kernel_size=1) 64 | ) 65 | 66 | self.aspp4 = nn.Sequential( 67 | nn.Conv2d(64, 128, kernel_size=3, dilation=24, padding=24), 68 | nn.BatchNorm2d(128), 69 | nn.ReLU(inplace=True), 70 | 71 | nn.Conv2d(128, 128, kernel_size=1), 72 | nn.BatchNorm2d(128), 73 | nn.ReLU(inplace=True), 74 | 75 | nn.Conv2d(128, 1, kernel_size=1) 76 | ) 77 | 78 | init_weights(self.last_conv.modules()) 79 | init_weights(self.aspp1.modules()) 80 | init_weights(self.aspp2.modules()) 81 | init_weights(self.aspp3.modules()) 82 | init_weights(self.aspp4.modules()) 83 | 84 | def forward(self, y, upconv4): 85 | refine = torch.cat([y.permute(0, 3, 1, 2), upconv4], dim=1) 86 | refine = self.last_conv(refine) 87 | 88 | aspp1 = self.aspp1(refine) 89 | aspp2 = self.aspp2(refine) 90 | aspp3 = self.aspp3(refine) 91 | aspp4 = self.aspp4(refine) 92 | 93 | # out = torch.add([aspp1, aspp2, aspp3, aspp4], dim=1) 94 | out = aspp1 + aspp2 + aspp3 + aspp4 95 | return out.permute(0, 2, 3, 1) # , refine.permute(0,2,3,1) 96 | -------------------------------------------------------------------------------- /craft_text_detector/predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from pathlib import Path 4 | 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | from torch.autograd import Variable 8 | 9 | import cv2 10 | import numpy as np 11 | 12 | import craft_text_detector.craft_utils as craft_utils 13 | import craft_text_detector.imgproc as imgproc 14 | import craft_text_detector.file_utils as file_utils 15 | from craft_text_detector.models.craftnet import CRAFT 16 | 17 | from collections import OrderedDict 18 | 19 | CRAFT_GDRIVE_URL = "https://drive.google.com/uc?id=1bupFXqT-VU6Jjeul13XP7yx2Sg5IHr4J" 20 | REFINENET_GDRIVE_URL = "https://drive.google.com/uc?id=1xcE9qpJXp4ofINwXWVhhQIh9S8Z7cuGj" 21 | 22 | 23 | def copyStateDict(state_dict): 24 | if list(state_dict.keys())[0].startswith("module"): 25 | start_idx = 1 26 | else: 27 | start_idx = 0 28 | new_state_dict = OrderedDict() 29 | for k, v in state_dict.items(): 30 | name = ".".join(k.split(".")[start_idx:]) 31 | new_state_dict[name] = v 32 | return new_state_dict 33 | 34 | 35 | def str2bool(v): 36 | return v.lower() in ("yes", "y", "true", "t", "1") 37 | 38 | 39 | def load_craftnet_model(cuda: bool = False): 40 | # get craft net path 41 | home_path = str(Path.home()) 42 | weight_path = os.path.join(home_path, 43 | ".craft_text_detector", 44 | "weights", 45 | "craft_mlt_25k.pth") 46 | # load craft net 47 | craft_net = CRAFT() # initialize 48 | 49 | # check if weights are already downloaded, if not download 50 | url = CRAFT_GDRIVE_URL 51 | if os.path.isfile(weight_path) is not True: 52 | print("Craft text detector weight will be downloaded to {}" 53 | .format(weight_path)) 54 | 55 | file_utils.download(url=url, save_path=weight_path) 56 | 57 | # arange device 58 | if cuda: 59 | craft_net.load_state_dict(copyStateDict(torch.load(weight_path))) 60 | 61 | craft_net = craft_net.cuda() 62 | craft_net = torch.nn.DataParallel(craft_net) 63 | cudnn.benchmark = False 64 | else: 65 | craft_net.load_state_dict(copyStateDict(torch.load(weight_path, 66 | map_location='cpu'))) 67 | craft_net.eval() 68 | return craft_net 69 | 70 | 71 | def load_refinenet_model(cuda: bool = False): 72 | # get refine net path 73 | home_path = str(Path.home()) 74 | weight_path = os.path.join(home_path, 75 | ".craft_text_detector", 76 | "weights", 77 | "craft_refiner_CTW1500.pth") 78 | # load refine net 79 | from craft_text_detector.models.refinenet import RefineNet 80 | refine_net = RefineNet() # initialize 81 | 82 | # check if weights are already downloaded, if not download 83 | url = REFINENET_GDRIVE_URL 84 | if os.path.isfile(weight_path) is not True: 85 | print("Craft text refiner weight will be downloaded to {}" 86 | .format(weight_path)) 87 | 88 | file_utils.download(url=url, save_path=weight_path) 89 | 90 | # arange device 91 | if cuda: 92 | refine_net.load_state_dict(copyStateDict(torch.load(weight_path))) 93 | 94 | refine_net = refine_net.cuda() 95 | refine_net = torch.nn.DataParallel(refine_net) 96 | cudnn.benchmark = False 97 | else: 98 | refine_net.load_state_dict(copyStateDict(torch.load(weight_path, 99 | map_location='cpu'))) 100 | refine_net.eval() 101 | return refine_net 102 | 103 | 104 | def get_prediction(image, 105 | craft_net, 106 | refine_net=None, 107 | text_threshold: float = 0.7, 108 | link_threshold: float = 0.4, 109 | low_text: float = 0.4, 110 | cuda: bool = False, 111 | long_size: int = 1280, 112 | mag_ratio: float = 1.5, 113 | poly: bool = True, 114 | show_time: bool = False): 115 | """ 116 | Arguments: 117 | image: image to be processed 118 | output_dir: path to the results to be exported 119 | craft_net: craft net model 120 | refine_net: refine net model 121 | text_threshold: text confidence threshold 122 | link_threshold: link confidence threshold 123 | low_text: text low-bound score 124 | cuda: Use cuda for inference 125 | canvas_size: image size for inference 126 | long_size: desired longest image size for inference 127 | poly: enable polygon type 128 | show_time: show processing time 129 | Output: 130 | {"masks": lists of predicted masks 2d as bool array, 131 | "boxes": list of coords of points of predicted boxes, 132 | "boxes_as_ratios": list of coords of points of predicted boxes as ratios of image size, 133 | "polys_as_ratios": list of coords of points of predicted polys as ratios of image size, 134 | "heatmaps": visualizations of the detected characters/links} 135 | """ 136 | t0 = time.time() 137 | 138 | # resize 139 | img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio( 140 | image, long_size, interpolation=cv2.INTER_LINEAR) 141 | ratio_h = ratio_w = 1 / target_ratio 142 | 143 | # preprocessing 144 | x = imgproc.normalizeMeanVariance(img_resized) 145 | x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] 146 | x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w] 147 | if cuda: 148 | x = x.cuda() 149 | 150 | # forward pass 151 | with torch.no_grad(): 152 | y, feature = craft_net(x) 153 | 154 | # make score and link map 155 | score_text = y[0, :, :, 0].cpu().data.numpy() 156 | score_link = y[0, :, :, 1].cpu().data.numpy() 157 | 158 | # refine link 159 | if refine_net is not None: 160 | with torch.no_grad(): 161 | y_refiner = refine_net(y, feature) 162 | score_link = y_refiner[0, :, :, 0].cpu().data.numpy() 163 | 164 | t0 = time.time() - t0 165 | t1 = time.time() 166 | 167 | # Post-processing 168 | boxes, polys = craft_utils.getDetBoxes( 169 | score_text, score_link, text_threshold, link_threshold, low_text, 170 | poly) 171 | 172 | # coordinate adjustment 173 | boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) 174 | polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h) 175 | for k in range(len(polys)): 176 | if polys[k] is None: 177 | polys[k] = boxes[k] 178 | 179 | t1 = time.time() - t1 180 | 181 | # render results (optional) 182 | text_score_heatmap = imgproc.cvt2HeatmapImg(score_text) 183 | link_score_heatmap = imgproc.cvt2HeatmapImg(score_link) 184 | 185 | if show_time: 186 | print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1)) 187 | 188 | # get image size 189 | img_height = image.shape[0] 190 | img_width = image.shape[1] 191 | 192 | # calculate box coords as ratios to image size 193 | boxes_as_ratio = [] 194 | for box in boxes: 195 | boxes_as_ratio.append(box / [img_width, img_height]) 196 | boxes_as_ratio = np.array(boxes_as_ratio) 197 | 198 | # calculate poly coords as ratios to image size 199 | polys_as_ratio = [] 200 | for poly in polys: 201 | polys_as_ratio.append(poly / [img_width, img_height]) 202 | polys_as_ratio = np.array(polys_as_ratio) 203 | 204 | return {"boxes": boxes, 205 | "boxes_as_ratios": boxes_as_ratio, 206 | "polys": polys, 207 | "polys_as_ratios": polys_as_ratio, 208 | "heatmaps": {"text_score_heatmap": text_score_heatmap, 209 | "link_score_heatmap": link_score_heatmap}} 210 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | 2 | name: craft 3 | #prefix: /your/custom/path/envs/craft 4 | channels: 5 | - conda-forge 6 | - pytorch 7 | - anaconda 8 | dependencies: 9 | - python=3.6 10 | - pip=19.3.1 11 | - pytorch::pytorch>=1.4.0 12 | - pytorch::torchvision>=0.5.0 13 | - pip: 14 | - scikit-image==0.14.2 15 | - opencv-python==3.4.2.17 16 | - scipy==1.4.1 17 | - pytesseract==0.3.3 18 | -------------------------------------------------------------------------------- /figures/craft_example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fcakyon/craft-text-detector-old-repo/9b8915b7c7e00e0e3edb67d8a89b894c34c89b7a/figures/craft_example.gif -------------------------------------------------------------------------------- /figures/idcard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fcakyon/craft-text-detector-old-repo/9b8915b7c7e00e0e3edb67d8a89b894c34c89b7a/figures/idcard.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=0.4.1 2 | torchvision>=0.2.1 3 | opencv-python==3.4.8.29 4 | scipy>=1.3.2 5 | gdown>=3.10.1 -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 119 3 | exclude =.git,__pycache__,docs/source/conf.py,build,dist 4 | ignore = I101,I201,F401,F403,S001,D100,D101,D102,D103,D104,D105,D106,D107,D200,D205,D400,W504,D202,E203,W503,B006 5 | inline-quotes = " 6 | 7 | [mypy] 8 | ignore_missing_imports = True 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import re 4 | import setuptools 5 | 6 | 7 | def get_long_description(): 8 | base_dir = os.path.abspath(os.path.dirname(__file__)) 9 | with io.open(os.path.join(base_dir, "README.md"), encoding="utf-8") as f: 10 | return f.read() 11 | 12 | 13 | def get_requirements(): 14 | with open('requirements.txt') as f: 15 | return f.read().splitlines() 16 | 17 | 18 | def get_version(): 19 | current_dir = os.path.abspath(os.path.dirname(__file__)) 20 | version_file = os.path.join(current_dir, 21 | "craft_text_detector", 22 | "__init__.py") 23 | with io.open(version_file, encoding="utf-8") as f: 24 | return re.search(r'^__version__ = [\'"]([^\'"]*)[\'"]', 25 | f.read(), 26 | re.M).group(1) 27 | 28 | 29 | setuptools.setup( 30 | name="craft-text-detector", 31 | version=get_version(), 32 | author="Fatih Cagatay Akyon", 33 | license="MIT", 34 | description="Fast and accurate text detection library built on CRAFT implementation", 35 | long_description=get_long_description(), 36 | long_description_content_type="text/markdown", 37 | url="https://github.com/fcakyon/craft_text_detector", 38 | packages=setuptools.find_packages(exclude=["tests"]), 39 | install_requires=get_requirements(), 40 | python_requires='>=3.5', 41 | classifiers=[ 42 | "License :: OSI Approved :: MIT License", 43 | "Operating System :: OS Independent", 44 | "Intended Audience :: Developers", 45 | "Intended Audience :: Science/Research", 46 | "Programming Language :: Python :: 3", 47 | "Programming Language :: Python :: 3.5", 48 | "Programming Language :: Python :: 3.6", 49 | "Programming Language :: Python :: 3.7", 50 | "Programming Language :: Python :: 3.8", 51 | "Topic :: Software Development :: Libraries", 52 | "Topic :: Software Development :: Libraries :: Python Modules" 53 | ] 54 | ) 55 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fcakyon/craft-text-detector-old-repo/9b8915b7c7e00e0e3edb67d8a89b894c34c89b7a/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_craft.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import craft_text_detector 3 | 4 | image_path = 'figures/idcard.png' 5 | 6 | 7 | class TestCraftTextDetector(unittest.TestCase): 8 | def test_load_craftnet_model(self): 9 | craft_net = craft_text_detector.load_craftnet_model() 10 | self.assertTrue(craft_net) 11 | 12 | def test_load_refinenet_model(self): 13 | refine_net = craft_text_detector.load_refinenet_model() 14 | self.assertTrue(refine_net) 15 | 16 | def test_get_prediction(self): 17 | # load image 18 | image = craft_text_detector.read_image(image_path) 19 | 20 | # load models 21 | craft_net = craft_text_detector.load_craftnet_model() 22 | refine_net = None 23 | 24 | # perform prediction 25 | text_threshold = 0.9 26 | link_threshold = 0.2 27 | low_text = 0.2 28 | cuda = False 29 | show_time = False 30 | get_prediction = craft_text_detector.get_prediction 31 | prediction_result = get_prediction(image=image, 32 | craft_net=craft_net, 33 | refine_net=refine_net, 34 | text_threshold=text_threshold, 35 | link_threshold=link_threshold, 36 | low_text=low_text, 37 | cuda=cuda, 38 | long_size=720, 39 | show_time=show_time) 40 | 41 | self.assertEqual(len(prediction_result["boxes"]), 35) 42 | self.assertEqual(len(prediction_result["boxes"][0]), 4) 43 | self.assertEqual(len(prediction_result["boxes"][0][0]), 2) 44 | self.assertEqual(int(prediction_result["boxes"][0][0][0]), 111) 45 | self.assertEqual(len(prediction_result["polys"]), 35) 46 | self.assertEqual(prediction_result["heatmaps"]["text_score_heatmap"].shape, (240, 368, 3)) 47 | 48 | def test_detect_text(self): 49 | prediction_result = craft_text_detector.detect_text(image_path=image_path, 50 | output_dir=None, 51 | rectify=True, 52 | export_extra=False, 53 | text_threshold=0.7, 54 | link_threshold=0.4, 55 | low_text=0.4, 56 | cuda=False, 57 | long_size=720, 58 | show_time=False, 59 | refiner=False, 60 | crop_type="poly") 61 | self.assertEqual(len(prediction_result["boxes"]), 52) 62 | self.assertEqual(len(prediction_result["boxes"][0]), 4) 63 | self.assertEqual(len(prediction_result["boxes"][0][0]), 2) 64 | self.assertEqual(int(prediction_result["boxes"][0][0][0]), 115) 65 | 66 | prediction_result = craft_text_detector.detect_text(image_path=image_path, 67 | output_dir=None, 68 | rectify=True, 69 | export_extra=False, 70 | text_threshold=0.7, 71 | link_threshold=0.4, 72 | low_text=0.4, 73 | cuda=False, 74 | long_size=720, 75 | show_time=False, 76 | refiner=True, 77 | crop_type="poly") 78 | self.assertEqual(len(prediction_result["boxes"]), 19) 79 | self.assertEqual(len(prediction_result["boxes"][0]), 4) 80 | self.assertEqual(len(prediction_result["boxes"][0][0]), 2) 81 | self.assertEqual(int(prediction_result["boxes"][0][2][0]), 661) 82 | 83 | prediction_result = craft_text_detector.detect_text(image_path=image_path, 84 | output_dir=None, 85 | rectify=False, 86 | export_extra=False, 87 | text_threshold=0.7, 88 | link_threshold=0.4, 89 | low_text=0.4, 90 | cuda=False, 91 | long_size=720, 92 | show_time=False, 93 | refiner=False, 94 | crop_type="box") 95 | self.assertEqual(len(prediction_result["boxes"]), 52) 96 | self.assertEqual(len(prediction_result["boxes"][0]), 4) 97 | self.assertEqual(len(prediction_result["boxes"][0][0]), 2) 98 | self.assertEqual(int(prediction_result["boxes"][0][2][0]), 244) 99 | 100 | prediction_result = craft_text_detector.detect_text(image_path=image_path, 101 | output_dir=None, 102 | rectify=False, 103 | export_extra=False, 104 | text_threshold=0.7, 105 | link_threshold=0.4, 106 | low_text=0.4, 107 | cuda=False, 108 | long_size=720, 109 | show_time=False, 110 | refiner=True, 111 | crop_type="box") 112 | self.assertEqual(len(prediction_result["boxes"]), 19) 113 | self.assertEqual(len(prediction_result["boxes"][0]), 4) 114 | self.assertEqual(len(prediction_result["boxes"][0][0]), 2) 115 | self.assertEqual(int(prediction_result["boxes"][0][2][0]), 661) 116 | 117 | 118 | if __name__ == '__main__': 119 | unittest.main() 120 | --------------------------------------------------------------------------------