├── .idea
├── misc.xml
├── modules.xml
├── pytorch_ctpn.iml
└── workspace.xml
├── checkpoints
└── note
├── config.py
├── ctpn_model.py
├── ctpn_predict.py
├── ctpn_train.py
├── ctpn_utils.py
├── data
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ └── dataset.cpython-36.pyc
└── dataset.py
├── images
├── android_det.jpg
├── android_rec.jpg
└── onto_android.md
├── logs
├── ANDROID_OCR.pdf
├── loss.png
└── training_logs.pdf
├── readme.md
└── results
├── ANDROID_DETECTION_SKEW.GIF
├── ANDROID_RECO_DEMO.GIF
├── detection_res.png
├── r0.jpg
├── r1.jpg
├── r2.jpg
└── r3.jpg
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/pytorch_ctpn.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
40 |
41 |
42 |
43 | cv2
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 | 1544493672935
160 |
161 |
162 | 1544493672935
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
--------------------------------------------------------------------------------
/checkpoints/note:
--------------------------------------------------------------------------------
1 | For a number of reasons, the pretrained weights will no longer be available.Thanks for your attention.
2 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | #-*- coding:utf-8 -*-
2 | #'''
3 | # Created on 18-12-11 上午10:09
4 | #
5 | # @Author: Greg Gao(laygin)
6 | #'''
7 | import os
8 |
9 | # base_dir = 'path to dataset base dir'
10 | base_dir = './images'
11 | img_dir = os.path.join(base_dir, 'VOC2007_text_detection/JPEGImages')
12 | xml_dir = os.path.join(base_dir, 'VOC2007_text_detection/Annotations')
13 |
14 | train_txt_file = os.path.join(base_dir, r'VOC2007_text_detection/ImageSets/Main/train.txt')
15 | val_txt_file = os.path.join(base_dir, r'VOC2007_text_detection/ImageSets/Main/val.txt')
16 |
17 |
18 | anchor_scale = 16
19 | IOU_NEGATIVE = 0.3
20 | IOU_POSITIVE = 0.7
21 | IOU_SELECT = 0.7
22 |
23 | RPN_POSITIVE_NUM = 150
24 | RPN_TOTAL_NUM = 300
25 |
26 | # bgr can find from here: https://github.com/fchollet/deep-learning-models/blob/master/imagenet_utils.py
27 | IMAGE_MEAN = [123.68, 116.779, 103.939]
28 |
29 |
30 | checkpoints_dir = './checkpoints'
31 | outputs = r'./logs'
32 |
--------------------------------------------------------------------------------
/ctpn_model.py:
--------------------------------------------------------------------------------
1 | #-*- coding:utf-8 -*-
2 | #'''
3 | # Created on 18-12-11 上午10:01
4 | #
5 | # @Author: Greg Gao(laygin)
6 | #'''
7 | import os
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torchvision.models as models
12 |
13 |
14 | class RPN_REGR_Loss(nn.Module):
15 | def __init__(self, device, sigma=9.0):
16 | super(RPN_REGR_Loss, self).__init__()
17 | self.sigma = sigma
18 | self.device = device
19 |
20 | def forward(self, input, target):
21 | '''
22 | smooth L1 loss
23 | :param input:y_preds
24 | :param target: y_true
25 | :return:
26 | '''
27 | try:
28 | cls = target[0, :, 0]
29 | regr = target[0, :, 1:3]
30 | regr_keep = (cls == 1).nonzero()[:, 0]
31 | regr_true = regr[regr_keep]
32 | regr_pred = input[0][regr_keep]
33 | diff = torch.abs(regr_true - regr_pred)
34 | less_one = (diff<1.0/self.sigma).float()
35 | loss = less_one * 0.5 * diff ** 2 * self.sigma + torch.abs(1- less_one) * (diff - 0.5/self.sigma)
36 | loss = torch.sum(loss, 1)
37 | loss = torch.mean(loss) if loss.numel() > 0 else torch.tensor(0.0)
38 | except Exception as e:
39 | print('RPN_REGR_Loss Exception:', e)
40 | # print(input, target)
41 | loss = torch.tensor(0.0)
42 |
43 | return loss.to(self.device)
44 |
45 |
46 | class RPN_CLS_Loss(nn.Module):
47 | def __init__(self,device):
48 | super(RPN_CLS_Loss, self).__init__()
49 | self.device = device
50 |
51 | def forward(self, input, target):
52 | y_true = target[0][0]
53 | cls_keep = (y_true != -1).nonzero()[:, 0]
54 | cls_true = y_true[cls_keep].long()
55 | cls_pred = input[0][cls_keep]
56 | loss = F.nll_loss(F.log_softmax(cls_pred, dim=-1), cls_true) # original is sparse_softmax_cross_entropy_with_logits
57 | # loss = nn.BCEWithLogitsLoss()(cls_pred[:,0], cls_true.float()) # 18-12-8
58 | loss = torch.clamp(torch.mean(loss), 0, 10) if loss.numel() > 0 else torch.tensor(0.0)
59 | return loss.to(self.device)
60 |
61 |
62 | class BasicConv(nn.Module):
63 | def __init__(self,
64 | in_planes,
65 | out_planes,
66 | kernel_size,
67 | stride=1,
68 | padding=0,
69 | dilation=1,
70 | groups=1,
71 | relu=True,
72 | bn=True,
73 | bias=True):
74 | super(BasicConv, self).__init__()
75 | self.out_channels = out_planes
76 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
77 | self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
78 | self.relu = nn.ReLU(inplace=True) if relu else None
79 |
80 | def forward(self, x):
81 | x = self.conv(x)
82 | if self.bn is not None:
83 | x = self.bn(x)
84 | if self.relu is not None:
85 | x = self.relu(x)
86 | return x
87 |
88 |
89 | class CTPN_Model(nn.Module):
90 | def __init__(self):
91 | super().__init__()
92 | base_model = models.vgg16(pretrained=False)
93 | layers = list(base_model.features)[:-1]
94 | self.base_layers = nn.Sequential(*layers) # block5_conv3 output
95 | self.rpn = BasicConv(512, 512, 3,1,1,bn=False)
96 | self.brnn = nn.GRU(512,128, bidirectional=True, batch_first=True)
97 | self.lstm_fc = BasicConv(256, 512,1,1,relu=True, bn=False)
98 | self.rpn_class = BasicConv(512, 10*2, 1, 1, relu=False,bn=False)
99 | self.rpn_regress = BasicConv(512, 10 * 2, 1, 1, relu=False, bn=False)
100 |
101 | def forward(self, x):
102 | x = self.base_layers(x)
103 | # rpn
104 | x = self.rpn(x)
105 |
106 | x1 = x.permute(0,2,3,1).contiguous() # channels last
107 | b = x1.size() # batch_size, h, w, c
108 | x1 = x1.view(b[0]*b[1], b[2], b[3])
109 |
110 | x2, _ = self.brnn(x1)
111 |
112 | xsz = x.size()
113 | x3 = x2.view(xsz[0], xsz[2], xsz[3], 256) # torch.Size([4, 20, 20, 256])
114 |
115 | x3 = x3.permute(0,3,1,2).contiguous() # channels first
116 | x3 = self.lstm_fc(x3)
117 | x = x3
118 |
119 | cls = self.rpn_class(x)
120 | regr = self.rpn_regress(x)
121 |
122 | cls = cls.permute(0,2,3,1).contiguous()
123 | regr = regr.permute(0,2,3,1).contiguous()
124 |
125 | cls = cls.view(cls.size(0), cls.size(1)*cls.size(2)*10, 2)
126 | regr = regr.view(regr.size(0), regr.size(1)*regr.size(2)*10, 2)
127 |
128 | return cls, regr
129 |
--------------------------------------------------------------------------------
/ctpn_predict.py:
--------------------------------------------------------------------------------
1 | #-*- coding:utf-8 -*-
2 | #'''
3 | # Created on 18-12-11 上午10:03
4 | #
5 | # @Author: Greg Gao(laygin)
6 | #'''
7 | import os
8 | os.environ['CUDA_VISIBLE_DEVICES'] = ''
9 | import cv2
10 | import numpy as np
11 |
12 | import torch
13 | import torch.nn.functional as F
14 | from ctpn_model import CTPN_Model
15 | from ctpn_utils import gen_anchor, bbox_transfor_inv, clip_box, filter_bbox,nms, TextProposalConnectorOriented
16 | from ctpn_utils import resize
17 | import config
18 |
19 |
20 | prob_thresh = 0.7
21 | width = 600
22 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
23 | weights = os.path.join(config.checkpoints_dir, 'trained weights file.pth.tar')
24 | img_path = 'path to test image'
25 |
26 | model = CTPN_Model()
27 | model.load_state_dict(torch.load(weights, map_location=device)['model_state_dict'])
28 | model.to(device)
29 | model.eval()
30 |
31 |
32 | def dis(image):
33 | cv2.imshow('image', image)
34 | cv2.waitKey(0)
35 | cv2.destroyAllWindows()
36 |
37 |
38 | image = cv2.imread(img_path)
39 | image = resize(image, width=width)
40 | image_c = image.copy()
41 | h, w = image.shape[:2]
42 | image = image.astype(np.float32) - config.IMAGE_MEAN
43 | image = torch.from_numpy(image.transpose(2, 0, 1)).unsqueeze(0).float()
44 |
45 |
46 | with torch.no_grad():
47 | image = image.to(device)
48 | cls, regr = model(image)
49 | cls_prob = F.softmax(cls, dim=-1).cpu().numpy()
50 | regr = regr.cpu().numpy()
51 | anchor = gen_anchor((int(h / 16), int(w / 16)), 16)
52 | bbox = bbox_transfor_inv(anchor, regr)
53 | bbox = clip_box(bbox, [h, w])
54 |
55 | fg = np.where(cls_prob[0, :, 1] > prob_thresh)[0]
56 | select_anchor = bbox[fg, :]
57 | select_score = cls_prob[0, fg, 1]
58 | select_anchor = select_anchor.astype(np.int32)
59 |
60 | keep_index = filter_bbox(select_anchor, 16)
61 |
62 | # nsm
63 | select_anchor = select_anchor[keep_index]
64 | select_score = select_score[keep_index]
65 | select_score = np.reshape(select_score, (select_score.shape[0], 1))
66 | nmsbox = np.hstack((select_anchor, select_score))
67 | keep = nms(nmsbox, 0.3)
68 | select_anchor = select_anchor[keep]
69 | select_score = select_score[keep]
70 |
71 | # text line-
72 | textConn = TextProposalConnectorOriented()
73 | text = textConn.get_text_lines(select_anchor, select_score, [h, w])
74 | print(text)
75 |
76 | for i in text:
77 | s = str(round(i[-1] * 100, 2)) + '%'
78 | i = [int(j) for j in i]
79 | cv2.line(image_c, (i[0], i[1]), (i[2], i[3]), (0, 0, 255), 2)
80 | cv2.line(image_c, (i[0], i[1]), (i[4], i[5]), (0, 0, 255), 2)
81 | cv2.line(image_c, (i[6], i[7]), (i[2], i[3]), (0, 0, 255), 2)
82 | cv2.line(image_c, (i[4], i[5]), (i[6], i[7]), (0, 0, 255), 2)
83 | cv2.putText(image_c, s, (i[0]+13, i[1]+13),
84 | cv2.FONT_HERSHEY_SIMPLEX,
85 | 1,
86 | (255,0,0),
87 | 2,
88 | cv2.LINE_AA)
89 |
90 | dis(image_c)
91 |
--------------------------------------------------------------------------------
/ctpn_train.py:
--------------------------------------------------------------------------------
1 | #-*- coding:utf-8 -*-
2 | #'''
3 | # Created on 18-12-27 上午10:31
4 | #
5 | # @Author: Greg Gao(laygin)
6 | #'''
7 | import os
8 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
9 | import torch
10 | from torch.utils.data import DataLoader
11 | from torch import optim
12 | import numpy as np
13 | import argparse
14 |
15 | import config
16 | from ctpn_model import CTPN_Model, RPN_CLS_Loss, RPN_REGR_Loss
17 | from data import VOCDataset
18 |
19 |
20 | random_seed = 2019
21 | torch.random.manual_seed(random_seed)
22 | np.random.seed(random_seed)
23 |
24 | num_workers = 8
25 | epochs = 20
26 | lr = 1e-3
27 | resume_epoch = 0
28 | pre_weights = os.path.join(config.checkpoints_dir, 'ctpn_keras_weights.pth.tar')
29 |
30 |
31 | def get_arguments():
32 | parser = argparse.ArgumentParser(description='Pytorch CTPN For TexT Detection')
33 | parser.add_argument('--num-workers', type=int, default=num_workers)
34 | parser.add_argument('--image-dir', type=str, default=config.img_dir)
35 | parser.add_argument('--labels-dir', type=str, default=config.xml_dir)
36 | parser.add_argument('--pretrained-weights', type=str,default=pre_weights)
37 | return parser.parse_args()
38 |
39 |
40 | def save_checkpoint(state, epoch, loss_cls, loss_regr, loss, ext='pth.tar'):
41 | check_path = os.path.join(config.checkpoints_dir,
42 | f'ctpn_ep{epoch:02d}_'
43 | f'{loss_cls:.4f}_{loss_regr:.4f}_{loss:.4f}.{ext}')
44 |
45 | torch.save(state, check_path)
46 | print('saving to {}'.format(check_path))
47 |
48 |
49 | args = vars(get_arguments())
50 |
51 | if __name__ == '__main__':
52 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
53 | checkpoints_weight = args['pretrained_weights']
54 | if os.path.exists(checkpoints_weight):
55 | pretrained = False
56 |
57 | dataset = VOCDataset(args['image_dir'], args['labels_dir'])
58 | dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=args['num_workers'])
59 | model = CTPN_Model()
60 | model.to(device)
61 |
62 | if os.path.exists(checkpoints_weight):
63 | print('using pretrained weight: {}'.format(checkpoints_weight))
64 | cc = torch.load(checkpoints_weight, map_location=device)
65 | model.load_state_dict(cc['model_state_dict'])
66 | resume_epoch = cc['epoch']
67 |
68 | params_to_uodate = model.parameters()
69 | optimizer = optim.SGD(params_to_uodate, lr=lr, momentum=0.9)
70 |
71 | critetion_cls = RPN_CLS_Loss(device)
72 | critetion_regr = RPN_REGR_Loss(device)
73 |
74 | best_loss_cls = 100
75 | best_loss_regr = 100
76 | best_loss = 100
77 | best_model = None
78 | epochs += resume_epoch
79 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
80 |
81 | for epoch in range(resume_epoch+1, epochs):
82 | print(f'Epoch {epoch}/{epochs}')
83 | print('#'*50)
84 | epoch_size = len(dataset) // 1
85 | model.train()
86 | epoch_loss_cls = 0
87 | epoch_loss_regr = 0
88 | epoch_loss = 0
89 | scheduler.step(epoch)
90 |
91 | for batch_i, (imgs, clss, regrs) in enumerate(dataloader):
92 | imgs = imgs.to(device)
93 | clss = clss.to(device)
94 | regrs = regrs.to(device)
95 |
96 | optimizer.zero_grad()
97 |
98 | out_cls, out_regr = model(imgs)
99 | loss_cls = critetion_cls(out_cls, clss)
100 | loss_regr = critetion_regr(out_regr, regrs)
101 |
102 | loss = loss_cls + loss_regr # total loss
103 | loss.backward()
104 | optimizer.step()
105 |
106 | epoch_loss_cls += loss_cls.item()
107 | epoch_loss_regr += loss_regr.item()
108 | epoch_loss += loss.item()
109 | mmp = batch_i+1
110 |
111 | print(f'Ep:{epoch}/{epochs-1}--'
112 | f'Batch:{batch_i}/{epoch_size}\n'
113 | f'batch: loss_cls:{loss_cls.item():.4f}--loss_regr:{loss_regr.item():.4f}--loss:{loss.item():.4f}\n'
114 | f'Epoch: loss_cls:{epoch_loss_cls/mmp:.4f}--loss_regr:{epoch_loss_regr/mmp:.4f}--'
115 | f'loss:{epoch_loss/mmp:.4f}\n')
116 |
117 | epoch_loss_cls /= epoch_size
118 | epoch_loss_regr /= epoch_size
119 | epoch_loss /= epoch_size
120 | print(f'Epoch:{epoch}--{epoch_loss_cls:.4f}--{epoch_loss_regr:.4f}--{epoch_loss:.4f}')
121 | if best_loss_cls > epoch_loss_cls or best_loss_regr > epoch_loss_regr or best_loss > epoch_loss:
122 | best_loss = epoch_loss
123 | best_loss_regr = epoch_loss_regr
124 | best_loss_cls = epoch_loss_cls
125 | best_model = model
126 | save_checkpoint({'model_state_dict': best_model.state_dict(),
127 | 'epoch': epoch},
128 | epoch,
129 | best_loss_cls,
130 | best_loss_regr,
131 | best_loss)
132 |
133 | if torch.cuda.is_available():
134 | torch.cuda.empty_cache()
135 |
136 |
--------------------------------------------------------------------------------
/ctpn_utils.py:
--------------------------------------------------------------------------------
1 | #-*- coding:utf-8 -*-
2 | #'''
3 | # Created on 18-12-11 上午10:05
4 | #
5 | # @Author: Greg Gao(laygin)
6 | #'''
7 | import numpy as np
8 | import cv2
9 | from config import *
10 |
11 |
12 | def resize(image, width=None, height=None, inter=cv2.INTER_AREA):
13 | # initialize the dimensions of the image to be resized and
14 | # grab the image size
15 | dim = None
16 | (h, w) = image.shape[:2]
17 |
18 | # if both the width and height are None, then return the
19 | # original image
20 | if width is None and height is None:
21 | return image
22 |
23 | # check to see if the width is None
24 | if width is None:
25 | # calculate the ratio of the height and construct the
26 | # dimensions
27 | r = height / float(h)
28 | dim = (int(w * r), height)
29 |
30 | # otherwise, the height is None
31 | else:
32 | # calculate the ratio of the width and construct the
33 | # dimensions
34 | r = width / float(w)
35 | dim = (width, int(h * r))
36 |
37 | # resize the image
38 | resized = cv2.resize(image, dim, interpolation=inter)
39 |
40 | # return the resized image
41 | return resized
42 |
43 |
44 | def gen_anchor(featuresize, scale):
45 | """
46 | gen base anchor from feature map [HXW][9][4]
47 | reshape [HXW][9][4] to [HXWX9][4]
48 | """
49 | heights = [11, 16, 23, 33, 48, 68, 97, 139, 198, 283]
50 | widths = [16, 16, 16, 16, 16, 16, 16, 16, 16, 16]
51 |
52 | # gen k=9 anchor size (h,w)
53 | heights = np.array(heights).reshape(len(heights), 1)
54 | widths = np.array(widths).reshape(len(widths), 1)
55 |
56 | base_anchor = np.array([0, 0, 15, 15])
57 | # center x,y
58 | xt = (base_anchor[0] + base_anchor[2]) * 0.5
59 | yt = (base_anchor[1] + base_anchor[3]) * 0.5
60 |
61 | # x1 y1 x2 y2
62 | x1 = xt - widths * 0.5
63 | y1 = yt - heights * 0.5
64 | x2 = xt + widths * 0.5
65 | y2 = yt + heights * 0.5
66 | base_anchor = np.hstack((x1, y1, x2, y2))
67 |
68 | h, w = featuresize
69 | shift_x = np.arange(0, w) * scale
70 | shift_y = np.arange(0, h) * scale
71 | # apply shift
72 | anchor = []
73 | for i in shift_y:
74 | for j in shift_x:
75 | anchor.append(base_anchor + [j, i, j, i])
76 | return np.array(anchor).reshape((-1, 4))
77 |
78 |
79 | def cal_iou(box1, box1_area, boxes2, boxes2_area):
80 | """
81 | box1 [x1,y1,x2,y2]
82 | boxes2 [Msample,x1,y1,x2,y2]
83 | """
84 | x1 = np.maximum(box1[0], boxes2[:, 0])
85 | x2 = np.minimum(box1[2], boxes2[:, 2])
86 | y1 = np.maximum(box1[1], boxes2[:, 1])
87 | y2 = np.minimum(box1[3], boxes2[:, 3])
88 |
89 | intersection = np.maximum(x2 - x1, 0) * np.maximum(y2 - y1, 0)
90 | iou = intersection / (box1_area + boxes2_area[:] - intersection[:])
91 | return iou
92 |
93 |
94 | def cal_overlaps(boxes1, boxes2):
95 | """
96 | boxes1 [Nsample,x1,y1,x2,y2] anchor
97 | boxes2 [Msample,x1,y1,x2,y2] grouth-box
98 |
99 | """
100 | area1 = (boxes1[:, 0] - boxes1[:, 2]) * (boxes1[:, 1] - boxes1[:, 3])
101 | area2 = (boxes2[:, 0] - boxes2[:, 2]) * (boxes2[:, 1] - boxes2[:, 3])
102 |
103 | overlaps = np.zeros((boxes1.shape[0], boxes2.shape[0]))
104 |
105 | # calculate the intersection of boxes1(anchor) and boxes2(GT box)
106 | for i in range(boxes1.shape[0]):
107 | overlaps[i][:] = cal_iou(boxes1[i], area1[i], boxes2, area2)
108 |
109 | return overlaps
110 |
111 |
112 | def bbox_transfrom(anchors, gtboxes):
113 | """
114 | compute relative predicted vertical coordinates Vc ,Vh
115 | with respect to the bounding box location of an anchor
116 | """
117 | regr = np.zeros((anchors.shape[0], 2))
118 | Cy = (gtboxes[:, 1] + gtboxes[:, 3]) * 0.5
119 | Cya = (anchors[:, 1] + anchors[:, 3]) * 0.5
120 | h = gtboxes[:, 3] - gtboxes[:, 1] + 1.0
121 | ha = anchors[:, 3] - anchors[:, 1] + 1.0
122 |
123 | Vc = (Cy - Cya) / ha
124 | Vh = np.log(h / ha)
125 |
126 | return np.vstack((Vc, Vh)).transpose()
127 |
128 |
129 | def bbox_transfor_inv(anchor, regr):
130 | """
131 | return predict bbox
132 | """
133 |
134 | Cya = (anchor[:, 1] + anchor[:, 3]) * 0.5
135 | ha = anchor[:, 3] - anchor[:, 1] + 1
136 |
137 | Vcx = regr[0, :, 0]
138 | Vhx = regr[0, :, 1]
139 |
140 | Cyx = Vcx * ha + Cya
141 | hx = np.exp(Vhx) * ha
142 | xt = (anchor[:, 0] + anchor[:, 2]) * 0.5
143 |
144 | x1 = xt - 16 * 0.5
145 | y1 = Cyx - hx * 0.5
146 | x2 = xt + 16 * 0.5
147 | y2 = Cyx + hx * 0.5
148 | bbox = np.vstack((x1, y1, x2, y2)).transpose()
149 |
150 | return bbox
151 |
152 |
153 | def clip_box(bbox, im_shape):
154 | # x1 >= 0
155 | bbox[:, 0] = np.maximum(np.minimum(bbox[:, 0], im_shape[1] - 1), 0)
156 | # y1 >= 0
157 | bbox[:, 1] = np.maximum(np.minimum(bbox[:, 1], im_shape[0] - 1), 0)
158 | # x2 < im_shape[1]
159 | bbox[:, 2] = np.maximum(np.minimum(bbox[:, 2], im_shape[1] - 1), 0)
160 | # y2 < im_shape[0]
161 | bbox[:, 3] = np.maximum(np.minimum(bbox[:, 3], im_shape[0] - 1), 0)
162 |
163 | return bbox
164 |
165 |
166 | def filter_bbox(bbox, minsize):
167 | ws = bbox[:, 2] - bbox[:, 0] + 1
168 | hs = bbox[:, 3] - bbox[:, 1] + 1
169 | keep = np.where((ws >= minsize) & (hs >= minsize))[0]
170 | return keep
171 |
172 |
173 | def cal_rpn(imgsize, featuresize, scale, gtboxes):
174 | imgh, imgw = imgsize
175 |
176 | # gen base anchor
177 | base_anchor = gen_anchor(featuresize, scale)
178 |
179 | # calculate iou
180 | overlaps = cal_overlaps(base_anchor, gtboxes)
181 |
182 | # init labels -1 don't care 0 is negative 1 is positive
183 | labels = np.empty(base_anchor.shape[0])
184 | labels.fill(-1)
185 |
186 | # for each GT box corresponds to an anchor which has highest IOU
187 | gt_argmax_overlaps = overlaps.argmax(axis=0)
188 |
189 | # the anchor with the highest IOU overlap with a GT box
190 | anchor_argmax_overlaps = overlaps.argmax(axis=1)
191 | anchor_max_overlaps = overlaps[range(overlaps.shape[0]), anchor_argmax_overlaps]
192 |
193 | # IOU > IOU_POSITIVE
194 | labels[anchor_max_overlaps > IOU_POSITIVE] = 1
195 | # IOU = imgw) |
205 | (base_anchor[:, 3] >= imgh)
206 | )[0]
207 | labels[outside_anchor] = -1
208 |
209 | # subsample positive labels ,if greater than RPN_POSITIVE_NUM(default 128)
210 | fg_index = np.where(labels == 1)[0]
211 | if (len(fg_index) > RPN_POSITIVE_NUM):
212 | labels[np.random.choice(fg_index, len(fg_index) - RPN_POSITIVE_NUM, replace=False)] = -1
213 |
214 | # subsample negative labels
215 | bg_index = np.where(labels == 0)[0]
216 | num_bg = RPN_TOTAL_NUM - np.sum(labels == 1)
217 | if (len(bg_index) > num_bg):
218 | # print('bgindex:',len(bg_index),'num_bg',num_bg)
219 | labels[np.random.choice(bg_index, len(bg_index) - num_bg, replace=False)] = -1
220 |
221 | # calculate bbox targets
222 | # debug here
223 | bbox_targets = bbox_transfrom(base_anchor, gtboxes[anchor_argmax_overlaps, :])
224 | # bbox_targets=[]
225 |
226 | return [labels, bbox_targets], base_anchor
227 |
228 |
229 | def nms(dets, thresh):
230 | x1 = dets[:, 0]
231 | y1 = dets[:, 1]
232 | x2 = dets[:, 2]
233 | y2 = dets[:, 3]
234 | scores = dets[:, 4]
235 |
236 | areas = (x2 - x1 + 1) * (y2 - y1 + 1)
237 | order = scores.argsort()[::-1]
238 |
239 | keep = []
240 | while order.size > 0:
241 | i = order[0]
242 | keep.append(i)
243 | xx1 = np.maximum(x1[i], x1[order[1:]])
244 | yy1 = np.maximum(y1[i], y1[order[1:]])
245 | xx2 = np.minimum(x2[i], x2[order[1:]])
246 | yy2 = np.minimum(y2[i], y2[order[1:]])
247 |
248 | w = np.maximum(0.0, xx2 - xx1 + 1)
249 | h = np.maximum(0.0, yy2 - yy1 + 1)
250 | inter = w * h
251 | ovr = inter / (areas[i] + areas[order[1:]] - inter)
252 |
253 | inds = np.where(ovr <= thresh)[0]
254 | order = order[inds + 1]
255 | return keep
256 |
257 |
258 | # for predict
259 | class Graph:
260 | def __init__(self, graph):
261 | self.graph = graph
262 |
263 | def sub_graphs_connected(self):
264 | sub_graphs = []
265 | for index in range(self.graph.shape[0]):
266 | if not self.graph[:, index].any() and self.graph[index, :].any():
267 | v = index
268 | sub_graphs.append([v])
269 | while self.graph[v, :].any():
270 | v = np.where(self.graph[v, :])[0][0]
271 | sub_graphs[-1].append(v)
272 | return sub_graphs
273 |
274 |
275 | class TextLineCfg:
276 | SCALE = 600
277 | MAX_SCALE = 1200
278 | TEXT_PROPOSALS_WIDTH = 16
279 | MIN_NUM_PROPOSALS = 2
280 | MIN_RATIO = 0.5
281 | LINE_MIN_SCORE = 0.9
282 | MAX_HORIZONTAL_GAP = 60
283 | TEXT_PROPOSALS_MIN_SCORE = 0.7
284 | TEXT_PROPOSALS_NMS_THRESH = 0.3
285 | MIN_V_OVERLAPS = 0.6
286 | MIN_SIZE_SIM = 0.6
287 |
288 |
289 | class TextProposalGraphBuilder:
290 | """
291 | Build Text proposals into a graph.
292 | """
293 |
294 | def get_successions(self, index):
295 | box = self.text_proposals[index]
296 | results = []
297 | for left in range(int(box[0]) + 1, min(int(box[0]) + TextLineCfg.MAX_HORIZONTAL_GAP + 1, self.im_size[1])):
298 | adj_box_indices = self.boxes_table[left]
299 | for adj_box_index in adj_box_indices:
300 | if self.meet_v_iou(adj_box_index, index):
301 | results.append(adj_box_index)
302 | if len(results) != 0:
303 | return results
304 | return results
305 |
306 | def get_precursors(self, index):
307 | box = self.text_proposals[index]
308 | results = []
309 | for left in range(int(box[0]) - 1, max(int(box[0] - TextLineCfg.MAX_HORIZONTAL_GAP), 0) - 1, -1):
310 | adj_box_indices = self.boxes_table[left]
311 | for adj_box_index in adj_box_indices:
312 | if self.meet_v_iou(adj_box_index, index):
313 | results.append(adj_box_index)
314 | if len(results) != 0:
315 | return results
316 | return results
317 |
318 | def is_succession_node(self, index, succession_index):
319 | precursors = self.get_precursors(succession_index)
320 | if self.scores[index] >= np.max(self.scores[precursors]):
321 | return True
322 | return False
323 |
324 | def meet_v_iou(self, index1, index2):
325 | def overlaps_v(index1, index2):
326 | h1 = self.heights[index1]
327 | h2 = self.heights[index2]
328 | y0 = max(self.text_proposals[index2][1], self.text_proposals[index1][1])
329 | y1 = min(self.text_proposals[index2][3], self.text_proposals[index1][3])
330 | return max(0, y1 - y0 + 1) / min(h1, h2)
331 |
332 | def size_similarity(index1, index2):
333 | h1 = self.heights[index1]
334 | h2 = self.heights[index2]
335 | return min(h1, h2) / max(h1, h2)
336 |
337 | return overlaps_v(index1, index2) >= TextLineCfg.MIN_V_OVERLAPS and \
338 | size_similarity(index1, index2) >= TextLineCfg.MIN_SIZE_SIM
339 |
340 | def build_graph(self, text_proposals, scores, im_size):
341 | self.text_proposals = text_proposals
342 | self.scores = scores
343 | self.im_size = im_size
344 | self.heights = text_proposals[:, 3] - text_proposals[:, 1] + 1
345 |
346 | boxes_table = [[] for _ in range(self.im_size[1])]
347 | for index, box in enumerate(text_proposals):
348 | boxes_table[int(box[0])].append(index)
349 | self.boxes_table = boxes_table
350 |
351 | graph = np.zeros((text_proposals.shape[0], text_proposals.shape[0]), np.bool)
352 |
353 | for index, box in enumerate(text_proposals):
354 | successions = self.get_successions(index)
355 | if len(successions) == 0:
356 | continue
357 | succession_index = successions[np.argmax(scores[successions])]
358 | if self.is_succession_node(index, succession_index):
359 | # NOTE: a box can have multiple successions(precursors) if multiple successions(precursors)
360 | # have equal scores.
361 | graph[index, succession_index] = True
362 | return Graph(graph)
363 |
364 |
365 | class TextProposalConnectorOriented:
366 | """
367 | Connect text proposals into text lines
368 | """
369 |
370 | def __init__(self):
371 | self.graph_builder = TextProposalGraphBuilder()
372 |
373 | def group_text_proposals(self, text_proposals, scores, im_size):
374 | graph = self.graph_builder.build_graph(text_proposals, scores, im_size)
375 | return graph.sub_graphs_connected()
376 |
377 | def fit_y(self, X, Y, x1, x2):
378 | # len(X) != 0
379 | # if X only include one point, the function will get line y=Y[0]
380 | if np.sum(X == X[0]) == len(X):
381 | return Y[0], Y[0]
382 | p = np.poly1d(np.polyfit(X, Y, 1))
383 | return p(x1), p(x2)
384 |
385 | def get_text_lines(self, text_proposals, scores, im_size):
386 | """
387 | text_proposals:boxes
388 |
389 | """
390 | # tp=text proposal
391 | tp_groups = self.group_text_proposals(text_proposals, scores, im_size) # 首先还是建图,获取到文本行由哪几个小框构成
392 |
393 | text_lines = np.zeros((len(tp_groups), 8), np.float32)
394 |
395 | for index, tp_indices in enumerate(tp_groups):
396 | text_line_boxes = text_proposals[list(tp_indices)] # 每个文本行的全部小框
397 | X = (text_line_boxes[:, 0] + text_line_boxes[:, 2]) / 2 # 求每一个小框的中心x,y坐标
398 | Y = (text_line_boxes[:, 1] + text_line_boxes[:, 3]) / 2
399 |
400 | z1 = np.polyfit(X, Y, 1) # 多项式拟合,根据之前求的中心店拟合一条直线(最小二乘)
401 |
402 | x0 = np.min(text_line_boxes[:, 0]) # 文本行x坐标最小值
403 | x1 = np.max(text_line_boxes[:, 2]) # 文本行x坐标最大值
404 |
405 | offset = (text_line_boxes[0, 2] - text_line_boxes[0, 0]) * 0.5 # 小框宽度的一半
406 |
407 | # 以全部小框的左上角这个点去拟合一条直线,然后计算一下文本行x坐标的极左极右对应的y坐标
408 | lt_y, rt_y = self.fit_y(text_line_boxes[:, 0], text_line_boxes[:, 1], x0 + offset, x1 - offset)
409 | # 以全部小框的左下角这个点去拟合一条直线,然后计算一下文本行x坐标的极左极右对应的y坐标
410 | lb_y, rb_y = self.fit_y(text_line_boxes[:, 0], text_line_boxes[:, 3], x0 + offset, x1 - offset)
411 |
412 | score = scores[list(tp_indices)].sum() / float(len(tp_indices)) # 求全部小框得分的均值作为文本行的均值
413 |
414 | text_lines[index, 0] = x0
415 | text_lines[index, 1] = min(lt_y, rt_y) # 文本行上端 线段 的y坐标的小值
416 | text_lines[index, 2] = x1
417 | text_lines[index, 3] = max(lb_y, rb_y) # 文本行下端 线段 的y坐标的大值
418 | text_lines[index, 4] = score # 文本行得分
419 | text_lines[index, 5] = z1[0] # 根据中心点拟合的直线的k,b
420 | text_lines[index, 6] = z1[1]
421 | height = np.mean((text_line_boxes[:, 3] - text_line_boxes[:, 1])) # 小框平均高度
422 | text_lines[index, 7] = height + 2.5
423 |
424 | text_recs = np.zeros((len(text_lines), 9), np.float)
425 | index = 0
426 | for line in text_lines:
427 | b1 = line[6] - line[7] / 2 # 根据高度和文本行中心线,求取文本行上下两条线的b值
428 | b2 = line[6] + line[7] / 2
429 | x1 = line[0]
430 | y1 = line[5] * line[0] + b1 # 左上
431 | x2 = line[2]
432 | y2 = line[5] * line[2] + b1 # 右上
433 | x3 = line[0]
434 | y3 = line[5] * line[0] + b2 # 左下
435 | x4 = line[2]
436 | y4 = line[5] * line[2] + b2 # 右下
437 | disX = x2 - x1
438 | disY = y2 - y1
439 | width = np.sqrt(disX * disX + disY * disY) # 文本行宽度
440 |
441 | fTmp0 = y3 - y1 # 文本行高度
442 | fTmp1 = fTmp0 * disY / width
443 | x = np.fabs(fTmp1 * disX / width) # 做补偿
444 | y = np.fabs(fTmp1 * disY / width)
445 | if line[5] < 0:
446 | x1 -= x
447 | y1 += y
448 | x4 += x
449 | y4 -= y
450 | else:
451 | x2 += x
452 | y2 += y
453 | x3 -= x
454 | y3 -= y
455 | text_recs[index, 0] = x1
456 | text_recs[index, 1] = y1
457 | text_recs[index, 2] = x2
458 | text_recs[index, 3] = y2
459 | text_recs[index, 4] = x3
460 | text_recs[index, 5] = y3
461 | text_recs[index, 6] = x4
462 | text_recs[index, 7] = y4
463 | text_recs[index, 8] = line[4]
464 | index = index + 1
465 |
466 | return text_recs
467 |
468 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | #-*- coding:utf-8 -*-
2 | #'''
3 | # Created on 18-12-27 上午10:33
4 | #
5 | # @Author: Greg Gao(laygin)
6 | #'''
7 | from .dataset import VOCDataset
--------------------------------------------------------------------------------
/data/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opconty/pytorch_ctpn/89ea784d7776a08962c6d5bc5591730e15156d07/data/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opconty/pytorch_ctpn/89ea784d7776a08962c6d5bc5591730e15156d07/data/__pycache__/dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 | #-*- coding:utf-8 -*-
2 | #'''
3 | # Created on 18-12-27 上午10:34
4 | #
5 | # @Author: Greg Gao(laygin)
6 | #'''
7 |
8 | import os
9 | import xml.etree.ElementTree as ET
10 | import numpy as np
11 | import cv2
12 | from torch.utils.data import Dataset
13 | import torch
14 | from config import IMAGE_MEAN
15 | from ctpn_utils import cal_rpn
16 |
17 |
18 | def readxml(path):
19 | gtboxes = []
20 | imgfile = ''
21 | xml = ET.parse(path)
22 | for elem in xml.iter():
23 | if 'filename' in elem.tag:
24 | imgfile = elem.text
25 | if 'object' in elem.tag:
26 | for attr in list(elem):
27 | if 'bndbox' in attr.tag:
28 | xmin = int(round(float(attr.find('xmin').text)))
29 | ymin = int(round(float(attr.find('ymin').text)))
30 | xmax = int(round(float(attr.find('xmax').text)))
31 | ymax = int(round(float(attr.find('ymax').text)))
32 |
33 | gtboxes.append((xmin, ymin, xmax, ymax))
34 |
35 | return np.array(gtboxes), imgfile
36 |
37 |
38 | # for ctpn text detection
39 | class VOCDataset(Dataset):
40 | def __init__(self,
41 | datadir,
42 | labelsdir):
43 | '''
44 |
45 | :param txtfile: image name list text file
46 | :param datadir: image's directory
47 | :param labelsdir: annotations' directory
48 | '''
49 | if not os.path.isdir(datadir):
50 | raise Exception('[ERROR] {} is not a directory'.format(datadir))
51 | if not os.path.isdir(labelsdir):
52 | raise Exception('[ERROR] {} is not a directory'.format(labelsdir))
53 |
54 | self.datadir = datadir
55 | self.img_names = os.listdir(self.datadir)
56 | self.labelsdir = labelsdir
57 |
58 | def __len__(self):
59 | return len(self.img_names)
60 |
61 | def __getitem__(self, idx):
62 | img_name = self.img_names[idx]
63 | img_path = os.path.join(self.datadir, img_name)
64 | print(img_path)
65 | xml_path = os.path.join(self.labelsdir, img_name.replace('.jpg', '.xml'))
66 | gtbox, _ = readxml(xml_path)
67 | img = cv2.imread(img_path)
68 | h, w, c = img.shape
69 | # clip image
70 | if np.random.randint(2) == 1:
71 | img = img[:, ::-1, :]
72 | newx1 = w - gtbox[:, 2] - 1
73 | newx2 = w - gtbox[:, 0] - 1
74 | gtbox[:, 0] = newx1
75 | gtbox[:, 2] = newx2
76 |
77 | [cls, regr], _ = cal_rpn((h, w), (int(h / 16), int(w / 16)), 16, gtbox)
78 |
79 | m_img = img - IMAGE_MEAN
80 |
81 | regr = np.hstack([cls.reshape(cls.shape[0], 1), regr])
82 |
83 | cls = np.expand_dims(cls, axis=0)
84 |
85 | # transform to torch tensor
86 | m_img = torch.from_numpy(m_img.transpose([2, 0, 1])).float()
87 | cls = torch.from_numpy(cls).float()
88 | regr = torch.from_numpy(regr).float()
89 |
90 | return m_img, cls, regr
91 |
92 |
--------------------------------------------------------------------------------
/images/android_det.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opconty/pytorch_ctpn/89ea784d7776a08962c6d5bc5591730e15156d07/images/android_det.jpg
--------------------------------------------------------------------------------
/images/android_rec.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opconty/pytorch_ctpn/89ea784d7776a08962c6d5bc5591730e15156d07/images/android_rec.jpg
--------------------------------------------------------------------------------
/images/onto_android.md:
--------------------------------------------------------------------------------
1 | ## Android Chinese OCR
2 | > 2019-03-20 wed
3 |
4 | ## DEMO VERSION
5 | ### DETECTION
6 | 
7 |
8 | ### RECOGNITION
9 | 
10 |
11 | ### END-TO-END
12 | COMING SOOOOOOOOOOOOOOOOOOOON
13 |
--------------------------------------------------------------------------------
/logs/ANDROID_OCR.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opconty/pytorch_ctpn/89ea784d7776a08962c6d5bc5591730e15156d07/logs/ANDROID_OCR.pdf
--------------------------------------------------------------------------------
/logs/loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opconty/pytorch_ctpn/89ea784d7776a08962c6d5bc5591730e15156d07/logs/loss.png
--------------------------------------------------------------------------------
/logs/training_logs.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opconty/pytorch_ctpn/89ea784d7776a08962c6d5bc5591730e15156d07/logs/training_logs.pdf
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | ## Pytorch CTPN
2 | > update 19-03-20 wed: android ocr
3 |
4 |
5 | This is a pytorch implementation of [CTPN(Detecting Text in Natural Image with Connectionist Text Proposal Network)](https://arxiv.org/pdf/1609.03605.pdf).Inspired by [keras-ocr](https://github.com/xiaomaxiao/keras_ocr).
6 |
7 | Training log is available:[Training Log](./logs/training_logs.pdf)(Chinese)
8 |
9 | |model|size|
10 | |:--:|:--:|
11 | |keras-CTPN|142M|
12 | |**pytorch-CTPN**|**67.6M**|
13 |
14 | ### train
15 | - ~~download ctpn model weights (converted from keras ctpn weights) `ctpn_keras_weights.pth.tar` from [dropbox](https://www.dropbox.com/s/81zfc50x6g6fauz/ctpn_keras_weights.pth.tar?dl=0), and move it to **./checkpoints/**~~ (*For a number of reasons, the pretrained weights will no longer be available.Thanks for your attention.*)
16 | - ~~download [VOC2007_text_detection Chinese Text Detection dataset](http://not_available_any_more_due_to_lack_of_space) and move it to **./images/**~~
17 | - run `python ctpn_train.py --image-dir image_dir --labels-dir labels_dir --num-workers num_workers`
18 |
19 | ### predict
20 | - ~~download the pretrained weights from [dropbox](https://www.dropbox.com/s/r1zjw167a5lsk4l/ctpn_ep18_0.0074_0.0121_0.0195%28w-lstm%29.pth.tar?dl=0)~~
21 | - Please refer to [predict.py](./ctpn_predict.py) for more details.
22 |
23 | ### results
24 | [Training Log](./logs/training_logs.pdf)(Chinese)
25 |
26 | ### Android DEMO
27 | These days, I'm working on deploying this model on Android devices.you can check the results from [here](./logs/ANDROID_OCR.pdf).
28 |
29 | **Android text recognition 4-23**
30 | > Find out that adopting [skew transform](./results/ANDROID_DETECTION_SKEW.GIF) can significantly improve recognition accuracy.(It may take a few seconds, heavily depends on your harware and input image size)
31 |
32 | 
33 |
34 | ### reference
35 | - [CTPN (Detecting Text in Natural Image with Connectionist Text Proposal Network)](https://arxiv.org/pdf/1609.03605.pdf)
36 | - [keras-ocr](https://github.com/xiaomaxiao/keras_ocr)
37 |
38 | ### Licence
39 | [MIT License](https://opensource.org/licenses/MIT)
40 |
--------------------------------------------------------------------------------
/results/ANDROID_DETECTION_SKEW.GIF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opconty/pytorch_ctpn/89ea784d7776a08962c6d5bc5591730e15156d07/results/ANDROID_DETECTION_SKEW.GIF
--------------------------------------------------------------------------------
/results/ANDROID_RECO_DEMO.GIF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opconty/pytorch_ctpn/89ea784d7776a08962c6d5bc5591730e15156d07/results/ANDROID_RECO_DEMO.GIF
--------------------------------------------------------------------------------
/results/detection_res.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opconty/pytorch_ctpn/89ea784d7776a08962c6d5bc5591730e15156d07/results/detection_res.png
--------------------------------------------------------------------------------
/results/r0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opconty/pytorch_ctpn/89ea784d7776a08962c6d5bc5591730e15156d07/results/r0.jpg
--------------------------------------------------------------------------------
/results/r1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opconty/pytorch_ctpn/89ea784d7776a08962c6d5bc5591730e15156d07/results/r1.jpg
--------------------------------------------------------------------------------
/results/r2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opconty/pytorch_ctpn/89ea784d7776a08962c6d5bc5591730e15156d07/results/r2.jpg
--------------------------------------------------------------------------------
/results/r3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opconty/pytorch_ctpn/89ea784d7776a08962c6d5bc5591730e15156d07/results/r3.jpg
--------------------------------------------------------------------------------