├── data
├── effb1-train.png
└── res34-train.png
├── .idea
├── inspectionProfiles
│ ├── profiles_settings.xml
│ └── Project_Default.xml
├── workspace.xml
├── modules.xml
├── captcha_ocr.iml
└── misc.xml
├── config.py
├── get_norm.py
├── README.md
├── captcha_dataset.py
├── predict.py
└── train.py
/data/effb1-train.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zyf-xtu/captcha_ocr/HEAD/data/effb1-train.png
--------------------------------------------------------------------------------
/data/res34-train.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zyf-xtu/captcha_ocr/HEAD/data/res34-train.png
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/captcha_ocr.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2022/2/16 上午10:28
4 | # @Author : zhangyunfei
5 | # @File : config.py
6 | # @Software: PyCharm
7 | """
8 | 配置文件
9 | """
10 | # 图像大小
11 | IMAGE_HEIGHT = 40
12 | IMAGE_WIDTH = 100
13 | # 图像分类的类别
14 | num_classes = 248
15 | # 训练batchsize大小
16 | batch_size = 32
17 | # 训练epoch
18 | num_epoch = 200
19 | # 学习率
20 | lr = 0.001
21 | # 训练过程保存模型地址
22 | checkpoints = 'checkpoints'
23 |
24 | # 训练集和验证集
25 | train_dir = 'data/train'
26 | val_dir = 'data/val'
27 | test_dir = 'data/test'
28 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/get_norm.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2022/2/16 上午10:41
4 | # @Author : zhangyunfei
5 | # @File : get_norm.py
6 | # @Software: PyCharm
7 | import cv2 as cv
8 | import os
9 | import numpy as np
10 | import torch
11 | import torchvision
12 |
13 |
14 | # 计算数据集的标准差和均值
15 | def init_normalize(data_dir, size):
16 | img_h, img_w = size[0], size[1] # 根据自己数据集适当调整,影响不大
17 | means = [0, 0, 0]
18 | stdevs = [0, 0, 0]
19 | img_list = []
20 | # imgs_path = './data/test'
21 | imgs_path_list = os.listdir(data_dir)
22 |
23 | num_imgs = 0
24 | # print(data)
25 | for pic in imgs_path_list:
26 | # print(pic)
27 | num_imgs += 1
28 | img = cv.imread(os.path.join(data_dir, pic))
29 | img = cv.resize(img, (img_h, img_w))
30 | img = img.astype(np.float32) / 255.
31 | for i in range(3):
32 | means[i] += img[:, :, i].mean()
33 | stdevs[i] += img[:, :, i].std()
34 |
35 | means.reverse()
36 | stdevs.reverse()
37 | means = np.asarray(means) / num_imgs
38 | stdevs = np.asarray(stdevs) / num_imgs
39 | # print("normMean = {}".format(means))
40 | # print("normStd = {}".format(stdevs))
41 | print('transforms.Normalize(normMean = {}, normStd = {})'.format(means, stdevs))
42 | print(list(means), list(stdevs))
43 | return list(means), list(stdevs)
44 |
45 |
46 | if __name__ == '__main__':
47 | init_normalize('./data/test', [256, 256])
48 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # captcha_ocr
2 | # 2022DCIC-基于文本字符的交易验证码识别-baseline
3 |
4 | ## 比赛地址
5 | https://www.dcic-china.com/competitions/10023
6 | ## 介绍
7 | 验证码作为性价较高的安全验证方法,在多场合得到了广泛的应用,有效地防止了机器人进行身份欺骗,其中,以基于文本字符的静态验证码最为常见。随着使用的深入,噪声点、噪声线、重叠、形变等干扰手段层出不穷,不断提升安全防范级别。RPA技术作为企业数字化转型的关键,因为其部署的非侵入式备受企业青睐,验证码识别率不高往往限制了RPA技术的应用。一个能同时过滤多种干扰的验证码模型,对于相关自动化技术的拓展使用有着一定的商业价值。
8 | 赛题任务:本次大赛以已标记字符信息的实例字符验证码图像数据为训练样本,参赛选手需基于提供的样本构建模型,对测试集中的字符验证码图像进行识别,提取有效的字符信息。训练数据集不局限于提供的数据,可以加入公开的数据集。
9 |
10 | ## 代码环境
11 | ```bash
12 | baseline是基于pytorch框架的,版本为1.6+
13 | 其他库如OpenCV、torchvision自行安装
14 | 如果使用efficientnet模型,
15 | 安装命令:pip install efficientnet_pytorch
16 | 使用请参考 https://github.com/lukemelas/EfficientNet-PyTorch
17 | ```
18 | ## 数据处理
19 | ```bash
20 | 需要将训练数据解压到data/train
21 | 需要将测试数据解压到data/test
22 | ```
23 | ## 代码介绍
24 | ```bash
25 | captcha_dataset.py
26 | 数据处理、加载
27 | config.py
28 | 训练配置文件
29 | train.py
30 | 训练文件
31 | 模型切换,如果使用resnet34,则需要将下面的注释开,同时注释efficientnet模型
32 | # 使用框架封装好的模型,使用预训练模型resnet34
33 | # model = models.resnet34(pretrained=True)
34 | # # 使用预训练模型需要修改fc层
35 | # num_fcs = model.fc.in_features
36 | # # print(num_fcs)
37 | # model.fc = nn.Sequential(
38 | # nn.Linear(num_fcs, num_classes)
39 | # )
40 | # 使用efficientnet网络模型
41 | model = EfficientNet.from_pretrained('efficientnet-b1', num_classes=248)
42 | 如果使用多卡训练需要设置
43 | os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2' 指定使用的GPU
44 | 然后取消236、237行的注释
45 | # device_ids = [0, 1, 2]
46 | # model = torch.nn.DataParallel(model, device_ids=device_ids)
47 | predict.py
48 | 预测文件,并生成提交文件
49 | ```
50 | ## 关于分数
51 |
52 | 没有提交成绩,贴出部分训练过程
53 | resnet34 分数88+
54 | 
55 | efficientb1 分数92+
56 | 
57 |
58 | ## 关于上分点
59 | ```bash
60 | 1.使用更好的模型如repvgg、efficientb2-b7等
61 | 2.使用数据增强mixup和cutmix等策略,暂未开源
62 | ```
--------------------------------------------------------------------------------
/captcha_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2022/2/16 上午10:23
4 | # @Author : zhangyunfei
5 | # @File : captcha_dataset.py
6 | # @Software: PyCharm
7 | import os
8 | from PIL import Image
9 | import torch
10 | from torch.utils.data import Dataset
11 |
12 | # 数据标签
13 | source = [str(i) for i in range(0, 10)]
14 | source += [chr(i) for i in range(97, 97 + 26)]
15 | source += [chr(i) for i in range(65, 65 + 26)]
16 | alphabet = ''.join(source)
17 |
18 | # 读取图像,并转化格式
19 | def img_loader(img_path):
20 | img = Image.open(img_path)
21 | return img.convert('RGB')
22 |
23 | # 制作数据集
24 | def make_dataset(data_path, alphabet, num_class, num_char):
25 | samples = []
26 | for img_path in data_path:
27 | target_str = img_path.split('.png')[0][-4:]
28 | assert len(target_str) == num_char
29 | target = []
30 | for char in target_str:
31 | vec = [0] * num_class
32 | vec[alphabet.find(char)] = 1
33 | target += vec
34 | samples.append((img_path, target))
35 | print(len(samples))
36 | return samples
37 |
38 | # 验证数据处理类
39 | class CaptchaData(Dataset):
40 | def __init__(self, data_path, num_class=62, num_char=4,
41 | transform=None, target_transform=None, alphabet=alphabet):
42 | super(Dataset, self).__init__()
43 | self.data_path = data_path
44 | self.num_class = num_class
45 | self.num_char = num_char
46 | self.transform = transform
47 | self.target_transform = target_transform
48 | self.alphabet = alphabet
49 | self.samples = make_dataset(self.data_path, self.alphabet,
50 | self.num_class, self.num_char)
51 |
52 | def __len__(self):
53 | return len(self.samples)
54 |
55 | def __getitem__(self, index):
56 | img_path, target = self.samples[index]
57 | img = img_loader(img_path)
58 | if self.transform is not None:
59 | img = self.transform(img)
60 | if self.target_transform is not None:
61 | target = self.target_transform(target)
62 |
63 | return img, torch.Tensor(target)
64 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2022/2/16 上午10:27
4 | # @Author : zhangyunfei
5 | # @File : predict.py
6 | # @Software: PyCharm
7 | import os
8 | import cv2 as cv
9 | import torch
10 | from PIL import Image
11 | import torch.nn as nn
12 | from torchvision import transforms
13 | from get_norm import init_normalize
14 | import csv
15 |
16 | source = [str(i) for i in range(0, 10)]
17 | source += [chr(i) for i in range(97, 97 + 26)]
18 | source += [chr(i) for i in range(65, 65 + 26)]
19 | alphabet = ''.join(source)
20 |
21 |
22 | def predict():
23 | model_path = './checkpoints/best_model.pth'
24 | test_dir = './data/test'
25 | test_mean, test_std = init_normalize(test_dir, size=[256, 256])
26 | transform = transforms.Compose([
27 | transforms.Resize((256, 256)),
28 | transforms.ToTensor(),
29 | transforms.Normalize(
30 | mean=test_mean,
31 | std=test_std
32 | )
33 | ])
34 | print(torch.cuda.is_available())
35 | model = torch.load(model_path)
36 | if torch.cuda.is_available():
37 | model = model.cuda()
38 |
39 | images = os.listdir(test_dir)
40 | images.sort(key=lambda x: int(x[:-4]))
41 | res = []
42 | for img in images:
43 | img_path = os.path.join(test_dir, img)
44 | image_read = Image.open(img_path)
45 | gray = image_read.convert('RGB')
46 | gray = transform(gray)
47 | image = gray.view(1, 3, 256, 256).cuda()
48 | output = model(image)
49 | output = output.view(-1, 62)
50 | output = nn.functional.softmax(output, dim=1)
51 | output = torch.argmax(output, dim=1)
52 | output = output.view(-1, 4)[0]
53 | pred = ''.join([alphabet[i] for i in output.cpu().numpy()])
54 | # print([alphabet[i] for i in output.cpu().numpy()])
55 | print(img, pred)
56 | res.append({'num': int(img[:-4]), 'tag': pred})
57 |
58 | header = ['num', 'tag']
59 | os.makedirs('sub', exist_ok=True)
60 | with open('sub/submit_021601.csv', 'w', encoding='utf_8_sig') as f:
61 | f_csv = csv.DictWriter(f, header)
62 | f_csv.writeheader()
63 | f_csv.writerows(res)
64 |
65 |
66 | if __name__ == '__main__':
67 | predict()
68 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2022/2/16 上午10:25
4 | # @Author : zhangyunfei
5 | # @File : train.py
6 | # @Software: PyCharm
7 | import os
8 | import json
9 | import torch
10 | from torchvision import models, transforms
11 | import torch.nn as nn
12 | import torch.optim as optim
13 | from captcha_dataset import CaptchaData
14 | from torch.utils.data import DataLoader
15 | import config
16 | import time
17 | import random
18 | import numpy as np
19 | from get_norm import init_normalize
20 | from efficientnet_pytorch import EfficientNet
21 |
22 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
23 |
24 |
25 | # 将数据集划分训练集和验证集
26 | def split_data(files):
27 | """
28 | :param files:
29 | :return:
30 | """
31 | random.shuffle(files)
32 | # 计算比例系数,分割数据训练集和验证集
33 | ratio = 0.9
34 | offset = int(len(files) * ratio)
35 | train_data = files[:offset]
36 | val_data = files[offset:]
37 | return train_data, val_data
38 |
39 |
40 | # 对数据集进行随机打乱
41 | def random_data(files):
42 | # 设置随机种子,保证每次随机值都一致
43 | random.seed(2022)
44 | random.shuffle(files)
45 | return files
46 |
47 |
48 | # 计算准确率
49 | def calculat_acc(output, target):
50 | output, target = output.view(-1, 62), target.view(-1, 62)
51 | output = nn.functional.softmax(output, dim=1)
52 | output = torch.argmax(output, dim=1)
53 | target = torch.argmax(target, dim=1)
54 | output, target = output.view(-1, 4), target.view(-1, 4)
55 | correct_list = []
56 | for i, j in zip(target, output):
57 | if torch.equal(i, j):
58 | correct_list.append(1)
59 | else:
60 | correct_list.append(0)
61 | acc = sum(correct_list) / len(correct_list)
62 | return acc
63 |
64 |
65 | # 设置随机种子,代码可复现
66 | def seed_it(seed):
67 | # random.seed(seed)
68 | os.environ["PYTHONSEED"] = str(seed)
69 | np.random.seed(seed)
70 | torch.cuda.manual_seed(seed)
71 | torch.cuda.manual_seed_all(seed)
72 | torch.backends.cudnn.deterministic = True
73 | torch.backends.cudnn.benchmark = True
74 | torch.backends.cudnn.enabled = True
75 | torch.manual_seed(seed)
76 |
77 |
78 | # 训练
79 | def train(model, loss_func, optimizer, checkpoints, epochs, lr_scheduler=None):
80 | print('Train......................')
81 | # 记录每个epoch的loss和acc
82 | record = []
83 | best_acc = 0
84 | best_epoch = 0
85 | # 训练过程
86 | for epoch in range(1, epochs):
87 | # 设置计时器,计算每个epoch的用时
88 | start_time = time.time()
89 | model.train() # 保证每一个batch都能进入model.train()的模式
90 | # 记录每个epoch的loss和acc
91 | train_loss, train_acc, val_loss, val_acc = [], [], [], []
92 | for i, (inputs, labels) in enumerate(train_data):
93 | # print(i, inputs, labels)
94 | inputs = inputs.to(device)
95 | labels = labels.to(device)
96 | # 预测输出
97 | outputs = model(inputs)
98 | # 计算损失
99 | loss = loss_func(outputs, labels)
100 | # print(outputs)
101 | # 因为梯度是累加的,需要清零
102 | optimizer.zero_grad()
103 | # 反向传播
104 | loss.backward()
105 | # 优化器
106 | optimizer.step()
107 | # 计算准确率
108 | acc = calculat_acc(outputs, labels)
109 | train_acc.append(float(acc))
110 | train_loss.append(float(loss))
111 | if lr_scheduler:
112 | lr_scheduler.step()
113 | # 验证集进行验证
114 | with torch.no_grad():
115 | model.eval()
116 | for i, (inputs, labels) in enumerate(val_data):
117 | inputs = inputs.to(device)
118 | labels = labels.to(device)
119 | # 预测输出
120 | outputs = model(inputs)
121 | # 计算损失
122 | loss = loss_func(outputs, labels)
123 | # 计算准确率
124 | acc = calculat_acc(outputs, labels)
125 | val_acc.append(float(acc))
126 | val_loss.append(float(loss))
127 |
128 | # 计算每个epoch的训练损失和精度
129 | train_loss_epoch = torch.mean(torch.Tensor(train_loss))
130 | train_acc_epoch = torch.mean(torch.Tensor(train_acc))
131 | # 计算每个epoch的验证集损失和精度
132 | val_loss_epoch = torch.mean(torch.Tensor(val_loss))
133 | val_acc_epoch = torch.mean(torch.Tensor(val_acc))
134 | # 记录训练过程
135 | record.append(
136 | [epoch, train_loss_epoch.item(), train_acc_epoch.item(), val_loss_epoch.item(), val_acc_epoch.item()])
137 | end_time = time.time()
138 | print(
139 | 'epoch:{} | time:{:.4f} | train_loss:{:.4f} | train_acc:{:.4f} | eval_loss:{:.4f} | val_acc:{:.4f}'.format(
140 | epoch,
141 | end_time - start_time,
142 | train_loss_epoch,
143 | train_acc_epoch,
144 | val_loss_epoch,
145 | val_acc_epoch))
146 |
147 | # 记录验证集上准确率最高的模型
148 | best_model_path = checkpoints + "/" 'best_model.pth'
149 | if val_acc_epoch >= best_acc:
150 | best_acc = val_acc_epoch
151 | best_epoch = epoch
152 | torch.save(model, best_model_path)
153 | print('Best Accuracy for Validation :{:.4f} at epoch {:d}'.format(best_acc, best_epoch))
154 | # 每迭代50次保存一次模型
155 | # if epoch % 50 == 0:
156 | # model_name = '/epoch_' + str(epoch) + '.pt'
157 | # torch.save(model, checkpoints + model_name)
158 | # 保存最后的模型
159 | # torch.save(model, checkpoints + '/last.pt')
160 | # 将记录保存下下来
161 | record_json = json.dumps(record)
162 | with open(checkpoints + '/' + 'record.txt', 'w+', encoding='utf8') as ff:
163 | ff.write(record_json)
164 |
165 |
166 | if __name__ == '__main__':
167 | # 设置随机种子
168 | # seed_it(2022)
169 | # 分类类别数
170 | num_classes = config.num_classes
171 | # batchsize大小
172 | batch_size = config.batch_size # config.batch_size
173 | # 迭代次数epoch
174 | epochs = config.num_epoch # config.num_epoch
175 | # 学习率
176 | lr = config.lr
177 | # 模型保存地址
178 | checkpoints = config.checkpoints
179 | if not os.path.exists(checkpoints):
180 | os.makedirs(checkpoints)
181 | # 训练接和验证集地址
182 | train_dir = config.train_dir
183 | test_dir = config.test_dir
184 | # 计算均值和标准差
185 | train_mean, train_std = init_normalize(train_dir, size=[256, 256])
186 | # 定义图像transform
187 | train_transform = transforms.Compose([
188 | transforms.Resize((256, 256)), # 图像放缩
189 | transforms.RandomRotation((-5, 5)), # 随机旋转
190 | # transforms.RandomVerticalFlip(p=0.2), # 随机旋转
191 | transforms.ToTensor(), # 转化成张量
192 | transforms.Normalize(
193 | mean=train_mean,
194 | std=train_std
195 | )
196 | ])
197 | val_transform = transforms.Compose([
198 | transforms.Resize((256, 256)), # 图像放缩
199 | transforms.ToTensor(), # 转化成张量
200 | transforms.Normalize(
201 | mean=train_mean,
202 | std=train_std
203 | )
204 | ])
205 | # 装载训练数据
206 | files = os.listdir(train_dir)
207 | img_paths = []
208 | for img in files:
209 | img_path = os.path.join(train_dir, img)
210 | img_paths.append(img_path)
211 | # 将训练数据拆分训练集和验证集
212 | train_paths, val_paths = split_data(img_paths)
213 | # 加载训练数据集,转化成标准格式
214 | train_dataset = CaptchaData(train_paths, transform=train_transform)
215 | # 加载数据
216 | train_data = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=32)
217 | # 加载验证集,转化成标准格式
218 | val_dataset = CaptchaData(val_paths, transform=val_transform)
219 | val_data = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True, num_workers=16)
220 | print('训练集数量:%s 验证集数量:%s' % (train_dataset.__len__(), val_dataset.__len__()))
221 |
222 | # 使用框架封装好的模型,使用预训练模型resnet34
223 | # model = models.resnet34(pretrained=True)
224 | # # 使用预训练模型需要修改fc层
225 | # num_fcs = model.fc.in_features
226 | # # print(num_fcs)
227 | # model.fc = nn.Sequential(
228 | # nn.Linear(num_fcs, num_classes)
229 | # )
230 | # 使用efficientnet网络模型
231 | model = EfficientNet.from_pretrained('efficientnet-b1', num_classes=248)
232 | # print(model)
233 | # GPU是否可用,如果可用,则使用GPU
234 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
235 | model.to(device)
236 | # 如果是多个gpu,数据并行训练
237 | # device_ids = [0, 1, 3, 4, 5]
238 | # model = torch.nn.DataParallel(model, device_ids=device_ids)
239 |
240 | # 定义损失函数
241 | loss_func = nn.MultiLabelSoftMarginLoss()
242 | # 定义优化器
243 | optimizer = optim.Adam(model.parameters(), lr=lr)
244 | # 训练
245 | train(model, loss_func, optimizer, checkpoints, epochs)
246 |
--------------------------------------------------------------------------------