├── icon.png ├── requirements.txt ├── setup.py ├── YOLOv8_Trainer.spec ├── icon.py ├── .github └── workflows │ └── python-app.yml ├── README.md ├── main.py ├── parameters.py ├── training.py ├── environment.py └── ui_components.py /icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DR-lin-eng/yologui/HEAD/icon.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ultralytics>=8.0.0 2 | torch>=1.7.0 3 | torchvision>=0.8.1 4 | pyside6>=6.0.0 5 | pyyaml>=5.3.1 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="yolov8-trainer", 5 | version="1.0.0", 6 | packages=find_packages(), 7 | install_requires=[ 8 | "ultralytics>=8.0.0", 9 | "torch>=1.7.0", 10 | "torchvision>=0.8.1", 11 | "pyside6>=6.0.0", 12 | "pyyaml>=5.3.1", 13 | ], 14 | python_requires=">=3.8", 15 | entry_points={ 16 | "console_scripts": [ 17 | "yolov8-trainer=main:main", 18 | ], 19 | }, 20 | include_package_data=True, 21 | package_data={ 22 | "": ["icon.png"], 23 | }, 24 | ) 25 | -------------------------------------------------------------------------------- /YOLOv8_Trainer.spec: -------------------------------------------------------------------------------- 1 | # -*- mode: python ; coding: utf-8 -*- 2 | 3 | block_cipher = None 4 | 5 | a = Analysis( 6 | ['main.py'], 7 | pathex=[], 8 | binaries=[], 9 | datas=[('icon.png', '.')], 10 | hiddenimports=[ 11 | 'yaml', 12 | 'PySide6.QtCore', 13 | 'PySide6.QtGui', 14 | 'PySide6.QtWidgets', 15 | 'ultralytics' 16 | ], 17 | hookspath=[], 18 | hooksconfig={}, 19 | runtime_hooks=[], 20 | excludes=[], 21 | win_no_prefer_redirects=False, 22 | win_private_assemblies=False, 23 | cipher=block_cipher, 24 | noarchive=False, 25 | ) 26 | 27 | pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher) 28 | 29 | exe = EXE( 30 | pyz, 31 | a.scripts, 32 | a.binaries, 33 | a.zipfiles, 34 | a.datas, 35 | [], 36 | name='YOLOv8_Trainer', 37 | debug=False, 38 | bootloader_ignore_signals=False, 39 | strip=False, 40 | upx=True, 41 | upx_exclude=[], 42 | runtime_tmpdir=None, 43 | console=False, 44 | disable_windowed_traceback=False, 45 | argv_emulation=False, 46 | target_arch=None, 47 | codesign_identity=None, 48 | entitlements_file=None, 49 | icon='icon.png', 50 | ) 51 | -------------------------------------------------------------------------------- /icon.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | 生成应用程序图标 6 | """ 7 | 8 | from PySide6.QtGui import QIcon, QPixmap, QPainter, QColor, QBrush, QPen, QFont, QLinearGradient 9 | from PySide6.QtCore import Qt, QSize, QRect 10 | import os 11 | 12 | def create_app_icon(): 13 | """创建应用程序图标并保存为文件""" 14 | # 创建图标 15 | icon_size = 256 16 | pixmap = QPixmap(icon_size, icon_size) 17 | pixmap.fill(Qt.transparent) 18 | 19 | # 开始绘制 20 | painter = QPainter(pixmap) 21 | painter.setRenderHint(QPainter.Antialiasing) 22 | painter.setRenderHint(QPainter.TextAntialiasing) 23 | 24 | # 创建渐变背景 25 | gradient = QLinearGradient(0, 0, icon_size, icon_size) 26 | gradient.setColorAt(0, QColor(52, 152, 219)) # 蓝色 27 | gradient.setColorAt(1, QColor(41, 128, 185)) # 深蓝色 28 | 29 | # 绘制圆形背景 30 | painter.setBrush(QBrush(gradient)) 31 | painter.setPen(Qt.NoPen) 32 | painter.drawEllipse(10, 10, icon_size - 20, icon_size - 20) 33 | 34 | # 绘制文字 35 | font = QFont("Arial", int(icon_size / 3)) 36 | font.setBold(True) 37 | painter.setFont(font) 38 | painter.setPen(QPen(QColor(255, 255, 255))) 39 | 40 | text_rect = QRect(10, 10, icon_size - 20, icon_size - 20) 41 | painter.drawText(text_rect, Qt.AlignCenter, "Y8") 42 | 43 | # 结束绘制 44 | painter.end() 45 | 46 | # 保存图标 47 | pixmap.save("icon.png", "PNG") 48 | 49 | return QIcon(pixmap) 50 | 51 | def generate_app_icon(): 52 | """生成应用程序图标文件(如果不存在)""" 53 | if not os.path.exists("icon.png"): 54 | create_app_icon() 55 | 56 | if __name__ == "__main__": 57 | create_app_icon() 58 | print("图标已保存为icon.png") 59 | -------------------------------------------------------------------------------- /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | name: Build Windows Executable 2 | 3 | on: 4 | push: 5 | branches: [ main, master ] 6 | pull_request: 7 | branches: [ main, master ] 8 | workflow_dispatch: # Allows manual triggering 9 | 10 | jobs: 11 | build: 12 | runs-on: windows-latest 13 | 14 | steps: 15 | - name: Checkout repository 16 | uses: actions/checkout@v2 17 | 18 | - name: Set up Python 3.10 19 | uses: actions/setup-python@v2 20 | with: 21 | python-version: '3.10' 22 | 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install pyinstaller 27 | pip install -r requirements.txt 28 | 29 | - name: Build with PyInstaller 30 | run: | 31 | pyinstaller --name="YOLOv8_Trainer" --onefile --windowed --icon=icon.png --add-data="icon.png;." main.py 32 | 33 | # Instead of using actions/upload-artifact, just create a release 34 | - name: Create Release 35 | id: create_release 36 | uses: actions/create-release@v1 37 | env: 38 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 39 | with: 40 | tag_name: v${{ github.run_number }} 41 | release_name: Build ${{ github.run_number }} 42 | draft: false 43 | prerelease: false 44 | 45 | # Upload the executable as a release asset 46 | - name: Upload Release Asset 47 | uses: actions/upload-release-asset@v1 48 | env: 49 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 50 | with: 51 | upload_url: ${{ steps.create_release.outputs.upload_url }} 52 | asset_path: ./dist/YOLOv8_Trainer.exe 53 | asset_name: YOLOv8_Trainer.exe 54 | asset_content_type: application/octet-stream 55 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # YOLOv8 训练工具 2 | 3 | 一个基于PySide6的YOLOv8训练图形界面工具,让深度学习目标检测、分类、分割和姿态估计任务变得简单易用。 4 | 代码还在完善中。。。。。 5 | 6 | ## 📝 项目简介 7 | 8 | YOLOv8 训练工具是一个图形界面应用程序,旨在简化YOLOv8模型的训练过程。无需编写代码或命令行操作,通过简洁直观的界面即可完成数据集准备、模型选择、训练参数配置以及训练过程监控。 9 | 10 | ### 支持的任务类型 11 | - 目标检测 (Detection) 12 | - 图像分类 (Classification) 13 | - 图像分割 (Segmentation) 14 | - 姿态估计 (Pose Estimation) 15 | 16 | ## ✨ 主要特点 17 | 18 | - 📊 **直观的图形界面**:避免命令行操作,所有功能都可通过图形界面完成 19 | - 🔍 **参数说明面板**:悬停或点击即可查看详细参数解释 20 | - 📁 **灵活的数据集支持**: 21 | - 支持YAML配置文件 22 | - 自动识别和处理多种分类数据集结构 23 | - 支持已预分割的训练/验证/测试集 24 | - 支持自动分割数据集 25 | - 🎮 **实时训练监控**:直观展示训练进度、损失曲线和性能指标 26 | - 🔧 **环境检测**:自动检测CUDA、GPU和PyTorch环境 27 | - 💾 **参数保存与加载**:保存常用参数配置,方便下次使用 28 | 29 | ## 🔧 系统要求 30 | 31 | - Python 3.7+ 32 | - PySide6 33 | - PyTorch 1.7+ 34 | - Ultralytics 8.0.0+ 35 | - CUDA (可选,但推荐用于GPU训练) 36 | - 最小系统要求: 37 | - 操作系统: Windows 10/11, macOS 10.14+, 或 Linux 38 | - RAM: 8GB+ 39 | - 存储空间: 5GB+ (取决于数据集大小) 40 | 41 | ## 📦 安装方法 42 | 43 | ```bash 44 | # 克隆仓库 45 | git clone https://github.com/DR-lin-eng/yologui.git 46 | cd yologui 47 | 48 | # 安装依赖 49 | pip install -r requirements.txt 50 | 51 | # 启动应用程序 52 | python main.py 53 | ``` 54 | 55 | ## 📖 使用指南 56 | 57 | ### 数据集准备 58 | 59 | #### 目标检测/分割/姿态估计数据集结构 60 | ``` 61 | dataset/ 62 | ├── images/ 63 | │ ├── train/ 64 | │ │ ├── image1.jpg 65 | │ │ └── ... 66 | │ └── val/ 67 | │ ├── image1.jpg 68 | │ └── ... 69 | ├── labels/ 70 | │ ├── train/ 71 | │ │ ├── image1.txt 72 | │ │ └── ... 73 | │ └── val/ 74 | │ ├── image1.txt 75 | │ └── ... 76 | └── data.yaml 77 | ``` 78 | 79 | #### 分类数据集结构选项 80 | 81 | 1. **直接文件夹结构** 82 | ``` 83 | dataset/ 84 | ├── class1/ 85 | │ ├── img1.jpg 86 | │ └── ... 87 | └── class2/ 88 | ├── img2.jpg 89 | └── ... 90 | ``` 91 | 92 | 2. **预分割数据集结构** 93 | ``` 94 | dataset/ 95 | ├── train/ 96 | │ ├── class1/ 97 | │ │ ├── img1.jpg 98 | │ │ └── ... 99 | │ └── class2/ 100 | │ ├── img2.jpg 101 | │ └── ... 102 | ├── val/ 103 | │ ├── class1/ 104 | │ └── class2/ 105 | └── test/ (可选) 106 | ├── class1/ 107 | └── class2/ 108 | ``` 109 | 110 | 3. **单层类别文件夹**(自动分割为训练/验证集) 111 | ``` 112 | dataset/ 113 | ├── class1/ 114 | │ ├── img1.jpg 115 | │ └── ... 116 | └── class2/ 117 | ├── img2.jpg 118 | └── ... 119 | ``` 120 | 121 | ### 界面说明 122 | 123 | #### 1. 训练设置标签页 124 | - **任务类型选择**:检测、分类、分割或姿态估计 125 | - **数据集选择**:YAML配置文件或文件夹路径 126 | - **模型选择**:预训练模型选择(从nano到xlarge不同大小) 127 | - **训练参数配置**:批次大小、学习率、训练轮数等参数设置 128 | - **参数说明面板**:实时显示所选参数的详细解释 129 | 130 | #### 2. 训练进度标签页 131 | - **进度条**:显示当前训练进度 132 | - **性能指标**:实时显示mAP、Precision、Recall等指标 133 | - **训练日志**:显示详细的训练输出信息 134 | 135 | #### 3. 环境信息标签页 136 | - **系统信息**:操作系统和Python版本 137 | - **YOLOv8信息**:YOLOv8安装状态和版本 138 | - **CUDA信息**:CUDA可用性和版本 139 | - **GPU信息**:检测到的GPU型号和显存 140 | 141 | ## 🔄 工作流程 142 | 143 | 1. 选择任务类型(检测/分类/分割/姿态估计) 144 | 2. 选择数据集(YAML文件或直接选择文件夹) 145 | 3. 选择预训练模型或自定义模型 146 | 4. 配置训练参数 147 | 5. 点击"开始训练"按钮 148 | 6. 在训练进度标签页监控训练过程 149 | 7. 训练完成后,查看生成的模型和评估结果 150 | 151 | ## 🔍 参数说明 152 | 153 | 工具中所有的训练参数都有详细说明,只需将鼠标悬停在参数上,或点击参数后的问号按钮即可查看详细说明。常用参数包括: 154 | 155 | - **batch**: 训练批次大小,根据显存调整 156 | - **imgsz**: 输入图像大小,单位为像素 157 | - **epochs**: 训练总轮数 158 | - **lr0**: 初始学习率 159 | - **patience**: 无改进时早停的轮数 160 | - **device**: 训练设备,空为自动选择 161 | 162 | ## ⚠️ 常见问题 163 | 164 | ### 训练速度很慢怎么办? 165 | - 确保已启用CUDA并选择了正确的GPU 166 | - 减小批次大小(batch) 167 | - 减小图像尺寸(imgsz) 168 | - 使用更轻量级的模型(如yolov8n代替yolov8x) 169 | 170 | ### 如何防止过拟合? 171 | - 增加数据增强参数(在"超参数"和"增强参数"组中) 172 | - 减小训练轮数 173 | - 使用早停(patience参数) 174 | - 增加正则化参数(如weight_decay) 175 | 176 | ### 分类数据集结构无法识别? 177 | - 检查是否符合上述三种结构之一 178 | - 使用"查看/编辑配置"功能手动调整YAML配置 179 | - 对于复杂结构,建议先整理成标准格式再使用 180 | 181 | ### 训练后模型文件在哪里? 182 | 模型文件默认保存在`runs/任务类型/实验名称`目录下。可以在保存参数中修改输出目录。 183 | 184 | ## 🤝 贡献 185 | 186 | 欢迎贡献代码、报告问题或提出功能建议!请通过GitHub Issues或Pull Requests参与项目开发。 187 | 188 | ## 📄 许可证 189 | 190 | 本项目采用 MIT 许可证。 191 | 192 | --- 193 | 194 | 祝您使用愉快!如遇到任何问题,请通过GitHub Issues反馈。 195 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import sys 5 | import os 6 | from PySide6.QtWidgets import QApplication, QMainWindow, QMessageBox, QSplashScreen 7 | from PySide6.QtCore import QTimer, Qt 8 | from PySide6.QtGui import QPixmap, QIcon 9 | 10 | from ui_components import MainWindow 11 | from environment import EnvironmentChecker 12 | from parameters import load_default_parameters 13 | from training import TrainingManager 14 | from icon import generate_app_icon 15 | 16 | 17 | class YOLOv8TrainerApp: 18 | def __init__(self): 19 | self.app = QApplication(sys.argv) 20 | self.app.setApplicationName("YOLOv8 训练工具") 21 | self.app.setStyle("Fusion") # 使用Fusion样式,看起来更现代 22 | 23 | # 生成应用图标 24 | generate_app_icon() 25 | if os.path.exists("icon.png"): 26 | self.app.setWindowIcon(QIcon("icon.png")) 27 | 28 | # 显示启动画面 29 | splash_pixmap = QPixmap(400, 300) 30 | splash_pixmap.fill(Qt.white) 31 | self.splash = QSplashScreen(splash_pixmap) 32 | self.splash.showMessage("正在加载YOLOv8训练工具...", Qt.AlignCenter, Qt.black) 33 | self.splash.show() 34 | self.app.processEvents() 35 | 36 | # 加载默认参数 37 | self.splash.showMessage("加载训练参数...", Qt.AlignCenter, Qt.black) 38 | self.app.processEvents() 39 | self.parameters = load_default_parameters() 40 | 41 | # 创建主窗口 42 | self.splash.showMessage("初始化界面...", Qt.AlignCenter, Qt.black) 43 | self.app.processEvents() 44 | self.main_window = MainWindow(self.parameters) 45 | 46 | # 创建训练管理器 47 | self.training_manager = TrainingManager() 48 | 49 | # 连接信号和槽 50 | self.connect_signals() 51 | 52 | # 第一次运行检查 53 | self.splash.showMessage("检查环境...", Qt.AlignCenter, Qt.black) 54 | self.app.processEvents() 55 | self.first_run_check() 56 | 57 | def connect_signals(self): 58 | """连接UI组件的信号和槽""" 59 | # 开始训练按钮 60 | self.main_window.start_training_button.clicked.connect(self.start_training) 61 | 62 | # 停止训练按钮 63 | self.main_window.stop_training_button.clicked.connect(self.stop_training) 64 | 65 | # 训练管理器的信号 66 | self.training_manager.progress_update.connect(self.main_window.update_progress) 67 | self.training_manager.training_finished.connect(self.training_finished) 68 | self.training_manager.training_error.connect(self.training_error) 69 | 70 | def first_run_check(self): 71 | """首次运行时检查环境""" 72 | env_checker = EnvironmentChecker() 73 | status = env_checker.check_all() 74 | 75 | if not status['yolov8_installed']: 76 | msg = QMessageBox() 77 | msg.setWindowTitle("环境检查") 78 | msg.setText("未检测到YOLOv8。是否安装?") 79 | msg.setStandardButtons(QMessageBox.Yes | QMessageBox.No) 80 | if msg.exec_() == QMessageBox.Yes: # 注意这里使用 exec_() 81 | env_checker.install_yolov8() 82 | 83 | # 更新CUDA状态 84 | self.main_window.update_cuda_status(status['cuda_available']) 85 | 86 | # 显示环境信息 87 | self.main_window.update_environment_info(status) 88 | 89 | def start_training(self): 90 | """开始训练过程""" 91 | # 从UI获取参数 92 | training_params = self.main_window.get_training_parameters() 93 | 94 | # 如果用户取消,则退出 95 | if training_params is None: 96 | return 97 | 98 | # 验证参数 99 | if not self.validate_parameters(training_params): 100 | return 101 | 102 | # 启动训练 103 | self.training_manager.start_training(training_params) 104 | 105 | # 更新UI状态 106 | self.main_window.set_training_mode(True) 107 | 108 | def stop_training(self): 109 | """停止训练过程""" 110 | self.training_manager.stop_training() 111 | 112 | def training_finished(self, success): 113 | """训练完成回调""" 114 | self.main_window.set_training_mode(False) 115 | if success: 116 | QMessageBox.information(self.main_window, "训练完成", "YOLOv8 训练已成功完成!") 117 | 118 | def training_error(self, error_message): 119 | """训练错误回调""" 120 | self.main_window.set_training_mode(False) 121 | QMessageBox.critical(self.main_window, "训练错误", f"训练过程中发生错误:\n{error_message}") 122 | 123 | def validate_parameters(self, params): 124 | """验证训练参数是否有效""" 125 | # 检查数据集 126 | if not params['data_path'] or not os.path.exists(params['data_path']): 127 | QMessageBox.warning(self.main_window, "参数错误", "请选择有效的数据集配置文件") 128 | return False 129 | 130 | # 检查模型 131 | if not params['model']: 132 | QMessageBox.warning(self.main_window, "参数错误", "请选择模型类型") 133 | return False 134 | 135 | return True 136 | 137 | def run(self): 138 | """运行应用程序""" 139 | # 显示主窗口 140 | self.main_window.show() 141 | 142 | # 关闭启动画面 143 | if hasattr(self, 'splash'): 144 | self.splash.finish(self.main_window) 145 | 146 | return self.app.exec() 147 | 148 | 149 | if __name__ == "__main__": 150 | app = YOLOv8TrainerApp() 151 | sys.exit(app.run()) 152 | -------------------------------------------------------------------------------- /parameters.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import yaml 6 | from collections import OrderedDict 7 | 8 | 9 | def load_default_parameters(): 10 | """加载默认训练参数""" 11 | return { 12 | # 数据参数 13 | 'data': { 14 | 'data_path': '', # 数据集路径 15 | 'batch': 16, # 批次大小 16 | 'imgsz': 640, # 图像大小 17 | 'cache': False, # 是否缓存图像到RAM 18 | 'single_cls': False, # 单类模式 19 | 'rect': False, # 矩形训练 20 | 'fraction': 1.0, # 数据集使用比例 21 | }, 22 | 23 | # 模型参数 24 | 'model': { 25 | 'model': 'yolov8n.pt', # 模型文件 26 | 'task': 'detect', # 任务类型 (detect, segment, classify, pose) 27 | 'pretrained': True, # 是否使用预训练权重 28 | 'resume': False, # 是否恢复训练 29 | }, 30 | 31 | # 训练参数 32 | 'training': { 33 | 'epochs': 100, # 训练轮数 34 | 'patience': 50, # 早停轮数 35 | 'optimizer': 'SGD', # 优化器 36 | 'lr0': 0.01, # 初始学习率 37 | 'lrf': 0.01, # 最终学习率系数 38 | 'momentum': 0.937, # SGD动量 39 | 'weight_decay': 0.0005, # 权重衰减 40 | 'warmup_epochs': 3.0, # 预热轮数 41 | 'warmup_momentum': 0.8, # 预热动量 42 | 'warmup_bias_lr': 0.1, # 预热偏置学习率 43 | 'device': '', # 训练设备 44 | 'cos_lr': False, # 余弦学习率 45 | 'close_mosaic': 10, # 最后N轮关闭马赛克增强 46 | 'amp': True, # 混合精度训练 47 | }, 48 | 49 | # 超参数 50 | 'hyp': { 51 | 'hsv_h': 0.015, # HSV-色调增强 52 | 'hsv_s': 0.7, # HSV-饱和度增强 53 | 'hsv_v': 0.4, # HSV-亮度增强 54 | 'degrees': 0.0, # 旋转角度 (±) 55 | 'translate': 0.1, # 平移 (±) 56 | 'scale': 0.5, # 缩放 (±) 57 | 'fliplr': 0.5, # 水平翻转概率 58 | 'flipud': 0.0, # 垂直翻转概率 59 | 'mosaic': 1.0, # 马赛克增强概率 60 | 'mixup': 0.0, # mixup增强概率 61 | 'copy_paste': 0.0, # 复制粘贴概率 62 | }, 63 | 64 | # 增强参数 65 | 'augment': { 66 | 'albumentations': '', # Albumentations设置 67 | 'blur': 0.0, # 模糊增强概率 68 | 'perspective': 0.0, # 透视变换概率 69 | 'shear': 0.0, # 剪切变换概率 70 | }, 71 | 72 | # 保存参数 73 | 'save': { 74 | 'project': 'runs/train', # 保存目录 75 | 'name': 'exp', # 实验名称 76 | 'exist_ok': False, # 覆盖现有实验 77 | 'save_period': -1, # 权重保存间隔(-1为仅保存最终权重) 78 | 'save_dir': '', # 实际保存目录(自动生成) 79 | }, 80 | 81 | # 可视化参数 82 | 'visual': { 83 | 'plots': True, # 是否绘制训练图表 84 | 'noval': False, # 只训练,不验证 85 | 'v5loader': False, # 使用YOLOv5的数据加载器 86 | }, 87 | 88 | # 高级参数 89 | 'advanced': { 90 | 'nbs': 64, # 标准批量大小 91 | 'overlap_mask': True, # 掩码重叠(分割) 92 | 'mask_ratio': 4, # 掩码下采样率(分割) 93 | 'dropout': 0.0, # 使用Dropout正则化 94 | 'val': True, # 是否在训练中进行验证 95 | 'seed': 0, # 全局随机种子 96 | 'workers': 8, # 数据加载线程数 97 | 'deterministic': True, # 确定性训练 98 | } 99 | } 100 | 101 | 102 | # 参数解释字典 103 | parameter_descriptions = { 104 | # 数据参数 105 | 'data_path': '数据集配置文件路径,YAML格式', 106 | 'batch': '训练批次大小,根据显存调整', 107 | 'imgsz': '输入图像大小,单位为像素', 108 | 'cache': '是否将图像缓存到RAM中以加速训练', 109 | 'single_cls': '将多类数据集视为单类数据集', 110 | 'rect': '使用矩形训练而不是方形训练', 111 | 'fraction': '数据集使用比例,1.0表示使用全部数据', 112 | 113 | # 模型参数 114 | 'model': '模型文件路径或预训练模型名称', 115 | 'task': '任务类型:检测(detect)、分割(segment)、分类(classify)或姿态估计(pose)', 116 | 'pretrained': '是否使用预训练权重', 117 | 'resume': '从上次中断处恢复训练', 118 | 119 | # 训练参数 120 | 'epochs': '训练总轮数', 121 | 'patience': '无改进时早停的轮数', 122 | 'optimizer': '优化器选择(SGD, Adam, AdamW等)', 123 | 'lr0': '初始学习率', 124 | 'lrf': '最终学习率=初始学习率×最终学习率系数', 125 | 'momentum': 'SGD动量因子', 126 | 'weight_decay': '权重衰减系数,用于L2正则化', 127 | 'warmup_epochs': '学习率预热的轮数', 128 | 'warmup_momentum': '预热阶段的初始动量', 129 | 'warmup_bias_lr': '预热阶段的偏置学习率', 130 | 'device': '训练设备,空为自动选择', 131 | 'cos_lr': '使用余弦学习率调度', 132 | 'close_mosaic': '最后N轮关闭马赛克增强以提高稳定性', 133 | 'amp': '使用自动混合精度训练以加速', 134 | 135 | # 超参数 136 | 'hsv_h': 'HSV色调增强因子', 137 | 'hsv_s': 'HSV饱和度增强因子', 138 | 'hsv_v': 'HSV亮度增强因子', 139 | 'degrees': '随机旋转角度范围(±度)', 140 | 'translate': '随机平移范围(±图像比例)', 141 | 'scale': '随机缩放范围(±图像比例)', 142 | 'fliplr': '水平翻转的概率', 143 | 'flipud': '垂直翻转的概率', 144 | 'mosaic': '马赛克增强的概率', 145 | 'mixup': 'Mixup增强的概率', 146 | 'copy_paste': '分割掩码复制粘贴的概率', 147 | 148 | # 增强参数 149 | 'albumentations': 'Albumentations数据增强库的设置', 150 | 'blur': '随机模糊的概率', 151 | 'perspective': '透视变换的概率', 152 | 'shear': '剪切变换的概率', 153 | 154 | # 保存参数 155 | 'project': '结果保存的项目文件夹', 156 | 'name': '实验名称', 157 | 'exist_ok': '是否允许覆盖现有实验文件夹', 158 | 'save_period': '权重保存间隔,-1表示只保存最终轮次', 159 | 'save_dir': '实际保存目录(自动生成)', 160 | 161 | # 可视化参数 162 | 'plots': '是否保存训练过程的图表', 163 | 'noval': '仅训练不验证', 164 | 'v5loader': '使用YOLOv5的数据加载器', 165 | 166 | # 高级参数 167 | 'nbs': '标称批次大小,用于权重缩放', 168 | 'overlap_mask': '在分割任务中是否允许掩码重叠', 169 | 'mask_ratio': '分割掩码的下采样率', 170 | 'dropout': 'Dropout比率,用于减少过拟合', 171 | 'val': '是否在训练过程中进行验证', 172 | 'seed': '随机种子,用于可重复性', 173 | 'workers': '数据加载线程数', 174 | 'deterministic': '是否使用确定性算法以确保可重复性', 175 | } 176 | 177 | 178 | def parse_data_yaml(yaml_path): 179 | """解析数据集YAML文件""" 180 | try: 181 | with open(yaml_path, 'r', encoding='utf-8') as file: 182 | data = yaml.safe_load(file) 183 | return data 184 | except Exception as e: 185 | print(f"解析YAML错误: {str(e)}") 186 | return None 187 | 188 | 189 | def save_data_yaml(yaml_path, data): 190 | """保存修改后的数据集YAML文件""" 191 | try: 192 | with open(yaml_path, 'w', encoding='utf-8') as file: 193 | yaml.dump(data, file, allow_unicode=True, default_flow_style=False) 194 | return True 195 | except Exception as e: 196 | print(f"保存YAML错误: {str(e)}") 197 | return False 198 | 199 | 200 | def get_command_line_args(params): 201 | """将GUI参数转换为命令行参数""" 202 | args = [] 203 | 204 | # 数据参数 205 | args.append(f"data={params['data_path']}") 206 | args.append(f"batch={params['batch']}") 207 | args.append(f"imgsz={params['imgsz']}") 208 | 209 | if params['cache']: 210 | args.append("cache=True") 211 | 212 | if params['single_cls']: 213 | args.append("single_cls=True") 214 | 215 | if params['rect']: 216 | args.append("rect=True") 217 | 218 | if params['fraction'] < 1.0: 219 | args.append(f"fraction={params['fraction']}") 220 | 221 | # 模型参数 222 | args.append(f"model={params['model']}") 223 | args.append(f"task={params['task']}") 224 | 225 | if not params['pretrained']: 226 | args.append("pretrained=False") 227 | 228 | if params['resume']: 229 | args.append("resume=True") 230 | 231 | # 训练参数 232 | args.append(f"epochs={params['epochs']}") 233 | args.append(f"patience={params['patience']}") 234 | args.append(f"optimizer={params['optimizer']}") 235 | args.append(f"lr0={params['lr0']}") 236 | args.append(f"lrf={params['lrf']}") 237 | args.append(f"momentum={params['momentum']}") 238 | args.append(f"weight_decay={params['weight_decay']}") 239 | args.append(f"warmup_epochs={params['warmup_epochs']}") 240 | args.append(f"warmup_momentum={params['warmup_momentum']}") 241 | args.append(f"warmup_bias_lr={params['warmup_bias_lr']}") 242 | 243 | if params['device']: 244 | args.append(f"device={params['device']}") 245 | 246 | if params['cos_lr']: 247 | args.append("cos_lr=True") 248 | 249 | if params['close_mosaic'] > 0: 250 | args.append(f"close_mosaic={params['close_mosaic']}") 251 | 252 | if not params['amp']: 253 | args.append("amp=False") 254 | 255 | # 超参数 256 | for key, value in params.items(): 257 | if key.startswith('hsv_') or key in ['degrees', 'translate', 'scale', 'fliplr', 'flipud', 'mosaic', 'mixup', 'copy_paste']: 258 | args.append(f"{key}={value}") 259 | 260 | # 保存参数 261 | args.append(f"project={params['project']}") 262 | args.append(f"name={params['name']}") 263 | 264 | if params['exist_ok']: 265 | args.append("exist_ok=True") 266 | 267 | if params['save_period'] > 0: 268 | args.append(f"save_period={params['save_period']}") 269 | 270 | # 可视化参数 271 | if not params['plots']: 272 | args.append("plots=False") 273 | 274 | if params['noval']: 275 | args.append("noval=True") 276 | 277 | if params['v5loader']: 278 | args.append("v5loader=True") 279 | 280 | # 高级参数 281 | args.append(f"nbs={params['nbs']}") 282 | 283 | if not params['overlap_mask']: 284 | args.append("overlap_mask=False") 285 | 286 | args.append(f"mask_ratio={params['mask_ratio']}") 287 | 288 | if params['dropout'] > 0: 289 | args.append(f"dropout={params['dropout']}") 290 | 291 | if not params['val']: 292 | args.append("val=False") 293 | 294 | if params['seed'] != 0: 295 | args.append(f"seed={params['seed']}") 296 | 297 | args.append(f"workers={params['workers']}") 298 | 299 | if not params['deterministic']: 300 | args.append("deterministic=False") 301 | 302 | return args 303 | -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import sys 6 | import time 7 | import threading 8 | import signal 9 | import re 10 | import shutil 11 | from datetime import datetime, timedelta 12 | import subprocess 13 | from PySide6.QtCore import QObject, Signal, QThread, QMutex, QWaitCondition 14 | from PySide6.QtWidgets import QApplication 15 | 16 | from parameters import get_command_line_args 17 | 18 | 19 | class TrainingThread(QThread): 20 | """用于执行YOLOv8训练的线程""" 21 | 22 | def __init__(self, command, env=None): 23 | super().__init__() 24 | self.command = command 25 | self.env = env if env else os.environ.copy() 26 | # 设置UTF-8编码环境变量以解决中文路径问题 27 | self.env["PYTHONIOENCODING"] = "utf-8" 28 | self.process = None 29 | self.stopped = False 30 | self.mutex = QMutex() 31 | self.condition = QWaitCondition() 32 | 33 | def run(self): 34 | """运行训练进程""" 35 | # 启动进程并重定向输出 36 | try: 37 | # 确保使用UTF-8编码处理所有输出 38 | self.process = subprocess.Popen( 39 | self.command, 40 | stdout=subprocess.PIPE, 41 | stderr=subprocess.STDOUT, 42 | text=True, 43 | bufsize=1, 44 | universal_newlines=True, 45 | env=self.env, 46 | encoding='utf-8', # 明确指定UTF-8编码 47 | errors='replace' # 对于无法解码的字符,使用替代方式而不是报错 48 | ) 49 | 50 | # 创建一个线程来持续读取输出,避免缓冲区填满导致阻塞 51 | def read_output(): 52 | for line in iter(self.process.stdout.readline, ''): 53 | if self.stopped: 54 | break 55 | if line: 56 | self.progress_line.emit(line.strip()) 57 | # 让UI有机会更新 58 | QApplication.processEvents() 59 | 60 | # 启动读取线程 61 | read_thread = threading.Thread(target=read_output) 62 | read_thread.daemon = True 63 | read_thread.start() 64 | 65 | # 等待进程完成 66 | exit_code = self.process.wait() 67 | # 等待读取线程结束 68 | read_thread.join(timeout=2) 69 | 70 | if exit_code == 0 and not self.stopped: 71 | self.finished.emit(True) 72 | else: 73 | self.finished.emit(False) 74 | 75 | except Exception as e: 76 | self.error.emit(str(e)) 77 | 78 | def stop(self): 79 | """停止训练线程""" 80 | self.stopped = True 81 | self.mutex.lock() 82 | if self.process and self.process.poll() is None: 83 | try: 84 | # 尝试先正常终止 85 | self.process.terminate() 86 | time.sleep(1) 87 | # 如果进程仍在运行,强制终止 88 | if self.process.poll() is None: 89 | self.process.kill() 90 | except: 91 | pass 92 | self.condition.wakeAll() 93 | self.mutex.unlock() 94 | 95 | # 定义信号 96 | progress_line = Signal(str) 97 | finished = Signal(bool) 98 | error = Signal(str) 99 | 100 | 101 | class TrainingManager(QObject): 102 | """管理YOLOv8训练过程""" 103 | 104 | def __init__(self): 105 | super().__init__() 106 | self.training_thread = None 107 | self.start_time = None 108 | self.current_epoch = 0 109 | self.total_epochs = 0 110 | self.current_metrics = {} 111 | self.training_dir = "" 112 | 113 | def start_training(self, params): 114 | """启动训练进程""" 115 | # 如果已经有一个训练线程在运行,先停止它 116 | if self.training_thread and self.training_thread.isRunning(): 117 | self.stop_training() 118 | 119 | # 确保params包含data_path键,以保持与main.py的兼容性 120 | if 'data_path' not in params and 'train_folder' in params: 121 | params['data_path'] = params['train_folder'] 122 | 123 | # 重置状态 124 | self.start_time = datetime.now() 125 | self.current_epoch = 0 126 | self.total_epochs = int(params.get('epochs', 100)) 127 | self.current_metrics = {} 128 | 129 | # 构建命令 - 使用yolo命令行格式 130 | # 首先查找是否有yolo命令可用 131 | yolo_cmd = "yolo" 132 | 133 | # 如果没有yolo命令可用,则回退到python模块调用 134 | if shutil.which(yolo_cmd) is None: 135 | cmd = [sys.executable, "-m", "ultralytics"] 136 | else: 137 | cmd = [yolo_cmd] 138 | 139 | # 确定任务类型 140 | task = params.get('task', 'detect') 141 | is_classification = params.get('is_classification', False) or task == 'classify' 142 | if is_classification: 143 | task = 'classify' 144 | 145 | # 添加任务类型和train命令 146 | cmd.append(task) 147 | cmd.append('train') 148 | 149 | # 处理数据路径 - 确保使用正斜杠以避免Windows路径转义问题 150 | data_arg_added = False # 跟踪是否已经添加了data参数 151 | 152 | if is_classification and params.get('direct_folder_mode', False): 153 | # 分类任务直接使用文件夹 - 确保直接指向目录而不是YAML文件 154 | folder_path = params['train_folder'].replace('\\', '/') 155 | # 移除末尾的斜杠(如果有) 156 | folder_path = folder_path.rstrip('/') 157 | 158 | # 检查目录结构 159 | train_dir_exists = os.path.isdir(os.path.join(folder_path, 'train')) 160 | val_dir_exists = os.path.isdir(os.path.join(folder_path, 'val')) 161 | 162 | # 如果文件夹结构正确 (有train和val子目录) 163 | if train_dir_exists and val_dir_exists: 164 | # 直接使用该目录 165 | cmd.append(f"data={folder_path}") 166 | data_arg_added = True 167 | else: 168 | # 检查目录中是否有类别子目录 169 | has_class_dirs = False 170 | class_dirs = [] 171 | 172 | for item in os.listdir(folder_path): 173 | item_path = os.path.join(folder_path, item) 174 | if os.path.isdir(item_path) and not item.startswith('.'): 175 | has_class_dirs = True 176 | class_dirs.append(item) 177 | 178 | # 如果有类别子目录,则可以使用split参数 179 | if has_class_dirs: 180 | cmd.append(f"data={folder_path}") 181 | cmd.append("split=0.9") # 90%训练、10%验证 182 | data_arg_added = True 183 | print(f"检测到分类数据文件夹,将使用自动分割: {folder_path}") 184 | print(f"发现类别: {', '.join(class_dirs)}") 185 | else: 186 | self.error.emit(f"无效的分类数据目录: {folder_path}\n需要train/val子目录或类别子目录") 187 | return 188 | 189 | # 如果尚未添加data参数,处理常规模式 190 | if not data_arg_added: 191 | if 'data_path' in params: 192 | # 确保对于分类任务,我们传递的是目录而不是YAML文件 193 | data_path = params['data_path'].replace('\\', '/') 194 | 195 | # 检查如果是分类任务,且路径以.yaml结尾 196 | if is_classification and data_path.lower().endswith('.yaml'): 197 | # 尝试读取YAML文件以获取正确的数据路径 198 | try: 199 | import yaml 200 | with open(data_path, 'r', encoding='utf-8') as f: 201 | yaml_data = yaml.safe_load(f) 202 | 203 | # 如果YAML包含路径信息,使用它 204 | if 'path' in yaml_data: 205 | actual_path = yaml_data['path'] 206 | # 如果是相对路径,相对于YAML所在目录 207 | if not os.path.isabs(actual_path): 208 | yaml_dir = os.path.dirname(data_path) 209 | actual_path = os.path.join(yaml_dir, actual_path) 210 | 211 | # 修复 f-string 问题 - 不在表达式内使用反斜杠替换 212 | replaced_path = actual_path.replace('\\', '/') 213 | cmd.append(f"data={replaced_path}") 214 | else: 215 | # 使用YAML文件所在的目录 216 | # 修复 f-string 问题 217 | dirname_path = os.path.dirname(data_path).replace('\\', '/') 218 | cmd.append(f"data={dirname_path}") 219 | except Exception as e: 220 | print(f"处理YAML文件时出错: {e}") 221 | # 回退到使用原始路径 222 | cmd.append(f"data={data_path}") 223 | else: 224 | # 非分类任务或非YAML文件,直接使用 225 | cmd.append(f"data={data_path}") 226 | elif is_classification and 'train_folder' in params: 227 | # 如果没有data_path但有train_folder,使用训练文件夹 228 | folder_path = params['train_folder'].replace('\\', '/') 229 | cmd.append(f"data={folder_path}") 230 | 231 | # 添加模型参数 232 | if 'model' in params: 233 | model_param = params['model'] 234 | if task == 'classify' and 'cls' not in model_param: 235 | # 确保分类任务使用分类模型 236 | model_name = model_param.split('.')[0] 237 | if not model_name.endswith('-cls'): 238 | model_name += '-cls' 239 | cmd.append(f"model={model_name}.pt") 240 | else: 241 | cmd.append(f"model={model_param}") 242 | 243 | # 添加图像大小参数 - 对于分类任务非常重要 244 | if is_classification: 245 | imgsz = params.get('imgsz', 224) # 分类默认使用224 246 | cmd.append(f"imgsz={imgsz}") 247 | 248 | # 添加其他参数 - 只添加非默认参数 249 | default_params = { 250 | 'batch': 16, 251 | 'imgsz': 640, # 检测默认640,分类默认224 252 | 'epochs': 100, 253 | 'patience': 50, 254 | 'lr0': 0.01, 255 | 'lrf': 0.01, 256 | 'momentum': 0.937, 257 | 'weight_decay': 0.0005, 258 | 'warmup_epochs': 3.0, 259 | 'warmup_momentum': 0.8, 260 | 'warmup_bias_lr': 0.1, 261 | 'box': 7.5, 262 | 'cls': 0.5, 263 | 'dfl': 1.5, 264 | 'fl_gamma': 0.0, 265 | 'hsv_h': 0.015, 266 | 'hsv_s': 0.7, 267 | 'hsv_v': 0.4, 268 | 'degrees': 0.0, 269 | 'translate': 0.1, 270 | 'scale': 0.5, 271 | 'shear': 0.0, 272 | 'perspective': 0.0, 273 | 'flipud': 0.0, 274 | 'fliplr': 0.5, 275 | 'mosaic': 1.0, 276 | 'mixup': 0.0, 277 | 'copy_paste': 0.0 278 | } 279 | 280 | # 设置工作目录为项目目录,避免路径问题 281 | if 'project' in params: 282 | # 修复 f-string 问题 283 | project_path = params['project'].replace('\\', '/') 284 | cmd.append(f"project={project_path}") 285 | 286 | for key, value in params.items(): 287 | if key not in ['data_path', 'train_folder', 'is_classification', 'direct_folder_mode', 'model', 'task', 'project', 'imgsz']: 288 | # 只添加非默认值或明确需要的值 289 | if key not in default_params or value != default_params.get(key): 290 | # 如果是路径类型的参数,确保使用正斜杠 291 | if key.endswith('_path') or key in ['save_dir']: 292 | if isinstance(value, str): 293 | # 修复 f-string 问题 294 | replaced_value = value.replace('\\', '/') 295 | cmd.append(f"{key}={replaced_value}") 296 | else: 297 | cmd.append(f"{key}={value}") 298 | 299 | print(f"执行命令: {' '.join(cmd)}") 300 | 301 | # 设置环境变量 302 | env_vars = os.environ.copy() 303 | # 添加PYTHONIOENCODING环境变量以确保正确处理UTF-8 304 | env_vars["PYTHONIOENCODING"] = "utf-8" 305 | 306 | # 处理device参数 - 确保有效值,避免空值引起错误 307 | if 'device' in params: 308 | device_value = params['device'] 309 | # 如果device为空字符串,默认使用所有可用设备(不设置CUDA_VISIBLE_DEVICES) 310 | if device_value and device_value != 'cpu': 311 | # 提取设备ID,去掉'cuda:'前缀 312 | device_id = device_value.replace('cuda:', '') 313 | if device_id and device_id.strip(): 314 | env_vars['CUDA_VISIBLE_DEVICES'] = device_id 315 | 316 | # 创建并启动训练线程 317 | self.training_thread = TrainingThread(cmd, env_vars) 318 | self.training_thread.progress_line.connect(self.process_progress_line) 319 | self.training_thread.finished.connect(self.training_finished) 320 | self.training_thread.error.connect(self.training_error) 321 | self.training_thread.start() 322 | 323 | def stop_training(self): 324 | """停止训练进程""" 325 | if self.training_thread and self.training_thread.isRunning(): 326 | self.training_thread.stop() 327 | self.training_thread.wait(5000) # 等待最多5秒让线程结束 328 | # 如果线程仍然在运行,我们不再等待 329 | if self.training_thread.isRunning(): 330 | print("警告: 训练线程没有及时结束") 331 | 332 | def process_progress_line(self, line): 333 | """处理训练进程的输出行""" 334 | # 尝试解析YOLOv8的输出 335 | try: 336 | # 创建基本的进度信息结构 337 | progress_info = { 338 | 'current_epoch': self.current_epoch, 339 | 'total_epochs': self.total_epochs, 340 | 'metrics': self.current_metrics.copy() if hasattr(self, 'current_metrics') else {}, 341 | 'output_line': line 342 | } 343 | 344 | # 捕获训练目录 345 | if "Results saved to" in line: 346 | try: 347 | self.training_dir = re.search(r"Results saved to\s+([^\s]+)", line).group(1) 348 | except: 349 | # 如果无法解析路径,使用一个安全的默认值 350 | self.training_dir = "runs/train/exp" 351 | print(f"无法解析训练目录,使用默认值: {self.training_dir}") 352 | 353 | # 匹配YOLOv8标准输出格式 - 例如: 1/100 1.49G 3.755 16 640: 90%|████████▉ | 2699/3000 [03:42<00:26, 11.17it/s] 354 | epoch_pattern = r"^(\d+)/(\d+)\s+[\d\.]+G\s+([0-9\.]+)\s+\d+\s+\d+:\s+(\d+)%\|[^|]*\|\s*(\d+)/(\d+)" 355 | epoch_match = re.search(epoch_pattern, line) 356 | 357 | if epoch_match: 358 | # 获取当前轮次和总轮次 359 | current_epoch = int(epoch_match.group(1)) 360 | total_epochs = int(epoch_match.group(2)) 361 | self.current_epoch = current_epoch 362 | self.total_epochs = total_epochs 363 | 364 | # 获取损失值 365 | loss = float(epoch_match.group(3)) 366 | 367 | # 获取进度百分比 368 | percent = int(epoch_match.group(4)) 369 | 370 | # 获取当前批次和总批次 371 | current_batch = int(epoch_match.group(5)) 372 | total_batch = int(epoch_match.group(6)) 373 | 374 | # 计算更精确的进度 375 | if current_epoch > 0 and total_epochs > 0: 376 | # 计算轮次进度和批次进度的组合 377 | epoch_progress = (current_epoch - 1) / total_epochs 378 | batch_progress = (current_batch / total_batch) / total_epochs 379 | total_progress = (epoch_progress + batch_progress) * 100 380 | progress_info['progress'] = total_progress 381 | 382 | # 提取时间信息 [03:42<00:26, 11.17it/s] 383 | time_pattern = r"\[(\d+):(\d+)<(\d+):(\d+)" 384 | time_match = re.search(time_pattern, line) 385 | if time_match: 386 | elapsed_min = int(time_match.group(1)) 387 | elapsed_sec = int(time_match.group(2)) 388 | remain_min = int(time_match.group(3)) 389 | remain_sec = int(time_match.group(4)) 390 | 391 | elapsed_time = str(timedelta(minutes=elapsed_min, seconds=elapsed_sec)) 392 | eta = str(timedelta(minutes=remain_min, seconds=remain_sec)) 393 | 394 | progress_info['elapsed_time'] = elapsed_time 395 | progress_info['eta'] = eta 396 | 397 | # 更新当前指标 398 | self.current_metrics['loss'] = loss 399 | progress_info['metrics'] = self.current_metrics.copy() 400 | 401 | # 匹配验证指标行 - 例如: classes top1_acc top5_acc: 0%| | 0/300 [00:00 1 else "0" 224 | 225 | if int(cuda_version_str) >= 11: 226 | cuda_package = f"cu{cuda_version_str}{cuda_version_minor}" 227 | else: 228 | cuda_package = f"cu{cuda_version_str}0" 229 | 230 | command = f"pip install torch=={version} torchvision torchaudio --index-url https://pypi.tuna.tsinghua.edu.cn/simple" 231 | else: 232 | # CPU版本 233 | command = f"pip install torch=={version} torchvision torchaudio --index-url https://pypi.tuna.tsinghua.edu.cn/simple" 234 | else: 235 | # 使用官方源 236 | if self.cuda_version: 237 | # CUDA版本 238 | cuda_version_str = self.cuda_version.split('.')[0] 239 | cuda_version_minor = self.cuda_version.split('.')[1] if len(self.cuda_version.split('.')) > 1 else "0" 240 | 241 | if int(cuda_version_str) >= 11: 242 | cuda_package = f"cu{cuda_version_str}{cuda_version_minor}" 243 | else: 244 | cuda_package = f"cu{cuda_version_str}0" 245 | 246 | command = f"pip install torch=={version} torchvision torchaudio" 247 | else: 248 | # CPU版本 249 | command = f"pip install torch=={version} torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu" 250 | 251 | self.command_text.setText(command) 252 | 253 | def start_install(self): 254 | """开始安装PyTorch""" 255 | command = self.command_text.text() 256 | command_args = command.split() 257 | 258 | # 禁用按钮并显示进度条 259 | self.install_button.setEnabled(False) 260 | self.progress_bar.setVisible(True) 261 | 262 | # 创建并启动安装线程 263 | self.install_thread = InstallThread(command_args) 264 | self.install_thread.progress_signal.connect(self.update_progress) 265 | self.install_thread.finished_signal.connect(self.installation_finished) 266 | self.install_thread.start() 267 | 268 | def update_progress(self, message): 269 | """更新安装进度""" 270 | self.log_label.setText(message) 271 | 272 | def installation_finished(self, success, message): 273 | """安装完成回调""" 274 | self.progress_bar.setVisible(False) 275 | self.log_label.setText(message) 276 | 277 | if success: 278 | QMessageBox.information(self, "安装成功", "PyTorch已成功安装!") 279 | else: 280 | QMessageBox.warning(self, "安装失败", f"PyTorch安装失败: {message}") 281 | 282 | # 重新启用安装按钮 283 | self.install_button.setEnabled(True) 284 | 285 | 286 | class EnvironmentChecker: 287 | """检查YOLOv8训练所需的环境""" 288 | 289 | def __init__(self): 290 | self.status = { 291 | 'yolov8_installed': False, 292 | 'cuda_available': False, 293 | 'gpu_info': [], 294 | 'python_version': platform.python_version(), 295 | 'os_info': f"{platform.system()} {platform.release()}", 296 | 'torch_version': "未安装", 297 | 'cuda_version': "未安装" 298 | } 299 | 300 | # 尝试导入torch 301 | try: 302 | import torch 303 | self.status['torch_version'] = torch.__version__ 304 | except ImportError: 305 | pass 306 | 307 | def check_all(self): 308 | """检查所有环境变量并返回状态""" 309 | self.check_yolov8() 310 | self.check_cuda() 311 | self.get_cuda_version() 312 | return self.status 313 | 314 | def check_yolov8(self): 315 | """检查是否安装了YOLOv8""" 316 | try: 317 | import ultralytics 318 | self.status['yolov8_installed'] = True 319 | self.status['yolov8_version'] = ultralytics.__version__ 320 | except ImportError: 321 | self.status['yolov8_installed'] = False 322 | 323 | def check_cuda(self): 324 | """检查CUDA是否可用并获取GPU信息""" 325 | try: 326 | import torch 327 | if torch.cuda.is_available(): 328 | self.status['cuda_available'] = True 329 | 330 | # 获取GPU信息 331 | gpu_count = torch.cuda.device_count() 332 | for i in range(gpu_count): 333 | gpu_name = torch.cuda.get_device_name(i) 334 | gpu_memory = torch.cuda.get_device_properties(i).total_memory / (1024**3) # 转换为GB 335 | self.status['gpu_info'].append({ 336 | 'index': i, 337 | 'name': gpu_name, 338 | 'memory': f"{gpu_memory:.2f} GB" 339 | }) 340 | else: 341 | self.status['cuda_available'] = False 342 | except ImportError: 343 | self.status['cuda_available'] = False 344 | 345 | def get_cuda_version(self): 346 | """获取CUDA版本""" 347 | try: 348 | if self.status['cuda_available']: 349 | import torch 350 | cuda_version = torch.version.cuda 351 | self.status['cuda_version'] = cuda_version if cuda_version else "未知" 352 | else: 353 | # 尝试从系统中检测CUDA 354 | try: 355 | nvcc_output = subprocess.check_output(['nvcc', '--version']).decode('utf-8') 356 | for line in nvcc_output.split('\n'): 357 | if 'release' in line: 358 | # 通常格式为 "release x.y" 359 | parts = line.split('release') 360 | if len(parts) > 1: 361 | version = parts[1].strip().split(' ')[0] 362 | self.status['cuda_version'] = version 363 | break 364 | except: 365 | pass 366 | except: 367 | self.status['cuda_version'] = "无法检测" 368 | 369 | def install_yolov8(self, parent=None, use_mirror=True, mirror_url=None, trust=True): 370 | """安装YOLOv8,支持镜像源配置""" 371 | try: 372 | # 准备安装命令 373 | cmd = [sys.executable, "-m", "pip", "install", "ultralytics"] 374 | 375 | # 添加镜像源参数 376 | if use_mirror and mirror_url: 377 | cmd.extend(["-i", mirror_url]) 378 | # 添加信任参数 379 | if trust: 380 | cmd.append("--trusted-host") 381 | # 提取主机名 382 | from urllib.parse import urlparse 383 | host = urlparse(mirror_url).netloc 384 | cmd.append(host) 385 | 386 | if parent: 387 | # 使用图形界面显示安装进度 388 | dialog = QDialog(parent) 389 | dialog.setWindowTitle("安装YOLOv8") 390 | dialog.resize(500, 300) 391 | 392 | layout = QVBoxLayout(dialog) 393 | 394 | # 显示安装命令 395 | command_label = QLabel("执行安装命令:") 396 | layout.addWidget(command_label) 397 | 398 | command_text = QLineEdit() 399 | command_text.setText(" ".join(cmd)) 400 | command_text.setReadOnly(True) 401 | layout.addWidget(command_text) 402 | 403 | # 进度条 404 | progress_bar = QProgressBar() 405 | progress_bar.setRange(0, 0) # 不确定进度 406 | layout.addWidget(progress_bar) 407 | 408 | # 日志显示 409 | log_label = QLabel("正在安装...") 410 | layout.addWidget(log_label) 411 | 412 | # 按钮 413 | button_box = QDialogButtonBox(QDialogButtonBox.Cancel) 414 | button_box.rejected.connect(dialog.reject) 415 | layout.addWidget(button_box) 416 | 417 | # 创建安装线程 418 | install_thread = InstallThread(cmd) 419 | 420 | # 连接信号 421 | def update_progress(message): 422 | log_label.setText(message) 423 | 424 | def installation_finished(success, message): 425 | progress_bar.setVisible(False) 426 | log_label.setText(message) 427 | 428 | if success: 429 | QMessageBox.information(dialog, "安装成功", "YOLOv8已成功安装!") 430 | dialog.accept() 431 | else: 432 | QMessageBox.warning(dialog, "安装失败", f"YOLOv8安装失败: {message}") 433 | # 仍然允许关闭对话框 434 | button_box.clear() 435 | button_box.addButton(QDialogButtonBox.Close) 436 | button_box.rejected.connect(dialog.reject) 437 | 438 | install_thread.progress_signal.connect(update_progress) 439 | install_thread.finished_signal.connect(installation_finished) 440 | 441 | # 显示对话框并启动线程 442 | install_thread.start() 443 | dialog.exec() 444 | 445 | # 安装后重新检查 446 | self.check_yolov8() 447 | return self.status['yolov8_installed'] 448 | else: 449 | # 命令行模式安装 450 | result = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) 451 | # 安装后重新检查 452 | self.check_yolov8() 453 | return self.status['yolov8_installed'] 454 | 455 | except Exception as e: 456 | if parent: 457 | QMessageBox.critical(parent, "安装错误", f"安装YOLOv8时出错: {str(e)}") 458 | return False 459 | 460 | def install_pytorch(self, parent=None): 461 | """安装PyTorch,根据CUDA可用性选择版本""" 462 | try: 463 | dialog = PytorchInstallDialog(self.status.get('cuda_version'), parent) 464 | 465 | if dialog.exec() == QDialog.Accepted: 466 | return True 467 | return False 468 | except Exception as e: 469 | if parent: 470 | QMessageBox.critical(parent, "安装错误", f"安装PyTorch时出错: {str(e)}") 471 | return False 472 | 473 | def configure_mirror(self, parent=None): 474 | """配置镜像源对话框""" 475 | dialog = MirrorConfigDialog(parent) 476 | 477 | if dialog.exec() == QDialog.Accepted: 478 | return { 479 | 'url': dialog.get_mirror_url(), 480 | 'trusted': dialog.is_trusted() 481 | } 482 | 483 | return None 484 | 485 | 486 | def get_python_executable(): 487 | """获取当前Python解释器路径""" 488 | return sys.executable 489 | 490 | 491 | def run_command(command, shell=False): 492 | """运行命令并返回输出""" 493 | try: 494 | result = subprocess.run( 495 | command, 496 | shell=shell, 497 | check=True, 498 | stdout=subprocess.PIPE, 499 | stderr=subprocess.PIPE, 500 | text=True 501 | ) 502 | return result.stdout 503 | except subprocess.CalledProcessError as e: 504 | return e.stderr 505 | 506 | 507 | def detect_nvidia_driver(): 508 | """检测NVIDIA驱动版本""" 509 | try: 510 | if platform.system() == "Windows": 511 | # Windows: 使用nvidiasmi 512 | output = subprocess.check_output(['nvidia-smi', '--query-gpu=driver_version', '--format=csv,noheader']).decode('utf-8') 513 | return output.strip() 514 | elif platform.system() == "Linux": 515 | # Linux: 使用nvidia-smi 516 | output = subprocess.check_output(['nvidia-smi', '--query-gpu=driver_version', '--format=csv,noheader']).decode('utf-8') 517 | return output.strip() 518 | else: 519 | return "未知" 520 | except: 521 | return "未检测到驱动" 522 | 523 | 524 | def detect_system_cuda(): 525 | """检测系统CUDA安装情况""" 526 | try: 527 | if platform.system() == "Windows": 528 | # 检查环境变量 529 | cuda_path = os.environ.get('CUDA_PATH') 530 | if cuda_path and os.path.exists(cuda_path): 531 | # 尝试从路径中提取版本 532 | path_parts = cuda_path.split('\\') 533 | for part in path_parts: 534 | if part.startswith('v'): 535 | return part[1:] # 去掉v前缀 536 | 537 | # 检查Program Files 538 | cuda_dirs = [] 539 | try: 540 | for root_dir in ['C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA', 'C:\\Program Files\\NVIDIA\\CUDA']: 541 | if os.path.exists(root_dir): 542 | for dir_name in os.listdir(root_dir): 543 | if dir_name.startswith('v'): 544 | cuda_dirs.append((dir_name[1:], os.path.join(root_dir, dir_name))) 545 | except: 546 | pass 547 | 548 | if cuda_dirs: 549 | # 返回最新版本 550 | cuda_dirs.sort(key=lambda x: [int(v) for v in x[0].split('.')]) 551 | return cuda_dirs[-1][0] 552 | 553 | elif platform.system() == "Linux": 554 | # 使用ldconfig检查 555 | try: 556 | output = subprocess.check_output(['ldconfig', '-p']).decode('utf-8') 557 | for line in output.split('\n'): 558 | if 'libcudart.so.' in line: 559 | # 提取版本号 560 | import re 561 | match = re.search(r'libcudart\.so\.(\d+\.\d+)', line) 562 | if match: 563 | return match.group(1) 564 | except: 565 | pass 566 | 567 | # 检查/usr/local目录 568 | try: 569 | cuda_dirs = [] 570 | if os.path.exists('/usr/local'): 571 | for dir_name in os.listdir('/usr/local'): 572 | if dir_name.startswith('cuda-'): 573 | cuda_dirs.append((dir_name[5:], os.path.join('/usr/local', dir_name))) 574 | 575 | if cuda_dirs: 576 | # 返回最新版本 577 | cuda_dirs.sort(key=lambda x: [int(v) for v in x[0].split('.')]) 578 | return cuda_dirs[-1][0] 579 | except: 580 | pass 581 | 582 | return "未检测到" 583 | except: 584 | return "检测失败" 585 | -------------------------------------------------------------------------------- /ui_components.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import time 6 | from datetime import datetime 7 | import yaml 8 | from PySide6.QtWidgets import ( 9 | QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, 10 | QLineEdit, QFileDialog, QComboBox, QCheckBox, QSpinBox, QDoubleSpinBox, 11 | QTabWidget, QScrollArea, QGroupBox, QProgressBar, QTextEdit, QGridLayout, 12 | QSplitter, QFrame, QMessageBox, QToolButton, QStyle, QSizePolicy, 13 | QFormLayout, QToolTip, QRadioButton, QButtonGroup 14 | ) 15 | from PySide6.QtCore import Qt, QSize, Slot, Signal, QPoint 16 | from PySide6.QtGui import QFont, QIcon, QPixmap, QColor, QPalette 17 | 18 | from parameters import parameter_descriptions, parse_data_yaml, save_data_yaml 19 | from PySide6.QtCore import Qt, QPropertyAnimation, QEasingCurve 20 | 21 | class CollapsibleBox(QWidget): 22 | """可折叠的分组框""" 23 | 24 | def __init__(self, title="", parent=None): 25 | super().__init__(parent) 26 | 27 | # 创建标题按钮 28 | self.toggleButton = QToolButton() 29 | self.toggleButton.setStyleSheet(""" 30 | QToolButton { 31 | font-weight: bold; 32 | font-size: 12px; 33 | background-color: #f0f0f0; 34 | border: 1px solid #cccccc; 35 | border-radius: 3px; 36 | padding: 5px; 37 | text-align: left; 38 | } 39 | 40 | QToolButton:hover { 41 | background-color: #e5e5e5; 42 | } 43 | 44 | QToolButton:pressed { 45 | background-color: #d0d0d0; 46 | } 47 | """) 48 | self.toggleButton.setToolButtonStyle(Qt.ToolButtonTextBesideIcon) 49 | self.toggleButton.setArrowType(Qt.RightArrow) 50 | self.toggleButton.setText(title) 51 | self.toggleButton.setCheckable(True) 52 | self.toggleButton.setChecked(False) 53 | 54 | # 滚动区域 55 | self.contentArea = QScrollArea() 56 | self.contentArea.setFrameShape(QFrame.NoFrame) 57 | self.contentArea.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed) 58 | self.contentArea.setMaximumHeight(0) 59 | self.contentArea.setMinimumHeight(0) 60 | 61 | # 内容控件 62 | self.content_widget = QWidget() 63 | self.content_layout = QVBoxLayout(self.content_widget) 64 | self.content_layout.setContentsMargins(10, 5, 5, 5) 65 | self.contentArea.setWidget(self.content_widget) 66 | self.contentArea.setWidgetResizable(True) 67 | 68 | # 主布局 69 | layout = QVBoxLayout(self) 70 | layout.setSpacing(0) 71 | layout.setContentsMargins(0, 0, 0, 0) 72 | layout.addWidget(self.toggleButton) 73 | layout.addWidget(self.contentArea) 74 | 75 | # 连接信号 76 | self.toggleButton.clicked.connect(self.toggle_contents) 77 | 78 | # 动画 79 | from PySide6.QtCore import QPropertyAnimation, QEasingCurve 80 | self.animation = QPropertyAnimation(self.contentArea, b"maximumHeight") 81 | self.animation.setDuration(300) # 300毫秒 82 | self.animation.setEasingCurve(QEasingCurve.OutCubic) 83 | self.animation.finished.connect(self.animation_finished) 84 | 85 | # 计算内容高度 86 | self.content_height = 0 87 | 88 | def toggle_contents(self, checked): 89 | """切换内容区域的展开/折叠状态""" 90 | # 更新按钮状态 91 | self.toggleButton.setArrowType(Qt.DownArrow if checked else Qt.RightArrow) 92 | 93 | # 如果已经有动画在运行,先停止 94 | if self.animation.state() == QPropertyAnimation.Running: 95 | self.animation.stop() 96 | 97 | # 开始高度和结束高度 98 | start_height = self.contentArea.height() 99 | end_height = self.content_height if checked else 0 100 | 101 | # 设置动画参数 102 | self.animation.setStartValue(start_height) 103 | self.animation.setEndValue(end_height) 104 | 105 | # 开始动画 106 | self.animation.start() 107 | 108 | def animation_finished(self): 109 | """动画结束回调""" 110 | # 如果折叠,确保最大高度为0 111 | if not self.toggleButton.isChecked(): 112 | self.contentArea.setMaximumHeight(0) 113 | else: 114 | # 如果展开,确保内容区域可以随内容变化 115 | self.contentArea.setMaximumHeight(self.content_height) 116 | 117 | def add_widget(self, widget): 118 | """添加控件到内容区域""" 119 | self.content_layout.addWidget(widget) 120 | 121 | # 重新计算内容高度 122 | self.content_widget.adjustSize() 123 | self.content_height = self.content_widget.height() 124 | 125 | # 如果已经展开,更新最大高度 126 | if self.toggleButton.isChecked(): 127 | self.contentArea.setMaximumHeight(self.content_height) 128 | 129 | def add_layout(self, layout): 130 | """添加布局到内容区域""" 131 | self.content_layout.addLayout(layout) 132 | 133 | # 重新计算内容高度 134 | self.content_widget.adjustSize() 135 | self.content_height = self.content_widget.height() 136 | 137 | # 如果已经展开,更新最大高度 138 | if self.toggleButton.isChecked(): 139 | self.contentArea.setMaximumHeight(self.content_height) 140 | 141 | def setContentLayout(self, layout): 142 | """设置内容区域的布局 (兼容旧代码)""" 143 | # 清除旧布局 144 | while self.content_layout.count(): 145 | item = self.content_layout.takeAt(0) 146 | if item.widget(): 147 | item.widget().deleteLater() 148 | 149 | # 添加新布局 150 | self.content_layout.addLayout(layout) 151 | 152 | # 计算内容高度 153 | self.content_widget.adjustSize() 154 | self.content_height = self.content_widget.height() 155 | 156 | def expand(self): 157 | """展开内容区域""" 158 | if not self.toggleButton.isChecked(): 159 | self.toggleButton.click() 160 | 161 | def collapse(self): 162 | """折叠内容区域""" 163 | if self.toggleButton.isChecked(): 164 | self.toggleButton.click() 165 | 166 | def setTitle(self, title): 167 | """设置标题""" 168 | self.toggleButton.setText(title) 169 | 170 | 171 | class ParameterWidget(QWidget): 172 | """参数设置控件""" 173 | 174 | parameterSelected = Signal(str, str) # 发出信号: 参数名称, 描述 175 | 176 | def __init__(self, name, value, param_type, description="", parent=None): 177 | super().__init__(parent) 178 | self.name = name 179 | self.param_type = param_type 180 | self.description = description 181 | 182 | layout = QHBoxLayout(self) 183 | layout.setContentsMargins(0, 5, 0, 5) 184 | 185 | # 参数名称标签 186 | name_label = QLabel(name + ":") 187 | name_label.setMinimumWidth(120) 188 | name_label.setToolTip(description) 189 | layout.addWidget(name_label) 190 | 191 | # 根据参数类型创建不同的输入控件 192 | if param_type == bool: 193 | self.input_widget = QCheckBox() 194 | self.input_widget.setChecked(value) 195 | elif param_type == int: 196 | self.input_widget = QSpinBox() 197 | self.input_widget.setRange(-999999, 999999) 198 | self.input_widget.setValue(value) 199 | elif param_type == float: 200 | self.input_widget = QDoubleSpinBox() 201 | self.input_widget.setRange(-999999.0, 999999.0) 202 | self.input_widget.setDecimals(4) 203 | self.input_widget.setValue(value) 204 | elif param_type == list: 205 | self.input_widget = QComboBox() 206 | self.input_widget.addItems(value) 207 | if isinstance(value, list) and len(value) > 0: 208 | self.input_widget.setCurrentIndex(0) 209 | else: # str 或其他 210 | self.input_widget = QLineEdit(str(value)) 211 | 212 | self.input_widget.setToolTip(description) 213 | layout.addWidget(self.input_widget) 214 | 215 | # 如果是文件路径,添加浏览按钮 216 | if name.endswith('_path') or name == 'model' or name == 'project': 217 | browse_button = QPushButton("浏览...") 218 | browse_button.clicked.connect(self.browse_file) 219 | layout.addWidget(browse_button) 220 | 221 | # 添加帮助按钮,点击显示详细描述 222 | if description: 223 | help_button = QPushButton("?") 224 | help_button.setFixedSize(24, 24) 225 | help_button.setStyleSheet(""" 226 | QPushButton { 227 | background-color: #3498db; 228 | color: white; 229 | border-radius: 12px; 230 | font-weight: bold; 231 | } 232 | QPushButton:hover { 233 | background-color: #2980b9; 234 | } 235 | """) 236 | help_button.clicked.connect(self.show_description) 237 | layout.addWidget(help_button) 238 | 239 | # 鼠标悬停或点击时发出信号 240 | self.setMouseTracking(True) 241 | 242 | def browse_file(self): 243 | """浏览文件或目录""" 244 | if self.name == 'data_path': 245 | file_path, _ = QFileDialog.getOpenFileName(self, "选择数据配置文件", "", "YAML文件 (*.yaml)") 246 | elif self.name == 'model': 247 | file_path, _ = QFileDialog.getOpenFileName(self, "选择模型文件", "", "PyTorch模型 (*.pt);;All Files (*)") 248 | elif self.name == 'project': 249 | file_path = QFileDialog.getExistingDirectory(self, "选择项目目录") 250 | else: 251 | file_path, _ = QFileDialog.getOpenFileName(self, "选择文件", "") 252 | 253 | if file_path: 254 | if isinstance(self.input_widget, QLineEdit): 255 | self.input_widget.setText(file_path) 256 | elif isinstance(self.input_widget, QComboBox): 257 | if self.input_widget.findText(file_path) == -1: 258 | self.input_widget.addItem(file_path) 259 | self.input_widget.setCurrentText(file_path) 260 | 261 | def show_description(self): 262 | """显示参数详细描述""" 263 | self.parameterSelected.emit(self.name, self.description) 264 | 265 | def enterEvent(self, event): 266 | """鼠标进入事件""" 267 | self.parameterSelected.emit(self.name, self.description) 268 | super().enterEvent(event) 269 | 270 | def get_value(self): 271 | """获取控件当前值""" 272 | if isinstance(self.input_widget, QCheckBox): 273 | return self.input_widget.isChecked() 274 | elif isinstance(self.input_widget, QSpinBox): 275 | return self.input_widget.value() 276 | elif isinstance(self.input_widget, QDoubleSpinBox): 277 | return self.input_widget.value() 278 | elif isinstance(self.input_widget, QComboBox): 279 | return self.input_widget.currentText() 280 | elif isinstance(self.input_widget, QLineEdit): 281 | value = self.input_widget.text() 282 | if self.param_type == int: 283 | try: 284 | return int(value) 285 | except ValueError: 286 | return 0 287 | elif self.param_type == float: 288 | try: 289 | return float(value) 290 | except ValueError: 291 | return 0.0 292 | return value 293 | 294 | def set_value(self, value): 295 | """设置控件值""" 296 | if isinstance(self.input_widget, QCheckBox): 297 | self.input_widget.setChecked(bool(value)) 298 | elif isinstance(self.input_widget, QSpinBox): 299 | self.input_widget.setValue(int(value)) 300 | elif isinstance(self.input_widget, QDoubleSpinBox): 301 | self.input_widget.setValue(float(value)) 302 | elif isinstance(self.input_widget, QComboBox): 303 | index = self.input_widget.findText(str(value)) 304 | if index >= 0: 305 | self.input_widget.setCurrentIndex(index) 306 | else: 307 | self.input_widget.addItem(str(value)) 308 | self.input_widget.setCurrentText(str(value)) 309 | elif isinstance(self.input_widget, QLineEdit): 310 | self.input_widget.setText(str(value)) 311 | 312 | 313 | class ParameterGroup(QWidget): 314 | """参数分组""" 315 | 316 | parameterSelected = Signal(str, str) # 发出信号: 参数名称, 描述 317 | 318 | def __init__(self, title, params, parent=None): 319 | super().__init__(parent) 320 | self.title = title 321 | self.params = params 322 | self.param_widgets = {} 323 | 324 | # 创建可折叠分组 325 | self.box = CollapsibleBox(title) 326 | layout = QVBoxLayout(self) 327 | layout.addWidget(self.box) 328 | 329 | # 创建网格布局 330 | grid_layout = QGridLayout() 331 | grid_layout.setColumnStretch(1, 1) 332 | 333 | # 为每个参数创建控件 334 | row = 0 335 | for name, value in params.items(): 336 | # 参数类型 337 | param_type = type(value) 338 | 339 | # 获取参数描述 340 | description = parameter_descriptions.get(name, "") 341 | 342 | # 创建参数控件 343 | param_widget = ParameterWidget(name, value, param_type, description) 344 | param_widget.parameterSelected.connect(self.on_parameter_selected) 345 | self.param_widgets[name] = param_widget 346 | 347 | # 添加到布局 348 | grid_layout.addWidget(param_widget, row, 0) 349 | row += 1 350 | 351 | # 设置分组的内容布局 352 | self.box.setContentLayout(grid_layout) 353 | 354 | def on_parameter_selected(self, name, description): 355 | """参数被选中的事件处理""" 356 | self.parameterSelected.emit(name, description) 357 | 358 | def get_values(self): 359 | """获取所有参数的当前值""" 360 | values = {} 361 | for name, widget in self.param_widgets.items(): 362 | values[name] = widget.get_value() 363 | return values 364 | 365 | def set_values(self, values): 366 | """设置所有参数的值""" 367 | for name, value in values.items(): 368 | if name in self.param_widgets: 369 | self.param_widgets[name].set_value(value) 370 | 371 | 372 | class TrainingTab(QWidget): 373 | """训练选项卡""" 374 | 375 | def __init__(self, parameters, parent=None): 376 | super().__init__(parent) 377 | self.parameters = parameters 378 | self.param_groups = {} 379 | 380 | # 主布局 381 | main_layout = QVBoxLayout(self) 382 | main_layout.setSpacing(15) 383 | 384 | # 左右分割区域 385 | splitter = QSplitter(Qt.Horizontal) 386 | 387 | # 左侧设置区域 388 | left_widget = QWidget() 389 | left_layout = QVBoxLayout(left_widget) 390 | left_layout.setContentsMargins(0, 0, 0, 0) 391 | 392 | # ====== 必要参数区域 ====== 393 | essential_group = QGroupBox("核心参数") 394 | essential_layout = QFormLayout(essential_group) 395 | essential_layout.setFieldGrowthPolicy(QFormLayout.AllNonFixedFieldsGrow) 396 | essential_layout.setLabelAlignment(Qt.AlignRight | Qt.AlignVCenter) 397 | essential_layout.setSpacing(10) 398 | 399 | # 选择任务类型 400 | self.task_combo = QComboBox() 401 | self.task_combo.addItems(["目标检测 (detect)", "分割 (segment)", "分类 (classify)", "姿态估计 (pose)"]) 402 | self.task_combo.setCurrentIndex(0) 403 | self.task_combo.currentIndexChanged.connect(self.on_task_changed) 404 | essential_layout.addRow("任务类型:", self.task_combo) 405 | 406 | # 数据配置文件 407 | self.data_path_layout = QHBoxLayout() 408 | self.data_path_input = QLineEdit() 409 | self.data_path_input.setPlaceholderText("选择数据集YAML文件") 410 | self.data_path_browse = QPushButton("浏览...") 411 | self.data_path_browse.clicked.connect(self.browse_data_file) 412 | self.data_view_button = QPushButton("查看/编辑") 413 | self.data_view_button.clicked.connect(self.view_data_config) 414 | 415 | self.data_path_layout.addWidget(self.data_path_input, 1) 416 | self.data_path_layout.addWidget(self.data_path_browse) 417 | self.data_path_layout.addWidget(self.data_view_button) 418 | essential_layout.addRow("数据配置:", self.data_path_layout) 419 | 420 | # 模型选择 421 | self.model_layout = QHBoxLayout() 422 | self.model_combo = QComboBox() 423 | self.model_combo.addItems([ 424 | "yolov8n.pt", # Nano 425 | "yolov8s.pt", # Small 426 | "yolov8m.pt", # Medium 427 | "yolov8l.pt", # Large 428 | "yolov8x.pt" # XLarge 429 | ]) 430 | self.model_browse = QPushButton("浏览...") 431 | self.model_browse.clicked.connect(self.browse_model_file) 432 | 433 | self.model_layout.addWidget(self.model_combo, 1) 434 | self.model_layout.addWidget(self.model_browse) 435 | essential_layout.addRow("模型:", self.model_layout) 436 | 437 | # 关键参数 438 | self.batch_spinbox = QSpinBox() 439 | self.batch_spinbox.setRange(1, 128) 440 | self.batch_spinbox.setValue(parameters['data']['batch']) 441 | essential_layout.addRow("批次大小:", self.batch_spinbox) 442 | 443 | self.imgsz_spinbox = QSpinBox() 444 | self.imgsz_spinbox.setRange(32, 1280) 445 | self.imgsz_spinbox.setValue(parameters['data']['imgsz']) 446 | self.imgsz_spinbox.setSingleStep(32) 447 | essential_layout.addRow("图像大小:", self.imgsz_spinbox) 448 | 449 | self.epochs_spinbox = QSpinBox() 450 | self.epochs_spinbox.setRange(1, 1000) 451 | self.epochs_spinbox.setValue(parameters['training']['epochs']) 452 | essential_layout.addRow("训练轮数:", self.epochs_spinbox) 453 | 454 | # 学习率 455 | self.lr0_spinbox = QDoubleSpinBox() 456 | self.lr0_spinbox.setRange(0.0001, 0.1) 457 | self.lr0_spinbox.setValue(parameters['training']['lr0']) 458 | self.lr0_spinbox.setDecimals(5) 459 | self.lr0_spinbox.setSingleStep(0.001) 460 | essential_layout.addRow("初始学习率:", self.lr0_spinbox) 461 | 462 | left_layout.addWidget(essential_group) 463 | 464 | # CUDA选择区域 465 | cuda_group = QGroupBox("硬件加速") 466 | cuda_layout = QGridLayout(cuda_group) 467 | 468 | self.cuda_available_label = QLabel("CUDA可用: 未检测") 469 | self.use_cuda_checkbox = QCheckBox("使用CUDA训练") 470 | self.use_cuda_checkbox.setChecked(True) 471 | self.device_label = QLabel("设备:") 472 | self.device_combo = QComboBox() 473 | self.device_combo.addItem("") # 自动选择 474 | 475 | cuda_layout.addWidget(self.cuda_available_label, 0, 0) 476 | cuda_layout.addWidget(self.use_cuda_checkbox, 0, 1) 477 | cuda_layout.addWidget(self.device_label, 1, 0) 478 | cuda_layout.addWidget(self.device_combo, 1, 1) 479 | 480 | left_layout.addWidget(cuda_group) 481 | 482 | # ====== 高级参数区域 ====== 483 | # 创建可折叠的高级参数区域 484 | advanced_box = CollapsibleBox("高级参数") 485 | advanced_widget = QWidget() 486 | advanced_layout = QVBoxLayout(advanced_widget) 487 | 488 | # 创建带滚动条的参数区域 489 | params_scroll_area = QScrollArea() 490 | params_scroll_area.setWidgetResizable(True) 491 | params_scroll_area.setFrameShape(QFrame.NoFrame) # 去掉边框 492 | params_scroll_area.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) # 禁用水平滚动条 493 | params_scroll_area.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) 494 | 495 | # 创建内部容器widget 496 | params_container = QWidget() 497 | params_layout = QVBoxLayout(params_container) 498 | params_layout.setSpacing(10) 499 | 500 | # 数据选择区域 501 | data_group = QGroupBox("数据集设置") 502 | data_layout = QVBoxLayout(data_group) 503 | data_layout.setSpacing(10) 504 | 505 | # 选择数据模式 506 | self.data_mode_label = QLabel("数据模式:") 507 | self.data_mode_combo = QComboBox() 508 | self.data_mode_combo.addItems(["YAML配置文件", "训练文件夹"]) 509 | self.data_mode_combo.setCurrentIndex(0) 510 | self.data_mode_combo.currentIndexChanged.connect(self.on_data_mode_changed) 511 | 512 | data_mode_layout = QHBoxLayout() 513 | data_mode_layout.addWidget(self.data_mode_label) 514 | data_mode_layout.addWidget(self.data_mode_combo) 515 | data_mode_layout.addStretch() 516 | 517 | data_layout.addLayout(data_mode_layout) 518 | 519 | # YAML配置文件选择 - 这部分已经移到核心参数中,这里只是备用 520 | self.yaml_group = QWidget() 521 | yaml_layout = QHBoxLayout(self.yaml_group) 522 | yaml_layout.setContentsMargins(0, 0, 0, 0) 523 | 524 | self.data_path_label = QLabel("数据配置文件:") 525 | yaml_layout.addWidget(self.data_path_label) 526 | 527 | data_layout.addWidget(self.yaml_group) 528 | self.yaml_group.setVisible(False) # 默认隐藏,已经在上面显示 529 | 530 | # 训练文件夹选择 531 | self.folder_group = QWidget() 532 | folder_layout = QVBoxLayout(self.folder_group) 533 | folder_layout.setContentsMargins(0, 0, 0, 0) 534 | folder_layout.setSpacing(10) 535 | 536 | # 训练文件夹 537 | train_folder_layout = QHBoxLayout() 538 | self.train_folder_label = QLabel("训练文件夹:") 539 | self.train_folder_input = QLineEdit() 540 | self.train_folder_input.setPlaceholderText("选择训练数据文件夹") 541 | self.train_folder_browse = QPushButton("浏览...") 542 | self.train_folder_browse.clicked.connect(self.browse_train_folder) 543 | 544 | train_folder_layout.addWidget(self.train_folder_label) 545 | train_folder_layout.addWidget(self.train_folder_input, 1) 546 | train_folder_layout.addWidget(self.train_folder_browse) 547 | 548 | folder_layout.addLayout(train_folder_layout) 549 | 550 | # 分类任务设置 - 改为可折叠的方式 551 | self.class_task_box = CollapsibleBox("分类数据结构") 552 | class_task_internal = QWidget() 553 | class_task_layout = QVBoxLayout(class_task_internal) 554 | 555 | # 数据集结构选择 556 | dataset_struct_layout = QHBoxLayout() 557 | self.dataset_struct_label = QLabel("结构类型:") 558 | self.dataset_struct_group = QButtonGroup(self) 559 | 560 | # 直接使用文件夹(不生成YAML) 561 | self.direct_folder_radio = QRadioButton("直接使用文件夹") 562 | self.direct_folder_radio.setChecked(True) 563 | self.direct_folder_radio.setToolTip("对于分类任务,可以直接使用文件夹路径进行训练,无需生成data.yaml") 564 | self.dataset_struct_group.addButton(self.direct_folder_radio, 1) 565 | 566 | # 预分割数据集结构 567 | self.presplit_folder_radio = QRadioButton("预分割数据集") 568 | self.presplit_folder_radio.setToolTip("数据集已包含train、val和test文件夹,每个文件夹下有类别子文件夹") 569 | self.dataset_struct_group.addButton(self.presplit_folder_radio, 2) 570 | 571 | # 单层类别文件夹结构 572 | self.single_folder_radio = QRadioButton("单层文件夹") 573 | self.single_folder_radio.setToolTip("数据集包含各个类别的文件夹,系统将自动分割为训练集和验证集") 574 | self.dataset_struct_group.addButton(self.single_folder_radio, 3) 575 | 576 | dataset_struct_layout.addWidget(self.dataset_struct_label) 577 | dataset_struct_layout.addWidget(self.direct_folder_radio) 578 | dataset_struct_layout.addWidget(self.presplit_folder_radio) 579 | dataset_struct_layout.addWidget(self.single_folder_radio) 580 | dataset_struct_layout.addStretch() 581 | 582 | class_task_layout.addLayout(dataset_struct_layout) 583 | 584 | # 连接信号 585 | self.dataset_struct_group.buttonClicked.connect(self.on_dataset_struct_changed) 586 | 587 | # 文件夹结构说明 - 改为更紧凑的显示方式 588 | self.struct_info_button = QPushButton("查看文件夹结构说明") 589 | self.struct_info_button.clicked.connect(self.show_folder_structure_info) 590 | class_task_layout.addWidget(self.struct_info_button) 591 | 592 | # 添加类设置到可折叠框 593 | self.class_task_box.add_widget(class_task_internal) 594 | self.class_task_box.setVisible(False) # 默认隐藏 595 | folder_layout.addWidget(self.class_task_box) 596 | 597 | # 非分类任务的设置 598 | self.nonclass_task_frame = QFrame() 599 | nonclass_task_layout = QHBoxLayout(self.nonclass_task_frame) 600 | nonclass_task_layout.setContentsMargins(0, 0, 0, 0) 601 | 602 | # 验证集比例 603 | self.val_split_label = QLabel("验证集比例:") 604 | self.val_split_spin = QDoubleSpinBox() 605 | self.val_split_spin.setRange(0.0, 0.5) 606 | self.val_split_spin.setSingleStep(0.05) 607 | self.val_split_spin.setValue(0.2) 608 | self.val_split_spin.setDecimals(2) 609 | 610 | # 生成YAML按钮 611 | self.generate_yaml_button = QPushButton("生成YAML配置") 612 | self.generate_yaml_button.clicked.connect(self.generate_yaml_config) 613 | 614 | nonclass_task_layout.addWidget(self.val_split_label) 615 | nonclass_task_layout.addWidget(self.val_split_spin) 616 | nonclass_task_layout.addWidget(self.generate_yaml_button) 617 | nonclass_task_layout.addStretch() 618 | 619 | folder_layout.addWidget(self.nonclass_task_frame) 620 | 621 | data_layout.addWidget(self.folder_group) 622 | 623 | # 默认显示YAML配置模式 624 | self.folder_group.setVisible(False) 625 | 626 | params_layout.addWidget(data_group) 627 | 628 | # 创建优化器组 629 | optimizer_group = self.create_optimizer_param_group() 630 | params_layout.addWidget(optimizer_group) 631 | 632 | # 创建增强组 633 | augment_group = self.create_augment_param_group() 634 | params_layout.addWidget(augment_group) 635 | 636 | # 添加其他参数组 637 | for group_name, group_params in parameters.items(): 638 | if group_name not in ['data', 'model', 'training', 'hyp', 'augment']: 639 | param_group = ParameterGroup(group_name.capitalize(), group_params) 640 | param_group.parameterSelected.connect(self.on_parameter_selected) 641 | self.param_groups[group_name] = param_group 642 | params_layout.addWidget(param_group) 643 | 644 | # 设置滚动区域的widget 645 | params_scroll_area.setWidget(params_container) 646 | advanced_layout.addWidget(params_scroll_area) 647 | 648 | # 设置高级参数区域的内容 649 | advanced_box.add_widget(advanced_widget) 650 | left_layout.addWidget(advanced_box) 651 | 652 | # 右侧参数说明区域 653 | right_widget = QWidget() 654 | right_layout = QVBoxLayout(right_widget) 655 | right_layout.setContentsMargins(10, 0, 0, 0) 656 | 657 | # 参数说明标题 658 | self.param_desc_title = QLabel("参数说明") 659 | self.param_desc_title.setStyleSheet(""" 660 | font-size: 16px; 661 | font-weight: bold; 662 | color: #333; 663 | padding: 5px; 664 | border-bottom: 1px solid #ddd; 665 | """) 666 | 667 | # 参数名称 668 | self.param_name_label = QLabel("") 669 | self.param_name_label.setStyleSheet(""" 670 | font-size: 14px; 671 | font-weight: bold; 672 | color: #444; 673 | padding: 5px; 674 | """) 675 | 676 | # 参数描述 677 | self.param_desc_text = QTextEdit() 678 | self.param_desc_text.setReadOnly(True) 679 | self.param_desc_text.setStyleSheet(""" 680 | background-color: #f8f9fa; 681 | border: 1px solid #e9ecef; 682 | border-radius: 5px; 683 | padding: 5px; 684 | font-size: 13px; 685 | """) 686 | 687 | right_layout.addWidget(self.param_desc_title) 688 | right_layout.addWidget(self.param_name_label) 689 | right_layout.addWidget(self.param_desc_text, 1) 690 | 691 | # 添加左右区域到分割器 692 | splitter.addWidget(left_widget) 693 | splitter.addWidget(right_widget) 694 | splitter.setStretchFactor(0, 3) # 左侧占比更大 695 | splitter.setStretchFactor(1, 1) # 右侧占比较小 696 | 697 | main_layout.addWidget(splitter, 1) 698 | 699 | # 训练控制按钮 700 | control_layout = QHBoxLayout() 701 | self.start_training_button = QPushButton("开始训练") 702 | self.start_training_button.setMinimumHeight(40) 703 | self.stop_training_button = QPushButton("停止训练") 704 | self.stop_training_button.setMinimumHeight(40) 705 | self.stop_training_button.setEnabled(False) 706 | 707 | control_layout.addWidget(self.start_training_button) 708 | control_layout.addWidget(self.stop_training_button) 709 | 710 | main_layout.addLayout(control_layout) 711 | 712 | # 初始状态下显示一些默认说明文字 713 | self.param_name_label.setText("参数使用说明") 714 | self.param_desc_text.setText("鼠标悬停在参数上可查看该参数的详细说明。\n\n" 715 | "点击参数后的问号按钮也可以查看详细说明。\n\n" 716 | "常见参数说明:\n" 717 | "- data_path: 数据集配置文件路径,YAML格式\n" 718 | "- batch: 训练批次大小,根据显存调整\n" 719 | "- imgsz: 输入图像大小,单位为像素\n" 720 | "- epochs: 训练总轮数\n" 721 | "- device: 训练设备,空为自动选择") 722 | 723 | # 尝试自动检测并填充数据路径 724 | self.auto_detect_data_path() 725 | 726 | def create_optimizer_param_group(self): 727 | """创建优化器参数组""" 728 | group = CollapsibleBox("优化器参数") 729 | widget = QWidget() 730 | layout = QFormLayout(widget) 731 | layout.setFieldGrowthPolicy(QFormLayout.AllNonFixedFieldsGrow) 732 | 733 | # 添加优化器选择 734 | self.optimizer_combo = QComboBox() 735 | self.optimizer_combo.addItems(["SGD", "Adam", "AdamW"]) 736 | self.optimizer_combo.setCurrentText(self.parameters['training']['optimizer']) 737 | layout.addRow("优化器:", self.optimizer_combo) 738 | 739 | # 添加动量参数 740 | self.momentum_spinbox = QDoubleSpinBox() 741 | self.momentum_spinbox.setRange(0.0, 0.999) 742 | self.momentum_spinbox.setValue(self.parameters['training']['momentum']) 743 | self.momentum_spinbox.setSingleStep(0.01) 744 | self.momentum_spinbox.setDecimals(3) 745 | layout.addRow("动量:", self.momentum_spinbox) 746 | 747 | # 添加权重衰减 748 | self.weight_decay_spinbox = QDoubleSpinBox() 749 | self.weight_decay_spinbox.setRange(0.0, 0.1) 750 | self.weight_decay_spinbox.setValue(self.parameters['training']['weight_decay']) 751 | self.weight_decay_spinbox.setSingleStep(0.0001) 752 | self.weight_decay_spinbox.setDecimals(5) 753 | layout.addRow("权重衰减:", self.weight_decay_spinbox) 754 | 755 | # 添加余弦学习率 756 | self.cos_lr_checkbox = QCheckBox() 757 | self.cos_lr_checkbox.setChecked(self.parameters['training']['cos_lr']) 758 | layout.addRow("余弦学习率调度:", self.cos_lr_checkbox) 759 | 760 | # 添加混合精度训练 761 | self.amp_checkbox = QCheckBox() 762 | self.amp_checkbox.setChecked(self.parameters['training']['amp']) 763 | layout.addRow("混合精度训练:", self.amp_checkbox) 764 | 765 | group.add_widget(widget) 766 | return group 767 | 768 | def create_augment_param_group(self): 769 | """创建数据增强参数组""" 770 | group = CollapsibleBox("数据增强参数") 771 | widget = QWidget() 772 | layout = QFormLayout(widget) 773 | layout.setFieldGrowthPolicy(QFormLayout.AllNonFixedFieldsGrow) 774 | 775 | # 添加HSV增强参数 776 | self.hsv_h_spinbox = QDoubleSpinBox() 777 | self.hsv_h_spinbox.setRange(0.0, 1.0) 778 | self.hsv_h_spinbox.setValue(self.parameters['hyp']['hsv_h']) 779 | self.hsv_h_spinbox.setSingleStep(0.01) 780 | self.hsv_h_spinbox.setDecimals(3) 781 | layout.addRow("HSV色调增强:", self.hsv_h_spinbox) 782 | 783 | self.hsv_s_spinbox = QDoubleSpinBox() 784 | self.hsv_s_spinbox.setRange(0.0, 1.0) 785 | self.hsv_s_spinbox.setValue(self.parameters['hyp']['hsv_s']) 786 | self.hsv_s_spinbox.setSingleStep(0.01) 787 | self.hsv_s_spinbox.setDecimals(3) 788 | layout.addRow("HSV饱和度增强:", self.hsv_s_spinbox) 789 | 790 | self.hsv_v_spinbox = QDoubleSpinBox() 791 | self.hsv_v_spinbox.setRange(0.0, 1.0) 792 | self.hsv_v_spinbox.setValue(self.parameters['hyp']['hsv_v']) 793 | self.hsv_v_spinbox.setSingleStep(0.01) 794 | self.hsv_v_spinbox.setDecimals(3) 795 | layout.addRow("HSV亮度增强:", self.hsv_v_spinbox) 796 | 797 | # 添加几何增强参数 798 | self.fliplr_spinbox = QDoubleSpinBox() 799 | self.fliplr_spinbox.setRange(0.0, 1.0) 800 | self.fliplr_spinbox.setValue(self.parameters['hyp']['fliplr']) 801 | self.fliplr_spinbox.setSingleStep(0.01) 802 | self.fliplr_spinbox.setDecimals(2) 803 | layout.addRow("水平翻转概率:", self.fliplr_spinbox) 804 | 805 | self.mosaic_spinbox = QDoubleSpinBox() 806 | self.mosaic_spinbox.setRange(0.0, 1.0) 807 | self.mosaic_spinbox.setValue(self.parameters['hyp']['mosaic']) 808 | self.mosaic_spinbox.setSingleStep(0.01) 809 | self.mosaic_spinbox.setDecimals(2) 810 | layout.addRow("马赛克增强概率:", self.mosaic_spinbox) 811 | 812 | self.mixup_spinbox = QDoubleSpinBox() 813 | self.mixup_spinbox.setRange(0.0, 1.0) 814 | self.mixup_spinbox.setValue(self.parameters['hyp']['mixup']) 815 | self.mixup_spinbox.setSingleStep(0.01) 816 | self.mixup_spinbox.setDecimals(2) 817 | layout.addRow("Mixup增强概率:", self.mixup_spinbox) 818 | 819 | group.add_widget(widget) 820 | return group 821 | 822 | def auto_detect_data_path(self): 823 | """自动检测并填充数据路径""" 824 | # 首先检查当前目录 825 | current_dir = os.getcwd() 826 | 827 | # 检查各种常见的数据目录名 828 | common_dirs = ["data", "datasets", "dataset", "yolo_data", "images", "annotations"] 829 | for dir_name in common_dirs: 830 | dir_path = os.path.join(current_dir, dir_name) 831 | if os.path.isdir(dir_path): 832 | # 搜索data.yaml文件 833 | yaml_path = os.path.join(dir_path, "data.yaml") 834 | if os.path.exists(yaml_path): 835 | self.data_path_input.setText(yaml_path) 836 | return 837 | 838 | # 递归搜索前三层目录 839 | def search_yaml(directory, depth=0): 840 | if depth > 2: # 限制搜索深度 841 | return None 842 | 843 | for root, dirs, files in os.walk(directory): 844 | for file in files: 845 | if file == "data.yaml": 846 | return os.path.join(root, file) 847 | 848 | # 递归搜索一层 849 | for dir_name in dirs: 850 | result = search_yaml(os.path.join(root, dir_name), depth + 1) 851 | if result: 852 | return result 853 | 854 | return None 855 | 856 | # 执行搜索 857 | yaml_path = search_yaml(current_dir) 858 | if yaml_path: 859 | self.data_path_input.setText(yaml_path) 860 | 861 | def on_parameter_selected(self, name, description): 862 | """参数被选择时更新右侧说明区域""" 863 | if not description: 864 | return 865 | 866 | self.param_name_label.setText(f"参数: {name}") 867 | self.param_desc_text.setText(description) 868 | 869 | def on_task_changed(self, index): 870 | """任务类型切换""" 871 | task_text = self.task_combo.currentText() 872 | 873 | # 检查是否为分类任务 874 | is_classification = "分类" in task_text or "classify" in task_text.lower() 875 | 876 | # 更新UI显示 877 | self.class_task_box.setVisible(is_classification) 878 | self.nonclass_task_frame.setVisible(not is_classification) 879 | 880 | # 如果是分类任务,更新数据模式提示 881 | if is_classification: 882 | self.train_folder_label.setText("分类数据文件夹:") 883 | self.train_folder_input.setPlaceholderText("选择包含各个类别文件夹的目录") 884 | else: 885 | self.train_folder_label.setText("训练文件夹:") 886 | self.train_folder_input.setPlaceholderText("选择包含images和labels的目录") 887 | 888 | def on_data_mode_changed(self, index): 889 | """数据模式切换""" 890 | if index == 0: # YAML配置文件 891 | self.yaml_group.setVisible(True) 892 | self.folder_group.setVisible(False) 893 | else: # 训练文件夹 894 | self.yaml_group.setVisible(False) 895 | self.folder_group.setVisible(True) 896 | 897 | # 如果是分类任务,显示分类任务相关控件 898 | task_text = self.task_combo.currentText() 899 | is_classification = "分类" in task_text or "classify" in task_text.lower() 900 | self.class_task_box.setVisible(is_classification) 901 | self.nonclass_task_frame.setVisible(not is_classification) 902 | 903 | def show_folder_structure_info(self): 904 | """显示文件夹结构说明对话框""" 905 | struct_type = 1 906 | if self.presplit_folder_radio.isChecked(): 907 | struct_type = 2 908 | elif self.single_folder_radio.isChecked(): 909 | struct_type = 3 910 | 911 | info_text = self.get_folder_structure_info(struct_type) 912 | 913 | from PySide6.QtWidgets import QDialog, QVBoxLayout, QTextEdit, QDialogButtonBox 914 | 915 | dialog = QDialog(self) 916 | dialog.setWindowTitle("文件夹结构说明") 917 | dialog.resize(400, 300) 918 | 919 | layout = QVBoxLayout(dialog) 920 | 921 | text_edit = QTextEdit() 922 | text_edit.setReadOnly(True) 923 | text_edit.setPlainText(info_text) 924 | text_edit.setStyleSheet("font-family: monospace;") 925 | layout.addWidget(text_edit) 926 | 927 | button_box = QDialogButtonBox(QDialogButtonBox.Close) 928 | button_box.rejected.connect(dialog.reject) 929 | layout.addWidget(button_box) 930 | 931 | dialog.exec() 932 | 933 | def get_folder_structure_info(self, struct_type): 934 | """获取文件夹结构说明文本""" 935 | if struct_type == 1: # 直接使用文件夹 936 | return "直接使用文件夹结构示例:\n" \ 937 | "└── dataset/\n" \ 938 | " ├── class1/\n" \ 939 | " │ ├── img1.jpg\n" \ 940 | " │ └── ...\n" \ 941 | " └── class2/\n" \ 942 | " ├── img2.jpg\n" \ 943 | " └── ...\n\n" \ 944 | "系统将直接使用该文件夹,无需生成YAML配置。" 945 | elif struct_type == 2: # 预分割的数据集 946 | return "预分割数据集结构示例:\n" \ 947 | "└── dataset/\n" \ 948 | " ├── train/\n" \ 949 | " │ ├── class1/\n" \ 950 | " │ └── class2/\n" \ 951 | " ├── val/\n" \ 952 | " │ ├── class1/\n" \ 953 | " │ └── class2/\n" \ 954 | " └── test/ (可选)\n" \ 955 | " ├── class1/\n" \ 956 | " └── class2/\n\n" \ 957 | "系统将根据这个结构生成data.yaml配置文件。" 958 | elif struct_type == 3: # 单层类别文件夹 959 | return "单层类别文件夹结构示例:\n" \ 960 | "└── dataset/\n" \ 961 | " ├── class1/\n" \ 962 | " │ ├── img1.jpg\n" \ 963 | " │ └── ...\n" \ 964 | " └── class2/\n" \ 965 | " ├── img2.jpg\n" \ 966 | " └── ...\n\n" \ 967 | "系统将自动分割为训练集和验证集,并生成相应的目录和data.yaml配置文件。" 968 | 969 | def on_dataset_struct_changed(self, button): 970 | """数据集结构选择改变时更新UI""" 971 | # 结构类型在点击"查看文件夹结构说明"时使用 972 | # 不需要更新显示,因为已经改为对话框方式 973 | pass 974 | 975 | def update_folder_structure_info(self, struct_type): 976 | """保留这个方法以兼容旧代码,现在不需要实时更新UI""" 977 | pass 978 | 979 | def browse_data_file(self): 980 | """浏览选择数据配置文件""" 981 | file_path, _ = QFileDialog.getOpenFileName(self, "选择数据配置文件", "", "YAML文件 (*.yaml)") 982 | if file_path: 983 | self.data_path_input.setText(file_path) 984 | 985 | def browse_train_folder(self): 986 | """浏览选择训练文件夹""" 987 | folder_path = QFileDialog.getExistingDirectory(self, "选择训练数据文件夹") 988 | if folder_path: 989 | self.train_folder_input.setText(folder_path) 990 | # 自动检测文件夹结构类型 991 | self.detect_folder_structure(folder_path) 992 | 993 | def detect_folder_structure(self, folder_path): 994 | """检测文件夹结构类型,自动选择合适的选项""" 995 | if not os.path.isdir(folder_path): 996 | return 997 | 998 | # 检查是否为预分割结构(有train/val子文件夹) 999 | train_dir = os.path.join(folder_path, 'train') 1000 | val_dir = os.path.join(folder_path, 'val') 1001 | 1002 | if os.path.isdir(train_dir) and os.path.isdir(val_dir): 1003 | # 检查train目录下是否有类别子文件夹 1004 | has_class_dirs = False 1005 | for item in os.listdir(train_dir): 1006 | if os.path.isdir(os.path.join(train_dir, item)) and not item.startswith('.'): 1007 | has_class_dirs = True 1008 | break 1009 | 1010 | if has_class_dirs: 1011 | # 这是预分割的分类数据集结构 1012 | self.presplit_folder_radio.setChecked(True) 1013 | return 1014 | 1015 | # 检查是否为单层类别文件夹结构 1016 | has_potential_class_dirs = False 1017 | has_images = False 1018 | 1019 | for item in os.listdir(folder_path): 1020 | item_path = os.path.join(folder_path, item) 1021 | if os.path.isdir(item_path) and not item.startswith('.'): 1022 | has_potential_class_dirs = True 1023 | 1024 | # 检查文件夹中是否有图像文件 1025 | for file in os.listdir(item_path): 1026 | if file.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')): 1027 | has_images = True 1028 | break 1029 | 1030 | if has_images: 1031 | break 1032 | 1033 | if has_potential_class_dirs and has_images: 1034 | # 这可能是单层类别文件夹结构 1035 | self.single_folder_radio.setChecked(True) 1036 | return 1037 | 1038 | # 默认使用直接文件夹模式 1039 | self.direct_folder_radio.setChecked(True) 1040 | 1041 | def generate_yaml_config(self): 1042 | """从训练文件夹生成YAML配置""" 1043 | train_folder = self.train_folder_input.text() 1044 | if not train_folder or not os.path.isdir(train_folder): 1045 | QMessageBox.warning(self, "错误", "请选择有效的训练文件夹") 1046 | return 1047 | 1048 | try: 1049 | # 创建YAML配置 1050 | import yaml 1051 | import os.path as osp 1052 | import shutil 1053 | 1054 | # 检查是否为分类任务 1055 | task_text = self.task_combo.currentText() 1056 | is_classification = "分类" in task_text or "classify" in task_text.lower() 1057 | 1058 | # 创建配置 1059 | config = {} 1060 | 1061 | if is_classification: 1062 | # 分类任务 1063 | # 获取当前选择的数据集结构类型 1064 | if self.direct_folder_radio.isChecked(): 1065 | # 直接使用文件夹,不生成YAML 1066 | QMessageBox.information(self, "直接使用文件夹", 1067 | "已选择直接使用文件夹模式,无需生成YAML配置。\n" 1068 | "训练时将直接使用该文件夹路径。") 1069 | return 1070 | 1071 | elif self.presplit_folder_radio.isChecked(): 1072 | # 已分割的数据集结构 1073 | config['path'] = train_folder 1074 | config['train'] = 'train' 1075 | config['val'] = 'val' 1076 | if os.path.isdir(os.path.join(train_folder, 'test')): 1077 | config['test'] = 'test' 1078 | 1079 | # 检测类别 1080 | train_dir = os.path.join(train_folder, 'train') 1081 | classes = [] 1082 | for item in os.listdir(train_dir): 1083 | if os.path.isdir(os.path.join(train_dir, item)) and not item.startswith('.'): 1084 | classes.append(item) 1085 | 1086 | elif self.single_folder_radio.isChecked(): 1087 | # 单层类别文件夹结构,需要自动分割 1088 | # 首先获取所有类别 1089 | classes = [] 1090 | for item in os.listdir(train_folder): 1091 | if os.path.isdir(os.path.join(train_folder, item)) and not item.startswith('.'): 1092 | classes.append(item) 1093 | 1094 | # 创建train和val目录 1095 | train_dir = os.path.join(train_folder, 'train') 1096 | val_dir = os.path.join(train_folder, 'val') 1097 | 1098 | # 检查目录是否已存在 1099 | if os.path.exists(train_dir) or os.path.exists(val_dir): 1100 | reply = QMessageBox.question( 1101 | self, 1102 | "目录已存在", 1103 | "train或val目录已存在,是否覆盖?", 1104 | QMessageBox.Yes | QMessageBox.No 1105 | ) 1106 | if reply == QMessageBox.No: 1107 | return 1108 | 1109 | # 删除已存在的目录 1110 | if os.path.exists(train_dir): 1111 | shutil.rmtree(train_dir) 1112 | if os.path.exists(val_dir): 1113 | shutil.rmtree(val_dir) 1114 | 1115 | # 创建目录 1116 | os.makedirs(train_dir, exist_ok=True) 1117 | os.makedirs(val_dir, exist_ok=True) 1118 | 1119 | # 为每个类别创建子目录 1120 | for cls in classes: 1121 | os.makedirs(os.path.join(train_dir, cls), exist_ok=True) 1122 | os.makedirs(os.path.join(val_dir, cls), exist_ok=True) 1123 | 1124 | # 获取验证集比例 1125 | val_ratio = self.val_split_spin.value() 1126 | 1127 | # 分割数据 1128 | import random 1129 | random.seed(42) # 固定随机种子以确保可重复 1130 | 1131 | for cls in classes: 1132 | cls_dir = os.path.join(train_folder, cls) 1133 | images = [f for f in os.listdir(cls_dir) 1134 | if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))] 1135 | 1136 | # 随机打乱 1137 | random.shuffle(images) 1138 | 1139 | # 计算分割点 1140 | split_idx = int(len(images) * (1 - val_ratio)) 1141 | train_images = images[:split_idx] 1142 | val_images = images[split_idx:] 1143 | 1144 | # 复制到train和val目录 1145 | for img in train_images: 1146 | shutil.copy2( 1147 | os.path.join(cls_dir, img), 1148 | os.path.join(train_dir, cls, img) 1149 | ) 1150 | 1151 | for img in val_images: 1152 | shutil.copy2( 1153 | os.path.join(cls_dir, img), 1154 | os.path.join(val_dir, cls, img) 1155 | ) 1156 | 1157 | # 更新配置 1158 | config['path'] = train_folder 1159 | config['train'] = 'train' 1160 | config['val'] = 'val' 1161 | 1162 | # 设置类别信息 1163 | config['nc'] = len(classes) 1164 | config['names'] = {i: name for i, name in enumerate(classes)} 1165 | 1166 | else: 1167 | # 检测/分割任务 1168 | config['path'] = train_folder 1169 | config['train'] = 'images/train' 1170 | config['val'] = 'images/val' 1171 | config['test'] = 'images/test' 1172 | 1173 | # 尝试自动检测类别 1174 | if os.path.exists(os.path.join(train_folder, 'labels')): 1175 | classes = set() 1176 | label_dir = os.path.join(train_folder, 'labels') 1177 | 1178 | for file in os.listdir(label_dir): 1179 | if file.endswith('.txt'): 1180 | with open(os.path.join(label_dir, file), 'r') as f: 1181 | for line in f: 1182 | parts = line.strip().split() 1183 | if parts: 1184 | class_id = int(parts[0]) 1185 | classes.add(class_id) 1186 | 1187 | config['nc'] = max(classes) + 1 if classes else 0 1188 | config['names'] = {i: f'class{i}' for i in range(config['nc'])} 1189 | else: 1190 | config['nc'] = 0 1191 | config['names'] = {} 1192 | 1193 | # 保存YAML文件 1194 | yaml_path = os.path.join(train_folder, 'data.yaml') 1195 | with open(yaml_path, 'w', encoding='utf-8') as f: 1196 | yaml.dump(config, f, allow_unicode=True, default_flow_style=False) 1197 | 1198 | # 更新UI 1199 | self.data_path_input.setText(yaml_path) 1200 | self.data_mode_combo.setCurrentIndex(0) # 切换到YAML模式 1201 | 1202 | QMessageBox.information(self, "成功", f"已生成YAML配置文件:\n{yaml_path}") 1203 | 1204 | except Exception as e: 1205 | QMessageBox.critical(self, "错误", f"生成YAML配置失败:\n{str(e)}") 1206 | 1207 | def browse_model_file(self): 1208 | """浏览选择模型文件""" 1209 | file_path, _ = QFileDialog.getOpenFileName(self, "选择模型文件", "", "PyTorch模型 (*.pt);;All Files (*)") 1210 | if file_path: 1211 | if self.model_combo.findText(file_path) == -1: 1212 | self.model_combo.addItem(file_path) 1213 | self.model_combo.setCurrentText(file_path) 1214 | 1215 | def view_data_config(self): 1216 | """查看/编辑数据配置""" 1217 | data_path = self.data_path_input.text() 1218 | if not data_path or not os.path.exists(data_path): 1219 | QMessageBox.warning(self, "错误", "请先选择有效的数据配置文件") 1220 | return 1221 | 1222 | # 解析YAML文件 1223 | data_config = parse_data_yaml(data_path) 1224 | if not data_config: 1225 | QMessageBox.warning(self, "错误", "无法解析数据配置文件") 1226 | return 1227 | 1228 | # 显示数据配置对话框 1229 | from PySide6.QtWidgets import QDialog, QDialogButtonBox, QTextEdit, QVBoxLayout 1230 | 1231 | dialog = QDialog(self) 1232 | dialog.setWindowTitle("数据配置") 1233 | dialog.resize(600, 400) 1234 | 1235 | layout = QVBoxLayout(dialog) 1236 | 1237 | # 显示YAML内容 1238 | text_edit = QTextEdit() 1239 | text_edit.setPlainText(yaml.dump(data_config, allow_unicode=True)) 1240 | layout.addWidget(text_edit) 1241 | 1242 | # 对话框按钮 1243 | button_box = QDialogButtonBox(QDialogButtonBox.Save | QDialogButtonBox.Cancel) 1244 | button_box.accepted.connect(lambda: self.save_data_config(data_path, text_edit.toPlainText(), dialog)) 1245 | button_box.rejected.connect(dialog.reject) 1246 | layout.addWidget(button_box) 1247 | 1248 | dialog.exec() 1249 | 1250 | def save_data_config(self, file_path, yaml_content, dialog): 1251 | """保存数据配置""" 1252 | try: 1253 | data = yaml.safe_load(yaml_content) 1254 | with open(file_path, 'w', encoding='utf-8') as file: 1255 | yaml.dump(data, file, allow_unicode=True) 1256 | dialog.accept() 1257 | except Exception as e: 1258 | QMessageBox.critical(dialog, "保存错误", f"保存YAML时出错:{str(e)}") 1259 | 1260 | def get_training_parameters(self): 1261 | """获取所有训练参数""" 1262 | params = {} 1263 | 1264 | # 获取任务类型 1265 | task_text = self.task_combo.currentText() 1266 | if "检测" in task_text or "detect" in task_text.lower(): 1267 | params['task'] = 'detect' 1268 | elif "分割" in task_text or "segment" in task_text.lower(): 1269 | params['task'] = 'segment' 1270 | elif "分类" in task_text or "classify" in task_text.lower(): 1271 | params['task'] = 'classify' 1272 | params['is_classification'] = True 1273 | elif "姿态" in task_text or "pose" in task_text.lower(): 1274 | params['task'] = 'pose' 1275 | 1276 | # 获取数据路径 1277 | if self.data_mode_combo.currentIndex() == 0: # YAML配置文件模式 1278 | params['data_path'] = self.data_path_input.text() 1279 | else: # 训练文件夹模式 1280 | train_folder = self.train_folder_input.text() 1281 | 1282 | # 分类任务 1283 | if params.get('is_classification', False): 1284 | # 根据选择的数据集结构类型处理 1285 | if self.direct_folder_radio.isChecked(): 1286 | # 直接使用文件夹 1287 | params['train_folder'] = train_folder 1288 | params['direct_folder_mode'] = True 1289 | elif self.presplit_folder_radio.isChecked() or self.single_folder_radio.isChecked(): 1290 | # 已分割或自动分割的数据集,生成或使用YAML 1291 | yaml_path = os.path.join(train_folder, 'data.yaml') 1292 | 1293 | if not os.path.exists(yaml_path): 1294 | reply = QMessageBox.question( 1295 | self, 1296 | "配置文件不存在", 1297 | "数据集YAML配置文件不存在,是否立即生成?", 1298 | QMessageBox.Yes | QMessageBox.No 1299 | ) 1300 | 1301 | if reply == QMessageBox.Yes: 1302 | self.generate_yaml_config() 1303 | params['data_path'] = yaml_path 1304 | else: 1305 | return None # 用户取消训练 1306 | else: 1307 | params['data_path'] = yaml_path 1308 | else: 1309 | # 检测/分割任务 1310 | yaml_path = os.path.join(train_folder, 'data.yaml') 1311 | 1312 | if not os.path.exists(yaml_path): 1313 | reply = QMessageBox.question( 1314 | self, 1315 | "配置文件不存在", 1316 | "数据集YAML配置文件不存在,是否立即生成?", 1317 | QMessageBox.Yes | QMessageBox.No 1318 | ) 1319 | 1320 | if reply == QMessageBox.Yes: 1321 | self.generate_yaml_config() 1322 | params['data_path'] = yaml_path 1323 | else: 1324 | return None # 用户取消训练 1325 | else: 1326 | params['data_path'] = yaml_path 1327 | 1328 | # 获取核心参数 1329 | params['model'] = self.model_combo.currentText() 1330 | params['batch'] = self.batch_spinbox.value() 1331 | params['imgsz'] = self.imgsz_spinbox.value() 1332 | params['epochs'] = self.epochs_spinbox.value() 1333 | params['lr0'] = self.lr0_spinbox.value() 1334 | 1335 | # 如果是分类任务,确保使用分类模型 1336 | if params.get('is_classification', False) and not params['model'].endswith('-cls.pt'): 1337 | model_name = params['model'].split('.')[0] 1338 | if not model_name.endswith('-cls'): 1339 | model_name += '-cls' 1340 | params['model'] = f"{model_name}.pt" 1341 | 1342 | # 提示用户使用分类模型 1343 | QMessageBox.information( 1344 | self, 1345 | "模型自动调整", 1346 | f"检测到分类任务,已自动调整为分类模型: {params['model']}" 1347 | ) 1348 | 1349 | # 获取优化器参数 1350 | params['optimizer'] = self.optimizer_combo.currentText() 1351 | params['momentum'] = self.momentum_spinbox.value() 1352 | params['weight_decay'] = self.weight_decay_spinbox.value() 1353 | params['cos_lr'] = self.cos_lr_checkbox.isChecked() 1354 | params['amp'] = self.amp_checkbox.isChecked() 1355 | 1356 | # 获取数据增强参数 1357 | params['hsv_h'] = self.hsv_h_spinbox.value() 1358 | params['hsv_s'] = self.hsv_s_spinbox.value() 1359 | params['hsv_v'] = self.hsv_v_spinbox.value() 1360 | params['fliplr'] = self.fliplr_spinbox.value() 1361 | params['mosaic'] = self.mosaic_spinbox.value() 1362 | params['mixup'] = self.mixup_spinbox.value() 1363 | 1364 | # 获取CUDA设置 1365 | if self.use_cuda_checkbox.isChecked(): 1366 | params['device'] = self.device_combo.currentText() 1367 | else: 1368 | params['device'] = 'cpu' 1369 | 1370 | # 获取各个分组的参数(仅获取已更改的值) 1371 | for group_name, param_group in self.param_groups.items(): 1372 | group_params = param_group.get_values() 1373 | for k, v in group_params.items(): 1374 | # 检查是否与默认值不同 1375 | default_value = self.parameters.get(group_name, {}).get(k) 1376 | if v != default_value: 1377 | params[k] = v 1378 | 1379 | return params 1380 | 1381 | def update_cuda_status(self, available): 1382 | """更新CUDA状态显示""" 1383 | if available: 1384 | self.cuda_available_label.setText("CUDA可用: 是") 1385 | self.use_cuda_checkbox.setEnabled(True) 1386 | self.use_cuda_checkbox.setChecked(True) 1387 | else: 1388 | self.cuda_available_label.setText("CUDA可用: 否") 1389 | self.use_cuda_checkbox.setEnabled(False) 1390 | self.use_cuda_checkbox.setChecked(False) 1391 | 1392 | def set_training_mode(self, training): 1393 | """设置界面的训练/非训练状态""" 1394 | self.start_training_button.setEnabled(not training) 1395 | self.stop_training_button.setEnabled(training) 1396 | 1397 | # 禁用/启用参数编辑 1398 | self.task_combo.setEnabled(not training) 1399 | self.data_mode_combo.setEnabled(not training) 1400 | self.data_path_input.setEnabled(not training) 1401 | self.data_path_browse.setEnabled(not training) 1402 | self.data_view_button.setEnabled(not training) 1403 | self.train_folder_input.setEnabled(not training) 1404 | self.train_folder_browse.setEnabled(not training) 1405 | self.direct_folder_radio.setEnabled(not training) 1406 | self.presplit_folder_radio.setEnabled(not training) 1407 | self.single_folder_radio.setEnabled(not training) 1408 | self.val_split_spin.setEnabled(not training) 1409 | self.generate_yaml_button.setEnabled(not training) 1410 | self.model_combo.setEnabled(not training) 1411 | self.model_browse.setEnabled(not training) 1412 | self.use_cuda_checkbox.setEnabled(not training) 1413 | self.device_combo.setEnabled(not training) 1414 | 1415 | # 禁用/启用核心参数 1416 | self.batch_spinbox.setEnabled(not training) 1417 | self.imgsz_spinbox.setEnabled(not training) 1418 | self.epochs_spinbox.setEnabled(not training) 1419 | self.lr0_spinbox.setEnabled(not training) 1420 | 1421 | # 禁用/启用优化器参数 1422 | self.optimizer_combo.setEnabled(not training) 1423 | self.momentum_spinbox.setEnabled(not training) 1424 | self.weight_decay_spinbox.setEnabled(not training) 1425 | self.cos_lr_checkbox.setEnabled(not training) 1426 | self.amp_checkbox.setEnabled(not training) 1427 | 1428 | # 禁用/启用所有参数分组 1429 | for group_name, param_group in self.param_groups.items(): 1430 | for name, widget in param_group.param_widgets.items(): 1431 | widget.setEnabled(not training) 1432 | 1433 | 1434 | class ProgressTab(QWidget): 1435 | """训练进度选项卡""" 1436 | 1437 | def __init__(self, parent=None): 1438 | super().__init__(parent) 1439 | 1440 | # 主布局 1441 | main_layout = QVBoxLayout(self) 1442 | main_layout.setSpacing(15) 1443 | 1444 | # 进度显示区域 1445 | progress_group = QGroupBox("训练进度") 1446 | progress_layout = QVBoxLayout(progress_group) 1447 | progress_layout.setSpacing(10) 1448 | 1449 | # 概览信息 - 使用卡片样式 1450 | from PySide6.QtWidgets import QFrame 1451 | overview_frame = QFrame() 1452 | overview_frame.setStyleSheet(""" 1453 | QFrame { 1454 | background-color: #f8f9fa; 1455 | border-radius: 8px; 1456 | border: 1px solid #e9ecef; 1457 | } 1458 | QLabel { 1459 | font-size: 14px; 1460 | color: #212529; 1461 | } 1462 | """) 1463 | overview_layout = QHBoxLayout(overview_frame) 1464 | 1465 | self.current_epoch_label = QLabel("当前轮次: 0/0") 1466 | self.current_epoch_label.setStyleSheet("font-weight: bold;") 1467 | 1468 | self.elapsed_time_label = QLabel("已用时间: 00:00:00") 1469 | self.elapsed_time_label.setStyleSheet("color: #495057;") 1470 | 1471 | self.eta_label = QLabel("预计剩余: 00:00:00") 1472 | self.eta_label.setStyleSheet("color: #0d6efd;") 1473 | 1474 | overview_layout.addWidget(self.current_epoch_label) 1475 | overview_layout.addWidget(self.elapsed_time_label) 1476 | overview_layout.addWidget(self.eta_label) 1477 | 1478 | progress_layout.addWidget(overview_frame) 1479 | 1480 | # 进度条 - 美化 1481 | self.progress_bar = QProgressBar() 1482 | self.progress_bar.setRange(0, 100) 1483 | self.progress_bar.setValue(0) 1484 | self.progress_bar.setStyleSheet(""" 1485 | QProgressBar { 1486 | border: none; 1487 | border-radius: 5px; 1488 | background-color: #e9ecef; 1489 | text-align: center; 1490 | height: 25px; 1491 | font-weight: bold; 1492 | } 1493 | 1494 | QProgressBar::chunk { 1495 | background-color: qlineargradient(x1:0, y1:0, x2:1, y2:0, stop:0 #0d6efd, stop:1 #0dcaf0); 1496 | border-radius: 5px; 1497 | } 1498 | """) 1499 | progress_layout.addWidget(self.progress_bar) 1500 | 1501 | # 指标显示 - 改为卡片式布局 1502 | metrics_frame = QFrame() 1503 | metrics_frame.setStyleSheet(""" 1504 | QFrame { 1505 | background-color: #f8f9fa; 1506 | border-radius: 8px; 1507 | border: 1px solid #e9ecef; 1508 | } 1509 | QLabel { 1510 | font-size: 14px; 1511 | } 1512 | """) 1513 | metrics_layout = QGridLayout(metrics_frame) 1514 | metrics_layout.setContentsMargins(15, 10, 15, 10) 1515 | metrics_layout.setSpacing(10) 1516 | 1517 | metrics_title = QLabel("性能指标") 1518 | metrics_title.setStyleSheet("font-weight: bold; font-size: 16px; color: #212529;") 1519 | metrics_layout.addWidget(metrics_title, 0, 0, 1, 3) 1520 | 1521 | # 指标标签 1522 | self.mAP_label = QLabel("mAP50-95: -") 1523 | self.mAP_label.setStyleSheet("color: #0d6efd;") 1524 | 1525 | self.mAP50_label = QLabel("mAP50: -") 1526 | self.mAP50_label.setStyleSheet("color: #20c997;") 1527 | 1528 | self.precision_label = QLabel("Precision: -") 1529 | self.precision_label.setStyleSheet("color: #fd7e14;") 1530 | 1531 | self.recall_label = QLabel("Recall: -") 1532 | self.recall_label.setStyleSheet("color: #6f42c1;") 1533 | 1534 | metrics_layout.addWidget(self.mAP_label, 1, 0) 1535 | metrics_layout.addWidget(self.mAP50_label, 1, 1) 1536 | metrics_layout.addWidget(self.precision_label, 2, 0) 1537 | metrics_layout.addWidget(self.recall_label, 2, 1) 1538 | 1539 | progress_layout.addWidget(metrics_frame) 1540 | 1541 | main_layout.addWidget(progress_group) 1542 | 1543 | # 日志显示区域 1544 | log_group = QGroupBox("训练日志") 1545 | log_layout = QVBoxLayout(log_group) 1546 | 1547 | self.log_text = QTextEdit() 1548 | self.log_text.setReadOnly(True) 1549 | self.log_text.setStyleSheet(""" 1550 | QTextEdit { 1551 | background-color: #212529; 1552 | color: #f8f9fa; 1553 | border: none; 1554 | border-radius: 5px; 1555 | font-family: Consolas, Monospace; 1556 | font-size: 12px; 1557 | padding: 5px; 1558 | } 1559 | """) 1560 | log_layout.addWidget(self.log_text) 1561 | 1562 | main_layout.addWidget(log_group, 1) # 添加拉伸因子 1563 | 1564 | def update_progress(self, progress_info): 1565 | """更新进度信息""" 1566 | # 更新输出日志 1567 | if 'output_line' in progress_info: 1568 | self.log_text.append(progress_info['output_line']) 1569 | # 滚动到底部 1570 | self.log_text.verticalScrollBar().setValue( 1571 | self.log_text.verticalScrollBar().maximum() 1572 | ) 1573 | 1574 | # 更新进度信息 1575 | if 'current_epoch' in progress_info and 'total_epochs' in progress_info: 1576 | self.current_epoch_label.setText( 1577 | f"当前轮次: {progress_info['current_epoch']}/{progress_info['total_epochs']}" 1578 | ) 1579 | 1580 | # 更新时间信息 1581 | if 'elapsed_time' in progress_info: 1582 | self.elapsed_time_label.setText(f"已用时间: {progress_info['elapsed_time']}") 1583 | 1584 | if 'eta' in progress_info: 1585 | self.eta_label.setText(f"预计剩余: {progress_info['eta']}") 1586 | 1587 | # 更新进度条 1588 | if 'progress' in progress_info: 1589 | self.progress_bar.setValue(int(progress_info['progress'])) 1590 | 1591 | # 更新指标 1592 | if 'metrics' in progress_info: 1593 | metrics = progress_info['metrics'] 1594 | if 'mAP50-95' in metrics: 1595 | self.mAP_label.setText(f"mAP50-95: {metrics['mAP50-95']:.4f}") 1596 | if 'mAP50' in metrics: 1597 | self.mAP50_label.setText(f"mAP50: {metrics['mAP50']:.4f}") 1598 | if 'precision' in metrics: 1599 | self.precision_label.setText(f"Precision: {metrics['precision']:.4f}") 1600 | if 'recall' in metrics: 1601 | self.recall_label.setText(f"Recall: {metrics['recall']:.4f}") 1602 | 1603 | 1604 | class EnvironmentTab(QWidget): 1605 | """环境信息选项卡""" 1606 | 1607 | def __init__(self, parent=None): 1608 | super().__init__(parent) 1609 | 1610 | # 主布局 1611 | main_layout = QVBoxLayout(self) 1612 | 1613 | # 系统信息 1614 | system_group = QGroupBox("系统信息") 1615 | system_layout = QGridLayout(system_group) 1616 | 1617 | self.os_label = QLabel("操作系统: 未检测") 1618 | self.python_version_label = QLabel("Python版本: 未检测") 1619 | 1620 | system_layout.addWidget(self.os_label, 0, 0) 1621 | system_layout.addWidget(self.python_version_label, 0, 1) 1622 | 1623 | main_layout.addWidget(system_group) 1624 | 1625 | # YOLOv8信息 1626 | yolo_group = QGroupBox("YOLOv8") 1627 | yolo_layout = QGridLayout(yolo_group) 1628 | 1629 | self.yolo_installed_label = QLabel("安装状态: 未检测") 1630 | self.yolo_version_label = QLabel("版本: 未检测") 1631 | self.install_yolo_button = QPushButton("安装YOLOv8") 1632 | 1633 | yolo_layout.addWidget(self.yolo_installed_label, 0, 0) 1634 | yolo_layout.addWidget(self.yolo_version_label, 0, 1) 1635 | yolo_layout.addWidget(self.install_yolo_button, 0, 2) 1636 | 1637 | main_layout.addWidget(yolo_group) 1638 | 1639 | # CUDA信息 1640 | cuda_group = QGroupBox("CUDA") 1641 | cuda_layout = QGridLayout(cuda_group) 1642 | 1643 | self.cuda_available_label = QLabel("CUDA可用: 未检测") 1644 | self.cuda_version_label = QLabel("CUDA版本: 未检测") 1645 | self.torch_version_label = QLabel("PyTorch版本: 未检测") 1646 | 1647 | cuda_layout.addWidget(self.cuda_available_label, 0, 0) 1648 | cuda_layout.addWidget(self.cuda_version_label, 0, 1) 1649 | cuda_layout.addWidget(self.torch_version_label, 1, 0, 1, 2) 1650 | 1651 | main_layout.addWidget(cuda_group) 1652 | 1653 | # GPU信息 1654 | gpu_group = QGroupBox("GPU信息") 1655 | gpu_layout = QVBoxLayout(gpu_group) 1656 | 1657 | self.gpu_info_text = QTextEdit() 1658 | self.gpu_info_text.setReadOnly(True) 1659 | gpu_layout.addWidget(self.gpu_info_text) 1660 | 1661 | main_layout.addWidget(gpu_group, 1) # 添加拉伸因子 1662 | 1663 | # 底部工具栏 1664 | tools_layout = QHBoxLayout() 1665 | 1666 | self.refresh_button = QPushButton("刷新") 1667 | tools_layout.addWidget(self.refresh_button) 1668 | 1669 | main_layout.addLayout(tools_layout) 1670 | 1671 | def update_environment_info(self, status): 1672 | """更新环境信息""" 1673 | # 系统信息 1674 | self.os_label.setText(f"操作系统: {status.get('os_info', '未知')}") 1675 | self.python_version_label.setText(f"Python版本: {status.get('python_version', '未知')}") 1676 | 1677 | # YOLOv8信息 1678 | if status.get('yolov8_installed', False): 1679 | self.yolo_installed_label.setText("安装状态: 已安装") 1680 | self.yolo_version_label.setText(f"版本: {status.get('yolov8_version', '未知')}") 1681 | self.install_yolo_button.setEnabled(False) 1682 | else: 1683 | self.yolo_installed_label.setText("安装状态: 未安装") 1684 | self.yolo_version_label.setText("版本: -") 1685 | self.install_yolo_button.setEnabled(True) 1686 | 1687 | # CUDA信息 1688 | if status.get('cuda_available', False): 1689 | self.cuda_available_label.setText("CUDA可用: 是") 1690 | else: 1691 | self.cuda_available_label.setText("CUDA可用: 否") 1692 | 1693 | self.cuda_version_label.setText(f"CUDA版本: {status.get('cuda_version', '未知')}") 1694 | self.torch_version_label.setText(f"PyTorch版本: {status.get('torch_version', '未知')}") 1695 | 1696 | # GPU信息 1697 | gpu_info = status.get('gpu_info', []) 1698 | if gpu_info: 1699 | gpu_text = "" 1700 | for gpu in gpu_info: 1701 | gpu_text += f"GPU {gpu['index']}: {gpu['name']} ({gpu['memory']})\n" 1702 | self.gpu_info_text.setText(gpu_text) 1703 | else: 1704 | self.gpu_info_text.setText("未检测到GPU") 1705 | 1706 | 1707 | class MainWindow(QMainWindow): 1708 | """主窗口""" 1709 | 1710 | def __init__(self, parameters, parent=None): 1711 | super().__init__(parent) 1712 | self.parameters = parameters 1713 | 1714 | # 设置窗口属性 1715 | self.setWindowTitle("YOLOv8 训练工具") 1716 | self.resize(1200, 800) 1717 | 1718 | # 创建中央部件 1719 | central_widget = QWidget() 1720 | self.setCentralWidget(central_widget) 1721 | 1722 | # 主布局 1723 | main_layout = QVBoxLayout(central_widget) 1724 | main_layout.setContentsMargins(10, 10, 10, 10) 1725 | 1726 | # 创建选项卡 1727 | tab_widget = QTabWidget() 1728 | tab_widget.setStyleSheet(""" 1729 | QTabWidget::pane { 1730 | border: 1px solid #cccccc; 1731 | border-radius: 5px; 1732 | background-color: #ffffff; 1733 | } 1734 | 1735 | QTabBar::tab { 1736 | background-color: #f0f0f0; 1737 | border: 1px solid #cccccc; 1738 | border-bottom: none; 1739 | border-top-left-radius: 4px; 1740 | border-top-right-radius: 4px; 1741 | padding: 8px 16px; 1742 | min-width: 100px; 1743 | font-weight: bold; 1744 | } 1745 | 1746 | QTabBar::tab:selected { 1747 | background-color: #ffffff; 1748 | border-bottom: 1px solid #ffffff; 1749 | } 1750 | 1751 | QTabBar::tab:hover:!selected { 1752 | background-color: #e5e5e5; 1753 | } 1754 | 1755 | QScrollBar:vertical { 1756 | border: none; 1757 | background: #f0f0f0; 1758 | width: 12px; 1759 | border-radius: 6px; 1760 | } 1761 | 1762 | QScrollBar::handle:vertical { 1763 | background: #a0a0a0; 1764 | min-height: 20px; 1765 | border-radius: 6px; 1766 | } 1767 | 1768 | QScrollBar::add-line:vertical, QScrollBar::sub-line:vertical { 1769 | border: none; 1770 | background: none; 1771 | } 1772 | 1773 | QScrollBar::add-page:vertical, QScrollBar::sub-page:vertical { 1774 | background: none; 1775 | } 1776 | 1777 | QPushButton { 1778 | background-color: #3498db; 1779 | color: white; 1780 | border: none; 1781 | border-radius: 4px; 1782 | padding: 8px 16px; 1783 | font-weight: bold; 1784 | } 1785 | 1786 | QPushButton:hover { 1787 | background-color: #2980b9; 1788 | } 1789 | 1790 | QPushButton:pressed { 1791 | background-color: #1c6ea4; 1792 | } 1793 | 1794 | QPushButton:disabled { 1795 | background-color: #cccccc; 1796 | color: #666666; 1797 | } 1798 | 1799 | QGroupBox { 1800 | border: 1px solid #cccccc; 1801 | border-radius: 5px; 1802 | margin-top: 1.5ex; 1803 | font-weight: bold; 1804 | font-size: 12px; 1805 | } 1806 | 1807 | QGroupBox::title { 1808 | subcontrol-origin: margin; 1809 | subcontrol-position: top left; 1810 | left: 10px; 1811 | padding: 0 5px; 1812 | color: #333333; 1813 | } 1814 | 1815 | QLineEdit, QSpinBox, QDoubleSpinBox, QComboBox { 1816 | border: 1px solid #cccccc; 1817 | border-radius: 4px; 1818 | padding: 5px; 1819 | background-color: #ffffff; 1820 | selection-background-color: #3498db; 1821 | } 1822 | 1823 | QLineEdit:focus, QSpinBox:focus, QDoubleSpinBox:focus, QComboBox:focus { 1824 | border: 1px solid #3498db; 1825 | } 1826 | 1827 | QProgressBar { 1828 | border: none; 1829 | border-radius: 5px; 1830 | background-color: #f0f0f0; 1831 | text-align: center; 1832 | height: 20px; 1833 | } 1834 | 1835 | QProgressBar::chunk { 1836 | background-color: #3498db; 1837 | border-radius: 5px; 1838 | } 1839 | 1840 | QTextEdit { 1841 | border: 1px solid #cccccc; 1842 | border-radius: 5px; 1843 | padding: 5px; 1844 | background-color: #ffffff; 1845 | selection-background-color: #3498db; 1846 | } 1847 | """) 1848 | 1849 | # 训练选项卡 1850 | self.training_tab = TrainingTab(parameters) 1851 | tab_widget.addTab(self.training_tab, "训练设置") 1852 | 1853 | # 进度选项卡 1854 | self.progress_tab = ProgressTab() 1855 | tab_widget.addTab(self.progress_tab, "训练进度") 1856 | 1857 | # 环境选项卡 1858 | self.environment_tab = EnvironmentTab() 1859 | tab_widget.addTab(self.environment_tab, "环境信息") 1860 | 1861 | main_layout.addWidget(tab_widget) 1862 | 1863 | # 获取训练控制按钮的引用 1864 | self.start_training_button = self.training_tab.start_training_button 1865 | self.stop_training_button = self.training_tab.stop_training_button 1866 | 1867 | # 设置应用程序图标 1868 | try: 1869 | from PySide6.QtGui import QIcon 1870 | self.setWindowIcon(QIcon("icon.png")) # 您需要提供一个图标文件 1871 | except: 1872 | pass 1873 | 1874 | # 状态栏 1875 | self.statusBar().showMessage("YOLOv8 训练工具已就绪") 1876 | 1877 | def update_cuda_status(self, available): 1878 | """更新CUDA状态""" 1879 | self.training_tab.update_cuda_status(available) 1880 | 1881 | def update_environment_info(self, status): 1882 | """更新环境信息""" 1883 | self.environment_tab.update_environment_info(status) 1884 | 1885 | # 更新CUDA设备选择 1886 | if status.get('cuda_available', False): 1887 | self.training_tab.device_combo.clear() 1888 | self.training_tab.device_combo.addItem("") # 自动选择 1889 | 1890 | # 添加每个GPU 1891 | for gpu in status.get('gpu_info', []): 1892 | self.training_tab.device_combo.addItem( 1893 | f"cuda:{gpu['index']} ({gpu['name']})" 1894 | ) 1895 | 1896 | def update_progress(self, progress_info): 1897 | """更新进度信息""" 1898 | self.progress_tab.update_progress(progress_info) 1899 | 1900 | def set_training_mode(self, training): 1901 | """设置界面的训练/非训练状态""" 1902 | self.training_tab.set_training_mode(training) 1903 | 1904 | def get_training_parameters(self): 1905 | """获取所有训练参数""" 1906 | return self.training_tab.get_training_parameters() 1907 | --------------------------------------------------------------------------------