├── README.md ├── SVM&ANN.zip └── SVM&ANN └── Code ├── ANN.py └── Dataload.py /README.md: -------------------------------------------------------------------------------- 1 | # SVM-ANN 2 | The code for NIRS analysis based SVM&ANN 3 | # 一、数据来源 4 | 使用药品数据,共310个样本,每条样本404个变量,根据活性成分,分成4类 5 | 图片如下: 6 | ![](https://img-blog.csdnimg.cn/2245e0129ed349eaae824a0963aae5f2.png?x-oss-process=image/watermark,type_ZHJvaWRzYW5zZmFsbGJhY2s,shadow_50,text_Q1NETiBARWNob19Db2Rl,size_20,color_FFFFFF,t_70,g_se,x_16) 7 | 8 | ![药品数据光谱](https://img-blog.csdnimg.cn/7d76d4b1ed4d4f71a150298bdae3da74.png?x-oss-process=image/watermark,type_ZHJvaWRzYW5zZmFsbGJhY2s,shadow_50,text_Q1NETiBARWNob19Db2Rl,size_20,color_FFFFFF,t_70,g_se,x_16 9 | # 二、代码解读 10 | ## 2.1 载入数据 11 | 12 | ```python 13 | def plotspc(x_col, data_x, tp): 14 | # figsize = 5, 3 15 | figsize = 8, 5.5 16 | figure, ax = plt.subplots(figsize=figsize) 17 | # ax = plt.figure(figsize=(5,3)) 18 | x_col = x_col[::-1] # 数组逆序 19 | y_col = np.transpose(data_x) 20 | plt.plot(x_col, y_col ) 21 | plt.tick_params(labelsize=12) 22 | labels = ax.get_xticklabels() + ax.get_yticklabels() 23 | [label.set_fontname('Times New Roman') for label in labels] 24 | font = {'weight': 'normal', 25 | 'size': 16, 26 | } 27 | plt.xlabel("Wavenumber/$\mathregular{cm^{-1}}$", font) 28 | plt.ylabel("Absorbance", font) 29 | # plt.title("The spectrum of the {} dataset".format(tp), fontweight="semibold", fontsize='x-large') 30 | plt.show() 31 | plt.tick_params(labelsize=23) 32 | 33 | def TableDataLoad(tp, test_ratio, start, end, seed): 34 | 35 | # global data_x 36 | data_path = '..//Data//table.csv' 37 | Rawdata = np.loadtxt(open(data_path, 'rb'), dtype=np.float64, delimiter=',', skiprows=0) 38 | table_random_state = seed 39 | 40 | if tp =='raw': 41 | data_x = Rawdata[0:, start:end] 42 | 43 | # x_col = np.linspace(0, 400, 400) 44 | if tp =='SG': 45 | SGdata_path = './/Code//TableSG.csv' 46 | data = np.loadtxt(open(SGdata_path, 'rb'), dtype=np.float64, delimiter=',', skiprows=0) 47 | data_x = data[0:, start:end] 48 | if tp =='SNV': 49 | SNVata_path = './/Code//TableSNV.csv' 50 | data = np.loadtxt(open(SNVata_path, 'rb'), dtype=np.float64, delimiter=',', skiprows=0) 51 | data_x = data[0:, start:end] 52 | if tp == 'MSC': 53 | MSCdata_path = './/Code//TableMSC.csv' 54 | data = np.loadtxt(open(MSCdata_path, 'rb'), dtype=np.float64, delimiter=',', skiprows=0) 55 | data_x = data[0:, start:end] 56 | data_y = Rawdata[0:, -1] 57 | x_col = np.linspace(7400, 10507, 404) 58 | plotspc(x_col, data_x[:, :], tp=0) 59 | 60 | x_data = np.array(data_x) 61 | y_data = np.array(data_y) 62 | X_train, X_test, y_train, y_test = train_test_split(x_data, y_data, test_size=test_ratio,random_state=table_random_state) 63 | return X_train, X_test, y_train, y_test 64 | ``` 65 | ![结果](https://img-blog.csdnimg.cn/b5603b89569642d88fc8366a46311d28.png?x-oss-process=image/watermark,type_ZHJvaWRzYW5zZmFsbGJhY2s,shadow_50,text_Q1NETiBARWNob19Db2Rl,size_20,color_FFFFFF,t_70,g_se,x_16) 66 | 67 | 68 | ## 2.1 基于SVM的药品光谱分类 69 | ### 2.1.1 建立SVM参数寻找,找到最佳的SVM参数 70 | 71 | ```python 72 | def SVM_Classs_test(train_x ,train_y,test_x,test_y): 73 | params = [ 74 | {'kernel': ['linear'], 'C': [ 0.1,0.5, 1, 1.5,2,3,5, 10, 15,50,100],'gamma': [1e-7,1e-6,1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1]}, 75 | {'kernel': ['poly'], 'C': [0.1, 1, 2, 10, 15], 'degree': [2, 3]} 76 | # {'kernel': ['rbf'], 'C': [0.1, 1, 2, 10], 'gamma':[1e-3, 1e-2, 1e-1, 1, 2]} 77 | ] 78 | 79 | model = ms.GridSearchCV(svm.SVC(probability=True), 80 | params, 81 | refit=True, 82 | return_train_score=True, # 后续版本需要指定True才有score方法 83 | cv=10) 84 | model.fit(train_x, train_y) 85 | model_best = model.best_estimator_ 86 | pred_test_y = model_best.predict(test_x) 87 | acc_test = accuracy_score(test_y, pred_test_y) 88 | print("best pamer is {}".format(model.best_estimator_)) 89 | print("acc is {}".format(acc_test)) 90 | print(sm.classification_report(test_y, pred_test_y)) 91 | 92 | ``` 93 | 结果如下: 94 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/53bea34399e74c06ae5bf9d29057e024.png?x-oss-process=image/watermark,type_ZHJvaWRzYW5zZmFsbGJhY2s,shadow_50,text_Q1NETiBARWNob19Db2Rl,size_20,color_FFFFFF,t_70,g_se,x_16) 95 | 96 | ### 2 .1.2 进行svm的测试 97 | 98 | ```python 99 | def SVM(train_x ,train_y,test_x,test_y): 100 | clf = svm.SVC(probability=True, C=10, gamma=1e-7,kernel='linear') 101 | clf.fit(train_x, train_y) 102 | pred = clf.predict(test_x) 103 | acc_test = accuracy_score(test_y, pred) 104 | return acc_test 105 | ``` 106 | 107 | ```python 108 | if __name__ == '__main__': 109 | test_ratio = 0.3 110 | tp = 'raw' 111 | 112 | X_train, X_test, y_train, y_test = TableDataLoad(tp=tp, test_ratio=test_ratio, start=0, end=404, seed=80) 113 | 114 | SVM_Classs_test(train_x=X_train, train_y=y_train, test_x=X_test, test_y=y_test) 115 | 116 | acc = SVM(train_x=X_train, train_y=y_train, test_x=X_test, test_y=y_test) 117 | print(acc) 118 | ``` 119 | 测试结果 120 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/f6e66bcfa7c44ca3810602e7cc128aef.png) 121 | -------------------------------------------------------------------------------- /SVM&ANN.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuSiry/SVM-ANN/39697be2a83e90d453e0647adf13a7fff16e2d2f/SVM&ANN.zip -------------------------------------------------------------------------------- /SVM&ANN/Code/ANN.py: -------------------------------------------------------------------------------- 1 | import sklearn.svm as svm 2 | import sklearn.metrics as sm 3 | import sklearn.model_selection as ms 4 | from sklearn.metrics import accuracy_score 5 | from Code.Dataload import TableDataLoad 6 | 7 | 8 | 9 | 10 | def SVM_Classs_test(train_x ,train_y,test_x,test_y): 11 | params = [ 12 | {'kernel': ['linear'], 'C': [ 0.1,0.5, 1, 1.5,2,3,5, 10, 15,50,100],'gamma': [1e-7,1e-6,1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1]}, 13 | {'kernel': ['poly'], 'C': [0.1, 1, 2, 10, 15], 'degree': [2, 3]} 14 | # {'kernel': ['rbf'], 'C': [0.1, 1, 2, 10], 'gamma':[1e-3, 1e-2, 1e-1, 1, 2]} 15 | ] 16 | 17 | model = ms.GridSearchCV(svm.SVC(probability=True), 18 | params, 19 | refit=True, 20 | return_train_score=True, # 后续版本需要指定True才有score方法 21 | cv=10) 22 | model.fit(train_x, train_y) 23 | model_best = model.best_estimator_ 24 | pred_test_y = model_best.predict(test_x) 25 | acc_test = accuracy_score(test_y, pred_test_y) 26 | print("best pamer is {}".format(model.best_estimator_)) 27 | print("acc is {}".format(acc_test)) 28 | print(sm.classification_report(test_y, pred_test_y)) 29 | 30 | def SVM(train_x ,train_y,test_x,test_y): 31 | clf = svm.SVC(probability=True, C=10, gamma=1e-7,kernel='linear') 32 | clf.fit(train_x, train_y) 33 | pred = clf.predict(test_x) 34 | acc_test = accuracy_score(test_y, pred) 35 | return acc_test 36 | 37 | if __name__ == '__main__': 38 | test_ratio = 0.3 39 | tp = 'raw' 40 | 41 | X_train, X_test, y_train, y_test = TableDataLoad(tp=tp, test_ratio=test_ratio, start=0, end=404, seed=80) 42 | 43 | SVM_Classs_test(train_x=X_train, train_y=y_train, test_x=X_test, test_y=y_test) 44 | 45 | acc = SVM(train_x=X_train, train_y=y_train, test_x=X_test, test_y=y_test) 46 | print(acc) -------------------------------------------------------------------------------- /SVM&ANN/Code/Dataload.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from sklearn.model_selection import train_test_split 3 | from sklearn.preprocessing import scale,MinMaxScaler,Normalizer,StandardScaler 4 | import numpy as np 5 | 6 | def plotspc(x_col, data_x, tp): 7 | # figsize = 5, 3 8 | figsize = 8, 5.5 9 | figure, ax = plt.subplots(figsize=figsize) 10 | # ax = plt.figure(figsize=(5,3)) 11 | x_col = x_col[::-1] # 数组逆序 12 | y_col = np.transpose(data_x) 13 | plt.plot(x_col, y_col ) 14 | plt.tick_params(labelsize=12) 15 | labels = ax.get_xticklabels() + ax.get_yticklabels() 16 | [label.set_fontname('Times New Roman') for label in labels] 17 | font = {'weight': 'normal', 18 | 'size': 16, 19 | } 20 | plt.xlabel("Wavenumber/$\mathregular{cm^{-1}}$", font) 21 | plt.ylabel("Absorbance", font) 22 | # plt.title("The spectrum of the {} dataset".format(tp), fontweight="semibold", fontsize='x-large') 23 | plt.show() 24 | plt.tick_params(labelsize=23) 25 | 26 | def TableDataLoad(tp, test_ratio, start, end, seed): 27 | 28 | # global data_x 29 | data_path = '..//Data//table.csv' 30 | Rawdata = np.loadtxt(open(data_path, 'rb'), dtype=np.float64, delimiter=',', skiprows=0) 31 | table_random_state = seed 32 | 33 | if tp =='raw': 34 | data_x = Rawdata[0:, start:end] 35 | 36 | # x_col = np.linspace(0, 400, 400) 37 | if tp =='SG': 38 | SGdata_path = './/Code//TableSG.csv' 39 | data = np.loadtxt(open(SGdata_path, 'rb'), dtype=np.float64, delimiter=',', skiprows=0) 40 | data_x = data[0:, start:end] 41 | if tp =='SNV': 42 | SNVata_path = './/Code//TableSNV.csv' 43 | data = np.loadtxt(open(SNVata_path, 'rb'), dtype=np.float64, delimiter=',', skiprows=0) 44 | data_x = data[0:, start:end] 45 | if tp == 'MSC': 46 | MSCdata_path = './/Code//TableMSC.csv' 47 | data = np.loadtxt(open(MSCdata_path, 'rb'), dtype=np.float64, delimiter=',', skiprows=0) 48 | data_x = data[0:, start:end] 49 | data_y = Rawdata[0:, -1] 50 | x_col = np.linspace(7400, 10507, 404) 51 | plotspc(x_col, data_x[:, :], tp=0) 52 | 53 | x_data = np.array(data_x) 54 | y_data = np.array(data_y) 55 | X_train, X_test, y_train, y_test = train_test_split(x_data, y_data, test_size=test_ratio,random_state=table_random_state) 56 | return X_train, X_test, y_train, y_test --------------------------------------------------------------------------------