├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── assets └── results_res101.png ├── captchaGenerator.py ├── datasets.py ├── docs ├── number.png └── number2.png ├── models.py ├── one_hot_encoding.py ├── predict.py ├── results.png ├── results.txt ├── settings.py ├── test.py ├── torch_util.py └── train.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | ./dataset/ 10 | ./weights/ 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # dotenv 85 | .env 86 | 87 | # virtualenv 88 | .venv 89 | venv/ 90 | ENV/ 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | .spyproject 95 | 96 | # Rope project settings 97 | .ropeproject 98 | 99 | # mkdocs documentation 100 | /site 101 | 102 | # mypy 103 | .mypy_cache/ 104 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 dpj 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 深度学习识别验证码 2 | 3 | 基于: https://github.com/dee1024/pytorch-captcha-recognition 进行修改 4 | 5 | 本项目致力于使用神经网络来识别各种验证码。 6 | 7 | 在这个库的基础上,进行了改动,添加了很多trick来增强识别效果,如attention机制,dual pooling, ibn模块,bnneck,center loss等。 8 | 9 | 链接为:https://github.com/pprp/captcha.Pytorch 10 | 11 | 改动 12 | === 13 | - 添加了更多torchvision中支持的模型 14 | - 改了一下文件的名称 15 | - 支持了GPU,当然cpu也可以 16 | - 添加了以下功能:训练完成一个epoch之后进行测试,(需要保证test和train中的模型一致) 17 | - 添加了以下功能:将每次得到的测试结果写入results.txt文件,运行torch_util.py得到results.png可视化准确率。 18 | - RES152为基础网络进行训练,混合数字和大写字符只能达到94%,还达不到原作者的识别率,希望得到赐教 19 | 20 | 特性 21 | === 22 | - __端到端,不需要做更多的图片预处理(比如图片字符切割、图片尺寸归一化、图片字符标记、字符图片特征提取)__ 23 | - __验证码包括数字、大写字母、小写__ 24 | - __采用自己生成的验证码来作为神经网络的训练集合、测试集合、预测集合__ 25 | - __纯四位数字,验证码识别率高达 99.9999 %__ 26 | - __四位数字 + 大写字符,验证码识别率约 96 %__ 27 | - __深度学习框架pytorch + 验证码生成器ImageCaptcha__ 28 | 29 | 30 | 原理 31 | === 32 | 33 | - __训练集合生成__ 34 | 35 | 使用常用的 Python 验证码生成库 ImageCaptcha,生成 10w 个验证码,并且都自动标记好; 36 | 如果需要识别其他的验证码也同样的道理,寻找对应的验证码生成算法自动生成已经标记好的训练集合或者手动对标记,需要上万级别的数量,纯手工需要一定的时间,再或者可以借助一些网络的打码平台进行标记 37 | 38 | - __训练卷积神经网络__ 39 | 构建一个多层的卷积网络,进行多标签分类模型的训练 40 | 标记的每个字符都做 one-hot 编码 41 | 批量输入图片集合和标记数据,大概15个Epoch后,准确率已经达到 96% 以上 42 | 43 | 44 | 验证码识别率展示 45 | ======== 46 | ![](https://raw.githubusercontent.com/dee1024/pytorch-captcha-recognition/master/docs/number.png) 47 | 48 | 快速开始 49 | ==== 50 | - __步骤一:10分钟环境安装__ 51 | 52 | Python3.6+ 、ImageCaptcha库(pip install captcha)、 Pytorch(参考官网http://pytorch.org) 53 | 54 | 55 | - __步骤二:生成验证码__ 56 | ```bash 57 | python captchaGenerator.py 58 | ``` 59 | 执行以上命令,会在目录 dataset/train/ 下生成多张验证码图片,图片已经标注好,数量可以是 1w、5w、10w,通过 captchaGenerator.py 内的 count 参数设定 60 | 61 | - __步骤三:训练模型__ 62 | ```bash 63 | python train.py 64 | ``` 65 | 使用步骤一生成的验证码图集合用CNN模型(在 models 中定义)进行训练,训练完成会生成文件保存在weights文件夹中,最好的结果保存为`cnn_best.pt` 66 | 67 | - __步骤四:测试模型__ 68 | ```bash 69 | python test.py 70 | ``` 71 | 可以在控制台,看到模型的准确率(如 95%) ,如果准确率较低,回到步骤一,生成更多的图片集合再次训练 72 | 73 | - __步骤五:使用模型做预测__ 74 | ```bash 75 | python predict.py 76 | ``` 77 | 可以在控制台,看到预测输出的结果 78 | 79 | 80 | 可视化 81 | === 82 | 83 | ![results_res101](assets/results_res101.png) 84 | 85 | 贡献 86 | === 87 | 88 | 我们期待你的 pull requests ! 89 | 90 | 有问题欢迎提出issue 91 | 92 | 作者 93 | === 94 | * __Dee Qiu__ 95 | * 补充:__pprp__ <1115957667@qq.com> 96 | 97 | 98 | 声明 99 | === 100 | 本项目仅用于交流学习 101 | -------------------------------------------------------------------------------- /assets/results_res101.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pprp/captcha_identify.pytorch/1089683c5d6a89481fa639af4473a2cc1ab5c79b/assets/results_res101.png -------------------------------------------------------------------------------- /captchaGenerator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | from captcha.image import ImageCaptcha # pip install captcha 3 | from PIL import Image 4 | import random 5 | import time 6 | import settings 7 | import os 8 | 9 | def random_captcha(): 10 | captcha_text = [] 11 | for i in range(settings.MAX_CAPTCHA): 12 | c = random.choice(settings.ALL_CHAR_SET) 13 | captcha_text.append(c) 14 | return ''.join(captcha_text) 15 | 16 | # 生成字符对应的验证码 17 | def gen_captcha_text_and_image(): 18 | image = ImageCaptcha() 19 | captcha_text = random_captcha() 20 | captcha_image = Image.open(image.generate(captcha_text)) 21 | return captcha_text, captcha_image 22 | 23 | if __name__ == '__main__': 24 | count = 1000 25 | path = settings.TRAIN_DATASET_PATH #通过改变此处目录,以生成 训练、测试和预测用的验证码集 26 | if not os.path.exists(path): 27 | os.makedirs(path) 28 | 29 | for i in range(count): 30 | now = str(int(time.time())) 31 | text, image = gen_captcha_text_and_image() 32 | filename = text+'_'+now+'.jpg' 33 | image.save(path + os.path.sep + filename) 34 | print('saved %d : %s' % (i+1,filename)) 35 | 36 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import os 3 | from torch.utils.data import DataLoader,Dataset 4 | import torchvision.transforms as transforms 5 | from PIL import Image 6 | import one_hot_encoding as ohe 7 | import settings 8 | import numpy as np 9 | import cv2 as cv 10 | 11 | class mydataset(Dataset): 12 | def __init__(self, folder, transform=None): 13 | self.train_image_file_paths = [os.path.join(folder, image_file) for image_file in os.listdir(folder)] 14 | self.transform = transforms.Compose([ 15 | transforms.ToTensor(), 16 | transforms.Lambda(lambda x: x.repeat(1,1,1)), 17 | # transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 18 | ]) 19 | 20 | 21 | def __len__(self): 22 | return len(self.train_image_file_paths) 23 | 24 | def __getitem__(self, idx): 25 | image_root = self.train_image_file_paths[idx] 26 | image_name = image_root.split(os.path.sep)[-1] 27 | image = Image.open(image_root) 28 | if self.transform is not None: 29 | image = self.transform(image) 30 | label = ohe.encode(image_name.split('_')[0]) # 为了方便,在生成图片的时候,图片文件的命名格式 "4个数字或者数字_时间戳.PNG", 4个字母或者即是图片的验证码的值,字母大写,同时对该值做 one-hot 处理 31 | return image, label 32 | 33 | 34 | 35 | transform = transforms.Compose([ 36 | transforms.ColorJitter(), 37 | # transforms.Grayscale(), 38 | transforms.ToTensor(), 39 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 40 | ]) 41 | def get_train_data_loader(): 42 | dataset = mydataset(settings.TRAIN_DATASET_PATH, transform=transform) 43 | return DataLoader(dataset, batch_size=64, shuffle=True) 44 | 45 | def get_test_data_loader(): 46 | dataset = mydataset(settings.TEST_DATASET_PATH, transform=transform) 47 | return DataLoader(dataset, batch_size=1, shuffle=True) 48 | 49 | def get_predict_data_loader(): 50 | dataset = mydataset(settings.PREDICT_DATASET_PATH, transform=transform) 51 | return DataLoader(dataset, batch_size=1, shuffle=True) -------------------------------------------------------------------------------- /docs/number.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pprp/captcha_identify.pytorch/1089683c5d6a89481fa639af4473a2cc1ab5c79b/docs/number.png -------------------------------------------------------------------------------- /docs/number2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pprp/captcha_identify.pytorch/1089683c5d6a89481fa639af4473a2cc1ab5c79b/docs/number2.png -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import torch.nn as nn 3 | import settings 4 | import torchvision.models as models 5 | import torchvision 6 | # CNN Model (2 conv layer) 7 | class CNN(nn.Module): 8 | def __init__(self): 9 | super(CNN, self).__init__() 10 | self.layer1 = nn.Sequential( 11 | nn.Conv2d(3, 32, kernel_size=3, padding=1), 12 | nn.BatchNorm2d(32), 13 | nn.Dropout(0.5), # drop 50% of the neuron 14 | nn.ReLU(), 15 | nn.MaxPool2d(2)) 16 | # self.layer_vgg = nn.Sequential(*list(vgg.features._modules.values())[:-1]) 17 | self.layer2 = nn.Sequential( 18 | nn.Conv2d(32, 64, kernel_size=3, padding=1), 19 | nn.BatchNorm2d(64), 20 | nn.Dropout(0.5), # drop 50% of the neuron 21 | nn.ReLU(), 22 | nn.MaxPool2d(2)) 23 | self.layer3 = nn.Sequential( 24 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 25 | nn.BatchNorm2d(64), 26 | nn.Dropout(0.5), # drop 50% of the neuron 27 | nn.ReLU(), 28 | nn.MaxPool2d(2)) 29 | # nn.AdaptiveAvgPool2d((1,1))) 30 | self.fc = nn.Sequential( 31 | # nn.AdaptiveAvgPool2d((1,1)), 32 | nn.Linear((settings.IMAGE_WIDTH//8)*(settings.IMAGE_HEIGHT//8)*64, 1024), 33 | nn.Dropout(0.5), # drop 50% of the neuron 34 | nn.ReLU()) 35 | self.rfc = nn.Sequential( 36 | nn.Linear(1024, settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN), 37 | ) 38 | 39 | def forward(self, x): 40 | out1 = self.layer1(x) 41 | out2 = self.layer2(out1) 42 | out3 = self.layer3(out2) 43 | out4 = out3.view(out3.size(0), -1) 44 | out5 = self.fc(out4) 45 | out6 = self.rfc(out5) 46 | return out6 47 | 48 | 49 | 50 | class RES18(nn.Module): 51 | def __init__(self): 52 | super(RES18, self).__init__() 53 | self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN 54 | self.base = torchvision.models.resnet18(pretrained=False) 55 | self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls) 56 | def forward(self, x): 57 | out = self.base(x) 58 | return out 59 | 60 | class RES34(nn.Module): 61 | def __init__(self): 62 | super(RES34, self).__init__() 63 | self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN 64 | self.base = torchvision.models.resnet34(pretrained=False) 65 | self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls) 66 | def forward(self, x): 67 | out = self.base(x) 68 | return out 69 | 70 | class RES50(nn.Module): 71 | def __init__(self): 72 | super(RES50, self).__init__() 73 | self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN 74 | self.base = torchvision.models.resnet50(pretrained=False) 75 | self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls) 76 | def forward(self, x): 77 | out = self.base(x) 78 | return out 79 | 80 | class RES101(nn.Module): 81 | def __init__(self): 82 | super(RES101, self).__init__() 83 | self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN 84 | self.base = torchvision.models.resnet101(pretrained=False) 85 | self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls) 86 | def forward(self, x): 87 | out = self.base(x) 88 | return out 89 | 90 | class RES152(nn.Module): 91 | def __init__(self): 92 | super(RES152, self).__init__() 93 | self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN 94 | self.base = torchvision.models.resnet152(pretrained=False) 95 | self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls) 96 | def forward(self, x): 97 | out = self.base(x) 98 | return out 99 | 100 | class ALEXNET(nn.Module): 101 | def __init__(self): 102 | super(ALEXNET, self).__init__() 103 | self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN 104 | self.base = torchvision.models.alexnet(pretrained=False) 105 | self.base.classifier[-1] = nn.Linear(4096, self.num_cls) 106 | def forward(self, x): 107 | out = self.base(x) 108 | return out 109 | 110 | class VGG11(nn.Module): 111 | def __init__(self): 112 | super(VGG11, self).__init__() 113 | self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN 114 | self.base = torchvision.models.vgg11(pretrained=False) 115 | self.base.classifier[-1] = nn.Linear(4096, self.num_cls) 116 | def forward(self, x): 117 | out = self.base(x) 118 | return out 119 | 120 | class VGG13(nn.Module): 121 | def __init__(self): 122 | super(VGG13, self).__init__() 123 | self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN 124 | self.base = torchvision.models.vgg13(pretrained=False) 125 | self.base.classifier[-1] = nn.Linear(4096, self.num_cls) 126 | def forward(self, x): 127 | out = self.base(x) 128 | return out 129 | 130 | class VGG16(nn.Module): 131 | def __init__(self): 132 | super(VGG16, self).__init__() 133 | self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN 134 | self.base = torchvision.models.vgg16(pretrained=False) 135 | self.base.classifier[-1] = nn.Linear(4096, self.num_cls) 136 | def forward(self, x): 137 | out = self.base(x) 138 | return out 139 | 140 | class VGG19(nn.Module): 141 | def __init__(self): 142 | super(VGG19, self).__init__() 143 | self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN 144 | self.base = torchvision.models.vgg19(pretrained=False) 145 | self.base.classifier[-1] = nn.Linear(4096, self.num_cls) 146 | def forward(self, x): 147 | out = self.base(x) 148 | return out 149 | 150 | class SQUEEZENET(nn.Module): 151 | def __init__(self): 152 | super(SQUEEZENET, self).__init__() 153 | self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN 154 | self.base = torchvision.models.squeezenet1_0(pretrained=False) 155 | self.base.classifier[-3] = nn.Linear(512, self.num_cls) 156 | def forward(self, x): 157 | out = self.base(x) 158 | return out 159 | 160 | class DENSE161(nn.Module): 161 | def __init__(self): 162 | super(DENSE161, self).__init__() 163 | self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN 164 | self.base = torchvision.models.densenet161(pretrained=False) 165 | self.base.classifier = nn.Linear(self.base.classifier.in_features, self.num_cls) 166 | def forward(self, x): 167 | out = self.base(x) 168 | return out 169 | 170 | class MOBILENET(nn.Module): 171 | def __init__(self): 172 | super(MOBILENET, self).__init__() 173 | self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN 174 | self.base = torchvision.models.mobilenet_v2(pretrained=False) 175 | self.base.classifier = nn.Linear(self.base.last_channel, self.num_cls) 176 | def forward(self, x): 177 | out = self.base(x) 178 | return out 179 | -------------------------------------------------------------------------------- /one_hot_encoding.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import numpy as np 3 | import settings 4 | 5 | 6 | def encode(text): 7 | vector = np.zeros(settings.ALL_CHAR_SET_LEN * settings.MAX_CAPTCHA, dtype=float) 8 | def char2pos(c): 9 | if c =='_': 10 | k = 62 11 | return k 12 | k = ord(c)-48 13 | if k > 9: 14 | k = ord(c) - 65 + 10 15 | if k > 35: 16 | k = ord(c) - 97 + 26 + 10 17 | if k > 61: 18 | raise ValueError('error') 19 | return k 20 | for i, c in enumerate(text): 21 | idx = i * settings.ALL_CHAR_SET_LEN + char2pos(c) 22 | vector[idx] = 1.0 23 | return vector 24 | 25 | def decode(vec): 26 | char_pos = vec.nonzero()[0] 27 | text=[] 28 | for i, c in enumerate(char_pos): 29 | char_at_pos = i #c/63 30 | char_idx = c % settings.ALL_CHAR_SET_LEN 31 | if char_idx < 10: 32 | char_code = char_idx + ord('0') 33 | elif char_idx <36: 34 | char_code = char_idx - 10 + ord('A') 35 | elif char_idx < 62: 36 | char_code = char_idx - 36 + ord('a') 37 | elif char_idx == 62: 38 | char_code = ord('_') 39 | else: 40 | raise ValueError('error') 41 | text.append(chr(char_code)) 42 | return "".join(text) 43 | 44 | if __name__ == '__main__': 45 | e = encode("BK7H") 46 | print(decode(e)) -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import numpy as np 3 | import torch 4 | from torch.autograd import Variable 5 | #from visdom import Visdom # pip install Visdom 6 | import captcha_setting 7 | import datasets 8 | from captcha_cnn_model import CNN 9 | import argparse 10 | 11 | def main(): 12 | cnn = CNN() 13 | cnn.eval() 14 | 15 | parser = argparse.ArgumentParser(description="model path") 16 | parser.add_argument("--model-path", type=str) 17 | args = parser.parse_args() 18 | 19 | cnn.load_state_dict(torch.load(args.model_path)) 20 | 21 | predict_dataloader = datasets.get_predict_data_loader() 22 | 23 | #vis = Visdom() 24 | for i, (images, labels) in enumerate(predict_dataloader): 25 | image = images 26 | vimage = Variable(image) 27 | predict_label = cnn(vimage) 28 | 29 | c0 = captcha_setting.ALL_CHAR_SET[np.argmax(predict_label[0, 0:captcha_setting.ALL_CHAR_SET_LEN].data.numpy())] 30 | c1 = captcha_setting.ALL_CHAR_SET[np.argmax(predict_label[0, captcha_setting.ALL_CHAR_SET_LEN:2 * captcha_setting.ALL_CHAR_SET_LEN].data.numpy())] 31 | c2 = captcha_setting.ALL_CHAR_SET[np.argmax(predict_label[0, 2 * captcha_setting.ALL_CHAR_SET_LEN:3 * captcha_setting.ALL_CHAR_SET_LEN].data.numpy())] 32 | c3 = captcha_setting.ALL_CHAR_SET[np.argmax(predict_label[0, 3 * captcha_setting.ALL_CHAR_SET_LEN:4 * captcha_setting.ALL_CHAR_SET_LEN].data.numpy())] 33 | 34 | c = '%s%s%s%s' % (c0, c1, c2, c3) 35 | print(c) 36 | #vis.images(image, opts=dict(caption=c)) 37 | 38 | if __name__ == '__main__': 39 | main() 40 | 41 | 42 | -------------------------------------------------------------------------------- /results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pprp/captcha_identify.pytorch/1089683c5d6a89481fa639af4473a2cc1ab5c79b/results.png -------------------------------------------------------------------------------- /results.txt: -------------------------------------------------------------------------------- 1 | 1,0.1 2 | 2,0.2 3 | 3,0.3 4 | 4,0.4 -------------------------------------------------------------------------------- /settings.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import os 3 | # 验证码中的字符 4 | # string.digits + string.ascii_uppercase 5 | NUMBER = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] 6 | ALPHABET = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'] 7 | # , \ 8 | # 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q' ,'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] 9 | 10 | ALL_CHAR_SET = NUMBER + ALPHABET 11 | ALL_CHAR_SET_LEN = len(ALL_CHAR_SET) 12 | MAX_CAPTCHA = 4 13 | 14 | # 图像大小 15 | IMAGE_HEIGHT = 60 16 | IMAGE_WIDTH = 160 17 | 18 | TRAIN_DATASET_PATH = 'dataset' + os.path.sep + 'train' 19 | TEST_DATASET_PATH = 'dataset' + os.path.sep + 'test' 20 | PREDICT_DATASET_PATH = 'dataset' + os.path.sep + 'predict' -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import numpy as np 3 | import torch 4 | from torch.autograd import Variable 5 | import settings 6 | import datasets 7 | from models import * 8 | import one_hot_encoding 9 | import argparse 10 | import torch_util 11 | import os 12 | from models import * 13 | from tqdm import * 14 | 15 | # os.environ["CUDA_VISIBLE_DEVICES"] = "1" 16 | 17 | device = torch.device("cpu") 18 | 19 | def main(model_path): 20 | cnn = CNN() 21 | cnn.eval() 22 | cnn.load_state_dict(torch.load(model_path, map_location=device)) 23 | print("load cnn net.") 24 | 25 | test_dataloader = datasets.get_test_data_loader() 26 | 27 | correct = 0 28 | total = 0 29 | 30 | pBar = tqdm(total=test_dataloader.__len__()) 31 | 32 | for i, (images, labels) in enumerate(test_dataloader): 33 | pBar.update(1) 34 | 35 | image = images 36 | vimage = Variable(image) 37 | predict_label = cnn(vimage) 38 | 39 | c0 = settings.ALL_CHAR_SET[np.argmax(predict_label[0, 0:settings.ALL_CHAR_SET_LEN].data.numpy())] 40 | c1 = settings.ALL_CHAR_SET[np.argmax(predict_label[0, settings.ALL_CHAR_SET_LEN:2 * settings.ALL_CHAR_SET_LEN].data.numpy())] 41 | c2 = settings.ALL_CHAR_SET[np.argmax(predict_label[0, 2 * settings.ALL_CHAR_SET_LEN:3 * settings.ALL_CHAR_SET_LEN].data.numpy())] 42 | c3 = settings.ALL_CHAR_SET[np.argmax(predict_label[0, 3 * settings.ALL_CHAR_SET_LEN:4 * settings.ALL_CHAR_SET_LEN].data.numpy())] 43 | predict_label = '%s%s%s%s' % (c0, c1, c2, c3) 44 | true_label = one_hot_encoding.decode(labels.numpy()[0]) 45 | total += labels.size(0) 46 | if(predict_label == true_label): 47 | correct += 1 48 | # if(total%200==0): 49 | # print('Test Accuracy of the model on the %d test images: %f %%' % (total, 100 * correct / total)) 50 | print('Test Accuracy of the model on the %d test images: %f %%' % (total, 100 * correct / total)) 51 | 52 | def test_data(model_path): 53 | cnn = CNN() 54 | cnn.eval() 55 | cnn.load_state_dict(torch.load(model_path, map_location=device)) 56 | test_dataloader = datasets.get_test_data_loader() 57 | 58 | correct = 0 59 | total = 0 60 | 61 | for i, (images, labels) in enumerate(test_dataloader): 62 | 63 | image = images 64 | vimage = Variable(image) 65 | predict_label = cnn(vimage) 66 | 67 | c0 = settings.ALL_CHAR_SET[np.argmax(predict_label[0, 0:settings.ALL_CHAR_SET_LEN].data.numpy())] 68 | c1 = settings.ALL_CHAR_SET[np.argmax(predict_label[0, settings.ALL_CHAR_SET_LEN:2 * settings.ALL_CHAR_SET_LEN].data.numpy())] 69 | c2 = settings.ALL_CHAR_SET[np.argmax(predict_label[0, 2 * settings.ALL_CHAR_SET_LEN:3 * settings.ALL_CHAR_SET_LEN].data.numpy())] 70 | c3 = settings.ALL_CHAR_SET[np.argmax(predict_label[0, 3 * settings.ALL_CHAR_SET_LEN:4 * settings.ALL_CHAR_SET_LEN].data.numpy())] 71 | predict_label = '%s%s%s%s' % (c0, c1, c2, c3) 72 | true_label = one_hot_encoding.decode(labels.numpy()[0]) 73 | total += labels.size(0) 74 | if(predict_label == true_label): 75 | correct += 1 76 | # if(total%200==0): 77 | # print('Test Accuracy of the model on the %d test images: %f %%' % (total, 100 * correct / total)) 78 | return 100 * correct / total 79 | # print('Test Accuracy of the model on the %d test images: %f %%' % (total, 100 * correct / total)) 80 | 81 | if __name__ == '__main__': 82 | parser = argparse.ArgumentParser(description="test path") 83 | parser.add_argument('--model-path', type=str, default="weights/cnn_1.pt") 84 | 85 | args = parser.parse_args() 86 | main(args.model_path) 87 | 88 | 89 | -------------------------------------------------------------------------------- /torch_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | 4 | def init_seeds(seed=0): 5 | torch.manual_seed(seed) 6 | torch.cuda.manual_seed(seed) 7 | torch.cuda.manual_seed_all(seed) 8 | 9 | 10 | def select_device(force_cpu=False): 11 | cuda = False if force_cpu else torch.cuda.is_available() 12 | device = torch.device('cuda:0' if cuda else 'cpu') 13 | if not cuda: 14 | print('Using CPU') 15 | if cuda: 16 | c = 1024 ** 2 # bytes to MB 17 | ng = torch.cuda.device_count() 18 | x = [torch.cuda.get_device_properties(i) for i in range(ng)] 19 | print("Using CUDA device0 _CudaDeviceProperties(name='%s', total_memory=%dMB)" % 20 | (x[0].name, x[0].total_memory / c)) 21 | if ng > 0: 22 | # torch.cuda.set_device(0) # OPTIONAL: Set GPU ID 23 | for i in range(1, ng): 24 | print(" device%g _CudaDeviceProperties(name='%s', total_memory=%dMB)" % 25 | (i, x[i].name, x[i].total_memory / c)) 26 | print('') # skip a line 27 | return device 28 | 29 | def plot_result(): 30 | fig = plt.figure(15) 31 | with open("results.txt", "r") as f: 32 | import csv 33 | f_csv = csv.reader(f) 34 | epoch , acc = [], [] 35 | for row in f_csv: 36 | print(row) 37 | epoch.append(int(row[0])) 38 | acc.append(float(row[1])) 39 | plt.plot(epoch, acc) 40 | plt.title("result-epoch&acc") 41 | plt.xlabel("epoch") 42 | plt.ylabel("accuracy") 43 | plt.show() 44 | fig.savefig("results.png", dpi=300) 45 | 46 | if __name__=="__main__": 47 | plot_result() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import datasets 6 | from models import * 7 | # import torch_util 8 | import os, shutil 9 | import argparse 10 | import test 11 | import torchvision 12 | import settings 13 | 14 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" 15 | 16 | # Hyper Parameters 17 | num_epochs = 300 18 | batch_size = 20 19 | learning_rate = 0.001 20 | 21 | # device = torch_util.select_device() 22 | device = torch.device("cpu") 23 | 24 | def main(args): 25 | cnn = CNN().to(device) 26 | 27 | cnn.train() 28 | criterion = nn.MultiLabelSoftMarginLoss() 29 | optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate) 30 | 31 | if args.resume: 32 | cnn.load_state_dict(torch.load(args.model_path, map_location=device)) 33 | 34 | max_acc = 0 35 | # Train the Model 36 | train_dataloader = datasets.get_train_data_loader() 37 | for epoch in range(num_epochs): 38 | for i, (images, labels) in enumerate(train_dataloader): 39 | images = Variable(images) 40 | labels = Variable(labels.float()) 41 | predict_labels = cnn(images) 42 | loss = criterion(predict_labels, labels) 43 | optimizer.zero_grad() 44 | loss.backward() 45 | optimizer.step() 46 | if (i+1) % 2 == 0: 47 | print("epoch: %03g \t step: %03g \t loss: %.5f \t\r" % (epoch, i+1, loss.item())) 48 | torch.save(cnn.state_dict(), "./weights/cnn_%03g.pt" % epoch) 49 | print("epoch: %03g \t step: %03g \t loss: %.5f \t" % (epoch, i, loss.item())) 50 | torch.save(cnn.state_dict(), "./weights/cnn_%03g.pt" % epoch) 51 | acc = test.test_data("./weights/cnn_%03g.pt" % epoch) 52 | if max_acc < acc: 53 | print("update accuracy %.5f." % acc) 54 | max_acc = acc 55 | shutil.copy("./weights/cnn_%03g.pt" % epoch, "./weights/cnn_best.pt") 56 | else: 57 | print("do not update %.5f." % acc) 58 | 59 | torch.save(cnn.state_dict(), "./weights/cnn_last.pt") 60 | print("save last model") 61 | 62 | if __name__ == '__main__': 63 | parser = argparse.ArgumentParser(description="load path") 64 | parser.add_argument('--model-path', type=str, default="./weights/cnn_0.pt") 65 | parser.add_argument('--resume',action='store_true') 66 | 67 | args = parser.parse_args() 68 | main(args) 69 | 70 | 71 | --------------------------------------------------------------------------------