├── PKLot_server ├── detect_Logs │ ├── detect.txt │ ├── logs.txt │ └── 1.png ├── model_data │ ├── coco_classes.txt │ ├── voc_classes.txt │ ├── timg.jpg │ └── simhei.ttf ├── nets │ ├── __pycache__ │ │ ├── resnet.cpython-37.pyc │ │ ├── resnet.cpython-38.pyc │ │ ├── centernet.cpython-37.pyc │ │ ├── centernet.cpython-38.pyc │ │ ├── hourglass.cpython-37.pyc │ │ ├── hourglass.cpython-38.pyc │ │ ├── centernet_training.cpython-37.pyc │ │ └── centernet_training.cpython-38.pyc │ ├── resnet.py │ ├── hourglass.py │ ├── centernet.py │ └── centernet_training.py ├── utils │ ├── __pycache__ │ │ ├── utils.cpython-37.pyc │ │ └── utils.cpython-38.pyc │ └── utils.py ├── data_fack.py ├── DataServer.py ├── PKLotServer.py ├── get_gt_txt.py ├── get_dr_txt.py ├── MainWindow.py ├── predictUI.py ├── centernet.py └── get_map.py ├── .gitattributes ├── PKLot_client ├── background.jpg ├── PKlot_client.py ├── client_window.ui └── client_window.py └── README.md /PKLot_server/detect_Logs/detect.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /PKLot_server/detect_Logs/logs.txt: -------------------------------------------------------------------------------- 1 | 2023-01-09 08:08:22;44;0;0 -------------------------------------------------------------------------------- /PKLot_server/model_data/coco_classes.txt: -------------------------------------------------------------------------------- 1 | space-occupied 2 | space-empty 3 | -------------------------------------------------------------------------------- /PKLot_server/model_data/voc_classes.txt: -------------------------------------------------------------------------------- 1 | space-occupied 2 | space-empty 3 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /PKLot_client/background.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiXing-w/Vehicles-ParkingSpaces-Object-detection-and-traffic-statistics/HEAD/PKLot_client/background.jpg -------------------------------------------------------------------------------- /PKLot_server/detect_Logs/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiXing-w/Vehicles-ParkingSpaces-Object-detection-and-traffic-statistics/HEAD/PKLot_server/detect_Logs/1.png -------------------------------------------------------------------------------- /PKLot_server/model_data/timg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiXing-w/Vehicles-ParkingSpaces-Object-detection-and-traffic-statistics/HEAD/PKLot_server/model_data/timg.jpg -------------------------------------------------------------------------------- /PKLot_server/model_data/simhei.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiXing-w/Vehicles-ParkingSpaces-Object-detection-and-traffic-statistics/HEAD/PKLot_server/model_data/simhei.ttf -------------------------------------------------------------------------------- /PKLot_server/nets/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiXing-w/Vehicles-ParkingSpaces-Object-detection-and-traffic-statistics/HEAD/PKLot_server/nets/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /PKLot_server/nets/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiXing-w/Vehicles-ParkingSpaces-Object-detection-and-traffic-statistics/HEAD/PKLot_server/nets/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /PKLot_server/utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiXing-w/Vehicles-ParkingSpaces-Object-detection-and-traffic-statistics/HEAD/PKLot_server/utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /PKLot_server/utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiXing-w/Vehicles-ParkingSpaces-Object-detection-and-traffic-statistics/HEAD/PKLot_server/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /PKLot_server/nets/__pycache__/centernet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiXing-w/Vehicles-ParkingSpaces-Object-detection-and-traffic-statistics/HEAD/PKLot_server/nets/__pycache__/centernet.cpython-37.pyc -------------------------------------------------------------------------------- /PKLot_server/nets/__pycache__/centernet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiXing-w/Vehicles-ParkingSpaces-Object-detection-and-traffic-statistics/HEAD/PKLot_server/nets/__pycache__/centernet.cpython-38.pyc -------------------------------------------------------------------------------- /PKLot_server/nets/__pycache__/hourglass.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiXing-w/Vehicles-ParkingSpaces-Object-detection-and-traffic-statistics/HEAD/PKLot_server/nets/__pycache__/hourglass.cpython-37.pyc -------------------------------------------------------------------------------- /PKLot_server/nets/__pycache__/hourglass.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiXing-w/Vehicles-ParkingSpaces-Object-detection-and-traffic-statistics/HEAD/PKLot_server/nets/__pycache__/hourglass.cpython-38.pyc -------------------------------------------------------------------------------- /PKLot_server/nets/__pycache__/centernet_training.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiXing-w/Vehicles-ParkingSpaces-Object-detection-and-traffic-statistics/HEAD/PKLot_server/nets/__pycache__/centernet_training.cpython-37.pyc -------------------------------------------------------------------------------- /PKLot_server/nets/__pycache__/centernet_training.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AiXing-w/Vehicles-ParkingSpaces-Object-detection-and-traffic-statistics/HEAD/PKLot_server/nets/__pycache__/centernet_training.cpython-38.pyc -------------------------------------------------------------------------------- /PKLot_server/data_fack.py: -------------------------------------------------------------------------------- 1 | import random 2 | max_set = 42 3 | year = "2022" 4 | month = "5" 5 | day = "16" 6 | 7 | with open("./detect_Logs/countlogs", 'a') as f: 8 | for i in range(24): 9 | for j in range(60): 10 | for k in range(0, 60, 5): 11 | rdi = random.randint(0, max_set) 12 | f.write(str(i) + ";" + str(rdi) + ";" + str(max_set-rdi)) 13 | f.write("\n") 14 | -------------------------------------------------------------------------------- /PKLot_server/DataServer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def FormetDayTime(datelist): 4 | day = "" 5 | for i in range(3): 6 | if i: 7 | day += '-' 8 | day += datelist[i] 9 | 10 | time = "" 11 | for i in range(3,len(datelist)): 12 | if i != 3: 13 | time += ':' 14 | time += datelist[i] 15 | return day, time 16 | 17 | 18 | def SaveData(day, time, num): 19 | # 保存人数 20 | if not os.path.exists("dayLogs"): 21 | os.mkdir("dayLogs") 22 | with open(os.path.join("dayLogs", day), "a", encoding='utf-8') as f: 23 | f.write(str(num)) 24 | f.write("\t") 25 | f.write(time) 26 | f.write("\n") 27 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Vehicles-ParkingSpaces-Object-detection-and-traffic-statistics 2 | 通过目标检测算法检测和统计区域内的车辆与车位情况使用户能够提前获知车位情况以及对用户停车位置的引导,并能对实际车流量进行统计。 3 | 4 | # 功能介绍 5 | ## 服务器端(PKLot_server) 6 | PKLot_server中的是服务器端,负责对车位进行目标检测,判断是否已有车停在车位上,并统计计算最优停车区域,以及记录车位占用情况以供统计可视化使用,检测后的数据将存储到detect_Log中,以供数据分析使用。当服务器端需要统计时,从detect_Log中读出相应的数据并做图展示 7 | 8 | ## 客户端(PKLot_client) 9 | PKLot_client为客户端,客户端发起车位查询的请求,客户端负责接收来自服务器端检测的车位占用情况以及最优的停车区域,并将其标注在画面上起到对于用户的指引作用 10 | 11 | ## 可视化 12 | 13 | 使用PyQt5构建可视化窗口 14 | 15 | # 权值文件 16 | 17 | 权值文件下载后应该放到 \PKLot_server\model_data文件夹中 18 | 19 | **权值文件下载地址:** 20 | 21 | 链接:https://pan.baidu.com/s/1X_iAieXwVn02HEo5ihFSFw?pwd=xqbr 22 | 23 | 提取码:xqbr 24 | 25 | # 测试样例 26 | 27 | **测试样例下载链接:** 28 | 29 | 链接:https://pan.baidu.com/s/1wH5_PkoUYxQTXvEt6wzoHA?pwd=m76z 30 | 31 | 提取码:m76z 32 | 33 | # 演示视频 34 | 35 | 视频链接:[https://www.bilibili.com/video/BV1XG4y1w7kT](https://www.bilibili.com/video/BV1XG4y1w7kT) 36 | -------------------------------------------------------------------------------- /PKLot_server/PKLotServer.py: -------------------------------------------------------------------------------- 1 | #!coding=utf-8 2 | 3 | import threading 4 | import socket 5 | import struct 6 | 7 | 8 | def socket_service(): 9 | try: 10 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 11 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 12 | # 绑定端口为9001 13 | s.bind(('127.0.0.1', 9001)) 14 | # 设置监听数 15 | s.listen(10) 16 | except socket.error as msg: 17 | print(msg) 18 | sys.exit(1) 19 | print('Waiting connection...') 20 | 21 | while 1: 22 | # 等待请求并接受(程序会停留在这一旦收到连接请求即开启接受数据的线程) 23 | conn, addr = s.accept() 24 | # 接收数据 25 | t = threading.Thread(target=deal_data, args=(conn, addr)) 26 | t.start() 27 | 28 | 29 | def deal_data(conn, addr): 30 | print('Accept new connection from {0}'.format(addr)) 31 | # conn.settimeout(500) 32 | # 收到请求后的回复 33 | conn.send('Hi, Welcome to the server!'.encode('utf-8')) 34 | 35 | while 1: 36 | # 申请相同大小的空间存放发送过来的文件名与文件大小信息 37 | fileinfo_size = struct.calcsize('128sl') 38 | # 接收文件名与文件大小信息 39 | buf = conn.recv(fileinfo_size) 40 | # 判断是否接收到文件头信息 41 | if buf: 42 | # 获取文件名和文件大小 43 | filename, filesize = struct.unpack('128sl', buf) 44 | fn = filename.strip(b'\00') 45 | fn = fn.decode() 46 | print('file new name is {0}, filesize if {1}'.format(str(fn), filesize)) 47 | 48 | recvd_size = 0 # 定义已接收文件的大小 49 | # 存储在该脚本所在目录下面 50 | fp = open('./' + str(fn), 'wb') 51 | print('start receiving...') 52 | 53 | # 将分批次传输的二进制流依次写入到文件 54 | while not recvd_size == filesize: 55 | if filesize - recvd_size > 1024: 56 | data = conn.recv(1024) 57 | recvd_size += len(data) 58 | else: 59 | data = conn.recv(filesize - recvd_size) 60 | recvd_size = filesize 61 | fp.write(data) 62 | fp.close() 63 | print('end receive...') 64 | # 传输结束断开连接 65 | conn.close() 66 | break 67 | 68 | 69 | if __name__ == "__main__": 70 | socket_service() 71 | -------------------------------------------------------------------------------- /PKLot_server/get_gt_txt.py: -------------------------------------------------------------------------------- 1 | #----------------------------------------------------# 2 | # 获取测试集的ground-truth 3 | # 具体视频教程可查看 4 | # https://www.bilibili.com/video/BV1zE411u7Vw 5 | #----------------------------------------------------# 6 | import sys 7 | import os 8 | import glob 9 | import xml.etree.ElementTree as ET 10 | 11 | ''' 12 | !!!!!!!!!!!!!注意事项!!!!!!!!!!!!! 13 | # 这一部分是当xml有无关的类的时候,下方有代码可以进行筛选! 14 | ''' 15 | #---------------------------------------------------# 16 | # 获得类 17 | #---------------------------------------------------# 18 | def get_classes(classes_path): 19 | '''loads the classes''' 20 | with open(classes_path) as f: 21 | class_names = f.readlines() 22 | class_names = [c.strip() for c in class_names] 23 | return class_names 24 | 25 | image_ids = open('VOCdevkit/VOC2007/ImageSets/Main/test.txt').read().strip().split() 26 | 27 | if not os.path.exists("./input"): 28 | os.makedirs("./input") 29 | if not os.path.exists("./input/ground-truth"): 30 | os.makedirs("./input/ground-truth") 31 | 32 | for image_id in image_ids: 33 | with open("./input/ground-truth/"+image_id+".txt", "w") as new_f: 34 | root = ET.parse("VOCdevkit/VOC2007/Annotations/"+image_id+".xml").getroot() 35 | for obj in root.findall('object'): 36 | difficult_flag = False 37 | if obj.find('difficult')!=None: 38 | difficult = obj.find('difficult').text 39 | if int(difficult)==1: 40 | difficult_flag = True 41 | obj_name = obj.find('name').text 42 | ''' 43 | !!!!!!!!!!!!注意事项!!!!!!!!!!!! 44 | # 这一部分是当xml有无关的类的时候,可以取消下面代码的注释 45 | # 利用对应的classes.txt来进行筛选!!!!!!!!!!!! 46 | ''' 47 | # classes_path = 'model_data/voc_classes.txt' 48 | # class_names = get_classes(classes_path) 49 | # if obj_name not in class_names: 50 | # continue 51 | 52 | bndbox = obj.find('bndbox') 53 | left = bndbox.find('xmin').text 54 | top = bndbox.find('ymin').text 55 | right = bndbox.find('xmax').text 56 | bottom = bndbox.find('ymax').text 57 | 58 | if difficult_flag: 59 | new_f.write("%s %s %s %s %s difficult\n" % (obj_name, left, top, right, bottom)) 60 | else: 61 | new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom)) 62 | 63 | print("Conversion completed!") 64 | -------------------------------------------------------------------------------- /PKLot_client/PKlot_client.py: -------------------------------------------------------------------------------- 1 | from client_window import Ui_MainWindow 2 | from PyQt5.QtWidgets import QWidget, QMainWindow, QApplication, QGraphicsScene, QGraphicsPixmapItem, QFileDialog, QMessageBox 3 | from PyQt5.QtGui import QImage, QPixmap 4 | import sys 5 | import os 6 | import datetime 7 | import cv2 8 | import numpy as np 9 | from PIL import Image, ImageDraw, ImageFont 10 | 11 | import multiprocessing as mp 12 | import socket 13 | from PIL import Image 14 | from io import BytesIO 15 | 16 | 17 | class client_window(QMainWindow, Ui_MainWindow): 18 | def __init__(self): 19 | super().__init__() 20 | self.setupUi(self) 21 | img_src = cv2.imread("background.jpg") # 读取图像 22 | img_src = cv2.cvtColor(img_src, cv2.COLOR_BGR2RGB) # 转换图像通道 23 | label_width = self.label.width() 24 | label_height = self.label.height() 25 | temp_imgSrc = QImage(img_src[:], img_src.shape[1], img_src.shape[0], img_src.shape[1] * 3, QImage.Format_RGB888) 26 | # 将图片转换为QPixmap方便显示 27 | self.pixmap_imgSrc = QPixmap.fromImage(temp_imgSrc).scaled(label_width, label_height) 28 | self.label.setPixmap(QPixmap(self.pixmap_imgSrc)) 29 | self.pushButton.clicked.connect(self.request) 30 | 31 | def request(self): 32 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 33 | sock.connect(('127.0.0.1', 8000)) 34 | sock.send(b'123') # (b'GET / HTTP/1.1\r\nHost: 127.0.0.1:8000\r\n\r\n') 35 | data = sock.recv(4096) 36 | date, op, ep, cnt = data.decode().split(';') 37 | self.lineEdit.setText(date) 38 | self.lineEdit_2.setText(op) 39 | self.lineEdit_3.setText(ep) 40 | sock.close() 41 | cnt = int(cnt.strip()) 42 | img_src = cv2.imread("background.jpg") # 读取图像 43 | img_src = cv2.cvtColor(img_src, cv2.COLOR_BGR2RGB) # 转换图像通道 44 | 45 | label_width = self.label.width() 46 | label_height = self.label.height() 47 | 48 | if cnt == 1: 49 | cv2.rectangle(img_src, (10, 10), (img_src.shape[0] // 2, img_src.shape[1] // 2), (255, 0, 0), 5) 50 | textx = (img_src.shape[0] // 2 - 15) // 3 51 | texty = (img_src.shape[1] // 2 - 50) // 3 52 | cv2.putText(img_src, "Best", (textx,texty), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 0), 5) 53 | cv2.putText(img_src, "Region", (textx - 15, texty + 80), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 0), 5) 54 | elif cnt == 2: 55 | cv2.rectangle(img_src, (img_src.shape[0] // 2, 0), (img_src.shape[0], img_src.shape[1] // 2 - 10), (255, 0, 0), 5) 56 | textx = img_src.shape[0] // 2 + (img_src.shape[0] // 2) // 3 57 | texty = (img_src.shape[1] // 2 - 50) // 3 58 | cv2.putText(img_src, "Best", (textx, texty), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 0), 5) 59 | cv2.putText(img_src, "Region", (textx - 15, texty + 80), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 0), 5) 60 | elif cnt == 3: 61 | cv2.rectangle(img_src, (10, img_src.shape[0] // 2), (img_src.shape[0] // 2, img_src.shape[1] - 10), (255, 0, 0), 5) 62 | textx = (img_src.shape[0] // 2 - 15) // 3 63 | texty = img_src.shape[0] // 2 + (img_src.shape[1] // 2 - 50) // 3 64 | cv2.putText(img_src, "Best", (textx, texty), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 0), 5) 65 | cv2.putText(img_src, "Region", (textx - 15, texty + 80), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 0), 5) 66 | 67 | elif cnt == 4: 68 | cv2.rectangle(img_src, (img_src.shape[0] // 2, img_src.shape[1] // 2), (img_src.shape[0] - 10, img_src.shape[1] - 10), (255, 0, 0), 5) 69 | textx = img_src.shape[0] // 2 + (img_src.shape[0] // 2 - 15) // 3 70 | texty = img_src.shape[0] // 2 + (img_src.shape[1] // 2 - 50) // 3 71 | cv2.putText(img_src, "Best", (textx, texty), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 0), 5) 72 | cv2.putText(img_src, "Region", (textx - 15, texty + 80), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 0), 5) 73 | 74 | 75 | temp_imgSrc = QImage(img_src[:], img_src.shape[1], img_src.shape[0], img_src.shape[1] * 3, QImage.Format_RGB888) 76 | # 将图片转换为QPixmap方便显示 77 | self.pixmap_imgSrc = QPixmap.fromImage(temp_imgSrc).scaled(label_width, label_height) 78 | self.label.setPixmap(QPixmap(self.pixmap_imgSrc)) 79 | 80 | if __name__ == '__main__': 81 | app = QApplication(sys.argv) 82 | window = client_window() 83 | window.show() 84 | sys.exit(app.exec_()) 85 | -------------------------------------------------------------------------------- /PKLot_client/client_window.ui: -------------------------------------------------------------------------------- 1 | 2 | 3 | MainWindow 4 | 5 | 6 | 7 | 0 8 | 0 9 | 677 10 | 385 11 | 12 | 13 | 14 | MainWindow 15 | 16 | 17 | 18 | 19 | 20 | 20 21 | 20 22 | 261 23 | 291 24 | 25 | 26 | 27 | 连接服务器失败...请稍后重试 28 | 29 | 30 | 31 | 32 | 33 | 30 34 | 520 35 | 631 36 | 41 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 313 47 | 40 48 | 321 49 | 271 50 | 51 | 52 | 53 | 54 | 55 | 56 | Qt::Vertical 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 查询时间: 68 | 69 | 70 | 71 | 72 | 73 | 74 | false 75 | 76 | 77 | true 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 已占车位: 89 | 90 | 91 | 92 | 93 | 94 | 95 | false 96 | 97 | 98 | true 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 剩余车位: 110 | 111 | 112 | 113 | 114 | 115 | 116 | false 117 | 118 | 119 | true 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 查询车位 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 0 141 | 0 142 | 677 143 | 26 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | -------------------------------------------------------------------------------- /PKLot_client/client_window.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Form implementation generated from reading ui file 'client_window.ui' 4 | # 5 | # Created by: PyQt5 UI code generator 5.15.4 6 | # 7 | # WARNING: Any manual changes made to this file will be lost when pyuic5 is 8 | # run again. Do not edit this file unless you know what you are doing. 9 | 10 | 11 | from PyQt5 import QtCore, QtGui, QtWidgets 12 | 13 | 14 | class Ui_MainWindow(object): 15 | def setupUi(self, MainWindow): 16 | MainWindow.setObjectName("MainWindow") 17 | MainWindow.resize(677, 385) 18 | self.centralwidget = QtWidgets.QWidget(MainWindow) 19 | self.centralwidget.setObjectName("centralwidget") 20 | self.label = QtWidgets.QLabel(self.centralwidget) 21 | self.label.setGeometry(QtCore.QRect(20, 20, 261, 291)) 22 | self.label.setObjectName("label") 23 | self.label_2 = QtWidgets.QLabel(self.centralwidget) 24 | self.label_2.setGeometry(QtCore.QRect(30, 520, 631, 41)) 25 | self.label_2.setText("") 26 | self.label_2.setObjectName("label_2") 27 | self.layoutWidget = QtWidgets.QWidget(self.centralwidget) 28 | self.layoutWidget.setGeometry(QtCore.QRect(313, 40, 321, 271)) 29 | self.layoutWidget.setObjectName("layoutWidget") 30 | self.horizontalLayout_4 = QtWidgets.QHBoxLayout(self.layoutWidget) 31 | self.horizontalLayout_4.setContentsMargins(0, 0, 0, 0) 32 | self.horizontalLayout_4.setObjectName("horizontalLayout_4") 33 | self.line = QtWidgets.QFrame(self.layoutWidget) 34 | self.line.setFrameShape(QtWidgets.QFrame.VLine) 35 | self.line.setFrameShadow(QtWidgets.QFrame.Sunken) 36 | self.line.setObjectName("line") 37 | self.horizontalLayout_4.addWidget(self.line) 38 | self.verticalLayout = QtWidgets.QVBoxLayout() 39 | self.verticalLayout.setObjectName("verticalLayout") 40 | self.horizontalLayout = QtWidgets.QHBoxLayout() 41 | self.horizontalLayout.setObjectName("horizontalLayout") 42 | self.label_3 = QtWidgets.QLabel(self.layoutWidget) 43 | self.label_3.setObjectName("label_3") 44 | self.horizontalLayout.addWidget(self.label_3) 45 | self.lineEdit = QtWidgets.QLineEdit(self.layoutWidget) 46 | self.lineEdit.setEnabled(False) 47 | self.lineEdit.setReadOnly(True) 48 | self.lineEdit.setObjectName("lineEdit") 49 | self.horizontalLayout.addWidget(self.lineEdit) 50 | self.verticalLayout.addLayout(self.horizontalLayout) 51 | self.horizontalLayout_2 = QtWidgets.QHBoxLayout() 52 | self.horizontalLayout_2.setObjectName("horizontalLayout_2") 53 | self.label_4 = QtWidgets.QLabel(self.layoutWidget) 54 | self.label_4.setObjectName("label_4") 55 | self.horizontalLayout_2.addWidget(self.label_4) 56 | self.lineEdit_2 = QtWidgets.QLineEdit(self.layoutWidget) 57 | self.lineEdit_2.setEnabled(False) 58 | self.lineEdit_2.setReadOnly(True) 59 | self.lineEdit_2.setObjectName("lineEdit_2") 60 | self.horizontalLayout_2.addWidget(self.lineEdit_2) 61 | self.verticalLayout.addLayout(self.horizontalLayout_2) 62 | self.horizontalLayout_3 = QtWidgets.QHBoxLayout() 63 | self.horizontalLayout_3.setObjectName("horizontalLayout_3") 64 | self.label_5 = QtWidgets.QLabel(self.layoutWidget) 65 | self.label_5.setObjectName("label_5") 66 | self.horizontalLayout_3.addWidget(self.label_5) 67 | self.lineEdit_3 = QtWidgets.QLineEdit(self.layoutWidget) 68 | self.lineEdit_3.setEnabled(False) 69 | self.lineEdit_3.setReadOnly(True) 70 | self.lineEdit_3.setObjectName("lineEdit_3") 71 | self.horizontalLayout_3.addWidget(self.lineEdit_3) 72 | self.verticalLayout.addLayout(self.horizontalLayout_3) 73 | self.pushButton = QtWidgets.QPushButton(self.layoutWidget) 74 | self.pushButton.setObjectName("pushButton") 75 | self.verticalLayout.addWidget(self.pushButton) 76 | self.horizontalLayout_4.addLayout(self.verticalLayout) 77 | MainWindow.setCentralWidget(self.centralwidget) 78 | self.menubar = QtWidgets.QMenuBar(MainWindow) 79 | self.menubar.setGeometry(QtCore.QRect(0, 0, 677, 26)) 80 | self.menubar.setObjectName("menubar") 81 | MainWindow.setMenuBar(self.menubar) 82 | self.statusbar = QtWidgets.QStatusBar(MainWindow) 83 | self.statusbar.setObjectName("statusbar") 84 | MainWindow.setStatusBar(self.statusbar) 85 | 86 | self.retranslateUi(MainWindow) 87 | QtCore.QMetaObject.connectSlotsByName(MainWindow) 88 | 89 | def retranslateUi(self, MainWindow): 90 | _translate = QtCore.QCoreApplication.translate 91 | MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow")) 92 | self.label.setText(_translate("MainWindow", "连接服务器失败...请稍后重试")) 93 | self.label_3.setText(_translate("MainWindow", "查询时间:")) 94 | self.label_4.setText(_translate("MainWindow", "已占车位:")) 95 | self.label_5.setText(_translate("MainWindow", "剩余车位:")) 96 | self.pushButton.setText(_translate("MainWindow", "查询车位")) 97 | -------------------------------------------------------------------------------- /PKLot_server/get_dr_txt.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | from PIL import Image 6 | from tqdm import tqdm 7 | 8 | from centernet import CenterNet 9 | from nets.centernet import centernet 10 | from utils.utils import centernet_correct_boxes, letterbox_image, nms 11 | 12 | gpus = tf.config.experimental.list_physical_devices(device_type='GPU') 13 | for gpu in gpus: 14 | tf.config.experimental.set_memory_growth(gpu, True) 15 | 16 | ''' 17 | 这里设置的门限值较低是因为计算map需要用到不同门限条件下的Recall和Precision值。 18 | 所以只有保留的框足够多,计算的map才会更精确,详情可以了解map的原理。 19 | 计算map时输出的Recall和Precision值指的是门限为0.5时的Recall和Precision值。 20 | 21 | 此处获得的./input/detection-results/里面的txt的框的数量会比直接predict多一些,这是因为这里的门限低, 22 | 目的是为了计算不同门限条件下的Recall和Precision值,从而实现map的计算。 23 | 24 | 这里的self.nms_threhold指的是非极大抑制所用到的iou,具体的可以了解非极大抑制的原理, 25 | 如果低分框与高分框的iou大于这里设定的self.nms_threhold,那么该低分框将会被剔除。 26 | 27 | 可能有些同学知道有0.5和0.5:0.95的mAP,这里的self.nms_threhold=0.5不代表mAP0.5。 28 | 如果想要设定mAP0.x,比如设定mAP0.75,可以去get_map.py设定MINOVERLAP。 29 | ''' 30 | def preprocess_image(image): 31 | mean = [0.40789655, 0.44719303, 0.47026116] 32 | std = [0.2886383 , 0.27408165, 0.27809834] 33 | return ((np.float32(image) / 255.) - mean) / std 34 | 35 | class mAP_CenterNet(CenterNet): 36 | #---------------------------------------------------# 37 | # 检测图片 38 | #---------------------------------------------------# 39 | def detect_image(self,image_id,image): 40 | f = open("./input/detection-results/"+image_id+".txt","w") 41 | self.confidence = 0.01 42 | self.nms_threhold = 0.5 43 | #---------------------------------------------------------# 44 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 45 | #---------------------------------------------------------# 46 | image = image.convert('RGB') 47 | 48 | image_shape = np.array(np.shape(image)[0:2]) 49 | #---------------------------------------------------------# 50 | # 给图像增加灰条,实现不失真的resize 51 | #---------------------------------------------------------# 52 | crop_img = letterbox_image(image, [self.input_shape[0],self.input_shape[1]]) 53 | #----------------------------------------------------------------------------------# 54 | # 将RGB转化成BGR,这是因为原始的centernet_hourglass权值是使用BGR通道的图片训练的 55 | #----------------------------------------------------------------------------------# 56 | photo = np.array(crop_img,dtype = np.float32)[:,:,::-1] 57 | #-----------------------------------------------------------# 58 | # 图片预处理,归一化。获得的photo的shape为[1, 512, 512, 3] 59 | #-----------------------------------------------------------# 60 | photo = np.reshape(preprocess_image(photo),[1,self.input_shape[0],self.input_shape[1],self.input_shape[2]]) 61 | 62 | preds = self.get_pred(photo).numpy() 63 | #-------------------------------------------------------# 64 | # 对于centernet网络来讲,确立中心非常重要。 65 | # 对于大目标而言,会存在许多的局部信息。 66 | # 此时对于同一个大目标,中心点比较难以确定。 67 | # 使用最大池化的非极大抑制方法无法去除局部框 68 | # 所以我还是写了另外一段对框进行非极大抑制的代码 69 | # 实际测试中,hourglass为主干网络时有无额外的nms相差不大,resnet相差较大。 70 | #-------------------------------------------------------# 71 | if self.nms: 72 | preds = np.array(nms(preds,self.nms_threhold)) 73 | 74 | if len(preds[0])<=0: 75 | return 76 | 77 | #-----------------------------------------------------------# 78 | # 将预测结果转换成小数的形式 79 | #-----------------------------------------------------------# 80 | preds[0][:,0:4] = preds[0][:,0:4]/(self.input_shape[0]/4) 81 | 82 | det_label = preds[0][:, -1] 83 | det_conf = preds[0][:, -2] 84 | det_xmin, det_ymin, det_xmax, det_ymax = preds[0][:, 0], preds[0][:, 1], preds[0][:, 2], preds[0][:, 3] 85 | #-----------------------------------------------------------# 86 | # 筛选出其中得分高于confidence的框 87 | #-----------------------------------------------------------# 88 | top_indices = [i for i, conf in enumerate(det_conf) if conf >= self.confidence] 89 | top_conf = det_conf[top_indices] 90 | top_label_indices = det_label[top_indices].tolist() 91 | top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims(det_xmin[top_indices],-1),np.expand_dims(det_ymin[top_indices],-1),np.expand_dims(det_xmax[top_indices],-1),np.expand_dims(det_ymax[top_indices],-1) 92 | 93 | #-----------------------------------------------------------# 94 | # 去掉灰条部分 95 | #-----------------------------------------------------------# 96 | boxes = centernet_correct_boxes(top_ymin,top_xmin,top_ymax,top_xmax,np.array([self.input_shape[0],self.input_shape[1]]),image_shape) 97 | 98 | for i, c in enumerate(top_label_indices): 99 | predicted_class = self.class_names[int(c)] 100 | score = str(top_conf[i]) 101 | 102 | top, left, bottom, right = boxes[i] 103 | f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom)))) 104 | 105 | f.close() 106 | return 107 | 108 | centernet = mAP_CenterNet() 109 | image_ids = open('VOCdevkit/VOC2007/ImageSets/Main/test.txt').read().strip().split() 110 | 111 | if not os.path.exists("./input"): 112 | os.makedirs("./input") 113 | if not os.path.exists("./input/detection-results"): 114 | os.makedirs("./input/detection-results") 115 | if not os.path.exists("./input/images-optional"): 116 | os.makedirs("./input/images-optional") 117 | 118 | for image_id in tqdm(image_ids): 119 | image_path = "./VOCdevkit/VOC2007/JPEGImages/"+image_id+".jpg" 120 | image = Image.open(image_path) 121 | # image.save("./input/images-optional/"+image_id+".jpg") 122 | centernet.detect_image(image_id,image) 123 | 124 | print("Conversion completed!") 125 | -------------------------------------------------------------------------------- /PKLot_server/MainWindow.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Form implementation generated from reading ui file 'MainWindow.ui' 4 | # 5 | # Created by: PyQt5 UI code generator 5.15.4 6 | # 7 | # WARNING: Any manual changes made to this file will be lost when pyuic5 is 8 | # run again. Do not edit this file unless you know what you are doing. 9 | 10 | 11 | from PyQt5 import QtCore, QtGui, QtWidgets 12 | 13 | 14 | class Ui_MainWindow(object): 15 | def setupUi(self, MainWindow): 16 | MainWindow.setObjectName("MainWindow") 17 | MainWindow.resize(800, 600) 18 | self.centralwidget = QtWidgets.QWidget(MainWindow) 19 | self.centralwidget.setObjectName("centralwidget") 20 | self.layoutWidget = QtWidgets.QWidget(self.centralwidget) 21 | self.layoutWidget.setGeometry(QtCore.QRect(21, 11, 131, 551)) 22 | self.layoutWidget.setObjectName("layoutWidget") 23 | self.verticalLayout = QtWidgets.QVBoxLayout(self.layoutWidget) 24 | self.verticalLayout.setContentsMargins(0, 0, 0, 0) 25 | self.verticalLayout.setObjectName("verticalLayout") 26 | self.pushButton_1 = QtWidgets.QPushButton(self.layoutWidget) 27 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Fixed) 28 | sizePolicy.setHorizontalStretch(20) 29 | sizePolicy.setVerticalStretch(20) 30 | sizePolicy.setHeightForWidth(self.pushButton_1.sizePolicy().hasHeightForWidth()) 31 | self.pushButton_1.setSizePolicy(sizePolicy) 32 | self.pushButton_1.setSizeIncrement(QtCore.QSize(20, 20)) 33 | self.pushButton_1.setBaseSize(QtCore.QSize(20, 20)) 34 | font = QtGui.QFont() 35 | font.setPointSize(10) 36 | self.pushButton_1.setFont(font) 37 | self.pushButton_1.setObjectName("pushButton_1") 38 | self.verticalLayout.addWidget(self.pushButton_1) 39 | self.pushButton_2 = QtWidgets.QPushButton(self.layoutWidget) 40 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Fixed) 41 | sizePolicy.setHorizontalStretch(20) 42 | sizePolicy.setVerticalStretch(20) 43 | sizePolicy.setHeightForWidth(self.pushButton_2.sizePolicy().hasHeightForWidth()) 44 | self.pushButton_2.setSizePolicy(sizePolicy) 45 | self.pushButton_2.setSizeIncrement(QtCore.QSize(20, 20)) 46 | self.pushButton_2.setBaseSize(QtCore.QSize(20, 20)) 47 | font = QtGui.QFont() 48 | font.setPointSize(10) 49 | self.pushButton_2.setFont(font) 50 | self.pushButton_2.setObjectName("pushButton_2") 51 | self.verticalLayout.addWidget(self.pushButton_2) 52 | self.pushButton_3 = QtWidgets.QPushButton(self.layoutWidget) 53 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Fixed) 54 | sizePolicy.setHorizontalStretch(20) 55 | sizePolicy.setVerticalStretch(20) 56 | sizePolicy.setHeightForWidth(self.pushButton_3.sizePolicy().hasHeightForWidth()) 57 | self.pushButton_3.setSizePolicy(sizePolicy) 58 | self.pushButton_3.setSizeIncrement(QtCore.QSize(20, 20)) 59 | self.pushButton_3.setBaseSize(QtCore.QSize(20, 20)) 60 | font = QtGui.QFont() 61 | font.setPointSize(10) 62 | self.pushButton_3.setFont(font) 63 | self.pushButton_3.setObjectName("pushButton_3") 64 | self.verticalLayout.addWidget(self.pushButton_3) 65 | self.pushButton = QtWidgets.QPushButton(self.layoutWidget) 66 | self.pushButton.setObjectName("pushButton") 67 | self.verticalLayout.addWidget(self.pushButton) 68 | self.label = QtWidgets.QLabel(self.centralwidget) 69 | self.label.setGeometry(QtCore.QRect(160, 10, 621, 511)) 70 | self.label.setObjectName("label") 71 | MainWindow.setCentralWidget(self.centralwidget) 72 | self.statusbar = QtWidgets.QStatusBar(MainWindow) 73 | self.statusbar.setObjectName("statusbar") 74 | MainWindow.setStatusBar(self.statusbar) 75 | self.action_16 = QtWidgets.QAction(MainWindow) 76 | self.action_16.setObjectName("action_16") 77 | self.actiona = QtWidgets.QAction(MainWindow) 78 | self.actiona.setObjectName("actiona") 79 | self.actiona_2 = QtWidgets.QAction(MainWindow) 80 | self.actiona_2.setObjectName("actiona_2") 81 | self.actionb = QtWidgets.QAction(MainWindow) 82 | self.actionb.setObjectName("actionb") 83 | self.actiona_3 = QtWidgets.QAction(MainWindow) 84 | self.actiona_3.setObjectName("actiona_3") 85 | self.actiona_4 = QtWidgets.QAction(MainWindow) 86 | self.actiona_4.setObjectName("actiona_4") 87 | self.actionb_2 = QtWidgets.QAction(MainWindow) 88 | self.actionb_2.setObjectName("actionb_2") 89 | self.actiona_5 = QtWidgets.QAction(MainWindow) 90 | self.actiona_5.setObjectName("actiona_5") 91 | self.actionb_3 = QtWidgets.QAction(MainWindow) 92 | self.actionb_3.setObjectName("actionb_3") 93 | 94 | self.retranslateUi(MainWindow) 95 | QtCore.QMetaObject.connectSlotsByName(MainWindow) 96 | 97 | def retranslateUi(self, MainWindow): 98 | _translate = QtCore.QCoreApplication.translate 99 | MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow")) 100 | self.pushButton_1.setText(_translate("MainWindow", "图片检测")) 101 | self.pushButton_2.setText(_translate("MainWindow", "视频检测")) 102 | self.pushButton_3.setText(_translate("MainWindow", "实时监控")) 103 | self.pushButton.setText(_translate("MainWindow", "统计时段")) 104 | self.label.setText(_translate("MainWindow", "检测中...")) 105 | self.action_16.setText(_translate("MainWindow", "按天统计")) 106 | self.actiona.setText(_translate("MainWindow", "a")) 107 | self.actiona_2.setText(_translate("MainWindow", "按月统计")) 108 | self.actionb.setText(_translate("MainWindow", "按年统计")) 109 | self.actiona_3.setText(_translate("MainWindow", "a")) 110 | self.actiona_4.setText(_translate("MainWindow", "实时")) 111 | self.actionb_2.setText(_translate("MainWindow", "打开文件")) 112 | self.actiona_5.setText(_translate("MainWindow", "暂停")) 113 | self.actionb_3.setText(_translate("MainWindow", "全屏")) 114 | -------------------------------------------------------------------------------- /PKLot_server/nets/resnet.py: -------------------------------------------------------------------------------- 1 | #-------------------------------------------------------------# 2 | # ResNet50的网络部分 3 | #-------------------------------------------------------------# 4 | from tensorflow.keras import layers 5 | from tensorflow.keras.layers import (Activation, BatchNormalization, Conv2D, 6 | Conv2DTranspose, Dropout, MaxPooling2D, 7 | ZeroPadding2D) 8 | from tensorflow.keras.regularizers import l2 9 | from tensorflow.keras.initializers import RandomNormal 10 | 11 | 12 | def identity_block(input_tensor, kernel_size, filters, stage, block): 13 | 14 | filters1, filters2, filters3 = filters 15 | 16 | conv_name_base = 'res' + str(stage) + block + '_branch' 17 | bn_name_base = 'bn' + str(stage) + block + '_branch' 18 | 19 | x = Conv2D(filters1, (1, 1), kernel_initializer=RandomNormal(stddev=0.02), name=conv_name_base + '2a', use_bias=False)(input_tensor) 20 | x = BatchNormalization(name=bn_name_base + '2a')(x) 21 | x = Activation('relu')(x) 22 | 23 | x = Conv2D(filters2, kernel_size, padding='same', kernel_initializer=RandomNormal(stddev=0.02), name=conv_name_base + '2b', use_bias=False)(x) 24 | x = BatchNormalization(name=bn_name_base + '2b')(x) 25 | x = Activation('relu')(x) 26 | 27 | x = Conv2D(filters3, (1, 1), kernel_initializer=RandomNormal(stddev=0.02), name=conv_name_base + '2c', use_bias=False)(x) 28 | x = BatchNormalization(name=bn_name_base + '2c')(x) 29 | 30 | x = layers.add([x, input_tensor]) 31 | x = Activation('relu')(x) 32 | return x 33 | 34 | 35 | def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)): 36 | 37 | filters1, filters2, filters3 = filters 38 | 39 | conv_name_base = 'res' + str(stage) + block + '_branch' 40 | bn_name_base = 'bn' + str(stage) + block + '_branch' 41 | 42 | x = Conv2D(filters1, (1, 1), strides=strides, kernel_initializer=RandomNormal(stddev=0.02), 43 | name=conv_name_base + '2a', use_bias=False)(input_tensor) 44 | x = BatchNormalization(name=bn_name_base + '2a')(x) 45 | x = Activation('relu')(x) 46 | 47 | x = Conv2D(filters2, kernel_size, padding='same', kernel_initializer=RandomNormal(stddev=0.02), 48 | name=conv_name_base + '2b', use_bias=False)(x) 49 | x = BatchNormalization(name=bn_name_base + '2b')(x) 50 | x = Activation('relu')(x) 51 | 52 | x = Conv2D(filters3, (1, 1), kernel_initializer=RandomNormal(stddev=0.02), name=conv_name_base + '2c', use_bias=False)(x) 53 | x = BatchNormalization(name=bn_name_base + '2c')(x) 54 | 55 | shortcut = Conv2D(filters3, (1, 1), strides=strides, kernel_initializer=RandomNormal(stddev=0.02), 56 | name=conv_name_base + '1', use_bias=False)(input_tensor) 57 | shortcut = BatchNormalization(name=bn_name_base + '1')(shortcut) 58 | 59 | x = layers.add([x, shortcut]) 60 | x = Activation('relu')(x) 61 | return x 62 | 63 | 64 | def ResNet50(inputs): 65 | # 512x512x3 66 | x = ZeroPadding2D((3, 3))(inputs) 67 | # 256,256,64 68 | x = Conv2D(64, (7, 7), kernel_initializer=RandomNormal(stddev=0.02), strides=(2, 2), name='conv1', use_bias=False)(x) 69 | x = BatchNormalization(name='bn_conv1')(x) 70 | x = Activation('relu')(x) 71 | 72 | # 256,256,64 -> 128,128,64 73 | x = MaxPooling2D((3, 3), strides=(2, 2), padding="same")(x) 74 | 75 | # 128,128,64 -> 128,128,256 76 | x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) 77 | x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') 78 | x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') 79 | 80 | # 128,128,256 -> 64,64,512 81 | x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') 82 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') 83 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') 84 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') 85 | 86 | # 64,64,512 -> 32,32,1024 87 | x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') 88 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') 89 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') 90 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') 91 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') 92 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') 93 | 94 | # 32,32,1024 -> 16,16,2048 95 | x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') 96 | x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') 97 | x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') 98 | 99 | return x 100 | 101 | def centernet_head(x,num_classes): 102 | x = Dropout(rate=0.5)(x) 103 | #-------------------------------# 104 | # 解码器 105 | #-------------------------------# 106 | num_filters = 256 107 | # 16, 16, 2048 -> 32, 32, 256 -> 64, 64, 128 -> 128, 128, 64 108 | for i in range(3): 109 | # 进行上采样 110 | x = Conv2DTranspose(num_filters // pow(2, i), (4, 4), strides=2, use_bias=False, padding='same', 111 | kernel_initializer='he_normal', 112 | kernel_regularizer=l2(5e-4))(x) 113 | x = BatchNormalization()(x) 114 | x = Activation('relu')(x) 115 | # 最终获得128,128,64的特征层 116 | # hm header 117 | y1 = Conv2D(64, 3, padding='same', use_bias=False, kernel_initializer=RandomNormal(stddev=0.02), kernel_regularizer=l2(5e-4))(x) 118 | y1 = BatchNormalization()(y1) 119 | y1 = Activation('relu')(y1) 120 | y1 = Conv2D(num_classes, 1, kernel_initializer=RandomNormal(stddev=0.02), kernel_regularizer=l2(5e-4), activation='sigmoid')(y1) 121 | 122 | # wh header 123 | y2 = Conv2D(64, 3, padding='same', use_bias=False, kernel_initializer=RandomNormal(stddev=0.02), kernel_regularizer=l2(5e-4))(x) 124 | y2 = BatchNormalization()(y2) 125 | y2 = Activation('relu')(y2) 126 | y2 = Conv2D(2, 1, kernel_initializer=RandomNormal(stddev=0.02), kernel_regularizer=l2(5e-4))(y2) 127 | 128 | # reg header 129 | y3 = Conv2D(64, 3, padding='same', use_bias=False, kernel_initializer=RandomNormal(stddev=0.02), kernel_regularizer=l2(5e-4))(x) 130 | y3 = BatchNormalization()(y3) 131 | y3 = Activation('relu')(y3) 132 | y3 = Conv2D(2, 1, kernel_initializer=RandomNormal(stddev=0.02), kernel_regularizer=l2(5e-4))(y3) 133 | return y1, y2, y3 134 | -------------------------------------------------------------------------------- /PKLot_server/nets/hourglass.py: -------------------------------------------------------------------------------- 1 | import tensorflow.keras.backend as K 2 | from tensorflow.keras.layers import (Activation, Add, BatchNormalization, Conv2D, Input, UpSampling2D, 3 | ZeroPadding2D) 4 | from tensorflow.keras.models import Model 5 | from tensorflow.keras.initializers import RandomNormal 6 | 7 | 8 | def conv2d(x, k, out_dim, name, stride=1): 9 | padding = (k - 1) // 2 10 | x = ZeroPadding2D(padding=padding, name=name + '.pad')(x) 11 | x = Conv2D(out_dim, k, strides=stride, kernel_initializer=RandomNormal(stddev=0.02), use_bias=False, name=name + '.conv')(x) 12 | x = BatchNormalization(epsilon=1e-5, name=name + '.bn')(x) 13 | x = Activation('relu', name=name + '.relu')(x) 14 | return x 15 | 16 | def residual(x, out_dim, name, stride=1): 17 | #-----------------------------------# 18 | # 残差网络结构 19 | # 两个形态 20 | # 1、残差边有卷积,改变维度 21 | # 2、残差边无卷积,加大深度 22 | #-----------------------------------# 23 | shortcut = x 24 | num_channels = K.int_shape(shortcut)[-1] 25 | 26 | x = ZeroPadding2D(padding=1, name=name + '.pad1')(x) 27 | x = Conv2D(out_dim, 3, strides=stride, kernel_initializer=RandomNormal(stddev=0.02), use_bias=False, name=name + '.conv1')(x) 28 | x = BatchNormalization(epsilon=1e-5, name=name + '.bn1')(x) 29 | x = Activation('relu', name=name + '.relu1')(x) 30 | 31 | x = Conv2D(out_dim, 3, padding='same', kernel_initializer=RandomNormal(stddev=0.02), use_bias=False, name=name + '.conv2')(x) 32 | x = BatchNormalization(epsilon=1e-5, name=name + '.bn2')(x) 33 | 34 | if num_channels != out_dim or stride != 1: 35 | shortcut = Conv2D(out_dim, 1, strides=stride, kernel_initializer=RandomNormal(stddev=0.02), use_bias=False, name=name + '.shortcut.0')( 36 | shortcut) 37 | shortcut = BatchNormalization(epsilon=1e-5, name=name + '.shortcut.1')(shortcut) 38 | 39 | x = Add(name=name + '.add')([x, shortcut]) 40 | x = Activation('relu', name=name + '.relu')(x) 41 | return x 42 | 43 | def bottleneck_layer(x, num_channels, hgid): 44 | #-----------------------------------# 45 | # 中间的深度结构 46 | #-----------------------------------# 47 | pow_str = 'center.' * 5 48 | x = residual(x, num_channels, name='kps.%d.%s0' % (hgid, pow_str)) 49 | x = residual(x, num_channels, name='kps.%d.%s1' % (hgid, pow_str)) 50 | x = residual(x, num_channels, name='kps.%d.%s2' % (hgid, pow_str)) 51 | x = residual(x, num_channels, name='kps.%d.%s3' % (hgid, pow_str)) 52 | return x 53 | 54 | def connect_left_right(left, right, num_channels, num_channels_next, name): 55 | # 图像上半部分的卷积 56 | left = residual(left, num_channels_next, name=name + 'skip.0') 57 | left = residual(left, num_channels_next, name=name + 'skip.1') 58 | # 图像右半部分的上采样 59 | out = residual(right, num_channels, name=name + 'out.0') 60 | out = residual(out, num_channels_next, name=name + 'out.1') 61 | out = UpSampling2D(name=name + 'out.upsampleNN')(out) 62 | # 利用相加进行全连接 63 | out = Add(name=name + 'out.add')([left, out]) 64 | return out 65 | 66 | def pre(x, num_channels): 67 | #-----------------------------------# 68 | # 图片进入金字塔前的预处理 69 | # 一般是一次普通卷积 70 | # 加上残差结构 71 | #-----------------------------------# 72 | x = conv2d(x, 7, 128, name='pre.0', stride=2) 73 | x = residual(x, num_channels, name='pre.1', stride=2) 74 | return x 75 | 76 | def left_features(bottom, hgid, dims): 77 | #-------------------------------------------------# 78 | # 进行五次下采样 79 | # f1, f2, f4 , f8, f16, f32 : 1, 1/2, 1/4 1/8, 1/16, 1/32 resolution 80 | # 5 times reduce/increase: (256, 384, 384, 384, 512) 81 | #-------------------------------------------------# 82 | features = [bottom] 83 | for kk, nh in enumerate(dims): 84 | x = residual(features[-1], nh, name='kps.%d%s.down.0' % (hgid, str(kk)), stride=2) 85 | x = residual(x, nh, name='kps.%d%s.down.1' % (hgid, str(kk))) 86 | features.append(x) 87 | return features 88 | 89 | def right_features(leftfeatures, hgid, dims): 90 | #-------------------------------------------------# 91 | # 进行五次上采样,并进行连接 92 | # f1, f2, f4 , f8, f16, f32 : 1, 1/2, 1/4 1/8, 1/16, 1/32 resolution 93 | # 5 times reduce/increase: (256, 384, 384, 384, 512) 94 | #-------------------------------------------------# 95 | rf = bottleneck_layer(leftfeatures[-1], dims[-1], hgid) 96 | for kk in reversed(range(len(dims))): 97 | pow_str = '' 98 | for _ in range(kk): 99 | pow_str += 'center.' 100 | rf = connect_left_right(leftfeatures[kk], rf, dims[kk], dims[max(kk - 1, 0)], name='kps.%d.%s' % (hgid, pow_str)) 101 | return rf 102 | 103 | 104 | def create_heads(num_classes, rf1, hgid): 105 | y1 = Conv2D(256, 3, kernel_initializer=RandomNormal(stddev=0.02), use_bias=True, padding='same', name='hm.%d.0.conv' % hgid)(rf1) 106 | y1 = Activation('relu', name='hm.%d.0.relu' % hgid)(y1) 107 | y1 = Conv2D(num_classes, 1, use_bias=True, name='hm.%d.1' % hgid, activation = "sigmoid")(y1) 108 | 109 | y2 = Conv2D(256, 3, kernel_initializer=RandomNormal(stddev=0.02), use_bias=True, padding='same', name='wh.%d.0.conv' % hgid)(rf1) 110 | y2 = Activation('relu', name='wh.%d.0.relu' % hgid)(y2) 111 | y2 = Conv2D(2, 1, use_bias=True, name='wh.%d.1' % hgid)(y2) 112 | 113 | y3 = Conv2D(256, 3, kernel_initializer=RandomNormal(stddev=0.02), use_bias=True, padding='same', name='reg.%d.0.conv' % hgid)(rf1) 114 | y3 = Activation('relu', name='reg.%d.0.relu' % hgid)(y3) 115 | y3 = Conv2D(2, 1, use_bias=True, name='reg.%d.1' % hgid)(y3) 116 | 117 | return [y1,y2,y3] 118 | 119 | def hourglass_module(num_classes, bottom, cnv_dim, hgid, dims): 120 | # 左边下采样的部分 121 | lfs = left_features(bottom, hgid, dims) 122 | 123 | # 右边上采样与中间的连接部分 124 | rf1 = right_features(lfs, hgid, dims) 125 | rf1 = conv2d(rf1, 3, cnv_dim, name='cnvs.%d' % hgid) 126 | 127 | heads = create_heads(num_classes, rf1, hgid) 128 | return heads, rf1 129 | 130 | 131 | def HourglassNetwork(inpnuts, num_stacks, num_classes, cnv_dim=256, dims=[256, 384, 384, 384, 512]): 132 | inter = pre(inpnuts, cnv_dim) 133 | outputs = [] 134 | for i in range(num_stacks): 135 | prev_inter = inter 136 | _heads, inter = hourglass_module(num_classes, inter, cnv_dim, i, dims) 137 | outputs.append(_heads) 138 | if i < num_stacks - 1: 139 | inter_ = Conv2D(cnv_dim, 1, kernel_initializer=RandomNormal(stddev=0.02), use_bias=False, name='inter_.%d.0' % i)(prev_inter) 140 | inter_ = BatchNormalization(epsilon=1e-5, name='inter_.%d.1' % i)(inter_) 141 | 142 | cnv_ = Conv2D(cnv_dim, 1, kernel_initializer=RandomNormal(stddev=0.02), use_bias=False, name='cnv_.%d.0' % i)(inter) 143 | cnv_ = BatchNormalization(epsilon=1e-5, name='cnv_.%d.1' % i)(cnv_) 144 | 145 | inter = Add(name='inters.%d.inters.add' % i)([inter_, cnv_]) 146 | inter = Activation('relu', name='inters.%d.inters.relu' % i)(inter) 147 | inter = residual(inter, cnv_dim, 'inters.%d' % i) 148 | return outputs 149 | 150 | if __name__ == "__main__": 151 | image_input = Input(shape=(512, 512, 3)) 152 | outputs = HourglassNetwork(image_input,2,20) 153 | model = Model(image_input,outputs[-1]) 154 | model.summary() 155 | 156 | -------------------------------------------------------------------------------- /PKLot_server/nets/centernet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.layers import Input, Lambda, MaxPooling2D 3 | from tensorflow.keras.models import Model 4 | 5 | from nets.centernet_training import loss 6 | from nets.hourglass import HourglassNetwork 7 | from nets.resnet import ResNet50, centernet_head 8 | 9 | 10 | def nms(heat, kernel=3): 11 | hmax = MaxPooling2D((kernel, kernel), strides=1, padding='SAME')(heat) 12 | heat = tf.where(tf.equal(hmax, heat), heat, tf.zeros_like(heat)) 13 | return heat 14 | 15 | def topk(hm, max_objects=100): 16 | #-------------------------------------------------------------------------# 17 | # 当利用512x512x3图片进行coco数据集预测的时候 18 | # h = w = 128 num_classes = 80 19 | # Hot map热力图 -> b, 128, 128, 80 20 | # 进行热力图的非极大抑制,利用3x3的卷积对热力图进行最大值筛选 21 | # 找出一定区域内,得分最大的特征点。 22 | #-------------------------------------------------------------------------# 23 | hm = nms(hm) 24 | b, h, w, c = tf.shape(hm)[0], tf.shape(hm)[1], tf.shape(hm)[2], tf.shape(hm)[3] 25 | #-------------------------------------------# 26 | # 将所有结果平铺,获得(b, 128 * 128 * 80) 27 | #-------------------------------------------# 28 | hm = tf.reshape(hm, (b, -1)) 29 | #-----------------------------# 30 | # (b, k), (b, k) 31 | #-----------------------------# 32 | scores, indices = tf.math.top_k(hm, k=max_objects, sorted=True) 33 | 34 | #--------------------------------------# 35 | # 计算求出种类、网格点以及索引。 36 | #--------------------------------------# 37 | class_ids = indices % c 38 | xs = indices // c % w 39 | ys = indices // c // w 40 | indices = ys * w + xs 41 | return scores, indices, class_ids, xs, ys 42 | 43 | def decode(hm, wh, reg, max_objects=100,num_classes=20): 44 | #-----------------------------------------------------# 45 | # hm b, 128, 128, num_classes 46 | # wh b, 128, 128, 2 47 | # reg b, 128, 128, 2 48 | # scores b, max_objects 49 | # indices b, max_objects 50 | # class_ids b, max_objects 51 | # xs b, max_objects 52 | # ys b, max_objects 53 | #-----------------------------------------------------# 54 | scores, indices, class_ids, xs, ys = topk(hm, max_objects=max_objects) 55 | b = tf.shape(hm)[0] 56 | 57 | #-----------------------------------------------------# 58 | # wh b, 128 * 128, 2 59 | # reg b, 128 * 128, 2 60 | #-----------------------------------------------------# 61 | reg = tf.reshape(reg, [b, -1, 2]) 62 | wh = tf.reshape(wh, [b, -1, 2]) 63 | length = tf.shape(wh)[1] 64 | 65 | #-----------------------------------------------------# 66 | # 找到其在1维上的索引 67 | # batch_idx b, max_objects 68 | #-----------------------------------------------------# 69 | batch_idx = tf.expand_dims(tf.range(0, b), 1) 70 | batch_idx = tf.tile(batch_idx, (1, max_objects)) 71 | full_indices = tf.reshape(batch_idx, [-1]) * tf.cast(length, tf.int32) + tf.reshape(indices, [-1]) 72 | 73 | #-----------------------------------------------------# 74 | # 取出top_k个框对应的参数 75 | #-----------------------------------------------------# 76 | topk_reg = tf.gather(tf.reshape(reg, [-1,2]), full_indices) 77 | topk_reg = tf.reshape(topk_reg, [b, -1, 2]) 78 | 79 | topk_wh = tf.gather(tf.reshape(wh, [-1,2]), full_indices) 80 | topk_wh = tf.reshape(topk_wh, [b, -1, 2]) 81 | 82 | #-----------------------------------------------------# 83 | # 利用参数获得调整后预测框的中心 84 | # topk_cx b,k,1 85 | # topk_cy b,k,1 86 | #-----------------------------------------------------# 87 | topk_cx = tf.cast(tf.expand_dims(xs, axis=-1), tf.float32) + topk_reg[..., 0:1] 88 | topk_cy = tf.cast(tf.expand_dims(ys, axis=-1), tf.float32) + topk_reg[..., 1:2] 89 | 90 | #-----------------------------------------------------# 91 | # 计算预测框左上角和右下角 92 | # topk_x1 b,k,1 预测框左上角x轴坐标 93 | # topk_y1 b,k,1 预测框左上角y轴坐标 94 | # topk_x2 b,k,1 预测框右下角x轴坐标 95 | # topk_y2 b,k,1 预测框右下角y轴坐标 96 | #-----------------------------------------------------# 97 | topk_x1, topk_y1 = topk_cx - topk_wh[..., 0:1] / 2, topk_cy - topk_wh[..., 1:2] / 2 98 | topk_x2, topk_y2 = topk_cx + topk_wh[..., 0:1] / 2, topk_cy + topk_wh[..., 1:2] / 2 99 | 100 | #-----------------------------------------------------# 101 | # scores b,k,1 预测框得分 102 | # class_ids b,k,1 预测框种类 103 | #-----------------------------------------------------# 104 | scores = tf.expand_dims(scores, axis=-1) 105 | class_ids = tf.cast(tf.expand_dims(class_ids, axis=-1), tf.float32) 106 | 107 | #-----------------------------------------------------# 108 | # detections 预测框所有参数的堆叠 109 | # 前四个是预测框的坐标,后两个是预测框的得分与种类 110 | #-----------------------------------------------------# 111 | detections = tf.concat([topk_x1, topk_y1, topk_x2, topk_y2, scores, class_ids], axis=-1) 112 | 113 | return detections 114 | 115 | 116 | def centernet(input_shape, num_classes, backbone='resnet50', max_objects=100, mode="train", num_stacks=2): 117 | assert backbone in ['resnet50', 'hourglass'] 118 | output_size = input_shape[0] // 4 119 | image_input = Input(shape=input_shape) 120 | hm_input = Input(shape=(output_size, output_size, num_classes)) 121 | wh_input = Input(shape=(max_objects, 2)) 122 | reg_input = Input(shape=(max_objects, 2)) 123 | reg_mask_input = Input(shape=(max_objects,)) 124 | index_input = Input(shape=(max_objects,)) 125 | 126 | if backbone=='resnet50': 127 | #-----------------------------------# 128 | # 对输入图片进行特征提取 129 | # 512, 512, 3 -> 16, 16, 2048 130 | #-----------------------------------# 131 | C5 = ResNet50(image_input) 132 | #--------------------------------------------------------------------------------------------------------# 133 | # 对获取到的特征进行上采样,进行分类预测和回归预测 134 | # 16, 16, 2048 -> 32, 32, 256 -> 64, 64, 128 -> 128, 128, 64 -> 128, 128, 64 -> 128, 128, num_classes 135 | # -> 128, 128, 64 -> 128, 128, 2 136 | # -> 128, 128, 64 -> 128, 128, 2 137 | #--------------------------------------------------------------------------------------------------------# 138 | y1, y2, y3 = centernet_head(C5,num_classes) 139 | 140 | if mode=="train": 141 | loss_ = Lambda(loss, name='centernet_loss')([y1, y2, y3, hm_input, wh_input, reg_input, reg_mask_input, index_input]) 142 | model = Model(inputs=[image_input, hm_input, wh_input, reg_input, reg_mask_input, index_input], outputs=[loss_]) 143 | return model 144 | else: 145 | detections = Lambda(lambda x: decode(*x, max_objects=max_objects, 146 | num_classes=num_classes))([y1, y2, y3]) 147 | prediction_model = Model(inputs=image_input, outputs=detections) 148 | return prediction_model 149 | 150 | else: 151 | outs = HourglassNetwork(image_input,num_stacks,num_classes) 152 | 153 | if mode=="train": 154 | loss_all = [] 155 | for out in outs: 156 | y1, y2, y3 = out 157 | loss_ = Lambda(loss)([y1, y2, y3, hm_input, wh_input, reg_input, reg_mask_input, index_input]) 158 | loss_all.append(loss_) 159 | loss_all = Lambda(tf.reduce_mean,name='centernet_loss')(loss_all) 160 | 161 | model = Model(inputs=[image_input, hm_input, wh_input, reg_input, reg_mask_input, index_input], outputs=loss_all) 162 | return model 163 | else: 164 | y1, y2, y3 = outs[-1] 165 | detections = Lambda(lambda x: decode(*x, max_objects=max_objects, 166 | num_classes=num_classes))([y1, y2, y3]) 167 | prediction_model = Model(inputs=image_input, outputs=[detections]) 168 | return prediction_model 169 | -------------------------------------------------------------------------------- /PKLot_server/utils/utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | from PIL import Image 6 | 7 | def letterbox_image(image, size): 8 | iw, ih = image.size 9 | w, h = size 10 | scale = min(w/iw, h/ih) 11 | nw = int(iw*scale) 12 | nh = int(ih*scale) 13 | 14 | image = image.resize((nw,nh), Image.BICUBIC) 15 | new_image = Image.new('RGB', size, (128,128,128)) 16 | new_image.paste(image, ((w-nw)//2, (h-nh)//2)) 17 | return new_image 18 | 19 | def centernet_correct_boxes(top, left, bottom, right, input_shape, image_shape): 20 | new_shape = image_shape*np.min(input_shape/image_shape) 21 | 22 | offset = (input_shape-new_shape)/2./input_shape 23 | scale = input_shape/new_shape 24 | 25 | box_yx = np.concatenate(((top+bottom)/2,(left+right)/2),axis=-1) 26 | box_hw = np.concatenate((bottom-top,right-left),axis=-1) 27 | 28 | box_yx = (box_yx - offset) * scale 29 | box_hw *= scale 30 | 31 | box_mins = box_yx - (box_hw / 2.) 32 | box_maxes = box_yx + (box_hw / 2.) 33 | boxes = np.concatenate([ 34 | box_mins[:, 0:1], 35 | box_mins[:, 1:2], 36 | box_maxes[:, 0:1], 37 | box_maxes[:, 1:2] 38 | ],axis=-1) 39 | boxes *= np.concatenate([image_shape, image_shape],axis=-1) 40 | return boxes 41 | 42 | def nms(results, nms): 43 | outputs = [] 44 | # 对每一个图片进行处理 45 | for i in range(len(results)): 46 | #------------------------------------------------# 47 | # 具体过程可参考 48 | # https://www.bilibili.com/video/BV1Lz411B7nQ 49 | #------------------------------------------------# 50 | detections = results[i] 51 | unique_class = np.unique(detections[:,-1]) 52 | 53 | best_box = [] 54 | if len(unique_class) == 0: 55 | results.append(best_box) 56 | continue 57 | # 对种类进行循环, 58 | # 非极大抑制的作用是筛选出一定区域内属于同一种类得分最大的框, 59 | # 对种类进行循环可以帮助我们对每一个类分别进行非极大抑制。 60 | for c in unique_class: 61 | cls_mask = detections[:,-1] == c 62 | 63 | detection = detections[cls_mask] 64 | scores = detection[:,4] 65 | # 根据得分对该种类进行从大到小排序。 66 | arg_sort = np.argsort(scores)[::-1] 67 | detection = detection[arg_sort] 68 | while np.shape(detection)[0]>0: 69 | # 每次取出得分最大的框,计算其与其它所有预测框的重合程度,重合程度过大的则剔除。 70 | best_box.append(detection[0]) 71 | if len(detection) == 1: 72 | break 73 | ious = iou(best_box[-1],detection[1:]) 74 | detection = detection[1:][ious 0 and min(masked_heatmap.shape) > 0: # TODO debug 110 | np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap) 111 | return heatmap 112 | 113 | def gaussian2D(shape, sigma=1): 114 | m, n = [(ss - 1.) / 2. for ss in shape] 115 | y, x = np.ogrid[-m:m + 1, -n:n + 1] 116 | 117 | h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) 118 | h[h < np.finfo(h.dtype).eps * h.max()] = 0 119 | return h 120 | 121 | def gaussian_radius(det_size, min_overlap=0.7): 122 | height, width = det_size 123 | 124 | a1 = 1 125 | b1 = (height + width) 126 | c1 = width * height * (1 - min_overlap) / (1 + min_overlap) 127 | sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1) 128 | r1 = (b1 + sq1) / 2 129 | 130 | a2 = 4 131 | b2 = 2 * (height + width) 132 | c2 = (1 - min_overlap) * width * height 133 | sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2) 134 | r2 = (b2 + sq2) / 2 135 | 136 | a3 = 4 * min_overlap 137 | b3 = -2 * min_overlap * (height + width) 138 | c3 = (min_overlap - 1) * width * height 139 | sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3) 140 | r3 = (b3 + sq3) / 2 141 | return min(r1, r2, r3) 142 | 143 | 144 | class ModelCheckpoint(tf.keras.callbacks.Callback): 145 | def __init__(self, filepath, monitor='val_loss', verbose=0, 146 | save_best_only=False, save_weights_only=False, 147 | mode='auto', period=1): 148 | super(ModelCheckpoint, self).__init__() 149 | self.monitor = monitor 150 | self.verbose = verbose 151 | self.filepath = filepath 152 | self.save_best_only = save_best_only 153 | self.save_weights_only = save_weights_only 154 | self.period = period 155 | self.epochs_since_last_save = 0 156 | 157 | if mode not in ['auto', 'min', 'max']: 158 | warnings.warn('ModelCheckpoint mode %s is unknown, ' 159 | 'fallback to auto mode.' % (mode), 160 | RuntimeWarning) 161 | mode = 'auto' 162 | 163 | if mode == 'min': 164 | self.monitor_op = np.less 165 | self.best = np.Inf 166 | elif mode == 'max': 167 | self.monitor_op = np.greater 168 | self.best = -np.Inf 169 | else: 170 | if 'acc' in self.monitor or self.monitor.startswith('fmeasure'): 171 | self.monitor_op = np.greater 172 | self.best = -np.Inf 173 | else: 174 | self.monitor_op = np.less 175 | self.best = np.Inf 176 | 177 | def on_epoch_end(self, epoch, logs=None): 178 | logs = logs or {} 179 | self.epochs_since_last_save += 1 180 | if self.epochs_since_last_save >= self.period: 181 | self.epochs_since_last_save = 0 182 | filepath = self.filepath.format(epoch=epoch + 1, **logs) 183 | if self.save_best_only: 184 | current = logs.get(self.monitor) 185 | if current is None: 186 | warnings.warn('Can save best model only with %s available, ' 187 | 'skipping.' % (self.monitor), RuntimeWarning) 188 | else: 189 | if self.monitor_op(current, self.best): 190 | if self.verbose > 0: 191 | print('\nEpoch %05d: %s improved from %0.5f to %0.5f,' 192 | ' saving model to %s' 193 | % (epoch + 1, self.monitor, self.best, 194 | current, filepath)) 195 | self.best = current 196 | if self.save_weights_only: 197 | self.model.save_weights(filepath, overwrite=True) 198 | else: 199 | self.model.save(filepath, overwrite=True) 200 | else: 201 | if self.verbose > 0: 202 | print('\nEpoch %05d: %s did not improve' % 203 | (epoch + 1, self.monitor)) 204 | else: 205 | if self.verbose > 0: 206 | print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath)) 207 | if self.save_weights_only: 208 | self.model.save_weights(filepath, overwrite=True) 209 | else: 210 | self.model.save(filepath, overwrite=True) 211 | -------------------------------------------------------------------------------- /PKLot_server/predictUI.py: -------------------------------------------------------------------------------- 1 | from MainWindow import Ui_MainWindow 2 | from PyQt5.QtWidgets import QWidget, QMainWindow, QApplication, QGraphicsScene, QGraphicsPixmapItem, QFileDialog, QMessageBox 3 | from PyQt5.QtGui import QImage, QPixmap 4 | import sys 5 | import datetime 6 | import cv2 7 | import numpy as np 8 | from PIL import ImageDraw, ImageFont 9 | from centernet import CenterNet 10 | from DataServer import FormetDayTime 11 | from threading import Thread, Lock 12 | import socket 13 | from PIL import Image 14 | import matplotlib.pyplot as plt 15 | 16 | plt.rcParams["font.sans-serif"] = ["SimHei"] 17 | plt.rcParams["axes.unicode_minus"] = False 18 | 19 | centernet = CenterNet() 20 | lock = Lock() 21 | 22 | class predictWindow(QMainWindow, Ui_MainWindow): 23 | def __init__(self): 24 | super().__init__() 25 | self.setupUi(self) 26 | img_src = cv2.imread("model_data/timg.jpg") # 读取图像 27 | img_src = cv2.cvtColor(img_src, cv2.COLOR_BGR2RGB) # 转换图像通道 28 | label_width = self.label.width() 29 | label_height = self.label.height() 30 | temp_imgSrc = QImage(img_src[:], img_src.shape[1], img_src.shape[0], img_src.shape[1] * 3, QImage.Format_RGB888) 31 | # 将图片转换为QPixmap方便显示 32 | self.pixmap_imgSrc = QPixmap.fromImage(temp_imgSrc).scaled(label_width, label_height) 33 | now = datetime.datetime.now() 34 | time = now.strftime("%Y %m %d %H %M %S") 35 | timelist = time.split(" ") 36 | self.mon = int(timelist[1]) 37 | self.isOn = False 38 | self.isVideoOn = False 39 | self.isImgOn = False 40 | self.day, self.time = FormetDayTime(timelist) 41 | self.label.setPixmap(QPixmap(self.pixmap_imgSrc)) 42 | self.pushButton.clicked.connect(self.CountTime) 43 | self.pushButton_1.clicked.connect(self.imgOnOff) 44 | self.pushButton_2.clicked.connect(self.videoOnOff) 45 | self.pushButton_3.clicked.connect(self.On_Off) 46 | 47 | def CountTime(self): 48 | try: 49 | tco = np.zeros(24) 50 | tce = np.zeros(24) 51 | lock.acquire() 52 | with open("./detect_Logs/countlogs") as f: 53 | for fn in f.readlines(): 54 | print(fn) 55 | tco[int(fn.split(';')[0])] += int(fn.split(";")[1]) 56 | tce[int(fn.split(';')[0])] += int(fn.split(";")[2].strip()) 57 | 58 | lock.release() 59 | label = [] 60 | for i in range(24): 61 | label.append(str(i)) 62 | rate = tco / (tco + tce) 63 | plt.plot(range(24), rate, label="时段内车位占用比") 64 | plt.xticks(range(24), label) 65 | plt.xlim((0, 23)) 66 | plt.xlabel("小时") 67 | plt.ylabel("车位占有率") 68 | plt.legend() 69 | plt.title("时段车位统计") 70 | plt.show() 71 | except: 72 | msg_box = QMessageBox(QMessageBox.Warning, 'Warning', '数据加载失败!') 73 | msg_box.exec_() 74 | 75 | def imgOnOff(self): 76 | self.isImgOn = ~self.isImgOn 77 | if not self.isImgOn: 78 | self.pushButton_1.setText("图片检测") 79 | else: 80 | self.pushButton_1.setText("完成图片检测") 81 | 82 | if self.isImgOn: 83 | try: 84 | image_file, _ = QFileDialog.getOpenFileName(self, 'Open file', '\\', 'Image files (*.jpg *.gif *.png *.jpeg)') 85 | except: 86 | msg_box = QMessageBox(QMessageBox.Warning, 'Warning', '未选中图片!') 87 | msg_box.exec_() 88 | print('Open Error! Try again!') 89 | return 90 | 91 | try: 92 | image = Image.open(image_file) 93 | r_image, obj1sum, obj2sum, color1, color2, max_cnt = centernet.detect_image(image) 94 | 95 | draw = ImageDraw.Draw(r_image) 96 | fontStyle = ImageFont.truetype( 97 | font="model_data/simhei.ttf", size=20, encoding='utf-8') 98 | 99 | # # 绘制框和文本 100 | # draw.rectangle( 101 | # [tuple((0, 560)), tuple((180, 640))], 102 | # fill=(255, 255, 255), outline='black') 103 | 104 | draw.ellipse((10, 560, 30, 580), fill=color1) 105 | draw.ellipse((10, 600, 30, 620), fill=color2) 106 | draw.text((50, 560), "被占车位:" + str(obj1sum), color1, font=fontStyle) 107 | draw.text((50, 600), "空车位 :" + str(obj2sum), color2, font=fontStyle) 108 | 109 | r_image = np.array(r_image) 110 | # RGBtoBGR满足opencv显示格式 111 | img_src = r_image # cv2.cvtColor(r_image, cv2.COLOR_RGB2BGR) 112 | label_width = self.label.width() 113 | label_height = self.label.height() 114 | temp_imgSrc = QImage(img_src[:], img_src.shape[1], img_src.shape[0], img_src.shape[1] * 3, 115 | QImage.Format_RGB888) 116 | 117 | # 将图片转换为QPixmap方便显示 118 | pixmap_imgSrc = QPixmap.fromImage(temp_imgSrc).scaled(label_width, label_height) 119 | self.label.setPixmap(QPixmap(pixmap_imgSrc)) 120 | 121 | except: 122 | msg_box = QMessageBox(QMessageBox.Warning, 'Warning', '识别失败!') 123 | msg_box.exec_() 124 | else: 125 | self.label.setPixmap(QPixmap(self.pixmap_imgSrc)) 126 | 127 | 128 | def videoOnOff(self): 129 | self.isVideoOn = ~self.isVideoOn 130 | if not self.isVideoOn: 131 | self.pushButton_2.setText("视频检测") 132 | else: 133 | self.pushButton_2.setText("停止视频检测") 134 | if self.isVideoOn: 135 | try: 136 | image_file, _ = QFileDialog.getOpenFileName(self, 'Open file', '\\', 137 | 'Video files (*.gif *.mp4 *.avi *.dat *.mkv *.flv *.vob *.3gp)') 138 | capture = cv2.VideoCapture(image_file) 139 | except: 140 | msg_box = QMessageBox(QMessageBox.Warning, 'Warning', '未选中视频!') 141 | msg_box.exec_() 142 | print('Open Error! Try again!') 143 | return 144 | 145 | try: 146 | global centernet 147 | while self.isVideoOn: 148 | ref, frame = capture.read() 149 | 150 | now = datetime.datetime.now() 151 | time = now.strftime("%Y %m %d %H %M %S") 152 | timelist = time.split(" ") 153 | self.day, self.time = FormetDayTime(timelist) 154 | 155 | # 转变成Image 156 | frame = Image.fromarray(np.uint8(frame)) 157 | # 进行检测 158 | frame, obj1sum, obj2sum, color1, color2, max_cnt = centernet.detect_image(frame) 159 | 160 | lock.acquire() 161 | with open("./detect_Logs/logs.txt", 'w') as f: 162 | f.write(self.day + " " + self.time + ';') 163 | f.write(str(obj1sum)+';') 164 | f.write(str(obj2sum)+";") 165 | f.write(str(max_cnt)) 166 | 167 | with open("./detect_Logs/countlogs", 'a') as f: 168 | f.write(self.time.split(":")[0] + ';') 169 | f.write(str(obj1sum)+';') 170 | f.write(str(obj2sum)) 171 | f.write("\n") 172 | lock.release() 173 | # 设置 174 | draw = ImageDraw.Draw(frame) 175 | fontStyle = ImageFont.truetype( 176 | font="model_data/simhei.ttf", size=20, encoding='utf-8') 177 | 178 | # 绘制框和文本 179 | # draw.rectangle( 180 | # [tuple((0, 560)), tuple((180, 640))], 181 | # fill=(255, 255, 255), outline='black') 182 | 183 | draw.ellipse((10, 560, 30, 580), fill=color1) 184 | draw.ellipse((10, 600, 30, 620), fill=color2) 185 | draw.text((50, 560), "被占车位:" + str(obj1sum), color1, font=fontStyle) 186 | draw.text((50, 600), "空车位 :" + str(obj2sum), color2, font=fontStyle) 187 | draw.text((0, 0), "时间:" + self.day + " " + self.time, (255, 255, 255), font=fontStyle) 188 | 189 | frame.save("detect_Logs/1.png") 190 | frame = np.array(frame) 191 | 192 | img_src = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) 193 | 194 | label_width = self.label.width() 195 | label_height = self.label.height() 196 | temp_imgSrc = QImage(img_src[:], img_src.shape[1], img_src.shape[0], img_src.shape[1] * 3, 197 | QImage.Format_RGB888) 198 | 199 | # 将图片转换为QPixmap方便显示 200 | pixmap_imgSrc = QPixmap.fromImage(temp_imgSrc).scaled(label_width, label_height) 201 | self.label.setPixmap(QPixmap(pixmap_imgSrc)) 202 | 203 | c = cv2.waitKey(1) & 0xff 204 | 205 | if c == 27: 206 | capture.release() 207 | break 208 | 209 | capture.release() 210 | cv2.destroyAllWindows() 211 | 212 | 213 | except: 214 | self.label.setPixmap(QPixmap(self.pixmap_imgSrc)) 215 | else: 216 | self.label.setPixmap(QPixmap(self.pixmap_imgSrc)) 217 | 218 | def On_Off(self): 219 | self.isOn = ~self.isOn 220 | if not self.isOn: 221 | self.pushButton_3.setText("实时监测") 222 | else: 223 | self.pushButton_3.setText("结束实时监测") 224 | if self.isOn: 225 | capture = cv2.VideoCapture(0) 226 | global centernet 227 | while self.isOn: 228 | # 获取当前时间 229 | now = datetime.datetime.now() 230 | time = now.strftime("%Y %m %d %H %M %S") 231 | timelist = time.split(" ") 232 | self.day, self.time = FormetDayTime(timelist) 233 | 234 | # 读取某一帧 235 | ref, frame = capture.read() 236 | # 转变成Image 237 | frame = Image.fromarray(np.uint8(frame)) 238 | # 进行检测 239 | frame, obj1sum, obj2sum, color1, color2, max_cnt = centernet.detect_image(frame) 240 | 241 | # 设置 242 | draw = ImageDraw.Draw(frame) 243 | fontStyle = ImageFont.truetype( 244 | font="model_data/simhei.ttf", size=20, encoding='utf-8') 245 | 246 | # 绘制框和文本 247 | # draw.rectangle( 248 | # [tuple((0, 0)), tuple((180, 40))], 249 | # fill=(255, 255, 255), outline='black') 250 | # draw.rectangle( 251 | # [tuple((0, 0)), tuple((420, 40))], 252 | # fill=(255, 255, 255), outline='black') 253 | 254 | draw.ellipse((35, 410, 55, 430), fill=color1) 255 | draw.ellipse((35, 440, 55, 460), fill=color2) 256 | draw.text((60, 410), "被占车位:" + str(obj1sum), fill=color1, font=fontStyle) 257 | draw.text((60, 440), "空车位 :" + str(obj2sum), fill=color2, font=fontStyle) 258 | draw.text((0, 0), "时间:" + self.day + " " + self.time, (0, 0, 0), font=fontStyle) 259 | lock.acquire() 260 | with open("./detect_Logs/logs.txt", 'w') as f: 261 | f.write(self.day + " " + self.time + ';') 262 | f.write(str(obj1sum) + ';') 263 | f.write(str(obj2sum)+";") 264 | f.write(str(max_cnt)) 265 | 266 | with open("./detect_Logs/countlogs", 'a') as f: 267 | f.write(self.time.split(":")[0] + ';') 268 | f.write(str(obj1sum) + ';') 269 | f.write(str(obj2sum)) 270 | f.write("\n") 271 | 272 | lock.release() 273 | frame.save("detect_Logs/1.png") 274 | frame = np.array(frame) 275 | 276 | # RGBtoBGR满足opencv显示格式 277 | img_src = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) 278 | label_width = self.label.width() 279 | label_height = self.label.height() 280 | temp_imgSrc = QImage(img_src[:], img_src.shape[1], img_src.shape[0], img_src.shape[1] * 3, 281 | QImage.Format_RGB888) 282 | 283 | # 将图片转换为QPixmap方便显示 284 | pixmap_imgSrc = QPixmap.fromImage(temp_imgSrc).scaled(label_width, label_height) 285 | self.label.setPixmap(QPixmap(pixmap_imgSrc)) 286 | 287 | # if int(timelist[4]) % 5 == 0 and int(timelist[5]) == 0: 288 | # # 每隔五分钟保存一次 289 | # SaveData(self.day, self.time, num) 290 | c = cv2.waitKey(1) & 0xff 291 | 292 | if c == 27: 293 | capture.release() 294 | break 295 | capture.release() 296 | cv2.destroyAllWindows() 297 | else: 298 | self.label.setPixmap(QPixmap(self.pixmap_imgSrc)) 299 | 300 | 301 | def detect_process(): 302 | app = QApplication(sys.argv) 303 | window = predictWindow() 304 | window.show() 305 | sys.exit(app.exec_()) 306 | 307 | def response_process(): 308 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 309 | sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 310 | sock.bind(('127.0.0.1', 8000)) 311 | sock.listen(5) 312 | while 1: 313 | cli_sock, cli_addr = sock.accept() 314 | req = cli_sock.recv(4096) 315 | lock.acquire() 316 | with open('detect_Logs/logs.txt') as f: 317 | word = f.readline() 318 | cli_sock.send(word.encode()) 319 | cli_sock.close() 320 | lock.release() 321 | 322 | if __name__ == "__main__": 323 | t1 = Thread(target=detect_process) 324 | t2 = Thread(target=response_process) 325 | t1.start() 326 | t2.start() 327 | t1.join() 328 | t2.join() 329 | -------------------------------------------------------------------------------- /PKLot_server/centernet.py: -------------------------------------------------------------------------------- 1 | import colorsys 2 | import os 3 | import time 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | from PIL import ImageDraw, ImageFont 8 | 9 | from nets.centernet import centernet 10 | from utils.utils import centernet_correct_boxes, letterbox_image, nms 11 | 12 | 13 | def preprocess_image(image): 14 | mean = [0.40789655, 0.44719303, 0.47026116] 15 | std = [0.2886383 , 0.27408165, 0.27809834] 16 | return ((np.float32(image) / 255.) - mean) / std 17 | 18 | #--------------------------------------------# 19 | # 使用自己训练好的模型预测需要修改3个参数 20 | # model_path、classes_path和backbone 21 | # 都需要修改! 22 | # 如果出现shape不匹配,一定要注意 23 | # 训练时的model_path和classes_path参数的修改 24 | #--------------------------------------------# 25 | 26 | opt = 0 27 | if opt == 0: 28 | m_p = 'model_data/centernet_hourglass_coco.h5' 29 | c_p = 'model_data/coco_classes.txt' 30 | backb = 'hourglass' 31 | ns = False 32 | 33 | else: 34 | m_p = 'model_data/centernet_resnet50_voc.h5' 35 | c_p = 'model_data/voc_classes.txt' 36 | backb = 'resnet50' 37 | ns = True 38 | 39 | class CenterNet(object): 40 | global m_p 41 | global c_p 42 | global backb 43 | global ns 44 | _defaults = { 45 | # "model_path" : 'model_data/centernet_resnet50_voc.h5', 46 | # "classes_path" : 'model_data/voc_classes.txt', 47 | # "backbone" : 'resnet50', 48 | "model_path" : m_p, #'model_data/centernet_hourglass_coco.h5', 49 | "classes_path" : c_p, #'model_data/coco_classes.txt', 50 | "backbone" : backb, # 'hourglass', 51 | "input_shape" : [512,512,3], 52 | "confidence" : 0.3, 53 | # backbone为resnet50时建议设置为True 54 | # backbone为hourglass时建议设置为False 55 | # 也可以根据检测效果自行选择 56 | # "nms" : True, 57 | "nms": ns, 58 | "nms_threhold" : 0.3, 59 | } 60 | 61 | @classmethod 62 | def get_defaults(cls, n): 63 | if n in cls._defaults: 64 | return cls._defaults[n] 65 | else: 66 | return "Unrecognized attribute name '" + n + "'" 67 | 68 | #---------------------------------------------------# 69 | # 初始化centernet 70 | #---------------------------------------------------# 71 | def __init__(self, **kwargs): 72 | self.__dict__.update(self._defaults) 73 | self.class_names = self._get_class() 74 | self.generate() 75 | 76 | #---------------------------------------------------# 77 | # 获得所有的分类 78 | #---------------------------------------------------# 79 | def _get_class(self): 80 | classes_path = os.path.expanduser(self.classes_path) 81 | with open(classes_path) as f: 82 | class_names = f.readlines() 83 | class_names = [c.strip() for c in class_names] 84 | return class_names 85 | 86 | #---------------------------------------------------# 87 | # 载入模型 88 | #---------------------------------------------------# 89 | def generate(self): 90 | model_path = os.path.expanduser(self.model_path) 91 | assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.' 92 | 93 | #----------------------------------------# 94 | # 计算种类数量 95 | #----------------------------------------# 96 | self.num_classes = len(self.class_names) 97 | 98 | #----------------------------------------# 99 | # 创建centernet模型 100 | #----------------------------------------# 101 | self.centernet = centernet(self.input_shape,num_classes=self.num_classes,backbone=self.backbone,mode='predict') 102 | self.centernet.load_weights(self.model_path, by_name=True) 103 | 104 | print('{} model, anchors, and classes loaded.'.format(self.model_path)) 105 | 106 | # 画框设置不同的颜色 107 | hsv_tuples = [(x / len(self.class_names), 1., 1.) 108 | for x in range(len(self.class_names))] 109 | self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples)) 110 | self.colors = list( 111 | map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), 112 | self.colors)) 113 | 114 | @tf.function 115 | def get_pred(self, photo): 116 | preds = self.centernet(photo, training=False) 117 | return preds 118 | #---------------------------------------------------# 119 | # 检测图片 120 | #---------------------------------------------------# 121 | def detect_image(self, image): 122 | x_max_shape, y_max_shape = self.input_shape[0], self.input_shape[1] 123 | #---------------------------------------------------------# 124 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 125 | #---------------------------------------------------------# 126 | image = image.convert('RGB') 127 | 128 | image_shape = np.array(np.shape(image)[0:2]) 129 | #---------------------------------------------------------# 130 | # 给图像增加灰条,实现不失真的resize 131 | #---------------------------------------------------------# 132 | crop_img = letterbox_image(image, [self.input_shape[0],self.input_shape[1]]) 133 | #----------------------------------------------------------------------------------# 134 | # 将RGB转化成BGR,这是因为原始的centernet_hourglass权值是使用BGR通道的图片训练的 135 | #----------------------------------------------------------------------------------# 136 | photo = np.array(crop_img,dtype = np.float32)[:,:,::-1] 137 | #-----------------------------------------------------------# 138 | # 图片预处理,归一化。获得的photo的shape为[1, 512, 512, 3] 139 | #-----------------------------------------------------------# 140 | photo = np.reshape(preprocess_image(photo),[1,self.input_shape[0],self.input_shape[1],self.input_shape[2]]) 141 | 142 | preds = self.get_pred(photo).numpy() 143 | #-------------------------------------------------------# 144 | # 对于centernet网络来讲,确立中心非常重要。 145 | # 对于大目标而言,会存在许多的局部信息。 146 | # 此时对于同一个大目标,中心点比较难以确定。 147 | # 使用最大池化的非极大抑制方法无法去除局部框 148 | # 所以我还是写了另外一段对框进行非极大抑制的代码 149 | # 实际测试中,hourglass为主干网络时有无额外的nms相差不大,resnet相差较大。 150 | #-------------------------------------------------------# 151 | if self.nms: 152 | preds = np.array(nms(preds,self.nms_threhold)) 153 | 154 | if len(preds[0])<=0: 155 | return image 156 | 157 | #-----------------------------------------------------------# 158 | # 将预测结果转换成小数的形式 159 | #-----------------------------------------------------------# 160 | preds[0][:,0:4] = preds[0][:,0:4]/(self.input_shape[0]/4) 161 | 162 | det_label = preds[0][:, -1] 163 | det_conf = preds[0][:, -2] 164 | det_xmin, det_ymin, det_xmax, det_ymax = preds[0][:, 0], preds[0][:, 1], preds[0][:, 2], preds[0][:, 3] 165 | #-----------------------------------------------------------# 166 | # 筛选出其中得分高于confidence的框 167 | #-----------------------------------------------------------# 168 | top_indices = [i for i, conf in enumerate(det_conf) if conf >= self.confidence] 169 | top_conf = det_conf[top_indices] 170 | top_objs = det_label[top_indices] 171 | obj1sum = len(top_objs[top_objs == 0]) 172 | obj2sum = len(top_objs[top_objs == 1]) 173 | 174 | top_label_indices = top_objs.tolist() 175 | top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims(det_xmin[top_indices],-1),np.expand_dims(det_ymin[top_indices],-1),np.expand_dims(det_xmax[top_indices],-1),np.expand_dims(det_ymax[top_indices],-1) 176 | 177 | #-----------------------------------------------------------# 178 | # 去掉灰条部分 179 | #-----------------------------------------------------------# 180 | boxes = centernet_correct_boxes(top_ymin,top_xmin,top_ymax,top_xmax,np.array([self.input_shape[0],self.input_shape[1]]),image_shape) 181 | 182 | font = ImageFont.truetype(font='model_data/simhei.ttf',size=np.floor(3e-2 * np.shape(image)[1] + 0.5).astype('int32')) 183 | 184 | thickness = max((np.shape(image)[0] + np.shape(image)[1]) // self.input_shape[0], 1) 185 | cnt1 = 0 186 | cnt2 = 0 187 | cnt3 = 0 188 | cnt4 = 0 189 | max_add = 0 190 | max_cnt = 0 191 | for i, c in enumerate(top_label_indices): 192 | predicted_class = self.class_names[int(c)] 193 | score = top_conf[i] 194 | 195 | top, left, bottom, right = boxes[i] 196 | x_mid = (left + right) / 2 197 | y_mid = (bottom + top) / 2 198 | if predicted_class == "space-empty": 199 | if x_mid <= x_max_shape / 2 and y_mid <= y_max_shape / 2: 200 | cnt1 += 1 201 | elif x_mid >= x_max_shape / 2 and y_mid <= y_max_shape / 2: 202 | cnt2 += 1 203 | elif x_mid <= x_max_shape / 2 and y_mid >= y_max_shape / 2: 204 | cnt3 += 1 205 | else: 206 | cnt4 += 1 207 | 208 | 209 | top = top - 5 210 | left = left - 5 211 | bottom = bottom + 5 212 | right = right + 5 213 | 214 | top = max(0, np.floor(top + 0.5).astype('int32')) 215 | left = max(0, np.floor(left + 0.5).astype('int32')) 216 | bottom = min(np.shape(image)[0], np.floor(bottom + 0.5).astype('int32')) 217 | right = min(np.shape(image)[1], np.floor(right + 0.5).astype('int32')) 218 | 219 | # 画框框 220 | label = '{} {:.2f}'.format(predicted_class, score) 221 | draw = ImageDraw.Draw(image) 222 | label_size = draw.textsize(label, font) 223 | label = label.encode('utf-8') 224 | # print(label, top, left, bottom, right) 225 | 226 | if top - label_size[1] >= 0: 227 | text_origin = np.array([left, top - label_size[1]]) 228 | else: 229 | text_origin = np.array([left, top + 1]) 230 | 231 | for i in range(thickness): 232 | draw.rectangle( 233 | [left + i, top + i, right - i, bottom - i], 234 | outline=self.colors[int(c)]) 235 | # draw.rectangle( 236 | # [tuple(text_origin), tuple(text_origin + label_size)], 237 | # fill=self.colors[int(c)]) 238 | # draw.text(text_origin, str(label,'UTF-8'), fill=(0, 0, 0), font=font) 239 | del draw 240 | 241 | if cnt1 > max_add: 242 | max_add = cnt1 243 | max_cnt = 1 244 | 245 | if cnt2 > max_add: 246 | max_add = cnt2 247 | max_cnt = 2 248 | 249 | if cnt3 > max_add: 250 | max_add = cnt3 251 | max_cnt = 3 252 | 253 | if cnt4 > max_add: 254 | max_add = cnt4 255 | max_cnt = 4 256 | print(cnt1, cnt2, cnt3, cnt4) 257 | return image, obj1sum, obj2sum, self.colors[0], self.colors[1], max_cnt 258 | 259 | def draw_lines(self): 260 | pass 261 | 262 | def get_FPS(self, image, test_interval): 263 | #---------------------------------------------------------# 264 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 265 | #---------------------------------------------------------# 266 | image = image.convert('RGB') 267 | 268 | image_shape = np.array(np.shape(image)[0:2]) 269 | #---------------------------------------------------------# 270 | # 给图像增加灰条,实现不失真的resize 271 | #---------------------------------------------------------# 272 | crop_img = letterbox_image(image, [self.input_shape[0],self.input_shape[1]]) 273 | #----------------------------------------------------------------------------------# 274 | # 将RGB转化成BGR,这是因为原始的centernet_hourglass权值是使用BGR通道的图片训练的 275 | #----------------------------------------------------------------------------------# 276 | photo = np.array(crop_img,dtype = np.float32)[:,:,::-1] 277 | #-----------------------------------------------------------# 278 | # 图片预处理,归一化。获得的photo的shape为[1, 512, 512, 3] 279 | #-----------------------------------------------------------# 280 | photo = np.reshape(preprocess_image(photo), [1, self.input_shape[0], self.input_shape[1], self.input_shape[2]]) 281 | 282 | preds = self.get_pred(photo).numpy() 283 | 284 | if self.nms: 285 | preds = np.array(nms(preds, self.nms_threhold)) 286 | 287 | if len(preds[0])>0: 288 | preds[0][:, 0:4] = preds[0][:, 0:4] / (self.input_shape[0] / 4) 289 | 290 | det_label = preds[0][:, -1] 291 | det_conf = preds[0][:, -2] 292 | det_xmin, det_ymin, det_xmax, det_ymax = preds[0][:, 0], preds[0][:, 1], preds[0][:, 2], preds[0][:, 3] 293 | 294 | top_indices = [i for i, conf in enumerate(det_conf) if conf >= self.confidence] 295 | top_conf = det_conf[top_indices] 296 | top_label_indices = det_label[top_indices].tolist() 297 | top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims(det_xmin[top_indices],-1),np.expand_dims(det_ymin[top_indices],-1),np.expand_dims(det_xmax[top_indices],-1),np.expand_dims(det_ymax[top_indices],-1) 298 | 299 | boxes = centernet_correct_boxes(top_ymin,top_xmin,top_ymax,top_xmax,np.array([self.input_shape[0],self.input_shape[1]]),image_shape) 300 | 301 | t1 = time.time() 302 | for _ in range(test_interval): 303 | preds = self.get_pred(photo).numpy() 304 | 305 | if self.nms: 306 | preds = np.array(nms(preds, self.nms_threhold)) 307 | 308 | if len(preds[0])>0: 309 | preds[0][:, 0:4] = preds[0][:, 0:4] / (self.input_shape[0] / 4) 310 | 311 | det_label = preds[0][:, -1] 312 | det_conf = preds[0][:, -2] 313 | det_xmin, det_ymin, det_xmax, det_ymax = preds[0][:, 0], preds[0][:, 1], preds[0][:, 2], preds[0][:, 3] 314 | 315 | top_indices = [i for i, conf in enumerate(det_conf) if conf >= self.confidence] 316 | top_conf = det_conf[top_indices] 317 | top_label_indices = det_label[top_indices].tolist() 318 | top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims(det_xmin[top_indices],-1),np.expand_dims(det_ymin[top_indices],-1),np.expand_dims(det_xmax[top_indices],-1),np.expand_dims(det_ymax[top_indices],-1) 319 | 320 | boxes = centernet_correct_boxes(top_ymin,top_xmin,top_ymax,top_xmax,np.array([self.input_shape[0],self.input_shape[1]]),image_shape) 321 | 322 | t2 = time.time() 323 | tact_time = (t2 - t1) / test_interval 324 | return tact_time 325 | -------------------------------------------------------------------------------- /PKLot_server/nets/centernet_training.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from random import shuffle 4 | 5 | import cv2 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import scipy.signal 9 | import tensorflow as tf 10 | from PIL import Image 11 | from utils.utils import draw_gaussian, gaussian_radius 12 | 13 | def preprocess_image(image): 14 | mean = [0.40789655, 0.44719303, 0.47026116] 15 | std = [0.2886383 , 0.27408165, 0.27809834] 16 | return ((np.float32(image) / 255.) - mean) / std 17 | 18 | def focal_loss(hm_pred, hm_true): 19 | #-------------------------------------------------------------------------# 20 | # 找到每张图片的正样本和负样本 21 | # 一个真实框对应一个正样本 22 | # 除去正样本的特征点,其余为负样本 23 | #-------------------------------------------------------------------------# 24 | pos_mask = tf.cast(tf.equal(hm_true, 1), tf.float32) 25 | #-------------------------------------------------------------------------# 26 | # 正样本特征点附近的负样本的权值更小一些 27 | #-------------------------------------------------------------------------# 28 | neg_mask = tf.cast(tf.less(hm_true, 1), tf.float32) 29 | neg_weights = tf.pow(1 - hm_true, 4) 30 | 31 | #-------------------------------------------------------------------------# 32 | # 计算focal loss。难分类样本权重大,易分类样本权重小。 33 | #-------------------------------------------------------------------------# 34 | pos_loss = -tf.math.log(tf.clip_by_value(hm_pred, 1e-6, 1.)) * tf.pow(1 - hm_pred, 2) * pos_mask 35 | neg_loss = -tf.math.log(tf.clip_by_value(1 - hm_pred, 1e-6, 1.)) * tf.pow(hm_pred, 2) * neg_weights * neg_mask 36 | 37 | num_pos = tf.reduce_sum(pos_mask) 38 | pos_loss = tf.reduce_sum(pos_loss) 39 | neg_loss = tf.reduce_sum(neg_loss) 40 | 41 | #-------------------------------------------------------------------------# 42 | # 进行损失的归一化 43 | #-------------------------------------------------------------------------# 44 | cls_loss = tf.cond(tf.greater(num_pos, 0), lambda: (pos_loss + neg_loss) / num_pos, lambda: neg_loss) 45 | return cls_loss 46 | 47 | 48 | def reg_l1_loss(y_pred, y_true, indices, mask): 49 | #-------------------------------------------------------------------------# 50 | # 获得batch_size和num_classes 51 | #-------------------------------------------------------------------------# 52 | b, c = tf.shape(y_pred)[0], tf.shape(y_pred)[-1] 53 | k = tf.shape(indices)[1] 54 | 55 | y_pred = tf.reshape(y_pred, (b, -1, c)) 56 | length = tf.shape(y_pred)[1] 57 | indices = tf.cast(indices, tf.int32) 58 | 59 | #-------------------------------------------------------------------------# 60 | # 利用序号取出预测结果中,和真实框相同的特征点的部分 61 | #-------------------------------------------------------------------------# 62 | batch_idx = tf.expand_dims(tf.range(0, b), 1) 63 | batch_idx = tf.tile(batch_idx, (1, k)) 64 | full_indices = (tf.reshape(batch_idx, [-1]) * tf.cast(length, tf.int32) + 65 | tf.reshape(indices, [-1])) 66 | 67 | y_pred = tf.gather(tf.reshape(y_pred, [-1,c]),full_indices) 68 | y_pred = tf.reshape(y_pred, [b, -1, c]) 69 | 70 | mask = tf.tile(tf.expand_dims(mask, axis=-1), (1, 1, 2)) 71 | #-------------------------------------------------------------------------# 72 | # 求取l1损失值 73 | #-------------------------------------------------------------------------# 74 | total_loss = tf.reduce_sum(tf.abs(y_true * mask - y_pred * mask)) 75 | reg_loss = total_loss / (tf.reduce_sum(mask) + 1e-4) 76 | return reg_loss 77 | 78 | 79 | def loss(args): 80 | #-----------------------------------------------------------------------------------------------------------------# 81 | # hm_pred:热力图的预测值 (batch_size, 128, 128, num_classes) 82 | # wh_pred:宽高的预测值 (batch_size, 128, 128, 2) 83 | # reg_pred:中心坐标偏移预测值 (batch_size, 128, 128, 2) 84 | # hm_true:热力图的真实值 (batch_size, 128, 128, num_classes) 85 | # wh_true:宽高的真实值 (batch_size, max_objects, 2) 86 | # reg_true:中心坐标偏移真实值 (batch_size, max_objects, 2) 87 | # reg_mask:真实值的mask (batch_size, max_objects) 88 | # indices:真实值对应的坐标 (batch_size, max_objects) 89 | #-----------------------------------------------------------------------------------------------------------------# 90 | hm_pred, wh_pred, reg_pred, hm_true, wh_true, reg_true, reg_mask, indices = args 91 | hm_loss = focal_loss(hm_pred, hm_true) 92 | wh_loss = 0.1 * reg_l1_loss(wh_pred, wh_true, indices, reg_mask) 93 | reg_loss = reg_l1_loss(reg_pred, reg_true, indices, reg_mask) 94 | total_loss = hm_loss + wh_loss + reg_loss 95 | # total_loss = tf.Print(total_loss,[hm_loss,wh_loss,reg_loss]) 96 | return total_loss 97 | 98 | def rand(a=0, b=1): 99 | return np.random.rand()*(b-a) + a 100 | 101 | class Generator(object): 102 | def __init__(self,batch_size,train_lines,val_lines, 103 | input_size,num_classes,max_objects=200): 104 | 105 | self.batch_size = batch_size 106 | self.train_lines = train_lines 107 | self.val_lines = val_lines 108 | self.input_size = input_size 109 | self.output_size = (int(input_size[0]/4) , int(input_size[1]/4)) 110 | self.num_classes = num_classes 111 | self.max_objects = max_objects 112 | 113 | def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=1.5, val=1.5, random=True): 114 | '''r实时数据增强的随机预处理''' 115 | line = annotation_line.split() 116 | image = Image.open(line[0]) 117 | iw, ih = image.size 118 | h, w = input_shape 119 | box = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]]) 120 | 121 | if not random: 122 | # resize image 123 | scale = min(w/iw, h/ih) 124 | nw = int(iw*scale) 125 | nh = int(ih*scale) 126 | dx = (w-nw)//2 127 | dy = (h-nh)//2 128 | 129 | image = image.resize((nw,nh), Image.BICUBIC) 130 | new_image = Image.new('RGB', (w,h), (128,128,128)) 131 | new_image.paste(image, (dx, dy)) 132 | image_data = np.array(new_image, np.float32) 133 | 134 | # correct boxes 135 | box_data = np.zeros((len(box),5)) 136 | if len(box)>0: 137 | np.random.shuffle(box) 138 | box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx 139 | box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy 140 | box[:, 0:2][box[:, 0:2]<0] = 0 141 | box[:, 2][box[:, 2]>w] = w 142 | box[:, 3][box[:, 3]>h] = h 143 | box_w = box[:, 2] - box[:, 0] 144 | box_h = box[:, 3] - box[:, 1] 145 | box = box[np.logical_and(box_w>1, box_h>1)] 146 | box_data = np.zeros((len(box),5)) 147 | box_data[:len(box)] = box 148 | 149 | return image_data, box_data 150 | 151 | # resize image 152 | new_ar = w/h * rand(1-jitter,1+jitter)/rand(1-jitter,1+jitter) 153 | scale = rand(0.25, 2) 154 | if new_ar < 1: 155 | nh = int(scale*h) 156 | nw = int(nh*new_ar) 157 | else: 158 | nw = int(scale*w) 159 | nh = int(nw/new_ar) 160 | image = image.resize((nw,nh), Image.BICUBIC) 161 | 162 | # place image 163 | dx = int(rand(0, w-nw)) 164 | dy = int(rand(0, h-nh)) 165 | new_image = Image.new('RGB', (w,h), (128,128,128)) 166 | new_image.paste(image, (dx, dy)) 167 | image = new_image 168 | 169 | # flip image or not 170 | flip = rand()<.5 171 | if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT) 172 | 173 | # distort image 174 | hue = rand(-hue, hue) 175 | sat = rand(1, sat) if rand()<.5 else 1/rand(1, sat) 176 | val = rand(1, val) if rand()<.5 else 1/rand(1, val) 177 | x = cv2.cvtColor(np.array(image,np.float32)/255, cv2.COLOR_RGB2HSV) 178 | x[..., 0] += hue*360 179 | x[..., 0][x[..., 0]>1] -= 1 180 | x[..., 0][x[..., 0]<0] += 1 181 | x[..., 1] *= sat 182 | x[..., 2] *= val 183 | x[x[:,:, 0]>360, 0] = 360 184 | x[:, :, 1:][x[:, :, 1:]>1] = 1 185 | x[x<0] = 0 186 | image_data = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255 187 | 188 | 189 | # correct boxes 190 | box_data = np.zeros((len(box),5)) 191 | if len(box)>0: 192 | np.random.shuffle(box) 193 | box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx 194 | box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy 195 | if flip: box[:, [0,2]] = w - box[:, [2,0]] 196 | box[:, 0:2][box[:, 0:2]<0] = 0 197 | box[:, 2][box[:, 2]>w] = w 198 | box[:, 3][box[:, 3]>h] = h 199 | box_w = box[:, 2] - box[:, 0] 200 | box_h = box[:, 3] - box[:, 1] 201 | box = box[np.logical_and(box_w>1, box_h>1)] # discard invalid box 202 | box_data = np.zeros((len(box),5)) 203 | box_data[:len(box)] = box 204 | if len(box) == 0: 205 | return image_data, [] 206 | 207 | if (box_data[:,:4]>0).any(): 208 | return image_data, box_data 209 | else: 210 | return image_data, [] 211 | 212 | def generate(self, train=True, eager=False): 213 | while True: 214 | if train: 215 | # 打乱 216 | shuffle(self.train_lines) 217 | lines = self.train_lines 218 | else: 219 | shuffle(self.val_lines) 220 | lines = self.val_lines 221 | 222 | batch_images = np.zeros((self.batch_size, self.input_size[0], self.input_size[1], self.input_size[2]), dtype=np.float32) 223 | batch_hms = np.zeros((self.batch_size, self.output_size[0], self.output_size[1], self.num_classes), dtype=np.float32) 224 | batch_whs = np.zeros((self.batch_size, self.max_objects, 2), dtype=np.float32) 225 | batch_regs = np.zeros((self.batch_size, self.max_objects, 2), dtype=np.float32) 226 | batch_reg_masks = np.zeros((self.batch_size, self.max_objects), dtype=np.float32) 227 | batch_indices = np.zeros((self.batch_size, self.max_objects), dtype=np.float32) 228 | 229 | b = 0 230 | for annotation_line in lines: 231 | img,y = self.get_random_data(annotation_line,self.input_size[0:2],random=train) 232 | 233 | if len(y)!=0: 234 | boxes = np.array(y[:,:4],dtype=np.float32) 235 | boxes[:,0] = boxes[:,0]/self.input_size[1]*self.output_size[1] 236 | boxes[:,1] = boxes[:,1]/self.input_size[0]*self.output_size[0] 237 | boxes[:,2] = boxes[:,2]/self.input_size[1]*self.output_size[1] 238 | boxes[:,3] = boxes[:,3]/self.input_size[0]*self.output_size[0] 239 | 240 | for i in range(len(y)): 241 | bbox = boxes[i].copy() 242 | bbox = np.array(bbox) 243 | bbox[[0, 2]] = np.clip(bbox[[0, 2]], 0, self.output_size[1] - 1) 244 | bbox[[1, 3]] = np.clip(bbox[[1, 3]], 0, self.output_size[0] - 1) 245 | cls_id = int(y[i,-1]) 246 | 247 | h, w = bbox[3] - bbox[1], bbox[2] - bbox[0] 248 | if h > 0 and w > 0: 249 | ct = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2], dtype=np.float32) 250 | ct_int = ct.astype(np.int32) 251 | 252 | # 获得热力图 253 | radius = gaussian_radius((math.ceil(h), math.ceil(w))) 254 | radius = max(0, int(radius)) 255 | batch_hms[b, :, :, cls_id] = draw_gaussian(batch_hms[b, :, :, cls_id], ct_int, radius) 256 | 257 | batch_whs[b, i] = 1. * w, 1. * h 258 | # 计算中心偏移量 259 | batch_regs[b, i] = ct - ct_int 260 | # 将对应的mask设置为1,用于排除多余的0 261 | batch_reg_masks[b, i] = 1 262 | # 表示第ct_int[1]行的第ct_int[0]个。 263 | batch_indices[b, i] = ct_int[1] * self.output_size[0] + ct_int[0] 264 | 265 | # 将RGB转化成BGR 266 | img = np.array(img,dtype = np.float32)[:,:,::-1] 267 | batch_images[b] = preprocess_image(img) 268 | b = b + 1 269 | if b == self.batch_size: 270 | b = 0 271 | if eager: 272 | yield batch_images, batch_hms, batch_whs, batch_regs, batch_reg_masks, batch_indices 273 | else: 274 | yield [batch_images, batch_hms, batch_whs, batch_regs, batch_reg_masks, batch_indices], np.zeros((self.batch_size,)) 275 | 276 | batch_images = np.zeros((self.batch_size, self.input_size[0], self.input_size[1], 3), dtype=np.float32) 277 | 278 | batch_hms = np.zeros((self.batch_size, self.output_size[0], self.output_size[1], self.num_classes), 279 | dtype=np.float32) 280 | batch_whs = np.zeros((self.batch_size, self.max_objects, 2), dtype=np.float32) 281 | batch_regs = np.zeros((self.batch_size, self.max_objects, 2), dtype=np.float32) 282 | batch_reg_masks = np.zeros((self.batch_size, self.max_objects), dtype=np.float32) 283 | batch_indices = np.zeros((self.batch_size, self.max_objects), dtype=np.float32) 284 | 285 | 286 | class LossHistory(tf.keras.callbacks.Callback): 287 | def __init__(self, log_dir): 288 | import datetime 289 | curr_time = datetime.datetime.now() 290 | time_str = datetime.datetime.strftime(curr_time,'%Y_%m_%d_%H_%M_%S') 291 | self.log_dir = log_dir 292 | self.time_str = time_str 293 | self.save_path = os.path.join(self.log_dir, "loss_" + str(self.time_str)) 294 | self.losses = [] 295 | self.val_loss = [] 296 | 297 | os.makedirs(self.save_path) 298 | 299 | def on_epoch_end(self, batch, logs={}): 300 | self.losses.append(logs.get('loss')) 301 | self.val_loss.append(logs.get('val_loss')) 302 | with open(os.path.join(self.save_path, "epoch_loss_" + str(self.time_str) + ".txt"), 'a') as f: 303 | f.write(str(logs.get('loss'))) 304 | f.write("\n") 305 | with open(os.path.join(self.save_path, "epoch_val_loss_" + str(self.time_str) + ".txt"), 'a') as f: 306 | f.write(str(logs.get('val_loss'))) 307 | f.write("\n") 308 | 309 | self.loss_plot() 310 | 311 | def loss_plot(self): 312 | iters = range(len(self.losses)) 313 | 314 | plt.figure() 315 | plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss') 316 | plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss') 317 | try: 318 | if len(self.losses) < 25: 319 | num = 5 320 | else: 321 | num = 15 322 | 323 | plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss') 324 | plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss') 325 | except: 326 | pass 327 | 328 | plt.grid(True) 329 | plt.xlabel('Epoch') 330 | plt.ylabel('Loss') 331 | plt.title('A Loss Curve') 332 | plt.legend(loc="upper right") 333 | 334 | plt.savefig(os.path.join(self.save_path, "epoch_loss_" + str(self.time_str) + ".png")) 335 | 336 | plt.cla() 337 | plt.close("all") 338 | -------------------------------------------------------------------------------- /PKLot_server/get_map.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | import shutil 5 | import operator 6 | import sys 7 | import argparse 8 | import math 9 | 10 | import numpy as np 11 | 12 | ''' 13 | 用于计算mAP 14 | 代码克隆自https://github.com/Cartucho/mAP 15 | 如果想要设定mAP0.x,比如计算mAP0.75,可以设定MINOVERLAP = 0.75。 16 | ''' 17 | MINOVERLAP = 0.5 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('-na', '--no-animation', help="no animation is shown.", action="store_true") 21 | parser.add_argument('-np', '--no-plot', help="no plot is shown.", action="store_true") 22 | parser.add_argument('-q', '--quiet', help="minimalistic console output.", action="store_true") 23 | parser.add_argument('-i', '--ignore', nargs='+', type=str, help="ignore a list of classes.") 24 | parser.add_argument('--set-class-iou', nargs='+', type=str, help="set IoU for a specific class.") 25 | args = parser.parse_args() 26 | 27 | ''' 28 | 0,0 ------> x (width) 29 | | 30 | | (Left,Top) 31 | | *_________ 32 | | | | 33 | | | 34 | y |_________| 35 | (height) * 36 | (Right,Bottom) 37 | ''' 38 | 39 | if args.ignore is None: 40 | args.ignore = [] 41 | 42 | specific_iou_flagged = False 43 | if args.set_class_iou is not None: 44 | specific_iou_flagged = True 45 | 46 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 47 | 48 | GT_PATH = os.path.join(os.getcwd(), 'input', 'ground-truth') 49 | DR_PATH = os.path.join(os.getcwd(), 'input', 'detection-results') 50 | IMG_PATH = os.path.join(os.getcwd(), 'input', 'images-optional') 51 | if os.path.exists(IMG_PATH): 52 | for dirpath, dirnames, files in os.walk(IMG_PATH): 53 | if not files: 54 | args.no_animation = True 55 | else: 56 | args.no_animation = True 57 | 58 | show_animation = False 59 | if not args.no_animation: 60 | try: 61 | import cv2 62 | show_animation = True 63 | except ImportError: 64 | print("\"opencv-python\" not found, please install to visualize the results.") 65 | args.no_animation = True 66 | 67 | draw_plot = False 68 | if not args.no_plot: 69 | try: 70 | import matplotlib.pyplot as plt 71 | draw_plot = True 72 | except ImportError: 73 | print("\"matplotlib\" not found, please install it to get the resulting plots.") 74 | args.no_plot = True 75 | 76 | 77 | def log_average_miss_rate(precision, fp_cumsum, num_images): 78 | """ 79 | log-average miss rate: 80 | Calculated by averaging miss rates at 9 evenly spaced FPPI points 81 | between 10e-2 and 10e0, in log-space. 82 | 83 | output: 84 | lamr | log-average miss rate 85 | mr | miss rate 86 | fppi | false positives per image 87 | 88 | references: 89 | [1] Dollar, Piotr, et al. "Pedestrian Detection: An Evaluation of the 90 | State of the Art." Pattern Analysis and Machine Intelligence, IEEE 91 | Transactions on 34.4 (2012): 743 - 761. 92 | """ 93 | 94 | if precision.size == 0: 95 | lamr = 0 96 | mr = 1 97 | fppi = 0 98 | return lamr, mr, fppi 99 | 100 | fppi = fp_cumsum / float(num_images) 101 | mr = (1 - precision) 102 | 103 | fppi_tmp = np.insert(fppi, 0, -1.0) 104 | mr_tmp = np.insert(mr, 0, 1.0) 105 | 106 | ref = np.logspace(-2.0, 0.0, num = 9) 107 | for i, ref_i in enumerate(ref): 108 | j = np.where(fppi_tmp <= ref_i)[-1][-1] 109 | ref[i] = mr_tmp[j] 110 | 111 | lamr = math.exp(np.mean(np.log(np.maximum(1e-10, ref)))) 112 | 113 | return lamr, mr, fppi 114 | 115 | """ 116 | throw error and exit 117 | """ 118 | def error(msg): 119 | print(msg) 120 | sys.exit(0) 121 | 122 | """ 123 | check if the number is a float between 0.0 and 1.0 124 | """ 125 | def is_float_between_0_and_1(value): 126 | try: 127 | val = float(value) 128 | if val > 0.0 and val < 1.0: 129 | return True 130 | else: 131 | return False 132 | except ValueError: 133 | return False 134 | 135 | """ 136 | Calculate the AP given the recall and precision array 137 | 1st) We compute a version of the measured precision/recall curve with 138 | precision monotonically decreasing 139 | 2nd) We compute the AP as the area under this curve by numerical integration. 140 | """ 141 | def voc_ap(rec, prec): 142 | """ 143 | --- Official matlab code VOC2012--- 144 | mrec=[0 ; rec ; 1]; 145 | mpre=[0 ; prec ; 0]; 146 | for i=numel(mpre)-1:-1:1 147 | mpre(i)=max(mpre(i),mpre(i+1)); 148 | end 149 | i=find(mrec(2:end)~=mrec(1:end-1))+1; 150 | ap=sum((mrec(i)-mrec(i-1)).*mpre(i)); 151 | """ 152 | rec.insert(0, 0.0) # insert 0.0 at begining of list 153 | rec.append(1.0) # insert 1.0 at end of list 154 | mrec = rec[:] 155 | prec.insert(0, 0.0) # insert 0.0 at begining of list 156 | prec.append(0.0) # insert 0.0 at end of list 157 | mpre = prec[:] 158 | """ 159 | This part makes the precision monotonically decreasing 160 | (goes from the end to the beginning) 161 | matlab: for i=numel(mpre)-1:-1:1 162 | mpre(i)=max(mpre(i),mpre(i+1)); 163 | """ 164 | for i in range(len(mpre)-2, -1, -1): 165 | mpre[i] = max(mpre[i], mpre[i+1]) 166 | """ 167 | This part creates a list of indexes where the recall changes 168 | matlab: i=find(mrec(2:end)~=mrec(1:end-1))+1; 169 | """ 170 | i_list = [] 171 | for i in range(1, len(mrec)): 172 | if mrec[i] != mrec[i-1]: 173 | i_list.append(i) # if it was matlab would be i + 1 174 | """ 175 | The Average Precision (AP) is the area under the curve 176 | (numerical integration) 177 | matlab: ap=sum((mrec(i)-mrec(i-1)).*mpre(i)); 178 | """ 179 | ap = 0.0 180 | for i in i_list: 181 | ap += ((mrec[i]-mrec[i-1])*mpre[i]) 182 | return ap, mrec, mpre 183 | 184 | 185 | """ 186 | Convert the lines of a file to a list 187 | """ 188 | def file_lines_to_list(path): 189 | # open txt file lines to a list 190 | with open(path) as f: 191 | content = f.readlines() 192 | # remove whitespace characters like `\n` at the end of each line 193 | content = [x.strip() for x in content] 194 | return content 195 | 196 | """ 197 | Draws text in image 198 | """ 199 | def draw_text_in_image(img, text, pos, color, line_width): 200 | font = cv2.FONT_HERSHEY_PLAIN 201 | fontScale = 1 202 | lineType = 1 203 | bottomLeftCornerOfText = pos 204 | cv2.putText(img, text, 205 | bottomLeftCornerOfText, 206 | font, 207 | fontScale, 208 | color, 209 | lineType) 210 | text_width, _ = cv2.getTextSize(text, font, fontScale, lineType)[0] 211 | return img, (line_width + text_width) 212 | 213 | """ 214 | Plot - adjust axes 215 | """ 216 | def adjust_axes(r, t, fig, axes): 217 | # get text width for re-scaling 218 | bb = t.get_window_extent(renderer=r) 219 | text_width_inches = bb.width / fig.dpi 220 | # get axis width in inches 221 | current_fig_width = fig.get_figwidth() 222 | new_fig_width = current_fig_width + text_width_inches 223 | propotion = new_fig_width / current_fig_width 224 | # get axis limit 225 | x_lim = axes.get_xlim() 226 | axes.set_xlim([x_lim[0], x_lim[1]*propotion]) 227 | 228 | """ 229 | Draw plot using Matplotlib 230 | """ 231 | def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, output_path, to_show, plot_color, true_p_bar): 232 | # sort the dictionary by decreasing value, into a list of tuples 233 | sorted_dic_by_value = sorted(dictionary.items(), key=operator.itemgetter(1)) 234 | # unpacking the list of tuples into two lists 235 | sorted_keys, sorted_values = zip(*sorted_dic_by_value) 236 | # 237 | if true_p_bar != "": 238 | """ 239 | Special case to draw in: 240 | - green -> TP: True Positives (object detected and matches ground-truth) 241 | - red -> FP: False Positives (object detected but does not match ground-truth) 242 | - orange -> FN: False Negatives (object not detected but present in the ground-truth) 243 | """ 244 | fp_sorted = [] 245 | tp_sorted = [] 246 | for key in sorted_keys: 247 | fp_sorted.append(dictionary[key] - true_p_bar[key]) 248 | tp_sorted.append(true_p_bar[key]) 249 | plt.barh(range(n_classes), fp_sorted, align='center', color='crimson', label='False Positive') 250 | plt.barh(range(n_classes), tp_sorted, align='center', color='forestgreen', label='True Positive', left=fp_sorted) 251 | # add legend 252 | plt.legend(loc='lower right') 253 | """ 254 | Write number on side of bar 255 | """ 256 | fig = plt.gcf() # gcf - get current figure 257 | axes = plt.gca() 258 | r = fig.canvas.get_renderer() 259 | for i, val in enumerate(sorted_values): 260 | fp_val = fp_sorted[i] 261 | tp_val = tp_sorted[i] 262 | fp_str_val = " " + str(fp_val) 263 | tp_str_val = fp_str_val + " " + str(tp_val) 264 | # trick to paint multicolor with offset: 265 | # first paint everything and then repaint the first number 266 | t = plt.text(val, i, tp_str_val, color='forestgreen', va='center', fontweight='bold') 267 | plt.text(val, i, fp_str_val, color='crimson', va='center', fontweight='bold') 268 | if i == (len(sorted_values)-1): # largest bar 269 | adjust_axes(r, t, fig, axes) 270 | else: 271 | plt.barh(range(n_classes), sorted_values, color=plot_color) 272 | """ 273 | Write number on side of bar 274 | """ 275 | fig = plt.gcf() # gcf - get current figure 276 | axes = plt.gca() 277 | r = fig.canvas.get_renderer() 278 | for i, val in enumerate(sorted_values): 279 | str_val = " " + str(val) # add a space before 280 | if val < 1.0: 281 | str_val = " {0:.2f}".format(val) 282 | t = plt.text(val, i, str_val, color=plot_color, va='center', fontweight='bold') 283 | # re-set axes to show number inside the figure 284 | if i == (len(sorted_values)-1): # largest bar 285 | adjust_axes(r, t, fig, axes) 286 | # set window title 287 | fig.canvas.set_window_title(window_title) 288 | # write classes in y axis 289 | tick_font_size = 12 290 | plt.yticks(range(n_classes), sorted_keys, fontsize=tick_font_size) 291 | """ 292 | Re-scale height accordingly 293 | """ 294 | init_height = fig.get_figheight() 295 | # comput the matrix height in points and inches 296 | dpi = fig.dpi 297 | height_pt = n_classes * (tick_font_size * 1.4) # 1.4 (some spacing) 298 | height_in = height_pt / dpi 299 | # compute the required figure height 300 | top_margin = 0.15 # in percentage of the figure height 301 | bottom_margin = 0.05 # in percentage of the figure height 302 | figure_height = height_in / (1 - top_margin - bottom_margin) 303 | # set new height 304 | if figure_height > init_height: 305 | fig.set_figheight(figure_height) 306 | 307 | # set plot title 308 | plt.title(plot_title, fontsize=14) 309 | # set axis titles 310 | # plt.xlabel('classes') 311 | plt.xlabel(x_label, fontsize='large') 312 | # adjust size of window 313 | fig.tight_layout() 314 | # save the plot 315 | fig.savefig(output_path) 316 | # show image 317 | if to_show: 318 | plt.show() 319 | # close the plot 320 | plt.close() 321 | 322 | """ 323 | Create a ".temp_files/" and "results/" directory 324 | """ 325 | TEMP_FILES_PATH = ".temp_files" 326 | if not os.path.exists(TEMP_FILES_PATH): # if it doesn't exist already 327 | os.makedirs(TEMP_FILES_PATH) 328 | results_files_path = "results" 329 | if os.path.exists(results_files_path): # if it exist already 330 | # reset the results directory 331 | shutil.rmtree(results_files_path) 332 | 333 | os.makedirs(results_files_path) 334 | if draw_plot: 335 | os.makedirs(os.path.join(results_files_path, "AP")) 336 | os.makedirs(os.path.join(results_files_path, "F1")) 337 | os.makedirs(os.path.join(results_files_path, "Recall")) 338 | os.makedirs(os.path.join(results_files_path, "Precision")) 339 | if show_animation: 340 | os.makedirs(os.path.join(results_files_path, "images", "detections_one_by_one")) 341 | 342 | """ 343 | ground-truth 344 | Load each of the ground-truth files into a temporary ".json" file. 345 | Create a list of all the class names present in the ground-truth (gt_classes). 346 | """ 347 | # get a list with the ground-truth files 348 | ground_truth_files_list = glob.glob(GT_PATH + '/*.txt') 349 | if len(ground_truth_files_list) == 0: 350 | error("Error: No ground-truth files found!") 351 | ground_truth_files_list.sort() 352 | # dictionary with counter per class 353 | gt_counter_per_class = {} 354 | counter_images_per_class = {} 355 | 356 | for txt_file in ground_truth_files_list: 357 | #print(txt_file) 358 | file_id = txt_file.split(".txt", 1)[0] 359 | file_id = os.path.basename(os.path.normpath(file_id)) 360 | # check if there is a correspondent detection-results file 361 | temp_path = os.path.join(DR_PATH, (file_id + ".txt")) 362 | if not os.path.exists(temp_path): 363 | error_msg = "Error. File not found: {}\n".format(temp_path) 364 | error_msg += "(You can avoid this error message by running extra/intersect-gt-and-dr.py)" 365 | error(error_msg) 366 | lines_list = file_lines_to_list(txt_file) 367 | # create ground-truth dictionary 368 | bounding_boxes = [] 369 | is_difficult = False 370 | already_seen_classes = [] 371 | for line in lines_list: 372 | try: 373 | if "difficult" in line: 374 | class_name, left, top, right, bottom, _difficult = line.split() 375 | is_difficult = True 376 | else: 377 | class_name, left, top, right, bottom = line.split() 378 | 379 | except: 380 | if "difficult" in line: 381 | line_split = line.split() 382 | _difficult = line_split[-1] 383 | bottom = line_split[-2] 384 | right = line_split[-3] 385 | top = line_split[-4] 386 | left = line_split[-5] 387 | class_name = "" 388 | for name in line_split[:-5]: 389 | class_name += name + " " 390 | class_name = class_name[:-1] 391 | is_difficult = True 392 | else: 393 | line_split = line.split() 394 | bottom = line_split[-1] 395 | right = line_split[-2] 396 | top = line_split[-3] 397 | left = line_split[-4] 398 | class_name = "" 399 | for name in line_split[:-4]: 400 | class_name += name + " " 401 | class_name = class_name[:-1] 402 | if class_name in args.ignore: 403 | continue 404 | bbox = left + " " + top + " " + right + " " +bottom 405 | if is_difficult: 406 | bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False, "difficult":True}) 407 | is_difficult = False 408 | else: 409 | bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False}) 410 | if class_name in gt_counter_per_class: 411 | gt_counter_per_class[class_name] += 1 412 | else: 413 | gt_counter_per_class[class_name] = 1 414 | 415 | if class_name not in already_seen_classes: 416 | if class_name in counter_images_per_class: 417 | counter_images_per_class[class_name] += 1 418 | else: 419 | counter_images_per_class[class_name] = 1 420 | already_seen_classes.append(class_name) 421 | 422 | 423 | with open(TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json", 'w') as outfile: 424 | json.dump(bounding_boxes, outfile) 425 | 426 | gt_classes = list(gt_counter_per_class.keys()) 427 | gt_classes = sorted(gt_classes) 428 | n_classes = len(gt_classes) 429 | 430 | """ 431 | Check format of the flag --set-class-iou (if used) 432 | e.g. check if class exists 433 | """ 434 | if specific_iou_flagged: 435 | n_args = len(args.set_class_iou) 436 | error_msg = \ 437 | '\n --set-class-iou [class_1] [IoU_1] [class_2] [IoU_2] [...]' 438 | if n_args % 2 != 0: 439 | error('Error, missing arguments. Flag usage:' + error_msg) 440 | # [class_1] [IoU_1] [class_2] [IoU_2] 441 | # specific_iou_classes = ['class_1', 'class_2'] 442 | specific_iou_classes = args.set_class_iou[::2] # even 443 | # iou_list = ['IoU_1', 'IoU_2'] 444 | iou_list = args.set_class_iou[1::2] # odd 445 | if len(specific_iou_classes) != len(iou_list): 446 | error('Error, missing arguments. Flag usage:' + error_msg) 447 | for tmp_class in specific_iou_classes: 448 | if tmp_class not in gt_classes: 449 | error('Error, unknown class \"' + tmp_class + '\". Flag usage:' + error_msg) 450 | for num in iou_list: 451 | if not is_float_between_0_and_1(num): 452 | error('Error, IoU must be between 0.0 and 1.0. Flag usage:' + error_msg) 453 | 454 | """ 455 | detection-results 456 | Load each of the detection-results files into a temporary ".json" file. 457 | """ 458 | dr_files_list = glob.glob(DR_PATH + '/*.txt') 459 | dr_files_list.sort() 460 | 461 | for class_index, class_name in enumerate(gt_classes): 462 | bounding_boxes = [] 463 | for txt_file in dr_files_list: 464 | file_id = txt_file.split(".txt",1)[0] 465 | file_id = os.path.basename(os.path.normpath(file_id)) 466 | temp_path = os.path.join(GT_PATH, (file_id + ".txt")) 467 | if class_index == 0: 468 | if not os.path.exists(temp_path): 469 | error_msg = "Error. File not found: {}\n".format(temp_path) 470 | error_msg += "(You can avoid this error message by running extra/intersect-gt-and-dr.py)" 471 | error(error_msg) 472 | lines = file_lines_to_list(txt_file) 473 | for line in lines: 474 | try: 475 | tmp_class_name, confidence, left, top, right, bottom = line.split() 476 | except: 477 | line_split = line.split() 478 | bottom = line_split[-1] 479 | right = line_split[-2] 480 | top = line_split[-3] 481 | left = line_split[-4] 482 | confidence = line_split[-5] 483 | tmp_class_name = "" 484 | for name in line_split[:-5]: 485 | tmp_class_name += name + " " 486 | tmp_class_name = tmp_class_name[:-1] 487 | 488 | if tmp_class_name == class_name: 489 | bbox = left + " " + top + " " + right + " " +bottom 490 | bounding_boxes.append({"confidence":confidence, "file_id":file_id, "bbox":bbox}) 491 | 492 | bounding_boxes.sort(key=lambda x:float(x['confidence']), reverse=True) 493 | with open(TEMP_FILES_PATH + "/" + class_name + "_dr.json", 'w') as outfile: 494 | json.dump(bounding_boxes, outfile) 495 | 496 | """ 497 | Calculate the AP for each class 498 | """ 499 | sum_AP = 0.0 500 | ap_dictionary = {} 501 | lamr_dictionary = {} 502 | with open(results_files_path + "/results.txt", 'w') as results_file: 503 | results_file.write("# AP and precision/recall per class\n") 504 | count_true_positives = {} 505 | 506 | for class_index, class_name in enumerate(gt_classes): 507 | count_true_positives[class_name] = 0 508 | """ 509 | Load detection-results of that class 510 | """ 511 | dr_file = TEMP_FILES_PATH + "/" + class_name + "_dr.json" 512 | dr_data = json.load(open(dr_file)) 513 | """ 514 | Assign detection-results to ground-truth objects 515 | """ 516 | nd = len(dr_data) 517 | tp = [0] * nd 518 | fp = [0] * nd 519 | score = [0] * nd 520 | score05_idx = 0 521 | for idx, detection in enumerate(dr_data): 522 | file_id = detection["file_id"] 523 | score[idx] = float(detection["confidence"]) 524 | if score[idx] > 0.5: 525 | score05_idx = idx 526 | 527 | if show_animation: 528 | ground_truth_img = glob.glob1(IMG_PATH, file_id + ".*") 529 | if len(ground_truth_img) == 0: 530 | error("Error. Image not found with id: " + file_id) 531 | elif len(ground_truth_img) > 1: 532 | error("Error. Multiple image with id: " + file_id) 533 | else: 534 | img = cv2.imread(IMG_PATH + "/" + ground_truth_img[0]) 535 | img_cumulative_path = results_files_path + "/images/" + ground_truth_img[0] 536 | if os.path.isfile(img_cumulative_path): 537 | img_cumulative = cv2.imread(img_cumulative_path) 538 | else: 539 | img_cumulative = img.copy() 540 | bottom_border = 60 541 | BLACK = [0, 0, 0] 542 | img = cv2.copyMakeBorder(img, 0, bottom_border, 0, 0, cv2.BORDER_CONSTANT, value=BLACK) 543 | 544 | gt_file = TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json" 545 | ground_truth_data = json.load(open(gt_file)) 546 | ovmax = -1 547 | gt_match = -1 548 | bb = [ float(x) for x in detection["bbox"].split() ] 549 | for obj in ground_truth_data: 550 | if obj["class_name"] == class_name: 551 | bbgt = [ float(x) for x in obj["bbox"].split() ] 552 | bi = [max(bb[0],bbgt[0]), max(bb[1],bbgt[1]), min(bb[2],bbgt[2]), min(bb[3],bbgt[3])] 553 | iw = bi[2] - bi[0] + 1 554 | ih = bi[3] - bi[1] + 1 555 | if iw > 0 and ih > 0: 556 | # compute overlap (IoU) = area of intersection / area of union 557 | ua = (bb[2] - bb[0] + 1) * (bb[3] - bb[1] + 1) + (bbgt[2] - bbgt[0] 558 | + 1) * (bbgt[3] - bbgt[1] + 1) - iw * ih 559 | ov = iw * ih / ua 560 | if ov > ovmax: 561 | ovmax = ov 562 | gt_match = obj 563 | 564 | if show_animation: 565 | status = "NO MATCH FOUND!" 566 | min_overlap = MINOVERLAP 567 | if specific_iou_flagged: 568 | if class_name in specific_iou_classes: 569 | index = specific_iou_classes.index(class_name) 570 | min_overlap = float(iou_list[index]) 571 | if ovmax >= min_overlap: 572 | if "difficult" not in gt_match: 573 | if not bool(gt_match["used"]): 574 | tp[idx] = 1 575 | gt_match["used"] = True 576 | count_true_positives[class_name] += 1 577 | with open(gt_file, 'w') as f: 578 | f.write(json.dumps(ground_truth_data)) 579 | if show_animation: 580 | status = "MATCH!" 581 | else: 582 | fp[idx] = 1 583 | if show_animation: 584 | status = "REPEATED MATCH!" 585 | else: 586 | fp[idx] = 1 587 | if ovmax > 0: 588 | status = "INSUFFICIENT OVERLAP" 589 | 590 | """ 591 | Draw image to show animation 592 | """ 593 | if show_animation: 594 | height, widht = img.shape[:2] 595 | # colors (OpenCV works with BGR) 596 | white = (255,255,255) 597 | light_blue = (255,200,100) 598 | green = (0,255,0) 599 | light_red = (30,30,255) 600 | # 1st line 601 | margin = 10 602 | v_pos = int(height - margin - (bottom_border / 2.0)) 603 | text = "Image: " + ground_truth_img[0] + " " 604 | img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0) 605 | text = "Class [" + str(class_index) + "/" + str(n_classes) + "]: " + class_name + " " 606 | img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), light_blue, line_width) 607 | if ovmax != -1: 608 | color = light_red 609 | if status == "INSUFFICIENT OVERLAP": 610 | text = "IoU: {0:.2f}% ".format(ovmax*100) + "< {0:.2f}% ".format(min_overlap*100) 611 | else: 612 | text = "IoU: {0:.2f}% ".format(ovmax*100) + ">= {0:.2f}% ".format(min_overlap*100) 613 | color = green 614 | img, _ = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width) 615 | # 2nd line 616 | v_pos += int(bottom_border / 2.0) 617 | rank_pos = str(idx+1) # rank position (idx starts at 0) 618 | text = "Detection #rank: " + rank_pos + " confidence: {0:.2f}% ".format(float(detection["confidence"])*100) 619 | img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0) 620 | color = light_red 621 | if status == "MATCH!": 622 | color = green 623 | text = "Result: " + status + " " 624 | img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width) 625 | 626 | font = cv2.FONT_HERSHEY_SIMPLEX 627 | if ovmax > 0: # if there is intersections between the bounding-boxes 628 | bbgt = [ int(round(float(x))) for x in gt_match["bbox"].split() ] 629 | cv2.rectangle(img,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2) 630 | cv2.rectangle(img_cumulative,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2) 631 | cv2.putText(img_cumulative, class_name, (bbgt[0],bbgt[1] - 5), font, 0.6, light_blue, 1, cv2.LINE_AA) 632 | bb = [int(i) for i in bb] 633 | cv2.rectangle(img,(bb[0],bb[1]),(bb[2],bb[3]),color,2) 634 | cv2.rectangle(img_cumulative,(bb[0],bb[1]),(bb[2],bb[3]),color,2) 635 | cv2.putText(img_cumulative, class_name, (bb[0],bb[1] - 5), font, 0.6, color, 1, cv2.LINE_AA) 636 | # show image 637 | cv2.imshow("Animation", img) 638 | cv2.waitKey(20) # show for 20 ms 639 | # save image to results 640 | output_img_path = results_files_path + "/images/detections_one_by_one/" + class_name + "_detection" + str(idx) + ".jpg" 641 | cv2.imwrite(output_img_path, img) 642 | # save the image with all the objects drawn to it 643 | cv2.imwrite(img_cumulative_path, img_cumulative) 644 | 645 | cumsum = 0 646 | for idx, val in enumerate(fp): 647 | fp[idx] += cumsum 648 | cumsum += val 649 | 650 | cumsum = 0 651 | for idx, val in enumerate(tp): 652 | tp[idx] += cumsum 653 | cumsum += val 654 | 655 | rec = tp[:] 656 | for idx, val in enumerate(tp): 657 | rec[idx] = float(tp[idx]) / np.maximum(gt_counter_per_class[class_name], 1) 658 | 659 | prec = tp[:] 660 | for idx, val in enumerate(tp): 661 | prec[idx] = float(tp[idx]) / np.maximum((fp[idx] + tp[idx]), 1) 662 | 663 | ap, mrec, mprec = voc_ap(rec[:], prec[:]) 664 | F1 = np.array(rec)*np.array(prec)*2 / np.where((np.array(prec)+np.array(rec))==0, 1, (np.array(prec)+np.array(rec))) 665 | 666 | sum_AP += ap 667 | text = "{0:.2f}%".format(ap*100) + " = " + class_name + " AP " #class_name + " AP = {0:.2f}%".format(ap*100) 668 | 669 | if len(prec)>0: 670 | F1_text = "{0:.2f}".format(F1[score05_idx]) + " = " + class_name + " F1 " 671 | Recall_text = "{0:.2f}%".format(rec[score05_idx]*100) + " = " + class_name + " Recall " 672 | Precision_text = "{0:.2f}%".format(prec[score05_idx]*100) + " = " + class_name + " Precision " 673 | else: 674 | F1_text = "0.00" + " = " + class_name + " F1 " 675 | Recall_text = "0.00%" + " = " + class_name + " Recall " 676 | Precision_text = "0.00%" + " = " + class_name + " Precision " 677 | 678 | rounded_prec = [ '%.2f' % elem for elem in prec ] 679 | rounded_rec = [ '%.2f' % elem for elem in rec ] 680 | results_file.write(text + "\n Precision: " + str(rounded_prec) + "\n Recall :" + str(rounded_rec) + "\n\n") 681 | if not args.quiet: 682 | if len(prec)>0: 683 | print(text + "\t||\tscore_threhold=0.5 : " + "F1=" + "{0:.2f}".format(F1[score05_idx])\ 684 | + " ; Recall=" + "{0:.2f}%".format(rec[score05_idx]*100) + " ; Precision=" + "{0:.2f}%".format(prec[score05_idx]*100)) 685 | else: 686 | print(text + "\t||\tscore_threhold=0.5 : F1=0.00% ; Recall=0.00% ; Precision=0.00%") 687 | ap_dictionary[class_name] = ap 688 | 689 | n_images = counter_images_per_class[class_name] 690 | lamr, mr, fppi = log_average_miss_rate(np.array(rec), np.array(fp), n_images) 691 | lamr_dictionary[class_name] = lamr 692 | 693 | """ 694 | Draw plot 695 | """ 696 | if draw_plot: 697 | plt.plot(rec, prec, '-o') 698 | area_under_curve_x = mrec[:-1] + [mrec[-2]] + [mrec[-1]] 699 | area_under_curve_y = mprec[:-1] + [0.0] + [mprec[-1]] 700 | plt.fill_between(area_under_curve_x, 0, area_under_curve_y, alpha=0.2, edgecolor='r') 701 | 702 | fig = plt.gcf() 703 | fig.canvas.set_window_title('AP ' + class_name) 704 | 705 | plt.title('class: ' + text) 706 | plt.xlabel('Recall') 707 | plt.ylabel('Precision') 708 | axes = plt.gca() 709 | axes.set_xlim([0.0,1.0]) 710 | axes.set_ylim([0.0,1.05]) 711 | fig.savefig(results_files_path + "/AP/" + class_name + ".png") 712 | plt.cla() 713 | 714 | plt.plot(score, F1, "-", color='orangered') 715 | plt.title('class: ' + F1_text + "\nscore_threhold=0.5") 716 | plt.xlabel('Score_Threhold') 717 | plt.ylabel('F1') 718 | axes = plt.gca() 719 | axes.set_xlim([0.0,1.0]) 720 | axes.set_ylim([0.0,1.05]) 721 | fig.savefig(results_files_path + "/F1/" + class_name + ".png") 722 | plt.cla() 723 | 724 | plt.plot(score, rec, "-H", color='gold') 725 | plt.title('class: ' + Recall_text + "\nscore_threhold=0.5") 726 | plt.xlabel('Score_Threhold') 727 | plt.ylabel('Recall') 728 | axes = plt.gca() 729 | axes.set_xlim([0.0,1.0]) 730 | axes.set_ylim([0.0,1.05]) 731 | fig.savefig(results_files_path + "/Recall/" + class_name + ".png") 732 | plt.cla() 733 | 734 | plt.plot(score, prec, "-s", color='palevioletred') 735 | plt.title('class: ' + Precision_text + "\nscore_threhold=0.5") 736 | plt.xlabel('Score_Threhold') 737 | plt.ylabel('Precision') 738 | axes = plt.gca() 739 | axes.set_xlim([0.0,1.0]) 740 | axes.set_ylim([0.0,1.05]) 741 | fig.savefig(results_files_path + "/Precision/" + class_name + ".png") 742 | plt.cla() 743 | 744 | if show_animation: 745 | cv2.destroyAllWindows() 746 | 747 | results_file.write("\n# mAP of all classes\n") 748 | mAP = sum_AP / n_classes 749 | text = "mAP = {0:.2f}%".format(mAP*100) 750 | results_file.write(text + "\n") 751 | print(text) 752 | 753 | # remove the temp_files directory 754 | shutil.rmtree(TEMP_FILES_PATH) 755 | 756 | """ 757 | Count total of detection-results 758 | """ 759 | # iterate through all the files 760 | det_counter_per_class = {} 761 | for txt_file in dr_files_list: 762 | # get lines to list 763 | lines_list = file_lines_to_list(txt_file) 764 | for line in lines_list: 765 | class_name = line.split()[0] 766 | # check if class is in the ignore list, if yes skip 767 | if class_name in args.ignore: 768 | continue 769 | # count that object 770 | if class_name in det_counter_per_class: 771 | det_counter_per_class[class_name] += 1 772 | else: 773 | # if class didn't exist yet 774 | det_counter_per_class[class_name] = 1 775 | #print(det_counter_per_class) 776 | dr_classes = list(det_counter_per_class.keys()) 777 | 778 | 779 | """ 780 | Plot the total number of occurences of each class in the ground-truth 781 | """ 782 | if draw_plot: 783 | window_title = "ground-truth-info" 784 | plot_title = "ground-truth\n" 785 | plot_title += "(" + str(len(ground_truth_files_list)) + " files and " + str(n_classes) + " classes)" 786 | x_label = "Number of objects per class" 787 | output_path = results_files_path + "/ground-truth-info.png" 788 | to_show = False 789 | plot_color = 'forestgreen' 790 | draw_plot_func( 791 | gt_counter_per_class, 792 | n_classes, 793 | window_title, 794 | plot_title, 795 | x_label, 796 | output_path, 797 | to_show, 798 | plot_color, 799 | '', 800 | ) 801 | 802 | """ 803 | Write number of ground-truth objects per class to results.txt 804 | """ 805 | with open(results_files_path + "/results.txt", 'a') as results_file: 806 | results_file.write("\n# Number of ground-truth objects per class\n") 807 | for class_name in sorted(gt_counter_per_class): 808 | results_file.write(class_name + ": " + str(gt_counter_per_class[class_name]) + "\n") 809 | 810 | """ 811 | Finish counting true positives 812 | """ 813 | for class_name in dr_classes: 814 | # if class exists in detection-result but not in ground-truth then there are no true positives in that class 815 | if class_name not in gt_classes: 816 | count_true_positives[class_name] = 0 817 | #print(count_true_positives) 818 | 819 | """ 820 | Plot the total number of occurences of each class in the "detection-results" folder 821 | """ 822 | if draw_plot: 823 | window_title = "detection-results-info" 824 | # Plot title 825 | plot_title = "detection-results\n" 826 | plot_title += "(" + str(len(dr_files_list)) + " files and " 827 | count_non_zero_values_in_dictionary = sum(int(x) > 0 for x in list(det_counter_per_class.values())) 828 | plot_title += str(count_non_zero_values_in_dictionary) + " detected classes)" 829 | # end Plot title 830 | x_label = "Number of objects per class" 831 | output_path = results_files_path + "/detection-results-info.png" 832 | to_show = False 833 | plot_color = 'forestgreen' 834 | true_p_bar = count_true_positives 835 | draw_plot_func( 836 | det_counter_per_class, 837 | len(det_counter_per_class), 838 | window_title, 839 | plot_title, 840 | x_label, 841 | output_path, 842 | to_show, 843 | plot_color, 844 | true_p_bar 845 | ) 846 | 847 | """ 848 | Write number of detected objects per class to results.txt 849 | """ 850 | with open(results_files_path + "/results.txt", 'a') as results_file: 851 | results_file.write("\n# Number of detected objects per class\n") 852 | for class_name in sorted(dr_classes): 853 | n_det = det_counter_per_class[class_name] 854 | text = class_name + ": " + str(n_det) 855 | text += " (tp:" + str(count_true_positives[class_name]) + "" 856 | text += ", fp:" + str(n_det - count_true_positives[class_name]) + ")\n" 857 | results_file.write(text) 858 | 859 | """ 860 | Draw log-average miss rate plot (Show lamr of all classes in decreasing order) 861 | """ 862 | if draw_plot: 863 | window_title = "lamr" 864 | plot_title = "log-average miss rate" 865 | x_label = "log-average miss rate" 866 | output_path = results_files_path + "/lamr.png" 867 | to_show = False 868 | plot_color = 'royalblue' 869 | draw_plot_func( 870 | lamr_dictionary, 871 | n_classes, 872 | window_title, 873 | plot_title, 874 | x_label, 875 | output_path, 876 | to_show, 877 | plot_color, 878 | "" 879 | ) 880 | 881 | """ 882 | Draw mAP plot (Show AP's of all classes in decreasing order) 883 | """ 884 | if draw_plot: 885 | window_title = "mAP" 886 | plot_title = "mAP = {0:.2f}%".format(mAP*100) 887 | x_label = "Average Precision" 888 | output_path = results_files_path + "/mAP.png" 889 | to_show = True 890 | plot_color = 'royalblue' 891 | draw_plot_func( 892 | ap_dictionary, 893 | n_classes, 894 | window_title, 895 | plot_title, 896 | x_label, 897 | output_path, 898 | to_show, 899 | plot_color, 900 | "" 901 | ) 902 | --------------------------------------------------------------------------------