├── LICENSE ├── README.md ├── ctpn ├── config.py ├── ctpn.py ├── dataset.py └── utils.py ├── log ├── 1.jpg ├── 1_src.jpg ├── 2.jpg ├── 2_src.png └── training_loss.png ├── predict.py ├── train.py └── weights └── ctpn.pth /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Hans Hu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ctpn.pytorch 2 | Pytorch implementation of CTPN (Detecting Text in Natural Image with Connectionist Text Proposal Network) 3 | 4 | # Paper 5 | https://arxiv.org/pdf/1609.03605.pdf 6 | 7 | # train 8 | training dataset: ICDAR2013 and ICDAR2017. 9 | If you want to train your own dataset, you need to change the 'img_dir' and 'label_dir' in file *ctpn/config.py*, then run 10 | ``` 11 | python train.py 12 | ``` 13 | 14 | ![training loss](https://github.com/CrazySummerday/ctpn.pytorch/raw/master/log/training_loss.png) 15 | 16 | 17 | # predict 18 | Download pretrained model from './weights/', change the test image path in file *predict.py*, then run: 19 | ``` 20 | python predict.py 21 | ``` 22 | ## result 23 | ![result_1](https://github.com/CrazySummerday/ctpn.pytorch/raw/master/log/1.jpg) 24 | ![result_2](https://github.com/CrazySummerday/ctpn.pytorch/raw/master/log/2.jpg) 25 | 26 | # references 27 | https://github.com/opconty/pytorch_ctpn 28 | 29 | https://github.com/courao/ocr.pytorch 30 | -------------------------------------------------------------------------------- /ctpn/config.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | import os 3 | 4 | img_dir = '../imagedata/image/' 5 | label_dir = '../imagedata/xml/' 6 | num_workers = 4 7 | pretrained_weights = '' 8 | 9 | anchor_scale = 16 10 | IOU_NEGATIVE = 0.3 11 | IOU_POSITIVE = 0.7 12 | IOU_SELECT = 0.7 13 | 14 | RPN_POSITIVE_NUM = 150 15 | RPN_TOTAL_NUM = 300 16 | 17 | IMAGE_MEAN = [123.68, 116.779, 103.939] 18 | 19 | # online hard example mining 20 | OHEM = True 21 | checkpoints_dir = './checkpoints' 22 | -------------------------------------------------------------------------------- /ctpn/ctpn.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision.models as models 6 | from ctpn import config 7 | 8 | ''' 9 | 回归损失: smooth L1 Loss 10 | 只针对正样本求取回归损失 11 | L = 0.5*x**2 |x|<1 12 | L = |x| - 0.5 13 | sigma: 平滑系数 14 | 1、从预测框p和真值框g中筛选出正样本 15 | 2、|x| = |g - p| 16 | 3、求取loss,这里设置了一个平滑系数 1/sigma 17 | (1) |x|>1/sigma: loss = |x| - 0.5/sigma 18 | (2) |x|<1/sigma: loss = 0.5*sigma*|x|**2 19 | ''' 20 | class RPN_REGR_Loss(nn.Module): 21 | def __init__(self, device, sigma=9.0): 22 | super(RPN_REGR_Loss, self).__init__() 23 | self.sigma = sigma 24 | self.device = device 25 | 26 | def forward(self, input, target): 27 | try: 28 | cls = target[0, :, 0] 29 | regression = target[0, :, 1:3] 30 | regr_keep = (cls == 1).nonzero()[:, 0] 31 | regr_true = regression[regr_keep] 32 | regr_pred = input[0][regr_keep] 33 | diff = torch.abs(regr_true - regr_pred) 34 | less_one = (diff<1.0/self.sigma).float() 35 | loss = less_one * 0.5 * diff ** 2 * self.sigma + torch.abs(1- less_one) * (diff - 0.5/self.sigma) 36 | loss = torch.sum(loss, 1) 37 | loss = torch.mean(loss) if loss.numel() > 0 else torch.tensor(0.0) 38 | except Exception as e: 39 | print('RPN_REGR_Loss Exception:', e) 40 | loss = torch.tensor(0.0) 41 | 42 | return loss.to(self.device) 43 | 44 | ''' 45 | 分类损失: softmax loss 46 | 1、OHEM模式 47 | (1) 筛选出正样本,求取softmaxloss 48 | (2) 求取负样本数量N_neg, 指定样本数量N, 求取负样本的topK loss, 其中K = min(N_neg, N - len(pos_num)) 49 | (3) loss = loss1 + loss2 50 | 2、求取NLLLoss,截断在(0, 10)区间 51 | ''' 52 | class RPN_CLS_Loss(nn.Module): 53 | def __init__(self,device): 54 | super(RPN_CLS_Loss, self).__init__() 55 | self.device = device 56 | self.L_cls = nn.CrossEntropyLoss(reduction='none') 57 | 58 | def forward(self, input, target): 59 | if config.OHEM: 60 | cls_gt = target[0][0] 61 | num_pos = 0 62 | loss_pos_sum = 0 63 | 64 | if len((cls_gt == 1).nonzero()) != 0: 65 | cls_pos = (cls_gt == 1).nonzero()[:, 0] 66 | gt_pos = cls_gt[cls_pos].long() 67 | cls_pred_pos = input[0][cls_pos] 68 | loss_pos = self.L_cls(cls_pred_pos.view(-1, 2), gt_pos.view(-1)) 69 | loss_pos_sum = loss_pos.sum() 70 | num_pos = len(loss_pos) 71 | 72 | cls_neg = (cls_gt == 0).nonzero()[:, 0] 73 | gt_neg = cls_gt[cls_neg].long() 74 | cls_pred_neg = input[0][cls_neg] 75 | 76 | loss_neg = self.L_cls(cls_pred_neg.view(-1, 2), gt_neg.view(-1)) 77 | loss_neg_topK, _ = torch.topk(loss_neg, min(len(loss_neg), config.RPN_TOTAL_NUM - num_pos)) 78 | loss_cls = loss_pos_sum + loss_neg_topK.sum() 79 | loss_cls = loss_cls / config.RPN_TOTAL_NUM 80 | 81 | return loss_cls.to(self.device) 82 | else: 83 | y_true = target[0][0] 84 | cls_keep = (y_true != -1).nonzero()[:, 0] 85 | cls_true = y_true[cls_keep].long() 86 | cls_pred = input[0][cls_keep] 87 | loss = F.nll_loss(F.log_softmax(cls_pred, dim=-1), cls_true) 88 | loss = torch.clamp(torch.mean(loss), 0, 10) if loss.numel() > 0 else torch.tensor(0.0) 89 | 90 | return loss.to(self.device) 91 | 92 | 93 | class basic_conv(nn.Module): 94 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=True): 95 | super(basic_conv, self).__init__() 96 | self.out_channels = out_planes 97 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 98 | self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None 99 | self.relu = nn.ReLU(inplace=True) if relu else None 100 | 101 | def forward(self, x): 102 | x = self.conv(x) 103 | if self.bn is not None: 104 | x = self.bn(x) 105 | if self.relu is not None: 106 | x = self.relu(x) 107 | return x 108 | 109 | 110 | ''' 111 | image -> feature map -> rpn -> blstm -> fc -> classifier 112 | -> regression 113 | ''' 114 | class CTPN_Model(nn.Module): 115 | def __init__(self): 116 | super().__init__() 117 | base_model = models.vgg16(pretrained=False) 118 | layers = list(base_model.features)[:-1] 119 | self.base_layers = nn.Sequential(*layers) 120 | self.rpn = basic_conv(512, 512, 3, 1, 1, bn=False) 121 | self.brnn = nn.GRU(512,128, bidirectional=True, batch_first=True) 122 | self.lstm_fc = basic_conv(256, 512, 1, 1, relu=True, bn=False) 123 | self.rpn_class = basic_conv(512, 10 * 2, 1, 1, relu=False, bn=False) 124 | self.rpn_regress = basic_conv(512, 10 * 2, 1, 1, relu=False, bn=False) 125 | 126 | def forward(self, x): 127 | x = self.base_layers(x) 128 | # rpn 129 | x = self.rpn(x) #[b, c, h, w] 130 | 131 | x1 = x.permute(0, 2, 3, 1).contiguous() # channels last [b, h, w, c] 132 | b = x1.size() # b, h, w, c 133 | x1 = x1.view(b[0]*b[1], b[2], b[3]) 134 | 135 | x2, _ = self.brnn(x1) 136 | 137 | xsz = x.size() 138 | x3 = x2.view(xsz[0], xsz[2], xsz[3], 256) # torch.Size([4, 20, 20, 256]) 139 | 140 | x3 = x3.permute(0, 3, 1, 2).contiguous() # channels first [b, c, h, w] 141 | x3 = self.lstm_fc(x3) 142 | x = x3 143 | 144 | cls = self.rpn_class(x) 145 | regression = self.rpn_regress(x) 146 | 147 | cls = cls.permute(0, 2, 3, 1).contiguous() 148 | regression = regression.permute(0, 2, 3, 1).contiguous() 149 | 150 | cls = cls.view(cls.size(0), cls.size(1)*cls.size(2)*10, 2) 151 | regression = regression.view(regression.size(0), regression.size(1)*regression.size(2)*10, 2) 152 | 153 | return cls, regression 154 | 155 | if __name__=='__main__': 156 | CTPN_Model() 157 | -------------------------------------------------------------------------------- /ctpn/dataset.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | import os 3 | import numpy as np 4 | import cv2 5 | import torch 6 | from torch.utils.data import Dataset 7 | import xml.etree.ElementTree as ET 8 | from ctpn.utils import cal_rpn 9 | 10 | IMAGE_MEAN = [123.68, 116.779, 103.939] 11 | 12 | ''' 13 | 从xml文件中读取图像中的真值框 14 | ''' 15 | def readxml(path): 16 | gtboxes = [] 17 | xml = ET.parse(path) 18 | for elem in xml.iter(): 19 | if 'object' in elem.tag: 20 | for attr in list(elem): 21 | if 'bndbox' in attr.tag: 22 | xmin = int(round(float(attr.find('xmin').text))) 23 | ymin = int(round(float(attr.find('ymin').text))) 24 | xmax = int(round(float(attr.find('xmax').text))) 25 | ymax = int(round(float(attr.find('ymax').text))) 26 | gtboxes.append((xmin, ymin, xmax, ymax)) 27 | 28 | return np.array(gtboxes) 29 | 30 | 31 | ''' 32 | 读取VOC格式数据,返回用于训练的图像、anchor目标框、标签 33 | ''' 34 | class VOCDataset(Dataset): 35 | def __init__(self, datadir, labelsdir): 36 | if not os.path.isdir(datadir): 37 | raise Exception('[ERROR] {} is not a directory'.format(datadir)) 38 | if not os.path.isdir(labelsdir): 39 | raise Exception('[ERROR] {} is not a directory'.format(labelsdir)) 40 | 41 | self.datadir = datadir 42 | self.img_names = os.listdir(self.datadir) 43 | self.labelsdir = labelsdir 44 | 45 | def __len__(self): 46 | return len(self.img_names) 47 | 48 | def generate_gtboxes(self, xml_path, rescale_fac = 1.0): 49 | base_gtboxes = readxml(xml_path) 50 | gtboxes = [] 51 | for base_gtbox in base_gtboxes: 52 | xmin, ymin, xmax, ymax = base_gtbox 53 | if rescale_fac > 1.0: 54 | xmin = int(xmin / rescale_fac) 55 | xmax = int(xmax / rescale_fac) 56 | ymin = int(ymin / rescale_fac) 57 | ymax = int(ymax / rescale_fac) 58 | prev = xmin 59 | for i in range(xmin // 16 + 1, xmax // 16 + 1): 60 | next = 16*i-0.5 61 | gtboxes.append((prev, ymin, next, ymax)) 62 | prev = next 63 | gtboxes.append((prev, ymin, xmax, ymax)) 64 | return np.array(gtboxes) 65 | 66 | def __getitem__(self, idx): 67 | img_name = self.img_names[idx] 68 | img_path = os.path.join(self.datadir, img_name) 69 | img = cv2.imread(img_path) 70 | h, w, c = img.shape 71 | rescale_fac = max(h, w) / 1000 72 | if rescale_fac > 1.0: 73 | h = int(h / rescale_fac) 74 | w = int(w / rescale_fac) 75 | img = cv2.resize(img,(w,h)) 76 | 77 | xml_path = os.path.join(self.labelsdir, img_name.split('.')[0]+'.xml') 78 | gtbox = self.generate_gtboxes(xml_path, rescale_fac) 79 | 80 | if np.random.randint(2) == 1: 81 | img = img[:, ::-1, :] 82 | newx1 = w - gtbox[:, 2] - 1 83 | newx2 = w - gtbox[:, 0] - 1 84 | gtbox[:, 0] = newx1 85 | gtbox[:, 2] = newx2 86 | 87 | [cls, regr] = cal_rpn((h, w), (int(h / 16), int(w / 16)), 16, gtbox) 88 | regr = np.hstack([cls.reshape(cls.shape[0], 1), regr]) 89 | cls = np.expand_dims(cls, axis=0) 90 | 91 | m_img = img - IMAGE_MEAN 92 | m_img = torch.from_numpy(m_img.transpose([2, 0, 1])).float() 93 | cls = torch.from_numpy(cls).float() 94 | regr = torch.from_numpy(regr).float() 95 | 96 | return m_img, cls, regr 97 | 98 | 99 | ################################################################################ 100 | 101 | 102 | class ICDARDataset(Dataset): 103 | def __init__(self, datadir, labelsdir): 104 | if not os.path.isdir(datadir): 105 | raise Exception('[ERROR] {} is not a directory'.format(datadir)) 106 | if not os.path.isdir(labelsdir): 107 | raise Exception('[ERROR] {} is not a directory'.format(labelsdir)) 108 | 109 | self.datadir = datadir 110 | self.img_names = os.listdir(self.datadir) 111 | self.labelsdir = labelsdir 112 | 113 | def __len__(self): 114 | return len(self.img_names) 115 | 116 | def box_transfer(self, coor_lists, rescale_fac = 1.0): 117 | gtboxes = [] 118 | for coor_list in coor_lists: 119 | coors_x = [int(coor_list[2*i]) for i in range(4)] 120 | coors_y = [int(coor_list[2*i+1]) for i in range(4)] 121 | xmin = min(coors_x) 122 | xmax = max(coors_x) 123 | ymin = min(coors_y) 124 | ymax = max(coors_y) 125 | if rescale_fac > 1.0: 126 | xmin = int(xmin / rescale_fac) 127 | xmax = int(xmax / rescale_fac) 128 | ymin = int(ymin / rescale_fac) 129 | ymax = int(ymax / rescale_fac) 130 | gtboxes.append((xmin, ymin, xmax, ymax)) 131 | return np.array(gtboxes) 132 | 133 | def box_transfer_v2(self, coor_lists, rescale_fac = 1.0): 134 | gtboxes = [] 135 | for coor_list in coor_lists: 136 | coors_x = [int(coor_list[2 * i]) for i in range(4)] 137 | coors_y = [int(coor_list[2 * i + 1]) for i in range(4)] 138 | xmin = min(coors_x) 139 | xmax = max(coors_x) 140 | ymin = min(coors_y) 141 | ymax = max(coors_y) 142 | if rescale_fac > 1.0: 143 | xmin = int(xmin / rescale_fac) 144 | xmax = int(xmax / rescale_fac) 145 | ymin = int(ymin / rescale_fac) 146 | ymax = int(ymax / rescale_fac) 147 | prev = xmin 148 | for i in range(xmin // 16 + 1, xmax // 16 + 1): 149 | next = 16*i-0.5 150 | gtboxes.append((prev, ymin, next, ymax)) 151 | prev = next 152 | gtboxes.append((prev, ymin, xmax, ymax)) 153 | return np.array(gtboxes) 154 | 155 | def parse_gtfile(self, gt_path, rescale_fac = 1.0): 156 | coor_lists = list() 157 | with open(gt_path, 'r', encoding="utf-8-sig") as f: 158 | content = f.readlines() 159 | for line in content: 160 | coor_list = line.split(',')[:8] 161 | if len(coor_list) == 8: 162 | coor_lists.append(coor_list) 163 | return self.box_transfer_v2(coor_lists, rescale_fac) 164 | 165 | def draw_boxes(self,img,cls,base_anchors,gt_box): 166 | for i in range(len(cls)): 167 | if cls[i]==1: 168 | pt1 = (int(base_anchors[i][0]),int(base_anchors[i][1])) 169 | pt2 = (int(base_anchors[i][2]),int(base_anchors[i][3])) 170 | img = cv2.rectangle(img,pt1,pt2,(200,100,100)) 171 | for i in range(gt_box.shape[0]): 172 | pt1 = (int(gt_box[i][0]),int(gt_box[i][1])) 173 | pt2 = (int(gt_box[i][2]),int(gt_box[i][3])) 174 | img = cv2.rectangle(img, pt1, pt2, (100, 200, 100)) 175 | return img 176 | 177 | def __getitem__(self, idx): 178 | img_name = self.img_names[idx] 179 | img_path = os.path.join(self.datadir, img_name) 180 | img = cv2.imread(img_path) 181 | 182 | h, w, c = img.shape 183 | rescale_fac = max(h, w) / 1000 184 | if rescale_fac > 1.0: 185 | h = int(h / rescale_fac) 186 | w = int(w / rescale_fac) 187 | img = cv2.resize(img,(w,h)) 188 | 189 | gt_path = os.path.join(self.labelsdir, img_name.split('.')[0]+'.txt') 190 | gtbox = self.parse_gtfile(gt_path, rescale_fac) 191 | 192 | # random flip image 193 | if np.random.randint(2) == 1: 194 | img = img[:, ::-1, :] 195 | newx1 = w - gtbox[:, 2] - 1 196 | newx2 = w - gtbox[:, 0] - 1 197 | gtbox[:, 0] = newx1 198 | gtbox[:, 2] = newx2 199 | 200 | [cls, regr] = cal_rpn((h, w), (int(h / 16), int(w / 16)), 16, gtbox) 201 | regr = np.hstack([cls.reshape(cls.shape[0], 1), regr]) 202 | cls = np.expand_dims(cls, axis=0) 203 | 204 | m_img = img - IMAGE_MEAN 205 | m_img = torch.from_numpy(m_img.transpose([2, 0, 1])).float() 206 | cls = torch.from_numpy(cls).float() 207 | regr = torch.from_numpy(regr).float() 208 | 209 | return m_img, cls, regr -------------------------------------------------------------------------------- /ctpn/utils.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | import numpy as np 3 | import cv2 4 | from ctpn.config import * 5 | 6 | 7 | ''' 8 | anchor生成 9 | 遇到的问题:首先,base_anchor 为初始位置点生成的anchor,按步长在feature map 的各个点生成anchor之后,anchors的 shape 为[10, h*w, 4]。 10 | 这里,我一开始是直接将anchors reshape 成 [10*h*w, 4],这在训练时不收敛。 11 | 原因浅析:按我代码的实现方式,直接[10, h*w, 4] -> [10*h*w, 4],anchor 的排列顺序将按照不同的anchor形状(共10种)进行排列,而不是根据feature map 的点按序排列, 12 | 而按 ctpn 的实现方式,小的anchor需要连成大的文本框才是最终的结果,不按点的顺序生成anchor可能给训练带来较大的干扰。 13 | 解决方案:将 anchor 根据feature_map 的各个点,按序生成10个anchor重新排列,也即:[10, h*w, 4] -> [h*w, 10, 4] -> [10*h*w, 4],问题解决。 14 | ''' 15 | def gen_anchor( featuresize, scale, 16 | heights = [11, 16, 23, 33, 48, 68, 97, 139, 198, 283], 17 | widths = [16, 16, 16, 16, 16, 16, 16, 16, 16, 16]): 18 | h, w = featuresize 19 | shift_x = np.arange(0, w) * scale 20 | shift_y = np.arange(0, h) * scale 21 | shift_x, shift_y = np.meshgrid(shift_x, shift_y) 22 | shift = np.stack((shift_x.ravel(), shift_y.ravel(), shift_x.ravel(), shift_y.ravel()), axis=1) 23 | 24 | #base center(x,,y) -> (x1, y1, x2, y2) 25 | base_anchor = np.array([0, 0, 15, 15]) 26 | xt = (base_anchor[0] + base_anchor[2]) * 0.5 27 | yt = (base_anchor[1] + base_anchor[3]) * 0.5 28 | heights = np.array(heights).reshape(len(heights), 1) 29 | widths = np.array(widths).reshape(len(widths), 1) 30 | x1 = xt - widths * 0.5 31 | y1 = yt - heights * 0.5 32 | x2 = xt + widths * 0.5 33 | y2 = yt + heights * 0.5 34 | base_anchor = np.hstack((x1, y1, x2, y2)) 35 | 36 | anchor = list() 37 | for i in range(base_anchor.shape[0]): 38 | anchor_x1 = shift[:,0] + base_anchor[i][0] 39 | anchor_y1 = shift[:,1] + base_anchor[i][1] 40 | anchor_x2 = shift[:,2] + base_anchor[i][2] 41 | anchor_y2 = shift[:,3] + base_anchor[i][3] 42 | anchor.append(np.dstack((anchor_x1, anchor_y1, anchor_x2, anchor_y2))) 43 | 44 | return np.squeeze(np.array(anchor)).transpose((1,0,2)).reshape((-1, 4)) 45 | 46 | ''' 47 | anchor 与 bbox的 iou计算 48 | iou = inter_area/(bb_area + anchor_area - inter_area) 49 | ''' 50 | def compute_iou(anchors, bbox): 51 | ious = np.zeros((len(anchors), len(bbox)), dtype=np.float32) 52 | anchor_area = (anchors[:,2] - anchors[:,0])*(anchors[:,3] - anchors[:,1]) 53 | for num, _bbox in enumerate(bbox): 54 | bb = np.tile(_bbox,(len(anchors), 1)) 55 | bb_area = (bb[:,2] - bb[:,0])*(bb[:,3] - bb[:,1]) 56 | inter_h = np.maximum(np.minimum(bb[:,3], anchors[:,3]) - np.maximum(bb[:,1], anchors[:,1]), 0) 57 | inter_w = np.maximum(np.minimum(bb[:,2], anchors[:,2]) - np.maximum(bb[:,0], anchors[:,0]), 0) 58 | inter_area = inter_h*inter_w 59 | ious[:,num] = inter_area/(bb_area + anchor_area - inter_area) 60 | 61 | return ious 62 | 63 | ''' 64 | 计算 anchor与 gtboxes在垂直方向的差异参数 regression_factor(Vc, Vh) 65 | 1、(x1, y1, x2, y2) -> (ctr_x, ctr_y, w, h) 66 | 2、 Vc = (gt_y - anchor_y) / anchor_h 67 | Vh = np.log(gt_h / anchor_h) 68 | ''' 69 | def bbox_transfrom(anchors, gtboxes): 70 | gt_y = (gtboxes[:, 1] + gtboxes[:, 3]) * 0.5 71 | gt_h = gtboxes[:, 3] - gtboxes[:, 1] + 1.0 72 | 73 | anchor_y = (anchors[:, 1] + anchors[:, 3]) * 0.5 74 | anchor_h = anchors[:, 3] - anchors[:, 1] + 1.0 75 | 76 | Vc = (gt_y - anchor_y) / anchor_h 77 | Vh = np.log(gt_h / anchor_h) 78 | 79 | return np.vstack((Vc, Vh)).transpose() 80 | 81 | ''' 82 | 已知 anchor和差异参数 regression_factor(Vc, Vh),计算目标框 bbox 83 | ''' 84 | def transform_bbox(anchor, regression_factor): 85 | anchor_y = (anchor[:, 1] + anchor[:, 3]) * 0.5 86 | anchor_x = (anchor[:, 0] + anchor[:, 2]) * 0.5 87 | anchor_h = anchor[:, 3] - anchor[:, 1] + 1 88 | 89 | Vc = regression_factor[0, :, 0] 90 | Vh = regression_factor[0, :, 1] 91 | 92 | bbox_y = Vc * anchor_h + anchor_y 93 | bbox_h = np.exp(Vh) * anchor_h 94 | 95 | x1 = anchor_x - 16 * 0.5 96 | y1 = bbox_y - bbox_h * 0.5 97 | x2 = anchor_x + 16 * 0.5 98 | y2 = bbox_y + bbox_h * 0.5 99 | bbox = np.vstack((x1, y1, x2, y2)).transpose() 100 | 101 | return bbox 102 | 103 | ''' 104 | bbox 边界裁剪 105 | x1 >= 0 106 | y1 >= 0 107 | x2 < im_shape[1] 108 | y2 < im_shape[0] 109 | ''' 110 | def clip_bbox(bbox, im_shape): 111 | bbox[:, 0] = np.maximum(np.minimum(bbox[:, 0], im_shape[1] - 1), 0) 112 | bbox[:, 1] = np.maximum(np.minimum(bbox[:, 1], im_shape[0] - 1), 0) 113 | bbox[:, 2] = np.maximum(np.minimum(bbox[:, 2], im_shape[1] - 1), 0) 114 | bbox[:, 3] = np.maximum(np.minimum(bbox[:, 3], im_shape[0] - 1), 0) 115 | 116 | return bbox 117 | 118 | ''' 119 | bbox尺寸过滤,舍弃小于设定最小尺寸的bbox 120 | ''' 121 | def filter_bbox(bbox, minsize): 122 | ws = bbox[:, 2] - bbox[:, 0] + 1 123 | hs = bbox[:, 3] - bbox[:, 1] + 1 124 | keep = np.where((ws >= minsize) & (hs >= minsize))[0] 125 | return keep 126 | 127 | ''' 128 | RPN module 129 | 1、生成anchor 130 | 2、计算anchor 和真值框 gtboxes的 iou 131 | 3、根据 iou,给每个anchor分配标签,0为负样本,1为正样本,-1为舍弃项 132 | (1) 对每个真值框 bbox,找出与其 iou最大的 anchor,设为正样本 133 | (2) 对每个anchor,记录其与每个bbox求取的 iou中最大的值 max_overlap 134 | (3) 对max_overlap 大于设定阈值的anchor,将其设为正样本,小于设定阈值,则设定为负样本 135 | 4、过滤超出边界的anchor框,将其标签设定为 -1 136 | 5、选取不超过设定数量的正负样本 137 | 6、求取anchor 取得max_overlap 时的gtbbox之间的真值差异量(Vc, Vh) 138 | ''' 139 | def cal_rpn(imgsize, featuresize, scale, gtboxes): 140 | base_anchor = gen_anchor(featuresize, scale) 141 | overlaps = compute_iou(base_anchor, gtboxes) 142 | 143 | gt_argmax_overlaps = overlaps.argmax(axis=0) 144 | anchor_argmax_overlaps = overlaps.argmax(axis=1) 145 | anchor_max_overlaps = overlaps[range(overlaps.shape[0]), anchor_argmax_overlaps] 146 | 147 | labels = np.empty(base_anchor.shape[0]) 148 | labels.fill(-1) 149 | labels[gt_argmax_overlaps] = 1 150 | labels[anchor_max_overlaps > IOU_POSITIVE] = 1 151 | labels[anchor_max_overlaps < IOU_NEGATIVE] = 0 152 | 153 | outside_anchor = np.where( 154 | (base_anchor[:, 0] < 0) | 155 | (base_anchor[:, 1] < 0) | 156 | (base_anchor[:, 2] >= imgsize[1]) | 157 | (base_anchor[:, 3] >= imgsize[0]) 158 | )[0] 159 | labels[outside_anchor] = -1 160 | 161 | fg_index = np.where(labels == 1)[0] 162 | if (len(fg_index) > RPN_POSITIVE_NUM): 163 | labels[np.random.choice(fg_index, len(fg_index) - RPN_POSITIVE_NUM, replace=False)] = -1 164 | if not OHEM: 165 | bg_index = np.where(labels == 0)[0] 166 | num_bg = RPN_TOTAL_NUM - np.sum(labels == 1) 167 | if (len(bg_index) > num_bg): 168 | labels[np.random.choice(bg_index, len(bg_index) - num_bg, replace=False)] = -1 169 | 170 | bbox_targets = bbox_transfrom(base_anchor, gtboxes[anchor_argmax_overlaps, :]) 171 | 172 | return [labels, bbox_targets] 173 | 174 | 175 | ''' 176 | 非极大值抑制,去除重叠框 177 | ''' 178 | def nms(dets, thresh): 179 | x1 = dets[:, 0] 180 | y1 = dets[:, 1] 181 | x2 = dets[:, 2] 182 | y2 = dets[:, 3] 183 | scores = dets[:, 4] 184 | 185 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 186 | order = scores.argsort()[::-1] 187 | 188 | keep = [] 189 | while order.size > 0: 190 | i = order[0] 191 | keep.append(i) 192 | xx1 = np.maximum(x1[i], x1[order[1:]]) 193 | yy1 = np.maximum(y1[i], y1[order[1:]]) 194 | xx2 = np.minimum(x2[i], x2[order[1:]]) 195 | yy2 = np.minimum(y2[i], y2[order[1:]]) 196 | 197 | w = np.maximum(0.0, xx2 - xx1 + 1) 198 | h = np.maximum(0.0, yy2 - yy1 + 1) 199 | inter = w * h 200 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 201 | 202 | inds = np.where(ovr <= thresh)[0] 203 | order = order[inds + 1] 204 | return keep 205 | 206 | 207 | ''' 208 | 基于图的文本行构造算法 209 | 子图连接规则,根据图中配对的文本框生成文本行 210 | 1、遍历 graph 的行和列,寻找列全为false、行不全为false的行和列,索引号为index 211 | 2、找到 graph 的第 index 行中为true的那项的索引号,加入子图中,并将索引号迭代给index 212 | 3、重复步骤2,直到 graph 的第 index 行全部为false 213 | 4、重复步骤1、2、3,遍历完graph 214 | 返回文本行list[文本框索引] 215 | ''' 216 | class Graph: 217 | def __init__(self, graph): 218 | self.graph = graph 219 | 220 | def sub_graphs_connected(self): 221 | sub_graphs = [] 222 | for index in range(self.graph.shape[0]): 223 | if not self.graph[:, index].any() and self.graph[index, :].any(): 224 | v = index 225 | sub_graphs.append([v]) 226 | while self.graph[v, :].any(): 227 | v = np.where(self.graph[v, :])[0][0] 228 | sub_graphs[-1].append(v) 229 | 230 | return sub_graphs 231 | 232 | ''' 233 | 配置参数 234 | MAX_HORIZONTAL_GAP: 文本行内,文本框最大水平距离 235 | MIN_V_OVERLAPS: 文本框最小垂直iou 236 | MIN_SIZE_SIM: 文本框尺寸最小相似度 237 | ''' 238 | class TextLineCfg: 239 | SCALE = 600 240 | MAX_SCALE = 1200 241 | TEXT_PROPOSALS_WIDTH = 16 242 | MIN_NUM_PROPOSALS = 2 243 | MIN_RATIO = 0.5 244 | LINE_MIN_SCORE = 0.9 245 | TEXT_PROPOSALS_MIN_SCORE = 0.7 246 | TEXT_PROPOSALS_NMS_THRESH = 0.3 247 | MAX_HORIZONTAL_GAP = 60 248 | MIN_V_OVERLAPS = 0.6 249 | MIN_SIZE_SIM = 0.6 250 | 251 | 252 | class TextProposalGraphBuilder: 253 | ''' 254 | 构建配对的文本框 255 | ''' 256 | def get_successions(self, index): 257 | ''' 258 | 遍历[x0, x0+MAX_HORIZONTAL_GAP] 259 | 获取指定索引号的后继文本框 260 | ''' 261 | box = self.text_proposals[index] 262 | results = [] 263 | for left in range(int(box[0]) + 1, min(int(box[0]) + TextLineCfg.MAX_HORIZONTAL_GAP + 1, self.im_size[1])): 264 | adj_box_indices = self.boxes_table[left] 265 | for adj_box_index in adj_box_indices: 266 | if self.meet_v_iou(adj_box_index, index): 267 | results.append(adj_box_index) 268 | if len(results) != 0: 269 | return results 270 | 271 | return results 272 | 273 | def get_precursors(self, index): 274 | ''' 275 | 遍历[x0-MAX_HORIZONTAL_GAP, x0] 276 | 获取指定索引号的前驱文本框 277 | ''' 278 | box = self.text_proposals[index] 279 | results = [] 280 | for left in range(int(box[0]) - 1, max(int(box[0] - TextLineCfg.MAX_HORIZONTAL_GAP), 0) - 1, -1): 281 | adj_box_indices = self.boxes_table[left] 282 | for adj_box_index in adj_box_indices: 283 | if self.meet_v_iou(adj_box_index, index): 284 | results.append(adj_box_index) 285 | if len(results) != 0: 286 | return results 287 | 288 | return results 289 | 290 | def is_succession_node(self, index, succession_index): 291 | ''' 292 | 判断是否是配对的文本框 293 | ''' 294 | precursors = self.get_precursors(succession_index) 295 | if self.scores[index] >= np.max(self.scores[precursors]): 296 | return True 297 | 298 | return False 299 | 300 | def meet_v_iou(self, index1, index2): 301 | ''' 302 | 判断两个文本框是否满足垂直方向的iou条件 303 | overlaps_v: 文本框垂直方向的iou计算。 iou_v = inv_y/min(h1, h2) 304 | size_similarity: 文本框在垂直方向的高度尺寸相似度。 sim = min(h1, h2)/max(h1, h2) 305 | ''' 306 | def overlaps_v(index1, index2): 307 | h1 = self.heights[index1] 308 | h2 = self.heights[index2] 309 | y0 = max(self.text_proposals[index2][1], self.text_proposals[index1][1]) 310 | y1 = min(self.text_proposals[index2][3], self.text_proposals[index1][3]) 311 | return max(0, y1 - y0 + 1) / min(h1, h2) 312 | 313 | def size_similarity(index1, index2): 314 | h1 = self.heights[index1] 315 | h2 = self.heights[index2] 316 | return min(h1, h2) / max(h1, h2) 317 | 318 | return overlaps_v(index1, index2) >= TextLineCfg.MIN_V_OVERLAPS and \ 319 | size_similarity(index1, index2) >= TextLineCfg.MIN_SIZE_SIM 320 | 321 | def build_graph(self, text_proposals, scores, im_size): 322 | ''' 323 | 根据文本框构建文本框对 324 | self.heights: 所有文本框的高度 325 | self.boxes_table: 将文本框根据左上点的x1坐标进行分组 326 | graph: bool类型的[n, n]数组,表示两个文本框是否配对,n为文本框的个数 327 | (1) 获取当前文本框Bi的后继文本框 328 | (2) 选取后继文本框中得分最高的,记为Bj 329 | (3) 获取Bj的前驱文本框 330 | (4) 如果Bj的前驱文本框中得分最高的恰好是 Bi,则构成文本框对 331 | ''' 332 | self.text_proposals = text_proposals 333 | self.scores = scores 334 | self.im_size = im_size 335 | self.heights = text_proposals[:, 3] - text_proposals[:, 1] + 1 336 | 337 | boxes_table = [[] for _ in range(self.im_size[1])] 338 | for index, box in enumerate(text_proposals): 339 | boxes_table[int(box[0])].append(index) 340 | self.boxes_table = boxes_table 341 | 342 | graph = np.zeros((text_proposals.shape[0], text_proposals.shape[0]), np.bool) 343 | 344 | for index, box in enumerate(text_proposals): 345 | successions = self.get_successions(index) 346 | if len(successions) == 0: 347 | continue 348 | succession_index = successions[np.argmax(scores[successions])] 349 | if self.is_succession_node(index, succession_index): 350 | graph[index, succession_index] = True 351 | 352 | return Graph(graph) 353 | 354 | 355 | class TextProposalConnectorOriented: 356 | """ 357 | 连接文本框,构建文本行bbox 358 | """ 359 | 360 | def __init__(self): 361 | self.graph_builder = TextProposalGraphBuilder() 362 | 363 | def group_text_proposals(self, text_proposals, scores, im_size): 364 | ''' 365 | 将文本框连接起来,按文本行分组 366 | ''' 367 | graph = self.graph_builder.build_graph(text_proposals, scores, im_size) 368 | 369 | return graph.sub_graphs_connected() 370 | 371 | def fit_y(self, X, Y, x1, x2): 372 | ''' 373 | 一元线性函数拟合X,Y,返回y1, y2的坐标值 374 | ''' 375 | if np.sum(X == X[0]) == len(X): 376 | return Y[0], Y[0] 377 | p = np.poly1d(np.polyfit(X, Y, 1)) 378 | return p(x1), p(x2) 379 | 380 | def get_text_lines(self, text_proposals, scores, im_size): 381 | ''' 382 | 根据文本框,构建文本行 383 | 1、将文本框划分成文本行组,每个文本行组内包含符合规则的文本框 384 | 2、处理每个文本行组,将其串成一个大的文本行 385 | (1) 获取文本行组内的所有文本框 text_line_boxes 386 | (2) 求取每个组内每个文本框的中心坐标 (X, Y),最小、最大宽度坐标值 (x0 ,x1) 387 | (3) 拟合所有中心点直线 z1 388 | (4) 设置offset为文本框宽度的一半 389 | (5) 拟合组内所有文本框的左上角点直线,并返回当x取 (x0+offset, x1-offset)时的极作极右y坐标 (lt_y, rt_y) 390 | (6) 拟合组内所有文本框的左下角点直线,并返回当x取 (x0+offset, x1-offset)时的极作极右y坐标 (lb_y, rb_y) 391 | (7) 取文本行组内所有框的评分的均值,作为该文本行的分数 392 | (8) 生成文本行基本数据 393 | 3、生成大文本框 394 | ''' 395 | tp_groups = self.group_text_proposals(text_proposals, scores, im_size) 396 | 397 | text_lines = np.zeros((len(tp_groups), 8), np.float32) 398 | for index, tp_indices in enumerate(tp_groups): 399 | text_line_boxes = text_proposals[list(tp_indices)] 400 | 401 | X = (text_line_boxes[:, 0] + text_line_boxes[:, 2]) / 2 402 | Y = (text_line_boxes[:, 1] + text_line_boxes[:, 3]) / 2 403 | x0 = np.min(text_line_boxes[:, 0]) 404 | x1 = np.max(text_line_boxes[:, 2]) 405 | 406 | z1 = np.polyfit(X, Y, 1) 407 | 408 | offset = (text_line_boxes[0, 2] - text_line_boxes[0, 0]) * 0.5 409 | 410 | lt_y, rt_y = self.fit_y(text_line_boxes[:, 0], text_line_boxes[:, 1], x0 + offset, x1 - offset) 411 | lb_y, rb_y = self.fit_y(text_line_boxes[:, 0], text_line_boxes[:, 3], x0 + offset, x1 - offset) 412 | 413 | score = scores[list(tp_indices)].sum() / float(len(tp_indices)) 414 | 415 | text_lines[index, 0] = x0 416 | text_lines[index, 1] = min(lt_y, rt_y) # 文本行上端 线段 的y坐标的小值 417 | text_lines[index, 2] = x1 418 | text_lines[index, 3] = max(lb_y, rb_y) # 文本行下端 线段 的y坐标的大值 419 | text_lines[index, 4] = score # 文本行得分 420 | text_lines[index, 5] = z1[0] # 根据中心点拟合的直线的k,b 421 | text_lines[index, 6] = z1[1] 422 | height = np.mean((text_line_boxes[:, 3] - text_line_boxes[:, 1])) # 小框平均高度 423 | text_lines[index, 7] = height + 2.5 424 | 425 | text_recs = np.zeros((len(text_lines), 9), np.float) 426 | index = 0 427 | for line in text_lines: 428 | b1 = line[6] - line[7] / 2 # 根据高度和文本行中心线,求取文本行上下两条线的b值 429 | b2 = line[6] + line[7] / 2 430 | x1 = line[0] 431 | y1 = line[5] * line[0] + b1 # 左上 432 | x2 = line[2] 433 | y2 = line[5] * line[2] + b1 # 右上 434 | x3 = line[0] 435 | y3 = line[5] * line[0] + b2 # 左下 436 | x4 = line[2] 437 | y4 = line[5] * line[2] + b2 # 右下 438 | disX = x2 - x1 439 | disY = y2 - y1 440 | width = np.sqrt(disX * disX + disY * disY) # 文本行宽度 441 | 442 | fTmp0 = y3 - y1 # 文本行高度 443 | fTmp1 = fTmp0 * disY / width 444 | x = np.fabs(fTmp1 * disX / width) # 做补偿 445 | y = np.fabs(fTmp1 * disY / width) 446 | if line[5] < 0: 447 | x1 -= x 448 | y1 += y 449 | x4 += x 450 | y4 -= y 451 | else: 452 | x2 += x 453 | y2 += y 454 | x3 -= x 455 | y3 -= y 456 | text_recs[index, 0] = x1 457 | text_recs[index, 1] = y1 458 | text_recs[index, 2] = x2 459 | text_recs[index, 3] = y2 460 | text_recs[index, 4] = x3 461 | text_recs[index, 5] = y3 462 | text_recs[index, 6] = x4 463 | text_recs[index, 7] = y4 464 | text_recs[index, 8] = line[4] 465 | index = index + 1 466 | 467 | return text_recs 468 | 469 | if __name__=='__main__': 470 | anchor = gen_anchor((10, 15), 16) 471 | -------------------------------------------------------------------------------- /log/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrazySummerday/ctpn.pytorch/99f6baf2780e550d7b4656ac7a7b90af9ade468f/log/1.jpg -------------------------------------------------------------------------------- /log/1_src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrazySummerday/ctpn.pytorch/99f6baf2780e550d7b4656ac7a7b90af9ade468f/log/1_src.jpg -------------------------------------------------------------------------------- /log/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrazySummerday/ctpn.pytorch/99f6baf2780e550d7b4656ac7a7b90af9ade468f/log/2.jpg -------------------------------------------------------------------------------- /log/2_src.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrazySummerday/ctpn.pytorch/99f6baf2780e550d7b4656ac7a7b90af9ade468f/log/2_src.png -------------------------------------------------------------------------------- /log/training_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrazySummerday/ctpn.pytorch/99f6baf2780e550d7b4656ac7a7b90af9ade468f/log/training_loss.png -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | import os 3 | import cv2 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from ctpn import config 8 | from ctpn.ctpn import CTPN_Model 9 | from ctpn.utils import gen_anchor, transform_bbox, clip_bbox, filter_bbox, nms, TextProposalConnectorOriented 10 | 11 | 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | weights = './weights/ctpn.pth' 14 | model = CTPN_Model().to(device) 15 | model.load_state_dict(torch.load(weights, map_location=device)['model_state_dict']) 16 | model.eval() 17 | 18 | 19 | 20 | def get_text_boxes(image, display = True, prob_thresh = 0.5): 21 | h, w= image.shape[:2] 22 | rescale_fac = max(h, w) / 1000 23 | if rescale_fac > 1.0: 24 | h = int(h / rescale_fac) 25 | w = int(w / rescale_fac) 26 | image = cv2.resize(image, (w,h)) 27 | h, w = image.shape[:2] 28 | image_c = image.copy() 29 | image = image.astype(np.float32) - config.IMAGE_MEAN 30 | image = torch.from_numpy(image.transpose(2, 0, 1)).unsqueeze(0).float().to(device) 31 | 32 | with torch.no_grad(): 33 | cls, regr = model(image) 34 | cls_prob = F.softmax(cls, dim=-1).cpu().numpy() 35 | regr = regr.cpu().numpy() 36 | anchor = gen_anchor((int(h / 16), int(w / 16)), 16) 37 | bbox = transform_bbox(anchor, regr) 38 | bbox = clip_bbox(bbox, [h, w]) 39 | 40 | fg = np.where(cls_prob[0, :, 1] > prob_thresh)[0] 41 | select_anchor = bbox[fg, :] 42 | select_score = cls_prob[0, fg, 1] 43 | select_anchor = select_anchor.astype(np.int32) 44 | keep_index = filter_bbox(select_anchor, 16) 45 | 46 | select_anchor = select_anchor[keep_index] 47 | select_score = select_score[keep_index] 48 | select_score = np.reshape(select_score, (select_score.shape[0], 1)) 49 | nmsbox = np.hstack((select_anchor, select_score)) 50 | keep = nms(nmsbox, 0.3) 51 | select_anchor = select_anchor[keep] 52 | select_score = select_score[keep] 53 | 54 | textConn = TextProposalConnectorOriented() 55 | text = textConn.get_text_lines(select_anchor, select_score, [h, w]) 56 | if display: 57 | for i in text: 58 | s = str(round(i[-1] * 100, 2)) + '%' 59 | i = [int(j) for j in i] 60 | cv2.line(image_c, (i[0], i[1]), (i[2], i[3]), (0, 0, 255), 2) 61 | cv2.line(image_c, (i[0], i[1]), (i[4], i[5]), (0, 0, 255), 2) 62 | cv2.line(image_c, (i[6], i[7]), (i[2], i[3]), (0, 0, 255), 2) 63 | cv2.line(image_c, (i[4], i[5]), (i[6], i[7]), (0, 0, 255), 2) 64 | cv2.putText(image_c, s, (i[0]+13, i[1]+13), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,0,0), 2, cv2.LINE_AA) 65 | 66 | return text, image_c 67 | 68 | if __name__ == '__main__': 69 | img_path = 'images/21.jpg' 70 | input_img = cv2.imread(img_path) 71 | text, out_img = get_text_boxes(input_img) 72 | cv2.imwrite('results/21.jpg', out_img) 73 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | import os 3 | import time 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from torch import optim 8 | from ctpn.ctpn import CTPN_Model, RPN_CLS_Loss, RPN_REGR_Loss 9 | from ctpn.dataset import VOCDataset 10 | from ctpn import config 11 | import visdom 12 | 13 | random_seed = 2019 14 | torch.random.manual_seed(random_seed) 15 | np.random.seed(random_seed) 16 | 17 | epochs = 80 18 | lr = 1e-3 19 | resume_epoch = 0 20 | 21 | 22 | def save_checkpoint(state, epoch, loss_cls, loss_regr, loss, ext='pth'): 23 | check_path = os.path.join(config.checkpoints_dir, 'ctpn_ep{:02d}_{:.4f}_{:.4f}_{:.4f}.'.format(epoch, loss_cls, loss_regr, loss) + ext) 24 | try: 25 | torch.save(state, check_path) 26 | except BaseException as e: 27 | print(e) 28 | print('fail to save to {}'.format(check_path)) 29 | print('saving to {}'.format(check_path)) 30 | 31 | def weights_init(m): 32 | classname = m.__class__.__name__ 33 | if classname.find('Conv') != -1: 34 | m.weight.data.normal_(0.0, 0.02) 35 | elif classname.find('BatchNorm') != -1: 36 | m.weight.data.normal_(1.0, 0.02) 37 | m.bias.data.fill_(0) 38 | 39 | 40 | if __name__ == '__main__': 41 | dataset = VOCDataset(config.img_dir, config.label_dir) 42 | dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=config.num_workers) 43 | 44 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 45 | model = CTPN_Model().to(device) 46 | 47 | checkpoints_weight = config.pretrained_weights 48 | print('exist pretrained ',os.path.exists(checkpoints_weight)) 49 | if os.path.exists(checkpoints_weight): 50 | print('using pretrained weight: {}'.format(checkpoints_weight)) 51 | cc = torch.load(checkpoints_weight, map_location=device) 52 | model.load_state_dict(cc['model_state_dict']) 53 | resume_epoch = cc['epoch'] 54 | else: 55 | model.apply(weights_init) 56 | 57 | optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9) 58 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [35, 55, 70], gamma=0.1, last_epoch=-1) 59 | 60 | critetion_cls = RPN_CLS_Loss(device) 61 | critetion_regr = RPN_REGR_Loss(device) 62 | 63 | best_loss_cls = 100 64 | best_loss_regr = 100 65 | best_loss = 100 66 | best_model = None 67 | epochs += resume_epoch 68 | 69 | viz = visdom.Visdom(env='ctpn-train') 70 | n_iter = 0 71 | for epoch in range(resume_epoch+1, epochs): 72 | print('Epoch {}/{}'.format(epoch, epochs)) 73 | epoch_size = len(dataset) // 1 74 | model.train() 75 | epoch_loss_cls = 0 76 | epoch_loss_regr = 0 77 | epoch_loss = 0 78 | scheduler.step(epoch) 79 | for param_group in scheduler.optimizer.param_groups: 80 | print('lr: %s'% param_group['lr']) 81 | print('#'*80) 82 | 83 | for batch_i, (imgs, clss, regrs) in enumerate(dataloader): 84 | since = time.time() 85 | imgs = imgs.to(device) 86 | clss = clss.to(device) 87 | regrs = regrs.to(device) 88 | 89 | optimizer.zero_grad() 90 | 91 | out_cls, out_regr = model(imgs) 92 | loss_regr = critetion_regr(out_regr, regrs) 93 | loss_cls = critetion_cls(out_cls, clss) 94 | 95 | loss = loss_cls + loss_regr 96 | loss.backward() 97 | optimizer.step() 98 | 99 | epoch_loss_cls += loss_cls.item() 100 | epoch_loss_regr += loss_regr.item() 101 | epoch_loss += loss.item() 102 | mmp = batch_i + 1 103 | n_iter += 1 104 | print('time:{}'.format(time.time() - since)) 105 | print( 'EPOCH:{}/{}--BATCH:{}/{}\n'.format(epoch, epochs-1, batch_i, epoch_size), 106 | 'batch: loss_cls:{:.4f}--loss_regr:{:.4f}--loss:{:.4f}\n'.format(loss_cls.item(), loss_regr.item(), loss.item()), 107 | 'epoch: loss_cls:{:.4f}--loss_regr:{:.4f}--loss:{:.4f}\n'.format(epoch_loss_cls/mmp, epoch_loss_regr/mmp, epoch_loss/mmp) 108 | ) 109 | if mmp % 100 == 0: 110 | viz.line(Y=np.array([epoch_loss_cls/mmp]), X=np.array([n_iter//100]), 111 | update='append', win='loss_cls', opts={'title':'loss_cls'}) 112 | viz.line(Y=np.array([epoch_loss_regr/mmp]), X=np.array([n_iter//100]), 113 | update='append', win='loss_regr', opts={'title':'loss_regr'}) 114 | viz.line(Y=np.array([epoch_loss/mmp]), X=np.array([n_iter//100]), 115 | update='append', win='loss_all', opts={'title':'loss_all'}) 116 | 117 | 118 | epoch_loss_cls /= epoch_size 119 | epoch_loss_regr /= epoch_size 120 | epoch_loss /= epoch_size 121 | print('Epoch:{}--{:.4f}--{:.4f}--{:.4f}'.format(epoch, epoch_loss_cls, epoch_loss_regr, epoch_loss)) 122 | if best_loss_cls > epoch_loss_cls or best_loss_regr > epoch_loss_regr or best_loss > epoch_loss: 123 | best_loss = epoch_loss 124 | best_loss_regr = epoch_loss_regr 125 | best_loss_cls = epoch_loss_cls 126 | best_model = model 127 | save_checkpoint({'model_state_dict': best_model.state_dict(), 'epoch': epoch}, 128 | epoch, 129 | best_loss_cls, 130 | best_loss_regr, 131 | best_loss) 132 | 133 | if torch.cuda.is_available(): 134 | torch.cuda.empty_cache() 135 | -------------------------------------------------------------------------------- /weights/ctpn.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CrazySummerday/ctpn.pytorch/99f6baf2780e550d7b4656ac7a7b90af9ade468f/weights/ctpn.pth --------------------------------------------------------------------------------