├── README.md └── yolo_test.py /README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/plsong/keras-yolo3-test/HEAD/README.md -------------------------------------------------------------------------------- /yolo_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | 功能:keras-yolov3 进行批量测试 并 保存结果 4 | 项目来源:https://github.com/qqwweee/keras-yolo3 5 | """ 6 | 7 | import colorsys 8 | import os 9 | from timeit import default_timer as timer 10 | import time 11 | 12 | import numpy as np 13 | from keras import backend as K 14 | from keras.models import load_model 15 | from keras.layers import Input 16 | from PIL import Image, ImageFont, ImageDraw 17 | 18 | from yolo3.model import yolo_eval, yolo_body, tiny_yolo_body 19 | from yolo3.utils import letterbox_image 20 | from keras.utils import multi_gpu_model 21 | 22 | path = './test/' #待检测图片的位置 23 | 24 | # 创建创建一个存储检测结果的dir 25 | result_path = './result' 26 | if not os.path.exists(result_path): 27 | os.makedirs(result_path) 28 | 29 | # result如果之前存放的有文件,全部清除 30 | for i in os.listdir(result_path): 31 | path_file = os.path.join(result_path,i) 32 | if os.path.isfile(path_file): 33 | os.remove(path_file) 34 | 35 | #创建一个记录检测结果的文件 36 | txt_path =result_path + '/result.txt' 37 | file = open(txt_path,'w') 38 | 39 | class YOLO(object): 40 | _defaults = { 41 | "model_path": 'model_data/yolo.h5', 42 | "anchors_path": 'model_data/yolo_anchors.txt', 43 | "classes_path": 'model_data/coco_classes.txt', 44 | "score" : 0.3, 45 | "iou" : 0.45, 46 | "model_image_size" : (416, 416), 47 | "gpu_num" : 1, 48 | } 49 | 50 | @classmethod 51 | def get_defaults(cls, n): 52 | if n in cls._defaults: 53 | return cls._defaults[n] 54 | else: 55 | return "Unrecognized attribute name '" + n + "'" 56 | 57 | def __init__(self, **kwargs): 58 | self.__dict__.update(self._defaults) # set up default values 59 | self.__dict__.update(kwargs) # and update with user overrides 60 | self.class_names = self._get_class() 61 | self.anchors = self._get_anchors() 62 | self.sess = K.get_session() 63 | self.boxes, self.scores, self.classes = self.generate() 64 | 65 | def _get_class(self): 66 | classes_path = os.path.expanduser(self.classes_path) 67 | with open(classes_path) as f: 68 | class_names = f.readlines() 69 | class_names = [c.strip() for c in class_names] 70 | return class_names 71 | 72 | def _get_anchors(self): 73 | anchors_path = os.path.expanduser(self.anchors_path) 74 | with open(anchors_path) as f: 75 | anchors = f.readline() 76 | anchors = [float(x) for x in anchors.split(',')] 77 | return np.array(anchors).reshape(-1, 2) 78 | 79 | def generate(self): 80 | model_path = os.path.expanduser(self.model_path) 81 | assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.' 82 | 83 | # Load model, or construct model and load weights. 84 | num_anchors = len(self.anchors) 85 | num_classes = len(self.class_names) 86 | is_tiny_version = num_anchors==6 # default setting 87 | try: 88 | self.yolo_model = load_model(model_path, compile=False) 89 | except: 90 | self.yolo_model = tiny_yolo_body(Input(shape=(None,None,3)), num_anchors//2, num_classes) \ 91 | if is_tiny_version else yolo_body(Input(shape=(None,None,3)), num_anchors//3, num_classes) 92 | self.yolo_model.load_weights(self.model_path) # make sure model, anchors and classes match 93 | else: 94 | assert self.yolo_model.layers[-1].output_shape[-1] == \ 95 | num_anchors/len(self.yolo_model.output) * (num_classes + 5), \ 96 | 'Mismatch between model and given anchor and class sizes' 97 | 98 | print('{} model, anchors, and classes loaded.'.format(model_path)) 99 | 100 | # Generate colors for drawing bounding boxes. 101 | hsv_tuples = [(x / len(self.class_names), 1., 1.) 102 | for x in range(len(self.class_names))] 103 | self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples)) 104 | self.colors = list( 105 | map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), 106 | self.colors)) 107 | np.random.seed(10101) # Fixed seed for consistent colors across runs. 108 | np.random.shuffle(self.colors) # Shuffle colors to decorrelate adjacent classes. 109 | np.random.seed(None) # Reset seed to default. 110 | 111 | # Generate output tensor targets for filtered bounding boxes. 112 | self.input_image_shape = K.placeholder(shape=(2, )) 113 | if self.gpu_num>=2: 114 | self.yolo_model = multi_gpu_model(self.yolo_model, gpus=self.gpu_num) 115 | boxes, scores, classes = yolo_eval(self.yolo_model.output, self.anchors, 116 | len(self.class_names), self.input_image_shape, 117 | score_threshold=self.score, iou_threshold=self.iou) 118 | return boxes, scores, classes 119 | 120 | def detect_image(self, image): 121 | start = timer() # 开始计时 122 | 123 | if self.model_image_size != (None, None): 124 | assert self.model_image_size[0]%32 == 0, 'Multiples of 32 required' 125 | assert self.model_image_size[1]%32 == 0, 'Multiples of 32 required' 126 | boxed_image = letterbox_image(image, tuple(reversed(self.model_image_size))) 127 | else: 128 | new_image_size = (image.width - (image.width % 32), 129 | image.height - (image.height % 32)) 130 | boxed_image = letterbox_image(image, new_image_size) 131 | image_data = np.array(boxed_image, dtype='float32') 132 | 133 | print(image_data.shape) #打印图片的尺寸 134 | image_data /= 255. 135 | image_data = np.expand_dims(image_data, 0) # Add batch dimension. 136 | 137 | out_boxes, out_scores, out_classes = self.sess.run( 138 | [self.boxes, self.scores, self.classes], 139 | feed_dict={ 140 | self.yolo_model.input: image_data, 141 | self.input_image_shape: [image.size[1], image.size[0]], 142 | K.learning_phase(): 0 143 | }) 144 | 145 | print('Found {} boxes for {}'.format(len(out_boxes), 'img')) # 提示用于找到几个bbox 146 | 147 | font = ImageFont.truetype(font='font/FiraMono-Medium.otf', 148 | size=np.floor(2e-2 * image.size[1] + 0.2).astype('int32')) 149 | thickness = (image.size[0] + image.size[1]) // 500 150 | 151 | # 保存框检测出的框的个数 152 | file.write('find '+str(len(out_boxes))+' target(s) \n') 153 | 154 | for i, c in reversed(list(enumerate(out_classes))): 155 | predicted_class = self.class_names[c] 156 | box = out_boxes[i] 157 | score = out_scores[i] 158 | 159 | label = '{} {:.2f}'.format(predicted_class, score) 160 | draw = ImageDraw.Draw(image) 161 | label_size = draw.textsize(label, font) 162 | 163 | top, left, bottom, right = box 164 | top = max(0, np.floor(top + 0.5).astype('int32')) 165 | left = max(0, np.floor(left + 0.5).astype('int32')) 166 | bottom = min(image.size[1], np.floor(bottom + 0.5).astype('int32')) 167 | right = min(image.size[0], np.floor(right + 0.5).astype('int32')) 168 | 169 | # 写入检测位置 170 | file.write(predicted_class+' score: '+str(score)+' \nlocation: top: '+str(top)+'、 bottom: '+str(bottom)+'、 left: '+str(left)+'、 right: '+str(right)+'\n') 171 | 172 | print(label, (left, top), (right, bottom)) 173 | 174 | if top - label_size[1] >= 0: 175 | text_origin = np.array([left, top - label_size[1]]) 176 | else: 177 | text_origin = np.array([left, top + 1]) 178 | 179 | # My kingdom for a good redistributable image drawing library. 180 | for i in range(thickness): 181 | draw.rectangle( 182 | [left + i, top + i, right - i, bottom - i], 183 | outline=self.colors[c]) 184 | draw.rectangle( 185 | [tuple(text_origin), tuple(text_origin + label_size)], 186 | fill=self.colors[c]) 187 | draw.text(text_origin, label, fill=(0, 0, 0), font=font) 188 | del draw 189 | 190 | end = timer() 191 | print('time consume:%.3f s '%(end - start)) 192 | return image 193 | 194 | def close_session(self): 195 | self.sess.close() 196 | 197 | 198 | # 图片检测 199 | 200 | if __name__ == '__main__': 201 | 202 | t1 = time.time() 203 | yolo = YOLO() 204 | for filename in os.listdir(path): 205 | image_path = path+'/'+filename 206 | portion = os.path.split(image_path) 207 | file.write(portion[1]+' detect_result:\n') 208 | image = Image.open(image_path) 209 | r_image = yolo.detect_image(image) 210 | file.write('\n') 211 | #r_image.show() 显示检测结果 212 | image_save_path = './result/result_'+portion[1] 213 | print('detect result save to....:'+image_save_path) 214 | r_image.save(image_save_path) 215 | 216 | time_sum = time.time() - t1 217 | file.write('time sum: '+str(time_sum)+'s') 218 | print('time sum:',time_sum) 219 | file.close() 220 | yolo.close_session() 221 | --------------------------------------------------------------------------------