├── README.md ├── __pycache__ └── treePlotter.cpython-37.pyc ├── dataset.jpg ├── dataset.txt ├── decisionTree ├── 4.4-1'.png ├── 4.4-1.png ├── 4.4-2.png ├── 4.4-3.png ├── 4.6-1.png ├── 4.6-2.png ├── 4.6-4.png ├── 4.6-5.png ├── 4.6-6.png ├── ID3-1.png ├── ID3-2.png ├── ID3.png ├── gini.png ├── x.dot └── 决策树实验报告.md ├── pic_results ├── C4.5.jpg ├── CART.jpg ├── ID3.jpg ├── figure_C4.5.jpg ├── figure_CART.jpg ├── figure_ID3.jpg └── first_find_bset_Index.jpg ├── testdata.jpg ├── testset.txt ├── tree.py ├── treePlotter.py ├── 数据表.jpg └── 数据表.xlsx /README.md: -------------------------------------------------------------------------------- 1 | # Decision_tree-python 2 | ### 决策树分类(ID3,C4.5,CART) 3 | ### 三种算法的区别如下: 4 | #### (1) ID3算法以信息增益为准则来进行选择划分属性,选择信息增益最大的;
5 | #### (2) C4.5算法先从候选划分属性中找出信息增益高于平均水平的属性,再从中选择增益率最高的;
6 | #### (3) CART算法使用“基尼指数”来选择划分属性,选择基尼值最小的属性作为划分属性.
7 | 8 | ### 本次实验我的数据集如下所示: 9 | ##### 共分为四个属性特征:年龄段,有工作,有自己的房子,信贷情况; 10 | ##### 现根据这四种属性特征来决定是否给予贷款 11 | ![image](https://github.com/Erikfather/Decision_tree-python/blob/master/%E6%95%B0%E6%8D%AE%E8%A1%A8.jpg) 12 | 13 | 14 | #### 为了方便,我对数据集进行如下处理: 15 | #### 在编写代码之前,我们先对数据集进行属性标注。 16 | #### (0)年龄:0代表青年,1代表中年,2代表老年; 17 | #### (1)有工作:0代表否,1代表是; 18 | #### (2)有自己的房子:0代表否,1代表是; 19 | #### (3)信贷情况:0代表一般,1代表好,2代表非常好; 20 | #### (4)类别(是否给贷款):no代表否,yes代表是。 21 | #### 存入txt文件中: 22 | ![image](https://github.com/Erikfather/Decision_tree-python/blob/master/dataset.jpg) 23 | 24 | #### 然后分别利用ID3,C4.5,CART三种算法对数据集进行决策树分类; 25 | #### 具体代码见:tree.py和treePlotter.py 26 | 27 | #### 实验结果如下: 28 | ##### (1)先将txt中数据读入数组,并打印出数据集长度,即共有多少条数据,并计算出起始信息熵Ent(D),然后分别找出三种算法的首个最优特征索引 29 | ![image](https://github.com/Erikfather/Decision_tree-python/blob/master/pic_results/first_find_bset_Index.jpg) 30 | ##### (2)下面通过相应的最优划分属性分别创建相应的决策树,并对测试集进行测试,测试集如下: 31 | ![image](https://github.com/Erikfather/Decision_tree-python/blob/master/testdata.jpg) 32 | 33 | #### 结果如下: 34 | ##### ID3算法: 35 | ![iamge](https://github.com/Erikfather/Decision_tree-python/blob/master/pic_results/ID3.jpg) 36 | ![iamge](https://github.com/Erikfather/Decision_tree-python/blob/master/pic_results/figure_ID3.jpg) 37 | 38 | ##### C4.5算法: 39 | ![image](https://github.com/Erikfather/Decision_tree-python/blob/master/pic_results/C4.5.jpg) 40 | ![image](https://github.com/Erikfather/Decision_tree-python/blob/master/pic_results/figure_C4.5.jpg) 41 | 42 | ##### CART算法: 43 | ![image](https://github.com/Erikfather/Decision_tree-python/blob/master/pic_results/CART.jpg) 44 | ![image](https://github.com/Erikfather/Decision_tree-python/blob/master/pic_results/figure_CART.jpg) 45 | 46 | #### 由上面结果可以看出: 47 | ##### (1)ID3和C4.5的最优索引以及决策树形图是相同的,而CART的最优索引以及决策树形图与前面两者不同,这与它们的选择标准以及训练集有关; 48 | ##### (2)但同时我们也发现,三种算法对测试集的测试结果是相同的,经过后期手动匹配,结果完全正确,这说明我们的决策树实验结果是正确的。 49 | -------------------------------------------------------------------------------- /__pycache__/treePlotter.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Erikfather/Decision_tree-python/6d29eec25d083a9afec9ecb1f762770928f2287b/__pycache__/treePlotter.cpython-37.pyc -------------------------------------------------------------------------------- /dataset.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Erikfather/Decision_tree-python/6d29eec25d083a9afec9ecb1f762770928f2287b/dataset.jpg -------------------------------------------------------------------------------- /dataset.txt: -------------------------------------------------------------------------------- 1 | 0,0,0,0,0 2 | 0,0,0,1,0 3 | 0,1,0,1,1 4 | 0,1,1,0,1 5 | 0,0,0,0,0 6 | 1,0,0,0,0 7 | 1,0,0,1,0 8 | 1,1,1,1,1 9 | 1,0,1,2,1 10 | 1,0,1,2,1 11 | 2,0,1,2,1 12 | 2,0,1,1,1 13 | 2,1,0,1,1 14 | 2,1,0,2,1 15 | 2,0,0,0,0 16 | 2,0,0,2,0 17 | -------------------------------------------------------------------------------- /decisionTree/4.4-1'.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Erikfather/Decision_tree-python/6d29eec25d083a9afec9ecb1f762770928f2287b/decisionTree/4.4-1'.png -------------------------------------------------------------------------------- /decisionTree/4.4-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Erikfather/Decision_tree-python/6d29eec25d083a9afec9ecb1f762770928f2287b/decisionTree/4.4-1.png -------------------------------------------------------------------------------- /decisionTree/4.4-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Erikfather/Decision_tree-python/6d29eec25d083a9afec9ecb1f762770928f2287b/decisionTree/4.4-2.png -------------------------------------------------------------------------------- /decisionTree/4.4-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Erikfather/Decision_tree-python/6d29eec25d083a9afec9ecb1f762770928f2287b/decisionTree/4.4-3.png -------------------------------------------------------------------------------- /decisionTree/4.6-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Erikfather/Decision_tree-python/6d29eec25d083a9afec9ecb1f762770928f2287b/decisionTree/4.6-1.png -------------------------------------------------------------------------------- /decisionTree/4.6-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Erikfather/Decision_tree-python/6d29eec25d083a9afec9ecb1f762770928f2287b/decisionTree/4.6-2.png -------------------------------------------------------------------------------- /decisionTree/4.6-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Erikfather/Decision_tree-python/6d29eec25d083a9afec9ecb1f762770928f2287b/decisionTree/4.6-4.png -------------------------------------------------------------------------------- /decisionTree/4.6-5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Erikfather/Decision_tree-python/6d29eec25d083a9afec9ecb1f762770928f2287b/decisionTree/4.6-5.png -------------------------------------------------------------------------------- /decisionTree/4.6-6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Erikfather/Decision_tree-python/6d29eec25d083a9afec9ecb1f762770928f2287b/decisionTree/4.6-6.png -------------------------------------------------------------------------------- /decisionTree/ID3-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Erikfather/Decision_tree-python/6d29eec25d083a9afec9ecb1f762770928f2287b/decisionTree/ID3-1.png -------------------------------------------------------------------------------- /decisionTree/ID3-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Erikfather/Decision_tree-python/6d29eec25d083a9afec9ecb1f762770928f2287b/decisionTree/ID3-2.png -------------------------------------------------------------------------------- /decisionTree/ID3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Erikfather/Decision_tree-python/6d29eec25d083a9afec9ecb1f762770928f2287b/decisionTree/ID3.png -------------------------------------------------------------------------------- /decisionTree/gini.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Erikfather/Decision_tree-python/6d29eec25d083a9afec9ecb1f762770928f2287b/decisionTree/gini.png -------------------------------------------------------------------------------- /decisionTree/x.dot: -------------------------------------------------------------------------------- 1 | digraph DI { 2 | node1[label="color"] 3 | node2[label="F", shape="box"] 4 | node3[label="size"] 5 | node4[label="F", shape="box"] 6 | node5[label="T", shape="box"] 7 | 8 | node1 -> node2[label="purple"] 9 | node1 -> node3[label="yellow"] 10 | node3 -> node4[label="large"] 11 | node3 -> node5[label="small"] 12 | 13 | } -------------------------------------------------------------------------------- /decisionTree/决策树实验报告.md: -------------------------------------------------------------------------------- 1 | # 决策树实验报告 2 | 3 | ## 实验要求 4 | 5 | 1. 编程实现一个基于信息熵进行划分选择的决策树算法,并为表中的数据生成一棵决策树。 6 | 7 | | 编号 | 色泽 | 根蒂 | 敲声 | 纹理 | 脐部 | 触感 | 密度 | 含糖率 | 好瓜 | 8 | | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ----- | ------ | ---- | 9 | | 1 | 青绿 | 蜷缩 | 浊响 | 清晰 | 凹陷 | 硬滑 | 0.697 | 0.460 | 是 | 10 | | 2 | 乌黑 | 蜷缩 | 沉闷 | 清晰 | 凹陷 | 硬滑 | 0.774 | 0.376 | 是 | 11 | | 3 | 乌黑 | 蜷缩 | 浊响 | 清晰 | 凹陷 | 硬滑 | 0.634 | 0.264 | 是 | 12 | | 4 | 青绿 | 蜷缩 | 沉闷 | 清晰 | 凹陷 | 硬滑 | 0.608 | 0.318 | 是 | 13 | | 5 | 浅白 | 蜷缩 | 浊响 | 清晰 | 凹陷 | 硬滑 | 0.556 | 0.215 | 是 | 14 | | 6 | 青绿 | 稍蜷 | 浊响 | 清晰 | 稍凹 | 软粘 | 0.403 | 0.237 | 是 | 15 | | 7 | 乌黑 | 稍蜷 | 浊响 | 稍糊 | 稍凹 | 软粘 | 0.481 | 0.149 | 是 | 16 | | 8 | 乌黑 | 稍蜷 | 浊响 | 清晰 | 稍凹 | 硬滑 | 0.437 | 0.211 | 是 | 17 | | 9 | 乌黑 | 稍蜷 | 沉闷 | 稍糊 | 稍凹 | 硬滑 | 0.666 | 0.091 | 否 | 18 | | 10 | 青绿 | 硬挺 | 清脆 | 清晰 | 平坦 | 软粘 | 0.243 | 0.267 | 否 | 19 | | 11 | 浅白 | 硬挺 | 清脆 | 模糊 | 平坦 | 硬滑 | 0.245 | 0.057 | 否 | 20 | | 12 | 浅白 | 蜷缩 | 浊响 | 模糊 | 平坦 | 软粘 | 0.343 | 0.099 | 否 | 21 | | 13 | 青绿 | 稍蜷 | 浊响 | 稍糊 | 凹陷 | 硬滑 | 0.639 | 0.161 | 否 | 22 | | 14 | 浅白 | 稍蜷 | 沉闷 | 稍糊 | 凹陷 | 硬滑 | 0.657 | 0.198 | 否 | 23 | | 15 | 乌黑 | 稍蜷 | 浊响 | 清晰 | 稍凹 | 软粘 | 0.360 | 0.370 | 否 | 24 | | 16 | 浅白 | 蜷缩 | 浊响 | 模糊 | 平坦 | 硬滑 | 0.593 | 0.042 | 否 | 25 | | 17 | 青绿 | 蜷缩 | 沉闷 | 稍糊 | 稍凹 | 硬滑 | 0.719 | 0.103 | 否 | 26 | 27 | 2. 编程实现基于基尼指数进行划分选择的决策树算法,为表中数据生成预剪枝,后剪枝决策树,并与未剪枝的决策树进行比较。 28 | 29 | | 编号 | 色泽 | 根蒂 | 敲声 | 纹理 | 脐部 | 触感 | 好瓜 | 30 | | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | 31 | | 1 | 青绿 | 蜷缩 | 浊响 | 清晰 | 凹陷 | 硬滑 | 是 | 32 | | 2 | 乌黑 | 蜷缩 | 沉闷 | 清晰 | 凹陷 | 硬滑 | 是 | 33 | | 3 | 乌黑 | 蜷缩 | 浊响 | 清晰 | 凹陷 | 硬滑 | 是 | 34 | | 4 | 青绿 | 蜷缩 | 沉闷 | 清晰 | 凹陷 | 硬滑 | 是 | 35 | | 5 | 浅白 | 蜷缩 | 浊响 | 清晰 | 凹陷 | 硬滑 | 是 | 36 | | 6 | 青绿 | 稍蜷 | 浊响 | 清晰 | 稍凹 | 软粘 | 是 | 37 | | 7 | 乌黑 | 稍蜷 | 浊响 | 稍糊 | 稍凹 | 软粘 | 是 | 38 | | 8 | 乌黑 | 稍蜷 | 浊响 | 清晰 | 稍凹 | 硬滑 | 是 | 39 | | 9 | 乌黑 | 稍蜷 | 沉闷 | 稍糊 | 稍凹 | 硬滑 | 否 | 40 | | 10 | 青绿 | 硬挺 | 清脆 | 清晰 | 平坦 | 软粘 | 否 | 41 | | 11 | 浅白 | 硬挺 | 清脆 | 模糊 | 平坦 | 硬滑 | 否 | 42 | | 12 | 浅白 | 蜷缩 | 浊响 | 模糊 | 平坦 | 软粘 | 否 | 43 | | 13 | 青绿 | 稍蜷 | 浊响 | 稍糊 | 凹陷 | 硬滑 | 否 | 44 | | 14 | 浅白 | 稍蜷 | 沉闷 | 稍糊 | 凹陷 | 硬滑 | 否 | 45 | | 15 | 乌黑 | 稍蜷 | 浊响 | 清晰 | 稍凹 | 软粘 | 否 | 46 | | 16 | 浅白 | 蜷缩 | 浊响 | 模糊 | 平坦 | 硬滑 | 否 | 47 | | 17 | 青绿 | 蜷缩 | 沉闷 | 稍糊 | 稍凹 | 硬滑 | 否 | 48 | 49 | 3. 选择两个UCI数据集,对上述3种算法产生的未剪枝,预剪枝,后剪枝的决策树进行实验比较 50 | 51 | ## 原理 52 | 53 | ### 决策树 54 | 55 | * 定义:一种描述对实例进行分类的树形结构。决策树由点和有向边组成。节点有两种类型:内部节点和叶节点。内部节点表示一种特征或者属性,叶节点表示一个分类。构建决策树时通常采用自上而下的方法,在每一步选择一个最好的属性来分裂。[8] "最好" 的定义是使得子节点中的训练集尽量的纯。不同的算法使用不同的指标来定义"最好"。 56 | * 意义:每次都找不同的切分点,将样本空间逐渐进行细分,最后把属于同一类的空间进行合并,就形成了决策边界,树的层次越深,决策边界的切分就越细,区分越准确,同时也越有可能产生过拟合。 57 | * 决策树是一个预测模型;他代表的是对象属性与对象值之间的一种映射关系。树中每个节点表示某个对象,而每个分叉路径则代表的某个可能的属性值,而每个叶结点则对应从根节点到该叶节点所经历的路径所表示的对象的值。决策树仅有单一输出,若欲有复数输出,可以建立独立的决策树以处理不同输出。 58 | * 分支使用的是节点属性中的离散型数据,如果数据是连续型的,也需要转化成离散型数据才能在决策树中展示。 59 | * 决策树的路径具有一个重要的性质:互斥且完备,即每一个样本均被且只能被一条路径所覆盖。决策树学习算法主要由三部分构成: 60 | * 特征选择 61 | * 决策树生成 62 | * 决策树的剪枝 63 | * 决策树的建立 64 | 开始,构建根节点,将所有训练数据放在根节点,选择一个最优特征,按照这一特征的取值将训练数据分割为子集,使各个子集有一个当前条件下最好的分类。如果这些子集能被基本正确分类,那么构造叶节点,将对应子集集中到叶节点。如果有子集不能被正确分类,那么就这些子集选择新的最优特征,继续对其进行分割,构建相应的节点。递归进行上述的操作,直到所有训练数据子集均能被正确分类。 65 | * 节点分裂:一般当一个节点所代表的属性无法给出判断时,则选择将这一节点分成2个子节点(如不是二叉树的情况会分成n个子节点) 66 | * 阈值的确定,选择适当的阈值使得分类错误率最小。 67 | * 与其他的数据挖掘算法相比,决策树有许多优点: 68 | * 易于理解和解释,人们很容易理解决策树的意义。 69 | * 只需很少的数据准备,其他技术往往需要数据归一化。 70 | * 既可以处理数值型数据也可以处理类别型数据。其他技术往往只能处理一种数据类型。例如关联规则只能处理类别型的而神经网络只能处理数值型的数据。 71 | * 使用白箱模型,输出结果容易通过模型的结构来解释。而神经网络是黑箱模型,很难解释输出的结果。 72 | * 可以通过测试集来验证模型的性能,可以考虑模型的稳定性。 73 | * 强健控制,对噪声处理有好的强健性。 74 | * 可以很好的处理大规模数据 。 75 | 76 | ### ID3 77 | 78 | * 信息熵 entropy 79 | 熵是接收的每条消息中包含的信息的平均量,熵是对不确定性的测量。但是在信息世界,熵越高,则能传输越多的信息,熵越低,则意味着传输的信息越少。 80 | 当取自有限的样本时,熵的公式可以表示为: 81 | $$H(X) = -\sum^{n}_{i = 1}p_i \log_2 p_i$$ 82 | 其中若 $p_i = 0$ 则定义 $p_i\log p_i = 0$ 83 | 84 | ```python 85 | def calc_ent(datasets): 86 | data_length = len(datasets) 87 | label_count = {} 88 | for i in range(data_length): 89 | label = datasets[i][-1] 90 | if label not in label_count: 91 | label_count[label] = 0 92 | label_count[label] += 1 93 | ent = -sum([(p / data_length) * log(p / data_length, 2) 94 | for p in label_count.values()]) 95 | return ent 96 | ``` 97 | 98 | * 信息量 99 | 信息量是对信息的度量,就跟时间的度量是秒一样,当我们考虑一个离散的随机变量x的时候,当我们观察到的这个变量的一个具体值的时候,我们接收到了多少信息呢?多少信息用信息量来衡量,我们接受到的信息量跟具体发生的事件有关。信息的大小跟随机事件的概率有关。越小概率的事情发生了产生的信息量越大,越大概率的事情发生了产生的信息量越小 100 | * 信息增益 information gain 101 | 得知特征 $X$ 的信息而使得类 $Y$ 的信息的不确定性减少的程度。 102 | 特征 $A$ 对训练数据集 $D$ 的信息增益 $g(D, A)$ 定义为集合 $D$ 的经验熵 $H(D)$ 与特征 $A$ 给定条件下 $D$ 的经验条件熵 $H(D|A)$ 的差。 103 | $$g(D,A) = H(D) - H(D|A)$$ 104 | 一般地,熵 $H(Y)$ 与条件熵 $H(Y|X)$ 之差称为互信息(mutual information) 根据信息增益准则进行特征选择的方法是:对训练数据集D,计算其每个特征的信息增益,并比它们的大小,从而选择信息增益最大的特征。 105 | 假设训练数据集为 $D$,样本容量为 $|D|$,有 $K$ 个类别 $C_k$,$|C_k|$ 为类别 $C_k$ 的样本个数。某一特征 $A$ 有 $n$ 个不同的取值 $a_1, a_2, ..., a_n$。根据特征 $A$ 的取值可将数据集 $D$ 划分为 $n$ 个子集 $D_1, D_2, ..., D_n$, $|D_i$|为 $D_i$ 的样本个数。并记子集 $D_i$ 中属于类 $C_k$ 的样本的集合为 $D_{ik}$,$|D_{ik}|$ 为 $D_{ik}$ 的样本个数。 106 | 107 | ```python 108 | def cond_ent(datasets, axis=0): 109 | data_length = len(datasets) 110 | feature_sets = {} 111 | for i in range(data_length): 112 | feature = datasets[i][axis] 113 | if feature not in feature_sets: 114 | feature_sets[feature] = [] 115 | feature_sets[feature].append(datasets[i]) 116 | cond_ent = sum( 117 | [(len(p) / data_length) * calc_ent(p) for p in feature_sets.values()]) 118 | return cond_ent 119 | 120 | def info_gain(ent, cond_ent): 121 | return ent - cond_ent 122 | ``` 123 | 124 | ID3算法的核心是在决策树的各个结点上应用信息增益准则进行特征选择。具体做法是: 125 | 126 | * 从根节点开始,对结点计算所有可能特征的信息增益,选择信息增益最大的特征作为结点的特征,并由该特征的不同取值构建子节点; 127 | * 对子节点递归地调用以上方法,构建决策树; 128 | * 直到所有特征的信息增益均很小或者没有特征可选时为止。 129 | 130 | ```text 131 | 判断数据集中的每个子项是否属于同一类: 132 | if true: 133 | return 类标签; 134 | else: 135 | 寻找划分数据集的最佳特征 136 | 根据最佳特征划分数据集 137 | 创建分支节点 138 | for 每个划分的子集 139 | 递归调用createBranch(); 140 | return 分支节点 141 | ``` 142 | 143 | ### C4.5 144 | 145 | * 信息增益比 information gain ratio 146 | 以信息增益作为划分训练数据集的特征,存在偏向与选择取值较多的特征的问题。使用信息增益比可以对这个问题进行矫正。 147 | 特征 $A$ 对训练数据集 $D$ 的信息增益比 $g_R(D,A)$ 定义为其信息增益 $g(D,A)$ 和训练数据集 $D$ 关于特征 $A$ 的熵 $H_A(D)$ 的比值。 148 | $$g_R(D,A) = \frac{g(D,A)}{H_A(D,A)}$$ 149 | 其中 $H_A(D) = \sum_{i=1}^n \frac{|D_i|}{|D|} \log_2 \frac{|D_i|}{|D|}$, $n$ 是特征 $A$ 取值的个数。 150 | 151 | 相比 ID3 算法,C4.5 算法更换了特征选择的标准,使用信息增益比进行特征选择。不直接选择增益率最大的候选划分属性,候选划分属性中找出信息增益高于平均水平的属性(这样保证了大部分好的的特征),再从中选择增益率最高的(又保证了不会出现编号特征这种极端的情况) 152 | 对于连续值属性来说,可取值数目不再有限,因此可以采用离散化技术(如二分法)进行处理。将属性值从小到大排序,然后选择中间值作为分割点,数值比它小的点被划分到左子树,数值不小于它的点被分到又子树,计算分割的信息增益率,选择信息增益率最大的属性值进行分割。 153 | 154 | ```text 155 | Function C4.5(R:包含连续属性的无类别属性集合,C:类别属性,S:训练集) 156 | /*返回一棵决策树*/ 157 | Begin 158 | If S 为空,返回一个值为 Failure 的单个节点; 159 | If S 是由相同类别属性值的记录组成: 160 | 返回一个带有该值的单个节点; 161 | If R 为空,则返回一个单节点,其值为在 S 的记录中找出的频率最高的类别属性值; 162 | [注意未出现错误则意味着是不适合分类的记录]; 163 | For 所有的属性 R(Ri) Do 164 | If 属性 Ri 为连续属性,则 165 | Begin 166 | 将Ri的最小值赋给 A1: 167 | 将Rm的最大值赋给Am;/*m值手工设置*/ 168 | For j From 2 To m-1 Do Aj=A1+j*(A1Am)/m; 169 | 将 Ri 点的基于{< =Aj,>Aj}的最大信息增益属性 (Ri,S) 赋给 A; 170 | End; 171 | 将 R 中属性之间具有最大信息增益的属性 (D,S) 赋给 D; 172 | 将属性D的值赋给{dj/j=1,2...m}; 173 | 将分别由对应于 D 的值为 dj 的记录组成的S的子集赋给 {sj/j=1,2...m}; 174 | 返回一棵树,其根标记为 D;树枝标记为 d1,d2...dm; 175 | 再分别构造以下树: 176 | C4.5(R-{D},C,S1),C4.5(R-{D},C,S2)...C4.5(R-{D},C,Sm); 177 | 178 | End C4.5 179 | ``` 180 | 181 | ### CART 182 | 183 | CART与ID3区别: CART中用于选择变量的不纯性度量是Gini指数; 如果目标变量是标称的,并且是具有两个以上的类别,则CART可能考虑将目标类别合并成两个超类别(双化); 如果目标变量是连续的,则CART算法找出一组基于树的回归方程来预测目标变量。 184 | 185 | * Gini 指数 186 | 分类问题中假设有 $K$ 个类,样本点属于第 $k$ 个类的概率为 $p_k$,则概率分布的基尼指数为定义为 187 | $$Gini(p) = \sum_{k=1}^K p_k(1 - p_k) = 1 - \sum_{k = 1}^K p_k^2$$ 188 | 对于二分类问题和给定的样本集合 $D$ 其基尼指数为 189 | $$Gini(D) = 1 - \sum_{k=1}^K (\frac{|C_k|}{|D|})^2$$ 190 | 若样本集合 $D$ 根据特征 $A$ 是否取某一可能的值 $a$ 分割为 $D_1, D_2$ 两部分,则在特征 $A$ 的条件下集合 $D$ 的基尼指数定义为 191 | $$Gini(D,A) = \frac{|D_1|}{|D|} Gini(D_1) + \frac{|D_2|}{|D|} Gini(D_2)$$ 192 | $Gini(D)$ 反映了数据集 $D$ 的纯度,值越小,纯度越高。我们在候选集合中选择使得划分后基尼指数最小的属性作为最优化分属性。 193 | 194 | ```python 195 | def gini(data_set): 196 | """ 197 | 计算gini的值,即Gini(p) 198 | """ 199 | length = len(data_set) 200 | category_2_cnt = calculate_diff_count(data_set) 201 | sum = 0.0 202 | for category in category_2_cnt: 203 | sum += pow(float(category_2_cnt[category]) / length, 2) 204 | return 1 - sum 205 | ``` 206 | 207 | CART是一棵二叉树,采用二元切分法,每次把数据切成两份,分别进入左子树、右子树。而且每个非叶子节点都有两个孩子,所以CART的叶子节点比非叶子多1。相比ID3和C4.5,CART应用要多一些,既可以用于分类也可以用于回归。CART分类时,使用基尼指数(Gini)来选择最好的数据分割的特征,gini描述的是纯度,与信息熵的含义相似。CART中每一次迭代都会降低GINI系数。 208 | 209 | * 算法流程: 210 | 211 | 1. CART回归树预测回归连续型数据,假设X与Y分别是输入和输出变量,并且Y是连续变量。在训练数据集所在的输入空间中,递归的将每个区域划分为两个子区域并决定每个子区域上的输出值,构建二叉决策树。 212 | 2. 选择最优切分变量j与切分点 $s$:遍历变量 $j$,对规定的切分变量 $j$ 扫描切分点 $s$,选择使下式得到最小值时的 $(j,s)$ 对。其中 $R_m$ 是被划分的输入空间,$c_m$ 是空间 $R_m$ 对应的固定输出值。 213 | 3. 用选定的 $(j,s)$ 对,划分区域并决定相应的输出值。 214 | 4. 继续对两个子区域调用上述步骤,将输入空间划分为 $M$ 个区域 $R_1,R_2,…,R_m$,生成决策树。 215 | 216 | 当输入空间划分确定时,可以用平方误差来表示回归树对于训练数据的预测方法,用平方误差最小的准则求解每个单元上的最优输出值。 217 | $$\sum_{x_i \in R_m} (y_i - f(x_i))^2$$ 218 | 219 | ### 剪枝 220 | 221 | 为了避免决策树“过拟合”样本。前面的算法生成的决策树非常的详细而庞大,每个属性都被详细地加以考虑,决策树的树叶节点所覆盖的训练样本都是“纯”的。因此用这个决策树来对训练样本进行分类的话,你会发现对于训练样本而言,这个树表现堪称完美,它可以100%完美正确得对训练样本集中的样本进行分类(因为决策树本身就是100%完美拟合训练样本的产物)。但是,这会带来一个问题,如果训练样本中包含了一些错误,按照前面的算法,这些错误也会100%一点不留得被决策树学习了,这就是“过拟合”。 222 | 223 | #### 预剪枝 pre-pruning 224 | 225 | 预剪枝就是在树的构建过程(只用到训练集),设置一个阈值(样本个数小于预定阈值或GINI指数小于预定阈值),使得当在当前分裂节点中分裂前和分裂后的误差超过这个阈值则分列,否则不进行分裂操作。所有决策树的构建方法,都是在无法进一步降低熵的情况下才会停止创建分支的过程,为了避免过拟合,可以设定一个阈值,熵减小的数量小于这个阈值,即使还可以继续降低熵,也停止继续创建分支。但是这种方法实际中的效果并不好。 226 | 在划分之前,所有样本集中于根节点,若不进行划分,该节点被标记为叶节点,其类别标记为训练样例最多的类别。若进行划分在测试集上的准确率小于在根节点不进行划分的准确率,或增幅没有超过阈值,都不进行划分,作为一个叶节点返回当前数据集中最多的标签类型。 227 | 228 | * 预剪枝就是在完全正确分类训练集之前,较早地停止树的生长。 具体在什么时候停止决策树的生长有多种不同的方法: 229 | 1. 一种最为简单的方法就是在决策树到达一定高度的情况下就停止树的生长。 230 | 2. 到达此结点的实例具有相同的特征向量,而不必一定属于同一类, 也可停止生长。 231 | 3. 到达此结点的实例个数小于某一个阈值也可停止树的生长。 232 | 4. 还有一种更为普遍的做法是计算每次扩张对系统性能的增益,如果这个增益值小于某个阈值则不进行扩展。 233 | 234 | * 优点:快速,可以在构建决策树时进行剪枝,显著降低了过拟合风险。由于预剪枝不必生成整棵决策树,且算法相对简单,效率很高,适合解决大规模问题。但是尽管这一方法看起来很直接, 但是怎样精确地估计何时停止树的增长是相当困难的。 235 | * 缺点:预剪枝基于贪心思想,本质上禁止分支展开,给决策树带来了欠拟合的风险。因为视野效果问题 。 也就是说在相同的标准下,也许当前的扩展会造成过度拟合训练数据,但是更进一步的扩展能够满足要求,也有可能准确地拟合训练数据。这将使得算法过早地停止决策树的构造。 236 | 237 | ```python 238 | if pre_pruning: 239 | ans = [] 240 | for index in range(len(test_dataset)): # build label for test dataset 241 | ans.append(test_dataset[index][-1]) 242 | result_counter = Counter() 243 | for vec in dataset: 244 | result_counter[vec[-1]] += 1 245 | # what will it be if it is a leaf node 246 | leaf_output = result_counter.most_common(1)[0][0] 247 | root_acc = cal_acc(test_output=[leaf_output] * len(test_dataset), 248 | label=ans) 249 | outputs = [] 250 | ans = [] 251 | for value in uniqueVals: # expand the node 252 | cut_testset = splitdataset(test_dataset, bestFeat, value) 253 | cut_dataset = splitdataset(dataset, bestFeat, value) 254 | for vec in cut_testset: 255 | ans.append(vec[-1]) 256 | result_counter = Counter() 257 | for vec in cut_dataset: 258 | result_counter[vec[-1]] += 1 259 | leaf_output = result_counter.most_common(1)[0][0] # what will it be if it is a leaf node 260 | outputs += [leaf_output] * len(cut_testset) 261 | cut_acc = cal_acc(test_output=outputs, label=ans) 262 | 263 | if cut_acc <= root_acc + threshold: # whether expand the node or not 264 | return leaf_output 265 | ``` 266 | 267 | #### 后剪枝 post-pruning 268 | 269 | 决策树构造完成后进行剪枝。剪枝的过程是对拥有同样父节点的一组节点进行检查,判断如果将其合并,熵的增加量是否小于某一阈值。如果确实小,则这一组节点可以合并一个节点,其中包含了所有可能的结果。后剪枝是目前最普遍的做法。 270 | 后剪枝的剪枝过程是删除一些子树,然后用其叶子节点代替,这个叶子节点所标识的类别通过大多数原则 (majority class criterion) 确定。所谓大多数原则,是指剪枝过程中, 将一些子树删除而用叶节点代替,这个叶节点所标识的类别用这棵子树中大多数训练样本所属的类别来标识。相比于前剪枝,后剪枝方法更常用,是因为在前剪枝方法中精确地估计何时停止树增长很困难。 271 | 272 | * 优点:欠拟合风险小,泛化性能好 273 | * 缺点:在生成决策树之后完成,自底向上对所有非叶节点进行逐一考察,训练的时间开销较大 274 | 275 | ```python 276 | 277 | def prune_tree(node, prunedList): 278 | # Base case: we've reached a leaf 279 | if isinstance(node, Leaf): 280 | return node 281 | # If we reach a pruned node, make that node a leaf node and return. 282 | # Since it becomes a leaf node, the nodes 283 | # below it are automatically not considered 284 | if int(node.id) in prunedList: 285 | return Leaf(node.rows, node.id, node.depth) 286 | 287 | # Call this function recursively on the true branch 288 | node.true_branch = prune_tree(node.true_branch, prunedList) 289 | 290 | # Call this function recursively on the false branch 291 | node.false_branch = prune_tree(node.false_branch, prunedList) 292 | 293 | return node 294 | ``` 295 | 296 | ## 实现细节 297 | 298 | 测试数据使用李航老师的统计方法中对应的数据集。 299 | 300 | ### ID3 算法实现 301 | 302 | ```python 303 | def ID3_chooseBestFeature(dataset): 304 | numFeatures = len(dataset[0]) - 1 305 | baseEnt = cal_entropy(dataset) 306 | bestInfoGain = 0.0 307 | bestFeature = -1 308 | for i in range(numFeatures): # check all features 309 | featList = [example[i] for example in dataset] 310 | uniqueVals = set(featList) 311 | newEnt = 0.0 312 | # claculate entropy of every divide ways 313 | for value in uniqueVals: 314 | # choose the samples mmeeting the requirement 315 | subdataset = splitdataset(dataset, i, value) 316 | p = len(subdataset) / float(len(dataset)) 317 | newEnt += p * cal_entropy(subdataset) 318 | infoGain = baseEnt - newEnt 319 | 320 | if (infoGain > bestInfoGain): 321 | bestInfoGain = infoGain # choose the largest information gain 322 | bestFeature = i 323 | return bestFeature 324 | ``` 325 | 326 | ### C4.5 算法实现 327 | 328 | ```python 329 | def C45_chooseBestFeatureToSplit(dataset): 330 | numFeatures = len(dataset[0]) - 1 331 | baseEnt = cal_entropy(dataset) 332 | bestInfoGain_ratio = 0.0 333 | bestFeature = -1 334 | for i in range(numFeatures): # check every feature 335 | featList = [example[i] for example in dataset] 336 | uniqueVals = set(featList) 337 | newEnt = 0.0 338 | IV = 0.0 339 | for value in uniqueVals: 340 | subdataset = splitdataset(dataset, i, value) 341 | p = len(subdataset) / float(len(dataset)) 342 | newEnt += p * cal_entropy(subdataset) 343 | IV = IV - p * log(p, 2) 344 | infoGain = baseEnt - newEnt 345 | if (IV == 0): 346 | continue 347 | infoGain_ratio = infoGain / IV # infoGain_ratio of current feature 348 | 349 | if (infoGain_ratio > bestInfoGain_ratio): # choose the greatest gain ratio 350 | bestInfoGain_ratio = infoGain_ratio 351 | bestFeature = i # choose the feature corsbounding to the gain ratio 352 | return bestFeature 353 | ``` 354 | 355 | ### CART 算法实现 356 | 357 | ```python 358 | def CART_chooseBestFeature(dataset): 359 | numFeatures = len(dataset[0]) - 1 # except the column of labels 360 | bestGini = 999999.0 361 | bestFeature = -1 # default label 362 | 363 | for i in range(numFeatures): 364 | featList = [example[i] for example in dataset] 365 | uniqueVals = set(featList) # get the possible values of each feature 366 | gini = 0.0 367 | 368 | for value in uniqueVals: 369 | subdataset = splitdataset(dataset, i, value) 370 | p = len(subdataset) / float(len(dataset)) 371 | subp = len(splitdataset(subdataset, -1, '0')) / float(len(subdataset)) 372 | gini += p * (1.0 - pow(subp, 2) - pow(1 - subp, 2)) 373 | 374 | if (gini < bestGini): 375 | bestGini = gini 376 | bestFeature = i 377 | 378 | return bestFeature 379 | ``` 380 | 381 | ### 建树操作 382 | 383 | 因为建树过程相似,仅选取 ID3 算法的建树过程。 384 | 385 | ```python 386 | def ID3_create_tree(dataset, labels, test_dataset): 387 | classList = [example[-1] for example in dataset] 388 | if classList.count(classList[0]) == len(classList): 389 | # 类别完全相同,停止划分 390 | return classList[0] 391 | if len(dataset[0]) == 1: 392 | # 遍历完所有特征时返回出现次数最多的 393 | return majority_count(classList) 394 | bestFeat = ID3_choose_best_feature(dataset) 395 | bestFeatLabel = labels[bestFeat] 396 | print(u"此时最优索引为:" + (bestFeatLabel)) 397 | 398 | ID3Tree = {bestFeatLabel: {}} 399 | del (labels[bestFeat]) 400 | # 得到列表包括节点所有的属性值 401 | featValues = [example[bestFeat] for example in dataset] 402 | uniqueVals = set(featValues) 403 | 404 | for value in uniqueVals: # 枚举对用特征的每个取值 405 | subLabels = labels[:] 406 | ID3Tree[bestFeatLabel][value] = ID3_create_tree( 407 | split_dataset(dataset, bestFeat, value), 408 | subLabels, 409 | split_dataset(test_dataset, bestFeat, value)) 410 | 411 | if cut_acc >= root_acc: 412 | return leaf_output 413 | 414 | return ID3Tree # 如果没有剪枝返回节点 415 | ``` 416 | 417 | ### 预剪枝 418 | 419 | 在实现时检查划分,如果在测试集上的准确率下降或没有上升到一个阈值时,将进行剪枝。 420 | 421 | ```python 422 | if pre_pruning: 423 | ans = [] 424 | for index in range(len(test_dataset)): 425 | ans.append(test_dataset[index][-1]) 426 | result_counter = Counter() 427 | for vec in dataset: 428 | result_counter[vec[-1]] += 1 429 | leaf_output = result_counter.most_common(1)[0][0] 430 | root_acc = cal_acc(test_output=[leaf_output] * len(test_dataset), label=ans) 431 | # 若当前节点是叶节点的准确率 432 | outputs = [] 433 | ans = [] 434 | for value in uniqueVals: 435 | cut_testset = split_dataset(test_dataset, bestFeat, value) 436 | cut_dataset = split_dataset(dataset, bestFeat, value) 437 | for vec in cut_testset: 438 | ans.append(vec[-1]) 439 | result_counter = Counter() 440 | for vec in cut_dataset: 441 | result_counter[vec[-1]] += 1 442 | leaf_output = result_counter.most_common(1)[0][0] 443 | outputs += [leaf_output] * len(cut_testset) 444 | cut_acc = cal_acc(test_output=outputs, label=ans) # 不进行剪枝在测试集上的准确率 445 | 446 | if cut_acc <= root_acc + threshold: # 检查准确率上升情况 447 | return leaf_output 448 | ``` 449 | 450 | ### 后剪枝 451 | 452 | 因为后剪枝的方法是自下而上的判断是否应该进行剪枝,所以在实现时在返回节点对象之前进行剪枝,决定是返回对象还是返回单一类别。 453 | 454 | ```python 455 | if post_pruning: 456 | tree_output = test_tree(C45Tree, 457 | featLabels=total_labels, 458 | testDataSet=test_dataset) 459 | ans = [] 460 | for vec in test_dataset: 461 | ans.append(vec[-1]) 462 | root_acc = cal_acc(tree_output, ans) 463 | result_counter = Counter() 464 | for vec in dataset: 465 | result_counter[vec[-1]] += 1 466 | leaf_output = result_counter.most_common(1)[0][0] 467 | cut_acc = cal_acc([leaf_output] * len(test_dataset), ans) 468 | 469 | if cut_acc >= root_acc: 470 | return leaf_output 471 | ``` 472 | 473 | 表面上后剪枝的操作比预剪枝少,实际上在测试时递归的测试了当前节点所在子树的正确率。所以后剪枝带来的开销远大于的预剪枝。 474 | 475 | ## 实验结果分析 476 | 477 | ### 西瓜数据集3.0 478 | 479 | 实现基于信息熵进行划分的决策树算法(ID3)算法,并可视化结果如下: 480 | 481 | ![img](ID3-1.png) 482 | 483 | 可以注意到直接使用了编号信息进行划分,这一点和书中描述的 “基于信息增益的算法对可取数值较多的属性有偏好” 一致。下来除去编号属性并对结果进行可视化,结果如下: 484 | 485 | ![img](ID3-2.png) 486 | 487 | ### 西瓜数据集2.0 488 | 489 | 实现基于基尼指数进行划分选择的决策树算法,并进行相应的剪枝操作,使用书中的测试集,验证集划分。 490 | 491 | * 未剪枝 492 | 493 | * 预剪枝 494 | 495 | * 后剪枝 496 | 497 | 498 | ### UCI 数据集 499 | 500 | 使用 UCI 数据集来判断未剪枝,预剪枝,后剪枝的方法产生的差异。 501 | 502 | #### Iris 503 | 504 | 总数据集一共 150 条数据,共三个类别。在数据集中随机抽取30条数据组成测试集,再从中随机抽取30个作为训练时的验证集,余下的90条数据组成训练集。数据集的描述内容包括:花萼长度,花萼宽度,花瓣长度,花瓣宽度。 505 | 506 | * 使用 CART 进行实验 507 | * 未剪枝 508 | 509 | * 预剪枝 510 | 511 | * 后剪枝 512 | 不是剪枝一定要剪掉某个节点,如果树中本身的节点符合要求,可以不用剪枝。 513 | 514 | 515 | 经过观察,通过修改随机数种子来更改测试集和验证集的划分可以很大程度上影响决策树的准确率。 516 | 517 | 将数据集划分函数的 `random state` 从 15 改成 1 之后,后剪枝产生的决策树变为下图: 518 | ![img](4.6-4.png) 519 | 所以数据集的划分可以显著影响树的形状和结果的准确率。 520 | 521 | #### Balloons 522 | 523 | * 数据集描述 524 | 数据集较小,共20个样本,四个属性。[数据集链接](https://archive.ics.uci.edu/ml/datasets/Balloons) 525 | |列名|取值| 526 | |--|--| 527 | |color|yellow, purple| 528 | |size|large, small| 529 | |act|stretch, dip| 530 | |age|adult, child| 531 | |label|T, F| 532 | 533 | 其中选择4个样本作为测试集,4个样本作为训练时使用的验证集,余下的作为训练集。 534 | 535 | * 不剪枝 536 | ![img](4.6-5.png) 537 | * 预剪枝 538 | ![img](4.6-6.png) 539 | * 后剪枝 540 | ![img](4.6-5.png) 541 | 542 | ## 个人感悟 543 | 544 | 通过本次实验有如下的发现和收获: 545 | 546 | * 训练数据集大小和模型精度的关系: 547 | * 当训练数据集过小时,建立的模型精度过低,不具有参考价值。 548 | * 随训练数据集尺寸增大,建立模型的分类精度也会随之增大。 549 | * 当训练数据集尺寸增大到一定程度时,建立模型的精度不会再持续增大,且最大分类精度不会超过模型对训练数据的拟合度。 550 | * 测试数据集大小与模型精度的关系: 551 | * 当测试数据集过小时,所测模型精度不具有代表性,没有参考价值。 552 | * 随测试数据集尺寸增大,模型精度也会随之增大。 553 | * 当测试数据集尺寸增大到一定程度时,对模型精度的测量值不会再持续增大,并保持在某一数值上下微小浮动。 554 | * 属性个数对数据集大小与模型精度的关系的影响: 555 | * 当实例的属性个数过少时,所建模型精度低,没有参考价值。 556 | * 随实例的属性个数增多,所建立模型的精度也会随之增大。 557 | 558 | 通过本次实验透彻的了解了决策树各种构造算法和剪枝算法,通过使用 Python 进行实现基本的决策树和简单的剪枝算法,锻炼了我的代码能力和对相应伪代码的理解能力。 559 | 560 | 在一开始,我对于决策树的整体没有认识,不知道如何从零开始构建一棵决策树,通过仔细研究课本,将思路从认识整体调整为模拟数据的流向。 561 | 首先定义参与到结构中的决策树的数据格式为 list 数组类型的嵌套。然后思考如何构建一棵决策树,自然而然想到的就是递归,通过返回节点的实例来递归建树。 562 | 下面碰到的问题就在于如何有机统一不同特征的不同的取值可能和在推理时的便利实现相结合,那么 Python 内建的 dict 字典类型就是很好的选择,通过设置特征的不同的可能取值为 key 值,对应的 value 为递归返回的结果,若连接的为叶节点,则 value 为 '0' 或 '1' 表示正负例,若为决策树的内部节点,则储存对应的对象,将查询请求递归处理直到叶节点。 563 | 下来就是漫长的 debug 时间,通过 pyCharm 方便的断点,我能成功复现每个异常发生时的程序情况,思考为什么会发生当前的状况和怎样进行修改。一开始我在实现后剪枝是还是给整个函数传入一整棵树,先递归到叶节点在向上逐层剪枝。但是在我仔细研究课本过后,发现书上给的顺序本就是自底向上的,所以只需要在建树时先剪枝再返回节点即可。 564 | 565 | 通过实现决策树和两个剪枝算法,我懂得了实际生活中,专业知识是怎样应用与实践的。同时了解了所谓的复现是什么,通过别人的描述重现他人的工作。了解到在实验中应该先彻底明白算法再进行复现,否则可能因为理解的偏差很多工作可能是错误的。为了最大化效率,应该先明确每个实现细节,在确定完了细节之后从下向上构建代码。 566 | -------------------------------------------------------------------------------- /pic_results/C4.5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Erikfather/Decision_tree-python/6d29eec25d083a9afec9ecb1f762770928f2287b/pic_results/C4.5.jpg -------------------------------------------------------------------------------- /pic_results/CART.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Erikfather/Decision_tree-python/6d29eec25d083a9afec9ecb1f762770928f2287b/pic_results/CART.jpg -------------------------------------------------------------------------------- /pic_results/ID3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Erikfather/Decision_tree-python/6d29eec25d083a9afec9ecb1f762770928f2287b/pic_results/ID3.jpg -------------------------------------------------------------------------------- /pic_results/figure_C4.5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Erikfather/Decision_tree-python/6d29eec25d083a9afec9ecb1f762770928f2287b/pic_results/figure_C4.5.jpg -------------------------------------------------------------------------------- /pic_results/figure_CART.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Erikfather/Decision_tree-python/6d29eec25d083a9afec9ecb1f762770928f2287b/pic_results/figure_CART.jpg -------------------------------------------------------------------------------- /pic_results/figure_ID3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Erikfather/Decision_tree-python/6d29eec25d083a9afec9ecb1f762770928f2287b/pic_results/figure_ID3.jpg -------------------------------------------------------------------------------- /pic_results/first_find_bset_Index.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Erikfather/Decision_tree-python/6d29eec25d083a9afec9ecb1f762770928f2287b/pic_results/first_find_bset_Index.jpg -------------------------------------------------------------------------------- /testdata.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Erikfather/Decision_tree-python/6d29eec25d083a9afec9ecb1f762770928f2287b/testdata.jpg -------------------------------------------------------------------------------- /testset.txt: -------------------------------------------------------------------------------- 1 | 0,0,0,1,0 2 | 0,1,0,1,1 3 | 1,0,1,2,1 4 | 1,0,0,1,0 5 | 2,1,0,2,1 6 | 2,0,0,0,0 7 | 2,0,0,2,0 -------------------------------------------------------------------------------- /tree.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | 3 | from math import log 4 | import operator 5 | import treePlotter 6 | from collections import Counter 7 | 8 | 9 | pre_pruning = True 10 | post_pruning = True 11 | 12 | 13 | def read_dataset(filename): 14 | """ 15 | 年龄段:0代表青年,1代表中年,2代表老年; 16 | 有工作:0代表否,1代表是; 17 | 有自己的房子:0代表否,1代表是; 18 | 信贷情况:0代表一般,1代表好,2代表非常好; 19 | 类别(是否给贷款):0代表否,1代表是 20 | """ 21 | fr = open(filename, 'r') 22 | all_lines = fr.readlines() # list形式,每行为1个str 23 | # print all_lines 24 | labels = ['年龄段', '有工作', '有自己的房子', '信贷情况'] 25 | # featname=all_lines[0].strip().split(',') #list形式 26 | # featname=featname[:-1] 27 | labelCounts = {} 28 | dataset = [] 29 | for line in all_lines[0:]: 30 | line = line.strip().split(',') # 以逗号为分割符拆分列表 31 | dataset.append(line) 32 | return dataset, labels 33 | 34 | 35 | def read_testset(testfile): 36 | """ 37 | 年龄段:0代表青年,1代表中年,2代表老年; 38 | 有工作:0代表否,1代表是; 39 | 有自己的房子:0代表否,1代表是; 40 | 信贷情况:0代表一般,1代表好,2代表非常好; 41 | 类别(是否给贷款):0代表否,1代表是 42 | """ 43 | fr = open(testfile, 'r') 44 | all_lines = fr.readlines() 45 | testset = [] 46 | for line in all_lines[0:]: 47 | line = line.strip().split(',') # 以逗号为分割符拆分列表 48 | testset.append(line) 49 | return testset 50 | 51 | 52 | # 计算信息熵 53 | def cal_entropy(dataset): 54 | numEntries = len(dataset) 55 | labelCounts = {} 56 | # 给所有可能分类创建字典 57 | for featVec in dataset: 58 | currentlabel = featVec[-1] 59 | if currentlabel not in labelCounts.keys(): 60 | labelCounts[currentlabel] = 0 61 | labelCounts[currentlabel] += 1 62 | Ent = 0.0 63 | for key in labelCounts: 64 | p = float(labelCounts[key]) / numEntries 65 | Ent = Ent - p * log(p, 2) # 以2为底求对数 66 | return Ent 67 | 68 | 69 | # 划分数据集 70 | def splitdataset(dataset, axis, value): 71 | retdataset = [] # 创建返回的数据集列表 72 | for featVec in dataset: # 抽取符合划分特征的值 73 | if featVec[axis] == value: 74 | reducedfeatVec = featVec[:axis] # 去掉axis特征 75 | reducedfeatVec.extend(featVec[axis + 1:]) # 将符合条件的特征添加到返回的数据集列表 76 | retdataset.append(reducedfeatVec) 77 | return retdataset 78 | 79 | 80 | ''' 81 | 选择最好的数据集划分方式 82 | ID3算法:以信息增益为准则选择划分属性 83 | C4.5算法:使用“增益率”来选择划分属性 84 | ''' 85 | 86 | 87 | # ID3算法 88 | def ID3_chooseBestFeatureToSplit(dataset): 89 | numFeatures = len(dataset[0]) - 1 90 | baseEnt = cal_entropy(dataset) 91 | bestInfoGain = 0.0 92 | bestFeature = -1 93 | for i in range(numFeatures): # 遍历所有特征 94 | # for example in dataset: 95 | # featList=example[i] 96 | featList = [example[i] for example in dataset] 97 | uniqueVals = set(featList) # 将特征列表创建成为set集合,元素不可重复。创建唯一的分类标签列表 98 | newEnt = 0.0 99 | for value in uniqueVals: # 计算每种划分方式的信息熵 100 | subdataset = splitdataset(dataset, i, value) 101 | p = len(subdataset) / float(len(dataset)) 102 | newEnt += p * cal_entropy(subdataset) 103 | infoGain = baseEnt - newEnt 104 | print(u"ID3中第%d个特征的信息增益为:%.3f" % (i, infoGain)) 105 | if (infoGain > bestInfoGain): 106 | bestInfoGain = infoGain # 计算最好的信息增益 107 | bestFeature = i 108 | return bestFeature 109 | 110 | 111 | # C4.5算法 112 | def C45_chooseBestFeatureToSplit(dataset): 113 | numFeatures = len(dataset[0]) - 1 114 | baseEnt = cal_entropy(dataset) 115 | bestInfoGain_ratio = 0.0 116 | bestFeature = -1 117 | for i in range(numFeatures): # 遍历所有特征 118 | featList = [example[i] for example in dataset] 119 | uniqueVals = set(featList) # 将特征列表创建成为set集合,元素不可重复。创建唯一的分类标签列表 120 | newEnt = 0.0 121 | IV = 0.0 122 | for value in uniqueVals: # 计算每种划分方式的信息熵 123 | subdataset = splitdataset(dataset, i, value) 124 | p = len(subdataset) / float(len(dataset)) 125 | newEnt += p * cal_entropy(subdataset) 126 | IV = IV - p * log(p, 2) 127 | infoGain = baseEnt - newEnt 128 | if (IV == 0): # fix the overflow bug 129 | continue 130 | infoGain_ratio = infoGain / IV # 这个feature的infoGain_ratio 131 | print(u"C4.5中第%d个特征的信息增益率为:%.3f" % (i, infoGain_ratio)) 132 | if (infoGain_ratio > bestInfoGain_ratio): # 选择最大的gain ratio 133 | bestInfoGain_ratio = infoGain_ratio 134 | bestFeature = i # 选择最大的gain ratio对应的feature 135 | return bestFeature 136 | 137 | 138 | # CART算法 139 | def CART_chooseBestFeatureToSplit(dataset): 140 | numFeatures = len(dataset[0]) - 1 141 | bestGini = 999999.0 142 | bestFeature = -1 143 | for i in range(numFeatures): 144 | featList = [example[i] for example in dataset] 145 | uniqueVals = set(featList) 146 | gini = 0.0 147 | for value in uniqueVals: 148 | subdataset = splitdataset(dataset, i, value) 149 | p = len(subdataset) / float(len(dataset)) 150 | subp = len(splitdataset(subdataset, -1, '0')) / float(len(subdataset)) 151 | gini += p * (1.0 - pow(subp, 2) - pow(1 - subp, 2)) 152 | print(u"CART中第%d个特征的基尼值为:%.3f" % (i, gini)) 153 | if (gini < bestGini): 154 | bestGini = gini 155 | bestFeature = i 156 | return bestFeature 157 | 158 | 159 | def majorityCnt(classList): 160 | ''' 161 | 数据集已经处理了所有属性,但是类标签依然不是唯一的, 162 | 此时我们需要决定如何定义该叶子节点,在这种情况下,我们通常会采用多数表决的方法决定该叶子节点的分类 163 | ''' 164 | classCont = {} 165 | for vote in classList: 166 | if vote not in classCont.keys(): 167 | classCont[vote] = 0 168 | classCont[vote] += 1 169 | sortedClassCont = sorted(classCont.items(), key=operator.itemgetter(1), reverse=True) 170 | return sortedClassCont[0][0] 171 | 172 | 173 | # 利用ID3算法创建决策树 174 | def ID3_createTree(dataset, labels, test_dataset): 175 | classList = [example[-1] for example in dataset] 176 | if classList.count(classList[0]) == len(classList): 177 | # 类别完全相同,停止划分 178 | return classList[0] 179 | if len(dataset[0]) == 1: 180 | # 遍历完所有特征时返回出现次数最多的 181 | return majorityCnt(classList) 182 | bestFeat = ID3_chooseBestFeatureToSplit(dataset) 183 | bestFeatLabel = labels[bestFeat] 184 | print(u"此时最优索引为:" + (bestFeatLabel)) 185 | 186 | ID3Tree = {bestFeatLabel: {}} 187 | del (labels[bestFeat]) 188 | # 得到列表包括节点所有的属性值 189 | featValues = [example[bestFeat] for example in dataset] 190 | uniqueVals = set(featValues) 191 | 192 | if pre_pruning: 193 | ans = [] 194 | for index in range(len(test_dataset)): 195 | ans.append(test_dataset[index][-1]) 196 | result_counter = Counter() 197 | for vec in dataset: 198 | result_counter[vec[-1]] += 1 199 | leaf_output = result_counter.most_common(1)[0][0] 200 | root_acc = cal_acc(test_output=[leaf_output] * len(test_dataset), label=ans) 201 | outputs = [] 202 | ans = [] 203 | for value in uniqueVals: 204 | cut_testset = splitdataset(test_dataset, bestFeat, value) 205 | cut_dataset = splitdataset(dataset, bestFeat, value) 206 | for vec in cut_testset: 207 | ans.append(vec[-1]) 208 | result_counter = Counter() 209 | for vec in cut_dataset: 210 | result_counter[vec[-1]] += 1 211 | leaf_output = result_counter.most_common(1)[0][0] 212 | outputs += [leaf_output] * len(cut_testset) 213 | cut_acc = cal_acc(test_output=outputs, label=ans) 214 | 215 | if cut_acc <= root_acc: 216 | return leaf_output 217 | 218 | for value in uniqueVals: 219 | subLabels = labels[:] 220 | ID3Tree[bestFeatLabel][value] = ID3_createTree( 221 | splitdataset(dataset, bestFeat, value), 222 | subLabels, 223 | splitdataset(test_dataset, bestFeat, value)) 224 | 225 | if post_pruning: 226 | tree_output = classifytest(ID3Tree, 227 | featLabels=['年龄段', '有工作', '有自己的房子', '信贷情况'], 228 | testDataSet=test_dataset) 229 | ans = [] 230 | for vec in test_dataset: 231 | ans.append(vec[-1]) 232 | root_acc = cal_acc(tree_output, ans) 233 | result_counter = Counter() 234 | for vec in dataset: 235 | result_counter[vec[-1]] += 1 236 | leaf_output = result_counter.most_common(1)[0][0] 237 | cut_acc = cal_acc([leaf_output] * len(test_dataset), ans) 238 | 239 | if cut_acc >= root_acc: 240 | return leaf_output 241 | 242 | return ID3Tree 243 | 244 | 245 | def C45_createTree(dataset, labels, test_dataset): 246 | classList = [example[-1] for example in dataset] 247 | if classList.count(classList[0]) == len(classList): 248 | # 类别完全相同,停止划分 249 | return classList[0] 250 | if len(dataset[0]) == 1: 251 | # 遍历完所有特征时返回出现次数最多的 252 | return majorityCnt(classList) 253 | bestFeat = C45_chooseBestFeatureToSplit(dataset) 254 | bestFeatLabel = labels[bestFeat] 255 | print(u"此时最优索引为:" + (bestFeatLabel)) 256 | C45Tree = {bestFeatLabel: {}} 257 | del (labels[bestFeat]) 258 | # 得到列表包括节点所有的属性值 259 | featValues = [example[bestFeat] for example in dataset] 260 | uniqueVals = set(featValues) 261 | 262 | if pre_pruning: 263 | ans = [] 264 | for index in range(len(test_dataset)): 265 | ans.append(test_dataset[index][-1]) 266 | result_counter = Counter() 267 | for vec in dataset: 268 | result_counter[vec[-1]] += 1 269 | leaf_output = result_counter.most_common(1)[0][0] 270 | root_acc = cal_acc(test_output=[leaf_output] * len(test_dataset), label=ans) 271 | outputs = [] 272 | ans = [] 273 | for value in uniqueVals: 274 | cut_testset = splitdataset(test_dataset, bestFeat, value) 275 | cut_dataset = splitdataset(dataset, bestFeat, value) 276 | for vec in cut_testset: 277 | ans.append(vec[-1]) 278 | result_counter = Counter() 279 | for vec in cut_dataset: 280 | result_counter[vec[-1]] += 1 281 | leaf_output = result_counter.most_common(1)[0][0] 282 | outputs += [leaf_output] * len(cut_testset) 283 | cut_acc = cal_acc(test_output=outputs, label=ans) 284 | 285 | if cut_acc <= root_acc: 286 | return leaf_output 287 | 288 | for value in uniqueVals: 289 | subLabels = labels[:] 290 | C45Tree[bestFeatLabel][value] = C45_createTree( 291 | splitdataset(dataset, bestFeat, value), 292 | subLabels, 293 | splitdataset(test_dataset, bestFeat, value)) 294 | 295 | if post_pruning: 296 | tree_output = classifytest(C45Tree, 297 | featLabels=['年龄段', '有工作', '有自己的房子', '信贷情况'], 298 | testDataSet=test_dataset) 299 | ans = [] 300 | for vec in test_dataset: 301 | ans.append(vec[-1]) 302 | root_acc = cal_acc(tree_output, ans) 303 | result_counter = Counter() 304 | for vec in dataset: 305 | result_counter[vec[-1]] += 1 306 | leaf_output = result_counter.most_common(1)[0][0] 307 | cut_acc = cal_acc([leaf_output] * len(test_dataset), ans) 308 | 309 | if cut_acc >= root_acc: 310 | return leaf_output 311 | 312 | return C45Tree 313 | 314 | 315 | def CART_createTree(dataset, labels, test_dataset): 316 | classList = [example[-1] for example in dataset] 317 | if classList.count(classList[0]) == len(classList): 318 | # 类别完全相同,停止划分 319 | return classList[0] 320 | if len(dataset[0]) == 1: 321 | # 遍历完所有特征时返回出现次数最多的 322 | return majorityCnt(classList) 323 | bestFeat = CART_chooseBestFeatureToSplit(dataset) 324 | # print(u"此时最优索引为:"+str(bestFeat)) 325 | bestFeatLabel = labels[bestFeat] 326 | print(u"此时最优索引为:" + (bestFeatLabel)) 327 | CARTTree = {bestFeatLabel: {}} 328 | del (labels[bestFeat]) 329 | # 得到列表包括节点所有的属性值 330 | featValues = [example[bestFeat] for example in dataset] 331 | uniqueVals = set(featValues) 332 | 333 | if pre_pruning: 334 | ans = [] 335 | for index in range(len(test_dataset)): 336 | ans.append(test_dataset[index][-1]) 337 | result_counter = Counter() 338 | for vec in dataset: 339 | result_counter[vec[-1]] += 1 340 | leaf_output = result_counter.most_common(1)[0][0] 341 | root_acc = cal_acc(test_output=[leaf_output] * len(test_dataset), label=ans) 342 | outputs = [] 343 | ans = [] 344 | for value in uniqueVals: 345 | cut_testset = splitdataset(test_dataset, bestFeat, value) 346 | cut_dataset = splitdataset(dataset, bestFeat, value) 347 | for vec in cut_testset: 348 | ans.append(vec[-1]) 349 | result_counter = Counter() 350 | for vec in cut_dataset: 351 | result_counter[vec[-1]] += 1 352 | leaf_output = result_counter.most_common(1)[0][0] 353 | outputs += [leaf_output] * len(cut_testset) 354 | cut_acc = cal_acc(test_output=outputs, label=ans) 355 | 356 | if cut_acc <= root_acc: 357 | return leaf_output 358 | 359 | for value in uniqueVals: 360 | subLabels = labels[:] 361 | CARTTree[bestFeatLabel][value] = CART_createTree( 362 | splitdataset(dataset, bestFeat, value), 363 | subLabels, 364 | splitdataset(test_dataset, bestFeat, value)) 365 | 366 | if post_pruning: 367 | tree_output = classifytest(CARTTree, 368 | featLabels=['年龄段', '有工作', '有自己的房子', '信贷情况'], 369 | testDataSet=test_dataset) 370 | ans = [] 371 | for vec in test_dataset: 372 | ans.append(vec[-1]) 373 | root_acc = cal_acc(tree_output, ans) 374 | result_counter = Counter() 375 | for vec in dataset: 376 | result_counter[vec[-1]] += 1 377 | leaf_output = result_counter.most_common(1)[0][0] 378 | cut_acc = cal_acc([leaf_output] * len(test_dataset), ans) 379 | 380 | if cut_acc >= root_acc: 381 | return leaf_output 382 | 383 | return CARTTree 384 | 385 | 386 | def classify(inputTree, featLabels, testVec): 387 | """ 388 | 输入:决策树,分类标签,测试数据 389 | 输出:决策结果 390 | 描述:跑决策树 391 | """ 392 | firstStr = list(inputTree.keys())[0] 393 | secondDict = inputTree[firstStr] 394 | featIndex = featLabels.index(firstStr) 395 | classLabel = '0' 396 | for key in secondDict.keys(): 397 | if testVec[featIndex] == key: 398 | if type(secondDict[key]).__name__ == 'dict': 399 | classLabel = classify(secondDict[key], featLabels, testVec) 400 | else: 401 | classLabel = secondDict[key] 402 | return classLabel 403 | 404 | 405 | def classifytest(inputTree, featLabels, testDataSet): 406 | """ 407 | 输入:决策树,分类标签,测试数据集 408 | 输出:决策结果 409 | 描述:跑决策树 410 | """ 411 | classLabelAll = [] 412 | for testVec in testDataSet: 413 | classLabelAll.append(classify(inputTree, featLabels, testVec)) 414 | return classLabelAll 415 | 416 | 417 | def cal_acc(test_output, label): 418 | """ 419 | :param test_output: the output of testset 420 | :param label: the answer 421 | :return: the acc of 422 | """ 423 | assert len(test_output) == len(label) 424 | count = 0 425 | for index in range(len(test_output)): 426 | if test_output[index] == label[index]: 427 | count += 1 428 | 429 | return float(count / len(test_output)) 430 | 431 | 432 | if __name__ == '__main__': 433 | filename = 'dataset.txt' 434 | testfile = 'testset.txt' 435 | dataset, labels = read_dataset(filename) 436 | # dataset,features=createDataSet() 437 | print('dataset', dataset) 438 | print("---------------------------------------------") 439 | print(u"数据集长度", len(dataset)) 440 | print("Ent(D):", cal_entropy(dataset)) 441 | print("---------------------------------------------") 442 | 443 | print(u"以下为首次寻找最优索引:\n") 444 | print(u"ID3算法的最优特征索引为:" + str(ID3_chooseBestFeatureToSplit(dataset))) 445 | print("--------------------------------------------------") 446 | print(u"C4.5算法的最优特征索引为:" + str(C45_chooseBestFeatureToSplit(dataset))) 447 | print("--------------------------------------------------") 448 | print(u"CART算法的最优特征索引为:" + str(CART_chooseBestFeatureToSplit(dataset))) 449 | print(u"首次寻找最优索引结束!") 450 | print("---------------------------------------------") 451 | 452 | print(u"下面开始创建相应的决策树-------") 453 | 454 | while True: 455 | dec_tree = '1' 456 | # ID3决策树 457 | if dec_tree == '1': 458 | labels_tmp = labels[:] # 拷贝,createTree会改变labels 459 | ID3desicionTree = ID3_createTree(dataset, labels_tmp, test_dataset=read_testset(testfile)) 460 | print('ID3desicionTree:\n', ID3desicionTree) 461 | # treePlotter.createPlot(ID3desicionTree) 462 | treePlotter.ID3_Tree(ID3desicionTree) 463 | testSet = read_testset(testfile) 464 | print("下面为测试数据集结果:") 465 | print('ID3_TestSet_classifyResult:\n', classifytest(ID3desicionTree, labels, testSet)) 466 | print("---------------------------------------------") 467 | 468 | # C4.5决策树 469 | if dec_tree == '2': 470 | labels_tmp = labels[:] # 拷贝,createTree会改变labels 471 | C45desicionTree = C45_createTree(dataset, labels_tmp, test_dataset=read_testset(testfile)) 472 | print('C45desicionTree:\n', C45desicionTree) 473 | treePlotter.C45_Tree(C45desicionTree) 474 | testSet = read_testset(testfile) 475 | print("下面为测试数据集结果:") 476 | print('C4.5_TestSet_classifyResult:\n', classifytest(C45desicionTree, labels, testSet)) 477 | print("---------------------------------------------") 478 | 479 | # CART决策树 480 | if dec_tree == '3': 481 | labels_tmp = labels[:] # 拷贝,createTree会改变labels 482 | CARTdesicionTree = CART_createTree(dataset, labels_tmp, test_dataset=read_testset(testfile)) 483 | print('CARTdesicionTree:\n', CARTdesicionTree) 484 | treePlotter.CART_Tree(CARTdesicionTree) 485 | testSet = read_testset(testfile) 486 | print("下面为测试数据集结果:") 487 | print('CART_TestSet_classifyResult:\n', classifytest(CARTdesicionTree, labels, testSet)) 488 | 489 | break 490 | -------------------------------------------------------------------------------- /treePlotter.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import matplotlib.pyplot as plt 3 | from pylab import mpl 4 | mpl.rcParams['font.sans-serif'] = ['SimHei'] 5 | decisionNode = dict(boxstyle="sawtooth", fc="0.8") 6 | leafNode = dict(boxstyle="round4", fc="0.8") 7 | arrow_args = dict(arrowstyle="<-") 8 | 9 | def plotNode(nodeTxt, centerPt, parentPt, nodeType): 10 | createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', \ 11 | xytext=centerPt, textcoords='axes fraction', \ 12 | va="center", ha="center", bbox=nodeType, arrowprops=arrow_args) 13 | 14 | def getNumLeafs(myTree): 15 | numLeafs = 0 16 | firstStr = list(myTree.keys())[0] 17 | secondDict = myTree[firstStr] 18 | for key in secondDict.keys(): 19 | if type(secondDict[key]).__name__ == 'dict': 20 | numLeafs += getNumLeafs(secondDict[key]) 21 | else: 22 | numLeafs += 1 23 | return numLeafs 24 | 25 | def getTreeDepth(myTree): 26 | maxDepth = 0 27 | firstStr = list(myTree.keys())[0] 28 | secondDict = myTree[firstStr] 29 | for key in secondDict.keys(): 30 | if type(secondDict[key]).__name__ == 'dict': 31 | thisDepth = getTreeDepth(secondDict[key]) + 1 32 | else: 33 | thisDepth = 1 34 | if thisDepth > maxDepth: 35 | maxDepth = thisDepth 36 | return maxDepth 37 | 38 | def plotMidText(cntrPt, parentPt, txtString): 39 | xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0] 40 | yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1] 41 | createPlot.ax1.text(xMid, yMid, txtString) 42 | 43 | def plotTree(myTree, parentPt, nodeTxt): 44 | numLeafs = getNumLeafs(myTree) 45 | depth = getTreeDepth(myTree) 46 | firstStr = list(myTree.keys())[0] 47 | cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalw, plotTree.yOff) 48 | plotMidText(cntrPt, parentPt, nodeTxt) 49 | plotNode(firstStr, cntrPt, parentPt, decisionNode) 50 | secondDict = myTree[firstStr] 51 | plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD 52 | for key in secondDict.keys(): 53 | if type(secondDict[key]).__name__ == 'dict': 54 | plotTree(secondDict[key], cntrPt, str(key)) 55 | else: 56 | plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalw 57 | plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) 58 | plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) 59 | plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD 60 | 61 | def createPlot(inTree): 62 | fig = plt.figure(1, facecolor='white') 63 | fig.clf() 64 | axprops = dict(xticks=[], yticks=[]) 65 | createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) 66 | plotTree.totalw = float(getNumLeafs(inTree)) 67 | plotTree.totalD = float(getTreeDepth(inTree)) 68 | plotTree.xOff = -0.5 / plotTree.totalw 69 | plotTree.yOff = 1.0 70 | plotTree(inTree, (0.5, 1.0), '') 71 | #plt.show() 72 | #ID3决策树 73 | def ID3_Tree(inTree): 74 | fig = plt.figure(1, facecolor='white') 75 | fig.clf() 76 | axprops = dict(xticks=[], yticks=[]) 77 | createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) 78 | plotTree.totalw = float(getNumLeafs(inTree)) 79 | plotTree.totalD = float(getTreeDepth(inTree)) 80 | plotTree.xOff = -0.5 / plotTree.totalw 81 | plotTree.yOff = 1.0 82 | plotTree(inTree, (0.5, 1.0), '') 83 | plt.title("ID3决策树",fontsize=12,color='red') 84 | plt.show() 85 | 86 | #C4.5决策树 87 | def C45_Tree(inTree): 88 | fig = plt.figure(2, facecolor='white') 89 | fig.clf() 90 | axprops = dict(xticks=[], yticks=[]) 91 | createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) 92 | plotTree.totalw = float(getNumLeafs(inTree)) 93 | plotTree.totalD = float(getTreeDepth(inTree)) 94 | plotTree.xOff = -0.5 / plotTree.totalw 95 | plotTree.yOff = 1.0 96 | plotTree(inTree, (0.5, 1.0), '') 97 | plt.title("C4.5决策树",fontsize=12,color='red') 98 | plt.show() 99 | 100 | #CART决策树 101 | def CART_Tree(inTree): 102 | fig = plt.figure(3, facecolor='white') 103 | fig.clf() 104 | axprops = dict(xticks=[], yticks=[]) 105 | createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) 106 | plotTree.totalw = float(getNumLeafs(inTree)) 107 | plotTree.totalD = float(getTreeDepth(inTree)) 108 | plotTree.xOff = -0.5 / plotTree.totalw 109 | plotTree.yOff = 1.0 110 | plotTree(inTree, (0.5, 1.0), '') 111 | plt.title("CART决策树",fontsize=12,color='red') 112 | plt.show() 113 | 114 | -------------------------------------------------------------------------------- /数据表.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Erikfather/Decision_tree-python/6d29eec25d083a9afec9ecb1f762770928f2287b/数据表.jpg -------------------------------------------------------------------------------- /数据表.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Erikfather/Decision_tree-python/6d29eec25d083a9afec9ecb1f762770928f2287b/数据表.xlsx --------------------------------------------------------------------------------