├── code ├── gongzhong.jpg ├── readme.md ├── 第10章 隐马尔可夫模型(HMM) │ └── HMM.ipynb ├── 第11章 条件随机场(CRF) │ └── CRF.ipynb ├── 第1章 统计学习方法概论(LeastSquaresMethod) │ └── least_sqaure_method.ipynb ├── 第2章 感知机(Perceptron) │ └── Iris_perceptron.ipynb ├── 第3章 k近邻法(KNearestNeighbors) │ ├── KDT.py │ └── KNN.ipynb ├── 第4章 朴素贝叶斯(NaiveBayes) │ └── GaussianNB.ipynb ├── 第5章 决策树(DecisonTree) │ ├── DT.ipynb │ ├── Decision Tree (ID3 剪枝) │ └── mytree.pdf ├── 第6章 逻辑斯谛回归(LogisticRegression) │ ├── LR.ipynb │ └── 最大熵模型 IIS.py ├── 第7章 支持向量机(SVM) │ └── support-vector-machine.ipynb ├── 第8章 提升方法(AdaBoost) │ └── Adaboost.ipynb └── 第9章 EM算法及其推广(EM) │ └── em.ipynb ├── images ├── 1543246677825.png └── gongzhong.png ├── ppt ├── readme.md ├── 第10章 隐马尔科夫模型.pdf ├── 第11章 条件随机场.pdf ├── 第12章 统计学习方法总结.pdf ├── 第1章 统计学习方法概论.pdf ├── 第2章 感知机.pdf ├── 第3章 k 近邻法.pdf ├── 第4章 朴素贝叶斯法.pdf ├── 第5章 决策树-2016-ID3CART.pdf ├── 第6章 Logistic回归.pdf ├── 第7章 支持向量机.pdf ├── 第8章 提升方法.pdf └── 第9章 EM算法及其推广.pdf └── readme.md /code/gongzhong.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingchaoZhu/lihang-code/3b833f78fd99ef6fb60b651708c61792e56c61bc/code/gongzhong.jpg -------------------------------------------------------------------------------- /code/readme.md: -------------------------------------------------------------------------------- 1 | **代码目录** 2 | 3 | 第1章 统计学习方法概论 4 | 5 | 第2章 感知机 6 | 7 | 第3章 k近邻法 8 | 9 | 第4章 朴素贝叶斯 10 | 11 | 第5章 决策树 12 | 13 | 第6章 逻辑斯谛回归 14 | 15 | 第7章 支持向量机 16 | 17 | 第8章 提升方法 18 | 19 | 第9章 EM算法及其推广 20 | 21 | 第10章 隐马尔可夫模型 22 | 23 | 第11章 条件随机场 24 | 25 | ----------- 26 | 参考: 27 | https://github.com/wzyonggege/statistical-learning-method 28 | https://github.com/WenDesi/lihang_book_algorithm 29 | https://blog.csdn.net/tudaodiaozhale 30 | 31 | 代码整理和修改:机器学习初学者 (微信公众号,ID:ai-start-com) 32 | 33 | 34 | -------------------------------------------------------------------------------- /code/第10章 隐马尔可夫模型(HMM)/HMM.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "原文代码作者:https://blog.csdn.net/tudaodiaozhale\n", 8 | "\n", 9 | "中文注释制作:机器学习初学者(微信公众号:ID:ai-start-com)\n", 10 | "\n", 11 | "配置环境:python 3.6\n", 12 | "\n", 13 | "代码全部测试通过。\n", 14 | "![gongzhong](../gongzhong.jpg)" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "# 第10章 隐马尔可夫模型" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 5, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import numpy as np" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 6, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "class HiddenMarkov:\n", 40 | " def forward(self, Q, V, A, B, O, PI): # 使用前向算法\n", 41 | " N = len(Q) # 状态序列的大小\n", 42 | " M = len(O) # 观测序列的大小\n", 43 | " alphas = np.zeros((N, M)) # alpha值\n", 44 | " T = M # 有几个时刻,有几个观测序列,就有几个时刻\n", 45 | " for t in range(T): # 遍历每一时刻,算出alpha值\n", 46 | " indexOfO = V.index(O[t]) # 找出序列对应的索引\n", 47 | " for i in range(N):\n", 48 | " if t == 0: # 计算初值\n", 49 | " alphas[i][t] = PI[t][i] * B[i][indexOfO] # P176(10.15)\n", 50 | " print('alpha1(%d)=p%db%db(o1)=%f' % (i, i, i, alphas[i][t]))\n", 51 | " else:\n", 52 | " alphas[i][t] = np.dot([alpha[t - 1] for alpha in alphas], [a[i] for a in A]) * B[i][\n", 53 | " indexOfO] # 对应P176(10.16)\n", 54 | " print('alpha%d(%d)=[sigma alpha%d(i)ai%d]b%d(o%d)=%f' % (t, i, t - 1, i, i, t, alphas[i][t]))\n", 55 | " # print(alphas)\n", 56 | " P = np.sum([alpha[M - 1] for alpha in alphas]) # P176(10.17)\n", 57 | " # alpha11 = pi[0][0] * B[0][0] #代表a1(1)\n", 58 | " # alpha12 = pi[0][1] * B[1][0] #代表a1(2)\n", 59 | " # alpha13 = pi[0][2] * B[2][0] #代表a1(3)\n", 60 | "\n", 61 | " def backward(self, Q, V, A, B, O, PI): # 后向算法\n", 62 | " N = len(Q) # 状态序列的大小\n", 63 | " M = len(O) # 观测序列的大小\n", 64 | " betas = np.ones((N, M)) # beta\n", 65 | " for i in range(N):\n", 66 | " print('beta%d(%d)=1' % (M, i))\n", 67 | " for t in range(M - 2, -1, -1):\n", 68 | " indexOfO = V.index(O[t + 1]) # 找出序列对应的索引\n", 69 | " for i in range(N):\n", 70 | " betas[i][t] = np.dot(np.multiply(A[i], [b[indexOfO] for b in B]), [beta[t + 1] for beta in betas])\n", 71 | " realT = t + 1\n", 72 | " realI = i + 1\n", 73 | " print('beta%d(%d)=[sigma a%djbj(o%d)]beta%d(j)=(' % (realT, realI, realI, realT + 1, realT + 1),\n", 74 | " end='')\n", 75 | " for j in range(N):\n", 76 | " print(\"%.2f*%.2f*%.2f+\" % (A[i][j], B[j][indexOfO], betas[j][t + 1]), end='')\n", 77 | " print(\"0)=%.3f\" % betas[i][t])\n", 78 | " # print(betas)\n", 79 | " indexOfO = V.index(O[0])\n", 80 | " P = np.dot(np.multiply(PI, [b[indexOfO] for b in B]), [beta[0] for beta in betas])\n", 81 | " print(\"P(O|lambda)=\", end=\"\")\n", 82 | " for i in range(N):\n", 83 | " print(\"%.1f*%.1f*%.5f+\" % (PI[0][i], B[i][indexOfO], betas[i][0]), end=\"\")\n", 84 | " print(\"0=%f\" % P)\n", 85 | "\n", 86 | " def viterbi(self, Q, V, A, B, O, PI):\n", 87 | " N = len(Q) # 状态序列的大小\n", 88 | " M = len(O) # 观测序列的大小\n", 89 | " deltas = np.zeros((N, M))\n", 90 | " psis = np.zeros((N, M))\n", 91 | " I = np.zeros((1, M))\n", 92 | " for t in range(M):\n", 93 | " realT = t+1\n", 94 | " indexOfO = V.index(O[t]) # 找出序列对应的索引\n", 95 | " for i in range(N):\n", 96 | " realI = i+1\n", 97 | " if t == 0:\n", 98 | " deltas[i][t] = PI[0][i] * B[i][indexOfO]\n", 99 | " psis[i][t] = 0\n", 100 | " print('delta1(%d)=pi%d * b%d(o1)=%.2f * %.2f=%.2f'%(realI, realI, realI, PI[0][i], B[i][indexOfO], deltas[i][t]))\n", 101 | " print('psis1(%d)=0' % (realI))\n", 102 | " else:\n", 103 | " deltas[i][t] = np.max(np.multiply([delta[t-1] for delta in deltas], [a[i] for a in A])) * B[i][indexOfO]\n", 104 | " print('delta%d(%d)=max[delta%d(j)aj%d]b%d(o%d)=%.2f*%.2f=%.5f'%(realT, realI, realT-1, realI, realI, realT, np.max(np.multiply([delta[t-1] for delta in deltas], [a[i] for a in A])), B[i][indexOfO], deltas[i][t]))\n", 105 | " psis[i][t] = np.argmax(np.multiply([delta[t-1] for delta in deltas], [a[i] for a in A]))\n", 106 | " print('psis%d(%d)=argmax[delta%d(j)aj%d]=%d' % (realT, realI, realT-1, realI, psis[i][t]))\n", 107 | " print(deltas)\n", 108 | " print(psis)\n", 109 | " I[0][M-1] = np.argmax([delta[M-1] for delta in deltas])\n", 110 | " print('i%d=argmax[deltaT(i)]=%d' % (M, I[0][M-1]+1))\n", 111 | " for t in range(M-2, -1, -1):\n", 112 | " I[0][t] = psis[int(I[0][t+1])][t+1]\n", 113 | " print('i%d=psis%d(i%d)=%d' % (t+1, t+2, t+2, I[0][t]+1))\n", 114 | " print(I)" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "metadata": {}, 120 | "source": [ 121 | "### 习题10.1" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 7, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "#习题10.1\n", 131 | "Q = [1, 2, 3]\n", 132 | "V = ['红', '白']\n", 133 | "A = [[0.5, 0.2, 0.3], [0.3, 0.5, 0.2], [0.2, 0.3, 0.5]]\n", 134 | "B = [[0.5, 0.5], [0.4, 0.6], [0.7, 0.3]]\n", 135 | "# O = ['红', '白', '红', '红', '白', '红', '白', '白']\n", 136 | "O = ['红', '白', '红', '白'] #习题10.1的例子\n", 137 | "PI = [[0.2, 0.4, 0.4]]" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 8, 143 | "metadata": {}, 144 | "outputs": [ 145 | { 146 | "name": "stdout", 147 | "output_type": "stream", 148 | "text": [ 149 | "delta1(1)=pi1 * b1(o1)=0.20 * 0.50=0.10\n", 150 | "psis1(1)=0\n", 151 | "delta1(2)=pi2 * b2(o1)=0.40 * 0.40=0.16\n", 152 | "psis1(2)=0\n", 153 | "delta1(3)=pi3 * b3(o1)=0.40 * 0.70=0.28\n", 154 | "psis1(3)=0\n", 155 | "delta2(1)=max[delta1(j)aj1]b1(o2)=0.06*0.50=0.02800\n", 156 | "psis2(1)=argmax[delta1(j)aj1]=2\n", 157 | "delta2(2)=max[delta1(j)aj2]b2(o2)=0.08*0.60=0.05040\n", 158 | "psis2(2)=argmax[delta1(j)aj2]=2\n", 159 | "delta2(3)=max[delta1(j)aj3]b3(o2)=0.14*0.30=0.04200\n", 160 | "psis2(3)=argmax[delta1(j)aj3]=2\n", 161 | "delta3(1)=max[delta2(j)aj1]b1(o3)=0.02*0.50=0.00756\n", 162 | "psis3(1)=argmax[delta2(j)aj1]=1\n", 163 | "delta3(2)=max[delta2(j)aj2]b2(o3)=0.03*0.40=0.01008\n", 164 | "psis3(2)=argmax[delta2(j)aj2]=1\n", 165 | "delta3(3)=max[delta2(j)aj3]b3(o3)=0.02*0.70=0.01470\n", 166 | "psis3(3)=argmax[delta2(j)aj3]=2\n", 167 | "delta4(1)=max[delta3(j)aj1]b1(o4)=0.00*0.50=0.00189\n", 168 | "psis4(1)=argmax[delta3(j)aj1]=0\n", 169 | "delta4(2)=max[delta3(j)aj2]b2(o4)=0.01*0.60=0.00302\n", 170 | "psis4(2)=argmax[delta3(j)aj2]=1\n", 171 | "delta4(3)=max[delta3(j)aj3]b3(o4)=0.01*0.30=0.00220\n", 172 | "psis4(3)=argmax[delta3(j)aj3]=2\n", 173 | "[[0.1 0.028 0.00756 0.00189 ]\n", 174 | " [0.16 0.0504 0.01008 0.003024]\n", 175 | " [0.28 0.042 0.0147 0.002205]]\n", 176 | "[[0. 2. 1. 0.]\n", 177 | " [0. 2. 1. 1.]\n", 178 | " [0. 2. 2. 2.]]\n", 179 | "i4=argmax[deltaT(i)]=2\n", 180 | "i3=psis4(i4)=2\n", 181 | "i2=psis3(i3)=2\n", 182 | "i1=psis2(i2)=3\n", 183 | "[[2. 1. 1. 1.]]\n" 184 | ] 185 | } 186 | ], 187 | "source": [ 188 | "HMM = HiddenMarkov()\n", 189 | "# HMM.forward(Q, V, A, B, O, PI)\n", 190 | "# HMM.backward(Q, V, A, B, O, PI)\n", 191 | "HMM.viterbi(Q, V, A, B, O, PI)" 192 | ] 193 | }, 194 | { 195 | "cell_type": "markdown", 196 | "metadata": {}, 197 | "source": [ 198 | "### 习题10.2" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 9, 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "Q = [1, 2, 3]\n", 208 | "V = ['红', '白']\n", 209 | "A = [[0.5, 0.2, 0.3], [0.3, 0.5, 0.2], [0.2, 0.3, 0.5]]\n", 210 | "B = [[0.5, 0.5], [0.4, 0.6], [0.7, 0.3]]\n", 211 | "O = ['红', '白', '红', '红', '白', '红', '白', '白']\n", 212 | "PI = [[0.2, 0.3, 0.5]]" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 10, 218 | "metadata": {}, 219 | "outputs": [ 220 | { 221 | "name": "stdout", 222 | "output_type": "stream", 223 | "text": [ 224 | "alpha1(0)=p0b0b(o1)=0.100000\n", 225 | "alpha1(1)=p1b1b(o1)=0.120000\n", 226 | "alpha1(2)=p2b2b(o1)=0.350000\n", 227 | "alpha1(0)=[sigma alpha0(i)ai0]b0(o1)=0.078000\n", 228 | "alpha1(1)=[sigma alpha0(i)ai1]b1(o1)=0.111000\n", 229 | "alpha1(2)=[sigma alpha0(i)ai2]b2(o1)=0.068700\n", 230 | "alpha2(0)=[sigma alpha1(i)ai0]b0(o2)=0.043020\n", 231 | "alpha2(1)=[sigma alpha1(i)ai1]b1(o2)=0.036684\n", 232 | "alpha2(2)=[sigma alpha1(i)ai2]b2(o2)=0.055965\n", 233 | "alpha3(0)=[sigma alpha2(i)ai0]b0(o3)=0.021854\n", 234 | "alpha3(1)=[sigma alpha2(i)ai1]b1(o3)=0.017494\n", 235 | "alpha3(2)=[sigma alpha2(i)ai2]b2(o3)=0.033758\n", 236 | "alpha4(0)=[sigma alpha3(i)ai0]b0(o4)=0.011463\n", 237 | "alpha4(1)=[sigma alpha3(i)ai1]b1(o4)=0.013947\n", 238 | "alpha4(2)=[sigma alpha3(i)ai2]b2(o4)=0.008080\n", 239 | "alpha5(0)=[sigma alpha4(i)ai0]b0(o5)=0.005766\n", 240 | "alpha5(1)=[sigma alpha4(i)ai1]b1(o5)=0.004676\n", 241 | "alpha5(2)=[sigma alpha4(i)ai2]b2(o5)=0.007188\n", 242 | "alpha6(0)=[sigma alpha5(i)ai0]b0(o6)=0.002862\n", 243 | "alpha6(1)=[sigma alpha5(i)ai1]b1(o6)=0.003389\n", 244 | "alpha6(2)=[sigma alpha5(i)ai2]b2(o6)=0.001878\n", 245 | "alpha7(0)=[sigma alpha6(i)ai0]b0(o7)=0.001411\n", 246 | "alpha7(1)=[sigma alpha6(i)ai1]b1(o7)=0.001698\n", 247 | "alpha7(2)=[sigma alpha6(i)ai2]b2(o7)=0.000743\n", 248 | "beta8(0)=1\n", 249 | "beta8(1)=1\n", 250 | "beta8(2)=1\n", 251 | "beta7(1)=[sigma a1jbj(o8)]beta8(j)=(0.50*0.50*1.00+0.20*0.60*1.00+0.30*0.30*1.00+0)=0.460\n", 252 | "beta7(2)=[sigma a2jbj(o8)]beta8(j)=(0.30*0.50*1.00+0.50*0.60*1.00+0.20*0.30*1.00+0)=0.510\n", 253 | "beta7(3)=[sigma a3jbj(o8)]beta8(j)=(0.20*0.50*1.00+0.30*0.60*1.00+0.50*0.30*1.00+0)=0.430\n", 254 | "beta6(1)=[sigma a1jbj(o7)]beta7(j)=(0.50*0.50*0.46+0.20*0.60*0.51+0.30*0.30*0.43+0)=0.215\n", 255 | "beta6(2)=[sigma a2jbj(o7)]beta7(j)=(0.30*0.50*0.46+0.50*0.60*0.51+0.20*0.30*0.43+0)=0.248\n", 256 | "beta6(3)=[sigma a3jbj(o7)]beta7(j)=(0.20*0.50*0.46+0.30*0.60*0.51+0.50*0.30*0.43+0)=0.202\n", 257 | "beta5(1)=[sigma a1jbj(o6)]beta6(j)=(0.50*0.50*0.21+0.20*0.40*0.25+0.30*0.70*0.20+0)=0.116\n", 258 | "beta5(2)=[sigma a2jbj(o6)]beta6(j)=(0.30*0.50*0.21+0.50*0.40*0.25+0.20*0.70*0.20+0)=0.110\n", 259 | "beta5(3)=[sigma a3jbj(o6)]beta6(j)=(0.20*0.50*0.21+0.30*0.40*0.25+0.50*0.70*0.20+0)=0.122\n", 260 | "beta4(1)=[sigma a1jbj(o5)]beta5(j)=(0.50*0.50*0.12+0.20*0.60*0.11+0.30*0.30*0.12+0)=0.053\n", 261 | "beta4(2)=[sigma a2jbj(o5)]beta5(j)=(0.30*0.50*0.12+0.50*0.60*0.11+0.20*0.30*0.12+0)=0.058\n", 262 | "beta4(3)=[sigma a3jbj(o5)]beta5(j)=(0.20*0.50*0.12+0.30*0.60*0.11+0.50*0.30*0.12+0)=0.050\n", 263 | "beta3(1)=[sigma a1jbj(o4)]beta4(j)=(0.50*0.50*0.05+0.20*0.40*0.06+0.30*0.70*0.05+0)=0.028\n", 264 | "beta3(2)=[sigma a2jbj(o4)]beta4(j)=(0.30*0.50*0.05+0.50*0.40*0.06+0.20*0.70*0.05+0)=0.026\n", 265 | "beta3(3)=[sigma a3jbj(o4)]beta4(j)=(0.20*0.50*0.05+0.30*0.40*0.06+0.50*0.70*0.05+0)=0.030\n", 266 | "beta2(1)=[sigma a1jbj(o3)]beta3(j)=(0.50*0.50*0.03+0.20*0.40*0.03+0.30*0.70*0.03+0)=0.015\n", 267 | "beta2(2)=[sigma a2jbj(o3)]beta3(j)=(0.30*0.50*0.03+0.50*0.40*0.03+0.20*0.70*0.03+0)=0.014\n", 268 | "beta2(3)=[sigma a3jbj(o3)]beta3(j)=(0.20*0.50*0.03+0.30*0.40*0.03+0.50*0.70*0.03+0)=0.016\n", 269 | "beta1(1)=[sigma a1jbj(o2)]beta2(j)=(0.50*0.50*0.02+0.20*0.60*0.01+0.30*0.30*0.02+0)=0.007\n", 270 | "beta1(2)=[sigma a2jbj(o2)]beta2(j)=(0.30*0.50*0.02+0.50*0.60*0.01+0.20*0.30*0.02+0)=0.007\n", 271 | "beta1(3)=[sigma a3jbj(o2)]beta2(j)=(0.20*0.50*0.02+0.30*0.60*0.01+0.50*0.30*0.02+0)=0.006\n", 272 | "P(O|lambda)=0.2*0.5*0.00698+0.3*0.4*0.00741+0.5*0.7*0.00647+0=0.003852\n" 273 | ] 274 | } 275 | ], 276 | "source": [ 277 | "HMM.forward(Q, V, A, B, O, PI)\n", 278 | "HMM.backward(Q, V, A, B, O, PI)" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": null, 284 | "metadata": {}, 285 | "outputs": [], 286 | "source": [] 287 | } 288 | ], 289 | "metadata": { 290 | "kernelspec": { 291 | "display_name": "Python 3", 292 | "language": "python", 293 | "name": "python3" 294 | }, 295 | "language_info": { 296 | "codemirror_mode": { 297 | "name": "ipython", 298 | "version": 3 299 | }, 300 | "file_extension": ".py", 301 | "mimetype": "text/x-python", 302 | "name": "python", 303 | "nbconvert_exporter": "python", 304 | "pygments_lexer": "ipython3", 305 | "version": "3.6.2" 306 | } 307 | }, 308 | "nbformat": 4, 309 | "nbformat_minor": 2 310 | } 311 | -------------------------------------------------------------------------------- /code/第11章 条件随机场(CRF)/CRF.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "原文代码作者:https://blog.csdn.net/GrinAndBearIt/article/details/79229803\n", 8 | "\n", 9 | "中文注释制作:机器学习初学者(微信公众号:ID:ai-start-com)\n", 10 | "\n", 11 | "配置环境:python 3.6\n", 12 | "\n", 13 | "代码全部测试通过。\n", 14 | "![gongzhong](../gongzhong.jpg)" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "# 第11章 条件随机场\n" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "### 例11.1" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 1, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "from numpy import *" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 4, 43 | "metadata": {}, 44 | "outputs": [ 45 | { 46 | "name": "stdout", 47 | "output_type": "stream", 48 | "text": [ 49 | "24.532530197109345\n", 50 | "24.532530197109352\n" 51 | ] 52 | } 53 | ], 54 | "source": [ 55 | "#这里定义T为转移矩阵列代表前一个y(ij)代表由状态i转到状态j的概率,Tx矩阵x对应于时间序列\n", 56 | "#这里将书上的转移特征转换为如下以时间轴为区别的三个多维列表,维度为输出的维度\n", 57 | "T1=[[0.6,1],[1,0]];T2=[[0,1],[1,0.2]]\n", 58 | "#将书上的状态特征同样转换成列表,第一个是为y1的未规划概率,第二个为y2的未规划概率\n", 59 | "S0=[1,0.5];S1=[0.8,0.5];S2=[0.8,0.5]\n", 60 | "Y=[1,2,2] #即书上例一需要计算的非规划条件概率的标记序列\n", 61 | "Y=array(Y)-1 #这里为了将数与索引相对应即从零开始\n", 62 | "P=exp(S0[Y[0]])\n", 63 | "for i in range(1,len(Y)):\n", 64 | " P *= exp((eval('S%d' % i)[Y[i]])+eval('T%d' % i)[Y[i-1]][Y[i]])\n", 65 | "print(P)\n", 66 | "print(exp(3.2))\n" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": {}, 72 | "source": [ 73 | "### 例11.2" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 3, 79 | "metadata": {}, 80 | "outputs": [ 81 | { 82 | "name": "stdout", 83 | "output_type": "stream", 84 | "text": [ 85 | "非规范化概率 24.532530197109345\n" 86 | ] 87 | } 88 | ], 89 | "source": [ 90 | "#这里根据例11.2的启发整合为一个矩阵\n", 91 | "F0=S0;F1=T1+array(S1*len(T1)).reshape(shape(T1));F2=T2+array(S2*len(T2)).reshape(shape(T2))\n", 92 | "Y=[1,2,2] #即书上例一需要计算的非规划条件概率的标记序列\n", 93 | "Y=array(Y)-1\n", 94 | "\n", 95 | "P=exp(F0[Y[0]])\n", 96 | "Sum=P\n", 97 | "for i in range(1,len(Y)):\n", 98 | " PIter=exp((eval('F%d' % i)[Y[i-1]][Y[i]]))\n", 99 | " P *= PIter\n", 100 | " Sum += PIter\n", 101 | "print('非规范化概率',P)\n" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [] 110 | } 111 | ], 112 | "metadata": { 113 | "kernelspec": { 114 | "display_name": "Python 3", 115 | "language": "python", 116 | "name": "python3" 117 | }, 118 | "language_info": { 119 | "codemirror_mode": { 120 | "name": "ipython", 121 | "version": 3 122 | }, 123 | "file_extension": ".py", 124 | "mimetype": "text/x-python", 125 | "name": "python", 126 | "nbconvert_exporter": "python", 127 | "pygments_lexer": "ipython3", 128 | "version": "3.6.2" 129 | } 130 | }, 131 | "nbformat": 4, 132 | "nbformat_minor": 2 133 | } 134 | -------------------------------------------------------------------------------- /code/第3章 k近邻法(KNearestNeighbors)/KDT.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from math import sqrt 3 | import pandas as pd 4 | from sklearn.datasets import load_iris 5 | import matplotlib.pyplot as plt 6 | from sklearn.model_selection import train_test_split 7 | 8 | iris = load_iris() 9 | df = pd.DataFrame(iris.data, columns=iris.feature_names) 10 | df['label'] = iris.target 11 | df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label'] 12 | 13 | data = np.array(df.iloc[:100, [0, 1, -1]]) 14 | train, test = train_test_split(data, test_size=0.1) 15 | x0 = np.array([x0 for i, x0 in enumerate(train) if train[i][-1] == 0]) 16 | x1 = np.array([x1 for i, x1 in enumerate(train) if train[i][-1] == 1]) 17 | 18 | 19 | def show_train(): 20 | plt.scatter(x0[:, 0], x0[:, 1], c='pink', label='[0]') 21 | plt.scatter(x1[:, 0], x1[:, 1], c='orange', label='[1]') 22 | plt.xlabel('sepal length') 23 | plt.ylabel('sepal width') 24 | 25 | 26 | class Node: 27 | def __init__(self, data, depth=0, lchild=None, rchild=None): 28 | self.data = data 29 | self.depth = depth 30 | self.lchild = lchild 31 | self.rchild = rchild 32 | 33 | 34 | class KdTree: 35 | def __init__(self): 36 | self.KdTree = None 37 | self.n = 0 38 | self.nearest = None 39 | 40 | def create(self, dataSet, depth=0): 41 | if len(dataSet) > 0: 42 | m, n = np.shape(dataSet) 43 | self.n = n - 1 44 | axis = depth % self.n 45 | mid = int(m / 2) 46 | dataSetcopy = sorted(dataSet, key=lambda x: x[axis]) 47 | node = Node(dataSetcopy[mid], depth) 48 | if depth == 0: 49 | self.KdTree = node 50 | node.lchild = self.create(dataSetcopy[:mid], depth+1) 51 | node.rchild = self.create(dataSetcopy[mid+1:], depth+1) 52 | return node 53 | return None 54 | 55 | def preOrder(self, node): 56 | if node is not None: 57 | print(node.depth, node.data) 58 | self.preOrder(node.lchild) 59 | self.preOrder(node.rchild) 60 | 61 | def search(self, x, count=1): 62 | nearest = [] 63 | for i in range(count): 64 | nearest.append([-1, None]) 65 | self.nearest = np.array(nearest) 66 | 67 | def recurve(node): 68 | if node is not None: 69 | axis = node.depth % self.n 70 | daxis = x[axis] - node.data[axis] 71 | if daxis < 0: 72 | recurve(node.lchild) 73 | else: 74 | recurve(node.rchild) 75 | 76 | dist = sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(x, node.data))) 77 | for i, d in enumerate(self.nearest): 78 | if d[0] < 0 or dist < d[0]: 79 | self.nearest = np.insert(self.nearest, i, [dist, node], axis=0) 80 | self.nearest = self.nearest[:-1] 81 | break 82 | 83 | n = list(self.nearest[:, 0]).count(-1) 84 | if self.nearest[-n-1, 0] > abs(daxis): 85 | if daxis < 0: 86 | recurve(node.rchild) 87 | else: 88 | recurve(node.lchild) 89 | 90 | recurve(self.KdTree) 91 | 92 | knn = self.nearest[:, 1] 93 | belong = [] 94 | for i in knn: 95 | belong.append(i.data[-1]) 96 | b = max(set(belong), key=belong.count) 97 | 98 | return self.nearest, b 99 | 100 | 101 | kdt = KdTree() 102 | kdt.create(train) 103 | kdt.preOrder(kdt.KdTree) 104 | 105 | score = 0 106 | for x in test: 107 | input('press Enter to show next:') 108 | show_train() 109 | plt.scatter(x[0], x[1], c='red', marker='x') # 测试点 110 | near, belong = kdt.search(x[:-1], 5) # 设置临近点的个数 111 | if belong == x[-1]: 112 | score += 1 113 | print("test:") 114 | print(x, "predict:", belong) 115 | print("nearest:") 116 | for n in near: 117 | print(n[1].data, "dist:", n[0]) 118 | plt.scatter(n[1].data[0], n[1].data[1], c='green', marker='+') # k个最近邻点 119 | plt.legend() 120 | plt.show() 121 | 122 | score /= len(test) 123 | print("score:", score) 124 | -------------------------------------------------------------------------------- /code/第4章 朴素贝叶斯(NaiveBayes)/GaussianNB.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "原文代码作者:https://github.com/wzyonggege/statistical-learning-method\n", 8 | "\n", 9 | "中文注释制作:机器学习初学者(微信公众号:ID:ai-start-com)\n", 10 | "\n", 11 | "配置环境:python 3.6\n", 12 | "\n", 13 | "代码全部测试通过。\n", 14 | "![gongzhong](../gongzhong.jpg)" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "# 第4章 朴素贝叶斯" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "基于贝叶斯定理与特征条件独立假设的分类方法。\n", 29 | "\n", 30 | "模型:\n", 31 | "\n", 32 | "- 高斯模型\n", 33 | "- 多项式模型\n", 34 | "- 伯努利模型" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 1, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "import numpy as np\n", 44 | "import pandas as pd\n", 45 | "import matplotlib.pyplot as plt\n", 46 | "%matplotlib inline\n", 47 | "\n", 48 | "from sklearn.datasets import load_iris\n", 49 | "from sklearn.model_selection import train_test_split\n", 50 | "\n", 51 | "from collections import Counter\n", 52 | "import math" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 2, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "# data\n", 62 | "def create_data():\n", 63 | " iris = load_iris()\n", 64 | " df = pd.DataFrame(iris.data, columns=iris.feature_names)\n", 65 | " df['label'] = iris.target\n", 66 | " df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n", 67 | " data = np.array(df.iloc[:100, :])\n", 68 | " # print(data)\n", 69 | " return data[:,:-1], data[:,-1]" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 3, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "X, y = create_data()\n", 79 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 4, 85 | "metadata": {}, 86 | "outputs": [ 87 | { 88 | "data": { 89 | "text/plain": [ 90 | "(array([4.6, 3.4, 1.4, 0.3]), 0.0)" 91 | ] 92 | }, 93 | "execution_count": 4, 94 | "metadata": {}, 95 | "output_type": "execute_result" 96 | } 97 | ], 98 | "source": [ 99 | "X_test[0], y_test[0]" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": {}, 105 | "source": [ 106 | "参考:https://machinelearningmastery.com/naive-bayes-classifier-scratch-python/\n", 107 | "\n", 108 | "## GaussianNB 高斯朴素贝叶斯\n", 109 | "\n", 110 | "特征的可能性被假设为高斯\n", 111 | "\n", 112 | "概率密度函数:\n", 113 | "$$P(x_i | y_k)=\\frac{1}{\\sqrt{2\\pi\\sigma^2_{yk}}}exp(-\\frac{(x_i-\\mu_{yk})^2}{2\\sigma^2_{yk}})$$\n", 114 | "\n", 115 | "数学期望(mean):$\\mu$,方差:$\\sigma^2=\\frac{\\sum(X-\\mu)^2}{N}$" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 5, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "class NaiveBayes:\n", 125 | " def __init__(self):\n", 126 | " self.model = None\n", 127 | "\n", 128 | " # 数学期望\n", 129 | " @staticmethod\n", 130 | " def mean(X):\n", 131 | " return sum(X) / float(len(X))\n", 132 | "\n", 133 | " # 标准差(方差)\n", 134 | " def stdev(self, X):\n", 135 | " avg = self.mean(X)\n", 136 | " return math.sqrt(sum([pow(x-avg, 2) for x in X]) / float(len(X)))\n", 137 | "\n", 138 | " # 概率密度函数\n", 139 | " def gaussian_probability(self, x, mean, stdev):\n", 140 | " exponent = math.exp(-(math.pow(x-mean,2)/(2*math.pow(stdev,2))))\n", 141 | " return (1 / (math.sqrt(2*math.pi) * stdev)) * exponent\n", 142 | "\n", 143 | " # 处理X_train\n", 144 | " def summarize(self, train_data):\n", 145 | " summaries = [(self.mean(i), self.stdev(i)) for i in zip(*train_data)]\n", 146 | " return summaries\n", 147 | "\n", 148 | " # 分类别求出数学期望和标准差\n", 149 | " def fit(self, X, y):\n", 150 | " labels = list(set(y))\n", 151 | " data = {label:[] for label in labels}\n", 152 | " for f, label in zip(X, y):\n", 153 | " data[label].append(f)\n", 154 | " self.model = {label: self.summarize(value) for label, value in data.items()}\n", 155 | " return 'gaussianNB train done!'\n", 156 | "\n", 157 | " # 计算概率\n", 158 | " def calculate_probabilities(self, input_data):\n", 159 | " # summaries:{0.0: [(5.0, 0.37),(3.42, 0.40)], 1.0: [(5.8, 0.449),(2.7, 0.27)]}\n", 160 | " # input_data:[1.1, 2.2]\n", 161 | " probabilities = {}\n", 162 | " for label, value in self.model.items():\n", 163 | " probabilities[label] = 1\n", 164 | " for i in range(len(value)):\n", 165 | " mean, stdev = value[i]\n", 166 | " probabilities[label] *= self.gaussian_probability(input_data[i], mean, stdev)\n", 167 | " return probabilities\n", 168 | "\n", 169 | " # 类别\n", 170 | " def predict(self, X_test):\n", 171 | " # {0.0: 2.9680340789325763e-27, 1.0: 3.5749783019849535e-26}\n", 172 | " label = sorted(self.calculate_probabilities(X_test).items(), key=lambda x: x[-1])[-1][0]\n", 173 | " return label\n", 174 | "\n", 175 | " def score(self, X_test, y_test):\n", 176 | " right = 0\n", 177 | " for X, y in zip(X_test, y_test):\n", 178 | " label = self.predict(X)\n", 179 | " if label == y:\n", 180 | " right += 1\n", 181 | "\n", 182 | " return right / float(len(X_test))" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 6, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "model = NaiveBayes()" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 7, 197 | "metadata": {}, 198 | "outputs": [ 199 | { 200 | "data": { 201 | "text/plain": [ 202 | "'gaussianNB train done!'" 203 | ] 204 | }, 205 | "execution_count": 7, 206 | "metadata": {}, 207 | "output_type": "execute_result" 208 | } 209 | ], 210 | "source": [ 211 | "model.fit(X_train, y_train)" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 8, 217 | "metadata": {}, 218 | "outputs": [ 219 | { 220 | "name": "stdout", 221 | "output_type": "stream", 222 | "text": [ 223 | "0.0\n" 224 | ] 225 | } 226 | ], 227 | "source": [ 228 | "print(model.predict([4.4, 3.2, 1.3, 0.2]))" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": 9, 234 | "metadata": {}, 235 | "outputs": [ 236 | { 237 | "data": { 238 | "text/plain": [ 239 | "1.0" 240 | ] 241 | }, 242 | "execution_count": 9, 243 | "metadata": {}, 244 | "output_type": "execute_result" 245 | } 246 | ], 247 | "source": [ 248 | "model.score(X_test, y_test)" 249 | ] 250 | }, 251 | { 252 | "cell_type": "markdown", 253 | "metadata": { 254 | "collapsed": true 255 | }, 256 | "source": [ 257 | "scikit-learn实例\n", 258 | "\n", 259 | "# sklearn.naive_bayes" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 10, 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "from sklearn.naive_bayes import GaussianNB" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": 11, 274 | "metadata": {}, 275 | "outputs": [ 276 | { 277 | "data": { 278 | "text/plain": [ 279 | "GaussianNB(priors=None)" 280 | ] 281 | }, 282 | "execution_count": 11, 283 | "metadata": {}, 284 | "output_type": "execute_result" 285 | } 286 | ], 287 | "source": [ 288 | "clf = GaussianNB()\n", 289 | "clf.fit(X_train, y_train)" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 12, 295 | "metadata": {}, 296 | "outputs": [ 297 | { 298 | "data": { 299 | "text/plain": [ 300 | "1.0" 301 | ] 302 | }, 303 | "execution_count": 12, 304 | "metadata": {}, 305 | "output_type": "execute_result" 306 | } 307 | ], 308 | "source": [ 309 | "clf.score(X_test, y_test)" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": 14, 315 | "metadata": {}, 316 | "outputs": [ 317 | { 318 | "data": { 319 | "text/plain": [ 320 | "array([0.])" 321 | ] 322 | }, 323 | "execution_count": 14, 324 | "metadata": {}, 325 | "output_type": "execute_result" 326 | } 327 | ], 328 | "source": [ 329 | "clf.predict([[4.4, 3.2, 1.3, 0.2]])" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": 15, 335 | "metadata": {}, 336 | "outputs": [], 337 | "source": [ 338 | "from sklearn.naive_bayes import BernoulliNB, MultinomialNB # 伯努利模型和多项式模型" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": null, 344 | "metadata": { 345 | "collapsed": true 346 | }, 347 | "outputs": [], 348 | "source": [] 349 | } 350 | ], 351 | "metadata": { 352 | "kernelspec": { 353 | "display_name": "Python 3", 354 | "language": "python", 355 | "name": "python3" 356 | }, 357 | "language_info": { 358 | "codemirror_mode": { 359 | "name": "ipython", 360 | "version": 3 361 | }, 362 | "file_extension": ".py", 363 | "mimetype": "text/x-python", 364 | "name": "python", 365 | "nbconvert_exporter": "python", 366 | "pygments_lexer": "ipython3", 367 | "version": "3.6.2" 368 | } 369 | }, 370 | "nbformat": 4, 371 | "nbformat_minor": 2 372 | } 373 | -------------------------------------------------------------------------------- /code/第5章 决策树(DecisonTree)/DT.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "原文代码作者:https://github.com/wzyonggege/statistical-learning-method\n", 8 | "\n", 9 | "中文注释制作:机器学习初学者(微信公众号:ID:ai-start-com)\n", 10 | "\n", 11 | "配置环境:python 3.6\n", 12 | "\n", 13 | "代码全部测试通过。\n", 14 | "![gongzhong](../gongzhong.jpg)" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "# 第5章 决策树" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "- ID3(基于信息增益)\n", 29 | "- C4.5(基于信息增益比)\n", 30 | "- CART(gini指数)" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "#### entropy:$H(x) = -\\sum_{i=1}^{n}p_i\\log{p_i}$\n", 38 | "\n", 39 | "#### conditional entropy: $H(X|Y)=\\sum{P(X|Y)}\\log{P(X|Y)}$\n", 40 | "\n", 41 | "#### information gain : $g(D, A)=H(D)-H(D|A)$\n", 42 | "\n", 43 | "#### information gain ratio: $g_R(D, A) = \\frac{g(D,A)}{H(A)}$\n", 44 | "\n", 45 | "#### gini index:$Gini(D)=\\sum_{k=1}^{K}p_k\\log{p_k}=1-\\sum_{k=1}^{K}p_k^2$" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 1, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "import numpy as np\n", 55 | "import pandas as pd\n", 56 | "import matplotlib.pyplot as plt\n", 57 | "%matplotlib inline\n", 58 | "\n", 59 | "from sklearn.datasets import load_iris\n", 60 | "from sklearn.model_selection import train_test_split\n", 61 | "\n", 62 | "from collections import Counter\n", 63 | "import math\n", 64 | "from math import log\n", 65 | "\n", 66 | "import pprint" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": {}, 72 | "source": [ 73 | "### 书上题目5.1" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 2, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "# 书上题目5.1\n", 83 | "def create_data():\n", 84 | " datasets = [['青年', '否', '否', '一般', '否'],\n", 85 | " ['青年', '否', '否', '好', '否'],\n", 86 | " ['青年', '是', '否', '好', '是'],\n", 87 | " ['青年', '是', '是', '一般', '是'],\n", 88 | " ['青年', '否', '否', '一般', '否'],\n", 89 | " ['中年', '否', '否', '一般', '否'],\n", 90 | " ['中年', '否', '否', '好', '否'],\n", 91 | " ['中年', '是', '是', '好', '是'],\n", 92 | " ['中年', '否', '是', '非常好', '是'],\n", 93 | " ['中年', '否', '是', '非常好', '是'],\n", 94 | " ['老年', '否', '是', '非常好', '是'],\n", 95 | " ['老年', '否', '是', '好', '是'],\n", 96 | " ['老年', '是', '否', '好', '是'],\n", 97 | " ['老年', '是', '否', '非常好', '是'],\n", 98 | " ['老年', '否', '否', '一般', '否'],\n", 99 | " ]\n", 100 | " labels = [u'年龄', u'有工作', u'有自己的房子', u'信贷情况', u'类别']\n", 101 | " # 返回数据集和每个维度的名称\n", 102 | " return datasets, labels" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 3, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "datasets, labels = create_data()" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 4, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "train_data = pd.DataFrame(datasets, columns=labels)" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 5, 126 | "metadata": {}, 127 | "outputs": [ 128 | { 129 | "data": { 130 | "text/html": [ 131 | "
\n", 132 | "\n", 145 | "\n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | "
年龄有工作有自己的房子信贷情况类别
0青年一般
1青年
2青年
3青年一般
4青年一般
5中年一般
6中年
7中年
8中年非常好
9中年非常好
10老年非常好
11老年
12老年
13老年非常好
14老年一般
\n", 279 | "
" 280 | ], 281 | "text/plain": [ 282 | " 年龄 有工作 有自己的房子 信贷情况 类别\n", 283 | "0 青年 否 否 一般 否\n", 284 | "1 青年 否 否 好 否\n", 285 | "2 青年 是 否 好 是\n", 286 | "3 青年 是 是 一般 是\n", 287 | "4 青年 否 否 一般 否\n", 288 | "5 中年 否 否 一般 否\n", 289 | "6 中年 否 否 好 否\n", 290 | "7 中年 是 是 好 是\n", 291 | "8 中年 否 是 非常好 是\n", 292 | "9 中年 否 是 非常好 是\n", 293 | "10 老年 否 是 非常好 是\n", 294 | "11 老年 否 是 好 是\n", 295 | "12 老年 是 否 好 是\n", 296 | "13 老年 是 否 非常好 是\n", 297 | "14 老年 否 否 一般 否" 298 | ] 299 | }, 300 | "execution_count": 5, 301 | "metadata": {}, 302 | "output_type": "execute_result" 303 | } 304 | ], 305 | "source": [ 306 | "train_data" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": 6, 312 | "metadata": {}, 313 | "outputs": [], 314 | "source": [ 315 | "# 熵\n", 316 | "def calc_ent(datasets):\n", 317 | " data_length = len(datasets)\n", 318 | " label_count = {}\n", 319 | " for i in range(data_length):\n", 320 | " label = datasets[i][-1]\n", 321 | " if label not in label_count:\n", 322 | " label_count[label] = 0\n", 323 | " label_count[label] += 1\n", 324 | " ent = -sum([(p/data_length)*log(p/data_length, 2) for p in label_count.values()])\n", 325 | " return ent\n", 326 | "\n", 327 | "# 经验条件熵\n", 328 | "def cond_ent(datasets, axis=0):\n", 329 | " data_length = len(datasets)\n", 330 | " feature_sets = {}\n", 331 | " for i in range(data_length):\n", 332 | " feature = datasets[i][axis]\n", 333 | " if feature not in feature_sets:\n", 334 | " feature_sets[feature] = []\n", 335 | " feature_sets[feature].append(datasets[i])\n", 336 | " cond_ent = sum([(len(p)/data_length)*calc_ent(p) for p in feature_sets.values()])\n", 337 | " return cond_ent\n", 338 | "\n", 339 | "# 信息增益\n", 340 | "def info_gain(ent, cond_ent):\n", 341 | " return ent - cond_ent\n", 342 | "\n", 343 | "def info_gain_train(datasets):\n", 344 | " count = len(datasets[0]) - 1\n", 345 | " ent = calc_ent(datasets)\n", 346 | " best_feature = []\n", 347 | " for c in range(count):\n", 348 | " c_info_gain = info_gain(ent, cond_ent(datasets, axis=c))\n", 349 | " best_feature.append((c, c_info_gain))\n", 350 | " print('特征({}) - info_gain - {:.3f}'.format(labels[c], c_info_gain))\n", 351 | " # 比较大小\n", 352 | " best_ = max(best_feature, key=lambda x: x[-1])\n", 353 | " return '特征({})的信息增益最大,选择为根节点特征'.format(labels[best_[0]])" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": 7, 359 | "metadata": {}, 360 | "outputs": [ 361 | { 362 | "name": "stdout", 363 | "output_type": "stream", 364 | "text": [ 365 | "特征(年龄) - info_gain - 0.083\n", 366 | "特征(有工作) - info_gain - 0.324\n", 367 | "特征(有自己的房子) - info_gain - 0.420\n", 368 | "特征(信贷情况) - info_gain - 0.363\n" 369 | ] 370 | }, 371 | { 372 | "data": { 373 | "text/plain": [ 374 | "'特征(有自己的房子)的信息增益最大,选择为根节点特征'" 375 | ] 376 | }, 377 | "execution_count": 7, 378 | "metadata": {}, 379 | "output_type": "execute_result" 380 | } 381 | ], 382 | "source": [ 383 | "info_gain_train(np.array(datasets))" 384 | ] 385 | }, 386 | { 387 | "cell_type": "markdown", 388 | "metadata": { 389 | "collapsed": true 390 | }, 391 | "source": [ 392 | "---\n", 393 | "\n", 394 | "利用ID3算法生成决策树,例5.3" 395 | ] 396 | }, 397 | { 398 | "cell_type": "code", 399 | "execution_count": 8, 400 | "metadata": {}, 401 | "outputs": [], 402 | "source": [ 403 | "# 定义节点类 二叉树\n", 404 | "class Node:\n", 405 | " def __init__(self, root=True, label=None, feature_name=None, feature=None):\n", 406 | " self.root = root\n", 407 | " self.label = label\n", 408 | " self.feature_name = feature_name\n", 409 | " self.feature = feature\n", 410 | " self.tree = {}\n", 411 | " self.result = {'label:': self.label, 'feature': self.feature, 'tree': self.tree}\n", 412 | "\n", 413 | " def __repr__(self):\n", 414 | " return '{}'.format(self.result)\n", 415 | "\n", 416 | " def add_node(self, val, node):\n", 417 | " self.tree[val] = node\n", 418 | "\n", 419 | " def predict(self, features):\n", 420 | " if self.root is True:\n", 421 | " return self.label\n", 422 | " return self.tree[features[self.feature]].predict(features)\n", 423 | " \n", 424 | "class DTree:\n", 425 | " def __init__(self, epsilon=0.1):\n", 426 | " self.epsilon = epsilon\n", 427 | " self._tree = {}\n", 428 | "\n", 429 | " # 熵\n", 430 | " @staticmethod\n", 431 | " def calc_ent(datasets):\n", 432 | " data_length = len(datasets)\n", 433 | " label_count = {}\n", 434 | " for i in range(data_length):\n", 435 | " label = datasets[i][-1]\n", 436 | " if label not in label_count:\n", 437 | " label_count[label] = 0\n", 438 | " label_count[label] += 1\n", 439 | " ent = -sum([(p/data_length)*log(p/data_length, 2) for p in label_count.values()])\n", 440 | " return ent\n", 441 | "\n", 442 | " # 经验条件熵\n", 443 | " def cond_ent(self, datasets, axis=0):\n", 444 | " data_length = len(datasets)\n", 445 | " feature_sets = {}\n", 446 | " for i in range(data_length):\n", 447 | " feature = datasets[i][axis]\n", 448 | " if feature not in feature_sets:\n", 449 | " feature_sets[feature] = []\n", 450 | " feature_sets[feature].append(datasets[i])\n", 451 | " cond_ent = sum([(len(p)/data_length)*self.calc_ent(p) for p in feature_sets.values()])\n", 452 | " return cond_ent\n", 453 | "\n", 454 | " # 信息增益\n", 455 | " @staticmethod\n", 456 | " def info_gain(ent, cond_ent):\n", 457 | " return ent - cond_ent\n", 458 | "\n", 459 | " def info_gain_train(self, datasets):\n", 460 | " count = len(datasets[0]) - 1\n", 461 | " ent = self.calc_ent(datasets)\n", 462 | " best_feature = []\n", 463 | " for c in range(count):\n", 464 | " c_info_gain = self.info_gain(ent, self.cond_ent(datasets, axis=c))\n", 465 | " best_feature.append((c, c_info_gain))\n", 466 | " # 比较大小\n", 467 | " best_ = max(best_feature, key=lambda x: x[-1])\n", 468 | " return best_\n", 469 | "\n", 470 | " def train(self, train_data):\n", 471 | " \"\"\"\n", 472 | " input:数据集D(DataFrame格式),特征集A,阈值eta\n", 473 | " output:决策树T\n", 474 | " \"\"\"\n", 475 | " _, y_train, features = train_data.iloc[:, :-1], train_data.iloc[:, -1], train_data.columns[:-1]\n", 476 | " # 1,若D中实例属于同一类Ck,则T为单节点树,并将类Ck作为结点的类标记,返回T\n", 477 | " if len(y_train.value_counts()) == 1:\n", 478 | " return Node(root=True,\n", 479 | " label=y_train.iloc[0])\n", 480 | "\n", 481 | " # 2, 若A为空,则T为单节点树,将D中实例树最大的类Ck作为该节点的类标记,返回T\n", 482 | " if len(features) == 0:\n", 483 | " return Node(root=True, label=y_train.value_counts().sort_values(ascending=False).index[0])\n", 484 | "\n", 485 | " # 3,计算最大信息增益 同5.1,Ag为信息增益最大的特征\n", 486 | " max_feature, max_info_gain = self.info_gain_train(np.array(train_data))\n", 487 | " max_feature_name = features[max_feature]\n", 488 | "\n", 489 | " # 4,Ag的信息增益小于阈值eta,则置T为单节点树,并将D中是实例数最大的类Ck作为该节点的类标记,返回T\n", 490 | " if max_info_gain < self.epsilon:\n", 491 | " return Node(root=True, label=y_train.value_counts().sort_values(ascending=False).index[0])\n", 492 | "\n", 493 | " # 5,构建Ag子集\n", 494 | " node_tree = Node(root=False, feature_name=max_feature_name, feature=max_feature)\n", 495 | "\n", 496 | " feature_list = train_data[max_feature_name].value_counts().index\n", 497 | " for f in feature_list:\n", 498 | " sub_train_df = train_data.loc[train_data[max_feature_name] == f].drop([max_feature_name], axis=1)\n", 499 | "\n", 500 | " # 6, 递归生成树\n", 501 | " sub_tree = self.train(sub_train_df)\n", 502 | " node_tree.add_node(f, sub_tree)\n", 503 | "\n", 504 | " # pprint.pprint(node_tree.tree)\n", 505 | " return node_tree\n", 506 | "\n", 507 | " def fit(self, train_data):\n", 508 | " self._tree = self.train(train_data)\n", 509 | " return self._tree\n", 510 | "\n", 511 | " def predict(self, X_test):\n", 512 | " return self._tree.predict(X_test)" 513 | ] 514 | }, 515 | { 516 | "cell_type": "code", 517 | "execution_count": 9, 518 | "metadata": {}, 519 | "outputs": [], 520 | "source": [ 521 | "datasets, labels = create_data()\n", 522 | "data_df = pd.DataFrame(datasets, columns=labels)\n", 523 | "dt = DTree()\n", 524 | "tree = dt.fit(data_df)" 525 | ] 526 | }, 527 | { 528 | "cell_type": "code", 529 | "execution_count": 10, 530 | "metadata": { 531 | "scrolled": true 532 | }, 533 | "outputs": [ 534 | { 535 | "data": { 536 | "text/plain": [ 537 | "{'label:': None, 'feature': 2, 'tree': {'否': {'label:': None, 'feature': 1, 'tree': {'否': {'label:': '否', 'feature': None, 'tree': {}}, '是': {'label:': '是', 'feature': None, 'tree': {}}}}, '是': {'label:': '是', 'feature': None, 'tree': {}}}}" 538 | ] 539 | }, 540 | "execution_count": 10, 541 | "metadata": {}, 542 | "output_type": "execute_result" 543 | } 544 | ], 545 | "source": [ 546 | "tree" 547 | ] 548 | }, 549 | { 550 | "cell_type": "code", 551 | "execution_count": 11, 552 | "metadata": {}, 553 | "outputs": [ 554 | { 555 | "data": { 556 | "text/plain": [ 557 | "'否'" 558 | ] 559 | }, 560 | "execution_count": 11, 561 | "metadata": {}, 562 | "output_type": "execute_result" 563 | } 564 | ], 565 | "source": [ 566 | "dt.predict(['老年', '否', '否', '一般'])" 567 | ] 568 | }, 569 | { 570 | "cell_type": "markdown", 571 | "metadata": {}, 572 | "source": [ 573 | "---\n", 574 | "\n", 575 | "## sklearn.tree.DecisionTreeClassifier\n", 576 | "\n", 577 | "### criterion : string, optional (default=”gini”)\n", 578 | "The function to measure the quality of a split. Supported criteria are “gini” for the Gini impurity and “entropy” for the information gain." 579 | ] 580 | }, 581 | { 582 | "cell_type": "code", 583 | "execution_count": 12, 584 | "metadata": {}, 585 | "outputs": [], 586 | "source": [ 587 | "# data\n", 588 | "def create_data():\n", 589 | " iris = load_iris()\n", 590 | " df = pd.DataFrame(iris.data, columns=iris.feature_names)\n", 591 | " df['label'] = iris.target\n", 592 | " df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n", 593 | " data = np.array(df.iloc[:100, [0, 1, -1]])\n", 594 | " # print(data)\n", 595 | " return data[:,:2], data[:,-1]\n", 596 | "\n", 597 | "X, y = create_data()\n", 598 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)" 599 | ] 600 | }, 601 | { 602 | "cell_type": "code", 603 | "execution_count": 13, 604 | "metadata": {}, 605 | "outputs": [], 606 | "source": [ 607 | "from sklearn.tree import DecisionTreeClassifier\n", 608 | "\n", 609 | "from sklearn.tree import export_graphviz\n", 610 | "import graphviz" 611 | ] 612 | }, 613 | { 614 | "cell_type": "code", 615 | "execution_count": 14, 616 | "metadata": {}, 617 | "outputs": [ 618 | { 619 | "data": { 620 | "text/plain": [ 621 | "DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,\n", 622 | " max_features=None, max_leaf_nodes=None,\n", 623 | " min_impurity_decrease=0.0, min_impurity_split=None,\n", 624 | " min_samples_leaf=1, min_samples_split=2,\n", 625 | " min_weight_fraction_leaf=0.0, presort=False, random_state=None,\n", 626 | " splitter='best')" 627 | ] 628 | }, 629 | "execution_count": 14, 630 | "metadata": {}, 631 | "output_type": "execute_result" 632 | } 633 | ], 634 | "source": [ 635 | "clf = DecisionTreeClassifier()\n", 636 | "clf.fit(X_train, y_train,)" 637 | ] 638 | }, 639 | { 640 | "cell_type": "code", 641 | "execution_count": 15, 642 | "metadata": {}, 643 | "outputs": [ 644 | { 645 | "data": { 646 | "text/plain": [ 647 | "1.0" 648 | ] 649 | }, 650 | "execution_count": 15, 651 | "metadata": {}, 652 | "output_type": "execute_result" 653 | } 654 | ], 655 | "source": [ 656 | "clf.score(X_test, y_test)" 657 | ] 658 | }, 659 | { 660 | "cell_type": "code", 661 | "execution_count": 16, 662 | "metadata": {}, 663 | "outputs": [], 664 | "source": [ 665 | "tree_pic = export_graphviz(clf, out_file=\"mytree.pdf\")\n", 666 | "with open('mytree.pdf') as f:\n", 667 | " dot_graph = f.read()" 668 | ] 669 | }, 670 | { 671 | "cell_type": "code", 672 | "execution_count": 17, 673 | "metadata": {}, 674 | "outputs": [ 675 | { 676 | "data": { 677 | "image/svg+xml": [ 678 | "\r\n", 679 | "\r\n", 681 | "\r\n", 683 | "\r\n", 684 | "\r\n", 686 | "\r\n", 687 | "Tree\r\n", 688 | "\r\n", 689 | "\r\n", 690 | "0\r\n", 691 | "\r\n", 692 | "X[0] <= 5.45\r\n", 693 | "gini = 0.5\r\n", 694 | "samples = 70\r\n", 695 | "value = [35, 35]\r\n", 696 | "\r\n", 697 | "\r\n", 698 | "1\r\n", 699 | "\r\n", 700 | "X[1] <= 2.85\r\n", 701 | "gini = 0.234\r\n", 702 | "samples = 37\r\n", 703 | "value = [32, 5]\r\n", 704 | "\r\n", 705 | "\r\n", 706 | "0->1\r\n", 707 | "\r\n", 708 | "\r\n", 709 | "True\r\n", 710 | "\r\n", 711 | "\r\n", 712 | "10\r\n", 713 | "\r\n", 714 | "X[1] <= 3.45\r\n", 715 | "gini = 0.165\r\n", 716 | "samples = 33\r\n", 717 | "value = [3, 30]\r\n", 718 | "\r\n", 719 | "\r\n", 720 | "0->10\r\n", 721 | "\r\n", 722 | "\r\n", 723 | "False\r\n", 724 | "\r\n", 725 | "\r\n", 726 | "2\r\n", 727 | "\r\n", 728 | "X[0] <= 4.7\r\n", 729 | "gini = 0.32\r\n", 730 | "samples = 5\r\n", 731 | "value = [1, 4]\r\n", 732 | "\r\n", 733 | "\r\n", 734 | "1->2\r\n", 735 | "\r\n", 736 | "\r\n", 737 | "\r\n", 738 | "\r\n", 739 | "5\r\n", 740 | "\r\n", 741 | "X[0] <= 5.35\r\n", 742 | "gini = 0.061\r\n", 743 | "samples = 32\r\n", 744 | "value = [31, 1]\r\n", 745 | "\r\n", 746 | "\r\n", 747 | "1->5\r\n", 748 | "\r\n", 749 | "\r\n", 750 | "\r\n", 751 | "\r\n", 752 | "3\r\n", 753 | "\r\n", 754 | "gini = 0.0\r\n", 755 | "samples = 1\r\n", 756 | "value = [1, 0]\r\n", 757 | "\r\n", 758 | "\r\n", 759 | "2->3\r\n", 760 | "\r\n", 761 | "\r\n", 762 | "\r\n", 763 | "\r\n", 764 | "4\r\n", 765 | "\r\n", 766 | "gini = 0.0\r\n", 767 | "samples = 4\r\n", 768 | "value = [0, 4]\r\n", 769 | "\r\n", 770 | "\r\n", 771 | "2->4\r\n", 772 | "\r\n", 773 | "\r\n", 774 | "\r\n", 775 | "\r\n", 776 | "6\r\n", 777 | "\r\n", 778 | "gini = 0.0\r\n", 779 | "samples = 28\r\n", 780 | "value = [28, 0]\r\n", 781 | "\r\n", 782 | "\r\n", 783 | "5->6\r\n", 784 | "\r\n", 785 | "\r\n", 786 | "\r\n", 787 | "\r\n", 788 | "7\r\n", 789 | "\r\n", 790 | "X[1] <= 3.2\r\n", 791 | "gini = 0.375\r\n", 792 | "samples = 4\r\n", 793 | "value = [3, 1]\r\n", 794 | "\r\n", 795 | "\r\n", 796 | "5->7\r\n", 797 | "\r\n", 798 | "\r\n", 799 | "\r\n", 800 | "\r\n", 801 | "8\r\n", 802 | "\r\n", 803 | "gini = 0.0\r\n", 804 | "samples = 1\r\n", 805 | "value = [0, 1]\r\n", 806 | "\r\n", 807 | "\r\n", 808 | "7->8\r\n", 809 | "\r\n", 810 | "\r\n", 811 | "\r\n", 812 | "\r\n", 813 | "9\r\n", 814 | "\r\n", 815 | "gini = 0.0\r\n", 816 | "samples = 3\r\n", 817 | "value = [3, 0]\r\n", 818 | "\r\n", 819 | "\r\n", 820 | "7->9\r\n", 821 | "\r\n", 822 | "\r\n", 823 | "\r\n", 824 | "\r\n", 825 | "11\r\n", 826 | "\r\n", 827 | "gini = 0.0\r\n", 828 | "samples = 30\r\n", 829 | "value = [0, 30]\r\n", 830 | "\r\n", 831 | "\r\n", 832 | "10->11\r\n", 833 | "\r\n", 834 | "\r\n", 835 | "\r\n", 836 | "\r\n", 837 | "12\r\n", 838 | "\r\n", 839 | "gini = 0.0\r\n", 840 | "samples = 3\r\n", 841 | "value = [3, 0]\r\n", 842 | "\r\n", 843 | "\r\n", 844 | "10->12\r\n", 845 | "\r\n", 846 | "\r\n", 847 | "\r\n", 848 | "\r\n", 849 | "\r\n" 850 | ], 851 | "text/plain": [ 852 | "" 853 | ] 854 | }, 855 | "execution_count": 17, 856 | "metadata": {}, 857 | "output_type": "execute_result" 858 | } 859 | ], 860 | "source": [ 861 | "graphviz.Source(dot_graph)" 862 | ] 863 | } 864 | ], 865 | "metadata": { 866 | "kernelspec": { 867 | "display_name": "Python 3", 868 | "language": "python", 869 | "name": "python3" 870 | }, 871 | "language_info": { 872 | "codemirror_mode": { 873 | "name": "ipython", 874 | "version": 3 875 | }, 876 | "file_extension": ".py", 877 | "mimetype": "text/x-python", 878 | "name": "python", 879 | "nbconvert_exporter": "python", 880 | "pygments_lexer": "ipython3", 881 | "version": "3.6.2" 882 | } 883 | }, 884 | "nbformat": 4, 885 | "nbformat_minor": 2 886 | } 887 | -------------------------------------------------------------------------------- /code/第5章 决策树(DecisonTree)/Decision Tree (ID3 剪枝): -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from collections import Counter 4 | import math 5 | 6 | 7 | class Node: 8 | def __init__(self, x=None, label=None, y=None, data=None): 9 | self.label = label # label:子节点分类依据的特征 10 | self.x = x # x:特征 11 | self.child = [] # child:子节点 12 | self.y = y # y:类标记(叶节点才有) 13 | self.data = data # data:包含数据(叶节点才有) 14 | 15 | def append(self, node): # 添加子节点 16 | self.child.append(node) 17 | 18 | def predict(self, features): # 预测数据所述类 19 | if self.y is not None: 20 | return self.y 21 | for c in self.child: 22 | if c.x == features[self.label]: 23 | return c.predict(features) 24 | 25 | 26 | def printnode(node, depth=0): # 打印树所有节点 27 | if node.label is None: 28 | print(depth, (node.label, node.x, node.y, len(node.data))) 29 | else: 30 | print(depth, (node.label, node.x)) 31 | for c in node.child: 32 | printnode(c, depth+1) 33 | 34 | 35 | class DTree: 36 | def __init__(self, epsilon=0, alpha=0): # 预剪枝、后剪枝参数 37 | self.epsilon = epsilon 38 | self.alpha = alpha 39 | self.tree = Node() 40 | 41 | def prob(self, datasets): # 求概率 42 | datalen = len(datasets) 43 | labelx = set(datasets) 44 | p = {l: 0 for l in labelx} 45 | for d in datasets: 46 | p[d] += 1 47 | for i in p.items(): 48 | p[i[0]] /= datalen 49 | return p 50 | 51 | def calc_ent(self, datasets): # 求熵 52 | p = self.prob(datasets) 53 | ent = sum([-v * math.log(v, 2) for v in p.values()]) 54 | return ent 55 | 56 | def cond_ent(self, datasets, col): # 求条件熵 57 | labelx = set(datasets.iloc[col]) 58 | p = {x: [] for x in labelx} 59 | for i, d in enumerate(datasets.iloc[-1]): 60 | p[datasets.iloc[col][i]].append(d) 61 | return sum([self.prob(datasets.iloc[col])[k] * self.calc_ent(p[k]) for k in p.keys()]) 62 | 63 | def info_gain_train(self, datasets, datalabels): # 求信息增益(互信息) 64 | #print('----信息增益----') 65 | datasets = datasets.T 66 | ent = self.calc_ent(datasets.iloc[-1]) 67 | gainmax = {} 68 | for i in range(len(datasets) - 1): 69 | cond = self.cond_ent(datasets, i) 70 | #print(datalabels[i], ent - cond) 71 | gainmax[ent - cond] = i 72 | m = max(gainmax.keys()) 73 | return gainmax[m], m 74 | 75 | def train(self, datasets, node): 76 | labely = datasets.columns[-1] 77 | if len(datasets[labely].value_counts()) == 1: 78 | node.data = datasets[labely] 79 | node.y = datasets[labely][0] 80 | return 81 | if len(datasets.columns[:-1]) == 0: 82 | node.data = datasets[labely] 83 | node.y = datasets[labely].value_counts().index[0] 84 | return 85 | gainmaxi, gainmax = self.info_gain_train(datasets, datasets.columns) 86 | #print('选择特征:', gainmaxi) 87 | if gainmax <= self.epsilon: # 若信息增益(互信息)为0意为输入特征x完全相同而标签y相反 88 | node.data = datasets[labely] 89 | node.y = datasets[labely].value_counts().index[0] 90 | return 91 | 92 | vc = datasets[datasets.columns[gainmaxi]].value_counts() 93 | for Di in vc.index: 94 | node.label = gainmaxi 95 | child = Node(Di) 96 | node.append(child) 97 | new_datasets = pd.DataFrame([list(i) for i in datasets.values if i[gainmaxi]==Di], columns=datasets.columns) 98 | self.train(new_datasets, child) 99 | 100 | def fit(self, datasets): 101 | self.train(datasets, self.tree) 102 | 103 | def findleaf(self, node, leaf): # 找到所有叶节点 104 | for t in node.child: 105 | if t.y is not None: 106 | leaf.append(t.data) 107 | else: 108 | for c in node.child: 109 | self.findleaf(c, leaf) 110 | 111 | def findfather(self, node, errormin): 112 | if node.label is not None: 113 | cy = [c.y for c in node.child] 114 | if None not in cy: # 全是叶节点 115 | childdata = [] 116 | for c in node.child: 117 | for d in list(c.data): 118 | childdata.append(d) 119 | childcounter = Counter(childdata) 120 | 121 | old_child = node.child # 剪枝前先拷贝一下 122 | old_label = node.label 123 | old_y = node.y 124 | old_data = node.data 125 | 126 | node.label = None # 剪枝 127 | node.y = childcounter.most_common(1)[0][0] 128 | node.data = childdata 129 | 130 | error = self.c_error() 131 | if error <= errormin: # 剪枝前后损失比较 132 | errormin = error 133 | return 1 134 | else: 135 | node.child = old_child # 剪枝效果不好,则复原 136 | node.label = old_label 137 | node.y = old_y 138 | node.data = old_data 139 | else: 140 | re = 0 141 | i = 0 142 | while i < len(node.child): 143 | if_re = self.findfather(node.child[i], errormin) # 若剪过枝,则其父节点要重新检测 144 | if if_re == 1: 145 | re = 1 146 | elif if_re == 2: 147 | i -= 1 148 | i += 1 149 | if re: 150 | return 2 151 | return 0 152 | 153 | def c_error(self): # 求C(T) 154 | leaf = [] 155 | self.findleaf(self.tree, leaf) 156 | leafnum = [len(l) for l in leaf] 157 | ent = [self.calc_ent(l) for l in leaf] 158 | print("Ent:", ent) 159 | error = self.alpha*len(leafnum) 160 | for l, e in zip(leafnum, ent): 161 | error += l*e 162 | print("C(T):", error) 163 | return error 164 | 165 | def cut(self, alpha=0): # 剪枝 166 | if alpha: 167 | self.alpha = alpha 168 | errormin = self.c_error() 169 | self.findfather(self.tree, errormin) 170 | 171 | 172 | datasets = np.array([['青年', '否', '否', '一般', '否'], 173 | ['青年', '否', '否', '好', '否'], 174 | ['青年', '是', '否', '好', '是'], 175 | ['青年', '是', '是', '一般', '是'], 176 | ['青年', '否', '否', '一般', '否'], 177 | ['中年', '否', '否', '一般', '否'], 178 | ['中年', '否', '否', '好', '否'], 179 | ['中年', '是', '是', '好', '是'], 180 | ['中年', '否', '是', '非常好', '是'], 181 | ['中年', '否', '是', '非常好', '是'], 182 | ['老年', '否', '是', '非常好', '是'], 183 | ['老年', '否', '是', '好', '是'], 184 | ['老年', '是', '否', '好', '是'], 185 | ['老年', '是', '否', '非常好', '是'], 186 | ['老年', '否', '否', '一般', '否'], 187 | ['青年', '否', '否', '一般', '是']]) # 在李航原始数据上多加了最后这行数据,以便体现剪枝效果 188 | 189 | datalabels = np.array(['年龄', '有工作', '有自己的房子', '信贷情况', '类别']) 190 | train_data = pd.DataFrame(datasets, columns=datalabels) 191 | test_data = ['老年', '否', '否', '一般'] 192 | 193 | dt = DTree(epsilon=0) # 可修改epsilon查看预剪枝效果 194 | dt.fit(train_data) 195 | 196 | print('DTree:') 197 | printnode(dt.tree) 198 | y = dt.tree.predict(test_data) 199 | print('result:', y) 200 | 201 | dt.cut(alpha=0.5) # 可修改正则化参数alpha查看后剪枝效果 202 | 203 | print('DTree:') 204 | printnode(dt.tree) 205 | y = dt.tree.predict(test_data) 206 | print('result:', y) 207 | -------------------------------------------------------------------------------- /code/第5章 决策树(DecisonTree)/mytree.pdf: -------------------------------------------------------------------------------- 1 | digraph Tree { 2 | node [shape=box] ; 3 | 0 [label="X[0] <= 5.45\ngini = 0.5\nsamples = 70\nvalue = [35, 35]"] ; 4 | 1 [label="X[1] <= 2.85\ngini = 0.234\nsamples = 37\nvalue = [32, 5]"] ; 5 | 0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ; 6 | 2 [label="X[0] <= 4.7\ngini = 0.32\nsamples = 5\nvalue = [1, 4]"] ; 7 | 1 -> 2 ; 8 | 3 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]"] ; 9 | 2 -> 3 ; 10 | 4 [label="gini = 0.0\nsamples = 4\nvalue = [0, 4]"] ; 11 | 2 -> 4 ; 12 | 5 [label="X[0] <= 5.35\ngini = 0.061\nsamples = 32\nvalue = [31, 1]"] ; 13 | 1 -> 5 ; 14 | 6 [label="gini = 0.0\nsamples = 28\nvalue = [28, 0]"] ; 15 | 5 -> 6 ; 16 | 7 [label="X[1] <= 3.2\ngini = 0.375\nsamples = 4\nvalue = [3, 1]"] ; 17 | 5 -> 7 ; 18 | 8 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]"] ; 19 | 7 -> 8 ; 20 | 9 [label="gini = 0.0\nsamples = 3\nvalue = [3, 0]"] ; 21 | 7 -> 9 ; 22 | 10 [label="X[1] <= 3.45\ngini = 0.165\nsamples = 33\nvalue = [3, 30]"] ; 23 | 0 -> 10 [labeldistance=2.5, labelangle=-45, headlabel="False"] ; 24 | 11 [label="gini = 0.0\nsamples = 30\nvalue = [0, 30]"] ; 25 | 10 -> 11 ; 26 | 12 [label="gini = 0.0\nsamples = 3\nvalue = [3, 0]"] ; 27 | 10 -> 12 ; 28 | } -------------------------------------------------------------------------------- /code/第6章 逻辑斯谛回归(LogisticRegression)/LR.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "原文代码作者:https://github.com/wzyonggege/statistical-learning-method\n", 8 | "\n", 9 | "中文注释制作:机器学习初学者(微信公众号:ID:ai-start-com)\n", 10 | "\n", 11 | "配置环境:python 3.6\n", 12 | "\n", 13 | "代码全部测试通过。\n", 14 | "![gongzhong](../gongzhong.jpg)" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "# 第6章 逻辑斯谛回归" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "LR是经典的分类方法" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "\n", 36 | "回归模型:$f(x) = \\frac{1}{1+e^{-wx}}$\n", 37 | "\n", 38 | "其中wx线性函数:$wx =w_0*x_0 + w_1*x_1 + w_2*x_2 +...+w_n*x_n,(x_0=1)$\n" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 1, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "from math import exp\n", 48 | "import numpy as np\n", 49 | "import pandas as pd\n", 50 | "import matplotlib.pyplot as plt\n", 51 | "%matplotlib inline\n", 52 | "\n", 53 | "from sklearn.datasets import load_iris\n", 54 | "from sklearn.model_selection import train_test_split" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 2, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "# data\n", 64 | "def create_data():\n", 65 | " iris = load_iris()\n", 66 | " df = pd.DataFrame(iris.data, columns=iris.feature_names)\n", 67 | " df['label'] = iris.target\n", 68 | " df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n", 69 | " data = np.array(df.iloc[:100, [0,1,-1]])\n", 70 | " # print(data)\n", 71 | " return data[:,:2], data[:,-1]" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 3, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "X, y = create_data()\n", 81 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 4, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "class LogisticReressionClassifier:\n", 91 | " def __init__(self, max_iter=200, learning_rate=0.01):\n", 92 | " self.max_iter = max_iter\n", 93 | " self.learning_rate = learning_rate\n", 94 | " \n", 95 | " def sigmoid(self, x):\n", 96 | " return 1 / (1 + exp(-x))\n", 97 | "\n", 98 | " def data_matrix(self, X):\n", 99 | " data_mat = []\n", 100 | " for d in X:\n", 101 | " data_mat.append([1.0, *d])\n", 102 | " return data_mat\n", 103 | "\n", 104 | " def fit(self, X, y):\n", 105 | " # label = np.mat(y)\n", 106 | " data_mat = self.data_matrix(X) # m*n\n", 107 | " self.weights = np.zeros((len(data_mat[0]),1), dtype=np.float32)\n", 108 | "\n", 109 | " for iter_ in range(self.max_iter):\n", 110 | " for i in range(len(X)):\n", 111 | " result = self.sigmoid(np.dot(data_mat[i], self.weights))\n", 112 | " error = y[i] - result \n", 113 | " self.weights += self.learning_rate * error * np.transpose([data_mat[i]])\n", 114 | " print('LogisticRegression Model(learning_rate={},max_iter={})'.format(self.learning_rate, self.max_iter))\n", 115 | "\n", 116 | " # def f(self, x):\n", 117 | " # return -(self.weights[0] + self.weights[1] * x) / self.weights[2]\n", 118 | "\n", 119 | " def score(self, X_test, y_test):\n", 120 | " right = 0\n", 121 | " X_test = self.data_matrix(X_test)\n", 122 | " for x, y in zip(X_test, y_test):\n", 123 | " result = np.dot(x, self.weights)\n", 124 | " if (result > 0 and y == 1) or (result < 0 and y == 0):\n", 125 | " right += 1\n", 126 | " return right / len(X_test)" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 5, 132 | "metadata": {}, 133 | "outputs": [ 134 | { 135 | "name": "stdout", 136 | "output_type": "stream", 137 | "text": [ 138 | "LogisticRegression Model(learning_rate=0.01,max_iter=200)\n" 139 | ] 140 | } 141 | ], 142 | "source": [ 143 | "lr_clf = LogisticReressionClassifier()\n", 144 | "lr_clf.fit(X_train, y_train)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 6, 150 | "metadata": {}, 151 | "outputs": [ 152 | { 153 | "data": { 154 | "text/plain": [ 155 | "0.9666666666666667" 156 | ] 157 | }, 158 | "execution_count": 6, 159 | "metadata": {}, 160 | "output_type": "execute_result" 161 | } 162 | ], 163 | "source": [ 164 | "lr_clf.score(X_test, y_test)" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 7, 170 | "metadata": {}, 171 | "outputs": [ 172 | { 173 | "data": { 174 | "text/plain": [ 175 | "" 176 | ] 177 | }, 178 | "execution_count": 7, 179 | "metadata": {}, 180 | "output_type": "execute_result" 181 | }, 182 | { 183 | "data": { 184 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xl8VPXZ9/HPRQgQEEEWFQgQSBQFBBEUkQoRFypu1BW3lmpF0T62d62t9uljW+4+T7W2te1tgyLeroh6uyBal1oh7KJBEBBEk7AlIEuAsEOW6/ljBggxJCfJmTm/OXO9Xy9eZM4czvx+M3pxOPO9fkdUFWOMMeHSJOgBGGOM8Z8Vd2OMCSEr7sYYE0JW3I0xJoSsuBtjTAhZcTfGmBCy4m6MMSFkxd0YY0LIirsxxoRQ06BeuEOHDpqRkRHUyxtjTEJatGjRVlXtWNd+nou7iKQAeUCxql5e7bmxwKNAcXTT46o6ubbjZWRkkJeX5/XljTHGACKy1st+9Tlz/wmwEjj+GM+/oqo/rsfxjDHGxIina+4ikg5cBtR6Nm6MMcYNXr9Q/SvwC6Cyln2uEZGlIvKaiHStaQcRGScieSKSt2XLlvqO1RhjjEd1XpYRkcuBzaq6SESyj7Hb28BUVT0gIncBzwEjqu+kqpOASQCDBg2ytYaNMYEoKyujqKiI/fv3Bz2UY2rRogXp6emkpqY26M97ueY+FLhSREYBLYDjReRFVb3l0A6qWlJl/6eARxo0GmOMiYOioiJat25NRkYGIhL0cL5FVSkpKaGoqIgePXo06Bh1XpZR1QdVNV1VM4AxwIyqhR1ARDpVeXglkS9ejTHGSfv376d9+/ZOFnYAEaF9+/aN+pdFg3PuIjIByFPV6cC9InIlUA5sA8Y2eETGGBMHrhb2Qxo7vnoVd1XNBXKjPz9UZfuDwIONGokxITJtcTGPfrCKDTv20bltGveP7MXoAV2CHpZJIrb8gDE+m7a4mAffWEbxjn0oULxjHw++sYxpi4vr/LMmebz//vv06tWLrKwsHn74Yd+Pb8XdGJ89+sEq9pVVHLVtX1kFj36wKqARGddUVFRwzz338N5777FixQqmTp3KihUrfH2NwNaWMSasNuzYV6/txn1+X2b75JNPyMrKomfPngCMGTOGt956i969e/s1ZDtzN8Zvndum1Wu7cVssLrMVFxfTteuRXs/09HSKi/29bGfF3Rif3T+yF2mpKUdtS0tN4f6RvQIakWmMWFxmU/12D6ff6R27LGOMzw79c93SMuEQi8ts6enprF+//vDjoqIiOnfu3ODj1cSKuzExMHpAFyvmIdG5bRrFNRTyxlxmO/vss/n6669ZvXo1Xbp04eWXX+all15qzDC/xS7LGGNMLWJxma1p06Y8/vjjjBw5ktNPP53rr7+ePn36NHaoR7+Gr0czxpiQidVltlGjRjFq1Cg/hlgjK+7GGFOHRLzMZpdljDEmhKy4G2NMCFlxN8aYELLibowxIWTF3RhjQsiKuzFE1g8Z+vAMejzwT4Y+PMOW5zUxd9ttt3HiiSfSt2/fmBzfirtJerb+ugnC2LFjef/992N2fCvuJunZ+uumTktfhcf6wm/bRn5f+mqjDzls2DDatWvnw+BqZk1MJunZ+uumVktfhbfvhbLofw+l6yOPAfpdH9y46mBn7ibp2frrplYfTThS2A8p2xfZ7jAr7ibp2frrplalRfXb7gi7LGOSnq2/bmrVJj1yKaam7Q6z4m4MibkwlImTCx86+po7QGpaZHsj3HjjjeTm5rJ161bS09P53e9+x+23397IwR5hxd0YY2pz6EvTjyZELsW0SY8U9kZ+mTp16lQfBndsVtyNU/y+y7wxvuh3vdPJmJpYcTfOONRMdChzfqiZCLACb0w9WVrGOMOaiUw8qWrQQ6hVY8dnxd04w5qJTLy0aNGCkpISZwu8qlJSUkKLFi0afAy7LGOcEYu7zBtTk/T0dIqKitiyZUvQQzmmFi1akJ7e8LilFXfjjPtH9jrqmjtYM5GJjdTUVHr06BH0MGLKirtxhjUTGeMfz8VdRFKAPKBYVS+v9lxz4HlgIFAC3KCqa3wcp0kS1kxkjD/qc+b+E2AlcHwNz90ObFfVLBEZAzwC3ODD+IwJNcv1m1jxlJYRkXTgMmDyMXa5Cngu+vNrwIUiIo0fnjHhZTcJMbHkNQr5V+AXQOUxnu8CrAdQ1XKgFGjf6NEZE2KW6zexVGdxF5HLgc2quqi23WrY9q0AqYiME5E8EclzOYJkTDxYrj95LFq7jfzNu+P6ml7O3IcCV4rIGuBlYISIvFhtnyKgK4CINAXaANuqH0hVJ6nqIFUd1LFjx0YN3JhEZzcJCTdVJXfVZq5/cgHXTFzAk7MK4vr6dRZ3VX1QVdNVNQMYA8xQ1Vuq7TYd+EH052uj+7jZ+mWMI+wmIeFUUan8c+lGLv+vuYx95lPWb9vLQ5f35ndX9YnrOBqccxeRCUCeqk4HngZeEJF8ImfsY3wanzGhZbn+cDlYXsmbi4t4clYhhVv30LNDK/54TT9GD+hCs6bxX+lFgjrBHjRokObl5QXy2sYY45e9B8t5aeE6Js9ZzTc799O3y/HcnZ3FyD4nk9LE/9CgiCxS1UF17WcdqibUfj1tGVMXrqdClRQRbhzcld+PPiPoYZkQ2LH3IM/NX8uz81ezfW8Zg3u044/X9uP8UzrgQhLcirsJrV9PW8aLH687/LhC9fBjK/CmoTbt3M/kOYW8tHAdew5WcNHpJzI+O4uB3U8IemhHseJuQmvqwhpuahzdbsXd1Nfakj08MauQ1xcVUV5ZyRX9OzM+O5PTTq6paT94VtxNaFUc4/ukY203piYrN+4kJ7eAfy7dQNOUJlw3KJ07h2XSrX3LoIdWKyvuJrRSRGos5CkOXA817stbs42c3AJmfLmZVs1SuOP8ntz+nR6ceHzDb6ART1bcTWjdOLjrUdfcq243piaqyqyvtpAzs4BP1myjXatm3HfxqXx/SAZtWqYGPbx6seJuQuvQdXVLy5i6VFQq7y3fyMTcAr7YsJNObVrwmyt6c8PZXWnZLDHLpOXcjTFJ60B5BW9+VsyTswtZvXUPPTu24q7hmYw+M5jGIy8s526ccvNTC5hXcGS5oaGZ7Zhyx5AARxRbtk672/YcKGfqJ0caj87o0oaJN5/FJTFqPAqCFXcTc9ULO8C8gm3c/NSCUBb4Q+u0H1rO99A67YAV+IDt2HuQZ+ev4dn5a9ixt4xze7rVeOQnK+4m5qoX9rq2J7ra1mm34h6Mb0qjjUefrGPvwQouOv0k7r4gk7O6udV45Ccr7sb4zNZpd8earXt4cnYBry8qpkKVK/p1Ynx2Fr1Obh300GLOirsxPuvcNo3iGgq5rdMePys27GTirCONR9efnc64891vPPKTFXcTc0Mz29V4CWZoZrsARhN794/sddQ1d7B12uPl0zXbyJmZz8xVWziueVPuGBZtPGqdGI1HfrLibmJuyh1DkiotY+u0x5eqkvvVFnJm5vPpmu20a9WMn19yKrcOyaBNWmI1HvnJcu7GmIRUUam8uyzSeLRi4046t2nBHcN6MubsbqQ1S6n7AAnKcu7GKX7nvr0ez/Lm4XOgvII3PivmyVkFrCnZS2bHVjx6bT+ucrjxKAhW3E3M+Z379no8y5uHy6HGo6fmFLJp5wHO6NKGJ245i0t6n0yTkDQe+cmKu4k5v3PfXo9nefNw2L4n0nj03IJI49GQnu3503X9+U5W+BqP/GTF3cSc37lvr8ezvHliq954dHHvkxifHe7GIz9ZcTcx53fu2+vxLG+emFZv3cOTswp4/bMiKhWu7N+Zu4ZnJkXjkZ/s2wcTc/eP7EVa6tHphcbkvr0ez+/XNbH1xYZSfvzSZ1z451zeWFzMmLO7kfvzbB674Uwr7A1gZ+4m5vzOfXs9nuXNE8Mnq7eRk5tPbrTxaNywTG77TkZSNh75yXLuxpi4U1VyV20hJzfSeNS+VTNu+04Pbjm3e1I3HnlhOfeQcD2n7fr4jFsqKpV/RhuPVm7cSZe2afz2it7cEPLGoyBYcXeY6zlt18dn3HGo8eiJWQWsjTYe/em6/lx1ZmdSU+yrv1iw4u4w13Paro/PBG/PgXJeWriOyXMjjUf90tvwxC0DuaT3SdZ4FGNW3B3mek7b9fGZ4Gzfc5Bn5q/huflrKN1XxnmZ7fnzdWcyNKu9NR7FiRV3h7me03Z9fCb+vindz1NzCpkabTy6JNp4NMAaj+LOirvDXF8X3PXxmfhZvXUPT+QW8MbiSOPRVf07c1d2JqeeZPn0oFhxd5jrOW3Xx2dib3lxKRNnFfDuso00S2nCmLO7MW5YT7q2S547Hrmqzpy7iLQAZgPNifxl8Jqq/qbaPmOBR4Hi6KbHVXVybce1nLsxieuT1dv4x8x8Zn21hdbNm3LLkO7cNrQHHVs3D3pooednzv0AMEJVd4tIKjBXRN5T1Y+r7feKqv64IYM1ievX05YxdeF6KlRJEeHGwV35/egzGrxfULl5y+vXTVWZuWozOTMLyFsbaTy6f2QvazxyVJ3FXSOn9rujD1Ojv4JpazVO+fW0Zbz48brDjytUDz+uWri97hdUbt7y+rUrr6g83Hj05Te76NI2jd9d2YfrB3W1xiOHeeoeEJEUEVkCbAY+VNWFNex2jYgsFZHXRKSrr6M0Tpq6cL2n7V73qy03H0tBva7rDpRX8NLCdYz48yx+8vISyiuVP1/Xn9z7s/nBeRlW2B3n6QtVVa0AzhSRtsCbItJXVZdX2eVtYKqqHhCRu4DngBHVjyMi44BxAN26dWv04E2wKo7xfU317V73Cyo3b3n9o+0+UM5LC9cyec5qNu86QP/0NvxqlDUeJZp6pWVUdYeI5ALfBZZX2V5SZbengEeO8ecnAZMg8oVqfQdr3JIiUmPhTqnWpOJ1v6By85bXj9i25yDPzlvNcwvWUrqvjKFZ7XnshjM5L9MajxJRnZdlRKRj9IwdEUkDLgK+rLZPpyoPrwRW+jlI46YbB9d89a36dq/7BbX+erKv+76xdB8T3l7B0Idn8PcZ+Qzu0Y5p9wxlyo/OZajdyi5heTlz7wQ8JyIpRP4yeFVV3xGRCUCeqk4H7hWRK4FyYBswNlYDNu449GVoXSkYr/sFlZtP1rx+4ZbdPDGrgDcXF0caj87szPjhmZxijUehYOu5G5NklheXMjG3gHeXRxqPbji7K3ecb41HicLWcw8Jv/PXXvPmfh/P6zxcn6/zlr4KH02A0iJokw4XPgT9rkdVI41HuQXMjjYejR+eyQ+t8Si0rLg7zO/8tde8ud/H8zoP1+frvKWvwtv3Qln0y+HS9ej0e5lRnELOmk4sWrudDsdFGo9uHdKd41tY41GY2Sr5DvM7f+01b+738bzOw/X5Ou+jCYcLe7k24a2KIVy65zfcPqsF35TuZ8JVfZj7yxHcc0GWFfYkYGfuDvM7f+01b+738bzOw/X5Oq+0iP2ayusV5/NkxRWs05PIkiL+kjqRK+6fbnc8SjJW3B3md/7aa97c7+N5nYfr83XZ7gPlTGl6E5N3n8cWTqC/5PO/U6dwcZNFNGmbDlbYk4594g7zO3/tNW/u9/G8zsP1+bpo256D/OVfqzjvDx/xh92X0StlAy+l/p5pzR5iZEoeTZq1iHypapKOnbk7zO/8tde8ud/H8zoP1+frkg079vHUnEJe/mQ9+8oqGNnnJO7OzqL/9j3w0S4olaPSMib5WM7dmARSsGU3T+QWMG1JMapw1ZldGJ/dk6wTrfEoWVjO3cRFUPn1ZLO8uJSc3HzeW/4NzVKacNM53bhjWE/ST3Ck8egY+XoTHCvupsGCyq8nC1VlYfSOR3O+3krrFk25OzvSeNThOIcaj2rI1/P2vZGfrcAHxoq7abDaculVi7bX/UxEZaUy48vN5OTm89m6HXQ4rhm/+G7kjkdO5tOr5OsPK9sX2W7FPTBW3E2DBZVfD6tDdzzKmVnAqk27SD8hjf+8qg/XDepKi1SHb4xRWlS/7SYurLibBgsqvx42+8sqeG1REZNmF7Ju215OOfE4/nJ9f67o3zkxGo/apEcuxdS03QQmAf7LMa4KKr8eFrv2l/HErALO/+NMfj1tOe1aNWPSrQP54KfDuPqs9MQo7BD58jS12l/UqWmWrw+YnbmbBgsqv57oSnYf4Nn5a3hu/hp27i/nO1kd+NuYMxnSM0HveHTourqlZZxiOXdj4mTDjn1Mml3Iy5+u40B5JSN7n8z47Ez6d20b9NBMArGce0D8znN7PV5Q65Zbfr1u+ZsjdzyatrgYiPxL5q7h9Ww8ClOOPExz8SKg+Vpx95HfeW6vxwtq3XLLr9duWVGk8ej9L76hedMm3HJud+4Y1pMu9f0iOUw58jDNxYsA55sg39gkBr/XI/d6vKDWLfd7vmGgqiwoKOHWpxdyxeNzmZu/lXuys5j7yxH89so+9S/sUHuOPNGEaS5eBDhfO3P3kd95bq/HC2rdcsuvH1FZqXwUbTxavG4HHY5rzi+/exq3nNuN1o1tPApTjjxMc/EiwPlacfeR33lur8cLat1yy69HGo/eWbqRnNx8vtq0O9J4NLov1w1M96/xKEw58jDNxYsA52uXZXzkd57b6/GCWrc8mfPr+8sqeOHjtVzw51x++soSAB67oT+5P8/m1nO7+9tRGqYceZjm4kWA87Uzdx/5nef2eryg1i1Pxvz6rv1lTFm4jslzVrN19wEGdGvLQ5f34cLTTqRJkxj9SylMOfIwzcWLAOdrOXdjPCjZfYBn5q3h+QWRxqPzT+nA3dlZnNuzXWI2HpmEZTn3kAgqN3/zUwuYV7Dt8OOhme2YcseQBr9uoiresY+nqjQefbdPpPGoX7o1HiWcd34Gi54FrQBJgYFj4fK/NPx4juf1rbg7LKjcfPXCDjCvYBs3P7UgaQp89caj7w3owp3DM8k68biAR2Ya5J2fQd7TRx5rxZHHDSnwCZDXty9UHRZUbr56Ya9re5gsLdrBXS8s4uLHZvHO0g3ccm53Zv3iAh69rr8V9kS26Nn6ba9LAuT17czdYUHl5pONqrKgsISJuQWH73j04wuyGHteBu1duuORaTitqN/2uiRAXt+Ku8OCys0ni8pK5d8rN5GTW8CS9ZHGowcuPY2bB/vQeGTcIik1F3JpYGQ1AfL6dlnGYUHl5odmtqvxzx9re6Ipr6jkzcVFfPdvsxn3wiJK9hzg96P7MveXF3DX8Ewr7GE0cGz9ttclAfL6dubusKBy81PuGBLKtMz+sgr+J289T84upGj7Pnqd1Jq/3nAml/frRNNEuTGGaZhDX5r6lZZJgLx+nTl3EWkBzAaaE/nL4DVV/U21fZoDzwMDgRLgBlVdU9txLedu4mXX/jJe/HgdT8+NNB6d1a0td2dnMSKWjUfGxIifOfcDwAhV3S0iqcBcEXlPVT+uss/twHZVzRKRMcAjwA0NGrmjvObDXV/f3Ou672GY79bdB3jmzXd5fkUFuzSN85t9xT0XdmXwRaMa3njkNSvtdwba9eMFyetcwjRnD+os7ho5td8dfZga/VX9dP8q4LfRn18DHhcR0aDaX33mNR/u+vrmXtd9T/T5Fm3fy1OzC3nlkzUcqGjKpU0WMz51Omc0WQ2fpMFJlQ37n9prVtrvDLTrxwuS17mEac4eebrQKCIpIrIE2Ax8qKoLq+3SBVgPoKrlQCnQ3s+BBslrPtz19c29rvueqPPN37yL+179nOxHc5mycB1XpObx72b3k9Psb5HCDo3LInvNSvudgXb9eEHyOpcwzdkjT1+oqmoFcKaItAXeFJG+qrq8yi41/Rv3W2ftIjIOGAfQrVu3Bgw3GF7z4a7nyL2u+55o8/18/Q5ycvP514pNNG/ahFuHdOeO83vS+a83QZMa5tzQLLLXrLTfGWjXjxckr3MJ05w9qldEQFV3ALnAd6s9VQR0BRCRpkAb4FvtjKo6SVUHqeqgjh07NmjAQThWDrz6dq/7BeVY67tX354I81VV5udv5ZbJC7nqH/NYUFDCjy/IYt4vR/CbK/pExnCszHFDs8jHykRX3+7367p+vCB5nUuY5uxRncVdRDpGz9gRkTTgIuDLartNB34Q/flaYEZYrreD93y46+ube1333eX5VlYqH3zxDaNz5nPT5IWs2rSLBy89jXkPjOC+S3od3VHqdxbZa1ba79d1/XhB8jqXMM3ZIy+XZToBz4lICpG/DF5V1XdEZAKQp6rTgaeBF0Qkn8gZ+5iYjTgAXvPhrq9v7nXddxfnW1ZRydufb2BibgFfb95N13Zp/H50X66t7Y5HfmeRvWal/X5d148XJK9zCdOcPbL13I3T9pdV8Greep6cVUjxjn2cdnJrxmdnctkZ1nhkkpOt5x4Ql3PfiWTn/jJe/Hgt/z13NVt3H+Ssbm2ZcFUfRpx2YrA3x3A9U+33+GIxD8vsx4UVdx+5mvtOJFt3H+C/567mhQVr2XWgnGGnduTu7EwG93DgjkeuZ6r9Hl8s5mGZ/bixyzI+GvrwjBpXXezSNo15D4wIYESJo2j7XibNLuSVT9dzsKKSUX07MT47k75d2gQ9tCMe63uMlQC7wn8sr/9+ro8vFvPw+5hBvdcBsssyAXAl951Ivt60i4mzCpi+ZAMiR+54lNnRwRtjuJ6p9nt8sZiHZfbjxoq7j2y9dO+WrN9BzsxI41FaagrfH5LBj87v4fZ75XUN76DW+vZ7fLGYh9/HTIB11YNicQMfuZ5zD5qqMi9/KzdP/pjR/5jHx4Ul3Dsii3kPjOChK3q7XdjB/Uy13+OLxTwssx83dubuI9dz7kGprFQ+jN7x6PP1OzixdXN+Neo0bhrcneOaJ9B/gq5nqv0eXyzmYZn9uLEvVE3MlFVUMn3JBp6YFWk86tauJXcO78k1Z9XSeGSMqZV9oWoCU1Pj0d/GnJlcjUde130PiuvjA/d7BRxnxd34Zuf+Ml5YsJZn5kUajwZ2P4H/HN2HC3oF3HgUb17XfQ+K6+MD93sFEoBdljGNtmXXAf573mpejDYeDY82Hp3jQuNREH7XrublgSUFfvOtxVLjz/Xxgfu9AgGyyzIm5tZv28tTc6o0Hp3RifHDHWs8CoLXdd+D4vr4wP1egQRgxd3U29ebdjExt4C3Pt9AE4GrB6Rz5/Ce9HSx8SgIknLsM2MXuD4+cL9XIAEkybdbxg+L123njufzuPix2by3/BvGnpfB7F9cwCPX9rPCXpXXdd+D4vr4wP1egQRgZ+6mVpHGoxJycvOZX1BCm7RU7r3wFMael0G7Vs2CHp6bvK77HhTXxwfu9wokAPtC1dSoslL514pNTMzN5/OiUk5s3Zw7zu/JjYO7JVbjkTEhY1+omgYpq6jkrWjjUf7m3XRv35L/970zuGZgF5o3jcM1Wdczxn6Pz++8uevvn4kbK+4GgH0HI41Hk2YfaTz6+40DGNX35Pg1HrmeMfZ7fH7nzV1//0xc2WWZJFe678gdj0r2HGRQ9xO4+4LMYBqPXM8Y+z0+v/Pmrr9/xhd2WcbUasuuAzw9dzVTPo40HmX36sjd2Vmc06NdcINyPWPs9/j8zpu7/v6ZuLLinmTWb4ve8ShvPWWuNR65njH2e3x+581df/9MXFnOPUl8tWkX//HKErL/lMvLn67j6gFdmHFfNv+46Sw3Cju4nzH2e3x+581df/9MXNmZe8gtXrednNwCPlyxiZbNUvjheRn86PyenNymRdBD+zbXM8Z+j8/vvLnr75+JK/tCNYRUlbn5W8mZWcCCwkjj0djzMhh7XgYnWOORMQnNvlBNQpHGo2/IyS1gaVEpJx3fnF9fdjo3ntONVtZ4VDu/8+Fej2e5dBMj9n98CJRVVDJtcTFPzCqgYMseMtq35A9Xn8HVZ8Wp8SjR+Z0P93o8y6WbGLLLMgls38EKXvl0HZNmF7KhdD+ndzqeu7MzGXVGJ1KaJOE66g3ldz7c6/Esl24awC7LhFjpvjJeWLCGZ+atoWTPQc7OOIH/+70zyO7VMTlvjtFYfufDvR7Pcukmhqy4J5DNu/ZHG4/WsftAORf06sjdF2RxdkaAjUdh4Hc+3OvxLJduYsiKewJYv20vT84u4NW8IsoPNR5lZ9KnsyP59ER34UNHX/uGxuXDvR7P79c1pgor7g5b9c0uJubm8/bSjaSIcM3ALtw5LJOMDq2CHlq4+J0P93o8y6WbGKrzC1UR6Qo8D5wMVAKTVPVv1fbJBt4CVkc3vaGqE2o7rn2hemyfrdtOzswC/r0y0nh08+Bu3P4dRxuPjDFx5ecXquXAfar6mYi0BhaJyIequqLafnNU9fKGDNZEGo/mfL2VnNx8Pi7cRtuWqfz0olP4wZCQNB65nue2XHrj2PvinDqLu6puBDZGf94lIiuBLkD14m4aoLJS+eCLSOPRsuKQNh65nue2XHrj2PvipHrl3EUkA5gN9FXVnVW2ZwOvA0XABuDnqvpFbcdK9ssyB8srmbYk0nhUGG08umt4Jt8LY+OR63luy6U3jr0vceV7zl1EjiNSwH9atbBHfQZ0V9XdIjIKmAacUsMxxgHjALp16+b1pUNl38EKXv50HU9VaTz6rxsHhLvxyPU8t+XSG8feFyd5Ku4ikkqksE9R1TeqP1+12KvquyKSIyIdVHVrtf0mAZMgcubeqJEnmNK9ZTy/YA3PzF/Dtj0HOSejHf/36jPIPjUJGo9cz3NbLr1x7H1xUp3ruUuk8jwNrFTVGtciFZGTo/shIudEj1vi50AT1ead+/nDuysZ+sgM/vzhV5zZtS2v3TWEV+8aEsyt7ILg+jrjXsfn+jyCYu+Lk7ycuQ8FbgWWiciS6LZfAd0AVPUJ4FpgvIiUA/uAMRrUojWOWFcSaTz6n0WRxqPL+nVm/PBMenc+PuihxZ/reW7LpTeOvS9OsoXDfPblNzuZmFvAO4cbj9K5c1hPazwyxvjCFg6Ls0VrtzMxN59/r9xMy2Yp3DY0csejk45P0MajZMstv/Mz/+6IZIwDrLg3gqoy++ut5MzMZ+HqSOPRf1x0Kj84rzttWyZw41Gy5Zbf+RnkPX3ksVYceWwF3iQoK+4NUHG48ShI8O3rAAAJcUlEQVSf5cU7Ofn4FuFqPPpowtGLWUHk8UcTwlncFz177O1W3E2CCkElip+D5UfueFS4dQ89OrTikWvOYPSAkDUeJVtuWSvqt92YBGDF3YO9B8uZ+sl6Js8pZGPpfnp3Op7HbxrApX1D2niUbLllSam5kEuI/sI2SceKey1K95bx3II1PDNvNdv3lnFOj3b84eozGB72xqNkW2d84Nijr7lX3W5MgrLiXoPNO/czee5qpny8lj0HK7jwtBO5+4JMBnZPkjseJVtu+dB1dUvLmBCxnHsVa0v28OTsQl7LK6K8spLL+3VmfHYmp3dKwsYjY4yTLOdeDys3Hmo82kDTJk24dlCk8ah7+9g1Hk1bXMyjH6xiw459dG6bxv0jezF6QJeYvV7MhCUPH5Z5BMXeP+ckdXFftHYbOTML+OjLzbRqlsKPzu/J7d/pEfPGo2mLi3nwjWXsK4t8iVe8Yx8PvrEMILEKfFjy8GGZR1Ds/XNS0l2WUVVmfbWFnNwCPlm9jRNapvLDoT34/pD4NR4NfXgGxTv2fWt7l7ZpzHtgRFzG4IuwrOMdlnkExd6/uLLLMtVUVCrvL480Hn2xIdJ49H8u782N53SlZbP4vg0baijstW13Vljy8GGZR1Ds/XNS6Iv7wfJK3lxcxJOzCincuoeeHVrxx2v6MXpAF5o1rXPF45jo3DatxjP3zm3TatjbYWHJw4dlHkGx989JwVS3ONh7sJzJcwoZ9seZ/PL1ZaQ1S+EfN53Fhz8bzvVndw2ssAPcP7IXaalHN8ikpaZw/8heAY2ogcKyjndY5hEUe/+cFLoz9x17D/Lc/LU8Oz/SeDS4RzseubYfw07p4Ezj0aEvTRM+LROWPHxY5hEUe/+cFJovVDft3M/kOYW8tHBdcjYeGWOSQtJ8obq2ZA9PzCrk9UWRxqMr+kcaj0472RqPTBLyO29u+fWElbDFfeXGneTkFvDPODYeGeM0v/Pmll9PaAlX3Fds2Mmf/rWKGdHGozuijUcnJuodj4zxi9/r8Cfbuv4hk3DFfcfegyxet537Lj6V7w/JoE3L1KCHZIwb/M6bW349oSVccR+S2Z75D1xIWjNba9uYo/idN7f8ekJLuJy7iFhhN6YmfufNLb+e0BKuuBtjjqHf9XDF3yNruiCR36/4e8Ovj/t9PBNXocm5G2NMMvCac7czd2OMCSEr7sYYE0JW3I0xJoSsuBtjTAhZcTfGmBCy4m6MMSFkxd0YY0LIirsxxoRQncVdRLqKyEwRWSkiX4jIT2rYR0Tk7yKSLyJLReSs2AzXOGfpq/BYX/ht28jvS18NekTGGLwtHFYO3Keqn4lIa2CRiHyoqiuq7HMpcEr012BgYvR3E2a23rcxzqrzzF1VN6rqZ9GfdwErgeo3+7wKeF4jPgbaikgn30dr3FLbet/GmEDV65q7iGQAA4CF1Z7qAlRdG7SIb/8FgIiME5E8EcnbsmVL/UZq3GPrfRvjLM/FXUSOA14HfqqqO6s/XcMf+daKZKo6SVUHqeqgjh071m+kxj3HWtfb1vs2JnCeiruIpBIp7FNU9Y0adikCulZ5nA5saPzwjNNsvW9jnOUlLSPA08BKVf3LMXabDnw/mpo5FyhV1Y0+jtO4yNb7NsZZXtIyQ4FbgWUisiS67VdANwBVfQJ4FxgF5AN7gR/6P1TjpH7XWzE3xkF1FndVnUvN19Sr7qPAPX4NyhhjTONYh6oxxoSQFXdjjAkhK+7GGBNCVtyNMSaErLgbY0wIWXE3xpgQsuJujDEhJJGIegAvLLIFWNvAP94B2OrjcIIUlrnYPNwSlnlAeObi1zy6q2qdi3MFVtwbQ0TyVHVQ0OPwQ1jmYvNwS1jmAeGZS7znYZdljDEmhKy4G2NMCCVqcZ8U9AB8FJa52DzcEpZ5QHjmEtd5JOQ1d2OMMbVL1DN3Y4wxtXC+uItIiogsFpF3aniuuYi8IiL5IrIweo9XJ9Uxj7EiskVElkR//SiIMXohImtEZFl0nHk1PC8i8vfoZ7JURM4KYpx18TCPbBEprfKZOHl7KRFpKyKviciXIrJSRIZUez5RPo+65pEon0evKmNcIiI7ReSn1faJy2fi5WYdQfsJsBI4vobnbge2q2qWiIwBHgFuiOfg6qG2eQC8oqo/juN4GuMCVT1WXvdS4JTor8HAxOjvLqptHgBzVPXyuI2mYf4GvK+q14pIM6BltecT5fOoax6QAJ+Hqq4CzoTICR1QDLxZbbe4fCZOn7mLSDpwGTD5GLtcBTwX/fk14MLobQGd4mEeYXIV8LxGfAy0FZFOQQ8qjETkeGAYkdtgoqoHVXVHtd2c/zw8ziMRXQgUqGr1Zs24fCZOF3fgr8AvgMpjPN8FWA+gquVAKdA+PkOrl7rmAXBN9J9or4lI11r2C5oC/xKRRSIyrobnD38mUUXRba6pax4AQ0TkcxF5T0T6xHNwHvUEtgDPRC/5TRaRVtX2SYTPw8s8wP3Po7oxwNQatsflM3G2uIvI5cBmVV1U2241bHMq/uNxHm8DGaraD/g3R/414qKhqnoWkX9a3iMiw6o97/xnElXXPD4j0ubdH/gvYFq8B+hBU+AsYKKqDgD2AA9U2ycRPg8v80iEz+Ow6KWlK4H/qenpGrb5/pk4W9yJ3Jj7ShFZA7wMjBCRF6vtUwR0BRCRpkAbYFs8B+lBnfNQ1RJVPRB9+BQwML5D9E5VN0R/30zkWuI51XY5/JlEpQMb4jM67+qah6ruVNXd0Z/fBVJFpEPcB1q7IqBIVRdGH79GpEhW38f1z6POeSTI51HVpcBnqrqphufi8pk4W9xV9UFVTVfVDCL/vJmhqrdU22068IPoz9dG93HqrMTLPKpdb7uSyBevzhGRViLS+tDPwCXA8mq7TQe+H00EnAuUqurGOA+1Vl7mISInH/r+RkTOIfL/Skm8x1obVf0GWC8ivaKbLgRWVNvN+c/DyzwS4fOo5kZqviQDcfpMEiEtcxQRmQDkqep0Il/AvCAi+UTO2McEOrh6qDaPe0XkSqCcyDzGBjm2WpwEvBn9f6wp8JKqvi8idwGo6hPAu8AoIB/YC/wwoLHWxss8rgXGi0g5sA8Y49qJQ9T/AqZELwMUAj9MwM8D6p5HonweiEhL4GLgzirb4v6ZWIeqMcaEkLOXZYwxxjScFXdjjAkhK+7GGBNCVtyNMSaErLgbY0wIWXE3xpgQsuJujDEhZMXdGGNC6P8DrXoM1rj8MesAAAAASUVORK5CYII=\n", 185 | "text/plain": [ 186 | "
" 187 | ] 188 | }, 189 | "metadata": {}, 190 | "output_type": "display_data" 191 | } 192 | ], 193 | "source": [ 194 | "x_ponits = np.arange(4, 8)\n", 195 | "y_ = -(lr_clf.weights[1]*x_ponits + lr_clf.weights[0])/lr_clf.weights[2]\n", 196 | "plt.plot(x_ponits, y_)\n", 197 | "\n", 198 | "#lr_clf.show_graph()\n", 199 | "plt.scatter(X[:50,0],X[:50,1], label='0')\n", 200 | "plt.scatter(X[50:,0],X[50:,1], label='1')\n", 201 | "plt.legend()" 202 | ] 203 | }, 204 | { 205 | "cell_type": "markdown", 206 | "metadata": { 207 | "collapsed": true 208 | }, 209 | "source": [ 210 | "## sklearn\n", 211 | "\n", 212 | "### sklearn.linear_model.LogisticRegression\n", 213 | "\n", 214 | "solver参数决定了我们对逻辑回归损失函数的优化方法,有四种算法可以选择,分别是:\n", 215 | "- a) liblinear:使用了开源的liblinear库实现,内部使用了坐标轴下降法来迭代优化损失函数。\n", 216 | "- b) lbfgs:拟牛顿法的一种,利用损失函数二阶导数矩阵即海森矩阵来迭代优化损失函数。\n", 217 | "- c) newton-cg:也是牛顿法家族的一种,利用损失函数二阶导数矩阵即海森矩阵来迭代优化损失函数。\n", 218 | "- d) sag:即随机平均梯度下降,是梯度下降法的变种,和普通梯度下降法的区别是每次迭代仅仅用一部分的样本来计算梯度,适合于样本数据多的时候。" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": 8, 224 | "metadata": {}, 225 | "outputs": [], 226 | "source": [ 227 | "from sklearn.linear_model import LogisticRegression" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 9, 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "clf = LogisticRegression(max_iter=200)" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": 10, 242 | "metadata": {}, 243 | "outputs": [ 244 | { 245 | "data": { 246 | "text/plain": [ 247 | "LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n", 248 | " intercept_scaling=1, max_iter=200, multi_class='ovr', n_jobs=1,\n", 249 | " penalty='l2', random_state=None, solver='liblinear', tol=0.0001,\n", 250 | " verbose=0, warm_start=False)" 251 | ] 252 | }, 253 | "execution_count": 10, 254 | "metadata": {}, 255 | "output_type": "execute_result" 256 | } 257 | ], 258 | "source": [ 259 | "clf.fit(X_train, y_train)" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 11, 265 | "metadata": {}, 266 | "outputs": [ 267 | { 268 | "data": { 269 | "text/plain": [ 270 | "0.9666666666666667" 271 | ] 272 | }, 273 | "execution_count": 11, 274 | "metadata": {}, 275 | "output_type": "execute_result" 276 | } 277 | ], 278 | "source": [ 279 | "clf.score(X_test, y_test)" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": 12, 285 | "metadata": {}, 286 | "outputs": [ 287 | { 288 | "name": "stdout", 289 | "output_type": "stream", 290 | "text": [ 291 | "[[ 1.94474283 -3.29077674]] [-0.53064339]\n" 292 | ] 293 | } 294 | ], 295 | "source": [ 296 | "print(clf.coef_, clf.intercept_)" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": 13, 302 | "metadata": {}, 303 | "outputs": [ 304 | { 305 | "data": { 306 | "text/plain": [ 307 | "" 308 | ] 309 | }, 310 | "execution_count": 13, 311 | "metadata": {}, 312 | "output_type": "execute_result" 313 | }, 314 | { 315 | "data": { 316 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEKCAYAAAD9xUlFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xl8VPW5+PHPQwLEyKIsKhCSQOIGKKu4IDtVKxatWpeit26lArb2drOtXe7PXm+rt4tt2cSltYrora0bVdsCCQgICAKioJAQlgCyBAhLDCHJ8/vjTEKMk+QkOWfmzMnzfr3mNZkzZ8483xnIk3O+3+f7FVXFGGOMAWgV7wCMMcYEhyUFY4wx1SwpGGOMqWZJwRhjTDVLCsYYY6pZUjDGGFPNkoIxxphqvicFEUkSkTUiMi/Kc3eIyD4RWRu53eN3PMYYY+qWHIP3uB/YCHSo4/kXVfW+GMRhjDGmAb4mBRFJA8YDDwPf8eKYXbp00czMTC8OZYwxLcbq1av3q2rXhvbz+0zhMeAHQPt69rlBREYAm4D/VNUd9R0wMzOTVatWeRiiMcaEn4hsc7Ofb30KInINsFdVV9ez2+tApqpeCMwHnqnjWJNEZJWIrNq3b58P0RpjjAF/O5qHARNEZCvwAjBGRJ6ruYOqFqnq8cjDJ4DB0Q6kqrNVdYiqDunatcGzH2OMMU3kW1JQ1R+papqqZgK3AAtV9baa+4hItxoPJ+B0SBtjjImTWIw++gwReQhYpaqvAd8SkQlAOXAAuCPW8RhjjFsnTpygsLCQ0tLSeIdSp5SUFNLS0mjdunWTXi+Jtp7CkCFD1DqajTHxUFBQQPv27encuTMiEu9wPkdVKSoq4siRI/Tq1eszz4nIalUd0tAxrKLZmACaMwcyM6FVK+d+zpx4R2QASktLA5sQAESEzp07N+tMJuaXj4wx9ZszByZNgpIS5/G2bc5jgIkT4xeXcQQ1IVRpbnx2pmBMwDz44MmEUKWkxNlujN8sKRgTMNu3N267aVneeustzj33XLKzs/nVr37l+fEtKRgTMOnpjdtugsvrvqGKigqmTp3Km2++yYYNG5g7dy4bNmzwItRqlhSMCZiHH4bU1M9uS011tpvEUdU3tG0bqJ7sG2pOYli5ciXZ2dn07t2bNm3acMstt/Dqq696FzSWFIwJnIkTYfZsyMgAEed+9mzrZE40fvQN7dy5k549e1Y/TktLY+fOnU0/YBQ2+siYAJo40ZJAovOjbyhaXZnXo6HsTMEYY3zgR99QWloaO3acnEi6sLCQ7t27N/2AUVhSMMYYH/jRN3TRRRexefNmCgoKKCsr44UXXmDChAnNC7QWSwrGGOMDP/qGkpOTmTZtGldeeSXnn38+N910E3379vUuaKxPwRhjfONH39DVV1/N1Vdf7e1Ba7AzBWOMMdUsKRhjjKlmScEYY0w1SwrGGGOqWVIwxhhTzZKCMR6wRXFMWFhSMKaZ/Jj4zJi63HXXXZxxxhn069fPl+NbUjCmmWxRHFOngjnwSiY838q5L2j+Xwp33HEHb731VrOPUxdLCsY0ky2KY6IqmAMrJ0HJNkCd+5WTmp0YRowYQadOnbyJMQpLCsY0ky2KY6Ja9yBU1DqFrChxtgeYJQVjmskWxTFRldRxqljX9oCwpGBMM9miOCaq1DpOFevaHhCWFIzxwMSJsHUrVFY695YQDP0fhqRap5BJqc72ALOkYELDagVMoPSaCENnQ2oGIM790NnO9ma49dZbufTSS/n4449JS0vjqaee8ibeCJs624RCVa1A1dDQqloBsL/aTRz1mtjsJFDb3LlzPT1ebXamYELBagWM8YYlBRMKVitgjDcsKZhQsFoBEyuqGu8Q6tXc+CwpmFCwWgETCykpKRQVFQU2MagqRUVFpKSkNPkY1tFsQqGqM/nBB51LRunpTkKwTmbjpbS0NAoLC9m3b1+8Q6lTSkoKaWlpTX69BDXj1WXIkCG6atWqeIdhjDEJRURWq+qQhvbz/fKRiCSJyBoRmRflubYi8qKI5InIChHJ9DseY8LEajOM12LRp3A/sLGO5+4GDqpqNvA74JEYxGNMKNg6DsYPviYFEUkDxgNP1rHLtcAzkZ9fAsaKiPgZkzFhYbUZxg9+nyk8BvwAqKzj+R7ADgBVLQeKgc61dxKRSSKySkRWBbmDx5hYstoM4wffkoKIXAPsVdXV9e0WZdvner5VdbaqDlHVIV27dvUsRmMSmdVmtByHSsp4bP4mluXv9/29/DxTGAZMEJGtwAvAGBF5rtY+hUBPABFJBjoCB3yMyZjQsNqM8Nt7pJRfvrGRYb9ayGPzN7M8v8j39/StTkFVfwT8CEBERgHfU9Xbau32GvA14B3gRmChJtoYWWPixGozwqvwYAmzF2/hxXd3cKKikvEXdmfKqCzO79bB9/eOefGaiDwErFLV14CngGdFJA/nDOGWWMdjTCKbONGSQJhs2XeUGbn5vLJmJwDXD+rB5FHZ9OpyasxiiElSUNVcIDfy889qbC8FvhKLGIxprClTnBXUKiogKckZ7jljRryjMmG0Yddhpufm8cb63bRJasVtl2Tw9RG96XHaKTGPxaa5MCaKKVNg5syTjysqTj62xGC88t72g0xfmMeCj/bSrm0y3xiRxd2X96Jr+7Zxi8mmuTAmiuRkJxHUlpQE5eWxj8eEh6qyLL+IaQvzeGdLEaeltuauYb342qWZdExt7dv7up3mws4UjIkiWkKob7sxDVFVFmzcy7ScPNbuOETX9m158Orz+erF6ZzaNji/ioMTiTEBkpRU95mCMY1RUam8sX4303Py+OiTI/Q47RR+cV0/vjI4jZTWwfsHZUnBmCgmTfpsn0LN7ca4UVZeyStrdjJzUT4F+4+R1fVUfvOV/kwY0J3WScFdysaSgjFRVHUm2+gj01ilJyp48d0dPL4on13FpfTt3oEZEwdxZd+zSGoV/KndrKPZGGM8cKT0BHNWbOfJtwvYf/Q4QzJOZ+qYbEad05UgzPNpHc3GGBMDB4+V8adlW/nz0gIOl5Yz/OwuTB09kIt7dQpEMmis4F7YMiZi3DgQOXkbNy7eEfnPFs8Jvr2HS/mfNzYy7JGF/GHBZi7u3ZlXpw7j2bsv5pLenRMyIYCdKZiAGzcOFiz47LYFC5zt8+fHJya/VS2eU7VWQtXiOWBTWgTBjgOReYlW7aC8opIv9e/O5FFZnHeW//MSxYL1KZhAq++PrQT7p+taZqaTCGrLyICtW2MdjamSv+8oM3LyeXXtTkTghkFp3Dsyi8wYzkvUHNanYEyCssVzguXDXcXMyMnnjQ920zbZmZdo0ojedI/DvESxYEnBmIBJT49+pmCL58TW6m0HmLYwj5yP99G+bTKTR2Zx1+W96NIufvMSxYIlBRNoY8d+vk+hantYPfzwZ/sUwBbPiRVVZWleEdNyNrN8ywFOT23Nd79wDv9xWSYdT/FvXqIgsaRgAm3+/M93No8dG95OZrDFc+KhslKZv3EP03PzWbfjEGd2aMtPxp/PrUODNS9RLFhHszGmxaqoVOa9v4sZOfl8vOcIPTudwr0js7hxcBptk4M3L1FzuO1otjoFE3hej9l3ezyrFQivsvJKXnx3O2N/k8v9L6ylQpXf3tSfnO+OYuLFGaFLCI3Rss6LTMLxesy+2+NZrUA4fVpWwQvvbmf24i3sLi6lX48OzLptEFf0OYtWCTAvUSzY5SMTaF6P2Xd7PKsVCJcjpSd4dvk2nl5SwP6jZVyUeTpTR2czMiDzEsWC1SmYUPB6zL7b41mtQDgcPFbGn5YW8OdlWzlcWs6Ic7py3+hshvbqFO/QAsuSggk0r8fsuz2e1Qoktj2HS3ny7S3MWbGdkrIKrux7JlNHZ3Nh2mnxDi3wrKPZBNrDDztj9Gtqzph9t8fz+n1NbOw4UMKDL69n+CM5PLWkgCv6nMm//nMEj98+xBKCS3amYALN6zH7bo9ntQKJJW/vUWbk5vHq2l0kiXDD4DTuHdmbjM6JMS9RkFhHszEmYX2ws5jpOXm89eEntE1uxVeHOvMSndUxJd6hBY7VKbRwiTDGPhFiNMG0ausB7vjTSq754xKWbN7PlFFZLH1gDD/7Uh9LCM1kl49CKBHG2CdCjCZYVJUlefuZtjCPFQUH6HRqG75/5bncfmkGHVJaxrxEsWCXj0IoEcbYJ0KMJhgqK5V/b9zDjJw81hUWc2aHtkwakcWtQ3uS2sb+rnXL6hRasEQYY58IMZr4Kq+o5B/rd1fPS5TeKZVfXn8B1w/q0aKnofCbJYUQSoQx9okQo4mP4+UV/P29ncxalM+2ohLOPqMdj908gGsu7EZyknWD+s2SQgglwnz8iRCjia1PyyqYu9KZl+iTw6Vc0KMjs24bzBV9zrR5iWLIkkIIJcIY+0SI0cTG4dITPPuOMy9R0bEyhvbqxKM3Xsjws7u0mHmJgsQ6mo0xcXHgWBlPLyngmXe2cqS0nJHndOW+MdlclGnzEvkh7nUKIpIiIitFZJ2IfCgi/y/KPneIyD4RWRu53eNXPCZ4pkyB5GQQce6nTGnefvGse7CaC/f2HC7lF/M2MOxXC5mem8fl2V2Y983LeeauoZYQAqDBy0ci0ha4Acisub+qPtTAS48DY1T1qIi0BpaIyJuqurzWfi+q6n2NC9skuilTYObMk48rKk4+njGj8fvFs+7Bai7c2XGghJmL8nlpVSEVqlzbvzuTR2Vx9pnt4x2aqaHBy0ci8hZQDKwGKqq2q+pvXL+JSCqwBJisqitqbL8DGNKYpGCXj8IhOdn5BV9bUhKUlzd+v3jWPVjNRf027znCzNx8Xl3nzEt045A07h2RRXrn1IZfbDzjZZ1Cmqpe1cQgknCSSTYwvWZCqOEGERkBbAL+U1V3RDnOJGASQLqNWQyFaL/oo213u1886x6s5iK69YXOvET/3PAJKclJ3HlZJvcMt3mJgs5NUlgmIheo6vrGHlxVK4ABInIa8LKI9FPVD2rs8jowV1WPi8i9wDPAmCjHmQ3MBudMobFxmOBJSqr7DKAp+8Wz7sFqLj5rZcEBpufksWjTPtqnJHPf6GzuHNaLTqe2iXdoxoU6O5pFZL2IvA9cDrwnIh+LyPs1trumqoeAXOCqWtuLVPV45OETwOBGRW8SVtU194a2u90vnusf2NoLzrxEizbt46ZZ73DT4+/wwc5ivn/luSz94Ri+e8W5lhASSH1nCtc058Ai0hU4oaqHROQUYBzwSK19uqnq7sjDCcDG5rynSRxVncSzZztnAklJzi/6mp3HjdkvnnUPLbnmorJS+deGPUzPyWP9zmK6dUzh51/qwy0XpXNKG5uKIhG56Wh+VlVvb2hblNddiHM5KAnnjOT/VPUhEXkIWKWqr4nIL3GSQTlwAKcj+qP6jmsdzcbEX3lFJfPe382M3Dw27TlKRudUJo/M4vpBabRJtqkogsjLjua+tQ6chIvLPKr6PjAwyvaf1fj5R8CPXMRgjAmA4+UV/G21My/R9gMlnHNmO35/ywDGX2DzEoVFfX0KPxKRI8CFInI4cjsC7AVejVmEpkn8KKZyW0Tm9fHctsXrNnvd3oRQMAdeyYTnWzn3Bc6HWFJWzlNLChj5aC4/fnk9p6e2Zvbtg3nr/hFcO6CHJYQwUdV6b8AvG9onlrfBgwerqd9zz6mmpqrCyVtqqrO9qSZP/uzxqm6TJ/t7PLdt8brNXrc3IWx5TvWFVNU5VN8qX0jVt17/tQ586F+a8cA8vWnWMl28aa9WVlbGO1rTSDiX7Rv8HVtnn4KIDGogmbznfYpqmPUpNMyPYiq3RWReH89tW7xus9ftTQivZELJ5z/EwrKu/FTnMXV0NkNsGoqE5bZPob6kkBP5MQUYAqwDBLgQWKGql3sUa6NYUmhYq1bO37W1iUBlZdOOWd9klU2ZU9Ht8dy2xes2e93ehPB8K+DzjVME+WoT/+GYwGj2hHiqOlpVRwPbgEGqOkRVB+N0Hud5F6rxWl1FU80ppqpdLNbQdq+O57YtXrfZ6/YG3baiYxzkrKjPSWoLrcJrodz0Dp2nNaqZ1alIHuBfSKa5/CimcltE5vXx3LbF6zZ73d6g2rTnCN9+YQ2jf53LLwpvo4xaU1AkpUL/FlSFZ1x1NM8FngRGASNxKo/nuumw8ONmHc3uPPecakaGqohz35xO5iqTJ6smJTkdrklJze90dXs8t23xus1etzdI3t9xSCf95V3NeGCenv/TN/W/532oe4o/dTqbX85QnSPO/RYP/uGYQKC5Hc1VRCQFmAyMiGxaDMxU1VJ/0lT9rE/BmKZbsaWI6bn5LN60jw4pydwxrBd3XpbJ6TYNReh5tsiOqpaq6u9U9cuR2+/ilRBMuMSr/qClUVVyP97LV2Yt4+bZy9mwq5gHrjqPpT8cw3e+cE5wEkIdNRImtuqsaBaR/1PVm0RkPVGGJKjqhb5GZkLN7cI0toBN01VWKv/88BOm5+bxwc7DdO+Ywn99qQ83B3FeooI5sHISVES+6JJtzmOAXvZFx1J9Q1K7qepuEcmI9ryqRhkV7j+7fBQO8ao/aAnKKyp5bd0uZuTmk7f3KL26nMrkkVlcN7BHcOclqqNGgtQMuG5rrKMJpWbPfaQnZy8dC7ytqpu9Cs4YtwvT2AI27h0vr+Cl1YXMWpTPjgOfct5Z7fnDrQMZf0E3klrVU3gRBCV1fKF1bTe+cTMhXiZwW+SMYTXwNk6SWOtnYCbc3C5MYwvYNKykrJznV2znibe3sOfwcfr3PI2fXdOXseedQaugJ4Mqqel1nCnYFx1rbjqaf6aqY4B+OOssfx8nORjTZPGqPwiT4k9P8McFmxn2q4X89z820rtLO+bcczGvTLmML/Q5M3ESAji1EEm1vmirkYiLBs8UROQnwDCgHbAG+B7O2YIxTeZ2YZqWvIBNXfYfPc5TSwp49p1tHD1ezpjzzmDq6GwGZ5we79Carqozed2DziWj1HQnIVgnc8y5qVN4D2cRnH8Ai4Dl8RySah3NpqXaXfwpjy/awgvvbud4eSVXX9CNKaOy6Nu9Y7xDMwnAyzqFQTidzSuBLwDrRWRJ80M0NXk9Ft/t8eK5ZoDVH7izdf8xfvi39xnxaA7PLd/GNRd2Z/53RjL9q4PcJ4Qw1QCEqS1uxbDNbi4f9QOG40xxMQTYgV0+8pTXY/HdHm/KFJg58+TjioqTj2uvgew1qz9o2MefHGFGbh6vr9tFclIrbh2azqQRvUk7PbXhF9cUphqAMLXFrRi32c3lo6rLRkuAd1X1hOdRNEIYLx95PRbf7fHiuWaA1R/Ubd2OQ0zLyePfG/Zwapskbrskg7uH9+KM9ikNvziaMNUAhKktbnnUZs/WaFbV8a7f1TSJ12Px3R4vWkKob7uXrP7gs1SVFQUHmJ6Tx9ub99PxlNbcP/Zs7hyWyWmpzZyGIkw1AGFqi1sxbrObOgXjM6/H4rs9XlJS3WcKfrP6A4czL9E+pufksWrbQbq0a8uPvngeEy/JoF1bj/57hqkGIExtcSvGbQ5ozXvL4vVYfLfHi+eaAS29/qCyUnlj/W6u+eMS7vzzu+wuLuWha/uy5IHRfGNklncJAcJVAxCmtrgV6za7mV87SLewrqfg9VoAbo8XzzUD/FjzIejKyiv0pVU7dMyvczTjgXk6+n9z9MV3t+vxExX+vnGY1kkIU1vc8qDNNHc9BRF5nWgLtp5MJhP8SVP1C2NHswm/0hMV/HV1IY8vyqfwoDMv0X1jsvlivwSYl8iEghd1Cr8GflPPzQSYHzUAbo85bpxT91B1Gzeu+e+dqI4dL+eJxVsY8WgOP33lA7q2b8tTXxvCm/cP55oLu1tCSCQrp8DcZHhenPuVHhT1BLDmor5ZUhfFMhDjHT9qANwec9w4WLDgs69dsMDZPn9+0947ERWXnODPy7byp2UFHCo5wbDszjx28wAuzeqMiCWChLNyCuTVKOrRipOPhzaxqCegNRdu6hTOBn4J9IGTq3qram9/Q4vOLh81zI8aALfHrO/3XQP/1EJh3xFnXqLnljvzEo07/wymjM5mUHoCz0tknDMDjTJUT5Lg1iYW9cS45sKzOgXgT8DPgd8Bo4E7AftTJ8D8qAGwuoL67Tr0KbMXb2Huyu2UVVQy/oJuTB2dzfndOsQ7NOOFaAmhvu1uBLTmwk1SOEVVF4iIqLPa2n+JyNs4icIEkB81AFZXEF3B/mPMzM3j5TU7UYUvD+zB5FFZ9O7aLt6hGS9JUt1nCk0V0JoLN3UKpSLSCtgsIveJyJeBM3yOyzSDHzUAbo85dmz019e1PVF99Mlhvjl3DWN/k8ura3fx1aHpLPrBaP73K/0tIYRRVh3FO3VtdyOgNRduzhS+DaQC3wJ+AYwBvuZnUKZ5/FiDwO0x58//fGfz2LHh6WRes/0g03Pymb/RmZdo0ogs7r68F13bt413aMZPVZ3J+bOdMwZJchJCUzuZIbBrSDTY0Vy9o0gHQFX1iL8h1c86mk2sqSrvbCliek4eS/OKOC21NXde1os7LsukY2rreIdnjCuedTSLyBCczub2kcfFwF2qWu+SnCKSAiwG2kbe5yVV/XmtfdoCfwEGA0XAzaq6taGYjIkFVSXn471MW5jHe9sP0bV9W3589Xl89WIP5yUyJmDc9Ck8DUxR1UxVzQSm4iSJhhwHxqhqf2AAcJWIXFJrn7uBg6qajTO66RHXkScItwVfibDgjNsFeRK9zRWVyrz3d/Ho4z/nnOUX8FLH4Xww8BssvXk3k0Y0c14itwVQXhc1Bf148eS2LWFqcz3c/Os+oqrVi+qo6hIRafASUmSujaORh60jt9rXqq4F/ivy80vAtMgop1CMaHdb8JUIC864XZAnkdt8oqKSV9bsZOaifPqdmMejadNIaXUcgHYVO2H1vZDUqunXfN0WQHld1BT048WT27aEqc0NcFO89jucjua5OL/UbwYOAn8DUNX36nltErAayAamq+oDtZ7/ALhKVQsjj/OBi1V1f13HTKQ+BbcFX4mw4IzbBXkSsc2lJyr466odzFq0hZ2HPuX8bh34W/dbST2x8/M7N6ewyG0BlNdFTUE/Xjy5bUsI2uxl8dqAyH3tuoTLcJLEmLpeqKoVwAAROQ14WUT6qeoHNeOM9rLaG0RkEjAJID2BBsa7LfhKhMIwtwvyJFKbjx4vZ87ybTzxdgH7jx5nUPpp/OK6vow+9wxk7q7oL2pOYZHbAiivi5qCfrx4ctuWMLW5AW5WXhvd3DdR1UMikgtcBdRMCoVAT6BQRJKBjsCBKK+fDcwG50yhufHEituCr0QoDHO7IE8itPlQSZkzL9HSrRR/eoLLs7swdfRALund6eS8RH4UFrktgPL6vYN+vHhy25YwtbkBDXY0i8iZIvKUiLwZedxHRO528bqukTMEROQUYBzwUa3dXuNkzcONwMKw9CeA+4KvRFhwxu2CPEFu874jx/nlmxsZ9quFPDZ/MxdlduKVqcN47p6LPz9RnR+FRW4LoLx+76AfL57ctiVMbW5IQwsuAG8CNwHrIo+TgfUuXnchsAZ4H+fs4GeR7Q8BEyI/pwB/BfKAlUDvho6baIvsuF1IJhEWnHG7IE/Q2lx4sER/+sp6PefBN7TXD+fpN59/TzfuLm74hX4s5rJisurzSapzcO5X1PEhev3eQT9ePLltS4K3meYuslNFRN5V1YtEZI2qDoxsW6uqA+p9oU8SqaPZxNeWfUeZmZvPy2t2IgLXD0zj3lFZ9OpyarxDMybmvFhkp8oxEelMpAM4UmtQ3Mz4TC1BHbOfiDbsOszU599j7G8X8dq6Xdx2SQaLvj+aR268MP4JIehj4v2IL+g1Ei2k/sAtN6OPvoNz7T9LRJYCXXGu/xuPBHHMfiJ6b/tBpi/MY8FHe2nXNpl7RzrzEnVpF5B5iYI+Jt6P+IJeI9GC6g/ccjX3UWRk0Lk4Q0g/VtUTfgdWlzBePgrSmP1Eo6q8k1/EtJw8luU78xLdNawXX7s0gPMSBX1MvB/xBb1GIgT1B255OffRV4C3VPVDEfkJMEhE/lvrKVozjROEMfuJRlVZsHEv03PzWLP9EGe0b8tPxp/PrUPTOTWo8xIFfUy8H/EFvUaiBdUfuOWmT+GnqnpERC4HrgSeAWY28BrTCHWNzQ9SnUJQVFQqr6/bxRd//zb3/GUV+44c57+v68fiH4zmnuG9g5sQoO4x7dHGxDfm9V7xIz6v2xL044WAm6RQVW0zHpipqq8CbfwLqeVJhDqFeCsrr+T/3t3BuN8u4ptz13CiopLffKU/Od8bxW2XZJDSuhkrYMVK0MfE+xFf0GskWlL9gUtu/qzaKSKP4xSfPRKZ7tpNMjEu+bEoTliUnqjgxXd3MHuxMy9R3+4dmDlxEFf2PYtWrRJsqXC3i6rEa/EVP+Lzui1BP14IuKlTSMWZnmK9qm4WkW7ABar6r1gEWFsYO5rN5x09Xs5zy7fxZGReoiEZpzN1TDajzun62cpjY4wrnnU0q2oJ8Pcaj3cDu5sXnjHRHTzmzEv052XOvETDz3bmJbq4V6eWlQxWTvF26UevBT0+cIabxuMMIF7v65EA98qZlmTv4VKeXFLAc8u3UVJWwRV9zmTq6Gz69zwt3qHFntt1F+Il6PFB8Gs9Asz1Gs1BYZePwqXwYAmPL9rCi6t2UF5RyZf6d2fKqGzOPat9vEOLH7frLsRL0OOD4Nd6xIGX6ykY47n8fUeZkZPPq2udeYluHJzGN0ZkkRnvaSiCwO26C/ES9Pgg+LUeAWZJwcTUh7uKmZGTzxsf7KZtcituvzSDSSN6063jKfEOLTjcrrsQL0GPD+K3/kEI1l2woaUmJlZvO8hdf36X8X9YwuJN+5g8MoslD4zh51/qawmhNrfrLsRL0OOD4Nd6BJidKRjfqCpL84qYlrOZ5VsOcHpqa753xTncfmkmHU8J2LxEQVLVWRvU0T1Bjw+CX+sRYNbRbDxXWaks+Ggv03LyWLfjEGd2aMvXh/fmqxenk9rG/g4xJh6so9nEXEWlMu/9XczMzeejT47Qs9MpPPzlftw4OI22yTG43pwI48O9jtHreoFE+AyNrywpmGYrK6/k5TWFzMx7pxsCAAARvklEQVTNZ2tRCdlntON3N/fnSxd2JzkpRt1WiTA+3OsYva4XSITP0PjOLh+ZJis9UcELK7cze/EWdhWX0q9HB+4bnc0VfeIwL1GAx4dX8zpGr+sFEuEzNE1ml4+Mb46UnuDZ5dt4ekkB+4+WcVHm6fzP9RcwMp7zEiXC+HCvY/S6XiARPkPjO0sKxrWDx8r409IC/rxsK4dLyxlxTlfuG53N0F6d4h1aYowP9zpGr+sFEuEzNL6zOgXToD2HS3n4HxsY9shC/rAwj0uzOvPafcP4y11Dg5EQIDHGh3sdo9f1AonwGRrf2ZmCqdOOAyXMWpTPX1cVUqHKhP7dmTwqi3PODOC8RIkwPtzrGL2uF0iEz9D4zjqazefk7T3KjNw8Xl27iyQRbhicxuSRWaR3Tm34xcaYQLKOZtNoH+wsZnpOHm99+Altk1vxtUszmTSiN2d1TIl3aMHn9fh+t8ezugLjMUsKhlVbDzAtJ4/cj/fRvm0yU0dlc+ewTDq3axvv0BKD1+P73R7P6gqMD+zyUQulqizJ28+0hXmsKDhAp1PbcPflvbj90gw6pNi8RI3i9fh+t8ezugLTCHb5yERVWan8e+MeZuTksa6wmLM6pPDTa/pw69CeNi9RU3k9vt/t8ayuwPjAfgu0EOUVlfxj/W6m5+Sxac9R0jul8svrL+D6QT1iMy9RmHk9vt/t8ayuwPjA6hRC7nh5BXNXbmfsbxdx/wtrUYXHbh7Awu+O5Nah6ZYQvOD1+H63x7O6AuMDO1MIqU/LnGQwe/EWPjlcygU9OjLrtsFc0efM2M9LFHZej+93ezyrKzA+sI7mkDlceoJn33HmJSo6VsbQXp24b3Q2w8/uEr95iYwxcRf3jmYR6Qn8BTgLqARmq+rva+0zCngVKIhs+ruqPuRXTGF24FgZTy8p4Jl3tnKktJyR53TlvjHZXJQZkGkojDEJwc/LR+XAd1X1PRFpD6wWkX+r6oZa+72tqtf4GEeo7TlcyuzFW3h+xXZKyyu4qu9ZTB2dTb8eHeMdmncSoUDLis2azz6bQPAtKajqbmB35OcjIrIR6AHUTgqmCbYXlTBrcT4vReYlujYyL9HZQZyXqDkSoUDLis2azz6bwIhJn4KIZAKLgX6qerjG9lHA34BCYBfwPVX9sL5jtfQ+hc17jjAjN5/X1jnzEt04JI17R4R4XqJEKNCyYrPms8/Gd3HvU6gRSDucX/zfrpkQIt4DMlT1qIhcDbwCnB3lGJOASQDp6S1zDPb6Qmdeon9u+ISU5CTuvCyTe4a3gHmJEqFAy4rNms8+m8DwNSmISGuchDBHVf9e+/maSUJV3xCRGSLSRVX319pvNjAbnDMFP2MOmpUFzrxEizfto31KMveNzubOYb3odGqbeIcWG4lQoGXFZs1nn01g+Fa8Js74x6eAjar62zr2OSuyHyIyNBJPkV8xJQpVZdGmfdw06x1uevwdPtxZzPevPJelPxzDd684t+UkBEiMAi0rNms++2wCw88zhWHA7cB6EVkb2fZjIB1AVWcBNwKTRaQc+BS4RROtcMJDlZXKvzbsYXpOHut3FtOtYwo//1IfbrkonVPatNDK40Qo0LJis+azzyYwrHgtAMorKnn9/V3MyMln896jZHROZfLILK4flEabZJuJxBjTfIHpaDZ1O15ewd9W72TWony2HyjhnDPb8ftbBjD+gm4kJyVwMmiJ481XTvFuWUxj4siSQhyUlJXz/IrtPPH2FvYcPk7/tI78ZPxgxp0fgnmJWuJ485VTIG/mycdacfKxJQaTYOzyUQwVf3qCZ9/ZytNLt3LgWBkX9+rEfWOyuTw7RPMStcTx5nOTnURQmyTBreWxj8eYKOzyUYAUHT3O00sL+MuybRw5Xs6oc7ty3+hshoRxXqKWON48WkKob7sxAWZJwUe7iz9l9uItzF25nePllXyx31lMGRWyeYlqa4njzSWp7jMFYxKMJQUfbCs6xqxF+by0upBKhesG9GDyqN5knxGyeYmi6f/wZ/sUIPzjzbMmfbZPoeZ2YxKMJQUPbdpzhBk5eby2bhfJSa24+aKefGNEFj07hXReomha4njzqs5kG31kQsA6mj3wfuEhpi3M418b9pDaJomJF6fz9eG9OaNDyOclMsYkDLcdzQk8GD7+Vmwp4vanVjBh2lKWbyniW2OyWfrAGB4c38e3hDBnDmRmQqtWzv2cOb68TWwUzHFGKz3fyrkvSODGhKkt8WKfYSDY5aNGUlVyN+1jRk4e7249SJd2bXjgqvO47ZJ02qe09vW958yBSZOgJHK5fts25zHAxES7OhOmeoYwtSVe7DMMDLt85FJlpfLPDz9hem4eH+w8TLeOKXxjRG9ujuG8RJmZTiKoLSMDtm6NSQjeCVM9Q5jaEi/2GfrO6hQ8cqKiktfX7WJGbj55e4+S2TmVR264gC8PjP28RNvrGOpf1/ZAC1M9Q5jaEi/2GQaGJYU6lJ6o4KXVhcxalE/hwU8576z2/OHWgYy/oBtJcZqKIj09+plCQq47FKZ6hjC1JV7sMwwM62iupaSsnCff3sKIR3P4ySsf0LldW574jyG88a3hTOjfPW4JAeDhhyG11ujW1FRne8IJ0/z5YWpLvNhnGBh2phBR/OkJnlm2lT8tLeBgyQku7d2Z3908gMuyOgdmXqKqzuQHH3QuGaWnOwkh4TqZIVz1DGFqS7zYZxgYLb6jef/R4zy1pIBn39nG0ePljDnvDKaOzmZwxumevYcxxsSbdTQ3YNchZ16iF9515iW6ul83pozOom/3EM9LZExD/FgLoyWur5HAWlxS2Lr/GDNz8/n7mkJU4bqBPZg8Kousru3iHZox8eVHrYDVHyScFpMU8vcd5ffzNzPvfWdeolsuSmfSiN4ta14iY+qz7sHPTmQIzuN1Dzb9F7gfxzS+ajFJYfuBEhZs3MPXh/fm7st72bxExtTmR62A1R8knBaTFEad05VlPxxLx1R/p6IwJmH5UStg9QcJp8XUKYiIJQRj6uNHrYDVHyScFpMUjDEN6DURhs525htCnPuhs5t37d+PYxpftfg6BWOMaQlsPQVjjDGNZknBGGNMNUsKxhhjqllSMMYYU82SgjHGmGqWFIwxxlSzpGCMMaaaJQVjjDHVfEsKItJTRHJEZKOIfCgi90fZR0TkDyKSJyLvi8ggv+IxxhjTMD/PFMqB76rq+cAlwFQR6VNrny8CZ0duk4CZPsZjgqZgDrySCc+3cu4L5sQ7ImNaPN+SgqruVtX3Ij8fATYCPWrtdi3wF3UsB04TkW5+xWQCpGrxlZJtgJ5cfMUSgzFxFZM+BRHJBAYCK2o91QPYUeNxIZ9PHCaM6lt8xRgTN74nBRFpB/wN+LaqHq79dJSXfG6GPhGZJCKrRGTVvn37/AjTxJotvmJMIPmaFESkNU5CmKOqf4+ySyHQs8bjNGBX7Z1UdbaqDlHVIV27dvUnWBNbdS2yYouvGBNXfo4+EuApYKOq/raO3V4D/iMyCukSoFhVd/sVkwkQW3zFmEDycznOYcDtwHoRWRvZ9mMgHUBVZwFvAFcDeUAJcKeP8ZggqVpkZd2DziWj1HQnIdjiK8bElW9JQVWXEL3PoOY+Ckz1KwYTcL0mWhIwJmCsotkYY0w1SwrGGGOqWVIwxhhTzZKCMcaYapYUjDHGVBNnAFDiEJF9wLYmvrwLsN/DcOLJ2hJMYWlLWNoB1pYqGaraYPVvwiWF5hCRVao6JN5xeMHaEkxhaUtY2gHWlsayy0fGGGOqWVIwxhhTraUlhdnxDsBD1pZgCktbwtIOsLY0SovqUzDGGFO/lnamYIwxph6hTQoikiQia0RkXpTn2orIiyKSJyIrIivDBVYDbblDRPaJyNrI7Z54xOiGiGwVkfWROFdFeV5E5A+R7+V9ERkUjzgb4qIdo0SkuMZ38rN4xOmGiJwmIi+JyEcislFELq31fEJ8J+CqLQnxvYjIuTViXCsih0Xk27X28e178XPq7Hi7H2dd6A5RnrsbOKiq2SJyC/AIcHMsg2uk+toC8KKq3hfDeJpjtKrWNc76i8DZkdvFwMzIfRDV1w6At1X1mphF03S/B95S1RtFpA1Qa5GLhPpOGmoLJMD3oqofAwPA+YMQ2Am8XGs3376XUJ4piEgaMB54so5drgWeifz8EjA2sihQ4LhoS5hcC/xFHcuB00SkW7yDCisR6QCMwFkMC1UtU9VDtXZLiO/EZVsS0VggX1VrF+z69r2EMikAjwE/ACrreL4HsANAVcuBYqBzbEJrtIbaAnBD5BTyJRHpWc9+8abAv0RktYhMivJ89fcSURjZFjQNtQPgUhFZJyJvikjfWAbXCL2BfcCfIpcnnxSRU2vtkyjfiZu2QGJ8LzXdAsyNst237yV0SUFErgH2qurq+naLsi1ww7BctuV1IFNVLwTmc/IMKIiGqeognFPfqSIyotbzCfG90HA73sOZUqA/8EfglVgH6FIyMAiYqaoDgWPAD2vtkyjfiZu2JMr3AkDkEtgE4K/Rno6yzZPvJXRJAWcZ0AkishV4ARgjIs/V2qcQ6AkgIslAR+BALIN0qcG2qGqRqh6PPHwCGBzbEN1T1V2R+70410iH1tql+nuJSAN2xSY69xpqh6oeVtWjkZ/fAFqLSJeYB9qwQqBQVVdEHr+E84u19j6B/05w0ZYE+l6qfBF4T1X3RHnOt+8ldElBVX+kqmmqmolz6rVQVW+rtdtrwNciP98Y2Sdwf/24aUut64gTcDqkA0dEThWR9lU/A1cAH9Ta7TXgPyIjKy4BilV1d4xDrZebdojIWVV9VCIyFOf/WVGsY22Iqn4C7BCRcyObxgIbau0W+O8E3LUlUb6XGm4l+qUj8PF7CfPoo88QkYeAVar6Gk5n1LMikodzhnBLXINrpFpt+ZaITADKcdpyRzxjq8eZwMuR/5PJwPOq+paI3AugqrOAN4CrgTygBLgzTrHWx007bgQmi0g58ClwSxD/6Ij4JjAncqliC3BnAn4nVRpqS8J8LyKSCnwB+EaNbTH5Xqyi2RhjTLXQXT4yxhjTdJYUjDHGVLOkYIwxppolBWOMMdUsKRhjjKlmScGYRorMthltxtqo2z14v+tEpE+Nx7kiEoo1h03wWFIwJviuA/o0uJcxHrCkYEInUnX8j8jEZx+IyM2R7YNFZFFkIrt/VlWDR/7yfkxElkX2HxrZPjSybU3k/tz63jdKDE+LyLuR118b2X6HiPxdRN4Skc0i8miN19wtIpsi8TwhItNE5DKcSvX/FWdu/azI7l8RkZWR/Yd79NEZ03Iqmk2LchWwS1XHA4hIRxFpjTMJ2rWqui+SKB4G7oq85lRVvSwyud3TQD/gI2CEqpaLyDjgf4AbXMbwIM60JHeJyGnAShGZH3luADAQOA58LCJ/BCqAn+LM13MEWAisU9VlIvIaME9VX4q0ByBZVYeKyNXAz4FxTfmgjKnNkoIJo/XAr0XkEZxfpm+LSD+cX/T/jvxSTQJqzhUzF0BVF4tIh8gv8vbAMyJyNs4MlK0bEcMVOJMZfi/yOAVIj/y8QFWLAURkA5ABdAEWqeqByPa/AufUc/y/R+5XA5mNiMuYellSMKGjqptEZDDO3DC/FJF/4cxm+qGqXlrXy6I8/gWQo6pfFmfJ1txGhCHADZFVtE5uFLkY5wyhSgXO/8PGLvJUdYyq1xvjCetTMKEjIt2BElV9Dvg1ziWZj4GuElm3V0Ray2cXWanqd7gcZ8bJYpwp1XdGnr+jkWH8E/hmjVk5Bzaw/0pgpIicLs507jUvUx3BOWsxxnf2F4YJowtwOmYrgRPAZFUtE5EbgT+ISEecf/uPAR9GXnNQRJbhrINd1c/wKM7lo+/gXONvjF9Ejv9+JDFsBepcG1hVd4rI/wArcObF34CzIiA4a2k8ISLfwpnp0xjf2CyppsUTkVzge6q6Ks5xtFPVo5EzhZeBp1W19oLtxvjKLh8ZExz/JSJrcRbtKSDgy0WacLIzBWOMMdXsTMEYY0w1SwrGGGOqWVIwxhhTzZKCMcaYapYUjDHGVLOkYIwxptr/B/cTZEbiZPByAAAAAElFTkSuQmCC\n", 317 | "text/plain": [ 318 | "
" 319 | ] 320 | }, 321 | "metadata": {}, 322 | "output_type": "display_data" 323 | } 324 | ], 325 | "source": [ 326 | "x_ponits = np.arange(4, 8)\n", 327 | "y_ = -(clf.coef_[0][0]*x_ponits + clf.intercept_)/clf.coef_[0][1]\n", 328 | "plt.plot(x_ponits, y_)\n", 329 | "\n", 330 | "plt.plot(X[:50, 0], X[:50, 1], 'bo', color='blue', label='0')\n", 331 | "plt.plot(X[50:, 0], X[50:, 1], 'bo', color='orange', label='1')\n", 332 | "plt.xlabel('sepal length')\n", 333 | "plt.ylabel('sepal width')\n", 334 | "plt.legend()" 335 | ] 336 | } 337 | ], 338 | "metadata": { 339 | "kernelspec": { 340 | "display_name": "Python 3", 341 | "language": "python", 342 | "name": "python3" 343 | }, 344 | "language_info": { 345 | "codemirror_mode": { 346 | "name": "ipython", 347 | "version": 3 348 | }, 349 | "file_extension": ".py", 350 | "mimetype": "text/x-python", 351 | "name": "python", 352 | "nbconvert_exporter": "python", 353 | "pygments_lexer": "ipython3", 354 | "version": "3.6.2" 355 | } 356 | }, 357 | "nbformat": 4, 358 | "nbformat_minor": 2 359 | } 360 | -------------------------------------------------------------------------------- /code/第6章 逻辑斯谛回归(LogisticRegression)/最大熵模型 IIS.py: -------------------------------------------------------------------------------- 1 | import math 2 | from copy import deepcopy 3 | 4 | 5 | class MaxEntropy: 6 | def __init__(self, EPS=0.005): 7 | self._samples = [] 8 | self._Y = set() # 标签集合,相当去去重后的y 9 | self._numXY = {} # key为(x,y),value为出现次数 10 | self._N = 0 # 样本数 11 | self._Ep_ = [] # 样本分布的特征期望值 12 | self._xyID = {} # key记录(x,y),value记录id号 13 | self._n = 0 # 特征键值(x,y)的个数 14 | self._C = 0 # 最大特征数 15 | self._IDxy = {} # key为(x,y),value为对应的id号 16 | self._w = [] 17 | self._EPS = EPS # 收敛条件 18 | self._lastw = [] # 上一次w参数值 19 | 20 | def loadData(self, dataset): 21 | self._samples = deepcopy(dataset) 22 | for items in self._samples: 23 | y = items[0] 24 | X = items[1:] 25 | self._Y.add(y) # 集合中y若已存在则会自动忽略 26 | for x in X: 27 | if (x, y) in self._numXY: 28 | self._numXY[(x, y)] += 1 29 | else: 30 | self._numXY[(x, y)] = 1 31 | 32 | self._N = len(self._samples) 33 | self._n = len(self._numXY) 34 | self._C = max([len(sample)-1 for sample in self._samples]) 35 | self._w = [0]*self._n 36 | self._lastw = self._w[:] 37 | 38 | self._Ep_ = [0] * self._n 39 | for i, xy in enumerate(self._numXY): # 计算特征函数fi关于经验分布的期望 40 | self._Ep_[i] = self._numXY[xy]/self._N 41 | self._xyID[xy] = i 42 | self._IDxy[i] = xy 43 | 44 | def _Zx(self, X): # 计算每个Z(x)值 45 | zx = 0 46 | for y in self._Y: 47 | ss = 0 48 | for x in X: 49 | if (x, y) in self._numXY: 50 | ss += self._w[self._xyID[(x, y)]] 51 | zx += math.exp(ss) 52 | return zx 53 | 54 | def _model_pyx(self, y, X): # 计算每个P(y|x) 55 | zx = self._Zx(X) 56 | ss = 0 57 | for x in X: 58 | if (x, y) in self._numXY: 59 | ss += self._w[self._xyID[(x, y)]] 60 | pyx = math.exp(ss)/zx 61 | return pyx 62 | 63 | def _model_ep(self, index): # 计算特征函数fi关于模型的期望 64 | x, y = self._IDxy[index] 65 | ep = 0 66 | for sample in self._samples: 67 | if x not in sample: 68 | continue 69 | pyx = self._model_pyx(y, sample) 70 | ep += pyx/self._N 71 | return ep 72 | 73 | def _convergence(self): # 判断是否全部收敛 74 | for last, now in zip(self._lastw, self._w): 75 | if abs(last - now) >= self._EPS: 76 | return False 77 | return True 78 | 79 | def predict(self, X): # 计算预测概率 80 | Z = self._Zx(X) 81 | result = {} 82 | for y in self._Y: 83 | ss = 0 84 | for x in X: 85 | if (x, y) in self._numXY: 86 | ss += self._w[self._xyID[(x, y)]] 87 | pyx = math.exp(ss)/Z 88 | result[y] = pyx 89 | return result 90 | 91 | def train(self, maxiter=1000): # 训练数据 92 | for loop in range(maxiter): # 最大训练次数 93 | print("iter:%d" % loop) 94 | self._lastw = self._w[:] 95 | for i in range(self._n): 96 | ep = self._model_ep(i) # 计算第i个特征的模型期望 97 | self._w[i] += math.log(self._Ep_[i]/ep)/self._C # 更新参数 98 | print("w:", self._w) 99 | if self._convergence(): # 判断是否收敛 100 | break 101 | 102 | 103 | dataset = [['no', 'sunny', 'hot', 'high', 'FALSE'], 104 | ['no', 'sunny', 'hot', 'high', 'TRUE'], 105 | ['yes', 'overcast', 'hot', 'high', 'FALSE'], 106 | ['yes', 'rainy', 'mild', 'high', 'FALSE'], 107 | ['yes', 'rainy', 'cool', 'normal', 'FALSE'], 108 | ['no', 'rainy', 'cool', 'normal', 'TRUE'], 109 | ['yes', 'overcast', 'cool', 'normal', 'TRUE'], 110 | ['no', 'sunny', 'mild', 'high', 'FALSE'], 111 | ['yes', 'sunny', 'cool', 'normal', 'FALSE'], 112 | ['yes', 'rainy', 'mild', 'normal', 'FALSE'], 113 | ['yes', 'sunny', 'mild', 'normal', 'TRUE'], 114 | ['yes', 'overcast', 'mild', 'high', 'TRUE'], 115 | ['yes', 'overcast', 'hot', 'normal', 'FALSE'], 116 | ['no', 'rainy', 'mild', 'high', 'TRUE']] 117 | 118 | maxent = MaxEntropy() 119 | x = ['overcast', 'mild', 'high', 'FALSE'] 120 | maxent.loadData(dataset) 121 | maxent.train() 122 | print('predict:', maxent.predict(x)) 123 | -------------------------------------------------------------------------------- /code/第7章 支持向量机(SVM)/support-vector-machine.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "原文代码作者:https://github.com/wzyonggege/statistical-learning-method\n", 8 | "\n", 9 | "中文注释制作:机器学习初学者(微信公众号:ID:ai-start-com)\n", 10 | "\n", 11 | "配置环境:python 3.6\n", 12 | "\n", 13 | "代码全部测试通过。\n", 14 | "![gongzhong](../gongzhong.jpg)" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "# 第7章 支持向量机" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "----\n", 29 | "分离超平面:$w^Tx+b=0$\n", 30 | "\n", 31 | "点到直线距离:$r=\\frac{|w^Tx+b|}{||w||_2}$\n", 32 | "\n", 33 | "$||w||_2$为2-范数:$||w||_2=\\sqrt[2]{\\sum^m_{i=1}w_i^2}$\n", 34 | "\n", 35 | "直线为超平面,样本可表示为:\n", 36 | "\n", 37 | "$w^Tx+b\\ \\geq+1$\n", 38 | "\n", 39 | "$w^Tx+b\\ \\leq+1$\n", 40 | "\n", 41 | "#### margin:\n", 42 | "\n", 43 | "**函数间隔**:$label(w^Tx+b)\\ or\\ y_i(w^Tx+b)$\n", 44 | "\n", 45 | "**几何间隔**:$r=\\frac{label(w^Tx+b)}{||w||_2}$,当数据被正确分类时,几何间隔就是点到超平面的距离\n", 46 | "\n", 47 | "为了求几何间隔最大,SVM基本问题可以转化为求解:($\\frac{r^*}{||w||}$为几何间隔,(${r^*}$为函数间隔)\n", 48 | "\n", 49 | "$$\\max\\ \\frac{r^*}{||w||}$$\n", 50 | "\n", 51 | "$$(subject\\ to)\\ y_i({w^T}x_i+{b})\\geq {r^*},\\ i=1,2,..,m$$\n", 52 | "\n", 53 | "分类点几何间隔最大,同时被正确分类。但这个方程并非凸函数求解,所以要先①将方程转化为凸函数,②用拉格朗日乘子法和KKT条件求解对偶问题。\n", 54 | "\n", 55 | "①转化为凸函数:\n", 56 | "\n", 57 | "先令${r^*}=1$,方便计算(参照衡量,不影响评价结果)\n", 58 | "\n", 59 | "$$\\max\\ \\frac{1}{||w||}$$\n", 60 | "\n", 61 | "$$s.t.\\ y_i({w^T}x_i+{b})\\geq {1},\\ i=1,2,..,m$$\n", 62 | "\n", 63 | "再将$\\max\\ \\frac{1}{||w||}$转化成$\\min\\ \\frac{1}{2}||w||^2$求解凸函数,1/2是为了求导之后方便计算。\n", 64 | "\n", 65 | "$$\\min\\ \\frac{1}{2}||w||^2$$\n", 66 | "\n", 67 | "$$s.t.\\ y_i(w^Tx_i+b)\\geq 1,\\ i=1,2,..,m$$\n", 68 | "\n", 69 | "②用拉格朗日乘子法和KKT条件求解最优值:\n", 70 | "\n", 71 | "$$\\min\\ \\frac{1}{2}||w||^2$$\n", 72 | "\n", 73 | "$$s.t.\\ -y_i(w^Tx_i+b)+1\\leq 0,\\ i=1,2,..,m$$\n", 74 | "\n", 75 | "整合成:\n", 76 | "\n", 77 | "$$L(w, b, \\alpha) = \\frac{1}{2}||w||^2+\\sum^m_{i=1}\\alpha_i(-y_i(w^Tx_i+b)+1)$$\n", 78 | "\n", 79 | "推导:$\\min\\ f(x)=\\min \\max\\ L(w, b, \\alpha)\\geq \\max \\min\\ L(w, b, \\alpha)$\n", 80 | "\n", 81 | "根据KKT条件:\n", 82 | "\n", 83 | "$$\\frac{\\partial }{\\partial w}L(w, b, \\alpha)=w-\\sum\\alpha_iy_ix_i=0,\\ w=\\sum\\alpha_iy_ix_i$$\n", 84 | "\n", 85 | "$$\\frac{\\partial }{\\partial b}L(w, b, \\alpha)=\\sum\\alpha_iy_i=0$$\n", 86 | "\n", 87 | "带入$ L(w, b, \\alpha)$\n", 88 | "\n", 89 | "$\\min\\ L(w, b, \\alpha)=\\frac{1}{2}||w||^2+\\sum^m_{i=1}\\alpha_i(-y_i(w^Tx_i+b)+1)$\n", 90 | "\n", 91 | "$\\qquad\\qquad\\qquad=\\frac{1}{2}w^Tw-\\sum^m_{i=1}\\alpha_iy_iw^Tx_i-b\\sum^m_{i=1}\\alpha_iy_i+\\sum^m_{i=1}\\alpha_i$\n", 92 | "\n", 93 | "$\\qquad\\qquad\\qquad=\\frac{1}{2}w^T\\sum\\alpha_iy_ix_i-\\sum^m_{i=1}\\alpha_iy_iw^Tx_i+\\sum^m_{i=1}\\alpha_i$\n", 94 | "\n", 95 | "$\\qquad\\qquad\\qquad=\\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i=1}\\alpha_iy_iw^Tx_i$\n", 96 | "\n", 97 | "$\\qquad\\qquad\\qquad=\\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)$\n", 98 | "\n", 99 | "再把max问题转成min问题:\n", 100 | "\n", 101 | "$\\max\\ \\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)=\\min \\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)-\\sum^m_{i=1}\\alpha_i$\n", 102 | "\n", 103 | "$s.t.\\ \\sum^m_{i=1}\\alpha_iy_i=0,$\n", 104 | "\n", 105 | "$ \\alpha_i \\geq 0,i=1,2,...,m$\n", 106 | "\n", 107 | "以上为SVM对偶问题的对偶形式\n", 108 | "\n", 109 | "-----\n", 110 | "#### kernel\n", 111 | "\n", 112 | "在低维空间计算获得高维空间的计算结果,也就是说计算结果满足高维(满足高维,才能说明高维下线性可分)。\n", 113 | "\n", 114 | "#### soft margin & slack variable\n", 115 | "\n", 116 | "引入松弛变量$\\xi\\geq0$,对应数据点允许偏离的functional margin 的量。\n", 117 | "\n", 118 | "目标函数:$\\min\\ \\frac{1}{2}||w||^2+C\\sum\\xi_i\\qquad s.t.\\ y_i(w^Tx_i+b)\\geq1-\\xi_i$ \n", 119 | "\n", 120 | "对偶问题:\n", 121 | "\n", 122 | "$$\\max\\ \\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)=\\min \\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)-\\sum^m_{i=1}\\alpha_i$$\n", 123 | "\n", 124 | "$$s.t.\\ C\\geq\\alpha_i \\geq 0,i=1,2,...,m\\quad \\sum^m_{i=1}\\alpha_iy_i=0,$$\n", 125 | "\n", 126 | "-----\n", 127 | "\n", 128 | "#### Sequential Minimal Optimization\n", 129 | "\n", 130 | "首先定义特征到结果的输出函数:$u=w^Tx+b$.\n", 131 | "\n", 132 | "因为$w=\\sum\\alpha_iy_ix_i$\n", 133 | "\n", 134 | "有$u=\\sum y_i\\alpha_iK(x_i, x)-b$\n", 135 | "\n", 136 | "\n", 137 | "----\n", 138 | "\n", 139 | "$\\max \\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i=1}\\sum^m_{j=1}\\alpha_i\\alpha_jy_iy_j<\\phi(x_i)^T,\\phi(x_j)>$\n", 140 | "\n", 141 | "$s.t.\\ \\sum^m_{i=1}\\alpha_iy_i=0,$\n", 142 | "\n", 143 | "$ \\alpha_i \\geq 0,i=1,2,...,m$\n", 144 | "\n", 145 | "-----\n", 146 | "参考资料:\n", 147 | "\n", 148 | "[1] :[Lagrange Multiplier and KKT](http://blog.csdn.net/xianlingmao/article/details/7919597)\n", 149 | "\n", 150 | "[2] :[推导SVM](https://my.oschina.net/dfsj66011/blog/517766)\n", 151 | "\n", 152 | "[3] :[机器学习算法实践-支持向量机(SVM)算法原理](http://pytlab.org/2017/08/15/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E7%AE%97%E6%B3%95%E5%AE%9E%E8%B7%B5-%E6%94%AF%E6%8C%81%E5%90%91%E9%87%8F%E6%9C%BA-SVM-%E7%AE%97%E6%B3%95%E5%8E%9F%E7%90%86/)\n", 153 | "\n", 154 | "[4] :[Python实现SVM](http://blog.csdn.net/wds2006sdo/article/details/53156589)" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 1, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "import numpy as np\n", 164 | "import pandas as pd\n", 165 | "from sklearn.datasets import load_iris\n", 166 | "from sklearn.model_selection import train_test_split\n", 167 | "import matplotlib.pyplot as plt\n", 168 | "%matplotlib inline" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 2, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "# data\n", 178 | "def create_data():\n", 179 | " iris = load_iris()\n", 180 | " df = pd.DataFrame(iris.data, columns=iris.feature_names)\n", 181 | " df['label'] = iris.target\n", 182 | " df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n", 183 | " data = np.array(df.iloc[:100, [0, 1, -1]])\n", 184 | " for i in range(len(data)):\n", 185 | " if data[i,-1] == 0:\n", 186 | " data[i,-1] = -1\n", 187 | " # print(data)\n", 188 | " return data[:,:2], data[:,-1]" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 3, 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "X, y = create_data()\n", 198 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 4, 204 | "metadata": {}, 205 | "outputs": [ 206 | { 207 | "data": { 208 | "text/plain": [ 209 | "" 210 | ] 211 | }, 212 | "execution_count": 4, 213 | "metadata": {}, 214 | "output_type": "execute_result" 215 | }, 216 | { 217 | "data": { 218 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAGihJREFUeJzt3X+MXWWdx/H3d4dZOiowoYwrzJQtP0yjQNfCCJImxAV3q7UWgiyU4I8qC7sGFwwuRgxBbUzAkOCPJdEUyALCFrsVS2H5sQhLVAI1U8B2bSWCoJ2BXYZii6wFyvDdP+6ddubOnbn3ufeeuc/z3M8raTr33Ken3+cc/XJ7zuc819wdERHJy5+1uwAREWk9NXcRkQypuYuIZEjNXUQkQ2ruIiIZUnMXEcmQmruISIbU3EVEMqTmLiKSof3qHWhmXcAQMOLuyyreWwlcA4yUN13n7jfMtL9DDjnE58+fH1SsiEin27Rp00vu3ldrXN3NHbgE2AYcOM37P3T3z9e7s/nz5zM0NBTw14uIiJn9rp5xdV2WMbMB4KPAjJ/GRUQkDvVec/828CXgrRnGfNzMNpvZOjObV22AmV1oZkNmNjQ6Ohpaq4iI1KlmczezZcCL7r5phmF3AfPdfSHwE+DmaoPcfbW7D7r7YF9fzUtGIiLSoHquuS8GlpvZUmAOcKCZ3erunxgf4O47Joy/Hvhma8sUEWmdPXv2MDw8zGuvvdbuUqY1Z84cBgYG6O7ubujP12zu7n45cDmAmX0Q+OeJjb28/VB3f6H8cjmlG68iIlEaHh7mgAMOYP78+ZhZu8uZwt3ZsWMHw8PDHHHEEQ3to+Gcu5mtMrPl5ZcXm9mvzOyXwMXAykb3KyJStNdee425c+dG2dgBzIy5c+c29S+LkCgk7v4w8HD55ysnbN/76V4kN+ufGOGa+5/i+Z27Oay3h8uWLOCMRf3tLkuaFGtjH9dsfUHNXaTTrH9ihMvv2MLuPWMAjOzczeV3bAFQg5eoafkBkRlcc/9Texv7uN17xrjm/qfaVJHk4r777mPBggUcffTRXH311S3fv5q7yAye37k7aLtIPcbGxrjooou499572bp1K2vWrGHr1q0t/Tt0WUZkBof19jBSpZEf1tvThmqkXVp93+UXv/gFRx99NEceeSQAK1as4M477+S9731vq0rWJ3eRmVy2ZAE93V2TtvV0d3HZkgVtqkhm2/h9l5Gdu3H23XdZ/8RIzT87nZGREebN2/cg/8DAACMjje+vGjV3kRmcsaifq848jv7eHgzo7+3hqjOP083UDlLEfRd3n7Kt1ekdXZYRqeGMRf1q5h2siPsuAwMDbN++fe/r4eFhDjvssIb3V40+uYuIzGC6+yvN3Hd5//vfz29+8xueffZZ3njjDW6//XaWL19e+w8GUHMXEZlBEfdd9ttvP6677jqWLFnCe97zHs4++2yOOeaYZkud/He0dG8iIpkZvyTX6qeUly5dytKlS1tRYlVq7iIiNaR430WXZUREMqTmLiKSITV3EZEMqbmLiGRIzV1EJENq7pKN9U+MsPjqhzjiy//B4qsfamrtD5Giffazn+Wd73wnxx57bCH7V3OXLBSxuJNIkVauXMl9991X2P7V3CUL+lINKdTmtfCtY+FrvaXfN69tepennHIKBx98cAuKq04PMUkW9KUaUpjNa+Gui2FP+X9Lu7aXXgMsPLt9ddWgT+6ShSIWdxIB4MFV+xr7uD27S9sjpuYuWdCXakhhdg2HbY+ELstIFopa3EmEgwZKl2KqbY+YmrtkI8XFnSQBp105+Zo7QHdPaXsTzj33XB5++GFeeuklBgYG+PrXv87555/fZLH7qLlL01r95cEiURm/afrgqtKlmIMGSo29yZupa9asaUFx01Nzl6aM58vHY4jj+XJADV7ysfDsqJMx1eiGqjRF+XKROKm5S1OUL5dUuXu7S5hRs/WpuUtTlC+XFM2ZM4cdO3ZE2+DdnR07djBnzpyG96Fr7tKUy5YsmHTNHZQvl/gNDAwwPDzM6Ohou0uZ1pw5cxgYaDxuqeYuTVG+XFLU3d3NEUcc0e4yClV3czezLmAIGHH3ZRXv7Q/cApwA7ADOcffnWlinREz5cpH4hHxyvwTYBhxY5b3zgT+4+9FmtgL4JnBOC+oTSYoy/xKLum6omtkA8FHghmmGnA7cXP55HXCamVnz5YmkQ2vKS0zqTct8G/gS8NY07/cD2wHc/U1gFzC36epEEqLMv8SkZnM3s2XAi+6+aaZhVbZNyRiZ2YVmNmRmQzHfpRZphDL/EpN6PrkvBpab2XPA7cCpZnZrxZhhYB6Ame0HHAS8XLkjd1/t7oPuPtjX19dU4SKxUeZfYlKzubv75e4+4O7zgRXAQ+7+iYphG4BPl38+qzwmzqcDRAqiNeUlJg3n3M1sFTDk7huAG4EfmNnTlD6xr2hRfSLJUOZfYmLt+oA9ODjoQ0NDbfm7RURSZWab3H2w1jg9oSrRumL9FtZs3M6YO11mnHvSPL5xxnHtLkskCWruEqUr1m/h1sd+v/f1mPve12rwIrVpVUiJ0pqNVb6zcobtIjKZmrtEaWyae0HTbReRydTcJUpd06xeMd12EZlMzV2idO5J84K2i8hkuqEqURq/aaq0jEhjlHMXEUmIcu7SlPOuf5RHntm3PNDiow7mtgtObmNF7aM12iVFuuYuU1Q2doBHnnmZ865/tE0VtY/WaJdUqbnLFJWNvdb2nGmNdkmVmrvIDLRGu6RKzV1kBlqjXVKl5i5TLD7q4KDtOdMa7ZIqNXeZ4rYLTp7SyDs1LXPGon6uOvM4+nt7MKC/t4erzjxOaRmJnnLuIiIJUc5dmlJUtjtkv8qXizROzV2mGM92j0cAx7PdQFPNNWS/RdUg0il0zV2mKCrbHbJf5ctFmqPmLlMUle0O2a/y5SLNUXOXKYrKdofsV/lykeaoucsURWW7Q/arfLlIc3RDVaYYv2HZ6qRKyH6LqkGkUyjnLiKSEOXcC5ZiBjvFmkWkMWruDUgxg51izSLSON1QbUCKGewUaxaRxqm5NyDFDHaKNYtI49TcG5BiBjvFmkWkcWruDUgxg51izSLSON1QbUCKGewUaxaRxtXMuZvZHOCnwP6U/mOwzt2/WjFmJXANMP6V8Ne5+w0z7Vc5dxGRcK3Mub8OnOrur5pZN/BzM7vX3R+rGPdDd/98I8XK7Lhi/RbWbNzOmDtdZpx70jy+ccZxTY+NJT8fSx0iMajZ3L300f7V8svu8q/2PNYqDbti/RZufez3e1+Pue99Xdm0Q8bGkp+PpQ6RWNR1Q9XMuszsSeBF4AF331hl2MfNbLOZrTOzeS2tUpq2ZuP2ureHjI0lPx9LHSKxqKu5u/uYu78PGABONLNjK4bcBcx394XAT4Cbq+3HzC40syEzGxodHW2mbgk0Ns29lWrbQ8bGkp+PpQ6RWARFId19J/Aw8OGK7Tvc/fXyy+uBE6b586vdfdDdB/v6+hooVxrVZVb39pCxseTnY6lDJBY1m7uZ9ZlZb/nnHuBDwK8rxhw64eVyYFsri5TmnXtS9Stl1baHjI0lPx9LHSKxqCctcyhws5l1UfqPwVp3v9vMVgFD7r4BuNjMlgNvAi8DK4sqWBozfiO0ngRMyNhY8vOx1CESC63nLiKSEK3nXrCiMtUh+fIi9x0yvxSPRXI2r4UHV8GuYThoAE67Ehae3e6qJGJq7g0oKlMdki8vct8h80vxWCRn81q462LYU07+7Npeeg1q8DItLRzWgKIy1SH58iL3HTK/FI9Fch5cta+xj9uzu7RdZBpq7g0oKlMdki8vct8h80vxWCRn13DYdhHU3BtSVKY6JF9e5L5D5pfisUjOQQNh20VQc29IUZnqkHx5kfsOmV+KxyI5p10J3RX/sezuKW0XmYZuqDagqEx1SL68yH2HzC/FY5Gc8ZumSstIAOXcRUQSopy7TBFDdl0Sp7x9MtTcO0QM2XVJnPL2SdEN1Q4RQ3ZdEqe8fVLU3DtEDNl1SZzy9klRc+8QMWTXJXHK2ydFzb1DxJBdl8Qpb58U3VDtEDFk1yVxytsnRTl3EZGEKOdeVlReO2S/saxLrux6ZHLPjOc+vxBtOBZZN/ei8toh+41lXXJl1yOTe2Y89/mFaNOxyPqGalF57ZD9xrIuubLrkck9M577/EK06Vhk3dyLymuH7DeWdcmVXY9M7pnx3OcXok3HIuvmXlReO2S/saxLrux6ZHLPjOc+vxBtOhZZN/ei8toh+41lXXJl1yOTe2Y89/mFaNOxyPqGalF57ZD9xrIuubLrkck9M577/EK06Vgo5y4ikhDl3AsWQ37+vOsf5ZFnXt77evFRB3PbBSc3XYNIVu6+FDbdBD4G1gUnrIRl1za/38hz/Flfcy/KeGZ8ZOdunH2Z8fVPjMzafisbO8Ajz7zMedc/2lQNIlm5+1IYurHU2KH0+9CNpe3NGM+u79oO+L7s+ua1TZfcKmruDYghP1/Z2GttF+lIm24K216vBHL8au4NiCE/LyJ18LGw7fVKIMev5t6AGPLzIlIH6wrbXq8Ecvxq7g2IIT+/+KiDq+5juu0iHemElWHb65VAjl/NvQFnLOrnqjOPo7+3BwP6e3u46szjWpKfr3e/t11w8pRGrrSMSIVl18Lg+fs+qVtX6XWzaZmFZ8PHvgsHzQOs9PvHvhtVWkY5dxGRhLQs525mc4CfAvuXx69z969WjNkfuAU4AdgBnOPuzzVQd02h+fLU1jAPWfs992NRaI44JPtcVB1Fzi/yDHZTQueW87GYQT0PMb0OnOrur5pZN/BzM7vX3R+bMOZ84A/ufrSZrQC+CZzT6mJD1yRPbQ3zkLXfcz8Wha6BPZ59HjeefYapDb6oOoqcX85rqYfOLedjUUPNa+5e8mr5ZXf5V+W1nNOBm8s/rwNOM2v9soeh+fLU1jAPWfs992NRaI44JPtcVB1Fzi+BDHbDQueW87Gooa4bqmbWZWZPAi8CD7j7xooh/cB2AHd/E9gFzK2ynwvNbMjMhkZHR4OLDc2Bp5YbD1n7PfdjUWiOOCT7XFQdRc4vgQx2w0LnlvOxqKGu5u7uY+7+PmAAONHMjq0YUu1T+pSO5O6r3X3Q3Qf7+vqCiw3NgaeWGw9Z+z33Y1Fojjgk+1xUHUXOL4EMdsNC55bzsaghKArp7juBh4EPV7w1DMwDMLP9gIOAlj8HH5ovT20N85C133M/FoXmiEOyz0XVUeT8EshgNyx0bjkfixrqScv0AXvcfaeZ9QAfonTDdKINwKeBR4GzgIe8gIxl6Jrkqa1hHrL2e+7HotA1sMdvmtaTlimqjiLnl/Na6qFzy/lY1FAz525mCyndLO2i9El/rbuvMrNVwJC7byjHJX8ALKL0iX2Fu/92pv0q5y4iEq5lOXd330ypaVduv3LCz68BfxdapIiIFCP7L+tI7sEdmR0hD7bE8BBMkQ/upPaQVgznIwFZN/fkHtyR2RHyYEsMD8EU+eBOag9pxXA+EpH1wmHJPbgjsyPkwZYYHoIp8sGd1B7SiuF8JCLr5p7cgzsyO0IebInhIZgiH9xJ7SGtGM5HIrJu7sk9uCOzI+TBlhgeginywZ3UHtKK4XwkIuvmntyDOzI7Qh5sieEhmCIf3EntIa0Yzkcism7uRX2phiQu5IsWYvhShtAaYphfavvNkL6sQ0QkIS17iEmk44V8sUcsUqs5lux6LHW0gJq7yExCvtgjFqnVHEt2PZY6WiTra+4iTQv5Yo9YpFZzLNn1WOpoETV3kZmEfLFHLFKrOZbseix1tIiau8hMQr7YIxap1RxLdj2WOlpEzV1kJiFf7BGL1GqOJbseSx0touYuMpNl18Lg+fs+9VpX6XWMNybHpVZzLNn1WOpoEeXcRUQSopy7zJ4Us8FF1VxUvjzFYyxtpeYuzUkxG1xUzUXly1M8xtJ2uuYuzUkxG1xUzUXly1M8xtJ2au7SnBSzwUXVXFS+PMVjLG2n5i7NSTEbXFTNReXLUzzG0nZq7tKcFLPBRdVcVL48xWMsbafmLs1JMRtcVM1F5ctTPMbSdsq5i4gkpN6cuz65Sz42r4VvHQtf6y39vnnt7O+3qBpEAinnLnkoKgsesl/l0SUi+uQueSgqCx6yX+XRJSJq7pKHorLgIftVHl0iouYueSgqCx6yX+XRJSJq7pKHorLgIftVHl0iouYueSgqCx6yX+XRJSI1c+5mNg+4BXgX8Baw2t2/UzHmg8CdwLPlTXe4+4x3kZRzFxEJ18r13N8Evujuj5vZAcAmM3vA3bdWjPuZuy9rpFiJUIrrh4fUnOL8YqDjloyazd3dXwBeKP/8RzPbBvQDlc1dcpFiXlt59OLpuCUl6Jq7mc0HFgEbq7x9spn90szuNbNjWlCbtEuKeW3l0Yun45aUup9QNbN3AD8CvuDur1S8/Tjwl+7+qpktBdYD766yjwuBCwEOP/zwhouWgqWY11YevXg6bkmp65O7mXVTauy3ufsdle+7+yvu/mr553uAbjM7pMq41e4+6O6DfX19TZYuhUkxr608evF03JJSs7mbmQE3AtvcverapWb2rvI4zOzE8n53tLJQmUUp5rWVRy+ejltS6rkssxj4JLDFzJ4sb/sKcDiAu38fOAv4nJm9CewGVni71hKW5o3fHEspFRFSc4rzi4GOW1K0nruISEJamXOXWClzPNndl8Kmm0pfSG1dpa+3a/ZbkEQSpeaeKmWOJ7v7Uhi6cd9rH9v3Wg1eOpDWlkmVMseTbbopbLtI5tTcU6XM8WQ+FrZdJHNq7qlS5ngy6wrbLpI5NfdUKXM82Qkrw7aLZE7NPVVaO3yyZdfC4Pn7PqlbV+m1bqZKh1LOXUQkIcq5N2D9EyNcc/9TPL9zN4f19nDZkgWcsai/3WW1Tu65+NznFwMd42SouZetf2KEy+/Ywu49pXTFyM7dXH7HFoA8Gnzuufjc5xcDHeOk6Jp72TX3P7W3sY/bvWeMa+5/qk0VtVjuufjc5xcDHeOkqLmXPb9zd9D25OSei899fjHQMU6KmnvZYb09QduTk3suPvf5xUDHOClq7mWXLVlAT/fkB156uru4bMmCNlXUYrnn4nOfXwx0jJOiG6pl4zdNs03L5L4Wd+7zi4GOcVKUcxcRSUi9OXddlhFJwea18K1j4Wu9pd83r01j39I2uiwjErsi8+XKrmdLn9xFYldkvlzZ9WypuYvErsh8ubLr2VJzF4ldkflyZdezpeYuErsi8+XKrmdLzV0kdkWu3a/vBciWcu4iIglRzl1EpIOpuYuIZEjNXUQkQ2ruIiIZUnMXEcmQmruISIbU3EVEMqTmLiKSoZrN3czmmdl/mdk2M/uVmV1SZYyZ2XfN7Gkz22xmxxdTrjRF63aLdIx61nN/E/iiuz9uZgcAm8zsAXffOmHMR4B3l3+dBHyv/LvEQut2i3SUmp/c3f0Fd3+8/PMfgW1A5ReLng7c4iWPAb1mdmjLq5XGad1ukY4SdM3dzOYDi4CNFW/1A9snvB5m6n8AMLMLzWzIzIZGR0fDKpXmaN1ukY5Sd3M3s3cAPwK+4O6vVL5d5Y9MWZHM3Ve7+6C7D/b19YVVKs3Rut0iHaWu5m5m3ZQa+23ufkeVIcPAvAmvB4Dnmy9PWkbrdot0lHrSMgbcCGxz92unGbYB+FQ5NfMBYJe7v9DCOqVZWrdbpKPUk5ZZDHwS2GJmT5a3fQU4HMDdvw/cAywFngb+BHym9aVK0xaerWYu0iFqNnd3/znVr6lPHOPARa0qSkREmqMnVEVEMqTmLiKSITV3EZEMqbmLiGRIzV1EJENq7iIiGVJzFxHJkJUi6m34i81Ggd+15S+v7RDgpXYXUSDNL105zw00v3r8pbvXXJyrbc09ZmY25O6D7a6jKJpfunKeG2h+raTLMiIiGVJzFxHJkJp7davbXUDBNL905Tw30PxaRtfcRUQypE/uIiIZ6ujmbmZdZvaEmd1d5b2VZjZqZk+Wf/19O2pshpk9Z2ZbyvUPVXnfzOy7Zva0mW02s+PbUWcj6pjbB81s14Tzl9RXTplZr5mtM7Nfm9k2Mzu54v1kzx3UNb9kz5+ZLZhQ95Nm9oqZfaFiTOHnr54v68jZJcA24MBp3v+hu39+Fuspwl+7+3S52o8A7y7/Ogn4Xvn3VMw0N4CfufuyWaumtb4D3OfuZ5nZnwNvq3g/9XNXa36Q6Plz96eA90HpAyQwAvy4Yljh569jP7mb2QDwUeCGdtfSRqcDt3jJY0CvmR3a7qI6nZkdCJxC6estcfc33H1nxbBkz12d88vFacAz7l75wGbh569jmzvwbeBLwFszjPl4+Z9M68xs3gzjYuXAf5rZJjO7sMr7/cD2Ca+Hy9tSUGtuACeb2S/N7F4zO2Y2i2vSkcAo8K/ly4Y3mNnbK8akfO7qmR+ke/4mWgGsqbK98PPXkc3dzJYBL7r7phmG3QXMd/eFwE+Am2eluNZa7O7HU/on4EVmdkrF+9W+PjGV+FStuT1O6THtvwL+BVg/2wU2YT/geOB77r4I+D/gyxVjUj539cwv5fMHQPly03Lg36u9XWVbS89fRzZ3Sl/6vdzMngNuB041s1snDnD3He7+evnl9cAJs1ti89z9+fLvL1K65ndixZBhYOK/SAaA52enuubUmpu7v+Lur5Z/vgfoNrNDZr3QxgwDw+6+sfx6HaVmWDkmyXNHHfNL/PyN+wjwuLv/b5X3Cj9/Hdnc3f1ydx9w9/mU/tn0kLt/YuKYiutfyyndeE2Gmb3dzA4Y/xn4W+C/K4ZtAD5VvnP/AWCXu78wy6UGq2duZvYuM7PyzydS+t/6jtmutRHu/j/AdjNbUN50GrC1YliS5w7qm1/K52+Cc6l+SQZm4fx1elpmEjNbBQy5+wbgYjNbDrwJvAysbGdtDfgL4Mfl/3/sB/ybu99nZv8I4O7fB+4BlgJPA38CPtOmWkPVM7ezgM+Z2ZvAbmCFp/XE3j8Bt5X/af9b4DOZnLtxteaX9Pkzs7cBfwP8w4Rts3r+9ISqiEiGOvKyjIhI7tTcRUQypOYuIpIhNXcRkQypuYuIZEjNXUQkQ2ruIiIZUnMXEcnQ/wPmMFqpaGCFHwAAAABJRU5ErkJggg==\n", 219 | "text/plain": [ 220 | "
" 221 | ] 222 | }, 223 | "metadata": {}, 224 | "output_type": "display_data" 225 | } 226 | ], 227 | "source": [ 228 | "plt.scatter(X[:50,0],X[:50,1], label='0')\n", 229 | "plt.scatter(X[50:,0],X[50:,1], label='1')\n", 230 | "plt.legend()" 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "metadata": {}, 236 | "source": [ 237 | "----\n", 238 | "\n" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 5, 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "class SVM:\n", 248 | " def __init__(self, max_iter=100, kernel='linear'):\n", 249 | " self.max_iter = max_iter\n", 250 | " self._kernel = kernel\n", 251 | " \n", 252 | " def init_args(self, features, labels):\n", 253 | " self.m, self.n = features.shape\n", 254 | " self.X = features\n", 255 | " self.Y = labels\n", 256 | " self.b = 0.0\n", 257 | " \n", 258 | " # 将Ei保存在一个列表里\n", 259 | " self.alpha = np.ones(self.m)\n", 260 | " self.E = [self._E(i) for i in range(self.m)]\n", 261 | " # 松弛变量\n", 262 | " self.C = 1.0\n", 263 | " \n", 264 | " def _KKT(self, i):\n", 265 | " y_g = self._g(i)*self.Y[i]\n", 266 | " if self.alpha[i] == 0:\n", 267 | " return y_g >= 1\n", 268 | " elif 0 < self.alpha[i] < self.C:\n", 269 | " return y_g == 1\n", 270 | " else:\n", 271 | " return y_g <= 1\n", 272 | " \n", 273 | " # g(x)预测值,输入xi(X[i])\n", 274 | " def _g(self, i):\n", 275 | " r = self.b\n", 276 | " for j in range(self.m):\n", 277 | " r += self.alpha[j]*self.Y[j]*self.kernel(self.X[i], self.X[j])\n", 278 | " return r\n", 279 | " \n", 280 | " # 核函数\n", 281 | " def kernel(self, x1, x2):\n", 282 | " if self._kernel == 'linear':\n", 283 | " return sum([x1[k]*x2[k] for k in range(self.n)])\n", 284 | " elif self._kernel == 'poly':\n", 285 | " return (sum([x1[k]*x2[k] for k in range(self.n)]) + 1)**2\n", 286 | " \n", 287 | " return 0\n", 288 | " \n", 289 | " # E(x)为g(x)对输入x的预测值和y的差\n", 290 | " def _E(self, i):\n", 291 | " return self._g(i) - self.Y[i]\n", 292 | " \n", 293 | " def _init_alpha(self):\n", 294 | " # 外层循环首先遍历所有满足0= 0:\n", 307 | " j = min(range(self.m), key=lambda x: self.E[x])\n", 308 | " else:\n", 309 | " j = max(range(self.m), key=lambda x: self.E[x])\n", 310 | " return i, j\n", 311 | " \n", 312 | " def _compare(self, _alpha, L, H):\n", 313 | " if _alpha > H:\n", 314 | " return H\n", 315 | " elif _alpha < L:\n", 316 | " return L\n", 317 | " else:\n", 318 | " return _alpha \n", 319 | " \n", 320 | " def fit(self, features, labels):\n", 321 | " self.init_args(features, labels)\n", 322 | " \n", 323 | " for t in range(self.max_iter):\n", 324 | " # train\n", 325 | " i1, i2 = self._init_alpha()\n", 326 | " \n", 327 | " # 边界\n", 328 | " if self.Y[i1] == self.Y[i2]:\n", 329 | " L = max(0, self.alpha[i1]+self.alpha[i2]-self.C)\n", 330 | " H = min(self.C, self.alpha[i1]+self.alpha[i2])\n", 331 | " else:\n", 332 | " L = max(0, self.alpha[i2]-self.alpha[i1])\n", 333 | " H = min(self.C, self.C+self.alpha[i2]-self.alpha[i1])\n", 334 | " \n", 335 | " E1 = self.E[i1]\n", 336 | " E2 = self.E[i2]\n", 337 | " # eta=K11+K22-2K12\n", 338 | " eta = self.kernel(self.X[i1], self.X[i1]) + self.kernel(self.X[i2], self.X[i2]) - 2*self.kernel(self.X[i1], self.X[i2])\n", 339 | " if eta <= 0:\n", 340 | " # print('eta <= 0')\n", 341 | " continue\n", 342 | " \n", 343 | " alpha2_new_unc = self.alpha[i2] + self.Y[i2] * (E1 - E2) / eta#此处有修改,根据书上应该是E1 - E2,书上130-131页\n", 344 | " alpha2_new = self._compare(alpha2_new_unc, L, H)\n", 345 | " \n", 346 | " alpha1_new = self.alpha[i1] + self.Y[i1] * self.Y[i2] * (self.alpha[i2] - alpha2_new)\n", 347 | " \n", 348 | " b1_new = -E1 - self.Y[i1] * self.kernel(self.X[i1], self.X[i1]) * (alpha1_new-self.alpha[i1]) - self.Y[i2] * self.kernel(self.X[i2], self.X[i1]) * (alpha2_new-self.alpha[i2])+ self.b \n", 349 | " b2_new = -E2 - self.Y[i1] * self.kernel(self.X[i1], self.X[i2]) * (alpha1_new-self.alpha[i1]) - self.Y[i2] * self.kernel(self.X[i2], self.X[i2]) * (alpha2_new-self.alpha[i2])+ self.b \n", 350 | " \n", 351 | " if 0 < alpha1_new < self.C:\n", 352 | " b_new = b1_new\n", 353 | " elif 0 < alpha2_new < self.C:\n", 354 | " b_new = b2_new\n", 355 | " else:\n", 356 | " # 选择中点\n", 357 | " b_new = (b1_new + b2_new) / 2\n", 358 | " \n", 359 | " # 更新参数\n", 360 | " self.alpha[i1] = alpha1_new\n", 361 | " self.alpha[i2] = alpha2_new\n", 362 | " self.b = b_new\n", 363 | " \n", 364 | " self.E[i1] = self._E(i1)\n", 365 | " self.E[i2] = self._E(i2)\n", 366 | " return 'train done!'\n", 367 | " \n", 368 | " def predict(self, data):\n", 369 | " r = self.b\n", 370 | " for i in range(self.m):\n", 371 | " r += self.alpha[i] * self.Y[i] * self.kernel(data, self.X[i])\n", 372 | " \n", 373 | " return 1 if r > 0 else -1\n", 374 | " \n", 375 | " def score(self, X_test, y_test):\n", 376 | " right_count = 0\n", 377 | " for i in range(len(X_test)):\n", 378 | " result = self.predict(X_test[i])\n", 379 | " if result == y_test[i]:\n", 380 | " right_count += 1\n", 381 | " return right_count / len(X_test)\n", 382 | " \n", 383 | " def _weight(self):\n", 384 | " # linear model\n", 385 | " yx = self.Y.reshape(-1, 1)*self.X\n", 386 | " self.w = np.dot(yx.T, self.alpha)\n", 387 | " return self.w" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": 6, 393 | "metadata": {}, 394 | "outputs": [], 395 | "source": [ 396 | "svm = SVM(max_iter=200)" 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "execution_count": 7, 402 | "metadata": {}, 403 | "outputs": [ 404 | { 405 | "data": { 406 | "text/plain": [ 407 | "'train done!'" 408 | ] 409 | }, 410 | "execution_count": 7, 411 | "metadata": {}, 412 | "output_type": "execute_result" 413 | } 414 | ], 415 | "source": [ 416 | "svm.fit(X_train, y_train)" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": 8, 422 | "metadata": {}, 423 | "outputs": [ 424 | { 425 | "data": { 426 | "text/plain": [ 427 | "0.92" 428 | ] 429 | }, 430 | "execution_count": 8, 431 | "metadata": {}, 432 | "output_type": "execute_result" 433 | } 434 | ], 435 | "source": [ 436 | "svm.score(X_test, y_test)" 437 | ] 438 | }, 439 | { 440 | "cell_type": "markdown", 441 | "metadata": {}, 442 | "source": [ 443 | "## sklearn.svm.SVC" 444 | ] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "execution_count": 9, 449 | "metadata": {}, 450 | "outputs": [ 451 | { 452 | "data": { 453 | "text/plain": [ 454 | "SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,\n", 455 | " decision_function_shape='ovr', degree=3, gamma='auto', kernel='rbf',\n", 456 | " max_iter=-1, probability=False, random_state=None, shrinking=True,\n", 457 | " tol=0.001, verbose=False)" 458 | ] 459 | }, 460 | "execution_count": 9, 461 | "metadata": {}, 462 | "output_type": "execute_result" 463 | } 464 | ], 465 | "source": [ 466 | "from sklearn.svm import SVC\n", 467 | "clf = SVC()\n", 468 | "clf.fit(X_train, y_train)" 469 | ] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "execution_count": 10, 474 | "metadata": {}, 475 | "outputs": [ 476 | { 477 | "data": { 478 | "text/plain": [ 479 | "0.96" 480 | ] 481 | }, 482 | "execution_count": 10, 483 | "metadata": {}, 484 | "output_type": "execute_result" 485 | } 486 | ], 487 | "source": [ 488 | "clf.score(X_test, y_test)" 489 | ] 490 | }, 491 | { 492 | "cell_type": "markdown", 493 | "metadata": {}, 494 | "source": [ 495 | "### sklearn.svm.SVC\n", 496 | "\n", 497 | "*(C=1.0, kernel='rbf', degree=3, gamma='auto', coef0=0.0, shrinking=True, probability=False,tol=0.001, cache_size=200, class_weight=None, verbose=False, max_iter=-1, decision_function_shape=None,random_state=None)*\n", 498 | "\n", 499 | "参数:\n", 500 | "\n", 501 | "- C:C-SVC的惩罚参数C?默认值是1.0\n", 502 | "\n", 503 | "C越大,相当于惩罚松弛变量,希望松弛变量接近0,即对误分类的惩罚增大,趋向于对训练集全分对的情况,这样对训练集测试时准确率很高,但泛化能力弱。C值小,对误分类的惩罚减小,允许容错,将他们当成噪声点,泛化能力较强。\n", 504 | "\n", 505 | "- kernel :核函数,默认是rbf,可以是‘linear’, ‘poly’, ‘rbf’, ‘sigmoid’, ‘precomputed’ \n", 506 | " \n", 507 | " – 线性:u'v\n", 508 | " \n", 509 | " – 多项式:(gamma*u'*v + coef0)^degree\n", 510 | "\n", 511 | " – RBF函数:exp(-gamma|u-v|^2)\n", 512 | "\n", 513 | " – sigmoid:tanh(gamma*u'*v + coef0)\n", 514 | "\n", 515 | "\n", 516 | "- degree :多项式poly函数的维度,默认是3,选择其他核函数时会被忽略。\n", 517 | "\n", 518 | "\n", 519 | "- gamma : ‘rbf’,‘poly’ 和‘sigmoid’的核函数参数。默认是’auto’,则会选择1/n_features\n", 520 | "\n", 521 | "\n", 522 | "- coef0 :核函数的常数项。对于‘poly’和 ‘sigmoid’有用。\n", 523 | "\n", 524 | "\n", 525 | "- probability :是否采用概率估计?.默认为False\n", 526 | "\n", 527 | "\n", 528 | "- shrinking :是否采用shrinking heuristic方法,默认为true\n", 529 | "\n", 530 | "\n", 531 | "- tol :停止训练的误差值大小,默认为1e-3\n", 532 | "\n", 533 | "\n", 534 | "- cache_size :核函数cache缓存大小,默认为200\n", 535 | "\n", 536 | "\n", 537 | "- class_weight :类别的权重,字典形式传递。设置第几类的参数C为weight*C(C-SVC中的C)\n", 538 | "\n", 539 | "\n", 540 | "- verbose :允许冗余输出?\n", 541 | "\n", 542 | "\n", 543 | "- max_iter :最大迭代次数。-1为无限制。\n", 544 | "\n", 545 | "\n", 546 | "- decision_function_shape :‘ovo’, ‘ovr’ or None, default=None3\n", 547 | "\n", 548 | "\n", 549 | "- random_state :数据洗牌时的种子值,int值\n", 550 | "\n", 551 | "\n", 552 | "主要调节的参数有:C、kernel、degree、gamma、coef0。" 553 | ] 554 | } 555 | ], 556 | "metadata": { 557 | "kernelspec": { 558 | "display_name": "Python 3", 559 | "language": "python", 560 | "name": "python3" 561 | }, 562 | "language_info": { 563 | "codemirror_mode": { 564 | "name": "ipython", 565 | "version": 3 566 | }, 567 | "file_extension": ".py", 568 | "mimetype": "text/x-python", 569 | "name": "python", 570 | "nbconvert_exporter": "python", 571 | "pygments_lexer": "ipython3", 572 | "version": "3.6.4" 573 | } 574 | }, 575 | "nbformat": 4, 576 | "nbformat_minor": 2 577 | } 578 | -------------------------------------------------------------------------------- /code/第8章 提升方法(AdaBoost)/Adaboost.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "原文代码作者:https://github.com/wzyonggege/statistical-learning-method\n", 8 | "\n", 9 | "中文注释制作:机器学习初学者(微信公众号:ID:ai-start-com)\n", 10 | "\n", 11 | "配置环境:python 3.6\n", 12 | "\n", 13 | "代码全部测试通过。\n", 14 | "![gongzhong](../gongzhong.jpg)" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "# 第8章 提升方法" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": { 27 | "collapsed": true 28 | }, 29 | "source": [ 30 | "# Boost\n", 31 | "\n", 32 | "“装袋”(bagging)和“提升”(boost)是构建组合模型的两种最主要的方法,所谓的组合模型是由多个基本模型构成的模型,组合模型的预测效果往往比任意一个基本模型的效果都要好。\n", 33 | "\n", 34 | "- 装袋:每个基本模型由从总体样本中随机抽样得到的不同数据集进行训练得到,通过重抽样得到不同训练数据集的过程称为装袋。\n", 35 | "\n", 36 | "- 提升:每个基本模型训练时的数据集采用不同权重,针对上一个基本模型分类错误的样本增加权重,使得新的模型重点关注误分类样本\n", 37 | "\n", 38 | "### AdaBoost\n", 39 | "\n", 40 | "AdaBoost是AdaptiveBoost的缩写,表明该算法是具有适应性的提升算法。\n", 41 | "\n", 42 | "算法的步骤如下:\n", 43 | "\n", 44 | "1)给每个训练样本($x_{1},x_{2},….,x_{N}$)分配权重,初始权重$w_{1}$均为1/N。\n", 45 | "\n", 46 | "2)针对带有权值的样本进行训练,得到模型$G_m$(初始模型为G1)。\n", 47 | "\n", 48 | "3)计算模型$G_m$的误分率$e_m=\\sum_{i=1}^Nw_iI(y_i\\not= G_m(x_i))$\n", 49 | "\n", 50 | "4)计算模型$G_m$的系数$\\alpha_m=0.5\\log[(1-e_m)/e_m]$\n", 51 | "\n", 52 | "5)根据误分率e和当前权重向量$w_m$更新权重向量$w_{m+1}$。\n", 53 | "\n", 54 | "6)计算组合模型$f(x)=\\sum_{m=1}^M\\alpha_mG_m(x_i)$的误分率。\n", 55 | "\n", 56 | "7)当组合模型的误分率或迭代次数低于一定阈值,停止迭代;否则,回到步骤2)" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 13, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "import numpy as np\n", 66 | "import pandas as pd\n", 67 | "from sklearn.datasets import load_iris\n", 68 | "from sklearn.model_selection import train_test_split\n", 69 | "import matplotlib.pyplot as plt\n", 70 | "%matplotlib inline" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 2, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "# data\n", 80 | "def create_data():\n", 81 | " iris = load_iris()\n", 82 | " df = pd.DataFrame(iris.data, columns=iris.feature_names)\n", 83 | " df['label'] = iris.target\n", 84 | " df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n", 85 | " data = np.array(df.iloc[:100, [0, 1, -1]])\n", 86 | " for i in range(len(data)):\n", 87 | " if data[i,-1] == 0:\n", 88 | " data[i,-1] = -1\n", 89 | " # print(data)\n", 90 | " return data[:,:2], data[:,-1]" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 3, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "X, y = create_data()\n", 100 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 4, 106 | "metadata": {}, 107 | "outputs": [ 108 | { 109 | "data": { 110 | "text/plain": [ 111 | "" 112 | ] 113 | }, 114 | "execution_count": 4, 115 | "metadata": {}, 116 | "output_type": "execute_result" 117 | }, 118 | { 119 | "data": { 120 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAGihJREFUeJzt3X+MXWWdx/H3d4dZOiowoYwrzJQtP0yjQNfCCJImxAV3q7UWgiyU4I8qC7sGFwwuRgxBbUzAkOCPJdEUyALCFrsVS2H5sQhLVAI1U8B2bSWCoJ2BXYZii6wFyvDdP+6ddubOnbn3ufeeuc/z3M8raTr33Ken3+cc/XJ7zuc819wdERHJy5+1uwAREWk9NXcRkQypuYuIZEjNXUQkQ2ruIiIZUnMXEcmQmruISIbU3EVEMqTmLiKSof3qHWhmXcAQMOLuyyreWwlcA4yUN13n7jfMtL9DDjnE58+fH1SsiEin27Rp00vu3ldrXN3NHbgE2AYcOM37P3T3z9e7s/nz5zM0NBTw14uIiJn9rp5xdV2WMbMB4KPAjJ/GRUQkDvVec/828CXgrRnGfNzMNpvZOjObV22AmV1oZkNmNjQ6Ohpaq4iI1KlmczezZcCL7r5phmF3AfPdfSHwE+DmaoPcfbW7D7r7YF9fzUtGIiLSoHquuS8GlpvZUmAOcKCZ3erunxgf4O47Joy/Hvhma8sUEWmdPXv2MDw8zGuvvdbuUqY1Z84cBgYG6O7ubujP12zu7n45cDmAmX0Q+OeJjb28/VB3f6H8cjmlG68iIlEaHh7mgAMOYP78+ZhZu8uZwt3ZsWMHw8PDHHHEEQ3to+Gcu5mtMrPl5ZcXm9mvzOyXwMXAykb3KyJStNdee425c+dG2dgBzIy5c+c29S+LkCgk7v4w8HD55ysnbN/76V4kN+ufGOGa+5/i+Z27Oay3h8uWLOCMRf3tLkuaFGtjH9dsfUHNXaTTrH9ihMvv2MLuPWMAjOzczeV3bAFQg5eoafkBkRlcc/9Texv7uN17xrjm/qfaVJHk4r777mPBggUcffTRXH311S3fv5q7yAye37k7aLtIPcbGxrjooou499572bp1K2vWrGHr1q0t/Tt0WUZkBof19jBSpZEf1tvThmqkXVp93+UXv/gFRx99NEceeSQAK1as4M477+S9731vq0rWJ3eRmVy2ZAE93V2TtvV0d3HZkgVtqkhm2/h9l5Gdu3H23XdZ/8RIzT87nZGREebN2/cg/8DAACMjje+vGjV3kRmcsaifq848jv7eHgzo7+3hqjOP083UDlLEfRd3n7Kt1ekdXZYRqeGMRf1q5h2siPsuAwMDbN++fe/r4eFhDjvssIb3V40+uYuIzGC6+yvN3Hd5//vfz29+8xueffZZ3njjDW6//XaWL19e+w8GUHMXEZlBEfdd9ttvP6677jqWLFnCe97zHs4++2yOOeaYZkud/He0dG8iIpkZvyTX6qeUly5dytKlS1tRYlVq7iIiNaR430WXZUREMqTmLiKSITV3EZEMqbmLiGRIzV1EJENq7pKN9U+MsPjqhzjiy//B4qsfamrtD5Giffazn+Wd73wnxx57bCH7V3OXLBSxuJNIkVauXMl9991X2P7V3CUL+lINKdTmtfCtY+FrvaXfN69tepennHIKBx98cAuKq04PMUkW9KUaUpjNa+Gui2FP+X9Lu7aXXgMsPLt9ddWgT+6ShSIWdxIB4MFV+xr7uD27S9sjpuYuWdCXakhhdg2HbY+ELstIFopa3EmEgwZKl2KqbY+YmrtkI8XFnSQBp105+Zo7QHdPaXsTzj33XB5++GFeeuklBgYG+PrXv87555/fZLH7qLlL01r95cEiURm/afrgqtKlmIMGSo29yZupa9asaUFx01Nzl6aM58vHY4jj+XJADV7ysfDsqJMx1eiGqjRF+XKROKm5S1OUL5dUuXu7S5hRs/WpuUtTlC+XFM2ZM4cdO3ZE2+DdnR07djBnzpyG96Fr7tKUy5YsmHTNHZQvl/gNDAwwPDzM6Ohou0uZ1pw5cxgYaDxuqeYuTVG+XFLU3d3NEUcc0e4yClV3czezLmAIGHH3ZRXv7Q/cApwA7ADOcffnWlinREz5cpH4hHxyvwTYBhxY5b3zgT+4+9FmtgL4JnBOC+oTSYoy/xKLum6omtkA8FHghmmGnA7cXP55HXCamVnz5YmkQ2vKS0zqTct8G/gS8NY07/cD2wHc/U1gFzC36epEEqLMv8SkZnM3s2XAi+6+aaZhVbZNyRiZ2YVmNmRmQzHfpRZphDL/EpN6PrkvBpab2XPA7cCpZnZrxZhhYB6Ame0HHAS8XLkjd1/t7oPuPtjX19dU4SKxUeZfYlKzubv75e4+4O7zgRXAQ+7+iYphG4BPl38+qzwmzqcDRAqiNeUlJg3n3M1sFTDk7huAG4EfmNnTlD6xr2hRfSLJUOZfYmLt+oA9ODjoQ0NDbfm7RURSZWab3H2w1jg9oSrRumL9FtZs3M6YO11mnHvSPL5xxnHtLkskCWruEqUr1m/h1sd+v/f1mPve12rwIrVpVUiJ0pqNVb6zcobtIjKZmrtEaWyae0HTbReRydTcJUpd06xeMd12EZlMzV2idO5J84K2i8hkuqEqURq/aaq0jEhjlHMXEUmIcu7SlPOuf5RHntm3PNDiow7mtgtObmNF7aM12iVFuuYuU1Q2doBHnnmZ865/tE0VtY/WaJdUqbnLFJWNvdb2nGmNdkmVmrvIDLRGu6RKzV1kBlqjXVKl5i5TLD7q4KDtOdMa7ZIqNXeZ4rYLTp7SyDs1LXPGon6uOvM4+nt7MKC/t4erzjxOaRmJnnLuIiIJUc5dmlJUtjtkv8qXizROzV2mGM92j0cAx7PdQFPNNWS/RdUg0il0zV2mKCrbHbJf5ctFmqPmLlMUle0O2a/y5SLNUXOXKYrKdofsV/lykeaoucsURWW7Q/arfLlIc3RDVaYYv2HZ6qRKyH6LqkGkUyjnLiKSEOXcC5ZiBjvFmkWkMWruDUgxg51izSLSON1QbUCKGewUaxaRxqm5NyDFDHaKNYtI49TcG5BiBjvFmkWkcWruDUgxg51izSLSON1QbUCKGewUaxaRxtXMuZvZHOCnwP6U/mOwzt2/WjFmJXANMP6V8Ne5+w0z7Vc5dxGRcK3Mub8OnOrur5pZN/BzM7vX3R+rGPdDd/98I8XK7Lhi/RbWbNzOmDtdZpx70jy+ccZxTY+NJT8fSx0iMajZ3L300f7V8svu8q/2PNYqDbti/RZufez3e1+Pue99Xdm0Q8bGkp+PpQ6RWNR1Q9XMuszsSeBF4AF331hl2MfNbLOZrTOzeS2tUpq2ZuP2ureHjI0lPx9LHSKxqKu5u/uYu78PGABONLNjK4bcBcx394XAT4Cbq+3HzC40syEzGxodHW2mbgk0Ns29lWrbQ8bGkp+PpQ6RWARFId19J/Aw8OGK7Tvc/fXyy+uBE6b586vdfdDdB/v6+hooVxrVZVb39pCxseTnY6lDJBY1m7uZ9ZlZb/nnHuBDwK8rxhw64eVyYFsri5TmnXtS9Stl1baHjI0lPx9LHSKxqCctcyhws5l1UfqPwVp3v9vMVgFD7r4BuNjMlgNvAi8DK4sqWBozfiO0ngRMyNhY8vOx1CESC63nLiKSEK3nXrCiMtUh+fIi9x0yvxSPRXI2r4UHV8GuYThoAE67Ehae3e6qJGJq7g0oKlMdki8vct8h80vxWCRn81q462LYU07+7Npeeg1q8DItLRzWgKIy1SH58iL3HTK/FI9Fch5cta+xj9uzu7RdZBpq7g0oKlMdki8vct8h80vxWCRn13DYdhHU3BtSVKY6JF9e5L5D5pfisUjOQQNh20VQc29IUZnqkHx5kfsOmV+KxyI5p10J3RX/sezuKW0XmYZuqDagqEx1SL68yH2HzC/FY5Gc8ZumSstIAOXcRUQSopy7TBFDdl0Sp7x9MtTcO0QM2XVJnPL2SdEN1Q4RQ3ZdEqe8fVLU3DtEDNl1SZzy9klRc+8QMWTXJXHK2ydFzb1DxJBdl8Qpb58U3VDtEDFk1yVxytsnRTl3EZGEKOdeVlReO2S/saxLrux6ZHLPjOc+vxBtOBZZN/ei8toh+41lXXJl1yOTe2Y89/mFaNOxyPqGalF57ZD9xrIuubLrkck9M577/EK06Vhk3dyLymuH7DeWdcmVXY9M7pnx3OcXok3HIuvmXlReO2S/saxLrux6ZHLPjOc+vxBtOhZZN/ei8toh+41lXXJl1yOTe2Y89/mFaNOxyPqGalF57ZD9xrIuubLrkck9M577/EK06Vgo5y4ikhDl3AsWQ37+vOsf5ZFnXt77evFRB3PbBSc3XYNIVu6+FDbdBD4G1gUnrIRl1za/38hz/Flfcy/KeGZ8ZOdunH2Z8fVPjMzafisbO8Ajz7zMedc/2lQNIlm5+1IYurHU2KH0+9CNpe3NGM+u79oO+L7s+ua1TZfcKmruDYghP1/Z2GttF+lIm24K216vBHL8au4NiCE/LyJ18LGw7fVKIMev5t6AGPLzIlIH6wrbXq8Ecvxq7g2IIT+/+KiDq+5juu0iHemElWHb65VAjl/NvQFnLOrnqjOPo7+3BwP6e3u46szjWpKfr3e/t11w8pRGrrSMSIVl18Lg+fs+qVtX6XWzaZmFZ8PHvgsHzQOs9PvHvhtVWkY5dxGRhLQs525mc4CfAvuXx69z969WjNkfuAU4AdgBnOPuzzVQd02h+fLU1jAPWfs992NRaI44JPtcVB1Fzi/yDHZTQueW87GYQT0PMb0OnOrur5pZN/BzM7vX3R+bMOZ84A/ufrSZrQC+CZzT6mJD1yRPbQ3zkLXfcz8Wha6BPZ59HjeefYapDb6oOoqcX85rqYfOLedjUUPNa+5e8mr5ZXf5V+W1nNOBm8s/rwNOM2v9soeh+fLU1jAPWfs992NRaI44JPtcVB1Fzi+BDHbDQueW87Gooa4bqmbWZWZPAi8CD7j7xooh/cB2AHd/E9gFzK2ynwvNbMjMhkZHR4OLDc2Bp5YbD1n7PfdjUWiOOCT7XFQdRc4vgQx2w0LnlvOxqKGu5u7uY+7+PmAAONHMjq0YUu1T+pSO5O6r3X3Q3Qf7+vqCiw3NgaeWGw9Z+z33Y1Fojjgk+1xUHUXOL4EMdsNC55bzsaghKArp7juBh4EPV7w1DMwDMLP9gIOAlj8HH5ovT20N85C133M/FoXmiEOyz0XVUeT8EshgNyx0bjkfixrqScv0AXvcfaeZ9QAfonTDdKINwKeBR4GzgIe8gIxl6Jrkqa1hHrL2e+7HotA1sMdvmtaTlimqjiLnl/Na6qFzy/lY1FAz525mCyndLO2i9El/rbuvMrNVwJC7byjHJX8ALKL0iX2Fu/92pv0q5y4iEq5lOXd330ypaVduv3LCz68BfxdapIiIFCP7L+tI7sEdmR0hD7bE8BBMkQ/upPaQVgznIwFZN/fkHtyR2RHyYEsMD8EU+eBOag9pxXA+EpH1wmHJPbgjsyPkwZYYHoIp8sGd1B7SiuF8JCLr5p7cgzsyO0IebInhIZgiH9xJ7SGtGM5HIrJu7sk9uCOzI+TBlhgeginywZ3UHtKK4XwkIuvmntyDOzI7Qh5sieEhmCIf3EntIa0Yzkcism7uRX2phiQu5IsWYvhShtAaYphfavvNkL6sQ0QkIS17iEmk44V8sUcsUqs5lux6LHW0gJq7yExCvtgjFqnVHEt2PZY6WiTra+4iTQv5Yo9YpFZzLNn1WOpoETV3kZmEfLFHLFKrOZbseix1tIiau8hMQr7YIxap1RxLdj2WOlpEzV1kJiFf7BGL1GqOJbseSx0touYuMpNl18Lg+fs+9VpX6XWMNybHpVZzLNn1WOpoEeXcRUQSopy7zJ4Us8FF1VxUvjzFYyxtpeYuzUkxG1xUzUXly1M8xtJ2uuYuzUkxG1xUzUXly1M8xtJ2au7SnBSzwUXVXFS+PMVjLG2n5i7NSTEbXFTNReXLUzzG0nZq7tKcFLPBRdVcVL48xWMsbafmLs1JMRtcVM1F5ctTPMbSdsq5i4gkpN6cuz65Sz42r4VvHQtf6y39vnnt7O+3qBpEAinnLnkoKgsesl/l0SUi+uQueSgqCx6yX+XRJSJq7pKHorLgIftVHl0iouYueSgqCx6yX+XRJSJq7pKHorLgIftVHl0iouYueSgqCx6yX+XRJSI1c+5mNg+4BXgX8Baw2t2/UzHmg8CdwLPlTXe4+4x3kZRzFxEJ18r13N8Evujuj5vZAcAmM3vA3bdWjPuZuy9rpFiJUIrrh4fUnOL8YqDjloyazd3dXwBeKP/8RzPbBvQDlc1dcpFiXlt59OLpuCUl6Jq7mc0HFgEbq7x9spn90szuNbNjWlCbtEuKeW3l0Yun45aUup9QNbN3AD8CvuDur1S8/Tjwl+7+qpktBdYD766yjwuBCwEOP/zwhouWgqWY11YevXg6bkmp65O7mXVTauy3ufsdle+7+yvu/mr553uAbjM7pMq41e4+6O6DfX19TZYuhUkxr608evF03JJSs7mbmQE3AtvcverapWb2rvI4zOzE8n53tLJQmUUp5rWVRy+ejltS6rkssxj4JLDFzJ4sb/sKcDiAu38fOAv4nJm9CewGVni71hKW5o3fHEspFRFSc4rzi4GOW1K0nruISEJamXOXWClzPNndl8Kmm0pfSG1dpa+3a/ZbkEQSpeaeKmWOJ7v7Uhi6cd9rH9v3Wg1eOpDWlkmVMseTbbopbLtI5tTcU6XM8WQ+FrZdJHNq7qlS5ngy6wrbLpI5NfdUKXM82Qkrw7aLZE7NPVVaO3yyZdfC4Pn7PqlbV+m1bqZKh1LOXUQkIcq5N2D9EyNcc/9TPL9zN4f19nDZkgWcsai/3WW1Tu65+NznFwMd42SouZetf2KEy+/Ywu49pXTFyM7dXH7HFoA8Gnzuufjc5xcDHeOk6Jp72TX3P7W3sY/bvWeMa+5/qk0VtVjuufjc5xcDHeOkqLmXPb9zd9D25OSei899fjHQMU6KmnvZYb09QduTk3suPvf5xUDHOClq7mWXLVlAT/fkB156uru4bMmCNlXUYrnn4nOfXwx0jJOiG6pl4zdNs03L5L4Wd+7zi4GOcVKUcxcRSUi9OXddlhFJwea18K1j4Wu9pd83r01j39I2uiwjErsi8+XKrmdLn9xFYldkvlzZ9WypuYvErsh8ubLr2VJzF4ldkflyZdezpeYuErsi8+XKrmdLzV0kdkWu3a/vBciWcu4iIglRzl1EpIOpuYuIZEjNXUQkQ2ruIiIZUnMXEcmQmruISIbU3EVEMqTmLiKSoZrN3czmmdl/mdk2M/uVmV1SZYyZ2XfN7Gkz22xmxxdTrjRF63aLdIx61nN/E/iiuz9uZgcAm8zsAXffOmHMR4B3l3+dBHyv/LvEQut2i3SUmp/c3f0Fd3+8/PMfgW1A5ReLng7c4iWPAb1mdmjLq5XGad1ukY4SdM3dzOYDi4CNFW/1A9snvB5m6n8AMLMLzWzIzIZGR0fDKpXmaN1ukY5Sd3M3s3cAPwK+4O6vVL5d5Y9MWZHM3Ve7+6C7D/b19YVVKs3Rut0iHaWu5m5m3ZQa+23ufkeVIcPAvAmvB4Dnmy9PWkbrdot0lHrSMgbcCGxz92unGbYB+FQ5NfMBYJe7v9DCOqVZWrdbpKPUk5ZZDHwS2GJmT5a3fQU4HMDdvw/cAywFngb+BHym9aVK0xaerWYu0iFqNnd3/znVr6lPHOPARa0qSkREmqMnVEVEMqTmLiKSITV3EZEMqbmLiGRIzV1EJENq7iIiGVJzFxHJkJUi6m34i81Ggd+15S+v7RDgpXYXUSDNL105zw00v3r8pbvXXJyrbc09ZmY25O6D7a6jKJpfunKeG2h+raTLMiIiGVJzFxHJkJp7davbXUDBNL905Tw30PxaRtfcRUQypE/uIiIZ6ujmbmZdZvaEmd1d5b2VZjZqZk+Wf/19O2pshpk9Z2ZbyvUPVXnfzOy7Zva0mW02s+PbUWcj6pjbB81s14Tzl9RXTplZr5mtM7Nfm9k2Mzu54v1kzx3UNb9kz5+ZLZhQ95Nm9oqZfaFiTOHnr54v68jZJcA24MBp3v+hu39+Fuspwl+7+3S52o8A7y7/Ogn4Xvn3VMw0N4CfufuyWaumtb4D3OfuZ5nZnwNvq3g/9XNXa36Q6Plz96eA90HpAyQwAvy4Yljh569jP7mb2QDwUeCGdtfSRqcDt3jJY0CvmR3a7qI6nZkdCJxC6estcfc33H1nxbBkz12d88vFacAz7l75wGbh569jmzvwbeBLwFszjPl4+Z9M68xs3gzjYuXAf5rZJjO7sMr7/cD2Ca+Hy9tSUGtuACeb2S/N7F4zO2Y2i2vSkcAo8K/ly4Y3mNnbK8akfO7qmR+ke/4mWgGsqbK98PPXkc3dzJYBL7r7phmG3QXMd/eFwE+Am2eluNZa7O7HU/on4EVmdkrF+9W+PjGV+FStuT1O6THtvwL+BVg/2wU2YT/geOB77r4I+D/gyxVjUj539cwv5fMHQPly03Lg36u9XWVbS89fRzZ3Sl/6vdzMngNuB041s1snDnD3He7+evnl9cAJs1ti89z9+fLvL1K65ndixZBhYOK/SAaA52enuubUmpu7v+Lur5Z/vgfoNrNDZr3QxgwDw+6+sfx6HaVmWDkmyXNHHfNL/PyN+wjwuLv/b5X3Cj9/Hdnc3f1ydx9w9/mU/tn0kLt/YuKYiutfyyndeE2Gmb3dzA4Y/xn4W+C/K4ZtAD5VvnP/AWCXu78wy6UGq2duZvYuM7PyzydS+t/6jtmutRHu/j/AdjNbUN50GrC1YliS5w7qm1/K52+Cc6l+SQZm4fx1elpmEjNbBQy5+wbgYjNbDrwJvAysbGdtDfgL4Mfl/3/sB/ybu99nZv8I4O7fB+4BlgJPA38CPtOmWkPVM7ezgM+Z2ZvAbmCFp/XE3j8Bt5X/af9b4DOZnLtxteaX9Pkzs7cBfwP8w4Rts3r+9ISqiEiGOvKyjIhI7tTcRUQypOYuIpIhNXcRkQypuYuIZEjNXUQkQ2ruIiIZUnMXEcnQ/wPmMFqpaGCFHwAAAABJRU5ErkJggg==\n", 121 | "text/plain": [ 122 | "
" 123 | ] 124 | }, 125 | "metadata": {}, 126 | "output_type": "display_data" 127 | } 128 | ], 129 | "source": [ 130 | "plt.scatter(X[:50,0],X[:50,1], label='0')\n", 131 | "plt.scatter(X[50:,0],X[50:,1], label='1')\n", 132 | "plt.legend()" 133 | ] 134 | }, 135 | { 136 | "cell_type": "markdown", 137 | "metadata": {}, 138 | "source": [ 139 | "----\n", 140 | "\n", 141 | "### AdaBoost in Python" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 5, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "class AdaBoost:\n", 151 | " def __init__(self, n_estimators=50, learning_rate=1.0):\n", 152 | " self.clf_num = n_estimators\n", 153 | " self.learning_rate = learning_rate\n", 154 | " \n", 155 | " def init_args(self, datasets, labels):\n", 156 | " \n", 157 | " self.X = datasets\n", 158 | " self.Y = labels\n", 159 | " self.M, self.N = datasets.shape\n", 160 | " \n", 161 | " # 弱分类器数目和集合\n", 162 | " self.clf_sets = []\n", 163 | " \n", 164 | " # 初始化weights\n", 165 | " self.weights = [1.0/self.M]*self.M\n", 166 | " \n", 167 | " # G(x)系数 alpha\n", 168 | " self.alpha = []\n", 169 | " \n", 170 | " def _G(self, features, labels, weights):\n", 171 | " m = len(features)\n", 172 | " error = 100000.0 # 无穷大\n", 173 | " best_v = 0.0\n", 174 | " # 单维features\n", 175 | " features_min = min(features)\n", 176 | " features_max = max(features)\n", 177 | " n_step = (features_max - features_min + self.learning_rate) // self.learning_rate\n", 178 | " # print('n_step:{}'.format(n_step))\n", 179 | " direct, compare_array = None, None\n", 180 | " for i in range(1, int(n_step)):\n", 181 | " v = features_min + self.learning_rate * i\n", 182 | " \n", 183 | " if v not in features:\n", 184 | " # 误分类计算\n", 185 | " compare_array_positive = np.array([1 if features[k] > v else -1 for k in range(m)])\n", 186 | " weight_error_positive = sum([weights[k] for k in range(m) if compare_array_positive[k] != labels[k]])\n", 187 | " \n", 188 | " compare_array_nagetive = np.array([-1 if features[k] > v else 1 for k in range(m)])\n", 189 | " weight_error_nagetive = sum([weights[k] for k in range(m) if compare_array_nagetive[k] != labels[k]])\n", 190 | "\n", 191 | " if weight_error_positive < weight_error_nagetive:\n", 192 | " weight_error = weight_error_positive\n", 193 | " _compare_array = compare_array_positive\n", 194 | " direct = 'positive'\n", 195 | " else:\n", 196 | " weight_error = weight_error_nagetive\n", 197 | " _compare_array = compare_array_nagetive\n", 198 | " direct = 'nagetive'\n", 199 | " \n", 200 | " # print('v:{} error:{}'.format(v, weight_error))\n", 201 | " if weight_error < error:\n", 202 | " error = weight_error\n", 203 | " compare_array = _compare_array\n", 204 | " best_v = v\n", 205 | " return best_v, direct, error, compare_array\n", 206 | " \n", 207 | " # 计算alpha\n", 208 | " def _alpha(self, error):\n", 209 | " return 0.5 * np.log((1-error)/error)\n", 210 | " \n", 211 | " # 规范化因子\n", 212 | " def _Z(self, weights, a, clf):\n", 213 | " return sum([weights[i]*np.exp(-1*a*self.Y[i]*clf[i]) for i in range(self.M)])\n", 214 | " \n", 215 | " # 权值更新\n", 216 | " def _w(self, a, clf, Z):\n", 217 | " for i in range(self.M):\n", 218 | " self.weights[i] = self.weights[i]*np.exp(-1*a*self.Y[i]*clf[i])/ Z\n", 219 | " \n", 220 | " # G(x)的线性组合\n", 221 | " def _f(self, alpha, clf_sets):\n", 222 | " pass\n", 223 | " \n", 224 | " def G(self, x, v, direct):\n", 225 | " if direct == 'positive':\n", 226 | " return 1 if x > v else -1 \n", 227 | " else:\n", 228 | " return -1 if x > v else 1 \n", 229 | " \n", 230 | " def fit(self, X, y):\n", 231 | " self.init_args(X, y)\n", 232 | " \n", 233 | " for epoch in range(self.clf_num):\n", 234 | " best_clf_error, best_v, clf_result = 100000, None, None\n", 235 | " # 根据特征维度, 选择误差最小的\n", 236 | " for j in range(self.N):\n", 237 | " features = self.X[:, j]\n", 238 | " # 分类阈值,分类误差,分类结果\n", 239 | " v, direct, error, compare_array = self._G(features, self.Y, self.weights)\n", 240 | " \n", 241 | " if error < best_clf_error:\n", 242 | " best_clf_error = error\n", 243 | " best_v = v\n", 244 | " final_direct = direct\n", 245 | " clf_result = compare_array\n", 246 | " axis = j\n", 247 | " \n", 248 | " # print('epoch:{}/{} feature:{} error:{} v:{}'.format(epoch, self.clf_num, j, error, best_v))\n", 249 | " if best_clf_error == 0:\n", 250 | " break\n", 251 | " \n", 252 | " # 计算G(x)系数a\n", 253 | " a = self._alpha(best_clf_error)\n", 254 | " self.alpha.append(a)\n", 255 | " # 记录分类器\n", 256 | " self.clf_sets.append((axis, best_v, final_direct))\n", 257 | " # 规范化因子\n", 258 | " Z = self._Z(self.weights, a, clf_result)\n", 259 | " # 权值更新\n", 260 | " self._w(a, clf_result, Z)\n", 261 | " \n", 262 | "# print('classifier:{}/{} error:{:.3f} v:{} direct:{} a:{:.5f}'.format(epoch+1, self.clf_num, error, best_v, final_direct, a))\n", 263 | "# print('weight:{}'.format(self.weights))\n", 264 | "# print('\\n')\n", 265 | " \n", 266 | " def predict(self, feature):\n", 267 | " result = 0.0\n", 268 | " for i in range(len(self.clf_sets)):\n", 269 | " axis, clf_v, direct = self.clf_sets[i]\n", 270 | " f_input = feature[axis]\n", 271 | " result += self.alpha[i] * self.G(f_input, clf_v, direct)\n", 272 | " # sign\n", 273 | " return 1 if result > 0 else -1\n", 274 | " \n", 275 | " def score(self, X_test, y_test):\n", 276 | " right_count = 0\n", 277 | " for i in range(len(X_test)):\n", 278 | " feature = X_test[i]\n", 279 | " if self.predict(feature) == y_test[i]:\n", 280 | " right_count += 1\n", 281 | " \n", 282 | " return right_count / len(X_test)" 283 | ] 284 | }, 285 | { 286 | "cell_type": "markdown", 287 | "metadata": {}, 288 | "source": [ 289 | "### 例8.1" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 6, 295 | "metadata": {}, 296 | "outputs": [], 297 | "source": [ 298 | "X = np.arange(10).reshape(10, 1)\n", 299 | "y = np.array([1, 1, 1, -1, -1, -1, 1, 1, 1, -1])" 300 | ] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "execution_count": 7, 305 | "metadata": {}, 306 | "outputs": [], 307 | "source": [ 308 | "clf = AdaBoost(n_estimators=3, learning_rate=0.5)\n", 309 | "clf.fit(X, y)" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": 8, 315 | "metadata": {}, 316 | "outputs": [], 317 | "source": [ 318 | "X, y = create_data()\n", 319 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)" 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": 9, 325 | "metadata": {}, 326 | "outputs": [ 327 | { 328 | "data": { 329 | "text/plain": [ 330 | "0.6363636363636364" 331 | ] 332 | }, 333 | "execution_count": 9, 334 | "metadata": {}, 335 | "output_type": "execute_result" 336 | } 337 | ], 338 | "source": [ 339 | "clf = AdaBoost(n_estimators=10, learning_rate=0.2)\n", 340 | "clf.fit(X_train, y_train)\n", 341 | "clf.score(X_test, y_test)" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": 10, 347 | "metadata": {}, 348 | "outputs": [ 349 | { 350 | "name": "stdout", 351 | "output_type": "stream", 352 | "text": [ 353 | "average score:65.000%\n" 354 | ] 355 | } 356 | ], 357 | "source": [ 358 | "# 100次结果\n", 359 | "result = []\n", 360 | "for i in range(1, 101):\n", 361 | " X, y = create_data()\n", 362 | " X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)\n", 363 | " clf = AdaBoost(n_estimators=100, learning_rate=0.2)\n", 364 | " clf.fit(X_train, y_train)\n", 365 | " r = clf.score(X_test, y_test)\n", 366 | " # print('{}/100 score:{}'.format(i, r))\n", 367 | " result.append(r)\n", 368 | "\n", 369 | "print('average score:{:.3f}%'.format(sum(result)))" 370 | ] 371 | }, 372 | { 373 | "cell_type": "markdown", 374 | "metadata": {}, 375 | "source": [ 376 | "-----\n", 377 | "# sklearn.ensemble.AdaBoostClassifier\n", 378 | "\n", 379 | "- algorithm:这个参数只有AdaBoostClassifier有。主要原因是scikit-learn实现了两种Adaboost分类算法,SAMME和SAMME.R。两者的主要区别是弱学习器权重的度量,SAMME使用了和我们的原理篇里二元分类Adaboost算法的扩展,即用对样本集分类效果作为弱学习器权重,而SAMME.R使用了对样本集分类的预测概率大小来作为弱学习器权重。由于SAMME.R使用了概率度量的连续值,迭代一般比SAMME快,因此AdaBoostClassifier的默认算法algorithm的值也是SAMME.R。我们一般使用默认的SAMME.R就够了,但是要注意的是使用了SAMME.R, 则弱分类学习器参数base_estimator必须限制使用支持概率预测的分类器。SAMME算法则没有这个限制。\n", 380 | "\n", 381 | "- n_estimators: AdaBoostClassifier和AdaBoostRegressor都有,就是我们的弱学习器的最大迭代次数,或者说最大的弱学习器的个数。一般来说n_estimators太小,容易欠拟合,n_estimators太大,又容易过拟合,一般选择一个适中的数值。默认是50。在实际调参的过程中,我们常常将n_estimators和下面介绍的参数learning_rate一起考虑。\n", 382 | "\n", 383 | "- learning_rate: AdaBoostClassifier和AdaBoostRegressor都有,即每个弱学习器的权重缩减系数ν\n", 384 | "\n", 385 | "- base_estimator:AdaBoostClassifier和AdaBoostRegressor都有,即我们的弱分类学习器或者弱回归学习器。理论上可以选择任何一个分类或者回归学习器,不过需要支持样本权重。我们常用的一般是CART决策树或者神经网络MLP。" 386 | ] 387 | }, 388 | { 389 | "cell_type": "code", 390 | "execution_count": 11, 391 | "metadata": {}, 392 | "outputs": [ 393 | { 394 | "name": "stderr", 395 | "output_type": "stream", 396 | "text": [ 397 | "c:\\programdata\\anaconda3\\envs\\tf\\lib\\site-packages\\sklearn\\ensemble\\weight_boosting.py:29: DeprecationWarning: numpy.core.umath_tests is an internal NumPy module and should not be imported. It will be removed in a future NumPy release.\n", 398 | " from numpy.core.umath_tests import inner1d\n" 399 | ] 400 | }, 401 | { 402 | "data": { 403 | "text/plain": [ 404 | "AdaBoostClassifier(algorithm='SAMME.R', base_estimator=None,\n", 405 | " learning_rate=0.5, n_estimators=100, random_state=None)" 406 | ] 407 | }, 408 | "execution_count": 11, 409 | "metadata": {}, 410 | "output_type": "execute_result" 411 | } 412 | ], 413 | "source": [ 414 | "from sklearn.ensemble import AdaBoostClassifier\n", 415 | "clf = AdaBoostClassifier(n_estimators=100, learning_rate=0.5)\n", 416 | "clf.fit(X_train, y_train)" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": 12, 422 | "metadata": {}, 423 | "outputs": [ 424 | { 425 | "data": { 426 | "text/plain": [ 427 | "0.9393939393939394" 428 | ] 429 | }, 430 | "execution_count": 12, 431 | "metadata": {}, 432 | "output_type": "execute_result" 433 | } 434 | ], 435 | "source": [ 436 | "clf.score(X_test, y_test)" 437 | ] 438 | } 439 | ], 440 | "metadata": { 441 | "kernelspec": { 442 | "display_name": "Python 3", 443 | "language": "python", 444 | "name": "python3" 445 | }, 446 | "language_info": { 447 | "codemirror_mode": { 448 | "name": "ipython", 449 | "version": 3 450 | }, 451 | "file_extension": ".py", 452 | "mimetype": "text/x-python", 453 | "name": "python", 454 | "nbconvert_exporter": "python", 455 | "pygments_lexer": "ipython3", 456 | "version": "3.6.2" 457 | } 458 | }, 459 | "nbformat": 4, 460 | "nbformat_minor": 2 461 | } 462 | -------------------------------------------------------------------------------- /code/第9章 EM算法及其推广(EM)/em.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "原文代码作者:https://github.com/wzyonggege/statistical-learning-method\n", 8 | "\n", 9 | "中文注释制作:机器学习初学者(微信公众号:ID:ai-start-com)\n", 10 | "\n", 11 | "配置环境:python 3.6\n", 12 | "\n", 13 | "代码全部测试通过。\n", 14 | "![gongzhong](../gongzhong.jpg)" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "# 第9章 EM算法及其推广" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "# Expectation Maximization algorithm\n", 29 | "\n", 30 | "### Maximum likehood function\n", 31 | "\n", 32 | "[likehood & maximum likehood](http://fangs.in/post/thinkstats/likelihood/)\n", 33 | "\n", 34 | "> 在统计学中,似然函数(likelihood function,通常简写为likelihood,似然)是一个非常重要的内容,在非正式场合似然和概率(Probability)几乎是一对同义词,但是在统计学中似然和概率却是两个不同的概念。概率是在特定环境下某件事情发生的可能性,也就是结果没有产生之前依据环境所对应的参数来预测某件事情发生的可能性,比如抛硬币,抛之前我们不知道最后是哪一面朝上,但是根据硬币的性质我们可以推测任何一面朝上的可能性均为50%,这个概率只有在抛硬币之前才是有意义的,抛完硬币后的结果便是确定的;而似然刚好相反,是在确定的结果下去推测产生这个结果的可能环境(参数),还是抛硬币的例子,假设我们随机抛掷一枚硬币1,000次,结果500次人头朝上,500次数字朝上(实际情况一般不会这么理想,这里只是举个例子),我们很容易判断这是一枚标准的硬币,两面朝上的概率均为50%,这个过程就是我们运用出现的结果来判断这个事情本身的性质(参数),也就是似然。" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "$$P(Y|\\theta) = \\prod[\\pi p^{y_i}(1-p)^{1-y_i}+(1-\\pi) q^{y_i}(1-q)^{1-y_i}]$$\n", 42 | "\n", 43 | "### E step:\n", 44 | "\n", 45 | "$$\\mu^{i+1}=\\frac{\\pi (p^i)^{y_i}(1-(p^i))^{1-y_i}}{\\pi (p^i)^{y_i}(1-(p^i))^{1-y_i}+(1-\\pi) (q^i)^{y_i}(1-(q^i))^{1-y_i}}$$" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 1, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "import numpy as np\n", 55 | "import math" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 2, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "pro_A, pro_B, por_C = 0.5, 0.5, 0.5\n", 65 | "\n", 66 | "def pmf(i, pro_A, pro_B, por_C):\n", 67 | " pro_1 = pro_A * math.pow(pro_B, data[i]) * math.pow((1-pro_B), 1-data[i])\n", 68 | " pro_2 = pro_A * math.pow(pro_C, data[i]) * math.pow((1-pro_C), 1-data[i])\n", 69 | " return pro_1 / (pro_1 + pro_2)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": {}, 75 | "source": [ 76 | "### M step:\n", 77 | "\n", 78 | "$$\\pi^{i+1}=\\frac{1}{n}\\sum_{j=1}^n\\mu^{i+1}_j$$\n", 79 | "\n", 80 | "$$p^{i+1}=\\frac{\\sum_{j=1}^n\\mu^{i+1}_jy_i}{\\sum_{j=1}^n\\mu^{i+1}_j}$$\n", 81 | "\n", 82 | "$$q^{i+1}=\\frac{\\sum_{j=1}^n(1-\\mu^{i+1}_jy_i)}{\\sum_{j=1}^n(1-\\mu^{i+1}_j)}$$" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 3, 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "class EM:\n", 92 | " def __init__(self, prob):\n", 93 | " self.pro_A, self.pro_B, self.pro_C = prob\n", 94 | " \n", 95 | " # e_step\n", 96 | " def pmf(self, i):\n", 97 | " pro_1 = self.pro_A * math.pow(self.pro_B, data[i]) * math.pow((1-self.pro_B), 1-data[i])\n", 98 | " pro_2 = (1 - self.pro_A) * math.pow(self.pro_C, data[i]) * math.pow((1-self.pro_C), 1-data[i])\n", 99 | " return pro_1 / (pro_1 + pro_2)\n", 100 | " \n", 101 | " # m_step\n", 102 | " def fit(self, data):\n", 103 | " count = len(data)\n", 104 | " print('init prob:{}, {}, {}'.format(self.pro_A, self.pro_B, self.pro_C))\n", 105 | " for d in range(count):\n", 106 | " _ = yield\n", 107 | " _pmf = [self.pmf(k) for k in range(count)]\n", 108 | " pro_A = 1/ count * sum(_pmf)\n", 109 | " pro_B = sum([_pmf[k]*data[k] for k in range(count)]) / sum([_pmf[k] for k in range(count)])\n", 110 | " pro_C = sum([(1-_pmf[k])*data[k] for k in range(count)]) / sum([(1-_pmf[k]) for k in range(count)])\n", 111 | " print('{}/{} pro_a:{:.3f}, pro_b:{:.3f}, pro_c:{:.3f}'.format(d+1, count, pro_A, pro_B, pro_C))\n", 112 | " self.pro_A = pro_A\n", 113 | " self.pro_B = pro_B\n", 114 | " self.pro_C = pro_C\n", 115 | " " 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 4, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "data=[1,1,0,1,0,0,1,0,1,1]" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 5, 130 | "metadata": {}, 131 | "outputs": [ 132 | { 133 | "name": "stdout", 134 | "output_type": "stream", 135 | "text": [ 136 | "init prob:0.5, 0.5, 0.5\n" 137 | ] 138 | } 139 | ], 140 | "source": [ 141 | "em = EM(prob=[0.5, 0.5, 0.5])\n", 142 | "f = em.fit(data)\n", 143 | "next(f)" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 6, 149 | "metadata": {}, 150 | "outputs": [ 151 | { 152 | "name": "stdout", 153 | "output_type": "stream", 154 | "text": [ 155 | "1/10 pro_a:0.500, pro_b:0.600, pro_c:0.600\n" 156 | ] 157 | } 158 | ], 159 | "source": [ 160 | "# 第一次迭代\n", 161 | "f.send(1)" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 7, 167 | "metadata": {}, 168 | "outputs": [ 169 | { 170 | "name": "stdout", 171 | "output_type": "stream", 172 | "text": [ 173 | "2/10 pro_a:0.500, pro_b:0.600, pro_c:0.600\n" 174 | ] 175 | } 176 | ], 177 | "source": [ 178 | "# 第二次\n", 179 | "f.send(2)" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 8, 185 | "metadata": {}, 186 | "outputs": [ 187 | { 188 | "name": "stdout", 189 | "output_type": "stream", 190 | "text": [ 191 | "init prob:0.4, 0.6, 0.7\n" 192 | ] 193 | } 194 | ], 195 | "source": [ 196 | "em = EM(prob=[0.4, 0.6, 0.7])\n", 197 | "f2 = em.fit(data)\n", 198 | "next(f2)" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 9, 204 | "metadata": {}, 205 | "outputs": [ 206 | { 207 | "name": "stdout", 208 | "output_type": "stream", 209 | "text": [ 210 | "1/10 pro_a:0.406, pro_b:0.537, pro_c:0.643\n" 211 | ] 212 | } 213 | ], 214 | "source": [ 215 | "f2.send(1)" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 10, 221 | "metadata": {}, 222 | "outputs": [ 223 | { 224 | "name": "stdout", 225 | "output_type": "stream", 226 | "text": [ 227 | "2/10 pro_a:0.406, pro_b:0.537, pro_c:0.643\n" 228 | ] 229 | } 230 | ], 231 | "source": [ 232 | "f2.send(2)" 233 | ] 234 | } 235 | ], 236 | "metadata": { 237 | "kernelspec": { 238 | "display_name": "Python 3", 239 | "language": "python", 240 | "name": "python3" 241 | }, 242 | "language_info": { 243 | "codemirror_mode": { 244 | "name": "ipython", 245 | "version": 3 246 | }, 247 | "file_extension": ".py", 248 | "mimetype": "text/x-python", 249 | "name": "python", 250 | "nbconvert_exporter": "python", 251 | "pygments_lexer": "ipython3", 252 | "version": "3.6.2" 253 | } 254 | }, 255 | "nbformat": 4, 256 | "nbformat_minor": 2 257 | } 258 | -------------------------------------------------------------------------------- /images/1543246677825.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingchaoZhu/lihang-code/3b833f78fd99ef6fb60b651708c61792e56c61bc/images/1543246677825.png -------------------------------------------------------------------------------- /images/gongzhong.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingchaoZhu/lihang-code/3b833f78fd99ef6fb60b651708c61792e56c61bc/images/gongzhong.png -------------------------------------------------------------------------------- /ppt/readme.md: -------------------------------------------------------------------------------- 1 | **《统计学习方法》简介** 2 | 3 | **《统计学习方法》**,作者李航,本书全面系统地介绍了统计学习的主要方法,特别是监督学习方法,包括感知机、k近邻法、朴素贝叶斯法、决策树、逻辑斯谛回归与支持向量机、提升方法、EM算法、隐马尔可夫模型和条件随机场等。除第1章概论和最后一章总结外,每章介绍一种方法。叙述从具体问题或实例入手,由浅入深,阐明思路,给出必要的数学推导,便于读者掌握统计学习方法的实质,学会运用。 4 | 5 | **目录:** 6 | 7 | 第1章 统计学习方法概论 8 | 9 | 第2章 感知机 10 | 11 | 第3章 k近邻法 12 | 13 | 第4章 朴素贝叶斯 14 | 15 | 第5章 决策树 16 | 17 | 第6章 逻辑斯谛回归 18 | 19 | 第7章 支持向量机 20 | 21 | 第8章 提升方法 22 | 23 | 第9章 EM算法及其推广 24 | 25 | 第10章 隐马尔可夫模型 26 | 27 | 第11章 条件随机场 28 | 29 | 第12章 统计学习方法总结 30 | 31 | 32 | 33 | **《统计学习方法》课件** 34 | 35 | 作者袁春: 清华大学深圳研究生院,提供了全书12章的PPT课件。 36 | 37 | 整理:机器学习初学者 (微信公众号,ID:ai-start-com) 38 | -------------------------------------------------------------------------------- /ppt/第10章 隐马尔科夫模型.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingchaoZhu/lihang-code/3b833f78fd99ef6fb60b651708c61792e56c61bc/ppt/第10章 隐马尔科夫模型.pdf -------------------------------------------------------------------------------- /ppt/第11章 条件随机场.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingchaoZhu/lihang-code/3b833f78fd99ef6fb60b651708c61792e56c61bc/ppt/第11章 条件随机场.pdf -------------------------------------------------------------------------------- /ppt/第12章 统计学习方法总结.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingchaoZhu/lihang-code/3b833f78fd99ef6fb60b651708c61792e56c61bc/ppt/第12章 统计学习方法总结.pdf -------------------------------------------------------------------------------- /ppt/第1章 统计学习方法概论.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingchaoZhu/lihang-code/3b833f78fd99ef6fb60b651708c61792e56c61bc/ppt/第1章 统计学习方法概论.pdf -------------------------------------------------------------------------------- /ppt/第2章 感知机.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingchaoZhu/lihang-code/3b833f78fd99ef6fb60b651708c61792e56c61bc/ppt/第2章 感知机.pdf -------------------------------------------------------------------------------- /ppt/第3章 k 近邻法.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingchaoZhu/lihang-code/3b833f78fd99ef6fb60b651708c61792e56c61bc/ppt/第3章 k 近邻法.pdf -------------------------------------------------------------------------------- /ppt/第4章 朴素贝叶斯法.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingchaoZhu/lihang-code/3b833f78fd99ef6fb60b651708c61792e56c61bc/ppt/第4章 朴素贝叶斯法.pdf -------------------------------------------------------------------------------- /ppt/第5章 决策树-2016-ID3CART.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingchaoZhu/lihang-code/3b833f78fd99ef6fb60b651708c61792e56c61bc/ppt/第5章 决策树-2016-ID3CART.pdf -------------------------------------------------------------------------------- /ppt/第6章 Logistic回归.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingchaoZhu/lihang-code/3b833f78fd99ef6fb60b651708c61792e56c61bc/ppt/第6章 Logistic回归.pdf -------------------------------------------------------------------------------- /ppt/第7章 支持向量机.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingchaoZhu/lihang-code/3b833f78fd99ef6fb60b651708c61792e56c61bc/ppt/第7章 支持向量机.pdf -------------------------------------------------------------------------------- /ppt/第8章 提升方法.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingchaoZhu/lihang-code/3b833f78fd99ef6fb60b651708c61792e56c61bc/ppt/第8章 提升方法.pdf -------------------------------------------------------------------------------- /ppt/第9章 EM算法及其推广.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingchaoZhu/lihang-code/3b833f78fd99ef6fb60b651708c61792e56c61bc/ppt/第9章 EM算法及其推广.pdf -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | 《**统计学习方法**》可以说是机器学习的入门宝典,许多机器学习培训班、互联网企业的面试、笔试题目,很多都参考这本书。本站根据网上资料用**Python**复现了课程内容,并提供本书的代码实现、课件下载。 2 | 3 | **《统计学习方法》**,作者李航,本书全面系统地介绍了统计学习的主要方法,特别是监督学习方法,包括感知机、k近邻法、朴素贝叶斯法、决策树、逻辑斯谛回归与支持向量机、提升方法、EM算法、隐马尔可夫模型和条件随机场等。除第1章概论和最后一章总结外,每章介绍一种方法。叙述从具体问题或实例入手,由浅入深,阐明思路,给出必要的数学推导,便于读者掌握统计学习方法的实质,学会运用。 4 | 5 | **目录:** 6 | 7 | 第1章 统计学习方法概论 8 | 9 | 第2章 感知机 10 | 11 | 第3章 k近邻法 12 | 13 | 第4章 朴素贝叶斯 14 | 15 | 第5章 决策树 16 | 17 | 第6章 逻辑斯谛回归 18 | 19 | 第7章 支持向量机 20 | 21 | 第8章 提升方法 22 | 23 | 第9章 EM算法及其推广 24 | 25 | 第10章 隐马尔可夫模型 26 | 27 | 第11章 条件随机场 28 | 29 | 第12章 统计学习方法总结 30 | 31 | **1.统计学习方法的代码实现(code文件夹)** 32 | 33 | **《统计学习方法》**官方没有提供代码实现,但是网上有许多机器学习爱好者尝试对每一章的内容进行了代码实现。 本站在github网站搜集了一些代码进行整理,并作了一定的修改,使用**Python3.6**实现了第1-11章的课程代码。 34 | 35 | **代码目录与截图:** 36 | 37 | ![1543246677825](images/1543246677825.png) 38 | 39 | **2.《统计学习方法》课件(ppt文件夹)** 40 | 41 | 作者袁春: 清华大学深圳研究生院,提供了全书12章的PPT课件。 42 | 43 | **参考** 44 | 45 | [https://github.com/wzyonggege/statistical-learning-method](http://link.zhihu.com/?target=https%3A//github.com/wzyonggege/statistical-learning-method) 46 | 47 | [https://github.com/WenDesi/lihang_book_algorithm](http://link.zhihu.com/?target=https%3A//github.com/WenDesi/lihang_book_algorithm) 48 | 49 | [https://blog.csdn.net/tudaodiaozhale](http://link.zhihu.com/?target=https%3A//blog.csdn.net/tudaodiaozhale) 50 | 51 | 代码整理和修改:机器学习初学者 (微信公众号,ID:ai-start-com),qq群:554839127。 52 | 53 | ![gongzhong](/images/gongzhong.png) 54 | 55 | [我的知乎](https://www.zhihu.com/people/fengdu78) --------------------------------------------------------------------------------