├── .idea
├── .gitignore
├── Adaboost.iml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
└── modules.xml
├── LICENSE
├── README.md
├── main.py
└── 训练数据.png
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # 默认忽略的文件
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/.idea/Adaboost.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Jack Lee
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AdaBoost(自适应增强算法)
2 | 以弱分类器为决策树桩,用 AdaBoost 算法学习了一个强分类器
3 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # -*-coding:utf-8 -*-
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | def loadSimpData():
5 | """
6 | 创建单层决策树的数据集
7 | Parameters:
8 | 无
9 | Returns:
10 | dataMat - 数据矩阵
11 | classLabels - 数据标签
12 | """
13 | dataMat = np.matrix([[0., 1., 3.],
14 | [0., 3., 1.],
15 | [1., 2., 2.],
16 | [1., 1., 3.],
17 | [1., 2., 3.],
18 | [0., 1., 2.],
19 | [1., 1., 2.],
20 | [1., 1., 1.],
21 | [1., 3., 1.],
22 | [0., 2., 1.]])
23 | classLabels = np.matrix([-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, 1.0, -1.0, -1.0])
24 | return dataMat, classLabels
25 |
26 | def showDataSet(dataMat, labelMat):
27 | """
28 | 数据可视化
29 | Parameters:
30 | dataMat - 数据矩阵
31 | labelMat - 数据标签
32 | Returns:
33 | 无
34 | """
35 | ax = plt.axes(projection='3d')
36 | data_plus = [] #正样本
37 | data_minus = [] #负样本
38 | labelMat = labelMat.T #label矩阵转置
39 | #将数据集分别存放到正负样本的矩阵
40 | for i in range(len(dataMat)):
41 | if labelMat[i] > 0:
42 | data_plus.append(dataMat[i])
43 | else:
44 | data_minus.append(dataMat[i])
45 | data_plus_np = np.array(data_plus) #转换为numpy矩阵
46 | data_minus_np = np.array(data_minus) #转换为numpy矩阵
47 | ax.scatter(np.transpose(data_plus_np)[0], np.transpose(data_plus_np)[1], np.transpose(data_plus_np)[2], c='r') #正样本散点图
48 | ax.scatter(np.transpose(data_minus_np)[0], np.transpose(data_minus_np)[1], np.transpose(data_minus_np)[2], c='b') #负样本散点图
49 | plt.show()
50 |
51 | def stumpClassify(dataMatrix, dimen, threshVal, threshIneq):
52 | """
53 | 单层决策树分类函数
54 | Parameters:
55 | dataMatrix - 数据矩阵
56 | dimen - 第dimen列,也就是第几个特征
57 | threshVal - 阈值
58 | threshIneq - 划分的符号 - lt:less than,gt:greater than
59 | Returns:
60 | retArray - 分类结果
61 | """
62 | retArray = np.ones((np.shape(dataMatrix)[0], 1)) # 初始化retArray为1
63 | if threshIneq == 'lt':
64 | retArray[dataMatrix[:, dimen] <= threshVal] = -1.0 # 如果小于阈值,则赋值为-1
65 | else:
66 | retArray[dataMatrix[:, dimen] > threshVal] = -1.0 # 如果大于阈值,则赋值为-1
67 | return retArray
68 |
69 | def buildStump(dataArr, classLabels, D):
70 | """
71 | 找到数据集上最佳的单层决策树
72 | Parameters:
73 | dataArr - 数据矩阵
74 | classLabels - 数据标签
75 | D - 样本权重
76 | Returns:
77 | bestStump - 最佳单层决策树信息
78 | minError - 最小误差
79 | bestClasEst - 最佳的分类结果
80 | Tips:
81 | 在已经写好单层决策树桩的前提下,这个函数需要用来确定哪个特征作为划分维度、划分的阈值以及划分的符号,从而输出“最佳的单层决策树”。
82 | 具体来说,我们需要做一个嵌套三层的遍历:第一层遍历所有特征,第二层遍历这一维度特征所有可能的阈值,第三层遍历划分的符号;
83 | 在确定以上三个关键信息之后,我们只需要调用决策树桩函数并获得其预测结果,结合真值计算误差;
84 | 将误差最小的决策树桩的信息用一个字典储存下来,作为最终的输出结果;
85 | """
86 | dataMatrix = np.mat(dataArr)
87 | labelMat = np.mat(classLabels).T
88 | m, n = np.shape(dataMatrix)
89 | numSteps = 10.0
90 | bestStump = {}
91 | bestClassEst = np.mat(np.zeros((m, 1)))
92 | minError = np.inf
93 | for i in range(n):
94 | rangeMin = dataMatrix[:, i].min()
95 | rangeMax = dataMatrix[:, i].max()
96 | stepSize = (rangeMax - rangeMin) / numSteps
97 | for j in range(-1, int(numSteps) + 1):
98 | for inequal in ['lt', 'gt']:
99 | threshVal = (rangeMin + float(j) * stepSize)
100 | predictedVals = stumpClassify(dataMatrix, i, threshVal, inequal)
101 | errArr = np.mat(np.ones((m, 1)))
102 | errArr[predictedVals == labelMat] = 0
103 | weightedError = D.T * errArr
104 | # 输出决策树信息,最小误差,估计的类别向量
105 | # print("split: dim %d, thresh %.2f, thresh ineqal: %s, the weighted error is %.3f" % \
106 | # (i, threshVal, inequal, weightedError))
107 | if weightedError < minError:
108 | minError = weightedError
109 | bestClassEst = predictedVals.copy()
110 | bestStump['dim'] = i
111 | bestStump['thresh'] = threshVal
112 | bestStump['ineq'] = inequal
113 | return bestStump, minError, bestClassEst
114 |
115 |
116 | def adaBoostTrainDS(dataArr, classLabels, numIt=40):
117 | """
118 | 完整决策树训练
119 | Parameters:
120 | dataArr - 数据矩阵
121 | classLabels - 数据标签
122 | numIt - 默认迭代次数
123 | Returns:
124 | weakClassArr- 完整决策树信息
125 | aggClassEst- 最终训练数据权值分布
126 | Tips:
127 | 基于我们已经写好的最优决策树函数,我们可以在现有数据上进行迭代,不断更新数据权重、算法权重与其对应的决策树桩,
128 | 直到误差为零,退出循环
129 | """
130 | weakClassArr = [] # 用于存储弱分类器
131 | m = np.shape(dataArr)[0] # 数据点数量
132 | D = np.mat(np.ones((m, 1)) / m) # 初始化权值矩阵
133 | aggClassEst = np.mat(np.zeros((m, 1))) # 记录每个数据点的类别估计累计值
134 |
135 | for i in range(numIt):
136 | # 1. 使用buildStump()函数得到当前迭代中最佳的弱分类器
137 | bestStump, error, classEst = buildStump(dataArr, classLabels, D)
138 |
139 | # 2. 计算当前弱分类器的权重alpha,防止过拟合,需要控制alpha小于1.0
140 | alpha = float(0.5 * np.log((1.0 - error) / max(error, 1e-16)))
141 | bestStump['alpha'] = alpha
142 | weakClassArr.append(bestStump)
143 |
144 | # 3. 更新权值向量D
145 | expon = np.multiply(-1 * alpha * np.mat(classLabels).T, classEst)
146 | D = np.multiply(D, np.exp(expon))
147 | D = D / D.sum()
148 |
149 | # 4. 计算每个数据点的类别估计累计值
150 | aggClassEst += alpha * classEst
151 |
152 | # 5. 计算错误率,如果错误率为0,则直接跳出循环
153 | errorRate = np.sum(np.sign(aggClassEst) != np.mat(classLabels).T) / m
154 | if errorRate == 0.0:
155 | break
156 |
157 | # 返回弱分类器集合和每个数据点的类别估计累计值
158 | return weakClassArr, aggClassEst
159 | if __name__ == '__main__':
160 | dataArr, classLabels = loadSimpData()
161 | showDataSet(dataArr, classLabels)
162 | weakClassArr,aggClassEst = adaBoostTrainDS(dataArr, classLabels)
163 | print(weakClassArr)
164 | print(aggClassEst)
165 |
--------------------------------------------------------------------------------
/训练数据.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Patrick9313/AdaBoost/dc8dbad03f04744384c79f8243c4a1afc010f937/训练数据.png
--------------------------------------------------------------------------------