├── DecisionTree_NoPruning.png ├── DecisionTree_Postpruning.png ├── DecisionTree_Prepruning.png ├── Decision_Tree_ID3.py ├── Decision_Tree_ID3_postpruning.py ├── Decision_Tree_ID3_prepruning.py ├── Decision_Tree_Visual.py ├── README.md ├── watermelon2.csv ├── watermelon2Training.csv ├── watermelon2Validation.csv └── watermelon2_Decision_Tree_ID3.png /DecisionTree_NoPruning.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JonathonYan1993/ML_DecisionTree_prepruning_postpruning/e0ae24ab1c3212059afd3d8456b443a8a4c51f85/DecisionTree_NoPruning.png -------------------------------------------------------------------------------- /DecisionTree_Postpruning.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JonathonYan1993/ML_DecisionTree_prepruning_postpruning/e0ae24ab1c3212059afd3d8456b443a8a4c51f85/DecisionTree_Postpruning.png -------------------------------------------------------------------------------- /DecisionTree_Prepruning.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JonathonYan1993/ML_DecisionTree_prepruning_postpruning/e0ae24ab1c3212059afd3d8456b443a8a4c51f85/DecisionTree_Prepruning.png -------------------------------------------------------------------------------- /Decision_Tree_ID3.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Jan 5 16:47:44 2019 4 | 本程序实现简单的决策树分类 5 | 决策树基于ID3决策树,以信息增益为准则选择划分属性 6 | 程序输出为决策树示意图 7 | 数据集为西瓜数据集2.0 8 | 参考资料:周志华《机器学习》 9 | @author: yanji 10 | """ 11 | 12 | 13 | import pandas as pd 14 | import math 15 | import Decision_Tree_Visual 16 | 17 | #读取西瓜数据集2.0 18 | watermelon=pd.read_csv('watermelon2.csv',encoding='gbk') 19 | 20 | 21 | 22 | #计算信息熵information entropy 23 | def entropyCal(dfdata): 24 | dataSize=dfdata.shape[0] #数据集样本数 25 | colSize=dfdata.shape[1] #数据集属性个数(包括最后一列分类) 26 | typeCount=dict(dfdata.iloc[:,colSize-1].value_counts()) #统计数据集样本各个类别及其数目 27 | entropy=0.0 28 | for key in typeCount: 29 | p=float(typeCount[key])/dataSize 30 | entropy=entropy-p*math.log(p,2) #以2为底求对数 31 | return entropy 32 | 33 | #以某个属性的值划分数据集 34 | def splitDataset(dfdata,colName,colValue): 35 | dataSize=dfdata.shape[0] #划分前数据集个数 36 | restData=pd.DataFrame(columns=dfdata.columns) #建立新的数据集,列索引与原数据集一样 37 | for rowNumber in range(dataSize): 38 | if dfdata.iloc[rowNumber][colName] == colValue: 39 | restData=restData.append(dfdata.iloc[rowNumber,:],ignore_index=True) #将划分属性等于该属性值的样本划给新数据集 40 | restData.drop([colName],axis=1,inplace=True) #去掉该属性列 41 | return restData 42 | 43 | #选择当前数据集最好的划分属性 44 | #ID3算法:以信息增益为准则选择划分属性 45 | def chooseBestFeatureToSplit(dfdata): 46 | dataSize = dfdata.shape[0] #数据集样本个数 47 | numFeature = dfdata.shape[1]-1 #数据集属性个数 48 | entropyBase = entropyCal(dfdata) #划分前样本集的信息熵 49 | infoGainMax=0.0 #初始化最大信息熵 50 | bestFeature='' #初始化最佳划分属性 51 | for col in range(numFeature): 52 | featureValueCount=dict(dfdata.iloc[:,col].value_counts()) #统计该属性下各个值及其数目 53 | entropyNew=0.0 54 | for key, value in featureValueCount.items(): 55 | #计算该属性划分下各样本集的信息熵加权和 56 | entropyNew+=entropyCal(splitDataset(dfdata,dfdata.columns[col],key))*float(value/dataSize) 57 | infoGain=entropyBase-entropyNew #计算该属性下的信息增益 58 | if infoGain> infoGainMax: 59 | infoGainMax=infoGain 60 | bestFeature=dfdata.columns[col] #寻找最佳划分属性 61 | return bestFeature 62 | 63 | #当叶节点样本已经无属性可划分了或者样本集为同一类别,这时采用多数表决法返回数量最多的类别 64 | def typeMajority(dfdata): 65 | typeCount=dict(dfdata.iloc[:,dfdata.shape[1]-1].value_counts()) 66 | return list(typeCount.keys())[0] 67 | 68 | #创建决策树 69 | def creatDecisionTree(dfdata): 70 | #首先判断样本集是否为同一类别以及是否还能进行属性划分 71 | if (dfdata.shape[1]==1 or len(dfdata.iloc[:,dfdata.shape[1]-1].unique())==1): 72 | return typeMajority(dfdata) 73 | bestFeature = chooseBestFeatureToSplit(dfdata) #选择最佳划分属性 74 | decisionTree={bestFeature:{}} #以字典形式创建决策树 75 | bestFeatureValueCount=dict(dfdata.loc[:,bestFeature].value_counts()) #统计该属性下的所有属性值 76 | for key, value in bestFeatureValueCount.items(): 77 | #以递归调用方式不断完善决策树 78 | decisionTree[bestFeature][key]=creatDecisionTree(splitDataset(dfdata,bestFeature,key)) 79 | return decisionTree 80 | 81 | #对新的样例进行分类预测 82 | def classify(inputTree,valSple): 83 | firstStr = list(inputTree.keys())[0] #决策树第一个值,即第一个划分属性 84 | secondDict = inputTree[firstStr] 85 | for key in secondDict.keys(): 86 | if(valSple[firstStr]==key): #该样本在该划分属性的值与决策树的对应判断 87 | if type(secondDict[key]).__name__ == 'dict': 88 | classLabel = classify(secondDict[key],valSple) # 递归调用分类函数 89 | else: 90 | classLabel = secondDict[key] 91 | return classLabel 92 | 93 | #创建西瓜数据集2.0的决策树 94 | watermelonDecisionTree=creatDecisionTree(watermelon) 95 | #决策树可视化 96 | Decision_Tree_Visual.createTree(watermelonDecisionTree,"ID3决策树_西瓜数据集2.0") 97 | 98 | 99 | 100 | 101 | 102 | 103 | -------------------------------------------------------------------------------- /Decision_Tree_ID3_postpruning.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Jan 19 22:10:28 2019 4 | 本程序实现决策树的后剪枝 5 | 数据集为西瓜书表4.2,划分为训练集和验证集 6 | 参考资料:周志华《机器学习》 7 | @author: yanji 8 | """ 9 | 10 | import Decision_Tree_ID3 as TreeID3 11 | import Decision_Tree_Visual as TreeVisual 12 | import pandas as pd 13 | 14 | #读取训练集 15 | watermelonTra=pd.read_csv('watermelon2Training.csv',encoding='gbk') 16 | #读取验证集 17 | watermelonVal=pd.read_csv('watermelon2Validation.csv',encoding='gbk') 18 | 19 | #基于训练集创建未剪枝决策树 20 | treeOriginal=TreeID3.creatDecisionTree(watermelonTra) 21 | #可视化未剪枝决策树 22 | TreeVisual.createTree(treeOriginal,'未剪枝决策树') 23 | 24 | 25 | #计算决策树在验证集的精度 26 | 27 | def valPrecision(thisTree,valdata): 28 | classTrue=list(valdata.iloc[:,-1]) 29 | valNum=valdata.shape[0] 30 | classPred=[] 31 | crtNum=0 #初始化预测正确样例数 32 | for rowNum in range(valNum): 33 | classSple=TreeID3.classify(thisTree,watermelonVal.iloc[rowNum,:]) #预测该样例的分类 34 | classPred.append(classSple) 35 | if classTrue[rowNum] == classSple: #判断预分类测是否正确 36 | crtNum+=1 37 | return crtNum/valNum #返回分类精度 38 | 39 | 40 | #对已建立的决策树进行后剪枝 41 | #递归调用通过设置剪枝代码位置实现自底向顶或自顶向底进行剪枝 42 | def createPostpruningTree(inputTree,dfdata,valdata): 43 | firstStr=list(inputTree.keys())[0] #获取第一个属性值 44 | secondDict=inputTree[firstStr] 45 | typedfdata=TreeID3.typeMajority(dfdata) #多数表决发确定剩余训练集的类别 46 | pruningTree={firstStr:{}} #初始化后剪枝决策树 47 | contrastTree={firstStr:{}} #对该属性建立不划分决策树 48 | for key in secondDict.keys(): 49 | contrastTree[firstStr][key]=typedfdata #不划分决策树即每个属性取值样例的类别都为多数表决结果 50 | #以递归调用方式完善决策树 51 | if type(secondDict[key]).__name__=='dict': 52 | pruningTree[firstStr][key]=createPostpruningTree(secondDict[key],TreeID3.splitDataset(dfdata,firstStr,key),TreeID3.splitDataset(valdata,firstStr,key)) 53 | else: 54 | pruningTree[firstStr][key]=secondDict[key] 55 | #针对该属性,计算剪枝后与不剪枝决策树在验证集的预测精度 56 | precisionContrast=valPrecision(contrastTree,valdata) 57 | precisionPruning=valPrecision(pruningTree,valdata) 58 | #将两种决策树进行比较,如果剪枝后能提高精度,则选择对该属性剪枝 59 | #剪枝操作放在递归调用之后,实现自底向顶的剪枝 60 | if precisionContrast>precisionPruning: 61 | #print(firstStr) 62 | #print(typedfdata) 63 | return typedfdata 64 | else: 65 | return pruningTree 66 | 67 | #基于未剪枝决策树、训练集与验证集创建后剪枝决策树 68 | treePostpruning=createPostpruningTree(treeOriginal,watermelonTra,watermelonVal) 69 | #后剪枝决策树可视化 70 | TreeVisual.createTree(treePostpruning,'后剪枝决策树') -------------------------------------------------------------------------------- /Decision_Tree_ID3_prepruning.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Jan 17 20:43:08 2019 4 | 本程序实现决策树的预剪枝 5 | 数据集为西瓜书表4.2,划分为训练集与验证集 6 | 参考资料:周志华《机器学习》 7 | @author: yanji 8 | """ 9 | import Decision_Tree_ID3 as TreeID3 10 | import Decision_Tree_Visual as TreeVisual 11 | import pandas as pd 12 | 13 | #读取训练集 14 | watermelonTra=pd.read_csv('watermelon2Training.csv',encoding='gbk') 15 | #读取验证集 16 | watermelonVal=pd.read_csv('watermelon2Validation.csv',encoding='gbk') 17 | 18 | #基于训练集创建未剪枝决策树 19 | treeOriginal=TreeID3.creatDecisionTree(watermelonTra) 20 | #可视化未剪枝决策树 21 | TreeVisual.createTree(treeOriginal,'未剪枝决策树') 22 | 23 | #判断进行属性划分时是否使得验证集精度进行增加,即验证集中判断正确的样例个数是否增加 24 | def precisionRaiseJudge(chosenFeature,dfdata,valdata): 25 | #此时训练集与验证集已经是上层决策树分完类的子集 26 | valNum=valdata.shape[0] 27 | typeColval =valdata.shape[1]-1 28 | #不划分,计算验证集预测正确个数 29 | numUnclassify=0; 30 | typedfdata=TreeID3.typeMajority(dfdata) 31 | for rowNum in range(valNum): 32 | if valdata.iloc[rowNum,typeColval]==typedfdata: 33 | numUnclassify+=1 34 | #划分,计算验证集预测正确个数 35 | numclassify=0 36 | chosenFeatureValueCount=dict(dfdata.loc[:,chosenFeature].value_counts()) 37 | typedfdataClassify={} 38 | for key in chosenFeatureValueCount.keys(): 39 | #判定各个属性取值对应的类 40 | typedfdataClassify[key]=TreeID3.typeMajority(TreeID3.splitDataset(dfdata,chosenFeature,key)) 41 | #print(typedfdataClassify) 42 | for rowNum in range(valNum): 43 | featureValue = valdata.iloc[rowNum][chosenFeature] 44 | typeValue = valdata.iloc[rowNum,typeColval] 45 | for key,value in typedfdataClassify.items(): 46 | if (featureValue == key and typeValue == value): 47 | numclassify+=1 48 | #如果划分后的验证集预测个数没有大于划分前个数,则不划分 49 | if numclassify>numUnclassify: 50 | return True 51 | else: 52 | return False 53 | 54 | #创建预剪枝决策树 55 | def createPrepruningTree(dfdata,valdata): 56 | #首先判断样本集是否为同一类别以及是否还能进行属性划分 57 | if (dfdata.shape[1]==1 or len(dfdata.iloc[:,dfdata.shape[1]-1].unique())==1): 58 | return TreeID3.typeMajority(dfdata) 59 | bestFeature = TreeID3.chooseBestFeatureToSplit(dfdata) #选择最佳划分属性 60 | bestFeatureValueCount=dict(dfdata.loc[:,bestFeature].value_counts()) #统计该属性下的所有属性值 61 | #判断是否应该对此节点进行划分 62 | if(precisionRaiseJudge(bestFeature,dfdata,valdata)==False): 63 | return TreeID3.typeMajority(dfdata) 64 | decisionTree={bestFeature:{}} #以字典形式创建决策树 65 | for key, value in bestFeatureValueCount.items(): 66 | #以递归调用方式不断完善决策树 67 | decisionTree[bestFeature][key]=createPrepruningTree(TreeID3.splitDataset(dfdata,bestFeature,key),TreeID3.splitDataset(valdata,bestFeature,key)) 68 | return decisionTree 69 | 70 | #基于训练集与验证集创建预剪枝决策树 71 | treePrepruning=createPrepruningTree(watermelonTra,watermelonVal) 72 | #预剪枝决策树可视化 73 | TreeVisual.createTree(treePrepruning,'预剪枝决策树') -------------------------------------------------------------------------------- /Decision_Tree_Visual.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Jan 14 18:49:24 2019 4 | 本程序将以字典形式存储的决策树绘制显示,实现决策树的可视化 5 | @author: yanji 6 | """ 7 | import matplotlib.pyplot as plt 8 | #用来正常显示中文 9 | plt.rcParams['font.sans-serif']=['SimHei'] 10 | #用来正常显示负号 11 | plt.rcParams['axes.unicode_minus']=False 12 | #设置画节点用的盒子的样式 13 | decisionNode = dict(boxstyle = "sawtooth",fc="2") 14 | leafNode = dict(boxstyle="round4",fc="2") 15 | #设置画箭头的样式 http://matplotlib.org/api/patches_api.html#matplotlib.patches.FancyArrowPatch 16 | arrow_args = dict(arrowstyle="<-") 17 | 18 | #获取叶子节点的数目 19 | def getNumLeafs(myTree): 20 | #初始化树的叶子节点个数 21 | numLeafs=0 22 | #myTree.keys()获取树的非叶子节点 23 | #list(myTree.keys())[0]获取第一个键名 24 | firstStr = list(myTree.keys())[0] 25 | #通过键名获取与之对应的值 26 | secondDict = myTree[firstStr] 27 | #遍历树,secondDict.keys()获取所有的键 28 | for key in secondDict.keys(): 29 | #判断键是否为字典,键名1和其值就组成了一个字典,如果是字典则通过递归继续遍历,寻找叶子节点 30 | if type(secondDict[key]).__name__ =='dict': 31 | numLeafs += getNumLeafs(secondDict[key]) 32 | #如果不是字典,则叶子结点的数目就加1 33 | else: 34 | numLeafs+=1 35 | #返回叶子节点的数目 36 | return numLeafs 37 | 38 | #获取树的深度 39 | def getTreeDepth(myTree): 40 | #初始化树的深度 41 | maxDepth=0 42 | #获取树的第一个键名 43 | firstStr=list(myTree.keys())[0] 44 | #获取键名所对应的值 45 | secondDict=myTree[firstStr] 46 | #遍历树 47 | for key in secondDict.keys(): 48 | #如果获取的键是字典,树的深度加1 49 | if type(secondDict[key]).__name__ == 'dict': 50 | thisDepth=1+getTreeDepth(secondDict[key]) 51 | else: 52 | thisDepth=1 53 | if thisDepth>maxDepth: 54 | maxDepth = thisDepth 55 | #返回树的深度 56 | return maxDepth 57 | 58 | #绘图相关参数的设置 59 | def plotNode(nodeTxt,centerPt,parentPt,nodeType): 60 | #annotate函数是为绘制图上指定的数据点xy添加一个nodeTxt注释 61 | #nodeTxt是给数据点xy添加一个注释,xy为数据点的开始绘制的坐标,位于节点的中间位置 62 | #xycoords设置指定点xy的坐标类型,xytext为注释的中间点坐标,textcoords设置注释点坐标样式 63 | #bbox设置装注释盒子的样式,arrowprops设置箭头的样式 64 | ''' 65 | figure points:表示坐标原点在图的左下角的数据点 66 | figure pixels:表示坐标原点在图的左下角的像素点 67 | figure fraction:此时取值是小数,范围是([0,1],[0,1]),在图的左下角时xy是(0,0),最右上角是(1,1) 68 | 其他位置是按相对图的宽高的比例取最小值 69 | axes points : 表示坐标原点在图中坐标的左下角的数据点 70 | axes pixels : 表示坐标原点在图中坐标的左下角的像素点 71 | axes fraction : 与figure fraction类似,只不过相对于图的位置改成是相对于坐标轴的位置 72 | ''' 73 | createTree.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,\ 74 | textcoords='axes fraction',va="center",ha="center",bbox=nodeType,\ 75 | arrowprops=arrow_args) 76 | 77 | #绘制线中间的文字(0和1)的绘制 78 | def plotMidText(cntrPt,parentPt,txtString): 79 | xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0] #计算文字的x坐标 80 | yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1] #计算文字的y坐标 81 | createTree.ax1.text(xMid,yMid,txtString) 82 | 83 | #绘制树 84 | def plotTree(myTree,parentPt,nodeTxt): 85 | #获取树的叶子节点 86 | numLeafs = getNumLeafs(myTree) 87 | #获取树的深度 88 | #depth = getTreeDepth(myTree) 89 | #获取第一个键名 90 | firstStr= list(myTree.keys())[0] 91 | #计算子节点的坐标 92 | #此步骤保证了决策树或决策子树的根节点的横坐标位于该树所有叶节点横坐标范围的中点 93 | #假设参考点横坐标为x0(即x0ff),节点之间的距离d=(1/totalW),该树总的叶节点个数为n 94 | #可知:第一个叶节点的横坐标为x0+d,最后一个叶节点的横坐标为x0+nd,则该树根节点横坐标为((x0+d)+(x0+nd))/2 95 | cntrPt = (plotTree.x0ff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.y0ff) 96 | #print(cntrPt) 97 | #绘制线上的文字 98 | plotMidText(cntrPt,parentPt,nodeTxt) 99 | #绘制节点 100 | plotNode(firstStr,cntrPt,parentPt,decisionNode) 101 | #获取第一个键值 102 | secondDict = myTree[firstStr] 103 | #计算节点y方向上的偏移量,根据树的深度 104 | plotTree.y0ff = plotTree.y0ff - 1.0/plotTree.totalD 105 | for key in secondDict.keys(): 106 | if type(secondDict[key]).__name__ == 'dict': 107 | #递归绘制树 108 | plotTree(secondDict[key],cntrPt,str(key)) 109 | else: 110 | #更新x的偏移量,每个叶子结点x轴方向上的距离为 1/plotTree.totalW 111 | plotTree.x0ff = plotTree.x0ff + 1.0/plotTree.totalW 112 | #print(plotTree.x0ff) 113 | #绘制非叶子节点 114 | plotNode(secondDict[key],(plotTree.x0ff,plotTree.y0ff),cntrPt,leafNode) 115 | #绘制箭头上的标志 116 | plotMidText((plotTree.x0ff,plotTree.y0ff),cntrPt,str(key)) 117 | #递归完成后回到上层,树的深度自然需要减到上层 118 | plotTree.y0ff = plotTree.y0ff + 1.0/plotTree.totalD 119 | 120 | #绘制决策树 121 | def createTree(newTree,titleName): 122 | #新建一个figure设置背景颜色为白色 123 | fig = plt.figure(1,facecolor='white') 124 | #清除figure 125 | fig.clf() 126 | axprops = dict(xticks=[],yticks=[]) 127 | #创建一个1行1列1个figure,并把网格里面的第一个figure的Axes实例返回给ax1作为函数createPlot() 128 | #的属性,这个属性ax1相当于一个全局变量,可以给plotNode函数使用 129 | createTree.ax1 = plt.subplot(111,frameon=False,**axprops) 130 | #获取树的叶子节点 131 | plotTree.totalW = float(getNumLeafs(newTree)) 132 | #获取树的深度 133 | plotTree.totalD = float(getTreeDepth(newTree)) 134 | #节点的x轴的偏移量为-1/plotTree.totlaW/2,1为x轴的长度,除以2保证每一个节点的x轴之间的距离为1/plotTree.totlaW*2 135 | plotTree.x0ff = -0.5/plotTree.totalW 136 | plotTree.y0ff = 1.0 137 | plotTree(newTree,(0.5,1.0),'') 138 | plt.title(str(titleName),fontsize=14,color='red') 139 | plt.show() 140 | 141 | 142 | 143 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ML_DecisionTree_prepruning_postpruning 2 | 本程序基于Numpy实现决策树的建立与可视化,以及决策树的预剪枝与后剪枝,数据集为西瓜书4.2、4.3节中的西瓜数据集 3 | ## 4 | 1、repo Structure 5 | * Decision_Tree_ID3.py——创建决策树 6 | * Decision_Tree_ID3_postpruning.py——决策树后剪枝 7 | * Decision_Tree_ID3_prepruning.py——决策树预剪枝 8 | * Decision_Tree_Visual.py——决策树可视化 9 | ## 10 | 2、数据集 11 | * watermelon2.csv 12 | * watermelon2Training.csv 13 | * watermelon2Validation.csv 14 | 15 | ## 16 | 3、结果可视化 17 | * watermelon2_Decision_Tree_ID3.png等 18 | -------------------------------------------------------------------------------- /watermelon2.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JonathonYan1993/ML_DecisionTree_prepruning_postpruning/e0ae24ab1c3212059afd3d8456b443a8a4c51f85/watermelon2.csv -------------------------------------------------------------------------------- /watermelon2Training.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JonathonYan1993/ML_DecisionTree_prepruning_postpruning/e0ae24ab1c3212059afd3d8456b443a8a4c51f85/watermelon2Training.csv -------------------------------------------------------------------------------- /watermelon2Validation.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JonathonYan1993/ML_DecisionTree_prepruning_postpruning/e0ae24ab1c3212059afd3d8456b443a8a4c51f85/watermelon2Validation.csv -------------------------------------------------------------------------------- /watermelon2_Decision_Tree_ID3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JonathonYan1993/ML_DecisionTree_prepruning_postpruning/e0ae24ab1c3212059afd3d8456b443a8a4c51f85/watermelon2_Decision_Tree_ID3.png --------------------------------------------------------------------------------