├── test ├── AW73RHW_18771.jpg ├── BK38RRH_45316.jpg ├── FV45YUG_81667.jpg ├── KA93EL_51307.jpg ├── MX90ECH_89776.jpg ├── QG27TAP_29957.jpg ├── TM77VH_62129.jpg └── Y81AEM_16610.jpg ├── fonts ├── RockoFLF-Bold.ttf └── UKNumberPlate.ttf ├── requirements.txt ├── README.md ├── data_aug.py ├── utils.py ├── main.py ├── model └── LPRnet.py └── gen_plates.py /test/AW73RHW_18771.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bluesy7585/tensorflow_LPRnet/HEAD/test/AW73RHW_18771.jpg -------------------------------------------------------------------------------- /test/BK38RRH_45316.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bluesy7585/tensorflow_LPRnet/HEAD/test/BK38RRH_45316.jpg -------------------------------------------------------------------------------- /test/FV45YUG_81667.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bluesy7585/tensorflow_LPRnet/HEAD/test/FV45YUG_81667.jpg -------------------------------------------------------------------------------- /test/KA93EL_51307.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bluesy7585/tensorflow_LPRnet/HEAD/test/KA93EL_51307.jpg -------------------------------------------------------------------------------- /test/MX90ECH_89776.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bluesy7585/tensorflow_LPRnet/HEAD/test/MX90ECH_89776.jpg -------------------------------------------------------------------------------- /test/QG27TAP_29957.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bluesy7585/tensorflow_LPRnet/HEAD/test/QG27TAP_29957.jpg -------------------------------------------------------------------------------- /test/TM77VH_62129.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bluesy7585/tensorflow_LPRnet/HEAD/test/TM77VH_62129.jpg -------------------------------------------------------------------------------- /test/Y81AEM_16610.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bluesy7585/tensorflow_LPRnet/HEAD/test/Y81AEM_16610.jpg -------------------------------------------------------------------------------- /fonts/RockoFLF-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bluesy7585/tensorflow_LPRnet/HEAD/fonts/RockoFLF-Bold.ttf -------------------------------------------------------------------------------- /fonts/UKNumberPlate.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bluesy7585/tensorflow_LPRnet/HEAD/fonts/UKNumberPlate.ttf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.7.1 2 | astor==0.8.0 3 | bleach==1.5.0 4 | certifi==2019.6.16 5 | gast==0.2.2 6 | grpcio==1.22.0 7 | html5lib==0.9999999 8 | Markdown==3.1.1 9 | numpy==1.16.4 10 | opencv-python==3.4.3.18 11 | Pillow==5.2.0 12 | protobuf==3.9.1 13 | six==1.12.0 14 | tensorboard==1.8.0 15 | tensorflow-gpu==1.8.0 16 | termcolor==1.1.0 17 | Werkzeug==0.15.5 18 | wincertstore==0.2 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tensorflow LPRnet 2 | tensorflow implementation of LPRnet. A lightweight deep network for number plate recognition. 3 | 4 | - multiple scale CNN features 5 | - CTC for variable length chars 6 | - no RNN layers 7 | 8 | ## training 9 | generate plate images for training 10 | 11 | `python gen_plates.py` 12 | 13 | generate validation images 14 | 15 | `python gen_plates.py -s .\valid -n 200` 16 | 17 | train 18 | 19 | `python main.py -m train` 20 | 21 | or train with runtime-generated images 22 | 23 | `python main.py -m train -r` 24 | 25 | model checkpoint will be save for each `SAVE_STEPS` steps. 26 | validation will be perform for each `VALIDATE_EPOCHS` epochs. 27 | 28 | ## test 29 | generate test images 30 | 31 | `python gen_plates.py -s .\test -n 200` 32 | 33 | restore checkpoint for test 34 | 35 | `python main.py -m test -c [checkpioint]` 36 | 37 | e.g 38 | 39 | ``` 40 | python main.py -m test -c .\checkpoint\LPRnet_steps8000_loss_0.069.ckpt 41 | ... 42 | val loss: 0.31266 43 | plate accuracy: 192-200 0.960, char accuracy: 1105-1115 0.99103 44 | ``` 45 | 46 | ### test single image 47 | 48 | to test single image and show result 49 | 50 | `python main.py -m test -c [checkpoint] --img [image fullpath]` 51 | 52 | e.g 53 | ``` 54 | python main.py -m test -c .\checkpoint\LPRnet_steps5000_loss_0.215.ckpt --img .\test\AW73RHW_18771.jpg 55 | ... 56 | restore from checkpoint: .\checkpoint\LPRnet_steps5000_loss_0.215.ckpt 57 | AM73RHW 58 | ``` 59 | ## train custom data 60 | change `TRAIN_DIR`, `VAL_DIR` in LPRnet.py to folder contains your training/validation data. 61 | image filename with the format [label]_XXXX 62 | e.g AB12CD_0000.jpg 63 | 64 | - char set 65 | 66 | change `CHARS` if possible chars in label is different with default. 67 | 68 | - char length 69 | 70 | default input resolution (94x24) has 24 timesteps in CTC layer. 71 | if your data have more than 8 chars in images, perhaps use wider resolution for good performance. 72 | e.g input width 128 has 32 timesteps in CTC layer. 73 | 74 | ## references 75 | - [LPRnet](https://arxiv.org/abs/1806.10447 "LPRnet") 76 | - https://github.com/lyl8213/Plate_Recognition-LPRnet 77 | - https://github.com/sirius-ai/LPRNet_Pytorch 78 | - https://github.com/mahavird/my_deep_anpr 79 | -------------------------------------------------------------------------------- /data_aug.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import cv2 4 | 5 | 6 | def jitter(img, jitter=0.1): 7 | rows, cols, _ = img.shape 8 | j_width = float(cols) * random.uniform(1 - jitter, 1 + jitter) 9 | j_height = float(rows) * random.uniform(1 - jitter, 1 + jitter) 10 | img = cv2.resize(img, (int(j_width), int(j_height))) 11 | return img 12 | 13 | def rotate(img, angle=5): 14 | 15 | scale = random.uniform(0.9, 1.1) 16 | angle = random.uniform(-angle, angle) 17 | 18 | rows, cols, _ = img.shape 19 | M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, scale) 20 | dst = img.copy() 21 | dst = cv2.warpAffine(img, M, (cols, rows), dst, cv2.INTER_LINEAR) 22 | 23 | return dst 24 | 25 | def perspective(img): 26 | 27 | h, w, _ = img.shape 28 | per = random.uniform(0.05, 0.1) 29 | w_p = int(w * per) 30 | h_p = int(h * per) 31 | 32 | pts1 = np.float32([[0, 0], [0, h], [w, 0], [w, h]]) 33 | pts2 = np.float32([[random.randint(0, w_p), random.randint(0, h_p)], 34 | [random.randint(0, w_p), h - random.randint(0, h_p)], 35 | [w - random.randint(0, w_p), random.randint(0, h_p)], 36 | [w - random.randint(0, w_p), h - random.randint(0, h_p)]]) 37 | 38 | M = cv2.getPerspectiveTransform(pts1, pts2) 39 | img = cv2.warpPerspective(img, M, (w, h)) 40 | return img 41 | 42 | def crop_subimage(img, margin=3): 43 | ran_margin = random.randint(0, margin) 44 | rows, cols, _ = img.shape 45 | crop_h = rows - ran_margin 46 | crop_w = cols - ran_margin 47 | row_start = random.randint(0, ran_margin) 48 | cols_start = random.randint(0, ran_margin) 49 | sub_img = img[row_start:row_start + crop_h, cols_start:cols_start + crop_w] 50 | return sub_img 51 | 52 | def hsv_space_variation(ori_img, scale): 53 | 54 | rows, cols, _ = ori_img.shape 55 | 56 | hsv_img = cv2.cvtColor(ori_img, cv2.COLOR_RGB2HSV) 57 | hsv_img = np.array(hsv_img, dtype=np.float32) 58 | img = hsv_img[:, :, 2] 59 | 60 | # gau noise 61 | noise_std = random.randint(5, 20) 62 | noise = np.random.normal(0, noise_std, (rows, cols)) 63 | 64 | # brightness scale 65 | img = img * scale 66 | img = np.clip(img, 0, 255) 67 | img = np.add(img, noise) 68 | 69 | # random hue variation 70 | hsv_img[:, :, 0] += random.randint(-5, 5) 71 | 72 | # random sat variation 73 | hsv_img[:, :, 1] += random.randint(-30, 30) 74 | 75 | hsv_img[:, :, 2] = img 76 | hsv_img = np.clip(hsv_img, 0, 255) 77 | hsv_img = np.array(hsv_img, dtype=np.uint8) 78 | rgb_img = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2RGB) 79 | 80 | return rgb_img 81 | 82 | def data_augmentation(img): 83 | 84 | img = jitter(img) 85 | 86 | if random.choice([True, False]): 87 | img = rotate(img) 88 | 89 | if random.choice([True, False]): 90 | img = perspective(img) 91 | 92 | img = crop_subimage(img) 93 | 94 | bright_scale = random.uniform(0.6, 1.2) 95 | img_out = hsv_space_variation(img, scale=bright_scale) 96 | 97 | return img_out -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | import random 5 | from data_aug import data_augmentation 6 | import gen_plates as gen 7 | import model.LPRnet as model 8 | 9 | 10 | def encode_label(label, char_dict): 11 | encode = [char_dict[c] for c in label] 12 | return encode 13 | 14 | def sparse_tuple_from(sequences, dtype=np.int32): 15 | """ 16 | Create a sparse representention of x. 17 | Args: 18 | sequences: a list of lists of type dtype where each element is a sequence 19 | Returns: 20 | A tuple with (indices, values, shape) 21 | """ 22 | indices = [] 23 | values = [] 24 | 25 | for n, seq in enumerate(sequences): 26 | indices.extend(zip([n] * len(seq), range(len(seq)))) 27 | values.extend(seq) 28 | 29 | indices = np.asarray(indices, dtype=np.int64) 30 | values = np.asarray(values, dtype=dtype) 31 | shape = np.asarray([len(sequences), np.asarray(indices).max(0)[1] + 1], dtype=np.int64) 32 | 33 | return indices, values, shape 34 | 35 | class DataIterator: 36 | def __init__(self, img_dir, runtime_generate=False): 37 | self.img_dir = img_dir 38 | self.batch_size = model.BATCH_SIZE 39 | self.channel_num = model.CH_NUM 40 | self.img_w, self.img_h = model.IMG_SIZE 41 | 42 | if runtime_generate: 43 | self.generator = gen.ImageGenerator('./fonts', model.CHARS) 44 | else: 45 | self.init() 46 | 47 | def init(self): 48 | self.filenames = [] 49 | self.labels = [] 50 | fs = os.listdir(self.img_dir) 51 | for filename in fs: 52 | self.filenames.append(filename) 53 | label = filename.split('_')[0] # format: [label]_[random number].jpg 54 | label = encode_label(label, model.CHARS_DICT) 55 | self.labels.append(label) 56 | self.sample_num = len(self.labels) 57 | self.labels = np.array(self.labels) 58 | self.random_index = list(range(self.sample_num)) 59 | random.shuffle(self.random_index) 60 | self.cur_index = 0 61 | 62 | def next_sample_ind(self): 63 | ret = self.random_index[self.cur_index] 64 | self.cur_index += 1 65 | if self.cur_index >= self.sample_num: 66 | self.cur_index = 0 67 | random.shuffle(self.random_index) 68 | return ret 69 | 70 | def next_batch(self): 71 | 72 | batch_size = self.batch_size 73 | images = np.zeros([batch_size, self.img_h, self.img_w, self.channel_num]) 74 | labels = [] 75 | 76 | for i in range(batch_size): 77 | sample_ind = self.next_sample_ind() 78 | fname = self.filenames[sample_ind] 79 | img = cv2.imread(os.path.join(self.img_dir, fname)) 80 | #img = data_augmentation(img) 81 | img = cv2.resize(img, (self.img_w, self.img_h)) 82 | images[i] = img 83 | 84 | labels.append(self.labels[sample_ind]) 85 | 86 | sparse_labels = sparse_tuple_from(labels) 87 | 88 | return images, sparse_labels, labels 89 | 90 | def next_test_batch(self): 91 | 92 | start = 0 93 | end = self.batch_size 94 | is_last_batch = False 95 | 96 | while not is_last_batch: 97 | if end >= self.sample_num: 98 | end = self.sample_num 99 | is_last_batch = True 100 | 101 | #print("s: {} e: {}".format(start, end)) 102 | 103 | cur_batch_size = end-start 104 | images = np.zeros([cur_batch_size, self.img_h, self.img_w, self.channel_num]) 105 | 106 | for j, i in enumerate(range(start, end)): 107 | fname = self.filenames[i] 108 | img = cv2.imread(os.path.join(self.img_dir, fname)) 109 | img = cv2.resize(img, (self.img_w, self.img_h)) 110 | images[j, ...] = img 111 | 112 | labels = self.labels[start:end, ...] 113 | sparse_labels = sparse_tuple_from(labels) 114 | 115 | start = end 116 | end += self.batch_size 117 | 118 | yield images, sparse_labels, labels 119 | 120 | def next_gen_batch(self): 121 | 122 | batch_size = self.batch_size 123 | imgs, labels = self.generator.generate_images(batch_size) 124 | labels = [encode_label(label, model.CHARS_DICT) for label in labels] 125 | 126 | images = np.zeros([batch_size, self.img_h, self.img_w, self.channel_num]) 127 | for i, img in enumerate(imgs): 128 | img = data_augmentation(img) 129 | img = cv2.resize(img, (self.img_w, self.img_h)) 130 | images[i, ...] = img 131 | 132 | sparse_labels = sparse_tuple_from(labels) 133 | 134 | return images, sparse_labels, labels -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import utils 2 | import os 3 | import time 4 | import numpy as np 5 | import argparse 6 | import cv2 7 | from model.LPRnet import * 8 | 9 | 10 | def infer_single_image(checkpoint, fname): 11 | 12 | if not os.path.isfile(fname): 13 | print('file {} does not exist.'.format(fname)) 14 | return 15 | 16 | img_w = IMG_SIZE[0] 17 | img_h = IMG_SIZE[1] 18 | 19 | img = cv2.imread(fname) 20 | img = cv2.resize(img, (img_w, img_h)) 21 | img_batch = np.expand_dims(img, axis=0) 22 | 23 | # print(img_batch.shape) 24 | lprnet = LPRnet(is_train=False) 25 | 26 | with tf.Session() as sess: 27 | sess.run(lprnet.init) 28 | saver = tf.train.Saver(tf.global_variables()) 29 | 30 | if not restore_checkpoint(sess, saver, checkpoint, is_train=False): 31 | return 32 | 33 | test_feed = {lprnet.inputs: img_batch} 34 | dense_decode = sess.run(lprnet.dense_decoded, test_feed) 35 | 36 | decoded_labels = [] 37 | for item in dense_decode: 38 | expression = ['' if i == -1 else DECODE_DICT[i] for i in item] 39 | expression = ''.join(expression) 40 | decoded_labels.append(expression) 41 | 42 | for l in decoded_labels: 43 | print(l) 44 | 45 | #cv2.imshow(os.path.basename(fname), img) 46 | #cv2.waitKey(0) 47 | 48 | 49 | def inference(sess, model, val_gen): 50 | 51 | def compare(labels, dense_decode): 52 | char_count = 0 53 | match_count = 0 54 | 55 | if len(labels) != len(dense_decode): 56 | print('length mismatch. {} != {}'.format(len(labels), len(dense_decode))) 57 | return char_count, match_count 58 | 59 | for label, decode in zip(labels, dense_decode): 60 | no_blank = decode[decode != -1] 61 | char_count += len(label) 62 | 63 | if np.array_equal(label, no_blank): 64 | match_count += 1 65 | return char_count, match_count 66 | 67 | total = 0 68 | total_match = 0 69 | total_chars = 0 70 | total_edits = 0 71 | batch_count = 0 72 | loss_sum = 0 73 | 74 | for test_inputs, test_targets, test_labels in val_gen.next_test_batch(): 75 | test_feed = {model.inputs: test_inputs, 76 | model.targets: test_targets} 77 | 78 | dense_decode, edit_sum, loss = sess.run([model.dense_decoded, model.edit_dis, model.loss], test_feed) 79 | 80 | loss_sum += loss 81 | char_count, match_count = compare(test_labels, dense_decode) 82 | 83 | total_chars += char_count 84 | total_edits += edit_sum 85 | 86 | total += len(test_labels) 87 | total_match += match_count 88 | batch_count += 1 89 | 90 | if batch_count > 0: 91 | print('val loss: {:.5f}'.format(loss_sum / batch_count)) 92 | 93 | if total > 0 and total_chars > 0: 94 | acc = total_match / total 95 | char_acc = (total_chars - total_edits) / total_chars 96 | print("plate accuracy: {}-{} {:.3f}, char accuracy: {}-{} {:.5f}" \ 97 | .format(total_match, total, acc, int(total_chars - total_edits), total_chars, char_acc)) 98 | 99 | def restore_checkpoint(sess, saver, ckpt, is_train=True): 100 | try: 101 | saver.restore(sess, ckpt) 102 | print('restore from checkpoint: {}'.format(ckpt)) 103 | return True 104 | except: 105 | if is_train: 106 | print("train from scratch") 107 | else: 108 | print("no valid checkpoint provided") 109 | return False 110 | 111 | def train(checkpoint, runtime_generate=False): 112 | lprnet = LPRnet(is_train=True) 113 | train_gen = utils.DataIterator(img_dir=TRAIN_DIR, runtime_generate=runtime_generate) 114 | val_gen = utils.DataIterator(img_dir=VAL_DIR) 115 | 116 | def train_batch(train_gen): 117 | if runtime_generate: 118 | train_inputs, train_targets, _ = train_gen.next_gen_batch() 119 | else: 120 | train_inputs, train_targets, _ = train_gen.next_batch() 121 | 122 | feed = {lprnet.inputs: train_inputs, lprnet.targets: train_targets} 123 | 124 | loss, steps, _, lr = sess.run( \ 125 | [lprnet.loss, lprnet.global_step, lprnet.optimizer, lprnet.learning_rate], feed) 126 | 127 | if steps > 0 and steps % SAVE_STEPS == 0: 128 | ckpt_dir = CHECKPOINT_DIR 129 | ckpt_file = os.path.join(ckpt_dir, \ 130 | 'LPRnet_steps{}_loss_{:.3f}.ckpt'.format(steps, loss)) 131 | if not os.path.isdir(ckpt_dir): os.mkdir(ckpt_dir) 132 | saver.save(sess, ckpt_file) 133 | print('checkpoint ', ckpt_file) 134 | return loss, steps, lr 135 | 136 | with tf.Session() as sess: 137 | sess.run(lprnet.init) 138 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=30) 139 | restore_checkpoint(sess, saver, checkpoint) 140 | 141 | print('training...') 142 | for curr_epoch in range(TRAIN_EPOCHS): 143 | print('Epoch {}/{}'.format(curr_epoch + 1, TRAIN_EPOCHS)) 144 | train_loss = lr = 0 145 | st = time.time() 146 | for batch in range(BATCH_PER_EPOCH): 147 | b_loss, steps, lr = train_batch(train_gen) 148 | train_loss += b_loss 149 | tim = time.time() - st 150 | train_loss /= BATCH_PER_EPOCH 151 | log = "train loss: {:.3f}, steps: {}, time: {:.1f}s, learning rate: {:.5f}" 152 | print(log.format(train_loss, steps, tim, lr)) 153 | 154 | if curr_epoch > 0 and curr_epoch % VALIDATE_EPOCHS == 0: 155 | inference(sess, lprnet, val_gen) 156 | 157 | 158 | def test(checkpoint): 159 | lprnet = LPRnet(is_train=False) 160 | test_gen = utils.DataIterator(img_dir=TEST_DIR) 161 | with tf.Session() as sess: 162 | sess.run(lprnet.init) 163 | saver = tf.train.Saver(tf.global_variables()) 164 | 165 | if not restore_checkpoint(sess, saver, checkpoint, is_train=False): 166 | return 167 | 168 | inference(sess, lprnet, test_gen) 169 | 170 | 171 | if __name__ == "__main__": 172 | parser = argparse.ArgumentParser() 173 | parser.add_argument("-c", "--ckpt", help="checkpoint", 174 | type=str, default=None) 175 | parser.add_argument("-m", "--mode", help="train or test", 176 | type=str, default="train") 177 | parser.add_argument("-r", "--runtime", help="train with runtime-generated images", 178 | action='store_true') 179 | parser.add_argument("--img", help="image fullpath to test", 180 | type=str, default=None) 181 | 182 | args = parser.parse_args() 183 | 184 | if args.mode == 'train': 185 | train(checkpoint=args.ckpt, runtime_generate=args.runtime) 186 | elif args.mode == 'test': 187 | if args.img is None: 188 | test(checkpoint=args.ckpt) 189 | else: 190 | infer_single_image(checkpoint=args.ckpt, fname=args.img) 191 | else: 192 | print('unknown mode:', args.mode) 193 | 194 | -------------------------------------------------------------------------------- /model/LPRnet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | TRAIN_EPOCHS = 300 4 | INITIAL_LEARNING_RATE = 1e-3 5 | DECAY_STEPS = 2000 6 | LEARNING_RATE_DECAY_FACTOR = 0.9 7 | MOMENTUM = 0.9 8 | 9 | SAVE_STEPS = 1000 10 | VALIDATE_EPOCHS = 10 11 | 12 | BATCH_SIZE = 64 13 | BATCH_PER_EPOCH = 50 14 | 15 | TRAIN_DIR = 'train' 16 | VAL_DIR = 'valid' 17 | TEST_DIR = 'test' 18 | 19 | CHECKPOINT_DIR = './checkpoint' 20 | 21 | IMG_SIZE = [94, 24] 22 | CH_NUM = 3 23 | 24 | CHARS = "ABCDEFGHJKLMNPQRSTUVWXYZ0123456789" # exclude I, O 25 | CHARS_DICT = {char:i for i, char in enumerate(CHARS)} 26 | DECODE_DICT = {i:char for i, char in enumerate(CHARS)} 27 | 28 | NUM_CLASS = len(CHARS)+1 29 | 30 | def small_basic_block(inputdata, out_channel, name=None): 31 | with tf.variable_scope(name): 32 | out_div4 = int(out_channel/4) 33 | conv1 = conv2d(inputdata, out_div4, ksize=[1,1], name='conv1') 34 | relu1 = tf.nn.relu(conv1) 35 | 36 | conv2 = conv2d(relu1, out_div4, ksize=[1,3], name='conv2') 37 | relu2 = tf.nn.relu(conv2) 38 | 39 | conv3 = conv2d(relu2, out_div4, ksize=[3,1], name='conv3') 40 | relu3 = tf.nn.relu(conv3) 41 | 42 | conv4 = conv2d(relu3, out_channel, ksize=[1,1], name='conv4') 43 | bn = tf.layers.batch_normalization(conv4) 44 | relu = tf.nn.relu(bn) 45 | return relu 46 | 47 | def conv2d(inputdata, out_channel,ksize,stride=[1,1,1,1],pad = 'SAME', name=None): 48 | 49 | with tf.variable_scope(name): 50 | in_channel = inputdata.get_shape().as_list()[3] 51 | filter_shape = [ksize[0], ksize[1], in_channel, out_channel] 52 | weights = tf.get_variable('w', filter_shape, dtype=tf.float32, initializer=tf.glorot_uniform_initializer()) 53 | biases = tf.get_variable('b', [out_channel], dtype=tf.float32, initializer=tf.constant_initializer()) 54 | conv = tf.nn.conv2d(inputdata, weights, 55 | strides=stride, 56 | padding=pad) 57 | add_bias = tf.nn.bias_add(conv, biases) 58 | return add_bias 59 | 60 | def global_context(inputdata, ksize, strides, name=None): 61 | with tf.variable_scope(name): 62 | avg_pool = tf.nn.avg_pool(inputdata, 63 | ksize=ksize, 64 | strides=strides, 65 | padding='SAME') 66 | sqm = tf.reduce_mean(tf.square(avg_pool)) 67 | out = tf.div(avg_pool, sqm) 68 | return out 69 | 70 | 71 | class LPRnet: 72 | 73 | def __init__(self, is_train): 74 | 75 | width, height = IMG_SIZE 76 | self.inputs = tf.placeholder( 77 | tf.float32, 78 | shape=(None, height, width, CH_NUM), 79 | name='inputs') 80 | 81 | self.targets = tf.sparse_placeholder(tf.int32) 82 | 83 | logits = self.cnn_layers(self.inputs, is_train) 84 | logits_shape = tf.shape(logits) 85 | 86 | cur_batch_size = logits_shape[0] 87 | timesteps = logits_shape[1] 88 | 89 | seq_len = tf.fill([cur_batch_size], timesteps) 90 | 91 | logits = tf.transpose(logits, (1, 0, 2)) 92 | decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits, seq_len, merge_repeated=False) 93 | 94 | self.dense_decoded = tf.sparse_tensor_to_dense(decoded[0], default_value=-1, name='decoded') 95 | 96 | self.edit_dis = tf.reduce_sum(tf.edit_distance(tf.cast(decoded[0], tf.int32), \ 97 | self.targets, normalize=False)) 98 | 99 | ctc_loss = tf.nn.ctc_loss(labels=self.targets, inputs=logits, sequence_length=seq_len) 100 | self.loss = tf.reduce_mean(ctc_loss) 101 | 102 | global_step = tf.Variable(0, trainable=False) 103 | learning_rate = tf.train.exponential_decay(INITIAL_LEARNING_RATE, 104 | global_step, 105 | DECAY_STEPS, 106 | LEARNING_RATE_DECAY_FACTOR, 107 | staircase=True) 108 | 109 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)\ 110 | .minimize(ctc_loss, global_step=global_step) 111 | 112 | self.logits = logits 113 | self.global_step = global_step 114 | self.optimizer = optimizer 115 | self.learning_rate = learning_rate 116 | 117 | self.init = tf.global_variables_initializer() 118 | 119 | def cnn_layers(self, inputs, is_train): 120 | 121 | ## back-bone 122 | conv1 = conv2d(inputs, 64, ksize=[3,3], name='conv1') 123 | conv1_bn = tf.layers.batch_normalization(conv1) 124 | conv1_relu = tf.nn.relu(conv1_bn) 125 | max1 = tf.nn.max_pool(conv1_relu, 126 | ksize=[1, 3, 3, 1], 127 | strides=[1, 1, 1, 1], 128 | padding='SAME') 129 | sbb1 = small_basic_block(max1, 128, name='sbb1') 130 | max2 = tf.nn.max_pool(sbb1, 131 | ksize=[1, 3, 3, 1], 132 | strides=[1, 1, 2, 1], 133 | padding='SAME') 134 | 135 | sbb2 = small_basic_block(max2, 256, name='sbb2') 136 | sbb3 = small_basic_block(sbb2, 256, name='sbb3') 137 | max3 = tf.nn.max_pool(sbb3, 138 | ksize=[1, 3, 3, 1], 139 | strides=[1, 1, 2, 1], 140 | padding='SAME') 141 | 142 | dropout1 = tf.layers.dropout(max3, training=is_train) 143 | 144 | conv2 = conv2d(dropout1, 256, ksize=[1, 4], name='conv_d1') 145 | conv2_bn = tf.layers.batch_normalization(conv2) 146 | conv2_relu = tf.nn.relu(conv2_bn) 147 | 148 | dropout2 = tf.layers.dropout(conv2_relu, training=is_train) 149 | 150 | conv3 = conv2d(dropout2, NUM_CLASS, ksize=[13, 1], name='conv_d2') 151 | conv3_bn = tf.layers.batch_normalization(conv3) 152 | conv3_relu = tf.nn.relu(conv3_bn) 153 | 154 | ## global context 155 | scale1 = global_context(conv1, 156 | ksize=[1, 1, 4, 1], 157 | strides=[1, 1, 4, 1], 158 | name='gc1') 159 | 160 | scale2 = global_context(sbb1, 161 | ksize=[1, 1, 4, 1], 162 | strides=[1, 1, 4, 1], 163 | name='gc2') 164 | 165 | scale3 = global_context(sbb3, 166 | ksize=[1, 1, 2, 1], 167 | strides=[1, 1, 2, 1], 168 | name = 'gc3') 169 | 170 | sqm = tf.reduce_mean(tf.square(conv3_relu)) 171 | scale4 = tf.div(conv3_relu, sqm) 172 | 173 | #print(scale1.get_shape().as_list()) 174 | #print(scale2.get_shape().as_list()) 175 | #print(scale3.get_shape().as_list()) 176 | #print(scale4.get_shape().as_list()) 177 | 178 | gc_concat = tf.concat([scale1, scale2, scale3, scale4], 3) 179 | conv_out = conv2d(gc_concat, NUM_CLASS, ksize=(1, 1), name='conv_out') 180 | 181 | logits = tf.reduce_mean(conv_out, axis=1) 182 | #print(logits.get_shape().as_list()) 183 | 184 | return logits 185 | -------------------------------------------------------------------------------- /gen_plates.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import sys 5 | import cv2 6 | import numpy as np 7 | 8 | from PIL import Image 9 | from PIL import ImageDraw 10 | from PIL import ImageFont 11 | 12 | from data_aug import data_augmentation 13 | import model.LPRnet as model 14 | 15 | class ImageGenerator: 16 | def __init__(self, ttf_dir, char_set, char_height=36): 17 | 18 | self.chars = char_set 19 | self.letters = [] 20 | self.digits = [] 21 | for c in char_set: 22 | if str.isalpha(c): 23 | self.letters.append(c) 24 | else: 25 | self.digits.append(c) 26 | 27 | self.char_height = char_height 28 | self.ttf_dir = ttf_dir 29 | self.fonts, self.font_char_ims = self.load_fonts(ttf_dir) 30 | 31 | white = [1, 1, 1] 32 | yellow = [0, 1, 1] 33 | blue = [1, 0, 0] 34 | self.black_text_colors = [white, yellow] 35 | self.white_text_colors = [blue] 36 | 37 | def random_text_plate_colors(self, min_diff=0.3, black_text=True): 38 | high = random.uniform(min_diff, 1.0) 39 | low = random.uniform(0.0, high - min_diff) 40 | text_color, plate_color = (low, high) if black_text else (high, low) 41 | return text_color, plate_color 42 | 43 | def load_fonts(self, folder_path): 44 | font_char_ims = {} 45 | fonts = [f for f in os.listdir(folder_path) if f.endswith('.ttf')] 46 | for font in fonts: 47 | font_char_ims[font] = dict(self.generate_char_imgs(\ 48 | os.path.join(folder_path, font), self.char_height)) 49 | return fonts, font_char_ims 50 | 51 | def generate_char_imgs(self, font_path, output_height): 52 | font_size = output_height * 4 53 | font = ImageFont.truetype(font_path, font_size) 54 | height = max(font.getsize(c)[1] for c in self.chars) 55 | 56 | for c in self.chars: 57 | width = font.getsize(c)[0] 58 | im = Image.new("RGBA", (width, height), (0, 0, 0)) 59 | 60 | draw = ImageDraw.Draw(im) 61 | draw.text((0, 0), c, (255, 255, 255), font=font) 62 | scale = float(output_height) / height 63 | im = im.resize((int(width * scale), output_height), Image.ANTIALIAS) 64 | yield c, np.array(im)[:, :, 0] 65 | 66 | def generate_code(self): 67 | # random 1~2 letters + 1~2 digits + 2~3 letters 68 | pre_n = random.randint(1, 2) 69 | pre_letters = [random.choice(self.letters) for _ in range(pre_n)] 70 | digit_n = random.randint(1, 2) 71 | digits = [random.choice(self.digits) for _ in range(digit_n)] 72 | post_n = random.randint(2, 3) 73 | post_letters = [random.choice(self.letters) for _ in range(post_n)] 74 | 75 | code = ''.join(pre_letters) + ''.join(digits) + '-' + ''.join(post_letters) 76 | return code 77 | 78 | def getOneRandomFont(self): 79 | return random.choice(self.fonts) 80 | 81 | def getCharGivenLabelFont(self, label, font): 82 | char_ims = self.font_char_ims[font] 83 | char_img = char_ims[label] 84 | return char_img, label 85 | 86 | def generate_images(self, number): 87 | 88 | images = [] 89 | labels = [] 90 | 91 | for _ in enumerate(range(number)): 92 | 93 | char_height = self.char_height 94 | code = self.generate_code() 95 | 96 | space = round(char_height * random.uniform(0.0, 0.3)) 97 | char_spacing = [] 98 | for c in code: 99 | if c == '-': 100 | char_spacing[-1] += space 101 | else: 102 | char_spacing.append(space) 103 | 104 | code = code.replace('-','') 105 | 106 | # generate letter, number images 107 | char_ims = [] 108 | char_font = self.getOneRandomFont() 109 | 110 | for i, c in enumerate(code): 111 | char, label = self.getCharGivenLabelFont(c, char_font) 112 | char_ims.append(char) 113 | 114 | char_width_sum = sum(char_im.shape[1] for char_im in char_ims) 115 | 116 | top_padding = round(random.uniform(0.1, 1.0) * char_height) 117 | bot_padding = round(random.uniform(0.1, 1.0) * char_height) 118 | left_padding = round(random.uniform(0.1, 1.0) * char_height) 119 | right_padding = round(random.uniform(0.1, 1.0) * char_height) 120 | 121 | Plate_h = (char_height + top_padding + bot_padding) 122 | Plate_w = (char_width_sum + left_padding + right_padding + sum(char_spacing[:-1])) 123 | 124 | out_shape = (Plate_h, Plate_w) 125 | text_mask = np.zeros(out_shape) 126 | 127 | x = left_padding 128 | y = top_padding 129 | 130 | for ind, c in enumerate(code): 131 | char_im = char_ims[ind] 132 | ix, iy = int(x), int(y) 133 | text_mask[iy:iy + char_im.shape[0], ix:ix + char_im.shape[1]] = char_im 134 | x += char_im.shape[1] + char_spacing[ind] 135 | 136 | is_black_text = random.choice([True, False]) 137 | text_color, plate_color = self.random_text_plate_colors(black_text=is_black_text) 138 | 139 | plate_mask = (255. - text_mask) 140 | 141 | if is_black_text: 142 | color = np.array(random.choice(self.black_text_colors)) 143 | else: 144 | color = np.array(random.choice(self.white_text_colors)) 145 | 146 | w_color = color * plate_color 147 | 148 | dim = (Plate_h, Plate_w, 3) 149 | Plate = np.ones(dim) 150 | Plate[:, :, 0] = text_mask * text_color 151 | Plate[:, :, 1] = text_mask * text_color 152 | Plate[:, :, 2] = text_mask * text_color 153 | Plate[:, :, 0] += plate_mask * w_color[0] 154 | Plate[:, :, 1] += plate_mask * w_color[1] 155 | Plate[:, :, 2] += plate_mask * w_color[2] 156 | Plate = Plate.astype(np.uint8) 157 | 158 | images.append(Plate) 159 | labels.append(code) 160 | 161 | return images, labels 162 | 163 | 164 | 165 | if __name__ == "__main__": 166 | 167 | parser = argparse.ArgumentParser() 168 | parser.add_argument("-s", "--save_dir", help="save directory", 169 | type=str, default="./train") 170 | parser.add_argument("-n", "--num", help="number of images", 171 | type=int, default=1000) 172 | 173 | args = parser.parse_args() 174 | save_dir = args.save_dir 175 | 176 | def labelToSaveFilename(label): 177 | rand_tail = random.randint(10000, 99999) 178 | name = '{}_{}.jpg'.format(label, rand_tail) 179 | return name 180 | 181 | FONT_HEIGHT = 32 182 | ttfCharGen = ImageGenerator('./fonts/', char_set=model.CHARS, char_height=FONT_HEIGHT) 183 | 184 | plates, labels = ttfCharGen.generate_images(args.num) 185 | #for plate in plates: 186 | #cv2.imshow('', plate) 187 | #cv2.waitKey(0) 188 | 189 | if not os.path.isdir(save_dir): os.mkdir(save_dir) 190 | 191 | for img, label in zip(plates, labels): 192 | full_path = os.path.join(save_dir, labelToSaveFilename(label)) 193 | img = data_augmentation(img) 194 | cv2.imwrite(full_path, img) --------------------------------------------------------------------------------