├── README.md ├── config.yaml ├── data ├── black_plate.py ├── blue_plate.py ├── draw.py ├── green_plate.py ├── make_label.py ├── random_draw.py ├── random_plate.py └── yellow_plate.py ├── detect_explorer.py ├── detect_train.py ├── eval.py ├── models ├── detection_nn.py └── ocr_nn.py ├── ocr_config.py ├── ocr_explorer.py ├── ocr_test.py ├── ocr_train.py └── utils ├── data_augmentation.py ├── dataset.py ├── loss.py └── smu.png /README.md: -------------------------------------------------------------------------------- 1 | # :tent: PyTorch ResNet Transformer OCR 2 | 🌐 This is my personal project on Object Detection and OCR Recognition using Deep Learning Neural Networks. 3 | The project involves my personal practices on different architectures, loss functions, training techniques, data augmentation techniques, and more. 4 | 5 | 6 | 7 | ## Models implemented: 8 | - ResNet 9 | - Multi-Head Self Attention Transformer 10 | - W-Pod Net 11 | 12 | ## Framework used: 13 | - PyTorch 14 | 15 | ## Dataset: 16 | - CCPD 17 | 18 | ## Papers consulted: 19 | - [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/pdf/2109.10282.pdf) 20 | - [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385.pdf) 21 | 22 | 23 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | # **************************************************************************** # 2 | # # 3 | # ::: :::::::: # 4 | # config.yaml :+: :+: :+: # 5 | # +:+ +:+ +:+ # 6 | # By: peterli +#+ +:+ +#+ # 7 | # +#+#+#+#+#+ +#+ # 8 | # Created: 2023/04/18 17:18:05 by peterli #+# #+# # 9 | # Updated: 2023/04/18 17:20:02 by peterli ### ########.fr # 10 | # # 11 | # **************************************************************************** # 12 | 13 | # Detection and OCR Configurations 14 | # Author: Peter Li 15 | 16 | --- 17 | 18 | Data Config: 19 | train data path: "data/CCPD2019/ccpd_base" 20 | 21 | Train Config: 22 | - batch size: 32 23 | - max epoch: 100 24 | - net: ResNet 25 | - device: CUDA 26 | 27 | Test Config: 28 | - confidence threshold: 0.90 -------------------------------------------------------------------------------- /data/black_plate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from PIL import Image, ImageFont, ImageDraw 5 | 6 | 7 | class Draw: 8 | # Font 9 | _font = [ 10 | ImageFont.truetype( 11 | os.path.join(os.path.dirname(__file__), "res/eng_92.ttf"), 126 12 | ), # 文件路径 13 | ImageFont.truetype( 14 | os.path.join(os.path.dirname(__file__), "res/zh_cn_92.ttf"), 95 15 | ), 16 | ] 17 | # Background 18 | _bg = cv2.resize( 19 | cv2.imread(os.path.join(os.path.dirname(__file__), "res/black_bg.png")), 20 | (440, 140), 21 | ) 22 | 23 | def __call__(self, plate): 24 | if len(plate) != 7: 25 | print("ERROR: Invalid length") 26 | return None 27 | fg = self._draw_fg(plate) 28 | return cv2.cvtColor(cv2.bitwise_or(fg, self._bg), cv2.COLOR_BGR2RGB) 29 | 30 | def _draw_char(self, ch): 31 | # Image.new(mode, size, color) 32 | img = Image.new( 33 | "RGB", (45 if ch.isupper() or ch.isdigit() else 95, 140), (0, 0, 0) 34 | ) 35 | draw = ImageDraw.Draw(img) 36 | # draw.text(position, string, options) 37 | draw.text( 38 | (0, -11 if ch.isupper() or ch.isdigit() else 3), 39 | ch, 40 | fill=(255, 255, 255), 41 | font=self._font[0 if ch.isupper() or ch.isdigit() else 1], 42 | ) 43 | if img.width > 45: 44 | img = img.resize((45, 140)) 45 | return np.array(img) 46 | 47 | def _draw_fg(self, plate): 48 | img = np.array(Image.new("RGB", (440, 140), (0, 0, 0))) 49 | offset = 15 50 | # [0:140, 15:60] 51 | img[0:140, offset : offset + 45] = self._draw_char(plate[0]) 52 | offset = offset + 45 + 12 # 72 53 | # [0:140, 72:117] 54 | img[0:140, offset : offset + 45] = self._draw_char(plate[1]) 55 | offset = offset + 45 + 34 # 151 56 | # [0:140,] 57 | for i in range(2, len(plate)): 58 | img[0:140, offset : offset + 45] = self._draw_char(plate[i]) 59 | offset = offset + 45 + 12 60 | return img 61 | 62 | 63 | if __name__ == "__main__": 64 | import argparse 65 | import matplotlib.pyplot as plt 66 | 67 | parser = argparse.ArgumentParser(description="Generate a black plate.") 68 | parser.add_argument( 69 | "plate", 70 | help="license plate number (default: 京A12345)", 71 | type=str, 72 | nargs="?", 73 | default="粤B45R5V", 74 | ) 75 | args = parser.parse_args() 76 | 77 | draw = Draw() 78 | plate = draw(args.plate) 79 | plt.imshow(plate) 80 | plt.show() 81 | -------------------------------------------------------------------------------- /data/blue_plate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from PIL import Image, ImageFont, ImageDraw 5 | 6 | 7 | class Draw: 8 | _font = [ 9 | ImageFont.truetype(os.path.join(os.path.dirname(__file__), "res/eng_92.ttf"), 126), 10 | ImageFont.truetype(os.path.join(os.path.dirname(__file__), "res/zh_cn_92.ttf"), 95) 11 | ] 12 | _bg = cv2.resize(cv2.imread(os.path.join(os.path.dirname(__file__), "res/blue_bg.png")), (440, 140)) 13 | 14 | def __call__(self, plate): 15 | if len(plate) != 7: 16 | print("ERROR: Invalid length") 17 | return None 18 | fg = self._draw_fg(plate) 19 | return cv2.cvtColor(cv2.bitwise_or(fg, self._bg), cv2.COLOR_BGR2RGB) 20 | 21 | def _draw_char(self, ch): 22 | img = Image.new("RGB", (45 if ch.isupper() or ch.isdigit() else 95, 140), (0, 0, 0)) 23 | draw = ImageDraw.Draw(img) 24 | draw.text( 25 | (0, -11 if ch.isupper() or ch.isdigit() else 3), ch, 26 | fill = (255, 255, 255), 27 | font = self._font[0 if ch.isupper() or ch.isdigit() else 1] 28 | ) 29 | if img.width > 45: 30 | img = img.resize((45, 140)) 31 | return np.array(img) 32 | 33 | def _draw_fg(self, plate): 34 | img = np.array(Image.new("RGB", (440, 140), (0, 0, 0))) 35 | offset = 15 36 | img[0:140, offset:offset+45] = self._draw_char(plate[0]) 37 | offset = offset + 45 + 12 38 | img[0:140, offset:offset+45] = self._draw_char(plate[1]) 39 | offset = offset + 45 + 34 40 | for i in range(2, len(plate)): 41 | img[0:140, offset:offset+45] = self._draw_char(plate[i]) 42 | offset = offset + 45 + 12 43 | return img 44 | 45 | 46 | if __name__ == "__main__": 47 | import argparse 48 | import matplotlib.pyplot as plt 49 | 50 | parser = argparse.ArgumentParser(description="Generate a blue plate.") 51 | parser.add_argument("plate", help="license plate number (default: 京A12345)", type=str, nargs="?", default="京A12345") 52 | args = parser.parse_args() 53 | 54 | draw = Draw() 55 | plate = draw(args.plate) 56 | plt.imshow(plate) 57 | plt.show() 58 | -------------------------------------------------------------------------------- /data/draw.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from PIL import Image, ImageFont, ImageDraw 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | def load_font(): 9 | return { 10 | "京": cv2.imread("res/ne000.png"), "津": cv2.imread("res/ne001.png"), "冀": cv2.imread("res/ne002.png"), 11 | "晋": cv2.imread("res/ne003.png"),"蒙": cv2.imread("res/ne004.png"),"辽": cv2.imread("res/ne005.png"), 12 | "吉": cv2.imread("res/ne006.png"),"黑": cv2.imread("res/ne007.png"),"沪": cv2.imread("res/ne008.png"), 13 | "苏": cv2.imread("res/ne009.png"),"浙": cv2.imread("res/ne010.png"),"皖": cv2.imread("res/ne011.png"), 14 | "闽": cv2.imread("res/ne012.png"),"赣": cv2.imread("res/ne013.png"),"鲁": cv2.imread("res/ne014.png"), 15 | "豫": cv2.imread("res/ne015.png"),"鄂": cv2.imread("res/ne016.png"),"湘": cv2.imread("res/ne017.png"), 16 | "粤": cv2.imread("res/ne018.png"),"桂": cv2.imread("res/ne019.png"),"琼": cv2.imread("res/ne020.png"), 17 | "渝": cv2.imread("res/ne021.png"),"川": cv2.imread("res/ne022.png"),"贵": cv2.imread("res/ne023.png"), 18 | "云": cv2.imread("res/ne024.png"),"藏": cv2.imread("res/ne025.png"),"陕": cv2.imread("res/ne026.png"), 19 | "甘": cv2.imread("res/ne027.png"),"青": cv2.imread("res/ne028.png"),"宁": cv2.imread("res/ne029.png"), 20 | "新": cv2.imread("res/ne030.png"),"A": cv2.imread("res/ne100.png"),"B": cv2.imread("res/ne101.png"), 21 | "C": cv2.imread("res/ne102.png"),"D": cv2.imread("res/ne103.png"),"E": cv2.imread("res/ne104.png"), 22 | "F": cv2.imread("res/ne105.png"),"G": cv2.imread("res/ne106.png"),"H": cv2.imread("res/ne107.png"), 23 | "J": cv2.imread("res/ne108.png"),"K": cv2.imread("res/ne109.png"),"L": cv2.imread("res/ne110.png"), 24 | "M": cv2.imread("res/ne111.png"),"N": cv2.imread("res/ne112.png"),"P": cv2.imread("res/ne113.png"), 25 | "Q": cv2.imread("res/ne114.png"),"R": cv2.imread("res/ne115.png"),"S": cv2.imread("res/ne116.png"), 26 | "T": cv2.imread("res/ne117.png"),"U": cv2.imread("res/ne118.png"),"V": cv2.imread("res/ne119.png"), 27 | "W": cv2.imread("res/ne120.png"),"X": cv2.imread("res/ne121.png"),"Y": cv2.imread("res/ne122.png"), 28 | "Z": cv2.imread("res/ne123.png"),"0": cv2.imread("res/ne124.png"),"1": cv2.imread("res/ne125.png"), 29 | "2": cv2.imread("res/ne126.png"),"3": cv2.imread("res/ne127.png"),"4": cv2.imread("res/ne128.png"), 30 | "5": cv2.imread("res/ne129.png"),"6": cv2.imread("res/ne130.png"),"7": cv2.imread("res/ne131.png"), 31 | "8": cv2.imread("res/ne132.png"),"9": cv2.imread("res/ne133.png") 32 | } 33 | 34 | 35 | class Draw: 36 | def __init__(self, bg): 37 | self._font = [ 38 | ImageFont.truetype(os.path.join(os.path.dirname(__file__), "res/eng_92.ttf"), 126), # 文件路径 39 | ImageFont.truetype(os.path.join(os.path.dirname(__file__), "res/zh_cn_92.ttf"), 95) 40 | ] if bg not in ["green_0", "green_1"] else load_font() 41 | self.bg = bg 42 | self.size = (440, 140) 43 | self.color = [] 44 | self._bg = None 45 | self.plane = (0, 140, 45, 9, 34) 46 | if bg == "black": 47 | self.color = [(0, 0, 0), (255, 255, 255)] 48 | self._bg = cv2.resize(cv2.imread(r"\res\black_bg.png"), self.size) 49 | elif bg == "blue": 50 | self.color = [(0, 0, 0), (255, 255, 255)] 51 | self._bg = cv2.resize(cv2.imread(r"\res\blue_bg.png"), self.size) 52 | elif bg == "green_0": 53 | self.city = 43 54 | self.size = (480, 140) 55 | self.plane = (25, 115, 43, 9, 49) 56 | self.color = [(255, 255, 255), (0, 0, 0)] 57 | self._bg = cv2.resize(cv2.imread(r"\res\green_bg_0.png"), self.size) 58 | elif bg == "green_1": 59 | self.city = 43 60 | self.size = (480, 140) 61 | self.plane = (25, 115, 43, 9, 49) 62 | self.color = [(255, 255, 255), (0, 0, 0)] 63 | self._bg = cv2.resize(cv2.imread(r"\res\green_bg_1.png"), self.size) 64 | elif bg == "yellow": 65 | self.color = [(255, 255, 255), (0, 0, 0)] 66 | self._bg = cv2.resize(cv2.imread(r"\res\yellow_bg.png"), self.size) 67 | 68 | def __call__(self, car_num): 69 | assert len(car_num) in (7, 8), print("Error, car number length must be 7 or 8! but got {}->{}".format(car_num,len(car_num))) 70 | fg = self._draw_fg(car_num) 71 | if self.bg in ["black", "blue"]: 72 | fg = cv2.bitwise_or(fg, self._bg) 73 | elif self.bg in ["yellow", "green_0", "green_1"]: 74 | fg = cv2.bitwise_and(fg, self._bg) 75 | return cv2.cvtColor(fg, cv2.COLOR_BGR2RGB) 76 | 77 | def _draw_ch(self, ch): 78 | if self.bg in ["green_0", "green_1"]: 79 | return cv2.resize(self._font[ch], (43 if ch.isupper() or ch.isdigit() else 45, 90)) 80 | img = Image.new("RGB", (45 if ch.isupper() or ch.isdigit() else 90, 140), self.color[0]) 81 | draw = ImageDraw.Draw(img) 82 | draw.text( 83 | (0, -11 if ch.isupper() or ch.isdigit() else 3), ch, 84 | fill=self.color[1], 85 | font=self._font[0] if ch.isupper() or ch.isdigit() else self._font[1] 86 | ) 87 | if img.width > 45: 88 | img = img.resize((45, 140)) 89 | return np.array(img) 90 | 91 | def _draw_fg(self, car_num): 92 | img = np.array(Image.new("RGB", self.size, self.color[0])) 93 | offset = 15 94 | img[self.plane[0]:self.plane[1], offset:offset + 45] = self._draw_ch(car_num[0]) 95 | offset = offset + 45 + self.plane[3] 96 | img[self.plane[0]:self.plane[1], offset:offset + self.plane[2]] = self._draw_ch(car_num[1]) 97 | offset = offset + self.plane[2] + self.plane[4] 98 | for i in range(2, len(car_num)): 99 | img[self.plane[0]:self.plane[1], offset:offset + self.plane[2]] = self._draw_ch(car_num[i]) 100 | offset = offset + self.plane[2] + self.plane[3] 101 | return img 102 | 103 | 104 | if __name__ == '__main__': 105 | draw = Draw("yellow") 106 | plate = draw("京A12345") 107 | plt.imshow(plate) 108 | plt.show() 109 | -------------------------------------------------------------------------------- /data/green_plate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from PIL import Image, ImageDraw 5 | 6 | 7 | def load_font(): 8 | return { 9 | "京": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne000.png")), 10 | "津": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne001.png")), 11 | "冀": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne002.png")), 12 | "晋": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne003.png")), 13 | "蒙": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne004.png")), 14 | "辽": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne005.png")), 15 | "吉": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne006.png")), 16 | "黑": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne007.png")), 17 | "沪": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne008.png")), 18 | "苏": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne009.png")), 19 | "浙": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne010.png")), 20 | "皖": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne011.png")), 21 | "闽": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne012.png")), 22 | "赣": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne013.png")), 23 | "鲁": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne014.png")), 24 | "豫": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne015.png")), 25 | "鄂": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne016.png")), 26 | "湘": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne017.png")), 27 | "粤": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne018.png")), 28 | "桂": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne019.png")), 29 | "琼": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne020.png")), 30 | "渝": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne021.png")), 31 | "川": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne022.png")), 32 | "贵": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne023.png")), 33 | "云": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne024.png")), 34 | "藏": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne025.png")), 35 | "陕": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne026.png")), 36 | "甘": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne027.png")), 37 | "青": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne028.png")), 38 | "宁": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne029.png")), 39 | "新": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne030.png")), 40 | "A": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne100.png")), 41 | "B": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne101.png")), 42 | "C": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne102.png")), 43 | "D": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne103.png")), 44 | "E": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne104.png")), 45 | "F": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne105.png")), 46 | "G": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne106.png")), 47 | "H": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne107.png")), 48 | "J": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne108.png")), 49 | "K": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne109.png")), 50 | "L": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne110.png")), 51 | "M": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne111.png")), 52 | "N": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne112.png")), 53 | "P": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne113.png")), 54 | "Q": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne114.png")), 55 | "R": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne115.png")), 56 | "S": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne116.png")), 57 | "T": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne117.png")), 58 | "U": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne118.png")), 59 | "V": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne119.png")), 60 | "W": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne120.png")), 61 | "X": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne121.png")), 62 | "Y": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne122.png")), 63 | "Z": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne123.png")), 64 | "0": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne124.png")), 65 | "1": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne125.png")), 66 | "2": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne126.png")), 67 | "3": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne127.png")), 68 | "4": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne128.png")), 69 | "5": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne129.png")), 70 | "6": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne130.png")), 71 | "7": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne131.png")), 72 | "8": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne132.png")), 73 | "9": cv2.imread(os.path.join(os.path.dirname(__file__), "res/ne133.png")) 74 | } 75 | 76 | 77 | class Draw: 78 | _font = load_font() 79 | _bg = [ 80 | cv2.resize(cv2.imread(os.path.join(os.path.dirname(__file__), "res/green_bg_0.png")), (480, 140)), 81 | cv2.resize(cv2.imread(os.path.join(os.path.dirname(__file__), "res/green_bg_1.png")), (480, 140)) 82 | ] 83 | 84 | def __call__(self, plate, bg=0): 85 | if len(plate) != 8: 86 | print("ERROR: Invalid length") 87 | return None 88 | fg = self._draw_fg(plate) 89 | return cv2.cvtColor(cv2.bitwise_and(fg, self._bg[bg]), cv2.COLOR_BGR2RGB) 90 | 91 | def _draw_char(self, ch): 92 | return cv2.resize(self._font[ch], (43 if ch.isupper() or ch.isdigit() else 45, 90)) 93 | 94 | def _draw_fg(self, plate): 95 | img = np.array(Image.new("RGB", (480, 140), (255, 255, 255))) 96 | offset = 15 97 | img[25:115, offset:offset+45] = self._draw_char(plate[0]) 98 | offset = offset + 45 + 9 99 | img[25:115, offset:offset+43] = self._draw_char(plate[1]) 100 | offset = offset + 43 + 49 101 | for i in range(2, len(plate)): 102 | img[25:115, offset:offset+43] = self._draw_char(plate[i]) 103 | offset = offset + 43 + 9 104 | return img 105 | 106 | 107 | if __name__ == "__main__": 108 | import argparse 109 | import matplotlib.pyplot as plt 110 | 111 | parser = argparse.ArgumentParser(description="Generate a green plate.") 112 | parser.add_argument("--background", help="set the backgrond index (default: 0)", type=int, default=0) 113 | parser.add_argument("plate", help="license plate number (default: 京AD12345)", type=str, nargs="?", default="京AD12345") 114 | args = parser.parse_args() 115 | 116 | draw = Draw() 117 | plate = draw(args.plate, args.background) 118 | plt.imshow(plate) 119 | plt.show() 120 | -------------------------------------------------------------------------------- /data/make_label.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | 4 | def point_in_polygon(x, y, pts): 5 | n = len(pts) // 2 6 | pts_x = [pts[i] for i in range(0, n)] 7 | pts_y = [pts[i] for i in range(n, len(pts))] 8 | if not min(pts_x) <= x <= max(pts_x) or not min(pts_y) <= y <= max(pts_y): 9 | return False 10 | res = False 11 | for i in range(n): 12 | j = n - 1 if i == 0 else i - 1 13 | if ((pts_y[i] > y) != (pts_y[j] > y)) and (x < (pts_x[j] - pts_x[i]) * (y - pts_y[i]) / (pts_y[j] - pts_y[i]) + pts_x[i]): 14 | res = not res 15 | return res 16 | 17 | 18 | def object_label(points, dims, stride):# pts,208,16 19 | scale = ((dims + 40.0) / 2.0) / stride 20 | size = dims // stride 21 | label = numpy.zeros((size, size, 9)) 22 | for i in range(size): 23 | y = (i + 0.5) / size 24 | for j in range(size): 25 | x = (j + 0.5) / size 26 | if point_in_polygon(x, y, points): 27 | label[i, j, 0] = 1 28 | pts = numpy.array(points).reshape((2, -1)) 29 | pts = pts * dims / stride 30 | pts -= numpy.array([[j + 0.5], [i + 0.5]]) 31 | pts = pts / scale 32 | label[i, j, 1:] = pts.reshape((-1,)) 33 | return label -------------------------------------------------------------------------------- /data/random_draw.py: -------------------------------------------------------------------------------- 1 | from fake_chs_lp.draw import Draw 2 | import random 3 | import math 4 | import matplotlib.pyplot as plt 5 | import argparse 6 | 7 | 8 | class RandomDraw: 9 | def __init__(self): 10 | self._draw = ["black", "blue", "green", "yellow"] 11 | self._province = ["皖", "沪", "津", "渝", "冀", "晋", "蒙", "辽", "吉", "黑", "苏", "浙", "京", "闽", "赣", "鲁", "豫", "鄂", "湘", 12 | "粤", "桂", "琼", "川", "贵", "云", "藏", "陕", "甘", "青", "宁", "新"] 13 | self._alpha = ["A", "B", "C", "D", "E", "F", "G", "H", "J", "K", "L", "M", "N", "P", "Q", "R", "S", 14 | "T", "U", "V", "W", "X", "Y", "Z"] 15 | self._ads = ["A", "B", "C", "D", "E", "F", "G", "H", "J", "K", "L", "M", "N", "P", "Q", "R", "S", 16 | "T", "U", "V", "W", "X", "Y", "Z", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9"] 17 | 18 | def __call__(self): 19 | draw = random.choice(self._draw) 20 | candidates = [self._province, self._alpha] 21 | if draw == "green": 22 | candidates += [self._ads] * 6 23 | label = "".join([random.choice(c) for c in candidates]) 24 | draw = draw + "_" + str(random.choice([0,1])) 25 | d = Draw(draw) 26 | return d(label), label 27 | elif draw == "black": 28 | if random.random() < 0.3: 29 | candidates += [self._ads] * 4 30 | candidates += [["港", "澳"]] 31 | else: 32 | candidates += [self._ads] * 5 33 | label = "".join([random.choice(c) for c in candidates]) 34 | d = Draw(draw) 35 | return d(label), label 36 | elif draw == "yellow": 37 | if random.random() < 0.3: 38 | candidates += [self._ads] * 4 39 | candidates += [["学"]] 40 | else: 41 | candidates += [self._ads] * 5 42 | label = "".join([random.choice(c) for c in candidates]) 43 | d = Draw(draw) 44 | return d(label), label 45 | else: 46 | candidates += [self._ads] * 5 47 | label = "".join([random.choice(c) for c in candidates]) 48 | d = Draw(draw) 49 | return d(label), label 50 | 51 | if __name__ == '__main__': 52 | parser = argparse.ArgumentParser(description="The input of car numbers") 53 | parser.add_argument("--num", help="numbers to random generate plates", default=9, type=int) 54 | args = parser.parse_args() 55 | 56 | random_draw = RandomDraw() 57 | rows = math.ceil(args.num / 3) 58 | cols = min(args.num, 3) 59 | for i in range(args.num): 60 | plate, label = random_draw() 61 | print(label) 62 | plt.subplot(rows, cols, i+1) 63 | plt.imshow(plate) 64 | plt.axis("off") 65 | plt.show() 66 | -------------------------------------------------------------------------------- /data/random_plate.py: -------------------------------------------------------------------------------- 1 | import random 2 | if __name__ == "__main__": 3 | import black_plate 4 | import blue_plate 5 | import yellow_plate 6 | import green_plate 7 | else: 8 | from . import black_plate 9 | from . import blue_plate 10 | from . import yellow_plate 11 | from . import green_plate 12 | 13 | 14 | class Draw: 15 | _draw = [ 16 | black_plate.Draw(), 17 | blue_plate.Draw(), 18 | yellow_plate.Draw(), 19 | green_plate.Draw() 20 | ] 21 | _provinces = ["皖", "沪", "津", "渝", "冀", "晋", "蒙", "辽", "吉", "黑", "苏", "浙", "京", "闽", "赣", "鲁", "豫", "鄂", "湘", "粤", "桂", "琼", "川", "贵", "云", "藏", "陕", "甘", "青", "宁", "新"] 22 | _alphabets = ["A", "B", "C", "D", "E", "F", "G", "H", "J", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"] 23 | _ads = ["A", "B", "C", "D", "E", "F", "G", "H", "J", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9"] 24 | 25 | def __call__(self): 26 | draw = random.choice(self._draw) 27 | candidates = [self._provinces, self._alphabets] 28 | if type(draw) == green_plate.Draw: 29 | candidates += [self._ads] * 6 30 | label = "".join([random.choice(c) for c in candidates]) 31 | return draw(label, random.randint(0, 1)), label 32 | elif type(draw) == black_plate.Draw: 33 | if random.random() < 0.3: 34 | candidates += [self._ads] * 4 35 | candidates += [["港", "澳"]] 36 | else: 37 | candidates += [self._ads] * 5 38 | label = "".join([random.choice(c) for c in candidates]) 39 | return draw(label), label 40 | elif type(draw) == yellow_plate.Draw: 41 | if random.random() < 0.3: 42 | candidates += [self._ads] * 4 43 | candidates += [["学"]] 44 | else: 45 | candidates += [self._ads] * 5 46 | label = "".join([random.choice(c) for c in candidates]) 47 | return draw(label), label 48 | else: 49 | candidates += [self._ads] * 5 50 | label = "".join([random.choice(c) for c in candidates]) 51 | return draw(label), label 52 | 53 | 54 | if __name__ == "__main__": 55 | import math 56 | import argparse 57 | import matplotlib.pyplot as plt 58 | 59 | parser = argparse.ArgumentParser(description="Generate a green plate.") 60 | parser.add_argument("--num", help="set the number of plates (default: 9)", type=int, default=9) 61 | args = parser.parse_args() 62 | 63 | draw = Draw() 64 | rows = math.ceil(args.num / 3) 65 | cols = min(args.num, 3) 66 | for i in range(args.num): 67 | plate, label = draw() 68 | print(plate.shape) 69 | print(label) 70 | plt.subplot(rows, cols, i + 1) 71 | plt.imshow(plate) 72 | plt.axis("off") 73 | plt.show() 74 | -------------------------------------------------------------------------------- /data/yellow_plate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from PIL import Image, ImageFont, ImageDraw 5 | 6 | 7 | class Draw: 8 | _font = [ 9 | ImageFont.truetype(os.path.join(os.path.dirname(__file__), "res/eng_92.ttf"), 126), 10 | ImageFont.truetype(os.path.join(os.path.dirname(__file__), "res/zh_cn_92.ttf"), 95) 11 | ] 12 | _bg = cv2.resize(cv2.imread(os.path.join(os.path.dirname(__file__), "res/yellow_bg.png")), (440, 140)) 13 | 14 | def __call__(self, plate): 15 | if len(plate) != 7: 16 | print("ERROR: Invalid length") 17 | return None 18 | fg = self._draw_fg(plate) 19 | return cv2.cvtColor(cv2.bitwise_and(fg, self._bg), cv2.COLOR_BGR2RGB) 20 | 21 | def _draw_char(self, ch): 22 | img = Image.new("RGB", (45 if ch.isupper() or ch.isdigit() else 95, 140), (255, 255, 255)) 23 | draw = ImageDraw.Draw(img) 24 | draw.text( 25 | (0, -11 if ch.isupper() or ch.isdigit() else 3), ch, 26 | fill = (0, 0, 0), 27 | font = self._font[0 if ch.isupper() or ch.isdigit() else 1] 28 | ) 29 | if img.width > 45: 30 | img = img.resize((45, 140)) 31 | return np.array(img) 32 | 33 | def _draw_fg(self, plate): 34 | img = np.array(Image.new("RGB", (440, 140), (255, 255, 255))) 35 | offset = 15 36 | img[0:140, offset:offset+45] = self._draw_char(plate[0]) 37 | offset = offset + 45 + 12 38 | img[0:140, offset:offset+45] = self._draw_char(plate[1]) 39 | offset = offset + 45 + 34 40 | for i in range(2, len(plate)): 41 | img[0:140, offset:offset+45] = self._draw_char(plate[i]) 42 | offset = offset + 45 + 12 43 | return img 44 | 45 | 46 | if __name__ == "__main__": 47 | import argparse 48 | import matplotlib.pyplot as plt 49 | 50 | parser = argparse.ArgumentParser(description="Generate a yellow plate.") 51 | parser.add_argument("plate", help="license plate number (default: 京A12345)", type=str, nargs="?", default="京A12345") 52 | args = parser.parse_args() 53 | 54 | draw = Draw() 55 | plate = draw(args.plate) 56 | plt.imshow(plate) 57 | plt.show() 58 | -------------------------------------------------------------------------------- /detect_explorer.py: -------------------------------------------------------------------------------- 1 | # **************************************************************************** # 2 | # # 3 | # ::: :::::::: # 4 | # detect_explorer.py :+: :+: :+: # 5 | # +:+ +:+ +:+ # 6 | # By: peterli +#+ +:+ +#+ # 7 | # +#+#+#+#+#+ +#+ # 8 | # Created: 2023/04/18 17:29:13 by peterli #+# #+# # 9 | # Updated: 2023/04/18 17:29:14 by peterli ### ########.fr # 10 | # # 11 | # **************************************************************************** # 12 | 13 | import config as config 14 | import cv2 15 | import torch 16 | from einops import rearrange 17 | import os 18 | import numpy 19 | 20 | 21 | class DExplorer: 22 | 23 | def __init__(self): 24 | self.net = config.net() 25 | if os.path.exists(config.weight): 26 | self.net.load_state_dict(torch.load(config.weight, map_location='cpu')) 27 | else: 28 | raise RuntimeError('Model parameters are not loaded') 29 | # self.net.to(config.device) 30 | self.net.eval() 31 | 32 | def __call__(self, image_o): 33 | image = image_o.copy() 34 | h, w, c = image.shape 35 | f = min(288 * max(h, w) / min(h, w), 608) / min(h, w) 36 | _w = int(w * f) + (0 if w % 16 == 0 else 16 - w % 16) 37 | _h = int(h * f) + (0 if h % 16 == 0 else 16 - h % 16) 38 | image = cv2.resize(image, (_w, _h), interpolation=cv2.INTER_AREA) 39 | image_tensor = torch.from_numpy(image) / 255 40 | image_tensor = rearrange(image_tensor, 'h w c ->() c h w') 41 | with torch.no_grad(): 42 | y = self.net(image_tensor).cpu() 43 | points = self.select_box(y, (_w, _h)) 44 | return points 45 | 46 | def select_box(self, predict, size, dims=208, stride=16): 47 | wh = numpy.array([[size[0]], [size[1]]]) 48 | probs = predict[0, :, :, 0:2] 49 | probs = torch.softmax(probs, dim=-1).numpy() 50 | affines = torch.cat( 51 | ( 52 | predict[0, :, :, 2:3], 53 | predict[0, :, :, 3:4], 54 | predict[0, :, :, 4:5], 55 | predict[0, :, :, 5:6], 56 | predict[0, :, :, 6:7], 57 | predict[0, :, :, 7:8] 58 | ), 59 | dim=2 60 | ) 61 | h, w, c = affines.shape 62 | affines = affines.reshape(h, w, 2, 3).numpy() 63 | scale = ((dims + 40.0) / 2.0) / stride 64 | unit = numpy.array([[-0.5, -0.5, 1], [0.5, -0.5, 1], [0.5, 0.5, 1], [-0.5, 0.5, 1]]).transpose((1, 0)) 65 | h, w, _ = probs.shape 66 | candidates = [] 67 | for i in range(h): 68 | for j in range(w): 69 | if probs[i, j, 1] > config.confidence_threshold: 70 | affine = affines[i, j] 71 | pts = affine @ unit 72 | # print(affine) 73 | # print(affine) 74 | pts *= scale 75 | pts += numpy.array([[j + 0.5], [i + 0.5]]) 76 | pts *= stride 77 | # print(pts) 78 | pts /= wh 79 | # exit() 80 | candidates.append((pts, probs[i, j, 1])) 81 | # break 82 | 83 | candidates.sort(key=lambda x: x[1], reverse=True) 84 | labels = [] 85 | for pts_c, prob_c in candidates: 86 | tl_c = pts_c.min(axis=1) 87 | br_c = pts_c.max(axis=1) 88 | overlap = False 89 | for pts_l, _ in labels: 90 | tl_l = pts_l.min(axis=1) 91 | br_l = pts_l.max(axis=1) 92 | if self.iou(tl_c, br_c, tl_l, br_l) > 0.1: 93 | overlap = True 94 | break 95 | if not overlap: 96 | labels.append((pts_c, prob_c)) 97 | return labels 98 | 99 | @staticmethod 100 | def iou(tl1, br1, tl2, br2): 101 | x1, y1 = tl1 102 | x2, y2 = br1 103 | x3, y3 = tl2 104 | x4, y4 = br2 105 | wh1 = br1 - tl1 106 | wh2 = br2 - tl2 107 | assert ((wh1 >= 0).sum() > 0 and (wh2 >= 0).sum() > 0) 108 | s1 = (y2 - y1) * (x2 - x1) 109 | s2 = (y4 - y3) * (x4 - x3) 110 | _x1 = max(x1, x3) 111 | _y1 = max(y1, y3) 112 | _x2 = min(x2, x4) 113 | _y2 = max(y2, y4) 114 | w = max(0, _x2 - _x1) 115 | h = max(0, _y2 - _y1) 116 | i = w * h 117 | return i / (s1 + s2 - i) 118 | 119 | 120 | if __name__ == '__main__': 121 | e = DExplorer() 122 | image = cv2.imread('test/test_image.jpg') 123 | label = e(image) 124 | print(label) 125 | -------------------------------------------------------------------------------- /detect_train.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from torch import nn 3 | from utils.loss import FocalLossManyClassification 4 | from utils.dataset import DetectDataset 5 | from einops import rearrange 6 | from tqdm import tqdm 7 | import config as config 8 | import torch 9 | import os 10 | 11 | 12 | class Trainer: 13 | 14 | def __init__(self): 15 | self.net = config.net() 16 | if os.path.exists(config.weight): 17 | self.net.load_state_dict(torch.load(config.weight, map_location='cpu')) 18 | print('Loaded Weights') 19 | else: 20 | print('Do not load weights') 21 | 22 | self.l1_loss = nn.L1Loss() 23 | self.c_loss = nn.CrossEntropyLoss() 24 | self.optimizer = torch.optim.Adam(self.net.parameters(),lr=0.00001) 25 | self.dataset = DetectDataset() 26 | self.data_loader = DataLoader(self.dataset, config.batch_size, drop_last=True) 27 | self.net.to(config.device) 28 | 29 | def train(self): 30 | 31 | for epoch in range(config.epoch): 32 | self.net.train() 33 | loss_sum = 0 34 | for i, (images, labels) in enumerate(tqdm(self.data_loader)): 35 | images = images.to(config.device) 36 | labels = labels.to(config.device) 37 | 38 | predict = self.net(images)# [b,13,13,8] [b, 13, 13, 9] 39 | loss_c, loss_p = self.count_loss(predict, labels) 40 | loss = loss_c + loss_p 41 | self.optimizer.zero_grad() 42 | loss.backward() 43 | self.optimizer.step() 44 | if (i+1) % 10==0: 45 | print(f"epoch={epoch+1},step={i},loss={loss.item()},loss_c={loss_c.item()},loss_p={loss_p.item()}") 46 | if (i+1) % 100 == 0: 47 | torch.save(self.net.state_dict(), config.weight) 48 | loss_sum += loss.item() 49 | logs = f'epoch:{epoch+1},loss:{loss_sum / len(self.data_loader)}' 50 | print(logs) 51 | torch.save(self.net.state_dict(), config.weight) 52 | 53 | def count_loss(self, predict, target): 54 | condition_positive = target[:, :, :, 0] == 1# [b, 13, 13] 55 | condition_negative = target[:, :, :, 0] == 0 56 | 57 | predict_positive = predict[condition_positive] 58 | predict_negative = predict[condition_negative] 59 | 60 | target_positive = target[condition_positive] 61 | target_negative = target[condition_negative] 62 | # print(target_positive.shape) 63 | n, v = predict_positive.shape 64 | if n > 0: 65 | loss_c_positive = self.c_loss(predict_positive[:, 0:2], target_positive[:, 0].long()) 66 | else: 67 | loss_c_positive = 0 68 | loss_c_nagative = self.c_loss(predict_negative[:, 0:2], target_negative[:, 0].long()) 69 | loss_c = loss_c_nagative + loss_c_positive 70 | 71 | if n > 0: 72 | affine = torch.cat( 73 | ( 74 | predict_positive[:, 2:3], 75 | predict_positive[:,3:4], 76 | predict_positive[:,4:5], 77 | predict_positive[:,5:6], 78 | predict_positive[:,6:7], 79 | predict_positive[:,7:8] 80 | ), 81 | dim=1 82 | ) 83 | # print(affine.shape) 84 | # exit() 85 | trans_m = affine.reshape(-1, 2, 3) 86 | unit = torch.tensor([[-0.5, -0.5, 1], [0.5, -0.5, 1], [0.5, 0.5, 1], [-0.5, 0.5, 1]]).transpose(0, 1).to( 87 | trans_m.device).float() 88 | # print(unit) 89 | point_pred = torch.einsum('n j k, k d -> n j d', trans_m, unit) # 等价于求和 90 | point_pred = rearrange(point_pred, 'n j k -> n (j k)') 91 | loss_p = self.l1_loss(point_pred, target_positive[:, 1:]) 92 | else: 93 | loss_p = 0 94 | # exit() 95 | return loss_c, loss_p 96 | 97 | # return loss 98 | 99 | 100 | if __name__ == '__main__': 101 | trainer = Trainer() 102 | trainer.train() 103 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from detect_explorer import DExplorer 2 | from ocr_explorer import Explorer 3 | import cv2 4 | import numpy 5 | import os 6 | 7 | class ReadPlate: 8 | 9 | def __init__(self): 10 | self.detect_exp = DExplorer() 11 | self.ocr_exp = Explorer() 12 | 13 | def __call__(self, image): 14 | points = self.detect_exp(image) 15 | h, w, _ = image.shape 16 | result = [] 17 | # print(points) 18 | for point, _ in points: 19 | plate, box = self.cutout_plate(image, point) 20 | # print(box) 21 | lp = self.ocr_exp(plate) 22 | result.append([lp, box]) 23 | # cv2.imshow('a', plate) 24 | # cv2.waitKey() 25 | return result 26 | 27 | def cutout_plate(self, image, point): 28 | h, w, _ = image.shape 29 | x1, x2, x3, x4, y1, y2, y3, y4 = point.reshape(-1) 30 | x1, x2, x3, x4 = x1 * w, x2 * w, x3 * w, x4 * w 31 | y1, y2, y3, y4 = y1 * h, y2 * h, y3 * h, y4 * h 32 | src = numpy.array([[x1, y1], [x2, y2], [x4, y4], [x3, y3]], dtype="float32") 33 | dst = numpy.array([[0, 0], [144, 0], [0, 48], [144, 48]], dtype="float32") 34 | box = [min(x1, x2, x3, x4), min(y1, y2, y3, y4), max(x1, x2, x3, x4), max(y1, y2, y3, y4)] 35 | M = cv2.getPerspectiveTransform(src, dst) 36 | out_img = cv2.warpPerspective(image, M, (144, 48)) 37 | return out_img, box 38 | 39 | 40 | def val(): 41 | provinces = ["皖", "沪", "津", "渝", "冀", "晋", "蒙", "辽", "吉", "黑", "苏", "浙", "京", "闽", "赣", "鲁", "豫", "鄂", "湘", "粤", "桂", "琼", "川", "贵", "云", "藏", "陕", "甘", "青", "宁", "新", "警", "学", "O"] 42 | alphabets = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 43 | 'X', 'Y', 'Z', 'O'] 44 | ads = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 45 | 'Y', 'Z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'O'] 46 | count = 0 47 | correct = 0 48 | read_plate = ReadPlate() 49 | for pic in os.listdir("test"): 50 | image = cv2.imread('test/' + pic) 51 | predict = read_plate(image)[0][0] 52 | num = pic.split("-")[-3] 53 | nums = num.split("_") 54 | pro = provinces[int(nums[0])] 55 | city = alphabets[int(nums[1])] 56 | 57 | plates = [ads[int(n)] for n in nums[2:]] 58 | true_label = pro + city + "".join(plates) 59 | if true_label == predict: 60 | result = "√" 61 | correct += 1 62 | else: 63 | result = "×" 64 | count += 1 65 | print(f"{true_label}--{predict}--{result}") 66 | print(f"accuracy={correct / count}") 67 | return 68 | 69 | if __name__ == '__main__': 70 | val() -------------------------------------------------------------------------------- /models/detection_nn.py: -------------------------------------------------------------------------------- 1 | # **************************************************************************** # 2 | # # 3 | # ::: :::::::: # 4 | # detection_nn.py :+: :+: :+: # 5 | # +:+ +:+ +:+ # 6 | # By: peterli +#+ +:+ +#+ # 7 | # +#+#+#+#+#+ +#+ # 8 | # Created: 2023/04/18 17:20:42 by peterli #+# #+# # 9 | # Updated: 2023/04/18 17:20:49 by peterli ### ########.fr # 10 | # # 11 | # **************************************************************************** # 12 | 13 | from torch import nn 14 | from torchvision.models import resnet18 15 | import torch 16 | from einops import rearrange 17 | import torch.nn as nn 18 | 19 | 20 | class ResNet(nn.Module): 21 | 22 | ''' 23 | Input dimension : 3 208 208 24 | 25 | Output dimension : 8 13 13 26 | ''' 27 | 28 | def __init__(self, block, num_block, num_classes=100): 29 | super().__init__() 30 | 31 | self.in_channels = 64 32 | 33 | self.conv1 = nn.Sequential( 34 | nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), 35 | nn.BatchNorm2d(64), 36 | nn.ReLU(inplace=True)) 37 | 38 | self.conv2_x = self._make_layer(block, 64, num_block[0], 1) 39 | self.conv3_x = self._make_layer(block, 128, num_block[1], 2) 40 | self.conv4_x = self._make_layer(block, 256, num_block[2], 2) 41 | self.conv5_x = self._make_layer(block, 512, num_block[3], 2) 42 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 43 | self.fc = nn.Linear(512 * block.expansion, num_classes) 44 | 45 | def _make_layer(self, block, out_channels, num_blocks, stride): 46 | 47 | strides = [stride] + [1] * (num_blocks - 1) 48 | layers = [] 49 | for stride in strides: 50 | layers.append(block(self.in_channels, out_channels, stride)) 51 | self.in_channels = out_channels * block.expansion 52 | 53 | return nn.Sequential(*layers) 54 | 55 | def forward(self, x): 56 | output = self.conv1(x) 57 | output = self.conv2_x(output) 58 | output = self.conv3_x(output) 59 | output = self.conv4_x(output) 60 | output = self.conv5_x(output) 61 | output = self.avg_pool(output) 62 | output = output.view(output.size(0), -1) 63 | output = self.fc(output) 64 | output = rearrange(output, 'n c h w -> n h w c') 65 | 66 | return output 67 | 68 | 69 | 70 | class WpodNet(nn.Module): 71 | 72 | ''' 73 | Input dimension : 3 208 208 74 | 75 | Output dimension : 8 13 13 76 | ''' 77 | 78 | def __init__(self): 79 | super(WpodNet, self).__init__() 80 | resnet = resnet18(True) 81 | backbone = list(resnet.children()) 82 | self.backbone = nn.Sequential( 83 | nn.BatchNorm2d(3), 84 | *backbone[:3], 85 | *backbone[4:8], 86 | ) 87 | self.detection = nn.Conv2d(512, 8, 3, 1, 1) 88 | 89 | def forward(self, x): 90 | features = self.backbone(x) 91 | out = self.detection(features) 92 | out = rearrange(out, 'n c h w -> n h w c') 93 | return out 94 | 95 | class BasicBlock(nn.Module): 96 | 97 | def __init__(self, in_channels, out_channels, stride=1): 98 | super().__init__() 99 | 100 | self.residual_function = nn.Sequential( 101 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), 102 | nn.BatchNorm2d(out_channels), 103 | nn.ReLU(inplace=True), 104 | nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False), 105 | nn.BatchNorm2d(out_channels * BasicBlock.expansion) 106 | ) 107 | 108 | self.shortcut = nn.Sequential() 109 | 110 | if stride != 1 or in_channels != BasicBlock.expansion * out_channels: 111 | self.shortcut = nn.Sequential( 112 | nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False), 113 | nn.BatchNorm2d(out_channels * BasicBlock.expansion) 114 | ) 115 | 116 | def forward(self, x): 117 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 118 | 119 | 120 | class BottleNeck(nn.Module): 121 | 122 | def __init__(self, in_channels, out_channels, stride=1): 123 | super().__init__() 124 | self.residual_function = nn.Sequential( 125 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), 126 | nn.BatchNorm2d(out_channels), 127 | nn.ReLU(inplace=True), 128 | nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False), 129 | nn.BatchNorm2d(out_channels), 130 | nn.ReLU(inplace=True), 131 | nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False), 132 | nn.BatchNorm2d(out_channels * BottleNeck.expansion), 133 | ) 134 | 135 | self.shortcut = nn.Sequential() 136 | 137 | if stride != 1 or in_channels != out_channels * BottleNeck.expansion: 138 | self.shortcut = nn.Sequential( 139 | nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False), 140 | nn.BatchNorm2d(out_channels * BottleNeck.expansion) 141 | ) 142 | 143 | def forward(self, x): 144 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 145 | 146 | 147 | 148 | 149 | 150 | if __name__ == '__main__': 151 | m = WpodNet() 152 | x = torch.randn(32, 3, 208, 208) 153 | print(m) 154 | print(m(x).shape) 155 | -------------------------------------------------------------------------------- /models/ocr_nn.py: -------------------------------------------------------------------------------- 1 | # **************************************************************************** # 2 | # # 3 | # ::: :::::::: # 4 | # ocr_nn.py :+: :+: :+: # 5 | # +:+ +:+ +:+ # 6 | # By: peterli +#+ +:+ +#+ # 7 | # +#+#+#+#+#+ +#+ # 8 | # Created: 2023/04/18 17:20:33 by peterli #+# #+# # 9 | # Updated: 2023/04/18 17:20:35 by peterli ### ########.fr # 10 | # # 11 | # **************************************************************************** # 12 | 13 | from torch import nn 14 | from torchvision.models import resnet18 15 | import torch 16 | from einops import rearrange 17 | 18 | 19 | class SelfAttention(nn.Module): 20 | 21 | ''' 22 | Multi-Head Self Attention 23 | 24 | Args: 25 | embed_dim: number of features of word vector 26 | num_head: number of heads 27 | is_masked: Whether to show mask, if true, then network would only be able to see first 28 | 29 | Shape: 30 | - Input: Batch, Series, Vector 31 | - Output: Batch Series, Vector 32 | 33 | Examples:: 34 | # >>> m = SelfAttention(720, 12) 35 | # >>> x = torch.randn(4, 13, 720) 36 | # >>> output = m(x) 37 | # >>> print(output.shape) 38 | # torch.Size([4, 13, 720]) 39 | ''' 40 | 41 | def __init__(self, embed_dim, num_head, is_masked=True): 42 | super(SelfAttention, self).__init__() 43 | assert embed_dim % num_head == 0 44 | self.num_head = num_head 45 | self.is_masked = is_masked 46 | self.linear1 = nn.Linear(embed_dim, 3 * embed_dim) 47 | self.linear2 = nn.Linear(embed_dim, embed_dim) 48 | 49 | def forward(self, x): 50 | 51 | ''' 52 | x has shape Batch, Series, Vector 53 | ''' 54 | 55 | # Shape become Batch, Series, 3 * Vector 56 | x = self.linear1(x) 57 | n, s, v = x.shape 58 | 59 | # Shape become Batch, Series, H, Vector 60 | x = x.reshape(n, s, self.num_head, -1) 61 | 62 | # Shape become Batch, H, Series, Vector 63 | x = torch.transpose(x, 1, 2) 64 | 65 | query, key, value = torch.chunk(x, 3, -1) 66 | dk = value.shape[-1] ** 0.5 67 | 68 | # Self Attention 69 | w = torch.matmul(query, key.transpose(-1, -2)) / dk 70 | 71 | if self.is_masked: 72 | mask = torch.tril(torch.ones(w.shape[-1], w.shape[-1])).to(w.device) 73 | w = w * mask - 1e10 * (1 - mask) 74 | 75 | # Combine for result 76 | w = torch.softmax(w, dim=-1) 77 | attention = torch.matmul(w, value) 78 | attention = attention.permute(0, 2, 1, 3) 79 | n, s, h, v = attention.shape 80 | 81 | # Concatinate 82 | attention = attention.reshape(n, s, h * v) 83 | 84 | # Linear Layer 85 | return self.linear2(attention) 86 | 87 | 88 | class Block(nn.Module): 89 | 90 | ''' 91 | Block 92 | 93 | Args: 94 | embed_dim: number of features of word vector 95 | num_head: number of heads 96 | is_masked: Whether to show mask, if true, then network would only be able to see first 97 | 98 | Shape: 99 | - Input: Batch, Series, Vector 100 | - Output: Batch Series, Vector 101 | 102 | Examples:: 103 | # >>> m = Block(720, 12) 104 | # >>> x = torch.randn(4, 13, 720) 105 | # >>> output = m(x) 106 | # >>> print(output.shape) 107 | # torch.Size([4, 13, 720]) 108 | ''' 109 | 110 | def __init__(self, embed_dim, num_head, is_masked): 111 | super(Block, self).__init__() 112 | self.ln_1 = nn.LayerNorm(embed_dim) 113 | self.attention = SelfAttention(embed_dim, num_head, is_masked) 114 | self.ln_2 = nn.LayerNorm(embed_dim) 115 | 116 | self.feed_forward = nn.Sequential( 117 | nn.Linear(embed_dim, embed_dim * 6), 118 | nn.ReLU(), 119 | nn.Linear(embed_dim * 6, embed_dim), 120 | ) 121 | 122 | def forward(self, x): 123 | 124 | # First, get attention 125 | attention = self.attention(self.ln_1(x)) 126 | 127 | # Residual 128 | x = attention + x 129 | x = self.ln_2(x) 130 | 131 | # Feed Forward 132 | h = self.feed_forward(x) 133 | x = h + x 134 | return x 135 | 136 | 137 | class AbsPosEmb(nn.Module): 138 | def __init__(self, fmap_size, dim_head): # (3,9) # 512 139 | super().__init__() 140 | height, width = fmap_size # 3,9 141 | scale = dim_head**-0.5 142 | self.height = nn.Parameter(torch.randn(height, dim_head) * scale) 143 | self.width = nn.Parameter(torch.randn(width, dim_head) * scale) 144 | 145 | def forward(self): 146 | emb = rearrange(self.height, "h d -> h () d") + rearrange( 147 | self.width, "w d -> () w d" 148 | ) 149 | emb = rearrange(emb, " h w d -> (w h) d") 150 | # logits = torch.einsum('b i d, j d -> b i j', q, emb) 151 | return emb 152 | 153 | 154 | class OcrNet(nn.Module): 155 | 156 | ''' 157 | Input Dimension: 3, 48, 144 158 | 159 | Ouput Dimension: 27, Batch, Num of Classes 160 | ''' 161 | 162 | def __init__(self, num_class): 163 | super(OcrNet, self).__init__() 164 | resnet = resnet18(True) 165 | backbone = list(resnet.children()) 166 | self.backbone = nn.Sequential( 167 | nn.BatchNorm2d(3), 168 | *backbone[:3], 169 | *backbone[4:8], 170 | ) 171 | self.decoder = nn.Sequential( 172 | Block(512, 8, False), 173 | Block(512, 8, False), 174 | Block(512, 8, False), 175 | ) 176 | self.out_layer = nn.Linear(512, num_class) 177 | self.abs_pos_emb = AbsPosEmb((3, 9), 512) 178 | 179 | def forward(self, x): 180 | x = self.backbone(x) 181 | x = rearrange(x, "n c h w -> n (w h) c") 182 | x = x + self.abs_pos_emb() 183 | x = self.decoder(x) 184 | x = rearrange(x, "n s v -> s n v") 185 | return self.out_layer(x) 186 | 187 | 188 | if __name__ == "__main__": 189 | m = OcrNet(70) 190 | print(m) 191 | x = torch.randn(32, 3, 48, 144) 192 | print(m(x).shape) 193 | -------------------------------------------------------------------------------- /ocr_config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class_name = ['*', "皖", "沪", "津", "渝", "冀", "晋", "蒙", "辽", "吉", "黑", "苏", "浙", "京", "闽", "赣", "鲁", "豫", "鄂", "湘", "粤", 4 | "桂", "琼", "川", "贵", "云", "藏", "陕", "甘", "青", "宁", "新", "警", "学", '港', '澳', 'A', 'B', 'C', 'D', 'E', 'F', 5 | 'G', 'H', 'J', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 6 | 'Y', 'Z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] 7 | 8 | device = 'cuda:0' 9 | 10 | device = torch.device(device if torch.cuda.is_available() else 'cpu') 11 | num_class = len(class_name) 12 | weight = 'weights/ocr_net2.pth' 13 | # print(num_class) 14 | 15 | -------------------------------------------------------------------------------- /ocr_explorer.py: -------------------------------------------------------------------------------- 1 | from models.ocr_nn import OcrNet 2 | import ocr_config as config 3 | import torch 4 | import cv2 5 | import numpy as np 6 | import os 7 | 8 | 9 | class Explorer: 10 | 11 | def __init__(self, is_cuda=False): 12 | self.device = config.device 13 | self.net = OcrNet(config.num_class) 14 | if os.path.exists(config.weight): 15 | self.net.load_state_dict(torch.load(config.weight, map_location='cpu')) 16 | print(f'Load success:{config.weight.split("/")[-1]}') 17 | else: 18 | raise RuntimeError('Model parameters are not loaded') 19 | self.net = self.net.to(self.device).eval() 20 | 21 | def __call__(self, image): 22 | with torch.no_grad(): 23 | image = torch.from_numpy(image).permute(2, 0, 1) / 255 24 | image = image.unsqueeze(0).to(self.device) 25 | # print(image.shape) 26 | out = self.net(image).reshape(-1, 70) 27 | out = torch.argmax(out, dim=1).cpu().numpy().tolist() 28 | c = '' 29 | for i in out: 30 | c += config.class_name[i] 31 | return self.deduplication(c) 32 | 33 | def deduplication(self, c): 34 | temp = '' 35 | new = '' 36 | for i in c: 37 | if i == temp: 38 | continue 39 | else: 40 | if i == '*': 41 | temp = i 42 | continue 43 | new += i 44 | temp = i 45 | return new 46 | 47 | 48 | if __name__ == '__main__': 49 | import os 50 | import matplotlib.pyplot as plt 51 | e = Explorer() 52 | co = 0 53 | i = 0 54 | from fake_chs_lp.random_plate import Draw 55 | 56 | draw = Draw() 57 | for i in range(10): 58 | plate, label = draw() 59 | plate = cv2.resize(plate,(144,48)) 60 | c = e(plate) 61 | print(i, c, label) 62 | if c == label: 63 | co += 1 64 | # cv2.imshow('a', plate) 65 | # cv2.waitKey(0) 66 | print(co, i, co / i) 67 | -------------------------------------------------------------------------------- /ocr_test.py: -------------------------------------------------------------------------------- 1 | from fake_chs_lp.random_plate import Draw 2 | # from models.ocr_net2 import OcrNet 3 | from ocr_explorer import Explorer 4 | import cv2 5 | # import torch 6 | 7 | draw = Draw() 8 | explorer = Explorer() 9 | yes = 0 10 | count = 0 11 | for i in range(1000): 12 | plate, label = draw() 13 | plate = cv2.cvtColor(plate, cv2.COLOR_RGB2BGR) 14 | plate = cv2.resize(plate, (144, 48)) 15 | # cv2.imshow('a', plate) 16 | a = explorer(plate) 17 | if a == label: 18 | yes += 1 19 | count += 1 20 | print(a,"-",label) 21 | # print(a) 22 | # cv2.waitKey(0) 23 | print(yes / count, yes, count) 24 | # cv2.waitKey() 25 | -------------------------------------------------------------------------------- /ocr_train.py: -------------------------------------------------------------------------------- 1 | import ocr_config 2 | import torch 3 | from torch import nn 4 | from torch.utils.data import DataLoader 5 | from models.ocr_nn import OcrNet 6 | from utils.dataset import OcrDataSet 7 | import os 8 | from tqdm import tqdm 9 | 10 | 11 | class Trainer: 12 | 13 | def __init__(self, load_parameters=True): 14 | self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 15 | self.net = OcrNet(ocr_config.num_class) 16 | if os.path.exists(ocr_config.weight) and load_parameters: 17 | self.net.load_state_dict(torch.load(ocr_config.weight, map_location='cpu')) 18 | print('Going to train with pretrained weights') 19 | elif load_parameters: 20 | print('Going to train without pretrained weights') 21 | else: 22 | raise RuntimeError('Model parameters are not loaded') 23 | self.dataset = OcrDataSet() 24 | self.dataloader = DataLoader(self.dataset, 64, True) 25 | self.net = self.net.to(self.device).train() 26 | self.loss_func = nn.CTCLoss(blank=0, zero_infinity=True) 27 | self.optimizer = torch.optim.Adam(self.net.parameters(), lr=0.00001) 28 | 29 | def __call__(self): 30 | # accumulation_steps = 3 31 | epoch = 0 32 | # print(len(self.dataloader)) 33 | while True: 34 | loss_sum=0 35 | for i, (images, targets, target_lengths) in enumerate(tqdm(self.dataloader)): 36 | images = images.to(self.device)#[64, 3, 48, 144] 37 | # print("ocr:",images.shape) 38 | e = torch.tensor([]) 39 | for i, j in enumerate(target_lengths): 40 | e = torch.cat((e, targets[i][:j]), dim=0) 41 | targets = e.long() 42 | # print(targets) 43 | targets = targets.to(self.device) 44 | target_lengths = target_lengths.to(self.device) 45 | 46 | predict = self.net(images) 47 | s, n, v = predict.shape 48 | input_lengths = torch.full(size=(n,), fill_value=s, dtype=torch.long) 49 | 50 | loss = self.loss_func(predict.log_softmax(2), targets, input_lengths, target_lengths) 51 | 52 | self.optimizer.zero_grad() 53 | loss.backward() 54 | self.optimizer.step() 55 | 56 | loss_sum += loss.item() 57 | i+=1 58 | logs = f'''{epoch},loss_sum: {loss_sum / len(self.dataloader)}''' 59 | torch.save(self.net.state_dict(), ocr_config.weight) 60 | print(logs) 61 | epoch += 1 62 | if epoch == 40: 63 | break 64 | 65 | 66 | if __name__ == '__main__': 67 | trainer = Trainer(True) 68 | trainer() -------------------------------------------------------------------------------- /utils/data_augmentation.py: -------------------------------------------------------------------------------- 1 | # **************************************************************************** # 2 | # # 3 | # ::: :::::::: # 4 | # data_augmentation.py :+: :+: :+: # 5 | # +:+ +:+ +:+ # 6 | # By: peterli +#+ +:+ +#+ # 7 | # +#+#+#+#+#+ +#+ # 8 | # Created: 2023/04/18 17:26:38 by peterli #+# #+# # 9 | # Updated: 2023/04/18 17:28:24 by peterli ### ########.fr # 10 | # # 11 | # **************************************************************************** # 12 | 13 | import cv2 14 | import numpy 15 | from random import randint 16 | import random 17 | import math 18 | 19 | 20 | class Smudge: 21 | 22 | def __init__(self, smu=None): 23 | self._smu = cv2.imread(smu) 24 | 25 | def __call__(self, image): 26 | """input image shape is bgr""" 27 | h1, w1, _ = self._smu.shape 28 | h2, w2, _ = image.shape 29 | y = randint(0, h1 - h2) 30 | x = randint(0, w1 - w2) 31 | texture = self._smu[y:y + h2, x:x + w2] 32 | return cv2.bitwise_not(cv2.bitwise_and(cv2.bitwise_not(image), texture)) 33 | 34 | def gauss_blur(image): 35 | level = randint(0, 8) 36 | return cv2.blur(image, (level * 2 + 1, level * 2 + 1)) 37 | 38 | def gauss_noise(image): 39 | for i in range(image.shape[2]): 40 | c = image[:, :, i] 41 | diff = 255 - c.max() 42 | noise = numpy.random.normal(0, randint(1, 6), c.shape) 43 | noise = (noise - noise.min()) / (noise.max() - noise.min()) 44 | noise = diff * noise 45 | image[:, :, i] = c + noise.astype(numpy.uint8) 46 | return image 47 | 48 | 49 | def transform_matrix(pts, t_pts): 50 | return cv2.getPerspectiveTransform(numpy.float32(pts[:2, :].T), numpy.float32(t_pts[:2, :].T)) 51 | 52 | def points_matrix(pts): 53 | return numpy.matrix(numpy.concatenate((pts, numpy.ones((1, pts.shape[1]))), 0)) 54 | 55 | 56 | def rect_matrix(tlx, tly, brx, bry): 57 | return numpy.matrix([ 58 | [tlx, brx, brx, tlx], 59 | [tly, tly, bry, bry], 60 | [1.0, 1.0, 1.0, 1.0] 61 | ]) 62 | 63 | def rotate_matrix(width, height, angles=numpy.zeros(3), zcop=1000.0, dpp=1000.0): 64 | rads = numpy.deg2rad(angles) 65 | rx = numpy.matrix([ 66 | [1.0, 0.0, 0.0], 67 | [0.0, math.cos(rads[0]), math.sin(rads[0])], 68 | [0.0, -math.sin(rads[0]), math.cos(rads[0])] 69 | ]) 70 | ry = numpy.matrix([ 71 | [math.cos(rads[1]), 0.0, -math.sin(rads[1])], 72 | [0.0, 1.0, 0.0], 73 | [math.sin(rads[1]), 0.0, math.cos(rads[1])] 74 | ]) 75 | rz = numpy.matrix([ 76 | [math.cos(rads[2]), math.sin(rads[2]), 0.0], 77 | [-math.sin(rads[2]), math.cos(rads[2]), 0.0], 78 | [0.0, 0.0, 1.0] 79 | ]) 80 | r = rx * ry * rz 81 | hxy = numpy.matrix([ 82 | [0.0, 0.0, width, width], 83 | [0.0, height, 0.0, height], 84 | [1.0, 1.0, 1.0, 1.0] 85 | ]) 86 | xyz = numpy.matrix([ 87 | [0.0, 0.0, width, width], 88 | [0.0, height, 0.0, height], 89 | [0.0, 0.0, 0.0, 0.0] 90 | ]) 91 | half = numpy.matrix([[width], [height], [0.0]]) / 2.0 92 | xyz = r * (xyz - half) - numpy.matrix([[0.0], [0.0], [zcop]]) 93 | xyz = numpy.concatenate((xyz, numpy.ones((1, 4))), 0) 94 | p = numpy.matrix([ 95 | [1.0, 0.0, 0.0, 0.0], 96 | [0.0, 1.0, 0.0, 0.0], 97 | [0.0, 0.0, -1.0 / dpp, 0.0] 98 | ]) 99 | t_hxy = p * xyz 100 | t_hxy = t_hxy / t_hxy[2, :] + half 101 | return transform_matrix(hxy, t_hxy) 102 | 103 | 104 | def project(img, pts, trans, dims): 105 | t_img = cv2.warpPerspective(img, trans, (dims, dims)) 106 | t_pts = numpy.matmul(trans, points_matrix(pts)) 107 | t_pts = t_pts / t_pts[2] 108 | return t_img, t_pts[:2] 109 | 110 | def hsv_noise(img): 111 | hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) 112 | hsv[:, :, 0] = hsv[:, :, 0] * (0.8 + random.uniform(0.0, 0.8)) 113 | hsv[:, :, 1] = hsv[:, :, 1] * (0.3 + random.uniform(0.0, 0.7)) 114 | hsv[:, :, 2] = hsv[:, :, 2] * (0.2 + random.uniform(0.0, 0.2)) 115 | return cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) 116 | 117 | def brightness_noise(img, ratio=0.1): 118 | return numpy.clip(img * (1.0 + random.uniform(-ratio, ratio)), 0, 255) 119 | 120 | def update(image, lightness, saturation): 121 | 122 | image = image.astype(numpy.float32) / 255.0 123 | hlsImg = cv2.cvtColor(image, cv2.COLOR_BGR2HLS) 124 | 125 | hlsImg[:, :, 1] = (1.0 + lightness / float(100)) * hlsImg[:, :, 1] 126 | hlsImg[:, :, 1][hlsImg[:, :, 1] > 1] = 1 127 | hlsImg[:, :, 2] = (1.0 + saturation / float(100)) * hlsImg[:, :, 2] 128 | hlsImg[:, :, 2][hlsImg[:, :, 2] > 1] = 1 129 | lsImg = cv2.cvtColor(hlsImg, cv2.COLOR_HLS2BGR) * 255 130 | lsImg = lsImg.astype(numpy.uint8) 131 | return lsImg 132 | 133 | 134 | def augment_sample(image, dims=208): 135 | points = [val + random.uniform(-0.1, 0.1) for val in [0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0]] 136 | points = numpy.array(points).reshape((2, 4)) 137 | points *= numpy.array([[image.shape[1]], [image.shape[0]]]) 138 | wh_ratio = random.uniform(2.0, 4.0) 139 | width = random.uniform(dims * 0.2, dims * 1.0) 140 | height = width / wh_ratio 141 | dx = random.uniform(0.0, dims - width) 142 | dy = random.uniform(0.0, dims - height) 143 | crop = transform_matrix( 144 | points_matrix(points), 145 | rect_matrix(dx, dy, dx + width, dy + height) 146 | ) 147 | max_angles = numpy.array([80.0, 80.0, 45.0]) 148 | angles = numpy.random.rand(3) * max_angles 149 | if angles.sum() > 120: 150 | angles = (angles / angles.sum()) * (max_angles / max_angles.sum()) 151 | rotate = rotate_matrix(dims, dims, angles) 152 | image, points = project(image, points, numpy.matmul(rotate, crop), dims) 153 | points /= dims 154 | image = hsv_noise(image) 155 | image = update(image,randint(-80,100),randint(-80,100)) 156 | return image, numpy.asarray(points).reshape((-1,)).tolist() 157 | 158 | 159 | def reconstruct_plates(image, plate_pts, out_size=(144, 48)): 160 | wh = numpy.array([[image.shape[1]], [image.shape[0]]]) 161 | plates = [] 162 | for pts in plate_pts: 163 | pts = points_matrix(pts * wh) 164 | t_pts = rect_matrix(0, 0, out_size[0], out_size[1]) 165 | m = transform_matrix(pts, t_pts) 166 | plate = cv2.warpPerspective(image, m, out_size) 167 | plates.append(plate) 168 | return plates 169 | 170 | def random_cut(image,size): 171 | h,w,c = image.shape 172 | min_side = min(h,w) 173 | h_sid_len = random.randint(int(0.2*min_side), int(0.9*min_side)) 174 | w_sid_len = random.randint(int(0.2 * min_side), int(0.9 * min_side)) 175 | h_s = random.randint(0, h-h_sid_len) 176 | w_s = random.randint(0, w - w_sid_len) 177 | image = image[h_s:h_s+h_sid_len,w_s:w_s+w_sid_len] 178 | image = cv2.resize(image,size,interpolation=cv2.INTER_AREA) 179 | return image 180 | 181 | 182 | def apply_plate(image, points, plate): 183 | points = [[points[2*i], points[2*i+1]] for i in range(4)] # to [(tlx,tly),(trx,try),(blx,bly),(brx,bry)] 184 | points = numpy.float32(points) 185 | h,w,_ = plate.shape # [140,440,3] 186 | pt2 = numpy.float32([[0,0],[w,0],[0,h],[w,h]]) 187 | m = cv2.getPerspectiveTransform(pt2,points) 188 | h,w,_ = image.shape #[208,208,3] 189 | mask = numpy.ones_like(plate, dtype=numpy.uint8) 190 | out_img = cv2.warpPerspective(plate, m, (w, h)) 191 | mask = cv2.warpPerspective(mask, m, (w, h)) 192 | mask = mask != 0 193 | image[mask] = out_img[mask] 194 | return image 195 | 196 | 197 | def augment_detect(image, points, dims, flip_prob=0.5): 198 | points = numpy.array(points).reshape((2, 4)) 199 | wh_ratio = random.uniform(2.0, 4.0) #[2.0,4.0) -> 3 200 | width = random.uniform(dims * 0.2, dims * 1.0) # [41.6, 208.0) -> 120 201 | height = width / wh_ratio # 120 / 3 = 40 202 | dx = random.uniform(0.0, dims - width) # [0.0, 120) -> 60 203 | dy = random.uniform(0.0, dims - height) # [0.0, 40) -> 20 204 | crop = transform_matrix( 205 | points_matrix(points),# shape=[2,4]->[3, 4] 206 | rect_matrix(dx, dy, dx + width, dy + height) #[60,20,180,60] 207 | ) 208 | # random rotate 209 | max_angles = numpy.array([80.0, 80.0, 45.0]) 210 | angles = numpy.random.rand(3) * max_angles 211 | if angles.sum() > 120: 212 | angles = (angles / angles.sum()) * (max_angles / max_angles.sum()) 213 | # print(angles) 214 | rotate = rotate_matrix(dims, dims, angles) 215 | # apply projection 216 | image, points = project(image, points, numpy.matmul(rotate, crop), dims) 217 | # scale the coordinates of points to [0, 1] 218 | points /= dims 219 | # random flip 220 | if random.random() < flip_prob: 221 | image = cv2.flip(image, 1) 222 | points[0] = 1 - points[0] 223 | points = points[..., [1, 0, 3, 2]] 224 | # color augment 225 | image = hsv_noise(image) 226 | # brightness augment 227 | image = update(image,randint(-80,300),randint(-80,250)) 228 | return image, numpy.asarray(points).reshape((-1,)).tolist() 229 | 230 | if __name__ == '__main__': 231 | import matplotlib.pyplot as plt 232 | img = cv2.imread("../test/test_image.jpg") 233 | smu_img,a = augment_sample(img) 234 | plt.imshow(smu_img) 235 | plt.show() 236 | 237 | 238 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | from data import make_label 2 | from torch.utils.data import Dataset 3 | from fake_chs_lp.random_plate import Draw 4 | import os 5 | from einops import rearrange 6 | import random 7 | import cv2 8 | from utils import data_augmentation 9 | import numpy 10 | import torch 11 | import ocr_config 12 | import config 13 | import re 14 | 15 | 16 | class OcrDataSet(Dataset): 17 | 18 | ''' 19 | Input: 3, 48, 144 20 | 21 | Output: 27, Batch, Num of Classes 22 | ''' 23 | 24 | def __init__(self): 25 | super(OcrDataSet, self).__init__() 26 | self.dataset = [] 27 | self.draw = Draw() 28 | for i in range(10000): 29 | self.dataset.append(1) 30 | self.smudge = data_augmentation.Smudge() 31 | 32 | def __len__(self): 33 | return len(self.dataset) 34 | 35 | def __getitem__(self, item): 36 | plate, label = self.draw() 37 | target = [] 38 | for i in label: 39 | target.append(ocr_config.class_name.index(i)) 40 | plate = cv2.cvtColor(plate, cv2.COLOR_RGB2BGR) 41 | 42 | ''' 43 | Data Augmentation 44 | ''' 45 | 46 | plate = self.data_to_enhance(plate) 47 | 48 | image = torch.from_numpy(plate).permute(2, 0, 1) / 255 49 | target_length = torch.tensor(len(target)).long() 50 | target = torch.tensor(target).reshape(-1).long() 51 | _target = torch.full(size=(15,), fill_value=0, dtype=torch.long) 52 | _target[:len(target)] = target 53 | 54 | return image, _target, target_length 55 | 56 | def data_to_enhance(self, plate): 57 | # Smudge 58 | plate = self.smudge(plate) 59 | # Gaussian Blur 60 | plate = data_augmentation.gauss_blur(plate) 61 | # Gaussian Noise 62 | plate = data_augmentation.gauss_noise(plate) 63 | plate, pts = data_augmentation.augment_sample(plate) 64 | plate = data_augmentation.reconstruct_plates(plate, [numpy.array(pts).reshape((2, 4))])[0] 65 | return plate 66 | 67 | 68 | class DetectDataset(Dataset): 69 | 70 | def __init__(self): 71 | super(DetectDataset, self).__init__() 72 | self.dataset = [] 73 | self.draw = Draw() 74 | self.smudge = data_augmentation.Smudge() 75 | root = config.image_root 76 | for image_name in os.listdir(root): 77 | box = self.get_box(image_name) 78 | x3, y3, x4, y4, x1, y1, x2, y2 = box 79 | box = [x1, y1, x2, y2, x4, y4, x3, y3] 80 | self.dataset.append((f'{root}/{image_name}', box)) 81 | 82 | def __len__(self): 83 | return len(self.dataset) 84 | 85 | def __getitem__(self, item): 86 | """return (img,label) -- img shape is [3, 208, 208], label shape is [13, 13, 9]""" 87 | """[364, 517, 440, 515, 368, 549, 444, 547] tl tr bl br""" 88 | image_path, points = self.dataset[item] 89 | image = cv2.imread(image_path) 90 | 91 | if random.random() < 0.5: 92 | plate, _ = self.draw() 93 | plate = cv2.cvtColor(plate, cv2.COLOR_RGB2BGR) 94 | plate = self.smudge(plate) 95 | image = data_augmentation.apply_plate(image, points, plate) 96 | [x1, y1, x2, y2, x4, y4, x3, y3] = points 97 | points = [x1, x2, x3, x4, y1, y2, y3, y4] 98 | image, pts = data_augmentation.augment_detect(image, points, 208) 99 | image_tensor = torch.from_numpy(image)/255 100 | image_tensor = rearrange(image_tensor, 'h w c -> c h w') 101 | label = make_label.object_label(pts,208,16) 102 | label = torch.from_numpy(label).float() 103 | return image_tensor,label 104 | 105 | @staticmethod 106 | def up_background(image): 107 | image = data_augmentation.gauss_blur(image) 108 | image = data_augmentation.gauss_noise(image) 109 | image = data_augmentation.random_cut(image, (208, 208)) 110 | return image 111 | 112 | def data_to_enhance(self, plate): 113 | plate = self.smudge(plate) 114 | plate = data_augmentation.gauss_blur(plate) 115 | plate = data_augmentation.gauss_noise(plate) 116 | plate, pts = data_augmentation.augment_sample(plate) 117 | plate = data_augmentation.reconstruct_plates(plate, [numpy.array(pts).reshape((2, 4))])[0] 118 | return plate 119 | 120 | @staticmethod 121 | def get_box(name): 122 | name = re.split('[.&_-]', name)[7:15] 123 | name = [int(i) for i in name] 124 | return name 125 | 126 | 127 | if __name__ == '__main__': 128 | import matplotlib.pyplot as plt 129 | import numpy as np 130 | data_set = DetectDataset() 131 | data_ocr = OcrDataSet() 132 | img, target, tl = data_ocr[0] 133 | img = torch.permute(img, [1,2,0]) 134 | print(img.shape, target.shape,tl) 135 | plt.imshow(img) 136 | plt.show() -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | # **************************************************************************** # 2 | # # 3 | # ::: :::::::: # 4 | # loss.py :+: :+: :+: # 5 | # +:+ +:+ +:+ # 6 | # By: peterli +#+ +:+ +#+ # 7 | # +#+#+#+#+#+ +#+ # 8 | # Created: 2023/04/18 17:28:32 by peterli #+# #+# # 9 | # Updated: 2023/04/18 17:28:33 by peterli ### ########.fr # 10 | # # 11 | # **************************************************************************** # 12 | 13 | import torch 14 | from torch import nn 15 | import numpy 16 | 17 | 18 | class FocalLossManyClassification(nn.Module): 19 | 20 | def __init__(self, num_class, alpha=None, gamma=2, balance_index=-1, smooth=None): 21 | super(FocalLossManyClassification, self).__init__() 22 | self.num_class = num_class 23 | self.alpha = alpha 24 | self.gamma = gamma 25 | self.smooth = smooth 26 | 27 | if self.alpha is None: 28 | self.alpha = torch.ones(self.num_class, 1) 29 | elif isinstance(self.alpha, (list, numpy.ndarray)): 30 | assert len(self.alpha) == self.num_class 31 | self.alpha = torch.FloatTensor(alpha).view(self.num_class, 1) 32 | self.alpha = self.alpha / self.alpha.sum() 33 | self.alpha = self.alpha.numpy() 34 | elif isinstance(self.alpha, float): 35 | alpha = torch.ones(self.num_class, 1) 36 | alpha *= 1 - self.alpha 37 | alpha[balance_index] = self.alpha 38 | self.alpha = alpha 39 | else: 40 | raise TypeError('Not support alpha type') 41 | 42 | if self.smooth is not None: 43 | if self.smooth < 0 or self.smooth > 1.0: 44 | raise ValueError('smooth value should be in [0,1]') 45 | 46 | def forward(self, input, target): 47 | logit = torch.softmax(input, dim=1) 48 | 49 | if logit.dim() > 2: 50 | # N,C,d1,d2 -> N,C,m (m=d1*d2*...) 51 | logit = logit.view(logit.size(0), logit.size(1), -1) 52 | logit = logit.permute(0, 2, 1).contiguous() 53 | logit = logit.view(-1, logit.size(-1)) 54 | target = target.view(-1, 1) 55 | 56 | # N = input.size(0) 57 | # alpha = torch.ones(N, self.num_class) 58 | # alpha = alpha * (1 - self.alpha) 59 | # alpha = alpha.scatter_(1, target.long(), self.alpha) 60 | epsilon = 1e-10 61 | # alpha = self.alpha 62 | alpha = torch.ones(self.num_class, 1) 63 | if alpha.device != input.device: 64 | alpha = alpha.to(input.device) 65 | 66 | idx = target.cpu().long() 67 | one_hot_key = torch.FloatTensor(target.size(0), self.num_class).zero_() 68 | one_hot_key = one_hot_key.scatter_(1, idx, 1) 69 | if one_hot_key.device != logit.device: 70 | one_hot_key = one_hot_key.to(logit.device) 71 | 72 | if self.smooth: 73 | one_hot_key = torch.clamp( 74 | one_hot_key, self.smooth, 1.0 - self.smooth) 75 | pt = (one_hot_key * logit).sum(1) + epsilon 76 | logpt = pt.log() 77 | 78 | gamma = self.gamma 79 | 80 | alpha = alpha[idx] 81 | loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt 82 | return loss.mean() -------------------------------------------------------------------------------- /utils/smu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeterrrrLi/ResNet-Transformer-OCR-Pytorch/9a22ad38e0490ca32c3274be01208cadcc7c2662/utils/smu.png --------------------------------------------------------------------------------