├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── basenet ├── __init__.py └── vgg16_bn.py ├── craft.py ├── craft_utils.py ├── data ├── cookbook.jpg └── textbook.jpg ├── figures └── craft_example.gif ├── file_utils.py ├── imgproc.py ├── onnx-export.py ├── onnx-inference.py ├── onnx └── placeholder.txt ├── requirements.txt └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.swp 3 | *.pkl 4 | *.pth 5 | result* 6 | weights* -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Use Python 3.9 slim image 2 | FROM python:3.9 3 | 4 | RUN pip install --upgrade pip 5 | 6 | # Install system dependencies 7 | RUN apt-get update && apt-get install -y \ 8 | libglib2.0-0 \ 9 | libsm6 \ 10 | libxext6 \ 11 | libxrender-dev \ 12 | libgl1-mesa-glx \ 13 | && apt-get clean \ 14 | && rm -rf /var/lib/apt/lists/* 15 | 16 | # Set the working directory in the container 17 | WORKDIR /app 18 | 19 | # Copy the current directory contents into the container at /app 20 | COPY . /app 21 | 22 | # Install torch and torchvision first (which will bring in opencv) 23 | # RUN pip install numpy>=1.24.0 24 | RUN pip install torch==2.2.2+cpu torchvision==0.17.2+cpu --index-url https://download.pytorch.org/whl/cpu 25 | 26 | 27 | # Install any needed packages specified in requirements.txt 28 | RUN pip install --no-cache-dir -r requirements.txt 29 | 30 | 31 | # Install craft-text-detector separately with --no-deps 32 | RUN pip install --no-deps craft-text-detector==0.4.3 33 | RUN pip install opencv-python==4.5.4.60 34 | RUN pip install gdown==5.2.0 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019-present NAVER Corp. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CRAFT Text Detection ONNX Export & Inference 🚀 2 | 3 | This repository contains code to export CRAFT (Character Region Awareness For Text Detection) model to ONNX format and run inference. 4 | 5 | ## 💎 Support 6 | If you find this tool useful, consider [becoming a sponsor](https://github.com/sponsors/ajinkya933) for $1/month and get priority support! 7 | 8 | --- 9 | 10 | ## Prerequisites 📋 11 | - Docker 🐳 12 | - High-resolution text images 🖼️ 13 | 14 | ## Directory Structure 📁 15 | ``` 16 | 17 | ├── data/ # Add your high-res images here 18 | ├── onnx/ # ONNX model will be exported here 19 | ├── weights/ # craft_mlt_25k.pth model goes here 20 | ├── outputs/ # Detection results will be saved here 21 | ├── Dockerfile 22 | ├── onnx-export.py 23 | └── onnx-inference.py 24 | ``` 25 | 26 | ## Quick Start 🏃‍♂️ 27 | 28 | 1. **Add Images and download pth file** 📸 29 | - Place your high-resolution text images in the `data` directory 30 | - Download pytorch model from [here](https://drive.google.com/file/d/1yN6_XLZVuKGL-3-w9MuqPqiM3QfAPVGV/view?usp=sharing), and save it in `weights` folder 31 | 32 | 2. **Build Docker Image** 🔨 33 | ```bash 34 | docker build -t craft-onnx:latest . 35 | ``` 36 | 37 | 3. **Run Docker Container** 🐋 38 | ```bash 39 | docker run -it craft-onnx:latest /bin/bash 40 | ``` 41 | 42 | 4. **Export ONNX Model** 📤 43 | ```bash 44 | python3 onnx-export.py 45 | ``` 46 | This will: 47 | - Take a sample image from `data` directory 48 | - Export ONNX graph to `onnx` folder 49 | 50 | 5. **Run Inference** 🔍 51 | ```bash 52 | python3 onnx-inference.py 53 | ``` 54 | - Uses the exported ONNX model 55 | - Saves detection results in `outputs` directory 56 | 57 | ## Model Details ⚙️ 58 | - Input size: 1280x960 59 | - Optimized for high-resolution document images 60 | - CPU-friendly inference 61 | 62 | ## Notes 📝 63 | - Make sure images are readable and have sufficient resolution 64 | - The model works best with clear, well-lit document images 65 | - Check `outputs` directory for detection results 66 | 67 | ## 💎 Sponsor $1 per month 68 | 👉 [My GitHub Sponsors link](https://github.com/sponsors/ajinkya933) 69 | 70 | 71 | ### 🌟 Sponsor Benefits 72 | - 🏢 Priority support for integrating this tool into your company's infrastructure 73 | - 🛠️ Direct assistance with project-related issues and customizations 74 | - 💡 Technical consultation for your specific use cases 75 | - 🚀 Early access to new features and improvements 76 | - ⭐ Recognition in our sponsors list 77 | 78 | ### Why Sponsor? 79 | Your sponsorship helps me maintain the code, ensuring it remains a robust and reliable tool for the community. Every contribution, no matter how small, makes a difference! 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | ## License 📄 89 | [MIT License](LICENSE) 90 | 91 | ### This Repository 92 | This repository's modifications and additions are licensed under MIT License. 93 | 94 | ### Third-Party Licenses 95 | - CRAFT Text Detector: [MIT License](https://github.com/clovaai/CRAFT-pytorch/blob/master/LICENSE) 96 | - PyTorch: [BSD License](https://github.com/pytorch/pytorch/blob/master/LICENSE) 97 | - ONNX Runtime: [MIT License](https://github.com/microsoft/onnxruntime/blob/master/LICENSE) 98 | 99 | ### Acknowledgments 🙏 100 | This work builds upon: 101 | - [CRAFT-pytorch](https://github.com/clovaai/CRAFT-pytorch) by CLOVA AI Research 102 | - Other open source projects listed in requirements.txt 103 | -------------------------------------------------------------------------------- /basenet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ajinkya933/CRAFT-pytorch/42b9ec502698147df110a1b90c5400ae4d12c5cf/basenet/__init__.py -------------------------------------------------------------------------------- /basenet/vgg16_bn.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | from torchvision import models 6 | from torchvision.models.vgg import VGG16_BN_Weights 7 | 8 | def init_weights(modules): 9 | for m in modules: 10 | if isinstance(m, nn.Conv2d): 11 | init.xavier_uniform_(m.weight.data) 12 | if m.bias is not None: 13 | m.bias.data.zero_() 14 | elif isinstance(m, nn.BatchNorm2d): 15 | m.weight.data.fill_(1) 16 | m.bias.data.zero_() 17 | elif isinstance(m, nn.Linear): 18 | m.weight.data.normal_(0, 0.01) 19 | m.bias.data.zero_() 20 | 21 | class vgg16_bn(torch.nn.Module): 22 | def __init__(self, pretrained=True, freeze=True): 23 | super(vgg16_bn, self).__init__() 24 | # Use the weights parameter based on the pretrained flag 25 | weights = VGG16_BN_Weights.IMAGENET1K_V1 if pretrained else None 26 | vgg_pretrained_features = models.vgg16_bn(weights=weights).features 27 | 28 | self.slice1 = torch.nn.Sequential() 29 | self.slice2 = torch.nn.Sequential() 30 | self.slice3 = torch.nn.Sequential() 31 | self.slice4 = torch.nn.Sequential() 32 | self.slice5 = torch.nn.Sequential() 33 | 34 | for x in range(12): # conv2_2 35 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 36 | for x in range(12, 19): # conv3_3 37 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 38 | for x in range(19, 29): # conv4_3 39 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 40 | for x in range(29, 39): # conv5_3 41 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 42 | 43 | # fc6, fc7 without atrous conv 44 | self.slice5 = torch.nn.Sequential( 45 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1), 46 | nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6), 47 | nn.Conv2d(1024, 1024, kernel_size=1) 48 | ) 49 | 50 | if not pretrained: 51 | init_weights(self.slice1.modules()) 52 | init_weights(self.slice2.modules()) 53 | init_weights(self.slice3.modules()) 54 | init_weights(self.slice4.modules()) 55 | 56 | init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7 57 | 58 | if freeze: 59 | for param in self.slice1.parameters(): # only first conv 60 | param.requires_grad = False 61 | 62 | def forward(self, X): 63 | h = self.slice1(X) 64 | h_relu2_2 = h 65 | h = self.slice2(h) 66 | h_relu3_2 = h 67 | h = self.slice3(h) 68 | h_relu4_3 = h 69 | h = self.slice4(h) 70 | h_relu5_3 = h 71 | h = self.slice5(h) 72 | h_fc7 = h 73 | vgg_outputs = namedtuple("VggOutputs", ['fc7', 'relu5_3', 'relu4_3', 'relu3_2', 'relu2_2']) 74 | out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2) 75 | return out -------------------------------------------------------------------------------- /craft.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | MIT License 4 | """ 5 | 6 | # -*- coding: utf-8 -*- 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from basenet.vgg16_bn import vgg16_bn, init_weights 12 | 13 | class double_conv(nn.Module): 14 | def __init__(self, in_ch, mid_ch, out_ch): 15 | super(double_conv, self).__init__() 16 | self.conv = nn.Sequential( 17 | nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1), 18 | nn.BatchNorm2d(mid_ch), 19 | nn.ReLU(inplace=True), 20 | nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1), 21 | nn.BatchNorm2d(out_ch), 22 | nn.ReLU(inplace=True) 23 | ) 24 | 25 | def forward(self, x): 26 | x = self.conv(x) 27 | return x 28 | 29 | 30 | class CRAFT(nn.Module): 31 | def __init__(self, pretrained=False, freeze=False): 32 | super(CRAFT, self).__init__() 33 | 34 | """ Base network """ 35 | self.basenet = vgg16_bn(pretrained, freeze) 36 | 37 | """ U network """ 38 | self.upconv1 = double_conv(1024, 512, 256) 39 | self.upconv2 = double_conv(512, 256, 128) 40 | self.upconv3 = double_conv(256, 128, 64) 41 | self.upconv4 = double_conv(128, 64, 32) 42 | 43 | num_class = 2 44 | self.conv_cls = nn.Sequential( 45 | nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True), 46 | nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True), 47 | nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True), 48 | nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True), 49 | nn.Conv2d(16, num_class, kernel_size=1), 50 | ) 51 | 52 | init_weights(self.upconv1.modules()) 53 | init_weights(self.upconv2.modules()) 54 | init_weights(self.upconv3.modules()) 55 | init_weights(self.upconv4.modules()) 56 | init_weights(self.conv_cls.modules()) 57 | 58 | def forward(self, x): 59 | """ Base network """ 60 | sources = self.basenet(x) 61 | 62 | """ U network """ 63 | y = torch.cat([sources[0], sources[1]], dim=1) 64 | y = self.upconv1(y) 65 | 66 | y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False) 67 | y = torch.cat([y, sources[2]], dim=1) 68 | y = self.upconv2(y) 69 | 70 | y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False) 71 | y = torch.cat([y, sources[3]], dim=1) 72 | y = self.upconv3(y) 73 | 74 | y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False) 75 | y = torch.cat([y, sources[4]], dim=1) 76 | feature = self.upconv4(y) 77 | 78 | y = self.conv_cls(feature) 79 | 80 | return y.permute(0,2,3,1), feature 81 | 82 | if __name__ == '__main__': 83 | model = CRAFT(pretrained=True).cuda() 84 | output, _ = model(torch.randn(1, 3, 768, 768).cuda()) 85 | print(output.shape) -------------------------------------------------------------------------------- /craft_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | MIT License 4 | """ 5 | 6 | # -*- coding: utf-8 -*- 7 | import numpy as np 8 | import cv2 9 | import math 10 | 11 | """ auxilary functions """ 12 | # unwarp corodinates 13 | def warpCoord(Minv, pt): 14 | out = np.matmul(Minv, (pt[0], pt[1], 1)) 15 | return np.array([out[0]/out[2], out[1]/out[2]]) 16 | """ end of auxilary functions """ 17 | 18 | 19 | def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text): 20 | # prepare data 21 | linkmap = linkmap.copy() 22 | textmap = textmap.copy() 23 | img_h, img_w = textmap.shape 24 | 25 | """ labeling method """ 26 | ret, text_score = cv2.threshold(textmap, low_text, 1, 0) 27 | ret, link_score = cv2.threshold(linkmap, link_threshold, 1, 0) 28 | 29 | text_score_comb = np.clip(text_score + link_score, 0, 1) 30 | nLabels, labels, stats, centroids = cv2.connectedComponentsWithStats(text_score_comb.astype(np.uint8), connectivity=4) 31 | 32 | det = [] 33 | mapper = [] 34 | for k in range(1,nLabels): 35 | # size filtering 36 | size = stats[k, cv2.CC_STAT_AREA] 37 | if size < 10: continue 38 | 39 | # thresholding 40 | if np.max(textmap[labels==k]) < text_threshold: continue 41 | 42 | # make segmentation map 43 | segmap = np.zeros(textmap.shape, dtype=np.uint8) 44 | segmap[labels==k] = 255 45 | segmap[np.logical_and(link_score==1, text_score==0)] = 0 # remove link area 46 | x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP] 47 | w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT] 48 | niter = int(math.sqrt(size * min(w, h) / (w * h)) * 2) 49 | sx, ex, sy, ey = x - niter, x + w + niter + 1, y - niter, y + h + niter + 1 50 | # boundary check 51 | if sx < 0 : sx = 0 52 | if sy < 0 : sy = 0 53 | if ex >= img_w: ex = img_w 54 | if ey >= img_h: ey = img_h 55 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(1 + niter, 1 + niter)) 56 | segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel) 57 | 58 | # make box 59 | np_contours = np.roll(np.array(np.where(segmap!=0)),1,axis=0).transpose().reshape(-1,2) 60 | rectangle = cv2.minAreaRect(np_contours) 61 | box = cv2.boxPoints(rectangle) 62 | 63 | # align diamond-shape 64 | w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2]) 65 | box_ratio = max(w, h) / (min(w, h) + 1e-5) 66 | if abs(1 - box_ratio) <= 0.1: 67 | l, r = min(np_contours[:,0]), max(np_contours[:,0]) 68 | t, b = min(np_contours[:,1]), max(np_contours[:,1]) 69 | box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32) 70 | 71 | # make clock-wise order 72 | startidx = box.sum(axis=1).argmin() 73 | box = np.roll(box, 4-startidx, 0) 74 | box = np.array(box) 75 | 76 | det.append(box) 77 | mapper.append(k) 78 | 79 | return det, labels, mapper 80 | 81 | def getPoly_core(boxes, labels, mapper, linkmap): 82 | # configs 83 | num_cp = 5 84 | max_len_ratio = 0.7 85 | expand_ratio = 1.45 86 | max_r = 2.0 87 | step_r = 0.2 88 | 89 | polys = [] 90 | for k, box in enumerate(boxes): 91 | # size filter for small instance 92 | w, h = int(np.linalg.norm(box[0] - box[1]) + 1), int(np.linalg.norm(box[1] - box[2]) + 1) 93 | if w < 30 or h < 30: 94 | polys.append(None); continue 95 | 96 | # warp image 97 | tar = np.float32([[0,0],[w,0],[w,h],[0,h]]) 98 | M = cv2.getPerspectiveTransform(box, tar) 99 | word_label = cv2.warpPerspective(labels, M, (w, h), flags=cv2.INTER_NEAREST) 100 | try: 101 | Minv = np.linalg.inv(M) 102 | except: 103 | polys.append(None); continue 104 | 105 | # binarization for selected label 106 | cur_label = mapper[k] 107 | word_label[word_label != cur_label] = 0 108 | word_label[word_label > 0] = 1 109 | 110 | """ Polygon generation """ 111 | # find top/bottom contours 112 | cp = [] 113 | max_len = -1 114 | for i in range(w): 115 | region = np.where(word_label[:,i] != 0)[0] 116 | if len(region) < 2 : continue 117 | cp.append((i, region[0], region[-1])) 118 | length = region[-1] - region[0] + 1 119 | if length > max_len: max_len = length 120 | 121 | # pass if max_len is similar to h 122 | if h * max_len_ratio < max_len: 123 | polys.append(None); continue 124 | 125 | # get pivot points with fixed length 126 | tot_seg = num_cp * 2 + 1 127 | seg_w = w / tot_seg # segment width 128 | pp = [None] * num_cp # init pivot points 129 | cp_section = [[0, 0]] * tot_seg 130 | seg_height = [0] * num_cp 131 | seg_num = 0 132 | num_sec = 0 133 | prev_h = -1 134 | for i in range(0,len(cp)): 135 | (x, sy, ey) = cp[i] 136 | if (seg_num + 1) * seg_w <= x and seg_num <= tot_seg: 137 | # average previous segment 138 | if num_sec == 0: break 139 | cp_section[seg_num] = [cp_section[seg_num][0] / num_sec, cp_section[seg_num][1] / num_sec] 140 | num_sec = 0 141 | 142 | # reset variables 143 | seg_num += 1 144 | prev_h = -1 145 | 146 | # accumulate center points 147 | cy = (sy + ey) * 0.5 148 | cur_h = ey - sy + 1 149 | cp_section[seg_num] = [cp_section[seg_num][0] + x, cp_section[seg_num][1] + cy] 150 | num_sec += 1 151 | 152 | if seg_num % 2 == 0: continue # No polygon area 153 | 154 | if prev_h < cur_h: 155 | pp[int((seg_num - 1)/2)] = (x, cy) 156 | seg_height[int((seg_num - 1)/2)] = cur_h 157 | prev_h = cur_h 158 | 159 | # processing last segment 160 | if num_sec != 0: 161 | cp_section[-1] = [cp_section[-1][0] / num_sec, cp_section[-1][1] / num_sec] 162 | 163 | # pass if num of pivots is not sufficient or segment widh is smaller than character height 164 | if None in pp or seg_w < np.max(seg_height) * 0.25: 165 | polys.append(None); continue 166 | 167 | # calc median maximum of pivot points 168 | half_char_h = np.median(seg_height) * expand_ratio / 2 169 | 170 | # calc gradiant and apply to make horizontal pivots 171 | new_pp = [] 172 | for i, (x, cy) in enumerate(pp): 173 | dx = cp_section[i * 2 + 2][0] - cp_section[i * 2][0] 174 | dy = cp_section[i * 2 + 2][1] - cp_section[i * 2][1] 175 | if dx == 0: # gradient if zero 176 | new_pp.append([x, cy - half_char_h, x, cy + half_char_h]) 177 | continue 178 | rad = - math.atan2(dy, dx) 179 | c, s = half_char_h * math.cos(rad), half_char_h * math.sin(rad) 180 | new_pp.append([x - s, cy - c, x + s, cy + c]) 181 | 182 | # get edge points to cover character heatmaps 183 | isSppFound, isEppFound = False, False 184 | grad_s = (pp[1][1] - pp[0][1]) / (pp[1][0] - pp[0][0]) + (pp[2][1] - pp[1][1]) / (pp[2][0] - pp[1][0]) 185 | grad_e = (pp[-2][1] - pp[-1][1]) / (pp[-2][0] - pp[-1][0]) + (pp[-3][1] - pp[-2][1]) / (pp[-3][0] - pp[-2][0]) 186 | for r in np.arange(0.5, max_r, step_r): 187 | dx = 2 * half_char_h * r 188 | if not isSppFound: 189 | line_img = np.zeros(word_label.shape, dtype=np.uint8) 190 | dy = grad_s * dx 191 | p = np.array(new_pp[0]) - np.array([dx, dy, dx, dy]) 192 | cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1) 193 | if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r: 194 | spp = p 195 | isSppFound = True 196 | if not isEppFound: 197 | line_img = np.zeros(word_label.shape, dtype=np.uint8) 198 | dy = grad_e * dx 199 | p = np.array(new_pp[-1]) + np.array([dx, dy, dx, dy]) 200 | cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1) 201 | if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r: 202 | epp = p 203 | isEppFound = True 204 | if isSppFound and isEppFound: 205 | break 206 | 207 | # pass if boundary of polygon is not found 208 | if not (isSppFound and isEppFound): 209 | polys.append(None); continue 210 | 211 | # make final polygon 212 | poly = [] 213 | poly.append(warpCoord(Minv, (spp[0], spp[1]))) 214 | for p in new_pp: 215 | poly.append(warpCoord(Minv, (p[0], p[1]))) 216 | poly.append(warpCoord(Minv, (epp[0], epp[1]))) 217 | poly.append(warpCoord(Minv, (epp[2], epp[3]))) 218 | for p in reversed(new_pp): 219 | poly.append(warpCoord(Minv, (p[2], p[3]))) 220 | poly.append(warpCoord(Minv, (spp[2], spp[3]))) 221 | 222 | # add to final result 223 | polys.append(np.array(poly)) 224 | 225 | return polys 226 | 227 | def getDetBoxes(textmap, linkmap, text_threshold, link_threshold, low_text, poly=False): 228 | boxes, labels, mapper = getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text) 229 | 230 | if poly: 231 | polys = getPoly_core(boxes, labels, mapper, linkmap) 232 | else: 233 | polys = [None] * len(boxes) 234 | 235 | return boxes, polys 236 | 237 | def adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net = 2): 238 | if len(polys) > 0: 239 | polys = np.array(polys) 240 | for k in range(len(polys)): 241 | if polys[k] is not None: 242 | polys[k] *= (ratio_w * ratio_net, ratio_h * ratio_net) 243 | return polys 244 | -------------------------------------------------------------------------------- /data/cookbook.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ajinkya933/CRAFT-pytorch/42b9ec502698147df110a1b90c5400ae4d12c5cf/data/cookbook.jpg -------------------------------------------------------------------------------- /data/textbook.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ajinkya933/CRAFT-pytorch/42b9ec502698147df110a1b90c5400ae4d12c5cf/data/textbook.jpg -------------------------------------------------------------------------------- /figures/craft_example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ajinkya933/CRAFT-pytorch/42b9ec502698147df110a1b90c5400ae4d12c5cf/figures/craft_example.gif -------------------------------------------------------------------------------- /file_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import numpy as np 4 | import cv2 5 | import imgproc 6 | 7 | # borrowed from https://github.com/lengstrom/fast-style-transfer/blob/master/src/utils.py 8 | def get_files(img_dir): 9 | imgs, masks, xmls = list_files(img_dir) 10 | return imgs, masks, xmls 11 | 12 | def list_files(in_path): 13 | img_files = [] 14 | mask_files = [] 15 | gt_files = [] 16 | for (dirpath, dirnames, filenames) in os.walk(in_path): 17 | for file in filenames: 18 | filename, ext = os.path.splitext(file) 19 | ext = str.lower(ext) 20 | if ext == '.jpg' or ext == '.jpeg' or ext == '.gif' or ext == '.png' or ext == '.pgm': 21 | img_files.append(os.path.join(dirpath, file)) 22 | elif ext == '.bmp': 23 | mask_files.append(os.path.join(dirpath, file)) 24 | elif ext == '.xml' or ext == '.gt' or ext == '.txt': 25 | gt_files.append(os.path.join(dirpath, file)) 26 | elif ext == '.zip': 27 | continue 28 | # img_files.sort() 29 | # mask_files.sort() 30 | # gt_files.sort() 31 | return img_files, mask_files, gt_files 32 | 33 | def saveResult(img_file, img, boxes, dirname='./result/', verticals=None, texts=None): 34 | """ save text detection result one by one 35 | Args: 36 | img_file (str): image file name 37 | img (array): raw image context 38 | boxes (array): array of result file 39 | Shape: [num_detections, 4] for BB output / [num_detections, 4] for QUAD output 40 | Return: 41 | None 42 | """ 43 | img = np.array(img) 44 | 45 | # make result file list 46 | filename, file_ext = os.path.splitext(os.path.basename(img_file)) 47 | 48 | # result directory 49 | res_file = dirname + "res_" + filename + '.txt' 50 | res_img_file = dirname + "res_" + filename + '.jpg' 51 | 52 | if not os.path.isdir(dirname): 53 | os.mkdir(dirname) 54 | 55 | with open(res_file, 'w') as f: 56 | for i, box in enumerate(boxes): 57 | poly = np.array(box).astype(np.int32).reshape((-1)) 58 | strResult = ','.join([str(p) for p in poly]) + '\r\n' 59 | f.write(strResult) 60 | 61 | poly = poly.reshape(-1, 2) 62 | cv2.polylines(img, [poly.reshape((-1, 1, 2))], True, color=(0, 0, 255), thickness=2) 63 | ptColor = (0, 255, 255) 64 | if verticals is not None: 65 | if verticals[i]: 66 | ptColor = (255, 0, 0) 67 | 68 | if texts is not None: 69 | font = cv2.FONT_HERSHEY_SIMPLEX 70 | font_scale = 0.5 71 | cv2.putText(img, "{}".format(texts[i]), (poly[0][0]+1, poly[0][1]+1), font, font_scale, (0, 0, 0), thickness=1) 72 | cv2.putText(img, "{}".format(texts[i]), tuple(poly[0]), font, font_scale, (0, 255, 255), thickness=1) 73 | 74 | # Save result image 75 | cv2.imwrite(res_img_file, img) 76 | 77 | -------------------------------------------------------------------------------- /imgproc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | MIT License 4 | """ 5 | 6 | # -*- coding: utf-8 -*- 7 | import numpy as np 8 | from skimage import io 9 | import cv2 10 | 11 | def loadImage(img_file): 12 | img = io.imread(img_file) # RGB order 13 | if img.shape[0] == 2: img = img[0] 14 | if len(img.shape) == 2 : img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) 15 | if img.shape[2] == 4: img = img[:,:,:3] 16 | img = np.array(img) 17 | 18 | return img 19 | 20 | def normalizeMeanVariance(in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)): 21 | # should be RGB order 22 | img = in_img.copy().astype(np.float32) 23 | 24 | img -= np.array([mean[0] * 255.0, mean[1] * 255.0, mean[2] * 255.0], dtype=np.float32) 25 | img /= np.array([variance[0] * 255.0, variance[1] * 255.0, variance[2] * 255.0], dtype=np.float32) 26 | return img 27 | 28 | def denormalizeMeanVariance(in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)): 29 | # should be RGB order 30 | img = in_img.copy() 31 | img *= variance 32 | img += mean 33 | img *= 255.0 34 | img = np.clip(img, 0, 255).astype(np.uint8) 35 | return img 36 | 37 | def resize_aspect_ratio(img, square_size, interpolation, mag_ratio=1): 38 | height, width, channel = img.shape 39 | 40 | # magnify image size 41 | target_size = mag_ratio * max(height, width) 42 | 43 | # set original image size 44 | if target_size > square_size: 45 | target_size = square_size 46 | 47 | ratio = target_size / max(height, width) 48 | 49 | target_h, target_w = int(height * ratio), int(width * ratio) 50 | proc = cv2.resize(img, (target_w, target_h), interpolation = interpolation) 51 | 52 | 53 | # make canvas and paste image 54 | target_h32, target_w32 = target_h, target_w 55 | if target_h % 32 != 0: 56 | target_h32 = target_h + (32 - target_h % 32) 57 | if target_w % 32 != 0: 58 | target_w32 = target_w + (32 - target_w % 32) 59 | resized = np.zeros((target_h32, target_w32, channel), dtype=np.float32) 60 | resized[0:target_h, 0:target_w, :] = proc 61 | target_h, target_w = target_h32, target_w32 62 | 63 | size_heatmap = (int(target_w/2), int(target_h/2)) 64 | 65 | return resized, ratio, size_heatmap 66 | 67 | def cvt2HeatmapImg(img): 68 | img = (np.clip(img, 0, 1) * 255).astype(np.uint8) 69 | img = cv2.applyColorMap(img, cv2.COLORMAP_JET) 70 | return img 71 | -------------------------------------------------------------------------------- /onnx-export.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import cv2 4 | import imgproc 5 | from craft import CRAFT 6 | from collections import OrderedDict 7 | 8 | 9 | def copyStateDict(state_dict): 10 | if list(state_dict.keys())[0].startswith("module"): 11 | start_idx = 1 12 | else: 13 | start_idx = 0 14 | new_state_dict = OrderedDict() 15 | for k, v in state_dict.items(): 16 | name = ".".join(k.split(".")[start_idx:]) 17 | new_state_dict[name] = v 18 | return new_state_dict 19 | 20 | 21 | 22 | 23 | 24 | 25 | # load net 26 | net = CRAFT() # initialize 27 | 28 | 29 | net.load_state_dict(copyStateDict(torch.load('./weights/craft_mlt_25k.pth', map_location='cpu'))) 30 | #net = net.cuda() 31 | net.eval() 32 | # load data 33 | image = imgproc.loadImage('./data/cookbook.jpg') 34 | 35 | # Calculate dimensions that maintain aspect ratio 36 | original_height, original_width = 4096, 3072 37 | aspect_ratio = original_width / original_height 38 | 39 | # Target width of 1280 (common size for CRAFT) 40 | target_width = 1280 41 | target_height = 960 # Fixed height that's divisible by 32 42 | 43 | print(f"Export dimensions: {target_width}x{target_height}") 44 | 45 | # resize 46 | img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio( 47 | image, 48 | square_size=target_width, 49 | interpolation=cv2.INTER_LINEAR, 50 | mag_ratio=1.5 51 | ) 52 | img_resized = cv2.resize(img_resized, (target_width, target_height), interpolation=cv2.INTER_LINEAR) 53 | 54 | # preprocessing 55 | x = imgproc.normalizeMeanVariance(img_resized) 56 | x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] 57 | x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w] 58 | # x = x.cuda() 59 | 60 | # Verify shape before export 61 | print(f"Export tensor shape: {x.shape}") # Should print [1, 3, 960, 1280] 62 | 63 | # trace export 64 | torch.onnx.export(net, 65 | x, 66 | 'onnx/craft-detect.onnx', 67 | input_names=['input'], 68 | output_names=['output'], 69 | export_params=True, 70 | opset_version=11, 71 | do_constant_folding=True, 72 | verbose=False) 73 | -------------------------------------------------------------------------------- /onnx-inference.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | start = datetime.now() 3 | 4 | from craft_text_detector import export_detected_regions 5 | import torch 6 | #torch.set_num_threads(3) 7 | import cv2 8 | import onnxruntime as rt 9 | 10 | import craft_utils 11 | import imgproc 12 | 13 | 14 | #refine_net = load_refinenet_model(cuda=False) 15 | 16 | #sess = rt.InferenceSession("model_640_480.onnx") #resize to 320x240 - bad results 17 | sess = rt.InferenceSession("onnx/craft-detect.onnx") #1000x960 resize to 512x384 18 | input_name = sess.get_inputs()[0].name 19 | 20 | output_dir = 'outputs/' 21 | 22 | img = cv2.imread('data/textbook.jpg') 23 | #img=imgproc.loadImage('frame1.jpg') 24 | print(img.shape) #(1944, 2592, 3) 25 | 26 | # Store original image dimensions 27 | original_height, original_width = img.shape[:2] 28 | print(f"Original dimensions: {original_width}x{original_height}") 29 | 30 | # Calculate target dimensions 31 | target_width = 1280 32 | target_height = 960 33 | 34 | # First resize using the original function 35 | img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio( 36 | img, 37 | square_size=target_width, 38 | interpolation=cv2.INTER_LINEAR, 39 | mag_ratio=1.5 40 | ) 41 | 42 | # Then resize to exact dimensions 43 | img_resized = cv2.resize(img_resized, (target_width, target_height), interpolation=cv2.INTER_LINEAR) 44 | 45 | # Calculate actual ratios based on final dimensions 46 | ratio_w = original_width / target_width 47 | ratio_h = original_height / target_height 48 | 49 | print(f"Resize ratios - width: {ratio_w}, height: {ratio_h}") 50 | 51 | x = imgproc.normalizeMeanVariance(img_resized) 52 | x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] 53 | x = x.unsqueeze(0) # [c, h, w] to [b, c, h, w] 54 | 55 | # Optional: verify tensor shape before inference 56 | print(f"Input tensor shape: {x.shape}") # Should print [1, 3, 960, 1280] 57 | 58 | y, feature = sess.run(None, {input_name: x.numpy()}) 59 | 60 | # make score and link map 61 | score_text = y[0, :, :, 0] 62 | score_link = y[0, :, :, 1] 63 | 64 | # refine link 65 | #with torch.no_grad(): 66 | # y_refiner = refine_net(y, feature) 67 | #score_link = y_refiner[0,:,:,0].cpu().data.numpy() 68 | 69 | boxes, polys = craft_utils.getDetBoxes(score_text, score_link, 0.7, 0.4, 0.4, True) 70 | boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) 71 | 72 | # Comment out the problematic polys adjustment 73 | #polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h) 74 | 75 | print(datetime.now()- start) 76 | 77 | export_detected_regions( 78 | image=img, 79 | regions=boxes, 80 | output_dir=output_dir, 81 | rectify=True) 82 | import file_utils 83 | file_utils.saveResult('outputs/', img[:,:,::-1], boxes, dirname=output_dir) 84 | -------------------------------------------------------------------------------- /onnx/placeholder.txt: -------------------------------------------------------------------------------- 1 | created to maintain folder structure on github, as github dosent allow empty folders to exist. 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-image>=0.19.0 2 | scipy==1.13.1 3 | onnx==1.17.0 4 | onnxruntime==1.19.2 5 | 6 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | MIT License 4 | """ 5 | 6 | # -*- coding: utf-8 -*- 7 | import sys 8 | import os 9 | import time 10 | import argparse 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.backends.cudnn as cudnn 15 | from torch.autograd import Variable 16 | 17 | from PIL import Image 18 | 19 | import cv2 20 | from skimage import io 21 | import numpy as np 22 | import craft_utils 23 | import imgproc 24 | import file_utils 25 | import json 26 | import zipfile 27 | 28 | from craft import CRAFT 29 | 30 | from collections import OrderedDict 31 | def copyStateDict(state_dict): 32 | if list(state_dict.keys())[0].startswith("module"): 33 | start_idx = 1 34 | else: 35 | start_idx = 0 36 | new_state_dict = OrderedDict() 37 | for k, v in state_dict.items(): 38 | name = ".".join(k.split(".")[start_idx:]) 39 | new_state_dict[name] = v 40 | return new_state_dict 41 | 42 | def str2bool(v): 43 | return v.lower() in ("yes", "y", "true", "t", "1") 44 | 45 | parser = argparse.ArgumentParser(description='CRAFT Text Detection') 46 | parser.add_argument('--trained_model', default='weights/craft_mlt_25k.pth', type=str, help='pretrained model') 47 | parser.add_argument('--text_threshold', default=0.7, type=float, help='text confidence threshold') 48 | parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score') 49 | parser.add_argument('--link_threshold', default=0.4, type=float, help='link confidence threshold') 50 | parser.add_argument('--cuda', default=torch.cuda.is_available(), type=str2bool, help='Use cuda for inference') 51 | parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference') 52 | parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio') 53 | parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type') 54 | parser.add_argument('--show_time', default=False, action='store_true', help='show processing time') 55 | parser.add_argument('--test_folder', default='/data/', type=str, help='folder path to input images') 56 | parser.add_argument('--refine', default=False, action='store_true', help='enable link refiner') 57 | parser.add_argument('--refiner_model', default='weights/craft_refiner_CTW1500.pth', type=str, help='pretrained refiner model') 58 | 59 | args = parser.parse_args() 60 | 61 | 62 | """ For test images in a folder """ 63 | image_list, _, _ = file_utils.get_files(args.test_folder) 64 | 65 | result_folder = './result/' 66 | if not os.path.isdir(result_folder): 67 | os.mkdir(result_folder) 68 | 69 | def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net=None): 70 | t0 = time.time() 71 | 72 | # resize 73 | img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, args.canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=args.mag_ratio) 74 | ratio_h = ratio_w = 1 / target_ratio 75 | 76 | # preprocessing 77 | x = imgproc.normalizeMeanVariance(img_resized) 78 | x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] 79 | x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w] 80 | if cuda: 81 | x = x.cuda() 82 | 83 | # forward pass 84 | with torch.no_grad(): 85 | y, feature = net(x) 86 | 87 | # make score and link map 88 | score_text = y[0,:,:,0].cpu().data.numpy() 89 | score_link = y[0,:,:,1].cpu().data.numpy() 90 | 91 | # refine link 92 | if refine_net is not None: 93 | with torch.no_grad(): 94 | y_refiner = refine_net(y, feature) 95 | score_link = y_refiner[0,:,:,0].cpu().data.numpy() 96 | 97 | t0 = time.time() - t0 98 | t1 = time.time() 99 | 100 | # Post-processing 101 | boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly) 102 | 103 | # coordinate adjustment 104 | boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) 105 | polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h) 106 | for k in range(len(polys)): 107 | if polys[k] is None: polys[k] = boxes[k] 108 | 109 | t1 = time.time() - t1 110 | 111 | # render results (optional) 112 | render_img = score_text.copy() 113 | render_img = np.hstack((render_img, score_link)) 114 | ret_score_text = imgproc.cvt2HeatmapImg(render_img) 115 | 116 | if args.show_time : print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1)) 117 | 118 | return boxes, polys, ret_score_text 119 | 120 | 121 | 122 | if __name__ == '__main__': 123 | # load net 124 | net = CRAFT() # initialize 125 | 126 | print('Loading weights from checkpoint (' + args.trained_model + ')') 127 | if args.cuda: 128 | net.load_state_dict(copyStateDict(torch.load(args.trained_model))) 129 | else: 130 | net.load_state_dict(copyStateDict(torch.load(args.trained_model, map_location='cpu'))) 131 | 132 | if args.cuda: 133 | net = net.cuda() 134 | net = torch.nn.DataParallel(net) 135 | cudnn.benchmark = False 136 | 137 | net.eval() 138 | 139 | # LinkRefiner 140 | refine_net = None 141 | if args.refine: 142 | from refinenet import RefineNet 143 | refine_net = RefineNet() 144 | print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')') 145 | if args.cuda: 146 | refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model))) 147 | refine_net = refine_net.cuda() 148 | refine_net = torch.nn.DataParallel(refine_net) 149 | else: 150 | refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model, map_location='cpu'))) 151 | 152 | refine_net.eval() 153 | args.poly = True 154 | 155 | t = time.time() 156 | 157 | # load data 158 | for k, image_path in enumerate(image_list): 159 | print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r') 160 | image = imgproc.loadImage(image_path) 161 | 162 | bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly, refine_net) 163 | 164 | # save score text 165 | filename, file_ext = os.path.splitext(os.path.basename(image_path)) 166 | mask_file = result_folder + "/res_" + filename + '_mask.jpg' 167 | cv2.imwrite(mask_file, score_text) 168 | 169 | file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder) 170 | 171 | print("elapsed time : {}s".format(time.time() - t)) 172 | --------------------------------------------------------------------------------