├── README.md ├── angle ├── README ├── __init__.py └── predict.py ├── crnn ├── README ├── crnn.py ├── dataset.py ├── keys_crnn.py └── models │ ├── crnn.py │ └── utils.py ├── ctpn ├── README ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── text_detect.cpython-36.pyc ├── ctpn │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── cfg.cpython-36.pyc │ │ ├── detectors.cpython-36.pyc │ │ └── other.cpython-36.pyc │ ├── cfg.py │ ├── demo.py │ ├── detectors.py │ ├── model.py │ ├── other.py │ ├── text.yml │ ├── text_proposal_connector.py │ ├── text_proposal_graph_builder.py │ └── train_net.py ├── data │ ├── demo │ │ ├── 001.jpg │ │ ├── 002.jpg │ │ ├── 003.jpg │ │ ├── 004.jpg │ │ ├── 005.jpg │ │ ├── 006.jpg │ │ ├── 007.jpg │ │ ├── 008.jpg │ │ ├── 009.jpg │ │ └── 010.png │ ├── oriented_results │ │ ├── 001.jpg │ │ ├── 002.jpg │ │ ├── 003.jpg │ │ ├── 004.jpg │ │ ├── 005.jpg │ │ ├── 006.jpg │ │ ├── 007.jpg │ │ ├── 008.jpg │ │ ├── 009.jpg │ │ └── 010.png │ └── results │ │ ├── 001.jpg │ │ ├── 002.jpg │ │ ├── 003.jpg │ │ └── 010.png ├── lib │ ├── datasets │ │ ├── __init__.py │ │ ├── ds_utils.py │ │ ├── factory.py │ │ ├── imdb.py │ │ └── pascal_voc.py │ ├── fast_rcnn │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── config.cpython-36.pyc │ │ │ └── nms_wrapper.cpython-36.pyc │ │ ├── bbox_transform.py │ │ ├── config.py │ │ ├── nms_wrapper.py │ │ ├── test.py │ │ └── train.py │ ├── networks │ │ ├── VGGnet_test.py │ │ ├── VGGnet_train.py │ │ ├── __init__.py │ │ ├── factory.py │ │ └── network.py │ ├── roi_data_layer │ │ ├── __init__.py │ │ ├── layer.py │ │ ├── minibatch.py │ │ └── roidb.py │ └── rpn_msr │ │ ├── anchor_target_layer_tf.py │ │ ├── generate_anchors.py │ │ └── proposal_layer_tf.py ├── models ├── prepare_training_data │ ├── ToVoc.py │ └── split_label.py └── text_detect.py ├── demo.ipynb ├── demo.py ├── img ├── tmp1.png ├── tmp1识别结果.png ├── tmp2.jpg └── tmp2识别结果.png ├── keras_model.py ├── model.py ├── pytorch_model.py ├── setup-cpu.sh ├── setup-python3.sh ├── setup.sh └── train ├── create-dataset.sh ├── create_dataset ├── create_dataset.py ├── fontA.ttf ├── textgen.py ├── viewlmdb.py └── 华文细黑.ttf ├── data ├── dataline │ ├── ff299a9c-b41b-11e7-89e1-1c1b0d6ddf51.jpg │ └── ff299a9c-b41b-11e7-89e1-1c1b0d6ddf51.txt └── lmdb │ ├── train │ ├── data.mdb │ └── lock.mdb │ └── val │ ├── data.mdb │ └── lock.mdb └── keras-train ├── allinonetrain.py ├── basemodel.png └── dataset.py /README.md: -------------------------------------------------------------------------------- 1 | # 基于tensorflow、keras/pytorch实现对图片文字检测及端到端的OCR文字识别 2 | 3 | 4 | ## 实现功能 5 | 6 | - 文字方向检测 0、90、180、270度检测 7 | - 文字检测 后期将切换到keras版本文本检测 实现keras端到端的文本检测及识别 8 | - 不定长OCR识别 9 | 10 | 11 | ## 环境部署 12 | ``` 13 | Bash 14 | ##GPU环境 15 | sh setup.sh 16 | ##CPU环境 17 | sh setup-cpu.sh 18 | ##CPU python3环境 19 | sh setup-python3.sh 20 | 使用环境:python3.6+tensorflow1.7+cpu/gpu 21 | ``` 22 | 23 | ## 模型训练 24 | * 一共分为3个网络 25 | * **1. 文本方向检测网络-Classify(vgg16)** 26 | * **2. 文本区域检测网络-CTPN(CNN+RNN)** 27 | * **3. EndToEnd文本识别网络-CRNN(CNN+GRU/LSTM+CTC)** 28 | 29 | ## 文字方向检测-vgg分类 30 | ``` 31 | 基于图像分类,在VGG16模型的基础上,训练0、90、180、270度检测的分类模型. 32 | 详细代码参考angle/predict.py文件,训练图片8000张,准确率88.23% 33 | ``` 34 | 模型地址[BaiduCloud](https://pan.baidu.com/s/1zquQNdO0MUsLMsuwxbgPYg) 35 | 36 | ## 文字区域检测CTPN 37 | 支持CPU、GPU环境,一键部署, 38 | [文本检测训练参考](https://github.com/eragonruan/text-detection-ctpn) 39 | 40 | 41 | ## OCR 端到端识别:CRNN 42 | ### ocr识别采用GRU+CTC端到端识别技术,实现不分隔识别不定长文字 43 | 提供keras 与pytorch版本的训练代码,在理解keras的基础上,可以切换到pytorch版本,此版本更稳定 44 | 45 | 46 | ## 训练网络 47 | ### 1 对ctpn进行训练 48 | * 定位到路径--./ctpn/ctpn/train_net.py 49 | * 预训练的vgg网络路径[VGG_imagenet.npy](https://pan.baidu.com/s/1JO_ZojA5bkmJZsnxsShgkg) 50 | 将预训练权重下载下来,pretrained_model指向该路径即可, 51 | 此外整个模型的预训练权重[checkpoint](https://pan.baidu.com/s/1aT-vHgq7nvLy4M_T6SwR1Q) 52 | * ctpn数据集[百度云](https://pan.baidu.com/s/1NXFmdP_OgRF42xfHXUhBHQ) 53 | 数据集下载完成并解压后,将.ctpn/lib/datasets/pascal_voc.py 文件中的pascal_voc 类中的参数self.devkit_path指向数据集的路径即可 54 | 55 | ### 2 对crnn进行训练 56 | * keras版本 ./train/keras_train/train_batch.py model_path--指向预训练权重位置 57 | MODEL_PATH---指向模型训练保存的位置 58 | [keras模型预训练权重](https://pan.baidu.com/s/1vTG6-i_bFMWxQ_7xF06usg) 59 | * pythorch版本./train/pytorch-train/crnn_main.py 60 | ``` 61 | parser.add_argument( 62 | '--crnn', 63 | help="path to crnn (to continue training)", 64 | default=预训练权重的路径) 65 | parser.add_argument( 66 | '--experiment', 67 | help='Where to store samples and models', 68 | default=模型训练的权重保存位置) 69 | ``` 70 | [pytorch预训练权重](https://pan.baidu.com/s/1LEDNHEr3luloB7eZK6GOeA) 71 | 72 | 73 | ## 识别结果显示 74 | ### 文字检测及OCR识别结果 75 | ![示例图像1](./img/tmp1.png) 76 | `===========================================================` 77 | ![ctpn+crnn结果1](./img/tmp1识别结果.png) 78 | 79 | ![示例图像2](./img/tmp2.jpg) 80 | `===========================================================` 81 | ![ctpn+crnn结果2](./img/tmp2识别结果.png) 82 | 83 | 84 | ## 在思乐中的运用 85 | 从主播的直播间评论区可获得新进入直播间粉丝信息,以及直播间粉丝实时评论, 86 | 在无法获取直播平台数据时,通过CV可提取到需要的用户行为数据,用于用户价值的实时精准分级。 87 | 88 | ## 参考 89 | 90 | - [pytorch 实现crnn](https://github.com/meijieru/crnn.pytorch.git) 91 | - [keras-crnn 版本实现参考](https://www.zhihu.com/question/59645822) 92 | - [tensorflow-crnn](https://github.com/ilovin/lstm_ctc_ocr) 93 | - [tensorflow-ctpn](https://github.com/eragonruan/text-detection-ctpn ) 94 | - [CAFFE-CTPN](https://github.com/tianzhi0549/CTPN) 95 | -------------------------------------------------------------------------------- /angle/README: -------------------------------------------------------------------------------- 1 | # 文字方向检测-vgg分类 2 | ``` 3 | 基于图像分类,在VGG16模型的基础上,训练0、90、180、270度检测的分类模型. 4 | 详细代码参考angle/predict.py文件,训练图片8000张,准确率88.23% 5 | ``` 6 | -------------------------------------------------------------------------------- /angle/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 图像文字方向检测 5 | @author: Wangmc 6 | """ 7 | -------------------------------------------------------------------------------- /angle/predict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # _Author_: wangmc 4 | # Date: 2018-04-22 18:13:46 5 | # Last Modified by: xiaofeng 6 | # Last Modified time: 2018-04-22 18:13:46 7 | ''' 8 | 根据给定的图形,分析文字的朝向 9 | ''' 10 | # from keras.models import load_model 11 | import numpy as np 12 | from PIL import Image 13 | from keras.applications.vgg16 import preprocess_input, VGG16 14 | from keras.layers import Dense 15 | from keras.models import Model 16 | # 编译模型,以较小的学习参数进行训练 17 | from keras.optimizers import SGD 18 | 19 | 20 | def load(): 21 | vgg = VGG16(weights=None, input_shape=(224, 224, 3)) 22 | # 修改输出层 3个输出 23 | x = vgg.layers[-2].output 24 | predictions_class = Dense( 25 | 4, activation='softmax', name='predictions_class')(x) 26 | prediction = [predictions_class] 27 | model = Model(inputs=vgg.input, outputs=prediction) 28 | sgd = SGD(lr=0.00001, momentum=0.9) 29 | model.compile( 30 | optimizer=sgd, loss='categorical_crossentropy', metrics=['accuracy']) 31 | model.load_weights( 32 | '/Users/xiaofeng/Code/Github/dataset/CHINESE_OCR/angle/modelAngle.h5') 33 | return model 34 | 35 | 36 | # 加载模型 37 | model = None 38 | 39 | 40 | def predict(path=None, img=None): 41 | global model 42 | if model is None: 43 | model = load() 44 | """ 45 | 图片文字方向预测 46 | """ 47 | ROTATE = [0, 90, 180, 270] 48 | if path is not None: 49 | im = Image.open(path).convert('RGB') 50 | elif img is not None: 51 | im = Image.fromarray(img).convert('RGB') 52 | w, h = im.size 53 | # 对图像进行剪裁 54 | # 左上角(int(0.1 * w), int(0.1 * h)) 55 | # 右下角(w - int(0.1 * w), h - int(0.1 * h)) 56 | xmin, ymin, xmax, ymax = int(0.1 * w), int( 57 | 0.1 * h), w - int(0.1 * w), h - int(0.1 * h) 58 | im = im.crop((xmin, ymin, xmax, ymax)) # 剪切图片边缘,清除边缘噪声 59 | # 对图片进行剪裁之后进行resize成(224,224) 60 | im = im.resize((224, 224)) 61 | # 将图像转化成数组形式 62 | img = np.array(im) 63 | img = preprocess_input(img.astype(np.float32)) 64 | pred = model.predict(np.array([img])) 65 | index = np.argmax(pred, axis=1)[0] 66 | return ROTATE[index] 67 | -------------------------------------------------------------------------------- /crnn/README: -------------------------------------------------------------------------------- 1 | # OCR 端到端识别:CRNN 2 | ## ocr识别采用GRU+CTC端到到识别技术,实现不分隔识别不定长文字 3 | 提供keras 与pytorch版本的训练代码,在理解keras的基础上,可以切换到pytorch版本,此版本更稳定 4 | -------------------------------------------------------------------------------- /crnn/crnn.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import sys 3 | 4 | sys.path.insert(1, "./crnn") 5 | import torch 6 | import torch.utils.data 7 | from torch.autograd import Variable 8 | import numpy as np 9 | import util 10 | import dataset 11 | import models.crnn as crnn 12 | import keys_crnn 13 | from math import * 14 | import cv2 15 | 16 | GPU = False 17 | 18 | 19 | def dumpRotateImage_(img, degree, pt1, pt2, pt3, pt4): 20 | height, width = img.shape[:2] 21 | heightNew = int(width * fabs(sin(radians(degree))) + height * fabs(cos(radians(degree)))) 22 | widthNew = int(height * fabs(sin(radians(degree))) + width * fabs(cos(radians(degree)))) 23 | matRotation = cv2.getRotationMatrix2D((width / 2, height / 2), degree, 1) 24 | matRotation[0, 2] += (widthNew - width) / 2 25 | matRotation[1, 2] += (heightNew - height) / 2 26 | imgRotation = cv2.warpAffine(img, matRotation, (widthNew, heightNew), borderValue=(255, 255, 255)) 27 | pt1 = list(pt1) 28 | pt3 = list(pt3) 29 | 30 | [[pt1[0]], [pt1[1]]] = np.dot(matRotation, np.array([[pt1[0]], [pt1[1]], [1]])) 31 | [[pt3[0]], [pt3[1]]] = np.dot(matRotation, np.array([[pt3[0]], [pt3[1]], [1]])) 32 | imgOut = imgRotation[int(pt1[1]):int(pt3[1]), int(pt1[0]):int(pt3[0])] 33 | height, width = imgOut.shape[:2] 34 | return imgOut 35 | 36 | 37 | def crnnSource(): 38 | alphabet = keys_crnn.alphabet 39 | converter = util.strLabelConverter(alphabet) 40 | if torch.cuda.is_available() and GPU: 41 | model = crnn.CRNN(32, 1, len(alphabet) + 1, 256, 1).cuda() 42 | else: 43 | model = crnn.CRNN(32, 1, len(alphabet) + 1, 256, 1).cpu() 44 | path = './crnn/samples/model_acc97.pth' 45 | model.eval() 46 | model.load_state_dict(torch.load(path)) 47 | return model, converter 48 | 49 | 50 | ##加载模型 51 | model, converter = crnnSource() 52 | 53 | 54 | def crnnOcr(image): 55 | """ 56 | crnn模型,ocr识别 57 | @@model, 58 | @@converter, 59 | @@im 60 | @@text_recs:text box 61 | 62 | """ 63 | scale = image.size[1] * 1.0 / 32 64 | w = image.size[0] / scale 65 | w = int(w) 66 | # print "im size:{},{}".format(image.size,w) 67 | transformer = dataset.resizeNormalize((w, 32)) 68 | if torch.cuda.is_available() and GPU: 69 | image = transformer(image).cuda() 70 | else: 71 | image = transformer(image).cpu() 72 | 73 | image = image.view(1, *image.size()) 74 | image = Variable(image) 75 | model.eval() 76 | preds = model(image) 77 | _, preds = preds.max(2) 78 | preds = preds.transpose(1, 0).contiguous().view(-1) 79 | preds_size = Variable(torch.IntTensor([preds.size(0)])) 80 | sim_pred = converter.decode(preds.data, preds_size.data, raw=False) 81 | if len(sim_pred) > 0: 82 | if sim_pred[0] == u'-': 83 | sim_pred = sim_pred[1:] 84 | 85 | return sim_pred 86 | -------------------------------------------------------------------------------- /crnn/dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # encoding: utf-8 3 | 4 | import random 5 | import sys 6 | 7 | import lmdb 8 | import numpy as np 9 | import six 10 | import torch 11 | import torchvision.transforms as transforms 12 | from PIL import Image 13 | from torch.utils.data import Dataset 14 | from torch.utils.data import sampler 15 | 16 | 17 | class lmdbDataset(Dataset): 18 | def __init__(self, root=None, transform=None, target_transform=None): 19 | self.env = lmdb.open( 20 | root, 21 | max_readers=1, 22 | readonly=True, 23 | lock=False, 24 | readahead=False, 25 | meminit=False) 26 | 27 | if not self.env: 28 | print('cannot creat lmdb from %s' % (root)) 29 | sys.exit(0) 30 | 31 | with self.env.begin(write=False) as txn: 32 | nSamples = int(txn.get('num-samples')) 33 | self.nSamples = nSamples 34 | 35 | self.transform = transform 36 | self.target_transform = target_transform 37 | 38 | def __len__(self): 39 | return self.nSamples 40 | 41 | def __getitem__(self, index): 42 | assert index <= len(self), 'index range error' 43 | index += 1 44 | with self.env.begin(write=False) as txn: 45 | img_key = 'image-%09d' % index 46 | imgbuf = txn.get(img_key) 47 | 48 | buf = six.BytesIO() 49 | buf.write(imgbuf) 50 | buf.seek(0) 51 | try: 52 | img = Image.open(buf).convert('L') 53 | except IOError: 54 | print('Corrupted image for %d' % index) 55 | return self[index + 1] 56 | 57 | if self.transform is not None: 58 | img = self.transform(img) 59 | 60 | label_key = 'label-%09d' % index 61 | label = str(txn.get(label_key)) 62 | if self.target_transform is not None: 63 | label = self.target_transform(label) 64 | 65 | return (img, label) 66 | 67 | 68 | class resizeNormalize(object): 69 | def __init__(self, size, interpolation=Image.BILINEAR): 70 | self.size = size 71 | self.interpolation = interpolation 72 | self.toTensor = transforms.ToTensor() 73 | 74 | def __call__(self, img): 75 | img = img.resize(self.size, self.interpolation) 76 | img = self.toTensor(img) 77 | img.sub_(0.5).div_(0.5) 78 | return img 79 | 80 | 81 | class randomSequentialSampler(sampler.Sampler): 82 | def __init__(self, data_source, batch_size): 83 | self.num_samples = len(data_source) 84 | self.batch_size = batch_size 85 | 86 | def __iter__(self): 87 | n_batch = len(self) // self.batch_size 88 | tail = len(self) % self.batch_size 89 | index = torch.LongTensor(len(self)).fill_(0) 90 | for i in range(n_batch): 91 | random_start = random.randint(0, len(self) - self.batch_size) 92 | batch_index = random_start + torch.range(0, self.batch_size - 1) 93 | index[i * self.batch_size:(i + 1) * self.batch_size] = batch_index 94 | # deal with tail 95 | if tail: 96 | random_start = random.randint(0, len(self) - self.batch_size) 97 | tail_index = random_start + torch.range(0, tail - 1) 98 | index[(i + 1) * self.batch_size:] = tail_index 99 | 100 | return iter(index) 101 | 102 | def __len__(self): 103 | return self.num_samples 104 | 105 | 106 | class alignCollate(object): 107 | def __init__(self, imgH=32, imgW=128, keep_ratio=False, min_ratio=1): 108 | self.imgH = imgH 109 | self.imgW = imgW 110 | self.keep_ratio = keep_ratio 111 | self.min_ratio = min_ratio 112 | 113 | def __call__(self, batch): 114 | images, labels = zip(*batch) 115 | 116 | imgH = self.imgH 117 | imgW = self.imgW 118 | if self.keep_ratio: 119 | ratios = [] 120 | for image in images: 121 | w, h = image.size 122 | ratios.append(w / float(h)) 123 | ratios.sort() 124 | max_ratio = ratios[-1] 125 | imgW = int(np.floor(max_ratio * imgH)) 126 | imgW = max(imgH * self.min_ratio, imgW) # assure imgH >= imgW 127 | 128 | transform = resizeNormalize((imgW, imgH)) 129 | images = [transform(image) for image in images] 130 | images = torch.cat([t.unsqueeze(0) for t in images], 0) 131 | 132 | return images, labels 133 | -------------------------------------------------------------------------------- /crnn/keys_crnn.py: -------------------------------------------------------------------------------- 1 | # coding:UTF-8 2 | alphabet = u'\'疗绚诚娇溜题贿者廖更纳加奉公一就汴计与路房原妇208-7其>:],,骑刈全消昏傈安久钟嗅不影处驽蜿资关椤地瘸专问忖票嫉炎韵要月田节陂鄙捌备拳伺眼网盎大傍心东愉汇蹿科每业里航晏字平录先13彤鲶产稍督腴有象岳注绍在泺文定核名水过理让偷率等这发”为含肥酉相鄱七编猥锛日镀蒂掰倒辆栾栗综涩州雌滑馀了机块司宰甙兴矽抚保用沧秩如收息滥页疑埠!!姥异橹钇向下跄的椴沫国绥獠报开民蜇何分凇长讥藏掏施羽中讲派嘟人提浼间世而古多倪唇饯控庚首赛蜓味断制觉技替艰溢潮夕钺外摘枋动双单啮户枇确锦曜杜或能效霜盒然侗电晁放步鹃新杖蜂吒濂瞬评总隍对独合也是府青天诲墙组滴级邀帘示已时骸仄泅和遨店雇疫持巍踮境只亨目鉴崤闲体泄杂作般轰化解迂诿蛭璀腾告版服省师小规程线海办引二桧牌砺洄裴修图痫胡许犊事郛基柴呼食研奶律蛋因葆察戏褒戒再李骁工貂油鹅章啄休场给睡纷豆器捎说敏学会浒设诊格廓查来霓室溆¢诡寥焕舜柒狐回戟砾厄实翩尿五入径惭喹股宇篝|;美期云九祺扮靠锝槌系企酰阊暂蚕忻豁本羹执条钦H獒限进季楦于芘玖铋茯未答粘括样精欠矢甥帷嵩扣令仔风皈行支部蓉刮站蜡救钊汗松嫌成可.鹤院从交政怕活调球局验髌第韫谗串到圆年米/*友忿检区看自敢刃个兹弄流留同没齿星聆轼湖什三建蛔儿椋汕震颧鲤跟力情璺铨陪务指族训滦鄣濮扒商箱十召慷辗所莞管护臭横硒嗓接侦六露党馋驾剖高侬妪幂猗绺骐央酐孝筝课徇缰门男西项句谙瞒秃篇教碲罚声呐景前富嘴鳌稀免朋啬睐去赈鱼住肩愕速旁波厅健茼厥鲟谅投攸炔数方击呋谈绩别愫僚躬鹧胪炳招喇膨泵蹦毛结54谱识陕粽婚拟构且搜任潘比郢妨醪陀桔碘扎选哈骷楷亿明缆脯监睫逻婵共赴淝凡惦及达揖谩澹减焰蛹番祁柏员禄怡峤龙白叽生闯起细装谕竟聚钙上导渊按艾辘挡耒盹饪臀记邮蕙受各医搂普滇朗茸带翻酚(光堤墟蔷万幻〓瑙辈昧盏亘蛀吉铰请子假闻税井诩哨嫂好面琐校馊鬣缂营访炖占农缀否经钚棵趟张亟吏茶谨捻论迸堂玉信吧瞠乡姬寺咬溏苄皿意赉宝尔钰艺特唳踉都荣倚登荐丧奇涵批炭近符傩感道着菊虹仲众懈濯颞眺南释北缝标既茗整撼迤贲挎耱拒某妍卫哇英矶藩治他元领膜遮穗蛾飞荒棺劫么市火温拈棚洼转果奕卸迪伸泳斗邡侄涨屯萋胭氡崮枞惧冒彩斜手豚随旭淑妞形菌吲沱争驯歹挟兆柱传至包内响临红功弩衡寂禁老棍耆渍织害氵渑布载靥嗬虽苹咨娄库雉榜帜嘲套瑚亲簸欧边6腿旮抛吹瞳得镓梗厨继漾愣憨士策窑抑躯襟脏参贸言干绸鳄穷藜音折详)举悍甸癌黎谴死罩迁寒驷袖媒蒋掘模纠恣观祖蛆碍位稿主澧跌筏京锏帝贴证糠才黄鲸略炯饱四出园犀牧容汉杆浈汰瑷造虫瘩怪驴济应花沣谔夙旅价矿以考su呦晒巡茅准肟瓴詹仟褂译桌混宁怦郑抿些余鄂饴攒珑群阖岔琨藓预环洮岌宀杲瀵最常囡周踊女鼓袭喉简范薯遐疏粱黜禧法箔斤遥汝奥直贞撑置绱集她馅逗钧橱魉[恙躁唤9旺膘待脾惫购吗依盲度瘿蠖俾之镗拇鲵厝簧续款展啃表剔品钻腭损清锶统涌寸滨贪链吠冈伎迥咏吁览防迅失汾阔逵绀蔑列川凭努熨揪利俱绉抢鸨我即责膦易毓鹊刹玷岿空嘞绊排术估锷违们苟铜播肘件烫审鲂广像铌惰铟巳胍鲍康憧色恢想拷尤疳知SYFDA峄裕帮握搔氐氘难墒沮雨叁缥悴藐湫娟苑稠颛簇后阕闭蕤缚怎佞码嘤蔡痊舱螯帕赫昵升烬岫、疵蜻髁蕨隶烛械丑盂梁强鲛由拘揉劭龟撤钩呕孛费妻漂求阑崖秤甘通深补赃坎床啪承吼量暇钼烨阂擎脱逮称P神属矗华届狍葑汹育患窒蛰佼静槎运鳗庆逝曼疱克代官此麸耧蚌晟例础榛副测唰缢迹灬霁身岁赭扛又菡乜雾板读陷徉贯郁虑变钓菜圾现琢式乐维渔浜左吾脑钡警T啵拴偌漱湿硕止骼魄积燥联踢玛|则窿见振畿送班钽您赵刨印讨踝籍谡舌崧汽蔽沪酥绒怖财帖肱私莎勋羔霸励哼帐将帅渠纪婴娩岭厘滕吻伤坝冠戊隆瘁介涧物黍并姗奢蹑掣垸锴命箍捉病辖琰眭迩艘绌繁寅若毋思诉类诈燮轲酮狂重反职筱县委磕绣奖晋濉志徽肠呈獐坻口片碰几村柿劳料获亩惕晕厌号罢池正鏖煨家棕复尝懋蜥锅岛扰队坠瘾钬@卧疣镇譬冰彷频黯据垄采八缪瘫型熹砰楠襁箐但嘶绳啤拍盥穆傲洗盯塘怔筛丿台恒喂葛永¥烟酒桦书砂蚝缉态瀚袄圳轻蛛超榧遛姒奘铮右荽望偻卡丶氰附做革索戚坨桷唁垅榻岐偎坛莨山殊微骇陈爨推嗝驹澡藁呤卤嘻糅逛侵郓酌德摇※鬃被慨殡羸昌泡戛鞋河宪沿玲鲨翅哽源铅语照邯址荃佬顺鸳町霭睾瓢夸椁晓酿痈咔侏券噎湍签嚷离午尚社锤背孟使浪缦潍鞅军姹驶笑鳟鲁》孽钜绿洱礴焯椰颖囔乌孔巴互性椽哞聘昨早暮胶炀隧低彗昝铁呓氽藉喔癖瑗姨权胱韦堑蜜酋楝砝毁靓歙锲究屋喳骨辨碑武鸠宫辜烊适坡殃培佩供走蜈迟翼况姣凛浔吃飘债犟金促苛崇坂莳畔绂兵蠕斋根砍亢欢恬崔剁餐榫快扶‖濒缠鳜当彭驭浦篮昀锆秸钳弋娣瞑夷龛苫拱致%嵊障隐弑初娓抉汩累蓖"唬助苓昙押毙破城郧逢嚏獭瞻溱婿赊跨恼璧萃姻貉灵炉密氛陶砸谬衔点琛沛枳层岱诺脍榈埂征冷裁打蹴素瘘逞蛐聊激腱萘踵飒蓟吆取咙簋涓矩曝挺揣座你史舵焱尘苏笈脚溉榨诵樊邓焊义庶儋蟋蒲赦呷杞诠豪还试颓茉太除紫逃痴草充鳕珉祗墨渭烩蘸慕璇镶穴嵘恶骂险绋幕碉肺戳刘潞秣纾潜銮洛须罘销瘪汞兮屉r林厕质探划狸殚善煊烹〒锈逯宸辍泱柚袍远蹋嶙绝峥娥缍雀徵认镱谷=贩勉撩鄯斐洋非祚泾诒饿撬威晷搭芍锥笺蓦候琊档礁沼卵荠忑朝凹瑞头仪弧孵畏铆突衲车浩气茂悖厢枕酝戴湾邹飚攘锂写宵翁岷无喜丈挑嗟绛殉议槽具醇淞笃郴阅饼底壕砚弈询缕庹翟零筷暨舟闺甯撞麂茌蔼很珲捕棠角阉媛娲诽剿尉爵睬韩诰匣危糍镯立浏阳少盆舔擘匪申尬铣旯抖赘瓯居ˇ哮游锭茏歌坏甚秒舞沙仗劲潺阿燧郭嗖霏忠材奂耐跺砀输岖媳氟极摆灿今扔腻枝奎药熄吨话q额慑嘌协喀壳埭视著於愧陲翌峁颅佛腹聋侯咎叟秀颇存较罪哄岗扫栏钾羌己璨枭霉煌涸衿键镝益岢奏连夯睿冥均糖狞蹊稻爸刿胥煜丽肿璃掸跚灾垂樾濑乎莲窄犹撮战馄软络显鸢胸宾妲恕埔蝌份遇巧瞟粒恰剥桡博讯凯堇阶滤卖斌骚彬兑磺樱舷两娱福仃差找桁÷净把阴污戬雷碓蕲楚罡焖抽妫咒仑闱尽邑菁爱贷沥鞑牡嗉崴骤塌嗦订拮滓捡锻次坪杩臃箬融珂鹗宗枚降鸬妯阄堰盐毅必杨崃俺甬状莘货耸菱腼铸唏痤孚澳懒溅翘疙杷淼缙骰喊悉砻坷艇赁界谤纣宴晃茹归饭梢铡街抄肼鬟苯颂撷戈炒咆茭瘙负仰客琉铢封卑珥椿镧窨鬲寿御袤铃萎砖餮脒裳肪孕嫣馗嵇恳氯江石褶冢祸阻狈羞银靳透咳叼敷芷啥它瓤兰痘懊逑肌往捺坊甩呻〃沦忘膻祟菅剧崆智坯臧霍墅攻眯倘拢骠铐庭岙瓠′缺泥迢捶??郏喙掷沌纯秘种听绘固螨团香盗妒埚蓝拖旱荞铀血遏汲辰叩拽幅硬惶桀漠措泼唑齐肾念酱虚屁耶旗砦闵婉馆拭绅韧忏窝醋葺顾辞倜堆辋逆玟贱疾董惘倌锕淘嘀莽俭笏绑鲷杈择蟀粥嗯驰逾案谪褓胫哩昕颚鲢绠躺鹄崂儒俨丝尕泌啊萸彰幺吟骄苣弦脊瑰〈诛镁析闪剪侧哟框螃守嬗燕狭铈缮概迳痧鲲俯售笼痣扉挖满咋援邱扇歪便玑绦峡蛇叨〖泽胃斓喋怂坟猪该蚬炕弥赞棣晔娠挲狡创疖铕镭稷挫弭啾翔粉履苘哦楼秕铂土锣瘟挣栉习享桢袅磨桂谦延坚蔚噗署谟猬钎恐嬉雒倦衅亏璩睹刻殿王算雕麻丘柯骆丸塍谚添鲈垓桎蚯芥予飕镦谌窗醚菀亮搪莺蒿羁足J真轶悬衷靛翊掩哒炅掐冼妮l谐稚荆擒犯陵虏浓崽刍陌傻孜千靖演矜钕煽杰酗渗伞栋俗泫戍罕沾疽灏煦芬磴叱阱榉湃蜀叉醒彪租郡篷屎良垢隗弱陨峪砷掴颁胎雯绵贬沐撵隘篙暖曹陡栓填臼彦瓶琪潼哪鸡摩啦俟锋域耻蔫疯纹撇毒绶痛酯忍爪赳歆嘹辕烈册朴钱吮毯癜娃谀邵厮炽璞邃丐追词瓒忆轧芫谯喷弟半冕裙掖墉绮寝苔势顷褥切衮君佳嫒蚩霞佚洙逊镖暹唛&殒顶碗獗轭铺蛊废恹汨崩珍那杵曲纺夏薰傀闳淬姘舀拧卷楂恍讪厩寮篪赓乘灭盅鞣沟慎挂饺鼾杳树缨丛絮娌臻嗳篡侩述衰矛圈蚜匕筹匿濞晨叶骋郝挚蚴滞增侍描瓣吖嫦蟒匾圣赌毡癞恺百曳需篓肮庖帏卿驿遗蹬鬓骡歉芎胳屐禽烦晌寄媾狄翡苒船廉终痞殇々畦饶改拆悻萄£瓿乃訾桅匮溧拥纱铍骗蕃龋缬父佐疚栎醍掳蓄x惆颜鲆榆〔猎敌暴谥鲫贾罗玻缄扦芪癣落徒臾恿猩托邴肄牵春陛耀刊拓蓓邳堕寇枉淌啡湄兽酷萼碚濠萤夹旬戮梭琥椭昔勺蜊绐晚孺僵宣摄冽旨萌忙蚤眉噼蟑付契瓜悼颡壁曾窕颢澎仿俑浑嵌浣乍碌褪乱蔟隙玩剐葫箫纲围伐决伙漩瑟刑肓镳缓蹭氨皓典畲坍铑檐塑洞倬储胴淳戾吐灼惺妙毕珐缈虱盖羰鸿磅谓髅娴苴唷蚣霹抨贤唠犬誓逍庠逼麓籼釉呜碧秧氩摔霄穸纨辟妈映完牛缴嗷炊恩荔茆掉紊慌莓羟阙萁磐另蕹辱鳐湮吡吩唐睦垠舒圜冗瞿溺芾囱匠僳汐菩饬漓黑霰浸濡窥毂蒡兢驻鹉芮诙迫雳厂忐臆猴鸣蚪栈箕羡渐莆捍眈哓趴蹼埕嚣骛宏淄斑噜严瑛垃椎诱压庾绞焘廿抡迄棘夫纬锹眨瞌侠脐竞瀑孳骧遁姜颦荪滚萦伪逸粳爬锁矣役趣洒颔诏逐奸甭惠攀蹄泛尼拼阮鹰亚颈惑勒〉际肛爷刚钨丰养冶鲽辉蔻画覆皴妊麦返醉皂擀〗酶凑粹悟诀硖港卜z杀涕±舍铠抵弛段敝镐奠拂轴跛袱et沉菇俎薪峦秭蟹历盟菠寡液肢喻染裱悱抱氙赤捅猛跑氮谣仁尺辊窍烙衍架擦倏璐瑁币楞胖夔趸邛惴饕虔蝎§哉贝宽辫炮扩饲籽魏菟锰伍猝末琳哚蛎邂呀姿鄞却歧仙恸椐森牒寤袒婆虢雅钉朵贼欲苞寰故龚坭嘘咫礼硷兀睢汶’铲烧绕诃浃钿哺柜讼颊璁腔洽咐脲簌筠镣玮鞠谁兼姆挥梯蝴谘漕刷躏宦弼b垌劈麟莉揭笙渎仕嗤仓配怏抬错泯镊孰猿邪仍秋鼬壹歇吵炼<尧射柬廷胧霾凳隋肚浮梦祥株堵退L鹫跎凶毽荟炫栩玳甜沂鹿顽伯爹赔蛴徐匡欣狰缸雹蟆疤默沤啜痂衣禅wih辽葳黝钗停沽棒馨颌肉吴硫悯劾娈马啧吊悌镑峭帆瀣涉咸疸滋泣翦拙癸钥蜒+尾庄凝泉婢渴谊乞陆锉糊鸦淮IBN晦弗乔庥葡尻席橡傣渣拿惩麋斛缃矮蛏岘鸽姐膏催奔镒喱蠡摧钯胤柠拐璋鸥卢荡倾^_珀逄萧塾掇贮笆聂圃冲嵬M滔笕值炙偶蜱搐梆汪蔬腑鸯蹇敞绯仨祯谆梧糗鑫啸豺囹猾巢柄瀛筑踌沭暗苁鱿蹉脂蘖牢热木吸溃宠序泞偿拜檩厚朐毗螳吞媚朽担蝗橘畴祈糟盱隼郜惜珠裨铵焙琚唯咚噪骊丫滢勤棉呸咣淀隔蕾窈饨挨煅短匙粕镜赣撕墩酬馁豌颐抗酣氓佑搁哭递耷涡桃贻碣截瘦昭镌蔓氚甲猕蕴蓬散拾纛狼猷铎埋旖矾讳囊糜迈粟蚂紧鲳瘢栽稼羊锄斟睁桥瓮蹙祉醺鼻昱剃跳篱跷蒜翎宅晖嗑壑峻癫屏狠陋袜途憎祀莹滟佶溥臣约盛峰磁慵婪拦莅朕鹦粲裤哎疡嫖琵窟堪谛嘉儡鳝斩郾驸酊妄胜贺徙傅噌钢栅庇恋匝巯邈尸锚粗佟蛟薹纵蚊郅绢锐苗俞篆淆膀鲜煎诶秽寻涮刺怀噶巨褰魅灶灌桉藕谜舸薄搀恽借牯痉渥愿亓耘杠柩锔蚶钣珈喘蹒幽赐稗晤莱泔扯肯菪裆腩豉疆骜腐倭珏唔粮亡润慰伽橄玄誉醐胆龊粼塬陇彼削嗣绾芽妗垭瘴爽薏寨龈泠弹赢漪猫嘧涂恤圭茧烽屑痕巾赖荸凰腮畈亵蹲偃苇澜艮换骺烘苕梓颉肇哗悄氤涠葬屠鹭植竺佯诣鲇瘀鲅邦移滁冯耕癔戌茬沁巩悠湘洪痹锟循谋腕鳃钠捞焉迎碱伫急榷奈邝卯辄皲卟醛畹忧稳雄昼缩阈睑扌耗曦涅捏瞧邕淖漉铝耦禹湛喽莼琅诸苎纂硅始嗨傥燃臂赅嘈呆贵屹壮肋亍蚀卅豹腆邬迭浊}童螂捐圩勐触寞汊壤荫膺渌芳懿遴螈泰蓼蛤茜舅枫朔膝眙避梅判鹜璜牍缅垫藻黔侥惚懂踩腰腈札丞唾慈顿摹荻琬~斧沈滂胁胀幄莜Z匀鄄掌绰茎焚赋萱谑汁铒瞎夺蜗野娆冀弯篁懵灞隽芡脘俐辩芯掺喏膈蝈觐悚踹蔗熠鼠呵抓橼峨畜缔禾崭弃熊摒凸拗穹蒙抒祛劝闫扳阵醌踪喵侣搬仅荧赎蝾琦买婧瞄寓皎冻赝箩莫瞰郊笫姝筒枪遣煸袋舆痱涛母〇启践耙绲盘遂昊搞槿诬纰泓惨檬亻越Co憩熵祷钒暧塔阗胰咄娶魔琶钞邻扬杉殴咽弓〆髻】吭揽霆拄殖脆彻岩芝勃辣剌钝嘎甄佘皖伦授徕憔挪皇庞稔芜踏溴兖卒擢饥鳞煲‰账颗叻斯捧鳍琮讹蛙纽谭酸兔莒睇伟觑羲嗜宜褐旎辛卦诘筋鎏溪挛熔阜晰鳅丢奚灸呱献陉黛鸪甾萨疮拯洲疹辑叙恻谒允柔烂氏逅漆拎惋扈湟纭啕掬擞哥忽涤鸵靡郗瓷扁廊怨雏钮敦E懦憋汀拚啉腌岸f痼瞅尊咀眩飙忌仝迦熬毫胯篑茄腺凄舛碴锵诧羯後漏汤宓仞蚁壶谰皑铄棰罔辅晶苦牟闽\烃饮聿丙蛳朱煤涔鳖犁罐荼砒淦妤黏戎孑婕瑾戢钵枣捋砥衩狙桠稣阎肃梏诫孪昶婊衫嗔侃塞蜃樵峒貌屿欺缫阐栖诟珞荭吝萍嗽恂啻蜴磬峋俸豫谎徊镍韬魇晴U囟猜蛮坐囿伴亭肝佗蝠妃胞滩榴氖垩苋砣扪馏姓轩厉夥侈禀垒岑赏钛辐痔披纸碳“坞蠓挤荥沅悔铧帼蒌蝇apyng哀浆瑶凿桶馈皮奴苜佤伶晗铱炬优弊氢恃甫攥端锌灰稹炝曙邋亥眶碾拉萝绔捷浍腋姑菖凌涞麽锢桨潢绎镰殆锑渝铬困绽觎匈糙暑裹鸟盔肽迷綦『亳佝俘钴觇骥仆疝跪婶郯瀹唉脖踞针晾忒扼瞩叛椒疟嗡邗肆跆玫忡捣咧唆艄蘑潦笛阚沸泻掊菽贫斥髂孢镂赂麝鸾屡衬苷恪叠希粤爻喝茫惬郸绻庸撅碟宄妹膛叮饵崛嗲椅冤搅咕敛尹垦闷蝉霎勰败蓑泸肤鹌幌焦浠鞍刁舰乙竿裔。茵函伊兄丨娜匍謇莪宥似蝽翳酪翠粑薇祢骏赠叫Q噤噻竖芗莠潭俊羿耜O郫趁嗪囚蹶芒洁笋鹑敲硝啶堡渲揩』携宿遒颍扭棱割萜蔸葵琴捂饰衙耿掠募岂窖涟蔺瘤柞瞪怜匹距楔炜哆秦缎幼茁绪痨恨楸娅瓦桩雪嬴伏榔妥铿拌眠雍缇‘卓搓哌觞噩屈哧髓咦巅娑侑淫膳祝勾姊莴胄疃薛蜷胛巷芙芋熙闰勿窃狱剩钏幢陟铛慧靴耍k浙浇飨惟绗祜澈啼咪磷摞诅郦抹跃壬吕肖琏颤尴剡抠凋赚泊津宕殷倔氲漫邺涎怠$垮荬遵俏叹噢饽蜘孙筵疼鞭羧牦箭潴c眸祭髯啖坳愁芩驮倡巽穰沃胚怒凤槛剂趵嫁v邢灯鄢桐睽檗锯槟婷嵋圻诗蕈颠遭痢芸怯馥竭锗徜恭遍籁剑嘱苡龄僧桑潸弘澶楹悲讫愤腥悸谍椹呢桓葭攫阀翰躲敖柑郎笨橇呃魁燎脓葩磋垛玺狮沓砜蕊锺罹蕉翱虐闾巫旦茱嬷枯鹏贡芹汛矫绁拣禺佃讣舫惯乳趋疲挽岚虾衾蠹蹂飓氦铖孩稞瑜壅掀勘妓畅髋W庐牲蓿榕练垣唱邸菲昆婺穿绡麒蚱掂愚泷涪漳妩娉榄讷觅旧藤煮呛柳腓叭庵烷阡罂蜕擂猖咿媲脉【沏貅黠熏哲烁坦酵兜×潇撒剽珩圹乾摸樟帽嗒襄魂轿憬锡〕喃皆咖隅脸残泮袂鹂珊囤捆咤误徨闹淙芊淋怆囗拨梳渤RG绨蚓婀幡狩麾谢唢裸旌伉纶裂驳砼咛澄樨蹈宙澍倍貔操勇蟠摈砧虬够缁悦藿撸艹摁淹豇虎榭ˉ吱d°喧荀踱侮奋偕饷犍惮坑璎徘宛妆袈倩窦昂荏乖K怅撰鳙牙袁酞X痿琼闸雁趾荚虻涝《杏韭偈烤绫鞘卉症遢蓥诋杭荨匆竣簪辙敕虞丹缭咩黟m淤瑕咂铉硼茨嶂痒畸敬涿粪窘熟叔嫔盾忱裘憾梵赡珙咯娘庙溯胺葱痪摊荷卞乒髦寐铭坩胗枷爆溟嚼羚砬轨惊挠罄竽菏氧浅楣盼枢炸阆杯谏噬淇渺俪秆墓泪跻砌痰垡渡耽釜讶鳎煞呗韶舶绷鹳缜旷铊皱龌檀霖奄槐艳蝶旋哝赶骞蚧腊盈丁`蜚矸蝙睨嚓僻鬼醴夜彝磊笔拔栀糕厦邰纫逭纤眦膊馍躇烯蘼冬诤暄骶哑瘠」臊丕愈咱螺擅跋搏硪谄笠淡嘿骅谧鼎皋姚歼蠢驼耳胬挝涯狗蒽孓犷凉芦箴铤孤嘛坤V茴朦挞尖橙诞搴碇洵浚帚蜍漯柘嚎讽芭荤咻祠秉跖埃吓糯眷馒惹娼鲑嫩讴轮瞥靶褚乏缤宋帧删驱碎扑俩俄偏涣竹噱皙佰渚唧斡#镉刀崎筐佣夭贰肴峙哔艿匐牺镛缘仡嫡劣枸堀梨簿鸭蒸亦稽浴{衢束槲j阁揍疥棋潋聪窜乓睛插冉阪苍搽「蟾螟幸仇樽撂慢跤幔俚淅覃觊溶妖帛侨曰妾泗' 3 | -------------------------------------------------------------------------------- /crnn/models/crnn.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(1, "./crnn") 3 | import torch.nn as nn 4 | import utils 5 | 6 | 7 | class BidirectionalLSTM(nn.Module): 8 | def __init__(self, nIn, nHidden, nOut, ngpu): 9 | super(BidirectionalLSTM, self).__init__() 10 | self.ngpu = ngpu 11 | 12 | self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True) 13 | self.embedding = nn.Linear(nHidden * 2, nOut) 14 | 15 | def forward(self, input): 16 | recurrent, _ = utils.data_parallel(self.rnn, input, 17 | self.ngpu) # [T, b, h * 2] 18 | 19 | T, b, h = recurrent.size() 20 | t_rec = recurrent.view(T * b, h) 21 | output = utils.data_parallel(self.embedding, t_rec, 22 | self.ngpu) # [T * b, nOut] 23 | output = output.view(T, b, -1) 24 | 25 | return output 26 | 27 | 28 | class CRNN(nn.Module): 29 | def __init__(self, imgH, nc, nclass, nh, ngpu, n_rnn=2, leakyRelu=False): 30 | super(CRNN, self).__init__() 31 | self.ngpu = ngpu 32 | assert imgH % 16 == 0, 'imgH has to be a multiple of 16' 33 | 34 | ks = [3, 3, 3, 3, 3, 3, 2] 35 | ps = [1, 1, 1, 1, 1, 1, 0] 36 | ss = [1, 1, 1, 1, 1, 1, 1] 37 | nm = [64, 128, 256, 256, 512, 512, 512] 38 | 39 | cnn = nn.Sequential() 40 | 41 | def convRelu(i, batchNormalization=False): 42 | nIn = nc if i == 0 else nm[i - 1] 43 | nOut = nm[i] 44 | cnn.add_module('conv{0}'.format(i), 45 | nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i])) 46 | if batchNormalization: 47 | cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut)) 48 | if leakyRelu: 49 | cnn.add_module('relu{0}'.format(i), 50 | nn.LeakyReLU(0.2, inplace=True)) 51 | else: 52 | cnn.add_module('relu{0}'.format(i), nn.ReLU(True)) 53 | 54 | convRelu(0) 55 | cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64 56 | convRelu(1) 57 | cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32 58 | convRelu(2, True) 59 | convRelu(3) 60 | cnn.add_module('pooling{0}'.format(2), 61 | nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16 62 | convRelu(4, True) 63 | convRelu(5) 64 | cnn.add_module('pooling{0}'.format(3), 65 | nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16 66 | convRelu(6, True) # 512x1x16 67 | 68 | self.cnn = cnn 69 | self.rnn = nn.Sequential( 70 | BidirectionalLSTM(512, nh, nh, ngpu), 71 | BidirectionalLSTM(nh, nh, nclass, ngpu)) 72 | 73 | def forward(self, input): 74 | # conv features 75 | conv = utils.data_parallel(self.cnn, input, self.ngpu) 76 | b, c, h, w = conv.size() 77 | assert h == 1, "the height of conv must be 1" 78 | conv = conv.squeeze(2) 79 | conv = conv.permute(2, 0, 1) # [w, b, c] 80 | 81 | # rnn features 82 | output = utils.data_parallel(self.rnn, conv, self.ngpu) 83 | 84 | return output 85 | -------------------------------------------------------------------------------- /crnn/models/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # encoding: utf-8 3 | 4 | import torch.nn as nn 5 | import torch.nn.parallel 6 | 7 | 8 | def data_parallel(model, input, ngpu): 9 | if isinstance(input.data, torch.cuda.FloatTensor) and ngpu > 1: 10 | output = nn.parallel.data_parallel(model, input, range(ngpu)) 11 | else: 12 | output = model(input) 13 | return output 14 | -------------------------------------------------------------------------------- /ctpn/README: -------------------------------------------------------------------------------- 1 | # 文字区域检测CTPN 2 | 支持CPU、GPU环境,一键部署 3 | -------------------------------------------------------------------------------- /ctpn/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ctpn/__pycache__/text_detect.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/__pycache__/text_detect.cpython-36.pyc -------------------------------------------------------------------------------- /ctpn/ctpn/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /ctpn/ctpn/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/ctpn/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ctpn/ctpn/__pycache__/cfg.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/ctpn/__pycache__/cfg.cpython-36.pyc -------------------------------------------------------------------------------- /ctpn/ctpn/__pycache__/detectors.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/ctpn/__pycache__/detectors.cpython-36.pyc -------------------------------------------------------------------------------- /ctpn/ctpn/__pycache__/other.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/ctpn/__pycache__/other.cpython-36.pyc -------------------------------------------------------------------------------- /ctpn/ctpn/cfg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Config: 5 | MEAN = np.float32([102.9801, 115.9465, 122.7717]) 6 | # MEAN=np.float32([100.0, 100.0, 100.0]) 7 | TEST_GPU_ID = 0 8 | SCALE = 900 9 | MAX_SCALE = 1500 10 | TEXT_PROPOSALS_WIDTH = 0 11 | MIN_RATIO = 0.01 12 | LINE_MIN_SCORE = 0.6 13 | TEXT_LINE_NMS_THRESH = 0.3 14 | MAX_HORIZONTAL_GAP = 30 15 | TEXT_PROPOSALS_MIN_SCORE = 0.7 16 | TEXT_PROPOSALS_NMS_THRESH = 0.3 17 | MIN_NUM_PROPOSALS = 0 18 | MIN_V_OVERLAPS = 0.6 19 | MIN_SIZE_SIM = 0.6 20 | -------------------------------------------------------------------------------- /ctpn/ctpn/demo.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import shutil 4 | import sys 5 | 6 | import cv2 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | parentdir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 11 | sys.path.append(parentdir) 12 | 13 | from lib.networks.factory import get_network 14 | from lib.fast_rcnn.config import cfg 15 | from lib.fast_rcnn.test import test_ctpn 16 | from lib.fast_rcnn.nms_wrapper import nms 17 | from lib.utils.timer import Timer 18 | from text_proposal_connector import TextProposalConnector 19 | 20 | CLASSES = ('__background__', 'text') 21 | 22 | 23 | def connect_proposal(text_proposals, scores, im_size): 24 | cp = TextProposalConnector() 25 | line = cp.get_text_lines(text_proposals, scores, im_size) 26 | return line 27 | 28 | 29 | def save_results(image_name, im, line, thresh): 30 | inds = np.where(line[:, -1] >= thresh)[0] 31 | if len(inds) == 0: 32 | return 33 | 34 | for i in inds: 35 | bbox = line[i, :4] 36 | score = line[i, -1] 37 | cv2.rectangle( 38 | im, (bbox[0], bbox[1]), (bbox[2], bbox[3]), 39 | color=(0, 0, 255), 40 | thickness=1) 41 | image_name = image_name.split('/')[-1] 42 | cv2.imwrite(os.path.join("../data/results", image_name), im) 43 | 44 | 45 | def check_img(im): 46 | im_size = im.shape 47 | if max(im_size[0:2]) < 600: 48 | img = np.zeros((600, 600, 3), dtype=np.uint8) 49 | start_row = int((600 - im_size[0]) / 2) 50 | start_col = int((600 - im_size[1]) / 2) 51 | end_row = start_row + im_size[0] 52 | end_col = start_col + im_size[1] 53 | img[start_row:end_row, start_col:end_col, :] = im 54 | return img 55 | else: 56 | return im 57 | 58 | 59 | def ctpn(sess, net, image_name): 60 | img = cv2.imread(image_name) 61 | im = check_img(img) 62 | timer = Timer() 63 | timer.tic() 64 | scores, boxes = test_ctpn(sess, net, im) 65 | timer.toc() 66 | # print('Detection took {:.3f}s for ' 67 | # '{:d} object proposals').format(timer.total_time, boxes.shape[0]) 68 | 69 | # Visualize detections for each class 70 | CONF_THRESH = 0.9 71 | NMS_THRESH = 0.3 72 | dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32) 73 | keep = nms(dets, NMS_THRESH) 74 | dets = dets[keep, :] 75 | 76 | keep = np.where(dets[:, 4] >= 0.7)[0] 77 | dets = dets[keep, :] 78 | line = connect_proposal(dets[:, 0:4], dets[:, 4], im.shape) 79 | save_results(image_name, im, line, thresh=0.9) 80 | 81 | 82 | if __name__ == '__main__': 83 | if os.path.exists("../data/results/"): 84 | shutil.rmtree("../data/results/") 85 | os.makedirs("../data/results/") 86 | 87 | cfg.TEST.HAS_RPN = True # Use RPN for proposals 88 | # init session 89 | config = tf.ConfigProto(allow_soft_placement=True) 90 | sess = tf.Session(config=config) 91 | # load network 92 | net = get_network("VGGnet_test") 93 | # load model 94 | print('Loading network {:s}... '.format("VGGnet_test")), 95 | saver = tf.train.Saver() 96 | # saver.restore(sess, 97 | # os.path.join(os.getcwd(), "checkpoints/model_final.ckpt")) 98 | saver.restore(sess, 99 | os.path.join(os.getcwd(), 100 | "/Users/xiaofeng/Code/Github/dataset/CHINESE_OCR/ctpn/checkpoints/VGGnet_fast_rcnn_iter_50000.ckpt")) 101 | print(' done.') 102 | 103 | # Warmup on a dummy image 104 | im = 128 * np.ones((300, 300, 3), dtype=np.uint8) 105 | for i in range(2): 106 | _, _ = test_ctpn(sess, net, im) 107 | 108 | im_names = glob.glob(os.path.join(cfg.DATA_DIR, 'demo', '*.png')) + \ 109 | glob.glob(os.path.join(cfg.DATA_DIR, 'demo', '*.jpg')) 110 | 111 | for im_name in im_names: 112 | print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~') 113 | print('Demo for {:s}'.format(im_name)) 114 | ctpn(sess, net, im_name) 115 | -------------------------------------------------------------------------------- /ctpn/ctpn/detectors.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import sys 3 | 4 | import numpy as np 5 | 6 | from .cfg import Config as cfg 7 | from .other import normalize 8 | 9 | sys.path.append('..') 10 | from ..lib.fast_rcnn.nms_wrapper import nms 11 | # from lib.fast_rcnn.test import test_ctpn 12 | 13 | from .text_proposal_connector import TextProposalConnector 14 | 15 | 16 | class TextDetector: 17 | """ 18 | Detect text from an image 19 | """ 20 | 21 | def __init__(self): 22 | """ 23 | pass 24 | """ 25 | self.text_proposal_connector = TextProposalConnector() 26 | 27 | def detect(self, text_proposals, scores, size): 28 | """ 29 | Detecting texts from an image 30 | :return: the bounding boxes of the detected texts 31 | """ 32 | # text_proposals, scores=self.text_proposal_detector.detect(im, cfg.MEAN) 33 | keep_inds = np.where(scores > cfg.TEXT_PROPOSALS_MIN_SCORE)[0] 34 | text_proposals, scores = text_proposals[keep_inds], scores[keep_inds] 35 | 36 | sorted_indices = np.argsort(scores.ravel())[::-1] 37 | text_proposals, scores = text_proposals[sorted_indices], scores[sorted_indices] 38 | 39 | # nms for text proposals 40 | keep_inds = nms(np.hstack((text_proposals, scores)), cfg.TEXT_PROPOSALS_NMS_THRESH) 41 | text_proposals, scores = text_proposals[keep_inds], scores[keep_inds] 42 | 43 | scores = normalize(scores) 44 | 45 | text_lines = self.text_proposal_connector.get_text_lines(text_proposals, scores, size) 46 | 47 | keep_inds = self.filter_boxes(text_lines) 48 | text_lines = text_lines[keep_inds] 49 | 50 | if text_lines.shape[0] != 0: 51 | keep_inds = nms(text_lines, cfg.TEXT_LINE_NMS_THRESH) 52 | text_lines = text_lines[keep_inds] 53 | 54 | return text_lines 55 | 56 | def filter_boxes(self, boxes): 57 | heights = boxes[:, 3] - boxes[:, 1] + 1 58 | widths = boxes[:, 2] - boxes[:, 0] + 1 59 | scores = boxes[:, -1] 60 | return np.where((widths / heights > cfg.MIN_RATIO) & (scores > cfg.LINE_MIN_SCORE) & 61 | (widths > (cfg.TEXT_PROPOSALS_WIDTH * cfg.MIN_NUM_PROPOSALS)))[0] 62 | -------------------------------------------------------------------------------- /ctpn/ctpn/model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | import tensorflow as tf 5 | 6 | from .cfg import Config 7 | from .other import resize_im 8 | 9 | sys.path.append(os.getcwd()) 10 | from lib.fast_rcnn.config import cfg 11 | from lib.networks.factory import get_network 12 | from lib.fast_rcnn.test import test_ctpn 13 | 14 | # from ..lib.networks.factory import get_network 15 | # from ..lib.fast_rcnn.config import cfg 16 | # from..lib.fast_rcnn.test import test_ctpn 17 | ''' 18 | load network 19 | 输入的名称为'Net_model' 20 | 'VGGnet_test'--test 21 | 'VGGnet_train'-train 22 | ''' 23 | 24 | 25 | def load_tf_model(): 26 | cfg.TEST.HAS_RPN = True # Use RPN for proposals 27 | # init session 28 | config = tf.ConfigProto(allow_soft_placement=True) 29 | net = get_network("VGGnet_test") 30 | # load model 31 | saver = tf.train.Saver() 32 | # sess = tf.Session(config=config) 33 | sess = tf.Session() 34 | ckpt = tf.train.get_checkpoint_state( 35 | '/Users/xiaofeng/Code/Github/dataset/CHINESE_OCR/ctpn/ctpn_checkpoints/') 36 | reader = tf.train.NewCheckpointReader(ckpt.model_checkpoint_path) 37 | var_to_shape_map = reader.get_variable_to_shape_map() 38 | for key in var_to_shape_map: 39 | print("Tensor_name is : ", key) 40 | # print(reader.get_tensor(key)) 41 | saver.restore(sess, ckpt.model_checkpoint_path) 42 | print("load vggnet done") 43 | return sess, saver, net 44 | 45 | 46 | # init model 47 | sess, saver, net = load_tf_model() 48 | 49 | 50 | # 进行文本识别 51 | def ctpn(img): 52 | """ 53 | text box detect 54 | """ 55 | scale, max_scale = Config.SCALE, Config.MAX_SCALE 56 | # 对图像进行resize,输出的图像长宽 57 | img, f = resize_im(img, scale=scale, max_scale=max_scale) 58 | scores, boxes = test_ctpn(sess, net, img) 59 | return scores, boxes, img 60 | -------------------------------------------------------------------------------- /ctpn/ctpn/other.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from matplotlib import cm 4 | 5 | 6 | def prepare_img(im, mean): 7 | """ 8 | transform img into caffe's input img. 9 | """ 10 | im_data = np.transpose(im - mean, (2, 0, 1)) 11 | return im_data 12 | 13 | 14 | def draw_boxes(im, 15 | bboxes, 16 | is_display=True, 17 | color=None, 18 | caption="Image", 19 | wait=True): 20 | """ 21 | boxes: bounding boxes 22 | """ 23 | text_recs = np.zeros((len(bboxes), 8), np.int) 24 | 25 | im = im.copy() 26 | index = 0 27 | for box in bboxes: 28 | if color == None: 29 | if len(box) == 8 or len(box) == 9: 30 | c = tuple(cm.jet([box[-1]])[0, 2::-1] * 255) 31 | else: 32 | c = tuple(np.random.randint(0, 256, 3)) 33 | else: 34 | c = color 35 | 36 | b1 = box[6] - box[7] / 2 37 | b2 = box[6] + box[7] / 2 38 | x1 = box[0] 39 | y1 = box[5] * box[0] + b1 40 | x2 = box[2] 41 | y2 = box[5] * box[2] + b1 42 | x3 = box[0] 43 | y3 = box[5] * box[0] + b2 44 | x4 = box[2] 45 | y4 = box[5] * box[2] + b2 46 | 47 | disX = x2 - x1 48 | disY = y2 - y1 49 | width = np.sqrt(disX * disX + disY * disY) 50 | fTmp0 = y3 - y1 51 | fTmp1 = fTmp0 * disY / width 52 | x = np.fabs(fTmp1 * disX / width) 53 | y = np.fabs(fTmp1 * disY / width) 54 | if box[5] < 0: 55 | x1 -= x 56 | y1 += y 57 | x4 += x 58 | y4 -= y 59 | else: 60 | x2 += x 61 | y2 += y 62 | x3 -= x 63 | y3 -= y 64 | cv2.line(im, (int(x1), int(y1)), (int(x2), int(y2)), c, 2) 65 | cv2.line(im, (int(x1), int(y1)), (int(x3), int(y3)), c, 2) 66 | cv2.line(im, (int(x4), int(y4)), (int(x2), int(y2)), c, 2) 67 | cv2.line(im, (int(x3), int(y3)), (int(x4), int(y4)), c, 2) 68 | text_recs[index, 0] = x1 69 | text_recs[index, 1] = y1 70 | text_recs[index, 2] = x2 71 | text_recs[index, 3] = y2 72 | text_recs[index, 4] = x3 73 | text_recs[index, 5] = y3 74 | text_recs[index, 6] = x4 75 | text_recs[index, 7] = y4 76 | index = index + 1 77 | # cv2.rectangle(im, tuple(box[:2]), tuple(box[2:4]), c,2) 78 | # cv2.waitKey(0) 79 | # cv2.imshow('kk', im) 80 | cv2.imwrite('/Users/xiaofeng/Code/Github/Chinese-OCR/test/test_result.png',im) 81 | 82 | return text_recs, im 83 | 84 | 85 | def threshold(coords, min_, max_): 86 | return np.maximum(np.minimum(coords, max_), min_) 87 | 88 | 89 | def clip_boxes(boxes, im_shape): 90 | """ 91 | Clip boxes to image boundaries. 92 | """ 93 | boxes[:, 0::2] = threshold(boxes[:, 0::2], 0, im_shape[1] - 1) 94 | boxes[:, 1::2] = threshold(boxes[:, 1::2], 0, im_shape[0] - 1) 95 | return boxes 96 | 97 | 98 | def normalize(data): 99 | if data.shape[0] == 0: 100 | return data 101 | max_ = data.max() 102 | min_ = data.min() 103 | return (data - min_) / (max_ - min_) if max_ - min_ != 0 else data - min_ 104 | 105 | 106 | def resize_im(im, scale, max_scale=None): 107 | # 按照scale和图片的长宽的最小值的比值作为输入模型的图片的尺寸 108 | f = float(scale) / min(im.shape[0], im.shape[1]) 109 | if max_scale != None and f * max(im.shape[0], im.shape[1]) > max_scale: 110 | f = float(max_scale) / max(im.shape[0], im.shape[1]) 111 | return cv2.resize(im, (0, 0), fx=f, fy=f), f 112 | # return cv2.resize(im, (0, 0), fx=1.2, fy=1.2), f 113 | 114 | 115 | class Graph: 116 | def __init__(self, graph): 117 | self.graph = graph 118 | 119 | def sub_graphs_connected(self): 120 | sub_graphs = [] 121 | for index in range(self.graph.shape[0]): 122 | if not self.graph[:, index].any() and self.graph[index, :].any(): 123 | v = index 124 | sub_graphs.append([v]) 125 | while self.graph[v, :].any(): 126 | v = np.where(self.graph[v, :])[0][0] 127 | sub_graphs[-1].append(v) 128 | return sub_graphs 129 | -------------------------------------------------------------------------------- /ctpn/ctpn/text.yml: -------------------------------------------------------------------------------- 1 | EXP_DIR: ctpn_end2end 2 | LOG_DIR: ctpn 3 | IS_MULTISCALE: False 4 | NET_NAME: VGGnet 5 | ANCHOR_SCALES: [16] 6 | NCLASSES: 2 7 | TRAIN: 8 | OHEM: False 9 | RPN_BATCHSIZE: 300 10 | BATCH_SIZE: 300 11 | LOG_IMAGE_ITERS: 100 12 | DISPLAY: 10 13 | SNAPSHOT_ITERS: 1000 14 | HAS_RPN: True 15 | LEARNING_RATE: 0.001 16 | MOMENTUM: 0.9 17 | GAMMA: 0.1 18 | STEPSIZE: 90000 19 | IMS_PER_BATCH: 1 20 | BBOX_NORMALIZE_TARGETS_PRECOMPUTED: True 21 | RPN_POSITIVE_OVERLAP: 0.7 22 | RPN_BATCHSIZE: 256 23 | PROPOSAL_METHOD: gt 24 | BG_THRESH_LO: 0.0 25 | PRECLUDE_HARD_SAMPLES: True 26 | BBOX_INSIDE_WEIGHTS: [1, 1, 1, 1] 27 | RPN_BBOX_INSIDE_WEIGHTS: [1, 1, 1, 1] 28 | RPN_POSITIVE_WEIGHT: -1.0 29 | FG_FRACTION: 0.3 30 | WEIGHT_DECAY: 0.0005 31 | TEST: 32 | HAS_RPN: True 33 | -------------------------------------------------------------------------------- /ctpn/ctpn/text_proposal_connector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os, sys 3 | parentdir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 4 | sys.path.append(parentdir) 5 | 6 | from .text_proposal_graph_builder import TextProposalGraphBuilder 7 | 8 | 9 | class TextProposalConnector: 10 | """ 11 | Connect text proposals into text lines 12 | """ 13 | 14 | def __init__(self): 15 | self.graph_builder = TextProposalGraphBuilder() 16 | 17 | def group_text_proposals(self, text_proposals, scores, im_size): 18 | graph = self.graph_builder.build_graph(text_proposals, scores, im_size) 19 | return graph.sub_graphs_connected() 20 | 21 | def fit_y(self, X, Y, x1, x2): 22 | len(X) != 0 23 | # if X only include one point, the function will get line y=Y[0] 24 | if np.sum(X == X[0]) == len(X): 25 | return Y[0], Y[0] 26 | p = np.poly1d(np.polyfit(X, Y, 1)) 27 | return p(x1), p(x2) 28 | 29 | def get_text_lines(self, text_proposals, scores, im_size): 30 | """ 31 | text_proposals:boxes 32 | 33 | """ 34 | # tp=text proposal 35 | tp_groups = self.group_text_proposals(text_proposals, scores, 36 | im_size) ##find the text line 37 | 38 | text_lines = np.zeros((len(tp_groups), 8), np.float32) 39 | 40 | for index, tp_indices in enumerate(tp_groups): 41 | text_line_boxes = text_proposals[list(tp_indices)] 42 | num = np.size(text_line_boxes) ##find 43 | X = (text_line_boxes[:, 0] + text_line_boxes[:, 2]) / 2 44 | Y = (text_line_boxes[:, 1] + text_line_boxes[:, 3]) / 2 45 | 46 | z1 = np.polyfit(X, Y, 1) 47 | p1 = np.poly1d(z1) 48 | 49 | x0 = np.min(text_line_boxes[:, 0]) 50 | x1 = np.max(text_line_boxes[:, 2]) 51 | 52 | offset = (text_line_boxes[0, 2] - text_line_boxes[0, 0]) * 0.5 53 | 54 | lt_y, rt_y = self.fit_y(text_line_boxes[:, 0], 55 | text_line_boxes[:, 1], x0 + offset, 56 | x1 - offset) 57 | lb_y, rb_y = self.fit_y(text_line_boxes[:, 0], 58 | text_line_boxes[:, 3], x0 + offset, 59 | x1 - offset) 60 | 61 | # the score of a text line is the average score of the scores 62 | # of all text proposals contained in the text line 63 | score = scores[list(tp_indices)].sum() / float(len(tp_indices)) 64 | 65 | text_lines[index, 0] = x0 66 | text_lines[index, 1] = min(lt_y, rt_y) 67 | text_lines[index, 2] = x1 68 | text_lines[index, 3] = max(lb_y, rb_y) 69 | text_lines[index, 4] = score 70 | text_lines[index, 5] = z1[0] 71 | text_lines[index, 6] = z1[1] 72 | height = np.mean((text_line_boxes[:, 3] - text_line_boxes[:, 1])) 73 | text_lines[index, 7] = height + 2.5 74 | # text_lines=clip_boxes(text_lines, im_size) 75 | 76 | return text_lines 77 | -------------------------------------------------------------------------------- /ctpn/ctpn/text_proposal_graph_builder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import os, sys 4 | parentdir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 5 | sys.path.append(parentdir) 6 | 7 | from .cfg import Config as cfg 8 | from .other import Graph 9 | 10 | 11 | class TextProposalGraphBuilder: 12 | """ 13 | Build Text proposals into a graph. 14 | """ 15 | 16 | def get_successions(self, index): 17 | box = self.text_proposals[index] 18 | results = [] 19 | for left in range(int(box[0]) + 1, min(int(box[0]) + cfg.MAX_HORIZONTAL_GAP + 1, self.im_size[1])): 20 | adj_box_indices = self.boxes_table[left] 21 | for adj_box_index in adj_box_indices: 22 | if self.meet_v_iou(adj_box_index, index): 23 | results.append(adj_box_index) 24 | if len(results) != 0: 25 | return results 26 | return results 27 | 28 | def get_precursors(self, index): 29 | box = self.text_proposals[index] 30 | results = [] 31 | for left in range(int(box[0]) - 1, max(int(box[0] - cfg.MAX_HORIZONTAL_GAP), 0) - 1, -1): 32 | adj_box_indices = self.boxes_table[left] 33 | for adj_box_index in adj_box_indices: 34 | if self.meet_v_iou(adj_box_index, index): 35 | results.append(adj_box_index) 36 | if len(results) != 0: 37 | return results 38 | return results 39 | 40 | def is_succession_node(self, index, succession_index): 41 | precursors = self.get_precursors(succession_index) 42 | if self.scores[index] >= np.max(self.scores[precursors]): 43 | return True 44 | return False 45 | 46 | def meet_v_iou(self, index1, index2): 47 | def overlaps_v(index1, index2): 48 | h1 = self.heights[index1] 49 | h2 = self.heights[index2] 50 | y0 = max(self.text_proposals[index2][1], self.text_proposals[index1][1]) 51 | y1 = min(self.text_proposals[index2][3], self.text_proposals[index1][3]) 52 | return max(0, y1 - y0 + 1) / min(h1, h2) 53 | 54 | def size_similarity(index1, index2): 55 | h1 = self.heights[index1] 56 | h2 = self.heights[index2] 57 | return min(h1, h2) / max(h1, h2) 58 | 59 | return overlaps_v(index1, index2) >= cfg.MIN_V_OVERLAPS and \ 60 | size_similarity(index1, index2) >= cfg.MIN_SIZE_SIM 61 | 62 | def build_graph(self, text_proposals, scores, im_size): 63 | self.text_proposals = text_proposals 64 | self.scores = scores 65 | self.im_size = im_size 66 | self.heights = text_proposals[:, 3] - text_proposals[:, 1] + 1 67 | 68 | boxes_table = [[] for _ in range(self.im_size[1])] 69 | for index, box in enumerate(text_proposals): 70 | boxes_table[int(box[0])].append(index) 71 | self.boxes_table = boxes_table 72 | 73 | graph = np.zeros((text_proposals.shape[0], text_proposals.shape[0]), np.bool) 74 | 75 | for index, box in enumerate(text_proposals): 76 | successions = self.get_successions(index) 77 | if len(successions) == 0: 78 | continue 79 | succession_index = successions[np.argmax(scores[successions])] 80 | if self.is_succession_node(index, succession_index): 81 | # NOTE: a box can have multiple successions(precursors) if multiple successions(precursors) 82 | # have equal scores. 83 | graph[index, succession_index] = True 84 | return Graph(graph) 85 | -------------------------------------------------------------------------------- /ctpn/ctpn/train_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # _Author_: xiaofeng 4 | # Date: 2018-04-16 10:55:15 5 | # Last Modified by: xiaofeng 6 | # Last Modified time: 2018-04-16 10:55:15 7 | ''' 8 | 使用keras进行网络训练,速度相对pytorch比较慢 9 | ''' 10 | import os.path as osp 11 | import pprint 12 | import sys, os 13 | 14 | # sys.path.append(os.getcwd()) 15 | # this_dir = os.path.dirname(__file__) 16 | parentdir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 17 | sys.path.append(parentdir) 18 | 19 | from lib.fast_rcnn.train import get_training_roidb, train_net 20 | from lib.fast_rcnn.config import cfg_from_file, get_output_dir, get_log_dir 21 | from lib.datasets.factory import get_imdb 22 | from lib.networks.factory import get_network 23 | from lib.fast_rcnn.config import cfg 24 | 25 | if __name__ == '__main__': 26 | # 将text.yml的配置与默认config中的默认配置进行合并 27 | cfg_from_file('text.yml') 28 | print('Using config:~~~~~~~~~~~~~~~~') 29 | # 根据给定的名字,得到要加载的数据集 30 | imdb = get_imdb('voc_2007_trainval') 31 | print('Loaded dataset `{:s}` for training'.format(imdb.name)) 32 | # 准备训练数据 33 | roidb = get_training_roidb(imdb) 34 | # 模型输出的路径 35 | output_dir = get_output_dir(imdb, None) 36 | # summary的输出路径 37 | log_dir = get_log_dir(imdb) 38 | print('Output will be saved to `{:s}`'.format(output_dir)) 39 | print('Logs will be saved to `{:s}`'.format(log_dir)) 40 | 41 | device_name = '/gpu:0' 42 | print(device_name) 43 | 44 | network = get_network('VGGnet_train') 45 | 46 | train_net( 47 | network, 48 | imdb, 49 | roidb, 50 | output_dir=output_dir, 51 | log_dir=log_dir, 52 | pretrained_model= 53 | '/Users/xiaofeng/Code/Github/dataset/CHINESE_OCR/ctpn/pretrain/VGG_imagenet.npy', 54 | max_iters=180000, 55 | restore=bool(int(0))) 56 | -------------------------------------------------------------------------------- /ctpn/data/demo/001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/data/demo/001.jpg -------------------------------------------------------------------------------- /ctpn/data/demo/002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/data/demo/002.jpg -------------------------------------------------------------------------------- /ctpn/data/demo/003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/data/demo/003.jpg -------------------------------------------------------------------------------- /ctpn/data/demo/004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/data/demo/004.jpg -------------------------------------------------------------------------------- /ctpn/data/demo/005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/data/demo/005.jpg -------------------------------------------------------------------------------- /ctpn/data/demo/006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/data/demo/006.jpg -------------------------------------------------------------------------------- /ctpn/data/demo/007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/data/demo/007.jpg -------------------------------------------------------------------------------- /ctpn/data/demo/008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/data/demo/008.jpg -------------------------------------------------------------------------------- /ctpn/data/demo/009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/data/demo/009.jpg -------------------------------------------------------------------------------- /ctpn/data/demo/010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/data/demo/010.png -------------------------------------------------------------------------------- /ctpn/data/oriented_results/001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/data/oriented_results/001.jpg -------------------------------------------------------------------------------- /ctpn/data/oriented_results/002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/data/oriented_results/002.jpg -------------------------------------------------------------------------------- /ctpn/data/oriented_results/003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/data/oriented_results/003.jpg -------------------------------------------------------------------------------- /ctpn/data/oriented_results/004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/data/oriented_results/004.jpg -------------------------------------------------------------------------------- /ctpn/data/oriented_results/005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/data/oriented_results/005.jpg -------------------------------------------------------------------------------- /ctpn/data/oriented_results/006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/data/oriented_results/006.jpg -------------------------------------------------------------------------------- /ctpn/data/oriented_results/007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/data/oriented_results/007.jpg -------------------------------------------------------------------------------- /ctpn/data/oriented_results/008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/data/oriented_results/008.jpg -------------------------------------------------------------------------------- /ctpn/data/oriented_results/009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/data/oriented_results/009.jpg -------------------------------------------------------------------------------- /ctpn/data/oriented_results/010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/data/oriented_results/010.png -------------------------------------------------------------------------------- /ctpn/data/results/001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/data/results/001.jpg -------------------------------------------------------------------------------- /ctpn/data/results/002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/data/results/002.jpg -------------------------------------------------------------------------------- /ctpn/data/results/003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/data/results/003.jpg -------------------------------------------------------------------------------- /ctpn/data/results/010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/data/results/010.png -------------------------------------------------------------------------------- /ctpn/lib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # _Author_: xiaofeng 4 | # Date: 2018-04-08 14:41:12 5 | # Last Modified by: xiaofeng 6 | # Last Modified time: 2018-04-08 14:41:12 7 | 8 | from .imdb import imdb 9 | # from pascal_voc import pascal_voc 10 | from .pascal_voc import pascal_voc 11 | from . import factory 12 | 13 | def _which(program): 14 | import os 15 | def is_exe(fpath): 16 | return os.path.isfile(fpath) and os.access(fpath, os.X_OK) 17 | 18 | fpath, fname = os.path.split(program) 19 | if fpath: 20 | if is_exe(program): 21 | return program 22 | else: 23 | for path in os.environ["PATH"].split(os.pathsep): 24 | path = path.strip('"') 25 | exe_file = os.path.join(path, program) 26 | if is_exe(exe_file): 27 | return exe_file 28 | 29 | return None 30 | -------------------------------------------------------------------------------- /ctpn/lib/datasets/ds_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # _Author_: xiaofeng 4 | # Date: 2018-04-08 14:46:05 5 | # Last Modified by: xiaofeng 6 | # Last Modified time: 2018-04-08 14:46:05 7 | 8 | import numpy as np 9 | 10 | def unique_boxes(boxes, scale=1.0): 11 | """Return indices of unique boxes.""" 12 | v = np.array([1, 1e3, 1e6, 1e9]) 13 | hashes = np.round(boxes * scale).dot(v) 14 | _, index = np.unique(hashes, return_index=True) 15 | return np.sort(index) 16 | 17 | def xywh_to_xyxy(boxes): 18 | """Convert [x y w h] box format to [x1 y1 x2 y2] format.""" 19 | return np.hstack((boxes[:, 0:2], boxes[:, 0:2] + boxes[:, 2:4] - 1)) 20 | 21 | def xyxy_to_xywh(boxes): 22 | """Convert [x1 y1 x2 y2] box format to [x y w h] format.""" 23 | return np.hstack((boxes[:, 0:2], boxes[:, 2:4] - boxes[:, 0:2] + 1)) 24 | 25 | def validate_boxes(boxes, width=0, height=0): 26 | """Check that a set of boxes are valid.""" 27 | x1 = boxes[:, 0] 28 | y1 = boxes[:, 1] 29 | x2 = boxes[:, 2] 30 | y2 = boxes[:, 3] 31 | assert (x1 >= 0).all() 32 | assert (y1 >= 0).all() 33 | assert (x2 >= x1).all() 34 | assert (y2 >= y1).all() 35 | assert (x2 < width).all() 36 | assert (y2 < height).all() 37 | 38 | def filter_small_boxes(boxes, min_size): 39 | w = boxes[:, 2] - boxes[:, 0] 40 | h = boxes[:, 3] - boxes[:, 1] 41 | keep = np.where((w >= min_size) & (h > min_size))[0] 42 | return keep 43 | -------------------------------------------------------------------------------- /ctpn/lib/datasets/factory.py: -------------------------------------------------------------------------------- 1 | __sets = {} 2 | from .pascal_voc import pascal_voc 3 | 4 | 5 | def _selective_search_IJCV_top_k(split, year, top_k): 6 | imdb = pascal_voc(split, year) 7 | imdb.roidb_handler = imdb.selective_search_IJCV_roidb 8 | imdb.config['top_k'] = top_k 9 | return imdb 10 | 11 | 12 | # Set up voc__ using selective search "fast" mode 13 | for year in ['2007', '2012', '0712']: 14 | for split in ['train', 'val', 'trainval', 'test']: 15 | name = 'voc_{}_{}'.format(year, split) 16 | # __sets[name] = (lambda split=split, year=year: pascal_voc(split, year)) 17 | __sets[name] = (lambda split=split, year=year: pascal_voc(split, year)) 18 | 19 | 20 | def get_imdb(name): 21 | """Get an imdb (image database) by name.""" 22 | # print('__Sets', __sets) 23 | if name not in __sets: 24 | raise KeyError('Unknown dataset: {}'.format(name)) 25 | return __sets[name]() 26 | 27 | 28 | def list_imdbs(): 29 | """List all registered imdbs.""" 30 | return list(__sets.keys()) 31 | -------------------------------------------------------------------------------- /ctpn/lib/datasets/imdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import PIL 4 | import numpy as np 5 | import scipy.sparse 6 | 7 | from ..utils.bbox import bbox_overlaps 8 | from ..fast_rcnn.config import cfg 9 | 10 | 11 | class imdb(object): 12 | def __init__(self, name): 13 | self._name = name 14 | self._num_classes = 0 15 | self._classes = [] 16 | self._image_index = [] 17 | self._obj_proposer = 'selective_search' 18 | self._roidb = None 19 | print(self.default_roidb) 20 | self._roidb_handler = self.default_roidb 21 | # Use this dict for storing dataset specific config options 22 | self.config = {} 23 | 24 | @property 25 | def name(self): 26 | return self._name 27 | 28 | @property 29 | def num_classes(self): 30 | return len(self._classes) 31 | 32 | @property 33 | def classes(self): 34 | return self._classes 35 | 36 | @property 37 | def image_index(self): 38 | return self._image_index 39 | 40 | @property 41 | def roidb_handler(self): 42 | return self._roidb_handler 43 | 44 | @roidb_handler.setter 45 | def roidb_handler(self, val): 46 | self._roidb_handler = val 47 | 48 | def set_proposal_method(self, method): 49 | method = eval('self.' + method + '_roidb') 50 | self.roidb_handler = method 51 | 52 | @property 53 | def roidb(self): 54 | # A roidb is a list of dictionaries, each with the following keys: 55 | # boxes 56 | # gt_overlaps 57 | # gt_classes 58 | # flipped 59 | if self._roidb is not None: 60 | return self._roidb 61 | self._roidb = self.roidb_handler() 62 | return self._roidb 63 | 64 | @property 65 | def cache_path(self): 66 | cache_path = osp.abspath(osp.join(cfg.DATA_DIR, 'cache')) 67 | if not os.path.exists(cache_path): 68 | os.makedirs(cache_path) 69 | return cache_path 70 | 71 | @property 72 | def num_images(self): 73 | return len(self.image_index) 74 | 75 | def image_path_at(self, i): 76 | raise NotImplementedError 77 | 78 | def default_roidb(self): 79 | raise NotImplementedError 80 | 81 | def _get_widths(self): 82 | return [ 83 | PIL.Image.open(self.image_path_at(i)).size[0] 84 | for i in range(self.num_images) 85 | ] 86 | 87 | def append_flipped_images(self): 88 | num_images = self.num_images 89 | widths = self._get_widths() 90 | for i in range(num_images): 91 | boxes = self.roidb[i]['boxes'].copy() 92 | oldx1 = boxes[:, 0].copy() 93 | oldx2 = boxes[:, 2].copy() 94 | boxes[:, 0] = widths[i] - oldx2 - 1 95 | boxes[:, 2] = widths[i] - oldx1 - 1 96 | for b in range(len(boxes)): 97 | if boxes[b][2] < boxes[b][0]: 98 | boxes[b][0] = 0 99 | assert (boxes[:, 2] >= boxes[:, 0]).all() 100 | entry = { 101 | 'boxes': boxes, 102 | 'gt_overlaps': self.roidb[i]['gt_overlaps'], 103 | 'gt_classes': self.roidb[i]['gt_classes'], 104 | 'flipped': True 105 | } 106 | 107 | if 'gt_ishard' in self.roidb[i] and 'dontcare_areas' in self.roidb[i]: 108 | entry['gt_ishard'] = self.roidb[i]['gt_ishard'].copy() 109 | dontcare_areas = self.roidb[i]['dontcare_areas'].copy() 110 | oldx1 = dontcare_areas[:, 0].copy() 111 | oldx2 = dontcare_areas[:, 2].copy() 112 | dontcare_areas[:, 0] = widths[i] - oldx2 - 1 113 | dontcare_areas[:, 2] = widths[i] - oldx1 - 1 114 | entry['dontcare_areas'] = dontcare_areas 115 | 116 | self.roidb.append(entry) 117 | 118 | self._image_index = self._image_index * 2 119 | 120 | def create_roidb_from_box_list(self, box_list, gt_roidb): 121 | assert len(box_list) == self.num_images, \ 122 | 'Number of boxes must match number of ground-truth images' 123 | roidb = [] 124 | for i in range(self.num_images): 125 | boxes = box_list[i] 126 | num_boxes = boxes.shape[0] 127 | overlaps = np.zeros( 128 | (num_boxes, self.num_classes), dtype=np.float32) 129 | 130 | if gt_roidb is not None and gt_roidb[i]['boxes'].size > 0: 131 | gt_boxes = gt_roidb[i]['boxes'] 132 | gt_classes = gt_roidb[i]['gt_classes'] 133 | gt_overlaps = bbox_overlaps( 134 | boxes.astype(np.float), gt_boxes.astype(np.float)) 135 | argmaxes = gt_overlaps.argmax(axis=1) 136 | maxes = gt_overlaps.max(axis=1) 137 | I = np.where(maxes > 0)[0] 138 | overlaps[I, gt_classes[argmaxes[I]]] = maxes[I] 139 | 140 | overlaps = scipy.sparse.csr_matrix(overlaps) 141 | roidb.append({ 142 | 'boxes': 143 | boxes, 144 | 'gt_classes': 145 | np.zeros((num_boxes, ), dtype=np.int32), 146 | 'gt_overlaps': 147 | overlaps, 148 | 'flipped': 149 | False, 150 | 'seg_areas': 151 | np.zeros((num_boxes, ), dtype=np.float32), 152 | }) 153 | return roidb 154 | 155 | @staticmethod 156 | def merge_roidbs(a, b): 157 | assert len(a) == len(b) 158 | for i in range(len(a)): 159 | a[i]['boxes'] = np.vstack((a[i]['boxes'], b[i]['boxes'])) 160 | a[i]['gt_classes'] = np.hstack((a[i]['gt_classes'], 161 | b[i]['gt_classes'])) 162 | a[i]['gt_overlaps'] = scipy.sparse.vstack( 163 | [a[i]['gt_overlaps'], b[i]['gt_overlaps']]) 164 | a[i]['seg_areas'] = np.hstack((a[i]['seg_areas'], 165 | b[i]['seg_areas'])) 166 | return a 167 | -------------------------------------------------------------------------------- /ctpn/lib/datasets/pascal_voc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # _Author_: xiaofeng 4 | # Date: 2018-04-08 14:40:30 5 | # Last Modified by: xiaofeng 6 | # Last Modified time: 2018-04-08 14:40:30 7 | 8 | import os, sys 9 | import numpy as np 10 | import scipy.sparse 11 | parentdir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 12 | sys.path.insert(0, parentdir) 13 | try: 14 | import cPickle as pickle 15 | except: 16 | import pickle 17 | #import pickle 18 | import uuid 19 | import scipy.io as sio 20 | import xml.etree.ElementTree as ET 21 | from .imdb import imdb 22 | 23 | from .ds_utils import * 24 | from ..fast_rcnn.config import cfg 25 | 26 | 27 | class pascal_voc(imdb): 28 | def __init__(self, image_set, year, devkit_path=None): 29 | imdb.__init__(self, 'voc_' + year + '_' + image_set) 30 | self._year = year 31 | self._image_set = image_set 32 | # 生成数据集的根目录 33 | self._devkit_path = self._get_default_path() if devkit_path is None \ 34 | else devkit_path 35 | # 将数据文件保存在仓库之外的位置 36 | self._devkit_path = '/Users/xiaofeng/Code/Github/dataset/CHINESE_OCR/ctpn/VOCdevkit2007' 37 | # 得到数据集的目录 38 | self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year) 39 | 40 | self._classes = ( 41 | '__background__', # always index 0 42 | 'text') 43 | 44 | self._class_to_ind = dict( 45 | list(zip(self.classes, list(range(self.num_classes))))) 46 | self._image_ext = '.jpg' 47 | # 读取数据集中的txt文件,得到对应的图片的索引 48 | self._image_index = self._load_image_set_index() 49 | # Default to roidb handler 50 | #self._roidb_handler = self.selective_search_roidb 51 | self._roidb_handler = self.gt_roidb 52 | self._salt = str(uuid.uuid4()) 53 | self._comp_id = 'comp4' 54 | 55 | # PASCAL specific config options 56 | self.config = { 57 | 'cleanup': True, 58 | 'use_salt': True, 59 | 'use_diff': False, 60 | 'matlab_eval': False, 61 | 'rpn_file': None, 62 | 'min_size': 2 63 | } 64 | 65 | assert os.path.exists(self._devkit_path), \ 66 | 'VOCdevkit path does not exist: {}'.format(self._devkit_path) 67 | assert os.path.exists(self._data_path), \ 68 | 'Path does not exist: {}'.format(self._data_path) 69 | 70 | def image_path_at(self, i): 71 | """ 72 | Return the absolute path to image i in the image sequence. 73 | """ 74 | return self.image_path_from_index(self._image_index[i]) 75 | 76 | def image_path_from_index(self, index): 77 | """ 78 | Construct an image path from the image's "index" identifier. 79 | """ 80 | image_path = os.path.join(self._data_path, 'JPEGImages', 81 | index + self._image_ext) 82 | assert os.path.exists(image_path), \ 83 | 'Path does not exist: {}'.format(image_path) 84 | return image_path 85 | 86 | def _load_image_set_index(self): 87 | """ 88 | Load the indexes listed in this dataset's image set file. 89 | """ 90 | # Example path to image set file: 91 | # self._devkit_path + /VOCdevkit2007/VOC2007/ImageSets/Main/val.txt 92 | image_set_file = os.path.join(self._data_path, 'ImageSets', 'Main', 93 | self._image_set + '.txt') 94 | assert os.path.exists(image_set_file), \ 95 | 'Path does not exist: {}'.format(image_set_file) 96 | with open(image_set_file) as f: 97 | image_index = [x.strip() for x in f.readlines()] 98 | return image_index 99 | 100 | def _get_default_path(self): 101 | """ 102 | Return the default path where PASCAL VOC is expected to be installed. 103 | """ 104 | return os.path.join(cfg.DATA_DIR, 'VOCdevkit' + self._year) 105 | 106 | def gt_roidb(self): 107 | """ 108 | Return the database of ground-truth regions of interest. 109 | This function loads/saves from/to a cache file to speed up future calls. 110 | """ 111 | # name 是指定的要读取的数据集的字符串 112 | cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl') 113 | # 如果catch存在则直接读取catch 114 | if os.path.exists(cache_file): 115 | with open(cache_file, 'rb') as fid: 116 | roidb = pickle.load(fid) 117 | print('{} gt roidb loaded from {}'.format(self.name, cache_file)) 118 | return roidb 119 | 120 | gt_roidb = [ 121 | self._load_pascal_annotation(index) for index in self.image_index 122 | ] 123 | with open(cache_file, 'wb') as fid: 124 | pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL) 125 | print('wrote gt roidb to {}'.format(cache_file)) 126 | 127 | return gt_roidb 128 | 129 | def selective_search_roidb(self): 130 | """ 131 | Return the database of selective search regions of interest. 132 | Ground-truth ROIs are also included. 133 | 134 | This function loads/saves from/to a cache file to speed up future calls. 135 | """ 136 | cache_file = os.path.join(self.cache_path, 137 | self.name + '_selective_search_roidb.pkl') 138 | 139 | if os.path.exists(cache_file): 140 | with open(cache_file, 'rb') as fid: 141 | roidb = pickle.load(fid) 142 | print('{} ss roidb loaded from {}'.format(self.name, cache_file)) 143 | return roidb 144 | 145 | if int(self._year) == 2007 or self._image_set != 'test': 146 | gt_roidb = self.gt_roidb() 147 | ss_roidb = self._load_selective_search_roidb(gt_roidb) 148 | roidb = imdb.merge_roidbs(gt_roidb, ss_roidb) 149 | else: 150 | roidb = self._load_selective_search_roidb(None) 151 | with open(cache_file, 'wb') as fid: 152 | pickle.dump(roidb, fid, pickle.HIGHEST_PROTOCOL) 153 | print('wrote ss roidb to {}'.format(cache_file)) 154 | 155 | return roidb 156 | 157 | def rpn_roidb(self): 158 | if int(self._year) == 2007 or self._image_set != 'test': 159 | gt_roidb = self.gt_roidb() 160 | rpn_roidb = self._load_rpn_roidb(gt_roidb) 161 | roidb = imdb.merge_roidbs(gt_roidb, rpn_roidb) 162 | else: 163 | roidb = self._load_rpn_roidb(None) 164 | 165 | return roidb 166 | 167 | def _load_rpn_roidb(self, gt_roidb): 168 | filename = self.config['rpn_file'] 169 | print('loading {}'.format(filename)) 170 | assert os.path.exists(filename), \ 171 | 'rpn data not found at: {}'.format(filename) 172 | with open(filename, 'rb') as f: 173 | box_list = pickle.load(f) 174 | return self.create_roidb_from_box_list(box_list, gt_roidb) 175 | 176 | def _load_selective_search_roidb(self, gt_roidb): 177 | filename = os.path.abspath( 178 | os.path.join(cfg.DATA_DIR, 'selective_search_data', 179 | self.name + '.mat')) 180 | assert os.path.exists(filename), \ 181 | 'Selective search data not found at: {}'.format(filename) 182 | raw_data = sio.loadmat(filename)['boxes'].ravel() 183 | 184 | box_list = [] 185 | for i in range(raw_data.shape[0]): 186 | boxes = raw_data[i][:, (1, 0, 3, 2)] - 1 187 | keep = unique_boxes(boxes) 188 | boxes = boxes[keep, :] 189 | keep = filter_small_boxes(boxes, self.config['min_size']) 190 | boxes = boxes[keep, :] 191 | box_list.append(boxes) 192 | 193 | return self.create_roidb_from_box_list(box_list, gt_roidb) 194 | 195 | def _load_pascal_annotation(self, index): 196 | """ 197 | Load image and bounding boxes info from XML file in the PASCAL VOC 198 | format. 199 | """ 200 | filename = os.path.join(self._data_path, 'Annotations', index + '.xml') 201 | tree = ET.parse(filename) 202 | objs = tree.findall('object') 203 | num_objs = len(objs) 204 | 205 | boxes = np.zeros((num_objs, 4), dtype=np.uint16) 206 | gt_classes = np.zeros((num_objs), dtype=np.int32) 207 | overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32) 208 | # "Seg" area for pascal is just the box area 209 | seg_areas = np.zeros((num_objs), dtype=np.float32) 210 | ishards = np.zeros((num_objs), dtype=np.int32) 211 | 212 | # objs = diff_objs(or non_diff_objs) 213 | # ignore any objects with classes except the classes we are looking for 214 | 215 | # cls_objs = [ 216 | # obj for obj in objs if obj.find('name').text in self._classes 217 | # ] 218 | # objs = cls_objs 219 | 220 | #### 221 | 222 | # Load object bounding boxes into a data frame. 223 | for ix, obj in enumerate(objs): 224 | bbox = obj.find('bndbox') 225 | # Make pixel indexes 0-based 226 | x1 = float(bbox.find('xmin').text) 227 | y1 = float(bbox.find('ymin').text) 228 | x2 = float(bbox.find('xmax').text) 229 | y2 = float(bbox.find('ymax').text) 230 | ''' 231 | x1 = float(bbox.find('xmin').text) - 1 232 | y1 = float(bbox.find('ymin').text) - 1 233 | x2 = float(bbox.find('xmax').text) - 1 234 | y2 = float(bbox.find('ymax').text) - 1 235 | ''' 236 | diffc = obj.find('difficult') 237 | difficult = 0 if diffc == None else int(diffc.text) 238 | ishards[ix] = difficult 239 | 240 | cls = self._class_to_ind[obj.find('name').text.lower().strip()] 241 | # cls = self._class_to_ind[obj.find('name').text] 242 | 243 | boxes[ix, :] = [x1, y1, x2, y2] 244 | gt_classes[ix] = cls 245 | overlaps[ix, cls] = 1.0 246 | seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1) 247 | 248 | overlaps = scipy.sparse.csr_matrix(overlaps) 249 | 250 | return { 251 | 'boxes': boxes, 252 | 'gt_classes': gt_classes, 253 | 'gt_ishard': ishards, 254 | 'gt_overlaps': overlaps, 255 | 'flipped': False, 256 | 'seg_areas': seg_areas 257 | } 258 | 259 | def _get_comp_id(self): 260 | comp_id = (self._comp_id + '_' + self._salt 261 | if self.config['use_salt'] else self._comp_id) 262 | return comp_id 263 | 264 | def _get_voc_results_file_template(self): 265 | filename = self._get_comp_id() + '_det_' + self._image_set + '_{:s}.txt' 266 | filedir = os.path.join(self._devkit_path, 'results', 267 | 'VOC' + self._year, 'Main') 268 | if not os.path.exists(filedir): 269 | os.makedirs(filedir) 270 | path = os.path.join(filedir, filename) 271 | return path 272 | 273 | def _write_voc_results_file(self, all_boxes): 274 | for cls_ind, cls in enumerate(self.classes): 275 | if cls == '__background__': 276 | continue 277 | print('Writing {} VOC results file'.format(cls)) 278 | filename = self._get_voc_results_file_template().format(cls) 279 | with open(filename, 'wt') as f: 280 | for im_ind, index in enumerate(self.image_index): 281 | dets = all_boxes[cls_ind][im_ind] 282 | if dets == []: 283 | continue 284 | # the VOCdevkit expects 1-based indices 285 | for k in range(dets.shape[0]): 286 | f.write( 287 | '{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.format( 288 | index, dets[k, -1], dets[k, 0] + 1, 289 | dets[k, 1] + 1, dets[k, 2] + 1, 290 | dets[k, 3] + 1)) 291 | 292 | 293 | if __name__ == '__main__': 294 | d = pascal_voc('trainval', '2007') 295 | res = d.roidb 296 | from IPython import embed 297 | embed() 298 | -------------------------------------------------------------------------------- /ctpn/lib/fast_rcnn/__init__.py: -------------------------------------------------------------------------------- 1 | from . import config 2 | from . import nms_wrapper 3 | from . import test 4 | from . import train 5 | -------------------------------------------------------------------------------- /ctpn/lib/fast_rcnn/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/lib/fast_rcnn/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ctpn/lib/fast_rcnn/__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/lib/fast_rcnn/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /ctpn/lib/fast_rcnn/__pycache__/nms_wrapper.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/lib/fast_rcnn/__pycache__/nms_wrapper.cpython-36.pyc -------------------------------------------------------------------------------- /ctpn/lib/fast_rcnn/bbox_transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # tinanjia 4 | def bbox_transform(ex_rois, gt_rois): 5 | """ 6 | computes the distance from ground-truth boxes to the given boxes, normed by their size 7 | :param ex_rois: n * 4 numpy array, given boxes 8 | :param gt_rois: n * 4 numpy array, ground-truth boxes 9 | :return: deltas: n * 4 numpy array, ground-truth boxes 10 | """ 11 | ex_widths = ex_rois[:, 2] - ex_rois[:, 0] + 1.0 12 | ex_heights = ex_rois[:, 3] - ex_rois[:, 1] + 1.0 13 | ex_ctr_x = ex_rois[:, 0] + 0.5 * ex_widths 14 | ex_ctr_y = ex_rois[:, 1] + 0.5 * ex_heights 15 | 16 | assert np.min(ex_widths) > 0.1 and np.min(ex_heights) > 0.1, \ 17 | 'Invalid boxes found: {} {}'. \ 18 | format(ex_rois[np.argmin(ex_widths), :], ex_rois[np.argmin(ex_heights), :]) 19 | 20 | gt_widths = gt_rois[:, 2] - gt_rois[:, 0] + 1.0 21 | gt_heights = gt_rois[:, 3] - gt_rois[:, 1] + 1.0 22 | gt_ctr_x = gt_rois[:, 0] + 0.5 * gt_widths 23 | gt_ctr_y = gt_rois[:, 1] + 0.5 * gt_heights 24 | 25 | # warnings.catch_warnings() 26 | # warnings.filterwarnings('error') 27 | targets_dx = (gt_ctr_x - ex_ctr_x) / ex_widths 28 | targets_dy = (gt_ctr_y - ex_ctr_y) / ex_heights 29 | targets_dw = np.log(gt_widths / ex_widths) 30 | targets_dh = np.log(gt_heights / ex_heights) 31 | 32 | targets = np.vstack( 33 | (targets_dx, targets_dy, targets_dw, targets_dh)).transpose() 34 | 35 | return targets 36 | 37 | 38 | def bbox_transform_inv(boxes, deltas): 39 | boxes = boxes.astype(deltas.dtype, copy=False) 40 | 41 | widths = boxes[:, 2] - boxes[:, 0] + 1.0 42 | heights = boxes[:, 3] - boxes[:, 1] + 1.0 43 | ctr_x = boxes[:, 0] + 0.5 * widths 44 | ctr_y = boxes[:, 1] + 0.5 * heights 45 | 46 | dx = deltas[:, 0::4] 47 | dy = deltas[:, 1::4] 48 | dw = deltas[:, 2::4] 49 | dh = deltas[:, 3::4] 50 | 51 | pred_ctr_x = ctr_x[:, np.newaxis] 52 | pred_ctr_y = dy * heights[:, np.newaxis] + ctr_y[:, np.newaxis] 53 | pred_w = widths[:, np.newaxis] 54 | pred_h = np.exp(dh) * heights[:, np.newaxis] 55 | 56 | pred_boxes = np.zeros(deltas.shape, dtype=deltas.dtype) 57 | # x1 58 | pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w 59 | # y1 60 | pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h 61 | # x2 62 | pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w 63 | # y2 64 | pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h 65 | 66 | return pred_boxes 67 | 68 | 69 | def clip_boxes(boxes, im_shape): 70 | """ 71 | Clip boxes to image boundaries. 72 | """ 73 | 74 | # x1 >= 0 75 | boxes[:, 0::4] = np.maximum(np.minimum(boxes[:, 0::4], im_shape[1] - 1), 0) 76 | # y1 >= 0 77 | boxes[:, 1::4] = np.maximum(np.minimum(boxes[:, 1::4], im_shape[0] - 1), 0) 78 | # x2 < im_shape[1] 79 | boxes[:, 2::4] = np.maximum(np.minimum(boxes[:, 2::4], im_shape[1] - 1), 0) 80 | # y2 < im_shape[0] 81 | boxes[:, 3::4] = np.maximum(np.minimum(boxes[:, 3::4], im_shape[0] - 1), 0) 82 | return boxes 83 | -------------------------------------------------------------------------------- /ctpn/lib/fast_rcnn/config.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick 6 | # -------------------------------------------------------- 7 | """Fast R-CNN config system. 8 | This file specifies default config options for Fast R-CNN. You should not 9 | change values in this file. Instead, you should write a config file (in yaml) 10 | and use cfg_from_file(yaml_file) to load it and override the default options. 11 | Most tools in $ROOT/tools take a --cfg option to specify an override file. 12 | - See tools/{train,test}_net.py for example code that uses cfg_from_file() 13 | - See experiments/cfgs/*.yml for example YAML config override files 14 | """ 15 | 16 | import os 17 | import os.path as osp 18 | from time import strftime, localtime 19 | 20 | import numpy as np 21 | from easydict import EasyDict as edict 22 | 23 | __C = edict() 24 | # Consumers can get config by: 25 | # from fast_rcnn_config import cfg 26 | cfg = __C 27 | 28 | # 29 | # Training options 30 | # 31 | 32 | # region proposal network (RPN) or not 33 | __C.IS_RPN = True 34 | __C.ANCHOR_SCALES = [16] 35 | __C.NCLASSES = 2 36 | 37 | # multiscale training and testing 38 | __C.IS_MULTISCALE = False 39 | __C.IS_EXTRAPOLATING = True 40 | 41 | __C.REGION_PROPOSAL = 'RPN' 42 | 43 | __C.NET_NAME = 'VGGnet' 44 | __C.SUBCLS_NAME = 'voxel_exemplars' 45 | 46 | __C.TRAIN = edict() 47 | # Adam, Momentum, RMS 48 | __C.TRAIN.SOLVER = 'Momentum' 49 | # learning rate 50 | __C.TRAIN.WEIGHT_DECAY = 0.0005 51 | __C.TRAIN.LEARNING_RATE = 0.001 52 | __C.TRAIN.MOMENTUM = 0.9 53 | __C.TRAIN.GAMMA = 0.1 54 | __C.TRAIN.STEPSIZE = 50000 55 | __C.TRAIN.DISPLAY = 10 56 | __C.TRAIN.LOG_IMAGE_ITERS = 100 57 | __C.TRAIN.OHEM = False 58 | __C.TRAIN.RANDOM_DOWNSAMPLE = False 59 | 60 | # Scales to compute real features 61 | __C.TRAIN.SCALES_BASE = (0.25, 0.5, 1.0, 2.0, 3.0) 62 | # __C.TRAIN.SCALES_BASE = (1.0,) 63 | 64 | # parameters for ROI generating 65 | # __C.TRAIN.SPATIAL_SCALE = 0.0625 66 | __C.TRAIN.KERNEL_SIZE = 5 67 | 68 | # Aspect ratio to use during training 69 | # __C.TRAIN.ASPECTS = (1, 0.75, 0.5, 0.25) 70 | __C.TRAIN.ASPECTS = (1, ) 71 | 72 | # Scales to use during training (can list multiple scales) 73 | # Each scale is the pixel size of an image's shortest side 74 | __C.TRAIN.SCALES = (600, ) 75 | 76 | # Max pixel size of the longest side of a scaled input image 77 | __C.TRAIN.MAX_SIZE = 1000 78 | 79 | # Images to use per minibatch 80 | __C.TRAIN.IMS_PER_BATCH = 1 81 | 82 | # Minibatch size (number of regions of interest [ROIs]) 83 | __C.TRAIN.BATCH_SIZE = 128 84 | 85 | # Fraction of minibatch that is labeled foreground (i.e. class > 0) 86 | __C.TRAIN.FG_FRACTION = 0.25 87 | 88 | # Overlap threshold for a ROI to be considered foreground (if >= FG_THRESH) 89 | __C.TRAIN.FG_THRESH = 0.5 90 | 91 | # Overlap threshold for a ROI to be considered background (class = 0 if 92 | # overlap in [LO, HI)) 93 | __C.TRAIN.BG_THRESH_HI = 0.5 94 | __C.TRAIN.BG_THRESH_LO = 0.1 95 | 96 | # Use horizontally-flipped images during training? 97 | __C.TRAIN.USE_FLIPPED = True 98 | 99 | # Train bounding-box regressors 100 | __C.TRAIN.BBOX_REG = True 101 | 102 | # Overlap required between a ROI and ground-truth box in order for that ROI to 103 | # be used as a bounding-box regression training example 104 | __C.TRAIN.BBOX_THRESH = 0.5 105 | 106 | # Iterations between snapshots 107 | __C.TRAIN.SNAPSHOT_ITERS = 5000 108 | 109 | # solver.prototxt specifies the snapshot path prefix, this adds an optional 110 | # infix to yield the path: [_]_iters_XYZ.caffemodel 111 | __C.TRAIN.SNAPSHOT_PREFIX = 'VGGnet_fast_rcnn' 112 | __C.TRAIN.SNAPSHOT_INFIX = '' 113 | 114 | # Use a prefetch thread in roi_data_layer.layer 115 | # So far I haven't found this useful; likely more engineering work is required 116 | __C.TRAIN.USE_PREFETCH = False 117 | 118 | # Normalize the targets (subtract empirical mean, divide by empirical stddev) 119 | __C.TRAIN.BBOX_NORMALIZE_TARGETS = True 120 | # Deprecated (inside weights) 121 | # used for assigning weights for each coords (x1, y1, w, h) 122 | __C.TRAIN.BBOX_INSIDE_WEIGHTS = (1.0, 1.0, 1.0, 1.0) 123 | # Normalize the targets using "precomputed" (or made up) means and stdevs 124 | # (BBOX_NORMALIZE_TARGETS must also be True) 125 | __C.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED = True 126 | __C.TRAIN.BBOX_NORMALIZE_MEANS = (0.0, 0.0, 0.0, 0.0) 127 | __C.TRAIN.BBOX_NORMALIZE_STDS = (0.1, 0.1, 0.2, 0.2) 128 | # faster rcnn dont use pre-generated rois by selective search 129 | # __C.TRAIN.BBOX_NORMALIZE_STDS = (1, 1, 1, 1) 130 | 131 | # Train using these proposals 132 | __C.TRAIN.PROPOSAL_METHOD = 'selective_search' 133 | 134 | # Make minibatches from images that have similar aspect ratios (i.e. both 135 | # tall and thin or both short and wide) in order to avoid wasting computation 136 | # on zero-padding. 137 | __C.TRAIN.ASPECT_GROUPING = True 138 | # preclude rois intersected with dontcare areas above the value 139 | __C.TRAIN.DONTCARE_AREA_INTERSECTION_HI = 0.5 140 | __C.TRAIN.PRECLUDE_HARD_SAMPLES = True 141 | # Use RPN to detect objects 142 | __C.TRAIN.HAS_RPN = True 143 | # IOU >= thresh: positive example 144 | __C.TRAIN.RPN_POSITIVE_OVERLAP = 0.7 145 | # IOU < thresh: negative example 146 | __C.TRAIN.RPN_NEGATIVE_OVERLAP = 0.3 147 | # If an anchor statisfied by positive and negative conditions set to negative 148 | __C.TRAIN.RPN_CLOBBER_POSITIVES = False 149 | # Max number of foreground examples 150 | __C.TRAIN.RPN_FG_FRACTION = 0.5 151 | # Total number of examples 152 | __C.TRAIN.RPN_BATCHSIZE = 256 153 | # NMS threshold used on RPN proposals 154 | __C.TRAIN.RPN_NMS_THRESH = 0.7 155 | # Number of top scoring boxes to keep before apply NMS to RPN proposals 156 | __C.TRAIN.RPN_PRE_NMS_TOP_N = 12000 157 | # Number of top scoring boxes to keep after applying NMS to RPN proposals 158 | __C.TRAIN.RPN_POST_NMS_TOP_N = 2000 159 | # Proposal height and width both need to be greater than RPN_MIN_SIZE (at orig image scale) 160 | __C.TRAIN.RPN_MIN_SIZE = 8 161 | # Deprecated (outside weights) 162 | __C.TRAIN.RPN_BBOX_INSIDE_WEIGHTS = (1.0, 1.0, 1.0, 1.0) 163 | # Give the positive RPN examples weight of p * 1 / {num positives} 164 | # and give negatives a weight of (1 - p) 165 | # Set to -1.0 to use uniform example weighting 166 | __C.TRAIN.RPN_POSITIVE_WEIGHT = -1.0 167 | # __C.TRAIN.RPN_POSITIVE_WEIGHT = 0.5 168 | 169 | # 170 | # Testing options 171 | # 172 | 173 | __C.TEST = edict() 174 | 175 | # Scales to use during testing (can list multiple scales) 176 | # Each scale is the pixel size of an image's shortest side 177 | __C.TEST.SCALES = (900, ) 178 | 179 | # Max pixel size of the longest side of a scaled input image 180 | __C.TEST.MAX_SIZE = 1500 181 | 182 | # Overlap threshold used for non-maximum suppression (suppress boxes with 183 | # IoU >= this threshold) 184 | __C.TEST.NMS = 0.3 185 | 186 | # Experimental: treat the (K+1) units in the cls_score layer as linear 187 | # predictors (trained, eg, with one-vs-rest SVMs). 188 | __C.TEST.SVM = False 189 | 190 | # Test using bounding-box regressors 191 | __C.TEST.BBOX_REG = True 192 | 193 | # Propose boxes 194 | __C.TEST.HAS_RPN = True 195 | 196 | # Test using these proposals 197 | __C.TEST.PROPOSAL_METHOD = 'selective_search' 198 | 199 | ## NMS threshold used on RPN proposals 200 | __C.TEST.RPN_NMS_THRESH = 0.7 201 | ## Number of top scoring boxes to keep before apply NMS to RPN proposals 202 | # __C.TEST.RPN_PRE_NMS_TOP_N = 6000 203 | __C.TEST.RPN_PRE_NMS_TOP_N = 12000 204 | ## Number of top scoring boxes to keep after applying NMS to RPN proposals 205 | __C.TEST.RPN_POST_NMS_TOP_N = 1000 206 | # __C.TEST.RPN_POST_NMS_TOP_N = 2000 207 | # Proposal height and width both need to be greater than RPN_MIN_SIZE (at orig image scale) 208 | __C.TEST.RPN_MIN_SIZE = 8 209 | 210 | # 211 | # MISC 212 | # 213 | 214 | # The mapping from image coordinates to feature map coordinates might cause 215 | # some boxes that are distinct in image space to become identical in feature 216 | # coordinates. If DEDUP_BOXES > 0, then DEDUP_BOXES is used as the scale factor 217 | # for identifying duplicate boxes. 218 | # 1/16 is correct for {Alex,Caffe}Net, VGG_CNN_M_1024, and VGG16 219 | __C.DEDUP_BOXES = 1. / 16. 220 | 221 | # Pixel mean values (BGR order) as a (1, 1, 3) array 222 | # We use the same pixel mean for all networks even though it's not exactly what 223 | # they were trained with 224 | __C.PIXEL_MEANS = np.array([[[102.9801, 115.9465, 122.7717]]]) 225 | 226 | # For reproducibility 227 | # __C.RNG_SEED = 3 228 | __C.RNG_SEED = 3 229 | 230 | # A small number that's used many times 231 | __C.EPS = 1e-14 232 | 233 | # Root directory of project 234 | __C.ROOT_DIR = osp.abspath(osp.join(osp.dirname(__file__), '..', '..')) 235 | 236 | # Data directory 237 | # __C.DATA_DIR = osp.abspath(osp.join(__C.ROOT_DIR, 'data')) 238 | # 将voc的文件位置移到仓库外部 239 | __C.DATA_DIR = '/Users/xiaofeng/Code/Github/dataset/CHINESE_OCR/ctpn/' 240 | 241 | # Model directory 242 | __C.MODELS_DIR = osp.abspath(osp.join(__C.ROOT_DIR, 'models', 'pascal_voc')) 243 | 244 | # Name (or path to) the matlab executable 245 | __C.MATLAB = 'matlab' 246 | 247 | # Place outputs under an experiments directory 248 | __C.EXP_DIR = 'default' 249 | __C.LOG_DIR = 'default' 250 | 251 | # Use GPU implementation of non-maximum suppression 252 | __C.USE_GPU_NMS = True 253 | 254 | # Default GPU device id 255 | __C.GPU_ID = 0 256 | 257 | 258 | def get_output_dir(imdb, weights_filename): 259 | """Return the directory where experimental artifacts are placed. 260 | If the directory does not exist, it is created. 261 | A canonical path is built using the name from an imdb and a network 262 | (if not None). 263 | """ 264 | outdir = osp.abspath( 265 | osp.join(__C.ROOT_DIR, 'output', __C.EXP_DIR, imdb.name)) 266 | if weights_filename is not None: 267 | outdir = osp.join(outdir, weights_filename) 268 | if not os.path.exists(outdir): 269 | os.makedirs(outdir) 270 | return outdir 271 | 272 | 273 | def get_log_dir(imdb): 274 | """Return the directory where experimental artifacts are placed. 275 | If the directory does not exist, it is created. 276 | A canonical path is built using the name from an imdb and a network 277 | (if not None). 278 | """ 279 | log_dir = osp.abspath( \ 280 | osp.join(__C.ROOT_DIR, 'logs', __C.LOG_DIR, imdb.name, strftime("%Y-%m-%d-%H-%M-%S", localtime()))) 281 | if not os.path.exists(log_dir): 282 | os.makedirs(log_dir) 283 | return log_dir 284 | 285 | 286 | def _merge_a_into_b(a, b): 287 | """Merge config dictionary a into config dictionary b, clobbering the 288 | options in b whenever they are also specified in a. 289 | """ 290 | if type(a) is not edict: 291 | return 292 | 293 | for k, v in a.items(): 294 | # a must specify keys that are in b 295 | # if not b.has_key(k): #--python2 296 | if k not in b: # python3 297 | raise KeyError('{} is not a valid config key'.format(k)) 298 | 299 | # the types must match, too 300 | old_type = type(b[k]) 301 | if old_type is not type(v): 302 | if isinstance(b[k], np.ndarray): 303 | v = np.array(v, dtype=b[k].dtype) 304 | else: 305 | raise ValueError(('Type mismatch ({} vs. {}) ' 306 | 'for config key: {}').format( 307 | type(b[k]), type(v), k)) 308 | 309 | # recursively merge dicts 310 | if type(v) is edict: 311 | try: 312 | _merge_a_into_b(a[k], b[k]) 313 | except: 314 | print('Error under config key: {}'.format(k)) 315 | raise 316 | else: 317 | b[k] = v 318 | 319 | 320 | def cfg_from_file(filename): 321 | """Load a config file and merge it into the default options.""" 322 | import yaml 323 | with open(filename, 'r') as f: 324 | yaml_cfg = edict(yaml.load(f)) 325 | 326 | _merge_a_into_b(yaml_cfg, __C) 327 | 328 | 329 | def cfg_from_list(cfg_list): 330 | """Set config keys via list (e.g., from command line).""" 331 | from ast import literal_eval 332 | assert len(cfg_list) % 2 == 0 333 | for k, v in zip(cfg_list[0::2], cfg_list[1::2]): 334 | key_list = k.split('.') 335 | d = __C 336 | for subkey in key_list[:-1]: 337 | # assert d.has_key(subkey) 338 | assert subkey in d 339 | d = d[subkey] 340 | subkey = key_list[-1] 341 | assert d.has_key(subkey) 342 | try: 343 | value = literal_eval(v) 344 | except: 345 | # handle the case when v is a string literal 346 | value = v 347 | assert type(value) == type(d[subkey]), \ 348 | 'type {} does not match original type {}'.format( 349 | type(value), type(d[subkey])) 350 | d[subkey] = value 351 | -------------------------------------------------------------------------------- /ctpn/lib/fast_rcnn/nms_wrapper.py: -------------------------------------------------------------------------------- 1 | from .config import cfg 2 | from ..utils.cython_nms import nms as cython_nms 3 | 4 | try: 5 | from lib.utils.gpu_nms import gpu_nms 6 | except: 7 | gpu_nms = cython_nms 8 | pass 9 | 10 | 11 | def nms(dets, thresh): 12 | if dets.shape[0] == 0: 13 | return [] 14 | if cfg.USE_GPU_NMS: 15 | try: 16 | return gpu_nms(dets, thresh, device_id=cfg.GPU_ID) 17 | except: 18 | return cython_nms(dets, thresh) 19 | else: 20 | return cython_nms(dets, thresh) 21 | -------------------------------------------------------------------------------- /ctpn/lib/fast_rcnn/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # _Author_: xiaofeng 4 | # Date: 2018-04-08 14:31:45 5 | # Last Modified by: xiaofeng 6 | # Last Modified time: 2018-04-08 14:31:45 7 | 8 | import cv2, os, sys 9 | import numpy as np 10 | 11 | # sys.path.append(os.getcwd()) 12 | parentdir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 13 | sys.path.insert(0, parentdir) 14 | from .config import cfg 15 | from ..utils.blob import im_list_to_blob 16 | 17 | # from ..utils.blob import im_list_to_blob 18 | 19 | 20 | def _get_image_blob(im): 21 | im_orig = im.astype(np.float32, copy=True) 22 | im_orig -= cfg.PIXEL_MEANS 23 | 24 | im_shape = im_orig.shape 25 | im_size_min = np.min(im_shape[0:2]) 26 | im_size_max = np.max(im_shape[0:2]) 27 | 28 | processed_ims = [] 29 | im_scale_factors = [] 30 | 31 | for target_size in cfg.TEST.SCALES: 32 | im_scale = float(target_size) / float(im_size_min) 33 | # Prevent the biggest axis from being more than MAX_SIZE 34 | if np.round(im_scale * im_size_max) > cfg.TEST.MAX_SIZE: 35 | im_scale = float(cfg.TEST.MAX_SIZE) / float(im_size_max) 36 | im = cv2.resize( 37 | im_orig, 38 | None, 39 | None, 40 | fx=im_scale, 41 | fy=im_scale, 42 | interpolation=cv2.INTER_LINEAR) 43 | im_scale_factors.append(im_scale) 44 | processed_ims.append(im) 45 | 46 | # Create a blob to hold the input images 47 | blob = im_list_to_blob(processed_ims) 48 | 49 | return blob, np.array(im_scale_factors) 50 | 51 | 52 | def _get_blobs(im, rois): 53 | blobs = {'data': None, 'rois': None} 54 | blobs['data'], im_scale_factors = _get_image_blob(im) 55 | return blobs, im_scale_factors 56 | 57 | 58 | def test_ctpn(sess, net, im, boxes=None): 59 | blobs, im_scales = _get_blobs(im, boxes) 60 | if cfg.TEST.HAS_RPN: 61 | im_blob = blobs['data'] 62 | blobs['im_info'] = np.array( 63 | [[im_blob.shape[1], im_blob.shape[2], im_scales[0]]], 64 | dtype=np.float32) 65 | 66 | # forward pass 67 | if cfg.TEST.HAS_RPN: 68 | feed_dict = { 69 | net.data: blobs['data'], 70 | net.im_info: blobs['im_info'], 71 | net.keep_prob: 1.0 72 | } 73 | 74 | rois = sess.run([net.get_output('rois')[0]], feed_dict=feed_dict) 75 | rois = rois[0] 76 | 77 | scores = rois[:, 0] 78 | if cfg.TEST.HAS_RPN: 79 | assert len(im_scales) == 1, "Only single-image batch implemented" 80 | boxes = rois[:, 1:5] / im_scales[0] 81 | return scores, boxes 82 | -------------------------------------------------------------------------------- /ctpn/lib/fast_rcnn/train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from ..fast_rcnn.config import cfg 9 | from ..roi_data_layer import roidb as rdl_roidb 10 | from ..roi_data_layer.layer import RoIDataLayer 11 | from..utils.timer import Timer 12 | # from lib.datasets import imdb as imdb 13 | 14 | _DEBUG = False 15 | 16 | 17 | class SolverWrapper(object): 18 | def __init__(self, 19 | sess, 20 | network, 21 | imdb, 22 | roidb, 23 | output_dir, 24 | logdir, 25 | pretrained_model=None): 26 | """Initialize the SolverWrapper.""" 27 | self.net = network 28 | self.imdb = imdb 29 | self.roidb = roidb 30 | self.output_dir = output_dir 31 | self.pretrained_model = pretrained_model 32 | 33 | print('Computing bounding-box regression targets...') 34 | if cfg.TRAIN.BBOX_REG: 35 | self.bbox_means, self.bbox_stds = rdl_roidb.add_bbox_regression_targets( 36 | roidb) 37 | print('done') 38 | 39 | # For checkpoint 40 | self.saver = tf.train.Saver( 41 | max_to_keep=1, write_version=tf.train.SaverDef.V2) 42 | self.writer = tf.summary.FileWriter( 43 | logdir=logdir, graph=tf.get_default_graph(), flush_secs=5) 44 | 45 | def snapshot(self, sess, iter): 46 | net = self.net 47 | if cfg.TRAIN.BBOX_REG and 'bbox_pred' in net.layers and cfg.TRAIN.BBOX_NORMALIZE_TARGETS: 48 | # save original values 49 | with tf.variable_scope('bbox_pred', reuse=True): 50 | weights = tf.get_variable("weights") 51 | biases = tf.get_variable("biases") 52 | 53 | orig_0 = weights.eval() 54 | orig_1 = biases.eval() 55 | 56 | # scale and shift with bbox reg unnormalization; then save snapshot 57 | weights_shape = weights.get_shape().as_list() 58 | sess.run( 59 | weights.assign(orig_0 * np.tile(self.bbox_stds, 60 | (weights_shape[0], 1)))) 61 | sess.run(biases.assign(orig_1 * self.bbox_stds + self.bbox_means)) 62 | 63 | if not os.path.exists(self.output_dir): 64 | os.makedirs(self.output_dir) 65 | 66 | infix = ('_' + cfg.TRAIN.SNAPSHOT_INFIX 67 | if cfg.TRAIN.SNAPSHOT_INFIX != '' else '') 68 | filename = (cfg.TRAIN.SNAPSHOT_PREFIX + infix + 69 | '_iter_{:d}'.format(iter + 1) + '.ckpt') 70 | filename = os.path.join(self.output_dir, filename) 71 | # save 72 | self.saver.save(sess, filename) 73 | print('Wrote snapshot to: {:s}'.format(filename)) 74 | 75 | if cfg.TRAIN.BBOX_REG and 'bbox_pred' in net.layers: 76 | # restore net to original state 77 | sess.run(weights.assign(orig_0)) 78 | sess.run(biases.assign(orig_1)) 79 | 80 | def build_image_summary(self): 81 | # A simple graph for write image summary 82 | 83 | log_image_data = tf.placeholder(tf.uint8, [None, None, 3]) 84 | log_image_name = tf.placeholder(tf.string) 85 | # import tensorflow.python.ops.gen_logging_ops as logging_ops 86 | from tensorflow.python.ops import gen_logging_ops 87 | from tensorflow.python.framework import ops as _ops 88 | log_image = gen_logging_ops.image_summary( 89 | log_image_name, tf.expand_dims(log_image_data, 0), max_images=1) 90 | _ops.add_to_collection(_ops.GraphKeys.SUMMARIES, log_image) 91 | # log_image = tf.summary.image(log_image_name, tf.expand_dims(log_image_data, 0), max_outputs=1) 92 | return log_image, log_image_data, log_image_name 93 | 94 | def train_model(self, sess, max_iters, restore=False): 95 | """Network training loop.""" 96 | data_layer = get_data_layer(self.roidb, self.imdb.num_classes) 97 | total_loss, model_loss, rpn_cross_entropy, rpn_loss_box = self.net.build_loss( 98 | ohem=cfg.TRAIN.OHEM) 99 | # scalar summary 100 | tf.summary.scalar('rpn_reg_loss', rpn_loss_box) 101 | tf.summary.scalar('rpn_cls_loss', rpn_cross_entropy) 102 | tf.summary.scalar('model_loss', model_loss) 103 | tf.summary.scalar('total_loss', total_loss) 104 | summary_op = tf.summary.merge_all() 105 | 106 | log_image, log_image_data, log_image_name = \ 107 | self.build_image_summary() 108 | 109 | # optimizer 110 | lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False) 111 | if cfg.TRAIN.SOLVER == 'Adam': 112 | opt = tf.train.AdamOptimizer(cfg.TRAIN.LEARNING_RATE) 113 | elif cfg.TRAIN.SOLVER == 'RMS': 114 | opt = tf.train.RMSPropOptimizer(cfg.TRAIN.LEARNING_RATE) 115 | else: 116 | # lr = tf.Variable(0.0, trainable=False) 117 | momentum = cfg.TRAIN.MOMENTUM 118 | opt = tf.train.MomentumOptimizer(lr, momentum) 119 | 120 | global_step = tf.Variable(0, trainable=False) 121 | with_clip = True 122 | if with_clip: 123 | tvars = tf.trainable_variables() 124 | grads, norm = tf.clip_by_global_norm( 125 | tf.gradients(total_loss, tvars), 10.0) 126 | train_op = opt.apply_gradients( 127 | list(zip(grads, tvars)), global_step=global_step) 128 | else: 129 | train_op = opt.minimize(total_loss, global_step=global_step) 130 | 131 | # intialize variables 132 | sess.run(tf.global_variables_initializer()) 133 | restore_iter = 0 134 | 135 | # load vgg16 136 | if self.pretrained_model is not None and not restore: 137 | try: 138 | print(('Loading pretrained model ' 139 | 'weights from {:s}').format(self.pretrained_model)) 140 | self.net.load(self.pretrained_model, sess, True) 141 | except: 142 | raise 'Check your pretrained model {:s}'.format( 143 | self.pretrained_model) 144 | 145 | # resuming a trainer 146 | if restore: 147 | # try: 148 | print('output_dir:', self.output_dir) 149 | # 加载ckpt文件路径,而非指向checkpoint 150 | ckpt = tf.train.get_checkpoint_state( 151 | self.output_dir + '/') 152 | print( 153 | 'Restoring from {}...'.format(ckpt.model_checkpoint_path), 154 | end=' ') 155 | self.saver.restore(sess, ckpt.model_checkpoint_path) 156 | stem = os.path.splitext( 157 | os.path.basename(ckpt.model_checkpoint_path))[0] 158 | restore_iter = int(stem.split('_')[-1]) 159 | sess.run(global_step.assign(restore_iter)) 160 | print('done') 161 | # except: 162 | 163 | # raise 'Check your pretrained {:s}'.format(ckpt.model_checkpoint_path) 164 | 165 | last_snapshot_iter = -1 166 | timer = Timer() 167 | print(restore_iter, max_iters) 168 | for iter in range(restore_iter, max_iters): 169 | timer.tic() 170 | # learning rate 171 | print(iter) 172 | if iter != 0 and iter % cfg.TRAIN.STEPSIZE == 0: 173 | sess.run(tf.assign(lr, lr.eval() * cfg.TRAIN.GAMMA)) 174 | print(lr) 175 | 176 | # get one batch 177 | blobs = data_layer.forward() 178 | 179 | feed_dict = { 180 | self.net.data: blobs['data'], 181 | self.net.im_info: blobs['im_info'], 182 | self.net.keep_prob: 0.5, 183 | self.net.gt_boxes: blobs['gt_boxes'], 184 | self.net.gt_ishard: blobs['gt_ishard'], 185 | self.net.dontcare_areas: blobs['dontcare_areas'] 186 | } 187 | res_fetches = [] 188 | fetch_list = [ 189 | total_loss, model_loss, rpn_cross_entropy, rpn_loss_box, 190 | summary_op, train_op 191 | ] + res_fetches 192 | 193 | total_loss_val, model_loss_val, rpn_loss_cls_val, rpn_loss_box_val, \ 194 | summary_str, _ = sess.run(fetches=fetch_list, feed_dict=feed_dict) 195 | 196 | self.writer.add_summary( 197 | summary=summary_str, global_step=global_step.eval()) 198 | 199 | _diff_time = timer.toc(average=False) 200 | 201 | if (iter) % (cfg.TRAIN.DISPLAY) == 0: 202 | print( 203 | 'iter: %d / %d, total loss: %.4f, model loss: %.4f, rpn_loss_cls: %.4f, rpn_loss_box: %.4f, lr: %f' % \ 204 | (iter, max_iters, total_loss_val, model_loss_val, rpn_loss_cls_val, rpn_loss_box_val, lr.eval())) 205 | print('speed: {:.3f}s / iter'.format(_diff_time)) 206 | 207 | if (iter + 1) % cfg.TRAIN.SNAPSHOT_ITERS == 0: 208 | last_snapshot_iter = iter 209 | self.snapshot(sess, iter) 210 | 211 | if last_snapshot_iter != iter: 212 | self.snapshot(sess, iter) 213 | 214 | 215 | def get_training_roidb(imdb): 216 | """Returns a roidb (Region of Interest database) for use in training.""" 217 | if cfg.TRAIN.USE_FLIPPED: 218 | print('Appending horizontally-flipped training examples...') 219 | imdb.append_flipped_images() 220 | print('done') 221 | 222 | print('Preparing training data...') 223 | if cfg.TRAIN.HAS_RPN: 224 | rdl_roidb.prepare_roidb(imdb) 225 | else: 226 | rdl_roidb.prepare_roidb(imdb) 227 | print('done') 228 | 229 | return imdb.roidb 230 | 231 | 232 | def get_data_layer(roidb, num_classes): 233 | """return a data layer.""" 234 | if cfg.TRAIN.HAS_RPN: 235 | if cfg.IS_MULTISCALE: 236 | # obsolete 237 | # layer = GtDataLayer(roidb) 238 | raise "Calling caffe modules..." 239 | else: 240 | layer = RoIDataLayer(roidb, num_classes) 241 | else: 242 | layer = RoIDataLayer(roidb, num_classes) 243 | 244 | return layer 245 | 246 | 247 | def train_net(network, 248 | imdb, 249 | roidb, 250 | output_dir, 251 | log_dir, 252 | pretrained_model=None, 253 | max_iters=40000, 254 | restore=False): 255 | """Train a Fast R-CNN network.""" 256 | 257 | config = tf.ConfigProto(allow_soft_placement=True) 258 | config.gpu_options.allocator_type = 'BFC' 259 | config.gpu_options.per_process_gpu_memory_fraction = 0.75 260 | with tf.Session(config=config) as sess: 261 | sw = SolverWrapper( 262 | sess, 263 | network, 264 | imdb, 265 | roidb, 266 | output_dir, 267 | logdir=log_dir, 268 | pretrained_model=pretrained_model) 269 | print('Solving...') 270 | sw.train_model(sess, max_iters, restore=restore) 271 | print('done solving') 272 | -------------------------------------------------------------------------------- /ctpn/lib/networks/VGGnet_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # _Author_: xiaofeng 4 | # Date: 2018-04-22 21:45:13 5 | # Last Modified by: xiaofeng 6 | # Last Modified time: 2018-04-22 21:45:13 7 | 8 | import tensorflow as tf 9 | import numpy as np 10 | from .network import Network 11 | from ..fast_rcnn.config import cfg 12 | 13 | 14 | class VGGnet_test(Network): 15 | def __init__(self, trainable=True): 16 | self.inputs = [] 17 | self.data = tf.placeholder(tf.float32, shape=[None, None, None, 3]) 18 | # a list of [image_height, image_width, scale_ratios] 19 | self.im_info = tf.placeholder(tf.float32, shape=[None, 3]) 20 | self.keep_prob = tf.placeholder(tf.float32) 21 | self.layers = dict({'data': self.data, 'im_info': self.im_info}) 22 | self.trainable = trainable 23 | self.setup() 24 | 25 | def setup(self): 26 | anchor_scales = cfg.ANCHOR_SCALES 27 | _feat_stride = [16, ] 28 | 29 | (self.feed('data').conv(3, 3, 64, 1, 1, name='conv1_1') 30 | .conv(3, 3, 64, 1, 1, name='conv1_2') 31 | .max_pool(2, 2, 2, 2, padding='VALID', name='pool1') 32 | .conv(3, 3, 128, 1, 1, name='conv2_1') 33 | .conv(3, 3, 128, 1, 1, name='conv2_2') 34 | .max_pool(2, 2, 2, 2, padding='VALID', name='pool2') 35 | .conv(3, 3, 256, 1, 1, name='conv3_1') 36 | .conv(3, 3, 256, 1, 1, name='conv3_2') 37 | .conv(3, 3, 256, 1, 1, name='conv3_3') 38 | .max_pool(2, 2, 2, 2, padding='VALID', name='pool3') 39 | .conv(3, 3, 512, 1, 1, name='conv4_1') 40 | .conv(3, 3, 512, 1, 1, name='conv4_2') 41 | .conv(3, 3, 512, 1, 1, name='conv4_3') 42 | .max_pool(2, 2, 2, 2, padding='VALID', name='pool4') 43 | .conv(3, 3, 512, 1, 1, name='conv5_1') 44 | .conv(3, 3, 512, 1, 1, name='conv5_2') 45 | .conv(3, 3, 512, 1, 1, name='conv5_3')) 46 | 47 | # 卷积3x3x512--步长1x1 48 | # 使用vgg最后一层的feature map进行rpn区域提议 49 | (self.feed('conv5_3').conv(3, 3, 512, 1, 1, name='rpn_conv/3x3')) 50 | # rpn的输出为512个通道 51 | # 双向lstm 包含128个节点 52 | (self.feed('rpn_conv/3x3').Bilstm(512, 128, 512, name='lstm_o')) 53 | 54 | # lstm全连接 55 | (self.feed('lstm_o').lstm_fc( 56 | 512, len(anchor_scales) * 10 * 4, name='rpn_bbox_pred')) 57 | (self.feed('lstm_o').lstm_fc( 58 | 512, len(anchor_scales) * 10 * 2, name='rpn_cls_score')) 59 | 60 | # shape is (1, H, W, Ax2) -> (1, H, WxA, 2) 61 | (self.feed('rpn_cls_score').spatial_reshape_layer( 62 | 2, name='rpn_cls_score_reshape') 63 | .spatial_softmax(name='rpn_cls_prob')) 64 | 65 | # shape is (1, H, WxA, 2) -> (1, H, W, Ax2) 66 | (self.feed('rpn_cls_prob').spatial_reshape_layer( 67 | len(anchor_scales) * 10 * 2, name='rpn_cls_prob_reshape')) 68 | 69 | (self.feed('rpn_cls_prob_reshape', 'rpn_bbox_pred', 'im_info') 70 | .proposal_layer(_feat_stride, anchor_scales, 'TEST', name='rois')) -------------------------------------------------------------------------------- /ctpn/lib/networks/VGGnet_train.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import tensorflow as tf 3 | 4 | from .network import Network 5 | from ..fast_rcnn.config import cfg 6 | 7 | 8 | class VGGnet_train(Network): 9 | def __init__(self, trainable=True): 10 | self.inputs = [] 11 | self.data = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='data') 12 | self.im_info = tf.placeholder(tf.float32, shape=[None, 3], name='im_info') 13 | self.gt_boxes = tf.placeholder(tf.float32, shape=[None, 5], name='gt_boxes') 14 | self.gt_ishard = tf.placeholder(tf.int32, shape=[None], name='gt_ishard') 15 | self.dontcare_areas = tf.placeholder(tf.float32, shape=[None, 4], name='dontcare_areas') 16 | self.keep_prob = tf.placeholder(tf.float32) 17 | self.layers = dict({'data': self.data, 'im_info': self.im_info, 'gt_boxes': self.gt_boxes, \ 18 | 'gt_ishard': self.gt_ishard, 'dontcare_areas': self.dontcare_areas}) 19 | self.trainable = trainable 20 | self.setup() 21 | 22 | def setup(self): 23 | # n_classes = 21 24 | n_classes = cfg.NCLASSES 25 | # anchor_scales = [8, 16, 32] 26 | anchor_scales = cfg.ANCHOR_SCALES 27 | _feat_stride = [16, ] 28 | # net frame 29 | (self.feed('data') 30 | .conv(3, 3, 64, 1, 1, name='conv1_1') 31 | .conv(3, 3, 64, 1, 1, name='conv1_2') 32 | .max_pool(2, 2, 2, 2, padding='VALID', name='pool1') 33 | .conv(3, 3, 128, 1, 1, name='conv2_1') 34 | .conv(3, 3, 128, 1, 1, name='conv2_2') 35 | .max_pool(2, 2, 2, 2, padding='VALID', name='pool2') 36 | .conv(3, 3, 256, 1, 1, name='conv3_1') 37 | .conv(3, 3, 256, 1, 1, name='conv3_2') 38 | .conv(3, 3, 256, 1, 1, name='conv3_3') 39 | .max_pool(2, 2, 2, 2, padding='VALID', name='pool3') 40 | .conv(3, 3, 512, 1, 1, name='conv4_1') 41 | .conv(3, 3, 512, 1, 1, name='conv4_2') 42 | .conv(3, 3, 512, 1, 1, name='conv4_3') 43 | .max_pool(2, 2, 2, 2, padding='VALID', name='pool4') 44 | .conv(3, 3, 512, 1, 1, name='conv5_1') 45 | .conv(3, 3, 512, 1, 1, name='conv5_2') 46 | .conv(3, 3, 512, 1, 1, name='conv5_3')) 47 | # ========= RPN ============ 48 | (self.feed('conv5_3') 49 | .conv(3, 3, 512, 1, 1, name='rpn_conv/3x3')) 50 | 51 | (self.feed('rpn_conv/3x3').Bilstm(512, 128, 512, name='lstm_o')) 52 | (self.feed('lstm_o').lstm_fc(512, len(anchor_scales) * 10 * 4, name='rpn_bbox_pred')) 53 | (self.feed('lstm_o').lstm_fc(512, len(anchor_scales) * 10 * 2, name='rpn_cls_score')) 54 | 55 | # generating training labels on the fly 56 | # output: rpn_labels(HxWxA, 2) rpn_bbox_targets(HxWxA, 4) rpn_bbox_inside_weights rpn_bbox_outside_weights 57 | # 给每个anchor上标签,并计算真值(也是delta的形式),以及内部权重和外部权重 58 | (self.feed('rpn_cls_score', 'gt_boxes', 'gt_ishard', 'dontcare_areas', 'im_info') 59 | .anchor_target_layer(_feat_stride, anchor_scales, name='rpn-data')) 60 | 61 | # shape is (1, H, W, Ax2) -> (1, H, WxA, 2) 62 | # 给之前得到的score进行softmax,得到0-1之间的得分 63 | (self.feed('rpn_cls_score') 64 | .spatial_reshape_layer(2, name='rpn_cls_score_reshape') 65 | .spatial_softmax(name='rpn_cls_prob')) 66 | -------------------------------------------------------------------------------- /ctpn/lib/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from . import factory 2 | from .VGGnet_test import VGGnet_test 3 | from .VGGnet_train import VGGnet_train 4 | -------------------------------------------------------------------------------- /ctpn/lib/networks/factory.py: -------------------------------------------------------------------------------- 1 | from .VGGnet_test import VGGnet_test 2 | from .VGGnet_train import VGGnet_train 3 | 4 | 5 | def get_network(name): 6 | """Get a network by name.""" 7 | if name.split('_')[0] == 'VGGnet': 8 | if name.split('_')[1] == 'test': 9 | return VGGnet_test() 10 | elif name.split('_')[1] == 'train': 11 | return VGGnet_train() 12 | else: 13 | raise KeyError('Unknown dataset: {}'.format(name)) 14 | else: 15 | raise KeyError('Unknown dataset: {}'.format(name)) 16 | -------------------------------------------------------------------------------- /ctpn/lib/roi_data_layer/__init__.py: -------------------------------------------------------------------------------- 1 | from . import roidb 2 | -------------------------------------------------------------------------------- /ctpn/lib/roi_data_layer/layer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick 6 | # -------------------------------------------------------- 7 | 8 | """The data layer used during training to train a Fast R-CNN network. 9 | 10 | RoIDataLayer implements a Caffe Python layer. 11 | """ 12 | 13 | import numpy as np 14 | 15 | # TODO: make fast_rcnn irrelevant 16 | # >>>> obsolete, because it depends on sth outside of this project 17 | from ..fast_rcnn.config import cfg 18 | # <<<< obsolete 19 | from ..roi_data_layer.minibatch import get_minibatch 20 | 21 | 22 | class RoIDataLayer(object): 23 | """Fast R-CNN data layer used for training.""" 24 | 25 | def __init__(self, roidb, num_classes): 26 | """Set the roidb to be used by this layer during training.""" 27 | self._roidb = roidb 28 | self._num_classes = num_classes 29 | self._shuffle_roidb_inds() 30 | 31 | def _shuffle_roidb_inds(self): 32 | """Randomly permute the training roidb.""" 33 | self._perm = np.random.permutation(np.arange(len(self._roidb))) 34 | self._cur = 0 35 | 36 | def _get_next_minibatch_inds(self): 37 | """Return the roidb indices for the next minibatch.""" 38 | 39 | if cfg.TRAIN.HAS_RPN: 40 | if self._cur + cfg.TRAIN.IMS_PER_BATCH >= len(self._roidb): 41 | self._shuffle_roidb_inds() 42 | 43 | db_inds = self._perm[self._cur:self._cur + cfg.TRAIN.IMS_PER_BATCH] 44 | self._cur += cfg.TRAIN.IMS_PER_BATCH 45 | else: 46 | # sample images 47 | db_inds = np.zeros((cfg.TRAIN.IMS_PER_BATCH), dtype=np.int32) 48 | i = 0 49 | while (i < cfg.TRAIN.IMS_PER_BATCH): 50 | ind = self._perm[self._cur] 51 | num_objs = self._roidb[ind]['boxes'].shape[0] 52 | if num_objs != 0: 53 | db_inds[i] = ind 54 | i += 1 55 | 56 | self._cur += 1 57 | if self._cur >= len(self._roidb): 58 | self._shuffle_roidb_inds() 59 | 60 | return db_inds 61 | 62 | def _get_next_minibatch(self): 63 | """Return the blobs to be used for the next minibatch. 64 | 65 | If cfg.TRAIN.USE_PREFETCH is True, then blobs will be computed in a 66 | separate process and made available through self._blob_queue. 67 | """ 68 | db_inds = self._get_next_minibatch_inds() 69 | minibatch_db = [self._roidb[i] for i in db_inds] 70 | return get_minibatch(minibatch_db, self._num_classes) 71 | 72 | def forward(self): 73 | """Get blobs and copy them into this layer's top blob vector.""" 74 | blobs = self._get_next_minibatch() 75 | return blobs 76 | -------------------------------------------------------------------------------- /ctpn/lib/roi_data_layer/minibatch.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | import numpy.random as npr 6 | 7 | from ..fast_rcnn.config import cfg 8 | from ..utils.blob import prep_im_for_blob, im_list_to_blob 9 | 10 | 11 | def get_minibatch(roidb, num_classes): 12 | """Given a roidb, construct a minibatch sampled from it.""" 13 | num_images = len(roidb) 14 | # Sample random scales to use for each image in this batch 15 | random_scale_inds = npr.randint( 16 | 0, high=len(cfg.TRAIN.SCALES), size=num_images) 17 | assert (cfg.TRAIN.BATCH_SIZE % num_images == 0), \ 18 | 'num_images ({}) must divide BATCH_SIZE ({})'. \ 19 | format(num_images, cfg.TRAIN.BATCH_SIZE) 20 | rois_per_image = cfg.TRAIN.BATCH_SIZE / num_images 21 | fg_rois_per_image = np.round(cfg.TRAIN.FG_FRACTION * rois_per_image) 22 | 23 | # Get the input image blob, formatted for caffe 24 | im_blob, im_scales = _get_image_blob(roidb, random_scale_inds) 25 | 26 | blobs = {'data': im_blob} 27 | 28 | if cfg.TRAIN.HAS_RPN: 29 | assert len(im_scales) == 1, "Single batch only" 30 | assert len(roidb) == 1, "Single batch only" 31 | # gt boxes: (x1, y1, x2, y2, cls) 32 | gt_inds = np.where(roidb[0]['gt_classes'] != 0)[0] 33 | gt_boxes = np.empty((len(gt_inds), 5), dtype=np.float32) 34 | gt_boxes[:, 0:4] = roidb[0]['boxes'][gt_inds, :] * im_scales[0] 35 | gt_boxes[:, 4] = roidb[0]['gt_classes'][gt_inds] 36 | blobs['gt_boxes'] = gt_boxes 37 | blobs['gt_ishard'] = roidb[0]['gt_ishard'][gt_inds] \ 38 | if 'gt_ishard' in roidb[0] else np.zeros(gt_inds.size, dtype=int) 39 | # blobs['gt_ishard'] = roidb[0]['gt_ishard'][gt_inds] 40 | blobs['dontcare_areas'] = roidb[0]['dontcare_areas'] * im_scales[0] \ 41 | if 'dontcare_areas' in roidb[0] else np.zeros([0, 4], dtype=float) 42 | blobs['im_info'] = np.array( 43 | [[im_blob.shape[1], im_blob.shape[2], im_scales[0]]], 44 | dtype=np.float32) 45 | blobs['im_name'] = os.path.basename(roidb[0]['image']) 46 | 47 | else: # not using RPN 48 | # Now, build the region of interest and label blobs 49 | rois_blob = np.zeros((0, 5), dtype=np.float32) 50 | labels_blob = np.zeros((0), dtype=np.float32) 51 | bbox_targets_blob = np.zeros((0, 4 * num_classes), dtype=np.float32) 52 | bbox_inside_blob = np.zeros(bbox_targets_blob.shape, dtype=np.float32) 53 | # all_overlaps = [] 54 | for im_i in range(num_images): 55 | labels, overlaps, im_rois, bbox_targets, bbox_inside_weights \ 56 | = _sample_rois(roidb[im_i], fg_rois_per_image, rois_per_image, 57 | num_classes) 58 | 59 | # Add to RoIs blob 60 | rois = _project_im_rois(im_rois, im_scales[im_i]) 61 | batch_ind = im_i * np.ones((rois.shape[0], 1)) 62 | rois_blob_this_image = np.hstack((batch_ind, rois)) 63 | rois_blob = np.vstack((rois_blob, rois_blob_this_image)) 64 | 65 | # Add to labels, bbox targets, and bbox loss blobs 66 | labels_blob = np.hstack((labels_blob, labels)) 67 | bbox_targets_blob = np.vstack((bbox_targets_blob, bbox_targets)) 68 | bbox_inside_blob = np.vstack((bbox_inside_blob, 69 | bbox_inside_weights)) 70 | # all_overlaps = np.hstack((all_overlaps, overlaps)) 71 | 72 | # For debug visualizations 73 | # _vis_minibatch(im_blob, rois_blob, labels_blob, all_overlaps) 74 | 75 | blobs['rois'] = rois_blob 76 | blobs['labels'] = labels_blob 77 | 78 | if cfg.TRAIN.BBOX_REG: 79 | blobs['bbox_targets'] = bbox_targets_blob 80 | blobs['bbox_inside_weights'] = bbox_inside_blob 81 | blobs['bbox_outside_weights'] = \ 82 | np.array(bbox_inside_blob > 0).astype(np.float32) 83 | 84 | return blobs 85 | 86 | 87 | def _sample_rois(roidb, fg_rois_per_image, rois_per_image, num_classes): 88 | """Generate a random sample of RoIs comprising foreground and background 89 | examples. 90 | """ 91 | # label = class RoI has max overlap with 92 | labels = roidb['max_classes'] 93 | overlaps = roidb['max_overlaps'] 94 | rois = roidb['boxes'] 95 | 96 | # Select foreground RoIs as those with >= FG_THRESH overlap 97 | fg_inds = np.where(overlaps >= cfg.TRAIN.FG_THRESH)[0] 98 | # Guard against the case when an image has fewer than fg_rois_per_image 99 | # foreground RoIs 100 | fg_rois_per_this_image = np.minimum(fg_rois_per_image, fg_inds.size) 101 | # Sample foreground regions without replacement 102 | if fg_inds.size > 0: 103 | fg_inds = npr.choice( 104 | fg_inds, size=fg_rois_per_this_image, replace=False) 105 | 106 | # Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI) 107 | bg_inds = np.where((overlaps < cfg.TRAIN.BG_THRESH_HI) & 108 | (overlaps >= cfg.TRAIN.BG_THRESH_LO))[0] 109 | # Compute number of background RoIs to take from this image (guarding 110 | # against there being fewer than desired) 111 | bg_rois_per_this_image = rois_per_image - fg_rois_per_this_image 112 | bg_rois_per_this_image = np.minimum(bg_rois_per_this_image, bg_inds.size) 113 | # Sample foreground regions without replacement 114 | if bg_inds.size > 0: 115 | bg_inds = npr.choice( 116 | bg_inds, size=bg_rois_per_this_image, replace=False) 117 | 118 | # The indices that we're selecting (both fg and bg) 119 | keep_inds = np.append(fg_inds, bg_inds) 120 | # Select sampled values from various arrays: 121 | labels = labels[keep_inds] 122 | # Clamp labels for the background RoIs to 0 123 | labels[fg_rois_per_this_image:] = 0 124 | overlaps = overlaps[keep_inds] 125 | rois = rois[keep_inds] 126 | 127 | bbox_targets, bbox_inside_weights = _get_bbox_regression_labels( 128 | roidb['bbox_targets'][keep_inds, :], num_classes) 129 | 130 | return labels, overlaps, rois, bbox_targets, bbox_inside_weights 131 | 132 | 133 | def _get_image_blob(roidb, scale_inds): 134 | """Builds an input blob from the images in the roidb at the specified 135 | scales. 136 | """ 137 | num_images = len(roidb) 138 | processed_ims = [] 139 | im_scales = [] 140 | for i in range(num_images): 141 | im = cv2.imread(roidb[i]['image']) 142 | if roidb[i]['flipped']: 143 | im = im[:, ::-1, :] 144 | target_size = cfg.TRAIN.SCALES[scale_inds[i]] 145 | im, im_scale = prep_im_for_blob(im, cfg.PIXEL_MEANS, target_size, 146 | cfg.TRAIN.MAX_SIZE) 147 | im_scales.append(im_scale) 148 | processed_ims.append(im) 149 | 150 | # Create a blob to hold the input images 151 | blob = im_list_to_blob(processed_ims) 152 | 153 | return blob, im_scales 154 | 155 | 156 | def _project_im_rois(im_rois, im_scale_factor): 157 | """Project image RoIs into the rescaled training image.""" 158 | rois = im_rois * im_scale_factor 159 | return rois 160 | 161 | 162 | def _get_bbox_regression_labels(bbox_target_data, num_classes): 163 | """Bounding-box regression targets are stored in a compact form in the 164 | roidb. 165 | 166 | This function expands those targets into the 4-of-4*K representation used 167 | by the network (i.e. only one class has non-zero targets). The loss weights 168 | are similarly expanded. 169 | 170 | Returns: 171 | bbox_target_data (ndarray): N x 4K blob of regression targets 172 | bbox_inside_weights (ndarray): N x 4K blob of loss weights 173 | """ 174 | clss = bbox_target_data[:, 0] 175 | bbox_targets = np.zeros((clss.size, 4 * num_classes), dtype=np.float32) 176 | bbox_inside_weights = np.zeros(bbox_targets.shape, dtype=np.float32) 177 | inds = np.where(clss > 0)[0] 178 | for ind in inds: 179 | cls = clss[ind] 180 | start = 4 * cls 181 | end = start + 4 182 | bbox_targets[ind, start:end] = bbox_target_data[ind, 1:] 183 | bbox_inside_weights[ind, start:end] = cfg.TRAIN.BBOX_INSIDE_WEIGHTS 184 | return bbox_targets, bbox_inside_weights 185 | 186 | 187 | def _vis_minibatch(im_blob, rois_blob, labels_blob, overlaps): 188 | """Visualize a mini-batch for debugging.""" 189 | import matplotlib.pyplot as plt 190 | for i in range(rois_blob.shape[0]): 191 | rois = rois_blob[i, :] 192 | im_ind = rois[0] 193 | roi = rois[1:] 194 | im = im_blob[im_ind, :, :, :].transpose((1, 2, 0)).copy() 195 | im += cfg.PIXEL_MEANS 196 | im = im[:, :, (2, 1, 0)] 197 | im = im.astype(np.uint8) 198 | cls = labels_blob[i] 199 | plt.imshow(im) 200 | print('class: ', cls, ' overlap: ', overlaps[i]) 201 | plt.gca().add_patch( 202 | plt.Rectangle( 203 | (roi[0], roi[1]), 204 | roi[2] - roi[0], 205 | roi[3] - roi[1], 206 | fill=False, 207 | edgecolor='r', 208 | linewidth=3)) 209 | plt.show() 210 | -------------------------------------------------------------------------------- /ctpn/lib/roi_data_layer/roidb.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | import numpy as np 3 | 4 | from ..fast_rcnn.bbox_transform import bbox_transform 5 | from ..fast_rcnn.config import cfg 6 | from ..utils.bbox import bbox_overlaps 7 | 8 | 9 | def prepare_roidb(imdb): 10 | """Enrich the imdb's roidb by adding some derived quantities that 11 | are useful for training. This function precomputes the maximum 12 | overlap, taken over ground-truth boxes, between each ROI and 13 | each ground-truth box. The class with maximum overlap is also 14 | recorded. 15 | """ 16 | sizes = [ 17 | PIL.Image.open(imdb.image_path_at(i)).size 18 | for i in range(imdb.num_images) 19 | ] 20 | roidb = imdb.roidb 21 | for i in range(len(imdb.image_index)): 22 | roidb[i]['image'] = imdb.image_path_at(i) 23 | roidb[i]['width'] = sizes[i][0] 24 | roidb[i]['height'] = sizes[i][1] 25 | # need gt_overlaps as a dense array for argmax 26 | gt_overlaps = roidb[i]['gt_overlaps'].toarray() 27 | # max overlap with gt over classes (columns) 28 | max_overlaps = gt_overlaps.max(axis=1) 29 | # gt class that had the max overlap 30 | max_classes = gt_overlaps.argmax(axis=1) 31 | roidb[i]['max_classes'] = max_classes 32 | roidb[i]['max_overlaps'] = max_overlaps 33 | # sanity checks 34 | # max overlap of 0 => class should be zero (background) 35 | zero_inds = np.where(max_overlaps == 0)[0] 36 | assert all(max_classes[zero_inds] == 0) 37 | # max overlap > 0 => class should not be zero (must be a fg class) 38 | nonzero_inds = np.where(max_overlaps > 0)[0] 39 | assert all(max_classes[nonzero_inds] != 0) 40 | 41 | 42 | def add_bbox_regression_targets(roidb): 43 | """ 44 | Add information needed to train bounding-box regressors. 45 | For each roi find the corresponding gt box, and compute the distance. 46 | then normalize the distance into Gaussian by minus mean and divided by std 47 | """ 48 | assert len(roidb) > 0 49 | assert 'max_classes' in roidb[0], 'Did you call prepare_roidb first?' 50 | 51 | num_images = len(roidb) 52 | # Infer number of classes from the number of columns in gt_overlaps 53 | num_classes = roidb[0]['gt_overlaps'].shape[1] 54 | for im_i in range(num_images): 55 | rois = roidb[im_i]['boxes'] 56 | max_overlaps = roidb[im_i]['max_overlaps'] 57 | max_classes = roidb[im_i]['max_classes'] 58 | roidb[im_i]['bbox_targets'] = \ 59 | _compute_targets(rois, max_overlaps, max_classes) 60 | 61 | if cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED: 62 | # Use fixed / precomputed "means" and "stds" instead of empirical values 63 | means = np.tile( 64 | np.array(cfg.TRAIN.BBOX_NORMALIZE_MEANS), (num_classes, 1)) 65 | stds = np.tile( 66 | np.array(cfg.TRAIN.BBOX_NORMALIZE_STDS), (num_classes, 1)) 67 | else: 68 | # Compute values needed for means and stds 69 | # var(x) = E(x^2) - E(x)^2 70 | class_counts = np.zeros((num_classes, 1)) + cfg.EPS 71 | sums = np.zeros((num_classes, 4)) 72 | squared_sums = np.zeros((num_classes, 4)) 73 | for im_i in range(num_images): 74 | targets = roidb[im_i]['bbox_targets'] 75 | for cls in range(1, num_classes): 76 | cls_inds = np.where(targets[:, 0] == cls)[0] 77 | if cls_inds.size > 0: 78 | class_counts[cls] += cls_inds.size 79 | sums[cls, :] += targets[cls_inds, 1:].sum(axis=0) 80 | squared_sums[cls, :] += \ 81 | (targets[cls_inds, 1:] ** 2).sum(axis=0) 82 | 83 | means = sums / class_counts 84 | stds = np.sqrt(squared_sums / class_counts - means**2) 85 | # too small number will cause nan error 86 | assert np.min(stds) < 0.01, \ 87 | 'Boxes std is too small, std:{}'.format(stds) 88 | 89 | print('bbox target means:') 90 | print(means) 91 | print(means[1:, :].mean(axis=0)) # ignore bg class 92 | print('bbox target stdevs:') 93 | print(stds) 94 | print(stds[1:, :].mean(axis=0)) # ignore bg class 95 | 96 | # Normalize targets 97 | if cfg.TRAIN.BBOX_NORMALIZE_TARGETS: 98 | print("Normalizing targets") 99 | for im_i in range(num_images): 100 | targets = roidb[im_i]['bbox_targets'] 101 | for cls in range(1, num_classes): 102 | cls_inds = np.where(targets[:, 0] == cls)[0] 103 | roidb[im_i]['bbox_targets'][cls_inds, 1:] -= means[cls, :] 104 | roidb[im_i]['bbox_targets'][cls_inds, 1:] /= stds[cls, :] 105 | else: 106 | print("NOT normalizing targets") 107 | 108 | # These values will be needed for making predictions 109 | # (the predicts will need to be unnormalized and uncentered) 110 | return means.ravel(), stds.ravel() 111 | 112 | 113 | def _compute_targets(rois, overlaps, labels): 114 | """ 115 | Compute bounding-box regression targets for an image. 116 | for each roi find the corresponding gt_box, then compute the distance. 117 | """ 118 | # Indices of ground-truth ROIs 119 | gt_inds = np.where(overlaps == 1)[0] 120 | if len(gt_inds) == 0: 121 | # Bail if the image has no ground-truth ROIs 122 | return np.zeros((rois.shape[0], 5), dtype=np.float32) 123 | # Indices of examples for which we try to make predictions 124 | ex_inds = np.where(overlaps >= cfg.TRAIN.BBOX_THRESH)[0] 125 | 126 | # Get IoU overlap between each ex ROI and gt ROI 127 | ex_gt_overlaps = bbox_overlaps( 128 | np.ascontiguousarray(rois[ex_inds, :], dtype=np.float), 129 | np.ascontiguousarray(rois[gt_inds, :], dtype=np.float)) 130 | 131 | # Find which gt ROI each ex ROI has max overlap with: 132 | # this will be the ex ROI's gt target 133 | gt_assignment = ex_gt_overlaps.argmax(axis=1) 134 | gt_rois = rois[gt_inds[gt_assignment], :] 135 | ex_rois = rois[ex_inds, :] 136 | 137 | targets = np.zeros((rois.shape[0], 5), dtype=np.float32) 138 | targets[ex_inds, 0] = labels[ex_inds] 139 | targets[ex_inds, 1:] = bbox_transform(ex_rois, gt_rois) 140 | return targets 141 | -------------------------------------------------------------------------------- /ctpn/lib/rpn_msr/anchor_target_layer_tf.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import numpy as np 3 | import numpy.random as npr 4 | 5 | from .generate_anchors import generate_anchors 6 | from ..fast_rcnn.bbox_transform import bbox_transform 7 | from ..fast_rcnn.config import cfg 8 | from ..utils.bbox import bbox_overlaps, bbox_intersections 9 | 10 | DEBUG = False 11 | 12 | 13 | def anchor_target_layer(rpn_cls_score, gt_boxes, gt_ishard, dontcare_areas, im_info, _feat_stride=[16, ], 14 | anchor_scales=[16, ]): 15 | """ 16 | Assign anchors to ground-truth targets. Produces anchor classification 17 | labels and bounding-box regression targets. 18 | Parameters 19 | ---------- 20 | rpn_cls_score: (1, H, W, Ax2) bg/fg scores of previous conv layer 21 | gt_boxes: (G, 5) vstack of [x1, y1, x2, y2, class] 22 | gt_ishard: (G, 1), 1 or 0 indicates difficult or not 23 | dontcare_areas: (D, 4), some areas may contains small objs but no labelling. D may be 0 24 | im_info: a list of [image_height, image_width, scale_ratios] 25 | _feat_stride: the downsampling ratio of feature map to the original input image 26 | anchor_scales: the scales to the basic_anchor (basic anchor is [16, 16]) 27 | ---------- 28 | Returns 29 | ---------- 30 | rpn_labels : (HxWxA, 1), for each anchor, 0 denotes bg, 1 fg, -1 dontcare 31 | rpn_bbox_targets: (HxWxA, 4), distances of the anchors to the gt_boxes(may contains some transform) 32 | that are the regression objectives 33 | rpn_bbox_inside_weights: (HxWxA, 4) weights of each boxes, mainly accepts hyper param in cfg 34 | rpn_bbox_outside_weights: (HxWxA, 4) used to balance the fg/bg, 35 | beacuse the numbers of bgs and fgs mays significiantly different 36 | """ 37 | _anchors = generate_anchors(scales=np.array(anchor_scales)) # 生成基本的anchor,一共9个 38 | _num_anchors = _anchors.shape[0] # 9个anchor 39 | 40 | if DEBUG: 41 | print('anchors:') 42 | print(_anchors) 43 | print('anchor shapes:') 44 | print(np.hstack(( 45 | _anchors[:, 2::4] - _anchors[:, 0::4], 46 | _anchors[:, 3::4] - _anchors[:, 1::4], 47 | ))) 48 | _counts = cfg.EPS 49 | _sums = np.zeros((1, 4)) 50 | _squared_sums = np.zeros((1, 4)) 51 | _fg_sum = 0 52 | _bg_sum = 0 53 | _count = 0 54 | 55 | # allow boxes to sit over the edge by a small amount 56 | _allowed_border = 0 57 | # map of shape (..., H, W) 58 | # height, width = rpn_cls_score.shape[1:3] 59 | 60 | im_info = im_info[0] # 图像的高宽及通道数 61 | 62 | # 在feature-map上定位anchor,并加上delta,得到在实际图像中anchor的真实坐标 63 | # Algorithm: 64 | # for each (H, W) location i 65 | # generate 9 anchor boxes centered on cell i 66 | # apply predicted bbox deltas at cell i to each of the 9 anchors 67 | # filter out-of-image anchors 68 | # measure GT overlap 69 | 70 | assert rpn_cls_score.shape[0] == 1, \ 71 | 'Only single item batches are supported' 72 | 73 | # map of shape (..., H, W) 74 | height, width = rpn_cls_score.shape[1:3] # feature-map的高宽 75 | 76 | if DEBUG: 77 | print('AnchorTargetLayer: height', height, 'width', width) 78 | print('') 79 | print('im_size: ({}, {})'.format(im_info[0], im_info[1])) 80 | print('scale: {}'.format(im_info[2])) 81 | print('height, width: ({}, {})'.format(height, width)) 82 | print('rpn: gt_boxes.shape', gt_boxes.shape) 83 | print('rpn: gt_boxes', gt_boxes) 84 | 85 | # 1. Generate proposals from bbox deltas and shifted anchors 86 | shift_x = np.arange(0, width) * _feat_stride 87 | shift_y = np.arange(0, height) * _feat_stride 88 | shift_x, shift_y = np.meshgrid(shift_x, shift_y) # in W H order 89 | # K is H x W 90 | shifts = np.vstack((shift_x.ravel(), shift_y.ravel(), 91 | shift_x.ravel(), shift_y.ravel())).transpose() # 生成feature-map和真实image上anchor之间的偏移量 92 | # add A anchors (1, A, 4) to 93 | # cell K shifts (K, 1, 4) to get 94 | # shift anchors (K, A, 4) 95 | # reshape to (K*A, 4) shifted anchors 96 | A = _num_anchors # 9个anchor 97 | K = shifts.shape[0] # 50*37,feature-map的宽乘高的大小 98 | all_anchors = (_anchors.reshape((1, A, 4)) + 99 | shifts.reshape((1, K, 4)).transpose((1, 0, 2))) # 相当于复制宽高的维度,然后相加 100 | all_anchors = all_anchors.reshape((K * A, 4)) 101 | total_anchors = int(K * A) 102 | 103 | # only keep anchors inside the image 104 | # 仅保留那些还在图像内部的anchor,超出图像的都删掉 105 | inds_inside = np.where( 106 | (all_anchors[:, 0] >= -_allowed_border) & 107 | (all_anchors[:, 1] >= -_allowed_border) & 108 | (all_anchors[:, 2] < im_info[1] + _allowed_border) & # width 109 | (all_anchors[:, 3] < im_info[0] + _allowed_border) # height 110 | )[0] 111 | 112 | if DEBUG: 113 | print('total_anchors', total_anchors) 114 | print('inds_inside', len(inds_inside)) 115 | 116 | # keep only inside anchors 117 | anchors = all_anchors[inds_inside, :] # 保留那些在图像内的anchor 118 | if DEBUG: 119 | print('anchors.shape', anchors.shape) 120 | 121 | # 至此,anchor准备好了 122 | # -------------------------------------------------------------- 123 | # label: 1 is positive, 0 is negative, -1 is dont care 124 | # (A) 125 | labels = np.empty((len(inds_inside),), dtype=np.float32) 126 | labels.fill(-1) # 初始化label,均为-1 127 | 128 | # overlaps between the anchors and the gt boxes 129 | # overlaps (ex, gt), shape is A x G 130 | # 计算anchor和gt-box的overlap,用来给anchor上标签 131 | overlaps = bbox_overlaps( 132 | np.ascontiguousarray(anchors, dtype=np.float), 133 | np.ascontiguousarray(gt_boxes, dtype=np.float)) # 假设anchors有x个,gt_boxes有y个,返回的是一个(x,y)的数组 134 | # 存放每一个anchor和每一个gtbox之间的overlap 135 | argmax_overlaps = overlaps.argmax(axis=1) # (A)#找到和每一个gtbox,overlap最大的那个anchor 136 | max_overlaps = overlaps[np.arange(len(inds_inside)), argmax_overlaps] 137 | gt_argmax_overlaps = overlaps.argmax(axis=0) # G#找到每个位置上9个anchor中与gtbox,overlap最大的那个 138 | gt_max_overlaps = overlaps[gt_argmax_overlaps, 139 | np.arange(overlaps.shape[1])] 140 | gt_argmax_overlaps = np.where(overlaps == gt_max_overlaps)[0] 141 | 142 | if not cfg.TRAIN.RPN_CLOBBER_POSITIVES: 143 | # assign bg labels first so that positive labels can clobber them 144 | labels[max_overlaps < cfg.TRAIN.RPN_NEGATIVE_OVERLAP] = 0 # 先给背景上标签,小于0.3overlap的 145 | 146 | # fg label: for each gt, anchor with highest overlap 147 | labels[gt_argmax_overlaps] = 1 # 每个位置上的9个anchor中overlap最大的认为是前景 148 | # fg label: above threshold IOU 149 | labels[max_overlaps >= cfg.TRAIN.RPN_POSITIVE_OVERLAP] = 1 # overlap大于0.7的认为是前景 150 | 151 | if cfg.TRAIN.RPN_CLOBBER_POSITIVES: 152 | # assign bg labels last so that negative labels can clobber positives 153 | labels[max_overlaps < cfg.TRAIN.RPN_NEGATIVE_OVERLAP] = 0 154 | 155 | # preclude dontcare areas 156 | if dontcare_areas is not None and dontcare_areas.shape[0] > 0: # 这里我们暂时不考虑有doncare_area的存在 157 | # intersec shape is D x A 158 | intersecs = bbox_intersections( 159 | np.ascontiguousarray(dontcare_areas, dtype=np.float), # D x 4 160 | np.ascontiguousarray(anchors, dtype=np.float) # A x 4 161 | ) 162 | intersecs_ = intersecs.sum(axis=0) # A x 1 163 | labels[intersecs_ > cfg.TRAIN.DONTCARE_AREA_INTERSECTION_HI] = -1 164 | 165 | # 这里我们暂时不考虑难样本的问题 166 | # preclude hard samples that are highly occlusioned, truncated or difficult to see 167 | if cfg.TRAIN.PRECLUDE_HARD_SAMPLES and gt_ishard is not None and gt_ishard.shape[0] > 0: 168 | assert gt_ishard.shape[0] == gt_boxes.shape[0] 169 | gt_ishard = gt_ishard.astype(int) 170 | gt_hardboxes = gt_boxes[gt_ishard == 1, :] 171 | if gt_hardboxes.shape[0] > 0: 172 | # H x A 173 | hard_overlaps = bbox_overlaps( 174 | np.ascontiguousarray(gt_hardboxes, dtype=np.float), # H x 4 175 | np.ascontiguousarray(anchors, dtype=np.float)) # A x 4 176 | hard_max_overlaps = hard_overlaps.max(axis=0) # (A) 177 | labels[hard_max_overlaps >= cfg.TRAIN.RPN_POSITIVE_OVERLAP] = -1 178 | max_intersec_label_inds = hard_overlaps.argmax(axis=1) # H x 1 179 | labels[max_intersec_label_inds] = -1 # 180 | 181 | # subsample positive labels if we have too many 182 | # 对正样本进行采样,如果正样本的数量太多的话 183 | # 限制正样本的数量不超过128个 184 | # TODO 这个后期可能还需要修改,毕竟如果使用的是字符的片段,那个正样本的数量是很多的。 185 | num_fg = int(cfg.TRAIN.RPN_FG_FRACTION * cfg.TRAIN.RPN_BATCHSIZE) 186 | fg_inds = np.where(labels == 1)[0] 187 | if len(fg_inds) > num_fg: 188 | disable_inds = npr.choice( 189 | fg_inds, size=(len(fg_inds) - num_fg), replace=False) # 随机去除掉一些正样本 190 | labels[disable_inds] = -1 # 变为-1 191 | 192 | # subsample negative labels if we have too many 193 | # 对负样本进行采样,如果负样本的数量太多的话 194 | # 正负样本总数是256,限制正样本数目最多128, 195 | # 如果正样本数量小于128,差的那些就用负样本补上,凑齐256个样本 196 | num_bg = cfg.TRAIN.RPN_BATCHSIZE - np.sum(labels == 1) 197 | bg_inds = np.where(labels == 0)[0] 198 | if len(bg_inds) > num_bg: 199 | disable_inds = npr.choice( 200 | bg_inds, size=(len(bg_inds) - num_bg), replace=False) 201 | labels[disable_inds] = -1 202 | # print "was %s inds, disabling %s, now %s inds" % ( 203 | # len(bg_inds), len(disable_inds), np.sum(labels == 0)) 204 | 205 | # 至此, 上好标签,开始计算rpn-box的真值 206 | # -------------------------------------------------------------- 207 | bbox_targets = np.zeros((len(inds_inside), 4), dtype=np.float32) 208 | bbox_targets = _compute_targets(anchors, gt_boxes[argmax_overlaps, :]) # 根据anchor和gtbox计算得真值(anchor和gtbox之间的偏差) 209 | 210 | bbox_inside_weights = np.zeros((len(inds_inside), 4), dtype=np.float32) 211 | bbox_inside_weights[labels == 1, :] = np.array(cfg.TRAIN.RPN_BBOX_INSIDE_WEIGHTS) # 内部权重,前景就给1,其他是0 212 | 213 | bbox_outside_weights = np.zeros((len(inds_inside), 4), dtype=np.float32) 214 | if cfg.TRAIN.RPN_POSITIVE_WEIGHT < 0: # 暂时使用uniform 权重,也就是正样本是1,负样本是0 215 | # uniform weighting of examples (given non-uniform sampling) 216 | num_examples = np.sum(labels >= 0) + 1 217 | # positive_weights = np.ones((1, 4)) * 1.0 / num_examples 218 | # negative_weights = np.ones((1, 4)) * 1.0 / num_examples 219 | positive_weights = np.ones((1, 4)) 220 | negative_weights = np.zeros((1, 4)) 221 | else: 222 | assert ((cfg.TRAIN.RPN_POSITIVE_WEIGHT > 0) & 223 | (cfg.TRAIN.RPN_POSITIVE_WEIGHT < 1)) 224 | positive_weights = (cfg.TRAIN.RPN_POSITIVE_WEIGHT / 225 | (np.sum(labels == 1)) + 1) 226 | negative_weights = ((1.0 - cfg.TRAIN.RPN_POSITIVE_WEIGHT) / 227 | (np.sum(labels == 0)) + 1) 228 | bbox_outside_weights[labels == 1, :] = positive_weights # 外部权重,前景是1,背景是0 229 | bbox_outside_weights[labels == 0, :] = negative_weights 230 | 231 | if DEBUG: 232 | _sums += bbox_targets[labels == 1, :].sum(axis=0) 233 | _squared_sums += (bbox_targets[labels == 1, :] ** 2).sum(axis=0) 234 | _counts += np.sum(labels == 1) 235 | means = _sums / _counts 236 | stds = np.sqrt(_squared_sums / _counts - means ** 2) 237 | print('means:') 238 | print(means) 239 | print('stdevs:') 240 | print(stds) 241 | 242 | # map up to original set of anchors 243 | # 一开始是将超出图像范围的anchor直接丢掉的,现在在加回来 244 | labels = _unmap(labels, total_anchors, inds_inside, fill=-1) # 这些anchor的label是-1,也即dontcare 245 | bbox_targets = _unmap(bbox_targets, total_anchors, inds_inside, fill=0) # 这些anchor的真值是0,也即没有值 246 | bbox_inside_weights = _unmap(bbox_inside_weights, total_anchors, inds_inside, fill=0) # 内部权重以0填充 247 | bbox_outside_weights = _unmap(bbox_outside_weights, total_anchors, inds_inside, fill=0) # 外部权重以0填充 248 | 249 | if DEBUG: 250 | print('rpn: max max_overlap', np.max(max_overlaps)) 251 | print('rpn: num_positive', np.sum(labels == 1)) 252 | print('rpn: num_negative', np.sum(labels == 0)) 253 | _fg_sum += np.sum(labels == 1) 254 | _bg_sum += np.sum(labels == 0) 255 | _count += 1 256 | print('rpn: num_positive avg', _fg_sum / _count) 257 | print('rpn: num_negative avg', _bg_sum / _count) 258 | 259 | # labels 260 | labels = labels.reshape((1, height, width, A)) # reshap一下label 261 | rpn_labels = labels 262 | 263 | # bbox_targets 264 | bbox_targets = bbox_targets \ 265 | .reshape((1, height, width, A * 4)) # reshape 266 | 267 | rpn_bbox_targets = bbox_targets 268 | # bbox_inside_weights 269 | bbox_inside_weights = bbox_inside_weights \ 270 | .reshape((1, height, width, A * 4)) 271 | 272 | rpn_bbox_inside_weights = bbox_inside_weights 273 | 274 | # bbox_outside_weights 275 | bbox_outside_weights = bbox_outside_weights \ 276 | .reshape((1, height, width, A * 4)) 277 | rpn_bbox_outside_weights = bbox_outside_weights 278 | 279 | return rpn_labels, rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights 280 | 281 | 282 | def _unmap(data, count, inds, fill=0): 283 | """ Unmap a subset of item (data) back to the original set of items (of 284 | size count) """ 285 | if len(data.shape) == 1: 286 | ret = np.empty((count,), dtype=np.float32) 287 | ret.fill(fill) 288 | ret[inds] = data 289 | else: 290 | ret = np.empty((count,) + data.shape[1:], dtype=np.float32) 291 | ret.fill(fill) 292 | ret[inds, :] = data 293 | return ret 294 | 295 | 296 | def _compute_targets(ex_rois, gt_rois): 297 | """Compute bounding-box regression targets for an image.""" 298 | 299 | assert ex_rois.shape[0] == gt_rois.shape[0] 300 | assert ex_rois.shape[1] == 4 301 | assert gt_rois.shape[1] == 5 302 | 303 | return bbox_transform(ex_rois, gt_rois[:, :4]).astype(np.float32, copy=False) 304 | -------------------------------------------------------------------------------- /ctpn/lib/rpn_msr/generate_anchors.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def generate_basic_anchors(sizes, base_size=16): 5 | base_anchor = np.array([0, 0, base_size - 1, base_size - 1], np.int32) 6 | anchors = np.zeros((len(sizes), 4), np.int32) 7 | index = 0 8 | for h, w in sizes: 9 | anchors[index] = scale_anchor(base_anchor, h, w) 10 | index += 1 11 | return anchors 12 | 13 | 14 | def scale_anchor(anchor, h, w): 15 | x_ctr = (anchor[0] + anchor[2]) * 0.5 16 | y_ctr = (anchor[1] + anchor[3]) * 0.5 17 | scaled_anchor = anchor.copy() 18 | scaled_anchor[0] = x_ctr - w / 2 # xmin 19 | scaled_anchor[2] = x_ctr + w / 2 # xmax 20 | scaled_anchor[1] = y_ctr - h / 2 # ymin 21 | scaled_anchor[3] = y_ctr + h / 2 # ymax 22 | return scaled_anchor 23 | 24 | 25 | def generate_anchors(base_size=16, ratios=[0.5, 1, 2], 26 | scales=2 ** np.arange(3, 6)): 27 | heights = [11, 16, 23, 33, 48, 68, 97, 139, 198, 283] 28 | widths = [16] 29 | sizes = [] 30 | for h in heights: 31 | for w in widths: 32 | sizes.append((h, w)) 33 | return generate_basic_anchors(sizes) 34 | 35 | 36 | if __name__ == '__main__': 37 | import time 38 | 39 | t = time.time() 40 | a = generate_anchors() 41 | print(time.time() - t) 42 | print(a) 43 | from IPython import embed 44 | 45 | embed() 46 | -------------------------------------------------------------------------------- /ctpn/lib/rpn_msr/proposal_layer_tf.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import numpy as np 3 | 4 | from .generate_anchors import generate_anchors 5 | from ..fast_rcnn.bbox_transform import bbox_transform_inv, clip_boxes 6 | from ..fast_rcnn.config import cfg 7 | from ..fast_rcnn.nms_wrapper import nms 8 | 9 | DEBUG = False 10 | # DEBUG=True 11 | """ 12 | Outputs object detection proposals by applying estimated bounding-box 13 | transformations to a set of regular boxes (called "anchors"). 14 | """ 15 | 16 | 17 | def proposal_layer(rpn_cls_prob_reshape, rpn_bbox_pred, im_info, cfg_key, _feat_stride=[16, ], anchor_scales=[16, ]): 18 | """ 19 | Parameters 20 | ---------- 21 | rpn_cls_prob_reshape: (1 , H , W , Ax2) outputs of RPN, prob of bg or fg 22 | NOTICE: the old version is ordered by (1, H, W, 2, A) !!!! 23 | rpn_bbox_pred: (1 , H , W , Ax4), rgs boxes output of RPN 24 | im_info: a list of [image_height, image_width, scale_ratios] 25 | cfg_key: 'TRAIN' or 'TEST' 26 | _feat_stride: the downsampling ratio of feature map to the original input image 27 | anchor_scales: the scales to the basic_anchor (basic anchor is [16, 16]) 28 | ---------- 29 | Returns 30 | ---------- 31 | rpn_rois : (1 x H x W x A, 5) e.g. [0, x1, y1, x2, y2] 32 | 33 | # Algorithm: 34 | # 35 | # for each (H, W) location i 36 | # generate A anchor boxes centered on cell i 37 | # apply predicted bbox deltas at cell i to each of the A anchors 38 | # clip predicted boxes to image 39 | # remove predicted boxes with either height or width < threshold 40 | # sort all (proposal, score) pairs by score from highest to lowest 41 | # take top pre_nms_topN proposals before NMS 42 | # apply NMS with threshold 0.7 to remaining proposals 43 | # take after_nms_topN proposals after NMS 44 | # return the top proposals (-> RoIs top, scores top) 45 | #layer_params = yaml.load(self.param_str_) 46 | 47 | """ 48 | cfg_key = cfg_key.decode('ascii') 49 | _anchors = generate_anchors( 50 | scales=np.array(anchor_scales)) # 生成基本的9个anchor 51 | # print('anchors', _anchors) 52 | _num_anchors = _anchors.shape[0] # 9个anchor 53 | 54 | im_info = im_info[0] # 原始图像的高宽、缩放尺度 55 | 56 | assert rpn_cls_prob_reshape.shape[0] == 1, \ 57 | 'Only single item batches are supported' 58 | 59 | pre_nms_topN = cfg[cfg_key].RPN_PRE_NMS_TOP_N # 12000,在做nms之前,最多保留的候选box数目 60 | post_nms_topN = cfg[cfg_key].RPN_POST_NMS_TOP_N # 2000,做完nms之后,最多保留的box的数目 61 | nms_thresh = cfg[cfg_key].RPN_NMS_THRESH # nms用参数,阈值是0.7 62 | min_size = cfg[cfg_key].RPN_MIN_SIZE # 候选box的最小尺寸,目前是16,高宽均要大于16 63 | # TODO 后期需要修改这个最小尺寸,改为8? 64 | 65 | height, width = rpn_cls_prob_reshape.shape[1:3] # feature-map的高宽 66 | 67 | # the first set of _num_anchors channels are bg probs 68 | # the second set are the fg probs, which we want 69 | # (1, H, W, A) 70 | scores = np.reshape(np.reshape(rpn_cls_prob_reshape, [1, height, width, _num_anchors, 2])[:, :, :, :, 1], 71 | [1, height, width, _num_anchors]) 72 | # 提取到object的分数,non-object的我们不关心 73 | # 并reshape到1*H*W*9 74 | 75 | bbox_deltas = rpn_bbox_pred # 模型输出的pred是相对值,需要进一步处理成真实图像中的坐标 76 | # im_info = bottom[2].data[0, :] 77 | 78 | if DEBUG: 79 | print('im_size: ({}, {})'.format(im_info[0], im_info[1])) 80 | print('scale: {}'.format(im_info[2])) 81 | 82 | # 1. Generate proposals from bbox deltas and shifted anchors 83 | if DEBUG: 84 | print('score map size: {}'.format(scores.shape)) 85 | 86 | # Enumerate all shifts 87 | # 同anchor-target-layer-tf这个文件一样,生成anchor的shift,进一步得到整张图像上的所有anchor 88 | shift_x = np.arange(0, width) * _feat_stride 89 | shift_y = np.arange(0, height) * _feat_stride 90 | shift_x, shift_y = np.meshgrid(shift_x, shift_y) 91 | shifts = np.vstack((shift_x.ravel(), shift_y.ravel(), 92 | shift_x.ravel(), shift_y.ravel())).transpose() 93 | 94 | # Enumerate all shifted anchors: 95 | # 96 | # add A anchors (1, A, 4) to 97 | # cell K shifts (K, 1, 4) to get 98 | # shift anchors (K, A, 4) 99 | # reshape to (K*A, 4) shifted anchors 100 | A = _num_anchors 101 | K = shifts.shape[0] 102 | anchors = _anchors.reshape((1, A, 4)) + \ 103 | shifts.reshape((1, K, 4)).transpose((1, 0, 2)) 104 | anchors = anchors.reshape((K * A, 4)) # 这里得到的anchor就是整张图像上的所有anchor 105 | 106 | # Transpose and reshape predicted bbox transformations to get them 107 | # into the same order as the anchors: 108 | # bbox deltas will be (1, 4 * A, H, W) format 109 | # transpose to (1, H, W, 4 * A) 110 | # reshape to (1 * H * W * A, 4) where rows are ordered by (h, w, a) 111 | # in slowest to fastest order 112 | bbox_deltas = bbox_deltas.reshape((-1, 4)) # (HxWxA, 4) 113 | 114 | # Same story for the scores: 115 | scores = scores.reshape((-1, 1)) 116 | 117 | # Convert anchors into proposals via bbox transformations 118 | proposals = bbox_transform_inv(anchors, bbox_deltas) # 做逆变换,得到box在图像上的真实坐标 119 | 120 | # 2. clip predicted boxes to image 121 | # 将所有的proposal修建一下,超出图像范围的将会被修剪掉 122 | proposals = clip_boxes(proposals, im_info[:2]) 123 | 124 | # 3. remove predicted boxes with either height or width < threshold 125 | # (NOTE: convert min_size to input image scale stored in im_info[2]) 126 | # 移除那些proposal小于一定尺寸的proposal 127 | keep = _filter_boxes(proposals, min_size * im_info[2]) 128 | proposals = proposals[keep, :] # 保留剩下的proposal 129 | scores = scores[keep] 130 | bbox_deltas = bbox_deltas[keep, :] 131 | 132 | # # remove irregular boxes, too fat too tall 133 | # keep = _filter_irregular_boxes(proposals) 134 | # proposals = proposals[keep, :] 135 | # scores = scores[keep] 136 | 137 | # 4. sort all (proposal, score) pairs by score from highest to lowest 138 | # 5. take top pre_nms_topN (e.g. 6000) 139 | order = scores.ravel().argsort()[::-1] # score按得分的高低进行排序 140 | if pre_nms_topN > 0: # 保留12000个proposal进去做nms 141 | order = order[:pre_nms_topN] 142 | proposals = proposals[order, :] 143 | scores = scores[order] 144 | bbox_deltas = bbox_deltas[order, :] 145 | 146 | # 6. apply nms (e.g. threshold = 0.7) 147 | # 7. take after_nms_topN (e.g. 300) 148 | # 8. return the top proposals (-> RoIs top) 149 | keep = nms(np.hstack((proposals, scores)), 150 | nms_thresh) # 进行nms操作,保留2000个proposal 151 | if post_nms_topN > 0: 152 | keep = keep[:post_nms_topN] 153 | proposals = proposals[keep, :] 154 | scores = scores[keep] 155 | bbox_deltas = bbox_deltas[keep, :] 156 | 157 | # Output rois blob 158 | # Our RPN implementation only supports a single input image, so all 159 | # batch inds are 0 160 | blob = np.hstack((scores.astype(np.float32, copy=False), 161 | proposals.astype(np.float32, copy=False))) 162 | 163 | return blob, bbox_deltas 164 | 165 | 166 | def _filter_boxes(boxes, min_size): 167 | """Remove all boxes with any side smaller than min_size.""" 168 | ws = boxes[:, 2] - boxes[:, 0] + 1 169 | hs = boxes[:, 3] - boxes[:, 1] + 1 170 | keep = np.where((ws >= min_size) & (hs >= min_size))[0] 171 | return keep 172 | 173 | 174 | def _filter_irregular_boxes(boxes, min_ratio=0.2, max_ratio=5): 175 | """Remove all boxes with any side smaller than min_size.""" 176 | ws = boxes[:, 2] - boxes[:, 0] + 1 177 | hs = boxes[:, 3] - boxes[:, 1] + 1 178 | rs = ws / hs 179 | keep = np.where((rs <= max_ratio) & (rs >= min_ratio))[0] 180 | return keep 181 | -------------------------------------------------------------------------------- /ctpn/models: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/ctpn/models -------------------------------------------------------------------------------- /ctpn/prepare_training_data/ToVoc.py: -------------------------------------------------------------------------------- 1 | from xml.dom.minidom import Document 2 | import cv2 3 | import os 4 | import glob 5 | import shutil 6 | import numpy as np 7 | 8 | def generate_xml(name, lines, img_size, class_sets, doncateothers=True): 9 | doc = Document() 10 | 11 | def append_xml_node_attr(child, parent=None, text=None): 12 | ele = doc.createElement(child) 13 | if not text is None: 14 | text_node = doc.createTextNode(text) 15 | ele.appendChild(text_node) 16 | parent = doc if parent is None else parent 17 | parent.appendChild(ele) 18 | return ele 19 | 20 | img_name = name + '.jpg' 21 | # create header 22 | annotation = append_xml_node_attr('annotation') 23 | append_xml_node_attr('folder', parent=annotation, text='text') 24 | append_xml_node_attr('filename', parent=annotation, text=img_name) 25 | source = append_xml_node_attr('source', parent=annotation) 26 | append_xml_node_attr('database', parent=source, text='coco_text_database') 27 | append_xml_node_attr('annotation', parent=source, text='text') 28 | append_xml_node_attr('image', parent=source, text='text') 29 | append_xml_node_attr('flickrid', parent=source, text='000000') 30 | owner = append_xml_node_attr('owner', parent=annotation) 31 | append_xml_node_attr('name', parent=owner, text='ms') 32 | size = append_xml_node_attr('size', annotation) 33 | append_xml_node_attr('width', size, str(img_size[1])) 34 | append_xml_node_attr('height', size, str(img_size[0])) 35 | append_xml_node_attr('depth', size, str(img_size[2])) 36 | append_xml_node_attr('segmented', parent=annotation, text='0') 37 | 38 | # create objects 39 | objs = [] 40 | for line in lines: 41 | splitted_line = line.strip().lower().split() 42 | cls = splitted_line[0].lower() 43 | if not doncateothers and cls not in class_sets: 44 | continue 45 | cls = 'dontcare' if cls not in class_sets else cls 46 | if cls == 'dontcare': 47 | continue 48 | obj = append_xml_node_attr('object', parent=annotation) 49 | occlusion = int(0) 50 | x1, y1, x2, y2 = int(float(splitted_line[1]) + 1), int(float(splitted_line[2]) + 1), \ 51 | int(float(splitted_line[3]) + 1), int(float(splitted_line[4]) + 1) 52 | truncation = float(0) 53 | difficult = 1 if _is_hard(cls, truncation, occlusion, x1, y1, x2, y2) else 0 54 | truncted = 0 if truncation < 0.5 else 1 55 | 56 | append_xml_node_attr('name', parent=obj, text=cls) 57 | append_xml_node_attr('pose', parent=obj, text='none') 58 | append_xml_node_attr('truncated', parent=obj, text=str(truncted)) 59 | append_xml_node_attr('difficult', parent=obj, text=str(int(difficult))) 60 | bb = append_xml_node_attr('bndbox', parent=obj) 61 | append_xml_node_attr('xmin', parent=bb, text=str(x1)) 62 | append_xml_node_attr('ymin', parent=bb, text=str(y1)) 63 | append_xml_node_attr('xmax', parent=bb, text=str(x2)) 64 | append_xml_node_attr('ymax', parent=bb, text=str(y2)) 65 | 66 | o = {'class': cls, 'box': np.asarray([x1, y1, x2, y2], dtype=float), \ 67 | 'truncation': truncation, 'difficult': difficult, 'occlusion': occlusion} 68 | objs.append(o) 69 | 70 | return doc, objs 71 | 72 | 73 | def _is_hard(cls, truncation, occlusion, x1, y1, x2, y2): 74 | hard = False 75 | if y2 - y1 < 25 and occlusion >= 2: 76 | hard = True 77 | return hard 78 | if occlusion >= 3: 79 | hard = True 80 | return hard 81 | if truncation > 0.8: 82 | hard = True 83 | return hard 84 | return hard 85 | 86 | 87 | def build_voc_dirs(outdir): 88 | mkdir = lambda dir: os.makedirs(dir) if not os.path.exists(dir) else None 89 | mkdir(outdir) 90 | mkdir(os.path.join(outdir, 'Annotations')) 91 | mkdir(os.path.join(outdir, 'ImageSets')) 92 | mkdir(os.path.join(outdir, 'ImageSets', 'Layout')) 93 | mkdir(os.path.join(outdir, 'ImageSets', 'Main')) 94 | mkdir(os.path.join(outdir, 'ImageSets', 'Segmentation')) 95 | mkdir(os.path.join(outdir, 'JPEGImages')) 96 | mkdir(os.path.join(outdir, 'SegmentationClass')) 97 | mkdir(os.path.join(outdir, 'SegmentationObject')) 98 | return os.path.join(outdir, 'Annotations'), os.path.join(outdir, 'JPEGImages'), os.path.join(outdir, 'ImageSets', 99 | 'Main') 100 | 101 | 102 | if __name__ == '__main__': 103 | _outdir = 'TEXTVOC/VOC2007' 104 | _draw = bool(0) 105 | _dest_label_dir, _dest_img_dir, _dest_set_dir = build_voc_dirs(_outdir) 106 | _doncateothers = bool(1) 107 | for dset in ['train']: 108 | _labeldir = 'label_tmp' 109 | _imagedir = 're_image' 110 | class_sets = ('text', 'dontcare') 111 | class_sets_dict = dict((k, i) for i, k in enumerate(class_sets)) 112 | allclasses = {} 113 | fs = [open(os.path.join(_dest_set_dir, cls + '_' + dset + '.txt'), 'w') for cls in class_sets] 114 | ftrain = open(os.path.join(_dest_set_dir, dset + '.txt'), 'w') 115 | 116 | files = glob.glob(os.path.join(_labeldir, '*.txt')) 117 | files.sort() 118 | for file in files: 119 | path, basename = os.path.split(file) 120 | stem, ext = os.path.splitext(basename) 121 | with open(file, 'r') as f: 122 | lines = f.readlines() 123 | img_file = os.path.join(_imagedir, stem + '.jpg') 124 | 125 | print(img_file) 126 | img = cv2.imread(img_file) 127 | img_size = img.shape 128 | 129 | doc, objs = generate_xml(stem, lines, img_size, class_sets=class_sets, doncateothers=_doncateothers) 130 | 131 | cv2.imwrite(os.path.join(_dest_img_dir, stem + '.jpg'), img) 132 | xmlfile = os.path.join(_dest_label_dir, stem + '.xml') 133 | with open(xmlfile, 'w') as f: 134 | f.write(doc.toprettyxml(indent=' ')) 135 | 136 | ftrain.writelines(stem + '\n') 137 | 138 | cls_in_image = set([o['class'] for o in objs]) 139 | 140 | for obj in objs: 141 | cls = obj['class'] 142 | allclasses[cls] = 0 \ 143 | if not cls in list(allclasses.keys()) else allclasses[cls] + 1 144 | 145 | for cls in cls_in_image: 146 | if cls in class_sets: 147 | fs[class_sets_dict[cls]].writelines(stem + ' 1\n') 148 | for cls in class_sets: 149 | if cls not in cls_in_image: 150 | fs[class_sets_dict[cls]].writelines(stem + ' -1\n') 151 | 152 | 153 | (f.close() for f in fs) 154 | ftrain.close() 155 | 156 | print('~~~~~~~~~~~~~~~~~~~') 157 | print(allclasses) 158 | print('~~~~~~~~~~~~~~~~~~~') 159 | shutil.copyfile(os.path.join(_dest_set_dir, 'train.txt'), os.path.join(_dest_set_dir, 'val.txt')) 160 | shutil.copyfile(os.path.join(_dest_set_dir, 'train.txt'), os.path.join(_dest_set_dir, 'trainval.txt')) 161 | for cls in class_sets: 162 | shutil.copyfile(os.path.join(_dest_set_dir, cls + '_train.txt'), 163 | os.path.join(_dest_set_dir, cls + '_trainval.txt')) 164 | shutil.copyfile(os.path.join(_dest_set_dir, cls + '_train.txt'), 165 | os.path.join(_dest_set_dir, cls + '_val.txt')) 166 | -------------------------------------------------------------------------------- /ctpn/prepare_training_data/split_label.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import math 4 | import cv2 as cv 5 | 6 | path = '/media/D/code/OCR/text-detection-ctpn/data/mlt_english+chinese/image' 7 | gt_path = '/media/D/code/OCR/text-detection-ctpn/data/mlt_english+chinese/label' 8 | out_path = 're_image' 9 | if not os.path.exists(out_path): 10 | os.makedirs(out_path) 11 | files = os.listdir(path) 12 | files.sort() 13 | #files=files[:100] 14 | for file in files: 15 | _, basename = os.path.split(file) 16 | if basename.lower().split('.')[-1] not in ['jpg', 'png']: 17 | continue 18 | stem, ext = os.path.splitext(basename) 19 | gt_file = os.path.join(gt_path, 'gt_' + stem + '.txt') 20 | img_path = os.path.join(path, file) 21 | print(img_path) 22 | img = cv.imread(img_path) 23 | img_size = img.shape 24 | im_size_min = np.min(img_size[0:2]) 25 | im_size_max = np.max(img_size[0:2]) 26 | 27 | im_scale = float(600) / float(im_size_min) 28 | if np.round(im_scale * im_size_max) > 1200: 29 | im_scale = float(1200) / float(im_size_max) 30 | re_im = cv.resize(img, None, None, fx=im_scale, fy=im_scale, interpolation=cv.INTER_LINEAR) 31 | re_size = re_im.shape 32 | cv.imwrite(os.path.join(out_path, stem) + '.jpg', re_im) 33 | 34 | with open(gt_file, 'r') as f: 35 | lines = f.readlines() 36 | for line in lines: 37 | splitted_line = line.strip().lower().split(',') 38 | pt_x = np.zeros((4, 1)) 39 | pt_y = np.zeros((4, 1)) 40 | pt_x[0, 0] = int(float(splitted_line[0]) / img_size[1] * re_size[1]) 41 | pt_y[0, 0] = int(float(splitted_line[1]) / img_size[0] * re_size[0]) 42 | pt_x[1, 0] = int(float(splitted_line[2]) / img_size[1] * re_size[1]) 43 | pt_y[1, 0] = int(float(splitted_line[3]) / img_size[0] * re_size[0]) 44 | pt_x[2, 0] = int(float(splitted_line[4]) / img_size[1] * re_size[1]) 45 | pt_y[2, 0] = int(float(splitted_line[5]) / img_size[0] * re_size[0]) 46 | pt_x[3, 0] = int(float(splitted_line[6]) / img_size[1] * re_size[1]) 47 | pt_y[3, 0] = int(float(splitted_line[7]) / img_size[0] * re_size[0]) 48 | 49 | ind_x = np.argsort(pt_x, axis=0) 50 | pt_x = pt_x[ind_x] 51 | pt_y = pt_y[ind_x] 52 | 53 | if pt_y[0] < pt_y[1]: 54 | pt1 = (pt_x[0], pt_y[0]) 55 | pt3 = (pt_x[1], pt_y[1]) 56 | else: 57 | pt1 = (pt_x[1], pt_y[1]) 58 | pt3 = (pt_x[0], pt_y[0]) 59 | 60 | if pt_y[2] < pt_y[3]: 61 | pt2 = (pt_x[2], pt_y[2]) 62 | pt4 = (pt_x[3], pt_y[3]) 63 | else: 64 | pt2 = (pt_x[3], pt_y[3]) 65 | pt4 = (pt_x[2], pt_y[2]) 66 | 67 | xmin = int(min(pt1[0], pt2[0])) 68 | ymin = int(min(pt1[1], pt2[1])) 69 | xmax = int(max(pt2[0], pt4[0])) 70 | ymax = int(max(pt3[1], pt4[1])) 71 | 72 | if xmin < 0: 73 | xmin = 0 74 | if xmax > re_size[1] - 1: 75 | xmax = re_size[1] - 1 76 | if ymin < 0: 77 | ymin = 0 78 | if ymax > re_size[0] - 1: 79 | ymax = re_size[0] - 1 80 | 81 | width = xmax - xmin 82 | height = ymax - ymin 83 | 84 | # reimplement 85 | step = 16.0 86 | x_left = [] 87 | x_right = [] 88 | x_left.append(xmin) 89 | x_left_start = int(math.ceil(xmin / 16.0) * 16.0) 90 | if x_left_start == xmin: 91 | x_left_start = xmin + 16 92 | for i in np.arange(x_left_start, xmax, 16): 93 | x_left.append(i) 94 | x_left = np.array(x_left) 95 | 96 | x_right.append(x_left_start - 1) 97 | for i in range(1, len(x_left) - 1): 98 | x_right.append(x_left[i] + 15) 99 | x_right.append(xmax) 100 | x_right = np.array(x_right) 101 | 102 | idx = np.where(x_left == x_right) 103 | x_left = np.delete(x_left, idx, axis=0) 104 | x_right = np.delete(x_right, idx, axis=0) 105 | 106 | if not os.path.exists('label_tmp'): 107 | os.makedirs('label_tmp') 108 | with open(os.path.join('label_tmp', stem) + '.txt', 'a') as f: 109 | for i in range(len(x_left)): 110 | f.writelines("text\t") 111 | f.writelines(str(int(x_left[i]))) 112 | f.writelines("\t") 113 | f.writelines(str(int(ymin))) 114 | f.writelines("\t") 115 | f.writelines(str(int(x_right[i]))) 116 | f.writelines("\t") 117 | f.writelines(str(int(ymax))) 118 | f.writelines("\n") 119 | -------------------------------------------------------------------------------- /ctpn/text_detect.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | # import tensorflow as tf 3 | from .ctpn.detectors import TextDetector 4 | from .ctpn.model import ctpn 5 | from .ctpn.other import draw_boxes 6 | ''' 7 | 进行文区别于识别-网络结构为cnn+rnn 8 | ''' 9 | 10 | 11 | def text_detect(img): 12 | # ctpn网络测到 13 | scores, boxes, img = ctpn(img) 14 | textdetector = TextDetector() 15 | boxes = textdetector.detect(boxes, scores[:, np.newaxis], img.shape[:2]) 16 | # text_recs, tmp = draw_boxes(img, boxes, caption='im_name', wait=True, is_display=False) 17 | text_recs, tmp = draw_boxes( 18 | img, boxes, caption='im_name', wait=True, is_display=True) 19 | return text_recs, tmp, img 20 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import model" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": { 18 | "collapsed": false 19 | }, 20 | "outputs": [], 21 | "source": [ 22 | "from glob import glob\n", 23 | "from IPython.display import Image as display\n", 24 | "import numpy as np\n", 25 | "from PIL import Image\n", 26 | "import time\n", 27 | "paths = glob('./test/*.*')\n", 28 | "paths" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "## pytorch crnn" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": { 42 | "collapsed": false 43 | }, 44 | "outputs": [], 45 | "source": [] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": { 51 | "collapsed": false 52 | }, 53 | "outputs": [], 54 | "source": [ 55 | "im = Image.open(paths[-2])\n", 56 | "img = np.array(im.convert('RGB'))\n", 57 | "t = time.time()\n", 58 | "result,img,angle = model.model(img,model='crnn', detectAngle=True) ## if model == crnn ,you should install pytorch\n", 59 | "print \"It takes time:{}s\".format(time.time()-t)\n", 60 | "print \"---------------------------------------\"\n", 61 | "print \"图像的文字朝向为:{}度\\n\".format(angle),\"识别结果:\\n\"\n", 62 | "\n", 63 | "for key in result:\n", 64 | " print result[key][1]\n", 65 | " \n", 66 | "Image.fromarray(img)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": {}, 72 | "source": [ 73 | "## keras crnn " 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": { 80 | "collapsed": false 81 | }, 82 | "outputs": [], 83 | "source": [ 84 | "im = Image.open(paths[-2])\n", 85 | "img = np.array(im.convert('RGB'))\n", 86 | "t = time.time()\n", 87 | "result,img,angle = model.model(img,model='keras',detectAngle=True) ##if model == keras ,you should install keras\n", 88 | "print \"It takes time:{}s\".format(time.time()-t)\n", 89 | "print \"---------------------------------------\"\n", 90 | "print \"图像的文字朝向为:{}度\\n\".format(angle),\"识别结果:\\n\"\n", 91 | "for key in result:\n", 92 | " print result[key][1]\n", 93 | "Image.fromarray(img)" 94 | ] 95 | } 96 | ], 97 | "metadata": { 98 | "kernelspec": { 99 | "display_name": "Python [Root]", 100 | "language": "python", 101 | "name": "Python [Root]" 102 | }, 103 | "language_info": { 104 | "codemirror_mode": { 105 | "name": "ipython", 106 | "version": 2 107 | }, 108 | "file_extension": ".py", 109 | "mimetype": "text/x-python", 110 | "name": "python", 111 | "nbconvert_exporter": "python", 112 | "pygments_lexer": "ipython2", 113 | "version": "2.7.13" 114 | } 115 | }, 116 | "nbformat": 4, 117 | "nbformat_minor": 2 118 | } 119 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import time 3 | from glob import glob 4 | 5 | import numpy as np 6 | from PIL import Image 7 | 8 | import model 9 | # ces 10 | 11 | paths = glob('./test/*.*') 12 | 13 | if __name__ == '__main__': 14 | im = Image.open("./test/test.png") 15 | img = np.array(im.convert('RGB')) 16 | t = time.time() 17 | ''' 18 | result,img,angel分别对应-识别结果,图像的数组,文字旋转角度 19 | ''' 20 | result, img, angle = model.model( 21 | img, model='keras', adjust=True, detectAngle=True) 22 | print("It takes time:{}s".format(time.time() - t)) 23 | print("---------------------------------------") 24 | for key in result: 25 | print(result[key][1]) 26 | -------------------------------------------------------------------------------- /img/tmp1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/img/tmp1.png -------------------------------------------------------------------------------- /img/tmp1识别结果.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/img/tmp1识别结果.png -------------------------------------------------------------------------------- /img/tmp2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/img/tmp2.jpg -------------------------------------------------------------------------------- /img/tmp2识别结果.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/img/tmp2识别结果.png -------------------------------------------------------------------------------- /keras_model.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | ##添加文本方向 检测模型,自动检测文字方向,0、90、180、270 3 | ##keras版本的OCR识别 4 | 5 | from math import * 6 | 7 | import cv2 8 | import numpy as np 9 | from PIL import Image 10 | 11 | from angle.predict import predict as angle_detect ##文字方向检测 12 | from ctpn.text_detect import text_detect 13 | from ocr.model import predict as ocr 14 | 15 | 16 | def crnnRec(im, text_recs, adjust=False): 17 | """ 18 | crnn模型,ocr识别 19 | @@model, 20 | @@converter, 21 | @@im:Array 22 | @@text_recs:text box 23 | 24 | """ 25 | index = 0 26 | results = {} 27 | xDim, yDim = im.shape[1], im.shape[0] 28 | 29 | for index, rec in enumerate(text_recs): 30 | results[index] = [ 31 | rec, 32 | ] 33 | xlength = int((rec[6] - rec[0]) * 0.1) 34 | ylength = int((rec[7] - rec[1]) * 0.2) 35 | if adjust: 36 | pt1 = (max(1, rec[0] - xlength), max(1, rec[1] - ylength)) 37 | pt2 = (rec[2], rec[3]) 38 | pt3 = (min(rec[6] + xlength, xDim - 2), 39 | min(yDim - 2, rec[7] + ylength)) 40 | pt4 = (rec[4], rec[5]) 41 | else: 42 | pt1 = (max(1, rec[0]), max(1, rec[1])) 43 | pt2 = (rec[2], rec[3]) 44 | pt3 = (min(rec[6], xDim - 2), min(yDim - 2, rec[7])) 45 | pt4 = (rec[4], rec[5]) 46 | 47 | degree = degrees(atan2(pt2[1] - pt1[1], pt2[0] - pt1[0])) ##图像倾斜角度 48 | 49 | partImg = dumpRotateImage(im, degree, pt1, pt2, pt3, pt4) 50 | 51 | image = Image.fromarray(partImg).convert('L') 52 | sim_pred = ocr(image) 53 | 54 | results[index].append(sim_pred) ##识别文字 55 | 56 | return results 57 | 58 | 59 | def dumpRotateImage(img, degree, pt1, pt2, pt3, pt4): 60 | height, width = img.shape[:2] 61 | heightNew = int(width * fabs(sin(radians(degree))) + 62 | height * fabs(cos(radians(degree)))) 63 | widthNew = int(height * fabs(sin(radians(degree))) + 64 | width * fabs(cos(radians(degree)))) 65 | matRotation = cv2.getRotationMatrix2D((width / 2, height / 2), degree, 1) 66 | matRotation[0, 2] += (widthNew - width) / 2 67 | matRotation[1, 2] += (heightNew - height) / 2 68 | imgRotation = cv2.warpAffine( 69 | img, matRotation, (widthNew, heightNew), borderValue=(255, 255, 255)) 70 | pt1 = list(pt1) 71 | pt3 = list(pt3) 72 | 73 | [[pt1[0]], [pt1[1]]] = np.dot(matRotation, 74 | np.array([[pt1[0]], [pt1[1]], [1]])) 75 | [[pt3[0]], [pt3[1]]] = np.dot(matRotation, 76 | np.array([[pt3[0]], [pt3[1]], [1]])) 77 | ydim, xdim = imgRotation.shape[:2] 78 | imgOut = imgRotation[max(1, int(pt1[1])):min(ydim - 1, int(pt3[1])), 79 | max(1, int(pt1[0])):min(xdim - 1, int(pt3[0]))] 80 | # height,width=imgOut.shape[:2] 81 | return imgOut 82 | 83 | 84 | def model(img, adjust=False, detectAngle=False): 85 | """ 86 | @@param:img, 87 | @@param:model,选择的ocr模型,支持keras\pytorch版本 88 | @@param:adjust 调整文字识别结果 89 | @@param:detectAngle,是否检测文字朝向 90 | 91 | """ 92 | angle = 0 93 | if detectAngle: 94 | 95 | angle = angle_detect(img=np.copy(img)) ##文字朝向检测 96 | im = Image.fromarray(img) 97 | if angle == 90: 98 | im = im.transpose(Image.ROTATE_90) 99 | elif angle == 180: 100 | im = im.transpose(Image.ROTATE_180) 101 | elif angle == 270: 102 | im = im.transpose(Image.ROTATE_270) 103 | img = np.array(im) 104 | 105 | text_recs, tmp, img = text_detect(img) 106 | text_recs = sort_box(text_recs) 107 | result = crnnRec(img, text_recs, model, adjust=adjust) 108 | return result, tmp, angle 109 | 110 | 111 | def sort_box(box): 112 | """ 113 | 对box排序,及页面进行排版 114 | text_recs[index, 0] = x1 115 | text_recs[index, 1] = y1 116 | text_recs[index, 2] = x2 117 | text_recs[index, 3] = y2 118 | text_recs[index, 4] = x3 119 | text_recs[index, 5] = y3 120 | text_recs[index, 6] = x4 121 | text_recs[index, 7] = y4 122 | """ 123 | 124 | box = sorted(box, key=lambda x: sum([x[1], x[3], x[5], x[7]])) 125 | return box 126 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | ##添加文本方向 检测模型,自动检测文字方向,0、90、180、270 3 | from math import * 4 | 5 | import cv2 6 | import numpy as np 7 | from PIL import Image 8 | import sys 9 | 10 | sys.path.append("ocr") 11 | from angle.predict import predict as angle_detect ##文字方向检测 12 | 13 | from crnn.crnn import crnnOcr 14 | 15 | from ctpn.text_detect import text_detect 16 | from ocr.model import predict as ocr 17 | 18 | 19 | def crnnRec(im, text_recs, ocrMode='keras', adjust=False): 20 | """ 21 | crnn模型,ocr识别 22 | @@model, 23 | @@converter, 24 | @@im:Array 25 | @@text_recs:text box 26 | 27 | """ 28 | index = 0 29 | results = {} 30 | xDim, yDim = im.shape[1], im.shape[0] 31 | 32 | for index, rec in enumerate(text_recs): 33 | results[index] = [ 34 | rec, 35 | ] 36 | xlength = int((rec[6] - rec[0]) * 0.1) 37 | ylength = int((rec[7] - rec[1]) * 0.2) 38 | if adjust: 39 | pt1 = (max(1, rec[0] - xlength), max(1, rec[1] - ylength)) 40 | pt2 = (rec[2], rec[3]) 41 | pt3 = (min(rec[6] + xlength, xDim - 2), 42 | min(yDim - 2, rec[7] + ylength)) 43 | pt4 = (rec[4], rec[5]) 44 | else: 45 | pt1 = (max(1, rec[0]), max(1, rec[1])) 46 | pt2 = (rec[2], rec[3]) 47 | pt3 = (min(rec[6], xDim - 2), min(yDim - 2, rec[7])) 48 | pt4 = (rec[4], rec[5]) 49 | 50 | degree = degrees(atan2(pt2[1] - pt1[1], pt2[0] - pt1[0])) ##图像倾斜角度 51 | 52 | partImg = dumpRotateImage(im, degree, pt1, pt2, pt3, pt4) 53 | # 根据ctpn进行识别出的文字区域,进行不同文字区域的crnn识别 54 | image = Image.fromarray(partImg).convert('L') 55 | # 进行识别出的文字识别 56 | if ocrMode == 'keras': 57 | sim_pred = ocr(image) 58 | else: 59 | sim_pred = crnnOcr(image) 60 | 61 | results[index].append(sim_pred) ##识别文字 62 | 63 | return results 64 | 65 | 66 | def dumpRotateImage(img, degree, pt1, pt2, pt3, pt4): 67 | height, width = img.shape[:2] 68 | heightNew = int(width * fabs(sin(radians(degree))) + 69 | height * fabs(cos(radians(degree)))) 70 | widthNew = int(height * fabs(sin(radians(degree))) + 71 | width * fabs(cos(radians(degree)))) 72 | matRotation = cv2.getRotationMatrix2D((width / 2, height / 2), degree, 1) 73 | matRotation[0, 2] += (widthNew - width) / 2 74 | matRotation[1, 2] += (heightNew - height) / 2 75 | imgRotation = cv2.warpAffine( 76 | img, matRotation, (widthNew, heightNew), borderValue=(255, 255, 255)) 77 | pt1 = list(pt1) 78 | pt3 = list(pt3) 79 | 80 | [[pt1[0]], [pt1[1]]] = np.dot(matRotation, 81 | np.array([[pt1[0]], [pt1[1]], [1]])) 82 | [[pt3[0]], [pt3[1]]] = np.dot(matRotation, 83 | np.array([[pt3[0]], [pt3[1]], [1]])) 84 | ydim, xdim = imgRotation.shape[:2] 85 | imgOut = imgRotation[max(1, int(pt1[1])):min(ydim - 1, int(pt3[1])), 86 | max(1, int(pt1[0])):min(xdim - 1, int(pt3[0]))] 87 | # height,width=imgOut.shape[:2] 88 | return imgOut 89 | 90 | 91 | def model(img, model='keras', adjust=False, detectAngle=False): 92 | """ 93 | @@param:img, 94 | @@param:model,选择的ocr模型,支持keras\pytorch版本 95 | @@param:adjust 调整文字识别结果 96 | @@param:detectAngle,是否检测文字朝向 97 | 98 | """ 99 | angle = 0 100 | if detectAngle: 101 | # 进行文字旋转方向检测,分为[0, 90, 180, 270]四种情况 102 | angle = angle_detect(img=np.copy(img)) ##文字朝向检测 103 | print('The angel of this character is:', angle) 104 | im = Image.fromarray(img) 105 | print('Rotate the array of this img!') 106 | if angle == 90: 107 | im = im.transpose(Image.ROTATE_90) 108 | elif angle == 180: 109 | im = im.transpose(Image.ROTATE_180) 110 | elif angle == 270: 111 | im = im.transpose(Image.ROTATE_270) 112 | img = np.array(im) 113 | # 进行图像中的文字区域的识别 114 | text_recs, tmp, img=text_detect(img) 115 | # 识别区域排列 116 | text_recs = sort_box(text_recs) 117 | # 118 | result = crnnRec(img, text_recs, model, adjust=adjust) 119 | return result, tmp, angle 120 | 121 | 122 | def sort_box(box): 123 | """ 124 | 对box排序,及页面进行排版 125 | text_recs[index, 0] = x1 126 | text_recs[index, 1] = y1 127 | text_recs[index, 2] = x2 128 | text_recs[index, 3] = y2 129 | text_recs[index, 4] = x3 130 | text_recs[index, 5] = y3 131 | text_recs[index, 6] = x4 132 | text_recs[index, 7] = y4 133 | """ 134 | 135 | box = sorted(box, key=lambda x: sum([x[1], x[3], x[5], x[7]])) 136 | return box 137 | -------------------------------------------------------------------------------- /pytorch_model.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | ##添加文本方向 检测模型,自动检测文字方向,0、90、180、270 3 | ##pytorch版本的OCR识别 4 | from math import * 5 | 6 | import cv2 7 | import numpy as np 8 | from PIL import Image 9 | 10 | from angle.predict import predict as angle_detect ##文字方向检测 11 | from crnn.crnn import crnnOcr 12 | from ctpn.text_detect import text_detect 13 | 14 | 15 | def crnnRec(im, text_recs, adjust=False): 16 | """ 17 | crnn模型,ocr识别 18 | @@model, 19 | @@converter, 20 | @@im:Array 21 | @@text_recs:text box 22 | 23 | """ 24 | index = 0 25 | results = {} 26 | xDim, yDim = im.shape[1], im.shape[0] 27 | 28 | for index, rec in enumerate(text_recs): 29 | results[index] = [rec, ] 30 | xlength = int((rec[6] - rec[0]) * 0.1) 31 | ylength = int((rec[7] - rec[1]) * 0.2) 32 | if adjust: 33 | pt1 = (max(1, rec[0] - xlength), max(1, rec[1] - ylength)) 34 | pt2 = (rec[2], rec[3]) 35 | pt3 = (min(rec[6] + xlength, xDim - 2), min(yDim - 2, rec[7] + ylength)) 36 | pt4 = (rec[4], rec[5]) 37 | else: 38 | pt1 = (max(1, rec[0]), max(1, rec[1])) 39 | pt2 = (rec[2], rec[3]) 40 | pt3 = (min(rec[6], xDim - 2), min(yDim - 2, rec[7])) 41 | pt4 = (rec[4], rec[5]) 42 | 43 | degree = degrees(atan2(pt2[1] - pt1[1], pt2[0] - pt1[0])) ##图像倾斜角度 44 | 45 | partImg = dumpRotateImage(im, degree, pt1, pt2, pt3, pt4) 46 | 47 | image = Image.fromarray(partImg).convert('L') 48 | sim_pred = crnnOcr(image) 49 | results[index].append(sim_pred) ##识别文字 50 | 51 | return results 52 | 53 | 54 | def dumpRotateImage(img, degree, pt1, pt2, pt3, pt4): 55 | height, width = img.shape[:2] 56 | heightNew = int(width * fabs(sin(radians(degree))) + height * fabs(cos(radians(degree)))) 57 | widthNew = int(height * fabs(sin(radians(degree))) + width * fabs(cos(radians(degree)))) 58 | matRotation = cv2.getRotationMatrix2D((width / 2, height / 2), degree, 1) 59 | matRotation[0, 2] += (widthNew - width) / 2 60 | matRotation[1, 2] += (heightNew - height) / 2 61 | imgRotation = cv2.warpAffine(img, matRotation, (widthNew, heightNew), borderValue=(255, 255, 255)) 62 | pt1 = list(pt1) 63 | pt3 = list(pt3) 64 | 65 | [[pt1[0]], [pt1[1]]] = np.dot(matRotation, np.array([[pt1[0]], [pt1[1]], [1]])) 66 | [[pt3[0]], [pt3[1]]] = np.dot(matRotation, np.array([[pt3[0]], [pt3[1]], [1]])) 67 | ydim, xdim = imgRotation.shape[:2] 68 | imgOut = imgRotation[max(1, int(pt1[1])):min(ydim - 1, int(pt3[1])), max(1, int(pt1[0])):min(xdim - 1, int(pt3[0]))] 69 | # height,width=imgOut.shape[:2] 70 | return imgOut 71 | 72 | 73 | def model(img, adjust=False, detectAngle=False): 74 | """ 75 | @@param:img, 76 | @@param:model,选择的ocr模型,支持keras\pytorch版本 77 | @@param:adjust 调整文字识别结果 78 | @@param:detectAngle,是否检测文字朝向 79 | 80 | """ 81 | angle = 0 82 | if detectAngle: 83 | angle = angle_detect(img=np.copy(img)) ##文字朝向检测 84 | im = Image.fromarray(img) 85 | if angle == 90: 86 | im = im.transpose(Image.ROTATE_90) 87 | elif angle == 180: 88 | im = im.transpose(Image.ROTATE_180) 89 | elif angle == 270: 90 | im = im.transpose(Image.ROTATE_270) 91 | img = np.array(im) 92 | 93 | text_recs, tmp, img = text_detect(img) 94 | text_recs = sort_box(text_recs) 95 | result = crnnRec(img, text_recs, model, adjust=adjust) 96 | return result, tmp, angle 97 | 98 | 99 | def sort_box(box): 100 | """ 101 | 对box排序,及页面进行排版 102 | text_recs[index, 0] = x1 103 | text_recs[index, 1] = y1 104 | text_recs[index, 2] = x2 105 | text_recs[index, 3] = y2 106 | text_recs[index, 4] = x3 107 | text_recs[index, 5] = y3 108 | text_recs[index, 6] = x4 109 | text_recs[index, 7] = y4 110 | """ 111 | 112 | box = sorted(box, key=lambda x: sum([x[1], x[3], x[5], x[7]])) 113 | return box 114 | -------------------------------------------------------------------------------- /setup-cpu.sh: -------------------------------------------------------------------------------- 1 | conda create -n chinese-ocr python=2.7 pip scipy numpy PIL jupyter##运用conda 创建python环境 2 | source activate chinese-ocr 3 | pip install easydict -i https://pypi.tuna.tsinghua.edu.cn/simple/ ##选择国内源,速度更快 4 | pip install keras==2.0.8 -i https://pypi.tuna.tsinghua.edu.cn/simple/ 5 | pip install Cython opencv-python -i https://pypi.tuna.tsinghua.edu.cn/simple/ 6 | pip install matplotlib -i https://pypi.tuna.tsinghua.edu.cn/simple/ 7 | pip install -U pillow -i https://pypi.tuna.tsinghua.edu.cn/simple/ 8 | pip install h5py lmdb mahotas -i https://pypi.tuna.tsinghua.edu.cn/simple/ 9 | conda install pytorch=0.1.12 torchvision -c soumith 10 | conda install tensorflow=1.3 ##解决cuda报错相关问题 11 | cd ./ctpn/lib/utils 12 | sh make-for-cpu.sh 13 | 14 | 15 | -------------------------------------------------------------------------------- /setup-python3.sh: -------------------------------------------------------------------------------- 1 | source activate base 2 | #conda create -n chinese-ocr3 python=3.6 pip scipy numpy Pillow jupyter 3 | #source activate chinese-ocr3 4 | pip install easydict -i https://pypi.tuna.tsinghua.edu.cn/simple/ 5 | pip install keras -i https://pypi.tuna.tsinghua.edu.cn/simple/ 6 | pip install Cython opencv-python -i https://pypi.tuna.tsinghua.edu.cn/simple/ 7 | pip install matplotlib -i https://pypi.tuna.tsinghua.edu.cn/simple/ 8 | pip install -U pillow -i https://pypi.tuna.tsinghua.edu.cn/simple/ 9 | pip install h5py lmdb mahotas -i https://pypi.tuna.tsinghua.edu.cn/simple/ 10 | pip install futures==3.1.1 -i https://pypi.tuna.tsinghua.edu.cn/simple/ 11 | pip install tensorflow -i https://pypi.tuna.tsinghua.edu.cn/simple/ 12 | cd ./ctpn/lib/utils 13 | ./make-for-cpu.sh 14 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | conda create -n chinese-ocr python=2.7 pip scipy numpy PIL jupyter##运用conda 创建python环境 2 | source activate chinese-ocr 3 | pip install easydict -i https://pypi.tuna.tsinghua.edu.cn/simple/ ##选择国内源,速度更快 4 | pip install keras==2.0.8 -i https://pypi.tuna.tsinghua.edu.cn/simple/ 5 | pip install Cython opencv-python -i https://pypi.tuna.tsinghua.edu.cn/simple/ 6 | pip install matplotlib -i https://pypi.tuna.tsinghua.edu.cn/simple/ 7 | pip install -U pillow -i https://pypi.tuna.tsinghua.edu.cn/simple/ 8 | pip install h5py lmdb mahotas -i https://pypi.tuna.tsinghua.edu.cn/simple/ 9 | conda install pytorch=0.1.12 cuda80 torchvision -c soumith 10 | conda install tensorflow=1.3 tensorflow-gpu=1.3 ##解决cuda报错相关问题 11 | cd ./ctpn/lib/utils 12 | sh make.sh 13 | 14 | 15 | -------------------------------------------------------------------------------- /train/create-dataset.sh: -------------------------------------------------------------------------------- 1 | cd create_dataset 2 | python create_dataset.py 3 | -------------------------------------------------------------------------------- /train/create_dataset/create_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import lmdb # install lmdb by "pip install lmdb" 5 | import numpy as np 6 | 7 | 8 | # from genLineText import GenTextImage 9 | 10 | def checkImageIsValid(imageBin): 11 | if imageBin is None: 12 | return False 13 | imageBuf = np.fromstring(imageBin, dtype=np.uint8) 14 | img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE) 15 | if img is None: 16 | return False 17 | imgH, imgW = img.shape[0], img.shape[1] 18 | if imgH * imgW == 0: 19 | return False 20 | return True 21 | 22 | 23 | def writeCache(env, cache): 24 | with env.begin(write=True) as txn: 25 | for k, v in cache.items(): 26 | txn.put(k.encode(), v) 27 | 28 | 29 | def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True): 30 | """ 31 | Create LMDB dataset for CRNN training. 32 | 33 | ARGS: 34 | outputPath : LMDB output path 35 | imagePathList : list of image path 36 | labelList : list of corresponding groundtruth texts 37 | lexiconList : (optional) list of lexicon lists 38 | checkValid : if true, check the validity of every image 39 | """ 40 | # print (len(imagePathList) , len(labelList)) 41 | assert (len(imagePathList) == len(labelList)) 42 | nSamples = len(imagePathList) 43 | print('...................') 44 | env = lmdb.open(outputPath, map_size=1099511627776) 45 | 46 | cache = {} 47 | cnt = 1 48 | for i in range(nSamples): 49 | imagePath = imagePathList[i] 50 | label = labelList[i] 51 | if not os.path.exists(imagePath): 52 | print('%s does not exist' % imagePath) 53 | continue 54 | with open(imagePath, 'rb') as f: 55 | imageBin = f.read() 56 | if checkValid: 57 | if not checkImageIsValid(imageBin): 58 | print('%s is not a valid image' % imagePath) 59 | continue 60 | 61 | imageKey = 'image-%09d' % cnt 62 | labelKey = 'label-%09d' % cnt 63 | cache[imageKey] = imageBin 64 | cache[labelKey] = label.encode() 65 | if lexiconList: 66 | lexiconKey = 'lexicon-%09d' % cnt 67 | cache[lexiconKey] = ' '.join(lexiconList[i]).encode() 68 | if cnt % 1000 == 0: 69 | writeCache(env, cache) 70 | cache = {} 71 | print('Written %d / %d' % (cnt, nSamples)) 72 | cnt += 1 73 | nSamples = cnt - 1 74 | cache['num-samples'] = str(nSamples).encode() 75 | writeCache(env, cache) 76 | print('Created dataset with %d samples' % nSamples) 77 | 78 | 79 | def read_text(path): 80 | with open(path) as f: 81 | text = f.read() 82 | text = text.strip() 83 | 84 | return text 85 | 86 | 87 | import glob 88 | 89 | if __name__ == '__main__': 90 | 91 | ##lmdb 输出目录 92 | outputPath = '../data/lmdb/train' 93 | 94 | path = '../data/dataline/*.jpg' 95 | imagePathList = glob.glob(path) 96 | print('------------', len(imagePathList), '------------') 97 | imgLabelLists = [] 98 | for p in imagePathList: 99 | try: 100 | imgLabelLists.append((p, read_text(p.replace('.jpg', '.txt')))) 101 | except: 102 | continue 103 | 104 | # imgLabelList = [ (p,read_text(p.replace('.jpg','.txt'))) for p in imagePathList] 105 | ##sort by lebelList 106 | imgLabelList = sorted(imgLabelLists, key=lambda x: len(x[1])) 107 | imgPaths = [p[0] for p in imgLabelList] 108 | txtLists = [p[1] for p in imgLabelList] 109 | 110 | createDataset(outputPath, imgPaths, txtLists, lexiconList=None, checkValid=True) 111 | -------------------------------------------------------------------------------- /train/create_dataset/fontA.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/train/create_dataset/fontA.ttf -------------------------------------------------------------------------------- /train/create_dataset/textgen.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from PIL import Image, ImageDraw, ImageFont 3 | import random 4 | import os 5 | 6 | 7 | def genImage(label, fontsize, color=(0, 0, 0),fontName="华文细黑.ttf"): 8 | img = Image.new("RGB", ((int)(fontsize * 1.2 * len(label)), (int)(fontsize * 2)), (255, 255, 255)) 9 | font = ImageFont.truetype(fontName, fontsize) 10 | draw = ImageDraw.Draw(img) 11 | draw.text((0, 0), label, fill=color, font=font) 12 | with open("../data/dataline/" + label + "-" + str(fontsize) + ".txt", "w", encoding='utf-8') as f: 13 | f.write(label) 14 | img.save("../data/dataline/" + label + "-" + str(fontsize) + ".jpg") 15 | 16 | 17 | if __name__ == '__main__': 18 | alphabet = """'疗绚诚娇溜题贿者廖更纳加奉公一就汴计与路房原妇208-7其>:],,骑刈全消昏傈安久钟嗅不影处驽蜿资关椤地瘸专问忖票嫉炎韵要月田节陂鄙捌备拳伺眼网盎大傍心东愉汇蹿科每业里航晏字平录先13彤鲶产稍督腴有象岳注绍在泺文定核名水过理让偷率等这发”为含肥酉相鄱七编猥锛日镀蒂掰倒辆栾栗综涩州雌滑馀了机块司宰甙兴矽抚保用沧秩如收息滥页疑埠!!姥异橹钇向下跄的椴沫国绥獠报开民蜇何分凇长讥藏掏施羽中讲派嘟人提浼间世而古多倪唇饯控庚首赛蜓味断制觉技替艰溢潮夕钺外摘枋动双单啮户枇确锦曜杜或能效霜盒然侗电晁放步鹃新杖蜂吒濂瞬评总隍对独合也是府青天诲墙组滴级邀帘示已时骸仄泅和遨店雇疫持巍踮境只亨目鉴崤闲体泄杂作般轰化解迂诿蛭璀腾告版服省师小规程线海办引二桧牌砺洄裴修图痫胡许犊事郛基柴呼食研奶律蛋因葆察戏褒戒再李骁工貂油鹅章啄休场给睡纷豆器捎说敏学会浒设诊格廓查来霓室溆¢诡寥焕舜柒狐回戟砾厄实翩尿五入径惭喹股宇篝|;美期云九祺扮靠锝槌系企酰阊暂蚕忻豁本羹执条钦H獒限进季楦于芘玖铋茯未答粘括样精欠矢甥帷嵩扣令仔风皈行支部蓉刮站蜡救钊汗松嫌成可.鹤院从交政怕活调球局验髌第韫谗串到圆年米/*友忿检区看自敢刃个兹弄流留同没齿星聆轼湖什三建蛔儿椋汕震颧鲤跟力情璺铨陪务指族训滦鄣濮扒商箱十召慷辗所莞管护臭横硒嗓接侦六露党馋驾剖高侬妪幂猗绺骐央酐孝筝课徇缰门男西项句谙瞒秃篇教碲罚声呐景前富嘴鳌稀免朋啬睐去赈鱼住肩愕速旁波厅健茼厥鲟谅投攸炔数方击呋谈绩别愫僚躬鹧胪炳招喇膨泵蹦毛结54谱识陕粽婚拟构且搜任潘比郢妨醪陀桔碘扎选哈骷楷亿明缆脯监睫逻婵共赴淝凡惦及达揖谩澹减焰蛹番祁柏员禄怡峤龙白叽生闯起细装谕竟聚钙上导渊按艾辘挡耒盹饪臀记邮蕙受各医搂普滇朗茸带翻酚(光堤墟蔷万幻〓瑙辈昧盏亘蛀吉铰请子假闻税井诩哨嫂好面琐校馊鬣缂营访炖占农缀否经钚棵趟张亟吏茶谨捻论迸堂玉信吧瞠乡姬寺咬溏苄皿意赉宝尔钰艺特唳踉都荣倚登荐丧奇涵批炭近符傩感道着菊虹仲众懈濯颞眺南释北缝标既茗整撼迤贲挎耱拒某妍卫哇英矶藩治他元领膜遮穗蛾飞荒棺劫么市火温拈棚洼转果奕卸迪伸泳斗邡侄涨屯萋胭氡崮枞惧冒彩斜手豚随旭淑妞形菌吲沱争驯歹挟兆柱传至包内响临红功弩衡寂禁老棍耆渍织害氵渑布载靥嗬虽苹咨娄库雉榜帜嘲套瑚亲簸欧边6腿旮抛吹瞳得镓梗厨继漾愣憨士策窑抑躯襟脏参贸言干绸鳄穷藜音折详)举悍甸癌黎谴死罩迁寒驷袖媒蒋掘模纠恣观祖蛆碍位稿主澧跌筏京锏帝贴证糠才黄鲸略炯饱四出园犀牧容汉杆浈汰瑷造虫瘩怪驴济应花沣谔夙旅价矿以考su呦晒巡茅准肟瓴詹仟褂译桌混宁怦郑抿些余鄂饴攒珑群阖岔琨藓预环洮岌宀杲瀵最常囡周踊女鼓袭喉简范薯遐疏粱黜禧法箔斤遥汝奥直贞撑置绱集她馅逗钧橱魉[恙躁唤9旺膘待脾惫购吗依盲度瘿蠖俾之镗拇鲵厝簧续款展啃表剔品钻腭损清锶统涌寸滨贪链吠冈伎迥咏吁览防迅失汾阔逵绀蔑列川凭努熨揪利俱绉抢鸨我即责膦易毓鹊刹玷岿空嘞绊排术估锷违们苟铜播肘件烫审鲂广像铌惰铟巳胍鲍康憧色恢想拷尤疳知SYFDA峄裕帮握搔氐氘难墒沮雨叁缥悴藐湫娟苑稠颛簇后阕闭蕤缚怎佞码嘤蔡痊舱螯帕赫昵升烬岫、疵蜻髁蕨隶烛械丑盂梁强鲛由拘揉劭龟撤钩呕孛费妻漂求阑崖秤甘通深补赃坎床啪承吼量暇钼烨阂擎脱逮称P神属矗华届狍葑汹育患窒蛰佼静槎运鳗庆逝曼疱克代官此麸耧蚌晟例础榛副测唰缢迹灬霁身岁赭扛又菡乜雾板读陷徉贯郁虑变钓菜圾现琢式乐维渔浜左吾脑钡警T啵拴偌漱湿硕止骼魄积燥联踢玛|则窿见振畿送班钽您赵刨印讨踝籍谡舌崧汽蔽沪酥绒怖财帖肱私莎勋羔霸励哼帐将帅渠纪婴娩岭厘滕吻伤坝冠戊隆瘁介涧物黍并姗奢蹑掣垸锴命箍捉病辖琰眭迩艘绌繁寅若毋思诉类诈燮轲酮狂重反职筱县委磕绣奖晋濉志徽肠呈獐坻口片碰几村柿劳料获亩惕晕厌号罢池正鏖煨家棕复尝懋蜥锅岛扰队坠瘾钬@卧疣镇譬冰彷频黯据垄采八缪瘫型熹砰楠襁箐但嘶绳啤拍盥穆傲洗盯塘怔筛丿台恒喂葛永¥烟酒桦书砂蚝缉态瀚袄圳轻蛛超榧遛姒奘铮右荽望偻卡丶氰附做革索戚坨桷唁垅榻岐偎坛莨山殊微骇陈爨推嗝驹澡藁呤卤嘻糅逛侵郓酌德摇※鬃被慨殡羸昌泡戛鞋河宪沿玲鲨翅哽源铅语照邯址荃佬顺鸳町霭睾瓢夸椁晓酿痈咔侏券噎湍签嚷离午尚社锤背孟使浪缦潍鞅军姹驶笑鳟鲁》孽钜绿洱礴焯椰颖囔乌孔巴互性椽哞聘昨早暮胶炀隧低彗昝铁呓氽藉喔癖瑗姨权胱韦堑蜜酋楝砝毁靓歙锲究屋喳骨辨碑武鸠宫辜烊适坡殃培佩供走蜈迟翼况姣凛浔吃飘债犟金促苛崇坂莳畔绂兵蠕斋根砍亢欢恬崔剁餐榫快扶‖濒缠鳜当彭驭浦篮昀锆秸钳弋娣瞑夷龛苫拱致%嵊障隐弑初娓抉汩累蓖"唬助苓昙押毙破城郧逢嚏獭瞻溱婿赊跨恼璧萃姻貉灵炉密氛陶砸谬衔点琛沛枳层岱诺脍榈埂征冷裁打蹴素瘘逞蛐聊激腱萘踵飒蓟吆取咙簋涓矩曝挺揣座你史舵焱尘苏笈脚溉榨诵樊邓焊义庶儋蟋蒲赦呷杞诠豪还试颓茉太除紫逃痴草充鳕珉祗墨渭烩蘸慕璇镶穴嵘恶骂险绋幕碉肺戳刘潞秣纾潜銮洛须罘销瘪汞兮屉r林厕质探划狸殚善煊烹〒锈逯宸辍泱柚袍远蹋嶙绝峥娥缍雀徵认镱谷=贩勉撩鄯斐洋非祚泾诒饿撬威晷搭芍锥笺蓦候琊档礁沼卵荠忑朝凹瑞头仪弧孵畏铆突衲车浩气茂悖厢枕酝戴湾邹飚攘锂写宵翁岷无喜丈挑嗟绛殉议槽具醇淞笃郴阅饼底壕砚弈询缕庹翟零筷暨舟闺甯撞麂茌蔼很珲捕棠角阉媛娲诽剿尉爵睬韩诰匣危糍镯立浏阳少盆舔擘匪申尬铣旯抖赘瓯居ˇ哮游锭茏歌坏甚秒舞沙仗劲潺阿燧郭嗖霏忠材奂耐跺砀输岖媳氟极摆灿今扔腻枝奎药熄吨话q额慑嘌协喀壳埭视著於愧陲翌峁颅佛腹聋侯咎叟秀颇存较罪哄岗扫栏钾羌己璨枭霉煌涸衿键镝益岢奏连夯睿冥均糖狞蹊稻爸刿胥煜丽肿璃掸跚灾垂樾濑乎莲窄犹撮战馄软络显鸢胸宾妲恕埔蝌份遇巧瞟粒恰剥桡博讯凯堇阶滤卖斌骚彬兑磺樱舷两娱福仃差找桁÷净把阴污戬雷碓蕲楚罡焖抽妫咒仑闱尽邑菁爱贷沥鞑牡嗉崴骤塌嗦订拮滓捡锻次坪杩臃箬融珂鹗宗枚降鸬妯阄堰盐毅必杨崃俺甬状莘货耸菱腼铸唏痤孚澳懒溅翘疙杷淼缙骰喊悉砻坷艇赁界谤纣宴晃茹归饭梢铡街抄肼鬟苯颂撷戈炒咆茭瘙负仰客琉铢封卑珥椿镧窨鬲寿御袤铃萎砖餮脒裳肪孕嫣馗嵇恳氯江石褶冢祸阻狈羞银靳透咳叼敷芷啥它瓤兰痘懊逑肌往捺坊甩呻〃沦忘膻祟菅剧崆智坯臧霍墅攻眯倘拢骠铐庭岙瓠′缺泥迢捶??郏喙掷沌纯秘种听绘固螨团香盗妒埚蓝拖旱荞铀血遏汲辰叩拽幅硬惶桀漠措泼唑齐肾念酱虚屁耶旗砦闵婉馆拭绅韧忏窝醋葺顾辞倜堆辋逆玟贱疾董惘倌锕淘嘀莽俭笏绑鲷杈择蟀粥嗯驰逾案谪褓胫哩昕颚鲢绠躺鹄崂儒俨丝尕泌啊萸彰幺吟骄苣弦脊瑰〈诛镁析闪剪侧哟框螃守嬗燕狭铈缮概迳痧鲲俯售笼痣扉挖满咋援邱扇歪便玑绦峡蛇叨〖泽胃斓喋怂坟猪该蚬炕弥赞棣晔娠挲狡创疖铕镭稷挫弭啾翔粉履苘哦楼秕铂土锣瘟挣栉习享桢袅磨桂谦延坚蔚噗署谟猬钎恐嬉雒倦衅亏璩睹刻殿王算雕麻丘柯骆丸塍谚添鲈垓桎蚯芥予飕镦谌窗醚菀亮搪莺蒿羁足J真轶悬衷靛翊掩哒炅掐冼妮l谐稚荆擒犯陵虏浓崽刍陌傻孜千靖演矜钕煽杰酗渗伞栋俗泫戍罕沾疽灏煦芬磴叱阱榉湃蜀叉醒彪租郡篷屎良垢隗弱陨峪砷掴颁胎雯绵贬沐撵隘篙暖曹陡栓填臼彦瓶琪潼哪鸡摩啦俟锋域耻蔫疯纹撇毒绶痛酯忍爪赳歆嘹辕烈册朴钱吮毯癜娃谀邵厮炽璞邃丐追词瓒忆轧芫谯喷弟半冕裙掖墉绮寝苔势顷褥切衮君佳嫒蚩霞佚洙逊镖暹唛&殒顶碗獗轭铺蛊废恹汨崩珍那杵曲纺夏薰傀闳淬姘舀拧卷楂恍讪厩寮篪赓乘灭盅鞣沟慎挂饺鼾杳树缨丛絮娌臻嗳篡侩述衰矛圈蚜匕筹匿濞晨叶骋郝挚蚴滞增侍描瓣吖嫦蟒匾圣赌毡癞恺百曳需篓肮庖帏卿驿遗蹬鬓骡歉芎胳屐禽烦晌寄媾狄翡苒船廉终痞殇々畦饶改拆悻萄£瓿乃訾桅匮溧拥纱铍骗蕃龋缬父佐疚栎醍掳蓄x惆颜鲆榆〔猎敌暴谥鲫贾罗玻缄扦芪癣落徒臾恿猩托邴肄牵春陛耀刊拓蓓邳堕寇枉淌啡湄兽酷萼碚濠萤夹旬戮梭琥椭昔勺蜊绐晚孺僵宣摄冽旨萌忙蚤眉噼蟑付契瓜悼颡壁曾窕颢澎仿俑浑嵌浣乍碌褪乱蔟隙玩剐葫箫纲围伐决伙漩瑟刑肓镳缓蹭氨皓典畲坍铑檐塑洞倬储胴淳戾吐灼惺妙毕珐缈虱盖羰鸿磅谓髅娴苴唷蚣霹抨贤唠犬誓逍庠逼麓籼釉呜碧秧氩摔霄穸纨辟妈映完牛缴嗷炊恩荔茆掉紊慌莓羟阙萁磐另蕹辱鳐湮吡吩唐睦垠舒圜冗瞿溺芾囱匠僳汐菩饬漓黑霰浸濡窥毂蒡兢驻鹉芮诙迫雳厂忐臆猴鸣蚪栈箕羡渐莆捍眈哓趴蹼埕嚣骛宏淄斑噜严瑛垃椎诱压庾绞焘廿抡迄棘夫纬锹眨瞌侠脐竞瀑孳骧遁姜颦荪滚萦伪逸粳爬锁矣役趣洒颔诏逐奸甭惠攀蹄泛尼拼阮鹰亚颈惑勒〉际肛爷刚钨丰养冶鲽辉蔻画覆皴妊麦返醉皂擀〗酶凑粹悟诀硖港卜z杀涕±舍铠抵弛段敝镐奠拂轴跛袱et沉菇俎薪峦秭蟹历盟菠寡液肢喻染裱悱抱氙赤捅猛跑氮谣仁尺辊窍烙衍架擦倏璐瑁币楞胖夔趸邛惴饕虔蝎§哉贝宽辫炮扩饲籽魏菟锰伍猝末琳哚蛎邂呀姿鄞却歧仙恸椐森牒寤袒婆虢雅钉朵贼欲苞寰故龚坭嘘咫礼硷兀睢汶’铲烧绕诃浃钿哺柜讼颊璁腔洽咐脲簌筠镣玮鞠谁兼姆挥梯蝴谘漕刷躏宦弼b垌劈麟莉揭笙渎仕嗤仓配怏抬错泯镊孰猿邪仍秋鼬壹歇吵炼<尧射柬廷胧霾凳隋肚浮梦祥株堵退L鹫跎凶毽荟炫栩玳甜沂鹿顽伯爹赔蛴徐匡欣狰缸雹蟆疤默沤啜痂衣禅wih辽葳黝钗停沽棒馨颌肉吴硫悯劾娈马啧吊悌镑峭帆瀣涉咸疸滋泣翦拙癸钥蜒+尾庄凝泉婢渴谊乞陆锉糊鸦淮IBN晦弗乔庥葡尻席橡傣渣拿惩麋斛缃矮蛏岘鸽姐膏催奔镒喱蠡摧钯胤柠拐璋鸥卢荡倾^_珀逄萧塾掇贮笆聂圃冲嵬M滔笕值炙偶蜱搐梆汪蔬腑鸯蹇敞绯仨祯谆梧糗鑫啸豺囹猾巢柄瀛筑踌沭暗苁鱿蹉脂蘖牢热木吸溃宠序泞偿拜檩厚朐毗螳吞媚朽担蝗橘畴祈糟盱隼郜惜珠裨铵焙琚唯咚噪骊丫滢勤棉呸咣淀隔蕾窈饨挨煅短匙粕镜赣撕墩酬馁豌颐抗酣氓佑搁哭递耷涡桃贻碣截瘦昭镌蔓氚甲猕蕴蓬散拾纛狼猷铎埋旖矾讳囊糜迈粟蚂紧鲳瘢栽稼羊锄斟睁桥瓮蹙祉醺鼻昱剃跳篱跷蒜翎宅晖嗑壑峻癫屏狠陋袜途憎祀莹滟佶溥臣约盛峰磁慵婪拦莅朕鹦粲裤哎疡嫖琵窟堪谛嘉儡鳝斩郾驸酊妄胜贺徙傅噌钢栅庇恋匝巯邈尸锚粗佟蛟薹纵蚊郅绢锐苗俞篆淆膀鲜煎诶秽寻涮刺怀噶巨褰魅灶灌桉藕谜舸薄搀恽借牯痉渥愿亓耘杠柩锔蚶钣珈喘蹒幽赐稗晤莱泔扯肯菪裆腩豉疆骜腐倭珏唔粮亡润慰伽橄玄誉醐胆龊粼塬陇彼削嗣绾芽妗垭瘴爽薏寨龈泠弹赢漪猫嘧涂恤圭茧烽屑痕巾赖荸凰腮畈亵蹲偃苇澜艮换骺烘苕梓颉肇哗悄氤涠葬屠鹭植竺佯诣鲇瘀鲅邦移滁冯耕癔戌茬沁巩悠湘洪痹锟循谋腕鳃钠捞焉迎碱伫急榷奈邝卯辄皲卟醛畹忧稳雄昼缩阈睑扌耗曦涅捏瞧邕淖漉铝耦禹湛喽莼琅诸苎纂硅始嗨傥燃臂赅嘈呆贵屹壮肋亍蚀卅豹腆邬迭浊}童螂捐圩勐触寞汊壤荫膺渌芳懿遴螈泰蓼蛤茜舅枫朔膝眙避梅判鹜璜牍缅垫藻黔侥惚懂踩腰腈札丞唾慈顿摹荻琬~斧沈滂胁胀幄莜Z匀鄄掌绰茎焚赋萱谑汁铒瞎夺蜗野娆冀弯篁懵灞隽芡脘俐辩芯掺喏膈蝈觐悚踹蔗熠鼠呵抓橼峨畜缔禾崭弃熊摒凸拗穹蒙抒祛劝闫扳阵醌踪喵侣搬仅荧赎蝾琦买婧瞄寓皎冻赝箩莫瞰郊笫姝筒枪遣煸袋舆痱涛母〇启践耙绲盘遂昊搞槿诬纰泓惨檬亻越Co憩熵祷钒暧塔阗胰咄娶魔琶钞邻扬杉殴咽弓〆髻】吭揽霆拄殖脆彻岩芝勃辣剌钝嘎甄佘皖伦授徕憔挪皇庞稔芜踏溴兖卒擢饥鳞煲‰账颗叻斯捧鳍琮讹蛙纽谭酸兔莒睇伟觑羲嗜宜褐旎辛卦诘筋鎏溪挛熔阜晰鳅丢奚灸呱献陉黛鸪甾萨疮拯洲疹辑叙恻谒允柔烂氏逅漆拎惋扈湟纭啕掬擞哥忽涤鸵靡郗瓷扁廊怨雏钮敦E懦憋汀拚啉腌岸f痼瞅尊咀眩飙忌仝迦熬毫胯篑茄腺凄舛碴锵诧羯後漏汤宓仞蚁壶谰皑铄棰罔辅晶苦牟闽\烃饮聿丙蛳朱煤涔鳖犁罐荼砒淦妤黏戎孑婕瑾戢钵枣捋砥衩狙桠稣阎肃梏诫孪昶婊衫嗔侃塞蜃樵峒貌屿欺缫阐栖诟珞荭吝萍嗽恂啻蜴磬峋俸豫谎徊镍韬魇晴U囟猜蛮坐囿伴亭肝佗蝠妃胞滩榴氖垩苋砣扪馏姓轩厉夥侈禀垒岑赏钛辐痔披纸碳“坞蠓挤荥沅悔铧帼蒌蝇apyng哀浆瑶凿桶馈皮奴苜佤伶晗铱炬优弊氢恃甫攥端锌灰稹炝曙邋亥眶碾拉萝绔捷浍腋姑菖凌涞麽锢桨潢绎镰殆锑渝铬困绽觎匈糙暑裹鸟盔肽迷綦『亳佝俘钴觇骥仆疝跪婶郯瀹唉脖踞针晾忒扼瞩叛椒疟嗡邗肆跆玫忡捣咧唆艄蘑潦笛阚沸泻掊菽贫斥髂孢镂赂麝鸾屡衬苷恪叠希粤爻喝茫惬郸绻庸撅碟宄妹膛叮饵崛嗲椅冤搅咕敛尹垦闷蝉霎勰败蓑泸肤鹌幌焦浠鞍刁舰乙竿裔。茵函伊兄丨娜匍謇莪宥似蝽翳酪翠粑薇祢骏赠叫Q噤噻竖芗莠潭俊羿耜O郫趁嗪囚蹶芒洁笋鹑敲硝啶堡渲揩』携宿遒颍扭棱割萜蔸葵琴捂饰衙耿掠募岂窖涟蔺瘤柞瞪怜匹距楔炜哆秦缎幼茁绪痨恨楸娅瓦桩雪嬴伏榔妥铿拌眠雍缇‘卓搓哌觞噩屈哧髓咦巅娑侑淫膳祝勾姊莴胄疃薛蜷胛巷芙芋熙闰勿窃狱剩钏幢陟铛慧靴耍k浙浇飨惟绗祜澈啼咪磷摞诅郦抹跃壬吕肖琏颤尴剡抠凋赚泊津宕殷倔氲漫邺涎怠$垮荬遵俏叹噢饽蜘孙筵疼鞭羧牦箭潴c眸祭髯啖坳愁芩驮倡巽穰沃胚怒凤槛剂趵嫁v邢灯鄢桐睽檗锯槟婷嵋圻诗蕈颠遭痢芸怯馥竭锗徜恭遍籁剑嘱苡龄僧桑潸弘澶楹悲讫愤腥悸谍椹呢桓葭攫阀翰躲敖柑郎笨橇呃魁燎脓葩磋垛玺狮沓砜蕊锺罹蕉翱虐闾巫旦茱嬷枯鹏贡芹汛矫绁拣禺佃讣舫惯乳趋疲挽岚虾衾蠹蹂飓氦铖孩稞瑜壅掀勘妓畅髋W庐牲蓿榕练垣唱邸菲昆婺穿绡麒蚱掂愚泷涪漳妩娉榄讷觅旧藤煮呛柳腓叭庵烷阡罂蜕擂猖咿媲脉【沏貅黠熏哲烁坦酵兜×潇撒剽珩圹乾摸樟帽嗒襄魂轿憬锡〕喃皆咖隅脸残泮袂鹂珊囤捆咤误徨闹淙芊淋怆囗拨梳渤RG绨蚓婀幡狩麾谢唢裸旌伉纶裂驳砼咛澄樨蹈宙澍倍貔操勇蟠摈砧虬够缁悦藿撸艹摁淹豇虎榭ˉ吱d°喧荀踱侮奋偕饷犍惮坑璎徘宛妆袈倩窦昂荏乖K怅撰鳙牙袁酞X痿琼闸雁趾荚虻涝《杏韭偈烤绫鞘卉症遢蓥诋杭荨匆竣簪辙敕虞丹缭咩黟m淤瑕咂铉硼茨嶂痒畸敬涿粪窘熟叔嫔盾忱裘憾梵赡珙咯娘庙溯胺葱痪摊荷卞乒髦寐铭坩胗枷爆溟嚼羚砬轨惊挠罄竽菏氧浅楣盼枢炸阆杯谏噬淇渺俪秆墓泪跻砌痰垡渡耽釜讶鳎煞呗韶舶绷鹳缜旷铊皱龌檀霖奄槐艳蝶旋哝赶骞蚧腊盈丁`蜚矸蝙睨嚓僻鬼醴夜彝磊笔拔栀糕厦邰纫逭纤眦膊馍躇烯蘼冬诤暄骶哑瘠」臊丕愈咱螺擅跋搏硪谄笠淡嘿骅谧鼎皋姚歼蠢驼耳胬挝涯狗蒽孓犷凉芦箴铤孤嘛坤V茴朦挞尖橙诞搴碇洵浚帚蜍漯柘嚎讽芭荤咻祠秉跖埃吓糯眷馒惹娼鲑嫩讴轮瞥靶褚乏缤宋帧删驱碎扑俩俄偏涣竹噱皙佰渚唧斡#镉刀崎筐佣夭贰肴峙哔艿匐牺镛缘仡嫡劣枸堀梨簿鸭蒸亦稽浴{衢束槲j阁揍疥棋潋聪窜乓睛插冉阪苍搽「蟾螟幸仇樽撂慢跤幔俚淅覃觊溶妖帛侨曰妾泗 """ 19 | charact = alphabet[:] 20 | textLen = len(charact) - 11 21 | for i in range(100): 22 | ss = random.randint(0, textLen) 23 | genImage(alphabet[ss:ss + 10], 20) 24 | genImage(alphabet[ss:ss + 10], 15) 25 | -------------------------------------------------------------------------------- /train/create_dataset/viewlmdb.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import lmdb # install lmdb by "pip install lmdb" 4 | 5 | outputPath = '../data/lmdb/train' 6 | env = lmdb.open(outputPath) 7 | txn = env.begin(write=False) 8 | for key, value in txn.cursor(): 9 | print(key, value) 10 | 11 | env.close() 12 | -------------------------------------------------------------------------------- /train/create_dataset/华文细黑.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/train/create_dataset/华文细黑.ttf -------------------------------------------------------------------------------- /train/data/dataline/ff299a9c-b41b-11e7-89e1-1c1b0d6ddf51.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/train/data/dataline/ff299a9c-b41b-11e7-89e1-1c1b0d6ddf51.jpg -------------------------------------------------------------------------------- /train/data/dataline/ff299a9c-b41b-11e7-89e1-1c1b0d6ddf51.txt: -------------------------------------------------------------------------------- 1 | 抱抱、包在我身上、超 -------------------------------------------------------------------------------- /train/data/lmdb/train/data.mdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/train/data/lmdb/train/data.mdb -------------------------------------------------------------------------------- /train/data/lmdb/train/lock.mdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/train/data/lmdb/train/lock.mdb -------------------------------------------------------------------------------- /train/data/lmdb/val/data.mdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/train/data/lmdb/val/data.mdb -------------------------------------------------------------------------------- /train/data/lmdb/val/lock.mdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/train/data/lmdb/val/lock.mdb -------------------------------------------------------------------------------- /train/keras-train/allinonetrain.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import random 3 | import sys 4 | 5 | import lmdb 6 | import numpy as np 7 | import six 8 | import torch 9 | import torchvision.transforms as transforms 10 | from PIL import Image 11 | from keras.layers import Flatten, BatchNormalization, Permute, TimeDistributed, Dense, Bidirectional, GRU 12 | from keras.layers import Input, Conv2D, MaxPooling2D, ZeroPadding2D 13 | from keras.models import Model 14 | from torch.utils.data import Dataset 15 | from torch.utils.data import sampler 16 | 17 | rnnunit = 256 18 | from keras import backend as K 19 | 20 | from keras.layers import Lambda 21 | from keras.optimizers import SGD 22 | 23 | 24 | class lmdbDataset(Dataset): 25 | def __init__(self, root=None, transform=None, target_transform=None): 26 | self.env = lmdb.open( 27 | root, 28 | max_readers=1, 29 | readonly=True, 30 | lock=False, 31 | readahead=False, 32 | meminit=False) 33 | 34 | if not self.env: 35 | print('cannot creat lmdb from %s' % (root)) 36 | sys.exit(0) 37 | 38 | with self.env.begin(write=False) as txn: 39 | nSamples = int(txn.get('num-samples'.encode())) 40 | print("nSamples:{}".format(nSamples)) 41 | self.nSamples = nSamples 42 | 43 | self.transform = transform 44 | self.target_transform = target_transform 45 | 46 | def __len__(self): 47 | return self.nSamples 48 | 49 | def __getitem__(self, index): 50 | assert index <= len(self), 'index range error' 51 | index += 1 52 | with self.env.begin(write=False) as txn: 53 | img_key = 'image-%09d' % index 54 | imgbuf = txn.get(img_key.encode()) 55 | 56 | buf = six.BytesIO() 57 | buf.write(imgbuf) 58 | buf.seek(0) 59 | try: 60 | img = Image.open(buf).convert('L') 61 | # img.save("1111111111.jpg") 62 | except IOError: 63 | print('Corrupted image for %d' % index) 64 | if index > self.nSamples - 1: 65 | index = 0 66 | return self[index + 1] 67 | 68 | if self.transform is not None: 69 | img = self.transform(img) 70 | 71 | label_key = 'label-%09d' % index 72 | label = str(txn.get(label_key.encode()), 'utf-8') 73 | 74 | if self.target_transform is not None: 75 | label = self.target_transform(label) 76 | # print(img,label) 77 | return (img, label) 78 | 79 | 80 | class resizeNormalize(object): 81 | def __init__(self, size, interpolation=Image.BILINEAR): 82 | self.size = size 83 | self.interpolation = interpolation 84 | self.toTensor = transforms.ToTensor() 85 | 86 | def __call__(self, img): 87 | img = img.resize(self.size, self.interpolation) 88 | img = self.toTensor(img) 89 | img.sub_(0.5).div_(0.5) 90 | return img 91 | 92 | 93 | class randomSequentialSampler(sampler.Sampler): 94 | def __init__(self, data_source, batch_size): 95 | self.num_samples = len(data_source) 96 | self.batch_size = batch_size 97 | 98 | def __iter__(self): 99 | n_batch = len(self) // self.batch_size 100 | tail = len(self) % self.batch_size 101 | index = torch.LongTensor(len(self)).fill_(0) 102 | for i in range(n_batch): 103 | random_start = random.randint(0, len(self) - self.batch_size) 104 | batch_index = random_start + torch.range(0, self.batch_size - 1) 105 | index[i * self.batch_size:(i + 1) * self.batch_size] = batch_index 106 | # deal with tail 107 | if tail: 108 | random_start = random.randint(0, len(self) - self.batch_size) 109 | tail_index = random_start + torch.range(0, tail - 1) 110 | index[(i + 1) * self.batch_size:] = tail_index 111 | 112 | return iter(index) 113 | 114 | def __len__(self): 115 | return self.num_samples 116 | 117 | 118 | class alignCollate(object): 119 | def __init__(self, imgH=32, imgW=100, keep_ratio=False, min_ratio=1): 120 | self.imgH = imgH 121 | self.imgW = imgW 122 | self.keep_ratio = keep_ratio 123 | self.min_ratio = min_ratio 124 | 125 | def __call__(self, batch): 126 | images, labels = zip(*batch) 127 | 128 | imgH = self.imgH 129 | imgW = self.imgW 130 | if self.keep_ratio: 131 | ratios = [] 132 | for image in images: 133 | w, h = image.size 134 | ratios.append(w / float(h)) 135 | ratios.sort() 136 | max_ratio = ratios[-1] 137 | imgW = int(np.floor(max_ratio * imgH)) 138 | imgW = max(imgH * self.min_ratio, imgW) # assure imgH >= imgW 139 | 140 | transform = resizeNormalize((imgW, imgH)) 141 | images = [transform(image) for image in images] 142 | images = torch.cat([t.unsqueeze(0) for t in images], 0) 143 | 144 | return images, labels 145 | 146 | 147 | def ctc_lambda_func(args): 148 | y_pred, labels, input_length, label_length = args 149 | # print("cccccccccc:",y_pred,labels,input_length,label_length) 150 | y_pred = y_pred[:, 2:, :] 151 | 152 | return K.ctc_batch_cost(labels, y_pred, input_length, label_length) 153 | 154 | 155 | def get_model(height, nclass): 156 | input = Input(shape=(height, None, 1), name='the_input') 157 | m = Conv2D(64, kernel_size=(3, 3), activation='relu', padding='same', name='conv1')(input) 158 | m = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name='pool1')(m) 159 | m = Conv2D(128, kernel_size=(3, 3), activation='relu', padding='same', name='conv2')(m) 160 | m = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name='pool2')(m) 161 | m = Conv2D(256, kernel_size=(3, 3), activation='relu', padding='same', name='conv3')(m) 162 | m = Conv2D(256, kernel_size=(3, 3), activation='relu', padding='same', name='conv4')(m) 163 | 164 | m = ZeroPadding2D(padding=(0, 1))(m) 165 | m = MaxPooling2D(pool_size=(2, 2), strides=(2, 1), padding='valid', name='pool3')(m) 166 | 167 | m = Conv2D(512, kernel_size=(3, 3), activation='relu', padding='same', name='conv5')(m) 168 | m = BatchNormalization(axis=1)(m) 169 | m = Conv2D(512, kernel_size=(3, 3), activation='relu', padding='same', name='conv6')(m) 170 | m = BatchNormalization(axis=1)(m) 171 | m = ZeroPadding2D(padding=(0, 1))(m) 172 | m = MaxPooling2D(pool_size=(2, 2), strides=(2, 1), padding='valid', name='pool4')(m) 173 | m = Conv2D(512, kernel_size=(2, 2), activation='relu', padding='valid', name='conv7')(m) 174 | 175 | m = Permute((2, 1, 3), name='permute')(m) 176 | m = TimeDistributed(Flatten(), name='timedistrib')(m) 177 | 178 | m = Bidirectional(GRU(rnnunit, return_sequences=True), name='blstm1')(m) 179 | m = Dense(rnnunit, name='blstm1_out', activation='linear')(m) 180 | m = Bidirectional(GRU(rnnunit, return_sequences=True), name='blstm2')(m) 181 | y_pred = Dense(nclass, name='blstm2_out', activation='softmax')(m) 182 | 183 | basemodel = Model(inputs=input, outputs=y_pred) 184 | 185 | labels = Input(name='the_labels', shape=[None, ], dtype='float32') 186 | input_length = Input(name='input_length', shape=[1], dtype='int64') 187 | label_length = Input(name='label_length', shape=[1], dtype='int64') 188 | 189 | loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([y_pred, labels, input_length, label_length]) 190 | model = Model(inputs=[input, labels, input_length, label_length], outputs=[loss_out]) 191 | sgd = SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True, clipnorm=5) 192 | # model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer='adadelta') 193 | model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=sgd) 194 | model.summary() 195 | return model, basemodel 196 | 197 | 198 | alphabet = """'疗绚诚娇溜题贿者廖更纳加奉公一就汴计与路房原妇208-7其>:],,骑刈全消昏傈安久钟嗅不影处驽蜿资关椤地瘸专问忖票嫉炎韵要月田节陂鄙捌备拳伺眼网盎大傍心东愉汇蹿科每业里航晏字平录先13彤鲶产稍督腴有象岳注绍在泺文定核名水过理让偷率等这发”为含肥酉相鄱七编猥锛日镀蒂掰倒辆栾栗综涩州雌滑馀了机块司宰甙兴矽抚保用沧秩如收息滥页疑埠!!姥异橹钇向下跄的椴沫国绥獠报开民蜇何分凇长讥藏掏施羽中讲派嘟人提浼间世而古多倪唇饯控庚首赛蜓味断制觉技替艰溢潮夕钺外摘枋动双单啮户枇确锦曜杜或能效霜盒然侗电晁放步鹃新杖蜂吒濂瞬评总隍对独合也是府青天诲墙组滴级邀帘示已时骸仄泅和遨店雇疫持巍踮境只亨目鉴崤闲体泄杂作般轰化解迂诿蛭璀腾告版服省师小规程线海办引二桧牌砺洄裴修图痫胡许犊事郛基柴呼食研奶律蛋因葆察戏褒戒再李骁工貂油鹅章啄休场给睡纷豆器捎说敏学会浒设诊格廓查来霓室溆¢诡寥焕舜柒狐回戟砾厄实翩尿五入径惭喹股宇篝|;美期云九祺扮靠锝槌系企酰阊暂蚕忻豁本羹执条钦H獒限进季楦于芘玖铋茯未答粘括样精欠矢甥帷嵩扣令仔风皈行支部蓉刮站蜡救钊汗松嫌成可.鹤院从交政怕活调球局验髌第韫谗串到圆年米/*友忿检区看自敢刃个兹弄流留同没齿星聆轼湖什三建蛔儿椋汕震颧鲤跟力情璺铨陪务指族训滦鄣濮扒商箱十召慷辗所莞管护臭横硒嗓接侦六露党馋驾剖高侬妪幂猗绺骐央酐孝筝课徇缰门男西项句谙瞒秃篇教碲罚声呐景前富嘴鳌稀免朋啬睐去赈鱼住肩愕速旁波厅健茼厥鲟谅投攸炔数方击呋谈绩别愫僚躬鹧胪炳招喇膨泵蹦毛结54谱识陕粽婚拟构且搜任潘比郢妨醪陀桔碘扎选哈骷楷亿明缆脯监睫逻婵共赴淝凡惦及达揖谩澹减焰蛹番祁柏员禄怡峤龙白叽生闯起细装谕竟聚钙上导渊按艾辘挡耒盹饪臀记邮蕙受各医搂普滇朗茸带翻酚(光堤墟蔷万幻〓瑙辈昧盏亘蛀吉铰请子假闻税井诩哨嫂好面琐校馊鬣缂营访炖占农缀否经钚棵趟张亟吏茶谨捻论迸堂玉信吧瞠乡姬寺咬溏苄皿意赉宝尔钰艺特唳踉都荣倚登荐丧奇涵批炭近符傩感道着菊虹仲众懈濯颞眺南释北缝标既茗整撼迤贲挎耱拒某妍卫哇英矶藩治他元领膜遮穗蛾飞荒棺劫么市火温拈棚洼转果奕卸迪伸泳斗邡侄涨屯萋胭氡崮枞惧冒彩斜手豚随旭淑妞形菌吲沱争驯歹挟兆柱传至包内响临红功弩衡寂禁老棍耆渍织害氵渑布载靥嗬虽苹咨娄库雉榜帜嘲套瑚亲簸欧边6腿旮抛吹瞳得镓梗厨继漾愣憨士策窑抑躯襟脏参贸言干绸鳄穷藜音折详)举悍甸癌黎谴死罩迁寒驷袖媒蒋掘模纠恣观祖蛆碍位稿主澧跌筏京锏帝贴证糠才黄鲸略炯饱四出园犀牧容汉杆浈汰瑷造虫瘩怪驴济应花沣谔夙旅价矿以考su呦晒巡茅准肟瓴詹仟褂译桌混宁怦郑抿些余鄂饴攒珑群阖岔琨藓预环洮岌宀杲瀵最常囡周踊女鼓袭喉简范薯遐疏粱黜禧法箔斤遥汝奥直贞撑置绱集她馅逗钧橱魉[恙躁唤9旺膘待脾惫购吗依盲度瘿蠖俾之镗拇鲵厝簧续款展啃表剔品钻腭损清锶统涌寸滨贪链吠冈伎迥咏吁览防迅失汾阔逵绀蔑列川凭努熨揪利俱绉抢鸨我即责膦易毓鹊刹玷岿空嘞绊排术估锷违们苟铜播肘件烫审鲂广像铌惰铟巳胍鲍康憧色恢想拷尤疳知SYFDA峄裕帮握搔氐氘难墒沮雨叁缥悴藐湫娟苑稠颛簇后阕闭蕤缚怎佞码嘤蔡痊舱螯帕赫昵升烬岫、疵蜻髁蕨隶烛械丑盂梁强鲛由拘揉劭龟撤钩呕孛费妻漂求阑崖秤甘通深补赃坎床啪承吼量暇钼烨阂擎脱逮称P神属矗华届狍葑汹育患窒蛰佼静槎运鳗庆逝曼疱克代官此麸耧蚌晟例础榛副测唰缢迹灬霁身岁赭扛又菡乜雾板读陷徉贯郁虑变钓菜圾现琢式乐维渔浜左吾脑钡警T啵拴偌漱湿硕止骼魄积燥联踢玛|则窿见振畿送班钽您赵刨印讨踝籍谡舌崧汽蔽沪酥绒怖财帖肱私莎勋羔霸励哼帐将帅渠纪婴娩岭厘滕吻伤坝冠戊隆瘁介涧物黍并姗奢蹑掣垸锴命箍捉病辖琰眭迩艘绌繁寅若毋思诉类诈燮轲酮狂重反职筱县委磕绣奖晋濉志徽肠呈獐坻口片碰几村柿劳料获亩惕晕厌号罢池正鏖煨家棕复尝懋蜥锅岛扰队坠瘾钬@卧疣镇譬冰彷频黯据垄采八缪瘫型熹砰楠襁箐但嘶绳啤拍盥穆傲洗盯塘怔筛丿台恒喂葛永¥烟酒桦书砂蚝缉态瀚袄圳轻蛛超榧遛姒奘铮右荽望偻卡丶氰附做革索戚坨桷唁垅榻岐偎坛莨山殊微骇陈爨推嗝驹澡藁呤卤嘻糅逛侵郓酌德摇※鬃被慨殡羸昌泡戛鞋河宪沿玲鲨翅哽源铅语照邯址荃佬顺鸳町霭睾瓢夸椁晓酿痈咔侏券噎湍签嚷离午尚社锤背孟使浪缦潍鞅军姹驶笑鳟鲁》孽钜绿洱礴焯椰颖囔乌孔巴互性椽哞聘昨早暮胶炀隧低彗昝铁呓氽藉喔癖瑗姨权胱韦堑蜜酋楝砝毁靓歙锲究屋喳骨辨碑武鸠宫辜烊适坡殃培佩供走蜈迟翼况姣凛浔吃飘债犟金促苛崇坂莳畔绂兵蠕斋根砍亢欢恬崔剁餐榫快扶‖濒缠鳜当彭驭浦篮昀锆秸钳弋娣瞑夷龛苫拱致%嵊障隐弑初娓抉汩累蓖"唬助苓昙押毙破城郧逢嚏獭瞻溱婿赊跨恼璧萃姻貉灵炉密氛陶砸谬衔点琛沛枳层岱诺脍榈埂征冷裁打蹴素瘘逞蛐聊激腱萘踵飒蓟吆取咙簋涓矩曝挺揣座你史舵焱尘苏笈脚溉榨诵樊邓焊义庶儋蟋蒲赦呷杞诠豪还试颓茉太除紫逃痴草充鳕珉祗墨渭烩蘸慕璇镶穴嵘恶骂险绋幕碉肺戳刘潞秣纾潜銮洛须罘销瘪汞兮屉r林厕质探划狸殚善煊烹〒锈逯宸辍泱柚袍远蹋嶙绝峥娥缍雀徵认镱谷=贩勉撩鄯斐洋非祚泾诒饿撬威晷搭芍锥笺蓦候琊档礁沼卵荠忑朝凹瑞头仪弧孵畏铆突衲车浩气茂悖厢枕酝戴湾邹飚攘锂写宵翁岷无喜丈挑嗟绛殉议槽具醇淞笃郴阅饼底壕砚弈询缕庹翟零筷暨舟闺甯撞麂茌蔼很珲捕棠角阉媛娲诽剿尉爵睬韩诰匣危糍镯立浏阳少盆舔擘匪申尬铣旯抖赘瓯居ˇ哮游锭茏歌坏甚秒舞沙仗劲潺阿燧郭嗖霏忠材奂耐跺砀输岖媳氟极摆灿今扔腻枝奎药熄吨话q额慑嘌协喀壳埭视著於愧陲翌峁颅佛腹聋侯咎叟秀颇存较罪哄岗扫栏钾羌己璨枭霉煌涸衿键镝益岢奏连夯睿冥均糖狞蹊稻爸刿胥煜丽肿璃掸跚灾垂樾濑乎莲窄犹撮战馄软络显鸢胸宾妲恕埔蝌份遇巧瞟粒恰剥桡博讯凯堇阶滤卖斌骚彬兑磺樱舷两娱福仃差找桁÷净把阴污戬雷碓蕲楚罡焖抽妫咒仑闱尽邑菁爱贷沥鞑牡嗉崴骤塌嗦订拮滓捡锻次坪杩臃箬融珂鹗宗枚降鸬妯阄堰盐毅必杨崃俺甬状莘货耸菱腼铸唏痤孚澳懒溅翘疙杷淼缙骰喊悉砻坷艇赁界谤纣宴晃茹归饭梢铡街抄肼鬟苯颂撷戈炒咆茭瘙负仰客琉铢封卑珥椿镧窨鬲寿御袤铃萎砖餮脒裳肪孕嫣馗嵇恳氯江石褶冢祸阻狈羞银靳透咳叼敷芷啥它瓤兰痘懊逑肌往捺坊甩呻〃沦忘膻祟菅剧崆智坯臧霍墅攻眯倘拢骠铐庭岙瓠′缺泥迢捶??郏喙掷沌纯秘种听绘固螨团香盗妒埚蓝拖旱荞铀血遏汲辰叩拽幅硬惶桀漠措泼唑齐肾念酱虚屁耶旗砦闵婉馆拭绅韧忏窝醋葺顾辞倜堆辋逆玟贱疾董惘倌锕淘嘀莽俭笏绑鲷杈择蟀粥嗯驰逾案谪褓胫哩昕颚鲢绠躺鹄崂儒俨丝尕泌啊萸彰幺吟骄苣弦脊瑰〈诛镁析闪剪侧哟框螃守嬗燕狭铈缮概迳痧鲲俯售笼痣扉挖满咋援邱扇歪便玑绦峡蛇叨〖泽胃斓喋怂坟猪该蚬炕弥赞棣晔娠挲狡创疖铕镭稷挫弭啾翔粉履苘哦楼秕铂土锣瘟挣栉习享桢袅磨桂谦延坚蔚噗署谟猬钎恐嬉雒倦衅亏璩睹刻殿王算雕麻丘柯骆丸塍谚添鲈垓桎蚯芥予飕镦谌窗醚菀亮搪莺蒿羁足J真轶悬衷靛翊掩哒炅掐冼妮l谐稚荆擒犯陵虏浓崽刍陌傻孜千靖演矜钕煽杰酗渗伞栋俗泫戍罕沾疽灏煦芬磴叱阱榉湃蜀叉醒彪租郡篷屎良垢隗弱陨峪砷掴颁胎雯绵贬沐撵隘篙暖曹陡栓填臼彦瓶琪潼哪鸡摩啦俟锋域耻蔫疯纹撇毒绶痛酯忍爪赳歆嘹辕烈册朴钱吮毯癜娃谀邵厮炽璞邃丐追词瓒忆轧芫谯喷弟半冕裙掖墉绮寝苔势顷褥切衮君佳嫒蚩霞佚洙逊镖暹唛&殒顶碗獗轭铺蛊废恹汨崩珍那杵曲纺夏薰傀闳淬姘舀拧卷楂恍讪厩寮篪赓乘灭盅鞣沟慎挂饺鼾杳树缨丛絮娌臻嗳篡侩述衰矛圈蚜匕筹匿濞晨叶骋郝挚蚴滞增侍描瓣吖嫦蟒匾圣赌毡癞恺百曳需篓肮庖帏卿驿遗蹬鬓骡歉芎胳屐禽烦晌寄媾狄翡苒船廉终痞殇々畦饶改拆悻萄£瓿乃訾桅匮溧拥纱铍骗蕃龋缬父佐疚栎醍掳蓄x惆颜鲆榆〔猎敌暴谥鲫贾罗玻缄扦芪癣落徒臾恿猩托邴肄牵春陛耀刊拓蓓邳堕寇枉淌啡湄兽酷萼碚濠萤夹旬戮梭琥椭昔勺蜊绐晚孺僵宣摄冽旨萌忙蚤眉噼蟑付契瓜悼颡壁曾窕颢澎仿俑浑嵌浣乍碌褪乱蔟隙玩剐葫箫纲围伐决伙漩瑟刑肓镳缓蹭氨皓典畲坍铑檐塑洞倬储胴淳戾吐灼惺妙毕珐缈虱盖羰鸿磅谓髅娴苴唷蚣霹抨贤唠犬誓逍庠逼麓籼釉呜碧秧氩摔霄穸纨辟妈映完牛缴嗷炊恩荔茆掉紊慌莓羟阙萁磐另蕹辱鳐湮吡吩唐睦垠舒圜冗瞿溺芾囱匠僳汐菩饬漓黑霰浸濡窥毂蒡兢驻鹉芮诙迫雳厂忐臆猴鸣蚪栈箕羡渐莆捍眈哓趴蹼埕嚣骛宏淄斑噜严瑛垃椎诱压庾绞焘廿抡迄棘夫纬锹眨瞌侠脐竞瀑孳骧遁姜颦荪滚萦伪逸粳爬锁矣役趣洒颔诏逐奸甭惠攀蹄泛尼拼阮鹰亚颈惑勒〉际肛爷刚钨丰养冶鲽辉蔻画覆皴妊麦返醉皂擀〗酶凑粹悟诀硖港卜z杀涕±舍铠抵弛段敝镐奠拂轴跛袱et沉菇俎薪峦秭蟹历盟菠寡液肢喻染裱悱抱氙赤捅猛跑氮谣仁尺辊窍烙衍架擦倏璐瑁币楞胖夔趸邛惴饕虔蝎§哉贝宽辫炮扩饲籽魏菟锰伍猝末琳哚蛎邂呀姿鄞却歧仙恸椐森牒寤袒婆虢雅钉朵贼欲苞寰故龚坭嘘咫礼硷兀睢汶’铲烧绕诃浃钿哺柜讼颊璁腔洽咐脲簌筠镣玮鞠谁兼姆挥梯蝴谘漕刷躏宦弼b垌劈麟莉揭笙渎仕嗤仓配怏抬错泯镊孰猿邪仍秋鼬壹歇吵炼<尧射柬廷胧霾凳隋肚浮梦祥株堵退L鹫跎凶毽荟炫栩玳甜沂鹿顽伯爹赔蛴徐匡欣狰缸雹蟆疤默沤啜痂衣禅wih辽葳黝钗停沽棒馨颌肉吴硫悯劾娈马啧吊悌镑峭帆瀣涉咸疸滋泣翦拙癸钥蜒+尾庄凝泉婢渴谊乞陆锉糊鸦淮IBN晦弗乔庥葡尻席橡傣渣拿惩麋斛缃矮蛏岘鸽姐膏催奔镒喱蠡摧钯胤柠拐璋鸥卢荡倾^_珀逄萧塾掇贮笆聂圃冲嵬M滔笕值炙偶蜱搐梆汪蔬腑鸯蹇敞绯仨祯谆梧糗鑫啸豺囹猾巢柄瀛筑踌沭暗苁鱿蹉脂蘖牢热木吸溃宠序泞偿拜檩厚朐毗螳吞媚朽担蝗橘畴祈糟盱隼郜惜珠裨铵焙琚唯咚噪骊丫滢勤棉呸咣淀隔蕾窈饨挨煅短匙粕镜赣撕墩酬馁豌颐抗酣氓佑搁哭递耷涡桃贻碣截瘦昭镌蔓氚甲猕蕴蓬散拾纛狼猷铎埋旖矾讳囊糜迈粟蚂紧鲳瘢栽稼羊锄斟睁桥瓮蹙祉醺鼻昱剃跳篱跷蒜翎宅晖嗑壑峻癫屏狠陋袜途憎祀莹滟佶溥臣约盛峰磁慵婪拦莅朕鹦粲裤哎疡嫖琵窟堪谛嘉儡鳝斩郾驸酊妄胜贺徙傅噌钢栅庇恋匝巯邈尸锚粗佟蛟薹纵蚊郅绢锐苗俞篆淆膀鲜煎诶秽寻涮刺怀噶巨褰魅灶灌桉藕谜舸薄搀恽借牯痉渥愿亓耘杠柩锔蚶钣珈喘蹒幽赐稗晤莱泔扯肯菪裆腩豉疆骜腐倭珏唔粮亡润慰伽橄玄誉醐胆龊粼塬陇彼削嗣绾芽妗垭瘴爽薏寨龈泠弹赢漪猫嘧涂恤圭茧烽屑痕巾赖荸凰腮畈亵蹲偃苇澜艮换骺烘苕梓颉肇哗悄氤涠葬屠鹭植竺佯诣鲇瘀鲅邦移滁冯耕癔戌茬沁巩悠湘洪痹锟循谋腕鳃钠捞焉迎碱伫急榷奈邝卯辄皲卟醛畹忧稳雄昼缩阈睑扌耗曦涅捏瞧邕淖漉铝耦禹湛喽莼琅诸苎纂硅始嗨傥燃臂赅嘈呆贵屹壮肋亍蚀卅豹腆邬迭浊}童螂捐圩勐触寞汊壤荫膺渌芳懿遴螈泰蓼蛤茜舅枫朔膝眙避梅判鹜璜牍缅垫藻黔侥惚懂踩腰腈札丞唾慈顿摹荻琬~斧沈滂胁胀幄莜Z匀鄄掌绰茎焚赋萱谑汁铒瞎夺蜗野娆冀弯篁懵灞隽芡脘俐辩芯掺喏膈蝈觐悚踹蔗熠鼠呵抓橼峨畜缔禾崭弃熊摒凸拗穹蒙抒祛劝闫扳阵醌踪喵侣搬仅荧赎蝾琦买婧瞄寓皎冻赝箩莫瞰郊笫姝筒枪遣煸袋舆痱涛母〇启践耙绲盘遂昊搞槿诬纰泓惨檬亻越Co憩熵祷钒暧塔阗胰咄娶魔琶钞邻扬杉殴咽弓〆髻】吭揽霆拄殖脆彻岩芝勃辣剌钝嘎甄佘皖伦授徕憔挪皇庞稔芜踏溴兖卒擢饥鳞煲‰账颗叻斯捧鳍琮讹蛙纽谭酸兔莒睇伟觑羲嗜宜褐旎辛卦诘筋鎏溪挛熔阜晰鳅丢奚灸呱献陉黛鸪甾萨疮拯洲疹辑叙恻谒允柔烂氏逅漆拎惋扈湟纭啕掬擞哥忽涤鸵靡郗瓷扁廊怨雏钮敦E懦憋汀拚啉腌岸f痼瞅尊咀眩飙忌仝迦熬毫胯篑茄腺凄舛碴锵诧羯後漏汤宓仞蚁壶谰皑铄棰罔辅晶苦牟闽\烃饮聿丙蛳朱煤涔鳖犁罐荼砒淦妤黏戎孑婕瑾戢钵枣捋砥衩狙桠稣阎肃梏诫孪昶婊衫嗔侃塞蜃樵峒貌屿欺缫阐栖诟珞荭吝萍嗽恂啻蜴磬峋俸豫谎徊镍韬魇晴U囟猜蛮坐囿伴亭肝佗蝠妃胞滩榴氖垩苋砣扪馏姓轩厉夥侈禀垒岑赏钛辐痔披纸碳“坞蠓挤荥沅悔铧帼蒌蝇apyng哀浆瑶凿桶馈皮奴苜佤伶晗铱炬优弊氢恃甫攥端锌灰稹炝曙邋亥眶碾拉萝绔捷浍腋姑菖凌涞麽锢桨潢绎镰殆锑渝铬困绽觎匈糙暑裹鸟盔肽迷綦『亳佝俘钴觇骥仆疝跪婶郯瀹唉脖踞针晾忒扼瞩叛椒疟嗡邗肆跆玫忡捣咧唆艄蘑潦笛阚沸泻掊菽贫斥髂孢镂赂麝鸾屡衬苷恪叠希粤爻喝茫惬郸绻庸撅碟宄妹膛叮饵崛嗲椅冤搅咕敛尹垦闷蝉霎勰败蓑泸肤鹌幌焦浠鞍刁舰乙竿裔。茵函伊兄丨娜匍謇莪宥似蝽翳酪翠粑薇祢骏赠叫Q噤噻竖芗莠潭俊羿耜O郫趁嗪囚蹶芒洁笋鹑敲硝啶堡渲揩』携宿遒颍扭棱割萜蔸葵琴捂饰衙耿掠募岂窖涟蔺瘤柞瞪怜匹距楔炜哆秦缎幼茁绪痨恨楸娅瓦桩雪嬴伏榔妥铿拌眠雍缇‘卓搓哌觞噩屈哧髓咦巅娑侑淫膳祝勾姊莴胄疃薛蜷胛巷芙芋熙闰勿窃狱剩钏幢陟铛慧靴耍k浙浇飨惟绗祜澈啼咪磷摞诅郦抹跃壬吕肖琏颤尴剡抠凋赚泊津宕殷倔氲漫邺涎怠$垮荬遵俏叹噢饽蜘孙筵疼鞭羧牦箭潴c眸祭髯啖坳愁芩驮倡巽穰沃胚怒凤槛剂趵嫁v邢灯鄢桐睽檗锯槟婷嵋圻诗蕈颠遭痢芸怯馥竭锗徜恭遍籁剑嘱苡龄僧桑潸弘澶楹悲讫愤腥悸谍椹呢桓葭攫阀翰躲敖柑郎笨橇呃魁燎脓葩磋垛玺狮沓砜蕊锺罹蕉翱虐闾巫旦茱嬷枯鹏贡芹汛矫绁拣禺佃讣舫惯乳趋疲挽岚虾衾蠹蹂飓氦铖孩稞瑜壅掀勘妓畅髋W庐牲蓿榕练垣唱邸菲昆婺穿绡麒蚱掂愚泷涪漳妩娉榄讷觅旧藤煮呛柳腓叭庵烷阡罂蜕擂猖咿媲脉【沏貅黠熏哲烁坦酵兜×潇撒剽珩圹乾摸樟帽嗒襄魂轿憬锡〕喃皆咖隅脸残泮袂鹂珊囤捆咤误徨闹淙芊淋怆囗拨梳渤RG绨蚓婀幡狩麾谢唢裸旌伉纶裂驳砼咛澄樨蹈宙澍倍貔操勇蟠摈砧虬够缁悦藿撸艹摁淹豇虎榭ˉ吱d°喧荀踱侮奋偕饷犍惮坑璎徘宛妆袈倩窦昂荏乖K怅撰鳙牙袁酞X痿琼闸雁趾荚虻涝《杏韭偈烤绫鞘卉症遢蓥诋杭荨匆竣簪辙敕虞丹缭咩黟m淤瑕咂铉硼茨嶂痒畸敬涿粪窘熟叔嫔盾忱裘憾梵赡珙咯娘庙溯胺葱痪摊荷卞乒髦寐铭坩胗枷爆溟嚼羚砬轨惊挠罄竽菏氧浅楣盼枢炸阆杯谏噬淇渺俪秆墓泪跻砌痰垡渡耽釜讶鳎煞呗韶舶绷鹳缜旷铊皱龌檀霖奄槐艳蝶旋哝赶骞蚧腊盈丁`蜚矸蝙睨嚓僻鬼醴夜彝磊笔拔栀糕厦邰纫逭纤眦膊馍躇烯蘼冬诤暄骶哑瘠」臊丕愈咱螺擅跋搏硪谄笠淡嘿骅谧鼎皋姚歼蠢驼耳胬挝涯狗蒽孓犷凉芦箴铤孤嘛坤V茴朦挞尖橙诞搴碇洵浚帚蜍漯柘嚎讽芭荤咻祠秉跖埃吓糯眷馒惹娼鲑嫩讴轮瞥靶褚乏缤宋帧删驱碎扑俩俄偏涣竹噱皙佰渚唧斡#镉刀崎筐佣夭贰肴峙哔艿匐牺镛缘仡嫡劣枸堀梨簿鸭蒸亦稽浴{衢束槲j阁揍疥棋潋聪窜乓睛插冉阪苍搽「蟾螟幸仇樽撂慢跤幔俚淅覃觊溶妖帛侨曰妾泗 """ 199 | characters = alphabet[:] 200 | 201 | nclass = len(characters) + 1 202 | 203 | trainroot = '../data/lmdb/train' 204 | valroot = '../data/lmdb/val' 205 | batchSize = 32 206 | workers = 4 207 | imgH = 32 208 | imgW = 256 209 | keep_ratio = False 210 | random_sample = False 211 | 212 | 213 | def one_hot(text, length=10, characters=characters): 214 | label = np.zeros(length) 215 | # print(type(text)) 216 | for i, char in enumerate(text): 217 | index = characters.find(char) 218 | if index == -1: 219 | index = characters.find(u' ') 220 | # print(i,char,length) 221 | # if i < length: 222 | label[i] = index 223 | return label 224 | 225 | 226 | n_len = 10 227 | 228 | 229 | def gen(loader, flag='train'): 230 | while True: 231 | i = 0 232 | n = len(loader) 233 | for X, Y in loader: 234 | X = X.numpy() 235 | X = X.reshape((-1, imgH, imgW, 1)) 236 | if flag == 'test': 237 | Y = Y.numpy() 238 | 239 | Y = np.array(Y) 240 | Length = int(imgW / 4) - 1 241 | batchs = X.shape[0] 242 | # Y = Y.numpy() 243 | if i > n - 1: 244 | i = 0 245 | break 246 | 247 | yield [X, Y, np.ones(batchs) * int(Length), np.ones(batchs) * n_len], np.ones(batchs) 248 | 249 | 250 | sampler = None 251 | train_dataset = lmdbDataset(root=trainroot, target_transform=one_hot) 252 | 253 | train_loader = torch.utils.data.DataLoader( 254 | train_dataset, batch_size=batchSize, 255 | shuffle=True, sampler=sampler, 256 | num_workers=int(workers), 257 | collate_fn=alignCollate(imgH=imgH, imgW=imgW, keep_ratio=keep_ratio)) 258 | 259 | test_dataset = lmdbDataset( 260 | root=valroot, transform=resizeNormalize((imgW, imgH)), target_transform=one_hot) 261 | 262 | test_loader = torch.utils.data.DataLoader( 263 | test_dataset, shuffle=True, batch_size=batchSize, num_workers=int(workers)) 264 | 265 | if __name__ == '__main__': 266 | from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau 267 | 268 | model, basemodel = get_model(height=imgH, nclass=nclass) 269 | import os 270 | 271 | if os.path.exists('/Users/xiaofeng/Code/Github/dataset/CHINESE_OCR/crnn_ocr/pretrain-models/keras.hdf5'): 272 | basemodel.load_weights('/Users/xiaofeng/Code/Github/dataset/CHINESE_OCR/crnn_ocr/pretrain-models/keras.hdf5') 273 | 274 | ##注意此处保存的是model的权重 275 | checkpointer = ModelCheckpoint(filepath="save_model/model{epoch:02d}-{val_loss:.4f}.hdf5", monitor='val_loss', 276 | verbose=0, save_weights_only=False, save_best_only=True) 277 | rlu = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=1, verbose=0, mode='auto', epsilon=0.0001, 278 | cooldown=0, min_lr=0) 279 | 280 | model.fit_generator(gen(train_loader, flag='train'), 281 | steps_per_epoch=102400, 282 | epochs=200, 283 | validation_data=gen(test_loader, flag='test'), 284 | callbacks=[checkpointer, rlu], 285 | validation_steps=1024) 286 | -------------------------------------------------------------------------------- /train/keras-train/basemodel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CodeREWorld/CV-OCR/943925b8bbe0f11c8079d10174f36d60faaeb88a/train/keras-train/basemodel.png -------------------------------------------------------------------------------- /train/keras-train/dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # !/usr/bin/python 3 | 4 | import random 5 | import sys 6 | 7 | import lmdb 8 | import numpy as np 9 | import six 10 | import torch 11 | import torchvision.transforms as transforms 12 | from PIL import Image 13 | from torch.utils.data import Dataset 14 | from torch.utils.data import sampler 15 | 16 | 17 | 18 | class lmdbDataset(Dataset): 19 | def __init__(self, root=None, transform=None, target_transform=None): 20 | self.env = lmdb.open( 21 | root, 22 | max_readers=1, 23 | readonly=True, 24 | lock=False, 25 | readahead=False, 26 | meminit=False) 27 | 28 | if not self.env: 29 | print('cannot creat lmdb from %s' % (root)) 30 | sys.exit(0) 31 | 32 | with self.env.begin(write=False) as txn: 33 | nSamples = int(txn.get('num-samples'.encode())) 34 | print("nSamples:{}".format(nSamples)) 35 | self.nSamples = nSamples 36 | 37 | self.transform = transform 38 | self.target_transform = target_transform 39 | 40 | def __len__(self): 41 | return self.nSamples 42 | 43 | def __getitem__(self, index): 44 | assert index <= len(self), 'index range error' 45 | index += 1 46 | with self.env.begin(write=False) as txn: 47 | img_key = 'image-%09d' % index 48 | imgbuf = txn.get(img_key.encode()) 49 | 50 | buf = six.BytesIO() 51 | buf.write(imgbuf) 52 | buf.seek(0) 53 | try: 54 | img = Image.open(buf).convert('L') 55 | # img.save("1111111111.jpg") 56 | except IOError: 57 | print('Corrupted image for %d' % index) 58 | if index > self.nSamples - 1: 59 | index = 0 60 | return self[index + 1] 61 | 62 | if self.transform is not None: 63 | img = self.transform(img) 64 | 65 | label_key = 'label-%09d' % index 66 | label = str(txn.get(label_key.encode()), 'utf-8') 67 | 68 | if self.target_transform is not None: 69 | label = self.target_transform(label) 70 | # print(img,label) 71 | return (img, label) 72 | 73 | 74 | class resizeNormalize(object): 75 | def __init__(self, size, interpolation=Image.BILINEAR): 76 | self.size = size 77 | self.interpolation = interpolation 78 | self.toTensor = transforms.ToTensor() 79 | 80 | def __call__(self, img): 81 | img = img.resize(self.size, self.interpolation) 82 | img = self.toTensor(img) 83 | img.sub_(0.5).div_(0.5) 84 | return img 85 | 86 | 87 | class randomSequentialSampler(sampler.Sampler): 88 | def __init__(self, data_source, batch_size): 89 | self.num_samples = len(data_source) 90 | self.batch_size = batch_size 91 | 92 | def __iter__(self): 93 | n_batch = len(self) // self.batch_size 94 | tail = len(self) % self.batch_size 95 | index = torch.LongTensor(len(self)).fill_(0) 96 | for i in range(n_batch): 97 | random_start = random.randint(0, len(self) - self.batch_size) 98 | batch_index = random_start + torch.range(0, self.batch_size - 1) 99 | index[i * self.batch_size:(i + 1) * self.batch_size] = batch_index 100 | # deal with tail 101 | if tail: 102 | random_start = random.randint(0, len(self) - self.batch_size) 103 | tail_index = random_start + torch.range(0, tail - 1) 104 | index[(i + 1) * self.batch_size:] = tail_index 105 | 106 | return iter(index) 107 | 108 | def __len__(self): 109 | return self.num_samples 110 | 111 | 112 | class alignCollate(object): 113 | def __init__(self, imgH=32, imgW=100, keep_ratio=False, min_ratio=1): 114 | self.imgH = imgH 115 | self.imgW = imgW 116 | self.keep_ratio = keep_ratio 117 | self.min_ratio = min_ratio 118 | 119 | def __call__(self, batch): 120 | images, labels = zip(*batch) 121 | 122 | imgH = self.imgH 123 | imgW = self.imgW 124 | if self.keep_ratio: 125 | ratios = [] 126 | for image in images: 127 | w, h = image.size 128 | ratios.append(w / float(h)) 129 | ratios.sort() 130 | max_ratio = ratios[-1] 131 | imgW = int(np.floor(max_ratio * imgH)) 132 | imgW = max(imgH * self.min_ratio, imgW) # assure imgH >= imgW 133 | 134 | transform = resizeNormalize((imgW, imgH)) 135 | images = [transform(image) for image in images] 136 | images = torch.cat([t.unsqueeze(0) for t in images], 0) 137 | 138 | return images, labels 139 | --------------------------------------------------------------------------------