├── Detect_GUI.py ├── Login_GUI.py ├── README.md ├── engine ├── __init__.py ├── exporter.py ├── model.py ├── predictor.py ├── results.py ├── trainer.py └── validator.py ├── lib └── share.py ├── setup.cfg ├── test ├── yolov8n.pt └── zidane.jpg ├── ui ├── detect_ui.py ├── login_ui.py ├── ori_ui │ ├── detect_ui.ui │ ├── login_ui.ui │ └── registe_ui.ui └── registe_ui.py ├── ui_img ├── delete.png ├── detect.JPG ├── detect_images.qrc ├── detect_images_rc.py ├── exit.png ├── images.png ├── login.JPG ├── login_images.qrc ├── login_images_rc.py ├── registe.JPG ├── registe_images.qrc ├── registe_images_rc.py ├── run.png ├── save.png └── upload.png ├── ultralytics ├── __init__.py ├── hub │ ├── __init__.py │ ├── auth.py │ ├── session.py │ └── utils.py ├── nn │ ├── __init__.py │ ├── autobackend.py │ ├── autoshape.py │ ├── modules.py │ └── tasks.py └── yolo │ ├── __init__.py │ ├── cfg │ ├── __init__.py │ └── default.yaml │ ├── data │ ├── __init__.py │ ├── augment.py │ ├── base.py │ ├── build.py │ ├── dataloaders │ │ ├── __init__.py │ │ ├── stream_loaders.py │ │ ├── v5augmentations.py │ │ └── v5loader.py │ ├── dataset.py │ ├── dataset_wrappers.py │ └── utils.py │ ├── engine │ ├── __init__.py │ ├── exporter.py │ ├── model.py │ ├── predictor.py │ ├── results.py │ ├── trainer.py │ └── validator.py │ ├── utils │ ├── __init__.py │ ├── autobatch.py │ ├── benchmarks.py │ ├── callbacks │ │ ├── __init__.py │ │ ├── base.py │ │ ├── clearml.py │ │ ├── comet.py │ │ ├── hub.py │ │ └── tensorboard.py │ ├── checks.py │ ├── dist.py │ ├── downloads.py │ ├── files.py │ ├── instance.py │ ├── loss.py │ ├── metrics.py │ ├── ops.py │ ├── plotting.py │ ├── tal.py │ └── torch_utils.py │ └── v8 │ ├── __init__.py │ ├── classify │ ├── __init__.py │ ├── predict.py │ ├── train.py │ └── val.py │ ├── detect │ ├── __init__.py │ ├── predict.py │ ├── train.py │ └── val.py │ └── segment │ ├── __init__.py │ ├── predict.py │ ├── train.py │ └── val.py ├── userInfo.csv ├── utils ├── __init__.py ├── autobatch.py ├── benchmarks.py ├── callbacks │ ├── __init__.py │ ├── base.py │ ├── clearml.py │ ├── comet.py │ ├── hub.py │ └── tensorboard.py ├── checks.py ├── dist.py ├── downloads.py ├── files.py ├── id_utils.py ├── instance.py ├── loss.py ├── metrics.py ├── ops.py ├── plotting.py ├── tal.py └── torch_utils.py └── v8 ├── __init__.py ├── classify ├── __init__.py ├── predict.py ├── train.py └── val.py ├── detect ├── __init__.py ├── predict.py ├── train.py └── val.py └── segment ├── __init__.py ├── predict.py ├── train.py └── val.py /Login_GUI.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from datetime import datetime 3 | from PyQt5.QtWidgets import * 4 | from utils.id_utils import get_id_info, sava_id_info # 账号信息工具函数 5 | from lib.share import shareInfo # 公共变量名 6 | # 导入QT-Design生成的UI 7 | from ui.login_ui import Ui_Form 8 | from ui.registe_ui import Ui_Dialog 9 | # 导入设计好的检测界面 10 | from Detect_GUI import Ui_MainWindow 11 | import matplotlib.backends.backend_tkagg 12 | # 界面登录 13 | class win_Login(QMainWindow): 14 | def __init__(self, parent = None): 15 | super(win_Login, self).__init__(parent) 16 | self.ui_login = Ui_Form() 17 | self.ui_login.setupUi(self) 18 | self.init_slots() 19 | self.hidden_pwd() 20 | 21 | # 密码输入框隐藏 22 | def hidden_pwd(self): 23 | self.ui_login.edit_password.setEchoMode(QLineEdit.Password) 24 | 25 | # 绑定信号槽 26 | def init_slots(self): 27 | self.ui_login.btn_login.clicked.connect(self.onSignIn) # 点击按钮登录 28 | self.ui_login.edit_password.returnPressed.connect(self.onSignIn) # 按下回车登录 29 | self.ui_login.btn_regeist.clicked.connect(self.create_id) 30 | 31 | # 跳转到注册界面 32 | def create_id(self): 33 | shareInfo.createWin = win_Registe() 34 | shareInfo.createWin.show() 35 | 36 | # 保存登录日志 37 | def sava_login_log(self, username): 38 | with open('login_log.txt', 'a', encoding='utf-8') as f: 39 | f.write(username + '\t log in at' + datetime.now().strftimestrftime+ '\r') 40 | 41 | # 登录 42 | def onSignIn(self): 43 | print("You pressed sign in") 44 | # 从登陆界面获得输入账户名与密码 45 | username = self.ui_login.edit_username.text().strip() 46 | password = self.ui_login.edit_password.text().strip() 47 | print(username) 48 | print(password) 49 | # 获得账号信息 50 | USER_PWD = get_id_info() 51 | # print(USER_PWD) 52 | if username not in USER_PWD.keys(): 53 | replay = QMessageBox.warning(self,"登陆失败!", "账号或密码输入错误", QMessageBox.Yes) 54 | else: 55 | # 若登陆成功,则跳转主界面 56 | if USER_PWD.get(username) == password: 57 | print("Jump to main window") 58 | # 所以使用公用变量名 59 | # shareInfo.mainWin = UI_Logic_Window() 60 | shareInfo.mainWin = Ui_MainWindow() 61 | shareInfo.mainWin.show() 62 | # 关闭当前窗口 63 | self.close() 64 | else: 65 | replay = QMessageBox.warning(self, "!", "账号或密码输入错误", QMessageBox.Yes) 66 | 67 | # 注册界面 68 | class win_Registe(QMainWindow): 69 | def __init__(self, parent = None): 70 | super(win_Registe, self).__init__(parent) 71 | self.ui_registe = Ui_Dialog() 72 | self.ui_registe.setupUi(self) 73 | self.init_slots() 74 | 75 | # 绑定槽信号 76 | def init_slots(self): 77 | self.ui_registe.pushButton_regiser.clicked.connect(self.new_account) 78 | self.ui_registe.pushButton_cancer.clicked.connect(self.cancel) 79 | 80 | # 创建新账户 81 | def new_account(self): 82 | print("Create new account") 83 | USER_PWD = get_id_info() 84 | # print(USER_PWD) 85 | new_username = self.ui_registe.edit_username.text().strip() 86 | new_password = self.ui_registe.edit_password.text().strip() 87 | # 判断用户名是否为空 88 | if new_username == "": 89 | replay = QMessageBox.warning(self, "!", "账号不准为空", QMessageBox.Yes) 90 | else: 91 | # 判断账号是否存在 92 | if new_username in USER_PWD.keys(): 93 | replay = QMessageBox.warning(self, "!", "账号已存在", QMessageBox.Yes) 94 | else: 95 | # 判断密码是否为空 96 | if new_password == "": 97 | replay = QMessageBox.warning(self, "!", "密码不能为空", QMessageBox.Yes) 98 | else: 99 | # 注册成功 100 | print("Successful!") 101 | sava_id_info(new_username, new_password) 102 | replay = QMessageBox.warning(self, "!", "注册成功!", QMessageBox.Yes) 103 | # 关闭界面 104 | self.close() 105 | 106 | # 取消注册 107 | def cancel(self): 108 | self.close() # 关闭当前界面 109 | 110 | if __name__ == "__main__": 111 | app = QApplication(sys.argv) 112 | # 利用共享变量名来实例化对象 113 | shareInfo.loginWin = win_Login() # 登录界面作为主界面 114 | shareInfo.loginWin.show() 115 | sys.exit(app.exec_()) 116 | 117 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # YOLOv8_PYQT5_GUI 2 | 基于YOLOv8和PYQT5的检测界面 3 | -------------------------------------------------------------------------------- /engine/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbl1234/YOLOv8_PYQT5_GUI/0b8a3f89d7a4b8f6a55b66da5d425b7b2918a407/engine/__init__.py -------------------------------------------------------------------------------- /lib/share.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : Ruihao 3 | # @ProjectName:yolov5-pyqt5 4 | 5 | class shareInfo: 6 | ''' 7 | 存取公用的界面名 # 参考自白月黑羽 www.python3.vip 8 | ''' 9 | mainWin = None 10 | loginWin = None 11 | createWin = None -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # Project-wide configuration file, can be used for package metadata and other toll configurations 2 | # Example usage: global configuration for PEP8 (via flake8) setting or default pytest arguments 3 | # Local usage: pip install pre-commit, pre-commit run --all-files 4 | 5 | [metadata] 6 | license_files = LICENSE 7 | description_file = README.md 8 | 9 | [tool:pytest] 10 | norecursedirs = 11 | .git 12 | dist 13 | build 14 | addopts = 15 | --doctest-modules 16 | --durations=25 17 | --color=yes 18 | 19 | [flake8] 20 | max-line-length = 120 21 | exclude = .tox,*.egg,build,temp 22 | select = E,W,F 23 | doctests = True 24 | verbose = 2 25 | # https://pep8.readthedocs.io/en/latest/intro.html#error-codes 26 | format = pylint 27 | # see: https://www.flake8rules.com/ 28 | ignore = E731,F405,E402,W504,E501 29 | # E731: Do not assign a lambda expression, use a def 30 | # F405: name may be undefined, or defined from star imports: module 31 | # E402: module level import not at top of file 32 | # W504: line break after binary operator 33 | # E501: line too long 34 | # removed: 35 | # F401: module imported but unused 36 | # E231: missing whitespace after ‘,’, ‘;’, or ‘:’ 37 | # E127: continuation line over-indented for visual indent 38 | # F403: ‘from module import *’ used; unable to detect undefined names 39 | 40 | 41 | [isort] 42 | # https://pycqa.github.io/isort/docs/configuration/options.html 43 | line_length = 120 44 | # see: https://pycqa.github.io/isort/docs/configuration/multi_line_output_modes.html 45 | multi_line_output = 0 46 | 47 | [yapf] 48 | based_on_style = pep8 49 | spaces_before_comment = 2 50 | COLUMN_LIMIT = 120 51 | COALESCE_BRACKETS = True 52 | SPACES_AROUND_POWER_OPERATOR = True 53 | SPACE_BETWEEN_ENDING_COMMA_AND_CLOSING_BRACKET = True 54 | SPLIT_BEFORE_CLOSING_BRACKET = False 55 | SPLIT_BEFORE_FIRST_ARGUMENT = False 56 | # EACH_DICT_ENTRY_ON_SEPARATE_LINE = False 57 | -------------------------------------------------------------------------------- /test/yolov8n.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbl1234/YOLOv8_PYQT5_GUI/0b8a3f89d7a4b8f6a55b66da5d425b7b2918a407/test/yolov8n.pt -------------------------------------------------------------------------------- /test/zidane.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbl1234/YOLOv8_PYQT5_GUI/0b8a3f89d7a4b8f6a55b66da5d425b7b2918a407/test/zidane.jpg -------------------------------------------------------------------------------- /ui/detect_ui.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Form implementation generated from reading ui file 'ui/ori_ui/detect_ui.ui' 4 | # 5 | # Created by: PyQt5 UI code generator 5.15.7 6 | # 7 | # WARNING: Any manual changes made to this file will be lost when pyuic5 is 8 | # run again. Do not edit this file unless you know what you are doing. 9 | 10 | 11 | from PyQt5 import QtCore, QtGui, QtWidgets 12 | 13 | 14 | class Ui_MainWindow(object): 15 | def setupUi(self, MainWindow): 16 | MainWindow.setObjectName("MainWindow") 17 | MainWindow.resize(1500, 1000) 18 | MainWindow.setStyleSheet("QWidget#centralwidget{\n" 19 | " background-image: url(:/detect_background/detect.jpg);}") 20 | self.centralwidget = QtWidgets.QWidget(MainWindow) 21 | self.centralwidget.setObjectName("centralwidget") 22 | self.pushButton = QtWidgets.QPushButton(self.centralwidget) 23 | self.pushButton.setGeometry(QtCore.QRect(70, 810, 70, 70)) 24 | self.pushButton.setStyleSheet("border-image: url(:/detect_button_background/upload.png);\n" 25 | "\n" 26 | "") 27 | self.pushButton.setText("") 28 | self.pushButton.setObjectName("pushButton") 29 | self.pushButton_3 = QtWidgets.QPushButton(self.centralwidget) 30 | self.pushButton_3.setGeometry(QtCore.QRect(390, 810, 70, 70)) 31 | self.pushButton_3.setStyleSheet("border-image: url(:/detect_button_background/images.png);") 32 | self.pushButton_3.setText("") 33 | self.pushButton_3.setObjectName("pushButton_3") 34 | self.pushButton_4 = QtWidgets.QPushButton(self.centralwidget) 35 | self.pushButton_4.setGeometry(QtCore.QRect(730, 810, 70, 70)) 36 | self.pushButton_4.setStyleSheet("border-image: url(:/detect_button_background/save.png);") 37 | self.pushButton_4.setText("") 38 | self.pushButton_4.setObjectName("pushButton_4") 39 | self.pushButton_5 = QtWidgets.QPushButton(self.centralwidget) 40 | self.pushButton_5.setGeometry(QtCore.QRect(1050, 810, 70, 70)) 41 | self.pushButton_5.setStyleSheet("border-image: url(:/detect_button_background/delete.png);") 42 | self.pushButton_5.setText("") 43 | self.pushButton_5.setObjectName("pushButton_5") 44 | self.pushButton_6 = QtWidgets.QPushButton(self.centralwidget) 45 | self.pushButton_6.setGeometry(QtCore.QRect(1360, 810, 70, 70)) 46 | self.pushButton_6.setStyleSheet("border-image: url(:/detect_button_background/exit.png);") 47 | self.pushButton_6.setText("") 48 | self.pushButton_6.setObjectName("pushButton_6") 49 | self.label_2 = QtWidgets.QLabel(self.centralwidget) 50 | self.label_2.setGeometry(QtCore.QRect(190, 10, 1101, 80)) 51 | font = QtGui.QFont() 52 | font.setFamily("Adobe 黑体 Std R") 53 | font.setPointSize(28) 54 | self.label_2.setFont(font) 55 | self.label_2.setStyleSheet("") 56 | self.label_2.setObjectName("label_2") 57 | self.label_3 = QtWidgets.QLabel(self.centralwidget) 58 | self.label_3.setGeometry(QtCore.QRect(0, 80, 700, 700)) 59 | self.label_3.setStyleSheet("background-color: rgb(255, 255, 255);") 60 | self.label_3.setObjectName("label_3") 61 | self.label_4 = QtWidgets.QLabel(self.centralwidget) 62 | self.label_4.setGeometry(QtCore.QRect(800, 80, 700, 700)) 63 | self.label_4.setStyleSheet("background-color: rgb(255, 255, 255);") 64 | self.label_4.setObjectName("label_4") 65 | self.label_5 = QtWidgets.QLabel(self.centralwidget) 66 | self.label_5.setGeometry(QtCore.QRect(-1, 800, 1501, 141)) 67 | self.label_5.setStyleSheet("background-color: rgb(255, 255, 255);\n" 68 | "border-color: rgb(0, 0, 0);") 69 | self.label_5.setText("") 70 | self.label_5.setObjectName("label_5") 71 | self.lineEdit = QtWidgets.QLineEdit(self.centralwidget) 72 | self.lineEdit.setGeometry(QtCore.QRect(20, 890, 161, 40)) 73 | font = QtGui.QFont() 74 | font.setFamily("Adobe 宋体 Std L") 75 | font.setPointSize(20) 76 | self.lineEdit.setFont(font) 77 | self.lineEdit.setLayoutDirection(QtCore.Qt.RightToLeft) 78 | self.lineEdit.setObjectName("lineEdit") 79 | self.lineEdit_3 = QtWidgets.QLineEdit(self.centralwidget) 80 | self.lineEdit_3.setGeometry(QtCore.QRect(350, 890, 161, 40)) 81 | font = QtGui.QFont() 82 | font.setFamily("Adobe 宋体 Std L") 83 | font.setPointSize(20) 84 | self.lineEdit_3.setFont(font) 85 | self.lineEdit_3.setObjectName("lineEdit_3") 86 | self.lineEdit_4 = QtWidgets.QLineEdit(self.centralwidget) 87 | self.lineEdit_4.setGeometry(QtCore.QRect(690, 890, 161, 40)) 88 | font = QtGui.QFont() 89 | font.setFamily("Adobe 宋体 Std L") 90 | font.setPointSize(20) 91 | self.lineEdit_4.setFont(font) 92 | self.lineEdit_4.setObjectName("lineEdit_4") 93 | self.lineEdit_5 = QtWidgets.QLineEdit(self.centralwidget) 94 | self.lineEdit_5.setGeometry(QtCore.QRect(1000, 890, 161, 40)) 95 | font = QtGui.QFont() 96 | font.setFamily("Adobe 宋体 Std L") 97 | font.setPointSize(20) 98 | self.lineEdit_5.setFont(font) 99 | self.lineEdit_5.setObjectName("lineEdit_5") 100 | self.lineEdit_6 = QtWidgets.QLineEdit(self.centralwidget) 101 | self.lineEdit_6.setGeometry(QtCore.QRect(1300, 890, 161, 40)) 102 | font = QtGui.QFont() 103 | font.setFamily("Adobe 宋体 Std L") 104 | font.setPointSize(20) 105 | self.lineEdit_6.setFont(font) 106 | self.lineEdit_6.setObjectName("lineEdit_6") 107 | self.label_5.raise_() 108 | self.pushButton.raise_() 109 | self.pushButton_3.raise_() 110 | self.pushButton_4.raise_() 111 | self.pushButton_5.raise_() 112 | self.pushButton_6.raise_() 113 | self.label_2.raise_() 114 | self.label_3.raise_() 115 | self.label_4.raise_() 116 | self.lineEdit.raise_() 117 | self.lineEdit_3.raise_() 118 | self.lineEdit_4.raise_() 119 | self.lineEdit_5.raise_() 120 | self.lineEdit_6.raise_() 121 | MainWindow.setCentralWidget(self.centralwidget) 122 | self.menubar = QtWidgets.QMenuBar(MainWindow) 123 | self.menubar.setGeometry(QtCore.QRect(0, 0, 1500, 30)) 124 | self.menubar.setObjectName("menubar") 125 | MainWindow.setMenuBar(self.menubar) 126 | self.statusbar = QtWidgets.QStatusBar(MainWindow) 127 | self.statusbar.setObjectName("statusbar") 128 | MainWindow.setStatusBar(self.statusbar) 129 | 130 | self.retranslateUi(MainWindow) 131 | QtCore.QMetaObject.connectSlotsByName(MainWindow) 132 | 133 | def retranslateUi(self, MainWindow): 134 | _translate = QtCore.QCoreApplication.translate 135 | MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow")) 136 | self.label_2.setText(_translate("MainWindow", "

基于YOLOv8的疏果前期幼果检测演示软件

")) 137 | self.label_3.setText(_translate("MainWindow", "

原始图像

")) 138 | self.label_4.setText(_translate("MainWindow", "

检测图像

")) 139 | self.lineEdit.setToolTip(_translate("MainWindow", "


")) 140 | self.lineEdit.setWhatsThis(_translate("MainWindow", "


")) 141 | self.lineEdit.setText(_translate("MainWindow", "模型加载")) 142 | self.lineEdit_3.setText(_translate("MainWindow", "图像加载")) 143 | self.lineEdit_4.setText(_translate("MainWindow", "图像保存")) 144 | self.lineEdit_5.setText(_translate("MainWindow", "图像清除")) 145 | self.lineEdit_6.setText(_translate("MainWindow", "应用退出")) 146 | import ui_img.detect_images 147 | -------------------------------------------------------------------------------- /ui/login_ui.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Form implementation generated from reading ui file 'ui/ori_ui/login_ui.ui' 4 | # 5 | # Created by: PyQt5 UI code generator 5.15.7 6 | # 7 | # WARNING: Any manual changes made to this file will be lost when pyuic5 is 8 | # run again. Do not edit this file unless you know what you are doing. 9 | 10 | 11 | from PyQt5 import QtCore, QtGui, QtWidgets 12 | import ui_img.login_images_rc 13 | 14 | class Ui_Form(object): 15 | def setupUi(self, Form): 16 | Form.setObjectName("Form") 17 | Form.setWindowTitle("基于YOLOv8的检测演示软件V1.0") 18 | Form.resize(800, 600) 19 | Form.setStyleSheet("QWidget#Form{background-image: url(:/login_background/login.JPG);}") 20 | 21 | 22 | self.edit_username = QtWidgets.QLineEdit(Form) 23 | self.edit_username.setGeometry(QtCore.QRect(250, 200, 250, 50)) 24 | font = QtGui.QFont() 25 | font.setFamily("Adobe 黑体 Std R") 26 | font.setPointSize(18) 27 | self.edit_username.setFont(font) 28 | self.edit_username.setObjectName("edit_username") 29 | 30 | self.btn_login = QtWidgets.QPushButton(Form) 31 | self.btn_login.setGeometry(QtCore.QRect(220, 430, 120, 80)) 32 | font = QtGui.QFont() 33 | font.setFamily("Adobe 黑体 Std R") 34 | font.setPointSize(18) 35 | self.btn_login.setFont(font) 36 | self.btn_login.setObjectName("btn_login") 37 | 38 | self.edit_password = QtWidgets.QLineEdit(Form) 39 | self.edit_password.setGeometry(QtCore.QRect(250, 300, 250, 50)) 40 | font = QtGui.QFont() 41 | font.setFamily("Adobe 黑体 Std R") 42 | font.setPointSize(18) 43 | self.edit_password.setFont(font) 44 | self.edit_password.setObjectName("edit_password") 45 | 46 | self.label = QtWidgets.QLabel(Form) 47 | self.label.setGeometry(QtCore.QRect(0, 80, 781, 80)) 48 | font = QtGui.QFont() 49 | font.setFamily("Adobe 黑体 Std R") 50 | font.setPointSize(22) 51 | self.label.setFont(font) 52 | self.label.setStyleSheet("") 53 | self.label.setObjectName("label") 54 | 55 | self.btn_regeist = QtWidgets.QPushButton(Form) 56 | self.btn_regeist.setGeometry(QtCore.QRect(420, 430, 120, 80)) 57 | font = QtGui.QFont() 58 | font.setFamily("Adobe 黑体 Std R") 59 | font.setPointSize(18) 60 | self.btn_regeist.setFont(font) 61 | self.btn_regeist.setObjectName("btn_regeist") 62 | 63 | self.retranslateUi(Form) 64 | QtCore.QMetaObject.connectSlotsByName(Form) 65 | 66 | def retranslateUi(self, Form): 67 | _translate = QtCore.QCoreApplication.translate 68 | Form.setWindowTitle(_translate("Form", "Form")) 69 | self.edit_username.setPlaceholderText(_translate("Form", "用户名")) 70 | self.btn_login.setText(_translate("Form", "登录")) 71 | self.edit_password.setPlaceholderText(_translate("Form", "密码")) 72 | self.label.setText(_translate("Form", "

基于YOLOv8的检测演示软件

