├── AdaBoost └── Adaboost.ipynb ├── DecisonTree ├── DT.ipynb ├── dt.py └── mytree.pdf ├── EM └── em.ipynb ├── KNearestNeighbors └── KNN.ipynb ├── LeastSquaresMethod ├── README.md └── least_sqaure_method.ipynb ├── LogisticRegression └── LR.ipynb ├── NaiveBayes └── GaussianNB.ipynb ├── Perceptron ├── Iris_perceptron.ipynb └── README.md ├── README.md ├── SVM ├── .ipynb_checkpoints │ └── support-vector-machine-checkpoint.ipynb └── support-vector-machine.ipynb └── notebooks ├── 2-knn.ipynb ├── 3-naive_bayes.ipynb ├── 4-decision_tree.ipynb ├── 5-logistic_regression.ipynb ├── 6-svm.ipynb └── 7-adaboost.ipynb /AdaBoost/Adaboost.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": true 7 | }, 8 | "source": [ 9 | "# Boost\n", 10 | "\n", 11 | "“装袋”(bagging)和“提升”(boost)是构建组合模型的两种最主要的方法,所谓的组合模型是由多个基本模型构成的模型,组合模型的预测效果往往比任意一个基本模型的效果都要好。\n", 12 | "\n", 13 | "- 装袋:每个基本模型由从总体样本中随机抽样得到的不同数据集进行训练得到,通过重抽样得到不同训练数据集的过程称为装袋。\n", 14 | "\n", 15 | "- 提升:每个基本模型训练时的数据集采用不同权重,针对上一个基本模型分类错误的样本增加权重,使得新的模型重点关注误分类样本\n", 16 | "\n", 17 | "### AdaBoost\n", 18 | "\n", 19 | "AdaBoost是AdaptiveBoost的缩写,表明该算法是具有适应性的提升算法。\n", 20 | "\n", 21 | "算法的步骤如下:\n", 22 | "\n", 23 | "1)给每个训练样本(x1,x2,….,xN)分配权重,初始权重$w_{1}$均为1/N。\n", 24 | "\n", 25 | "2)针对带有权值的样本进行训练,得到模型$G_m$(初始模型为G1)。\n", 26 | "\n", 27 | "3)计算模型$G_m$的误分率$e_m=\\sum_{i=1}^Nw_iI(y_i\\not= G_m(x_i))$\n", 28 | "\n", 29 | "4)计算模型$G_m$的系数$\\alpha_m=0.5\\log[(1-e_m)/e_m]$\n", 30 | "\n", 31 | "5)根据误分率e和当前权重向量$w_m$更新权重向量$w_{m+1}$。\n", 32 | "\n", 33 | "6)计算组合模型$f(x)=\\sum_{m=1}^M\\alpha_mG_m(x_i)$的误分率。\n", 34 | "\n", 35 | "7)当组合模型的误分率或迭代次数低于一定阈值,停止迭代;否则,回到步骤2)" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 1, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "import numpy as np\n", 45 | "import pandas as pd\n", 46 | "from sklearn.datasets import load_iris\n", 47 | "from sklearn.model_selection import train_test_split\n", 48 | "import matplotlib.pyplot as plt\n", 49 | "%matplotlib inline" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 2, 55 | "metadata": { 56 | "collapsed": true 57 | }, 58 | "outputs": [], 59 | "source": [ 60 | "# data\n", 61 | "def create_data():\n", 62 | " iris = load_iris()\n", 63 | " df = pd.DataFrame(iris.data, columns=iris.feature_names)\n", 64 | " df['label'] = iris.target\n", 65 | " df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n", 66 | " data = np.array(df.iloc[:100, [0, 1, -1]])\n", 67 | " for i in range(len(data)):\n", 68 | " if data[i,-1] == 0:\n", 69 | " data[i,-1] = -1\n", 70 | " # print(data)\n", 71 | " return data[:,:2], data[:,-1]" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 3, 77 | "metadata": { 78 | "collapsed": true 79 | }, 80 | "outputs": [], 81 | "source": [ 82 | "X, y = create_data()\n", 83 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 4, 89 | "metadata": {}, 90 | "outputs": [ 91 | { 92 | "data": { 93 | "text/plain": [ 94 | "" 95 | ] 96 | }, 97 | "execution_count": 4, 98 | "metadata": {}, 99 | "output_type": "execute_result" 100 | }, 101 | { 102 | "data": { 103 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAGZ9JREFUeJzt3X9sHOWdx/H394yv8bWAReMWsJNLCihqSSLSugTICXGg\nXkqaQoRQlAiKQhE5ELpS0aNqKtQfqBJISLRQdEQBdBTBBeVoGihHgjgoKkUklZMg5y4pKhxtY8MV\nE5TQHKYE93t/7Dqx12vvzu6O93me/bwky97Zyfj7zMA3m5nPPGPujoiIpOWvml2AiIg0npq7iEiC\n1NxFRBKk5i4ikiA1dxGRBKm5i4gkSM1dRCRBau4iIglScxcRSdBx1a5oZm1AHzDo7stL3rsAeBx4\nvbhos7vfOtX2Zs6c6XPmzMlUrIhIq9u5c+fb7t5Vab2qmztwI7APOGGS918obfpTmTNnDn19fRl+\nvYiImNnvq1mvqtMyZtYDfAm4v56iRERkelR7zv1HwDeBv0yxznlm1m9mW83szHIrmNlaM+szs76h\noaGstYqISJUqNnczWw685e47p1htFzDb3RcCPwa2lFvJ3Te4e6+793Z1VTxlJCIiNarmnPsS4BIz\nWwbMAE4ws4fd/crRFdz93TE/P2Vm/2JmM9397caXLCJSnyNHjjAwMMD777/f7FImNWPGDHp6emhv\nb6/pz1ds7u6+DlgHR1Mx/zy2sReXnwz80d3dzM6m8C+CAzVVJCKSs4GBAY4//njmzJmDmTW7nAnc\nnQMHDjAwMMDcuXNr2kaWtMw4ZnZdsYj1wOXA9Wb2ITAMrHI9BUREAvX+++8H29gBzIyPf/zj1HNt\nMlNzd/fngeeLP68fs/we4J6aqxAJ2Jbdg9zx9Cu8cXCYUzs7uHnpPFYs6m52WVKnUBv7qHrrq/mT\nu0gr2LJ7kHWb9zB8ZASAwYPDrNu8B0ANXoKm6QdEpnDH068cbeyjho+McMfTrzSpIknFtm3bmDdv\nHqeffjq33357w7ev5i4yhTcODmdaLlKNkZERbrjhBrZu3crevXvZuHEje/fubejv0GkZkSmc2tnB\nYJlGfmpnRxOqkWZp9HWXX//615x++ul86lOfAmDVqlU8/vjjfOYzn2lUyfrkLjKVm5fOo6O9bdyy\njvY2bl46r0kVyXQbve4yeHAY59h1ly27B2ve5uDgILNmzTr6uqenh8HB2rdXjpq7yBRWLOrmtssW\n0N3ZgQHdnR3cdtkCXUxtIbFed9FpGZEKVizqVjNvYXlcd+nu7mb//v1HXw8MDNDd3dj/xvTJXURk\nCpNdX6nnusvnP/95fvvb3/L666/zwQcf8Oijj3LJJZfUvL1y1NxFRKaQx3WX4447jnvuuYelS5fy\n6U9/mpUrV3LmmWUn0639dzR0ayIiiRk9Jdfou5SXLVvGsmXLGlFiWWruIiIVxHjdRadlREQSpOYu\nIpIgNXcRkQSpuYuIJEjNXUQkQWrukowtuwdZcvtzzP3Wf7Dk9ufqmvtDJG9f/epX+cQnPsH8+fNz\n2b6auyQhj8mdRPK0Zs0atm3bltv21dwlCbFO7iSR6N8EP5wP3+ssfO/fVPcmzz//fE466aQGFFee\nbmKSJOihGpKb/k3w86/BkeJ/S4f2F14DLFzZvLoq0Cd3SUIekzuJAPDsrcca+6gjw4XlAVNzlyTo\noRqSm0MD2ZYHQqdlJAl5Te4kwok9hVMx5ZYHTM1dkhHj5E4SgYu+M/6cO0B7R2F5HVavXs3zzz/P\n22+/TU9PD9///ve55ppr6iz2GDV3qVujHx4sEpTRi6bP3lo4FXNiT6Gx13kxdePGjQ0obnJq7lKX\n0Xz5aAxxNF8OqMFLOhauDDoZU44uqEpdlC8XCZOau9RF+XKJlbs3u4Qp1VufmrvURflyidGMGTM4\ncOBAsA3e3Tlw4AAzZsyoeRs65y51uXnpvHHn3EH5cglfT08PAwMDDA0NNbuUSc2YMYOentrjlmru\nUhflyyVG7e3tzJ07t9ll5Krq5m5mbUAfMOjuy0veM+AuYBnwHrDG3Xc1slAJl/LlIuHJ8sn9RmAf\ncEKZ9y4Gzih+LQbuLX4XaSnK/EsoqrqgamY9wJeA+ydZ5VLgIS/YDnSa2SkNqlEkCppTXkJSbVrm\nR8A3gb9M8n43MHbyhYHiMpGWocy/hKRiczez5cBb7r6z3l9mZmvNrM/M+kK+Si1SC2X+JSTVfHJf\nAlxiZr8DHgUuNLOHS9YZBGaNed1TXDaOu29w91537+3q6qqxZJEwKfMvIanY3N19nbv3uPscYBXw\nnLtfWbLaE8BVVnAOcMjd32x8uSLh0pzyEpKac+5mdh2Au68HnqIQg3yVQhTy6oZUJxIRZf4lJNas\n2297e3u9r6+vKb9bRCRWZrbT3Xsrrac7VCVYt2zZw8Yd+xlxp82M1Ytn8YMVC5pdlkgU1NwlSLds\n2cPD2/9w9PWI+9HXavAilWlWSAnSxh1lnlk5xXIRGU/NXYI0Msm1oMmWi8h4au4SpDazTMtFZDw1\ndwnS6sWzMi0XkfF0QVWCNHrRVGkZkdoo5y4iEhHl3KUuV9z3Ei++9s7R10tOO4lHrj23iRU1j+Zo\nlxjpnLtMUNrYAV587R2uuO+lJlXUPJqjXWKl5i4TlDb2SstTpjnaJVZq7iJT0BztEis1d5EpaI52\niZWau0yw5LSTMi1PmeZol1ipucsEj1x77oRG3qppmRWLurntsgV0d3ZgQHdnB7ddtkBpGQmecu4i\nIhFRzl3qkle2O8t2lS8XqZ2au0wwmu0ejQCOZruBupprlu3mVYNIq9A5d5kgr2x3lu0qXy5SHzV3\nmSCvbHeW7SpfLlIfNXeZIK9sd5btKl8uUh81d5kgr2x3lu0qXy5SH11QlQlGL1g2OqmSZbt51SDS\nKpRzFxGJiHLuOYsxgx1jzSJSGzX3GsSYwY6xZhGpnS6o1iDGDHaMNYtI7dTcaxBjBjvGmkWkdmru\nNYgxgx1jzSJSOzX3GsSYwY6xZhGpnS6o1iDGDHaMNYtI7Srm3M1sBvBL4CMU/jJ4zN2/W7LOBcDj\nwOvFRZvd/daptqucu4hIdo3Muf8ZuNDdD5tZO/ArM9vq7ttL1nvB3ZfXUqxMj1u27GHjjv2MuNNm\nxurFs/jBigV1rxtKfj6UOkRCULG5e+Gj/eHiy/biV3Nua5Wa3bJlDw9v/8PR1yPuR1+XNu0s64aS\nnw+lDpFQVHVB1czazOxl4C3gGXffUWa188ys38y2mtmZDa1S6rZxx/6ql2dZN5T8fCh1iISiqubu\n7iPufhbQA5xtZvNLVtkFzHb3hcCPgS3ltmNma82sz8z6hoaG6qlbMhqZ5NpKueVZ1g0lPx9KHSKh\nyBSFdPeDwC+AL5Ysf9fdDxd/fgpoN7OZZf78Bnfvdfferq6uOsqWrNrMql6eZd1Q8vOh1CESiorN\n3cy6zKyz+HMH8AXgNyXrnGxW+D/fzM4ubvdA48uVWq1ePKvq5VnWDSU/H0odIqGoJi1zCvATM2uj\n0LQ3ufuTZnYdgLuvBy4HrjezD4FhYJU3ay5hKWv0Qmg1CZgs64aSnw+lDpFQaD53EZGIaD73nOWV\nqc6SL89z21nGF+O+iE7/Jnj2Vjg0ACf2wEXfgYUrm12VBEzNvQZ5Zaqz5Mvz3HaW8cW4L6LTvwl+\n/jU4Ukz+HNpfeA1q8DIpTRxWg7wy1Vny5XluO8v4YtwX0Xn21mONfdSR4cJykUmoudcgr0x1lnx5\nntvOMr4Y90V0Dg1kWy6CmntN8spUZ8mX57ntLOOLcV9E58SebMtFUHOvSV6Z6iz58jy3nWV8Me6L\n6Fz0HWgv+cuyvaOwXGQSuqBag7wy1Vny5XluO8v4YtwX0Rm9aKq0jGSgnLuISESUc5cJQsiuS+SU\nt4+GmnuLCCG7LpFT3j4quqDaIkLIrkvklLePipp7iwghuy6RU94+KmruLSKE7LpETnn7qKi5t4gQ\nsusSOeXto6ILqi0ihOy6RE55+6go5y4iEhHl3Ivyymtn2W4o85Irux6Y1DPjqY8viybsi6Sbe155\n7SzbDWVecmXXA5N6Zjz18WXRpH2R9AXVvPLaWbYbyrzkyq4HJvXMeOrjy6JJ+yLp5p5XXjvLdkOZ\nl1zZ9cCknhlPfXxZNGlfJN3c88prZ9luKPOSK7semNQz46mPL4sm7Yukm3teee0s2w1lXnJl1wOT\nemY89fFl0aR9kfQF1bzy2lm2G8q85MquByb1zHjq48uiSftCOXcRkYgo556zEPLzV9z3Ei++9s7R\n10tOO4lHrj237hpEkvLkTbDzQfARsDb43BpYfmf92w08x5/0Ofe8jGbGBw8O4xzLjG/ZPTht2y1t\n7AAvvvYOV9z3Ul01iCTlyZug74FCY4fC974HCsvrMZpdP7Qf8GPZ9f5NdZfcKGruNQghP1/a2Cst\nF2lJOx/MtrxaEeT41dxrEEJ+XkSq4CPZllcrghy/mnsNQsjPi0gVrC3b8mpFkONXc69BCPn5Jaed\nVHYbky0XaUmfW5NtebUiyPGruddgxaJubrtsAd2dHRjQ3dnBbZctaEh+vtrtPnLtuRMaudIyIiWW\n3wm91xz7pG5thdf1pmUWroQv3w0nzgKs8P3LdweVllHOXUQkIg3LuZvZDOCXwEeK6z/m7t8tWceA\nu4BlwHvAGnffVUvhlWTNl8c2h3mWud9T3xe55oizZJ/zqiPP8QWewa5L1rGlvC+mUM1NTH8GLnT3\nw2bWDvzKzLa6+/Yx61wMnFH8WgzcW/zeUFnnJI9tDvMsc7+nvi9ynQN7NPs8ajT7DBMbfF515Dm+\nlOdSzzq2lPdFBRXPuXvB4eLL9uJX6bmcS4GHiutuBzrN7JTGlpo9Xx7bHOZZ5n5PfV/kmiPOkn3O\nq448xxdBBrtmWceW8r6ooKoLqmbWZmYvA28Bz7j7jpJVuoGxHWiguKx0O2vNrM/M+oaGhjIXmzUH\nHltuPMvc76nvi1xzxFmyz3nVkef4Ishg1yzr2FLeFxVU1dzdfcTdzwJ6gLPNbH4tv8zdN7h7r7v3\ndnV1Zf7zWXPgseXGs8z9nvq+yDVHnCX7nFcdeY4vggx2zbKOLeV9UUGmKKS7HwR+AXyx5K1BYOwE\n5T3FZQ2VNV8e2xzmWeZ+T31f5JojzpJ9zquOPMcXQQa7ZlnHlvK+qKBiczezLjPrLP7cAXwB+E3J\nak8AV1nBOcAhd3+z0cVmzZfnlUfPyw9WLODKc2Yf/aTeZsaV58wum5ZJfV/kmiPOkn3Oq448xxdB\nBrtmWceW8r6ooGLO3cwWAj8B2ij8ZbDJ3W81s+sA3H19MQp5D4VP9O8BV7v7lCF25dxFRLJrWM7d\n3fuBRWWWrx/zswM3ZC1SRETykfzDOqK7cUemR5YbW0K4CSbPG3diu0krhOMRgaSbe3Q37sj0yHJj\nSwg3weR5405sN2mFcDwikfTEYdHduCPTI8uNLSHcBJPnjTux3aQVwvGIRNLNPbobd2R6ZLmxJYSb\nYPK8cSe2m7RCOB6RSLq5R3fjjkyPLDe2hHATTJ437sR2k1YIxyMSSTf36G7ckemR5caWEG6CyfPG\nndhu0grheEQi6eYe3Y07Mj2y3NgSwk0wed64E9tNWiEcj0joYR0iIhFp2E1MIi0vy4M9QhFbzaFk\n10OpowHU3EWmkuXBHqGIreZQsuuh1NEgSZ9zF6lblgd7hCK2mkPJrodSR4OouYtMJcuDPUIRW82h\nZNdDqaNB1NxFppLlwR6hiK3mULLrodTRIGruIlPJ8mCPUMRWcyjZ9VDqaBA1d5GpZHmwRyhiqzmU\n7HoodTSIcu4iIhFRzl2mT4zZ4LxqzitfHuM+lqZSc5f6xJgNzqvmvPLlMe5jaTqdc5f6xJgNzqvm\nvPLlMe5jaTo1d6lPjNngvGrOK18e4z6WplNzl/rEmA3Oq+a88uUx7mNpOjV3qU+M2eC8as4rXx7j\nPpamU3OX+sSYDc6r5rzy5THuY2k65dxFRCJSbc5dn9wlHf2b4Ifz4Xudhe/9m6Z/u3nVIJKRcu6S\nhryy4Fm2qzy6BESf3CUNeWXBs2xXeXQJiJq7pCGvLHiW7SqPLgFRc5c05JUFz7Jd5dElIGrukoa8\nsuBZtqs8ugREzV3SkFcWPMt2lUeXgFTMuZvZLOAh4JOAAxvc/a6SdS4AHgdeLy7a7O5TXkVSzl1E\nJLtGzuf+IfANd99lZscDO83sGXffW7LeC+6+vJZiJUAxzh+epeYYxxcC7bdoVGzu7v4m8Gbx5z+Z\n2T6gGyht7pKKGPPayqPnT/stKpnOuZvZHGARsKPM2+eZWb+ZbTWzMxtQmzRLjHlt5dHzp/0Wlarv\nUDWzjwE/Bb7u7u+WvL0LmO3uh81sGbAFOKPMNtYCawFmz55dc9GSsxjz2sqj50/7LSpVfXI3s3YK\njf0Rd99c+r67v+vuh4s/PwW0m9nMMuttcPded+/t6uqqs3TJTYx5beXR86f9FpWKzd3MDHgA2Ofu\nZecuNbOTi+thZmcXt3ugkYXKNIoxr608ev6036JSzWmZJcBXgD1m9nJx2beB2QDuvh64HLjezD4E\nhoFV3qy5hKV+oxfHYkpFZKk5xvGFQPstKprPXUQkIo3MuUuolDke78mbYOeDhQdSW1vh8Xb1PgVJ\nJFJq7rFS5ni8J2+CvgeOvfaRY6/V4KUFaW6ZWClzPN7OB7MtF0mcmnuslDkez0eyLRdJnJp7rJQ5\nHs/asi0XSZyae6yUOR7vc2uyLRdJnJp7rDR3+HjL74Tea459Ure2wmtdTJUWpZy7iEhElHOvwZbd\ng9zx9Cu8cXCYUzs7uHnpPFYs6m52WY2Tei4+9fGFQPs4GmruRVt2D7Ju8x6GjxTSFYMHh1m3eQ9A\nGg0+9Vx86uMLgfZxVHTOveiOp1852thHDR8Z4Y6nX2lSRQ2Wei4+9fGFQPs4KmruRW8cHM60PDqp\n5+JTH18ItI+jouZedGpnR6bl0Uk9F5/6+EKgfRwVNfeim5fOo6N9/A0vHe1t3Lx0XpMqarDUc/Gp\njy8E2sdR0QXVotGLpsmmZVKfizv18YVA+zgqyrmLiESk2py7TsuIxKB/E/xwPnyvs/C9f1Mc25am\n0WkZkdDlmS9Xdj1Z+uQuEro88+XKridLzV0kdHnmy5VdT5aau0jo8syXK7ueLDV3kdDlmS9Xdj1Z\nau4ioctz7n49FyBZyrmLiEREOXcRkRam5i4ikiA1dxGRBKm5i4gkSM1dRCRBau4iIglScxcRSZCa\nu4hIgio2dzObZWa/MLO9ZvbfZnZjmXXMzO42s1fNrN/MPptPuVIXzdst0jKqmc/9Q+Ab7r7LzI4H\ndprZM+6+d8w6FwNnFL8WA/cWv0soNG+3SEup+Mnd3d90913Fn/8E7ANKHyx6KfCQF2wHOs3slIZX\nK7XTvN0iLSXTOXczmwMsAnaUvNUN7B/zeoCJfwFgZmvNrM/M+oaGhrJVKvXRvN0iLaXq5m5mHwN+\nCnzd3d+t5Ze5+wZ373X33q6urlo2IbXSvN0iLaWq5m5m7RQa+yPuvrnMKoPArDGve4rLJBSat1uk\npVSTljHgAWCfu985yWpPAFcVUzPnAIfc/c0G1in10rzdIi2lmrTMEuArwB4ze7m47NvAbAB3Xw88\nBSwDXgXeA65ufKlSt4Ur1cxFWkTF5u7uvwKswjoO3NCookREpD66Q1VEJEFq7iIiCVJzFxFJkJq7\niEiC1NxFRBKk5i4ikiA1dxGRBFkhot6EX2w2BPy+Kb+8spnA280uIkcaX7xSHhtofNX4W3evODlX\n05p7yMysz917m11HXjS+eKU8NtD4GkmnZUREEqTmLiKSIDX38jY0u4CcaXzxSnlsoPE1jM65i4gk\nSJ/cRUQS1NLN3czazGy3mT1Z5r0LzOyQmb1c/IrqkUVm9jsz21Osva/M+2Zmd5vZq2bWb2afbUad\ntapifLEfv04ze8zMfmNm+8zs3JL3Yz9+lcYX7fEzs3lj6n7ZzN41s6+XrJP78avmYR0puxHYB5ww\nyfsvuPvyaayn0f7e3SfL1F4MnFH8WgzcW/wek6nGB3Efv7uAbe5+uZn9NfA3Je/HfvwqjQ8iPX7u\n/gpwFhQ+QFJ45OjPSlbL/fi17Cd3M+sBvgTc3+xamuRS4CEv2A50mtkpzS5KwMxOBM6n8HhL3P0D\ndz9Yslq0x6/K8aXiIuA1dy+9YTP349eyzR34EfBN4C9TrHNe8Z9MW83szGmqq1Ec+E8z22lma8u8\n3w3sH/N6oLgsFpXGB/Eev7nAEPCvxdOG95vZR0vWifn4VTM+iPf4jbUK2Fhmee7HryWbu5ktB95y\n951TrLYLmO3uC4EfA1umpbjG+Tt3P4vCP/9uMLPzm11Qg1UaX8zH7zjgs8C97r4I+D/gW80tqaGq\nGV/Mxw+A4ummS4B/b8bvb8nmTuGh35eY2e+AR4ELzezhsSu4+7vufrj481NAu5nNnPZKa+Tug8Xv\nb1E433d2ySqDwKwxr3uKy6JQaXyRH78BYMDddxRfP0ahGY4V8/GrOL7Ij9+oi4Fd7v7HMu/lfvxa\nsrm7+zp373H3ORT+2fScu185dh0zO9nMrPjz2RT21YFpL7YGZvZRMzt+9GfgH4D/KlntCeCq4lX7\nc4BD7v7mNJdak2rGF/Pxc/f/Bfab2bzioouAvSWrRXv8qhlfzMdvjNWUPyUD03D8Wj0tM46ZXQfg\n7uuBy4HrzexDYBhY5fHc8fVJ4GfF/zeOA/7N3beVjO8pYBnwKvAecHWTaq1FNeOL+fgB/BPwSPGf\n9v8DXJ3Q8YPK44v6+BU/dHwB+Mcxy6b1+OkOVRGRBLXkaRkRkdSpuYuIJEjNXUQkQWruIiIJUnMX\nEUmQmruISILU3EVEEqTmLiKSoP8H2fNC9uxjMHwAAAAASUVORK5CYII=\n", 104 | "text/plain": [ 105 | "" 106 | ] 107 | }, 108 | "metadata": {}, 109 | "output_type": "display_data" 110 | } 111 | ], 112 | "source": [ 113 | "plt.scatter(X[:50,0],X[:50,1], label='0')\n", 114 | "plt.scatter(X[50:,0],X[50:,1], label='1')\n", 115 | "plt.legend()" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": {}, 121 | "source": [ 122 | "----\n", 123 | "\n", 124 | "### AdaBoost in Python" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 59, 130 | "metadata": { 131 | "collapsed": true 132 | }, 133 | "outputs": [], 134 | "source": [ 135 | "class AdaBoost:\n", 136 | " def __init__(self, n_estimators=50, learning_rate=1.0):\n", 137 | " self.clf_num = n_estimators\n", 138 | " self.learning_rate = learning_rate\n", 139 | " \n", 140 | " def init_args(self, datasets, labels):\n", 141 | " \n", 142 | " self.X = datasets\n", 143 | " self.Y = labels\n", 144 | " self.M, self.N = datasets.shape\n", 145 | " \n", 146 | " # 弱分类器数目和集合\n", 147 | " self.clf_sets = []\n", 148 | " \n", 149 | " # 初始化weights\n", 150 | " self.weights = [1.0/self.M]*self.M\n", 151 | " \n", 152 | " # G(x)系数 alpha\n", 153 | " self.alpha = []\n", 154 | " \n", 155 | " def _G(self, features, labels, weights):\n", 156 | " m = len(features)\n", 157 | " error = 100000.0 # 无穷大\n", 158 | " best_v = 0.0\n", 159 | " # 单维features\n", 160 | " features_min = min(features)\n", 161 | " features_max = max(features)\n", 162 | " n_step = (features_max - features_min + self.learning_rate) // self.learning_rate\n", 163 | " # print('n_step:{}'.format(n_step))\n", 164 | " direct, compare_array = None, None\n", 165 | " for i in range(1, int(n_step)):\n", 166 | " v = features_min + self.learning_rate * i\n", 167 | " \n", 168 | " if v not in features:\n", 169 | " # 误分类计算\n", 170 | " compare_array_positive = np.array([1 if features[k] > v else -1 for k in range(m)])\n", 171 | " weight_error_positive = sum([weights[k] for k in range(m) if compare_array_positive[k] != labels[k]])\n", 172 | " \n", 173 | " compare_array_nagetive = np.array([-1 if features[k] > v else 1 for k in range(m)])\n", 174 | " weight_error_nagetive = sum([weights[k] for k in range(m) if compare_array_nagetive[k] != labels[k]])\n", 175 | "\n", 176 | " if weight_error_positive < weight_error_nagetive:\n", 177 | " weight_error = weight_error_positive\n", 178 | " _compare_array = compare_array_positive\n", 179 | " direct = 'positive'\n", 180 | " else:\n", 181 | " weight_error = weight_error_nagetive\n", 182 | " _compare_array = compare_array_nagetive\n", 183 | " direct = 'nagetive'\n", 184 | " \n", 185 | " # print('v:{} error:{}'.format(v, weight_error))\n", 186 | " if weight_error < error:\n", 187 | " error = weight_error\n", 188 | " compare_array = _compare_array\n", 189 | " best_v = v\n", 190 | " return best_v, direct, error, compare_array\n", 191 | " \n", 192 | " # 计算alpha\n", 193 | " def _alpha(self, error):\n", 194 | " return 0.5 * np.log((1-error)/error)\n", 195 | " \n", 196 | " # 规范化因子\n", 197 | " def _Z(self, weights, a, clf):\n", 198 | " return sum([weights[i]*np.exp(-1*a*self.Y[i]*clf[i]) for i in range(self.M)])\n", 199 | " \n", 200 | " # 权值更新\n", 201 | " def _w(self, a, clf, Z):\n", 202 | " for i in range(self.M):\n", 203 | " self.weights[i] = self.weights[i]*np.exp(-1*a*self.Y[i]*clf[i])/ Z\n", 204 | " \n", 205 | " # G(x)的线性组合\n", 206 | " def _f(self, alpha, clf_sets):\n", 207 | " pass\n", 208 | " \n", 209 | " def G(self, x, v, direct):\n", 210 | " if direct == 'positive':\n", 211 | " return 1 if x > v else -1 \n", 212 | " else:\n", 213 | " return -1 if x > v else 1 \n", 214 | " \n", 215 | " def fit(self, X, y):\n", 216 | " self.init_args(X, y)\n", 217 | " \n", 218 | " for epoch in range(self.clf_num):\n", 219 | " best_clf_error, best_v, clf_result = 100000, None, None\n", 220 | " # 根据特征维度, 选择误差最小的\n", 221 | " for j in range(self.N):\n", 222 | " features = self.X[:, j]\n", 223 | " # 分类阈值,分类误差,分类结果\n", 224 | " v, direct, error, compare_array = self._G(features, self.Y, self.weights)\n", 225 | " \n", 226 | " if error < best_clf_error:\n", 227 | " best_clf_error = error\n", 228 | " best_v = v\n", 229 | " final_direct = direct\n", 230 | " clf_result = compare_array\n", 231 | " axis = j\n", 232 | " \n", 233 | " # print('epoch:{}/{} feature:{} error:{} v:{}'.format(epoch, self.clf_num, j, error, best_v))\n", 234 | " if best_clf_error == 0:\n", 235 | " break\n", 236 | " \n", 237 | " # 计算G(x)系数a\n", 238 | " a = self._alpha(best_clf_error)\n", 239 | " self.alpha.append(a)\n", 240 | " # 记录分类器\n", 241 | " self.clf_sets.append((axis, best_v, final_direct))\n", 242 | " # 规范化因子\n", 243 | " Z = self._Z(self.weights, a, clf_result)\n", 244 | " # 权值更新\n", 245 | " self._w(a, clf_result, Z)\n", 246 | " \n", 247 | "# print('classifier:{}/{} error:{:.3f} v:{} direct:{} a:{:.5f}'.format(epoch+1, self.clf_num, error, best_v, final_direct, a))\n", 248 | "# print('weight:{}'.format(self.weights))\n", 249 | "# print('\\n')\n", 250 | " \n", 251 | " def predict(self, feature):\n", 252 | " result = 0.0\n", 253 | " for i in range(len(self.clf_sets)):\n", 254 | " axis, clf_v, direct = self.clf_sets[i]\n", 255 | " f_input = feature[axis]\n", 256 | " result += self.alpha[i] * self.G(f_input, clf_v, direct)\n", 257 | " # sign\n", 258 | " return 1 if result > 0 else -1\n", 259 | " \n", 260 | " def score(self, X_test, y_test):\n", 261 | " right_count = 0\n", 262 | " for i in range(len(X_test)):\n", 263 | " feature = X_test[i]\n", 264 | " if self.predict(feature) == y_test[i]:\n", 265 | " right_count += 1\n", 266 | " \n", 267 | " return right_count / len(X_test)" 268 | ] 269 | }, 270 | { 271 | "cell_type": "markdown", 272 | "metadata": {}, 273 | "source": [ 274 | "### 例8.1" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": 22, 280 | "metadata": { 281 | "collapsed": true 282 | }, 283 | "outputs": [], 284 | "source": [ 285 | "X = np.arange(10).reshape(10, 1)\n", 286 | "y = np.array([1, 1, 1, -1, -1, -1, 1, 1, 1, -1])" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": 23, 292 | "metadata": {}, 293 | "outputs": [ 294 | { 295 | "name": "stdout", 296 | "output_type": "stream", 297 | "text": [ 298 | "classifier:1/3 error:0.300 v:2.5 direct:nagetive a:0.42365\n", 299 | "weight:[0.071428571428571425, 0.071428571428571425, 0.071428571428571425, 0.071428571428571425, 0.071428571428571425, 0.071428571428571425, 0.16666666666666663, 0.16666666666666663, 0.16666666666666663, 0.071428571428571425]\n", 300 | "\n", 301 | "\n", 302 | "classifier:2/3 error:0.214 v:8.5 direct:nagetive a:0.64964\n", 303 | "weight:[0.045454545454545463, 0.045454545454545463, 0.045454545454545463, 0.16666666666666669, 0.16666666666666669, 0.16666666666666669, 0.10606060606060606, 0.10606060606060606, 0.10606060606060606, 0.045454545454545463]\n", 304 | "\n", 305 | "\n", 306 | "classifier:3/3 error:0.182 v:5.5 direct:nagetive a:0.75204\n", 307 | "weight:[0.12499999999999996, 0.12499999999999996, 0.12499999999999996, 0.10185185185185185, 0.10185185185185185, 0.10185185185185185, 0.064814814814814797, 0.064814814814814797, 0.064814814814814797, 0.12499999999999996]\n", 308 | "\n", 309 | "\n" 310 | ] 311 | } 312 | ], 313 | "source": [ 314 | "clf = AdaBoost(n_estimators=3, learning_rate=0.5)\n", 315 | "clf.fit(X, y)" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": 50, 321 | "metadata": { 322 | "collapsed": true 323 | }, 324 | "outputs": [], 325 | "source": [ 326 | "X, y = create_data()\n", 327 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)" 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": 51, 333 | "metadata": {}, 334 | "outputs": [ 335 | { 336 | "data": { 337 | "text/plain": [ 338 | "0.8484848484848485" 339 | ] 340 | }, 341 | "execution_count": 51, 342 | "metadata": {}, 343 | "output_type": "execute_result" 344 | } 345 | ], 346 | "source": [ 347 | "clf = AdaBoost(n_estimators=10, learning_rate=0.2)\n", 348 | "clf.fit(X_train, y_train)\n", 349 | "clf.score(X_test, y_test)" 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "execution_count": 60, 355 | "metadata": {}, 356 | "outputs": [ 357 | { 358 | "name": "stdout", 359 | "output_type": "stream", 360 | "text": [ 361 | "average score:63.061%\n" 362 | ] 363 | } 364 | ], 365 | "source": [ 366 | "# 100次结果\n", 367 | "result = []\n", 368 | "for i in range(1, 101):\n", 369 | " X, y = create_data()\n", 370 | " X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)\n", 371 | " clf = AdaBoost(n_estimators=100, learning_rate=0.2)\n", 372 | " clf.fit(X_train, y_train)\n", 373 | " r = clf.score(X_test, y_test)\n", 374 | " # print('{}/100 score:{}'.format(i, r))\n", 375 | " result.append(r)\n", 376 | "\n", 377 | "print('average score:{:.3f}%'.format(sum(result)))" 378 | ] 379 | }, 380 | { 381 | "cell_type": "markdown", 382 | "metadata": {}, 383 | "source": [ 384 | "-----\n", 385 | "# sklearn.ensemble.AdaBoostClassifier\n", 386 | "\n", 387 | "- algorithm:这个参数只有AdaBoostClassifier有。主要原因是scikit-learn实现了两种Adaboost分类算法,SAMME和SAMME.R。两者的主要区别是弱学习器权重的度量,SAMME使用了和我们的原理篇里二元分类Adaboost算法的扩展,即用对样本集分类效果作为弱学习器权重,而SAMME.R使用了对样本集分类的预测概率大小来作为弱学习器权重。由于SAMME.R使用了概率度量的连续值,迭代一般比SAMME快,因此AdaBoostClassifier的默认算法algorithm的值也是SAMME.R。我们一般使用默认的SAMME.R就够了,但是要注意的是使用了SAMME.R, 则弱分类学习器参数base_estimator必须限制使用支持概率预测的分类器。SAMME算法则没有这个限制。\n", 388 | "\n", 389 | "- n_estimators: AdaBoostClassifier和AdaBoostRegressor都有,就是我们的弱学习器的最大迭代次数,或者说最大的弱学习器的个数。一般来说n_estimators太小,容易欠拟合,n_estimators太大,又容易过拟合,一般选择一个适中的数值。默认是50。在实际调参的过程中,我们常常将n_estimators和下面介绍的参数learning_rate一起考虑。\n", 390 | "\n", 391 | "- learning_rate: AdaBoostClassifier和AdaBoostRegressor都有,即每个弱学习器的权重缩减系数ν\n", 392 | "\n", 393 | "- base_estimator:AdaBoostClassifier和AdaBoostRegressor都有,即我们的弱分类学习器或者弱回归学习器。理论上可以选择任何一个分类或者回归学习器,不过需要支持样本权重。我们常用的一般是CART决策树或者神经网络MLP。" 394 | ] 395 | }, 396 | { 397 | "cell_type": "code", 398 | "execution_count": 57, 399 | "metadata": {}, 400 | "outputs": [ 401 | { 402 | "data": { 403 | "text/plain": [ 404 | "AdaBoostClassifier(algorithm='SAMME.R', base_estimator=None,\n", 405 | " learning_rate=0.5, n_estimators=100, random_state=None)" 406 | ] 407 | }, 408 | "execution_count": 57, 409 | "metadata": {}, 410 | "output_type": "execute_result" 411 | } 412 | ], 413 | "source": [ 414 | "from sklearn.ensemble import AdaBoostClassifier\n", 415 | "clf = AdaBoostClassifier(n_estimators=100, learning_rate=0.5)\n", 416 | "clf.fit(X_train, y_train)" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": 58, 422 | "metadata": {}, 423 | "outputs": [ 424 | { 425 | "data": { 426 | "text/plain": [ 427 | "0.90909090909090906" 428 | ] 429 | }, 430 | "execution_count": 58, 431 | "metadata": {}, 432 | "output_type": "execute_result" 433 | } 434 | ], 435 | "source": [ 436 | "clf.score(X_test, y_test)" 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": null, 442 | "metadata": { 443 | "collapsed": true 444 | }, 445 | "outputs": [], 446 | "source": [] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": null, 451 | "metadata": { 452 | "collapsed": true 453 | }, 454 | "outputs": [], 455 | "source": [] 456 | } 457 | ], 458 | "metadata": { 459 | "kernelspec": { 460 | "display_name": "Python 3", 461 | "language": "python", 462 | "name": "python3" 463 | }, 464 | "language_info": { 465 | "codemirror_mode": { 466 | "name": "ipython", 467 | "version": 3 468 | }, 469 | "file_extension": ".py", 470 | "mimetype": "text/x-python", 471 | "name": "python", 472 | "nbconvert_exporter": "python", 473 | "pygments_lexer": "ipython3", 474 | "version": "3.6.1" 475 | } 476 | }, 477 | "nbformat": 4, 478 | "nbformat_minor": 2 479 | } 480 | -------------------------------------------------------------------------------- /DecisonTree/dt.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2018/1/3 14:02 3 | # @Author : wangzy 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from math import log 9 | 10 | 11 | # 定义节点类 二叉树 12 | class Node: 13 | def __init__(self, root=True, label=None, feature_name=None, feature=None): 14 | self.root = root 15 | self.label = label 16 | self.feature_name = feature_name 17 | self.feature = feature 18 | self.tree = {} 19 | self.result = {'label:': self.label, 'feature': self.feature, 'tree': self.tree} 20 | 21 | def __repr__(self): 22 | return '{}'.format(self.result) 23 | 24 | def add_node(self, val, node): 25 | self.tree[val] = node 26 | 27 | def predict(self, features): 28 | if self.root is True: 29 | return self.label 30 | return self.tree[features[self.feature]].predict(features) 31 | 32 | 33 | # 书上题目5.1 34 | def create_data(): 35 | datasets = [['青年', '否', '否', '一般', '否'], 36 | ['青年', '否', '否', '好', '否'], 37 | ['青年', '是', '否', '好', '是'], 38 | ['青年', '是', '是', '一般', '是'], 39 | ['青年', '否', '否', '一般', '否'], 40 | ['中年', '否', '否', '一般', '否'], 41 | ['中年', '否', '否', '好', '否'], 42 | ['中年', '是', '是', '好', '是'], 43 | ['中年', '否', '是', '非常好', '是'], 44 | ['中年', '否', '是', '非常好', '是'], 45 | ['老年', '否', '是', '非常好', '是'], 46 | ['老年', '否', '是', '好', '是'], 47 | ['老年', '是', '否', '好', '是'], 48 | ['老年', '是', '否', '非常好', '是'], 49 | ['老年', '否', '否', '一般', '否'], 50 | ] 51 | labels = [u'年龄', u'有工作', u'有自己的房子', u'信贷情况', u'类别'] 52 | # 返回数据集和每个维度的名称 53 | return datasets, labels 54 | 55 | 56 | class DTree: 57 | def __init__(self, epsilon=0.1): 58 | self.epsilon = epsilon 59 | self._tree = {} 60 | 61 | # 熵 62 | @staticmethod 63 | def calc_ent(datasets): 64 | data_length = len(datasets) 65 | label_count = {} 66 | for i in range(data_length): 67 | label = datasets[i][-1] 68 | if label not in label_count: 69 | label_count[label] = 0 70 | label_count[label] += 1 71 | ent = -sum([(p/data_length)*log(p/data_length, 2) for p in label_count.values()]) 72 | return ent 73 | 74 | # 经验条件熵 75 | def cond_ent(self, datasets, axis=0): 76 | data_length = len(datasets) 77 | feature_sets = {} 78 | for i in range(data_length): 79 | feature = datasets[i][axis] 80 | if feature not in feature_sets: 81 | feature_sets[feature] = [] 82 | feature_sets[feature].append(datasets[i]) 83 | cond_ent = sum([(len(p)/data_length)*self.calc_ent(p) for p in feature_sets.values()]) 84 | return cond_ent 85 | 86 | # 信息增益 87 | @staticmethod 88 | def info_gain(ent, cond_ent): 89 | return ent - cond_ent 90 | 91 | def info_gain_train(self, datasets): 92 | count = len(datasets[0]) - 1 93 | ent = self.calc_ent(datasets) 94 | best_feature = [] 95 | for c in range(count): 96 | c_info_gain = self.info_gain(ent, self.cond_ent(datasets, axis=c)) 97 | best_feature.append((c, c_info_gain)) 98 | # 比较大小 99 | best_ = max(best_feature, key=lambda x: x[-1]) 100 | return best_ 101 | 102 | def train(self, train_data): 103 | """ 104 | input:数据集D(DataFrame格式),特征集A,阈值eta 105 | output:决策树T 106 | """ 107 | _, y_train, features = train_data.iloc[:, :-1], train_data.iloc[:, -1], train_data.columns[:-1] 108 | # 1,若D中实例属于同一类Ck,则T为单节点树,并将类Ck作为结点的类标记,返回T 109 | if len(y_train.value_counts()) == 1: 110 | return Node(root=True, 111 | label=y_train.iloc[0]) 112 | 113 | # 2, 若A为空,则T为单节点树,将D中实例树最大的类Ck作为该节点的类标记,返回T 114 | if len(features) == 0: 115 | return Node(root=True, label=y_train.value_counts().sort_values(ascending=False).index[0]) 116 | 117 | # 3,计算最大信息增益 同5.1,Ag为信息增益最大的特征 118 | max_feature, max_info_gain = self.info_gain_train(np.array(train_data)) 119 | max_feature_name = features[max_feature] 120 | 121 | # 4,Ag的信息增益小于阈值eta,则置T为单节点树,并将D中是实例数最大的类Ck作为该节点的类标记,返回T 122 | if max_info_gain < self.epsilon: 123 | return Node(root=True, label=y_train.value_counts().sort_values(ascending=False).index[0]) 124 | 125 | # 5,构建Ag子集 126 | node_tree = Node(root=False, feature_name=max_feature_name, feature=max_feature) 127 | 128 | feature_list = train_data[max_feature_name].value_counts().index 129 | for f in feature_list: 130 | sub_train_df = train_data.loc[train_data[max_feature_name] == f].drop([max_feature_name], axis=1) 131 | 132 | # 6, 递归生成树 133 | sub_tree = self.train(sub_train_df) 134 | node_tree.add_node(f, sub_tree) 135 | 136 | # pprint.pprint(node_tree.tree) 137 | return node_tree 138 | 139 | def fit(self, train_data): 140 | self._tree = self.train(train_data) 141 | return self._tree 142 | 143 | def predict(self, X_test): 144 | return self._tree.predict(X_test) 145 | 146 | 147 | if __name__ == '__main__': 148 | datasets, labels = create_data() 149 | data_df = pd.DataFrame(datasets, columns=labels) 150 | dt = DTree() 151 | tree = dt.fit(data_df) 152 | print(dt.predict(['老年', '否', '否', '一般'])) 153 | 154 | -------------------------------------------------------------------------------- /DecisonTree/mytree.pdf: -------------------------------------------------------------------------------- 1 | digraph Tree { 2 | node [shape=box] ; 3 | 0 [label="X[0] <= 5.45\ngini = 0.4996\nsamples = 70\nvalue = [34, 36]"] ; 4 | 1 [label="X[1] <= 2.85\ngini = 0.2392\nsamples = 36\nvalue = [31, 5]"] ; 5 | 0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ; 6 | 2 [label="gini = 0.0\nsamples = 4\nvalue = [0, 4]"] ; 7 | 1 -> 2 ; 8 | 3 [label="X[0] <= 5.3\ngini = 0.0605\nsamples = 32\nvalue = [31, 1]"] ; 9 | 1 -> 3 ; 10 | 4 [label="gini = 0.0\nsamples = 27\nvalue = [27, 0]"] ; 11 | 3 -> 4 ; 12 | 5 [label="X[1] <= 3.2\ngini = 0.32\nsamples = 5\nvalue = [4, 1]"] ; 13 | 3 -> 5 ; 14 | 6 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]"] ; 15 | 5 -> 6 ; 16 | 7 [label="gini = 0.0\nsamples = 4\nvalue = [4, 0]"] ; 17 | 5 -> 7 ; 18 | 8 [label="X[1] <= 3.35\ngini = 0.1609\nsamples = 34\nvalue = [3, 31]"] ; 19 | 0 -> 8 [labeldistance=2.5, labelangle=-45, headlabel="False"] ; 20 | 9 [label="gini = 0.0\nsamples = 31\nvalue = [0, 31]"] ; 21 | 8 -> 9 ; 22 | 10 [label="gini = 0.0\nsamples = 3\nvalue = [3, 0]"] ; 23 | 8 -> 10 ; 24 | } -------------------------------------------------------------------------------- /EM/em.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# EM算法\n", 8 | "\n", 9 | "# Expectation Maximization algorithm\n", 10 | "\n", 11 | "### Maximum likehood function\n", 12 | "\n", 13 | "[likehood & maximum likehood](http://fangs.in/post/thinkstats/likelihood/)\n", 14 | "\n", 15 | "> 在统计学中,似然函数(likelihood function,通常简写为likelihood,似然)是一个非常重要的内容,在非正式场合似然和概率(Probability)几乎是一对同义词,但是在统计学中似然和概率却是两个不同的概念。概率是在特定环境下某件事情发生的可能性,也就是结果没有产生之前依据环境所对应的参数来预测某件事情发生的可能性,比如抛硬币,抛之前我们不知道最后是哪一面朝上,但是根据硬币的性质我们可以推测任何一面朝上的可能性均为50%,这个概率只有在抛硬币之前才是有意义的,抛完硬币后的结果便是确定的;而似然刚好相反,是在确定的结果下去推测产生这个结果的可能环境(参数),还是抛硬币的例子,假设我们随机抛掷一枚硬币1,000次,结果500次人头朝上,500次数字朝上(实际情况一般不会这么理想,这里只是举个例子),我们很容易判断这是一枚标准的硬币,两面朝上的概率均为50%,这个过程就是我们运用出现的结果来判断这个事情本身的性质(参数),也就是似然。" 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": {}, 21 | "source": [ 22 | "$$P(Y|\\theta) = \\prod[\\pi p^{y_i}(1-p)^{1-y_i}+(1-\\pi) q^{y_i}(1-q)^{1-y_i}]$$\n", 23 | "\n", 24 | "### E step:\n", 25 | "\n", 26 | "$$\\mu^{i+1}=\\frac{\\pi (p^i)^{y_i}(1-(p^i))^{1-y_i}}{\\pi (p^i)^{y_i}(1-(p^i))^{1-y_i}+(1-\\pi) (q^i)^{y_i}(1-(q^i))^{1-y_i}}$$" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 1, 32 | "metadata": { 33 | "collapsed": true 34 | }, 35 | "outputs": [], 36 | "source": [ 37 | "import numpy as np\n", 38 | "import math" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 2, 44 | "metadata": { 45 | "collapsed": true 46 | }, 47 | "outputs": [], 48 | "source": [ 49 | "pro_A, pro_B, por_C = 0.5, 0.5, 0.5\n", 50 | "\n", 51 | "def pmf(i, pro_A, pro_B, por_C):\n", 52 | " pro_1 = pro_A * math.pow(pro_B, data[i]) * math.pow((1-pro_B), 1-data[i])\n", 53 | " pro_2 = pro_A * math.pow(pro_C, data[i]) * math.pow((1-pro_C), 1-data[i])\n", 54 | " return pro_1 / (pro_1 + pro_2)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "### M step:\n", 62 | "\n", 63 | "$$\\pi^{i+1}=\\frac{1}{n}\\sum_{j=1}^n\\mu^{i+1}_j$$\n", 64 | "\n", 65 | "$$p^{i+1}=\\frac{\\sum_{j=1}^n\\mu^{i+1}_jy_i}{\\sum_{j=1}^n\\mu^{i+1}_j}$$\n", 66 | "\n", 67 | "$$q^{i+1}=\\frac{\\sum_{j=1}^n(1-\\mu^{i+1}_jy_i)}{\\sum_{j=1}^n(1-\\mu^{i+1}_j)}$$" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 3, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "class EM:\n", 77 | " def __init__(self, prob):\n", 78 | " self.pro_A, self.pro_B, self.pro_C = prob\n", 79 | " \n", 80 | " # e_step\n", 81 | " def pmf(self, i):\n", 82 | " pro_1 = self.pro_A * math.pow(self.pro_B, data[i]) * math.pow((1-self.pro_B), 1-data[i])\n", 83 | " pro_2 = (1 - self.pro_A) * math.pow(self.pro_C, data[i]) * math.pow((1-self.pro_C), 1-data[i])\n", 84 | " return pro_1 / (pro_1 + pro_2)\n", 85 | " \n", 86 | " # m_step\n", 87 | " def fit(self, data):\n", 88 | " count = len(data)\n", 89 | " print('init prob:{}, {}, {}'.format(self.pro_A, self.pro_B, self.pro_C))\n", 90 | " for d in range(count):\n", 91 | " _ = yield\n", 92 | " _pmf = [self.pmf(k) for k in range(count)]\n", 93 | " pro_A = 1/ count * sum(_pmf)\n", 94 | " pro_B = sum([_pmf[k]*data[k] for k in range(count)]) / sum([_pmf[k] for k in range(count)])\n", 95 | " pro_C = sum([(1-_pmf[k])*data[k] for k in range(count)]) / sum([(1-_pmf[k]) for k in range(count)])\n", 96 | " print('{}/{} pro_a:{:.3f}, pro_b:{:.3f}, pro_c:{:.3f}'.format(d+1, count, pro_A, pro_B, pro_C))\n", 97 | " self.pro_A = pro_A\n", 98 | " self.pro_B = pro_B\n", 99 | " self.pro_C = pro_C\n", 100 | " " 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 4, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "data=[1,1,0,1,0,0,1,0,1,1]" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 5, 115 | "metadata": {}, 116 | "outputs": [ 117 | { 118 | "name": "stdout", 119 | "output_type": "stream", 120 | "text": [ 121 | "init prob:0.5, 0.5, 0.5\n" 122 | ] 123 | } 124 | ], 125 | "source": [ 126 | "em = EM(prob=[0.5, 0.5, 0.5])\n", 127 | "f = em.fit(data)\n", 128 | "next(f)" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 6, 134 | "metadata": {}, 135 | "outputs": [ 136 | { 137 | "name": "stdout", 138 | "output_type": "stream", 139 | "text": [ 140 | "1/10 pro_a:0.500, pro_b:0.600, pro_c:0.600\n" 141 | ] 142 | } 143 | ], 144 | "source": [ 145 | "# 第一次迭代\n", 146 | "f.send(1)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 7, 152 | "metadata": {}, 153 | "outputs": [ 154 | { 155 | "name": "stdout", 156 | "output_type": "stream", 157 | "text": [ 158 | "2/10 pro_a:0.500, pro_b:0.600, pro_c:0.600\n" 159 | ] 160 | } 161 | ], 162 | "source": [ 163 | "# 第二次\n", 164 | "f.send(2)" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 8, 170 | "metadata": {}, 171 | "outputs": [ 172 | { 173 | "name": "stdout", 174 | "output_type": "stream", 175 | "text": [ 176 | "init prob:0.4, 0.6, 0.7\n" 177 | ] 178 | } 179 | ], 180 | "source": [ 181 | "em = EM(prob=[0.4, 0.6, 0.7])\n", 182 | "f2 = em.fit(data)\n", 183 | "next(f2)" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 9, 189 | "metadata": {}, 190 | "outputs": [ 191 | { 192 | "name": "stdout", 193 | "output_type": "stream", 194 | "text": [ 195 | "1/10 pro_a:0.406, pro_b:0.537, pro_c:0.643\n" 196 | ] 197 | } 198 | ], 199 | "source": [ 200 | "f2.send(1)" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": 10, 206 | "metadata": {}, 207 | "outputs": [ 208 | { 209 | "name": "stdout", 210 | "output_type": "stream", 211 | "text": [ 212 | "2/10 pro_a:0.406, pro_b:0.537, pro_c:0.643\n" 213 | ] 214 | } 215 | ], 216 | "source": [ 217 | "f2.send(2)" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": null, 223 | "metadata": { 224 | "collapsed": true 225 | }, 226 | "outputs": [], 227 | "source": [] 228 | } 229 | ], 230 | "metadata": { 231 | "kernelspec": { 232 | "display_name": "Python 3", 233 | "language": "python", 234 | "name": "python3" 235 | }, 236 | "language_info": { 237 | "codemirror_mode": { 238 | "name": "ipython", 239 | "version": 3 240 | }, 241 | "file_extension": ".py", 242 | "mimetype": "text/x-python", 243 | "name": "python", 244 | "nbconvert_exporter": "python", 245 | "pygments_lexer": "ipython3", 246 | "version": "3.6.1" 247 | } 248 | }, 249 | "nbformat": 4, 250 | "nbformat_minor": 2 251 | } 252 | -------------------------------------------------------------------------------- /LeastSquaresMethod/README.md: -------------------------------------------------------------------------------- 1 | # Least Squares Method 2 | 3 | -------------------------------------------------------------------------------- /NaiveBayes/GaussianNB.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 朴素贝叶斯\n", 8 | "\n", 9 | "基于贝叶斯定理与特征条件独立假设的分类方法。\n", 10 | "\n", 11 | "模型:\n", 12 | "\n", 13 | "- 高斯模型\n", 14 | "- 多项式模型\n", 15 | "- 伯努利模型" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 1, 21 | "metadata": { 22 | "collapsed": true 23 | }, 24 | "outputs": [], 25 | "source": [ 26 | "import numpy as np\n", 27 | "import pandas as pd\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "%matplotlib inline\n", 30 | "\n", 31 | "from sklearn.datasets import load_iris\n", 32 | "from sklearn.model_selection import train_test_split\n", 33 | "\n", 34 | "from collections import Counter\n", 35 | "import math" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "# data\n", 45 | "def create_data():\n", 46 | " iris = load_iris()\n", 47 | " df = pd.DataFrame(iris.data, columns=iris.feature_names)\n", 48 | " df['label'] = iris.target\n", 49 | " df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n", 50 | " data = np.array(df.iloc[:100, :])\n", 51 | " # print(data)\n", 52 | " return data[:,:-1], data[:,-1]" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 3, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "X, y = create_data()\n", 62 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 4, 68 | "metadata": {}, 69 | "outputs": [ 70 | { 71 | "data": { 72 | "text/plain": [ 73 | "(array([ 4.8, 3. , 1.4, 0.1]), 0.0)" 74 | ] 75 | }, 76 | "execution_count": 4, 77 | "metadata": {}, 78 | "output_type": "execute_result" 79 | } 80 | ], 81 | "source": [ 82 | "X_test[0], y_test[0]" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "参考:https://machinelearningmastery.com/naive-bayes-classifier-scratch-python/\n", 90 | "\n", 91 | "## GaussianNB 高斯朴素贝叶斯\n", 92 | "\n", 93 | "特征的可能性被假设为高斯\n", 94 | "\n", 95 | "概率密度函数:\n", 96 | "$$P(x_i | y_k)=\\frac{1}{\\sqrt{2\\pi\\sigma^2_{yk}}}exp(-\\frac{(x_i-\\mu_{yk})^2}{2\\sigma^2_{yk}})$$\n", 97 | "\n", 98 | "数学期望(mean):$\\mu$,方差:$\\sigma^2=\\frac{\\sum(X-\\mu)^2}{N}$" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 7, 104 | "metadata": { 105 | "collapsed": true 106 | }, 107 | "outputs": [], 108 | "source": [ 109 | "class NaiveBayes:\n", 110 | " def __init__(self):\n", 111 | " self.model = None\n", 112 | "\n", 113 | " # 数学期望\n", 114 | " @staticmethod\n", 115 | " def mean(X):\n", 116 | " return sum(X) / float(len(X))\n", 117 | "\n", 118 | " # 标准差(方差)\n", 119 | " def stdev(self, X):\n", 120 | " avg = self.mean(X)\n", 121 | " return math.sqrt(sum([pow(x-avg, 2) for x in X]) / float(len(X)))\n", 122 | "\n", 123 | " # 概率密度函数\n", 124 | " def gaussian_probability(self, x, mean, stdev):\n", 125 | " exponent = math.exp(-(math.pow(x-mean,2)/(2*math.pow(stdev,2))))\n", 126 | " return (1 / (math.sqrt(2*math.pi) * stdev)) * exponent\n", 127 | "\n", 128 | " # 处理X_train\n", 129 | " def summarize(self, train_data):\n", 130 | " summaries = [(self.mean(i), self.stdev(i)) for i in zip(*train_data)]\n", 131 | " return summaries\n", 132 | "\n", 133 | " # 分类别求出数学期望和标准差\n", 134 | " def fit(self, X, y):\n", 135 | " labels = list(set(y))\n", 136 | " data = {label:[] for label in labels}\n", 137 | " for f, label in zip(X, y):\n", 138 | " data[label].append(f)\n", 139 | " self.model = {label: self.summarize(value) for label, value in data.items()}\n", 140 | " return 'gaussianNB train done!'\n", 141 | "\n", 142 | " # 计算概率\n", 143 | " def calculate_probabilities(self, input_data):\n", 144 | " # summaries:{0.0: [(5.0, 0.37),(3.42, 0.40)], 1.0: [(5.8, 0.449),(2.7, 0.27)]}\n", 145 | " # input_data:[1.1, 2.2]\n", 146 | " probabilities = {}\n", 147 | " for label, value in self.model.items():\n", 148 | " probabilities[label] = 1\n", 149 | " for i in range(len(value)):\n", 150 | " mean, stdev = value[i]\n", 151 | " probabilities[label] *= self.gaussian_probability(input_data[i], mean, stdev)\n", 152 | " return probabilities\n", 153 | "\n", 154 | " # 类别\n", 155 | " def predict(self, X_test):\n", 156 | " # {0.0: 2.9680340789325763e-27, 1.0: 3.5749783019849535e-26}\n", 157 | " label = sorted(self.calculate_probabilities(X_test).items(), key=lambda x: x[-1])[-1][0]\n", 158 | " return label\n", 159 | "\n", 160 | " def score(self, X_test, y_test):\n", 161 | " right = 0\n", 162 | " for X, y in zip(X_test, y_test):\n", 163 | " label = self.predict(X)\n", 164 | " if label == y:\n", 165 | " right += 1\n", 166 | "\n", 167 | " return right / float(len(X_test))" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 8, 173 | "metadata": { 174 | "collapsed": true 175 | }, 176 | "outputs": [], 177 | "source": [ 178 | "model = NaiveBayes()" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 9, 184 | "metadata": {}, 185 | "outputs": [ 186 | { 187 | "data": { 188 | "text/plain": [ 189 | "'gaussianNB train done!'" 190 | ] 191 | }, 192 | "execution_count": 9, 193 | "metadata": {}, 194 | "output_type": "execute_result" 195 | } 196 | ], 197 | "source": [ 198 | "model.fit(X_train, y_train)" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 10, 204 | "metadata": {}, 205 | "outputs": [ 206 | { 207 | "name": "stdout", 208 | "output_type": "stream", 209 | "text": [ 210 | "0.0\n" 211 | ] 212 | } 213 | ], 214 | "source": [ 215 | "print(model.predict([4.4, 3.2, 1.3, 0.2]))" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 11, 221 | "metadata": {}, 222 | "outputs": [ 223 | { 224 | "data": { 225 | "text/plain": [ 226 | "1.0" 227 | ] 228 | }, 229 | "execution_count": 11, 230 | "metadata": {}, 231 | "output_type": "execute_result" 232 | } 233 | ], 234 | "source": [ 235 | "model.score(X_test, y_test)" 236 | ] 237 | }, 238 | { 239 | "cell_type": "markdown", 240 | "metadata": { 241 | "collapsed": true 242 | }, 243 | "source": [ 244 | "scikit-learn实例\n", 245 | "\n", 246 | "# sklearn.naive_bayes" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 12, 252 | "metadata": { 253 | "collapsed": true 254 | }, 255 | "outputs": [], 256 | "source": [ 257 | "from sklearn.naive_bayes import GaussianNB" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 13, 263 | "metadata": {}, 264 | "outputs": [ 265 | { 266 | "data": { 267 | "text/plain": [ 268 | "GaussianNB(priors=None)" 269 | ] 270 | }, 271 | "execution_count": 13, 272 | "metadata": {}, 273 | "output_type": "execute_result" 274 | } 275 | ], 276 | "source": [ 277 | "clf = GaussianNB()\n", 278 | "clf.fit(X_train, y_train)" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": 14, 284 | "metadata": {}, 285 | "outputs": [ 286 | { 287 | "data": { 288 | "text/plain": [ 289 | "1.0" 290 | ] 291 | }, 292 | "execution_count": 14, 293 | "metadata": {}, 294 | "output_type": "execute_result" 295 | } 296 | ], 297 | "source": [ 298 | "clf.score(X_test, y_test)" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": 16, 304 | "metadata": {}, 305 | "outputs": [ 306 | { 307 | "name": "stderr", 308 | "output_type": "stream", 309 | "text": [ 310 | "E:\\Anaconda3\\lib\\site-packages\\sklearn\\utils\\validation.py:395: DeprecationWarning: Passing 1d arrays as data is deprecated in 0.17 and will raise ValueError in 0.19. Reshape your data either using X.reshape(-1, 1) if your data has a single feature or X.reshape(1, -1) if it contains a single sample.\n", 311 | " DeprecationWarning)\n" 312 | ] 313 | }, 314 | { 315 | "data": { 316 | "text/plain": [ 317 | "array([ 0.])" 318 | ] 319 | }, 320 | "execution_count": 16, 321 | "metadata": {}, 322 | "output_type": "execute_result" 323 | } 324 | ], 325 | "source": [ 326 | "clf.predict([4.4, 3.2, 1.3, 0.2])" 327 | ] 328 | }, 329 | { 330 | "cell_type": "code", 331 | "execution_count": 17, 332 | "metadata": { 333 | "collapsed": true 334 | }, 335 | "outputs": [], 336 | "source": [ 337 | "from sklearn.naive_bayes import BernoulliNB, MultinomialNB # 伯努利模型和多项式模型" 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": null, 343 | "metadata": { 344 | "collapsed": true 345 | }, 346 | "outputs": [], 347 | "source": [] 348 | } 349 | ], 350 | "metadata": { 351 | "kernelspec": { 352 | "display_name": "Python 3", 353 | "language": "python", 354 | "name": "python3" 355 | }, 356 | "language_info": { 357 | "codemirror_mode": { 358 | "name": "ipython", 359 | "version": 3 360 | }, 361 | "file_extension": ".py", 362 | "mimetype": "text/x-python", 363 | "name": "python", 364 | "nbconvert_exporter": "python", 365 | "pygments_lexer": "ipython3", 366 | "version": "3.6.1" 367 | } 368 | }, 369 | "nbformat": 4, 370 | "nbformat_minor": 2 371 | } 372 | -------------------------------------------------------------------------------- /Perceptron/README.md: -------------------------------------------------------------------------------- 1 | # Perceptron 感知机 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # statistical-learning-method 2 | 《统计学习方法》笔记-基于Python算法实现 3 | 4 | 5 | notebook中2-knn和3-naive_bayes是基于这个项目的解读: [lihang_book_algorithm](https://github.com/nicolas-chan/lihang_book_algorithm) 6 | 7 | 4-decision_tree之后的笔记都是本项目的解读。 8 | 9 | 10 | 11 | 第一章 [最小二乘法](http://nbviewer.jupyter.org/github/BrambleXu/statistical-learning-method/blob/master/LeastSquaresMethod/least_sqaure_method.ipynb) 12 | 13 | 14 | 第二章 [感知机](http://nbviewer.jupyter.org/github/BrambleXu/statistical-learning-method/blob/master/Perceptron/Iris_perceptron.ipynb) 15 | 16 | 第三章 [k近邻法](http://nbviewer.jupyter.org/github/BrambleXu/statistical-learning-method/blob/master/KNearestNeighbors/KNN.ipynb),[note](http://nbviewer.jupyter.org/github/BrambleXu/statistical-learning-method/blob/master/notebooks/2-knn.ipynb) 17 | 18 | 第四章 [朴素贝叶斯](http://nbviewer.jupyter.org/github/BrambleXu/statistical-learning-method/blob/master/NaiveBayes/GaussianNB.ipynb),[note](http://nbviewer.jupyter.org/github/BrambleXu/statistical-learning-method/blob/master/notebooks/3-naive_bayes.ipynb) 19 | 20 | 第五章 [决策树](http://nbviewer.jupyter.org/github/BrambleXu/statistical-learning-method/blob/master/DecisonTree/DT.ipynb),[note](http://nbviewer.jupyter.org/github/BrambleXu/statistical-learning-method/blob/master/notebooks/4-decision_tree.ipynb) 21 | 22 | 第六章 [逻辑斯谛回归](http://nbviewer.jupyter.org/github/BrambleXu/statistical-learning-method/blob/master/LogisticRegression/LR.ipynb),[note](http://nbviewer.jupyter.org/github/BrambleXu/statistical-learning-method/blob/master/notebooks/5-logistic_regression.ipynb) 23 | 24 | 第七章 [支持向量机](http://nbviewer.jupyter.org/github/BrambleXu/statistical-learning-method/blob/master/SVM/support-vector-machine.ipynb),[note](http://nbviewer.jupyter.org/github/BrambleXu/statistical-learning-method/blob/master/notebooks/6-svm.ipynb) 25 | 26 | 第八章 [AdaBoost](http://nbviewer.jupyter.org/github/BrambleXu/statistical-learning-method/blob/master/AdaBoost/Adaboost.ipynb),[note](http://nbviewer.jupyter.org/github/BrambleXu/statistical-learning-method/blob/master/notebooks/7-adaboost.ipynb) 27 | 28 | 第九章 [EM算法](http://nbviewer.jupyter.org/github/BrambleXu/statistical-learning-method/blob/master/EM/em.ipynb) 29 | -------------------------------------------------------------------------------- /SVM/.ipynb_checkpoints/support-vector-machine-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 支持向量机" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "----\n", 15 | "分离超平面:$w^Tx+b=0$\n", 16 | "\n", 17 | "点到直线距离:$r=\\frac{|w^Tx+b|}{||w||_2}$\n", 18 | "\n", 19 | "$||w||_2$为2-范数:$||w||_2=\\sqrt[2]{\\sum^m_{i=1}w_i^2}$\n", 20 | "\n", 21 | "直线为超平面,样本可表示为:\n", 22 | "\n", 23 | "$w^Tx+b\\ \\geq+1$\n", 24 | "\n", 25 | "$w^Tx+b\\ \\leq+1$\n", 26 | "\n", 27 | "#### margin:\n", 28 | "\n", 29 | "**函数间隔**:$label(w^Tx+b)\\ or\\ y_i(w^Tx+b)$\n", 30 | "\n", 31 | "**几何间隔**:$r=\\frac{label(w^Tx+b)}{||w||_2}$,当数据被正确分类时,几何间隔就是点到超平面的距离\n", 32 | "\n", 33 | "为了求几何间隔最大,SVM基本问题可以转化为求解:($\\frac{r^*}{||w||}$为几何间隔,(${r^*}$为函数间隔)\n", 34 | "\n", 35 | "$$\\max\\ \\frac{r^*}{||w||}$$\n", 36 | "\n", 37 | "$$(subject\\ to)\\ y_i({w^T}x_i+{b})\\geq {r^*},\\ i=1,2,..,m$$\n", 38 | "\n", 39 | "分类点几何间隔最大,同时被正确分类。但这个方程并非凸函数求解,所以要先①将方程转化为凸函数,②用拉格朗日乘子法和KKT条件求解对偶问题。\n", 40 | "\n", 41 | "①转化为凸函数:\n", 42 | "\n", 43 | "先令${r^*}=1$,方便计算(参照衡量,不影响评价结果)\n", 44 | "\n", 45 | "$$\\max\\ \\frac{1}{||w||}$$\n", 46 | "\n", 47 | "$$s.t.\\ y_i({w^T}x_i+{b})\\geq {1},\\ i=1,2,..,m$$\n", 48 | "\n", 49 | "再将$\\max\\ \\frac{1}{||w||}$转化成$\\min\\ \\frac{1}{2}||w||^2$求解凸函数,1/2是为了求导之后方便计算。\n", 50 | "\n", 51 | "$$\\min\\ \\frac{1}{2}||w||^2$$\n", 52 | "\n", 53 | "$$s.t.\\ y_i(w^Tx_i+b)\\geq 1,\\ i=1,2,..,m$$\n", 54 | "\n", 55 | "②用拉格朗日乘子法和KKT条件求解最优值:\n", 56 | "\n", 57 | "$$\\min\\ \\frac{1}{2}||w||^2$$\n", 58 | "\n", 59 | "$$s.t.\\ -y_i(w^Tx_i+b)+1\\leq 0,\\ i=1,2,..,m$$\n", 60 | "\n", 61 | "整合成:\n", 62 | "\n", 63 | "$$L(w, b, \\alpha) = \\frac{1}{2}||w||^2+\\sum^m_{i=1}\\alpha_i(-y_i(w^Tx_i+b)+1)$$\n", 64 | "\n", 65 | "推导:$\\min\\ f(x)=\\min \\max\\ L(w, b, \\alpha)\\geq \\max \\min\\ L(w, b, \\alpha)$\n", 66 | "\n", 67 | "根据KKT条件:\n", 68 | "\n", 69 | "$$\\frac{\\partial }{\\partial w}L(w, b, \\alpha)=w-\\sum\\alpha_iy_ix_i=0,\\ w=\\sum\\alpha_iy_ix_i$$\n", 70 | "\n", 71 | "$$\\frac{\\partial }{\\partial b}L(w, b, \\alpha)=\\sum\\alpha_iy_i=0$$\n", 72 | "\n", 73 | "带入$ L(w, b, \\alpha)$\n", 74 | "\n", 75 | "$\\min\\ L(w, b, \\alpha)=\\frac{1}{2}||w||^2+\\sum^m_{i=1}\\alpha_i(-y_i(w^Tx_i+b)+1)$\n", 76 | "\n", 77 | "$\\qquad\\qquad\\qquad=\\frac{1}{2}w^Tw-\\sum^m_{i=1}\\alpha_iy_iw^Tx_i-b\\sum^m_{i=1}\\alpha_iy_i+\\sum^m_{i=1}\\alpha_i$\n", 78 | "\n", 79 | "$\\qquad\\qquad\\qquad=\\frac{1}{2}w^T\\sum\\alpha_iy_ix_i-\\sum^m_{i=1}\\alpha_iy_iw^Tx_i+\\sum^m_{i=1}\\alpha_i$\n", 80 | "\n", 81 | "$\\qquad\\qquad\\qquad=\\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i=1}\\alpha_iy_iw^Tx_i$\n", 82 | "\n", 83 | "$\\qquad\\qquad\\qquad=\\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)$\n", 84 | "\n", 85 | "再把max问题转成min问题:\n", 86 | "\n", 87 | "$\\max\\ \\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)=\\min \\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)-\\sum^m_{i=1}\\alpha_i$\n", 88 | "\n", 89 | "$s.t.\\ \\sum^m_{i=1}\\alpha_iy_i=0,$\n", 90 | "\n", 91 | "$ \\alpha_i \\geq 0,i=1,2,...,m$\n", 92 | "\n", 93 | "以上为SVM对偶问题的对偶形式\n", 94 | "\n", 95 | "-----\n", 96 | "#### kernel\n", 97 | "\n", 98 | "在低维空间计算获得高维空间的计算结果,也就是说计算结果满足高维(满足高维,才能说明高维下线性可分)。\n", 99 | "\n", 100 | "#### soft margin & slack variable\n", 101 | "\n", 102 | "引入松弛变量$\\xi\\geq0$,对应数据点允许偏离的functional margin 的量。\n", 103 | "\n", 104 | "目标函数:$\\min\\ \\frac{1}{2}||w||^2+C\\sum\\xi_i\\qquad s.t.\\ y_i(w^Tx_i+b)\\geq1-\\xi_i$ \n", 105 | "\n", 106 | "对偶问题:\n", 107 | "\n", 108 | "$$\\max\\ \\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)=\\min \\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)-\\sum^m_{i=1}\\alpha_i$$\n", 109 | "\n", 110 | "$$s.t.\\ C\\geq\\alpha_i \\geq 0,i=1,2,...,m\\quad \\sum^m_{i=1}\\alpha_iy_i=0,$$\n", 111 | "\n", 112 | "-----\n", 113 | "\n", 114 | "#### Sequential Minimal Optimization\n", 115 | "\n", 116 | "首先定义特征到结果的输出函数:$u=w^Tx+b$.\n", 117 | "\n", 118 | "因为$w=\\sum\\alpha_iy_ix_i$\n", 119 | "\n", 120 | "有$u=\\sum y_i\\alpha_iK(x_i, x)-b$\n", 121 | "\n", 122 | "\n", 123 | "----\n", 124 | "\n", 125 | "$\\max \\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i=1}\\sum^m_{j=1}\\alpha_i\\alpha_jy_iy_j<\\phi(x_i)^T,\\phi(x_j)>$\n", 126 | "\n", 127 | "$s.t.\\ \\sum^m_{i=1}\\alpha_iy_i=0,$\n", 128 | "\n", 129 | "$ \\alpha_i \\geq 0,i=1,2,...,m$\n", 130 | "\n", 131 | "-----\n", 132 | "参考资料:\n", 133 | "\n", 134 | "[1] :[Lagrange Multiplier and KKT](http://blog.csdn.net/xianlingmao/article/details/7919597)\n", 135 | "\n", 136 | "[2] :[推导SVM](https://my.oschina.net/dfsj66011/blog/517766)\n", 137 | "\n", 138 | "[3] :[机器学习算法实践-支持向量机(SVM)算法原理](http://pytlab.org/2017/08/15/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E7%AE%97%E6%B3%95%E5%AE%9E%E8%B7%B5-%E6%94%AF%E6%8C%81%E5%90%91%E9%87%8F%E6%9C%BA-SVM-%E7%AE%97%E6%B3%95%E5%8E%9F%E7%90%86/)\n", 139 | "\n", 140 | "[4] :[Python实现SVM](http://blog.csdn.net/wds2006sdo/article/details/53156589)" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 1, 146 | "metadata": {}, 147 | "outputs": [ 148 | { 149 | "name": "stderr", 150 | "output_type": "stream", 151 | "text": [ 152 | "E:\\Anaconda3\\lib\\site-packages\\sklearn\\cross_validation.py:44: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n", 153 | " \"This module will be removed in 0.20.\", DeprecationWarning)\n" 154 | ] 155 | } 156 | ], 157 | "source": [ 158 | "import numpy as np\n", 159 | "import pandas as pd\n", 160 | "from sklearn.datasets import load_iris\n", 161 | "from sklearn.cross_validation import train_test_split\n", 162 | "import matplotlib.pyplot as plt\n", 163 | "%matplotlib inline" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 2, 169 | "metadata": { 170 | "collapsed": true 171 | }, 172 | "outputs": [], 173 | "source": [ 174 | "# data\n", 175 | "def create_data():\n", 176 | " iris = load_iris()\n", 177 | " df = pd.DataFrame(iris.data, columns=iris.feature_names)\n", 178 | " df['label'] = iris.target\n", 179 | " df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n", 180 | " data = np.array(df.iloc[:100, [0, 1, -1]])\n", 181 | " for i in range(len(data)):\n", 182 | " if data[i,-1] == 0:\n", 183 | " data[i,-1] = -1\n", 184 | " # print(data)\n", 185 | " return data[:,:2], data[:,-1]" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 9, 191 | "metadata": { 192 | "collapsed": true 193 | }, 194 | "outputs": [], 195 | "source": [ 196 | "X, y = create_data()\n", 197 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 10, 203 | "metadata": {}, 204 | "outputs": [ 205 | { 206 | "data": { 207 | "text/plain": [ 208 | "" 209 | ] 210 | }, 211 | "execution_count": 10, 212 | "metadata": {}, 213 | "output_type": "execute_result" 214 | }, 215 | { 216 | "data": { 217 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAGZ9JREFUeJzt3X9sHOWdx/H394yv8bWAReMWsJNLCihqSSLSugTICXGg\nXkqaQoRQlAiKQhE5ELpS0aNqKtQfqBJISLRQdEQBdBTBBeVoGihHgjgoKkUklZMg5y4pKhxtY8MV\nE5TQHKYE93t/7Dqx12vvzu6O93me/bwky97Zyfj7zMA3m5nPPGPujoiIpOWvml2AiIg0npq7iEiC\n1NxFRBKk5i4ikiA1dxGRBKm5i4gkSM1dRCRBau4iIglScxcRSdBx1a5oZm1AHzDo7stL3rsAeBx4\nvbhos7vfOtX2Zs6c6XPmzMlUrIhIq9u5c+fb7t5Vab2qmztwI7APOGGS918obfpTmTNnDn19fRl+\nvYiImNnvq1mvqtMyZtYDfAm4v56iRERkelR7zv1HwDeBv0yxznlm1m9mW83szHIrmNlaM+szs76h\noaGstYqISJUqNnczWw685e47p1htFzDb3RcCPwa2lFvJ3Te4e6+793Z1VTxlJCIiNarmnPsS4BIz\nWwbMAE4ws4fd/crRFdz93TE/P2Vm/2JmM9397caXLCJSnyNHjjAwMMD777/f7FImNWPGDHp6emhv\nb6/pz1ds7u6+DlgHR1Mx/zy2sReXnwz80d3dzM6m8C+CAzVVJCKSs4GBAY4//njmzJmDmTW7nAnc\nnQMHDjAwMMDcuXNr2kaWtMw4ZnZdsYj1wOXA9Wb2ITAMrHI9BUREAvX+++8H29gBzIyPf/zj1HNt\nMlNzd/fngeeLP68fs/we4J6aqxAJ2Jbdg9zx9Cu8cXCYUzs7uHnpPFYs6m52WVKnUBv7qHrrq/mT\nu0gr2LJ7kHWb9zB8ZASAwYPDrNu8B0ANXoKm6QdEpnDH068cbeyjho+McMfTrzSpIknFtm3bmDdv\nHqeffjq33357w7ev5i4yhTcODmdaLlKNkZERbrjhBrZu3crevXvZuHEje/fubejv0GkZkSmc2tnB\nYJlGfmpnRxOqkWZp9HWXX//615x++ul86lOfAmDVqlU8/vjjfOYzn2lUyfrkLjKVm5fOo6O9bdyy\njvY2bl46r0kVyXQbve4yeHAY59h1ly27B2ve5uDgILNmzTr6uqenh8HB2rdXjpq7yBRWLOrmtssW\n0N3ZgQHdnR3cdtkCXUxtIbFed9FpGZEKVizqVjNvYXlcd+nu7mb//v1HXw8MDNDd3dj/xvTJXURk\nCpNdX6nnusvnP/95fvvb3/L666/zwQcf8Oijj3LJJZfUvL1y1NxFRKaQx3WX4447jnvuuYelS5fy\n6U9/mpUrV3LmmWUn0639dzR0ayIiiRk9Jdfou5SXLVvGsmXLGlFiWWruIiIVxHjdRadlREQSpOYu\nIpIgNXcRkQSpuYuIJEjNXUQkQWrukowtuwdZcvtzzP3Wf7Dk9ufqmvtDJG9f/epX+cQnPsH8+fNz\n2b6auyQhj8mdRPK0Zs0atm3bltv21dwlCbFO7iSR6N8EP5wP3+ssfO/fVPcmzz//fE466aQGFFee\nbmKSJOihGpKb/k3w86/BkeJ/S4f2F14DLFzZvLoq0Cd3SUIekzuJAPDsrcca+6gjw4XlAVNzlyTo\noRqSm0MD2ZYHQqdlJAl5Te4kwok9hVMx5ZYHTM1dkhHj5E4SgYu+M/6cO0B7R2F5HVavXs3zzz/P\n22+/TU9PD9///ve55ppr6iz2GDV3qVujHx4sEpTRi6bP3lo4FXNiT6Gx13kxdePGjQ0obnJq7lKX\n0Xz5aAxxNF8OqMFLOhauDDoZU44uqEpdlC8XCZOau9RF+XKJlbs3u4Qp1VufmrvURflyidGMGTM4\ncOBAsA3e3Tlw4AAzZsyoeRs65y51uXnpvHHn3EH5cglfT08PAwMDDA0NNbuUSc2YMYOentrjlmru\nUhflyyVG7e3tzJ07t9ll5Krq5m5mbUAfMOjuy0veM+AuYBnwHrDG3Xc1slAJl/LlIuHJ8sn9RmAf\ncEKZ9y4Gzih+LQbuLX4XaSnK/EsoqrqgamY9wJeA+ydZ5VLgIS/YDnSa2SkNqlEkCppTXkJSbVrm\nR8A3gb9M8n43MHbyhYHiMpGWocy/hKRiczez5cBb7r6z3l9mZmvNrM/M+kK+Si1SC2X+JSTVfHJf\nAlxiZr8DHgUuNLOHS9YZBGaNed1TXDaOu29w91537+3q6qqxZJEwKfMvIanY3N19nbv3uPscYBXw\nnLtfWbLaE8BVVnAOcMjd32x8uSLh0pzyEpKac+5mdh2Au68HnqIQg3yVQhTy6oZUJxIRZf4lJNas\n2297e3u9r6+vKb9bRCRWZrbT3Xsrrac7VCVYt2zZw8Yd+xlxp82M1Ytn8YMVC5pdlkgU1NwlSLds\n2cPD2/9w9PWI+9HXavAilWlWSAnSxh1lnlk5xXIRGU/NXYI0Msm1oMmWi8h4au4SpDazTMtFZDw1\ndwnS6sWzMi0XkfF0QVWCNHrRVGkZkdoo5y4iEhHl3KUuV9z3Ei++9s7R10tOO4lHrj23iRU1j+Zo\nlxjpnLtMUNrYAV587R2uuO+lJlXUPJqjXWKl5i4TlDb2SstTpjnaJVZq7iJT0BztEis1d5EpaI52\niZWau0yw5LSTMi1PmeZol1ipucsEj1x77oRG3qppmRWLurntsgV0d3ZgQHdnB7ddtkBpGQmecu4i\nIhFRzl3qkle2O8t2lS8XqZ2au0wwmu0ejQCOZruBupprlu3mVYNIq9A5d5kgr2x3lu0qXy5SHzV3\nmSCvbHeW7SpfLlIfNXeZIK9sd5btKl8uUh81d5kgr2x3lu0qXy5SH11QlQlGL1g2OqmSZbt51SDS\nKpRzFxGJiHLuOYsxgx1jzSJSGzX3GsSYwY6xZhGpnS6o1iDGDHaMNYtI7dTcaxBjBjvGmkWkdmru\nNYgxgx1jzSJSOzX3GsSYwY6xZhGpnS6o1iDGDHaMNYtI7Srm3M1sBvBL4CMU/jJ4zN2/W7LOBcDj\nwOvFRZvd/daptqucu4hIdo3Muf8ZuNDdD5tZO/ArM9vq7ttL1nvB3ZfXUqxMj1u27GHjjv2MuNNm\nxurFs/jBigV1rxtKfj6UOkRCULG5e+Gj/eHiy/biV3Nua5Wa3bJlDw9v/8PR1yPuR1+XNu0s64aS\nnw+lDpFQVHVB1czazOxl4C3gGXffUWa188ys38y2mtmZDa1S6rZxx/6ql2dZN5T8fCh1iISiqubu\n7iPufhbQA5xtZvNLVtkFzHb3hcCPgS3ltmNma82sz8z6hoaG6qlbMhqZ5NpKueVZ1g0lPx9KHSKh\nyBSFdPeDwC+AL5Ysf9fdDxd/fgpoN7OZZf78Bnfvdfferq6uOsqWrNrMql6eZd1Q8vOh1CESiorN\n3cy6zKyz+HMH8AXgNyXrnGxW+D/fzM4ubvdA48uVWq1ePKvq5VnWDSU/H0odIqGoJi1zCvATM2uj\n0LQ3ufuTZnYdgLuvBy4HrjezD4FhYJU3ay5hKWv0Qmg1CZgs64aSnw+lDpFQaD53EZGIaD73nOWV\nqc6SL89z21nGF+O+iE7/Jnj2Vjg0ACf2wEXfgYUrm12VBEzNvQZ5Zaqz5Mvz3HaW8cW4L6LTvwl+\n/jU4Ukz+HNpfeA1q8DIpTRxWg7wy1Vny5XluO8v4YtwX0Xn21mONfdSR4cJykUmoudcgr0x1lnx5\nntvOMr4Y90V0Dg1kWy6CmntN8spUZ8mX57ntLOOLcV9E58SebMtFUHOvSV6Z6iz58jy3nWV8Me6L\n6Fz0HWgv+cuyvaOwXGQSuqBag7wy1Vny5XluO8v4YtwX0Rm9aKq0jGSgnLuISESUc5cJQsiuS+SU\nt4+GmnuLCCG7LpFT3j4quqDaIkLIrkvklLePipp7iwghuy6RU94+KmruLSKE7LpETnn7qKi5t4gQ\nsusSOeXto6ILqi0ihOy6RE55+6go5y4iEhHl3Ivyymtn2W4o85Irux6Y1DPjqY8viybsi6Sbe155\n7SzbDWVecmXXA5N6Zjz18WXRpH2R9AXVvPLaWbYbyrzkyq4HJvXMeOrjy6JJ+yLp5p5XXjvLdkOZ\nl1zZ9cCknhlPfXxZNGlfJN3c88prZ9luKPOSK7semNQz46mPL4sm7Yukm3teee0s2w1lXnJl1wOT\nemY89fFl0aR9kfQF1bzy2lm2G8q85MquByb1zHjq48uiSftCOXcRkYgo556zEPLzV9z3Ei++9s7R\n10tOO4lHrj237hpEkvLkTbDzQfARsDb43BpYfmf92w08x5/0Ofe8jGbGBw8O4xzLjG/ZPTht2y1t\n7AAvvvYOV9z3Ul01iCTlyZug74FCY4fC974HCsvrMZpdP7Qf8GPZ9f5NdZfcKGruNQghP1/a2Cst\nF2lJOx/MtrxaEeT41dxrEEJ+XkSq4CPZllcrghy/mnsNQsjPi0gVrC3b8mpFkONXc69BCPn5Jaed\nVHYbky0XaUmfW5NtebUiyPGruddgxaJubrtsAd2dHRjQ3dnBbZctaEh+vtrtPnLtuRMaudIyIiWW\n3wm91xz7pG5thdf1pmUWroQv3w0nzgKs8P3LdweVllHOXUQkIg3LuZvZDOCXwEeK6z/m7t8tWceA\nu4BlwHvAGnffVUvhlWTNl8c2h3mWud9T3xe55oizZJ/zqiPP8QWewa5L1rGlvC+mUM1NTH8GLnT3\nw2bWDvzKzLa6+/Yx61wMnFH8WgzcW/zeUFnnJI9tDvMsc7+nvi9ynQN7NPs8ajT7DBMbfF515Dm+\nlOdSzzq2lPdFBRXPuXvB4eLL9uJX6bmcS4GHiutuBzrN7JTGlpo9Xx7bHOZZ5n5PfV/kmiPOkn3O\nq448xxdBBrtmWceW8r6ooKoLqmbWZmYvA28Bz7j7jpJVuoGxHWiguKx0O2vNrM/M+oaGhjIXmzUH\nHltuPMvc76nvi1xzxFmyz3nVkef4Ishg1yzr2FLeFxVU1dzdfcTdzwJ6gLPNbH4tv8zdN7h7r7v3\ndnV1Zf7zWXPgseXGs8z9nvq+yDVHnCX7nFcdeY4vggx2zbKOLeV9UUGmKKS7HwR+AXyx5K1BYOwE\n5T3FZQ2VNV8e2xzmWeZ+T31f5JojzpJ9zquOPMcXQQa7ZlnHlvK+qKBiczezLjPrLP7cAXwB+E3J\nak8AV1nBOcAhd3+z0cVmzZfnlUfPyw9WLODKc2Yf/aTeZsaV58wum5ZJfV/kmiPOkn3Oq448xxdB\nBrtmWceW8r6ooGLO3cwWAj8B2ij8ZbDJ3W81s+sA3H19MQp5D4VP9O8BV7v7lCF25dxFRLJrWM7d\n3fuBRWWWrx/zswM3ZC1SRETykfzDOqK7cUemR5YbW0K4CSbPG3diu0krhOMRgaSbe3Q37sj0yHJj\nSwg3weR5405sN2mFcDwikfTEYdHduCPTI8uNLSHcBJPnjTux3aQVwvGIRNLNPbobd2R6ZLmxJYSb\nYPK8cSe2m7RCOB6RSLq5R3fjjkyPLDe2hHATTJ437sR2k1YIxyMSSTf36G7ckemR5caWEG6CyfPG\nndhu0grheEQi6eYe3Y07Mj2y3NgSwk0wed64E9tNWiEcj0joYR0iIhFp2E1MIi0vy4M9QhFbzaFk\n10OpowHU3EWmkuXBHqGIreZQsuuh1NEgSZ9zF6lblgd7hCK2mkPJrodSR4OouYtMJcuDPUIRW82h\nZNdDqaNB1NxFppLlwR6hiK3mULLrodTRIGruIlPJ8mCPUMRWcyjZ9VDqaBA1d5GpZHmwRyhiqzmU\n7HoodTSIcu4iIhFRzl2mT4zZ4LxqzitfHuM+lqZSc5f6xJgNzqvmvPLlMe5jaTqdc5f6xJgNzqvm\nvPLlMe5jaTo1d6lPjNngvGrOK18e4z6WplNzl/rEmA3Oq+a88uUx7mNpOjV3qU+M2eC8as4rXx7j\nPpamU3OX+sSYDc6r5rzy5THuY2k65dxFRCJSbc5dn9wlHf2b4Ifz4Xudhe/9m6Z/u3nVIJKRcu6S\nhryy4Fm2qzy6BESf3CUNeWXBs2xXeXQJiJq7pCGvLHiW7SqPLgFRc5c05JUFz7Jd5dElIGrukoa8\nsuBZtqs8ugREzV3SkFcWPMt2lUeXgFTMuZvZLOAh4JOAAxvc/a6SdS4AHgdeLy7a7O5TXkVSzl1E\nJLtGzuf+IfANd99lZscDO83sGXffW7LeC+6+vJZiJUAxzh+epeYYxxcC7bdoVGzu7v4m8Gbx5z+Z\n2T6gGyht7pKKGPPayqPnT/stKpnOuZvZHGARsKPM2+eZWb+ZbTWzMxtQmzRLjHlt5dHzp/0Wlarv\nUDWzjwE/Bb7u7u+WvL0LmO3uh81sGbAFOKPMNtYCawFmz55dc9GSsxjz2sqj50/7LSpVfXI3s3YK\njf0Rd99c+r67v+vuh4s/PwW0m9nMMuttcPded+/t6uqqs3TJTYx5beXR86f9FpWKzd3MDHgA2Ofu\nZecuNbOTi+thZmcXt3ugkYXKNIoxr608ev6036JSzWmZJcBXgD1m9nJx2beB2QDuvh64HLjezD4E\nhoFV3qy5hKV+oxfHYkpFZKk5xvGFQPstKprPXUQkIo3MuUuolDke78mbYOeDhQdSW1vh8Xb1PgVJ\nJFJq7rFS5ni8J2+CvgeOvfaRY6/V4KUFaW6ZWClzPN7OB7MtF0mcmnuslDkez0eyLRdJnJp7rJQ5\nHs/asi0XSZyae6yUOR7vc2uyLRdJnJp7rDR3+HjL74Tea459Ure2wmtdTJUWpZy7iEhElHOvwZbd\ng9zx9Cu8cXCYUzs7uHnpPFYs6m52WY2Tei4+9fGFQPs4GmruRVt2D7Ju8x6GjxTSFYMHh1m3eQ9A\nGg0+9Vx86uMLgfZxVHTOveiOp1852thHDR8Z4Y6nX2lSRQ2Wei4+9fGFQPs4KmruRW8cHM60PDqp\n5+JTH18ItI+jouZedGpnR6bl0Uk9F5/6+EKgfRwVNfeim5fOo6N9/A0vHe1t3Lx0XpMqarDUc/Gp\njy8E2sdR0QXVotGLpsmmZVKfizv18YVA+zgqyrmLiESk2py7TsuIxKB/E/xwPnyvs/C9f1Mc25am\n0WkZkdDlmS9Xdj1Z+uQuEro88+XKridLzV0kdHnmy5VdT5aau0jo8syXK7ueLDV3kdDlmS9Xdj1Z\nau4ioctz7n49FyBZyrmLiEREOXcRkRam5i4ikiA1dxGRBKm5i4gkSM1dRCRBau4iIglScxcRSZCa\nu4hIgio2dzObZWa/MLO9ZvbfZnZjmXXMzO42s1fNrN/MPptPuVIXzdst0jKqmc/9Q+Ab7r7LzI4H\ndprZM+6+d8w6FwNnFL8WA/cWv0soNG+3SEup+Mnd3d90913Fn/8E7ANKHyx6KfCQF2wHOs3slIZX\nK7XTvN0iLSXTOXczmwMsAnaUvNUN7B/zeoCJfwFgZmvNrM/M+oaGhrJVKvXRvN0iLaXq5m5mHwN+\nCnzd3d+t5Ze5+wZ373X33q6urlo2IbXSvN0iLaWq5m5m7RQa+yPuvrnMKoPArDGve4rLJBSat1uk\npVSTljHgAWCfu985yWpPAFcVUzPnAIfc/c0G1in10rzdIi2lmrTMEuArwB4ze7m47NvAbAB3Xw88\nBSwDXgXeA65ufKlSt4Ur1cxFWkTF5u7uvwKswjoO3NCookREpD66Q1VEJEFq7iIiCVJzFxFJkJq7\niEiC1NxFRBKk5i4ikiA1dxGRBFkhot6EX2w2BPy+Kb+8spnA280uIkcaX7xSHhtofNX4W3evODlX\n05p7yMysz917m11HXjS+eKU8NtD4GkmnZUREEqTmLiKSIDX38jY0u4CcaXzxSnlsoPE1jM65i4gk\nSJ/cRUQS1NLN3czazGy3mT1Z5r0LzOyQmb1c/IrqkUVm9jsz21Osva/M+2Zmd5vZq2bWb2afbUad\ntapifLEfv04ze8zMfmNm+8zs3JL3Yz9+lcYX7fEzs3lj6n7ZzN41s6+XrJP78avmYR0puxHYB5ww\nyfsvuPvyaayn0f7e3SfL1F4MnFH8WgzcW/wek6nGB3Efv7uAbe5+uZn9NfA3Je/HfvwqjQ8iPX7u\n/gpwFhQ+QFJ45OjPSlbL/fi17Cd3M+sBvgTc3+xamuRS4CEv2A50mtkpzS5KwMxOBM6n8HhL3P0D\ndz9Yslq0x6/K8aXiIuA1dy+9YTP349eyzR34EfBN4C9TrHNe8Z9MW83szGmqq1Ec+E8z22lma8u8\n3w3sH/N6oLgsFpXGB/Eev7nAEPCvxdOG95vZR0vWifn4VTM+iPf4jbUK2Fhmee7HryWbu5ktB95y\n951TrLYLmO3uC4EfA1umpbjG+Tt3P4vCP/9uMLPzm11Qg1UaX8zH7zjgs8C97r4I+D/gW80tqaGq\nGV/Mxw+A4ummS4B/b8bvb8nmTuGh35eY2e+AR4ELzezhsSu4+7vufrj481NAu5nNnPZKa+Tug8Xv\nb1E433d2ySqDwKwxr3uKy6JQaXyRH78BYMDddxRfP0ahGY4V8/GrOL7Ij9+oi4Fd7v7HMu/lfvxa\nsrm7+zp373H3ORT+2fScu185dh0zO9nMrPjz2RT21YFpL7YGZvZRMzt+9GfgH4D/KlntCeCq4lX7\nc4BD7v7mNJdak2rGF/Pxc/f/Bfab2bzioouAvSWrRXv8qhlfzMdvjNWUPyUD03D8Wj0tM46ZXQfg\n7uuBy4HrzexDYBhY5fHc8fVJ4GfF/zeOA/7N3beVjO8pYBnwKvAecHWTaq1FNeOL+fgB/BPwSPGf\n9v8DXJ3Q8YPK44v6+BU/dHwB+Mcxy6b1+OkOVRGRBLXkaRkRkdSpuYuIJEjNXUQkQWruIiIJUnMX\nEUmQmruISILU3EVEEqTmLiKSoP8H2fNC9uxjMHwAAAAASUVORK5CYII=\n", 218 | "text/plain": [ 219 | "" 220 | ] 221 | }, 222 | "metadata": {}, 223 | "output_type": "display_data" 224 | } 225 | ], 226 | "source": [ 227 | "plt.scatter(X[:50,0],X[:50,1], label='0')\n", 228 | "plt.scatter(X[50:,0],X[50:,1], label='1')\n", 229 | "plt.legend()" 230 | ] 231 | }, 232 | { 233 | "cell_type": "markdown", 234 | "metadata": {}, 235 | "source": [ 236 | "----\n", 237 | "\n" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 11, 243 | "metadata": { 244 | "collapsed": true 245 | }, 246 | "outputs": [], 247 | "source": [ 248 | "class SVM:\n", 249 | " def __init__(self, max_iter=100, kernel='linear'):\n", 250 | " self.max_iter = max_iter\n", 251 | " self._kernel = kernel\n", 252 | " \n", 253 | " def init_args(self, features, labels):\n", 254 | " self.m, self.n = features.shape\n", 255 | " self.X = features\n", 256 | " self.Y = labels\n", 257 | " self.b = 0.0\n", 258 | " \n", 259 | " # 将Ei保存在一个列表里\n", 260 | " self.alpha = np.ones(self.m)\n", 261 | " self.E = [self._E(i) for i in range(self.m)]\n", 262 | " # 松弛变量\n", 263 | " self.C = 1.0\n", 264 | " \n", 265 | " def _KKT(self, i):\n", 266 | " y_g = self._g(i)*self.Y[i]\n", 267 | " if self.alpha[i] == 0:\n", 268 | " return y_g >= 1\n", 269 | " elif 0 < self.alpha[i] < self.C:\n", 270 | " return y_g == 1\n", 271 | " else:\n", 272 | " return y_g <= 1\n", 273 | " \n", 274 | " # g(x)预测值,输入xi(X[i])\n", 275 | " def _g(self, i):\n", 276 | " r = self.b\n", 277 | " for j in range(self.m):\n", 278 | " r += self.alpha[j]*self.Y[j]*self.kernel(self.X[i], self.X[j])\n", 279 | " return r\n", 280 | " \n", 281 | " # 核函数\n", 282 | " def kernel(self, x1, x2):\n", 283 | " if self._kernel == 'linear':\n", 284 | " return sum([x1[k]*x2[k] for k in range(self.n)])\n", 285 | " elif self._kernel == 'poly':\n", 286 | " return (sum([x1[k]*x2[k] for k in range(self.n)]) + 1)**2\n", 287 | " \n", 288 | " return 0\n", 289 | " \n", 290 | " # E(x)为g(x)对输入x的预测值和y的差\n", 291 | " def _E(self, i):\n", 292 | " return self._g(i) - self.Y[i]\n", 293 | " \n", 294 | " def _init_alpha(self):\n", 295 | " # 外层循环首先遍历所有满足0= 0:\n", 308 | " j = min(range(self.m), key=lambda x: self.E[x])\n", 309 | " else:\n", 310 | " j = max(range(self.m), key=lambda x: self.E[x])\n", 311 | " return i, j\n", 312 | " \n", 313 | " def _compare(self, _alpha, L, H):\n", 314 | " if _alpha > H:\n", 315 | " return H\n", 316 | " elif _alpha < L:\n", 317 | " return L\n", 318 | " else:\n", 319 | " return _alpha \n", 320 | " \n", 321 | " def fit(self, features, labels):\n", 322 | " self.init_args(features, labels)\n", 323 | " \n", 324 | " for t in range(self.max_iter):\n", 325 | " # train\n", 326 | " i1, i2 = self._init_alpha()\n", 327 | " \n", 328 | " # 边界\n", 329 | " if self.Y[i1] == self.Y[i2]:\n", 330 | " L = max(0, self.alpha[i1]+self.alpha[i2]-self.C)\n", 331 | " H = min(self.C, self.alpha[i1]+self.alpha[i2])\n", 332 | " else:\n", 333 | " L = max(0, self.alpha[i2]-self.alpha[i1])\n", 334 | " H = min(self.C, self.C+self.alpha[i2]-self.alpha[i1])\n", 335 | " \n", 336 | " E1 = self.E[i1]\n", 337 | " E2 = self.E[i2]\n", 338 | " # eta=K11+K22-2K12\n", 339 | " eta = self.kernel(self.X[i1], self.X[i1]) + self.kernel(self.X[i2], self.X[i2]) - 2*self.kernel(self.X[i1], self.X[i2])\n", 340 | " if eta <= 0:\n", 341 | " # print('eta <= 0')\n", 342 | " continue\n", 343 | " \n", 344 | " alpha2_new_unc = self.alpha[i2] + self.Y[i2] * (E2 - E1) / eta\n", 345 | " alpha2_new = self._compare(alpha2_new_unc, L, H)\n", 346 | " \n", 347 | " alpha1_new = self.alpha[i1] + self.Y[i1] * self.Y[i2] * (self.alpha[i2] - alpha2_new)\n", 348 | " \n", 349 | " b1_new = -E1 - self.Y[i1] * self.kernel(self.X[i1], self.X[i1]) * (alpha1_new-self.alpha[i1]) - self.Y[i2] * self.kernel(self.X[i2], self.X[i1]) * (alpha2_new-self.alpha[i2])+ self.b \n", 350 | " b2_new = -E2 - self.Y[i1] * self.kernel(self.X[i1], self.X[i2]) * (alpha1_new-self.alpha[i1]) - self.Y[i2] * self.kernel(self.X[i2], self.X[i2]) * (alpha2_new-self.alpha[i2])+ self.b \n", 351 | " \n", 352 | " if 0 < alpha1_new < self.C:\n", 353 | " b_new = b1_new\n", 354 | " elif 0 < alpha2_new < self.C:\n", 355 | " b_new = b2_new\n", 356 | " else:\n", 357 | " # 选择中点\n", 358 | " b_new = (b1_new + b2_new) / 2\n", 359 | " \n", 360 | " # 更新参数\n", 361 | " self.alpha[i1] = alpha1_new\n", 362 | " self.alpha[i2] = alpha2_new\n", 363 | " self.b = b_new\n", 364 | " \n", 365 | " self.E[i1] = self._E(i1)\n", 366 | " self.E[i2] = self._E(i2)\n", 367 | " return 'train done!'\n", 368 | " \n", 369 | " def predict(self, data):\n", 370 | " r = self.b\n", 371 | " for i in range(self.m):\n", 372 | " r += self.alpha[i] * self.Y[i] * self.kernel(data, self.X[i])\n", 373 | " \n", 374 | " return 1 if r > 0 else -1\n", 375 | " \n", 376 | " def score(self, X_test, y_test):\n", 377 | " right_count = 0\n", 378 | " for i in range(len(X_test)):\n", 379 | " result = self.predict(X_test[i])\n", 380 | " if result == y_test[i]:\n", 381 | " right_count += 1\n", 382 | " return right_count / len(X_test)\n", 383 | " \n", 384 | " def _weight(self):\n", 385 | " # linear model\n", 386 | " yx = self.Y.reshape(-1, 1)*self.X\n", 387 | " self.w = np.dot(yx.T, self.alpha)\n", 388 | " return self.w" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": 12, 394 | "metadata": { 395 | "collapsed": true 396 | }, 397 | "outputs": [], 398 | "source": [ 399 | "svm = SVM(max_iter=200)" 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "execution_count": 13, 405 | "metadata": {}, 406 | "outputs": [ 407 | { 408 | "data": { 409 | "text/plain": [ 410 | "'train done!'" 411 | ] 412 | }, 413 | "execution_count": 13, 414 | "metadata": {}, 415 | "output_type": "execute_result" 416 | } 417 | ], 418 | "source": [ 419 | "svm.fit(X_train, y_train)" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": 14, 425 | "metadata": {}, 426 | "outputs": [ 427 | { 428 | "data": { 429 | "text/plain": [ 430 | "0.96" 431 | ] 432 | }, 433 | "execution_count": 14, 434 | "metadata": {}, 435 | "output_type": "execute_result" 436 | } 437 | ], 438 | "source": [ 439 | "svm.score(X_test, y_test)" 440 | ] 441 | }, 442 | { 443 | "cell_type": "markdown", 444 | "metadata": {}, 445 | "source": [ 446 | "## sklearn.svm.SVC" 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": 15, 452 | "metadata": {}, 453 | "outputs": [ 454 | { 455 | "data": { 456 | "text/plain": [ 457 | "SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,\n", 458 | " decision_function_shape=None, degree=3, gamma='auto', kernel='rbf',\n", 459 | " max_iter=-1, probability=False, random_state=None, shrinking=True,\n", 460 | " tol=0.001, verbose=False)" 461 | ] 462 | }, 463 | "execution_count": 15, 464 | "metadata": {}, 465 | "output_type": "execute_result" 466 | } 467 | ], 468 | "source": [ 469 | "from sklearn.svm import SVC\n", 470 | "clf = SVC()\n", 471 | "clf.fit(X_train, y_train)" 472 | ] 473 | }, 474 | { 475 | "cell_type": "code", 476 | "execution_count": 16, 477 | "metadata": {}, 478 | "outputs": [ 479 | { 480 | "data": { 481 | "text/plain": [ 482 | "1.0" 483 | ] 484 | }, 485 | "execution_count": 16, 486 | "metadata": {}, 487 | "output_type": "execute_result" 488 | } 489 | ], 490 | "source": [ 491 | "clf.score(X_test, y_test)" 492 | ] 493 | }, 494 | { 495 | "cell_type": "markdown", 496 | "metadata": {}, 497 | "source": [ 498 | "### sklearn.svm.SVC\n", 499 | "\n", 500 | "*(C=1.0, kernel='rbf', degree=3, gamma='auto', coef0=0.0, shrinking=True, probability=False,tol=0.001, cache_size=200, class_weight=None, verbose=False, max_iter=-1, decision_function_shape=None,random_state=None)*\n", 501 | "\n", 502 | "参数:\n", 503 | "\n", 504 | "- C:C-SVC的惩罚参数C?默认值是1.0\n", 505 | "\n", 506 | "C越大,相当于惩罚松弛变量,希望松弛变量接近0,即对误分类的惩罚增大,趋向于对训练集全分对的情况,这样对训练集测试时准确率很高,但泛化能力弱。C值小,对误分类的惩罚减小,允许容错,将他们当成噪声点,泛化能力较强。\n", 507 | "\n", 508 | "- kernel :核函数,默认是rbf,可以是‘linear’, ‘poly’, ‘rbf’, ‘sigmoid’, ‘precomputed’ \n", 509 | " \n", 510 | " – 线性:u'v\n", 511 | " \n", 512 | " – 多项式:(gamma*u'*v + coef0)^degree\n", 513 | "\n", 514 | " – RBF函数:exp(-gamma|u-v|^2)\n", 515 | "\n", 516 | " – sigmoid:tanh(gamma*u'*v + coef0)\n", 517 | "\n", 518 | "\n", 519 | "- degree :多项式poly函数的维度,默认是3,选择其他核函数时会被忽略。\n", 520 | "\n", 521 | "\n", 522 | "- gamma : ‘rbf’,‘poly’ 和‘sigmoid’的核函数参数。默认是’auto’,则会选择1/n_features\n", 523 | "\n", 524 | "\n", 525 | "- coef0 :核函数的常数项。对于‘poly’和 ‘sigmoid’有用。\n", 526 | "\n", 527 | "\n", 528 | "- probability :是否采用概率估计?.默认为False\n", 529 | "\n", 530 | "\n", 531 | "- shrinking :是否采用shrinking heuristic方法,默认为true\n", 532 | "\n", 533 | "\n", 534 | "- tol :停止训练的误差值大小,默认为1e-3\n", 535 | "\n", 536 | "\n", 537 | "- cache_size :核函数cache缓存大小,默认为200\n", 538 | "\n", 539 | "\n", 540 | "- class_weight :类别的权重,字典形式传递。设置第几类的参数C为weight*C(C-SVC中的C)\n", 541 | "\n", 542 | "\n", 543 | "- verbose :允许冗余输出?\n", 544 | "\n", 545 | "\n", 546 | "- max_iter :最大迭代次数。-1为无限制。\n", 547 | "\n", 548 | "\n", 549 | "- decision_function_shape :‘ovo’, ‘ovr’ or None, default=None3\n", 550 | "\n", 551 | "\n", 552 | "- random_state :数据洗牌时的种子值,int值\n", 553 | "\n", 554 | "\n", 555 | "主要调节的参数有:C、kernel、degree、gamma、coef0。" 556 | ] 557 | }, 558 | { 559 | "cell_type": "code", 560 | "execution_count": null, 561 | "metadata": { 562 | "collapsed": true 563 | }, 564 | "outputs": [], 565 | "source": [] 566 | } 567 | ], 568 | "metadata": { 569 | "kernelspec": { 570 | "display_name": "Python 3", 571 | "language": "python", 572 | "name": "python3" 573 | }, 574 | "language_info": { 575 | "codemirror_mode": { 576 | "name": "ipython", 577 | "version": 3 578 | }, 579 | "file_extension": ".py", 580 | "mimetype": "text/x-python", 581 | "name": "python", 582 | "nbconvert_exporter": "python", 583 | "pygments_lexer": "ipython3", 584 | "version": "3.6.1" 585 | } 586 | }, 587 | "nbformat": 4, 588 | "nbformat_minor": 2 589 | } 590 | -------------------------------------------------------------------------------- /SVM/support-vector-machine.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 支持向量机" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "----\n", 15 | "分离超平面:$w^Tx+b=0$\n", 16 | "\n", 17 | "点到直线距离:$r=\\frac{|w^Tx+b|}{||w||_2}$\n", 18 | "\n", 19 | "$||w||_2$为2-范数:$||w||_2=\\sqrt[2]{\\sum^m_{i=1}w_i^2}$\n", 20 | "\n", 21 | "直线为超平面,样本可表示为:\n", 22 | "\n", 23 | "$w^Tx+b\\ \\geq+1$\n", 24 | "\n", 25 | "$w^Tx+b\\ \\leq+1$\n", 26 | "\n", 27 | "#### margin:\n", 28 | "\n", 29 | "**函数间隔**:$label(w^Tx+b)\\ or\\ y_i(w^Tx+b)$\n", 30 | "\n", 31 | "**几何间隔**:$r=\\frac{label(w^Tx+b)}{||w||_2}$,当数据被正确分类时,几何间隔就是点到超平面的距离\n", 32 | "\n", 33 | "为了求几何间隔最大,SVM基本问题可以转化为求解:($\\frac{r^*}{||w||}$为几何间隔,(${r^*}$为函数间隔)\n", 34 | "\n", 35 | "$$\\max\\ \\frac{r^*}{||w||}$$\n", 36 | "\n", 37 | "$$(subject\\ to)\\ y_i({w^T}x_i+{b})\\geq {r^*},\\ i=1,2,..,m$$\n", 38 | "\n", 39 | "分类点几何间隔最大,同时被正确分类。但这个方程并非凸函数求解,所以要先①将方程转化为凸函数,②用拉格朗日乘子法和KKT条件求解对偶问题。\n", 40 | "\n", 41 | "①转化为凸函数:\n", 42 | "\n", 43 | "先令${r^*}=1$,方便计算(参照衡量,不影响评价结果)\n", 44 | "\n", 45 | "$$\\max\\ \\frac{1}{||w||}$$\n", 46 | "\n", 47 | "$$s.t.\\ y_i({w^T}x_i+{b})\\geq {1},\\ i=1,2,..,m$$\n", 48 | "\n", 49 | "再将$\\max\\ \\frac{1}{||w||}$转化成$\\min\\ \\frac{1}{2}||w||^2$求解凸函数,1/2是为了求导之后方便计算。\n", 50 | "\n", 51 | "$$\\min\\ \\frac{1}{2}||w||^2$$\n", 52 | "\n", 53 | "$$s.t.\\ y_i(w^Tx_i+b)\\geq 1,\\ i=1,2,..,m$$\n", 54 | "\n", 55 | "②用拉格朗日乘子法和KKT条件求解最优值:\n", 56 | "\n", 57 | "$$\\min\\ \\frac{1}{2}||w||^2$$\n", 58 | "\n", 59 | "$$s.t.\\ -y_i(w^Tx_i+b)+1\\leq 0,\\ i=1,2,..,m$$\n", 60 | "\n", 61 | "整合成:\n", 62 | "\n", 63 | "$$L(w, b, \\alpha) = \\frac{1}{2}||w||^2+\\sum^m_{i=1}\\alpha_i(-y_i(w^Tx_i+b)+1)$$\n", 64 | "\n", 65 | "推导:$\\min\\ f(x)=\\min \\max\\ L(w, b, \\alpha)\\geq \\max \\min\\ L(w, b, \\alpha)$\n", 66 | "\n", 67 | "根据KKT条件:\n", 68 | "\n", 69 | "$$\\frac{\\partial }{\\partial w}L(w, b, \\alpha)=w-\\sum\\alpha_iy_ix_i=0,\\ w=\\sum\\alpha_iy_ix_i$$\n", 70 | "\n", 71 | "$$\\frac{\\partial }{\\partial b}L(w, b, \\alpha)=\\sum\\alpha_iy_i=0$$\n", 72 | "\n", 73 | "带入$ L(w, b, \\alpha)$\n", 74 | "\n", 75 | "$\\min\\ L(w, b, \\alpha)=\\frac{1}{2}||w||^2+\\sum^m_{i=1}\\alpha_i(-y_i(w^Tx_i+b)+1)$\n", 76 | "\n", 77 | "$\\qquad\\qquad\\qquad=\\frac{1}{2}w^Tw-\\sum^m_{i=1}\\alpha_iy_iw^Tx_i-b\\sum^m_{i=1}\\alpha_iy_i+\\sum^m_{i=1}\\alpha_i$\n", 78 | "\n", 79 | "$\\qquad\\qquad\\qquad=\\frac{1}{2}w^T\\sum\\alpha_iy_ix_i-\\sum^m_{i=1}\\alpha_iy_iw^Tx_i+\\sum^m_{i=1}\\alpha_i$\n", 80 | "\n", 81 | "$\\qquad\\qquad\\qquad=\\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i=1}\\alpha_iy_iw^Tx_i$\n", 82 | "\n", 83 | "$\\qquad\\qquad\\qquad=\\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)$\n", 84 | "\n", 85 | "再把max问题转成min问题:\n", 86 | "\n", 87 | "$\\max\\ \\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)=\\min \\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)-\\sum^m_{i=1}\\alpha_i$\n", 88 | "\n", 89 | "$s.t.\\ \\sum^m_{i=1}\\alpha_iy_i=0,$\n", 90 | "\n", 91 | "$ \\alpha_i \\geq 0,i=1,2,...,m$\n", 92 | "\n", 93 | "以上为SVM对偶问题的对偶形式\n", 94 | "\n", 95 | "-----\n", 96 | "#### kernel\n", 97 | "\n", 98 | "在低维空间计算获得高维空间的计算结果,也就是说计算结果满足高维(满足高维,才能说明高维下线性可分)。\n", 99 | "\n", 100 | "#### soft margin & slack variable\n", 101 | "\n", 102 | "引入松弛变量$\\xi\\geq0$,对应数据点允许偏离的functional margin 的量。\n", 103 | "\n", 104 | "目标函数:$\\min\\ \\frac{1}{2}||w||^2+C\\sum\\xi_i\\qquad s.t.\\ y_i(w^Tx_i+b)\\geq1-\\xi_i$ \n", 105 | "\n", 106 | "对偶问题:\n", 107 | "\n", 108 | "$$\\max\\ \\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)=\\min \\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)-\\sum^m_{i=1}\\alpha_i$$\n", 109 | "\n", 110 | "$$s.t.\\ C\\geq\\alpha_i \\geq 0,i=1,2,...,m\\quad \\sum^m_{i=1}\\alpha_iy_i=0,$$\n", 111 | "\n", 112 | "-----\n", 113 | "\n", 114 | "#### Sequential Minimal Optimization\n", 115 | "\n", 116 | "首先定义特征到结果的输出函数:$u=w^Tx+b$.\n", 117 | "\n", 118 | "因为$w=\\sum\\alpha_iy_ix_i$\n", 119 | "\n", 120 | "有$u=\\sum y_i\\alpha_iK(x_i, x)-b$\n", 121 | "\n", 122 | "\n", 123 | "----\n", 124 | "\n", 125 | "$\\max \\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i=1}\\sum^m_{j=1}\\alpha_i\\alpha_jy_iy_j<\\phi(x_i)^T,\\phi(x_j)>$\n", 126 | "\n", 127 | "$s.t.\\ \\sum^m_{i=1}\\alpha_iy_i=0,$\n", 128 | "\n", 129 | "$ \\alpha_i \\geq 0,i=1,2,...,m$\n", 130 | "\n", 131 | "-----\n", 132 | "参考资料:\n", 133 | "\n", 134 | "[1] :[Lagrange Multiplier and KKT](http://blog.csdn.net/xianlingmao/article/details/7919597)\n", 135 | "\n", 136 | "[2] :[推导SVM](https://my.oschina.net/dfsj66011/blog/517766)\n", 137 | "\n", 138 | "[3] :[机器学习算法实践-支持向量机(SVM)算法原理](http://pytlab.org/2017/08/15/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E7%AE%97%E6%B3%95%E5%AE%9E%E8%B7%B5-%E6%94%AF%E6%8C%81%E5%90%91%E9%87%8F%E6%9C%BA-SVM-%E7%AE%97%E6%B3%95%E5%8E%9F%E7%90%86/)\n", 139 | "\n", 140 | "[4] :[Python实现SVM](http://blog.csdn.net/wds2006sdo/article/details/53156589)" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 1, 146 | "metadata": {}, 147 | "outputs": [ 148 | { 149 | "name": "stderr", 150 | "output_type": "stream", 151 | "text": [ 152 | "E:\\Anaconda3\\lib\\site-packages\\sklearn\\cross_validation.py:44: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n", 153 | " \"This module will be removed in 0.20.\", DeprecationWarning)\n" 154 | ] 155 | } 156 | ], 157 | "source": [ 158 | "import numpy as np\n", 159 | "import pandas as pd\n", 160 | "from sklearn.datasets import load_iris\n", 161 | "from sklearn.cross_validation import train_test_split\n", 162 | "import matplotlib.pyplot as plt\n", 163 | "%matplotlib inline" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 2, 169 | "metadata": { 170 | "collapsed": true 171 | }, 172 | "outputs": [], 173 | "source": [ 174 | "# data\n", 175 | "def create_data():\n", 176 | " iris = load_iris()\n", 177 | " df = pd.DataFrame(iris.data, columns=iris.feature_names)\n", 178 | " df['label'] = iris.target\n", 179 | " df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n", 180 | " data = np.array(df.iloc[:100, [0, 1, -1]])\n", 181 | " for i in range(len(data)):\n", 182 | " if data[i,-1] == 0:\n", 183 | " data[i,-1] = -1\n", 184 | " # print(data)\n", 185 | " return data[:,:2], data[:,-1]" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 9, 191 | "metadata": { 192 | "collapsed": true 193 | }, 194 | "outputs": [], 195 | "source": [ 196 | "X, y = create_data()\n", 197 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 10, 203 | "metadata": {}, 204 | "outputs": [ 205 | { 206 | "data": { 207 | "text/plain": [ 208 | "" 209 | ] 210 | }, 211 | "execution_count": 10, 212 | "metadata": {}, 213 | "output_type": "execute_result" 214 | }, 215 | { 216 | "data": { 217 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAGZ9JREFUeJzt3X9sHOWdx/H394yv8bWAReMWsJNLCihqSSLSugTICXGg\nXkqaQoRQlAiKQhE5ELpS0aNqKtQfqBJISLRQdEQBdBTBBeVoGihHgjgoKkUklZMg5y4pKhxtY8MV\nE5TQHKYE93t/7Dqx12vvzu6O93me/bwky97Zyfj7zMA3m5nPPGPujoiIpOWvml2AiIg0npq7iEiC\n1NxFRBKk5i4ikiA1dxGRBKm5i4gkSM1dRCRBau4iIglScxcRSdBx1a5oZm1AHzDo7stL3rsAeBx4\nvbhos7vfOtX2Zs6c6XPmzMlUrIhIq9u5c+fb7t5Vab2qmztwI7APOGGS918obfpTmTNnDn19fRl+\nvYiImNnvq1mvqtMyZtYDfAm4v56iRERkelR7zv1HwDeBv0yxznlm1m9mW83szHIrmNlaM+szs76h\noaGstYqISJUqNnczWw685e47p1htFzDb3RcCPwa2lFvJ3Te4e6+793Z1VTxlJCIiNarmnPsS4BIz\nWwbMAE4ws4fd/crRFdz93TE/P2Vm/2JmM9397caXLCJSnyNHjjAwMMD777/f7FImNWPGDHp6emhv\nb6/pz1ds7u6+DlgHR1Mx/zy2sReXnwz80d3dzM6m8C+CAzVVJCKSs4GBAY4//njmzJmDmTW7nAnc\nnQMHDjAwMMDcuXNr2kaWtMw4ZnZdsYj1wOXA9Wb2ITAMrHI9BUREAvX+++8H29gBzIyPf/zj1HNt\nMlNzd/fngeeLP68fs/we4J6aqxAJ2Jbdg9zx9Cu8cXCYUzs7uHnpPFYs6m52WVKnUBv7qHrrq/mT\nu0gr2LJ7kHWb9zB8ZASAwYPDrNu8B0ANXoKm6QdEpnDH068cbeyjho+McMfTrzSpIknFtm3bmDdv\nHqeffjq33357w7ev5i4yhTcODmdaLlKNkZERbrjhBrZu3crevXvZuHEje/fubejv0GkZkSmc2tnB\nYJlGfmpnRxOqkWZp9HWXX//615x++ul86lOfAmDVqlU8/vjjfOYzn2lUyfrkLjKVm5fOo6O9bdyy\njvY2bl46r0kVyXQbve4yeHAY59h1ly27B2ve5uDgILNmzTr6uqenh8HB2rdXjpq7yBRWLOrmtssW\n0N3ZgQHdnR3cdtkCXUxtIbFed9FpGZEKVizqVjNvYXlcd+nu7mb//v1HXw8MDNDd3dj/xvTJXURk\nCpNdX6nnusvnP/95fvvb3/L666/zwQcf8Oijj3LJJZfUvL1y1NxFRKaQx3WX4447jnvuuYelS5fy\n6U9/mpUrV3LmmWUn0639dzR0ayIiiRk9Jdfou5SXLVvGsmXLGlFiWWruIiIVxHjdRadlREQSpOYu\nIpIgNXcRkQSpuYuIJEjNXUQkQWrukowtuwdZcvtzzP3Wf7Dk9ufqmvtDJG9f/epX+cQnPsH8+fNz\n2b6auyQhj8mdRPK0Zs0atm3bltv21dwlCbFO7iSR6N8EP5wP3+ssfO/fVPcmzz//fE466aQGFFee\nbmKSJOihGpKb/k3w86/BkeJ/S4f2F14DLFzZvLoq0Cd3SUIekzuJAPDsrcca+6gjw4XlAVNzlyTo\noRqSm0MD2ZYHQqdlJAl5Te4kwok9hVMx5ZYHTM1dkhHj5E4SgYu+M/6cO0B7R2F5HVavXs3zzz/P\n22+/TU9PD9///ve55ppr6iz2GDV3qVujHx4sEpTRi6bP3lo4FXNiT6Gx13kxdePGjQ0obnJq7lKX\n0Xz5aAxxNF8OqMFLOhauDDoZU44uqEpdlC8XCZOau9RF+XKJlbs3u4Qp1VufmrvURflyidGMGTM4\ncOBAsA3e3Tlw4AAzZsyoeRs65y51uXnpvHHn3EH5cglfT08PAwMDDA0NNbuUSc2YMYOentrjlmru\nUhflyyVG7e3tzJ07t9ll5Krq5m5mbUAfMOjuy0veM+AuYBnwHrDG3Xc1slAJl/LlIuHJ8sn9RmAf\ncEKZ9y4Gzih+LQbuLX4XaSnK/EsoqrqgamY9wJeA+ydZ5VLgIS/YDnSa2SkNqlEkCppTXkJSbVrm\nR8A3gb9M8n43MHbyhYHiMpGWocy/hKRiczez5cBb7r6z3l9mZmvNrM/M+kK+Si1SC2X+JSTVfHJf\nAlxiZr8DHgUuNLOHS9YZBGaNed1TXDaOu29w91537+3q6qqxZJEwKfMvIanY3N19nbv3uPscYBXw\nnLtfWbLaE8BVVnAOcMjd32x8uSLh0pzyEpKac+5mdh2Au68HnqIQg3yVQhTy6oZUJxIRZf4lJNas\n2297e3u9r6+vKb9bRCRWZrbT3Xsrrac7VCVYt2zZw8Yd+xlxp82M1Ytn8YMVC5pdlkgU1NwlSLds\n2cPD2/9w9PWI+9HXavAilWlWSAnSxh1lnlk5xXIRGU/NXYI0Msm1oMmWi8h4au4SpDazTMtFZDw1\ndwnS6sWzMi0XkfF0QVWCNHrRVGkZkdoo5y4iEhHl3KUuV9z3Ei++9s7R10tOO4lHrj23iRU1j+Zo\nlxjpnLtMUNrYAV587R2uuO+lJlXUPJqjXWKl5i4TlDb2SstTpjnaJVZq7iJT0BztEis1d5EpaI52\niZWau0yw5LSTMi1PmeZol1ipucsEj1x77oRG3qppmRWLurntsgV0d3ZgQHdnB7ddtkBpGQmecu4i\nIhFRzl3qkle2O8t2lS8XqZ2au0wwmu0ejQCOZruBupprlu3mVYNIq9A5d5kgr2x3lu0qXy5SHzV3\nmSCvbHeW7SpfLlIfNXeZIK9sd5btKl8uUh81d5kgr2x3lu0qXy5SH11QlQlGL1g2OqmSZbt51SDS\nKpRzFxGJiHLuOYsxgx1jzSJSGzX3GsSYwY6xZhGpnS6o1iDGDHaMNYtI7dTcaxBjBjvGmkWkdmru\nNYgxgx1jzSJSOzX3GsSYwY6xZhGpnS6o1iDGDHaMNYtI7Srm3M1sBvBL4CMU/jJ4zN2/W7LOBcDj\nwOvFRZvd/daptqucu4hIdo3Muf8ZuNDdD5tZO/ArM9vq7ttL1nvB3ZfXUqxMj1u27GHjjv2MuNNm\nxurFs/jBigV1rxtKfj6UOkRCULG5e+Gj/eHiy/biV3Nua5Wa3bJlDw9v/8PR1yPuR1+XNu0s64aS\nnw+lDpFQVHVB1czazOxl4C3gGXffUWa188ys38y2mtmZDa1S6rZxx/6ql2dZN5T8fCh1iISiqubu\n7iPufhbQA5xtZvNLVtkFzHb3hcCPgS3ltmNma82sz8z6hoaG6qlbMhqZ5NpKueVZ1g0lPx9KHSKh\nyBSFdPeDwC+AL5Ysf9fdDxd/fgpoN7OZZf78Bnfvdfferq6uOsqWrNrMql6eZd1Q8vOh1CESiorN\n3cy6zKyz+HMH8AXgNyXrnGxW+D/fzM4ubvdA48uVWq1ePKvq5VnWDSU/H0odIqGoJi1zCvATM2uj\n0LQ3ufuTZnYdgLuvBy4HrjezD4FhYJU3ay5hKWv0Qmg1CZgs64aSnw+lDpFQaD53EZGIaD73nOWV\nqc6SL89z21nGF+O+iE7/Jnj2Vjg0ACf2wEXfgYUrm12VBEzNvQZ5Zaqz5Mvz3HaW8cW4L6LTvwl+\n/jU4Ukz+HNpfeA1q8DIpTRxWg7wy1Vny5XluO8v4YtwX0Xn21mONfdSR4cJykUmoudcgr0x1lnx5\nntvOMr4Y90V0Dg1kWy6CmntN8spUZ8mX57ntLOOLcV9E58SebMtFUHOvSV6Z6iz58jy3nWV8Me6L\n6Fz0HWgv+cuyvaOwXGQSuqBag7wy1Vny5XluO8v4YtwX0Rm9aKq0jGSgnLuISESUc5cJQsiuS+SU\nt4+GmnuLCCG7LpFT3j4quqDaIkLIrkvklLePipp7iwghuy6RU94+KmruLSKE7LpETnn7qKi5t4gQ\nsusSOeXto6ILqi0ihOy6RE55+6go5y4iEhHl3Ivyymtn2W4o85Irux6Y1DPjqY8viybsi6Sbe155\n7SzbDWVecmXXA5N6Zjz18WXRpH2R9AXVvPLaWbYbyrzkyq4HJvXMeOrjy6JJ+yLp5p5XXjvLdkOZ\nl1zZ9cCknhlPfXxZNGlfJN3c88prZ9luKPOSK7semNQz46mPL4sm7Yukm3teee0s2w1lXnJl1wOT\nemY89fFl0aR9kfQF1bzy2lm2G8q85MquByb1zHjq48uiSftCOXcRkYgo556zEPLzV9z3Ei++9s7R\n10tOO4lHrj237hpEkvLkTbDzQfARsDb43BpYfmf92w08x5/0Ofe8jGbGBw8O4xzLjG/ZPTht2y1t\n7AAvvvYOV9z3Ul01iCTlyZug74FCY4fC974HCsvrMZpdP7Qf8GPZ9f5NdZfcKGruNQghP1/a2Cst\nF2lJOx/MtrxaEeT41dxrEEJ+XkSq4CPZllcrghy/mnsNQsjPi0gVrC3b8mpFkONXc69BCPn5Jaed\nVHYbky0XaUmfW5NtebUiyPGruddgxaJubrtsAd2dHRjQ3dnBbZctaEh+vtrtPnLtuRMaudIyIiWW\n3wm91xz7pG5thdf1pmUWroQv3w0nzgKs8P3LdweVllHOXUQkIg3LuZvZDOCXwEeK6z/m7t8tWceA\nu4BlwHvAGnffVUvhlWTNl8c2h3mWud9T3xe55oizZJ/zqiPP8QWewa5L1rGlvC+mUM1NTH8GLnT3\nw2bWDvzKzLa6+/Yx61wMnFH8WgzcW/zeUFnnJI9tDvMsc7+nvi9ynQN7NPs8ajT7DBMbfF515Dm+\nlOdSzzq2lPdFBRXPuXvB4eLL9uJX6bmcS4GHiutuBzrN7JTGlpo9Xx7bHOZZ5n5PfV/kmiPOkn3O\nq448xxdBBrtmWceW8r6ooKoLqmbWZmYvA28Bz7j7jpJVuoGxHWiguKx0O2vNrM/M+oaGhjIXmzUH\nHltuPMvc76nvi1xzxFmyz3nVkef4Ishg1yzr2FLeFxVU1dzdfcTdzwJ6gLPNbH4tv8zdN7h7r7v3\ndnV1Zf7zWXPgseXGs8z9nvq+yDVHnCX7nFcdeY4vggx2zbKOLeV9UUGmKKS7HwR+AXyx5K1BYOwE\n5T3FZQ2VNV8e2xzmWeZ+T31f5JojzpJ9zquOPMcXQQa7ZlnHlvK+qKBiczezLjPrLP7cAXwB+E3J\nak8AV1nBOcAhd3+z0cVmzZfnlUfPyw9WLODKc2Yf/aTeZsaV58wum5ZJfV/kmiPOkn3Oq448xxdB\nBrtmWceW8r6ooGLO3cwWAj8B2ij8ZbDJ3W81s+sA3H19MQp5D4VP9O8BV7v7lCF25dxFRLJrWM7d\n3fuBRWWWrx/zswM3ZC1SRETykfzDOqK7cUemR5YbW0K4CSbPG3diu0krhOMRgaSbe3Q37sj0yHJj\nSwg3weR5405sN2mFcDwikfTEYdHduCPTI8uNLSHcBJPnjTux3aQVwvGIRNLNPbobd2R6ZLmxJYSb\nYPK8cSe2m7RCOB6RSLq5R3fjjkyPLDe2hHATTJ437sR2k1YIxyMSSTf36G7ckemR5caWEG6CyfPG\nndhu0grheEQi6eYe3Y07Mj2y3NgSwk0wed64E9tNWiEcj0joYR0iIhFp2E1MIi0vy4M9QhFbzaFk\n10OpowHU3EWmkuXBHqGIreZQsuuh1NEgSZ9zF6lblgd7hCK2mkPJrodSR4OouYtMJcuDPUIRW82h\nZNdDqaNB1NxFppLlwR6hiK3mULLrodTRIGruIlPJ8mCPUMRWcyjZ9VDqaBA1d5GpZHmwRyhiqzmU\n7HoodTSIcu4iIhFRzl2mT4zZ4LxqzitfHuM+lqZSc5f6xJgNzqvmvPLlMe5jaTqdc5f6xJgNzqvm\nvPLlMe5jaTo1d6lPjNngvGrOK18e4z6WplNzl/rEmA3Oq+a88uUx7mNpOjV3qU+M2eC8as4rXx7j\nPpamU3OX+sSYDc6r5rzy5THuY2k65dxFRCJSbc5dn9wlHf2b4Ifz4Xudhe/9m6Z/u3nVIJKRcu6S\nhryy4Fm2qzy6BESf3CUNeWXBs2xXeXQJiJq7pCGvLHiW7SqPLgFRc5c05JUFz7Jd5dElIGrukoa8\nsuBZtqs8ugREzV3SkFcWPMt2lUeXgFTMuZvZLOAh4JOAAxvc/a6SdS4AHgdeLy7a7O5TXkVSzl1E\nJLtGzuf+IfANd99lZscDO83sGXffW7LeC+6+vJZiJUAxzh+epeYYxxcC7bdoVGzu7v4m8Gbx5z+Z\n2T6gGyht7pKKGPPayqPnT/stKpnOuZvZHGARsKPM2+eZWb+ZbTWzMxtQmzRLjHlt5dHzp/0Wlarv\nUDWzjwE/Bb7u7u+WvL0LmO3uh81sGbAFOKPMNtYCawFmz55dc9GSsxjz2sqj50/7LSpVfXI3s3YK\njf0Rd99c+r67v+vuh4s/PwW0m9nMMuttcPded+/t6uqqs3TJTYx5beXR86f9FpWKzd3MDHgA2Ofu\nZecuNbOTi+thZmcXt3ugkYXKNIoxr608ev6036JSzWmZJcBXgD1m9nJx2beB2QDuvh64HLjezD4E\nhoFV3qy5hKV+oxfHYkpFZKk5xvGFQPstKprPXUQkIo3MuUuolDke78mbYOeDhQdSW1vh8Xb1PgVJ\nJFJq7rFS5ni8J2+CvgeOvfaRY6/V4KUFaW6ZWClzPN7OB7MtF0mcmnuslDkez0eyLRdJnJp7rJQ5\nHs/asi0XSZyae6yUOR7vc2uyLRdJnJp7rDR3+HjL74Tea459Ure2wmtdTJUWpZy7iEhElHOvwZbd\ng9zx9Cu8cXCYUzs7uHnpPFYs6m52WY2Tei4+9fGFQPs4GmruRVt2D7Ju8x6GjxTSFYMHh1m3eQ9A\nGg0+9Vx86uMLgfZxVHTOveiOp1852thHDR8Z4Y6nX2lSRQ2Wei4+9fGFQPs4KmruRW8cHM60PDqp\n5+JTH18ItI+jouZedGpnR6bl0Uk9F5/6+EKgfRwVNfeim5fOo6N9/A0vHe1t3Lx0XpMqarDUc/Gp\njy8E2sdR0QXVotGLpsmmZVKfizv18YVA+zgqyrmLiESk2py7TsuIxKB/E/xwPnyvs/C9f1Mc25am\n0WkZkdDlmS9Xdj1Z+uQuEro88+XKridLzV0kdHnmy5VdT5aau0jo8syXK7ueLDV3kdDlmS9Xdj1Z\nau4ioctz7n49FyBZyrmLiEREOXcRkRam5i4ikiA1dxGRBKm5i4gkSM1dRCRBau4iIglScxcRSZCa\nu4hIgio2dzObZWa/MLO9ZvbfZnZjmXXMzO42s1fNrN/MPptPuVIXzdst0jKqmc/9Q+Ab7r7LzI4H\ndprZM+6+d8w6FwNnFL8WA/cWv0soNG+3SEup+Mnd3d90913Fn/8E7ANKHyx6KfCQF2wHOs3slIZX\nK7XTvN0iLSXTOXczmwMsAnaUvNUN7B/zeoCJfwFgZmvNrM/M+oaGhrJVKvXRvN0iLaXq5m5mHwN+\nCnzd3d+t5Ze5+wZ373X33q6urlo2IbXSvN0iLaWq5m5m7RQa+yPuvrnMKoPArDGve4rLJBSat1uk\npVSTljHgAWCfu985yWpPAFcVUzPnAIfc/c0G1in10rzdIi2lmrTMEuArwB4ze7m47NvAbAB3Xw88\nBSwDXgXeA65ufKlSt4Ur1cxFWkTF5u7uvwKswjoO3NCookREpD66Q1VEJEFq7iIiCVJzFxFJkJq7\niEiC1NxFRBKk5i4ikiA1dxGRBFkhot6EX2w2BPy+Kb+8spnA280uIkcaX7xSHhtofNX4W3evODlX\n05p7yMysz917m11HXjS+eKU8NtD4GkmnZUREEqTmLiKSIDX38jY0u4CcaXzxSnlsoPE1jM65i4gk\nSJ/cRUQS1NLN3czazGy3mT1Z5r0LzOyQmb1c/IrqkUVm9jsz21Osva/M+2Zmd5vZq2bWb2afbUad\ntapifLEfv04ze8zMfmNm+8zs3JL3Yz9+lcYX7fEzs3lj6n7ZzN41s6+XrJP78avmYR0puxHYB5ww\nyfsvuPvyaayn0f7e3SfL1F4MnFH8WgzcW/wek6nGB3Efv7uAbe5+uZn9NfA3Je/HfvwqjQ8iPX7u\n/gpwFhQ+QFJ45OjPSlbL/fi17Cd3M+sBvgTc3+xamuRS4CEv2A50mtkpzS5KwMxOBM6n8HhL3P0D\ndz9Yslq0x6/K8aXiIuA1dy+9YTP349eyzR34EfBN4C9TrHNe8Z9MW83szGmqq1Ec+E8z22lma8u8\n3w3sH/N6oLgsFpXGB/Eev7nAEPCvxdOG95vZR0vWifn4VTM+iPf4jbUK2Fhmee7HryWbu5ktB95y\n951TrLYLmO3uC4EfA1umpbjG+Tt3P4vCP/9uMLPzm11Qg1UaX8zH7zjgs8C97r4I+D/gW80tqaGq\nGV/Mxw+A4ummS4B/b8bvb8nmTuGh35eY2e+AR4ELzezhsSu4+7vufrj481NAu5nNnPZKa+Tug8Xv\nb1E433d2ySqDwKwxr3uKy6JQaXyRH78BYMDddxRfP0ahGY4V8/GrOL7Ij9+oi4Fd7v7HMu/lfvxa\nsrm7+zp373H3ORT+2fScu185dh0zO9nMrPjz2RT21YFpL7YGZvZRMzt+9GfgH4D/KlntCeCq4lX7\nc4BD7v7mNJdak2rGF/Pxc/f/Bfab2bzioouAvSWrRXv8qhlfzMdvjNWUPyUD03D8Wj0tM46ZXQfg\n7uuBy4HrzexDYBhY5fHc8fVJ4GfF/zeOA/7N3beVjO8pYBnwKvAecHWTaq1FNeOL+fgB/BPwSPGf\n9v8DXJ3Q8YPK44v6+BU/dHwB+Mcxy6b1+OkOVRGRBLXkaRkRkdSpuYuIJEjNXUQkQWruIiIJUnMX\nEUmQmruISILU3EVEEqTmLiKSoP8H2fNC9uxjMHwAAAAASUVORK5CYII=\n", 218 | "text/plain": [ 219 | "" 220 | ] 221 | }, 222 | "metadata": {}, 223 | "output_type": "display_data" 224 | } 225 | ], 226 | "source": [ 227 | "plt.scatter(X[:50,0],X[:50,1], label='0')\n", 228 | "plt.scatter(X[50:,0],X[50:,1], label='1')\n", 229 | "plt.legend()" 230 | ] 231 | }, 232 | { 233 | "cell_type": "markdown", 234 | "metadata": {}, 235 | "source": [ 236 | "----\n", 237 | "\n" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 11, 243 | "metadata": { 244 | "collapsed": true 245 | }, 246 | "outputs": [], 247 | "source": [ 248 | "class SVM:\n", 249 | " def __init__(self, max_iter=100, kernel='linear'):\n", 250 | " self.max_iter = max_iter\n", 251 | " self._kernel = kernel\n", 252 | " \n", 253 | " def init_args(self, features, labels):\n", 254 | " self.m, self.n = features.shape\n", 255 | " self.X = features\n", 256 | " self.Y = labels\n", 257 | " self.b = 0.0\n", 258 | " \n", 259 | " # 将Ei保存在一个列表里\n", 260 | " self.alpha = np.ones(self.m)\n", 261 | " self.E = [self._E(i) for i in range(self.m)]\n", 262 | " # 松弛变量\n", 263 | " self.C = 1.0\n", 264 | " \n", 265 | " def _KKT(self, i):\n", 266 | " y_g = self._g(i)*self.Y[i]\n", 267 | " if self.alpha[i] == 0:\n", 268 | " return y_g >= 1\n", 269 | " elif 0 < self.alpha[i] < self.C:\n", 270 | " return y_g == 1\n", 271 | " else:\n", 272 | " return y_g <= 1\n", 273 | " \n", 274 | " # g(x)预测值,输入xi(X[i])\n", 275 | " def _g(self, i):\n", 276 | " r = self.b\n", 277 | " for j in range(self.m):\n", 278 | " r += self.alpha[j]*self.Y[j]*self.kernel(self.X[i], self.X[j])\n", 279 | " return r\n", 280 | " \n", 281 | " # 核函数\n", 282 | " def kernel(self, x1, x2):\n", 283 | " if self._kernel == 'linear':\n", 284 | " return sum([x1[k]*x2[k] for k in range(self.n)])\n", 285 | " elif self._kernel == 'poly':\n", 286 | " return (sum([x1[k]*x2[k] for k in range(self.n)]) + 1)**2\n", 287 | " \n", 288 | " return 0\n", 289 | " \n", 290 | " # E(x)为g(x)对输入x的预测值和y的差\n", 291 | " def _E(self, i):\n", 292 | " return self._g(i) - self.Y[i]\n", 293 | " \n", 294 | " def _init_alpha(self):\n", 295 | " # 外层循环首先遍历所有满足0= 0:\n", 308 | " j = min(range(self.m), key=lambda x: self.E[x])\n", 309 | " else:\n", 310 | " j = max(range(self.m), key=lambda x: self.E[x])\n", 311 | " return i, j\n", 312 | " \n", 313 | " def _compare(self, _alpha, L, H):\n", 314 | " if _alpha > H:\n", 315 | " return H\n", 316 | " elif _alpha < L:\n", 317 | " return L\n", 318 | " else:\n", 319 | " return _alpha \n", 320 | " \n", 321 | " def fit(self, features, labels):\n", 322 | " self.init_args(features, labels)\n", 323 | " \n", 324 | " for t in range(self.max_iter):\n", 325 | " # train\n", 326 | " i1, i2 = self._init_alpha()\n", 327 | " \n", 328 | " # 边界\n", 329 | " if self.Y[i1] == self.Y[i2]:\n", 330 | " L = max(0, self.alpha[i1]+self.alpha[i2]-self.C)\n", 331 | " H = min(self.C, self.alpha[i1]+self.alpha[i2])\n", 332 | " else:\n", 333 | " L = max(0, self.alpha[i2]-self.alpha[i1])\n", 334 | " H = min(self.C, self.C+self.alpha[i2]-self.alpha[i1])\n", 335 | " \n", 336 | " E1 = self.E[i1]\n", 337 | " E2 = self.E[i2]\n", 338 | " # eta=K11+K22-2K12\n", 339 | " eta = self.kernel(self.X[i1], self.X[i1]) + self.kernel(self.X[i2], self.X[i2]) - 2*self.kernel(self.X[i1], self.X[i2])\n", 340 | " if eta <= 0:\n", 341 | " # print('eta <= 0')\n", 342 | " continue\n", 343 | " \n", 344 | " alpha2_new_unc = self.alpha[i2] + self.Y[i2] * (E2 - E1) / eta\n", 345 | " alpha2_new = self._compare(alpha2_new_unc, L, H)\n", 346 | " \n", 347 | " alpha1_new = self.alpha[i1] + self.Y[i1] * self.Y[i2] * (self.alpha[i2] - alpha2_new)\n", 348 | " \n", 349 | " b1_new = -E1 - self.Y[i1] * self.kernel(self.X[i1], self.X[i1]) * (alpha1_new-self.alpha[i1]) - self.Y[i2] * self.kernel(self.X[i2], self.X[i1]) * (alpha2_new-self.alpha[i2])+ self.b \n", 350 | " b2_new = -E2 - self.Y[i1] * self.kernel(self.X[i1], self.X[i2]) * (alpha1_new-self.alpha[i1]) - self.Y[i2] * self.kernel(self.X[i2], self.X[i2]) * (alpha2_new-self.alpha[i2])+ self.b \n", 351 | " \n", 352 | " if 0 < alpha1_new < self.C:\n", 353 | " b_new = b1_new\n", 354 | " elif 0 < alpha2_new < self.C:\n", 355 | " b_new = b2_new\n", 356 | " else:\n", 357 | " # 选择中点\n", 358 | " b_new = (b1_new + b2_new) / 2\n", 359 | " \n", 360 | " # 更新参数\n", 361 | " self.alpha[i1] = alpha1_new\n", 362 | " self.alpha[i2] = alpha2_new\n", 363 | " self.b = b_new\n", 364 | " \n", 365 | " self.E[i1] = self._E(i1)\n", 366 | " self.E[i2] = self._E(i2)\n", 367 | " return 'train done!'\n", 368 | " \n", 369 | " def predict(self, data):\n", 370 | " r = self.b\n", 371 | " for i in range(self.m):\n", 372 | " r += self.alpha[i] * self.Y[i] * self.kernel(data, self.X[i])\n", 373 | " \n", 374 | " return 1 if r > 0 else -1\n", 375 | " \n", 376 | " def score(self, X_test, y_test):\n", 377 | " right_count = 0\n", 378 | " for i in range(len(X_test)):\n", 379 | " result = self.predict(X_test[i])\n", 380 | " if result == y_test[i]:\n", 381 | " right_count += 1\n", 382 | " return right_count / len(X_test)\n", 383 | " \n", 384 | " def _weight(self):\n", 385 | " # linear model\n", 386 | " yx = self.Y.reshape(-1, 1)*self.X\n", 387 | " self.w = np.dot(yx.T, self.alpha)\n", 388 | " return self.w" 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": 12, 394 | "metadata": { 395 | "collapsed": true 396 | }, 397 | "outputs": [], 398 | "source": [ 399 | "svm = SVM(max_iter=200)" 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "execution_count": 13, 405 | "metadata": {}, 406 | "outputs": [ 407 | { 408 | "data": { 409 | "text/plain": [ 410 | "'train done!'" 411 | ] 412 | }, 413 | "execution_count": 13, 414 | "metadata": {}, 415 | "output_type": "execute_result" 416 | } 417 | ], 418 | "source": [ 419 | "svm.fit(X_train, y_train)" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": 14, 425 | "metadata": {}, 426 | "outputs": [ 427 | { 428 | "data": { 429 | "text/plain": [ 430 | "0.96" 431 | ] 432 | }, 433 | "execution_count": 14, 434 | "metadata": {}, 435 | "output_type": "execute_result" 436 | } 437 | ], 438 | "source": [ 439 | "svm.score(X_test, y_test)" 440 | ] 441 | }, 442 | { 443 | "cell_type": "markdown", 444 | "metadata": {}, 445 | "source": [ 446 | "## sklearn.svm.SVC" 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": 15, 452 | "metadata": {}, 453 | "outputs": [ 454 | { 455 | "data": { 456 | "text/plain": [ 457 | "SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,\n", 458 | " decision_function_shape=None, degree=3, gamma='auto', kernel='rbf',\n", 459 | " max_iter=-1, probability=False, random_state=None, shrinking=True,\n", 460 | " tol=0.001, verbose=False)" 461 | ] 462 | }, 463 | "execution_count": 15, 464 | "metadata": {}, 465 | "output_type": "execute_result" 466 | } 467 | ], 468 | "source": [ 469 | "from sklearn.svm import SVC\n", 470 | "clf = SVC()\n", 471 | "clf.fit(X_train, y_train)" 472 | ] 473 | }, 474 | { 475 | "cell_type": "code", 476 | "execution_count": 16, 477 | "metadata": {}, 478 | "outputs": [ 479 | { 480 | "data": { 481 | "text/plain": [ 482 | "1.0" 483 | ] 484 | }, 485 | "execution_count": 16, 486 | "metadata": {}, 487 | "output_type": "execute_result" 488 | } 489 | ], 490 | "source": [ 491 | "clf.score(X_test, y_test)" 492 | ] 493 | }, 494 | { 495 | "cell_type": "markdown", 496 | "metadata": {}, 497 | "source": [ 498 | "### sklearn.svm.SVC\n", 499 | "\n", 500 | "*(C=1.0, kernel='rbf', degree=3, gamma='auto', coef0=0.0, shrinking=True, probability=False,tol=0.001, cache_size=200, class_weight=None, verbose=False, max_iter=-1, decision_function_shape=None,random_state=None)*\n", 501 | "\n", 502 | "参数:\n", 503 | "\n", 504 | "- C:C-SVC的惩罚参数C?默认值是1.0\n", 505 | "\n", 506 | "C越大,相当于惩罚松弛变量,希望松弛变量接近0,即对误分类的惩罚增大,趋向于对训练集全分对的情况,这样对训练集测试时准确率很高,但泛化能力弱。C值小,对误分类的惩罚减小,允许容错,将他们当成噪声点,泛化能力较强。\n", 507 | "\n", 508 | "- kernel :核函数,默认是rbf,可以是‘linear’, ‘poly’, ‘rbf’, ‘sigmoid’, ‘precomputed’ \n", 509 | " \n", 510 | " – 线性:u'v\n", 511 | " \n", 512 | " – 多项式:(gamma*u'*v + coef0)^degree\n", 513 | "\n", 514 | " – RBF函数:exp(-gamma|u-v|^2)\n", 515 | "\n", 516 | " – sigmoid:tanh(gamma*u'*v + coef0)\n", 517 | "\n", 518 | "\n", 519 | "- degree :多项式poly函数的维度,默认是3,选择其他核函数时会被忽略。\n", 520 | "\n", 521 | "\n", 522 | "- gamma : ‘rbf’,‘poly’ 和‘sigmoid’的核函数参数。默认是’auto’,则会选择1/n_features\n", 523 | "\n", 524 | "\n", 525 | "- coef0 :核函数的常数项。对于‘poly’和 ‘sigmoid’有用。\n", 526 | "\n", 527 | "\n", 528 | "- probability :是否采用概率估计?.默认为False\n", 529 | "\n", 530 | "\n", 531 | "- shrinking :是否采用shrinking heuristic方法,默认为true\n", 532 | "\n", 533 | "\n", 534 | "- tol :停止训练的误差值大小,默认为1e-3\n", 535 | "\n", 536 | "\n", 537 | "- cache_size :核函数cache缓存大小,默认为200\n", 538 | "\n", 539 | "\n", 540 | "- class_weight :类别的权重,字典形式传递。设置第几类的参数C为weight*C(C-SVC中的C)\n", 541 | "\n", 542 | "\n", 543 | "- verbose :允许冗余输出?\n", 544 | "\n", 545 | "\n", 546 | "- max_iter :最大迭代次数。-1为无限制。\n", 547 | "\n", 548 | "\n", 549 | "- decision_function_shape :‘ovo’, ‘ovr’ or None, default=None3\n", 550 | "\n", 551 | "\n", 552 | "- random_state :数据洗牌时的种子值,int值\n", 553 | "\n", 554 | "\n", 555 | "主要调节的参数有:C、kernel、degree、gamma、coef0。" 556 | ] 557 | }, 558 | { 559 | "cell_type": "code", 560 | "execution_count": null, 561 | "metadata": { 562 | "collapsed": true 563 | }, 564 | "outputs": [], 565 | "source": [] 566 | } 567 | ], 568 | "metadata": { 569 | "kernelspec": { 570 | "display_name": "Python 3", 571 | "language": "python", 572 | "name": "python3" 573 | }, 574 | "language_info": { 575 | "codemirror_mode": { 576 | "name": "ipython", 577 | "version": 3 578 | }, 579 | "file_extension": ".py", 580 | "mimetype": "text/x-python", 581 | "name": "python", 582 | "nbconvert_exporter": "python", 583 | "pygments_lexer": "ipython3", 584 | "version": "3.6.1" 585 | } 586 | }, 587 | "nbformat": 4, 588 | "nbformat_minor": 2 589 | } 590 | -------------------------------------------------------------------------------- /notebooks/2-knn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import pandas as pd\n", 12 | "import numpy as np\n", 13 | "import cv2\n", 14 | "import random\n", 15 | "import time\n", 16 | "\n", 17 | "from sklearn.model_selection import train_test_split\n", 18 | "from sklearn.metrics import accuracy_score" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "数据集是mnist,28\\*28,这里选择提取HOG特征,方向梯度直方图(Histogram of Oriented Gradient, HOG):" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "metadata": { 32 | "collapsed": true 33 | }, 34 | "outputs": [], 35 | "source": [ 36 | "raw_data = pd.read_csv('../data/train.csv',header=0)" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 4, 42 | "metadata": { 43 | "collapsed": false 44 | }, 45 | "outputs": [ 46 | { 47 | "data": { 48 | "text/html": [ 49 | "
\n", 50 | "\n", 63 | "\n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | "
labelpixel0pixel1pixel2pixel3pixel4pixel5pixel6pixel7pixel8...pixel774pixel775pixel776pixel777pixel778pixel779pixel780pixel781pixel782pixel783
01000000000...0000000000
10000000000...0000000000
21000000000...0000000000
34000000000...0000000000
40000000000...0000000000
\n", 213 | "

