├── CallFrame.py ├── Frame.py ├── GetTestImage.py ├── GetTrainImage.py ├── README.md ├── SaveGesture.py ├── TestGesture.py ├── TestInTest.py ├── Train.py ├── Train_inputdata.py ├── Train_model.py └── ges_ico ├── frame.ico ├── ges1.ico ├── ges2.ico ├── ges3.ico ├── ges4.ico ├── ges5.ico └── white.ico /CallFrame.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | import serial # 这个模块是通信模块 4 | from PyQt5.QtWidgets import QApplication, QMainWindow, QMessageBox 5 | from PyQt5.QtCore import Qt, pyqtSignal, QDateTime, QThread 6 | from PyQt5.QtGui import QIcon 7 | from PyQt5.QtGui import QPixmap, QPalette 8 | from SaveGesture import * 9 | from TestGesture import * 10 | from Frame import * 11 | 12 | global gesture_action 13 | 14 | class MyMainWindow(QMainWindow,Ui_MainWindow): 15 | def __init__(self, parent=None): 16 | super(MyMainWindow, self).__init__() 17 | self.setupUi(self) # 在界面文件Frame中以及根据界面自动定义了 18 | self.initUI() 19 | 20 | def initUI(self): 21 | # 给按钮连接槽函数(CloseButton在Frame中自动连接了) 22 | self.GetGestureButton.clicked.connect(self.GetGesture) 23 | self.JudgeButton.clicked.connect(self.JudgeGesture) 24 | self.ExcuteGestureButton.clicked.connect(self.ExcuteGesture) 25 | self.HelpButton.clicked.connect(self.Help) 26 | 27 | # 窗口设置美化 28 | self.setWindowTitle('手势识别') 29 | self.setWindowIcon(QIcon('./ges_ico/frame.ico')) 30 | self.resize(750, 485) 31 | 32 | # 线程操作用于显示时间 33 | self.initxianceng() 34 | 35 | # 单独给CloseButton添加标签 36 | self.CloseButton.setProperty('color', 'gray') # 自定义标签 37 | self.GetGestureButton.setProperty('color', 'same') 38 | self.JudgeButton.setProperty('color', 'same') 39 | self.ExcuteGestureButton.setProperty('color', 'same') 40 | self.HelpButton.setProperty('color', 'same') 41 | 42 | # 定义槽函数 43 | def GetGesture(self): 44 | self.LitResultlabel.setText("") 45 | self.ImaResultlabel.setPixmap(QPixmap('./ges_ico/white.ico')) 46 | self.LitResultlabel.setAutoFillBackground(False) 47 | saveGesture() 48 | self.LitResultlabel.setText("已经将该图像保存在电脑本地") 49 | self.LitResultlabel.setAlignment(Qt.AlignCenter) 50 | 51 | def JudgeGesture(self): 52 | global gesture_action # 要修改全局变量需要先在函数里面声明一下 53 | self.LitResultlabel.setText("正在调用卷积神经网络识别图像") 54 | self.LitResultlabel.setAlignment(Qt.AlignCenter) 55 | QApplication.processEvents() # 这里需要刷新一下,否则上面的文字不显示 56 | gesture_num = evaluate_one_image() 57 | if gesture_num == 1: 58 | gesture_action = "1" 59 | self.result_show_1() 60 | elif gesture_num == 2: 61 | gesture_action = "2" 62 | self.result_show_2() 63 | elif gesture_num == 3: 64 | gesture_action = "3" 65 | self.result_show_3() 66 | elif gesture_num == 4: 67 | gesture_action = "4" 68 | self.result_show_4() 69 | elif gesture_num == 5: 70 | gesture_action = "5" 71 | self.result_show_5() 72 | 73 | def ExcuteGesture(self): 74 | self.serial_communicate() 75 | 76 | def Help(self): 77 | QMessageBox.information(self, "操作提示框", "获取手势:通过OpenCV和摄像头获取一张即时照片。\n" 78 | "判断手势:通过之前训练好的参数和卷积神经网络判断手势。\n" 79 | "执行手势:根据识别的手势姿态控制机械手作业。") 80 | 81 | def serial_communicate(self): 82 | """ 83 | 把gesture_action的字符(1~5)传给下位机 84 | """ 85 | print(gesture_action) 86 | try: 87 | # G7电脑的左端口是COM4 88 | portx = "COM4" 89 | # 波特率,我STM32单片机设置的为115200 90 | bps = 115200 91 | # 超时设置,None:永远等待操作,0为立即返回请求结果,其他值为等待超时时间(单位为秒) 92 | timex = 5 93 | # 打开串口,并得到串口对象 94 | ser = serial.Serial(portx, bps, timeout=timex) 95 | # 写数据 96 | # result=ser.write("1".encode("utf-8")) # encode("gbk")是中国中文的字符 97 | result = ser.write(gesture_action.encode("utf-8")) 98 | print("成功:", result) 99 | 100 | ser.close() # 关闭串口 101 | 102 | except Exception as extioc: 103 | print("---异常---:", extioc) 104 | 105 | def result_show_1(self): 106 | self.LitResultlabel.setText("判断结果:该手势为剪刀") 107 | self.LitResultlabel.setAutoFillBackground(True) # 允许上色 108 | palette = QPalette() # palette 调色板 109 | palette.setColor(QPalette.Window, Qt.lightGray) 110 | self.LitResultlabel.setPalette(palette) 111 | self.ImaResultlabel.setToolTip('这是一个示意图片结果') # 鼠标放在上面出现提示框 112 | self.ImaResultlabel.setPixmap(QPixmap('./ges_ico/ges1.ico')) 113 | self.LitResultlabel.setAlignment(Qt.AlignCenter) 114 | self.ImaResultlabel.setAlignment(Qt.AlignCenter) 115 | 116 | def result_show_2(self): 117 | self.LitResultlabel.setText("判断结果:该手势为石头") 118 | self.LitResultlabel.setAutoFillBackground(True) # 允许上色 119 | palette = QPalette() # palette 调色板 120 | palette.setColor(QPalette.Window, Qt.lightGray) 121 | self.LitResultlabel.setPalette(palette) 122 | self.ImaResultlabel.setToolTip('这是一个示意图片结果') # 鼠标放在上面出现提示框 123 | self.ImaResultlabel.setPixmap(QPixmap('./ges_ico/ges3.ico')) 124 | self.LitResultlabel.setAlignment(Qt.AlignCenter) 125 | self.ImaResultlabel.setAlignment(Qt.AlignCenter) 126 | 127 | def result_show_3(self): 128 | self.LitResultlabel.setText("判断结果:该手势为布") 129 | self.LitResultlabel.setAutoFillBackground(True) 130 | palette = QPalette() # palette 调色板 131 | palette.setColor(QPalette.Window, Qt.lightGray) 132 | self.LitResultlabel.setPalette(palette) 133 | self.ImaResultlabel.setToolTip('这是一个示意图片结果') # 鼠标放在上面出现提示框 134 | self.ImaResultlabel.setPixmap(QPixmap('./ges_ico/ges2.ico')) 135 | self.LitResultlabel.setAlignment(Qt.AlignCenter) 136 | self.ImaResultlabel.setAlignment(Qt.AlignCenter) 137 | 138 | def result_show_4(self): 139 | self.LitResultlabel.setText("判断结果:该手势为OK") 140 | self.LitResultlabel.setAutoFillBackground(True) 141 | palette = QPalette() # palette 调色板 142 | palette.setColor(QPalette.Window, Qt.lightGray) 143 | self.LitResultlabel.setPalette(palette) 144 | self.ImaResultlabel.setToolTip('这是一个示意图片结果') # 鼠标放在上面出现提示框 145 | self.ImaResultlabel.setPixmap(QPixmap('./ges_ico/ges4.ico')) 146 | self.LitResultlabel.setAlignment(Qt.AlignCenter) 147 | self.ImaResultlabel.setAlignment(Qt.AlignCenter) 148 | 149 | def result_show_5(self): 150 | self.LitResultlabel.setText("判断结果:该手势为good") 151 | self.LitResultlabel.setAutoFillBackground(True) 152 | palette = QPalette() # palette 调色板 153 | palette.setColor(QPalette.Window, Qt.lightGray) 154 | self.LitResultlabel.setPalette(palette) 155 | self.ImaResultlabel.setToolTip('这是一个示意图片结果') # 鼠标放在上面出现提示框 156 | self.ImaResultlabel.setPixmap(QPixmap('./ges_ico/ges5.ico')) 157 | self.LitResultlabel.setAlignment(Qt.AlignCenter) 158 | self.ImaResultlabel.setAlignment(Qt.AlignCenter) 159 | 160 | def initxianceng(self): 161 | # 创建线程 162 | self.backend = BackendThread() 163 | # 信号连接槽函数 164 | self.backend.update_date.connect(self.handleDisplay) 165 | # 开始线程 166 | self.backend.start() 167 | 168 | # 将当期时间输出到文本框 169 | def handleDisplay(self, data): 170 | self.statusBar().showMessage(data) 171 | 172 | 173 | # 后台线程更新时间 174 | class BackendThread(QThread): 175 | update_date = pyqtSignal(str) 176 | 177 | def run(self): 178 | while True: 179 | date = QDateTime.currentDateTime() 180 | currTime = date.toString('yyyy-MM-dd hh:mm:ss') 181 | self.update_date.emit(str(currTime)) 182 | time.sleep(1) # 推迟执行的1秒 183 | 184 | 185 | if __name__ == "__main__": 186 | app = QApplication(sys.argv) # sys.argv是一个命令行参数列表 187 | myWin = MyMainWindow() 188 | myWin.setObjectName('Window') 189 | # 给窗口背景上色 190 | qssStyle = ''' 191 | QPushButton[color='gray']{ 192 | background-color:rgb(205,197,191) 193 | } 194 | QPushButton[color='same']{ 195 | background-color:rgb(225,238,238) 196 | } 197 | #Window{ 198 | background-color:rgb(162,181,205) 199 | } 200 | ''' 201 | myWin.setStyleSheet(qssStyle) 202 | myWin.show() 203 | sys.exit(app.exec_()) 204 | 205 | -------------------------------------------------------------------------------- /Frame.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Form implementation generated from reading ui file 'Frame.ui' 4 | # 5 | # Created by: PyQt5 UI code generator 5.11.3 6 | # 7 | # WARNING! All changes made in this file will be lost! 8 | 9 | from PyQt5 import QtCore, QtGui, QtWidgets 10 | 11 | class Ui_MainWindow(object): 12 | def setupUi(self, MainWindow): 13 | MainWindow.setObjectName("MainWindow") 14 | MainWindow.resize(740, 511) 15 | self.centralwidget = QtWidgets.QWidget(MainWindow) 16 | self.centralwidget.setObjectName("centralwidget") 17 | self.groupBox = QtWidgets.QGroupBox(self.centralwidget) 18 | self.groupBox.setGeometry(QtCore.QRect(20, 30, 211, 411)) 19 | self.groupBox.setObjectName("groupBox") 20 | self.GetGestureButton = QtWidgets.QPushButton(self.groupBox) 21 | self.GetGestureButton.setGeometry(QtCore.QRect(20, 40, 171, 51)) 22 | self.GetGestureButton.setObjectName("GetGestureButton") 23 | self.HelpButton = QtWidgets.QPushButton(self.groupBox) 24 | self.HelpButton.setGeometry(QtCore.QRect(20, 310, 171, 51)) 25 | self.HelpButton.setObjectName("HelpButton") 26 | self.ExcuteGestureButton = QtWidgets.QPushButton(self.groupBox) 27 | self.ExcuteGestureButton.setGeometry(QtCore.QRect(20, 220, 171, 51)) 28 | self.ExcuteGestureButton.setObjectName("ExcuteGestureButton") 29 | self.JudgeButton = QtWidgets.QPushButton(self.groupBox) 30 | self.JudgeButton.setGeometry(QtCore.QRect(20, 130, 171, 51)) 31 | self.JudgeButton.setObjectName("JudgeButton") 32 | self.groupBox_2 = QtWidgets.QGroupBox(self.centralwidget) 33 | self.groupBox_2.setGeometry(QtCore.QRect(300, 30, 371, 291)) 34 | self.groupBox_2.setObjectName("groupBox_2") 35 | self.LitResultlabel = QtWidgets.QLabel(self.groupBox_2) 36 | self.LitResultlabel.setGeometry(QtCore.QRect(60, 50, 241, 21)) 37 | self.LitResultlabel.setText("") 38 | self.LitResultlabel.setObjectName("LitResultlabel") 39 | self.ImaResultlabel = QtWidgets.QLabel(self.groupBox_2) 40 | self.ImaResultlabel.setGeometry(QtCore.QRect(70, 80, 221, 171)) 41 | self.ImaResultlabel.setText("") 42 | self.ImaResultlabel.setObjectName("ImaResultlabel") 43 | self.CloseButton = QtWidgets.QPushButton(self.centralwidget) 44 | self.CloseButton.setGeometry(QtCore.QRect(540, 370, 111, 31)) 45 | self.CloseButton.setObjectName("CloseButton") 46 | MainWindow.setCentralWidget(self.centralwidget) 47 | self.menubar = QtWidgets.QMenuBar(MainWindow) 48 | self.menubar.setGeometry(QtCore.QRect(0, 0, 740, 26)) 49 | self.menubar.setObjectName("menubar") 50 | MainWindow.setMenuBar(self.menubar) 51 | self.statusbar = QtWidgets.QStatusBar(MainWindow) 52 | self.statusbar.setObjectName("statusbar") 53 | MainWindow.setStatusBar(self.statusbar) 54 | 55 | self.retranslateUi(MainWindow) 56 | self.CloseButton.clicked.connect(MainWindow.close) 57 | QtCore.QMetaObject.connectSlotsByName(MainWindow) 58 | 59 | def retranslateUi(self, MainWindow): 60 | _translate = QtCore.QCoreApplication.translate 61 | MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow")) 62 | self.groupBox.setTitle(_translate("MainWindow", "操作按钮")) 63 | self.GetGestureButton.setText(_translate("MainWindow", "获取手势")) 64 | self.HelpButton.setText(_translate("MainWindow", "操作提示")) 65 | self.ExcuteGestureButton.setText(_translate("MainWindow", "执行手势")) 66 | self.JudgeButton.setText(_translate("MainWindow", "判断手势")) 67 | self.groupBox_2.setTitle(_translate("MainWindow", "结果显示")) 68 | self.CloseButton.setText(_translate("MainWindow", "关闭")) 69 | 70 | -------------------------------------------------------------------------------- /GetTestImage.py: -------------------------------------------------------------------------------- 1 | import cv2 as cv 2 | 3 | img_roi_y = 30 4 | img_roi_x = 200 5 | img_roi_height = 350 # [2]设置ROI区域的高度 6 | img_roi_width = 350 # [3]设置ROI区域的宽度 7 | capture = cv.VideoCapture(0) 8 | index = 1 9 | num = 150 10 | while True: 11 | ret, frame = capture.read() 12 | if ret is True: 13 | img_roi = frame[img_roi_y:(img_roi_y + img_roi_height), img_roi_x:(img_roi_x + img_roi_width)] 14 | cv.imshow("frame", img_roi) 15 | index += 1 16 | if index % 6 == 0: # 每20帧保存一次图像 17 | num += 1 18 | cv.imwrite("D:/python/deep-learning/Gesture-recognition/data/test/" 19 | + "gesture_5."+str(num) + ".jpg", img_roi) 20 | c = cv.waitKey(50) # 每50ms判断一下键盘的触发。 0则为无限等待。 21 | if c == 27: # 在ASCII码中27表示ESC键,ord函数可以将字符转换为ASCII码。 22 | break 23 | if index == 300: 24 | break 25 | else: 26 | break 27 | 28 | cv.destroyAllWindows() 29 | capture.release() 30 | -------------------------------------------------------------------------------- /GetTrainImage.py: -------------------------------------------------------------------------------- 1 | import cv2 as cv 2 | 3 | img_roi_y = 30 4 | img_roi_x = 200 5 | img_roi_height = 300 # [2]设置ROI区域的高度 6 | img_roi_width = 300 # [3]设置ROI区域的宽度 7 | capture = cv.VideoCapture(0) 8 | index = 1 9 | num = 1100 10 | while True: 11 | ret, frame = capture.read() 12 | if ret is True: 13 | img_roi = frame[img_roi_y:(img_roi_y + img_roi_height), img_roi_x:(img_roi_x + img_roi_width)] 14 | cv.imshow("frame", img_roi) 15 | index += 1 16 | if index % 5 == 0: # 每5帧保存一次图像 17 | num += 1 18 | cv.imwrite("D:/python/deep-learning/Gesture-recognition/data/train/" 19 | + "gesture_1."+str(num) + ".jpg", img_roi) 20 | c = cv.waitKey(50) # 每50ms判断一下键盘的触发。 0则为无限等待。 21 | if c == 27: # 在ASCII码中27表示ESC键,ord函数可以将字符转换为ASCII码。 22 | break 23 | if index == 1000: 24 | break 25 | else: 26 | break 27 | 28 | cv.destroyAllWindows() 29 | capture.release() 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## 利用卷积神经网络实时识别手势动作 2 | 3 | 一共识别5种手势动作 4 | 1. 剪刀动作 2.石头动作 3.布动作 4.OK动作 5.good动作 5 | 6 | ### 项目文件 7 | 8 | 项目文件列表如下: 9 | 10 | 1. `data`:存放训练集、测试集,实时保存的图像(用于在线检测)。 11 | 2. `ges_ico`:存放UI窗口使用的各种图标。 12 | 3. `log`:存放训练的CNN网络的模型参数。 13 | 4. `CallFrame.py`:界面窗口的逻辑文件,用来调用界面文件并编写信号与槽函数。 14 | 5. `Frame.py`:界面窗口的界面文件,通过PyQt5的designer工具生成。 15 | 6. `GetTestImage.py`:利用OpenCV获取图片并标记,用来制作测试集。 16 | 7. `GetTrainImage.py`:利用OpenCV获取图片并标记,用来制作训练集。 17 | 8. `SaveGesture.py`:利用OpenCV实时获取图片,并进行预处理,用于在线检测手势。 18 | 9. `TestGesture.py`:将实时获取的图片送入已训练好的CNN中判断其手势动作。 19 | 10. `TestInTest.py`:将测试集送入已训练好的CNN中判断该网络模型的准确率。 20 | 11. `Train.py`:训练CNN模型函数,并将训练好的模型参数保存在本地。 21 | 12. `Train_inputdata.py`:用来读取数据集的图像和标签,并打包成batch形式。 22 | 13. `Train_model.py`:模型结构,采用AlexNet结构。 23 | 24 | ### 使用方法 25 | 26 | 先用Train.py训练好模型参数,然后运行CallFrame.py调用出界面窗口, 27 | 点击窗口的相应按钮就可以在线检测手势动作,其中的执行手势按钮是和下位机通信(如STM32单片机), 28 | 通过串口函数将识别结果传给下位机,实现根据手势动作控制的功能。 29 | 30 | ### 测试结果: 31 | 使用该模型训练到900步的时候在测试集上正确率可以稳定在95%左右。 32 | (训练集:1,2,3,4号动作各有1300张照片,5号动作有1450张照片;测试集:每种动作各有200张照片) 33 | 34 | 35 | ### 未来改进: 36 | (1)图像预处理多一些如去除背景 37 | (2)在线检测图像的时候加一个预选框。 38 | 39 | #### 详细说明与调参过程见我的CSDN: https://blog.csdn.net/Amigo_1997/article/details/89174493?spm=1001.2014.3001.5502 40 | 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /SaveGesture.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | 4 | def saveGesture(): 5 | """ 6 | 原始图像的大小为(480,640,3) 7 | 保存的图像是切割后的图像(400,400,3) 8 | """ 9 | cameraCapture = cv2.VideoCapture(0) 10 | success, frame = cameraCapture.read() 11 | if success is True: 12 | cv2.imwrite("./data/testImage/" + "test.jpg", frame) 13 | 14 | testImg = cv2.imread('./data/testImage/test.jpg') 15 | print(testImg.shape) # G7笔记本的摄像头是(480,640,3) 高度,宽度,通道数 16 | 17 | img_roi_y = 30 18 | img_roi_x = 200 19 | img_roi_height = 350 # [2]设置ROI区域的高度 20 | img_roi_width = 350 # [3]设置ROI区域的宽度 21 | img_roi = testImg[img_roi_y:(img_roi_y + img_roi_height), img_roi_x:(img_roi_x + img_roi_width)] 22 | 23 | cv2.imshow("[ROI_Img]", img_roi) 24 | cv2.imwrite("./data/testImage/roi/" + "img_roi.jpg", img_roi) 25 | cv2.waitKey(0) 26 | cv2.destroyWindow("[ROI_Img]") 27 | -------------------------------------------------------------------------------- /TestGesture.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import matplotlib.pyplot as plt 3 | import Train_inputdata 4 | import Train_model 5 | import os 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | 10 | def get_one_image(train): 11 | n = len(train) 12 | ind = np.random.randint(0, n) 13 | img_dir = train[ind] 14 | 15 | image = Image.open(img_dir) 16 | image = image.resize([227, 227]) 17 | image = np.array(image) 18 | return image 19 | 20 | 21 | def evaluate_one_image(): 22 | """ 23 | Test one image against the saved models and parameters 24 | 返回字符串1~5 25 | """ 26 | train_dir = 'D:/python/deep-learning/Gesture-recognition/data/testImage/roi/' 27 | train, train_label = Train_inputdata.get_files(train_dir) 28 | image_array = get_one_image(train) 29 | 30 | with tf.Graph().as_default(): 31 | BATCH_SIZE = 1 32 | N_CLASSES = 5 33 | 34 | image = tf.cast(image_array, tf.float32) 35 | image = tf.image.per_image_standardization(image) 36 | image = tf.reshape(image, [1, 227, 227, 3]) 37 | logit = Train_model.cnn_inference(image, BATCH_SIZE, N_CLASSES, keep_prob=1) 38 | 39 | logit = tf.nn.softmax(logit) 40 | 41 | x = tf.placeholder(tf.float32, shape=[227, 227, 3]) 42 | 43 | logs_train_dir = 'D:/python/deep-learning/Gesture-recognition/log/' 44 | 45 | saver = tf.train.Saver() 46 | 47 | with tf.Session() as sess: 48 | 49 | print("Reading checkpoints...") 50 | ckpt = tf.train.get_checkpoint_state(logs_train_dir) 51 | if ckpt and ckpt.model_checkpoint_path: 52 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 53 | saver.restore(sess, ckpt.model_checkpoint_path) 54 | print('Loading success, global_step is %s' % global_step) 55 | else: 56 | print('No checkpoint file found') 57 | 58 | prediction = sess.run(logit, feed_dict={x: image_array}) 59 | max_index = np.argmax(prediction) 60 | if max_index == 0: 61 | print('This is a scissor with possibility %.6f' % prediction[:, 0]) 62 | return 1 63 | elif max_index == 1: 64 | print('This is a rock with possibility %.6f' % prediction[:, 1]) 65 | return 2 66 | elif max_index == 2: 67 | print('This is a paper with possibility %.6f' % prediction[:, 2]) 68 | return 3 69 | elif max_index == 3: 70 | print('This is a ok with possibility %.6f' % prediction[:, 3]) 71 | return 4 72 | elif max_index == 4: 73 | print('This is a good with possibility %.6f' % prediction[:, 4]) 74 | return 5 75 | 76 | 77 | 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /TestInTest.py: -------------------------------------------------------------------------------- 1 | import Train_inputdata 2 | import Train_model 3 | import tensorflow as tf 4 | import numpy as np 5 | 6 | with tf.Graph().as_default(): # 不加的话graph就会冲突 7 | IMG_W = 227 8 | IMG_H = 227 9 | BATCH_SIZE = 500 10 | CAPACITY = 2000 11 | N_CLASSES = 5 12 | 13 | test_dir = 'D:/python/deep-learning/Gesture-recognition/data/test/' 14 | 15 | test, test_label = Train_inputdata.get_files(test_dir) 16 | test_batch, test_label_batch = Train_inputdata.get_batch(test, 17 | test_label, 18 | IMG_W, 19 | IMG_H, 20 | BATCH_SIZE, 21 | CAPACITY) 22 | 23 | test_logit = Train_model.cnn_inference(test_batch, BATCH_SIZE, N_CLASSES, keep_prob=1) 24 | test_acc = Train_model.evaluation(test_logit, test_label_batch) 25 | test_loss = Train_model.losses(test_logit, test_label_batch) 26 | 27 | logs_train_dir = 'D:/python/deep-learning/Gesture-recognition/log/' 28 | 29 | saver = tf.train.Saver() 30 | 31 | with tf.Session() as sess: 32 | 33 | coord = tf.train.Coordinator() 34 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 35 | 36 | print("Reading checkpoints...") 37 | ckpt = tf.train.get_checkpoint_state(logs_train_dir) 38 | 39 | # 对最新的模型进行测验 40 | # if ckpt and ckpt.model_checkpoint_path: 41 | # global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 42 | # saver.restore(sess, ckpt.model_checkpoint_path) 43 | # print('Loading success, global_step is %s' % global_step) 44 | # else: 45 | # print('No checkpoint file found') 46 | 47 | # 对所有保存的模型进行测验 48 | if ckpt and ckpt.all_model_checkpoint_paths: 49 | for path in ckpt.all_model_checkpoint_paths: 50 | saver.restore(sess, path) 51 | global_step = path.split('/')[-1].split('-')[-1] 52 | print('Loading success, global_step is %s' % global_step) 53 | accuracy, loss = sess.run([test_acc, test_loss]) 54 | print("测试集正确率是:%.2f%%" % (accuracy * 100)) 55 | print("测试集损失率:%.2f" % loss) 56 | else: 57 | print('No checkpoint file found') 58 | 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /Train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | import matplotlib.pyplot as plt 5 | 6 | import Train_model 7 | import Train_inputdata 8 | 9 | 10 | N_CLASSES = 5 # 5种手势 11 | IMG_W = 227 # resize图像,太大的话训练时间久 12 | IMG_H = 227 13 | BATCH_SIZE = 32 # 由于数据集较小,所以选值较大 一次训练512张已经是该电脑的极限了 14 | CAPACITY = 320 15 | MAX_STEP = 1000 16 | learning_rate = 0.0001 # 一般小于0.0001 17 | 18 | train_dir = 'D:/python/deep-learning/Gesture-recognition/data/train/' 19 | logs_train_dir = 'D:/python/deep-learning/Gesture-recognition/log/' # 记录训练过程与保存模型 20 | 21 | train, train_label = Train_inputdata.get_files(train_dir) 22 | train_batch, train_label_batch = Train_inputdata.get_batch(train, 23 | train_label, 24 | IMG_W, 25 | IMG_H, 26 | BATCH_SIZE, 27 | CAPACITY) 28 | 29 | train_logits = Train_model.cnn_inference(train_batch, BATCH_SIZE, N_CLASSES, keep_prob=0.5) 30 | train_loss = Train_model.losses(train_logits, train_label_batch) 31 | train_op = Train_model.training(train_loss, learning_rate) 32 | train__acc = Train_model.evaluation(train_logits, train_label_batch) 33 | 34 | 35 | summary_op = tf.summary.merge_all() # 这个是log汇总记录 36 | 37 | # 可视化为了画折线图 38 | step_list = list(range(50)) # 因为后来的cnn_list有20个 39 | cnn_list1 = [] # tra_acc 40 | cnn_list2 = [] # tra_loss 41 | 42 | fig = plt.figure() # 建立可视化图像框 43 | ax = fig.add_subplot(2, 1, 1) # 子图总行数、列数,位置 44 | ax.yaxis.grid(True) 45 | ax.set_title('cnn_accuracy ', fontsize=14, y=1.02) 46 | ax.set_xlabel('step') 47 | ax.set_ylabel('accuracy') 48 | bx = fig.add_subplot(2, 1, 2) 49 | bx.yaxis.grid(True) 50 | bx.set_title('cnn_loss ', fontsize=14, y=1.02) 51 | bx.set_xlabel('step') 52 | bx.set_ylabel('loss') 53 | 54 | 55 | # 初始化,如果存在变量则是必不可少的操作 56 | with tf.Session() as sess: 57 | sess.run(tf.global_variables_initializer()) 58 | 59 | # 产生一个writer来写log文件 60 | train_writer = tf.summary.FileWriter(logs_train_dir, sess.graph) 61 | # 产生一个saver来存储训练好的模型 62 | saver = tf.train.Saver() 63 | 64 | # 队列监控 65 | # batch训练法用到了队列,不想用队列也可以用placeholder 66 | coord = tf.train.Coordinator() 67 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 68 | 69 | try: 70 | # 执行MAX_STEP步的训练,一步一个batch 71 | for step in np.arange(MAX_STEP): 72 | if coord.should_stop(): 73 | break 74 | # 启动以下操作节点,这里不能用train_op,因为它在第二次迭代是None,会导致session出错,改为_ 75 | _op, tra_loss, tra_acc = sess.run([train_op, train_loss, train__acc]) 76 | 77 | # tes_acc, tes_loss = sess.run([test_acc, test_loss]) 78 | # 每隔20步打印一次当前的loss以及acc,同时记录log,写入writer,并画个图 79 | if step % 20 == 0: 80 | print('Step %d, train loss = %.2f, train accuracy = %.2f%%' % (step, tra_loss, tra_acc * 100.0)) 81 | summary_str = sess.run(summary_op) 82 | train_writer.add_summary(summary_str, step) 83 | cnn_list1.append(tra_acc) 84 | cnn_list2.append(tra_loss) 85 | # 每隔40步,保存一次训练好的模型 86 | if step % 60 == 0 or (step + 1) == MAX_STEP: 87 | checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt') 88 | saver.save(sess, checkpoint_path, global_step=step) 89 | 90 | ax.plot(step_list, cnn_list1, color="r", label=train) 91 | bx.plot(step_list, cnn_list2, color="r", label=train) 92 | 93 | plt.tight_layout() 94 | plt.show() 95 | 96 | except tf.errors.OutOfRangeError: 97 | print('Done training -- epoch limit reached') 98 | finally: 99 | coord.request_stop() 100 | 101 | 102 | 103 | -------------------------------------------------------------------------------- /Train_inputdata.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import os 4 | 5 | 6 | def get_files(file_dir): 7 | """ 8 | 输入: 存放训练照片的文件地址 9 | 返回: 图像列表, 标签列表 10 | """ 11 | # 建立空列表 12 | gesture_1 = [] # 剪刀 13 | label_gesture_1 = [] 14 | gesture_2 = [] # 石头 15 | label_gesture_2 = [] 16 | gesture_3 = [] # 布 17 | label_gesture_3 = [] 18 | gesture_4 = [] # OK 19 | label_gesture_4 = [] 20 | gesture_5 = [] # good 21 | label_gesture_5 = [] 22 | 23 | # 读取标记好的图像和加入标签 24 | for file in os.listdir(file_dir): # file就是要读取的照片 25 | name = file.split(sep='.') # 因为照片的格式是gesture_1.1.jpg/gesture_1.2.jpg 26 | if name[0] == 'gesture_1': # 所以只用读取 . 前面这个字符串 27 | gesture_1.append(file_dir + file) 28 | label_gesture_1.append(0) # 把图像和标签加入列表 29 | elif name[0] == 'gesture_2': 30 | gesture_2.append(file_dir + file) 31 | label_gesture_2.append(1) 32 | elif name[0] == 'gesture_3': 33 | gesture_3.append(file_dir + file) 34 | label_gesture_3.append(2) 35 | elif name[0] == 'gesture_4': 36 | gesture_4.append(file_dir + file) 37 | label_gesture_4.append(3) 38 | else: 39 | gesture_5.append(file_dir + file) 40 | label_gesture_5.append(4) 41 | 42 | print('There are %d scissors\nThere are %d paper\nThere are %d rock\nThere are %d ok\nThere are %d good' 43 | % (len(gesture_1), len(gesture_3), len(gesture_2), len(gesture_4), len(gesture_5))) 44 | 45 | image_list = np.hstack((gesture_1, gesture_2, gesture_3, gesture_4, gesture_5)) # 在水平方向平铺合成一个行向量 46 | label_list = np.hstack((label_gesture_1, label_gesture_2, label_gesture_3, label_gesture_4, label_gesture_5)) 47 | 48 | temp = np.array([image_list, label_list]) # 生成一个两行数组列表,大小是2 X (1300X4+1450) 49 | temp = temp.transpose() # 转置向量,大小变成(1300X4+1450) X 2 50 | np.random.shuffle(temp) # 乱序,打乱这2500个例子的顺序,一个乱序的效果不太明显 51 | np.random.shuffle(temp) 52 | np.random.shuffle(temp) 53 | 54 | image_list = list(temp[:, 0]) # 所有行,列=0 55 | label_list = list(temp[:, 1]) # 所有行,列=1 56 | label_list = [int(float(i)) for i in label_list] # 把标签列表转化为int类型 57 | 58 | return image_list, label_list 59 | 60 | 61 | def get_batch(image, label, image_W, image_H, batch_size, capacity): 62 | """ 63 | 输入: 64 | image,label :要生成batch的图像和标签 65 | image_W,image_H: 图像的宽度和高度 66 | batch_size: 每个batch(小批次)有多少张图片数据 67 | capacity: 队列的最大容量 68 | 返回: 69 | image_batch: 4D tensor [batch_size, width, height, 3], dtype=tf.float32 70 | label_batch: 1D tensor [batch_size], dtype=tf.int32 71 | """ 72 | # 将列表转换成tf能够识别的格式 73 | image = tf.cast(image, tf.string) 74 | label = tf.cast(label, tf.int32) 75 | 76 | # 生成队列(牵扯到线程概念,便于batch训练) 77 | """ 78 | 队列的理解:每次训练时,从队列中取一个batch送到网络进行训练, 79 | 然后又有新的图片从训练库中注入队列,这样循环往复。 80 | 队列相当于起到了训练库到网络模型间数据管道的作用, 81 | 训练数据通过队列送入网络。 82 | """ 83 | input_queue = tf.train.slice_input_producer([image, label]) 84 | 85 | # 图像的读取需要tf.read_file(),标签则可以直接赋值 86 | image_contents = tf.read_file(input_queue[0]) 87 | image = tf.image.decode_jpeg(image_contents, channels=3) # 解码彩色的.jpg图像 88 | label = input_queue[1] 89 | 90 | # 统一图片大小 91 | image = tf.image.resize_images(image, [image_H, image_W], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) 92 | image = tf.cast(image, tf.float32) 93 | image = tf.image.per_image_standardization(image) # 标准化图片(图片归一化) 94 | 95 | # 打包batch的大小 96 | image_batch, label_batch = tf.train.batch([image, label], 97 | batch_size=batch_size, 98 | num_threads=64, # 涉及到线程,配合队列 99 | capacity=capacity) 100 | 101 | # 下面两行代码应该也多余了,放在这里确保一下格式不会出问题 102 | image_batch = tf.cast(image_batch, tf.float32) 103 | label_batch = tf.cast(label_batch, tf.int32) 104 | 105 | return image_batch, label_batch 106 | -------------------------------------------------------------------------------- /Train_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def cnn_inference(images, batch_size, n_classes, keep_prob): 5 | 6 | """ 7 | 使用AlexNet结构 8 | 输入 9 | images 输入的图像 10 | batch_size 每个批次的大小 11 | n_classes n分类 12 | keep_prob droupout保存的比例( 设置神经元被选中的概率) 13 | 返回 14 | softmax_linear 还差一个softmax 15 | """ 16 | # 第一层的卷积层conv1,卷积核为3X3,有16个 17 | with tf.variable_scope('conv1') as scope: 18 | # 建立weights和biases的共享变量 19 | # conv1, shape = [kernel size, kernel size, channels, kernel numbers] 20 | weights = tf.get_variable('weights', 21 | shape=[11, 11, 3, 96], 22 | dtype=tf.float32, 23 | initializer=tf.truncated_normal_initializer(stddev=0.1,dtype=tf.float32)) # stddev标准差 24 | biases = tf.get_variable('biases', 25 | shape=[96], 26 | dtype=tf.float32, 27 | initializer=tf.constant_initializer(0.1)) 28 | # 卷积层 strides = [1, x_movement, y_movement, 1], padding填充周围有valid和same可选择 29 | conv = tf.nn.conv2d(images, weights, strides=[1, 4, 4, 1], padding='SAME') 30 | pre_activation = tf.nn.bias_add(conv, biases) # 加入偏差 31 | conv1 = tf.nn.relu(pre_activation, name=scope.name) # 加上激活函数非线性化处理,且是在conv1的命名空间 32 | 33 | # 第一层的池化层pool1和规范化norm1(特征缩放) 34 | with tf.variable_scope('pooling1_lrn') as scope: 35 | pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], 36 | padding='VALID', name='pooling1') 37 | norm1 = tf.nn.lrn(pool1, depth_radius=4, bias=1.0, alpha=0.001/9.0, 38 | beta=0.75,name='norm1') 39 | # ksize是池化窗口的大小=[1,height,width,1],一般height=width=池化窗口的步长 40 | # 池化窗口的步长一般是比卷积核多移动一位 41 | # tf.nn.lrn是Local Response Normalization,(局部响应归一化) 42 | 43 | # 第二层的卷积层cov2,这里的命名空间和第一层不一样,所以可以和第一层取同名 44 | with tf.variable_scope('conv2') as scope: 45 | weights = tf.get_variable('weights', 46 | shape=[5, 5, 96, 256], # 这里只有第三位数字96需要等于上一层的tensor维度 47 | dtype=tf.float32, 48 | initializer=tf.truncated_normal_initializer(stddev=0.1,dtype=tf.float32)) 49 | biases = tf.get_variable('biases', 50 | shape=[256], 51 | dtype=tf.float32, 52 | initializer=tf.constant_initializer(0.1)) 53 | conv = tf.nn.conv2d(norm1, weights, strides=[1, 1, 1, 1],padding='SAME') 54 | pre_activation = tf.nn.bias_add(conv, biases) 55 | conv2 = tf.nn.relu(pre_activation, name='conv2') 56 | 57 | # 第二层的池化层pool2和规范化norm2 58 | with tf.variable_scope('pooling2_lrn') as scope: 59 | norm2 = tf.nn.lrn(conv2, depth_radius=4, bias=1.0, alpha=0.001/9.0, 60 | beta=0.75,name='norm2') 61 | pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], 62 | padding='VALID',name='pooling2') 63 | # 这里选择了先规范化再池化 64 | 65 | # conv3 66 | with tf.variable_scope('conv3') as scope: 67 | weights = tf.get_variable('weights', 68 | shape=[3, 3, 256, 384], 69 | dtype=tf.float32, 70 | initializer=tf.truncated_normal_initializer(stddev=0.1,dtype=tf.float32)) 71 | biases = tf.get_variable('biases', 72 | shape=[384], 73 | dtype=tf.float32, 74 | initializer=tf.constant_initializer(0.1)) 75 | conv = tf.nn.conv2d(pool2, weights, strides=[1, 1, 1, 1],padding='SAME') 76 | pre_activation = tf.nn.bias_add(conv, biases) 77 | conv3 = tf.nn.relu(pre_activation, name='conv3') 78 | 79 | # conv4 80 | with tf.variable_scope('conv4') as scope: 81 | weights = tf.get_variable('weights', 82 | shape=[3, 3, 384, 384], 83 | dtype=tf.float32, 84 | initializer=tf.truncated_normal_initializer(stddev=0.1,dtype=tf.float32)) 85 | biases = tf.get_variable('biases', 86 | shape=[384], 87 | dtype=tf.float32, 88 | initializer=tf.constant_initializer(0.1)) 89 | conv = tf.nn.conv2d(conv3, weights, strides=[1, 1, 1, 1],padding='SAME') 90 | pre_activation = tf.nn.bias_add(conv, biases) 91 | conv4 = tf.nn.relu(pre_activation, name='conv4') 92 | 93 | # conv5 94 | with tf.variable_scope('conv5') as scope: 95 | weights = tf.get_variable('weights', 96 | shape=[3, 3, 384, 256], 97 | dtype=tf.float32, 98 | initializer=tf.truncated_normal_initializer(stddev=0.1,dtype=tf.float32)) 99 | biases = tf.get_variable('biases', 100 | shape=[256], 101 | dtype=tf.float32, 102 | initializer=tf.constant_initializer(0.1)) 103 | conv = tf.nn.conv2d(conv4, weights, strides=[1, 1, 1, 1],padding='SAME') 104 | pre_activation = tf.nn.bias_add(conv, biases) 105 | conv5 = tf.nn.relu(pre_activation, name='conv5') 106 | 107 | # 池化 108 | with tf.variable_scope('pooling') as scope: 109 | pooling = tf.nn.max_pool(conv5, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], 110 | padding='VALID', name='pooling1') 111 | 112 | # 第三层为全连接层local3 113 | with tf.variable_scope('local3') as scope: 114 | # flatten-把卷积过的多维tensor拉平成二维张量(矩阵) 115 | reshape = tf.reshape(pooling, shape=[batch_size, -1]) # batch_size表明了有多少个样本 116 | 117 | dim = reshape.get_shape()[1].value # 知道-1(代表任意)这里具体是多少个 118 | weights = tf.get_variable('weights', 119 | shape=[dim, 1024], # 连接1024个神经元 120 | dtype=tf.float32, 121 | initializer=tf.truncated_normal_initializer(stddev=0.005, dtype=tf.float32)) 122 | biases = tf.get_variable('biases', 123 | shape=[1024], 124 | dtype=tf.float32, 125 | initializer=tf.constant_initializer(0.1)) 126 | local3 = tf.nn.relu(tf.matmul(reshape, weights) + biases, name=scope.name) # 矩阵相乘加上bias 127 | local3 = tf.nn.dropout(local3, keep_prob) # 设置神经元被选中的概率 128 | 129 | # 第四层为全连接层local4 130 | with tf.variable_scope('local4') as scope: 131 | weights = tf.get_variable('weights', 132 | shape=[1024, 1024], # 再连接1024个神经元 133 | dtype=tf.float32, 134 | initializer=tf.truncated_normal_initializer(stddev=0.005, dtype=tf.float32)) 135 | biases = tf.get_variable('biases', 136 | shape=[1024], 137 | dtype=tf.float32, 138 | initializer=tf.constant_initializer(0.1)) 139 | local4 = tf.nn.relu(tf.matmul(local3, weights) + biases, name='local4') 140 | local4 = tf.nn.dropout(local4, keep_prob) 141 | 142 | # 第五层为输出层softmax_linear 143 | with tf.variable_scope('softmax_linear') as scope: 144 | weights = tf.get_variable('weights', 145 | shape=[1024, n_classes], 146 | dtype=tf.float32, 147 | initializer=tf.truncated_normal_initializer(stddev=0.005,dtype=tf.float32)) 148 | biases = tf.get_variable('biases', 149 | shape=[n_classes], 150 | dtype=tf.float32, 151 | initializer=tf.constant_initializer(0.1)) 152 | softmax_linear = tf.add(tf.matmul(local4, weights), biases, name='softmax_linear') 153 | # 这里只是命名为softmax_linear,真正的softmax函数放在下面的losses函数里面和交叉熵结合在一起了,这样可以提高运算速度。 154 | # softmax_linear的行数=local4的行数,列数=weights的列数=bias的行数=需要分类的个数 155 | # 经过softmax函数用于分类过程中,它将多个神经元的输出,映射到(0,1)区间内,可以看成概率来理解 156 | softmax_linear = tf.nn.dropout(softmax_linear, keep_prob) 157 | 158 | return softmax_linear 159 | 160 | 161 | def losses(logits, labels): 162 | """ 163 | 输入 164 | logits: 经过cnn_inference处理过的tensor 165 | labels: 对应的标签 166 | 返回 167 | loss: 损失函数(交叉熵) 168 | """ 169 | with tf.variable_scope('loss') as scope: 170 | # 下面把交叉熵和softmax合到一起写是为了通过spares提高计算速度 171 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels, name='loss_per_eg') 172 | loss = tf.reduce_mean(cross_entropy, name='loss') # 求所有样本的平均loss 173 | return loss 174 | 175 | 176 | def training(loss, learning_rate): 177 | """ 178 | 输入 179 | loss: 损失函数(交叉熵) 180 | learning_rate: 学习率 181 | 返回 182 | train_op: 训练的最优值 183 | """ 184 | with tf.name_scope('optimizer'): 185 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) 186 | # global_step不是共享变量,初始值为0,设定trainable=False 可以防止该变量被数据流图的 GraphKeys.TRAINABLE_VARIABLES 收集, 187 | # 这样我们就不会在训练的时候尝试更新它的值。 188 | global_step = tf.Variable(0, name='global_step', trainable=False) 189 | train_op = optimizer.minimize(loss, global_step= global_step) 190 | return train_op 191 | 192 | 193 | def evaluation(logits, labels): 194 | """ 195 | 输入 196 | logits: 经过cnn_inference处理过的tensor 197 | labels: 198 | 返回 199 | accuracy:正确率 200 | """ 201 | with tf.variable_scope('accuracy') as scope: 202 | prediction = tf.nn.softmax(logits) # 这个logits有n_classes列 203 | # prediction每行的最大元素(1)的索引和label的值相同则为1 否则为0。 204 | correct = tf.nn.in_top_k(prediction, labels, 1) 205 | # correct = tf.nn.in_top_k(logits, labels, 1) 也可以不需要prediction过渡,因为最大值的索引没变,这里这样写是为了更好理解 206 | correct = tf.cast(correct, tf.float16) # 记得要转换格式 207 | accuracy = tf.reduce_mean(correct) 208 | return accuracy 209 | -------------------------------------------------------------------------------- /ges_ico/frame.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingjianzhang1997/gesture-recognition/2298c4b11a5f96dc3a73a1ed392882215d3df07f/ges_ico/frame.ico -------------------------------------------------------------------------------- /ges_ico/ges1.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingjianzhang1997/gesture-recognition/2298c4b11a5f96dc3a73a1ed392882215d3df07f/ges_ico/ges1.ico -------------------------------------------------------------------------------- /ges_ico/ges2.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingjianzhang1997/gesture-recognition/2298c4b11a5f96dc3a73a1ed392882215d3df07f/ges_ico/ges2.ico -------------------------------------------------------------------------------- /ges_ico/ges3.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingjianzhang1997/gesture-recognition/2298c4b11a5f96dc3a73a1ed392882215d3df07f/ges_ico/ges3.ico -------------------------------------------------------------------------------- /ges_ico/ges4.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingjianzhang1997/gesture-recognition/2298c4b11a5f96dc3a73a1ed392882215d3df07f/ges_ico/ges4.ico -------------------------------------------------------------------------------- /ges_ico/ges5.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingjianzhang1997/gesture-recognition/2298c4b11a5f96dc3a73a1ed392882215d3df07f/ges_ico/ges5.ico -------------------------------------------------------------------------------- /ges_ico/white.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingjianzhang1997/gesture-recognition/2298c4b11a5f96dc3a73a1ed392882215d3df07f/ges_ico/white.ico --------------------------------------------------------------------------------