├── .gitignore
├── .idea
├── .gitignore
├── ImageCaptchaOCR.iml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── README.md
├── development
├── __pycache__
│ ├── gen_ImageCaptcha.cpython-38.pyc
│ ├── model.cpython-38.pyc
│ └── one_hot.cpython-38.pyc
├── gen_ImageCaptcha.py
├── model.py
├── one_hot.py
├── predict.py
├── pth2onnx.py
└── server.py
├── requirement.txt
└── 英数验证码数据集.zip
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | ./development/pth2onnx.py
3 | ./英数验证码数据集.zip
4 | /dataset/
5 | /deploy/
6 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/ImageCaptchaOCR.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | @[TOC](pytorch_CNN英数验证码识别模型训练)
2 |
3 | ### 安卓逆向,JS逆向,图像识别,在线接单,全套源码+部署+算法联系QQ: 27788854,wechat: taisuivip,[telegram: rtais00](https://t.me/rtais00)
4 | ### 微信公众号:R逆向
5 | # 验证码样式
6 | #### 问题案例:
7 | 项目github:[https://github.com/taisuii/ImageCaptchaOCR](https://github.com/taisuii/ImageCaptchaOCR)
8 | 验证码样式通常为通常为英文数字
9 | 
10 | #### 解决方案或思路:
11 | 卷积神经网络,搭建模型,使其输出为5*(26+26+10)=310的张量,解码为英文和数字
12 | # 数据集准备
13 | #### 整理数据集:
14 | 先关注公众号:R逆向,回复:验证码数据集
15 | 用两个文件夹,train存放数据集,test存放测试数据集,内容如下,标签就是图片名字
16 | 
17 | #### 自定义数据集:
18 | ```python
19 | class mydatasets(Dataset):
20 | def __init__(self, root_dir):
21 | super(mydatasets, self).__init__()
22 | self.list_image_path = [os.path.join(root_dir, image_name) for image_name in os.listdir(root_dir)]
23 | self.transforms = transforms.Compose([
24 | transforms.ToTensor()
25 | ])
26 |
27 | def __getitem__(self, index):
28 | image_path = self.list_image_path[index]
29 | img_ = Image.open(image_path)
30 | image_name = image_path.split("\\")[-1]
31 | img_tesor = self.transforms(img_)
32 | img_lable = image_name.split(".")[0]
33 | img_lable = one_hot.text2vec(img_lable)
34 | img_lable = img_lable.view(1, -1)[0]
35 | return img_tesor, img_lable
36 |
37 | def __len__(self):
38 | return self.list_image_path.__len__()
39 | ```
40 | # 模型训练
41 | #### 模型代码如下:
42 | 随便写几个层,留意输入和输出,gen_ImageCaptcha.captcha_size * len(gen_ImageCaptcha.captcha_array)就是5*(26+26+10)
43 | ```python
44 | class MyModel(nn.Module):
45 | def __init__(self):
46 | super(MyModel, self).__init__()
47 | self.layer1 = nn.Sequential(
48 | nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1),
49 | nn.BatchNorm2d(64), # 添加Batch Normalization
50 | nn.ReLU(),
51 | nn.MaxPool2d(kernel_size=2)
52 | )
53 | self.layer2 = nn.Sequential(
54 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
55 | nn.BatchNorm2d(128),
56 | nn.ReLU(),
57 | nn.MaxPool2d(2)
58 | )
59 | self.layer3 = nn.Sequential(
60 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
61 | nn.BatchNorm2d(256),
62 | nn.ReLU(),
63 | nn.MaxPool2d(2)
64 | )
65 | self.layer4 = nn.Sequential(
66 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1),
67 | nn.BatchNorm2d(512),
68 | nn.ReLU(),
69 | nn.MaxPool2d(2)
70 | )
71 |
72 | # 使用全局平均池化代替Flatten
73 | self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
74 | self.fc = nn.Sequential(
75 | nn.Linear(512, 4096),
76 | nn.Dropout(0.2),
77 | nn.ReLU(),
78 | nn.Linear(4096, gen_ImageCaptcha.captcha_size * len(gen_ImageCaptcha.captcha_array))
79 | )
80 |
81 | def forward(self, x):
82 | x = self.layer1(x)
83 | x = self.layer2(x)
84 | x = self.layer3(x)
85 | x = self.layer4(x)
86 | x = self.global_avg_pool(x) # 全局平均池化
87 | x = x.view(x.size(0), -1) # 展平
88 | x = self.fc(x)
89 | return x
90 | ```
91 | #### 模型训练代码
92 | 这里的代码很常规,分类模型的训练,这里训练24个epoch
93 | ```python
94 | def train(epoch):
95 | train_datas = mydatasets("../dataset/train")
96 | train_dataloader = DataLoader(train_datas, batch_size=64, shuffle=True)
97 | m = MyModel().cuda()
98 |
99 | loss_fn = nn.MultiLabelSoftMarginLoss().cuda()
100 | optimizer = torch.optim.Adam(m.parameters(), lr=0.001)
101 | epoch_losses = []
102 | for i in range(epoch):
103 | losses = []
104 | # 迭代器进度条
105 | data_loader_tqdm = tqdm(train_dataloader)
106 |
107 | epoch_loss = 0
108 | for inputs, labels in data_loader_tqdm:
109 | inputs = inputs.cuda()
110 | labels = labels.cuda()
111 | optimizer.zero_grad()
112 |
113 | outputs = m(inputs)
114 | loss = loss_fn(outputs, labels)
115 | losses.append(loss.item())
116 | epoch_loss = np.mean(losses)
117 | data_loader_tqdm.set_description(
118 | f"This epoch is {str(i + 1)} and it's loss is {loss.item()}, average loss {epoch_loss}"
119 | )
120 |
121 | loss.backward()
122 | optimizer.step()
123 | epoch_losses.append(epoch_loss)
124 | # 每过一个batch就保存一次模型
125 | torch.save(m.state_dict(), f'../deplo/model/{str(i + 1)}_{epoch_loss}.pth')
126 |
127 |
128 | if __name__ == '__main__':
129 | train(24)
130 |
131 | ```
132 | loss值变化如下:
133 | 
134 |
135 | # 调用
136 | #### 模型预测:
137 | 把模型的输出转换成字符,也就是每隔62就解码一个字符
138 | ```python
139 | def test_pred():
140 | m = MyModel()
141 | m.load_state_dict(torch.load("../deploy/model/22_0.0007106820558649762.pth"))
142 | m.eval()
143 | test_data = mydatasets("../dataset/test")
144 |
145 | test_dataloader = DataLoader(test_data, batch_size=1, shuffle=False)
146 | test_length = test_data.__len__()
147 |
148 | correct = 0
149 |
150 | for i, (imgs, lables) in enumerate(test_dataloader):
151 | imgs = imgs
152 | lables = lables
153 | lables = lables.view(-1, gen_ImageCaptcha.captcha_array.__len__())
154 | lables_text = one_hot.vectotext(lables)
155 | start_time = time.time()
156 | predict_outputs = m(imgs)
157 | predict_outputs = predict_outputs.view(-1, gen_ImageCaptcha.captcha_array.__len__())
158 | predict_labels = one_hot.vectotext(predict_outputs)
159 | print(time.time() - start_time)
160 | if predict_labels == lables_text:
161 | correct += 1
162 | else:
163 | pass
164 | print("正确率{}".format(correct / test_length * 100))
165 |
166 |
167 | if __name__ == '__main__':
168 | test_pred()
169 | ```
170 | #### 识别速度和平均成功率如下:
171 | 
172 |
173 | 
174 |
175 | 
176 |
--------------------------------------------------------------------------------
/development/__pycache__/gen_ImageCaptcha.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taisuii/ImageCaptchaOCR/a838c3a51355fcefc748205a1ab5976598749ac4/development/__pycache__/gen_ImageCaptcha.cpython-38.pyc
--------------------------------------------------------------------------------
/development/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taisuii/ImageCaptchaOCR/a838c3a51355fcefc748205a1ab5976598749ac4/development/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/development/__pycache__/one_hot.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taisuii/ImageCaptchaOCR/a838c3a51355fcefc748205a1ab5976598749ac4/development/__pycache__/one_hot.cpython-38.pyc
--------------------------------------------------------------------------------
/development/gen_ImageCaptcha.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import time
4 |
5 | captcha_array = list("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
6 | captcha_size = 5
7 | from captcha.image import ImageCaptcha
8 |
9 | if __name__ == '__main__':
10 | print(captcha_array)
11 | image = ImageCaptcha()
12 | for i in range(1):
13 | image_val = "".join(random.sample(captcha_array, 4))
14 | image_name = "./{}_{}.png".format(image_val, int(time.time()))
15 | print(image_name)
16 | image.write(image_val, image_name)
17 |
--------------------------------------------------------------------------------
/development/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from matplotlib import pyplot as plt
3 | from torch import nn
4 | import development.gen_ImageCaptcha as gen_ImageCaptcha
5 | import os
6 | from PIL import Image
7 | from torch.utils.data import Dataset
8 | from torchvision import transforms
9 | import one_hot
10 | from torch.utils.data import DataLoader
11 | from tqdm import tqdm
12 | import numpy as np
13 |
14 |
15 | class mydatasets(Dataset):
16 | def __init__(self, root_dir):
17 | super(mydatasets, self).__init__()
18 | self.list_image_path = [os.path.join(root_dir, image_name) for image_name in os.listdir(root_dir)]
19 | self.transforms = transforms.Compose([
20 | transforms.ToTensor()
21 | ])
22 |
23 | def __getitem__(self, index):
24 | image_path = self.list_image_path[index]
25 | img_ = Image.open(image_path)
26 | image_name = image_path.split("\\")[-1]
27 | img_tesor = self.transforms(img_)
28 | img_lable = image_name.split(".")[0]
29 | img_lable = one_hot.text2vec(img_lable)
30 | img_lable = img_lable.view(1, -1)[0]
31 | return img_tesor, img_lable
32 |
33 | def __len__(self):
34 | return self.list_image_path.__len__()
35 |
36 |
37 | class MyModel(nn.Module):
38 | def __init__(self):
39 | super(MyModel, self).__init__()
40 | self.layer1 = nn.Sequential(
41 | nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1),
42 | nn.BatchNorm2d(64), # 添加Batch Normalization
43 | nn.ReLU(),
44 | nn.MaxPool2d(kernel_size=2)
45 | )
46 | self.layer2 = nn.Sequential(
47 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
48 | nn.BatchNorm2d(128),
49 | nn.ReLU(),
50 | nn.MaxPool2d(2)
51 | )
52 | self.layer3 = nn.Sequential(
53 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
54 | nn.BatchNorm2d(256),
55 | nn.ReLU(),
56 | nn.MaxPool2d(2)
57 | )
58 | self.layer4 = nn.Sequential(
59 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1),
60 | nn.BatchNorm2d(512),
61 | nn.ReLU(),
62 | nn.MaxPool2d(2)
63 | )
64 |
65 | # 使用全局平均池化代替Flatten
66 | self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
67 | self.fc = nn.Sequential(
68 | nn.Linear(512, 4096),
69 | nn.Dropout(0.2),
70 | nn.ReLU(),
71 | nn.Linear(4096, gen_ImageCaptcha.captcha_size * len(gen_ImageCaptcha.captcha_array))
72 | )
73 |
74 | def forward(self, x):
75 | x = self.layer1(x)
76 | x = self.layer2(x)
77 | x = self.layer3(x)
78 | x = self.layer4(x)
79 | x = self.global_avg_pool(x) # 全局平均池化
80 | x = x.view(x.size(0), -1) # 展平
81 | x = self.fc(x)
82 | return x
83 |
84 |
85 | def train(epoch):
86 | train_datas = mydatasets("../dataset/train")
87 | train_dataloader = DataLoader(train_datas, batch_size=64, shuffle=True)
88 | m = MyModel().cuda()
89 |
90 | loss_fn = nn.MultiLabelSoftMarginLoss().cuda()
91 | optimizer = torch.optim.Adam(m.parameters(), lr=0.001)
92 | epoch_losses = []
93 | for i in range(epoch):
94 | losses = []
95 | # 迭代器进度条
96 | data_loader_tqdm = tqdm(train_dataloader)
97 |
98 | epoch_loss = 0
99 | for inputs, labels in data_loader_tqdm:
100 | inputs = inputs.cuda()
101 | labels = labels.cuda()
102 | optimizer.zero_grad()
103 |
104 | outputs = m(inputs)
105 | loss = loss_fn(outputs, labels)
106 | losses.append(loss.item())
107 | epoch_loss = np.mean(losses)
108 | data_loader_tqdm.set_description(
109 | f"This epoch is {str(i + 1)} and it's loss is {loss.item()}, average loss {epoch_loss}"
110 | )
111 |
112 | loss.backward()
113 | optimizer.step()
114 | epoch_losses.append(epoch_loss)
115 | # 每过一个batch就保存一次模型
116 | torch.save(m.state_dict(), f'../deploy/model/{str(i + 1)}_{epoch_loss}.pth')
117 |
118 | # loss 变化绘制代码
119 | data = np.array(epoch_losses)
120 | plt.figure(figsize=(10, 6))
121 | plt.plot(data)
122 | plt.title(f"{epoch} epoch loss change")
123 | plt.xlabel("epoch")
124 | plt.ylabel("Loss")
125 | # 显示图像
126 | plt.show()
127 | print(f"completed. Model saved.")
128 |
129 | if __name__ == '__main__':
130 | train(24)
131 |
--------------------------------------------------------------------------------
/development/one_hot.py:
--------------------------------------------------------------------------------
1 | import development.gen_ImageCaptcha as gen_ImageCaptcha
2 | import torch
3 |
4 | def text2vec(text):
5 | vectors = torch.zeros((gen_ImageCaptcha.captcha_size, gen_ImageCaptcha.captcha_array.__len__()))
6 | for i in range(len(text)):
7 | vectors[i, gen_ImageCaptcha.captcha_array.index(text[i])] = 1
8 | return vectors
9 |
10 |
11 | def vectotext(vec):
12 | vec = torch.argmax(vec, dim=1)
13 |
14 | text_label = ""
15 | for v in vec:
16 | text_label += gen_ImageCaptcha.captcha_array[v]
17 | return text_label
18 |
19 |
20 | if __name__ == '__main__':
21 | vec = text2vec("aaabv")
22 | print(vec.shape)
23 | print(vectotext(vec))
24 |
--------------------------------------------------------------------------------
/development/predict.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import numpy as np
4 | from PIL import Image
5 | from matplotlib import pyplot as plt
6 | from torch.utils.data import DataLoader
7 | import one_hot
8 | import torch
9 | import development.gen_ImageCaptcha as gen_ImageCaptcha
10 | from torchvision import transforms
11 | from model import mydatasets, MyModel
12 |
13 |
14 | def test_pred():
15 | m = MyModel()
16 | m.load_state_dict(torch.load("../deploy/model/22_0.0007106820558649762.pth"))
17 | m.eval()
18 | test_data = mydatasets("../dataset/test")
19 |
20 | test_dataloader = DataLoader(test_data, batch_size=1, shuffle=False)
21 | test_length = test_data.__len__()
22 |
23 | correct = 0
24 | spendtime = []
25 | for i, (imgs, lables) in enumerate(test_dataloader):
26 | imgs = imgs
27 | lables = lables
28 | lables = lables.view(-1, gen_ImageCaptcha.captcha_array.__len__())
29 | lables_text = one_hot.vectotext(lables)
30 | start_time = time.time()
31 | predict_outputs = m(imgs)
32 | predict_outputs = predict_outputs.view(-1, gen_ImageCaptcha.captcha_array.__len__())
33 | predict_labels = one_hot.vectotext(predict_outputs)
34 | spendtime.append(time.time() - start_time)
35 | print(predict_labels, lables_text)
36 | if predict_labels == lables_text:
37 | correct += 1
38 | else:
39 | pass
40 |
41 | correct_ = correct / test_length * 100
42 | data = np.array(spendtime)
43 | plt.figure(figsize=(10, 6))
44 | plt.plot(data)
45 | plt.title(f"verify spend time, average spend: {np.mean(spendtime)} and Success rate: {correct_}")
46 | plt.ylabel("time")
47 | # 显示图像
48 | plt.show()
49 |
50 |
51 | if __name__ == '__main__':
52 | test_pred()
53 |
--------------------------------------------------------------------------------
/development/pth2onnx.py:
--------------------------------------------------------------------------------
1 | from model import MyModel
2 | import torch
3 |
4 | def convert():
5 | # 加载 PyTorch 模型
6 | model_path = "model/resnet18_38_0.021147585306924.pth"
7 | model = MyModel()
8 | model.load_state_dict(torch.load(model_path))
9 | model.eval()
10 | # 生成一个示例输入
11 | dummy_input = torch.randn(10, 3, 224, 224)
12 | # 将模型转换为 ONNX 格式
13 | torch.onnx.export(model, dummy_input, "model/resnet18.onnx", verbose=True)
14 |
15 |
16 | if __name__ == '__main__':
17 | convert()
18 |
--------------------------------------------------------------------------------
/development/server.py:
--------------------------------------------------------------------------------
1 | from flask import Flask, request, jsonify
2 | import torch
3 | from PIL import Image
4 | import io
5 | import base64
6 | import torchvision.transforms as transforms
7 |
8 | from development import gen_ImageCaptcha, one_hot
9 | from development.model import MyModel
10 |
11 | # 加载模型
12 | m = MyModel()
13 | m.load_state_dict(torch.load("../deploy/model/22_0.0007106820558649762.pth"))
14 | m.eval()
15 |
16 | # Flask 应用初始化
17 | app = Flask(__name__)
18 |
19 | # 预处理步骤(如有必要可以自定义)
20 | data_transform = transforms.Compose([
21 | transforms.ToTensor(),
22 | ])
23 |
24 |
25 | # OCR 处理函数
26 | def ocr(image_bytes):
27 | # 将字节数据转化为PIL图像
28 | image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
29 |
30 | # 图像预处理
31 | imgs = data_transform(image)
32 |
33 | # 模型推理
34 | predict_outputs = m(imgs.unsqueeze(0)) # 加入批次维度
35 | predict_outputs = predict_outputs.view(-1, gen_ImageCaptcha.captcha_array.__len__())
36 |
37 | # 转换为可读标签
38 | predict_labels = one_hot.vectotext(predict_outputs)
39 |
40 | return predict_labels
41 |
42 |
43 | # 路由处理
44 | @app.route("/runtime/text/invoke", methods=["POST"])
45 | def invoke_ocr():
46 | # 从请求中提取 JSON
47 | req_data = request.get_json()
48 |
49 | # 提取项目名和图像(Base64 编码)
50 | project_name = req_data.get("project_name", "")
51 | image_base64 = req_data.get("image", "")
52 |
53 | if project_name != "ctc_en5l_240516":
54 | return jsonify({"error": "Invalid project name"}), 400
55 |
56 | # 解码 Base64 图像
57 | image_bytes = base64.b64decode(image_base64)
58 |
59 | # 运行 OCR 识别
60 | try:
61 | ocr_result = ocr(image_bytes)
62 | return jsonify({"data": ocr_result})
63 | except Exception as e:
64 | return jsonify({"error": str(e)}), 500
65 |
66 |
67 | # 启动 Flask 应用
68 | if __name__ == "__main__":
69 | app.run(host="127.0.0.1", port=19199)
70 |
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taisuii/ImageCaptchaOCR/a838c3a51355fcefc748205a1ab5976598749ac4/requirement.txt
--------------------------------------------------------------------------------
/英数验证码数据集.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taisuii/ImageCaptchaOCR/a838c3a51355fcefc748205a1ab5976598749ac4/英数验证码数据集.zip
--------------------------------------------------------------------------------