├── FPS_test.py ├── README.md ├── contrast.png ├── dark_result.jpg ├── dehaze.png ├── get_dr_txt.py ├── get_gt_txt.py ├── get_map.py ├── kmeans_for_anchors.py ├── light.py ├── predict.py ├── test.py ├── train.py ├── video.py ├── voc_annotation.py ├── yolo.py └── yolo1.py /FPS_test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import torch 4 | from PIL import Image, ImageDraw, ImageFont 5 | 6 | from utils.utils import (DecodeBox, bbox_iou, letterbox_image, 7 | non_max_suppression, yolo_correct_boxes) 8 | from yolo import YOLO 9 | 10 | ''' 11 | 该FPS测试不包括前处理(归一化与resize部分)、绘图。 12 | 包括的内容为:网络推理、得分门限筛选、非极大抑制。 13 | 测试方法参考https://github.com/zylo117/Yet-Another-EfficientDet-Pytorch 14 | video.py里面测试的FPS会低于该FPS,因为摄像头的读取频率有限,而且处理过程包含了前处理和绘图部分。 15 | ''' 16 | 17 | 18 | class FPS_YOLO(YOLO): 19 | def get_FPS(self, image, test_interval): 20 | image_shape = np.array(np.shape(image)[0:2]) 21 | if self.letterbox_image: 22 | crop_img = np.array(letterbox_image(image, (self.model_image_size[1], self.model_image_size[0]))) 23 | else: 24 | crop_img = image.convert('RGB') 25 | crop_img = crop_img.resize((self.model_image_size[1], self.model_image_size[0]), Image.BICUBIC) 26 | photo = np.array(crop_img, dtype=np.float32) / 255.0 27 | photo = np.transpose(photo, (2, 0, 1)) 28 | images = [photo] 29 | 30 | with torch.no_grad(): 31 | images = torch.from_numpy(np.asarray(images)) 32 | if self.cuda: 33 | images = images.cuda() 34 | 35 | outputs = self.net(images) 36 | output_list = [] 37 | for i in range(2): 38 | output_list.append(self.yolo_decodes[i](outputs[i])) 39 | 40 | output = torch.cat(output_list, 1) 41 | batch_detections = non_max_suppression(output, len(self.class_names), 42 | conf_thres=self.confidence, 43 | nms_thres=self.iou) 44 | try: 45 | batch_detections = batch_detections[0].cpu().numpy() 46 | top_index = batch_detections[:, 4] * batch_detections[:, 5] > self.confidence 47 | top_conf = batch_detections[top_index, 4] * batch_detections[top_index, 5] 48 | top_label = np.array(batch_detections[top_index, -1], np.int32) 49 | top_bboxes = np.array(batch_detections[top_index, :4]) 50 | top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims(top_bboxes[:, 0], -1), np.expand_dims( 51 | top_bboxes[:, 1], -1), np.expand_dims(top_bboxes[:, 2], -1), np.expand_dims(top_bboxes[:, 3], -1) 52 | 53 | if self.letterbox_image: 54 | boxes = yolo_correct_boxes(top_ymin, top_xmin, top_ymax, top_xmax, 55 | np.array([self.model_image_size[0], self.model_image_size[1]]), 56 | image_shape) 57 | else: 58 | top_xmin = top_xmin / self.model_image_size[1] * image_shape[1] 59 | top_ymin = top_ymin / self.model_image_size[0] * image_shape[0] 60 | top_xmax = top_xmax / self.model_image_size[1] * image_shape[1] 61 | top_ymax = top_ymax / self.model_image_size[0] * image_shape[0] 62 | boxes = np.concatenate([top_ymin, top_xmin, top_ymax, top_xmax], axis=-1) 63 | except: 64 | pass 65 | 66 | t1 = time.time() 67 | for _ in range(test_interval): 68 | with torch.no_grad(): 69 | outputs = self.net(images) 70 | output_list = [] 71 | for i in range(2): 72 | output_list.append(self.yolo_decodes[i](outputs[i])) 73 | 74 | output = torch.cat(output_list, 1) 75 | batch_detections = non_max_suppression(output, len(self.class_names), 76 | conf_thres=self.confidence, 77 | nms_thres=self.iou) 78 | try: 79 | batch_detections = batch_detections[0].cpu().numpy() 80 | top_index = batch_detections[:, 4] * batch_detections[:, 5] > self.confidence 81 | top_conf = batch_detections[top_index, 4] * batch_detections[top_index, 5] 82 | top_label = np.array(batch_detections[top_index, -1], np.int32) 83 | top_bboxes = np.array(batch_detections[top_index, :4]) 84 | top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims(top_bboxes[:, 0], -1), np.expand_dims( 85 | top_bboxes[:, 1], -1), np.expand_dims(top_bboxes[:, 2], -1), np.expand_dims(top_bboxes[:, 3], 86 | -1) 87 | 88 | if self.letterbox_image: 89 | boxes = yolo_correct_boxes(top_ymin, top_xmin, top_ymax, top_xmax, 90 | np.array([self.model_image_size[0], self.model_image_size[1]]), 91 | image_shape) 92 | else: 93 | top_xmin = top_xmin / self.model_image_size[1] * image_shape[1] 94 | top_ymin = top_ymin / self.model_image_size[0] * image_shape[0] 95 | top_xmax = top_xmax / self.model_image_size[1] * image_shape[1] 96 | top_ymax = top_ymax / self.model_image_size[0] * image_shape[0] 97 | boxes = np.concatenate([top_ymin, top_xmin, top_ymax, top_xmax], axis=-1) 98 | except: 99 | pass 100 | t2 = time.time() 101 | tact_time = (t2 - t1) / test_interval 102 | return tact_time 103 | 104 | 105 | yolo = FPS_YOLO() 106 | test_interval = 100 107 | img = Image.open('img/1.jpg') 108 | tact_time = yolo.get_FPS(img, test_interval) 109 | print(str(tact_time) + ' seconds, ' + str(1 / tact_time) + 'FPS, @batch_size 1') 110 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## YOLOV4-Tiny:You Only Look Once-Tiny目标检测模型在Pytorch当中的实现 2 | --- 3 | 4 | **2021年2月7日更新:** 5 | **仔细对照了darknet库的网络结构,发现P5_Upsample和feat1的顺序搞反了,已经调整,重新训练了权值,加入letterbox_image的选项,关闭letterbox_image后网络的map得到提升。** 6 | 7 | ## 目录 8 | 1. [性能情况 Performance](#性能情况) 9 | 2. [所需环境 Environment](#所需环境) 10 | 3. [注意事项 Attention](#注意事项) 11 | 4. [小技巧的设置 TricksSet](#小技巧的设置) 12 | 5. [文件下载 Download](#文件下载) 13 | 6. [预测步骤 How2predict](#预测步骤) 14 | 7. [训练步骤 How2train](#训练步骤) 15 | 8. [评估步骤 How2eval](#评估步骤) 16 | 9. [参考资料 Reference](#Reference) 17 | 18 | ## 性能情况 19 | | 训练数据集 | 权值文件名称 | 测试数据集 | 输入图片大小 | mAP 0.5:0.95 | mAP 0.5 | 20 | | :-----: | :-----: | :------: | :------: | :------: | :-----: | 21 | | VOC07+12+COCO | [yolov4_tiny_weights_voc.pth](https://github.com/bubbliiiing/yolov4-tiny-pytorch/releases/download/v1.0/yolov4_tiny_weights_voc.pth) | VOC-Test07 | 416x416 | - | 77.8 22 | | COCO-Train2017 | [yolov4_tiny_weights_coco.pth](https://github.com/bubbliiiing/yolov4-tiny-pytorch/releases/download/v1.0/yolov4_tiny_weights_coco.pth) | COCO-Val2017 | 416x416 | 21.5 | 41.0 23 | 24 | ## 所需环境 25 | torch==1.2.0 26 | 27 | ## 注意事项 28 | 代码中的yolov4_tiny_weights_coco.pth和yolov4_tiny_weights_voc.pth是基于416x416的图片训练的。 29 | 30 | ## 小技巧的设置 31 | 在train.py文件下: 32 | 1、mosaic参数可用于控制是否实现Mosaic数据增强。 33 | 2、Cosine_scheduler可用于控制是否使用学习率余弦退火衰减。 34 | 3、label_smoothing可用于控制是否Label Smoothing平滑。 35 | 36 | ## 文件下载 37 | 训练所需的yolov4_tiny_weights_coco.pth和yolov4_tiny_weights_voc.pth可在百度网盘中下载。 38 | 链接: https://pan.baidu.com/s/1B37A_-Fcx8TsAK-M4hm90g 提取码: 5te5 39 | 40 | VOC数据集下载地址如下: 41 | VOC2007+2012训练集 42 | 链接: https://pan.baidu.com/s/16pemiBGd-P9q2j7dZKGDFA 提取码: eiw9 43 | 44 | VOC2007测试集 45 | 链接: https://pan.baidu.com/s/1BnMiFwlNwIWG9gsd4jHLig 提取码: dsda 46 | 47 | ## 预测步骤 48 | ### a、使用预训练权重 49 | 1. 下载完库后解压,在百度网盘下载yolov4_tiny_voc.pth,放入model_data,运行predict.py,输入 50 | ```python 51 | img/street.jpg 52 | ``` 53 | 2. 利用video.py可进行摄像头检测。 54 | ### b、使用自己训练的权重 55 | 1. 按照训练步骤训练。 56 | 2. 在yolo.py文件里面,在如下部分修改model_path和classes_path使其对应训练好的文件;**model_path对应logs文件夹下面的权值文件,classes_path是model_path对应分的类**。 57 | ```python 58 | _defaults = { 59 | "model_path": 'model_data/yolov4_tiny_weights_coco.pth', 60 | "anchors_path": 'model_data/yolo_anchors.txt', 61 | "classes_path": 'model_data/coco_classes.txt, 62 | "score" : 0.5, 63 | "iou" : 0.3, 64 | # 显存比较小可以使用416x416 65 | # 显存比较大可以使用608x608 66 | "model_image_size" : (416, 416) 67 | } 68 | 69 | ``` 70 | 3. 运行predict.py,输入 71 | ```python 72 | img/street.jpg 73 | ``` 74 | 4. 利用video.py可进行摄像头检测。 75 | 76 | ## 训练步骤 77 | 1. 本文使用VOC格式进行训练。 78 | 2. 训练前将标签文件放在VOCdevkit文件夹下的VOC2007文件夹下的Annotation中。 79 | 3. 训练前将图片文件放在VOCdevkit文件夹下的VOC2007文件夹下的JPEGImages中。 80 | 4. 在训练前利用voc2yolo4.py文件生成对应的txt。 81 | 5. 再运行根目录下的voc_annotation.py,运行前需要将classes改成你自己的classes。**注意不要使用中文标签,文件夹中不要有空格!** 82 | ```python 83 | classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"] 84 | ``` 85 | 6. 此时会生成对应的2007_train.txt,每一行对应其**图片位置**及其**真实框的位置**。 86 | 7. **在训练前需要务必在model_data下新建一个txt文档,文档中输入需要分的类,在train.py中将classes_path指向该文件**,示例如下: 87 | ```python 88 | classes_path = 'model_data/new_classes.txt' 89 | ``` 90 | model_data/new_classes.txt文件内容为: 91 | ```python 92 | cat 93 | dog 94 | ... 95 | ``` 96 | 8. 运行train.py即可开始训练。 97 | 98 | ## 评估步骤 99 | 评估过程可参考视频https://www.bilibili.com/video/BV1zE411u7Vw 100 | 步骤是一样的,不需要自己再建立get_dr_txt.py、get_gt_txt.py等文件。 101 | 1. 本文使用VOC格式进行评估。 102 | 2. 评估前将标签文件放在VOCdevkit文件夹下的VOC2007文件夹下的Annotation中。 103 | 3. 评估前将图片文件放在VOCdevkit文件夹下的VOC2007文件夹下的JPEGImages中。 104 | 4. 在评估前利用voc2yolo4.py文件生成对应的txt,评估用的txt为VOCdevkit/VOC2007/ImageSets/Main/test.txt,需要注意的是,如果整个VOC2007里面的数据集都是用于评估,那么直接将trainval_percent设置成0即可。 105 | 5. 在yolo.py文件里面,在如下部分修改model_path和classes_path使其对应训练好的文件;**model_path对应logs文件夹下面的权值文件,classes_path是model_path对应分的类**。 106 | 6. 运行get_dr_txt.py和get_gt_txt.py,在./input/detection-results和./input/ground-truth文件夹下生成对应的txt。 107 | 7. 运行get_map.py即可开始计算模型的mAP。 108 | 109 | ## mAP目标检测精度计算更新 110 | 更新了get_gt_txt.py、get_dr_txt.py和get_map.py文件。 111 | get_map文件克隆自https://github.com/Cartucho/mAP 112 | 具体mAP计算过程可参考:https://www.bilibili.com/video/BV1zE411u7Vw 113 | 114 | ## Reference 115 | https://github.com/qqwweee/keras-yolo3/ 116 | https://github.com/Cartucho/mAP 117 | https://github.com/Ma-Dan/keras-yolo4 118 | -------------------------------------------------------------------------------- /contrast.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ichangchangchang/yolov4-tiny-pytorch-master/8febc690e6f8832894c1066d0d14e7f3f8e08c6a/contrast.png -------------------------------------------------------------------------------- /dark_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ichangchangchang/yolov4-tiny-pytorch-master/8febc690e6f8832894c1066d0d14e7f3f8e08c6a/dark_result.jpg -------------------------------------------------------------------------------- /dehaze.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ichangchangchang/yolov4-tiny-pytorch-master/8febc690e6f8832894c1066d0d14e7f3f8e08c6a/dehaze.png -------------------------------------------------------------------------------- /get_dr_txt.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | from yolo import YOLO 5 | from PIL import Image, ImageFont, ImageDraw 6 | from utils.utils import non_max_suppression, bbox_iou, DecodeBox, letterbox_image, yolo_correct_boxes 7 | from tqdm import tqdm 8 | 9 | class mAP_Yolo(YOLO): 10 | # ---------------------------------------------------# 11 | # 检测图片 12 | # ---------------------------------------------------# 13 | def detect_image(self, image_id, image): 14 | self.confidence = 0.01 15 | self.iou = 0.5 16 | f = open("./input/detection-results/" + image_id + ".txt", "w") 17 | image_shape = np.array(np.shape(image)[0:2]) 18 | 19 | # ---------------------------------------------------------# 20 | # 给图像增加灰条,实现不失真的resize 21 | # 也可以直接resize进行识别 22 | # ---------------------------------------------------------# 23 | if self.letterbox_image: 24 | crop_img = np.array(letterbox_image(image, (self.model_image_size[1], self.model_image_size[0]))) 25 | else: 26 | crop_img = image.convert('RGB') 27 | crop_img = crop_img.resize((self.model_image_size[1], self.model_image_size[0]), Image.BICUBIC) 28 | photo = np.array(crop_img, dtype=np.float32) / 255.0 29 | photo = np.transpose(photo, (2, 0, 1)) 30 | # ---------------------------------------------------------# 31 | # 添加上batch_size维度 32 | # ---------------------------------------------------------# 33 | images = [photo] 34 | 35 | with torch.no_grad(): 36 | images = torch.from_numpy(np.asarray(images)) 37 | if self.cuda: 38 | images = images.cuda() 39 | 40 | # ---------------------------------------------------------# 41 | # 将图像输入网络当中进行预测! 42 | # ---------------------------------------------------------# 43 | outputs = self.net(images) 44 | output_list = [] 45 | for i in range(2): 46 | output_list.append(self.yolo_decodes[i](outputs[i])) 47 | 48 | # ---------------------------------------------------------# 49 | # 将预测框进行堆叠,然后进行非极大抑制 50 | # ---------------------------------------------------------# 51 | output = torch.cat(output_list, 1) 52 | batch_detections = non_max_suppression(output, len(self.class_names), 53 | conf_thres=self.confidence, 54 | nms_thres=self.iou) 55 | 56 | try: 57 | batch_detections = batch_detections[0].cpu().numpy() 58 | except: 59 | return 60 | 61 | # ---------------------------------------------------------# 62 | # 对预测框进行得分筛选 63 | # ---------------------------------------------------------# 64 | top_index = batch_detections[:, 4] * batch_detections[:, 5] > self.confidence 65 | top_conf = batch_detections[top_index, 4] * batch_detections[top_index, 5] 66 | top_label = np.array(batch_detections[top_index, -1], np.int32) 67 | top_bboxes = np.array(batch_detections[top_index, :4]) 68 | top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims(top_bboxes[:, 0], -1), np.expand_dims( 69 | top_bboxes[:, 1], -1), np.expand_dims(top_bboxes[:, 2], -1), np.expand_dims(top_bboxes[:, 3], -1) 70 | 71 | # -----------------------------------------------------------------# 72 | # 在图像传入网络预测前会进行letterbox_image给图像周围添加灰条 73 | # 因此生成的top_bboxes是相对于有灰条的图像的 74 | # 我们需要对其进行修改,去除灰条的部分。 75 | # -----------------------------------------------------------------# 76 | if self.letterbox_image: 77 | boxes = yolo_correct_boxes(top_ymin, top_xmin, top_ymax, top_xmax, 78 | np.array([self.model_image_size[0], self.model_image_size[1]]), image_shape) 79 | else: 80 | top_xmin = top_xmin / self.model_image_size[1] * image_shape[1] 81 | top_ymin = top_ymin / self.model_image_size[0] * image_shape[0] 82 | top_xmax = top_xmax / self.model_image_size[1] * image_shape[1] 83 | top_ymax = top_ymax / self.model_image_size[0] * image_shape[0] 84 | boxes = np.concatenate([top_ymin, top_xmin, top_ymax, top_xmax], axis=-1) 85 | 86 | for i, c in enumerate(top_label): 87 | predicted_class = self.class_names[c] 88 | score = str(top_conf[i]) 89 | 90 | top, left, bottom, right = boxes[i] 91 | f.write("%s %s %s %s %s %s\n" % ( 92 | predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)), str(int(bottom)))) 93 | 94 | f.close() 95 | return 96 | 97 | 98 | yolo = mAP_Yolo() 99 | image_ids = open('VOCdevkit/VOC2007/ImageSets/Main/test.txt').read().strip().split() 100 | 101 | if not os.path.exists("./input"): 102 | os.makedirs("./input") 103 | if not os.path.exists("./input/detection-results"): 104 | os.makedirs("./input/detection-results") 105 | if not os.path.exists("./input/images-optional"): 106 | os.makedirs("./input/images-optional") 107 | 108 | for image_id in tqdm(image_ids): 109 | 110 | if "to" in image_id: 111 | image_path = "./VOCdevkit/VOC2007/JPEGImages/" + image_id + ".png" 112 | else: 113 | image_path = "./VOCdevkit/VOC2007/JPEGImages/"+image_id+".jpg" 114 | image = Image.open(image_path) 115 | # 开启后在之后计算mAP可以可视化 116 | # image.save("./input/images-optional/"+image_id+".jpg") 117 | yolo.detect_image(image_id,image) 118 | 119 | print("Conversion completed!") 120 | -------------------------------------------------------------------------------- /get_gt_txt.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import glob 4 | import xml.etree.ElementTree as ET 5 | 6 | 7 | # ---------------------------------------------------# 8 | # 获得类 9 | # ---------------------------------------------------# 10 | def get_classes(classes_path): 11 | '''loads the classes''' 12 | with open(classes_path) as f: 13 | class_names = f.readlines() 14 | class_names = [c.strip() for c in class_names] 15 | return class_names 16 | 17 | 18 | image_ids = open('VOCdevkit/VOC2007/ImageSets/Main/test.txt').read().strip().split() 19 | 20 | if not os.path.exists("./input"): 21 | os.makedirs("./input") 22 | if not os.path.exists("./input/ground-truth"): 23 | os.makedirs("./input/ground-truth") 24 | 25 | for image_id in image_ids: 26 | with open("./input/ground-truth/" + image_id + ".txt", "w") as new_f: 27 | root = ET.parse("VOCdevkit/VOC2007/Annotations/" + image_id + ".xml").getroot() 28 | for obj in root.findall('object'): 29 | difficult_flag = False 30 | if obj.find('difficult') != None: 31 | difficult = obj.find('difficult').text 32 | if int(difficult) == 1: 33 | difficult_flag = True 34 | obj_name = obj.find('name').text 35 | 36 | bndbox = obj.find('bndbox') 37 | left = bndbox.find('xmin').text 38 | top = bndbox.find('ymin').text 39 | right = bndbox.find('xmax').text 40 | bottom = bndbox.find('ymax').text 41 | 42 | if difficult_flag: 43 | new_f.write("%s %s %s %s %s difficult\n" % (obj_name, left, top, right, bottom)) 44 | else: 45 | new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom)) 46 | 47 | print("Conversion completed!") 48 | -------------------------------------------------------------------------------- /get_map.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | import shutil 5 | import operator 6 | import sys 7 | import argparse 8 | import math 9 | 10 | import numpy as np 11 | 12 | ''' 13 | 用于计算mAP 14 | 代码克隆自https://github.com/Cartucho/mAP 15 | 如果想要设定mAP0.x,比如计算mAP0.75,可以设定MINOVERLAP = 0.75。 16 | ''' 17 | MINOVERLAP = 0.5 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('-na', '--no-animation', help="no animation is shown.", action="store_true") 21 | parser.add_argument('-np', '--no-plot', help="no plot is shown.", action="store_true") 22 | parser.add_argument('-q', '--quiet', help="minimalistic console output.", action="store_true") 23 | parser.add_argument('-i', '--ignore', nargs='+', type=str, help="ignore a list of classes.") 24 | parser.add_argument('--set-class-iou', nargs='+', type=str, help="set IoU for a specific class.") 25 | args = parser.parse_args() 26 | 27 | ''' 28 | 0,0 ------> x (width) 29 | | 30 | | (Left,Top) 31 | | *_________ 32 | | | | 33 | | | 34 | y |_________| 35 | (height) * 36 | (Right,Bottom) 37 | ''' 38 | 39 | if args.ignore is None: 40 | args.ignore = [] 41 | 42 | specific_iou_flagged = False 43 | if args.set_class_iou is not None: 44 | specific_iou_flagged = True 45 | 46 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 47 | 48 | GT_PATH = os.path.join(os.getcwd(), 'input', 'ground-truth') 49 | DR_PATH = os.path.join(os.getcwd(), 'input', 'detection-results') 50 | IMG_PATH = os.path.join(os.getcwd(), 'input', 'images-optional') 51 | if os.path.exists(IMG_PATH): 52 | for dirpath, dirnames, files in os.walk(IMG_PATH): 53 | if not files: 54 | args.no_animation = True 55 | else: 56 | args.no_animation = True 57 | 58 | show_animation = False 59 | if not args.no_animation: 60 | try: 61 | import cv2 62 | show_animation = True 63 | except ImportError: 64 | print("\"opencv-python\" not found, please install to visualize the results.") 65 | args.no_animation = True 66 | 67 | draw_plot = False 68 | if not args.no_plot: 69 | try: 70 | import matplotlib.pyplot as plt 71 | draw_plot = True 72 | except ImportError: 73 | print("\"matplotlib\" not found, please install it to get the resulting plots.") 74 | args.no_plot = True 75 | 76 | 77 | def log_average_miss_rate(precision, fp_cumsum, num_images): 78 | """ 79 | log-average miss rate: 80 | Calculated by averaging miss rates at 9 evenly spaced FPPI points 81 | between 10e-2 and 10e0, in log-space. 82 | 83 | output: 84 | lamr | log-average miss rate 85 | mr | miss rate 86 | fppi | false positives per image 87 | 88 | references: 89 | [1] Dollar, Piotr, et al. "Pedestrian Detection: An Evaluation of the 90 | State of the Art." Pattern Analysis and Machine Intelligence, IEEE 91 | Transactions on 34.4 (2012): 743 - 761. 92 | """ 93 | 94 | if precision.size == 0: 95 | lamr = 0 96 | mr = 1 97 | fppi = 0 98 | return lamr, mr, fppi 99 | 100 | fppi = fp_cumsum / float(num_images) 101 | mr = (1 - precision) 102 | 103 | fppi_tmp = np.insert(fppi, 0, -1.0) 104 | mr_tmp = np.insert(mr, 0, 1.0) 105 | 106 | ref = np.logspace(-2.0, 0.0, num = 9) 107 | for i, ref_i in enumerate(ref): 108 | j = np.where(fppi_tmp <= ref_i)[-1][-1] 109 | ref[i] = mr_tmp[j] 110 | 111 | lamr = math.exp(np.mean(np.log(np.maximum(1e-10, ref)))) 112 | 113 | return lamr, mr, fppi 114 | 115 | """ 116 | throw error and exit 117 | """ 118 | def error(msg): 119 | print(msg) 120 | sys.exit(0) 121 | 122 | """ 123 | check if the number is a float between 0.0 and 1.0 124 | """ 125 | def is_float_between_0_and_1(value): 126 | try: 127 | val = float(value) 128 | if val > 0.0 and val < 1.0: 129 | return True 130 | else: 131 | return False 132 | except ValueError: 133 | return False 134 | 135 | """ 136 | Calculate the AP given the recall and precision array 137 | 1st) We compute a version of the measured precision/recall curve with 138 | precision monotonically decreasing 139 | 2nd) We compute the AP as the area under this curve by numerical integration. 140 | """ 141 | def voc_ap(rec, prec): 142 | """ 143 | --- Official matlab code VOC2012--- 144 | mrec=[0 ; rec ; 1]; 145 | mpre=[0 ; prec ; 0]; 146 | for i=numel(mpre)-1:-1:1 147 | mpre(i)=max(mpre(i),mpre(i+1)); 148 | end 149 | i=find(mrec(2:end)~=mrec(1:end-1))+1; 150 | ap=sum((mrec(i)-mrec(i-1)).*mpre(i)); 151 | """ 152 | rec.insert(0, 0.0) # insert 0.0 at begining of list 153 | rec.append(1.0) # insert 1.0 at end of list 154 | mrec = rec[:] 155 | prec.insert(0, 0.0) # insert 0.0 at begining of list 156 | prec.append(0.0) # insert 0.0 at end of list 157 | mpre = prec[:] 158 | """ 159 | This part makes the precision monotonically decreasing 160 | (goes from the end to the beginning) 161 | matlab: for i=numel(mpre)-1:-1:1 162 | mpre(i)=max(mpre(i),mpre(i+1)); 163 | """ 164 | for i in range(len(mpre)-2, -1, -1): 165 | mpre[i] = max(mpre[i], mpre[i+1]) 166 | """ 167 | This part creates a list of indexes where the recall changes 168 | matlab: i=find(mrec(2:end)~=mrec(1:end-1))+1; 169 | """ 170 | i_list = [] 171 | for i in range(1, len(mrec)): 172 | if mrec[i] != mrec[i-1]: 173 | i_list.append(i) # if it was matlab would be i + 1 174 | """ 175 | The Average Precision (AP) is the area under the curve 176 | (numerical integration) 177 | matlab: ap=sum((mrec(i)-mrec(i-1)).*mpre(i)); 178 | """ 179 | ap = 0.0 180 | for i in i_list: 181 | ap += ((mrec[i]-mrec[i-1])*mpre[i]) 182 | return ap, mrec, mpre 183 | 184 | 185 | """ 186 | Convert the lines of a file to a list 187 | """ 188 | def file_lines_to_list(path): 189 | # open txt file lines to a list 190 | with open(path) as f: 191 | content = f.readlines() 192 | # remove whitespace characters like `\n` at the end of each line 193 | content = [x.strip() for x in content] 194 | return content 195 | 196 | """ 197 | Draws text in image 198 | """ 199 | def draw_text_in_image(img, text, pos, color, line_width): 200 | font = cv2.FONT_HERSHEY_PLAIN 201 | fontScale = 1 202 | lineType = 1 203 | bottomLeftCornerOfText = pos 204 | cv2.putText(img, text, 205 | bottomLeftCornerOfText, 206 | font, 207 | fontScale, 208 | color, 209 | lineType) 210 | text_width, _ = cv2.getTextSize(text, font, fontScale, lineType)[0] 211 | return img, (line_width + text_width) 212 | 213 | """ 214 | Plot - adjust axes 215 | """ 216 | def adjust_axes(r, t, fig, axes): 217 | # get text width for re-scaling 218 | bb = t.get_window_extent(renderer=r) 219 | text_width_inches = bb.width / fig.dpi 220 | # get axis width in inches 221 | current_fig_width = fig.get_figwidth() 222 | new_fig_width = current_fig_width + text_width_inches 223 | propotion = new_fig_width / current_fig_width 224 | # get axis limit 225 | x_lim = axes.get_xlim() 226 | axes.set_xlim([x_lim[0], x_lim[1]*propotion]) 227 | 228 | """ 229 | Draw plot using Matplotlib 230 | """ 231 | def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, output_path, to_show, plot_color, true_p_bar): 232 | # sort the dictionary by decreasing value, into a list of tuples 233 | sorted_dic_by_value = sorted(dictionary.items(), key=operator.itemgetter(1)) 234 | # unpacking the list of tuples into two lists 235 | sorted_keys, sorted_values = zip(*sorted_dic_by_value) 236 | # 237 | if true_p_bar != "": 238 | """ 239 | Special case to draw in: 240 | - green -> TP: True Positives (object detected and matches ground-truth) 241 | - red -> FP: False Positives (object detected but does not match ground-truth) 242 | - orange -> FN: False Negatives (object not detected but present in the ground-truth) 243 | """ 244 | fp_sorted = [] 245 | tp_sorted = [] 246 | for key in sorted_keys: 247 | fp_sorted.append(dictionary[key] - true_p_bar[key]) 248 | tp_sorted.append(true_p_bar[key]) 249 | plt.barh(range(n_classes), fp_sorted, align='center', color='crimson', label='False Positive') 250 | plt.barh(range(n_classes), tp_sorted, align='center', color='forestgreen', label='True Positive', left=fp_sorted) 251 | # add legend 252 | plt.legend(loc='lower right') 253 | """ 254 | Write number on side of bar 255 | """ 256 | fig = plt.gcf() # gcf - get current figure 257 | axes = plt.gca() 258 | r = fig.canvas.get_renderer() 259 | for i, val in enumerate(sorted_values): 260 | fp_val = fp_sorted[i] 261 | tp_val = tp_sorted[i] 262 | fp_str_val = " " + str(fp_val) 263 | tp_str_val = fp_str_val + " " + str(tp_val) 264 | # trick to paint multicolor with offset: 265 | # first paint everything and then repaint the first number 266 | t = plt.text(val, i, tp_str_val, color='forestgreen', va='center', fontweight='bold') 267 | plt.text(val, i, fp_str_val, color='crimson', va='center', fontweight='bold') 268 | if i == (len(sorted_values)-1): # largest bar 269 | adjust_axes(r, t, fig, axes) 270 | else: 271 | plt.barh(range(n_classes), sorted_values, color=plot_color) 272 | """ 273 | Write number on side of bar 274 | """ 275 | fig = plt.gcf() # gcf - get current figure 276 | axes = plt.gca() 277 | r = fig.canvas.get_renderer() 278 | for i, val in enumerate(sorted_values): 279 | str_val = " " + str(val) # add a space before 280 | if val < 1.0: 281 | str_val = " {0:.2f}".format(val) 282 | t = plt.text(val, i, str_val, color=plot_color, va='center', fontweight='bold') 283 | # re-set axes to show number inside the figure 284 | if i == (len(sorted_values)-1): # largest bar 285 | adjust_axes(r, t, fig, axes) 286 | # set window title 287 | fig.canvas.set_window_title(window_title) 288 | # write classes in y axis 289 | tick_font_size = 12 290 | plt.yticks(range(n_classes), sorted_keys, fontsize=tick_font_size) 291 | """ 292 | Re-scale height accordingly 293 | """ 294 | init_height = fig.get_figheight() 295 | # comput the matrix height in points and inches 296 | dpi = fig.dpi 297 | height_pt = n_classes * (tick_font_size * 1.4) # 1.4 (some spacing) 298 | height_in = height_pt / dpi 299 | # compute the required figure height 300 | top_margin = 0.15 # in percentage of the figure height 301 | bottom_margin = 0.05 # in percentage of the figure height 302 | figure_height = height_in / (1 - top_margin - bottom_margin) 303 | # set new height 304 | if figure_height > init_height: 305 | fig.set_figheight(figure_height) 306 | 307 | # set plot title 308 | plt.title(plot_title, fontsize=14) 309 | # set axis titles 310 | # plt.xlabel('classes') 311 | plt.xlabel(x_label, fontsize='large') 312 | # adjust size of window 313 | fig.tight_layout() 314 | # save the plot 315 | fig.savefig(output_path) 316 | # show image 317 | if to_show: 318 | plt.show() 319 | # close the plot 320 | plt.close() 321 | 322 | """ 323 | Create a ".temp_files/" and "results/" directory 324 | """ 325 | TEMP_FILES_PATH = ".temp_files" 326 | if not os.path.exists(TEMP_FILES_PATH): # if it doesn't exist already 327 | os.makedirs(TEMP_FILES_PATH) 328 | results_files_path = "results" 329 | if os.path.exists(results_files_path): # if it exist already 330 | # reset the results directory 331 | shutil.rmtree(results_files_path) 332 | 333 | os.makedirs(results_files_path) 334 | if draw_plot: 335 | os.makedirs(os.path.join(results_files_path, "AP")) 336 | os.makedirs(os.path.join(results_files_path, "F1")) 337 | os.makedirs(os.path.join(results_files_path, "Recall")) 338 | os.makedirs(os.path.join(results_files_path, "Precision")) 339 | if show_animation: 340 | os.makedirs(os.path.join(results_files_path, "images", "detections_one_by_one")) 341 | 342 | """ 343 | ground-truth 344 | Load each of the ground-truth files into a temporary ".json" file. 345 | Create a list of all the class names present in the ground-truth (gt_classes). 346 | """ 347 | # get a list with the ground-truth files 348 | ground_truth_files_list = glob.glob(GT_PATH + '/*.txt') 349 | if len(ground_truth_files_list) == 0: 350 | error("Error: No ground-truth files found!") 351 | ground_truth_files_list.sort() 352 | # dictionary with counter per class 353 | gt_counter_per_class = {} 354 | counter_images_per_class = {} 355 | 356 | for txt_file in ground_truth_files_list: 357 | #print(txt_file) 358 | file_id = txt_file.split(".txt", 1)[0] 359 | file_id = os.path.basename(os.path.normpath(file_id)) 360 | # check if there is a correspondent detection-results file 361 | temp_path = os.path.join(DR_PATH, (file_id + ".txt")) 362 | if not os.path.exists(temp_path): 363 | error_msg = "Error. File not found: {}\n".format(temp_path) 364 | error_msg += "(You can avoid this error message by running extra/intersect-gt-and-dr.py)" 365 | error(error_msg) 366 | lines_list = file_lines_to_list(txt_file) 367 | # create ground-truth dictionary 368 | bounding_boxes = [] 369 | is_difficult = False 370 | already_seen_classes = [] 371 | for line in lines_list: 372 | try: 373 | if "difficult" in line: 374 | class_name, left, top, right, bottom, _difficult = line.split() 375 | is_difficult = True 376 | else: 377 | class_name, left, top, right, bottom = line.split() 378 | 379 | except: 380 | if "difficult" in line: 381 | line_split = line.split() 382 | _difficult = line_split[-1] 383 | bottom = line_split[-2] 384 | right = line_split[-3] 385 | top = line_split[-4] 386 | left = line_split[-5] 387 | class_name = "" 388 | for name in line_split[:-5]: 389 | class_name += name + " " 390 | class_name = class_name[:-1] 391 | is_difficult = True 392 | else: 393 | line_split = line.split() 394 | bottom = line_split[-1] 395 | right = line_split[-2] 396 | top = line_split[-3] 397 | left = line_split[-4] 398 | class_name = "" 399 | for name in line_split[:-4]: 400 | class_name += name + " " 401 | class_name = class_name[:-1] 402 | if class_name in args.ignore: 403 | continue 404 | bbox = left + " " + top + " " + right + " " +bottom 405 | if is_difficult: 406 | bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False, "difficult":True}) 407 | is_difficult = False 408 | else: 409 | bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False}) 410 | if class_name in gt_counter_per_class: 411 | gt_counter_per_class[class_name] += 1 412 | else: 413 | gt_counter_per_class[class_name] = 1 414 | 415 | if class_name not in already_seen_classes: 416 | if class_name in counter_images_per_class: 417 | counter_images_per_class[class_name] += 1 418 | else: 419 | counter_images_per_class[class_name] = 1 420 | already_seen_classes.append(class_name) 421 | 422 | 423 | with open(TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json", 'w') as outfile: 424 | json.dump(bounding_boxes, outfile) 425 | 426 | gt_classes = list(gt_counter_per_class.keys()) 427 | gt_classes = sorted(gt_classes) 428 | n_classes = len(gt_classes) 429 | 430 | """ 431 | Check format of the flag --set-class-iou (if used) 432 | e.g. check if class exists 433 | """ 434 | if specific_iou_flagged: 435 | n_args = len(args.set_class_iou) 436 | error_msg = \ 437 | '\n --set-class-iou [class_1] [IoU_1] [class_2] [IoU_2] [...]' 438 | if n_args % 2 != 0: 439 | error('Error, missing arguments. Flag usage:' + error_msg) 440 | # [class_1] [IoU_1] [class_2] [IoU_2] 441 | # specific_iou_classes = ['class_1', 'class_2'] 442 | specific_iou_classes = args.set_class_iou[::2] # even 443 | # iou_list = ['IoU_1', 'IoU_2'] 444 | iou_list = args.set_class_iou[1::2] # odd 445 | if len(specific_iou_classes) != len(iou_list): 446 | error('Error, missing arguments. Flag usage:' + error_msg) 447 | for tmp_class in specific_iou_classes: 448 | if tmp_class not in gt_classes: 449 | error('Error, unknown class \"' + tmp_class + '\". Flag usage:' + error_msg) 450 | for num in iou_list: 451 | if not is_float_between_0_and_1(num): 452 | error('Error, IoU must be between 0.0 and 1.0. Flag usage:' + error_msg) 453 | 454 | """ 455 | detection-results 456 | Load each of the detection-results files into a temporary ".json" file. 457 | """ 458 | dr_files_list = glob.glob(DR_PATH + '/*.txt') 459 | dr_files_list.sort() 460 | 461 | for class_index, class_name in enumerate(gt_classes): 462 | bounding_boxes = [] 463 | for txt_file in dr_files_list: 464 | file_id = txt_file.split(".txt",1)[0] 465 | file_id = os.path.basename(os.path.normpath(file_id)) 466 | temp_path = os.path.join(GT_PATH, (file_id + ".txt")) 467 | if class_index == 0: 468 | if not os.path.exists(temp_path): 469 | error_msg = "Error. File not found: {}\n".format(temp_path) 470 | error_msg += "(You can avoid this error message by running extra/intersect-gt-and-dr.py)" 471 | error(error_msg) 472 | lines = file_lines_to_list(txt_file) 473 | for line in lines: 474 | try: 475 | tmp_class_name, confidence, left, top, right, bottom = line.split() 476 | except: 477 | line_split = line.split() 478 | bottom = line_split[-1] 479 | right = line_split[-2] 480 | top = line_split[-3] 481 | left = line_split[-4] 482 | confidence = line_split[-5] 483 | tmp_class_name = "" 484 | for name in line_split[:-5]: 485 | tmp_class_name += name + " " 486 | tmp_class_name = tmp_class_name[:-1] 487 | 488 | if tmp_class_name == class_name: 489 | bbox = left + " " + top + " " + right + " " +bottom 490 | bounding_boxes.append({"confidence":confidence, "file_id":file_id, "bbox":bbox}) 491 | 492 | bounding_boxes.sort(key=lambda x:float(x['confidence']), reverse=True) 493 | with open(TEMP_FILES_PATH + "/" + class_name + "_dr.json", 'w') as outfile: 494 | json.dump(bounding_boxes, outfile) 495 | 496 | """ 497 | Calculate the AP for each class 498 | """ 499 | sum_AP = 0.0 500 | ap_dictionary = {} 501 | lamr_dictionary = {} 502 | with open(results_files_path + "/results.txt", 'w') as results_file: 503 | results_file.write("# AP and precision/recall per class\n") 504 | count_true_positives = {} 505 | 506 | for class_index, class_name in enumerate(gt_classes): 507 | count_true_positives[class_name] = 0 508 | """ 509 | Load detection-results of that class 510 | """ 511 | dr_file = TEMP_FILES_PATH + "/" + class_name + "_dr.json" 512 | dr_data = json.load(open(dr_file)) 513 | """ 514 | Assign detection-results to ground-truth objects 515 | """ 516 | nd = len(dr_data) 517 | tp = [0] * nd 518 | fp = [0] * nd 519 | score = [0] * nd 520 | score05_idx = 0 521 | for idx, detection in enumerate(dr_data): 522 | file_id = detection["file_id"] 523 | score[idx] = float(detection["confidence"]) 524 | if score[idx] > 0.5: 525 | score05_idx = idx 526 | 527 | if show_animation: 528 | ground_truth_img = glob.glob1(IMG_PATH, file_id + ".*") 529 | if len(ground_truth_img) == 0: 530 | error("Error. Image not found with id: " + file_id) 531 | elif len(ground_truth_img) > 1: 532 | error("Error. Multiple image with id: " + file_id) 533 | else: 534 | img = cv2.imread(IMG_PATH + "/" + ground_truth_img[0]) 535 | img_cumulative_path = results_files_path + "/images/" + ground_truth_img[0] 536 | if os.path.isfile(img_cumulative_path): 537 | img_cumulative = cv2.imread(img_cumulative_path) 538 | else: 539 | img_cumulative = img.copy() 540 | bottom_border = 60 541 | BLACK = [0, 0, 0] 542 | img = cv2.copyMakeBorder(img, 0, bottom_border, 0, 0, cv2.BORDER_CONSTANT, value=BLACK) 543 | 544 | gt_file = TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json" 545 | ground_truth_data = json.load(open(gt_file)) 546 | ovmax = -1 547 | gt_match = -1 548 | bb = [ float(x) for x in detection["bbox"].split() ] 549 | for obj in ground_truth_data: 550 | if obj["class_name"] == class_name: 551 | bbgt = [ float(x) for x in obj["bbox"].split() ] 552 | bi = [max(bb[0],bbgt[0]), max(bb[1],bbgt[1]), min(bb[2],bbgt[2]), min(bb[3],bbgt[3])] 553 | iw = bi[2] - bi[0] + 1 554 | ih = bi[3] - bi[1] + 1 555 | if iw > 0 and ih > 0: 556 | # compute overlap (IoU) = area of intersection / area of union 557 | ua = (bb[2] - bb[0] + 1) * (bb[3] - bb[1] + 1) + (bbgt[2] - bbgt[0] 558 | + 1) * (bbgt[3] - bbgt[1] + 1) - iw * ih 559 | ov = iw * ih / ua 560 | if ov > ovmax: 561 | ovmax = ov 562 | gt_match = obj 563 | 564 | if show_animation: 565 | status = "NO MATCH FOUND!" 566 | min_overlap = MINOVERLAP 567 | if specific_iou_flagged: 568 | if class_name in specific_iou_classes: 569 | index = specific_iou_classes.index(class_name) 570 | min_overlap = float(iou_list[index]) 571 | if ovmax >= min_overlap: 572 | if "difficult" not in gt_match: 573 | if not bool(gt_match["used"]): 574 | tp[idx] = 1 575 | gt_match["used"] = True 576 | count_true_positives[class_name] += 1 577 | with open(gt_file, 'w') as f: 578 | f.write(json.dumps(ground_truth_data)) 579 | if show_animation: 580 | status = "MATCH!" 581 | else: 582 | fp[idx] = 1 583 | if show_animation: 584 | status = "REPEATED MATCH!" 585 | else: 586 | fp[idx] = 1 587 | if ovmax > 0: 588 | status = "INSUFFICIENT OVERLAP" 589 | 590 | """ 591 | Draw image to show animation 592 | """ 593 | if show_animation: 594 | height, widht = img.shape[:2] 595 | # colors (OpenCV works with BGR) 596 | white = (255,255,255) 597 | light_blue = (255,200,100) 598 | green = (0,255,0) 599 | light_red = (30,30,255) 600 | # 1st line 601 | margin = 10 602 | v_pos = int(height - margin - (bottom_border / 2.0)) 603 | text = "Image: " + ground_truth_img[0] + " " 604 | img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0) 605 | text = "Class [" + str(class_index) + "/" + str(n_classes) + "]: " + class_name + " " 606 | img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), light_blue, line_width) 607 | if ovmax != -1: 608 | color = light_red 609 | if status == "INSUFFICIENT OVERLAP": 610 | text = "IoU: {0:.2f}% ".format(ovmax*100) + "< {0:.2f}% ".format(min_overlap*100) 611 | else: 612 | text = "IoU: {0:.2f}% ".format(ovmax*100) + ">= {0:.2f}% ".format(min_overlap*100) 613 | color = green 614 | img, _ = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width) 615 | # 2nd line 616 | v_pos += int(bottom_border / 2.0) 617 | rank_pos = str(idx+1) # rank position (idx starts at 0) 618 | text = "Detection #rank: " + rank_pos + " confidence: {0:.2f}% ".format(float(detection["confidence"])*100) 619 | img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0) 620 | color = light_red 621 | if status == "MATCH!": 622 | color = green 623 | text = "Result: " + status + " " 624 | img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width) 625 | 626 | font = cv2.FONT_HERSHEY_SIMPLEX 627 | if ovmax > 0: # if there is intersections between the bounding-boxes 628 | bbgt = [ int(round(float(x))) for x in gt_match["bbox"].split() ] 629 | cv2.rectangle(img,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2) 630 | cv2.rectangle(img_cumulative,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2) 631 | cv2.putText(img_cumulative, class_name, (bbgt[0],bbgt[1] - 5), font, 0.6, light_blue, 1, cv2.LINE_AA) 632 | bb = [int(i) for i in bb] 633 | cv2.rectangle(img,(bb[0],bb[1]),(bb[2],bb[3]),color,2) 634 | cv2.rectangle(img_cumulative,(bb[0],bb[1]),(bb[2],bb[3]),color,2) 635 | cv2.putText(img_cumulative, class_name, (bb[0],bb[1] - 5), font, 0.6, color, 1, cv2.LINE_AA) 636 | # show image 637 | cv2.imshow("Animation", img) 638 | cv2.waitKey(20) # show for 20 ms 639 | # save image to results 640 | output_img_path = results_files_path + "/images/detections_one_by_one/" + class_name + "_detection" + str(idx) + ".jpg" 641 | cv2.imwrite(output_img_path, img) 642 | # save the image with all the objects drawn to it 643 | cv2.imwrite(img_cumulative_path, img_cumulative) 644 | 645 | cumsum = 0 646 | for idx, val in enumerate(fp): 647 | fp[idx] += cumsum 648 | cumsum += val 649 | 650 | cumsum = 0 651 | for idx, val in enumerate(tp): 652 | tp[idx] += cumsum 653 | cumsum += val 654 | 655 | rec = tp[:] 656 | for idx, val in enumerate(tp): 657 | rec[idx] = float(tp[idx]) / np.maximum(gt_counter_per_class[class_name], 1) 658 | 659 | prec = tp[:] 660 | for idx, val in enumerate(tp): 661 | prec[idx] = float(tp[idx]) / np.maximum((fp[idx] + tp[idx]), 1) 662 | 663 | ap, mrec, mprec = voc_ap(rec[:], prec[:]) 664 | F1 = np.array(rec)*np.array(prec)*2 / np.where((np.array(prec)+np.array(rec))==0, 1, (np.array(prec)+np.array(rec))) 665 | 666 | sum_AP += ap 667 | text = "{0:.2f}%".format(ap*100) + " = " + class_name + " AP " #class_name + " AP = {0:.2f}%".format(ap*100) 668 | 669 | if len(prec)>0: 670 | F1_text = "{0:.2f}".format(F1[score05_idx]) + " = " + class_name + " F1 " 671 | Recall_text = "{0:.2f}%".format(rec[score05_idx]*100) + " = " + class_name + " Recall " 672 | Precision_text = "{0:.2f}%".format(prec[score05_idx]*100) + " = " + class_name + " Precision " 673 | else: 674 | F1_text = "0.00" + " = " + class_name + " F1 " 675 | Recall_text = "0.00%" + " = " + class_name + " Recall " 676 | Precision_text = "0.00%" + " = " + class_name + " Precision " 677 | 678 | rounded_prec = [ '%.2f' % elem for elem in prec ] 679 | rounded_rec = [ '%.2f' % elem for elem in rec ] 680 | results_file.write(text + "\n Precision: " + str(rounded_prec) + "\n Recall :" + str(rounded_rec) + "\n\n") 681 | if not args.quiet: 682 | if len(prec)>0: 683 | print(text + "\t||\tscore_threhold=0.5 : " + "F1=" + "{0:.2f}".format(F1[score05_idx])\ 684 | + " ; Recall=" + "{0:.2f}%".format(rec[score05_idx]*100) + " ; Precision=" + "{0:.2f}%".format(prec[score05_idx]*100)) 685 | else: 686 | print(text + "\t||\tscore_threhold=0.5 : F1=0.00% ; Recall=0.00% ; Precision=0.00%") 687 | ap_dictionary[class_name] = ap 688 | 689 | n_images = counter_images_per_class[class_name] 690 | lamr, mr, fppi = log_average_miss_rate(np.array(rec), np.array(fp), n_images) 691 | lamr_dictionary[class_name] = lamr 692 | 693 | """ 694 | Draw plot 695 | """ 696 | if draw_plot: 697 | plt.plot(rec, prec, '-o') 698 | area_under_curve_x = mrec[:-1] + [mrec[-2]] + [mrec[-1]] 699 | area_under_curve_y = mprec[:-1] + [0.0] + [mprec[-1]] 700 | plt.fill_between(area_under_curve_x, 0, area_under_curve_y, alpha=0.2, edgecolor='r') 701 | 702 | fig = plt.gcf() 703 | fig.canvas.set_window_title('AP ' + class_name) 704 | 705 | plt.title('class: ' + text) 706 | plt.xlabel('Recall') 707 | plt.ylabel('Precision') 708 | axes = plt.gca() 709 | axes.set_xlim([0.0,1.0]) 710 | axes.set_ylim([0.0,1.05]) 711 | fig.savefig(results_files_path + "/AP/" + class_name + ".png") 712 | plt.cla() 713 | 714 | plt.plot(score, F1, "-", color='orangered') 715 | plt.title('class: ' + F1_text + "\nscore_threhold=0.5") 716 | plt.xlabel('Score_Threhold') 717 | plt.ylabel('F1') 718 | axes = plt.gca() 719 | axes.set_xlim([0.0,1.0]) 720 | axes.set_ylim([0.0,1.05]) 721 | fig.savefig(results_files_path + "/F1/" + class_name + ".png") 722 | plt.cla() 723 | 724 | plt.plot(score, rec, "-H", color='gold') 725 | plt.title('class: ' + Recall_text + "\nscore_threhold=0.5") 726 | plt.xlabel('Score_Threhold') 727 | plt.ylabel('Recall') 728 | axes = plt.gca() 729 | axes.set_xlim([0.0,1.0]) 730 | axes.set_ylim([0.0,1.05]) 731 | fig.savefig(results_files_path + "/Recall/" + class_name + ".png") 732 | plt.cla() 733 | 734 | plt.plot(score, prec, "-s", color='palevioletred') 735 | plt.title('class: ' + Precision_text + "\nscore_threhold=0.5") 736 | plt.xlabel('Score_Threhold') 737 | plt.ylabel('Precision') 738 | axes = plt.gca() 739 | axes.set_xlim([0.0,1.0]) 740 | axes.set_ylim([0.0,1.05]) 741 | fig.savefig(results_files_path + "/Precision/" + class_name + ".png") 742 | plt.cla() 743 | 744 | if show_animation: 745 | cv2.destroyAllWindows() 746 | 747 | results_file.write("\n# mAP of all classes\n") 748 | mAP = sum_AP / n_classes 749 | text = "mAP = {0:.2f}%".format(mAP*100) 750 | results_file.write(text + "\n") 751 | print(text) 752 | 753 | # remove the temp_files directory 754 | shutil.rmtree(TEMP_FILES_PATH) 755 | 756 | """ 757 | Count total of detection-results 758 | """ 759 | # iterate through all the files 760 | det_counter_per_class = {} 761 | for txt_file in dr_files_list: 762 | # get lines to list 763 | lines_list = file_lines_to_list(txt_file) 764 | for line in lines_list: 765 | class_name = line.split()[0] 766 | # check if class is in the ignore list, if yes skip 767 | if class_name in args.ignore: 768 | continue 769 | # count that object 770 | if class_name in det_counter_per_class: 771 | det_counter_per_class[class_name] += 1 772 | else: 773 | # if class didn't exist yet 774 | det_counter_per_class[class_name] = 1 775 | #print(det_counter_per_class) 776 | dr_classes = list(det_counter_per_class.keys()) 777 | 778 | 779 | """ 780 | Plot the total number of occurences of each class in the ground-truth 781 | """ 782 | if draw_plot: 783 | window_title = "ground-truth-info" 784 | plot_title = "ground-truth\n" 785 | plot_title += "(" + str(len(ground_truth_files_list)) + " files and " + str(n_classes) + " classes)" 786 | x_label = "Number of objects per class" 787 | output_path = results_files_path + "/ground-truth-info.png" 788 | to_show = False 789 | plot_color = 'forestgreen' 790 | draw_plot_func( 791 | gt_counter_per_class, 792 | n_classes, 793 | window_title, 794 | plot_title, 795 | x_label, 796 | output_path, 797 | to_show, 798 | plot_color, 799 | '', 800 | ) 801 | 802 | """ 803 | Write number of ground-truth objects per class to results.txt 804 | """ 805 | with open(results_files_path + "/results.txt", 'a') as results_file: 806 | results_file.write("\n# Number of ground-truth objects per class\n") 807 | for class_name in sorted(gt_counter_per_class): 808 | results_file.write(class_name + ": " + str(gt_counter_per_class[class_name]) + "\n") 809 | 810 | """ 811 | Finish counting true positives 812 | """ 813 | for class_name in dr_classes: 814 | # if class exists in detection-result but not in ground-truth then there are no true positives in that class 815 | if class_name not in gt_classes: 816 | count_true_positives[class_name] = 0 817 | #print(count_true_positives) 818 | 819 | """ 820 | Plot the total number of occurences of each class in the "detection-results" folder 821 | """ 822 | if draw_plot: 823 | window_title = "detection-results-info" 824 | # Plot title 825 | plot_title = "detection-results\n" 826 | plot_title += "(" + str(len(dr_files_list)) + " files and " 827 | count_non_zero_values_in_dictionary = sum(int(x) > 0 for x in list(det_counter_per_class.values())) 828 | plot_title += str(count_non_zero_values_in_dictionary) + " detected classes)" 829 | # end Plot title 830 | x_label = "Number of objects per class" 831 | output_path = results_files_path + "/detection-results-info.png" 832 | to_show = False 833 | plot_color = 'forestgreen' 834 | true_p_bar = count_true_positives 835 | draw_plot_func( 836 | det_counter_per_class, 837 | len(det_counter_per_class), 838 | window_title, 839 | plot_title, 840 | x_label, 841 | output_path, 842 | to_show, 843 | plot_color, 844 | true_p_bar 845 | ) 846 | 847 | """ 848 | Write number of detected objects per class to results.txt 849 | """ 850 | with open(results_files_path + "/results.txt", 'a') as results_file: 851 | results_file.write("\n# Number of detected objects per class\n") 852 | for class_name in sorted(dr_classes): 853 | n_det = det_counter_per_class[class_name] 854 | text = class_name + ": " + str(n_det) 855 | text += " (tp:" + str(count_true_positives[class_name]) + "" 856 | text += ", fp:" + str(n_det - count_true_positives[class_name]) + ")\n" 857 | results_file.write(text) 858 | 859 | """ 860 | Draw log-average miss rate plot (Show lamr of all classes in decreasing order) 861 | """ 862 | if draw_plot: 863 | window_title = "lamr" 864 | plot_title = "log-average miss rate" 865 | x_label = "log-average miss rate" 866 | output_path = results_files_path + "/lamr.png" 867 | to_show = False 868 | plot_color = 'royalblue' 869 | draw_plot_func( 870 | lamr_dictionary, 871 | n_classes, 872 | window_title, 873 | plot_title, 874 | x_label, 875 | output_path, 876 | to_show, 877 | plot_color, 878 | "" 879 | ) 880 | 881 | """ 882 | Draw mAP plot (Show AP's of all classes in decreasing order) 883 | """ 884 | if draw_plot: 885 | window_title = "mAP" 886 | plot_title = "mAP = {0:.2f}%".format(mAP*100) 887 | x_label = "Average Precision" 888 | output_path = results_files_path + "/mAP.png" 889 | to_show = True 890 | plot_color = 'royalblue' 891 | draw_plot_func( 892 | ap_dictionary, 893 | n_classes, 894 | window_title, 895 | plot_title, 896 | x_label, 897 | output_path, 898 | to_show, 899 | plot_color, 900 | "" 901 | ) 902 | -------------------------------------------------------------------------------- /kmeans_for_anchors.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import random 3 | import xml.etree.ElementTree as ET 4 | 5 | import numpy as np 6 | 7 | 8 | def cas_iou(box,cluster): 9 | x = np.minimum(cluster[:,0],box[0]) 10 | y = np.minimum(cluster[:,1],box[1]) 11 | 12 | intersection = x * y 13 | area1 = box[0] * box[1] 14 | 15 | area2 = cluster[:,0] * cluster[:,1] 16 | iou = intersection / (area1 + area2 -intersection) 17 | 18 | return iou 19 | 20 | def avg_iou(box,cluster): 21 | return np.mean([np.max(cas_iou(box[i],cluster)) for i in range(box.shape[0])]) 22 | 23 | 24 | def kmeans(box,k): 25 | # 取出一共有多少框 26 | row = box.shape[0] 27 | 28 | # 每个框各个点的位置 29 | distance = np.empty((row,k)) 30 | 31 | # 最后的聚类位置 32 | last_clu = np.zeros((row,)) 33 | 34 | np.random.seed() 35 | 36 | # 随机选5个当聚类中心 37 | cluster = box[np.random.choice(row,k,replace = False)] 38 | # cluster = random.sample(row, k) 39 | while True: 40 | # 计算每一行距离五个点的iou情况。 41 | for i in range(row): 42 | distance[i] = 1 - cas_iou(box[i],cluster) 43 | 44 | # 取出最小点 45 | near = np.argmin(distance,axis=1) 46 | 47 | if (last_clu == near).all(): 48 | break 49 | 50 | # 求每一个类的中位点 51 | for j in range(k): 52 | cluster[j] = np.median( 53 | box[near == j],axis=0) 54 | 55 | last_clu = near 56 | 57 | return cluster 58 | 59 | def load_data(path): 60 | data = [] 61 | # 对于每一个xml都寻找box 62 | for xml_file in glob.glob('{}/*xml'.format(path)): 63 | tree = ET.parse(xml_file) 64 | height = int(tree.findtext('./size/height')) 65 | width = int(tree.findtext('./size/width')) 66 | if height<=0 or width<=0: 67 | continue 68 | 69 | # 对于每一个目标都获得它的宽高 70 | for obj in tree.iter('object'): 71 | xmin = int(float(obj.findtext('bndbox/xmin'))) / width 72 | ymin = int(float(obj.findtext('bndbox/ymin'))) / height 73 | xmax = int(float(obj.findtext('bndbox/xmax'))) / width 74 | ymax = int(float(obj.findtext('bndbox/ymax'))) / height 75 | 76 | xmin = np.float64(xmin) 77 | ymin = np.float64(ymin) 78 | xmax = np.float64(xmax) 79 | ymax = np.float64(ymax) 80 | # 得到宽高 81 | data.append([xmax-xmin,ymax-ymin]) 82 | return np.array(data) 83 | 84 | 85 | if __name__ == '__main__': 86 | # 运行该程序会计算'./VOCdevkit/VOC2007/Annotations'的xml 87 | # 会生成yolo_anchors.txt 88 | SIZE = 416 89 | anchors_num = 6 90 | # 载入数据集,可以使用VOC的xml 91 | path = r'./VOCdevkit/VOC2007/Annotations' 92 | 93 | # 载入所有的xml 94 | # 存储格式为转化为比例后的width,height 95 | data = load_data(path) 96 | 97 | # 使用k聚类算法 98 | out = kmeans(data,anchors_num) 99 | out = out[np.argsort(out[:,0])] 100 | print('acc:{:.2f}%'.format(avg_iou(data,out) * 100)) 101 | print(out*SIZE) 102 | data = out*SIZE 103 | f = open("yolo_anchors.txt", 'w') 104 | row = np.shape(data)[0] 105 | for i in range(row): 106 | if i == 0: 107 | x_y = "%d,%d" % (data[i][0], data[i][1]) 108 | else: 109 | x_y = ", %d,%d" % (data[i][0], data[i][1]) 110 | f.write(x_y) 111 | f.close() 112 | -------------------------------------------------------------------------------- /light.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | from yolo import YOLO 4 | 5 | 6 | def MinRgb(c): 7 | return min(c[0], c[1], c[2]) 8 | 9 | 10 | def SumRgb(c): 11 | return c[0] + c[1] + c[2] 12 | 13 | 14 | def Invert(img): 15 | img = 255 - img 16 | return img 17 | 18 | 19 | def GetA(R, G, B, k=100): 20 | # k默认是原文获取排序后前100个像素点 21 | rlist = [] 22 | height, width = R.shape[0], R.shape[1] 23 | for hi in range(height): 24 | for wi in range(width): 25 | rlist.append([R[hi][wi], G[hi][wi], B[hi][wi]]) 26 | rlist.sort(key=MinRgb) 27 | rlist.reverse() 28 | rlist = rlist[:k] 29 | rlist.sort(key=SumRgb) 30 | rlist.reverse() 31 | return rlist[0][0], rlist[0][1], rlist[0][2] 32 | 33 | 34 | def CalT(R, G, B, r_A, g_A, b_A, size=1, w=0.76): 35 | # 计算A值时使用size×size窗口,以图像边缘点为窗口中心时需要进行填充 36 | # 图像填充时上下左右各填充1行/列255 37 | ts = (size - 1) // 2 38 | height, width = R.shape[0], R.shape[1] 39 | R_f = np.pad(R, ((ts, ts), (ts, ts)), 'constant', constant_values=(255, 255)) / r_A 40 | G_f = np.pad(G, ((ts, ts), (ts, ts)), 'constant', constant_values=(255, 255)) / g_A 41 | B_f = np.pad(B, ((ts, ts), (ts, ts)), 'constant', constant_values=(255, 255)) / b_A 42 | 43 | shape = (height, width, size, size) 44 | strides = R_f.itemsize * np.array([width + ts * 2, 1, width + ts * 2, 1]) 45 | 46 | blocks_R = np.lib.stride_tricks.as_strided(R_f, shape=shape, strides=strides) 47 | blocks_G = np.lib.stride_tricks.as_strided(G_f, shape=shape, strides=strides) 48 | blocks_B = np.lib.stride_tricks.as_strided(B_f, shape=shape, strides=strides) 49 | 50 | t = np.zeros((height, width)) 51 | for hi in range(height): 52 | for wi in range(width): 53 | t[hi, wi] = 1 - w * min(np.min(blocks_R[hi, wi]), np.min(blocks_G[hi, wi]), np.min(blocks_B[hi, wi])) 54 | if t[hi, wi] < 0.5: 55 | t[hi, wi] = 2 * t[hi, wi] * t[hi, wi] 56 | return t 57 | 58 | 59 | def DeHaze(img): 60 | # 获取图像宽度、高度 61 | # width, height = img.size 62 | # 获取图像的RGB数组 63 | img = np.asarray(img, dtype=np.int32) 64 | R, G, B = img[:, :, 0], img[:, :, 1], img[:, :, 2] 65 | #进行反转 66 | R, G, B = Invert(R), Invert(G), Invert(B) 67 | # 计算A值 68 | r_A, g_A, b_A = GetA(R, G, B) 69 | t = CalT(R, G, B, r_A, g_A, b_A) 70 | #得到真实图(也就是去雾之后的反转图) 71 | J_R = (R - r_A) / t + r_A 72 | J_G = (G - g_A) / t + g_A 73 | J_B = (B - b_A) / t + b_A 74 | # 进行低光照图还原,光照增强 75 | J_R, J_G, J_B = Invert(J_R), Invert(J_G), Invert(J_B) 76 | r = Image.fromarray(J_R).convert('L') 77 | g = Image.fromarray(J_G).convert('L') 78 | b = Image.fromarray(J_B).convert('L') 79 | image = Image.merge("RGB", (r, g, b)) 80 | image.save("dark_result.jpg") 81 | image.show() 82 | 83 | 84 | if __name__ == '__main__': 85 | yolo = YOLO() 86 | while True: 87 | img = input('Input image filename:') 88 | try: 89 | image = Image.open(img) 90 | except: 91 | print('Open Error! Try again!') 92 | continue 93 | else: 94 | DeHaze(image) 95 | img = Image.open("dark_result.jpg") 96 | r_image = yolo.detect_image(img) 97 | r_image.show() 98 | r_image.save('7.jpg') 99 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | 2 | from PIL import Image 3 | 4 | from yolo import YOLO 5 | 6 | yolo = YOLO() 7 | 8 | while True: 9 | img = input('Input image filename:') 10 | try: 11 | image = Image.open(img) 12 | except: 13 | print('Open Error! Try again!') 14 | continue 15 | else: 16 | r_image = yolo.detect_image(image) 17 | r_image.show() 18 | r_image.save('6.jpg') 19 | 20 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------# 2 | # 查看网络参数和结构 3 | # --------------------------------------------# 4 | import torch 5 | from torchsummary import summary 6 | 7 | from nets.Dw_yolo4_tiny import YoloBody 8 | 9 | if __name__ == "__main__": 10 | # 需要使用device来指定网络在GPU还是CPU运行 11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 12 | model = YoloBody(3, 20).to(device) 13 | summary(model, input_size=(3, 416, 416)) 14 | 15 | # darknet 16 | # Total 17 | # params: 5, 918, 006 18 | # Trainable 19 | # params: 5, 918, 006 20 | 21 | # mobilenet1: 22 | # Total 23 | # params: 2, 842, 230 24 | # Trainable 25 | # params: 2, 842, 230 26 | 27 | # Total 28 | # params: 4, 075, 926 29 | 30 | # 深度可分级卷积_darknet 31 | # Total 32 | # params: 4, 878, 903 33 | # Trainable 34 | # params: 4, 878, 903 35 | # Non - trainable 36 | # params: 0 37 | 38 | # 深度可分离卷积最终版_ 39 | # Total params: 2,787,383 40 | # Trainable params: 2,787,383 41 | # Non-trainable params: 0 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------# 2 | # 对数据集进行训练 3 | # -------------------------------------# 4 | import os 5 | import numpy as np 6 | import time 7 | import torch 8 | from torch.autograd import Variable 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | import torch.nn.functional as F 12 | import torch.backends.cudnn as cudnn 13 | from torch.utils.data import DataLoader 14 | from utils.dataloader import yolo_dataset_collate, YoloDataset 15 | from nets.yolo_training import YOLOLoss, Generator 16 | from nets.Dw_yolo4_tiny import YoloBody 17 | from tqdm import tqdm 18 | 19 | 20 | # 学习率 21 | def get_lr(optimizer): 22 | for param_group in optimizer.param_groups: 23 | return param_group['lr'] 24 | 25 | 26 | # ---------------------------------------------------# 27 | # 获得类和先验框 28 | # ---------------------------------------------------# 29 | def get_classes(classes_path): 30 | with open(classes_path) as f: 31 | class_names = f.readlines() 32 | class_names = [c.strip() for c in class_names] 33 | return class_names 34 | 35 | 36 | def get_anchors(anchors_path): 37 | with open(anchors_path) as f: 38 | anchors = f.readline() 39 | anchors = [float(x) for x in anchors.split(',')] 40 | return np.array(anchors).reshape([-1, 3, 2]) 41 | 42 | 43 | def fit_one_epoch(net, yolo_losses, epoch, epoch_size, epoch_size_val, gen, genval, Epoch, cuda): 44 | total_loss = 0 45 | val_loss = 0 46 | 47 | net.train() 48 | with tqdm(total=epoch_size, desc=f'Epoch {epoch + 1}/{Epoch}', postfix=dict, mininterval=0.3) as pbar: 49 | for iteration, batch in enumerate(gen): 50 | if iteration >= epoch_size: 51 | break 52 | images, targets = batch[0], batch[1] 53 | with torch.no_grad(): 54 | if cuda: 55 | images = Variable(torch.from_numpy(images).type(torch.FloatTensor)).cuda() 56 | targets = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets] 57 | else: 58 | images = Variable(torch.from_numpy(images).type(torch.FloatTensor)) 59 | targets = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets] 60 | 61 | # ----------------------# 62 | # 清零梯度 63 | # ----------------------# 64 | optimizer.zero_grad() 65 | # 前向传播 66 | outputs = net(images) 67 | losses = [] 68 | num_pos_all = 0 69 | # 计算损失 70 | for i in range(2): 71 | loss_item, num_pos = yolo_losses[i](outputs[i], targets) 72 | losses.append(loss_item) 73 | num_pos_all += num_pos 74 | 75 | loss = sum(losses) / num_pos_all 76 | # 反向传播 77 | loss.backward() 78 | # 更新优化器,更新网络参数 79 | optimizer.step() 80 | 81 | total_loss += loss.item() 82 | pbar.set_postfix(**{'total_loss': total_loss / (iteration + 1), 83 | 'lr': get_lr(optimizer)}) 84 | pbar.update(1) 85 | 86 | net.eval() 87 | print('Start Validation') 88 | with tqdm(total=epoch_size_val, desc=f'Epoch {epoch + 1}/{Epoch}', postfix=dict, mininterval=0.3) as pbar: 89 | for iteration, batch in enumerate(genval): 90 | if iteration >= epoch_size_val: 91 | break 92 | images_val, targets_val = batch[0], batch[1] 93 | 94 | with torch.no_grad(): 95 | if cuda: 96 | images_val = Variable(torch.from_numpy(images_val).type(torch.FloatTensor)).cuda() 97 | targets_val = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets_val] 98 | else: 99 | images_val = Variable(torch.from_numpy(images_val).type(torch.FloatTensor)) 100 | targets_val = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets_val] 101 | optimizer.zero_grad() 102 | outputs = net(images_val) 103 | losses = [] 104 | num_pos_all = 0 105 | for i in range(2): 106 | loss_item, num_pos = yolo_losses[i](outputs[i], targets_val) 107 | losses.append(loss_item) 108 | num_pos_all += num_pos 109 | loss = sum(losses) / num_pos_all 110 | val_loss += loss.item() 111 | pbar.set_postfix(**{'total_loss': val_loss / (iteration + 1)}) 112 | pbar.update(1) 113 | print('Finish Validation') 114 | print('Epoch:' + str(epoch + 1) + '/' + str(Epoch)) 115 | print('Total Loss: %.4f || Val Loss: %.4f ' % (total_loss / (epoch_size + 1), val_loss / (epoch_size_val + 1))) 116 | print('Saving state, iter:', str(epoch + 1)) 117 | torch.save(model.state_dict(), 'log_perfect/Epoch%d-Total_Loss%.4f-Val_Loss%.4f.pth' % ( 118 | (epoch + 1), total_loss / (epoch_size + 1), val_loss / (epoch_size_val + 1))) 119 | 120 | # 判断学习率是否要更新 121 | lr_scheduler.step(loss) 122 | 123 | 124 | 125 | 126 | if __name__ == "__main__": 127 | 128 | # 因为使用了GPU训练,所以Cuda为True 129 | Cuda = True 130 | # -------------------------------# 131 | # Dataloder的使用 132 | # -------------------------------# 133 | Use_Data_Loader = True 134 | normalize = False 135 | # -------------------------------# 136 | # 输入的shape大小 137 | # 显存小所以用416x416 138 | input_shape = (416, 416) 139 | 140 | # classes和anchor的路径, 141 | anchors_path = 'model_data/yolo_anchors.txt' 142 | classes_path = 'model_data/new.txt' 143 | # 获取classes和anchor 144 | class_names = get_classes(classes_path) 145 | anchors = get_anchors(anchors_path) 146 | num_classes = len(class_names) 147 | 148 | # mosaic 马赛克数据增强 True or False ,因为不稳定,所以不用 149 | # Cosine_scheduler 余弦退火学习率 True or False 150 | mosaic = False 151 | Cosine_lr = False 152 | smoooth_label = 0 153 | 154 | # ------------------------------------------------------# 155 | # 创建yolo模型 156 | # 训练前一定要修改classes_path和对应的txt文件 157 | # ------------------------------------------------------# 158 | model = YoloBody(len(anchors[0]), num_classes) 159 | model_path = "log_perfect/Epoch966-Total_Loss2.3349-Val_Loss1.8903.pth" 160 | # 加快模型训练的效率 161 | print('Loading weights into state dict...') 162 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 163 | model_dict = model.state_dict() 164 | pretrained_dict = torch.load(model_path, map_location=device) 165 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)} 166 | model_dict.update(pretrained_dict) 167 | model.load_state_dict(model_dict) 168 | print('Finished!') 169 | 170 | net = model.train() 171 | 172 | if Cuda: 173 | net = torch.nn.DataParallel(model) 174 | cudnn.benchmark = True 175 | net = net.cuda() 176 | 177 | # 建立loss函数 178 | yolo_losses = [] 179 | for i in range(2): 180 | yolo_losses.append(YOLOLoss(np.reshape(anchors, [-1, 2]), num_classes, \ 181 | (input_shape[1], input_shape[0]), smoooth_label, Cuda, normalize)) 182 | 183 | # 获得图片路径和标签 184 | annotation_path = '2007_train.txt' 185 | # 2007_test.txt和2007_val.txt里面没有内容是正常的。训练不会使用到。 186 | # 当前划分方式下,验证集和训练集的比例为1:9 187 | val_split = 0.1 188 | with open(annotation_path) as f: 189 | lines = f.readlines() 190 | np.random.seed(10101) 191 | np.random.shuffle(lines) 192 | np.random.seed(None) 193 | num_val = int(len(lines) * val_split) 194 | num_train = len(lines) - num_val 195 | 196 | if True: 197 | lr = 1e-3 198 | Batch_size = 16 199 | Init_Epoch = 384 200 | Freeze_Epoch = 384 201 | 202 | optimizer = optim.Adam(net.parameters(), lr) 203 | if Cosine_lr: 204 | lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=1e-5) 205 | else: 206 | # lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.92) 207 | # 当loss值改进经过了10个epoch都没有改进0.1,那么就改变学习率 208 | lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, 209 | verbose=False, 210 | threshold=0.0001, threshold_mode='rel', cooldown=0, 211 | min_lr=0, 212 | eps=1e-08) 213 | 214 | if Use_Data_Loader: 215 | train_dataset = YoloDataset(lines[:num_train], (input_shape[0], input_shape[1]), mosaic=mosaic, 216 | is_train=True) 217 | val_dataset = YoloDataset(lines[num_train:], (input_shape[0], input_shape[1]), mosaic=False, is_train=False) 218 | gen = DataLoader(train_dataset, shuffle=True, batch_size=Batch_size, num_workers=4, pin_memory=True, 219 | drop_last=True, collate_fn=yolo_dataset_collate) 220 | gen_val = DataLoader(val_dataset, shuffle=True, batch_size=Batch_size, num_workers=4, pin_memory=True, 221 | drop_last=True, collate_fn=yolo_dataset_collate) 222 | else: 223 | gen = Generator(Batch_size, lines[:num_train], 224 | (input_shape[0], input_shape[1])).generate(train=True, mosaic=mosaic) 225 | gen_val = Generator(Batch_size, lines[num_train:], 226 | (input_shape[0], input_shape[1])).generate(train=False, mosaic=mosaic) 227 | 228 | epoch_size = max(1, num_train // Batch_size) 229 | epoch_size_val = num_val // Batch_size 230 | # ------------------------------------# 231 | # 冻结一定部分训练 232 | # ------------------------------------# 233 | for param in model.backbone.parameters(): 234 | param.requires_grad = False 235 | 236 | for epoch in range(Init_Epoch, Freeze_Epoch): 237 | fit_one_epoch(net, yolo_losses, epoch, epoch_size, epoch_size_val, gen, gen_val, Freeze_Epoch, Cuda) 238 | # lr_scheduler.step() 239 | 240 | if True: 241 | lr = 1e-4 242 | # lr = 1.83e-6 243 | # lr = 3.87e-6 244 | Batch_size = 4 245 | Freeze_Epoch = 920 246 | Unfreeze_Epoch = 970 247 | 248 | optimizer = optim.Adam(net.parameters(), lr) 249 | if Cosine_lr: 250 | lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=1e-5) 251 | else: 252 | # lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.92) 253 | lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, 254 | verbose=False, 255 | threshold=0.0001, threshold_mode='rel', cooldown=0, 256 | min_lr=0, 257 | eps=1e-08) 258 | 259 | if Use_Data_Loader: 260 | train_dataset = YoloDataset(lines[:num_train], (input_shape[0], input_shape[1]), mosaic=mosaic, 261 | is_train=True) 262 | val_dataset = YoloDataset(lines[num_train:], (input_shape[0], input_shape[1]), mosaic=False, is_train=False) 263 | gen = DataLoader(train_dataset, shuffle=True, batch_size=Batch_size, num_workers=4, pin_memory=True, 264 | drop_last=True, collate_fn=yolo_dataset_collate) 265 | gen_val = DataLoader(val_dataset, shuffle=True, batch_size=Batch_size, num_workers=4, pin_memory=True, 266 | drop_last=True, collate_fn=yolo_dataset_collate) 267 | else: 268 | gen = Generator(Batch_size, lines[:num_train], 269 | (input_shape[0], input_shape[1])).generate(train=True, mosaic=mosaic) 270 | gen_val = Generator(Batch_size, lines[num_train:], 271 | (input_shape[0], input_shape[1])).generate(train=False, mosaic=mosaic) 272 | 273 | epoch_size = max(1, num_train // Batch_size) 274 | epoch_size_val = num_val // Batch_size 275 | for param in model.backbone.parameters(): 276 | param.requires_grad = True 277 | 278 | for epoch in range(Freeze_Epoch, Unfreeze_Epoch): 279 | fit_one_epoch(net, yolo_losses, epoch, epoch_size, epoch_size_val, gen, gen_val, Unfreeze_Epoch, Cuda) 280 | -------------------------------------------------------------------------------- /video.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------# 2 | # 调用摄像头或者视频进行检测 3 | # -------------------------------------# 4 | import time 5 | 6 | import cv2 7 | import numpy as np 8 | from PIL import Image 9 | 10 | from yolo import YOLO 11 | yolo = YOLO() 12 | # -------------------------------------# 13 | # 调用摄像头 在Jetson Nano上运行时这里徐需要改变,因为CSI 14 | #capture = cv2.VideoCapture("3.mp4") 15 | 16 | capture = cv2.VideoCapture(0) 17 | fps = 0.0 18 | while (True): 19 | t1 = time.time() 20 | # 读取某一帧 21 | ref, frame = capture.read() 22 | # 格式转变,BGRtoRGB 23 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 24 | # 转变成Image 25 | frame = Image.fromarray(np.uint8(frame)) 26 | # 进行检测 27 | frame = np.array(yolo.detect_image(frame)) 28 | # RGBtoBGR满足opencv显示格式 29 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) 30 | 31 | fps = (fps + (1. / (time.time() - t1))) / 2 32 | print("fps= %.2f" % (fps)) 33 | frame = cv2.putText(frame, "fps= %.2f" % (fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) 34 | 35 | cv2.imshow("video", frame) 36 | 37 | c = cv2.waitKey(1) & 0xff 38 | if c == 27: 39 | capture.release() 40 | break 41 | -------------------------------------------------------------------------------- /voc_annotation.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------# 2 | # 运行前一定要修改classes 3 | # 如果生成的2007_train.txt里面没有目标信息 4 | # 那么就是因为classes没有设定正确 5 | # ---------------------------------------------# 6 | import xml.etree.ElementTree as ET 7 | from os import getcwd 8 | 9 | sets = [('2007', 'train'), ('2007', 'val'), ('2007', 'test')] 10 | 11 | classes = ["Multi-rotor"] 12 | 13 | 14 | def convert_annotation(year, image_id, list_file): 15 | in_file = open('VOCdevkit/VOC%s/Annotations/%s.xml' % (year, image_id), encoding='utf-8') 16 | tree = ET.parse(in_file) 17 | root = tree.getroot() 18 | 19 | for obj in root.iter('object'): 20 | difficult = 0 21 | if obj.find('difficult') != None: 22 | difficult = obj.find('difficult').text 23 | 24 | cls = obj.find('name').text 25 | if cls not in classes or int(difficult) == 1: 26 | continue 27 | cls_id = classes.index(cls) 28 | xmlbox = obj.find('bndbox') 29 | b = (int(xmlbox.find('xmin').text), int(xmlbox.find('ymin').text), int(xmlbox.find('xmax').text), 30 | int(xmlbox.find('ymax').text)) 31 | list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id)) 32 | 33 | 34 | wd = getcwd() 35 | 36 | for year, image_set in sets: 37 | image_ids = open('VOCdevkit/VOC%s/ImageSets/Main/%s.txt' % (year, image_set)).read().strip().split() 38 | list_file = open('%s_%s.txt' % (year, image_set), 'w') 39 | imagename = "to" 40 | for image_id in image_ids: 41 | if imagename in image_id: 42 | list_file.write('%s/VOCdevkit/VOC%s/JPEGImages/%s.png' % (wd, year, image_id)) 43 | else: 44 | list_file.write('%s/VOCdevkit/VOC%s/JPEGImages/%s.jpg' % (wd, year, image_id)) 45 | convert_annotation(year, image_id, list_file) 46 | list_file.write('\n') 47 | list_file.close() 48 | -------------------------------------------------------------------------------- /yolo.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------# 2 | # 创建YOLO类 3 | # -------------------------------------# 4 | import colorsys 5 | import os 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from PIL import Image, ImageDraw, ImageFont 10 | from nets.Dw_yolo4_tiny import YoloBody 11 | from utils.utils import (DecodeBox, bbox_iou, letterbox_image, 12 | non_max_suppression, yolo_correct_boxes) 13 | 14 | 15 | class YOLO(object): 16 | _defaults = { 17 | "model_path": 'log_perfect/Epoch966-Total_Loss2.3349-Val_Loss1.8903.pth', 18 | "anchors_path": 'model_data/yolo_anchors.txt', 19 | "classes_path": 'model_data/new.txt', 20 | "model_image_size": (416, 416, 3), 21 | "confidence": 0.5, 22 | "iou": 0.3, 23 | "cuda": True, 24 | "letterbox_image": False, 25 | } 26 | 27 | @classmethod 28 | def get_defaults(cls, n): 29 | if n in cls._defaults: 30 | return cls._defaults[n] 31 | else: 32 | return "Unrecognized attribute name '" + n + "'" 33 | 34 | # 初始化YOLO 35 | 36 | def __init__(self, **kwargs): 37 | self.__dict__.update(self._defaults) 38 | self.class_names = self._get_class() 39 | self.anchors = self._get_anchors() 40 | self.generate() 41 | 42 | # 获得所有的分类 43 | def _get_class(self): 44 | classes_path = os.path.expanduser(self.classes_path) 45 | with open(classes_path) as f: 46 | class_names = f.readlines() 47 | class_names = [c.strip() for c in class_names] 48 | return class_names 49 | 50 | # ---------------------------------------------------# 51 | # 获得所有的先验框 52 | # ---------------------------------------------------# 53 | def _get_anchors(self): 54 | anchors_path = os.path.expanduser(self.anchors_path) 55 | with open(anchors_path) as f: 56 | anchors = f.readline() 57 | anchors = [float(x) for x in anchors.split(',')] 58 | return np.array(anchors).reshape([-1, 3, 2]) 59 | 60 | # ---------------------------------------------------# 61 | # 生成模型 62 | # ---------------------------------------------------# 63 | def generate(self): 64 | 65 | # 建立dw_yolov4_tiny模型 66 | 67 | self.net = YoloBody(len(self.anchors[0]), len(self.class_names)) 68 | 69 | # 载入yolov4_tiny模型的权重 70 | print('Loading weights into state dict...') 71 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 72 | state_dict = torch.load(self.model_path, map_location=device) 73 | self.net.load_state_dict(state_dict) 74 | 75 | # data = torch.randn((1, 3, 416, 416)).cuda().half() 76 | # self.net_trt = torch2trt(self.net, [data], fp16_mode=True) 77 | # torch.save(self.net_trt.state_dict(), "logs/net_trt.pth") 78 | 79 | print('Finished!') 80 | 81 | if self.cuda: 82 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 83 | self.net = nn.DataParallel(self.net) 84 | # self.net = self.net.half() 85 | self.net = self.net.cuda().eval() 86 | 87 | # self.net_trt = self.net_trt.half() 88 | # self.net_trt = nn.DataParallel(self.net_trt) 89 | # self.net_trt = self.net_trt.cuda() 90 | # ---------------------------------------------------# 91 | # 建立特征层解码用的工具 92 | # ---------------------------------------------------# 93 | self.yolo_decodes = [] 94 | self.anchors_mask = [[3, 4, 5], [1, 2, 3]] 95 | for i in range(2): 96 | self.yolo_decodes.append( 97 | DecodeBox(np.reshape(self.anchors, [-1, 2])[self.anchors_mask[i]], len(self.class_names), 98 | (self.model_image_size[1], self.model_image_size[0]))) 99 | 100 | print('{} model, anchors, and classes loaded.'.format(self.model_path)) 101 | # 画框设置不同的颜色 102 | hsv_tuples = [(x / len(self.class_names), 1., 1.) 103 | for x in range(len(self.class_names))] 104 | self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples)) 105 | self.colors = list( 106 | map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), 107 | self.colors)) 108 | 109 | # 检测图片 110 | def detect_image(self, image): 111 | image_shape = np.array(np.shape(image)[0:2]) 112 | 113 | # 给图像增加灰条,实现不失真的resize 114 | 115 | if self.letterbox_image: 116 | crop_img = np.array(letterbox_image(image, (self.model_image_size[1], self.model_image_size[0]))) 117 | else: 118 | crop_img = image.convert('RGB') 119 | crop_img = crop_img.resize((self.model_image_size[1], self.model_image_size[0]), Image.BICUBIC) 120 | photo = np.array(crop_img, dtype=np.float32) / 255.0 121 | photo = np.transpose(photo, (2, 0, 1)) 122 | 123 | # 添加上batch_size维度 124 | 125 | images = [photo] 126 | 127 | with torch.no_grad(): 128 | images = torch.from_numpy(np.asarray(images)).float() 129 | # images = torch.from_numpy(np.asarray(images)) 130 | if self.cuda: 131 | images = images.cuda() 132 | # images = images.half() 133 | 134 | # ---------------------------------------------------------# 135 | # 将图像输入网络当中进行预测! 136 | # ---------------------------------------------------------# 137 | outputs = self.net(images) 138 | # outputs_trt = self.net_trt(images) #进行TRT加速后的两种模型比较 139 | 140 | output_list = [] 141 | for i in range(2): 142 | output_list.append(self.yolo_decodes[i](outputs[i])) 143 | 144 | # ---------------------------------------------------------# 145 | # 将预测框进行堆叠,然后进行非极大抑制 146 | # ---------------------------------------------------------# 147 | output = torch.cat(output_list, 1) 148 | batch_detections = non_max_suppression(output, len(self.class_names), 149 | conf_thres=self.confidence, 150 | nms_thres=self.iou) 151 | # 如果没有检测出物体,返回原图 152 | 153 | try: 154 | batch_detections = batch_detections[0].cpu().numpy() 155 | except: 156 | return image 157 | 158 | # 对预测框进行得分筛选 159 | top_index = batch_detections[:, 4] * batch_detections[:, 5] > self.confidence 160 | top_conf = batch_detections[top_index, 4] * batch_detections[top_index, 5] 161 | top_label = np.array(batch_detections[top_index, -1], np.int32) 162 | top_bboxes = np.array(batch_detections[top_index, :4]) 163 | top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims(top_bboxes[:, 0], -1), np.expand_dims( 164 | top_bboxes[:, 1], -1), np.expand_dims(top_bboxes[:, 2], -1), np.expand_dims(top_bboxes[:, 3], -1) 165 | 166 | # 在图像传入网络预测前会进行letterbox_image给图像周围添加灰条,因此生成的top_bboxes是相对于有灰条的图像的,这里是去除灰条的部分。 167 | # 画图 168 | if self.letterbox_image: 169 | boxes = yolo_correct_boxes(top_ymin, top_xmin, top_ymax, top_xmax, 170 | np.array([self.model_image_size[0], self.model_image_size[1]]), image_shape) 171 | else: 172 | top_xmin = top_xmin / self.model_image_size[1] * image_shape[1] 173 | top_ymin = top_ymin / self.model_image_size[0] * image_shape[0] 174 | top_xmax = top_xmax / self.model_image_size[1] * image_shape[1] 175 | top_ymax = top_ymax / self.model_image_size[0] * image_shape[0] 176 | boxes = np.concatenate([top_ymin, top_xmin, top_ymax, top_xmax], axis=-1) 177 | 178 | font = ImageFont.truetype(font='model_data/simhei.ttf', 179 | size=np.floor(3e-2 * np.shape(image)[1] + 0.5).astype('int32')) 180 | 181 | thickness = max((np.shape(image)[0] + np.shape(image)[1]) // self.model_image_size[0], 1) 182 | 183 | for i, c in enumerate(top_label): 184 | predicted_class = self.class_names[c] 185 | score = top_conf[i] 186 | 187 | top, left, bottom, right = boxes[i] 188 | top = top - 5 189 | left = left - 5 190 | bottom = bottom + 5 191 | right = right + 5 192 | 193 | top = max(0, np.floor(top + 0.5).astype('int32')) 194 | left = max(0, np.floor(left + 0.5).astype('int32')) 195 | bottom = min(np.shape(image)[0], np.floor(bottom + 0.5).astype('int32')) 196 | right = min(np.shape(image)[1], np.floor(right + 0.5).astype('int32')) 197 | 198 | # 画框框 199 | label = '{} {:.2f}'.format(predicted_class, score) 200 | draw = ImageDraw.Draw(image) 201 | label_size = draw.textsize(label, font) 202 | label = label.encode('utf-8') 203 | print(label, top, left, bottom, right) 204 | 205 | if top - label_size[1] >= 0: 206 | text_origin = np.array([left, top - label_size[1]]) 207 | else: 208 | text_origin = np.array([left, top + 1]) 209 | 210 | for i in range(thickness): 211 | draw.rectangle( 212 | [left + i, top + i, right - i, bottom - i], 213 | outline=self.colors[self.class_names.index(predicted_class)]) 214 | draw.rectangle( 215 | [tuple(text_origin), tuple(text_origin + label_size)], 216 | fill=self.colors[self.class_names.index(predicted_class)]) 217 | draw.text(text_origin, str(label, 'UTF-8'), fill=(0, 0, 0), font=font) 218 | del draw 219 | return image 220 | -------------------------------------------------------------------------------- /yolo1.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------# 2 | # 创建YOLO类 3 | # -------------------------------------# 4 | import colorsys 5 | import os 6 | 7 | import cv2 8 | import numpy as np 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | import torch.nn as nn 12 | from PIL import Image, ImageDraw, ImageFont 13 | from torch.autograd import Variable 14 | from nets.yolo4_tiny import YoloBody 15 | from utils.utils import (DecodeBox, bbox_iou, letterbox_image, 16 | non_max_suppression, yolo_correct_boxes) 17 | 18 | 19 | # --------------------------------------------# 20 | # 使用自己训练好的模型预测需要修改2个参数 21 | # model_path和classes_path都需要修改! 22 | # 如果出现shape不匹配,一定要注意 23 | # 训练时的model_path和classes_path参数的修改 24 | # --------------------------------------------# 25 | class YOLO(object): 26 | _defaults = { 27 | "model_path": 'logs/Epoch99-Total_Loss1.7024-Val_Loss1.2447.pth', 28 | "anchors_path": 'model_data/yolo_anchors.txt', 29 | "classes_path": 'model_data/new.txt', 30 | "model_image_size": (416, 416, 3), 31 | "confidence": 0.5, 32 | "iou": 0.3, 33 | "cuda": True, 34 | # ---------------------------------------------------------------------# 35 | # 该变量用于控制是否使用letterbox_image对输入图像进行不失真的resize, 36 | # 在多次测试后,发现关闭letterbox_image直接resize的效果更好 37 | # ---------------------------------------------------------------------# 38 | "letterbox_image": False, 39 | } 40 | 41 | @classmethod 42 | def get_defaults(cls, n): 43 | if n in cls._defaults: 44 | return cls._defaults[n] 45 | else: 46 | return "Unrecognized attribute name '" + n + "'" 47 | 48 | # ---------------------------------------------------# 49 | # 初始化YOLO 50 | # ---------------------------------------------------# 51 | def __init__(self, **kwargs): 52 | self.__dict__.update(self._defaults) 53 | self.class_names = self._get_class() 54 | self.anchors = self._get_anchors() 55 | self.generate() 56 | 57 | # ---------------------------------------------------# 58 | # 获得所有的分类 59 | # ---------------------------------------------------# 60 | def _get_class(self): 61 | classes_path = os.path.expanduser(self.classes_path) 62 | with open(classes_path) as f: 63 | class_names = f.readlines() 64 | class_names = [c.strip() for c in class_names] 65 | return class_names 66 | 67 | # ---------------------------------------------------# 68 | # 获得所有的先验框 69 | # ---------------------------------------------------# 70 | def _get_anchors(self): 71 | anchors_path = os.path.expanduser(self.anchors_path) 72 | with open(anchors_path) as f: 73 | anchors = f.readline() 74 | anchors = [float(x) for x in anchors.split(',')] 75 | return np.array(anchors).reshape([-1, 3, 2]) 76 | 77 | # ---------------------------------------------------# 78 | # 生成模型 79 | # ---------------------------------------------------# 80 | def generate(self): 81 | # ---------------------------------------------------# 82 | # 建立yolov4_tiny模型 83 | # ---------------------------------------------------# 84 | self.net = YoloBody(len(self.anchors[0]), len(self.class_names)) 85 | 86 | # ---------------------------------------------------# 87 | # 载入yolov4_tiny模型的权重 88 | # ---------------------------------------------------# 89 | print('Loading weights into state dict...') 90 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 91 | state_dict = torch.load(self.model_path, map_location=device) 92 | self.net.load_state_dict(state_dict) 93 | 94 | # data = torch.randn((1, 3, 416, 416)).cuda().half() 95 | # self.net_trt = torch2trt(self.net, [data], fp16_mode=True) 96 | # torch.save(self.net_trt.state_dict(), "logs/net_trt.pth") 97 | 98 | print('Finished!') 99 | 100 | if self.cuda: 101 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 102 | self.net = nn.DataParallel(self.net) 103 | #self.net = self.net.half() 104 | self.net = self.net.cuda().eval() 105 | 106 | # self.net_trt = self.net_trt.half() 107 | # self.net_trt = nn.DataParallel(self.net_trt) 108 | # self.net_trt = self.net_trt.cuda() 109 | # ---------------------------------------------------# 110 | # 建立特征层解码用的工具 111 | # ---------------------------------------------------# 112 | self.yolo_decodes = [] 113 | self.anchors_mask = [[3, 4, 5], [1, 2, 3]] 114 | for i in range(2): 115 | self.yolo_decodes.append( 116 | DecodeBox(np.reshape(self.anchors, [-1, 2])[self.anchors_mask[i]], len(self.class_names), 117 | (self.model_image_size[1], self.model_image_size[0]))) 118 | 119 | print('{} model, anchors, and classes loaded.'.format(self.model_path)) 120 | # 画框设置不同的颜色 121 | hsv_tuples = [(x / len(self.class_names), 1., 1.) 122 | for x in range(len(self.class_names))] 123 | self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples)) 124 | self.colors = list( 125 | map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), 126 | self.colors)) 127 | 128 | # ---------------------------------------------------# 129 | # 检测图片 130 | # ---------------------------------------------------# 131 | def detect_image(self, image): 132 | image_shape = np.array(np.shape(image)[0:2]) 133 | 134 | # ---------------------------------------------------------# 135 | # 给图像增加灰条,实现不失真的resize 136 | # 也可以直接resize进行识别 137 | # ---------------------------------------------------------# 138 | if self.letterbox_image: 139 | crop_img = np.array(letterbox_image(image, (self.model_image_size[1], self.model_image_size[0]))) 140 | else: 141 | crop_img = image.convert('RGB') 142 | crop_img = crop_img.resize((self.model_image_size[1], self.model_image_size[0]), Image.BICUBIC) 143 | photo = np.array(crop_img, dtype=np.float32) / 255.0 144 | photo = np.transpose(photo, (2, 0, 1)) 145 | # ---------------------------------------------------------# 146 | # 添加上batch_size维度 147 | # ---------------------------------------------------------# 148 | images = [photo] 149 | 150 | with torch.no_grad(): 151 | images = torch.from_numpy(np.asarray(images)).float() 152 | # images = torch.from_numpy(np.asarray(images)) 153 | if self.cuda: 154 | images = images.cuda() 155 | #images = images.half() 156 | 157 | # ---------------------------------------------------------# 158 | # 将图像输入网络当中进行预测! 159 | # ---------------------------------------------------------# 160 | outputs = self.net(images) 161 | # outputs_trt = self.net_trt(images) #进行TRT加速后的两种模型比较 162 | 163 | output_list = [] 164 | for i in range(2): 165 | output_list.append(self.yolo_decodes[i](outputs[i])) 166 | 167 | # ---------------------------------------------------------# 168 | # 将预测框进行堆叠,然后进行非极大抑制 169 | # ---------------------------------------------------------# 170 | output = torch.cat(output_list, 1) 171 | batch_detections = non_max_suppression(output, len(self.class_names), 172 | conf_thres=self.confidence, 173 | nms_thres=self.iou) 174 | 175 | # ---------------------------------------------------------# 176 | # 如果没有检测出物体,返回原图 177 | # ---------------------------------------------------------# 178 | try: 179 | batch_detections = batch_detections[0].cpu().numpy() 180 | except: 181 | return image 182 | 183 | # ---------------------------------------------------------# 184 | # 对预测框进行得分筛选 185 | # ---------------------------------------------------------# 186 | top_index = batch_detections[:, 4] * batch_detections[:, 5] > self.confidence 187 | top_conf = batch_detections[top_index, 4] * batch_detections[top_index, 5] 188 | top_label = np.array(batch_detections[top_index, -1], np.int32) 189 | top_bboxes = np.array(batch_detections[top_index, :4]) 190 | top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims(top_bboxes[:, 0], -1), np.expand_dims( 191 | top_bboxes[:, 1], -1), np.expand_dims(top_bboxes[:, 2], -1), np.expand_dims(top_bboxes[:, 3], -1) 192 | 193 | # -----------------------------------------------------------------# 194 | # 在图像传入网络预测前会进行letterbox_image给图像周围添加灰条 195 | # 因此生成的top_bboxes是相对于有灰条的图像的 196 | # 我们需要对其进行修改,去除灰条的部分。 197 | # -----------------------------------------------------------------# 198 | if self.letterbox_image: 199 | boxes = yolo_correct_boxes(top_ymin, top_xmin, top_ymax, top_xmax, 200 | np.array([self.model_image_size[0], self.model_image_size[1]]), image_shape) 201 | else: 202 | top_xmin = top_xmin / self.model_image_size[1] * image_shape[1] 203 | top_ymin = top_ymin / self.model_image_size[0] * image_shape[0] 204 | top_xmax = top_xmax / self.model_image_size[1] * image_shape[1] 205 | top_ymax = top_ymax / self.model_image_size[0] * image_shape[0] 206 | boxes = np.concatenate([top_ymin, top_xmin, top_ymax, top_xmax], axis=-1) 207 | 208 | font = ImageFont.truetype(font='model_data/simhei.ttf', 209 | size=np.floor(3e-2 * np.shape(image)[1] + 0.5).astype('int32')) 210 | 211 | thickness = max((np.shape(image)[0] + np.shape(image)[1]) // self.model_image_size[0], 1) 212 | 213 | for i, c in enumerate(top_label): 214 | predicted_class = self.class_names[c] 215 | score = top_conf[i] 216 | 217 | top, left, bottom, right = boxes[i] 218 | top = top - 5 219 | left = left - 5 220 | bottom = bottom + 5 221 | right = right + 5 222 | 223 | top = max(0, np.floor(top + 0.5).astype('int32')) 224 | left = max(0, np.floor(left + 0.5).astype('int32')) 225 | bottom = min(np.shape(image)[0], np.floor(bottom + 0.5).astype('int32')) 226 | right = min(np.shape(image)[1], np.floor(right + 0.5).astype('int32')) 227 | 228 | # 画框框 229 | label = '{} {:.2f}'.format(predicted_class, score) 230 | draw = ImageDraw.Draw(image) 231 | label_size = draw.textsize(label, font) 232 | label = label.encode('utf-8') 233 | print(label, top, left, bottom, right) 234 | 235 | if top - label_size[1] >= 0: 236 | text_origin = np.array([left, top - label_size[1]]) 237 | else: 238 | text_origin = np.array([left, top + 1]) 239 | 240 | for i in range(thickness): 241 | draw.rectangle( 242 | [left + i, top + i, right - i, bottom - i], 243 | outline=self.colors[self.class_names.index(predicted_class)]) 244 | draw.rectangle( 245 | [tuple(text_origin), tuple(text_origin + label_size)], 246 | fill=self.colors[self.class_names.index(predicted_class)]) 247 | draw.text(text_origin, str(label, 'UTF-8'), fill=(0, 0, 0), font=font) 248 | del draw 249 | return image 250 | --------------------------------------------------------------------------------