├── rotate_model.pth ├── README.md ├── rotate_all_image.py ├── example.py ├── train_torch.py └── label.py /rotate_model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/8yteDance/RotateCaptcha/HEAD/rotate_model.pth -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 项目介绍 2 | 3 | 本项目主要用于介绍一种旋转类验证码基于CNN识别的通用方法,包含双图和单图,并基于某书的双图训练了一个误差极小的双图旋转验证码识别模型。 4 | 5 | ![image](https://github.com/8yteDance/RotateCaptcha/assets/164896531/961218cd-746e-4d28-8786-601eaf283ba7) 6 | 7 | 如果想直接使用某书模型,直接参考`example.py`,不用继续向下看。 8 | 9 | # 自行训练操作步骤 10 | 11 | ## 第一步:准备训练语料 12 | 13 | 将尽可能多的验证码图片下载下来,并使用标注系统`label.py`(需要自行修改代码中的路径,如果是单图旋转的话,也需要自行修改代码)旋转到正确角度(0度)。 14 | 15 | 然后将旋转到正确角度的代码进行去重,我为了方便使用了Duplicate Cleaner 5,将多余的重复图片删除掉,某书大概能剩余个五十来张不重复素材。 16 | 17 | 将所有图片分别旋转到对应360度图片,并保存到硬盘,参考`rotate_all_image.py`。 18 | 19 | 不需要手动分测试集语料,训练代码会自动分测试集。 20 | 21 | ## 第二步:开始训练 22 | 23 | 修改`train_torch.py`中第`60行`代码处的路径,将其指向到你的图片路径。 24 | 25 | ## 第三步:使用模型预测 26 | 27 | 参考`example.py` 28 | 29 | ## 关于训练的一些技巧 30 | 31 | 第一次训练可以把`69行`的`num_epochs`调整大一些,比如200,然后观察训练时打印的loss数值,找到一个loss最低的轮数,修改num_epochs为上次loss值最低的轮数重新跑一次,基本上训练出来效果就会非常好 32 | 33 | # 结尾 34 | 35 | 这些代码全是ChatGPT写的,我只是给ChatGPT提供了靠谱的思路,这些代码你们也可以让ChatGPT给你们写,不过记得用英文提问,这样ChatGPT的理解能力会有巨大提升,能写出来靠谱的代码~ 36 | -------------------------------------------------------------------------------- /rotate_all_image.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | 4 | def rotate_image(image_path, degrees_to_rotate, output_path): 5 | # 打开图像 6 | image = Image.open(image_path) 7 | 8 | print(degrees_to_rotate) 9 | # 旋转图像 10 | rotated_image = image.rotate(-degrees_to_rotate) 11 | 12 | # 保存旋转后的图像 13 | rotated_image.save(output_path) 14 | 15 | print("图像已旋转并保存到:", output_path) 16 | 17 | def rotate_and_save_images(input_dir, output_dir, degree_step=1): 18 | # 确保输出目录存在 19 | os.makedirs(output_dir, exist_ok=True) 20 | 21 | # 遍历输入目录下的所有文件 22 | for filename in os.listdir(input_dir): 23 | if filename.endswith(".png") or filename.endswith(".jpg"): 24 | # 打开图像 25 | image_path = os.path.join(input_dir, filename) 26 | image = Image.open(image_path) 27 | 28 | # 生成旋转的图像并保存 29 | for degree in range(0, 360, degree_step): 30 | rotated_image = image.rotate(degree) 31 | output_filename = f"{os.path.splitext(filename)[0]}_{degree}.png" 32 | output_path = os.path.join(output_dir, output_filename) 33 | rotated_image.save(output_path) 34 | print(f"已生成并保存旋转 {degree} 度的图像:{output_path}") 35 | 36 | 37 | 38 | 39 | 40 | if __name__ == '__main__': 41 | 42 | # 输入目录和输出目录 43 | input_directory = "temp" 44 | output_directory = "360" 45 | 46 | # 调用函数生成旋转的图像并保存 47 | rotate_and_save_images(input_directory, output_directory) 48 | 49 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | import torch 5 | import torch.nn as nn 6 | from torchvision import transforms 7 | from PIL import Image 8 | 9 | # Define the CNN architecture 10 | class CNN(nn.Module): 11 | def __init__(self): 12 | super(CNN, self).__init__() 13 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1) 14 | self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1) 15 | self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 16 | self.pool = nn.MaxPool2d(2, 2) 17 | self.fc1 = nn.Linear(64 * 5 * 5, 128) 18 | self.fc2 = nn.Linear(128, 360) # 360 classes for 0-358 degrees rotation 19 | 20 | def forward(self, x): 21 | x = self.pool(torch.relu(self.conv1(x))) 22 | x = self.pool(torch.relu(self.conv2(x))) 23 | x = self.pool(torch.relu(self.conv3(x))) 24 | x = x.view(-1, 64 * 5 * 5) 25 | x = torch.relu(self.fc1(x)) 26 | x = self.fc2(x) 27 | return x 28 | 29 | # Load the trained model 30 | model = CNN() 31 | model.load_state_dict(torch.load('rotate_model.pth')) 32 | model.eval() 33 | 34 | # Define transformations for input images 35 | transform = transforms.Compose([ 36 | transforms.Resize((40, 40)), 37 | transforms.ToTensor(), 38 | ]) 39 | 40 | # Function to predict rotation angle 41 | def predict_rotation_angle(image_path, model, transform): 42 | # Open and preprocess the image 43 | image = Image.open(image_path).convert('RGB') 44 | image = transform(image).unsqueeze(0) # Add batch dimension 45 | 46 | # Make prediction 47 | with torch.no_grad(): 48 | output = model(image) 49 | _, predicted = torch.max(output, 1) 50 | predicted_angle = predicted.item() # Get the predicted angle 51 | return predicted_angle 52 | 53 | if __name__ == '__main__': 54 | 55 | for name in random.sample(os.listdir("360"), 10): 56 | 57 | image_path = f'360/{name}' 58 | angle = image_path.split("_")[2].split(".")[0] 59 | 60 | # Predict rotation angle 61 | ts = time.time() 62 | predicted_angle = predict_rotation_angle(image_path, model, transform) 63 | print("路径:", image_path,", 真实角度:" , angle,", 预测角度:", predicted_angle, "耗时:",time.time()-ts) 64 | 65 | -------------------------------------------------------------------------------- /train_torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.utils.data import DataLoader, Dataset 5 | from torchvision import transforms 6 | from PIL import Image 7 | import os 8 | 9 | # Define the CNN architecture 10 | class CNN(nn.Module): 11 | def __init__(self): 12 | super(CNN, self).__init__() 13 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1) 14 | self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1) 15 | self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 16 | self.pool = nn.MaxPool2d(2, 2) 17 | self.fc1 = nn.Linear(64 * 5 * 5, 128) 18 | self.fc2 = nn.Linear(128, 360) # 360 classes for 0-359 degrees rotation 19 | 20 | def forward(self, x): 21 | x = self.pool(torch.relu(self.conv1(x))) 22 | x = self.pool(torch.relu(self.conv2(x))) 23 | x = self.pool(torch.relu(self.conv3(x))) 24 | x = x.view(-1, 64 * 5 * 5) 25 | x = torch.relu(self.fc1(x)) 26 | x = self.fc2(x) 27 | return x 28 | 29 | # Define dataset class 30 | class RotationDataset(Dataset): 31 | def __init__(self, root_dir, transform=None): 32 | self.root_dir = root_dir 33 | self.transform = transform 34 | self.image_files = sorted(os.listdir(root_dir)) 35 | 36 | def __len__(self): 37 | return len(self.image_files) 38 | 39 | def __getitem__(self, idx): 40 | img_name = self.image_files[idx] 41 | img_path = os.path.join(self.root_dir, img_name) 42 | image = Image.open(img_path).convert('RGB') 43 | label = int(img_name.split('_')[-1].split('.')[0]) # Extract label from file name 44 | # Ensure label is within bounds 45 | label = label % 359 # Limit label to 0-358 range 46 | if self.transform: 47 | image = self.transform(image) 48 | return image, label 49 | 50 | # Set device 51 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 52 | 53 | # Define transformations 54 | transform = transforms.Compose([ 55 | transforms.Resize((40, 40)), 56 | transforms.ToTensor(), 57 | ]) 58 | 59 | # Define dataset and dataloader 60 | dataset = RotationDataset(root_dir='360', transform=transform) 61 | dataloader = DataLoader(dataset, batch_size=64, shuffle=True) 62 | 63 | # Initialize the model, loss function, and optimizer 64 | model = CNN().to(device) 65 | criterion = nn.CrossEntropyLoss() 66 | optimizer = optim.Adam(model.parameters(), lr=0.001) 67 | 68 | # Training loop 69 | num_epochs = 100 70 | for epoch in range(num_epochs): 71 | running_loss = 0.0 72 | for i, data in enumerate(dataloader, 0): 73 | inputs, labels = data[0].to(device), data[1].to(device) 74 | 75 | # Zero the parameter gradients 76 | optimizer.zero_grad() 77 | 78 | # Forward + backward + optimize 79 | outputs = model(inputs) 80 | loss = criterion(outputs, labels) 81 | loss.backward() 82 | optimizer.step() 83 | 84 | # Print statistics 85 | running_loss += loss.item() 86 | if i % 100 == 99: # Print every 100 mini-batches 87 | print('[%d, %5d] loss: %.3f' % 88 | (epoch + 1, i + 1, running_loss / 100)) 89 | running_loss = 0.0 90 | 91 | print('Finished Training') 92 | 93 | # Save the model 94 | torch.save(model.state_dict(), 'rotate_model.pth') 95 | -------------------------------------------------------------------------------- /label.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import tkinter as tk 4 | import tkinter.messagebox 5 | from PIL import Image, ImageTk 6 | 7 | class ImageRotatorApp: 8 | def __init__(self, master, labels_path, output_path, all_image_index=500): 9 | 10 | self.config = {} 11 | 12 | self.output_path = output_path 13 | self.labels_path=labels_path 14 | self.master = master 15 | self.title = "双图旋转验证码标注系统" 16 | self.master.title(self.title) 17 | 18 | # Initialize variables 19 | self.angle = 0 20 | self.image_index = 1 21 | self.all_image_index = all_image_index 22 | 23 | if os.path.exists(f"{labels_path}/config.json"): 24 | with open(f"{labels_path}/config.json") as f: 25 | self.config = json.loads(f.read()) 26 | keys = [int(k) for k in self.config.keys()] 27 | self.image_index = int(max(keys))+1 28 | 29 | # Load images 30 | self.image1 = Image.open(f"{labels_path}/bg_{self.image_index}.png") 31 | self.image2 = Image.open(f"{labels_path}/center_{self.image_index}.png") 32 | self.image1_tk = ImageTk.PhotoImage(self.image1) 33 | self.image2_tk = ImageTk.PhotoImage(self.image2) 34 | 35 | self.master.title(f"{self.title} - {self.image_index} / {self.all_image_index}") 36 | 37 | # Create canvas 38 | self.canvas = tk.Canvas(master, width=self.image1.width, height=self.image1.height) 39 | self.canvas.pack() 40 | 41 | # Display images on canvas 42 | self.image1_id = self.canvas.create_image(0, 0, anchor=tk.NW, image=self.image1_tk) 43 | self.center_x = (self.image1.width - self.image2.width) // 2 44 | self.center_y = (self.image1.height - self.image2.height) // 2 45 | self.image2_id = self.canvas.create_image(self.center_x, self.center_y, anchor=tk.NW, image=self.image2_tk) 46 | 47 | # Create rotation slider 48 | self.rotation_slider = tk.Scale(master, from_=0, to=360, orient=tk.HORIZONTAL, length=360, command=self.rotate_image) 49 | self.rotation_slider.pack() 50 | 51 | # Create angle entry 52 | self.input_frame = tk.Frame(master) 53 | self.input_frame.pack() 54 | 55 | self.decrease_button = tk.Button(self.input_frame, text="-", command=lambda: self.adjust_angle(-1)) 56 | self.decrease_button.pack(side=tk.LEFT, pady=5) 57 | self.angle_var = tk.StringVar() 58 | self.angle_var.set(str(self.angle)) 59 | validate_cmd = master.register(self.validate_input) 60 | self.angle_entry = tk.Entry(self.input_frame, textvariable=self.angle_var, validate="key", validatecommand=(validate_cmd, '%P')) 61 | self.angle_entry.pack(side=tk.LEFT) 62 | self.increase_button = tk.Button(self.input_frame, text="+", command=lambda: self.adjust_angle(1)) 63 | self.increase_button.pack(side=tk.LEFT, pady=5) 64 | 65 | # Create "Next Group" button 66 | self.next_button = tk.Button(master, text="Next Group", command=self.next_group) 67 | self.next_button.pack(pady=10) 68 | 69 | # Create "Commit" button 70 | self.commit_button = tk.Button(master, text="Commit", command=self.commit_angle) 71 | self.commit_button.pack(pady=10) 72 | 73 | 74 | def rotate_image(self, angle): 75 | if not hasattr(self,"angle_entry"): 76 | return 77 | self.angle = int(angle) 78 | self.angle_entry.delete(0, tk.END) 79 | self.angle_entry.insert(0, str(self.angle)) 80 | 81 | # Rotate image2 82 | self.rotated_image2 = self.image2.rotate(-self.angle) 83 | self.rotated_image2_tk = ImageTk.PhotoImage(self.rotated_image2) 84 | self.canvas.itemconfig(self.image2_id, image=self.rotated_image2_tk) 85 | # Update image2 position to keep it centered 86 | self.canvas.coords(self.image2_id, self.center_x, self.center_y) 87 | 88 | def validate_input(self, event): 89 | value = self.angle_var.get() 90 | if not value.isdigit(): 91 | self.angle_var.set(str(self.angle)) 92 | else: 93 | self.rotation_slider.set(value) 94 | self.rotate_image(value) 95 | 96 | def adjust_angle(self, delta): 97 | new_angle = int(self.angle_var.get()) + delta 98 | if 0 <= new_angle <= 360: 99 | self.angle_var.set(str(new_angle)) 100 | self.rotation_slider.set(new_angle) 101 | self.rotate_image(new_angle) 102 | 103 | def save_rotate_image(self, image_path, degrees_to_rotate, output_path): 104 | # 打开图像 105 | image = Image.open(image_path) 106 | 107 | # 旋转图像 108 | rotated_image = image.rotate(-degrees_to_rotate) 109 | 110 | # 保存旋转后的图像 111 | rotated_image.save(output_path) 112 | 113 | print("图像已旋转并保存到:", output_path) 114 | 115 | def commit_angle(self): 116 | angle = int(self.angle_var.get()) 117 | self.config[str(self.image_index)] = angle 118 | with open(f"{self.labels_path}/config.json", "w") as f: 119 | f.write(json.dumps(self.config)) 120 | self.save_rotate_image(f"{self.labels_path}/center_{self.image_index}.png", angle, f"{self.output_path}/center_{self.image_index}.png") 121 | def next_group(self): 122 | # Change images 123 | self.commit_angle() 124 | self.image_index += 1 125 | if not os.path.exists(f"{self.labels_path}/bg_{self.image_index}.png") or not os.path.exists(f"{self.labels_path}/center_{self.image_index}.png"): 126 | tk.messagebox.showinfo("提示", "文件未找到或者已经标注结束") 127 | return 128 | 129 | self.master.title(f"{self.title} - {self.image_index} / {self.all_image_index}") 130 | self.image1 = Image.open(f"{self.labels_path}/bg_{self.image_index}.png") 131 | self.image2 = Image.open(f"{self.labels_path}/center_{self.image_index}.png") 132 | 133 | self.image1_tk = ImageTk.PhotoImage(self.image1) 134 | self.image2_tk = ImageTk.PhotoImage(self.image2) 135 | 136 | # Update canvas 137 | self.canvas.itemconfig(self.image1_id, image=self.image1_tk) 138 | self.canvas.itemconfig(self.image2_id, image=self.image2_tk) 139 | self.rotation_slider.set(0) # Reset slider 140 | self.angle = 0 141 | self.angle_entry.delete(0, tk.END) 142 | self.angle_entry.insert(0, str(self.angle)) 143 | 144 | def main(): 145 | labels_path = "xhs_captcha_imgs" # 需要标注的图片路径 146 | output_path = "temp" # 输出旋转到正确角度的验证码图片 147 | 148 | root = tk.Tk() 149 | app = ImageRotatorApp(root, labels_path, output_path) 150 | root.mainloop() 151 | 152 | if __name__ == "__main__": 153 | main() 154 | --------------------------------------------------------------------------------