└── Salary_DecisionTree.ipynb /Salary_DecisionTree.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 134, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "[['A1' 'A3' 'B7' 'B6' 'A2' 4000]\n", 13 | " ['B1' 'A3' 'B7' 'D6' 'B2' 3000]\n", 14 | " ['B1' 'A3' 'B7' 'A6' 'A2' 6000]\n", 15 | " ...\n", 16 | " ['B1' 'B3' 'D7' 'A6' 'A2' 10000]\n", 17 | " ['C1' 'C3' 'A7' 'C6' 'A2' 3000]\n", 18 | " ['D1' 'B3' 'B7' 'A6' 'A2' 17000]]\n" 19 | ] 20 | } 21 | ], 22 | "source": [ 23 | "#数据装载\n", 24 | "import csv\n", 25 | "import numpy as np\n", 26 | "import pandas as pd\n", 27 | "\n", 28 | "df = pd.read_csv('train_data.csv')\n", 29 | "data = np.array(df)\n", 30 | "print data" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 135, 36 | "metadata": {}, 37 | "outputs": [ 38 | { 39 | "name": "stdout", 40 | "output_type": "stream", 41 | "text": [ 42 | "[['A1' 'A3' 'B7' 'B6' 'A2' 3000]\n", 43 | " ['B1' 'A3' 'B7' 'D6' 'B2' 3000]\n", 44 | " ['B1' 'A3' 'B7' 'A6' 'A2' 5000]\n", 45 | " ...\n", 46 | " ['B1' 'B3' 'D7' 'A6' 'A2' 10000]\n", 47 | " ['C1' 'C3' 'A7' 'C6' 'A2' 3000]\n", 48 | " ['D1' 'B3' 'B7' 'A6' 'A2' 10000]]\n" 49 | ] 50 | } 51 | ], 52 | "source": [ 53 | "#预测值分层\n", 54 | "for i in range(len(data)):\n", 55 | " if 1000<=data[i][5]<3000:\n", 56 | " data[i][5] = 1000\n", 57 | " if 3000<=data[i][5]<5000:\n", 58 | " data[i][5] = 3000\n", 59 | " if 5000<=data[i][5]<8000:\n", 60 | " data[i][5] = 5000\n", 61 | " if 8000<=data[i][5]<10000:\n", 62 | " data[i][5] = 8000\n", 63 | " if 10000<=data[i][5]<40000:\n", 64 | " data[i][5] = 10000\n", 65 | " if data[i][5]>=40000:\n", 66 | " data[i][5] = 40000\n", 67 | "\n", 68 | "print data" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 136, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "#划分数据集\n", 78 | "train_data = data[:35001]\n", 79 | "test_data = data[35001:]" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 137, 85 | "metadata": {}, 86 | "outputs": [ 87 | { 88 | "name": "stdout", 89 | "output_type": "stream", 90 | "text": [ 91 | "[['A1' 'A3' 'B7' 'B6' 'A2' 3000]\n", 92 | " ['B1' 'A3' 'B7' 'D6' 'B2' 3000]\n", 93 | " ['B1' 'A3' 'B7' 'A6' 'A2' 5000]\n", 94 | " ...\n", 95 | " ['D1' 'B3' 'A7' 'B6' 'A2' 5000]\n", 96 | " ['A1' 'C3' 'A7' 'A6' 'A2' 10000]\n", 97 | " ['A1' 'A3' 'B7' 'B6' 'A2' 5000]]\n", 98 | "['experience', 'education', 'major', 'location', 'company_type']\n" 99 | ] 100 | } 101 | ], 102 | "source": [ 103 | "#生成训练数据\n", 104 | "def createDataSet():\n", 105 | " dataSet = train_data\n", 106 | " atrributes = ['experience','education','major','location','company_type']\n", 107 | " return dataSet, atrributes\n", 108 | "\n", 109 | "dataSet, attributes = createDataSet()\n", 110 | "print dataSet\n", 111 | "print attributes" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": {}, 117 | "source": [ 118 | "# C4.5" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 138, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "#计算熵值\n", 128 | "import math\n", 129 | "def calcEnt(dataSet):\n", 130 | " labels = set([dataVec[-1] for dataVec in dataSet])\n", 131 | " labelSet = {}\n", 132 | " for label in labels:\n", 133 | " labelSet[label] = 0\n", 134 | " #样本总个数\n", 135 | " totalNum = len(dataSet)\n", 136 | " #类别集合\n", 137 | " #计算每个类别的样本个数\n", 138 | " for dataVec in dataSet:\n", 139 | " label = dataVec[-1]\n", 140 | " if label not in labelSet.keys():\n", 141 | " labelSet[label] = 0\n", 142 | " labelSet[label] += 1\n", 143 | " Ent = 0\n", 144 | " #计算熵值\n", 145 | " for key in labelSet:\n", 146 | " pi = float(labelSet[key])/totalNum\n", 147 | " Ent -= pi*math.log(pi,2)\n", 148 | " #print labelSet\n", 149 | " return Ent" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 139, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "#按给定特征划分数据集:返回第featNum个特征其值为value的样本集合,且返回的样本数据中已经去除该特征\n", 159 | "def splitDataSet(dataSet, featNum, featvalue):\n", 160 | " retDataSet = []\n", 161 | " for dataVec in dataSet:\n", 162 | " if dataVec[featNum] == featvalue:\n", 163 | " splitData = dataVec[:featNum]\n", 164 | " splitData = np.append(splitData,dataVec[featNum+1:])\n", 165 | " retDataSet.append(splitData)\n", 166 | " return np.array(retDataSet)" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 140, 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "#选择最好的特征划分数据集\n", 176 | "def chooseBestFeatToSplit(dataSet, labelSet):\n", 177 | " featNum = len(dataSet[0]) - 1\n", 178 | " maxInfoGain = 0\n", 179 | " bestFeat = -1\n", 180 | " #计算样本熵值,对应公式中:H(X)\n", 181 | " baseEnt = calcEnt(dataSet)\n", 182 | " #以每一个特征进行分类,找出使信息增益最大的特征\n", 183 | " for i in range(featNum):\n", 184 | " featList = [dataVec[i] for dataVec in dataSet]\n", 185 | " featList = set(featList)\n", 186 | " #print labelSet[i],\"的所有属性值\",featList\n", 187 | " newEnt = 0\n", 188 | " #计算以第i个特征进行分类后的熵值,对应公式中:H(X|Y)\n", 189 | " for featValue in featList:\n", 190 | " subDataSet = splitDataSet(dataSet, i, featValue)\n", 191 | " prob = len(subDataSet)/float(len(dataSet))\n", 192 | " newEnt += prob*calcEnt(subDataSet)\n", 193 | " #ID3算法:计算信息增益,对应公式中:g(X,Y)=H(X)-H(X|Y)\n", 194 | " #infoGain = baseEnt - newEnt\n", 195 | " #C4.5算法:计算信息增益比\n", 196 | " selfClass = np.array([dataVec[i] for dataVec in dataSet]).reshape(-1,1)\n", 197 | " if calcEnt(selfClass) == 0:\n", 198 | " return i\n", 199 | " infoGain = (baseEnt - newEnt)/calcEnt(selfClass)\n", 200 | " #print labelSet[i],\"的信息增益率\", infoGain\n", 201 | " #找出最大的熵值以及其对应的特征\n", 202 | " if infoGain > maxInfoGain:\n", 203 | " maxInfoGain = infoGain\n", 204 | " bestFeat = i\n", 205 | " return bestFeat" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": 141, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "#如果决策树递归生成完毕,且叶子节点中样本不是属于同一类,则以少数服从多数原则确定该叶子节点类别\n", 215 | "def majorityCnt(labelList):\n", 216 | " labels = set(labelList)\n", 217 | " labelSet = {}\n", 218 | " for label in labels:\n", 219 | " labelSet[label] = 0\n", 220 | " #统计每个类别的样本个数\n", 221 | " for label in labelList:\n", 222 | " if label not in labelSet.keys():\n", 223 | " labelSet[label] = 0\n", 224 | " labelSet[label] += 1\n", 225 | " #iteritems:返回列表迭代器\n", 226 | " #operator.itemgeter(1):获取对象第一个域的值\n", 227 | " #True:降序\n", 228 | " sortedLabelSet = sorted(labelSet.items(), key=operator.itemgetter(1), reverse=True)\n", 229 | " return sortedLabelSet[0][0]" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 142, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "import operator\n", 239 | "\n", 240 | "#创建决策树\n", 241 | "def createDecideTree(dataSet, featName):\n", 242 | " #数据集的分类类别\n", 243 | " classList = [dataVec[-1] for dataVec in dataSet]\n", 244 | " #所有样本属于同一类时,停止划分,返回该类别\n", 245 | " if len(classList) == classList.count(classList[0]):\n", 246 | " return classList[0]\n", 247 | " #所有特征已经遍历完,停止划分,返回样本数最多的类别\n", 248 | " if len(dataSet[0]) == 1:\n", 249 | " return majorityCnt(classList)\n", 250 | " #选择最好的特征进行划分\n", 251 | " bestFeat = chooseBestFeatToSplit(dataSet, featName)\n", 252 | " bestFeatName = featName[bestFeat]\n", 253 | " #print \"当前最佳属性是\", bestFeatName\n", 254 | " del featName[bestFeat]\n", 255 | " #以字典形式表示树\n", 256 | " DTree = {bestFeatName:{}}\n", 257 | " #根据选择的特征,遍历该特征的所有属性值,在每个划分子集上递归调用createDecideTree\n", 258 | " featValue = [dataVec[bestFeat] for dataVec in dataSet]\n", 259 | " featValue = set(featValue)\n", 260 | " #print \"剩余候选属性是\", featName\n", 261 | " for value in featValue:\n", 262 | " subFeatName = featName[:]\n", 263 | " DTree[bestFeatName][value] = createDecideTree(splitDataSet(dataSet,bestFeat,value), subFeatName)\n", 264 | " return DTree" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 190, 270 | "metadata": {}, 271 | "outputs": [ 272 | { 273 | "name": "stdout", 274 | "output_type": "stream", 275 | "text": [ 276 | "{'experience': {'A1': {'location': {'D6': {'company_type': {'C2': 3000, 'A2': {'education': {'C3': {'major': {'F7': 3000, 'B7': 3000, 'A7': 3000, 'D7': 3000, 'E7': 1000, 'C7': 3000}}, 'A3': {'major': {'F7': 3000, 'B7': 3000, 'A7': 3000, 'D7': 3000, 'E7': 3000, 'C7': 3000}}, 'E3': {'major': {'D7': 8000, 'B7': 8000}}, 'B3': {'major': {'F7': 3000, 'B7': 3000, 'A7': 3000, 'D7': 3000, 'E7': 3000, 'C7': 3000}}}}, 'B2': {'major': {'E7': {'education': {'C3': 3000, 'B3': 3000}}, 'B7': {'education': {'C3': 3000, 'A3': 8000, 'B3': 10000}}, 'C7': {'education': {'C3': 8000, 'A3': 1000, 'B3': 3000}}, 'A7': {'education': {'C3': 3000, 'A3': 1000, 'B3': 3000}}, 'F7': 5000}}}}, 'B6': {'education': {'C3': {'company_type': {'C2': 1000, 'A2': {'major': {'F7': 3000, 'B7': 3000, 'A7': 3000, 'D7': 3000, 'E7': 3000, 'C7': 3000}}, 'D2': 3000, 'B2': {'major': {'D7': 1000, 'B7': 3000, 'C7': 1000, 'A7': 3000}}}}, 'A3': {'company_type': {'A2': {'major': {'F7': 3000, 'B7': 3000, 'A7': 3000, 'D7': 3000, 'E7': 3000, 'C7': 3000}}, 'D2': 5000, 'B2': {'major': {'D7': 3000, 'E7': 3000, 'B7': 3000, 'C7': 3000, 'A7': 3000}}}}, 'D3': 10000, 'E3': {'major': {'F7': {'company_type': {'A2': 3000}}, 'B7': {'company_type': {'A2': 8000}}, 'A7': {'company_type': {'A2': 5000}}, 'D7': {'company_type': {'C2': 3000, 'A2': 5000}}, 'E7': {'company_type': {'A2': 5000, 'B2': 5000}}, 'C7': {'company_type': {'A2': 5000}}}}, 'B3': {'major': {'F7': {'company_type': {'A2': 3000, 'B2': 3000}}, 'B7': {'company_type': {'C2': 1000, 'A2': 3000, 'B2': 3000}}, 'A7': {'company_type': {'C2': 3000, 'A2': 3000, 'B2': 3000}}, 'D7': {'company_type': {'A2': 3000, 'D2': 3000}}, 'E7': {'company_type': {'A2': 3000, 'B2': 3000}}, 'C7': {'company_type': {'A2': 3000, 'B2': 5000}}}}}}, 'E6': {'company_type': {'C2': 3000, 'A2': {'major': {'F7': {'education': {'C3': 3000, 'A3': 3000, 'D3': 10000, 'E3': 10000, 'B3': 3000}}, 'B7': {'education': {'C3': 3000, 'A3': 3000, 'B3': 3000}}, 'A7': {'education': {'C3': 3000, 'A3': 3000, 'B3': 3000}}, 'D7': {'education': {'C3': 3000, 'A3': 5000, 'E3': 5000, 'B3': 3000}}, 'E7': {'education': {'C3': 1000, 'A3': 1000, 'B3': 1000}}, 'C7': {'education': {'C3': 3000, 'A3': 3000, 'B3': 1000}}}}, 'B2': {'education': {'A3': {'major': {'C7': 5000, 'A7': 3000}}, 'B3': {'major': {'B7': 10000, 'C7': 3000, 'A7': 10000}}}}}}, 'C6': {'education': {'C3': {'company_type': {'C2': 8000, 'A2': {'major': {'F7': 3000, 'B7': 3000, 'A7': 3000, 'D7': 3000, 'E7': 3000, 'C7': 3000}}, 'B2': {'major': {'A7': 5000}}}}, 'A3': {'company_type': {'C2': {'major': {'D7': 1000, 'B7': 3000, 'A7': 5000}}, 'A2': {'major': {'F7': 3000, 'B7': 3000, 'A7': 3000, 'D7': 3000, 'E7': 3000, 'C7': 3000}}, 'B2': {'major': {'D7': 1000, 'E7': 3000, 'B7': 3000, 'C7': 3000, 'A7': 3000}}}}, 'D3': 10000, 'E3': {'company_type': {'A2': {'major': {'F7': 10000, 'B7': 5000, 'A7': 5000, 'D7': 3000, 'E7': 8000, 'C7': 5000}}, 'B2': {'major': {'A7': 1000, 'F7': 3000}}}}, 'B3': {'company_type': {'C2': {'major': {'A7': 1000}}, 'A2': {'major': {'F7': 3000, 'B7': 3000, 'A7': 3000, 'D7': 3000, 'E7': 3000, 'C7': 3000}}, 'B2': {'major': {'B7': 3000, 'C7': 3000, 'A7': 3000, 'F7': 1000}}}}}}, 'A6': {'education': {'C3': {'company_type': {'A2': {'major': {'F7': 3000, 'B7': 3000, 'A7': 3000, 'D7': 3000, 'E7': 3000, 'C7': 3000}}, 'D2': 3000, 'B2': {'major': {'D7': 1000, 'B7': 3000, 'C7': 1000, 'A7': 3000}}}}, 'A3': {'company_type': {'C2': {'major': {'F7': 5000, 'B7': 3000, 'A7': 1000, 'D7': 3000, 'E7': 5000, 'C7': 5000}}, 'A2': {'major': {'F7': 3000, 'B7': 3000, 'A7': 3000, 'D7': 3000, 'E7': 3000, 'C7': 3000}}, 'D2': 1000, 'B2': {'major': {'D7': 1000, 'E7': 3000, 'B7': 3000, 'C7': 3000, 'A7': 3000}}}}, 'D3': {'company_type': {'C2': 8000, 'A2': {'major': {'D7': 10000, 'B7': 10000, 'F7': 10000}}}}, 'E3': {'company_type': {'C2': {'major': {'B7': 8000, 'A7': 5000}}, 'A2': {'major': {'F7': 1000, 'B7': 10000, 'A7': 5000, 'D7': 5000, 'E7': 1000, 'C7': 5000}}, 'B2': {'major': {'B7': 5000, 'C7': 5000, 'A7': 5000}}}}, 'B3': {'major': {'F7': {'company_type': {'A2': 3000}}, 'B7': {'company_type': {'C2': 5000, 'A2': 3000, 'B2': 5000}}, 'A7': {'company_type': {'C2': 5000, 'A2': 3000, 'B2': 3000}}, 'D7': {'company_type': {'A2': 5000, 'D2': 3000}}, 'E7': {'company_type': {'C2': 3000, 'A2': 3000, 'B2': 3000}}, 'C7': {'company_type': {'C2': 1000, 'A2': 5000, 'B2': 3000}}}}}}}}, 'C1': {'education': {'C3': {'company_type': {'A2': {'location': {'D6': {'major': {'F7': 3000, 'B7': 3000, 'A7': 3000, 'D7': 3000, 'E7': 3000, 'C7': 3000}}, 'B6': {'major': {'F7': 8000, 'B7': 3000, 'A7': 3000, 'D7': 5000, 'E7': 10000, 'C7': 3000}}, 'E6': {'major': {'D7': 5000, 'E7': 3000, 'B7': 3000, 'C7': 8000, 'A7': 3000}}, 'C6': {'major': {'D7': 5000, 'E7': 5000, 'B7': 3000, 'C7': 3000, 'A7': 3000}}, 'A6': {'major': {'F7': 8000, 'B7': 5000, 'A7': 3000, 'D7': 3000, 'E7': 10000, 'C7': 5000}}}}, 'B2': {'major': {'E7': 5000, 'B7': {'location': {'E6': 1000, 'C6': 5000}}, 'A7': {'location': {'D6': 3000, 'B6': 1000, 'E6': 3000, 'C6': 1000, 'A6': 3000}}}}}}, 'A3': {'location': {'D6': {'company_type': {'A2': {'major': {'F7': 3000, 'B7': 5000, 'A7': 3000, 'D7': 3000, 'E7': 3000, 'C7': 5000}}, 'D2': 3000, 'B2': {'major': {'E7': 10000, 'B7': 3000, 'C7': 3000, 'A7': 10000, 'F7': 3000}}}}, 'B6': {'company_type': {'C2': {'major': {'B7': 10000, 'A7': 5000}}, 'A2': {'major': {'F7': 5000, 'B7': 5000, 'A7': 3000, 'D7': 5000, 'E7': 3000, 'C7': 5000}}, 'D2': {'major': {'D7': 1000, 'B7': 3000}}, 'B2': {'major': {'D7': 10000, 'E7': 3000, 'B7': 5000, 'C7': 8000, 'A7': 3000}}}}, 'E6': {'major': {'F7': {'company_type': {'A2': 3000}}, 'B7': {'company_type': {'A2': 3000, 'B2': 3000}}, 'A7': {'company_type': {'A2': 3000, 'D2': 3000, 'B2': 10000}}, 'D7': {'company_type': {'A2': 8000}}, 'E7': {'company_type': {'A2': 8000, 'B2': 10000}}, 'C7': {'company_type': {'A2': 5000, 'B2': 5000}}}}, 'C6': {'company_type': {'A2': {'major': {'F7': 5000, 'B7': 3000, 'A7': 3000, 'D7': 3000, 'E7': 3000, 'C7': 5000}}, 'B2': {'major': {'D7': 3000, 'E7': 8000, 'B7': 10000, 'C7': 8000, 'A7': 3000}}}}, 'A6': {'major': {'F7': {'company_type': {'A2': 5000}}, 'B7': {'company_type': {'A2': 10000, 'B2': 5000}}, 'A7': {'company_type': {'A2': 5000, 'B2': 5000}}, 'D7': {'company_type': {'A2': 5000}}, 'E7': {'company_type': {'A2': 5000, 'B2': 10000}}, 'C7': {'company_type': {'A2': 5000, 'B2': 5000}}}}}}, 'D3': 10000, 'E3': {'location': {'D6': {'company_type': {'A2': {'major': {'D7': 8000, 'C7': 5000, 'A7': 8000}}, 'B2': 10000}}, 'B6': {'major': {'D7': {'company_type': {'A2': 5000}}, 'B7': 10000, 'C7': {'company_type': {'A2': 10000}}, 'A7': {'company_type': {'A2': 10000}}, 'F7': 10000}}, 'C6': {'major': {'D7': {'company_type': {'A2': 10000}}, 'B7': 10000, 'C7': 5000, 'A7': {'company_type': {'A2': 10000}}, 'F7': 10000}}, 'A6': {'major': {'F7': 10000, 'B7': {'company_type': {'A2': 10000, 'B2': 10000}}, 'A7': {'company_type': {'A2': 10000, 'B2': 10000}}, 'D7': {'company_type': {'A2': 10000, 'B2': 10000}}, 'E7': {'company_type': {'A2': 10000, 'B2': 10000}}, 'C7': {'company_type': {'A2': 10000}}}}}}, 'B3': {'location': {'D6': {'major': {'F7': {'company_type': {'A2': 5000, 'B2': 5000}}, 'B7': {'company_type': {'A2': 10000, 'B2': 10000}}, 'A7': {'company_type': {'A2': 3000, 'B2': 10000}}, 'D7': {'company_type': {'A2': 3000}}, 'E7': {'company_type': {'A2': 8000}}, 'C7': {'company_type': {'A2': 10000, 'B2': 5000}}}}, 'B6': {'company_type': {'C2': {'major': {'D7': 10000, 'C7': 10000}}, 'A2': {'major': {'F7': 5000, 'B7': 10000, 'A7': 10000, 'D7': 5000, 'E7': 5000, 'C7': 10000}}, 'D2': 5000, 'B2': {'major': {'E7': 10000, 'B7': 5000, 'C7': 10000, 'A7': 5000, 'F7': 3000}}}}, 'E6': {'company_type': {'C2': 5000, 'A2': {'major': {'F7': 5000, 'B7': 10000, 'A7': 5000, 'D7': 10000, 'E7': 5000, 'C7': 10000}}, 'B2': {'major': {'B7': 3000, 'C7': 10000, 'A7': 3000}}}}, 'C6': {'company_type': {'C2': 5000, 'A2': {'major': {'F7': 8000, 'B7': 5000, 'A7': 3000, 'D7': 10000, 'E7': 8000, 'C7': 10000}}, 'B2': {'major': {'E7': 5000, 'B7': 5000, 'C7': 3000, 'A7': 3000, 'F7': 10000}}}}, 'A6': {'company_type': {'C2': {'major': {'E7': 8000, 'B7': 10000, 'A7': 3000}}, 'A2': {'major': {'F7': 10000, 'B7': 10000, 'A7': 10000, 'D7': 10000, 'E7': 10000, 'C7': 10000}}, 'D2': 3000, 'B2': {'major': {'E7': 10000, 'B7': 10000, 'C7': 5000, 'A7': 10000, 'F7': 8000}}}}}}}}, 'E1': {'education': {'C3': {'major': {'D7': 5000, 'B7': {'company_type': {'A2': {'location': {'D6': 5000, 'B6': 8000, 'A6': 3000}}}}, 'C7': {'company_type': {'A2': {'location': {5000: 5000, 10000: 10000}}}}, 'A7': {'location': {'D6': 3000, 'B6': {'company_type': {'A2': 10000}}, 'A6': {'company_type': {'A2': 40000, 'B2': 10000}}}}}}, 'A3': {'company_type': {'A2': {'location': {'D6': {'major': {'E7': 10000, 'B7': 10000, 'C7': 10000, 'A7': 10000}}, 'B6': {'major': {'F7': 40000, 'B7': 10000, 'A7': 10000, 'D7': 40000, 'E7': 3000, 'C7': 10000}}, 'E6': {'major': {'B7': 10000, 'C7': 10000, 'A7': 10000}}, 'C6': {'major': {'D7': 10000, 'E7': 10000, 'B7': 10000, 'C7': 10000, 'A7': 10000}}, 'A6': {'major': {'D7': 10000, 'E7': 5000, 'B7': 10000, 'C7': 10000, 'A7': 10000}}}}, 'D2': 3000, 'B2': {'location': {'D6': {'major': {'C7': 10000, 'A7': 3000}}, 'B6': 3000, 'C6': 10000, 'A6': {'major': {'C7': 10000, 'A7': 5000}}}}}}, 'D3': {'company_type': {'A2': {'major': {'D7': 10000, 'B7': 40000, 'C7': 10000}}}}, 'E3': {'company_type': {'A2': {'location': {'D6': 40000, 'B6': {'major': {'E7': 10000, 'B7': 8000}}, 'A6': {'major': {'D7': 40000, 'B7': 10000, 'C7': 10000, 'A7': 10000}}}}}}, 'B3': {'location': {'D6': {'major': {'F7': 10000, 'B7': {'company_type': {'A2': 10000}}, 'A7': {'company_type': {'A2': 10000, 'B2': 10000}}, 'D7': {'company_type': {'A2': 40000}}, 'E7': 10000, 'C7': {'company_type': {'A2': 10000}}}}, 'B6': {'major': {'E7': {'company_type': {'A2': 8000, 'B2': 10000}}, 'B7': {'company_type': {'C2': 10000, 'A2': 10000, 'B2': 8000}}, 'C7': {'company_type': {'A2': 10000, 'B2': 40000}}, 'A7': {'company_type': {'C2': 10000, 'A2': 10000, 'B2': 10000}}, 'F7': 10000}}, 'E6': {'company_type': {'A2': {'major': {'D7': 10000, 'B7': 10000, 'C7': 10000, 'A7': 10000}}, 'B2': {'major': {'B7': 3000, 'A7': 5000}}}}, 'C6': {'major': {'F7': 10000, 'B7': {'company_type': {'A2': 10000}}, 'A7': {'company_type': {'A2': 10000, 'B2': 10000}}, 'D7': {'company_type': {'A2': 40000}}, 'E7': {'company_type': {'A2': 10000}}, 'C7': {'company_type': {'A2': 10000, 'B2': 10000}}}}, 'A6': {'company_type': {'A2': {'major': {'F7': 10000, 'B7': 10000, 'A7': 10000, 'D7': 10000, 'E7': 10000, 'C7': 10000}}, 'B2': {'major': {'B7': 10000, 'C7': 10000, 'A7': 10000}}}}}}}}, 'B1': {'education': {'C3': {'company_type': {'C2': {'major': {'D7': {'location': {'B6': 5000, 'C6': 1000}}, 'E7': 3000, 'B7': {'location': {'C6': 5000, 'A6': 10000}}, 'A7': {'location': {'B6': 3000, 'A6': 3000}}}}, 'A2': {'location': {'D6': {'major': {'F7': 3000, 'B7': 3000, 'A7': 3000, 'D7': 3000, 'E7': 3000, 'C7': 3000}}, 'B6': {'major': {'F7': 3000, 'B7': 3000, 'A7': 3000, 'D7': 3000, 'E7': 3000, 'C7': 3000}}, 'E6': {'major': {'F7': 8000, 'B7': 3000, 'A7': 3000, 'D7': 1000, 'E7': 3000, 'C7': 3000}}, 'C6': {'major': {'F7': 3000, 'B7': 3000, 'A7': 3000, 'D7': 3000, 'E7': 3000, 'C7': 3000}}, 'A6': {'major': {'F7': 3000, 'B7': 3000, 'A7': 3000, 'D7': 3000, 'E7': 3000, 'C7': 3000}}}}, 'D2': {'location': {'E6': 3000, 'C6': 3000, 'A6': 5000}}, 'B2': {'major': {'D7': {'location': {'D6': 1000, 'C6': 3000}}, 'E7': 10000, 'B7': {'location': {'B6': 5000, 'A6': 3000}}, 'C7': 8000, 'A7': {'location': {'B6': 3000, 'C6': 3000, 'A6': 1000}}}}}}, 'A3': {'location': {'D6': {'major': {'F7': {'company_type': {'A2': 3000, 'B2': 1000}}, 'B7': {'company_type': {'A2': 3000, 'B2': 3000}}, 'A7': {'company_type': {'C2': 3000, 'A2': 3000, 'B2': 3000}}, 'D7': {'company_type': {'A2': 3000, 'B2': 3000}}, 'E7': {'company_type': {'A2': 3000, 'B2': 5000}}, 'C7': {'company_type': {'A2': 3000, 'B2': 3000}}}}, 'B6': {'company_type': {'C2': {'major': {'A7': 5000}}, 'A2': {'major': {'F7': 3000, 'B7': 3000, 'A7': 3000, 'D7': 3000, 'E7': 3000, 'C7': 3000}}, 'D2': {'major': {'A7': 5000}}, 'B2': {'major': {'D7': 3000, 'E7': 3000, 'B7': 3000, 'C7': 3000, 'A7': 3000}}}}, 'E6': {'company_type': {'A2': {'major': {'F7': 1000, 'B7': 3000, 'A7': 3000, 'D7': 3000, 'E7': 3000, 'C7': 3000}}, 'B2': {'major': {'D7': 3000, 'E7': 8000, 'B7': 3000, 'A7': 1000, 'F7': 1000}}}}, 'C6': {'major': {'F7': {'company_type': {'A2': 3000}}, 'B7': {'company_type': {'A2': 3000, 'B2': 3000}}, 'A7': {'company_type': {'A2': 3000, 'D2': 3000, 'B2': 3000}}, 'D7': {'company_type': {'A2': 3000, 'B2': 3000}}, 'E7': {'company_type': {'C2': 1000, 'A2': 3000, 'B2': 1000}}, 'C7': {'company_type': {'A2': 3000, 'B2': 3000}}}}, 'A6': {'company_type': {'C2': {'major': {'E7': 3000, 'B7': 3000, 'A7': 5000}}, 'A2': {'major': {'F7': 3000, 'B7': 5000, 'A7': 3000, 'D7': 3000, 'E7': 5000, 'C7': 5000}}, 'D2': {'major': {'E7': 5000, 'B7': 5000, 'A7': 3000}}, 'B2': {'major': {'D7': 3000, 'E7': 8000, 'B7': 3000, 'C7': 5000, 'A7': 3000}}}}}}, 'D3': 10000, 'E3': {'major': {'F7': {'location': {'B6': {'company_type': {'A2': 10000}}, 'C6': {'company_type': {'C2': 5000, 'A2': 3000}}, 'A6': 8000}}, 'B7': {'location': {'D6': 8000, 'B6': {'company_type': {'A2': 10000}}, 'C6': {'company_type': {'C2': 10000, 'A2': 10000}}, 'A6': {'company_type': {'C2': 10000, 'A2': 10000, 'B2': 10000}}}}, 'A7': {'company_type': {'C2': 3000, 'A2': {'location': {'B6': 8000, 'E6': 3000, 'C6': 8000, 'A6': 8000}}, 'B2': 10000}}, 'D7': {'company_type': {'A2': {'location': {'B6': 5000, 'C6': 10000, 'A6': 5000}}}}, 'E7': {'company_type': {'A2': {'location': {'B6': 3000, 'C6': 8000, 'A6': 8000}}}}, 'C7': {'company_type': {'C2': 8000, 'A2': {'location': {'D6': 3000, 'B6': 3000, 'C6': 5000, 'A6': 10000}}, 'B2': 5000}}}}, 'B3': {'location': {'D6': {'company_type': {'C2': {'major': {'B7': 3000, 'A7': 8000}}, 'A2': {'major': {'F7': 3000, 'B7': 3000, 'A7': 3000, 'D7': 3000, 'E7': 3000, 'C7': 3000}}, 'B2': {'major': {'B7': 3000, 'C7': 8000, 'A7': 3000}}}}, 'B6': {'major': {'F7': {'company_type': {'C2': 5000, 'A2': 3000, 'B2': 5000}}, 'B7': {'company_type': {'C2': 8000, 'A2': 5000, 'B2': 5000}}, 'A7': {'company_type': {'C2': 3000, 'A2': 3000, 'B2': 3000}}, 'D7': {'company_type': {'A2': 5000, 'D2': 5000, 'B2': 3000}}, 'E7': {'company_type': {'C2': 5000, 'A2': 3000, 'D2': 10000, 'B2': 3000}}, 'C7': {'company_type': {'C2': 3000, 'A2': 3000, 'B2': 5000}}}}, 'E6': {'company_type': {'A2': {'major': {'F7': 3000, 'B7': 3000, 'A7': 3000, 'D7': 3000, 'E7': 3000, 'C7': 3000}}, 'B2': {'major': {'E7': 3000, 'B7': 8000, 'C7': 8000, 'A7': 5000, 'F7': 3000}}}}, 'C6': {'company_type': {'C2': {'major': {'B7': 5000}}, 'A2': {'major': {'F7': 3000, 'B7': 3000, 'A7': 3000, 'D7': 5000, 'E7': 3000, 'C7': 3000}}, 'D2': 3000, 'B2': {'major': {'F7': 5000, 'B7': 5000, 'A7': 5000, 'D7': 3000, 'E7': 5000, 'C7': 10000}}}}, 'A6': {'major': {'F7': {'company_type': {'A2': 10000, 'B2': 3000}}, 'B7': {'company_type': {'C2': 10000, 'A2': 10000, 'B2': 10000}}, 'A7': {'company_type': {'C2': 5000, 'A2': 5000, 'D2': 3000, 'B2': 5000}}, 'D7': {'company_type': {'C2': 5000, 'A2': 5000, 'B2': 5000}}, 'E7': {'company_type': {'C2': 3000, 'A2': 5000, 'D2': 1000, 'B2': 5000}}, 'C7': {'company_type': {'C2': 5000, 'A2': 5000, 'B2': 5000}}}}}}}}, 'D1': {'education': {'C3': {'company_type': {'C2': {'major': {'B7': {'location': {'D6': 8000, 'A6': 3000}}}}, 'A2': {'major': {'F7': 10000, 'B7': {'location': {'D6': 5000, 'B6': 5000, 'E6': 5000, 'C6': 5000, 'A6': 10000}}, 'A7': {'location': {'D6': 8000, 'B6': 3000, 'E6': 8000, 'C6': 5000, 'A6': 5000}}, 'D7': {'location': {'B6': 1000, 'C6': 10000, 'A6': 10000}}, 'E7': {'location': {'B6': 5000, 'A6': 10000}}, 'C7': {'location': {'D6': 10000, 'B6': 10000, 'E6': 5000, 'C6': 1000, 'A6': 5000}}}}, 'D2': 3000, 'B2': {'location': {'B6': {'major': {'E7': 5000, 'F7': 10000}}, 'A6': 3000}}}}, 'A3': {'company_type': {'C2': 10000, 'A2': {'location': {'D6': {'major': {'F7': 5000, 'B7': 5000, 'A7': 10000, 'D7': 3000, 'E7': 8000, 'C7': 10000}}, 'B6': {'major': {'F7': 5000, 'B7': 10000, 'A7': 10000, 'D7': 10000, 'E7': 10000, 'C7': 10000}}, 'E6': {'major': {'D7': 10000, 'B7': 10000, 'C7': 8000, 'A7': 10000, 'F7': 5000}}, 'C6': {'major': {'F7': 5000, 'B7': 10000, 'A7': 10000, 'D7': 10000, 'E7': 10000, 'C7': 10000}}, 'A6': {'major': {'F7': 8000, 'B7': 10000, 'A7': 10000, 'D7': 10000, 'E7': 10000, 'C7': 10000}}}}, 'D2': {'major': {'D7': 3000, 'A7': 1000}}, 'B2': {'major': {'E7': {'location': {'D6': 5000, 'B6': 5000, 'A6': 10000}}, 'B7': {'location': {'D6': 10000, 'B6': 3000, 'E6': 10000, 'C6': 3000, 'A6': 10000}}, 'C7': {'location': {'D6': 5000, 'B6': 10000, 'E6': 3000, 'C6': 10000, 'A6': 10000}}, 'A7': {'location': {'D6': 5000, 'B6': 5000, 'C6': 5000, 'A6': 5000}}, 'F7': {'location': {'D6': 5000, 'A6': 10000}}}}}}, 'D3': {'major': {'B7': 10000, 'C7': 10000, 'A7': 40000}}, 'E3': {'major': {'F7': 10000, 'B7': {'location': {'D6': 10000, 'B6': {'company_type': {'A2': 10000}}, 'C6': 10000, 'A6': {'company_type': {'A2': 10000, 'B2': 10000}}}}, 'A7': {'location': {'B6': 10000, 'C6': 40000, 'A6': {'company_type': {'C2': 10000, 'A2': 10000, 'B2': 10000}}}}, 'D7': {'company_type': {'A2': {'location': {'D6': 10000, 'B6': 10000, 'C6': 10000, 'A6': 10000}}}}, 'E7': {'company_type': {'A2': {'location': {'B6': 5000, 'C6': 5000, 'A6': 10000}}}}, 'C7': {'location': {'A6': {'company_type': {'A2': 10000, 'B2': 8000}}}}}}, 'B3': {'location': {'D6': {'major': {'F7': {'company_type': {'A2': 10000}}, 'B7': {'company_type': {'C2': 5000, 'A2': 10000, 'B2': 8000}}, 'A7': {'company_type': {'A2': 10000, 'B2': 10000}}, 'D7': {'company_type': {'A2': 10000}}, 'E7': {'company_type': {'A2': 10000}}, 'C7': {'company_type': {'A2': 10000, 'B2': 10000}}}}, 'B6': {'major': {'F7': {'company_type': {'A2': 10000}}, 'B7': {'company_type': {'A2': 10000, 'B2': 10000}}, 'A7': {'company_type': {'C2': 3000, 'A2': 10000, 'B2': 10000}}, 'D7': {'company_type': {'A2': 10000}}, 'E7': {'company_type': {'A2': 10000}}, 'C7': {'company_type': {'A2': 10000, 'B2': 10000}}}}, 'E6': {'major': {'F7': 8000, 'B7': {'company_type': {'A2': 10000, 'B2': 5000}}, 'A7': {'company_type': {'A2': 10000, 'B2': 10000}}, 'D7': {'company_type': {'A2': 3000}}, 'E7': {'company_type': {'A2': 10000}}, 'C7': {'company_type': {'A2': 10000, 'B2': 5000}}}}, 'C6': {'major': {'F7': 8000, 'B7': {'company_type': {'C2': 10000, 'A2': 10000, 'B2': 10000}}, 'A7': {'company_type': {'A2': 10000, 'B2': 10000}}, 'D7': {'company_type': {'A2': 10000, 'B2': 10000}}, 'E7': {'company_type': {'A2': 10000}}, 'C7': {'company_type': {'A2': 10000, 'B2': 8000}}}}, 'A6': {'company_type': {'C2': 10000, 'A2': {'major': {'F7': 10000, 'B7': 10000, 'A7': 10000, 'D7': 10000, 'E7': 10000, 'C7': 10000}}, 'B2': {'major': {'E7': 10000, 'B7': 10000, 'C7': 10000, 'A7': 10000}}}}}}}}}}\n" 277 | ] 278 | } 279 | ], 280 | "source": [ 281 | "DT = createDecideTree(dataSet,attributes)\n", 282 | "print DT" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": 144, 288 | "metadata": {}, 289 | "outputs": [], 290 | "source": [ 291 | "#使用训练的模型进行分类\n", 292 | "def classify(tree,feat,featValue):\n", 293 | " firstFeat = list(tree.keys())[0]\n", 294 | " secondDict = tree[firstFeat]\n", 295 | " featIndex = feat.index(firstFeat)\n", 296 | " key = featValue[featIndex]\n", 297 | " if key in secondDict.keys():\n", 298 | " if type(secondDict[key]).__name__ == 'dict':\n", 299 | " classLabel = classify(secondDict[key],feat,featValue)\n", 300 | " else:\n", 301 | " classLabel = secondDict[key]\n", 302 | " else:\n", 303 | " classLabel = \"OH NO!\"\n", 304 | " return classLabel" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 172, 310 | "metadata": {}, 311 | "outputs": [], 312 | "source": [ 313 | "#计算测试集准确率\n", 314 | "feat = ['experience','education','major','location','company_type']\n", 315 | "def precisionCal(tree,feat,test_data):\n", 316 | " tests = [dataVec[:-1] for dataVec in test_data]\n", 317 | " true = 0\n", 318 | " sum = len(tests)\n", 319 | " for i in range(len(tests)):\n", 320 | " result = classify(tree,feat,tests[i])\n", 321 | " if result == \"OH NO!\":\n", 322 | " sum = sum - 1\n", 323 | " else:\n", 324 | " #print result,\" versus \",test_data[i][5]\n", 325 | " if result == test_data[i][-1]:\n", 326 | " true = true + 1\n", 327 | " return float(true)/float(sum)" 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": 173, 333 | "metadata": {}, 334 | "outputs": [ 335 | { 336 | "name": "stdout", 337 | "output_type": "stream", 338 | "text": [ 339 | "0.518890675241\n" 340 | ] 341 | } 342 | ], 343 | "source": [ 344 | "print precisionCal(DT,feat,test_data)" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": 174, 350 | "metadata": {}, 351 | "outputs": [], 352 | "source": [ 353 | "#C4.5剪枝\n", 354 | "import copy\n", 355 | "\n", 356 | "def isTree(obj):#判断当前节点是否是叶节点\n", 357 | " return (type(obj).__name__ == 'dict')\n", 358 | "\n", 359 | "def prune(tree, testData, featList):\n", 360 | " newTree = copy.deepcopy(tree)\n", 361 | " firstFeat = list(tree.keys())[0]\n", 362 | " secondDict = tree[firstFeat]\n", 363 | " #print firstFeat\n", 364 | " for key in secondDict.keys():\n", 365 | " #print key\n", 366 | " if isTree(secondDict[key]):\n", 367 | " #print \"is tree\"\n", 368 | " #print secondDict[key]\n", 369 | " featNum = featList.index(firstFeat)\n", 370 | " subSample = splitDataSet(testData, featNum, key)\n", 371 | " subclassList = [dataVec[-1] for dataVec in subSample]\n", 372 | " if len(subclassList)!=0:\n", 373 | " #print \"majority\", majorityCnt(subclassList)\n", 374 | " if len(subclassList) == subclassList.count(subclassList[0]):\n", 375 | " tree[firstFeat][key] = majorityCnt(subclassList)\n", 376 | " else:\n", 377 | " newTree[firstFeat][key] = majorityCnt(subclassList)\n", 378 | " precison_old = precisionCal(tree, featList, testData)\n", 379 | " precision_new = precisionCal(newTree, featList, testData)\n", 380 | " #print precison_old, \"versus\", precision_new\n", 381 | " if precision_new > precison_old:\n", 382 | " #print \"new tree\"\n", 383 | " tree[firstFeat][key] = majorityCnt(subclassList)\n", 384 | " else:\n", 385 | " #print \"go on\"\n", 386 | " newfeatList = featList[:featNum] + featList[featNum + 1:]\n", 387 | " tree[firstFeat][key] = prune(secondDict[key], subSample, newfeatList)\n", 388 | " else:\n", 389 | " continue\n", 390 | " return tree " 391 | ] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "execution_count": 175, 396 | "metadata": {}, 397 | "outputs": [], 398 | "source": [ 399 | "import copy\n", 400 | "tree = copy.deepcopy(DT)\n", 401 | "atrributes = ['experience','education','major','location','company_type']\n", 402 | "newDT = prune(tree, test_data, atrributes)" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": 149, 408 | "metadata": {}, 409 | "outputs": [ 410 | { 411 | "data": { 412 | "text/plain": [ 413 | "0.5334535256410257" 414 | ] 415 | }, 416 | "execution_count": 149, 417 | "metadata": {}, 418 | "output_type": "execute_result" 419 | } 420 | ], 421 | "source": [ 422 | "atrributes = ['experience','education','major','location','company_type']\n", 423 | "precisionCal(newDT,atrributes,test_data)" 424 | ] 425 | }, 426 | { 427 | "cell_type": "markdown", 428 | "metadata": {}, 429 | "source": [ 430 | "# CART" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": 150, 436 | "metadata": {}, 437 | "outputs": [], 438 | "source": [ 439 | "#划分样本集\n", 440 | "def binarySplitDataSet(dataset,feature,value):\n", 441 | " matLeft = dataset[np.nonzero(dataset[:,feature] <= value)[0],:]\n", 442 | " matRight = dataset[np.nonzero(dataset[:,feature] > value)[0],:]\n", 443 | " return matLeft,matRight" 444 | ] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "execution_count": 151, 449 | "metadata": {}, 450 | "outputs": [], 451 | "source": [ 452 | "def regressErr(dataset):\n", 453 | " return np.var(dataset[:,-1]) * np.shape(dataset)[0]\n", 454 | "\n", 455 | "def chooseBestSplit(dataset, threshold=(1, 200)):\n", 456 | " thresholdErr = threshold[0]\n", 457 | " thresholdSamples = threshold[1]\n", 458 | " if len(set(dataset[:,-1].tolist())) == 1:\n", 459 | " return None,np.mean(dataset[:,-1])\n", 460 | " m,n = np.shape(dataset)\n", 461 | " Err = regressErr(dataset)\n", 462 | " bestErr = np.inf \n", 463 | " bestFeatureIndex = 0\n", 464 | " bestFeatureValue = 0\n", 465 | " for featureindex in range(n-1):#最优划分特征选择\n", 466 | " #print featureindex\n", 467 | " for featurevalue in set(dataset[:,featureindex]):#最优划分特征值选择\n", 468 | " #print featurevalue\n", 469 | " matLeft,matRight = binarySplitDataSet(dataset,featureindex,featurevalue)\n", 470 | " if (np.shape(matLeft)[0] < thresholdSamples) or (np.shape(matRight)[0] < thresholdSamples):\n", 471 | " continue\n", 472 | " temErr = regressErr(matLeft) + regressErr(matRight)\n", 473 | " #print \"temErr\",temErr\n", 474 | " if temErr < bestErr:\n", 475 | " bestErr = temErr\n", 476 | " bestFeatureIndex = featureindex\n", 477 | " bestFeatureValue = featurevalue\n", 478 | " #检验在所选出的最优划分特征及其取值下,误差平方和与未划分时的差是否小于阈值,若是,则不适合划分\n", 479 | " if (Err - bestErr) < thresholdErr:\n", 480 | " return None,np.mean(dataset[:,-1])\n", 481 | " matLeft,matRight = binarySplitDataSet(dataset,bestFeatureIndex,bestFeatureValue)\n", 482 | " #检验在所选出的最优划分特征及其取值下,划分的左右数据集的样本数是否小于阈值,若是,则不适合划分\n", 483 | " if (np.shape(matLeft)[0] < thresholdSamples) or (np.shape(matRight)[0] < thresholdSamples):\n", 484 | " return None,np.mean(dataset[:,-1])\n", 485 | " return bestFeatureIndex,bestFeatureValue" 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "execution_count": 152, 491 | "metadata": {}, 492 | "outputs": [], 493 | "source": [ 494 | "def createCARTtree(dataset,threshold=(1,200)):\n", 495 | " feature,value = chooseBestSplit(dataset,threshold)\n", 496 | " #当不满足阈值或某一子数据集下输出全相等时,返回叶节点\n", 497 | " if feature == None: return value\n", 498 | " returnTree = {}\n", 499 | " returnTree['bestSplitFeature'] = feature\n", 500 | " returnTree['bestSplitFeatValue'] = value\n", 501 | " leftSet,rightSet = binarySplitDataSet(dataset,feature,value)\n", 502 | " returnTree['left'] = createCARTtree(leftSet,threshold)\n", 503 | " returnTree['right'] = createCARTtree(rightSet,threshold)\n", 504 | " return returnTree" 505 | ] 506 | }, 507 | { 508 | "cell_type": "code", 509 | "execution_count": 153, 510 | "metadata": {}, 511 | "outputs": [ 512 | { 513 | "name": "stdout", 514 | "output_type": "stream", 515 | "text": [ 516 | "{'bestSplitFeature': 0, 'bestSplitFeatValue': 'B1', 'right': {'bestSplitFeature': 0, 'bestSplitFeatValue': 'C1', 'right': {'bestSplitFeature': 3, 'bestSplitFeatValue': 'A6', 'right': {'bestSplitFeature': 0, 'bestSplitFeatValue': 'D1', 'right': 11546.703296703297, 'left': {'bestSplitFeature': 1, 'bestSplitFeatValue': 'A3', 'right': {'bestSplitFeature': 3, 'bestSplitFeatValue': 'C6', 'right': 8593.272171253822, 'left': {'bestSplitFeature': 3, 'bestSplitFeatValue': 'B6', 'right': 9386.533665835412, 'left': {'bestSplitFeature': 2, 'bestSplitFeatValue': 'A7', 'right': 9289.473684210527, 'left': 9099.514563106795}}}, 'left': {'bestSplitFeature': 2, 'bestSplitFeatValue': 'A7', 'right': 7508.87573964497, 'left': 7426.751592356688}}}, 'left': {'bestSplitFeature': 1, 'bestSplitFeatValue': 'A3', 'right': {'bestSplitFeature': 2, 'bestSplitFeatValue': 'A7', 'right': {'bestSplitFeature': 2, 'bestSplitFeatValue': 'B7', 'right': 11502.732240437159, 'left': 11117.936117936119}, 'left': 13829.683698296836}, 'left': 9000.0}}, 'left': {'bestSplitFeature': 3, 'bestSplitFeatValue': 'A6', 'right': {'bestSplitFeature': 1, 'bestSplitFeatValue': 'A3', 'right': {'bestSplitFeature': 1, 'bestSplitFeatValue': 'B3', 'right': 5013.368983957219, 'left': {'bestSplitFeature': 3, 'bestSplitFeatValue': 'B6', 'right': {'bestSplitFeature': 3, 'bestSplitFeatValue': 'C6', 'right': 6487.421383647798, 'left': 6237.410071942446}, 'left': {'bestSplitFeature': 2, 'bestSplitFeatValue': 'A7', 'right': 7199.579831932773, 'left': 6808.463251670379}}}, 'left': {'bestSplitFeature': 2, 'bestSplitFeatValue': 'A7', 'right': {'bestSplitFeature': 2, 'bestSplitFeatValue': 'C7', 'right': 5680.216802168022, 'left': {'bestSplitFeature': 3, 'bestSplitFeatValue': 'B6', 'right': 5625.0, 'left': 5546.3917525773195}}, 'left': {'bestSplitFeature': 3, 'bestSplitFeatValue': 'B6', 'right': 4942.105263157895, 'left': 5248.076923076923}}}, 'left': {'bestSplitFeature': 1, 'bestSplitFeatValue': 'A3', 'right': {'bestSplitFeature': 2, 'bestSplitFeatValue': 'B7', 'right': 8258.536585365853, 'left': {'bestSplitFeature': 2, 'bestSplitFeatValue': 'A7', 'right': 9043.844856661046, 'left': 8448.520710059172}}, 'left': {'bestSplitFeature': 2, 'bestSplitFeatValue': 'A7', 'right': {'bestSplitFeature': 2, 'bestSplitFeatValue': 'B7', 'right': 6620.12987012987, 'left': 7310.0}, 'left': 6429.921259842519}}}}, 'left': {'bestSplitFeature': 3, 'bestSplitFeatValue': 'A6', 'right': {'bestSplitFeature': 1, 'bestSplitFeatValue': 'A3', 'right': {'bestSplitFeature': 1, 'bestSplitFeatValue': 'B3', 'right': {'bestSplitFeature': 0, 'bestSplitFeatValue': 'A1', 'right': {'bestSplitFeature': 2, 'bestSplitFeatValue': 'B7', 'right': 4076.923076923077, 'left': {'bestSplitFeature': 3, 'bestSplitFeatValue': 'B6', 'right': 3476.0705289672546, 'left': 3716.89497716895}}, 'left': {'bestSplitFeature': 2, 'bestSplitFeatValue': 'A7', 'right': 3135.6589147286822, 'left': {'bestSplitFeature': 3, 'bestSplitFeatValue': 'B6', 'right': 3980.6451612903224, 'left': 3608.1504702194356}}}, 'left': {'bestSplitFeature': 0, 'bestSplitFeatValue': 'A1', 'right': {'bestSplitFeature': 2, 'bestSplitFeatValue': 'A7', 'right': {'bestSplitFeature': 2, 'bestSplitFeatValue': 'C7', 'right': 4578.4114052953155, 'left': {'bestSplitFeature': 3, 'bestSplitFeatValue': 'B6', 'right': 5099.009900990099, 'left': 5597.964376590331}}, 'left': {'bestSplitFeature': 3, 'bestSplitFeatValue': 'B6', 'right': 4088.095238095238, 'left': 4325.367647058823}}, 'left': {'bestSplitFeature': 2, 'bestSplitFeatValue': 'B7', 'right': 3347.417840375587, 'left': {'bestSplitFeature': 3, 'bestSplitFeatValue': 'B6', 'right': 3912.2340425531916, 'left': 4025.839793281654}}}}, 'left': {'bestSplitFeature': 0, 'bestSplitFeatValue': 'A1', 'right': {'bestSplitFeature': 2, 'bestSplitFeatValue': 'A7', 'right': {'bestSplitFeature': 2, 'bestSplitFeatValue': 'C7', 'right': {'bestSplitFeature': 3, 'bestSplitFeatValue': 'B6', 'right': 3437.037037037037, 'left': 3578.2312925170068}, 'left': {'bestSplitFeature': 3, 'bestSplitFeatValue': 'B6', 'right': 3706.997084548105, 'left': 4190.962099125364}}, 'left': {'bestSplitFeature': 3, 'bestSplitFeatValue': 'B6', 'right': {'bestSplitFeature': 3, 'bestSplitFeatValue': 'C6', 'right': 3360.4240282685514, 'left': 3420.256111757858}, 'left': 3588.9407061958696}}, 'left': {'bestSplitFeature': 3, 'bestSplitFeatValue': 'B6', 'right': {'bestSplitFeature': 2, 'bestSplitFeatValue': 'C7', 'right': 2926.9841269841268, 'left': {'bestSplitFeature': 2, 'bestSplitFeatValue': 'A7', 'right': 3068.281938325991, 'left': {'bestSplitFeature': 3, 'bestSplitFeatValue': 'C6', 'right': 3122.5806451612902, 'left': 3094.6236559139784}}}, 'left': {'bestSplitFeature': 2, 'bestSplitFeatValue': 'A7', 'right': {'bestSplitFeature': 2, 'bestSplitFeatValue': 'C7', 'right': 3114.4578313253014, 'left': 3220.0}, 'left': 3328.159645232816}}}}, 'left': {'bestSplitFeature': 1, 'bestSplitFeatValue': 'A3', 'right': {'bestSplitFeature': 1, 'bestSplitFeatValue': 'B3', 'right': {'bestSplitFeature': 2, 'bestSplitFeatValue': 'A7', 'right': {'bestSplitFeature': 0, 'bestSplitFeatValue': 'A1', 'right': 4585.751978891821, 'left': 4339.74358974359}, 'left': {'bestSplitFeature': 0, 'bestSplitFeatValue': 'A1', 'right': 4508.474576271186, 'left': 5022.058823529412}}, 'left': {'bestSplitFeature': 0, 'bestSplitFeatValue': 'A1', 'right': {'bestSplitFeature': 2, 'bestSplitFeatValue': 'B7', 'right': 5858.851674641149, 'left': {'bestSplitFeature': 2, 'bestSplitFeatValue': 'A7', 'right': 7442.708333333333, 'left': 6144.300144300144}}, 'left': {'bestSplitFeature': 2, 'bestSplitFeatValue': 'A7', 'right': 5085.995085995086, 'left': 4757.76397515528}}}, 'left': {'bestSplitFeature': 0, 'bestSplitFeatValue': 'A1', 'right': {'bestSplitFeature': 2, 'bestSplitFeatValue': 'A7', 'right': {'bestSplitFeature': 2, 'bestSplitFeatValue': 'B7', 'right': 4571.593533487298, 'left': 5172.413793103448}, 'left': 4532.02846975089}, 'left': {'bestSplitFeature': 2, 'bestSplitFeatValue': 'A7', 'right': 3909.531502423263, 'left': 4412.71676300578}}}}}\n" 517 | ] 518 | } 519 | ], 520 | "source": [ 521 | "cart_DT = createCARTtree(train_data,threshold=(1,300))\n", 522 | "print cart_DT" 523 | ] 524 | }, 525 | { 526 | "cell_type": "code", 527 | "execution_count": 154, 528 | "metadata": {}, 529 | "outputs": [], 530 | "source": [ 531 | "#使用模型进行预测\n", 532 | "def isTree(obj):#判断当前节点是否是叶节点\n", 533 | " return type(obj).__name__ == 'dict'\n", 534 | "\n", 535 | "def treeForeCast(tree, inputData):\n", 536 | " if not isTree(tree): return float(tree)\n", 537 | " if inputData[tree['bestSplitFeature']] <= tree['bestSplitFeatValue']:\n", 538 | " if isTree(tree['left']):\n", 539 | " return treeForeCast(tree['left'],inputData)\n", 540 | " else:\n", 541 | " return float(tree['left'])\n", 542 | " else:\n", 543 | " if isTree(tree['right']):\n", 544 | " return treeForeCast(tree['right'],inputData)\n", 545 | " else:\n", 546 | " return float(tree['right'])" 547 | ] 548 | }, 549 | { 550 | "cell_type": "code", 551 | "execution_count": 182, 552 | "metadata": {}, 553 | "outputs": [], 554 | "source": [ 555 | "#计算测试集准确率\n", 556 | "def tranverse(data):\n", 557 | " if 1000<=data<3000:\n", 558 | " data = 1000\n", 559 | " if 3000<=data<5000:\n", 560 | " data = 3000\n", 561 | " if 5000<=data<8000:\n", 562 | " data = 5000\n", 563 | " if 8000<=data<10000:\n", 564 | " data = 8000\n", 565 | " if 10000<=data<40000:\n", 566 | " data = 10000\n", 567 | " if data>=40000:\n", 568 | " data = 40000\n", 569 | " return data\n", 570 | "\n", 571 | "def precisionCal(tree,test_data):\n", 572 | " tests = [dataVec[:-1] for dataVec in test_data]\n", 573 | " true = 0\n", 574 | " sum = len(tests)\n", 575 | " for i in range(len(tests)):\n", 576 | " result = treeForeCast(tree, tests[i])\n", 577 | " #print result,\" versus \",test_data[i][-1]\n", 578 | " if abs(result - test_data[i][-1]) <= 2000:\n", 579 | " true = true + 1\n", 580 | " return float(true)/float(sum)" 581 | ] 582 | }, 583 | { 584 | "cell_type": "code", 585 | "execution_count": 183, 586 | "metadata": {}, 587 | "outputs": [ 588 | { 589 | "data": { 590 | "text/plain": [ 591 | "0.6199239847969594" 592 | ] 593 | }, 594 | "execution_count": 183, 595 | "metadata": {}, 596 | "output_type": "execute_result" 597 | } 598 | ], 599 | "source": [ 600 | "precisionCal(cart_DT,test_data)" 601 | ] 602 | }, 603 | { 604 | "cell_type": "code", 605 | "execution_count": 184, 606 | "metadata": {}, 607 | "outputs": [], 608 | "source": [ 609 | "#CART剪枝\n", 610 | "def getMean(tree):\n", 611 | " if isTree(tree['left']): tree['left'] = getMean(tree['left'])\n", 612 | " if isTree(tree['right']): tree['right'] = getMean(tree['right'])\n", 613 | " return (tree['left'] + tree['right'])/2.0\n", 614 | "\n", 615 | "def cart_prune(tree, testData):\n", 616 | " if np.shape(testData)[0] == 0: return getMean(tree)#存在测试集中没有训练集中数据的情况\n", 617 | " if isTree(tree['left']) or isTree(tree['right']):\n", 618 | " leftTestData, rightTestData = binarySplitDataSet(testData,tree['bestSplitFeature'],tree['bestSplitFeatValue'])\n", 619 | " #递归调用prune函数对左右子树,注意与左右子树对应的左右子测试数据集\n", 620 | " if isTree(tree['left']): tree['left'] = cart_prune(tree['left'],leftTestData)\n", 621 | " if isTree(tree['right']): tree['right'] = cart_prune(tree['right'],rightTestData)\n", 622 | " #当递归搜索到左右子树均为叶节点时,计算测试数据集的误差平方和\n", 623 | " if not isTree(tree['left']) and not isTree(tree['right']):\n", 624 | " leftTestData, rightTestData = binarySplitDataSet(testData,tree['bestSplitFeature'],tree['bestSplitFeatValue'])\n", 625 | " errorNOmerge = sum(np.power(leftTestData[:,-1] - tree['left'],2)) +sum(np.power(rightTestData[:,-1] - tree['right'],2))\n", 626 | " errorMerge = sum(np.power(testData[:,-1] - getMean(tree),2))\n", 627 | " if errorMerge < errorNOmerge:\n", 628 | " return getMean(tree)\n", 629 | " else: return tree\n", 630 | " else: return tree" 631 | ] 632 | }, 633 | { 634 | "cell_type": "code", 635 | "execution_count": 185, 636 | "metadata": {}, 637 | "outputs": [ 638 | { 639 | "data": { 640 | "text/plain": [ 641 | "0.6263252650530106" 642 | ] 643 | }, 644 | "execution_count": 185, 645 | "metadata": {}, 646 | "output_type": "execute_result" 647 | } 648 | ], 649 | "source": [ 650 | "#计算剪枝后准确率\n", 651 | "import copy\n", 652 | "\n", 653 | "cart_tree = copy.deepcopy(cart_DT)\n", 654 | "cart_tree = cart_prune(cart_tree, test_data)\n", 655 | "precisionCal(cart_tree,test_data)" 656 | ] 657 | }, 658 | { 659 | "cell_type": "markdown", 660 | "metadata": {}, 661 | "source": [ 662 | "# Bagging" 663 | ] 664 | }, 665 | { 666 | "cell_type": "code", 667 | "execution_count": 176, 668 | "metadata": {}, 669 | "outputs": [], 670 | "source": [ 671 | "#生成C4.5随机森林\n", 672 | "Forest = []\n", 673 | "\n", 674 | "def createSample():\n", 675 | " index = np.random.randint(35000, size=30000)\n", 676 | " train_sample = train_data[index,:]\n", 677 | " atrributes = ['experience','education','major','location','company_type']\n", 678 | " return train_sample, atrributes\n", 679 | "\n", 680 | "def createForest():\n", 681 | " for i in range(7):\n", 682 | " print \"creating tree\",i,\"...\"\n", 683 | " train_sample, attributes = createSample()\n", 684 | " DT_sample = createDecideTree(train_sample,attributes)\n", 685 | " atrributes = ['experience','education','major','location','company_type']\n", 686 | " DT_sample = prune(DT_sample, train_sample, atrributes)\n", 687 | " Forest.append(DT_sample) " 688 | ] 689 | }, 690 | { 691 | "cell_type": "code", 692 | "execution_count": 177, 693 | "metadata": {}, 694 | "outputs": [ 695 | { 696 | "name": "stdout", 697 | "output_type": "stream", 698 | "text": [ 699 | "creating tree 0 ...\n", 700 | "creating tree 1 ...\n", 701 | "creating tree 2 ...\n", 702 | "creating tree 3 ...\n", 703 | "creating tree 4 ...\n", 704 | "creating tree 5 ...\n", 705 | "creating tree 6 ...\n" 706 | ] 707 | } 708 | ], 709 | "source": [ 710 | "createForest()" 711 | ] 712 | }, 713 | { 714 | "cell_type": "code", 715 | "execution_count": 178, 716 | "metadata": {}, 717 | "outputs": [], 718 | "source": [ 719 | "def tranverse(data):\n", 720 | " if 1000<=data<3000:\n", 721 | " data = 1000\n", 722 | " if 3000<=data<5000:\n", 723 | " data = 3000\n", 724 | " if 5000<=data<8000:\n", 725 | " data = 5000\n", 726 | " if 8000<=data<10000:\n", 727 | " data = 8000\n", 728 | " if 10000<=data<40000:\n", 729 | " data = 10000\n", 730 | " if data>=40000:\n", 731 | " data = 40000\n", 732 | " return data\n", 733 | "\n", 734 | "def precisionForest(Forest,feat,test_data):\n", 735 | " tests = [dataVec[:-1] for dataVec in test_data]\n", 736 | " true = 0\n", 737 | " sum = len(tests)\n", 738 | " for i in range(len(tests)):\n", 739 | " results = []\n", 740 | " for j in range(len(Forest)):\n", 741 | " result = classify(Forest[j],feat,tests[i])\n", 742 | " if result != \"OH NO!\":\n", 743 | " results.append(result)\n", 744 | " #print results\n", 745 | " if len(results)>0:\n", 746 | " forestResult = majorityCnt(results)\n", 747 | " #print forestResult,\"versus\",test_data[i][-1]\n", 748 | " if forestResult == test_data[i][-1]:\n", 749 | " true = true + 1\n", 750 | " else:\n", 751 | " sum = sum - 1\n", 752 | " return float(true)/float(sum)\n", 753 | " " 754 | ] 755 | }, 756 | { 757 | "cell_type": "code", 758 | "execution_count": 179, 759 | "metadata": {}, 760 | "outputs": [ 761 | { 762 | "data": { 763 | "text/plain": [ 764 | "0.5145319703347364" 765 | ] 766 | }, 767 | "execution_count": 179, 768 | "metadata": {}, 769 | "output_type": "execute_result" 770 | } 771 | ], 772 | "source": [ 773 | "attributes = ['experience','education','major','location','company_type']\n", 774 | "precisionForest(Forest,attributes,test_data)" 775 | ] 776 | }, 777 | { 778 | "cell_type": "code", 779 | "execution_count": 186, 780 | "metadata": {}, 781 | "outputs": [], 782 | "source": [ 783 | "cart_Forest = []\n", 784 | "\n", 785 | "def createSample():\n", 786 | " index = np.random.randint(35000, size=30000)\n", 787 | " train_sample = train_data[index,:]\n", 788 | " atrributes = ['experience','education','major','location','company_type']\n", 789 | " return train_sample, atrributes\n", 790 | "\n", 791 | "def createCARTForest():\n", 792 | " for i in range(7):\n", 793 | " print \"creating tree\",i,\"...\"\n", 794 | " train_sample, attributes = createSample()\n", 795 | " DT_sample = createCARTtree(train_sample,threshold=(1,250))\n", 796 | " atrributes = ['experience','education','major','location','company_type']\n", 797 | " DT_sample = cart_prune(DT_sample, train_sample)\n", 798 | " cart_Forest.append(DT_sample) " 799 | ] 800 | }, 801 | { 802 | "cell_type": "code", 803 | "execution_count": 187, 804 | "metadata": {}, 805 | "outputs": [ 806 | { 807 | "name": "stdout", 808 | "output_type": "stream", 809 | "text": [ 810 | "creating tree 0 ...\n", 811 | "creating tree 1 ...\n", 812 | "creating tree 2 ...\n", 813 | "creating tree 3 ...\n", 814 | "creating tree 4 ...\n", 815 | "creating tree 5 ...\n", 816 | "creating tree 6 ...\n" 817 | ] 818 | } 819 | ], 820 | "source": [ 821 | "createCARTForest()" 822 | ] 823 | }, 824 | { 825 | "cell_type": "code", 826 | "execution_count": 188, 827 | "metadata": {}, 828 | "outputs": [], 829 | "source": [ 830 | "def precisionCARTForest(Forest,feat,test_data):\n", 831 | " tests = [dataVec[:-1] for dataVec in test_data]\n", 832 | " true = 0\n", 833 | " sum = len(tests)\n", 834 | " for i in range(len(tests)):\n", 835 | " results = []\n", 836 | " for j in range(len(Forest)):\n", 837 | " result = treeForeCast(cart_Forest[j], tests[i])\n", 838 | " results.append(result)\n", 839 | " #print results\n", 840 | " if len(results)>0:\n", 841 | " forestResult = np.mean(results)\n", 842 | " #print forestResult,\"versus\",test_data[i][-1]\n", 843 | " if abs(forestResult - test_data[i][-1]) <= 2500:\n", 844 | " true = true + 1\n", 845 | " else:\n", 846 | " sum = sum - 1\n", 847 | " return float(true)/float(sum)" 848 | ] 849 | }, 850 | { 851 | "cell_type": "code", 852 | "execution_count": 189, 853 | "metadata": {}, 854 | "outputs": [ 855 | { 856 | "data": { 857 | "text/plain": [ 858 | "0.7181436287257451" 859 | ] 860 | }, 861 | "execution_count": 189, 862 | "metadata": {}, 863 | "output_type": "execute_result" 864 | } 865 | ], 866 | "source": [ 867 | "attributes = ['experience','education','major','location','company_type']\n", 868 | "precisionCARTForest(cart_Forest,attributes,test_data)" 869 | ] 870 | }, 871 | { 872 | "cell_type": "code", 873 | "execution_count": null, 874 | "metadata": {}, 875 | "outputs": [], 876 | "source": [] 877 | }, 878 | { 879 | "cell_type": "code", 880 | "execution_count": null, 881 | "metadata": {}, 882 | "outputs": [], 883 | "source": [] 884 | } 885 | ], 886 | "metadata": { 887 | "kernelspec": { 888 | "display_name": "Python 2", 889 | "language": "python", 890 | "name": "python2" 891 | }, 892 | "language_info": { 893 | "codemirror_mode": { 894 | "name": "ipython", 895 | "version": 2 896 | }, 897 | "file_extension": ".py", 898 | "mimetype": "text/x-python", 899 | "name": "python", 900 | "nbconvert_exporter": "python", 901 | "pygments_lexer": "ipython2", 902 | "version": "2.7.15" 903 | } 904 | }, 905 | "nbformat": 4, 906 | "nbformat_minor": 2 907 | } 908 | --------------------------------------------------------------------------------