├── data
├── 2an6.png
└── 2g3e.png
├── __init__.py
├── README.md
├── config.yaml
├── models.py
├── datasets.py
├── test.py
└── train.py
/data/2an6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a2king/Captcha_NumAlphabet/HEAD/data/2an6.png
--------------------------------------------------------------------------------
/data/2g3e.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/a2king/Captcha_NumAlphabet/HEAD/data/2g3e.png
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- encoding: utf-8 -*-
3 | # Created on 2020/12/17 16:39
4 | # Project:
5 | # @Author: CaoYugang
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Captcha_NumAlphabet
2 | 基于CNN的数字字母验证码识别训练项目pytorch版
3 |
4 | ## 实例
5 | 数据集、数据集模型已上传至网盘,环境测试时可下载 (百度云:链接:https://pan.baidu.com/s/1DglFQ1hl3mHzAooWxqBU8g 提取码:aegd )
6 |
7 | 实际训练过程中出现标签数据重复时,可通过'标签_随机值.jpg'实现相同标签的图片多样化,例:2Ab6_1231345.jpg
8 |
9 | ## 数据集展示
10 | 
11 | 
12 |
--------------------------------------------------------------------------------
/config.yaml:
--------------------------------------------------------------------------------
1 | # 配置
2 | width: 180
3 | height: 100
4 | alphabet: 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ
5 | numchar: 4
6 | train:
7 | epoch: 200 # 遍历数据集次数
8 | pre_epoch: 0 # 定义已经遍历数据集的次数
9 | batch_size: 256 # 批处理尺寸(batch_size)
10 | lr: 0.001 # 学习率
11 | train_data: ./data/train # 训练集路径
12 | test_data: ./data/test # 训练集路径
13 | is_gpu: True # 是否使用gpu
14 | num_workers: 8 # 并行处理数据进程数,根据显存大小自定义,显存越小work数越小
15 | out_model_path: ./model # 网络模型保存地址,需写绝对路径
16 | test:
17 | model_path: H:\pytouch_product\VerificationCode\cyg\GitHub_Captcha_NumAlphabet\dome_model.path # 测试所用的模型路径
18 | is_gpu: False # 是否使用gpu
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- encoding: utf-8 -*-
3 | # Created on 2020/11/3 13:38
4 | # Project:
5 | # @Author: CaoYugang
6 |
7 | import torch.nn as nn
8 |
9 |
10 | class CNN(nn.Module):
11 | def __init__(self, num_class=36, num_char=4, width=100, height=35):
12 | super(CNN, self).__init__()
13 | self.num_class = num_class
14 | self.num_char = num_char
15 | self.line_size = int(512 * (width // 2 // 2 // 2 // 2) * (height // 2 // 2 // 2 // 2))
16 | self.conv1 = nn.Sequential(
17 | nn.Conv2d(3, 16, 3, padding=(1, 1)),
18 | nn.MaxPool2d(2, 2),
19 | nn.BatchNorm2d(16),
20 | nn.ReLU()
21 | )
22 | self.conv2 = nn.Sequential(
23 | nn.Conv2d(16, 64, 3, padding=(1, 1)),
24 | nn.MaxPool2d(2, 2),
25 | nn.BatchNorm2d(64),
26 | nn.ReLU()
27 | )
28 | self.conv3 = nn.Sequential(
29 | nn.Conv2d(64, 512, 3, padding=(1, 1)),
30 | nn.MaxPool2d(2, 2),
31 | nn.BatchNorm2d(512),
32 | nn.ReLU()
33 | )
34 | self.conv4 = nn.Sequential(
35 | nn.Conv2d(512, 512, 3, padding=(1, 1)),
36 | nn.MaxPool2d(2, 2),
37 | nn.BatchNorm2d(512),
38 | nn.ReLU()
39 | )
40 | self.fc = nn.Linear(self.line_size, self.num_class * self.num_char)
41 |
42 | def forward(self, x):
43 | x = self.conv1(x)
44 | x = self.conv2(x)
45 | x = self.conv3(x)
46 | x = self.conv4(x)
47 | x = x.view(-1, self.line_size)
48 | x = self.fc(x)
49 | return x
50 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- encoding: utf-8 -*-
3 | # Created on 2020/11/3 13:38
4 | # Project:
5 | # @Author: CaoYugang
6 |
7 | import os
8 | from PIL import Image
9 | import torch
10 | from torch.utils.data import Dataset
11 |
12 | # source = [str(i) for i in range(0, 10)]
13 | # source += [chr(i) for i in range(65, 65 + 26)]
14 | # alphabet = ''.join(source)
15 |
16 |
17 | def img_loader(img_path):
18 | img = Image.open(img_path)
19 | # 判断图片是否是RGB(简单判断图像是否是PNG格式)
20 | img = img if not img_path.endswith("png") else img.convert('RGB')
21 | return img.convert('RGB')
22 |
23 |
24 | def make_dataset(data_path, alphabet, num_class, num_char):
25 | img_names = os.listdir(data_path)
26 | samples = []
27 | for img_name in img_names:
28 | img_path = os.path.join(data_path, img_name)
29 | target_str = img_name.replace("\\", "/").split('/')[-1].split('.')[0].split("_")[0]
30 | assert len(target_str) == num_char
31 | target = []
32 | for char in target_str:
33 | vec = [0] * num_class
34 | vec[alphabet.find(char)] = 1
35 | target += vec
36 | samples.append((img_path, target))
37 | return samples
38 |
39 |
40 | class CaptchaData(Dataset):
41 | def __init__(self, data_path, num_class=36, num_char=4,
42 | transform=None, target_transform=None, alphabet="0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"):
43 | super(Dataset, self).__init__()
44 | self.data_path = data_path
45 | self.num_class = num_class
46 | self.num_char = num_char
47 | self.transform = transform
48 | self.target_transform = target_transform
49 | self.alphabet = alphabet
50 | self.samples = make_dataset(self.data_path, self.alphabet,
51 | self.num_class, self.num_char)
52 |
53 | def __len__(self):
54 | return len(self.samples)
55 |
56 | def __getitem__(self, index):
57 | img_path, target = self.samples[index]
58 | img = img_loader(img_path)
59 | if self.transform is not None:
60 | img = self.transform(img)
61 | if self.target_transform is not None:
62 | target = self.target_transform(target)
63 | return img, torch.Tensor(target)
64 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Wed Feb 13 20:07:17 2019
4 |
5 | @author: icetong
6 | """
7 | import logging
8 | from io import BytesIO
9 |
10 | import torch
11 | import torch.nn as nn
12 | import yaml
13 | from PIL import ImageSequence, Image
14 |
15 | from models import CNN
16 | from torchvision.transforms import Compose, ToTensor, Resize
17 |
18 | logging.basicConfig(level=logging.INFO,
19 | format='%(asctime)s -[PID:%(process)s]-%(levelname)s-%(funcName)s-%(lineno)d: [ %(message)s ]',
20 | datefmt="%Y-%m-%d %H:%M:%S")
21 |
22 | with open('./config.yaml', 'r', encoding='utf-8') as f_config:
23 | config_result = f_config.read()
24 | config = yaml.load(config_result, Loader=yaml.FullLoader)
25 |
26 | model_path = config["test"]["model_path"]
27 | use_gpu = config["test"]["is_gpu"]
28 | width = config["width"]
29 | height = config["height"]
30 | alphabet = config["alphabet"]
31 | numchar = config["numchar"]
32 |
33 | # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34 | model_net = CNN()
35 |
36 |
37 | def load_net():
38 | global model_net
39 | model_net = CNN(num_class=len(alphabet), num_char=int(numchar), width=width, height=height)
40 | if use_gpu:
41 | model_net = model_net.cuda()
42 | model_net.eval()
43 | model_net.load_state_dict(torch.load(model_path))
44 | else:
45 | model_net.eval()
46 | model_net.load_state_dict(torch.load(model_path, map_location='cpu'))
47 |
48 |
49 | def predict_image(img):
50 | global model_net
51 | with torch.no_grad():
52 | img = img.convert('RGB')
53 | transforms = Compose([Resize((height, width)), ToTensor()])
54 | img = transforms(img)
55 |
56 | if use_gpu:
57 | img = img.view(1, 3, height, width).cuda()
58 | else:
59 | img = img.view(1, 3, height, width)
60 | output = model_net(img)
61 |
62 | output = output.view(-1, len(alphabet))
63 | output = nn.functional.softmax(output, dim=1)
64 | output = torch.argmax(output, dim=1)
65 | output = output.view(-1, 4)[0]
66 | return ''.join([alphabet[i] for i in output.cpu().numpy()])
67 |
68 |
69 | def gif2jpg(gif_image):
70 | image = gif_image
71 | jpg_list = []
72 | for index, f in enumerate(ImageSequence.Iterator(image)):
73 | # 获取图像序列存储
74 | # f.show()
75 | if index % 3 == 0:
76 | f = f.convert('RGB')
77 | output_buffer = BytesIO()
78 | f.save(output_buffer, format='JPEG')
79 | jpg_list.append(f)
80 | return jpg_list
81 |
82 |
83 | def predict_gif(gif_image):
84 | result_list = {}
85 | for _image in gif2jpg(gif_image):
86 | result = predict_image(_image)
87 | if result in result_list:
88 | result_list[result] += 1
89 | else:
90 | result_list[result] = 1
91 | return sorted([(key, result_list[key]) for key in result_list], key=lambda x: x[1], reverse=True)[0][0]
92 |
93 |
94 | if __name__ == "__main__":
95 | # predict()
96 | load_net()
97 |
98 | # gif_image = Image.open(r"F:\xftp_tmpfile\checkpoints_haiguan\1608197652.gif")
99 | # v_code = predict_gif(gif_image)
100 | # print(v_code)
101 |
102 | v_code = predict_image(Image.open(r"C:\Users\caoyugang\Downloads\data\test/fngg.png"))
103 | print(v_code)
104 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- encoding: utf-8 -*-
3 | # Created on 2020/11/3 13:38
4 | # Project:
5 | # @Author: CaoYugang
6 | import logging
7 |
8 | import torch
9 | import torch.nn as nn
10 | import yaml
11 | from torch.autograd import Variable
12 | from models import CNN
13 | from datasets import CaptchaData
14 | from torch.utils.data import DataLoader
15 | from torchvision.transforms import Compose, ToTensor, Resize
16 | import time
17 | import os
18 | logging.basicConfig(level=logging.INFO,
19 | format='%(asctime)s -[PID:%(process)s]-%(levelname)s-%(funcName)s-%(lineno)d: [ %(message)s ]',
20 | datefmt="%Y-%m-%d %H:%M:%S")
21 |
22 | with open('./config.yaml', 'r', encoding='utf-8') as f_config:
23 | config_result = f_config.read()
24 | config = yaml.load(config_result, Loader=yaml.FullLoader)
25 |
26 |
27 | batch_size = config["train"]["batch_size"]
28 | base_lr = config["train"]["lr"]
29 | max_epoch = config["train"]["epoch"]
30 | model_path = config["train"]["out_model_path"]
31 | train_data_path = config["train"]["train_data"]
32 | test_data_path = config["train"]["test_data"]
33 | num_workers = config["train"]["num_workers"]
34 | use_gpu = config["train"]["is_gpu"]
35 | width = config["width"]
36 | height = config["height"]
37 | alphabet = config["alphabet"]
38 | numchar = config["numchar"]
39 | # restor = False
40 |
41 | if not os.path.exists(model_path):
42 | logging.info("新建训练模型保存路径:{}".format(model_path))
43 | os.makedirs(model_path)
44 |
45 |
46 | def calculat_acc(output, target):
47 | output, target = output.view(-1, len(alphabet)), target.view(-1, len(alphabet))
48 | output = nn.functional.softmax(output, dim=1)
49 | output = torch.argmax(output, dim=1)
50 | target = torch.argmax(target, dim=1)
51 | output, target = output.view(-1, int(numchar)), target.view(-1, int(numchar))
52 | correct_list = []
53 | for i, j in zip(target, output):
54 | if torch.equal(i, j):
55 | correct_list.append(1)
56 | else:
57 | correct_list.append(0)
58 | acc = sum(correct_list) / len(correct_list)
59 | return acc
60 |
61 |
62 | def train():
63 | transforms = Compose([Resize((height, width)), ToTensor()])
64 | train_dataset = CaptchaData(train_data_path, num_class=len(alphabet), num_char=int(numchar), transform=transforms, alphabet=alphabet)
65 | train_data_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers,
66 | shuffle=True, drop_last=True)
67 | test_data = CaptchaData(test_data_path, num_class=len(alphabet), num_char=int(numchar), transform=transforms, alphabet=alphabet)
68 | test_data_loader = DataLoader(test_data, batch_size=batch_size,
69 | num_workers=num_workers, shuffle=True, drop_last=True)
70 | cnn = CNN(num_class=len(alphabet), num_char=int(numchar), width=width, height=height)
71 | if use_gpu:
72 | cnn.cuda()
73 |
74 | optimizer = torch.optim.Adam(cnn.parameters(), lr=base_lr)
75 | criterion = nn.MultiLabelSoftMarginLoss()
76 |
77 | for epoch in range(max_epoch):
78 | start_ = time.time()
79 |
80 | loss_history = []
81 | acc_history = []
82 | cnn.train()
83 | for img, target in train_data_loader:
84 | img = Variable(img)
85 | target = Variable(target)
86 | if use_gpu:
87 | img = img.cuda()
88 | target = target.cuda()
89 | output = cnn(img)
90 | loss = criterion(output, target)
91 | optimizer.zero_grad()
92 | loss.backward()
93 | optimizer.step()
94 |
95 | acc = calculat_acc(output, target)
96 | acc_history.append(float(acc))
97 | loss_history.append(float(loss))
98 | print('epoch:{},train_loss: {:.4}|train_acc: {:.4}'.format(
99 | epoch,
100 | torch.mean(torch.Tensor(loss_history)),
101 | torch.mean(torch.Tensor(acc_history)),
102 | ))
103 |
104 | loss_history = []
105 | acc_history = []
106 | cnn.eval()
107 | for img, target in test_data_loader:
108 | img = Variable(img)
109 | target = Variable(target)
110 | if torch.cuda.is_available():
111 | img = img.cuda()
112 | target = target.cuda()
113 | output = cnn(img)
114 |
115 | acc = calculat_acc(output, target)
116 | acc_history.append(float(acc))
117 | loss_history.append(float(loss))
118 | print('test_loss: {:.4}|test_acc: {:.4}'.format(
119 | torch.mean(torch.Tensor(loss_history)),
120 | torch.mean(torch.Tensor(acc_history)),
121 | ))
122 | print('epoch: {}|time: {:.4f}'.format(epoch, time.time() - start_))
123 | torch.save(cnn.state_dict(), os.path.join(model_path, "model_{}.path".format(epoch)))
124 |
125 |
126 | if __name__ == "__main__":
127 | train()
128 |
--------------------------------------------------------------------------------