├── PKLot_server
├── detect_Logs
│ ├── detect.txt
│ ├── logs.txt
│ └── 1.png
├── model_data
│ ├── coco_classes.txt
│ ├── voc_classes.txt
│ ├── timg.jpg
│ └── simhei.ttf
├── nets
│ ├── __pycache__
│ │ ├── resnet.cpython-37.pyc
│ │ ├── resnet.cpython-38.pyc
│ │ ├── centernet.cpython-37.pyc
│ │ ├── centernet.cpython-38.pyc
│ │ ├── hourglass.cpython-37.pyc
│ │ ├── hourglass.cpython-38.pyc
│ │ ├── centernet_training.cpython-37.pyc
│ │ └── centernet_training.cpython-38.pyc
│ ├── resnet.py
│ ├── hourglass.py
│ ├── centernet.py
│ └── centernet_training.py
├── utils
│ ├── __pycache__
│ │ ├── utils.cpython-37.pyc
│ │ └── utils.cpython-38.pyc
│ └── utils.py
├── data_fack.py
├── DataServer.py
├── PKLotServer.py
├── get_gt_txt.py
├── get_dr_txt.py
├── MainWindow.py
├── predictUI.py
├── centernet.py
└── get_map.py
├── .gitattributes
├── PKLot_client
├── background.jpg
├── PKlot_client.py
├── client_window.ui
└── client_window.py
└── README.md
/PKLot_server/detect_Logs/detect.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/PKLot_server/detect_Logs/logs.txt:
--------------------------------------------------------------------------------
1 | 2023-01-09 08:08:22;44;0;0
--------------------------------------------------------------------------------
/PKLot_server/model_data/coco_classes.txt:
--------------------------------------------------------------------------------
1 | space-occupied
2 | space-empty
3 |
--------------------------------------------------------------------------------
/PKLot_server/model_data/voc_classes.txt:
--------------------------------------------------------------------------------
1 | space-occupied
2 | space-empty
3 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/PKLot_client/background.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiXing-w/Vehicles-ParkingSpaces-Object-detection-and-traffic-statistics/HEAD/PKLot_client/background.jpg
--------------------------------------------------------------------------------
/PKLot_server/detect_Logs/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiXing-w/Vehicles-ParkingSpaces-Object-detection-and-traffic-statistics/HEAD/PKLot_server/detect_Logs/1.png
--------------------------------------------------------------------------------
/PKLot_server/model_data/timg.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiXing-w/Vehicles-ParkingSpaces-Object-detection-and-traffic-statistics/HEAD/PKLot_server/model_data/timg.jpg
--------------------------------------------------------------------------------
/PKLot_server/model_data/simhei.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiXing-w/Vehicles-ParkingSpaces-Object-detection-and-traffic-statistics/HEAD/PKLot_server/model_data/simhei.ttf
--------------------------------------------------------------------------------
/PKLot_server/nets/__pycache__/resnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiXing-w/Vehicles-ParkingSpaces-Object-detection-and-traffic-statistics/HEAD/PKLot_server/nets/__pycache__/resnet.cpython-37.pyc
--------------------------------------------------------------------------------
/PKLot_server/nets/__pycache__/resnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiXing-w/Vehicles-ParkingSpaces-Object-detection-and-traffic-statistics/HEAD/PKLot_server/nets/__pycache__/resnet.cpython-38.pyc
--------------------------------------------------------------------------------
/PKLot_server/utils/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiXing-w/Vehicles-ParkingSpaces-Object-detection-and-traffic-statistics/HEAD/PKLot_server/utils/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/PKLot_server/utils/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiXing-w/Vehicles-ParkingSpaces-Object-detection-and-traffic-statistics/HEAD/PKLot_server/utils/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/PKLot_server/nets/__pycache__/centernet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiXing-w/Vehicles-ParkingSpaces-Object-detection-and-traffic-statistics/HEAD/PKLot_server/nets/__pycache__/centernet.cpython-37.pyc
--------------------------------------------------------------------------------
/PKLot_server/nets/__pycache__/centernet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiXing-w/Vehicles-ParkingSpaces-Object-detection-and-traffic-statistics/HEAD/PKLot_server/nets/__pycache__/centernet.cpython-38.pyc
--------------------------------------------------------------------------------
/PKLot_server/nets/__pycache__/hourglass.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiXing-w/Vehicles-ParkingSpaces-Object-detection-and-traffic-statistics/HEAD/PKLot_server/nets/__pycache__/hourglass.cpython-37.pyc
--------------------------------------------------------------------------------
/PKLot_server/nets/__pycache__/hourglass.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiXing-w/Vehicles-ParkingSpaces-Object-detection-and-traffic-statistics/HEAD/PKLot_server/nets/__pycache__/hourglass.cpython-38.pyc
--------------------------------------------------------------------------------
/PKLot_server/nets/__pycache__/centernet_training.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiXing-w/Vehicles-ParkingSpaces-Object-detection-and-traffic-statistics/HEAD/PKLot_server/nets/__pycache__/centernet_training.cpython-37.pyc
--------------------------------------------------------------------------------
/PKLot_server/nets/__pycache__/centernet_training.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AiXing-w/Vehicles-ParkingSpaces-Object-detection-and-traffic-statistics/HEAD/PKLot_server/nets/__pycache__/centernet_training.cpython-38.pyc
--------------------------------------------------------------------------------
/PKLot_server/data_fack.py:
--------------------------------------------------------------------------------
1 | import random
2 | max_set = 42
3 | year = "2022"
4 | month = "5"
5 | day = "16"
6 |
7 | with open("./detect_Logs/countlogs", 'a') as f:
8 | for i in range(24):
9 | for j in range(60):
10 | for k in range(0, 60, 5):
11 | rdi = random.randint(0, max_set)
12 | f.write(str(i) + ";" + str(rdi) + ";" + str(max_set-rdi))
13 | f.write("\n")
14 |
--------------------------------------------------------------------------------
/PKLot_server/DataServer.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | def FormetDayTime(datelist):
4 | day = ""
5 | for i in range(3):
6 | if i:
7 | day += '-'
8 | day += datelist[i]
9 |
10 | time = ""
11 | for i in range(3,len(datelist)):
12 | if i != 3:
13 | time += ':'
14 | time += datelist[i]
15 | return day, time
16 |
17 |
18 | def SaveData(day, time, num):
19 | # 保存人数
20 | if not os.path.exists("dayLogs"):
21 | os.mkdir("dayLogs")
22 | with open(os.path.join("dayLogs", day), "a", encoding='utf-8') as f:
23 | f.write(str(num))
24 | f.write("\t")
25 | f.write(time)
26 | f.write("\n")
27 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Vehicles-ParkingSpaces-Object-detection-and-traffic-statistics
2 | 通过目标检测算法检测和统计区域内的车辆与车位情况使用户能够提前获知车位情况以及对用户停车位置的引导,并能对实际车流量进行统计。
3 |
4 | # 功能介绍
5 | ## 服务器端(PKLot_server)
6 | PKLot_server中的是服务器端,负责对车位进行目标检测,判断是否已有车停在车位上,并统计计算最优停车区域,以及记录车位占用情况以供统计可视化使用,检测后的数据将存储到detect_Log中,以供数据分析使用。当服务器端需要统计时,从detect_Log中读出相应的数据并做图展示
7 |
8 | ## 客户端(PKLot_client)
9 | PKLot_client为客户端,客户端发起车位查询的请求,客户端负责接收来自服务器端检测的车位占用情况以及最优的停车区域,并将其标注在画面上起到对于用户的指引作用
10 |
11 | ## 可视化
12 |
13 | 使用PyQt5构建可视化窗口
14 |
15 | # 权值文件
16 |
17 | 权值文件下载后应该放到 \PKLot_server\model_data文件夹中
18 |
19 | **权值文件下载地址:**
20 |
21 | 链接:https://pan.baidu.com/s/1X_iAieXwVn02HEo5ihFSFw?pwd=xqbr
22 |
23 | 提取码:xqbr
24 |
25 | # 测试样例
26 |
27 | **测试样例下载链接:**
28 |
29 | 链接:https://pan.baidu.com/s/1wH5_PkoUYxQTXvEt6wzoHA?pwd=m76z
30 |
31 | 提取码:m76z
32 |
33 | # 演示视频
34 |
35 | 视频链接:[https://www.bilibili.com/video/BV1XG4y1w7kT](https://www.bilibili.com/video/BV1XG4y1w7kT)
36 |
--------------------------------------------------------------------------------
/PKLot_server/PKLotServer.py:
--------------------------------------------------------------------------------
1 | #!coding=utf-8
2 |
3 | import threading
4 | import socket
5 | import struct
6 |
7 |
8 | def socket_service():
9 | try:
10 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
11 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
12 | # 绑定端口为9001
13 | s.bind(('127.0.0.1', 9001))
14 | # 设置监听数
15 | s.listen(10)
16 | except socket.error as msg:
17 | print(msg)
18 | sys.exit(1)
19 | print('Waiting connection...')
20 |
21 | while 1:
22 | # 等待请求并接受(程序会停留在这一旦收到连接请求即开启接受数据的线程)
23 | conn, addr = s.accept()
24 | # 接收数据
25 | t = threading.Thread(target=deal_data, args=(conn, addr))
26 | t.start()
27 |
28 |
29 | def deal_data(conn, addr):
30 | print('Accept new connection from {0}'.format(addr))
31 | # conn.settimeout(500)
32 | # 收到请求后的回复
33 | conn.send('Hi, Welcome to the server!'.encode('utf-8'))
34 |
35 | while 1:
36 | # 申请相同大小的空间存放发送过来的文件名与文件大小信息
37 | fileinfo_size = struct.calcsize('128sl')
38 | # 接收文件名与文件大小信息
39 | buf = conn.recv(fileinfo_size)
40 | # 判断是否接收到文件头信息
41 | if buf:
42 | # 获取文件名和文件大小
43 | filename, filesize = struct.unpack('128sl', buf)
44 | fn = filename.strip(b'\00')
45 | fn = fn.decode()
46 | print('file new name is {0}, filesize if {1}'.format(str(fn), filesize))
47 |
48 | recvd_size = 0 # 定义已接收文件的大小
49 | # 存储在该脚本所在目录下面
50 | fp = open('./' + str(fn), 'wb')
51 | print('start receiving...')
52 |
53 | # 将分批次传输的二进制流依次写入到文件
54 | while not recvd_size == filesize:
55 | if filesize - recvd_size > 1024:
56 | data = conn.recv(1024)
57 | recvd_size += len(data)
58 | else:
59 | data = conn.recv(filesize - recvd_size)
60 | recvd_size = filesize
61 | fp.write(data)
62 | fp.close()
63 | print('end receive...')
64 | # 传输结束断开连接
65 | conn.close()
66 | break
67 |
68 |
69 | if __name__ == "__main__":
70 | socket_service()
71 |
--------------------------------------------------------------------------------
/PKLot_server/get_gt_txt.py:
--------------------------------------------------------------------------------
1 | #----------------------------------------------------#
2 | # 获取测试集的ground-truth
3 | # 具体视频教程可查看
4 | # https://www.bilibili.com/video/BV1zE411u7Vw
5 | #----------------------------------------------------#
6 | import sys
7 | import os
8 | import glob
9 | import xml.etree.ElementTree as ET
10 |
11 | '''
12 | !!!!!!!!!!!!!注意事项!!!!!!!!!!!!!
13 | # 这一部分是当xml有无关的类的时候,下方有代码可以进行筛选!
14 | '''
15 | #---------------------------------------------------#
16 | # 获得类
17 | #---------------------------------------------------#
18 | def get_classes(classes_path):
19 | '''loads the classes'''
20 | with open(classes_path) as f:
21 | class_names = f.readlines()
22 | class_names = [c.strip() for c in class_names]
23 | return class_names
24 |
25 | image_ids = open('VOCdevkit/VOC2007/ImageSets/Main/test.txt').read().strip().split()
26 |
27 | if not os.path.exists("./input"):
28 | os.makedirs("./input")
29 | if not os.path.exists("./input/ground-truth"):
30 | os.makedirs("./input/ground-truth")
31 |
32 | for image_id in image_ids:
33 | with open("./input/ground-truth/"+image_id+".txt", "w") as new_f:
34 | root = ET.parse("VOCdevkit/VOC2007/Annotations/"+image_id+".xml").getroot()
35 | for obj in root.findall('object'):
36 | difficult_flag = False
37 | if obj.find('difficult')!=None:
38 | difficult = obj.find('difficult').text
39 | if int(difficult)==1:
40 | difficult_flag = True
41 | obj_name = obj.find('name').text
42 | '''
43 | !!!!!!!!!!!!注意事项!!!!!!!!!!!!
44 | # 这一部分是当xml有无关的类的时候,可以取消下面代码的注释
45 | # 利用对应的classes.txt来进行筛选!!!!!!!!!!!!
46 | '''
47 | # classes_path = 'model_data/voc_classes.txt'
48 | # class_names = get_classes(classes_path)
49 | # if obj_name not in class_names:
50 | # continue
51 |
52 | bndbox = obj.find('bndbox')
53 | left = bndbox.find('xmin').text
54 | top = bndbox.find('ymin').text
55 | right = bndbox.find('xmax').text
56 | bottom = bndbox.find('ymax').text
57 |
58 | if difficult_flag:
59 | new_f.write("%s %s %s %s %s difficult\n" % (obj_name, left, top, right, bottom))
60 | else:
61 | new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
62 |
63 | print("Conversion completed!")
64 |
--------------------------------------------------------------------------------
/PKLot_client/PKlot_client.py:
--------------------------------------------------------------------------------
1 | from client_window import Ui_MainWindow
2 | from PyQt5.QtWidgets import QWidget, QMainWindow, QApplication, QGraphicsScene, QGraphicsPixmapItem, QFileDialog, QMessageBox
3 | from PyQt5.QtGui import QImage, QPixmap
4 | import sys
5 | import os
6 | import datetime
7 | import cv2
8 | import numpy as np
9 | from PIL import Image, ImageDraw, ImageFont
10 |
11 | import multiprocessing as mp
12 | import socket
13 | from PIL import Image
14 | from io import BytesIO
15 |
16 |
17 | class client_window(QMainWindow, Ui_MainWindow):
18 | def __init__(self):
19 | super().__init__()
20 | self.setupUi(self)
21 | img_src = cv2.imread("background.jpg") # 读取图像
22 | img_src = cv2.cvtColor(img_src, cv2.COLOR_BGR2RGB) # 转换图像通道
23 | label_width = self.label.width()
24 | label_height = self.label.height()
25 | temp_imgSrc = QImage(img_src[:], img_src.shape[1], img_src.shape[0], img_src.shape[1] * 3, QImage.Format_RGB888)
26 | # 将图片转换为QPixmap方便显示
27 | self.pixmap_imgSrc = QPixmap.fromImage(temp_imgSrc).scaled(label_width, label_height)
28 | self.label.setPixmap(QPixmap(self.pixmap_imgSrc))
29 | self.pushButton.clicked.connect(self.request)
30 |
31 | def request(self):
32 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
33 | sock.connect(('127.0.0.1', 8000))
34 | sock.send(b'123') # (b'GET / HTTP/1.1\r\nHost: 127.0.0.1:8000\r\n\r\n')
35 | data = sock.recv(4096)
36 | date, op, ep, cnt = data.decode().split(';')
37 | self.lineEdit.setText(date)
38 | self.lineEdit_2.setText(op)
39 | self.lineEdit_3.setText(ep)
40 | sock.close()
41 | cnt = int(cnt.strip())
42 | img_src = cv2.imread("background.jpg") # 读取图像
43 | img_src = cv2.cvtColor(img_src, cv2.COLOR_BGR2RGB) # 转换图像通道
44 |
45 | label_width = self.label.width()
46 | label_height = self.label.height()
47 |
48 | if cnt == 1:
49 | cv2.rectangle(img_src, (10, 10), (img_src.shape[0] // 2, img_src.shape[1] // 2), (255, 0, 0), 5)
50 | textx = (img_src.shape[0] // 2 - 15) // 3
51 | texty = (img_src.shape[1] // 2 - 50) // 3
52 | cv2.putText(img_src, "Best", (textx,texty), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 0), 5)
53 | cv2.putText(img_src, "Region", (textx - 15, texty + 80), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 0), 5)
54 | elif cnt == 2:
55 | cv2.rectangle(img_src, (img_src.shape[0] // 2, 0), (img_src.shape[0], img_src.shape[1] // 2 - 10), (255, 0, 0), 5)
56 | textx = img_src.shape[0] // 2 + (img_src.shape[0] // 2) // 3
57 | texty = (img_src.shape[1] // 2 - 50) // 3
58 | cv2.putText(img_src, "Best", (textx, texty), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 0), 5)
59 | cv2.putText(img_src, "Region", (textx - 15, texty + 80), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 0), 5)
60 | elif cnt == 3:
61 | cv2.rectangle(img_src, (10, img_src.shape[0] // 2), (img_src.shape[0] // 2, img_src.shape[1] - 10), (255, 0, 0), 5)
62 | textx = (img_src.shape[0] // 2 - 15) // 3
63 | texty = img_src.shape[0] // 2 + (img_src.shape[1] // 2 - 50) // 3
64 | cv2.putText(img_src, "Best", (textx, texty), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 0), 5)
65 | cv2.putText(img_src, "Region", (textx - 15, texty + 80), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 0), 5)
66 |
67 | elif cnt == 4:
68 | cv2.rectangle(img_src, (img_src.shape[0] // 2, img_src.shape[1] // 2), (img_src.shape[0] - 10, img_src.shape[1] - 10), (255, 0, 0), 5)
69 | textx = img_src.shape[0] // 2 + (img_src.shape[0] // 2 - 15) // 3
70 | texty = img_src.shape[0] // 2 + (img_src.shape[1] // 2 - 50) // 3
71 | cv2.putText(img_src, "Best", (textx, texty), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 0), 5)
72 | cv2.putText(img_src, "Region", (textx - 15, texty + 80), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 0), 5)
73 |
74 |
75 | temp_imgSrc = QImage(img_src[:], img_src.shape[1], img_src.shape[0], img_src.shape[1] * 3, QImage.Format_RGB888)
76 | # 将图片转换为QPixmap方便显示
77 | self.pixmap_imgSrc = QPixmap.fromImage(temp_imgSrc).scaled(label_width, label_height)
78 | self.label.setPixmap(QPixmap(self.pixmap_imgSrc))
79 |
80 | if __name__ == '__main__':
81 | app = QApplication(sys.argv)
82 | window = client_window()
83 | window.show()
84 | sys.exit(app.exec_())
85 |
--------------------------------------------------------------------------------
/PKLot_client/client_window.ui:
--------------------------------------------------------------------------------
1 |
2 |
3 | MainWindow
4 |
5 |
6 |
7 | 0
8 | 0
9 | 677
10 | 385
11 |
12 |
13 |
14 | MainWindow
15 |
16 |
17 |
18 |
19 |
20 | 20
21 | 20
22 | 261
23 | 291
24 |
25 |
26 |
27 | 连接服务器失败...请稍后重试
28 |
29 |
30 |
31 |
32 |
33 | 30
34 | 520
35 | 631
36 | 41
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 | 313
47 | 40
48 | 321
49 | 271
50 |
51 |
52 |
53 | -
54 |
55 |
56 | Qt::Vertical
57 |
58 |
59 |
60 | -
61 |
62 |
-
63 |
64 |
-
65 |
66 |
67 | 查询时间:
68 |
69 |
70 |
71 | -
72 |
73 |
74 | false
75 |
76 |
77 | true
78 |
79 |
80 |
81 |
82 |
83 | -
84 |
85 |
-
86 |
87 |
88 | 已占车位:
89 |
90 |
91 |
92 | -
93 |
94 |
95 | false
96 |
97 |
98 | true
99 |
100 |
101 |
102 |
103 |
104 | -
105 |
106 |
-
107 |
108 |
109 | 剩余车位:
110 |
111 |
112 |
113 | -
114 |
115 |
116 | false
117 |
118 |
119 | true
120 |
121 |
122 |
123 |
124 |
125 | -
126 |
127 |
128 | 查询车位
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
147 |
148 |
149 |
150 |
151 |
152 |
--------------------------------------------------------------------------------
/PKLot_client/client_window.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Form implementation generated from reading ui file 'client_window.ui'
4 | #
5 | # Created by: PyQt5 UI code generator 5.15.4
6 | #
7 | # WARNING: Any manual changes made to this file will be lost when pyuic5 is
8 | # run again. Do not edit this file unless you know what you are doing.
9 |
10 |
11 | from PyQt5 import QtCore, QtGui, QtWidgets
12 |
13 |
14 | class Ui_MainWindow(object):
15 | def setupUi(self, MainWindow):
16 | MainWindow.setObjectName("MainWindow")
17 | MainWindow.resize(677, 385)
18 | self.centralwidget = QtWidgets.QWidget(MainWindow)
19 | self.centralwidget.setObjectName("centralwidget")
20 | self.label = QtWidgets.QLabel(self.centralwidget)
21 | self.label.setGeometry(QtCore.QRect(20, 20, 261, 291))
22 | self.label.setObjectName("label")
23 | self.label_2 = QtWidgets.QLabel(self.centralwidget)
24 | self.label_2.setGeometry(QtCore.QRect(30, 520, 631, 41))
25 | self.label_2.setText("")
26 | self.label_2.setObjectName("label_2")
27 | self.layoutWidget = QtWidgets.QWidget(self.centralwidget)
28 | self.layoutWidget.setGeometry(QtCore.QRect(313, 40, 321, 271))
29 | self.layoutWidget.setObjectName("layoutWidget")
30 | self.horizontalLayout_4 = QtWidgets.QHBoxLayout(self.layoutWidget)
31 | self.horizontalLayout_4.setContentsMargins(0, 0, 0, 0)
32 | self.horizontalLayout_4.setObjectName("horizontalLayout_4")
33 | self.line = QtWidgets.QFrame(self.layoutWidget)
34 | self.line.setFrameShape(QtWidgets.QFrame.VLine)
35 | self.line.setFrameShadow(QtWidgets.QFrame.Sunken)
36 | self.line.setObjectName("line")
37 | self.horizontalLayout_4.addWidget(self.line)
38 | self.verticalLayout = QtWidgets.QVBoxLayout()
39 | self.verticalLayout.setObjectName("verticalLayout")
40 | self.horizontalLayout = QtWidgets.QHBoxLayout()
41 | self.horizontalLayout.setObjectName("horizontalLayout")
42 | self.label_3 = QtWidgets.QLabel(self.layoutWidget)
43 | self.label_3.setObjectName("label_3")
44 | self.horizontalLayout.addWidget(self.label_3)
45 | self.lineEdit = QtWidgets.QLineEdit(self.layoutWidget)
46 | self.lineEdit.setEnabled(False)
47 | self.lineEdit.setReadOnly(True)
48 | self.lineEdit.setObjectName("lineEdit")
49 | self.horizontalLayout.addWidget(self.lineEdit)
50 | self.verticalLayout.addLayout(self.horizontalLayout)
51 | self.horizontalLayout_2 = QtWidgets.QHBoxLayout()
52 | self.horizontalLayout_2.setObjectName("horizontalLayout_2")
53 | self.label_4 = QtWidgets.QLabel(self.layoutWidget)
54 | self.label_4.setObjectName("label_4")
55 | self.horizontalLayout_2.addWidget(self.label_4)
56 | self.lineEdit_2 = QtWidgets.QLineEdit(self.layoutWidget)
57 | self.lineEdit_2.setEnabled(False)
58 | self.lineEdit_2.setReadOnly(True)
59 | self.lineEdit_2.setObjectName("lineEdit_2")
60 | self.horizontalLayout_2.addWidget(self.lineEdit_2)
61 | self.verticalLayout.addLayout(self.horizontalLayout_2)
62 | self.horizontalLayout_3 = QtWidgets.QHBoxLayout()
63 | self.horizontalLayout_3.setObjectName("horizontalLayout_3")
64 | self.label_5 = QtWidgets.QLabel(self.layoutWidget)
65 | self.label_5.setObjectName("label_5")
66 | self.horizontalLayout_3.addWidget(self.label_5)
67 | self.lineEdit_3 = QtWidgets.QLineEdit(self.layoutWidget)
68 | self.lineEdit_3.setEnabled(False)
69 | self.lineEdit_3.setReadOnly(True)
70 | self.lineEdit_3.setObjectName("lineEdit_3")
71 | self.horizontalLayout_3.addWidget(self.lineEdit_3)
72 | self.verticalLayout.addLayout(self.horizontalLayout_3)
73 | self.pushButton = QtWidgets.QPushButton(self.layoutWidget)
74 | self.pushButton.setObjectName("pushButton")
75 | self.verticalLayout.addWidget(self.pushButton)
76 | self.horizontalLayout_4.addLayout(self.verticalLayout)
77 | MainWindow.setCentralWidget(self.centralwidget)
78 | self.menubar = QtWidgets.QMenuBar(MainWindow)
79 | self.menubar.setGeometry(QtCore.QRect(0, 0, 677, 26))
80 | self.menubar.setObjectName("menubar")
81 | MainWindow.setMenuBar(self.menubar)
82 | self.statusbar = QtWidgets.QStatusBar(MainWindow)
83 | self.statusbar.setObjectName("statusbar")
84 | MainWindow.setStatusBar(self.statusbar)
85 |
86 | self.retranslateUi(MainWindow)
87 | QtCore.QMetaObject.connectSlotsByName(MainWindow)
88 |
89 | def retranslateUi(self, MainWindow):
90 | _translate = QtCore.QCoreApplication.translate
91 | MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
92 | self.label.setText(_translate("MainWindow", "连接服务器失败...请稍后重试"))
93 | self.label_3.setText(_translate("MainWindow", "查询时间:"))
94 | self.label_4.setText(_translate("MainWindow", "已占车位:"))
95 | self.label_5.setText(_translate("MainWindow", "剩余车位:"))
96 | self.pushButton.setText(_translate("MainWindow", "查询车位"))
97 |
--------------------------------------------------------------------------------
/PKLot_server/get_dr_txt.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import tensorflow as tf
5 | from PIL import Image
6 | from tqdm import tqdm
7 |
8 | from centernet import CenterNet
9 | from nets.centernet import centernet
10 | from utils.utils import centernet_correct_boxes, letterbox_image, nms
11 |
12 | gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
13 | for gpu in gpus:
14 | tf.config.experimental.set_memory_growth(gpu, True)
15 |
16 | '''
17 | 这里设置的门限值较低是因为计算map需要用到不同门限条件下的Recall和Precision值。
18 | 所以只有保留的框足够多,计算的map才会更精确,详情可以了解map的原理。
19 | 计算map时输出的Recall和Precision值指的是门限为0.5时的Recall和Precision值。
20 |
21 | 此处获得的./input/detection-results/里面的txt的框的数量会比直接predict多一些,这是因为这里的门限低,
22 | 目的是为了计算不同门限条件下的Recall和Precision值,从而实现map的计算。
23 |
24 | 这里的self.nms_threhold指的是非极大抑制所用到的iou,具体的可以了解非极大抑制的原理,
25 | 如果低分框与高分框的iou大于这里设定的self.nms_threhold,那么该低分框将会被剔除。
26 |
27 | 可能有些同学知道有0.5和0.5:0.95的mAP,这里的self.nms_threhold=0.5不代表mAP0.5。
28 | 如果想要设定mAP0.x,比如设定mAP0.75,可以去get_map.py设定MINOVERLAP。
29 | '''
30 | def preprocess_image(image):
31 | mean = [0.40789655, 0.44719303, 0.47026116]
32 | std = [0.2886383 , 0.27408165, 0.27809834]
33 | return ((np.float32(image) / 255.) - mean) / std
34 |
35 | class mAP_CenterNet(CenterNet):
36 | #---------------------------------------------------#
37 | # 检测图片
38 | #---------------------------------------------------#
39 | def detect_image(self,image_id,image):
40 | f = open("./input/detection-results/"+image_id+".txt","w")
41 | self.confidence = 0.01
42 | self.nms_threhold = 0.5
43 | #---------------------------------------------------------#
44 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
45 | #---------------------------------------------------------#
46 | image = image.convert('RGB')
47 |
48 | image_shape = np.array(np.shape(image)[0:2])
49 | #---------------------------------------------------------#
50 | # 给图像增加灰条,实现不失真的resize
51 | #---------------------------------------------------------#
52 | crop_img = letterbox_image(image, [self.input_shape[0],self.input_shape[1]])
53 | #----------------------------------------------------------------------------------#
54 | # 将RGB转化成BGR,这是因为原始的centernet_hourglass权值是使用BGR通道的图片训练的
55 | #----------------------------------------------------------------------------------#
56 | photo = np.array(crop_img,dtype = np.float32)[:,:,::-1]
57 | #-----------------------------------------------------------#
58 | # 图片预处理,归一化。获得的photo的shape为[1, 512, 512, 3]
59 | #-----------------------------------------------------------#
60 | photo = np.reshape(preprocess_image(photo),[1,self.input_shape[0],self.input_shape[1],self.input_shape[2]])
61 |
62 | preds = self.get_pred(photo).numpy()
63 | #-------------------------------------------------------#
64 | # 对于centernet网络来讲,确立中心非常重要。
65 | # 对于大目标而言,会存在许多的局部信息。
66 | # 此时对于同一个大目标,中心点比较难以确定。
67 | # 使用最大池化的非极大抑制方法无法去除局部框
68 | # 所以我还是写了另外一段对框进行非极大抑制的代码
69 | # 实际测试中,hourglass为主干网络时有无额外的nms相差不大,resnet相差较大。
70 | #-------------------------------------------------------#
71 | if self.nms:
72 | preds = np.array(nms(preds,self.nms_threhold))
73 |
74 | if len(preds[0])<=0:
75 | return
76 |
77 | #-----------------------------------------------------------#
78 | # 将预测结果转换成小数的形式
79 | #-----------------------------------------------------------#
80 | preds[0][:,0:4] = preds[0][:,0:4]/(self.input_shape[0]/4)
81 |
82 | det_label = preds[0][:, -1]
83 | det_conf = preds[0][:, -2]
84 | det_xmin, det_ymin, det_xmax, det_ymax = preds[0][:, 0], preds[0][:, 1], preds[0][:, 2], preds[0][:, 3]
85 | #-----------------------------------------------------------#
86 | # 筛选出其中得分高于confidence的框
87 | #-----------------------------------------------------------#
88 | top_indices = [i for i, conf in enumerate(det_conf) if conf >= self.confidence]
89 | top_conf = det_conf[top_indices]
90 | top_label_indices = det_label[top_indices].tolist()
91 | top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims(det_xmin[top_indices],-1),np.expand_dims(det_ymin[top_indices],-1),np.expand_dims(det_xmax[top_indices],-1),np.expand_dims(det_ymax[top_indices],-1)
92 |
93 | #-----------------------------------------------------------#
94 | # 去掉灰条部分
95 | #-----------------------------------------------------------#
96 | boxes = centernet_correct_boxes(top_ymin,top_xmin,top_ymax,top_xmax,np.array([self.input_shape[0],self.input_shape[1]]),image_shape)
97 |
98 | for i, c in enumerate(top_label_indices):
99 | predicted_class = self.class_names[int(c)]
100 | score = str(top_conf[i])
101 |
102 | top, left, bottom, right = boxes[i]
103 | f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom))))
104 |
105 | f.close()
106 | return
107 |
108 | centernet = mAP_CenterNet()
109 | image_ids = open('VOCdevkit/VOC2007/ImageSets/Main/test.txt').read().strip().split()
110 |
111 | if not os.path.exists("./input"):
112 | os.makedirs("./input")
113 | if not os.path.exists("./input/detection-results"):
114 | os.makedirs("./input/detection-results")
115 | if not os.path.exists("./input/images-optional"):
116 | os.makedirs("./input/images-optional")
117 |
118 | for image_id in tqdm(image_ids):
119 | image_path = "./VOCdevkit/VOC2007/JPEGImages/"+image_id+".jpg"
120 | image = Image.open(image_path)
121 | # image.save("./input/images-optional/"+image_id+".jpg")
122 | centernet.detect_image(image_id,image)
123 |
124 | print("Conversion completed!")
125 |
--------------------------------------------------------------------------------
/PKLot_server/MainWindow.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Form implementation generated from reading ui file 'MainWindow.ui'
4 | #
5 | # Created by: PyQt5 UI code generator 5.15.4
6 | #
7 | # WARNING: Any manual changes made to this file will be lost when pyuic5 is
8 | # run again. Do not edit this file unless you know what you are doing.
9 |
10 |
11 | from PyQt5 import QtCore, QtGui, QtWidgets
12 |
13 |
14 | class Ui_MainWindow(object):
15 | def setupUi(self, MainWindow):
16 | MainWindow.setObjectName("MainWindow")
17 | MainWindow.resize(800, 600)
18 | self.centralwidget = QtWidgets.QWidget(MainWindow)
19 | self.centralwidget.setObjectName("centralwidget")
20 | self.layoutWidget = QtWidgets.QWidget(self.centralwidget)
21 | self.layoutWidget.setGeometry(QtCore.QRect(21, 11, 131, 551))
22 | self.layoutWidget.setObjectName("layoutWidget")
23 | self.verticalLayout = QtWidgets.QVBoxLayout(self.layoutWidget)
24 | self.verticalLayout.setContentsMargins(0, 0, 0, 0)
25 | self.verticalLayout.setObjectName("verticalLayout")
26 | self.pushButton_1 = QtWidgets.QPushButton(self.layoutWidget)
27 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Fixed)
28 | sizePolicy.setHorizontalStretch(20)
29 | sizePolicy.setVerticalStretch(20)
30 | sizePolicy.setHeightForWidth(self.pushButton_1.sizePolicy().hasHeightForWidth())
31 | self.pushButton_1.setSizePolicy(sizePolicy)
32 | self.pushButton_1.setSizeIncrement(QtCore.QSize(20, 20))
33 | self.pushButton_1.setBaseSize(QtCore.QSize(20, 20))
34 | font = QtGui.QFont()
35 | font.setPointSize(10)
36 | self.pushButton_1.setFont(font)
37 | self.pushButton_1.setObjectName("pushButton_1")
38 | self.verticalLayout.addWidget(self.pushButton_1)
39 | self.pushButton_2 = QtWidgets.QPushButton(self.layoutWidget)
40 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Fixed)
41 | sizePolicy.setHorizontalStretch(20)
42 | sizePolicy.setVerticalStretch(20)
43 | sizePolicy.setHeightForWidth(self.pushButton_2.sizePolicy().hasHeightForWidth())
44 | self.pushButton_2.setSizePolicy(sizePolicy)
45 | self.pushButton_2.setSizeIncrement(QtCore.QSize(20, 20))
46 | self.pushButton_2.setBaseSize(QtCore.QSize(20, 20))
47 | font = QtGui.QFont()
48 | font.setPointSize(10)
49 | self.pushButton_2.setFont(font)
50 | self.pushButton_2.setObjectName("pushButton_2")
51 | self.verticalLayout.addWidget(self.pushButton_2)
52 | self.pushButton_3 = QtWidgets.QPushButton(self.layoutWidget)
53 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Fixed)
54 | sizePolicy.setHorizontalStretch(20)
55 | sizePolicy.setVerticalStretch(20)
56 | sizePolicy.setHeightForWidth(self.pushButton_3.sizePolicy().hasHeightForWidth())
57 | self.pushButton_3.setSizePolicy(sizePolicy)
58 | self.pushButton_3.setSizeIncrement(QtCore.QSize(20, 20))
59 | self.pushButton_3.setBaseSize(QtCore.QSize(20, 20))
60 | font = QtGui.QFont()
61 | font.setPointSize(10)
62 | self.pushButton_3.setFont(font)
63 | self.pushButton_3.setObjectName("pushButton_3")
64 | self.verticalLayout.addWidget(self.pushButton_3)
65 | self.pushButton = QtWidgets.QPushButton(self.layoutWidget)
66 | self.pushButton.setObjectName("pushButton")
67 | self.verticalLayout.addWidget(self.pushButton)
68 | self.label = QtWidgets.QLabel(self.centralwidget)
69 | self.label.setGeometry(QtCore.QRect(160, 10, 621, 511))
70 | self.label.setObjectName("label")
71 | MainWindow.setCentralWidget(self.centralwidget)
72 | self.statusbar = QtWidgets.QStatusBar(MainWindow)
73 | self.statusbar.setObjectName("statusbar")
74 | MainWindow.setStatusBar(self.statusbar)
75 | self.action_16 = QtWidgets.QAction(MainWindow)
76 | self.action_16.setObjectName("action_16")
77 | self.actiona = QtWidgets.QAction(MainWindow)
78 | self.actiona.setObjectName("actiona")
79 | self.actiona_2 = QtWidgets.QAction(MainWindow)
80 | self.actiona_2.setObjectName("actiona_2")
81 | self.actionb = QtWidgets.QAction(MainWindow)
82 | self.actionb.setObjectName("actionb")
83 | self.actiona_3 = QtWidgets.QAction(MainWindow)
84 | self.actiona_3.setObjectName("actiona_3")
85 | self.actiona_4 = QtWidgets.QAction(MainWindow)
86 | self.actiona_4.setObjectName("actiona_4")
87 | self.actionb_2 = QtWidgets.QAction(MainWindow)
88 | self.actionb_2.setObjectName("actionb_2")
89 | self.actiona_5 = QtWidgets.QAction(MainWindow)
90 | self.actiona_5.setObjectName("actiona_5")
91 | self.actionb_3 = QtWidgets.QAction(MainWindow)
92 | self.actionb_3.setObjectName("actionb_3")
93 |
94 | self.retranslateUi(MainWindow)
95 | QtCore.QMetaObject.connectSlotsByName(MainWindow)
96 |
97 | def retranslateUi(self, MainWindow):
98 | _translate = QtCore.QCoreApplication.translate
99 | MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
100 | self.pushButton_1.setText(_translate("MainWindow", "图片检测"))
101 | self.pushButton_2.setText(_translate("MainWindow", "视频检测"))
102 | self.pushButton_3.setText(_translate("MainWindow", "实时监控"))
103 | self.pushButton.setText(_translate("MainWindow", "统计时段"))
104 | self.label.setText(_translate("MainWindow", "检测中..."))
105 | self.action_16.setText(_translate("MainWindow", "按天统计"))
106 | self.actiona.setText(_translate("MainWindow", "a"))
107 | self.actiona_2.setText(_translate("MainWindow", "按月统计"))
108 | self.actionb.setText(_translate("MainWindow", "按年统计"))
109 | self.actiona_3.setText(_translate("MainWindow", "a"))
110 | self.actiona_4.setText(_translate("MainWindow", "实时"))
111 | self.actionb_2.setText(_translate("MainWindow", "打开文件"))
112 | self.actiona_5.setText(_translate("MainWindow", "暂停"))
113 | self.actionb_3.setText(_translate("MainWindow", "全屏"))
114 |
--------------------------------------------------------------------------------
/PKLot_server/nets/resnet.py:
--------------------------------------------------------------------------------
1 | #-------------------------------------------------------------#
2 | # ResNet50的网络部分
3 | #-------------------------------------------------------------#
4 | from tensorflow.keras import layers
5 | from tensorflow.keras.layers import (Activation, BatchNormalization, Conv2D,
6 | Conv2DTranspose, Dropout, MaxPooling2D,
7 | ZeroPadding2D)
8 | from tensorflow.keras.regularizers import l2
9 | from tensorflow.keras.initializers import RandomNormal
10 |
11 |
12 | def identity_block(input_tensor, kernel_size, filters, stage, block):
13 |
14 | filters1, filters2, filters3 = filters
15 |
16 | conv_name_base = 'res' + str(stage) + block + '_branch'
17 | bn_name_base = 'bn' + str(stage) + block + '_branch'
18 |
19 | x = Conv2D(filters1, (1, 1), kernel_initializer=RandomNormal(stddev=0.02), name=conv_name_base + '2a', use_bias=False)(input_tensor)
20 | x = BatchNormalization(name=bn_name_base + '2a')(x)
21 | x = Activation('relu')(x)
22 |
23 | x = Conv2D(filters2, kernel_size, padding='same', kernel_initializer=RandomNormal(stddev=0.02), name=conv_name_base + '2b', use_bias=False)(x)
24 | x = BatchNormalization(name=bn_name_base + '2b')(x)
25 | x = Activation('relu')(x)
26 |
27 | x = Conv2D(filters3, (1, 1), kernel_initializer=RandomNormal(stddev=0.02), name=conv_name_base + '2c', use_bias=False)(x)
28 | x = BatchNormalization(name=bn_name_base + '2c')(x)
29 |
30 | x = layers.add([x, input_tensor])
31 | x = Activation('relu')(x)
32 | return x
33 |
34 |
35 | def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):
36 |
37 | filters1, filters2, filters3 = filters
38 |
39 | conv_name_base = 'res' + str(stage) + block + '_branch'
40 | bn_name_base = 'bn' + str(stage) + block + '_branch'
41 |
42 | x = Conv2D(filters1, (1, 1), strides=strides, kernel_initializer=RandomNormal(stddev=0.02),
43 | name=conv_name_base + '2a', use_bias=False)(input_tensor)
44 | x = BatchNormalization(name=bn_name_base + '2a')(x)
45 | x = Activation('relu')(x)
46 |
47 | x = Conv2D(filters2, kernel_size, padding='same', kernel_initializer=RandomNormal(stddev=0.02),
48 | name=conv_name_base + '2b', use_bias=False)(x)
49 | x = BatchNormalization(name=bn_name_base + '2b')(x)
50 | x = Activation('relu')(x)
51 |
52 | x = Conv2D(filters3, (1, 1), kernel_initializer=RandomNormal(stddev=0.02), name=conv_name_base + '2c', use_bias=False)(x)
53 | x = BatchNormalization(name=bn_name_base + '2c')(x)
54 |
55 | shortcut = Conv2D(filters3, (1, 1), strides=strides, kernel_initializer=RandomNormal(stddev=0.02),
56 | name=conv_name_base + '1', use_bias=False)(input_tensor)
57 | shortcut = BatchNormalization(name=bn_name_base + '1')(shortcut)
58 |
59 | x = layers.add([x, shortcut])
60 | x = Activation('relu')(x)
61 | return x
62 |
63 |
64 | def ResNet50(inputs):
65 | # 512x512x3
66 | x = ZeroPadding2D((3, 3))(inputs)
67 | # 256,256,64
68 | x = Conv2D(64, (7, 7), kernel_initializer=RandomNormal(stddev=0.02), strides=(2, 2), name='conv1', use_bias=False)(x)
69 | x = BatchNormalization(name='bn_conv1')(x)
70 | x = Activation('relu')(x)
71 |
72 | # 256,256,64 -> 128,128,64
73 | x = MaxPooling2D((3, 3), strides=(2, 2), padding="same")(x)
74 |
75 | # 128,128,64 -> 128,128,256
76 | x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
77 | x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
78 | x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
79 |
80 | # 128,128,256 -> 64,64,512
81 | x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
82 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
83 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
84 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
85 |
86 | # 64,64,512 -> 32,32,1024
87 | x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
88 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
89 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
90 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
91 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
92 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
93 |
94 | # 32,32,1024 -> 16,16,2048
95 | x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
96 | x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
97 | x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
98 |
99 | return x
100 |
101 | def centernet_head(x,num_classes):
102 | x = Dropout(rate=0.5)(x)
103 | #-------------------------------#
104 | # 解码器
105 | #-------------------------------#
106 | num_filters = 256
107 | # 16, 16, 2048 -> 32, 32, 256 -> 64, 64, 128 -> 128, 128, 64
108 | for i in range(3):
109 | # 进行上采样
110 | x = Conv2DTranspose(num_filters // pow(2, i), (4, 4), strides=2, use_bias=False, padding='same',
111 | kernel_initializer='he_normal',
112 | kernel_regularizer=l2(5e-4))(x)
113 | x = BatchNormalization()(x)
114 | x = Activation('relu')(x)
115 | # 最终获得128,128,64的特征层
116 | # hm header
117 | y1 = Conv2D(64, 3, padding='same', use_bias=False, kernel_initializer=RandomNormal(stddev=0.02), kernel_regularizer=l2(5e-4))(x)
118 | y1 = BatchNormalization()(y1)
119 | y1 = Activation('relu')(y1)
120 | y1 = Conv2D(num_classes, 1, kernel_initializer=RandomNormal(stddev=0.02), kernel_regularizer=l2(5e-4), activation='sigmoid')(y1)
121 |
122 | # wh header
123 | y2 = Conv2D(64, 3, padding='same', use_bias=False, kernel_initializer=RandomNormal(stddev=0.02), kernel_regularizer=l2(5e-4))(x)
124 | y2 = BatchNormalization()(y2)
125 | y2 = Activation('relu')(y2)
126 | y2 = Conv2D(2, 1, kernel_initializer=RandomNormal(stddev=0.02), kernel_regularizer=l2(5e-4))(y2)
127 |
128 | # reg header
129 | y3 = Conv2D(64, 3, padding='same', use_bias=False, kernel_initializer=RandomNormal(stddev=0.02), kernel_regularizer=l2(5e-4))(x)
130 | y3 = BatchNormalization()(y3)
131 | y3 = Activation('relu')(y3)
132 | y3 = Conv2D(2, 1, kernel_initializer=RandomNormal(stddev=0.02), kernel_regularizer=l2(5e-4))(y3)
133 | return y1, y2, y3
134 |
--------------------------------------------------------------------------------
/PKLot_server/nets/hourglass.py:
--------------------------------------------------------------------------------
1 | import tensorflow.keras.backend as K
2 | from tensorflow.keras.layers import (Activation, Add, BatchNormalization, Conv2D, Input, UpSampling2D,
3 | ZeroPadding2D)
4 | from tensorflow.keras.models import Model
5 | from tensorflow.keras.initializers import RandomNormal
6 |
7 |
8 | def conv2d(x, k, out_dim, name, stride=1):
9 | padding = (k - 1) // 2
10 | x = ZeroPadding2D(padding=padding, name=name + '.pad')(x)
11 | x = Conv2D(out_dim, k, strides=stride, kernel_initializer=RandomNormal(stddev=0.02), use_bias=False, name=name + '.conv')(x)
12 | x = BatchNormalization(epsilon=1e-5, name=name + '.bn')(x)
13 | x = Activation('relu', name=name + '.relu')(x)
14 | return x
15 |
16 | def residual(x, out_dim, name, stride=1):
17 | #-----------------------------------#
18 | # 残差网络结构
19 | # 两个形态
20 | # 1、残差边有卷积,改变维度
21 | # 2、残差边无卷积,加大深度
22 | #-----------------------------------#
23 | shortcut = x
24 | num_channels = K.int_shape(shortcut)[-1]
25 |
26 | x = ZeroPadding2D(padding=1, name=name + '.pad1')(x)
27 | x = Conv2D(out_dim, 3, strides=stride, kernel_initializer=RandomNormal(stddev=0.02), use_bias=False, name=name + '.conv1')(x)
28 | x = BatchNormalization(epsilon=1e-5, name=name + '.bn1')(x)
29 | x = Activation('relu', name=name + '.relu1')(x)
30 |
31 | x = Conv2D(out_dim, 3, padding='same', kernel_initializer=RandomNormal(stddev=0.02), use_bias=False, name=name + '.conv2')(x)
32 | x = BatchNormalization(epsilon=1e-5, name=name + '.bn2')(x)
33 |
34 | if num_channels != out_dim or stride != 1:
35 | shortcut = Conv2D(out_dim, 1, strides=stride, kernel_initializer=RandomNormal(stddev=0.02), use_bias=False, name=name + '.shortcut.0')(
36 | shortcut)
37 | shortcut = BatchNormalization(epsilon=1e-5, name=name + '.shortcut.1')(shortcut)
38 |
39 | x = Add(name=name + '.add')([x, shortcut])
40 | x = Activation('relu', name=name + '.relu')(x)
41 | return x
42 |
43 | def bottleneck_layer(x, num_channels, hgid):
44 | #-----------------------------------#
45 | # 中间的深度结构
46 | #-----------------------------------#
47 | pow_str = 'center.' * 5
48 | x = residual(x, num_channels, name='kps.%d.%s0' % (hgid, pow_str))
49 | x = residual(x, num_channels, name='kps.%d.%s1' % (hgid, pow_str))
50 | x = residual(x, num_channels, name='kps.%d.%s2' % (hgid, pow_str))
51 | x = residual(x, num_channels, name='kps.%d.%s3' % (hgid, pow_str))
52 | return x
53 |
54 | def connect_left_right(left, right, num_channels, num_channels_next, name):
55 | # 图像上半部分的卷积
56 | left = residual(left, num_channels_next, name=name + 'skip.0')
57 | left = residual(left, num_channels_next, name=name + 'skip.1')
58 | # 图像右半部分的上采样
59 | out = residual(right, num_channels, name=name + 'out.0')
60 | out = residual(out, num_channels_next, name=name + 'out.1')
61 | out = UpSampling2D(name=name + 'out.upsampleNN')(out)
62 | # 利用相加进行全连接
63 | out = Add(name=name + 'out.add')([left, out])
64 | return out
65 |
66 | def pre(x, num_channels):
67 | #-----------------------------------#
68 | # 图片进入金字塔前的预处理
69 | # 一般是一次普通卷积
70 | # 加上残差结构
71 | #-----------------------------------#
72 | x = conv2d(x, 7, 128, name='pre.0', stride=2)
73 | x = residual(x, num_channels, name='pre.1', stride=2)
74 | return x
75 |
76 | def left_features(bottom, hgid, dims):
77 | #-------------------------------------------------#
78 | # 进行五次下采样
79 | # f1, f2, f4 , f8, f16, f32 : 1, 1/2, 1/4 1/8, 1/16, 1/32 resolution
80 | # 5 times reduce/increase: (256, 384, 384, 384, 512)
81 | #-------------------------------------------------#
82 | features = [bottom]
83 | for kk, nh in enumerate(dims):
84 | x = residual(features[-1], nh, name='kps.%d%s.down.0' % (hgid, str(kk)), stride=2)
85 | x = residual(x, nh, name='kps.%d%s.down.1' % (hgid, str(kk)))
86 | features.append(x)
87 | return features
88 |
89 | def right_features(leftfeatures, hgid, dims):
90 | #-------------------------------------------------#
91 | # 进行五次上采样,并进行连接
92 | # f1, f2, f4 , f8, f16, f32 : 1, 1/2, 1/4 1/8, 1/16, 1/32 resolution
93 | # 5 times reduce/increase: (256, 384, 384, 384, 512)
94 | #-------------------------------------------------#
95 | rf = bottleneck_layer(leftfeatures[-1], dims[-1], hgid)
96 | for kk in reversed(range(len(dims))):
97 | pow_str = ''
98 | for _ in range(kk):
99 | pow_str += 'center.'
100 | rf = connect_left_right(leftfeatures[kk], rf, dims[kk], dims[max(kk - 1, 0)], name='kps.%d.%s' % (hgid, pow_str))
101 | return rf
102 |
103 |
104 | def create_heads(num_classes, rf1, hgid):
105 | y1 = Conv2D(256, 3, kernel_initializer=RandomNormal(stddev=0.02), use_bias=True, padding='same', name='hm.%d.0.conv' % hgid)(rf1)
106 | y1 = Activation('relu', name='hm.%d.0.relu' % hgid)(y1)
107 | y1 = Conv2D(num_classes, 1, use_bias=True, name='hm.%d.1' % hgid, activation = "sigmoid")(y1)
108 |
109 | y2 = Conv2D(256, 3, kernel_initializer=RandomNormal(stddev=0.02), use_bias=True, padding='same', name='wh.%d.0.conv' % hgid)(rf1)
110 | y2 = Activation('relu', name='wh.%d.0.relu' % hgid)(y2)
111 | y2 = Conv2D(2, 1, use_bias=True, name='wh.%d.1' % hgid)(y2)
112 |
113 | y3 = Conv2D(256, 3, kernel_initializer=RandomNormal(stddev=0.02), use_bias=True, padding='same', name='reg.%d.0.conv' % hgid)(rf1)
114 | y3 = Activation('relu', name='reg.%d.0.relu' % hgid)(y3)
115 | y3 = Conv2D(2, 1, use_bias=True, name='reg.%d.1' % hgid)(y3)
116 |
117 | return [y1,y2,y3]
118 |
119 | def hourglass_module(num_classes, bottom, cnv_dim, hgid, dims):
120 | # 左边下采样的部分
121 | lfs = left_features(bottom, hgid, dims)
122 |
123 | # 右边上采样与中间的连接部分
124 | rf1 = right_features(lfs, hgid, dims)
125 | rf1 = conv2d(rf1, 3, cnv_dim, name='cnvs.%d' % hgid)
126 |
127 | heads = create_heads(num_classes, rf1, hgid)
128 | return heads, rf1
129 |
130 |
131 | def HourglassNetwork(inpnuts, num_stacks, num_classes, cnv_dim=256, dims=[256, 384, 384, 384, 512]):
132 | inter = pre(inpnuts, cnv_dim)
133 | outputs = []
134 | for i in range(num_stacks):
135 | prev_inter = inter
136 | _heads, inter = hourglass_module(num_classes, inter, cnv_dim, i, dims)
137 | outputs.append(_heads)
138 | if i < num_stacks - 1:
139 | inter_ = Conv2D(cnv_dim, 1, kernel_initializer=RandomNormal(stddev=0.02), use_bias=False, name='inter_.%d.0' % i)(prev_inter)
140 | inter_ = BatchNormalization(epsilon=1e-5, name='inter_.%d.1' % i)(inter_)
141 |
142 | cnv_ = Conv2D(cnv_dim, 1, kernel_initializer=RandomNormal(stddev=0.02), use_bias=False, name='cnv_.%d.0' % i)(inter)
143 | cnv_ = BatchNormalization(epsilon=1e-5, name='cnv_.%d.1' % i)(cnv_)
144 |
145 | inter = Add(name='inters.%d.inters.add' % i)([inter_, cnv_])
146 | inter = Activation('relu', name='inters.%d.inters.relu' % i)(inter)
147 | inter = residual(inter, cnv_dim, 'inters.%d' % i)
148 | return outputs
149 |
150 | if __name__ == "__main__":
151 | image_input = Input(shape=(512, 512, 3))
152 | outputs = HourglassNetwork(image_input,2,20)
153 | model = Model(image_input,outputs[-1])
154 | model.summary()
155 |
156 |
--------------------------------------------------------------------------------
/PKLot_server/nets/centernet.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.keras.layers import Input, Lambda, MaxPooling2D
3 | from tensorflow.keras.models import Model
4 |
5 | from nets.centernet_training import loss
6 | from nets.hourglass import HourglassNetwork
7 | from nets.resnet import ResNet50, centernet_head
8 |
9 |
10 | def nms(heat, kernel=3):
11 | hmax = MaxPooling2D((kernel, kernel), strides=1, padding='SAME')(heat)
12 | heat = tf.where(tf.equal(hmax, heat), heat, tf.zeros_like(heat))
13 | return heat
14 |
15 | def topk(hm, max_objects=100):
16 | #-------------------------------------------------------------------------#
17 | # 当利用512x512x3图片进行coco数据集预测的时候
18 | # h = w = 128 num_classes = 80
19 | # Hot map热力图 -> b, 128, 128, 80
20 | # 进行热力图的非极大抑制,利用3x3的卷积对热力图进行最大值筛选
21 | # 找出一定区域内,得分最大的特征点。
22 | #-------------------------------------------------------------------------#
23 | hm = nms(hm)
24 | b, h, w, c = tf.shape(hm)[0], tf.shape(hm)[1], tf.shape(hm)[2], tf.shape(hm)[3]
25 | #-------------------------------------------#
26 | # 将所有结果平铺,获得(b, 128 * 128 * 80)
27 | #-------------------------------------------#
28 | hm = tf.reshape(hm, (b, -1))
29 | #-----------------------------#
30 | # (b, k), (b, k)
31 | #-----------------------------#
32 | scores, indices = tf.math.top_k(hm, k=max_objects, sorted=True)
33 |
34 | #--------------------------------------#
35 | # 计算求出种类、网格点以及索引。
36 | #--------------------------------------#
37 | class_ids = indices % c
38 | xs = indices // c % w
39 | ys = indices // c // w
40 | indices = ys * w + xs
41 | return scores, indices, class_ids, xs, ys
42 |
43 | def decode(hm, wh, reg, max_objects=100,num_classes=20):
44 | #-----------------------------------------------------#
45 | # hm b, 128, 128, num_classes
46 | # wh b, 128, 128, 2
47 | # reg b, 128, 128, 2
48 | # scores b, max_objects
49 | # indices b, max_objects
50 | # class_ids b, max_objects
51 | # xs b, max_objects
52 | # ys b, max_objects
53 | #-----------------------------------------------------#
54 | scores, indices, class_ids, xs, ys = topk(hm, max_objects=max_objects)
55 | b = tf.shape(hm)[0]
56 |
57 | #-----------------------------------------------------#
58 | # wh b, 128 * 128, 2
59 | # reg b, 128 * 128, 2
60 | #-----------------------------------------------------#
61 | reg = tf.reshape(reg, [b, -1, 2])
62 | wh = tf.reshape(wh, [b, -1, 2])
63 | length = tf.shape(wh)[1]
64 |
65 | #-----------------------------------------------------#
66 | # 找到其在1维上的索引
67 | # batch_idx b, max_objects
68 | #-----------------------------------------------------#
69 | batch_idx = tf.expand_dims(tf.range(0, b), 1)
70 | batch_idx = tf.tile(batch_idx, (1, max_objects))
71 | full_indices = tf.reshape(batch_idx, [-1]) * tf.cast(length, tf.int32) + tf.reshape(indices, [-1])
72 |
73 | #-----------------------------------------------------#
74 | # 取出top_k个框对应的参数
75 | #-----------------------------------------------------#
76 | topk_reg = tf.gather(tf.reshape(reg, [-1,2]), full_indices)
77 | topk_reg = tf.reshape(topk_reg, [b, -1, 2])
78 |
79 | topk_wh = tf.gather(tf.reshape(wh, [-1,2]), full_indices)
80 | topk_wh = tf.reshape(topk_wh, [b, -1, 2])
81 |
82 | #-----------------------------------------------------#
83 | # 利用参数获得调整后预测框的中心
84 | # topk_cx b,k,1
85 | # topk_cy b,k,1
86 | #-----------------------------------------------------#
87 | topk_cx = tf.cast(tf.expand_dims(xs, axis=-1), tf.float32) + topk_reg[..., 0:1]
88 | topk_cy = tf.cast(tf.expand_dims(ys, axis=-1), tf.float32) + topk_reg[..., 1:2]
89 |
90 | #-----------------------------------------------------#
91 | # 计算预测框左上角和右下角
92 | # topk_x1 b,k,1 预测框左上角x轴坐标
93 | # topk_y1 b,k,1 预测框左上角y轴坐标
94 | # topk_x2 b,k,1 预测框右下角x轴坐标
95 | # topk_y2 b,k,1 预测框右下角y轴坐标
96 | #-----------------------------------------------------#
97 | topk_x1, topk_y1 = topk_cx - topk_wh[..., 0:1] / 2, topk_cy - topk_wh[..., 1:2] / 2
98 | topk_x2, topk_y2 = topk_cx + topk_wh[..., 0:1] / 2, topk_cy + topk_wh[..., 1:2] / 2
99 |
100 | #-----------------------------------------------------#
101 | # scores b,k,1 预测框得分
102 | # class_ids b,k,1 预测框种类
103 | #-----------------------------------------------------#
104 | scores = tf.expand_dims(scores, axis=-1)
105 | class_ids = tf.cast(tf.expand_dims(class_ids, axis=-1), tf.float32)
106 |
107 | #-----------------------------------------------------#
108 | # detections 预测框所有参数的堆叠
109 | # 前四个是预测框的坐标,后两个是预测框的得分与种类
110 | #-----------------------------------------------------#
111 | detections = tf.concat([topk_x1, topk_y1, topk_x2, topk_y2, scores, class_ids], axis=-1)
112 |
113 | return detections
114 |
115 |
116 | def centernet(input_shape, num_classes, backbone='resnet50', max_objects=100, mode="train", num_stacks=2):
117 | assert backbone in ['resnet50', 'hourglass']
118 | output_size = input_shape[0] // 4
119 | image_input = Input(shape=input_shape)
120 | hm_input = Input(shape=(output_size, output_size, num_classes))
121 | wh_input = Input(shape=(max_objects, 2))
122 | reg_input = Input(shape=(max_objects, 2))
123 | reg_mask_input = Input(shape=(max_objects,))
124 | index_input = Input(shape=(max_objects,))
125 |
126 | if backbone=='resnet50':
127 | #-----------------------------------#
128 | # 对输入图片进行特征提取
129 | # 512, 512, 3 -> 16, 16, 2048
130 | #-----------------------------------#
131 | C5 = ResNet50(image_input)
132 | #--------------------------------------------------------------------------------------------------------#
133 | # 对获取到的特征进行上采样,进行分类预测和回归预测
134 | # 16, 16, 2048 -> 32, 32, 256 -> 64, 64, 128 -> 128, 128, 64 -> 128, 128, 64 -> 128, 128, num_classes
135 | # -> 128, 128, 64 -> 128, 128, 2
136 | # -> 128, 128, 64 -> 128, 128, 2
137 | #--------------------------------------------------------------------------------------------------------#
138 | y1, y2, y3 = centernet_head(C5,num_classes)
139 |
140 | if mode=="train":
141 | loss_ = Lambda(loss, name='centernet_loss')([y1, y2, y3, hm_input, wh_input, reg_input, reg_mask_input, index_input])
142 | model = Model(inputs=[image_input, hm_input, wh_input, reg_input, reg_mask_input, index_input], outputs=[loss_])
143 | return model
144 | else:
145 | detections = Lambda(lambda x: decode(*x, max_objects=max_objects,
146 | num_classes=num_classes))([y1, y2, y3])
147 | prediction_model = Model(inputs=image_input, outputs=detections)
148 | return prediction_model
149 |
150 | else:
151 | outs = HourglassNetwork(image_input,num_stacks,num_classes)
152 |
153 | if mode=="train":
154 | loss_all = []
155 | for out in outs:
156 | y1, y2, y3 = out
157 | loss_ = Lambda(loss)([y1, y2, y3, hm_input, wh_input, reg_input, reg_mask_input, index_input])
158 | loss_all.append(loss_)
159 | loss_all = Lambda(tf.reduce_mean,name='centernet_loss')(loss_all)
160 |
161 | model = Model(inputs=[image_input, hm_input, wh_input, reg_input, reg_mask_input, index_input], outputs=loss_all)
162 | return model
163 | else:
164 | y1, y2, y3 = outs[-1]
165 | detections = Lambda(lambda x: decode(*x, max_objects=max_objects,
166 | num_classes=num_classes))([y1, y2, y3])
167 | prediction_model = Model(inputs=image_input, outputs=[detections])
168 | return prediction_model
169 |
--------------------------------------------------------------------------------
/PKLot_server/utils/utils.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import numpy as np
4 | import tensorflow as tf
5 | from PIL import Image
6 |
7 | def letterbox_image(image, size):
8 | iw, ih = image.size
9 | w, h = size
10 | scale = min(w/iw, h/ih)
11 | nw = int(iw*scale)
12 | nh = int(ih*scale)
13 |
14 | image = image.resize((nw,nh), Image.BICUBIC)
15 | new_image = Image.new('RGB', size, (128,128,128))
16 | new_image.paste(image, ((w-nw)//2, (h-nh)//2))
17 | return new_image
18 |
19 | def centernet_correct_boxes(top, left, bottom, right, input_shape, image_shape):
20 | new_shape = image_shape*np.min(input_shape/image_shape)
21 |
22 | offset = (input_shape-new_shape)/2./input_shape
23 | scale = input_shape/new_shape
24 |
25 | box_yx = np.concatenate(((top+bottom)/2,(left+right)/2),axis=-1)
26 | box_hw = np.concatenate((bottom-top,right-left),axis=-1)
27 |
28 | box_yx = (box_yx - offset) * scale
29 | box_hw *= scale
30 |
31 | box_mins = box_yx - (box_hw / 2.)
32 | box_maxes = box_yx + (box_hw / 2.)
33 | boxes = np.concatenate([
34 | box_mins[:, 0:1],
35 | box_mins[:, 1:2],
36 | box_maxes[:, 0:1],
37 | box_maxes[:, 1:2]
38 | ],axis=-1)
39 | boxes *= np.concatenate([image_shape, image_shape],axis=-1)
40 | return boxes
41 |
42 | def nms(results, nms):
43 | outputs = []
44 | # 对每一个图片进行处理
45 | for i in range(len(results)):
46 | #------------------------------------------------#
47 | # 具体过程可参考
48 | # https://www.bilibili.com/video/BV1Lz411B7nQ
49 | #------------------------------------------------#
50 | detections = results[i]
51 | unique_class = np.unique(detections[:,-1])
52 |
53 | best_box = []
54 | if len(unique_class) == 0:
55 | results.append(best_box)
56 | continue
57 | # 对种类进行循环,
58 | # 非极大抑制的作用是筛选出一定区域内属于同一种类得分最大的框,
59 | # 对种类进行循环可以帮助我们对每一个类分别进行非极大抑制。
60 | for c in unique_class:
61 | cls_mask = detections[:,-1] == c
62 |
63 | detection = detections[cls_mask]
64 | scores = detection[:,4]
65 | # 根据得分对该种类进行从大到小排序。
66 | arg_sort = np.argsort(scores)[::-1]
67 | detection = detection[arg_sort]
68 | while np.shape(detection)[0]>0:
69 | # 每次取出得分最大的框,计算其与其它所有预测框的重合程度,重合程度过大的则剔除。
70 | best_box.append(detection[0])
71 | if len(detection) == 1:
72 | break
73 | ious = iou(best_box[-1],detection[1:])
74 | detection = detection[1:][ious 0 and min(masked_heatmap.shape) > 0: # TODO debug
110 | np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
111 | return heatmap
112 |
113 | def gaussian2D(shape, sigma=1):
114 | m, n = [(ss - 1.) / 2. for ss in shape]
115 | y, x = np.ogrid[-m:m + 1, -n:n + 1]
116 |
117 | h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
118 | h[h < np.finfo(h.dtype).eps * h.max()] = 0
119 | return h
120 |
121 | def gaussian_radius(det_size, min_overlap=0.7):
122 | height, width = det_size
123 |
124 | a1 = 1
125 | b1 = (height + width)
126 | c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
127 | sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)
128 | r1 = (b1 + sq1) / 2
129 |
130 | a2 = 4
131 | b2 = 2 * (height + width)
132 | c2 = (1 - min_overlap) * width * height
133 | sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)
134 | r2 = (b2 + sq2) / 2
135 |
136 | a3 = 4 * min_overlap
137 | b3 = -2 * min_overlap * (height + width)
138 | c3 = (min_overlap - 1) * width * height
139 | sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)
140 | r3 = (b3 + sq3) / 2
141 | return min(r1, r2, r3)
142 |
143 |
144 | class ModelCheckpoint(tf.keras.callbacks.Callback):
145 | def __init__(self, filepath, monitor='val_loss', verbose=0,
146 | save_best_only=False, save_weights_only=False,
147 | mode='auto', period=1):
148 | super(ModelCheckpoint, self).__init__()
149 | self.monitor = monitor
150 | self.verbose = verbose
151 | self.filepath = filepath
152 | self.save_best_only = save_best_only
153 | self.save_weights_only = save_weights_only
154 | self.period = period
155 | self.epochs_since_last_save = 0
156 |
157 | if mode not in ['auto', 'min', 'max']:
158 | warnings.warn('ModelCheckpoint mode %s is unknown, '
159 | 'fallback to auto mode.' % (mode),
160 | RuntimeWarning)
161 | mode = 'auto'
162 |
163 | if mode == 'min':
164 | self.monitor_op = np.less
165 | self.best = np.Inf
166 | elif mode == 'max':
167 | self.monitor_op = np.greater
168 | self.best = -np.Inf
169 | else:
170 | if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
171 | self.monitor_op = np.greater
172 | self.best = -np.Inf
173 | else:
174 | self.monitor_op = np.less
175 | self.best = np.Inf
176 |
177 | def on_epoch_end(self, epoch, logs=None):
178 | logs = logs or {}
179 | self.epochs_since_last_save += 1
180 | if self.epochs_since_last_save >= self.period:
181 | self.epochs_since_last_save = 0
182 | filepath = self.filepath.format(epoch=epoch + 1, **logs)
183 | if self.save_best_only:
184 | current = logs.get(self.monitor)
185 | if current is None:
186 | warnings.warn('Can save best model only with %s available, '
187 | 'skipping.' % (self.monitor), RuntimeWarning)
188 | else:
189 | if self.monitor_op(current, self.best):
190 | if self.verbose > 0:
191 | print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
192 | ' saving model to %s'
193 | % (epoch + 1, self.monitor, self.best,
194 | current, filepath))
195 | self.best = current
196 | if self.save_weights_only:
197 | self.model.save_weights(filepath, overwrite=True)
198 | else:
199 | self.model.save(filepath, overwrite=True)
200 | else:
201 | if self.verbose > 0:
202 | print('\nEpoch %05d: %s did not improve' %
203 | (epoch + 1, self.monitor))
204 | else:
205 | if self.verbose > 0:
206 | print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
207 | if self.save_weights_only:
208 | self.model.save_weights(filepath, overwrite=True)
209 | else:
210 | self.model.save(filepath, overwrite=True)
211 |
--------------------------------------------------------------------------------
/PKLot_server/predictUI.py:
--------------------------------------------------------------------------------
1 | from MainWindow import Ui_MainWindow
2 | from PyQt5.QtWidgets import QWidget, QMainWindow, QApplication, QGraphicsScene, QGraphicsPixmapItem, QFileDialog, QMessageBox
3 | from PyQt5.QtGui import QImage, QPixmap
4 | import sys
5 | import datetime
6 | import cv2
7 | import numpy as np
8 | from PIL import ImageDraw, ImageFont
9 | from centernet import CenterNet
10 | from DataServer import FormetDayTime
11 | from threading import Thread, Lock
12 | import socket
13 | from PIL import Image
14 | import matplotlib.pyplot as plt
15 |
16 | plt.rcParams["font.sans-serif"] = ["SimHei"]
17 | plt.rcParams["axes.unicode_minus"] = False
18 |
19 | centernet = CenterNet()
20 | lock = Lock()
21 |
22 | class predictWindow(QMainWindow, Ui_MainWindow):
23 | def __init__(self):
24 | super().__init__()
25 | self.setupUi(self)
26 | img_src = cv2.imread("model_data/timg.jpg") # 读取图像
27 | img_src = cv2.cvtColor(img_src, cv2.COLOR_BGR2RGB) # 转换图像通道
28 | label_width = self.label.width()
29 | label_height = self.label.height()
30 | temp_imgSrc = QImage(img_src[:], img_src.shape[1], img_src.shape[0], img_src.shape[1] * 3, QImage.Format_RGB888)
31 | # 将图片转换为QPixmap方便显示
32 | self.pixmap_imgSrc = QPixmap.fromImage(temp_imgSrc).scaled(label_width, label_height)
33 | now = datetime.datetime.now()
34 | time = now.strftime("%Y %m %d %H %M %S")
35 | timelist = time.split(" ")
36 | self.mon = int(timelist[1])
37 | self.isOn = False
38 | self.isVideoOn = False
39 | self.isImgOn = False
40 | self.day, self.time = FormetDayTime(timelist)
41 | self.label.setPixmap(QPixmap(self.pixmap_imgSrc))
42 | self.pushButton.clicked.connect(self.CountTime)
43 | self.pushButton_1.clicked.connect(self.imgOnOff)
44 | self.pushButton_2.clicked.connect(self.videoOnOff)
45 | self.pushButton_3.clicked.connect(self.On_Off)
46 |
47 | def CountTime(self):
48 | try:
49 | tco = np.zeros(24)
50 | tce = np.zeros(24)
51 | lock.acquire()
52 | with open("./detect_Logs/countlogs") as f:
53 | for fn in f.readlines():
54 | print(fn)
55 | tco[int(fn.split(';')[0])] += int(fn.split(";")[1])
56 | tce[int(fn.split(';')[0])] += int(fn.split(";")[2].strip())
57 |
58 | lock.release()
59 | label = []
60 | for i in range(24):
61 | label.append(str(i))
62 | rate = tco / (tco + tce)
63 | plt.plot(range(24), rate, label="时段内车位占用比")
64 | plt.xticks(range(24), label)
65 | plt.xlim((0, 23))
66 | plt.xlabel("小时")
67 | plt.ylabel("车位占有率")
68 | plt.legend()
69 | plt.title("时段车位统计")
70 | plt.show()
71 | except:
72 | msg_box = QMessageBox(QMessageBox.Warning, 'Warning', '数据加载失败!')
73 | msg_box.exec_()
74 |
75 | def imgOnOff(self):
76 | self.isImgOn = ~self.isImgOn
77 | if not self.isImgOn:
78 | self.pushButton_1.setText("图片检测")
79 | else:
80 | self.pushButton_1.setText("完成图片检测")
81 |
82 | if self.isImgOn:
83 | try:
84 | image_file, _ = QFileDialog.getOpenFileName(self, 'Open file', '\\', 'Image files (*.jpg *.gif *.png *.jpeg)')
85 | except:
86 | msg_box = QMessageBox(QMessageBox.Warning, 'Warning', '未选中图片!')
87 | msg_box.exec_()
88 | print('Open Error! Try again!')
89 | return
90 |
91 | try:
92 | image = Image.open(image_file)
93 | r_image, obj1sum, obj2sum, color1, color2, max_cnt = centernet.detect_image(image)
94 |
95 | draw = ImageDraw.Draw(r_image)
96 | fontStyle = ImageFont.truetype(
97 | font="model_data/simhei.ttf", size=20, encoding='utf-8')
98 |
99 | # # 绘制框和文本
100 | # draw.rectangle(
101 | # [tuple((0, 560)), tuple((180, 640))],
102 | # fill=(255, 255, 255), outline='black')
103 |
104 | draw.ellipse((10, 560, 30, 580), fill=color1)
105 | draw.ellipse((10, 600, 30, 620), fill=color2)
106 | draw.text((50, 560), "被占车位:" + str(obj1sum), color1, font=fontStyle)
107 | draw.text((50, 600), "空车位 :" + str(obj2sum), color2, font=fontStyle)
108 |
109 | r_image = np.array(r_image)
110 | # RGBtoBGR满足opencv显示格式
111 | img_src = r_image # cv2.cvtColor(r_image, cv2.COLOR_RGB2BGR)
112 | label_width = self.label.width()
113 | label_height = self.label.height()
114 | temp_imgSrc = QImage(img_src[:], img_src.shape[1], img_src.shape[0], img_src.shape[1] * 3,
115 | QImage.Format_RGB888)
116 |
117 | # 将图片转换为QPixmap方便显示
118 | pixmap_imgSrc = QPixmap.fromImage(temp_imgSrc).scaled(label_width, label_height)
119 | self.label.setPixmap(QPixmap(pixmap_imgSrc))
120 |
121 | except:
122 | msg_box = QMessageBox(QMessageBox.Warning, 'Warning', '识别失败!')
123 | msg_box.exec_()
124 | else:
125 | self.label.setPixmap(QPixmap(self.pixmap_imgSrc))
126 |
127 |
128 | def videoOnOff(self):
129 | self.isVideoOn = ~self.isVideoOn
130 | if not self.isVideoOn:
131 | self.pushButton_2.setText("视频检测")
132 | else:
133 | self.pushButton_2.setText("停止视频检测")
134 | if self.isVideoOn:
135 | try:
136 | image_file, _ = QFileDialog.getOpenFileName(self, 'Open file', '\\',
137 | 'Video files (*.gif *.mp4 *.avi *.dat *.mkv *.flv *.vob *.3gp)')
138 | capture = cv2.VideoCapture(image_file)
139 | except:
140 | msg_box = QMessageBox(QMessageBox.Warning, 'Warning', '未选中视频!')
141 | msg_box.exec_()
142 | print('Open Error! Try again!')
143 | return
144 |
145 | try:
146 | global centernet
147 | while self.isVideoOn:
148 | ref, frame = capture.read()
149 |
150 | now = datetime.datetime.now()
151 | time = now.strftime("%Y %m %d %H %M %S")
152 | timelist = time.split(" ")
153 | self.day, self.time = FormetDayTime(timelist)
154 |
155 | # 转变成Image
156 | frame = Image.fromarray(np.uint8(frame))
157 | # 进行检测
158 | frame, obj1sum, obj2sum, color1, color2, max_cnt = centernet.detect_image(frame)
159 |
160 | lock.acquire()
161 | with open("./detect_Logs/logs.txt", 'w') as f:
162 | f.write(self.day + " " + self.time + ';')
163 | f.write(str(obj1sum)+';')
164 | f.write(str(obj2sum)+";")
165 | f.write(str(max_cnt))
166 |
167 | with open("./detect_Logs/countlogs", 'a') as f:
168 | f.write(self.time.split(":")[0] + ';')
169 | f.write(str(obj1sum)+';')
170 | f.write(str(obj2sum))
171 | f.write("\n")
172 | lock.release()
173 | # 设置
174 | draw = ImageDraw.Draw(frame)
175 | fontStyle = ImageFont.truetype(
176 | font="model_data/simhei.ttf", size=20, encoding='utf-8')
177 |
178 | # 绘制框和文本
179 | # draw.rectangle(
180 | # [tuple((0, 560)), tuple((180, 640))],
181 | # fill=(255, 255, 255), outline='black')
182 |
183 | draw.ellipse((10, 560, 30, 580), fill=color1)
184 | draw.ellipse((10, 600, 30, 620), fill=color2)
185 | draw.text((50, 560), "被占车位:" + str(obj1sum), color1, font=fontStyle)
186 | draw.text((50, 600), "空车位 :" + str(obj2sum), color2, font=fontStyle)
187 | draw.text((0, 0), "时间:" + self.day + " " + self.time, (255, 255, 255), font=fontStyle)
188 |
189 | frame.save("detect_Logs/1.png")
190 | frame = np.array(frame)
191 |
192 | img_src = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
193 |
194 | label_width = self.label.width()
195 | label_height = self.label.height()
196 | temp_imgSrc = QImage(img_src[:], img_src.shape[1], img_src.shape[0], img_src.shape[1] * 3,
197 | QImage.Format_RGB888)
198 |
199 | # 将图片转换为QPixmap方便显示
200 | pixmap_imgSrc = QPixmap.fromImage(temp_imgSrc).scaled(label_width, label_height)
201 | self.label.setPixmap(QPixmap(pixmap_imgSrc))
202 |
203 | c = cv2.waitKey(1) & 0xff
204 |
205 | if c == 27:
206 | capture.release()
207 | break
208 |
209 | capture.release()
210 | cv2.destroyAllWindows()
211 |
212 |
213 | except:
214 | self.label.setPixmap(QPixmap(self.pixmap_imgSrc))
215 | else:
216 | self.label.setPixmap(QPixmap(self.pixmap_imgSrc))
217 |
218 | def On_Off(self):
219 | self.isOn = ~self.isOn
220 | if not self.isOn:
221 | self.pushButton_3.setText("实时监测")
222 | else:
223 | self.pushButton_3.setText("结束实时监测")
224 | if self.isOn:
225 | capture = cv2.VideoCapture(0)
226 | global centernet
227 | while self.isOn:
228 | # 获取当前时间
229 | now = datetime.datetime.now()
230 | time = now.strftime("%Y %m %d %H %M %S")
231 | timelist = time.split(" ")
232 | self.day, self.time = FormetDayTime(timelist)
233 |
234 | # 读取某一帧
235 | ref, frame = capture.read()
236 | # 转变成Image
237 | frame = Image.fromarray(np.uint8(frame))
238 | # 进行检测
239 | frame, obj1sum, obj2sum, color1, color2, max_cnt = centernet.detect_image(frame)
240 |
241 | # 设置
242 | draw = ImageDraw.Draw(frame)
243 | fontStyle = ImageFont.truetype(
244 | font="model_data/simhei.ttf", size=20, encoding='utf-8')
245 |
246 | # 绘制框和文本
247 | # draw.rectangle(
248 | # [tuple((0, 0)), tuple((180, 40))],
249 | # fill=(255, 255, 255), outline='black')
250 | # draw.rectangle(
251 | # [tuple((0, 0)), tuple((420, 40))],
252 | # fill=(255, 255, 255), outline='black')
253 |
254 | draw.ellipse((35, 410, 55, 430), fill=color1)
255 | draw.ellipse((35, 440, 55, 460), fill=color2)
256 | draw.text((60, 410), "被占车位:" + str(obj1sum), fill=color1, font=fontStyle)
257 | draw.text((60, 440), "空车位 :" + str(obj2sum), fill=color2, font=fontStyle)
258 | draw.text((0, 0), "时间:" + self.day + " " + self.time, (0, 0, 0), font=fontStyle)
259 | lock.acquire()
260 | with open("./detect_Logs/logs.txt", 'w') as f:
261 | f.write(self.day + " " + self.time + ';')
262 | f.write(str(obj1sum) + ';')
263 | f.write(str(obj2sum)+";")
264 | f.write(str(max_cnt))
265 |
266 | with open("./detect_Logs/countlogs", 'a') as f:
267 | f.write(self.time.split(":")[0] + ';')
268 | f.write(str(obj1sum) + ';')
269 | f.write(str(obj2sum))
270 | f.write("\n")
271 |
272 | lock.release()
273 | frame.save("detect_Logs/1.png")
274 | frame = np.array(frame)
275 |
276 | # RGBtoBGR满足opencv显示格式
277 | img_src = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
278 | label_width = self.label.width()
279 | label_height = self.label.height()
280 | temp_imgSrc = QImage(img_src[:], img_src.shape[1], img_src.shape[0], img_src.shape[1] * 3,
281 | QImage.Format_RGB888)
282 |
283 | # 将图片转换为QPixmap方便显示
284 | pixmap_imgSrc = QPixmap.fromImage(temp_imgSrc).scaled(label_width, label_height)
285 | self.label.setPixmap(QPixmap(pixmap_imgSrc))
286 |
287 | # if int(timelist[4]) % 5 == 0 and int(timelist[5]) == 0:
288 | # # 每隔五分钟保存一次
289 | # SaveData(self.day, self.time, num)
290 | c = cv2.waitKey(1) & 0xff
291 |
292 | if c == 27:
293 | capture.release()
294 | break
295 | capture.release()
296 | cv2.destroyAllWindows()
297 | else:
298 | self.label.setPixmap(QPixmap(self.pixmap_imgSrc))
299 |
300 |
301 | def detect_process():
302 | app = QApplication(sys.argv)
303 | window = predictWindow()
304 | window.show()
305 | sys.exit(app.exec_())
306 |
307 | def response_process():
308 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
309 | sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
310 | sock.bind(('127.0.0.1', 8000))
311 | sock.listen(5)
312 | while 1:
313 | cli_sock, cli_addr = sock.accept()
314 | req = cli_sock.recv(4096)
315 | lock.acquire()
316 | with open('detect_Logs/logs.txt') as f:
317 | word = f.readline()
318 | cli_sock.send(word.encode())
319 | cli_sock.close()
320 | lock.release()
321 |
322 | if __name__ == "__main__":
323 | t1 = Thread(target=detect_process)
324 | t2 = Thread(target=response_process)
325 | t1.start()
326 | t2.start()
327 | t1.join()
328 | t2.join()
329 |
--------------------------------------------------------------------------------
/PKLot_server/centernet.py:
--------------------------------------------------------------------------------
1 | import colorsys
2 | import os
3 | import time
4 |
5 | import numpy as np
6 | import tensorflow as tf
7 | from PIL import ImageDraw, ImageFont
8 |
9 | from nets.centernet import centernet
10 | from utils.utils import centernet_correct_boxes, letterbox_image, nms
11 |
12 |
13 | def preprocess_image(image):
14 | mean = [0.40789655, 0.44719303, 0.47026116]
15 | std = [0.2886383 , 0.27408165, 0.27809834]
16 | return ((np.float32(image) / 255.) - mean) / std
17 |
18 | #--------------------------------------------#
19 | # 使用自己训练好的模型预测需要修改3个参数
20 | # model_path、classes_path和backbone
21 | # 都需要修改!
22 | # 如果出现shape不匹配,一定要注意
23 | # 训练时的model_path和classes_path参数的修改
24 | #--------------------------------------------#
25 |
26 | opt = 0
27 | if opt == 0:
28 | m_p = 'model_data/centernet_hourglass_coco.h5'
29 | c_p = 'model_data/coco_classes.txt'
30 | backb = 'hourglass'
31 | ns = False
32 |
33 | else:
34 | m_p = 'model_data/centernet_resnet50_voc.h5'
35 | c_p = 'model_data/voc_classes.txt'
36 | backb = 'resnet50'
37 | ns = True
38 |
39 | class CenterNet(object):
40 | global m_p
41 | global c_p
42 | global backb
43 | global ns
44 | _defaults = {
45 | # "model_path" : 'model_data/centernet_resnet50_voc.h5',
46 | # "classes_path" : 'model_data/voc_classes.txt',
47 | # "backbone" : 'resnet50',
48 | "model_path" : m_p, #'model_data/centernet_hourglass_coco.h5',
49 | "classes_path" : c_p, #'model_data/coco_classes.txt',
50 | "backbone" : backb, # 'hourglass',
51 | "input_shape" : [512,512,3],
52 | "confidence" : 0.3,
53 | # backbone为resnet50时建议设置为True
54 | # backbone为hourglass时建议设置为False
55 | # 也可以根据检测效果自行选择
56 | # "nms" : True,
57 | "nms": ns,
58 | "nms_threhold" : 0.3,
59 | }
60 |
61 | @classmethod
62 | def get_defaults(cls, n):
63 | if n in cls._defaults:
64 | return cls._defaults[n]
65 | else:
66 | return "Unrecognized attribute name '" + n + "'"
67 |
68 | #---------------------------------------------------#
69 | # 初始化centernet
70 | #---------------------------------------------------#
71 | def __init__(self, **kwargs):
72 | self.__dict__.update(self._defaults)
73 | self.class_names = self._get_class()
74 | self.generate()
75 |
76 | #---------------------------------------------------#
77 | # 获得所有的分类
78 | #---------------------------------------------------#
79 | def _get_class(self):
80 | classes_path = os.path.expanduser(self.classes_path)
81 | with open(classes_path) as f:
82 | class_names = f.readlines()
83 | class_names = [c.strip() for c in class_names]
84 | return class_names
85 |
86 | #---------------------------------------------------#
87 | # 载入模型
88 | #---------------------------------------------------#
89 | def generate(self):
90 | model_path = os.path.expanduser(self.model_path)
91 | assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.'
92 |
93 | #----------------------------------------#
94 | # 计算种类数量
95 | #----------------------------------------#
96 | self.num_classes = len(self.class_names)
97 |
98 | #----------------------------------------#
99 | # 创建centernet模型
100 | #----------------------------------------#
101 | self.centernet = centernet(self.input_shape,num_classes=self.num_classes,backbone=self.backbone,mode='predict')
102 | self.centernet.load_weights(self.model_path, by_name=True)
103 |
104 | print('{} model, anchors, and classes loaded.'.format(self.model_path))
105 |
106 | # 画框设置不同的颜色
107 | hsv_tuples = [(x / len(self.class_names), 1., 1.)
108 | for x in range(len(self.class_names))]
109 | self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
110 | self.colors = list(
111 | map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),
112 | self.colors))
113 |
114 | @tf.function
115 | def get_pred(self, photo):
116 | preds = self.centernet(photo, training=False)
117 | return preds
118 | #---------------------------------------------------#
119 | # 检测图片
120 | #---------------------------------------------------#
121 | def detect_image(self, image):
122 | x_max_shape, y_max_shape = self.input_shape[0], self.input_shape[1]
123 | #---------------------------------------------------------#
124 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
125 | #---------------------------------------------------------#
126 | image = image.convert('RGB')
127 |
128 | image_shape = np.array(np.shape(image)[0:2])
129 | #---------------------------------------------------------#
130 | # 给图像增加灰条,实现不失真的resize
131 | #---------------------------------------------------------#
132 | crop_img = letterbox_image(image, [self.input_shape[0],self.input_shape[1]])
133 | #----------------------------------------------------------------------------------#
134 | # 将RGB转化成BGR,这是因为原始的centernet_hourglass权值是使用BGR通道的图片训练的
135 | #----------------------------------------------------------------------------------#
136 | photo = np.array(crop_img,dtype = np.float32)[:,:,::-1]
137 | #-----------------------------------------------------------#
138 | # 图片预处理,归一化。获得的photo的shape为[1, 512, 512, 3]
139 | #-----------------------------------------------------------#
140 | photo = np.reshape(preprocess_image(photo),[1,self.input_shape[0],self.input_shape[1],self.input_shape[2]])
141 |
142 | preds = self.get_pred(photo).numpy()
143 | #-------------------------------------------------------#
144 | # 对于centernet网络来讲,确立中心非常重要。
145 | # 对于大目标而言,会存在许多的局部信息。
146 | # 此时对于同一个大目标,中心点比较难以确定。
147 | # 使用最大池化的非极大抑制方法无法去除局部框
148 | # 所以我还是写了另外一段对框进行非极大抑制的代码
149 | # 实际测试中,hourglass为主干网络时有无额外的nms相差不大,resnet相差较大。
150 | #-------------------------------------------------------#
151 | if self.nms:
152 | preds = np.array(nms(preds,self.nms_threhold))
153 |
154 | if len(preds[0])<=0:
155 | return image
156 |
157 | #-----------------------------------------------------------#
158 | # 将预测结果转换成小数的形式
159 | #-----------------------------------------------------------#
160 | preds[0][:,0:4] = preds[0][:,0:4]/(self.input_shape[0]/4)
161 |
162 | det_label = preds[0][:, -1]
163 | det_conf = preds[0][:, -2]
164 | det_xmin, det_ymin, det_xmax, det_ymax = preds[0][:, 0], preds[0][:, 1], preds[0][:, 2], preds[0][:, 3]
165 | #-----------------------------------------------------------#
166 | # 筛选出其中得分高于confidence的框
167 | #-----------------------------------------------------------#
168 | top_indices = [i for i, conf in enumerate(det_conf) if conf >= self.confidence]
169 | top_conf = det_conf[top_indices]
170 | top_objs = det_label[top_indices]
171 | obj1sum = len(top_objs[top_objs == 0])
172 | obj2sum = len(top_objs[top_objs == 1])
173 |
174 | top_label_indices = top_objs.tolist()
175 | top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims(det_xmin[top_indices],-1),np.expand_dims(det_ymin[top_indices],-1),np.expand_dims(det_xmax[top_indices],-1),np.expand_dims(det_ymax[top_indices],-1)
176 |
177 | #-----------------------------------------------------------#
178 | # 去掉灰条部分
179 | #-----------------------------------------------------------#
180 | boxes = centernet_correct_boxes(top_ymin,top_xmin,top_ymax,top_xmax,np.array([self.input_shape[0],self.input_shape[1]]),image_shape)
181 |
182 | font = ImageFont.truetype(font='model_data/simhei.ttf',size=np.floor(3e-2 * np.shape(image)[1] + 0.5).astype('int32'))
183 |
184 | thickness = max((np.shape(image)[0] + np.shape(image)[1]) // self.input_shape[0], 1)
185 | cnt1 = 0
186 | cnt2 = 0
187 | cnt3 = 0
188 | cnt4 = 0
189 | max_add = 0
190 | max_cnt = 0
191 | for i, c in enumerate(top_label_indices):
192 | predicted_class = self.class_names[int(c)]
193 | score = top_conf[i]
194 |
195 | top, left, bottom, right = boxes[i]
196 | x_mid = (left + right) / 2
197 | y_mid = (bottom + top) / 2
198 | if predicted_class == "space-empty":
199 | if x_mid <= x_max_shape / 2 and y_mid <= y_max_shape / 2:
200 | cnt1 += 1
201 | elif x_mid >= x_max_shape / 2 and y_mid <= y_max_shape / 2:
202 | cnt2 += 1
203 | elif x_mid <= x_max_shape / 2 and y_mid >= y_max_shape / 2:
204 | cnt3 += 1
205 | else:
206 | cnt4 += 1
207 |
208 |
209 | top = top - 5
210 | left = left - 5
211 | bottom = bottom + 5
212 | right = right + 5
213 |
214 | top = max(0, np.floor(top + 0.5).astype('int32'))
215 | left = max(0, np.floor(left + 0.5).astype('int32'))
216 | bottom = min(np.shape(image)[0], np.floor(bottom + 0.5).astype('int32'))
217 | right = min(np.shape(image)[1], np.floor(right + 0.5).astype('int32'))
218 |
219 | # 画框框
220 | label = '{} {:.2f}'.format(predicted_class, score)
221 | draw = ImageDraw.Draw(image)
222 | label_size = draw.textsize(label, font)
223 | label = label.encode('utf-8')
224 | # print(label, top, left, bottom, right)
225 |
226 | if top - label_size[1] >= 0:
227 | text_origin = np.array([left, top - label_size[1]])
228 | else:
229 | text_origin = np.array([left, top + 1])
230 |
231 | for i in range(thickness):
232 | draw.rectangle(
233 | [left + i, top + i, right - i, bottom - i],
234 | outline=self.colors[int(c)])
235 | # draw.rectangle(
236 | # [tuple(text_origin), tuple(text_origin + label_size)],
237 | # fill=self.colors[int(c)])
238 | # draw.text(text_origin, str(label,'UTF-8'), fill=(0, 0, 0), font=font)
239 | del draw
240 |
241 | if cnt1 > max_add:
242 | max_add = cnt1
243 | max_cnt = 1
244 |
245 | if cnt2 > max_add:
246 | max_add = cnt2
247 | max_cnt = 2
248 |
249 | if cnt3 > max_add:
250 | max_add = cnt3
251 | max_cnt = 3
252 |
253 | if cnt4 > max_add:
254 | max_add = cnt4
255 | max_cnt = 4
256 | print(cnt1, cnt2, cnt3, cnt4)
257 | return image, obj1sum, obj2sum, self.colors[0], self.colors[1], max_cnt
258 |
259 | def draw_lines(self):
260 | pass
261 |
262 | def get_FPS(self, image, test_interval):
263 | #---------------------------------------------------------#
264 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
265 | #---------------------------------------------------------#
266 | image = image.convert('RGB')
267 |
268 | image_shape = np.array(np.shape(image)[0:2])
269 | #---------------------------------------------------------#
270 | # 给图像增加灰条,实现不失真的resize
271 | #---------------------------------------------------------#
272 | crop_img = letterbox_image(image, [self.input_shape[0],self.input_shape[1]])
273 | #----------------------------------------------------------------------------------#
274 | # 将RGB转化成BGR,这是因为原始的centernet_hourglass权值是使用BGR通道的图片训练的
275 | #----------------------------------------------------------------------------------#
276 | photo = np.array(crop_img,dtype = np.float32)[:,:,::-1]
277 | #-----------------------------------------------------------#
278 | # 图片预处理,归一化。获得的photo的shape为[1, 512, 512, 3]
279 | #-----------------------------------------------------------#
280 | photo = np.reshape(preprocess_image(photo), [1, self.input_shape[0], self.input_shape[1], self.input_shape[2]])
281 |
282 | preds = self.get_pred(photo).numpy()
283 |
284 | if self.nms:
285 | preds = np.array(nms(preds, self.nms_threhold))
286 |
287 | if len(preds[0])>0:
288 | preds[0][:, 0:4] = preds[0][:, 0:4] / (self.input_shape[0] / 4)
289 |
290 | det_label = preds[0][:, -1]
291 | det_conf = preds[0][:, -2]
292 | det_xmin, det_ymin, det_xmax, det_ymax = preds[0][:, 0], preds[0][:, 1], preds[0][:, 2], preds[0][:, 3]
293 |
294 | top_indices = [i for i, conf in enumerate(det_conf) if conf >= self.confidence]
295 | top_conf = det_conf[top_indices]
296 | top_label_indices = det_label[top_indices].tolist()
297 | top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims(det_xmin[top_indices],-1),np.expand_dims(det_ymin[top_indices],-1),np.expand_dims(det_xmax[top_indices],-1),np.expand_dims(det_ymax[top_indices],-1)
298 |
299 | boxes = centernet_correct_boxes(top_ymin,top_xmin,top_ymax,top_xmax,np.array([self.input_shape[0],self.input_shape[1]]),image_shape)
300 |
301 | t1 = time.time()
302 | for _ in range(test_interval):
303 | preds = self.get_pred(photo).numpy()
304 |
305 | if self.nms:
306 | preds = np.array(nms(preds, self.nms_threhold))
307 |
308 | if len(preds[0])>0:
309 | preds[0][:, 0:4] = preds[0][:, 0:4] / (self.input_shape[0] / 4)
310 |
311 | det_label = preds[0][:, -1]
312 | det_conf = preds[0][:, -2]
313 | det_xmin, det_ymin, det_xmax, det_ymax = preds[0][:, 0], preds[0][:, 1], preds[0][:, 2], preds[0][:, 3]
314 |
315 | top_indices = [i for i, conf in enumerate(det_conf) if conf >= self.confidence]
316 | top_conf = det_conf[top_indices]
317 | top_label_indices = det_label[top_indices].tolist()
318 | top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims(det_xmin[top_indices],-1),np.expand_dims(det_ymin[top_indices],-1),np.expand_dims(det_xmax[top_indices],-1),np.expand_dims(det_ymax[top_indices],-1)
319 |
320 | boxes = centernet_correct_boxes(top_ymin,top_xmin,top_ymax,top_xmax,np.array([self.input_shape[0],self.input_shape[1]]),image_shape)
321 |
322 | t2 = time.time()
323 | tact_time = (t2 - t1) / test_interval
324 | return tact_time
325 |
--------------------------------------------------------------------------------
/PKLot_server/nets/centernet_training.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | from random import shuffle
4 |
5 | import cv2
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | import scipy.signal
9 | import tensorflow as tf
10 | from PIL import Image
11 | from utils.utils import draw_gaussian, gaussian_radius
12 |
13 | def preprocess_image(image):
14 | mean = [0.40789655, 0.44719303, 0.47026116]
15 | std = [0.2886383 , 0.27408165, 0.27809834]
16 | return ((np.float32(image) / 255.) - mean) / std
17 |
18 | def focal_loss(hm_pred, hm_true):
19 | #-------------------------------------------------------------------------#
20 | # 找到每张图片的正样本和负样本
21 | # 一个真实框对应一个正样本
22 | # 除去正样本的特征点,其余为负样本
23 | #-------------------------------------------------------------------------#
24 | pos_mask = tf.cast(tf.equal(hm_true, 1), tf.float32)
25 | #-------------------------------------------------------------------------#
26 | # 正样本特征点附近的负样本的权值更小一些
27 | #-------------------------------------------------------------------------#
28 | neg_mask = tf.cast(tf.less(hm_true, 1), tf.float32)
29 | neg_weights = tf.pow(1 - hm_true, 4)
30 |
31 | #-------------------------------------------------------------------------#
32 | # 计算focal loss。难分类样本权重大,易分类样本权重小。
33 | #-------------------------------------------------------------------------#
34 | pos_loss = -tf.math.log(tf.clip_by_value(hm_pred, 1e-6, 1.)) * tf.pow(1 - hm_pred, 2) * pos_mask
35 | neg_loss = -tf.math.log(tf.clip_by_value(1 - hm_pred, 1e-6, 1.)) * tf.pow(hm_pred, 2) * neg_weights * neg_mask
36 |
37 | num_pos = tf.reduce_sum(pos_mask)
38 | pos_loss = tf.reduce_sum(pos_loss)
39 | neg_loss = tf.reduce_sum(neg_loss)
40 |
41 | #-------------------------------------------------------------------------#
42 | # 进行损失的归一化
43 | #-------------------------------------------------------------------------#
44 | cls_loss = tf.cond(tf.greater(num_pos, 0), lambda: (pos_loss + neg_loss) / num_pos, lambda: neg_loss)
45 | return cls_loss
46 |
47 |
48 | def reg_l1_loss(y_pred, y_true, indices, mask):
49 | #-------------------------------------------------------------------------#
50 | # 获得batch_size和num_classes
51 | #-------------------------------------------------------------------------#
52 | b, c = tf.shape(y_pred)[0], tf.shape(y_pred)[-1]
53 | k = tf.shape(indices)[1]
54 |
55 | y_pred = tf.reshape(y_pred, (b, -1, c))
56 | length = tf.shape(y_pred)[1]
57 | indices = tf.cast(indices, tf.int32)
58 |
59 | #-------------------------------------------------------------------------#
60 | # 利用序号取出预测结果中,和真实框相同的特征点的部分
61 | #-------------------------------------------------------------------------#
62 | batch_idx = tf.expand_dims(tf.range(0, b), 1)
63 | batch_idx = tf.tile(batch_idx, (1, k))
64 | full_indices = (tf.reshape(batch_idx, [-1]) * tf.cast(length, tf.int32) +
65 | tf.reshape(indices, [-1]))
66 |
67 | y_pred = tf.gather(tf.reshape(y_pred, [-1,c]),full_indices)
68 | y_pred = tf.reshape(y_pred, [b, -1, c])
69 |
70 | mask = tf.tile(tf.expand_dims(mask, axis=-1), (1, 1, 2))
71 | #-------------------------------------------------------------------------#
72 | # 求取l1损失值
73 | #-------------------------------------------------------------------------#
74 | total_loss = tf.reduce_sum(tf.abs(y_true * mask - y_pred * mask))
75 | reg_loss = total_loss / (tf.reduce_sum(mask) + 1e-4)
76 | return reg_loss
77 |
78 |
79 | def loss(args):
80 | #-----------------------------------------------------------------------------------------------------------------#
81 | # hm_pred:热力图的预测值 (batch_size, 128, 128, num_classes)
82 | # wh_pred:宽高的预测值 (batch_size, 128, 128, 2)
83 | # reg_pred:中心坐标偏移预测值 (batch_size, 128, 128, 2)
84 | # hm_true:热力图的真实值 (batch_size, 128, 128, num_classes)
85 | # wh_true:宽高的真实值 (batch_size, max_objects, 2)
86 | # reg_true:中心坐标偏移真实值 (batch_size, max_objects, 2)
87 | # reg_mask:真实值的mask (batch_size, max_objects)
88 | # indices:真实值对应的坐标 (batch_size, max_objects)
89 | #-----------------------------------------------------------------------------------------------------------------#
90 | hm_pred, wh_pred, reg_pred, hm_true, wh_true, reg_true, reg_mask, indices = args
91 | hm_loss = focal_loss(hm_pred, hm_true)
92 | wh_loss = 0.1 * reg_l1_loss(wh_pred, wh_true, indices, reg_mask)
93 | reg_loss = reg_l1_loss(reg_pred, reg_true, indices, reg_mask)
94 | total_loss = hm_loss + wh_loss + reg_loss
95 | # total_loss = tf.Print(total_loss,[hm_loss,wh_loss,reg_loss])
96 | return total_loss
97 |
98 | def rand(a=0, b=1):
99 | return np.random.rand()*(b-a) + a
100 |
101 | class Generator(object):
102 | def __init__(self,batch_size,train_lines,val_lines,
103 | input_size,num_classes,max_objects=200):
104 |
105 | self.batch_size = batch_size
106 | self.train_lines = train_lines
107 | self.val_lines = val_lines
108 | self.input_size = input_size
109 | self.output_size = (int(input_size[0]/4) , int(input_size[1]/4))
110 | self.num_classes = num_classes
111 | self.max_objects = max_objects
112 |
113 | def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=1.5, val=1.5, random=True):
114 | '''r实时数据增强的随机预处理'''
115 | line = annotation_line.split()
116 | image = Image.open(line[0])
117 | iw, ih = image.size
118 | h, w = input_shape
119 | box = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
120 |
121 | if not random:
122 | # resize image
123 | scale = min(w/iw, h/ih)
124 | nw = int(iw*scale)
125 | nh = int(ih*scale)
126 | dx = (w-nw)//2
127 | dy = (h-nh)//2
128 |
129 | image = image.resize((nw,nh), Image.BICUBIC)
130 | new_image = Image.new('RGB', (w,h), (128,128,128))
131 | new_image.paste(image, (dx, dy))
132 | image_data = np.array(new_image, np.float32)
133 |
134 | # correct boxes
135 | box_data = np.zeros((len(box),5))
136 | if len(box)>0:
137 | np.random.shuffle(box)
138 | box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
139 | box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
140 | box[:, 0:2][box[:, 0:2]<0] = 0
141 | box[:, 2][box[:, 2]>w] = w
142 | box[:, 3][box[:, 3]>h] = h
143 | box_w = box[:, 2] - box[:, 0]
144 | box_h = box[:, 3] - box[:, 1]
145 | box = box[np.logical_and(box_w>1, box_h>1)]
146 | box_data = np.zeros((len(box),5))
147 | box_data[:len(box)] = box
148 |
149 | return image_data, box_data
150 |
151 | # resize image
152 | new_ar = w/h * rand(1-jitter,1+jitter)/rand(1-jitter,1+jitter)
153 | scale = rand(0.25, 2)
154 | if new_ar < 1:
155 | nh = int(scale*h)
156 | nw = int(nh*new_ar)
157 | else:
158 | nw = int(scale*w)
159 | nh = int(nw/new_ar)
160 | image = image.resize((nw,nh), Image.BICUBIC)
161 |
162 | # place image
163 | dx = int(rand(0, w-nw))
164 | dy = int(rand(0, h-nh))
165 | new_image = Image.new('RGB', (w,h), (128,128,128))
166 | new_image.paste(image, (dx, dy))
167 | image = new_image
168 |
169 | # flip image or not
170 | flip = rand()<.5
171 | if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)
172 |
173 | # distort image
174 | hue = rand(-hue, hue)
175 | sat = rand(1, sat) if rand()<.5 else 1/rand(1, sat)
176 | val = rand(1, val) if rand()<.5 else 1/rand(1, val)
177 | x = cv2.cvtColor(np.array(image,np.float32)/255, cv2.COLOR_RGB2HSV)
178 | x[..., 0] += hue*360
179 | x[..., 0][x[..., 0]>1] -= 1
180 | x[..., 0][x[..., 0]<0] += 1
181 | x[..., 1] *= sat
182 | x[..., 2] *= val
183 | x[x[:,:, 0]>360, 0] = 360
184 | x[:, :, 1:][x[:, :, 1:]>1] = 1
185 | x[x<0] = 0
186 | image_data = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255
187 |
188 |
189 | # correct boxes
190 | box_data = np.zeros((len(box),5))
191 | if len(box)>0:
192 | np.random.shuffle(box)
193 | box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
194 | box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
195 | if flip: box[:, [0,2]] = w - box[:, [2,0]]
196 | box[:, 0:2][box[:, 0:2]<0] = 0
197 | box[:, 2][box[:, 2]>w] = w
198 | box[:, 3][box[:, 3]>h] = h
199 | box_w = box[:, 2] - box[:, 0]
200 | box_h = box[:, 3] - box[:, 1]
201 | box = box[np.logical_and(box_w>1, box_h>1)] # discard invalid box
202 | box_data = np.zeros((len(box),5))
203 | box_data[:len(box)] = box
204 | if len(box) == 0:
205 | return image_data, []
206 |
207 | if (box_data[:,:4]>0).any():
208 | return image_data, box_data
209 | else:
210 | return image_data, []
211 |
212 | def generate(self, train=True, eager=False):
213 | while True:
214 | if train:
215 | # 打乱
216 | shuffle(self.train_lines)
217 | lines = self.train_lines
218 | else:
219 | shuffle(self.val_lines)
220 | lines = self.val_lines
221 |
222 | batch_images = np.zeros((self.batch_size, self.input_size[0], self.input_size[1], self.input_size[2]), dtype=np.float32)
223 | batch_hms = np.zeros((self.batch_size, self.output_size[0], self.output_size[1], self.num_classes), dtype=np.float32)
224 | batch_whs = np.zeros((self.batch_size, self.max_objects, 2), dtype=np.float32)
225 | batch_regs = np.zeros((self.batch_size, self.max_objects, 2), dtype=np.float32)
226 | batch_reg_masks = np.zeros((self.batch_size, self.max_objects), dtype=np.float32)
227 | batch_indices = np.zeros((self.batch_size, self.max_objects), dtype=np.float32)
228 |
229 | b = 0
230 | for annotation_line in lines:
231 | img,y = self.get_random_data(annotation_line,self.input_size[0:2],random=train)
232 |
233 | if len(y)!=0:
234 | boxes = np.array(y[:,:4],dtype=np.float32)
235 | boxes[:,0] = boxes[:,0]/self.input_size[1]*self.output_size[1]
236 | boxes[:,1] = boxes[:,1]/self.input_size[0]*self.output_size[0]
237 | boxes[:,2] = boxes[:,2]/self.input_size[1]*self.output_size[1]
238 | boxes[:,3] = boxes[:,3]/self.input_size[0]*self.output_size[0]
239 |
240 | for i in range(len(y)):
241 | bbox = boxes[i].copy()
242 | bbox = np.array(bbox)
243 | bbox[[0, 2]] = np.clip(bbox[[0, 2]], 0, self.output_size[1] - 1)
244 | bbox[[1, 3]] = np.clip(bbox[[1, 3]], 0, self.output_size[0] - 1)
245 | cls_id = int(y[i,-1])
246 |
247 | h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
248 | if h > 0 and w > 0:
249 | ct = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2], dtype=np.float32)
250 | ct_int = ct.astype(np.int32)
251 |
252 | # 获得热力图
253 | radius = gaussian_radius((math.ceil(h), math.ceil(w)))
254 | radius = max(0, int(radius))
255 | batch_hms[b, :, :, cls_id] = draw_gaussian(batch_hms[b, :, :, cls_id], ct_int, radius)
256 |
257 | batch_whs[b, i] = 1. * w, 1. * h
258 | # 计算中心偏移量
259 | batch_regs[b, i] = ct - ct_int
260 | # 将对应的mask设置为1,用于排除多余的0
261 | batch_reg_masks[b, i] = 1
262 | # 表示第ct_int[1]行的第ct_int[0]个。
263 | batch_indices[b, i] = ct_int[1] * self.output_size[0] + ct_int[0]
264 |
265 | # 将RGB转化成BGR
266 | img = np.array(img,dtype = np.float32)[:,:,::-1]
267 | batch_images[b] = preprocess_image(img)
268 | b = b + 1
269 | if b == self.batch_size:
270 | b = 0
271 | if eager:
272 | yield batch_images, batch_hms, batch_whs, batch_regs, batch_reg_masks, batch_indices
273 | else:
274 | yield [batch_images, batch_hms, batch_whs, batch_regs, batch_reg_masks, batch_indices], np.zeros((self.batch_size,))
275 |
276 | batch_images = np.zeros((self.batch_size, self.input_size[0], self.input_size[1], 3), dtype=np.float32)
277 |
278 | batch_hms = np.zeros((self.batch_size, self.output_size[0], self.output_size[1], self.num_classes),
279 | dtype=np.float32)
280 | batch_whs = np.zeros((self.batch_size, self.max_objects, 2), dtype=np.float32)
281 | batch_regs = np.zeros((self.batch_size, self.max_objects, 2), dtype=np.float32)
282 | batch_reg_masks = np.zeros((self.batch_size, self.max_objects), dtype=np.float32)
283 | batch_indices = np.zeros((self.batch_size, self.max_objects), dtype=np.float32)
284 |
285 |
286 | class LossHistory(tf.keras.callbacks.Callback):
287 | def __init__(self, log_dir):
288 | import datetime
289 | curr_time = datetime.datetime.now()
290 | time_str = datetime.datetime.strftime(curr_time,'%Y_%m_%d_%H_%M_%S')
291 | self.log_dir = log_dir
292 | self.time_str = time_str
293 | self.save_path = os.path.join(self.log_dir, "loss_" + str(self.time_str))
294 | self.losses = []
295 | self.val_loss = []
296 |
297 | os.makedirs(self.save_path)
298 |
299 | def on_epoch_end(self, batch, logs={}):
300 | self.losses.append(logs.get('loss'))
301 | self.val_loss.append(logs.get('val_loss'))
302 | with open(os.path.join(self.save_path, "epoch_loss_" + str(self.time_str) + ".txt"), 'a') as f:
303 | f.write(str(logs.get('loss')))
304 | f.write("\n")
305 | with open(os.path.join(self.save_path, "epoch_val_loss_" + str(self.time_str) + ".txt"), 'a') as f:
306 | f.write(str(logs.get('val_loss')))
307 | f.write("\n")
308 |
309 | self.loss_plot()
310 |
311 | def loss_plot(self):
312 | iters = range(len(self.losses))
313 |
314 | plt.figure()
315 | plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
316 | plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
317 | try:
318 | if len(self.losses) < 25:
319 | num = 5
320 | else:
321 | num = 15
322 |
323 | plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
324 | plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
325 | except:
326 | pass
327 |
328 | plt.grid(True)
329 | plt.xlabel('Epoch')
330 | plt.ylabel('Loss')
331 | plt.title('A Loss Curve')
332 | plt.legend(loc="upper right")
333 |
334 | plt.savefig(os.path.join(self.save_path, "epoch_loss_" + str(self.time_str) + ".png"))
335 |
336 | plt.cla()
337 | plt.close("all")
338 |
--------------------------------------------------------------------------------
/PKLot_server/get_map.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import json
3 | import os
4 | import shutil
5 | import operator
6 | import sys
7 | import argparse
8 | import math
9 |
10 | import numpy as np
11 |
12 | '''
13 | 用于计算mAP
14 | 代码克隆自https://github.com/Cartucho/mAP
15 | 如果想要设定mAP0.x,比如计算mAP0.75,可以设定MINOVERLAP = 0.75。
16 | '''
17 | MINOVERLAP = 0.5
18 |
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument('-na', '--no-animation', help="no animation is shown.", action="store_true")
21 | parser.add_argument('-np', '--no-plot', help="no plot is shown.", action="store_true")
22 | parser.add_argument('-q', '--quiet', help="minimalistic console output.", action="store_true")
23 | parser.add_argument('-i', '--ignore', nargs='+', type=str, help="ignore a list of classes.")
24 | parser.add_argument('--set-class-iou', nargs='+', type=str, help="set IoU for a specific class.")
25 | args = parser.parse_args()
26 |
27 | '''
28 | 0,0 ------> x (width)
29 | |
30 | | (Left,Top)
31 | | *_________
32 | | | |
33 | | |
34 | y |_________|
35 | (height) *
36 | (Right,Bottom)
37 | '''
38 |
39 | if args.ignore is None:
40 | args.ignore = []
41 |
42 | specific_iou_flagged = False
43 | if args.set_class_iou is not None:
44 | specific_iou_flagged = True
45 |
46 | os.chdir(os.path.dirname(os.path.abspath(__file__)))
47 |
48 | GT_PATH = os.path.join(os.getcwd(), 'input', 'ground-truth')
49 | DR_PATH = os.path.join(os.getcwd(), 'input', 'detection-results')
50 | IMG_PATH = os.path.join(os.getcwd(), 'input', 'images-optional')
51 | if os.path.exists(IMG_PATH):
52 | for dirpath, dirnames, files in os.walk(IMG_PATH):
53 | if not files:
54 | args.no_animation = True
55 | else:
56 | args.no_animation = True
57 |
58 | show_animation = False
59 | if not args.no_animation:
60 | try:
61 | import cv2
62 | show_animation = True
63 | except ImportError:
64 | print("\"opencv-python\" not found, please install to visualize the results.")
65 | args.no_animation = True
66 |
67 | draw_plot = False
68 | if not args.no_plot:
69 | try:
70 | import matplotlib.pyplot as plt
71 | draw_plot = True
72 | except ImportError:
73 | print("\"matplotlib\" not found, please install it to get the resulting plots.")
74 | args.no_plot = True
75 |
76 |
77 | def log_average_miss_rate(precision, fp_cumsum, num_images):
78 | """
79 | log-average miss rate:
80 | Calculated by averaging miss rates at 9 evenly spaced FPPI points
81 | between 10e-2 and 10e0, in log-space.
82 |
83 | output:
84 | lamr | log-average miss rate
85 | mr | miss rate
86 | fppi | false positives per image
87 |
88 | references:
89 | [1] Dollar, Piotr, et al. "Pedestrian Detection: An Evaluation of the
90 | State of the Art." Pattern Analysis and Machine Intelligence, IEEE
91 | Transactions on 34.4 (2012): 743 - 761.
92 | """
93 |
94 | if precision.size == 0:
95 | lamr = 0
96 | mr = 1
97 | fppi = 0
98 | return lamr, mr, fppi
99 |
100 | fppi = fp_cumsum / float(num_images)
101 | mr = (1 - precision)
102 |
103 | fppi_tmp = np.insert(fppi, 0, -1.0)
104 | mr_tmp = np.insert(mr, 0, 1.0)
105 |
106 | ref = np.logspace(-2.0, 0.0, num = 9)
107 | for i, ref_i in enumerate(ref):
108 | j = np.where(fppi_tmp <= ref_i)[-1][-1]
109 | ref[i] = mr_tmp[j]
110 |
111 | lamr = math.exp(np.mean(np.log(np.maximum(1e-10, ref))))
112 |
113 | return lamr, mr, fppi
114 |
115 | """
116 | throw error and exit
117 | """
118 | def error(msg):
119 | print(msg)
120 | sys.exit(0)
121 |
122 | """
123 | check if the number is a float between 0.0 and 1.0
124 | """
125 | def is_float_between_0_and_1(value):
126 | try:
127 | val = float(value)
128 | if val > 0.0 and val < 1.0:
129 | return True
130 | else:
131 | return False
132 | except ValueError:
133 | return False
134 |
135 | """
136 | Calculate the AP given the recall and precision array
137 | 1st) We compute a version of the measured precision/recall curve with
138 | precision monotonically decreasing
139 | 2nd) We compute the AP as the area under this curve by numerical integration.
140 | """
141 | def voc_ap(rec, prec):
142 | """
143 | --- Official matlab code VOC2012---
144 | mrec=[0 ; rec ; 1];
145 | mpre=[0 ; prec ; 0];
146 | for i=numel(mpre)-1:-1:1
147 | mpre(i)=max(mpre(i),mpre(i+1));
148 | end
149 | i=find(mrec(2:end)~=mrec(1:end-1))+1;
150 | ap=sum((mrec(i)-mrec(i-1)).*mpre(i));
151 | """
152 | rec.insert(0, 0.0) # insert 0.0 at begining of list
153 | rec.append(1.0) # insert 1.0 at end of list
154 | mrec = rec[:]
155 | prec.insert(0, 0.0) # insert 0.0 at begining of list
156 | prec.append(0.0) # insert 0.0 at end of list
157 | mpre = prec[:]
158 | """
159 | This part makes the precision monotonically decreasing
160 | (goes from the end to the beginning)
161 | matlab: for i=numel(mpre)-1:-1:1
162 | mpre(i)=max(mpre(i),mpre(i+1));
163 | """
164 | for i in range(len(mpre)-2, -1, -1):
165 | mpre[i] = max(mpre[i], mpre[i+1])
166 | """
167 | This part creates a list of indexes where the recall changes
168 | matlab: i=find(mrec(2:end)~=mrec(1:end-1))+1;
169 | """
170 | i_list = []
171 | for i in range(1, len(mrec)):
172 | if mrec[i] != mrec[i-1]:
173 | i_list.append(i) # if it was matlab would be i + 1
174 | """
175 | The Average Precision (AP) is the area under the curve
176 | (numerical integration)
177 | matlab: ap=sum((mrec(i)-mrec(i-1)).*mpre(i));
178 | """
179 | ap = 0.0
180 | for i in i_list:
181 | ap += ((mrec[i]-mrec[i-1])*mpre[i])
182 | return ap, mrec, mpre
183 |
184 |
185 | """
186 | Convert the lines of a file to a list
187 | """
188 | def file_lines_to_list(path):
189 | # open txt file lines to a list
190 | with open(path) as f:
191 | content = f.readlines()
192 | # remove whitespace characters like `\n` at the end of each line
193 | content = [x.strip() for x in content]
194 | return content
195 |
196 | """
197 | Draws text in image
198 | """
199 | def draw_text_in_image(img, text, pos, color, line_width):
200 | font = cv2.FONT_HERSHEY_PLAIN
201 | fontScale = 1
202 | lineType = 1
203 | bottomLeftCornerOfText = pos
204 | cv2.putText(img, text,
205 | bottomLeftCornerOfText,
206 | font,
207 | fontScale,
208 | color,
209 | lineType)
210 | text_width, _ = cv2.getTextSize(text, font, fontScale, lineType)[0]
211 | return img, (line_width + text_width)
212 |
213 | """
214 | Plot - adjust axes
215 | """
216 | def adjust_axes(r, t, fig, axes):
217 | # get text width for re-scaling
218 | bb = t.get_window_extent(renderer=r)
219 | text_width_inches = bb.width / fig.dpi
220 | # get axis width in inches
221 | current_fig_width = fig.get_figwidth()
222 | new_fig_width = current_fig_width + text_width_inches
223 | propotion = new_fig_width / current_fig_width
224 | # get axis limit
225 | x_lim = axes.get_xlim()
226 | axes.set_xlim([x_lim[0], x_lim[1]*propotion])
227 |
228 | """
229 | Draw plot using Matplotlib
230 | """
231 | def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, output_path, to_show, plot_color, true_p_bar):
232 | # sort the dictionary by decreasing value, into a list of tuples
233 | sorted_dic_by_value = sorted(dictionary.items(), key=operator.itemgetter(1))
234 | # unpacking the list of tuples into two lists
235 | sorted_keys, sorted_values = zip(*sorted_dic_by_value)
236 | #
237 | if true_p_bar != "":
238 | """
239 | Special case to draw in:
240 | - green -> TP: True Positives (object detected and matches ground-truth)
241 | - red -> FP: False Positives (object detected but does not match ground-truth)
242 | - orange -> FN: False Negatives (object not detected but present in the ground-truth)
243 | """
244 | fp_sorted = []
245 | tp_sorted = []
246 | for key in sorted_keys:
247 | fp_sorted.append(dictionary[key] - true_p_bar[key])
248 | tp_sorted.append(true_p_bar[key])
249 | plt.barh(range(n_classes), fp_sorted, align='center', color='crimson', label='False Positive')
250 | plt.barh(range(n_classes), tp_sorted, align='center', color='forestgreen', label='True Positive', left=fp_sorted)
251 | # add legend
252 | plt.legend(loc='lower right')
253 | """
254 | Write number on side of bar
255 | """
256 | fig = plt.gcf() # gcf - get current figure
257 | axes = plt.gca()
258 | r = fig.canvas.get_renderer()
259 | for i, val in enumerate(sorted_values):
260 | fp_val = fp_sorted[i]
261 | tp_val = tp_sorted[i]
262 | fp_str_val = " " + str(fp_val)
263 | tp_str_val = fp_str_val + " " + str(tp_val)
264 | # trick to paint multicolor with offset:
265 | # first paint everything and then repaint the first number
266 | t = plt.text(val, i, tp_str_val, color='forestgreen', va='center', fontweight='bold')
267 | plt.text(val, i, fp_str_val, color='crimson', va='center', fontweight='bold')
268 | if i == (len(sorted_values)-1): # largest bar
269 | adjust_axes(r, t, fig, axes)
270 | else:
271 | plt.barh(range(n_classes), sorted_values, color=plot_color)
272 | """
273 | Write number on side of bar
274 | """
275 | fig = plt.gcf() # gcf - get current figure
276 | axes = plt.gca()
277 | r = fig.canvas.get_renderer()
278 | for i, val in enumerate(sorted_values):
279 | str_val = " " + str(val) # add a space before
280 | if val < 1.0:
281 | str_val = " {0:.2f}".format(val)
282 | t = plt.text(val, i, str_val, color=plot_color, va='center', fontweight='bold')
283 | # re-set axes to show number inside the figure
284 | if i == (len(sorted_values)-1): # largest bar
285 | adjust_axes(r, t, fig, axes)
286 | # set window title
287 | fig.canvas.set_window_title(window_title)
288 | # write classes in y axis
289 | tick_font_size = 12
290 | plt.yticks(range(n_classes), sorted_keys, fontsize=tick_font_size)
291 | """
292 | Re-scale height accordingly
293 | """
294 | init_height = fig.get_figheight()
295 | # comput the matrix height in points and inches
296 | dpi = fig.dpi
297 | height_pt = n_classes * (tick_font_size * 1.4) # 1.4 (some spacing)
298 | height_in = height_pt / dpi
299 | # compute the required figure height
300 | top_margin = 0.15 # in percentage of the figure height
301 | bottom_margin = 0.05 # in percentage of the figure height
302 | figure_height = height_in / (1 - top_margin - bottom_margin)
303 | # set new height
304 | if figure_height > init_height:
305 | fig.set_figheight(figure_height)
306 |
307 | # set plot title
308 | plt.title(plot_title, fontsize=14)
309 | # set axis titles
310 | # plt.xlabel('classes')
311 | plt.xlabel(x_label, fontsize='large')
312 | # adjust size of window
313 | fig.tight_layout()
314 | # save the plot
315 | fig.savefig(output_path)
316 | # show image
317 | if to_show:
318 | plt.show()
319 | # close the plot
320 | plt.close()
321 |
322 | """
323 | Create a ".temp_files/" and "results/" directory
324 | """
325 | TEMP_FILES_PATH = ".temp_files"
326 | if not os.path.exists(TEMP_FILES_PATH): # if it doesn't exist already
327 | os.makedirs(TEMP_FILES_PATH)
328 | results_files_path = "results"
329 | if os.path.exists(results_files_path): # if it exist already
330 | # reset the results directory
331 | shutil.rmtree(results_files_path)
332 |
333 | os.makedirs(results_files_path)
334 | if draw_plot:
335 | os.makedirs(os.path.join(results_files_path, "AP"))
336 | os.makedirs(os.path.join(results_files_path, "F1"))
337 | os.makedirs(os.path.join(results_files_path, "Recall"))
338 | os.makedirs(os.path.join(results_files_path, "Precision"))
339 | if show_animation:
340 | os.makedirs(os.path.join(results_files_path, "images", "detections_one_by_one"))
341 |
342 | """
343 | ground-truth
344 | Load each of the ground-truth files into a temporary ".json" file.
345 | Create a list of all the class names present in the ground-truth (gt_classes).
346 | """
347 | # get a list with the ground-truth files
348 | ground_truth_files_list = glob.glob(GT_PATH + '/*.txt')
349 | if len(ground_truth_files_list) == 0:
350 | error("Error: No ground-truth files found!")
351 | ground_truth_files_list.sort()
352 | # dictionary with counter per class
353 | gt_counter_per_class = {}
354 | counter_images_per_class = {}
355 |
356 | for txt_file in ground_truth_files_list:
357 | #print(txt_file)
358 | file_id = txt_file.split(".txt", 1)[0]
359 | file_id = os.path.basename(os.path.normpath(file_id))
360 | # check if there is a correspondent detection-results file
361 | temp_path = os.path.join(DR_PATH, (file_id + ".txt"))
362 | if not os.path.exists(temp_path):
363 | error_msg = "Error. File not found: {}\n".format(temp_path)
364 | error_msg += "(You can avoid this error message by running extra/intersect-gt-and-dr.py)"
365 | error(error_msg)
366 | lines_list = file_lines_to_list(txt_file)
367 | # create ground-truth dictionary
368 | bounding_boxes = []
369 | is_difficult = False
370 | already_seen_classes = []
371 | for line in lines_list:
372 | try:
373 | if "difficult" in line:
374 | class_name, left, top, right, bottom, _difficult = line.split()
375 | is_difficult = True
376 | else:
377 | class_name, left, top, right, bottom = line.split()
378 |
379 | except:
380 | if "difficult" in line:
381 | line_split = line.split()
382 | _difficult = line_split[-1]
383 | bottom = line_split[-2]
384 | right = line_split[-3]
385 | top = line_split[-4]
386 | left = line_split[-5]
387 | class_name = ""
388 | for name in line_split[:-5]:
389 | class_name += name + " "
390 | class_name = class_name[:-1]
391 | is_difficult = True
392 | else:
393 | line_split = line.split()
394 | bottom = line_split[-1]
395 | right = line_split[-2]
396 | top = line_split[-3]
397 | left = line_split[-4]
398 | class_name = ""
399 | for name in line_split[:-4]:
400 | class_name += name + " "
401 | class_name = class_name[:-1]
402 | if class_name in args.ignore:
403 | continue
404 | bbox = left + " " + top + " " + right + " " +bottom
405 | if is_difficult:
406 | bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False, "difficult":True})
407 | is_difficult = False
408 | else:
409 | bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False})
410 | if class_name in gt_counter_per_class:
411 | gt_counter_per_class[class_name] += 1
412 | else:
413 | gt_counter_per_class[class_name] = 1
414 |
415 | if class_name not in already_seen_classes:
416 | if class_name in counter_images_per_class:
417 | counter_images_per_class[class_name] += 1
418 | else:
419 | counter_images_per_class[class_name] = 1
420 | already_seen_classes.append(class_name)
421 |
422 |
423 | with open(TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json", 'w') as outfile:
424 | json.dump(bounding_boxes, outfile)
425 |
426 | gt_classes = list(gt_counter_per_class.keys())
427 | gt_classes = sorted(gt_classes)
428 | n_classes = len(gt_classes)
429 |
430 | """
431 | Check format of the flag --set-class-iou (if used)
432 | e.g. check if class exists
433 | """
434 | if specific_iou_flagged:
435 | n_args = len(args.set_class_iou)
436 | error_msg = \
437 | '\n --set-class-iou [class_1] [IoU_1] [class_2] [IoU_2] [...]'
438 | if n_args % 2 != 0:
439 | error('Error, missing arguments. Flag usage:' + error_msg)
440 | # [class_1] [IoU_1] [class_2] [IoU_2]
441 | # specific_iou_classes = ['class_1', 'class_2']
442 | specific_iou_classes = args.set_class_iou[::2] # even
443 | # iou_list = ['IoU_1', 'IoU_2']
444 | iou_list = args.set_class_iou[1::2] # odd
445 | if len(specific_iou_classes) != len(iou_list):
446 | error('Error, missing arguments. Flag usage:' + error_msg)
447 | for tmp_class in specific_iou_classes:
448 | if tmp_class not in gt_classes:
449 | error('Error, unknown class \"' + tmp_class + '\". Flag usage:' + error_msg)
450 | for num in iou_list:
451 | if not is_float_between_0_and_1(num):
452 | error('Error, IoU must be between 0.0 and 1.0. Flag usage:' + error_msg)
453 |
454 | """
455 | detection-results
456 | Load each of the detection-results files into a temporary ".json" file.
457 | """
458 | dr_files_list = glob.glob(DR_PATH + '/*.txt')
459 | dr_files_list.sort()
460 |
461 | for class_index, class_name in enumerate(gt_classes):
462 | bounding_boxes = []
463 | for txt_file in dr_files_list:
464 | file_id = txt_file.split(".txt",1)[0]
465 | file_id = os.path.basename(os.path.normpath(file_id))
466 | temp_path = os.path.join(GT_PATH, (file_id + ".txt"))
467 | if class_index == 0:
468 | if not os.path.exists(temp_path):
469 | error_msg = "Error. File not found: {}\n".format(temp_path)
470 | error_msg += "(You can avoid this error message by running extra/intersect-gt-and-dr.py)"
471 | error(error_msg)
472 | lines = file_lines_to_list(txt_file)
473 | for line in lines:
474 | try:
475 | tmp_class_name, confidence, left, top, right, bottom = line.split()
476 | except:
477 | line_split = line.split()
478 | bottom = line_split[-1]
479 | right = line_split[-2]
480 | top = line_split[-3]
481 | left = line_split[-4]
482 | confidence = line_split[-5]
483 | tmp_class_name = ""
484 | for name in line_split[:-5]:
485 | tmp_class_name += name + " "
486 | tmp_class_name = tmp_class_name[:-1]
487 |
488 | if tmp_class_name == class_name:
489 | bbox = left + " " + top + " " + right + " " +bottom
490 | bounding_boxes.append({"confidence":confidence, "file_id":file_id, "bbox":bbox})
491 |
492 | bounding_boxes.sort(key=lambda x:float(x['confidence']), reverse=True)
493 | with open(TEMP_FILES_PATH + "/" + class_name + "_dr.json", 'w') as outfile:
494 | json.dump(bounding_boxes, outfile)
495 |
496 | """
497 | Calculate the AP for each class
498 | """
499 | sum_AP = 0.0
500 | ap_dictionary = {}
501 | lamr_dictionary = {}
502 | with open(results_files_path + "/results.txt", 'w') as results_file:
503 | results_file.write("# AP and precision/recall per class\n")
504 | count_true_positives = {}
505 |
506 | for class_index, class_name in enumerate(gt_classes):
507 | count_true_positives[class_name] = 0
508 | """
509 | Load detection-results of that class
510 | """
511 | dr_file = TEMP_FILES_PATH + "/" + class_name + "_dr.json"
512 | dr_data = json.load(open(dr_file))
513 | """
514 | Assign detection-results to ground-truth objects
515 | """
516 | nd = len(dr_data)
517 | tp = [0] * nd
518 | fp = [0] * nd
519 | score = [0] * nd
520 | score05_idx = 0
521 | for idx, detection in enumerate(dr_data):
522 | file_id = detection["file_id"]
523 | score[idx] = float(detection["confidence"])
524 | if score[idx] > 0.5:
525 | score05_idx = idx
526 |
527 | if show_animation:
528 | ground_truth_img = glob.glob1(IMG_PATH, file_id + ".*")
529 | if len(ground_truth_img) == 0:
530 | error("Error. Image not found with id: " + file_id)
531 | elif len(ground_truth_img) > 1:
532 | error("Error. Multiple image with id: " + file_id)
533 | else:
534 | img = cv2.imread(IMG_PATH + "/" + ground_truth_img[0])
535 | img_cumulative_path = results_files_path + "/images/" + ground_truth_img[0]
536 | if os.path.isfile(img_cumulative_path):
537 | img_cumulative = cv2.imread(img_cumulative_path)
538 | else:
539 | img_cumulative = img.copy()
540 | bottom_border = 60
541 | BLACK = [0, 0, 0]
542 | img = cv2.copyMakeBorder(img, 0, bottom_border, 0, 0, cv2.BORDER_CONSTANT, value=BLACK)
543 |
544 | gt_file = TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json"
545 | ground_truth_data = json.load(open(gt_file))
546 | ovmax = -1
547 | gt_match = -1
548 | bb = [ float(x) for x in detection["bbox"].split() ]
549 | for obj in ground_truth_data:
550 | if obj["class_name"] == class_name:
551 | bbgt = [ float(x) for x in obj["bbox"].split() ]
552 | bi = [max(bb[0],bbgt[0]), max(bb[1],bbgt[1]), min(bb[2],bbgt[2]), min(bb[3],bbgt[3])]
553 | iw = bi[2] - bi[0] + 1
554 | ih = bi[3] - bi[1] + 1
555 | if iw > 0 and ih > 0:
556 | # compute overlap (IoU) = area of intersection / area of union
557 | ua = (bb[2] - bb[0] + 1) * (bb[3] - bb[1] + 1) + (bbgt[2] - bbgt[0]
558 | + 1) * (bbgt[3] - bbgt[1] + 1) - iw * ih
559 | ov = iw * ih / ua
560 | if ov > ovmax:
561 | ovmax = ov
562 | gt_match = obj
563 |
564 | if show_animation:
565 | status = "NO MATCH FOUND!"
566 | min_overlap = MINOVERLAP
567 | if specific_iou_flagged:
568 | if class_name in specific_iou_classes:
569 | index = specific_iou_classes.index(class_name)
570 | min_overlap = float(iou_list[index])
571 | if ovmax >= min_overlap:
572 | if "difficult" not in gt_match:
573 | if not bool(gt_match["used"]):
574 | tp[idx] = 1
575 | gt_match["used"] = True
576 | count_true_positives[class_name] += 1
577 | with open(gt_file, 'w') as f:
578 | f.write(json.dumps(ground_truth_data))
579 | if show_animation:
580 | status = "MATCH!"
581 | else:
582 | fp[idx] = 1
583 | if show_animation:
584 | status = "REPEATED MATCH!"
585 | else:
586 | fp[idx] = 1
587 | if ovmax > 0:
588 | status = "INSUFFICIENT OVERLAP"
589 |
590 | """
591 | Draw image to show animation
592 | """
593 | if show_animation:
594 | height, widht = img.shape[:2]
595 | # colors (OpenCV works with BGR)
596 | white = (255,255,255)
597 | light_blue = (255,200,100)
598 | green = (0,255,0)
599 | light_red = (30,30,255)
600 | # 1st line
601 | margin = 10
602 | v_pos = int(height - margin - (bottom_border / 2.0))
603 | text = "Image: " + ground_truth_img[0] + " "
604 | img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0)
605 | text = "Class [" + str(class_index) + "/" + str(n_classes) + "]: " + class_name + " "
606 | img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), light_blue, line_width)
607 | if ovmax != -1:
608 | color = light_red
609 | if status == "INSUFFICIENT OVERLAP":
610 | text = "IoU: {0:.2f}% ".format(ovmax*100) + "< {0:.2f}% ".format(min_overlap*100)
611 | else:
612 | text = "IoU: {0:.2f}% ".format(ovmax*100) + ">= {0:.2f}% ".format(min_overlap*100)
613 | color = green
614 | img, _ = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width)
615 | # 2nd line
616 | v_pos += int(bottom_border / 2.0)
617 | rank_pos = str(idx+1) # rank position (idx starts at 0)
618 | text = "Detection #rank: " + rank_pos + " confidence: {0:.2f}% ".format(float(detection["confidence"])*100)
619 | img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0)
620 | color = light_red
621 | if status == "MATCH!":
622 | color = green
623 | text = "Result: " + status + " "
624 | img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width)
625 |
626 | font = cv2.FONT_HERSHEY_SIMPLEX
627 | if ovmax > 0: # if there is intersections between the bounding-boxes
628 | bbgt = [ int(round(float(x))) for x in gt_match["bbox"].split() ]
629 | cv2.rectangle(img,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2)
630 | cv2.rectangle(img_cumulative,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2)
631 | cv2.putText(img_cumulative, class_name, (bbgt[0],bbgt[1] - 5), font, 0.6, light_blue, 1, cv2.LINE_AA)
632 | bb = [int(i) for i in bb]
633 | cv2.rectangle(img,(bb[0],bb[1]),(bb[2],bb[3]),color,2)
634 | cv2.rectangle(img_cumulative,(bb[0],bb[1]),(bb[2],bb[3]),color,2)
635 | cv2.putText(img_cumulative, class_name, (bb[0],bb[1] - 5), font, 0.6, color, 1, cv2.LINE_AA)
636 | # show image
637 | cv2.imshow("Animation", img)
638 | cv2.waitKey(20) # show for 20 ms
639 | # save image to results
640 | output_img_path = results_files_path + "/images/detections_one_by_one/" + class_name + "_detection" + str(idx) + ".jpg"
641 | cv2.imwrite(output_img_path, img)
642 | # save the image with all the objects drawn to it
643 | cv2.imwrite(img_cumulative_path, img_cumulative)
644 |
645 | cumsum = 0
646 | for idx, val in enumerate(fp):
647 | fp[idx] += cumsum
648 | cumsum += val
649 |
650 | cumsum = 0
651 | for idx, val in enumerate(tp):
652 | tp[idx] += cumsum
653 | cumsum += val
654 |
655 | rec = tp[:]
656 | for idx, val in enumerate(tp):
657 | rec[idx] = float(tp[idx]) / np.maximum(gt_counter_per_class[class_name], 1)
658 |
659 | prec = tp[:]
660 | for idx, val in enumerate(tp):
661 | prec[idx] = float(tp[idx]) / np.maximum((fp[idx] + tp[idx]), 1)
662 |
663 | ap, mrec, mprec = voc_ap(rec[:], prec[:])
664 | F1 = np.array(rec)*np.array(prec)*2 / np.where((np.array(prec)+np.array(rec))==0, 1, (np.array(prec)+np.array(rec)))
665 |
666 | sum_AP += ap
667 | text = "{0:.2f}%".format(ap*100) + " = " + class_name + " AP " #class_name + " AP = {0:.2f}%".format(ap*100)
668 |
669 | if len(prec)>0:
670 | F1_text = "{0:.2f}".format(F1[score05_idx]) + " = " + class_name + " F1 "
671 | Recall_text = "{0:.2f}%".format(rec[score05_idx]*100) + " = " + class_name + " Recall "
672 | Precision_text = "{0:.2f}%".format(prec[score05_idx]*100) + " = " + class_name + " Precision "
673 | else:
674 | F1_text = "0.00" + " = " + class_name + " F1 "
675 | Recall_text = "0.00%" + " = " + class_name + " Recall "
676 | Precision_text = "0.00%" + " = " + class_name + " Precision "
677 |
678 | rounded_prec = [ '%.2f' % elem for elem in prec ]
679 | rounded_rec = [ '%.2f' % elem for elem in rec ]
680 | results_file.write(text + "\n Precision: " + str(rounded_prec) + "\n Recall :" + str(rounded_rec) + "\n\n")
681 | if not args.quiet:
682 | if len(prec)>0:
683 | print(text + "\t||\tscore_threhold=0.5 : " + "F1=" + "{0:.2f}".format(F1[score05_idx])\
684 | + " ; Recall=" + "{0:.2f}%".format(rec[score05_idx]*100) + " ; Precision=" + "{0:.2f}%".format(prec[score05_idx]*100))
685 | else:
686 | print(text + "\t||\tscore_threhold=0.5 : F1=0.00% ; Recall=0.00% ; Precision=0.00%")
687 | ap_dictionary[class_name] = ap
688 |
689 | n_images = counter_images_per_class[class_name]
690 | lamr, mr, fppi = log_average_miss_rate(np.array(rec), np.array(fp), n_images)
691 | lamr_dictionary[class_name] = lamr
692 |
693 | """
694 | Draw plot
695 | """
696 | if draw_plot:
697 | plt.plot(rec, prec, '-o')
698 | area_under_curve_x = mrec[:-1] + [mrec[-2]] + [mrec[-1]]
699 | area_under_curve_y = mprec[:-1] + [0.0] + [mprec[-1]]
700 | plt.fill_between(area_under_curve_x, 0, area_under_curve_y, alpha=0.2, edgecolor='r')
701 |
702 | fig = plt.gcf()
703 | fig.canvas.set_window_title('AP ' + class_name)
704 |
705 | plt.title('class: ' + text)
706 | plt.xlabel('Recall')
707 | plt.ylabel('Precision')
708 | axes = plt.gca()
709 | axes.set_xlim([0.0,1.0])
710 | axes.set_ylim([0.0,1.05])
711 | fig.savefig(results_files_path + "/AP/" + class_name + ".png")
712 | plt.cla()
713 |
714 | plt.plot(score, F1, "-", color='orangered')
715 | plt.title('class: ' + F1_text + "\nscore_threhold=0.5")
716 | plt.xlabel('Score_Threhold')
717 | plt.ylabel('F1')
718 | axes = plt.gca()
719 | axes.set_xlim([0.0,1.0])
720 | axes.set_ylim([0.0,1.05])
721 | fig.savefig(results_files_path + "/F1/" + class_name + ".png")
722 | plt.cla()
723 |
724 | plt.plot(score, rec, "-H", color='gold')
725 | plt.title('class: ' + Recall_text + "\nscore_threhold=0.5")
726 | plt.xlabel('Score_Threhold')
727 | plt.ylabel('Recall')
728 | axes = plt.gca()
729 | axes.set_xlim([0.0,1.0])
730 | axes.set_ylim([0.0,1.05])
731 | fig.savefig(results_files_path + "/Recall/" + class_name + ".png")
732 | plt.cla()
733 |
734 | plt.plot(score, prec, "-s", color='palevioletred')
735 | plt.title('class: ' + Precision_text + "\nscore_threhold=0.5")
736 | plt.xlabel('Score_Threhold')
737 | plt.ylabel('Precision')
738 | axes = plt.gca()
739 | axes.set_xlim([0.0,1.0])
740 | axes.set_ylim([0.0,1.05])
741 | fig.savefig(results_files_path + "/Precision/" + class_name + ".png")
742 | plt.cla()
743 |
744 | if show_animation:
745 | cv2.destroyAllWindows()
746 |
747 | results_file.write("\n# mAP of all classes\n")
748 | mAP = sum_AP / n_classes
749 | text = "mAP = {0:.2f}%".format(mAP*100)
750 | results_file.write(text + "\n")
751 | print(text)
752 |
753 | # remove the temp_files directory
754 | shutil.rmtree(TEMP_FILES_PATH)
755 |
756 | """
757 | Count total of detection-results
758 | """
759 | # iterate through all the files
760 | det_counter_per_class = {}
761 | for txt_file in dr_files_list:
762 | # get lines to list
763 | lines_list = file_lines_to_list(txt_file)
764 | for line in lines_list:
765 | class_name = line.split()[0]
766 | # check if class is in the ignore list, if yes skip
767 | if class_name in args.ignore:
768 | continue
769 | # count that object
770 | if class_name in det_counter_per_class:
771 | det_counter_per_class[class_name] += 1
772 | else:
773 | # if class didn't exist yet
774 | det_counter_per_class[class_name] = 1
775 | #print(det_counter_per_class)
776 | dr_classes = list(det_counter_per_class.keys())
777 |
778 |
779 | """
780 | Plot the total number of occurences of each class in the ground-truth
781 | """
782 | if draw_plot:
783 | window_title = "ground-truth-info"
784 | plot_title = "ground-truth\n"
785 | plot_title += "(" + str(len(ground_truth_files_list)) + " files and " + str(n_classes) + " classes)"
786 | x_label = "Number of objects per class"
787 | output_path = results_files_path + "/ground-truth-info.png"
788 | to_show = False
789 | plot_color = 'forestgreen'
790 | draw_plot_func(
791 | gt_counter_per_class,
792 | n_classes,
793 | window_title,
794 | plot_title,
795 | x_label,
796 | output_path,
797 | to_show,
798 | plot_color,
799 | '',
800 | )
801 |
802 | """
803 | Write number of ground-truth objects per class to results.txt
804 | """
805 | with open(results_files_path + "/results.txt", 'a') as results_file:
806 | results_file.write("\n# Number of ground-truth objects per class\n")
807 | for class_name in sorted(gt_counter_per_class):
808 | results_file.write(class_name + ": " + str(gt_counter_per_class[class_name]) + "\n")
809 |
810 | """
811 | Finish counting true positives
812 | """
813 | for class_name in dr_classes:
814 | # if class exists in detection-result but not in ground-truth then there are no true positives in that class
815 | if class_name not in gt_classes:
816 | count_true_positives[class_name] = 0
817 | #print(count_true_positives)
818 |
819 | """
820 | Plot the total number of occurences of each class in the "detection-results" folder
821 | """
822 | if draw_plot:
823 | window_title = "detection-results-info"
824 | # Plot title
825 | plot_title = "detection-results\n"
826 | plot_title += "(" + str(len(dr_files_list)) + " files and "
827 | count_non_zero_values_in_dictionary = sum(int(x) > 0 for x in list(det_counter_per_class.values()))
828 | plot_title += str(count_non_zero_values_in_dictionary) + " detected classes)"
829 | # end Plot title
830 | x_label = "Number of objects per class"
831 | output_path = results_files_path + "/detection-results-info.png"
832 | to_show = False
833 | plot_color = 'forestgreen'
834 | true_p_bar = count_true_positives
835 | draw_plot_func(
836 | det_counter_per_class,
837 | len(det_counter_per_class),
838 | window_title,
839 | plot_title,
840 | x_label,
841 | output_path,
842 | to_show,
843 | plot_color,
844 | true_p_bar
845 | )
846 |
847 | """
848 | Write number of detected objects per class to results.txt
849 | """
850 | with open(results_files_path + "/results.txt", 'a') as results_file:
851 | results_file.write("\n# Number of detected objects per class\n")
852 | for class_name in sorted(dr_classes):
853 | n_det = det_counter_per_class[class_name]
854 | text = class_name + ": " + str(n_det)
855 | text += " (tp:" + str(count_true_positives[class_name]) + ""
856 | text += ", fp:" + str(n_det - count_true_positives[class_name]) + ")\n"
857 | results_file.write(text)
858 |
859 | """
860 | Draw log-average miss rate plot (Show lamr of all classes in decreasing order)
861 | """
862 | if draw_plot:
863 | window_title = "lamr"
864 | plot_title = "log-average miss rate"
865 | x_label = "log-average miss rate"
866 | output_path = results_files_path + "/lamr.png"
867 | to_show = False
868 | plot_color = 'royalblue'
869 | draw_plot_func(
870 | lamr_dictionary,
871 | n_classes,
872 | window_title,
873 | plot_title,
874 | x_label,
875 | output_path,
876 | to_show,
877 | plot_color,
878 | ""
879 | )
880 |
881 | """
882 | Draw mAP plot (Show AP's of all classes in decreasing order)
883 | """
884 | if draw_plot:
885 | window_title = "mAP"
886 | plot_title = "mAP = {0:.2f}%".format(mAP*100)
887 | x_label = "Average Precision"
888 | output_path = results_files_path + "/mAP.png"
889 | to_show = True
890 | plot_color = 'royalblue'
891 | draw_plot_func(
892 | ap_dictionary,
893 | n_classes,
894 | window_title,
895 | plot_title,
896 | x_label,
897 | output_path,
898 | to_show,
899 | plot_color,
900 | ""
901 | )
902 |
--------------------------------------------------------------------------------