├── .idea ├── .gitignore ├── Adaboost.iml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml └── modules.xml ├── LICENSE ├── README.md ├── main.py └── 训练数据.png /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # 默认忽略的文件 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.idea/Adaboost.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Jack Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AdaBoost(自适应增强算法) 2 | 以弱分类器为决策树桩,用 AdaBoost 算法学习了一个强分类器 3 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*-coding:utf-8 -*- 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | def loadSimpData(): 5 | """ 6 | 创建单层决策树的数据集 7 | Parameters: 8 | 无 9 | Returns: 10 | dataMat - 数据矩阵 11 | classLabels - 数据标签 12 | """ 13 | dataMat = np.matrix([[0., 1., 3.], 14 | [0., 3., 1.], 15 | [1., 2., 2.], 16 | [1., 1., 3.], 17 | [1., 2., 3.], 18 | [0., 1., 2.], 19 | [1., 1., 2.], 20 | [1., 1., 1.], 21 | [1., 3., 1.], 22 | [0., 2., 1.]]) 23 | classLabels = np.matrix([-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, 1.0, -1.0, -1.0]) 24 | return dataMat, classLabels 25 | 26 | def showDataSet(dataMat, labelMat): 27 | """ 28 | 数据可视化 29 | Parameters: 30 | dataMat - 数据矩阵 31 | labelMat - 数据标签 32 | Returns: 33 | 无 34 | """ 35 | ax = plt.axes(projection='3d') 36 | data_plus = [] #正样本 37 | data_minus = [] #负样本 38 | labelMat = labelMat.T #label矩阵转置 39 | #将数据集分别存放到正负样本的矩阵 40 | for i in range(len(dataMat)): 41 | if labelMat[i] > 0: 42 | data_plus.append(dataMat[i]) 43 | else: 44 | data_minus.append(dataMat[i]) 45 | data_plus_np = np.array(data_plus) #转换为numpy矩阵 46 | data_minus_np = np.array(data_minus) #转换为numpy矩阵 47 | ax.scatter(np.transpose(data_plus_np)[0], np.transpose(data_plus_np)[1], np.transpose(data_plus_np)[2], c='r') #正样本散点图 48 | ax.scatter(np.transpose(data_minus_np)[0], np.transpose(data_minus_np)[1], np.transpose(data_minus_np)[2], c='b') #负样本散点图 49 | plt.show() 50 | 51 | def stumpClassify(dataMatrix, dimen, threshVal, threshIneq): 52 | """ 53 | 单层决策树分类函数 54 | Parameters: 55 | dataMatrix - 数据矩阵 56 | dimen - 第dimen列,也就是第几个特征 57 | threshVal - 阈值 58 | threshIneq - 划分的符号 - lt:less than,gt:greater than 59 | Returns: 60 | retArray - 分类结果 61 | """ 62 | retArray = np.ones((np.shape(dataMatrix)[0], 1)) # 初始化retArray为1 63 | if threshIneq == 'lt': 64 | retArray[dataMatrix[:, dimen] <= threshVal] = -1.0 # 如果小于阈值,则赋值为-1 65 | else: 66 | retArray[dataMatrix[:, dimen] > threshVal] = -1.0 # 如果大于阈值,则赋值为-1 67 | return retArray 68 | 69 | def buildStump(dataArr, classLabels, D): 70 | """ 71 | 找到数据集上最佳的单层决策树 72 | Parameters: 73 | dataArr - 数据矩阵 74 | classLabels - 数据标签 75 | D - 样本权重 76 | Returns: 77 | bestStump - 最佳单层决策树信息 78 | minError - 最小误差 79 | bestClasEst - 最佳的分类结果 80 | Tips: 81 | 在已经写好单层决策树桩的前提下,这个函数需要用来确定哪个特征作为划分维度、划分的阈值以及划分的符号,从而输出“最佳的单层决策树”。 82 | 具体来说,我们需要做一个嵌套三层的遍历:第一层遍历所有特征,第二层遍历这一维度特征所有可能的阈值,第三层遍历划分的符号; 83 | 在确定以上三个关键信息之后,我们只需要调用决策树桩函数并获得其预测结果,结合真值计算误差; 84 | 将误差最小的决策树桩的信息用一个字典储存下来,作为最终的输出结果; 85 | """ 86 | dataMatrix = np.mat(dataArr) 87 | labelMat = np.mat(classLabels).T 88 | m, n = np.shape(dataMatrix) 89 | numSteps = 10.0 90 | bestStump = {} 91 | bestClassEst = np.mat(np.zeros((m, 1))) 92 | minError = np.inf 93 | for i in range(n): 94 | rangeMin = dataMatrix[:, i].min() 95 | rangeMax = dataMatrix[:, i].max() 96 | stepSize = (rangeMax - rangeMin) / numSteps 97 | for j in range(-1, int(numSteps) + 1): 98 | for inequal in ['lt', 'gt']: 99 | threshVal = (rangeMin + float(j) * stepSize) 100 | predictedVals = stumpClassify(dataMatrix, i, threshVal, inequal) 101 | errArr = np.mat(np.ones((m, 1))) 102 | errArr[predictedVals == labelMat] = 0 103 | weightedError = D.T * errArr 104 | # 输出决策树信息,最小误差,估计的类别向量 105 | # print("split: dim %d, thresh %.2f, thresh ineqal: %s, the weighted error is %.3f" % \ 106 | # (i, threshVal, inequal, weightedError)) 107 | if weightedError < minError: 108 | minError = weightedError 109 | bestClassEst = predictedVals.copy() 110 | bestStump['dim'] = i 111 | bestStump['thresh'] = threshVal 112 | bestStump['ineq'] = inequal 113 | return bestStump, minError, bestClassEst 114 | 115 | 116 | def adaBoostTrainDS(dataArr, classLabels, numIt=40): 117 | """ 118 | 完整决策树训练 119 | Parameters: 120 | dataArr - 数据矩阵 121 | classLabels - 数据标签 122 | numIt - 默认迭代次数 123 | Returns: 124 | weakClassArr- 完整决策树信息 125 | aggClassEst- 最终训练数据权值分布 126 | Tips: 127 | 基于我们已经写好的最优决策树函数,我们可以在现有数据上进行迭代,不断更新数据权重、算法权重与其对应的决策树桩, 128 | 直到误差为零,退出循环 129 | """ 130 | weakClassArr = [] # 用于存储弱分类器 131 | m = np.shape(dataArr)[0] # 数据点数量 132 | D = np.mat(np.ones((m, 1)) / m) # 初始化权值矩阵 133 | aggClassEst = np.mat(np.zeros((m, 1))) # 记录每个数据点的类别估计累计值 134 | 135 | for i in range(numIt): 136 | # 1. 使用buildStump()函数得到当前迭代中最佳的弱分类器 137 | bestStump, error, classEst = buildStump(dataArr, classLabels, D) 138 | 139 | # 2. 计算当前弱分类器的权重alpha,防止过拟合,需要控制alpha小于1.0 140 | alpha = float(0.5 * np.log((1.0 - error) / max(error, 1e-16))) 141 | bestStump['alpha'] = alpha 142 | weakClassArr.append(bestStump) 143 | 144 | # 3. 更新权值向量D 145 | expon = np.multiply(-1 * alpha * np.mat(classLabels).T, classEst) 146 | D = np.multiply(D, np.exp(expon)) 147 | D = D / D.sum() 148 | 149 | # 4. 计算每个数据点的类别估计累计值 150 | aggClassEst += alpha * classEst 151 | 152 | # 5. 计算错误率,如果错误率为0,则直接跳出循环 153 | errorRate = np.sum(np.sign(aggClassEst) != np.mat(classLabels).T) / m 154 | if errorRate == 0.0: 155 | break 156 | 157 | # 返回弱分类器集合和每个数据点的类别估计累计值 158 | return weakClassArr, aggClassEst 159 | if __name__ == '__main__': 160 | dataArr, classLabels = loadSimpData() 161 | showDataSet(dataArr, classLabels) 162 | weakClassArr,aggClassEst = adaBoostTrainDS(dataArr, classLabels) 163 | print(weakClassArr) 164 | print(aggClassEst) 165 | -------------------------------------------------------------------------------- /训练数据.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Patrick9313/AdaBoost/dc8dbad03f04744384c79f8243c4a1afc010f937/训练数据.png --------------------------------------------------------------------------------