")) 73 | self.btn_regeist.setText(_translate("Form", "注册")) 74 | 75 | -------------------------------------------------------------------------------- /ui/ori_ui/detect_ui.ui: -------------------------------------------------------------------------------- 1 | 2 | 3 | MainWindow 4 | 5 | 6 | 7 | 0 8 | 0 9 | 1500 10 | 1000 11 | 12 | 13 | 14 | MainWindow 15 | 16 | 17 | QWidget#centralwidget{ 18 | background-image: url(:/detect_background/detect.jpg);} 19 | 20 | 21 | 22 | 23 | 24 | 70 25 | 810 26 | 70 27 | 70 28 | 29 | 30 | 31 | border-image: url(:/detect_button_background/upload.png); 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 390 43 | 810 44 | 70 45 | 70 46 | 47 | 48 | 49 | border-image: url(:/detect_button_background/images.png); 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 730 59 | 810 60 | 70 61 | 70 62 | 63 | 64 | 65 | border-image: url(:/detect_button_background/save.png); 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 1050 75 | 810 76 | 70 77 | 70 78 | 79 | 80 | 81 | border-image: url(:/detect_button_background/delete.png); 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 1360 91 | 810 92 | 70 93 | 70 94 | 95 | 96 | 97 | border-image: url(:/detect_button_background/exit.png); 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 190 107 | 10 108 | 1101 109 | 80 110 | 111 | 112 | 113 | 114 | Adobe 黑体 Std R 115 | 28 116 | 117 | 118 | 119 | 120 | 121 | 122 | <html><head/><body><p align="center"><span style=" font-size:28pt; font-weight:600; color:#ffffff;">基于YOLOv8的疏果前期幼果检测演示软件</span></p></body></html> 123 | 124 | 125 | 126 | 127 | 128 | 0 129 | 80 130 | 700 131 | 700 132 | 133 | 134 | 135 | background-color: rgb(255, 255, 255); 136 | 137 | 138 | <html><head/><body><p align="center"><span style=" font-size:20pt;">原始图像</span></p></body></html> 139 | 140 | 141 | 142 | 143 | 144 | 800 145 | 80 146 | 700 147 | 700 148 | 149 | 150 | 151 | background-color: rgb(255, 255, 255); 152 | 153 | 154 | <html><head/><body><p align="center"><span style=" font-size:20pt;">检测图像</span></p></body></html> 155 | 156 | 157 | 158 | 159 | 160 | -1 161 | 800 162 | 1501 163 | 141 164 | 165 | 166 | 167 | background-color: rgb(255, 255, 255); 168 | border-color: rgb(0, 0, 0); 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 20 178 | 890 179 | 161 180 | 40 181 | 182 | 183 | 184 | 185 | Adobe 宋体 Std L 186 | 20 187 | 188 | 189 | 190 | <html><head/><body><p align="center"><br/></p></body></html> 191 | 192 | 193 | <html><head/><body><p align="center"><br/></p></body></html> 194 | 195 | 196 | Qt::RightToLeft 197 | 198 | 199 | 模型加载 200 | 201 | 202 | 203 | 204 | 205 | 350 206 | 890 207 | 161 208 | 40 209 | 210 | 211 | 212 | 213 | Adobe 宋体 Std L 214 | 20 215 | 216 | 217 | 218 | 图像加载 219 | 220 | 221 | 222 | 223 | 224 | 690 225 | 890 226 | 161 227 | 40 228 | 229 | 230 | 231 | 232 | Adobe 宋体 Std L 233 | 20 234 | 235 | 236 | 237 | 图像保存 238 | 239 | 240 | 241 | 242 | 243 | 1000 244 | 890 245 | 161 246 | 40 247 | 248 | 249 | 250 | 251 | Adobe 宋体 Std L 252 | 20 253 | 254 | 255 | 256 | 图像清除 257 | 258 | 259 | 260 | 261 | 262 | 1300 263 | 890 264 | 161 265 | 40 266 | 267 | 268 | 269 | 270 | Adobe 宋体 Std L 271 | 20 272 | 273 | 274 | 275 | 应用退出 276 | 277 | 278 | label_5 279 | pushButton 280 | pushButton_3 281 | pushButton_4 282 | pushButton_5 283 | pushButton_6 284 | label_2 285 | label_3 286 | label_4 287 | lineEdit 288 | lineEdit_3 289 | lineEdit_4 290 | lineEdit_5 291 | lineEdit_6 292 | 293 | 294 | 295 | 296 | 0 297 | 0 298 | 1500 299 | 30 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | -------------------------------------------------------------------------------- /ui/ori_ui/login_ui.ui: -------------------------------------------------------------------------------- 1 | 2 | 3 | Form 4 | 5 | 6 | 7 | 0 8 | 0 9 | 800 10 | 600 11 | 12 | 13 | 14 | Form 15 | 16 | 17 | QWidget#Form{ 18 | background-image: url(:/login_background/denglu.jpg);} 19 | 20 | 21 | 22 | 23 | 250 24 | 200 25 | 250 26 | 50 27 | 28 | 29 | 30 | 31 | Adobe 黑体 Std R 32 | 18 33 | 34 | 35 | 36 | 用户名 37 | 38 | 39 | 40 | 41 | 42 | 220 43 | 430 44 | 120 45 | 80 46 | 47 | 48 | 49 | 50 | Adobe 黑体 Std R 51 | 18 52 | 53 | 54 | 55 | 登录 56 | 57 | 58 | 59 | 60 | 61 | 250 62 | 300 63 | 250 64 | 50 65 | 66 | 67 | 68 | 69 | Adobe 黑体 Std R 70 | 18 71 | 72 | 73 | 74 | 密码 75 | 76 | 77 | 78 | 79 | 80 | 0 81 | 80 82 | 781 83 | 80 84 | 85 | 86 | 87 | 88 | Adobe 黑体 Std R 89 | 22 90 | 91 | 92 | 93 | 94 | 95 | 96 | <html><head/><body><p align="center"><span style=" font-size:20pt; color:#ffffff;">基于YOLOv8的疏果前期幼果检测演示软件</span></p></body></html> 97 | 98 | 99 | 100 | 101 | 102 | 420 103 | 430 104 | 120 105 | 80 106 | 107 | 108 | 109 | 110 | Adobe 黑体 Std R 111 | 18 112 | 113 | 114 | 115 | 注册 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /ui/ori_ui/registe_ui.ui: -------------------------------------------------------------------------------- 1 | 2 | 3 | Dialog 4 | 5 | 6 | 7 | 0 8 | 0 9 | 800 10 | 600 11 | 12 | 13 | 14 | Dialog 15 | 16 | 17 | QDialog#Dialog{ 18 | background-image: url(:/registe_background/registe.JPG);} 19 | 20 | 21 | 22 | 23 | 24 | 25 | 260 26 | 120 27 | 241 28 | 81 29 | 30 | 31 | 32 | 33 | Adobe 黑体 Std R 34 | 28 35 | 75 36 | true 37 | 38 | 39 | 40 | 41 | 42 | 43 | <html><head/><body><p align="center"><span style=" font-size:28pt; color:#ffffff;">账号注册</span></p></body></html> 44 | 45 | 46 | 47 | 48 | 49 | 260 50 | 320 51 | 250 52 | 50 53 | 54 | 55 | 56 | 57 | Adobe 黑体 Std R 58 | 18 59 | 60 | 61 | 62 | 设置密码 63 | 64 | 65 | 66 | 67 | 68 | 220 69 | 410 70 | 120 71 | 80 72 | 73 | 74 | 75 | 76 | Adobe 黑体 Std R 77 | 18 78 | 79 | 80 | 81 | 注册 82 | 83 | 84 | 85 | 86 | 87 | 260 88 | 230 89 | 250 90 | 50 91 | 92 | 93 | 94 | 95 | Adobe 黑体 Std R 96 | 18 97 | 98 | 99 | 100 | 注册账号 101 | 102 | 103 | 104 | 105 | 106 | 420 107 | 410 108 | 120 109 | 80 110 | 111 | 112 | 113 | 114 | Adobe 黑体 Std R 115 | 18 116 | 117 | 118 | 119 | 取消 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /ui/registe_ui.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Form implementation generated from reading ui file 'ui/ori_ui/registe_ui.ui' 4 | # 5 | # Created by: PyQt5 UI code generator 5.15.7 6 | # 7 | # WARNING: Any manual changes made to this file will be lost when pyuic5 is 8 | # run again. Do not edit this file unless you know what you are doing. 9 | 10 | 11 | from PyQt5 import QtCore, QtGui, QtWidgets 12 | import ui_img.registe_images_rc 13 | 14 | class Ui_Dialog(object): 15 | def setupUi(self, Dialog): 16 | Dialog.setObjectName("Dialog") 17 | Dialog.setWindowTitle("基于YOLOv8的检测演示软件V1.0") 18 | Dialog.resize(800, 600) 19 | Dialog.setStyleSheet("QDialog#Dialog{background-image: url(:/registe_background/registe.JPG);}\n" 20 | "") 21 | 22 | 23 | self.label = QtWidgets.QLabel(Dialog) 24 | self.label.setGeometry(QtCore.QRect(260, 120, 241, 81)) 25 | font = QtGui.QFont() 26 | font.setFamily("Adobe 黑体 Std R") 27 | font.setPointSize(28) 28 | font.setBold(True) 29 | font.setWeight(75) 30 | self.label.setFont(font) 31 | self.label.setStyleSheet("") 32 | self.label.setObjectName("label") 33 | 34 | self.edit_password = QtWidgets.QLineEdit(Dialog) 35 | self.edit_password.setGeometry(QtCore.QRect(260, 320, 250, 50)) 36 | font = QtGui.QFont() 37 | font.setFamily("Adobe 黑体 Std R") 38 | font.setPointSize(18) 39 | self.edit_password.setFont(font) 40 | self.edit_password.setObjectName("edit_password") 41 | 42 | self.pushButton_regiser = QtWidgets.QPushButton(Dialog) 43 | self.pushButton_regiser.setGeometry(QtCore.QRect(220, 410, 120, 80)) 44 | font = QtGui.QFont() 45 | font.setFamily("Adobe 黑体 Std R") 46 | font.setPointSize(18) 47 | self.pushButton_regiser.setFont(font) 48 | self.pushButton_regiser.setObjectName("pushButton_regiser") 49 | 50 | self.edit_username = QtWidgets.QLineEdit(Dialog) 51 | self.edit_username.setGeometry(QtCore.QRect(260, 230, 250, 50)) 52 | font = QtGui.QFont() 53 | font.setFamily("Adobe 黑体 Std R") 54 | font.setPointSize(18) 55 | self.edit_username.setFont(font) 56 | self.edit_username.setObjectName("edit_username") 57 | 58 | self.pushButton_cancer = QtWidgets.QPushButton(Dialog) 59 | self.pushButton_cancer.setGeometry(QtCore.QRect(420, 410, 120, 80)) 60 | font = QtGui.QFont() 61 | font.setFamily("Adobe 黑体 Std R") 62 | font.setPointSize(18) 63 | self.pushButton_cancer.setFont(font) 64 | self.pushButton_cancer.setObjectName("pushButton_cancer") 65 | 66 | self.retranslateUi(Dialog) 67 | QtCore.QMetaObject.connectSlotsByName(Dialog) 68 | 69 | def retranslateUi(self, Dialog): 70 | _translate = QtCore.QCoreApplication.translate 71 | Dialog.setWindowTitle(_translate("Dialog", "Dialog")) 72 | self.label.setText(_translate("Dialog", "

账号注册

