├── readme_image ├── Cage1.jpg ├── Cage2.jpg ├── iCaptcha.jpg ├── test_acc.png ├── 压力测试结果.png ├── Kaptcha_2.png ├── Kaptcha_3.png ├── Kaptcha_5.png ├── jcaptcha1.jpg ├── jcaptcha2.jpg ├── jcaptcha3.jpg ├── patchca_1.png ├── train_acc.png ├── SkewPassImage.jpg ├── bug_api启动失败.png ├── py_Captcha-1.jpg ├── SimpleCaptcha_1.jpg ├── SimpleCaptcha_2.jpg └── SimpleCaptcha_3.jpg ├── conf ├── captcha_config.json ├── sample_config.json └── sample_config.md ├── tools ├── gen_md_content.py ├── collect_labels.py └── correction_captcha.py ├── requirements.txt ├── gen_sample_by_captcha.py ├── recognize_local.py ├── .gitignore ├── webserver_captcha_image.py ├── recognize_time_test.py ├── recognize_online.py ├── cnnlib ├── recognition_object.py └── network.py ├── webserver_recognize_api.py ├── test_batch.py ├── verify_and_split_data.py ├── LICENSE ├── train_model.py └── README.md /readme_image/Cage1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nickliqian/cnn_captcha/HEAD/readme_image/Cage1.jpg -------------------------------------------------------------------------------- /readme_image/Cage2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nickliqian/cnn_captcha/HEAD/readme_image/Cage2.jpg -------------------------------------------------------------------------------- /readme_image/iCaptcha.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nickliqian/cnn_captcha/HEAD/readme_image/iCaptcha.jpg -------------------------------------------------------------------------------- /readme_image/test_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nickliqian/cnn_captcha/HEAD/readme_image/test_acc.png -------------------------------------------------------------------------------- /readme_image/压力测试结果.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nickliqian/cnn_captcha/HEAD/readme_image/压力测试结果.png -------------------------------------------------------------------------------- /readme_image/Kaptcha_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nickliqian/cnn_captcha/HEAD/readme_image/Kaptcha_2.png -------------------------------------------------------------------------------- /readme_image/Kaptcha_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nickliqian/cnn_captcha/HEAD/readme_image/Kaptcha_3.png -------------------------------------------------------------------------------- /readme_image/Kaptcha_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nickliqian/cnn_captcha/HEAD/readme_image/Kaptcha_5.png -------------------------------------------------------------------------------- /readme_image/jcaptcha1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nickliqian/cnn_captcha/HEAD/readme_image/jcaptcha1.jpg -------------------------------------------------------------------------------- /readme_image/jcaptcha2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nickliqian/cnn_captcha/HEAD/readme_image/jcaptcha2.jpg -------------------------------------------------------------------------------- /readme_image/jcaptcha3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nickliqian/cnn_captcha/HEAD/readme_image/jcaptcha3.jpg -------------------------------------------------------------------------------- /readme_image/patchca_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nickliqian/cnn_captcha/HEAD/readme_image/patchca_1.png -------------------------------------------------------------------------------- /readme_image/train_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nickliqian/cnn_captcha/HEAD/readme_image/train_acc.png -------------------------------------------------------------------------------- /readme_image/SkewPassImage.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nickliqian/cnn_captcha/HEAD/readme_image/SkewPassImage.jpg -------------------------------------------------------------------------------- /readme_image/bug_api启动失败.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nickliqian/cnn_captcha/HEAD/readme_image/bug_api启动失败.png -------------------------------------------------------------------------------- /readme_image/py_Captcha-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nickliqian/cnn_captcha/HEAD/readme_image/py_Captcha-1.jpg -------------------------------------------------------------------------------- /readme_image/SimpleCaptcha_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nickliqian/cnn_captcha/HEAD/readme_image/SimpleCaptcha_1.jpg -------------------------------------------------------------------------------- /readme_image/SimpleCaptcha_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nickliqian/cnn_captcha/HEAD/readme_image/SimpleCaptcha_2.jpg -------------------------------------------------------------------------------- /readme_image/SimpleCaptcha_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nickliqian/cnn_captcha/HEAD/readme_image/SimpleCaptcha_3.jpg -------------------------------------------------------------------------------- /conf/captcha_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "root_dir": "sample/origin/", 3 | "image_suffix": "png", 4 | "characters": "0123456789abcdefghijklmnopqrstuvwxyz", 5 | "count": 20000, 6 | "char_count": 4, 7 | "width": 100, 8 | "height": 60 9 | } 10 | -------------------------------------------------------------------------------- /tools/gen_md_content.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import re 3 | 4 | 5 | file_path = "../README.md" 6 | with open(file_path, "r") as f: 7 | content = f.readlines() 8 | 9 | for c in content: 10 | c = c.strip() 11 | pattern = r"^#+\s[0-9.]+\s" 12 | r = re.match(pattern, c) 13 | if r: 14 | c1 = re.sub(pattern, "", c) 15 | c2 = re.sub(r"#+\s", "", c) 16 | string = '{} '.format(c1, c2) 17 | print(string) 18 | 19 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.7.1 2 | astor==0.7.1 3 | bleach==1.5.0 4 | captcha==0.3 5 | certifi==2019.3.9 6 | chardet==3.0.4 7 | Click==7.0 8 | cycler==0.10.0 9 | easydict==1.8 10 | Flask==1.0.2 11 | gast==0.2.2 12 | grpcio==1.19.0 13 | html5lib==0.9999999 14 | idna==2.7 15 | itsdangerous==1.1.0 16 | Jinja2==2.10.1 17 | Markdown==3.1 18 | MarkupSafe==1.1.1 19 | matplotlib==2.1.0 20 | numpy==1.16.2 21 | olefile==0.46 22 | Pillow==4.3.0 23 | protobuf==3.6.0 24 | pyparsing==2.4.0 25 | python-dateutil==2.8.0 26 | pytz==2018.9 27 | requests==2.19.1 28 | six==1.12.0 29 | tensorboard==1.7.0 30 | tensorflow==1.7.0 31 | termcolor==1.1.0 32 | urllib3==1.23 33 | Werkzeug==0.15.2 34 | -------------------------------------------------------------------------------- /tools/collect_labels.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: UTF-8 -*- 3 | """ 4 | 统计样本的标签,并写入文件labels.json 5 | """ 6 | import os 7 | import json 8 | 9 | 10 | image_dir = "../sample/origin" 11 | image_list = os.listdir(image_dir) 12 | 13 | labels = set() 14 | for img in image_list: 15 | split_result = img.split("_") 16 | if len(split_result) == 2: 17 | label, name = split_result 18 | if label: 19 | for word in label: 20 | labels.add(word) 21 | else: 22 | pass 23 | 24 | print("共有标签{}种".format(len(labels))) 25 | 26 | with open("./labels.json", "w") as f: 27 | f.write(json.dumps("".join(list(labels)), ensure_ascii=False)) 28 | 29 | print("将标签列表写入文件labels.json成功") 30 | -------------------------------------------------------------------------------- /conf/sample_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "origin_image_dir": "sample/origin/", 3 | "new_image_dir": "sample/new_train/", 4 | "train_image_dir": "sample/train/", 5 | "test_image_dir": "sample/test/", 6 | "api_image_dir": "sample/api/", 7 | "online_image_dir": "sample/online/", 8 | "local_image_dir": "sample/local/", 9 | "model_save_dir": "model/", 10 | "image_width": 100, 11 | "image_height": 60, 12 | "max_captcha": 4, 13 | "image_suffix": "png", 14 | "char_set": "0123456789abcdefghijklmnopqrstuvwxyz", 15 | "use_labels_json_file": false, 16 | "remote_url": "http://127.0.0.1:6100/captcha/", 17 | "cycle_stop": 20000, 18 | "acc_stop": 0.99, 19 | "cycle_save": 500, 20 | "enable_gpu": 1, 21 | "train_batch_size": 128, 22 | "test_batch_size": 100 23 | } -------------------------------------------------------------------------------- /conf/sample_config.md: -------------------------------------------------------------------------------- 1 | ## 图片文件夹 2 | ``` 3 | origin_image_dir = "./sample/origin/" # 原始文件 4 | train_image_dir = "./sample/train/" # 训练集 5 | test_image_dir = "./sample/test/" # 测试集 6 | api_image_dir = "./sample/api/" # api接收的图片储存路径 7 | online_image_dir = "./sample/online/" # 从验证码url获取的图片的储存路径 8 | ``` 9 | ## 模型文件夹 10 | ``` 11 | model_save_dir = "./model/" # 训练好的模型储存路径 12 | ``` 13 | ## 图片相关参数 14 | ``` 15 | image_width = 80 # 图片宽度 16 | image_height = 40 # 图片高度 17 | max_captcha = 4 # 验证码字符个数 18 | image_suffix = "jpg" # 图片文件后缀 19 | ``` 20 | ## 是否从文件中的导入标签 21 | ``` 22 | use_labels_json_file = False 23 | ``` 24 | ## 验证码字符相关参数 25 | ``` 26 | char_set = "0123456789abcdefghijklmnopqrstuvwxyz" 27 | char_set = "abcdefghijklmnopqrstuvwxyz" 28 | char_set = "0123456789" 29 | ``` 30 | ## 在线识别远程验证码地址 31 | ``` 32 | remote_url = "http://127.0.0.1:6100/captcha/" 33 | ``` 34 | ## 训练相关参数 35 | ``` 36 | cycle_stop = 3000 # 到指定迭代次数后停止 37 | acc_stop = 0.99 # 到指定准确率后停止 38 | cycle_save = 500 # 每训练指定轮数就保存一次(覆盖之前的模型) 39 | enable_gpu = 0 # 使用GPU还是CPU,使用GPU需要安装对应版本的tensorflow-gpu==1.7.0 40 | ``` -------------------------------------------------------------------------------- /gen_sample_by_captcha.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | 使用captcha lib生成验证码(前提:pip install captcha) 4 | """ 5 | from captcha.image import ImageCaptcha 6 | import os 7 | import random 8 | import time 9 | import json 10 | 11 | 12 | def gen_special_img(text, file_path, width, height): 13 | # 生成img文件 14 | generator = ImageCaptcha(width=width, height=height) # 指定大小 15 | img = generator.generate_image(text) # 生成图片 16 | img.save(file_path) # 保存图片 17 | 18 | 19 | def gen_ima_by_batch(root_dir, image_suffix, characters, count, char_count, width, height): 20 | # 判断文件夹是否存在 21 | if not os.path.exists(root_dir): 22 | os.makedirs(root_dir) 23 | 24 | for index, i in enumerate(range(count)): 25 | text = "" 26 | for j in range(char_count): 27 | text += random.choice(characters) 28 | 29 | timec = str(time.time()).replace(".", "") 30 | p = os.path.join(root_dir, "{}_{}.{}".format(text, timec, image_suffix)) 31 | gen_special_img(text, p, width, height) 32 | 33 | print("Generate captcha image => {}".format(index + 1)) 34 | 35 | 36 | def main(): 37 | with open("conf/captcha_config.json", "r") as f: 38 | config = json.load(f) 39 | # 配置参数 40 | root_dir = config["root_dir"] # 图片储存路径 41 | image_suffix = config["image_suffix"] # 图片储存后缀 42 | characters = config["characters"] # 图片上显示的字符集 # characters = "0123456789abcdefghijklmnopqrstuvwxyz" 43 | count = config["count"] # 生成多少张样本 44 | char_count = config["char_count"] # 图片上的字符数量 45 | 46 | # 设置图片高度和宽度 47 | width = config["width"] 48 | height = config["height"] 49 | 50 | gen_ima_by_batch(root_dir, image_suffix, characters, count, char_count, width, height) 51 | 52 | 53 | if __name__ == '__main__': 54 | main() 55 | -------------------------------------------------------------------------------- /recognize_local.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: UTF-8 -*- 3 | """ 4 | 使用自建的接口识别来自网络的验证码 5 | 需要配置参数: 6 | remote_url = "https://www.xxxxxxx.com/getImg" 验证码链接地址 7 | rec_times = 1 识别的次数 8 | """ 9 | import datetime 10 | import requests 11 | from io import BytesIO 12 | import time 13 | import json 14 | import os 15 | 16 | 17 | def recognize_captcha(test_path, save_path, image_suffix): 18 | image_file_name = 'captcha.{}'.format(image_suffix) 19 | 20 | with open(test_path, "rb") as f: 21 | content = f.read() 22 | 23 | # 识别 24 | s = time.time() 25 | url = "http://127.0.0.1:6000/b" 26 | files = {'image_file': (image_file_name, BytesIO(content), 'application')} 27 | r = requests.post(url=url, files=files) 28 | e = time.time() 29 | 30 | # 识别结果 31 | print("接口响应: {}".format(r.text)) 32 | predict_text = json.loads(r.text)["value"] 33 | now_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') 34 | print("【{}】 耗时:{}ms 预测结果:{}".format(now_time, int((e-s)*1000), predict_text)) 35 | 36 | # 保存文件 37 | img_name = "{}_{}.{}".format(predict_text, str(time.time()).replace(".", ""), image_suffix) 38 | path = os.path.join(save_path, img_name) 39 | with open(path, "wb") as f: 40 | f.write(content) 41 | print("============== end ==============") 42 | 43 | 44 | def main(): 45 | with open("conf/sample_config.json", "r") as f: 46 | sample_conf = json.load(f) 47 | 48 | # 配置相关参数 49 | test_path = "sample/test/0401_15440848576253345.png" # 测试识别的图片路径 50 | save_path = sample_conf["local_image_dir"] # 保存的地址 51 | image_suffix = sample_conf["image_suffix"] # 文件后缀 52 | recognize_captcha(test_path, save_path, image_suffix) 53 | 54 | 55 | if __name__ == '__main__': 56 | main() 57 | 58 | 59 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # pycharm 107 | .idea/ 108 | 109 | # 数据文件 110 | sample/ 111 | model/ 112 | labels.json 113 | test.csv 114 | loss_test.csv 115 | loss_train.csv 116 | 117 | 118 | logs/ -------------------------------------------------------------------------------- /webserver_captcha_image.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | 验证码图片接口,访问`/captcha/1`获得图片 4 | """ 5 | from captcha.image import ImageCaptcha 6 | import os 7 | import random 8 | from flask import Flask, request, jsonify, Response, make_response 9 | import json 10 | import io 11 | 12 | 13 | # Flask对象 14 | app = Flask(__name__) 15 | basedir = os.path.abspath(os.path.dirname(__file__)) 16 | 17 | 18 | with open("conf/captcha_config.json", "r") as f: 19 | config = json.load(f) 20 | # 配置参数 21 | root_dir = config["root_dir"] # 图片储存路径 22 | image_suffix = config["image_suffix"] # 图片储存后缀 23 | characters = config["characters"] # 图片上显示的字符集 # characters = "0123456789abcdefghijklmnopqrstuvwxyz" 24 | count = config["count"] # 生成多少张样本 25 | char_count = config["char_count"] # 图片上的字符数量 26 | 27 | # 设置图片高度和宽度 28 | width = config["width"] 29 | height = config["height"] 30 | 31 | 32 | def response_headers(content): 33 | resp = Response(content) 34 | resp.headers['Access-Control-Allow-Origin'] = '*' 35 | return resp 36 | 37 | 38 | def gen_special_img(): 39 | # 随机文字 40 | text = "" 41 | for j in range(char_count): 42 | text += random.choice(characters) 43 | print(text) 44 | # 生成img文件 45 | generator = ImageCaptcha(width=width, height=height) # 指定大小 46 | img = generator.generate_image(text) # 生成图片 47 | imgByteArr = io.BytesIO() 48 | img.save(imgByteArr, format='PNG') 49 | imgByteArr = imgByteArr.getvalue() 50 | return imgByteArr 51 | 52 | 53 | @app.route('/captcha/', methods=['GET']) 54 | def show_photo(): 55 | if request.method == 'GET': 56 | image_data = gen_special_img() 57 | response = make_response(image_data) 58 | response.headers['Content-Type'] = 'image/png' 59 | response.headers['Access-Control-Allow-Origin'] = '*' 60 | return response 61 | else: 62 | pass 63 | 64 | 65 | if __name__ == '__main__': 66 | app.run( 67 | host='0.0.0.0', 68 | port=6100, 69 | debug=True 70 | ) 71 | -------------------------------------------------------------------------------- /recognize_time_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: UTF-8 -*- 3 | """ 4 | 使用自建的接口识别来自网络的验证码 5 | 需要配置参数: 6 | remote_url = "https://www.xxxxxxx.com/getImg" 验证码链接地址 7 | rec_times = 1 识别的次数 8 | """ 9 | import datetime 10 | import requests 11 | from io import BytesIO 12 | import time 13 | import json 14 | import os 15 | 16 | 17 | def recognize_captcha(index, test_path, save_path, image_suffix): 18 | image_file_name = 'captcha.{}'.format(image_suffix) 19 | 20 | with open(test_path, "rb") as f: 21 | content = f.read() 22 | 23 | # 识别 24 | s = time.time() 25 | url = "http://127.0.0.1:6000/b" 26 | files = {'image_file': (image_file_name, BytesIO(content), 'application')} 27 | r = requests.post(url=url, files=files) 28 | e = time.time() 29 | 30 | # 测试参数 31 | result_dict = json.loads(r.text)["value"] # 响应 32 | predict_text = result_dict["value"] # 识别结果 33 | whole_time_for_work = int((e - s) * 1000) 34 | speed_time_by_rec = result_dict["speed_time(ms)"] # 模型识别耗时 35 | request_time_by_rec = whole_time_for_work - speed_time_by_rec # 请求耗时 36 | now_time = datetime.datetime.now().strftime('%Y-%m-%d@%H:%M:%S') # 当前时间 37 | 38 | # 记录日志 39 | log = "{},{},{},{},{},{}\n"\ 40 | .format(index, predict_text, now_time, whole_time_for_work, speed_time_by_rec, request_time_by_rec) 41 | with open("./test.csv", "a+") as f: 42 | f.write(log) 43 | 44 | # 输出结果到控制台 45 | print("次数:{},结果:{},时刻:{},总耗时:{}ms,识别:{}ms,请求:{}ms" 46 | .format(index, predict_text, now_time, whole_time_for_work, speed_time_by_rec, request_time_by_rec)) 47 | 48 | # 保存文件 49 | # img_name = "{}_{}.{}".format(predict_text, str(time.time()).replace(".", ""), image_suffix) 50 | # path = os.path.join(save_path, img_name) 51 | # with open(path, "wb") as f: 52 | # f.write(content) 53 | 54 | 55 | def main(): 56 | with open("conf/sample_config.json", "r") as f: 57 | sample_conf = json.load(f) 58 | 59 | # 配置相关参数 60 | test_file = "sample/test/0001_15430304076164024.png" # 测试识别的图片路径 61 | save_path = sample_conf["local_image_dir"] # 保存的地址 62 | image_suffix = sample_conf["image_suffix"] # 文件后缀 63 | for i in range(20000): 64 | recognize_captcha(i, test_file, save_path, image_suffix) 65 | 66 | 67 | if __name__ == '__main__': 68 | main() 69 | 70 | 71 | -------------------------------------------------------------------------------- /recognize_online.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: UTF-8 -*- 3 | """ 4 | 使用自建的接口识别来自网络的验证码 5 | 需要配置参数: 6 | remote_url = "https://www.xxxxxxx.com/getImg" 验证码链接地址 7 | rec_times = 1 识别的次数 8 | """ 9 | import datetime 10 | import requests 11 | from io import BytesIO 12 | import time 13 | import json 14 | import os 15 | 16 | 17 | def recognize_captcha(remote_url, rec_times, save_path, image_suffix): 18 | image_file_name = 'captcha.{}'.format(image_suffix) 19 | 20 | headers = { 21 | 'user-agent': "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/65.0.3325.146 Safari/537.36", 22 | } 23 | 24 | for index in range(rec_times): 25 | # 请求 26 | while True: 27 | try: 28 | response = requests.request("GET", remote_url, headers=headers, timeout=6) 29 | if response.text: 30 | break 31 | else: 32 | print("retry, response.text is empty") 33 | except Exception as ee: 34 | print(ee) 35 | 36 | # 识别 37 | s = time.time() 38 | url = "http://127.0.0.1:6000/b" 39 | files = {'image_file': (image_file_name, BytesIO(response.content), 'application')} 40 | r = requests.post(url=url, files=files) 41 | e = time.time() 42 | 43 | # 识别结果 44 | print("接口响应: {}".format(r.text)) 45 | predict_text = json.loads(r.text)["value"] 46 | now_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') 47 | print("【{}】 index:{} 耗时:{}ms 预测结果:{}".format(now_time, index, int((e-s)*1000), predict_text)) 48 | 49 | # 保存文件 50 | img_name = "{}_{}.{}".format(predict_text, str(time.time()).replace(".", ""), image_suffix) 51 | path = os.path.join(save_path, img_name) 52 | with open(path, "wb") as f: 53 | f.write(response.content) 54 | print("============== end ==============") 55 | 56 | 57 | def main(): 58 | with open("conf/sample_config.json", "r") as f: 59 | sample_conf = json.load(f) 60 | 61 | # 配置相关参数 62 | save_path = sample_conf["online_image_dir"] # 下载图片保存的地址 63 | remote_url = sample_conf["remote_url"] # 网络验证码地址 64 | image_suffix = sample_conf["image_suffix"] # 文件后缀 65 | rec_times = 1 66 | recognize_captcha(remote_url, rec_times, save_path, image_suffix) 67 | 68 | 69 | if __name__ == '__main__': 70 | main() 71 | 72 | 73 | -------------------------------------------------------------------------------- /tools/correction_captcha.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: UTF-8 -*- 3 | """ 4 | 人工在线验证脚本 5 | """ 6 | import requests 7 | from io import BytesIO 8 | import time 9 | import matplotlib.pyplot as plt 10 | import json 11 | import numpy as np 12 | from PIL import Image 13 | import os 14 | 15 | 16 | def correction(fail_path, pass_path, correction_times, remote_url): 17 | headers = { 18 | 'user-agent': "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/65.0.3325.146 Safari/537.36", 19 | } 20 | 21 | fail_count = 0 22 | for index in range(correction_times): 23 | # 请求 24 | while True: 25 | try: 26 | response = requests.request("GET", remote_url, headers=headers, timeout=10) 27 | break 28 | except Exception as e: 29 | print(e) 30 | 31 | # 识别 32 | s = time.time() 33 | url = "http://127.0.0.1:6000/b" 34 | files = {'image_file': ('captcha.jpg', BytesIO(response.content), 'application')} 35 | r = requests.post(url=url, files=files) 36 | e = time.time() 37 | print(index, int((e-s)*1000), "ms") 38 | print(r.text) 39 | time.sleep(2) 40 | 41 | # 识别结果 42 | predict_text = json.loads(r.text)["value"] 43 | f = plt.figure() 44 | ax = f.add_subplot(111) 45 | ax.text(0.1, 0.9, "备注", ha='center', va='center', transform=ax.transAxes) 46 | 47 | # 图片字节流转为image array 48 | img = BytesIO(response.content) 49 | img = Image.open(img, mode="r") 50 | captcha_array = np.array(img) 51 | plt.imshow(captcha_array) 52 | 53 | # 预测图片 54 | print("预测: {}\n".format(predict_text)) 55 | 56 | # 显示图片和预测结果 57 | plt.text(20, 2, 'predict:{}'.format(predict_text)) 58 | plt.show() 59 | 60 | q = input("index:<{}> 正确按enter,错误输入真实值后会保存:".format(index)) 61 | img_name = "{}_{}".format(q, str(time.time()).replace(".", "")) 62 | if q: 63 | path = os.path.join(fail_path, img_name) 64 | with open(path, "wb") as f: 65 | f.write(response.content) 66 | fail_count += 1 67 | else: 68 | path = os.path.join(pass_path, img_name) 69 | with open(path, "wb") as f: 70 | f.write(response.content) 71 | 72 | print("==============") 73 | 74 | rate = (correction_times - fail_count)/correction_times 75 | print("Pass Rate: {}".format(rate)) 76 | 77 | 78 | def main(): 79 | fail_path = "./sample/fail_sample/" 80 | pass_path = "./sample/pass_sample/" 81 | correction_times = 10 82 | remote_url = "https://www.xxxxxxx.com/getImg" 83 | 84 | correction(fail_path, pass_path, correction_times, remote_url) 85 | 86 | 87 | if __name__ == '__main__': 88 | main() 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /cnnlib/recognition_object.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | 识别图像的类,为了快速进行多次识别可以调用此类下面的方法: 4 | R = Recognizer(image_height, image_width, max_captcha) 5 | for i in range(10): 6 | r_img = Image.open(str(i) + ".jpg") 7 | t = R.rec_image(r_img) 8 | 简单的图片每张基本上可以达到毫秒级的识别速度 9 | """ 10 | import tensorflow as tf 11 | import numpy as np 12 | from PIL import Image 13 | from cnnlib.network import CNN 14 | import json 15 | 16 | 17 | class Recognizer(CNN): 18 | def __init__(self, image_height, image_width, max_captcha, char_set, model_save_dir): 19 | # 初始化变量 20 | super(Recognizer, self).__init__(image_height, image_width, max_captcha, char_set, model_save_dir) 21 | 22 | # 新建图和会话 23 | self.g = tf.Graph() 24 | self.sess = tf.Session(graph=self.g) 25 | # 使用指定的图和会话 26 | with self.g.as_default(): 27 | # 迭代循环前,写出所有用到的张量的计算表达式,如果写在循环中,会发生内存泄漏,拖慢识别的速度 28 | # tf初始化占位符 29 | self.X = tf.placeholder(tf.float32, [None, self.image_height * self.image_width]) # 特征向量 30 | self.Y = tf.placeholder(tf.float32, [None, self.max_captcha * self.char_set_len]) # 标签 31 | self.keep_prob = tf.placeholder(tf.float32) # dropout值 32 | # 加载网络和模型参数 33 | self.y_predict = self.model() 34 | self.predict = tf.argmax(tf.reshape(self.y_predict, [-1, self.max_captcha, self.char_set_len]), 2) 35 | saver = tf.train.Saver() 36 | with self.sess.as_default() as sess: 37 | saver.restore(sess, self.model_save_dir) 38 | 39 | # def __del__(self): 40 | # self.sess.close() 41 | # print("session close") 42 | 43 | def rec_image(self, img): 44 | # 读取图片 45 | img_array = np.array(img) 46 | test_image = self.convert2gray(img_array) 47 | test_image = test_image.flatten() / 255 48 | # 使用指定的图和会话 49 | with self.g.as_default(): 50 | with self.sess.as_default() as sess: 51 | text_list = sess.run(self.predict, feed_dict={self.X: [test_image], self.keep_prob: 1.}) 52 | 53 | # 获取结果 54 | predict_text = text_list[0].tolist() 55 | p_text = "" 56 | for p in predict_text: 57 | p_text += str(self.char_set[p]) 58 | 59 | # 返回识别结果 60 | return p_text 61 | 62 | 63 | def main(): 64 | with open("conf/sample_config.json", "r", encoding="utf-8") as f: 65 | sample_conf = json.load(f) 66 | image_height = sample_conf["image_height"] 67 | image_width = sample_conf["image_width"] 68 | max_captcha = sample_conf["max_captcha"] 69 | char_set = sample_conf["char_set"] 70 | model_save_dir = sample_conf["model_save_dir"] 71 | R = Recognizer(image_height, image_width, max_captcha, char_set, model_save_dir) 72 | r_img = Image.open("./sample/test/2b3n_6915e26c67a52bc0e4e13d216eb62b37.jpg") 73 | t = R.rec_image(r_img) 74 | print(t) 75 | 76 | 77 | if __name__ == '__main__': 78 | main() 79 | -------------------------------------------------------------------------------- /webserver_recognize_api.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | 构建flask接口服务 4 | 接收 files={'image_file': ('captcha.jpg', BytesIO(bytes), 'application')} 参数识别验证码 5 | 需要配置参数: 6 | image_height = 40 7 | image_width = 80 8 | max_captcha = 4 9 | """ 10 | import json 11 | from io import BytesIO 12 | import os 13 | from cnnlib.recognition_object import Recognizer 14 | 15 | import time 16 | from flask import Flask, request, jsonify, Response 17 | from PIL import Image 18 | 19 | # 默认使用CPU 20 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 21 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 22 | 23 | with open("conf/sample_config.json", "r") as f: 24 | sample_conf = json.load(f) 25 | # 配置参数 26 | image_height = sample_conf["image_height"] 27 | image_width = sample_conf["image_width"] 28 | max_captcha = sample_conf["max_captcha"] 29 | api_image_dir = sample_conf["api_image_dir"] 30 | model_save_dir = sample_conf["model_save_dir"] 31 | image_suffix = sample_conf["image_suffix"] # 文件后缀 32 | use_labels_json_file = sample_conf['use_labels_json_file'] 33 | 34 | if use_labels_json_file: 35 | with open("tools/labels.json", "r") as f: 36 | char_set = f.read().strip() 37 | else: 38 | char_set = sample_conf["char_set"] 39 | 40 | # Flask对象 41 | app = Flask(__name__) 42 | basedir = os.path.abspath(os.path.dirname(__file__)) 43 | 44 | # 生成识别对象,需要配置参数 45 | R = Recognizer(image_height, image_width, max_captcha, char_set, model_save_dir) 46 | 47 | # 如果你需要使用多个模型,可以参照原有的例子配置路由和编写逻辑 48 | # Q = Recognizer(image_height, image_width, max_captcha, char_set, model_save_dir) 49 | 50 | 51 | def response_headers(content): 52 | resp = Response(content) 53 | resp.headers['Access-Control-Allow-Origin'] = '*' 54 | return resp 55 | 56 | 57 | @app.route('/b', methods=['POST']) 58 | def up_image(): 59 | if request.method == 'POST' and request.files.get('image_file'): 60 | timec = str(time.time()).replace(".", "") 61 | file = request.files.get('image_file') 62 | img = file.read() 63 | img = BytesIO(img) 64 | img = Image.open(img, mode="r") 65 | # username = request.form.get("name") 66 | print("接收图片尺寸: {}".format(img.size)) 67 | s = time.time() 68 | value = R.rec_image(img) 69 | e = time.time() 70 | print("识别结果: {}".format(value)) 71 | # 保存图片 72 | print("保存图片: {}{}_{}.{}".format(api_image_dir, value, timec, image_suffix)) 73 | file_name = "{}_{}.{}".format(value, timec, image_suffix) 74 | file_path = os.path.join(api_image_dir + file_name) 75 | img.save(file_path) 76 | result = { 77 | 'time': timec, # 时间戳 78 | 'value': value, # 预测的结果 79 | 'speed_time(ms)': int((e - s) * 1000) # 识别耗费的时间 80 | } 81 | img.close() 82 | return jsonify(result) 83 | else: 84 | content = json.dumps({"error_code": "1001"}) 85 | resp = response_headers(content) 86 | return resp 87 | 88 | 89 | if __name__ == '__main__': 90 | app.run( 91 | host='0.0.0.0', 92 | port=6000, 93 | debug=True 94 | ) 95 | -------------------------------------------------------------------------------- /test_batch.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | 4 | import tensorflow as tf 5 | import numpy as np 6 | import time 7 | from PIL import Image 8 | import random 9 | import os 10 | from cnnlib.network import CNN 11 | 12 | 13 | class TestError(Exception): 14 | pass 15 | 16 | 17 | class TestBatch(CNN): 18 | def __init__(self, img_path, char_set, model_save_dir, total): 19 | # 模型路径 20 | self.model_save_dir = model_save_dir 21 | # 打乱文件顺序 22 | self.img_path = img_path 23 | self.img_list = os.listdir(img_path) 24 | random.seed(time.time()) 25 | random.shuffle(self.img_list) 26 | 27 | # 获得图片宽高和字符长度基本信息 28 | label, captcha_array = self.gen_captcha_text_image() 29 | 30 | captcha_shape = captcha_array.shape 31 | captcha_shape_len = len(captcha_shape) 32 | if captcha_shape_len == 3: 33 | image_height, image_width, channel = captcha_shape 34 | self.channel = channel 35 | elif captcha_shape_len == 2: 36 | image_height, image_width = captcha_shape 37 | else: 38 | raise TestError("图片转换为矩阵时出错,请检查图片格式") 39 | 40 | # 初始化变量 41 | super(TestBatch, self).__init__(image_height, image_width, len(label), char_set, model_save_dir) 42 | self.total = total 43 | 44 | # 相关信息打印 45 | print("-->图片尺寸: {} X {}".format(image_height, image_width)) 46 | print("-->验证码长度: {}".format(self.max_captcha)) 47 | print("-->验证码共{}类 {}".format(self.char_set_len, char_set)) 48 | print("-->使用测试集为 {}".format(img_path)) 49 | 50 | def gen_captcha_text_image(self): 51 | """ 52 | 返回一个验证码的array形式和对应的字符串标签 53 | :return:tuple (str, numpy.array) 54 | """ 55 | img_name = random.choice(self.img_list) 56 | # 标签 57 | label = img_name.split("_")[0] 58 | # 文件 59 | img_file = os.path.join(self.img_path, img_name) 60 | captcha_image = Image.open(img_file) 61 | captcha_array = np.array(captcha_image) # 向量化 62 | 63 | return label, captcha_array 64 | 65 | def test_batch(self): 66 | y_predict = self.model() 67 | total = self.total 68 | right = 0 69 | 70 | saver = tf.train.Saver() 71 | with tf.Session() as sess: 72 | saver.restore(sess, self.model_save_dir) 73 | s = time.time() 74 | for i in range(total): 75 | # test_text, test_image = gen_special_num_image(i) 76 | test_text, test_image = self.gen_captcha_text_image() # 随机 77 | test_image = self.convert2gray(test_image) 78 | test_image = test_image.flatten() / 255 79 | 80 | predict = tf.argmax(tf.reshape(y_predict, [-1, self.max_captcha, self.char_set_len]), 2) 81 | text_list = sess.run(predict, feed_dict={self.X: [test_image], self.keep_prob: 1.}) 82 | predict_text = text_list[0].tolist() 83 | p_text = "" 84 | for p in predict_text: 85 | p_text += str(self.char_set[p]) 86 | print("origin: {} predict: {}".format(test_text, p_text)) 87 | if test_text == p_text: 88 | right += 1 89 | else: 90 | pass 91 | e = time.time() 92 | rate = str(right/total * 100) + "%" 93 | print("测试结果: {}/{}".format(right, total)) 94 | print("{}个样本识别耗时{}秒,准确率{}".format(total, e-s, rate)) 95 | 96 | 97 | def main(): 98 | with open("conf/sample_config.json", "r") as f: 99 | sample_conf = json.load(f) 100 | 101 | test_image_dir = sample_conf["test_image_dir"] 102 | model_save_dir = sample_conf["model_save_dir"] 103 | 104 | use_labels_json_file = sample_conf['use_labels_json_file'] 105 | 106 | if use_labels_json_file: 107 | with open("tools/labels.json", "r") as f: 108 | char_set = f.read().strip() 109 | else: 110 | char_set = sample_conf["char_set"] 111 | 112 | total = 100 113 | tb = TestBatch(test_image_dir, char_set, model_save_dir, total) 114 | tb.test_batch() 115 | 116 | 117 | if __name__ == '__main__': 118 | main() 119 | -------------------------------------------------------------------------------- /verify_and_split_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | 验证图片尺寸和分离测试集(5%)和训练集(95%) 3 | 初始化的时候使用,有新的图片后,可以把图片放在new目录里面使用。 4 | """ 5 | import json 6 | 7 | from PIL import Image 8 | import random 9 | import os 10 | import shutil 11 | 12 | 13 | def verify(origin_dir, real_width, real_height, image_suffix): 14 | """ 15 | 校验图片大小 16 | :return: 17 | """ 18 | if not os.path.exists(origin_dir): 19 | print("【警告】找不到目录{},即将创建".format(origin_dir)) 20 | os.makedirs(origin_dir) 21 | 22 | print("开始校验原始图片集") 23 | # 图片真实尺寸 24 | real_size = (real_width, real_height) 25 | # 图片名称列表和数量 26 | img_list = os.listdir(origin_dir) 27 | total_count = len(img_list) 28 | print("原始集共有图片: {}张".format(total_count)) 29 | 30 | # 无效图片列表 31 | bad_img = [] 32 | 33 | # 遍历所有图片进行验证 34 | for index, img_name in enumerate(img_list): 35 | file_path = os.path.join(origin_dir, img_name) 36 | # 过滤图片不正确的后缀 37 | if not img_name.endswith(image_suffix): 38 | bad_img.append((index, img_name, "文件后缀不正确")) 39 | continue 40 | 41 | # 过滤图片标签不标准的情况 42 | prefix, posfix = img_name.split("_") 43 | if prefix == "" or posfix == "": 44 | bad_img.append((index, img_name, "图片标签异常")) 45 | continue 46 | 47 | # 图片无法正常打开 48 | try: 49 | img = Image.open(file_path) 50 | except OSError: 51 | bad_img.append((index, img_name, "图片无法正常打开")) 52 | continue 53 | 54 | # 图片尺寸有异常 55 | if real_size == img.size: 56 | print("{} pass".format(index), end='\r') 57 | else: 58 | bad_img.append((index, img_name, "图片尺寸异常为:{}".format(img.size))) 59 | 60 | print("====以下{}张图片有异常====".format(len(bad_img))) 61 | if bad_img: 62 | for b in bad_img: 63 | print("[第{}张图片] [{}] [{}]".format(b[0], b[1], b[2])) 64 | else: 65 | print("未发现异常(共 {} 张图片)".format(len(img_list))) 66 | print("========end") 67 | return bad_img 68 | 69 | 70 | def split(origin_dir, train_dir, test_dir, bad_imgs): 71 | """ 72 | 分离训练集和测试集 73 | :return: 74 | """ 75 | if not os.path.exists(origin_dir): 76 | print("【警告】找不到目录{},即将创建".format(origin_dir)) 77 | os.makedirs(origin_dir) 78 | 79 | print("开始分离原始图片集为:测试集(5%)和训练集(95%)") 80 | 81 | # 图片名称列表和数量 82 | img_list = os.listdir(origin_dir) 83 | for img in bad_imgs: 84 | img_list.remove(img) 85 | total_count = len(img_list) 86 | print("共分配{}张图片到训练集和测试集,其中{}张为异常留在原始目录".format(total_count, len(bad_imgs))) 87 | 88 | # 创建文件夹 89 | if not os.path.exists(train_dir): 90 | os.mkdir(train_dir) 91 | 92 | if not os.path.exists(test_dir): 93 | os.mkdir(test_dir) 94 | 95 | # 测试集 96 | test_count = int(total_count*0.05) 97 | test_set = set() 98 | for i in range(test_count): 99 | while True: 100 | file_name = random.choice(img_list) 101 | if file_name in test_set: 102 | pass 103 | else: 104 | test_set.add(file_name) 105 | img_list.remove(file_name) 106 | break 107 | 108 | test_list = list(test_set) 109 | print("测试集数量为:{}".format(len(test_list))) 110 | for file_name in test_list: 111 | src = os.path.join(origin_dir, file_name) 112 | dst = os.path.join(test_dir, file_name) 113 | shutil.move(src, dst) 114 | 115 | # 训练集 116 | train_list = img_list 117 | print("训练集数量为:{}".format(len(train_list))) 118 | for file_name in train_list: 119 | src = os.path.join(origin_dir, file_name) 120 | dst = os.path.join(train_dir, file_name) 121 | shutil.move(src, dst) 122 | 123 | if os.listdir(origin_dir) == 0: 124 | print("migration done") 125 | 126 | 127 | def main(): 128 | with open("conf/sample_config.json", "r") as f: 129 | sample_conf = json.load(f) 130 | 131 | # 图片路径 132 | origin_dir = sample_conf["origin_image_dir"] 133 | new_dir = sample_conf["new_image_dir"] 134 | train_dir = sample_conf["train_image_dir"] 135 | test_dir = sample_conf["test_image_dir"] 136 | # 图片尺寸 137 | real_width = sample_conf["image_width"] 138 | real_height = sample_conf["image_height"] 139 | # 图片后缀 140 | image_suffix = sample_conf["image_suffix"] 141 | 142 | for image_dir in [origin_dir, new_dir]: 143 | print(">>> 开始校验目录:[{}]".format(image_dir)) 144 | bad_images_info = verify(image_dir, real_width, real_height, image_suffix) 145 | bad_imgs = [] 146 | for info in bad_images_info: 147 | bad_imgs.append(info[1]) 148 | split(image_dir, train_dir, test_dir, bad_imgs) 149 | 150 | 151 | if __name__ == '__main__': 152 | main() 153 | -------------------------------------------------------------------------------- /cnnlib/network.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import os 4 | from PIL import Image 5 | import random 6 | 7 | 8 | class CNN(object): 9 | def __init__(self, image_height, image_width, max_captcha, char_set, model_save_dir): 10 | # 初始值 11 | self.image_height = image_height 12 | self.image_width = image_width 13 | self.max_captcha = max_captcha 14 | self.char_set = char_set 15 | self.char_set_len = len(char_set) 16 | self.model_save_dir = model_save_dir # 模型路径 17 | with tf.name_scope('parameters'): 18 | self.w_alpha = 0.01 19 | self.b_alpha = 0.1 20 | # tf初始化占位符 21 | with tf.name_scope('data'): 22 | self.X = tf.placeholder(tf.float32, [None, self.image_height * self.image_width]) # 特征向量 23 | self.Y = tf.placeholder(tf.float32, [None, self.max_captcha * self.char_set_len]) # 标签 24 | self.keep_prob = tf.placeholder(tf.float32) # dropout值 25 | 26 | @staticmethod 27 | def convert2gray(img): 28 | """ 29 | 图片转为灰度图,如果是3通道图则计算,单通道图则直接返回 30 | :param img: 31 | :return: 32 | """ 33 | if len(img.shape) > 2: 34 | r, g, b = img[:, :, 0], img[:, :, 1], img[:, :, 2] 35 | gray = 0.2989 * r + 0.5870 * g + 0.1140 * b 36 | return gray 37 | else: 38 | return img 39 | 40 | def text2vec(self, text): 41 | """ 42 | 转标签为oneHot编码 43 | :param text: str 44 | :return: numpy.array 45 | """ 46 | text_len = len(text) 47 | if text_len > self.max_captcha: 48 | raise ValueError('验证码最长{}个字符'.format(self.max_captcha)) 49 | 50 | vector = np.zeros(self.max_captcha * self.char_set_len) 51 | 52 | for i, ch in enumerate(text): 53 | idx = i * self.char_set_len + self.char_set.index(ch) 54 | vector[idx] = 1 55 | return vector 56 | 57 | def model(self): 58 | x = tf.reshape(self.X, shape=[-1, self.image_height, self.image_width, 1]) 59 | print(">>> input x: {}".format(x)) 60 | 61 | # 卷积层1 62 | wc1 = tf.get_variable(name='wc1', shape=[3, 3, 1, 32], dtype=tf.float32, 63 | initializer=tf.contrib.layers.xavier_initializer()) 64 | bc1 = tf.Variable(self.b_alpha * tf.random_normal([32])) 65 | conv1 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(x, wc1, strides=[1, 1, 1, 1], padding='SAME'), bc1)) 66 | conv1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 67 | conv1 = tf.nn.dropout(conv1, self.keep_prob) 68 | 69 | # 卷积层2 70 | wc2 = tf.get_variable(name='wc2', shape=[3, 3, 32, 64], dtype=tf.float32, 71 | initializer=tf.contrib.layers.xavier_initializer()) 72 | bc2 = tf.Variable(self.b_alpha * tf.random_normal([64])) 73 | conv2 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv1, wc2, strides=[1, 1, 1, 1], padding='SAME'), bc2)) 74 | conv2 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 75 | conv2 = tf.nn.dropout(conv2, self.keep_prob) 76 | 77 | # 卷积层3 78 | wc3 = tf.get_variable(name='wc3', shape=[3, 3, 64, 128], dtype=tf.float32, 79 | initializer=tf.contrib.layers.xavier_initializer()) 80 | bc3 = tf.Variable(self.b_alpha * tf.random_normal([128])) 81 | conv3 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv2, wc3, strides=[1, 1, 1, 1], padding='SAME'), bc3)) 82 | conv3 = tf.nn.max_pool(conv3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 83 | conv3 = tf.nn.dropout(conv3, self.keep_prob) 84 | print(">>> convolution 3: ", conv3.shape) 85 | next_shape = conv3.shape[1] * conv3.shape[2] * conv3.shape[3] 86 | 87 | # 全连接层1 88 | wd1 = tf.get_variable(name='wd1', shape=[next_shape, 1024], dtype=tf.float32, 89 | initializer=tf.contrib.layers.xavier_initializer()) 90 | bd1 = tf.Variable(self.b_alpha * tf.random_normal([1024])) 91 | dense = tf.reshape(conv3, [-1, wd1.get_shape().as_list()[0]]) 92 | dense = tf.nn.relu(tf.add(tf.matmul(dense, wd1), bd1)) 93 | dense = tf.nn.dropout(dense, self.keep_prob) 94 | 95 | # 全连接层2 96 | wout = tf.get_variable('name', shape=[1024, self.max_captcha * self.char_set_len], dtype=tf.float32, 97 | initializer=tf.contrib.layers.xavier_initializer()) 98 | bout = tf.Variable(self.b_alpha * tf.random_normal([self.max_captcha * self.char_set_len])) 99 | 100 | with tf.name_scope('y_prediction'): 101 | y_predict = tf.add(tf.matmul(dense, wout), bout) 102 | 103 | return y_predict 104 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /train_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | 4 | import tensorflow as tf 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import time 8 | from PIL import Image 9 | import random 10 | import os 11 | from cnnlib.network import CNN 12 | 13 | 14 | class TrainError(Exception): 15 | pass 16 | 17 | 18 | class TrainModel(CNN): 19 | def __init__(self, train_img_path, verify_img_path, char_set, model_save_dir, cycle_stop, acc_stop, cycle_save, 20 | image_suffix, train_batch_size, test_batch_size, verify=False): 21 | # 训练相关参数 22 | self.cycle_stop = cycle_stop 23 | self.acc_stop = acc_stop 24 | self.cycle_save = cycle_save 25 | self.train_batch_size = train_batch_size 26 | self.test_batch_size = test_batch_size 27 | 28 | self.image_suffix = image_suffix 29 | char_set = [str(i) for i in char_set] 30 | 31 | # 打乱文件顺序+校验图片格式 32 | self.train_img_path = train_img_path 33 | self.train_images_list = os.listdir(train_img_path) 34 | # 校验格式 35 | if verify: 36 | self.confirm_image_suffix() 37 | # 打乱文件顺序 38 | random.seed(time.time()) 39 | random.shuffle(self.train_images_list) 40 | 41 | # 验证集文件 42 | self.verify_img_path = verify_img_path 43 | self.verify_images_list = os.listdir(verify_img_path) 44 | 45 | # 获得图片宽高和字符长度基本信息 46 | label, captcha_array = self.gen_captcha_text_image(train_img_path, self.train_images_list[0]) 47 | 48 | captcha_shape = captcha_array.shape 49 | captcha_shape_len = len(captcha_shape) 50 | if captcha_shape_len == 3: 51 | image_height, image_width, channel = captcha_shape 52 | self.channel = channel 53 | elif captcha_shape_len == 2: 54 | image_height, image_width = captcha_shape 55 | else: 56 | raise TrainError("图片转换为矩阵时出错,请检查图片格式") 57 | 58 | # 初始化变量 59 | super(TrainModel, self).__init__(image_height, image_width, len(label), char_set, model_save_dir) 60 | 61 | # 相关信息打印 62 | print("-->图片尺寸: {} X {}".format(image_height, image_width)) 63 | print("-->验证码长度: {}".format(self.max_captcha)) 64 | print("-->验证码共{}类 {}".format(self.char_set_len, char_set)) 65 | print("-->使用测试集为 {}".format(train_img_path)) 66 | print("-->使验证集为 {}".format(verify_img_path)) 67 | 68 | # test model input and output 69 | print(">>> Start model test") 70 | batch_x, batch_y = self.get_batch(0, size=100) 71 | print(">>> input batch images shape: {}".format(batch_x.shape)) 72 | print(">>> input batch labels shape: {}".format(batch_y.shape)) 73 | 74 | @staticmethod 75 | def gen_captcha_text_image(img_path, img_name): 76 | """ 77 | 返回一个验证码的array形式和对应的字符串标签 78 | :return:tuple (str, numpy.array) 79 | """ 80 | # 标签 81 | label = img_name.split("_")[0] 82 | # 文件 83 | img_file = os.path.join(img_path, img_name) 84 | captcha_image = Image.open(img_file) 85 | captcha_array = np.array(captcha_image) # 向量化 86 | return label, captcha_array 87 | 88 | def get_batch(self, n, size=128): 89 | batch_x = np.zeros([size, self.image_height * self.image_width]) # 初始化 90 | batch_y = np.zeros([size, self.max_captcha * self.char_set_len]) # 初始化 91 | 92 | max_batch = int(len(self.train_images_list) / size) 93 | # print(max_batch) 94 | if max_batch - 1 < 0: 95 | raise TrainError("训练集图片数量需要大于每批次训练的图片数量") 96 | if n > max_batch - 1: 97 | n = n % max_batch 98 | s = n * size 99 | e = (n + 1) * size 100 | this_batch = self.train_images_list[s:e] 101 | # print("{}:{}".format(s, e)) 102 | 103 | for i, img_name in enumerate(this_batch): 104 | label, image_array = self.gen_captcha_text_image(self.train_img_path, img_name) 105 | image_array = self.convert2gray(image_array) # 灰度化图片 106 | batch_x[i, :] = image_array.flatten() / 255 # flatten 转为一维 107 | batch_y[i, :] = self.text2vec(label) # 生成 oneHot 108 | return batch_x, batch_y 109 | 110 | def get_verify_batch(self, size=100): 111 | batch_x = np.zeros([size, self.image_height * self.image_width]) # 初始化 112 | batch_y = np.zeros([size, self.max_captcha * self.char_set_len]) # 初始化 113 | 114 | verify_images = [] 115 | for i in range(size): 116 | verify_images.append(random.choice(self.verify_images_list)) 117 | 118 | for i, img_name in enumerate(verify_images): 119 | label, image_array = self.gen_captcha_text_image(self.verify_img_path, img_name) 120 | image_array = self.convert2gray(image_array) # 灰度化图片 121 | batch_x[i, :] = image_array.flatten() / 255 # flatten 转为一维 122 | batch_y[i, :] = self.text2vec(label) # 生成 oneHot 123 | return batch_x, batch_y 124 | 125 | def confirm_image_suffix(self): 126 | # 在训练前校验所有文件格式 127 | print("开始校验所有图片后缀") 128 | for index, img_name in enumerate(self.train_images_list): 129 | print("{} image pass".format(index), end='\r') 130 | if not img_name.endswith(self.image_suffix): 131 | raise TrainError('confirm images suffix:you request [.{}] file but get file [{}]' 132 | .format(self.image_suffix, img_name)) 133 | print("所有图片格式校验通过") 134 | 135 | def train_cnn(self): 136 | y_predict = self.model() 137 | print(">>> input batch predict shape: {}".format(y_predict.shape)) 138 | print(">>> End model test") 139 | # 计算概率 损失 140 | with tf.name_scope('cost'): 141 | cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=y_predict, labels=self.Y)) 142 | # 梯度下降 143 | with tf.name_scope('train'): 144 | optimizer = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(cost) 145 | # 计算准确率 146 | predict = tf.reshape(y_predict, [-1, self.max_captcha, self.char_set_len]) # 预测结果 147 | max_idx_p = tf.argmax(predict, 2) # 预测结果 148 | max_idx_l = tf.argmax(tf.reshape(self.Y, [-1, self.max_captcha, self.char_set_len]), 2) # 标签 149 | # 计算准确率 150 | correct_pred = tf.equal(max_idx_p, max_idx_l) 151 | with tf.name_scope('char_acc'): 152 | accuracy_char_count = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) 153 | with tf.name_scope('image_acc'): 154 | accuracy_image_count = tf.reduce_mean(tf.reduce_min(tf.cast(correct_pred, tf.float32), axis=1)) 155 | # 模型保存对象 156 | saver = tf.train.Saver() 157 | with tf.Session() as sess: 158 | init = tf.global_variables_initializer() 159 | sess.run(init) 160 | # 恢复模型 161 | if os.path.exists(self.model_save_dir): 162 | try: 163 | saver.restore(sess, self.model_save_dir) 164 | # 判断捕获model文件夹中没有模型文件的错误 165 | except ValueError: 166 | print("model文件夹为空,将创建新模型") 167 | else: 168 | pass 169 | # 写入日志 170 | tf.summary.FileWriter("logs/", sess.graph) 171 | 172 | step = 1 173 | for i in range(self.cycle_stop): 174 | batch_x, batch_y = self.get_batch(i, size=self.train_batch_size) 175 | # 梯度下降训练 176 | _, cost_ = sess.run([optimizer, cost], 177 | feed_dict={self.X: batch_x, self.Y: batch_y, self.keep_prob: 0.75}) 178 | if step % 10 == 0: 179 | # 基于训练集的测试 180 | batch_x_test, batch_y_test = self.get_batch(i, size=self.train_batch_size) 181 | acc_char = sess.run(accuracy_char_count, feed_dict={self.X: batch_x_test, self.Y: batch_y_test, self.keep_prob: 1.}) 182 | acc_image = sess.run(accuracy_image_count, feed_dict={self.X: batch_x_test, self.Y: batch_y_test, self.keep_prob: 1.}) 183 | print("第{}次训练 >>> ".format(step)) 184 | print("[训练集] 字符准确率为 {:.5f} 图片准确率为 {:.5f} >>> loss {:.10f}".format(acc_char, acc_image, cost_)) 185 | 186 | # with open("loss_train.csv", "a+") as f: 187 | # f.write("{},{},{},{}\n".format(step, acc_char, acc_image, cost_)) 188 | 189 | # 基于验证集的测试 190 | batch_x_verify, batch_y_verify = self.get_verify_batch(size=self.test_batch_size) 191 | acc_char = sess.run(accuracy_char_count, feed_dict={self.X: batch_x_verify, self.Y: batch_y_verify, self.keep_prob: 1.}) 192 | acc_image = sess.run(accuracy_image_count, feed_dict={self.X: batch_x_verify, self.Y: batch_y_verify, self.keep_prob: 1.}) 193 | print("[验证集] 字符准确率为 {:.5f} 图片准确率为 {:.5f} >>> loss {:.10f}".format(acc_char, acc_image, cost_)) 194 | 195 | # with open("loss_test.csv", "a+") as f: 196 | # f.write("{}, {},{},{}\n".format(step, acc_char, acc_image, cost_)) 197 | 198 | # 准确率达到99%后保存并停止 199 | if acc_image > self.acc_stop: 200 | saver.save(sess, self.model_save_dir) 201 | print("验证集准确率达到99%,保存模型成功") 202 | break 203 | # 每训练500轮就保存一次 204 | if i % self.cycle_save == 0: 205 | saver.save(sess, self.model_save_dir) 206 | print("定时保存模型成功") 207 | step += 1 208 | saver.save(sess, self.model_save_dir) 209 | 210 | def recognize_captcha(self): 211 | label, captcha_array = self.gen_captcha_text_image(self.train_img_path, random.choice(self.train_images_list)) 212 | 213 | f = plt.figure() 214 | ax = f.add_subplot(111) 215 | ax.text(0.1, 0.9, "origin:" + label, ha='center', va='center', transform=ax.transAxes) 216 | plt.imshow(captcha_array) 217 | # 预测图片 218 | image = self.convert2gray(captcha_array) 219 | image = image.flatten() / 255 220 | 221 | y_predict = self.model() 222 | 223 | saver = tf.train.Saver() 224 | with tf.Session() as sess: 225 | saver.restore(sess, self.model_save_dir) 226 | predict = tf.argmax(tf.reshape(y_predict, [-1, self.max_captcha, self.char_set_len]), 2) 227 | text_list = sess.run(predict, feed_dict={self.X: [image], self.keep_prob: 1.}) 228 | predict_text = text_list[0].tolist() 229 | 230 | print("正确: {} 预测: {}".format(label, predict_text)) 231 | # 显示图片和预测结果 232 | p_text = "" 233 | for p in predict_text: 234 | p_text += str(self.char_set[p]) 235 | print(p_text) 236 | plt.text(20, 1, 'predict:{}'.format(p_text)) 237 | plt.show() 238 | 239 | 240 | def main(): 241 | with open("conf/sample_config.json", "r") as f: 242 | sample_conf = json.load(f) 243 | 244 | train_image_dir = sample_conf["train_image_dir"] 245 | verify_image_dir = sample_conf["test_image_dir"] 246 | model_save_dir = sample_conf["model_save_dir"] 247 | cycle_stop = sample_conf["cycle_stop"] 248 | acc_stop = sample_conf["acc_stop"] 249 | cycle_save = sample_conf["cycle_save"] 250 | enable_gpu = sample_conf["enable_gpu"] 251 | image_suffix = sample_conf['image_suffix'] 252 | use_labels_json_file = sample_conf['use_labels_json_file'] 253 | train_batch_size = sample_conf['train_batch_size'] 254 | test_batch_size = sample_conf['test_batch_size'] 255 | 256 | if use_labels_json_file: 257 | with open("tools/labels.json", "r") as f: 258 | char_set = f.read().strip() 259 | else: 260 | char_set = sample_conf["char_set"] 261 | 262 | if not enable_gpu: 263 | # 设置以下环境变量可开启CPU识别 264 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 265 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 266 | 267 | tm = TrainModel(train_image_dir, verify_image_dir, char_set, model_save_dir, cycle_stop, acc_stop, cycle_save, 268 | image_suffix, train_batch_size, test_batch_size, verify=False) 269 | tm.train_cnn() # 开始训练模型 270 | # tm.recognize_captcha() # 识别图片示例 271 | 272 | 273 | if __name__ == '__main__': 274 | main() 275 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # cnn_captcha 2 | use CNN recognize captcha by tensorflow. 3 | 本项目针对字符型图片验证码,使用tensorflow实现卷积神经网络,进行验证码识别。 4 | 项目封装了比较通用的**校验、训练、验证、识别、API模块**,极大的减少了识别字符型验证码花费的时间和精力。 5 | 6 | 项目已经帮助很多同学高效完成了验证码识别任务。 7 | 如果你在使用过程中出现了bug和做了良好的改进,欢迎提出issue和PR,作者会尽快回复,希望能和你共同完善项目。 8 | 9 | 如果你需要识别点选、拖拽类验证码,或者有目标检测需求,也可以参考这个项目[nickliqian/darknet_captcha](https://github.com/nickliqian/darknet_captcha)。 10 | 11 | # 时间表 12 | #### 2018.11.12 13 | 初版Readme.md 14 | #### 2018.11.21 15 | 加入关于验证码识别的一些说明 16 | #### 2018.11.24 17 | 优化校验数据集图片的规则 18 | #### 2018.11.26 19 | 新增`train_model_v2.py`文件,训练过程中同时输出训练集和验证集的准确率 20 | #### 2018.12.06 21 | 新增多模型部署支持,修复若干bug 22 | #### 2018.12.08 23 | 优化模型识别速度,支持api压力测试和统计耗时 24 | #### 2019.02.19 25 | 1. 新增一种准确率计算方式 26 | 2. TAG: v1.0 27 | #### 2019.04.12 28 | 1. 只保留一种`train_model.py`文件 29 | 2. 优化代码结构 30 | 3. 把通用配置抽取到`sample_config.json`和`captcha_config.json` 31 | 4. 修复若干大家在issue提出的问题 32 | #### 2019.06.01 33 | 1. 完善readme文档,文档不长,请大家一定要读完~ 34 | 2. 使用cnnlib目录存放神经网络结构代码 35 | 3. 做了一版训练数据统计,大家可以参考我们的训练次数、时长和准确率 36 | 4. TAG: v2.0 37 | 38 | # 目录 39 | 1 项目介绍 40 | - 1.1 关于验证码识别 41 | - 1.2 目录结构 42 | - 1.3 依赖 43 | - 1.4 模型结构 44 | 45 | 2 如何使用 46 | - 2.1 数据集 47 | - 2.2 配置文件 48 | - 2.3 验证和拆分数据集 49 | - 2.4 训练模型 50 | - 2.5 批量验证 51 | - 2.6 启动WebServer 52 | - 2.7 调用接口识别 53 | - 2.8 部署 54 | - 2.9 部署多个模型 55 | - 2.10 在线识别 56 | 57 | 3 统计数据 58 | - 3.1 训练数据统计 59 | - 3.2 压力测试 60 | 61 | 4 开发说明 62 | 63 | 5 已知BUG 64 | 65 | 66 | 67 | # 1 项目介绍 68 | ## 1.1 关于验证码识别 69 | 验证码识别大多是爬虫会遇到的问题,也可以作为图像识别的入门案例。目前通常使用如下几种方法: 70 | 71 | | 方法名称 | 相关要点 | 72 | | ------ | ------ | 73 | | tesseract | 仅适合识别没有干扰和扭曲的图片,训练起来很麻烦 | 74 | | 其他开源识别库 | 不够通用,识别率未知 | 75 | | 付费OCR API | 需求量大的情形成本很高 | 76 | | 图像处理+机器学习分类算法 | 涉及多种技术,学习成本高,且不通用 | 77 | | 卷积神经网络 | 一定的学习成本,算法适用于多类验证码 | 78 | 79 | 这里说一下使用传统的**图像处理和机器学习算法**,涉及多种技术: 80 | 81 | 1. 图像处理 82 | - 前处理(灰度化、二值化) 83 | - 图像分割 84 | - 裁剪(去边框) 85 | - 图像滤波、降噪 86 | - 去背景 87 | - 颜色分离 88 | - 旋转 89 | 2. 机器学习 90 | - KNN 91 | - SVM 92 | 93 | 使用这类方法对使用者的要求较高,且由于图片的变化类型较多,处理的方法不够通用,经常花费很多时间去调整处理步骤和相关算法。 94 | 而使用**卷积神经网络**,只需要通过简单的前处理,就可以实现大部分静态字符型验证码的端到端识别,效果很好,通用性很高。 95 | 96 | 这里列出目前**常用的验证码**生成库: 97 | >参考:[Java验证全家桶](https://www.cnblogs.com/cynchanpin/p/6912301.html) 98 | 99 | | 语言 | 验证码库名称 | 链接 | 样例 | 100 | | ------ | ------ | ------ | ------ | 101 | | Java | JCaptcha | [示例](https://jcaptcha.atlassian.net/wiki/spaces/general/pages/1212427/Samples+tests) | ![效果1](./readme_image/jcaptcha1.jpg) ![效果2](./readme_image/jcaptcha2.jpg) ![效果3](./readme_image/jcaptcha3.jpg) | 102 | | Java | JCaptcha4Struts2 | | | 103 | | Java | SimpleCaptcha | [例子](https://www.oschina.net/p/simplecaptcha) | ![效果1](./readme_image/SimpleCaptcha_1.jpg) ![效果2](./readme_image/SimpleCaptcha_2.jpg) ![效果3](./readme_image/SimpleCaptcha_3.jpg) | 104 | | Java | kaptcha | [例子](https://github.com/linghushaoxia/kaptcha) | ![水纹效果](./readme_image/Kaptcha_5.png) ![鱼眼效果](./readme_image/Kaptcha_2.png) ![阴影效果](./readme_image/Kaptcha_3.png) | 105 | | Java | patchca | | ![效果1](./readme_image/patchca_1.png) | 106 | | Java | imageRandom | | | 107 | | Java | iCaptcha | | ![效果1](./readme_image/iCaptcha.jpg) | 108 | | Java | SkewPassImage | | ![效果1](./readme_image/SkewPassImage.jpg) | 109 | | Java | Cage | | ![效果1](./readme_image/Cage1.jpg) ![效果2](./readme_image/Cage2.jpg) | 110 | | Python | captcha | [例子](https://github.com/nickliqian/cnn_captcha/blob/master/gen_image/gen_sample_by_captcha.py) | ![py_Captcha](./readme_image/py_Captcha-1.jpg) | 111 | | Python | pycapt | [例子](https://github.com/aboutmydreams/pycapt) | ![pycapt](https://github.com/aboutmydreams/pycapt/raw/master/img/do4.png) | 112 | | PHP | Gregwar/Captcha | [文档](https://github.com/Gregwar/Captcha) | | 113 | | PHP | mewebstudio/captcha | [文档](https://github.com/mewebstudio/captcha) | | 114 | 115 | ## 1.2 目录结构 116 | ### 1.2.1 基本配置 117 | | 序号 | 文件名称 | 说明 | 118 | | ------ | ------ | ------ | 119 | | 1 | `conf/` | 配置文件目录 | 120 | | 2 | `sample/` | 数据集目录 | 121 | | 3 | `model/` | 模型文件目录 | 122 | | 4 | `cnnlib/` | 封装CNN的相关代码目录 | 123 | ### 1.2.2 训练模型 124 | | 序号 | 文件名称 | 说明 | 125 | | ------ | ------ | ------ | 126 | | 1 | verify_and_split_data.py | 验证数据集、拆分数据为训练集和测试集 | 127 | | 2 | network.py | cnn网络基类 | 128 | | 3 | train_model.py | 训练模型 | 129 | | 4 | test_batch.py | 批量验证 | 130 | | 5 | gen_image/gen_sample_by_captcha.py | 生成验证码的脚本 | 131 | | 6 | gen_image/collect_labels.py | 用于统计验证码标签(常用于中文验证码) | 132 | 133 | ### 1.2.3 web接口 134 | | 序号 | 文件名称 | 说明 | 135 | | ------ | ------ | ------ | 136 | | 1 | webserver_captcha_image.py | 获取验证码接口 | 137 | | 2 | webserver_recognize_api.py | 提供在线识别验证码接口 | 138 | | 3 | recognize_online.py | 使用接口识别的例子 | 139 | | 4 | recognize_local.py | 测试本地图片的例子 | 140 | | 5 | recognize_time_test.py | 压力测试识别耗时和请求响应耗时 | 141 | 142 | ## 1.3 依赖 143 | ``` 144 | pip install -r requirements.txt 145 | ``` 146 | 注意:如果需要使用GPU进行训练,请把文件中的tenforflow修改为tensorflow-gpu 147 | 148 | ## 1.4 模型结构 149 | 150 | | 序号 | 层级 | 151 | | :------: | :------: | 152 | | 输入 | input | 153 | | 1 | 卷积层 + 池化层 + 降采样层 + ReLU | 154 | | 2 | 卷积层 + 池化层 + 降采样层 + ReLU | 155 | | 3 | 卷积层 + 池化层 + 降采样层 + ReLU | 156 | | 4 | 全连接 + 降采样层 + Relu | 157 | | 5 | 全连接 + softmax | 158 | | 输出 | output | 159 | 160 | # 2 如何使用 161 | ## 2.1 数据集 162 | 原始数据集可以存放在`./sample/origin`目录中。 163 | 为了便于处理,图片最好以`2e8j_17322d3d4226f0b5c5a71d797d2ba7f7.jpg`格式命名(标签_序列号.后缀)。 164 | 165 | 如果你没有训练集,你可以使用`gen_sample_by_captcha.py`文件生成训练集文件。 166 | 生成之前你需要修改相关配置`conf/captcha_config.json`(路径、文件后缀、字符集等)。 167 | ``` 168 | { 169 | "root_dir": "sample/origin/", # 验证码保存路径 170 | "image_suffix": "png", # 验证码图片后缀 171 | "characters": "0123456789", # 生成验证码的可选字符 172 | "count": 1000, # 生成验证码的图片数量 173 | "char_count": 4, # 每张验证码图片上的字符数量 174 | "width": 100, # 图片宽度 175 | "height": 60 # 图片高度 176 | } 177 | ``` 178 | 179 | ## 2.2 配置文件 180 | 创建一个新项目前,需要自行**修改相关配置文件**`conf/sample_config.json`。 181 | ``` 182 | { 183 | "origin_image_dir": "sample/origin/", # 原始文件 184 | "new_image_dir": "sample/new_train/", # 新的训练样本 185 | "train_image_dir": "sample/train/", # 训练集 186 | "test_image_dir": "sample/test/", # 测试集 187 | "api_image_dir": "sample/api/", # api接收的图片储存路径 188 | "online_image_dir": "sample/online/", # 从验证码url获取的图片的储存路径 189 | "local_image_dir": "sample/local/", # 本地保存图片的路径 190 | "model_save_dir": "model/", # 从验证码url获取的图片的储存路径 191 | "image_width": 100, # 图片宽度 192 | "image_height": 60, # 图片高度 193 | "max_captcha": 4, # 验证码字符个数 194 | "image_suffix": "png", # 图片文件后缀 195 | "char_set": "0123456789abcdefghijklmnopqrstuvwxyz", # 验证码识别结果类别 196 | "use_labels_json_file": false, # 是否开启读取`labels.json`内容 197 | "remote_url": "http://127.0.0.1:6100/captcha/", # 验证码远程获取地址 198 | "cycle_stop": 3000, # 启动任务后的训练指定次数后停止 199 | "acc_stop": 0.99, # 训练到指定准确率后停止 200 | "cycle_save": 500, # 训练指定次数后定时保存模型 201 | "enable_gpu": 0, # 是否开启GUP训练 202 | "train_batch_size": 128, # 训练时每次使用的图片张数,如果CPU或者GPU内存太小可以减少这个参数 203 | "test_batch_size": 100 # 每批次测试时验证的图片张数,不要超过验证码集的总数 204 | } 205 | 206 | ``` 207 | 关于`验证码识别结果类别`,假设你的样本是中文验证码,你可以使用`tools/collect_labels.py`脚本进行标签的统计。 208 | 会生成文件`gen_image/labels.json`存放所有标签,在配置文件中设置`use_labels_json_file = True`开启读取`labels.json`内容作为`结果类别`。 209 | 210 | ## 2.3 验证和拆分数据集 211 | 此功能会校验原始图片集的尺寸和测试图片是否能打开,并按照19:1的比例拆分出训练集和测试集。 212 | 所以需要分别创建和指定三个文件夹:origin,train,test用于存放相关文件。 213 | 214 | 也可以修改为不同的目录,但是最好修改为绝对路径。 215 | 文件夹创建好之后,执行以下命令即可: 216 | ``` 217 | python3 verify_and_split_data.py 218 | ``` 219 | 一般会有类似下面的提示 220 | ``` 221 | >>> 开始校验目录:[sample/origin/] 222 | 开始校验原始图片集 223 | 原始集共有图片: 1001张 224 | ====以下1张图片有异常==== 225 | [第0张图片] [.DStore] [文件后缀不正确] 226 | ========end 227 | 开始分离原始图片集为:测试集(5%)和训练集(95%) 228 | 共分配1000张图片到训练集和测试集,其中1张为异常留在原始目录 229 | 测试集数量为:50 230 | 训练集数量为:950 231 | >>> 开始校验目录:[sample/new_train/] 232 | 【警告】找不到目录sample/new_train/,即将创建 233 | 开始校验原始图片集 234 | 原始集共有图片: 0张 235 | ====以下0张图片有异常==== 236 | 未发现异常(共 0 张图片) 237 | ========end 238 | 开始分离原始图片集为:测试集(5%)和训练集(95%) 239 | 共分配0张图片到训练集和测试集,其中0张为异常留在原始目录 240 | 测试集数量为:0 241 | 训练集数量为:0 242 | ``` 243 | 程序会同时校验和分割`origin_image_dir`和`new_image_dir`两个目录中的图片;后续有了更多的样本,可以把样本放在`new_image_dir`目录中再次执行`verify_and_split_data`。 244 | 程序会把无效的文件留在原文件夹。 245 | 246 | 此外,当你有新的样本需要一起训练,可以放在`sample/new`目录下,再次运行`python3 verify_and_split_data.py`即可。 247 | 需要注意的是,如果新的样本中有新增的标签,你需要把新的标签增加到`char_set`配置中或者`labels.json`文件中。 248 | 249 | ## 2.4 训练模型 250 | 创建好训练集和测试集之后,就可以开始训练模型了。 251 | 训练的过程中会输出日志,日志展示当前的训练轮数、准确率和loss。 252 | **此时的准确率是训练集图片的准确率,代表训练集的图片识别情况** 253 | 例如: 254 | ``` 255 | 第10次训练 >>> 256 | [训练集] 字符准确率为 0.03000 图片准确率为 0.00000 >>> loss 0.1698757857 257 | [验证集] 字符准确率为 0.04000 图片准确率为 0.00000 >>> loss 0.1698757857 258 | ``` 259 | 字符准确率和图片准确率的解释: 260 | ``` 261 | 假设:有100张图片,每张图片四个字符,共400个字符。我们这里把任务拆分为为需要识别400个字符 262 | 字符准确率:识别400的字符中,正确字符的占比。 263 | 图片准确率:100张图片中,4个字符完全识别准确的图片占比。 264 | ``` 265 | 这里不具体介绍tensorflow安装相关问题,直奔主题。 266 | 确保图片相关参数和目录设置正确后,执行以下命令开始训练: 267 | ``` 268 | python3 train_model.py 269 | ``` 270 | 也可以根据`train_model.py`的`main`函数中的代码调用类开始训练或执行一次简单的识别演示。 271 | 272 | 由于训练集中常常不包含所有的样本特征,所以会出现训练集准确率是100%而测试集准确率不足100%的情况,此时提升准确率的一个解决方案是增加正确标记后的负样本。 273 | 274 | ## 2.5 批量验证 275 | 使用测试集的图片进行验证,输出准确率。 276 | ``` 277 | python3 test_batch.py 278 | ``` 279 | 同样可以根据`main`函数中的代码调用类开始验证。 280 | 281 | ## 2.6 启动WebServer 282 | 项目已经封装好加载模型和识别图片的类,启动`web server`后调用接口就可以使用识别服务。 283 | 启动`web server` 284 | ``` 285 | python3 webserver_recognize_api.py 286 | ``` 287 | 接口url为`http://127.0.0.1:6000/b` 288 | 289 | ## 2.7 调用接口识别 290 | 使用requests调用接口: 291 | ``` 292 | url = "http://127.0.0.1:6000/b" 293 | files = {'image_file': (image_file_name, open('captcha.jpg', 'rb'), 'application')} 294 | r = requests.post(url=url, files=files) 295 | ``` 296 | 返回的结果是一个json: 297 | ``` 298 | { 299 | 'time': '1542017705.9152594', 300 | 'value': 'jsp1', 301 | } 302 | ``` 303 | 文件`recognize_local.py`是使用接口识别本地的例子,这个例子运行成功,那么识别验证码的一套流程基本上是走了一遍了。 304 | 在线识别验证码是显示中常用场景,文件`recognize_online.py`是使用接口在线识别的例子,参见:`## 2.11 在线识别`。 305 | 306 | ## 2.8 部署 307 | 部署的时候,把`webserver_recognize_api.py`文件的最后一行修改为如下内容: 308 | ``` 309 | app.run(host='0.0.0.0',port=5000,debug=False) 310 | ``` 311 | 然后开启端口访问权限,就可以通过外网访问了。 312 | 另外为了开启多进程处理请求,可以使用uwsgi+nginx组合进行部署。 313 | 这部分可以参考:[Flask部署选择](http://docs.jinkan.org/docs/flask/deploying/index.html) 314 | 315 | ## 2.9 部署多个模型 316 | 部署多个模型: 317 | 在`webserver_recognize_api.py`文件汇总,新建一个Recognizer对象; 318 | 并参照原有`up_image`函数编写的路由和识别逻辑。 319 | ``` 320 | Q = Recognizer(image_height, image_width, max_captcha, char_set, model_save_dir) 321 | ``` 322 | 注意修改这一行: 323 | ``` 324 | value = Q.rec_image(img) 325 | ``` 326 | 327 | ## 2.10 在线识别 328 | 在线识别验证码是显示中常用场景,即实时获取目标验证码来调用接口进行识别。 329 | 为了测试的完整性,这里搭建了一个验证码获取接口,通过执行下面的命令启动: 330 | ``` 331 | python webserver_captcha_image.py 332 | ``` 333 | 启动后通过访问此地址:`http://127.0.0.1:6100/captcha/`可以接收到验证码图片的二进制流文件。 334 | 具体进行在线识别任务的demo参见:`recognize_online.py`。 335 | 336 | # 3 数据统计 337 | ## 3.1 训练数据统计 338 | 由于很多同学提出,“需要训练多久呀?”、“准确率可以达到多少?”、“为什么我的准确率一直是0?”类似的疑问。 339 | 这一小节,使用默认配置(2019.06.02),把训练过程中的数据做了统计,给大家做一个展示。 340 | 本次测试条件如下: 341 | - 验证码:本项目自带生成验证码程序,数字+小写英文 342 | - 数量:20000张 343 | - 计算引擎:GPU 344 | - GPU型号:笔记本,GTX 950X 2G显卡 345 | 346 | 经过测试: 347 | 5000次,25分钟,**训练集**字符准确率84%,图片准确率51%; 348 | 9190次,46分钟,**训练集**字符准确率100%,图片准确率100%; 349 | 12000,60分钟,**测试集**的准确率基本上已经跑不动了。 350 | 351 | 使用`test_batch.py`测试,日志如下: 352 | ``` 353 | 100个样本识别耗时6.513171672821045秒,准确率37.0% 354 | ``` 355 | 有37%的准确率,可以说是识别成功的第一步了。 356 | 357 | 曲线图如下: 358 | 训练集- 359 | ![train_acc](readme_image/train_acc.png) 360 | 361 | 测试集- 362 | ![test_acc](readme_image/test_acc.png) 363 | 364 | 365 | ## 3.2 压力测试和统计数据 366 | 提供了一个简易的压力测试脚本,可以统计api运行过程中识别耗时和请求耗时的相关数据,不过图需要自己用Excel拉出来。 367 | 打开文件`recognize_time_test.py`,修改`main`函数下的`test_file`路径,这里会重复使用一张图片来访问是被接口。 368 | 最后数据会储存在test.csv文件中。 369 | 使用如下命令运行: 370 | ``` 371 | python3 recognize_time_test.py 372 | ----输出如下 373 | 2938,5150,13:30:25,总耗时:29ms,识别:15ms,请求:14ms 374 | 2939,5150,13:30:25,总耗时:41ms,识别:21ms,请求:20ms 375 | 2940,5150,13:30:25,总耗时:47ms,识别:16ms,请求:31ms 376 | ``` 377 | 这里对一个模型进行了两万次测试后,一组数据test.csv。 378 | 把test.csv使用箱线图进行分析后可以看到: 379 | ![压力测试结果](readme_image/压力测试结果.png) 380 | - 单次请求API总耗时(平均值):27ms 381 | - 单次识别耗时(平均值):12ms 382 | - 每次请求耗时(平均值):15ms 383 | 其中有:请求API总耗时 = 识别耗时 + 请求耗时 384 | 385 | # 4 开发说明 386 | - 20190209 387 | 1. 目前tensorboard展示支持的不是很好。 388 | - 20190601 389 | 1. 最近比较忙,issue回的有点慢,请大家见谅 390 | 2. dev分支开发到一半一直没时间弄,今天儿童节花了一下午时间更新了一下:) 391 | 3. 感谢看到这里的你,谢谢你的支持 392 | 393 | # 4 已知BUG 394 | 1. 使用pycharm启动recognize_api.py文件报错 395 | ``` 396 | 2018-12-01 00:35:15.106333: W T:\src\github\tensorflow\tensorflow\core\framework\op_kernel.cc:1273] OP_REQUIRES failed at save_restore_tensor.cc:170 : Invalid argument: Unsuccessful TensorSliceReader constructor: Failed to get matching files on ./model/: Not found: FindFirstFile failed for: ./model : ϵͳ�Ҳ���ָ����·���� 397 | ; No such process 398 | ...... 399 | tensorflow.python.framework.errors_impl.InvalidArgumentError: Unsuccessful TensorSliceReader constructor: Failed to get matching files on ./model/: Not found: FindFirstFile failed for: ./model : ϵͳ\udcd5Ҳ\udcbb\udcb5\udcbdָ\udcb6\udca8\udcb5\udcc4·\udcbe\udcb6\udca1\udca3 400 | ; No such process 401 | [[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]] 402 | ``` 403 | 由pycharm默认设置了工作空间,导致读取相对路径的model文件夹出错。 404 | 解决办法:编辑运行配置,设置工作空间为项目目录即可。 405 | ![bug_api启动失败](readme_image/bug_api启动失败.png) 406 | 407 | 2. FileNotFoundError: [Errno 2] No such file or directory: 'xxxxxx' 408 | 目录下有文件夹不存在,在指定目录创建好文件夹即可。 409 | 410 | 3. api程序在运行过程中内存越占越大 411 | 结果查阅资料:[链接](https://blog.csdn.net/The_lastest/article/details/81130500) 412 | 在迭代循环时,不能再包含任何张量的计算表达式,否在会内存溢出。 413 | 将张量的计算表达式放到init初始化执行后,识别速度得到极大的提升。 414 | 415 | 4. 加载多个模型报错 416 | 原因是两个Recognizer对象都使用了默认的Graph。 417 | 解决办法是在创建对象的时候不使用默认Graph,新建graph,这样每个Recognizer都使用不同的graph,就不会冲突了。 418 | 419 | 5. Flask程序用于生产 420 | 可以参考官方文档:[Flask的生产配置](http://docs.jinkan.org/docs/flask/config.html) 421 | 422 | 6. OOM happens 423 | ``` 424 | Hint: If you want to see a list of allocated tensors when OOM happens, 425 | add report_tensor_allocations_upon_oom to RunOptions for current allocation info. 426 | ``` 427 | 尽可能关闭其他占用GPU或者CPU的任务,或者减小`sample_config.json`中的`train_batch_size`参数。 428 | --------------------------------------------------------------------------------