├── yolo2_data ├── car.jpg ├── detection.jpg └── coco_classes.txt ├── config.py ├── README.md ├── decode.py ├── Main.py ├── Loss.py ├── model_darknet19.py └── utils.py /yolo2_data/car.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KOD-Chen/YOLOv2-Tensorflow/HEAD/yolo2_data/car.jpg -------------------------------------------------------------------------------- /yolo2_data/detection.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KOD-Chen/YOLOv2-Tensorflow/HEAD/yolo2_data/detection.jpg -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # -------------------------------------- 3 | # @Time : 2018/5/16$ 17:12$ 4 | # @Author : KOD Chen 5 | # @Email : 821237536@qq.com 6 | # @File : configs$.py 7 | # Description :anchor尺寸、coco数据集的80个classes类别名称 8 | # -------------------------------------- 9 | 10 | anchors = [[0.57273, 0.677385], 11 | [1.87446, 2.06253], 12 | [3.33843, 5.47434], 13 | [7.88282, 3.52778], 14 | [9.77052, 9.16828]] 15 | 16 | def read_coco_labels(): 17 | f = open("./yolo2_data/coco_classes.txt") 18 | class_names = [] 19 | for l in f.readlines(): 20 | l = l.strip() # 去掉回车'\n' 21 | class_names.append(l) 22 | return class_names 23 | 24 | class_names = read_coco_labels() 25 | -------------------------------------------------------------------------------- /yolo2_data/coco_classes.txt: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorbike 5 | aeroplane 6 | bus 7 | train 8 | truck 9 | boat 10 | traffic light 11 | fire hydrant 12 | stop sign 13 | parking meter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sports ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard 39 | tennis racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | sofa 59 | pottedplant 60 | bed 61 | diningtable 62 | toilet 63 | tvmonitor 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cell phone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear 79 | hair drier 80 | toothbrush -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # YOLOv2-Tensorflow
2 | ## 声明:
3 | 更详细的代码解读[Tensorflow实现YOLO2](https://zhuanlan.zhihu.com/p/36902889).
4 | 欢迎关注[我的知乎](https://www.zhihu.com/people/chensicheng/posts).

5 | 6 | ## 运行环境:
7 | Python3 + Tensorflow1.5 + OpenCV-python3.3.1 + Numpy1.13
8 | windows和ubuntu环境都可以

9 | 10 | ## 准备工作:
11 | 请在[yolo2检测模型](https://pan.baidu.com/s/1ZeT5HerjQxyUZ_L9d3X52w)下载模型,并放到yolo2_model文件夹下

12 | 13 | ## 文件说明:
14 | 1、model_darknet19.py:yolo2网络模型——darknet19
15 | 2、decode.py:解码darknet19网络得到的参数
16 | 3、utils.py:功能函数,包含:预处理输入图片、筛选边界框NMS、绘制筛选后的边界框
17 | 4、config.py:配置文件,包含anchor尺寸、coco数据集的80个classes类别名称
18 | 5、Main.py:YOLO_v2主函数,对应程序有三个步骤:
19 | (1)输入图片进入darknet19网络得到特征图,并进行解码得到:xmin xmax表示的边界框、置信度、类别概率
20 | (2)筛选解码后的回归边界框——NMS
21 | (3)绘制筛选后的边界框
22 | 6、Loss.py:Yolo_v2 Loss损失函数(train时候用,预测时候没有调用此程序)
23 | (1)IOU值最大的那个anchor与ground truth匹配,对应的预测框用来预测这个ground truth:计算xywh、置信度c(目标值为1)、类别概率p误差。
24 | (2)IOU小于某阈值的anchor对应的预测框:只计算置信度c(目标值为0)误差。
25 | (3)剩下IOU大于某阈值但不是max的anchor对应的预测框:丢弃,不计算任何误差。
26 | 7、yolo2_data文件夹:包含待检测输入图片car.jpg、检测后的输出图片detection.jpg、coco数据集80个类别名称coco_classes.txt

27 | 28 | ## 运行Main.py即可得到效果图:
29 | 1、car.jpg:输入的待检测图片

30 | ![image](https://github.com/KOD-Chen/YOLOv2-Tensorflow/blob/master/yolo2_data/car.jpg)
31 | 2、detected.jpg:检测结果可视化

32 | ![image](https://github.com/KOD-Chen/YOLOv2-Tensorflow/blob/master/yolo2_data/detection.jpg)
33 | -------------------------------------------------------------------------------- /decode.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # -------------------------------------- 3 | # @Time : 2018/5/15$ 17:01$ 4 | # @Author : KOD Chen 5 | # @Email : 821237536@qq.com 6 | # @File : decode$.py 7 | # Description :解码darknet19网络得到的参数. 8 | # -------------------------------------- 9 | 10 | import tensorflow as tf 11 | import numpy as np 12 | 13 | def decode(model_output,output_sizes=(13,13),num_class=80,anchors=None): 14 | ''' 15 | model_output:darknet19网络输出的特征图 16 | output_sizes:darknet19网络输出的特征图大小,默认是13*13(默认输入416*416,下采样32) 17 | ''' 18 | H, W = output_sizes 19 | num_anchors = len(anchors) # 这里的anchor是在configs文件中设置的 20 | anchors = tf.constant(anchors, dtype=tf.float32) # 将传入的anchors转变成tf格式的常量列表 21 | 22 | # 13*13*num_anchors*(num_class+5),第一个维度自适应batchsize 23 | detection_result = tf.reshape(model_output,[-1,H*W,num_anchors,num_class+5]) 24 | 25 | # darknet19网络输出转化——偏移量、置信度、类别概率 26 | xy_offset = tf.nn.sigmoid(detection_result[:,:,:,0:2]) # 中心坐标相对于该cell左上角的偏移量,sigmoid函数归一化到0-1 27 | wh_offset = tf.exp(detection_result[:,:,:,2:4]) #相对于anchor的wh比例,通过e指数解码 28 | obj_probs = tf.nn.sigmoid(detection_result[:,:,:,4]) # 置信度,sigmoid函数归一化到0-1 29 | class_probs = tf.nn.softmax(detection_result[:,:,:,5:]) # 网络回归的是'得分',用softmax转变成类别概率 30 | 31 | # 构建特征图每个cell的左上角的xy坐标 32 | height_index = tf.range(H,dtype=tf.float32) # range(0,13) 33 | width_index = tf.range(W,dtype=tf.float32) # range(0,13) 34 | # 变成x_cell=[[0,1,...,12],...,[0,1,...,12]]和y_cell=[[0,0,...,0],[1,...,1]...,[12,...,12]] 35 | x_cell,y_cell = tf.meshgrid(height_index,width_index) 36 | x_cell = tf.reshape(x_cell,[1,-1,1]) # 和上面[H*W,num_anchors,num_class+5]对应 37 | y_cell = tf.reshape(y_cell,[1,-1,1]) 38 | 39 | # decode 40 | bbox_x = (x_cell + xy_offset[:,:,:,0]) / W 41 | bbox_y = (y_cell + xy_offset[:,:,:,1]) / H 42 | bbox_w = (anchors[:,0] * wh_offset[:,:,:,0]) / W 43 | bbox_h = (anchors[:,1] * wh_offset[:,:,:,1]) / H 44 | # 中心坐标+宽高box(x,y,w,h) -> xmin=x-w/2 -> 左上+右下box(xmin,ymin,xmax,ymax) 45 | bboxes = tf.stack([bbox_x-bbox_w/2, bbox_y-bbox_h/2, 46 | bbox_x+bbox_w/2, bbox_y+bbox_h/2], axis=3) 47 | 48 | return bboxes, obj_probs, class_probs 49 | -------------------------------------------------------------------------------- /Main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # -------------------------------------- 3 | # @Time : 2018/5/16$ 17:17$ 4 | # @Author : KOD Chen 5 | # @Email : 821237536@qq.com 6 | # @File : Main$.py 7 | # Description :YOLO_v2主函数. 8 | # -------------------------------------- 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | import cv2 13 | from PIL import Image 14 | 15 | from YOLO_v2.model_darknet19 import darknet 16 | from YOLO_v2.decode import decode 17 | from YOLO_v2.utils import preprocess_image, postprocess, draw_detection 18 | from YOLO_v2.config import anchors, class_names 19 | 20 | def main(): 21 | input_size = (416,416) 22 | image_file = './yolo2_data/car.jpg' 23 | image = cv2.imread(image_file) 24 | image_shape = image.shape[:2] #只取wh,channel=3不取 25 | 26 | # copy、resize416*416、归一化、在第0维增加存放batchsize维度 27 | image_cp = preprocess_image(image,input_size) 28 | 29 | # 【1】输入图片进入darknet19网络得到特征图,并进行解码得到:xmin xmax表示的边界框、置信度、类别概率 30 | tf_image = tf.placeholder(tf.float32,[1,input_size[0],input_size[1],3]) 31 | model_output = darknet(tf_image) # darknet19网络输出的特征图 32 | output_sizes = input_size[0]//32, input_size[1]//32 # 特征图尺寸是图片下采样32倍 33 | output_decoded = decode(model_output=model_output,output_sizes=output_sizes, 34 | num_class=len(class_names),anchors=anchors) # 解码 35 | 36 | model_path = "./yolo2_model/yolo2_coco.ckpt" 37 | saver = tf.train.Saver() 38 | with tf.Session() as sess: 39 | saver.restore(sess,model_path) 40 | bboxes,obj_probs,class_probs = sess.run(output_decoded,feed_dict={tf_image:image_cp}) 41 | 42 | # 【2】筛选解码后的回归边界框——NMS(post process后期处理) 43 | bboxes,scores,class_max_index = postprocess(bboxes,obj_probs,class_probs,image_shape=image_shape) 44 | 45 | # 【3】绘制筛选后的边界框 46 | img_detection = draw_detection(image, bboxes, scores, class_max_index, class_names) 47 | cv2.imwrite("./yolo2_data/detection.jpg", img_detection) 48 | print('YOLO_v2 detection has done!') 49 | cv2.imshow("detection_results", img_detection) 50 | cv2.waitKey(0) 51 | 52 | if __name__ == '__main__': 53 | main() 54 | -------------------------------------------------------------------------------- /Loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # -------------------------------------- 3 | # @Time : 2018/5/12$ 17:49$ 4 | # @Author : KOD Chen 5 | # @Email : 821237536@qq.com 6 | # @File : Loss$.py 7 | # Description :Yolo_v2 Loss损失函数. 8 | # -------------------------------------- 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | 13 | def compute_loss(predictions,targets,anchors,scales,num_classes=20,output_size=(13,13)): 14 | W,H = output_size 15 | C = num_classes 16 | B = len(anchors) 17 | anchors = tf.constant(anchors,dtype=tf.float32) 18 | anchors = tf.reshape(anchors,[1,1,B,2]) # 存放输入的anchors的wh 19 | 20 | # 【1】ground truth:期望值、真实值 21 | sprob,sconf,snoob,scoor = scales # loss不同部分的前面系数 22 | _coords = targets["coords"] # ground truth [-1, H*W, B, 4],真实坐标xywh 23 | _probs = targets["probs"] # class probability [-1, H*W, B, C] ,类别概率——one hot形式,C维 24 | _confs = targets["confs"] # 1 for object, 0 for background, [-1, H*W, B],置信度,每个边界框一个 25 | # ground truth计算IOU-->_up_left, _down_right 26 | _wh = tf.pow(_coords[:, :, :, 2:4], 2) * np.reshape([W, H], [1, 1, 1, 2]) 27 | _areas = _wh[:, :, :, 0] * _wh[:, :, :, 1] 28 | _centers = _coords[:, :, :, 0:2] 29 | _up_left, _down_right = _centers - (_wh * 0.5), _centers + (_wh * 0.5) 30 | # ground truth汇总 31 | truths = tf.concat([_coords, tf.expand_dims(_confs, -1), _probs], 3) 32 | 33 | # 【2】decode the net prediction:预测值、网络输出值 34 | predictions = tf.reshape(predictions,[-1,H,W,B,(5+C)]) 35 | # t_x, t_y, t_w, t_h 36 | coords = tf.reshape(predictions[:,:,:,:,0:4],[-1,H*W,B,4]) 37 | coords_xy = tf.nn.sigmoid(coords[:,:,:,0:2]) # 0-1,xy是相对于cell左上角的偏移量 38 | coords_wh = tf.sqrt(tf.exp(coords[:,:,:,2:4])*anchors/np.reshape([W,H],[1,1,1,2])) # 0-1,除以特征图的尺寸13,解码成相对于整张图片的wh 39 | coords = tf.concat([coords_xy,coords_wh],axis=3) # [batch_size, H*W, B, 4] 40 | # 置信度 41 | confs = tf.nn.sigmoid(predictions[:,:,:,:,4]) 42 | confs = tf.reshape(confs,[-1,H*W,B,1]) # 每个边界框一个置信度,每个cell有B个边界框 43 | # 类别概率 44 | probs = tf.nn.softmax(predictions[:,:,:,:,5:]) # 网络最后输出是"得分",通过softmax变成概率 45 | probs = tf.reshape(probs,[-1,H*W,B,C]) 46 | # prediction汇总 47 | preds = tf.concat([coords,confs,probs],axis=3) # [-1, H*W, B, (4+1+C)] 48 | # prediction计算IOU-->up_left, down_right 49 | wh = tf.pow(coords[:, :, :, 2:4], 2) * np.reshape([W, H], [1, 1, 1, 2]) 50 | areas = wh[:, :, :, 0] * wh[:, :, :, 1] 51 | centers = coords[:, :, :, 0:2] 52 | up_left, down_right = centers - (wh * 0.5), centers + (wh * 0.5) 53 | 54 | 55 | # ※※※【3】计算ground truth和anchor的IOU:※※※ 56 | # 计算IOU只考虑形状,先将anchor与ground truth的中心点都偏移到同一位置(cell左上角),然后计算出对应的IOU值。 57 | # ①IOU值最大的那个anchor与ground truth匹配,对应的预测框用来预测这个ground truth:计算xywh、置信度c(目标值为1)、类别概率p误差。 58 | # ②IOU小于某阈值的anchor对应的预测框:只计算置信度c(目标值为0)误差。 59 | # ③剩下IOU大于某阈值但不是max的anchor对应的预测框:丢弃,不计算任何误差。 60 | inter_upleft = tf.maximum(up_left, _up_left) 61 | inter_downright = tf.minimum(down_right, _down_right) 62 | inter_wh = tf.maximum(inter_downright - inter_upleft, 0.0) 63 | intersects = inter_wh[:, :, :, 0] * inter_wh[:, :, :, 1] 64 | ious = tf.truediv(intersects, areas + _areas - intersects) 65 | 66 | best_iou_mask = tf.equal(ious, tf.reduce_max(ious, axis=2, keep_dims=True)) 67 | best_iou_mask = tf.cast(best_iou_mask, tf.float32) 68 | mask = best_iou_mask * _confs # [-1, H*W, B] 69 | mask = tf.expand_dims(mask, -1) # [-1, H*W, B, 1] 70 | 71 | # 【4】计算各项损失所占的比例权重weight 72 | confs_w = snoob * (1 - mask) + sconf * mask 73 | coords_w = scoor * mask 74 | probs_w = sprob * mask 75 | weights = tf.concat([coords_w, confs_w, probs_w], axis=3) 76 | 77 | # 【5】计算loss:ground truth汇总和prediction汇总均方差损失函数,再乘以相应的比例权重 78 | loss = tf.pow(preds - truths, 2) * weights 79 | loss = tf.reduce_sum(loss, axis=[1, 2, 3]) 80 | loss = 0.5 * tf.reduce_mean(loss) 81 | 82 | return loss 83 | -------------------------------------------------------------------------------- /model_darknet19.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # -------------------------------------- 3 | # @Time : 2018/5/15$ 12:12$ 4 | # @Author : KOD Chen 5 | # @Email : 821237536@qq.com 6 | # @File : model_darknet19$.py 7 | # Description :yolo2网络模型——darknet19. 8 | # -------------------------------------- 9 | 10 | import os 11 | import tensorflow as tf 12 | import numpy as np 13 | 14 | ################# 基础层:conv/pool/reorg(带passthrough的重组层) ############################################# 15 | # 激活函数 16 | def leaky_relu(x): 17 | return tf.nn.leaky_relu(x,alpha=0.1,name='leaky_relu') # 或者tf.maximum(0.1*x,x) 18 | 19 | # Conv+BN:yolo2中每个卷积层后面都有一个BN层 20 | def conv2d(x,filters_num,filters_size,pad_size=0,stride=1,batch_normalize=True, 21 | activation=leaky_relu,use_bias=False,name='conv2d'): 22 | # padding,注意: 不用padding="SAME",否则可能会导致坐标计算错误 23 | if pad_size > 0: 24 | x = tf.pad(x,[[0,0],[pad_size,pad_size],[pad_size,pad_size],[0,0]]) 25 | # 有BN层,所以后面有BN层的conv就不用偏置bias,并先不经过激活函数activation 26 | out = tf.layers.conv2d(x,filters=filters_num,kernel_size=filters_size,strides=stride, 27 | padding='VALID',activation=None,use_bias=use_bias,name=name) 28 | # BN,如果有,应该在卷积层conv和激活函数activation之间 29 | if batch_normalize: 30 | out = tf.layers.batch_normalization(out,axis=-1,momentum=0.9,training=False,name=name+'_bn') 31 | if activation: 32 | out = activation(out) 33 | return out 34 | 35 | # max_pool 36 | def maxpool(x,size=2,stride=2,name='maxpool'): 37 | return tf.layers.max_pooling2d(x,pool_size=size,strides=stride) 38 | 39 | # reorg layer(带passthrough的重组层) 40 | def reorg(x,stride): 41 | return tf.space_to_depth(x,block_size=stride) 42 | # 或者return tf.extract_image_patches(x,ksizes=[1,stride,stride,1],strides=[1,stride,stride,1], 43 | # rates=[1,1,1,1],padding='VALID') 44 | ######################################################################################################### 45 | 46 | ################################### Darknet19 ########################################################### 47 | # 默认是coco数据集,最后一层维度是anchor_num*(class_num+5)=5*(80+5)=425 48 | def darknet(images,n_last_channels=425): 49 | net = conv2d(images, filters_num=32, filters_size=3, pad_size=1, name='conv1') 50 | net = maxpool(net, size=2, stride=2, name='pool1') 51 | 52 | net = conv2d(net, 64, 3, 1, name='conv2') 53 | net = maxpool(net, 2, 2, name='pool2') 54 | 55 | net = conv2d(net, 128, 3, 1, name='conv3_1') 56 | net = conv2d(net, 64, 1, 0, name='conv3_2') 57 | net = conv2d(net, 128, 3, 1, name='conv3_3') 58 | net = maxpool(net, 2, 2, name='pool3') 59 | 60 | net = conv2d(net, 256, 3, 1, name='conv4_1') 61 | net = conv2d(net, 128, 1, 0, name='conv4_2') 62 | net = conv2d(net, 256, 3, 1, name='conv4_3') 63 | net = maxpool(net, 2, 2, name='pool4') 64 | 65 | net = conv2d(net, 512, 3, 1, name='conv5_1') 66 | net = conv2d(net, 256, 1, 0,name='conv5_2') 67 | net = conv2d(net,512, 3, 1, name='conv5_3') 68 | net = conv2d(net, 256, 1, 0, name='conv5_4') 69 | net = conv2d(net, 512, 3, 1, name='conv5_5') 70 | shortcut = net # 存储这一层特征图,以便后面passthrough层 71 | net = maxpool(net, 2, 2, name='pool5') 72 | 73 | net = conv2d(net, 1024, 3, 1, name='conv6_1') 74 | net = conv2d(net, 512, 1, 0, name='conv6_2') 75 | net = conv2d(net, 1024, 3, 1, name='conv6_3') 76 | net = conv2d(net, 512, 1, 0, name='conv6_4') 77 | net = conv2d(net, 1024, 3, 1, name='conv6_5') 78 | 79 | net = conv2d(net, 1024, 3, 1, name='conv7_1') 80 | net = conv2d(net, 1024, 3, 1, name='conv7_2') 81 | # shortcut增加了一个中间卷积层,先采用64个1*1卷积核进行卷积,然后再进行passthrough处理 82 | # 这样26*26*512 -> 26*26*64 -> 13*13*256的特征图 83 | shortcut = conv2d(shortcut, 64, 1, 0, name='conv_shortcut') 84 | shortcut = reorg(shortcut, 2) 85 | net = tf.concat([shortcut, net], axis=-1) # channel整合到一起 86 | net = conv2d(net, 1024, 3, 1, name='conv8') 87 | 88 | # detection layer:最后用一个1*1卷积去调整channel,该层没有BN层和激活函数 89 | output = conv2d(net, filters_num=n_last_channels, filters_size=1, batch_normalize=False, 90 | activation=None, use_bias=True, name='conv_dec') 91 | 92 | return output 93 | ######################################################################################################### 94 | 95 | if __name__ == '__main__': 96 | x = tf.random_normal([1, 416, 416, 3]) 97 | model_output = darknet(x) 98 | 99 | saver = tf.train.Saver() 100 | with tf.Session() as sess: 101 | # 必须先restore模型才能打印shape;导入模型时,上面每层网络的name不能修改,否则找不到 102 | saver.restore(sess, "./yolo2_model/yolo2_coco.ckpt") 103 | print(sess.run(model_output).shape) # (1,13,13,425) 104 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # -------------------------------------- 3 | # @Time : 2018/5/16$ 14:48$ 4 | # @Author : KOD Chen 5 | # @Email : 821237536@qq.com 6 | # @File : utils$.py 7 | # Description :功能函数,包含:预处理输入图片、筛选边界框NMS、绘制筛选后的边界框。 8 | # -------------------------------------- 9 | 10 | import random 11 | import colorsys 12 | import cv2 13 | import numpy as np 14 | 15 | # 【1】图像预处理(pre process前期处理) 16 | def preprocess_image(image,image_size=(416,416)): 17 | # 复制原图像 18 | image_cp = np.copy(image).astype(np.float32) 19 | 20 | # resize image 21 | image_rgb = cv2.cvtColor(image_cp,cv2.COLOR_BGR2RGB) 22 | image_resized = cv2.resize(image_rgb,image_size) 23 | 24 | # normalize归一化 25 | image_normalized = image_resized.astype(np.float32) / 225.0 26 | 27 | # 增加一个维度在第0维——batch_size 28 | image_expanded = np.expand_dims(image_normalized,axis=0) 29 | 30 | return image_expanded 31 | 32 | # 【2】筛选解码后的回归边界框——NMS(post process后期处理) 33 | def postprocess(bboxes,obj_probs,class_probs,image_shape=(416,416),threshold=0.5): 34 | # bboxes表示为:图片中有多少box就多少行;4列分别是box(xmin,ymin,xmax,ymax) 35 | bboxes = np.reshape(bboxes,[-1,4]) 36 | # 将所有box还原成图片中真实的位置 37 | bboxes[:,0:1] *= float(image_shape[1]) # xmin*width 38 | bboxes[:,1:2] *= float(image_shape[0]) # ymin*height 39 | bboxes[:,2:3] *= float(image_shape[1]) # xmax*width 40 | bboxes[:,3:4] *= float(image_shape[0]) # ymax*height 41 | bboxes = bboxes.astype(np.int32) 42 | 43 | # (1)cut the box:将边界框超出整张图片(0,0)—(415,415)的部分cut掉 44 | bbox_min_max = [0,0,image_shape[1]-1,image_shape[0]-1] 45 | bboxes = bboxes_cut(bbox_min_max,bboxes) 46 | 47 | # ※※※置信度*max类别概率=类别置信度scores※※※ 48 | obj_probs = np.reshape(obj_probs,[-1]) 49 | class_probs = np.reshape(class_probs,[len(obj_probs),-1]) 50 | class_max_index = np.argmax(class_probs,axis=1) # 得到max类别概率对应的维度 51 | class_probs = class_probs[np.arange(len(obj_probs)),class_max_index] 52 | scores = obj_probs * class_probs 53 | 54 | # ※※※类别置信度scores>threshold的边界框bboxes留下※※※ 55 | keep_index = scores > threshold 56 | class_max_index = class_max_index[keep_index] 57 | scores = scores[keep_index] 58 | bboxes = bboxes[keep_index] 59 | 60 | # (2)排序top_k(默认为400) 61 | class_max_index,scores,bboxes = bboxes_sort(class_max_index,scores,bboxes) 62 | # ※※※(3)NMS※※※ 63 | class_max_index,scores,bboxes = bboxes_nms(class_max_index,scores,bboxes) 64 | 65 | return bboxes,scores,class_max_index 66 | 67 | # 【3】绘制筛选后的边界框 68 | def draw_detection(im, bboxes, scores, cls_inds, labels, thr=0.3): 69 | # Generate colors for drawing bounding boxes. 70 | hsv_tuples = [(x/float(len(labels)), 1., 1.) for x in range(len(labels))] 71 | colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples)) 72 | colors = list( 73 | map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),colors)) 74 | random.seed(10101) # Fixed seed for consistent colors across runs. 75 | random.shuffle(colors) # Shuffle colors to decorrelate adjacent classes. 76 | random.seed(None) # Reset seed to default. 77 | # draw image 78 | imgcv = np.copy(im) 79 | h, w, _ = imgcv.shape 80 | for i, box in enumerate(bboxes): 81 | if scores[i] < thr: 82 | continue 83 | cls_indx = cls_inds[i] 84 | 85 | thick = int((h + w) / 300) 86 | cv2.rectangle(imgcv,(box[0], box[1]), (box[2], box[3]),colors[cls_indx], thick) 87 | mess = '%s: %.3f' % (labels[cls_indx], scores[i]) 88 | if box[1] < 20: 89 | text_loc = (box[0] + 2, box[1] + 15) 90 | else: 91 | text_loc = (box[0], box[1] - 10) 92 | # cv2.rectangle(imgcv, (box[0], box[1]-20), ((box[0]+box[2])//3+120, box[1]-8), (125, 125, 125), -1) # puttext函数的背景 93 | cv2.putText(imgcv, mess, text_loc, cv2.FONT_HERSHEY_SIMPLEX, 1e-3*h, (255,255,255), thick//3) 94 | return imgcv 95 | 96 | ######################## 对应【2】:筛选解码后的回归边界框######################################### 97 | # (1)cut the box:将边界框超出整张图片(0,0)—(415,415)的部分cut掉 98 | def bboxes_cut(bbox_min_max,bboxes): 99 | bboxes = np.copy(bboxes) 100 | bboxes = np.transpose(bboxes) 101 | bbox_min_max = np.transpose(bbox_min_max) 102 | # cut the box 103 | bboxes[0] = np.maximum(bboxes[0],bbox_min_max[0]) # xmin 104 | bboxes[1] = np.maximum(bboxes[1],bbox_min_max[1]) # ymin 105 | bboxes[2] = np.minimum(bboxes[2],bbox_min_max[2]) # xmax 106 | bboxes[3] = np.minimum(bboxes[3],bbox_min_max[3]) # ymax 107 | bboxes = np.transpose(bboxes) 108 | return bboxes 109 | 110 | # (2)按类别置信度scores降序,对边界框进行排序并仅保留top_k 111 | def bboxes_sort(classes,scores,bboxes,top_k=400): 112 | index = np.argsort(-scores) 113 | classes = classes[index][:top_k] 114 | scores = scores[index][:top_k] 115 | bboxes = bboxes[index][:top_k] 116 | return classes,scores,bboxes 117 | 118 | # (3)计算IOU+NMS 119 | # 计算两个box的IOU 120 | def bboxes_iou(bboxes1,bboxes2): 121 | bboxes1 = np.transpose(bboxes1) 122 | bboxes2 = np.transpose(bboxes2) 123 | 124 | # 计算两个box的交集:交集左上角的点取两个box的max,交集右下角的点取两个box的min 125 | int_ymin = np.maximum(bboxes1[0], bboxes2[0]) 126 | int_xmin = np.maximum(bboxes1[1], bboxes2[1]) 127 | int_ymax = np.minimum(bboxes1[2], bboxes2[2]) 128 | int_xmax = np.minimum(bboxes1[3], bboxes2[3]) 129 | 130 | # 计算两个box交集的wh:如果两个box没有交集,那么wh为0(按照计算方式wh为负数,跟0比较取最大值) 131 | int_h = np.maximum(int_ymax-int_ymin,0.) 132 | int_w = np.maximum(int_xmax-int_xmin,0.) 133 | 134 | # 计算IOU 135 | int_vol = int_h * int_w # 交集面积 136 | vol1 = (bboxes1[2] - bboxes1[0]) * (bboxes1[3] - bboxes1[1]) # bboxes1面积 137 | vol2 = (bboxes2[2] - bboxes2[0]) * (bboxes2[3] - bboxes2[1]) # bboxes2面积 138 | IOU = int_vol / (vol1 + vol2 - int_vol) # IOU=交集/并集 139 | return IOU 140 | # NMS,或者用tf.image.non_max_suppression(boxes, scores,self.max_output_size, self.iou_threshold) 141 | def bboxes_nms(classes, scores, bboxes, nms_threshold=0.5): 142 | keep_bboxes = np.ones(scores.shape, dtype=np.bool) 143 | for i in range(scores.size-1): 144 | if keep_bboxes[i]: 145 | # Computer overlap with bboxes which are following. 146 | overlap = bboxes_iou(bboxes[i], bboxes[(i+1):]) 147 | # Overlap threshold for keeping + checking part of the same class 148 | keep_overlap = np.logical_or(overlap < nms_threshold, classes[(i+1):] != classes[i]) 149 | keep_bboxes[(i+1):] = np.logical_and(keep_bboxes[(i+1):], keep_overlap) 150 | 151 | idxes = np.where(keep_bboxes) 152 | return classes[idxes], scores[idxes], bboxes[idxes] 153 | ################################################################################################### 154 | --------------------------------------------------------------------------------