├── .gitignore ├── README-zh.md ├── README.md ├── icon.jpg ├── main.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | /__pycache__ 2 | /data 3 | /output/* 4 | -------------------------------------------------------------------------------- /README-zh.md: -------------------------------------------------------------------------------- 1 | # PyQt5-RamanSpectraClassification 2 | 3 | 基于 PyQt5 的拉曼光谱分类 GUI 4 | 5 | version: 2.0 6 | 7 | © 本工作属于学长的硕士毕业论文《拉曼光谱结合化学计量学在磷矿检测中的应用》(2020) 8 | 9 | 如果本工作对您有帮助,请留下一个star,万分感激! 10 | 11 | ## 主要功能 12 | 13 | ### 导入训练集 14 | 15 | 所需的训练集格式要求: 16 | 17 | * 要求为 csv 文件 18 | * 第一行为浓度标记,二分类时将根据分类阈值分为高低两类 19 | * 第一列为光谱波长,在分析时无具体意义 20 | 21 | ### 参数配置 22 | 23 | 分类模式: 24 | 25 | * 二分类,根据分类阈值将浓度标记分为高低两类 26 | * 多分类,直接按浓度标记分为原本的类目 27 | * 展示分类数目 28 | 29 | 主成分数目: 30 | 31 | * 可自定义 32 | * 默认 3 种 33 | 34 | 交叉验证折数: 35 | 36 | * 可自定义 37 | * 默认 10 折 38 | 39 | ### 主成分分析 40 | 41 | 使用 PCA 方法,根据配置的主成分数目进行降维,并输出每维的方差贡献率及其总和 42 | 43 | ### 支持向量机训练 44 | 45 | 使用网格搜索来寻找最优参数,并展示 46 | 47 | * 二分类时标签为 0,1 48 | * 多分类时为 浓度 * 100 即去除了小数点,并转为 int 类型 49 | * 根据最优参数进行交叉验证,并输出交叉验证的准确率及平均准确率 50 | 51 | ### 保存和导入 SVM 模型文件 52 | 53 | 可以保存和导入上述步骤的模型文件,免去重新训练的重复步骤 54 | 55 | ### 预测 56 | 57 | 导入需要预测的数据集进行预测: 58 | 59 | * 预测集的读取方法和训练集相同,因此需要保证第一行第一列无数据,因为这部分内容会被丢弃 60 | * 预测结果也会保存至 csv 文件中 61 | 62 | ### 数据展示 63 | 64 | 每步操作都会即时在右侧表格中展示当前步骤的输出数据 65 | 66 | ## 使用说明 67 | 68 | ### 运行环境 69 | 70 | Anaconda (Python 3.7.1 64-bit) 71 | 72 | Windows 10 Pro 64-bit 73 | 74 | 命令行启动: 75 | 76 | ```shell 77 | python main.py 78 | ``` 79 | 80 | ### 涉及依赖 81 | 82 | pyqt5 83 | 84 | pandas 85 | 86 | numpy 87 | 88 | sklearn 89 | 90 | ### 文件说明 91 | 92 | 代码中包含了相当多的注释,方便读者理解 93 | 94 | #### main.py 95 | 96 | 核心功能,涉及: 97 | 98 | * GUI 初始化及主要流程 99 | * csv 文件读取及标签处理 100 | * 表格展示及处理 101 | * 参数配置、按钮操作 102 | 103 | #### utils.py 104 | 105 | 主要是涉及机器学习的内容: 106 | 107 | * 主成分分析 108 | * 支持向量机训练,包括交叉验证、模型保存 109 | * 模型读取 110 | 111 | ### 项目成果 112 | 113 | 实现了基于 pyqt5 的光谱分类处理的整个流程 GUI 包括: 114 | 115 | * 文件导入、文本展示、按钮事件、布局排版 116 | * QTableWidget 表格展示 117 | 118 | 实现了基于 sklearn 的支持向量机模型训练及预测 119 | 120 | 实现了文件存取,数据统计及分析处理 121 | 122 | ### 总结 123 | 124 | 是第一次使用 pyqt5 制作 GUI,之前用过 Electron 125 | 126 | python 有天然的功能优势,但在界面布局和美化上,难以与基于 H5 的 Electron 相比 127 | 128 | 踩了很多的坑,也对 python 和涉及到的各种库有了更多的了解 129 | 130 | 参考了大量的资料,无法再一一列举,在此对所有在网络上辛勤奉献的同行们表示由衷的感激 131 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyQt5-RamanSpectraClassification 2 | 3 | Raman spectral classification base on PyQt5 4 | 5 | 基于 PyQt5 的拉曼光谱分类 GUI [查看中文版README](https://github.com/Cheereus/PyQt5-RamanSpectraClassification/blob/master/README-zh.md) 6 | 7 | version: 2.0 8 | 9 | © This project belongs to my senior's Master thesis《拉曼光谱结合化学计量学在磷矿检测中的应用》(2020) 10 | 11 | If this project is useful to you, please give me a star ! 12 | 13 | ## Main Function 14 | 15 | ### Import Training Set 16 | 17 | Training set file format requirements: 18 | 19 | * It should be `.csv` file. 20 | * The first row is the concentration mark, and it will be divided into high and low categories according to the classification threshold in the binary classification. 21 | * The first column is the spectral wavelength, which has no specific meaning in our data analysis. 22 | 23 | ### Parameter Configuration 24 | 25 | Classification mode: 26 | 27 | * Binary classification. According to the classification threshold, the concentration markers are divided into high and low categories. 28 | * Multi-class classification. datasets will be directly divided into the original categories by concentration mark. 29 | * Display the number of categories. 30 | 31 | Number of principal components: 32 | 33 | * It can be customized. 34 | * 3 by default. 35 | 36 | Cross-validation folds: 37 | 38 | * It can be customized. 39 | * 10 by default. 40 | 41 | ### Principal component analysis (PCA) 42 | 43 | Perform dimensionality reduction according to the number of principal components configured, and output the variance contribution rate of each dimension and their sum. 44 | 45 | ### Support vector machine (SVM) training 46 | 47 | Use grid search to find the optimal parameters and display them. 48 | 49 | * The labels are 0 and 1 in binary classification. 50 | * In multi-class classification, the labels are concentration * 100 (so that the decimal point is removed), and convert them into int type. 51 | * Perform cross-validation according to the optimal parameters, and output the accuracy rates and average accuracy rate of the cross-validation. 52 | 53 | ### Save and import SVM model file 54 | 55 | The model of the above steps can be saved and imported as file, eliminating the need for repeated steps of retraining 56 | 57 | ### Prediction 58 | 59 | Import the data set that needs to be predicted then make prediction: 60 | 61 | * The reading method of the prediction set is the same as that of the training set, so it is necessary to ensure that there is no data in the first row and first column, because this part of the content will be discarded. 62 | * The results will alse be saved as `.csv` file. 63 | 64 | ### Data demonstration 65 | 66 | Each step will instantly display the output data of the current step in the table on the right. 67 | 68 | ## Instructions for use 69 | 70 | ### Running environment 71 | 72 | Anaconda (Python 3.7.1 64-bit) 73 | 74 | Windows 10 Pro 64-bit 75 | 76 | Command line start: 77 | 78 | ```shell 79 | python main.py 80 | ``` 81 | 82 | ### Dependencies 83 | 84 | pyqt5 85 | 86 | pandas 87 | 88 | numpy 89 | 90 | sklearn 91 | 92 | ### File instructions 93 | 94 | The code contains a lot of comments to facilitate readers’ understanding. 95 | 96 | #### main.py 97 | 98 | Core functions including: 99 | 100 | * GUI initialization and main process. 101 | * `.csv` file operations and label treatment. 102 | * Table display and processing. 103 | * Parameter configuration and button operations. 104 | 105 | #### utils.py 106 | 107 | Mainly includes machine learning fuctions: 108 | 109 | * PCA 110 | * SVM training including cross-validation. 111 | * Save and import model file. 112 | 113 | ### Project achievements 114 | 115 | The entire process GUI of spectral classification processing based on PyQt5 including: 116 | 117 | * File import, text display, button events, layout and typesetting. 118 | * QTableWidget table display. 119 | 120 | Support vector machine model training and prediction based on sklearn. 121 | 122 | File access, data statistics and analysis processing. 123 | 124 | ### Interface language 125 | 126 | The interface language is `Chinese`, if you need a English version (or you want to provide a English version) please send me a email (fanwei1995@hotmail.com) or create an issue. 127 | -------------------------------------------------------------------------------- /icon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cheereus/PyQt5-RamanSpectraClassification/dc36282796268a739522354c8c22a02cd7888674/icon.jpg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # 2.0版 旨在尽量减少数据和逻辑耦合 增强模块化和可维护性 2 | import sys 3 | import random 4 | import pandas as pd 5 | import numpy as np 6 | 7 | # 引入 PyQt5 相关 8 | from PyQt5 import QtCore,QtGui 9 | from PyQt5.QtGui import QIcon 10 | from PyQt5.QtWidgets import * 11 | 12 | # 引入自定义模块 13 | from utils import * 14 | 15 | # 程序窗口 16 | class ApplicationWindow(QMainWindow): 17 | def __init__(self): 18 | 19 | # 程序窗口初始化 20 | super().__init__() 21 | self.setAttribute(QtCore.Qt.WA_DeleteOnClose) 22 | self.center() 23 | self.setWindowTitle("拉曼光谱分析") 24 | self.setWindowIcon(QIcon('icon.jpg')) 25 | 26 | # 全局变量 27 | self.classType = 'binary' # 分类模式 默认二分类 28 | self.classNum = 2 # 分类数目 默认是二分类模式下的 2 29 | self.threshold = 20 # 分类阈值 默认 20 30 | self.components = 3 # 主成分数 默认 3 31 | self.scross = 10 # S折交叉验证 默认 10 32 | self.X = None # 原始数据 33 | self.headline = None # 原始表头 即浓度 34 | self.newX = None # 降维后的数据 35 | self.labels = None # 标注序列 36 | self.ratio = None # 方差贡献率 37 | self.OSVM = None # 最佳 SVM 模型 38 | self.pcaModel = None # pca 模型 39 | 40 | # 布局 41 | self.main_widget = QWidget(self) 42 | mlayout = QHBoxLayout(self.main_widget) 43 | llayout = QFormLayout() 44 | rlayout = QVBoxLayout() 45 | 46 | addForm(self, llayout) # 表单初始化 47 | addTable(self, rlayout) # 表格初始化 48 | self.addMenu() # 菜单栏初始化 49 | 50 | # 窗口初始化 51 | mlayout.addLayout(llayout) 52 | mlayout.addLayout(rlayout) 53 | self.main_widget.setFocus() 54 | self.setCentralWidget(self.main_widget) 55 | 56 | # 居中 57 | def center(self): 58 | qr = self.frameGeometry() 59 | cp = QDesktopWidget().availableGeometry().center() 60 | qr.moveCenter(cp) 61 | self.move(qr.topLeft()) 62 | # 菜单栏初始化 63 | def addMenu(self): 64 | self.file_menu = QMenu('文件', self) 65 | self.file_menu.addAction('导入训练集', lambda: openTrain(self)) 66 | self.file_menu.addAction('导入模型', lambda: openModel(self)) 67 | self.menuBar().addMenu(self.file_menu) 68 | self.help_menu = QMenu('帮助', self) 69 | self.menuBar().addSeparator() 70 | self.menuBar().addMenu(self.help_menu) 71 | self.help_menu.addAction('关于', self.about) 72 | def fileQuit(self): 73 | self.close() 74 | def closeEvent(self, ce): 75 | self.fileQuit() 76 | def about(self): 77 | QMessageBox.about(self, "关于", 78 | """ 79 | 对拉曼光谱数据进行分类的支持向量机模型 80 | 将拉曼光谱数据根据浓度分类并进行预测 81 | 82 | 作者:CheeReus_11 83 | """ 84 | ) 85 | 86 | # 文件操作 87 | 88 | # csv文件读取 文件中第一行为浓度 89 | # 输出 纯数据矩阵 X 、浓度向量 y 其中浓度向量去除了小数点 90 | def csvReader(filePath): 91 | 92 | # 由于路径中的中文会导致 read_csv 报错,因此先用 open 打开 93 | f = open(filePath,'r') 94 | data=pd.read_csv(f,header=None, low_memory=False) 95 | f.close() 96 | array_data = np.array(data) 97 | X=array_data.T[1:,1:] # 纯数据 98 | y=array_data.T[1:,0] # 第一行 99 | y = y * 100 # 支持向量机训练仅支持 int 类型的 label 因此将浓度值进行去除小数点的处理 100 | return X,y 101 | 102 | def openTrain(self): 103 | 104 | filePath,filetype = QFileDialog.getOpenFileName(self,"选取文件","./", "CSV Files (*.csv)") 105 | if filePath: 106 | showLabel(self.filepath, "已选择训练集:" + filePath) 107 | self.X, self.headline = csvReader(filePath) 108 | self.labels, self.classNum = getLabel(self) 109 | setTable(self, self.labels, self.X) 110 | self.btn0.setEnabled(True) 111 | self.btn1.setEnabled(True) 112 | self.btn2.setEnabled(True) 113 | self.btn3.setEnabled(True) 114 | self.btn4.setEnabled(True) 115 | 116 | else: 117 | QMessageBox.warning(self,"温馨提示","打开文件错误,请重新尝试!",QMessageBox.Cancel) 118 | 119 | # 导入训练集后才允许修改参数 120 | self.btn0.setEnabled(True) 121 | self.btn1.setEnabled(True) 122 | self.btn2.setEnabled(True) 123 | self.btn3.setEnabled(True) 124 | 125 | # 读取模型 126 | def openModel(self): 127 | filePath,filetype = QFileDialog.getOpenFileName(self,"选取文件","./", "model Files (*.pkl)") 128 | if filePath: 129 | self.OSVM, self.pcaModel, self.classNum, self.threshold, self.components, self.scross = modelReader(filePath) 130 | 131 | self.thresholdLabel.setText('二分类阈值:' + str(self.threshold)) 132 | self.componentsLabel.setText('主成分数目:' + str(self.components)) 133 | self.scrossLabel.setText('交叉验证数目:' + str(self.scross)) 134 | self.filepath.setText("已选择模型:" + filePath) 135 | if self.classNum > 2: 136 | self.classType = 'multi' 137 | showLabel(self.classTypeLabel,'当前模式:多分类(' + str(self.classNum) + '类)') 138 | else: 139 | self.classType = 'binary' 140 | showLabel(self.classTypeLabel,'当前模式:二分类') 141 | self.btn6.setVisible(True) 142 | # 导入模型后不允许修改参数及主成分分析、训练等操作 143 | self.btn0.setEnabled(False) 144 | self.btn1.setEnabled(False) 145 | self.btn2.setEnabled(False) 146 | self.btn3.setEnabled(False) 147 | self.btn4.setEnabled(False) 148 | self.btn5.setEnabled(False) 149 | self.ratioLabel.setVisible(False) 150 | self.btn5.setEnabled(False) 151 | self.svmLabel.setVisible(False) 152 | 153 | else: 154 | QMessageBox.warning(self,"温馨提示","打开文件错误,请重新尝试!",QMessageBox.Cancel) 155 | 156 | # 表单操作 157 | 158 | # 表单初始化 159 | def addForm(self,llayout): 160 | 161 | self.filepath = QLabel('请先导入数据或模型!') 162 | self.filepath.setFixedWidth(600) 163 | llayout.addRow(self.filepath) 164 | 165 | # 参数展示及修改 166 | self.classTypeLabel = QLabel('当前模式:二分类') 167 | self.thresholdLabel = QLabel('二分类阈值:' + str(self.threshold)) 168 | self.componentsLabel = QLabel('主成分数目:' + str(self.components)) 169 | self.scrossLabel = QLabel('交叉验证折数:' + str(self.scross)) 170 | self.btn0 = QPushButton('切换') 171 | self.btn1 = QPushButton('修改') 172 | self.btn2 = QPushButton('修改') 173 | self.btn3 = QPushButton('修改') 174 | self.btn0.clicked.connect(lambda: toggleClassType(self)) 175 | self.btn1.clicked.connect(lambda: changeParas(self,1)) 176 | self.btn2.clicked.connect(lambda: changeParas(self,2)) 177 | self.btn3.clicked.connect(lambda: changeParas(self,3)) 178 | llayout.setLabelAlignment(QtCore.Qt.AlignRight) # 标签右对齐 179 | llayout.addRow(self.classTypeLabel,self.btn0) 180 | llayout.addRow(self.thresholdLabel,self.btn1) 181 | llayout.addRow(self.componentsLabel,self.btn2) 182 | llayout.addRow(self.scrossLabel,self.btn3) 183 | 184 | # 导入训练集后才允许修改参数 185 | self.btn0.setEnabled(False) 186 | self.btn1.setEnabled(False) 187 | self.btn2.setEnabled(False) 188 | self.btn3.setEnabled(False) 189 | 190 | # 主成分分析 191 | self.btn4 = QPushButton('主成分分析') 192 | self.btn4.clicked.connect(lambda: pca(self, self.X, self.components, showTable=True)) 193 | llayout.addRow(self.btn4) 194 | self.ratioLabel = QLabel('方差贡献率:') 195 | llayout.addRow(self.ratioLabel) 196 | self.btn4.setEnabled(False) 197 | self.ratioLabel.setVisible(False) 198 | 199 | # 训练支持向量机 200 | self.btn5 = QPushButton('训练支持向量机') 201 | self.btn5.clicked.connect(lambda: getSVM(self, self.newX, self.labels, self.scross)) 202 | llayout.addRow(self.btn5) 203 | self.svmLabel = QLabel('交叉验证准确率:') 204 | llayout.addRow(self.svmLabel) 205 | self.btn5.setEnabled(False) 206 | self.svmLabel.setVisible(False) 207 | 208 | # 预测 209 | self.btn6 = QPushButton('开始预测') 210 | self.btn6.clicked.connect(lambda: getPredict(self)) 211 | self.predictLabel = QLabel('预测结果已保存至 output/predictResult.csv') 212 | llayout.addRow(self.btn6) 213 | llayout.addRow(self.predictLabel) 214 | self.btn6.setVisible(False) 215 | self.predictLabel.setVisible(False) 216 | 217 | # 清空表格 218 | self.btn7 = QPushButton('清空右侧表格') 219 | self.btn7.clicked.connect(lambda: TableOperator.clearTable(self)) 220 | llayout.addRow(self.btn7) 221 | self.btn7.setEnabled(False) 222 | 223 | # 设置并展示label文字 224 | def showLabel(labelObj, labelTxt): 225 | labelObj.setText(labelTxt) 226 | labelObj.setVisible(True) 227 | 228 | # 表格操作 229 | 230 | # 表格初始化 231 | def addTable(self,rlayout): 232 | self.tableWidget = QTableWidget() 233 | self.tableWidget.setRowCount(20) 234 | self.tableWidget.setColumnCount(20) 235 | self.tableWidget.setVerticalHeaderLabels([str(item) for item in range(20)]) 236 | rlayout.addWidget(self.tableWidget) 237 | 238 | # 清空表格 239 | def clearTable(self): 240 | self.tableWidget.setRowCount(20) 241 | self.tableWidget.setColumnCount(20) 242 | self.tableWidget.setVerticalHeaderLabels([str(item) for item in range(20)]) 243 | for i in range(20): 244 | for j in range(20): 245 | #为每个表格内添加空白数据 246 | self.tableWidget.setItem(i,j,QTableWidgetItem('')) 247 | # 关闭清空数据按钮 248 | self.btn7.setEnabled(False) 249 | 250 | # 表格数据展示 与第一版区别主要在 此处不再处理label 只关心数据填充 251 | def setTable(self,head,data): 252 | data = data.T 253 | rows, columns = data.shape 254 | if head is None: 255 | head = range(rows) 256 | self.tableWidget.setRowCount(rows+1) 257 | self.tableWidget.setColumnCount(columns) 258 | self.tableWidget.setVerticalHeaderLabels([str(item) for item in range(rows+1)]) 259 | for k in range(columns): 260 | #为表格添加内置头部 261 | self.tableWidget.setItem(0,k,QTableWidgetItem(str(head[k]))) 262 | for i in range(rows): 263 | for j in range(columns): 264 | #为每个表格内添加数据 265 | self.tableWidget.setItem(i+1,j,QTableWidgetItem(str(data[i,j]))) 266 | # 开启清空数据按钮 267 | self.btn7.setEnabled(True) 268 | 269 | # 数据预处理及参数配置 270 | 271 | # 获取分类标签及分类数目 二分类时按阈值返回0或1 多分类时原样返回 272 | # 同时返回分类数目 并更新界面文字 273 | def getLabel(self): 274 | 275 | if self.classType == 'multi': 276 | nums = len(pd.value_counts(self.headline)) 277 | showLabel(self.classTypeLabel,'当前模式:多分类(' + str(nums) + '类)') 278 | return self.headline, nums 279 | else: 280 | headSize = len(self.headline) 281 | labels = np.zeros(headSize) 282 | for i in range(headSize): 283 | if self.headline[i] > (self.threshold * 100): 284 | labels[i] = 1 285 | showLabel(self.classTypeLabel,'当前模式:二分类') 286 | return labels, 2 287 | 288 | # 切换分类模式 289 | def toggleClassType(self): 290 | if self.classType == 'binary': 291 | self.classType = 'multi' 292 | else: 293 | self.classType = 'binary' 294 | self.labels, self.classNum = getLabel(self) 295 | if self.classType == 'binary': 296 | setTable(self, self.labels, self.X) 297 | else: 298 | setTable(self, self.labels / 100, self.X) 299 | # 切换后需要重新训练 因此关闭部分按钮和文字 300 | self.ratioLabel.setVisible(False) 301 | self.btn5.setEnabled(False) 302 | self.svmLabel.setVisible(False) 303 | self.btn6.setVisible(False) 304 | self.predictLabel.setVisible(False) 305 | 306 | # 参数设定 307 | def changeParas(self,type): 308 | if type == 1: 309 | self.threshold, ok = QInputDialog.getDouble(self, "二分类阈值", "请输入阈值:", 25.00, 0, 100, 2) 310 | self.thresholdLabel.setText('二分类阈值:' + str(self.threshold)) 311 | if type == 2: 312 | self.components, ok = QInputDialog.getInt(self, "主成分数目", "请输入整数:", 3, 2, 1000, 0) 313 | self.componentsLabel.setText('主成分数目:' + str(self.components)) 314 | if type == 3: 315 | self.scross, ok = QInputDialog.getInt(self, "交叉验证数目", "请输入整数:", 10, 2, 20, 0) 316 | self.scrossLabel.setText('交叉验证数目:' + str(self.scross)) 317 | 318 | # 机器学习模块 319 | 320 | # 主成分分析 321 | def pca(self,X,n, showTable): 322 | self.newX, self.ratio, self.pcaModel = pca_op(X,n) 323 | ratioText = "方差贡献率:\n" 324 | sum = 0 325 | for i in range(len(self.ratio)): 326 | ratioText = ratioText + str(self.ratio[i]) + "\n" 327 | sum = sum + self.ratio[i] 328 | ratioText = ratioText + '\n总贡献率:' + str(sum) + "\n降维后数据及标注序列已保存至 output 目录下" 329 | showLabel(self.ratioLabel,ratioText) 330 | if showTable: 331 | setTable(self, self.labels, self.newX) 332 | self.btn5.setEnabled(True) 333 | 334 | # 支持向量机训练 335 | def getSVM(self,newX,labels,s): 336 | self.OSVM, scores, bestParas = cross_validation(newX,labels,s) 337 | modelSave(self.OSVM, self.pcaModel, self.classNum, self.threshold, self.components, self.scross) 338 | 339 | # 如果要针对不同的核函数加 if else 就改这个 bestText 340 | bestText = "最优参数:\n" + "C:" + str(bestParas['C']) + '\ngamma:' + str(bestParas['gamma']) + '\ndegree:' + str(bestParas['degree']) + '\nkernel:' + str(bestParas['kernel']) + '\ndecision_function_shape:' + str(bestParas['decision_function_shape']) 341 | 342 | svmText = "\n交叉验证准确率:\n" 343 | avg = 0 344 | for i in range(len(scores)): 345 | avg = avg + scores[i] 346 | svmText = svmText + str(scores[i]) + "\n" 347 | avg = avg / len(scores) 348 | svmText = bestText + svmText + '平均准确率:' + str(avg) + "\n模型已保存至 output/svm_model_with_pca.pkl" 349 | showLabel(self.svmLabel,svmText) 350 | self.btn6.setVisible(True) 351 | 352 | # 预测 353 | def getPredict(self): 354 | # 开始预测后不允许修改参数及主成分分析、训练等操作,除非重新导入训练集 355 | self.btn0.setEnabled(False) 356 | self.btn1.setEnabled(False) 357 | self.btn2.setEnabled(False) 358 | self.btn3.setEnabled(False) 359 | self.btn4.setEnabled(False) 360 | self.btn5.setEnabled(False) 361 | filePath,filetype = QFileDialog.getOpenFileName(self,"选取文件","./", "CSV Files (*.csv)") 362 | if filePath: 363 | self.filepath.setText("已选测试集:" + filePath) 364 | self.X, self.headline = csvReader(filePath) 365 | self.newX = re_pca(self.X, self.pcaModel) 366 | self.labels = self.OSVM.predict(self.newX) 367 | 368 | correct = 0 369 | if self.classType == 'binary': 370 | labelY, nums = getLabel(self) 371 | for i in range(len(self.labels)): 372 | if self.labels[i] == labelY[i]: 373 | correct = correct + 1 374 | else: 375 | for i in range(len(self.labels)): 376 | if self.labels[i] == self.headline[i]: 377 | correct = correct + 1 378 | correctRate = correct / len(self.labels) 379 | 380 | if self.classType == 'multi': 381 | self.labels = self.labels / 100 382 | setTable(self, self.labels, self.newX) 383 | data = np.hstack((self.labels[:,None], self.X)) 384 | data = np.hstack(((self.headline / 100)[:,None], data)) 385 | pd.DataFrame.to_csv(pd.DataFrame(data.T),'output/predictResult.csv', mode='w',header=None,index=None) 386 | showLabel(self.predictLabel,'准确率:' + str(correctRate) + '\n预测结果已保存至 output/predictResult.csv') 387 | 388 | if __name__ == '__main__': 389 | app = QApplication(sys.argv) 390 | aw = ApplicationWindow() 391 | aw.showMaximized() 392 | #sys.exit(qApp.exec_()) 393 | app.exec_() 394 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from sklearn.decomposition import PCA # 导入PCA模块 4 | from sklearn.preprocessing import StandardScaler, MinMaxScaler, Normalizer # 导入数据预处理归一化类 5 | from sklearn import svm 6 | from sklearn.model_selection import cross_val_score 7 | from sklearn.externals import joblib 8 | from sklearn.model_selection import GridSearchCV 9 | 10 | # 输入数据矩阵 X 及主成分数目 c 11 | # 主成分数目 c 要小于数据矩阵 X 的长和宽 即 c <= min(x.shape[0],x.shape[1]) 12 | # 输出降维后的数据及贡献率 13 | def pca_op(X,c=3): 14 | 15 | prepress = Normalizer() # 16 | x = prepress.fit_transform(X) # 拟合转换数据一统一量纲标准化 17 | pca_result = PCA(n_components=c) # 降维后有c个主成分 18 | pca_result.fit(x) # 训练 19 | newX = pca_result.fit_transform(x) # 降维后的数据 20 | 21 | # 保存为csv文件 22 | np.savetxt('output/pca_x.csv', newX, delimiter = ',') 23 | 24 | return newX, pca_result.explained_variance_ratio_, pca_result 25 | 26 | # 使用先有的 pca 模型进行直接降维 输出降维后的矩阵 27 | def re_pca(X, pcaModel): 28 | prepress = Normalizer() # 29 | x = prepress.fit_transform(X) # 拟合转换数据一统一量纲标准化 30 | newX = pcaModel.transform(x) 31 | return newX 32 | 33 | # 输入降维后的数据 x 标注 y 交叉验证折数 s 34 | # 交叉验证折数是有限制的 必须保证训练集每个类都能至少分为 s 份 即单个类的数目 sn 应满足 sn / s >= 1 35 | # 输出最佳的SVM模型 36 | def cross_validation(x,y,s=10): 37 | 38 | svc = svm.SVC() 39 | parameters = [ 40 | { 41 | 'C': [1, 3, 5], 42 | 'gamma': [0.001, 0.1, 1, 10], 43 | 'degree': [3,5,7,9], 44 | 'kernel': ['linear','poly', 'rbf', 'sigmoid'], 45 | 'decision_function_shape': ['ovo', 'ovr' ,None] 46 | } 47 | ] 48 | clf=GridSearchCV(svc,parameters,cv=s,refit=True) 49 | y = y.astype('int') 50 | clf.fit(x, y) 51 | print(clf.best_params_) 52 | print(clf.best_score_) 53 | 54 | cross_model = svm.SVC(C=clf.best_params_['C'],degree=clf.best_params_['degree'],kernel=clf.best_params_['kernel'],gamma=clf.best_params_['gamma'], decision_function_shape=clf.best_params_['decision_function_shape'], verbose=0) 55 | scores = cross_val_score(cross_model, x, y.ravel(), cv=s) 56 | 57 | return clf.best_estimator_, scores, clf.best_params_ 58 | 59 | class model: 60 | def __init__(self, OSVM, pcaModel, classNum, threshold, n_components, scross): 61 | self.svmModel = OSVM 62 | self.pcaModel = pcaModel 63 | self.classNum = classNum 64 | self.threshold = threshold 65 | self.n_components = n_components 66 | self.scross = scross 67 | 68 | # 保存模型 69 | def modelSave(OSVM, pcaModel, classNum, threshold, n_components, scross): 70 | modelObj = model(OSVM, pcaModel, classNum, threshold, n_components, scross) 71 | joblib.dump(modelObj, 'output/svm_model_with_pca.pkl') 72 | 73 | # 读取模型 74 | def modelReader(filePath): 75 | 76 | f = open(filePath,'rb') 77 | model = joblib.load(f) 78 | f.close() 79 | return model.svmModel, model.pcaModel, model.classNum, model.threshold, model.n_components, model.scross --------------------------------------------------------------------------------