├── .gitattributes ├── .gitignore ├── README.md ├── fonts └── Ubuntu-M.ttf ├── imageGenerate.py └── lstm+ctc ├── train.py └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *,cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # dotenv 80 | .env 81 | 82 | # virtualenv 83 | .venv/ 84 | venv/ 85 | ENV/ 86 | 87 | # Spyder project settings 88 | .spyderproject 89 | 90 | # Rope project settings 91 | .ropeproject 92 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 多种端到端验证码识别的方案 2 | 3 | ## 方案 4 | * CNN 5 | * LSTM + CTC 6 | * LSTM + WrapCTC 7 | * GRU + CTC 8 | * GRU + WrapCTC 9 | 10 | ## 结果 11 | 12 | 待更新完善,已经完成了CNN、LSTM+CTC两部分内容 13 | 14 | |方法|Model|Checkpoint|准确率|updated| 15 | |--|--|--|--|--| 16 | |CNN|CNN|checkpoints|98.5 %|20170823| 17 | |LSTM+CTC|--|--|--|--| 18 | |LSTM + WrapCTC|--|--|--|--| 19 | |GRU + CTC|--|--|--|--| 20 | |GRU + WrapCTC|--|--|--|--| 21 | -------------------------------------------------------------------------------- /fonts/Ubuntu-M.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunjunee/verifyCodes/bcf7e518c0b901f7947ada4aa2915b04c556565b/fonts/Ubuntu-M.ttf -------------------------------------------------------------------------------- /imageGenerate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @ Author: Jun Sun {Python3} 4 | @ E-mail: sunjunee@qq.com 5 | @ Create: 2017-08-21 15:04 6 | 7 | Descript: Generate verify code 8 | """ 9 | import random, os 10 | from captcha.image import ImageCaptcha 11 | from multiprocessing import Pool 12 | 13 | #10 14 | char_set='0123456789' 15 | numProcess = 4 16 | 17 | def gen_rand(): 18 | buf = ""; max_len = random.randint(4,6) 19 | for i in range(max_len): 20 | buf += random.choice(char_set) 21 | return buf 22 | 23 | def generateImg(imgDir, ind): 24 | captcha = ImageCaptcha(fonts=['fonts/Ubuntu-M.ttf']); 25 | theChars = gen_rand(); captcha.generate(theChars) 26 | img_name = '{:08d}'.format(ind) + '_' + theChars + '.png' 27 | img_path = imgDir + '/' + img_name 28 | captcha.write(theChars, img_path) 29 | 30 | def run(num, path): 31 | if not os.path.exists(path): os.mkdir(path); 32 | tasks = [(path, i) for i in range(num)]; 33 | with Pool(processes = numProcess) as pool: 34 | pool.starmap(generateImg, tasks) 35 | 36 | if __name__=='__main__': 37 | run(10000, 'train') 38 | run(2000, 'validation') -------------------------------------------------------------------------------- /lstm+ctc/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @ Author: Jun Sun {Python3} 4 | @ E-mail: sunjunee@qq.com 5 | @ Create: 2017-08-21 20:18 6 | 7 | Descript: Training LSTM with CTC loss For verify codes 8 | """ 9 | 10 | import os 11 | import numpy as np 12 | import tensorflow as tf 13 | import time 14 | import utils 15 | 16 | restore = True; 17 | checkpoint_dir = './checkpoint/'; 18 | initial_learning_rate = 1e-3; 19 | 20 | num_layers = 2; 21 | num_hidden = 48; 22 | num_epochs = 100; 23 | batch_size = 50; 24 | save_steps = 200; 25 | validation_steps = 200; 26 | 27 | decay_rate = 0.9; 28 | decay_steps = 1000; 29 | 30 | beta1 = 0.9; 31 | beta2 = 0.999; 32 | momentum = 0.9; 33 | 34 | log_dir = './log'; 35 | 36 | image_width = 120; 37 | image_height = 45; 38 | 39 | num_features = image_height; 40 | num_classes = 10 + 1 + 1; 41 | 42 | class Graph(object): 43 | def __init__(self): 44 | self.graph = tf.Graph() 45 | with self.graph.as_default(): 46 | # LSTM网络的输入,[batch_size, max_timesteps, num_features] 47 | # num_features可理解为图片的一列,max_timesteps可理解为图片的列数 48 | self.inputs = tf.placeholder(tf.float32, [None, None, num_features]); 49 | shape = tf.shape(self.inputs); batch_size, _ = shape[0], shape[1]; 50 | 51 | # 产生一个ctc_loss需要的sparse_placeholder 52 | self.labels = tf.sparse_placeholder(tf.int32) 53 | 54 | # 一个batch中每个样本序列的长度,是一维向量 55 | self.seq_len = tf.placeholder(tf.int32, [None]) 56 | 57 | # 多层RNN结构(两层LSTM堆叠),隐藏状态为num_hidden,dynamic_rnn使得rnn的输入序列可以变长 58 | stack = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.LSTMCell(num_hidden,state_is_tuple=True) for _ in range(num_layers)], state_is_tuple=True) 59 | outputs, _ = tf.nn.dynamic_rnn(stack, self.inputs, self.seq_len, dtype=tf.float32) 60 | outputs = tf.reshape(outputs, [-1, num_hidden]) 61 | 62 | # 连接一个全连接层 63 | W = tf.Variable(tf.truncated_normal([num_hidden, num_classes], stddev=0.1, dtype=tf.float32), name='W') 64 | b = tf.Variable(tf.constant(0., dtype = tf.float32, shape=[num_classes], name='b')) 65 | logits = tf.matmul(outputs, W) + b 66 | 67 | # 将结果变成[batch_size, -1, num_classes]的形状 68 | logits = tf.reshape(logits, [batch_size, -1, num_classes]) 69 | 70 | # 将时间放到第一维 71 | logits = tf.transpose(logits, (1, 0, 2)) 72 | 73 | # 定义CTC loss 74 | self.loss = tf.nn.ctc_loss(labels=self.labels, inputs=logits, sequence_length=self.seq_len) 75 | self.cost = tf.reduce_mean(self.loss) 76 | 77 | self.global_step = tf.Variable(0, trainable=False) 78 | self.learning_rate = tf.train.exponential_decay(initial_learning_rate, self.global_step, decay_steps, decay_rate, staircase=True) 79 | # self.optimizer = tf.train.MomentumOptimizer(learning_rate=self.learning_rate,momentum = momentum,use_nesterov=True).minimize(self.cost,global_step=self.global_step) 80 | self.optimizer = tf.train.AdamOptimizer(learning_rate=initial_learning_rate,beta1=beta1,beta2=beta2).minimize(self.cost,global_step=self.global_step) 81 | #分类结果 82 | self.decoded, self.log_prob = tf.nn.ctc_beam_search_decoder(logits, self.seq_len, merge_repeated=False) 83 | self.dense_decoded = tf.sparse_tensor_to_dense(self.decoded[0], default_value=-1) 84 | 85 | #分类错误率 86 | self.lerr = tf.reduce_mean(tf.edit_distance(tf.cast(self.decoded[0], tf.int32), self.labels)) 87 | 88 | tf.summary.scalar('cost', self.cost) 89 | self.merged_summay = tf.summary.merge_all() 90 | 91 | def train(train_dir=None, val_dir=None): 92 | #载入训练、测试数据 93 | print('Loading training data...') 94 | train_feeder = utils.DataIterator(data_dir=train_dir) 95 | print('Get images: ', train_feeder.size) 96 | 97 | print('Loading validate data...') 98 | val_feeder=utils.DataIterator(data_dir=val_dir) 99 | print('Get images: ', val_feeder.size) 100 | 101 | #定义网络结构 102 | g = Graph() 103 | 104 | #训练样本总数 105 | num_train_samples = train_feeder.size 106 | #每一轮(epoch)样本可以跑多少个batch 107 | num_batches_per_epoch = int(num_train_samples / batch_size) 108 | 109 | with tf.Session(graph = g.graph) as sess: 110 | sess.run(tf.global_variables_initializer()) 111 | 112 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=10) 113 | 114 | # restore = True 加载模型 115 | if restore: 116 | ckpt = tf.train.latest_checkpoint(checkpoint_dir) 117 | if ckpt: 118 | # global_step也会被加载 119 | saver.restore(sess, ckpt); 120 | print('restore from the checkpoint{0}'.format(ckpt)) 121 | 122 | print('============begin training============') 123 | # 获取一个batch的验证数据,制作成placeholder的输入格式 124 | val_inputs, val_seq_len, val_labels = val_feeder.input_index_generate_batch() 125 | val_feed = {g.inputs: val_inputs, g.labels: val_labels, g.seq_len: val_seq_len} 126 | 127 | start_time = time.time(); 128 | for cur_epoch in range(num_epochs): #按照epoch进行循环 129 | shuffle_idx = np.random.permutation(num_train_samples) #将训练样本的index打乱 130 | train_cost = 0; 131 | 132 | for cur_batch in range(num_batches_per_epoch): #对于当前epoch中的每个bacth进行训练 133 | # 获取一个batch的训练样本,制作成placeholder的输入格式 134 | indexs = [shuffle_idx[i % num_train_samples] for i in range(cur_batch * batch_size, (cur_batch+1) * batch_size)]; 135 | batch_inputs, batch_seq_len, batch_labels = train_feeder.input_index_generate_batch(indexs); 136 | feed = {g.inputs: batch_inputs, g.labels:batch_labels, g.seq_len:batch_seq_len}; 137 | 138 | # 训练run 139 | summary_str, batch_cost, step, _ = sess.run([g.merged_summay, g.cost, g.global_step, g.optimizer], feed) 140 | # 计算损失 141 | train_cost += batch_cost; 142 | 143 | # 打印 144 | if step % 50 == 1: 145 | end_time = time.time(); 146 | print('No. %5d batches, loss: %5.2f, time: %3.1fs' % (step, batch_cost, end_time-start_time)); 147 | start_time = time.time(); 148 | 149 | #验证集验证、保存checkpoint: 150 | if step % validation_steps == 1: 151 | if not os.path.isdir(checkpoint_dir): os.mkdir(checkpoint_dir); 152 | saver.save(sess,os.path.join(checkpoint_dir, 'ocr-model'), global_step=step) 153 | 154 | #解码的结果: 155 | dense_decoded, lastbatch_err, lr = sess.run([g.dense_decoded, g.lerr, g.learning_rate], val_feed) 156 | acc = utils.accuracy_calculation(val_feeder.labels, dense_decoded, ignore_value=-1, isPrint=False) 157 | print('-After %5d steps, Val accu: %4.2f%%' % (step, acc)); 158 | 159 | if __name__ == '__main__': 160 | train(train_dir='train', val_dir='validation') 161 | -------------------------------------------------------------------------------- /lstm+ctc/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | #import config 3 | import numpy as np 4 | import cv2 5 | 6 | #10 digit + blank + space 7 | num_classes = 10 + 1 + 1 8 | channel = 1 9 | image_width = 120 10 | image_height = 45 11 | num_features = image_height * channel 12 | SPACE_INDEX = 0 13 | SPACE_TOKEN = '' 14 | maxPrintLen = 10 15 | 16 | charset = '0123456789' 17 | encode_maps = {}; decode_maps = {}; 18 | for i, char in enumerate(charset, 1): 19 | encode_maps[char] = i; decode_maps[i] = char; 20 | encode_maps[SPACE_TOKEN]=SPACE_INDEX 21 | decode_maps[SPACE_INDEX]=SPACE_TOKEN 22 | 23 | #数据读取、按index取batch 24 | class DataIterator: 25 | def __init__(self, data_dir): 26 | self.image_names = []; self.image = []; self.labels=[]; 27 | file_list = os.listdir(data_dir); 28 | for file_path in file_list: 29 | if file_path[-3:] != 'png': #只要png的图片 30 | continue; 31 | image_name = os.path.join(data_dir, file_path) 32 | self.image_names.append(image_name) 33 | 34 | im = cv2.imread(image_name, 0).astype(np.float32) / 255. 35 | im = cv2.resize(im,(image_width, image_height)) 36 | im = im.swapaxes(0,1) #变成120*45并转成45*120 37 | self.image.append(np.array(im)) 38 | 39 | # 验证码内容 40 | code = file_path.split('_')[1].split('.')[0] 41 | code = [SPACE_INDEX if code == SPACE_TOKEN else encode_maps[c] for c in list(code)] 42 | self.labels.append(code) 43 | 44 | @property 45 | def size(self): 46 | return len(self.labels) 47 | 48 | def the_label(self,indexs): 49 | labels=[] 50 | for i in indexs: 51 | labels.append(self.labels[i]) 52 | return labels 53 | 54 | #输入index,返回图片数据 55 | def input_index_generate_batch(self, index=None): 56 | if index: 57 | image_batch = [self.image[i] for i in index]; 58 | label_batch = [self.labels[i] for i in index]; 59 | else: 60 | image_batch = self.image 61 | label_batch = self.labels 62 | 63 | def get_input_lens(sequences): 64 | lengths = np.asarray([len(s) for s in sequences], dtype=np.int64) 65 | return sequences, lengths 66 | 67 | batch_inputs, batch_seq_len = get_input_lens(np.array(image_batch)) 68 | batch_labels = sparse_tuple_from_label(label_batch) 69 | return batch_inputs, batch_seq_len, batch_labels 70 | 71 | def accuracy_calculation(original_seq, decoded_seq, ignore_value=-1, isPrint = True): 72 | count = 0 73 | for i,origin_label in enumerate(original_seq): 74 | decoded_label = [j for j in decoded_seq[i] if j != ignore_value] 75 | if origin_label == decoded_label: count+=1; 76 | return count * 1.0 / len(original_seq) 77 | 78 | def sparse_tuple_from_label(sequences, dtype=np.int32): 79 | """Create a sparse representention of x. 80 | Args: 81 | sequences: a list of lists of type dtype where each element is a sequence 82 | Returns: 83 | A tuple with (indices, values, shape) 84 | """ 85 | indices = []; values = [] 86 | 87 | for n, seq in enumerate(sequences): 88 | indices.extend(zip([n]*len(seq), range(len(seq)))) 89 | values.extend(seq) 90 | 91 | indices = np.asarray(indices, dtype=np.int64);values = np.asarray(values, dtype=dtype) 92 | shape = np.asarray([len(sequences), np.asarray(indices).max(0)[1] + 1], dtype=np.int64) 93 | 94 | return indices, values, shape 95 | 96 | def pad_input_sequences(sequences, maxlen=None, dtype=np.float32, 97 | padding='post', truncating='post', value=0.): 98 | '''Pads each sequence to the same length: the length of the longest 99 | sequence. 100 | If maxlen is provided, any sequence longer than maxlen is truncated to 101 | maxlen. Truncation happens off either the beginning or the end 102 | (default) of the sequence. Supports post-padding (default) and 103 | pre-padding. 104 | Args: 105 | sequences: list of lists where each element is a sequence 106 | maxlen: int, maximum length 107 | dtype: type to cast the resulting sequence. 108 | padding: 'pre' or 'post', pad either before or after each sequence. 109 | truncating: 'pre' or 'post', remove values from sequences larger 110 | than maxlen either in the beginning or in the end of the sequence 111 | value: float, value to pad the sequences to the desired value. 112 | Returns 113 | x: numpy array with dimensions (number_of_sequences, maxlen) 114 | lengths: numpy array with the original sequence lengths 115 | ''' 116 | lengths = np.asarray([len(s) for s in sequences], dtype=np.int64) 117 | 118 | nb_samples = len(sequences) 119 | if maxlen is None: 120 | maxlen = np.max(lengths) 121 | 122 | # take the sample shape from the first non empty sequence 123 | # checking for consistency in the main loop below. 124 | sample_shape = tuple() 125 | for s in sequences: 126 | if len(s) > 0: 127 | sample_shape = np.asarray(s).shape[1:] 128 | break 129 | 130 | x = (np.ones((nb_samples, maxlen) + sample_shape) * value).astype(dtype) 131 | for idx, s in enumerate(sequences): 132 | if len(s) == 0: 133 | continue # empty list was found 134 | if truncating == 'pre': 135 | trunc = s[-maxlen:] 136 | elif truncating == 'post': 137 | trunc = s[:maxlen] 138 | else: 139 | raise ValueError('Truncating type "%s" not understood' % truncating) 140 | 141 | # check `trunc` has expected shape 142 | trunc = np.asarray(trunc, dtype=dtype) 143 | if trunc.shape[1:] != sample_shape: 144 | raise ValueError('Shape of sample %s of sequence at position %s is different from expected shape %s' % 145 | (trunc.shape[1:], idx, sample_shape)) 146 | 147 | if padding == 'post': 148 | x[idx, :len(trunc)] = trunc 149 | elif padding == 'pre': 150 | x[idx, -len(trunc):] = trunc 151 | else: 152 | raise ValueError('Padding type "%s" not understood' % padding) 153 | return x, lengths 154 | 155 | --------------------------------------------------------------------------------