├── README.md ├── caffe_model ├── caffenet_deploy_1.prototxt ├── caffenet_train_val_1.prototxt └── solver_1.prototxt ├── caffe_model_test ├── __init__.py ├── predict_base.py ├── predict_from_db.py └── predict_from_local.py ├── caffe_train ├── __init__.py └── caffe_train.py ├── run_images ├── caffe_model_dir.jpg ├── caffe_traintxt_1.jpg ├── caffe_traintxt_2.jpg ├── caffe_traintxt_3.jpg ├── caffe_traintxt_4.jpg ├── caffe_triaintxt_5.jpg ├── code_framework_1.jpg ├── create_lmdb_2.jpg ├── deploy_txt.jpg ├── img_url_2 (2).jpg ├── img_url_2.jpg ├── input_data_1.jpg ├── label_dir_1.jpg ├── label_dir_2.jpg ├── lmdb_img_1.jpg ├── mean_1.png ├── mean_2.png ├── predict_base_1.jpg ├── predict_from_db_1.jpg ├── predict_from_db_2.jpg ├── predict_from_local_1.jpg ├── sql_ret.png └── train_process_!.png ├── train_data_generate ├── __init__.py ├── create_lmdb.py └── create_mean_binaryproto.py └── utils ├── DbBse.py ├── __init__.py └── img_process.py /README.md: -------------------------------------------------------------------------------- 1 | 2 | 本文主要是使用caffe python做图片识别的示例包括训练数据lmdb生成,训练,以及模型测试,主要内容如下: 3 | 4 | 5 | ---------- 6 | 7 | 8 | 1. 训练,验证数据lmdb生成,主要包括:样本的预处理 (直方图均衡化,resize),训练样本以及验证样本的lmdb的生成,以及mean_file mean.binaryproto生成 9 | 10 | 2. caffe中模型的定义,主要是修改 [caffe Alexnet 训练文件train_val.prototxt](https://github.com/BVLC/caffe/blob/master/models/bvlc_alexnet/train_val.prototxt) ,以及[训练参数文件solver.prototxt ](https://github.com/BVLC/caffe/blob/master/models/bvlc_alexnet/solver.prototxt),还有[部署文件deploy.prototxt](https://github.com/BVLC/caffe/blob/master/models/bvlc_alexnet/deploy.prototxt) 11 | 12 | 3. 训练验证数据准备完成之后,就是模型的训练 13 | 14 | 4. 得到训练模型之后,一般会进行本地测试以及从数据库获取url测试然后将结果写到数据库中 15 | 16 | 17 | ---------- 18 | 先上个[代码](https://github.com/Jayhello/python_caffe_train_test "github代码地址")的框架图,说明见图片(下面会有详细的讲解): 19 | 20 | ![这里写图片描述](http://img.blog.csdn.net/20170909195539501?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvaGFsdW9sdW8yMTE=/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/SouthEast) 21 | 22 | 下面给出最终的识别结果: 23 | 24 | 25 | ---------- 26 | 27 | 28 | ![这里写图片描述](http://img.blog.csdn.net/20170909200020022?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvaGFsdW9sdW8yMTE=/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/SouthEast) 29 | 30 | 31 | ### 1. 训练,验证数据lmdb生成 32 | 33 | 1. 对图片进行预处理包括直方图均衡化(Histogram equalization)以及resize到指定的大小,并生成lmdb格式,图片以及对于的标签(label) 34 | 35 | 2. 按照一定的比例生成,训练样本lmdb以及验证样本lmdb,以及mean_file mean.binaryproto 36 | 37 | 3. 在测试的时候,我们往往是从数据库中读取url以及id信息,然后将url转化为cv2 可以处理的图片样式,因此我们还要实现将url转化cv2可以处理的图片 38 | 39 | 40 | ### 1.1 图片进行预处理包括直方图均衡化,url->cv2 image 格式 41 | 下面通过代码来讲解(文件: utils->img_process.py): 42 | 43 | ```python 44 | # _*_coding:utf-8 _*_ 45 | 46 | import cv2 47 | import urllib 48 | import numpy as np 49 | 50 | IMG_HEIGHT = 227 51 | IMG_WIDTH = 227 52 | 53 | # 对图片做直方图均衡化处理 54 | def pre_process_img(img, img_height=IMG_HEIGHT, img_width=IMG_WIDTH): 55 | # firstly histogram equalization 56 | img[:, :, 0] = cv2.equalizeHist(img[:, :, 0]) 57 | img[:, :, 1] = cv2.equalizeHist(img[:, :, 1]) 58 | img[:, :, 2] = cv2.equalizeHist(img[:, :, 2]) 59 | 60 | # resize image to size 61 | img = cv2.resize(img, (img_width, img_height), interpolation=cv2.INTER_CUBIC) 62 | 63 | return img 64 | 65 | # 通过图片url将其转化为cv2可以处理的形式 66 | def get_cv_img__from_url(url): 67 | """ 68 | read image from url to cv codec 69 | :param url: 70 | :return: 71 | """ 72 | try: 73 | url_response = urllib.urlopen(url) 74 | img_array = np.array(bytearray(url_response.read()), dtype=np.uint8) 75 | img = cv2.imdecode(img_array, -1) 76 | return img 77 | except Exception, e: 78 | print e 79 | return None 80 | 81 | 82 | if __name__ == '__main__': 83 | url = 'http://www.sanyarb.com.cn/images/attachement/jpg/site2/20161009/A121475977636942_change_ljx6a9_b.jpg' 84 | img = get_cv_img__from_url(url) 85 | cv2.imshow("zhan lang", img) 86 | 87 | img = pre_process_img(img) 88 | cv2.imshow("pre_process_img", img) 89 | cv2.waitKey() 90 | pass 91 | 92 | 93 | ``` 94 | 95 | 下面是下载网上的图片,然后对其进行直方图均衡化以及resize的运行的结果: 96 | 97 | ![这里写图片描述](http://img.blog.csdn.net/20170909200955191?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvaGFsdW9sdW8yMTE=/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/SouthEast) 98 | 99 | 100 | ---------- 101 | 102 | 103 | ### 1.2 图片按照一定的比例生成训练样本以及验证样本lmdb] 104 | 105 | ```python 106 | # _*_coding:utf-8 _*_ 107 | 108 | import sys 109 | sys.path.insert(0, '../../caffe_train_test/') 110 | import os 111 | import glob 112 | import random 113 | import numpy as np 114 | 115 | import cv2 116 | 117 | import caffe 118 | from caffe.proto import caffe_pb2 119 | import lmdb 120 | 121 | from utils.img_process import * 122 | 123 | # 根据图片和标签转化为对应的lmdb格式 124 | def make_datum(img, label): 125 | # image is numpy.ndarray format. BGR instead of RGB 126 | return caffe_pb2.Datum( 127 | channels=3, 128 | width=IMG_HEIGHT, 129 | height=IMG_WIDTH, 130 | label=label, 131 | data=np.rollaxis(img, 2).tostring()) 132 | 133 | 134 | # 创建lmdb的基类 135 | class GenerateLmdb(object): 136 | 137 | def __init__(self, img_path): 138 | """ 139 | img_path -> multiple calss directory 140 | like, class_1, class_2, class_3.... 141 | each class has corresponding class image like class_1_1.png 142 | :param img_path: 143 | """ 144 | # get all the images in different class directory 145 | # 获取到多有的图片列表 146 | self.img_lst = glob.glob(os.path.join(img_path, '*', '*.png')) 147 | print 'input_img list num is %s' % len(self.img_lst) 148 | # shuffle all the images 149 | # 需要对列表乱序 150 | random.shuffle(self.img_lst) 151 | 152 | # 根据标签,比例生成训练lmdb以及验证lmdb 153 | def generate_lmdb(self, label_lst, percentage, train_path, validation_path): 154 | """ 155 | label_lst like ['class_1', 'class_2', 'class_3', .....] 156 | percentage like is 5 (4/5) then 80% be train image, (1/5) 20% be validation image 157 | train_path like that '/data/train/train_lmdb' 158 | validation_path like '/data/train/validation_lmdb' 159 | """ 160 | print 'now generate train lmdb' 161 | self._generate_lmdb(label_lst, percentage, True, train_path) 162 | print 'now generate validation lmdb' 163 | self._generate_lmdb(label_lst, percentage, False, validation_path) 164 | 165 | print '\n generate all images' 166 | 167 | def _generate_lmdb(self, label_lst, percentage, b_train, input_path): 168 | """ 169 | b_train is True means to generate train lmdb, or validation lmdb 170 | """ 171 | output_db = lmdb.open(input_path, map_size=int(1e12)) 172 | with output_db.begin(write=True) as in_txn: 173 | for idx, img_path in enumerate(self.img_lst): 174 | 175 | # create train data 176 | if b_train: 177 | # !=0 means validation data then skip loop 178 | if idx % percentage != 0: 179 | continue 180 | # create validation data 181 | else: 182 | # ==0 means train data then skip 183 | if idx % percentage == 0: 184 | continue 185 | 186 | img = cv2.imread(img_path, cv2.IMREAD_COLOR) 187 | img = pre_process_img(img) 188 | # path like that '../../class_1/0001.png' 189 | # so img_path.split('/')[-2] -> class_1 190 | label = label_lst.index(img_path.split('/')[-2]) 191 | datum = make_datum(img, label) 192 | in_txn.put('{:0>5d}'.format(idx), datum.SerializeToString()) 193 | print '{:0>5d}'.format(idx) + '->label: ', label, " " + img_path 194 | 195 | output_db.close() 196 | 197 | 198 | def get_label_lst_by_dir(f_dir): 199 | """ 200 | f_dir like 'home/user/class', sub dir 'class_1', 'class_2'...'class_n' 201 | :return: ['class_1', 'class_2'...'class_n'] 202 | """ 203 | return os.listdir(f_dir) 204 | 205 | if __name__ == '__main__': 206 | img_path = '../../ad_train/' 207 | cl = GenerateLmdb(img_path) 208 | 209 | train_lmdb = '/data6/light/storm_1_1/images/ad_train_py/input_data/train_lmdb' 210 | validation_lmdb = '/data6/light/storm_1_1/images/ad_train_py/input_data/validation_lmdb' 211 | 212 | # 删除原有的lmdb文件 213 | os.system('rm -rf ' + train_lmdb) 214 | os.system('rm -rf ' + validation_lmdb) 215 | 216 | input_path = '/data6/light/storm_1_1/images/ad_train/' 217 | label_lst = get_label_lst_by_dir(input_path) 218 | print 'label_lst is: %s' % ', '.join(label_lst) 219 | 220 | # (1/10)10% to be validation data, 90% to be train data 221 | # 1/10的文件为验证lmdb, 9/10为训练lmdb 222 | percentage = 10 223 | 224 | cl.generate_lmdb(label_lst, percentage, train_lmdb, validation_lmdb) 225 | 226 | pass 227 | 228 | 229 | ``` 230 | 231 | 下面是实践的运行截图(这个代码好早前就运行了,这次写bolg做了一些处理)下面是一个三分类的目录(前面做过十几中的分类,这里写bolg,做了简化) 232 | 类别标签是: ad_text(文字广告), ad_web(网页广告),others(其他类) 233 | 234 | 类别目录如下: 235 | 236 | ![分类目录](http://img.blog.csdn.net/20170909202101550?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvaGFsdW9sdW8yMTE=/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/SouthEast) 237 | 238 | 下面是输出的label列表: 239 | 240 | ![这里写图片描述](http://img.blog.csdn.net/20170909202034630?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvaGFsdW9sdW8yMTE=/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/SouthEast) 241 | 242 | 下面是运行 `python create_lmdb.py` 的部分日志结果(为了简便做了很多处理) 243 | 244 | ![这里写图片描述](http://img.blog.csdn.net/20170909202526070?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvaGFsdW9sdW8yMTE=/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/SouthEast) 245 | 246 | 下面是最终生成的lmdb文件: 247 | 248 | ![这里写图片描述](http://img.blog.csdn.net/20170909202706259?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvaGFsdW9sdW8yMTE=/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/SouthEast) 249 | 250 | 到此我们生成了,caffe训练需要的lmdb文件 251 | 252 | 253 | ### 1.3 mean_file mean.binaryproto 254 | 255 | ```python 256 | 257 | # _*_ coding:utf-8 258 | 259 | import os 260 | 261 | # 生成,生成mean_binaryproto文件的字符串命令 262 | def get_mean_cmd(mean_tool_path, train_lmdb_path, mean_binaryproto_path): 263 | # create train command 264 | return '%s -backend=lmdb %s %s ' % (mean_tool_path, train_lmdb_path, mean_binaryproto_path) 265 | 266 | 267 | if __name__ == '__main__': 268 | # caffe mean 工具的路径 269 | mean_tool_path = '/home/ubuntu/caffe/build/tools/compute_image_mean' 270 | train_lmdb_path = '/home/xiongyu/input/train_lmdb' 271 | mean_binaryproto_path = '/home/xiongyu/input/mean.binaryproto' 272 | 273 | cmd = get_mean_cmd(mean_tool_path, train_lmdb_path, mean_binaryproto_path) 274 | print cmd 275 | 276 | # 执行生成命令 277 | os.system(cmd) 278 | 279 | ``` 280 | 281 | cmd合成的字符串 282 | 283 | ![这里写图片描述](http://img.blog.csdn.net/20170909205300908?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvaGFsdW9sdW8yMTE=/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/SouthEast) 284 | 285 | 实际生成的结果 286 | 287 | ![这里写图片描述](http://img.blog.csdn.net/20170909205312198?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvaGFsdW9sdW8yMTE=/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/SouthEast) 288 | 289 | ---------- 290 | 291 | 292 | ### 2. caffe中模型的配置文件的定义以及说明 293 | 294 | 295 | ---------- 296 | 297 | 298 | ### 2.1 训练模型定义 299 | 300 | caffe中模型的定义,主要是修改 [caffe Alexnet 训练文件train_val.prototxt](https://github.com/BVLC/caffe/blob/master/models/bvlc_alexnet/train_val.prototxt) 301 | 。主要修改mean_file mean.binaryproto,source train lmdb 路径, 302 | 303 | 304 | ![这里写图片描述](http://img.blog.csdn.net/20170909203645666?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvaGFsdW9sdW8yMTE=/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/SouthEast) 305 | 306 | 307 | ![这里写图片描述](http://img.blog.csdn.net/20170909203657259?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvaGFsdW9sdW8yMTE=/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/SouthEast) 308 | 309 | 310 | ![这里写图片描述](http://img.blog.csdn.net/20170909203707971?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvaGFsdW9sdW8yMTE=/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/SouthEast) 311 | 312 | 313 | ![这里写图片描述](http://img.blog.csdn.net/20170909203719180?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvaGFsdW9sdW8yMTE=/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/SouthEast) 314 | 315 | 由于这个示例主要讲的是3分类,因此还要修改num_output为3(记得修改对应的 **`部署文件`**) 316 | 317 | ![这里写图片描述](http://img.blog.csdn.net/20170909203731319?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvaGFsdW9sdW8yMTE=/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/SouthEast) 318 | 319 | 320 | ### 2.2 部署文件 321 | 322 | [部署文件deploy.prototxt](https://github.com/BVLC/caffe/blob/master/models/bvlc_alexnet/deploy.prototxt) 记得修改对应的num_output为3和训练文件一致 323 | 324 | 325 | ![这里写图片描述](http://img.blog.csdn.net/20170909205722029?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvaGFsdW9sdW8yMTE=/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/SouthEast) 326 | 327 | ### 2.3 训练运行参数文件 328 | 329 | [训练运行参数文件solver.prototxt ](https://github.com/BVLC/caffe/blob/master/models/bvlc_alexnet/solver.prototxt) 330 | 331 | ``` 332 | net: "/data6/light/storm_1_1/images/ad_train_py/caffe_model/caffenet_train_val_1.prototxt" 333 | test_iter: 1000 334 | # 每1000次做一次验证 335 | test_interval: 1000 336 | base_lr: 0.001 337 | lr_policy: "step" 338 | gamma: 0.1 339 | stepsize: 2500 340 | display: 50 341 | # 最大迭代次数 342 | max_iter: 30000 343 | momentum: 0.9 344 | # 权重衰减因子 345 | weight_decay: 0.0005 346 | # 每训练6000次生成一次模型快照 347 | snapshot: 5000 348 | # 模型快照前缀 349 | snapshot_prefix: "/data6/light/storm_1_1/images/ad_train_py/caffe_model/caffe_model_1" 350 | # GPU模式 351 | solver_mode: GPU 352 | 353 | ``` 354 | 355 | 下面看下最终生成的模型文件(文件太大删除了很多,只保留一个运行时的) 356 | 357 | ![这里写图片描述](http://img.blog.csdn.net/20170909210408514?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvaGFsdW9sdW8yMTE=/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/SouthEast) 358 | 359 | 360 | ---------- 361 | 362 | 363 | ### 3. 训练验证数据准备完成之后,就是模型的训练 364 | 365 | 代码类似与mean 文件的生成,这里就不解释了 366 | 367 | > command |& tee out.log 368 | , 将结果输出到标准输出流以及out.log文件中 369 | 370 | 371 | 372 | ```python 373 | # _*_ coding:utf-8 374 | 375 | import os 376 | 377 | 378 | def get_train_cmd(caffe_path, solver_path, log_path): 379 | # create train command 380 | return '%s train --solver %s |& tee %s ' % (caffe_path, solver_path, log_path) 381 | 382 | 383 | if __name__ == '__main__': 384 | 385 | caffe_path = "/home/xiongyu/caffe/build/tools/caffe" 386 | solver_path = "/home/xiongyu/caffe_models/caffe_model_1/solver_1.prototxt" 387 | log_path = "/home/xiongyu/caffe_models/caffe_model_1/model_1_train.log" 388 | 389 | train = get_train_cmd(caffe_path, solver_path, log_path) 390 | 391 | print train 392 | # use caffe to train model 393 | os.system(train) 394 | 395 | pass 396 | 397 | ``` 398 | 399 | 下面是训练时的部分截图: 400 | 401 | ![这里写图片描述](http://img.blog.csdn.net/20170909211336774?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvaGFsdW9sdW8yMTE=/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/SouthEast) 402 | 403 | 404 | ---------- 405 | 406 | 407 | ### 4. 本地测试以及从数据库获取url测试然后将结果写到数据库中 408 | 409 | 410 | ---------- 411 | 412 | 413 | ### 4.1 测试基类文件predict_base.py 414 | 415 | 为了保证代码的模块性,测试的便捷性,这个基类提供给测试本地文件以及数据库文件调用 416 | 417 | ```python 418 | # _*_coding:utf-8 _*_ 419 | 420 | import sys 421 | sys.path.insert(0, '../../caffe_train_test/') 422 | import os 423 | import glob 424 | import cv2 425 | import caffe 426 | import lmdb 427 | import numpy as np 428 | from caffe.proto import caffe_pb2 429 | 430 | from utils.img_process import * 431 | 432 | 433 | class CaffePredict(object): 434 | 435 | def __init__(self, b_gpu, mean_path, deploy_path, model_path): 436 | # cpu或者是gpu模式 437 | if b_gpu: 438 | caffe.set_mode_gpu() 439 | else: 440 | caffe.set_mode_cpu() 441 | 442 | mean_blob = caffe_pb2.BlobProto() 443 | with open(mean_path) as f: 444 | mean_blob.ParseFromString(f.read()) 445 | 446 | mean_array = np.asarray(mean_blob.data, dtype=np.float32).\ 447 | reshape((mean_blob.channels, mean_blob.height, mean_blob.width)) 448 | 449 | self.net = caffe.Net(deploy_path, model_path, caffe.TEST) 450 | 451 | # Define image transformers 452 | self.transformer = caffe.io.Transformer({'data': self.net.blobs['data'].data.shape}) 453 | self.transformer.set_mean('data', mean_array) 454 | # puts the channel as the first dimention 455 | self.transformer.set_transpose('data', (2, 0, 1)) 456 | 457 | # predict只需要输入cv2 image格式图片即可 458 | def predict(self, img): 459 | img = pre_process_img(img) 460 | self.net.blobs['data'].data[...] = self.transformer.preprocess('data', img) 461 | out = self.net.forward() 462 | pred_probas = out['prob'] 463 | 464 | # predict result 465 | ret_lst = [round(f, 4) for f in pred_probas[0].tolist()] 466 | return ret_lst 467 | 468 | # 获取默认的caffe模型 469 | def get_default_caffe_predict(): 470 | # Read model architecture and trained model's weights 471 | mean_path = "/data6/light/storm_1_1/images/ad_train_py/input_data/mean.binaryproto" 472 | deploy_path = "/data6/light/storm_1_1/images/ad_train_py/caffe_model/caffenet_deploy_1.prototxt" 473 | model_path = "/data6/light/storm_1_1/images/ad_train_py/caffe_model/caffe_model_1_iter_10000.caffemodel" 474 | b_gpu = True 475 | caffe_predict = CaffePredict(b_gpu, mean_path, deploy_path, model_path) 476 | return caffe_predict 477 | 478 | 479 | if __name__ == '__main__': 480 | # 使用默认的模型识别 481 | caffe_predict = get_default_caffe_predict() 482 | 483 | img_path = '/data6/light/storm_1_1/images/ad_train_py/test_data/0.png' 484 | img = cv2.imread(img_path, cv2.IMREAD_COLOR) 485 | print caffe_predict.predict(img) 486 | 487 | pass 488 | 489 | ``` 490 | 491 | 识别一张图片,运行结果如下: 492 | 493 | ![这里写图片描述](http://img.blog.csdn.net/20170909212031010?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvaGFsdW9sdW8yMTE=/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/SouthEast) 494 | 495 | ### 4.2 测试本地目录所有图片文件 496 | 497 | `predict_from_local.py` 读取目录下的所有文件,并输出识别结果 498 | 499 | ```python 500 | import sys 501 | sys.path.insert(0, '../../caffe_train_test/') 502 | from predict_base import CaffePredict, get_default_caffe_predict 503 | import glob 504 | import cv2 505 | 506 | 507 | def get_img_lst(img_dir): 508 | """ 509 | img_dir: /data6/light/storm_1_1/images/ad_train_py/test_data/ 510 | lots of images like '0.jpg, 1.jpg ......' 511 | """ 512 | return glob.glob(img_dir + "*.png") 513 | 514 | 515 | def predict_all(): 516 | path = '/data6/light/storm_1_1/images/ad_train_py/test_data/' 517 | img_lst = get_img_lst(path) 518 | caffe_predict = get_default_caffe_predict() 519 | 520 | for path in img_lst: 521 | try: 522 | img = cv2.imread(path, cv2.IMREAD_COLOR) 523 | # caffe_predict.predict is not thread safe,so can't be used in multiple thread 524 | # python is dummy multiple threads 525 | ret_lst = caffe_predict.predict(img) 526 | print path, ret_lst 527 | except Exception, e: 528 | print e 529 | 530 | 531 | if __name__ == '__main__': 532 | predict_all() 533 | pass 534 | 535 | ``` 536 | 537 | 运行结果如下: 538 | 539 | ![这里写图片描述](http://img.blog.csdn.net/20170909212335843?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvaGFsdW9sdW8yMTE=/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/SouthEast) 540 | 541 | 542 | ### 4.3 测试数据库所有图片文件 543 | 544 | 当然在实际的运行中我们往往测试几十万张图片,一般上传到服务器也很麻烦(图片要下载下来,然后打包在sz到linux目录,这样很麻烦而且,打包文件太大的话上传到服务器往往报错)。所以我们一般在数据库上面读取url然后识别,在把识别的结果写回到数据库,例如这样: 545 | 546 | ![这里写图片描述](http://img.blog.csdn.net/20170909213000172?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvaGFsdW9sdW8yMTE=/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/SouthEast) 547 | 548 | ```python 549 | # _*_ coding:utf-8 _*_ 550 | 551 | import sys 552 | sys.path.insert(0, '../../caffe_train_test/') 553 | from utils.DbBse import DbService, get_default_db 554 | from utils.img_process import get_cv_img__from_url 555 | from predict_base import CaffePredict, get_default_caffe_predict 556 | 557 | 558 | def predict_from_db(): 559 | """ 560 | get all the url and id from database and 561 | then predict, write predict result to database 562 | :return: 563 | """ 564 | db = get_default_db() 565 | 566 | # [(1, 'http://xxx.1.jpg'), (2, 'http://xxx.2.jpg).....] 567 | url_id_lst = db.get_ad_info() 568 | 569 | print 'url_id_lst length is %s: ' % len(url_id_lst) 570 | print 'url_id_lst first is', url_id_lst[0] 571 | 572 | caffe_predict = get_default_caffe_predict() 573 | 574 | for item in url_id_lst: 575 | img = get_cv_img__from_url(item[1]) 576 | if img is None: 577 | continue 578 | 579 | ret_lst = caffe_predict.predict(img) 580 | # item[0] is id 581 | ret_lst.append(item[0]) 582 | # write result to database 583 | print item[1], ret_lst 584 | db.update_ad_info(ret_lst) 585 | 586 | 587 | if __name__ == '__main__': 588 | predict_from_db() 589 | pass 590 | 591 | ``` 592 | 593 | 下面是运行结果: 594 | 595 | ![这里写图片描述](http://img.blog.csdn.net/20170909213045877?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvaGFsdW9sdW8yMTE=/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/SouthEast) 596 | 597 | ![这里写图片描述](http://img.blog.csdn.net/20170909213105846?watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvaGFsdW9sdW8yMTE=/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/SouthEast) 598 | 599 | 600 | > http://adilmoujahid.com/posts/2016/06/introduction-deep-learning-python-caffe/ 601 | > 602 | > https://software.intel.com/en-us/articles/training-and-deploying-deep-learning-networks-with-caffe-optimized-for-intel-architecture 603 | 604 | 605 | # English README project framework 606 | 1. caffe_model:the training and deploy prototxt files 607 | 608 | 2. train_data_generate:generate training lmdb, validation lmdb, and mean_binaryproto 609 | 610 | 3. caffe_train: training caffe model 611 | 612 | 4. caffe_model_test:test model recognition results, both local files, and files from database 613 | 614 | 5. utils:image process fucntion, url(of image) to cv2 format, database process 615 | 616 | ### The directory -> caffe model 617 | ------------ 618 | 619 | [caffe Alexnet training file train_val.prototxt](https://github.com/BVLC/caffe/blob/master/models/bvlc_alexnet/train_val.prototxt), change the input lmdb path 620 | 621 | [parameters files solver.prototxt ](https://github.com/BVLC/caffe/blob/master/models/bvlc_alexnet/solver.prototxt), change input path 622 | 623 | [deploy file deploy.prototxt](https://github.com/BVLC/caffe/blob/master/models/bvlc_alexnet/deploy.prototxt) change output_num like training prototxt files 624 | 625 | 626 | ![](https://github.com/Jayhello/python_caffe_train_test/blob/master/run_images/caffe_traintxt_1.jpg) 627 | 628 | 629 | ![](https://github.com/Jayhello/python_caffe_train_test/blob/master/run_images/caffe_traintxt_2.jpg) 630 | 631 | 632 | ![](https://github.com/Jayhello/python_caffe_train_test/blob/master/run_images/caffe_traintxt_3.jpg) 633 | 634 | 635 | ![](https://github.com/Jayhello/python_caffe_train_test/blob/master/run_images/caffe_traintxt_4.jpg) 636 | 637 | 638 | ![](https://github.com/Jayhello/python_caffe_train_test/blob/master/run_images/caffe_traintxt_5.jpg) 639 | 640 | 641 | ### utils 642 | ------------ 643 | 644 | read image from url coded in cv2 format 645 | 646 | 647 | ![](https://github.com/Jayhello/python_caffe_train_test/blob/master/run_images/img_url_2%20(2).jpg) 648 | 649 | 650 | ### caffe_train.py, create_mean_binaryproto.py 651 | ------------ 652 | generate train command and run train 653 | 654 | create_mean_binaryproto.py create mean binary proto file 655 | 656 | ![](https://github.com/Jayhello/python_caffe_train_test/blob/master/run_images/mean_1.png) 657 | 658 | ![](https://github.com/Jayhello/python_caffe_train_test/blob/master/run_images/mean_2.png) 659 | 660 | ### caffe model test 661 | ------------ 662 | 663 | base predict class 664 | 665 | The below demo predict one image 666 | 667 | ![](https://github.com/Jayhello/python_caffe_train_test/blob/master/run_images/predict_base_1.jpg) 668 | 669 | The below demo predict images from local directory 670 | 671 | ![](https://github.com/Jayhello/python_caffe_train_test/blob/master/run_images/predict_from_local_1.jpg) 672 | 673 | 674 | The below demo predict images from database and write recognition results to database 675 | 676 | ![](https://github.com/Jayhello/python_caffe_train_test/blob/master/run_images/predict_from_db_1.jpg) 677 | 678 | ![](https://github.com/Jayhello/python_caffe_train_test/blob/master/run_images/predict_from_db_2.jpg) 679 | 680 | 681 | ----------------- 682 | 683 | You can see all the explantion in this [bolg](http://blog.csdn.net/haluoluo211/article/details/77918156) 684 | 685 | 686 | Cited some contents from the below two articles. 687 | 688 | >http://adilmoujahid.com/posts/2016/06/introduction-deep-learning-python-caffe/ 689 | 690 | >https://software.intel.com/en-us/articles/training-and-deploying-deep-learning-networks-with-caffe-optimized-for-intel-architecture 691 | 692 | -------------------------------------------------------------------------------- /caffe_model/caffenet_deploy_1.prototxt: -------------------------------------------------------------------------------- 1 | name: "CaffeNet" 2 | layer { 3 | name: "data" 4 | type: "Input" 5 | top: "data" 6 | input_param { shape: { dim: 1 dim: 3 dim: 227 dim: 227 } } 7 | } 8 | layer { 9 | name: "conv1" 10 | type: "Convolution" 11 | bottom: "data" 12 | top: "conv1" 13 | convolution_param { 14 | num_output: 96 15 | kernel_size: 11 16 | stride: 4 17 | } 18 | } 19 | layer { 20 | name: "relu1" 21 | type: "ReLU" 22 | bottom: "conv1" 23 | top: "conv1" 24 | } 25 | layer { 26 | name: "pool1" 27 | type: "Pooling" 28 | bottom: "conv1" 29 | top: "pool1" 30 | pooling_param { 31 | pool: MAX 32 | kernel_size: 3 33 | stride: 2 34 | } 35 | } 36 | layer { 37 | name: "norm1" 38 | type: "LRN" 39 | bottom: "pool1" 40 | top: "norm1" 41 | lrn_param { 42 | local_size: 5 43 | alpha: 0.0001 44 | beta: 0.75 45 | } 46 | } 47 | layer { 48 | name: "conv2" 49 | type: "Convolution" 50 | bottom: "norm1" 51 | top: "conv2" 52 | convolution_param { 53 | num_output: 256 54 | pad: 2 55 | kernel_size: 5 56 | group: 2 57 | } 58 | } 59 | layer { 60 | name: "relu2" 61 | type: "ReLU" 62 | bottom: "conv2" 63 | top: "conv2" 64 | } 65 | layer { 66 | name: "pool2" 67 | type: "Pooling" 68 | bottom: "conv2" 69 | top: "pool2" 70 | pooling_param { 71 | pool: MAX 72 | kernel_size: 3 73 | stride: 2 74 | } 75 | } 76 | layer { 77 | name: "norm2" 78 | type: "LRN" 79 | bottom: "pool2" 80 | top: "norm2" 81 | lrn_param { 82 | local_size: 5 83 | alpha: 0.0001 84 | beta: 0.75 85 | } 86 | } 87 | layer { 88 | name: "conv3" 89 | type: "Convolution" 90 | bottom: "norm2" 91 | top: "conv3" 92 | convolution_param { 93 | num_output: 384 94 | pad: 1 95 | kernel_size: 3 96 | } 97 | } 98 | layer { 99 | name: "relu3" 100 | type: "ReLU" 101 | bottom: "conv3" 102 | top: "conv3" 103 | } 104 | layer { 105 | name: "conv4" 106 | type: "Convolution" 107 | bottom: "conv3" 108 | top: "conv4" 109 | convolution_param { 110 | num_output: 384 111 | pad: 1 112 | kernel_size: 3 113 | group: 2 114 | } 115 | } 116 | layer { 117 | name: "relu4" 118 | type: "ReLU" 119 | bottom: "conv4" 120 | top: "conv4" 121 | } 122 | layer { 123 | name: "conv5" 124 | type: "Convolution" 125 | bottom: "conv4" 126 | top: "conv5" 127 | convolution_param { 128 | num_output: 256 129 | pad: 1 130 | kernel_size: 3 131 | group: 2 132 | } 133 | } 134 | layer { 135 | name: "relu5" 136 | type: "ReLU" 137 | bottom: "conv5" 138 | top: "conv5" 139 | } 140 | layer { 141 | name: "pool5" 142 | type: "Pooling" 143 | bottom: "conv5" 144 | top: "pool5" 145 | pooling_param { 146 | pool: MAX 147 | kernel_size: 3 148 | stride: 2 149 | } 150 | } 151 | layer { 152 | name: "fc6" 153 | type: "InnerProduct" 154 | bottom: "pool5" 155 | top: "fc6" 156 | inner_product_param { 157 | num_output: 4096 158 | } 159 | } 160 | layer { 161 | name: "relu6" 162 | type: "ReLU" 163 | bottom: "fc6" 164 | top: "fc6" 165 | } 166 | layer { 167 | name: "drop6" 168 | type: "Dropout" 169 | bottom: "fc6" 170 | top: "fc6" 171 | dropout_param { 172 | dropout_ratio: 0.5 173 | } 174 | } 175 | layer { 176 | name: "fc7" 177 | type: "InnerProduct" 178 | bottom: "fc6" 179 | top: "fc7" 180 | inner_product_param { 181 | num_output: 4096 182 | } 183 | } 184 | layer { 185 | name: "relu7" 186 | type: "ReLU" 187 | bottom: "fc7" 188 | top: "fc7" 189 | } 190 | layer { 191 | name: "drop7" 192 | type: "Dropout" 193 | bottom: "fc7" 194 | top: "fc7" 195 | dropout_param { 196 | dropout_ratio: 0.5 197 | } 198 | } 199 | layer { 200 | name: "fc8" 201 | type: "InnerProduct" 202 | bottom: "fc7" 203 | top: "fc8" 204 | inner_product_param { 205 | num_output: 3 206 | } 207 | } 208 | layer { 209 | name: "prob" 210 | type: "Softmax" 211 | bottom: "fc8" 212 | top: "prob" 213 | } -------------------------------------------------------------------------------- /caffe_model/caffenet_train_val_1.prototxt: -------------------------------------------------------------------------------- 1 | name: "CaffeNet" 2 | layer { 3 | name: "data" 4 | type: "Data" 5 | top: "data" 6 | top: "label" 7 | include { 8 | phase: TRAIN 9 | } 10 | transform_param { 11 | mirror: true 12 | crop_size: 227 13 | mean_file: "/data6/light/storm_1_1/images/ad_train_py/input_data/mean.binaryproto" 14 | } 15 | # mean pixel / channel-wise mean instead of mean image 16 | # transform_param { 17 | # crop_size: 227 18 | # mean_value: 104 19 | # mean_value: 117 20 | # mean_value: 123 21 | # mirror: true 22 | # } 23 | data_param { 24 | source: "/data6/light/storm_1_1/images/ad_train_py/input_data/train_lmdb" 25 | batch_size: 256 26 | backend: LMDB 27 | } 28 | } 29 | layer { 30 | name: "data" 31 | type: "Data" 32 | top: "data" 33 | top: "label" 34 | include { 35 | phase: TEST 36 | } 37 | transform_param { 38 | mirror: false 39 | crop_size: 227 40 | mean_file: "/data6/light/storm_1_1/images/ad_train_py/input_data/mean.binaryproto" 41 | } 42 | # mean pixel / channel-wise mean instead of mean image 43 | # transform_param { 44 | # crop_size: 227 45 | # mean_value: 104 46 | # mean_value: 117 47 | # mean_value: 123 48 | # mirror: true 49 | # } 50 | data_param { 51 | source: "/data6/light/storm_1_1/images/ad_train_py/input_data/validation_lmdb" 52 | batch_size: 50 53 | backend: LMDB 54 | } 55 | } 56 | layer { 57 | name: "conv1" 58 | type: "Convolution" 59 | bottom: "data" 60 | top: "conv1" 61 | param { 62 | lr_mult: 1 63 | decay_mult: 1 64 | } 65 | param { 66 | lr_mult: 2 67 | decay_mult: 0 68 | } 69 | convolution_param { 70 | num_output: 96 71 | kernel_size: 11 72 | stride: 4 73 | weight_filler { 74 | type: "gaussian" 75 | std: 0.01 76 | } 77 | bias_filler { 78 | type: "constant" 79 | value: 0 80 | } 81 | } 82 | } 83 | layer { 84 | name: "relu1" 85 | type: "ReLU" 86 | bottom: "conv1" 87 | top: "conv1" 88 | } 89 | layer { 90 | name: "pool1" 91 | type: "Pooling" 92 | bottom: "conv1" 93 | top: "pool1" 94 | pooling_param { 95 | pool: MAX 96 | kernel_size: 3 97 | stride: 2 98 | } 99 | } 100 | layer { 101 | name: "norm1" 102 | type: "LRN" 103 | bottom: "pool1" 104 | top: "norm1" 105 | lrn_param { 106 | local_size: 5 107 | alpha: 0.0001 108 | beta: 0.75 109 | } 110 | } 111 | layer { 112 | name: "conv2" 113 | type: "Convolution" 114 | bottom: "norm1" 115 | top: "conv2" 116 | param { 117 | lr_mult: 1 118 | decay_mult: 1 119 | } 120 | param { 121 | lr_mult: 2 122 | decay_mult: 0 123 | } 124 | convolution_param { 125 | num_output: 256 126 | pad: 2 127 | kernel_size: 5 128 | group: 2 129 | weight_filler { 130 | type: "gaussian" 131 | std: 0.01 132 | } 133 | bias_filler { 134 | type: "constant" 135 | value: 1 136 | } 137 | } 138 | } 139 | layer { 140 | name: "relu2" 141 | type: "ReLU" 142 | bottom: "conv2" 143 | top: "conv2" 144 | } 145 | layer { 146 | name: "pool2" 147 | type: "Pooling" 148 | bottom: "conv2" 149 | top: "pool2" 150 | pooling_param { 151 | pool: MAX 152 | kernel_size: 3 153 | stride: 2 154 | } 155 | } 156 | layer { 157 | name: "norm2" 158 | type: "LRN" 159 | bottom: "pool2" 160 | top: "norm2" 161 | lrn_param { 162 | local_size: 5 163 | alpha: 0.0001 164 | beta: 0.75 165 | } 166 | } 167 | layer { 168 | name: "conv3" 169 | type: "Convolution" 170 | bottom: "norm2" 171 | top: "conv3" 172 | param { 173 | lr_mult: 1 174 | decay_mult: 1 175 | } 176 | param { 177 | lr_mult: 2 178 | decay_mult: 0 179 | } 180 | convolution_param { 181 | num_output: 384 182 | pad: 1 183 | kernel_size: 3 184 | weight_filler { 185 | type: "gaussian" 186 | std: 0.01 187 | } 188 | bias_filler { 189 | type: "constant" 190 | value: 0 191 | } 192 | } 193 | } 194 | layer { 195 | name: "relu3" 196 | type: "ReLU" 197 | bottom: "conv3" 198 | top: "conv3" 199 | } 200 | layer { 201 | name: "conv4" 202 | type: "Convolution" 203 | bottom: "conv3" 204 | top: "conv4" 205 | param { 206 | lr_mult: 1 207 | decay_mult: 1 208 | } 209 | param { 210 | lr_mult: 2 211 | decay_mult: 0 212 | } 213 | convolution_param { 214 | num_output: 384 215 | pad: 1 216 | kernel_size: 3 217 | group: 2 218 | weight_filler { 219 | type: "gaussian" 220 | std: 0.01 221 | } 222 | bias_filler { 223 | type: "constant" 224 | value: 1 225 | } 226 | } 227 | } 228 | layer { 229 | name: "relu4" 230 | type: "ReLU" 231 | bottom: "conv4" 232 | top: "conv4" 233 | } 234 | layer { 235 | name: "conv5" 236 | type: "Convolution" 237 | bottom: "conv4" 238 | top: "conv5" 239 | param { 240 | lr_mult: 1 241 | decay_mult: 1 242 | } 243 | param { 244 | lr_mult: 2 245 | decay_mult: 0 246 | } 247 | convolution_param { 248 | num_output: 256 249 | pad: 1 250 | kernel_size: 3 251 | group: 2 252 | weight_filler { 253 | type: "gaussian" 254 | std: 0.01 255 | } 256 | bias_filler { 257 | type: "constant" 258 | value: 1 259 | } 260 | } 261 | } 262 | layer { 263 | name: "relu5" 264 | type: "ReLU" 265 | bottom: "conv5" 266 | top: "conv5" 267 | } 268 | layer { 269 | name: "pool5" 270 | type: "Pooling" 271 | bottom: "conv5" 272 | top: "pool5" 273 | pooling_param { 274 | pool: MAX 275 | kernel_size: 3 276 | stride: 2 277 | } 278 | } 279 | layer { 280 | name: "fc6" 281 | type: "InnerProduct" 282 | bottom: "pool5" 283 | top: "fc6" 284 | param { 285 | lr_mult: 1 286 | decay_mult: 1 287 | } 288 | param { 289 | lr_mult: 2 290 | decay_mult: 0 291 | } 292 | inner_product_param { 293 | num_output: 4096 294 | weight_filler { 295 | type: "gaussian" 296 | std: 0.005 297 | } 298 | bias_filler { 299 | type: "constant" 300 | value: 1 301 | } 302 | } 303 | } 304 | layer { 305 | name: "relu6" 306 | type: "ReLU" 307 | bottom: "fc6" 308 | top: "fc6" 309 | } 310 | layer { 311 | name: "drop6" 312 | type: "Dropout" 313 | bottom: "fc6" 314 | top: "fc6" 315 | dropout_param { 316 | dropout_ratio: 0.5 317 | } 318 | } 319 | layer { 320 | name: "fc7" 321 | type: "InnerProduct" 322 | bottom: "fc6" 323 | top: "fc7" 324 | param { 325 | lr_mult: 1 326 | decay_mult: 1 327 | } 328 | param { 329 | lr_mult: 2 330 | decay_mult: 0 331 | } 332 | inner_product_param { 333 | num_output: 4096 334 | weight_filler { 335 | type: "gaussian" 336 | std: 0.005 337 | } 338 | bias_filler { 339 | type: "constant" 340 | value: 1 341 | } 342 | } 343 | } 344 | layer { 345 | name: "relu7" 346 | type: "ReLU" 347 | bottom: "fc7" 348 | top: "fc7" 349 | } 350 | layer { 351 | name: "drop7" 352 | type: "Dropout" 353 | bottom: "fc7" 354 | top: "fc7" 355 | dropout_param { 356 | dropout_ratio: 0.5 357 | } 358 | } 359 | layer { 360 | name: "fc8" 361 | type: "InnerProduct" 362 | bottom: "fc7" 363 | top: "fc8" 364 | param { 365 | lr_mult: 1 366 | decay_mult: 1 367 | } 368 | param { 369 | lr_mult: 2 370 | decay_mult: 0 371 | } 372 | inner_product_param { 373 | num_output: 3 374 | weight_filler { 375 | type: "gaussian" 376 | std: 0.01 377 | } 378 | bias_filler { 379 | type: "constant" 380 | value: 0 381 | } 382 | } 383 | } 384 | layer { 385 | name: "accuracy" 386 | type: "Accuracy" 387 | bottom: "fc8" 388 | bottom: "label" 389 | top: "accuracy" 390 | include { 391 | phase: TEST 392 | } 393 | } 394 | layer { 395 | name: "loss" 396 | type: "SoftmaxWithLoss" 397 | bottom: "fc8" 398 | bottom: "label" 399 | top: "loss" 400 | } -------------------------------------------------------------------------------- /caffe_model/solver_1.prototxt: -------------------------------------------------------------------------------- 1 | net: "/data6/light/storm_1_1/images/ad_train_py/caffe_model/caffenet_train_val_1.prototxt" 2 | test_iter: 1000 3 | test_interval: 1000 4 | base_lr: 0.001 5 | lr_policy: "step" 6 | gamma: 0.1 7 | stepsize: 2500 8 | display: 50 9 | max_iter: 30000 10 | momentum: 0.9 11 | weight_decay: 0.0005 12 | snapshot: 6000 13 | snapshot_prefix: "/data6/light/storm_1_1/images/ad_train_py/caffe_model/caffe_model_1" 14 | solver_mode: GPU -------------------------------------------------------------------------------- /caffe_model_test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jayhello/python_caffe_train_test/c7aa16cd4faba99f05b42543cb0b3ff1fbd4d6d9/caffe_model_test/__init__.py -------------------------------------------------------------------------------- /caffe_model_test/predict_base.py: -------------------------------------------------------------------------------- 1 | # _*_coding:utf-8 _*_ 2 | 3 | import sys 4 | sys.path.insert(0, '../../caffe_train_test/') 5 | import os 6 | import glob 7 | import cv2 8 | import caffe 9 | import lmdb 10 | import numpy as np 11 | from caffe.proto import caffe_pb2 12 | 13 | from utils.img_process import * 14 | 15 | 16 | class CaffePredict(object): 17 | 18 | def __init__(self, b_gpu, mean_path, deploy_path, model_path): 19 | 20 | if b_gpu: 21 | caffe.set_mode_gpu() 22 | else: 23 | caffe.set_mode_cpu() 24 | 25 | mean_blob = caffe_pb2.BlobProto() 26 | with open(mean_path) as f: 27 | mean_blob.ParseFromString(f.read()) 28 | 29 | mean_array = np.asarray(mean_blob.data, dtype=np.float32).\ 30 | reshape((mean_blob.channels, mean_blob.height, mean_blob.width)) 31 | 32 | self.net = caffe.Net(deploy_path, model_path, caffe.TEST) 33 | 34 | # Define image transformers 35 | self.transformer = caffe.io.Transformer({'data': self.net.blobs['data'].data.shape}) 36 | self.transformer.set_mean('data', mean_array) 37 | # puts the channel as the first dimention 38 | self.transformer.set_transpose('data', (2, 0, 1)) 39 | 40 | def predict(self, img): 41 | img = pre_process_img(img) 42 | self.net.blobs['data'].data[...] = self.transformer.preprocess('data', img) 43 | out = self.net.forward() 44 | pred_probas = out['prob'] 45 | 46 | # predict result 47 | ret_lst = [round(f, 4) for f in pred_probas[0].tolist()] 48 | return ret_lst 49 | 50 | 51 | def get_default_caffe_predict(): 52 | # Read model architecture and trained model's weights 53 | mean_path = "/data6/light/storm_1_1/images/ad_train_py/input_data/mean.binaryproto" 54 | deploy_path = "/data6/light/storm_1_1/images/ad_train_py/caffe_model/caffenet_deploy_1.prototxt" 55 | model_path = "/data6/light/storm_1_1/images/ad_train_py/caffe_model/caffe_model_1_iter_10000.caffemodel" 56 | b_gpu = True 57 | caffe_predict = CaffePredict(b_gpu, mean_path, deploy_path, model_path) 58 | return caffe_predict 59 | 60 | 61 | if __name__ == '__main__': 62 | caffe_predict = get_default_caffe_predict() 63 | 64 | img_path = '/data6/light/storm_1_1/images/ad_train_py/test_data/0.png' 65 | img = cv2.imread(img_path, cv2.IMREAD_COLOR) 66 | print caffe_predict.predict(img) 67 | 68 | pass 69 | -------------------------------------------------------------------------------- /caffe_model_test/predict_from_db.py: -------------------------------------------------------------------------------- 1 | # _*_ coding:utf-8 _*_ 2 | 3 | import sys 4 | sys.path.insert(0, '../../caffe_train_test/') 5 | from utils.DbBse import DbService, get_default_db 6 | from utils.img_process import get_cv_img__from_url 7 | from predict_base import CaffePredict, get_default_caffe_predict 8 | 9 | 10 | def predict_from_db(): 11 | """ 12 | get all the url and id from database and 13 | then predict, write predict result to database 14 | :return: 15 | """ 16 | db = get_default_db() 17 | 18 | # [(1, 'http://xxx.1.jpg'), (2, 'http://xxx.2.jpg).....] 19 | url_id_lst = db.get_ad_info() 20 | 21 | print 'url_id_lst length is %s: ' % len(url_id_lst) 22 | print 'url_id_lst first is', url_id_lst[0] 23 | 24 | caffe_predict = get_default_caffe_predict() 25 | 26 | for item in url_id_lst: 27 | img = get_cv_img__from_url(item[1]) 28 | if img is None: 29 | continue 30 | 31 | ret_lst = caffe_predict.predict(img) 32 | # item[0] is id 33 | ret_lst.append(item[0]) 34 | # write result to database 35 | print item[1], ret_lst 36 | db.update_ad_info(ret_lst) 37 | 38 | 39 | if __name__ == '__main__': 40 | predict_from_db() 41 | pass 42 | -------------------------------------------------------------------------------- /caffe_model_test/predict_from_local.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, '../../caffe_train_test/') 3 | from predict_base import CaffePredict, get_default_caffe_predict 4 | import glob 5 | import cv2 6 | 7 | 8 | def get_img_lst(img_dir): 9 | """ 10 | img_dir: /data6/light/storm_1_1/images/ad_train_py/test_data/ 11 | lots of images like '0.jpg, 1.jpg ......' 12 | """ 13 | return glob.glob(img_dir + "*.png") 14 | 15 | 16 | def predict_all(): 17 | path = '/data6/light/storm_1_1/images/ad_train_py/test_data/' 18 | img_lst = get_img_lst(path) 19 | caffe_predict = get_default_caffe_predict() 20 | 21 | for path in img_lst: 22 | try: 23 | img = cv2.imread(path, cv2.IMREAD_COLOR) 24 | # caffe_predict.predict is not thread safe,so can't be used in multiple thread 25 | # python is dummy multiple threads 26 | ret_lst = caffe_predict.predict(img) 27 | print path, ret_lst 28 | except Exception, e: 29 | print e 30 | 31 | 32 | if __name__ == '__main__': 33 | predict_all() 34 | pass 35 | -------------------------------------------------------------------------------- /caffe_train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jayhello/python_caffe_train_test/c7aa16cd4faba99f05b42543cb0b3ff1fbd4d6d9/caffe_train/__init__.py -------------------------------------------------------------------------------- /caffe_train/caffe_train.py: -------------------------------------------------------------------------------- 1 | # _*_ coding:utf-8 2 | 3 | import os 4 | 5 | 6 | def get_train_cmd(caffe_path, solver_path, log_path): 7 | # create train command 8 | return '%s train --solver %s |& tee %s ' % (caffe_path, solver_path, log_path) 9 | 10 | 11 | if __name__ == '__main__': 12 | 13 | caffe_path = "/home/xiongyu/caffe/build/tools/caffe" 14 | solver_path = "/home/xiongyu/caffe_models/caffe_model_1/solver_1.prototxt" 15 | log_path = "/home/xiongyu/caffe_models/caffe_model_1/model_1_train.log" 16 | 17 | train = get_train_cmd(caffe_path, solver_path, log_path) 18 | 19 | print train 20 | # use caffe to train model 21 | os.system(train) 22 | 23 | pass 24 | -------------------------------------------------------------------------------- /run_images/caffe_model_dir.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jayhello/python_caffe_train_test/c7aa16cd4faba99f05b42543cb0b3ff1fbd4d6d9/run_images/caffe_model_dir.jpg -------------------------------------------------------------------------------- /run_images/caffe_traintxt_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jayhello/python_caffe_train_test/c7aa16cd4faba99f05b42543cb0b3ff1fbd4d6d9/run_images/caffe_traintxt_1.jpg -------------------------------------------------------------------------------- /run_images/caffe_traintxt_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jayhello/python_caffe_train_test/c7aa16cd4faba99f05b42543cb0b3ff1fbd4d6d9/run_images/caffe_traintxt_2.jpg -------------------------------------------------------------------------------- /run_images/caffe_traintxt_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jayhello/python_caffe_train_test/c7aa16cd4faba99f05b42543cb0b3ff1fbd4d6d9/run_images/caffe_traintxt_3.jpg -------------------------------------------------------------------------------- /run_images/caffe_traintxt_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jayhello/python_caffe_train_test/c7aa16cd4faba99f05b42543cb0b3ff1fbd4d6d9/run_images/caffe_traintxt_4.jpg -------------------------------------------------------------------------------- /run_images/caffe_triaintxt_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jayhello/python_caffe_train_test/c7aa16cd4faba99f05b42543cb0b3ff1fbd4d6d9/run_images/caffe_triaintxt_5.jpg -------------------------------------------------------------------------------- /run_images/code_framework_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jayhello/python_caffe_train_test/c7aa16cd4faba99f05b42543cb0b3ff1fbd4d6d9/run_images/code_framework_1.jpg -------------------------------------------------------------------------------- /run_images/create_lmdb_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jayhello/python_caffe_train_test/c7aa16cd4faba99f05b42543cb0b3ff1fbd4d6d9/run_images/create_lmdb_2.jpg -------------------------------------------------------------------------------- /run_images/deploy_txt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jayhello/python_caffe_train_test/c7aa16cd4faba99f05b42543cb0b3ff1fbd4d6d9/run_images/deploy_txt.jpg -------------------------------------------------------------------------------- /run_images/img_url_2 (2).jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jayhello/python_caffe_train_test/c7aa16cd4faba99f05b42543cb0b3ff1fbd4d6d9/run_images/img_url_2 (2).jpg -------------------------------------------------------------------------------- /run_images/img_url_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jayhello/python_caffe_train_test/c7aa16cd4faba99f05b42543cb0b3ff1fbd4d6d9/run_images/img_url_2.jpg -------------------------------------------------------------------------------- /run_images/input_data_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jayhello/python_caffe_train_test/c7aa16cd4faba99f05b42543cb0b3ff1fbd4d6d9/run_images/input_data_1.jpg -------------------------------------------------------------------------------- /run_images/label_dir_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jayhello/python_caffe_train_test/c7aa16cd4faba99f05b42543cb0b3ff1fbd4d6d9/run_images/label_dir_1.jpg -------------------------------------------------------------------------------- /run_images/label_dir_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jayhello/python_caffe_train_test/c7aa16cd4faba99f05b42543cb0b3ff1fbd4d6d9/run_images/label_dir_2.jpg -------------------------------------------------------------------------------- /run_images/lmdb_img_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jayhello/python_caffe_train_test/c7aa16cd4faba99f05b42543cb0b3ff1fbd4d6d9/run_images/lmdb_img_1.jpg -------------------------------------------------------------------------------- /run_images/mean_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jayhello/python_caffe_train_test/c7aa16cd4faba99f05b42543cb0b3ff1fbd4d6d9/run_images/mean_1.png -------------------------------------------------------------------------------- /run_images/mean_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jayhello/python_caffe_train_test/c7aa16cd4faba99f05b42543cb0b3ff1fbd4d6d9/run_images/mean_2.png -------------------------------------------------------------------------------- /run_images/predict_base_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jayhello/python_caffe_train_test/c7aa16cd4faba99f05b42543cb0b3ff1fbd4d6d9/run_images/predict_base_1.jpg -------------------------------------------------------------------------------- /run_images/predict_from_db_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jayhello/python_caffe_train_test/c7aa16cd4faba99f05b42543cb0b3ff1fbd4d6d9/run_images/predict_from_db_1.jpg -------------------------------------------------------------------------------- /run_images/predict_from_db_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jayhello/python_caffe_train_test/c7aa16cd4faba99f05b42543cb0b3ff1fbd4d6d9/run_images/predict_from_db_2.jpg -------------------------------------------------------------------------------- /run_images/predict_from_local_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jayhello/python_caffe_train_test/c7aa16cd4faba99f05b42543cb0b3ff1fbd4d6d9/run_images/predict_from_local_1.jpg -------------------------------------------------------------------------------- /run_images/sql_ret.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jayhello/python_caffe_train_test/c7aa16cd4faba99f05b42543cb0b3ff1fbd4d6d9/run_images/sql_ret.png -------------------------------------------------------------------------------- /run_images/train_process_!.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jayhello/python_caffe_train_test/c7aa16cd4faba99f05b42543cb0b3ff1fbd4d6d9/run_images/train_process_!.png -------------------------------------------------------------------------------- /train_data_generate/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jayhello/python_caffe_train_test/c7aa16cd4faba99f05b42543cb0b3ff1fbd4d6d9/train_data_generate/__init__.py -------------------------------------------------------------------------------- /train_data_generate/create_lmdb.py: -------------------------------------------------------------------------------- 1 | # _*_coding:utf-8 _*_ 2 | 3 | import sys 4 | sys.path.insert(0, '../../caffe_train_test/') 5 | import os 6 | import glob 7 | import random 8 | import numpy as np 9 | 10 | import cv2 11 | 12 | import caffe 13 | from caffe.proto import caffe_pb2 14 | import lmdb 15 | 16 | from utils.img_process import * 17 | 18 | 19 | def make_datum(img, label): 20 | # image is numpy.ndarray format. BGR instead of RGB 21 | return caffe_pb2.Datum( 22 | channels=3, 23 | width=IMG_HEIGHT, 24 | height=IMG_WIDTH, 25 | label=label, 26 | data=np.rollaxis(img, 2).tostring()) 27 | 28 | 29 | class GenerateLmdb(object): 30 | 31 | def __init__(self, img_path): 32 | """ 33 | img_path -> multiple calss directory 34 | like, class_1, class_2, class_3.... 35 | each class has corresponding class image like class_1_1.png 36 | :param img_path: 37 | """ 38 | # get all the images in different class directory 39 | self.img_lst = glob.glob(os.path.join(img_path, '*', '*.png')) 40 | print 'input_img list num is %s' % len(self.img_lst) 41 | # shuffle all the images 42 | random.shuffle(self.img_lst) 43 | 44 | def generate_lmdb(self, label_lst, percentage, train_path, validation_path): 45 | """ 46 | label_lst like ['class_1', 'class_2', 'class_3', .....] 47 | percentage like is 5 (4/5) then 80% be train image, (1/5) 20% be validation image 48 | train_path like that '/data/train/train_lmdb' 49 | validation_path like '/data/train/validation_lmdb' 50 | """ 51 | print 'now generate train lmdb' 52 | self._generate_lmdb(label_lst, percentage, True, train_path) 53 | print 'now generate validation lmdb' 54 | self._generate_lmdb(label_lst, percentage, False, validation_path) 55 | 56 | print '\n generate all images' 57 | 58 | def _generate_lmdb(self, label_lst, percentage, b_train, input_path): 59 | """ 60 | b_train is True means to generate train lmdb, or validation lmdb 61 | """ 62 | output_db = lmdb.open(input_path, map_size=int(1e12)) 63 | with output_db.begin(write=True) as in_txn: 64 | for idx, img_path in enumerate(self.img_lst): 65 | 66 | # create train data 67 | if b_train: 68 | # !=0 means validation data then skip loop 69 | if idx % percentage != 0: 70 | continue 71 | # create validation data 72 | else: 73 | # ==0 means train data then skip 74 | if idx % percentage == 0: 75 | continue 76 | 77 | img = cv2.imread(img_path, cv2.IMREAD_COLOR) 78 | img = pre_process_img(img) 79 | # path like that '../../class_1/0001.png' 80 | # so img_path.split('/')[-2] -> class_1 81 | label = label_lst.index(img_path.split('/')[-2]) 82 | datum = make_datum(img, label) 83 | in_txn.put('{:0>5d}'.format(idx), datum.SerializeToString()) 84 | print '{:0>5d}'.format(idx) + '->label: ', label, " " + img_path 85 | 86 | output_db.close() 87 | 88 | 89 | def get_label_lst_by_dir(f_dir): 90 | """ 91 | f_dir like 'home/user/class', sub dir 'class_1', 'class_2'...'class_n' 92 | :return: ['class_1', 'class_2'...'class_n'] 93 | """ 94 | return os.listdir(f_dir) 95 | 96 | if __name__ == '__main__': 97 | img_path = '../../ad_train/' 98 | cl = GenerateLmdb(img_path) 99 | 100 | train_lmdb = '/data6/light/storm_1_1/images/ad_train_py/input_data/train_lmdb' 101 | validation_lmdb = '/data6/light/storm_1_1/images/ad_train_py/input_data/validation_lmdb' 102 | 103 | os.system('rm -rf ' + train_lmdb) 104 | os.system('rm -rf ' + validation_lmdb) 105 | 106 | input_path = '/data6/light/storm_1_1/images/ad_train/' 107 | label_lst = get_label_lst_by_dir(input_path) 108 | print 'label_lst is: %s' % ', '.join(label_lst) 109 | 110 | # (1/10)10% to be validation data, 90% to be train data 111 | percentage = 10 112 | 113 | cl.generate_lmdb(label_lst, percentage, train_lmdb, validation_lmdb) 114 | 115 | pass 116 | -------------------------------------------------------------------------------- /train_data_generate/create_mean_binaryproto.py: -------------------------------------------------------------------------------- 1 | # _*_ coding:utf-8 2 | 3 | import os 4 | 5 | 6 | def get_mean_cmd(mean_tool_path, train_lmdb_path, mean_binaryproto_path): 7 | # create train command 8 | return '%s -backend=lmdb %s %s ' % (mean_tool_path, train_lmdb_path, mean_binaryproto_path) 9 | 10 | 11 | if __name__ == '__main__': 12 | mean_tool_path = '/home/ubuntu/caffe/build/tools/compute_image_mean' 13 | train_lmdb_path = '/home/xiongyu/input/train_lmdb' 14 | mean_binaryproto_path = '/home/xiongyu/input/mean.binaryproto' 15 | 16 | cmd = get_mean_cmd(mean_tool_path, train_lmdb_path, mean_binaryproto_path) 17 | print cmd 18 | 19 | os.system(cmd) 20 | -------------------------------------------------------------------------------- /utils/DbBse.py: -------------------------------------------------------------------------------- 1 | # _*_ coding:utf-8 _*_ 2 | import MySQLdb 3 | 4 | 5 | class DbBase(object): 6 | def __init__(self, **kwargs): 7 | self.db_config_file = kwargs['db_config'] 8 | self.config_db(self.db_config_file) 9 | 10 | def config_db(self, db_config): 11 | data = db_config 12 | host = data['host'] 13 | user = data['user'] 14 | pwd = data['pwd'] 15 | db = data['db'] 16 | port = data['port'] 17 | self.conn = MySQLdb.connect(host=host, port=port, user=user, passwd=pwd, db=db, charset="utf8", use_unicode=True) 18 | self.cursor = self.conn.cursor() 19 | 20 | 21 | class DbService(DbBase): 22 | def __init__(self, **kwargs): 23 | super(DbService, self).__init__(**kwargs) 24 | 25 | def get_ad_info(self): 26 | """ 27 | return all id and url 28 | [(1, 'http://xxx.1.jpg'), (2, 'http://xxx.2.jpg).....] 29 | :return: 30 | """ 31 | sql = """select id, url from ad_text_web_set_2017818""" 32 | self.cursor.execute(sql) 33 | 34 | return [row for row in self.cursor] 35 | 36 | def update_ad_info(self, lst): 37 | """ 38 | write predict result to database 39 | :param lst: 40 | :return: 41 | """ 42 | try: 43 | sql = """ 44 | update ad_text_web_set_2017818 45 | set label_txt=%s, label_web=%s, label_others=%s,modify_date=now() 46 | where id=%s 47 | """ % (lst[0], lst[1], lst[2], lst[3]) 48 | self.cursor.execute(sql) 49 | self.conn.commit() 50 | except Exception, e: 51 | print e 52 | 53 | 54 | def get_default_db(): 55 | ip = '127.0.0.1' 56 | port = 3307 57 | user = 'user' 58 | pwd = 'user' 59 | db = 'caffe' 60 | 61 | db_config = {} 62 | db_config['host'] = ip 63 | db_config['port'] = port 64 | db_config['user'] = user 65 | db_config['pwd'] = pwd 66 | db_config['db'] = db 67 | 68 | return DbService(db_config=db_config) 69 | 70 | 71 | if __name__ == '__main__': 72 | 73 | pass 74 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jayhello/python_caffe_train_test/c7aa16cd4faba99f05b42543cb0b3ff1fbd4d6d9/utils/__init__.py -------------------------------------------------------------------------------- /utils/img_process.py: -------------------------------------------------------------------------------- 1 | # _*_coding:utf-8 _*_ 2 | 3 | import cv2 4 | import urllib 5 | import numpy as np 6 | 7 | IMG_HEIGHT = 227 8 | IMG_WIDTH = 227 9 | 10 | 11 | def pre_process_img(img, img_height=IMG_HEIGHT, img_width=IMG_WIDTH): 12 | # firstly histogram equalization 13 | img[:, :, 0] = cv2.equalizeHist(img[:, :, 0]) 14 | img[:, :, 1] = cv2.equalizeHist(img[:, :, 1]) 15 | img[:, :, 2] = cv2.equalizeHist(img[:, :, 2]) 16 | 17 | # resize image to size 18 | img = cv2.resize(img, (img_width, img_height), interpolation=cv2.INTER_CUBIC) 19 | 20 | return img 21 | 22 | 23 | def get_cv_img__from_url(url): 24 | """ 25 | read image from url to cv codec 26 | :param url: 27 | :return: 28 | """ 29 | try: 30 | url_response = urllib.urlopen(url) 31 | img_array = np.array(bytearray(url_response.read()), dtype=np.uint8) 32 | img = cv2.imdecode(img_array, -1) 33 | return img 34 | except Exception, e: 35 | print e 36 | return None 37 | 38 | 39 | if __name__ == '__main__': 40 | url = 'http://www.sanyarb.com.cn/images/attachement/jpg/site2/20161009/A121475977636942_change_ljx6a9_b.jpg' 41 | img = get_cv_img__from_url(url) 42 | cv2.imshow("zhan lang", img) 43 | 44 | img = pre_process_img(img) 45 | cv2.imshow("pre_process_img", img) 46 | cv2.waitKey() 47 | pass 48 | --------------------------------------------------------------------------------