5 rows × 785 columns

\n", 214 | "
" 215 | ], 216 | "text/plain": [ 217 | " label pixel0 pixel1 pixel2 pixel3 pixel4 pixel5 pixel6 pixel7 \\\n", 218 | "0 1 0 0 0 0 0 0 0 0 \n", 219 | "1 0 0 0 0 0 0 0 0 0 \n", 220 | "2 1 0 0 0 0 0 0 0 0 \n", 221 | "3 4 0 0 0 0 0 0 0 0 \n", 222 | "4 0 0 0 0 0 0 0 0 0 \n", 223 | "\n", 224 | " pixel8 ... pixel774 pixel775 pixel776 pixel777 pixel778 \\\n", 225 | "0 0 ... 0 0 0 0 0 \n", 226 | "1 0 ... 0 0 0 0 0 \n", 227 | "2 0 ... 0 0 0 0 0 \n", 228 | "3 0 ... 0 0 0 0 0 \n", 229 | "4 0 ... 0 0 0 0 0 \n", 230 | "\n", 231 | " pixel779 pixel780 pixel781 pixel782 pixel783 \n", 232 | "0 0 0 0 0 0 \n", 233 | "1 0 0 0 0 0 \n", 234 | "2 0 0 0 0 0 \n", 235 | "3 0 0 0 0 0 \n", 236 | "4 0 0 0 0 0 \n", 237 | "\n", 238 | "[5 rows x 785 columns]" 239 | ] 240 | }, 241 | "execution_count": 4, 242 | "metadata": {}, 243 | "output_type": "execute_result" 244 | } 245 | ], 246 | "source": [ 247 | "raw_data.head()" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": 7, 253 | "metadata": { 254 | "collapsed": false 255 | }, 256 | "outputs": [ 257 | { 258 | "data": { 259 | "text/plain": [ 260 | "(42000, 785)" 261 | ] 262 | }, 263 | "execution_count": 7, 264 | "metadata": {}, 265 | "output_type": "execute_result" 266 | } 267 | ], 268 | "source": [ 269 | "raw_data.shape" 270 | ] 271 | }, 272 | { 273 | "cell_type": "markdown", 274 | "metadata": {}, 275 | "source": [ 276 | "两个冒号的语法:\n", 277 | " seq[start:end:step]\n", 278 | "原来是\n", 279 | " imgs = data[0::,1::]\n", 280 | " labels = data[::,0]\n", 281 | "没必要这样写" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": 14, 287 | "metadata": { 288 | "collapsed": true 289 | }, 290 | "outputs": [], 291 | "source": [ 292 | "data = raw_data.values\n", 293 | "imgs = data[:, 1:]\n", 294 | "labels = data[:, 0]" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": 15, 300 | "metadata": { 301 | "collapsed": false 302 | }, 303 | "outputs": [ 304 | { 305 | "data": { 306 | "text/plain": [ 307 | "(42000, 784)" 308 | ] 309 | }, 310 | "execution_count": 15, 311 | "metadata": {}, 312 | "output_type": "execute_result" 313 | } 314 | ], 315 | "source": [ 316 | "imgs.shape" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": 29, 322 | "metadata": { 323 | "collapsed": false 324 | }, 325 | "outputs": [ 326 | { 327 | "data": { 328 | "text/plain": [ 329 | "array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])" 330 | ] 331 | }, 332 | "execution_count": 29, 333 | "metadata": {}, 334 | "output_type": "execute_result" 335 | } 336 | ], 337 | "source": [ 338 | "np.unique(labels)" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": 18, 344 | "metadata": { 345 | "collapsed": true 346 | }, 347 | "outputs": [], 348 | "source": [ 349 | "# 利用opencv获取图像hog特征\n", 350 | "def get_hog_features(trainset):\n", 351 | " features = []\n", 352 | "\n", 353 | " hog = cv2.HOGDescriptor('../hog.xml')\n", 354 | "\n", 355 | " for img in trainset:\n", 356 | " img = np.reshape(img,(28,28))\n", 357 | " cv_img = img.astype(np.uint8)\n", 358 | "\n", 359 | " hog_feature = hog.compute(cv_img)\n", 360 | " # hog_feature = np.transpose(hog_feature)\n", 361 | " features.append(hog_feature)\n", 362 | "\n", 363 | " features = np.array(features)\n", 364 | " features = np.reshape(features,(-1,324))\n", 365 | "\n", 366 | " return features" 367 | ] 368 | }, 369 | { 370 | "cell_type": "code", 371 | "execution_count": 19, 372 | "metadata": { 373 | "collapsed": true 374 | }, 375 | "outputs": [], 376 | "source": [ 377 | "features = get_hog_features(imgs)" 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": 21, 383 | "metadata": { 384 | "collapsed": false 385 | }, 386 | "outputs": [ 387 | { 388 | "data": { 389 | "text/plain": [ 390 | "(42000, 324)" 391 | ] 392 | }, 393 | "execution_count": 21, 394 | "metadata": {}, 395 | "output_type": "execute_result" 396 | } 397 | ], 398 | "source": [ 399 | "features.shape" 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "execution_count": 112, 405 | "metadata": { 406 | "collapsed": false 407 | }, 408 | "outputs": [ 409 | { 410 | "data": { 411 | "text/plain": [ 412 | "(42000,)" 413 | ] 414 | }, 415 | "execution_count": 112, 416 | "metadata": {}, 417 | "output_type": "execute_result" 418 | } 419 | ], 420 | "source": [ 421 | "labels.shape" 422 | ] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": 22, 427 | "metadata": { 428 | "collapsed": true 429 | }, 430 | "outputs": [], 431 | "source": [ 432 | "train_features, test_features, train_labels, test_labels = train_test_split(features, labels, test_size=0.33, random_state=23323)" 433 | ] 434 | }, 435 | { 436 | "cell_type": "markdown", 437 | "metadata": {}, 438 | "source": [ 439 | "# 预测\n", 440 | "\n", 441 | "因为knn不需要训练,我们可以直接进行预测。不过因为4万个数据即使是预测也非常花时间,这里只取前100个样本做训练集,去30个样本做测试集:" 442 | ] 443 | }, 444 | { 445 | "cell_type": "code", 446 | "execution_count": 113, 447 | "metadata": { 448 | "collapsed": true 449 | }, 450 | "outputs": [], 451 | "source": [ 452 | "testset, trainset, train_labels = test_features[:30], train_features[:100], train_labels[:100]" 453 | ] 454 | }, 455 | { 456 | "cell_type": "code", 457 | "execution_count": 121, 458 | "metadata": { 459 | "collapsed": true 460 | }, 461 | "outputs": [], 462 | "source": [ 463 | "k = 10 # 最近的10个点\n", 464 | "\n", 465 | "predict = []\n", 466 | "count = 0" 467 | ] 468 | }, 469 | { 470 | "cell_type": "code", 471 | "execution_count": 122, 472 | "metadata": { 473 | "collapsed": false 474 | }, 475 | "outputs": [ 476 | { 477 | "data": { 478 | "text/plain": [ 479 | "5.0" 480 | ] 481 | }, 482 | "execution_count": 122, 483 | "metadata": {}, 484 | "output_type": "execute_result" 485 | } 486 | ], 487 | "source": [ 488 | "# 计算两个点的欧氏距离\n", 489 | "np.linalg.norm(np.array([0, 3]) - np.array([4, 0]))" 490 | ] 491 | }, 492 | { 493 | "cell_type": "code", 494 | "execution_count": 123, 495 | "metadata": { 496 | "collapsed": false 497 | }, 498 | "outputs": [], 499 | "source": [ 500 | "time_1 = time.time()\n", 501 | "\n", 502 | "for test_vec in testset:\n", 503 | " # 输出当前运行的测试用例坐标,用于测试\n", 504 | " count += 1\n", 505 | " if count % 5000 == 0:\n", 506 | " print(count)\n", 507 | " \n", 508 | " knn_list = np.zeros((1, 2)) # 初始化,存放当前k个最近邻居\n", 509 | " \n", 510 | " # 先将前k个点放入k个最近邻居中,填充满knn_list\n", 511 | " for i in range(k):\n", 512 | " label = train_labels[i]\n", 513 | " train_vec = trainset[i]\n", 514 | "\n", 515 | " dist = np.linalg.norm(train_vec - test_vec) # 计算两个点的欧氏距离\n", 516 | " knn_list = np.append(knn_list, [[dist, label]], axis=0)\n", 517 | " \n", 518 | " # 剩下的点\n", 519 | " for i in range(k, len(train_labels)):\n", 520 | " label = train_labels[i]\n", 521 | " train_vec = trainset[i]\n", 522 | "\n", 523 | " dist = np.linalg.norm(train_vec - test_vec) # 计算两个点的欧氏距离\n", 524 | "\n", 525 | " # 寻找10个邻近点中距离最远的点\n", 526 | " max_index = np.argmax(knn_list[:, 0])\n", 527 | " max_dist = np.max(knn_list[:, 0])\n", 528 | "\n", 529 | " # 如果当前k个最近邻居中存在点距离比当前点距离远,则替换\n", 530 | " if dist < max_dist:\n", 531 | " knn_list[max_index] = [dist, label]\n", 532 | " \n", 533 | " \n", 534 | " # 上面代码计算全部运算完之后,即说明已经找到了离当前test_vec最近的10个train_vec\n", 535 | " # 统计选票\n", 536 | " class_total = 10\n", 537 | " class_count = [0 for i in range(class_total)]\n", 538 | " for dist, label in knn_list:\n", 539 | " class_count[int(label)] += 1\n", 540 | "\n", 541 | " # 找出最大选票数\n", 542 | " label_max = max(class_count)\n", 543 | "\n", 544 | " # 最大选票数对应的class\n", 545 | " predict.append(class_count.index(label_max))\n", 546 | "\n", 547 | "time_2 = time.time()\n" 548 | ] 549 | }, 550 | { 551 | "cell_type": "code", 552 | "execution_count": 124, 553 | "metadata": { 554 | "collapsed": false 555 | }, 556 | "outputs": [ 557 | { 558 | "name": "stdout", 559 | "output_type": "stream", 560 | "text": [ 561 | "train time is 0.07612895965576172\n" 562 | ] 563 | } 564 | ], 565 | "source": [ 566 | "print('train time is %s' % (time_2 - time_1))" 567 | ] 568 | }, 569 | { 570 | "cell_type": "code", 571 | "execution_count": 109, 572 | "metadata": { 573 | "collapsed": false 574 | }, 575 | "outputs": [ 576 | { 577 | "name": "stdout", 578 | "output_type": "stream", 579 | "text": [ 580 | "train time is 3\n" 581 | ] 582 | } 583 | ], 584 | "source": [ 585 | "print('train time is %s' % (5-2))" 586 | ] 587 | }, 588 | { 589 | "cell_type": "code", 590 | "execution_count": 106, 591 | "metadata": { 592 | "collapsed": false 593 | }, 594 | "outputs": [ 595 | { 596 | "data": { 597 | "text/plain": [ 598 | "array([[ 0. , 0. ],\n", 599 | " [ 1.10036302, 3. ],\n", 600 | " [ 1.09803486, 3. ],\n", 601 | " [ 1.09235775, 3. ],\n", 602 | " [ 1.03992426, 3. ],\n", 603 | " [ 1.04467952, 3. ],\n", 604 | " [ 1.06501627, 3. ],\n", 605 | " [ 0.93764162, 3. ],\n", 606 | " [ 1.05351973, 3. ],\n", 607 | " [ 1.04691565, 3. ],\n", 608 | " [ 0.9816038 , 3. ]])" 609 | ] 610 | }, 611 | "execution_count": 106, 612 | "metadata": {}, 613 | "output_type": "execute_result" 614 | } 615 | ], 616 | "source": [ 617 | "knn_list" 618 | ] 619 | }, 620 | { 621 | "cell_type": "code", 622 | "execution_count": 90, 623 | "metadata": { 624 | "collapsed": false 625 | }, 626 | "outputs": [ 627 | { 628 | "data": { 629 | "text/plain": [ 630 | "array([], dtype=float64)" 631 | ] 632 | }, 633 | "execution_count": 90, 634 | "metadata": {}, 635 | "output_type": "execute_result" 636 | } 637 | ], 638 | "source": [ 639 | "knn_list = np.array([]) # 当前k个最近邻居\n", 640 | " \n", 641 | "# 先将前k个点放入k个最近邻居中,填充满knn_list\n", 642 | "for i in range(k):\n", 643 | " label = train_labels[i]\n", 644 | " train_vec = trainset[i]\n", 645 | "\n", 646 | " dist = np.linalg.norm(train_vec - test_vec) # 计算两个点的欧氏距离\n", 647 | " knn_list_test = np.append(knn_list_test, [[8.5, 9]], axis=0)\n" 648 | ] 649 | }, 650 | { 651 | "cell_type": "markdown", 652 | "metadata": {}, 653 | "source": [ 654 | "# 测试用\n", 655 | "\n", 656 | "下面自己写一个寻找10个领近点中距离最远的点:" 657 | ] 658 | }, 659 | { 660 | "cell_type": "code", 661 | "execution_count": 96, 662 | "metadata": { 663 | "collapsed": false 664 | }, 665 | "outputs": [ 666 | { 667 | "data": { 668 | "text/plain": [ 669 | "array([[ 0., 0.]])" 670 | ] 671 | }, 672 | "execution_count": 96, 673 | "metadata": {}, 674 | "output_type": "execute_result" 675 | } 676 | ], 677 | "source": [ 678 | "knn_list = np.zeros((1, 2)) # 当前k个最近邻居\n", 679 | "knn_list" 680 | ] 681 | }, 682 | { 683 | "cell_type": "code", 684 | "execution_count": 94, 685 | "metadata": { 686 | "collapsed": false 687 | }, 688 | "outputs": [ 689 | { 690 | "ename": "ValueError", 691 | "evalue": "all the input array dimensions except for the concatenation axis must match exactly", 692 | "output_type": "error", 693 | "traceback": [ 694 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 695 | "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", 696 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mknn_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m8.5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m9\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 697 | "\u001b[0;32m/Users/xu/anaconda/envs/py35/lib/python3.5/site-packages/numpy/lib/function_base.py\u001b[0m in \u001b[0;36mappend\u001b[0;34m(arr, values, axis)\u001b[0m\n\u001b[1;32m 5145\u001b[0m \u001b[0mvalues\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mravel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5146\u001b[0m \u001b[0maxis\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0marr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 5147\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalues\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 698 | "\u001b[0;31mValueError\u001b[0m: all the input array dimensions except for the concatenation axis must match exactly" 699 | ] 700 | } 701 | ], 702 | "source": [ 703 | "np.append(knn_list, [[8.5, 9]], axis=0)" 704 | ] 705 | }, 706 | { 707 | "cell_type": "code", 708 | "execution_count": 78, 709 | "metadata": { 710 | "collapsed": false 711 | }, 712 | "outputs": [ 713 | { 714 | "data": { 715 | "text/plain": [ 716 | "array([[ 2.3, 1. ],\n", 717 | " [ 3.5, 1. ],\n", 718 | " [ 1.5, 4. ],\n", 719 | " [ 6.5, 2. ],\n", 720 | " [ 5.5, 8. ]])" 721 | ] 722 | }, 723 | "execution_count": 78, 724 | "metadata": {}, 725 | "output_type": "execute_result" 726 | } 727 | ], 728 | "source": [ 729 | "knn_list_test = np.array([[2.3, 1], [3.5, 1], [1.5, 4], [6.5, 2], [5.5, 8]])\n", 730 | "# 每个元组里,第一个是距离,第二个是对应标签\n", 731 | "knn_list_test" 732 | ] 733 | }, 734 | { 735 | "cell_type": "code", 736 | "execution_count": 79, 737 | "metadata": { 738 | "collapsed": false 739 | }, 740 | "outputs": [ 741 | { 742 | "data": { 743 | "text/plain": [ 744 | "array([ 2.3, 3.5, 1.5, 6.5, 5.5])" 745 | ] 746 | }, 747 | "execution_count": 79, 748 | "metadata": {}, 749 | "output_type": "execute_result" 750 | } 751 | ], 752 | "source": [ 753 | "knn_list_test[:, 0]" 754 | ] 755 | }, 756 | { 757 | "cell_type": "code", 758 | "execution_count": 80, 759 | "metadata": { 760 | "collapsed": true 761 | }, 762 | "outputs": [], 763 | "source": [ 764 | "knn_list_test[2] = [9.5, 5]" 765 | ] 766 | }, 767 | { 768 | "cell_type": "code", 769 | "execution_count": 81, 770 | "metadata": { 771 | "collapsed": false 772 | }, 773 | "outputs": [ 774 | { 775 | "data": { 776 | "text/plain": [ 777 | "array([[ 2.3, 1. ],\n", 778 | " [ 3.5, 1. ],\n", 779 | " [ 9.5, 5. ],\n", 780 | " [ 6.5, 2. ],\n", 781 | " [ 5.5, 8. ]])" 782 | ] 783 | }, 784 | "execution_count": 81, 785 | "metadata": {}, 786 | "output_type": "execute_result" 787 | } 788 | ], 789 | "source": [ 790 | "knn_list_test" 791 | ] 792 | }, 793 | { 794 | "cell_type": "markdown", 795 | "metadata": {}, 796 | "source": [ 797 | "要想给一个ndarray添加一个元素,必须是同样的格式,即必须是`[[8.5, 9]]`,不能使`[8.5, 9]`,而且必须要用axis指定才行。" 798 | ] 799 | }, 800 | { 801 | "cell_type": "code", 802 | "execution_count": 86, 803 | "metadata": { 804 | "collapsed": false 805 | }, 806 | "outputs": [ 807 | { 808 | "data": { 809 | "text/plain": [ 810 | "array([[ 2.3, 1. ],\n", 811 | " [ 3.5, 1. ],\n", 812 | " [ 9.5, 5. ],\n", 813 | " [ 6.5, 2. ],\n", 814 | " [ 5.5, 8. ],\n", 815 | " [ 8.5, 9. ],\n", 816 | " [ 8.5, 9. ]])" 817 | ] 818 | }, 819 | "execution_count": 86, 820 | "metadata": {}, 821 | "output_type": "execute_result" 822 | } 823 | ], 824 | "source": [ 825 | "np.append(knn_list_test, [[8.5, 9]], axis=0)" 826 | ] 827 | }, 828 | { 829 | "cell_type": "code", 830 | "execution_count": 87, 831 | "metadata": { 832 | "collapsed": false 833 | }, 834 | "outputs": [ 835 | { 836 | "data": { 837 | "text/plain": [ 838 | "array([[ 2.3, 1. ],\n", 839 | " [ 3.5, 1. ],\n", 840 | " [ 9.5, 5. ],\n", 841 | " [ 6.5, 2. ],\n", 842 | " [ 5.5, 8. ],\n", 843 | " [ 8.5, 9. ]])" 844 | ] 845 | }, 846 | "execution_count": 87, 847 | "metadata": {}, 848 | "output_type": "execute_result" 849 | } 850 | ], 851 | "source": [ 852 | "knn_list_test" 853 | ] 854 | }, 855 | { 856 | "cell_type": "code", 857 | "execution_count": 37, 858 | "metadata": { 859 | "collapsed": false 860 | }, 861 | "outputs": [ 862 | { 863 | "data": { 864 | "text/plain": [ 865 | "3" 866 | ] 867 | }, 868 | "execution_count": 37, 869 | "metadata": {}, 870 | "output_type": "execute_result" 871 | } 872 | ], 873 | "source": [ 874 | "knn_list_test[:, 0].argmax()" 875 | ] 876 | }, 877 | { 878 | "cell_type": "code", 879 | "execution_count": 41, 880 | "metadata": { 881 | "collapsed": false 882 | }, 883 | "outputs": [ 884 | { 885 | "data": { 886 | "text/plain": [ 887 | "array([], dtype=float64)" 888 | ] 889 | }, 890 | "execution_count": 41, 891 | "metadata": {}, 892 | "output_type": "execute_result" 893 | } 894 | ], 895 | "source": [ 896 | "np.array([])" 897 | ] 898 | }, 899 | { 900 | "cell_type": "markdown", 901 | "metadata": {}, 902 | "source": [ 903 | "# 输出评分\n", 904 | "\n", 905 | "统计结束后,得到predict" 906 | ] 907 | }, 908 | { 909 | "cell_type": "code", 910 | "execution_count": 125, 911 | "metadata": { 912 | "collapsed": false 913 | }, 914 | "outputs": [ 915 | { 916 | "data": { 917 | "text/plain": [ 918 | "30" 919 | ] 920 | }, 921 | "execution_count": 125, 922 | "metadata": {}, 923 | "output_type": "execute_result" 924 | } 925 | ], 926 | "source": [ 927 | "len(predict)" 928 | ] 929 | }, 930 | { 931 | "cell_type": "code", 932 | "execution_count": 127, 933 | "metadata": { 934 | "collapsed": true 935 | }, 936 | "outputs": [], 937 | "source": [ 938 | "test_predict = np.array(predict)" 939 | ] 940 | }, 941 | { 942 | "cell_type": "code", 943 | "execution_count": 128, 944 | "metadata": { 945 | "collapsed": false 946 | }, 947 | "outputs": [], 948 | "source": [ 949 | "score = accuracy_score(test_labels[:30], test_predict)\n" 950 | ] 951 | }, 952 | { 953 | "cell_type": "code", 954 | "execution_count": 129, 955 | "metadata": { 956 | "collapsed": false 957 | }, 958 | "outputs": [ 959 | { 960 | "data": { 961 | "text/plain": [ 962 | "0.6333333333333333" 963 | ] 964 | }, 965 | "execution_count": 129, 966 | "metadata": {}, 967 | "output_type": "execute_result" 968 | } 969 | ], 970 | "source": [ 971 | "score" 972 | ] 973 | }, 974 | { 975 | "cell_type": "code", 976 | "execution_count": null, 977 | "metadata": { 978 | "collapsed": true 979 | }, 980 | "outputs": [], 981 | "source": [] 982 | } 983 | ], 984 | "metadata": { 985 | "kernelspec": { 986 | "display_name": "Python [py35]", 987 | "language": "python", 988 | "name": "Python [py35]" 989 | }, 990 | "language_info": { 991 | "codemirror_mode": { 992 | "name": "ipython", 993 | "version": 3 994 | }, 995 | "file_extension": ".py", 996 | "mimetype": "text/x-python", 997 | "name": "python", 998 | "nbconvert_exporter": "python", 999 | "pygments_lexer": "ipython3", 1000 | "version": "3.5.2" 1001 | } 1002 | }, 1003 | "nbformat": 4, 1004 | "nbformat_minor": 0 1005 | } 1006 | -------------------------------------------------------------------------------- /notebooks/3-naive_bayes.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import pandas as pd\n", 12 | "import numpy as np\n", 13 | "import cv2\n", 14 | "import random\n", 15 | "import time\n", 16 | "\n", 17 | "from sklearn.model_selection import train_test_split\n", 18 | "from sklearn.metrics import accuracy_score" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "# 数据预处理" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 6, 31 | "metadata": { 32 | "collapsed": false 33 | }, 34 | "outputs": [ 35 | { 36 | "name": "stdout", 37 | "output_type": "stream", 38 | "text": [ 39 | "(28140, 784)\n", 40 | "(13860, 784)\n" 41 | ] 42 | } 43 | ], 44 | "source": [ 45 | "raw_data = pd.read_csv('../data/train.csv',header=0)\n", 46 | "data = raw_data.values\n", 47 | "imgs = data[:, 1:]\n", 48 | "labels = data[:, 0]\n", 49 | "# 选取 2/3 数据作为训练集, 1/3 数据作为测试集\n", 50 | "train_features, test_features, train_labels, test_labels = train_test_split(imgs, labels, test_size=0.33, random_state=23323)\n", 51 | "\n", 52 | "print(train_features.shape)\n", 53 | "print(test_features.shape)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "# train" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 7, 66 | "metadata": { 67 | "collapsed": true 68 | }, 69 | "outputs": [], 70 | "source": [ 71 | "# 二值化\n", 72 | "def binaryzation(img):\n", 73 | " cv_img = img.astype(np.uint8)\n", 74 | " cv2.threshold(cv_img, 50, 1, cv2.THRESH_BINARY_INV, cv_img)\n", 75 | " return cv_img" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": {}, 81 | "source": [ 82 | " cv2.threshold(cv_img, 50, 1, cv2.THRESH_BINARY_INV, cv_img)\n", 83 | "这句代码中,cv_img是输入的784个pixel数字(0~255),50表示阈值,1表示最大值,cv2.THRESH_BINARY_INV表示二值化的类型。这句代码表示pixel数字大于50的部分,为1,小于50的部分,为0。\n", 84 | "\n", 85 | "看一下经过二值化处理后是什么效果:" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 16, 91 | "metadata": { 92 | "collapsed": false 93 | }, 94 | "outputs": [ 95 | { 96 | "data": { 97 | "text/plain": [ 98 | "array([ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 99 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 100 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 101 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 102 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 103 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 104 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 105 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 106 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 107 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 108 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 109 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 110 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 111 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 63,\n", 112 | " 255, 253, 253, 244, 120, 22, 0, 0, 0, 0, 0, 0, 0,\n", 113 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 12,\n", 114 | " 100, 209, 253, 252, 252, 252, 252, 187, 6, 0, 0, 0, 0,\n", 115 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 116 | " 144, 217, 252, 161, 253, 183, 153, 106, 218, 252, 70, 0, 0,\n", 117 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 118 | " 87, 180, 242, 243, 202, 68, 10, 3, 0, 0, 60, 194, 31,\n", 119 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 120 | " 0, 5, 184, 252, 226, 93, 23, 0, 0, 0, 0, 0, 32,\n", 121 | " 142, 179, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 122 | " 0, 0, 0, 195, 252, 183, 29, 0, 0, 0, 0, 0, 0,\n", 123 | " 0, 141, 252, 45, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 124 | " 0, 0, 0, 0, 48, 247, 173, 38, 0, 0, 0, 0, 0,\n", 125 | " 0, 0, 26, 245, 252, 74, 0, 0, 0, 0, 0, 0, 0,\n", 126 | " 0, 0, 0, 0, 0, 0, 100, 229, 72, 0, 0, 0, 0,\n", 127 | " 0, 0, 0, 0, 132, 252, 252, 131, 0, 0, 0, 0, 0,\n", 128 | " 0, 0, 0, 0, 0, 0, 0, 0, 26, 153, 27, 0, 0,\n", 129 | " 0, 0, 0, 0, 0, 34, 132, 252, 252, 98, 0, 0, 0,\n", 130 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 242, 159,\n", 131 | " 111, 58, 68, 77, 0, 15, 34, 14, 180, 252, 252, 21, 0,\n", 132 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 133 | " 181, 253, 253, 253, 253, 114, 0, 0, 0, 100, 253, 253, 141,\n", 134 | " 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 135 | " 0, 0, 50, 229, 252, 249, 120, 20, 0, 0, 0, 176, 252,\n", 136 | " 252, 55, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 137 | " 0, 0, 0, 0, 0, 29, 44, 42, 0, 0, 0, 0, 0,\n", 138 | " 209, 252, 206, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 139 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 140 | " 3, 128, 251, 252, 92, 0, 0, 0, 0, 0, 0, 0, 0,\n", 141 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 142 | " 0, 0, 58, 252, 252, 238, 31, 0, 0, 0, 0, 0, 0,\n", 143 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 144 | " 0, 0, 0, 39, 230, 252, 252, 143, 0, 0, 0, 0, 0,\n", 145 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 146 | " 0, 0, 0, 0, 0, 116, 253, 252, 252, 20, 0, 0, 0,\n", 147 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 148 | " 0, 0, 0, 0, 0, 0, 14, 226, 253, 252, 172, 4, 0,\n", 149 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 150 | " 0, 0, 0, 0, 0, 0, 0, 0, 159, 252, 253, 232, 30,\n", 151 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 152 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 20, 216, 252, 253,\n", 153 | " 186, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 154 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 155 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 156 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 157 | " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", 158 | " 0, 0, 0, 0])" 159 | ] 160 | }, 161 | "execution_count": 16, 162 | "metadata": {}, 163 | "output_type": "execute_result" 164 | } 165 | ], 166 | "source": [ 167 | "trainset[0]" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 17, 173 | "metadata": { 174 | "collapsed": false 175 | }, 176 | "outputs": [ 177 | { 178 | "data": { 179 | "text/plain": [ 180 | "array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", 181 | " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", 182 | " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", 183 | " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", 184 | " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", 185 | " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", 186 | " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", 187 | " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,\n", 188 | " 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", 189 | " 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", 190 | " 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,\n", 191 | " 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1,\n", 192 | " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,\n", 193 | " 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1,\n", 194 | " 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,\n", 195 | " 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", 196 | " 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,\n", 197 | " 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1,\n", 198 | " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,\n", 199 | " 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,\n", 200 | " 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", 201 | " 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", 202 | " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1,\n", 203 | " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1,\n", 204 | " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,\n", 205 | " 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", 206 | " 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", 207 | " 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", 208 | " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,\n", 209 | " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1,\n", 210 | " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,\n", 211 | " 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", 212 | " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", 213 | " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", 214 | " 1, 1], dtype=uint8)" 215 | ] 216 | }, 217 | "execution_count": 17, 218 | "metadata": {}, 219 | "output_type": "execute_result" 220 | } 221 | ], 222 | "source": [ 223 | "binaryzation(trainset[0]) # 图片二值化" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": 8, 229 | "metadata": { 230 | "collapsed": true 231 | }, 232 | "outputs": [], 233 | "source": [ 234 | "trainset, train_labels = train_features, train_labels" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": 9, 240 | "metadata": { 241 | "collapsed": true 242 | }, 243 | "outputs": [], 244 | "source": [ 245 | "class_num = 10\n", 246 | "feature_len = 784" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 11, 252 | "metadata": { 253 | "collapsed": false 254 | }, 255 | "outputs": [ 256 | { 257 | "name": "stdout", 258 | "output_type": "stream", 259 | "text": [ 260 | "(10,)\n", 261 | "(10, 784, 2)\n" 262 | ] 263 | } 264 | ], 265 | "source": [ 266 | "# 存放先验概率\n", 267 | "prior_probability = np.zeros(class_num) \n", 268 | "print(prior_probability.shape)\n", 269 | "# 存放条件概率\n", 270 | "conditional_probability = np.zeros((class_num, feature_len, 2)) \n", 271 | "print(conditional_probability.shape)" 272 | ] 273 | }, 274 | { 275 | "cell_type": "markdown", 276 | "metadata": {}, 277 | "source": [ 278 | "具体文章参考这一篇[机器学习通俗入门-朴素贝叶斯分类器](http://blog.csdn.net/TaiJi1985/article/details/73657994)\n", 279 | "\n", 280 | "$x^{(i)}$ 为一个28维的向量表示第i个样本, $y^{(i)}$ 为标注的类别。我们求解的目标是:\n", 281 | "\n", 282 | "$$f = \\underset{j}{arg maxP} (y^{(i)} = j \\mid x^{(i)}) $$\n", 283 | "\n", 284 | "简单说就是计算$p (y^{(i)} = 0 \\mid x^{(i)}) $,$p (y^{(i)} = 1 \\mid x^{(i)}) $ ... $p (y^{(i)} = 9 \\mid x^{(i)}) $,从中找出一个最大的,如果从属于第j个类的概率最大,那么就认为这张图片从属于j这个类。\n" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 23, 290 | "metadata": { 291 | "collapsed": true 292 | }, 293 | "outputs": [], 294 | "source": [ 295 | "# 计算先验概率及条件概率\n", 296 | "for i in range(len(train_labels)):\n", 297 | " img = binaryzation(trainset[i]) # 图片二值化\n", 298 | " label = train_labels[i]\n", 299 | "\n", 300 | " prior_probability[label] += 1 # 每个label的图片各有多少个\n", 301 | "\n", 302 | " for j in range(feature_len):\n", 303 | " conditional_probability[label][j][img[j]] += 1 \n", 304 | " # img[j]表示在像素点j上的值。如果是0,就会给第一个位置+1,如果是1,会给第二个位置+1\n", 305 | " # 比如下面的conditional_probability[0][0],结果是[0, 2711]。\n", 306 | " # 说明在img中,标签为0的样本中,像素点为0的对应位置,在img中分别为0或1的数量" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": 39, 312 | "metadata": { 313 | "collapsed": false 314 | }, 315 | "outputs": [ 316 | { 317 | "data": { 318 | "text/plain": [ 319 | "array([ 2711., 3197., 2828., 2897., 2751., 2565., 2769., 2964.,\n", 320 | " 2654., 2804.])" 321 | ] 322 | }, 323 | "execution_count": 39, 324 | "metadata": {}, 325 | "output_type": "execute_result" 326 | } 327 | ], 328 | "source": [ 329 | "prior_probability" 330 | ] 331 | }, 332 | { 333 | "cell_type": "markdown", 334 | "metadata": {}, 335 | "source": [ 336 | "把上面的循环拆解分析一下,这里取第一个训练样本:\n" 337 | ] 338 | }, 339 | { 340 | "cell_type": "code", 341 | "execution_count": 25, 342 | "metadata": { 343 | "collapsed": false 344 | }, 345 | "outputs": [ 346 | { 347 | "data": { 348 | "text/plain": [ 349 | "9" 350 | ] 351 | }, 352 | "execution_count": 25, 353 | "metadata": {}, 354 | "output_type": "execute_result" 355 | } 356 | ], 357 | "source": [ 358 | "img1 = binaryzation(trainset[0]) # 图片二值化, 784个像素(feature)中,要么是0,要么是1\n", 359 | "label1 = train_labels[0]\n", 360 | "label1" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": 26, 366 | "metadata": { 367 | "collapsed": false 368 | }, 369 | "outputs": [ 370 | { 371 | "data": { 372 | "text/plain": [ 373 | "1" 374 | ] 375 | }, 376 | "execution_count": 26, 377 | "metadata": {}, 378 | "output_type": "execute_result" 379 | } 380 | ], 381 | "source": [ 382 | "img1[0]" 383 | ] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "execution_count": 24, 388 | "metadata": { 389 | "collapsed": false 390 | }, 391 | "outputs": [ 392 | { 393 | "data": { 394 | "text/plain": [ 395 | "array([ 0., 2711.])" 396 | ] 397 | }, 398 | "execution_count": 24, 399 | "metadata": {}, 400 | "output_type": "execute_result" 401 | } 402 | ], 403 | "source": [ 404 | "print(conditional_probability[0][0]) # " 405 | ] 406 | }, 407 | { 408 | "cell_type": "markdown", 409 | "metadata": {}, 410 | "source": [ 411 | "说明在img中,标签为0的样本中,像素点为500的对应位置,在img中为0的样本数量是199,为1的样本数量是2512" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": 27, 417 | "metadata": { 418 | "collapsed": false 419 | }, 420 | "outputs": [ 421 | { 422 | "name": "stdout", 423 | "output_type": "stream", 424 | "text": [ 425 | "[ 199. 2512.]\n" 426 | ] 427 | } 428 | ], 429 | "source": [ 430 | "print(conditional_probability[0][500]) # " 431 | ] 432 | }, 433 | { 434 | "cell_type": "markdown", 435 | "metadata": {}, 436 | "source": [ 437 | "下面之所以将概率归到[1.10001],是因为上面所有关于概率的部分都是直接用样本数量,而不是实际的概率来记录的。这么做应该是为了在工程上解决内存,但是这种工程上的优化,对于理解书中的公式造成了影响。" 438 | ] 439 | }, 440 | { 441 | "cell_type": "markdown", 442 | "metadata": {}, 443 | "source": [ 444 | "而且下面计算概率的时候有点问题:\n", 445 | " probalility_0 = (float(pix_0)/float(pix_0+pix_1))*1000000 + 1\n", 446 | "分母部分是,属于i类(0~9)的图像中,像素j的数量……对啊,这个像素j的数量其实就是pix_0和pix_1的和,即属于i类的图像的数量。看来这里没问题,是我想多了。" 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": 28, 452 | "metadata": { 453 | "collapsed": true 454 | }, 455 | "outputs": [], 456 | "source": [ 457 | "# 将概率归到[1.10001]\n", 458 | "for i in range(class_num):\n", 459 | " for j in range(feature_len):\n", 460 | "\n", 461 | " # 经过二值化后图像只有0,1两种取值\n", 462 | " pix_0 = conditional_probability[i][j][0] # 属于i类(0~9)的图像中,像素j(0~783)为0的数量\n", 463 | " pix_1 = conditional_probability[i][j][1] # 属于i类(0~9)的图像中,像素j(0~783)为1的数量\n", 464 | "\n", 465 | " # 计算0,1像素点对应的条件概率\n", 466 | " probalility_0 = (float(pix_0)/float(pix_0+pix_1))*1000000 + 1\n", 467 | " probalility_1 = (float(pix_1)/float(pix_0+pix_1))*1000000 + 1\n", 468 | "\n", 469 | " conditional_probability[i][j][0] = probalility_0\n", 470 | " conditional_probability[i][j][1] = probalility_1" 471 | ] 472 | }, 473 | { 474 | "cell_type": "code", 475 | "execution_count": 29, 476 | "metadata": { 477 | "collapsed": false 478 | }, 479 | "outputs": [ 480 | { 481 | "data": { 482 | "text/plain": [ 483 | "array([ 1.00000000e+00, 1.00000100e+06])" 484 | ] 485 | }, 486 | "execution_count": 29, 487 | "metadata": {}, 488 | "output_type": "execute_result" 489 | } 490 | ], 491 | "source": [ 492 | "conditional_probability[0][0]" 493 | ] 494 | }, 495 | { 496 | "cell_type": "markdown", 497 | "metadata": {}, 498 | "source": [ 499 | "得到了prior_probability和conditional_probability,这就算是训练结束了。\n", 500 | "\n", 501 | "# test (predict)\n" 502 | ] 503 | }, 504 | { 505 | "cell_type": "code", 506 | "execution_count": 34, 507 | "metadata": { 508 | "collapsed": false 509 | }, 510 | "outputs": [ 511 | { 512 | "data": { 513 | "text/plain": [ 514 | "(100, 784)" 515 | ] 516 | }, 517 | "execution_count": 34, 518 | "metadata": {}, 519 | "output_type": "execute_result" 520 | } 521 | ], 522 | "source": [ 523 | "# 为了加快预测速度,这里直接取100个测试样本\n", 524 | "\n", 525 | "testset = test_features[:100]\n", 526 | "testset.shape" 527 | ] 528 | }, 529 | { 530 | "cell_type": "markdown", 531 | "metadata": {}, 532 | "source": [ 533 | "\n", 534 | "\n", 535 | "$$p (y^{(i)} = j \\mid x_{k}^{(i)}) = \\frac{p (x_{k}^{(i)} \\mid y^{(i)} = j) \\cdot p(y^{(i)} = j)}{p(x_{k}^{(i)})}$$\n", 536 | "\n", 537 | "$p (y^{(i)} = j \\mid x_{k}^{(i)}) $中,$y^{(i)} = j$表示从属于哪一类,$x_{k}^{(i)}$表示哪一个像素点。\n", 538 | "\n", 539 | "下面calculate_probability函数就是在计算分子部分。\n", 540 | "\n", 541 | "`probability *= int(conditional_probability[label][i][img[i]])`\n", 542 | "\n", 543 | "这行代码中:\n", 544 | "- probability表示先验概率 $p(y^{(i)} = j)$\n", 545 | "- `conditional_probability[label][i][img[i]]`表示 $p (x_{k}^{(i)} \\mid y^{(i)} = j) $\n" 546 | ] 547 | }, 548 | { 549 | "cell_type": "code", 550 | "execution_count": 35, 551 | "metadata": { 552 | "collapsed": true 553 | }, 554 | "outputs": [], 555 | "source": [ 556 | "# 计算不同标签下,testdata的概率\n", 557 | "def calculate_probability(img, label):\n", 558 | " probability = int(prior_probability[label])\n", 559 | "\n", 560 | " for i in range(len(img)):\n", 561 | " probability *= int(conditional_probability[label][i][img[i]])\n", 562 | "\n", 563 | " return probability" 564 | ] 565 | }, 566 | { 567 | "cell_type": "code", 568 | "execution_count": 36, 569 | "metadata": { 570 | "collapsed": false 571 | }, 572 | "outputs": [], 573 | "source": [ 574 | "predict = []\n", 575 | "\n", 576 | "for img in testset:\n", 577 | "\n", 578 | " # 图像二值化\n", 579 | " img = binaryzation(img)\n", 580 | "\n", 581 | " max_label = 0\n", 582 | " max_probability = calculate_probability(img, 0)\n", 583 | "\n", 584 | " for j in range(1, 10):\n", 585 | " probability = calculate_probability(img, j)\n", 586 | "\n", 587 | " if max_probability < probability:\n", 588 | " max_label = j\n", 589 | " max_probability = probability\n", 590 | "\n", 591 | " predict.append(max_label)" 592 | ] 593 | }, 594 | { 595 | "cell_type": "code", 596 | "execution_count": 37, 597 | "metadata": { 598 | "collapsed": true 599 | }, 600 | "outputs": [], 601 | "source": [ 602 | "test_predict = np.array(predict)" 603 | ] 604 | }, 605 | { 606 | "cell_type": "code", 607 | "execution_count": 38, 608 | "metadata": { 609 | "collapsed": false 610 | }, 611 | "outputs": [ 612 | { 613 | "data": { 614 | "text/plain": [ 615 | "0.76000000000000001" 616 | ] 617 | }, 618 | "execution_count": 38, 619 | "metadata": {}, 620 | "output_type": "execute_result" 621 | } 622 | ], 623 | "source": [ 624 | "score = accuracy_score(test_labels[:100], test_predict)\n", 625 | "score" 626 | ] 627 | }, 628 | { 629 | "cell_type": "markdown", 630 | "metadata": {}, 631 | "source": [ 632 | "# 重构朴素贝叶斯算法\n", 633 | "\n", 634 | "![](https://pic1.zhimg.com/v2-e17426fd0627560f1fc82118dd1d5d14_r.jpg)\n", 635 | "\n", 636 | "朴素贝叶斯认为所有特征都是独立的,然后得出一个样本出现的概率使其所有特征出现概率的联乘。\n", 637 | "\n", 638 | "首先求每一个标签的先验概率:" 639 | ] 640 | }, 641 | { 642 | "cell_type": "code", 643 | "execution_count": 46, 644 | "metadata": { 645 | "collapsed": false 646 | }, 647 | "outputs": [ 648 | { 649 | "name": "stdout", 650 | "output_type": "stream", 651 | "text": [ 652 | "(10,)\n", 653 | "(10, 784, 2)\n" 654 | ] 655 | } 656 | ], 657 | "source": [ 658 | "class_num = 10\n", 659 | "feature_len = 784\n", 660 | "\n", 661 | "# 存放每个label的数量\n", 662 | "class_number = np.zeros(class_num) \n", 663 | "\n", 664 | "# 存放先验概率\n", 665 | "prior_probability = np.zeros(class_num) \n", 666 | "print(prior_probability.shape)\n", 667 | "# 存放条件概率\n", 668 | "conditional_probability = np.zeros((class_num, feature_len, 2)) \n", 669 | "print(conditional_probability.shape)" 670 | ] 671 | }, 672 | { 673 | "cell_type": "code", 674 | "execution_count": 47, 675 | "metadata": { 676 | "collapsed": false 677 | }, 678 | "outputs": [], 679 | "source": [ 680 | "# 计算先验概率\n", 681 | "for i in range(len(train_labels)):\n", 682 | " img = binaryzation(trainset[i]) # 图片二值化\n", 683 | " label = train_labels[i]\n", 684 | "\n", 685 | " class_number[label] += 1 # 每个label的图片各有多少个\n" 686 | ] 687 | }, 688 | { 689 | "cell_type": "code", 690 | "execution_count": 48, 691 | "metadata": { 692 | "collapsed": false 693 | }, 694 | "outputs": [ 695 | { 696 | "data": { 697 | "text/plain": [ 698 | "array([ 0.09633973, 0.11361052, 0.10049751, 0.10294954, 0.09776119,\n", 699 | " 0.09115139, 0.09840085, 0.10533049, 0.09431414, 0.09964463])" 700 | ] 701 | }, 702 | "execution_count": 48, 703 | "metadata": {}, 704 | "output_type": "execute_result" 705 | } 706 | ], 707 | "source": [ 708 | "class_number/len(train_labels)" 709 | ] 710 | }, 711 | { 712 | "cell_type": "code", 713 | "execution_count": 49, 714 | "metadata": { 715 | "collapsed": false 716 | }, 717 | "outputs": [ 718 | { 719 | "data": { 720 | "text/plain": [ 721 | "array([ 2711., 3197., 2828., 2897., 2751., 2565., 2769., 2964.,\n", 722 | " 2654., 2804.])" 723 | ] 724 | }, 725 | "execution_count": 49, 726 | "metadata": {}, 727 | "output_type": "execute_result" 728 | } 729 | ], 730 | "source": [ 731 | "class_number" 732 | ] 733 | }, 734 | { 735 | "cell_type": "code", 736 | "execution_count": 50, 737 | "metadata": { 738 | "collapsed": true 739 | }, 740 | "outputs": [], 741 | "source": [ 742 | "prior_probability = class_number / len(train_labels)" 743 | ] 744 | }, 745 | { 746 | "cell_type": "markdown", 747 | "metadata": {}, 748 | "source": [ 749 | "计算条件概率: \n", 750 | "\n", 751 | "$$p (X^{(i)} = a_{jl} \\mid Y = c_k)$$\n", 752 | "\n", 753 | "在标签为$c_k$的前提下,样本x的第$j$个特征(像素点)的第$l$个值(经过二值化处理,这里的$l$只有0或1两种可能)。conditional_probability的维度是`(10, 784, 2)`,最后的那个2,指的就是每个特征可以取的值。如果不做二值化处理,那么每个像素点应该有256种取值。" 754 | ] 755 | }, 756 | { 757 | "cell_type": "code", 758 | "execution_count": 66, 759 | "metadata": { 760 | "collapsed": false 761 | }, 762 | "outputs": [], 763 | "source": [ 764 | "# 条件概率\n", 765 | "conditional_probability = np.zeros((class_num, feature_len, 2)) \n", 766 | "\n", 767 | "for i in range(len(train_labels)):\n", 768 | " img = binaryzation(trainset[i]) # 图片二值化\n", 769 | " label = train_labels[i]\n", 770 | " for j in range(feature_len):\n", 771 | " conditional_probability[label][j][img[j]] += 1 # 这里只得到a_jl的数量" 772 | ] 773 | }, 774 | { 775 | "cell_type": "code", 776 | "execution_count": 59, 777 | "metadata": { 778 | "collapsed": false 779 | }, 780 | "outputs": [ 781 | { 782 | "data": { 783 | "text/plain": [ 784 | "array([ 199., 2512.])" 785 | ] 786 | }, 787 | "execution_count": 59, 788 | "metadata": {}, 789 | "output_type": "execute_result" 790 | } 791 | ], 792 | "source": [ 793 | "conditional_probability[0][500] " 794 | ] 795 | }, 796 | { 797 | "cell_type": "code", 798 | "execution_count": 65, 799 | "metadata": { 800 | "collapsed": false 801 | }, 802 | "outputs": [ 803 | { 804 | "data": { 805 | "text/plain": [ 806 | "2711.0" 807 | ] 808 | }, 809 | "execution_count": 65, 810 | "metadata": {}, 811 | "output_type": "execute_result" 812 | } 813 | ], 814 | "source": [ 815 | "class_number[0]" 816 | ] 817 | }, 818 | { 819 | "cell_type": "code", 820 | "execution_count": 60, 821 | "metadata": { 822 | "collapsed": false 823 | }, 824 | "outputs": [ 825 | { 826 | "data": { 827 | "text/plain": [ 828 | "array([ 0.07340465, 0.92659535])" 829 | ] 830 | }, 831 | "execution_count": 60, 832 | "metadata": {}, 833 | "output_type": "execute_result" 834 | } 835 | ], 836 | "source": [ 837 | "conditional_probability[0][500] / class_number[0]" 838 | ] 839 | }, 840 | { 841 | "cell_type": "code", 842 | "execution_count": 69, 843 | "metadata": { 844 | "collapsed": true 845 | }, 846 | "outputs": [], 847 | "source": [ 848 | "conditional_probability_fraction = np.zeros((class_num, feature_len, 2)) \n", 849 | "\n", 850 | "for i in range(len(train_labels)):\n", 851 | " label = train_labels[i]\n", 852 | " for j in range(feature_len):\n", 853 | " conditional_probability_fraction[label][j] = conditional_probability[label][j] / class_number[label]" 854 | ] 855 | }, 856 | { 857 | "cell_type": "code", 858 | "execution_count": 70, 859 | "metadata": { 860 | "collapsed": false 861 | }, 862 | "outputs": [ 863 | { 864 | "data": { 865 | "text/plain": [ 866 | "array([ 0.07340465, 0.92659535])" 867 | ] 868 | }, 869 | "execution_count": 70, 870 | "metadata": {}, 871 | "output_type": "execute_result" 872 | } 873 | ], 874 | "source": [ 875 | "conditional_probability_fraction[0][500]" 876 | ] 877 | }, 878 | { 879 | "cell_type": "markdown", 880 | "metadata": {}, 881 | "source": [ 882 | "发现上面如果分开两循环写的话冗长,这里还是应该写在一起:" 883 | ] 884 | }, 885 | { 886 | "cell_type": "code", 887 | "execution_count": null, 888 | "metadata": { 889 | "collapsed": true 890 | }, 891 | "outputs": [], 892 | "source": [ 893 | "# 计算先验概率及条件概率\n", 894 | "for i in range(len(train_labels)):\n", 895 | " img = binaryzation(trainset[i]) # 图片二值化\n", 896 | " label = train_labels[i]\n", 897 | "\n", 898 | " class_number[label] += 1 # 每个label的图片各有多少个\n", 899 | " prior_probability = class_number / len(train_labels)\n", 900 | "\n", 901 | " for j in range(feature_len):\n", 902 | " conditional_probability[label][j][img[j]] += 1 \n", 903 | " # 在所有训练样本中,标签=0的样本中,像素点=0的对应位置上,一共有多少个样本是0,一共有多少个样本是1\n", 904 | " \n", 905 | "# 推荐概率 \n", 906 | "for i in range(class_num):\n", 907 | " for j in range(feature_len):\n", 908 | " conditional_probability[label][j] = conditional_probability[label][j] / class_number[label]" 909 | ] 910 | }, 911 | { 912 | "cell_type": "markdown", 913 | "metadata": {}, 914 | "source": [ 915 | "上面就算完成了第一步,计算完了先验概率和条件概率。接下来第二步对测试集进行预测:" 916 | ] 917 | }, 918 | { 919 | "cell_type": "code", 920 | "execution_count": 71, 921 | "metadata": { 922 | "collapsed": false 923 | }, 924 | "outputs": [ 925 | { 926 | "data": { 927 | "text/plain": [ 928 | "(100, 784)" 929 | ] 930 | }, 931 | "execution_count": 71, 932 | "metadata": {}, 933 | "output_type": "execute_result" 934 | } 935 | ], 936 | "source": [ 937 | "testset.shape" 938 | ] 939 | }, 940 | { 941 | "cell_type": "code", 942 | "execution_count": 72, 943 | "metadata": { 944 | "collapsed": true 945 | }, 946 | "outputs": [], 947 | "source": [ 948 | "# 写一个函数来计算每一个标签下,对应的概率\n", 949 | "def calculate_probability(img, label):\n", 950 | " probability = prior_probability[label] # 先验概率\n", 951 | "\n", 952 | " # 对每一个像素点进行迭代,计算在laebl固定的情况下,每一个像素点的概率,然后联乘\n", 953 | " for i in range(len(img)):\n", 954 | " probability *= conditional_probability[label][i][img[i]] \n", 955 | " # [i]表示一个测试样本中,第i个像素点\n", 956 | " # img[i]表示一个测试样本中,第i个像素点是0还是1\n", 957 | "\n", 958 | " return probability" 959 | ] 960 | }, 961 | { 962 | "cell_type": "code", 963 | "execution_count": 73, 964 | "metadata": { 965 | "collapsed": false 966 | }, 967 | "outputs": [ 968 | { 969 | "name": "stderr", 970 | "output_type": "stream", 971 | "text": [ 972 | "/Users/xu/anaconda/envs/py35/lib/python3.5/site-packages/ipykernel/__main__.py:7: RuntimeWarning: overflow encountered in double_scalars\n", 973 | "/Users/xu/anaconda/envs/py35/lib/python3.5/site-packages/ipykernel/__main__.py:7: RuntimeWarning: invalid value encountered in double_scalars\n" 974 | ] 975 | } 976 | ], 977 | "source": [ 978 | "predict = []\n", 979 | "\n", 980 | "for img in testset:\n", 981 | " img = binaryzation(img)\n", 982 | " \n", 983 | " max_label = 0\n", 984 | " max_probability = calculate_probability(img, 0)\n", 985 | " \n", 986 | " for j in range(1, 10):\n", 987 | " probability = calculate_probability(img, j)\n", 988 | " \n", 989 | " if max_probability < probability:\n", 990 | " max_label = j\n", 991 | " max_probability = probability\n", 992 | "\n", 993 | " predict.append(max_label)" 994 | ] 995 | }, 996 | { 997 | "cell_type": "markdown", 998 | "metadata": {}, 999 | "source": [ 1000 | "看来确实是这样,为了防止溢出,所以源代码里才一直用数量代替。\n", 1001 | "\n", 1002 | "看来这个算法不需要我改了,源代码其实已经考虑了溢出的问题。" 1003 | ] 1004 | } 1005 | ], 1006 | "metadata": { 1007 | "kernelspec": { 1008 | "display_name": "Python [py35]", 1009 | "language": "python", 1010 | "name": "Python [py35]" 1011 | }, 1012 | "language_info": { 1013 | "codemirror_mode": { 1014 | "name": "ipython", 1015 | "version": 3 1016 | }, 1017 | "file_extension": ".py", 1018 | "mimetype": "text/x-python", 1019 | "name": "python", 1020 | "nbconvert_exporter": "python", 1021 | "pygments_lexer": "ipython3", 1022 | "version": "3.5.2" 1023 | } 1024 | }, 1025 | "nbformat": 4, 1026 | "nbformat_minor": 0 1027 | } 1028 | -------------------------------------------------------------------------------- /notebooks/6-svm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np\n", 12 | "import pandas as pd\n", 13 | "from sklearn.datasets import load_iris\n", 14 | "from sklearn.model_selection import train_test_split\n", 15 | "import matplotlib.pyplot as plt\n", 16 | "%matplotlib inline" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 3, 22 | "metadata": { 23 | "collapsed": true 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "# data\n", 28 | "def create_data():\n", 29 | " iris = load_iris()\n", 30 | " df = pd.DataFrame(iris.data, columns=iris.feature_names)\n", 31 | " df['label'] = iris.target\n", 32 | " df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n", 33 | " data = np.array(df.iloc[:100, [0, 1, -1]])\n", 34 | " for i in range(len(data)):\n", 35 | " if data[i,-1] == 0:\n", 36 | " data[i,-1] = -1\n", 37 | " # print(data)\n", 38 | " return data[:,:2], data[:,-1]" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 4, 44 | "metadata": { 45 | "collapsed": true 46 | }, 47 | "outputs": [], 48 | "source": [ 49 | "X, y = create_data()\n", 50 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 9, 56 | "metadata": { 57 | "collapsed": false 58 | }, 59 | "outputs": [ 60 | { 61 | "data": { 62 | "text/plain": [ 63 | "array([[ 5.1, 3.5],\n", 64 | " [ 4.9, 3. ],\n", 65 | " [ 4.7, 3.2]])" 66 | ] 67 | }, 68 | "execution_count": 9, 69 | "metadata": {}, 70 | "output_type": "execute_result" 71 | } 72 | ], 73 | "source": [ 74 | "X[:3]" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 5, 80 | "metadata": { 81 | "collapsed": false 82 | }, 83 | "outputs": [ 84 | { 85 | "data": { 86 | "text/plain": [ 87 | "(75, 2)" 88 | ] 89 | }, 90 | "execution_count": 5, 91 | "metadata": {}, 92 | "output_type": "execute_result" 93 | } 94 | ], 95 | "source": [ 96 | "X_train.shape" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 8, 102 | "metadata": { 103 | "collapsed": false 104 | }, 105 | "outputs": [ 106 | { 107 | "data": { 108 | "text/plain": [ 109 | "array([[ 6.9, 3.1],\n", 110 | " [ 5.2, 2.7],\n", 111 | " [ 6.1, 2.8],\n", 112 | " [ 6.7, 3. ],\n", 113 | " [ 4.4, 3. ],\n", 114 | " [ 6.7, 3.1],\n", 115 | " [ 5. , 3. ],\n", 116 | " [ 5.4, 3.4],\n", 117 | " [ 4.9, 3. ],\n", 118 | " [ 5.2, 3.4]])" 119 | ] 120 | }, 121 | "execution_count": 8, 122 | "metadata": {}, 123 | "output_type": "execute_result" 124 | } 125 | ], 126 | "source": [ 127 | "X_train[:10]" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 6, 133 | "metadata": { 134 | "collapsed": false 135 | }, 136 | "outputs": [ 137 | { 138 | "data": { 139 | "text/plain": [ 140 | "" 141 | ] 142 | }, 143 | "execution_count": 6, 144 | "metadata": {}, 145 | "output_type": "execute_result" 146 | }, 147 | { 148 | "data": { 149 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAGZ9JREFUeJzt3X9sHOWdx/H394yv8bWAReMWsJNLCihqSSLSugTICXGg\nXkqaQoRQlAiKQhE5ELpS0aNqKtQfqBJISLRQdEQBdBTBBeVoGihHgjgoKkUklZMg5y4pKhxtY8MV\nE5TQHKYE93t/7Dqx12vvzu6O93me/bwky97Zyfj7zMA3m5nPPGPujoiIpOWvml2AiIg0npq7iEiC\n1NxFRBKk5i4ikiA1dxGRBKm5i4gkSM1dRCRBau4iIglScxcRSdBx1a5oZm1AHzDo7stL3rsAeBx4\nvbhos7vfOtX2Zs6c6XPmzMlUrIhIq9u5c+fb7t5Vab2qmztwI7APOGGS918obfpTmTNnDn19fRl+\nvYiImNnvq1mvqtMyZtYDfAm4v56iRERkelR7zv1HwDeBv0yxznlm1m9mW83szHIrmNlaM+szs76h\noaGstYqISJUqNnczWw685e47p1htFzDb3RcCPwa2lFvJ3Te4e6+793Z1VTxlJCIiNarmnPsS4BIz\nWwbMAE4ws4fd/crRFdz93TE/P2Vm/2JmM9397caXLCJSnyNHjjAwMMD777/f7FImNWPGDHp6emhv\nb6/pz1ds7u6+DlgHR1Mx/zy2sReXnwz80d3dzM6m8C+CAzVVJCKSs4GBAY4//njmzJmDmTW7nAnc\nnQMHDjAwMMDcuXNr2kaWtMw4ZnZdsYj1wOXA9Wb2ITAMrHI9BUREAvX+++8H29gBzIyPf/zj1HNt\nMlNzd/fngeeLP68fs/we4J6aqxAJ2Jbdg9zx9Cu8cXCYUzs7uHnpPFYs6m52WVKnUBv7qHrrq/mT\nu0gr2LJ7kHWb9zB8ZASAwYPDrNu8B0ANXoKm6QdEpnDH068cbeyjho+McMfTrzSpIknFtm3bmDdv\nHqeffjq33357w7ev5i4yhTcODmdaLlKNkZERbrjhBrZu3crevXvZuHEje/fubejv0GkZkSmc2tnB\nYJlGfmpnRxOqkWZp9HWXX//615x++ul86lOfAmDVqlU8/vjjfOYzn2lUyfrkLjKVm5fOo6O9bdyy\njvY2bl46r0kVyXQbve4yeHAY59h1ly27B2ve5uDgILNmzTr6uqenh8HB2rdXjpq7yBRWLOrmtssW\n0N3ZgQHdnR3cdtkCXUxtIbFed9FpGZEKVizqVjNvYXlcd+nu7mb//v1HXw8MDNDd3dj/xvTJXURk\nCpNdX6nnusvnP/95fvvb3/L666/zwQcf8Oijj3LJJZfUvL1y1NxFRKaQx3WX4447jnvuuYelS5fy\n6U9/mpUrV3LmmWUn0639dzR0ayIiiRk9Jdfou5SXLVvGsmXLGlFiWWruIiIVxHjdRadlREQSpOYu\nIpIgNXcRkQSpuYuIJEjNXUQkQWrukowtuwdZcvtzzP3Wf7Dk9ufqmvtDJG9f/epX+cQnPsH8+fNz\n2b6auyQhj8mdRPK0Zs0atm3bltv21dwlCbFO7iSR6N8EP5wP3+ssfO/fVPcmzz//fE466aQGFFee\nbmKSJOihGpKb/k3w86/BkeJ/S4f2F14DLFzZvLoq0Cd3SUIekzuJAPDsrcca+6gjw4XlAVNzlyTo\noRqSm0MD2ZYHQqdlJAl5Te4kwok9hVMx5ZYHTM1dkhHj5E4SgYu+M/6cO0B7R2F5HVavXs3zzz/P\n22+/TU9PD9///ve55ppr6iz2GDV3qVujHx4sEpTRi6bP3lo4FXNiT6Gx13kxdePGjQ0obnJq7lKX\n0Xz5aAxxNF8OqMFLOhauDDoZU44uqEpdlC8XCZOau9RF+XKJlbs3u4Qp1VufmrvURflyidGMGTM4\ncOBAsA3e3Tlw4AAzZsyoeRs65y51uXnpvHHn3EH5cglfT08PAwMDDA0NNbuUSc2YMYOentrjlmru\nUhflyyVG7e3tzJ07t9ll5Krq5m5mbUAfMOjuy0veM+AuYBnwHrDG3Xc1slAJl/LlIuHJ8sn9RmAf\ncEKZ9y4Gzih+LQbuLX4XaSnK/EsoqrqgamY9wJeA+ydZ5VLgIS/YDnSa2SkNqlEkCppTXkJSbVrm\nR8A3gb9M8n43MHbyhYHiMpGWocy/hKRiczez5cBb7r6z3l9mZmvNrM/M+kK+Si1SC2X+JSTVfHJf\nAlxiZr8DHgUuNLOHS9YZBGaNed1TXDaOu29w91537+3q6qqxZJEwKfMvIanY3N19nbv3uPscYBXw\nnLtfWbLaE8BVVnAOcMjd32x8uSLh0pzyEpKac+5mdh2Au68HnqIQg3yVQhTy6oZUJxIRZf4lJNas\n2297e3u9r6+vKb9bRCRWZrbT3Xsrrac7VCVYt2zZw8Yd+xlxp82M1Ytn8YMVC5pdlkgU1NwlSLds\n2cPD2/9w9PWI+9HXavAilWlWSAnSxh1lnlk5xXIRGU/NXYI0Msm1oMmWi8h4au4SpDazTMtFZDw1\ndwnS6sWzMi0XkfF0QVWCNHrRVGkZkdoo5y4iEhHl3KUuV9z3Ei++9s7R10tOO4lHrj23iRU1j+Zo\nlxjpnLtMUNrYAV587R2uuO+lJlXUPJqjXWKl5i4TlDb2SstTpjnaJVZq7iJT0BztEis1d5EpaI52\niZWau0yw5LSTMi1PmeZol1ipucsEj1x77oRG3qppmRWLurntsgV0d3ZgQHdnB7ddtkBpGQmecu4i\nIhFRzl3qkle2O8t2lS8XqZ2au0wwmu0ejQCOZruBupprlu3mVYNIq9A5d5kgr2x3lu0qXy5SHzV3\nmSCvbHeW7SpfLlIfNXeZIK9sd5btKl8uUh81d5kgr2x3lu0qXy5SH11QlQlGL1g2OqmSZbt51SDS\nKpRzFxGJiHLuOYsxgx1jzSJSGzX3GsSYwY6xZhGpnS6o1iDGDHaMNYtI7dTcaxBjBjvGmkWkdmru\nNYgxgx1jzSJSOzX3GsSYwY6xZhGpnS6o1iDGDHaMNYtI7Srm3M1sBvBL4CMU/jJ4zN2/W7LOBcDj\nwOvFRZvd/daptqucu4hIdo3Muf8ZuNDdD5tZO/ArM9vq7ttL1nvB3ZfXUqxMj1u27GHjjv2MuNNm\nxurFs/jBigV1rxtKfj6UOkRCULG5e+Gj/eHiy/biV3Nua5Wa3bJlDw9v/8PR1yPuR1+XNu0s64aS\nnw+lDpFQVHVB1czazOxl4C3gGXffUWa188ys38y2mtmZDa1S6rZxx/6ql2dZN5T8fCh1iISiqubu\n7iPufhbQA5xtZvNLVtkFzHb3hcCPgS3ltmNma82sz8z6hoaG6qlbMhqZ5NpKueVZ1g0lPx9KHSKh\nyBSFdPeDwC+AL5Ysf9fdDxd/fgpoN7OZZf78Bnfvdfferq6uOsqWrNrMql6eZd1Q8vOh1CESiorN\n3cy6zKyz+HMH8AXgNyXrnGxW+D/fzM4ubvdA48uVWq1ePKvq5VnWDSU/H0odIqGoJi1zCvATM2uj\n0LQ3ufuTZnYdgLuvBy4HrjezD4FhYJU3ay5hKWv0Qmg1CZgs64aSnw+lDpFQaD53EZGIaD73nOWV\nqc6SL89z21nGF+O+iE7/Jnj2Vjg0ACf2wEXfgYUrm12VBEzNvQZ5Zaqz5Mvz3HaW8cW4L6LTvwl+\n/jU4Ukz+HNpfeA1q8DIpTRxWg7wy1Vny5XluO8v4YtwX0Xn21mONfdSR4cJykUmoudcgr0x1lnx5\nntvOMr4Y90V0Dg1kWy6CmntN8spUZ8mX57ntLOOLcV9E58SebMtFUHOvSV6Z6iz58jy3nWV8Me6L\n6Fz0HWgv+cuyvaOwXGQSuqBag7wy1Vny5XluO8v4YtwX0Rm9aKq0jGSgnLuISESUc5cJQsiuS+SU\nt4+GmnuLCCG7LpFT3j4quqDaIkLIrkvklLePipp7iwghuy6RU94+KmruLSKE7LpETnn7qKi5t4gQ\nsusSOeXto6ILqi0ihOy6RE55+6go5y4iEhHl3Ivyymtn2W4o85Irux6Y1DPjqY8viybsi6Sbe155\n7SzbDWVecmXXA5N6Zjz18WXRpH2R9AXVvPLaWbYbyrzkyq4HJvXMeOrjy6JJ+yLp5p5XXjvLdkOZ\nl1zZ9cCknhlPfXxZNGlfJN3c88prZ9luKPOSK7semNQz46mPL4sm7Yukm3teee0s2w1lXnJl1wOT\nemY89fFl0aR9kfQF1bzy2lm2G8q85MquByb1zHjq48uiSftCOXcRkYgo556zEPLzV9z3Ei++9s7R\n10tOO4lHrj237hpEkvLkTbDzQfARsDb43BpYfmf92w08x5/0Ofe8jGbGBw8O4xzLjG/ZPTht2y1t\n7AAvvvYOV9z3Ul01iCTlyZug74FCY4fC974HCsvrMZpdP7Qf8GPZ9f5NdZfcKGruNQghP1/a2Cst\nF2lJOx/MtrxaEeT41dxrEEJ+XkSq4CPZllcrghy/mnsNQsjPi0gVrC3b8mpFkONXc69BCPn5Jaed\nVHYbky0XaUmfW5NtebUiyPGruddgxaJubrtsAd2dHRjQ3dnBbZctaEh+vtrtPnLtuRMaudIyIiWW\n3wm91xz7pG5thdf1pmUWroQv3w0nzgKs8P3LdweVllHOXUQkIg3LuZvZDOCXwEeK6z/m7t8tWceA\nu4BlwHvAGnffVUvhlWTNl8c2h3mWud9T3xe55oizZJ/zqiPP8QWewa5L1rGlvC+mUM1NTH8GLnT3\nw2bWDvzKzLa6+/Yx61wMnFH8WgzcW/zeUFnnJI9tDvMsc7+nvi9ynQN7NPs8ajT7DBMbfF515Dm+\nlOdSzzq2lPdFBRXPuXvB4eLL9uJX6bmcS4GHiutuBzrN7JTGlpo9Xx7bHOZZ5n5PfV/kmiPOkn3O\nq448xxdBBrtmWceW8r6ooKoLqmbWZmYvA28Bz7j7jpJVuoGxHWiguKx0O2vNrM/M+oaGhjIXmzUH\nHltuPMvc76nvi1xzxFmyz3nVkef4Ishg1yzr2FLeFxVU1dzdfcTdzwJ6gLPNbH4tv8zdN7h7r7v3\ndnV1Zf7zWXPgseXGs8z9nvq+yDVHnCX7nFcdeY4vggx2zbKOLeV9UUGmKKS7HwR+AXyx5K1BYOwE\n5T3FZQ2VNV8e2xzmWeZ+T31f5JojzpJ9zquOPMcXQQa7ZlnHlvK+qKBiczezLjPrLP7cAXwB+E3J\nak8AV1nBOcAhd3+z0cVmzZfnlUfPyw9WLODKc2Yf/aTeZsaV58wum5ZJfV/kmiPOkn3Oq448xxdB\nBrtmWceW8r6ooGLO3cwWAj8B2ij8ZbDJ3W81s+sA3H19MQp5D4VP9O8BV7v7lCF25dxFRLJrWM7d\n3fuBRWWWrx/zswM3ZC1SRETykfzDOqK7cUemR5YbW0K4CSbPG3diu0krhOMRgaSbe3Q37sj0yHJj\nSwg3weR5405sN2mFcDwikfTEYdHduCPTI8uNLSHcBJPnjTux3aQVwvGIRNLNPbobd2R6ZLmxJYSb\nYPK8cSe2m7RCOB6RSLq5R3fjjkyPLDe2hHATTJ437sR2k1YIxyMSSTf36G7ckemR5caWEG6CyfPG\nndhu0grheEQi6eYe3Y07Mj2y3NgSwk0wed64E9tNWiEcj0joYR0iIhFp2E1MIi0vy4M9QhFbzaFk\n10OpowHU3EWmkuXBHqGIreZQsuuh1NEgSZ9zF6lblgd7hCK2mkPJrodSR4OouYtMJcuDPUIRW82h\nZNdDqaNB1NxFppLlwR6hiK3mULLrodTRIGruIlPJ8mCPUMRWcyjZ9VDqaBA1d5GpZHmwRyhiqzmU\n7HoodTSIcu4iIhFRzl2mT4zZ4LxqzitfHuM+lqZSc5f6xJgNzqvmvPLlMe5jaTqdc5f6xJgNzqvm\nvPLlMe5jaTo1d6lPjNngvGrOK18e4z6WplNzl/rEmA3Oq+a88uUx7mNpOjV3qU+M2eC8as4rXx7j\nPpamU3OX+sSYDc6r5rzy5THuY2k65dxFRCJSbc5dn9wlHf2b4Ifz4Xudhe/9m6Z/u3nVIJKRcu6S\nhryy4Fm2qzy6BESf3CUNeWXBs2xXeXQJiJq7pCGvLHiW7SqPLgFRc5c05JUFz7Jd5dElIGrukoa8\nsuBZtqs8ugREzV3SkFcWPMt2lUeXgFTMuZvZLOAh4JOAAxvc/a6SdS4AHgdeLy7a7O5TXkVSzl1E\nJLtGzuf+IfANd99lZscDO83sGXffW7LeC+6+vJZiJUAxzh+epeYYxxcC7bdoVGzu7v4m8Gbx5z+Z\n2T6gGyht7pKKGPPayqPnT/stKpnOuZvZHGARsKPM2+eZWb+ZbTWzMxtQmzRLjHlt5dHzp/0Wlarv\nUDWzjwE/Bb7u7u+WvL0LmO3uh81sGbAFOKPMNtYCawFmz55dc9GSsxjz2sqj50/7LSpVfXI3s3YK\njf0Rd99c+r67v+vuh4s/PwW0m9nMMuttcPded+/t6uqqs3TJTYx5beXR86f9FpWKzd3MDHgA2Ofu\nZecuNbOTi+thZmcXt3ugkYXKNIoxr608ev6036JSzWmZJcBXgD1m9nJx2beB2QDuvh64HLjezD4E\nhoFV3qy5hKV+oxfHYkpFZKk5xvGFQPstKprPXUQkIo3MuUuolDke78mbYOeDhQdSW1vh8Xb1PgVJ\nJFJq7rFS5ni8J2+CvgeOvfaRY6/V4KUFaW6ZWClzPN7OB7MtF0mcmnuslDkez0eyLRdJnJp7rJQ5\nHs/asi0XSZyae6yUOR7vc2uyLRdJnJp7rDR3+HjL74Tea459Ure2wmtdTJUWpZy7iEhElHOvwZbd\ng9zx9Cu8cXCYUzs7uHnpPFYs6m52WY2Tei4+9fGFQPs4GmruRVt2D7Ju8x6GjxTSFYMHh1m3eQ9A\nGg0+9Vx86uMLgfZxVHTOveiOp1852thHDR8Z4Y6nX2lSRQ2Wei4+9fGFQPs4KmruRW8cHM60PDqp\n5+JTH18ItI+jouZedGpnR6bl0Uk9F5/6+EKgfRwVNfeim5fOo6N9/A0vHe1t3Lx0XpMqarDUc/Gp\njy8E2sdR0QXVotGLpsmmZVKfizv18YVA+zgqyrmLiESk2py7TsuIxKB/E/xwPnyvs/C9f1Mc25am\n0WkZkdDlmS9Xdj1Z+uQuEro88+XKridLzV0kdHnmy5VdT5aau0jo8syXK7ueLDV3kdDlmS9Xdj1Z\nau4ioctz7n49FyBZyrmLiEREOXcRkRam5i4ikiA1dxGRBKm5i4gkSM1dRCRBau4iIglScxcRSZCa\nu4hIgio2dzObZWa/MLO9ZvbfZnZjmXXMzO42s1fNrN/MPptPuVIXzdst0jKqmc/9Q+Ab7r7LzI4H\ndprZM+6+d8w6FwNnFL8WA/cWv0soNG+3SEup+Mnd3d90913Fn/8E7ANKHyx6KfCQF2wHOs3slIZX\nK7XTvN0iLSXTOXczmwMsAnaUvNUN7B/zeoCJfwFgZmvNrM/M+oaGhrJVKvXRvN0iLaXq5m5mHwN+\nCnzd3d+t5Ze5+wZ373X33q6urlo2IbXSvN0iLaWq5m5m7RQa+yPuvrnMKoPArDGve4rLJBSat1uk\npVSTljHgAWCfu985yWpPAFcVUzPnAIfc/c0G1in10rzdIi2lmrTMEuArwB4ze7m47NvAbAB3Xw88\nBSwDXgXeA65ufKlSt4Ur1cxFWkTF5u7uvwKswjoO3NCookREpD66Q1VEJEFq7iIiCVJzFxFJkJq7\niEiC1NxFRBKk5i4ikiA1dxGRBFkhot6EX2w2BPy+Kb+8spnA280uIkcaX7xSHhtofNX4W3evODlX\n05p7yMysz917m11HXjS+eKU8NtD4GkmnZUREEqTmLiKSIDX38jY0u4CcaXzxSnlsoPE1jM65i4gk\nSJ/cRUQS1NLN3czazGy3mT1Z5r0LzOyQmb1c/IrqkUVm9jsz21Osva/M+2Zmd5vZq2bWb2afbUad\ntapifLEfv04ze8zMfmNm+8zs3JL3Yz9+lcYX7fEzs3lj6n7ZzN41s6+XrJP78avmYR0puxHYB5ww\nyfsvuPvyaayn0f7e3SfL1F4MnFH8WgzcW/wek6nGB3Efv7uAbe5+uZn9NfA3Je/HfvwqjQ8iPX7u\n/gpwFhQ+QFJ45OjPSlbL/fi17Cd3M+sBvgTc3+xamuRS4CEv2A50mtkpzS5KwMxOBM6n8HhL3P0D\ndz9Yslq0x6/K8aXiIuA1dy+9YTP349eyzR34EfBN4C9TrHNe8Z9MW83szGmqq1Ec+E8z22lma8u8\n3w3sH/N6oLgsFpXGB/Eev7nAEPCvxdOG95vZR0vWifn4VTM+iPf4jbUK2Fhmee7HryWbu5ktB95y\n951TrLYLmO3uC4EfA1umpbjG+Tt3P4vCP/9uMLPzm11Qg1UaX8zH7zjgs8C97r4I+D/gW80tqaGq\nGV/Mxw+A4ummS4B/b8bvb8nmTuGh35eY2e+AR4ELzezhsSu4+7vufrj481NAu5nNnPZKa+Tug8Xv\nb1E433d2ySqDwKwxr3uKy6JQaXyRH78BYMDddxRfP0ahGY4V8/GrOL7Ij9+oi4Fd7v7HMu/lfvxa\nsrm7+zp373H3ORT+2fScu185dh0zO9nMrPjz2RT21YFpL7YGZvZRMzt+9GfgH4D/KlntCeCq4lX7\nc4BD7v7mNJdak2rGF/Pxc/f/Bfab2bzioouAvSWrRXv8qhlfzMdvjNWUPyUD03D8Wj0tM46ZXQfg\n7uuBy4HrzexDYBhY5fHc8fVJ4GfF/zeOA/7N3beVjO8pYBnwKvAecHWTaq1FNeOL+fgB/BPwSPGf\n9v8DXJ3Q8YPK44v6+BU/dHwB+Mcxy6b1+OkOVRGRBLXkaRkRkdSpuYuIJEjNXUQkQWruIiIJUnMX\nEUmQmruISILU3EVEEqTmLiKSoP8H2fNC9uxjMHwAAAAASUVORK5CYII=\n", 150 | "text/plain": [ 151 | "" 152 | ] 153 | }, 154 | "metadata": {}, 155 | "output_type": "display_data" 156 | } 157 | ], 158 | "source": [ 159 | "plt.scatter(X[:50,0],X[:50,1], label='0')\n", 160 | "plt.scatter(X[50:,0],X[50:,1], label='1')\n", 161 | "plt.legend()" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 7, 167 | "metadata": { 168 | "collapsed": true 169 | }, 170 | "outputs": [], 171 | "source": [ 172 | "class SVM:\n", 173 | " def __init__(self, max_iter=100, kernel='linear'):\n", 174 | " self.max_iter = max_iter\n", 175 | " self._kernel = kernel\n", 176 | " \n", 177 | " def init_args(self, features, labels):\n", 178 | " self.m, self.n = features.shape\n", 179 | " self.X = features\n", 180 | " self.Y = labels\n", 181 | " self.b = 0.0\n", 182 | " \n", 183 | " # 将Ei保存在一个列表里\n", 184 | " self.alpha = np.ones(self.m)\n", 185 | " self.E = [self._E(i) for i in range(self.m)]\n", 186 | " # 松弛变量\n", 187 | " self.C = 1.0\n", 188 | " \n", 189 | " def _KKT(self, i):\n", 190 | " y_g = self._g(i)*self.Y[i]\n", 191 | " if self.alpha[i] == 0:\n", 192 | " return y_g >= 1\n", 193 | " elif 0 < self.alpha[i] < self.C:\n", 194 | " return y_g == 1\n", 195 | " else:\n", 196 | " return y_g <= 1\n", 197 | " \n", 198 | " # g(x)预测值,输入xi(X[i])\n", 199 | " def _g(self, i):\n", 200 | " r = self.b\n", 201 | " for j in range(self.m):\n", 202 | " r += self.alpha[j]*self.Y[j]*self.kernel(self.X[i], self.X[j])\n", 203 | " return r\n", 204 | " \n", 205 | " # 核函数\n", 206 | " def kernel(self, x1, x2):\n", 207 | " if self._kernel == 'linear':\n", 208 | " return sum([x1[k]*x2[k] for k in range(self.n)])\n", 209 | " elif self._kernel == 'poly':\n", 210 | " return (sum([x1[k]*x2[k] for k in range(self.n)]) + 1)**2\n", 211 | " \n", 212 | " return 0\n", 213 | " \n", 214 | " # E(x)为g(x)对输入x的预测值和y的差\n", 215 | " def _E(self, i):\n", 216 | " return self._g(i) - self.Y[i]\n", 217 | " \n", 218 | " def _init_alpha(self):\n", 219 | " # 外层循环首先遍历所有满足0= 0:\n", 232 | " j = min(range(self.m), key=lambda x: self.E[x])\n", 233 | " else:\n", 234 | " j = max(range(self.m), key=lambda x: self.E[x])\n", 235 | " return i, j\n", 236 | " \n", 237 | " def _compare(self, _alpha, L, H):\n", 238 | " if _alpha > H:\n", 239 | " return H\n", 240 | " elif _alpha < L:\n", 241 | " return L\n", 242 | " else:\n", 243 | " return _alpha \n", 244 | " \n", 245 | " def fit(self, features, labels):\n", 246 | " self.init_args(features, labels)\n", 247 | " \n", 248 | " for t in range(self.max_iter):\n", 249 | " # train\n", 250 | " i1, i2 = self._init_alpha()\n", 251 | " \n", 252 | " # 边界\n", 253 | " if self.Y[i1] == self.Y[i2]:\n", 254 | " L = max(0, self.alpha[i1]+self.alpha[i2]-self.C)\n", 255 | " H = min(self.C, self.alpha[i1]+self.alpha[i2])\n", 256 | " else:\n", 257 | " L = max(0, self.alpha[i2]-self.alpha[i1])\n", 258 | " H = min(self.C, self.C+self.alpha[i2]-self.alpha[i1])\n", 259 | " \n", 260 | " E1 = self.E[i1]\n", 261 | " E2 = self.E[i2]\n", 262 | " # eta=K11+K22-2K12\n", 263 | " eta = self.kernel(self.X[i1], self.X[i1]) + self.kernel(self.X[i2], self.X[i2]) - 2*self.kernel(self.X[i1], self.X[i2])\n", 264 | " if eta <= 0:\n", 265 | " # print('eta <= 0')\n", 266 | " continue\n", 267 | " \n", 268 | " alpha2_new_unc = self.alpha[i2] + self.Y[i2] * (E2 - E1) / eta\n", 269 | " alpha2_new = self._compare(alpha2_new_unc, L, H)\n", 270 | " \n", 271 | " alpha1_new = self.alpha[i1] + self.Y[i1] * self.Y[i2] * (self.alpha[i2] - alpha2_new)\n", 272 | " \n", 273 | " b1_new = -E1 - self.Y[i1] * self.kernel(self.X[i1], self.X[i1]) * (alpha1_new-self.alpha[i1]) - self.Y[i2] * self.kernel(self.X[i2], self.X[i1]) * (alpha2_new-self.alpha[i2])+ self.b \n", 274 | " b2_new = -E2 - self.Y[i1] * self.kernel(self.X[i1], self.X[i2]) * (alpha1_new-self.alpha[i1]) - self.Y[i2] * self.kernel(self.X[i2], self.X[i2]) * (alpha2_new-self.alpha[i2])+ self.b \n", 275 | " \n", 276 | " if 0 < alpha1_new < self.C:\n", 277 | " b_new = b1_new\n", 278 | " elif 0 < alpha2_new < self.C:\n", 279 | " b_new = b2_new\n", 280 | " else:\n", 281 | " # 选择中点\n", 282 | " b_new = (b1_new + b2_new) / 2\n", 283 | " \n", 284 | " # 更新参数\n", 285 | " self.alpha[i1] = alpha1_new\n", 286 | " self.alpha[i2] = alpha2_new\n", 287 | " self.b = b_new\n", 288 | " \n", 289 | " self.E[i1] = self._E(i1)\n", 290 | " self.E[i2] = self._E(i2)\n", 291 | " return 'train done!'\n", 292 | " \n", 293 | " def predict(self, data):\n", 294 | " r = self.b\n", 295 | " for i in range(self.m):\n", 296 | " r += self.alpha[i] * self.Y[i] * self.kernel(data, self.X[i])\n", 297 | " \n", 298 | " return 1 if r > 0 else -1\n", 299 | " \n", 300 | " def score(self, X_test, y_test):\n", 301 | " right_count = 0\n", 302 | " for i in range(len(X_test)):\n", 303 | " result = self.predict(X_test[i])\n", 304 | " if result == y_test[i]:\n", 305 | " right_count += 1\n", 306 | " return right_count / len(X_test)\n", 307 | " \n", 308 | " def _weight(self):\n", 309 | " # linear model\n", 310 | " yx = self.Y.reshape(-1, 1)*self.X\n", 311 | " self.w = np.dot(yx.T, self.alpha)\n", 312 | " return self.w" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": 10, 318 | "metadata": { 319 | "collapsed": true 320 | }, 321 | "outputs": [], 322 | "source": [ 323 | "svm = SVM(max_iter=200)" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 11, 329 | "metadata": { 330 | "collapsed": false 331 | }, 332 | "outputs": [ 333 | { 334 | "data": { 335 | "text/plain": [ 336 | "'train done!'" 337 | ] 338 | }, 339 | "execution_count": 11, 340 | "metadata": {}, 341 | "output_type": "execute_result" 342 | } 343 | ], 344 | "source": [ 345 | "svm.fit(X_train, y_train)" 346 | ] 347 | }, 348 | { 349 | "cell_type": "code", 350 | "execution_count": 12, 351 | "metadata": { 352 | "collapsed": false 353 | }, 354 | "outputs": [ 355 | { 356 | "data": { 357 | "text/plain": [ 358 | "0.88" 359 | ] 360 | }, 361 | "execution_count": 12, 362 | "metadata": {}, 363 | "output_type": "execute_result" 364 | } 365 | ], 366 | "source": [ 367 | "svm.score(X_test, y_test)" 368 | ] 369 | }, 370 | { 371 | "cell_type": "markdown", 372 | "metadata": {}, 373 | "source": [ 374 | "初始化" 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": 14, 380 | "metadata": { 381 | "collapsed": true 382 | }, 383 | "outputs": [], 384 | "source": [ 385 | "max_iter = 200\n", 386 | "_kernel = 'linear'" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": 16, 392 | "metadata": { 393 | "collapsed": true 394 | }, 395 | "outputs": [], 396 | "source": [ 397 | "def init_args(self, features, labels):\n", 398 | " m, n = features.shape\n", 399 | " X = features\n", 400 | " Y = labels\n", 401 | " b = 0.0\n", 402 | " \n", 403 | " # 将Ei保存在一个列表里\n", 404 | " alpha = np.ones(m)\n", 405 | " E = [_E(i) for i in range(m)]\n", 406 | " # 松弛变量\n", 407 | " C = 1.0" 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": 17, 413 | "metadata": { 414 | "collapsed": false 415 | }, 416 | "outputs": [ 417 | { 418 | "name": "stdout", 419 | "output_type": "stream", 420 | "text": [ 421 | "75 2\n" 422 | ] 423 | } 424 | ], 425 | "source": [ 426 | "features = X_train\n", 427 | "labels = y_train\n", 428 | "m, n = features.shape\n", 429 | "print(m, n)" 430 | ] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": 18, 435 | "metadata": { 436 | "collapsed": true 437 | }, 438 | "outputs": [], 439 | "source": [ 440 | "b = 0.0" 441 | ] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": 19, 446 | "metadata": { 447 | "collapsed": false 448 | }, 449 | "outputs": [ 450 | { 451 | "data": { 452 | "text/plain": [ 453 | "array([ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", 454 | " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", 455 | " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", 456 | " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", 457 | " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", 458 | " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])" 459 | ] 460 | }, 461 | "execution_count": 19, 462 | "metadata": {}, 463 | "output_type": "execute_result" 464 | } 465 | ], 466 | "source": [ 467 | "alpha = np.ones(m)\n", 468 | "alpha" 469 | ] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "execution_count": null, 474 | "metadata": { 475 | "collapsed": true 476 | }, 477 | "outputs": [], 478 | "source": [ 479 | "# E(x)为g(x)对输入x的预测值和y的差\n", 480 | "def _E(self, i):\n", 481 | " return self._g(i) - self.Y[i]\n" 482 | ] 483 | }, 484 | { 485 | "cell_type": "code", 486 | "execution_count": 20, 487 | "metadata": { 488 | "collapsed": false 489 | }, 490 | "outputs": [ 491 | { 492 | "ename": "NameError", 493 | "evalue": "name '_E' is not defined", 494 | "output_type": "error", 495 | "traceback": [ 496 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 497 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 498 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mE\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0m_E\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 499 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mE\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0m_E\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 500 | "\u001b[0;31mNameError\u001b[0m: name '_E' is not defined" 501 | ] 502 | } 503 | ], 504 | "source": [ 505 | "E = [_E(i) for i in range(m)]\n" 506 | ] 507 | }, 508 | { 509 | "cell_type": "code", 510 | "execution_count": null, 511 | "metadata": { 512 | "collapsed": true 513 | }, 514 | "outputs": [], 515 | "source": [] 516 | } 517 | ], 518 | "metadata": { 519 | "kernelspec": { 520 | "display_name": "Python [py35]", 521 | "language": "python", 522 | "name": "Python [py35]" 523 | }, 524 | "language_info": { 525 | "codemirror_mode": { 526 | "name": "ipython", 527 | "version": 3 528 | }, 529 | "file_extension": ".py", 530 | "mimetype": "text/x-python", 531 | "name": "python", 532 | "nbconvert_exporter": "python", 533 | "pygments_lexer": "ipython3", 534 | "version": "3.5.2" 535 | } 536 | }, 537 | "nbformat": 4, 538 | "nbformat_minor": 0 539 | } 540 | -------------------------------------------------------------------------------- /notebooks/7-adaboost.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np\n", 12 | "import pandas as pd\n", 13 | "from sklearn.datasets import load_iris\n", 14 | "from sklearn.model_selection import train_test_split\n", 15 | "import matplotlib.pyplot as plt\n", 16 | "%matplotlib inline" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "metadata": { 23 | "collapsed": true 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "# data\n", 28 | "def create_data():\n", 29 | " iris = load_iris()\n", 30 | " df = pd.DataFrame(iris.data, columns=iris.feature_names)\n", 31 | " df['label'] = iris.target\n", 32 | " df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n", 33 | " data = np.array(df.iloc[:100, [0, 1, -1]])\n", 34 | " for i in range(len(data)):\n", 35 | " if data[i,-1] == 0:\n", 36 | " data[i,-1] = -1\n", 37 | " # print(data)\n", 38 | " return data[:,:2], data[:,-1]" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 8, 44 | "metadata": { 45 | "collapsed": true 46 | }, 47 | "outputs": [], 48 | "source": [ 49 | "X, y = create_data()\n", 50 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 4, 56 | "metadata": { 57 | "collapsed": false 58 | }, 59 | "outputs": [ 60 | { 61 | "data": { 62 | "text/plain": [ 63 | "(100, 2)" 64 | ] 65 | }, 66 | "execution_count": 4, 67 | "metadata": {}, 68 | "output_type": "execute_result" 69 | } 70 | ], 71 | "source": [ 72 | "X.shape" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 9, 78 | "metadata": { 79 | "collapsed": false 80 | }, 81 | "outputs": [ 82 | { 83 | "data": { 84 | "text/plain": [ 85 | "array([[ 5.1, 3.5],\n", 86 | " [ 4.9, 3. ],\n", 87 | " [ 4.7, 3.2],\n", 88 | " [ 4.6, 3.1],\n", 89 | " [ 5. , 3.6],\n", 90 | " [ 5.4, 3.9],\n", 91 | " [ 4.6, 3.4],\n", 92 | " [ 5. , 3.4],\n", 93 | " [ 4.4, 2.9],\n", 94 | " [ 4.9, 3.1]])" 95 | ] 96 | }, 97 | "execution_count": 9, 98 | "metadata": {}, 99 | "output_type": "execute_result" 100 | } 101 | ], 102 | "source": [ 103 | "X[:10]" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 10, 109 | "metadata": { 110 | "collapsed": true 111 | }, 112 | "outputs": [], 113 | "source": [ 114 | "X = np.arange(10).reshape(10, 1)" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 11, 120 | "metadata": { 121 | "collapsed": false 122 | }, 123 | "outputs": [ 124 | { 125 | "data": { 126 | "text/plain": [ 127 | "(10, 1)" 128 | ] 129 | }, 130 | "execution_count": 11, 131 | "metadata": {}, 132 | "output_type": "execute_result" 133 | } 134 | ], 135 | "source": [ 136 | "X.shape" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 12, 142 | "metadata": { 143 | "collapsed": false 144 | }, 145 | "outputs": [ 146 | { 147 | "data": { 148 | "text/plain": [ 149 | "array([[0],\n", 150 | " [1],\n", 151 | " [2],\n", 152 | " [3],\n", 153 | " [4],\n", 154 | " [5],\n", 155 | " [6],\n", 156 | " [7],\n", 157 | " [8],\n", 158 | " [9]])" 159 | ] 160 | }, 161 | "execution_count": 12, 162 | "metadata": {}, 163 | "output_type": "execute_result" 164 | } 165 | ], 166 | "source": [ 167 | "X[:10]" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 13, 173 | "metadata": { 174 | "collapsed": true 175 | }, 176 | "outputs": [], 177 | "source": [ 178 | "y = np.array([1, 1, 1, -1, -1, -1, 1, 1, 1, -1])" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 16, 184 | "metadata": { 185 | "collapsed": false 186 | }, 187 | "outputs": [ 188 | { 189 | "data": { 190 | "text/plain": [ 191 | "array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])" 192 | ] 193 | }, 194 | "execution_count": 16, 195 | "metadata": {}, 196 | "output_type": "execute_result" 197 | } 198 | ], 199 | "source": [ 200 | "features = X[:, 0]\n", 201 | "features" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 19, 207 | "metadata": { 208 | "collapsed": true 209 | }, 210 | "outputs": [], 211 | "source": [ 212 | "learning_rate = 0.5" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 20, 218 | "metadata": { 219 | "collapsed": false 220 | }, 221 | "outputs": [], 222 | "source": [ 223 | "# 单维features\n", 224 | "features_min = min(features)\n", 225 | "features_max = max(features)\n", 226 | "n_step = (features_max - features_min + learning_rate) // learning_rate" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": 23, 232 | "metadata": { 233 | "collapsed": false 234 | }, 235 | "outputs": [ 236 | { 237 | "name": "stdout", 238 | "output_type": "stream", 239 | "text": [ 240 | "0 \n", 241 | " 9 \n", 242 | " 19.0\n" 243 | ] 244 | } 245 | ], 246 | "source": [ 247 | "print(features_min, '\\n', features_max,'\\n', n_step)" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": 24, 253 | "metadata": { 254 | "collapsed": false 255 | }, 256 | "outputs": [ 257 | { 258 | "data": { 259 | "text/plain": [ 260 | "19" 261 | ] 262 | }, 263 | "execution_count": 24, 264 | "metadata": {}, 265 | "output_type": "execute_result" 266 | } 267 | ], 268 | "source": [ 269 | "95 // 5 # 取商,整除操作" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 25, 275 | "metadata": { 276 | "collapsed": false 277 | }, 278 | "outputs": [ 279 | { 280 | "data": { 281 | "text/plain": [ 282 | "0" 283 | ] 284 | }, 285 | "execution_count": 25, 286 | "metadata": {}, 287 | "output_type": "execute_result" 288 | } 289 | ], 290 | "source": [ 291 | "95 % 5 # 取余" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 26, 297 | "metadata": { 298 | "collapsed": false 299 | }, 300 | "outputs": [ 301 | { 302 | "data": { 303 | "text/plain": [ 304 | "19.0" 305 | ] 306 | }, 307 | "execution_count": 26, 308 | "metadata": {}, 309 | "output_type": "execute_result" 310 | } 311 | ], 312 | "source": [ 313 | "95 / 5 " 314 | ] 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "metadata": {}, 319 | "source": [ 320 | "`_G(self, features, labels, weights)`针对一个特征,函数返回最佳的阈值,并计算分类" 321 | ] 322 | } 323 | ], 324 | "metadata": { 325 | "anaconda-cloud": {}, 326 | "kernelspec": { 327 | "display_name": "Python [py35]", 328 | "language": "python", 329 | "name": "Python [py35]" 330 | }, 331 | "language_info": { 332 | "codemirror_mode": { 333 | "name": "ipython", 334 | "version": 3 335 | }, 336 | "file_extension": ".py", 337 | "mimetype": "text/x-python", 338 | "name": "python", 339 | "nbconvert_exporter": "python", 340 | "pygments_lexer": "ipython3", 341 | "version": "3.5.2" 342 | } 343 | }, 344 | "nbformat": 4, 345 | "nbformat_minor": 0 346 | } 347 | --------------------------------------------------------------------------------