")) 73 | self.edit_password.setPlaceholderText(_translate("Dialog", "设置密码")) 74 | self.pushButton_regiser.setText(_translate("Dialog", "注册")) 75 | self.edit_username.setPlaceholderText(_translate("Dialog", "注册账号")) 76 | self.pushButton_cancer.setText(_translate("Dialog", "取消")) 77 | 78 | -------------------------------------------------------------------------------- /ui_img/delete.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbl1234/YOLOv8_PYQT5_GUI/0b8a3f89d7a4b8f6a55b66da5d425b7b2918a407/ui_img/delete.png -------------------------------------------------------------------------------- /ui_img/detect.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbl1234/YOLOv8_PYQT5_GUI/0b8a3f89d7a4b8f6a55b66da5d425b7b2918a407/ui_img/detect.JPG -------------------------------------------------------------------------------- /ui_img/detect_images.qrc: -------------------------------------------------------------------------------- 1 | 2 | 3 | detect.JPG 4 | 5 | 6 | upload.png 7 | images.png 8 | delete.png 9 | exit.png 10 | run.png 11 | save.png 12 | 13 | 14 | -------------------------------------------------------------------------------- /ui_img/exit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbl1234/YOLOv8_PYQT5_GUI/0b8a3f89d7a4b8f6a55b66da5d425b7b2918a407/ui_img/exit.png -------------------------------------------------------------------------------- /ui_img/images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbl1234/YOLOv8_PYQT5_GUI/0b8a3f89d7a4b8f6a55b66da5d425b7b2918a407/ui_img/images.png -------------------------------------------------------------------------------- /ui_img/login.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbl1234/YOLOv8_PYQT5_GUI/0b8a3f89d7a4b8f6a55b66da5d425b7b2918a407/ui_img/login.JPG -------------------------------------------------------------------------------- /ui_img/login_images.qrc: -------------------------------------------------------------------------------- 1 | 2 | 3 | login.JPG 4 | 5 | 6 | -------------------------------------------------------------------------------- /ui_img/registe.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbl1234/YOLOv8_PYQT5_GUI/0b8a3f89d7a4b8f6a55b66da5d425b7b2918a407/ui_img/registe.JPG -------------------------------------------------------------------------------- /ui_img/registe_images.qrc: -------------------------------------------------------------------------------- 1 | 2 | 3 | registe.JPG 4 | 5 | 6 | -------------------------------------------------------------------------------- /ui_img/run.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbl1234/YOLOv8_PYQT5_GUI/0b8a3f89d7a4b8f6a55b66da5d425b7b2918a407/ui_img/run.png -------------------------------------------------------------------------------- /ui_img/save.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbl1234/YOLOv8_PYQT5_GUI/0b8a3f89d7a4b8f6a55b66da5d425b7b2918a407/ui_img/save.png -------------------------------------------------------------------------------- /ui_img/upload.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbl1234/YOLOv8_PYQT5_GUI/0b8a3f89d7a4b8f6a55b66da5d425b7b2918a407/ui_img/upload.png -------------------------------------------------------------------------------- /ultralytics/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | __version__ = '8.0.55' 4 | 5 | from ultralytics.yolo.engine.model import YOLO 6 | from ultralytics.yolo.utils.checks import check_yolo as checks 7 | 8 | __all__ = '__version__', 'YOLO', 'checks' # allow simpler import 9 | -------------------------------------------------------------------------------- /ultralytics/hub/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | import requests 4 | 5 | from ultralytics.hub.auth import Auth 6 | from ultralytics.hub.session import HUBTrainingSession 7 | from ultralytics.hub.utils import PREFIX, split_key 8 | from ultralytics.yolo.engine.model import YOLO 9 | from ultralytics.yolo.utils import LOGGER, emojis 10 | 11 | 12 | def start(key=''): 13 | """ 14 | Start training models with Ultralytics HUB. Usage: from ultralytics.hub import start; start('API_KEY') 15 | """ 16 | auth = Auth(key) 17 | if not auth.get_state(): 18 | model_id = request_api_key(auth) 19 | else: 20 | _, model_id = split_key(key) 21 | 22 | if not model_id: 23 | raise ConnectionError(emojis('Connecting with global API key is not currently supported. ❌')) 24 | 25 | session = HUBTrainingSession(model_id=model_id, auth=auth) 26 | session.check_disk_space() 27 | 28 | model = YOLO(model=session.model_file, session=session) 29 | model.train(**session.train_args) 30 | 31 | 32 | def request_api_key(auth, max_attempts=3): 33 | """ 34 | Prompt the user to input their API key. Returns the model ID. 35 | """ 36 | import getpass 37 | for attempts in range(max_attempts): 38 | LOGGER.info(f'{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}') 39 | input_key = getpass.getpass('Enter your Ultralytics HUB API key:\n') 40 | auth.api_key, model_id = split_key(input_key) 41 | 42 | if auth.authenticate(): 43 | LOGGER.info(f'{PREFIX}Authenticated ✅') 44 | return model_id 45 | 46 | LOGGER.warning(f'{PREFIX}Invalid API key ⚠️\n') 47 | 48 | raise ConnectionError(emojis(f'{PREFIX}Failed to authenticate ❌')) 49 | 50 | 51 | def reset_model(key=''): 52 | # Reset a trained model to an untrained state 53 | api_key, model_id = split_key(key) 54 | r = requests.post('https://api.ultralytics.com/model-reset', json={'apiKey': api_key, 'modelId': model_id}) 55 | 56 | if r.status_code == 200: 57 | LOGGER.info(f'{PREFIX}Model reset successfully') 58 | return 59 | LOGGER.warning(f'{PREFIX}Model reset failure {r.status_code} {r.reason}') 60 | 61 | 62 | def export_fmts_hub(): 63 | # Returns a list of HUB-supported export formats 64 | from ultralytics.yolo.engine.exporter import export_formats 65 | return list(export_formats()['Argument'][1:]) + ['ultralytics_tflite', 'ultralytics_coreml'] 66 | 67 | 68 | def export_model(key='', format='torchscript'): 69 | # Export a model to all formats 70 | assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}" 71 | api_key, model_id = split_key(key) 72 | r = requests.post('https://api.ultralytics.com/export', 73 | json={ 74 | 'apiKey': api_key, 75 | 'modelId': model_id, 76 | 'format': format}) 77 | assert r.status_code == 200, f'{PREFIX}{format} export failure {r.status_code} {r.reason}' 78 | LOGGER.info(f'{PREFIX}{format} export started ✅') 79 | 80 | 81 | def get_export(key='', format='torchscript'): 82 | # Get an exported model dictionary with download URL 83 | assert format in export_fmts_hub, f"Unsupported export format '{format}', valid formats are {export_fmts_hub}" 84 | api_key, model_id = split_key(key) 85 | r = requests.post('https://api.ultralytics.com/get-export', 86 | json={ 87 | 'apiKey': api_key, 88 | 'modelId': model_id, 89 | 'format': format}) 90 | assert r.status_code == 200, f'{PREFIX}{format} get_export failure {r.status_code} {r.reason}' 91 | return r.json() 92 | 93 | 94 | if __name__ == '__main__': 95 | start() 96 | -------------------------------------------------------------------------------- /ultralytics/hub/auth.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | import requests 4 | 5 | from ultralytics.hub.utils import HUB_API_ROOT, request_with_credentials 6 | from ultralytics.yolo.utils import is_colab 7 | 8 | API_KEY_PATH = 'https://hub.ultralytics.com/settings?tab=api+keys' 9 | 10 | 11 | class Auth: 12 | id_token = api_key = model_key = False 13 | 14 | def __init__(self, api_key=None): 15 | self.api_key = self._clean_api_key(api_key) 16 | self.authenticate() if self.api_key else self.auth_with_cookies() 17 | 18 | @staticmethod 19 | def _clean_api_key(key: str) -> str: 20 | """Strip model from key if present""" 21 | separator = '_' 22 | return key.split(separator)[0] if separator in key else key 23 | 24 | def authenticate(self) -> bool: 25 | """Attempt to authenticate with server""" 26 | try: 27 | header = self.get_auth_header() 28 | if header: 29 | r = requests.post(f'{HUB_API_ROOT}/v1/auth', headers=header) 30 | if not r.json().get('success', False): 31 | raise ConnectionError('Unable to authenticate.') 32 | return True 33 | raise ConnectionError('User has not authenticated locally.') 34 | except ConnectionError: 35 | self.id_token = self.api_key = False # reset invalid 36 | return False 37 | 38 | def auth_with_cookies(self) -> bool: 39 | """ 40 | Attempt to fetch authentication via cookies and set id_token. 41 | User must be logged in to HUB and running in a supported browser. 42 | """ 43 | if not is_colab(): 44 | return False # Currently only works with Colab 45 | try: 46 | authn = request_with_credentials(f'{HUB_API_ROOT}/v1/auth/auto') 47 | if authn.get('success', False): 48 | self.id_token = authn.get('data', {}).get('idToken', None) 49 | self.authenticate() 50 | return True 51 | raise ConnectionError('Unable to fetch browser authentication details.') 52 | except ConnectionError: 53 | self.id_token = False # reset invalid 54 | return False 55 | 56 | def get_auth_header(self): 57 | if self.id_token: 58 | return {'authorization': f'Bearer {self.id_token}'} 59 | elif self.api_key: 60 | return {'x-api-key': self.api_key} 61 | else: 62 | return None 63 | 64 | def get_state(self) -> bool: 65 | """Get the authentication state""" 66 | return self.id_token or self.api_key 67 | 68 | def set_api_key(self, key: str): 69 | """Get the authentication state""" 70 | self.api_key = key 71 | -------------------------------------------------------------------------------- /ultralytics/hub/session.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | import signal 3 | import sys 4 | from pathlib import Path 5 | from time import sleep 6 | 7 | import requests 8 | 9 | from ultralytics.hub.utils import HUB_API_ROOT, check_dataset_disk_space, smart_request 10 | from ultralytics.yolo.utils import LOGGER, PREFIX, __version__, checks, emojis, is_colab, threaded 11 | 12 | AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local' 13 | 14 | 15 | class HUBTrainingSession: 16 | 17 | def __init__(self, model_id, auth): 18 | self.agent_id = None # identifies which instance is communicating with server 19 | self.model_id = model_id 20 | self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}' 21 | self.auth_header = auth.get_auth_header() 22 | self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds) 23 | self.timers = {} # rate limit timers (seconds) 24 | self.metrics_queue = {} # metrics queue 25 | self.model = self._get_model() 26 | self.alive = True 27 | self._start_heartbeat() # start heartbeats 28 | self._register_signal_handlers() 29 | 30 | def _register_signal_handlers(self): 31 | signal.signal(signal.SIGTERM, self._handle_signal) 32 | signal.signal(signal.SIGINT, self._handle_signal) 33 | 34 | def _handle_signal(self, signum, frame): 35 | """ 36 | Prevent heartbeats from being sent on Colab after kill. 37 | This method does not use frame, it is included as it is 38 | passed by signal. 39 | """ 40 | if self.alive is True: 41 | LOGGER.info(f'{PREFIX}Kill signal received! ❌') 42 | self._stop_heartbeat() 43 | sys.exit(signum) 44 | 45 | def _stop_heartbeat(self): 46 | """End the heartbeat loop""" 47 | self.alive = False 48 | 49 | def upload_metrics(self): 50 | payload = {'metrics': self.metrics_queue.copy(), 'type': 'metrics'} 51 | smart_request('post', self.api_url, json=payload, headers=self.auth_header, code=2) 52 | 53 | def _get_model(self): 54 | # Returns model from database by id 55 | api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}' 56 | 57 | try: 58 | response = smart_request('get', api_url, headers=self.auth_header, thread=False, code=0) 59 | data = response.json().get('data', None) 60 | 61 | if data.get('status', None) == 'trained': 62 | raise ValueError( 63 | emojis(f'Model is already trained and uploaded to ' 64 | f'https://hub.ultralytics.com/models/{self.model_id} 🚀')) 65 | 66 | if not data.get('data', None): 67 | raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix 68 | self.model_id = data['id'] 69 | 70 | # TODO: restore when server keys when dataset URL and GPU train is working 71 | 72 | self.train_args = { 73 | 'batch': data['batch_size'], 74 | 'epochs': data['epochs'], 75 | 'imgsz': data['imgsz'], 76 | 'patience': data['patience'], 77 | 'device': data['device'], 78 | 'cache': data['cache'], 79 | 'data': data['data']} 80 | 81 | self.model_file = data.get('cfg', data['weights']) 82 | self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u 83 | 84 | return data 85 | except requests.exceptions.ConnectionError as e: 86 | raise ConnectionRefusedError('ERROR: The HUB server is not online. Please try again later.') from e 87 | except Exception: 88 | raise 89 | 90 | def check_disk_space(self): 91 | if not check_dataset_disk_space(self.model['data']): 92 | raise MemoryError('Not enough disk space') 93 | 94 | def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False): 95 | # Upload a model to HUB 96 | if Path(weights).is_file(): 97 | with open(weights, 'rb') as f: 98 | file = f.read() 99 | else: 100 | LOGGER.warning(f'{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.') 101 | file = None 102 | url = f'{self.api_url}/upload' 103 | # url = 'http://httpbin.org/post' # for debug 104 | data = {'epoch': epoch} 105 | if final: 106 | data.update({'type': 'final', 'map': map}) 107 | smart_request('post', 108 | url, 109 | data=data, 110 | files={'best.pt': file}, 111 | headers=self.auth_header, 112 | retry=10, 113 | timeout=3600, 114 | thread=False, 115 | progress=True, 116 | code=4) 117 | else: 118 | data.update({'type': 'epoch', 'isBest': bool(is_best)}) 119 | smart_request('post', url, data=data, files={'last.pt': file}, headers=self.auth_header, code=3) 120 | 121 | @threaded 122 | def _start_heartbeat(self): 123 | while self.alive: 124 | r = smart_request('post', 125 | f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}', 126 | json={ 127 | 'agent': AGENT_NAME, 128 | 'agentId': self.agent_id}, 129 | headers=self.auth_header, 130 | retry=0, 131 | code=5, 132 | thread=False) # already in a thread 133 | self.agent_id = r.json().get('data', {}).get('agentId', None) 134 | sleep(self.rate_limits['heartbeat']) 135 | -------------------------------------------------------------------------------- /ultralytics/hub/utils.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | import os 4 | import platform 5 | import shutil 6 | import sys 7 | import threading 8 | import time 9 | from pathlib import Path 10 | from random import random 11 | 12 | import requests 13 | from tqdm import tqdm 14 | 15 | from ultralytics.yolo.utils import (DEFAULT_CFG_DICT, ENVIRONMENT, LOGGER, ONLINE, RANK, SETTINGS, TESTS_RUNNING, 16 | TQDM_BAR_FORMAT, TryExcept, __version__, colorstr, emojis, get_git_origin_url, 17 | is_colab, is_git_dir, is_pip_package) 18 | 19 | PREFIX = colorstr('Ultralytics HUB: ') 20 | HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.' 21 | HUB_API_ROOT = os.environ.get('ULTRALYTICS_HUB_API', 'https://api.ultralytics.com') 22 | 23 | 24 | def check_dataset_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=2.0): 25 | # Check that url fits on disk with safety factor sf, i.e. require 2GB free if url size is 1GB with sf=2.0 26 | gib = 1 << 30 # bytes per GiB 27 | data = int(requests.head(url).headers['Content-Length']) / gib # dataset size (GB) 28 | total, used, free = (x / gib for x in shutil.disk_usage('/')) # bytes 29 | LOGGER.info(f'{PREFIX}{data:.3f} GB dataset, {free:.1f}/{total:.1f} GB free disk space') 30 | if data * sf < free: 31 | return True # sufficient space 32 | LOGGER.warning(f'{PREFIX}WARNING: Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, ' 33 | f'training cancelled ❌. Please free {data * sf - free:.1f} GB additional disk space and try again.') 34 | return False # insufficient space 35 | 36 | 37 | def request_with_credentials(url: str) -> any: 38 | """ Make an ajax request with cookies attached """ 39 | if not is_colab(): 40 | raise OSError('request_with_credentials() must run in a Colab environment') 41 | from google.colab import output # noqa 42 | from IPython import display # noqa 43 | display.display( 44 | display.Javascript(""" 45 | window._hub_tmp = new Promise((resolve, reject) => { 46 | const timeout = setTimeout(() => reject("Failed authenticating existing browser session"), 5000) 47 | fetch("%s", { 48 | method: 'POST', 49 | credentials: 'include' 50 | }) 51 | .then((response) => resolve(response.json())) 52 | .then((json) => { 53 | clearTimeout(timeout); 54 | }).catch((err) => { 55 | clearTimeout(timeout); 56 | reject(err); 57 | }); 58 | }); 59 | """ % url)) 60 | return output.eval_js('_hub_tmp') 61 | 62 | 63 | def split_key(key=''): 64 | """ 65 | Verify and split a 'api_key[sep]model_id' string, sep is one of '.' or '_' 66 | 67 | Args: 68 | key (str): The model key to split. If not provided, the user will be prompted to enter it. 69 | 70 | Returns: 71 | Tuple[str, str]: A tuple containing the API key and model ID. 72 | """ 73 | 74 | import getpass 75 | 76 | error_string = emojis(f'{PREFIX}Invalid API key ⚠️\n') # error string 77 | if not key: 78 | key = getpass.getpass('Enter model key: ') 79 | sep = '_' if '_' in key else '.' if '.' in key else None # separator 80 | assert sep, error_string 81 | api_key, model_id = key.split(sep) 82 | assert len(api_key) and len(model_id), error_string 83 | return api_key, model_id 84 | 85 | 86 | def requests_with_progress(method, url, **kwargs): 87 | """ 88 | Make an HTTP request using the specified method and URL, with an optional progress bar. 89 | 90 | Args: 91 | method (str): The HTTP method to use (e.g. 'GET', 'POST'). 92 | url (str): The URL to send the request to. 93 | progress (bool, optional): Whether to display a progress bar. Defaults to False. 94 | **kwargs: Additional keyword arguments to pass to the underlying `requests.request` function. 95 | 96 | Returns: 97 | requests.Response: The response from the HTTP request. 98 | 99 | """ 100 | progress = kwargs.pop('progress', False) 101 | if not progress: 102 | return requests.request(method, url, **kwargs) 103 | response = requests.request(method, url, stream=True, **kwargs) 104 | total = int(response.headers.get('content-length', 0)) # total size 105 | pbar = tqdm(total=total, unit='B', unit_scale=True, unit_divisor=1024, bar_format=TQDM_BAR_FORMAT) 106 | for data in response.iter_content(chunk_size=1024): 107 | pbar.update(len(data)) 108 | pbar.close() 109 | return response 110 | 111 | 112 | def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbose=True, progress=False, **kwargs): 113 | """ 114 | Makes an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout. 115 | 116 | Args: 117 | method (str): The HTTP method to use for the request. Choices are 'post' and 'get'. 118 | url (str): The URL to make the request to. 119 | retry (int, optional): Number of retries to attempt before giving up. Default is 3. 120 | timeout (int, optional): Timeout in seconds after which the function will give up retrying. Default is 30. 121 | thread (bool, optional): Whether to execute the request in a separate daemon thread. Default is True. 122 | code (int, optional): An identifier for the request, used for logging purposes. Default is -1. 123 | verbose (bool, optional): A flag to determine whether to print out to console or not. Default is True. 124 | progress (bool, optional): Whether to show a progress bar during the request. Default is False. 125 | **kwargs: Keyword arguments to be passed to the requests function specified in method. 126 | 127 | Returns: 128 | requests.Response: The HTTP response object. If the request is executed in a separate thread, returns None. 129 | 130 | """ 131 | retry_codes = (408, 500) # retry only these codes 132 | 133 | @TryExcept(verbose=verbose) 134 | def func(func_method, func_url, **func_kwargs): 135 | r = None # response 136 | t0 = time.time() # initial time for timer 137 | for i in range(retry + 1): 138 | if (time.time() - t0) > timeout: 139 | break 140 | r = requests_with_progress(func_method, func_url, **func_kwargs) # i.e. get(url, data, json, files) 141 | if r.status_code == 200: 142 | break 143 | try: 144 | m = r.json().get('message', 'No JSON message.') 145 | except AttributeError: 146 | m = 'Unable to read JSON.' 147 | if i == 0: 148 | if r.status_code in retry_codes: 149 | m += f' Retrying {retry}x for {timeout}s.' if retry else '' 150 | elif r.status_code == 429: # rate limit 151 | h = r.headers # response headers 152 | m = f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). " \ 153 | f"Please retry after {h['Retry-After']}s." 154 | if verbose: 155 | LOGGER.warning(f'{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})') 156 | if r.status_code not in retry_codes: 157 | return r 158 | time.sleep(2 ** i) # exponential standoff 159 | return r 160 | 161 | args = method, url 162 | kwargs['progress'] = progress 163 | if thread: 164 | threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start() 165 | else: 166 | return func(*args, **kwargs) 167 | 168 | 169 | class Traces: 170 | 171 | def __init__(self): 172 | """ 173 | Initialize Traces for error tracking and reporting if tests are not currently running. 174 | """ 175 | self.rate_limit = 3.0 # rate limit (seconds) 176 | self.t = 0.0 # rate limit timer (seconds) 177 | self.metadata = { 178 | 'sys_argv_name': Path(sys.argv[0]).name, 179 | 'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other', 180 | 'python': platform.python_version(), 181 | 'release': __version__, 182 | 'environment': ENVIRONMENT} 183 | self.enabled = \ 184 | SETTINGS['sync'] and \ 185 | RANK in (-1, 0) and \ 186 | not TESTS_RUNNING and \ 187 | ONLINE and \ 188 | (is_pip_package() or get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git') 189 | 190 | def __call__(self, cfg, all_keys=False, traces_sample_rate=1.0): 191 | """ 192 | Sync traces data if enabled in the global settings 193 | 194 | Args: 195 | cfg (IterableSimpleNamespace): Configuration for the task and mode. 196 | all_keys (bool): Sync all items, not just non-default values. 197 | traces_sample_rate (float): Fraction of traces captured from 0.0 to 1.0 198 | """ 199 | t = time.time() # current time 200 | if self.enabled and random() < traces_sample_rate and (t - self.t) > self.rate_limit: 201 | self.t = t # reset rate limit timer 202 | cfg = vars(cfg) # convert type from IterableSimpleNamespace to dict 203 | if not all_keys: # filter cfg 204 | include_keys = {'task', 'mode'} # always include 205 | cfg = { 206 | k: (v.split(os.sep)[-1] if isinstance(v, str) and os.sep in v else v) 207 | for k, v in cfg.items() if v != DEFAULT_CFG_DICT.get(k, None) or k in include_keys} 208 | trace = {'uuid': SETTINGS['uuid'], 'cfg': cfg, 'metadata': self.metadata} 209 | 210 | # Send a request to the HUB API to sync analytics 211 | smart_request('post', f'{HUB_API_ROOT}/v1/usage/anonymous', json=trace, code=3, retry=0, verbose=False) 212 | 213 | 214 | # Run below code on hub/utils init ------------------------------------------------------------------------------------- 215 | traces = Traces() 216 | -------------------------------------------------------------------------------- /ultralytics/nn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbl1234/YOLOv8_PYQT5_GUI/0b8a3f89d7a4b8f6a55b66da5d425b7b2918a407/ultralytics/nn/__init__.py -------------------------------------------------------------------------------- /ultralytics/yolo/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | from . import v8 4 | 5 | __all__ = 'v8', # tuple or list 6 | -------------------------------------------------------------------------------- /ultralytics/yolo/cfg/default.yaml: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | # Default training settings and hyperparameters for medium-augmentation COCO training 3 | 4 | task: detect # inference task, i.e. detect, segment, classify 5 | mode: train # YOLO mode, i.e. train, val, predict, export 6 | 7 | # Train settings ------------------------------------------------------------------------------------------------------- 8 | model: # path to model file, i.e. yolov8n.pt, yolov8n.yaml 9 | data: # path to data file, i.e. coco128.yaml 10 | epochs: 100 # number of epochs to train for 11 | patience: 50 # epochs to wait for no observable improvement for early stopping of training 12 | batch: 16 # number of images per batch (-1 for AutoBatch) 13 | imgsz: 640 # size of input images as integer or w,h 14 | save: True # save train checkpoints and predict results 15 | save_period: -1 # Save checkpoint every x epochs (disabled if < 1) 16 | cache: False # True/ram, disk or False. Use cache for data loading 17 | device: # device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu 18 | workers: 8 # number of worker threads for data loading (per RANK if DDP) 19 | project: # project name 20 | name: # experiment name, results saved to 'project/name' directory 21 | exist_ok: False # whether to overwrite existing experiment 22 | pretrained: False # whether to use a pretrained model 23 | optimizer: SGD # optimizer to use, choices=['SGD', 'Adam', 'AdamW', 'RMSProp'] 24 | verbose: True # whether to print verbose output 25 | seed: 0 # random seed for reproducibility 26 | deterministic: True # whether to enable deterministic mode 27 | single_cls: False # train multi-class data as single-class 28 | image_weights: False # use weighted image selection for training 29 | rect: False # support rectangular training if mode='train', support rectangular evaluation if mode='val' 30 | cos_lr: False # use cosine learning rate scheduler 31 | close_mosaic: 10 # disable mosaic augmentation for final 10 epochs 32 | resume: False # resume training from last checkpoint 33 | # Segmentation 34 | overlap_mask: True # masks should overlap during training (segment train only) 35 | mask_ratio: 4 # mask downsample ratio (segment train only) 36 | # Classification 37 | dropout: 0.0 # use dropout regularization (classify train only) 38 | 39 | # Val/Test settings ---------------------------------------------------------------------------------------------------- 40 | val: True # validate/test during training 41 | split: val # dataset split to use for validation, i.e. 'val', 'test' or 'train' 42 | save_json: False # save results to JSON file 43 | save_hybrid: False # save hybrid version of labels (labels + additional predictions) 44 | conf: # object confidence threshold for detection (default 0.25 predict, 0.001 val) 45 | iou: 0.7 # intersection over union (IoU) threshold for NMS 46 | max_det: 300 # maximum number of detections per image 47 | half: False # use half precision (FP16) 48 | dnn: False # use OpenCV DNN for ONNX inference 49 | plots: True # save plots during train/val 50 | 51 | # Prediction settings -------------------------------------------------------------------------------------------------- 52 | source: # source directory for images or videos 53 | show: False # show results if possible 54 | save_txt: False # save results as .txt file 55 | save_conf: False # save results with confidence scores 56 | save_crop: False # save cropped images with results 57 | hide_labels: False # hide labels 58 | hide_conf: False # hide confidence scores 59 | vid_stride: 1 # video frame-rate stride 60 | line_thickness: 3 # bounding box thickness (pixels) 61 | visualize: False # visualize model features 62 | augment: False # apply image augmentation to prediction sources 63 | agnostic_nms: False # class-agnostic NMS 64 | classes: # filter results by class, i.e. class=0, or class=[0,2,3] 65 | retina_masks: False # use high-resolution segmentation masks 66 | boxes: True # Show boxes in segmentation predictions 67 | 68 | # Export settings ------------------------------------------------------------------------------------------------------ 69 | format: torchscript # format to export to 70 | keras: False # use Keras 71 | optimize: False # TorchScript: optimize for mobile 72 | int8: False # CoreML/TF INT8 quantization 73 | dynamic: False # ONNX/TF/TensorRT: dynamic axes 74 | simplify: False # ONNX: simplify model 75 | opset: # ONNX: opset version (optional) 76 | workspace: 4 # TensorRT: workspace size (GB) 77 | nms: False # CoreML: add NMS 78 | 79 | # Hyperparameters ------------------------------------------------------------------------------------------------------ 80 | lr0: 0.01 # initial learning rate (i.e. SGD=1E-2, Adam=1E-3) 81 | lrf: 0.01 # final learning rate (lr0 * lrf) 82 | momentum: 0.937 # SGD momentum/Adam beta1 83 | weight_decay: 0.0005 # optimizer weight decay 5e-4 84 | warmup_epochs: 3.0 # warmup epochs (fractions ok) 85 | warmup_momentum: 0.8 # warmup initial momentum 86 | warmup_bias_lr: 0.1 # warmup initial bias lr 87 | box: 7.5 # box loss gain 88 | cls: 0.5 # cls loss gain (scale with pixels) 89 | dfl: 1.5 # dfl loss gain 90 | fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5) 91 | label_smoothing: 0.0 # label smoothing (fraction) 92 | nbs: 64 # nominal batch size 93 | hsv_h: 0.015 # image HSV-Hue augmentation (fraction) 94 | hsv_s: 0.7 # image HSV-Saturation augmentation (fraction) 95 | hsv_v: 0.4 # image HSV-Value augmentation (fraction) 96 | degrees: 0.0 # image rotation (+/- deg) 97 | translate: 0.1 # image translation (+/- fraction) 98 | scale: 0.5 # image scale (+/- gain) 99 | shear: 0.0 # image shear (+/- deg) 100 | perspective: 0.0 # image perspective (+/- fraction), range 0-0.001 101 | flipud: 0.0 # image flip up-down (probability) 102 | fliplr: 0.5 # image flip left-right (probability) 103 | mosaic: 1.0 # image mosaic (probability) 104 | mixup: 0.0 # image mixup (probability) 105 | copy_paste: 0.0 # segment copy-paste (probability) 106 | 107 | # Custom config.yaml --------------------------------------------------------------------------------------------------- 108 | cfg: # for overriding defaults.yaml 109 | 110 | # Debug, do not modify ------------------------------------------------------------------------------------------------- 111 | v5loader: False # use legacy YOLOv5 dataloader 112 | 113 | # Tracker settings ------------------------------------------------------------------------------------------------------ 114 | tracker: botsort.yaml # tracker type, ['botsort.yaml', 'bytetrack.yaml'] 115 | -------------------------------------------------------------------------------- /ultralytics/yolo/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | from .base import BaseDataset 4 | from .build import build_classification_dataloader, build_dataloader, load_inference_source 5 | from .dataset import ClassificationDataset, SemanticDataset, YOLODataset 6 | from .dataset_wrappers import MixAndRectDataset 7 | 8 | __all__ = ('BaseDataset', 'ClassificationDataset', 'MixAndRectDataset', 'SemanticDataset', 'YOLODataset', 9 | 'build_classification_dataloader', 'build_dataloader', 'load_inference_source') 10 | -------------------------------------------------------------------------------- /ultralytics/yolo/data/base.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | import glob 4 | import math 5 | import os 6 | from multiprocessing.pool import ThreadPool 7 | from pathlib import Path 8 | from typing import Optional 9 | 10 | import cv2 11 | import numpy as np 12 | from torch.utils.data import Dataset 13 | from tqdm import tqdm 14 | 15 | from ..utils import NUM_THREADS, TQDM_BAR_FORMAT 16 | from .utils import HELP_URL, IMG_FORMATS, LOCAL_RANK 17 | 18 | 19 | class BaseDataset(Dataset): 20 | """Base Dataset. 21 | Args: 22 | img_path (str): image path. 23 | pipeline (dict): a dict of image transforms. 24 | label_path (str): label path, this can also be an ann_file or other custom label path. 25 | """ 26 | 27 | def __init__(self, 28 | img_path, 29 | imgsz=640, 30 | cache=False, 31 | augment=True, 32 | hyp=None, 33 | prefix='', 34 | rect=False, 35 | batch_size=None, 36 | stride=32, 37 | pad=0.5, 38 | single_cls=False, 39 | classes=None): 40 | super().__init__() 41 | self.img_path = img_path 42 | self.imgsz = imgsz 43 | self.augment = augment 44 | self.single_cls = single_cls 45 | self.prefix = prefix 46 | 47 | self.im_files = self.get_img_files(self.img_path) 48 | self.labels = self.get_labels() 49 | self.update_labels(include_class=classes) # single_cls and include_class 50 | 51 | self.ni = len(self.labels) 52 | 53 | # rect stuff 54 | self.rect = rect 55 | self.batch_size = batch_size 56 | self.stride = stride 57 | self.pad = pad 58 | if self.rect: 59 | assert self.batch_size is not None 60 | self.set_rectangle() 61 | 62 | # cache stuff 63 | self.ims = [None] * self.ni 64 | self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files] 65 | if cache: 66 | self.cache_images(cache) 67 | 68 | # transforms 69 | self.transforms = self.build_transforms(hyp=hyp) 70 | 71 | def get_img_files(self, img_path): 72 | """Read image files.""" 73 | try: 74 | f = [] # image files 75 | for p in img_path if isinstance(img_path, list) else [img_path]: 76 | p = Path(p) # os-agnostic 77 | if p.is_dir(): # dir 78 | f += glob.glob(str(p / '**' / '*.*'), recursive=True) 79 | # f = list(p.rglob('*.*')) # pathlib 80 | elif p.is_file(): # file 81 | with open(p) as t: 82 | t = t.read().strip().splitlines() 83 | parent = str(p.parent) + os.sep 84 | f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path 85 | # f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib) 86 | else: 87 | raise FileNotFoundError(f'{self.prefix}{p} does not exist') 88 | im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS) 89 | # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib 90 | assert im_files, f'{self.prefix}No images found' 91 | except Exception as e: 92 | raise FileNotFoundError(f'{self.prefix}Error loading data from {img_path}\n{HELP_URL}') from e 93 | return im_files 94 | 95 | def update_labels(self, include_class: Optional[list]): 96 | """include_class, filter labels to include only these classes (optional)""" 97 | include_class_array = np.array(include_class).reshape(1, -1) 98 | for i in range(len(self.labels)): 99 | if include_class is not None: 100 | cls = self.labels[i]['cls'] 101 | bboxes = self.labels[i]['bboxes'] 102 | segments = self.labels[i]['segments'] 103 | j = (cls == include_class_array).any(1) 104 | self.labels[i]['cls'] = cls[j] 105 | self.labels[i]['bboxes'] = bboxes[j] 106 | if segments: 107 | self.labels[i]['segments'] = [segments[si] for si, idx in enumerate(j) if idx] 108 | if self.single_cls: 109 | self.labels[i]['cls'][:, 0] = 0 110 | 111 | def load_image(self, i): 112 | # Loads 1 image from dataset index 'i', returns (im, resized hw) 113 | im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i] 114 | if im is None: # not cached in RAM 115 | if fn.exists(): # load npy 116 | im = np.load(fn) 117 | else: # read image 118 | im = cv2.imread(f) # BGR 119 | if im is None: 120 | raise FileNotFoundError(f'Image Not Found {f}') 121 | h0, w0 = im.shape[:2] # orig hw 122 | r = self.imgsz / max(h0, w0) # ratio 123 | if r != 1: # if sizes are not equal 124 | interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA 125 | im = cv2.resize(im, (math.ceil(w0 * r), math.ceil(h0 * r)), interpolation=interp) 126 | return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized 127 | return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized 128 | 129 | def cache_images(self, cache): 130 | # cache images to memory or disk 131 | gb = 0 # Gigabytes of cached images 132 | self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni 133 | fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image 134 | with ThreadPool(NUM_THREADS) as pool: 135 | results = pool.imap(fcn, range(self.ni)) 136 | pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0) 137 | for i, x in pbar: 138 | if cache == 'disk': 139 | gb += self.npy_files[i].stat().st_size 140 | else: # 'ram' 141 | self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i) 142 | gb += self.ims[i].nbytes 143 | pbar.desc = f'{self.prefix}Caching images ({gb / 1E9:.1f}GB {cache})' 144 | pbar.close() 145 | 146 | def cache_images_to_disk(self, i): 147 | # Saves an image as an *.npy file for faster loading 148 | f = self.npy_files[i] 149 | if not f.exists(): 150 | np.save(f.as_posix(), cv2.imread(self.im_files[i])) 151 | 152 | def set_rectangle(self): 153 | bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index 154 | nb = bi[-1] + 1 # number of batches 155 | 156 | s = np.array([x.pop('shape') for x in self.labels]) # hw 157 | ar = s[:, 0] / s[:, 1] # aspect ratio 158 | irect = ar.argsort() 159 | self.im_files = [self.im_files[i] for i in irect] 160 | self.labels = [self.labels[i] for i in irect] 161 | ar = ar[irect] 162 | 163 | # Set training image shapes 164 | shapes = [[1, 1]] * nb 165 | for i in range(nb): 166 | ari = ar[bi == i] 167 | mini, maxi = ari.min(), ari.max() 168 | if maxi < 1: 169 | shapes[i] = [maxi, 1] 170 | elif mini > 1: 171 | shapes[i] = [1, 1 / mini] 172 | 173 | self.batch_shapes = np.ceil(np.array(shapes) * self.imgsz / self.stride + self.pad).astype(int) * self.stride 174 | self.batch = bi # batch index of image 175 | 176 | def __getitem__(self, index): 177 | return self.transforms(self.get_label_info(index)) 178 | 179 | def get_label_info(self, index): 180 | label = self.labels[index].copy() 181 | label.pop('shape', None) # shape is for rect, remove it 182 | label['img'], label['ori_shape'], label['resized_shape'] = self.load_image(index) 183 | label['ratio_pad'] = ( 184 | label['resized_shape'][0] / label['ori_shape'][0], 185 | label['resized_shape'][1] / label['ori_shape'][1], 186 | ) # for evaluation 187 | if self.rect: 188 | label['rect_shape'] = self.batch_shapes[self.batch[index]] 189 | label = self.update_labels_info(label) 190 | return label 191 | 192 | def __len__(self): 193 | return len(self.labels) 194 | 195 | def update_labels_info(self, label): 196 | """custom your label format here""" 197 | return label 198 | 199 | def build_transforms(self, hyp=None): 200 | """Users can custom augmentations here 201 | like: 202 | if self.augment: 203 | # training transforms 204 | return Compose([]) 205 | else: 206 | # val transforms 207 | return Compose([]) 208 | """ 209 | raise NotImplementedError 210 | 211 | def get_labels(self): 212 | """Users can custom their own format here. 213 | Make sure your output is a list with each element like below: 214 | dict( 215 | im_file=im_file, 216 | shape=shape, # format: (height, width) 217 | cls=cls, 218 | bboxes=bboxes, # xywh 219 | segments=segments, # xy 220 | keypoints=keypoints, # xy 221 | normalized=True, # or False 222 | bbox_format="xyxy", # or xywh, ltwh 223 | ) 224 | """ 225 | raise NotImplementedError 226 | -------------------------------------------------------------------------------- /ultralytics/yolo/data/build.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | import os 4 | import random 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import torch 9 | from PIL import Image 10 | from torch.utils.data import DataLoader, dataloader, distributed 11 | 12 | from ultralytics.yolo.data.dataloaders.stream_loaders import (LOADERS, LoadImages, LoadPilAndNumpy, LoadScreenshots, 13 | LoadStreams, LoadTensor, SourceTypes, autocast_list) 14 | from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS 15 | from ultralytics.yolo.utils.checks import check_file 16 | 17 | from ..utils import LOGGER, colorstr 18 | from ..utils.torch_utils import torch_distributed_zero_first 19 | from .dataset import ClassificationDataset, YOLODataset 20 | from .utils import PIN_MEMORY, RANK 21 | 22 | 23 | class InfiniteDataLoader(dataloader.DataLoader): 24 | """Dataloader that reuses workers 25 | 26 | Uses same syntax as vanilla DataLoader 27 | """ 28 | 29 | def __init__(self, *args, **kwargs): 30 | super().__init__(*args, **kwargs) 31 | object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)) 32 | self.iterator = super().__iter__() 33 | 34 | def __len__(self): 35 | return len(self.batch_sampler.sampler) 36 | 37 | def __iter__(self): 38 | for _ in range(len(self)): 39 | yield next(self.iterator) 40 | 41 | 42 | class _RepeatSampler: 43 | """Sampler that repeats forever 44 | 45 | Args: 46 | sampler (Sampler) 47 | """ 48 | 49 | def __init__(self, sampler): 50 | self.sampler = sampler 51 | 52 | def __iter__(self): 53 | while True: 54 | yield from iter(self.sampler) 55 | 56 | 57 | def seed_worker(worker_id): # noqa 58 | # Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader 59 | worker_seed = torch.initial_seed() % 2 ** 32 60 | np.random.seed(worker_seed) 61 | random.seed(worker_seed) 62 | 63 | 64 | def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, rank=-1, mode='train'): 65 | assert mode in ['train', 'val'] 66 | shuffle = mode == 'train' 67 | if cfg.rect and shuffle: 68 | LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False") 69 | shuffle = False 70 | with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP 71 | dataset = YOLODataset( 72 | img_path=img_path, 73 | imgsz=cfg.imgsz, 74 | batch_size=batch, 75 | augment=mode == 'train', # augmentation 76 | hyp=cfg, # TODO: probably add a get_hyps_from_cfg function 77 | rect=cfg.rect or rect, # rectangular batches 78 | cache=cfg.cache or None, 79 | single_cls=cfg.single_cls or False, 80 | stride=int(stride), 81 | pad=0.0 if mode == 'train' else 0.5, 82 | prefix=colorstr(f'{mode}: '), 83 | use_segments=cfg.task == 'segment', 84 | use_keypoints=cfg.task == 'keypoint', 85 | names=names, 86 | classes=cfg.classes) 87 | 88 | batch = min(batch, len(dataset)) 89 | nd = torch.cuda.device_count() # number of CUDA devices 90 | workers = cfg.workers if mode == 'train' else cfg.workers * 2 91 | nw = min([os.cpu_count() // max(nd, 1), batch if batch > 1 else 0, workers]) # number of workers 92 | sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle) 93 | loader = DataLoader if cfg.image_weights or cfg.close_mosaic else InfiniteDataLoader # allow attribute updates 94 | generator = torch.Generator() 95 | generator.manual_seed(6148914691236517205 + RANK) 96 | return loader(dataset=dataset, 97 | batch_size=batch, 98 | shuffle=shuffle and sampler is None, 99 | num_workers=nw, 100 | sampler=sampler, 101 | pin_memory=PIN_MEMORY, 102 | collate_fn=getattr(dataset, 'collate_fn', None), 103 | worker_init_fn=seed_worker, 104 | generator=generator), dataset 105 | 106 | 107 | # build classification 108 | # TODO: using cfg like `build_dataloader` 109 | def build_classification_dataloader(path, 110 | imgsz=224, 111 | batch_size=16, 112 | augment=True, 113 | cache=False, 114 | rank=-1, 115 | workers=8, 116 | shuffle=True): 117 | # Returns Dataloader object to be used with YOLOv5 Classifier 118 | with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP 119 | dataset = ClassificationDataset(root=path, imgsz=imgsz, augment=augment, cache=cache) 120 | batch_size = min(batch_size, len(dataset)) 121 | nd = torch.cuda.device_count() 122 | nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) 123 | sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle) 124 | generator = torch.Generator() 125 | generator.manual_seed(6148914691236517205 + RANK) 126 | return InfiniteDataLoader(dataset, 127 | batch_size=batch_size, 128 | shuffle=shuffle and sampler is None, 129 | num_workers=nw, 130 | sampler=sampler, 131 | pin_memory=PIN_MEMORY, 132 | worker_init_fn=seed_worker, 133 | generator=generator) # or DataLoader(persistent_workers=True) 134 | 135 | 136 | def check_source(source): 137 | webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, False 138 | if isinstance(source, (str, int, Path)): # int for local usb camera 139 | source = str(source) 140 | is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS) 141 | is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://')) 142 | webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file) 143 | screenshot = source.lower().startswith('screen') 144 | if is_url and is_file: 145 | source = check_file(source) # download 146 | elif isinstance(source, tuple(LOADERS)): 147 | in_memory = True 148 | elif isinstance(source, (list, tuple)): 149 | source = autocast_list(source) # convert all list elements to PIL or np arrays 150 | from_img = True 151 | elif isinstance(source, (Image.Image, np.ndarray)): 152 | from_img = True 153 | elif isinstance(source, torch.Tensor): 154 | tensor = True 155 | else: 156 | raise TypeError('Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict') 157 | 158 | return source, webcam, screenshot, from_img, in_memory, tensor 159 | 160 | 161 | def load_inference_source(source=None, transforms=None, imgsz=640, vid_stride=1, stride=32, auto=True): 162 | """ 163 | TODO: docs 164 | """ 165 | source, webcam, screenshot, from_img, in_memory, tensor = check_source(source) 166 | source_type = source.source_type if in_memory else SourceTypes(webcam, screenshot, from_img, tensor) 167 | 168 | # Dataloader 169 | if tensor: 170 | dataset = LoadTensor(source) 171 | elif in_memory: 172 | dataset = source 173 | elif webcam: 174 | dataset = LoadStreams(source, 175 | imgsz=imgsz, 176 | stride=stride, 177 | auto=auto, 178 | transforms=transforms, 179 | vid_stride=vid_stride) 180 | 181 | elif screenshot: 182 | dataset = LoadScreenshots(source, imgsz=imgsz, stride=stride, auto=auto, transforms=transforms) 183 | elif from_img: 184 | dataset = LoadPilAndNumpy(source, imgsz=imgsz, stride=stride, auto=auto, transforms=transforms) 185 | else: 186 | dataset = LoadImages(source, 187 | imgsz=imgsz, 188 | stride=stride, 189 | auto=auto, 190 | transforms=transforms, 191 | vid_stride=vid_stride) 192 | 193 | setattr(dataset, 'source_type', source_type) # attach source types 194 | 195 | return dataset 196 | -------------------------------------------------------------------------------- /ultralytics/yolo/data/dataloaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbl1234/YOLOv8_PYQT5_GUI/0b8a3f89d7a4b8f6a55b66da5d425b7b2918a407/ultralytics/yolo/data/dataloaders/__init__.py -------------------------------------------------------------------------------- /ultralytics/yolo/data/dataset_wrappers.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | import collections 4 | from copy import deepcopy 5 | 6 | from .augment import LetterBox 7 | 8 | 9 | class MixAndRectDataset: 10 | """A wrapper of multiple images mixed dataset. 11 | 12 | Args: 13 | dataset (:obj:`BaseDataset`): The dataset to be mixed. 14 | transforms (Sequence[dict]): config dict to be composed. 15 | """ 16 | 17 | def __init__(self, dataset): 18 | self.dataset = dataset 19 | self.imgsz = dataset.imgsz 20 | 21 | def __len__(self): 22 | return len(self.dataset) 23 | 24 | def __getitem__(self, index): 25 | labels = deepcopy(self.dataset[index]) 26 | for transform in self.dataset.transforms.tolist(): 27 | # mosaic and mixup 28 | if hasattr(transform, 'get_indexes'): 29 | indexes = transform.get_indexes(self.dataset) 30 | if not isinstance(indexes, collections.abc.Sequence): 31 | indexes = [indexes] 32 | mix_labels = [deepcopy(self.dataset[index]) for index in indexes] 33 | labels['mix_labels'] = mix_labels 34 | if self.dataset.rect and isinstance(transform, LetterBox): 35 | transform.new_shape = self.dataset.batch_shapes[self.dataset.batch[index]] 36 | labels = transform(labels) 37 | if 'mix_labels' in labels: 38 | labels.pop('mix_labels') 39 | return labels 40 | -------------------------------------------------------------------------------- /ultralytics/yolo/engine/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbl1234/YOLOv8_PYQT5_GUI/0b8a3f89d7a4b8f6a55b66da5d425b7b2918a407/ultralytics/yolo/engine/__init__.py -------------------------------------------------------------------------------- /ultralytics/yolo/utils/autobatch.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | """ 3 | Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch. 4 | """ 5 | 6 | from copy import deepcopy 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from ultralytics.yolo.utils import LOGGER, colorstr 12 | from ultralytics.yolo.utils.torch_utils import profile 13 | 14 | 15 | def check_train_batch_size(model, imgsz=640, amp=True): 16 | """ 17 | Check YOLO training batch size using the autobatch() function. 18 | 19 | Args: 20 | model (torch.nn.Module): YOLO model to check batch size for. 21 | imgsz (int): Image size used for training. 22 | amp (bool): If True, use automatic mixed precision (AMP) for training. 23 | 24 | Returns: 25 | int: Optimal batch size computed using the autobatch() function. 26 | """ 27 | 28 | with torch.cuda.amp.autocast(amp): 29 | return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size 30 | 31 | 32 | def autobatch(model, imgsz=640, fraction=0.67, batch_size=16): 33 | """ 34 | Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory. 35 | 36 | Args: 37 | model: YOLO model to compute batch size for. 38 | imgsz (int, optional): The image size used as input for the YOLO model. Defaults to 640. 39 | fraction (float, optional): The fraction of available CUDA memory to use. Defaults to 0.67. 40 | batch_size (int, optional): The default batch size to use if an error is detected. Defaults to 16. 41 | 42 | Returns: 43 | int: The optimal batch size. 44 | """ 45 | 46 | # Check device 47 | prefix = colorstr('AutoBatch: ') 48 | LOGGER.info(f'{prefix}Computing optimal batch size for imgsz={imgsz}') 49 | device = next(model.parameters()).device # get model device 50 | if device.type == 'cpu': 51 | LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}') 52 | return batch_size 53 | if torch.backends.cudnn.benchmark: 54 | LOGGER.info(f'{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}') 55 | return batch_size 56 | 57 | # Inspect CUDA memory 58 | gb = 1 << 30 # bytes to GiB (1024 ** 3) 59 | d = str(device).upper() # 'CUDA:0' 60 | properties = torch.cuda.get_device_properties(device) # device properties 61 | t = properties.total_memory / gb # GiB total 62 | r = torch.cuda.memory_reserved(device) / gb # GiB reserved 63 | a = torch.cuda.memory_allocated(device) / gb # GiB allocated 64 | f = t - (r + a) # GiB free 65 | LOGGER.info(f'{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free') 66 | 67 | # Profile batch sizes 68 | batch_sizes = [1, 2, 4, 8, 16] 69 | try: 70 | img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes] 71 | results = profile(img, model, n=3, device=device) 72 | 73 | # Fit a solution 74 | y = [x[2] for x in results if x] # memory [2] 75 | p = np.polyfit(batch_sizes[:len(y)], y, deg=1) # first degree polynomial fit 76 | b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size) 77 | if None in results: # some sizes failed 78 | i = results.index(None) # first fail index 79 | if b >= batch_sizes[i]: # y intercept above failure point 80 | b = batch_sizes[max(i - 1, 0)] # select prior safe point 81 | if b < 1 or b > 1024: # b outside of safe range 82 | b = batch_size 83 | LOGGER.info(f'{prefix}WARNING ⚠️ CUDA anomaly detected, using default batch-size {batch_size}.') 84 | 85 | fraction = (np.polyval(p, b) + r + a) / t # actual fraction predicted 86 | LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅') 87 | return b 88 | except Exception as e: 89 | LOGGER.warning(f'{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.') 90 | return batch_size 91 | -------------------------------------------------------------------------------- /ultralytics/yolo/utils/benchmarks.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | """ 3 | Benchmark a YOLO model formats for speed and accuracy 4 | 5 | Usage: 6 | from ultralytics.yolo.utils.benchmarks import run_benchmarks 7 | run_benchmarks(model='yolov8n.pt', imgsz=160) 8 | 9 | Format | `format=argument` | Model 10 | --- | --- | --- 11 | PyTorch | - | yolov8n.pt 12 | TorchScript | `torchscript` | yolov8n.torchscript 13 | ONNX | `onnx` | yolov8n.onnx 14 | OpenVINO | `openvino` | yolov8n_openvino_model/ 15 | TensorRT | `engine` | yolov8n.engine 16 | CoreML | `coreml` | yolov8n.mlmodel 17 | TensorFlow SavedModel | `saved_model` | yolov8n_saved_model/ 18 | TensorFlow GraphDef | `pb` | yolov8n.pb 19 | TensorFlow Lite | `tflite` | yolov8n.tflite 20 | TensorFlow Edge TPU | `edgetpu` | yolov8n_edgetpu.tflite 21 | TensorFlow.js | `tfjs` | yolov8n_web_model/ 22 | PaddlePaddle | `paddle` | yolov8n_paddle_model/ 23 | """ 24 | 25 | import platform 26 | import time 27 | from pathlib import Path 28 | 29 | from ultralytics import YOLO 30 | from ultralytics.yolo.engine.exporter import export_formats 31 | from ultralytics.yolo.utils import LINUX, LOGGER, MACOS, ROOT, SETTINGS 32 | from ultralytics.yolo.utils.checks import check_yolo 33 | from ultralytics.yolo.utils.downloads import download 34 | from ultralytics.yolo.utils.files import file_size 35 | from ultralytics.yolo.utils.torch_utils import select_device 36 | 37 | 38 | def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt', imgsz=160, half=False, device='cpu', hard_fail=False): 39 | import pandas as pd 40 | pd.options.display.max_columns = 10 41 | pd.options.display.width = 120 42 | device = select_device(device, verbose=False) 43 | if isinstance(model, (str, Path)): 44 | model = YOLO(model) 45 | 46 | y = [] 47 | t0 = time.time() 48 | for i, (name, format, suffix, cpu, gpu) in export_formats().iterrows(): # index, (name, format, suffix, CPU, GPU) 49 | emoji, filename = '❌', None # export defaults 50 | try: 51 | if model.task == 'classify': 52 | assert i != 11, 'paddle cls exports coming soon' 53 | assert i != 9 or LINUX, 'Edge TPU export only supported on Linux' 54 | if i == 10: 55 | assert MACOS or LINUX, 'TF.js export only supported on macOS and Linux' 56 | if 'cpu' in device.type: 57 | assert cpu, 'inference not supported on CPU' 58 | if 'cuda' in device.type: 59 | assert gpu, 'inference not supported on GPU' 60 | 61 | # Export 62 | if format == '-': 63 | filename = model.ckpt_path or model.cfg 64 | export = model # PyTorch format 65 | else: 66 | filename = model.export(imgsz=imgsz, format=format, half=half, device=device) # all others 67 | export = YOLO(filename, task=model.task) 68 | assert suffix in str(filename), 'export failed' 69 | emoji = '❎' # indicates export succeeded 70 | 71 | # Predict 72 | assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported 73 | assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML 74 | if not (ROOT / 'assets/bus.jpg').exists(): 75 | download(url='https://ultralytics.com/images/bus.jpg', dir=ROOT / 'assets') 76 | export.predict(ROOT / 'assets/bus.jpg', imgsz=imgsz, device=device, half=half) 77 | 78 | # Validate 79 | if model.task == 'detect': 80 | data, key = 'coco128.yaml', 'metrics/mAP50-95(B)' 81 | elif model.task == 'segment': 82 | data, key = 'coco128-seg.yaml', 'metrics/mAP50-95(M)' 83 | elif model.task == 'classify': 84 | data, key = 'imagenet100', 'metrics/accuracy_top5' 85 | 86 | results = export.val(data=data, batch=1, imgsz=imgsz, plots=False, device=device, half=half, verbose=False) 87 | metric, speed = results.results_dict[key], results.speed['inference'] 88 | y.append([name, '✅', round(file_size(filename), 1), round(metric, 4), round(speed, 2)]) 89 | except Exception as e: 90 | if hard_fail: 91 | assert type(e) is AssertionError, f'Benchmark hard_fail for {name}: {e}' 92 | LOGGER.warning(f'ERROR ❌️ Benchmark failure for {name}: {e}') 93 | y.append([name, emoji, round(file_size(filename), 1), None, None]) # mAP, t_inference 94 | 95 | # Print results 96 | check_yolo(device=device) # print system info 97 | df = pd.DataFrame(y, columns=['Format', 'Status❔', 'Size (MB)', key, 'Inference time (ms/im)']) 98 | 99 | name = Path(model.ckpt_path).name 100 | s = f'\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({time.time() - t0:.2f}s)\n{df}\n' 101 | LOGGER.info(s) 102 | with open('benchmarks.log', 'a', errors='ignore', encoding='utf-8') as f: 103 | f.write(s) 104 | 105 | if hard_fail and isinstance(hard_fail, float): 106 | metrics = df[key].array # values to compare to floor 107 | floor = hard_fail # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n 108 | assert all(x > floor for x in metrics if pd.notna(x)), f'HARD FAIL: one or more metric(s) < floor {floor}' 109 | 110 | return df 111 | 112 | 113 | if __name__ == '__main__': 114 | benchmark() 115 | -------------------------------------------------------------------------------- /ultralytics/yolo/utils/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import add_integration_callbacks, default_callbacks 2 | 3 | __all__ = 'add_integration_callbacks', 'default_callbacks' 4 | -------------------------------------------------------------------------------- /ultralytics/yolo/utils/callbacks/base.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | """ 3 | Base callbacks 4 | """ 5 | 6 | 7 | # Trainer callbacks ---------------------------------------------------------------------------------------------------- 8 | def on_pretrain_routine_start(trainer): 9 | pass 10 | 11 | 12 | def on_pretrain_routine_end(trainer): 13 | pass 14 | 15 | 16 | def on_train_start(trainer): 17 | pass 18 | 19 | 20 | def on_train_epoch_start(trainer): 21 | pass 22 | 23 | 24 | def on_train_batch_start(trainer): 25 | pass 26 | 27 | 28 | def optimizer_step(trainer): 29 | pass 30 | 31 | 32 | def on_before_zero_grad(trainer): 33 | pass 34 | 35 | 36 | def on_train_batch_end(trainer): 37 | pass 38 | 39 | 40 | def on_train_epoch_end(trainer): 41 | pass 42 | 43 | 44 | def on_fit_epoch_end(trainer): 45 | pass 46 | 47 | 48 | def on_model_save(trainer): 49 | pass 50 | 51 | 52 | def on_train_end(trainer): 53 | pass 54 | 55 | 56 | def on_params_update(trainer): 57 | pass 58 | 59 | 60 | def teardown(trainer): 61 | pass 62 | 63 | 64 | # Validator callbacks -------------------------------------------------------------------------------------------------- 65 | def on_val_start(validator): 66 | pass 67 | 68 | 69 | def on_val_batch_start(validator): 70 | pass 71 | 72 | 73 | def on_val_batch_end(validator): 74 | pass 75 | 76 | 77 | def on_val_end(validator): 78 | pass 79 | 80 | 81 | # Predictor callbacks -------------------------------------------------------------------------------------------------- 82 | def on_predict_start(predictor): 83 | pass 84 | 85 | 86 | def on_predict_batch_start(predictor): 87 | pass 88 | 89 | 90 | def on_predict_batch_end(predictor): 91 | pass 92 | 93 | 94 | def on_predict_postprocess_end(predictor): 95 | pass 96 | 97 | 98 | def on_predict_end(predictor): 99 | pass 100 | 101 | 102 | # Exporter callbacks --------------------------------------------------------------------------------------------------- 103 | def on_export_start(exporter): 104 | pass 105 | 106 | 107 | def on_export_end(exporter): 108 | pass 109 | 110 | 111 | default_callbacks = { 112 | # Run in trainer 113 | 'on_pretrain_routine_start': [on_pretrain_routine_start], 114 | 'on_pretrain_routine_end': [on_pretrain_routine_end], 115 | 'on_train_start': [on_train_start], 116 | 'on_train_epoch_start': [on_train_epoch_start], 117 | 'on_train_batch_start': [on_train_batch_start], 118 | 'optimizer_step': [optimizer_step], 119 | 'on_before_zero_grad': [on_before_zero_grad], 120 | 'on_train_batch_end': [on_train_batch_end], 121 | 'on_train_epoch_end': [on_train_epoch_end], 122 | 'on_fit_epoch_end': [on_fit_epoch_end], # fit = train + val 123 | 'on_model_save': [on_model_save], 124 | 'on_train_end': [on_train_end], 125 | 'on_params_update': [on_params_update], 126 | 'teardown': [teardown], 127 | 128 | # Run in validator 129 | 'on_val_start': [on_val_start], 130 | 'on_val_batch_start': [on_val_batch_start], 131 | 'on_val_batch_end': [on_val_batch_end], 132 | 'on_val_end': [on_val_end], 133 | 134 | # Run in predictor 135 | 'on_predict_start': [on_predict_start], 136 | 'on_predict_batch_start': [on_predict_batch_start], 137 | 'on_predict_postprocess_end': [on_predict_postprocess_end], 138 | 'on_predict_batch_end': [on_predict_batch_end], 139 | 'on_predict_end': [on_predict_end], 140 | 141 | # Run in exporter 142 | 'on_export_start': [on_export_start], 143 | 'on_export_end': [on_export_end]} 144 | 145 | 146 | def add_integration_callbacks(instance): 147 | from .clearml import callbacks as clearml_callbacks 148 | from .comet import callbacks as comet_callbacks 149 | from .hub import callbacks as hub_callbacks 150 | from .tensorboard import callbacks as tb_callbacks 151 | 152 | for x in clearml_callbacks, comet_callbacks, hub_callbacks, tb_callbacks: 153 | for k, v in x.items(): 154 | if v not in instance.callbacks[k]: # prevent duplicate callbacks addition 155 | instance.callbacks[k].append(v) # callback[name].append(func) 156 | -------------------------------------------------------------------------------- /ultralytics/yolo/utils/callbacks/clearml.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING 3 | from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params 4 | 5 | try: 6 | import clearml 7 | from clearml import Task 8 | 9 | assert clearml.__version__ # verify package is not directory 10 | assert not TESTS_RUNNING # do not log pytest 11 | except (ImportError, AssertionError): 12 | clearml = None 13 | 14 | 15 | def _log_images(imgs_dict, group='', step=0): 16 | task = Task.current_task() 17 | if task: 18 | for k, v in imgs_dict.items(): 19 | task.get_logger().report_image(group, k, step, v) 20 | 21 | 22 | def on_pretrain_routine_start(trainer): 23 | try: 24 | task = Task.init(project_name=trainer.args.project or 'YOLOv8', 25 | task_name=trainer.args.name, 26 | tags=['YOLOv8'], 27 | output_uri=True, 28 | reuse_last_task_id=False, 29 | auto_connect_frameworks={'pytorch': False}) 30 | task.connect(vars(trainer.args), name='General') 31 | except Exception as e: 32 | LOGGER.warning(f'WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}') 33 | 34 | 35 | def on_train_epoch_end(trainer): 36 | if trainer.epoch == 1: 37 | _log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, 'Mosaic', trainer.epoch) 38 | 39 | 40 | def on_fit_epoch_end(trainer): 41 | task = Task.current_task() 42 | if task and trainer.epoch == 0: 43 | model_info = { 44 | 'model/parameters': get_num_params(trainer.model), 45 | 'model/GFLOPs': round(get_flops(trainer.model), 3), 46 | 'model/speed(ms)': round(trainer.validator.speed['inference'], 3)} 47 | task.connect(model_info, name='Model') 48 | 49 | 50 | def on_train_end(trainer): 51 | task = Task.current_task() 52 | if task: 53 | task.update_output_model(model_path=str(trainer.best), model_name=trainer.args.name, auto_delete_file=False) 54 | 55 | 56 | callbacks = { 57 | 'on_pretrain_routine_start': on_pretrain_routine_start, 58 | 'on_train_epoch_end': on_train_epoch_end, 59 | 'on_fit_epoch_end': on_fit_epoch_end, 60 | 'on_train_end': on_train_end} if clearml else {} 61 | -------------------------------------------------------------------------------- /ultralytics/yolo/utils/callbacks/comet.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING 3 | from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params 4 | 5 | try: 6 | import comet_ml 7 | 8 | assert not TESTS_RUNNING # do not log pytest 9 | assert comet_ml.__version__ # verify package is not directory 10 | except (ImportError, AssertionError): 11 | comet_ml = None 12 | 13 | 14 | def on_pretrain_routine_start(trainer): 15 | try: 16 | experiment = comet_ml.Experiment(project_name=trainer.args.project or 'YOLOv8') 17 | experiment.log_parameters(vars(trainer.args)) 18 | except Exception as e: 19 | LOGGER.warning(f'WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}') 20 | 21 | 22 | def on_train_epoch_end(trainer): 23 | experiment = comet_ml.get_global_experiment() 24 | if experiment: 25 | experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix='train'), step=trainer.epoch + 1) 26 | if trainer.epoch == 1: 27 | for f in trainer.save_dir.glob('train_batch*.jpg'): 28 | experiment.log_image(f, name=f.stem, step=trainer.epoch + 1) 29 | 30 | 31 | def on_fit_epoch_end(trainer): 32 | experiment = comet_ml.get_global_experiment() 33 | if experiment: 34 | experiment.log_metrics(trainer.metrics, step=trainer.epoch + 1) 35 | if trainer.epoch == 0: 36 | model_info = { 37 | 'model/parameters': get_num_params(trainer.model), 38 | 'model/GFLOPs': round(get_flops(trainer.model), 3), 39 | 'model/speed(ms)': round(trainer.validator.speed['inference'], 3)} 40 | experiment.log_metrics(model_info, step=trainer.epoch + 1) 41 | 42 | 43 | def on_train_end(trainer): 44 | experiment = comet_ml.get_global_experiment() 45 | if experiment: 46 | experiment.log_model('YOLOv8', file_or_folder=str(trainer.best), file_name='best.pt', overwrite=True) 47 | 48 | 49 | callbacks = { 50 | 'on_pretrain_routine_start': on_pretrain_routine_start, 51 | 'on_train_epoch_end': on_train_epoch_end, 52 | 'on_fit_epoch_end': on_fit_epoch_end, 53 | 'on_train_end': on_train_end} if comet_ml else {} 54 | -------------------------------------------------------------------------------- /ultralytics/yolo/utils/callbacks/hub.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | import json 4 | from time import time 5 | 6 | from ultralytics.hub.utils import PREFIX, traces 7 | from ultralytics.yolo.utils import LOGGER 8 | from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params 9 | 10 | 11 | def on_pretrain_routine_end(trainer): 12 | session = getattr(trainer, 'hub_session', None) 13 | if session: 14 | # Start timer for upload rate limit 15 | LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀') 16 | session.timers = {'metrics': time(), 'ckpt': time()} # start timer on session.rate_limit 17 | 18 | 19 | def on_fit_epoch_end(trainer): 20 | session = getattr(trainer, 'hub_session', None) 21 | if session: 22 | # Upload metrics after val end 23 | all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics} 24 | if trainer.epoch == 0: 25 | model_info = { 26 | 'model/parameters': get_num_params(trainer.model), 27 | 'model/GFLOPs': round(get_flops(trainer.model), 3), 28 | 'model/speed(ms)': round(trainer.validator.speed['inference'], 3)} 29 | all_plots = {**all_plots, **model_info} 30 | session.metrics_queue[trainer.epoch] = json.dumps(all_plots) 31 | if time() - session.timers['metrics'] > session.rate_limits['metrics']: 32 | session.upload_metrics() 33 | session.timers['metrics'] = time() # reset timer 34 | session.metrics_queue = {} # reset queue 35 | 36 | 37 | def on_model_save(trainer): 38 | session = getattr(trainer, 'hub_session', None) 39 | if session: 40 | # Upload checkpoints with rate limiting 41 | is_best = trainer.best_fitness == trainer.fitness 42 | if time() - session.timers['ckpt'] > session.rate_limits['ckpt']: 43 | LOGGER.info(f'{PREFIX}Uploading checkpoint {session.model_id}') 44 | session.upload_model(trainer.epoch, trainer.last, is_best) 45 | session.timers['ckpt'] = time() # reset timer 46 | 47 | 48 | def on_train_end(trainer): 49 | session = getattr(trainer, 'hub_session', None) 50 | if session: 51 | # Upload final model and metrics with exponential standoff 52 | LOGGER.info(f'{PREFIX}Syncing final model...') 53 | session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics.get('metrics/mAP50-95(B)', 0), final=True) 54 | session.alive = False # stop heartbeats 55 | LOGGER.info(f'{PREFIX}Done ✅\n' 56 | f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀') 57 | 58 | 59 | def on_train_start(trainer): 60 | traces(trainer.args, traces_sample_rate=1.0) 61 | 62 | 63 | def on_val_start(validator): 64 | traces(validator.args, traces_sample_rate=1.0) 65 | 66 | 67 | def on_predict_start(predictor): 68 | traces(predictor.args, traces_sample_rate=1.0) 69 | 70 | 71 | def on_export_start(exporter): 72 | traces(exporter.args, traces_sample_rate=1.0) 73 | 74 | 75 | callbacks = { 76 | 'on_pretrain_routine_end': on_pretrain_routine_end, 77 | 'on_fit_epoch_end': on_fit_epoch_end, 78 | 'on_model_save': on_model_save, 79 | 'on_train_end': on_train_end, 80 | 'on_train_start': on_train_start, 81 | 'on_val_start': on_val_start, 82 | 'on_predict_start': on_predict_start, 83 | 'on_export_start': on_export_start} 84 | -------------------------------------------------------------------------------- /ultralytics/yolo/utils/callbacks/tensorboard.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING, colorstr 3 | 4 | try: 5 | from torch.utils.tensorboard import SummaryWriter 6 | 7 | assert not TESTS_RUNNING # do not log pytest 8 | except (ImportError, AssertionError): 9 | SummaryWriter = None 10 | 11 | writer = None # TensorBoard SummaryWriter instance 12 | 13 | 14 | def _log_scalars(scalars, step=0): 15 | if writer: 16 | for k, v in scalars.items(): 17 | writer.add_scalar(k, v, step) 18 | 19 | 20 | def on_pretrain_routine_start(trainer): 21 | if SummaryWriter: 22 | try: 23 | global writer 24 | writer = SummaryWriter(str(trainer.save_dir)) 25 | prefix = colorstr('TensorBoard: ') 26 | LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/") 27 | except Exception as e: 28 | LOGGER.warning(f'WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}') 29 | 30 | 31 | def on_batch_end(trainer): 32 | _log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1) 33 | 34 | 35 | def on_fit_epoch_end(trainer): 36 | _log_scalars(trainer.metrics, trainer.epoch + 1) 37 | 38 | 39 | callbacks = { 40 | 'on_pretrain_routine_start': on_pretrain_routine_start, 41 | 'on_fit_epoch_end': on_fit_epoch_end, 42 | 'on_batch_end': on_batch_end} 43 | -------------------------------------------------------------------------------- /ultralytics/yolo/utils/dist.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | import os 4 | import re 5 | import shutil 6 | import socket 7 | import sys 8 | import tempfile 9 | from pathlib import Path 10 | 11 | from . import USER_CONFIG_DIR 12 | from .torch_utils import TORCH_1_9 13 | 14 | 15 | def find_free_network_port() -> int: 16 | """Finds a free port on localhost. 17 | 18 | It is useful in single-node training when we don't want to connect to a real main node but have to set the 19 | `MASTER_PORT` environment variable. 20 | """ 21 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 22 | s.bind(('127.0.0.1', 0)) 23 | return s.getsockname()[1] # port 24 | 25 | 26 | def generate_ddp_file(trainer): 27 | module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1) 28 | 29 | content = f'''cfg = {vars(trainer.args)} \nif __name__ == "__main__": 30 | from {module} import {name} 31 | 32 | trainer = {name}(cfg=cfg) 33 | trainer.train()''' 34 | (USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True) 35 | with tempfile.NamedTemporaryFile(prefix='_temp_', 36 | suffix=f'{id(trainer)}.py', 37 | mode='w+', 38 | encoding='utf-8', 39 | dir=USER_CONFIG_DIR / 'DDP', 40 | delete=False) as file: 41 | file.write(content) 42 | return file.name 43 | 44 | 45 | def generate_ddp_command(world_size, trainer): 46 | import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218 47 | if not trainer.resume: 48 | shutil.rmtree(trainer.save_dir) # remove the save_dir 49 | file = str(Path(sys.argv[0]).resolve()) 50 | safe_pattern = re.compile(r'^[a-zA-Z0-9_. /\\-]{1,128}$') # allowed characters and maximum of 100 characters 51 | if not (safe_pattern.match(file) and Path(file).exists() and file.endswith('.py')): # using CLI 52 | file = generate_ddp_file(trainer) 53 | dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch' 54 | port = find_free_network_port() 55 | exclude_args = ['save_dir'] 56 | args = [f'{k}={v}' for k, v in vars(trainer.args).items() if k not in exclude_args] 57 | cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file] + args 58 | return cmd, file 59 | 60 | 61 | def ddp_cleanup(trainer, file): 62 | # delete temp file if created 63 | if f'{id(trainer)}.py' in file: # if temp_file suffix in file 64 | os.remove(file) 65 | -------------------------------------------------------------------------------- /ultralytics/yolo/utils/files.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | import contextlib 4 | import glob 5 | import os 6 | import urllib 7 | from datetime import datetime 8 | from pathlib import Path 9 | 10 | 11 | class WorkingDirectory(contextlib.ContextDecorator): 12 | # Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager 13 | def __init__(self, new_dir): 14 | self.dir = new_dir # new dir 15 | self.cwd = Path.cwd().resolve() # current dir 16 | 17 | def __enter__(self): 18 | os.chdir(self.dir) 19 | 20 | def __exit__(self, exc_type, exc_val, exc_tb): 21 | os.chdir(self.cwd) 22 | 23 | 24 | def increment_path(path, exist_ok=False, sep='', mkdir=False): 25 | """ 26 | Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc. 27 | 28 | If the path exists and exist_ok is not set to True, the path will be incremented by appending a number and sep to 29 | the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the 30 | number will be appended directly to the end of the path. If mkdir is set to True, the path will be created as a 31 | directory if it does not already exist. 32 | 33 | Args: 34 | path (str or pathlib.Path): Path to increment. 35 | exist_ok (bool, optional): If True, the path will not be incremented and will be returned as-is. Defaults to False. 36 | sep (str, optional): Separator to use between the path and the incrementation number. Defaults to an empty string. 37 | mkdir (bool, optional): If True, the path will be created as a directory if it does not exist. Defaults to False. 38 | 39 | Returns: 40 | pathlib.Path: Incremented path. 41 | """ 42 | path = Path(path) # os-agnostic 43 | if path.exists() and not exist_ok: 44 | path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '') 45 | 46 | # Method 1 47 | for n in range(2, 9999): 48 | p = f'{path}{sep}{n}{suffix}' # increment path 49 | if not os.path.exists(p): # 50 | break 51 | path = Path(p) 52 | 53 | if mkdir: 54 | path.mkdir(parents=True, exist_ok=True) # make directory 55 | 56 | return path 57 | 58 | 59 | def file_age(path=__file__): 60 | # Return days since last file update 61 | dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta 62 | return dt.days # + dt.seconds / 86400 # fractional days 63 | 64 | 65 | def file_date(path=__file__): 66 | # Return human-readable file modification date, i.e. '2021-3-26' 67 | t = datetime.fromtimestamp(Path(path).stat().st_mtime) 68 | return f'{t.year}-{t.month}-{t.day}' 69 | 70 | 71 | def file_size(path): 72 | # Return file/dir size (MB) 73 | if isinstance(path, (str, Path)): 74 | mb = 1 << 20 # bytes to MiB (1024 ** 2) 75 | path = Path(path) 76 | if path.is_file(): 77 | return path.stat().st_size / mb 78 | elif path.is_dir(): 79 | return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb 80 | return 0.0 81 | 82 | 83 | def url2file(url): 84 | # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt 85 | url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/ 86 | return Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth 87 | 88 | 89 | def get_latest_run(search_dir='.'): 90 | # Return path to most recent 'last.pt' in /runs (i.e. to --resume from) 91 | last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True) 92 | return max(last_list, key=os.path.getctime) if last_list else '' 93 | -------------------------------------------------------------------------------- /ultralytics/yolo/utils/loss.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from .metrics import bbox_iou 8 | from .tal import bbox2dist 9 | 10 | 11 | class VarifocalLoss(nn.Module): 12 | # Varifocal loss by Zhang et al. https://arxiv.org/abs/2008.13367 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def forward(self, pred_score, gt_score, label, alpha=0.75, gamma=2.0): 17 | weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label 18 | with torch.cuda.amp.autocast(enabled=False): 19 | loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') * 20 | weight).sum() 21 | return loss 22 | 23 | 24 | class BboxLoss(nn.Module): 25 | 26 | def __init__(self, reg_max, use_dfl=False): 27 | super().__init__() 28 | self.reg_max = reg_max 29 | self.use_dfl = use_dfl 30 | 31 | def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask): 32 | # IoU loss 33 | weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1) 34 | iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True) 35 | loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum 36 | 37 | # DFL loss 38 | if self.use_dfl: 39 | target_ltrb = bbox2dist(anchor_points, target_bboxes, self.reg_max) 40 | loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight 41 | loss_dfl = loss_dfl.sum() / target_scores_sum 42 | else: 43 | loss_dfl = torch.tensor(0.0).to(pred_dist.device) 44 | 45 | return loss_iou, loss_dfl 46 | 47 | @staticmethod 48 | def _df_loss(pred_dist, target): 49 | # Return sum of left and right DFL losses 50 | # Distribution Focal Loss (DFL) proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391 51 | tl = target.long() # target left 52 | tr = tl + 1 # target right 53 | wl = tr - target # weight left 54 | wr = 1 - wl # weight right 55 | return (F.cross_entropy(pred_dist, tl.view(-1), reduction='none').view(tl.shape) * wl + 56 | F.cross_entropy(pred_dist, tr.view(-1), reduction='none').view(tl.shape) * wr).mean(-1, keepdim=True) 57 | -------------------------------------------------------------------------------- /ultralytics/yolo/v8/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | from ultralytics.yolo.v8 import classify, detect, segment 4 | 5 | __all__ = 'classify', 'segment', 'detect' 6 | -------------------------------------------------------------------------------- /ultralytics/yolo/v8/classify/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | from ultralytics.yolo.v8.classify.predict import ClassificationPredictor, predict 4 | from ultralytics.yolo.v8.classify.train import ClassificationTrainer, train 5 | from ultralytics.yolo.v8.classify.val import ClassificationValidator, val 6 | 7 | __all__ = 'ClassificationPredictor', 'predict', 'ClassificationTrainer', 'train', 'ClassificationValidator', 'val' 8 | -------------------------------------------------------------------------------- /ultralytics/yolo/v8/classify/predict.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | import torch 4 | 5 | from ultralytics.yolo.engine.predictor import BasePredictor 6 | from ultralytics.yolo.engine.results import Results 7 | from ultralytics.yolo.utils import DEFAULT_CFG, ROOT 8 | from ultralytics.yolo.utils.plotting import Annotator 9 | 10 | 11 | class ClassificationPredictor(BasePredictor): 12 | 13 | def get_annotator(self, img): 14 | return Annotator(img, example=str(self.model.names), pil=True) 15 | 16 | def preprocess(self, img): 17 | img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device) 18 | return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32 19 | 20 | def postprocess(self, preds, img, orig_imgs): 21 | results = [] 22 | for i, pred in enumerate(preds): 23 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 24 | path, _, _, _, _ = self.batch 25 | img_path = path[i] if isinstance(path, list) else path 26 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, probs=pred)) 27 | 28 | return results 29 | 30 | def write_results(self, idx, results, batch): 31 | p, im, im0 = batch 32 | log_string = '' 33 | if len(im.shape) == 3: 34 | im = im[None] # expand for batch dim 35 | self.seen += 1 36 | im0 = im0.copy() 37 | if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1 38 | log_string += f'{idx}: ' 39 | frame = self.dataset.count 40 | else: 41 | frame = getattr(self.dataset, 'frame', 0) 42 | 43 | self.data_path = p 44 | # save_path = str(self.save_dir / p.name) # im.jpg 45 | self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}') 46 | log_string += '%gx%g ' % im.shape[2:] # print string 47 | self.annotator = self.get_annotator(im0) 48 | 49 | result = results[idx] 50 | if len(result) == 0: 51 | return log_string 52 | prob = result.probs 53 | # Print results 54 | n5 = min(len(self.model.names), 5) 55 | top5i = prob.argsort(0, descending=True)[:n5].tolist() # top 5 indices 56 | log_string += f"{', '.join(f'{self.model.names[j]} {prob[j]:.2f}' for j in top5i)}, " 57 | 58 | # write 59 | text = '\n'.join(f'{prob[j]:.2f} {self.model.names[j]}' for j in top5i) 60 | if self.args.save or self.args.show: # Add bbox to image 61 | self.annotator.text((32, 32), text, txt_color=(255, 255, 255)) 62 | if self.args.save_txt: # Write to file 63 | with open(f'{self.txt_path}.txt', 'a') as f: 64 | f.write(text + '\n') 65 | 66 | return log_string 67 | 68 | 69 | def predict(cfg=DEFAULT_CFG, use_python=False): 70 | model = cfg.model or 'yolov8n-cls.pt' # or "resnet18" 71 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \ 72 | else 'https://ultralytics.com/images/bus.jpg' 73 | 74 | args = dict(model=model, source=source) 75 | if use_python: 76 | from ultralytics import YOLO 77 | YOLO(model)(**args) 78 | else: 79 | predictor = ClassificationPredictor(overrides=args) 80 | predictor.predict_cli() 81 | 82 | 83 | if __name__ == '__main__': 84 | predict() 85 | -------------------------------------------------------------------------------- /ultralytics/yolo/v8/classify/train.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | import torch 4 | import torchvision 5 | 6 | from ultralytics.nn.tasks import ClassificationModel, attempt_load_one_weight 7 | from ultralytics.yolo import v8 8 | from ultralytics.yolo.data import build_classification_dataloader 9 | from ultralytics.yolo.engine.trainer import BaseTrainer 10 | from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, colorstr 11 | from ultralytics.yolo.utils.torch_utils import is_parallel, strip_optimizer 12 | 13 | 14 | class ClassificationTrainer(BaseTrainer): 15 | 16 | def __init__(self, cfg=DEFAULT_CFG, overrides=None): 17 | if overrides is None: 18 | overrides = {} 19 | overrides['task'] = 'classify' 20 | super().__init__(cfg, overrides) 21 | 22 | def set_model_attributes(self): 23 | self.model.names = self.data['names'] 24 | 25 | def get_model(self, cfg=None, weights=None, verbose=True): 26 | model = ClassificationModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1) 27 | if weights: 28 | model.load(weights) 29 | 30 | pretrained = False 31 | for m in model.modules(): 32 | if not pretrained and hasattr(m, 'reset_parameters'): 33 | m.reset_parameters() 34 | if isinstance(m, torch.nn.Dropout) and self.args.dropout: 35 | m.p = self.args.dropout # set dropout 36 | for p in model.parameters(): 37 | p.requires_grad = True # for training 38 | 39 | # Update defaults 40 | if self.args.imgsz == 640: 41 | self.args.imgsz = 224 42 | 43 | return model 44 | 45 | def setup_model(self): 46 | """ 47 | load/create/download model for any task 48 | """ 49 | # classification models require special handling 50 | 51 | if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed 52 | return 53 | 54 | model = str(self.model) 55 | # Load a YOLO model locally, from torchvision, or from Ultralytics assets 56 | if model.endswith('.pt'): 57 | self.model, _ = attempt_load_one_weight(model, device='cpu') 58 | for p in self.model.parameters(): 59 | p.requires_grad = True # for training 60 | elif model.endswith('.yaml'): 61 | self.model = self.get_model(cfg=model) 62 | elif model in torchvision.models.__dict__: 63 | pretrained = True 64 | self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None) 65 | else: 66 | FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.') 67 | ClassificationModel.reshape_outputs(self.model, self.data['nc']) 68 | 69 | return # dont return ckpt. Classification doesn't support resume 70 | 71 | def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'): 72 | loader = build_classification_dataloader(path=dataset_path, 73 | imgsz=self.args.imgsz, 74 | batch_size=batch_size if mode == 'train' else (batch_size * 2), 75 | augment=mode == 'train', 76 | rank=rank, 77 | workers=self.args.workers) 78 | # Attach inference transforms 79 | if mode != 'train': 80 | if is_parallel(self.model): 81 | self.model.module.transforms = loader.dataset.torch_transforms 82 | else: 83 | self.model.transforms = loader.dataset.torch_transforms 84 | return loader 85 | 86 | def preprocess_batch(self, batch): 87 | batch['img'] = batch['img'].to(self.device) 88 | batch['cls'] = batch['cls'].to(self.device) 89 | return batch 90 | 91 | def progress_string(self): 92 | return ('\n' + '%11s' * (4 + len(self.loss_names))) % \ 93 | ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size') 94 | 95 | def get_validator(self): 96 | self.loss_names = ['loss'] 97 | return v8.classify.ClassificationValidator(self.test_loader, self.save_dir) 98 | 99 | def criterion(self, preds, batch): 100 | loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / self.args.nbs 101 | loss_items = loss.detach() 102 | return loss, loss_items 103 | 104 | # def label_loss_items(self, loss_items=None, prefix="train"): 105 | # """ 106 | # Returns a loss dict with labelled training loss items tensor 107 | # """ 108 | # # Not needed for classification but necessary for segmentation & detection 109 | # keys = [f"{prefix}/{x}" for x in self.loss_names] 110 | # if loss_items is not None: 111 | # loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats 112 | # return dict(zip(keys, loss_items)) 113 | # else: 114 | # return keys 115 | 116 | def label_loss_items(self, loss_items=None, prefix='train'): 117 | """ 118 | Returns a loss dict with labelled training loss items tensor 119 | """ 120 | # Not needed for classification but necessary for segmentation & detection 121 | keys = [f'{prefix}/{x}' for x in self.loss_names] 122 | if loss_items is None: 123 | return keys 124 | loss_items = [round(float(loss_items), 5)] 125 | return dict(zip(keys, loss_items)) 126 | 127 | def resume_training(self, ckpt): 128 | pass 129 | 130 | def final_eval(self): 131 | for f in self.last, self.best: 132 | if f.exists(): 133 | strip_optimizer(f) # strip optimizers 134 | # TODO: validate best.pt after training completes 135 | # if f is self.best: 136 | # LOGGER.info(f'\nValidating {f}...') 137 | # self.validator.args.save_json = True 138 | # self.metrics = self.validator(model=f) 139 | # self.metrics.pop('fitness', None) 140 | # self.run_callbacks('on_fit_epoch_end') 141 | LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}") 142 | 143 | 144 | def train(cfg=DEFAULT_CFG, use_python=False): 145 | model = cfg.model or 'yolov8n-cls.pt' # or "resnet18" 146 | data = cfg.data or 'mnist160' # or yolo.ClassificationDataset("mnist") 147 | device = cfg.device if cfg.device is not None else '' 148 | 149 | args = dict(model=model, data=data, device=device) 150 | if use_python: 151 | from ultralytics import YOLO 152 | YOLO(model).train(**args) 153 | else: 154 | trainer = ClassificationTrainer(overrides=args) 155 | trainer.train() 156 | 157 | 158 | if __name__ == '__main__': 159 | train() 160 | -------------------------------------------------------------------------------- /ultralytics/yolo/v8/classify/val.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | from ultralytics.yolo.data import build_classification_dataloader 4 | from ultralytics.yolo.engine.validator import BaseValidator 5 | from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER 6 | from ultralytics.yolo.utils.metrics import ClassifyMetrics 7 | 8 | 9 | class ClassificationValidator(BaseValidator): 10 | 11 | def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None): 12 | super().__init__(dataloader, save_dir, pbar, args) 13 | self.args.task = 'classify' 14 | self.metrics = ClassifyMetrics() 15 | 16 | def get_desc(self): 17 | return ('%22s' + '%11s' * 2) % ('classes', 'top1_acc', 'top5_acc') 18 | 19 | def init_metrics(self, model): 20 | self.pred = [] 21 | self.targets = [] 22 | 23 | def preprocess(self, batch): 24 | batch['img'] = batch['img'].to(self.device, non_blocking=True) 25 | batch['img'] = batch['img'].half() if self.args.half else batch['img'].float() 26 | batch['cls'] = batch['cls'].to(self.device) 27 | return batch 28 | 29 | def update_metrics(self, preds, batch): 30 | n5 = min(len(self.model.names), 5) 31 | self.pred.append(preds.argsort(1, descending=True)[:, :n5]) 32 | self.targets.append(batch['cls']) 33 | 34 | def finalize_metrics(self, *args, **kwargs): 35 | self.metrics.speed = self.speed 36 | # self.metrics.confusion_matrix = self.confusion_matrix # TODO: classification ConfusionMatrix 37 | 38 | def get_stats(self): 39 | self.metrics.process(self.targets, self.pred) 40 | return self.metrics.results_dict 41 | 42 | def get_dataloader(self, dataset_path, batch_size): 43 | return build_classification_dataloader(path=dataset_path, 44 | imgsz=self.args.imgsz, 45 | batch_size=batch_size, 46 | augment=False, 47 | shuffle=False, 48 | workers=self.args.workers) 49 | 50 | def print_results(self): 51 | pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format 52 | LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5)) 53 | 54 | 55 | def val(cfg=DEFAULT_CFG, use_python=False): 56 | model = cfg.model or 'yolov8n-cls.pt' # or "resnet18" 57 | data = cfg.data or 'mnist160' 58 | 59 | args = dict(model=model, data=data) 60 | if use_python: 61 | from ultralytics import YOLO 62 | YOLO(model).val(**args) 63 | else: 64 | validator = ClassificationValidator(args=args) 65 | validator(model=args['model']) 66 | 67 | 68 | if __name__ == '__main__': 69 | val() 70 | -------------------------------------------------------------------------------- /ultralytics/yolo/v8/detect/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | from .predict import DetectionPredictor, predict 4 | from .train import DetectionTrainer, train 5 | from .val import DetectionValidator, val 6 | 7 | __all__ = 'DetectionPredictor', 'predict', 'DetectionTrainer', 'train', 'DetectionValidator', 'val' 8 | -------------------------------------------------------------------------------- /ultralytics/yolo/v8/detect/predict.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | import torch 4 | 5 | from ultralytics.yolo.engine.predictor import BasePredictor 6 | from ultralytics.yolo.engine.results import Results 7 | from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops 8 | from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box 9 | 10 | 11 | class DetectionPredictor(BasePredictor): 12 | 13 | def get_annotator(self, img): 14 | return Annotator(img, line_width=self.args.line_thickness, example=str(self.model.names)) 15 | 16 | def preprocess(self, img): 17 | img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device) 18 | img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32 19 | img /= 255 # 0 - 255 to 0.0 - 1.0 20 | return img 21 | 22 | def postprocess(self, preds, img, orig_imgs): 23 | preds = ops.non_max_suppression(preds, 24 | self.args.conf, 25 | self.args.iou, 26 | agnostic=self.args.agnostic_nms, 27 | max_det=self.args.max_det, 28 | classes=self.args.classes) 29 | 30 | results = [] 31 | for i, pred in enumerate(preds): 32 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 33 | if not isinstance(orig_imgs, torch.Tensor): 34 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) 35 | path, _, _, _, _ = self.batch 36 | img_path = path[i] if isinstance(path, list) else path 37 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred)) 38 | return results 39 | 40 | def write_results(self, idx, results, batch): 41 | p, im, im0 = batch 42 | log_string = '' 43 | if len(im.shape) == 3: 44 | im = im[None] # expand for batch dim 45 | self.seen += 1 46 | imc = im0.copy() if self.args.save_crop else im0 47 | if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1 48 | log_string += f'{idx}: ' 49 | frame = self.dataset.count 50 | else: 51 | frame = getattr(self.dataset, 'frame', 0) 52 | self.data_path = p 53 | self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}') 54 | log_string += '%gx%g ' % im.shape[2:] # print string 55 | self.annotator = self.get_annotator(im0) 56 | 57 | det = results[idx].boxes # TODO: make boxes inherit from tensors 58 | if len(det) == 0: 59 | return f'{log_string}(no detections), ' 60 | for c in det.cls.unique(): 61 | n = (det.cls == c).sum() # detections per class 62 | log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, " 63 | 64 | # write 65 | for d in reversed(det): 66 | c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item()) 67 | if self.args.save_txt: # Write to file 68 | line = (c, *d.xywhn.view(-1)) + (conf, ) * self.args.save_conf + (() if id is None else (id, )) 69 | with open(f'{self.txt_path}.txt', 'a') as f: 70 | f.write(('%g ' * len(line)).rstrip() % line + '\n') 71 | if self.args.save or self.args.show: # Add bbox to image 72 | name = ('' if id is None else f'id:{id} ') + self.model.names[c] 73 | label = None if self.args.hide_labels else (name if self.args.hide_conf else f'{name} {conf:.2f}') 74 | self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) 75 | if self.args.save_crop: 76 | save_one_box(d.xyxy, 77 | imc, 78 | file=self.save_dir / 'crops' / self.model.names[c] / f'{self.data_path.stem}.jpg', 79 | BGR=True) 80 | 81 | return log_string 82 | 83 | 84 | def predict(cfg=DEFAULT_CFG, use_python=False): 85 | model = cfg.model or 'yolov8n.pt' 86 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \ 87 | else 'https://ultralytics.com/images/bus.jpg' 88 | 89 | args = dict(model=model, source=source) 90 | if use_python: 91 | from ultralytics import YOLO 92 | YOLO(model)(**args) 93 | else: 94 | predictor = DetectionPredictor(overrides=args) 95 | predictor.predict_cli() 96 | 97 | 98 | if __name__ == '__main__': 99 | predict() 100 | -------------------------------------------------------------------------------- /ultralytics/yolo/v8/segment/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | from .predict import SegmentationPredictor, predict 4 | from .train import SegmentationTrainer, train 5 | from .val import SegmentationValidator, val 6 | 7 | __all__ = 'SegmentationPredictor', 'predict', 'SegmentationTrainer', 'train', 'SegmentationValidator', 'val' 8 | -------------------------------------------------------------------------------- /ultralytics/yolo/v8/segment/predict.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | import torch 4 | 5 | from ultralytics.yolo.engine.results import Results 6 | from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops 7 | from ultralytics.yolo.utils.plotting import colors, save_one_box 8 | from ultralytics.yolo.v8.detect.predict import DetectionPredictor 9 | 10 | 11 | class SegmentationPredictor(DetectionPredictor): 12 | 13 | def postprocess(self, preds, img, orig_imgs): 14 | # TODO: filter by classes 15 | p = ops.non_max_suppression(preds[0], 16 | self.args.conf, 17 | self.args.iou, 18 | agnostic=self.args.agnostic_nms, 19 | max_det=self.args.max_det, 20 | nc=len(self.model.names), 21 | classes=self.args.classes) 22 | results = [] 23 | proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported 24 | for i, pred in enumerate(p): 25 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 26 | path, _, _, _, _ = self.batch 27 | img_path = path[i] if isinstance(path, list) else path 28 | if not len(pred): # save empty boxes 29 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6])) 30 | continue 31 | if self.args.retina_masks: 32 | if not isinstance(orig_imgs, torch.Tensor): 33 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) 34 | masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC 35 | else: 36 | masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC 37 | if not isinstance(orig_imgs, torch.Tensor): 38 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) 39 | results.append( 40 | Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)) 41 | return results 42 | 43 | def write_results(self, idx, results, batch): 44 | p, im, im0 = batch 45 | log_string = '' 46 | if len(im.shape) == 3: 47 | im = im[None] # expand for batch dim 48 | self.seen += 1 49 | imc = im0.copy() if self.args.save_crop else im0 50 | if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1 51 | log_string += f'{idx}: ' 52 | frame = self.dataset.count 53 | else: 54 | frame = getattr(self.dataset, 'frame', 0) 55 | 56 | self.data_path = p 57 | self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}') 58 | log_string += '%gx%g ' % im.shape[2:] # print string 59 | self.annotator = self.get_annotator(im0) 60 | 61 | result = results[idx] 62 | if len(result) == 0: 63 | return f'{log_string}(no detections), ' 64 | det, mask = result.boxes, result.masks # getting tensors TODO: mask mask,box inherit for tensor 65 | 66 | # Print results 67 | for c in det.cls.unique(): 68 | n = (det.cls == c).sum() # detections per class 69 | log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, " 70 | 71 | # Mask plotting 72 | if self.args.save or self.args.show: 73 | im_gpu = torch.as_tensor(im0, dtype=torch.float16, device=mask.masks.device).permute( 74 | 2, 0, 1).flip(0).contiguous() / 255 if self.args.retina_masks else im[idx] 75 | self.annotator.masks(masks=mask.masks, colors=[colors(x, True) for x in det.cls], im_gpu=im_gpu) 76 | 77 | # Write results 78 | for j, d in enumerate(reversed(det)): 79 | c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item()) 80 | if self.args.save_txt: # Write to file 81 | seg = mask.segments[len(det) - j - 1].copy().reshape(-1) # reversed mask.segments, (n,2) to (n*2) 82 | line = (c, *seg) + (conf, ) * self.args.save_conf + (() if id is None else (id, )) 83 | with open(f'{self.txt_path}.txt', 'a') as f: 84 | f.write(('%g ' * len(line)).rstrip() % line + '\n') 85 | if self.args.save or self.args.show: # Add bbox to image 86 | name = ('' if id is None else f'id:{id} ') + self.model.names[c] 87 | label = None if self.args.hide_labels else (name if self.args.hide_conf else f'{name} {conf:.2f}') 88 | if self.args.boxes: 89 | self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) 90 | if self.args.save_crop: 91 | save_one_box(d.xyxy, 92 | imc, 93 | file=self.save_dir / 'crops' / self.model.names[c] / f'{self.data_path.stem}.jpg', 94 | BGR=True) 95 | 96 | return log_string 97 | 98 | 99 | def predict(cfg=DEFAULT_CFG, use_python=False): 100 | model = cfg.model or 'yolov8n-seg.pt' 101 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \ 102 | else 'https://ultralytics.com/images/bus.jpg' 103 | 104 | args = dict(model=model, source=source) 105 | if use_python: 106 | from ultralytics import YOLO 107 | YOLO(model)(**args) 108 | else: 109 | predictor = SegmentationPredictor(overrides=args) 110 | predictor.predict_cli() 111 | 112 | 113 | if __name__ == '__main__': 114 | predict() 115 | -------------------------------------------------------------------------------- /ultralytics/yolo/v8/segment/train.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | from copy import copy 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from ultralytics.nn.tasks import SegmentationModel 8 | from ultralytics.yolo import v8 9 | from ultralytics.yolo.utils import DEFAULT_CFG, RANK 10 | from ultralytics.yolo.utils.ops import crop_mask, xyxy2xywh 11 | from ultralytics.yolo.utils.plotting import plot_images, plot_results 12 | from ultralytics.yolo.utils.tal import make_anchors 13 | from ultralytics.yolo.utils.torch_utils import de_parallel 14 | from ultralytics.yolo.v8.detect.train import Loss 15 | 16 | 17 | # BaseTrainer python usage 18 | class SegmentationTrainer(v8.detect.DetectionTrainer): 19 | 20 | def __init__(self, cfg=DEFAULT_CFG, overrides=None): 21 | if overrides is None: 22 | overrides = {} 23 | overrides['task'] = 'segment' 24 | super().__init__(cfg, overrides) 25 | 26 | def get_model(self, cfg=None, weights=None, verbose=True): 27 | model = SegmentationModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1) 28 | if weights: 29 | model.load(weights) 30 | 31 | return model 32 | 33 | def get_validator(self): 34 | self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss' 35 | return v8.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) 36 | 37 | def criterion(self, preds, batch): 38 | if not hasattr(self, 'compute_loss'): 39 | self.compute_loss = SegLoss(de_parallel(self.model), overlap=self.args.overlap_mask) 40 | return self.compute_loss(preds, batch) 41 | 42 | def plot_training_samples(self, batch, ni): 43 | images = batch['img'] 44 | masks = batch['masks'] 45 | cls = batch['cls'].squeeze(-1) 46 | bboxes = batch['bboxes'] 47 | paths = batch['im_file'] 48 | batch_idx = batch['batch_idx'] 49 | plot_images(images, batch_idx, cls, bboxes, masks, paths=paths, fname=self.save_dir / f'train_batch{ni}.jpg') 50 | 51 | def plot_metrics(self): 52 | plot_results(file=self.csv, segment=True) # save results.png 53 | 54 | 55 | # Criterion class for computing training losses 56 | class SegLoss(Loss): 57 | 58 | def __init__(self, model, overlap=True): # model must be de-paralleled 59 | super().__init__(model) 60 | self.nm = model.model[-1].nm # number of masks 61 | self.overlap = overlap 62 | 63 | def __call__(self, preds, batch): 64 | loss = torch.zeros(4, device=self.device) # box, cls, dfl 65 | feats, pred_masks, proto = preds if len(preds) == 3 else preds[1] 66 | batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width 67 | pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( 68 | (self.reg_max * 4, self.nc), 1) 69 | 70 | # b, grids, .. 71 | pred_scores = pred_scores.permute(0, 2, 1).contiguous() 72 | pred_distri = pred_distri.permute(0, 2, 1).contiguous() 73 | pred_masks = pred_masks.permute(0, 2, 1).contiguous() 74 | 75 | dtype = pred_scores.dtype 76 | imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w) 77 | anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) 78 | 79 | # targets 80 | try: 81 | batch_idx = batch['batch_idx'].view(-1, 1) 82 | targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1) 83 | targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) 84 | gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy 85 | mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) 86 | except RuntimeError as e: 87 | raise TypeError('ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n' 88 | "This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, " 89 | "i.e. 'yolo train model=yolov8n-seg.pt data=coco128.yaml'.\nVerify your dataset is a " 90 | "correctly formatted 'segment' dataset using 'data=coco128-seg.yaml' " 91 | 'as an example.\nSee https://docs.ultralytics.com/tasks/segment/ for help.') from e 92 | 93 | # pboxes 94 | pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) 95 | 96 | _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner( 97 | pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), 98 | anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt) 99 | 100 | target_scores_sum = max(target_scores.sum(), 1) 101 | 102 | # cls loss 103 | # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way 104 | loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE 105 | 106 | if fg_mask.sum(): 107 | # bbox loss 108 | loss[0], loss[3] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes / stride_tensor, 109 | target_scores, target_scores_sum, fg_mask) 110 | # masks loss 111 | masks = batch['masks'].to(self.device).float() 112 | if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample 113 | masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0] 114 | 115 | for i in range(batch_size): 116 | if fg_mask[i].sum(): 117 | mask_idx = target_gt_idx[i][fg_mask[i]] 118 | if self.overlap: 119 | gt_mask = torch.where(masks[[i]] == (mask_idx + 1).view(-1, 1, 1), 1.0, 0.0) 120 | else: 121 | gt_mask = masks[batch_idx.view(-1) == i][mask_idx] 122 | xyxyn = target_bboxes[i][fg_mask[i]] / imgsz[[1, 0, 1, 0]] 123 | marea = xyxy2xywh(xyxyn)[:, 2:].prod(1) 124 | mxyxy = xyxyn * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device) 125 | loss[1] += self.single_mask_loss(gt_mask, pred_masks[i][fg_mask[i]], proto[i], mxyxy, marea) # seg 126 | 127 | # WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove 128 | else: 129 | loss[1] += proto.sum() * 0 + pred_masks.sum() * 0 130 | 131 | # WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove 132 | else: 133 | loss[1] += proto.sum() * 0 + pred_masks.sum() * 0 134 | 135 | loss[0] *= self.hyp.box # box gain 136 | loss[1] *= self.hyp.box / batch_size # seg gain 137 | loss[2] *= self.hyp.cls # cls gain 138 | loss[3] *= self.hyp.dfl # dfl gain 139 | 140 | return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl) 141 | 142 | def single_mask_loss(self, gt_mask, pred, proto, xyxy, area): 143 | # Mask loss for one image 144 | pred_mask = (pred @ proto.view(self.nm, -1)).view(-1, *proto.shape[1:]) # (n, 32) @ (32,80,80) -> (n,80,80) 145 | loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction='none') 146 | return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean() 147 | 148 | 149 | def train(cfg=DEFAULT_CFG, use_python=False): 150 | model = cfg.model or 'yolov8n-seg.pt' 151 | data = cfg.data or 'coco128-seg.yaml' # or yolo.ClassificationDataset("mnist") 152 | device = cfg.device if cfg.device is not None else '' 153 | 154 | args = dict(model=model, data=data, device=device) 155 | if use_python: 156 | from ultralytics import YOLO 157 | YOLO(model).train(**args) 158 | else: 159 | trainer = SegmentationTrainer(overrides=args) 160 | trainer.train() 161 | 162 | 163 | if __name__ == '__main__': 164 | train() 165 | -------------------------------------------------------------------------------- /userInfo.csv: -------------------------------------------------------------------------------- 1 | vscode,123 2 | 123123,MBL 3 | 123123l,mbl 4 | 123456,123456 5 | 123123456,123123456 6 | 123456789,123456789 7 | 789789,789789 8 | 147147,147147 9 | 148148,148148 10 | -------------------------------------------------------------------------------- /utils/autobatch.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | """ 3 | Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch. 4 | """ 5 | 6 | from copy import deepcopy 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from ultralytics.yolo.utils import LOGGER, colorstr 12 | from ultralytics.yolo.utils.torch_utils import profile 13 | 14 | 15 | def check_train_batch_size(model, imgsz=640, amp=True): 16 | """ 17 | Check YOLO training batch size using the autobatch() function. 18 | 19 | Args: 20 | model (torch.nn.Module): YOLO model to check batch size for. 21 | imgsz (int): Image size used for training. 22 | amp (bool): If True, use automatic mixed precision (AMP) for training. 23 | 24 | Returns: 25 | int: Optimal batch size computed using the autobatch() function. 26 | """ 27 | 28 | with torch.cuda.amp.autocast(amp): 29 | return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size 30 | 31 | 32 | def autobatch(model, imgsz=640, fraction=0.67, batch_size=16): 33 | """ 34 | Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory. 35 | 36 | Args: 37 | model: YOLO model to compute batch size for. 38 | imgsz (int, optional): The image size used as input for the YOLO model. Defaults to 640. 39 | fraction (float, optional): The fraction of available CUDA memory to use. Defaults to 0.67. 40 | batch_size (int, optional): The default batch size to use if an error is detected. Defaults to 16. 41 | 42 | Returns: 43 | int: The optimal batch size. 44 | """ 45 | 46 | # Check device 47 | prefix = colorstr('AutoBatch: ') 48 | LOGGER.info(f'{prefix}Computing optimal batch size for imgsz={imgsz}') 49 | device = next(model.parameters()).device # get model device 50 | if device.type == 'cpu': 51 | LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}') 52 | return batch_size 53 | if torch.backends.cudnn.benchmark: 54 | LOGGER.info(f'{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}') 55 | return batch_size 56 | 57 | # Inspect CUDA memory 58 | gb = 1 << 30 # bytes to GiB (1024 ** 3) 59 | d = str(device).upper() # 'CUDA:0' 60 | properties = torch.cuda.get_device_properties(device) # device properties 61 | t = properties.total_memory / gb # GiB total 62 | r = torch.cuda.memory_reserved(device) / gb # GiB reserved 63 | a = torch.cuda.memory_allocated(device) / gb # GiB allocated 64 | f = t - (r + a) # GiB free 65 | LOGGER.info(f'{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free') 66 | 67 | # Profile batch sizes 68 | batch_sizes = [1, 2, 4, 8, 16] 69 | try: 70 | img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes] 71 | results = profile(img, model, n=3, device=device) 72 | 73 | # Fit a solution 74 | y = [x[2] for x in results if x] # memory [2] 75 | p = np.polyfit(batch_sizes[:len(y)], y, deg=1) # first degree polynomial fit 76 | b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size) 77 | if None in results: # some sizes failed 78 | i = results.index(None) # first fail index 79 | if b >= batch_sizes[i]: # y intercept above failure point 80 | b = batch_sizes[max(i - 1, 0)] # select prior safe point 81 | if b < 1 or b > 1024: # b outside of safe range 82 | b = batch_size 83 | LOGGER.info(f'{prefix}WARNING ⚠️ CUDA anomaly detected, using default batch-size {batch_size}.') 84 | 85 | fraction = (np.polyval(p, b) + r + a) / t # actual fraction predicted 86 | LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅') 87 | return b 88 | except Exception as e: 89 | LOGGER.warning(f'{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.') 90 | return batch_size 91 | -------------------------------------------------------------------------------- /utils/benchmarks.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | """ 3 | Benchmark a YOLO model formats for speed and accuracy 4 | 5 | Usage: 6 | from ultralytics.yolo.utils.benchmarks import run_benchmarks 7 | run_benchmarks(model='yolov8n.pt', imgsz=160) 8 | 9 | Format | `format=argument` | Model 10 | --- | --- | --- 11 | PyTorch | - | yolov8n.pt 12 | TorchScript | `torchscript` | yolov8n.torchscript 13 | ONNX | `onnx` | yolov8n.onnx 14 | OpenVINO | `openvino` | yolov8n_openvino_model/ 15 | TensorRT | `engine` | yolov8n.engine 16 | CoreML | `coreml` | yolov8n.mlmodel 17 | TensorFlow SavedModel | `saved_model` | yolov8n_saved_model/ 18 | TensorFlow GraphDef | `pb` | yolov8n.pb 19 | TensorFlow Lite | `tflite` | yolov8n.tflite 20 | TensorFlow Edge TPU | `edgetpu` | yolov8n_edgetpu.tflite 21 | TensorFlow.js | `tfjs` | yolov8n_web_model/ 22 | PaddlePaddle | `paddle` | yolov8n_paddle_model/ 23 | """ 24 | 25 | import platform 26 | import time 27 | from pathlib import Path 28 | 29 | from ultralytics import YOLO 30 | from ultralytics.yolo.engine.exporter import export_formats 31 | from ultralytics.yolo.utils import LINUX, LOGGER, MACOS, ROOT, SETTINGS 32 | from ultralytics.yolo.utils.checks import check_yolo 33 | from ultralytics.yolo.utils.downloads import download 34 | from ultralytics.yolo.utils.files import file_size 35 | from ultralytics.yolo.utils.torch_utils import select_device 36 | 37 | 38 | def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt', imgsz=160, half=False, device='cpu', hard_fail=False): 39 | import pandas as pd 40 | pd.options.display.max_columns = 10 41 | pd.options.display.width = 120 42 | device = select_device(device, verbose=False) 43 | if isinstance(model, (str, Path)): 44 | model = YOLO(model) 45 | 46 | y = [] 47 | t0 = time.time() 48 | for i, (name, format, suffix, cpu, gpu) in export_formats().iterrows(): # index, (name, format, suffix, CPU, GPU) 49 | emoji, filename = '❌', None # export defaults 50 | try: 51 | if model.task == 'classify': 52 | assert i != 11, 'paddle cls exports coming soon' 53 | assert i != 9 or LINUX, 'Edge TPU export only supported on Linux' 54 | if i == 10: 55 | assert MACOS or LINUX, 'TF.js export only supported on macOS and Linux' 56 | if 'cpu' in device.type: 57 | assert cpu, 'inference not supported on CPU' 58 | if 'cuda' in device.type: 59 | assert gpu, 'inference not supported on GPU' 60 | 61 | # Export 62 | if format == '-': 63 | filename = model.ckpt_path or model.cfg 64 | export = model # PyTorch format 65 | else: 66 | filename = model.export(imgsz=imgsz, format=format, half=half, device=device) # all others 67 | export = YOLO(filename, task=model.task) 68 | assert suffix in str(filename), 'export failed' 69 | emoji = '❎' # indicates export succeeded 70 | 71 | # Predict 72 | assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported 73 | assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML 74 | if not (ROOT / 'assets/bus.jpg').exists(): 75 | download(url='https://ultralytics.com/images/bus.jpg', dir=ROOT / 'assets') 76 | export.predict(ROOT / 'assets/bus.jpg', imgsz=imgsz, device=device, half=half) 77 | 78 | # Validate 79 | if model.task == 'detect': 80 | data, key = 'coco128.yaml', 'metrics/mAP50-95(B)' 81 | elif model.task == 'segment': 82 | data, key = 'coco128-seg.yaml', 'metrics/mAP50-95(M)' 83 | elif model.task == 'classify': 84 | data, key = 'imagenet100', 'metrics/accuracy_top5' 85 | 86 | results = export.val(data=data, batch=1, imgsz=imgsz, plots=False, device=device, half=half, verbose=False) 87 | metric, speed = results.results_dict[key], results.speed['inference'] 88 | y.append([name, '✅', round(file_size(filename), 1), round(metric, 4), round(speed, 2)]) 89 | except Exception as e: 90 | if hard_fail: 91 | assert type(e) is AssertionError, f'Benchmark hard_fail for {name}: {e}' 92 | LOGGER.warning(f'ERROR ❌️ Benchmark failure for {name}: {e}') 93 | y.append([name, emoji, round(file_size(filename), 1), None, None]) # mAP, t_inference 94 | 95 | # Print results 96 | check_yolo(device=device) # print system info 97 | df = pd.DataFrame(y, columns=['Format', 'Status❔', 'Size (MB)', key, 'Inference time (ms/im)']) 98 | 99 | name = Path(model.ckpt_path).name 100 | s = f'\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({time.time() - t0:.2f}s)\n{df}\n' 101 | LOGGER.info(s) 102 | with open('benchmarks.log', 'a', errors='ignore', encoding='utf-8') as f: 103 | f.write(s) 104 | 105 | if hard_fail and isinstance(hard_fail, float): 106 | metrics = df[key].array # values to compare to floor 107 | floor = hard_fail # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n 108 | assert all(x > floor for x in metrics if pd.notna(x)), f'HARD FAIL: one or more metric(s) < floor {floor}' 109 | 110 | return df 111 | 112 | 113 | if __name__ == '__main__': 114 | benchmark() 115 | -------------------------------------------------------------------------------- /utils/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import add_integration_callbacks, default_callbacks 2 | 3 | __all__ = 'add_integration_callbacks', 'default_callbacks' 4 | -------------------------------------------------------------------------------- /utils/callbacks/base.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | """ 3 | Base callbacks 4 | """ 5 | 6 | 7 | # Trainer callbacks ---------------------------------------------------------------------------------------------------- 8 | def on_pretrain_routine_start(trainer): 9 | pass 10 | 11 | 12 | def on_pretrain_routine_end(trainer): 13 | pass 14 | 15 | 16 | def on_train_start(trainer): 17 | pass 18 | 19 | 20 | def on_train_epoch_start(trainer): 21 | pass 22 | 23 | 24 | def on_train_batch_start(trainer): 25 | pass 26 | 27 | 28 | def optimizer_step(trainer): 29 | pass 30 | 31 | 32 | def on_before_zero_grad(trainer): 33 | pass 34 | 35 | 36 | def on_train_batch_end(trainer): 37 | pass 38 | 39 | 40 | def on_train_epoch_end(trainer): 41 | pass 42 | 43 | 44 | def on_fit_epoch_end(trainer): 45 | pass 46 | 47 | 48 | def on_model_save(trainer): 49 | pass 50 | 51 | 52 | def on_train_end(trainer): 53 | pass 54 | 55 | 56 | def on_params_update(trainer): 57 | pass 58 | 59 | 60 | def teardown(trainer): 61 | pass 62 | 63 | 64 | # Validator callbacks -------------------------------------------------------------------------------------------------- 65 | def on_val_start(validator): 66 | pass 67 | 68 | 69 | def on_val_batch_start(validator): 70 | pass 71 | 72 | 73 | def on_val_batch_end(validator): 74 | pass 75 | 76 | 77 | def on_val_end(validator): 78 | pass 79 | 80 | 81 | # Predictor callbacks -------------------------------------------------------------------------------------------------- 82 | def on_predict_start(predictor): 83 | pass 84 | 85 | 86 | def on_predict_batch_start(predictor): 87 | pass 88 | 89 | 90 | def on_predict_batch_end(predictor): 91 | pass 92 | 93 | 94 | def on_predict_postprocess_end(predictor): 95 | pass 96 | 97 | 98 | def on_predict_end(predictor): 99 | pass 100 | 101 | 102 | # Exporter callbacks --------------------------------------------------------------------------------------------------- 103 | def on_export_start(exporter): 104 | pass 105 | 106 | 107 | def on_export_end(exporter): 108 | pass 109 | 110 | 111 | default_callbacks = { 112 | # Run in trainer 113 | 'on_pretrain_routine_start': [on_pretrain_routine_start], 114 | 'on_pretrain_routine_end': [on_pretrain_routine_end], 115 | 'on_train_start': [on_train_start], 116 | 'on_train_epoch_start': [on_train_epoch_start], 117 | 'on_train_batch_start': [on_train_batch_start], 118 | 'optimizer_step': [optimizer_step], 119 | 'on_before_zero_grad': [on_before_zero_grad], 120 | 'on_train_batch_end': [on_train_batch_end], 121 | 'on_train_epoch_end': [on_train_epoch_end], 122 | 'on_fit_epoch_end': [on_fit_epoch_end], # fit = train + val 123 | 'on_model_save': [on_model_save], 124 | 'on_train_end': [on_train_end], 125 | 'on_params_update': [on_params_update], 126 | 'teardown': [teardown], 127 | 128 | # Run in validator 129 | 'on_val_start': [on_val_start], 130 | 'on_val_batch_start': [on_val_batch_start], 131 | 'on_val_batch_end': [on_val_batch_end], 132 | 'on_val_end': [on_val_end], 133 | 134 | # Run in predictor 135 | 'on_predict_start': [on_predict_start], 136 | 'on_predict_batch_start': [on_predict_batch_start], 137 | 'on_predict_postprocess_end': [on_predict_postprocess_end], 138 | 'on_predict_batch_end': [on_predict_batch_end], 139 | 'on_predict_end': [on_predict_end], 140 | 141 | # Run in exporter 142 | 'on_export_start': [on_export_start], 143 | 'on_export_end': [on_export_end]} 144 | 145 | 146 | def add_integration_callbacks(instance): 147 | from .clearml import callbacks as clearml_callbacks 148 | from .comet import callbacks as comet_callbacks 149 | from .hub import callbacks as hub_callbacks 150 | from .tensorboard import callbacks as tb_callbacks 151 | 152 | for x in clearml_callbacks, comet_callbacks, hub_callbacks, tb_callbacks: 153 | for k, v in x.items(): 154 | if v not in instance.callbacks[k]: # prevent duplicate callbacks addition 155 | instance.callbacks[k].append(v) # callback[name].append(func) 156 | -------------------------------------------------------------------------------- /utils/callbacks/clearml.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING 3 | from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params 4 | 5 | try: 6 | import clearml 7 | from clearml import Task 8 | 9 | assert clearml.__version__ # verify package is not directory 10 | assert not TESTS_RUNNING # do not log pytest 11 | except (ImportError, AssertionError): 12 | clearml = None 13 | 14 | 15 | def _log_images(imgs_dict, group='', step=0): 16 | task = Task.current_task() 17 | if task: 18 | for k, v in imgs_dict.items(): 19 | task.get_logger().report_image(group, k, step, v) 20 | 21 | 22 | def on_pretrain_routine_start(trainer): 23 | try: 24 | task = Task.init(project_name=trainer.args.project or 'YOLOv8', 25 | task_name=trainer.args.name, 26 | tags=['YOLOv8'], 27 | output_uri=True, 28 | reuse_last_task_id=False, 29 | auto_connect_frameworks={'pytorch': False}) 30 | task.connect(vars(trainer.args), name='General') 31 | except Exception as e: 32 | LOGGER.warning(f'WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}') 33 | 34 | 35 | def on_train_epoch_end(trainer): 36 | if trainer.epoch == 1: 37 | _log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, 'Mosaic', trainer.epoch) 38 | 39 | 40 | def on_fit_epoch_end(trainer): 41 | task = Task.current_task() 42 | if task and trainer.epoch == 0: 43 | model_info = { 44 | 'model/parameters': get_num_params(trainer.model), 45 | 'model/GFLOPs': round(get_flops(trainer.model), 3), 46 | 'model/speed(ms)': round(trainer.validator.speed['inference'], 3)} 47 | task.connect(model_info, name='Model') 48 | 49 | 50 | def on_train_end(trainer): 51 | task = Task.current_task() 52 | if task: 53 | task.update_output_model(model_path=str(trainer.best), model_name=trainer.args.name, auto_delete_file=False) 54 | 55 | 56 | callbacks = { 57 | 'on_pretrain_routine_start': on_pretrain_routine_start, 58 | 'on_train_epoch_end': on_train_epoch_end, 59 | 'on_fit_epoch_end': on_fit_epoch_end, 60 | 'on_train_end': on_train_end} if clearml else {} 61 | -------------------------------------------------------------------------------- /utils/callbacks/comet.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING 3 | from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params 4 | 5 | try: 6 | import comet_ml 7 | 8 | assert not TESTS_RUNNING # do not log pytest 9 | assert comet_ml.__version__ # verify package is not directory 10 | except (ImportError, AssertionError): 11 | comet_ml = None 12 | 13 | 14 | def on_pretrain_routine_start(trainer): 15 | try: 16 | experiment = comet_ml.Experiment(project_name=trainer.args.project or 'YOLOv8') 17 | experiment.log_parameters(vars(trainer.args)) 18 | except Exception as e: 19 | LOGGER.warning(f'WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}') 20 | 21 | 22 | def on_train_epoch_end(trainer): 23 | experiment = comet_ml.get_global_experiment() 24 | if experiment: 25 | experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix='train'), step=trainer.epoch + 1) 26 | if trainer.epoch == 1: 27 | for f in trainer.save_dir.glob('train_batch*.jpg'): 28 | experiment.log_image(f, name=f.stem, step=trainer.epoch + 1) 29 | 30 | 31 | def on_fit_epoch_end(trainer): 32 | experiment = comet_ml.get_global_experiment() 33 | if experiment: 34 | experiment.log_metrics(trainer.metrics, step=trainer.epoch + 1) 35 | if trainer.epoch == 0: 36 | model_info = { 37 | 'model/parameters': get_num_params(trainer.model), 38 | 'model/GFLOPs': round(get_flops(trainer.model), 3), 39 | 'model/speed(ms)': round(trainer.validator.speed['inference'], 3)} 40 | experiment.log_metrics(model_info, step=trainer.epoch + 1) 41 | 42 | 43 | def on_train_end(trainer): 44 | experiment = comet_ml.get_global_experiment() 45 | if experiment: 46 | experiment.log_model('YOLOv8', file_or_folder=str(trainer.best), file_name='best.pt', overwrite=True) 47 | 48 | 49 | callbacks = { 50 | 'on_pretrain_routine_start': on_pretrain_routine_start, 51 | 'on_train_epoch_end': on_train_epoch_end, 52 | 'on_fit_epoch_end': on_fit_epoch_end, 53 | 'on_train_end': on_train_end} if comet_ml else {} 54 | -------------------------------------------------------------------------------- /utils/callbacks/hub.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | import json 4 | from time import time 5 | 6 | from ultralytics.hub.utils import PREFIX, traces 7 | from ultralytics.yolo.utils import LOGGER 8 | from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params 9 | 10 | 11 | def on_pretrain_routine_end(trainer): 12 | session = getattr(trainer, 'hub_session', None) 13 | if session: 14 | # Start timer for upload rate limit 15 | LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀') 16 | session.timers = {'metrics': time(), 'ckpt': time()} # start timer on session.rate_limit 17 | 18 | 19 | def on_fit_epoch_end(trainer): 20 | session = getattr(trainer, 'hub_session', None) 21 | if session: 22 | # Upload metrics after val end 23 | all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics} 24 | if trainer.epoch == 0: 25 | model_info = { 26 | 'model/parameters': get_num_params(trainer.model), 27 | 'model/GFLOPs': round(get_flops(trainer.model), 3), 28 | 'model/speed(ms)': round(trainer.validator.speed['inference'], 3)} 29 | all_plots = {**all_plots, **model_info} 30 | session.metrics_queue[trainer.epoch] = json.dumps(all_plots) 31 | if time() - session.timers['metrics'] > session.rate_limits['metrics']: 32 | session.upload_metrics() 33 | session.timers['metrics'] = time() # reset timer 34 | session.metrics_queue = {} # reset queue 35 | 36 | 37 | def on_model_save(trainer): 38 | session = getattr(trainer, 'hub_session', None) 39 | if session: 40 | # Upload checkpoints with rate limiting 41 | is_best = trainer.best_fitness == trainer.fitness 42 | if time() - session.timers['ckpt'] > session.rate_limits['ckpt']: 43 | LOGGER.info(f'{PREFIX}Uploading checkpoint {session.model_id}') 44 | session.upload_model(trainer.epoch, trainer.last, is_best) 45 | session.timers['ckpt'] = time() # reset timer 46 | 47 | 48 | def on_train_end(trainer): 49 | session = getattr(trainer, 'hub_session', None) 50 | if session: 51 | # Upload final model and metrics with exponential standoff 52 | LOGGER.info(f'{PREFIX}Syncing final model...') 53 | session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics.get('metrics/mAP50-95(B)', 0), final=True) 54 | session.alive = False # stop heartbeats 55 | LOGGER.info(f'{PREFIX}Done ✅\n' 56 | f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀') 57 | 58 | 59 | def on_train_start(trainer): 60 | traces(trainer.args, traces_sample_rate=1.0) 61 | 62 | 63 | def on_val_start(validator): 64 | traces(validator.args, traces_sample_rate=1.0) 65 | 66 | 67 | def on_predict_start(predictor): 68 | traces(predictor.args, traces_sample_rate=1.0) 69 | 70 | 71 | def on_export_start(exporter): 72 | traces(exporter.args, traces_sample_rate=1.0) 73 | 74 | 75 | callbacks = { 76 | 'on_pretrain_routine_end': on_pretrain_routine_end, 77 | 'on_fit_epoch_end': on_fit_epoch_end, 78 | 'on_model_save': on_model_save, 79 | 'on_train_end': on_train_end, 80 | 'on_train_start': on_train_start, 81 | 'on_val_start': on_val_start, 82 | 'on_predict_start': on_predict_start, 83 | 'on_export_start': on_export_start} 84 | -------------------------------------------------------------------------------- /utils/callbacks/tensorboard.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING, colorstr 3 | 4 | try: 5 | from torch.utils.tensorboard import SummaryWriter 6 | 7 | assert not TESTS_RUNNING # do not log pytest 8 | except (ImportError, AssertionError): 9 | SummaryWriter = None 10 | 11 | writer = None # TensorBoard SummaryWriter instance 12 | 13 | 14 | def _log_scalars(scalars, step=0): 15 | if writer: 16 | for k, v in scalars.items(): 17 | writer.add_scalar(k, v, step) 18 | 19 | 20 | def on_pretrain_routine_start(trainer): 21 | if SummaryWriter: 22 | try: 23 | global writer 24 | writer = SummaryWriter(str(trainer.save_dir)) 25 | prefix = colorstr('TensorBoard: ') 26 | LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/") 27 | except Exception as e: 28 | LOGGER.warning(f'WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}') 29 | 30 | 31 | def on_batch_end(trainer): 32 | _log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1) 33 | 34 | 35 | def on_fit_epoch_end(trainer): 36 | _log_scalars(trainer.metrics, trainer.epoch + 1) 37 | 38 | 39 | callbacks = { 40 | 'on_pretrain_routine_start': on_pretrain_routine_start, 41 | 'on_fit_epoch_end': on_fit_epoch_end, 42 | 'on_batch_end': on_batch_end} 43 | -------------------------------------------------------------------------------- /utils/dist.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | import os 4 | import re 5 | import shutil 6 | import socket 7 | import sys 8 | import tempfile 9 | from pathlib import Path 10 | 11 | from . import USER_CONFIG_DIR 12 | from .torch_utils import TORCH_1_9 13 | 14 | 15 | def find_free_network_port() -> int: 16 | """Finds a free port on localhost. 17 | 18 | It is useful in single-node training when we don't want to connect to a real main node but have to set the 19 | `MASTER_PORT` environment variable. 20 | """ 21 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 22 | s.bind(('127.0.0.1', 0)) 23 | return s.getsockname()[1] # port 24 | 25 | 26 | def generate_ddp_file(trainer): 27 | module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1) 28 | 29 | content = f'''cfg = {vars(trainer.args)} \nif __name__ == "__main__": 30 | from {module} import {name} 31 | 32 | trainer = {name}(cfg=cfg) 33 | trainer.train()''' 34 | (USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True) 35 | with tempfile.NamedTemporaryFile(prefix='_temp_', 36 | suffix=f'{id(trainer)}.py', 37 | mode='w+', 38 | encoding='utf-8', 39 | dir=USER_CONFIG_DIR / 'DDP', 40 | delete=False) as file: 41 | file.write(content) 42 | return file.name 43 | 44 | 45 | def generate_ddp_command(world_size, trainer): 46 | import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218 47 | if not trainer.resume: 48 | shutil.rmtree(trainer.save_dir) # remove the save_dir 49 | file = str(Path(sys.argv[0]).resolve()) 50 | safe_pattern = re.compile(r'^[a-zA-Z0-9_. /\\-]{1,128}$') # allowed characters and maximum of 100 characters 51 | if not (safe_pattern.match(file) and Path(file).exists() and file.endswith('.py')): # using CLI 52 | file = generate_ddp_file(trainer) 53 | dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch' 54 | port = find_free_network_port() 55 | exclude_args = ['save_dir'] 56 | args = [f'{k}={v}' for k, v in vars(trainer.args).items() if k not in exclude_args] 57 | cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file] + args 58 | return cmd, file 59 | 60 | 61 | def ddp_cleanup(trainer, file): 62 | # delete temp file if created 63 | if f'{id(trainer)}.py' in file: # if temp_file suffix in file 64 | os.remove(file) 65 | -------------------------------------------------------------------------------- /utils/files.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | import contextlib 4 | import glob 5 | import os 6 | import urllib 7 | from datetime import datetime 8 | from pathlib import Path 9 | 10 | 11 | class WorkingDirectory(contextlib.ContextDecorator): 12 | # Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager 13 | def __init__(self, new_dir): 14 | self.dir = new_dir # new dir 15 | self.cwd = Path.cwd().resolve() # current dir 16 | 17 | def __enter__(self): 18 | os.chdir(self.dir) 19 | 20 | def __exit__(self, exc_type, exc_val, exc_tb): 21 | os.chdir(self.cwd) 22 | 23 | 24 | def increment_path(path, exist_ok=False, sep='', mkdir=False): 25 | """ 26 | Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc. 27 | 28 | If the path exists and exist_ok is not set to True, the path will be incremented by appending a number and sep to 29 | the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the 30 | number will be appended directly to the end of the path. If mkdir is set to True, the path will be created as a 31 | directory if it does not already exist. 32 | 33 | Args: 34 | path (str or pathlib.Path): Path to increment. 35 | exist_ok (bool, optional): If True, the path will not be incremented and will be returned as-is. Defaults to False. 36 | sep (str, optional): Separator to use between the path and the incrementation number. Defaults to an empty string. 37 | mkdir (bool, optional): If True, the path will be created as a directory if it does not exist. Defaults to False. 38 | 39 | Returns: 40 | pathlib.Path: Incremented path. 41 | """ 42 | path = Path(path) # os-agnostic 43 | if path.exists() and not exist_ok: 44 | path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '') 45 | 46 | # Method 1 47 | for n in range(2, 9999): 48 | p = f'{path}{sep}{n}{suffix}' # increment path 49 | if not os.path.exists(p): # 50 | break 51 | path = Path(p) 52 | 53 | if mkdir: 54 | path.mkdir(parents=True, exist_ok=True) # make directory 55 | 56 | return path 57 | 58 | 59 | def file_age(path=__file__): 60 | # Return days since last file update 61 | dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta 62 | return dt.days # + dt.seconds / 86400 # fractional days 63 | 64 | 65 | def file_date(path=__file__): 66 | # Return human-readable file modification date, i.e. '2021-3-26' 67 | t = datetime.fromtimestamp(Path(path).stat().st_mtime) 68 | return f'{t.year}-{t.month}-{t.day}' 69 | 70 | 71 | def file_size(path): 72 | # Return file/dir size (MB) 73 | if isinstance(path, (str, Path)): 74 | mb = 1 << 20 # bytes to MiB (1024 ** 2) 75 | path = Path(path) 76 | if path.is_file(): 77 | return path.stat().st_size / mb 78 | elif path.is_dir(): 79 | return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb 80 | return 0.0 81 | 82 | 83 | def url2file(url): 84 | # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt 85 | url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/ 86 | return Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth 87 | 88 | 89 | def get_latest_run(search_dir='.'): 90 | # Return path to most recent 'last.pt' in /runs (i.e. to --resume from) 91 | last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True) 92 | return max(last_list, key=os.path.getctime) if last_list else '' 93 | -------------------------------------------------------------------------------- /utils/id_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Modified by: Ruihao 3 | # @ProjectName:yolov5-pyqt5 4 | 5 | ''' 6 | 存放公用的账户读写函数 7 | ''' 8 | import csv 9 | 10 | # 写入账户信息到csv文件 11 | def sava_id_info(user, pwd): 12 | headers = ['name', 'key'] 13 | values = [{'name':user, 'key':pwd}] 14 | with open('userInfo.csv', 'a', encoding='utf-8', newline='') as fp: 15 | writer = csv.DictWriter(fp, headers) 16 | writer.writerows(values) 17 | 18 | # 读取csv文件获得账户信息 19 | def get_id_info(): 20 | USER_PWD = {} 21 | with open('userInfo.csv', 'r') as csvfile: # 此目录即是当前项目根目录 22 | spamreader = csv.reader(csvfile) 23 | # 逐行遍历csv文件,按照字典存储用户名与密码 24 | for row in spamreader: 25 | USER_PWD[row[0]] = row[1] 26 | return USER_PWD 27 | 28 | # 29 | 30 | 31 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from .metrics import bbox_iou 8 | from .tal import bbox2dist 9 | 10 | 11 | class VarifocalLoss(nn.Module): 12 | # Varifocal loss by Zhang et al. https://arxiv.org/abs/2008.13367 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def forward(self, pred_score, gt_score, label, alpha=0.75, gamma=2.0): 17 | weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label 18 | with torch.cuda.amp.autocast(enabled=False): 19 | loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') * 20 | weight).sum() 21 | return loss 22 | 23 | 24 | class BboxLoss(nn.Module): 25 | 26 | def __init__(self, reg_max, use_dfl=False): 27 | super().__init__() 28 | self.reg_max = reg_max 29 | self.use_dfl = use_dfl 30 | 31 | def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask): 32 | # IoU loss 33 | weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1) 34 | iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True) 35 | loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum 36 | 37 | # DFL loss 38 | if self.use_dfl: 39 | target_ltrb = bbox2dist(anchor_points, target_bboxes, self.reg_max) 40 | loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight 41 | loss_dfl = loss_dfl.sum() / target_scores_sum 42 | else: 43 | loss_dfl = torch.tensor(0.0).to(pred_dist.device) 44 | 45 | return loss_iou, loss_dfl 46 | 47 | @staticmethod 48 | def _df_loss(pred_dist, target): 49 | # Return sum of left and right DFL losses 50 | # Distribution Focal Loss (DFL) proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391 51 | tl = target.long() # target left 52 | tr = tl + 1 # target right 53 | wl = tr - target # weight left 54 | wr = 1 - wl # weight right 55 | return (F.cross_entropy(pred_dist, tl.view(-1), reduction='none').view(tl.shape) * wl + 56 | F.cross_entropy(pred_dist, tr.view(-1), reduction='none').view(tl.shape) * wr).mean(-1, keepdim=True) 57 | -------------------------------------------------------------------------------- /v8/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | from ultralytics.yolo.v8 import classify, detect, segment 4 | 5 | __all__ = 'classify', 'segment', 'detect' 6 | -------------------------------------------------------------------------------- /v8/classify/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | from ultralytics.yolo.v8.classify.predict import ClassificationPredictor, predict 4 | from ultralytics.yolo.v8.classify.train import ClassificationTrainer, train 5 | from ultralytics.yolo.v8.classify.val import ClassificationValidator, val 6 | 7 | __all__ = 'ClassificationPredictor', 'predict', 'ClassificationTrainer', 'train', 'ClassificationValidator', 'val' 8 | -------------------------------------------------------------------------------- /v8/classify/predict.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | import torch 4 | 5 | from ultralytics.yolo.engine.predictor import BasePredictor 6 | from ultralytics.yolo.engine.results import Results 7 | from ultralytics.yolo.utils import DEFAULT_CFG, ROOT 8 | from ultralytics.yolo.utils.plotting import Annotator 9 | 10 | 11 | class ClassificationPredictor(BasePredictor): 12 | 13 | def get_annotator(self, img): 14 | return Annotator(img, example=str(self.model.names), pil=True) 15 | 16 | def preprocess(self, img): 17 | img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device) 18 | return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32 19 | 20 | def postprocess(self, preds, img, orig_imgs): 21 | results = [] 22 | for i, pred in enumerate(preds): 23 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 24 | path, _, _, _, _ = self.batch 25 | img_path = path[i] if isinstance(path, list) else path 26 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, probs=pred)) 27 | 28 | return results 29 | 30 | def write_results(self, idx, results, batch): 31 | p, im, im0 = batch 32 | log_string = '' 33 | if len(im.shape) == 3: 34 | im = im[None] # expand for batch dim 35 | self.seen += 1 36 | im0 = im0.copy() 37 | if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1 38 | log_string += f'{idx}: ' 39 | frame = self.dataset.count 40 | else: 41 | frame = getattr(self.dataset, 'frame', 0) 42 | 43 | self.data_path = p 44 | # save_path = str(self.save_dir / p.name) # im.jpg 45 | self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}') 46 | log_string += '%gx%g ' % im.shape[2:] # print string 47 | self.annotator = self.get_annotator(im0) 48 | 49 | result = results[idx] 50 | if len(result) == 0: 51 | return log_string 52 | prob = result.probs 53 | # Print results 54 | n5 = min(len(self.model.names), 5) 55 | top5i = prob.argsort(0, descending=True)[:n5].tolist() # top 5 indices 56 | log_string += f"{', '.join(f'{self.model.names[j]} {prob[j]:.2f}' for j in top5i)}, " 57 | 58 | # write 59 | text = '\n'.join(f'{prob[j]:.2f} {self.model.names[j]}' for j in top5i) 60 | if self.args.save or self.args.show: # Add bbox to image 61 | self.annotator.text((32, 32), text, txt_color=(255, 255, 255)) 62 | if self.args.save_txt: # Write to file 63 | with open(f'{self.txt_path}.txt', 'a') as f: 64 | f.write(text + '\n') 65 | 66 | return log_string 67 | 68 | 69 | def predict(cfg=DEFAULT_CFG, use_python=False): 70 | model = cfg.model or 'yolov8n-cls.pt' # or "resnet18" 71 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \ 72 | else 'https://ultralytics.com/images/bus.jpg' 73 | 74 | args = dict(model=model, source=source) 75 | if use_python: 76 | from ultralytics import YOLO 77 | YOLO(model)(**args) 78 | else: 79 | predictor = ClassificationPredictor(overrides=args) 80 | predictor.predict_cli() 81 | 82 | 83 | if __name__ == '__main__': 84 | predict() 85 | -------------------------------------------------------------------------------- /v8/classify/train.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | import torch 4 | import torchvision 5 | 6 | from ultralytics.nn.tasks import ClassificationModel, attempt_load_one_weight 7 | from ultralytics.yolo import v8 8 | from ultralytics.yolo.data import build_classification_dataloader 9 | from ultralytics.yolo.engine.trainer import BaseTrainer 10 | from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, colorstr 11 | from ultralytics.yolo.utils.torch_utils import is_parallel, strip_optimizer 12 | 13 | 14 | class ClassificationTrainer(BaseTrainer): 15 | 16 | def __init__(self, cfg=DEFAULT_CFG, overrides=None): 17 | if overrides is None: 18 | overrides = {} 19 | overrides['task'] = 'classify' 20 | super().__init__(cfg, overrides) 21 | 22 | def set_model_attributes(self): 23 | self.model.names = self.data['names'] 24 | 25 | def get_model(self, cfg=None, weights=None, verbose=True): 26 | model = ClassificationModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1) 27 | if weights: 28 | model.load(weights) 29 | 30 | pretrained = False 31 | for m in model.modules(): 32 | if not pretrained and hasattr(m, 'reset_parameters'): 33 | m.reset_parameters() 34 | if isinstance(m, torch.nn.Dropout) and self.args.dropout: 35 | m.p = self.args.dropout # set dropout 36 | for p in model.parameters(): 37 | p.requires_grad = True # for training 38 | 39 | # Update defaults 40 | if self.args.imgsz == 640: 41 | self.args.imgsz = 224 42 | 43 | return model 44 | 45 | def setup_model(self): 46 | """ 47 | load/create/download model for any task 48 | """ 49 | # classification models require special handling 50 | 51 | if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed 52 | return 53 | 54 | model = str(self.model) 55 | # Load a YOLO model locally, from torchvision, or from Ultralytics assets 56 | if model.endswith('.pt'): 57 | self.model, _ = attempt_load_one_weight(model, device='cpu') 58 | for p in self.model.parameters(): 59 | p.requires_grad = True # for training 60 | elif model.endswith('.yaml'): 61 | self.model = self.get_model(cfg=model) 62 | elif model in torchvision.models.__dict__: 63 | pretrained = True 64 | self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None) 65 | else: 66 | FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.') 67 | ClassificationModel.reshape_outputs(self.model, self.data['nc']) 68 | 69 | return # dont return ckpt. Classification doesn't support resume 70 | 71 | def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'): 72 | loader = build_classification_dataloader(path=dataset_path, 73 | imgsz=self.args.imgsz, 74 | batch_size=batch_size if mode == 'train' else (batch_size * 2), 75 | augment=mode == 'train', 76 | rank=rank, 77 | workers=self.args.workers) 78 | # Attach inference transforms 79 | if mode != 'train': 80 | if is_parallel(self.model): 81 | self.model.module.transforms = loader.dataset.torch_transforms 82 | else: 83 | self.model.transforms = loader.dataset.torch_transforms 84 | return loader 85 | 86 | def preprocess_batch(self, batch): 87 | batch['img'] = batch['img'].to(self.device) 88 | batch['cls'] = batch['cls'].to(self.device) 89 | return batch 90 | 91 | def progress_string(self): 92 | return ('\n' + '%11s' * (4 + len(self.loss_names))) % \ 93 | ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size') 94 | 95 | def get_validator(self): 96 | self.loss_names = ['loss'] 97 | return v8.classify.ClassificationValidator(self.test_loader, self.save_dir) 98 | 99 | def criterion(self, preds, batch): 100 | loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / self.args.nbs 101 | loss_items = loss.detach() 102 | return loss, loss_items 103 | 104 | # def label_loss_items(self, loss_items=None, prefix="train"): 105 | # """ 106 | # Returns a loss dict with labelled training loss items tensor 107 | # """ 108 | # # Not needed for classification but necessary for segmentation & detection 109 | # keys = [f"{prefix}/{x}" for x in self.loss_names] 110 | # if loss_items is not None: 111 | # loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats 112 | # return dict(zip(keys, loss_items)) 113 | # else: 114 | # return keys 115 | 116 | def label_loss_items(self, loss_items=None, prefix='train'): 117 | """ 118 | Returns a loss dict with labelled training loss items tensor 119 | """ 120 | # Not needed for classification but necessary for segmentation & detection 121 | keys = [f'{prefix}/{x}' for x in self.loss_names] 122 | if loss_items is None: 123 | return keys 124 | loss_items = [round(float(loss_items), 5)] 125 | return dict(zip(keys, loss_items)) 126 | 127 | def resume_training(self, ckpt): 128 | pass 129 | 130 | def final_eval(self): 131 | for f in self.last, self.best: 132 | if f.exists(): 133 | strip_optimizer(f) # strip optimizers 134 | # TODO: validate best.pt after training completes 135 | # if f is self.best: 136 | # LOGGER.info(f'\nValidating {f}...') 137 | # self.validator.args.save_json = True 138 | # self.metrics = self.validator(model=f) 139 | # self.metrics.pop('fitness', None) 140 | # self.run_callbacks('on_fit_epoch_end') 141 | LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}") 142 | 143 | 144 | def train(cfg=DEFAULT_CFG, use_python=False): 145 | model = cfg.model or 'yolov8n-cls.pt' # or "resnet18" 146 | data = cfg.data or 'mnist160' # or yolo.ClassificationDataset("mnist") 147 | device = cfg.device if cfg.device is not None else '' 148 | 149 | args = dict(model=model, data=data, device=device) 150 | if use_python: 151 | from ultralytics import YOLO 152 | YOLO(model).train(**args) 153 | else: 154 | trainer = ClassificationTrainer(overrides=args) 155 | trainer.train() 156 | 157 | 158 | if __name__ == '__main__': 159 | train() 160 | -------------------------------------------------------------------------------- /v8/classify/val.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | from ultralytics.yolo.data import build_classification_dataloader 4 | from ultralytics.yolo.engine.validator import BaseValidator 5 | from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER 6 | from ultralytics.yolo.utils.metrics import ClassifyMetrics 7 | 8 | 9 | class ClassificationValidator(BaseValidator): 10 | 11 | def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None): 12 | super().__init__(dataloader, save_dir, pbar, args) 13 | self.args.task = 'classify' 14 | self.metrics = ClassifyMetrics() 15 | 16 | def get_desc(self): 17 | return ('%22s' + '%11s' * 2) % ('classes', 'top1_acc', 'top5_acc') 18 | 19 | def init_metrics(self, model): 20 | self.pred = [] 21 | self.targets = [] 22 | 23 | def preprocess(self, batch): 24 | batch['img'] = batch['img'].to(self.device, non_blocking=True) 25 | batch['img'] = batch['img'].half() if self.args.half else batch['img'].float() 26 | batch['cls'] = batch['cls'].to(self.device) 27 | return batch 28 | 29 | def update_metrics(self, preds, batch): 30 | n5 = min(len(self.model.names), 5) 31 | self.pred.append(preds.argsort(1, descending=True)[:, :n5]) 32 | self.targets.append(batch['cls']) 33 | 34 | def finalize_metrics(self, *args, **kwargs): 35 | self.metrics.speed = self.speed 36 | # self.metrics.confusion_matrix = self.confusion_matrix # TODO: classification ConfusionMatrix 37 | 38 | def get_stats(self): 39 | self.metrics.process(self.targets, self.pred) 40 | return self.metrics.results_dict 41 | 42 | def get_dataloader(self, dataset_path, batch_size): 43 | return build_classification_dataloader(path=dataset_path, 44 | imgsz=self.args.imgsz, 45 | batch_size=batch_size, 46 | augment=False, 47 | shuffle=False, 48 | workers=self.args.workers) 49 | 50 | def print_results(self): 51 | pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format 52 | LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5)) 53 | 54 | 55 | def val(cfg=DEFAULT_CFG, use_python=False): 56 | model = cfg.model or 'yolov8n-cls.pt' # or "resnet18" 57 | data = cfg.data or 'mnist160' 58 | 59 | args = dict(model=model, data=data) 60 | if use_python: 61 | from ultralytics import YOLO 62 | YOLO(model).val(**args) 63 | else: 64 | validator = ClassificationValidator(args=args) 65 | validator(model=args['model']) 66 | 67 | 68 | if __name__ == '__main__': 69 | val() 70 | -------------------------------------------------------------------------------- /v8/detect/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | from .predict import DetectionPredictor, predict 4 | from .train import DetectionTrainer, train 5 | from .val import DetectionValidator, val 6 | 7 | __all__ = 'DetectionPredictor', 'predict', 'DetectionTrainer', 'train', 'DetectionValidator', 'val' 8 | -------------------------------------------------------------------------------- /v8/detect/predict.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | import torch 4 | 5 | from ultralytics.yolo.engine.predictor import BasePredictor 6 | from ultralytics.yolo.engine.results import Results 7 | from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops 8 | from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box 9 | 10 | 11 | class DetectionPredictor(BasePredictor): 12 | 13 | def get_annotator(self, img): 14 | return Annotator(img, line_width=self.args.line_thickness, example=str(self.model.names)) 15 | 16 | def preprocess(self, img): 17 | img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device) 18 | img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32 19 | img /= 255 # 0 - 255 to 0.0 - 1.0 20 | return img 21 | 22 | def postprocess(self, preds, img, orig_imgs): 23 | preds = ops.non_max_suppression(preds, 24 | self.args.conf, 25 | self.args.iou, 26 | agnostic=self.args.agnostic_nms, 27 | max_det=self.args.max_det, 28 | classes=self.args.classes) 29 | 30 | results = [] 31 | for i, pred in enumerate(preds): 32 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 33 | if not isinstance(orig_imgs, torch.Tensor): 34 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) 35 | path, _, _, _, _ = self.batch 36 | img_path = path[i] if isinstance(path, list) else path 37 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred)) 38 | return results 39 | 40 | def write_results(self, idx, results, batch): 41 | p, im, im0 = batch 42 | log_string = '' 43 | if len(im.shape) == 3: 44 | im = im[None] # expand for batch dim 45 | self.seen += 1 46 | imc = im0.copy() if self.args.save_crop else im0 47 | if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1 48 | log_string += f'{idx}: ' 49 | frame = self.dataset.count 50 | else: 51 | frame = getattr(self.dataset, 'frame', 0) 52 | self.data_path = p 53 | self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}') 54 | log_string += '%gx%g ' % im.shape[2:] # print string 55 | self.annotator = self.get_annotator(im0) 56 | 57 | det = results[idx].boxes # TODO: make boxes inherit from tensors 58 | if len(det) == 0: 59 | return f'{log_string}(no detections), ' 60 | for c in det.cls.unique(): 61 | n = (det.cls == c).sum() # detections per class 62 | log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, " 63 | 64 | # write 65 | for d in reversed(det): 66 | c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item()) 67 | if self.args.save_txt: # Write to file 68 | line = (c, *d.xywhn.view(-1)) + (conf, ) * self.args.save_conf + (() if id is None else (id, )) 69 | with open(f'{self.txt_path}.txt', 'a') as f: 70 | f.write(('%g ' * len(line)).rstrip() % line + '\n') 71 | if self.args.save or self.args.show: # Add bbox to image 72 | name = ('' if id is None else f'id:{id} ') + self.model.names[c] 73 | label = None if self.args.hide_labels else (name if self.args.hide_conf else f'{name} {conf:.2f}') 74 | self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) 75 | if self.args.save_crop: 76 | save_one_box(d.xyxy, 77 | imc, 78 | file=self.save_dir / 'crops' / self.model.names[c] / f'{self.data_path.stem}.jpg', 79 | BGR=True) 80 | 81 | return log_string 82 | 83 | 84 | def predict(cfg=DEFAULT_CFG, use_python=False): 85 | model = cfg.model or 'yolov8n.pt' 86 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \ 87 | else 'https://ultralytics.com/images/bus.jpg' 88 | 89 | args = dict(model=model, source=source) 90 | if use_python: 91 | from ultralytics import YOLO 92 | YOLO(model)(**args) 93 | else: 94 | predictor = DetectionPredictor(overrides=args) 95 | predictor.predict_cli() 96 | 97 | 98 | if __name__ == '__main__': 99 | predict() 100 | -------------------------------------------------------------------------------- /v8/segment/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | from .predict import SegmentationPredictor, predict 4 | from .train import SegmentationTrainer, train 5 | from .val import SegmentationValidator, val 6 | 7 | __all__ = 'SegmentationPredictor', 'predict', 'SegmentationTrainer', 'train', 'SegmentationValidator', 'val' 8 | -------------------------------------------------------------------------------- /v8/segment/predict.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | 3 | import torch 4 | 5 | from ultralytics.yolo.engine.results import Results 6 | from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops 7 | from ultralytics.yolo.utils.plotting import colors, save_one_box 8 | from ultralytics.yolo.v8.detect.predict import DetectionPredictor 9 | 10 | 11 | class SegmentationPredictor(DetectionPredictor): 12 | 13 | def postprocess(self, preds, img, orig_imgs): 14 | # TODO: filter by classes 15 | p = ops.non_max_suppression(preds[0], 16 | self.args.conf, 17 | self.args.iou, 18 | agnostic=self.args.agnostic_nms, 19 | max_det=self.args.max_det, 20 | nc=len(self.model.names), 21 | classes=self.args.classes) 22 | results = [] 23 | proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported 24 | for i, pred in enumerate(p): 25 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 26 | path, _, _, _, _ = self.batch 27 | img_path = path[i] if isinstance(path, list) else path 28 | if not len(pred): # save empty boxes 29 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6])) 30 | continue 31 | if self.args.retina_masks: 32 | if not isinstance(orig_imgs, torch.Tensor): 33 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) 34 | masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC 35 | else: 36 | masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC 37 | if not isinstance(orig_imgs, torch.Tensor): 38 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) 39 | results.append( 40 | Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)) 41 | return results 42 | 43 | def write_results(self, idx, results, batch): 44 | p, im, im0 = batch 45 | log_string = '' 46 | if len(im.shape) == 3: 47 | im = im[None] # expand for batch dim 48 | self.seen += 1 49 | imc = im0.copy() if self.args.save_crop else im0 50 | if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1 51 | log_string += f'{idx}: ' 52 | frame = self.dataset.count 53 | else: 54 | frame = getattr(self.dataset, 'frame', 0) 55 | 56 | self.data_path = p 57 | self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}') 58 | log_string += '%gx%g ' % im.shape[2:] # print string 59 | self.annotator = self.get_annotator(im0) 60 | 61 | result = results[idx] 62 | if len(result) == 0: 63 | return f'{log_string}(no detections), ' 64 | det, mask = result.boxes, result.masks # getting tensors TODO: mask mask,box inherit for tensor 65 | 66 | # Print results 67 | for c in det.cls.unique(): 68 | n = (det.cls == c).sum() # detections per class 69 | log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, " 70 | 71 | # Mask plotting 72 | if self.args.save or self.args.show: 73 | im_gpu = torch.as_tensor(im0, dtype=torch.float16, device=mask.masks.device).permute( 74 | 2, 0, 1).flip(0).contiguous() / 255 if self.args.retina_masks else im[idx] 75 | self.annotator.masks(masks=mask.masks, colors=[colors(x, True) for x in det.cls], im_gpu=im_gpu) 76 | 77 | # Write results 78 | for j, d in enumerate(reversed(det)): 79 | c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item()) 80 | if self.args.save_txt: # Write to file 81 | seg = mask.segments[len(det) - j - 1].copy().reshape(-1) # reversed mask.segments, (n,2) to (n*2) 82 | line = (c, *seg) + (conf, ) * self.args.save_conf + (() if id is None else (id, )) 83 | with open(f'{self.txt_path}.txt', 'a') as f: 84 | f.write(('%g ' * len(line)).rstrip() % line + '\n') 85 | if self.args.save or self.args.show: # Add bbox to image 86 | name = ('' if id is None else f'id:{id} ') + self.model.names[c] 87 | label = None if self.args.hide_labels else (name if self.args.hide_conf else f'{name} {conf:.2f}') 88 | if self.args.boxes: 89 | self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) 90 | if self.args.save_crop: 91 | save_one_box(d.xyxy, 92 | imc, 93 | file=self.save_dir / 'crops' / self.model.names[c] / f'{self.data_path.stem}.jpg', 94 | BGR=True) 95 | 96 | return log_string 97 | 98 | 99 | def predict(cfg=DEFAULT_CFG, use_python=False): 100 | model = cfg.model or 'yolov8n-seg.pt' 101 | source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \ 102 | else 'https://ultralytics.com/images/bus.jpg' 103 | 104 | args = dict(model=model, source=source) 105 | if use_python: 106 | from ultralytics import YOLO 107 | YOLO(model)(**args) 108 | else: 109 | predictor = SegmentationPredictor(overrides=args) 110 | predictor.predict_cli() 111 | 112 | 113 | if __name__ == '__main__': 114 | predict() 115 | -------------------------------------------------------------------------------- /v8/segment/train.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, GPL-3.0 license 2 | from copy import copy 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from ultralytics.nn.tasks import SegmentationModel 8 | from ultralytics.yolo import v8 9 | from ultralytics.yolo.utils import DEFAULT_CFG, RANK 10 | from ultralytics.yolo.utils.ops import crop_mask, xyxy2xywh 11 | from ultralytics.yolo.utils.plotting import plot_images, plot_results 12 | from ultralytics.yolo.utils.tal import make_anchors 13 | from ultralytics.yolo.utils.torch_utils import de_parallel 14 | from ultralytics.yolo.v8.detect.train import Loss 15 | 16 | 17 | # BaseTrainer python usage 18 | class SegmentationTrainer(v8.detect.DetectionTrainer): 19 | 20 | def __init__(self, cfg=DEFAULT_CFG, overrides=None): 21 | if overrides is None: 22 | overrides = {} 23 | overrides['task'] = 'segment' 24 | super().__init__(cfg, overrides) 25 | 26 | def get_model(self, cfg=None, weights=None, verbose=True): 27 | model = SegmentationModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1) 28 | if weights: 29 | model.load(weights) 30 | 31 | return model 32 | 33 | def get_validator(self): 34 | self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss' 35 | return v8.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) 36 | 37 | def criterion(self, preds, batch): 38 | if not hasattr(self, 'compute_loss'): 39 | self.compute_loss = SegLoss(de_parallel(self.model), overlap=self.args.overlap_mask) 40 | return self.compute_loss(preds, batch) 41 | 42 | def plot_training_samples(self, batch, ni): 43 | images = batch['img'] 44 | masks = batch['masks'] 45 | cls = batch['cls'].squeeze(-1) 46 | bboxes = batch['bboxes'] 47 | paths = batch['im_file'] 48 | batch_idx = batch['batch_idx'] 49 | plot_images(images, batch_idx, cls, bboxes, masks, paths=paths, fname=self.save_dir / f'train_batch{ni}.jpg') 50 | 51 | def plot_metrics(self): 52 | plot_results(file=self.csv, segment=True) # save results.png 53 | 54 | 55 | # Criterion class for computing training losses 56 | class SegLoss(Loss): 57 | 58 | def __init__(self, model, overlap=True): # model must be de-paralleled 59 | super().__init__(model) 60 | self.nm = model.model[-1].nm # number of masks 61 | self.overlap = overlap 62 | 63 | def __call__(self, preds, batch): 64 | loss = torch.zeros(4, device=self.device) # box, cls, dfl 65 | feats, pred_masks, proto = preds if len(preds) == 3 else preds[1] 66 | batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width 67 | pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( 68 | (self.reg_max * 4, self.nc), 1) 69 | 70 | # b, grids, .. 71 | pred_scores = pred_scores.permute(0, 2, 1).contiguous() 72 | pred_distri = pred_distri.permute(0, 2, 1).contiguous() 73 | pred_masks = pred_masks.permute(0, 2, 1).contiguous() 74 | 75 | dtype = pred_scores.dtype 76 | imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w) 77 | anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) 78 | 79 | # targets 80 | try: 81 | batch_idx = batch['batch_idx'].view(-1, 1) 82 | targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1) 83 | targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) 84 | gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy 85 | mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) 86 | except RuntimeError as e: 87 | raise TypeError('ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n' 88 | "This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, " 89 | "i.e. 'yolo train model=yolov8n-seg.pt data=coco128.yaml'.\nVerify your dataset is a " 90 | "correctly formatted 'segment' dataset using 'data=coco128-seg.yaml' " 91 | 'as an example.\nSee https://docs.ultralytics.com/tasks/segment/ for help.') from e 92 | 93 | # pboxes 94 | pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) 95 | 96 | _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner( 97 | pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), 98 | anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt) 99 | 100 | target_scores_sum = max(target_scores.sum(), 1) 101 | 102 | # cls loss 103 | # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way 104 | loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE 105 | 106 | if fg_mask.sum(): 107 | # bbox loss 108 | loss[0], loss[3] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes / stride_tensor, 109 | target_scores, target_scores_sum, fg_mask) 110 | # masks loss 111 | masks = batch['masks'].to(self.device).float() 112 | if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample 113 | masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0] 114 | 115 | for i in range(batch_size): 116 | if fg_mask[i].sum(): 117 | mask_idx = target_gt_idx[i][fg_mask[i]] 118 | if self.overlap: 119 | gt_mask = torch.where(masks[[i]] == (mask_idx + 1).view(-1, 1, 1), 1.0, 0.0) 120 | else: 121 | gt_mask = masks[batch_idx.view(-1) == i][mask_idx] 122 | xyxyn = target_bboxes[i][fg_mask[i]] / imgsz[[1, 0, 1, 0]] 123 | marea = xyxy2xywh(xyxyn)[:, 2:].prod(1) 124 | mxyxy = xyxyn * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device) 125 | loss[1] += self.single_mask_loss(gt_mask, pred_masks[i][fg_mask[i]], proto[i], mxyxy, marea) # seg 126 | 127 | # WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove 128 | else: 129 | loss[1] += proto.sum() * 0 + pred_masks.sum() * 0 130 | 131 | # WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove 132 | else: 133 | loss[1] += proto.sum() * 0 + pred_masks.sum() * 0 134 | 135 | loss[0] *= self.hyp.box # box gain 136 | loss[1] *= self.hyp.box / batch_size # seg gain 137 | loss[2] *= self.hyp.cls # cls gain 138 | loss[3] *= self.hyp.dfl # dfl gain 139 | 140 | return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl) 141 | 142 | def single_mask_loss(self, gt_mask, pred, proto, xyxy, area): 143 | # Mask loss for one image 144 | pred_mask = (pred @ proto.view(self.nm, -1)).view(-1, *proto.shape[1:]) # (n, 32) @ (32,80,80) -> (n,80,80) 145 | loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction='none') 146 | return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean() 147 | 148 | 149 | def train(cfg=DEFAULT_CFG, use_python=False): 150 | model = cfg.model or 'yolov8n-seg.pt' 151 | data = cfg.data or 'coco128-seg.yaml' # or yolo.ClassificationDataset("mnist") 152 | device = cfg.device if cfg.device is not None else '' 153 | 154 | args = dict(model=model, data=data, device=device) 155 | if use_python: 156 | from ultralytics import YOLO 157 | YOLO(model).train(**args) 158 | else: 159 | trainer = SegmentationTrainer(overrides=args) 160 | trainer.train() 161 | 162 | 163 | if __name__ == '__main__': 164 | train() 165 | --------------------------------------------------------------------------------