├── README.md ├── car_identify.ipynb ├── car_identify.py ├── check_image.py ├── tools └── mytool.py └── 测试结果.jpg /README.md: -------------------------------------------------------------------------------- 1 | 2 | 使用神经网络对汽车车型进行识别(对车头,车身,车尾任意部分拍照都可以识别) 3 | 4 | 本程序的功能是对拍照后的车辆图片进行识别,判断该车是什么品牌的车,是品牌的什么车型。例如你照了一辆车,它可以判断他是大众的速腾,还是丰田的开凯美瑞,还是日产的天籁等等。之所做这么一款程序,源自本人以前的一个心愿吧。很久前在在唯品会工作期间,开发过一个品牌logo识别程序,能识别出品牌的程序,那时就想,能不能开发个识别车辆车型的程序,世界那么大,那么品牌的车,人只能识别其中很少品牌,如果能做个程序能通过照相把车品牌识别出来是多么酷,但是那时的想法只是通过车标logo进行匹配。直到后来见到一个懂车帝的app,我用自己的手机拍了自己的车,它精准识别出我的车是速腾,没有照车标,我是有点惊讶的,当时想它是怎么做到的。知道后来自己学了神经网络和深度学习的相关技术,就想或许卷积神经网络能很好的解决车型识别的问题。于是我就动手做了这样一个识别汽车车型的神经网络模型,识别率达83%左右(更多数据可以再进一步提高准确率),当然也能很好识别出我的速腾车。由于训练数据都是来自百度图片,百度图片有点坑,搜速腾车的图片可能给你混杂辆迈腾或者高尔夫,或者是车内内饰图片~~我也没有那么多时间去人工处理,所以对精准度有点影响,不然精准度会更高。当然有个good idea是拿goole图片训练个模型去检查百度中错误的图片,然后用百度图片的模型去检查google图片的错误~~这样可以达到自动去除错误的图片。 5 | 6 | 7 | 如果你觉得这代码对你的工作或者学习有用,请点个星 8 | 9 | 需要交流或者帮助的,可以发邮件给我676995058@qq.com,微信QQ同号: 676995058 10 | 11 | 测试结果如下 12 | 13 | ![Image text](https://github.com/blueapplehe/car_identify/blob/master/%E6%B5%8B%E8%AF%95%E7%BB%93%E6%9E%9C.jpg) 14 | -------------------------------------------------------------------------------- /car_identify.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import os,shutil,sys 4 | sys.path.append('/data/py/lib/') 5 | import keras 6 | import time 7 | from keras import models 8 | from keras import layers 9 | from keras import optimizers 10 | #from keras.applications import VGG16 11 | from keras.applications import xception 12 | from keras.applications.resnet import ResNet50 13 | from keras.applications.resnet import preprocess_input, decode_predictions 14 | from keras.applications.inception_v3 import InceptionV3 15 | from keras.preprocessing import image 16 | from keras.layers import Dense, GlobalAveragePooling2D 17 | from keras import backend as K 18 | from keras.models import Model 19 | from keras.preprocessing.image import ImageDataGenerator 20 | import matplotlib.pyplot as plt 21 | ####定义一些常用的训练调整参数########### 22 | epochs=5 #定义训练轮数 23 | batch_size=20 #每批数量 24 | lock_layer_num=0; #锁住的层数 25 | lr=1e-4 #学习率 26 | dense_num=256 #连接层数量 27 | pre_train_epochs=1 #预训练轮数,0表示不进行预训练 28 | img_height=299 #训练图片高度 29 | img_width=299 #训练图片宽度 30 | is_load_model=False #是否加载自己训练的历史模型 31 | ########################## 32 | 33 | base_dir='/data/keras/download/qiche'#汽车图片根目录 34 | train_dir=os.path.join(base_dir,'train')#汽车图片训练目录 35 | validation_dir=os.path.join(base_dir,'validation')#汽车图片验证目录 36 | test_dir=os.path.join(base_dir,'test')#汽车图片测试目录 37 | #精选一些品牌的汽车种类,引入更多的品牌的种类,不会大影响识别准确率,放心推广到更多的品牌和车型, 38 | #这里不演示更多的品牌,是因为我的显卡太烂了,图片太多,训练速度有点满 39 | mod_names=["速腾","迈腾","雷凌","卡罗拉","凯美瑞", 40 | "天籁","雅阁","朗逸","威驰","福克斯", 41 | "福睿斯","蒙迪欧","轩逸","帕萨特","途观", 42 | "飞度","锋范"] 43 | 44 | mod_num=len(mod_names)#汽车车型总数 45 | 46 | 47 | #使用图片数据增强,降低拟合的有效手段 48 | train_datagen=ImageDataGenerator( 49 | rescale=1./255, 50 | rotation_range=30, 51 | width_shift_range=0.2, 52 | height_shift_range=0.2, 53 | shear_range=0.2, 54 | zoom_range=0.2, 55 | horizontal_flip=True, 56 | fill_mode='nearest') 57 | 58 | #验证,测试数据不能进行数据增强 59 | test_datagen=ImageDataGenerator(rescale=1./255) 60 | train_generator=train_datagen.flow_from_directory( 61 | train_dir, 62 | target_size=(img_height,img_width), 63 | batch_size=batch_size, 64 | class_mode='categorical' 65 | ) 66 | 67 | validation_generator=test_datagen.flow_from_directory( 68 | validation_dir, 69 | target_size=(img_height,img_width), 70 | batch_size=batch_size, 71 | class_mode='categorical' 72 | ) 73 | 74 | 75 | if is_load_model is False: 76 | # 构建不带分类器的预训练模型 77 | base_model = xception.Xception(weights="imagenet",include_top=False,input_shape=(img_height,img_width,3)) 78 | 79 | # 添加全局平均池化层 80 | x = base_model.output 81 | x = GlobalAveragePooling2D()(x) 82 | 83 | # 添加一个全连接层 84 | x = Dense(dense_num, activation='relu')(x) 85 | 86 | # 添加一个分类器 87 | predictions = Dense(mod_num, activation='softmax')(x) 88 | 89 | # 构建我们需要训练的完整模型 90 | model = Model(inputs=base_model.input, outputs=predictions) 91 | 92 | # 锁住所有 Xception 的卷积层 93 | for layer in base_model.layers: 94 | layer.trainable = False 95 | 96 | #预训练 97 | if pre_train_epochs>0: 98 | model.compile(optimizer=optimizers.rmsprop_v2.RMSProp(learning_rate=1e-3), loss='categorical_crossentropy',metrics=['acc']) 99 | history=model.fit_generator( 100 | train_generator, 101 | steps_per_epoch=train_generator.n/train_generator.batch_size, 102 | epochs=pre_train_epochs, 103 | validation_data=validation_generator, 104 | validation_steps=validation_generator.n/validation_generator.batch_size 105 | ) 106 | 107 | # 现在顶层应该训练好了,开始微调 Xception的卷积层。 108 | # 锁住底下的几层,然后训练其余的顶层。 109 | # 看看每一层的名字和层号,看看我们应该锁多少层呢: 110 | # for i, layer in enumerate(base_model.layers): 111 | # print(i, layer.name) 112 | 113 | # 锁住的层数 114 | for layer in model.layers[:lock_layer_num]: 115 | layer.trainable = False 116 | for layer in model.layers[lock_layer_num:]: 117 | layer.trainable = True 118 | 119 | # 设置一个很低的学习率,使用 SGD 来微调 120 | model.compile(optimizer=optimizers.rmsprop_v2.RMSprop(lr=lr), loss='categorical_crossentropy',metrics=['acc']) 121 | 122 | # 继续训练模型 123 | history=model.fit_generator( 124 | train_generator, 125 | steps_per_epoch=train_generator.n/train_generator.batch_size, 126 | epochs=epochs, 127 | validation_data=validation_generator, 128 | validation_steps=validation_generator.n/validation_generator.batch_size 129 | ) 130 | #保存训练好的模型 131 | time_t=time.strftime("%m%d%H%M", time.localtime()) 132 | model.save('/data/keras/models/%s.h'%time_t) 133 | 134 | 135 | #显示训练过程中精度变化 136 | if is_load_model is False: 137 | acc=history.history['acc'] 138 | val_acc=history.history['val_acc'] 139 | loss=history.history['loss'] 140 | val_loss=history.history['val_loss'] 141 | epochs=range(1,len(acc)+1) 142 | plt.plot(epochs,acc,'bo',label='Training acc') 143 | plt.plot(epochs,val_acc,'b',label='Validation acc') 144 | plt.legend() 145 | plt.figure() 146 | plt.show() 147 | 148 | #显示测试结果 149 | from keras.preprocessing import image 150 | import numpy as np 151 | import cv2 152 | from tools.mytool import MyTool 153 | test_imgs=['/data/keras/download/qiche/timg2.jpg', 154 | '/data/keras/download/qiche/su1.jpg', 155 | '/data/keras/download/qiche/su2.jpg', 156 | '/data/keras/download/qiche/su3.jpg', 157 | '/data/test/su21.jpg', 158 | '/data/test/su22.jpg', 159 | '/data/test/su23.jpg', 160 | '/data/test/su24.jpeg', 161 | '/data/test/su25.jpeg', 162 | '/data/test/su26.jpeg', 163 | '/data/test/mt20.jpg', 164 | '/data/test/mt21.jpg', 165 | '/data/test/mt22.jpg', 166 | '/data/test/mt23.jpg', 167 | '/data/test/kll20.jpg', 168 | '/data/test/kll21.jpg', 169 | '/data/test/kll22.jpg', 170 | '/data/test/kll23.jpg', 171 | '/data/test/kll24.jpg', 172 | ] 173 | 174 | for img_path in test_imgs: 175 | #img = image.load_img(img_path, target_size=(img_height, img_width)) 176 | img =cv2.imread(img_path) 177 | # plt.imshow(img) 178 | # plt.show() 179 | img=MyTool.cro_img(img,img_height,img_width) 180 | plt.imshow(img) 181 | plt.show() 182 | 183 | x = image.img_to_array(img) 184 | x=x/255 185 | x = np.expand_dims(x, axis=0) 186 | preds = model.predict(x) 187 | paixu=dict(zip(train_generator.class_indices,preds[0])) 188 | paixu= sorted(paixu.items(), key=lambda x: x[1], reverse=True) 189 | print(paixu) -------------------------------------------------------------------------------- /check_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | if __name__ == "__main__": 4 | base_dir='/data/keras/download/qiche/train' 5 | #Image.open("./00b2602fc9aca88659ecf11f6f42ae1a.jpg") 6 | i=0 7 | for root, dirs, files in os.walk(base_dir, topdown=False): 8 | for name in files: 9 | #print(os.path.join(root, name)) 10 | file_path=os.path.join(root, name) 11 | try: 12 | with open(file_path, 'rb') as f: 13 | img_PIL = Image.open(f) 14 | except Exception as e: 15 | print(str(e)) 16 | i=i+1 17 | os.remove(file_path) 18 | for name in dirs: 19 | #print(os.path.join(root, name)) 20 | pass 21 | print(i) -------------------------------------------------------------------------------- /tools/mytool.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | class MyTool(object): 3 | @staticmethod 4 | def cro_img(img,t_height=299,t_width=299): 5 | a = img.shape 6 | height=a[0] 7 | width=a[1] 8 | if height/width>t_height/t_width: 9 | height2=t_height/t_width*width 10 | height2=int(height2) 11 | n=int((height-height2)/2.0) 12 | cropImg=img[n:(height2+n), 0:width] 13 | else: 14 | width2=t_width*1.0/t_height*height 15 | width2=int(width2) 16 | n=int((width-width2)/2.0) 17 | cropImg=img[ 0:height, n:(width2+n)] 18 | cropImg=cv2.resize(cropImg,(t_width,t_height),interpolation=cv2.INTER_AREA ) 19 | return cropImg -------------------------------------------------------------------------------- /测试结果.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blueapplehe/car_identify/5a9f4a0838ba6c0cc4420e7fa9bed22e65987f94/测试结果.jpg --------------------------------------------------------------------------------