├── .gitignore ├── README.md ├── config.py ├── custom ├── graphicsView.py ├── listWidgetItems.py ├── listWidgets.py ├── mediaplayer.py ├── stackedWidget.py ├── tableWidget.py ├── treeView.py └── warn.mp3 ├── flags.py ├── font └── simhei.ttf ├── icons ├── branch-close.png ├── branch-open.png ├── mask.png ├── video.png ├── 右旋转.png ├── 图片.png ├── 左旋转.png ├── 摄像头.png └── 识别.png ├── main.py ├── maskimg └── demo.png ├── model_data ├── coco_classes.txt ├── voc_classes.txt └── yolo_anchors.txt ├── nets ├── mobilenet_v1.py ├── mobilenet_v2.py ├── mobilenet_v3.py ├── yolo4.py └── yolo_training.py ├── requirements.txt ├── utils ├── dataloader.py └── utils.py └── yolo.py /.gitignore: -------------------------------------------------------------------------------- 1 | # ignore map, miou, datasets 2 | model_data/yolov4_mobile_mask.pth 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 8 | ## PyQt5实现目标检测系统搭建——支持多线程 9 | --- 10 | 11 | ## 功能 12 | 1. 支持读取本地图片 13 | 2. 支持读取本地视频 14 | 3. 支持打开摄像头实时检测 15 | 4. 支持多线程,防止卡顿 16 | 5. 支持检测到人脸未佩戴口罩时记录,并语音警告 17 | 18 | ## 所需环境 19 | torch==1.2.0 20 | PyQt5==5.15.7 21 | pygame==2.2.0 22 | 23 | ## 文件下载 24 | 模型权重下载后放入**model_data**文件夹中 25 | 26 | [yolov4_mobile_mask.pth](https://github.com/Egrt/YOLO_PyQt5/releases/download/V1.0/yolov7_mobile_mask.pth) 27 | 28 | ## 运行 29 | 运行根目录下**main.py**启动界面: 30 | 31 | ![界面演示](maskimg/demo.png) -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from custom.tableWidget import * 2 | from custom.listWidgetItems import * 3 | 4 | 5 | # Implemented functions 6 | items = [ 7 | GrayingItem, 8 | FilterItem, 9 | MorphItem, 10 | GradItem, 11 | ThresholdItem, 12 | EdgeItem, 13 | ContourItem, 14 | EqualizeItem, 15 | HoughLineItem, 16 | LightItem, 17 | GammaItem 18 | ] 19 | 20 | tables = [ 21 | GrayingTableWidget, 22 | FilterTabledWidget, 23 | EqualizeTableWidget, 24 | MorphTabledWidget, 25 | GradTabledWidget, 26 | ThresholdTableWidget, 27 | EdgeTableWidget, 28 | ContourTableWidget, 29 | HoughLineTableWidget, 30 | LightTableWidget, 31 | GammaITabelWidget 32 | 33 | ] 34 | 35 | -------------------------------------------------------------------------------- /custom/graphicsView.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from PyQt5.QtGui import * 3 | from PyQt5.QtCore import * 4 | from PyQt5.QtWidgets import * 5 | 6 | 7 | class GraphicsView(QGraphicsView): 8 | def __init__(self, parent=None): 9 | super(GraphicsView, self).__init__(parent=parent) 10 | self._zoom = 0 11 | self._empty = True 12 | self._photo = QGraphicsPixmapItem() 13 | self._scene = QGraphicsScene(self) 14 | self._scene.addItem(self._photo) 15 | self.setScene(self._scene) 16 | self.setAlignment(Qt.AlignCenter) # 居中显示 17 | self.setDragMode(QGraphicsView.ScrollHandDrag) # 设置拖动 18 | self.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff) 19 | self.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) 20 | self.setMinimumSize(640, 480) 21 | 22 | def contextMenuEvent(self, event): 23 | if not self.has_photo(): 24 | return 25 | menu = QMenu() 26 | save_action = QAction('另存为', self) 27 | save_action.triggered.connect(self.save_current) # 传递额外值 28 | menu.addAction(save_action) 29 | menu.exec(QCursor.pos()) 30 | 31 | def save_current(self): 32 | file_name = QFileDialog.getSaveFileName(self, '另存为', './', 'Image files(*.jpg *.gif *.png)')[0] 33 | print(file_name) 34 | if file_name: 35 | self._photo.pixmap().save(file_name) 36 | 37 | def get_image(self): 38 | if self.has_photo(): 39 | return self._photo.pixmap().toImage() 40 | 41 | def set_image(self,img): 42 | self._photo.setPixmap(img) 43 | self.fitInView2() 44 | 45 | def has_photo(self): 46 | return not self._empty 47 | 48 | def change_image(self, img): 49 | self.update_image(img) 50 | self.fitInView() 51 | def img_to_pixmap(self, img): 52 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # bgr -> rgb 53 | h, w, c = img.shape # 获取图片形状 54 | image = QImage(img, w, h, 3 * w, QImage.Format_RGB888) 55 | return QPixmap.fromImage(image) 56 | 57 | def update_image(self, img): 58 | self._empty = False 59 | self._photo.setPixmap(self.img_to_pixmap(img)) 60 | 61 | def fitInView(self, scale=True): 62 | rect = QRectF(self._photo.pixmap().rect()) 63 | if not rect.isNull(): 64 | self.setSceneRect(rect) 65 | if self.has_photo(): 66 | unity = self.transform().mapRect(QRectF(0, 0, 1, 1)) 67 | self.scale(1 / unity.width(), 1 / unity.height()) 68 | viewrect = self.viewport().rect() 69 | scenerect = self.transform().mapRect(rect) 70 | factor = min(viewrect.width() / scenerect.width(), 71 | viewrect.height() / scenerect.height()) 72 | self.scale(factor, factor) 73 | self._zoom = 0 74 | def fitInView2(self, scale=True): 75 | rect = QRectF(self._photo.pixmap().rect()) 76 | if not rect.isNull(): 77 | self.setSceneRect(rect) 78 | 79 | unity = self.transform().mapRect(QRectF(0, 0, 1, 1)) 80 | self.scale(1 / unity.width(), 1 / unity.height()) 81 | viewrect = self.viewport().rect() 82 | scenerect = self.transform().mapRect(rect) 83 | factor = min(viewrect.width() / scenerect.width(), 84 | viewrect.height() / scenerect.height()) 85 | self.scale(factor, factor) 86 | self._zoom = 0 87 | 88 | def wheelEvent(self, event): 89 | if self.has_photo(): 90 | if event.angleDelta().y() > 0: 91 | factor = 1.25 92 | self._zoom += 1 93 | else: 94 | factor = 0.8 95 | self._zoom -= 1 96 | if self._zoom > 0: 97 | self.scale(factor, factor) 98 | elif self._zoom == 0: 99 | self.fitInView() 100 | else: 101 | self._zoom = 0 102 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /custom/listWidgetItems.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PyQt5.QtCore import QSize 3 | from PyQt5.QtGui import QIcon, QColor 4 | from PyQt5.QtWidgets import QListWidgetItem, QPushButton 5 | from flags import * 6 | from yolo import YOLO 7 | 8 | class MyItem(QListWidgetItem): 9 | def __init__(self, name=None, parent=None): 10 | super(MyItem, self).__init__(name, parent=parent) 11 | self.setIcon(QIcon('icons/识别.png')) 12 | self.setSizeHint(QSize(80, 60)) # size 13 | def get_params(self): 14 | protected = [v for v in dir(self) if v.startswith('_') and not v.startswith('__')] 15 | param = {} 16 | for v in protected: 17 | param[v.replace('_', '', 1)] = self.__getattribute__(v) 18 | return param 19 | 20 | def update_params(self, param): 21 | for k, v in param.items(): 22 | if '_' + k in dir(self): 23 | self.__setattr__('_' + k, v) 24 | 25 | 26 | class GrayingItem(MyItem): 27 | def __init__(self, parent=None): 28 | super(GrayingItem, self).__init__(' 二值化 ', parent=parent) 29 | self._mode = BGR2GRAY_COLOR 30 | 31 | def __call__(self, img): 32 | img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) 33 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 34 | return img 35 | 36 | 37 | class FilterItem(MyItem): 38 | 39 | def __init__(self, parent=None): 40 | super().__init__('平滑处理', parent=parent) 41 | self._ksize = 3 42 | self._kind = MEAN_FILTER 43 | self._sigmax = 0 44 | 45 | def __call__(self, img): 46 | if self._kind == MEAN_FILTER: 47 | img = cv2.blur(img, (self._ksize, self._ksize)) 48 | elif self._kind == GAUSSIAN_FILTER: 49 | img = cv2.GaussianBlur(img, (self._ksize, self._ksize), self._sigmax) 50 | elif self._kind == MEDIAN_FILTER: 51 | img = cv2.medianBlur(img, self._ksize) 52 | return img 53 | 54 | 55 | class MorphItem(MyItem): 56 | def __init__(self, parent=None): 57 | super().__init__(' 形态学 ', parent=parent) 58 | self._ksize = 3 59 | self._op = ERODE_MORPH_OP 60 | self._kshape = RECT_MORPH_SHAPE 61 | 62 | def __call__(self, img): 63 | op = MORPH_OP[self._op] 64 | kshape = MORPH_SHAPE[self._kshape] 65 | kernal = cv2.getStructuringElement(kshape, (self._ksize, self._ksize)) 66 | img = cv2.morphologyEx(img, self._op, kernal) 67 | return img 68 | 69 | 70 | class GradItem(MyItem): 71 | 72 | def __init__(self, parent=None): 73 | super().__init__('图像梯度', parent=parent) 74 | self._kind = SOBEL_GRAD 75 | self._ksize = 3 76 | self._dx = 1 77 | self._dy = 0 78 | 79 | def __call__(self, img): 80 | if self._dx == 0 and self._dy == 0 and self._kind != LAPLACIAN_GRAD: 81 | self.setBackground(QColor(255, 0, 0)) 82 | self.setText('图像梯度 (无效: dx与dy不同时为0)') 83 | else: 84 | self.setBackground(QColor(200, 200, 200)) 85 | self.setText('图像梯度') 86 | if self._kind == SOBEL_GRAD: 87 | img = cv2.Sobel(img, -1, self._dx, self._dy, self._ksize) 88 | elif self._kind == SCHARR_GRAD: 89 | img = cv2.Scharr(img, -1, self._dx, self._dy) 90 | elif self._kind == LAPLACIAN_GRAD: 91 | img = cv2.Laplacian(img, -1) 92 | return img 93 | 94 | 95 | class ThresholdItem(MyItem): 96 | def __init__(self, parent=None): 97 | super().__init__('阈值处理', parent=parent) 98 | self._thresh = 127 99 | self._maxval = 255 100 | self._method = BINARY_THRESH_METHOD 101 | 102 | def __call__(self, img): 103 | method = THRESH_METHOD[self._method] 104 | img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) 105 | img = cv2.threshold(img, self._thresh, self._thresh, method)[1] 106 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 107 | return img 108 | 109 | 110 | class EdgeItem(MyItem): 111 | def __init__(self, parent=None): 112 | super(EdgeItem, self).__init__('边缘检测', parent=parent) 113 | self._thresh1 = 20 114 | self._thresh2 = 100 115 | 116 | def __call__(self, img): 117 | img = cv2.Canny(img, threshold1=self._thresh1, threshold2=self._thresh2) 118 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 119 | return img 120 | 121 | 122 | class ContourItem(MyItem): 123 | def __init__(self, parent=None): 124 | super(ContourItem, self).__init__('轮廓检测', parent=parent) 125 | self._mode = TREE_CONTOUR_MODE 126 | self._method = SIMPLE_CONTOUR_METHOD 127 | self._bbox = NORMAL_CONTOUR 128 | 129 | def __call__(self, img): 130 | mode = CONTOUR_MODE[self._mode] 131 | method = CONTOUR_METHOD[self._method] 132 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 133 | cnts, _ = cv2.findContours(img, mode, method) 134 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 135 | if self._bbox == RECT_CONTOUR: 136 | bboxs = [cv2.boundingRect(cnt) for cnt in cnts] 137 | print(bboxs) 138 | for x, y, w, h in bboxs: 139 | img = cv2.rectangle(img, (x, y), (x + w, y + h), (255, 0, 0), thickness=2) 140 | elif self._bbox == MINRECT_CONTOUR: 141 | bboxs = [np.int0(cv2.boxPoints(cv2.minAreaRect(cnt))) for cnt in cnts] 142 | img = cv2.drawContours(img, bboxs, -1, (255, 0, 0), thickness=2) 143 | elif self._bbox == MINCIRCLE_CONTOUR: 144 | circles = [cv2.minEnclosingCircle(cnt) for cnt in cnts] 145 | print(circles) 146 | for (x, y), r in circles: 147 | img = cv2.circle(img, (int(x), int(y)), int(r), (255, 0, 0), thickness=2) 148 | elif self._bbox == NORMAL_CONTOUR: 149 | img = cv2.drawContours(img, cnts, -1, (255, 0, 0), thickness=2) 150 | 151 | return img 152 | 153 | 154 | class EqualizeItem(MyItem): 155 | def __init__(self, parent=None): 156 | super().__init__(' 均衡化 ', parent=parent) 157 | self._blue = True 158 | self._green = True 159 | self._red = True 160 | 161 | def __call__(self, img): 162 | b, g, r = cv2.split(img) 163 | if self._blue: 164 | b = cv2.equalizeHist(b) 165 | if self._green: 166 | g = cv2.equalizeHist(g) 167 | if self._red: 168 | r = cv2.equalizeHist(r) 169 | return cv2.merge((b, g, r)) 170 | 171 | 172 | class HoughLineItem(MyItem): 173 | def __init__(self, parent=None): 174 | super(HoughLineItem, self).__init__('直线检测', parent=parent) 175 | self._rho = 1 176 | self._theta = np.pi / 180 177 | self._thresh = 10 178 | self._min_length = 20 179 | self._max_gap = 5 180 | 181 | def __call__(self, img): 182 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 183 | lines = cv2.HoughLinesP(img, self._rho, self._theta, self._thresh, minLineLength=self._min_length, 184 | maxLineGap=self._max_gap) 185 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 186 | if lines is None: return img 187 | for line in lines: 188 | for x1, y1, x2, y2 in line: 189 | img = cv2.line(img, (x1, y1), (x2, y2), (0, 255, 0), thickness=2) 190 | return img 191 | 192 | 193 | class LightItem(MyItem): 194 | def __init__(self, parent=None): 195 | super(LightItem, self).__init__('亮度调节', parent=parent) 196 | self._alpha = 1 197 | self._beta = 0 198 | 199 | def __call__(self, img): 200 | blank = np.zeros(img.shape, img.dtype) 201 | img = cv2.addWeighted(img, self._alpha, blank, 1 - self._alpha, self._beta) 202 | return img 203 | 204 | 205 | class GammaItem(MyItem): 206 | def __init__(self, parent=None): 207 | super(GammaItem, self).__init__('伽马校正', parent=parent) 208 | self._gamma = 1 209 | 210 | def __call__(self, img): 211 | gamma_table = [np.power(x / 255.0, self._gamma) * 255.0 for x in range(256)] 212 | gamma_table = np.round(np.array(gamma_table)).astype(np.uint8) 213 | return cv2.LUT(img, gamma_table) 214 | -------------------------------------------------------------------------------- /custom/listWidgets.py: -------------------------------------------------------------------------------- 1 | from PyQt5.QtGui import * 2 | from PyQt5.QtCore import * 3 | from PyQt5.QtWidgets import * 4 | 5 | from config import items 6 | 7 | 8 | class MyListWidget(QListWidget): 9 | def __init__(self, parent=None): 10 | super().__init__(parent=parent) 11 | self.mainwindow = parent 12 | self.setDragEnabled(True) 13 | # 选中不显示虚线 14 | # self.setEditTriggers(QAbstractItemView.NoEditTriggers) 15 | self.setFocusPolicy(Qt.NoFocus) 16 | 17 | 18 | class UsedListWidget(MyListWidget): 19 | def __init__(self, parent=None): 20 | super().__init__(parent=parent) 21 | self.setAcceptDrops(True) 22 | self.setFlow(QListView.TopToBottom) # 设置列表方向 23 | self.setDefaultDropAction(Qt.MoveAction) # 设置拖放为移动而不是复制一个 24 | self.setDragDropMode(QAbstractItemView.InternalMove) # 设置拖放模式, 内部拖放 25 | self.itemClicked.connect(self.show_attr) 26 | self.setMinimumWidth(300) 27 | 28 | self.move_item = None 29 | 30 | def contextMenuEvent(self, e): 31 | # 右键菜单事件 32 | item = self.itemAt(self.mapFromGlobal(QCursor.pos())) 33 | if not item: return # 判断是否是空白区域 34 | menu = QMenu() 35 | delete_action = QAction('删除', self) 36 | delete_action.triggered.connect(lambda: self.delete_item(item)) # 传递额外值 37 | menu.addAction(delete_action) 38 | menu.exec(QCursor.pos()) 39 | 40 | def delete_item(self, item): 41 | # 删除操作 42 | self.takeItem(self.row(item)) 43 | self.mainwindow.update_image() # 更新frame 44 | # self.mainwindow.dock_attr.close() 45 | 46 | def dropEvent(self, event): 47 | super().dropEvent(event) 48 | self.mainwindow.update_image() 49 | 50 | def show_attr(self): 51 | item = self.itemAt(self.mapFromGlobal(QCursor.pos())) 52 | if not item: return 53 | param = item.get_params() # 获取当前item的属性 54 | if type(item) in items: 55 | index = items.index(type(item)) # 获取item对应的table索引 56 | self.mainwindow.stackedWidget.setCurrentIndex(index) 57 | self.mainwindow.stackedWidget.currentWidget().update_params(param) # 更新对应的table 58 | self.mainwindow.dock_attr.show() 59 | 60 | 61 | class FuncListWidget(MyListWidget): 62 | def __init__(self, parent=None): 63 | super().__init__(parent=parent) 64 | self.setFixedHeight(64) 65 | self.setFlow(QListView.LeftToRight) # 设置列表方向 66 | self.setViewMode(QListView.IconMode) # 设置列表模式 67 | self.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff) # 关掉滑动条 68 | self.setAcceptDrops(False) 69 | for itemType in items: 70 | self.addItem(itemType()) 71 | self.itemClicked.connect(self.add_used_function) 72 | 73 | def add_used_function(self): 74 | func_item = self.currentItem() 75 | if type(func_item) in items: 76 | use_item = type(func_item)() 77 | self.mainwindow.useListWidget.addItem(use_item) 78 | self.mainwindow.update_image() 79 | 80 | def enterEvent(self, event): 81 | self.setCursor(Qt.PointingHandCursor) 82 | 83 | def leaveEvent(self, event): 84 | self.setCursor(Qt.ArrowCursor) 85 | self.setCurrentRow(-1) # 取消选中状态 86 | -------------------------------------------------------------------------------- /custom/mediaplayer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Jan 13 17:51:21 2021 4 | 5 | @author: YeZixun 6 | """ 7 | 8 | 9 | import time # 需要导入时间模块设置延时 10 | 11 | from PyQt5.QtWidgets import * 12 | from PyQt5.QtCore import * 13 | from PyQt5 import QtCore, QtGui, QtWidgets 14 | import pygame 15 | 16 | class MediaPlayer(QThread): 17 | # 自定义信号对象。参数str就代表这个信号可以传一个字符串 18 | trigger = pyqtSignal(str) 19 | def __init__(self,parent=None): 20 | # 初始化函数 21 | super(MediaPlayer, self).__init__(parent=parent) 22 | 23 | 24 | def run(self): 25 | print("播放警告") 26 | pygame.init() 27 | sound = pygame.mixer.Sound(r"custom/warn.mp3") 28 | sound.set_volume(1) 29 | sound.play() 30 | 31 | 32 | -------------------------------------------------------------------------------- /custom/stackedWidget.py: -------------------------------------------------------------------------------- 1 | from PyQt5.QtGui import * 2 | from PyQt5.QtCore import * 3 | from PyQt5.QtWidgets import * 4 | 5 | 6 | class MyListWidget(QListWidget): 7 | def __init__(self, parent=None): 8 | super().__init__(parent=parent) 9 | self.mainwindow = parent 10 | self.setDragEnabled(True) 11 | # 选中不显示虚线 12 | # self.setEditTriggers(QAbstractItemView.NoEditTriggers) 13 | self.setFocusPolicy(Qt.NoFocus) 14 | 15 | class StackedWidget(MyListWidget): 16 | def __init__(self, parent): 17 | super().__init__(parent=parent) 18 | self.setAcceptDrops(True) 19 | self.setFlow(QListView.TopToBottom) # 设置列表方向 20 | self.setDefaultDropAction(Qt.MoveAction) # 设置拖放为移动而不是复制一个 21 | self.setDragDropMode(QAbstractItemView.InternalMove) # 设置拖放模式, 内部拖放 22 | self.setMinimumWidth(300) 23 | 24 | self.move_item = None 25 | 26 | -------------------------------------------------------------------------------- /custom/tableWidget.py: -------------------------------------------------------------------------------- 1 | from PyQt5.QtWidgets import * 2 | from PyQt5.QtCore import * 3 | 4 | 5 | class TableWidget(QTableWidget): 6 | def __init__(self, parent=None): 7 | super(TableWidget, self).__init__(parent=parent) 8 | self.mainwindow = parent 9 | self.setShowGrid(True) # 显示网格 10 | self.setAlternatingRowColors(True) # 隔行显示颜色 11 | self.setEditTriggers(QAbstractItemView.NoEditTriggers) 12 | self.horizontalHeader().setVisible(False) 13 | self.verticalHeader().setVisible(False) 14 | self.horizontalHeader().sectionResizeMode(QHeaderView.Stretch) 15 | self.verticalHeader().sectionResizeMode(QHeaderView.Stretch) 16 | self.horizontalHeader().setStretchLastSection(True) 17 | self.setFocusPolicy(Qt.NoFocus) 18 | 19 | def signal_connect(self): 20 | for spinbox in self.findChildren(QSpinBox): 21 | spinbox.valueChanged.connect(self.update_item) 22 | for doublespinbox in self.findChildren(QDoubleSpinBox): 23 | doublespinbox.valueChanged.connect(self.update_item) 24 | for combox in self.findChildren(QComboBox): 25 | combox.currentIndexChanged.connect(self.update_item) 26 | for checkbox in self.findChildren(QCheckBox): 27 | checkbox.stateChanged.connect(self.update_item) 28 | 29 | def update_item(self): 30 | param = self.get_params() 31 | self.mainwindow.useListWidget.currentItem().update_params(param) 32 | self.mainwindow.update_image() 33 | 34 | def update_params(self, param=None): 35 | for key in param.keys(): 36 | box = self.findChild(QWidget, name=key) 37 | if isinstance(box, QSpinBox) or isinstance(box, QDoubleSpinBox): 38 | box.setValue(param[key]) 39 | elif isinstance(box, QComboBox): 40 | box.setCurrentIndex(param[key]) 41 | elif isinstance(box, QCheckBox): 42 | box.setChecked(param[key]) 43 | 44 | def get_params(self): 45 | param = {} 46 | for spinbox in self.findChildren(QSpinBox): 47 | param[spinbox.objectName()] = spinbox.value() 48 | for doublespinbox in self.findChildren(QDoubleSpinBox): 49 | param[doublespinbox.objectName()] = doublespinbox.value() 50 | for combox in self.findChildren(QComboBox): 51 | param[combox.objectName()] = combox.currentIndex() 52 | for combox in self.findChildren(QCheckBox): 53 | param[combox.objectName()] = combox.isChecked() 54 | return param 55 | 56 | 57 | class GrayingTableWidget(TableWidget): 58 | def __init__(self, parent=None): 59 | super(GrayingTableWidget, self).__init__(parent=parent) 60 | 61 | 62 | class FilterTabledWidget(TableWidget): 63 | def __init__(self, parent=None): 64 | super(FilterTabledWidget, self).__init__(parent=parent) 65 | 66 | self.kind_comBox = QComboBox() 67 | self.kind_comBox.addItems(['均值滤波', '高斯滤波', '中值滤波']) 68 | self.kind_comBox.setObjectName('kind') 69 | 70 | self.ksize_spinBox = QSpinBox() 71 | self.ksize_spinBox.setObjectName('ksize') 72 | self.ksize_spinBox.setMinimum(1) 73 | self.ksize_spinBox.setSingleStep(2) 74 | 75 | self.setColumnCount(2) 76 | self.setRowCount(2) 77 | self.setItem(0, 0, QTableWidgetItem('类型')) 78 | self.setCellWidget(0, 1, self.kind_comBox) 79 | self.setItem(1, 0, QTableWidgetItem('核大小')) 80 | self.setCellWidget(1, 1, self.ksize_spinBox) 81 | 82 | self.signal_connect() 83 | 84 | 85 | class MorphTabledWidget(TableWidget): 86 | def __init__(self, parent=None): 87 | super(MorphTabledWidget, self).__init__(parent=parent) 88 | 89 | self.op_comBox = QComboBox() 90 | self.op_comBox.addItems(['腐蚀操作', '膨胀操作', '开操作', '闭操作', '梯度操作', '顶帽操作', '黑帽操作']) 91 | self.op_comBox.setObjectName('op') 92 | 93 | self.ksize_spinBox = QSpinBox() 94 | self.ksize_spinBox.setMinimum(1) 95 | self.ksize_spinBox.setSingleStep(2) 96 | self.ksize_spinBox.setObjectName('ksize') 97 | 98 | self.kshape_comBox = QComboBox() 99 | self.kshape_comBox.addItems(['方形', '十字形', '椭圆形']) 100 | self.kshape_comBox.setObjectName('kshape') 101 | 102 | self.setColumnCount(2) 103 | self.setRowCount(3) 104 | self.setItem(0, 0, QTableWidgetItem('类型')) 105 | self.setCellWidget(0, 1, self.op_comBox) 106 | self.setItem(1, 0, QTableWidgetItem('核大小')) 107 | self.setCellWidget(1, 1, self.ksize_spinBox) 108 | self.setItem(2, 0, QTableWidgetItem('核形状')) 109 | self.setCellWidget(2, 1, self.kshape_comBox) 110 | self.signal_connect() 111 | 112 | 113 | class GradTabledWidget(TableWidget): 114 | def __init__(self, parent=None): 115 | super(GradTabledWidget, self).__init__(parent=parent) 116 | 117 | self.kind_comBox = QComboBox() 118 | self.kind_comBox.addItems(['Sobel算子', 'Scharr算子', 'Laplacian算子']) 119 | self.kind_comBox.setObjectName('kind') 120 | 121 | self.ksize_spinBox = QSpinBox() 122 | self.ksize_spinBox.setMinimum(1) 123 | self.ksize_spinBox.setSingleStep(2) 124 | self.ksize_spinBox.setObjectName('ksize') 125 | 126 | self.dx_spinBox = QSpinBox() 127 | self.dx_spinBox.setMaximum(1) 128 | self.dx_spinBox.setMinimum(0) 129 | self.dx_spinBox.setSingleStep(1) 130 | self.dx_spinBox.setObjectName('dx') 131 | 132 | self.dy_spinBox = QSpinBox() 133 | self.dy_spinBox.setMaximum(1) 134 | self.dy_spinBox.setMinimum(0) 135 | self.dy_spinBox.setSingleStep(1) 136 | self.dy_spinBox.setObjectName('dy') 137 | 138 | self.setColumnCount(2) 139 | self.setRowCount(4) 140 | 141 | self.setItem(0, 0, QTableWidgetItem('类型')) 142 | self.setCellWidget(0, 1, self.kind_comBox) 143 | self.setItem(1, 0, QTableWidgetItem('核大小')) 144 | self.setCellWidget(1, 1, self.ksize_spinBox) 145 | self.setItem(2, 0, QTableWidgetItem('x方向')) 146 | self.setCellWidget(2, 1, self.dx_spinBox) 147 | self.setItem(3, 0, QTableWidgetItem('y方向')) 148 | self.setCellWidget(3, 1, self.dy_spinBox) 149 | 150 | self.signal_connect() 151 | 152 | 153 | class ThresholdTableWidget(TableWidget): 154 | def __init__(self, parent=None): 155 | super(ThresholdTableWidget, self).__init__(parent=parent) 156 | 157 | self.thresh_spinBox = QSpinBox() 158 | self.thresh_spinBox.setObjectName('thresh') 159 | self.thresh_spinBox.setMaximum(255) 160 | self.thresh_spinBox.setMinimum(0) 161 | self.thresh_spinBox.setSingleStep(1) 162 | 163 | self.maxval_spinBox = QSpinBox() 164 | self.maxval_spinBox.setObjectName('maxval') 165 | self.maxval_spinBox.setMaximum(255) 166 | self.maxval_spinBox.setMinimum(0) 167 | self.maxval_spinBox.setSingleStep(1) 168 | 169 | self.method_comBox = QComboBox() 170 | self.method_comBox.addItems(['二进制阈值化', '反二进制阈值化', '截断阈值化', '阈值化为0', '反阈值化为0', '大津算法']) 171 | self.method_comBox.setObjectName('method') 172 | 173 | self.setColumnCount(2) 174 | self.setRowCount(3) 175 | 176 | self.setItem(0, 0, QTableWidgetItem('类型')) 177 | self.setCellWidget(0, 1, self.method_comBox) 178 | self.setItem(1, 0, QTableWidgetItem('阈值')) 179 | self.setCellWidget(1, 1, self.thresh_spinBox) 180 | self.setItem(2, 0, QTableWidgetItem('最大值')) 181 | self.setCellWidget(2, 1, self.maxval_spinBox) 182 | 183 | self.signal_connect() 184 | 185 | 186 | class EdgeTableWidget(TableWidget): 187 | def __init__(self, parent=None): 188 | super(EdgeTableWidget, self).__init__(parent=parent) 189 | 190 | self.thresh1_spinBox = QSpinBox() 191 | self.thresh1_spinBox.setMinimum(0) 192 | self.thresh1_spinBox.setMaximum(255) 193 | self.thresh1_spinBox.setSingleStep(1) 194 | self.thresh1_spinBox.setObjectName('thresh1') 195 | 196 | self.thresh2_spinBox = QSpinBox() 197 | self.thresh2_spinBox.setMinimum(0) 198 | self.thresh2_spinBox.setMaximum(255) 199 | self.thresh2_spinBox.setSingleStep(1) 200 | self.thresh2_spinBox.setObjectName('thresh2') 201 | 202 | self.setColumnCount(2) 203 | self.setRowCount(2) 204 | 205 | self.setItem(0, 0, QTableWidgetItem('阈值1')) 206 | self.setCellWidget(0, 1, self.thresh1_spinBox) 207 | self.setItem(1, 0, QTableWidgetItem('阈值2')) 208 | self.setCellWidget(1, 1, self.thresh2_spinBox) 209 | self.signal_connect() 210 | 211 | 212 | class ContourTableWidget(TableWidget): 213 | def __init__(self, parent=None): 214 | super(ContourTableWidget, self).__init__(parent=parent) 215 | 216 | self.bbox_comBox = QComboBox() 217 | self.bbox_comBox.addItems(['正常轮廓', '外接矩形', '最小外接矩形', '最小外接圆']) 218 | self.bbox_comBox.setObjectName('bbox') 219 | 220 | self.mode_comBox = QComboBox() 221 | self.mode_comBox.addItems(['外轮廓', '轮廓列表', '外轮廓与内孔', '轮廓等级树']) 222 | self.mode_comBox.setObjectName('mode') 223 | 224 | self.method_comBox = QComboBox() 225 | self.method_comBox.addItems(['无近似', '简易近似']) 226 | self.method_comBox.setObjectName('method') 227 | 228 | self.setColumnCount(2) 229 | self.setRowCount(3) 230 | 231 | self.setItem(0, 0, QTableWidgetItem('轮廓模式')) 232 | self.setCellWidget(0, 1, self.mode_comBox) 233 | self.setItem(1, 0, QTableWidgetItem('轮廓近似')) 234 | self.setCellWidget(1, 1, self.method_comBox) 235 | self.setItem(2, 0, QTableWidgetItem('边界模式')) 236 | self.setCellWidget(2, 1, self.bbox_comBox) 237 | self.signal_connect() 238 | 239 | 240 | class EqualizeTableWidget(TableWidget): 241 | def __init__(self, parent=None): 242 | super(EqualizeTableWidget, self).__init__(parent=parent) 243 | self.red_checkBox = QCheckBox() 244 | self.red_checkBox.setObjectName('red') 245 | self.red_checkBox.setTristate(False) 246 | self.blue_checkBox = QCheckBox() 247 | self.blue_checkBox.setObjectName('blue') 248 | self.blue_checkBox.setTristate(False) 249 | self.green_checkBox = QCheckBox() 250 | self.green_checkBox.setObjectName('green') 251 | self.green_checkBox.setTristate(False) 252 | 253 | self.setColumnCount(2) 254 | self.setRowCount(3) 255 | 256 | self.setItem(0, 0, QTableWidgetItem('R通道')) 257 | self.setCellWidget(0, 1, self.red_checkBox) 258 | self.setItem(1, 0, QTableWidgetItem('G通道')) 259 | self.setCellWidget(1, 1, self.green_checkBox) 260 | self.setItem(2, 0, QTableWidgetItem('B通道')) 261 | self.setCellWidget(2, 1, self.blue_checkBox) 262 | self.signal_connect() 263 | 264 | 265 | class HoughLineTableWidget(TableWidget): 266 | def __init__(self, parent=None): 267 | super(HoughLineTableWidget, self).__init__(parent=parent) 268 | 269 | self.thresh_spinBox = QSpinBox() 270 | self.thresh_spinBox.setMinimum(0) 271 | self.thresh_spinBox.setSingleStep(1) 272 | self.thresh_spinBox.setObjectName('thresh') 273 | 274 | self.min_length_spinBox = QSpinBox() 275 | self.min_length_spinBox.setMinimum(0) 276 | self.min_length_spinBox.setSingleStep(1) 277 | self.min_length_spinBox.setObjectName('min_length') 278 | 279 | self.max_gap_spinbox = QSpinBox() 280 | self.max_gap_spinbox.setMinimum(0) 281 | self.max_gap_spinbox.setSingleStep(1) 282 | self.max_gap_spinbox.setObjectName('max_gap') 283 | 284 | self.setColumnCount(2) 285 | self.setRowCount(3) 286 | 287 | self.setItem(0, 0, QTableWidgetItem('交点阈值')) 288 | self.setCellWidget(0, 1, self.thresh_spinBox) 289 | self.setItem(1, 0, QTableWidgetItem('最小长度')) 290 | self.setCellWidget(1, 1, self.min_length_spinBox) 291 | self.setItem(2, 0, QTableWidgetItem('最大间距')) 292 | self.setCellWidget(2, 1, self.max_gap_spinbox) 293 | self.signal_connect() 294 | 295 | 296 | class LightTableWidget(TableWidget): 297 | def __init__(self, parent=None): 298 | super(LightTableWidget, self).__init__(parent=parent) 299 | 300 | self.alpha_spinBox = QDoubleSpinBox() 301 | self.alpha_spinBox.setMinimum(0) 302 | self.alpha_spinBox.setMaximum(3) 303 | self.alpha_spinBox.setSingleStep(0.1) 304 | self.alpha_spinBox.setObjectName('alpha') 305 | 306 | self.beta_spinbox = QSpinBox() 307 | self.beta_spinbox.setMinimum(0) 308 | self.beta_spinbox.setSingleStep(1) 309 | self.beta_spinbox.setObjectName('beta') 310 | 311 | self.setColumnCount(2) 312 | self.setRowCount(2) 313 | 314 | self.setItem(0, 0, QTableWidgetItem('alpha')) 315 | self.setCellWidget(0, 1, self.alpha_spinBox) 316 | self.setItem(1, 0, QTableWidgetItem('beta')) 317 | self.setCellWidget(1, 1, self.beta_spinbox) 318 | self.signal_connect() 319 | 320 | 321 | class GammaITabelWidget(TableWidget): 322 | def __init__(self, parent=None): 323 | super(GammaITabelWidget, self).__init__(parent=parent) 324 | self.gamma_spinbox = QDoubleSpinBox() 325 | self.gamma_spinbox.setMinimum(0) 326 | self.gamma_spinbox.setSingleStep(0.1) 327 | self.gamma_spinbox.setObjectName('gamma') 328 | 329 | self.setColumnCount(2) 330 | self.setRowCount(1) 331 | 332 | self.setItem(0, 0, QTableWidgetItem('gamma')) 333 | self.setCellWidget(0, 1, self.gamma_spinbox) 334 | self.signal_connect() 335 | -------------------------------------------------------------------------------- /custom/treeView.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from PyQt5.QtWidgets import * 5 | from PyQt5.QtCore import * 6 | 7 | 8 | class FileSystemTreeView(QTreeView, QDockWidget): 9 | def __init__(self, parent=None): 10 | super().__init__(parent=parent) 11 | self.mainwindow = parent 12 | self.fileSystemModel = QFileSystemModel() 13 | self.fileSystemModel.setRootPath('.') 14 | self.setModel(self.fileSystemModel) 15 | # 隐藏size,date等列 16 | self.setColumnWidth(0, 200) 17 | self.setColumnHidden(1, True) 18 | self.setColumnHidden(2, True) 19 | self.setColumnHidden(3, True) 20 | # 不显示标题栏 21 | self.header().hide() 22 | # 设置动画 23 | self.setAnimated(True) 24 | # 选中不显示虚线 25 | self.setFocusPolicy(Qt.NoFocus) 26 | self.doubleClicked.connect(self.select_image) 27 | self.setMinimumWidth(200) 28 | 29 | def select_image(self, file_index): 30 | file_name = self.fileSystemModel.filePath(file_index) 31 | if file_name.endswith(('.jpg', '.png', '.bmp')): 32 | src_img = cv2.imdecode(np.fromfile(file_name, dtype=np.uint8), -1) 33 | self.mainwindow.change_image(src_img) 34 | 35 | 36 | -------------------------------------------------------------------------------- /custom/warn.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Egrt/YOLO_PyQt5/457fbd769c6c97a69ff31f22fbb02cff4a3e562f/custom/warn.mp3 -------------------------------------------------------------------------------- /flags.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | GRAYING_STACKED_WIDGET = 0 4 | FILTER_STACKED_WIDGET = 1 5 | MORPH_STACKED_WIDGET = 2 6 | GRAD_STACKED_WIDGET = 3 7 | THRESH_STACKED_WIDGET = 4 8 | EDGE_STACKED_WIDGET = 5 9 | 10 | BGR2GRAY_COLOR = 0 11 | GRAY2BGR_COLOR = 1 12 | COLOR = { 13 | BGR2GRAY_COLOR: cv2.COLOR_BGR2GRAY, 14 | GRAY2BGR_COLOR: cv2.COLOR_GRAY2BGR 15 | } 16 | 17 | MEAN_FILTER = 0 18 | GAUSSIAN_FILTER = 1 19 | MEDIAN_FILTER = 2 20 | 21 | ERODE_MORPH_OP = 0 22 | DILATE_MORPH_OP = 1 23 | OPEN_MORPH_OP = 2 24 | CLOSE_MORPH_OP = 3 25 | GRADIENT_MORPH_OP = 4 26 | TOPHAT_MORPH_OP = 5 27 | BLACKHAT_MORPH_OP = 6 28 | 29 | MORPH_OP = { 30 | ERODE_MORPH_OP: cv2.MORPH_ERODE, 31 | DILATE_MORPH_OP: cv2.MORPH_DILATE, 32 | OPEN_MORPH_OP: cv2.MORPH_OPEN, 33 | CLOSE_MORPH_OP: cv2.MORPH_CLOSE, 34 | GRADIENT_MORPH_OP: cv2.MORPH_GRADIENT, 35 | TOPHAT_MORPH_OP: cv2.MORPH_TOPHAT, 36 | BLACKHAT_MORPH_OP: cv2.MORPH_BLACKHAT 37 | } 38 | 39 | RECT_MORPH_SHAPE = 0 40 | CROSS_MORPH_SHAPE = 1 41 | ELLIPSE_MORPH_SHAPE = 2 42 | 43 | MORPH_SHAPE = { 44 | RECT_MORPH_SHAPE: cv2.MORPH_RECT, 45 | CROSS_MORPH_SHAPE: cv2.MORPH_CROSS, 46 | ELLIPSE_MORPH_SHAPE: cv2.MORPH_ELLIPSE 47 | } 48 | 49 | SOBEL_GRAD = 0 50 | SCHARR_GRAD = 1 51 | LAPLACIAN_GRAD = 2 52 | 53 | BINARY_THRESH_METHOD = 0 54 | BINARY_INV_THRESH_METHOD = 1 55 | TRUNC_THRESH_METHOD = 2 56 | TOZERO_THRESH_METHOD = 3 57 | TOZERO_INV_THRESH_METHOD = 4 58 | OTSU_THRESH_METHOD = 5 59 | THRESH_METHOD = { 60 | BINARY_THRESH_METHOD: cv2.THRESH_BINARY, # 0 61 | BINARY_INV_THRESH_METHOD: cv2.THRESH_BINARY_INV, # 1 62 | TRUNC_THRESH_METHOD: cv2.THRESH_TRUNC, # 2 63 | TOZERO_THRESH_METHOD: cv2.THRESH_TOZERO, # 3 64 | TOZERO_INV_THRESH_METHOD: cv2.THRESH_TOZERO_INV, # 4 65 | OTSU_THRESH_METHOD: cv2.THRESH_OTSU # 5 66 | } 67 | 68 | EXTERNAL_CONTOUR_MODE = 0 69 | LIST_CONTOUR_MODE = 1 70 | CCOMP_CONTOUR_MODE = 2 71 | TREE_CONTOUR_MODE = 3 72 | CONTOUR_MODE = { 73 | EXTERNAL_CONTOUR_MODE: cv2.RETR_EXTERNAL, 74 | LIST_CONTOUR_MODE: cv2.RETR_LIST, 75 | CCOMP_CONTOUR_MODE: cv2.RETR_CCOMP, 76 | TREE_CONTOUR_MODE: cv2.RETR_TREE 77 | } 78 | 79 | NONE_CONTOUR_METHOD = 0 80 | SIMPLE_CONTOUR_METHOD = 1 81 | CONTOUR_METHOD = { 82 | NONE_CONTOUR_METHOD: cv2.CHAIN_APPROX_NONE, 83 | SIMPLE_CONTOUR_METHOD: cv2.CHAIN_APPROX_SIMPLE 84 | } 85 | 86 | NORMAL_CONTOUR = 0 87 | RECT_CONTOUR = 1 88 | MINRECT_CONTOUR = 2 89 | MINCIRCLE_CONTOUR = 3 90 | 91 | 92 | # 均衡化 93 | BLUE_CHANNEL = 0 94 | GREEN_CHANNEL = 1 95 | RED_CHANNEL = 2 96 | ALL_CHANNEL = 3 97 | -------------------------------------------------------------------------------- /font/simhei.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Egrt/YOLO_PyQt5/457fbd769c6c97a69ff31f22fbb02cff4a3e562f/font/simhei.ttf -------------------------------------------------------------------------------- /icons/branch-close.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Egrt/YOLO_PyQt5/457fbd769c6c97a69ff31f22fbb02cff4a3e562f/icons/branch-close.png -------------------------------------------------------------------------------- /icons/branch-open.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Egrt/YOLO_PyQt5/457fbd769c6c97a69ff31f22fbb02cff4a3e562f/icons/branch-open.png -------------------------------------------------------------------------------- /icons/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Egrt/YOLO_PyQt5/457fbd769c6c97a69ff31f22fbb02cff4a3e562f/icons/mask.png -------------------------------------------------------------------------------- /icons/video.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Egrt/YOLO_PyQt5/457fbd769c6c97a69ff31f22fbb02cff4a3e562f/icons/video.png -------------------------------------------------------------------------------- /icons/右旋转.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Egrt/YOLO_PyQt5/457fbd769c6c97a69ff31f22fbb02cff4a3e562f/icons/右旋转.png -------------------------------------------------------------------------------- /icons/图片.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Egrt/YOLO_PyQt5/457fbd769c6c97a69ff31f22fbb02cff4a3e562f/icons/图片.png -------------------------------------------------------------------------------- /icons/左旋转.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Egrt/YOLO_PyQt5/457fbd769c6c97a69ff31f22fbb02cff4a3e562f/icons/左旋转.png -------------------------------------------------------------------------------- /icons/摄像头.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Egrt/YOLO_PyQt5/457fbd769c6c97a69ff31f22fbb02cff4a3e562f/icons/摄像头.png -------------------------------------------------------------------------------- /icons/识别.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Egrt/YOLO_PyQt5/457fbd769c6c97a69ff31f22fbb02cff4a3e562f/icons/识别.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import sys 3 | import time 4 | 5 | import cv2 6 | import numpy as np 7 | import qdarkstyle 8 | from PIL import Image 9 | from PyQt5 import QtCore, QtGui, QtWidgets 10 | from PyQt5.Qt import QThread 11 | from PyQt5.QtCore import * 12 | from PyQt5.QtGui import * 13 | from PyQt5.QtWidgets import * 14 | 15 | from custom.graphicsView import GraphicsView 16 | from custom.listWidgets import * 17 | from custom.stackedWidget import * 18 | from custom.treeView import FileSystemTreeView 19 | from yolo import YOLO 20 | 21 | ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID("myappid") 22 | 23 | # 多线程实时检测 24 | class DetectThread(QThread): 25 | Send_signal = pyqtSignal(np.ndarray, int) 26 | 27 | def __init__(self, fileName): 28 | super(DetectThread, self).__init__() 29 | self.capture = cv2.VideoCapture(fileName) 30 | self.count = 0 31 | self.warn = False # 是否发送警告信号 32 | 33 | def run(self): 34 | ret, self.frame = self.capture.read() 35 | while ret: 36 | ret, self.frame = self.capture.read() 37 | self.detectCall() 38 | 39 | def detectCall(self): 40 | fps = 0.0 41 | t1 = time.time() 42 | # 读取某一帧 43 | frame = self.frame 44 | # 格式转变,BGRtoRGB 45 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 46 | # 转变成Image 47 | frame = Image.fromarray(np.uint8(frame)) 48 | # 进行检测 49 | frame_new, predicted_class = yolo.detect_image(frame) 50 | frame = np.array(frame_new) 51 | if predicted_class == "face": 52 | self.count = self.count+1 53 | else: 54 | self.count = 0 55 | # RGBtoBGR满足opencv显示格式 56 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) 57 | fps = (fps + (1./(time.time()-t1))) / 2 58 | print("fps= %.2f" % (fps)) 59 | frame = cv2.putText(frame, "fps= %.2f" % ( 60 | fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) 61 | if self.count > 30: 62 | self.count = 0 63 | self.warn = True 64 | else: 65 | self.warn = False 66 | # 发送pyqt信号 67 | self.Send_signal.emit(frame, self.warn) 68 | 69 | class MyApp(QMainWindow): 70 | def __init__(self): 71 | super(MyApp, self).__init__() 72 | 73 | self.cap = cv2.VideoCapture() 74 | self.CAM_NUM = 0 75 | self.thread_status = False # 判断识别线程是否开启 76 | self.tool_bar = self.addToolBar('工具栏') 77 | self.action_right_rotate = QAction( 78 | QIcon("icons/右旋转.png"), "向右旋转90", self) 79 | self.action_left_rotate = QAction( 80 | QIcon("icons/左旋转.png"), "向左旋转90°", self) 81 | self.action_opencam = QAction(QIcon("icons/摄像头.png"), "开启摄像头", self) 82 | self.action_video = QAction(QIcon("icons/video.png"), "加载视频", self) 83 | self.action_image = QAction(QIcon("icons/图片.png"), "加载图片", self) 84 | self.action_right_rotate.triggered.connect(self.right_rotate) 85 | self.action_left_rotate.triggered.connect(self.left_rotate) 86 | self.action_opencam.triggered.connect(self.opencam) 87 | self.action_video.triggered.connect(self.openvideo) 88 | self.action_image.triggered.connect(self.openimage) 89 | self.tool_bar.addActions((self.action_left_rotate, self.action_right_rotate, 90 | self.action_opencam, self.action_video, self.action_image)) 91 | self.stackedWidget = StackedWidget(self) 92 | self.fileSystemTreeView = FileSystemTreeView(self) 93 | self.graphicsView = GraphicsView(self) 94 | self.dock_file = QDockWidget(self) 95 | self.dock_file.setWidget(self.fileSystemTreeView) 96 | self.dock_file.setTitleBarWidget(QLabel('目录')) 97 | self.dock_file.setFeatures(QDockWidget.NoDockWidgetFeatures) 98 | 99 | self.dock_attr = QDockWidget(self) 100 | self.dock_attr.setWidget(self.stackedWidget) 101 | self.dock_attr.setTitleBarWidget(QLabel('上报数据')) 102 | self.dock_attr.setFeatures(QDockWidget.NoDockWidgetFeatures) 103 | 104 | self.setCentralWidget(self.graphicsView) 105 | self.addDockWidget(Qt.LeftDockWidgetArea, self.dock_file) 106 | self.addDockWidget(Qt.RightDockWidgetArea, self.dock_attr) 107 | 108 | self.setWindowTitle('口罩佩戴检测') 109 | self.setWindowIcon(QIcon('icons/mask.png')) 110 | self.src_img = None 111 | self.cur_img = None 112 | 113 | def update_image(self): 114 | if self.src_img is None: 115 | return 116 | img = self.process_image() 117 | self.cur_img = img 118 | self.graphicsView.update_image(img) 119 | 120 | def change_image(self, img): 121 | self.src_img = img 122 | img = self.process_image() 123 | self.cur_img = img 124 | self.graphicsView.change_image(img) 125 | 126 | def process_image(self): 127 | img = self.src_img.copy() 128 | for i in range(self.useListWidget.count()): 129 | img = self.useListWidget.item(i)(img) 130 | return img 131 | 132 | def right_rotate(self): 133 | self.graphicsView.rotate(90) 134 | 135 | def left_rotate(self): 136 | self.graphicsView.rotate(-90) 137 | 138 | def add_item(self, image): 139 | # 总Widget 140 | wight = QWidget() 141 | # 总体横向布局 142 | layout_main = QHBoxLayout() 143 | map_l = QLabel() # 图片显示 144 | map_l.setFixedSize(60, 40) 145 | map_l.setPixmap(image.scaled(60, 40)) 146 | # 右边的纵向布局 147 | layout_right = QVBoxLayout() 148 | # 右下的的横向布局 149 | layout_right_down = QHBoxLayout() # 右下的横向布局 150 | layout_right_down.addWidget( 151 | QLabel(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))) 152 | 153 | # 按照从左到右, 从上到下布局添加 154 | layout_main.addWidget(map_l) # 最左边的图片 155 | layout_right.addWidget(QLabel('警告!检测到未佩戴口罩')) # 右边的纵向布局 156 | layout_right.addLayout(layout_right_down) # 右下角横向布局 157 | layout_main.addLayout(layout_right) # 右边的布局 158 | wight.setLayout(layout_main) # 布局给wight 159 | item = QListWidgetItem() # 创建QListWidgetItem对象 160 | item.setSizeHint(QSize(300, 80)) # 设置QListWidgetItem大小 161 | self.stackedWidget.addItem(item) # 添加item 162 | self.stackedWidget.setItemWidget(item, wight) # 为item设置widget 163 | 164 | def openvideo(self): 165 | print(self.thread_status) 166 | if self.thread_status == False: 167 | 168 | fileName, filetype = QFileDialog.getOpenFileName( 169 | self, "选择视频", "D:/", "*.mp4;;*.flv;;All Files(*)") 170 | 171 | flag = self.cap.open(fileName) 172 | if flag == False: 173 | msg = QtWidgets.QMessageBox.warning(self, u"警告", u"请选择视频文件", 174 | buttons=QtWidgets.QMessageBox.Ok, 175 | defaultButton=QtWidgets.QMessageBox.Ok) 176 | else: 177 | self.detectThread = DetectThread(fileName) 178 | self.detectThread.Send_signal.connect(self.Display) 179 | self.detectThread.start() 180 | self.action_video.setText('关闭视频') 181 | self.thread_status = True 182 | elif self.thread_status == True: 183 | self.detectThread.terminate() 184 | if self.cap.isOpened(): 185 | self.cap.release() 186 | self.action_video.setText('打开视频') 187 | self.thread_status = False 188 | 189 | def openimage(self): 190 | if self.thread_status == False: 191 | fileName, filetype = QFileDialog.getOpenFileName( 192 | self, "选择图片", "D:/", "*.jpg;;*.png;;All Files(*)") 193 | if fileName != '': 194 | src_img = Image.open(fileName) 195 | r_image, predicted_class = yolo.detect_image(src_img) 196 | r_image = np.array(r_image) 197 | showImage = QtGui.QImage( 198 | r_image.data, r_image.shape[1], r_image.shape[0], QtGui.QImage.Format_RGB888) 199 | self.graphicsView.set_image(QtGui.QPixmap.fromImage(showImage)) 200 | 201 | def opencam(self): 202 | if self.thread_status == False: 203 | flag = self.cap.open(self.CAM_NUM) 204 | if flag == False: 205 | msg = QtWidgets.QMessageBox.warning(self, u"警告", u"请检测相机与电脑是否连接正确", 206 | buttons=QtWidgets.QMessageBox.Ok, 207 | defaultButton=QtWidgets.QMessageBox.Ok) 208 | else: 209 | self.detectThread = DetectThread(self.CAM_NUM) 210 | self.detectThread.Send_signal.connect(self.Display) 211 | self.detectThread.start() 212 | self.action_video.setText('关闭视频') 213 | self.thread_status = True 214 | else: 215 | self.detectThread.terminate() 216 | if self.cap.isOpened(): 217 | self.cap.release() 218 | self.action_video.setText('打开视频') 219 | self.thread_status = False 220 | 221 | def Display(self, frame, warn): 222 | 223 | im = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 224 | showImage = QtGui.QImage( 225 | im.data, im.shape[1], im.shape[0], QtGui.QImage.Format_RGB888) 226 | self.graphicsView.set_image(QtGui.QPixmap.fromImage(showImage)) 227 | 228 | def closeEvent(self, event): 229 | ok = QtWidgets.QPushButton() 230 | cacel = QtWidgets.QPushButton() 231 | msg = QtWidgets.QMessageBox( 232 | QtWidgets.QMessageBox.Warning, u"关闭", u"确定退出?") 233 | msg.addButton(ok, QtWidgets.QMessageBox.ActionRole) 234 | msg.addButton(cacel, QtWidgets.QMessageBox.RejectRole) 235 | ok.setText(u'确定') 236 | cacel.setText(u'取消') 237 | if msg.exec_() == QtWidgets.QMessageBox.RejectRole: 238 | event.ignore() 239 | else: 240 | if self.thread_status == True: 241 | self.detectThread.terminate() 242 | if self.cap.isOpened(): 243 | self.cap.release() 244 | event.accept() 245 | 246 | 247 | if __name__ == "__main__": 248 | # 初始化yolo模型 249 | yolo = YOLO() 250 | app = QApplication(sys.argv) 251 | app.setStyleSheet(qdarkstyle.load_stylesheet_pyqt5()) 252 | window = MyApp() 253 | window.show() 254 | sys.exit(app.exec_()) 255 | -------------------------------------------------------------------------------- /maskimg/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Egrt/YOLO_PyQt5/457fbd769c6c97a69ff31f22fbb02cff4a3e562f/maskimg/demo.png -------------------------------------------------------------------------------- /model_data/coco_classes.txt: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorbike 5 | aeroplane 6 | bus 7 | train 8 | truck 9 | boat 10 | traffic light 11 | fire hydrant 12 | stop sign 13 | parking meter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sports ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard 39 | tennis racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | sofa 59 | pottedplant 60 | bed 61 | diningtable 62 | toilet 63 | tvmonitor 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cell phone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear 79 | hair drier 80 | toothbrush 81 | -------------------------------------------------------------------------------- /model_data/voc_classes.txt: -------------------------------------------------------------------------------- 1 | face 2 | face_mask -------------------------------------------------------------------------------- /model_data/yolo_anchors.txt: -------------------------------------------------------------------------------- 1 | 12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401 -------------------------------------------------------------------------------- /nets/mobilenet_v1.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.models as models 7 | import torchvision.models._utils as _utils 8 | from torch.autograd import Variable 9 | 10 | 11 | def conv_bn(inp, oup, stride = 1): 12 | return nn.Sequential( 13 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 14 | nn.BatchNorm2d(oup), 15 | nn.ReLU6(inplace=True) 16 | ) 17 | 18 | def conv_dw(inp, oup, stride = 1): 19 | return nn.Sequential( 20 | # part1 21 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 22 | nn.BatchNorm2d(inp), 23 | nn.ReLU6(inplace=True), 24 | 25 | # part2 26 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 27 | nn.BatchNorm2d(oup), 28 | nn.ReLU6(inplace=True), 29 | ) 30 | 31 | class MobileNetV1(nn.Module): 32 | def __init__(self): 33 | super(MobileNetV1, self).__init__() 34 | self.stage1 = nn.Sequential( 35 | # 416,416,3 -> 208,208,32 36 | conv_bn(3, 32, 2), 37 | # 208,208,32 -> 208,208,64 38 | conv_dw(32, 64, 1), 39 | 40 | # 208,208,64 -> 104,104,128 41 | conv_dw(64, 128, 2), 42 | conv_dw(128, 128, 1), 43 | 44 | # 104,104,128 -> 52,52,256 45 | conv_dw(128, 256, 2), 46 | conv_dw(256, 256, 1), 47 | ) 48 | # 52,52,256 -> 26,26,512 49 | self.stage2 = nn.Sequential( 50 | conv_dw(256, 512, 2), 51 | conv_dw(512, 512, 1), 52 | conv_dw(512, 512, 1), 53 | conv_dw(512, 512, 1), 54 | conv_dw(512, 512, 1), 55 | conv_dw(512, 512, 1), 56 | ) 57 | # 26,26,512 -> 13,13,1024 58 | self.stage3 = nn.Sequential( 59 | conv_dw(512, 1024, 2), 60 | conv_dw(1024, 1024, 1), 61 | ) 62 | self.avg = nn.AdaptiveAvgPool2d((1,1)) 63 | self.fc = nn.Linear(1024, 1000) 64 | 65 | def forward(self, x): 66 | x = self.stage1(x) 67 | x = self.stage2(x) 68 | x = self.stage3(x) 69 | x = self.avg(x) 70 | # x = self.model(x) 71 | x = x.view(-1, 1024) 72 | x = self.fc(x) 73 | return x 74 | 75 | def mobilenet_v1(pretrained=False, progress=True): 76 | model = MobileNetV1() 77 | if pretrained: 78 | print("mobilenet_v1 has no pretrained model") 79 | return model 80 | 81 | if __name__ == "__main__": 82 | import torch 83 | from torchsummary import summary 84 | 85 | # 需要使用device来指定网络在GPU还是CPU运行 86 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 87 | model = mobilenet_v1().to(device) 88 | summary(model, input_size=(3, 416, 416)) 89 | -------------------------------------------------------------------------------- /nets/mobilenet_v2.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | # from torchvision.models.utils import load_state_dict_from_url 3 | 4 | model_urls = { 5 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 6 | } 7 | 8 | 9 | def _make_divisible(v, divisor, min_value=None): 10 | if min_value is None: 11 | min_value = divisor 12 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 13 | if new_v < 0.9 * v: 14 | new_v += divisor 15 | return new_v 16 | 17 | class ConvBNReLU(nn.Sequential): 18 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): 19 | padding = (kernel_size - 1) // 2 20 | super(ConvBNReLU, self).__init__( 21 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 22 | nn.BatchNorm2d(out_planes), 23 | nn.ReLU6(inplace=True) 24 | ) 25 | 26 | class InvertedResidual(nn.Module): 27 | def __init__(self, inp, oup, stride, expand_ratio): 28 | super(InvertedResidual, self).__init__() 29 | self.stride = stride 30 | assert stride in [1, 2] 31 | 32 | hidden_dim = int(round(inp * expand_ratio)) 33 | self.use_res_connect = self.stride == 1 and inp == oup 34 | 35 | layers = [] 36 | if expand_ratio != 1: 37 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 38 | 39 | layers.extend([ 40 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), 41 | 42 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 43 | nn.BatchNorm2d(oup), 44 | ]) 45 | self.conv = nn.Sequential(*layers) 46 | 47 | def forward(self, x): 48 | if self.use_res_connect: 49 | return x + self.conv(x) 50 | else: 51 | return self.conv(x) 52 | 53 | 54 | class MobileNetV2(nn.Module): 55 | def __init__(self, num_classes=1000, width_mult=1.0, inverted_residual_setting=None, round_nearest=8): 56 | super(MobileNetV2, self).__init__() 57 | block = InvertedResidual 58 | input_channel = 32 59 | last_channel = 1280 60 | 61 | if inverted_residual_setting is None: 62 | inverted_residual_setting = [ 63 | # t, c, n, s 64 | # 208,208,32 -> 208,208,16 65 | [1, 16, 1, 1], 66 | # 208,208,16 -> 104,104,24 67 | [6, 24, 2, 2], 68 | # 104,104,24 -> 52,52,32 69 | [6, 32, 3, 2], 70 | 71 | # 52,52,32 -> 26,26,64 72 | [6, 64, 4, 2], 73 | # 26,26,64 -> 26,26,96 74 | [6, 96, 3, 1], 75 | 76 | # 26,26,96 -> 13,13,160 77 | [6, 160, 3, 2], 78 | # 13,13,160 -> 13,13,320 79 | [6, 320, 1, 1], 80 | ] 81 | 82 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 83 | raise ValueError("inverted_residual_setting should be non-empty " 84 | "or a 4-element list, got {}".format(inverted_residual_setting)) 85 | 86 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 87 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 88 | 89 | # 416,416,3 -> 208,208,32 90 | features = [ConvBNReLU(3, input_channel, stride=2)] 91 | 92 | for t, c, n, s in inverted_residual_setting: 93 | output_channel = _make_divisible(c * width_mult, round_nearest) 94 | for i in range(n): 95 | stride = s if i == 0 else 1 96 | features.append(block(input_channel, output_channel, stride, expand_ratio=t)) 97 | input_channel = output_channel 98 | 99 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) 100 | self.features = nn.Sequential(*features) 101 | 102 | self.classifier = nn.Sequential( 103 | nn.Dropout(0.2), 104 | nn.Linear(self.last_channel, num_classes), 105 | ) 106 | 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 110 | if m.bias is not None: 111 | nn.init.zeros_(m.bias) 112 | elif isinstance(m, nn.BatchNorm2d): 113 | nn.init.ones_(m.weight) 114 | nn.init.zeros_(m.bias) 115 | elif isinstance(m, nn.Linear): 116 | nn.init.normal_(m.weight, 0, 0.01) 117 | nn.init.zeros_(m.bias) 118 | 119 | def forward(self, x): 120 | x = self.features(x) 121 | x = x.mean([2, 3]) 122 | x = self.classifier(x) 123 | return x 124 | 125 | def mobilenet_v2(pretrained=False, progress=True): 126 | model = MobileNetV2() 127 | # if pretrained: 128 | # state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], model_dir="model_data", 129 | # progress=progress) 130 | # model.load_state_dict(state_dict) 131 | 132 | return model 133 | 134 | if __name__ == "__main__": 135 | print(mobilenet_v2()) 136 | -------------------------------------------------------------------------------- /nets/mobilenet_v3.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def _make_divisible(v, divisor, min_value=None): 8 | if min_value is None: 9 | min_value = divisor 10 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 11 | # Make sure that round down does not go down by more than 10%. 12 | if new_v < 0.9 * v: 13 | new_v += divisor 14 | return new_v 15 | 16 | class h_sigmoid(nn.Module): 17 | def __init__(self, inplace=True): 18 | super(h_sigmoid, self).__init__() 19 | self.relu = nn.ReLU6(inplace=inplace) 20 | 21 | def forward(self, x): 22 | return self.relu(x + 3) / 6 23 | 24 | 25 | class h_swish(nn.Module): 26 | def __init__(self, inplace=True): 27 | super(h_swish, self).__init__() 28 | self.sigmoid = h_sigmoid(inplace=inplace) 29 | 30 | def forward(self, x): 31 | return x * self.sigmoid(x) 32 | 33 | 34 | class SELayer(nn.Module): 35 | def __init__(self, channel, reduction=4): 36 | super(SELayer, self).__init__() 37 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 38 | self.fc = nn.Sequential( 39 | nn.Linear(channel, _make_divisible(channel // reduction, 8)), 40 | nn.ReLU(inplace=True), 41 | nn.Linear(_make_divisible(channel // reduction, 8), channel), 42 | h_sigmoid() 43 | ) 44 | 45 | def forward(self, x): 46 | b, c, _, _ = x.size() 47 | y = self.avg_pool(x).view(b, c) 48 | y = self.fc(y).view(b, c, 1, 1) 49 | return x * y 50 | 51 | 52 | def conv_3x3_bn(inp, oup, stride): 53 | return nn.Sequential( 54 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 55 | nn.BatchNorm2d(oup), 56 | h_swish() 57 | ) 58 | 59 | 60 | def conv_1x1_bn(inp, oup): 61 | return nn.Sequential( 62 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 63 | nn.BatchNorm2d(oup), 64 | h_swish() 65 | ) 66 | 67 | 68 | class InvertedResidual(nn.Module): 69 | def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs): 70 | super(InvertedResidual, self).__init__() 71 | assert stride in [1, 2] 72 | 73 | self.identity = stride == 1 and inp == oup 74 | 75 | if inp == hidden_dim: 76 | self.conv = nn.Sequential( 77 | # dw 78 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim, bias=False), 79 | nn.BatchNorm2d(hidden_dim), 80 | h_swish() if use_hs else nn.ReLU(inplace=True), 81 | # Squeeze-and-Excite 82 | SELayer(hidden_dim) if use_se else nn.Identity(), 83 | # pw-linear 84 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 85 | nn.BatchNorm2d(oup), 86 | ) 87 | else: 88 | self.conv = nn.Sequential( 89 | 90 | # pw 91 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 92 | nn.BatchNorm2d(hidden_dim), 93 | h_swish() if use_hs else nn.ReLU(inplace=True), 94 | 95 | # dw 96 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim, bias=False), 97 | nn.BatchNorm2d(hidden_dim), 98 | 99 | # Squeeze-and-Excite 100 | SELayer(hidden_dim) if use_se else nn.Identity(), 101 | 102 | h_swish() if use_hs else nn.ReLU(inplace=True), 103 | 104 | # pw-linear 105 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 106 | nn.BatchNorm2d(oup), 107 | ) 108 | 109 | def forward(self, x): 110 | if self.identity: 111 | return x + self.conv(x) 112 | else: 113 | return self.conv(x) 114 | 115 | 116 | class MobileNetV3(nn.Module): 117 | def __init__(self, num_classes=1000, width_mult=1.): 118 | super(MobileNetV3, self).__init__() 119 | # setting of inverted residual blocks 120 | self.cfgs = [ 121 | #` k, t, c, SE,HS,s 122 | # 208,208,16 -> 208,208,16 123 | [3, 1, 16, 0, 0, 1], 124 | 125 | # 208,208,16 -> 104,104,24 126 | [3, 4, 24, 0, 0, 2], 127 | [3, 3, 24, 0, 0, 1], 128 | 129 | # 104,104,24 -> 52,52,40 130 | [5, 3, 40, 1, 0, 2], 131 | [5, 3, 40, 1, 0, 1], 132 | [5, 3, 40, 1, 0, 1], 133 | 134 | # 52,52,40 -> 26,26,80 135 | [3, 6, 80, 0, 1, 2], 136 | [3, 2.5, 80, 0, 1, 1], 137 | [3, 2.3, 80, 0, 1, 1], 138 | [3, 2.3, 80, 0, 1, 1], 139 | 140 | # 26,26,80 -> 26,26,112 141 | [3, 6, 112, 1, 1, 1], 142 | [3, 6, 112, 1, 1, 1], 143 | 144 | # 26,26,112 -> 13,13,160 145 | [5, 6, 160, 1, 1, 2], 146 | [5, 6, 160, 1, 1, 1], 147 | [5, 6, 160, 1, 1, 1] 148 | ] 149 | 150 | input_channel = _make_divisible(16 * width_mult, 8) 151 | # 416,416,3 -> 208,208,16 152 | layers = [conv_3x3_bn(3, input_channel, 2)] 153 | 154 | block = InvertedResidual 155 | for k, t, c, use_se, use_hs, s in self.cfgs: 156 | output_channel = _make_divisible(c * width_mult, 8) 157 | exp_size = _make_divisible(input_channel * t, 8) 158 | layers.append(block(input_channel, exp_size, output_channel, k, s, use_se, use_hs)) 159 | input_channel = output_channel 160 | self.features = nn.Sequential(*layers) 161 | 162 | self.conv = conv_1x1_bn(input_channel, exp_size) 163 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 164 | output_channel = _make_divisible(1280 * width_mult, 8) if width_mult > 1.0 else 1280 165 | self.classifier = nn.Sequential( 166 | nn.Linear(exp_size, output_channel), 167 | h_swish(), 168 | nn.Dropout(0.2), 169 | nn.Linear(output_channel, num_classes), 170 | ) 171 | 172 | self._initialize_weights() 173 | 174 | def forward(self, x): 175 | x = self.features(x) 176 | x = self.conv(x) 177 | x = self.avgpool(x) 178 | x = x.view(x.size(0), -1) 179 | x = self.classifier(x) 180 | return x 181 | 182 | def _initialize_weights(self): 183 | for m in self.modules(): 184 | if isinstance(m, nn.Conv2d): 185 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 186 | m.weight.data.normal_(0, math.sqrt(2. / n)) 187 | if m.bias is not None: 188 | m.bias.data.zero_() 189 | elif isinstance(m, nn.BatchNorm2d): 190 | m.weight.data.fill_(1) 191 | m.bias.data.zero_() 192 | elif isinstance(m, nn.Linear): 193 | n = m.weight.size(1) 194 | m.weight.data.normal_(0, 0.01) 195 | m.bias.data.zero_() 196 | 197 | def mobilenet_v3(pretrained=False, **kwargs): 198 | model = MobileNetV3(**kwargs) 199 | if pretrained: 200 | state_dict = torch.load('./model_data/mobilenetv3-large-1cd25616.pth') 201 | model.load_state_dict(state_dict, strict=True) 202 | return model 203 | 204 | -------------------------------------------------------------------------------- /nets/yolo4.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from nets.mobilenet_v1 import mobilenet_v1 7 | from nets.mobilenet_v2 import mobilenet_v2 8 | from nets.mobilenet_v3 import mobilenet_v3 9 | 10 | 11 | class MobileNetV1(nn.Module): 12 | def __init__(self, pretrained = False): 13 | super(MobileNetV1, self).__init__() 14 | self.model = mobilenet_v1(pretrained=pretrained) 15 | 16 | def forward(self, x): 17 | out3 = self.model.stage1(x) 18 | out4 = self.model.stage2(out3) 19 | out5 = self.model.stage3(out4) 20 | return out3, out4, out5 21 | 22 | class MobileNetV2(nn.Module): 23 | def __init__(self, pretrained = False): 24 | super(MobileNetV2, self).__init__() 25 | self.model = mobilenet_v2(pretrained=pretrained) 26 | 27 | def forward(self, x): 28 | out3 = self.model.features[:7](x) 29 | out4 = self.model.features[7:14](out3) 30 | out5 = self.model.features[14:18](out4) 31 | return out3, out4, out5 32 | 33 | class MobileNetV3(nn.Module): 34 | def __init__(self, pretrained = False): 35 | super(MobileNetV3, self).__init__() 36 | self.model = mobilenet_v3(pretrained=pretrained) 37 | 38 | def forward(self, x): 39 | out3 = self.model.features[:7](x) 40 | out4 = self.model.features[7:13](out3) 41 | out5 = self.model.features[13:16](out4) 42 | return out3, out4, out5 43 | 44 | def conv2d(filter_in, filter_out, kernel_size, groups=1, stride=1): 45 | pad = (kernel_size - 1) // 2 if kernel_size else 0 46 | return nn.Sequential(OrderedDict([ 47 | ("conv", nn.Conv2d(filter_in, filter_out, kernel_size=kernel_size, stride=stride, padding=pad, groups=groups, bias=False)), 48 | ("bn", nn.BatchNorm2d(filter_out)), 49 | ("relu", nn.ReLU6(inplace=True)), 50 | ])) 51 | 52 | def conv_dw(filter_in, filter_out, stride = 1): 53 | return nn.Sequential( 54 | nn.Conv2d(filter_in, filter_in, 3, stride, 1, groups=filter_in, bias=False), 55 | nn.BatchNorm2d(filter_in), 56 | nn.ReLU6(inplace=True), 57 | 58 | nn.Conv2d(filter_in, filter_out, 1, 1, 0, bias=False), 59 | nn.BatchNorm2d(filter_out), 60 | nn.ReLU6(inplace=True), 61 | ) 62 | 63 | #---------------------------------------------------# 64 | # SPP结构,利用不同大小的池化核进行池化 65 | # 池化后堆叠 66 | #---------------------------------------------------# 67 | class SpatialPyramidPooling(nn.Module): 68 | def __init__(self, pool_sizes=[5, 9, 13]): 69 | super(SpatialPyramidPooling, self).__init__() 70 | 71 | self.maxpools = nn.ModuleList([nn.MaxPool2d(pool_size, 1, pool_size//2) for pool_size in pool_sizes]) 72 | 73 | def forward(self, x): 74 | features = [maxpool(x) for maxpool in self.maxpools[::-1]] 75 | features = torch.cat(features + [x], dim=1) 76 | 77 | return features 78 | 79 | #---------------------------------------------------# 80 | # 卷积 + 上采样 81 | #---------------------------------------------------# 82 | class Upsample(nn.Module): 83 | def __init__(self, in_channels, out_channels): 84 | super(Upsample, self).__init__() 85 | 86 | self.upsample = nn.Sequential( 87 | conv2d(in_channels, out_channels, 1), 88 | nn.Upsample(scale_factor=2, mode='nearest') 89 | ) 90 | 91 | def forward(self, x,): 92 | x = self.upsample(x) 93 | return x 94 | 95 | #---------------------------------------------------# 96 | # 三次卷积块 97 | #---------------------------------------------------# 98 | def make_three_conv(filters_list, in_filters): 99 | m = nn.Sequential( 100 | conv2d(in_filters, filters_list[0], 1), 101 | conv_dw(filters_list[0], filters_list[1]), 102 | conv2d(filters_list[1], filters_list[0], 1), 103 | ) 104 | return m 105 | 106 | #---------------------------------------------------# 107 | # 五次卷积块 108 | #---------------------------------------------------# 109 | def make_five_conv(filters_list, in_filters): 110 | m = nn.Sequential( 111 | conv2d(in_filters, filters_list[0], 1), 112 | conv_dw(filters_list[0], filters_list[1]), 113 | conv2d(filters_list[1], filters_list[0], 1), 114 | conv_dw(filters_list[0], filters_list[1]), 115 | conv2d(filters_list[1], filters_list[0], 1), 116 | ) 117 | return m 118 | 119 | #---------------------------------------------------# 120 | # 最后获得yolov4的输出 121 | #---------------------------------------------------# 122 | def yolo_head(filters_list, in_filters): 123 | m = nn.Sequential( 124 | conv_dw(in_filters, filters_list[0]), 125 | 126 | nn.Conv2d(filters_list[0], filters_list[1], 1), 127 | ) 128 | return m 129 | 130 | 131 | #---------------------------------------------------# 132 | # yolo_body 133 | #---------------------------------------------------# 134 | class YoloBody(nn.Module): 135 | def __init__(self, num_anchors, num_classes, backbone="mobilenetv2", pretrained=False): 136 | super(YoloBody, self).__init__() 137 | #---------------------------------------------------# 138 | # 生成mobilnet的主干模型,获得三个有效特征层。 139 | #---------------------------------------------------# 140 | if backbone == "mobilenetv1": 141 | #---------------------------------------------------# 142 | # 52,52,256;26,26,512;13,13,1024 143 | #---------------------------------------------------# 144 | self.backbone = MobileNetV1(pretrained=pretrained) 145 | in_filters = [256,512,1024] 146 | elif backbone == "mobilenetv2": 147 | #---------------------------------------------------# 148 | # 52,52,32;26,26,92;13,13,320 149 | #---------------------------------------------------# 150 | self.backbone = MobileNetV2(pretrained=pretrained) 151 | in_filters = [32,96,320] 152 | elif backbone == "mobilenetv3": 153 | #---------------------------------------------------# 154 | # 52,52,40;26,26,112;13,13,160 155 | #---------------------------------------------------# 156 | self.backbone = MobileNetV3(pretrained=pretrained) 157 | in_filters = [40,112,160] 158 | else: 159 | raise ValueError('Unsupported backbone - `{}`, Use mobilenetv1, mobilenetv2, mobilenetv3.'.format(backbone)) 160 | 161 | self.conv1 = make_three_conv([512, 1024], in_filters[2]) 162 | self.SPP = SpatialPyramidPooling() 163 | self.conv2 = make_three_conv([512, 1024], 2048) 164 | 165 | self.upsample1 = Upsample(512, 256) 166 | self.conv_for_P4 = conv2d(in_filters[1], 256,1) 167 | self.make_five_conv1 = make_five_conv([256, 512], 512) 168 | 169 | self.upsample2 = Upsample(256, 128) 170 | self.conv_for_P3 = conv2d(in_filters[0], 128,1) 171 | self.make_five_conv2 = make_five_conv([128, 256], 256) 172 | 173 | # 3*(5+num_classes) = 3*(5+20) = 3*(4+1+20)=75 174 | final_out_filter2 = num_anchors * (5 + num_classes) 175 | self.yolo_head3 = yolo_head([256, final_out_filter2],128) 176 | 177 | self.down_sample1 = conv_dw(128, 256,stride=2) 178 | self.make_five_conv3 = make_five_conv([256, 512],512) 179 | 180 | # 3*(5+num_classes) = 3*(5+20) = 3*(4+1+20)=75 181 | final_out_filter1 = num_anchors * (5 + num_classes) 182 | self.yolo_head2 = yolo_head([512, final_out_filter1], 256) 183 | 184 | self.down_sample2 = conv_dw(256, 512,stride=2) 185 | self.make_five_conv4 = make_five_conv([512, 1024], 1024) 186 | 187 | # 3*(5+num_classes) = 3*(5+20) = 3*(4+1+20)=75 188 | final_out_filter0 = num_anchors * (5 + num_classes) 189 | self.yolo_head1 = yolo_head([1024, final_out_filter0], 512) 190 | 191 | 192 | def forward(self, x): 193 | # backbone 194 | x2, x1, x0 = self.backbone(x) 195 | 196 | P5 = self.conv1(x0) 197 | P5 = self.SPP(P5) 198 | P5 = self.conv2(P5) 199 | 200 | P5_upsample = self.upsample1(P5) 201 | P4 = self.conv_for_P4(x1) 202 | P4 = torch.cat([P4,P5_upsample],axis=1) 203 | P4 = self.make_five_conv1(P4) 204 | 205 | P4_upsample = self.upsample2(P4) 206 | P3 = self.conv_for_P3(x2) 207 | P3 = torch.cat([P3,P4_upsample],axis=1) 208 | P3 = self.make_five_conv2(P3) 209 | 210 | P3_downsample = self.down_sample1(P3) 211 | P4 = torch.cat([P3_downsample,P4],axis=1) 212 | P4 = self.make_five_conv3(P4) 213 | 214 | P4_downsample = self.down_sample2(P4) 215 | P5 = torch.cat([P4_downsample,P5],axis=1) 216 | P5 = self.make_five_conv4(P5) 217 | 218 | out2 = self.yolo_head3(P3) 219 | out1 = self.yolo_head2(P4) 220 | out0 = self.yolo_head1(P5) 221 | 222 | return out0, out1, out2 223 | 224 | -------------------------------------------------------------------------------- /nets/yolo_training.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from random import shuffle 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | import torch.nn.functional as F 8 | from matplotlib.colors import rgb_to_hsv, hsv_to_rgb 9 | from PIL import Image 10 | from utils.utils import bbox_iou, merge_bboxes 11 | 12 | def jaccard(_box_a, _box_b): 13 | b1_x1, b1_x2 = _box_a[:, 0] - _box_a[:, 2] / 2, _box_a[:, 0] + _box_a[:, 2] / 2 14 | b1_y1, b1_y2 = _box_a[:, 1] - _box_a[:, 3] / 2, _box_a[:, 1] + _box_a[:, 3] / 2 15 | b2_x1, b2_x2 = _box_b[:, 0] - _box_b[:, 2] / 2, _box_b[:, 0] + _box_b[:, 2] / 2 16 | b2_y1, b2_y2 = _box_b[:, 1] - _box_b[:, 3] / 2, _box_b[:, 1] + _box_b[:, 3] / 2 17 | box_a = torch.zeros_like(_box_a) 18 | box_b = torch.zeros_like(_box_b) 19 | box_a[:, 0], box_a[:, 1], box_a[:, 2], box_a[:, 3] = b1_x1, b1_y1, b1_x2, b1_y2 20 | box_b[:, 0], box_b[:, 1], box_b[:, 2], box_b[:, 3] = b2_x1, b2_y1, b2_x2, b2_y2 21 | A = box_a.size(0) 22 | B = box_b.size(0) 23 | max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), 24 | box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) 25 | min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), 26 | box_b[:, :2].unsqueeze(0).expand(A, B, 2)) 27 | inter = torch.clamp((max_xy - min_xy), min=0) 28 | 29 | inter = inter[:, :, 0] * inter[:, :, 1] 30 | # 计算先验框和真实框各自的面积 31 | area_a = ((box_a[:, 2]-box_a[:, 0]) * 32 | (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] 33 | area_b = ((box_b[:, 2]-box_b[:, 0]) * 34 | (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] 35 | # 求IOU 36 | union = area_a + area_b - inter 37 | return inter / union # [A,B] 38 | 39 | #---------------------------------------------------# 40 | # 平滑标签 41 | #---------------------------------------------------# 42 | def smooth_labels(y_true, label_smoothing,num_classes): 43 | return y_true * (1.0 - label_smoothing) + label_smoothing / num_classes 44 | 45 | def box_ciou(b1, b2): 46 | """ 47 | 输入为: 48 | ---------- 49 | b1: tensor, shape=(batch, feat_w, feat_h, anchor_num, 4), xywh 50 | b2: tensor, shape=(batch, feat_w, feat_h, anchor_num, 4), xywh 51 | 52 | 返回为: 53 | ------- 54 | ciou: tensor, shape=(batch, feat_w, feat_h, anchor_num, 1) 55 | """ 56 | # 求出预测框左上角右下角 57 | b1_xy = b1[..., :2] 58 | b1_wh = b1[..., 2:4] 59 | b1_wh_half = b1_wh/2. 60 | b1_mins = b1_xy - b1_wh_half 61 | b1_maxes = b1_xy + b1_wh_half 62 | # 求出真实框左上角右下角 63 | b2_xy = b2[..., :2] 64 | b2_wh = b2[..., 2:4] 65 | b2_wh_half = b2_wh/2. 66 | b2_mins = b2_xy - b2_wh_half 67 | b2_maxes = b2_xy + b2_wh_half 68 | 69 | # 求真实框和预测框所有的iou 70 | intersect_mins = torch.max(b1_mins, b2_mins) 71 | intersect_maxes = torch.min(b1_maxes, b2_maxes) 72 | intersect_wh = torch.max(intersect_maxes - intersect_mins, torch.zeros_like(intersect_maxes)) 73 | intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1] 74 | b1_area = b1_wh[..., 0] * b1_wh[..., 1] 75 | b2_area = b2_wh[..., 0] * b2_wh[..., 1] 76 | union_area = b1_area + b2_area - intersect_area 77 | iou = intersect_area / torch.clamp(union_area,min = 1e-6) 78 | 79 | # 计算中心的差距 80 | center_distance = torch.sum(torch.pow((b1_xy - b2_xy), 2), axis=-1) 81 | 82 | # 找到包裹两个框的最小框的左上角和右下角 83 | enclose_mins = torch.min(b1_mins, b2_mins) 84 | enclose_maxes = torch.max(b1_maxes, b2_maxes) 85 | enclose_wh = torch.max(enclose_maxes - enclose_mins, torch.zeros_like(intersect_maxes)) 86 | # 计算对角线距离 87 | enclose_diagonal = torch.sum(torch.pow(enclose_wh,2), axis=-1) 88 | ciou = iou - 1.0 * (center_distance) / torch.clamp(enclose_diagonal,min = 1e-6) 89 | 90 | v = (4 / (math.pi ** 2)) * torch.pow((torch.atan(b1_wh[..., 0]/torch.clamp(b1_wh[..., 1],min = 1e-6)) - torch.atan(b2_wh[..., 0]/torch.clamp(b2_wh[..., 1],min = 1e-6))), 2) 91 | alpha = v / torch.clamp((1.0 - iou + v),min=1e-6) 92 | ciou = ciou - alpha * v 93 | return ciou 94 | 95 | def clip_by_tensor(t,t_min,t_max): 96 | t=t.float() 97 | result = (t >= t_min).float() * t + (t < t_min).float() * t_min 98 | result = (result <= t_max).float() * result + (result > t_max).float() * t_max 99 | return result 100 | 101 | def MSELoss(pred,target): 102 | return (pred-target)**2 103 | 104 | def BCELoss(pred,target): 105 | epsilon = 1e-7 106 | pred = clip_by_tensor(pred, epsilon, 1.0 - epsilon) 107 | output = -target * torch.log(pred) - (1.0 - target) * torch.log(1.0 - pred) 108 | return output 109 | 110 | class YOLOLoss(nn.Module): 111 | def __init__(self, anchors, num_classes, img_size, label_smooth=0, cuda=True, normalize=True): 112 | super(YOLOLoss, self).__init__() 113 | self.anchors = anchors 114 | self.num_anchors = len(anchors) 115 | self.num_classes = num_classes 116 | self.bbox_attrs = 5 + num_classes 117 | self.img_size = img_size 118 | self.feature_length = [img_size[0]//32,img_size[0]//16,img_size[0]//8] 119 | self.label_smooth = label_smooth 120 | 121 | self.ignore_threshold = 0.5 122 | self.lambda_conf = 1.0 123 | self.lambda_cls = 1.0 124 | self.lambda_loc = 1.0 125 | self.cuda = cuda 126 | self.normalize = normalize 127 | 128 | def forward(self, input, targets=None): 129 | #----------------------------------------------------# 130 | # input的shape为 bs, 3*(5+num_classes), 13, 13 131 | # bs, 3*(5+num_classes), 26, 26 132 | # bs, 3*(5+num_classes), 52, 52 133 | #----------------------------------------------------# 134 | 135 | #-----------------------# 136 | # 一共多少张图片 137 | #-----------------------# 138 | bs = input.size(0) 139 | #-----------------------# 140 | # 特征层的高 141 | #-----------------------# 142 | in_h = input.size(2) 143 | #-----------------------# 144 | # 特征层的宽 145 | #-----------------------# 146 | in_w = input.size(3) 147 | 148 | #-----------------------------------------------------------------------# 149 | # 计算步长 150 | # 每一个特征点对应原来的图片上多少个像素点 151 | # 如果特征层为13x13的话,一个特征点就对应原来的图片上的32个像素点 152 | # 如果特征层为26x26的话,一个特征点就对应原来的图片上的16个像素点 153 | # 如果特征层为52x52的话,一个特征点就对应原来的图片上的8个像素点 154 | # stride_h = stride_w = 32、16、8 155 | #-----------------------------------------------------------------------# 156 | stride_h = self.img_size[1] / in_h 157 | stride_w = self.img_size[0] / in_w 158 | 159 | 160 | #-------------------------------------------------# 161 | # 此时获得的scaled_anchors大小是相对于特征层的 162 | #-------------------------------------------------# 163 | scaled_anchors = [(a_w / stride_w, a_h / stride_h) for a_w, a_h in self.anchors] 164 | 165 | #-----------------------------------------------# 166 | # 输入的input一共有三个,他们的shape分别是 167 | # batch_size, 3, 13, 13, 5 + num_classes 168 | # batch_size, 3, 26, 26, 5 + num_classes 169 | # batch_size, 3, 52, 52, 5 + num_classes 170 | #-----------------------------------------------# 171 | prediction = input.view(bs, int(self.num_anchors/3), 172 | self.bbox_attrs, in_h, in_w).permute(0, 1, 3, 4, 2).contiguous() 173 | 174 | # 获得置信度,是否有物体 175 | conf = torch.sigmoid(prediction[..., 4]) 176 | # 种类置信度 177 | pred_cls = torch.sigmoid(prediction[..., 5:]) 178 | 179 | #---------------------------------------------------------------# 180 | # 找到哪些先验框内部包含物体 181 | # 利用真实框和先验框计算交并比 182 | # mask batch_size, 3, in_h, in_w 有目标的特征点 183 | # noobj_mask batch_size, 3, in_h, in_w 无目标的特征点 184 | # t_box batch_size, 3, in_h, in_w, 4 中心宽高的真实值 185 | # tconf batch_size, 3, in_h, in_w 置信度真实值 186 | # tcls batch_size, 3, in_h, in_w, num_classes 种类真实值 187 | #----------------------------------------------------------------# 188 | mask, noobj_mask, t_box, tconf, tcls, box_loss_scale_x, box_loss_scale_y = self.get_target(targets, scaled_anchors,in_w, in_h,self.ignore_threshold) 189 | 190 | #---------------------------------------------------------------# 191 | # 将预测结果进行解码,判断预测结果和真实值的重合程度 192 | # 如果重合程度过大则忽略,因为这些特征点属于预测比较准确的特征点 193 | # 作为负样本不合适 194 | #----------------------------------------------------------------# 195 | noobj_mask, pred_boxes_for_ciou = self.get_ignore(prediction, targets, scaled_anchors, in_w, in_h, noobj_mask) 196 | 197 | if self.cuda: 198 | mask, noobj_mask = mask.cuda(), noobj_mask.cuda() 199 | box_loss_scale_x, box_loss_scale_y= box_loss_scale_x.cuda(), box_loss_scale_y.cuda() 200 | tconf, tcls = tconf.cuda(), tcls.cuda() 201 | pred_boxes_for_ciou = pred_boxes_for_ciou.cuda() 202 | t_box = t_box.cuda() 203 | 204 | box_loss_scale = 2 - box_loss_scale_x * box_loss_scale_y 205 | #---------------------------------------------------------------# 206 | # 计算预测结果和真实结果的CIOU 207 | #----------------------------------------------------------------# 208 | ciou = (1 - box_ciou( pred_boxes_for_ciou[mask.bool()], t_box[mask.bool()]))* box_loss_scale[mask.bool()] 209 | loss_loc = torch.sum(ciou) 210 | 211 | # 计算置信度的loss 212 | loss_conf = torch.sum(BCELoss(conf, mask) * mask) + \ 213 | torch.sum(BCELoss(conf, mask) * noobj_mask) 214 | 215 | loss_cls = torch.sum(BCELoss(pred_cls[mask == 1], smooth_labels(tcls[mask == 1],self.label_smooth,self.num_classes))) 216 | 217 | loss = loss_conf * self.lambda_conf + loss_cls * self.lambda_cls + loss_loc * self.lambda_loc 218 | 219 | if self.normalize: 220 | num_pos = torch.sum(mask) 221 | num_pos = torch.max(num_pos, torch.ones_like(num_pos)) 222 | else: 223 | num_pos = bs/3 224 | return loss, num_pos 225 | 226 | def get_target(self, target, anchors, in_w, in_h, ignore_threshold): 227 | #-----------------------------------------------------# 228 | # 计算一共有多少张图片 229 | #-----------------------------------------------------# 230 | bs = len(target) 231 | #-------------------------------------------------------# 232 | # 获得当前特征层先验框所属的编号,方便后面对先验框筛选 233 | #-------------------------------------------------------# 234 | anchor_index = [[0,1,2],[3,4,5],[6,7,8]][self.feature_length.index(in_w)] 235 | subtract_index = [0,3,6][self.feature_length.index(in_w)] 236 | #-------------------------------------------------------# 237 | # 创建全是0或者全是1的阵列 238 | #-------------------------------------------------------# 239 | mask = torch.zeros(bs, int(self.num_anchors/3), in_h, in_w, requires_grad=False) 240 | noobj_mask = torch.ones(bs, int(self.num_anchors/3), in_h, in_w, requires_grad=False) 241 | 242 | tx = torch.zeros(bs, int(self.num_anchors/3), in_h, in_w, requires_grad=False) 243 | ty = torch.zeros(bs, int(self.num_anchors/3), in_h, in_w, requires_grad=False) 244 | tw = torch.zeros(bs, int(self.num_anchors/3), in_h, in_w, requires_grad=False) 245 | th = torch.zeros(bs, int(self.num_anchors/3), in_h, in_w, requires_grad=False) 246 | t_box = torch.zeros(bs, int(self.num_anchors/3), in_h, in_w, 4, requires_grad=False) 247 | tconf = torch.zeros(bs, int(self.num_anchors/3), in_h, in_w, requires_grad=False) 248 | tcls = torch.zeros(bs, int(self.num_anchors/3), in_h, in_w, self.num_classes, requires_grad=False) 249 | 250 | box_loss_scale_x = torch.zeros(bs, int(self.num_anchors/3), in_h, in_w, requires_grad=False) 251 | box_loss_scale_y = torch.zeros(bs, int(self.num_anchors/3), in_h, in_w, requires_grad=False) 252 | for b in range(bs): 253 | if len(target[b])==0: 254 | continue 255 | #-------------------------------------------------------# 256 | # 计算出正样本在特征层上的中心点 257 | #-------------------------------------------------------# 258 | gxs = target[b][:, 0:1] * in_w 259 | gys = target[b][:, 1:2] * in_h 260 | 261 | #-------------------------------------------------------# 262 | # 计算出正样本相对于特征层的宽高 263 | #-------------------------------------------------------# 264 | gws = target[b][:, 2:3] * in_w 265 | ghs = target[b][:, 3:4] * in_h 266 | 267 | #-------------------------------------------------------# 268 | # 计算出正样本属于特征层的哪个特征点 269 | #-------------------------------------------------------# 270 | gis = torch.floor(gxs) 271 | gjs = torch.floor(gys) 272 | 273 | #-------------------------------------------------------# 274 | # 将真实框转换一个形式 275 | # num_true_box, 4 276 | #-------------------------------------------------------# 277 | gt_box = torch.FloatTensor(torch.cat([torch.zeros_like(gws), torch.zeros_like(ghs), gws, ghs], 1)) 278 | 279 | #-------------------------------------------------------# 280 | # 将先验框转换一个形式 281 | # 9, 4 282 | #-------------------------------------------------------# 283 | anchor_shapes = torch.FloatTensor(torch.cat((torch.zeros((self.num_anchors, 2)), torch.FloatTensor(anchors)), 1)) 284 | #-------------------------------------------------------# 285 | # 计算交并比 286 | # num_true_box, 9 287 | #-------------------------------------------------------# 288 | anch_ious = jaccard(gt_box, anchor_shapes) 289 | 290 | #-------------------------------------------------------# 291 | # 计算重合度最大的先验框是哪个 292 | # num_true_box, 293 | #-------------------------------------------------------# 294 | best_ns = torch.argmax(anch_ious,dim=-1) 295 | for i, best_n in enumerate(best_ns): 296 | if best_n not in anchor_index: 297 | continue 298 | #-------------------------------------------------------------# 299 | # 取出各类坐标: 300 | # gi和gj代表的是真实框对应的特征点的x轴y轴坐标 301 | # gx和gy代表真实框的x轴和y轴坐标 302 | # gw和gh代表真实框的宽和高 303 | #-------------------------------------------------------------# 304 | gi = gis[i].long() 305 | gj = gjs[i].long() 306 | gx = gxs[i] 307 | gy = gys[i] 308 | gw = gws[i] 309 | gh = ghs[i] 310 | if (gj < in_h) and (gi < in_w): 311 | best_n = best_n - subtract_index 312 | #----------------------------------------# 313 | # noobj_mask代表无目标的特征点 314 | #----------------------------------------# 315 | noobj_mask[b, best_n, gj, gi] = 0 316 | #----------------------------------------# 317 | # mask代表有目标的特征点 318 | #----------------------------------------# 319 | mask[b, best_n, gj, gi] = 1 320 | #----------------------------------------# 321 | # tx、ty代表中心的真实值 322 | #----------------------------------------# 323 | tx[b, best_n, gj, gi] = gx 324 | ty[b, best_n, gj, gi] = gy 325 | #----------------------------------------# 326 | # tw、th代表宽高的真实值 327 | #----------------------------------------# 328 | tw[b, best_n, gj, gi] = gw 329 | th[b, best_n, gj, gi] = gh 330 | #----------------------------------------# 331 | # 用于获得xywh的比例 332 | # 大目标loss权重小,小目标loss权重大 333 | #----------------------------------------# 334 | box_loss_scale_x[b, best_n, gj, gi] = target[b][i, 2] 335 | box_loss_scale_y[b, best_n, gj, gi] = target[b][i, 3] 336 | #----------------------------------------# 337 | # tconf代表物体置信度 338 | #----------------------------------------# 339 | tconf[b, best_n, gj, gi] = 1 340 | #----------------------------------------# 341 | # tcls代表种类置信度 342 | #----------------------------------------# 343 | tcls[b, best_n, gj, gi, target[b][i, 4].long()] = 1 344 | else: 345 | print('Step {0} out of bound'.format(b)) 346 | print('gj: {0}, height: {1} | gi: {2}, width: {3}'.format(gj, in_h, gi, in_w)) 347 | continue 348 | t_box[...,0] = tx 349 | t_box[...,1] = ty 350 | t_box[...,2] = tw 351 | t_box[...,3] = th 352 | return mask, noobj_mask, t_box, tconf, tcls, box_loss_scale_x, box_loss_scale_y 353 | 354 | 355 | def get_ignore(self,prediction,target,scaled_anchors,in_w, in_h,noobj_mask): 356 | #-----------------------------------------------------# 357 | # 计算一共有多少张图片 358 | #-----------------------------------------------------# 359 | bs = len(target) 360 | #-------------------------------------------------------# 361 | # 获得当前特征层先验框所属的编号,方便后面对先验框筛选 362 | #-------------------------------------------------------# 363 | anchor_index = [[0,1,2],[3,4,5],[6,7,8]][self.feature_length.index(in_w)] 364 | scaled_anchors = np.array(scaled_anchors)[anchor_index] 365 | 366 | # 先验框的中心位置的调整参数 367 | x = torch.sigmoid(prediction[..., 0]) 368 | y = torch.sigmoid(prediction[..., 1]) 369 | # 先验框的宽高调整参数 370 | w = prediction[..., 2] # Width 371 | h = prediction[..., 3] # Height 372 | 373 | FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor 374 | LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor 375 | 376 | # 生成网格,先验框中心,网格左上角 377 | grid_x = torch.linspace(0, in_w - 1, in_w).repeat(in_h, 1).repeat( 378 | int(bs*self.num_anchors/3), 1, 1).view(x.shape).type(FloatTensor) 379 | grid_y = torch.linspace(0, in_h - 1, in_h).repeat(in_w, 1).t().repeat( 380 | int(bs*self.num_anchors/3), 1, 1).view(y.shape).type(FloatTensor) 381 | 382 | # 生成先验框的宽高 383 | anchor_w = FloatTensor(scaled_anchors).index_select(1, LongTensor([0])) 384 | anchor_h = FloatTensor(scaled_anchors).index_select(1, LongTensor([1])) 385 | 386 | anchor_w = anchor_w.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(w.shape) 387 | anchor_h = anchor_h.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(h.shape) 388 | 389 | #-------------------------------------------------------# 390 | # 计算调整后的先验框中心与宽高 391 | #-------------------------------------------------------# 392 | pred_boxes = FloatTensor(prediction[..., :4].shape) 393 | pred_boxes[..., 0] = x + grid_x 394 | pred_boxes[..., 1] = y + grid_y 395 | pred_boxes[..., 2] = torch.exp(w) * anchor_w 396 | pred_boxes[..., 3] = torch.exp(h) * anchor_h 397 | for i in range(bs): 398 | pred_boxes_for_ignore = pred_boxes[i] 399 | #-------------------------------------------------------# 400 | # 将预测结果转换一个形式 401 | # pred_boxes_for_ignore num_anchors, 4 402 | #-------------------------------------------------------# 403 | pred_boxes_for_ignore = pred_boxes_for_ignore.view(-1, 4) 404 | #-------------------------------------------------------# 405 | # 计算真实框,并把真实框转换成相对于特征层的大小 406 | # gt_box num_true_box, 4 407 | #-------------------------------------------------------# 408 | if len(target[i]) > 0: 409 | gx = target[i][:, 0:1] * in_w 410 | gy = target[i][:, 1:2] * in_h 411 | gw = target[i][:, 2:3] * in_w 412 | gh = target[i][:, 3:4] * in_h 413 | gt_box = torch.FloatTensor(torch.cat([gx, gy, gw, gh],-1)).type(FloatTensor) 414 | 415 | #-------------------------------------------------------# 416 | # 计算交并比 417 | # anch_ious num_true_box, num_anchors 418 | #-------------------------------------------------------# 419 | anch_ious = jaccard(gt_box, pred_boxes_for_ignore) 420 | #-------------------------------------------------------# 421 | # 每个先验框对应真实框的最大重合度 422 | # anch_ious_max num_anchors 423 | #-------------------------------------------------------# 424 | anch_ious_max, _ = torch.max(anch_ious,dim=0) 425 | anch_ious_max = anch_ious_max.view(pred_boxes[i].size()[:3]) 426 | noobj_mask[i][anch_ious_max>self.ignore_threshold] = 0 427 | return noobj_mask, pred_boxes 428 | 429 | 430 | def rand(a=0, b=1): 431 | return np.random.rand()*(b-a) + a 432 | 433 | 434 | class Generator(object): 435 | def __init__(self,batch_size, 436 | train_lines, image_size, 437 | ): 438 | 439 | self.batch_size = batch_size 440 | self.train_lines = train_lines 441 | self.train_batches = len(train_lines) 442 | self.image_size = image_size 443 | 444 | def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=1.5, val=1.5, random=True): 445 | '''r实时数据增强的随机预处理''' 446 | line = annotation_line.split() 447 | image = Image.open(line[0]) 448 | iw, ih = image.size 449 | h, w = input_shape 450 | box = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]]) 451 | 452 | if not random: 453 | scale = min(w/iw, h/ih) 454 | nw = int(iw*scale) 455 | nh = int(ih*scale) 456 | dx = (w-nw)//2 457 | dy = (h-nh)//2 458 | 459 | image = image.resize((nw,nh), Image.BICUBIC) 460 | new_image = Image.new('RGB', (w,h), (128,128,128)) 461 | new_image.paste(image, (dx, dy)) 462 | image_data = np.array(new_image, np.float32) 463 | 464 | # 调整目标框坐标 465 | box_data = np.zeros((len(box), 5)) 466 | if len(box) > 0: 467 | np.random.shuffle(box) 468 | box[:, [0, 2]] = box[:, [0, 2]] * nw / iw + dx 469 | box[:, [1, 3]] = box[:, [1, 3]] * nh / ih + dy 470 | box[:, 0:2][box[:, 0:2] < 0] = 0 471 | box[:, 2][box[:, 2] > w] = w 472 | box[:, 3][box[:, 3] > h] = h 473 | box_w = box[:, 2] - box[:, 0] 474 | box_h = box[:, 3] - box[:, 1] 475 | box = box[np.logical_and(box_w > 1, box_h > 1)] # 保留有效框 476 | box_data = np.zeros((len(box), 5)) 477 | box_data[:len(box)] = box 478 | 479 | return image_data, box_data 480 | 481 | # resize image 482 | new_ar = w/h * rand(1-jitter,1+jitter)/rand(1-jitter,1+jitter) 483 | scale = rand(.25, 2) 484 | if new_ar < 1: 485 | nh = int(scale*h) 486 | nw = int(nh*new_ar) 487 | else: 488 | nw = int(scale*w) 489 | nh = int(nw/new_ar) 490 | image = image.resize((nw,nh), Image.BICUBIC) 491 | 492 | # place image 493 | dx = int(rand(0, w-nw)) 494 | dy = int(rand(0, h-nh)) 495 | new_image = Image.new('RGB', (w,h), (128,128,128)) 496 | new_image.paste(image, (dx, dy)) 497 | image = new_image 498 | 499 | # flip image or not 500 | flip = rand()<.5 501 | if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT) 502 | 503 | # distort image 504 | hue = rand(-hue, hue) 505 | sat = rand(1, sat) if rand()<.5 else 1/rand(1, sat) 506 | val = rand(1, val) if rand()<.5 else 1/rand(1, val) 507 | x = cv2.cvtColor(np.array(image,np.float32)/255, cv2.COLOR_RGB2HSV) 508 | x[..., 0] += hue*360 509 | x[..., 0][x[..., 0]>1] -= 1 510 | x[..., 0][x[..., 0]<0] += 1 511 | x[..., 1] *= sat 512 | x[..., 2] *= val 513 | x[x[:,:, 0]>360, 0] = 360 514 | x[:, :, 1:][x[:, :, 1:]>1] = 1 515 | x[x<0] = 0 516 | image_data = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255 517 | 518 | # correct boxes 519 | box_data = np.zeros((len(box),5)) 520 | if len(box)>0: 521 | np.random.shuffle(box) 522 | box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx 523 | box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy 524 | if flip: box[:, [0,2]] = w - box[:, [2,0]] 525 | box[:, 0:2][box[:, 0:2]<0] = 0 526 | box[:, 2][box[:, 2]>w] = w 527 | box[:, 3][box[:, 3]>h] = h 528 | box_w = box[:, 2] - box[:, 0] 529 | box_h = box[:, 3] - box[:, 1] 530 | box = box[np.logical_and(box_w>1, box_h>1)] # discard invalid box 531 | box_data = np.zeros((len(box),5)) 532 | box_data[:len(box)] = box 533 | 534 | return image_data, box_data 535 | 536 | def get_random_data_with_Mosaic(self, annotation_line, input_shape, hue=.1, sat=1.5, val=1.5): 537 | '''random preprocessing for real-time data augmentation''' 538 | h, w = input_shape 539 | min_offset_x = 0.3 540 | min_offset_y = 0.3 541 | scale_low = 1-min(min_offset_x,min_offset_y) 542 | scale_high = scale_low+0.2 543 | 544 | image_datas = [] 545 | box_datas = [] 546 | index = 0 547 | 548 | place_x = [0,0,int(w*min_offset_x),int(w*min_offset_x)] 549 | place_y = [0,int(h*min_offset_y),int(h*min_offset_y),0] 550 | for line in annotation_line: 551 | # 每一行进行分割 552 | line_content = line.split() 553 | # 打开图片 554 | image = Image.open(line_content[0]) 555 | image = image.convert("RGB") 556 | # 图片的大小 557 | iw, ih = image.size 558 | # 保存框的位置 559 | box = np.array([np.array(list(map(int,box.split(',')))) for box in line_content[1:]]) 560 | 561 | # 是否翻转图片 562 | flip = rand()<.5 563 | if flip and len(box)>0: 564 | image = image.transpose(Image.FLIP_LEFT_RIGHT) 565 | box[:, [0,2]] = iw - box[:, [2,0]] 566 | 567 | # 对输入进来的图片进行缩放 568 | new_ar = w/h 569 | scale = rand(scale_low, scale_high) 570 | if new_ar < 1: 571 | nh = int(scale*h) 572 | nw = int(nh*new_ar) 573 | else: 574 | nw = int(scale*w) 575 | nh = int(nw/new_ar) 576 | image = image.resize((nw,nh), Image.BICUBIC) 577 | 578 | # 进行色域变换 579 | hue = rand(-hue, hue) 580 | sat = rand(1, sat) if rand()<.5 else 1/rand(1, sat) 581 | val = rand(1, val) if rand()<.5 else 1/rand(1, val) 582 | x = cv2.cvtColor(np.array(image,np.float32)/255, cv2.COLOR_RGB2HSV) 583 | x[..., 0] += hue*360 584 | x[..., 0][x[..., 0]>1] -= 1 585 | x[..., 0][x[..., 0]<0] += 1 586 | x[..., 1] *= sat 587 | x[..., 2] *= val 588 | x[x[:,:, 0]>360, 0] = 360 589 | x[:, :, 1:][x[:, :, 1:]>1] = 1 590 | x[x<0] = 0 591 | image = cv2.cvtColor(x, cv2.COLOR_HSV2RGB) # numpy array, 0 to 1 592 | 593 | image = Image.fromarray((image*255).astype(np.uint8)) 594 | # 将图片进行放置,分别对应四张分割图片的位置 595 | dx = place_x[index] 596 | dy = place_y[index] 597 | new_image = Image.new('RGB', (w,h), (128,128,128)) 598 | new_image.paste(image, (dx, dy)) 599 | image_data = np.array(new_image) 600 | 601 | 602 | index = index + 1 603 | box_data = [] 604 | # 对box进行重新处理 605 | if len(box)>0: 606 | np.random.shuffle(box) 607 | box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx 608 | box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy 609 | box[:, 0:2][box[:, 0:2]<0] = 0 610 | box[:, 2][box[:, 2]>w] = w 611 | box[:, 3][box[:, 3]>h] = h 612 | box_w = box[:, 2] - box[:, 0] 613 | box_h = box[:, 3] - box[:, 1] 614 | box = box[np.logical_and(box_w>1, box_h>1)] 615 | box_data = np.zeros((len(box),5)) 616 | box_data[:len(box)] = box 617 | 618 | image_datas.append(image_data) 619 | box_datas.append(box_data) 620 | 621 | # 将图片分割,放在一起 622 | cutx = np.random.randint(int(w*min_offset_x), int(w*(1 - min_offset_x))) 623 | cuty = np.random.randint(int(h*min_offset_y), int(h*(1 - min_offset_y))) 624 | 625 | new_image = np.zeros([h,w,3]) 626 | new_image[:cuty, :cutx, :] = image_datas[0][:cuty, :cutx, :] 627 | new_image[cuty:, :cutx, :] = image_datas[1][cuty:, :cutx, :] 628 | new_image[cuty:, cutx:, :] = image_datas[2][cuty:, cutx:, :] 629 | new_image[:cuty, cutx:, :] = image_datas[3][:cuty, cutx:, :] 630 | 631 | # 对框进行进一步的处理 632 | new_boxes = np.array(merge_bboxes(box_datas, cutx, cuty)) 633 | 634 | if len(new_boxes) == 0: 635 | return new_image, [] 636 | if (new_boxes[:,:4]>0).any(): 637 | return new_image, new_boxes 638 | else: 639 | return new_image, [] 640 | 641 | def generate(self, train = True, mosaic = True): 642 | while True: 643 | shuffle(self.train_lines) 644 | lines = self.train_lines 645 | inputs = [] 646 | targets = [] 647 | flag = True 648 | n = len(lines) 649 | for i in range(len(lines)): 650 | if mosaic == True: 651 | if flag and (i+4) < n: 652 | img,y = self.get_random_data_with_Mosaic(lines[i:i+4], self.image_size[0:2]) 653 | i = (i+4) % n 654 | else: 655 | img,y = self.get_random_data(lines[i], self.image_size[0:2], random=train) 656 | i = (i+1) % n 657 | flag = bool(1-flag) 658 | else: 659 | img,y = self.get_random_data(lines[i], self.image_size[0:2], random=train) 660 | i = (i+1) % n 661 | 662 | if len(y)!=0: 663 | boxes = np.array(y[:,:4],dtype=np.float32) 664 | boxes[:,0] = boxes[:,0]/self.image_size[1] 665 | boxes[:,1] = boxes[:,1]/self.image_size[0] 666 | boxes[:,2] = boxes[:,2]/self.image_size[1] 667 | boxes[:,3] = boxes[:,3]/self.image_size[0] 668 | 669 | boxes = np.maximum(np.minimum(boxes,1),0) 670 | boxes[:,2] = boxes[:,2] - boxes[:,0] 671 | boxes[:,3] = boxes[:,3] - boxes[:,1] 672 | 673 | boxes[:,0] = boxes[:,0] + boxes[:,2]/2 674 | boxes[:,1] = boxes[:,1] + boxes[:,3]/2 675 | y = np.concatenate([boxes,y[:,-1:]],axis=-1) 676 | 677 | img = np.array(img,dtype = np.float32) 678 | 679 | inputs.append(np.transpose(img/255.0,(2,0,1))) 680 | targets.append(np.array(y,dtype = np.float32)) 681 | if len(targets) == self.batch_size: 682 | tmp_inp = np.array(inputs) 683 | tmp_targets = targets 684 | inputs = [] 685 | targets = [] 686 | yield tmp_inp, tmp_targets 687 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy==1.2.1 2 | numpy==1.17.0 3 | matplotlib==3.1.2 4 | opencv_python==4.1.2.30 5 | torch==1.2.0 6 | torchvision==0.4.0 7 | tqdm==4.60.0 8 | Pillow==8.2.0 9 | h5py==2.10.0 10 | qdarkstyle==2.8.1 11 | pygame==2.2.0 12 | PyQt5==5.15.7 -------------------------------------------------------------------------------- /utils/dataloader.py: -------------------------------------------------------------------------------- 1 | from random import shuffle 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | import torch.nn.functional as F 7 | from PIL import Image 8 | from torch.autograd import Variable 9 | from torch.utils.data import DataLoader 10 | from torch.utils.data.dataset import Dataset 11 | from utils.utils import bbox_iou, merge_bboxes 12 | from matplotlib.colors import rgb_to_hsv, hsv_to_rgb 13 | from nets.yolo_training import Generator 14 | import cv2 15 | 16 | class YoloDataset(Dataset): 17 | def __init__(self, train_lines, image_size, mosaic=True, is_train=True): 18 | super(YoloDataset, self).__init__() 19 | 20 | self.train_lines = train_lines 21 | self.train_batches = len(train_lines) 22 | self.image_size = image_size 23 | self.mosaic = mosaic 24 | self.flag = True 25 | self.is_train = is_train 26 | 27 | def __len__(self): 28 | return self.train_batches 29 | 30 | def rand(self, a=0, b=1): 31 | return np.random.rand() * (b - a) + a 32 | 33 | def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=1.5, val=1.5, random=True): 34 | """实时数据增强的随机预处理""" 35 | line = annotation_line.split() 36 | image = Image.open(line[0]) 37 | iw, ih = image.size 38 | h, w = input_shape 39 | box = np.array([np.array(list(map(int, box.split(',')))) for box in line[1:]]) 40 | 41 | if not random: 42 | scale = min(w/iw, h/ih) 43 | nw = int(iw*scale) 44 | nh = int(ih*scale) 45 | dx = (w-nw)//2 46 | dy = (h-nh)//2 47 | 48 | image = image.resize((nw,nh), Image.BICUBIC) 49 | new_image = Image.new('RGB', (w,h), (128,128,128)) 50 | new_image.paste(image, (dx, dy)) 51 | image_data = np.array(new_image, np.float32) 52 | 53 | # 调整目标框坐标 54 | box_data = np.zeros((len(box), 5)) 55 | if len(box) > 0: 56 | np.random.shuffle(box) 57 | box[:, [0, 2]] = box[:, [0, 2]] * nw / iw + dx 58 | box[:, [1, 3]] = box[:, [1, 3]] * nh / ih + dy 59 | box[:, 0:2][box[:, 0:2] < 0] = 0 60 | box[:, 2][box[:, 2] > w] = w 61 | box[:, 3][box[:, 3] > h] = h 62 | box_w = box[:, 2] - box[:, 0] 63 | box_h = box[:, 3] - box[:, 1] 64 | box = box[np.logical_and(box_w > 1, box_h > 1)] # 保留有效框 65 | box_data = np.zeros((len(box), 5)) 66 | box_data[:len(box)] = box 67 | 68 | return image_data, box_data 69 | 70 | # 调整图片大小 71 | new_ar = w / h * self.rand(1 - jitter, 1 + jitter) / self.rand(1 - jitter, 1 + jitter) 72 | scale = self.rand(.25, 2) 73 | if new_ar < 1: 74 | nh = int(scale * h) 75 | nw = int(nh * new_ar) 76 | else: 77 | nw = int(scale * w) 78 | nh = int(nw / new_ar) 79 | image = image.resize((nw, nh), Image.BICUBIC) 80 | 81 | # 放置图片 82 | dx = int(self.rand(0, w - nw)) 83 | dy = int(self.rand(0, h - nh)) 84 | new_image = Image.new('RGB', (w, h), 85 | (np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255))) 86 | new_image.paste(image, (dx, dy)) 87 | image = new_image 88 | 89 | # 是否翻转图片 90 | flip = self.rand() < .5 91 | if flip: 92 | image = image.transpose(Image.FLIP_LEFT_RIGHT) 93 | 94 | # 色域变换 95 | hue = self.rand(-hue, hue) 96 | sat = self.rand(1, sat) if self.rand() < .5 else 1 / self.rand(1, sat) 97 | val = self.rand(1, val) if self.rand() < .5 else 1 / self.rand(1, val) 98 | x = cv2.cvtColor(np.array(image,np.float32)/255, cv2.COLOR_RGB2HSV) 99 | x[..., 0] += hue*360 100 | x[..., 0][x[..., 0]>1] -= 1 101 | x[..., 0][x[..., 0]<0] += 1 102 | x[..., 1] *= sat 103 | x[..., 2] *= val 104 | x[x[:,:, 0]>360, 0] = 360 105 | x[:, :, 1:][x[:, :, 1:]>1] = 1 106 | x[x<0] = 0 107 | image_data = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255 108 | 109 | # 调整目标框坐标 110 | box_data = np.zeros((len(box), 5)) 111 | if len(box) > 0: 112 | np.random.shuffle(box) 113 | box[:, [0, 2]] = box[:, [0, 2]] * nw / iw + dx 114 | box[:, [1, 3]] = box[:, [1, 3]] * nh / ih + dy 115 | if flip: 116 | box[:, [0, 2]] = w - box[:, [2, 0]] 117 | box[:, 0:2][box[:, 0:2] < 0] = 0 118 | box[:, 2][box[:, 2] > w] = w 119 | box[:, 3][box[:, 3] > h] = h 120 | box_w = box[:, 2] - box[:, 0] 121 | box_h = box[:, 3] - box[:, 1] 122 | box = box[np.logical_and(box_w > 1, box_h > 1)] # 保留有效框 123 | box_data = np.zeros((len(box), 5)) 124 | box_data[:len(box)] = box 125 | 126 | return image_data, box_data 127 | 128 | def get_random_data_with_Mosaic(self, annotation_line, input_shape, hue=.1, sat=1.5, val=1.5): 129 | h, w = input_shape 130 | min_offset_x = 0.3 131 | min_offset_y = 0.3 132 | scale_low = 1 - min(min_offset_x, min_offset_y) 133 | scale_high = scale_low + 0.2 134 | 135 | image_datas = [] 136 | box_datas = [] 137 | index = 0 138 | 139 | place_x = [0, 0, int(w * min_offset_x), int(w * min_offset_x)] 140 | place_y = [0, int(h * min_offset_y), int(h * min_offset_y), 0] 141 | for line in annotation_line: 142 | # 每一行进行分割 143 | line_content = line.split() 144 | # 打开图片 145 | image = Image.open(line_content[0]) 146 | image = image.convert("RGB") 147 | # 图片的大小 148 | iw, ih = image.size 149 | # 保存框的位置 150 | box = np.array([np.array(list(map(int, box.split(',')))) for box in line_content[1:]]) 151 | 152 | # 是否翻转图片 153 | flip = self.rand() < .5 154 | if flip and len(box) > 0: 155 | image = image.transpose(Image.FLIP_LEFT_RIGHT) 156 | box[:, [0, 2]] = iw - box[:, [2, 0]] 157 | 158 | # 对输入进来的图片进行缩放 159 | new_ar = w / h 160 | scale = self.rand(scale_low, scale_high) 161 | if new_ar < 1: 162 | nh = int(scale * h) 163 | nw = int(nh * new_ar) 164 | else: 165 | nw = int(scale * w) 166 | nh = int(nw / new_ar) 167 | image = image.resize((nw, nh), Image.BICUBIC) 168 | 169 | # 进行色域变换 170 | hue = self.rand(-hue, hue) 171 | sat = self.rand(1, sat) if self.rand() < .5 else 1 / self.rand(1, sat) 172 | val = self.rand(1, val) if self.rand() < .5 else 1 / self.rand(1, val) 173 | x = cv2.cvtColor(np.array(image,np.float32)/255, cv2.COLOR_RGB2HSV) 174 | x[..., 0] += hue*360 175 | x[..., 0][x[..., 0]>1] -= 1 176 | x[..., 0][x[..., 0]<0] += 1 177 | x[..., 1] *= sat 178 | x[..., 2] *= val 179 | x[x[:,:, 0]>360, 0] = 360 180 | x[:, :, 1:][x[:, :, 1:]>1] = 1 181 | x[x<0] = 0 182 | image = cv2.cvtColor(x, cv2.COLOR_HSV2RGB) # numpy array, 0 to 1 183 | 184 | image = Image.fromarray((image * 255).astype(np.uint8)) 185 | # 将图片进行放置,分别对应四张分割图片的位置 186 | dx = place_x[index] 187 | dy = place_y[index] 188 | new_image = Image.new('RGB', (w, h), 189 | (np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255))) 190 | new_image.paste(image, (dx, dy)) 191 | image_data = np.array(new_image) 192 | 193 | index = index + 1 194 | box_data = [] 195 | # 对box进行重新处理 196 | if len(box) > 0: 197 | np.random.shuffle(box) 198 | box[:, [0, 2]] = box[:, [0, 2]] * nw / iw + dx 199 | box[:, [1, 3]] = box[:, [1, 3]] * nh / ih + dy 200 | box[:, 0:2][box[:, 0:2] < 0] = 0 201 | box[:, 2][box[:, 2] > w] = w 202 | box[:, 3][box[:, 3] > h] = h 203 | box_w = box[:, 2] - box[:, 0] 204 | box_h = box[:, 3] - box[:, 1] 205 | box = box[np.logical_and(box_w > 1, box_h > 1)] 206 | box_data = np.zeros((len(box), 5)) 207 | box_data[:len(box)] = box 208 | 209 | image_datas.append(image_data) 210 | box_datas.append(box_data) 211 | 212 | # 将图片分割,放在一起 213 | cutx = np.random.randint(int(w * min_offset_x), int(w * (1 - min_offset_x))) 214 | cuty = np.random.randint(int(h * min_offset_y), int(h * (1 - min_offset_y))) 215 | 216 | new_image = np.zeros([h, w, 3]) 217 | new_image[:cuty, :cutx, :] = image_datas[0][:cuty, :cutx, :] 218 | new_image[cuty:, :cutx, :] = image_datas[1][cuty:, :cutx, :] 219 | new_image[cuty:, cutx:, :] = image_datas[2][cuty:, cutx:, :] 220 | new_image[:cuty, cutx:, :] = image_datas[3][:cuty, cutx:, :] 221 | 222 | # 对框进行进一步的处理 223 | new_boxes = np.array(merge_bboxes(box_datas, cutx, cuty)) 224 | 225 | return new_image, new_boxes 226 | 227 | def __getitem__(self, index): 228 | lines = self.train_lines 229 | n = self.train_batches 230 | index = index % n 231 | if self.mosaic: 232 | if self.flag and (index + 4) < n: 233 | img, y = self.get_random_data_with_Mosaic(lines[index:index + 4], self.image_size[0:2]) 234 | else: 235 | img, y = self.get_random_data(lines[index], self.image_size[0:2], random=self.is_train) 236 | self.flag = bool(1-self.flag) 237 | else: 238 | img, y = self.get_random_data(lines[index], self.image_size[0:2], random=self.is_train) 239 | 240 | if len(y) != 0: 241 | # 从坐标转换成0~1的百分比 242 | boxes = np.array(y[:, :4], dtype=np.float32) 243 | boxes[:, 0] = boxes[:, 0] / self.image_size[1] 244 | boxes[:, 1] = boxes[:, 1] / self.image_size[0] 245 | boxes[:, 2] = boxes[:, 2] / self.image_size[1] 246 | boxes[:, 3] = boxes[:, 3] / self.image_size[0] 247 | 248 | boxes = np.maximum(np.minimum(boxes, 1), 0) 249 | boxes[:, 2] = boxes[:, 2] - boxes[:, 0] 250 | boxes[:, 3] = boxes[:, 3] - boxes[:, 1] 251 | 252 | boxes[:, 0] = boxes[:, 0] + boxes[:, 2] / 2 253 | boxes[:, 1] = boxes[:, 1] + boxes[:, 3] / 2 254 | y = np.concatenate([boxes, y[:, -1:]], axis=-1) 255 | img = np.array(img, dtype=np.float32) 256 | 257 | tmp_inp = np.transpose(img / 255.0, (2, 0, 1)) 258 | tmp_targets = np.array(y, dtype=np.float32) 259 | return tmp_inp, tmp_targets 260 | 261 | 262 | # DataLoader中collate_fn使用 263 | def yolo_dataset_collate(batch): 264 | images = [] 265 | bboxes = [] 266 | for img, box in batch: 267 | images.append(img) 268 | bboxes.append(box) 269 | images = np.array(images) 270 | return images, bboxes 271 | 272 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import math 4 | import os 5 | import time 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from PIL import Image, ImageDraw, ImageFont 12 | from torch.autograd import Variable 13 | from torchvision.ops import nms 14 | 15 | 16 | class DecodeBox(nn.Module): 17 | def __init__(self, anchors, num_classes, img_size): 18 | super(DecodeBox, self).__init__() 19 | #-----------------------------------------------------------# 20 | # 13x13的特征层对应的anchor是[142, 110], [192, 243], [459, 401] 21 | # 26x26的特征层对应的anchor是[36, 75], [76, 55], [72, 146] 22 | # 52x52的特征层对应的anchor是[12, 16], [19, 36], [40, 28] 23 | #-----------------------------------------------------------# 24 | self.anchors = anchors 25 | self.num_anchors = len(anchors) 26 | self.num_classes = num_classes 27 | self.bbox_attrs = 5 + num_classes 28 | self.img_size = img_size 29 | 30 | def forward(self, input): 31 | #-----------------------------------------------# 32 | # 输入的input一共有三个,他们的shape分别是 33 | # batch_size, 255, 13, 13 34 | # batch_size, 255, 26, 26 35 | # batch_size, 255, 52, 52 36 | #-----------------------------------------------# 37 | batch_size = input.size(0) 38 | input_height = input.size(2) 39 | input_width = input.size(3) 40 | 41 | #-----------------------------------------------# 42 | # 输入为416x416时 43 | # stride_h = stride_w = 32、16、8 44 | #-----------------------------------------------# 45 | stride_h = self.img_size[1] / input_height 46 | stride_w = self.img_size[0] / input_width 47 | #-------------------------------------------------# 48 | # 此时获得的scaled_anchors大小是相对于特征层的 49 | #-------------------------------------------------# 50 | scaled_anchors = [(anchor_width / stride_w, anchor_height / stride_h) for anchor_width, anchor_height in self.anchors] 51 | 52 | #-----------------------------------------------# 53 | # 输入的input一共有三个,他们的shape分别是 54 | # batch_size, 3, 13, 13, 85 55 | # batch_size, 3, 26, 26, 85 56 | # batch_size, 3, 52, 52, 85 57 | #-----------------------------------------------# 58 | prediction = input.view(batch_size, self.num_anchors, 59 | self.bbox_attrs, input_height, input_width).permute(0, 1, 3, 4, 2).contiguous() 60 | 61 | # 先验框的中心位置的调整参数 62 | x = torch.sigmoid(prediction[..., 0]) 63 | y = torch.sigmoid(prediction[..., 1]) 64 | # 先验框的宽高调整参数 65 | w = prediction[..., 2] 66 | h = prediction[..., 3] 67 | # 获得置信度,是否有物体 68 | conf = torch.sigmoid(prediction[..., 4]) 69 | # 种类置信度 70 | pred_cls = torch.sigmoid(prediction[..., 5:]) 71 | 72 | FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor 73 | LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor 74 | 75 | #----------------------------------------------------------# 76 | # 生成网格,先验框中心,网格左上角 77 | # batch_size,3,13,13 78 | #----------------------------------------------------------# 79 | grid_x = torch.linspace(0, input_width - 1, input_width).repeat(input_height, 1).repeat( 80 | batch_size * self.num_anchors, 1, 1).view(x.shape).type(FloatTensor) 81 | grid_y = torch.linspace(0, input_height - 1, input_height).repeat(input_width, 1).t().repeat( 82 | batch_size * self.num_anchors, 1, 1).view(y.shape).type(FloatTensor) 83 | 84 | #----------------------------------------------------------# 85 | # 按照网格格式生成先验框的宽高 86 | # batch_size,3,13,13 87 | #----------------------------------------------------------# 88 | anchor_w = FloatTensor(scaled_anchors).index_select(1, LongTensor([0])) 89 | anchor_h = FloatTensor(scaled_anchors).index_select(1, LongTensor([1])) 90 | anchor_w = anchor_w.repeat(batch_size, 1).repeat(1, 1, input_height * input_width).view(w.shape) 91 | anchor_h = anchor_h.repeat(batch_size, 1).repeat(1, 1, input_height * input_width).view(h.shape) 92 | 93 | #----------------------------------------------------------# 94 | # 利用预测结果对先验框进行调整 95 | # 首先调整先验框的中心,从先验框中心向右下角偏移 96 | # 再调整先验框的宽高。 97 | #----------------------------------------------------------# 98 | pred_boxes = FloatTensor(prediction[..., :4].shape) 99 | pred_boxes[..., 0] = x.data + grid_x 100 | pred_boxes[..., 1] = y.data + grid_y 101 | pred_boxes[..., 2] = torch.exp(w.data) * anchor_w 102 | pred_boxes[..., 3] = torch.exp(h.data) * anchor_h 103 | 104 | # fig = plt.figure() 105 | # ax = fig.add_subplot(121) 106 | # if input_height==13: 107 | # plt.ylim(0,13) 108 | # plt.xlim(0,13) 109 | # elif input_height==26: 110 | # plt.ylim(0,26) 111 | # plt.xlim(0,26) 112 | # elif input_height==52: 113 | # plt.ylim(0,52) 114 | # plt.xlim(0,52) 115 | # plt.scatter(grid_x.cpu(),grid_y.cpu()) 116 | 117 | # anchor_left = grid_x - anchor_w/2 118 | # anchor_top = grid_y - anchor_h/2 119 | 120 | # rect1 = plt.Rectangle([anchor_left[0,0,5,5],anchor_top[0,0,5,5]],anchor_w[0,0,5,5],anchor_h[0,0,5,5],color="r",fill=False) 121 | # rect2 = plt.Rectangle([anchor_left[0,1,5,5],anchor_top[0,1,5,5]],anchor_w[0,1,5,5],anchor_h[0,1,5,5],color="r",fill=False) 122 | # rect3 = plt.Rectangle([anchor_left[0,2,5,5],anchor_top[0,2,5,5]],anchor_w[0,2,5,5],anchor_h[0,2,5,5],color="r",fill=False) 123 | 124 | # ax.add_patch(rect1) 125 | # ax.add_patch(rect2) 126 | # ax.add_patch(rect3) 127 | 128 | # ax = fig.add_subplot(122) 129 | # if input_height==13: 130 | # plt.ylim(0,13) 131 | # plt.xlim(0,13) 132 | # elif input_height==26: 133 | # plt.ylim(0,26) 134 | # plt.xlim(0,26) 135 | # elif input_height==52: 136 | # plt.ylim(0,52) 137 | # plt.xlim(0,52) 138 | # plt.scatter(grid_x.cpu(),grid_y.cpu()) 139 | # plt.scatter(pred_boxes[0,:,5,5,0].cpu(),pred_boxes[0,:,5,5,1].cpu(),c='r') 140 | 141 | # pre_left = pred_boxes[...,0] - pred_boxes[...,2]/2 142 | # pre_top = pred_boxes[...,1] - pred_boxes[...,3]/2 143 | 144 | # rect1 = plt.Rectangle([pre_left[0,0,5,5],pre_top[0,0,5,5]],pred_boxes[0,0,5,5,2],pred_boxes[0,0,5,5,3],color="r",fill=False) 145 | # rect2 = plt.Rectangle([pre_left[0,1,5,5],pre_top[0,1,5,5]],pred_boxes[0,1,5,5,2],pred_boxes[0,1,5,5,3],color="r",fill=False) 146 | # rect3 = plt.Rectangle([pre_left[0,2,5,5],pre_top[0,2,5,5]],pred_boxes[0,2,5,5,2],pred_boxes[0,2,5,5,3],color="r",fill=False) 147 | 148 | # ax.add_patch(rect1) 149 | # ax.add_patch(rect2) 150 | # ax.add_patch(rect3) 151 | 152 | # plt.show() 153 | 154 | #----------------------------------------------------------# 155 | # 将输出结果调整成相对于输入图像大小 156 | #----------------------------------------------------------# 157 | _scale = torch.Tensor([stride_w, stride_h] * 2).type(FloatTensor) 158 | output = torch.cat((pred_boxes.view(batch_size, -1, 4) * _scale, 159 | conf.view(batch_size, -1, 1), pred_cls.view(batch_size, -1, self.num_classes)), -1) 160 | return output.data 161 | 162 | def letterbox_image(image, size): 163 | iw, ih = image.size 164 | w, h = size 165 | scale = min(w/iw, h/ih) 166 | nw = int(iw*scale) 167 | nh = int(ih*scale) 168 | 169 | image = image.resize((nw,nh), Image.BICUBIC) 170 | new_image = Image.new('RGB', size, (128,128,128)) 171 | new_image.paste(image, ((w-nw)//2, (h-nh)//2)) 172 | return new_image 173 | 174 | def yolo_correct_boxes(top, left, bottom, right, input_shape, image_shape): 175 | new_shape = image_shape*np.min(input_shape/image_shape) 176 | 177 | offset = (input_shape-new_shape)/2./input_shape 178 | scale = input_shape/new_shape 179 | 180 | box_yx = np.concatenate(((top+bottom)/2,(left+right)/2),axis=-1)/input_shape 181 | box_hw = np.concatenate((bottom-top,right-left),axis=-1)/input_shape 182 | 183 | box_yx = (box_yx - offset) * scale 184 | box_hw *= scale 185 | 186 | box_mins = box_yx - (box_hw / 2.) 187 | box_maxes = box_yx + (box_hw / 2.) 188 | boxes = np.concatenate([ 189 | box_mins[:, 0:1], 190 | box_mins[:, 1:2], 191 | box_maxes[:, 0:1], 192 | box_maxes[:, 1:2] 193 | ],axis=-1) 194 | boxes *= np.concatenate([image_shape, image_shape],axis=-1) 195 | return boxes 196 | 197 | def bbox_iou(box1, box2, x1y1x2y2=True): 198 | """ 199 | 计算IOU 200 | """ 201 | if not x1y1x2y2: 202 | b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2 203 | b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2 204 | b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2 205 | b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2 206 | else: 207 | b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3] 208 | b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3] 209 | 210 | inter_rect_x1 = torch.max(b1_x1, b2_x1) 211 | inter_rect_y1 = torch.max(b1_y1, b2_y1) 212 | inter_rect_x2 = torch.min(b1_x2, b2_x2) 213 | inter_rect_y2 = torch.min(b1_y2, b2_y2) 214 | 215 | inter_area = torch.clamp(inter_rect_x2 - inter_rect_x1 + 1, min=0) * \ 216 | torch.clamp(inter_rect_y2 - inter_rect_y1 + 1, min=0) 217 | 218 | b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1) 219 | b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1) 220 | 221 | iou = inter_area / (b1_area + b2_area - inter_area + 1e-16) 222 | 223 | return iou 224 | 225 | 226 | def non_max_suppression(prediction, num_classes, conf_thres=0.5, nms_thres=0.4): 227 | #----------------------------------------------------------# 228 | # 将预测结果的格式转换成左上角右下角的格式。 229 | # prediction [batch_size, num_anchors, 85] 230 | #----------------------------------------------------------# 231 | box_corner = prediction.new(prediction.shape) 232 | box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2 233 | box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2 234 | box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2 235 | box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2 236 | prediction[:, :, :4] = box_corner[:, :, :4] 237 | 238 | output = [None for _ in range(len(prediction))] 239 | for image_i, image_pred in enumerate(prediction): 240 | #----------------------------------------------------------# 241 | # 对种类预测部分取max。 242 | # class_conf [num_anchors, 1] 种类置信度 243 | # class_pred [num_anchors, 1] 种类 244 | #----------------------------------------------------------# 245 | class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True) 246 | 247 | #----------------------------------------------------------# 248 | # 利用置信度进行第一轮筛选 249 | #----------------------------------------------------------# 250 | conf_mask = (image_pred[:, 4] * class_conf[:, 0] >= conf_thres).squeeze() 251 | 252 | #----------------------------------------------------------# 253 | # 根据置信度进行预测结果的筛选 254 | #----------------------------------------------------------# 255 | image_pred = image_pred[conf_mask] 256 | class_conf = class_conf[conf_mask] 257 | class_pred = class_pred[conf_mask] 258 | if not image_pred.size(0): 259 | continue 260 | #-------------------------------------------------------------------------# 261 | # detections [num_anchors, 7] 262 | # 7的内容为:x1, y1, x2, y2, obj_conf, class_conf, class_pred 263 | #-------------------------------------------------------------------------# 264 | detections = torch.cat((image_pred[:, :5], class_conf.float(), class_pred.float()), 1) 265 | 266 | #------------------------------------------# 267 | # 获得预测结果中包含的所有种类 268 | #------------------------------------------# 269 | unique_labels = detections[:, -1].cpu().unique() 270 | 271 | if prediction.is_cuda: 272 | unique_labels = unique_labels.cuda() 273 | detections = detections.cuda() 274 | 275 | for c in unique_labels: 276 | #------------------------------------------# 277 | # 获得某一类得分筛选后全部的预测结果 278 | #------------------------------------------# 279 | detections_class = detections[detections[:, -1] == c] 280 | 281 | #------------------------------------------# 282 | # 使用官方自带的非极大抑制会速度更快一些! 283 | #------------------------------------------# 284 | keep = nms( 285 | detections_class[:, :4], 286 | detections_class[:, 4] * detections_class[:, 5], 287 | nms_thres 288 | ) 289 | max_detections = detections_class[keep] 290 | 291 | # # 按照存在物体的置信度排序 292 | # _, conf_sort_index = torch.sort(detections_class[:, 4]*detections_class[:, 5], descending=True) 293 | # detections_class = detections_class[conf_sort_index] 294 | # # 进行非极大抑制 295 | # max_detections = [] 296 | # while detections_class.size(0): 297 | # # 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉 298 | # max_detections.append(detections_class[0].unsqueeze(0)) 299 | # if len(detections_class) == 1: 300 | # break 301 | # ious = bbox_iou(max_detections[-1], detections_class[1:]) 302 | # detections_class = detections_class[1:][ious < nms_thres] 303 | # # 堆叠 304 | # max_detections = torch.cat(max_detections).data 305 | 306 | # Add max detections to outputs 307 | output[image_i] = max_detections if output[image_i] is None else torch.cat( 308 | (output[image_i], max_detections)) 309 | 310 | return output 311 | 312 | 313 | def merge_bboxes(bboxes, cutx, cuty): 314 | merge_bbox = [] 315 | for i in range(len(bboxes)): 316 | for box in bboxes[i]: 317 | tmp_box = [] 318 | x1,y1,x2,y2 = box[0], box[1], box[2], box[3] 319 | 320 | if i == 0: 321 | if y1 > cuty or x1 > cutx: 322 | continue 323 | if y2 >= cuty and y1 <= cuty: 324 | y2 = cuty 325 | if y2-y1 < 5: 326 | continue 327 | if x2 >= cutx and x1 <= cutx: 328 | x2 = cutx 329 | if x2-x1 < 5: 330 | continue 331 | 332 | if i == 1: 333 | if y2 < cuty or x1 > cutx: 334 | continue 335 | 336 | if y2 >= cuty and y1 <= cuty: 337 | y1 = cuty 338 | if y2-y1 < 5: 339 | continue 340 | 341 | if x2 >= cutx and x1 <= cutx: 342 | x2 = cutx 343 | if x2-x1 < 5: 344 | continue 345 | 346 | if i == 2: 347 | if y2 < cuty or x2 < cutx: 348 | continue 349 | 350 | if y2 >= cuty and y1 <= cuty: 351 | y1 = cuty 352 | if y2-y1 < 5: 353 | continue 354 | 355 | if x2 >= cutx and x1 <= cutx: 356 | x1 = cutx 357 | if x2-x1 < 5: 358 | continue 359 | 360 | if i == 3: 361 | if y1 > cuty or x2 < cutx: 362 | continue 363 | 364 | if y2 >= cuty and y1 <= cuty: 365 | y2 = cuty 366 | if y2-y1 < 5: 367 | continue 368 | 369 | if x2 >= cutx and x1 <= cutx: 370 | x1 = cutx 371 | if x2-x1 < 5: 372 | continue 373 | 374 | tmp_box.append(x1) 375 | tmp_box.append(y1) 376 | tmp_box.append(x2) 377 | tmp_box.append(y2) 378 | tmp_box.append(box[-1]) 379 | merge_bbox.append(tmp_box) 380 | return merge_bbox 381 | -------------------------------------------------------------------------------- /yolo.py: -------------------------------------------------------------------------------- 1 | #-------------------------------------# 2 | # 创建YOLO类 3 | #-------------------------------------# 4 | import colorsys 5 | import os 6 | import time 7 | import cv2 8 | import numpy as np 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | import torch.nn as nn 12 | from PIL import Image, ImageDraw, ImageFont 13 | from torch.autograd import Variable 14 | 15 | from nets.yolo4 import YoloBody 16 | from utils.utils import (DecodeBox, bbox_iou, letterbox_image, 17 | non_max_suppression, yolo_correct_boxes) 18 | 19 | 20 | #--------------------------------------------# 21 | # 使用自己训练好的模型预测需要修改3个参数 22 | # model_path、classes_path和backbone 23 | # 都需要修改! 24 | # 如果出现shape不匹配,一定要注意 25 | # 训练时的model_path和classes_path参数的修改 26 | #--------------------------------------------# 27 | class YOLO(object): 28 | _defaults = { 29 | "model_path" : 'model_data/yolov4_mobile_mask.pth', 30 | "anchors_path" : 'model_data/yolo_anchors.txt', 31 | "classes_path" : 'model_data/voc_classes.txt', 32 | "backbone" : 'mobilenetv3', 33 | "model_image_size" : (608, 608, 3), 34 | "confidence" : 0.5, 35 | "iou" : 0.3, 36 | "cuda" : True, 37 | #---------------------------------------------------------------------# 38 | # 该变量用于控制是否使用letterbox_image对输入图像进行不失真的resize, 39 | # 在多次测试后,发现关闭letterbox_image直接resize的效果更好 40 | #---------------------------------------------------------------------# 41 | "letterbox_image" : True, 42 | } 43 | 44 | @classmethod 45 | def get_defaults(cls, n): 46 | if n in cls._defaults: 47 | return cls._defaults[n] 48 | else: 49 | return "Unrecognized attribute name '" + n + "'" 50 | 51 | #---------------------------------------------------# 52 | # 初始化YOLO 53 | #---------------------------------------------------# 54 | def __init__(self, **kwargs): 55 | self.__dict__.update(self._defaults) 56 | self.class_names = self._get_class() 57 | self.anchors = self._get_anchors() 58 | self.generate() 59 | 60 | #---------------------------------------------------# 61 | # 获得所有的分类 62 | #---------------------------------------------------# 63 | def _get_class(self): 64 | classes_path = os.path.expanduser(self.classes_path) 65 | with open(classes_path) as f: 66 | class_names = f.readlines() 67 | class_names = [c.strip() for c in class_names] 68 | return class_names 69 | 70 | #---------------------------------------------------# 71 | # 获得所有的先验框 72 | #---------------------------------------------------# 73 | def _get_anchors(self): 74 | anchors_path = os.path.expanduser(self.anchors_path) 75 | with open(anchors_path) as f: 76 | anchors = f.readline() 77 | anchors = [float(x) for x in anchors.split(',')] 78 | return np.array(anchors).reshape([-1, 3, 2])[::-1,:,:] 79 | 80 | #---------------------------------------------------# 81 | # 生成模型 82 | #---------------------------------------------------# 83 | def generate(self): 84 | #---------------------------------------------------# 85 | # 建立yolov4模型 86 | #---------------------------------------------------# 87 | self.net = YoloBody(len(self.anchors[0]),len(self.class_names),backbone=self.backbone).eval() 88 | 89 | #---------------------------------------------------# 90 | # 载入yolov4模型的权重 91 | #---------------------------------------------------# 92 | print('Loading weights into state dict...') 93 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 94 | state_dict = torch.load(self.model_path, map_location=device) 95 | self.net.load_state_dict(state_dict) 96 | print('Finished!') 97 | 98 | 99 | if self.cuda: 100 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 101 | self.net = nn.DataParallel(self.net) 102 | self.net = self.net.cuda() 103 | 104 | #---------------------------------------------------# 105 | # 建立三个特征层解码用的工具 106 | #---------------------------------------------------# 107 | self.yolo_decodes = [] 108 | for i in range(3): 109 | self.yolo_decodes.append(DecodeBox(self.anchors[i], len(self.class_names), (self.model_image_size[1], self.model_image_size[0]))) 110 | 111 | 112 | print('{} model, anchors, and classes loaded.'.format(self.model_path)) 113 | # 画框设置不同的颜色 114 | hsv_tuples = [(x / len(self.class_names), 1., 1.) 115 | for x in range(len(self.class_names))] 116 | self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples)) 117 | self.colors = list( 118 | map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), 119 | self.colors)) 120 | 121 | #---------------------------------------------------# 122 | # 检测图片 123 | #---------------------------------------------------# 124 | def detect_image(self, image): 125 | image_shape = np.array(np.shape(image)[0:2]) 126 | 127 | #---------------------------------------------------------# 128 | # 给图像增加灰条,实现不失真的resize 129 | # 也可以直接resize进行识别 130 | #---------------------------------------------------------# 131 | if self.letterbox_image: 132 | crop_img = np.array(letterbox_image(image, (self.model_image_size[1],self.model_image_size[0]))) 133 | else: 134 | crop_img = image.convert('RGB') 135 | crop_img = crop_img.resize((self.model_image_size[1],self.model_image_size[0]), Image.BICUBIC) 136 | photo = np.array(crop_img,dtype = np.float32) / 255.0 137 | photo = np.transpose(photo, (2, 0, 1)) 138 | #---------------------------------------------------------# 139 | # 添加上batch_size维度 140 | #---------------------------------------------------------# 141 | images = [photo] 142 | 143 | with torch.no_grad(): 144 | images = torch.from_numpy(np.asarray(images)) 145 | if self.cuda: 146 | images = images.cuda() 147 | 148 | #---------------------------------------------------------# 149 | # 将图像输入网络当中进行预测! 150 | #---------------------------------------------------------# 151 | outputs = self.net(images) 152 | output_list = [] 153 | for i in range(3): 154 | output_list.append(self.yolo_decodes[i](outputs[i])) 155 | 156 | #---------------------------------------------------------# 157 | # 将预测框进行堆叠,然后进行非极大抑制 158 | #---------------------------------------------------------# 159 | output = torch.cat(output_list, 1) 160 | batch_detections = non_max_suppression(output, len(self.class_names), 161 | conf_thres=self.confidence, 162 | nms_thres=self.iou) 163 | 164 | #---------------------------------------------------------# 165 | # 如果没有检测出物体,返回原图 166 | #---------------------------------------------------------# 167 | try: 168 | batch_detections = batch_detections[0].cpu().numpy() 169 | except: 170 | predicted_class="none"#表示没有找到目标 171 | return image,predicted_class 172 | 173 | #---------------------------------------------------------# 174 | # 对预测框进行得分筛选 175 | #---------------------------------------------------------# 176 | top_index = batch_detections[:,4] * batch_detections[:,5] > self.confidence 177 | top_conf = batch_detections[top_index,4]*batch_detections[top_index,5] 178 | top_label = np.array(batch_detections[top_index,-1],np.int32) 179 | top_bboxes = np.array(batch_detections[top_index,:4]) 180 | top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims(top_bboxes[:,0],-1),np.expand_dims(top_bboxes[:,1],-1),np.expand_dims(top_bboxes[:,2],-1),np.expand_dims(top_bboxes[:,3],-1) 181 | 182 | #-----------------------------------------------------------------# 183 | # 在图像传入网络预测前会进行letterbox_image给图像周围添加灰条 184 | # 因此生成的top_bboxes是相对于有灰条的图像的 185 | # 我们需要对其进行修改,去除灰条的部分。 186 | #-----------------------------------------------------------------# 187 | if self.letterbox_image: 188 | boxes = yolo_correct_boxes(top_ymin,top_xmin,top_ymax,top_xmax,np.array([self.model_image_size[0],self.model_image_size[1]]),image_shape) 189 | else: 190 | top_xmin = top_xmin / self.model_image_size[1] * image_shape[1] 191 | top_ymin = top_ymin / self.model_image_size[0] * image_shape[0] 192 | top_xmax = top_xmax / self.model_image_size[1] * image_shape[1] 193 | top_ymax = top_ymax / self.model_image_size[0] * image_shape[0] 194 | boxes = np.concatenate([top_ymin,top_xmin,top_ymax,top_xmax], axis=-1) 195 | 196 | font = ImageFont.truetype(font='model_data/simhei.ttf',size=np.floor(3e-2 * np.shape(image)[1] + 0.5).astype('int32')) 197 | 198 | thickness = max((np.shape(image)[0] + np.shape(image)[1]) // self.model_image_size[0], 1) 199 | 200 | for i, c in enumerate(top_label): 201 | predicted_class = self.class_names[c] 202 | score = top_conf[i] 203 | 204 | top, left, bottom, right = boxes[i] 205 | top = top - 5 206 | left = left - 5 207 | bottom = bottom + 5 208 | right = right + 5 209 | 210 | top = max(0, np.floor(top + 0.5).astype('int32')) 211 | left = max(0, np.floor(left + 0.5).astype('int32')) 212 | bottom = min(np.shape(image)[0], np.floor(bottom + 0.5).astype('int32')) 213 | right = min(np.shape(image)[1], np.floor(right + 0.5).astype('int32')) 214 | 215 | # 画框框 216 | label = '{} {:.2f}'.format(predicted_class, score) 217 | draw = ImageDraw.Draw(image) 218 | label_size = draw.textsize(label, font) 219 | label = label.encode('utf-8') 220 | print(label, top, left, bottom, right) 221 | 222 | if top - label_size[1] >= 0: 223 | text_origin = np.array([left, top - label_size[1]]) 224 | else: 225 | text_origin = np.array([left, top + 1]) 226 | 227 | for i in range(thickness): 228 | draw.rectangle( 229 | [left + i, top + i, right - i, bottom - i], 230 | outline=self.colors[self.class_names.index(predicted_class)]) 231 | draw.rectangle( 232 | [tuple(text_origin), tuple(text_origin + label_size)], 233 | fill=self.colors[self.class_names.index(predicted_class)]) 234 | draw.text(text_origin, str(label,'UTF-8'), fill=(0, 0, 0), font=font) 235 | del draw 236 | return image,predicted_class#返回图片和预测结果 237 | 238 | 239 | def detect_video(video_path, output_path): 240 | yolo = YOLO() 241 | #-------------------------------------# 242 | # 调用摄像头 243 | # capture=cv2.VideoCapture("1.mp4") 244 | #-------------------------------------# 245 | capture=cv2.VideoCapture(video_path) 246 | if not capture.isOpened(): 247 | raise IOError("请检查摄像头连接") 248 | isOutput = True if output_path != "" else False 249 | video_FourCC = int(capture.get(cv2.CAP_PROP_FOURCC)) # 视频编码 250 | video_fps = capture.get(cv2.CAP_PROP_FPS) # 视频的帧率 251 | video_size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), 252 | int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))) # 视频的宽和高 253 | if isOutput: 254 | print("!!! TYPE:", type(output_path), type(video_FourCC), type(video_fps), type(video_size)) 255 | out = cv2.VideoWriter(output_path, video_FourCC, video_fps, video_size) # 视频保存 256 | fps = 0.0 257 | while(True): 258 | t1 = time.time() 259 | # 读取某一帧 260 | ret,frame=capture.read() 261 | if ret == True: 262 | frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) 263 | # 转变成Image 264 | frame = Image.fromarray(np.uint8(frame)) 265 | 266 | # 进行检测 267 | frame = np.array(yolo.detect_image(frame)) 268 | 269 | # RGBtoBGR满足opencv显示格式 270 | frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR) 271 | 272 | fps = ( fps + (1./(time.time()-t1)) ) / 2 273 | print("fps= %.2f"%(fps)) 274 | frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) 275 | cv2.namedWindow("Mask", cv2.WINDOW_NORMAL) 276 | cv2.imshow("Mask",frame) 277 | else: 278 | capture.release() 279 | out.release() 280 | cv2.waitKey(0) 281 | if isOutput: 282 | out.write(frame) 283 | if cv2.waitKey(1) & 0xFF == ord('q'): # 延时 284 | break 285 | cv2.destroyAllWindows()# 防止窗口无法关闭卡死 286 | --------------------------------------------------------------------------------