├── models ├── __init__.py ├── crnn.py └── layer.py ├── test.jpg ├── datasets ├── __init__.py ├── collate_fn.py └── dataloader.py ├── dictionary ├── c2i_dict.pkl └── i2c_dict.pkl ├── utils ├── __init__.py ├── editDistance.py ├── EncoderDecoder.py ├── generate_dict.py ├── trans_dgrl.py └── divide.py ├── requirements.txt ├── config.py ├── val.py ├── LICENSE ├── README.md ├── evaluate.py ├── train.py ├── .gitignore └── main.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .crnn import CRNN -------------------------------------------------------------------------------- /test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyttyx/Chinese_OCR/HEAD/test.jpg -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataloader import train_Dataset 2 | from .collate_fn import collate_fn -------------------------------------------------------------------------------- /dictionary/c2i_dict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyttyx/Chinese_OCR/HEAD/dictionary/c2i_dict.pkl -------------------------------------------------------------------------------- /dictionary/i2c_dict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyttyx/Chinese_OCR/HEAD/dictionary/i2c_dict.pkl -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .EncoderDecoder import EncoderDecoder 2 | from .editDistance import editDistance -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.23.5 2 | opencv_python==4.8.0.74 3 | Pillow==10.0.1 4 | torch==2.0.1 5 | torchvision==0.15.2 6 | tqdm==4.65.0 7 | -------------------------------------------------------------------------------- /datasets/collate_fn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def collate_fn(batch: list): 4 | batch_size = len(batch) 5 | data = torch.stack([item[0] for item in batch]) 6 | labels = [item[1] for item in batch] 7 | len_labels = torch.tensor([len(item[1]) for item in batch]) 8 | max_label_length = max(len_labels) 9 | 10 | padded_labels = torch.zeros(batch_size, max_label_length, dtype=torch.int32) 11 | for i in range(batch_size): 12 | padded_labels[i,:len_labels[i]] = labels[i] 13 | 14 | return data, padded_labels, len_labels 15 | -------------------------------------------------------------------------------- /utils/editDistance.py: -------------------------------------------------------------------------------- 1 | 2 | def editDistance(str1:str, str2:str)->int: 3 | dp = [[0] * (len(str2) + 1) for _ in range(len(str1) + 1)] 4 | 5 | for i in range(len(str1) + 1): 6 | dp[i][0] = i 7 | for j in range(len(str2) + 1): 8 | dp[0][j] = j 9 | 10 | for i in range(1, len(str1) + 1): 11 | for j in range(1, len(str2) + 1): 12 | cost = 0 if str1[i - 1] == str2[j - 1] else 1 13 | dp[i][j] = min( 14 | dp[i - 1][j] + 1, # 插入 15 | dp[i][j - 1] + 1, # 删除 16 | dp[i - 1][j - 1] + cost # 替换 17 | ) 18 | 19 | return dp[len(str1)][len(str2)] -------------------------------------------------------------------------------- /models/crnn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .layer import CNNLayer, BiLSTMLayer, MobileNetV3 3 | 4 | # 定义CRNN模型 5 | class CRNN(nn.Module): 6 | def __init__(self, num_classes): 7 | super(CRNN,self).__init__() 8 | self.cnn = CNNLayer() 9 | #self.cnn = MobileNetV3('small') 10 | 11 | self.lstm_input_size = 512 12 | self.lstm_hidden_size = 512 13 | self.lstm = BiLSTMLayer(self.lstm_input_size, self.lstm_hidden_size, 2, num_classes) 14 | 15 | self.log_softmax = nn.LogSoftmax(dim = 2) 16 | 17 | def forward(self, input): 18 | x = self.cnn(input) 19 | x = x.squeeze(2) 20 | x = x.permute(2, 0, 1) 21 | t,n,c = x.size() 22 | assert c == self.lstm_input_size 23 | output = self.lstm(x) 24 | output = self.log_softmax(output) 25 | return output 26 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | CURRENT_EPOCH = 0 5 | TRAIN_EPOCH = 100 6 | BATCH_SIZE = 64 7 | 8 | DATA_PATH = os.path.abspath('..\\Chinese_OCR_data\\datasets') 9 | TRAIN_IMG_PATH = os.path.join(DATA_PATH, 'Train_images') 10 | TRAIN_LABEL_PATH = os.path.join(DATA_PATH, 'Train_label') 11 | TEST_IMG_PATH = os.path.join(DATA_PATH, 'Test_images') 12 | TEST_LABEL_PATH = os.path.join(DATA_PATH, 'Test_label') 13 | 14 | DICT_PATH = os.path.abspath('.\\dictionary') 15 | I2C_DICT_NAME = 'i2c_dict.pkl' 16 | C2I_DICT_NAME = 'c2i_dict.pkl' 17 | I2C_PATH = os.path.join(DICT_PATH, I2C_DICT_NAME) 18 | C2I_PATH = os.path.join(DICT_PATH, C2I_DICT_NAME) 19 | 20 | MODEL_SAVE_PATH = os.path.abspath('.\\models\\save') 21 | TRAIN_LOG = os.path.abspath('..\\Chinese_OCR_data\\models\\train_log') 22 | 23 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 24 | -------------------------------------------------------------------------------- /val.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import numpy as np 5 | import torchvision.transforms as transforms 6 | 7 | from models import CRNN 8 | from utils import EncoderDecoder 9 | import config 10 | 11 | def val(img_list): 12 | val_transform = transforms.Compose([ 13 | transforms.ToPILImage(), 14 | transforms.Resize((32,256)), 15 | transforms.ToTensor(), 16 | transforms.Normalize((0.5,), (0.5,)) 17 | ]) 18 | model_name = 'CRNN_best.pth' 19 | 20 | en_decoder = EncoderDecoder(config.C2I_PATH, config.I2C_PATH) 21 | model = CRNN(2700).to(config.DEVICE) 22 | model.load_state_dict(torch.load(os.path.join(config.MODEL_SAVE_PATH,model_name))) 23 | output_list = list() 24 | 25 | for img in img_list: 26 | 27 | input = val_transform(img) 28 | input = input.unsqueeze(0) 29 | input = input.to(config.DEVICE) 30 | outputs = model(input) 31 | output_str = en_decoder.TensorDecode(outputs[:,0,:])#一次一行,n=0 32 | output_list.append(output_str) 33 | 34 | return output_list -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 xyttyx 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Chinese_OCR 2 | 手写中文文本行识别,使用CRNN+CTC,在HWBD2.x数据集上accuracy达到**0.954** 3 | 4 | ## 详解CRNN+CTC 5 | ### CRNN部分 6 | CRNN采用CNN网络加RNN网络进行联合识别。首先使用CNN对网络的特征进行提取,其次使用RNN在时间序列上(此处的时间序列即为图片的横向坐标)对每个感受野下提取的特征进行识别。 7 | (未完待续) 8 | 9 | ### CTC函数 10 | 虽然在PyTorch之中,已经集成了CTCLoss这一损失函数,但了解原理对于我们的学习至关重要。下面将讲解CTCLoss的原理。 11 | (未完待续) 12 | *** 13 | ## 训练方式 14 | 使用requirements.txt安装必要的包,在命令行中输入 15 | `pip install -r requirements.txt` 16 | 17 | 请在当前工程文件夹的上一级文件夹中,创建数据集文件夹,默认文件夹名称为**Chinese_OCR_data**。此文件夹默认结构如下 18 | ``` 19 | Chinese_OCR_data 20 | |--datasets 21 | | |--Test_images 22 | | |--Test_label 23 | | |--Train_images 24 | | |--Train_label 25 | |--model 26 | ``` 27 | 28 | 请确认图片文件夹和标签文件夹下图片与标签名称相同,内容一一对应 29 | 请注意,默认标签为.txt文件,图片为.jpg文件。如果使用HWBD2.x数据集,则可以使用utils下的trans_dgrl.py将.dgrl文件分离为图片和标签。 30 | **如需使用trans_dgrl.py,请确认并修改文件中指向的路径** 31 | 32 | 在命令行中输入 33 | `python train.py` 34 | 开始训练 35 | 36 | ## 测试方式 37 | 使用evaluate.py,会使用测试集中的数据进行评估。评估所输出的accuracy 38 | 计算公式如下 39 | 40 | ***accuracy = average(1 - 编辑距离/标签字符串长度)*** 41 | 42 | 编辑距离的定义和计算方式可自行搜索,实现方式放在util.editDistance之中。 43 | *** 44 | #### 闲言碎语 45 | 这个网络与原始CRNN的区别在于: 46 | 47 | 1. 使用了9层卷积代替7层卷积; 48 | 2. 使用了双层的biLSTM作为转录层。 49 | 50 | 其中,9层卷积神经网络的设计参考了论文《Attention机制在脱机中文手写体文本行识别中的应用》。这篇论文在单层biLSTM后加入了Attention机制的seq2seq,其acc为0.9576,而本模型在训练80轮后效果最好的模型acc为0.954。 -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import torch 4 | 5 | from datasets import train_Dataset,collate_fn 6 | from torch.utils.data import DataLoader 7 | 8 | from models import CRNN 9 | from utils import EncoderDecoder, editDistance 10 | import config 11 | import time 12 | 13 | def evaluate(model=None): 14 | 15 | en_decoder = EncoderDecoder(config.C2I_PATH, config.I2C_PATH) 16 | ''' 17 | with open(config.TRAIN_LOG,'r') as log: 18 | _ , model_name = log.read().split() 19 | ''' 20 | dataset = train_Dataset(config.TEST_IMG_PATH, config.TEST_LABEL_PATH) 21 | dataloader = DataLoader(dataset, batch_size = 48, shuffle=False, collate_fn=collate_fn,drop_last=True) 22 | 23 | if model is None: 24 | model_name = 'CRNN_epoch80.pth' 25 | model = CRNN(2700).to(config.DEVICE) # 与训练集保持一致 26 | model.load_state_dict(torch.load(os.path.join(config.MODEL_SAVE_PATH, model_name))) 27 | 28 | accuracy_list = list() 29 | infer_time_list = list() 30 | for inputs, labels, _ in tqdm(dataloader,'evaluate processing') : 31 | 32 | #-------# 33 | begin_time = time.process_time_ns() 34 | inputs = inputs.to(config.DEVICE) 35 | outputs = model(inputs) 36 | end_time = time.process_time_ns() 37 | #-------# 38 | for i in range(len(inputs)): 39 | target = en_decoder.StringDecode(labels[i]).replace('#','') 40 | output_str = en_decoder.TensorDecode(outputs[:,i,:]) 41 | edit_distance = editDistance(target, output_str) 42 | accuracy_list.append(1 - edit_distance / len(target)) 43 | infer_time_list.append(end_time-begin_time) 44 | 45 | accuracy = sum(accuracy_list)/len(accuracy_list) 46 | infer_time_ns = sum(infer_time_list)/len(infer_time_list) 47 | print(f'Accuracy: {accuracy}, average infer time: {infer_time_ns/10000000}ms') 48 | return 49 | 50 | if __name__ == '__main__': 51 | evaluate() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.nn import functional as F 7 | 8 | from datasets import train_Dataset, collate_fn 9 | from torch.utils.data import DataLoader 10 | 11 | from models import CRNN 12 | from utils import EncoderDecoder 13 | from evaluate import evaluate 14 | import config 15 | 16 | def train(): 17 | ''' 18 | with open(config.TRAIN_LOG,'r') as log: 19 | current_epoch, model_name = log.read().split() 20 | current_epoch = int(current_epoch) 21 | ''' 22 | model_name = 'CRNN_epoch86.pth' 23 | current_epoch = 86 24 | dataset = train_Dataset(config.TRAIN_IMG_PATH, config.TRAIN_LABEL_PATH) 25 | dataloader = DataLoader(dataset, batch_size=config.BATCH_SIZE, shuffle=True, collate_fn=collate_fn) 26 | 27 | model = CRNN(2700) 28 | model.load_state_dict(torch.load(os.path.join(config.MODEL_SAVE_PATH, model_name))) 29 | model = model.to(config.DEVICE) 30 | 31 | criterion = nn.CTCLoss() 32 | optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9,0.999), eps=1e-8, weight_decay=1e-5) 33 | 34 | for epoch in range(current_epoch + 1, current_epoch + 1 + config.TRAIN_EPOCH): 35 | for inputs, labels, len_labels in tqdm(dataloader,f'epoch:{epoch}'): 36 | inputs = inputs.to(config.DEVICE) 37 | labels = labels.to(config.DEVICE) 38 | len_labels = len_labels.to(config.DEVICE) 39 | optimizer.zero_grad() 40 | outputs = model(inputs) 41 | len_seq = torch.ones(outputs.shape[1], dtype=torch.int32) * outputs.shape[0] 42 | len_seq = len_seq.to(config.DEVICE) 43 | loss = criterion(outputs, labels, len_seq, len_labels) 44 | loss.backward() 45 | optimizer.step() 46 | if epoch % 1 == 0: 47 | torch.save(model.state_dict(),os.path.join(config.MODEL_SAVE_PATH, f'CRNN_epoch{epoch}.pth')) 48 | print(f'Model has been saved') 49 | evaluate(model=model) 50 | 51 | return 52 | 53 | if __name__ == '__main__': 54 | train() -------------------------------------------------------------------------------- /datasets/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from torch.utils.data import DataLoader, Dataset 4 | import torchvision.transforms as transforms 5 | 6 | from utils.EncoderDecoder import EncoderDecoder 7 | from config import C2I_PATH, I2C_PATH 8 | 9 | train_transform = transforms.Compose([ 10 | transforms.Resize((32,256)), 11 | transforms.ToTensor(), 12 | transforms.Normalize((0.5,), (0.5,)) 13 | ]) 14 | 15 | class train_Dataset(Dataset): 16 | def __init__(self,img_folder:str,labels_folder:str): 17 | super(train_Dataset, self).__init__() 18 | self.transform = train_transform 19 | self.img_folder = img_folder 20 | self.img_names = os.listdir(img_folder) 21 | self.label_names = os.listdir(labels_folder) 22 | assert len(self.img_names) == len(self.label_names) 23 | for i in range(len(self.img_names)): 24 | assert self.img_names[i].replace('.jpg','') == self.label_names[i].replace('.txt','') 25 | self.labels = list() 26 | self.en_decoder = EncoderDecoder(C2I_PATH,I2C_PATH) 27 | for label_name in self.label_names: 28 | with open(os.path.join(labels_folder,label_name),'r',encoding='utf-8') as label_txt_file: 29 | tmp = label_txt_file.read() 30 | self.labels.append(tmp) 31 | 32 | def __len__(self): 33 | return len(self.labels) 34 | 35 | def __getitem__(self, index): 36 | img_name = self.img_names[index] 37 | img = Image.open(os.path.join(self.img_folder, img_name)) 38 | img = self.transform(img) 39 | label = self.labels[index] 40 | label = self.en_decoder.StringEncode(label) 41 | return img, label 42 | 43 | class val_Dataset(Dataset): 44 | def __init__(self,img_folder:str,labels_folder:str): 45 | super(train_Dataset, self).__init__() 46 | self.img_names = os.listdir(img_folder) 47 | self.label_names = os.listdir(labels_folder) 48 | assert len(self.img_names) == len(self.label_names) 49 | for i in range(len(self.img_names)): 50 | assert self.img_names[i].replace('.jpg','') == self.label_names[i].replace('.txt','') 51 | 52 | def __len__(self): 53 | return len(self.label_names) 54 | -------------------------------------------------------------------------------- /utils/EncoderDecoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import torch 5 | 6 | class EncoderDecoder(): 7 | def __init__(self,c2i_pkl_dict:str|dict = None, i2c_pkl_dict:str|dict = None): 8 | if isinstance(c2i_pkl_dict,str) and os.path.exists(c2i_pkl_dict): 9 | with open(c2i_pkl_dict,'rb') as c2i_file: 10 | self.c2i_dict = pickle.load(c2i_file) 11 | assert isinstance(self.c2i_dict, dict) 12 | elif isinstance(c2i_pkl_dict,dict): 13 | self.c2i_dict = c2i_pkl_dict.copy() 14 | else: 15 | self.c2i_dict = None 16 | if isinstance(i2c_pkl_dict,str) and os.path.exists(i2c_pkl_dict): 17 | with open(i2c_pkl_dict,'rb') as i2c_file: 18 | self.i2c_dict:dict = pickle.load(i2c_file) 19 | assert isinstance(self.i2c_dict, dict) 20 | elif isinstance(i2c_pkl_dict,dict): 21 | self.i2c_dict = i2c_pkl_dict.copy() 22 | else: 23 | self.i2c_dict = None 24 | 25 | def StringEncode(self, string:str)->torch.Tensor: 26 | assert self.c2i_dict is not None 27 | enc = torch.zeros(len(string), dtype = torch.int32) 28 | for i in range(len(string)): 29 | enc[i] = self.c2i_dict[string[i]] if string[i] in self.c2i_dict.keys() else 0 30 | return enc 31 | 32 | def StringDecode(self, enc:torch.Tensor|list)->str: 33 | assert self.i2c_dict is not None 34 | string = str() 35 | for i in range(len(enc)): 36 | key = int(enc[i]) 37 | if key == 0 or key > len(self.i2c_dict): 38 | string += '#' 39 | else: 40 | string += self.i2c_dict[key] 41 | return string 42 | 43 | def TensorDecode(self,probs:torch.Tensor)->str: #(T C) 44 | assert self.i2c_dict is not None 45 | T, C = probs.size() 46 | enc = torch.zeros(T, dtype=torch.int32) 47 | for i in range(T): 48 | T_prob = probs[i,:] 49 | max_index =int(torch.argmax(T_prob)) 50 | enc[i] = max_index 51 | tmp_string = self.StringDecode(enc) 52 | string = str() 53 | c = '' 54 | for i in range(T): 55 | if c != tmp_string[i]: 56 | c = tmp_string[i] 57 | string += c 58 | return string.replace('#', '') -------------------------------------------------------------------------------- /utils/generate_dict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | def generate_dict(): 5 | save_folder = '.\\dictionary'#vscode运行时当前文件夹为utils的上一级文件夹,如果运行时当前文件夹为utils则需修改save_folder,下面的data同理 6 | i2c_dict_path = os.path.join(save_folder,'i2c_dict.pkl') 7 | c2i_dict_path = os.path.join(save_folder,'c2i_dict.pkl') 8 | 9 | if os.path.exists(i2c_dict_path): 10 | with open(i2c_dict_path,'rb') as file: 11 | i2c_dict = pickle.load(file) 12 | assert isinstance(i2c_dict,dict) 13 | else: 14 | i2c_dict = dict() 15 | if os.path.exists(c2i_dict_path): 16 | with open(c2i_dict_path,'rb') as file: 17 | c2i_dict = pickle.load(file) 18 | assert isinstance(c2i_dict,dict) 19 | else: 20 | c2i_dict = dict() 21 | 22 | if len(c2i_dict) == 0 and len(i2c_dict) == 0: 23 | pass 24 | elif len(c2i_dict) != 0 and len(c2i_dict) != 0: 25 | for key in c2i_dict.keys(): 26 | key_ = c2i_dict[key] 27 | if i2c_dict[key_] == key: 28 | continue 29 | else: 30 | print('fatal dict load') 31 | return 32 | else: 33 | print('fatal dict load') 34 | return 35 | 36 | update_dict(c2i_dict, i2c_dict) 37 | 38 | with open(c2i_dict_path,'wb') as file: 39 | pickle.dump(c2i_dict,file) 40 | with open(i2c_dict_path,'wb') as file: 41 | pickle.dump(i2c_dict,file) 42 | print('lenght of dict is {}'.format(len(c2i_dict))) 43 | 44 | return 45 | 46 | 47 | 48 | def update_dict(c2i_dict: dict, i2c_dict: dict): 49 | i_number = len(c2i_dict) + 1 # 序号0不分配,留给CTCloss的占位符 50 | list_string = generate_list_string() 51 | for string in list_string: 52 | for char in string: 53 | if char not in c2i_dict.keys(): 54 | c2i_dict[char] = i_number 55 | i2c_dict[i_number] = char 56 | i_number += 1 57 | return 58 | 59 | 60 | 61 | def generate_list_string(): 62 | data_folder_path = '..\\Chinese_OCR_data\\datasets' 63 | txt_folder_name = 'Train_label' 64 | txt_folder_path = os.path.join(data_folder_path, txt_folder_name) 65 | list_string = list() 66 | for _0, _1, files in os.walk(txt_folder_path): 67 | for file in files: 68 | txt_file_path = os.path.join(txt_folder_path, file) 69 | with open(txt_file_path,'r',encoding='utf-8') as txt_file: 70 | context = txt_file.read() 71 | list_string.append(context) 72 | return list_string 73 | 74 | 75 | 76 | if __name__ == '__main__': 77 | generate_dict() 78 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | 141 | # pytype static type analyzer 142 | .pytype/ 143 | 144 | # Deep learning models 145 | /models/save 146 | *.pth 147 | *.pt 148 | 149 | # Cython debug symbols 150 | cython_debug/ 151 | 152 | # PyCharm 153 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 154 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 155 | # and can be added to the global gitignore or merged into this file. For a more nuclear 156 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 157 | .idea/ 158 | -------------------------------------------------------------------------------- /utils/trans_dgrl.py: -------------------------------------------------------------------------------- 1 | import struct 2 | import os 3 | import cv2 as cv 4 | import numpy as np 5 | 6 | dgrl_folder_path = '.\\data\\HWDB2.2Train' 7 | total = 0 8 | 9 | def trans_dgrl(): 10 | for root, dirs, files in os.walk(dgrl_folder_path): 11 | for file in files: 12 | dgrl_file_paht = os.path.join(root, file) 13 | read_from_dgrl(dgrl_file_paht) 14 | return 15 | 16 | def read_from_dgrl(dgrl): 17 | 18 | dir_name,base_name = os.path.split(dgrl) 19 | label_dir = dir_name+'_label' 20 | image_dir = dir_name+'_images' 21 | if not os.path.exists(label_dir): 22 | os.makedirs(label_dir) 23 | if not os.path.exists(image_dir): 24 | os.makedirs(image_dir) 25 | 26 | with open(dgrl, 'rb') as f: 27 | # 读取表头尺寸 28 | header_size = np.fromfile(f, dtype='uint8', count=4) 29 | header_size = sum([j<<(i*8) for i,j in enumerate(header_size)]) 30 | # print(header_size) 31 | 32 | # 读取表头剩下内容,提取 code_length 33 | header = np.fromfile(f, dtype='uint8', count=header_size-4) 34 | code_length = sum([j<<(i*8) for i,j in enumerate(header[-4:-2])]) 35 | # print(code_length) 36 | 37 | # 读取图像尺寸信息,提取图像中行数量 38 | image_record = np.fromfile(f, dtype='uint8', count=12) 39 | height = sum([j<<(i*8) for i,j in enumerate(image_record[:4])]) 40 | width = sum([j<<(i*8) for i,j in enumerate(image_record[4:8])]) 41 | line_num = sum([j<<(i*8) for i,j in enumerate(image_record[8:])]) 42 | #print('图像尺寸:') 43 | #print(height, width, line_num) 44 | 45 | # 读取每一行的信息 46 | for k in range(line_num): 47 | #print(k+1) 48 | 49 | # 读取该行的字符数量 50 | char_num = np.fromfile(f, dtype='uint8', count=4) 51 | char_num = sum([j<<(i*8) for i,j in enumerate(char_num)]) 52 | #print('字符数量:', char_num) 53 | 54 | # 读取该行的标注信息 55 | label = np.fromfile(f, dtype='uint8', count=code_length*char_num) 56 | label = [label[i]<<(8*(i%code_length)) for i in range(code_length*char_num)] 57 | label = [sum(label[i*code_length:(i+1)*code_length]) for i in range(char_num)] 58 | label = [struct.pack('I', i).decode('gbk', 'ignore')[0] for i in label] 59 | #print('合并前:', label) 60 | label = ''.join(label) 61 | label = ''.join(label.split(b'\x00'.decode())) # 去掉不可见字符 \x00,这一步不加的话后面保存的内容会出现看不见的问题 62 | #print('合并后:', label) 63 | 64 | # 读取该行的位置和尺寸 65 | pos_size = np.fromfile(f, dtype='uint8', count=16) 66 | y = sum([j<<(i*8) for i,j in enumerate(pos_size[:4])]) 67 | x = sum([j<<(i*8) for i,j in enumerate(pos_size[4:8])]) 68 | h = sum([j<<(i*8) for i,j in enumerate(pos_size[8:12])]) 69 | w = sum([j<<(i*8) for i,j in enumerate(pos_size[12:])]) 70 | # print(x, y, w, h) 71 | 72 | # 读取该行的图片 73 | bitmap = np.fromfile(f, dtype='uint8', count=h*w) 74 | bitmap = np.array(bitmap).reshape(h, w) 75 | 76 | # 保存信息 77 | label_file = os.path.join(label_dir, base_name.replace('.dgrl', '_'+str(k)+'.txt')) 78 | with open(label_file, 'w', encoding='utf-8') as f1: 79 | f1.write(label) 80 | bitmap_file = os.path.join(image_dir, base_name.replace('.dgrl', '_'+str(k)+'.jpg')) 81 | cv.imwrite(bitmap_file, bitmap) 82 | 83 | #确认程序正常运行 84 | global total 85 | total += 1 86 | if total % 100 == 0: 87 | print(total) 88 | 89 | 90 | if __name__ == '__main__': 91 | trans_dgrl() -------------------------------------------------------------------------------- /utils/divide.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import statistics # 中位数 4 | 5 | 6 | def get_vvList(list_data): 7 | vv_list=list() 8 | v_list=list() 9 | for index,i in enumerate(list_data): 10 | if i>0: 11 | v_list.append(index) 12 | else: 13 | if v_list: 14 | vv_list.append(v_list) 15 | v_list=[] 16 | # 最后一行 17 | if len(v_list) > 0: 18 | vv_list.append(v_list) 19 | return vv_list 20 | 21 | 22 | ''' 23 | img:投影法将依据它产生切割行的切割线 24 | base_img:根据img获得切割线后, 将在base_img上切割 25 | THRESH:小于THRESH割像素的行被视为空白分割行 26 | 返回值:两个列表,前者每个元素是列表,保存了一个切割后的文本行图像;后者的 27 | 每个元素保存了前者对应的切割后的文本行的行高 28 | ''' 29 | def slice(img, base_img=None, THRESH=15): 30 | if base_img is None: 31 | base_img = img 32 | # 局部阈值二值化 33 | bi_img = cv2.adaptiveThreshold(img, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 5) 34 | 35 | # 中值滤波 36 | bi_img = cv2.medianBlur(bi_img, ksize=3) 37 | 38 | # 投影法切行 39 | row, col = bi_img.shape 40 | hor_count = [0] * row # 保存每行的白色像素数量(即文字像素数量) 41 | for i in range(row): 42 | for j in range(col): 43 | if bi_img[i][j] == 255: 44 | hor_count[i] = hor_count[i] + 1 45 | 46 | hor_count = np.array(hor_count) 47 | hor_count[np.where(hor_count < THRESH)] = 0 48 | 49 | hor_count = hor_count.tolist() 50 | 51 | vv_list = get_vvList(hor_count) 52 | 53 | img_list = [] # 输出的分行的图像列表 54 | rows_height = [] # 每行的行高 55 | 56 | for i in vv_list: 57 | img_hor = base_img[i[0]:i[-1], :] 58 | 59 | if img_hor.size != 0: 60 | img_list.append([img_hor]) 61 | rows_height.append(i[-1] - i[0] + 1) 62 | 63 | return img_list, rows_height 64 | 65 | 66 | 67 | ''' 68 | input:img-输入待划分行的图像,或使用img_path指明图像路径,THRESH-小于THRESH 69 | 个像素的行将被视为空白行 70 | output:一个列表,每个元素是一个划分后的文本行图像 71 | ''' 72 | def divide_text(img = None, img_path=r'.\test.jpg', THRESH = 15): 73 | if img is None: 74 | img = cv2.imread(img_path) 75 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 76 | 77 | # 改善光照不均 78 | # 计算所有像素的平均强度 79 | img = img.astype(float) 80 | mean_intensity = np.mean(img) 81 | 82 | # 使用闭合重建来估计背景光照水平 83 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (50, 50)) 84 | background = cv2.morphologyEx(img, cv2.MORPH_CLOSE, kernel) 85 | 86 | # 从原始图像中减去背景光照 87 | img = img - background 88 | 89 | # 添加平均强度 90 | img = img + mean_intensity 91 | img = np.clip(img, 0, 255).astype('uint8') 92 | 93 | # 对比度增强 94 | # 定义alpha(对比度控制)和beta(亮度控制) 95 | alpha = 1.5 96 | beta = 0 97 | 98 | # 使用cv2.convertScaleAbs()调整对比度和亮度 99 | base_img = cv2.convertScaleAbs(img, alpha=alpha, beta=beta) 100 | 101 | # 切割文本行 102 | img_list, rows_height = slice(img, base_img, THRESH=THRESH) 103 | 104 | ''' 105 | 由于投影法分出的文本行图像可能过度分割出一些无效的行, 106 | 或者分不出来,形成一个包含多行文本的图像,故取行高的中位数, 107 | flag标记分出的所有文本行图像:0-正常,1-过小,2-过大, 108 | 定义大于中位数1.8倍为过大,小于中位数0.2倍为过小. 109 | 删除过小的,二次划分过大的. 110 | ''' 111 | median_height = statistics.median(rows_height) 112 | flag = [0]*len(img_list) 113 | for i, height in enumerate(rows_height): 114 | if height <= median_height*0.2: 115 | flag[i] = 1 116 | elif height >= median_height * 1.8: 117 | flag[i] = 2 118 | 119 | large_idx = [i for i, item in enumerate(flag) if item == 2] 120 | for idx in large_idx: 121 | large_lines = img_list[idx][0] 122 | # 腐蚀膨胀 123 | # 创建核结构 124 | kernel1 = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)) # 矩形结构 125 | kernel2 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) # 椭圆结构 126 | kernel3 = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)) # 十字形结构 127 | # 腐蚀 128 | img_erode = cv2.erode(large_lines, kernel3, iterations=2) 129 | 130 | # 膨胀 131 | img_dilate = cv2.dilate(img_erode, kernel1, iterations=3) 132 | 133 | # 二次切割文本行 134 | imgs_list, heights = slice(img=img_dilate, base_img=large_lines, THRESH=3*THRESH) 135 | 136 | del_idx = [] 137 | for i, height in enumerate(heights): 138 | # 仅删除过小的文本行划分 139 | if height <= median_height * 0.2: 140 | del_idx.append(i) 141 | 142 | for i in reversed(del_idx): 143 | del imgs_list[i] 144 | 145 | # 用新划分的文本行图像列表替换之前的文本行图像 146 | img_list[idx] = [item[0] for item in imgs_list] 147 | 148 | # 将最终划分结果整合到一个列表,每个元素是一个划分后的文本行图像 149 | result_imgs = [] 150 | for i, imgs in enumerate(img_list): 151 | if flag[i] == 1: 152 | continue # 过小图像跳过 153 | for img in imgs: 154 | result_imgs.append(img) 155 | 156 | return result_imgs 157 | 158 | 159 | # imgs = detect_txt(img_path='./zh.jpg',THRESH=10) 160 | # for img in imgs: 161 | # cv2.imshow('show', img) 162 | # cv2.waitKey(0) -------------------------------------------------------------------------------- /models/layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class CNNLayer(nn.Module): 5 | #使用前将Dataset的resize部分改为32先56 6 | def __init__(self): 7 | super(CNNLayer, self).__init__() 8 | # suppose H = 32 9 | self.conv_1 = nn.Sequential( 10 | nn.Conv2d(in_channels=1,out_channels=64,kernel_size=3,stride=1,padding=1), 11 | nn.ReLU(), 12 | 13 | nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1,padding=1), 14 | nn.ReLU(), 15 | 16 | nn.MaxPool2d(kernel_size=2,stride=2) 17 | ) 18 | self.conv_2 = nn.Sequential( 19 | nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,stride=1,padding=1), 20 | nn.ReLU(), 21 | 22 | nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3,stride=1,padding=1), 23 | nn.ReLU(), 24 | 25 | nn.MaxPool2d(kernel_size=2,stride=2), 26 | ) 27 | self.conv_3 = nn.Sequential( 28 | nn.Conv2d(in_channels=128,out_channels=256,kernel_size=3,stride=1,padding=1), 29 | nn.BatchNorm2d(256), 30 | nn.ReLU(), 31 | 32 | nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,stride=1,padding=1), 33 | nn.ReLU(), 34 | 35 | nn.MaxPool2d(kernel_size=(2,1),stride=(2,1)), 36 | ) 37 | self.conv_4 = nn.Sequential( 38 | nn.Conv2d(in_channels=256,out_channels=512,kernel_size=3,stride=1,padding=1), 39 | nn.BatchNorm2d(512), 40 | nn.ReLU(), 41 | 42 | nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,stride=1,padding=1), 43 | nn.BatchNorm2d(512), 44 | nn.ReLU(), 45 | 46 | nn.MaxPool2d(kernel_size=(2,1),stride=(2,1)), 47 | 48 | nn.Conv2d(in_channels=512,out_channels=512,kernel_size=(2,1),stride=(2,1),padding=0), 49 | nn.BatchNorm2d(512), 50 | nn.ReLU() 51 | ) 52 | 53 | def forward(self, x): 54 | x = self.conv_1(x) 55 | x = self.conv_2(x) 56 | x = self.conv_3(x) 57 | x = self.conv_4(x) 58 | return x 59 | 60 | 61 | class MobileNetV3(nn.Module): 62 | #暂定输入图像为 63 | def __init__(self,mode:str): 64 | super(MobileNetV3, self).__init__() 65 | self.first_out_channels = 16 66 | self.first_conv = nn.Sequential( 67 | nn.Conv2d(in_channels=1, out_channels=self.first_out_channels, kernel_size=3, stride=2, padding=1), 68 | nn.BatchNorm2d(16), 69 | nn.Hardswish() 70 | ) 71 | 72 | if mode == 'large': 73 | # refer to Table 1 in paper 74 | mobile_setting = [ 75 | # k, exp, out_c, se, nl, s, 76 | [3, 16, 16, False, 'RE', 1], 77 | [3, 64, 24, False, 'RE', 2], 78 | [3, 72, 24, False, 'RE', 1], 79 | [5, 72, 40, True, 'RE', 2], 80 | [5, 120, 40, True, 'RE', 1], 81 | [5, 120, 40, True, 'RE', 1], 82 | [3, 240, 80, False, 'HS', 2], 83 | [3, 200, 80, False, 'HS', 1], 84 | [3, 184, 80, False, 'HS', 1], 85 | [3, 184, 80, False, 'HS', 1], 86 | [3, 480, 112, True, 'HS', 1], 87 | [3, 672, 112, True, 'HS', 1], 88 | [5, 672, 160, True, 'HS', 1], 89 | [5, 960, 160, True, 'HS', 1], 90 | [5, 960, 160, True, 'HS', 1], 91 | ] 92 | else: 93 | # refer to Table 2 in paper 94 | mobile_setting = [ 95 | # k, exp, out_c, se, nl, s, 96 | [3, 16, 16, True, 'RE', 2], 97 | [3, 72, 24, False, 'RE', 2], 98 | [3, 88, 24, False, 'RE', 1], 99 | [5, 96, 40, True, 'HS', 2], 100 | [5, 240, 40, True, 'HS', 1], 101 | [5, 240, 40, True, 'HS', 1], 102 | [5, 120, 48, True, 'HS', 1], 103 | [5, 144, 48, True, 'HS', 1], 104 | [5, 288, 96, True, 'HS', 1], 105 | [5, 576, 96, True, 'HS', 1], 106 | [5, 576, 160, True, 'HS', 1], 107 | ] 108 | self.main_layers = self.make_layers(mobile_setting)# B C H W 109 | self.maxpool = nn.MaxPool2d(kernel_size=(2,1),stride=(2,1)) 110 | 111 | def make_layers(self, mobile_setting): 112 | layers = [] 113 | for i in range(len(mobile_setting)): 114 | if i == 0: 115 | layer = MobileNetV3Block(self.first_out_channels, 116 | mobile_setting[i][2], 117 | mobile_setting[i][0], 118 | mobile_setting[i][1], 119 | mobile_setting[i][3], 120 | mobile_setting[i][4], 121 | mobile_setting[i][5]) 122 | else: 123 | layer = MobileNetV3Block(mobile_setting[i-1][2], 124 | mobile_setting[i][2], 125 | mobile_setting[i][0], 126 | mobile_setting[i][1], 127 | mobile_setting[i][3], 128 | mobile_setting[i][4], 129 | mobile_setting[i][5]) 130 | layers.append(layer) 131 | return nn.Sequential(*layers) 132 | 133 | def forward(self,input): 134 | output = self.first_conv(input) 135 | output = self.main_layers(output) 136 | output:torch.Tensor = self.maxpool(output) 137 | b,c,h,w = output.size() 138 | output = output.view(b, -1, 1, w) 139 | return output 140 | 141 | class MobileNetV3Block(nn.Module): 142 | def __init__(self, in_channels, out_channels, kernel_size, exp_size:int, use_se=True, nl:str='RE', stride:int=1,): 143 | super(MobileNetV3Block, self).__init__() 144 | self.NL = nn.ReLU6() if nl == 'RE' else nn.Hardswish() 145 | self.use_se = use_se 146 | self.in_channels = in_channels 147 | self.out_channels = out_channels 148 | self.stride = stride 149 | # Depthwise convolution 150 | self.dwise = nn.Sequential( 151 | nn.Conv2d(in_channels, exp_size, kernel_size=1, stride=1, padding=0, bias=False), 152 | nn.BatchNorm2d(exp_size), 153 | self.NL 154 | ) 155 | 156 | self.conv = nn.Sequential( 157 | nn.Conv2d(exp_size, exp_size, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, groups=exp_size, bias=False), 158 | nn.BatchNorm2d(exp_size), 159 | self.NL 160 | ) 161 | # Squeeze-and-Excitation (SE) block 162 | if self.use_se: 163 | self.se = nn.Sequential( 164 | nn.AdaptiveAvgPool2d(1), 165 | nn.Conv2d(exp_size, exp_size // 4, kernel_size=1, stride=1, padding=0), 166 | nn.ReLU(), 167 | nn.Conv2d(exp_size // 4, exp_size, kernel_size=1, stride=1, padding=0), 168 | nn.Hardsigmoid() 169 | ) 170 | 171 | # Pointwise convolution 172 | self.pwise = nn.Sequential( 173 | nn.Conv2d(exp_size, out_channels, kernel_size=1, stride=1, padding=0, bias=False), 174 | nn.BatchNorm2d(out_channels) 175 | ) 176 | 177 | def forward(self, x): 178 | out = self.dwise(x) 179 | out = self.conv(out) 180 | if self.use_se: 181 | se_weight = self.se(out) 182 | out = out * se_weight 183 | out = self.pwise(out) 184 | if self.in_channels == self.out_channels and self.stride == 1: 185 | out += x 186 | return out 187 | 188 | class BiLSTMLayer(nn.Module): 189 | def __init__(self, input_size, hidden_size, num_layers, out_length): 190 | super(BiLSTMLayer, self).__init__() 191 | self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=True) 192 | self.fc = nn.Linear(hidden_size * 2, out_length) 193 | 194 | def forward(self, x): 195 | out, _ = self.lstm(x) 196 | out = self.fc(out) 197 | return out -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from PyQt5.QtCore import QSize, Qt, QRect 2 | from PyQt5.QtGui import QColor 3 | from PyQt5.QtWidgets import QApplication, QDialog, QGridLayout, QPushButton, QSpacerItem, QSizePolicy,QTextEdit,QFileDialog,QLabel 4 | from PyQt5 import QtWidgets,uic,QtCore 5 | from PyQt5.QtCore import QRectF, Qt, pyqtSignal 6 | from PyQt5.QtGui import QColor, QPixmap, QPen 7 | from PyQt5.QtWidgets import QGraphicsView, QGraphicsPixmapItem, QGraphicsScene, QGraphicsItem 8 | 9 | from utils import divide 10 | from val import val 11 | 12 | class Form(QDialog): 13 | def __init__(self): 14 | super(Form, self).__init__() 15 | 16 | self.setWindowTitle("手写中文识别") # 修改标题 17 | self.resize(1024, 700) 18 | self.picture = None 19 | self.init_ui() 20 | 21 | 22 | self.graphicsView.save_signal.connect(self.pushButton_save.setEnabled) 23 | self.pushButton_cut.clicked.connect(self.pushButton_cut_clicked) 24 | self.pushButton_save.clicked.connect(self.pushButton_save_clicked) 25 | self.pushButton_xianshi.clicked.connect(self.boxSelect) 26 | self.pushButton_shibie.clicked.connect(self.shibie) 27 | 28 | # image_item = GraphicsPolygonItem() 29 | # image_item.setFlag(QGraphicsItem.ItemIsMovable) 30 | # self.scene.addItem(image_item) 31 | def boxSelect(self): 32 | options = QFileDialog.Options() 33 | options |= QFileDialog.ReadOnly 34 | file_name, _ = QFileDialog.getOpenFileName(self, "选择图片", "", 35 | "Images (*.png *.xpm *.jpg);;All Files (*)", options=options) 36 | if file_name: 37 | pixmap = QPixmap(file_name) 38 | pixmap.save(r"test.jpg") 39 | self.picture=r"test.jpg" 40 | self.graphicsView.scene.clear() 41 | self.graphicsView.image_item = GraphicsPixmapItem(QPixmap(self.picture)) 42 | self.graphicsView.image_item.setFlag(QGraphicsItem.ItemIsMovable) 43 | self.graphicsView.scene.addItem(self.graphicsView.image_item) 44 | size = self.graphicsView.image_item.pixmap().size() 45 | self.graphicsView.image_item.setPos(-size.width() / 2, -size.height() / 2) 46 | def shibie(self): 47 | img_list = divide.divide_text() 48 | str_list = val(img_list=img_list) 49 | str_show = '\n'.join(str_list) 50 | self.text_edit.setPlainText(str_show) 51 | 52 | def init_ui(self): 53 | 54 | background = QLabel(self) 55 | background.setStyleSheet("background-color: lightblue;") 56 | background.resize(1224, 700) 57 | background.move(0, 0) 58 | background.lower() # 将背景放在最底层 59 | 60 | 61 | self.gridLayout = QGridLayout(self) 62 | self.pushButton_cut = QPushButton('剪切', self) 63 | self.pushButton_cut.setCheckable(True) 64 | self.pushButton_cut.setMaximumSize(QSize(200, 16777215)) 65 | self.gridLayout.addWidget(self.pushButton_cut, 1, 1, 1, 1) 66 | self.pushButton_cut.setStyleSheet( 67 | "background-color: rgb(120, 120, 120);" 68 | "color: white;" 69 | "border-radius: 5px;" 70 | ) 71 | 72 | self.pushButton_save = QPushButton('保存', self) 73 | self.pushButton_save.setEnabled(False) 74 | self.pushButton_save.setMaximumSize(QSize(200, 16777215)) 75 | self.gridLayout.addWidget(self.pushButton_save, 1, 2, 1, 1) 76 | self.pushButton_save.setStyleSheet( 77 | "background-color: rgb(120, 120, 120);" 78 | "color: white;" 79 | "border-radius: 5px;" 80 | ) 81 | 82 | self.pushButton_xianshi = QPushButton('选取图片', self) 83 | self.pushButton_xianshi.setCheckable(True) 84 | self.pushButton_xianshi.setMaximumSize(QSize(200, 16777215)) 85 | self.gridLayout.addWidget(self.pushButton_xianshi, 1, 0, 1, 1) 86 | self.pushButton_xianshi.setStyleSheet( 87 | "background-color: rgb(120, 120, 120);" 88 | "color: white;" 89 | "border-radius: 5px;" 90 | ) 91 | 92 | self.pushButton_shibie = QPushButton('识别', self) 93 | self.pushButton_shibie.setCheckable(True) 94 | self.pushButton_shibie.setMaximumSize(QSize(200, 16777215)) 95 | self.gridLayout.addWidget(self.pushButton_shibie, 1, 3, 1, 1) 96 | self.pushButton_shibie.setStyleSheet( 97 | "background-color: rgb(120, 120, 120);" 98 | "color: white;" 99 | "border-radius: 5px;" 100 | ) 101 | 102 | self.text_edit = QTextEdit(self) 103 | self.gridLayout.addWidget(self.text_edit, 0,3, 1, 2) 104 | 105 | self.graphicsView = GraphicsView(self.picture, self) 106 | self.graphicsView.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff) 107 | self.graphicsView.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) 108 | self.gridLayout.addWidget(self.graphicsView, 0, 0, 1, 3) 109 | 110 | 111 | # 设置按钮的宽度 112 | button_width = 150 113 | self.pushButton_cut.setFixedWidth(button_width) 114 | self.pushButton_save.setFixedWidth(button_width) 115 | self.pushButton_xianshi.setFixedWidth(button_width) 116 | self.pushButton_shibie.setFixedWidth(button_width) 117 | 118 | # 设置按钮的高度 119 | button_height = 40 120 | self.pushButton_cut.setFixedHeight(button_height) 121 | self.pushButton_save.setFixedHeight(button_height) 122 | self.pushButton_xianshi.setFixedHeight(button_height) 123 | self.pushButton_shibie.setFixedHeight(button_height) 124 | 125 | # 设置文本编辑框的高度 126 | text_edit_height = 600 127 | text_edit_width=300 128 | self.text_edit.setFixedHeight(text_edit_height) 129 | self.text_edit.setFixedWidth(text_edit_width) 130 | # 设置 GraphicsView 的大小 131 | graphics_view_width = 800 132 | graphics_view_height = 600 133 | self.graphicsView.setFixedWidth(graphics_view_width) 134 | self.graphicsView.setFixedHeight(graphics_view_height) 135 | 136 | def pushButton_cut_clicked(self): 137 | if self.graphicsView.image_item.is_start_cut: 138 | self.graphicsView.image_item.is_start_cut = False 139 | self.graphicsView.image_item.setCursor(Qt.ArrowCursor) # 箭头光标 140 | else: 141 | self.graphicsView.image_item.is_start_cut = True 142 | self.graphicsView.image_item.setCursor(Qt.CrossCursor) # 十字光标 143 | 144 | def pushButton_save_clicked(self): 145 | rect = QRect(self.graphicsView.image_item.start_point.toPoint(), 146 | self.graphicsView.image_item.end_point.toPoint()) 147 | new_pixmap = self.graphicsView.image_item.pixmap().copy(rect) 148 | new_pixmap.save(r'test.jpg') 149 | self.picture = r"test.jpg" 150 | self.graphicsView.scene.clear() 151 | self.graphicsView.image_item = GraphicsPixmapItem(QPixmap(self.picture)) 152 | self.graphicsView.image_item.setFlag(QGraphicsItem.ItemIsMovable) 153 | self.graphicsView.scene.addItem(self.graphicsView.image_item) 154 | size = self.graphicsView.image_item.pixmap().size() 155 | self.graphicsView.image_item.setPos(-size.width() / 2, -size.height() / 2) 156 | 157 | class GraphicsView(QGraphicsView): 158 | save_signal = pyqtSignal(bool) 159 | 160 | def __init__(self, picture, parent=None): 161 | super(GraphicsView, self).__init__(parent) 162 | self.setBackgroundBrush(QColor(14, 20, 20)) 163 | 164 | # 设置放大缩小时跟随鼠标 165 | self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse) 166 | self.setResizeAnchor(QGraphicsView.AnchorUnderMouse) 167 | 168 | self.scene = QGraphicsScene() 169 | self.setScene(self.scene) 170 | 171 | self.image_item = GraphicsPixmapItem(QPixmap(picture)) 172 | self.image_item.setFlag(QGraphicsItem.ItemIsMovable) 173 | self.scene.addItem(self.image_item) 174 | 175 | size = self.image_item.pixmap().size() 176 | # 调整图片在中间 177 | self.image_item.setPos(-size.width() / 2, -size.height() / 2) 178 | 179 | self.scale(0.4, 0.4) 180 | 181 | def wheelEvent(self, event): 182 | '''滚轮事件''' 183 | zoomInFactor = 1.25 184 | zoomOutFactor = 1 / zoomInFactor 185 | 186 | if event.angleDelta().y() > 0: 187 | zoomFactor = zoomInFactor 188 | else: 189 | zoomFactor = zoomOutFactor 190 | 191 | self.scale(zoomFactor, zoomFactor) 192 | 193 | def mouseReleaseEvent(self, event): 194 | '''鼠标释放事件''' 195 | # print(self.image_item.is_finish_cut, self.image_item.is_start_cut) 196 | if self.image_item.is_finish_cut: 197 | self.save_signal.emit(True) 198 | else: 199 | self.save_signal.emit(False) 200 | 201 | 202 | class GraphicsPixmapItem(QGraphicsPixmapItem): 203 | save_signal = pyqtSignal(bool) 204 | 205 | def __init__(self, picture, parent=None): 206 | super(GraphicsPixmapItem, self).__init__(parent) 207 | 208 | self.setPixmap(picture) 209 | self.is_start_cut = False 210 | self.current_point = None 211 | self.is_finish_cut = False 212 | 213 | def mouseMoveEvent(self, event): 214 | '''鼠标移动事件''' 215 | self.current_point = event.pos() 216 | if not self.is_start_cut or self.is_midbutton: 217 | self.moveBy(self.current_point.x() - self.start_point.x(), 218 | self.current_point.y() - self.start_point.y()) 219 | self.is_finish_cut = False 220 | self.update() 221 | 222 | def mousePressEvent(self, event): 223 | '''鼠标按压事件''' 224 | super(GraphicsPixmapItem, self).mousePressEvent(event) 225 | self.start_point = event.pos() 226 | self.current_point = None 227 | self.is_finish_cut = False 228 | if event.button() == Qt.MidButton: 229 | self.is_midbutton = True 230 | self.update() 231 | else: 232 | self.is_midbutton = False 233 | self.update() 234 | 235 | def paint(self, painter, QStyleOptionGraphicsItem, QWidget): 236 | super(GraphicsPixmapItem, self).paint(painter, QStyleOptionGraphicsItem, QWidget) 237 | if self.is_start_cut and not self.is_midbutton: 238 | # print(self.start_point, self.current_point) 239 | pen = QPen(Qt.DashLine) 240 | pen.setColor(QColor(0, 150, 0, 70)) 241 | pen.setWidth(3) 242 | painter.setPen(pen) 243 | painter.setBrush(QColor(0, 0, 255, 70)) 244 | if not self.current_point: 245 | return 246 | painter.drawRect(QRectF(self.start_point, self.current_point)) 247 | self.end_point = self.current_point 248 | self.is_finish_cut = True 249 | 250 | if __name__ == '__main__': 251 | import sys 252 | 253 | QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling) 254 | app = QApplication(sys.argv) 255 | form = Form() 256 | form.show() 257 | app.exec_() 258 | --------------------------------------------------------------------------------