├── README.md ├── code ├── svm.py └── svmtest.py └── data └── Mnist-image.rar /README.md: -------------------------------------------------------------------------------- 1 | # SVM 2 | 基于svm的手写数字图像识别 3 | 按照目录创建文件内容即可正常运行 4 | 5 | 大部分内容参考的https://blog.csdn.net/liyuqian199695/article/details/54236092 作者 6 | 我只是做了一小部分修改,不然没法运行 7 | 8 | -------------------------------------------------------------------------------- /code/svm.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | import sys 4 | import numpy as np 5 | import time 6 | from sklearn import svm 7 | from sklearn.externals import joblib 8 | 9 | # 获取指定路径下的所有 .png 文件 10 | def get_file_list(path): 11 | # file_list = [] 12 | # for filename in os.listdir(path): 13 | # ele_path = os.path.join(path, filename) 14 | # for imgname in os.listdir(ele_path): 15 | # subele_path = os.path.join(ele_path, imgname) 16 | # if (subele_path.endswith(".png")): 17 | # file_list.append(subele_path) 18 | # return file_list 19 | return [os.path.join(path, f) for f in os.listdir(path) if f.endswith(".png")] 20 | 21 | 22 | 23 | # 解析出 .png 图件文件的名称 24 | def get_img_name_str(imgPath): 25 | return imgPath.split(os.path.sep)[-1] 26 | 27 | 28 | # 将 20px * 20px 的图像数据转换成 1*400 的 numpy 向量 29 | # 参数:imgFile--图像名 如:0_1.png 30 | # 返回:1*400 的 numpy 向量 31 | def img2vector(imgFile): 32 | # print("in img2vector func--para:{}".format(imgFile)) 33 | img = Image.open(imgFile).convert('L') 34 | img_arr = np.array(img, 'i') # 20px * 20px 灰度图像 35 | img_normalization = np.round(img_arr / 255) # 对灰度值进行归一化 36 | img_arr2 = np.reshape(img_normalization, (1, -1)) # 1 * 400 矩阵 37 | return img_arr2 38 | 39 | 40 | # 读取一个类别的所有数据并转换成矩阵 41 | # 参数: 42 | # basePath: 图像数据所在的基本路径 43 | # Mnist-image/train/ 44 | # Mnist-image/test/ 45 | # cla:类别名称 46 | # 0,1,2,...,9 47 | # 返回:某一类别的所有数据----[样本数量*(图像宽x图像高)] 矩阵 48 | def read_and_convert(imgFileList): 49 | dataLabel = [] # 存放类标签 50 | dataNum = len(imgFileList) 51 | dataMat = np.zeros((dataNum, 400)) # dataNum * 400 的矩阵 52 | for i in range(dataNum): 53 | imgNameStr = imgFileList[i] 54 | imgName = get_img_name_str(imgNameStr) # 得到 数字_实例编号.png 55 | # print("imgName: {}".format(imgName)) 56 | classTag = imgName.split(".")[0].split("_")[0] # 得到 类标签(数字) 57 | # print("classTag: {}".format(classTag)) 58 | dataLabel.append(classTag) 59 | dataMat[i, :] = img2vector(imgNameStr) 60 | return dataMat, dataLabel 61 | 62 | 63 | # 读取训练数据 64 | def read_all_data(): 65 | cName = ['1', '2', '3', '4', '5', '6', '7', '8', '9'] 66 | path = sys.path[1] 67 | train_data_path = os.path.join(path, 'data\\Mnist-image\\train\\0') 68 | #print(train_data_path) 69 | #train_data_path = "Mnist-image\\train\\0" 70 | print('0') 71 | flist = get_file_list(train_data_path) 72 | dataMat, dataLabel = read_and_convert(flist) 73 | for c in cName: 74 | print(c) 75 | train_data_path = os.path.join(path, 'data\\Mnist-image\\train\\') + c 76 | flist_ = get_file_list(train_data_path) 77 | dataMat_, dataLabel_ = read_and_convert(flist_) 78 | dataMat = np.concatenate((dataMat, dataMat_), axis=0) 79 | dataLabel = np.concatenate((dataLabel, dataLabel_), axis=0) 80 | # print(dataMat.shape) 81 | # print(len(dataLabel)) 82 | return dataMat, dataLabel 83 | 84 | 85 | # create model 86 | def create_svm(dataMat, dataLabel,path,decision='ovr'): 87 | clf = svm.SVC(decision_function_shape=decision) 88 | rf =clf.fit(dataMat, dataLabel) 89 | joblib.dump(rf, path) 90 | return clf 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | if __name__ == '__main__': 101 | # clf = svm.SVC(decision_function_shape='ovr') 102 | st = time.clock() 103 | dataMat, dataLabel = read_all_data() 104 | path = sys.path[1] 105 | model_path=os.path.join(path,'model\\svm.model') 106 | create_svm(dataMat, dataLabel,model_path, decision='ovr') 107 | et = time.clock() 108 | print("Training spent {:.4f}s.".format((et - st))) 109 | 110 | -------------------------------------------------------------------------------- /code/svmtest.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | from code import svm 4 | import os 5 | from sklearn.externals import joblib 6 | 7 | def svmtest(model_path): 8 | path = sys.path[1] 9 | tbasePath = os.path.join(path, "data\\Mnist-image\\test\\") 10 | tcName = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] 11 | tst = time.clock() 12 | allErrCount = 0 13 | allErrorRate = 0.0 14 | allScore = 0.0 15 | #加载模型 16 | clf = joblib.load(model_path) 17 | for tcn in tcName: 18 | testPath = tbasePath + tcn 19 | # print("class " + tcn + " path is: {}.".format(testPath)) 20 | tflist = svm.get_file_list(testPath) 21 | # tflist 22 | tdataMat, tdataLabel = svm.read_and_convert(tflist) 23 | print("test dataMat shape: {0}, test dataLabel len: {1} ".format(tdataMat.shape, len(tdataLabel))) 24 | 25 | # print("test dataLabel: {}".format(len(tdataLabel))) 26 | pre_st = time.clock() 27 | preResult = clf.predict(tdataMat) 28 | pre_et = time.clock() 29 | print("Recognition " + tcn + " spent {:.4f}s.".format((pre_et - pre_st))) 30 | # print("predict result: {}".format(len(preResult))) 31 | errCount = len([x for x in preResult if x != tcn]) 32 | print("errorCount: {}.".format(errCount)) 33 | allErrCount += errCount 34 | score_st = time.clock() 35 | score = clf.score(tdataMat, tdataLabel) 36 | score_et = time.clock() 37 | print("computing score spent {:.6f}s.".format(score_et - score_st)) 38 | allScore += score 39 | print("score: {:.6f}.".format(score)) 40 | print("error rate is {:.6f}.".format((1 - score))) 41 | 42 | tet = time.clock() 43 | print("Testing All class total spent {:.6f}s.".format(tet - tst)) 44 | print("All error Count is: {}.".format(allErrCount)) 45 | avgAccuracy = allScore / 10.0 46 | print("Average accuracy is: {:.6f}.".format(avgAccuracy)) 47 | print("Average error rate is: {:.6f}.".format(1 - avgAccuracy)) 48 | 49 | 50 | if __name__ == '__main__': 51 | path = sys.path[1] 52 | model_path=os.path.join(path,'model\\svm.model') 53 | svmtest(model_path) 54 | 55 | 56 | -------------------------------------------------------------------------------- /data/Mnist-image.rar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECNUHP/SVM/da83babb6dc8e56512fa5b2edb8131aaba33a56c/data/Mnist-image.rar --------------------------------------------------------------------------------