├── basenet ├── __init__.py └── vgg16_bn.py ├── .gitignore ├── figures └── craft_example.gif ├── requirements.txt ├── LICENSE ├── imgproc.py ├── refinenet.py ├── craft.py ├── file_utils.py ├── README.md ├── test.py └── craft_utils.py /basenet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.swp 3 | *.pkl 4 | *.pth 5 | result* 6 | weights* 7 | .vscode 8 | .mypy_cache/ -------------------------------------------------------------------------------- /figures/craft_example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gaarv/CRAFT-pytorch/master/figures/craft_example.gif -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | lmdb 2 | torchvision 3 | nltk 4 | natsort 5 | opencv-python 6 | scikit-image 7 | Pillow==6.2.1 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019-present NAVER Corp. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /imgproc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | MIT License 4 | """ 5 | 6 | # -*- coding: utf-8 -*- 7 | import numpy as np 8 | from skimage import io 9 | import cv2 10 | 11 | def loadImage(img_file): 12 | img = io.imread(img_file) # RGB order 13 | if img.shape[0] == 2: img = img[0] 14 | if len(img.shape) == 2 : img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) 15 | if img.shape[2] == 4: img = img[:,:,:3] 16 | img = np.array(img) 17 | 18 | return img 19 | 20 | def normalizeMeanVariance(in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)): 21 | # should be RGB order 22 | img = in_img.copy().astype(np.float32) 23 | 24 | img -= np.array([mean[0] * 255.0, mean[1] * 255.0, mean[2] * 255.0], dtype=np.float32) 25 | img /= np.array([variance[0] * 255.0, variance[1] * 255.0, variance[2] * 255.0], dtype=np.float32) 26 | return img 27 | 28 | def denormalizeMeanVariance(in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)): 29 | # should be RGB order 30 | img = in_img.copy() 31 | img *= variance 32 | img += mean 33 | img *= 255.0 34 | img = np.clip(img, 0, 255).astype(np.uint8) 35 | return img 36 | 37 | def resize_aspect_ratio(img, square_size, interpolation, mag_ratio=1): 38 | height, width, channel = img.shape 39 | 40 | # magnify image size 41 | target_size = mag_ratio * max(height, width) 42 | 43 | # set original image size 44 | if target_size > square_size: 45 | target_size = square_size 46 | 47 | ratio = target_size / max(height, width) 48 | 49 | target_h, target_w = int(height * ratio), int(width * ratio) 50 | proc = cv2.resize(img, (target_w, target_h), interpolation = interpolation) 51 | 52 | 53 | # make canvas and paste image 54 | target_h32, target_w32 = target_h, target_w 55 | if target_h % 32 != 0: 56 | target_h32 = target_h + (32 - target_h % 32) 57 | if target_w % 32 != 0: 58 | target_w32 = target_w + (32 - target_w % 32) 59 | resized = np.zeros((target_h32, target_w32, channel), dtype=np.float32) 60 | resized[0:target_h, 0:target_w, :] = proc 61 | target_h, target_w = target_h32, target_w32 62 | 63 | size_heatmap = (int(target_w/2), int(target_h/2)) 64 | 65 | return resized, ratio, size_heatmap 66 | 67 | def cvt2HeatmapImg(img): 68 | img = (np.clip(img, 0, 1) * 255).astype(np.uint8) 69 | img = cv2.applyColorMap(img, cv2.COLORMAP_JET) 70 | return img 71 | -------------------------------------------------------------------------------- /refinenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | MIT License 4 | """ 5 | 6 | # -*- coding: utf-8 -*- 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | from basenet.vgg16_bn import init_weights 12 | 13 | 14 | class RefineNet(nn.Module): 15 | def __init__(self): 16 | super(RefineNet, self).__init__() 17 | 18 | self.last_conv = nn.Sequential( 19 | nn.Conv2d(34, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), 20 | nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), 21 | nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True) 22 | ) 23 | 24 | self.aspp1 = nn.Sequential( 25 | nn.Conv2d(64, 128, kernel_size=3, dilation=6, padding=6), nn.BatchNorm2d(128), nn.ReLU(inplace=True), 26 | nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), 27 | nn.Conv2d(128, 1, kernel_size=1) 28 | ) 29 | 30 | self.aspp2 = nn.Sequential( 31 | nn.Conv2d(64, 128, kernel_size=3, dilation=12, padding=12), nn.BatchNorm2d(128), nn.ReLU(inplace=True), 32 | nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), 33 | nn.Conv2d(128, 1, kernel_size=1) 34 | ) 35 | 36 | self.aspp3 = nn.Sequential( 37 | nn.Conv2d(64, 128, kernel_size=3, dilation=18, padding=18), nn.BatchNorm2d(128), nn.ReLU(inplace=True), 38 | nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), 39 | nn.Conv2d(128, 1, kernel_size=1) 40 | ) 41 | 42 | self.aspp4 = nn.Sequential( 43 | nn.Conv2d(64, 128, kernel_size=3, dilation=24, padding=24), nn.BatchNorm2d(128), nn.ReLU(inplace=True), 44 | nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), 45 | nn.Conv2d(128, 1, kernel_size=1) 46 | ) 47 | 48 | init_weights(self.last_conv.modules()) 49 | init_weights(self.aspp1.modules()) 50 | init_weights(self.aspp2.modules()) 51 | init_weights(self.aspp3.modules()) 52 | init_weights(self.aspp4.modules()) 53 | 54 | def forward(self, y, upconv4): 55 | refine = torch.cat([y.permute(0,3,1,2), upconv4], dim=1) 56 | refine = self.last_conv(refine) 57 | 58 | aspp1 = self.aspp1(refine) 59 | aspp2 = self.aspp2(refine) 60 | aspp3 = self.aspp3(refine) 61 | aspp4 = self.aspp4(refine) 62 | 63 | #out = torch.add([aspp1, aspp2, aspp3, aspp4], dim=1) 64 | out = aspp1 + aspp2 + aspp3 + aspp4 65 | return out.permute(0, 2, 3, 1) # , refine.permute(0,2,3,1) 66 | -------------------------------------------------------------------------------- /craft.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | MIT License 4 | """ 5 | 6 | # -*- coding: utf-8 -*- 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from basenet.vgg16_bn import vgg16_bn, init_weights 12 | 13 | class double_conv(nn.Module): 14 | def __init__(self, in_ch, mid_ch, out_ch): 15 | super(double_conv, self).__init__() 16 | self.conv = nn.Sequential( 17 | nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1), 18 | nn.BatchNorm2d(mid_ch), 19 | nn.ReLU(inplace=True), 20 | nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1), 21 | nn.BatchNorm2d(out_ch), 22 | nn.ReLU(inplace=True) 23 | ) 24 | 25 | def forward(self, x): 26 | x = self.conv(x) 27 | return x 28 | 29 | 30 | class CRAFT(nn.Module): 31 | def __init__(self, pretrained=False, freeze=False): 32 | super(CRAFT, self).__init__() 33 | 34 | """ Base network """ 35 | self.basenet = vgg16_bn(pretrained, freeze) 36 | 37 | """ U network """ 38 | self.upconv1 = double_conv(1024, 512, 256) 39 | self.upconv2 = double_conv(512, 256, 128) 40 | self.upconv3 = double_conv(256, 128, 64) 41 | self.upconv4 = double_conv(128, 64, 32) 42 | 43 | num_class = 2 44 | self.conv_cls = nn.Sequential( 45 | nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True), 46 | nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True), 47 | nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True), 48 | nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True), 49 | nn.Conv2d(16, num_class, kernel_size=1), 50 | ) 51 | 52 | init_weights(self.upconv1.modules()) 53 | init_weights(self.upconv2.modules()) 54 | init_weights(self.upconv3.modules()) 55 | init_weights(self.upconv4.modules()) 56 | init_weights(self.conv_cls.modules()) 57 | 58 | def forward(self, x): 59 | """ Base network """ 60 | sources = self.basenet(x) 61 | 62 | """ U network """ 63 | y = torch.cat([sources[0], sources[1]], dim=1) 64 | y = self.upconv1(y) 65 | 66 | y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False) 67 | y = torch.cat([y, sources[2]], dim=1) 68 | y = self.upconv2(y) 69 | 70 | y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False) 71 | y = torch.cat([y, sources[3]], dim=1) 72 | y = self.upconv3(y) 73 | 74 | y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False) 75 | y = torch.cat([y, sources[4]], dim=1) 76 | feature = self.upconv4(y) 77 | 78 | y = self.conv_cls(feature) 79 | 80 | return y.permute(0,2,3,1), feature 81 | 82 | if __name__ == '__main__': 83 | model = CRAFT(pretrained=True).cuda() 84 | output, _ = model(torch.randn(1, 3, 768, 768).cuda()) 85 | print(output.shape) -------------------------------------------------------------------------------- /basenet/vgg16_bn.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | from torchvision import models 7 | from torchvision.models.vgg import model_urls 8 | 9 | def init_weights(modules): 10 | for m in modules: 11 | if isinstance(m, nn.Conv2d): 12 | init.xavier_uniform_(m.weight.data) 13 | if m.bias is not None: 14 | m.bias.data.zero_() 15 | elif isinstance(m, nn.BatchNorm2d): 16 | m.weight.data.fill_(1) 17 | m.bias.data.zero_() 18 | elif isinstance(m, nn.Linear): 19 | m.weight.data.normal_(0, 0.01) 20 | m.bias.data.zero_() 21 | 22 | class vgg16_bn(torch.nn.Module): 23 | def __init__(self, pretrained=True, freeze=True): 24 | super(vgg16_bn, self).__init__() 25 | model_urls['vgg16_bn'] = model_urls['vgg16_bn'].replace('https://', 'http://') 26 | vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features 27 | self.slice1 = torch.nn.Sequential() 28 | self.slice2 = torch.nn.Sequential() 29 | self.slice3 = torch.nn.Sequential() 30 | self.slice4 = torch.nn.Sequential() 31 | self.slice5 = torch.nn.Sequential() 32 | for x in range(12): # conv2_2 33 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 34 | for x in range(12, 19): # conv3_3 35 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 36 | for x in range(19, 29): # conv4_3 37 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 38 | for x in range(29, 39): # conv5_3 39 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 40 | 41 | # fc6, fc7 without atrous conv 42 | self.slice5 = torch.nn.Sequential( 43 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1), 44 | nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6), 45 | nn.Conv2d(1024, 1024, kernel_size=1) 46 | ) 47 | 48 | if not pretrained: 49 | init_weights(self.slice1.modules()) 50 | init_weights(self.slice2.modules()) 51 | init_weights(self.slice3.modules()) 52 | init_weights(self.slice4.modules()) 53 | 54 | init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7 55 | 56 | if freeze: 57 | for param in self.slice1.parameters(): # only first conv 58 | param.requires_grad= False 59 | 60 | def forward(self, X): 61 | h = self.slice1(X) 62 | h_relu2_2 = h 63 | h = self.slice2(h) 64 | h_relu3_2 = h 65 | h = self.slice3(h) 66 | h_relu4_3 = h 67 | h = self.slice4(h) 68 | h_relu5_3 = h 69 | h = self.slice5(h) 70 | h_fc7 = h 71 | vgg_outputs = namedtuple("VggOutputs", ['fc7', 'relu5_3', 'relu4_3', 'relu3_2', 'relu2_2']) 72 | out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2) 73 | return out 74 | -------------------------------------------------------------------------------- /file_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import numpy as np 4 | import cv2 5 | import imgproc 6 | 7 | # borrowed from https://github.com/lengstrom/fast-style-transfer/blob/master/src/utils.py 8 | def get_files(img_dir): 9 | imgs, masks, xmls = list_files(img_dir) 10 | return imgs, masks, xmls 11 | 12 | def list_files(in_path): 13 | img_files = [] 14 | mask_files = [] 15 | gt_files = [] 16 | for (dirpath, dirnames, filenames) in os.walk(in_path): 17 | for file in filenames: 18 | filename, ext = os.path.splitext(file) 19 | ext = str.lower(ext) 20 | if ext == '.jpg' or ext == '.jpeg' or ext == '.gif' or ext == '.png' or ext == '.pgm': 21 | img_files.append(os.path.join(dirpath, file)) 22 | elif ext == '.bmp': 23 | mask_files.append(os.path.join(dirpath, file)) 24 | elif ext == '.xml' or ext == '.gt' or ext == '.txt': 25 | gt_files.append(os.path.join(dirpath, file)) 26 | elif ext == '.zip': 27 | continue 28 | img_files.sort() 29 | mask_files.sort() 30 | gt_files.sort() 31 | return img_files, mask_files, gt_files 32 | 33 | def saveResult(img_file, img, boxes, dirname='./result/', verticals=None, texts=None): 34 | """ save text detection result one by one 35 | Args: 36 | img_file (str): image file name 37 | img (array): raw image context 38 | boxes (array): array of result file 39 | Shape: [num_detections, 4] for BB output / [num_detections, 4] for QUAD output 40 | Return: 41 | None 42 | """ 43 | img = np.array(img) 44 | 45 | # make result file list 46 | filename, file_ext = os.path.splitext(os.path.basename(img_file)) 47 | 48 | # result directory 49 | res_file = dirname + "res_" + filename + '.txt' 50 | res_img_file = dirname + "res_" + filename + file_ext 51 | 52 | if not os.path.isdir(dirname): 53 | os.mkdir(dirname) 54 | 55 | with open(res_file, 'w') as f: 56 | for i, box in enumerate(boxes): 57 | poly = np.array(box).astype(np.int32).reshape((-1)) 58 | strResult = ','.join([str(p) for p in poly]) + '\r\n' 59 | f.write(strResult) 60 | 61 | poly = poly.reshape(-1, 2) 62 | cv2.polylines(img, [poly.reshape((-1, 1, 2))], True, color=(0, 0, 255), thickness=2) 63 | ptColor = (0, 255, 255) 64 | if verticals is not None: 65 | if verticals[i]: 66 | ptColor = (255, 0, 0) 67 | 68 | if texts is not None: 69 | font = cv2.FONT_HERSHEY_SIMPLEX 70 | font_scale = 0.5 71 | cv2.putText(img, "{}".format(texts[i]), (poly[0][0]+1, poly[0][1]+1), font, font_scale, (0, 0, 0), thickness=1) 72 | cv2.putText(img, "{}".format(texts[i]), tuple(poly[0]), font, font_scale, (0, 255, 255), thickness=1) 73 | 74 | # Save result image 75 | cv2.imwrite(res_img_file, img) 76 | 77 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## CRAFT: Character-Region Awareness For Text detection 2 | Official Pytorch implementation of CRAFT text detector | [Paper](https://arxiv.org/abs/1904.01941) | [Pretrained Model](https://drive.google.com/open?id=1Jk4eGD7crsqCCg9C9VjCLkMN3ze8kutZ) | [Supplementary](https://youtu.be/HI8MzpY8KMI) 3 | 4 | **[Youngmin Baek](mailto:youngmin.baek@navercorp.com), Bado Lee, Dongyoon Han, Sangdoo Yun, Hwalsuk Lee.** 5 | 6 | Clova AI Research, NAVER Corp. 7 | 8 | ### Sample Results 9 | 10 | ### Overview 11 | PyTorch implementation for CRAFT text detector that effectively detect text area by exploring each character region and affinity between characters. The bounding box of texts are obtained by simply finding minimum bounding rectangles on binary map after thresholding character region and affinity scores. 12 | 13 | teaser 14 | 15 | ## Updates 16 | **13 Jun, 2019**: Initial update 17 | **20 Jul, 2019**: Added post-processing for polygon result 18 | **28 Sep, 2019**: Added the trained model on IC15 and the link refiner 19 | 20 | 21 | ## Getting started 22 | ### Install dependencies 23 | #### Requirements 24 | - PyTorch>=0.4.1 25 | - torchvision>=0.2.1 26 | - opencv-python>=3.4.2 27 | - check requiremtns.txt 28 | ``` 29 | pip install -r requirements.txt 30 | ``` 31 | 32 | ### Training 33 | The code for training is not included in this repository, and we cannot release the full training code for IP reason. 34 | 35 | 36 | ### Test instruction using pretrained model 37 | - Download the trained models 38 | 39 | *Model name* | *Used datasets* | *Languages* | *Purpose* | *Model Link* | 40 | | :--- | :--- | :--- | :--- | :--- | 41 | General | SynthText, IC13, IC17 | Eng + MLT | For general purpose | [Click](https://drive.google.com/open?id=1Jk4eGD7crsqCCg9C9VjCLkMN3ze8kutZ) 42 | IC15 | SynthText, IC15 | Eng | For IC15 only | [Click](https://drive.google.com/open?id=1i2R7UIUqmkUtF0jv_3MXTqmQ_9wuAnLf) 43 | LinkRefiner | CTW1500 | - | Used with the General Model | [Click](https://drive.google.com/open?id=1XSaFwBkOaFOdtk4Ane3DFyJGPRw6v5bO) 44 | 45 | * Run with pretrained model 46 | ``` (with python 3.7) 47 | python test.py --trained_model=[weightfile] --test_folder=[folder path to test images] 48 | ``` 49 | 50 | The result image and socre maps will be saved to `./result` by default. 51 | 52 | ### Arguments 53 | * `--trained_model`: pretrained model 54 | * `--text_threshold`: text confidence threshold 55 | * `--low_text`: text low-bound score 56 | * `--link_threshold`: link confidence threshold 57 | * `--cuda`: use cuda for inference (default:True) 58 | * `--canvas_size`: max image size for inference 59 | * `--mag_ratio`: image magnification ratio 60 | * `--poly`: enable polygon type result 61 | * `--show_time`: show processing time 62 | * `--test_folder`: folder path to input images 63 | * `--refine`: use link refiner for sentense-level dataset 64 | * `--refiner_model`: pretrained refiner model 65 | 66 | 67 | ## Links 68 | - WebDemo : https://demo.ocr.clova.ai/ 69 | - Repo of recognition : https://github.com/clovaai/deep-text-recognition-benchmark 70 | 71 | ## Citation 72 | ``` 73 | @inproceedings{baek2019character, 74 | title={Character Region Awareness for Text Detection}, 75 | author={Baek, Youngmin and Lee, Bado and Han, Dongyoon and Yun, Sangdoo and Lee, Hwalsuk}, 76 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 77 | pages={9365--9374}, 78 | year={2019} 79 | } 80 | ``` 81 | 82 | ## License 83 | ``` 84 | Copyright (c) 2019-present NAVER Corp. 85 | 86 | Permission is hereby granted, free of charge, to any person obtaining a copy 87 | of this software and associated documentation files (the "Software"), to deal 88 | in the Software without restriction, including without limitation the rights 89 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 90 | copies of the Software, and to permit persons to whom the Software is 91 | furnished to do so, subject to the following conditions: 92 | 93 | The above copyright notice and this permission notice shall be included in 94 | all copies or substantial portions of the Software. 95 | 96 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 97 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 98 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 99 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 100 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 101 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 102 | THE SOFTWARE. 103 | ``` 104 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | MIT License 4 | """ 5 | 6 | # -*- coding: utf-8 -*- 7 | import sys 8 | import os 9 | import time 10 | import argparse 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.backends.cudnn as cudnn 15 | from torch.autograd import Variable 16 | 17 | from PIL import Image 18 | 19 | import cv2 20 | from skimage import io 21 | import numpy as np 22 | import craft_utils 23 | import imgproc 24 | import file_utils 25 | import json 26 | import zipfile 27 | 28 | from craft import CRAFT 29 | 30 | from collections import OrderedDict 31 | def copyStateDict(state_dict): 32 | if list(state_dict.keys())[0].startswith("module"): 33 | start_idx = 1 34 | else: 35 | start_idx = 0 36 | new_state_dict = OrderedDict() 37 | for k, v in state_dict.items(): 38 | name = ".".join(k.split(".")[start_idx:]) 39 | new_state_dict[name] = v 40 | return new_state_dict 41 | 42 | def str2bool(v): 43 | if isinstance(v, bool): 44 | return v 45 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 46 | return True 47 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 48 | return False 49 | else: 50 | raise argparse.ArgumentTypeError('Boolean value expected.') 51 | 52 | parser = argparse.ArgumentParser(description='CRAFT Text Detection') 53 | parser.add_argument('--trained_model', default='weights/craft_mlt_25k.pth', type=str, help='pretrained model') 54 | parser.add_argument('--text_threshold', default=0.7, type=float, help='text confidence threshold') 55 | parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score') 56 | parser.add_argument('--link_threshold', default=0.4, type=float, help='link confidence threshold') 57 | parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda for inference') 58 | parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference') 59 | parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio') 60 | parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type') 61 | parser.add_argument('--show_time', default=False, action='store_true', help='show processing time') 62 | parser.add_argument('--test_folder', default='/data/', type=str, help='folder path to input images') 63 | parser.add_argument('--refine', default=False, action='store_true', help='enable link refiner') 64 | parser.add_argument('--refiner_model', default='weights/craft_refiner_CTW1500.pth', type=str, help='pretrained refiner model') 65 | parser.add_argument('--save_bboxes', default=True, type=str2bool, help='save bounding boxes as cropped images') 66 | parser.add_argument('--save_result', default=False, type=str2bool, help='save result image with bounding boxes') 67 | parser.add_argument('--output_folder', default='./result/', type=str, help='folder path to output images') 68 | 69 | args = parser.parse_args() 70 | 71 | 72 | """ For test images in a folder """ 73 | image_list, _, _ = file_utils.get_files(args.test_folder) 74 | 75 | result_folder = args.output_folder 76 | if not os.path.isdir(result_folder): 77 | os.mkdir(result_folder) 78 | 79 | def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net=None): 80 | t0 = time.time() 81 | 82 | # resize 83 | img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, args.canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=args.mag_ratio) 84 | ratio_h = ratio_w = 1 / target_ratio 85 | 86 | # preprocessing 87 | x = imgproc.normalizeMeanVariance(img_resized) 88 | x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] 89 | x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w] 90 | if cuda: 91 | x = x.cuda() 92 | 93 | # forward pass 94 | with torch.no_grad(): 95 | y, feature = net(x) 96 | 97 | # make score and link map 98 | score_text = y[0,:,:,0].cpu().data.numpy() 99 | score_link = y[0,:,:,1].cpu().data.numpy() 100 | 101 | # refine link 102 | if refine_net is not None: 103 | with torch.no_grad(): 104 | y_refiner = refine_net(y, feature) 105 | score_link = y_refiner[0,:,:,0].cpu().data.numpy() 106 | 107 | t0 = time.time() - t0 108 | t1 = time.time() 109 | 110 | # Post-processing 111 | boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly) 112 | 113 | # coordinate adjustment 114 | boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) 115 | polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h) 116 | for k in range(len(polys)): 117 | if polys[k] is None: polys[k] = boxes[k] 118 | 119 | t1 = time.time() - t1 120 | 121 | # render results (optional) 122 | render_img = score_text.copy() 123 | render_img = np.hstack((render_img, score_link)) 124 | ret_score_text = imgproc.cvt2HeatmapImg(render_img) 125 | 126 | if args.show_time : print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1)) 127 | 128 | return boxes, polys, ret_score_text 129 | 130 | 131 | # convert bounding shape to bounding box 132 | def bounding_box(points): 133 | points = points.astype(np.int16) 134 | x_coordinates, y_coordinates = zip(*points) 135 | return [(min(x_coordinates), min(y_coordinates)), (max(x_coordinates), max(y_coordinates))] 136 | 137 | 138 | if __name__ == '__main__': 139 | # load net 140 | net = CRAFT() # initialize 141 | 142 | print('Loading weights from checkpoint (' + args.trained_model + ')') 143 | if args.cuda: 144 | net.load_state_dict(copyStateDict(torch.load(args.trained_model))) 145 | else: 146 | net.load_state_dict(copyStateDict(torch.load(args.trained_model, map_location='cpu'))) 147 | 148 | if args.cuda: 149 | net = net.cuda() 150 | net = torch.nn.DataParallel(net) 151 | cudnn.benchmark = False 152 | 153 | net.eval() 154 | 155 | # LinkRefiner 156 | refine_net = None 157 | if args.refine: 158 | from refinenet import RefineNet 159 | refine_net = RefineNet() 160 | print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')') 161 | if args.cuda: 162 | refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model))) 163 | refine_net = refine_net.cuda() 164 | refine_net = torch.nn.DataParallel(refine_net) 165 | else: 166 | refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model, map_location='cpu'))) 167 | 168 | refine_net.eval() 169 | args.poly = True 170 | 171 | t = time.time() 172 | 173 | # load data 174 | for k, image_path in enumerate(image_list): 175 | print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r') 176 | image = imgproc.loadImage(image_path) 177 | 178 | bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly, refine_net) 179 | 180 | filename, file_ext = os.path.splitext(os.path.basename(image_path)) 181 | 182 | # save cropped boxes 183 | if args.save_bboxes: 184 | for i, bbs in enumerate(bboxes): 185 | crop = bounding_box(bbs) 186 | cropped = image[crop[0][1]:crop[1][1],crop[0][0]:crop[1][0]] 187 | cv2.imwrite(result_folder + '/res_' + filename + '_cropped_' + str(i).zfill(4) + file_ext, cropped) 188 | 189 | # save score text 190 | if args.save_result: 191 | mask_file = result_folder + "/res_" + filename + '_mask' + file_ext 192 | cv2.imwrite(mask_file, score_text) 193 | file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder) 194 | 195 | print("elapsed time : {}s".format(time.time() - t)) 196 | -------------------------------------------------------------------------------- /craft_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | MIT License 4 | """ 5 | 6 | # -*- coding: utf-8 -*- 7 | import numpy as np 8 | import cv2 9 | import math 10 | 11 | """ auxilary functions """ 12 | # unwarp corodinates 13 | def warpCoord(Minv, pt): 14 | out = np.matmul(Minv, (pt[0], pt[1], 1)) 15 | return np.array([out[0]/out[2], out[1]/out[2]]) 16 | """ end of auxilary functions """ 17 | 18 | 19 | def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text): 20 | # prepare data 21 | linkmap = linkmap.copy() 22 | textmap = textmap.copy() 23 | img_h, img_w = textmap.shape 24 | 25 | """ labeling method """ 26 | ret, text_score = cv2.threshold(textmap, low_text, 1, 0) 27 | ret, link_score = cv2.threshold(linkmap, link_threshold, 1, 0) 28 | 29 | text_score_comb = np.clip(text_score + link_score, 0, 1) 30 | nLabels, labels, stats, centroids = cv2.connectedComponentsWithStats(text_score_comb.astype(np.uint8), connectivity=4) 31 | 32 | det = [] 33 | mapper = [] 34 | for k in range(1,nLabels): 35 | # size filtering 36 | size = stats[k, cv2.CC_STAT_AREA] 37 | if size < 10: continue 38 | 39 | # thresholding 40 | if np.max(textmap[labels==k]) < text_threshold: continue 41 | 42 | # make segmentation map 43 | segmap = np.zeros(textmap.shape, dtype=np.uint8) 44 | segmap[labels==k] = 255 45 | segmap[np.logical_and(link_score==1, text_score==0)] = 0 # remove link area 46 | x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP] 47 | w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT] 48 | niter = int(math.sqrt(size * min(w, h) / (w * h)) * 2) 49 | sx, ex, sy, ey = x - niter, x + w + niter + 1, y - niter, y + h + niter + 1 50 | # boundary check 51 | if sx < 0 : sx = 0 52 | if sy < 0 : sy = 0 53 | if ex >= img_w: ex = img_w 54 | if ey >= img_h: ey = img_h 55 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(1 + niter, 1 + niter)) 56 | segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel) 57 | 58 | # make box 59 | np_contours = np.roll(np.array(np.where(segmap!=0)),1,axis=0).transpose().reshape(-1,2) 60 | rectangle = cv2.minAreaRect(np_contours) 61 | box = cv2.boxPoints(rectangle) 62 | 63 | # align diamond-shape 64 | w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2]) 65 | box_ratio = max(w, h) / (min(w, h) + 1e-5) 66 | if abs(1 - box_ratio) <= 0.1: 67 | l, r = min(np_contours[:,0]), max(np_contours[:,0]) 68 | t, b = min(np_contours[:,1]), max(np_contours[:,1]) 69 | box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32) 70 | 71 | # make clock-wise order 72 | startidx = box.sum(axis=1).argmin() 73 | box = np.roll(box, 4-startidx, 0) 74 | box = np.array(box) 75 | 76 | det.append(box) 77 | mapper.append(k) 78 | 79 | return det, labels, mapper 80 | 81 | def getPoly_core(boxes, labels, mapper, linkmap): 82 | # configs 83 | num_cp = 5 84 | max_len_ratio = 0.7 85 | expand_ratio = 1.45 86 | max_r = 2.0 87 | step_r = 0.2 88 | 89 | polys = [] 90 | for k, box in enumerate(boxes): 91 | # size filter for small instance 92 | w, h = int(np.linalg.norm(box[0] - box[1]) + 1), int(np.linalg.norm(box[1] - box[2]) + 1) 93 | if w < 10 or h < 10: 94 | polys.append(None); continue 95 | 96 | # warp image 97 | tar = np.float32([[0,0],[w,0],[w,h],[0,h]]) 98 | M = cv2.getPerspectiveTransform(box, tar) 99 | word_label = cv2.warpPerspective(labels, M, (w, h), flags=cv2.INTER_NEAREST) 100 | try: 101 | Minv = np.linalg.inv(M) 102 | except: 103 | polys.append(None); continue 104 | 105 | # binarization for selected label 106 | cur_label = mapper[k] 107 | word_label[word_label != cur_label] = 0 108 | word_label[word_label > 0] = 1 109 | 110 | """ Polygon generation """ 111 | # find top/bottom contours 112 | cp = [] 113 | max_len = -1 114 | for i in range(w): 115 | region = np.where(word_label[:,i] != 0)[0] 116 | if len(region) < 2 : continue 117 | cp.append((i, region[0], region[-1])) 118 | length = region[-1] - region[0] + 1 119 | if length > max_len: max_len = length 120 | 121 | # pass if max_len is similar to h 122 | if h * max_len_ratio < max_len: 123 | polys.append(None); continue 124 | 125 | # get pivot points with fixed length 126 | tot_seg = num_cp * 2 + 1 127 | seg_w = w / tot_seg # segment width 128 | pp = [None] * num_cp # init pivot points 129 | cp_section = [[0, 0]] * tot_seg 130 | seg_height = [0] * num_cp 131 | seg_num = 0 132 | num_sec = 0 133 | prev_h = -1 134 | for i in range(0,len(cp)): 135 | (x, sy, ey) = cp[i] 136 | if (seg_num + 1) * seg_w <= x and seg_num <= tot_seg: 137 | # average previous segment 138 | if num_sec == 0: break 139 | cp_section[seg_num] = [cp_section[seg_num][0] / num_sec, cp_section[seg_num][1] / num_sec] 140 | num_sec = 0 141 | 142 | # reset variables 143 | seg_num += 1 144 | prev_h = -1 145 | 146 | # accumulate center points 147 | cy = (sy + ey) * 0.5 148 | cur_h = ey - sy + 1 149 | cp_section[seg_num] = [cp_section[seg_num][0] + x, cp_section[seg_num][1] + cy] 150 | num_sec += 1 151 | 152 | if seg_num % 2 == 0: continue # No polygon area 153 | 154 | if prev_h < cur_h: 155 | pp[int((seg_num - 1)/2)] = (x, cy) 156 | seg_height[int((seg_num - 1)/2)] = cur_h 157 | prev_h = cur_h 158 | 159 | # processing last segment 160 | if num_sec != 0: 161 | cp_section[-1] = [cp_section[-1][0] / num_sec, cp_section[-1][1] / num_sec] 162 | 163 | # pass if num of pivots is not sufficient or segment widh is smaller than character height 164 | if None in pp or seg_w < np.max(seg_height) * 0.25: 165 | polys.append(None); continue 166 | 167 | # calc median maximum of pivot points 168 | half_char_h = np.median(seg_height) * expand_ratio / 2 169 | 170 | # calc gradiant and apply to make horizontal pivots 171 | new_pp = [] 172 | for i, (x, cy) in enumerate(pp): 173 | dx = cp_section[i * 2 + 2][0] - cp_section[i * 2][0] 174 | dy = cp_section[i * 2 + 2][1] - cp_section[i * 2][1] 175 | if dx == 0: # gradient if zero 176 | new_pp.append([x, cy - half_char_h, x, cy + half_char_h]) 177 | continue 178 | rad = - math.atan2(dy, dx) 179 | c, s = half_char_h * math.cos(rad), half_char_h * math.sin(rad) 180 | new_pp.append([x - s, cy - c, x + s, cy + c]) 181 | 182 | # get edge points to cover character heatmaps 183 | isSppFound, isEppFound = False, False 184 | grad_s = (pp[1][1] - pp[0][1]) / (pp[1][0] - pp[0][0]) + (pp[2][1] - pp[1][1]) / (pp[2][0] - pp[1][0]) 185 | grad_e = (pp[-2][1] - pp[-1][1]) / (pp[-2][0] - pp[-1][0]) + (pp[-3][1] - pp[-2][1]) / (pp[-3][0] - pp[-2][0]) 186 | for r in np.arange(0.5, max_r, step_r): 187 | dx = 2 * half_char_h * r 188 | if not isSppFound: 189 | line_img = np.zeros(word_label.shape, dtype=np.uint8) 190 | dy = grad_s * dx 191 | p = np.array(new_pp[0]) - np.array([dx, dy, dx, dy]) 192 | cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1) 193 | if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r: 194 | spp = p 195 | isSppFound = True 196 | if not isEppFound: 197 | line_img = np.zeros(word_label.shape, dtype=np.uint8) 198 | dy = grad_e * dx 199 | p = np.array(new_pp[-1]) + np.array([dx, dy, dx, dy]) 200 | cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1) 201 | if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r: 202 | epp = p 203 | isEppFound = True 204 | if isSppFound and isEppFound: 205 | break 206 | 207 | # pass if boundary of polygon is not found 208 | if not (isSppFound and isEppFound): 209 | polys.append(None); continue 210 | 211 | # make final polygon 212 | poly = [] 213 | poly.append(warpCoord(Minv, (spp[0], spp[1]))) 214 | for p in new_pp: 215 | poly.append(warpCoord(Minv, (p[0], p[1]))) 216 | poly.append(warpCoord(Minv, (epp[0], epp[1]))) 217 | poly.append(warpCoord(Minv, (epp[2], epp[3]))) 218 | for p in reversed(new_pp): 219 | poly.append(warpCoord(Minv, (p[2], p[3]))) 220 | poly.append(warpCoord(Minv, (spp[2], spp[3]))) 221 | 222 | # add to final result 223 | polys.append(np.array(poly)) 224 | 225 | return polys 226 | 227 | def getDetBoxes(textmap, linkmap, text_threshold, link_threshold, low_text, poly=False): 228 | boxes, labels, mapper = getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text) 229 | 230 | if poly: 231 | polys = getPoly_core(boxes, labels, mapper, linkmap) 232 | else: 233 | polys = [None] * len(boxes) 234 | 235 | return boxes, polys 236 | 237 | def adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net = 2): 238 | if len(polys) > 0: 239 | polys = np.array(polys) 240 | for k in range(len(polys)): 241 | if polys[k] is not None: 242 | polys[k] *= (ratio_w * ratio_net, ratio_h * ratio_net) 243 | return polys 244 | --------------------------------------------------------------------------------