├── README.md ├── config.py ├── generate.py ├── .gitignore ├── preprocess.py ├── main.py └── reference.py /README.md: -------------------------------------------------------------------------------- 1 | # CrackCaptcha 2 | Crack Captcha Using TensorFlow 3 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | VOCAB = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] 2 | CAPTCHA_LENGTH = 4 3 | VOCAB_LENGTH = len(VOCAB) 4 | DATA_LENGTH = 10000 5 | DATA_PATH = 'data' -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | from captcha.image import ImageCaptcha 2 | from PIL import Image 3 | 4 | text = '1234' 5 | image = ImageCaptcha() 6 | captcha = image.generate(text) 7 | captcha_image = Image.open(captcha) 8 | captcha_image.show() 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | data/ 8 | ckpt/ 9 | # C extensions 10 | *.so 11 | .idea/ 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | 104 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | from os.path import join, exists 2 | import pickle 3 | from PIL import Image 4 | from captcha.image import ImageCaptcha 5 | import numpy as np 6 | import random 7 | from os import makedirs 8 | from config import * 9 | 10 | 11 | def generate_captcha(captcha_text): 12 | """ 13 | get captcha text and np array 14 | :param captcha_text: source text 15 | :return: captcha image and array 16 | """ 17 | image = ImageCaptcha() 18 | captcha = image.generate(captcha_text) 19 | captcha_image = Image.open(captcha) 20 | captcha_array = np.array(captcha_image) 21 | return captcha_array 22 | 23 | 24 | def text2vec(text): 25 | """ 26 | text to one-hot vector 27 | :param text: source text 28 | :return: np array 29 | """ 30 | if len(text) > CAPTCHA_LENGTH: 31 | return False 32 | vector = np.zeros(CAPTCHA_LENGTH * VOCAB_LENGTH) 33 | 34 | for i, c in enumerate(text): 35 | index = i * VOCAB_LENGTH + VOCAB.index(c) 36 | vector[index] = 1 37 | return vector 38 | 39 | 40 | def vec2text(vector): 41 | """ 42 | vector to captcha text 43 | :param vector: np array 44 | :return: text 45 | """ 46 | if not isinstance(vector, np.ndarray): 47 | vector = np.asarray(vector) 48 | vector = np.reshape(vector, [CAPTCHA_LENGTH, -1]) 49 | text = '' 50 | for item in vector: 51 | text += VOCAB[np.argmax(item)] 52 | return text 53 | 54 | 55 | def get_random_text(): 56 | text = '' 57 | for i in range(CAPTCHA_LENGTH): 58 | text += random.choice(VOCAB) 59 | return text 60 | 61 | 62 | def generate_data(): 63 | print('Generating Data...') 64 | data_x, data_y = [], [] 65 | 66 | # generate data x and y 67 | for i in range(DATA_LENGTH): 68 | text = get_random_text() 69 | # get captcha array 70 | captcha_array = generate_captcha(text) 71 | # get vector 72 | vector = text2vec(text) 73 | data_x.append(captcha_array) 74 | data_y.append(vector) 75 | 76 | # write data to pickle 77 | if not exists(DATA_PATH): 78 | makedirs(DATA_PATH) 79 | 80 | x = np.asarray(data_x, np.float32) 81 | y = np.asarray(data_y, np.float32) 82 | with open(join(DATA_PATH, 'data.pkl'), 'wb') as f: 83 | pickle.dump(x, f) 84 | pickle.dump(y, f) 85 | 86 | 87 | if __name__ == '__main__': 88 | vector = text2vec('1234') 89 | text = vec2text(vector) 90 | print(vector, text) 91 | 92 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import tensorflow as tf 3 | from config import * 4 | from os.path import join, exists 5 | from os import makedirs 6 | import pickle 7 | import math 8 | from sklearn.model_selection import train_test_split 9 | 10 | FLAGS = None 11 | 12 | 13 | def standardize(x): 14 | return (x - x.mean()) / x.std() 15 | 16 | 17 | def load_data(): 18 | """ 19 | load data from pickle 20 | :return: 21 | """ 22 | with open(join(FLAGS.source_data), 'rb') as f: 23 | data_x = pickle.load(f) 24 | data_y = pickle.load(f) 25 | return standardize(data_x), data_y 26 | 27 | 28 | def get_data(data_x, data_y): 29 | """ 30 | split data from loaded data 31 | :param data_x: 32 | :param data_y: 33 | :return: Arrays 34 | """ 35 | print('Data X Length', len(data_x), 'Data Y Length', len(data_y)) 36 | print('Data X Example', data_x[0]) 37 | print('Data Y Example', data_y[0]) 38 | train_x, test_x, train_y, test_y = train_test_split(data_x, data_y, test_size=0.4, random_state=40) 39 | dev_x, test_x, dev_y, test_y, = train_test_split(test_x, test_y, test_size=0.5, random_state=40) 40 | 41 | print('Train X Shape', train_x.shape, 'Train Y Shape', train_y.shape) 42 | print('Dev X Shape', dev_x.shape, 'Dev Y Shape', dev_y.shape) 43 | print('Test Y Shape', test_x.shape, 'Test Y Shape', test_y.shape) 44 | return train_x, train_y, dev_x, dev_y, test_x, test_y 45 | 46 | 47 | def main(): 48 | data_x, data_y = load_data() 49 | train_x, train_y, dev_x, dev_y, test_x, test_y = get_data(data_x, data_y) 50 | train_steps = math.ceil(train_x.shape[0] / FLAGS.train_batch_size) 51 | dev_steps = math.ceil(dev_x.shape[0] / FLAGS.dev_batch_size) 52 | test_steps = math.ceil(test_x.shape[0] / FLAGS.test_batch_size) 53 | 54 | global_step = tf.Variable(-1, trainable=False, name='global_step') 55 | 56 | # train and dev dataset 57 | train_dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y)).shuffle(10000) 58 | train_dataset = train_dataset.batch(FLAGS.train_batch_size) 59 | 60 | dev_dataset = tf.data.Dataset.from_tensor_slices((dev_x, dev_y)) 61 | dev_dataset = dev_dataset.batch(FLAGS.dev_batch_size) 62 | 63 | test_dataset = tf.data.Dataset.from_tensor_slices((test_x, test_y)) 64 | test_dataset = test_dataset.batch(FLAGS.test_batch_size) 65 | 66 | # a reinitializable iterator 67 | iterator = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes) 68 | 69 | train_initializer = iterator.make_initializer(train_dataset) 70 | dev_initializer = iterator.make_initializer(dev_dataset) 71 | test_initializer = iterator.make_initializer(test_dataset) 72 | 73 | # input Layer 74 | with tf.variable_scope('inputs'): 75 | # x.shape = [-1, 60, 160, 3] 76 | x, y_label = iterator.get_next() 77 | 78 | keep_prob = tf.placeholder(tf.float32, []) 79 | 80 | y = tf.cast(x, tf.float32) 81 | 82 | # 3 CNN layers 83 | for _ in range(3): 84 | y = tf.layers.conv2d(y, filters=32, kernel_size=3, padding='same', activation=tf.nn.relu) 85 | y = tf.layers.max_pooling2d(y, pool_size=2, strides=2, padding='same') 86 | # y = tf.layers.dropout(y, rate=keep_prob) 87 | 88 | # 2 dense layers 89 | y = tf.layers.flatten(y) 90 | y = tf.layers.dense(y, 1024, activation=tf.nn.relu) 91 | y = tf.layers.dropout(y, rate=keep_prob) 92 | y = tf.layers.dense(y, VOCAB_LENGTH) 93 | 94 | y_reshape = tf.reshape(y, [-1, VOCAB_LENGTH]) 95 | y_label_reshape = tf.reshape(y_label, [-1, VOCAB_LENGTH]) 96 | 97 | # loss 98 | cross_entropy = tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(logits=y_reshape, labels=y_label_reshape)) 99 | 100 | # accuracy 101 | max_index_predict = tf.argmax(y_reshape, axis=-1) 102 | max_index_label = tf.argmax(y_label_reshape, axis=-1) 103 | correct_predict = tf.equal(max_index_predict, max_index_label) 104 | accuracy = tf.reduce_mean(tf.cast(correct_predict, tf.float32)) 105 | 106 | # train 107 | train_op = tf.train.RMSPropOptimizer(FLAGS.learning_rate).minimize(cross_entropy, global_step=global_step) 108 | 109 | # saver 110 | saver = tf.train.Saver() 111 | 112 | # iterator 113 | sess = tf.Session() 114 | sess.run(tf.global_variables_initializer()) 115 | 116 | # global step 117 | gstep = 0 118 | 119 | # checkpoint dir 120 | if not exists(FLAGS.checkpoint_dir): 121 | makedirs(FLAGS.checkpoint_dir) 122 | 123 | if FLAGS.train: 124 | for epoch in range(FLAGS.epoch_num): 125 | tf.train.global_step(sess, global_step_tensor=global_step) 126 | # train 127 | sess.run(train_initializer) 128 | for step in range(int(train_steps)): 129 | loss, acc, gstep, _ = sess.run([cross_entropy, accuracy, global_step, train_op], 130 | feed_dict={keep_prob: FLAGS.keep_prob}) 131 | # print log 132 | if step % FLAGS.steps_per_print == 0: 133 | print('Global Step', gstep, 'Step', step, 'Train Loss', loss, 'Accuracy', acc) 134 | 135 | if epoch % FLAGS.epochs_per_dev == 0: 136 | # dev 137 | sess.run(dev_initializer) 138 | for step in range(int(dev_steps)): 139 | if step % FLAGS.steps_per_print == 0: 140 | print('Dev Accuracy', sess.run(accuracy, feed_dict={keep_prob: 1}), 'Step', step) 141 | 142 | # save model 143 | if epoch % FLAGS.epochs_per_save == 0: 144 | saver.save(sess, FLAGS.checkpoint_dir, global_step=gstep) 145 | 146 | else: 147 | # load model 148 | ckpt = tf.train.get_checkpoint_state('ckpt') 149 | if ckpt: 150 | saver.restore(sess, ckpt.model_checkpoint_path) 151 | print('Restore from', ckpt.model_checkpoint_path) 152 | sess.run(test_initializer) 153 | for step in range(int(test_steps)): 154 | if step % FLAGS.steps_per_print == 0: 155 | print('Test Accuracy', sess.run(accuracy, feed_dict={keep_prob: 1}), 'Step', step) 156 | else: 157 | print('No Model Found') 158 | 159 | 160 | if __name__ == '__main__': 161 | parser = argparse.ArgumentParser(description='Captcha') 162 | parser.add_argument('--train_batch_size', help='train batch size', default=128) 163 | parser.add_argument('--dev_batch_size', help='dev batch size', default=256) 164 | parser.add_argument('--test_batch_size', help='test batch size', default=256) 165 | parser.add_argument('--source_data', help='source size', default='./data/data.pkl') 166 | parser.add_argument('--num_layer', help='num of layer', default=2, type=int) 167 | parser.add_argument('--num_units', help='num of units', default=64, type=int) 168 | parser.add_argument('--time_step', help='time steps', default=32, type=int) 169 | parser.add_argument('--embedding_size', help='time steps', default=64, type=int) 170 | parser.add_argument('--category_num', help='category num', default=5, type=int) 171 | parser.add_argument('--learning_rate', help='learning rate', default=0.001, type=float) 172 | parser.add_argument('--epoch_num', help='num of epoch', default=10000, type=int) 173 | parser.add_argument('--epochs_per_test', help='epochs per test', default=100, type=int) 174 | parser.add_argument('--epochs_per_dev', help='epochs per dev', default=2, type=int) 175 | parser.add_argument('--epochs_per_save', help='epochs per save', default=10, type=int) 176 | parser.add_argument('--steps_per_print', help='steps per print', default=2, type=int) 177 | parser.add_argument('--steps_per_summary', help='steps per summary', default=100, type=int) 178 | parser.add_argument('--keep_prob', help='train keep prob dropout', default=0.5, type=float) 179 | parser.add_argument('--checkpoint_dir', help='checkpoint dir', default='ckpt/model.ckpt', type=str) 180 | parser.add_argument('--summaries_dir', help='summaries dir', default='summaries/', type=str) 181 | parser.add_argument('--train', help='train', default=1, type=int) 182 | 183 | FLAGS, args = parser.parse_known_args() 184 | main() 185 | -------------------------------------------------------------------------------- /reference.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from captcha.image import ImageCaptcha 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from PIL import Image 7 | import random 8 | 9 | number = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] 10 | 11 | 12 | # alphabet = ['a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z'] 13 | # ALPHABET = ['A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z'] 14 | 15 | # def random_captcha_text(char_set=number+alphabet+ALPHABET, captcha_size=4): 16 | def random_captcha_text(char_set=number, captcha_size=4): 17 | captcha_text = [] 18 | for i in range(captcha_size): 19 | c = random.choice(char_set) 20 | captcha_text.append(c) 21 | return captcha_text 22 | 23 | 24 | def gen_captcha_text_and_image(): 25 | image = ImageCaptcha() 26 | 27 | captcha_text = random_captcha_text() 28 | captcha_text = ''.join(captcha_text) 29 | 30 | captcha = image.generate(captcha_text) 31 | # image.write(captcha_text, captcha_text + '.jpg') 32 | 33 | captcha_image = Image.open(captcha) 34 | captcha_image = np.array(captcha_image) 35 | return captcha_text, captcha_image 36 | 37 | 38 | def convert2gray(img): 39 | if len(img.shape) > 2: 40 | gray = np.mean(img, -1) 41 | # 上面的转法较快,正规转法如下 42 | # r, g, b = img[:,:,0], img[:,:,1], img[:,:,2] 43 | # gray = 0.2989 * r + 0.5870 * g + 0.1140 * b 44 | return gray 45 | else: 46 | return img 47 | 48 | 49 | def text2vec(text): 50 | text_len = len(text) 51 | if text_len > MAX_CAPTCHA: 52 | raise ValueError('验证码最长4个字符') 53 | 54 | vector = np.zeros(MAX_CAPTCHA * CHAR_SET_LEN) 55 | """ 56 | def char2pos(c): 57 | if c =='_': 58 | k = 62 59 | return k 60 | k = ord(c)-48 61 | if k > 9: 62 | k = ord(c) - 55 63 | if k > 35: 64 | k = ord(c) - 61 65 | if k > 61: 66 | raise ValueError('No Map') 67 | return k 68 | """ 69 | for i, c in enumerate(text): 70 | idx = i * CHAR_SET_LEN + int(c) 71 | vector[idx] = 1 72 | return vector 73 | 74 | 75 | # 向量转回文本 76 | def vec2text(vec): 77 | """ 78 | char_pos = vec.nonzero()[0] 79 | text=[] 80 | for i, c in enumerate(char_pos): 81 | char_at_pos = i #c/63 82 | char_idx = c % CHAR_SET_LEN 83 | if char_idx < 10: 84 | char_code = char_idx + ord('0') 85 | elif char_idx <36: 86 | char_code = char_idx - 10 + ord('A') 87 | elif char_idx < 62: 88 | char_code = char_idx- 36 + ord('a') 89 | elif char_idx == 62: 90 | char_code = ord('_') 91 | else: 92 | raise ValueError('error') 93 | text.append(chr(char_code)) 94 | """ 95 | text = [] 96 | char_pos = vec.nonzero()[0] 97 | for i, c in enumerate(char_pos): 98 | number = i % 10 99 | text.append(str(number)) 100 | 101 | return "".join(text) 102 | 103 | 104 | """ 105 | #向量(大小MAX_CAPTCHA*CHAR_SET_LEN)用0,1编码 每63个编码一个字符,这样顺利有,字符也有 106 | vec = text2vec("F5Sd") 107 | text = vec2text(vec) 108 | print(text) # F5Sd 109 | vec = text2vec("SFd5") 110 | text = vec2text(vec) 111 | print(text) # SFd5 112 | """ 113 | 114 | 115 | # 生成一个训练batch 116 | def get_next_batch(batch_size=128): 117 | batch_x = np.zeros([batch_size, IMAGE_HEIGHT * IMAGE_WIDTH]) 118 | batch_y = np.zeros([batch_size, MAX_CAPTCHA * CHAR_SET_LEN]) 119 | 120 | # 有时生成图像大小不是(60, 160, 3) 121 | def wrap_gen_captcha_text_and_image(): 122 | while True: 123 | text, image = gen_captcha_text_and_image() 124 | if image.shape == (60, 160, 3): 125 | return text, image 126 | 127 | for i in range(batch_size): 128 | text, image = wrap_gen_captcha_text_and_image() 129 | image = convert2gray(image) 130 | 131 | batch_x[i, :] = image.flatten() / 255 # (image.flatten()-128)/128 mean为0 132 | batch_y[i, :] = text2vec(text) 133 | 134 | return batch_x, batch_y 135 | 136 | 137 | # 定义CNN 138 | def crack_captcha_cnn(w_alpha=0.01, b_alpha=0.1): 139 | x = tf.reshape(X, shape=[-1, IMAGE_HEIGHT, IMAGE_WIDTH, 1]) 140 | 141 | # w_c1_alpha = np.sqrt(2.0/(IMAGE_HEIGHT*IMAGE_WIDTH)) # 142 | # w_c2_alpha = np.sqrt(2.0/(3*3*32)) 143 | # w_c3_alpha = np.sqrt(2.0/(3*3*64)) 144 | # w_d1_alpha = np.sqrt(2.0/(8*32*64)) 145 | # out_alpha = np.sqrt(2.0/1024) 146 | 147 | # 3 conv layer 148 | w_c1 = tf.Variable(w_alpha * tf.random_normal([3, 3, 1, 32])) 149 | b_c1 = tf.Variable(b_alpha * tf.random_normal([32])) 150 | conv1 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(x, w_c1, strides=[1, 1, 1, 1], padding='SAME'), b_c1)) 151 | conv1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 152 | conv1 = tf.nn.dropout(conv1, keep_prob) 153 | 154 | w_c2 = tf.Variable(w_alpha * tf.random_normal([3, 3, 32, 64])) 155 | b_c2 = tf.Variable(b_alpha * tf.random_normal([64])) 156 | conv2 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv1, w_c2, strides=[1, 1, 1, 1], padding='SAME'), b_c2)) 157 | conv2 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 158 | conv2 = tf.nn.dropout(conv2, keep_prob) 159 | 160 | w_c3 = tf.Variable(w_alpha * tf.random_normal([3, 3, 64, 64])) 161 | b_c3 = tf.Variable(b_alpha * tf.random_normal([64])) 162 | conv3 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv2, w_c3, strides=[1, 1, 1, 1], padding='SAME'), b_c3)) 163 | conv3 = tf.nn.max_pool(conv3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 164 | conv3 = tf.nn.dropout(conv3, keep_prob) 165 | 166 | # Fully connected layer 167 | w_d = tf.Variable(w_alpha * tf.random_normal([8 * 20 * 64, 1024])) 168 | b_d = tf.Variable(b_alpha * tf.random_normal([1024])) 169 | dense = tf.reshape(conv3, [-1, w_d.get_shape().as_list()[0]]) 170 | dense = tf.nn.relu(tf.add(tf.matmul(dense, w_d), b_d)) 171 | dense = tf.nn.dropout(dense, keep_prob) 172 | 173 | w_out = tf.Variable(w_alpha * tf.random_normal([1024, MAX_CAPTCHA * CHAR_SET_LEN])) 174 | b_out = tf.Variable(b_alpha * tf.random_normal([MAX_CAPTCHA * CHAR_SET_LEN])) 175 | out = tf.add(tf.matmul(dense, w_out), b_out) 176 | return out 177 | 178 | 179 | # 训练 180 | def train_crack_captcha_cnn(): 181 | output = crack_captcha_cnn() 182 | loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(output, Y)) 183 | optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss) 184 | predict = tf.reshape(output, [-1, MAX_CAPTCHA, CHAR_SET_LEN]) 185 | max_idx_p = tf.argmax(predict, 2) 186 | max_idx_l = tf.argmax(tf.reshape(Y, [-1, MAX_CAPTCHA, CHAR_SET_LEN]), 2) 187 | correct_pred = tf.equal(max_idx_p, max_idx_l) 188 | accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) 189 | 190 | saver = tf.train.Saver() 191 | with tf.Session() as sess: 192 | sess.run(tf.global_variables_initializer()) 193 | 194 | step = 0 195 | while True: 196 | batch_x, batch_y = get_next_batch(64) 197 | _, loss_ = sess.run([optimizer, loss], feed_dict={X: batch_x, Y: batch_y, keep_prob: 0.75}) 198 | print(step, loss_) 199 | 200 | # 每100 step计算一次准确率 201 | if step % 10 == 0: 202 | batch_x_test, batch_y_test = get_next_batch(100) 203 | acc = sess.run(accuracy, feed_dict={X: batch_x_test, Y: batch_y_test, keep_prob: 1.}) 204 | print(step, acc) 205 | # 如果准确率大于50%,保存模型,完成训练 206 | if acc > 0.50: 207 | saver.save(sess, "./model/crack_capcha.model", global_step=step) 208 | break 209 | 210 | step += 1 211 | 212 | 213 | def crack_captcha(captcha_image): 214 | output = crack_captcha_cnn() 215 | 216 | saver = tf.train.Saver() 217 | with tf.Session() as sess: 218 | saver.restore(sess, "./model/crack_capcha.model-810") 219 | 220 | predict = tf.argmax(tf.reshape(output, [-1, MAX_CAPTCHA, CHAR_SET_LEN]), 2) 221 | text_list = sess.run(predict, feed_dict={X: [captcha_image], keep_prob: 1}) 222 | text = text_list[0].tolist() 223 | return text 224 | 225 | 226 | if __name__ == '__main__': 227 | train = 1 228 | if train == 0: 229 | number = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] 230 | # alphabet = ['a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z'] 231 | # ALPHABET = ['A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z'] 232 | 233 | text, image = gen_captcha_text_and_image() 234 | print("验证码图像channel:", image.shape) # (60, 160, 3) 235 | # 图像大小 236 | IMAGE_HEIGHT = 60 237 | IMAGE_WIDTH = 160 238 | MAX_CAPTCHA = len(text) 239 | print("验证码文本最长字符数", MAX_CAPTCHA) 240 | # 文本转向量 241 | # char_set = number + alphabet + ALPHABET + ['_'] # 如果验证码长度小于4, '_'用来补齐 242 | char_set = number 243 | CHAR_SET_LEN = len(char_set) 244 | 245 | X = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT * IMAGE_WIDTH]) 246 | Y = tf.placeholder(tf.float32, [None, MAX_CAPTCHA * CHAR_SET_LEN]) 247 | keep_prob = tf.placeholder(tf.float32) # dropout 248 | 249 | train_crack_captcha_cnn() 250 | if train == 1: 251 | number = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] 252 | IMAGE_HEIGHT = 60 253 | IMAGE_WIDTH = 160 254 | char_set = number 255 | CHAR_SET_LEN = len(char_set) 256 | 257 | text, image = gen_captcha_text_and_image() 258 | 259 | f = plt.figure() 260 | ax = f.add_subplot(111) 261 | ax.text(0.1, 0.9, text, ha='center', va='center', transform=ax.transAxes) 262 | plt.imshow(image) 263 | 264 | plt.show() 265 | 266 | MAX_CAPTCHA = len(text) 267 | image = convert2gray(image) 268 | image = image.flatten() / 255 269 | 270 | X = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT * IMAGE_WIDTH]) 271 | Y = tf.placeholder(tf.float32, [None, MAX_CAPTCHA * CHAR_SET_LEN]) 272 | keep_prob = tf.placeholder(tf.float32) # dropout 273 | 274 | predict_text = crack_captcha(image) 275 | print("正确: {} 预测: {}".format(text, predict_text)) 276 | --------------------------------------------------------------------------------