├── Detect_GUI.py
├── Login_GUI.py
├── README.md
├── engine
├── __init__.py
├── exporter.py
├── model.py
├── predictor.py
├── results.py
├── trainer.py
└── validator.py
├── lib
└── share.py
├── setup.cfg
├── test
├── yolov8n.pt
└── zidane.jpg
├── ui
├── detect_ui.py
├── login_ui.py
├── ori_ui
│ ├── detect_ui.ui
│ ├── login_ui.ui
│ └── registe_ui.ui
└── registe_ui.py
├── ui_img
├── delete.png
├── detect.JPG
├── detect_images.qrc
├── detect_images_rc.py
├── exit.png
├── images.png
├── login.JPG
├── login_images.qrc
├── login_images_rc.py
├── registe.JPG
├── registe_images.qrc
├── registe_images_rc.py
├── run.png
├── save.png
└── upload.png
├── ultralytics
├── __init__.py
├── hub
│ ├── __init__.py
│ ├── auth.py
│ ├── session.py
│ └── utils.py
├── nn
│ ├── __init__.py
│ ├── autobackend.py
│ ├── autoshape.py
│ ├── modules.py
│ └── tasks.py
└── yolo
│ ├── __init__.py
│ ├── cfg
│ ├── __init__.py
│ └── default.yaml
│ ├── data
│ ├── __init__.py
│ ├── augment.py
│ ├── base.py
│ ├── build.py
│ ├── dataloaders
│ │ ├── __init__.py
│ │ ├── stream_loaders.py
│ │ ├── v5augmentations.py
│ │ └── v5loader.py
│ ├── dataset.py
│ ├── dataset_wrappers.py
│ └── utils.py
│ ├── engine
│ ├── __init__.py
│ ├── exporter.py
│ ├── model.py
│ ├── predictor.py
│ ├── results.py
│ ├── trainer.py
│ └── validator.py
│ ├── utils
│ ├── __init__.py
│ ├── autobatch.py
│ ├── benchmarks.py
│ ├── callbacks
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── clearml.py
│ │ ├── comet.py
│ │ ├── hub.py
│ │ └── tensorboard.py
│ ├── checks.py
│ ├── dist.py
│ ├── downloads.py
│ ├── files.py
│ ├── instance.py
│ ├── loss.py
│ ├── metrics.py
│ ├── ops.py
│ ├── plotting.py
│ ├── tal.py
│ └── torch_utils.py
│ └── v8
│ ├── __init__.py
│ ├── classify
│ ├── __init__.py
│ ├── predict.py
│ ├── train.py
│ └── val.py
│ ├── detect
│ ├── __init__.py
│ ├── predict.py
│ ├── train.py
│ └── val.py
│ └── segment
│ ├── __init__.py
│ ├── predict.py
│ ├── train.py
│ └── val.py
├── userInfo.csv
├── utils
├── __init__.py
├── autobatch.py
├── benchmarks.py
├── callbacks
│ ├── __init__.py
│ ├── base.py
│ ├── clearml.py
│ ├── comet.py
│ ├── hub.py
│ └── tensorboard.py
├── checks.py
├── dist.py
├── downloads.py
├── files.py
├── id_utils.py
├── instance.py
├── loss.py
├── metrics.py
├── ops.py
├── plotting.py
├── tal.py
└── torch_utils.py
└── v8
├── __init__.py
├── classify
├── __init__.py
├── predict.py
├── train.py
└── val.py
├── detect
├── __init__.py
├── predict.py
├── train.py
└── val.py
└── segment
├── __init__.py
├── predict.py
├── train.py
└── val.py
/Login_GUI.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from datetime import datetime
3 | from PyQt5.QtWidgets import *
4 | from utils.id_utils import get_id_info, sava_id_info # 账号信息工具函数
5 | from lib.share import shareInfo # 公共变量名
6 | # 导入QT-Design生成的UI
7 | from ui.login_ui import Ui_Form
8 | from ui.registe_ui import Ui_Dialog
9 | # 导入设计好的检测界面
10 | from Detect_GUI import Ui_MainWindow
11 | import matplotlib.backends.backend_tkagg
12 | # 界面登录
13 | class win_Login(QMainWindow):
14 | def __init__(self, parent = None):
15 | super(win_Login, self).__init__(parent)
16 | self.ui_login = Ui_Form()
17 | self.ui_login.setupUi(self)
18 | self.init_slots()
19 | self.hidden_pwd()
20 |
21 | # 密码输入框隐藏
22 | def hidden_pwd(self):
23 | self.ui_login.edit_password.setEchoMode(QLineEdit.Password)
24 |
25 | # 绑定信号槽
26 | def init_slots(self):
27 | self.ui_login.btn_login.clicked.connect(self.onSignIn) # 点击按钮登录
28 | self.ui_login.edit_password.returnPressed.connect(self.onSignIn) # 按下回车登录
29 | self.ui_login.btn_regeist.clicked.connect(self.create_id)
30 |
31 | # 跳转到注册界面
32 | def create_id(self):
33 | shareInfo.createWin = win_Registe()
34 | shareInfo.createWin.show()
35 |
36 | # 保存登录日志
37 | def sava_login_log(self, username):
38 | with open('login_log.txt', 'a', encoding='utf-8') as f:
39 | f.write(username + '\t log in at' + datetime.now().strftimestrftime+ '\r')
40 |
41 | # 登录
42 | def onSignIn(self):
43 | print("You pressed sign in")
44 | # 从登陆界面获得输入账户名与密码
45 | username = self.ui_login.edit_username.text().strip()
46 | password = self.ui_login.edit_password.text().strip()
47 | print(username)
48 | print(password)
49 | # 获得账号信息
50 | USER_PWD = get_id_info()
51 | # print(USER_PWD)
52 | if username not in USER_PWD.keys():
53 | replay = QMessageBox.warning(self,"登陆失败!", "账号或密码输入错误", QMessageBox.Yes)
54 | else:
55 | # 若登陆成功,则跳转主界面
56 | if USER_PWD.get(username) == password:
57 | print("Jump to main window")
58 | # 所以使用公用变量名
59 | # shareInfo.mainWin = UI_Logic_Window()
60 | shareInfo.mainWin = Ui_MainWindow()
61 | shareInfo.mainWin.show()
62 | # 关闭当前窗口
63 | self.close()
64 | else:
65 | replay = QMessageBox.warning(self, "!", "账号或密码输入错误", QMessageBox.Yes)
66 |
67 | # 注册界面
68 | class win_Registe(QMainWindow):
69 | def __init__(self, parent = None):
70 | super(win_Registe, self).__init__(parent)
71 | self.ui_registe = Ui_Dialog()
72 | self.ui_registe.setupUi(self)
73 | self.init_slots()
74 |
75 | # 绑定槽信号
76 | def init_slots(self):
77 | self.ui_registe.pushButton_regiser.clicked.connect(self.new_account)
78 | self.ui_registe.pushButton_cancer.clicked.connect(self.cancel)
79 |
80 | # 创建新账户
81 | def new_account(self):
82 | print("Create new account")
83 | USER_PWD = get_id_info()
84 | # print(USER_PWD)
85 | new_username = self.ui_registe.edit_username.text().strip()
86 | new_password = self.ui_registe.edit_password.text().strip()
87 | # 判断用户名是否为空
88 | if new_username == "":
89 | replay = QMessageBox.warning(self, "!", "账号不准为空", QMessageBox.Yes)
90 | else:
91 | # 判断账号是否存在
92 | if new_username in USER_PWD.keys():
93 | replay = QMessageBox.warning(self, "!", "账号已存在", QMessageBox.Yes)
94 | else:
95 | # 判断密码是否为空
96 | if new_password == "":
97 | replay = QMessageBox.warning(self, "!", "密码不能为空", QMessageBox.Yes)
98 | else:
99 | # 注册成功
100 | print("Successful!")
101 | sava_id_info(new_username, new_password)
102 | replay = QMessageBox.warning(self, "!", "注册成功!", QMessageBox.Yes)
103 | # 关闭界面
104 | self.close()
105 |
106 | # 取消注册
107 | def cancel(self):
108 | self.close() # 关闭当前界面
109 |
110 | if __name__ == "__main__":
111 | app = QApplication(sys.argv)
112 | # 利用共享变量名来实例化对象
113 | shareInfo.loginWin = win_Login() # 登录界面作为主界面
114 | shareInfo.loginWin.show()
115 | sys.exit(app.exec_())
116 |
117 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # YOLOv8_PYQT5_GUI
2 | 基于YOLOv8和PYQT5的检测界面
3 |
--------------------------------------------------------------------------------
/engine/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbl1234/YOLOv8_PYQT5_GUI/0b8a3f89d7a4b8f6a55b66da5d425b7b2918a407/engine/__init__.py
--------------------------------------------------------------------------------
/lib/share.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Author : Ruihao
3 | # @ProjectName:yolov5-pyqt5
4 |
5 | class shareInfo:
6 | '''
7 | 存取公用的界面名 # 参考自白月黑羽 www.python3.vip
8 | '''
9 | mainWin = None
10 | loginWin = None
11 | createWin = None
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | # Project-wide configuration file, can be used for package metadata and other toll configurations
2 | # Example usage: global configuration for PEP8 (via flake8) setting or default pytest arguments
3 | # Local usage: pip install pre-commit, pre-commit run --all-files
4 |
5 | [metadata]
6 | license_files = LICENSE
7 | description_file = README.md
8 |
9 | [tool:pytest]
10 | norecursedirs =
11 | .git
12 | dist
13 | build
14 | addopts =
15 | --doctest-modules
16 | --durations=25
17 | --color=yes
18 |
19 | [flake8]
20 | max-line-length = 120
21 | exclude = .tox,*.egg,build,temp
22 | select = E,W,F
23 | doctests = True
24 | verbose = 2
25 | # https://pep8.readthedocs.io/en/latest/intro.html#error-codes
26 | format = pylint
27 | # see: https://www.flake8rules.com/
28 | ignore = E731,F405,E402,W504,E501
29 | # E731: Do not assign a lambda expression, use a def
30 | # F405: name may be undefined, or defined from star imports: module
31 | # E402: module level import not at top of file
32 | # W504: line break after binary operator
33 | # E501: line too long
34 | # removed:
35 | # F401: module imported but unused
36 | # E231: missing whitespace after ‘,’, ‘;’, or ‘:’
37 | # E127: continuation line over-indented for visual indent
38 | # F403: ‘from module import *’ used; unable to detect undefined names
39 |
40 |
41 | [isort]
42 | # https://pycqa.github.io/isort/docs/configuration/options.html
43 | line_length = 120
44 | # see: https://pycqa.github.io/isort/docs/configuration/multi_line_output_modes.html
45 | multi_line_output = 0
46 |
47 | [yapf]
48 | based_on_style = pep8
49 | spaces_before_comment = 2
50 | COLUMN_LIMIT = 120
51 | COALESCE_BRACKETS = True
52 | SPACES_AROUND_POWER_OPERATOR = True
53 | SPACE_BETWEEN_ENDING_COMMA_AND_CLOSING_BRACKET = True
54 | SPLIT_BEFORE_CLOSING_BRACKET = False
55 | SPLIT_BEFORE_FIRST_ARGUMENT = False
56 | # EACH_DICT_ENTRY_ON_SEPARATE_LINE = False
57 |
--------------------------------------------------------------------------------
/test/yolov8n.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbl1234/YOLOv8_PYQT5_GUI/0b8a3f89d7a4b8f6a55b66da5d425b7b2918a407/test/yolov8n.pt
--------------------------------------------------------------------------------
/test/zidane.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbl1234/YOLOv8_PYQT5_GUI/0b8a3f89d7a4b8f6a55b66da5d425b7b2918a407/test/zidane.jpg
--------------------------------------------------------------------------------
/ui/detect_ui.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Form implementation generated from reading ui file 'ui/ori_ui/detect_ui.ui'
4 | #
5 | # Created by: PyQt5 UI code generator 5.15.7
6 | #
7 | # WARNING: Any manual changes made to this file will be lost when pyuic5 is
8 | # run again. Do not edit this file unless you know what you are doing.
9 |
10 |
11 | from PyQt5 import QtCore, QtGui, QtWidgets
12 |
13 |
14 | class Ui_MainWindow(object):
15 | def setupUi(self, MainWindow):
16 | MainWindow.setObjectName("MainWindow")
17 | MainWindow.resize(1500, 1000)
18 | MainWindow.setStyleSheet("QWidget#centralwidget{\n"
19 | " background-image: url(:/detect_background/detect.jpg);}")
20 | self.centralwidget = QtWidgets.QWidget(MainWindow)
21 | self.centralwidget.setObjectName("centralwidget")
22 | self.pushButton = QtWidgets.QPushButton(self.centralwidget)
23 | self.pushButton.setGeometry(QtCore.QRect(70, 810, 70, 70))
24 | self.pushButton.setStyleSheet("border-image: url(:/detect_button_background/upload.png);\n"
25 | "\n"
26 | "")
27 | self.pushButton.setText("")
28 | self.pushButton.setObjectName("pushButton")
29 | self.pushButton_3 = QtWidgets.QPushButton(self.centralwidget)
30 | self.pushButton_3.setGeometry(QtCore.QRect(390, 810, 70, 70))
31 | self.pushButton_3.setStyleSheet("border-image: url(:/detect_button_background/images.png);")
32 | self.pushButton_3.setText("")
33 | self.pushButton_3.setObjectName("pushButton_3")
34 | self.pushButton_4 = QtWidgets.QPushButton(self.centralwidget)
35 | self.pushButton_4.setGeometry(QtCore.QRect(730, 810, 70, 70))
36 | self.pushButton_4.setStyleSheet("border-image: url(:/detect_button_background/save.png);")
37 | self.pushButton_4.setText("")
38 | self.pushButton_4.setObjectName("pushButton_4")
39 | self.pushButton_5 = QtWidgets.QPushButton(self.centralwidget)
40 | self.pushButton_5.setGeometry(QtCore.QRect(1050, 810, 70, 70))
41 | self.pushButton_5.setStyleSheet("border-image: url(:/detect_button_background/delete.png);")
42 | self.pushButton_5.setText("")
43 | self.pushButton_5.setObjectName("pushButton_5")
44 | self.pushButton_6 = QtWidgets.QPushButton(self.centralwidget)
45 | self.pushButton_6.setGeometry(QtCore.QRect(1360, 810, 70, 70))
46 | self.pushButton_6.setStyleSheet("border-image: url(:/detect_button_background/exit.png);")
47 | self.pushButton_6.setText("")
48 | self.pushButton_6.setObjectName("pushButton_6")
49 | self.label_2 = QtWidgets.QLabel(self.centralwidget)
50 | self.label_2.setGeometry(QtCore.QRect(190, 10, 1101, 80))
51 | font = QtGui.QFont()
52 | font.setFamily("Adobe 黑体 Std R")
53 | font.setPointSize(28)
54 | self.label_2.setFont(font)
55 | self.label_2.setStyleSheet("")
56 | self.label_2.setObjectName("label_2")
57 | self.label_3 = QtWidgets.QLabel(self.centralwidget)
58 | self.label_3.setGeometry(QtCore.QRect(0, 80, 700, 700))
59 | self.label_3.setStyleSheet("background-color: rgb(255, 255, 255);")
60 | self.label_3.setObjectName("label_3")
61 | self.label_4 = QtWidgets.QLabel(self.centralwidget)
62 | self.label_4.setGeometry(QtCore.QRect(800, 80, 700, 700))
63 | self.label_4.setStyleSheet("background-color: rgb(255, 255, 255);")
64 | self.label_4.setObjectName("label_4")
65 | self.label_5 = QtWidgets.QLabel(self.centralwidget)
66 | self.label_5.setGeometry(QtCore.QRect(-1, 800, 1501, 141))
67 | self.label_5.setStyleSheet("background-color: rgb(255, 255, 255);\n"
68 | "border-color: rgb(0, 0, 0);")
69 | self.label_5.setText("")
70 | self.label_5.setObjectName("label_5")
71 | self.lineEdit = QtWidgets.QLineEdit(self.centralwidget)
72 | self.lineEdit.setGeometry(QtCore.QRect(20, 890, 161, 40))
73 | font = QtGui.QFont()
74 | font.setFamily("Adobe 宋体 Std L")
75 | font.setPointSize(20)
76 | self.lineEdit.setFont(font)
77 | self.lineEdit.setLayoutDirection(QtCore.Qt.RightToLeft)
78 | self.lineEdit.setObjectName("lineEdit")
79 | self.lineEdit_3 = QtWidgets.QLineEdit(self.centralwidget)
80 | self.lineEdit_3.setGeometry(QtCore.QRect(350, 890, 161, 40))
81 | font = QtGui.QFont()
82 | font.setFamily("Adobe 宋体 Std L")
83 | font.setPointSize(20)
84 | self.lineEdit_3.setFont(font)
85 | self.lineEdit_3.setObjectName("lineEdit_3")
86 | self.lineEdit_4 = QtWidgets.QLineEdit(self.centralwidget)
87 | self.lineEdit_4.setGeometry(QtCore.QRect(690, 890, 161, 40))
88 | font = QtGui.QFont()
89 | font.setFamily("Adobe 宋体 Std L")
90 | font.setPointSize(20)
91 | self.lineEdit_4.setFont(font)
92 | self.lineEdit_4.setObjectName("lineEdit_4")
93 | self.lineEdit_5 = QtWidgets.QLineEdit(self.centralwidget)
94 | self.lineEdit_5.setGeometry(QtCore.QRect(1000, 890, 161, 40))
95 | font = QtGui.QFont()
96 | font.setFamily("Adobe 宋体 Std L")
97 | font.setPointSize(20)
98 | self.lineEdit_5.setFont(font)
99 | self.lineEdit_5.setObjectName("lineEdit_5")
100 | self.lineEdit_6 = QtWidgets.QLineEdit(self.centralwidget)
101 | self.lineEdit_6.setGeometry(QtCore.QRect(1300, 890, 161, 40))
102 | font = QtGui.QFont()
103 | font.setFamily("Adobe 宋体 Std L")
104 | font.setPointSize(20)
105 | self.lineEdit_6.setFont(font)
106 | self.lineEdit_6.setObjectName("lineEdit_6")
107 | self.label_5.raise_()
108 | self.pushButton.raise_()
109 | self.pushButton_3.raise_()
110 | self.pushButton_4.raise_()
111 | self.pushButton_5.raise_()
112 | self.pushButton_6.raise_()
113 | self.label_2.raise_()
114 | self.label_3.raise_()
115 | self.label_4.raise_()
116 | self.lineEdit.raise_()
117 | self.lineEdit_3.raise_()
118 | self.lineEdit_4.raise_()
119 | self.lineEdit_5.raise_()
120 | self.lineEdit_6.raise_()
121 | MainWindow.setCentralWidget(self.centralwidget)
122 | self.menubar = QtWidgets.QMenuBar(MainWindow)
123 | self.menubar.setGeometry(QtCore.QRect(0, 0, 1500, 30))
124 | self.menubar.setObjectName("menubar")
125 | MainWindow.setMenuBar(self.menubar)
126 | self.statusbar = QtWidgets.QStatusBar(MainWindow)
127 | self.statusbar.setObjectName("statusbar")
128 | MainWindow.setStatusBar(self.statusbar)
129 |
130 | self.retranslateUi(MainWindow)
131 | QtCore.QMetaObject.connectSlotsByName(MainWindow)
132 |
133 | def retranslateUi(self, MainWindow):
134 | _translate = QtCore.QCoreApplication.translate
135 | MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
136 | self.label_2.setText(_translate("MainWindow", "
基于YOLOv8的疏果前期幼果检测演示软件
"))
137 | self.label_3.setText(_translate("MainWindow", "原始图像
"))
138 | self.label_4.setText(_translate("MainWindow", "检测图像
"))
139 | self.lineEdit.setToolTip(_translate("MainWindow", "
"))
140 | self.lineEdit.setWhatsThis(_translate("MainWindow", "
"))
141 | self.lineEdit.setText(_translate("MainWindow", "模型加载"))
142 | self.lineEdit_3.setText(_translate("MainWindow", "图像加载"))
143 | self.lineEdit_4.setText(_translate("MainWindow", "图像保存"))
144 | self.lineEdit_5.setText(_translate("MainWindow", "图像清除"))
145 | self.lineEdit_6.setText(_translate("MainWindow", "应用退出"))
146 | import ui_img.detect_images
147 |
--------------------------------------------------------------------------------
/ui/login_ui.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Form implementation generated from reading ui file 'ui/ori_ui/login_ui.ui'
4 | #
5 | # Created by: PyQt5 UI code generator 5.15.7
6 | #
7 | # WARNING: Any manual changes made to this file will be lost when pyuic5 is
8 | # run again. Do not edit this file unless you know what you are doing.
9 |
10 |
11 | from PyQt5 import QtCore, QtGui, QtWidgets
12 | import ui_img.login_images_rc
13 |
14 | class Ui_Form(object):
15 | def setupUi(self, Form):
16 | Form.setObjectName("Form")
17 | Form.setWindowTitle("基于YOLOv8的检测演示软件V1.0")
18 | Form.resize(800, 600)
19 | Form.setStyleSheet("QWidget#Form{background-image: url(:/login_background/login.JPG);}")
20 |
21 |
22 | self.edit_username = QtWidgets.QLineEdit(Form)
23 | self.edit_username.setGeometry(QtCore.QRect(250, 200, 250, 50))
24 | font = QtGui.QFont()
25 | font.setFamily("Adobe 黑体 Std R")
26 | font.setPointSize(18)
27 | self.edit_username.setFont(font)
28 | self.edit_username.setObjectName("edit_username")
29 |
30 | self.btn_login = QtWidgets.QPushButton(Form)
31 | self.btn_login.setGeometry(QtCore.QRect(220, 430, 120, 80))
32 | font = QtGui.QFont()
33 | font.setFamily("Adobe 黑体 Std R")
34 | font.setPointSize(18)
35 | self.btn_login.setFont(font)
36 | self.btn_login.setObjectName("btn_login")
37 |
38 | self.edit_password = QtWidgets.QLineEdit(Form)
39 | self.edit_password.setGeometry(QtCore.QRect(250, 300, 250, 50))
40 | font = QtGui.QFont()
41 | font.setFamily("Adobe 黑体 Std R")
42 | font.setPointSize(18)
43 | self.edit_password.setFont(font)
44 | self.edit_password.setObjectName("edit_password")
45 |
46 | self.label = QtWidgets.QLabel(Form)
47 | self.label.setGeometry(QtCore.QRect(0, 80, 781, 80))
48 | font = QtGui.QFont()
49 | font.setFamily("Adobe 黑体 Std R")
50 | font.setPointSize(22)
51 | self.label.setFont(font)
52 | self.label.setStyleSheet("")
53 | self.label.setObjectName("label")
54 |
55 | self.btn_regeist = QtWidgets.QPushButton(Form)
56 | self.btn_regeist.setGeometry(QtCore.QRect(420, 430, 120, 80))
57 | font = QtGui.QFont()
58 | font.setFamily("Adobe 黑体 Std R")
59 | font.setPointSize(18)
60 | self.btn_regeist.setFont(font)
61 | self.btn_regeist.setObjectName("btn_regeist")
62 |
63 | self.retranslateUi(Form)
64 | QtCore.QMetaObject.connectSlotsByName(Form)
65 |
66 | def retranslateUi(self, Form):
67 | _translate = QtCore.QCoreApplication.translate
68 | Form.setWindowTitle(_translate("Form", "Form"))
69 | self.edit_username.setPlaceholderText(_translate("Form", "用户名"))
70 | self.btn_login.setText(_translate("Form", "登录"))
71 | self.edit_password.setPlaceholderText(_translate("Form", "密码"))
72 | self.label.setText(_translate("Form", "基于YOLOv8的检测演示软件
"))
73 | self.btn_regeist.setText(_translate("Form", "注册"))
74 |
75 |
--------------------------------------------------------------------------------
/ui/ori_ui/detect_ui.ui:
--------------------------------------------------------------------------------
1 |
2 |
3 | MainWindow
4 |
5 |
6 |
7 | 0
8 | 0
9 | 1500
10 | 1000
11 |
12 |
13 |
14 | MainWindow
15 |
16 |
17 | QWidget#centralwidget{
18 | background-image: url(:/detect_background/detect.jpg);}
19 |
20 |
21 |
22 |
23 |
24 | 70
25 | 810
26 | 70
27 | 70
28 |
29 |
30 |
31 | border-image: url(:/detect_button_background/upload.png);
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 | 390
43 | 810
44 | 70
45 | 70
46 |
47 |
48 |
49 | border-image: url(:/detect_button_background/images.png);
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 | 730
59 | 810
60 | 70
61 | 70
62 |
63 |
64 |
65 | border-image: url(:/detect_button_background/save.png);
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 | 1050
75 | 810
76 | 70
77 | 70
78 |
79 |
80 |
81 | border-image: url(:/detect_button_background/delete.png);
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 | 1360
91 | 810
92 | 70
93 | 70
94 |
95 |
96 |
97 | border-image: url(:/detect_button_background/exit.png);
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 | 190
107 | 10
108 | 1101
109 | 80
110 |
111 |
112 |
113 |
114 | Adobe 黑体 Std R
115 | 28
116 |
117 |
118 |
119 |
120 |
121 |
122 | <html><head/><body><p align="center"><span style=" font-size:28pt; font-weight:600; color:#ffffff;">基于YOLOv8的疏果前期幼果检测演示软件</span></p></body></html>
123 |
124 |
125 |
126 |
127 |
128 | 0
129 | 80
130 | 700
131 | 700
132 |
133 |
134 |
135 | background-color: rgb(255, 255, 255);
136 |
137 |
138 | <html><head/><body><p align="center"><span style=" font-size:20pt;">原始图像</span></p></body></html>
139 |
140 |
141 |
142 |
143 |
144 | 800
145 | 80
146 | 700
147 | 700
148 |
149 |
150 |
151 | background-color: rgb(255, 255, 255);
152 |
153 |
154 | <html><head/><body><p align="center"><span style=" font-size:20pt;">检测图像</span></p></body></html>
155 |
156 |
157 |
158 |
159 |
160 | -1
161 | 800
162 | 1501
163 | 141
164 |
165 |
166 |
167 | background-color: rgb(255, 255, 255);
168 | border-color: rgb(0, 0, 0);
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 | 20
178 | 890
179 | 161
180 | 40
181 |
182 |
183 |
184 |
185 | Adobe 宋体 Std L
186 | 20
187 |
188 |
189 |
190 | <html><head/><body><p align="center"><br/></p></body></html>
191 |
192 |
193 | <html><head/><body><p align="center"><br/></p></body></html>
194 |
195 |
196 | Qt::RightToLeft
197 |
198 |
199 | 模型加载
200 |
201 |
202 |
203 |
204 |
205 | 350
206 | 890
207 | 161
208 | 40
209 |
210 |
211 |
212 |
213 | Adobe 宋体 Std L
214 | 20
215 |
216 |
217 |
218 | 图像加载
219 |
220 |
221 |
222 |
223 |
224 | 690
225 | 890
226 | 161
227 | 40
228 |
229 |
230 |
231 |
232 | Adobe 宋体 Std L
233 | 20
234 |
235 |
236 |
237 | 图像保存
238 |
239 |
240 |
241 |
242 |
243 | 1000
244 | 890
245 | 161
246 | 40
247 |
248 |
249 |
250 |
251 | Adobe 宋体 Std L
252 | 20
253 |
254 |
255 |
256 | 图像清除
257 |
258 |
259 |
260 |
261 |
262 | 1300
263 | 890
264 | 161
265 | 40
266 |
267 |
268 |
269 |
270 | Adobe 宋体 Std L
271 | 20
272 |
273 |
274 |
275 | 应用退出
276 |
277 |
278 | label_5
279 | pushButton
280 | pushButton_3
281 | pushButton_4
282 | pushButton_5
283 | pushButton_6
284 | label_2
285 | label_3
286 | label_4
287 | lineEdit
288 | lineEdit_3
289 | lineEdit_4
290 | lineEdit_5
291 | lineEdit_6
292 |
293 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
--------------------------------------------------------------------------------
/ui/ori_ui/login_ui.ui:
--------------------------------------------------------------------------------
1 |
2 |
3 | Form
4 |
5 |
6 |
7 | 0
8 | 0
9 | 800
10 | 600
11 |
12 |
13 |
14 | Form
15 |
16 |
17 | QWidget#Form{
18 | background-image: url(:/login_background/denglu.jpg);}
19 |
20 |
21 |
22 |
23 | 250
24 | 200
25 | 250
26 | 50
27 |
28 |
29 |
30 |
31 | Adobe 黑体 Std R
32 | 18
33 |
34 |
35 |
36 | 用户名
37 |
38 |
39 |
40 |
41 |
42 | 220
43 | 430
44 | 120
45 | 80
46 |
47 |
48 |
49 |
50 | Adobe 黑体 Std R
51 | 18
52 |
53 |
54 |
55 | 登录
56 |
57 |
58 |
59 |
60 |
61 | 250
62 | 300
63 | 250
64 | 50
65 |
66 |
67 |
68 |
69 | Adobe 黑体 Std R
70 | 18
71 |
72 |
73 |
74 | 密码
75 |
76 |
77 |
78 |
79 |
80 | 0
81 | 80
82 | 781
83 | 80
84 |
85 |
86 |
87 |
88 | Adobe 黑体 Std R
89 | 22
90 |
91 |
92 |
93 |
94 |
95 |
96 | <html><head/><body><p align="center"><span style=" font-size:20pt; color:#ffffff;">基于YOLOv8的疏果前期幼果检测演示软件</span></p></body></html>
97 |
98 |
99 |
100 |
101 |
102 | 420
103 | 430
104 | 120
105 | 80
106 |
107 |
108 |
109 |
110 | Adobe 黑体 Std R
111 | 18
112 |
113 |
114 |
115 | 注册
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
--------------------------------------------------------------------------------
/ui/ori_ui/registe_ui.ui:
--------------------------------------------------------------------------------
1 |
2 |
3 | Dialog
4 |
5 |
6 |
7 | 0
8 | 0
9 | 800
10 | 600
11 |
12 |
13 |
14 | Dialog
15 |
16 |
17 | QDialog#Dialog{
18 | background-image: url(:/registe_background/registe.JPG);}
19 |
20 |
21 |
22 |
23 |
24 |
25 | 260
26 | 120
27 | 241
28 | 81
29 |
30 |
31 |
32 |
33 | Adobe 黑体 Std R
34 | 28
35 | 75
36 | true
37 |
38 |
39 |
40 |
41 |
42 |
43 | <html><head/><body><p align="center"><span style=" font-size:28pt; color:#ffffff;">账号注册</span></p></body></html>
44 |
45 |
46 |
47 |
48 |
49 | 260
50 | 320
51 | 250
52 | 50
53 |
54 |
55 |
56 |
57 | Adobe 黑体 Std R
58 | 18
59 |
60 |
61 |
62 | 设置密码
63 |
64 |
65 |
66 |
67 |
68 | 220
69 | 410
70 | 120
71 | 80
72 |
73 |
74 |
75 |
76 | Adobe 黑体 Std R
77 | 18
78 |
79 |
80 |
81 | 注册
82 |
83 |
84 |
85 |
86 |
87 | 260
88 | 230
89 | 250
90 | 50
91 |
92 |
93 |
94 |
95 | Adobe 黑体 Std R
96 | 18
97 |
98 |
99 |
100 | 注册账号
101 |
102 |
103 |
104 |
105 |
106 | 420
107 | 410
108 | 120
109 | 80
110 |
111 |
112 |
113 |
114 | Adobe 黑体 Std R
115 | 18
116 |
117 |
118 |
119 | 取消
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
--------------------------------------------------------------------------------
/ui/registe_ui.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Form implementation generated from reading ui file 'ui/ori_ui/registe_ui.ui'
4 | #
5 | # Created by: PyQt5 UI code generator 5.15.7
6 | #
7 | # WARNING: Any manual changes made to this file will be lost when pyuic5 is
8 | # run again. Do not edit this file unless you know what you are doing.
9 |
10 |
11 | from PyQt5 import QtCore, QtGui, QtWidgets
12 | import ui_img.registe_images_rc
13 |
14 | class Ui_Dialog(object):
15 | def setupUi(self, Dialog):
16 | Dialog.setObjectName("Dialog")
17 | Dialog.setWindowTitle("基于YOLOv8的检测演示软件V1.0")
18 | Dialog.resize(800, 600)
19 | Dialog.setStyleSheet("QDialog#Dialog{background-image: url(:/registe_background/registe.JPG);}\n"
20 | "")
21 |
22 |
23 | self.label = QtWidgets.QLabel(Dialog)
24 | self.label.setGeometry(QtCore.QRect(260, 120, 241, 81))
25 | font = QtGui.QFont()
26 | font.setFamily("Adobe 黑体 Std R")
27 | font.setPointSize(28)
28 | font.setBold(True)
29 | font.setWeight(75)
30 | self.label.setFont(font)
31 | self.label.setStyleSheet("")
32 | self.label.setObjectName("label")
33 |
34 | self.edit_password = QtWidgets.QLineEdit(Dialog)
35 | self.edit_password.setGeometry(QtCore.QRect(260, 320, 250, 50))
36 | font = QtGui.QFont()
37 | font.setFamily("Adobe 黑体 Std R")
38 | font.setPointSize(18)
39 | self.edit_password.setFont(font)
40 | self.edit_password.setObjectName("edit_password")
41 |
42 | self.pushButton_regiser = QtWidgets.QPushButton(Dialog)
43 | self.pushButton_regiser.setGeometry(QtCore.QRect(220, 410, 120, 80))
44 | font = QtGui.QFont()
45 | font.setFamily("Adobe 黑体 Std R")
46 | font.setPointSize(18)
47 | self.pushButton_regiser.setFont(font)
48 | self.pushButton_regiser.setObjectName("pushButton_regiser")
49 |
50 | self.edit_username = QtWidgets.QLineEdit(Dialog)
51 | self.edit_username.setGeometry(QtCore.QRect(260, 230, 250, 50))
52 | font = QtGui.QFont()
53 | font.setFamily("Adobe 黑体 Std R")
54 | font.setPointSize(18)
55 | self.edit_username.setFont(font)
56 | self.edit_username.setObjectName("edit_username")
57 |
58 | self.pushButton_cancer = QtWidgets.QPushButton(Dialog)
59 | self.pushButton_cancer.setGeometry(QtCore.QRect(420, 410, 120, 80))
60 | font = QtGui.QFont()
61 | font.setFamily("Adobe 黑体 Std R")
62 | font.setPointSize(18)
63 | self.pushButton_cancer.setFont(font)
64 | self.pushButton_cancer.setObjectName("pushButton_cancer")
65 |
66 | self.retranslateUi(Dialog)
67 | QtCore.QMetaObject.connectSlotsByName(Dialog)
68 |
69 | def retranslateUi(self, Dialog):
70 | _translate = QtCore.QCoreApplication.translate
71 | Dialog.setWindowTitle(_translate("Dialog", "Dialog"))
72 | self.label.setText(_translate("Dialog", "账号注册
"))
73 | self.edit_password.setPlaceholderText(_translate("Dialog", "设置密码"))
74 | self.pushButton_regiser.setText(_translate("Dialog", "注册"))
75 | self.edit_username.setPlaceholderText(_translate("Dialog", "注册账号"))
76 | self.pushButton_cancer.setText(_translate("Dialog", "取消"))
77 |
78 |
--------------------------------------------------------------------------------
/ui_img/delete.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbl1234/YOLOv8_PYQT5_GUI/0b8a3f89d7a4b8f6a55b66da5d425b7b2918a407/ui_img/delete.png
--------------------------------------------------------------------------------
/ui_img/detect.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbl1234/YOLOv8_PYQT5_GUI/0b8a3f89d7a4b8f6a55b66da5d425b7b2918a407/ui_img/detect.JPG
--------------------------------------------------------------------------------
/ui_img/detect_images.qrc:
--------------------------------------------------------------------------------
1 |
2 |
3 | detect.JPG
4 |
5 |
6 | upload.png
7 | images.png
8 | delete.png
9 | exit.png
10 | run.png
11 | save.png
12 |
13 |
14 |
--------------------------------------------------------------------------------
/ui_img/exit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbl1234/YOLOv8_PYQT5_GUI/0b8a3f89d7a4b8f6a55b66da5d425b7b2918a407/ui_img/exit.png
--------------------------------------------------------------------------------
/ui_img/images.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbl1234/YOLOv8_PYQT5_GUI/0b8a3f89d7a4b8f6a55b66da5d425b7b2918a407/ui_img/images.png
--------------------------------------------------------------------------------
/ui_img/login.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbl1234/YOLOv8_PYQT5_GUI/0b8a3f89d7a4b8f6a55b66da5d425b7b2918a407/ui_img/login.JPG
--------------------------------------------------------------------------------
/ui_img/login_images.qrc:
--------------------------------------------------------------------------------
1 |
2 |
3 | login.JPG
4 |
5 |
6 |
--------------------------------------------------------------------------------
/ui_img/registe.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbl1234/YOLOv8_PYQT5_GUI/0b8a3f89d7a4b8f6a55b66da5d425b7b2918a407/ui_img/registe.JPG
--------------------------------------------------------------------------------
/ui_img/registe_images.qrc:
--------------------------------------------------------------------------------
1 |
2 |
3 | registe.JPG
4 |
5 |
6 |
--------------------------------------------------------------------------------
/ui_img/run.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbl1234/YOLOv8_PYQT5_GUI/0b8a3f89d7a4b8f6a55b66da5d425b7b2918a407/ui_img/run.png
--------------------------------------------------------------------------------
/ui_img/save.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbl1234/YOLOv8_PYQT5_GUI/0b8a3f89d7a4b8f6a55b66da5d425b7b2918a407/ui_img/save.png
--------------------------------------------------------------------------------
/ui_img/upload.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbl1234/YOLOv8_PYQT5_GUI/0b8a3f89d7a4b8f6a55b66da5d425b7b2918a407/ui_img/upload.png
--------------------------------------------------------------------------------
/ultralytics/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | __version__ = '8.0.55'
4 |
5 | from ultralytics.yolo.engine.model import YOLO
6 | from ultralytics.yolo.utils.checks import check_yolo as checks
7 |
8 | __all__ = '__version__', 'YOLO', 'checks' # allow simpler import
9 |
--------------------------------------------------------------------------------
/ultralytics/hub/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | import requests
4 |
5 | from ultralytics.hub.auth import Auth
6 | from ultralytics.hub.session import HUBTrainingSession
7 | from ultralytics.hub.utils import PREFIX, split_key
8 | from ultralytics.yolo.engine.model import YOLO
9 | from ultralytics.yolo.utils import LOGGER, emojis
10 |
11 |
12 | def start(key=''):
13 | """
14 | Start training models with Ultralytics HUB. Usage: from ultralytics.hub import start; start('API_KEY')
15 | """
16 | auth = Auth(key)
17 | if not auth.get_state():
18 | model_id = request_api_key(auth)
19 | else:
20 | _, model_id = split_key(key)
21 |
22 | if not model_id:
23 | raise ConnectionError(emojis('Connecting with global API key is not currently supported. ❌'))
24 |
25 | session = HUBTrainingSession(model_id=model_id, auth=auth)
26 | session.check_disk_space()
27 |
28 | model = YOLO(model=session.model_file, session=session)
29 | model.train(**session.train_args)
30 |
31 |
32 | def request_api_key(auth, max_attempts=3):
33 | """
34 | Prompt the user to input their API key. Returns the model ID.
35 | """
36 | import getpass
37 | for attempts in range(max_attempts):
38 | LOGGER.info(f'{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}')
39 | input_key = getpass.getpass('Enter your Ultralytics HUB API key:\n')
40 | auth.api_key, model_id = split_key(input_key)
41 |
42 | if auth.authenticate():
43 | LOGGER.info(f'{PREFIX}Authenticated ✅')
44 | return model_id
45 |
46 | LOGGER.warning(f'{PREFIX}Invalid API key ⚠️\n')
47 |
48 | raise ConnectionError(emojis(f'{PREFIX}Failed to authenticate ❌'))
49 |
50 |
51 | def reset_model(key=''):
52 | # Reset a trained model to an untrained state
53 | api_key, model_id = split_key(key)
54 | r = requests.post('https://api.ultralytics.com/model-reset', json={'apiKey': api_key, 'modelId': model_id})
55 |
56 | if r.status_code == 200:
57 | LOGGER.info(f'{PREFIX}Model reset successfully')
58 | return
59 | LOGGER.warning(f'{PREFIX}Model reset failure {r.status_code} {r.reason}')
60 |
61 |
62 | def export_fmts_hub():
63 | # Returns a list of HUB-supported export formats
64 | from ultralytics.yolo.engine.exporter import export_formats
65 | return list(export_formats()['Argument'][1:]) + ['ultralytics_tflite', 'ultralytics_coreml']
66 |
67 |
68 | def export_model(key='', format='torchscript'):
69 | # Export a model to all formats
70 | assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
71 | api_key, model_id = split_key(key)
72 | r = requests.post('https://api.ultralytics.com/export',
73 | json={
74 | 'apiKey': api_key,
75 | 'modelId': model_id,
76 | 'format': format})
77 | assert r.status_code == 200, f'{PREFIX}{format} export failure {r.status_code} {r.reason}'
78 | LOGGER.info(f'{PREFIX}{format} export started ✅')
79 |
80 |
81 | def get_export(key='', format='torchscript'):
82 | # Get an exported model dictionary with download URL
83 | assert format in export_fmts_hub, f"Unsupported export format '{format}', valid formats are {export_fmts_hub}"
84 | api_key, model_id = split_key(key)
85 | r = requests.post('https://api.ultralytics.com/get-export',
86 | json={
87 | 'apiKey': api_key,
88 | 'modelId': model_id,
89 | 'format': format})
90 | assert r.status_code == 200, f'{PREFIX}{format} get_export failure {r.status_code} {r.reason}'
91 | return r.json()
92 |
93 |
94 | if __name__ == '__main__':
95 | start()
96 |
--------------------------------------------------------------------------------
/ultralytics/hub/auth.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | import requests
4 |
5 | from ultralytics.hub.utils import HUB_API_ROOT, request_with_credentials
6 | from ultralytics.yolo.utils import is_colab
7 |
8 | API_KEY_PATH = 'https://hub.ultralytics.com/settings?tab=api+keys'
9 |
10 |
11 | class Auth:
12 | id_token = api_key = model_key = False
13 |
14 | def __init__(self, api_key=None):
15 | self.api_key = self._clean_api_key(api_key)
16 | self.authenticate() if self.api_key else self.auth_with_cookies()
17 |
18 | @staticmethod
19 | def _clean_api_key(key: str) -> str:
20 | """Strip model from key if present"""
21 | separator = '_'
22 | return key.split(separator)[0] if separator in key else key
23 |
24 | def authenticate(self) -> bool:
25 | """Attempt to authenticate with server"""
26 | try:
27 | header = self.get_auth_header()
28 | if header:
29 | r = requests.post(f'{HUB_API_ROOT}/v1/auth', headers=header)
30 | if not r.json().get('success', False):
31 | raise ConnectionError('Unable to authenticate.')
32 | return True
33 | raise ConnectionError('User has not authenticated locally.')
34 | except ConnectionError:
35 | self.id_token = self.api_key = False # reset invalid
36 | return False
37 |
38 | def auth_with_cookies(self) -> bool:
39 | """
40 | Attempt to fetch authentication via cookies and set id_token.
41 | User must be logged in to HUB and running in a supported browser.
42 | """
43 | if not is_colab():
44 | return False # Currently only works with Colab
45 | try:
46 | authn = request_with_credentials(f'{HUB_API_ROOT}/v1/auth/auto')
47 | if authn.get('success', False):
48 | self.id_token = authn.get('data', {}).get('idToken', None)
49 | self.authenticate()
50 | return True
51 | raise ConnectionError('Unable to fetch browser authentication details.')
52 | except ConnectionError:
53 | self.id_token = False # reset invalid
54 | return False
55 |
56 | def get_auth_header(self):
57 | if self.id_token:
58 | return {'authorization': f'Bearer {self.id_token}'}
59 | elif self.api_key:
60 | return {'x-api-key': self.api_key}
61 | else:
62 | return None
63 |
64 | def get_state(self) -> bool:
65 | """Get the authentication state"""
66 | return self.id_token or self.api_key
67 |
68 | def set_api_key(self, key: str):
69 | """Get the authentication state"""
70 | self.api_key = key
71 |
--------------------------------------------------------------------------------
/ultralytics/hub/session.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 | import signal
3 | import sys
4 | from pathlib import Path
5 | from time import sleep
6 |
7 | import requests
8 |
9 | from ultralytics.hub.utils import HUB_API_ROOT, check_dataset_disk_space, smart_request
10 | from ultralytics.yolo.utils import LOGGER, PREFIX, __version__, checks, emojis, is_colab, threaded
11 |
12 | AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local'
13 |
14 |
15 | class HUBTrainingSession:
16 |
17 | def __init__(self, model_id, auth):
18 | self.agent_id = None # identifies which instance is communicating with server
19 | self.model_id = model_id
20 | self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}'
21 | self.auth_header = auth.get_auth_header()
22 | self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds)
23 | self.timers = {} # rate limit timers (seconds)
24 | self.metrics_queue = {} # metrics queue
25 | self.model = self._get_model()
26 | self.alive = True
27 | self._start_heartbeat() # start heartbeats
28 | self._register_signal_handlers()
29 |
30 | def _register_signal_handlers(self):
31 | signal.signal(signal.SIGTERM, self._handle_signal)
32 | signal.signal(signal.SIGINT, self._handle_signal)
33 |
34 | def _handle_signal(self, signum, frame):
35 | """
36 | Prevent heartbeats from being sent on Colab after kill.
37 | This method does not use frame, it is included as it is
38 | passed by signal.
39 | """
40 | if self.alive is True:
41 | LOGGER.info(f'{PREFIX}Kill signal received! ❌')
42 | self._stop_heartbeat()
43 | sys.exit(signum)
44 |
45 | def _stop_heartbeat(self):
46 | """End the heartbeat loop"""
47 | self.alive = False
48 |
49 | def upload_metrics(self):
50 | payload = {'metrics': self.metrics_queue.copy(), 'type': 'metrics'}
51 | smart_request('post', self.api_url, json=payload, headers=self.auth_header, code=2)
52 |
53 | def _get_model(self):
54 | # Returns model from database by id
55 | api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}'
56 |
57 | try:
58 | response = smart_request('get', api_url, headers=self.auth_header, thread=False, code=0)
59 | data = response.json().get('data', None)
60 |
61 | if data.get('status', None) == 'trained':
62 | raise ValueError(
63 | emojis(f'Model is already trained and uploaded to '
64 | f'https://hub.ultralytics.com/models/{self.model_id} 🚀'))
65 |
66 | if not data.get('data', None):
67 | raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix
68 | self.model_id = data['id']
69 |
70 | # TODO: restore when server keys when dataset URL and GPU train is working
71 |
72 | self.train_args = {
73 | 'batch': data['batch_size'],
74 | 'epochs': data['epochs'],
75 | 'imgsz': data['imgsz'],
76 | 'patience': data['patience'],
77 | 'device': data['device'],
78 | 'cache': data['cache'],
79 | 'data': data['data']}
80 |
81 | self.model_file = data.get('cfg', data['weights'])
82 | self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
83 |
84 | return data
85 | except requests.exceptions.ConnectionError as e:
86 | raise ConnectionRefusedError('ERROR: The HUB server is not online. Please try again later.') from e
87 | except Exception:
88 | raise
89 |
90 | def check_disk_space(self):
91 | if not check_dataset_disk_space(self.model['data']):
92 | raise MemoryError('Not enough disk space')
93 |
94 | def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
95 | # Upload a model to HUB
96 | if Path(weights).is_file():
97 | with open(weights, 'rb') as f:
98 | file = f.read()
99 | else:
100 | LOGGER.warning(f'{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.')
101 | file = None
102 | url = f'{self.api_url}/upload'
103 | # url = 'http://httpbin.org/post' # for debug
104 | data = {'epoch': epoch}
105 | if final:
106 | data.update({'type': 'final', 'map': map})
107 | smart_request('post',
108 | url,
109 | data=data,
110 | files={'best.pt': file},
111 | headers=self.auth_header,
112 | retry=10,
113 | timeout=3600,
114 | thread=False,
115 | progress=True,
116 | code=4)
117 | else:
118 | data.update({'type': 'epoch', 'isBest': bool(is_best)})
119 | smart_request('post', url, data=data, files={'last.pt': file}, headers=self.auth_header, code=3)
120 |
121 | @threaded
122 | def _start_heartbeat(self):
123 | while self.alive:
124 | r = smart_request('post',
125 | f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
126 | json={
127 | 'agent': AGENT_NAME,
128 | 'agentId': self.agent_id},
129 | headers=self.auth_header,
130 | retry=0,
131 | code=5,
132 | thread=False) # already in a thread
133 | self.agent_id = r.json().get('data', {}).get('agentId', None)
134 | sleep(self.rate_limits['heartbeat'])
135 |
--------------------------------------------------------------------------------
/ultralytics/hub/utils.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | import os
4 | import platform
5 | import shutil
6 | import sys
7 | import threading
8 | import time
9 | from pathlib import Path
10 | from random import random
11 |
12 | import requests
13 | from tqdm import tqdm
14 |
15 | from ultralytics.yolo.utils import (DEFAULT_CFG_DICT, ENVIRONMENT, LOGGER, ONLINE, RANK, SETTINGS, TESTS_RUNNING,
16 | TQDM_BAR_FORMAT, TryExcept, __version__, colorstr, emojis, get_git_origin_url,
17 | is_colab, is_git_dir, is_pip_package)
18 |
19 | PREFIX = colorstr('Ultralytics HUB: ')
20 | HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
21 | HUB_API_ROOT = os.environ.get('ULTRALYTICS_HUB_API', 'https://api.ultralytics.com')
22 |
23 |
24 | def check_dataset_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=2.0):
25 | # Check that url fits on disk with safety factor sf, i.e. require 2GB free if url size is 1GB with sf=2.0
26 | gib = 1 << 30 # bytes per GiB
27 | data = int(requests.head(url).headers['Content-Length']) / gib # dataset size (GB)
28 | total, used, free = (x / gib for x in shutil.disk_usage('/')) # bytes
29 | LOGGER.info(f'{PREFIX}{data:.3f} GB dataset, {free:.1f}/{total:.1f} GB free disk space')
30 | if data * sf < free:
31 | return True # sufficient space
32 | LOGGER.warning(f'{PREFIX}WARNING: Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, '
33 | f'training cancelled ❌. Please free {data * sf - free:.1f} GB additional disk space and try again.')
34 | return False # insufficient space
35 |
36 |
37 | def request_with_credentials(url: str) -> any:
38 | """ Make an ajax request with cookies attached """
39 | if not is_colab():
40 | raise OSError('request_with_credentials() must run in a Colab environment')
41 | from google.colab import output # noqa
42 | from IPython import display # noqa
43 | display.display(
44 | display.Javascript("""
45 | window._hub_tmp = new Promise((resolve, reject) => {
46 | const timeout = setTimeout(() => reject("Failed authenticating existing browser session"), 5000)
47 | fetch("%s", {
48 | method: 'POST',
49 | credentials: 'include'
50 | })
51 | .then((response) => resolve(response.json()))
52 | .then((json) => {
53 | clearTimeout(timeout);
54 | }).catch((err) => {
55 | clearTimeout(timeout);
56 | reject(err);
57 | });
58 | });
59 | """ % url))
60 | return output.eval_js('_hub_tmp')
61 |
62 |
63 | def split_key(key=''):
64 | """
65 | Verify and split a 'api_key[sep]model_id' string, sep is one of '.' or '_'
66 |
67 | Args:
68 | key (str): The model key to split. If not provided, the user will be prompted to enter it.
69 |
70 | Returns:
71 | Tuple[str, str]: A tuple containing the API key and model ID.
72 | """
73 |
74 | import getpass
75 |
76 | error_string = emojis(f'{PREFIX}Invalid API key ⚠️\n') # error string
77 | if not key:
78 | key = getpass.getpass('Enter model key: ')
79 | sep = '_' if '_' in key else '.' if '.' in key else None # separator
80 | assert sep, error_string
81 | api_key, model_id = key.split(sep)
82 | assert len(api_key) and len(model_id), error_string
83 | return api_key, model_id
84 |
85 |
86 | def requests_with_progress(method, url, **kwargs):
87 | """
88 | Make an HTTP request using the specified method and URL, with an optional progress bar.
89 |
90 | Args:
91 | method (str): The HTTP method to use (e.g. 'GET', 'POST').
92 | url (str): The URL to send the request to.
93 | progress (bool, optional): Whether to display a progress bar. Defaults to False.
94 | **kwargs: Additional keyword arguments to pass to the underlying `requests.request` function.
95 |
96 | Returns:
97 | requests.Response: The response from the HTTP request.
98 |
99 | """
100 | progress = kwargs.pop('progress', False)
101 | if not progress:
102 | return requests.request(method, url, **kwargs)
103 | response = requests.request(method, url, stream=True, **kwargs)
104 | total = int(response.headers.get('content-length', 0)) # total size
105 | pbar = tqdm(total=total, unit='B', unit_scale=True, unit_divisor=1024, bar_format=TQDM_BAR_FORMAT)
106 | for data in response.iter_content(chunk_size=1024):
107 | pbar.update(len(data))
108 | pbar.close()
109 | return response
110 |
111 |
112 | def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbose=True, progress=False, **kwargs):
113 | """
114 | Makes an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
115 |
116 | Args:
117 | method (str): The HTTP method to use for the request. Choices are 'post' and 'get'.
118 | url (str): The URL to make the request to.
119 | retry (int, optional): Number of retries to attempt before giving up. Default is 3.
120 | timeout (int, optional): Timeout in seconds after which the function will give up retrying. Default is 30.
121 | thread (bool, optional): Whether to execute the request in a separate daemon thread. Default is True.
122 | code (int, optional): An identifier for the request, used for logging purposes. Default is -1.
123 | verbose (bool, optional): A flag to determine whether to print out to console or not. Default is True.
124 | progress (bool, optional): Whether to show a progress bar during the request. Default is False.
125 | **kwargs: Keyword arguments to be passed to the requests function specified in method.
126 |
127 | Returns:
128 | requests.Response: The HTTP response object. If the request is executed in a separate thread, returns None.
129 |
130 | """
131 | retry_codes = (408, 500) # retry only these codes
132 |
133 | @TryExcept(verbose=verbose)
134 | def func(func_method, func_url, **func_kwargs):
135 | r = None # response
136 | t0 = time.time() # initial time for timer
137 | for i in range(retry + 1):
138 | if (time.time() - t0) > timeout:
139 | break
140 | r = requests_with_progress(func_method, func_url, **func_kwargs) # i.e. get(url, data, json, files)
141 | if r.status_code == 200:
142 | break
143 | try:
144 | m = r.json().get('message', 'No JSON message.')
145 | except AttributeError:
146 | m = 'Unable to read JSON.'
147 | if i == 0:
148 | if r.status_code in retry_codes:
149 | m += f' Retrying {retry}x for {timeout}s.' if retry else ''
150 | elif r.status_code == 429: # rate limit
151 | h = r.headers # response headers
152 | m = f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). " \
153 | f"Please retry after {h['Retry-After']}s."
154 | if verbose:
155 | LOGGER.warning(f'{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})')
156 | if r.status_code not in retry_codes:
157 | return r
158 | time.sleep(2 ** i) # exponential standoff
159 | return r
160 |
161 | args = method, url
162 | kwargs['progress'] = progress
163 | if thread:
164 | threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start()
165 | else:
166 | return func(*args, **kwargs)
167 |
168 |
169 | class Traces:
170 |
171 | def __init__(self):
172 | """
173 | Initialize Traces for error tracking and reporting if tests are not currently running.
174 | """
175 | self.rate_limit = 3.0 # rate limit (seconds)
176 | self.t = 0.0 # rate limit timer (seconds)
177 | self.metadata = {
178 | 'sys_argv_name': Path(sys.argv[0]).name,
179 | 'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
180 | 'python': platform.python_version(),
181 | 'release': __version__,
182 | 'environment': ENVIRONMENT}
183 | self.enabled = \
184 | SETTINGS['sync'] and \
185 | RANK in (-1, 0) and \
186 | not TESTS_RUNNING and \
187 | ONLINE and \
188 | (is_pip_package() or get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git')
189 |
190 | def __call__(self, cfg, all_keys=False, traces_sample_rate=1.0):
191 | """
192 | Sync traces data if enabled in the global settings
193 |
194 | Args:
195 | cfg (IterableSimpleNamespace): Configuration for the task and mode.
196 | all_keys (bool): Sync all items, not just non-default values.
197 | traces_sample_rate (float): Fraction of traces captured from 0.0 to 1.0
198 | """
199 | t = time.time() # current time
200 | if self.enabled and random() < traces_sample_rate and (t - self.t) > self.rate_limit:
201 | self.t = t # reset rate limit timer
202 | cfg = vars(cfg) # convert type from IterableSimpleNamespace to dict
203 | if not all_keys: # filter cfg
204 | include_keys = {'task', 'mode'} # always include
205 | cfg = {
206 | k: (v.split(os.sep)[-1] if isinstance(v, str) and os.sep in v else v)
207 | for k, v in cfg.items() if v != DEFAULT_CFG_DICT.get(k, None) or k in include_keys}
208 | trace = {'uuid': SETTINGS['uuid'], 'cfg': cfg, 'metadata': self.metadata}
209 |
210 | # Send a request to the HUB API to sync analytics
211 | smart_request('post', f'{HUB_API_ROOT}/v1/usage/anonymous', json=trace, code=3, retry=0, verbose=False)
212 |
213 |
214 | # Run below code on hub/utils init -------------------------------------------------------------------------------------
215 | traces = Traces()
216 |
--------------------------------------------------------------------------------
/ultralytics/nn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbl1234/YOLOv8_PYQT5_GUI/0b8a3f89d7a4b8f6a55b66da5d425b7b2918a407/ultralytics/nn/__init__.py
--------------------------------------------------------------------------------
/ultralytics/yolo/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | from . import v8
4 |
5 | __all__ = 'v8', # tuple or list
6 |
--------------------------------------------------------------------------------
/ultralytics/yolo/cfg/default.yaml:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 | # Default training settings and hyperparameters for medium-augmentation COCO training
3 |
4 | task: detect # inference task, i.e. detect, segment, classify
5 | mode: train # YOLO mode, i.e. train, val, predict, export
6 |
7 | # Train settings -------------------------------------------------------------------------------------------------------
8 | model: # path to model file, i.e. yolov8n.pt, yolov8n.yaml
9 | data: # path to data file, i.e. coco128.yaml
10 | epochs: 100 # number of epochs to train for
11 | patience: 50 # epochs to wait for no observable improvement for early stopping of training
12 | batch: 16 # number of images per batch (-1 for AutoBatch)
13 | imgsz: 640 # size of input images as integer or w,h
14 | save: True # save train checkpoints and predict results
15 | save_period: -1 # Save checkpoint every x epochs (disabled if < 1)
16 | cache: False # True/ram, disk or False. Use cache for data loading
17 | device: # device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu
18 | workers: 8 # number of worker threads for data loading (per RANK if DDP)
19 | project: # project name
20 | name: # experiment name, results saved to 'project/name' directory
21 | exist_ok: False # whether to overwrite existing experiment
22 | pretrained: False # whether to use a pretrained model
23 | optimizer: SGD # optimizer to use, choices=['SGD', 'Adam', 'AdamW', 'RMSProp']
24 | verbose: True # whether to print verbose output
25 | seed: 0 # random seed for reproducibility
26 | deterministic: True # whether to enable deterministic mode
27 | single_cls: False # train multi-class data as single-class
28 | image_weights: False # use weighted image selection for training
29 | rect: False # support rectangular training if mode='train', support rectangular evaluation if mode='val'
30 | cos_lr: False # use cosine learning rate scheduler
31 | close_mosaic: 10 # disable mosaic augmentation for final 10 epochs
32 | resume: False # resume training from last checkpoint
33 | # Segmentation
34 | overlap_mask: True # masks should overlap during training (segment train only)
35 | mask_ratio: 4 # mask downsample ratio (segment train only)
36 | # Classification
37 | dropout: 0.0 # use dropout regularization (classify train only)
38 |
39 | # Val/Test settings ----------------------------------------------------------------------------------------------------
40 | val: True # validate/test during training
41 | split: val # dataset split to use for validation, i.e. 'val', 'test' or 'train'
42 | save_json: False # save results to JSON file
43 | save_hybrid: False # save hybrid version of labels (labels + additional predictions)
44 | conf: # object confidence threshold for detection (default 0.25 predict, 0.001 val)
45 | iou: 0.7 # intersection over union (IoU) threshold for NMS
46 | max_det: 300 # maximum number of detections per image
47 | half: False # use half precision (FP16)
48 | dnn: False # use OpenCV DNN for ONNX inference
49 | plots: True # save plots during train/val
50 |
51 | # Prediction settings --------------------------------------------------------------------------------------------------
52 | source: # source directory for images or videos
53 | show: False # show results if possible
54 | save_txt: False # save results as .txt file
55 | save_conf: False # save results with confidence scores
56 | save_crop: False # save cropped images with results
57 | hide_labels: False # hide labels
58 | hide_conf: False # hide confidence scores
59 | vid_stride: 1 # video frame-rate stride
60 | line_thickness: 3 # bounding box thickness (pixels)
61 | visualize: False # visualize model features
62 | augment: False # apply image augmentation to prediction sources
63 | agnostic_nms: False # class-agnostic NMS
64 | classes: # filter results by class, i.e. class=0, or class=[0,2,3]
65 | retina_masks: False # use high-resolution segmentation masks
66 | boxes: True # Show boxes in segmentation predictions
67 |
68 | # Export settings ------------------------------------------------------------------------------------------------------
69 | format: torchscript # format to export to
70 | keras: False # use Keras
71 | optimize: False # TorchScript: optimize for mobile
72 | int8: False # CoreML/TF INT8 quantization
73 | dynamic: False # ONNX/TF/TensorRT: dynamic axes
74 | simplify: False # ONNX: simplify model
75 | opset: # ONNX: opset version (optional)
76 | workspace: 4 # TensorRT: workspace size (GB)
77 | nms: False # CoreML: add NMS
78 |
79 | # Hyperparameters ------------------------------------------------------------------------------------------------------
80 | lr0: 0.01 # initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
81 | lrf: 0.01 # final learning rate (lr0 * lrf)
82 | momentum: 0.937 # SGD momentum/Adam beta1
83 | weight_decay: 0.0005 # optimizer weight decay 5e-4
84 | warmup_epochs: 3.0 # warmup epochs (fractions ok)
85 | warmup_momentum: 0.8 # warmup initial momentum
86 | warmup_bias_lr: 0.1 # warmup initial bias lr
87 | box: 7.5 # box loss gain
88 | cls: 0.5 # cls loss gain (scale with pixels)
89 | dfl: 1.5 # dfl loss gain
90 | fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5)
91 | label_smoothing: 0.0 # label smoothing (fraction)
92 | nbs: 64 # nominal batch size
93 | hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
94 | hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)
95 | hsv_v: 0.4 # image HSV-Value augmentation (fraction)
96 | degrees: 0.0 # image rotation (+/- deg)
97 | translate: 0.1 # image translation (+/- fraction)
98 | scale: 0.5 # image scale (+/- gain)
99 | shear: 0.0 # image shear (+/- deg)
100 | perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
101 | flipud: 0.0 # image flip up-down (probability)
102 | fliplr: 0.5 # image flip left-right (probability)
103 | mosaic: 1.0 # image mosaic (probability)
104 | mixup: 0.0 # image mixup (probability)
105 | copy_paste: 0.0 # segment copy-paste (probability)
106 |
107 | # Custom config.yaml ---------------------------------------------------------------------------------------------------
108 | cfg: # for overriding defaults.yaml
109 |
110 | # Debug, do not modify -------------------------------------------------------------------------------------------------
111 | v5loader: False # use legacy YOLOv5 dataloader
112 |
113 | # Tracker settings ------------------------------------------------------------------------------------------------------
114 | tracker: botsort.yaml # tracker type, ['botsort.yaml', 'bytetrack.yaml']
115 |
--------------------------------------------------------------------------------
/ultralytics/yolo/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | from .base import BaseDataset
4 | from .build import build_classification_dataloader, build_dataloader, load_inference_source
5 | from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
6 | from .dataset_wrappers import MixAndRectDataset
7 |
8 | __all__ = ('BaseDataset', 'ClassificationDataset', 'MixAndRectDataset', 'SemanticDataset', 'YOLODataset',
9 | 'build_classification_dataloader', 'build_dataloader', 'load_inference_source')
10 |
--------------------------------------------------------------------------------
/ultralytics/yolo/data/base.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | import glob
4 | import math
5 | import os
6 | from multiprocessing.pool import ThreadPool
7 | from pathlib import Path
8 | from typing import Optional
9 |
10 | import cv2
11 | import numpy as np
12 | from torch.utils.data import Dataset
13 | from tqdm import tqdm
14 |
15 | from ..utils import NUM_THREADS, TQDM_BAR_FORMAT
16 | from .utils import HELP_URL, IMG_FORMATS, LOCAL_RANK
17 |
18 |
19 | class BaseDataset(Dataset):
20 | """Base Dataset.
21 | Args:
22 | img_path (str): image path.
23 | pipeline (dict): a dict of image transforms.
24 | label_path (str): label path, this can also be an ann_file or other custom label path.
25 | """
26 |
27 | def __init__(self,
28 | img_path,
29 | imgsz=640,
30 | cache=False,
31 | augment=True,
32 | hyp=None,
33 | prefix='',
34 | rect=False,
35 | batch_size=None,
36 | stride=32,
37 | pad=0.5,
38 | single_cls=False,
39 | classes=None):
40 | super().__init__()
41 | self.img_path = img_path
42 | self.imgsz = imgsz
43 | self.augment = augment
44 | self.single_cls = single_cls
45 | self.prefix = prefix
46 |
47 | self.im_files = self.get_img_files(self.img_path)
48 | self.labels = self.get_labels()
49 | self.update_labels(include_class=classes) # single_cls and include_class
50 |
51 | self.ni = len(self.labels)
52 |
53 | # rect stuff
54 | self.rect = rect
55 | self.batch_size = batch_size
56 | self.stride = stride
57 | self.pad = pad
58 | if self.rect:
59 | assert self.batch_size is not None
60 | self.set_rectangle()
61 |
62 | # cache stuff
63 | self.ims = [None] * self.ni
64 | self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
65 | if cache:
66 | self.cache_images(cache)
67 |
68 | # transforms
69 | self.transforms = self.build_transforms(hyp=hyp)
70 |
71 | def get_img_files(self, img_path):
72 | """Read image files."""
73 | try:
74 | f = [] # image files
75 | for p in img_path if isinstance(img_path, list) else [img_path]:
76 | p = Path(p) # os-agnostic
77 | if p.is_dir(): # dir
78 | f += glob.glob(str(p / '**' / '*.*'), recursive=True)
79 | # f = list(p.rglob('*.*')) # pathlib
80 | elif p.is_file(): # file
81 | with open(p) as t:
82 | t = t.read().strip().splitlines()
83 | parent = str(p.parent) + os.sep
84 | f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
85 | # f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
86 | else:
87 | raise FileNotFoundError(f'{self.prefix}{p} does not exist')
88 | im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
89 | # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
90 | assert im_files, f'{self.prefix}No images found'
91 | except Exception as e:
92 | raise FileNotFoundError(f'{self.prefix}Error loading data from {img_path}\n{HELP_URL}') from e
93 | return im_files
94 |
95 | def update_labels(self, include_class: Optional[list]):
96 | """include_class, filter labels to include only these classes (optional)"""
97 | include_class_array = np.array(include_class).reshape(1, -1)
98 | for i in range(len(self.labels)):
99 | if include_class is not None:
100 | cls = self.labels[i]['cls']
101 | bboxes = self.labels[i]['bboxes']
102 | segments = self.labels[i]['segments']
103 | j = (cls == include_class_array).any(1)
104 | self.labels[i]['cls'] = cls[j]
105 | self.labels[i]['bboxes'] = bboxes[j]
106 | if segments:
107 | self.labels[i]['segments'] = [segments[si] for si, idx in enumerate(j) if idx]
108 | if self.single_cls:
109 | self.labels[i]['cls'][:, 0] = 0
110 |
111 | def load_image(self, i):
112 | # Loads 1 image from dataset index 'i', returns (im, resized hw)
113 | im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
114 | if im is None: # not cached in RAM
115 | if fn.exists(): # load npy
116 | im = np.load(fn)
117 | else: # read image
118 | im = cv2.imread(f) # BGR
119 | if im is None:
120 | raise FileNotFoundError(f'Image Not Found {f}')
121 | h0, w0 = im.shape[:2] # orig hw
122 | r = self.imgsz / max(h0, w0) # ratio
123 | if r != 1: # if sizes are not equal
124 | interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
125 | im = cv2.resize(im, (math.ceil(w0 * r), math.ceil(h0 * r)), interpolation=interp)
126 | return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
127 | return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
128 |
129 | def cache_images(self, cache):
130 | # cache images to memory or disk
131 | gb = 0 # Gigabytes of cached images
132 | self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
133 | fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image
134 | with ThreadPool(NUM_THREADS) as pool:
135 | results = pool.imap(fcn, range(self.ni))
136 | pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
137 | for i, x in pbar:
138 | if cache == 'disk':
139 | gb += self.npy_files[i].stat().st_size
140 | else: # 'ram'
141 | self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
142 | gb += self.ims[i].nbytes
143 | pbar.desc = f'{self.prefix}Caching images ({gb / 1E9:.1f}GB {cache})'
144 | pbar.close()
145 |
146 | def cache_images_to_disk(self, i):
147 | # Saves an image as an *.npy file for faster loading
148 | f = self.npy_files[i]
149 | if not f.exists():
150 | np.save(f.as_posix(), cv2.imread(self.im_files[i]))
151 |
152 | def set_rectangle(self):
153 | bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
154 | nb = bi[-1] + 1 # number of batches
155 |
156 | s = np.array([x.pop('shape') for x in self.labels]) # hw
157 | ar = s[:, 0] / s[:, 1] # aspect ratio
158 | irect = ar.argsort()
159 | self.im_files = [self.im_files[i] for i in irect]
160 | self.labels = [self.labels[i] for i in irect]
161 | ar = ar[irect]
162 |
163 | # Set training image shapes
164 | shapes = [[1, 1]] * nb
165 | for i in range(nb):
166 | ari = ar[bi == i]
167 | mini, maxi = ari.min(), ari.max()
168 | if maxi < 1:
169 | shapes[i] = [maxi, 1]
170 | elif mini > 1:
171 | shapes[i] = [1, 1 / mini]
172 |
173 | self.batch_shapes = np.ceil(np.array(shapes) * self.imgsz / self.stride + self.pad).astype(int) * self.stride
174 | self.batch = bi # batch index of image
175 |
176 | def __getitem__(self, index):
177 | return self.transforms(self.get_label_info(index))
178 |
179 | def get_label_info(self, index):
180 | label = self.labels[index].copy()
181 | label.pop('shape', None) # shape is for rect, remove it
182 | label['img'], label['ori_shape'], label['resized_shape'] = self.load_image(index)
183 | label['ratio_pad'] = (
184 | label['resized_shape'][0] / label['ori_shape'][0],
185 | label['resized_shape'][1] / label['ori_shape'][1],
186 | ) # for evaluation
187 | if self.rect:
188 | label['rect_shape'] = self.batch_shapes[self.batch[index]]
189 | label = self.update_labels_info(label)
190 | return label
191 |
192 | def __len__(self):
193 | return len(self.labels)
194 |
195 | def update_labels_info(self, label):
196 | """custom your label format here"""
197 | return label
198 |
199 | def build_transforms(self, hyp=None):
200 | """Users can custom augmentations here
201 | like:
202 | if self.augment:
203 | # training transforms
204 | return Compose([])
205 | else:
206 | # val transforms
207 | return Compose([])
208 | """
209 | raise NotImplementedError
210 |
211 | def get_labels(self):
212 | """Users can custom their own format here.
213 | Make sure your output is a list with each element like below:
214 | dict(
215 | im_file=im_file,
216 | shape=shape, # format: (height, width)
217 | cls=cls,
218 | bboxes=bboxes, # xywh
219 | segments=segments, # xy
220 | keypoints=keypoints, # xy
221 | normalized=True, # or False
222 | bbox_format="xyxy", # or xywh, ltwh
223 | )
224 | """
225 | raise NotImplementedError
226 |
--------------------------------------------------------------------------------
/ultralytics/yolo/data/build.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | import os
4 | import random
5 | from pathlib import Path
6 |
7 | import numpy as np
8 | import torch
9 | from PIL import Image
10 | from torch.utils.data import DataLoader, dataloader, distributed
11 |
12 | from ultralytics.yolo.data.dataloaders.stream_loaders import (LOADERS, LoadImages, LoadPilAndNumpy, LoadScreenshots,
13 | LoadStreams, LoadTensor, SourceTypes, autocast_list)
14 | from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
15 | from ultralytics.yolo.utils.checks import check_file
16 |
17 | from ..utils import LOGGER, colorstr
18 | from ..utils.torch_utils import torch_distributed_zero_first
19 | from .dataset import ClassificationDataset, YOLODataset
20 | from .utils import PIN_MEMORY, RANK
21 |
22 |
23 | class InfiniteDataLoader(dataloader.DataLoader):
24 | """Dataloader that reuses workers
25 |
26 | Uses same syntax as vanilla DataLoader
27 | """
28 |
29 | def __init__(self, *args, **kwargs):
30 | super().__init__(*args, **kwargs)
31 | object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
32 | self.iterator = super().__iter__()
33 |
34 | def __len__(self):
35 | return len(self.batch_sampler.sampler)
36 |
37 | def __iter__(self):
38 | for _ in range(len(self)):
39 | yield next(self.iterator)
40 |
41 |
42 | class _RepeatSampler:
43 | """Sampler that repeats forever
44 |
45 | Args:
46 | sampler (Sampler)
47 | """
48 |
49 | def __init__(self, sampler):
50 | self.sampler = sampler
51 |
52 | def __iter__(self):
53 | while True:
54 | yield from iter(self.sampler)
55 |
56 |
57 | def seed_worker(worker_id): # noqa
58 | # Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader
59 | worker_seed = torch.initial_seed() % 2 ** 32
60 | np.random.seed(worker_seed)
61 | random.seed(worker_seed)
62 |
63 |
64 | def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, rank=-1, mode='train'):
65 | assert mode in ['train', 'val']
66 | shuffle = mode == 'train'
67 | if cfg.rect and shuffle:
68 | LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
69 | shuffle = False
70 | with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
71 | dataset = YOLODataset(
72 | img_path=img_path,
73 | imgsz=cfg.imgsz,
74 | batch_size=batch,
75 | augment=mode == 'train', # augmentation
76 | hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
77 | rect=cfg.rect or rect, # rectangular batches
78 | cache=cfg.cache or None,
79 | single_cls=cfg.single_cls or False,
80 | stride=int(stride),
81 | pad=0.0 if mode == 'train' else 0.5,
82 | prefix=colorstr(f'{mode}: '),
83 | use_segments=cfg.task == 'segment',
84 | use_keypoints=cfg.task == 'keypoint',
85 | names=names,
86 | classes=cfg.classes)
87 |
88 | batch = min(batch, len(dataset))
89 | nd = torch.cuda.device_count() # number of CUDA devices
90 | workers = cfg.workers if mode == 'train' else cfg.workers * 2
91 | nw = min([os.cpu_count() // max(nd, 1), batch if batch > 1 else 0, workers]) # number of workers
92 | sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
93 | loader = DataLoader if cfg.image_weights or cfg.close_mosaic else InfiniteDataLoader # allow attribute updates
94 | generator = torch.Generator()
95 | generator.manual_seed(6148914691236517205 + RANK)
96 | return loader(dataset=dataset,
97 | batch_size=batch,
98 | shuffle=shuffle and sampler is None,
99 | num_workers=nw,
100 | sampler=sampler,
101 | pin_memory=PIN_MEMORY,
102 | collate_fn=getattr(dataset, 'collate_fn', None),
103 | worker_init_fn=seed_worker,
104 | generator=generator), dataset
105 |
106 |
107 | # build classification
108 | # TODO: using cfg like `build_dataloader`
109 | def build_classification_dataloader(path,
110 | imgsz=224,
111 | batch_size=16,
112 | augment=True,
113 | cache=False,
114 | rank=-1,
115 | workers=8,
116 | shuffle=True):
117 | # Returns Dataloader object to be used with YOLOv5 Classifier
118 | with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
119 | dataset = ClassificationDataset(root=path, imgsz=imgsz, augment=augment, cache=cache)
120 | batch_size = min(batch_size, len(dataset))
121 | nd = torch.cuda.device_count()
122 | nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])
123 | sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
124 | generator = torch.Generator()
125 | generator.manual_seed(6148914691236517205 + RANK)
126 | return InfiniteDataLoader(dataset,
127 | batch_size=batch_size,
128 | shuffle=shuffle and sampler is None,
129 | num_workers=nw,
130 | sampler=sampler,
131 | pin_memory=PIN_MEMORY,
132 | worker_init_fn=seed_worker,
133 | generator=generator) # or DataLoader(persistent_workers=True)
134 |
135 |
136 | def check_source(source):
137 | webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, False
138 | if isinstance(source, (str, int, Path)): # int for local usb camera
139 | source = str(source)
140 | is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
141 | is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://'))
142 | webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
143 | screenshot = source.lower().startswith('screen')
144 | if is_url and is_file:
145 | source = check_file(source) # download
146 | elif isinstance(source, tuple(LOADERS)):
147 | in_memory = True
148 | elif isinstance(source, (list, tuple)):
149 | source = autocast_list(source) # convert all list elements to PIL or np arrays
150 | from_img = True
151 | elif isinstance(source, (Image.Image, np.ndarray)):
152 | from_img = True
153 | elif isinstance(source, torch.Tensor):
154 | tensor = True
155 | else:
156 | raise TypeError('Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict')
157 |
158 | return source, webcam, screenshot, from_img, in_memory, tensor
159 |
160 |
161 | def load_inference_source(source=None, transforms=None, imgsz=640, vid_stride=1, stride=32, auto=True):
162 | """
163 | TODO: docs
164 | """
165 | source, webcam, screenshot, from_img, in_memory, tensor = check_source(source)
166 | source_type = source.source_type if in_memory else SourceTypes(webcam, screenshot, from_img, tensor)
167 |
168 | # Dataloader
169 | if tensor:
170 | dataset = LoadTensor(source)
171 | elif in_memory:
172 | dataset = source
173 | elif webcam:
174 | dataset = LoadStreams(source,
175 | imgsz=imgsz,
176 | stride=stride,
177 | auto=auto,
178 | transforms=transforms,
179 | vid_stride=vid_stride)
180 |
181 | elif screenshot:
182 | dataset = LoadScreenshots(source, imgsz=imgsz, stride=stride, auto=auto, transforms=transforms)
183 | elif from_img:
184 | dataset = LoadPilAndNumpy(source, imgsz=imgsz, stride=stride, auto=auto, transforms=transforms)
185 | else:
186 | dataset = LoadImages(source,
187 | imgsz=imgsz,
188 | stride=stride,
189 | auto=auto,
190 | transforms=transforms,
191 | vid_stride=vid_stride)
192 |
193 | setattr(dataset, 'source_type', source_type) # attach source types
194 |
195 | return dataset
196 |
--------------------------------------------------------------------------------
/ultralytics/yolo/data/dataloaders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbl1234/YOLOv8_PYQT5_GUI/0b8a3f89d7a4b8f6a55b66da5d425b7b2918a407/ultralytics/yolo/data/dataloaders/__init__.py
--------------------------------------------------------------------------------
/ultralytics/yolo/data/dataset_wrappers.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | import collections
4 | from copy import deepcopy
5 |
6 | from .augment import LetterBox
7 |
8 |
9 | class MixAndRectDataset:
10 | """A wrapper of multiple images mixed dataset.
11 |
12 | Args:
13 | dataset (:obj:`BaseDataset`): The dataset to be mixed.
14 | transforms (Sequence[dict]): config dict to be composed.
15 | """
16 |
17 | def __init__(self, dataset):
18 | self.dataset = dataset
19 | self.imgsz = dataset.imgsz
20 |
21 | def __len__(self):
22 | return len(self.dataset)
23 |
24 | def __getitem__(self, index):
25 | labels = deepcopy(self.dataset[index])
26 | for transform in self.dataset.transforms.tolist():
27 | # mosaic and mixup
28 | if hasattr(transform, 'get_indexes'):
29 | indexes = transform.get_indexes(self.dataset)
30 | if not isinstance(indexes, collections.abc.Sequence):
31 | indexes = [indexes]
32 | mix_labels = [deepcopy(self.dataset[index]) for index in indexes]
33 | labels['mix_labels'] = mix_labels
34 | if self.dataset.rect and isinstance(transform, LetterBox):
35 | transform.new_shape = self.dataset.batch_shapes[self.dataset.batch[index]]
36 | labels = transform(labels)
37 | if 'mix_labels' in labels:
38 | labels.pop('mix_labels')
39 | return labels
40 |
--------------------------------------------------------------------------------
/ultralytics/yolo/engine/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbl1234/YOLOv8_PYQT5_GUI/0b8a3f89d7a4b8f6a55b66da5d425b7b2918a407/ultralytics/yolo/engine/__init__.py
--------------------------------------------------------------------------------
/ultralytics/yolo/utils/autobatch.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 | """
3 | Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch.
4 | """
5 |
6 | from copy import deepcopy
7 |
8 | import numpy as np
9 | import torch
10 |
11 | from ultralytics.yolo.utils import LOGGER, colorstr
12 | from ultralytics.yolo.utils.torch_utils import profile
13 |
14 |
15 | def check_train_batch_size(model, imgsz=640, amp=True):
16 | """
17 | Check YOLO training batch size using the autobatch() function.
18 |
19 | Args:
20 | model (torch.nn.Module): YOLO model to check batch size for.
21 | imgsz (int): Image size used for training.
22 | amp (bool): If True, use automatic mixed precision (AMP) for training.
23 |
24 | Returns:
25 | int: Optimal batch size computed using the autobatch() function.
26 | """
27 |
28 | with torch.cuda.amp.autocast(amp):
29 | return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
30 |
31 |
32 | def autobatch(model, imgsz=640, fraction=0.67, batch_size=16):
33 | """
34 | Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory.
35 |
36 | Args:
37 | model: YOLO model to compute batch size for.
38 | imgsz (int, optional): The image size used as input for the YOLO model. Defaults to 640.
39 | fraction (float, optional): The fraction of available CUDA memory to use. Defaults to 0.67.
40 | batch_size (int, optional): The default batch size to use if an error is detected. Defaults to 16.
41 |
42 | Returns:
43 | int: The optimal batch size.
44 | """
45 |
46 | # Check device
47 | prefix = colorstr('AutoBatch: ')
48 | LOGGER.info(f'{prefix}Computing optimal batch size for imgsz={imgsz}')
49 | device = next(model.parameters()).device # get model device
50 | if device.type == 'cpu':
51 | LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')
52 | return batch_size
53 | if torch.backends.cudnn.benchmark:
54 | LOGGER.info(f'{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}')
55 | return batch_size
56 |
57 | # Inspect CUDA memory
58 | gb = 1 << 30 # bytes to GiB (1024 ** 3)
59 | d = str(device).upper() # 'CUDA:0'
60 | properties = torch.cuda.get_device_properties(device) # device properties
61 | t = properties.total_memory / gb # GiB total
62 | r = torch.cuda.memory_reserved(device) / gb # GiB reserved
63 | a = torch.cuda.memory_allocated(device) / gb # GiB allocated
64 | f = t - (r + a) # GiB free
65 | LOGGER.info(f'{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free')
66 |
67 | # Profile batch sizes
68 | batch_sizes = [1, 2, 4, 8, 16]
69 | try:
70 | img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
71 | results = profile(img, model, n=3, device=device)
72 |
73 | # Fit a solution
74 | y = [x[2] for x in results if x] # memory [2]
75 | p = np.polyfit(batch_sizes[:len(y)], y, deg=1) # first degree polynomial fit
76 | b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size)
77 | if None in results: # some sizes failed
78 | i = results.index(None) # first fail index
79 | if b >= batch_sizes[i]: # y intercept above failure point
80 | b = batch_sizes[max(i - 1, 0)] # select prior safe point
81 | if b < 1 or b > 1024: # b outside of safe range
82 | b = batch_size
83 | LOGGER.info(f'{prefix}WARNING ⚠️ CUDA anomaly detected, using default batch-size {batch_size}.')
84 |
85 | fraction = (np.polyval(p, b) + r + a) / t # actual fraction predicted
86 | LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅')
87 | return b
88 | except Exception as e:
89 | LOGGER.warning(f'{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.')
90 | return batch_size
91 |
--------------------------------------------------------------------------------
/ultralytics/yolo/utils/benchmarks.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 | """
3 | Benchmark a YOLO model formats for speed and accuracy
4 |
5 | Usage:
6 | from ultralytics.yolo.utils.benchmarks import run_benchmarks
7 | run_benchmarks(model='yolov8n.pt', imgsz=160)
8 |
9 | Format | `format=argument` | Model
10 | --- | --- | ---
11 | PyTorch | - | yolov8n.pt
12 | TorchScript | `torchscript` | yolov8n.torchscript
13 | ONNX | `onnx` | yolov8n.onnx
14 | OpenVINO | `openvino` | yolov8n_openvino_model/
15 | TensorRT | `engine` | yolov8n.engine
16 | CoreML | `coreml` | yolov8n.mlmodel
17 | TensorFlow SavedModel | `saved_model` | yolov8n_saved_model/
18 | TensorFlow GraphDef | `pb` | yolov8n.pb
19 | TensorFlow Lite | `tflite` | yolov8n.tflite
20 | TensorFlow Edge TPU | `edgetpu` | yolov8n_edgetpu.tflite
21 | TensorFlow.js | `tfjs` | yolov8n_web_model/
22 | PaddlePaddle | `paddle` | yolov8n_paddle_model/
23 | """
24 |
25 | import platform
26 | import time
27 | from pathlib import Path
28 |
29 | from ultralytics import YOLO
30 | from ultralytics.yolo.engine.exporter import export_formats
31 | from ultralytics.yolo.utils import LINUX, LOGGER, MACOS, ROOT, SETTINGS
32 | from ultralytics.yolo.utils.checks import check_yolo
33 | from ultralytics.yolo.utils.downloads import download
34 | from ultralytics.yolo.utils.files import file_size
35 | from ultralytics.yolo.utils.torch_utils import select_device
36 |
37 |
38 | def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt', imgsz=160, half=False, device='cpu', hard_fail=False):
39 | import pandas as pd
40 | pd.options.display.max_columns = 10
41 | pd.options.display.width = 120
42 | device = select_device(device, verbose=False)
43 | if isinstance(model, (str, Path)):
44 | model = YOLO(model)
45 |
46 | y = []
47 | t0 = time.time()
48 | for i, (name, format, suffix, cpu, gpu) in export_formats().iterrows(): # index, (name, format, suffix, CPU, GPU)
49 | emoji, filename = '❌', None # export defaults
50 | try:
51 | if model.task == 'classify':
52 | assert i != 11, 'paddle cls exports coming soon'
53 | assert i != 9 or LINUX, 'Edge TPU export only supported on Linux'
54 | if i == 10:
55 | assert MACOS or LINUX, 'TF.js export only supported on macOS and Linux'
56 | if 'cpu' in device.type:
57 | assert cpu, 'inference not supported on CPU'
58 | if 'cuda' in device.type:
59 | assert gpu, 'inference not supported on GPU'
60 |
61 | # Export
62 | if format == '-':
63 | filename = model.ckpt_path or model.cfg
64 | export = model # PyTorch format
65 | else:
66 | filename = model.export(imgsz=imgsz, format=format, half=half, device=device) # all others
67 | export = YOLO(filename, task=model.task)
68 | assert suffix in str(filename), 'export failed'
69 | emoji = '❎' # indicates export succeeded
70 |
71 | # Predict
72 | assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported
73 | assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML
74 | if not (ROOT / 'assets/bus.jpg').exists():
75 | download(url='https://ultralytics.com/images/bus.jpg', dir=ROOT / 'assets')
76 | export.predict(ROOT / 'assets/bus.jpg', imgsz=imgsz, device=device, half=half)
77 |
78 | # Validate
79 | if model.task == 'detect':
80 | data, key = 'coco128.yaml', 'metrics/mAP50-95(B)'
81 | elif model.task == 'segment':
82 | data, key = 'coco128-seg.yaml', 'metrics/mAP50-95(M)'
83 | elif model.task == 'classify':
84 | data, key = 'imagenet100', 'metrics/accuracy_top5'
85 |
86 | results = export.val(data=data, batch=1, imgsz=imgsz, plots=False, device=device, half=half, verbose=False)
87 | metric, speed = results.results_dict[key], results.speed['inference']
88 | y.append([name, '✅', round(file_size(filename), 1), round(metric, 4), round(speed, 2)])
89 | except Exception as e:
90 | if hard_fail:
91 | assert type(e) is AssertionError, f'Benchmark hard_fail for {name}: {e}'
92 | LOGGER.warning(f'ERROR ❌️ Benchmark failure for {name}: {e}')
93 | y.append([name, emoji, round(file_size(filename), 1), None, None]) # mAP, t_inference
94 |
95 | # Print results
96 | check_yolo(device=device) # print system info
97 | df = pd.DataFrame(y, columns=['Format', 'Status❔', 'Size (MB)', key, 'Inference time (ms/im)'])
98 |
99 | name = Path(model.ckpt_path).name
100 | s = f'\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({time.time() - t0:.2f}s)\n{df}\n'
101 | LOGGER.info(s)
102 | with open('benchmarks.log', 'a', errors='ignore', encoding='utf-8') as f:
103 | f.write(s)
104 |
105 | if hard_fail and isinstance(hard_fail, float):
106 | metrics = df[key].array # values to compare to floor
107 | floor = hard_fail # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n
108 | assert all(x > floor for x in metrics if pd.notna(x)), f'HARD FAIL: one or more metric(s) < floor {floor}'
109 |
110 | return df
111 |
112 |
113 | if __name__ == '__main__':
114 | benchmark()
115 |
--------------------------------------------------------------------------------
/ultralytics/yolo/utils/callbacks/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import add_integration_callbacks, default_callbacks
2 |
3 | __all__ = 'add_integration_callbacks', 'default_callbacks'
4 |
--------------------------------------------------------------------------------
/ultralytics/yolo/utils/callbacks/base.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 | """
3 | Base callbacks
4 | """
5 |
6 |
7 | # Trainer callbacks ----------------------------------------------------------------------------------------------------
8 | def on_pretrain_routine_start(trainer):
9 | pass
10 |
11 |
12 | def on_pretrain_routine_end(trainer):
13 | pass
14 |
15 |
16 | def on_train_start(trainer):
17 | pass
18 |
19 |
20 | def on_train_epoch_start(trainer):
21 | pass
22 |
23 |
24 | def on_train_batch_start(trainer):
25 | pass
26 |
27 |
28 | def optimizer_step(trainer):
29 | pass
30 |
31 |
32 | def on_before_zero_grad(trainer):
33 | pass
34 |
35 |
36 | def on_train_batch_end(trainer):
37 | pass
38 |
39 |
40 | def on_train_epoch_end(trainer):
41 | pass
42 |
43 |
44 | def on_fit_epoch_end(trainer):
45 | pass
46 |
47 |
48 | def on_model_save(trainer):
49 | pass
50 |
51 |
52 | def on_train_end(trainer):
53 | pass
54 |
55 |
56 | def on_params_update(trainer):
57 | pass
58 |
59 |
60 | def teardown(trainer):
61 | pass
62 |
63 |
64 | # Validator callbacks --------------------------------------------------------------------------------------------------
65 | def on_val_start(validator):
66 | pass
67 |
68 |
69 | def on_val_batch_start(validator):
70 | pass
71 |
72 |
73 | def on_val_batch_end(validator):
74 | pass
75 |
76 |
77 | def on_val_end(validator):
78 | pass
79 |
80 |
81 | # Predictor callbacks --------------------------------------------------------------------------------------------------
82 | def on_predict_start(predictor):
83 | pass
84 |
85 |
86 | def on_predict_batch_start(predictor):
87 | pass
88 |
89 |
90 | def on_predict_batch_end(predictor):
91 | pass
92 |
93 |
94 | def on_predict_postprocess_end(predictor):
95 | pass
96 |
97 |
98 | def on_predict_end(predictor):
99 | pass
100 |
101 |
102 | # Exporter callbacks ---------------------------------------------------------------------------------------------------
103 | def on_export_start(exporter):
104 | pass
105 |
106 |
107 | def on_export_end(exporter):
108 | pass
109 |
110 |
111 | default_callbacks = {
112 | # Run in trainer
113 | 'on_pretrain_routine_start': [on_pretrain_routine_start],
114 | 'on_pretrain_routine_end': [on_pretrain_routine_end],
115 | 'on_train_start': [on_train_start],
116 | 'on_train_epoch_start': [on_train_epoch_start],
117 | 'on_train_batch_start': [on_train_batch_start],
118 | 'optimizer_step': [optimizer_step],
119 | 'on_before_zero_grad': [on_before_zero_grad],
120 | 'on_train_batch_end': [on_train_batch_end],
121 | 'on_train_epoch_end': [on_train_epoch_end],
122 | 'on_fit_epoch_end': [on_fit_epoch_end], # fit = train + val
123 | 'on_model_save': [on_model_save],
124 | 'on_train_end': [on_train_end],
125 | 'on_params_update': [on_params_update],
126 | 'teardown': [teardown],
127 |
128 | # Run in validator
129 | 'on_val_start': [on_val_start],
130 | 'on_val_batch_start': [on_val_batch_start],
131 | 'on_val_batch_end': [on_val_batch_end],
132 | 'on_val_end': [on_val_end],
133 |
134 | # Run in predictor
135 | 'on_predict_start': [on_predict_start],
136 | 'on_predict_batch_start': [on_predict_batch_start],
137 | 'on_predict_postprocess_end': [on_predict_postprocess_end],
138 | 'on_predict_batch_end': [on_predict_batch_end],
139 | 'on_predict_end': [on_predict_end],
140 |
141 | # Run in exporter
142 | 'on_export_start': [on_export_start],
143 | 'on_export_end': [on_export_end]}
144 |
145 |
146 | def add_integration_callbacks(instance):
147 | from .clearml import callbacks as clearml_callbacks
148 | from .comet import callbacks as comet_callbacks
149 | from .hub import callbacks as hub_callbacks
150 | from .tensorboard import callbacks as tb_callbacks
151 |
152 | for x in clearml_callbacks, comet_callbacks, hub_callbacks, tb_callbacks:
153 | for k, v in x.items():
154 | if v not in instance.callbacks[k]: # prevent duplicate callbacks addition
155 | instance.callbacks[k].append(v) # callback[name].append(func)
156 |
--------------------------------------------------------------------------------
/ultralytics/yolo/utils/callbacks/clearml.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 | from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING
3 | from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params
4 |
5 | try:
6 | import clearml
7 | from clearml import Task
8 |
9 | assert clearml.__version__ # verify package is not directory
10 | assert not TESTS_RUNNING # do not log pytest
11 | except (ImportError, AssertionError):
12 | clearml = None
13 |
14 |
15 | def _log_images(imgs_dict, group='', step=0):
16 | task = Task.current_task()
17 | if task:
18 | for k, v in imgs_dict.items():
19 | task.get_logger().report_image(group, k, step, v)
20 |
21 |
22 | def on_pretrain_routine_start(trainer):
23 | try:
24 | task = Task.init(project_name=trainer.args.project or 'YOLOv8',
25 | task_name=trainer.args.name,
26 | tags=['YOLOv8'],
27 | output_uri=True,
28 | reuse_last_task_id=False,
29 | auto_connect_frameworks={'pytorch': False})
30 | task.connect(vars(trainer.args), name='General')
31 | except Exception as e:
32 | LOGGER.warning(f'WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}')
33 |
34 |
35 | def on_train_epoch_end(trainer):
36 | if trainer.epoch == 1:
37 | _log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, 'Mosaic', trainer.epoch)
38 |
39 |
40 | def on_fit_epoch_end(trainer):
41 | task = Task.current_task()
42 | if task and trainer.epoch == 0:
43 | model_info = {
44 | 'model/parameters': get_num_params(trainer.model),
45 | 'model/GFLOPs': round(get_flops(trainer.model), 3),
46 | 'model/speed(ms)': round(trainer.validator.speed['inference'], 3)}
47 | task.connect(model_info, name='Model')
48 |
49 |
50 | def on_train_end(trainer):
51 | task = Task.current_task()
52 | if task:
53 | task.update_output_model(model_path=str(trainer.best), model_name=trainer.args.name, auto_delete_file=False)
54 |
55 |
56 | callbacks = {
57 | 'on_pretrain_routine_start': on_pretrain_routine_start,
58 | 'on_train_epoch_end': on_train_epoch_end,
59 | 'on_fit_epoch_end': on_fit_epoch_end,
60 | 'on_train_end': on_train_end} if clearml else {}
61 |
--------------------------------------------------------------------------------
/ultralytics/yolo/utils/callbacks/comet.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 | from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING
3 | from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params
4 |
5 | try:
6 | import comet_ml
7 |
8 | assert not TESTS_RUNNING # do not log pytest
9 | assert comet_ml.__version__ # verify package is not directory
10 | except (ImportError, AssertionError):
11 | comet_ml = None
12 |
13 |
14 | def on_pretrain_routine_start(trainer):
15 | try:
16 | experiment = comet_ml.Experiment(project_name=trainer.args.project or 'YOLOv8')
17 | experiment.log_parameters(vars(trainer.args))
18 | except Exception as e:
19 | LOGGER.warning(f'WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}')
20 |
21 |
22 | def on_train_epoch_end(trainer):
23 | experiment = comet_ml.get_global_experiment()
24 | if experiment:
25 | experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix='train'), step=trainer.epoch + 1)
26 | if trainer.epoch == 1:
27 | for f in trainer.save_dir.glob('train_batch*.jpg'):
28 | experiment.log_image(f, name=f.stem, step=trainer.epoch + 1)
29 |
30 |
31 | def on_fit_epoch_end(trainer):
32 | experiment = comet_ml.get_global_experiment()
33 | if experiment:
34 | experiment.log_metrics(trainer.metrics, step=trainer.epoch + 1)
35 | if trainer.epoch == 0:
36 | model_info = {
37 | 'model/parameters': get_num_params(trainer.model),
38 | 'model/GFLOPs': round(get_flops(trainer.model), 3),
39 | 'model/speed(ms)': round(trainer.validator.speed['inference'], 3)}
40 | experiment.log_metrics(model_info, step=trainer.epoch + 1)
41 |
42 |
43 | def on_train_end(trainer):
44 | experiment = comet_ml.get_global_experiment()
45 | if experiment:
46 | experiment.log_model('YOLOv8', file_or_folder=str(trainer.best), file_name='best.pt', overwrite=True)
47 |
48 |
49 | callbacks = {
50 | 'on_pretrain_routine_start': on_pretrain_routine_start,
51 | 'on_train_epoch_end': on_train_epoch_end,
52 | 'on_fit_epoch_end': on_fit_epoch_end,
53 | 'on_train_end': on_train_end} if comet_ml else {}
54 |
--------------------------------------------------------------------------------
/ultralytics/yolo/utils/callbacks/hub.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | import json
4 | from time import time
5 |
6 | from ultralytics.hub.utils import PREFIX, traces
7 | from ultralytics.yolo.utils import LOGGER
8 | from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params
9 |
10 |
11 | def on_pretrain_routine_end(trainer):
12 | session = getattr(trainer, 'hub_session', None)
13 | if session:
14 | # Start timer for upload rate limit
15 | LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀')
16 | session.timers = {'metrics': time(), 'ckpt': time()} # start timer on session.rate_limit
17 |
18 |
19 | def on_fit_epoch_end(trainer):
20 | session = getattr(trainer, 'hub_session', None)
21 | if session:
22 | # Upload metrics after val end
23 | all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics}
24 | if trainer.epoch == 0:
25 | model_info = {
26 | 'model/parameters': get_num_params(trainer.model),
27 | 'model/GFLOPs': round(get_flops(trainer.model), 3),
28 | 'model/speed(ms)': round(trainer.validator.speed['inference'], 3)}
29 | all_plots = {**all_plots, **model_info}
30 | session.metrics_queue[trainer.epoch] = json.dumps(all_plots)
31 | if time() - session.timers['metrics'] > session.rate_limits['metrics']:
32 | session.upload_metrics()
33 | session.timers['metrics'] = time() # reset timer
34 | session.metrics_queue = {} # reset queue
35 |
36 |
37 | def on_model_save(trainer):
38 | session = getattr(trainer, 'hub_session', None)
39 | if session:
40 | # Upload checkpoints with rate limiting
41 | is_best = trainer.best_fitness == trainer.fitness
42 | if time() - session.timers['ckpt'] > session.rate_limits['ckpt']:
43 | LOGGER.info(f'{PREFIX}Uploading checkpoint {session.model_id}')
44 | session.upload_model(trainer.epoch, trainer.last, is_best)
45 | session.timers['ckpt'] = time() # reset timer
46 |
47 |
48 | def on_train_end(trainer):
49 | session = getattr(trainer, 'hub_session', None)
50 | if session:
51 | # Upload final model and metrics with exponential standoff
52 | LOGGER.info(f'{PREFIX}Syncing final model...')
53 | session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics.get('metrics/mAP50-95(B)', 0), final=True)
54 | session.alive = False # stop heartbeats
55 | LOGGER.info(f'{PREFIX}Done ✅\n'
56 | f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀')
57 |
58 |
59 | def on_train_start(trainer):
60 | traces(trainer.args, traces_sample_rate=1.0)
61 |
62 |
63 | def on_val_start(validator):
64 | traces(validator.args, traces_sample_rate=1.0)
65 |
66 |
67 | def on_predict_start(predictor):
68 | traces(predictor.args, traces_sample_rate=1.0)
69 |
70 |
71 | def on_export_start(exporter):
72 | traces(exporter.args, traces_sample_rate=1.0)
73 |
74 |
75 | callbacks = {
76 | 'on_pretrain_routine_end': on_pretrain_routine_end,
77 | 'on_fit_epoch_end': on_fit_epoch_end,
78 | 'on_model_save': on_model_save,
79 | 'on_train_end': on_train_end,
80 | 'on_train_start': on_train_start,
81 | 'on_val_start': on_val_start,
82 | 'on_predict_start': on_predict_start,
83 | 'on_export_start': on_export_start}
84 |
--------------------------------------------------------------------------------
/ultralytics/yolo/utils/callbacks/tensorboard.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 | from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING, colorstr
3 |
4 | try:
5 | from torch.utils.tensorboard import SummaryWriter
6 |
7 | assert not TESTS_RUNNING # do not log pytest
8 | except (ImportError, AssertionError):
9 | SummaryWriter = None
10 |
11 | writer = None # TensorBoard SummaryWriter instance
12 |
13 |
14 | def _log_scalars(scalars, step=0):
15 | if writer:
16 | for k, v in scalars.items():
17 | writer.add_scalar(k, v, step)
18 |
19 |
20 | def on_pretrain_routine_start(trainer):
21 | if SummaryWriter:
22 | try:
23 | global writer
24 | writer = SummaryWriter(str(trainer.save_dir))
25 | prefix = colorstr('TensorBoard: ')
26 | LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
27 | except Exception as e:
28 | LOGGER.warning(f'WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}')
29 |
30 |
31 | def on_batch_end(trainer):
32 | _log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1)
33 |
34 |
35 | def on_fit_epoch_end(trainer):
36 | _log_scalars(trainer.metrics, trainer.epoch + 1)
37 |
38 |
39 | callbacks = {
40 | 'on_pretrain_routine_start': on_pretrain_routine_start,
41 | 'on_fit_epoch_end': on_fit_epoch_end,
42 | 'on_batch_end': on_batch_end}
43 |
--------------------------------------------------------------------------------
/ultralytics/yolo/utils/dist.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | import os
4 | import re
5 | import shutil
6 | import socket
7 | import sys
8 | import tempfile
9 | from pathlib import Path
10 |
11 | from . import USER_CONFIG_DIR
12 | from .torch_utils import TORCH_1_9
13 |
14 |
15 | def find_free_network_port() -> int:
16 | """Finds a free port on localhost.
17 |
18 | It is useful in single-node training when we don't want to connect to a real main node but have to set the
19 | `MASTER_PORT` environment variable.
20 | """
21 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
22 | s.bind(('127.0.0.1', 0))
23 | return s.getsockname()[1] # port
24 |
25 |
26 | def generate_ddp_file(trainer):
27 | module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1)
28 |
29 | content = f'''cfg = {vars(trainer.args)} \nif __name__ == "__main__":
30 | from {module} import {name}
31 |
32 | trainer = {name}(cfg=cfg)
33 | trainer.train()'''
34 | (USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
35 | with tempfile.NamedTemporaryFile(prefix='_temp_',
36 | suffix=f'{id(trainer)}.py',
37 | mode='w+',
38 | encoding='utf-8',
39 | dir=USER_CONFIG_DIR / 'DDP',
40 | delete=False) as file:
41 | file.write(content)
42 | return file.name
43 |
44 |
45 | def generate_ddp_command(world_size, trainer):
46 | import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
47 | if not trainer.resume:
48 | shutil.rmtree(trainer.save_dir) # remove the save_dir
49 | file = str(Path(sys.argv[0]).resolve())
50 | safe_pattern = re.compile(r'^[a-zA-Z0-9_. /\\-]{1,128}$') # allowed characters and maximum of 100 characters
51 | if not (safe_pattern.match(file) and Path(file).exists() and file.endswith('.py')): # using CLI
52 | file = generate_ddp_file(trainer)
53 | dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
54 | port = find_free_network_port()
55 | exclude_args = ['save_dir']
56 | args = [f'{k}={v}' for k, v in vars(trainer.args).items() if k not in exclude_args]
57 | cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file] + args
58 | return cmd, file
59 |
60 |
61 | def ddp_cleanup(trainer, file):
62 | # delete temp file if created
63 | if f'{id(trainer)}.py' in file: # if temp_file suffix in file
64 | os.remove(file)
65 |
--------------------------------------------------------------------------------
/ultralytics/yolo/utils/files.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | import contextlib
4 | import glob
5 | import os
6 | import urllib
7 | from datetime import datetime
8 | from pathlib import Path
9 |
10 |
11 | class WorkingDirectory(contextlib.ContextDecorator):
12 | # Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
13 | def __init__(self, new_dir):
14 | self.dir = new_dir # new dir
15 | self.cwd = Path.cwd().resolve() # current dir
16 |
17 | def __enter__(self):
18 | os.chdir(self.dir)
19 |
20 | def __exit__(self, exc_type, exc_val, exc_tb):
21 | os.chdir(self.cwd)
22 |
23 |
24 | def increment_path(path, exist_ok=False, sep='', mkdir=False):
25 | """
26 | Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
27 |
28 | If the path exists and exist_ok is not set to True, the path will be incremented by appending a number and sep to
29 | the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the
30 | number will be appended directly to the end of the path. If mkdir is set to True, the path will be created as a
31 | directory if it does not already exist.
32 |
33 | Args:
34 | path (str or pathlib.Path): Path to increment.
35 | exist_ok (bool, optional): If True, the path will not be incremented and will be returned as-is. Defaults to False.
36 | sep (str, optional): Separator to use between the path and the incrementation number. Defaults to an empty string.
37 | mkdir (bool, optional): If True, the path will be created as a directory if it does not exist. Defaults to False.
38 |
39 | Returns:
40 | pathlib.Path: Incremented path.
41 | """
42 | path = Path(path) # os-agnostic
43 | if path.exists() and not exist_ok:
44 | path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
45 |
46 | # Method 1
47 | for n in range(2, 9999):
48 | p = f'{path}{sep}{n}{suffix}' # increment path
49 | if not os.path.exists(p): #
50 | break
51 | path = Path(p)
52 |
53 | if mkdir:
54 | path.mkdir(parents=True, exist_ok=True) # make directory
55 |
56 | return path
57 |
58 |
59 | def file_age(path=__file__):
60 | # Return days since last file update
61 | dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
62 | return dt.days # + dt.seconds / 86400 # fractional days
63 |
64 |
65 | def file_date(path=__file__):
66 | # Return human-readable file modification date, i.e. '2021-3-26'
67 | t = datetime.fromtimestamp(Path(path).stat().st_mtime)
68 | return f'{t.year}-{t.month}-{t.day}'
69 |
70 |
71 | def file_size(path):
72 | # Return file/dir size (MB)
73 | if isinstance(path, (str, Path)):
74 | mb = 1 << 20 # bytes to MiB (1024 ** 2)
75 | path = Path(path)
76 | if path.is_file():
77 | return path.stat().st_size / mb
78 | elif path.is_dir():
79 | return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
80 | return 0.0
81 |
82 |
83 | def url2file(url):
84 | # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
85 | url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
86 | return Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
87 |
88 |
89 | def get_latest_run(search_dir='.'):
90 | # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
91 | last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
92 | return max(last_list, key=os.path.getctime) if last_list else ''
93 |
--------------------------------------------------------------------------------
/ultralytics/yolo/utils/loss.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from .metrics import bbox_iou
8 | from .tal import bbox2dist
9 |
10 |
11 | class VarifocalLoss(nn.Module):
12 | # Varifocal loss by Zhang et al. https://arxiv.org/abs/2008.13367
13 | def __init__(self):
14 | super().__init__()
15 |
16 | def forward(self, pred_score, gt_score, label, alpha=0.75, gamma=2.0):
17 | weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
18 | with torch.cuda.amp.autocast(enabled=False):
19 | loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') *
20 | weight).sum()
21 | return loss
22 |
23 |
24 | class BboxLoss(nn.Module):
25 |
26 | def __init__(self, reg_max, use_dfl=False):
27 | super().__init__()
28 | self.reg_max = reg_max
29 | self.use_dfl = use_dfl
30 |
31 | def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
32 | # IoU loss
33 | weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1)
34 | iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
35 | loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
36 |
37 | # DFL loss
38 | if self.use_dfl:
39 | target_ltrb = bbox2dist(anchor_points, target_bboxes, self.reg_max)
40 | loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight
41 | loss_dfl = loss_dfl.sum() / target_scores_sum
42 | else:
43 | loss_dfl = torch.tensor(0.0).to(pred_dist.device)
44 |
45 | return loss_iou, loss_dfl
46 |
47 | @staticmethod
48 | def _df_loss(pred_dist, target):
49 | # Return sum of left and right DFL losses
50 | # Distribution Focal Loss (DFL) proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
51 | tl = target.long() # target left
52 | tr = tl + 1 # target right
53 | wl = tr - target # weight left
54 | wr = 1 - wl # weight right
55 | return (F.cross_entropy(pred_dist, tl.view(-1), reduction='none').view(tl.shape) * wl +
56 | F.cross_entropy(pred_dist, tr.view(-1), reduction='none').view(tl.shape) * wr).mean(-1, keepdim=True)
57 |
--------------------------------------------------------------------------------
/ultralytics/yolo/v8/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | from ultralytics.yolo.v8 import classify, detect, segment
4 |
5 | __all__ = 'classify', 'segment', 'detect'
6 |
--------------------------------------------------------------------------------
/ultralytics/yolo/v8/classify/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | from ultralytics.yolo.v8.classify.predict import ClassificationPredictor, predict
4 | from ultralytics.yolo.v8.classify.train import ClassificationTrainer, train
5 | from ultralytics.yolo.v8.classify.val import ClassificationValidator, val
6 |
7 | __all__ = 'ClassificationPredictor', 'predict', 'ClassificationTrainer', 'train', 'ClassificationValidator', 'val'
8 |
--------------------------------------------------------------------------------
/ultralytics/yolo/v8/classify/predict.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | import torch
4 |
5 | from ultralytics.yolo.engine.predictor import BasePredictor
6 | from ultralytics.yolo.engine.results import Results
7 | from ultralytics.yolo.utils import DEFAULT_CFG, ROOT
8 | from ultralytics.yolo.utils.plotting import Annotator
9 |
10 |
11 | class ClassificationPredictor(BasePredictor):
12 |
13 | def get_annotator(self, img):
14 | return Annotator(img, example=str(self.model.names), pil=True)
15 |
16 | def preprocess(self, img):
17 | img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
18 | return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
19 |
20 | def postprocess(self, preds, img, orig_imgs):
21 | results = []
22 | for i, pred in enumerate(preds):
23 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
24 | path, _, _, _, _ = self.batch
25 | img_path = path[i] if isinstance(path, list) else path
26 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, probs=pred))
27 |
28 | return results
29 |
30 | def write_results(self, idx, results, batch):
31 | p, im, im0 = batch
32 | log_string = ''
33 | if len(im.shape) == 3:
34 | im = im[None] # expand for batch dim
35 | self.seen += 1
36 | im0 = im0.copy()
37 | if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1
38 | log_string += f'{idx}: '
39 | frame = self.dataset.count
40 | else:
41 | frame = getattr(self.dataset, 'frame', 0)
42 |
43 | self.data_path = p
44 | # save_path = str(self.save_dir / p.name) # im.jpg
45 | self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
46 | log_string += '%gx%g ' % im.shape[2:] # print string
47 | self.annotator = self.get_annotator(im0)
48 |
49 | result = results[idx]
50 | if len(result) == 0:
51 | return log_string
52 | prob = result.probs
53 | # Print results
54 | n5 = min(len(self.model.names), 5)
55 | top5i = prob.argsort(0, descending=True)[:n5].tolist() # top 5 indices
56 | log_string += f"{', '.join(f'{self.model.names[j]} {prob[j]:.2f}' for j in top5i)}, "
57 |
58 | # write
59 | text = '\n'.join(f'{prob[j]:.2f} {self.model.names[j]}' for j in top5i)
60 | if self.args.save or self.args.show: # Add bbox to image
61 | self.annotator.text((32, 32), text, txt_color=(255, 255, 255))
62 | if self.args.save_txt: # Write to file
63 | with open(f'{self.txt_path}.txt', 'a') as f:
64 | f.write(text + '\n')
65 |
66 | return log_string
67 |
68 |
69 | def predict(cfg=DEFAULT_CFG, use_python=False):
70 | model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
71 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
72 | else 'https://ultralytics.com/images/bus.jpg'
73 |
74 | args = dict(model=model, source=source)
75 | if use_python:
76 | from ultralytics import YOLO
77 | YOLO(model)(**args)
78 | else:
79 | predictor = ClassificationPredictor(overrides=args)
80 | predictor.predict_cli()
81 |
82 |
83 | if __name__ == '__main__':
84 | predict()
85 |
--------------------------------------------------------------------------------
/ultralytics/yolo/v8/classify/train.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | import torch
4 | import torchvision
5 |
6 | from ultralytics.nn.tasks import ClassificationModel, attempt_load_one_weight
7 | from ultralytics.yolo import v8
8 | from ultralytics.yolo.data import build_classification_dataloader
9 | from ultralytics.yolo.engine.trainer import BaseTrainer
10 | from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, colorstr
11 | from ultralytics.yolo.utils.torch_utils import is_parallel, strip_optimizer
12 |
13 |
14 | class ClassificationTrainer(BaseTrainer):
15 |
16 | def __init__(self, cfg=DEFAULT_CFG, overrides=None):
17 | if overrides is None:
18 | overrides = {}
19 | overrides['task'] = 'classify'
20 | super().__init__(cfg, overrides)
21 |
22 | def set_model_attributes(self):
23 | self.model.names = self.data['names']
24 |
25 | def get_model(self, cfg=None, weights=None, verbose=True):
26 | model = ClassificationModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
27 | if weights:
28 | model.load(weights)
29 |
30 | pretrained = False
31 | for m in model.modules():
32 | if not pretrained and hasattr(m, 'reset_parameters'):
33 | m.reset_parameters()
34 | if isinstance(m, torch.nn.Dropout) and self.args.dropout:
35 | m.p = self.args.dropout # set dropout
36 | for p in model.parameters():
37 | p.requires_grad = True # for training
38 |
39 | # Update defaults
40 | if self.args.imgsz == 640:
41 | self.args.imgsz = 224
42 |
43 | return model
44 |
45 | def setup_model(self):
46 | """
47 | load/create/download model for any task
48 | """
49 | # classification models require special handling
50 |
51 | if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
52 | return
53 |
54 | model = str(self.model)
55 | # Load a YOLO model locally, from torchvision, or from Ultralytics assets
56 | if model.endswith('.pt'):
57 | self.model, _ = attempt_load_one_weight(model, device='cpu')
58 | for p in self.model.parameters():
59 | p.requires_grad = True # for training
60 | elif model.endswith('.yaml'):
61 | self.model = self.get_model(cfg=model)
62 | elif model in torchvision.models.__dict__:
63 | pretrained = True
64 | self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
65 | else:
66 | FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
67 | ClassificationModel.reshape_outputs(self.model, self.data['nc'])
68 |
69 | return # dont return ckpt. Classification doesn't support resume
70 |
71 | def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
72 | loader = build_classification_dataloader(path=dataset_path,
73 | imgsz=self.args.imgsz,
74 | batch_size=batch_size if mode == 'train' else (batch_size * 2),
75 | augment=mode == 'train',
76 | rank=rank,
77 | workers=self.args.workers)
78 | # Attach inference transforms
79 | if mode != 'train':
80 | if is_parallel(self.model):
81 | self.model.module.transforms = loader.dataset.torch_transforms
82 | else:
83 | self.model.transforms = loader.dataset.torch_transforms
84 | return loader
85 |
86 | def preprocess_batch(self, batch):
87 | batch['img'] = batch['img'].to(self.device)
88 | batch['cls'] = batch['cls'].to(self.device)
89 | return batch
90 |
91 | def progress_string(self):
92 | return ('\n' + '%11s' * (4 + len(self.loss_names))) % \
93 | ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
94 |
95 | def get_validator(self):
96 | self.loss_names = ['loss']
97 | return v8.classify.ClassificationValidator(self.test_loader, self.save_dir)
98 |
99 | def criterion(self, preds, batch):
100 | loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / self.args.nbs
101 | loss_items = loss.detach()
102 | return loss, loss_items
103 |
104 | # def label_loss_items(self, loss_items=None, prefix="train"):
105 | # """
106 | # Returns a loss dict with labelled training loss items tensor
107 | # """
108 | # # Not needed for classification but necessary for segmentation & detection
109 | # keys = [f"{prefix}/{x}" for x in self.loss_names]
110 | # if loss_items is not None:
111 | # loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
112 | # return dict(zip(keys, loss_items))
113 | # else:
114 | # return keys
115 |
116 | def label_loss_items(self, loss_items=None, prefix='train'):
117 | """
118 | Returns a loss dict with labelled training loss items tensor
119 | """
120 | # Not needed for classification but necessary for segmentation & detection
121 | keys = [f'{prefix}/{x}' for x in self.loss_names]
122 | if loss_items is None:
123 | return keys
124 | loss_items = [round(float(loss_items), 5)]
125 | return dict(zip(keys, loss_items))
126 |
127 | def resume_training(self, ckpt):
128 | pass
129 |
130 | def final_eval(self):
131 | for f in self.last, self.best:
132 | if f.exists():
133 | strip_optimizer(f) # strip optimizers
134 | # TODO: validate best.pt after training completes
135 | # if f is self.best:
136 | # LOGGER.info(f'\nValidating {f}...')
137 | # self.validator.args.save_json = True
138 | # self.metrics = self.validator(model=f)
139 | # self.metrics.pop('fitness', None)
140 | # self.run_callbacks('on_fit_epoch_end')
141 | LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
142 |
143 |
144 | def train(cfg=DEFAULT_CFG, use_python=False):
145 | model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
146 | data = cfg.data or 'mnist160' # or yolo.ClassificationDataset("mnist")
147 | device = cfg.device if cfg.device is not None else ''
148 |
149 | args = dict(model=model, data=data, device=device)
150 | if use_python:
151 | from ultralytics import YOLO
152 | YOLO(model).train(**args)
153 | else:
154 | trainer = ClassificationTrainer(overrides=args)
155 | trainer.train()
156 |
157 |
158 | if __name__ == '__main__':
159 | train()
160 |
--------------------------------------------------------------------------------
/ultralytics/yolo/v8/classify/val.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | from ultralytics.yolo.data import build_classification_dataloader
4 | from ultralytics.yolo.engine.validator import BaseValidator
5 | from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER
6 | from ultralytics.yolo.utils.metrics import ClassifyMetrics
7 |
8 |
9 | class ClassificationValidator(BaseValidator):
10 |
11 | def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None):
12 | super().__init__(dataloader, save_dir, pbar, args)
13 | self.args.task = 'classify'
14 | self.metrics = ClassifyMetrics()
15 |
16 | def get_desc(self):
17 | return ('%22s' + '%11s' * 2) % ('classes', 'top1_acc', 'top5_acc')
18 |
19 | def init_metrics(self, model):
20 | self.pred = []
21 | self.targets = []
22 |
23 | def preprocess(self, batch):
24 | batch['img'] = batch['img'].to(self.device, non_blocking=True)
25 | batch['img'] = batch['img'].half() if self.args.half else batch['img'].float()
26 | batch['cls'] = batch['cls'].to(self.device)
27 | return batch
28 |
29 | def update_metrics(self, preds, batch):
30 | n5 = min(len(self.model.names), 5)
31 | self.pred.append(preds.argsort(1, descending=True)[:, :n5])
32 | self.targets.append(batch['cls'])
33 |
34 | def finalize_metrics(self, *args, **kwargs):
35 | self.metrics.speed = self.speed
36 | # self.metrics.confusion_matrix = self.confusion_matrix # TODO: classification ConfusionMatrix
37 |
38 | def get_stats(self):
39 | self.metrics.process(self.targets, self.pred)
40 | return self.metrics.results_dict
41 |
42 | def get_dataloader(self, dataset_path, batch_size):
43 | return build_classification_dataloader(path=dataset_path,
44 | imgsz=self.args.imgsz,
45 | batch_size=batch_size,
46 | augment=False,
47 | shuffle=False,
48 | workers=self.args.workers)
49 |
50 | def print_results(self):
51 | pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format
52 | LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5))
53 |
54 |
55 | def val(cfg=DEFAULT_CFG, use_python=False):
56 | model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
57 | data = cfg.data or 'mnist160'
58 |
59 | args = dict(model=model, data=data)
60 | if use_python:
61 | from ultralytics import YOLO
62 | YOLO(model).val(**args)
63 | else:
64 | validator = ClassificationValidator(args=args)
65 | validator(model=args['model'])
66 |
67 |
68 | if __name__ == '__main__':
69 | val()
70 |
--------------------------------------------------------------------------------
/ultralytics/yolo/v8/detect/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | from .predict import DetectionPredictor, predict
4 | from .train import DetectionTrainer, train
5 | from .val import DetectionValidator, val
6 |
7 | __all__ = 'DetectionPredictor', 'predict', 'DetectionTrainer', 'train', 'DetectionValidator', 'val'
8 |
--------------------------------------------------------------------------------
/ultralytics/yolo/v8/detect/predict.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | import torch
4 |
5 | from ultralytics.yolo.engine.predictor import BasePredictor
6 | from ultralytics.yolo.engine.results import Results
7 | from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops
8 | from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
9 |
10 |
11 | class DetectionPredictor(BasePredictor):
12 |
13 | def get_annotator(self, img):
14 | return Annotator(img, line_width=self.args.line_thickness, example=str(self.model.names))
15 |
16 | def preprocess(self, img):
17 | img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
18 | img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
19 | img /= 255 # 0 - 255 to 0.0 - 1.0
20 | return img
21 |
22 | def postprocess(self, preds, img, orig_imgs):
23 | preds = ops.non_max_suppression(preds,
24 | self.args.conf,
25 | self.args.iou,
26 | agnostic=self.args.agnostic_nms,
27 | max_det=self.args.max_det,
28 | classes=self.args.classes)
29 |
30 | results = []
31 | for i, pred in enumerate(preds):
32 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
33 | if not isinstance(orig_imgs, torch.Tensor):
34 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
35 | path, _, _, _, _ = self.batch
36 | img_path = path[i] if isinstance(path, list) else path
37 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred))
38 | return results
39 |
40 | def write_results(self, idx, results, batch):
41 | p, im, im0 = batch
42 | log_string = ''
43 | if len(im.shape) == 3:
44 | im = im[None] # expand for batch dim
45 | self.seen += 1
46 | imc = im0.copy() if self.args.save_crop else im0
47 | if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1
48 | log_string += f'{idx}: '
49 | frame = self.dataset.count
50 | else:
51 | frame = getattr(self.dataset, 'frame', 0)
52 | self.data_path = p
53 | self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
54 | log_string += '%gx%g ' % im.shape[2:] # print string
55 | self.annotator = self.get_annotator(im0)
56 |
57 | det = results[idx].boxes # TODO: make boxes inherit from tensors
58 | if len(det) == 0:
59 | return f'{log_string}(no detections), '
60 | for c in det.cls.unique():
61 | n = (det.cls == c).sum() # detections per class
62 | log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, "
63 |
64 | # write
65 | for d in reversed(det):
66 | c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item())
67 | if self.args.save_txt: # Write to file
68 | line = (c, *d.xywhn.view(-1)) + (conf, ) * self.args.save_conf + (() if id is None else (id, ))
69 | with open(f'{self.txt_path}.txt', 'a') as f:
70 | f.write(('%g ' * len(line)).rstrip() % line + '\n')
71 | if self.args.save or self.args.show: # Add bbox to image
72 | name = ('' if id is None else f'id:{id} ') + self.model.names[c]
73 | label = None if self.args.hide_labels else (name if self.args.hide_conf else f'{name} {conf:.2f}')
74 | self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
75 | if self.args.save_crop:
76 | save_one_box(d.xyxy,
77 | imc,
78 | file=self.save_dir / 'crops' / self.model.names[c] / f'{self.data_path.stem}.jpg',
79 | BGR=True)
80 |
81 | return log_string
82 |
83 |
84 | def predict(cfg=DEFAULT_CFG, use_python=False):
85 | model = cfg.model or 'yolov8n.pt'
86 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
87 | else 'https://ultralytics.com/images/bus.jpg'
88 |
89 | args = dict(model=model, source=source)
90 | if use_python:
91 | from ultralytics import YOLO
92 | YOLO(model)(**args)
93 | else:
94 | predictor = DetectionPredictor(overrides=args)
95 | predictor.predict_cli()
96 |
97 |
98 | if __name__ == '__main__':
99 | predict()
100 |
--------------------------------------------------------------------------------
/ultralytics/yolo/v8/segment/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | from .predict import SegmentationPredictor, predict
4 | from .train import SegmentationTrainer, train
5 | from .val import SegmentationValidator, val
6 |
7 | __all__ = 'SegmentationPredictor', 'predict', 'SegmentationTrainer', 'train', 'SegmentationValidator', 'val'
8 |
--------------------------------------------------------------------------------
/ultralytics/yolo/v8/segment/predict.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | import torch
4 |
5 | from ultralytics.yolo.engine.results import Results
6 | from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops
7 | from ultralytics.yolo.utils.plotting import colors, save_one_box
8 | from ultralytics.yolo.v8.detect.predict import DetectionPredictor
9 |
10 |
11 | class SegmentationPredictor(DetectionPredictor):
12 |
13 | def postprocess(self, preds, img, orig_imgs):
14 | # TODO: filter by classes
15 | p = ops.non_max_suppression(preds[0],
16 | self.args.conf,
17 | self.args.iou,
18 | agnostic=self.args.agnostic_nms,
19 | max_det=self.args.max_det,
20 | nc=len(self.model.names),
21 | classes=self.args.classes)
22 | results = []
23 | proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
24 | for i, pred in enumerate(p):
25 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
26 | path, _, _, _, _ = self.batch
27 | img_path = path[i] if isinstance(path, list) else path
28 | if not len(pred): # save empty boxes
29 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6]))
30 | continue
31 | if self.args.retina_masks:
32 | if not isinstance(orig_imgs, torch.Tensor):
33 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
34 | masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
35 | else:
36 | masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
37 | if not isinstance(orig_imgs, torch.Tensor):
38 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
39 | results.append(
40 | Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
41 | return results
42 |
43 | def write_results(self, idx, results, batch):
44 | p, im, im0 = batch
45 | log_string = ''
46 | if len(im.shape) == 3:
47 | im = im[None] # expand for batch dim
48 | self.seen += 1
49 | imc = im0.copy() if self.args.save_crop else im0
50 | if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1
51 | log_string += f'{idx}: '
52 | frame = self.dataset.count
53 | else:
54 | frame = getattr(self.dataset, 'frame', 0)
55 |
56 | self.data_path = p
57 | self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
58 | log_string += '%gx%g ' % im.shape[2:] # print string
59 | self.annotator = self.get_annotator(im0)
60 |
61 | result = results[idx]
62 | if len(result) == 0:
63 | return f'{log_string}(no detections), '
64 | det, mask = result.boxes, result.masks # getting tensors TODO: mask mask,box inherit for tensor
65 |
66 | # Print results
67 | for c in det.cls.unique():
68 | n = (det.cls == c).sum() # detections per class
69 | log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, "
70 |
71 | # Mask plotting
72 | if self.args.save or self.args.show:
73 | im_gpu = torch.as_tensor(im0, dtype=torch.float16, device=mask.masks.device).permute(
74 | 2, 0, 1).flip(0).contiguous() / 255 if self.args.retina_masks else im[idx]
75 | self.annotator.masks(masks=mask.masks, colors=[colors(x, True) for x in det.cls], im_gpu=im_gpu)
76 |
77 | # Write results
78 | for j, d in enumerate(reversed(det)):
79 | c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item())
80 | if self.args.save_txt: # Write to file
81 | seg = mask.segments[len(det) - j - 1].copy().reshape(-1) # reversed mask.segments, (n,2) to (n*2)
82 | line = (c, *seg) + (conf, ) * self.args.save_conf + (() if id is None else (id, ))
83 | with open(f'{self.txt_path}.txt', 'a') as f:
84 | f.write(('%g ' * len(line)).rstrip() % line + '\n')
85 | if self.args.save or self.args.show: # Add bbox to image
86 | name = ('' if id is None else f'id:{id} ') + self.model.names[c]
87 | label = None if self.args.hide_labels else (name if self.args.hide_conf else f'{name} {conf:.2f}')
88 | if self.args.boxes:
89 | self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
90 | if self.args.save_crop:
91 | save_one_box(d.xyxy,
92 | imc,
93 | file=self.save_dir / 'crops' / self.model.names[c] / f'{self.data_path.stem}.jpg',
94 | BGR=True)
95 |
96 | return log_string
97 |
98 |
99 | def predict(cfg=DEFAULT_CFG, use_python=False):
100 | model = cfg.model or 'yolov8n-seg.pt'
101 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
102 | else 'https://ultralytics.com/images/bus.jpg'
103 |
104 | args = dict(model=model, source=source)
105 | if use_python:
106 | from ultralytics import YOLO
107 | YOLO(model)(**args)
108 | else:
109 | predictor = SegmentationPredictor(overrides=args)
110 | predictor.predict_cli()
111 |
112 |
113 | if __name__ == '__main__':
114 | predict()
115 |
--------------------------------------------------------------------------------
/ultralytics/yolo/v8/segment/train.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 | from copy import copy
3 |
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | from ultralytics.nn.tasks import SegmentationModel
8 | from ultralytics.yolo import v8
9 | from ultralytics.yolo.utils import DEFAULT_CFG, RANK
10 | from ultralytics.yolo.utils.ops import crop_mask, xyxy2xywh
11 | from ultralytics.yolo.utils.plotting import plot_images, plot_results
12 | from ultralytics.yolo.utils.tal import make_anchors
13 | from ultralytics.yolo.utils.torch_utils import de_parallel
14 | from ultralytics.yolo.v8.detect.train import Loss
15 |
16 |
17 | # BaseTrainer python usage
18 | class SegmentationTrainer(v8.detect.DetectionTrainer):
19 |
20 | def __init__(self, cfg=DEFAULT_CFG, overrides=None):
21 | if overrides is None:
22 | overrides = {}
23 | overrides['task'] = 'segment'
24 | super().__init__(cfg, overrides)
25 |
26 | def get_model(self, cfg=None, weights=None, verbose=True):
27 | model = SegmentationModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1)
28 | if weights:
29 | model.load(weights)
30 |
31 | return model
32 |
33 | def get_validator(self):
34 | self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss'
35 | return v8.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
36 |
37 | def criterion(self, preds, batch):
38 | if not hasattr(self, 'compute_loss'):
39 | self.compute_loss = SegLoss(de_parallel(self.model), overlap=self.args.overlap_mask)
40 | return self.compute_loss(preds, batch)
41 |
42 | def plot_training_samples(self, batch, ni):
43 | images = batch['img']
44 | masks = batch['masks']
45 | cls = batch['cls'].squeeze(-1)
46 | bboxes = batch['bboxes']
47 | paths = batch['im_file']
48 | batch_idx = batch['batch_idx']
49 | plot_images(images, batch_idx, cls, bboxes, masks, paths=paths, fname=self.save_dir / f'train_batch{ni}.jpg')
50 |
51 | def plot_metrics(self):
52 | plot_results(file=self.csv, segment=True) # save results.png
53 |
54 |
55 | # Criterion class for computing training losses
56 | class SegLoss(Loss):
57 |
58 | def __init__(self, model, overlap=True): # model must be de-paralleled
59 | super().__init__(model)
60 | self.nm = model.model[-1].nm # number of masks
61 | self.overlap = overlap
62 |
63 | def __call__(self, preds, batch):
64 | loss = torch.zeros(4, device=self.device) # box, cls, dfl
65 | feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
66 | batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
67 | pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
68 | (self.reg_max * 4, self.nc), 1)
69 |
70 | # b, grids, ..
71 | pred_scores = pred_scores.permute(0, 2, 1).contiguous()
72 | pred_distri = pred_distri.permute(0, 2, 1).contiguous()
73 | pred_masks = pred_masks.permute(0, 2, 1).contiguous()
74 |
75 | dtype = pred_scores.dtype
76 | imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
77 | anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
78 |
79 | # targets
80 | try:
81 | batch_idx = batch['batch_idx'].view(-1, 1)
82 | targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
83 | targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
84 | gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
85 | mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
86 | except RuntimeError as e:
87 | raise TypeError('ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n'
88 | "This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
89 | "i.e. 'yolo train model=yolov8n-seg.pt data=coco128.yaml'.\nVerify your dataset is a "
90 | "correctly formatted 'segment' dataset using 'data=coco128-seg.yaml' "
91 | 'as an example.\nSee https://docs.ultralytics.com/tasks/segment/ for help.') from e
92 |
93 | # pboxes
94 | pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
95 |
96 | _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
97 | pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
98 | anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
99 |
100 | target_scores_sum = max(target_scores.sum(), 1)
101 |
102 | # cls loss
103 | # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
104 | loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
105 |
106 | if fg_mask.sum():
107 | # bbox loss
108 | loss[0], loss[3] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes / stride_tensor,
109 | target_scores, target_scores_sum, fg_mask)
110 | # masks loss
111 | masks = batch['masks'].to(self.device).float()
112 | if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
113 | masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0]
114 |
115 | for i in range(batch_size):
116 | if fg_mask[i].sum():
117 | mask_idx = target_gt_idx[i][fg_mask[i]]
118 | if self.overlap:
119 | gt_mask = torch.where(masks[[i]] == (mask_idx + 1).view(-1, 1, 1), 1.0, 0.0)
120 | else:
121 | gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
122 | xyxyn = target_bboxes[i][fg_mask[i]] / imgsz[[1, 0, 1, 0]]
123 | marea = xyxy2xywh(xyxyn)[:, 2:].prod(1)
124 | mxyxy = xyxyn * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device)
125 | loss[1] += self.single_mask_loss(gt_mask, pred_masks[i][fg_mask[i]], proto[i], mxyxy, marea) # seg
126 |
127 | # WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
128 | else:
129 | loss[1] += proto.sum() * 0 + pred_masks.sum() * 0
130 |
131 | # WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
132 | else:
133 | loss[1] += proto.sum() * 0 + pred_masks.sum() * 0
134 |
135 | loss[0] *= self.hyp.box # box gain
136 | loss[1] *= self.hyp.box / batch_size # seg gain
137 | loss[2] *= self.hyp.cls # cls gain
138 | loss[3] *= self.hyp.dfl # dfl gain
139 |
140 | return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
141 |
142 | def single_mask_loss(self, gt_mask, pred, proto, xyxy, area):
143 | # Mask loss for one image
144 | pred_mask = (pred @ proto.view(self.nm, -1)).view(-1, *proto.shape[1:]) # (n, 32) @ (32,80,80) -> (n,80,80)
145 | loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction='none')
146 | return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
147 |
148 |
149 | def train(cfg=DEFAULT_CFG, use_python=False):
150 | model = cfg.model or 'yolov8n-seg.pt'
151 | data = cfg.data or 'coco128-seg.yaml' # or yolo.ClassificationDataset("mnist")
152 | device = cfg.device if cfg.device is not None else ''
153 |
154 | args = dict(model=model, data=data, device=device)
155 | if use_python:
156 | from ultralytics import YOLO
157 | YOLO(model).train(**args)
158 | else:
159 | trainer = SegmentationTrainer(overrides=args)
160 | trainer.train()
161 |
162 |
163 | if __name__ == '__main__':
164 | train()
165 |
--------------------------------------------------------------------------------
/userInfo.csv:
--------------------------------------------------------------------------------
1 | vscode,123
2 | 123123,MBL
3 | 123123l,mbl
4 | 123456,123456
5 | 123123456,123123456
6 | 123456789,123456789
7 | 789789,789789
8 | 147147,147147
9 | 148148,148148
10 |
--------------------------------------------------------------------------------
/utils/autobatch.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 | """
3 | Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch.
4 | """
5 |
6 | from copy import deepcopy
7 |
8 | import numpy as np
9 | import torch
10 |
11 | from ultralytics.yolo.utils import LOGGER, colorstr
12 | from ultralytics.yolo.utils.torch_utils import profile
13 |
14 |
15 | def check_train_batch_size(model, imgsz=640, amp=True):
16 | """
17 | Check YOLO training batch size using the autobatch() function.
18 |
19 | Args:
20 | model (torch.nn.Module): YOLO model to check batch size for.
21 | imgsz (int): Image size used for training.
22 | amp (bool): If True, use automatic mixed precision (AMP) for training.
23 |
24 | Returns:
25 | int: Optimal batch size computed using the autobatch() function.
26 | """
27 |
28 | with torch.cuda.amp.autocast(amp):
29 | return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
30 |
31 |
32 | def autobatch(model, imgsz=640, fraction=0.67, batch_size=16):
33 | """
34 | Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory.
35 |
36 | Args:
37 | model: YOLO model to compute batch size for.
38 | imgsz (int, optional): The image size used as input for the YOLO model. Defaults to 640.
39 | fraction (float, optional): The fraction of available CUDA memory to use. Defaults to 0.67.
40 | batch_size (int, optional): The default batch size to use if an error is detected. Defaults to 16.
41 |
42 | Returns:
43 | int: The optimal batch size.
44 | """
45 |
46 | # Check device
47 | prefix = colorstr('AutoBatch: ')
48 | LOGGER.info(f'{prefix}Computing optimal batch size for imgsz={imgsz}')
49 | device = next(model.parameters()).device # get model device
50 | if device.type == 'cpu':
51 | LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')
52 | return batch_size
53 | if torch.backends.cudnn.benchmark:
54 | LOGGER.info(f'{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}')
55 | return batch_size
56 |
57 | # Inspect CUDA memory
58 | gb = 1 << 30 # bytes to GiB (1024 ** 3)
59 | d = str(device).upper() # 'CUDA:0'
60 | properties = torch.cuda.get_device_properties(device) # device properties
61 | t = properties.total_memory / gb # GiB total
62 | r = torch.cuda.memory_reserved(device) / gb # GiB reserved
63 | a = torch.cuda.memory_allocated(device) / gb # GiB allocated
64 | f = t - (r + a) # GiB free
65 | LOGGER.info(f'{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free')
66 |
67 | # Profile batch sizes
68 | batch_sizes = [1, 2, 4, 8, 16]
69 | try:
70 | img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
71 | results = profile(img, model, n=3, device=device)
72 |
73 | # Fit a solution
74 | y = [x[2] for x in results if x] # memory [2]
75 | p = np.polyfit(batch_sizes[:len(y)], y, deg=1) # first degree polynomial fit
76 | b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size)
77 | if None in results: # some sizes failed
78 | i = results.index(None) # first fail index
79 | if b >= batch_sizes[i]: # y intercept above failure point
80 | b = batch_sizes[max(i - 1, 0)] # select prior safe point
81 | if b < 1 or b > 1024: # b outside of safe range
82 | b = batch_size
83 | LOGGER.info(f'{prefix}WARNING ⚠️ CUDA anomaly detected, using default batch-size {batch_size}.')
84 |
85 | fraction = (np.polyval(p, b) + r + a) / t # actual fraction predicted
86 | LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅')
87 | return b
88 | except Exception as e:
89 | LOGGER.warning(f'{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.')
90 | return batch_size
91 |
--------------------------------------------------------------------------------
/utils/benchmarks.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 | """
3 | Benchmark a YOLO model formats for speed and accuracy
4 |
5 | Usage:
6 | from ultralytics.yolo.utils.benchmarks import run_benchmarks
7 | run_benchmarks(model='yolov8n.pt', imgsz=160)
8 |
9 | Format | `format=argument` | Model
10 | --- | --- | ---
11 | PyTorch | - | yolov8n.pt
12 | TorchScript | `torchscript` | yolov8n.torchscript
13 | ONNX | `onnx` | yolov8n.onnx
14 | OpenVINO | `openvino` | yolov8n_openvino_model/
15 | TensorRT | `engine` | yolov8n.engine
16 | CoreML | `coreml` | yolov8n.mlmodel
17 | TensorFlow SavedModel | `saved_model` | yolov8n_saved_model/
18 | TensorFlow GraphDef | `pb` | yolov8n.pb
19 | TensorFlow Lite | `tflite` | yolov8n.tflite
20 | TensorFlow Edge TPU | `edgetpu` | yolov8n_edgetpu.tflite
21 | TensorFlow.js | `tfjs` | yolov8n_web_model/
22 | PaddlePaddle | `paddle` | yolov8n_paddle_model/
23 | """
24 |
25 | import platform
26 | import time
27 | from pathlib import Path
28 |
29 | from ultralytics import YOLO
30 | from ultralytics.yolo.engine.exporter import export_formats
31 | from ultralytics.yolo.utils import LINUX, LOGGER, MACOS, ROOT, SETTINGS
32 | from ultralytics.yolo.utils.checks import check_yolo
33 | from ultralytics.yolo.utils.downloads import download
34 | from ultralytics.yolo.utils.files import file_size
35 | from ultralytics.yolo.utils.torch_utils import select_device
36 |
37 |
38 | def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt', imgsz=160, half=False, device='cpu', hard_fail=False):
39 | import pandas as pd
40 | pd.options.display.max_columns = 10
41 | pd.options.display.width = 120
42 | device = select_device(device, verbose=False)
43 | if isinstance(model, (str, Path)):
44 | model = YOLO(model)
45 |
46 | y = []
47 | t0 = time.time()
48 | for i, (name, format, suffix, cpu, gpu) in export_formats().iterrows(): # index, (name, format, suffix, CPU, GPU)
49 | emoji, filename = '❌', None # export defaults
50 | try:
51 | if model.task == 'classify':
52 | assert i != 11, 'paddle cls exports coming soon'
53 | assert i != 9 or LINUX, 'Edge TPU export only supported on Linux'
54 | if i == 10:
55 | assert MACOS or LINUX, 'TF.js export only supported on macOS and Linux'
56 | if 'cpu' in device.type:
57 | assert cpu, 'inference not supported on CPU'
58 | if 'cuda' in device.type:
59 | assert gpu, 'inference not supported on GPU'
60 |
61 | # Export
62 | if format == '-':
63 | filename = model.ckpt_path or model.cfg
64 | export = model # PyTorch format
65 | else:
66 | filename = model.export(imgsz=imgsz, format=format, half=half, device=device) # all others
67 | export = YOLO(filename, task=model.task)
68 | assert suffix in str(filename), 'export failed'
69 | emoji = '❎' # indicates export succeeded
70 |
71 | # Predict
72 | assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported
73 | assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML
74 | if not (ROOT / 'assets/bus.jpg').exists():
75 | download(url='https://ultralytics.com/images/bus.jpg', dir=ROOT / 'assets')
76 | export.predict(ROOT / 'assets/bus.jpg', imgsz=imgsz, device=device, half=half)
77 |
78 | # Validate
79 | if model.task == 'detect':
80 | data, key = 'coco128.yaml', 'metrics/mAP50-95(B)'
81 | elif model.task == 'segment':
82 | data, key = 'coco128-seg.yaml', 'metrics/mAP50-95(M)'
83 | elif model.task == 'classify':
84 | data, key = 'imagenet100', 'metrics/accuracy_top5'
85 |
86 | results = export.val(data=data, batch=1, imgsz=imgsz, plots=False, device=device, half=half, verbose=False)
87 | metric, speed = results.results_dict[key], results.speed['inference']
88 | y.append([name, '✅', round(file_size(filename), 1), round(metric, 4), round(speed, 2)])
89 | except Exception as e:
90 | if hard_fail:
91 | assert type(e) is AssertionError, f'Benchmark hard_fail for {name}: {e}'
92 | LOGGER.warning(f'ERROR ❌️ Benchmark failure for {name}: {e}')
93 | y.append([name, emoji, round(file_size(filename), 1), None, None]) # mAP, t_inference
94 |
95 | # Print results
96 | check_yolo(device=device) # print system info
97 | df = pd.DataFrame(y, columns=['Format', 'Status❔', 'Size (MB)', key, 'Inference time (ms/im)'])
98 |
99 | name = Path(model.ckpt_path).name
100 | s = f'\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({time.time() - t0:.2f}s)\n{df}\n'
101 | LOGGER.info(s)
102 | with open('benchmarks.log', 'a', errors='ignore', encoding='utf-8') as f:
103 | f.write(s)
104 |
105 | if hard_fail and isinstance(hard_fail, float):
106 | metrics = df[key].array # values to compare to floor
107 | floor = hard_fail # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n
108 | assert all(x > floor for x in metrics if pd.notna(x)), f'HARD FAIL: one or more metric(s) < floor {floor}'
109 |
110 | return df
111 |
112 |
113 | if __name__ == '__main__':
114 | benchmark()
115 |
--------------------------------------------------------------------------------
/utils/callbacks/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import add_integration_callbacks, default_callbacks
2 |
3 | __all__ = 'add_integration_callbacks', 'default_callbacks'
4 |
--------------------------------------------------------------------------------
/utils/callbacks/base.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 | """
3 | Base callbacks
4 | """
5 |
6 |
7 | # Trainer callbacks ----------------------------------------------------------------------------------------------------
8 | def on_pretrain_routine_start(trainer):
9 | pass
10 |
11 |
12 | def on_pretrain_routine_end(trainer):
13 | pass
14 |
15 |
16 | def on_train_start(trainer):
17 | pass
18 |
19 |
20 | def on_train_epoch_start(trainer):
21 | pass
22 |
23 |
24 | def on_train_batch_start(trainer):
25 | pass
26 |
27 |
28 | def optimizer_step(trainer):
29 | pass
30 |
31 |
32 | def on_before_zero_grad(trainer):
33 | pass
34 |
35 |
36 | def on_train_batch_end(trainer):
37 | pass
38 |
39 |
40 | def on_train_epoch_end(trainer):
41 | pass
42 |
43 |
44 | def on_fit_epoch_end(trainer):
45 | pass
46 |
47 |
48 | def on_model_save(trainer):
49 | pass
50 |
51 |
52 | def on_train_end(trainer):
53 | pass
54 |
55 |
56 | def on_params_update(trainer):
57 | pass
58 |
59 |
60 | def teardown(trainer):
61 | pass
62 |
63 |
64 | # Validator callbacks --------------------------------------------------------------------------------------------------
65 | def on_val_start(validator):
66 | pass
67 |
68 |
69 | def on_val_batch_start(validator):
70 | pass
71 |
72 |
73 | def on_val_batch_end(validator):
74 | pass
75 |
76 |
77 | def on_val_end(validator):
78 | pass
79 |
80 |
81 | # Predictor callbacks --------------------------------------------------------------------------------------------------
82 | def on_predict_start(predictor):
83 | pass
84 |
85 |
86 | def on_predict_batch_start(predictor):
87 | pass
88 |
89 |
90 | def on_predict_batch_end(predictor):
91 | pass
92 |
93 |
94 | def on_predict_postprocess_end(predictor):
95 | pass
96 |
97 |
98 | def on_predict_end(predictor):
99 | pass
100 |
101 |
102 | # Exporter callbacks ---------------------------------------------------------------------------------------------------
103 | def on_export_start(exporter):
104 | pass
105 |
106 |
107 | def on_export_end(exporter):
108 | pass
109 |
110 |
111 | default_callbacks = {
112 | # Run in trainer
113 | 'on_pretrain_routine_start': [on_pretrain_routine_start],
114 | 'on_pretrain_routine_end': [on_pretrain_routine_end],
115 | 'on_train_start': [on_train_start],
116 | 'on_train_epoch_start': [on_train_epoch_start],
117 | 'on_train_batch_start': [on_train_batch_start],
118 | 'optimizer_step': [optimizer_step],
119 | 'on_before_zero_grad': [on_before_zero_grad],
120 | 'on_train_batch_end': [on_train_batch_end],
121 | 'on_train_epoch_end': [on_train_epoch_end],
122 | 'on_fit_epoch_end': [on_fit_epoch_end], # fit = train + val
123 | 'on_model_save': [on_model_save],
124 | 'on_train_end': [on_train_end],
125 | 'on_params_update': [on_params_update],
126 | 'teardown': [teardown],
127 |
128 | # Run in validator
129 | 'on_val_start': [on_val_start],
130 | 'on_val_batch_start': [on_val_batch_start],
131 | 'on_val_batch_end': [on_val_batch_end],
132 | 'on_val_end': [on_val_end],
133 |
134 | # Run in predictor
135 | 'on_predict_start': [on_predict_start],
136 | 'on_predict_batch_start': [on_predict_batch_start],
137 | 'on_predict_postprocess_end': [on_predict_postprocess_end],
138 | 'on_predict_batch_end': [on_predict_batch_end],
139 | 'on_predict_end': [on_predict_end],
140 |
141 | # Run in exporter
142 | 'on_export_start': [on_export_start],
143 | 'on_export_end': [on_export_end]}
144 |
145 |
146 | def add_integration_callbacks(instance):
147 | from .clearml import callbacks as clearml_callbacks
148 | from .comet import callbacks as comet_callbacks
149 | from .hub import callbacks as hub_callbacks
150 | from .tensorboard import callbacks as tb_callbacks
151 |
152 | for x in clearml_callbacks, comet_callbacks, hub_callbacks, tb_callbacks:
153 | for k, v in x.items():
154 | if v not in instance.callbacks[k]: # prevent duplicate callbacks addition
155 | instance.callbacks[k].append(v) # callback[name].append(func)
156 |
--------------------------------------------------------------------------------
/utils/callbacks/clearml.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 | from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING
3 | from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params
4 |
5 | try:
6 | import clearml
7 | from clearml import Task
8 |
9 | assert clearml.__version__ # verify package is not directory
10 | assert not TESTS_RUNNING # do not log pytest
11 | except (ImportError, AssertionError):
12 | clearml = None
13 |
14 |
15 | def _log_images(imgs_dict, group='', step=0):
16 | task = Task.current_task()
17 | if task:
18 | for k, v in imgs_dict.items():
19 | task.get_logger().report_image(group, k, step, v)
20 |
21 |
22 | def on_pretrain_routine_start(trainer):
23 | try:
24 | task = Task.init(project_name=trainer.args.project or 'YOLOv8',
25 | task_name=trainer.args.name,
26 | tags=['YOLOv8'],
27 | output_uri=True,
28 | reuse_last_task_id=False,
29 | auto_connect_frameworks={'pytorch': False})
30 | task.connect(vars(trainer.args), name='General')
31 | except Exception as e:
32 | LOGGER.warning(f'WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}')
33 |
34 |
35 | def on_train_epoch_end(trainer):
36 | if trainer.epoch == 1:
37 | _log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, 'Mosaic', trainer.epoch)
38 |
39 |
40 | def on_fit_epoch_end(trainer):
41 | task = Task.current_task()
42 | if task and trainer.epoch == 0:
43 | model_info = {
44 | 'model/parameters': get_num_params(trainer.model),
45 | 'model/GFLOPs': round(get_flops(trainer.model), 3),
46 | 'model/speed(ms)': round(trainer.validator.speed['inference'], 3)}
47 | task.connect(model_info, name='Model')
48 |
49 |
50 | def on_train_end(trainer):
51 | task = Task.current_task()
52 | if task:
53 | task.update_output_model(model_path=str(trainer.best), model_name=trainer.args.name, auto_delete_file=False)
54 |
55 |
56 | callbacks = {
57 | 'on_pretrain_routine_start': on_pretrain_routine_start,
58 | 'on_train_epoch_end': on_train_epoch_end,
59 | 'on_fit_epoch_end': on_fit_epoch_end,
60 | 'on_train_end': on_train_end} if clearml else {}
61 |
--------------------------------------------------------------------------------
/utils/callbacks/comet.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 | from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING
3 | from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params
4 |
5 | try:
6 | import comet_ml
7 |
8 | assert not TESTS_RUNNING # do not log pytest
9 | assert comet_ml.__version__ # verify package is not directory
10 | except (ImportError, AssertionError):
11 | comet_ml = None
12 |
13 |
14 | def on_pretrain_routine_start(trainer):
15 | try:
16 | experiment = comet_ml.Experiment(project_name=trainer.args.project or 'YOLOv8')
17 | experiment.log_parameters(vars(trainer.args))
18 | except Exception as e:
19 | LOGGER.warning(f'WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}')
20 |
21 |
22 | def on_train_epoch_end(trainer):
23 | experiment = comet_ml.get_global_experiment()
24 | if experiment:
25 | experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix='train'), step=trainer.epoch + 1)
26 | if trainer.epoch == 1:
27 | for f in trainer.save_dir.glob('train_batch*.jpg'):
28 | experiment.log_image(f, name=f.stem, step=trainer.epoch + 1)
29 |
30 |
31 | def on_fit_epoch_end(trainer):
32 | experiment = comet_ml.get_global_experiment()
33 | if experiment:
34 | experiment.log_metrics(trainer.metrics, step=trainer.epoch + 1)
35 | if trainer.epoch == 0:
36 | model_info = {
37 | 'model/parameters': get_num_params(trainer.model),
38 | 'model/GFLOPs': round(get_flops(trainer.model), 3),
39 | 'model/speed(ms)': round(trainer.validator.speed['inference'], 3)}
40 | experiment.log_metrics(model_info, step=trainer.epoch + 1)
41 |
42 |
43 | def on_train_end(trainer):
44 | experiment = comet_ml.get_global_experiment()
45 | if experiment:
46 | experiment.log_model('YOLOv8', file_or_folder=str(trainer.best), file_name='best.pt', overwrite=True)
47 |
48 |
49 | callbacks = {
50 | 'on_pretrain_routine_start': on_pretrain_routine_start,
51 | 'on_train_epoch_end': on_train_epoch_end,
52 | 'on_fit_epoch_end': on_fit_epoch_end,
53 | 'on_train_end': on_train_end} if comet_ml else {}
54 |
--------------------------------------------------------------------------------
/utils/callbacks/hub.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | import json
4 | from time import time
5 |
6 | from ultralytics.hub.utils import PREFIX, traces
7 | from ultralytics.yolo.utils import LOGGER
8 | from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params
9 |
10 |
11 | def on_pretrain_routine_end(trainer):
12 | session = getattr(trainer, 'hub_session', None)
13 | if session:
14 | # Start timer for upload rate limit
15 | LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀')
16 | session.timers = {'metrics': time(), 'ckpt': time()} # start timer on session.rate_limit
17 |
18 |
19 | def on_fit_epoch_end(trainer):
20 | session = getattr(trainer, 'hub_session', None)
21 | if session:
22 | # Upload metrics after val end
23 | all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics}
24 | if trainer.epoch == 0:
25 | model_info = {
26 | 'model/parameters': get_num_params(trainer.model),
27 | 'model/GFLOPs': round(get_flops(trainer.model), 3),
28 | 'model/speed(ms)': round(trainer.validator.speed['inference'], 3)}
29 | all_plots = {**all_plots, **model_info}
30 | session.metrics_queue[trainer.epoch] = json.dumps(all_plots)
31 | if time() - session.timers['metrics'] > session.rate_limits['metrics']:
32 | session.upload_metrics()
33 | session.timers['metrics'] = time() # reset timer
34 | session.metrics_queue = {} # reset queue
35 |
36 |
37 | def on_model_save(trainer):
38 | session = getattr(trainer, 'hub_session', None)
39 | if session:
40 | # Upload checkpoints with rate limiting
41 | is_best = trainer.best_fitness == trainer.fitness
42 | if time() - session.timers['ckpt'] > session.rate_limits['ckpt']:
43 | LOGGER.info(f'{PREFIX}Uploading checkpoint {session.model_id}')
44 | session.upload_model(trainer.epoch, trainer.last, is_best)
45 | session.timers['ckpt'] = time() # reset timer
46 |
47 |
48 | def on_train_end(trainer):
49 | session = getattr(trainer, 'hub_session', None)
50 | if session:
51 | # Upload final model and metrics with exponential standoff
52 | LOGGER.info(f'{PREFIX}Syncing final model...')
53 | session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics.get('metrics/mAP50-95(B)', 0), final=True)
54 | session.alive = False # stop heartbeats
55 | LOGGER.info(f'{PREFIX}Done ✅\n'
56 | f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀')
57 |
58 |
59 | def on_train_start(trainer):
60 | traces(trainer.args, traces_sample_rate=1.0)
61 |
62 |
63 | def on_val_start(validator):
64 | traces(validator.args, traces_sample_rate=1.0)
65 |
66 |
67 | def on_predict_start(predictor):
68 | traces(predictor.args, traces_sample_rate=1.0)
69 |
70 |
71 | def on_export_start(exporter):
72 | traces(exporter.args, traces_sample_rate=1.0)
73 |
74 |
75 | callbacks = {
76 | 'on_pretrain_routine_end': on_pretrain_routine_end,
77 | 'on_fit_epoch_end': on_fit_epoch_end,
78 | 'on_model_save': on_model_save,
79 | 'on_train_end': on_train_end,
80 | 'on_train_start': on_train_start,
81 | 'on_val_start': on_val_start,
82 | 'on_predict_start': on_predict_start,
83 | 'on_export_start': on_export_start}
84 |
--------------------------------------------------------------------------------
/utils/callbacks/tensorboard.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 | from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING, colorstr
3 |
4 | try:
5 | from torch.utils.tensorboard import SummaryWriter
6 |
7 | assert not TESTS_RUNNING # do not log pytest
8 | except (ImportError, AssertionError):
9 | SummaryWriter = None
10 |
11 | writer = None # TensorBoard SummaryWriter instance
12 |
13 |
14 | def _log_scalars(scalars, step=0):
15 | if writer:
16 | for k, v in scalars.items():
17 | writer.add_scalar(k, v, step)
18 |
19 |
20 | def on_pretrain_routine_start(trainer):
21 | if SummaryWriter:
22 | try:
23 | global writer
24 | writer = SummaryWriter(str(trainer.save_dir))
25 | prefix = colorstr('TensorBoard: ')
26 | LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
27 | except Exception as e:
28 | LOGGER.warning(f'WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}')
29 |
30 |
31 | def on_batch_end(trainer):
32 | _log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1)
33 |
34 |
35 | def on_fit_epoch_end(trainer):
36 | _log_scalars(trainer.metrics, trainer.epoch + 1)
37 |
38 |
39 | callbacks = {
40 | 'on_pretrain_routine_start': on_pretrain_routine_start,
41 | 'on_fit_epoch_end': on_fit_epoch_end,
42 | 'on_batch_end': on_batch_end}
43 |
--------------------------------------------------------------------------------
/utils/dist.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | import os
4 | import re
5 | import shutil
6 | import socket
7 | import sys
8 | import tempfile
9 | from pathlib import Path
10 |
11 | from . import USER_CONFIG_DIR
12 | from .torch_utils import TORCH_1_9
13 |
14 |
15 | def find_free_network_port() -> int:
16 | """Finds a free port on localhost.
17 |
18 | It is useful in single-node training when we don't want to connect to a real main node but have to set the
19 | `MASTER_PORT` environment variable.
20 | """
21 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
22 | s.bind(('127.0.0.1', 0))
23 | return s.getsockname()[1] # port
24 |
25 |
26 | def generate_ddp_file(trainer):
27 | module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1)
28 |
29 | content = f'''cfg = {vars(trainer.args)} \nif __name__ == "__main__":
30 | from {module} import {name}
31 |
32 | trainer = {name}(cfg=cfg)
33 | trainer.train()'''
34 | (USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
35 | with tempfile.NamedTemporaryFile(prefix='_temp_',
36 | suffix=f'{id(trainer)}.py',
37 | mode='w+',
38 | encoding='utf-8',
39 | dir=USER_CONFIG_DIR / 'DDP',
40 | delete=False) as file:
41 | file.write(content)
42 | return file.name
43 |
44 |
45 | def generate_ddp_command(world_size, trainer):
46 | import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
47 | if not trainer.resume:
48 | shutil.rmtree(trainer.save_dir) # remove the save_dir
49 | file = str(Path(sys.argv[0]).resolve())
50 | safe_pattern = re.compile(r'^[a-zA-Z0-9_. /\\-]{1,128}$') # allowed characters and maximum of 100 characters
51 | if not (safe_pattern.match(file) and Path(file).exists() and file.endswith('.py')): # using CLI
52 | file = generate_ddp_file(trainer)
53 | dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
54 | port = find_free_network_port()
55 | exclude_args = ['save_dir']
56 | args = [f'{k}={v}' for k, v in vars(trainer.args).items() if k not in exclude_args]
57 | cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file] + args
58 | return cmd, file
59 |
60 |
61 | def ddp_cleanup(trainer, file):
62 | # delete temp file if created
63 | if f'{id(trainer)}.py' in file: # if temp_file suffix in file
64 | os.remove(file)
65 |
--------------------------------------------------------------------------------
/utils/files.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | import contextlib
4 | import glob
5 | import os
6 | import urllib
7 | from datetime import datetime
8 | from pathlib import Path
9 |
10 |
11 | class WorkingDirectory(contextlib.ContextDecorator):
12 | # Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
13 | def __init__(self, new_dir):
14 | self.dir = new_dir # new dir
15 | self.cwd = Path.cwd().resolve() # current dir
16 |
17 | def __enter__(self):
18 | os.chdir(self.dir)
19 |
20 | def __exit__(self, exc_type, exc_val, exc_tb):
21 | os.chdir(self.cwd)
22 |
23 |
24 | def increment_path(path, exist_ok=False, sep='', mkdir=False):
25 | """
26 | Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
27 |
28 | If the path exists and exist_ok is not set to True, the path will be incremented by appending a number and sep to
29 | the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the
30 | number will be appended directly to the end of the path. If mkdir is set to True, the path will be created as a
31 | directory if it does not already exist.
32 |
33 | Args:
34 | path (str or pathlib.Path): Path to increment.
35 | exist_ok (bool, optional): If True, the path will not be incremented and will be returned as-is. Defaults to False.
36 | sep (str, optional): Separator to use between the path and the incrementation number. Defaults to an empty string.
37 | mkdir (bool, optional): If True, the path will be created as a directory if it does not exist. Defaults to False.
38 |
39 | Returns:
40 | pathlib.Path: Incremented path.
41 | """
42 | path = Path(path) # os-agnostic
43 | if path.exists() and not exist_ok:
44 | path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
45 |
46 | # Method 1
47 | for n in range(2, 9999):
48 | p = f'{path}{sep}{n}{suffix}' # increment path
49 | if not os.path.exists(p): #
50 | break
51 | path = Path(p)
52 |
53 | if mkdir:
54 | path.mkdir(parents=True, exist_ok=True) # make directory
55 |
56 | return path
57 |
58 |
59 | def file_age(path=__file__):
60 | # Return days since last file update
61 | dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
62 | return dt.days # + dt.seconds / 86400 # fractional days
63 |
64 |
65 | def file_date(path=__file__):
66 | # Return human-readable file modification date, i.e. '2021-3-26'
67 | t = datetime.fromtimestamp(Path(path).stat().st_mtime)
68 | return f'{t.year}-{t.month}-{t.day}'
69 |
70 |
71 | def file_size(path):
72 | # Return file/dir size (MB)
73 | if isinstance(path, (str, Path)):
74 | mb = 1 << 20 # bytes to MiB (1024 ** 2)
75 | path = Path(path)
76 | if path.is_file():
77 | return path.stat().st_size / mb
78 | elif path.is_dir():
79 | return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
80 | return 0.0
81 |
82 |
83 | def url2file(url):
84 | # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
85 | url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
86 | return Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
87 |
88 |
89 | def get_latest_run(search_dir='.'):
90 | # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
91 | last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
92 | return max(last_list, key=os.path.getctime) if last_list else ''
93 |
--------------------------------------------------------------------------------
/utils/id_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Modified by: Ruihao
3 | # @ProjectName:yolov5-pyqt5
4 |
5 | '''
6 | 存放公用的账户读写函数
7 | '''
8 | import csv
9 |
10 | # 写入账户信息到csv文件
11 | def sava_id_info(user, pwd):
12 | headers = ['name', 'key']
13 | values = [{'name':user, 'key':pwd}]
14 | with open('userInfo.csv', 'a', encoding='utf-8', newline='') as fp:
15 | writer = csv.DictWriter(fp, headers)
16 | writer.writerows(values)
17 |
18 | # 读取csv文件获得账户信息
19 | def get_id_info():
20 | USER_PWD = {}
21 | with open('userInfo.csv', 'r') as csvfile: # 此目录即是当前项目根目录
22 | spamreader = csv.reader(csvfile)
23 | # 逐行遍历csv文件,按照字典存储用户名与密码
24 | for row in spamreader:
25 | USER_PWD[row[0]] = row[1]
26 | return USER_PWD
27 |
28 | #
29 |
30 |
31 |
--------------------------------------------------------------------------------
/utils/loss.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from .metrics import bbox_iou
8 | from .tal import bbox2dist
9 |
10 |
11 | class VarifocalLoss(nn.Module):
12 | # Varifocal loss by Zhang et al. https://arxiv.org/abs/2008.13367
13 | def __init__(self):
14 | super().__init__()
15 |
16 | def forward(self, pred_score, gt_score, label, alpha=0.75, gamma=2.0):
17 | weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
18 | with torch.cuda.amp.autocast(enabled=False):
19 | loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') *
20 | weight).sum()
21 | return loss
22 |
23 |
24 | class BboxLoss(nn.Module):
25 |
26 | def __init__(self, reg_max, use_dfl=False):
27 | super().__init__()
28 | self.reg_max = reg_max
29 | self.use_dfl = use_dfl
30 |
31 | def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
32 | # IoU loss
33 | weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1)
34 | iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
35 | loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
36 |
37 | # DFL loss
38 | if self.use_dfl:
39 | target_ltrb = bbox2dist(anchor_points, target_bboxes, self.reg_max)
40 | loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight
41 | loss_dfl = loss_dfl.sum() / target_scores_sum
42 | else:
43 | loss_dfl = torch.tensor(0.0).to(pred_dist.device)
44 |
45 | return loss_iou, loss_dfl
46 |
47 | @staticmethod
48 | def _df_loss(pred_dist, target):
49 | # Return sum of left and right DFL losses
50 | # Distribution Focal Loss (DFL) proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
51 | tl = target.long() # target left
52 | tr = tl + 1 # target right
53 | wl = tr - target # weight left
54 | wr = 1 - wl # weight right
55 | return (F.cross_entropy(pred_dist, tl.view(-1), reduction='none').view(tl.shape) * wl +
56 | F.cross_entropy(pred_dist, tr.view(-1), reduction='none').view(tl.shape) * wr).mean(-1, keepdim=True)
57 |
--------------------------------------------------------------------------------
/v8/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | from ultralytics.yolo.v8 import classify, detect, segment
4 |
5 | __all__ = 'classify', 'segment', 'detect'
6 |
--------------------------------------------------------------------------------
/v8/classify/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | from ultralytics.yolo.v8.classify.predict import ClassificationPredictor, predict
4 | from ultralytics.yolo.v8.classify.train import ClassificationTrainer, train
5 | from ultralytics.yolo.v8.classify.val import ClassificationValidator, val
6 |
7 | __all__ = 'ClassificationPredictor', 'predict', 'ClassificationTrainer', 'train', 'ClassificationValidator', 'val'
8 |
--------------------------------------------------------------------------------
/v8/classify/predict.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | import torch
4 |
5 | from ultralytics.yolo.engine.predictor import BasePredictor
6 | from ultralytics.yolo.engine.results import Results
7 | from ultralytics.yolo.utils import DEFAULT_CFG, ROOT
8 | from ultralytics.yolo.utils.plotting import Annotator
9 |
10 |
11 | class ClassificationPredictor(BasePredictor):
12 |
13 | def get_annotator(self, img):
14 | return Annotator(img, example=str(self.model.names), pil=True)
15 |
16 | def preprocess(self, img):
17 | img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
18 | return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
19 |
20 | def postprocess(self, preds, img, orig_imgs):
21 | results = []
22 | for i, pred in enumerate(preds):
23 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
24 | path, _, _, _, _ = self.batch
25 | img_path = path[i] if isinstance(path, list) else path
26 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, probs=pred))
27 |
28 | return results
29 |
30 | def write_results(self, idx, results, batch):
31 | p, im, im0 = batch
32 | log_string = ''
33 | if len(im.shape) == 3:
34 | im = im[None] # expand for batch dim
35 | self.seen += 1
36 | im0 = im0.copy()
37 | if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1
38 | log_string += f'{idx}: '
39 | frame = self.dataset.count
40 | else:
41 | frame = getattr(self.dataset, 'frame', 0)
42 |
43 | self.data_path = p
44 | # save_path = str(self.save_dir / p.name) # im.jpg
45 | self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
46 | log_string += '%gx%g ' % im.shape[2:] # print string
47 | self.annotator = self.get_annotator(im0)
48 |
49 | result = results[idx]
50 | if len(result) == 0:
51 | return log_string
52 | prob = result.probs
53 | # Print results
54 | n5 = min(len(self.model.names), 5)
55 | top5i = prob.argsort(0, descending=True)[:n5].tolist() # top 5 indices
56 | log_string += f"{', '.join(f'{self.model.names[j]} {prob[j]:.2f}' for j in top5i)}, "
57 |
58 | # write
59 | text = '\n'.join(f'{prob[j]:.2f} {self.model.names[j]}' for j in top5i)
60 | if self.args.save or self.args.show: # Add bbox to image
61 | self.annotator.text((32, 32), text, txt_color=(255, 255, 255))
62 | if self.args.save_txt: # Write to file
63 | with open(f'{self.txt_path}.txt', 'a') as f:
64 | f.write(text + '\n')
65 |
66 | return log_string
67 |
68 |
69 | def predict(cfg=DEFAULT_CFG, use_python=False):
70 | model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
71 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
72 | else 'https://ultralytics.com/images/bus.jpg'
73 |
74 | args = dict(model=model, source=source)
75 | if use_python:
76 | from ultralytics import YOLO
77 | YOLO(model)(**args)
78 | else:
79 | predictor = ClassificationPredictor(overrides=args)
80 | predictor.predict_cli()
81 |
82 |
83 | if __name__ == '__main__':
84 | predict()
85 |
--------------------------------------------------------------------------------
/v8/classify/train.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | import torch
4 | import torchvision
5 |
6 | from ultralytics.nn.tasks import ClassificationModel, attempt_load_one_weight
7 | from ultralytics.yolo import v8
8 | from ultralytics.yolo.data import build_classification_dataloader
9 | from ultralytics.yolo.engine.trainer import BaseTrainer
10 | from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, colorstr
11 | from ultralytics.yolo.utils.torch_utils import is_parallel, strip_optimizer
12 |
13 |
14 | class ClassificationTrainer(BaseTrainer):
15 |
16 | def __init__(self, cfg=DEFAULT_CFG, overrides=None):
17 | if overrides is None:
18 | overrides = {}
19 | overrides['task'] = 'classify'
20 | super().__init__(cfg, overrides)
21 |
22 | def set_model_attributes(self):
23 | self.model.names = self.data['names']
24 |
25 | def get_model(self, cfg=None, weights=None, verbose=True):
26 | model = ClassificationModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
27 | if weights:
28 | model.load(weights)
29 |
30 | pretrained = False
31 | for m in model.modules():
32 | if not pretrained and hasattr(m, 'reset_parameters'):
33 | m.reset_parameters()
34 | if isinstance(m, torch.nn.Dropout) and self.args.dropout:
35 | m.p = self.args.dropout # set dropout
36 | for p in model.parameters():
37 | p.requires_grad = True # for training
38 |
39 | # Update defaults
40 | if self.args.imgsz == 640:
41 | self.args.imgsz = 224
42 |
43 | return model
44 |
45 | def setup_model(self):
46 | """
47 | load/create/download model for any task
48 | """
49 | # classification models require special handling
50 |
51 | if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
52 | return
53 |
54 | model = str(self.model)
55 | # Load a YOLO model locally, from torchvision, or from Ultralytics assets
56 | if model.endswith('.pt'):
57 | self.model, _ = attempt_load_one_weight(model, device='cpu')
58 | for p in self.model.parameters():
59 | p.requires_grad = True # for training
60 | elif model.endswith('.yaml'):
61 | self.model = self.get_model(cfg=model)
62 | elif model in torchvision.models.__dict__:
63 | pretrained = True
64 | self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
65 | else:
66 | FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
67 | ClassificationModel.reshape_outputs(self.model, self.data['nc'])
68 |
69 | return # dont return ckpt. Classification doesn't support resume
70 |
71 | def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
72 | loader = build_classification_dataloader(path=dataset_path,
73 | imgsz=self.args.imgsz,
74 | batch_size=batch_size if mode == 'train' else (batch_size * 2),
75 | augment=mode == 'train',
76 | rank=rank,
77 | workers=self.args.workers)
78 | # Attach inference transforms
79 | if mode != 'train':
80 | if is_parallel(self.model):
81 | self.model.module.transforms = loader.dataset.torch_transforms
82 | else:
83 | self.model.transforms = loader.dataset.torch_transforms
84 | return loader
85 |
86 | def preprocess_batch(self, batch):
87 | batch['img'] = batch['img'].to(self.device)
88 | batch['cls'] = batch['cls'].to(self.device)
89 | return batch
90 |
91 | def progress_string(self):
92 | return ('\n' + '%11s' * (4 + len(self.loss_names))) % \
93 | ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
94 |
95 | def get_validator(self):
96 | self.loss_names = ['loss']
97 | return v8.classify.ClassificationValidator(self.test_loader, self.save_dir)
98 |
99 | def criterion(self, preds, batch):
100 | loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / self.args.nbs
101 | loss_items = loss.detach()
102 | return loss, loss_items
103 |
104 | # def label_loss_items(self, loss_items=None, prefix="train"):
105 | # """
106 | # Returns a loss dict with labelled training loss items tensor
107 | # """
108 | # # Not needed for classification but necessary for segmentation & detection
109 | # keys = [f"{prefix}/{x}" for x in self.loss_names]
110 | # if loss_items is not None:
111 | # loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
112 | # return dict(zip(keys, loss_items))
113 | # else:
114 | # return keys
115 |
116 | def label_loss_items(self, loss_items=None, prefix='train'):
117 | """
118 | Returns a loss dict with labelled training loss items tensor
119 | """
120 | # Not needed for classification but necessary for segmentation & detection
121 | keys = [f'{prefix}/{x}' for x in self.loss_names]
122 | if loss_items is None:
123 | return keys
124 | loss_items = [round(float(loss_items), 5)]
125 | return dict(zip(keys, loss_items))
126 |
127 | def resume_training(self, ckpt):
128 | pass
129 |
130 | def final_eval(self):
131 | for f in self.last, self.best:
132 | if f.exists():
133 | strip_optimizer(f) # strip optimizers
134 | # TODO: validate best.pt after training completes
135 | # if f is self.best:
136 | # LOGGER.info(f'\nValidating {f}...')
137 | # self.validator.args.save_json = True
138 | # self.metrics = self.validator(model=f)
139 | # self.metrics.pop('fitness', None)
140 | # self.run_callbacks('on_fit_epoch_end')
141 | LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
142 |
143 |
144 | def train(cfg=DEFAULT_CFG, use_python=False):
145 | model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
146 | data = cfg.data or 'mnist160' # or yolo.ClassificationDataset("mnist")
147 | device = cfg.device if cfg.device is not None else ''
148 |
149 | args = dict(model=model, data=data, device=device)
150 | if use_python:
151 | from ultralytics import YOLO
152 | YOLO(model).train(**args)
153 | else:
154 | trainer = ClassificationTrainer(overrides=args)
155 | trainer.train()
156 |
157 |
158 | if __name__ == '__main__':
159 | train()
160 |
--------------------------------------------------------------------------------
/v8/classify/val.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | from ultralytics.yolo.data import build_classification_dataloader
4 | from ultralytics.yolo.engine.validator import BaseValidator
5 | from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER
6 | from ultralytics.yolo.utils.metrics import ClassifyMetrics
7 |
8 |
9 | class ClassificationValidator(BaseValidator):
10 |
11 | def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None):
12 | super().__init__(dataloader, save_dir, pbar, args)
13 | self.args.task = 'classify'
14 | self.metrics = ClassifyMetrics()
15 |
16 | def get_desc(self):
17 | return ('%22s' + '%11s' * 2) % ('classes', 'top1_acc', 'top5_acc')
18 |
19 | def init_metrics(self, model):
20 | self.pred = []
21 | self.targets = []
22 |
23 | def preprocess(self, batch):
24 | batch['img'] = batch['img'].to(self.device, non_blocking=True)
25 | batch['img'] = batch['img'].half() if self.args.half else batch['img'].float()
26 | batch['cls'] = batch['cls'].to(self.device)
27 | return batch
28 |
29 | def update_metrics(self, preds, batch):
30 | n5 = min(len(self.model.names), 5)
31 | self.pred.append(preds.argsort(1, descending=True)[:, :n5])
32 | self.targets.append(batch['cls'])
33 |
34 | def finalize_metrics(self, *args, **kwargs):
35 | self.metrics.speed = self.speed
36 | # self.metrics.confusion_matrix = self.confusion_matrix # TODO: classification ConfusionMatrix
37 |
38 | def get_stats(self):
39 | self.metrics.process(self.targets, self.pred)
40 | return self.metrics.results_dict
41 |
42 | def get_dataloader(self, dataset_path, batch_size):
43 | return build_classification_dataloader(path=dataset_path,
44 | imgsz=self.args.imgsz,
45 | batch_size=batch_size,
46 | augment=False,
47 | shuffle=False,
48 | workers=self.args.workers)
49 |
50 | def print_results(self):
51 | pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format
52 | LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5))
53 |
54 |
55 | def val(cfg=DEFAULT_CFG, use_python=False):
56 | model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
57 | data = cfg.data or 'mnist160'
58 |
59 | args = dict(model=model, data=data)
60 | if use_python:
61 | from ultralytics import YOLO
62 | YOLO(model).val(**args)
63 | else:
64 | validator = ClassificationValidator(args=args)
65 | validator(model=args['model'])
66 |
67 |
68 | if __name__ == '__main__':
69 | val()
70 |
--------------------------------------------------------------------------------
/v8/detect/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | from .predict import DetectionPredictor, predict
4 | from .train import DetectionTrainer, train
5 | from .val import DetectionValidator, val
6 |
7 | __all__ = 'DetectionPredictor', 'predict', 'DetectionTrainer', 'train', 'DetectionValidator', 'val'
8 |
--------------------------------------------------------------------------------
/v8/detect/predict.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | import torch
4 |
5 | from ultralytics.yolo.engine.predictor import BasePredictor
6 | from ultralytics.yolo.engine.results import Results
7 | from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops
8 | from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
9 |
10 |
11 | class DetectionPredictor(BasePredictor):
12 |
13 | def get_annotator(self, img):
14 | return Annotator(img, line_width=self.args.line_thickness, example=str(self.model.names))
15 |
16 | def preprocess(self, img):
17 | img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
18 | img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
19 | img /= 255 # 0 - 255 to 0.0 - 1.0
20 | return img
21 |
22 | def postprocess(self, preds, img, orig_imgs):
23 | preds = ops.non_max_suppression(preds,
24 | self.args.conf,
25 | self.args.iou,
26 | agnostic=self.args.agnostic_nms,
27 | max_det=self.args.max_det,
28 | classes=self.args.classes)
29 |
30 | results = []
31 | for i, pred in enumerate(preds):
32 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
33 | if not isinstance(orig_imgs, torch.Tensor):
34 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
35 | path, _, _, _, _ = self.batch
36 | img_path = path[i] if isinstance(path, list) else path
37 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred))
38 | return results
39 |
40 | def write_results(self, idx, results, batch):
41 | p, im, im0 = batch
42 | log_string = ''
43 | if len(im.shape) == 3:
44 | im = im[None] # expand for batch dim
45 | self.seen += 1
46 | imc = im0.copy() if self.args.save_crop else im0
47 | if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1
48 | log_string += f'{idx}: '
49 | frame = self.dataset.count
50 | else:
51 | frame = getattr(self.dataset, 'frame', 0)
52 | self.data_path = p
53 | self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
54 | log_string += '%gx%g ' % im.shape[2:] # print string
55 | self.annotator = self.get_annotator(im0)
56 |
57 | det = results[idx].boxes # TODO: make boxes inherit from tensors
58 | if len(det) == 0:
59 | return f'{log_string}(no detections), '
60 | for c in det.cls.unique():
61 | n = (det.cls == c).sum() # detections per class
62 | log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, "
63 |
64 | # write
65 | for d in reversed(det):
66 | c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item())
67 | if self.args.save_txt: # Write to file
68 | line = (c, *d.xywhn.view(-1)) + (conf, ) * self.args.save_conf + (() if id is None else (id, ))
69 | with open(f'{self.txt_path}.txt', 'a') as f:
70 | f.write(('%g ' * len(line)).rstrip() % line + '\n')
71 | if self.args.save or self.args.show: # Add bbox to image
72 | name = ('' if id is None else f'id:{id} ') + self.model.names[c]
73 | label = None if self.args.hide_labels else (name if self.args.hide_conf else f'{name} {conf:.2f}')
74 | self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
75 | if self.args.save_crop:
76 | save_one_box(d.xyxy,
77 | imc,
78 | file=self.save_dir / 'crops' / self.model.names[c] / f'{self.data_path.stem}.jpg',
79 | BGR=True)
80 |
81 | return log_string
82 |
83 |
84 | def predict(cfg=DEFAULT_CFG, use_python=False):
85 | model = cfg.model or 'yolov8n.pt'
86 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
87 | else 'https://ultralytics.com/images/bus.jpg'
88 |
89 | args = dict(model=model, source=source)
90 | if use_python:
91 | from ultralytics import YOLO
92 | YOLO(model)(**args)
93 | else:
94 | predictor = DetectionPredictor(overrides=args)
95 | predictor.predict_cli()
96 |
97 |
98 | if __name__ == '__main__':
99 | predict()
100 |
--------------------------------------------------------------------------------
/v8/segment/__init__.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | from .predict import SegmentationPredictor, predict
4 | from .train import SegmentationTrainer, train
5 | from .val import SegmentationValidator, val
6 |
7 | __all__ = 'SegmentationPredictor', 'predict', 'SegmentationTrainer', 'train', 'SegmentationValidator', 'val'
8 |
--------------------------------------------------------------------------------
/v8/segment/predict.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 |
3 | import torch
4 |
5 | from ultralytics.yolo.engine.results import Results
6 | from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops
7 | from ultralytics.yolo.utils.plotting import colors, save_one_box
8 | from ultralytics.yolo.v8.detect.predict import DetectionPredictor
9 |
10 |
11 | class SegmentationPredictor(DetectionPredictor):
12 |
13 | def postprocess(self, preds, img, orig_imgs):
14 | # TODO: filter by classes
15 | p = ops.non_max_suppression(preds[0],
16 | self.args.conf,
17 | self.args.iou,
18 | agnostic=self.args.agnostic_nms,
19 | max_det=self.args.max_det,
20 | nc=len(self.model.names),
21 | classes=self.args.classes)
22 | results = []
23 | proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
24 | for i, pred in enumerate(p):
25 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
26 | path, _, _, _, _ = self.batch
27 | img_path = path[i] if isinstance(path, list) else path
28 | if not len(pred): # save empty boxes
29 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6]))
30 | continue
31 | if self.args.retina_masks:
32 | if not isinstance(orig_imgs, torch.Tensor):
33 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
34 | masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
35 | else:
36 | masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
37 | if not isinstance(orig_imgs, torch.Tensor):
38 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
39 | results.append(
40 | Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
41 | return results
42 |
43 | def write_results(self, idx, results, batch):
44 | p, im, im0 = batch
45 | log_string = ''
46 | if len(im.shape) == 3:
47 | im = im[None] # expand for batch dim
48 | self.seen += 1
49 | imc = im0.copy() if self.args.save_crop else im0
50 | if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1
51 | log_string += f'{idx}: '
52 | frame = self.dataset.count
53 | else:
54 | frame = getattr(self.dataset, 'frame', 0)
55 |
56 | self.data_path = p
57 | self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
58 | log_string += '%gx%g ' % im.shape[2:] # print string
59 | self.annotator = self.get_annotator(im0)
60 |
61 | result = results[idx]
62 | if len(result) == 0:
63 | return f'{log_string}(no detections), '
64 | det, mask = result.boxes, result.masks # getting tensors TODO: mask mask,box inherit for tensor
65 |
66 | # Print results
67 | for c in det.cls.unique():
68 | n = (det.cls == c).sum() # detections per class
69 | log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, "
70 |
71 | # Mask plotting
72 | if self.args.save or self.args.show:
73 | im_gpu = torch.as_tensor(im0, dtype=torch.float16, device=mask.masks.device).permute(
74 | 2, 0, 1).flip(0).contiguous() / 255 if self.args.retina_masks else im[idx]
75 | self.annotator.masks(masks=mask.masks, colors=[colors(x, True) for x in det.cls], im_gpu=im_gpu)
76 |
77 | # Write results
78 | for j, d in enumerate(reversed(det)):
79 | c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item())
80 | if self.args.save_txt: # Write to file
81 | seg = mask.segments[len(det) - j - 1].copy().reshape(-1) # reversed mask.segments, (n,2) to (n*2)
82 | line = (c, *seg) + (conf, ) * self.args.save_conf + (() if id is None else (id, ))
83 | with open(f'{self.txt_path}.txt', 'a') as f:
84 | f.write(('%g ' * len(line)).rstrip() % line + '\n')
85 | if self.args.save or self.args.show: # Add bbox to image
86 | name = ('' if id is None else f'id:{id} ') + self.model.names[c]
87 | label = None if self.args.hide_labels else (name if self.args.hide_conf else f'{name} {conf:.2f}')
88 | if self.args.boxes:
89 | self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
90 | if self.args.save_crop:
91 | save_one_box(d.xyxy,
92 | imc,
93 | file=self.save_dir / 'crops' / self.model.names[c] / f'{self.data_path.stem}.jpg',
94 | BGR=True)
95 |
96 | return log_string
97 |
98 |
99 | def predict(cfg=DEFAULT_CFG, use_python=False):
100 | model = cfg.model or 'yolov8n-seg.pt'
101 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
102 | else 'https://ultralytics.com/images/bus.jpg'
103 |
104 | args = dict(model=model, source=source)
105 | if use_python:
106 | from ultralytics import YOLO
107 | YOLO(model)(**args)
108 | else:
109 | predictor = SegmentationPredictor(overrides=args)
110 | predictor.predict_cli()
111 |
112 |
113 | if __name__ == '__main__':
114 | predict()
115 |
--------------------------------------------------------------------------------
/v8/segment/train.py:
--------------------------------------------------------------------------------
1 | # Ultralytics YOLO 🚀, GPL-3.0 license
2 | from copy import copy
3 |
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | from ultralytics.nn.tasks import SegmentationModel
8 | from ultralytics.yolo import v8
9 | from ultralytics.yolo.utils import DEFAULT_CFG, RANK
10 | from ultralytics.yolo.utils.ops import crop_mask, xyxy2xywh
11 | from ultralytics.yolo.utils.plotting import plot_images, plot_results
12 | from ultralytics.yolo.utils.tal import make_anchors
13 | from ultralytics.yolo.utils.torch_utils import de_parallel
14 | from ultralytics.yolo.v8.detect.train import Loss
15 |
16 |
17 | # BaseTrainer python usage
18 | class SegmentationTrainer(v8.detect.DetectionTrainer):
19 |
20 | def __init__(self, cfg=DEFAULT_CFG, overrides=None):
21 | if overrides is None:
22 | overrides = {}
23 | overrides['task'] = 'segment'
24 | super().__init__(cfg, overrides)
25 |
26 | def get_model(self, cfg=None, weights=None, verbose=True):
27 | model = SegmentationModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1)
28 | if weights:
29 | model.load(weights)
30 |
31 | return model
32 |
33 | def get_validator(self):
34 | self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss'
35 | return v8.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
36 |
37 | def criterion(self, preds, batch):
38 | if not hasattr(self, 'compute_loss'):
39 | self.compute_loss = SegLoss(de_parallel(self.model), overlap=self.args.overlap_mask)
40 | return self.compute_loss(preds, batch)
41 |
42 | def plot_training_samples(self, batch, ni):
43 | images = batch['img']
44 | masks = batch['masks']
45 | cls = batch['cls'].squeeze(-1)
46 | bboxes = batch['bboxes']
47 | paths = batch['im_file']
48 | batch_idx = batch['batch_idx']
49 | plot_images(images, batch_idx, cls, bboxes, masks, paths=paths, fname=self.save_dir / f'train_batch{ni}.jpg')
50 |
51 | def plot_metrics(self):
52 | plot_results(file=self.csv, segment=True) # save results.png
53 |
54 |
55 | # Criterion class for computing training losses
56 | class SegLoss(Loss):
57 |
58 | def __init__(self, model, overlap=True): # model must be de-paralleled
59 | super().__init__(model)
60 | self.nm = model.model[-1].nm # number of masks
61 | self.overlap = overlap
62 |
63 | def __call__(self, preds, batch):
64 | loss = torch.zeros(4, device=self.device) # box, cls, dfl
65 | feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
66 | batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
67 | pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
68 | (self.reg_max * 4, self.nc), 1)
69 |
70 | # b, grids, ..
71 | pred_scores = pred_scores.permute(0, 2, 1).contiguous()
72 | pred_distri = pred_distri.permute(0, 2, 1).contiguous()
73 | pred_masks = pred_masks.permute(0, 2, 1).contiguous()
74 |
75 | dtype = pred_scores.dtype
76 | imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
77 | anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
78 |
79 | # targets
80 | try:
81 | batch_idx = batch['batch_idx'].view(-1, 1)
82 | targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
83 | targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
84 | gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
85 | mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
86 | except RuntimeError as e:
87 | raise TypeError('ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n'
88 | "This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
89 | "i.e. 'yolo train model=yolov8n-seg.pt data=coco128.yaml'.\nVerify your dataset is a "
90 | "correctly formatted 'segment' dataset using 'data=coco128-seg.yaml' "
91 | 'as an example.\nSee https://docs.ultralytics.com/tasks/segment/ for help.') from e
92 |
93 | # pboxes
94 | pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
95 |
96 | _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
97 | pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
98 | anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
99 |
100 | target_scores_sum = max(target_scores.sum(), 1)
101 |
102 | # cls loss
103 | # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
104 | loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
105 |
106 | if fg_mask.sum():
107 | # bbox loss
108 | loss[0], loss[3] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes / stride_tensor,
109 | target_scores, target_scores_sum, fg_mask)
110 | # masks loss
111 | masks = batch['masks'].to(self.device).float()
112 | if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
113 | masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0]
114 |
115 | for i in range(batch_size):
116 | if fg_mask[i].sum():
117 | mask_idx = target_gt_idx[i][fg_mask[i]]
118 | if self.overlap:
119 | gt_mask = torch.where(masks[[i]] == (mask_idx + 1).view(-1, 1, 1), 1.0, 0.0)
120 | else:
121 | gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
122 | xyxyn = target_bboxes[i][fg_mask[i]] / imgsz[[1, 0, 1, 0]]
123 | marea = xyxy2xywh(xyxyn)[:, 2:].prod(1)
124 | mxyxy = xyxyn * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device)
125 | loss[1] += self.single_mask_loss(gt_mask, pred_masks[i][fg_mask[i]], proto[i], mxyxy, marea) # seg
126 |
127 | # WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
128 | else:
129 | loss[1] += proto.sum() * 0 + pred_masks.sum() * 0
130 |
131 | # WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
132 | else:
133 | loss[1] += proto.sum() * 0 + pred_masks.sum() * 0
134 |
135 | loss[0] *= self.hyp.box # box gain
136 | loss[1] *= self.hyp.box / batch_size # seg gain
137 | loss[2] *= self.hyp.cls # cls gain
138 | loss[3] *= self.hyp.dfl # dfl gain
139 |
140 | return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
141 |
142 | def single_mask_loss(self, gt_mask, pred, proto, xyxy, area):
143 | # Mask loss for one image
144 | pred_mask = (pred @ proto.view(self.nm, -1)).view(-1, *proto.shape[1:]) # (n, 32) @ (32,80,80) -> (n,80,80)
145 | loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction='none')
146 | return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
147 |
148 |
149 | def train(cfg=DEFAULT_CFG, use_python=False):
150 | model = cfg.model or 'yolov8n-seg.pt'
151 | data = cfg.data or 'coco128-seg.yaml' # or yolo.ClassificationDataset("mnist")
152 | device = cfg.device if cfg.device is not None else ''
153 |
154 | args = dict(model=model, data=data, device=device)
155 | if use_python:
156 | from ultralytics import YOLO
157 | YOLO(model).train(**args)
158 | else:
159 | trainer = SegmentationTrainer(overrides=args)
160 | trainer.train()
161 |
162 |
163 | if __name__ == '__main__':
164 | train()
165 |
--------------------------------------------------------------------------------