├── data ├── 2an6.png └── 2g3e.png ├── __init__.py ├── README.md ├── config.yaml ├── models.py ├── datasets.py ├── test.py └── train.py /data/2an6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a2king/Captcha_NumAlphabet/HEAD/data/2an6.png -------------------------------------------------------------------------------- /data/2g3e.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/a2king/Captcha_NumAlphabet/HEAD/data/2g3e.png -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | # Created on 2020/12/17 16:39 4 | # Project: 5 | # @Author: CaoYugang 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Captcha_NumAlphabet 2 | 基于CNN的数字字母验证码识别训练项目pytorch版 3 | 4 | ## 实例 5 | 数据集、数据集模型已上传至网盘,环境测试时可下载 (百度云:链接:https://pan.baidu.com/s/1DglFQ1hl3mHzAooWxqBU8g 提取码:aegd )
6 |
7 | 实际训练过程中出现标签数据重复时,可通过'标签_随机值.jpg'实现相同标签的图片多样化,例:2Ab6_1231345.jpg 8 | 9 | ## 数据集展示 10 | ![avatar](data/2an6.png) 11 | ![avatar](data/2g3e.png) 12 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | # 配置 2 | width: 180 3 | height: 100 4 | alphabet: 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ 5 | numchar: 4 6 | train: 7 | epoch: 200 # 遍历数据集次数 8 | pre_epoch: 0 # 定义已经遍历数据集的次数 9 | batch_size: 256 # 批处理尺寸(batch_size) 10 | lr: 0.001 # 学习率 11 | train_data: ./data/train # 训练集路径 12 | test_data: ./data/test # 训练集路径 13 | is_gpu: True # 是否使用gpu 14 | num_workers: 8 # 并行处理数据进程数,根据显存大小自定义,显存越小work数越小 15 | out_model_path: ./model # 网络模型保存地址,需写绝对路径 16 | test: 17 | model_path: H:\pytouch_product\VerificationCode\cyg\GitHub_Captcha_NumAlphabet\dome_model.path # 测试所用的模型路径 18 | is_gpu: False # 是否使用gpu -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | # Created on 2020/11/3 13:38 4 | # Project: 5 | # @Author: CaoYugang 6 | 7 | import torch.nn as nn 8 | 9 | 10 | class CNN(nn.Module): 11 | def __init__(self, num_class=36, num_char=4, width=100, height=35): 12 | super(CNN, self).__init__() 13 | self.num_class = num_class 14 | self.num_char = num_char 15 | self.line_size = int(512 * (width // 2 // 2 // 2 // 2) * (height // 2 // 2 // 2 // 2)) 16 | self.conv1 = nn.Sequential( 17 | nn.Conv2d(3, 16, 3, padding=(1, 1)), 18 | nn.MaxPool2d(2, 2), 19 | nn.BatchNorm2d(16), 20 | nn.ReLU() 21 | ) 22 | self.conv2 = nn.Sequential( 23 | nn.Conv2d(16, 64, 3, padding=(1, 1)), 24 | nn.MaxPool2d(2, 2), 25 | nn.BatchNorm2d(64), 26 | nn.ReLU() 27 | ) 28 | self.conv3 = nn.Sequential( 29 | nn.Conv2d(64, 512, 3, padding=(1, 1)), 30 | nn.MaxPool2d(2, 2), 31 | nn.BatchNorm2d(512), 32 | nn.ReLU() 33 | ) 34 | self.conv4 = nn.Sequential( 35 | nn.Conv2d(512, 512, 3, padding=(1, 1)), 36 | nn.MaxPool2d(2, 2), 37 | nn.BatchNorm2d(512), 38 | nn.ReLU() 39 | ) 40 | self.fc = nn.Linear(self.line_size, self.num_class * self.num_char) 41 | 42 | def forward(self, x): 43 | x = self.conv1(x) 44 | x = self.conv2(x) 45 | x = self.conv3(x) 46 | x = self.conv4(x) 47 | x = x.view(-1, self.line_size) 48 | x = self.fc(x) 49 | return x 50 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | # Created on 2020/11/3 13:38 4 | # Project: 5 | # @Author: CaoYugang 6 | 7 | import os 8 | from PIL import Image 9 | import torch 10 | from torch.utils.data import Dataset 11 | 12 | # source = [str(i) for i in range(0, 10)] 13 | # source += [chr(i) for i in range(65, 65 + 26)] 14 | # alphabet = ''.join(source) 15 | 16 | 17 | def img_loader(img_path): 18 | img = Image.open(img_path) 19 | # 判断图片是否是RGB(简单判断图像是否是PNG格式) 20 | img = img if not img_path.endswith("png") else img.convert('RGB') 21 | return img.convert('RGB') 22 | 23 | 24 | def make_dataset(data_path, alphabet, num_class, num_char): 25 | img_names = os.listdir(data_path) 26 | samples = [] 27 | for img_name in img_names: 28 | img_path = os.path.join(data_path, img_name) 29 | target_str = img_name.replace("\\", "/").split('/')[-1].split('.')[0].split("_")[0] 30 | assert len(target_str) == num_char 31 | target = [] 32 | for char in target_str: 33 | vec = [0] * num_class 34 | vec[alphabet.find(char)] = 1 35 | target += vec 36 | samples.append((img_path, target)) 37 | return samples 38 | 39 | 40 | class CaptchaData(Dataset): 41 | def __init__(self, data_path, num_class=36, num_char=4, 42 | transform=None, target_transform=None, alphabet="0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"): 43 | super(Dataset, self).__init__() 44 | self.data_path = data_path 45 | self.num_class = num_class 46 | self.num_char = num_char 47 | self.transform = transform 48 | self.target_transform = target_transform 49 | self.alphabet = alphabet 50 | self.samples = make_dataset(self.data_path, self.alphabet, 51 | self.num_class, self.num_char) 52 | 53 | def __len__(self): 54 | return len(self.samples) 55 | 56 | def __getitem__(self, index): 57 | img_path, target = self.samples[index] 58 | img = img_loader(img_path) 59 | if self.transform is not None: 60 | img = self.transform(img) 61 | if self.target_transform is not None: 62 | target = self.target_transform(target) 63 | return img, torch.Tensor(target) 64 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Feb 13 20:07:17 2019 4 | 5 | @author: icetong 6 | """ 7 | import logging 8 | from io import BytesIO 9 | 10 | import torch 11 | import torch.nn as nn 12 | import yaml 13 | from PIL import ImageSequence, Image 14 | 15 | from models import CNN 16 | from torchvision.transforms import Compose, ToTensor, Resize 17 | 18 | logging.basicConfig(level=logging.INFO, 19 | format='%(asctime)s -[PID:%(process)s]-%(levelname)s-%(funcName)s-%(lineno)d: [ %(message)s ]', 20 | datefmt="%Y-%m-%d %H:%M:%S") 21 | 22 | with open('./config.yaml', 'r', encoding='utf-8') as f_config: 23 | config_result = f_config.read() 24 | config = yaml.load(config_result, Loader=yaml.FullLoader) 25 | 26 | model_path = config["test"]["model_path"] 27 | use_gpu = config["test"]["is_gpu"] 28 | width = config["width"] 29 | height = config["height"] 30 | alphabet = config["alphabet"] 31 | numchar = config["numchar"] 32 | 33 | # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 34 | model_net = CNN() 35 | 36 | 37 | def load_net(): 38 | global model_net 39 | model_net = CNN(num_class=len(alphabet), num_char=int(numchar), width=width, height=height) 40 | if use_gpu: 41 | model_net = model_net.cuda() 42 | model_net.eval() 43 | model_net.load_state_dict(torch.load(model_path)) 44 | else: 45 | model_net.eval() 46 | model_net.load_state_dict(torch.load(model_path, map_location='cpu')) 47 | 48 | 49 | def predict_image(img): 50 | global model_net 51 | with torch.no_grad(): 52 | img = img.convert('RGB') 53 | transforms = Compose([Resize((height, width)), ToTensor()]) 54 | img = transforms(img) 55 | 56 | if use_gpu: 57 | img = img.view(1, 3, height, width).cuda() 58 | else: 59 | img = img.view(1, 3, height, width) 60 | output = model_net(img) 61 | 62 | output = output.view(-1, len(alphabet)) 63 | output = nn.functional.softmax(output, dim=1) 64 | output = torch.argmax(output, dim=1) 65 | output = output.view(-1, 4)[0] 66 | return ''.join([alphabet[i] for i in output.cpu().numpy()]) 67 | 68 | 69 | def gif2jpg(gif_image): 70 | image = gif_image 71 | jpg_list = [] 72 | for index, f in enumerate(ImageSequence.Iterator(image)): 73 | # 获取图像序列存储 74 | # f.show() 75 | if index % 3 == 0: 76 | f = f.convert('RGB') 77 | output_buffer = BytesIO() 78 | f.save(output_buffer, format='JPEG') 79 | jpg_list.append(f) 80 | return jpg_list 81 | 82 | 83 | def predict_gif(gif_image): 84 | result_list = {} 85 | for _image in gif2jpg(gif_image): 86 | result = predict_image(_image) 87 | if result in result_list: 88 | result_list[result] += 1 89 | else: 90 | result_list[result] = 1 91 | return sorted([(key, result_list[key]) for key in result_list], key=lambda x: x[1], reverse=True)[0][0] 92 | 93 | 94 | if __name__ == "__main__": 95 | # predict() 96 | load_net() 97 | 98 | # gif_image = Image.open(r"F:\xftp_tmpfile\checkpoints_haiguan\1608197652.gif") 99 | # v_code = predict_gif(gif_image) 100 | # print(v_code) 101 | 102 | v_code = predict_image(Image.open(r"C:\Users\caoyugang\Downloads\data\test/fngg.png")) 103 | print(v_code) 104 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | # Created on 2020/11/3 13:38 4 | # Project: 5 | # @Author: CaoYugang 6 | import logging 7 | 8 | import torch 9 | import torch.nn as nn 10 | import yaml 11 | from torch.autograd import Variable 12 | from models import CNN 13 | from datasets import CaptchaData 14 | from torch.utils.data import DataLoader 15 | from torchvision.transforms import Compose, ToTensor, Resize 16 | import time 17 | import os 18 | logging.basicConfig(level=logging.INFO, 19 | format='%(asctime)s -[PID:%(process)s]-%(levelname)s-%(funcName)s-%(lineno)d: [ %(message)s ]', 20 | datefmt="%Y-%m-%d %H:%M:%S") 21 | 22 | with open('./config.yaml', 'r', encoding='utf-8') as f_config: 23 | config_result = f_config.read() 24 | config = yaml.load(config_result, Loader=yaml.FullLoader) 25 | 26 | 27 | batch_size = config["train"]["batch_size"] 28 | base_lr = config["train"]["lr"] 29 | max_epoch = config["train"]["epoch"] 30 | model_path = config["train"]["out_model_path"] 31 | train_data_path = config["train"]["train_data"] 32 | test_data_path = config["train"]["test_data"] 33 | num_workers = config["train"]["num_workers"] 34 | use_gpu = config["train"]["is_gpu"] 35 | width = config["width"] 36 | height = config["height"] 37 | alphabet = config["alphabet"] 38 | numchar = config["numchar"] 39 | # restor = False 40 | 41 | if not os.path.exists(model_path): 42 | logging.info("新建训练模型保存路径:{}".format(model_path)) 43 | os.makedirs(model_path) 44 | 45 | 46 | def calculat_acc(output, target): 47 | output, target = output.view(-1, len(alphabet)), target.view(-1, len(alphabet)) 48 | output = nn.functional.softmax(output, dim=1) 49 | output = torch.argmax(output, dim=1) 50 | target = torch.argmax(target, dim=1) 51 | output, target = output.view(-1, int(numchar)), target.view(-1, int(numchar)) 52 | correct_list = [] 53 | for i, j in zip(target, output): 54 | if torch.equal(i, j): 55 | correct_list.append(1) 56 | else: 57 | correct_list.append(0) 58 | acc = sum(correct_list) / len(correct_list) 59 | return acc 60 | 61 | 62 | def train(): 63 | transforms = Compose([Resize((height, width)), ToTensor()]) 64 | train_dataset = CaptchaData(train_data_path, num_class=len(alphabet), num_char=int(numchar), transform=transforms, alphabet=alphabet) 65 | train_data_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, 66 | shuffle=True, drop_last=True) 67 | test_data = CaptchaData(test_data_path, num_class=len(alphabet), num_char=int(numchar), transform=transforms, alphabet=alphabet) 68 | test_data_loader = DataLoader(test_data, batch_size=batch_size, 69 | num_workers=num_workers, shuffle=True, drop_last=True) 70 | cnn = CNN(num_class=len(alphabet), num_char=int(numchar), width=width, height=height) 71 | if use_gpu: 72 | cnn.cuda() 73 | 74 | optimizer = torch.optim.Adam(cnn.parameters(), lr=base_lr) 75 | criterion = nn.MultiLabelSoftMarginLoss() 76 | 77 | for epoch in range(max_epoch): 78 | start_ = time.time() 79 | 80 | loss_history = [] 81 | acc_history = [] 82 | cnn.train() 83 | for img, target in train_data_loader: 84 | img = Variable(img) 85 | target = Variable(target) 86 | if use_gpu: 87 | img = img.cuda() 88 | target = target.cuda() 89 | output = cnn(img) 90 | loss = criterion(output, target) 91 | optimizer.zero_grad() 92 | loss.backward() 93 | optimizer.step() 94 | 95 | acc = calculat_acc(output, target) 96 | acc_history.append(float(acc)) 97 | loss_history.append(float(loss)) 98 | print('epoch:{},train_loss: {:.4}|train_acc: {:.4}'.format( 99 | epoch, 100 | torch.mean(torch.Tensor(loss_history)), 101 | torch.mean(torch.Tensor(acc_history)), 102 | )) 103 | 104 | loss_history = [] 105 | acc_history = [] 106 | cnn.eval() 107 | for img, target in test_data_loader: 108 | img = Variable(img) 109 | target = Variable(target) 110 | if torch.cuda.is_available(): 111 | img = img.cuda() 112 | target = target.cuda() 113 | output = cnn(img) 114 | 115 | acc = calculat_acc(output, target) 116 | acc_history.append(float(acc)) 117 | loss_history.append(float(loss)) 118 | print('test_loss: {:.4}|test_acc: {:.4}'.format( 119 | torch.mean(torch.Tensor(loss_history)), 120 | torch.mean(torch.Tensor(acc_history)), 121 | )) 122 | print('epoch: {}|time: {:.4f}'.format(epoch, time.time() - start_)) 123 | torch.save(cnn.state_dict(), os.path.join(model_path, "model_{}.path".format(epoch))) 124 | 125 | 126 | if __name__ == "__main__": 127 | train() 128 | --------------------------------------------------------------------------------