├── .gitignore ├── BaseTool.py ├── README.md ├── cnn_model.py ├── dataset.rar ├── dataset └── map.py └── picReader.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.bmp 2 | *.jpg 3 | -------------------------------------------------------------------------------- /BaseTool.py: -------------------------------------------------------------------------------- 1 | __author__ = 'jmh081701' 2 | from dataset.map import maps as aspectMap 3 | import os 4 | from picReader import img2mat 5 | import random 6 | class data_generator: 7 | def __init__(self,aspect='area',seperate_ratio=0.1): 8 | ''' 9 | :param aspect: 打开什么样的训练集,[area,letter,province] 三选一 10 | :param seperate_ratio: 测试集划分比例 ,从训练集和验证集里面随机抽取seperate_ratio作为训练集 11 | :return: 12 | ''' 13 | self.train_dir = "dataset\\train\\%s\\" %aspect 14 | self.val_dir = "dataset\\val\\%s\\" % aspect 15 | self.seperate_ratio = seperate_ratio 16 | 17 | self.data_vector_set = [] #保存所有的图片向量 18 | self.data_label_set = [] #保存所有的标签 19 | 20 | self.train_set = [] #保存训练集的下标 21 | self.train_batch_index = 0 22 | self.train_epoch =0 23 | self.valid_set = [] #保存验证集的下标 24 | self.valid_batch_index = 0 25 | self.test_set = [] #保存测试集的下标 26 | self.test_batch_index = 0 27 | 28 | 29 | self.classes = 0 #最大的classes为34,这个值会在载入train和test后有所变化 30 | self.data_set_cnt = 0 31 | 32 | self.load_train() 33 | self.load_valid() 34 | 35 | def load_train(self): 36 | for rt,dirs,files in os.walk(self.train_dir): 37 | self.classes = max(self.classes,len(dirs)) 38 | 39 | if len(dirs)==0 : 40 | #说明到了叶子目录,里面放着就是图片 41 | label = int(rt.split('\\')[-1]) 42 | 43 | for name in files: 44 | img_filename = os.path.join(rt,name) 45 | vec = img2mat(img_filename) 46 | self.data_vector_set.append(vec) 47 | self.data_label_set.append(label) 48 | if random.random() < self.seperate_ratio: 49 | self.test_set.append(self.data_set_cnt) 50 | else: 51 | self.train_set.append(self.data_set_cnt) 52 | self.data_set_cnt +=1 53 | def load_valid(self): 54 | for rt,dirs,files in os.walk(self.val_dir): 55 | self.classes = max(self.classes,len(dirs)) 56 | if len(dirs)==0 : 57 | #说明到了叶子目录,里面放着就是图片 58 | label = int(rt.split('\\')[-1]) 59 | #print(label,self.data_set_cnt) 60 | for name in files: 61 | img_filename = os.path.join(rt,name) 62 | vec = img2mat(img_filename) 63 | self.data_vector_set.append(vec) 64 | self.data_label_set.append(label) 65 | if random.random() < self.seperate_ratio: 66 | self.test_set.append(self.data_set_cnt) 67 | else: 68 | self.valid_set.append(self.data_set_cnt) 69 | self.data_set_cnt +=1 70 | def next_train_batch(self,batch=100): 71 | input_x =[] 72 | input_y =[] 73 | for i in range(batch): 74 | input_x.append(self.data_vector_set[self.train_set[(self.train_batch_index + i)%len(self.train_set)]]) 75 | y = [0] * 34 76 | y[self.data_label_set[self.train_set[(self.train_batch_index +i)%len(self.train_set)]]] = 1 77 | input_y.append(y) 78 | self.train_batch_index +=batch 79 | if self.train_batch_index > len(self.train_set) : 80 | self.train_epoch +=1 81 | self.train_batch_index %=len(self.train_set) 82 | return input_x,input_y,self.train_epoch 83 | 84 | def next_valid_batch(self,batch=100): 85 | input_x =[] 86 | input_y =[] 87 | for i in range(batch): 88 | index = random.randint(0,len(self.valid_set)-1) 89 | input_x.append(self.data_vector_set[index]) 90 | y = [0] * 34 91 | y[self.data_label_set[index]] = 1 92 | input_y.append(y) 93 | self.valid_batch_index +=batch 94 | 95 | self.valid_batch_index %=len(self.valid_set) 96 | return input_x,input_y,self.train_epoch 97 | def next_test_batch(self,batch=100): 98 | input_x =[] 99 | input_y =[] 100 | for i in range(batch): 101 | input_x.append(self.data_vector_set[self.test_set[(self.test_batch_index + i)%len(self.test_set)]]) 102 | y = [0] * 34 103 | y[self.data_label_set[self.test_set[(self.test_batch_index +i)%(len(self.test_set))]]] = 1 104 | input_y.append(y) 105 | self.test_batch_index +=batch 106 | if self.test_batch_index > len(self.test_set) : 107 | self.train_epoch +=1 108 | self.test_batch_index %=len(self.test_set) 109 | return input_x,input_y,self.train_epoch 110 | if __name__ == '__main__': 111 | data_gen = data_generator() 112 | print(len(data_gen.test_set)) 113 | print(data_gen.next_train_batch(100)) 114 | print(data_gen.next_valid_batch(100)) 115 | print(data_gen.next_test_batch(100)) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # car-board-reg 2 | 基于CNN的车牌号识别 3 | 4 | # 博客链接 5 | https://blog.csdn.net/jmh1996/article/details/88951797 【CNN——基于CNN的车牌号识别】 6 | 7 | # 数据集介绍 8 | ## 车牌构成 9 | 为简化实验,在该实验中默认车牌字符已经得到划分,因此车牌识别可以分解为三个区域的字符识别任务(多分类任务),共实现7个字符的识别。 10 | 例如:`京A·F0236` 11 | 其中第一部分 `京` 表示车牌所在的省市,后面紧跟的`A`是发牌单位,间隔符`·`后面的5个字符就是序号。 12 | 省市Province: 13 | ("皖", "沪", "津", "渝", "冀", "晋", "蒙", "辽", "吉", "黑", "苏", "浙", "京", "闽", "赣", "鲁", "豫", "鄂", "湘", "粤", "桂", "琼", "川", "贵", "云", "藏", "陕", "甘", "青", "宁", "新") 14 | 发牌单位Area: 15 | ("A","B","C","D","E","F","G","H","I","J","K","L","M","N","O","P","Q","R","S","T","U","V","W","X","Y","Z") 16 | 字符Letter: 17 | ("0","1","2","3","4","5","6","7","8","9","A","B","C","D","E","F","G","H","J","K","L","M","N","P","Q","R","S","T","U","V","W","X","Y","Z") 18 | ## 数据集目录: 19 | ``` 20 | └─dataset 21 | ├─map.py ---->映射关系文件 22 | ├─test ---->测试集 23 | ├─train ---->训练集 24 | │ ├─area 25 | │ ├─letter 26 | │ └─province 27 | └─val ----->验证集 28 | ├─area 29 | ├─letter 30 | └─province 31 | ``` 32 | 验证集和训练集目录都各自包含3目录:area,letter,province.分别对应各自的含义。 33 | ```map 34 | { 35 | "province": { 36 | "0": "皖", 37 | "1": "沪", 38 | "2": "津", 39 | "3": "渝", 40 | "4": "冀", 41 | "5": "晋", 42 | "6": "蒙", 43 | "7": "辽", 44 | "8": "吉", 45 | "9": "黑", 46 | "10": "苏", 47 | "11": "浙", 48 | "12": "京", 49 | "13": "闽", 50 | "14": "赣", 51 | "15": "鲁", 52 | "16": "豫", 53 | "17": "鄂", 54 | "18": "湘", 55 | "19": "粤", 56 | "20": "桂", 57 | "21": "琼", 58 | "22": "川", 59 | "23": "贵", 60 | "24": "云", 61 | "25": "藏", 62 | "26": "陕", 63 | "27": "甘", 64 | "28": "青", 65 | "29": "宁", 66 | "30": "新" 67 | }, 68 | "area": { 69 | "0": "A", 70 | "1": "B", 71 | "2": "C", 72 | "3": "D", 73 | "4": "E", 74 | "5": "F", 75 | "6": "G", 76 | "7": "H", 77 | "8": "I", 78 | "9": "J", 79 | "10": "K", 80 | "11": "L", 81 | "12": "M", 82 | "13": "N", 83 | "14": "O", 84 | "15": "P", 85 | "16": "Q", 86 | "17": "R", 87 | "18": "S", 88 | "19": "T", 89 | "20": "U", 90 | "21": "V", 91 | "22": "W", 92 | "23": "X", 93 | "24": "Y", 94 | "25": "Z" 95 | }, 96 | "letter": { 97 | "0": "0", 98 | "1": "1", 99 | "2": "2", 100 | "3": "3", 101 | "4": "4", 102 | "5": "5", 103 | "6": "6", 104 | "7": "7", 105 | "8": "8", 106 | "9": "9", 107 | "10": "A", 108 | "11": "B", 109 | "12": "C", 110 | "13": "D", 111 | "14": "E", 112 | "15": "F", 113 | "16": "G", 114 | "17": "H", 115 | "18": "J", 116 | "19": "K", 117 | "20": "L", 118 | "21": "M", 119 | "22": "N", 120 | "23": "P", 121 | "24": "Q", 122 | "25": "R", 123 | "26": "S", 124 | "27": "T", 125 | "28": "U", 126 | "29": "V", 127 | "30": "W", 128 | "31": "X", 129 | "32": "Y", 130 | "33": "Z" 131 | } 132 | } 133 | ``` 134 | # PIL读取Image文件 135 | 本例提供的训练集里面的每个图片都是`20x20` 的二值化后的灰度图,例如: 136 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/20190401143457733.png) 137 | 因此,我们需要使用PIL库或opencv库把灰度图转换为我们方便处理的数据形式。本人是先转化为list of list. 138 | 139 | picReader.py 140 | 141 | ```python 142 | __author__ = 'jmh081701' 143 | from PIL import Image 144 | def img2mat(img_filename): 145 | #把所有的图片都resize为20x20 146 | img = Image.open(img_filename) 147 | img = img.resize((20,20)) 148 | mat = [[img.getpixel((x,y)) for x in range(0,img.size[0])] for y in range(0,img.size[1])] 149 | return mat 150 | def test(): 151 | mat = img2mat("dataset\\test\\1.bmp") 152 | print(mat) 153 | print(mat[0][0],len(mat),len(mat[0])) 154 | if __name__ == '__main__': 155 | test() 156 | ``` 157 | 样例输出: 158 | 159 | ```shell 160 | [[0, 0, 0, 0, 144, 212, 74, 17, 15, 60, 60, 62, 64, 67, 67, 68, 35, 0, 0, 0], [0, 0, 0, 0, 28, 119, 255, 101, 61, 233, 255, 255, 255, 255, 255, 255, 241, 44, 0, 0], [0, 0, 0, 0, 0, 15, 170, 92, 8, 14, 34, 31, 29, 29, 24, 74, 226, 38, 0, 0], [0, 0, 67, 220, 83, 4, 0, 0, 0, 84, 160, 0, 0, 0, 0, 52, 170, 11, 0, 0], [0, 0, 71, 255, 105, 10, 0, 0, 75, 230, 246, 124, 5, 0, 0, 49, 188, 19, 0, 0], [0, 0, 64, 255, 113, 15, 152, 216, 246, 255, 255, 255, 226, 225, 27, 46, 255, 59, 0, 0], [0, 0, 53, 255, 120, 22, 172, 249, 255, 255, 255, 255, 255, 255, 35, 33, 213, 61, 0, 0], [0, 0, 43, 255, 139, 105, 243, 254, 130, 231, 255, 139, 217, 255, 37, 35, 234, 63, 0, 0], [0, 0, 34, 247, 151, 68, 166, 248, 143, 225, 255, 159, 219, 255, 41, 37, 240, 50, 0, 0], [0, 0, 26, 240, 136, 38, 143, 246, 255, 255, 255, 255, 255, 255, 43, 29, 168, 0, 0, 0], [0, 0, 18, 231, 142, 44, 135, 246, 255, 255, 255, 230, 190, 98, 6, 25, 210, 49, 2, 0], [0, 0, 17, 223, 147, 49, 112, 214, 123, 226, 255, 147, 0, 0, 0, 28, 218, 10, 1, 0], [0, 0, 16, 212, 154, 56, 0, 0, 4, 69, 249, 149, 148, 216, 46, 18, 205, 94, 13, 0], [0, 0, 15, 200, 157, 59, 0, 11, 45, 255, 255, 242, 244, 255, 57, 3, 33, 13, 2, 0], [0, 0, 15, 196, 164, 66, 0, 66, 253, 198, 198, 198, 200, 225, 154, 87, 252, 90, 18, 0], [0, 0, 14, 184, 171, 73, 0, 8, 31, 1, 0, 0, 1, 16, 8, 13, 255, 110, 25, 0], [0, 0, 13, 175, 177, 79, 0, 0, 0, 0, 0, 0, 0, 0, 8, 37, 255, 117, 30, 0], [0, 0, 10, 134, 147, 69, 0, 0, 0, 0, 0, 0, 0, 0, 29, 127, 230, 24, 5, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 18, 2, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] 161 | 0 20 20 162 | ``` 163 | # 模型设计 164 | 数据集已经把一个车牌的三个部分都分开了,所以可以设计三个模型分别去识别这三部分。在本例中,本人为了简单期间,三个部分用了用一个CNN 网络结构,但是每个网络结构里面的具体参数是各自独立的。 165 | CNN网络结构: 166 | 1. 输入层:20x20 167 | 2. 第一层卷积:卷积核大小:3x3,卷积核个数:32,Stride 步长:1,Same 卷积 168 | 3. 第二层卷积:卷积核大下:3x3,卷积核个数:64,Stride 步长:1,Same卷积 169 | (两个卷积级联,效果就是5x5的卷积核,但是减少了参数个数) 170 | 4. 第二层池化:池化大小:2x2,max pool,Stride 步长:2 171 | 5. 第三层卷积:卷积核大小:3x3,卷积核个数:8,Stride 步长:1,Same卷积 172 | 6. 第三层池化:池化大小:2x2,max pooling,Stride :2。应该得到8个5x5的特征图。 173 | 平坦化:得到8x5x5=200维的向量 174 | 7. 第四层全连接:512个神经元,激活函数为relu。 175 | 8. 第五层全连接:34个神经元,softmax激活函数。 176 | 177 | 第五层是分类层,一共有34个神经元,表示最多有34个类别。对于province来说,只有前31类有效;对于area来说只有前26类有效;对于letter来说,这34个神经元都有效。因此在生成训练集的时候,需要把正确的答案标签编码为34维的one-hot 178 | # 数据处理 179 | 数据处理模块,主要仿照minist数据集的写法,编写一个类,实现next_train_batch,next_test_batch,next_valid_batch函数。 180 | BaseTool.py 181 | 182 | ```python 183 | __author__ = 'jmh081701' 184 | from dataset.map import maps as aspectMap 185 | import os 186 | from picReader import img2mat 187 | import random 188 | class data_generator: 189 | def __init__(self,aspect='area',seperate_ratio=0.1): 190 | ''' 191 | :param aspect: 打开什么样的训练集,[area,letter,province] 三选一 192 | :param seperate_ratio: 测试集划分比例 ,从训练集和验证集里面随机抽取seperate_ratio作为训练集 193 | :return: 194 | ''' 195 | self.train_dir = "dataset\\train\\%s\\" %aspect 196 | self.val_dir = "dataset\\val\\%s\\" % aspect 197 | self.seperate_ratio = seperate_ratio 198 | 199 | self.data_vector_set = [] #保存所有的图片向量 200 | self.data_label_set = [] #保存所有的标签 201 | 202 | self.train_set = [] #保存训练集的下标 203 | self.train_batch_index = 0 204 | self.valid_set = [] #保存验证集的下标 205 | self.valid_batch_index = 0 206 | self.test_set = [] #保存测试集的下标 207 | self.test_batch_index = 0 208 | 209 | 210 | self.classes = 0 #最大的classes为34,这个值会在载入train和test后有所变化 211 | self.data_set_cnt = 0 212 | 213 | self.load_train() 214 | self.load_valid() 215 | 216 | def load_train(self): 217 | for rt,dirs,files in os.walk(self.train_dir): 218 | self.classes = max(self.classes,len(dirs)) 219 | if len(dirs)==0 : 220 | #说明到了叶子目录,里面放着就是图片 221 | label = int(rt.split('\\')[-1]) 222 | for name in files: 223 | img_filename = os.path.join(rt,name) 224 | vec = img2mat(img_filename) 225 | self.data_vector_set.append(vec) 226 | self.data_label_set.append(label) 227 | if random.random() < self.seperate_ratio: 228 | self.test_set.append(self.data_set_cnt) 229 | else: 230 | self.train_set.append(self.data_set_cnt) 231 | self.data_set_cnt +=1 232 | def load_valid(self): 233 | for rt,dirs,files in os.walk(self.val_dir): 234 | self.classes = max(self.classes,len(dirs)) 235 | if len(dirs)==0 : 236 | #说明到了叶子目录,里面放着就是图片 237 | label = int(rt.split('\\')[-1]) 238 | for name in files: 239 | img_filename = os.path.join(rt,name) 240 | vec = img2mat(img_filename) 241 | self.data_vector_set.append(vec) 242 | self.data_label_set.append(label) 243 | if random.random() < self.seperate_ratio: 244 | self.test_set.append(self.data_set_cnt) 245 | else: 246 | self.valid_set.append(self.data_set_cnt) 247 | self.data_set_cnt +=1 248 | def next_train_batch(self,batch=100): 249 | input_x =[] 250 | input_y =[] 251 | for i in range(batch): 252 | input_x.append(self.data_vector_set[self.train_set[(self.train_batch_index + i)%len(self.train_set)]]) 253 | y = [0] * self.classes 254 | y[self.data_label_set[self.train_set[(self.train_batch_index +i)%len(self.train_set)]]] = 1 255 | input_y.append(y) 256 | self.train_batch_index +=batch 257 | self.train_batch_index %=len(self.train_set) 258 | return input_x,input_y 259 | 260 | def next_valid_batch(self,batch=100): 261 | input_x =[] 262 | input_y =[] 263 | for i in range(batch): 264 | index = random.randint(0,len(self.valid_set)-1) 265 | input_x.append(self.data_vector_set[index]) 266 | y = [0] * 34 267 | y[self.data_label_set[index]] = 1 268 | input_y.append(y) 269 | self.valid_batch_index +=batch 270 | 271 | self.valid_batch_index %=len(self.valid_set) 272 | return input_x,input_y,self.train_epoch 273 | def next_test_batch(self,batch=100): 274 | input_x =[] 275 | input_y =[] 276 | for i in range(batch): 277 | input_x.append(self.data_vector_set[self.test_set[(self.test_batch_index + i)%len(self.test_set)]]) 278 | y = [0] * self.classes 279 | y[self.data_label_set[self.test_set[(self.test_batch_index +i)%(len(self.test_set))]]] = 1 280 | input_y.append(y) 281 | self.test_batch_index +=batch 282 | self.test_batch_index %=len(self.test_set) 283 | return input_x,input_y 284 | if __name__ == '__main__': 285 | data_gen = data_generator() 286 | print(len(data_gen.test_set)) 287 | print(data_gen.next_train_batch(50)) 288 | print(data_gen.next_valid_batch(50)[1]) 289 | print(data_gen.next_train_batch(30)) 290 | ``` 291 | 292 | # 构建CNN模型 293 | cnn_model.py 294 | 295 | ```python 296 | __author__ = 'jmh081701' 297 | import tensorflow as tf 298 | from BaseTool import data_generator 299 | 300 | batch_size = 100 # 每个batch的大小 301 | learning_rate=1e-4 #学习速率 302 | aspect = "area" 303 | data_gen = data_generator(aspect) 304 | 305 | input_x =tf.placeholder(dtype=tf.float32,shape=[None,20,20],name='input_x') 306 | input_y =tf.placeholder(dtype=tf.float32,shape=[None,34],name='input_y') 307 | 308 | with tf.name_scope('conv1'): 309 | W_C1 = tf.Variable(tf.truncated_normal(shape=[3,3,1,32],stddev=0.1)) 310 | b_C1 = tf.Variable(tf.constant(0.1,tf.float32,shape=[32])) 311 | 312 | X=tf.reshape(input_x,[-1,20,20,1]) 313 | featureMap_C1 = tf.nn.relu(tf.nn.conv2d(X,W_C1,strides=[1,1,1,1],padding='SAME') + b_C1 ) 314 | 315 | with tf.name_scope('conv2'): 316 | W_C2 = tf.Variable(tf.truncated_normal(shape=[3,3,32,64],stddev=0.1)) 317 | b_C2 = tf.Variable(tf.constant(0.1,tf.float32,shape=[64])) 318 | featureMap_C2 = tf.nn.relu(tf.nn.conv2d(featureMap_C1,W_C2,strides=[1,1,1,1],padding='SAME') + b_C2) 319 | 320 | with tf.name_scope('pooling2'): 321 | featureMap_S2 = tf.nn.max_pool(featureMap_C2,ksize=[1,2,2,1],strides=[1,2,2,1],padding='VALID') 322 | 323 | with tf.name_scope('conv3'): 324 | W_C3 = tf.Variable(tf.truncated_normal(shape=[3,3,64,8],stddev=0.1)) 325 | b_C3 = tf.Variable(tf.constant(0.1,shape=[8],dtype=tf.float32)) 326 | featureMap_C3 = tf.nn.relu(tf.nn.conv2d(featureMap_S2,filter=W_C3,strides=[1,1,1,1],padding='SAME')+ b_C3) 327 | 328 | with tf.name_scope('pooling3'): 329 | featureMap_S3 = tf.nn.max_pool(featureMap_C3,[1,2,2,1],[1,2,2,1],padding='VALID') 330 | 331 | with tf.name_scope('fulnet'): 332 | featureMap_flatten = tf.reshape(featureMap_S3,[-1,5*5*8]) 333 | W_F4 = tf.Variable(tf.truncated_normal(shape=[5*5*8,512],stddev=0.1)) 334 | b_F4 = tf.Variable(tf.constant(0.1,shape=[512],dtype=tf.float32)) 335 | out_F4 = tf.nn.relu(tf.matmul(featureMap_flatten,W_F4) + b_F4) 336 | #out_F4 =tf.nn.dropout(out_F4,keep_prob=0.5) 337 | with tf.name_scope('output'): 338 | W_OUTPUT = tf.Variable(tf.truncated_normal(shape=[512,34],stddev=0.1)) 339 | b_OUTPUT = tf.Variable(tf.constant(0.1,shape=[34],dtype=tf.float32)) 340 | logits = tf.matmul(out_F4,W_OUTPUT)+b_OUTPUT 341 | 342 | loss = tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(labels=input_y,logits=logits)) 343 | train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss) 344 | predictY = tf.nn.softmax(logits) 345 | y_pred=tf.arg_max(predictY,1) 346 | bool_pred=tf.equal(tf.arg_max(input_y,1),y_pred) 347 | right_rate=tf.reduce_mean(tf.to_float(bool_pred)) 348 | 349 | saver = tf.train.Saver() 350 | def load_model(sess,dir,modelname): 351 | ckpt=tf.train.get_checkpoint_state(dir) 352 | if ckpt and ckpt.model_checkpoint_path: 353 | print("*"*30) 354 | print("load lastest model......") 355 | saver.restore(sess,dir+".\\"+modelname) 356 | print("*"*30) 357 | def save_model(sess,dir,modelname): 358 | saver.save(sess,dir+modelname) 359 | dir = r".//" 360 | modelname = aspect 361 | 362 | with tf.Session() as sess: 363 | sess.run(tf.initialize_all_variables()) 364 | step = 1 365 | display_interval=200 366 | max_epoch = 50 367 | epoch = 0 368 | acc = 0 369 | load_model(sess,dir=dir,modelname=modelname) 370 | while True : 371 | if step % display_interval ==0: 372 | image_batch,label_batch,epoch = data_gen.next_valid_batch(batch_size) 373 | acc = sess.run(right_rate,feed_dict={input_x:image_batch,input_y:label_batch}) 374 | print({'!'*30+str(epoch)+":"+str(step):acc}) 375 | image_batch,label_batch,epoch = data_gen.next_train_batch(batch_size) 376 | sess.run([loss,train_op],{input_x:image_batch,input_y:label_batch}) 377 | if(epoch> max_epoch): 378 | break 379 | step +=1 380 | while True : 381 | test_img,test_lab,test_epoch = data_gen.next_test_batch(batch_size) 382 | test_acc = sess.run(right_rate,{input_x:test_img,input_y:test_lab}) 383 | acc = test_acc * 0.8 + acc * 0.2 #指数滑动平均 384 | if(test_epoch!=epoch): 385 | print({"Test Over..... acc:":acc}) 386 | break 387 | save_model(sess,dir,modelname) 388 | 389 | ``` 390 | # 训练结果: 391 | area: 392 | ```bash 393 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!3:200': 0.34} 394 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!6:400': 0.61} 395 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!9:600': 0.78} 396 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!13:800': 0.73} 397 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!16:1000': 0.8} 398 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!19:1200': 0.88} 399 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!23:1400': 0.76} 400 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!25:1600': 0.86} 401 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!28:1800': 0.89} 402 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!32:2000': 0.83} 403 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!35:2200': 0.87} 404 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!38:2400': 0.93} 405 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!42:2600': 0.89} 406 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!45:2800': 0.9} 407 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!48:3000': 0.95} 408 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!50:3200': 0.97} 409 | {'Test Over..... acc:': 0.9042283506058594} 410 | ``` 411 | province: 412 | 413 | ```bash 414 | ****************************** 415 | load lastest model...... 416 | ****************************** 417 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!7:200': 0.9} 418 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!15:400': 0.88} 419 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!22:600': 0.91} 420 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!29:800': 0.92} 421 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!36:1000': 0.95} 422 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!44:1200': 0.99} 423 | {'Test Over..... acc:': 0.7786719818115235} 424 | ``` 425 | 可以看到,模型有点过拟合了,因为验证集表现的很好,但是测试gg。 426 | 于是加上dropout 427 | 428 | out_F4 =tf.nn.dropout(out_F4,keep_prob=0.5) 429 | 结果,还是很差,哈哈哈,垃圾网络,233 430 | 431 | ```bash 432 | ****************************** 433 | load lastest model...... 434 | ****************************** 435 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!7:200': 0.91} 436 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!15:400': 0.92} 437 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!24:600': 0.94} 438 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!32:800': 0.92} 439 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!40:1000': 0.95} 440 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!48:1200': 0.98} 441 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!55:1400': 0.94} 442 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!63:1600': 0.96} 443 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!71:1800': 0.95} 444 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!79:2000': 0.99} 445 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!87:2200': 0.95} 446 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!95:2400': 0.98} 447 | {'Test Over..... acc:': 0.857055978012085} 448 | ``` 449 | letter: 450 | 字符比较多,把max_epoch加大一些来。 451 | ```bash 452 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!1:200': 0.02} 453 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!3:400': 0.06} 454 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!5:600': 0.07} 455 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!6:800': 0.05} 456 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!8:1000': 0.29} 457 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!10:1200': 0.21} 458 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!11:1400': 0.34} 459 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!13:1600': 0.41} 460 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!15:1800': 0.51} 461 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!16:2000': 0.48} 462 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!18:2200': 0.51} 463 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!20:2400': 0.68} 464 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!21:2600': 0.48} 465 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!23:2800': 0.64} 466 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!24:3000': 0.76} 467 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!26:3200': 0.65} 468 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!27:3400': 0.64} 469 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!29:3600': 0.71} 470 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!31:3800': 0.77} 471 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!32:4000': 0.75} 472 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!34:4200': 0.74} 473 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!36:4400': 0.82} 474 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!37:4600': 0.8} 475 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!39:4800': 0.77} 476 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!41:5000': 0.82} 477 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!42:5200': 0.9} 478 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!44:5400': 0.72} 479 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!46:5600': 0.88} 480 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!48:5800': 0.94} 481 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!48:6000': 0.58} 482 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!50:6200': 0.85} 483 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!52:6400': 0.91} 484 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!53:6600': 0.85} 485 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!55:6800': 0.89} 486 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!57:7000': 0.91} 487 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!58:7200': 0.9} 488 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!60:7400': 0.92} 489 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!62:7600': 0.97} 490 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!63:7800': 0.9} 491 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!65:8000': 0.85} 492 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!67:8200': 0.91} 493 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!69:8400': 0.94} 494 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!70:8600': 0.84} 495 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!72:8800': 0.89} 496 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!73:9000': 0.93} 497 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!74:9200': 0.9} 498 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!76:9400': 0.92} 499 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!78:9600': 0.97} 500 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!79:9800': 0.95} 501 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!81:10000': 0.96} 502 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!83:10200': 0.96} 503 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!84:10400': 0.93} 504 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!86:10600': 0.81} 505 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!88:10800': 0.97} 506 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!90:11000': 0.94} 507 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!91:11200': 0.84} 508 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!93:11400': 0.99} 509 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!95:11600': 0.94} 510 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!96:11800': 0.93} 511 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!97:12000': 0.95} 512 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!99:12200': 0.96} 513 | {'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!100:12400': 0.98} 514 | {'Test Over..... acc:': 0.8561504802688517} 515 | ``` 516 | -------------------------------------------------------------------------------- /cnn_model.py: -------------------------------------------------------------------------------- 1 | __author__ = 'jmh081701' 2 | import tensorflow as tf 3 | from BaseTool import data_generator 4 | 5 | batch_size = 100 # 每个batch的大小 6 | lr=0.001 # 学习速率,每20个epoch 减少为原来的80% 7 | aspect = "letter" 8 | data_gen = data_generator(aspect) 9 | 10 | input_x =tf.placeholder(dtype=tf.float32,shape=[None,20,20],name='input_x') 11 | input_y =tf.placeholder(dtype=tf.float32,shape=[None,34],name='input_y') 12 | input_learning_rate = tf.placeholder(dtype=tf.float32,name='learning_rate') 13 | 14 | with tf.name_scope('conv1'): 15 | W_C1 = tf.Variable(tf.truncated_normal(shape=[3,3,1,32],stddev=0.1)) 16 | b_C1 = tf.Variable(tf.constant(0.1,tf.float32,shape=[32])) 17 | 18 | X=tf.reshape(input_x,[-1,20,20,1]) 19 | featureMap_C1 = tf.nn.relu(tf.nn.conv2d(X,W_C1,strides=[1,1,1,1],padding='SAME') + b_C1 ) 20 | 21 | with tf.name_scope('conv2'): 22 | W_C2 = tf.Variable(tf.truncated_normal(shape=[3,3,32,64],stddev=0.1)) 23 | b_C2 = tf.Variable(tf.constant(0.1,tf.float32,shape=[64])) 24 | featureMap_C2 = tf.nn.relu(tf.nn.conv2d(featureMap_C1,W_C2,strides=[1,1,1,1],padding='SAME') + b_C2) 25 | 26 | with tf.name_scope('pooling2'): 27 | featureMap_S2 = tf.nn.max_pool(featureMap_C2,ksize=[1,2,2,1],strides=[1,2,2,1],padding='VALID') 28 | 29 | with tf.name_scope('conv3'): 30 | W_C3 = tf.Variable(tf.truncated_normal(shape=[3,3,64,8],stddev=0.1)) 31 | b_C3 = tf.Variable(tf.constant(0.1,shape=[8],dtype=tf.float32)) 32 | featureMap_C3 = tf.nn.relu(tf.nn.conv2d(featureMap_S2,filter=W_C3,strides=[1,1,1,1],padding='SAME')+ b_C3) 33 | 34 | with tf.name_scope('pooling3'): 35 | featureMap_S3 = tf.nn.max_pool(featureMap_C3,[1,2,2,1],[1,2,2,1],padding='VALID') 36 | 37 | with tf.name_scope('fulnet'): 38 | featureMap_flatten = tf.reshape(featureMap_S3,[-1,5*5*8]) 39 | W_F4 = tf.Variable(tf.truncated_normal(shape=[5*5*8,512],stddev=0.1)) 40 | b_F4 = tf.Variable(tf.constant(0.1,shape=[512],dtype=tf.float32)) 41 | out_F4 = tf.nn.relu(tf.matmul(featureMap_flatten,W_F4) + b_F4) 42 | out_F4 =tf.nn.dropout(out_F4,keep_prob=0.5) 43 | with tf.name_scope('output'): 44 | W_OUTPUT = tf.Variable(tf.truncated_normal(shape=[512,34],stddev=0.1)) 45 | b_OUTPUT = tf.Variable(tf.constant(0.1,shape=[34],dtype=tf.float32)) 46 | logits = tf.matmul(out_F4,W_OUTPUT)+b_OUTPUT 47 | 48 | loss = tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(labels=input_y,logits=logits)) 49 | train_op = tf.train.AdamOptimizer(learning_rate=input_learning_rate).minimize(loss) 50 | predictY = tf.nn.softmax(logits) 51 | y_pred=tf.arg_max(predictY,1) 52 | bool_pred=tf.equal(tf.arg_max(input_y,1),y_pred) 53 | right_rate=tf.reduce_mean(tf.to_float(bool_pred)) 54 | 55 | saver = tf.train.Saver() 56 | def load_model(sess,dir,modelname): 57 | ckpt=tf.train.get_checkpoint_state(dir) 58 | if ckpt and ckpt.model_checkpoint_path: 59 | print("*"*30) 60 | print("load lastest model......") 61 | saver.restore(sess,dir+".\\"+modelname) 62 | print("*"*30) 63 | def save_model(sess,dir,modelname): 64 | saver.save(sess,dir+modelname) 65 | dir = r".\\parameter\\%s\\"%aspect 66 | modelname = aspect 67 | 68 | with tf.Session() as sess: 69 | sess.run(tf.initialize_all_variables()) 70 | step = 1 71 | display_interval=200 72 | max_epoch = 500 73 | epoch = 1 74 | acc = 0 75 | load_model(sess,dir=dir,modelname=modelname) 76 | while True : 77 | if step % display_interval ==0: 78 | image_batch,label_batch,epoch = data_gen.next_valid_batch(batch_size) 79 | acc = sess.run(right_rate,feed_dict={input_x:image_batch,input_y:label_batch,input_learning_rate:lr}) 80 | print({'!'*30+str(epoch)+":"+str(step):acc}) 81 | image_batch,label_batch,epoch = data_gen.next_train_batch(batch_size) 82 | sess.run([loss,train_op],{input_x:image_batch,input_y:label_batch}) 83 | if(epoch> max_epoch): 84 | break 85 | step +=1 86 | if (epoch % 20) ==0: 87 | lr =lr * 0.8 88 | while True : 89 | test_img,test_lab,test_epoch = data_gen.next_test_batch(batch_size) 90 | test_acc = sess.run(right_rate,feed_dict={input_x:test_img,input_y:test_lab,input_learning_rate:lr}) 91 | acc = test_acc * 0.8 + acc * 0.2 #指数滑动平均 92 | if(test_epoch!=epoch): 93 | print({"Test Over..... acc:":acc}) 94 | break 95 | 96 | save_model(sess,dir,modelname) 97 | -------------------------------------------------------------------------------- /dataset.rar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jmhIcoding/car-board-reg/4cdc0426248099cf2480a4e478c4e2c4682eb130/dataset.rar -------------------------------------------------------------------------------- /dataset/map.py: -------------------------------------------------------------------------------- 1 | maps={ 2 | "province": 3 | { 4 | "0":"皖", 5 | "1":"沪", 6 | "2":"津", 7 | "3":"渝", 8 | "4":"冀", 9 | "5":"晋", 10 | "6":"蒙", 11 | "7":"辽", 12 | "8":"吉", 13 | "9":"黑", 14 | "10":"苏", 15 | "11":"浙", 16 | "12":"京", 17 | "13":"闽", 18 | "14":"赣", 19 | "15":"鲁", 20 | "16":"豫", 21 | "17":"鄂", 22 | "18":"湘", 23 | "19":"粤", 24 | "20":"桂", 25 | "21":"琼", 26 | "22":"川", 27 | "23":"贵", 28 | "24":"云", 29 | "25":"藏", 30 | "26":"陕", 31 | "27":"甘", 32 | "28":"青", 33 | "29":"宁", 34 | "30":"新" 35 | }, 36 | "area": 37 | { 38 | "0":"A", 39 | "1":"B", 40 | "2":"C", 41 | "3":"D", 42 | "4":"E", 43 | "5":"F", 44 | "6":"G", 45 | "7":"H", 46 | "8":"I", 47 | "9":"J", 48 | "10":"K", 49 | "11":"L", 50 | "12":"M", 51 | "13":"N", 52 | "14":"O", 53 | "15":"P", 54 | "16":"Q", 55 | "17":"R", 56 | "18":"S", 57 | "19":"T", 58 | "20":"U", 59 | "21":"V", 60 | "22":"W", 61 | "23":"X", 62 | "24":"Y", 63 | "25":"Z" 64 | }, 65 | "letter": 66 | { 67 | "0":"0", 68 | "1":"1", 69 | "2":"2", 70 | "3":"3", 71 | "4":"4", 72 | "5":"5", 73 | "6":"6", 74 | "7":"7", 75 | "8":"8", 76 | "9":"9", 77 | "10":"A", 78 | "11":"B", 79 | "12":"C", 80 | "13":"D", 81 | "14":"E", 82 | "15":"F", 83 | "16":"G", 84 | "17":"H", 85 | "18":"J", 86 | "19":"K", 87 | "20":"L", 88 | "21":"M", 89 | "22":"N", 90 | "23":"P", 91 | "24":"Q", 92 | "25":"R", 93 | "26":"S", 94 | "27":"T", 95 | "28":"U", 96 | "29":"V", 97 | "30":"W", 98 | "31":"X", 99 | "32":"Y", 100 | "33":"Z" 101 | } 102 | } -------------------------------------------------------------------------------- /picReader.py: -------------------------------------------------------------------------------- 1 | __author__ = 'jmh081701' 2 | from PIL import Image 3 | def img2mat(img_filename): 4 | #把所有的图片都resize为20x20 5 | img = Image.open(img_filename) 6 | img = img.resize((20,20)) 7 | mat = [[img.getpixel((x,y)) for x in range(0,img.size[0])] for y in range(0,img.size[1])] 8 | return mat 9 | def test(): 10 | mat = img2mat("dataset\\test\\1.bmp") 11 | print(mat) 12 | print(mat[0][0],len(mat),len(mat[0])) 13 | if __name__ == '__main__': 14 | test() --------------------------------------------------------------------------------