├── MNIST_data ├── t10k-images-idx3-ubyte.gz ├── t10k-labels-idx1-ubyte.gz ├── train-images-idx3-ubyte.gz └── train-labels-idx1-ubyte.gz ├── Model ├── checkpoint ├── mnist_model.ckpt.index └── mnist_model.ckpt.meta ├── PaintBoard.py ├── README ├── README.md ├── Recognize.py ├── Train.py ├── image ├── 1.jpg ├── num.zip └── test.png └── main.py /MNIST_data/t10k-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-huicheng/Learning-diary/cad83acb549676976ca10c6e2d1b0cd6ddeed36c/MNIST_data/t10k-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /MNIST_data/t10k-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-huicheng/Learning-diary/cad83acb549676976ca10c6e2d1b0cd6ddeed36c/MNIST_data/t10k-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /MNIST_data/train-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-huicheng/Learning-diary/cad83acb549676976ca10c6e2d1b0cd6ddeed36c/MNIST_data/train-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /MNIST_data/train-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-huicheng/Learning-diary/cad83acb549676976ca10c6e2d1b0cd6ddeed36c/MNIST_data/train-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /Model/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "mnist_model.ckpt" 2 | all_model_checkpoint_paths: "mnist_model.ckpt" 3 | -------------------------------------------------------------------------------- /Model/mnist_model.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-huicheng/Learning-diary/cad83acb549676976ca10c6e2d1b0cd6ddeed36c/Model/mnist_model.ckpt.index -------------------------------------------------------------------------------- /Model/mnist_model.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-huicheng/Learning-diary/cad83acb549676976ca10c6e2d1b0cd6ddeed36c/Model/mnist_model.ckpt.meta -------------------------------------------------------------------------------- /PaintBoard.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import sys 5 | from PyQt5.QtWidgets import QWidget 6 | from PyQt5.Qt import * 7 | from PyQt5.QtCore import Qt 8 | 9 | class PaintBoard(QWidget): 10 | 11 | def __init__(self): 12 | super().__init__() 13 | 14 | self.init() #初始化 15 | 16 | def init(self): 17 | 18 | self.size = QSize(280,280) 19 | 20 | self.board = QPixmap(self.size) #新建 QPixmap 作为画板 21 | self.board.fill(Qt.white) #用白色填充画板 22 | 23 | self.lastPos = QPoint(0,0) # 鼠标位置 24 | self.currentPos = QPoint(0,0) 25 | 26 | self.penThickness = 23 #画笔宽度 27 | self.penColor = QColor('black') 28 | 29 | self.painter = QPainter() #定义画笔 30 | 31 | self.setFixedSize(self.size) 32 | 33 | def Clear(self): #清空画板,即用白色填充 34 | self.board.fill(Qt.white) 35 | self.update() 36 | 37 | def GetContentAsQImage(self): #将画板转化为图片 38 | image = self.board.toImage() 39 | return image 40 | 41 | def paintEvent(self, paintEvent): #由起始位置到 42 | self.painter.begin(self) 43 | self.painter.drawPixmap(0,0,self.board) 44 | self.painter.end() 45 | 46 | def mousePressEvent(self, mouseEvent): #鼠标点击 47 | 48 | self.currentPos = mouseEvent.pos() 49 | self.lastPos = self.currentPos 50 | 51 | 52 | def mouseMoveEvent(self,mouseEvent): #鼠标移动并画线 53 | self.currentPos = mouseEvent.pos() 54 | self.painter.begin(self.board) 55 | self.painter.setPen(QPen(self.penColor,self.penThickness)) 56 | self.painter.drawLine(self.lastPos,self.currentPos) 57 | self.painter.end() 58 | self.lastPos = self.currentPos 59 | self.update() 60 | 61 | 62 | if __name__ == '__main__': 63 | 64 | app = QApplication(sys.argv) 65 | pb = PaintBoard() 66 | pb.show() 67 | 68 | sys.exit(app.exec_()) 69 | 70 | -------------------------------------------------------------------------------- /README: -------------------------------------------------------------------------------- 1 | DATE: 2018.4.5 2 | 3 | 手写体识别 4 | 5 | http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html 6 | 7 | 由softmax()回归方程实现 8 | 9 | y = tf.nn.softmax(tf.matmul(x * W) + b) 10 | 11 | 12 | loss = tf.ruduce_sum(y_ * tf.log(y) 13 | 14 | PyQt5 shixianjiemian 15 | 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning-diary 2 | 3 | 基于MNIST 手写数字数据集,运用卷积神经网络(`Convolutional Neural Network,CNN`)算法实现手写数字识别。 4 | 首先,设计手写体数字识别模型算法,通过`TensorFlow`实现算法;其次,从`MNIST`数据库中每次选出50组图片作为一个集合对算法进行训练, 5 | 不断优化算法的模型参数,并保存模型参数; 6 | 然后,设计识别算法,其中包含图片进行预处理、读取模型参数,识别手写体数字图片等功能; 7 | 最后,由`PyQt5`实现可视化界面,使用可视化界面的画板手写数字,使用识别算法对手写体数字进行识别,返回识别结果。 8 | 9 | -------------------------------------------------------------------------------- /Recognize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import cv2 4 | import tensorflow as tf 5 | import numpy as np 6 | from PIL import Image,ImageFilter 7 | from tensorflow.examples.tutorials.mnist import input_data 8 | 9 | class Recognize(): 10 | 11 | def __init__(self): #构造函数 将 Train 训练的模型导入 12 | 13 | self.sess = tf.Session() 14 | self.saver = tf.train.import_meta_graph('Model/mnist_model.ckpt.meta') 15 | self.saver.restore(self.sess, 'Model/mnist_model.ckpt') 16 | 17 | self.graph = tf.get_default_graph() 18 | self.x_data = self.graph.get_tensor_by_name('x_data:0') 19 | self.y_data = self.graph.get_tensor_by_name('y_data:0') 20 | self.keep_prob = self.graph.get_tensor_by_name('keep_prob:0') 21 | 22 | self.loss = self.graph.get_tensor_by_name('loss:0') 23 | self.train = self.graph.get_tensor_by_name('Variable/Adam:0') 24 | self.y = self.graph.get_tensor_by_name('y:0') 25 | 26 | 27 | def __del__(self): #析构函数 讲学习后的模型保存 并 关闭回话 28 | print('#') 29 | self.saver.save(self.sess,'Model/mnist_model.ckpt') 30 | self.sess.close() 31 | print('#') 32 | 33 | 34 | def Picture(self,fileName): # 通过文件路径 对图片处理 35 | 36 | image = Image.open(fileName).convert('L') 37 | image = image.resize((28,28),Image.ANTIALIAS) 38 | image.save(fileName) 39 | 40 | tv = list(image.getdata()) 41 | value = [(255-x)*1.0/255.0 for x in tv] 42 | #value = [(255-x)*1.0/255.0-0.4 for x in tv] 43 | #value = [x if(x>0) else 0 for x in value] 44 | v = np.array(value) 45 | v = v.reshape((28,28)) 46 | a = np.max(v,1); 47 | b = np.max(v,0); 48 | for i in range(27,-1,-1): 49 | if a[i] == 0: 50 | v = np.delete(v,i,0) 51 | if b[i] == 0: 52 | v = np.delete(v,i,1) 53 | len0 = 28 - len(v) 54 | len1 = 28 - len(v[0]) 55 | i0 = len0 // 2 56 | j0 = len0 - i0 57 | i1 = len1 // 2 58 | j1 = len1 - i1 59 | v = np.pad(v,((i0,j0),(i1,j1)),'constant',constant_values = (0,0)) 60 | #self.printf(v) 61 | value = v.reshape(784) 62 | return value 63 | 64 | 65 | def recognizeNumber(self,image): # 将处理过得图片识别成 独热码 66 | 67 | result = self.sess.run(self.y,feed_dict={self.x_data:image,self.keep_prob:1.0}) 68 | 69 | for x in result: 70 | for i in x: 71 | print('%.2f' % (i),end=' ') 72 | print() 73 | 74 | label = np.argmax(result, 1) 75 | return label 76 | 77 | def Binarization(self,path,savePath): 78 | image = cv2.imread(path,0) 79 | ret,image = cv2.threshold(image,127,255,cv2.THRESH_BINARY) 80 | im = Image.fromarray(image) 81 | im.save(savePath) 82 | 83 | 84 | def Learn(self,image,label): # 通过提供 图片 和 label 训练模型 85 | self.sess.run(self.train,feed_dict={self.x_data:image, 86 | self.y_data:label,keep_prob:0.5}) 87 | ''' 88 | for x in image: 89 | self.printf(x,28,28) 90 | for x in label: 91 | print(x) 92 | ''' 93 | 94 | 95 | def printf(self,data): # 将图片以 0 1 形式输出 96 | 97 | for i in range(len(data)): 98 | for j in range(len(data[0])): 99 | #print("%.2f"%(data[i][j]),end=' ') 100 | if data[i][j] > 0.05: 101 | print('1',end='') 102 | else: 103 | print('0',end='') 104 | print() 105 | 106 | 107 | if __name__=='__main__': #用于测试: 108 | 109 | 110 | image = Image.open('image/1.jpg').convert('L') 111 | image.save('image/1.jpg') 112 | image = [] 113 | re = Recognize() 114 | mnist = input_data.read_data_sets('MNIST_data', one_hot=True) 115 | test_x = mnist.test.images[:20] 116 | test_y = mnist.test.labels[:20] 117 | label = re.recognizeNumber(test_x) 118 | print(label) 119 | print(np.argmax(test_y,1)) 120 | 121 | -------------------------------------------------------------------------------- /Train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | #-*- coding: utf-8 -*- 3 | 4 | import os 5 | import numpy as np 6 | import tensorflow as tf 7 | from tensorflow.examples.tutorials.mnist import input_data 8 | 9 | mnist = input_data.read_data_sets('MNIST_data', one_hot=True) 10 | # one_hot 独热码的编码(encoding)形式 11 | # 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 的十位数字 12 | # 0 : 1000000000 13 | # 1 : 0100000000 14 | # 2 : 0010000000 15 | # 3 : 0001000000 16 | # 4 : 0000100000 ... 17 | 18 | 19 | x_data = tf.placeholder("float", shape=[None, 784], name='x_data') # 输入 20 | y_data = tf.placeholder("float", shape=[None, 10], name='y_data') # 实际值 21 | 22 | # 初始化权重 23 | def weight_variable(shape,names): 24 | initial = tf.truncated_normal(shape, stddev=0.1) # 产生正态分布 标准差0.1 25 | return tf.Variable(initial,name=names) 26 | # 初始化偏置 27 | def bias_variable(shape,names): 28 | initial = tf.constant(0.1, shape=shape) # 定义常量 29 | return tf.Variable(initial,name=names) 30 | 31 | # 卷积层 32 | def conv2d(x,W): 33 | return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 34 | ''' 35 | tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, name=None) 36 | input: 输入图像,张量[batch, in_height, in _width, in_channels] 37 | filter: 卷积核, 张量[filter_height, filter_width, in_channels, out_channels] 38 | strides: 步长,一维向量,长度4 39 | padding:卷积方式,'SAME' 'VALID' 40 | ''' 41 | # 池化层 42 | def max_pool_2x2(x): 43 | return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') # 最大池化 44 | ''' 45 | tf.nn.max_pool(value, ksize, strides, padding, name=None) 46 | value: 输入,一般是卷积层的输出 feature map 47 | ksize: 池化窗口大小,[1, height, width, 1] 48 | strides: 窗口每个维度滑动步长 [1, strides, strides, 1] 49 | padding:和卷积类似,'SAME' 'VALID' 50 | ''' 51 | # 第一层卷积 卷积在每个5*5中算出32个特征 52 | W_conv1 = weight_variable([5, 5, 1, 32],'W_conv1') 53 | b_conv1 = bias_variable([32],'b_conv1') 54 | tf.summary.histogram('W_conv1',W_conv1) 55 | tf.summary.histogram('b_conv1',b_conv1) 56 | 57 | x_image = tf.reshape(x_data, [-1, 28, 28, 1]) 58 | tf.summary.image('input', x_image, 10) 59 | 60 | h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) 61 | h_pool1 = max_pool_2x2(h_conv1) 62 | 63 | # 第二层卷积 64 | W_conv2 = weight_variable([5, 5, 32, 64],'W_conv2') 65 | b_conv2 = bias_variable([64],'b_conv2') 66 | tf.summary.histogram('W_conv2',W_conv2) 67 | tf.summary.histogram('b_conv2',b_conv2) 68 | 69 | h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 70 | h_pool2 = max_pool_2x2(h_conv2) 71 | 72 | # 密集连接层 图片尺寸缩减到了7*7, 本层用1024个神经元处理 73 | W_fc1 = weight_variable([7 * 7 * 64, 1024],'W_fc1') 74 | b_fc1 = bias_variable([1024],'b_fc1') 75 | 76 | h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64]) 77 | h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 78 | 79 | # dropout 防止过拟合 80 | keep_prob = tf.placeholder("float", name='keep_prob') 81 | h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) 82 | 83 | # 输出层 最后添加一个Softmax层 84 | W_fc2 = weight_variable([1024, 10],'W_fc2') 85 | b_fc2 = bias_variable([10],'b_fc2') 86 | 87 | y = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2, name='y') 88 | 89 | # 训练和评估模型 90 | loss = -tf.reduce_sum(y_data * tf.log(y),name = 'loss') 91 | tf.summary.scalar('loss',loss) 92 | train = tf.train.AdamOptimizer(1e-4).minimize(loss) 93 | correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_data, 1)) 94 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"),name='accuracy') 95 | tf.summary.scalar('accuracy',accuracy) 96 | sess = tf.Session() 97 | sess.run(tf.global_variables_initializer()) 98 | 99 | saver = tf.train.Saver() 100 | logdir = 'log_1' 101 | merge = tf.summary.merge_all() 102 | writer = tf.summary.FileWriter(logdir,sess.graph) 103 | 104 | for i in range(10000): 105 | batch = mnist.train.next_batch(50) 106 | if i % 100 == 0: 107 | train_accuracy = sess.run(accuracy,feed_dict={x_data: batch[0], y_data: batch[1], keep_prob: 1.0}) 108 | print("step %d, training accuracy %g"%(i, train_accuracy)) 109 | _,summary = sess.run([train,merge],feed_dict={x_data: batch[0], y_data: batch[1], keep_prob: 0.5}) 110 | writer.add_summary(summary, global_step=i) 111 | print(sess.run(accuracy,feed_dict={x_data: mnist.test.images, y_data: mnist.test.labels, keep_prob: 1.0})) 112 | 113 | # 保存模型 114 | saver.save(sess, "Model/mnist_model.ckpt") 115 | 116 | writer.close() 117 | sess.close() 118 | -------------------------------------------------------------------------------- /image/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-huicheng/Learning-diary/cad83acb549676976ca10c6e2d1b0cd6ddeed36c/image/1.jpg -------------------------------------------------------------------------------- /image/num.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-huicheng/Learning-diary/cad83acb549676976ca10c6e2d1b0cd6ddeed36c/image/num.zip -------------------------------------------------------------------------------- /image/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-huicheng/Learning-diary/cad83acb549676976ca10c6e2d1b0cd6ddeed36c/image/test.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import sys #导入必要的包 5 | import numpy as np 6 | from PaintBoard import PaintBoard #自定义的包 7 | from Recognize import Recognize #自定义的包 8 | from PyQt5.Qt import * 9 | from PyQt5.QtCore import Qt 10 | from PyQt5.QtWidgets import * 11 | from PyQt5.QtGui import QIcon,QPixmap 12 | 13 | class Windows(QWidget): 14 | 15 | def __init__(self): # 类的构造函数 16 | 17 | super().__init__() 18 | 19 | self.init() 20 | 21 | 22 | def init(self): # 初始化函数 被构造函数调用 23 | 24 | self.labels = [] # 数据初始化 labels为独热码 25 | self.images = [] # images 是图片 26 | self.paintBoard = PaintBoard() # 画板 27 | self.re = Recognize() 28 | # UI初始化 29 | self.clear = QPushButton('清空',self) # 创建 button 并关联事件处理器函数 30 | self.clear.setToolTip('点击可清空画板、识别结果、打开的图片') 31 | self.clear.clicked.connect(self.Clear) 32 | 33 | self.recognize = QPushButton('数字识别',self) 34 | self.recognize.clicked.connect(self.onClickRecognize) 35 | 36 | self.resLabel = QLabel(self) # 创建 Label 37 | self.resLabel.setFixedSize(80,45) 38 | self.resLabel.setText('识别结果:') 39 | self.resLabel.setAlignment(Qt.AlignCenter) 40 | 41 | self.result = QLabel(self) 42 | self.result.setFixedSize(45,45) 43 | self.result.setFont(QFont('sans-serif',20,QFont.Bold)) 44 | self.result.setStyleSheet('QLabel{border:2px solid black}') 45 | self.result.setAlignment(Qt.AlignCenter) 46 | 47 | splitter = QSplitter(self) # 占位符 48 | 49 | self.numImage = QLabel(self) 50 | self.numImage.setAlignment(Qt.AlignCenter) 51 | 52 | self.open = QPushButton('打开图片并识别',self) 53 | self.open.clicked.connect(self.openImage) 54 | 55 | self.save = QPushButton('保存图片',self) 56 | self.save.clicked.connect(self.saveImage) 57 | 58 | resLayout = QHBoxLayout() # 识别结果显示的布局 59 | resLayout.addWidget(splitter) 60 | resLayout.addWidget(self.resLabel) 61 | resLayout.addWidget(self.result) 62 | resLayout.addWidget(splitter) 63 | 64 | imageLayout = QHBoxLayout() # 打开图片显示的布局 65 | imageLayout.addWidget(splitter) 66 | imageLayout.addWidget(self.numImage) 67 | imageLayout.addWidget(splitter) 68 | 69 | menu = QVBoxLayout() # 创建右侧工具栏 垂直布局 70 | menu.setContentsMargins(10,10,10,10) 71 | menu.addWidget(self.clear) # 将 button and label 添加到右侧 72 | menu.addWidget(self.recognize) 73 | menu.addLayout(resLayout) 74 | menu.addWidget(splitter) 75 | menu.addLayout(imageLayout) 76 | menu.addWidget(self.open) 77 | menu.addWidget(self.save) 78 | 79 | subLayout1 = QHBoxLayout() 80 | subLayout1.addWidget(self.paintBoard) 81 | subLayout1.addLayout(menu) 82 | 83 | self.numLabel = QLabel(self) # 底侧的 button and label 84 | self.numLabel.setFixedSize(45,30) 85 | self.numLabel.setText('数字:') 86 | self.numLabel.setAlignment(Qt.AlignCenter) 87 | 88 | self.number = QSpinBox(self) 89 | self.number.setFixedSize(35,30) 90 | self.number.setMaximum(9) 91 | self.number.setMinimum(0) 92 | self.number.setSingleStep(1) 93 | 94 | 95 | self.addData = QPushButton('添加数据',self) 96 | self.addData.setToolTip('用于添加数据') 97 | self.addData.setFixedSize(85,30) 98 | self.addData.clicked.connect(self.AddData) 99 | 100 | self.learn = QPushButton('执行学习',self) 101 | self.learn.setFixedSize(85,30) 102 | self.learn.clicked.connect(self.Learn) 103 | 104 | self.status = QLabel(self) 105 | self.status.setText('--状态栏--') 106 | self.status.setToolTip('我是 --状态栏--') 107 | self.status.setAlignment(Qt.AlignCenter) 108 | 109 | subLayout2 = QHBoxLayout() # 底侧的布局 110 | subLayout2.addWidget(self.numLabel) 111 | subLayout2.addWidget(self.number) 112 | subLayout2.addWidget(self.addData) 113 | subLayout2.addWidget(self.learn) 114 | subLayout2.addWidget(self.status) 115 | 116 | mainLayout = QVBoxLayout(self) # 总布局 117 | mainLayout.setSpacing(10) 118 | mainLayout.addLayout(subLayout1) 119 | mainLayout.addLayout(subLayout2) 120 | 121 | self.setLayout(mainLayout) 122 | 123 | self.setFixedSize(550,340) # 设置属性 124 | self.setWindowTitle('***手写体数字识别***') # 标题 125 | self.setWindowIcon(QIcon('image/Icon.jpg')) # 图标 126 | self.center() # 居中 127 | self.show() 128 | 129 | # 事件处理器函数 130 | def Clear(self): # 清空界面 131 | 132 | self.paintBoard.Clear() 133 | self.result.setText('') 134 | self.numImage.clear() 135 | self.status.setText('--已清空--') 136 | 137 | 138 | def saveImage(self): # 保存图片 139 | 140 | savePath = QFileDialog.getSaveFileName(self,'saveImage', 141 | '/home/cheng/test.png','Image(*.png *.jpg)') 142 | if savePath[0] == '': 143 | self.status.setText('--取消保存--') 144 | return 145 | image = self.paintBoard.GetContentAsQImage() 146 | print(savePath[0]) 147 | image.save(savePath[0]) 148 | self.status.setText('--图片保存成功--') 149 | 150 | 151 | def openImage(self): #打开图片 并识别 152 | 153 | openPath = QFileDialog.getOpenFileName(self,'openImage', 154 | 'image','Image(*.png *.jpg *.bmp)') 155 | if openPath[0] == '': 156 | self.status.setText('--取消打开图片--') 157 | return 158 | 159 | print(openPath[0]) 160 | savePath = 'image/test.png' 161 | self.re.Binarization(openPath[0],savePath) 162 | image = QPixmap(savePath).scaled(40,40) 163 | self.numImage.setPixmap(image) 164 | 165 | picture = [] 166 | picture.append(self.re.Picture(savePath)) 167 | label = self.re.recognizeNumber(picture) 168 | self.result.setText(str(label[0])) 169 | self.status.setText('--以识别--') 170 | 171 | 172 | def onClickRecognize(self): # 识别画板图片 173 | 174 | image = self.paintBoard.GetContentAsQImage() 175 | image.save('image/test.png') 176 | picture = [] 177 | picture.append(self.re.Picture('image/test.png')) 178 | label = self.re.recognizeNumber(picture) 179 | self.result.setText(str(label[0])) 180 | self.status.setText('--以识别--') 181 | 182 | 183 | def AddData(self): # 添加数据 184 | 185 | index = self.number.text() 186 | label = np.zeros([10]) 187 | label[int(index)] = 1 188 | self.labels.append(label) 189 | image = self.paintBoard.GetContentAsQImage() 190 | image.save('image/test.png') 191 | self.images.append(self.re.Picture('image/test.png')) 192 | self.status.setText('--添加成功--') 193 | 194 | 195 | def Learn(self): # 执行学习 196 | 197 | if len(self.images) == 0: 198 | self.status.setText('--没有数据 请添加--') 199 | return 200 | 201 | self.status.setText('--学习中。。。--') 202 | self.re.Learn(self.images,self.labels) 203 | self.status.setText('--学习完成--') 204 | 205 | 206 | def center(self): # 使窗口居中 207 | 208 | qr = self.frameGeometry() 209 | cp = QDesktopWidget().availableGeometry().center() 210 | qr.moveCenter(cp) 211 | self.move(qr.topLeft()) 212 | 213 | 214 | def closeEvent(self,event): # 退出时调用Recognize的析构函数 215 | del self.re 216 | event.accept() 217 | 218 | 219 | if __name__ == '__main__': # 只有执行本文件使执行 220 | 221 | app = QApplication(sys.argv) 222 | window = Windows() 223 | 224 | sys.exit(app.exec_()) 225 | 226 | 227 | --------------------------------------------------------------------------------