├── README.md ├── crnn_predictor.py ├── data ├── README.md ├── test │ ├── README.md │ └── text │ │ └── README.md └── train │ ├── README.md │ └── text │ └── README.md ├── generate_data ├── font │ └── font.ttf └── generate_data.py ├── lstm_predictor.py ├── model └── README.md ├── symbol ├── __init__.py ├── bi_lstm.py ├── crnn.py └── lstm.py ├── train_bi_lstm.py ├── train_crnn.py └── train_lstm.py /README.md: -------------------------------------------------------------------------------- 1 | # CNN-LSTM-CTC text recognition 2 | 3 | I realize three different models for text recognition, and all of them consist of CTC loss layer to realize no segmentation for text images. 4 | 5 | ### Disclaimer 6 | 7 | I refer to the official mxnet warpctc example [here](https://github.com/dmlc/mxnet/tree/master/example/warpctc). 8 | 9 | ### Getting started 10 | * Build MXNet with Baidu Warp CTC, and please following this instructions [here](https://github.com/dmlc/mxnet/tree/master/example/warpctc). 11 | 12 | When I use this official instructions to add Baidu Warp CTC to Mxnet, there are some errors because the latest version of Baidu Warp CTC has complicts with mxnet. Recently, I see someone has already solved this problem and updated the official mxnet warpctc example. However, if you still have problem, please refer to this issue [here](https://github.com/dmlc/mxnet/pull/3853). 13 | 14 | ### Generating data 15 | 16 | Run `generate_data.py` in `generate_data`. When generating training and test data, please remember to change output path and number in `generate_data.py` (I will update a more friendly way to generate training and test data when I have free time). 17 | 18 | ### Train the model 19 | 20 | I realize three different models for text recognition, you can check them in `symbol`: 21 | 22 | 1. LSTM + CTC; 23 | 2. Bidirection LSTM + CTC; 24 | 3. CNN (a modified model similiar to VGG) + Bidirection LSTM + CTC. Disclaimer: This CNN + LSTM + CTC model is a re-implementation of original CRNN which is based on torch. The official repository is available [here](https://github.com/bgshih/crnn). The arxiv paper is available [here](https://arxiv.org/pdf/1507.05717v1.pdf). 25 | 26 | * Start training: 27 | 28 | LSTM + CTC: 29 | 30 | ``` 31 | python train_lstm.py 32 | ``` 33 | 34 | Bidirection LSTM + CTC: 35 | 36 | ``` 37 | python train_bi_lstm.py 38 | ``` 39 | 40 | CNN + Bidirection LSTM + CTC: 41 | 42 | ``` 43 | python train_crnn.py 44 | ``` 45 | ### Prediction 46 | 47 | You can do the prediction using your trained model. I only write the predictors for model 1 and model 3, but it is very easy to write the predictor for model 2 when referring to the examples. 48 | 49 | Plesae run: 50 | ``` 51 | python lstm_predictor.py 52 | ``` 53 | or 54 | ``` 55 | python crnn_predictor.py 56 | ``` 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /crnn_predictor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2.7 2 | # coding=utf-8 3 | from __future__ import print_function 4 | import sys, os 5 | curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) 6 | sys.path.append("../../amalgamation/python/") 7 | sys.path.append("../../python/") 8 | import argparse 9 | 10 | from mxnet_predict import Predictor 11 | import mxnet as mx 12 | 13 | from symbol.crnn import crnn 14 | 15 | import numpy as np 16 | import cv2 17 | import os 18 | import matplotlib.pyplot as plt 19 | 20 | class lstm_ocr_model(object): 21 | # Keep Zero index for blank. (CTC request it) 22 | def __init__(self, path_of_json, path_of_params, classes, data_shape, batch_size, num_label, num_hidden, num_lstm_layer): 23 | super(lstm_ocr_model, self).__init__() 24 | self.path_of_json = path_of_json 25 | self.path_of_params = path_of_params 26 | self.classes = classes 27 | self.batch_size = batch_size 28 | self.data_shape = data_shape 29 | self.num_label = num_label 30 | self.num_hidden = num_hidden 31 | self.num_lstm_layer = num_lstm_layer 32 | self.predictor = None 33 | self.__init_ocr() 34 | 35 | def __init_ocr(self): 36 | init_c = [('l%d_init_c'%l, (self.batch_size, self.num_hidden)) for l in range(self.num_lstm_layer*2)] 37 | init_h = [('l%d_init_h'%l, (self.batch_size, self.num_hidden)) for l in range(self.num_lstm_layer*2)] 38 | init_states = init_c + init_h 39 | 40 | all_shapes = [('data', (batch_size, 1, self.data_shape[1], self.data_shape[0]))] + init_states + [('label', (self.batch_size, self.num_label))] 41 | all_shapes_dict = {} 42 | for _shape in all_shapes: 43 | all_shapes_dict[_shape[0]] = _shape[1] 44 | self.predictor = Predictor(open(self.path_of_json).read(), 45 | open(self.path_of_params).read(), 46 | all_shapes_dict,dev_type="gpu", dev_id=0) 47 | 48 | def forward_ocr(self, img): 49 | img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY) 50 | img = cv2.resize(img, self.data_shape) 51 | img = img.reshape((1, data_shape[1], data_shape[0])) 52 | img = np.multiply(img, 1/255.0) 53 | self.predictor.forward(data=img) 54 | prob = self.predictor.get_output(0) 55 | label_list = [] 56 | for p in prob: 57 | max_index = np.argsort(p)[::-1][0] 58 | label_list.append(max_index) 59 | return self.__get_string(label_list) 60 | 61 | def __get_string(self, label_list): 62 | # Do CTC label rule 63 | # CTC cannot emit a repeated symbol on consecutive timesteps 64 | ret = [] 65 | label_list2 = [0] + list(label_list) 66 | for i in range(len(label_list)): 67 | c1 = label_list2[i] 68 | c2 = label_list2[i+1] 69 | if c2 == 0 or c2 == c1: 70 | continue 71 | ret.append(c2) 72 | # change to ascii 73 | s = '' 74 | for l in ret: 75 | if l > 0 and l < (len(self.classes)+1): 76 | c = self.classes[l-1] 77 | else: 78 | c = '' 79 | s += c 80 | return s 81 | 82 | def parse_args(): 83 | parser = argparse.ArgumentParser(description='predictor') 84 | parser.add_argument('--img', dest='img', help='which image to use', 85 | default=os.path.join(os.getcwd(), 'data', 'demo', '20150105_14543723_Z.jpg'), type=str) 86 | args = parser.parse_args() 87 | return args 88 | 89 | 90 | if __name__ == '__main__': 91 | args = parse_args() 92 | json_path = os.path.join(os.getcwd(), 'model', 'crnn_ctc-symbol.json') 93 | param_path = os.path.join(os.getcwd(), 'model', 'crnn_ctc-0100.params') 94 | num_label = 9 # Set your max length of label, add one more for blank 95 | batch_size = 1 96 | num_hidden = 256 97 | num_lstm_layer = 2 98 | data_shape = (100, 32) 99 | classes = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", 100 | "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"] 101 | demo_img = args.img 102 | 103 | _lstm_ocr_model = lstm_ocr_model(json_path, param_path, classes, data_shape, batch_size, 104 | num_label, num_hidden, num_lstm_layer) 105 | img = cv2.imread(demo_img) 106 | #img = cv2.bitwise_not(img) 107 | _str = _lstm_ocr_model.forward_ocr(img) 108 | print('Result: ', _str) 109 | plt.imshow(img) 110 | plt.gca().text(0, 6.8, 111 | '{:s} {:s}'.format("prediction", _str), 112 | #bbox=dict(facecolor=colors[cls_id], alpha=0.5), 113 | fontsize=12, color='red') 114 | plt.show() 115 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | #### This is the default directory to store generated data 2 | -------------------------------------------------------------------------------- /data/test/README.md: -------------------------------------------------------------------------------- 1 | #### This is the default directory to store generated test data 2 | -------------------------------------------------------------------------------- /data/test/text/README.md: -------------------------------------------------------------------------------- 1 | #### This is the default directory to store generated test images data 2 | -------------------------------------------------------------------------------- /data/train/README.md: -------------------------------------------------------------------------------- 1 | #### This is the default directory to store generated train data 2 | -------------------------------------------------------------------------------- /data/train/text/README.md: -------------------------------------------------------------------------------- 1 | #### This is the default directory to store generated train image data 2 | -------------------------------------------------------------------------------- /generate_data/font/font.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oyxhust/CNN-LSTM-CTC-text-recognition/60250e82383ab5f3e59e062ec27b1171a8ed64c7/generate_data/font/font.ttf -------------------------------------------------------------------------------- /generate_data/generate_data.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import PIL 3 | from PIL import ImageFont 4 | from PIL import Image 5 | from PIL import ImageDraw 6 | import cv2; 7 | import numpy as np 8 | import random 9 | import os; 10 | from math import * 11 | import cPickle 12 | import re 13 | 14 | 15 | index = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4, "5": 5, "6": 6, "7": 7, "8": 8, "9": 9, "A": 10, "B": 11, "C": 12, "D": 13, "E": 14, "F": 15, "G": 16, "H": 17, 16 | "I":18, "J": 19, "K": 20, "L": 21, "M": 22, "N": 23, "O": 24, "P": 25, "Q": 26, "R": 27, "S": 28, "T": 29, "U": 30, "V": 31, 17 | "W": 32, "X": 33, "Y": 34, "Z": 35}; 18 | 19 | chars = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", 20 | "Y", "Z"]; 21 | 22 | def r(val): 23 | return int(np.random.random() * val) 24 | 25 | def random_pick(some_list, probabilities): 26 | x = random.uniform(0,1) 27 | cumulative_probability = 0.0 28 | for item, item_probability in zip(some_list, probabilities): 29 | cumulative_probability += item_probability 30 | if x < cumulative_probability:break 31 | return item 32 | 33 | def rot(img,angel,shape,max_angel,bg_gray): 34 | size_o = [shape[1],shape[0]] 35 | 36 | size = (shape[1] + int(shape[0]*cos((float(max_angel )/180) * 3.14)),shape[0]) 37 | 38 | 39 | interval = abs(int(sin((float(angel) /180) * 3.14)* shape[0])) 40 | 41 | pts1 = np.float32([[0,0], [0,size_o[1]], [size_o[0],0], [size_o[0], size_o[1]]]) 42 | if(angel>0): 43 | 44 | pts2 = np.float32([[interval,0],[0,size[1] ],[size[0],0 ],[size[0]-interval,size_o[1]]]) 45 | else: 46 | pts2 = np.float32([[0,0],[interval,size[1] ],[size[0]-interval,0 ],[size[0],size_o[1]]]) 47 | 48 | M = cv2.getPerspectiveTransform(pts1,pts2) 49 | dst = cv2.warpPerspective(img,M,size,borderValue=bg_gray) 50 | 51 | return dst 52 | 53 | def rotRandrom(img, factor, size, bg_gray): 54 | shape = size 55 | pts1 = np.float32([[0, 0], [0, shape[0]], [shape[1], 0], [shape[1], shape[0]]]) 56 | pts2 = np.float32([[r(factor), r(factor)], [ r(factor), shape[0] - r(factor)], [shape[1] - r(factor), r(factor)], 57 | [shape[1] - r(factor), shape[0] - r(factor)]]) 58 | M = cv2.getPerspectiveTransform(pts1, pts2) 59 | dst = cv2.warpPerspective(img, M, size, borderValue=bg_gray) 60 | return dst 61 | 62 | def tfactor(img): 63 | 64 | img[:,:] = img[:,:]*(0.8+ np.random.random()*0.2) 65 | return img 66 | 67 | 68 | def random_scale(x,y): 69 | gray_out = r(y+1-x) + x 70 | return gray_out 71 | 72 | def text_Gengray(bg_gray, line): 73 | gray_flag = np.random.randint(2) 74 | if bg_gray < line: 75 | text_gray = random_scale(bg_gray + line, 255) 76 | elif bg_gray > (255 - line): 77 | text_gray = random_scale(0, bg_gray - line) 78 | else: 79 | text_gray = gray_flag*random_scale(0, bg_gray - line) + (1 - gray_flag)*random_scale(bg_gray+line, 255) 80 | return text_gray 81 | 82 | def GenCh(f,val, data_shape1, data_shape2, bg_gray, text_gray, text_position): 83 | img=Image.new("L", (data_shape1,data_shape2),bg_gray) 84 | draw = ImageDraw.Draw(img) 85 | draw.text((0, text_position),val.decode('utf-8'),text_gray,font=f) 86 | #draw.text((0, text_position),val.decode('utf-8'),0,font=f) 87 | A = np.array(img) 88 | return A 89 | 90 | def AddNoiseSingleChannel(single): 91 | diff = 255-single.max(); 92 | noise = np.random.normal(0,1+r(6),single.shape); 93 | noise = (noise - noise.min())/(noise.max()-noise.min()) 94 | noise= diff*noise; 95 | noise= noise.astype(np.uint8) 96 | dst = single + noise 97 | return dst 98 | 99 | def Addblur(img, val): 100 | blur_kernel = r(val) + 1 101 | #print blur_kernel 102 | img = cv2.blur(img, (blur_kernel,blur_kernel)) 103 | return img 104 | 105 | class GenText: 106 | def __init__(self, font, font_size, counter): 107 | self.font = ImageFont.truetype(font,font_size) 108 | self.counter = counter 109 | 110 | def draw(self,val,data_shape1, data_shape2): 111 | bg_gray = r(256) 112 | text_gray = text_Gengray(bg_gray, 60) 113 | text_position = random_scale(4,12) 114 | offset_left = int(np.random.random() * 30) 115 | offset_right = int(np.random.random() * 30) 116 | offset_middle = 17 117 | add_position = -1 118 | if self.counter > 4: 119 | add_number = np.random.randint(2) 120 | else: 121 | add_number = 0 122 | if add_number == 1: 123 | add_position = np.random.randint(self.counter-1) 124 | offset = offset_left + offset_right + offset_middle*add_number 125 | img = np.array(Image.new("L", (self.counter * data_shape1 + offset, data_shape2), bg_gray)) 126 | base = offset_left 127 | for i in range(counter): 128 | #offset_middle_add = random_pick([0,1],[0.8,0.2])*offset_middle 129 | img[0: data_shape2, base : base + data_shape1]= GenCh(self.font,val[i], data_shape1, data_shape2, bg_gray, text_gray, text_position) 130 | base += data_shape1 131 | if add_position == i: 132 | base += offset_middle 133 | return img, bg_gray 134 | 135 | def generate(self,text, data_shape1, data_shape2): 136 | fg, bg_gray = self.draw(text.decode(encoding="utf-8"),data_shape1, data_shape2) 137 | com = rot(fg,r(60)-30,fg.shape,30, bg_gray) 138 | com = rotRandrom(com,10,(com.shape[1],com.shape[0]), bg_gray) 139 | com = tfactor(com) 140 | com = Addblur(com, 8) 141 | #com = AddNoiseSingleChannel(com) 142 | return com 143 | 144 | def genTextString(self, counter): 145 | textStr = ""; 146 | for idx in xrange(counter): 147 | textStr += chars[r(len(chars))] 148 | 149 | return textStr; 150 | 151 | 152 | if __name__ == '__main__': 153 | outputPath = "../data/train" 154 | #outputPath = "../data/test" 155 | gt = [] 156 | imgaePath = os.path.join(outputPath, 'text') 157 | num = 1000 158 | font_size = 60 159 | data_shape1 = 30 160 | data_shape2 = 80 161 | 162 | if (not os.path.exists(imgaePath)): 163 | os.mkdir(imgaePath) 164 | sum = 0 165 | for counter in xrange(2,10): 166 | G = GenText("./font/font.ttf", font_size, counter) 167 | for i in xrange(num): 168 | textStr = G.genTextString(counter) 169 | print textStr 170 | img = G.generate(textStr, data_shape1, data_shape2) 171 | #print img.shape 172 | cv2.imwrite(os.path.join(imgaePath, str(i + sum) + ".jpg"), img) 173 | gt.append(textStr) 174 | sum += num 175 | gt_file = open(os.path.join(outputPath, 'gt.pkl'), 'wb') 176 | cPickle.dump(gt, gt_file) 177 | gt_file.close() 178 | -------------------------------------------------------------------------------- /lstm_predictor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2.7 2 | # coding=utf-8 3 | from __future__ import print_function 4 | import sys, os 5 | curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) 6 | sys.path.append("../../amalgamation/python/") 7 | sys.path.append("../../python/") 8 | 9 | from mxnet_predict import Predictor 10 | import mxnet as mx 11 | 12 | from symbol.lstm import lstm_unroll 13 | 14 | import numpy as np 15 | import cv2 16 | import os 17 | import matplotlib.pyplot as plt 18 | 19 | class lstm_ocr_model(object): 20 | # Keep Zero index for blank. (CTC request it) 21 | def __init__(self, path_of_json, path_of_params, classes, data_shape, batch_size, num_label, num_hidden, num_lstm_layer): 22 | super(lstm_ocr_model, self).__init__() 23 | self.path_of_json = path_of_json 24 | self.path_of_params = path_of_params 25 | self.classes = classes 26 | self.batch_size = batch_size 27 | self.data_shape = data_shape 28 | self.num_label = num_label 29 | self.num_hidden = num_hidden 30 | self.num_lstm_layer = num_lstm_layer 31 | self.predictor = None 32 | self.__init_ocr() 33 | 34 | def __init_ocr(self): 35 | init_c = [('l%d_init_c'%l, (self.batch_size, self.num_hidden)) for l in range(self.num_lstm_layer)] 36 | init_h = [('l%d_init_h'%l, (self.batch_size, self.num_hidden)) for l in range(self.num_lstm_layer)] 37 | init_states = init_c + init_h 38 | 39 | all_shapes = [('data', (self.batch_size, self.data_shape[0] * self.data_shape[1]))] + init_states + [('label', (self.batch_size, self.num_label))] 40 | all_shapes_dict = {} 41 | for _shape in all_shapes: 42 | all_shapes_dict[_shape[0]] = _shape[1] 43 | self.predictor = Predictor(open(self.path_of_json).read(), 44 | open(self.path_of_params).read(), 45 | all_shapes_dict) 46 | 47 | def forward_ocr(self, img): 48 | img = cv2.resize(img, self.data_shape) 49 | img = img.transpose(1, 0) 50 | img = img.reshape((self.data_shape[0] * self.data_shape[1])) 51 | img = np.multiply(img, 1/255.0) 52 | self.predictor.forward(data=img) 53 | prob = self.predictor.get_output(0) 54 | label_list = [] 55 | for p in prob: 56 | max_index = np.argsort(p)[::-1][0] 57 | label_list.append(max_index) 58 | return self.__get_string(label_list) 59 | 60 | def __get_string(self, label_list): 61 | # Do CTC label rule 62 | # CTC cannot emit a repeated symbol on consecutive timesteps 63 | ret = [] 64 | label_list2 = [0] + list(label_list) 65 | for i in range(len(label_list)): 66 | c1 = label_list2[i] 67 | c2 = label_list2[i+1] 68 | if c2 == 0 or c2 == c1: 69 | continue 70 | ret.append(c2) 71 | # change to ascii 72 | s = '' 73 | for l in ret: 74 | if l > 0 and l < (len(self.classes)+1): 75 | c = self.classes[l-1] 76 | else: 77 | c = '' 78 | s += c 79 | return s 80 | 81 | if __name__ == '__main__': 82 | json_path = os.path.join(os.getcwd(), 'model', 'lctc-symbol.json') 83 | param_path = os.path.join(os.getcwd(), 'model', 'lctc-0100.params') 84 | num_label = 9 # Set your max length of label, add one more for blank 85 | batch_size = 1 86 | num_hidden = 100 87 | num_lstm_layer = 2 88 | data_shape = (80, 30) 89 | classes = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", 90 | "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"] 91 | demo_img = os.path.join(os.getcwd(), 'data', 'demo', '6986.jpg') 92 | 93 | _lstm_ocr_model = lstm_ocr_model(json_path, param_path, classes, data_shape, batch_size, 94 | num_label, num_hidden, num_lstm_layer) 95 | original_img = cv2.imread(demo_img) 96 | img = cv2.cvtColor(original_img, cv2.COLOR_BGR2GRAY) 97 | _str = _lstm_ocr_model.forward_ocr(img) 98 | print('Result: ', _str) 99 | plt.imshow(original_img) 100 | plt.gca().text(0, 8, 101 | '{:s} {:s}'.format("prediction", _str), 102 | #bbox=dict(facecolor=colors[cls_id], alpha=0.5), 103 | fontsize=12, color='red') 104 | plt.show() 105 | -------------------------------------------------------------------------------- /model/README.md: -------------------------------------------------------------------------------- 1 | #### This is the default directory to store all the models, including `*.params` and `*.json` 2 | -------------------------------------------------------------------------------- /symbol/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oyxhust/CNN-LSTM-CTC-text-recognition/60250e82383ab5f3e59e062ec27b1171a8ed64c7/symbol/__init__.py -------------------------------------------------------------------------------- /symbol/bi_lstm.py: -------------------------------------------------------------------------------- 1 | # pylint:skip-file 2 | import mxnet as mx 3 | import numpy as np 4 | from collections import namedtuple 5 | import time 6 | import math 7 | LSTMState = namedtuple("LSTMState", ["c", "h"]) 8 | LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias", 9 | "h2h_weight", "h2h_bias"]) 10 | LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol", 11 | "init_states", "last_states", "forward_state", "backward_state", 12 | "seq_data", "seq_labels", "seq_outputs", 13 | "param_blocks"]) 14 | 15 | def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0.): 16 | """LSTM Cell symbol""" 17 | if dropout > 0.: 18 | indata = mx.sym.Dropout(data=indata, p=dropout) 19 | i2h = mx.sym.FullyConnected(data=indata, 20 | weight=param.i2h_weight, 21 | bias=param.i2h_bias, 22 | num_hidden=num_hidden * 4, 23 | name="t%d_l%d_i2h" % (seqidx, layeridx)) 24 | h2h = mx.sym.FullyConnected(data=prev_state.h, 25 | weight=param.h2h_weight, 26 | bias=param.h2h_bias, 27 | num_hidden=num_hidden * 4, 28 | name="t%d_l%d_h2h" % (seqidx, layeridx)) 29 | gates = i2h + h2h 30 | slice_gates = mx.sym.SliceChannel(gates, num_outputs=4, 31 | name="t%d_l%d_slice" % (seqidx, layeridx)) 32 | in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid") 33 | in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh") 34 | forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid") 35 | out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid") 36 | next_c = (forget_gate * prev_state.c) + (in_gate * in_transform) 37 | next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh") 38 | return LSTMState(c=next_c, h=next_h) 39 | 40 | 41 | def bi_lstm_unroll(num_lstm_layer, seq_len, num_hidden, num_classes, num_label, dropout=0.): 42 | 43 | last_states = [] 44 | forward_param = [] 45 | backward_param = [] 46 | for i in range(num_lstm_layer*2): 47 | last_states.append(LSTMState(c = mx.sym.Variable("l%d_init_c" % i), h = mx.sym.Variable("l%d_init_h" % i))) 48 | if i % 2 == 0: 49 | forward_param.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i), 50 | i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i), 51 | h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i), 52 | h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i))) 53 | else: 54 | backward_param.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i), 55 | i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i), 56 | h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i), 57 | h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i))) 58 | 59 | # embeding layer 60 | data = mx.sym.Variable('data') 61 | label = mx.sym.Variable('label') 62 | wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1) 63 | 64 | forward_hidden = [] 65 | for seqidx in range(seq_len): 66 | hidden = wordvec[seqidx] 67 | for i in range(num_lstm_layer): 68 | next_state = lstm(num_hidden, indata=hidden, 69 | prev_state=last_states[2*i], 70 | param=forward_param[i], 71 | seqidx=seqidx, layeridx=0, dropout=dropout) 72 | hidden = next_state.h 73 | last_states[2*i] = next_state 74 | forward_hidden.append(hidden) 75 | 76 | backward_hidden = [] 77 | for seqidx in range(seq_len): 78 | k = seq_len - seqidx - 1 79 | hidden = wordvec[k] 80 | for i in range(num_lstm_layer): 81 | next_state = lstm(num_hidden, indata=hidden, 82 | prev_state=last_states[2*i + 1], 83 | param=backward_param[i], 84 | seqidx=k, layeridx=1,dropout=dropout) 85 | hidden = next_state.h 86 | last_states[2*i + 1] = next_state 87 | backward_hidden.insert(0, hidden) 88 | 89 | hidden_all = [] 90 | for i in range(seq_len): 91 | hidden_all.append(mx.sym.Concat(*[forward_hidden[i], backward_hidden[i]], dim=1)) 92 | 93 | hidden_concat = mx.sym.Concat(*hidden_all, dim=0) 94 | pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=num_classes) 95 | 96 | label = mx.sym.Reshape(data=label, shape=(-1,)) 97 | label = mx.sym.Cast(data = label, dtype = 'int32') 98 | sm = mx.sym.WarpCTC(data=pred, label=label, label_length = num_label, input_length = seq_len) 99 | 100 | return sm 101 | 102 | 103 | def bi_lstm_inference_symbol(input_size, seq_len, 104 | num_hidden, num_embed, num_label, dropout=0.): 105 | seqidx = 0 106 | embed_weight=mx.sym.Variable("embed_weight") 107 | cls_weight = mx.sym.Variable("cls_weight") 108 | cls_bias = mx.sym.Variable("cls_bias") 109 | last_states = [LSTMState(c = mx.sym.Variable("l0_init_c"), h = mx.sym.Variable("l0_init_h")), 110 | LSTMState(c = mx.sym.Variable("l1_init_c"), h = mx.sym.Variable("l1_init_h"))] 111 | forward_param = LSTMParam(i2h_weight=mx.sym.Variable("l0_i2h_weight"), 112 | i2h_bias=mx.sym.Variable("l0_i2h_bias"), 113 | h2h_weight=mx.sym.Variable("l0_h2h_weight"), 114 | h2h_bias=mx.sym.Variable("l0_h2h_bias")) 115 | backward_param = LSTMParam(i2h_weight=mx.sym.Variable("l1_i2h_weight"), 116 | i2h_bias=mx.sym.Variable("l1_i2h_bias"), 117 | h2h_weight=mx.sym.Variable("l1_h2h_weight"), 118 | h2h_bias=mx.sym.Variable("l1_h2h_bias")) 119 | data = mx.sym.Variable("data") 120 | embed = mx.sym.Embedding(data=data, input_dim=input_size, 121 | weight=embed_weight, output_dim=num_embed, name='embed') 122 | wordvec = mx.sym.SliceChannel(data=embed, num_outputs=seq_len, squeeze_axis=1) 123 | forward_hidden = [] 124 | for seqidx in range(seq_len): 125 | next_state = lstm(num_hidden, indata=wordvec[seqidx], 126 | prev_state=last_states[0], 127 | param=forward_param, 128 | seqidx=seqidx, layeridx=0, dropout=0.0) 129 | hidden = next_state.h 130 | last_states[0] = next_state 131 | forward_hidden.append(hidden) 132 | 133 | backward_hidden = [] 134 | for seqidx in range(seq_len): 135 | k = seq_len - seqidx - 1 136 | next_state = lstm(num_hidden, indata=wordvec[k], 137 | prev_state=last_states[1], 138 | param=backward_param, 139 | seqidx=k, layeridx=1, dropout=0.0) 140 | hidden = next_state.h 141 | last_states[1] = next_state 142 | backward_hidden.insert(0, hidden) 143 | 144 | hidden_all = [] 145 | for i in range(seq_len): 146 | hidden_all.append(mx.sym.Concat(*[forward_hidden[i], backward_hidden[i]], dim=1)) 147 | hidden_concat = mx.sym.Concat(*hidden_all, dim=0) 148 | fc = mx.sym.FullyConnected(data=hidden_concat, num_hidden=num_label, 149 | weight=cls_weight, bias=cls_bias, name='pred') 150 | sm = mx.sym.SoftmaxOutput(data=fc, name='softmax') 151 | output = [sm] 152 | for state in last_states: 153 | output.append(state.c) 154 | output.append(state.h) 155 | return mx.sym.Group(output) 156 | 157 | -------------------------------------------------------------------------------- /symbol/crnn.py: -------------------------------------------------------------------------------- 1 | # pylint:skip-file 2 | import mxnet as mx 3 | import numpy as np 4 | from collections import namedtuple 5 | import time 6 | import math 7 | LSTMState = namedtuple("LSTMState", ["c", "h"]) 8 | LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias", 9 | "h2h_weight", "h2h_bias"]) 10 | LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol", 11 | "init_states", "last_states", "forward_state", "backward_state", 12 | "seq_data", "seq_labels", "seq_outputs", 13 | "param_blocks"]) 14 | 15 | def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0.): 16 | """LSTM Cell symbol""" 17 | if dropout > 0.: 18 | indata = mx.sym.Dropout(data=indata, p=dropout) 19 | i2h = mx.sym.FullyConnected(data=indata, 20 | weight=param.i2h_weight, 21 | bias=param.i2h_bias, 22 | num_hidden=num_hidden * 4, 23 | name="t%d_l%d_i2h" % (seqidx, layeridx)) 24 | h2h = mx.sym.FullyConnected(data=prev_state.h, 25 | weight=param.h2h_weight, 26 | bias=param.h2h_bias, 27 | num_hidden=num_hidden * 4, 28 | name="t%d_l%d_h2h" % (seqidx, layeridx)) 29 | gates = i2h + h2h 30 | slice_gates = mx.sym.SliceChannel(gates, num_outputs=4, 31 | name="t%d_l%d_slice" % (seqidx, layeridx)) 32 | in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid") 33 | in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh") 34 | forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid") 35 | out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid") 36 | next_c = (forget_gate * prev_state.c) + (in_gate * in_transform) 37 | next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh") 38 | return LSTMState(c=next_c, h=next_h) 39 | 40 | 41 | def crnn(num_lstm_layer, seq_len, num_hidden, num_classes, num_label, dropout=0.): 42 | 43 | last_states = [] 44 | forward_param = [] 45 | backward_param = [] 46 | for i in range(num_lstm_layer*2): 47 | last_states.append(LSTMState(c = mx.sym.Variable("l%d_init_c" % i), h = mx.sym.Variable("l%d_init_h" % i))) 48 | if i % 2 == 0: 49 | forward_param.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i), 50 | i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i), 51 | h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i), 52 | h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i))) 53 | else: 54 | backward_param.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i), 55 | i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i), 56 | h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i), 57 | h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i))) 58 | 59 | # input 60 | data = mx.sym.Variable('data') 61 | label = mx.sym.Variable('label') 62 | 63 | #CNN model- similar to VGG 64 | # group 1 65 | conv1_1 = mx.symbol.Convolution( 66 | data=data, kernel=(3, 3), pad=(1, 1), num_filter=64, name="conv1_1") 67 | relu1_1 = mx.symbol.Activation(data=conv1_1, act_type="relu", name="relu1_1") 68 | # conv1_2 = mx.symbol.Convolution( 69 | # data=relu1_1, kernel=(3, 3), pad=(1, 1), num_filter=64, name="conv1_2") 70 | # relu1_2 = mx.symbol.Activation(data=conv1_2, act_type="relu", name="relu1_2") 71 | pool1 = mx.symbol.Pooling( 72 | data=relu1_1, pool_type="max", kernel=(2, 2), stride=(2, 2), name="pool1") 73 | # group 2 74 | conv2_1 = mx.symbol.Convolution( 75 | data=pool1, kernel=(3, 3), pad=(1, 1), num_filter=128, name="conv2_1") 76 | relu2_1 = mx.symbol.Activation(data=conv2_1, act_type="relu", name="relu2_1") 77 | # conv2_2 = mx.symbol.Convolution( 78 | # data=relu2_1, kernel=(3, 3), pad=(1, 1), num_filter=128, name="conv2_2") 79 | # relu2_2 = mx.symbol.Activation(data=conv2_2, act_type="relu", name="relu2_2") 80 | pool2 = mx.symbol.Pooling( 81 | data=relu2_1, pool_type="max", kernel=(2, 2), stride=(2, 2), name="pool2") 82 | # group 3 83 | conv3_1 = mx.symbol.Convolution( 84 | data=pool2, kernel=(3, 3), pad=(1, 1), num_filter=256, name="conv3_1") 85 | batchnorm1 = mx.symbol.BatchNorm(data= conv3_1, name="batchnorm1") 86 | relu3_1 = mx.symbol.Activation(data=batchnorm1, act_type="relu", name="relu3_1") 87 | conv3_2 = mx.symbol.Convolution( 88 | data=relu3_1, kernel=(3, 3), pad=(1, 1), num_filter=256, name="conv3_2") 89 | relu3_2 = mx.symbol.Activation(data=conv3_2, act_type="relu", name="relu3_2") 90 | # conv3_3 = mx.symbol.Convolution( 91 | # data=relu3_2, kernel=(3, 3), pad=(1, 1), num_filter=256, name="conv3_3") 92 | # relu3_3 = mx.symbol.Activation(data=conv3_3, act_type="relu", name="relu3_3") 93 | pool3 = mx.symbol.Pooling( 94 | data=relu3_2, pool_type="max", kernel=(2, 1), stride=(2, 1), name="pool3") 95 | # group 4 96 | conv4_1 = mx.symbol.Convolution( 97 | data=pool3, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_1") 98 | batchnorm2 = mx.symbol.BatchNorm(data= conv4_1, name="batchnorm2") 99 | relu4_1 = mx.symbol.Activation(data=batchnorm2, act_type="relu", name="relu4_1") 100 | conv4_2 = mx.symbol.Convolution( 101 | data=batchnorm1, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_2") 102 | relu4_2 = mx.symbol.Activation(data=conv4_2, act_type="relu", name="relu4_2") 103 | # conv4_3 = mx.symbol.Convolution( 104 | # data=relu4_2, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_3") 105 | # relu4_3 = mx.symbol.Activation(data=conv4_3, act_type="relu", name="relu4_3") 106 | pool4 = mx.symbol.Pooling( 107 | data=batchnorm2, pool_type="max", kernel=(2, 2), stride=(2, 1), pad=(0, 1), name="pool4") 108 | # group 5 109 | conv5_1 = mx.symbol.Convolution( 110 | data=pool4, kernel=(2, 2), pad=(0, 0), num_filter=512, name="conv5_1") 111 | batchnorm3 = mx.symbol.BatchNorm(data= conv5_1, name="batchnorm3") 112 | # relu5_1 = mx.symbol.Activation(data=conv5_1, act_type="relu", name="relu5_1") 113 | # conv5_2 = mx.symbol.Convolution( 114 | # data=relu5_1, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv5_2") 115 | # relu5_2 = mx.symbol.Activation(data=conv5_2, act_type="relu", name="relu5_2") 116 | # conv5_3 = mx.symbol.Convolution( 117 | # data=relu5_2, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv5_3") 118 | # relu5_3 = mx.symbol.Activation(data=conv5_3, act_type="relu", name="relu5_3") 119 | # pool5 = mx.symbol.Pooling( 120 | # data=relu5_3, pool_type="max", kernel=(3, 3), stride=(1, 1), 121 | # pad=(1,1), name="pool5") 122 | if dropout > 0.: 123 | batchnorm3 = mx.sym.Dropout(data=batchnorm3, p=dropout) 124 | # arg_shape, output_shape, aux_shape = batchnorm3.infer_shape(data=(32,1,32,100)) 125 | # print(output_shape) 126 | cnn_out = mx.sym.transpose(data=batchnorm3, axes=(0,3,1,2), name="cnn_out") 127 | # arg_shape, output_shape, aux_shape = cnn_out.infer_shape(data=(32,1,32,100)) 128 | # print(output_shape) 129 | flatten_out = mx.sym.Flatten(data=cnn_out, name="flatten_out") 130 | # arg_shape, output_shape, aux_shape = flatten_out.infer_shape(data=(32,1,32,100)) 131 | # print(output_shape) 132 | wordvec = mx.sym.SliceChannel(data=flatten_out, num_outputs=seq_len, squeeze_axis=1) 133 | 134 | forward_hidden = [] 135 | for seqidx in range(seq_len): 136 | hidden = wordvec[seqidx] 137 | for i in range(num_lstm_layer): 138 | next_state = lstm(num_hidden, indata=hidden, 139 | prev_state=last_states[2*i], 140 | param=forward_param[i], 141 | seqidx=seqidx, layeridx=0, dropout=dropout) 142 | hidden = next_state.h 143 | last_states[2*i] = next_state 144 | forward_hidden.append(hidden) 145 | 146 | backward_hidden = [] 147 | for seqidx in range(seq_len): 148 | k = seq_len - seqidx - 1 149 | hidden = wordvec[k] 150 | for i in range(num_lstm_layer): 151 | next_state = lstm(num_hidden, indata=hidden, 152 | prev_state=last_states[2*i + 1], 153 | param=backward_param[i], 154 | seqidx=k, layeridx=1,dropout=dropout) 155 | hidden = next_state.h 156 | last_states[2*i + 1] = next_state 157 | backward_hidden.insert(0, hidden) 158 | 159 | 160 | hidden_all = [] 161 | for i in range(seq_len): 162 | hidden_all.append(mx.sym.Concat(*[forward_hidden[i], backward_hidden[i]], dim=1)) 163 | 164 | hidden_concat = mx.sym.Concat(*hidden_all, dim=0) 165 | pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=num_classes) 166 | 167 | label = mx.sym.Reshape(data=label, shape=(-1,)) 168 | label = mx.sym.Cast(data = label, dtype = 'int32') 169 | sm = mx.sym.WarpCTC(data=pred, label=label, label_length = num_label, input_length = seq_len) 170 | 171 | return sm 172 | 173 | -------------------------------------------------------------------------------- /symbol/lstm.py: -------------------------------------------------------------------------------- 1 | # pylint:skip-file 2 | import sys 3 | sys.path.insert(0, "../../python") 4 | import mxnet as mx 5 | import numpy as np 6 | from collections import namedtuple 7 | import time 8 | import math 9 | LSTMState = namedtuple("LSTMState", ["c", "h"]) 10 | LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias", 11 | "h2h_weight", "h2h_bias"]) 12 | LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol", 13 | "init_states", "last_states", 14 | "seq_data", "seq_labels", "seq_outputs", 15 | "param_blocks"]) 16 | 17 | def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0.): 18 | """LSTM Cell symbol""" 19 | if dropout > 0.: 20 | indata = mx.sym.Dropout(data=indata, p=dropout) 21 | i2h = mx.sym.FullyConnected(data=indata, 22 | weight=param.i2h_weight, 23 | bias=param.i2h_bias, 24 | num_hidden=num_hidden * 4, 25 | name="t%d_l%d_i2h" % (seqidx, layeridx)) 26 | h2h = mx.sym.FullyConnected(data=prev_state.h, 27 | weight=param.h2h_weight, 28 | bias=param.h2h_bias, 29 | num_hidden=num_hidden * 4, 30 | name="t%d_l%d_h2h" % (seqidx, layeridx)) 31 | gates = i2h + h2h 32 | slice_gates = mx.sym.SliceChannel(gates, num_outputs=4, 33 | name="t%d_l%d_slice" % (seqidx, layeridx)) 34 | in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid") 35 | in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh") 36 | forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid") 37 | out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid") 38 | next_c = (forget_gate * prev_state.c) + (in_gate * in_transform) 39 | next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh") 40 | return LSTMState(c=next_c, h=next_h) 41 | 42 | 43 | def lstm_unroll(num_lstm_layer, seq_len, 44 | num_hidden, num_classes, num_label, dropout=0.): 45 | param_cells = [] 46 | last_states = [] 47 | for i in range(num_lstm_layer): 48 | param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i), 49 | i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i), 50 | h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i), 51 | h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i))) 52 | state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i), 53 | h=mx.sym.Variable("l%d_init_h" % i)) 54 | last_states.append(state) 55 | assert(len(last_states) == num_lstm_layer) 56 | 57 | # embeding layer 58 | data = mx.sym.Variable('data') 59 | label = mx.sym.Variable('label') 60 | wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1) 61 | 62 | hidden_all = [] 63 | for seqidx in range(seq_len): 64 | hidden = wordvec[seqidx] 65 | for i in range(num_lstm_layer): 66 | next_state = lstm(num_hidden, indata=hidden, 67 | prev_state=last_states[i], 68 | param=param_cells[i], 69 | seqidx=seqidx, layeridx=i, dropout=dropout) 70 | hidden = next_state.h 71 | last_states[i] = next_state 72 | hidden_all.append(hidden) 73 | 74 | hidden_concat = mx.sym.Concat(*hidden_all, dim=0) 75 | pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=num_classes) 76 | 77 | label = mx.sym.Reshape(data=label, shape=(-1,)) 78 | label = mx.sym.Cast(data = label, dtype = 'int32') 79 | sm = mx.sym.WarpCTC(data=pred, label=label, label_length = num_label, input_length = seq_len) 80 | return sm 81 | 82 | -------------------------------------------------------------------------------- /train_bi_lstm.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys, random 3 | sys.path.insert(0, "../../python") 4 | import numpy as np 5 | import mxnet as mx 6 | import logging 7 | 8 | from symbol.bi_lstm import bi_lstm_unroll 9 | 10 | from io import BytesIO 11 | import cv2, random 12 | import cPickle 13 | import os 14 | 15 | 16 | class SimpleBatch(object): 17 | def __init__(self, data_names, data, label_names, label): 18 | self.data = data 19 | self.label = label 20 | self.data_names = data_names 21 | self.label_names = label_names 22 | 23 | self.pad = 0 24 | self.index = None # TODO: what is index? 25 | 26 | @property 27 | def provide_data(self): 28 | return [(n, x.shape) for n, x in zip(self.data_names, self.data)] 29 | 30 | @property 31 | def provide_label(self): 32 | return [(n, x.shape) for n, x in zip(self.label_names, self.label)] 33 | 34 | class OCRIter(mx.io.DataIter): 35 | def __init__(self, batch_size, classes, data_shape, num_label, init_states, shuffle=True, train_flag=True): 36 | super(OCRIter, self).__init__() 37 | self.batch_size = batch_size 38 | self.data_shape = data_shape 39 | self.num_label = num_label 40 | self.init_states = init_states 41 | self.init_state_arrays = [mx.nd.zeros(x[1]) for x in init_states] 42 | self.classes = classes 43 | if train_flag: 44 | self.data_path = os.path.join(os.getcwd(), 'data', 'train', 'text') 45 | self.label_path = os.path.join(os.getcwd(), 'data', 'train') 46 | else: 47 | self.data_path = os.path.join(os.getcwd(), 'data', 'test', 'text') 48 | self.label_path = os.path.join(os.getcwd(), 'data', 'test') 49 | self.image_set_index = self._load_image_set_index(shuffle) 50 | self.count = len(self.image_set_index) / self.batch_size 51 | self.gt = self._label_path_from_index() 52 | self.provide_data = [('data', (batch_size, data_shape[0]*data_shape[1]))] + init_states 53 | self.provide_label = [('label', (self.batch_size, num_label))] 54 | 55 | def __iter__(self): 56 | #print('iter') 57 | init_state_names = [x[0] for x in self.init_states] 58 | for k in range(self.count): 59 | data = [] 60 | label = [] 61 | for i in range(self.batch_size): 62 | img_name = self.image_set_index[i + k*self.batch_size] 63 | img = cv2.imread(os.path.join(self.data_path, img_name + '.jpg')) 64 | img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY) 65 | img = cv2.resize(img, self.data_shape) 66 | #print(img) 67 | img = img.transpose(1, 0) 68 | img = img.reshape((data_shape[0] * data_shape[1])) 69 | img = np.multiply(img, 1/255.0) 70 | #print(img) 71 | data.append(img) 72 | ret = np.zeros(self.num_label, int) 73 | plate_str = self.gt[int(img_name)] 74 | #print(plate_str) 75 | for number in range(len(plate_str)): 76 | ret[number] = self.classes.index(plate_str[number]) + 1 77 | #print(ret) 78 | label.append(ret) 79 | 80 | data_all = [mx.nd.array(data)] + self.init_state_arrays 81 | label_all = [mx.nd.array(label)] 82 | data_names = ['data'] + init_state_names 83 | label_names = ['label'] 84 | 85 | 86 | data_batch = SimpleBatch(data_names, data_all, label_names, label_all) 87 | yield data_batch 88 | 89 | def reset(self): 90 | pass 91 | 92 | def _load_image_set_index(self, shuffle): 93 | assert os.path.isdir(self.data_path), 'Path does not exist: {}'.format(image_set_path) 94 | image_set_index = [] 95 | list_dir = os.walk(self.data_path) 96 | for root, _, image_names in list_dir: 97 | for name in image_names: 98 | image_set_index.append(name.split('.')[0]) 99 | if shuffle: 100 | np.random.shuffle(image_set_index) 101 | return image_set_index 102 | 103 | def _label_path_from_index(self): 104 | label_file = os.path.join(self.label_path, 'gt.pkl') 105 | assert os.path.exists(label_file), 'Path does not exist: {}'.format(label_file) 106 | gt_file = open(label_file, 'rb') 107 | label_file = cPickle.load(gt_file) 108 | gt_file.close() 109 | return label_file 110 | 111 | BATCH_SIZE = 32 112 | SEQ_LENGTH = 80 113 | 114 | def ctc_label(p): 115 | ret = [] 116 | p1 = [0] + p 117 | for i in range(len(p)): 118 | c1 = p1[i] 119 | c2 = p1[i+1] 120 | if c2 == 0 or c2 == c1: 121 | continue 122 | ret.append(c2) 123 | return ret 124 | 125 | def remove_blank(l): 126 | ret = [] 127 | for i in range(len(l)): 128 | if l[i] == 0: 129 | break 130 | ret.append(l[i]) 131 | return ret 132 | 133 | def Accuracy(label, pred): 134 | global BATCH_SIZE 135 | global SEQ_LENGTH 136 | hit = 0. 137 | total = 0. 138 | for i in range(BATCH_SIZE): 139 | l = remove_blank(label[i]) 140 | p = [] 141 | for k in range(SEQ_LENGTH): 142 | p.append(np.argmax(pred[k * BATCH_SIZE + i])) 143 | p = ctc_label(p) 144 | if len(p) == len(l): 145 | match = True 146 | for k in range(len(p)): 147 | if p[k] != int(l[k]): 148 | match = False 149 | break 150 | if match: 151 | hit += 1.0 152 | total += 1.0 153 | return hit / total 154 | 155 | if __name__ == '__main__': 156 | # set up logger 157 | log_file_name = "bi_lstm_plate.log" 158 | log_file = open(log_file_name, 'w') 159 | log_file.close() 160 | logging.basicConfig() 161 | logger = logging.getLogger() 162 | logger.setLevel(logging.INFO) 163 | fh = logging.FileHandler(log_file_name) 164 | logger.addHandler(fh) 165 | 166 | prefix = os.path.join(os.getcwd(), 'model', 'bi_lstm_ctc') 167 | 168 | num_hidden = 100 169 | num_lstm_layer = 2 170 | 171 | num_epoch = 100 172 | learning_rate = 0.001 173 | momentum = 0.9 174 | num_label = 9 175 | data_shape = (80, 30) 176 | classes = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", 177 | "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"] 178 | num_classes = len(classes) + 1 179 | 180 | contexts = [mx.context.gpu(0)] 181 | 182 | def sym_gen(seq_len): 183 | return bi_lstm_unroll(num_lstm_layer, seq_len, 184 | num_hidden=num_hidden, num_classes = num_classes, 185 | num_label = num_label) 186 | 187 | init_c = [('l%d_init_c'%l, (BATCH_SIZE, num_hidden)) for l in range(num_lstm_layer*2)] 188 | init_h = [('l%d_init_h'%l, (BATCH_SIZE, num_hidden)) for l in range(num_lstm_layer*2)] 189 | init_states = init_c + init_h 190 | 191 | data_train = OCRIter(BATCH_SIZE, classes, data_shape, num_label, init_states) 192 | data_val = OCRIter(BATCH_SIZE, classes, data_shape, num_label, init_states, train_flag=False) 193 | 194 | symbol = sym_gen(SEQ_LENGTH) 195 | 196 | model = mx.model.FeedForward(ctx=contexts, 197 | symbol=symbol, 198 | num_epoch=num_epoch, 199 | learning_rate=learning_rate, 200 | momentum=momentum, 201 | wd=0.00001, 202 | initializer=mx.init.Xavier(factor_type="in", magnitude=2.34)) 203 | 204 | import logging 205 | head = '%(asctime)-15s %(message)s' 206 | logging.basicConfig(level=logging.DEBUG, format=head) 207 | 208 | logger.info('begin fit') 209 | 210 | model.fit(X=data_train, eval_data=data_val, 211 | eval_metric = mx.metric.np(Accuracy), 212 | batch_end_callback=mx.callback.Speedometer(BATCH_SIZE, 50), logger = logger, 213 | epoch_end_callback = mx.callback.do_checkpoint(prefix, 1)) 214 | 215 | model.save("bi_lctc") 216 | 217 | -------------------------------------------------------------------------------- /train_crnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys, random 3 | sys.path.insert(0, "../../python") 4 | import numpy as np 5 | import mxnet as mx 6 | import logging 7 | 8 | from symbol.crnn import crnn 9 | 10 | from io import BytesIO 11 | import cv2, random 12 | import cPickle 13 | import os 14 | 15 | 16 | class SimpleBatch(object): 17 | def __init__(self, data_names, data, label_names, label): 18 | self.data = data 19 | self.label = label 20 | self.data_names = data_names 21 | self.label_names = label_names 22 | 23 | self.pad = 0 24 | self.index = None # TODO: what is index? 25 | 26 | @property 27 | def provide_data(self): 28 | return [(n, x.shape) for n, x in zip(self.data_names, self.data)] 29 | 30 | @property 31 | def provide_label(self): 32 | return [(n, x.shape) for n, x in zip(self.label_names, self.label)] 33 | 34 | class OCRIter(mx.io.DataIter): 35 | def __init__(self, batch_size, classes, data_shape, num_label, init_states, shuffle=True, train_flag=True): 36 | super(OCRIter, self).__init__() 37 | self.batch_size = batch_size 38 | self.data_shape = data_shape 39 | self.num_label = num_label 40 | self.init_states = init_states 41 | self.init_state_arrays = [mx.nd.zeros(x[1]) for x in init_states] 42 | self.classes = classes 43 | if train_flag: 44 | self.data_path = os.path.join(os.getcwd(), 'data', 'train', 'text') 45 | self.label_path = os.path.join(os.getcwd(), 'data', 'train') 46 | else: 47 | self.data_path = os.path.join(os.getcwd(), 'data', 'test', 'text') 48 | self.label_path = os.path.join(os.getcwd(), 'data', 'test') 49 | self.image_set_index = self._load_image_set_index(shuffle) 50 | self.count = len(self.image_set_index) / self.batch_size 51 | self.gt = self._label_path_from_index() 52 | self.provide_data = [('data', (batch_size, 1, data_shape[1], data_shape[0]))] + init_states 53 | self.provide_label = [('label', (self.batch_size, num_label))] 54 | 55 | def __iter__(self): 56 | #print('iter') 57 | init_state_names = [x[0] for x in self.init_states] 58 | for k in range(self.count): 59 | data = [] 60 | label = [] 61 | for i in range(self.batch_size): 62 | img_name = self.image_set_index[i + k*self.batch_size] 63 | img = cv2.imread(os.path.join(self.data_path, img_name + '.jpg'), cv2.IMREAD_GRAYSCALE) 64 | #img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY) 65 | img = cv2.resize(img, self.data_shape) 66 | img = img.reshape((1, data_shape[1], data_shape[0])) 67 | #print(img) 68 | #img = img.transpose(1, 0) 69 | #img = img.reshape((data_shape[0] * data_shape[1])) 70 | img = np.multiply(img, 1/255.0) 71 | #print(img) 72 | data.append(img) 73 | ret = np.zeros(self.num_label, int) 74 | plate_str = self.gt[int(img_name)] 75 | #print(plate_str) 76 | for number in range(len(plate_str)): 77 | ret[number] = self.classes.index(plate_str[number]) + 1 78 | #print(ret) 79 | label.append(ret) 80 | 81 | data_all = [mx.nd.array(data)] + self.init_state_arrays 82 | label_all = [mx.nd.array(label)] 83 | data_names = ['data'] + init_state_names 84 | label_names = ['label'] 85 | 86 | 87 | data_batch = SimpleBatch(data_names, data_all, label_names, label_all) 88 | yield data_batch 89 | 90 | def reset(self): 91 | pass 92 | 93 | def _load_image_set_index(self, shuffle): 94 | assert os.path.isdir(self.data_path), 'Path does not exist: {}'.format(image_set_path) 95 | image_set_index = [] 96 | list_dir = os.walk(self.data_path) 97 | for root, _, image_names in list_dir: 98 | for name in image_names: 99 | image_set_index.append(name.split('.')[0]) 100 | if shuffle: 101 | np.random.shuffle(image_set_index) 102 | return image_set_index 103 | 104 | def _label_path_from_index(self): 105 | label_file = os.path.join(self.label_path, 'gt.pkl') 106 | assert os.path.exists(label_file), 'Path does not exist: {}'.format(label_file) 107 | gt_file = open(label_file, 'rb') 108 | label_file = cPickle.load(gt_file) 109 | gt_file.close() 110 | return label_file 111 | 112 | BATCH_SIZE = 32 113 | SEQ_LENGTH = 25 114 | 115 | def ctc_label(p): 116 | ret = [] 117 | p1 = [0] + p 118 | for i in range(len(p)): 119 | c1 = p1[i] 120 | c2 = p1[i+1] 121 | if c2 == 0 or c2 == c1: 122 | continue 123 | ret.append(c2) 124 | return ret 125 | 126 | def remove_blank(l): 127 | ret = [] 128 | for i in range(len(l)): 129 | if l[i] == 0: 130 | break 131 | ret.append(l[i]) 132 | return ret 133 | 134 | def Accuracy(label, pred): 135 | global BATCH_SIZE 136 | global SEQ_LENGTH 137 | hit = 0. 138 | total = 0. 139 | for i in range(BATCH_SIZE): 140 | l = remove_blank(label[i]) 141 | p = [] 142 | for k in range(SEQ_LENGTH): 143 | p.append(np.argmax(pred[k * BATCH_SIZE + i])) 144 | p = ctc_label(p) 145 | if len(p) == len(l): 146 | match = True 147 | for k in range(len(p)): 148 | if p[k] != int(l[k]): 149 | match = False 150 | break 151 | if match: 152 | hit += 1.0 153 | total += 1.0 154 | return hit / total 155 | 156 | if __name__ == '__main__': 157 | # set up logger 158 | log_file_name = "crnn_plate.log" 159 | log_file = open(log_file_name, 'w') 160 | log_file.close() 161 | logging.basicConfig() 162 | logger = logging.getLogger() 163 | logger.setLevel(logging.INFO) 164 | fh = logging.FileHandler(log_file_name) 165 | logger.addHandler(fh) 166 | 167 | prefix = os.path.join(os.getcwd(), 'model', 'crnn_ctc') 168 | 169 | num_hidden = 256 170 | num_lstm_layer = 2 171 | 172 | num_epoch = 100 173 | learning_rate = 0.001 174 | momentum = 0.9 175 | num_label = 9 176 | data_shape = (100, 32) 177 | classes = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", 178 | "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"] 179 | num_classes = len(classes) + 1 180 | 181 | contexts = [mx.context.gpu(0)] 182 | 183 | def sym_gen(seq_len): 184 | return crnn(num_lstm_layer, seq_len, 185 | num_hidden=num_hidden, num_classes = num_classes, 186 | num_label = num_label, dropout=0.3) 187 | 188 | init_c = [('l%d_init_c'%l, (BATCH_SIZE, num_hidden)) for l in range(num_lstm_layer*2)] 189 | init_h = [('l%d_init_h'%l, (BATCH_SIZE, num_hidden)) for l in range(num_lstm_layer*2)] 190 | init_states = init_c + init_h 191 | 192 | data_train = OCRIter(BATCH_SIZE, classes, data_shape, num_label, init_states) 193 | data_val = OCRIter(BATCH_SIZE, classes, data_shape, num_label, init_states, train_flag=False) 194 | 195 | symbol = sym_gen(SEQ_LENGTH) 196 | 197 | model = mx.model.FeedForward(ctx=contexts, 198 | symbol=symbol, 199 | num_epoch=num_epoch, 200 | learning_rate=learning_rate, 201 | momentum=momentum, 202 | wd=0.00001, 203 | #optimizer='AdaDelta', 204 | initializer=mx.init.Xavier(factor_type="in", magnitude=2.34)) 205 | 206 | import logging 207 | head = '%(asctime)-15s %(message)s' 208 | logging.basicConfig(level=logging.DEBUG, format=head) 209 | 210 | logger.info('begin fit') 211 | 212 | model.fit(X=data_train, eval_data=data_val, 213 | eval_metric = mx.metric.np(Accuracy), 214 | batch_end_callback=mx.callback.Speedometer(BATCH_SIZE, 50), logger = logger, 215 | epoch_end_callback = mx.callback.do_checkpoint(prefix, 1)) 216 | 217 | model.save("crnnctc") 218 | 219 | -------------------------------------------------------------------------------- /train_lstm.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme 2 | # pylint: disable=superfluous-parens, no-member, invalid-name 3 | from __future__ import print_function 4 | import sys, random 5 | sys.path.insert(0, "../../python") 6 | import numpy as np 7 | import mxnet as mx 8 | import logging 9 | 10 | from symbol.lstm import lstm_unroll 11 | 12 | from io import BytesIO 13 | import cv2, random 14 | import cPickle 15 | import os 16 | 17 | 18 | class SimpleBatch(object): 19 | def __init__(self, data_names, data, label_names, label): 20 | self.data = data 21 | self.label = label 22 | self.data_names = data_names 23 | self.label_names = label_names 24 | 25 | self.pad = 0 26 | self.index = None # TODO: what is index? 27 | 28 | @property 29 | def provide_data(self): 30 | return [(n, x.shape) for n, x in zip(self.data_names, self.data)] 31 | 32 | @property 33 | def provide_label(self): 34 | return [(n, x.shape) for n, x in zip(self.label_names, self.label)] 35 | 36 | class OCRIter(mx.io.DataIter): 37 | def __init__(self, batch_size, classes, data_shape, num_label, init_states, shuffle=True, train_flag=True): 38 | super(OCRIter, self).__init__() 39 | self.batch_size = batch_size 40 | self.data_shape = data_shape 41 | self.num_label = num_label 42 | self.init_states = init_states 43 | self.init_state_arrays = [mx.nd.zeros(x[1]) for x in init_states] 44 | self.classes = classes 45 | if train_flag: 46 | self.data_path = os.path.join(os.getcwd(), 'data', 'train', 'text') 47 | self.label_path = os.path.join(os.getcwd(), 'data', 'train') 48 | else: 49 | self.data_path = os.path.join(os.getcwd(), 'data', 'test', 'text') 50 | self.label_path = os.path.join(os.getcwd(), 'data', 'test') 51 | self.image_set_index = self._load_image_set_index(shuffle) 52 | self.count = len(self.image_set_index) / self.batch_size 53 | self.gt = self._label_path_from_index() 54 | self.provide_data = [('data', (batch_size, data_shape[0]*data_shape[1]))] + init_states 55 | self.provide_label = [('label', (self.batch_size, num_label))] 56 | 57 | def __iter__(self): 58 | #print('iter') 59 | init_state_names = [x[0] for x in self.init_states] 60 | for k in range(self.count): 61 | data = [] 62 | label = [] 63 | for i in range(self.batch_size): 64 | img_name = self.image_set_index[i + k*self.batch_size] 65 | img = cv2.imread(os.path.join(self.data_path, img_name + '.jpg')) 66 | img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY) 67 | img = cv2.resize(img, self.data_shape) 68 | #print(img) 69 | img = img.transpose(1, 0) 70 | img = img.reshape((data_shape[0] * data_shape[1])) 71 | img = np.multiply(img, 1/255.0) 72 | #print(img) 73 | data.append(img) 74 | ret = np.zeros(self.num_label, int) 75 | plate_str = self.gt[int(img_name)] 76 | #print(plate_str) 77 | for number in range(len(plate_str)): 78 | ret[number] = self.classes.index(plate_str[number]) + 1 79 | #print(ret) 80 | label.append(ret) 81 | 82 | data_all = [mx.nd.array(data)] + self.init_state_arrays 83 | label_all = [mx.nd.array(label)] 84 | data_names = ['data'] + init_state_names 85 | label_names = ['label'] 86 | 87 | 88 | data_batch = SimpleBatch(data_names, data_all, label_names, label_all) 89 | yield data_batch 90 | 91 | def reset(self): 92 | pass 93 | 94 | def _load_image_set_index(self, shuffle): 95 | assert os.path.isdir(self.data_path), 'Path does not exist: {}'.format(image_set_path) 96 | image_set_index = [] 97 | list_dir = os.walk(self.data_path) 98 | for root, _, image_names in list_dir: 99 | for name in image_names: 100 | image_set_index.append(name.split('.')[0]) 101 | if shuffle: 102 | np.random.shuffle(image_set_index) 103 | return image_set_index 104 | 105 | def _label_path_from_index(self): 106 | label_file = os.path.join(self.label_path, 'gt.pkl') 107 | assert os.path.exists(label_file), 'Path does not exist: {}'.format(label_file) 108 | gt_file = open(label_file, 'rb') 109 | label_file = cPickle.load(gt_file) 110 | gt_file.close() 111 | return label_file 112 | 113 | BATCH_SIZE = 32 114 | SEQ_LENGTH = 150 115 | 116 | def ctc_label(p): 117 | ret = [] 118 | p1 = [0] + p 119 | for i in range(len(p)): 120 | c1 = p1[i] 121 | c2 = p1[i+1] 122 | if c2 == 0 or c2 == c1: 123 | continue 124 | ret.append(c2) 125 | return ret 126 | 127 | def remove_blank(l): 128 | ret = [] 129 | for i in range(len(l)): 130 | if l[i] == 0: 131 | break 132 | ret.append(l[i]) 133 | return ret 134 | 135 | def Accuracy(label, pred): 136 | global BATCH_SIZE 137 | global SEQ_LENGTH 138 | hit = 0. 139 | total = 0. 140 | for i in range(BATCH_SIZE): 141 | l = remove_blank(label[i]) 142 | p = [] 143 | for k in range(SEQ_LENGTH): 144 | p.append(np.argmax(pred[k * BATCH_SIZE + i])) 145 | p = ctc_label(p) 146 | if len(p) == len(l): 147 | match = True 148 | for k in range(len(p)): 149 | if p[k] != int(l[k]): 150 | match = False 151 | break 152 | if match: 153 | hit += 1.0 154 | total += 1.0 155 | return hit / total 156 | 157 | if __name__ == '__main__': 158 | # set up logger 159 | log_file_name = "lstm_plate.log" 160 | log_file = open(log_file_name, 'w') 161 | log_file.close() 162 | logging.basicConfig() 163 | logger = logging.getLogger() 164 | logger.setLevel(logging.INFO) 165 | fh = logging.FileHandler(log_file_name) 166 | logger.addHandler(fh) 167 | 168 | prefix = os.path.join(os.getcwd(), 'model', 'lstm_ctc') 169 | 170 | num_hidden = 100 171 | num_lstm_layer = 2 172 | 173 | num_epoch = 100 174 | learning_rate = 0.001 175 | momentum = 0.9 176 | num_label = 9 177 | data_shape = (80, 30) 178 | classes = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", 179 | "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"] 180 | num_classes = len(classes) + 1 181 | 182 | contexts = [mx.context.gpu(0)] 183 | 184 | def sym_gen(seq_len): 185 | return lstm_unroll(num_lstm_layer, seq_len, 186 | num_hidden=num_hidden, num_classes = num_classes, 187 | num_label = num_label) 188 | 189 | init_c = [('l%d_init_c'%l, (BATCH_SIZE, num_hidden)) for l in range(num_lstm_layer)] 190 | init_h = [('l%d_init_h'%l, (BATCH_SIZE, num_hidden)) for l in range(num_lstm_layer)] 191 | init_states = init_c + init_h 192 | 193 | data_train = OCRIter(BATCH_SIZE, classes, data_shape, num_label, init_states) 194 | data_val = OCRIter(BATCH_SIZE, classes, data_shape, num_label, init_states, train_flag=False) 195 | 196 | symbol = sym_gen(SEQ_LENGTH) 197 | 198 | model = mx.model.FeedForward(ctx=contexts, 199 | symbol=symbol, 200 | num_epoch=num_epoch, 201 | learning_rate=learning_rate, 202 | momentum=momentum, 203 | wd=0.00001, 204 | initializer=mx.init.Xavier(factor_type="in", magnitude=2.34)) 205 | 206 | import logging 207 | head = '%(asctime)-15s %(message)s' 208 | logging.basicConfig(level=logging.DEBUG, format=head) 209 | 210 | logger.info('begin fit') 211 | 212 | model.fit(X=data_train, eval_data=data_val, 213 | eval_metric = mx.metric.np(Accuracy), 214 | batch_end_callback=mx.callback.Speedometer(BATCH_SIZE, 50), logger = logger, 215 | epoch_end_callback = mx.callback.do_checkpoint(prefix, 1)) 216 | 217 | model.save("lctc") 218 | 219 | --------------------------------------------------------------------------------