├── LICENSE.md ├── README.md ├── average_meter.py ├── collatefn.py ├── dataset.py ├── demo.py ├── english.txt ├── gen_image.py ├── img_aug.py ├── infer_tool.py ├── label_converter.py ├── logger.py ├── loss.py ├── metric.py ├── model.py ├── requirements.txt └── train.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 william_lzw 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## MicroOCR 2 | a micro OCR network. 3 | 4 | This model can handle complex tasks without lstm, and its accuracy and speed are better than resnet and crnn models. 5 | 6 | ## Script Description 7 | 8 | ```shell 9 | MicroOCR 10 | ├── README.md # descriptions about MicroNet 11 | ├── average_meter.py # average meter 12 | ├── collatefn.py # batch data processing 13 | ├── label_converter.py # label converter 14 | ├── dataset.py # data preprocessing for training and evaluation 15 | ├── demo.py # inference 16 | ├── gen_image.py # generate image for train and eval 17 | ├── img_aug.py # img augmentation 18 | ├── infer_tool.py # inference tool 19 | ├── logger.py # logger 20 | ├── loss.py # ctcloss definition 21 | ├── model.py # MicroMLPNet 22 | ├── train.py # train the model 23 | ``` 24 | 25 | ## Generate data for train and eval 26 | ```shell 27 | python gen_image.py 28 | ``` 29 | 30 | ## Training 31 | ```shell 32 | python train.py 33 | ``` 34 | 35 | ## Inference 36 | ```shell 37 | python demo.py 38 | ``` -------------------------------------------------------------------------------- /average_meter.py: -------------------------------------------------------------------------------- 1 | 2 | class AverageMeter(object): 3 | """Computes and stores the average and current value""" 4 | def __init__(self): 5 | self.reset() 6 | 7 | def reset(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def update(self, val, n=1): 14 | self.val = val 15 | self.sum += val * n 16 | self.count += n 17 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /collatefn.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | 3 | import torch 4 | from torch import Tensor 5 | import numpy as np 6 | import cv2 7 | from torchvision import transforms 8 | 9 | 10 | def resize_with_specific_height(input_h, img: np.ndarray) -> np.ndarray: 11 | """ 12 | 将图像resize到指定高度 13 | :param _img: 待resize的图像 14 | :return: resize完成的图像 15 | """ 16 | resize_ratio = input_h / img.shape[0] 17 | return cv2.resize(img, (0, 0), fx=resize_ratio, 18 | fy=resize_ratio, interpolation=cv2.INTER_LINEAR) 19 | 20 | 21 | def width_pad_img(img: np.ndarray, 22 | target_width: int, 23 | pad_value: int = 0) -> np.ndarray: 24 | """ 25 | 将图像进行高度不变,宽度的调整的pad 26 | :param _img: 待pad的图像 27 | :param _target_width: 目标宽度 28 | :param _pad_value: pad的值 29 | :return: pad完成后的图像 30 | """ 31 | if len(img.shape) == 3: 32 | height, width, channels = img.shape 33 | if target_width > width: 34 | to_return_img = np.ones( 35 | [height, target_width, channels], dtype=img.dtype) * pad_value 36 | to_return_img[:height, :width, :] = img 37 | else: 38 | to_return_img = img 39 | elif len(img.shape) == 2: 40 | height, width = img.shape 41 | if target_width > width: 42 | to_return_img = np.ones( 43 | [height, target_width], dtype=img.dtype) * pad_value 44 | to_return_img[:height, :width] = img 45 | else: 46 | to_return_img = img 47 | return to_return_img 48 | 49 | 50 | class RecCollateFn: 51 | """ 52 | 将图片缩放到固定高度,宽度取当前批次最长的RecCollateFn 53 | """ 54 | 55 | def __init__(self, input_h: int = 32): 56 | self.input_h = input_h 57 | self.transforms = transforms.ToTensor() 58 | 59 | def __call__(self, 60 | batch: List[Dict[str, np.ndarray]]) \ 61 | -> Dict[str, List[np.ndarray]]: 62 | resize_images: List[Tensor] = [] 63 | # 统一缩放到指定高度 64 | all_same_height_images = [ 65 | resize_with_specific_height( 66 | self.input_h, batch_index['images']) for batch_index in batch] 67 | # 取出最大宽度 68 | max_img_w = max({m_img.shape[1] for m_img in all_same_height_images}) 69 | # 确保最大宽度是8的倍数 70 | max_img_w = int(np.ceil(max_img_w / 8) * 8) 71 | #print(max_img_w) 72 | labels = [] 73 | for i in range(len(batch)): 74 | labels.append(batch[i]['labels']) 75 | img = width_pad_img( 76 | all_same_height_images[i], max_img_w) 77 | img = self.transforms(img) 78 | resize_images.append(img) 79 | images = torch.stack(resize_images) 80 | return {'images': images, 'labels': labels} 81 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Union, Any 2 | import os 3 | 4 | import cv2 5 | import numpy as np 6 | from torch.utils.data import Dataset 7 | import torchvision.transforms.functional as F 8 | 9 | from img_aug import DataProcess, cv2pil, pil2cv 10 | 11 | 12 | class TextLineDataset(Dataset): 13 | def __init__(self, 14 | data_dir: str, 15 | label_file_list: Union[str, List[str]], 16 | character: str, 17 | in_channels: int, 18 | augmentation: bool = False): 19 | self.in_channels = in_channels 20 | self.aug = DataProcess() 21 | self.augmentation = augmentation 22 | self._get_image_info_list(label_file_list) 23 | self.str2idx = dict(zip(character, range(len(character)))) 24 | self.str2idx[' '] = len(self.str2idx) 25 | self.data_dir = data_dir 26 | 27 | def _get_image_info_list(self, 28 | file_list: Union[str, List[str]]) -> List[str]: 29 | if isinstance(file_list, str): 30 | file_list = [file_list] 31 | self.data_lines = [] 32 | for file in file_list: 33 | with open(file, "r", encoding="utf-8") as f: 34 | lines = f.readlines() 35 | self.data_lines.extend(lines) 36 | 37 | def __len__(self) -> int: 38 | return len(self.data_lines) 39 | 40 | def __getitem__(self, index: int) -> Dict[str, Any]: 41 | img_name, label = self.data_lines[index].strip().split('\t') 42 | img_path = os.path.join(self.data_dir, img_name) 43 | image = cv2.imdecode(np.fromfile( 44 | img_path, dtype=np.uint8), cv2.IMREAD_COLOR if self.in_channels == 3 else cv2.IMREAD_GRAYSCALE) 45 | if self.in_channels == 3: 46 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 47 | if self.augmentation: 48 | image = pil2cv(self.aug.aug_img(cv2pil(image))) 49 | return dict(images=image, labels=label) 50 | 51 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | import cv2 5 | 6 | from infer_tool import RecInfer 7 | 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser(description='MicroOCR') 11 | parser.add_argument('--vocabulary_path', default='english.txt', 12 | help='vocabulary path') 13 | parser.add_argument('--model_path', 14 | default='./save_model/micromlp_nh64_depth2_best_rec.pth', 15 | help='model path') 16 | parser.add_argument('--nh', default=64, type=int, help='nh') 17 | parser.add_argument('--depth', default=2, type=int, help='depth') 18 | parser.add_argument( 19 | '--in_channels', default=3, help='in channels', type=int) 20 | cfg = parser.parse_args() 21 | infer = RecInfer(cfg) 22 | img = cv2.imread('D:/dataset/gen/test/00000.jpg', 23 | cv2.IMREAD_COLOR if cfg.in_channels == 3 else cv2.IMREAD_GRAYSCALE) 24 | if cfg.in_channels == 3: 25 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 26 | t0 = time.time() 27 | out = infer.predict(img) 28 | t1 = time.time() 29 | print(out, '{:.2f}ms'.format((t1-t0)*1000)) 30 | 31 | 32 | if __name__ == "__main__": 33 | main() 34 | -------------------------------------------------------------------------------- /english.txt: -------------------------------------------------------------------------------- 1 | 0 2 | 1 3 | 2 4 | 3 5 | 4 6 | 5 7 | 6 8 | 7 9 | 8 10 | 9 11 | a 12 | b 13 | c 14 | d 15 | e 16 | f 17 | g 18 | h 19 | i 20 | j 21 | k 22 | l 23 | m 24 | n 25 | o 26 | p 27 | q 28 | r 29 | s 30 | t 31 | u 32 | v 33 | w 34 | x 35 | y 36 | z 37 | A 38 | B 39 | C 40 | D 41 | E 42 | F 43 | G 44 | H 45 | I 46 | J 47 | K 48 | L 49 | M 50 | N 51 | O 52 | P 53 | Q 54 | R 55 | S 56 | T 57 | U 58 | V 59 | W 60 | X 61 | Y 62 | Z -------------------------------------------------------------------------------- /gen_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | from PIL import Image 5 | from captcha.image import ImageCaptcha 6 | 7 | alphabet = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' 8 | gen = ImageCaptcha(160, 60) 9 | 10 | 11 | def gen_img(max_length): 12 | """ 13 | max_length:验证码位数 14 | """ 15 | content = [random.randrange(0, len(alphabet)) for _ in range(max_length)] 16 | s = ''.join([alphabet[i] for i in content]) 17 | d = gen.generate(s) 18 | img = Image.open(d) 19 | return s, img 20 | 21 | 22 | def gen_dataset(img_root, count): 23 | if not os.path.exists(img_root): 24 | os.makedirs(img_root) 25 | label_name = img_root + '.txt' 26 | with open(label_name, mode='w+', encoding="utf-8") as fs: 27 | for i in range(count): 28 | length = random.randint(3, 8) 29 | length = 4 30 | code, img = gen_img(length) 31 | img_name = str(i).zfill(5)+'.jpg' 32 | img_path = img_root + '/' + img_name 33 | fs.write(img_name+'\t'+code+"\n") 34 | img.save(img_path) 35 | 36 | 37 | if __name__ == '__main__': 38 | gen_dataset('D:/dataset/gen/train', 10000) 39 | gen_dataset('D:/dataset/gen/test', 500) -------------------------------------------------------------------------------- /img_aug.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import abc 3 | import random 4 | import numpy as np 5 | from PIL import Image, ImageEnhance, ImageFilter, ImageOps, ImageDraw 6 | import math 7 | 8 | 9 | def cv2pil(image): 10 | """ 11 | 将bgr格式的numpy的图像转换为pil 12 | :param image: 图像数组 13 | :return: Image对象 14 | """ 15 | assert isinstance(image, np.ndarray), 'input image type is not cv2' 16 | if len(image.shape) == 2: 17 | return Image.fromarray(image) 18 | elif len(image.shape) == 3: 19 | return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) 20 | 21 | 22 | def get_pil_image(image): 23 | """ 24 | 将图像统一转换为PIL格式 25 | :param image: 图像 26 | :return: Image格式的图像 27 | """ 28 | if isinstance(image, Image.Image): # or isinstance(image, PIL.JpegImagePlugin.JpegImageFile): 29 | return image 30 | elif isinstance(image, np.ndarray): 31 | return cv2pil(image) 32 | 33 | 34 | def get_cv_image(image): 35 | """ 36 | 将图像转换为numpy格式的数据 37 | :param image: 图像 38 | :return: ndarray格式的图像数据 39 | """ 40 | if isinstance(image, np.ndarray): 41 | return image 42 | elif isinstance(image, Image.Image): # or isinstance(image, PIL.JpegImagePlugin.JpegImageFile): 43 | return pil2cv(image) 44 | 45 | 46 | def pil2cv(image): 47 | """ 48 | 将Image对象转换为ndarray格式图像 49 | :param image: 图像对象 50 | :return: ndarray图像数组 51 | """ 52 | if len(image.split()) == 1: 53 | return np.asarray(image) 54 | elif len(image.split()) == 3: 55 | return cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) 56 | elif len(image.split()) == 4: 57 | return cv2.cvtColor(np.asarray(image), cv2.COLOR_RGBA2BGR) 58 | 59 | 60 | class TransBase(object): 61 | """ 62 | 数据增广的基类 63 | """ 64 | 65 | def __init__(self, probability=1.): 66 | """ 67 | 初始化对象 68 | :param probability: 执行概率 69 | """ 70 | super(TransBase, self).__init__() 71 | self.probability = probability 72 | 73 | @abc.abstractmethod 74 | def trans_function(self, _image): 75 | """ 76 | 初始化执行函数,需要进行重载 77 | :param _image: 待处理图像 78 | :return: 执行后的Image对象 79 | """ 80 | pass 81 | 82 | # @utils.zlog 83 | def process(self, _image): 84 | """ 85 | 调用执行函数 86 | :param _image: 待处理图像 87 | :return: 执行后的Image对象 88 | """ 89 | if np.random.random() < self.probability: 90 | return self.trans_function(_image) 91 | else: 92 | return _image 93 | 94 | def __call__(self, _image): 95 | """ 96 | 重载(),方便直接进行调用 97 | :param _image: 待处理图像 98 | :return: 执行后的Image 99 | """ 100 | return self.process(_image) 101 | 102 | 103 | 104 | class SightTransfer(TransBase): 105 | """ 106 | 随机视角变换 107 | """ 108 | 109 | def setparam(self): 110 | self.horizontal_sight_directions = ('left', 'mid', 'right') 111 | self.vertical_sight_directions = ('up', 'mid', 'down') 112 | self.angle_left_right = 5 113 | self.angle_up_down = 5 114 | self.angle_vertical = 5 115 | self.angle_horizontal = 5 116 | 117 | def trans_function(self, image): 118 | horizontal_sight_direction = self.horizontal_sight_directions[random.randint(0, 2)] 119 | vertical_sight_direction = self.vertical_sight_directions[random.randint(0, 2)] 120 | image = get_cv_image(image) 121 | image = self.sight_transfer([image], horizontal_sight_direction, vertical_sight_direction) 122 | image = image[0] 123 | image = get_pil_image(image) 124 | return image 125 | 126 | @staticmethod 127 | def rand_reduce(val): 128 | return int(np.random.random() * val) 129 | 130 | def left_right_transfer(self, img, is_left=True, angle=None): 131 | """ 左右视角,默认左视角 132 | :param img: 正面视角原始图片 133 | :param is_left: 是否左视角 134 | :param angle: 角度 135 | :return: 136 | """ 137 | if angle is None: 138 | angle = self.angle_left_right # self.rand_reduce(self.angle_left_right) 139 | 140 | shape = img.shape 141 | size_src = (shape[1], shape[0]) 142 | # 源图像四个顶点坐标 143 | pts1 = np.float32([[0, 0], [0, size_src[1]], [size_src[0], 0], [size_src[0], size_src[1]]]) 144 | # 计算图片进行投影倾斜后的位置 145 | interval = abs(int(math.sin((float(angle) / 180) * math.pi) * shape[0])) 146 | # 目标图像上四个顶点的坐标 147 | if is_left: 148 | pts2 = np.float32([[0, 0], [0, size_src[1]], 149 | [size_src[0], interval], [size_src[0], size_src[1] - interval]]) 150 | else: 151 | pts2 = np.float32([[0, interval], [0, size_src[1] - interval], 152 | [size_src[0], 0], [size_src[0], size_src[1]]]) 153 | # 获取 3x3的投影映射/透视变换 矩阵 154 | matrix = cv2.getPerspectiveTransform(pts1, pts2) 155 | dst = cv2.warpPerspective(img, matrix, size_src) 156 | return dst, matrix, size_src 157 | 158 | def up_down_transfer(self, img, is_down=True, angle=None): 159 | """ 上下视角,默认下视角 160 | :param img: 正面视角原始图片 161 | :param is_down: 是否下视角 162 | :param angle: 角度 163 | :return: 164 | """ 165 | if angle is None: 166 | angle = self.rand_reduce(self.angle_up_down) 167 | 168 | shape = img.shape 169 | size_src = (shape[1], shape[0]) 170 | # 源图像四个顶点坐标 171 | pts1 = np.float32([[0, 0], [0, size_src[1]], [size_src[0], 0], [size_src[0], size_src[1]]]) 172 | # 计算图片进行投影倾斜后的位置 173 | interval = abs(int(math.sin((float(angle) / 180) * math.pi) * shape[0])) 174 | # 目标图像上四个顶点的坐标 175 | if is_down: 176 | pts2 = np.float32([[interval, 0], [0, size_src[1]], 177 | [size_src[0] - interval, 0], [size_src[0], size_src[1]]]) 178 | else: 179 | pts2 = np.float32([[0, 0], [interval, size_src[1]], 180 | [size_src[0], 0], [size_src[0] - interval, size_src[1]]]) 181 | # 获取 3x3的投影映射/透视变换 矩阵 182 | matrix = cv2.getPerspectiveTransform(pts1, pts2) 183 | dst = cv2.warpPerspective(img, matrix, size_src) 184 | return dst, matrix, size_src 185 | 186 | def vertical_tilt_transfer(self, img, is_left_high=True): 187 | """ 添加按照指定角度进行垂直倾斜(上倾斜或下倾斜,最大倾斜角度self.angle_vertical一半) 188 | :param img: 输入图像的numpy 189 | :param is_left_high: 图片投影的倾斜角度,左边是否相对右边高 190 | """ 191 | angle = self.rand_reduce(self.angle_vertical) 192 | 193 | shape = img.shape 194 | size_src = [shape[1], shape[0]] 195 | # 源图像四个顶点坐标 196 | pts1 = np.float32([[0, 0], [0, size_src[1]], [size_src[0], 0], [size_src[0], size_src[1]]]) 197 | 198 | # 计算图片进行上下倾斜后的距离,及形状 199 | interval = abs(int(math.sin((float(angle) / 180) * math.pi) * shape[1])) 200 | size_target = (int(math.cos((float(angle) / 180) * math.pi) * shape[1]), shape[0] + interval) 201 | # 目标图像上四个顶点的坐标 202 | if is_left_high: 203 | pts2 = np.float32([[0, 0], [0, size_target[1] - interval], 204 | [size_target[0], interval], [size_target[0], size_target[1]]]) 205 | else: 206 | pts2 = np.float32([[0, interval], [0, size_target[1]], 207 | [size_target[0], 0], [size_target[0], size_target[1] - interval]]) 208 | 209 | # 获取 3x3的投影映射/透视变换 矩阵 210 | matrix = cv2.getPerspectiveTransform(pts1, pts2) 211 | dst = cv2.warpPerspective(img, matrix, size_target) 212 | return dst, matrix, size_target 213 | 214 | def horizontal_tilt_transfer(self, img, is_right_tilt=True): 215 | """ 添加按照指定角度进行水平倾斜(右倾斜或左倾斜,最大倾斜角度self.angle_horizontal一半) 216 | :param img: 输入图像的numpy 217 | :param is_right_tilt: 图片投影的倾斜方向(右倾,左倾) 218 | """ 219 | angle = self.rand_reduce(self.angle_horizontal) 220 | 221 | shape = img.shape 222 | size_src = [shape[1], shape[0]] 223 | # 源图像四个顶点坐标 224 | pts1 = np.float32([[0, 0], [0, size_src[1]], [size_src[0], 0], [size_src[0], size_src[1]]]) 225 | 226 | # 计算图片进行左右倾斜后的距离,及形状 227 | interval = abs(int(math.sin((float(angle) / 180) * math.pi) * shape[0])) 228 | size_target = (shape[1] + interval, int(math.cos((float(angle) / 180) * math.pi) * shape[0])) 229 | # 目标图像上四个顶点的坐标 230 | if is_right_tilt: 231 | pts2 = np.float32([[interval, 0], [0, size_target[1]], 232 | [size_target[0], 0], [size_target[0] - interval, size_target[1]]]) 233 | else: 234 | pts2 = np.float32([[0, 0], [interval, size_target[1]], 235 | [size_target[0] - interval, 0], [size_target[0], size_target[1]]]) 236 | 237 | # 获取 3x3的投影映射/透视变换 矩阵 238 | matrix = cv2.getPerspectiveTransform(pts1, pts2) 239 | dst = cv2.warpPerspective(img, matrix, size_target) 240 | return dst, matrix, size_target 241 | 242 | def sight_transfer(self, images, horizontal_sight_direction, vertical_sight_direction): 243 | """ 对图片进行视角变换 244 | :param images: 图片列表 245 | :param horizontal_sight_direction: 水平视角变换方向 246 | :param vertical_sight_direction: 垂直视角变换方向 247 | :return: 248 | """ 249 | flag = 0 250 | img_num = len(images) 251 | # 左右视角 252 | if horizontal_sight_direction == 'left': 253 | flag += 1 254 | images[0], matrix, size = self.left_right_transfer(images[0], is_left=True) 255 | for i in range(1, img_num): 256 | images[i] = cv2.warpPerspective(images[i], matrix, size) 257 | elif horizontal_sight_direction == 'right': 258 | flag -= 1 259 | images[0], matrix, size = self.left_right_transfer(images[0], is_left=False) 260 | for i in range(1, img_num): 261 | images[i] = cv2.warpPerspective(images[i], matrix, size) 262 | else: 263 | pass 264 | # 上下视角 265 | if vertical_sight_direction == 'down': 266 | flag += 1 267 | images[0], matrix, size = self.up_down_transfer(images[0], is_down=True) 268 | for i in range(1, img_num): 269 | images[i] = cv2.warpPerspective(images[i], matrix, size) 270 | elif vertical_sight_direction == 'up': 271 | flag -= 1 272 | images[0], matrix, size = self.up_down_transfer(images[0], is_down=False) 273 | for i in range(1, img_num): 274 | images[i] = cv2.warpPerspective(images[i], matrix, size) 275 | else: 276 | pass 277 | 278 | # 左下视角 或 右上视角 279 | if abs(flag) == 2: 280 | images[0], matrix, size = self.vertical_tilt_transfer(images[0], is_left_high=True) 281 | for i in range(1, img_num): 282 | images[i] = cv2.warpPerspective(images[i], matrix, size) 283 | 284 | images[0], matrix, size = self.horizontal_tilt_transfer(images[0], is_right_tilt=True) 285 | for i in range(1, img_num): 286 | images[i] = cv2.warpPerspective(images[i], matrix, size) 287 | # 左上视角 或 右下视角 288 | elif abs(flag) == 1: 289 | images[0], matrix, size = self.vertical_tilt_transfer(images[0], is_left_high=False) 290 | for i in range(1, img_num): 291 | images[i] = cv2.warpPerspective(images[i], matrix, size) 292 | 293 | images[0], matrix, size = self.horizontal_tilt_transfer(images[0], is_right_tilt=False) 294 | for i in range(1, img_num): 295 | images[i] = cv2.warpPerspective(images[i], matrix, size) 296 | else: 297 | pass 298 | 299 | return images 300 | 301 | 302 | class Blur(TransBase): 303 | """ 304 | 随机高斯模糊 305 | """ 306 | 307 | def setparam(self, lower=0, upper=1): 308 | self.lower = lower 309 | self.upper = upper 310 | assert self.upper >= self.lower, "upper must be >= lower." 311 | assert self.lower >= 0, "lower must be non-negative." 312 | 313 | def trans_function(self, image): 314 | image = get_pil_image(image) 315 | image = image.filter(ImageFilter.GaussianBlur(radius=1.5)) 316 | return image 317 | 318 | 319 | class MotionBlur(TransBase): 320 | """ 321 | 随机运动模糊 322 | """ 323 | 324 | def setparam(self, degree=5, angle=180): 325 | self.degree = degree 326 | self.angle = angle 327 | 328 | def trans_function(self, image): 329 | image = get_pil_image(image) 330 | angle = random.randint(0, self.angle) 331 | M = cv2.getRotationMatrix2D((self.degree / 2, self.degree / 2), angle, 1) 332 | motion_blur_kernel = np.diag(np.ones(self.degree)) 333 | motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M, (self.degree, self.degree)) 334 | motion_blur_kernel = motion_blur_kernel / self.degree 335 | image = image.filter(ImageFilter.Kernel(size=(self.degree, self.degree), kernel=motion_blur_kernel.reshape(-1))) 336 | return image 337 | 338 | class RandomHsv(TransBase): 339 | def setparam(self, hue_keep=0.1, saturation_keep=0.7, value_keep=0.4): 340 | self.hue_keep = hue_keep 341 | self.saturation_keep = saturation_keep 342 | self.value_keep = value_keep 343 | 344 | def trans_function(self, image): 345 | image = get_cv_image(image) 346 | hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 347 | # 色调,饱和度,亮度 348 | hsv[:, :, 0] = hsv[:, :, 0] * (self.hue_keep + np.random.random() * (1 - self.hue_keep)) 349 | hsv[:, :, 1] = hsv[:, :, 1] * (self.saturation_keep + np.random.random() * (1 - self.saturation_keep)) 350 | hsv[:, :, 2] = hsv[:, :, 2] * (self.value_keep + np.random.random() * (1 - self.value_keep)) 351 | image = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) 352 | image = get_pil_image(image) 353 | return image 354 | 355 | class Smudge(TransBase): 356 | def setparam(self): 357 | pass 358 | 359 | def trans_function(self, image): 360 | image = get_cv_image(image) 361 | smu = cv2.imread("smu.jpg") 362 | rows = self.rand_reduce(smu.shape[0] - image.shape[0]) 363 | cols = self.rand_reduce(smu.shape[1] - image.shape[1]) 364 | add_smu = smu[rows:rows + image.shape[0], cols:cols + image.shape[1]] 365 | image = cv2.bitwise_not(image) 366 | image = cv2.bitwise_and(add_smu, image) 367 | image = cv2.bitwise_not(image) 368 | image = get_pil_image(image) 369 | return image 370 | 371 | @staticmethod 372 | def rand_reduce(val): 373 | return int(np.random.random() * val) 374 | 375 | class DataProcess: 376 | def __init__(self): 377 | """ 378 | 文本数据增广类 379 | """ 380 | self.sight_transfer = SightTransfer(probability=0.5) 381 | self.blur = Blur(probability=0.3) 382 | self.motion_blur = MotionBlur(probability=0.3) 383 | self.rand_hsv = RandomHsv(probability=0.3) 384 | self.sight_transfer.setparam() 385 | self.blur.setparam() 386 | self.motion_blur.setparam() 387 | self.rand_hsv.setparam() 388 | 389 | 390 | def aug_img(self, img): 391 | img = self.motion_blur.process(img) 392 | img = self.blur.process(img) 393 | img = self.sight_transfer.process(img) 394 | img = self.rand_hsv.process(img) 395 | return img 396 | 397 | 398 | if __name__ == '__main__': 399 | pass 400 | img = Image.open('002.png') 401 | aa = Smudge() 402 | aa.setparam() 403 | img = aa.trans_function(img) 404 | img.save('00001.jpg') 405 | # data_augment = DataAug() 406 | # augmented_img = data_augment.aug_img(img) 407 | # augmented_img.show() -------------------------------------------------------------------------------- /infer_tool.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | 3 | import numpy as np 4 | import torch 5 | from torch import Tensor 6 | from torchvision import transforms 7 | 8 | from train import build_model, load_model, build_conveter 9 | from collatefn import resize_with_specific_height 10 | 11 | 12 | class RecInfer: 13 | def __init__(self, cfg): 14 | self.device = torch.device('cpu') 15 | character = [] 16 | with open(cfg.vocabulary_path, mode='r', encoding='utf-8') as fa: 17 | lines = fa.readlines() 18 | for line in lines: 19 | character.append(line.strip()) 20 | self.converter = build_conveter(character) 21 | self.model = build_model(cfg.in_channels, cfg.nh, cfg.depth, self.converter.num_of_classes) 22 | load_model(cfg.model_path, self.model) 23 | self.model.to(self.device) 24 | self.model.eval() 25 | self.transforms = transforms.ToTensor() 26 | 27 | def predict(self, img: np.ndarray) -> List[Tuple[str, List[np.ndarray]]]: 28 | img = resize_with_specific_height(32, img) 29 | tensor = self.transforms(img) 30 | tensor = tensor.unsqueeze(dim=0) 31 | tensor = tensor.to(self.device) 32 | out: Tensor = self.model(tensor) 33 | txt = self.converter.decode(out, False) 34 | return txt 35 | -------------------------------------------------------------------------------- /label_converter.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import torch 4 | from torch import Tensor 5 | import numpy as np 6 | 7 | 8 | class CTCLabelConverter(object): 9 | def __init__(self, character: str): 10 | list_character = list(character) 11 | self.num_of_classes = len(character)+2 12 | self.idx2char = [] 13 | self.idx2char.append('_') 14 | for line in list_character: 15 | line = line.strip() 16 | if line != '': 17 | self.idx2char.append(line) 18 | self.idx2char.append(' ') 19 | self.char2idx = {} 20 | for idx, char in enumerate(self.idx2char): 21 | self.char2idx[char] = idx 22 | 23 | def str2idx(self, strings): 24 | """Convert strings to indexes. 25 | Args: 26 | strings (list[str]): ['hello', 'world']. 27 | Returns: 28 | indexes (list[list[int]]): [1,2,3,3,4,5,4,6,3,7]. 29 | """ 30 | indexes = [] 31 | for string in strings: 32 | for char in string: 33 | char_idx = self.char2idx.get(char) 34 | if char_idx is None: 35 | raise Exception(f'Chararcter: {char} not in dict') 36 | indexes.append(char_idx) 37 | return indexes 38 | 39 | def encode(self, strings: List[str]) -> Tuple[Tensor, Tensor]: 40 | targets_lengths = [len(s) for s in strings] 41 | targets = self.str2idx(strings) 42 | return torch.LongTensor(targets), torch.LongTensor(targets_lengths) 43 | 44 | def decode(self, 45 | preds: Tensor, 46 | raw: bool = False) -> List[Tuple[str, np.ndarray]]: 47 | preds = preds.softmax(dim=2) 48 | preds_score, preds_idx = preds.max(dim=2) 49 | preds_idx = preds_idx.detach().cpu().numpy().tolist() 50 | preds_score = preds_score.detach().cpu().numpy().tolist() 51 | result_list = [] 52 | for word, score in zip(preds_idx, preds_score): 53 | if raw: 54 | result_list.append( 55 | (''.join([self.idx2char[char_idx] for char_idx in word]), score)) 56 | else: 57 | char_list = [] 58 | score_list = [] 59 | for i, char_idx in enumerate(word): 60 | if char_idx != 0 and (not (i > 0 and word[i - 1] == char_idx)): 61 | char_list.append(self.idx2char[char_idx]) 62 | score_list.append(score[i]) 63 | result_list.append((''.join(char_list), score_list)) 64 | return result_list -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | import functools 5 | 6 | from termcolor import colored 7 | 8 | 9 | @functools.lru_cache() 10 | def create_logger(output_dir, dist_rank=0, name=''): 11 | # create logger 12 | logger = logging.getLogger(name) 13 | logger.setLevel(logging.DEBUG) 14 | logger.propagate = False 15 | 16 | # create formatter 17 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s' 18 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ 19 | colored('(%(filename)s %(lineno)d)', 'yellow') + \ 20 | ': %(levelname)s %(message)s' 21 | 22 | # create console handlers for master process 23 | if dist_rank == 0: 24 | console_handler = logging.StreamHandler(sys.stdout) 25 | console_handler.setLevel(logging.DEBUG) 26 | console_handler.setFormatter( 27 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S')) 28 | logger.addHandler(console_handler) 29 | 30 | # create file handlers 31 | file_handler = logging.FileHandler(os.path.join( 32 | output_dir, f'log_rank{dist_rank}.txt'), mode='a') 33 | file_handler.setLevel(logging.DEBUG) 34 | file_handler.setFormatter(logging.Formatter( 35 | fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) 36 | logger.addHandler(file_handler) 37 | 38 | return logger 39 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | from torch import nn, Tensor 5 | from torch.nn import CTCLoss as TorchCTCLoss 6 | 7 | 8 | class CTCLoss(nn.Module): 9 | def __init__(self, blank_idx: int, reduction: str = 'mean'): 10 | super().__init__() 11 | self.loss_func = TorchCTCLoss( 12 | blank=blank_idx, reduction=reduction, zero_infinity=True) 13 | 14 | def focal_ctc_loss(self, ctc_loss,alpha=0.25,gamma=0.5): # 0.99,1 15 | prob = torch.exp(-ctc_loss) 16 | focal_loss = alpha*(1-prob).pow(gamma)*ctc_loss 17 | return focal_loss.mean() 18 | 19 | def forward(self, 20 | pred: Tensor, 21 | label: Tensor, 22 | label_length: Tensor) -> Dict[str, Tensor]: 23 | pred = pred.permute(1, 0, 2) 24 | batch_size = pred.size(1) 25 | pred = pred.log_softmax(2) 26 | preds_lengths = torch.LongTensor([pred.size(0)] * batch_size) 27 | loss = self.loss_func(pred, label, preds_lengths, label_length) 28 | #loss = self.focal_ctc_loss(loss) 29 | return dict(loss=loss) -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Any 2 | 3 | from torch import Tensor 4 | 5 | 6 | class RecMetric(object): 7 | def __init__(self, converter): 8 | """ 9 | 文本识别相关指标计算类 10 | :param converter: 用于label转换的转换器 11 | """ 12 | self.converter = converter 13 | 14 | def __call__(self, 15 | predictions: Tensor, 16 | labels: List[str]) -> Dict[str, Any]: 17 | preds_list = self.converter.decode(predictions) 18 | raws_list = self.converter.decode(predictions, True) 19 | word_correct = 0 20 | char_correct = 0 21 | show_str = [] 22 | 23 | for (raw_str, raw_score), (pred_str, pred_score), target_str in zip( 24 | raws_list, preds_list, labels): 25 | show_str.append(f'{raw_str} -> {pred_str} -> {target_str}') 26 | if pred_str == target_str: 27 | word_correct += 1 28 | for idxa, idxb in zip(pred_str, target_str): 29 | if idxa == idxb: 30 | char_correct += 1 31 | return dict(word_correct=word_correct, 32 | char_correct=char_correct, 33 | show_str=show_str) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.functional import adaptive_avg_pool1d 4 | 5 | 6 | class ConvBNACT(nn.Module): 7 | def __init__(self, in_channels, out_channels, kernel_size, 8 | stride=1, padding=0, groups=1): 9 | super().__init__() 10 | self.conv = nn.Conv2d(in_channels, out_channels, 11 | kernel_size, stride, padding, 1, groups, False) 12 | self.bn = nn.BatchNorm2d(out_channels) 13 | self.act = nn.GELU() 14 | 15 | def forward(self, x): 16 | x = self.conv(x) 17 | x = self.bn(x) 18 | x = self.act(x) 19 | return x 20 | 21 | 22 | class MicroBlock(nn.Module): 23 | def __init__(self, nh): 24 | super().__init__() 25 | self.conv1 = ConvBNACT(nh, nh, 1) 26 | self.conv2 = ConvBNACT(nh, nh, 3, 1, 1, nh) 27 | 28 | def forward(self, x): 29 | x = self.conv1(x) 30 | x = x + self.conv2(x) 31 | return x 32 | 33 | 34 | class MicroStage(nn.Sequential): 35 | def __init__(self, depth, nh): 36 | super().__init__(*[MicroBlock(nh) for _ in range(depth)]) 37 | 38 | 39 | class MLP(nn.Sequential): 40 | def __init__(self, input_dim: int, hidden_dim: int): 41 | super().__init__( 42 | nn.Linear(input_dim, hidden_dim), 43 | nn.GELU(), 44 | nn.Linear(hidden_dim, input_dim), 45 | nn.Dropout(0.5) 46 | ) 47 | 48 | 49 | class Residual(nn.Module): 50 | def __init__(self, fn: nn.Module): 51 | super().__init__() 52 | self.fn = fn 53 | 54 | def forward(self, x: torch.Tensor) -> torch.Tensor: 55 | return x + self.fn(x) 56 | 57 | 58 | class MLPBlock(nn.Module): 59 | def __init__(self, input_dim: int, hidden_dim: int): 60 | super().__init__() 61 | self.layer_norm1 = nn.LayerNorm(input_dim) 62 | self.residual1 = Residual(MLP(input_dim, hidden_dim)) 63 | self.layer_norm2 = nn.LayerNorm(input_dim) 64 | self.conv = nn.Conv2d( 65 | input_dim, input_dim, kernel_size=3, padding=1, groups=input_dim, bias=False) 66 | self.layer_norm3 = nn.LayerNorm(input_dim) 67 | self.residual2 = Residual(MLP(input_dim, hidden_dim)) 68 | 69 | def forward(self, x): 70 | x = x.permute(0, 3, 2, 1) 71 | x = self.layer_norm1(x) 72 | x = self.residual1(x) 73 | x = self.layer_norm2(x) 74 | x = x.permute(0, 3, 2, 1) 75 | x = self.conv(x) 76 | x = x.permute(0, 3, 2, 1) 77 | x = self.layer_norm3(x) 78 | x = self.residual2(x) 79 | x = x.permute(0, 3, 2, 1) 80 | return x 81 | 82 | 83 | class MLPStage(nn.Sequential): 84 | def __init__(self, depth, input_dim, hidden_dim): 85 | super().__init__(*[MLPBlock(input_dim, hidden_dim) 86 | for _ in range(depth)]) 87 | 88 | 89 | class Tokenizer(nn.Sequential): 90 | def __init__(self, in_channels: int, hidden_dim: int, out_dim: int): 91 | super().__init__( 92 | ConvBNACT(in_channels, hidden_dim // 2, 3, 2, 1), 93 | ConvBNACT(hidden_dim // 2, hidden_dim // 2, 3, 1, 1), 94 | ConvBNACT(hidden_dim // 2, out_dim, 3, 1, 1), 95 | nn.MaxPool2d(3, 2, 1) 96 | ) 97 | 98 | 99 | class MicroMLPNet(nn.Module): 100 | def __init__(self, in_channels=3, nh=64, depth=2, nclass=60, img_height=32): 101 | super().__init__() 102 | """ 103 | nh512可以 104 | """ 105 | self.embed = Tokenizer(in_channels, nh, nh) 106 | self.micro_stages = MicroStage(depth, nh) 107 | self.mlp_stages = MLPStage(depth, nh, nh) 108 | self.flatten = nn.Flatten(1, 2) 109 | self.dropout = nn.Dropout(0.5) 110 | linear_in = nh * int(img_height//4) 111 | self.fc = nn.Linear(linear_in, nclass) 112 | 113 | def forward(self, x): 114 | x_shape = x.size() 115 | x = self.embed(x) 116 | x = self.micro_stages(x) 117 | x = self.mlp_stages(x) 118 | x = self.flatten(x) 119 | x = adaptive_avg_pool1d(x, int(x_shape[3]/4)) 120 | x = x.permute(0, 2, 1) 121 | x = self.dropout(x) 122 | x = self.fc(x) 123 | return x 124 | 125 | 126 | if __name__ == '__main__': 127 | import time 128 | x = torch.randn(1, 3, 32, 128) 129 | model = MicroMLPNet(nh=32, depth=2, nclass=62, img_height=32) 130 | t0 = time.time() 131 | out = model(x) 132 | t1 = time.time() 133 | print(out.shape, (t1-t0)*1000) 134 | #torch.save(model, 'test.pth') 135 | #from torchsummaryX import summary 136 | #summary(model, x) 137 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv_python 2 | torch 3 | torchsummary>=1.5.1 4 | tqdm 5 | torchvision 6 | numpy 7 | Pillow 8 | termcolor 9 | ImageCaptcha -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import torch 6 | import torch.optim as optim 7 | from torch.utils.data import DataLoader 8 | 9 | from model import MicroMLPNet 10 | from average_meter import AverageMeter 11 | from metric import RecMetric 12 | from loss import CTCLoss 13 | from dataset import TextLineDataset 14 | from collatefn import RecCollateFn 15 | from label_converter import CTCLabelConverter 16 | from logger import create_logger 17 | 18 | 19 | if not os.path.exists('save_model'): 20 | os.makedirs('save_model') 21 | if not os.path.exists('log'): 22 | os.makedirs('log') 23 | logger = create_logger('log') 24 | 25 | 26 | def test_model(model, device, data_loader, converter, metric, loss_func, show_str_size): 27 | model.eval() 28 | with torch.no_grad(): 29 | running_char_corrects, running_word_corrects, running_all_word, running_all_char = 0, 0, 0, 0 30 | show_strs = [] 31 | since = time.time() 32 | for batch_idx, batch_data in enumerate(data_loader): 33 | batch_data['targets'], batch_data['targets_lengths'] = converter.encode( 34 | batch_data['labels']) 35 | batch_data['images'] = batch_data['images'].to(device) 36 | batch_data['targets'] = batch_data['targets'].to(device) 37 | batch_data['targets_lengths'] = batch_data['targets_lengths'].to( 38 | device) 39 | predicted = model.forward( 40 | batch_data['images']) 41 | loss_dict = loss_func( 42 | predicted, batch_data['targets'], batch_data['targets_lengths']) 43 | acc_dict = metric(predicted, batch_data['labels']) 44 | show_strs.extend(acc_dict['show_str']) 45 | running_char_corrects += acc_dict['char_correct'] 46 | running_word_corrects += acc_dict['word_correct'] 47 | running_all_char += torch.sum(batch_data['targets_lengths']).item() 48 | running_all_word += len(batch_data['images']) 49 | if (batch_idx+1) == len(data_loader): 50 | logger.info('Eval:[step {}/{} ({:.0f}%)] Loss:{:.4f} Word Acc:{:.4f} ' 51 | 'Char Acc:{:.4f} Cost time:{:5.0f}s'.format( 52 | running_all_word, 53 | len(data_loader.dataset), 54 | 100. * (batch_idx+1) / len(data_loader), 55 | loss_dict['loss'].item(), 56 | running_word_corrects / running_all_word, 57 | running_char_corrects / running_all_char, 58 | time.time()-since)) 59 | for s in show_strs[:show_str_size]: 60 | logger.info(s) 61 | model.train() 62 | val_word_accu = running_word_corrects / \ 63 | running_all_word if running_all_word != 0 else 0. 64 | val_char_accu = running_char_corrects / \ 65 | running_all_char if running_all_char != 0 else 0. 66 | return val_word_accu, val_char_accu 67 | 68 | 69 | def train_model(cfg): 70 | device = torch.device("cuda:{}".format(cfg.gpu_index) 71 | if torch.cuda.is_available() else "cpu") 72 | with open(cfg.vocabulary_path, mode='r', encoding='utf-8') as fa: 73 | lines = fa.readlines() 74 | character = [line.strip() for line in lines] 75 | train_loader = build_dataloader( 76 | cfg.train_root, cfg.train_list, cfg.batch_size, cfg.workers, character, cfg.in_channels, is_train=True, aug=True) 77 | test_loader = build_dataloader( 78 | cfg.test_root, cfg.test_list, cfg.batch_size, cfg.workers, character, cfg.in_channels, is_train=True) 79 | converter = build_conveter(character) 80 | loss_func = build_loss().to(device) 81 | loss_average = build_average_meter() 82 | metric = build_metric(converter) 83 | model = build_model( 84 | cfg.in_channels, cfg.nh, cfg.depth, converter.num_of_classes).to(device) 85 | if cfg.model_path != '': 86 | load_model(cfg.model_path, model) 87 | optimizer = build_optimizer(model, cfg.lr) 88 | scheduler = build_scheduler(optimizer) 89 | val_word_accu, val_char_accu, best_word_accu = 0., 0., 0. 90 | for epoch in range(cfg.epochs): 91 | model.train() 92 | running_char_corrects, running_word_corrects, running_all_word, running_all_char = 0, 0, 0, 0 93 | since = time.time() 94 | for batch_idx, batch_data in enumerate(train_loader): 95 | batch_data['targets'], batch_data['targets_lengths'] = converter.encode( 96 | batch_data['labels']) 97 | batch_data['images'] = batch_data['images'].to(device) 98 | batch_data['targets'] = batch_data['targets'].to(device) 99 | batch_data['targets_lengths'] = batch_data['targets_lengths'].to( 100 | device) 101 | predicted = model.forward( 102 | batch_data['images']) 103 | loss_dict = loss_func( 104 | predicted, batch_data['targets'], batch_data['targets_lengths']) 105 | optimizer.zero_grad() 106 | loss_dict['loss'].backward() 107 | optimizer.step() 108 | loss_average.update(loss_dict['loss'].item()) 109 | acc_dict = metric(predicted, batch_data['labels']) 110 | running_char_corrects += acc_dict['char_correct'] 111 | running_word_corrects += acc_dict['word_correct'] 112 | running_all_char += torch.sum(batch_data['targets_lengths']).item() 113 | running_all_word += len(batch_data['images']) 114 | cost_time = time.time()-since 115 | if (batch_idx+1) % cfg.display_step_interval == 0 or (batch_idx+1) == len(train_loader): 116 | logger.info('Train:[epoch {}/{}][step {}/{} ({:.0f}%)] lr:{:.5f} Loss:{:.4f} Word Acc:{:.4f} ' 117 | 'Char Acc:{:.4f} Cost time:{:5.0f}s Estimated time:{:5.0f}s'.format( 118 | epoch+1, 119 | cfg.epochs, 120 | running_all_word//len(batch_data['images']), 121 | len(train_loader.dataset)//len( 122 | batch_data['images']), 123 | 100. * (batch_idx+1) / len(train_loader), 124 | scheduler.get_last_lr()[0], 125 | loss_average.avg, 126 | running_word_corrects / running_all_word, 127 | running_char_corrects / running_all_char, 128 | cost_time, 129 | cost_time*len(train_loader) / (batch_idx+1) - cost_time)) 130 | if (batch_idx+1) % cfg.eval_step_interval == 0: 131 | val_word_accu, val_char_accu = test_model( 132 | model, device, test_loader, converter, metric, loss_func, cfg.show_str_size) 133 | if val_word_accu > best_word_accu: 134 | best_word_accu = val_word_accu 135 | save_model(cfg.model_type, model, cfg.nh, cfg.depth, 'best', 136 | best_word_accu, val_char_accu) 137 | if (epoch+1) % cfg.save_epoch_interval == 0: 138 | val_word_accu, val_char_accu = test_model( 139 | model, device, test_loader, converter, metric, loss_func, cfg.show_str_size) 140 | if val_word_accu > best_word_accu: 141 | best_word_accu = val_word_accu 142 | save_epoch = 'best' 143 | else: 144 | save_epoch = epoch+1 145 | save_model(cfg.model_type, model, cfg.nh, cfg.depth, save_epoch, 146 | val_word_accu, val_char_accu) 147 | loss_average.reset() 148 | scheduler.step() 149 | 150 | 151 | def build_conveter(character): 152 | return CTCLabelConverter(character) 153 | 154 | 155 | def build_average_meter(): 156 | return AverageMeter() 157 | 158 | 159 | def build_metric(converter): 160 | return RecMetric(converter) 161 | 162 | 163 | def build_loss(blank_idx=0, reduction='sum'): 164 | return CTCLoss(blank_idx, reduction) 165 | 166 | 167 | def build_optimizer(model, lr=0.0001): 168 | # return optim.SGD(model.parameters(), lr, momentum=0.9, weight_decay=0.0001) 169 | return optim.Adam(model.parameters(), lr, betas=(0.5, 0.999), weight_decay=0.0001) 170 | 171 | 172 | def build_dataset(data_dir, label_file_list, character, in_channels, augmentation): 173 | return TextLineDataset(data_dir, label_file_list, character, in_channels, augmentation) 174 | 175 | 176 | def build_collate_fn(): 177 | return RecCollateFn(32) 178 | 179 | 180 | def build_dataloader(data_dir, label_file_list, batch_size, 181 | num_workers, character, in_channels, is_train=False, aug=False): 182 | dataset = build_dataset( 183 | data_dir, label_file_list, character, in_channels, aug) 184 | collate_fn = build_collate_fn() 185 | loader = DataLoader(dataset=dataset, batch_size=batch_size, 186 | collate_fn=collate_fn, shuffle=is_train, 187 | num_workers=num_workers) 188 | return loader 189 | 190 | 191 | def save_model(model_type, model, nh, depth, epoch, word_acc, char_acc): 192 | if epoch == 'best': 193 | save_path = './save_model/{}_nh{}_depth{}_best_rec.pth'.format( 194 | model_type, nh, depth) 195 | if os.path.exists(save_path): 196 | data = torch.load(save_path) 197 | if 'model' in data and data['wordAcc'] > word_acc: 198 | return 199 | else: 200 | save_path = './save_model/{}_nh{}_depth{}_epoch{}_wordAcc{:05f}_charAcc{:05f}.pth'.format( 201 | model_type, nh, depth, epoch, word_acc, char_acc) 202 | torch.save({ 203 | 'model': model.state_dict(), 204 | 'nh': nh, 205 | 'depth': depth, 206 | 'wordAcc': word_acc, 207 | 'charAcc': char_acc}, 208 | save_path) 209 | logger.info('save model to:'+save_path) 210 | 211 | 212 | def load_model(model_path, model): 213 | data = torch.load(model_path) 214 | if 'model' in data: 215 | model.load_state_dict(data['model']) 216 | logger.info('Model loaded nh {}, depth {}, wordAcc {} , charAcc {}'.format( 217 | data['nh'], data['depth'], data['wordAcc'], data['charAcc'])) 218 | 219 | 220 | def build_scheduler(optimizer): 221 | scheduler = optim.lr_scheduler.MultiStepLR( 222 | optimizer, milestones=[10], gamma=0.1) 223 | #scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 20) 224 | return scheduler 225 | 226 | 227 | def build_model(in_channels, nh, depth, nclass): 228 | model = MicroMLPNet(in_channels=in_channels, nh=nh, depth=depth, nclass=nclass, img_height=32) 229 | return model 230 | 231 | 232 | def main(): 233 | parser = argparse.ArgumentParser(description='MicroOCR') 234 | parser.add_argument('--train_root', default='D:/dataset/gen/', 235 | help='path to train dataset dir') 236 | parser.add_argument('--test_root', default='D:/dataset/gen/', 237 | help='path to test dataset dir') 238 | parser.add_argument( 239 | '--train_list', default='D:/dataset/gen/train.txt', help='path to train dataset label file') 240 | parser.add_argument( 241 | '--test_list', default='D:/dataset/gen/test.txt', help='path to test dataset label file') 242 | parser.add_argument('--vocabulary_path', default='english.txt', 243 | help='vocabulary path') 244 | parser.add_argument('--model_path', default='', 245 | help='model path') 246 | parser.add_argument('--model_type', default='micromlp', 247 | help='model type', type=str) 248 | parser.add_argument( 249 | '--nh', default=256, help='feature width, the more complex the picture background, the greater this value', type=int) 250 | parser.add_argument( 251 | '--depth', default=2, help='depth, the greater the number of samples, the greater this value', type=int) 252 | parser.add_argument( 253 | '--in_channels', default=3, help='in channels', type=int) 254 | parser.add_argument('--lr', default=0.001, 255 | help='initial learning rate', type=float) 256 | parser.add_argument('--batch_size', default=8, type=int, 257 | help='batch size') 258 | parser.add_argument('--workers', default=0, 259 | help='number of data loading workers', type=int) 260 | parser.add_argument('--epochs', default=20, 261 | help='number of total epochs', type=int) 262 | parser.add_argument('--display_step_interval', default=50, 263 | help='display step interval', type=int) 264 | parser.add_argument('--eval_step_interval', default=500, 265 | help='eval step interval', type=int) 266 | parser.add_argument('--save_epoch_interval', default=1, 267 | help='save checkpoint epoch interval', type=int) 268 | parser.add_argument('--show_str_size', default=10, 269 | help='show str size', type=int) 270 | parser.add_argument('--gpu_index', default=0, type=int, 271 | help='gpu index') 272 | cfg = parser.parse_args() 273 | 274 | train_model(cfg) 275 | 276 | 277 | if __name__ == '__main__': 278 | main() 279 | --------------------------------------------------------------------------------