├── .DS_Store ├── LICENSE ├── README.md ├── SLM.jpg ├── 第01章 统计学习方法概论 ├── 20190522135945.jpg └── least_sqaure_method.ipynb ├── 第02章 感知机 └── perceptron.ipynb ├── 第03章 k近邻法 └── KNN.ipynb ├── 第04章 朴素贝叶斯法 └── NaiveBayes.ipynb ├── 第05章 决策树 └── DT.ipynb ├── 第06章 逻辑斯蒂回归与最大熵模型 └── LR.ipynb ├── 第07章 支持向量机 ├── .DS_Store └── SVM.ipynb ├── 第08章 提升方法 ├── .DS_Store └── Adaboost.ipynb ├── 第09章 EM算法及其推广 └── EM.ipynb ├── 第10章 隐马尔可夫模型 └── HMM.ipynb ├── 第11章 条件随机场 └── CRF.ipynb ├── 第12章 监督学习方法总结 └── Summary_of_Supervised_Learning_Methods.ipynb ├── 第13章 无监督学习概论 └── Introduction_to_Unsupervised_Learning.ipynb ├── 第14章 聚类方法 └── Clustering.ipynb ├── 第15章 奇异值分解 └── SVD.ipynb ├── 第16章 主成分分析 └── PCA.ipynb ├── 第17章 潜在语义分析 └── LSA.ipynb ├── 第18章 概率潜在语义分析 └── PLSA.ipynb ├── 第19章 马尔可夫链蒙特卡洛法 └── MCMC.ipynb ├── 第20章 潜在狄利克雷分配 └── LDA.ipynb └── 第21章 PageRank算法 └── PageRank.ipynb /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hktxt/Learn-Statistical-Learning-Method/14a3b65db3ff35c42743a1062a2a95cecbfad094/.DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Max 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learn-Statistical-Learning-Method, Second Edition 2 | ![alt text](SLM.jpg) 3 | Implementation of Statistical Learning Method 4 | 《统计学习方法》第二版,算法实现。 5 | 6 | 7 | 第1章:统计学习方法概论 [least_sqaure_method.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC01%E7%AB%A0%20%E7%BB%9F%E8%AE%A1%E5%AD%A6%E4%B9%A0%E6%96%B9%E6%B3%95%E6%A6%82%E8%AE%BA/least_sqaure_method.ipynb) 8 | 第2章:感知机 [perceptron.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC02%E7%AB%A0%20%E6%84%9F%E7%9F%A5%E6%9C%BA/perceptron.ipynb) 9 | 第3章:k近邻法 [KNN.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC03%E7%AB%A0%20k%E8%BF%91%E9%82%BB%E6%B3%95/KNN.ipynb) 10 | 第4章:朴素贝叶斯法 [NaiveBayes.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC04%E7%AB%A0%20%E6%9C%B4%E7%B4%A0%E8%B4%9D%E5%8F%B6%E6%96%AF%E6%B3%95/NaiveBayes.ipynb) 11 | 第5章:决策树 [DT.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC05%E7%AB%A0%20%E5%86%B3%E7%AD%96%E6%A0%91/DT.ipynb) 12 | 第6章:逻辑斯蒂回归与最大熵模型 [LR.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC06%E7%AB%A0%20%E9%80%BB%E8%BE%91%E6%96%AF%E8%92%82%E5%9B%9E%E5%BD%92%E4%B8%8E%E6%9C%80%E5%A4%A7%E7%86%B5%E6%A8%A1%E5%9E%8B/LR.ipynb) 13 | 第7章:支持向量机 [SVM.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC07%E7%AB%A0%20%E6%94%AF%E6%8C%81%E5%90%91%E9%87%8F%E6%9C%BA/SVM.ipynb) 14 | 第8章:提升方法 [Adaboost.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC08%E7%AB%A0%20%E6%8F%90%E5%8D%87%E6%96%B9%E6%B3%95/Adaboost.ipynb) 15 | 第9章:EM算法及其推广 [EM.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC09%E7%AB%A0%20EM%E7%AE%97%E6%B3%95%E5%8F%8A%E5%85%B6%E6%8E%A8%E5%B9%BF/EM.ipynb) 16 | 第10章:隐马尔可夫模型 [HMM.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC10%E7%AB%A0%20%E9%9A%90%E9%A9%AC%E5%B0%94%E5%8F%AF%E5%A4%AB%E6%A8%A1%E5%9E%8B/HMM.ipynb) 17 | 第11章:条件随机场 [CRF.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC11%E7%AB%A0%20%E6%9D%A1%E4%BB%B6%E9%9A%8F%E6%9C%BA%E5%9C%BA/CRF.ipynb) 18 | 第12章: 监督学习方法总结 [Summary_of_Supervised_Learning_Methods.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC12%E7%AB%A0%20%E7%9B%91%E7%9D%A3%E5%AD%A6%E4%B9%A0%E6%96%B9%E6%B3%95%E6%80%BB%E7%BB%93/Summary_of_Supervised_Learning_Methods.ipynb) 19 | 第13章:无监督学习概论 [Introduction_to_Unsupervised_Learning.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC13%E7%AB%A0%20%E6%97%A0%E7%9B%91%E7%9D%A3%E5%AD%A6%E4%B9%A0%E6%A6%82%E8%AE%BA/Introduction_to_Unsupervised_Learning.ipynb) 20 | 第14章:聚类方法 [Clustering.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC14%E7%AB%A0%20%E8%81%9A%E7%B1%BB%E6%96%B9%E6%B3%95/Clustering.ipynb) 21 | 第15章:奇异值分解 [SVD.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC15%E7%AB%A0%20%E5%A5%87%E5%BC%82%E5%80%BC%E5%88%86%E8%A7%A3/SVD.ipynb) 22 | 第16章:主成分分析 [PCA.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC16%E7%AB%A0%20%E4%B8%BB%E6%88%90%E5%88%86%E5%88%86%E6%9E%90/PCA.ipynb) 23 | 第17章:潜在语义分析 [LSA.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC17%E7%AB%A0%20%E6%BD%9C%E5%9C%A8%E8%AF%AD%E4%B9%89%E5%88%86%E6%9E%90/LSA.ipynb) 24 | 第18章:概率潜在语义分析 [PLSA.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC18%E7%AB%A0%20%E6%A6%82%E7%8E%87%E6%BD%9C%E5%9C%A8%E8%AF%AD%E4%B9%89%E5%88%86%E6%9E%90/PLSA.ipynb) 25 | 第19章:马尔可夫链蒙特卡洛法 [MCMC.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC19%E7%AB%A0%20%E9%A9%AC%E5%B0%94%E5%8F%AF%E5%A4%AB%E9%93%BE%E8%92%99%E7%89%B9%E5%8D%A1%E6%B4%9B%E6%B3%95/MCMC.ipynb) 26 | 第20章:潜在狄利克雷分配 [LDA.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC20%E7%AB%A0%20%E6%BD%9C%E5%9C%A8%E7%8B%84%E5%88%A9%E5%85%8B%E9%9B%B7%E5%88%86%E9%85%8D/LDA.ipynb) 27 | 第21章:PageRank算法 [PageRank.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC21%E7%AB%A0%20PageRank%E7%AE%97%E6%B3%95/PageRank.ipynb) 28 | 29 | 30 | 31 | ## acknowledgment 32 | 33 | At present, this is still an incomplete project. For some algorithms, I am still ignorant, just followed the math equations to implement. Some algorithms are reproduced independently by myself, and others are referred to online resources, you can find the specific link in the file. I will keep updating this project until I have mastered all the algorithms in the book. 34 | 35 | 36 | -------------------------------------------------------------------------------- /SLM.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hktxt/Learn-Statistical-Learning-Method/14a3b65db3ff35c42743a1062a2a95cecbfad094/SLM.jpg -------------------------------------------------------------------------------- /第01章 统计学习方法概论/20190522135945.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hktxt/Learn-Statistical-Learning-Method/14a3b65db3ff35c42743a1062a2a95cecbfad094/第01章 统计学习方法概论/20190522135945.jpg -------------------------------------------------------------------------------- /第04章 朴素贝叶斯法/NaiveBayes.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "NaiveBayes.ipynb", 7 | "version": "0.3.2", 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "language_info": { 12 | "codemirror_mode": { 13 | "name": "ipython", 14 | "version": 3 15 | }, 16 | "file_extension": ".py", 17 | "mimetype": "text/x-python", 18 | "name": "python", 19 | "nbconvert_exporter": "python", 20 | "pygments_lexer": "ipython3", 21 | "version": "3.6.2" 22 | }, 23 | "kernelspec": { 24 | "display_name": "Python 3", 25 | "language": "python", 26 | "name": "python3" 27 | } 28 | }, 29 | "cells": [ 30 | { 31 | "cell_type": "markdown", 32 | "metadata": { 33 | "id": "WDqA-VKvfWix", 34 | "colab_type": "text" 35 | }, 36 | "source": [ 37 | "# 第4章 朴素贝叶斯" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": { 43 | "id": "rMWMEdyUfWix", 44 | "colab_type": "text" 45 | }, 46 | "source": [ 47 | "基于贝叶斯定理与特征条件独立假设的分类方法。\n", 48 | "\n", 49 | "模型:\n", 50 | "\n", 51 | "- 高斯模型\n", 52 | "- 多项式模型\n", 53 | "- 伯努利模型" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "metadata": { 59 | "id": "mahnF7NFfWiy", 60 | "colab_type": "code", 61 | "colab": {} 62 | }, 63 | "source": [ 64 | "import numpy as np\n", 65 | "import pandas as pd\n", 66 | "import matplotlib.pyplot as plt\n", 67 | "%matplotlib inline\n", 68 | "\n", 69 | "from sklearn.datasets import load_iris\n", 70 | "from sklearn.model_selection import train_test_split\n", 71 | "\n", 72 | "from collections import Counter\n", 73 | "import math" 74 | ], 75 | "execution_count": 0, 76 | "outputs": [] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "metadata": { 81 | "id": "6tRQt9QFf27Y", 82 | "colab_type": "code", 83 | "colab": {} 84 | }, 85 | "source": [ 86 | "# 例 4.1 \n", 87 | "lambda_ = 0.2\n", 88 | "x = [2, 'S']\n", 89 | "\n", 90 | "X1 = [1,2,3]\n", 91 | "X2 = ['S', 'M', 'L']\n", 92 | "Y = [1, -1]" 93 | ], 94 | "execution_count": 0, 95 | "outputs": [] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": { 100 | "id": "CCzHK_v2hafm", 101 | "colab_type": "text" 102 | }, 103 | "source": [ 104 | "$P_\\lambda(Y=1)=(9+lambda\\_)/(15 + 2*lambda\\_) = (9+0.2)/(15+2*0.2)=0.5974025974025974$\n", 105 | "$P_\\lambda(Y=-1)=(6+lambda\\_)/(15 + 2*lambda\\_) = (6+0.2)/(15+2*0.2)=0.40259740259740264$ \n", 106 | "$P(X^{(1)}=1|Y=1) = (2+0.2)/(9+3*0.2)=0.22916666666666669 $ \n", 107 | "$P(X^{(1)}=2|Y=1) = (3+0.2)/(9+3*0.2)=0.33333333333333337 $ \n", 108 | "$P(X^{(1)}=3|Y=1) = (4+0.2)/(9+3*0.2)=0.43750000000000006 $ \n", 109 | "$P(X^{(2)}=S|Y=1) = (1+0.2)/(9+3*0.2)=0.125 $ \n", 110 | "$P(X^{(2)}=M|Y=1) = (4+0.2)/(9+3*0.2)=0.43750000000000006 $ \n", 111 | "$P(X^{(2)}=L|Y=1) = (4+0.2)/(9+3*0.2)=0.43750000000000006 $ \n", 112 | "$P(X^{(1)}=1|Y=-1) = (3+0.2)/(6+3*0.2)=0.4848484848484849 $ \n", 113 | "$P(X^{(1)}=2|Y=-1) = (2+0.2)/(6+3*0.2)=0.33333333333333337 $ \n", 114 | "$P(X^{(1)}=3|Y=-1) = (1+0.2)/(6+3*0.2)=0.18181818181818182 $ \n", 115 | "$P(X^{(2)}=S|Y=-1) = (3+0.2)/(6+3*0.2)=0.4848484848484849 $ \n", 116 | "$P(X^{(2)}=M|Y=-1) = (2+0.2)/(6+3*0.2)=0.33333333333333337 $ \n", 117 | "$P(X^{(2)}=L|Y=-1) = (1+0.2)/(6+3*0.2)=0.18181818181818182 $ \n", 118 | "so \n", 119 | "$P(Y=1)P(X^{(1)}=2|Y=1)P(X^{(2)}=S|Y=1) =0.5974025974025974* 0.33333333333333337*0.125=0.024891774891774892$ \n", 120 | "$P(Y=-1)P(X^{(1)}=2|Y=-1)P(X^{(2)}=S|Y=-1) =0.40259740259740264* 0.33333333333333337*0.4848484848484849=0.06506624688442873$ \n", 121 | "\n", 122 | "so, it should be -1." 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "metadata": { 128 | "id": "86WQkGZefWi1", 129 | "colab_type": "code", 130 | "colab": {} 131 | }, 132 | "source": [ 133 | "class NB:\n", 134 | " def __init__(self, lambda_):\n", 135 | " self.lambda_ = lambda_\n", 136 | " \n", 137 | " def fit(self, X, y):\n", 138 | " N, M = X.shape\n", 139 | " data = np.hstack((X, y.reshape(N, 1)))\n", 140 | " \n", 141 | " py = {}\n", 142 | " pxy = {}\n", 143 | " uniquey, countsy = np.unique(y, return_counts=True)\n", 144 | " tmp = dict(zip(uniquey, countsy))\n", 145 | " for k,v in tmp.items():\n", 146 | " py[k] = (v + self.lambda_)/(N + len(uniquey) * self.lambda_)\n", 147 | " tmp_data = data[data[:, -1] == k]\n", 148 | " for col in range(M):\n", 149 | " uniquecol, countscol = np.unique(tmp_data[:,col], return_counts=True)\n", 150 | " tmp1 = dict(zip(uniquecol, countscol))\n", 151 | " for kk, vv in tmp1.items():\n", 152 | " pxy['X({})={}|Y={}'.format(col+1, kk, k)] = (vv + self.lambda_)/(v + len(uniquecol) * self.lambda_)\n", 153 | " \n", 154 | " self.py = py\n", 155 | " self.pxy = pxy\n", 156 | "\n", 157 | " #return self.py, self.pxy\n", 158 | " \n", 159 | " def predict(self, x):\n", 160 | " M = len(x)\n", 161 | " res = {}\n", 162 | " for k,v in self.py.items():\n", 163 | " p = v\n", 164 | " for i in range(len(x)):\n", 165 | " p = p * self.pxy['X({})={}|Y={}'.format(i+1, x[i], k)]\n", 166 | " res[k] = p\n", 167 | " print(res)\n", 168 | " maxp = -1\n", 169 | " maxk = -1\n", 170 | " for kk,vv in res.items():\n", 171 | " if vv > maxp:\n", 172 | " maxp = vv\n", 173 | " maxk = kk\n", 174 | " \n", 175 | " return maxk" 176 | ], 177 | "execution_count": 0, 178 | "outputs": [] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "metadata": { 183 | "id": "3hPRglhJfWi3", 184 | "colab_type": "code", 185 | "colab": {} 186 | }, 187 | "source": [ 188 | "lambda_ = 0.2\n", 189 | "d = {'S':0, 'M':1, 'L':2}\n", 190 | "\n", 191 | "X = np.array([[1, d['S']], [1, d['M']], [1, d['M']],\n", 192 | " [1, d['S']], [1, d['S']], [2, d['S']],\n", 193 | " [2, d['M']], [2, d['M']], [2, d['L']],\n", 194 | " [2, d['L']], [3, d['L']], [3, d['M']],\n", 195 | " [3, d['M']], [3, d['L']], [3, d['L']]])\n", 196 | "\n", 197 | "y = np.array([-1, -1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, -1])" 198 | ], 199 | "execution_count": 0, 200 | "outputs": [] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "metadata": { 205 | "id": "fs8vvcpWfWi5", 206 | "colab_type": "code", 207 | "outputId": "5b1fafb5-b4de-4618-fdad-baa42385751f", 208 | "colab": { 209 | "base_uri": "https://localhost:8080/", 210 | "height": 287 211 | } 212 | }, 213 | "source": [ 214 | "X" 215 | ], 216 | "execution_count": 129, 217 | "outputs": [ 218 | { 219 | "output_type": "execute_result", 220 | "data": { 221 | "text/plain": [ 222 | "array([[1, 0],\n", 223 | " [1, 1],\n", 224 | " [1, 1],\n", 225 | " [1, 0],\n", 226 | " [1, 0],\n", 227 | " [2, 0],\n", 228 | " [2, 1],\n", 229 | " [2, 1],\n", 230 | " [2, 2],\n", 231 | " [2, 2],\n", 232 | " [3, 2],\n", 233 | " [3, 1],\n", 234 | " [3, 1],\n", 235 | " [3, 2],\n", 236 | " [3, 2]])" 237 | ] 238 | }, 239 | "metadata": { 240 | "tags": [] 241 | }, 242 | "execution_count": 129 243 | } 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "metadata": { 249 | "id": "FM89mRoZRiP1", 250 | "colab_type": "code", 251 | "colab": { 252 | "base_uri": "https://localhost:8080/", 253 | "height": 35 254 | }, 255 | "outputId": "93afb1dd-3be8-4c55-cbe0-f0549365e38a" 256 | }, 257 | "source": [ 258 | "y" 259 | ], 260 | "execution_count": 130, 261 | "outputs": [ 262 | { 263 | "output_type": "execute_result", 264 | "data": { 265 | "text/plain": [ 266 | "array([-1, -1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, -1])" 267 | ] 268 | }, 269 | "metadata": { 270 | "tags": [] 271 | }, 272 | "execution_count": 130 273 | } 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "metadata": { 279 | "id": "l08zoOrQRlcN", 280 | "colab_type": "code", 281 | "colab": { 282 | "base_uri": "https://localhost:8080/", 283 | "height": 53 284 | }, 285 | "outputId": "71096a11-5d50-4122-84ab-e0ab415dfc81" 286 | }, 287 | "source": [ 288 | "model = NB(lambda_)\n", 289 | "model.fit(X,y)\n", 290 | "model.predict(np.array([2, 0]))" 291 | ], 292 | "execution_count": 77, 293 | "outputs": [ 294 | { 295 | "output_type": "stream", 296 | "text": [ 297 | "{-1: 0.06506624688442873, 1: 0.024891774891774892}\n" 298 | ], 299 | "name": "stdout" 300 | }, 301 | { 302 | "output_type": "execute_result", 303 | "data": { 304 | "text/plain": [ 305 | "-1" 306 | ] 307 | }, 308 | "metadata": { 309 | "tags": [] 310 | }, 311 | "execution_count": 77 312 | } 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "metadata": { 318 | "id": "wULij7sgcQna", 319 | "colab_type": "code", 320 | "colab": {} 321 | }, 322 | "source": [ 323 | "# data\n", 324 | "def create_data():\n", 325 | " iris = load_iris()\n", 326 | " df = pd.DataFrame(iris.data, columns=iris.feature_names)\n", 327 | " df['label'] = iris.target\n", 328 | " df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n", 329 | " data = np.array(df.iloc[:100, :])\n", 330 | " # print(data)\n", 331 | " return data[:,:-1], data[:,-1]" 332 | ], 333 | "execution_count": 0, 334 | "outputs": [] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "metadata": { 339 | "id": "wniDd3wMcTRW", 340 | "colab_type": "code", 341 | "colab": {} 342 | }, 343 | "source": [ 344 | "X, y = create_data()\n", 345 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)" 346 | ], 347 | "execution_count": 0, 348 | "outputs": [] 349 | }, 350 | { 351 | "cell_type": "code", 352 | "metadata": { 353 | "id": "G6NBwGCxcUur", 354 | "colab_type": "code", 355 | "colab": { 356 | "base_uri": "https://localhost:8080/", 357 | "height": 35 358 | }, 359 | "outputId": "e85d5a75-de23-4ebf-fb00-7d1ce166b0ee" 360 | }, 361 | "source": [ 362 | "X_test[0], y_test[0]" 363 | ], 364 | "execution_count": 80, 365 | "outputs": [ 366 | { 367 | "output_type": "execute_result", 368 | "data": { 369 | "text/plain": [ 370 | "(array([5.6, 3. , 4.5, 1.5]), 1.0)" 371 | ] 372 | }, 373 | "metadata": { 374 | "tags": [] 375 | }, 376 | "execution_count": 80 377 | } 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "metadata": { 383 | "id": "J_xJWo5GcVya", 384 | "colab_type": "code", 385 | "colab": { 386 | "base_uri": "https://localhost:8080/", 387 | "height": 35 388 | }, 389 | "outputId": "03afa1be-a569-4857-d05d-7b3f70cc9a02" 390 | }, 391 | "source": [ 392 | "X_train.shape" 393 | ], 394 | "execution_count": 82, 395 | "outputs": [ 396 | { 397 | "output_type": "execute_result", 398 | "data": { 399 | "text/plain": [ 400 | "(70, 4)" 401 | ] 402 | }, 403 | "metadata": { 404 | "tags": [] 405 | }, 406 | "execution_count": 82 407 | } 408 | ] 409 | }, 410 | { 411 | "cell_type": "markdown", 412 | "metadata": { 413 | "id": "GyXsq6VvfWi-", 414 | "colab_type": "text" 415 | }, 416 | "source": [ 417 | "## GaussianNB 高斯朴素贝叶斯\n", 418 | "\n", 419 | "特征的可能性被假设为高斯\n", 420 | "\n", 421 | "概率密度函数:\n", 422 | "$$P(x_i | y_k)=\\frac{1}{\\sqrt{2\\pi\\sigma^2_{yk}}}exp(-\\frac{(x_i-\\mu_{yk})^2}{2\\sigma^2_{yk}})$$\n", 423 | "\n", 424 | "数学期望(mean):$\\mu$,方差:$\\sigma^2=\\frac{\\sum(X-\\mu)^2}{N}$" 425 | ] 426 | }, 427 | { 428 | "cell_type": "code", 429 | "metadata": { 430 | "id": "uqJnMoUrfWi-", 431 | "colab_type": "code", 432 | "colab": {} 433 | }, 434 | "source": [ 435 | "class NaiveBayes:\n", 436 | " def fit(self, X, y):\n", 437 | " self.classes = list(np.unique(y))\n", 438 | " self.parameters = {}\n", 439 | " \n", 440 | " for c in self.classes:\n", 441 | " # 计算每个种类的平均值,方差,先验概率\n", 442 | " X_Index_c = X[np.where(y == c)]\n", 443 | " X_index_c_mean = np.mean(X_Index_c, axis=0, keepdims=True)\n", 444 | " X_index_c_var = np.var(X_Index_c, axis=0, keepdims=True)\n", 445 | " parameters = {\"mean\": X_index_c_mean, \"var\": X_index_c_var, \"prior\": X_Index_c.shape[0] / X.shape[0]}\n", 446 | " self.parameters[\"class\" + str(c)] = parameters\n", 447 | " print(self.parameters)\n", 448 | " \n", 449 | " def _pdf(self, X, classes):\n", 450 | " # 一维高斯分布的概率密度函数\n", 451 | " eps = 1e-4\n", 452 | " mean = self.parameters[\"class\" + str(classes)][\"mean\"]\n", 453 | " var = self.parameters[\"class\" + str(classes)][\"var\"]\n", 454 | " \n", 455 | " numerator = np.exp(-(X - mean) ** 2 / (2 * var + eps))\n", 456 | " denominator = np.sqrt(2 * np.pi * var + eps)\n", 457 | " \n", 458 | " # 取对数防止数值溢出\n", 459 | " result = np.sum(np.log(numerator / denominator), axis=1, keepdims=True)\n", 460 | " \n", 461 | " return result.T\n", 462 | " \n", 463 | " def _predict(self, X):\n", 464 | " output = []\n", 465 | " for y in self.classes:\n", 466 | " prior = np.log(self.parameters[\"class\" + str(y)][\"prior\"])\n", 467 | " posterior = self._pdf(X, y)\n", 468 | " prediction = prior + posterior\n", 469 | " output.append(prediction)\n", 470 | " return output\n", 471 | " \n", 472 | " def predict(self, X):\n", 473 | " # 取概率最大的类别返回预测值\n", 474 | " output = self._predict(X)\n", 475 | " output = np.reshape(output, (len(self.classes), X.shape[0]))\n", 476 | " prediction = np.argmax(output, axis=0)\n", 477 | " return prediction\n", 478 | " \n", 479 | " def score(self, X_test, y_test):\n", 480 | " right = 0\n", 481 | " pred = self.predict(X_test)\n", 482 | " right = (y_test - pred == 0).sum()\n", 483 | "\n", 484 | " return right / float(len(X_test))" 485 | ], 486 | "execution_count": 0, 487 | "outputs": [] 488 | }, 489 | { 490 | "cell_type": "code", 491 | "metadata": { 492 | "id": "NpeBcwKJfWjA", 493 | "colab_type": "code", 494 | "colab": {} 495 | }, 496 | "source": [ 497 | "model = NaiveBayes()" 498 | ], 499 | "execution_count": 0, 500 | "outputs": [] 501 | }, 502 | { 503 | "cell_type": "code", 504 | "metadata": { 505 | "id": "JLj3a70GfWjD", 506 | "colab_type": "code", 507 | "outputId": "ef5182a8-668e-4f62-93fb-314f95f68220", 508 | "colab": { 509 | "base_uri": "https://localhost:8080/", 510 | "height": 73 511 | } 512 | }, 513 | "source": [ 514 | "model.fit(X_train, y_train)" 515 | ], 516 | "execution_count": 123, 517 | "outputs": [ 518 | { 519 | "output_type": "stream", 520 | "text": [ 521 | "{'class0.0': {'mean': array([[5.02571429, 3.42857143, 1.49142857, 0.24857143]]), 'var': array([[0.10648163, 0.15918367, 0.02478367, 0.01278367]]), 'prior': 0.5}}\n", 522 | "{'class0.0': {'mean': array([[5.02571429, 3.42857143, 1.49142857, 0.24857143]]), 'var': array([[0.10648163, 0.15918367, 0.02478367, 0.01278367]]), 'prior': 0.5}, 'class1.0': {'mean': array([[5.94285714, 2.77714286, 4.27428571, 1.34 ]]), 'var': array([[0.18816327, 0.09833469, 0.17505306, 0.03954286]]), 'prior': 0.5}}\n" 523 | ], 524 | "name": "stdout" 525 | } 526 | ] 527 | }, 528 | { 529 | "cell_type": "code", 530 | "metadata": { 531 | "id": "x9PDXudxfWjF", 532 | "colab_type": "code", 533 | "outputId": "68a6c9c7-9524-47fe-800e-d8a2d080ddea", 534 | "colab": { 535 | "base_uri": "https://localhost:8080/", 536 | "height": 35 537 | } 538 | }, 539 | "source": [ 540 | "print(model.predict(X_test))" 541 | ], 542 | "execution_count": 124, 543 | "outputs": [ 544 | { 545 | "output_type": "stream", 546 | "text": [ 547 | "[1 0 0 0 1 0 0 0 1 0 1 0 1 1 1 1 0 1 1 1 0 0 0 0 0 1 1 1 1 0]\n" 548 | ], 549 | "name": "stdout" 550 | } 551 | ] 552 | }, 553 | { 554 | "cell_type": "code", 555 | "metadata": { 556 | "id": "xMO7vvVvfWjI", 557 | "colab_type": "code", 558 | "outputId": "86148657-b876-4277-b3e0-678254b0ddaf", 559 | "colab": { 560 | "base_uri": "https://localhost:8080/", 561 | "height": 35 562 | } 563 | }, 564 | "source": [ 565 | "model.score(X_test, y_test)" 566 | ], 567 | "execution_count": 125, 568 | "outputs": [ 569 | { 570 | "output_type": "execute_result", 571 | "data": { 572 | "text/plain": [ 573 | "1.0" 574 | ] 575 | }, 576 | "metadata": { 577 | "tags": [] 578 | }, 579 | "execution_count": 125 580 | } 581 | ] 582 | }, 583 | { 584 | "cell_type": "markdown", 585 | "metadata": { 586 | "collapsed": true, 587 | "id": "TGnmFQFkfWjL", 588 | "colab_type": "text" 589 | }, 590 | "source": [ 591 | "scikit-learn实例\n", 592 | "\n", 593 | "# sklearn.naive_bayes" 594 | ] 595 | }, 596 | { 597 | "cell_type": "code", 598 | "metadata": { 599 | "id": "EBKRlWmsfWjM", 600 | "colab_type": "code", 601 | "colab": {} 602 | }, 603 | "source": [ 604 | "from sklearn.naive_bayes import GaussianNB" 605 | ], 606 | "execution_count": 0, 607 | "outputs": [] 608 | }, 609 | { 610 | "cell_type": "code", 611 | "metadata": { 612 | "id": "S7Q8mOzmfWjO", 613 | "colab_type": "code", 614 | "outputId": "d7fbefa1-b855-4ce5-9352-81049d510232", 615 | "colab": { 616 | "base_uri": "https://localhost:8080/", 617 | "height": 35 618 | } 619 | }, 620 | "source": [ 621 | "clf = GaussianNB()\n", 622 | "clf.fit(X, y)" 623 | ], 624 | "execution_count": 133, 625 | "outputs": [ 626 | { 627 | "output_type": "execute_result", 628 | "data": { 629 | "text/plain": [ 630 | "GaussianNB(priors=None, var_smoothing=1e-09)" 631 | ] 632 | }, 633 | "metadata": { 634 | "tags": [] 635 | }, 636 | "execution_count": 133 637 | } 638 | ] 639 | }, 640 | { 641 | "cell_type": "code", 642 | "metadata": { 643 | "id": "BdAKVtxXfWjT", 644 | "colab_type": "code", 645 | "outputId": "e994d698-ee17-4c39-aa4f-0ed491f7fad5", 646 | "colab": { 647 | "base_uri": "https://localhost:8080/", 648 | "height": 35 649 | } 650 | }, 651 | "source": [ 652 | "clf.predict([[2, 0]])" 653 | ], 654 | "execution_count": 134, 655 | "outputs": [ 656 | { 657 | "output_type": "execute_result", 658 | "data": { 659 | "text/plain": [ 660 | "array([-1])" 661 | ] 662 | }, 663 | "metadata": { 664 | "tags": [] 665 | }, 666 | "execution_count": 134 667 | } 668 | ] 669 | }, 670 | { 671 | "cell_type": "code", 672 | "metadata": { 673 | "colab_type": "code", 674 | "id": "7qit4xK_0aka", 675 | "colab": {} 676 | }, 677 | "source": [ 678 | "from sklearn.naive_bayes import BernoulliNB, MultinomialNB # 伯努利模型和多项式模型" 679 | ], 680 | "execution_count": 0, 681 | "outputs": [] 682 | }, 683 | { 684 | "cell_type": "code", 685 | "metadata": { 686 | "id": "l4sFBX_u0drg", 687 | "colab_type": "code", 688 | "colab": { 689 | "base_uri": "https://localhost:8080/", 690 | "height": 35 691 | }, 692 | "outputId": "a7a74c51-e62e-46d1-b966-d00b7ac266f8" 693 | }, 694 | "source": [ 695 | "clf1 = BernoulliNB()\n", 696 | "clf1.fit(X, y)\n", 697 | "clf1.predict([[2, 0]])" 698 | ], 699 | "execution_count": 138, 700 | "outputs": [ 701 | { 702 | "output_type": "execute_result", 703 | "data": { 704 | "text/plain": [ 705 | "array([-1])" 706 | ] 707 | }, 708 | "metadata": { 709 | "tags": [] 710 | }, 711 | "execution_count": 138 712 | } 713 | ] 714 | }, 715 | { 716 | "cell_type": "code", 717 | "metadata": { 718 | "id": "QEClQeuV0qw2", 719 | "colab_type": "code", 720 | "colab": { 721 | "base_uri": "https://localhost:8080/", 722 | "height": 35 723 | }, 724 | "outputId": "bef9b428-fc41-4a06-b120-8c66ab380a56" 725 | }, 726 | "source": [ 727 | "clf2 = MultinomialNB()\n", 728 | "clf2.fit(X, y)\n", 729 | "clf2.predict([[2, 0]])" 730 | ], 731 | "execution_count": 139, 732 | "outputs": [ 733 | { 734 | "output_type": "execute_result", 735 | "data": { 736 | "text/plain": [ 737 | "array([1])" 738 | ] 739 | }, 740 | "metadata": { 741 | "tags": [] 742 | }, 743 | "execution_count": 139 744 | } 745 | ] 746 | } 747 | ] 748 | } -------------------------------------------------------------------------------- /第05章 决策树/DT.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "DT.ipynb", 7 | "version": "0.3.2", 8 | "provenance": [], 9 | "collapsed_sections": [], 10 | "toc_visible": true 11 | }, 12 | "language_info": { 13 | "codemirror_mode": { 14 | "name": "ipython", 15 | "version": 3 16 | }, 17 | "file_extension": ".py", 18 | "mimetype": "text/x-python", 19 | "name": "python", 20 | "nbconvert_exporter": "python", 21 | "pygments_lexer": "ipython3", 22 | "version": "3.6.2" 23 | }, 24 | "kernelspec": { 25 | "display_name": "Python 3", 26 | "language": "python", 27 | "name": "python3" 28 | } 29 | }, 30 | "cells": [ 31 | { 32 | "cell_type": "markdown", 33 | "metadata": { 34 | "id": "Ajhk76kTn8L4", 35 | "colab_type": "text" 36 | }, 37 | "source": [ 38 | "# 第5章 决策树" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": { 44 | "id": "-iCRuHmOn8L5", 45 | "colab_type": "text" 46 | }, 47 | "source": [ 48 | "- ID3(基于信息增益)\n", 49 | "- C4.5(基于信息增益比)\n", 50 | "- CART 二叉决策树(gini指数)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": { 56 | "id": "EbDeryI9n8L6", 57 | "colab_type": "text" 58 | }, 59 | "source": [ 60 | "#### entropy:$H(x) = -\\sum_{i=1}^{n}p_i\\log{p_i}$\n", 61 | "\n", 62 | "#### conditional entropy: $H(X|Y)=\\sum{P(X|Y)}\\log{P(X|Y)}$\n", 63 | "\n", 64 | "#### information gain : $g(D, A)=H(D)-H(D|A)$\n", 65 | "\n", 66 | "#### information gain ratio: $g_R(D, A) = \\frac{g(D,A)}{H_{A}(D)}$\n", 67 | "\n", 68 | "#### gini index:$Gini(D)=\\sum_{k=1}^{K}p_k\\log{p_k}=1-\\sum_{k=1}^{K}p_k^2$" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "metadata": { 74 | "id": "qemEGcJ7n8L6", 75 | "colab_type": "code", 76 | "colab": {} 77 | }, 78 | "source": [ 79 | "import numpy as np\n", 80 | "import pandas as pd\n", 81 | "import matplotlib.pyplot as plt\n", 82 | "%matplotlib inline\n", 83 | "\n", 84 | "from sklearn.datasets import load_iris\n", 85 | "from sklearn.model_selection import train_test_split\n", 86 | "\n", 87 | "from collections import Counter\n", 88 | "import math\n", 89 | "from math import log\n", 90 | "\n", 91 | "import pprint" 92 | ], 93 | "execution_count": 0, 94 | "outputs": [] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": { 99 | "id": "XyUeXAC_n8L-", 100 | "colab_type": "text" 101 | }, 102 | "source": [ 103 | "### 例 5.1" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "metadata": { 109 | "id": "YtNCGcaHn8L_", 110 | "colab_type": "code", 111 | "colab": {} 112 | }, 113 | "source": [ 114 | "def create_data():\n", 115 | " datasets = [['青年', '否', '否', '一般', '否'],\n", 116 | " ['青年', '否', '否', '好', '否'],\n", 117 | " ['青年', '是', '否', '好', '是'],\n", 118 | " ['青年', '是', '是', '一般', '是'],\n", 119 | " ['青年', '否', '否', '一般', '否'],\n", 120 | " ['中年', '否', '否', '一般', '否'],\n", 121 | " ['中年', '否', '否', '好', '否'],\n", 122 | " ['中年', '是', '是', '好', '是'],\n", 123 | " ['中年', '否', '是', '非常好', '是'],\n", 124 | " ['中年', '否', '是', '非常好', '是'],\n", 125 | " ['老年', '否', '是', '非常好', '是'],\n", 126 | " ['老年', '否', '是', '好', '是'],\n", 127 | " ['老年', '是', '否', '好', '是'],\n", 128 | " ['老年', '是', '否', '非常好', '是'],\n", 129 | " ['老年', '否', '否', '一般', '否'],\n", 130 | " ]\n", 131 | " labels = [u'年龄', u'有工作', u'有自己的房子', u'信贷情况', u'类别']\n", 132 | " # 返回数据集和每个维度的名称\n", 133 | " return datasets, labels" 134 | ], 135 | "execution_count": 0, 136 | "outputs": [] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "metadata": { 141 | "id": "Ji3uUZS-n8MB", 142 | "colab_type": "code", 143 | "outputId": "bec9dbfe-5016-44ff-a080-6e3eea61bbd6", 144 | "colab": { 145 | "base_uri": "https://localhost:8080/", 146 | "height": 514 147 | } 148 | }, 149 | "source": [ 150 | "datasets, labels = create_data()\n", 151 | "train_data = pd.DataFrame(datasets, columns=labels)\n", 152 | "train_data" 153 | ], 154 | "execution_count": 68, 155 | "outputs": [ 156 | { 157 | "output_type": "execute_result", 158 | "data": { 159 | "text/html": [ 160 | "
\n", 161 | "\n", 174 | "\n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | "
年龄有工作有自己的房子信贷情况类别
0青年一般
1青年
2青年
3青年一般
4青年一般
5中年一般
6中年
7中年
8中年非常好
9中年非常好
10老年非常好
11老年
12老年
13老年非常好
14老年一般
\n", 308 | "
" 309 | ], 310 | "text/plain": [ 311 | " 年龄 有工作 有自己的房子 信贷情况 类别\n", 312 | "0 青年 否 否 一般 否\n", 313 | "1 青年 否 否 好 否\n", 314 | "2 青年 是 否 好 是\n", 315 | "3 青年 是 是 一般 是\n", 316 | "4 青年 否 否 一般 否\n", 317 | "5 中年 否 否 一般 否\n", 318 | "6 中年 否 否 好 否\n", 319 | "7 中年 是 是 好 是\n", 320 | "8 中年 否 是 非常好 是\n", 321 | "9 中年 否 是 非常好 是\n", 322 | "10 老年 否 是 非常好 是\n", 323 | "11 老年 否 是 好 是\n", 324 | "12 老年 是 否 好 是\n", 325 | "13 老年 是 否 非常好 是\n", 326 | "14 老年 否 否 一般 否" 327 | ] 328 | }, 329 | "metadata": { 330 | "tags": [] 331 | }, 332 | "execution_count": 68 333 | } 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "metadata": { 339 | "id": "bWbCbJSmrdkp", 340 | "colab_type": "code", 341 | "outputId": "3d16ea53-103e-4ad9-c3d6-c40b86614406", 342 | "colab": { 343 | "base_uri": "https://localhost:8080/", 344 | "height": 287 345 | } 346 | }, 347 | "source": [ 348 | "datasets" 349 | ], 350 | "execution_count": 53, 351 | "outputs": [ 352 | { 353 | "output_type": "execute_result", 354 | "data": { 355 | "text/plain": [ 356 | "[['青年', '否', '否', '一般', '否'],\n", 357 | " ['青年', '否', '否', '好', '否'],\n", 358 | " ['青年', '是', '否', '好', '是'],\n", 359 | " ['青年', '是', '是', '一般', '是'],\n", 360 | " ['青年', '否', '否', '一般', '否'],\n", 361 | " ['中年', '否', '否', '一般', '否'],\n", 362 | " ['中年', '否', '否', '好', '否'],\n", 363 | " ['中年', '是', '是', '好', '是'],\n", 364 | " ['中年', '否', '是', '非常好', '是'],\n", 365 | " ['中年', '否', '是', '非常好', '是'],\n", 366 | " ['老年', '否', '是', '非常好', '是'],\n", 367 | " ['老年', '否', '是', '好', '是'],\n", 368 | " ['老年', '是', '否', '好', '是'],\n", 369 | " ['老年', '是', '否', '非常好', '是'],\n", 370 | " ['老年', '否', '否', '一般', '否']]" 371 | ] 372 | }, 373 | "metadata": { 374 | "tags": [] 375 | }, 376 | "execution_count": 53 377 | } 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "metadata": { 383 | "id": "9zcwdE1uiTOO", 384 | "colab_type": "code", 385 | "colab": { 386 | "base_uri": "https://localhost:8080/", 387 | "height": 35 388 | }, 389 | "outputId": "30378d95-6bb8-4d63-b57b-33f70b357b3e" 390 | }, 391 | "source": [ 392 | "labels" 393 | ], 394 | "execution_count": 54, 395 | "outputs": [ 396 | { 397 | "output_type": "execute_result", 398 | "data": { 399 | "text/plain": [ 400 | "['年龄', '有工作', '有自己的房子', '信贷情况', '类别']" 401 | ] 402 | }, 403 | "metadata": { 404 | "tags": [] 405 | }, 406 | "execution_count": 54 407 | } 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "metadata": { 413 | "id": "UP3X4BaVrgYQ", 414 | "colab_type": "code", 415 | "colab": {} 416 | }, 417 | "source": [ 418 | "d = {'青年':1, '中年':2, '老年':3, '一般':1, '好':2, '非常好':3, '是':0, '否':1}\n", 419 | "data = []\n", 420 | "for i in range(15):\n", 421 | " tmp = []\n", 422 | " t = datasets[i]\n", 423 | " for tt in t:\n", 424 | " tmp.append(d[tt])\n", 425 | " data.append(tmp)" 426 | ], 427 | "execution_count": 0, 428 | "outputs": [] 429 | }, 430 | { 431 | "cell_type": "code", 432 | "metadata": { 433 | "id": "5tV-TRIQqftJ", 434 | "colab_type": "code", 435 | "outputId": "16be14d9-1d5d-4080-db56-0dfcf329830f", 436 | "colab": { 437 | "base_uri": "https://localhost:8080/", 438 | "height": 287 439 | } 440 | }, 441 | "source": [ 442 | "data = np.array(data);data" 443 | ], 444 | "execution_count": 56, 445 | "outputs": [ 446 | { 447 | "output_type": "execute_result", 448 | "data": { 449 | "text/plain": [ 450 | "array([[1, 1, 1, 1, 1],\n", 451 | " [1, 1, 1, 2, 1],\n", 452 | " [1, 0, 1, 2, 0],\n", 453 | " [1, 0, 0, 1, 0],\n", 454 | " [1, 1, 1, 1, 1],\n", 455 | " [2, 1, 1, 1, 1],\n", 456 | " [2, 1, 1, 2, 1],\n", 457 | " [2, 0, 0, 2, 0],\n", 458 | " [2, 1, 0, 3, 0],\n", 459 | " [2, 1, 0, 3, 0],\n", 460 | " [3, 1, 0, 3, 0],\n", 461 | " [3, 1, 0, 2, 0],\n", 462 | " [3, 0, 1, 2, 0],\n", 463 | " [3, 0, 1, 3, 0],\n", 464 | " [3, 1, 1, 1, 1]])" 465 | ] 466 | }, 467 | "metadata": { 468 | "tags": [] 469 | }, 470 | "execution_count": 56 471 | } 472 | ] 473 | }, 474 | { 475 | "cell_type": "code", 476 | "metadata": { 477 | "id": "sN169YUn2LvE", 478 | "colab_type": "code", 479 | "outputId": "b0561124-c930-4706-fed3-095308e4f53f", 480 | "colab": { 481 | "base_uri": "https://localhost:8080/", 482 | "height": 35 483 | } 484 | }, 485 | "source": [ 486 | "data.shape" 487 | ], 488 | "execution_count": 57, 489 | "outputs": [ 490 | { 491 | "output_type": "execute_result", 492 | "data": { 493 | "text/plain": [ 494 | "(15, 5)" 495 | ] 496 | }, 497 | "metadata": { 498 | "tags": [] 499 | }, 500 | "execution_count": 57 501 | } 502 | ] 503 | }, 504 | { 505 | "cell_type": "code", 506 | "metadata": { 507 | "id": "oN7QSJC72UN-", 508 | "colab_type": "code", 509 | "colab": {} 510 | }, 511 | "source": [ 512 | "X, y = data[:,:-1], data[:, -1]" 513 | ], 514 | "execution_count": 0, 515 | "outputs": [] 516 | }, 517 | { 518 | "cell_type": "code", 519 | "metadata": { 520 | "id": "1KsMqBec5Cwb", 521 | "colab_type": "code", 522 | "colab": {} 523 | }, 524 | "source": [ 525 | "# 熵\n", 526 | "def entropy(y):\n", 527 | " N = len(y)\n", 528 | " count = []\n", 529 | " for value in set(y):\n", 530 | " count.append(len(y[y == value]))\n", 531 | " count = np.array(count)\n", 532 | " entro = -np.sum((count / N) * (np.log2(count / N)))\n", 533 | " return entro" 534 | ], 535 | "execution_count": 0, 536 | "outputs": [] 537 | }, 538 | { 539 | "cell_type": "code", 540 | "metadata": { 541 | "id": "DWb2n4RcDflB", 542 | "colab_type": "code", 543 | "outputId": "b01d7bde-33c9-467a-bf65-28d8c4d3e77f", 544 | "colab": { 545 | "base_uri": "https://localhost:8080/", 546 | "height": 35 547 | } 548 | }, 549 | "source": [ 550 | "entropy(y)" 551 | ], 552 | "execution_count": 10, 553 | "outputs": [ 554 | { 555 | "output_type": "execute_result", 556 | "data": { 557 | "text/plain": [ 558 | "0.9709505944546686" 559 | ] 560 | }, 561 | "metadata": { 562 | "tags": [] 563 | }, 564 | "execution_count": 10 565 | } 566 | ] 567 | }, 568 | { 569 | "cell_type": "code", 570 | "metadata": { 571 | "id": "ApbpfKpcxw6y", 572 | "colab_type": "code", 573 | "colab": {} 574 | }, 575 | "source": [ 576 | "# 条件熵\n", 577 | "def cond_entropy(X, y, cond):\n", 578 | " N = len(y)\n", 579 | " cond_X = X[:, cond]\n", 580 | " tmp_entro = []\n", 581 | " for val in set(cond_X):\n", 582 | " tmp_y = y[np.where(cond_X == val)]\n", 583 | " tmp_entro.append(len(tmp_y)/N * entropy(tmp_y))\n", 584 | " cond_entro = sum(tmp_entro)\n", 585 | " return cond_entro" 586 | ], 587 | "execution_count": 0, 588 | "outputs": [] 589 | }, 590 | { 591 | "cell_type": "code", 592 | "metadata": { 593 | "id": "NF-g7udFK5qN", 594 | "colab_type": "code", 595 | "outputId": "9ddf18ff-6ddc-48fb-d766-4dc2c5c4034e", 596 | "colab": { 597 | "base_uri": "https://localhost:8080/", 598 | "height": 35 599 | } 600 | }, 601 | "source": [ 602 | "cond_entropy(X, y, 0)" 603 | ], 604 | "execution_count": 12, 605 | "outputs": [ 606 | { 607 | "output_type": "execute_result", 608 | "data": { 609 | "text/plain": [ 610 | "0.8879430945988998" 611 | ] 612 | }, 613 | "metadata": { 614 | "tags": [] 615 | }, 616 | "execution_count": 12 617 | } 618 | ] 619 | }, 620 | { 621 | "cell_type": "code", 622 | "metadata": { 623 | "id": "QXrKL-4mS3Q5", 624 | "colab_type": "code", 625 | "colab": {} 626 | }, 627 | "source": [ 628 | "# 信息增益\n", 629 | "def info_gain(X, y, cond):\n", 630 | " return entropy(y) - cond_entropy(X, y, cond)" 631 | ], 632 | "execution_count": 0, 633 | "outputs": [] 634 | }, 635 | { 636 | "cell_type": "code", 637 | "metadata": { 638 | "id": "KjLX_NtqezON", 639 | "colab_type": "code", 640 | "colab": {} 641 | }, 642 | "source": [ 643 | "# 信息增益比\n", 644 | "def info_gain_ratio(X, y, cond):\n", 645 | " return (entropy(y) - cond_entropy(X, y, cond))/cond_entropy(X, y, cond)" 646 | ], 647 | "execution_count": 0, 648 | "outputs": [] 649 | }, 650 | { 651 | "cell_type": "code", 652 | "metadata": { 653 | "id": "kKY7AVPeF4kh", 654 | "colab_type": "code", 655 | "outputId": "670c66c9-8f2c-46b0-b626-633f7fca4cd8", 656 | "colab": { 657 | "base_uri": "https://localhost:8080/", 658 | "height": 35 659 | } 660 | }, 661 | "source": [ 662 | "# A1, A2, A3, A4 =》年龄 工作 房子 信贷\n", 663 | "# 信息增益\n", 664 | "\n", 665 | "gain_a1 = info_gain(X, y, 0);gain_a1" 666 | ], 667 | "execution_count": 15, 668 | "outputs": [ 669 | { 670 | "output_type": "execute_result", 671 | "data": { 672 | "text/plain": [ 673 | "0.08300749985576883" 674 | ] 675 | }, 676 | "metadata": { 677 | "tags": [] 678 | }, 679 | "execution_count": 15 680 | } 681 | ] 682 | }, 683 | { 684 | "cell_type": "code", 685 | "metadata": { 686 | "id": "VVTUqG4tSgwn", 687 | "colab_type": "code", 688 | "outputId": "72b043b5-a4c0-42db-b12d-5a31f577ef80", 689 | "colab": { 690 | "base_uri": "https://localhost:8080/", 691 | "height": 34 692 | } 693 | }, 694 | "source": [ 695 | "gain_a2 = info_gain(X, y, 1);gain_a2" 696 | ], 697 | "execution_count": 0, 698 | "outputs": [ 699 | { 700 | "output_type": "execute_result", 701 | "data": { 702 | "text/plain": [ 703 | "0.32365019815155627" 704 | ] 705 | }, 706 | "metadata": { 707 | "tags": [] 708 | }, 709 | "execution_count": 16 710 | } 711 | ] 712 | }, 713 | { 714 | "cell_type": "code", 715 | "metadata": { 716 | "id": "242jN12HSj_F", 717 | "colab_type": "code", 718 | "outputId": "a620d840-ac93-4adb-da8e-151bbe04e95c", 719 | "colab": { 720 | "base_uri": "https://localhost:8080/", 721 | "height": 34 722 | } 723 | }, 724 | "source": [ 725 | "gain_a3 = info_gain(X, y, 2);gain_a3" 726 | ], 727 | "execution_count": 0, 728 | "outputs": [ 729 | { 730 | "output_type": "execute_result", 731 | "data": { 732 | "text/plain": [ 733 | "0.4199730940219749" 734 | ] 735 | }, 736 | "metadata": { 737 | "tags": [] 738 | }, 739 | "execution_count": 17 740 | } 741 | ] 742 | }, 743 | { 744 | "cell_type": "code", 745 | "metadata": { 746 | "id": "m9prl_iaSmM1", 747 | "colab_type": "code", 748 | "outputId": "b440c3ed-8bc7-4f36-caf5-d13524d43957", 749 | "colab": { 750 | "base_uri": "https://localhost:8080/", 751 | "height": 34 752 | } 753 | }, 754 | "source": [ 755 | "gain_a4 = info_gain(X, y, 3);gain_a4" 756 | ], 757 | "execution_count": 0, 758 | "outputs": [ 759 | { 760 | "output_type": "execute_result", 761 | "data": { 762 | "text/plain": [ 763 | "0.36298956253708536" 764 | ] 765 | }, 766 | "metadata": { 767 | "tags": [] 768 | }, 769 | "execution_count": 18 770 | } 771 | ] 772 | }, 773 | { 774 | "cell_type": "code", 775 | "metadata": { 776 | "id": "eIuVibAjpXSr", 777 | "colab_type": "code", 778 | "colab": {} 779 | }, 780 | "source": [ 781 | "def best_split(X,y, method='info_gain'):\n", 782 | " \"\"\"根据method指定的方法使用信息增益或信息增益比来计算各个维度的最大信息增益(比),返回特征的axis\"\"\"\n", 783 | " _, M = X.shape\n", 784 | " info_gains = []\n", 785 | " if method == 'info_gain':\n", 786 | " split = info_gain\n", 787 | " elif method == 'info_gain_ratio':\n", 788 | " split = info_gain_ratio\n", 789 | " else:\n", 790 | " print('No such method')\n", 791 | " return\n", 792 | " for i in range(M):\n", 793 | " tmp_gain = split(X, y, i)\n", 794 | " info_gains.append(tmp_gain)\n", 795 | " best_feature = np.argmax(info_gains)\n", 796 | " \n", 797 | " return best_feature" 798 | ], 799 | "execution_count": 0, 800 | "outputs": [] 801 | }, 802 | { 803 | "cell_type": "code", 804 | "metadata": { 805 | "id": "Tr6ckR8wriYm", 806 | "colab_type": "code", 807 | "outputId": "d2db3308-ce72-4f5d-c74e-893944909892", 808 | "colab": { 809 | "base_uri": "https://localhost:8080/", 810 | "height": 35 811 | } 812 | }, 813 | "source": [ 814 | "best_split(X,y)" 815 | ], 816 | "execution_count": 27, 817 | "outputs": [ 818 | { 819 | "output_type": "execute_result", 820 | "data": { 821 | "text/plain": [ 822 | "2" 823 | ] 824 | }, 825 | "metadata": { 826 | "tags": [] 827 | }, 828 | "execution_count": 27 829 | } 830 | ] 831 | }, 832 | { 833 | "cell_type": "code", 834 | "metadata": { 835 | "id": "iv2hm3ueTKa6", 836 | "colab_type": "code", 837 | "colab": {} 838 | }, 839 | "source": [ 840 | "def majorityCnt(y):\n", 841 | " \"\"\"当特征使用完时,返回类别数最多的类别\"\"\"\n", 842 | " unique, counts = np.unique(y, return_counts=True)\n", 843 | " max_idx = np.argmax(counts)\n", 844 | " return unique[max_idx]" 845 | ], 846 | "execution_count": 0, 847 | "outputs": [] 848 | }, 849 | { 850 | "cell_type": "code", 851 | "metadata": { 852 | "id": "FXlY9UPxT80q", 853 | "colab_type": "code", 854 | "colab": { 855 | "base_uri": "https://localhost:8080/", 856 | "height": 35 857 | }, 858 | "outputId": "e356a964-3f83-40b4-ed66-41fae5266a89" 859 | }, 860 | "source": [ 861 | "majorityCnt(y)" 862 | ], 863 | "execution_count": 20, 864 | "outputs": [ 865 | { 866 | "output_type": "execute_result", 867 | "data": { 868 | "text/plain": [ 869 | "0" 870 | ] 871 | }, 872 | "metadata": { 873 | "tags": [] 874 | }, 875 | "execution_count": 20 876 | } 877 | ] 878 | }, 879 | { 880 | "cell_type": "markdown", 881 | "metadata": { 882 | "collapsed": true, 883 | "id": "EDx9vrfcn8MQ", 884 | "colab_type": "text" 885 | }, 886 | "source": [ 887 | "#### ID3, C4.5算法\n", 888 | "\n", 889 | "例5.3" 890 | ] 891 | }, 892 | { 893 | "cell_type": "code", 894 | "metadata": { 895 | "id": "kpgCEMIKRo8_", 896 | "colab_type": "code", 897 | "colab": {} 898 | }, 899 | "source": [ 900 | "class DecisionTreeClassifer:\n", 901 | " \"\"\"\n", 902 | " 决策树生成算法,\n", 903 | " method指定ID3或C4.5,两方法唯一不同在于特征选择方法不同\n", 904 | " info_gain: 信息增益即ID3\n", 905 | " info_gain_ratio: 信息增益比即C4.5\n", 906 | " \n", 907 | " \n", 908 | " \"\"\"\n", 909 | " def __init__(self, threshold, method='info_gain'):\n", 910 | " self.threshold = threshold\n", 911 | " self.method = method\n", 912 | " \n", 913 | " def fit(self, X, y, labels):\n", 914 | " labels = labels.copy()\n", 915 | " M, N = X.shape\n", 916 | " if len(np.unique(y)) == 1:\n", 917 | " return y[0]\n", 918 | " \n", 919 | " if N == 1:\n", 920 | " return majorityCnt(y)\n", 921 | " \n", 922 | " bestSplit = best_split(X,y, method=self.method)\n", 923 | " bestFeaLable = labels[bestSplit]\n", 924 | " Tree = {bestFeaLable: {}}\n", 925 | " del (labels[bestSplit])\n", 926 | " \n", 927 | " feaVals = np.unique(X[:, bestSplit])\n", 928 | " for val in feaVals:\n", 929 | " idx = np.where(X[:, bestSplit] == val)\n", 930 | " sub_X = X[idx]\n", 931 | " sub_y = y[idx]\n", 932 | " sub_labels = labels\n", 933 | " Tree[bestFeaLable][val] = self.fit(sub_X, sub_y, sub_labels)\n", 934 | " \n", 935 | " return Tree" 936 | ], 937 | "execution_count": 0, 938 | "outputs": [] 939 | }, 940 | { 941 | "cell_type": "code", 942 | "metadata": { 943 | "id": "8k4cgeqBn8MQ", 944 | "colab_type": "code", 945 | "colab": { 946 | "base_uri": "https://localhost:8080/", 947 | "height": 35 948 | }, 949 | "outputId": "f3c4ca27-9f09-4d42-bd58-3afa41cc32e0" 950 | }, 951 | "source": [ 952 | "My_Tree = DecisionTreeClassifer(threshold=0.1)\n", 953 | "My_Tree.fit(X, y, labels)" 954 | ], 955 | "execution_count": 69, 956 | "outputs": [ 957 | { 958 | "output_type": "execute_result", 959 | "data": { 960 | "text/plain": [ 961 | "{'有自己的房子': {0: 0, 1: {'有工作': {0: 0, 1: 1}}}}" 962 | ] 963 | }, 964 | "metadata": { 965 | "tags": [] 966 | }, 967 | "execution_count": 69 968 | } 969 | ] 970 | }, 971 | { 972 | "cell_type": "markdown", 973 | "metadata": { 974 | "id": "XaGNaDfAoivJ", 975 | "colab_type": "text" 976 | }, 977 | "source": [ 978 | "#### CART树" 979 | ] 980 | }, 981 | { 982 | "cell_type": "code", 983 | "metadata": { 984 | "id": "yXTTfkLCmsdP", 985 | "colab_type": "code", 986 | "colab": {} 987 | }, 988 | "source": [ 989 | "class CART:\n", 990 | " \"\"\"CART树\"\"\"\n", 991 | " def __init__(self, ):\n", 992 | " \"to be continue\"" 993 | ], 994 | "execution_count": 0, 995 | "outputs": [] 996 | }, 997 | { 998 | "cell_type": "markdown", 999 | "metadata": { 1000 | "id": "6nxK8duGo37e", 1001 | "colab_type": "text" 1002 | }, 1003 | "source": [ 1004 | "#### 决策树的剪枝" 1005 | ] 1006 | }, 1007 | { 1008 | "cell_type": "code", 1009 | "metadata": { 1010 | "id": "N79jPbwWo6rv", 1011 | "colab_type": "code", 1012 | "colab": {} 1013 | }, 1014 | "source": [ 1015 | "\"to be continue\"" 1016 | ], 1017 | "execution_count": 0, 1018 | "outputs": [] 1019 | }, 1020 | { 1021 | "cell_type": "markdown", 1022 | "metadata": { 1023 | "id": "Gop3ocYDn8MZ", 1024 | "colab_type": "text" 1025 | }, 1026 | "source": [ 1027 | "---\n", 1028 | "\n", 1029 | "## sklearn.tree.DecisionTreeClassifier\n", 1030 | "\n", 1031 | "### criterion : string, optional (default=”gini”)\n", 1032 | "The function to measure the quality of a split. Supported criteria are “gini” for the Gini impurity and “entropy” for the information gain." 1033 | ] 1034 | }, 1035 | { 1036 | "cell_type": "code", 1037 | "metadata": { 1038 | "id": "nxE7F4sqn8Ma", 1039 | "colab_type": "code", 1040 | "colab": {} 1041 | }, 1042 | "source": [ 1043 | "# data\n", 1044 | "def create_data():\n", 1045 | " iris = load_iris()\n", 1046 | " df = pd.DataFrame(iris.data, columns=iris.feature_names)\n", 1047 | " df['label'] = iris.target\n", 1048 | " df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n", 1049 | " data = np.array(df.iloc[:100, [0, 1, -1]])\n", 1050 | " # print(data)\n", 1051 | " return data[:,:2], data[:,-1]\n", 1052 | "\n", 1053 | "X, y = create_data()\n", 1054 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)" 1055 | ], 1056 | "execution_count": 0, 1057 | "outputs": [] 1058 | }, 1059 | { 1060 | "cell_type": "code", 1061 | "metadata": { 1062 | "id": "LyqL3F8un8Mc", 1063 | "colab_type": "code", 1064 | "colab": {} 1065 | }, 1066 | "source": [ 1067 | "from sklearn.tree import DecisionTreeClassifier\n", 1068 | "\n", 1069 | "from sklearn.tree import export_graphviz\n", 1070 | "import graphviz" 1071 | ], 1072 | "execution_count": 0, 1073 | "outputs": [] 1074 | }, 1075 | { 1076 | "cell_type": "code", 1077 | "metadata": { 1078 | "id": "nNDjw1Phn8Me", 1079 | "colab_type": "code", 1080 | "outputId": "d2dd416b-1a48-4564-c53f-c801416c6df0", 1081 | "colab": { 1082 | "base_uri": "https://localhost:8080/", 1083 | "height": 125 1084 | } 1085 | }, 1086 | "source": [ 1087 | "clf = DecisionTreeClassifier()\n", 1088 | "clf.fit(data[:,:-1], data[:,-1])" 1089 | ], 1090 | "execution_count": 0, 1091 | "outputs": [ 1092 | { 1093 | "output_type": "execute_result", 1094 | "data": { 1095 | "text/plain": [ 1096 | "DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,\n", 1097 | " max_features=None, max_leaf_nodes=None,\n", 1098 | " min_impurity_decrease=0.0, min_impurity_split=None,\n", 1099 | " min_samples_leaf=1, min_samples_split=2,\n", 1100 | " min_weight_fraction_leaf=0.0, presort=False,\n", 1101 | " random_state=None, splitter='best')" 1102 | ] 1103 | }, 1104 | "metadata": { 1105 | "tags": [] 1106 | }, 1107 | "execution_count": 25 1108 | } 1109 | ] 1110 | }, 1111 | { 1112 | "cell_type": "code", 1113 | "metadata": { 1114 | "id": "RsB_iiLZn8Mh", 1115 | "colab_type": "code", 1116 | "outputId": "af553ea2-ec41-496d-ceb9-a750be3d8088", 1117 | "colab": { 1118 | "base_uri": "https://localhost:8080/", 1119 | "height": 35 1120 | } 1121 | }, 1122 | "source": [ 1123 | "clf.predict(np.array([1, 1, 0, 1]).reshape(1,-1)) # A" 1124 | ], 1125 | "execution_count": 0, 1126 | "outputs": [ 1127 | { 1128 | "output_type": "execute_result", 1129 | "data": { 1130 | "text/plain": [ 1131 | "array([0])" 1132 | ] 1133 | }, 1134 | "metadata": { 1135 | "tags": [] 1136 | }, 1137 | "execution_count": 28 1138 | } 1139 | ] 1140 | }, 1141 | { 1142 | "cell_type": "code", 1143 | "metadata": { 1144 | "id": "Sd2ScBfQu1Bo", 1145 | "colab_type": "code", 1146 | "outputId": "cbdf0ec2-1fd6-48a3-ae24-f4de9d6a442b", 1147 | "colab": { 1148 | "base_uri": "https://localhost:8080/", 1149 | "height": 35 1150 | } 1151 | }, 1152 | "source": [ 1153 | "clf.predict(np.array([2, 0, 1, 2]).reshape(1,-1)) # B" 1154 | ], 1155 | "execution_count": 0, 1156 | "outputs": [ 1157 | { 1158 | "output_type": "execute_result", 1159 | "data": { 1160 | "text/plain": [ 1161 | "array([0])" 1162 | ] 1163 | }, 1164 | "metadata": { 1165 | "tags": [] 1166 | }, 1167 | "execution_count": 29 1168 | } 1169 | ] 1170 | }, 1171 | { 1172 | "cell_type": "code", 1173 | "metadata": { 1174 | "id": "0E9mMz34u1a3", 1175 | "colab_type": "code", 1176 | "outputId": "7078b7fc-3322-4ee5-9e5f-6f8d575347f6", 1177 | "colab": { 1178 | "base_uri": "https://localhost:8080/", 1179 | "height": 35 1180 | } 1181 | }, 1182 | "source": [ 1183 | "clf.predict(np.array([2, 1, 0, 1]).reshape(1,-1)) # C" 1184 | ], 1185 | "execution_count": 0, 1186 | "outputs": [ 1187 | { 1188 | "output_type": "execute_result", 1189 | "data": { 1190 | "text/plain": [ 1191 | "array([0])" 1192 | ] 1193 | }, 1194 | "metadata": { 1195 | "tags": [] 1196 | }, 1197 | "execution_count": 30 1198 | } 1199 | ] 1200 | }, 1201 | { 1202 | "cell_type": "code", 1203 | "metadata": { 1204 | "id": "rmZHZjbYn8Mm", 1205 | "colab_type": "code", 1206 | "colab": {} 1207 | }, 1208 | "source": [ 1209 | "tree_pic = export_graphviz(clf, out_file=\"mytree.pdf\")\n", 1210 | "with open('mytree.pdf') as f:\n", 1211 | " dot_graph = f.read()" 1212 | ], 1213 | "execution_count": 0, 1214 | "outputs": [] 1215 | }, 1216 | { 1217 | "cell_type": "code", 1218 | "metadata": { 1219 | "id": "AeRk07sYn8Mq", 1220 | "colab_type": "code", 1221 | "outputId": "c526e824-6d1d-4f3e-f231-53b9d6447d90", 1222 | "colab": { 1223 | "base_uri": "https://localhost:8080/", 1224 | "height": 379 1225 | } 1226 | }, 1227 | "source": [ 1228 | "graphviz.Source(dot_graph)" 1229 | ], 1230 | "execution_count": 0, 1231 | "outputs": [ 1232 | { 1233 | "output_type": "execute_result", 1234 | "data": { 1235 | "text/plain": [ 1236 | "" 1237 | ], 1238 | "image/svg+xml": "\n\n\n\n\n\nTree\n\n\n\n0\n\nX[2] <= 0.5\ngini = 0.48\nsamples = 15\nvalue = [9, 6]\n\n\n\n1\n\ngini = 0.0\nsamples = 6\nvalue = [6, 0]\n\n\n\n0->1\n\n\nTrue\n\n\n\n2\n\nX[1] <= 0.5\ngini = 0.444\nsamples = 9\nvalue = [3, 6]\n\n\n\n0->2\n\n\nFalse\n\n\n\n3\n\ngini = 0.0\nsamples = 3\nvalue = [3, 0]\n\n\n\n2->3\n\n\n\n\n\n4\n\ngini = 0.0\nsamples = 6\nvalue = [0, 6]\n\n\n\n2->4\n\n\n\n\n\n" 1239 | }, 1240 | "metadata": { 1241 | "tags": [] 1242 | }, 1243 | "execution_count": 32 1244 | } 1245 | ] 1246 | }, 1247 | { 1248 | "cell_type": "code", 1249 | "metadata": { 1250 | "id": "dlk_DsGByMix", 1251 | "colab_type": "code", 1252 | "colab": {} 1253 | }, 1254 | "source": [ 1255 | "" 1256 | ], 1257 | "execution_count": 0, 1258 | "outputs": [] 1259 | } 1260 | ] 1261 | } -------------------------------------------------------------------------------- /第07章 支持向量机/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hktxt/Learn-Statistical-Learning-Method/14a3b65db3ff35c42743a1062a2a95cecbfad094/第07章 支持向量机/.DS_Store -------------------------------------------------------------------------------- /第07章 支持向量机/SVM.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 第7章 支持向量机" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "----\n", 15 | "分离超平面:$w^Tx+b=0$\n", 16 | "\n", 17 | "点到直线距离:$r=\\frac{|w^Tx+b|}{||w||_2}$\n", 18 | "\n", 19 | "$||w||_2$为2-范数:$||w||_2=\\sqrt[2]{\\sum^m_{i=1}w_i^2}$\n", 20 | "\n", 21 | "直线为超平面,样本可表示为:\n", 22 | "\n", 23 | "$w^Tx+b\\ \\geq+1$\n", 24 | "\n", 25 | "$w^Tx+b\\ \\leq+1$\n", 26 | "\n", 27 | "#### margin:\n", 28 | "\n", 29 | "**函数间隔**:$label(w^Tx+b)\\ or\\ y_i(w^Tx+b)$\n", 30 | "\n", 31 | "**几何间隔**:$r=\\frac{label(w^Tx+b)}{||w||_2}$,当数据被正确分类时,几何间隔就是点到超平面的距离\n", 32 | "\n", 33 | "为了求几何间隔最大,SVM基本问题可以转化为求解:($\\frac{r^*}{||w||}$为几何间隔,(${r^*}$为函数间隔)\n", 34 | "\n", 35 | "$$\\max\\ \\frac{r^*}{||w||}$$\n", 36 | "\n", 37 | "$$(subject\\ to)\\ y_i({w^T}x_i+{b})\\geq {r^*},\\ i=1,2,..,m$$\n", 38 | "\n", 39 | "分类点几何间隔最大,同时被正确分类。但这个方程并非凸函数求解,所以要先①将方程转化为凸函数,②用拉格朗日乘子法和KKT条件求解对偶问题。\n", 40 | "\n", 41 | "①转化为凸函数:\n", 42 | "\n", 43 | "先令${r^*}=1$,方便计算(参照衡量,不影响评价结果)\n", 44 | "\n", 45 | "$$\\max\\ \\frac{1}{||w||}$$\n", 46 | "\n", 47 | "$$s.t.\\ y_i({w^T}x_i+{b})\\geq {1},\\ i=1,2,..,m$$\n", 48 | "\n", 49 | "再将$\\max\\ \\frac{1}{||w||}$转化成$\\min\\ \\frac{1}{2}||w||^2$求解凸函数,1/2是为了求导之后方便计算。\n", 50 | "\n", 51 | "$$\\min\\ \\frac{1}{2}||w||^2$$\n", 52 | "\n", 53 | "$$s.t.\\ y_i(w^Tx_i+b)\\geq 1,\\ i=1,2,..,m$$\n", 54 | "\n", 55 | "②用拉格朗日乘子法和KKT条件求解最优值:\n", 56 | "\n", 57 | "$$\\min\\ \\frac{1}{2}||w||^2$$\n", 58 | "\n", 59 | "$$s.t.\\ -y_i(w^Tx_i+b)+1\\leq 0,\\ i=1,2,..,m$$\n", 60 | "\n", 61 | "整合成:\n", 62 | "\n", 63 | "$$L(w, b, \\alpha) = \\frac{1}{2}||w||^2+\\sum^m_{i=1}\\alpha_i(-y_i(w^Tx_i+b)+1)$$\n", 64 | "\n", 65 | "推导:$\\min\\ f(x)=\\min \\max\\ L(w, b, \\alpha)\\geq \\max \\min\\ L(w, b, \\alpha)$\n", 66 | "\n", 67 | "根据KKT条件:\n", 68 | "\n", 69 | "$$\\frac{\\partial }{\\partial w}L(w, b, \\alpha)=w-\\sum\\alpha_iy_ix_i=0,\\ w=\\sum\\alpha_iy_ix_i$$\n", 70 | "\n", 71 | "$$\\frac{\\partial }{\\partial b}L(w, b, \\alpha)=\\sum\\alpha_iy_i=0$$\n", 72 | "\n", 73 | "带入$ L(w, b, \\alpha)$\n", 74 | "\n", 75 | "$\\min\\ L(w, b, \\alpha)=\\frac{1}{2}||w||^2+\\sum^m_{i=1}\\alpha_i(-y_i(w^Tx_i+b)+1)$\n", 76 | "\n", 77 | "$\\qquad\\qquad\\qquad=\\frac{1}{2}w^Tw-\\sum^m_{i=1}\\alpha_iy_iw^Tx_i-b\\sum^m_{i=1}\\alpha_iy_i+\\sum^m_{i=1}\\alpha_i$\n", 78 | "\n", 79 | "$\\qquad\\qquad\\qquad=\\frac{1}{2}w^T\\sum\\alpha_iy_ix_i-\\sum^m_{i=1}\\alpha_iy_iw^Tx_i+\\sum^m_{i=1}\\alpha_i$\n", 80 | "\n", 81 | "$\\qquad\\qquad\\qquad=\\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i=1}\\alpha_iy_iw^Tx_i$\n", 82 | "\n", 83 | "$\\qquad\\qquad\\qquad=\\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)$\n", 84 | "\n", 85 | "再把max问题转成min问题:\n", 86 | "\n", 87 | "$\\max\\ \\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)=\\min \\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)-\\sum^m_{i=1}\\alpha_i$\n", 88 | "\n", 89 | "$s.t.\\ \\sum^m_{i=1}\\alpha_iy_i=0,$\n", 90 | "\n", 91 | "$ \\alpha_i \\geq 0,i=1,2,...,m$\n", 92 | "\n", 93 | "以上为SVM对偶问题的对偶形式\n", 94 | "\n", 95 | "-----\n", 96 | "#### kernel\n", 97 | "\n", 98 | "在低维空间计算获得高维空间的计算结果,也就是说计算结果满足高维(满足高维,才能说明高维下线性可分)。\n", 99 | "\n", 100 | "#### soft margin & slack variable\n", 101 | "\n", 102 | "引入松弛变量$\\xi\\geq0$,对应数据点允许偏离的functional margin 的量。\n", 103 | "\n", 104 | "目标函数:$\\min\\ \\frac{1}{2}||w||^2+C\\sum\\xi_i\\qquad s.t.\\ y_i(w^Tx_i+b)\\geq1-\\xi_i$ \n", 105 | "\n", 106 | "对偶问题:\n", 107 | "\n", 108 | "$$\\max\\ \\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)=\\min \\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)-\\sum^m_{i=1}\\alpha_i$$\n", 109 | "\n", 110 | "$$s.t.\\ C\\geq\\alpha_i \\geq 0,i=1,2,...,m\\quad \\sum^m_{i=1}\\alpha_iy_i=0,$$\n", 111 | "\n", 112 | "-----\n", 113 | "\n", 114 | "#### Sequential Minimal Optimization\n", 115 | "\n", 116 | "首先定义特征到结果的输出函数:$u=w^Tx+b$.\n", 117 | "\n", 118 | "因为$w=\\sum\\alpha_iy_ix_i$\n", 119 | "\n", 120 | "有$u=\\sum y_i\\alpha_iK(x_i, x)-b$\n", 121 | "\n", 122 | "\n", 123 | "----\n", 124 | "\n", 125 | "$\\max \\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i=1}\\sum^m_{j=1}\\alpha_i\\alpha_jy_iy_j<\\phi(x_i)^T,\\phi(x_j)>$\n", 126 | "\n", 127 | "$s.t.\\ \\sum^m_{i=1}\\alpha_iy_i=0,$\n", 128 | "\n", 129 | "$ \\alpha_i \\geq 0,i=1,2,...,m$\n", 130 | "\n", 131 | "Reference: \n", 132 | "https://www.youtube.com/watch?v=_PwhiWxHK8o \n", 133 | "https://www.youtube.com/watch?v=vywmP6Ud1HA \n", 134 | "https://www.youtube.com/watch?v=iB2VK7qPfjg\n" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 1, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "import numpy as np\n", 144 | "import pandas as pd\n", 145 | "from sklearn.datasets import load_iris\n", 146 | "from sklearn.model_selection import train_test_split\n", 147 | "import matplotlib.pyplot as plt\n", 148 | "%matplotlib inline" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 2, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "# data\n", 158 | "def create_data():\n", 159 | " iris = load_iris()\n", 160 | " df = pd.DataFrame(iris.data, columns=iris.feature_names)\n", 161 | " df['label'] = iris.target\n", 162 | " df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n", 163 | " data = np.array(df.iloc[:100, [0, 1, -1]])\n", 164 | " for i in range(len(data)):\n", 165 | " if data[i,-1] == 0:\n", 166 | " data[i,-1] = -1\n", 167 | " # print(data)\n", 168 | " return data[:,:2], data[:,-1]" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 3, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "X, y = create_data()\n", 178 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 161, 184 | "metadata": {}, 185 | "outputs": [ 186 | { 187 | "data": { 188 | "text/plain": [ 189 | "" 190 | ] 191 | }, 192 | "execution_count": 161, 193 | "metadata": {}, 194 | "output_type": "execute_result" 195 | }, 196 | { 197 | "data": { 198 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvDW2N/gAAGXxJREFUeJzt3X+MXWWdx/H3d4dZOiowaRkWmCmWVdM/bLsWRrBpQlxxF8VaGmShjb+qrN01uGBwMdYQ1IYEDQaV1WhayALCVrsVu4XlxyIs8UekyZTWdrWQoIt2CixDsa2shW3Ld/+4d+jM7Z2597n3nnuf57mfV9J07rkPp9/nHP329pzPea65OyIikpc/6XQBIiLSemruIiIZUnMXEcmQmruISIbU3EVEMqTmLiKSITV3EZEMqbmLiGRIzV1EJEPH1TvQzHqAEWCPuy+peG8lcCOwp7zpm+5+y3T7O/nkk33OnDlBxYqIdLutW7e+4O4DtcbV3dyBq4BdwIlTvP99d/9UvTubM2cOIyMjAX+8iIiY2W/rGVfXZRkzGwLeB0z7aVxEROJQ7zX3rwOfBV6dZswHzGyHmW00s9nVBpjZKjMbMbORsbGx0FpFRKRONZu7mS0Bnnf3rdMMuweY4+4LgB8Bt1cb5O5r3X3Y3YcHBmpeMhIRkQbVc819MbDUzC4EZgAnmtmd7v6h8QHuvnfC+HXAV1pbpohI4w4dOsTo6Cgvv/xyp0up24wZMxgaGqK3t7eh/75mc3f31cBqADN7J/CPExt7eftp7v5s+eVSSjdeRUSiMDo6ygknnMCcOXMws06XU5O7s3fvXkZHRznzzDMb2kfDOXczW2NmS8svrzSzX5rZL4ArgZWN7ldEpNVefvllZs2alURjBzAzZs2a1dS/NEKikLj7o8Cj5Z+vm7D9tU/3IrnZtG0PNz74JM/sO8jp/X1cc8Fcli0c7HRZEiiVxj6u2XqDmrtIt9m0bQ+r797JwUNHANiz7yCr794JoAYvUdPyAyLTuPHBJ19r7OMOHjrCjQ8+2aGKJHVPPPEEixYt4vjjj+erX/1qYX+OPrmLTOOZfQeDtovUMnPmTG6++WY2bdpU6J+jT+4i0zi9vy9ou+Rh07Y9LP7yI5z5uX9n8ZcfYdO2PbX/ozqdcsopvP3tb2844lgvNXeRaVxzwVz6ensmbevr7eGaC+Z2qCIp2vh9lj37DuIcvc/SygbfDmruItNYtnCQGy6ez2B/HwYM9vdxw8XzdTM1Y7ncZ9E1d5Eali0cVDPvIkXcZ/nWt77FunXrALjvvvs4/fTTG95XvfTJXURkgiLus1xxxRVs376d7du3t6Wxg5q7iMgkRd9nee655xgaGuKmm27i+uuvZ2hoiAMHDrRk3xPpsoyIyATjl+CKeir51FNPZXR0tCX7mo6au4hIhRzus+iyjIhIhtTcRUQypOYuIpIhNXcRkQypuYuIZEjNXbJR5GJPIs36+Mc/zimnnMK8efPa8uepuUsWclnsSfK1cuVKHnjggbb9eWrukoVcFnuSSOzYAF+bB1/sL/2+Y0PTuzzvvPOYOXNmC4qrjx5ikizoSzWkZXZsgHuuhEPl/+3s3116DbDg0s7VFUif3CUL+lINaZmH1xxt7OMOHSxtT4iau2RBX6ohLbN/inVfptoeKV2WkSwUvdiTdJGThkqXYqptT4iau2Qjh8WeJALnXzf5mjtAb19pexNWrFjBo48+ygsvvMDQ0BBf+tKXuPzyy5ssdmpq7tK0Tdv26BOz5GP8punDa0qXYk4aKjX2Jm+mrl+/vgXF1U/NXZoyni8fjyGO58sBNXhJ14JLk0rGVKMbqtIU5ctF4qTmLk1RvlxS4e6dLiFIs/WquUtTlC+XFMyYMYO9e/cm0+Ddnb179zJjxoyG96Fr7tKUay6YO+maOyhfLvEZGhpidHSUsbGxTpdStxkzZjA01Hj8Us1dmqJ8uaSgt7eXM888s9NltFXdzd3MeoARYI+7L6l473jgDuBsYC9wmbs/3cI6JWLKl4vEJ+ST+1XALuDEKu9dDvze3d9sZsuBrwCXtaA+kaQo8y+xqOuGqpkNAe8DbpliyEXA7eWfNwLnm5k1X55IOrSmvMSk3rTM14HPAq9O8f4gsBvA3Q8D+4FZTVcnkhBl/iUmNZu7mS0Bnnf3rdMNq7LtmMyRma0ysxEzG0nprrVIPZT5l5jU88l9MbDUzJ4Gvge8y8zurBgzCswGMLPjgJOAFyt35O5r3X3Y3YcHBgaaKlwkNsr8S0xqNnd3X+3uQ+4+B1gOPOLuH6oYthn4aPnnS8pj0nhaQKRFtKa8xKThnLuZrQFG3H0zcCvwXTN7itIn9uUtqk8kGcr8S0ysUx+wh4eHfWRkpCN/tohIqsxsq7sP1xqnJ1QlWtdu2sn6Lbs54k6PGSvOnc31y+Z3uiyRJKi5S5Su3bSTOx/73Wuvj7i/9loNXqQ2rQopUVq/pcp3WE6zXUQmU3OXKB2Z4l7QVNtFZDI1d4lSzxSrV0y1XUQmU3OXKK04d3bQdhGZTDdUJUrjN02VlhFpjHLuIiIJUc5dmvLBdT/nZ78+ujzQ4jfN5K5PLOpgRZ2jNdolRbrmLseobOwAP/v1i3xw3c87VFHnaI12SZWauxyjsrHX2p4zrdEuqVJzF5mG1miXVKm5i0xDa7RLqtTc5RiL3zQzaHvOtEa7pErNXY5x1ycWHdPIuzUts2zhIDdcPJ/B/j4MGOzv44aL5ystI9FTzl1EJCHKuUtTisp2h+xX+XKRxqm5yzHGs93jEcDxbDfQVHMN2W9RNYh0C11zl2MUle0O2a/y5SLNUXOXYxSV7Q7Zr/LlIs1Rc5djFJXtDtmv8uUizVFzl2MUle0O2a/y5SLN0Q1VOcb4DctWJ1VC9ltUDSLdQjl3EZGEKOdesBgy2KE1xFCziLSHmnsDYshgh9YQQ80i0j66odqAGDLYoTXEULOItI+aewNiyGCH1hBDzSLSPmruDYghgx1aQww1i0j7qLk3IIYMdmgNMdQsIu2jG6oNiCGDHVpDDDWLSPvUzLmb2Qzgx8DxlP4y2OjuX6gYsxK4ERj/Svhvuvst0+1XOXcRkXCtzLm/ArzL3V8ys17gp2Z2v7s/VjHu++7+qUaKlfa4dtNO1m/ZzRF3esxYce5srl82v+mxseTnY6lDJAY1m7uXPtq/VH7ZW/7VmcdapWHXbtrJnY/97rXXR9xfe13ZtEPGxpKfj6UOkVjUdUPVzHrMbDvwPPCQu2+pMuwDZrbDzDaa2eyWVilNW79ld93bQ8bGkp+PpQ6RWNTV3N39iLu/DRgCzjGzeRVD7gHmuPsC4EfA7dX2Y2arzGzEzEbGxsaaqVsCHZni3kq17SFjY8nPx1KHSCyCopDuvg94FHhPxfa97v5K+eU64Owp/vu17j7s7sMDAwMNlCuN6jGre3vI2Fjy87HUIRKLms3dzAbMrL/8cx/wbuCJijGnTXi5FNjVyiKleSvOrX6lrNr2kLGx5OdjqUMkFvWkZU4DbjezHkp/GWxw93vNbA0w4u6bgSvNbClwGHgRWFlUwdKY8Ruh9SRgQsbGkp+PpQ6RWGg9dxGRhGg994IVlakOyZcXue+Q+aV4LJKzYwM8vAb2j8JJQ3D+dbDg0k5XJRFTc29AUZnqkHx5kfsOmV+KxyI5OzbAPVfCoXLyZ//u0mtQg5cpaeGwBhSVqQ7Jlxe575D5pXgskvPwmqONfdyhg6XtIlNQc29AUZnqkHx5kfsOmV+KxyI5+0fDtoug5t6QojLVIfnyIvcdMr8Uj0VyThoK2y6CmntDispUh+TLi9x3yPxSPBbJOf866K34y7K3r7RdZAq6odqAojLVIfnyIvcdMr8Uj0Vyxm+aKi0jAZRzFxFJiHLucowYsuuSOOXtk6Hm3iViyK5L4pS3T4puqHaJGLLrkjjl7ZOi5t4lYsiuS+KUt0+KmnuXiCG7LolT3j4pau5dIobsuiROefuk6IZql4ghuy6JU94+Kcq5i4gkRDn3sqLy2iH7jWVdcmXXI5N7Zjz3+YXowLHIurkXldcO2W8s65Irux6Z3DPjuc8vRIeORdY3VIvKa4fsN5Z1yZVdj0zumfHc5xeiQ8ci6+ZeVF47ZL+xrEuu7Hpkcs+M5z6/EB06Flk396Ly2iH7jWVdcmXXI5N7Zjz3+YXo0LHIurkXldcO2W8s65Irux6Z3DPjuc8vRIeORdY3VIvKa4fsN5Z1yZVdj0zumfHc5xeiQ8dCOXcRkYQo514w5edFEnHv1bD1NvAjYD1w9kpYclPz+408x6/m3gDl50USce/VMHLr0dd+5OjrZhp8Ajn+rG+oFkX5eZFEbL0tbHu9Esjxq7k3QPl5kUT4kbDt9Uogx6/m3gDl50USYT1h2+uVQI5fzb0Bys+LJOLslWHb65VAjl83VBug/LxIIsZvmrY6LZNAjl85dxGRhLQs525mM4AfA8eXx2909y9UjDkeuAM4G9gLXObuTzdQd02h+fLU1jAPya7nfiwKzRGHZJ+LqqPI+UWewW5K6NxyPhbTqOeyzCvAu9z9JTPrBX5qZve7+2MTxlwO/N7d32xmy4GvAJe1utjQfHlqa5iHZNdzPxaF5ohDss9F1VHk/BLIYDcsdG45H4saat5Q9ZKXyi97y78qr+VcBNxe/nkjcL5Z62Mbofny1NYwD8mu534sCs0Rh2Sfi6qjyPklkMFuWOjccj4WNdSVljGzHjPbDjwPPOTuWyqGDAK7Adz9MLAfmFVlP6vMbMTMRsbGxoKLDc2Xp7aGeUh2PfdjUWiOOCT7XFQdRc4vgQx2w0LnlvOxqKGu5u7uR9z9bcAQcI6ZzasYUu1T+jEdyd3Xuvuwuw8PDAwEFxuaL09tDfOQ7Hrux6LQHHFI9rmoOoqcXwIZ7IaFzi3nY1FDUM7d3fcBjwLvqXhrFJgNYGbHAScBL7agvklC8+WprWEekl3P/VgUmiMOyT4XVUeR80sgg92w0LnlfCxqqCctMwAccvd9ZtYHvJvSDdOJNgMfBX4OXAI84gVkLEPz5amtYR6SXc/9WBSaIw7JPhdVR5HzSyCD3bDQueV8LGqomXM3swWUbpb2UPqkv8Hd15jZGmDE3TeX45LfBRZS+sS+3N1/M91+lXMXEQnXspy7u++g1LQrt1834eeXgb8JLVJERIqR/fIDyT24I+0R8mBLDA/BFPngTmoPacVwPhKQdXNP7sEdaY+QB1tieAimyAd3UntIK4bzkYisV4VM7sEdaY+QB1tieAimyAd3UntIK4bzkYism3tyD+5Ie4Q82BLDQzBFPriT2kNaMZyPRGTd3JN7cEfaI+TBlhgeginywZ3UHtKK4XwkIuvmntyDO9IeIQ+2xPAQTJEP7qT2kFYM5yMRWTf3ZQsHueHi+Qz292HAYH8fN1w8XzdTu92CS+H9N8NJswEr/f7+m6vfkAsZG0O9oeOLml9q+82QvqxDRCQhLXuISaTrhXyxRyxSqzmW7HosdbSAmrvIdEK+2CMWqdUcS3Y9ljpaJOtr7iJNC/lij1ikVnMs2fVY6mgRNXeR6YR8sUcsUqs5lux6LHW0iJq7yHRCvtgjFqnVHEt2PZY6WkTNXWQ6IV/sEYvUao4lux5LHS2i5i4ynSU3wfDlRz/1Wk/pdYw3JselVnMs2fVY6mgR5dxFRBKinLu0T4rZ4KJqLipfnuIxlo5Sc5fmpJgNLqrmovLlKR5j6Thdc5fmpJgNLqrmovLlKR5j6Tg1d2lOitngomouKl+e4jGWjlNzl+akmA0uquai8uUpHmPpODV3aU6K2eCiai4qX57iMZaOU3OX5qSYDS6q5qLy5SkeY+k45dxFRBJSb85dn9wlHzs2wNfmwRf7S7/v2ND+/RZVg0gg5dwlD0VlwUP2qzy6RESf3CUPRWXBQ/arPLpERM1d8lBUFjxkv8qjS0TU3CUPRWXBQ/arPLpERM1d8lBUFjxkv8qjS0TU3CUPRWXBQ/arPLpEpGbO3cxmA3cApwKvAmvd/RsVY94J/Bvw3+VNd7v7tHeRlHMXEQnXyvXcDwOfcffHzewEYKuZPeTuv6oY9xN3X9JIsRKhFNcPD6k5xfnFQMctGTWbu7s/Czxb/vkPZrYLGAQqm7vkIsW8tvLoxdNxS0rQNXczmwMsBLZUeXuRmf3CzO43s7e2oDbplBTz2sqjF0/HLSl1P6FqZm8AfgB82t0PVLz9OPBGd3/JzC4ENgFvqbKPVcAqgDPOOKPhoqVgKea1lUcvno5bUur65G5mvZQa+13ufnfl++5+wN1fKv98H9BrZidXGbfW3YfdfXhgYKDJ0qUwKea1lUcvno5bUmo2dzMz4FZgl7tXXbvUzE4tj8PMzinvd28rC5U2SjGvrTx68XTcklLPZZnFwIeBnWa2vbzt88AZAO7+HeAS4JNmdhg4CCz3Tq0lLM0bvzmWUioipOYU5xcDHbekaD13EZGEtDLnLrFS5niye6+GrbeVvpDaekpfb9fstyCJJErNPVXKHE9279UwcuvR137k6Gs1eOlCWlsmVcocT7b1trDtIplTc0+VMseT+ZGw7SKZU3NPlTLHk1lP2HaRzKm5p0qZ48nOXhm2XSRzau6p0trhky25CYYvP/pJ3XpKr3UzVbqUcu4iIglRzr0Bm7bt4cYHn+SZfQc5vb+Pay6Yy7KFg50uq3Vyz8XnPr8Y6BgnQ829bNO2Pay+eycHD5XSFXv2HWT13TsB8mjwuefic59fDHSMk6Jr7mU3Pvjka4193MFDR7jxwSc7VFGL5Z6Lz31+MdAxToqae9kz+w4GbU9O7rn43OcXAx3jpKi5l53e3xe0PTm55+Jzn18MdIyTouZeds0Fc+nrnfzAS19vD9dcMLdDFbVY7rn43OcXAx3jpOiGatn4TdNs0zK5r8Wd+/xioGOcFOXcRUQSUm/OXZdlRFKwYwN8bR58sb/0+44NaexbOkaXZURiV2S+XNn1bOmTu0jsisyXK7ueLTV3kdgVmS9Xdj1bau4isSsyX67serbU3EViV2S+XNn1bKm5i8SuyLX79b0A2VLOXUQkIcq5i4h0MTV3EZEMqbmLiGRIzV1EJENq7iIiGVJzFxHJkJq7iEiG1NxFRDJUs7mb2Wwz+08z22VmvzSzq6qMMTO72cyeMrMdZnZWMeVKU7Rut0jXqGc998PAZ9z9cTM7AdhqZg+5+68mjHkv8Jbyr3OBb5d/l1ho3W6RrlLzk7u7P+vuj5d//gOwC6j8YtGLgDu85DGg38xOa3m10jit2y3SVYKuuZvZHGAhsKXirUFg94TXoxz7FwBmtsrMRsxsZGxsLKxSaY7W7RbpKnU3dzN7A/AD4NPufqDy7Sr/yTErkrn7WncfdvfhgYGBsEqlOVq3W6Sr1NXczayXUmO/y93vrjJkFJg94fUQ8Ezz5UnLaN1uka5ST1rGgFuBXe5+0xTDNgMfKadm3gHsd/dnW1inNEvrdot0lXrSMouBDwM7zWx7edvngTMA3P07wH3AhcBTwB+Bj7W+VGnagkvVzEW6RM3m7u4/pfo19YljHLiiVUWJiEhz9ISqiEiG1NxFRDKk5i4ikiE1dxGRDKm5i4hkSM1dRCRDau4iIhmyUkS9A3+w2Rjw24784bWdDLzQ6SIKpPmlK+e5geZXjze6e83FuTrW3GNmZiPuPtzpOoqi+aUr57mB5tdKuiwjIpIhNXcRkQypuVe3ttMFFEzzS1fOcwPNr2V0zV1EJEP65C4ikqGubu5m1mNm28zs3irvrTSzMTPbXv71t52osRlm9rSZ7SzXP1LlfTOzm83sKTPbYWZndaLORtQxt3ea2f4J5y+pr5wys34z22hmT5jZLjNbVPF+sucO6ppfsufPzOZOqHu7mR0ws09XjCn8/NXzZR05uwrYBZw4xfvfd/dPtbGeIvylu0+Vq30v8Jbyr3OBb5d/T8V0cwP4ibsvaVs1rfUN4AF3v8TM/hR4XcX7qZ+7WvODRM+fuz8JvA1KHyCBPcAPK4YVfv669pO7mQ0B7wNu6XQtHXQRcIeXPAb0m9lpnS6q25nZicB5lL7eEnf/P3ffVzEs2XNX5/xycT7wa3evfGCz8PPXtc0d+DrwWeDVacZ8oPxPpo1mNnuacbFy4D/MbKuZrary/iCwe8Lr0fK2FNSaG8AiM/uFmd1vZm9tZ3FN+nNgDPjn8mXDW8zs9RVjUj539cwP0j1/Ey0H1lfZXvj568rmbmZLgOfdfes0w+4B5rj7AuBHwO1tKa61Frv7WZT+CXiFmZ1X8X61r09MJT5Va26PU3pM+y+AfwI2tbvAJhwHnAV8290XAv8LfK5iTMrnrp75pXz+AChfbloK/Gu1t6tsa+n568rmTulLv5ea2dPA94B3mdmdEwe4+153f6X8ch1wdntLbJ67P1P+/XlK1/zOqRgyCkz8F8kQ8Ex7qmtOrbm5+wF3f6n8831Ar5md3PZCGzMKjLr7lvLrjZSaYeWYJM8ddcwv8fM37r3A4+7+P1XeK/z8dWVzd/fV7j7k7nMo/bPpEXf/0MQxFde/llK68ZoMM3u9mZ0w/jPw18B/VQzbDHykfOf+HcB+d3+2zaUGq2duZnaqmVn553Mo/W99b7trbYS7PwfsNrO55U3nA7+qGJbkuYP65pfy+ZtgBdUvyUAbzl+3p2UmMbM1wIi7bwauNLOlwGHgRWBlJ2trwJ8BPyz//+M44F/c/QEz+3sAd/8OcB9wIfAU8EfgYx2qNVQ9c7sE+KSZHQYOAss9rSf2/gG4q/xP+98AH8vk3I2rNb+kz5+ZvQ74K+DvJmxr6/nTE6oiIhnqyssyIiK5U3MXEcmQmruISIbU3EVEMqTmLiKSITV3EZEMqbmLiGRIzV1EJEP/D+1KgcwTy4s9AAAAAElFTkSuQmCC\n", 199 | "text/plain": [ 200 | "
" 201 | ] 202 | }, 203 | "metadata": { 204 | "needs_background": "light" 205 | }, 206 | "output_type": "display_data" 207 | } 208 | ], 209 | "source": [ 210 | "plt.scatter(X[:50,0],X[:50,1], label='-1')\n", 211 | "plt.scatter(X[50:,0],X[50:,1], label='1')\n", 212 | "plt.legend()" 213 | ] 214 | }, 215 | { 216 | "cell_type": "markdown", 217 | "metadata": {}, 218 | "source": [ 219 | "----\n", 220 | "##### SMO算法\n", 221 | "算法7.5 P130" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 155, 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [ 230 | "class SVM:\n", 231 | " def __init__(self, max_iter=100, epsilon=0.001, C=1.0, kernel='linear'):\n", 232 | " self.max_iter = max_iter\n", 233 | " self.kernel = kernel\n", 234 | " self.epsilon = epsilon\n", 235 | " self.C = C\n", 236 | " \n", 237 | " def _init_parameters(self, X, y):\n", 238 | " '''\n", 239 | " 初始化一些参数\n", 240 | " '''\n", 241 | " self.X = X\n", 242 | " self.y = y\n", 243 | "\n", 244 | " self.b = 0.0\n", 245 | " self.M, self.N = X.shape\n", 246 | " self.alpha = np.ones(self.M)\n", 247 | " self.E = [self._E(i) for i in range(self.M)]\n", 248 | "\n", 249 | " def _kernel(self, x1, x2):\n", 250 | " #核函数\n", 251 | " if self.kernel == 'linear':\n", 252 | " return np.dot(x1, x2)\n", 253 | " \n", 254 | " def _gx(self, i):\n", 255 | " # g(x_i) 公式7.104\n", 256 | " #return np.sum(self.alpha * self.y * self._kernel(self.X, self.X[i]) + self.b)\n", 257 | " \n", 258 | " r = self.b\n", 259 | " for j in range(self.M):\n", 260 | " r += self.alpha[j]*self.y[j]*self._kernel(self.X[i], self.X[j])\n", 261 | " return r\n", 262 | " \n", 263 | " def _E(self, i):\n", 264 | " # 公式 7.105\n", 265 | " return self._gx(i) - self.y[i]\n", 266 | " \n", 267 | " def _KKT(self, i):\n", 268 | " # P130\n", 269 | " ygx = self.y[i] * self._gx(i)\n", 270 | " if self.alpha[i] == 0:\n", 271 | " return ygx >= 1\n", 272 | " elif 0 < self.alpha[i] < self.C:\n", 273 | " return ygx ==1\n", 274 | " else:\n", 275 | " return ygx <= 1\n", 276 | " \n", 277 | " def _init_alpha(self):\n", 278 | " # 按照书上7.4.2选择两个变量\n", 279 | " # 外层循环首先遍历所有满足0= 0:\n", 291 | " j = np.argmin(self.E)\n", 292 | " else:\n", 293 | " j = np.argmax(self.E)\n", 294 | " return i, j\n", 295 | " \n", 296 | " def _clip(self, alpha, L, H):\n", 297 | " if alpha > H:\n", 298 | " return H\n", 299 | " elif alpha < L:\n", 300 | " return L\n", 301 | " else:\n", 302 | " return alpha\n", 303 | " \n", 304 | " def fit(self, X, y):\n", 305 | " self._init_parameters(X, y)\n", 306 | " \n", 307 | " for _iter in range(self.max_iter):\n", 308 | " i1, i2 = self._init_alpha()\n", 309 | " \n", 310 | " #bound, P126\n", 311 | " if self.y[i1] == self.y[i2]:\n", 312 | " L = np.max((0, self.alpha[i2] + self.alpha[i1] - self.C))\n", 313 | " H = np.min((self.C, self.alpha[i2] + self.alpha[i1]))\n", 314 | " else:\n", 315 | " L = np.max((0, self.alpha[i2] - self.alpha[i1]))\n", 316 | " H = np.min((self.C, self.C + self.alpha[i2] - self.alpha[i1]))\n", 317 | " \n", 318 | " E1 = self.E[i1]\n", 319 | " E2 = self.E[i2]\n", 320 | " \n", 321 | " #eta = K11 + K22 - 2K12, 7.107\n", 322 | " eta = self._kernel(self.X[i1], self.X[i1]) + self._kernel(self.X[i2], self.X[i2]) - \\\n", 323 | " 2 * self._kernel(self.X[i1], self.X[i2])\n", 324 | " \n", 325 | " alpha2_new_unc = self.alpha[i2] + self.y[i2] * (E1 - E2) / (eta + 1e-4) # 7.106\n", 326 | " \n", 327 | " alpha2_new = self._clip(alpha2_new_unc, L, H) # 7.108\n", 328 | " \n", 329 | " alpha1_new = self.alpha[i1] + self.y[i1] * self.y[i2] * (self.alpha[i2] - alpha2_new) # 7.109\n", 330 | " \n", 331 | " b1_new = -E1 - self.y[i1] * self._kernel(self.X[i1], self.X[i1]) * (alpha1_new - self.alpha[i1]) - \\\n", 332 | " self.y[i2] * self._kernel(self.X[i2], self.X[i1]) * (alpha2_new - self.alpha[i2]) + self.b # 7.115\n", 333 | " \n", 334 | " b2_new = -E2 - self.y[i1] * self._kernel(self.X[i1], self.X[i2]) * (alpha1_new - self.alpha[i1]) - \\\n", 335 | " self.y[i2] * self._kernel(self.X[i2], self.X[i2]) * (alpha2_new - self.alpha[i2]) + self.b # 7.116\n", 336 | " \n", 337 | " if 0 < alpha1_new < self.C and 0 < alpha2_new < self.C:\n", 338 | " b_new = b1_new\n", 339 | " else:\n", 340 | " b_new = (b1_new + b2_new) / 2 # 中点, P130\n", 341 | " \n", 342 | " # update parameters\n", 343 | " self.alpha[i1] = alpha1_new\n", 344 | " self.alpha[i2] = alpha2_new\n", 345 | " self.b = b_new\n", 346 | " \n", 347 | " self.E[i1] = self._E(i1)\n", 348 | " self.E[i2] = self._E(i2)\n", 349 | " \n", 350 | " return 'Done.'\n", 351 | " \n", 352 | " def predict(self, data):\n", 353 | " r = self.b\n", 354 | " for i in range(self.M):\n", 355 | " r += self.alpha[i] * self.y[i] * self._kernel(data, self.X[i])\n", 356 | " \n", 357 | " return 1 if r > 0 else -1\n", 358 | " \n", 359 | " def score(self, X_test, y_test):\n", 360 | " right_item = 0\n", 361 | " for i in range(len(X_test)):\n", 362 | " res = self.predict(X_test[i])\n", 363 | " if res == y_test[i]:\n", 364 | " right_item += 1\n", 365 | " return right_item / len(X_test)\n", 366 | " \n", 367 | " def _weight(self):\n", 368 | " yx = self.y.reshape(-1, 1) * self.X\n", 369 | " self.w = np.dot(yx.T, self.alpha)\n", 370 | " return self.w, self.b\n", 371 | " \n", 372 | "\n", 373 | "#https://blog.csdn.net/wds2006sdo/article/details/53156589\n", 374 | "#https://github.com/fengdu78/lihang-code/blob/master/code/%E7%AC%AC7%E7%AB% \\\n", 375 | "#A0%20%E6%94%AF%E6%8C%81%E5%90%91%E9%87%8F%E6%9C%BA(SVM)/support-vector-machine.ipynb" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": 156, 381 | "metadata": {}, 382 | "outputs": [], 383 | "source": [ 384 | "svm = SVM(max_iter=1000)" 385 | ] 386 | }, 387 | { 388 | "cell_type": "code", 389 | "execution_count": 157, 390 | "metadata": {}, 391 | "outputs": [ 392 | { 393 | "data": { 394 | "text/plain": [ 395 | "'Done.'" 396 | ] 397 | }, 398 | "execution_count": 157, 399 | "metadata": {}, 400 | "output_type": "execute_result" 401 | } 402 | ], 403 | "source": [ 404 | "svm.fit(X_train, y_train)" 405 | ] 406 | }, 407 | { 408 | "cell_type": "code", 409 | "execution_count": 158, 410 | "metadata": {}, 411 | "outputs": [ 412 | { 413 | "data": { 414 | "text/plain": [ 415 | "1.0" 416 | ] 417 | }, 418 | "execution_count": 158, 419 | "metadata": {}, 420 | "output_type": "execute_result" 421 | } 422 | ], 423 | "source": [ 424 | "svm.score(X_test, y_test)" 425 | ] 426 | }, 427 | { 428 | "cell_type": "code", 429 | "execution_count": 159, 430 | "metadata": {}, 431 | "outputs": [ 432 | { 433 | "data": { 434 | "text/plain": [ 435 | "(array([ 3.6, -5.7]), -3.8699999999999815)" 436 | ] 437 | }, 438 | "execution_count": 159, 439 | "metadata": {}, 440 | "output_type": "execute_result" 441 | } 442 | ], 443 | "source": [ 444 | "svm._weight() #array([ 3.6, -5.7])" 445 | ] 446 | }, 447 | { 448 | "cell_type": "markdown", 449 | "metadata": {}, 450 | "source": [ 451 | "## sklearn.svm.SVC" 452 | ] 453 | }, 454 | { 455 | "cell_type": "code", 456 | "execution_count": 169, 457 | "metadata": {}, 458 | "outputs": [ 459 | { 460 | "name": "stderr", 461 | "output_type": "stream", 462 | "text": [ 463 | "/Users/max/anaconda2/envs/pytorch/lib/python3.6/site-packages/sklearn/svm/base.py:196: FutureWarning: The default value of gamma will change from 'auto' to 'scale' in version 0.22 to account better for unscaled features. Set gamma explicitly to 'auto' or 'scale' to avoid this warning.\n", 464 | " \"avoid this warning.\", FutureWarning)\n" 465 | ] 466 | }, 467 | { 468 | "data": { 469 | "text/plain": [ 470 | "SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,\n", 471 | " decision_function_shape='ovr', degree=3, gamma='auto_deprecated',\n", 472 | " kernel='rbf', max_iter=-1, probability=False, random_state=None,\n", 473 | " shrinking=True, tol=0.001, verbose=False)" 474 | ] 475 | }, 476 | "execution_count": 169, 477 | "metadata": {}, 478 | "output_type": "execute_result" 479 | } 480 | ], 481 | "source": [ 482 | "from sklearn.svm import SVC\n", 483 | "clf = SVC()\n", 484 | "clf.fit(X_train, y_train)" 485 | ] 486 | }, 487 | { 488 | "cell_type": "code", 489 | "execution_count": 170, 490 | "metadata": {}, 491 | "outputs": [ 492 | { 493 | "data": { 494 | "text/plain": [ 495 | "1.0" 496 | ] 497 | }, 498 | "execution_count": 170, 499 | "metadata": {}, 500 | "output_type": "execute_result" 501 | } 502 | ], 503 | "source": [ 504 | "clf.score(X_test, y_test)" 505 | ] 506 | }, 507 | { 508 | "cell_type": "markdown", 509 | "metadata": {}, 510 | "source": [ 511 | "### sklearn.svm.SVC\n", 512 | "\n", 513 | "*(C=1.0, kernel='rbf', degree=3, gamma='auto', coef0=0.0, shrinking=True, probability=False,tol=0.001, cache_size=200, class_weight=None, verbose=False, max_iter=-1, decision_function_shape=None,random_state=None)*\n", 514 | "\n", 515 | "参数:\n", 516 | "\n", 517 | "- C:C-SVC的惩罚参数C?默认值是1.0\n", 518 | "\n", 519 | "C越大,相当于惩罚松弛变量,希望松弛变量接近0,即对误分类的惩罚增大,趋向于对训练集全分对的情况,这样对训练集测试时准确率很高,但泛化能力弱。C值小,对误分类的惩罚减小,允许容错,将他们当成噪声点,泛化能力较强。\n", 520 | "\n", 521 | "- kernel :核函数,默认是rbf,可以是‘linear’, ‘poly’, ‘rbf’, ‘sigmoid’, ‘precomputed’ \n", 522 | " \n", 523 | " – 线性:u'v\n", 524 | " \n", 525 | " – 多项式:(gamma*u'*v + coef0)^degree\n", 526 | "\n", 527 | " – RBF函数:exp(-gamma|u-v|^2)\n", 528 | "\n", 529 | " – sigmoid:tanh(gamma*u'*v + coef0)\n", 530 | "\n", 531 | "\n", 532 | "- degree :多项式poly函数的维度,默认是3,选择其他核函数时会被忽略。\n", 533 | "\n", 534 | "\n", 535 | "- gamma : ‘rbf’,‘poly’ 和‘sigmoid’的核函数参数。默认是’auto’,则会选择1/n_features\n", 536 | "\n", 537 | "\n", 538 | "- coef0 :核函数的常数项。对于‘poly’和 ‘sigmoid’有用。\n", 539 | "\n", 540 | "\n", 541 | "- probability :是否采用概率估计?.默认为False\n", 542 | "\n", 543 | "\n", 544 | "- shrinking :是否采用shrinking heuristic方法,默认为true\n", 545 | "\n", 546 | "\n", 547 | "- tol :停止训练的误差值大小,默认为1e-3\n", 548 | "\n", 549 | "\n", 550 | "- cache_size :核函数cache缓存大小,默认为200\n", 551 | "\n", 552 | "\n", 553 | "- class_weight :类别的权重,字典形式传递。设置第几类的参数C为weight*C(C-SVC中的C)\n", 554 | "\n", 555 | "\n", 556 | "- verbose :允许冗余输出?\n", 557 | "\n", 558 | "\n", 559 | "- max_iter :最大迭代次数。-1为无限制。\n", 560 | "\n", 561 | "\n", 562 | "- decision_function_shape :‘ovo’, ‘ovr’ or None, default=None3\n", 563 | "\n", 564 | "\n", 565 | "- random_state :数据洗牌时的种子值,int值\n", 566 | "\n", 567 | "\n", 568 | "主要调节的参数有:C、kernel、degree、gamma、coef0。" 569 | ] 570 | } 571 | ], 572 | "metadata": { 573 | "kernelspec": { 574 | "display_name": "Python 3", 575 | "language": "python", 576 | "name": "python3" 577 | }, 578 | "language_info": { 579 | "codemirror_mode": { 580 | "name": "ipython", 581 | "version": 3 582 | }, 583 | "file_extension": ".py", 584 | "mimetype": "text/x-python", 585 | "name": "python", 586 | "nbconvert_exporter": "python", 587 | "pygments_lexer": "ipython3", 588 | "version": "3.6.7" 589 | } 590 | }, 591 | "nbformat": 4, 592 | "nbformat_minor": 2 593 | } 594 | -------------------------------------------------------------------------------- /第08章 提升方法/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hktxt/Learn-Statistical-Learning-Method/14a3b65db3ff35c42743a1062a2a95cecbfad094/第08章 提升方法/.DS_Store -------------------------------------------------------------------------------- /第08章 提升方法/Adaboost.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Adaboost.ipynb", 7 | "version": "0.3.2", 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "language_info": { 12 | "codemirror_mode": { 13 | "name": "ipython", 14 | "version": 3 15 | }, 16 | "file_extension": ".py", 17 | "mimetype": "text/x-python", 18 | "name": "python", 19 | "nbconvert_exporter": "python", 20 | "pygments_lexer": "ipython3", 21 | "version": "3.6.2" 22 | }, 23 | "kernelspec": { 24 | "display_name": "Python 3", 25 | "language": "python", 26 | "name": "python3" 27 | } 28 | }, 29 | "cells": [ 30 | { 31 | "cell_type": "markdown", 32 | "metadata": { 33 | "id": "CGJ1QiK3cnzN", 34 | "colab_type": "text" 35 | }, 36 | "source": [ 37 | "# 第8章 提升方法" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": { 43 | "collapsed": true, 44 | "id": "v_MmENfgcnzN", 45 | "colab_type": "text" 46 | }, 47 | "source": [ 48 | "# Boost\n", 49 | "\n", 50 | "“装袋”(bagging)和“提升”(boost)是构建组合模型的两种最主要的方法,所谓的组合模型是由多个基本模型构成的模型,组合模型的预测效果往往比任意一个基本模型的效果都要好。\n", 51 | "\n", 52 | "- 装袋:每个基本模型由从总体样本中随机抽样得到的不同数据集进行训练得到,通过重抽样得到不同训练数据集的过程称为装袋。\n", 53 | "\n", 54 | "- 提升:每个基本模型训练时的数据集采用不同权重,针对上一个基本模型分类错误的样本增加权重,使得新的模型重点关注误分类样本\n", 55 | "\n", 56 | "### AdaBoost\n", 57 | "\n", 58 | "AdaBoost是AdaptiveBoost的缩写,表明该算法是具有适应性的提升算法。\n", 59 | "\n", 60 | "算法的步骤如下:\n", 61 | "\n", 62 | "1)给每个训练样本($x_{1},x_{2},….,x_{N}$)分配权重,初始权重$w_{1}$均为1/N。\n", 63 | "\n", 64 | "2)针对带有权值的样本进行训练,得到模型$G_m$(初始模型为G1)。\n", 65 | "\n", 66 | "3)计算模型$G_m$的误分率$e_m=\\sum_{i=1}^Nw_iI(y_i\\not= G_m(x_i))$\n", 67 | "\n", 68 | "4)计算模型$G_m$的系数$\\alpha_m=0.5\\log[(1-e_m)/e_m]$\n", 69 | "\n", 70 | "5)根据误分率e和当前权重向量$w_m$更新权重向量$w_{m+1}$。\n", 71 | "\n", 72 | "6)计算组合模型$f(x)=\\sum_{m=1}^M\\alpha_mG_m(x_i)$的误分率。\n", 73 | "\n", 74 | "7)当组合模型的误分率或迭代次数低于一定阈值,停止迭代;否则,回到步骤2)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "metadata": { 80 | "id": "WkmYWWexcnzO", 81 | "colab_type": "code", 82 | "colab": {} 83 | }, 84 | "source": [ 85 | "import numpy as np\n", 86 | "import pandas as pd\n", 87 | "from sklearn.datasets import load_iris\n", 88 | "from sklearn.tree import DecisionTreeClassifier\n", 89 | "from sklearn.model_selection import train_test_split\n", 90 | "import matplotlib.pyplot as plt\n", 91 | "%matplotlib inline" 92 | ], 93 | "execution_count": 0, 94 | "outputs": [] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "metadata": { 99 | "id": "kWFOcuTKcnzR", 100 | "colab_type": "code", 101 | "colab": {} 102 | }, 103 | "source": [ 104 | "# data\n", 105 | "def create_data():\n", 106 | " iris = load_iris()\n", 107 | " df = pd.DataFrame(iris.data, columns=iris.feature_names)\n", 108 | " df['label'] = iris.target\n", 109 | " df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n", 110 | " data = np.array(df.iloc[:100, [0, 1, -1]])\n", 111 | " for i in range(len(data)):\n", 112 | " if data[i,-1] == 0:\n", 113 | " data[i,-1] = -1\n", 114 | " # print(data)\n", 115 | " return data[:,:2], data[:,-1]" 116 | ], 117 | "execution_count": 0, 118 | "outputs": [] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "metadata": { 123 | "id": "uk2Mg38UcnzT", 124 | "colab_type": "code", 125 | "colab": {} 126 | }, 127 | "source": [ 128 | "X, y = create_data()\n", 129 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)" 130 | ], 131 | "execution_count": 0, 132 | "outputs": [] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "metadata": { 137 | "id": "FNCiiDMycnzW", 138 | "colab_type": "code", 139 | "outputId": "abb8a27d-9db0-449e-e78e-b019f70c2586", 140 | "colab": { 141 | "base_uri": "https://localhost:8080/", 142 | "height": 287 143 | } 144 | }, 145 | "source": [ 146 | "plt.scatter(X[:50,0],X[:50,1], label='0')\n", 147 | "plt.scatter(X[50:,0],X[50:,1], label='1')\n", 148 | "plt.legend()" 149 | ], 150 | "execution_count": 8, 151 | "outputs": [ 152 | { 153 | "output_type": "execute_result", 154 | "data": { 155 | "text/plain": [ 156 | "" 157 | ] 158 | }, 159 | "metadata": { 160 | "tags": [] 161 | }, 162 | "execution_count": 8 163 | }, 164 | { 165 | "output_type": "display_data", 166 | "data": { 167 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGZhJREFUeJzt3X+MXWWdx/H3d4dZOqvQCWVUmCk7\naE2jQNfCCJJuiAtxq7WWBtlS4q8qa3cNLhhcjBiC2piAS4LKkmgqZAFhi92K5cdCWQISf0RqpoDt\n2kpEQTsDuwyDLbIWaMfv/nHvtDO3M3Pvc+89c5/nuZ9X0sycc0/PfJ9z4Ns753zOc83dERGRvPxZ\nqwsQEZHmU3MXEcmQmruISIbU3EVEMqTmLiKSITV3EZEMqbmLiGRIzV1EJENq7iIiGTqi1g3NrAMY\nBIbdfXnFa2uAa4Hh8qob3P3GmfZ37LHHen9/f1CxIiLtbtu2bS+4e0+17Wpu7sClwC7g6Gle/667\nf7rWnfX39zM4OBjw40VExMx+W8t2NV2WMbM+4P3AjO/GRUQkDrVec/868DngTzNs80Ez225mm8xs\n/lQbmNlaMxs0s8GRkZHQWkVEpEZVm7uZLQeed/dtM2x2D9Dv7ouAB4FbptrI3de7+4C7D/T0VL1k\nJCIidarlmvsSYIWZLQPmAEeb2W3u/uHxDdx9dML2NwL/0twyRUSaZ//+/QwNDfHKK6+0upRpzZkz\nh76+Pjo7O+v6+1Wbu7tfAVwBYGbvBv55YmMvrz/O3Z8rL66gdONVRCRKQ0NDHHXUUfT392NmrS7n\nMO7O6OgoQ0NDnHjiiXXto+6cu5mtM7MV5cVLzOwXZvZz4BJgTb37FREp2iuvvMK8efOibOwAZsa8\nefMa+s0iJAqJuz8CPFL+/qoJ6w++uxfJzebHh7n2gSd5ds8+ju/u4vKlC1m5uLfVZUmDYm3s4xqt\nL6i5i7SbzY8Pc8WdO9i3fwyA4T37uOLOHQBq8BI1TT8gMoNrH3jyYGMft2//GNc+8GSLKpJcbNmy\nhYULF7JgwQKuueaapu9fzV1kBs/u2Re0XqQWY2NjXHzxxdx///3s3LmTDRs2sHPnzqb+DF2WEZnB\n8d1dDE/RyI/v7mpBNdIqzb7v8rOf/YwFCxbw5je/GYDVq1dz11138fa3v71ZJeudu8hMLl+6kK7O\njknrujo7uHzpwhZVJLNt/L7L8J59OIfuu2x+fLjq353O8PAw8+cfepC/r6+P4eH69zcVNXeRGaxc\n3MvV551Cb3cXBvR2d3H1eafoZmobSfW+iy7LiFSxcnGvmnkbK+K+S29vL7t37z64PDQ0RG9vc/8b\n0zt3EZEZTHd/pZH7Lu985zv51a9+xdNPP81rr73GHXfcwYoVK6r/xQBq7iIiMyjivssRRxzBDTfc\nwNKlS3nb297GqlWrOOmkkxotdfLPaOreREQyM35JrtlPKS9btoxly5Y1o8QpqbmLiFSR4n0XXZYR\nEcmQmruISIbU3EVEMqTmLiKSITV3EZEMqblLNjY/PsySax7mxM//J0uuebihuT9EivaJT3yCN7zh\nDZx88smF7F/NXbJQxOROIkVas2YNW7ZsKWz/au6ShVQnd5JEbN8IXzsZvtRd+rp9Y8O7POusszjm\nmGOaUNzU9BCTZEEfqiGF2b4R7rkE9pf/W9q7u7QMsGhV6+qqQu/cJQtFTO4kAsBD6w419nH795XW\nR0zNXbKgD9WQwuwdClsfCV2WkSwUNbmTCHP7SpdiplofMTV3yUaKkztJAs65avI1d4DOrtL6Blx4\n4YU88sgjvPDCC/T19fHlL3+Ziy66qMFiD1Fzl4Y1+8ODRaIyftP0oXWlSzFz+0qNvcGbqRs2bGhC\ncdNTc5eGjOfLx2OI4/lyQA1e8rFoVdTJmKnohqo0RPlykTipuUtDlC+XVLl7q0uYUaP1qblLQ5Qv\nlxTNmTOH0dHRaBu8uzM6OsqcOXPq3oeuuUtDLl+6cNI1d1C+XOLX19fH0NAQIyMjrS5lWnPmzKGv\nr/64pZq7NET5cklRZ2cnJ554YqvLKFTNzd3MOoBBYNjdl1e8diRwK3AaMApc4O7PNLFOiZjy5SLx\nCXnnfimwCzh6itcuAn7v7gvMbDXwVeCCJtQnkhRl/iUWNd1QNbM+4P3AjdNsci5wS/n7TcA5ZmaN\nlyeSDs0pLzGpNS3zdeBzwJ+meb0X2A3g7geAvcC8hqsTSYgy/xKTqs3dzJYDz7v7tkZ/mJmtNbNB\nMxuM+S61SD2U+ZeY1PLOfQmwwsyeAe4Azjaz2yq2GQbmA5jZEcBcSjdWJ3H39e4+4O4DPT09DRUu\nEhtl/iUmVZu7u1/h7n3u3g+sBh529w9XbHY38LHy9+eXt4nz6QCRgmhOeYlJ3Tl3M1sHDLr73cBN\nwHfM7CngRUr/CIi0FWX+JSbWqjfYAwMDPjg42JKfLSKSKjPb5u4D1bbTE6oSrSs372DD1t2MudNh\nxoVnzOcrK09pdVkiSVBzlyhduXkHtz36u4PLY+4Hl9XgRarTrJASpQ1bp/jMyhnWi8hkau4SpbFp\n7gVNt15EJlNzlyh1TDN7xXTrRWQyNXeJ0oVnzA9aLyKT6YaqRGn8pqnSMiL1Uc5dRCQhyrlLQz70\n7Z/yk1+/eHB5yVuO4fZPntnCilpHc7RLinTNXQ5T2dgBfvLrF/nQt3/aoopaR3O0S6rU3OUwlY29\n2vqcaY52SZWau8gMNEe7pErNXWQGmqNdUqXmLodZ8pZjgtbnTHO0S6rU3OUwt3/yzMMaebumZVYu\n7uXq806ht7sLA3q7u7j6vFOUlpHoKecuIpIQ5dylIUVlu0P2q3y5SP3U3OUw49nu8QjgeLYbaKi5\nhuy3qBpE2oWuucthisp2h+xX+XKRxqi5y2GKynaH7Ff5cpHGqLnLYYrKdofsV/lykcaoucthisp2\nh+xX+XKRxuiGqhxm/IZls5MqIfstqgaRdqGcu4hIQpRzL1gMGezQGmKoWURmh5p7HWLIYIfWEEPN\nIjJ7dEO1DjFksENriKFmEZk9au51iCGDHVpDDDWLyOxRc69DDBns0BpiqFlEZo+aex1iyGCH1hBD\nzSIye3RDtQ4xZLBDa4ihZhGZPVVz7mY2B/ghcCSlfww2ufsXK7ZZA1wLjH8k/A3ufuNM+1XOXUQk\nXDNz7q8CZ7v7y2bWCfzYzO5390crtvuuu3+6nmJldly5eQcbtu5mzJ0OMy48Yz5fWXlKw9vGkp+P\npQ6RGFRt7l56a/9yebGz/Kc1j7VK3a7cvIPbHv3dweUx94PLlU07ZNtY8vOx1CESi5puqJpZh5k9\nATwPPOjuW6fY7INmtt3MNpnZ/KZWKQ3bsHV3zetDto0lPx9LHSKxqKm5u/uYu78D6ANON7OTKza5\nB+h390XAg8AtU+3HzNaa2aCZDY6MjDRStwQam+beylTrQ7aNJT8fSx0isQiKQrr7HuAHwHsr1o+6\n+6vlxRuB06b5++vdfcDdB3p6euqpV+rUYVbz+pBtY8nPx1KHSCyqNncz6zGz7vL3XcB7gF9WbHPc\nhMUVwK5mFimNu/CMqa+UTbU+ZNtY8vOx1CESi1rSMscBt5hZB6V/DDa6+71mtg4YdPe7gUvMbAVw\nAHgRWFNUwVKf8RuhtSRgQraNJT8fSx0isdB87iIiCdF87gUrKlMdki8vct8h40vxWCRn+0Z4aB3s\nHYK5fXDOVbBoVaurkoipudehqEx1SL68yH2HjC/FY5Gc7Rvhnktgfzn5s3d3aRnU4GVamjisDkVl\nqkPy5UXuO2R8KR6L5Dy07lBjH7d/X2m9yDTU3OtQVKY6JF9e5L5DxpfisUjO3qGw9SKoudelqEx1\nSL68yH2HjC/FY5GcuX1h60VQc69LUZnqkHx5kfsOGV+KxyI551wFnRX/WHZ2ldaLTEM3VOtQVKY6\nJF9e5L5DxpfisUjO+E1TpWUkgHLuIiIJUc5dDhNDdl0Sp7x9MtTc20QM2XVJnPL2SdEN1TYRQ3Zd\nEqe8fVLU3NtEDNl1SZzy9klRc28TMWTXJXHK2ydFzb1NxJBdl8Qpb58U3VBtEzFk1yVxytsnRTl3\nEZGEKOdeVlReO2S/scxLrux6ZHLPjOc+vhAtOBZZN/ei8toh+41lXnJl1yOTe2Y89/GFaNGxyPqG\nalF57ZD9xjIvubLrkck9M577+EK06Fhk3dyLymuH7DeWecmVXY9M7pnx3McXokXHIuvmXlReO2S/\nscxLrux6ZHLPjOc+vhAtOhZZN/ei8toh+41lXnJl1yOTe2Y89/GFaNGxyPqGalF57ZD9xjIvubLr\nkck9M577+EK06Fgo5y4ikhDl3Aum/LxIIu69DLbdDD4G1gGnrYHl1zW+38hz/GrudVB+XiQR914G\ngzcdWvaxQ8uNNPgEcvxZ31AtivLzIonYdnPY+lolkONXc6+D8vMiifCxsPW1SiDHr+ZeB+XnRRJh\nHWHra5VAjl/NvQ7Kz4sk4rQ1YetrlUCOXzdU66D8vEgixm+aNjstk0COXzl3EZGENC3nbmZzgB8C\nR5a33+TuX6zY5kjgVuA0YBS4wN2fqaPuqkLz5anNYR6SXc/9WBSaIw7JPhdVR5HjizyD3ZDQseV8\nLGZQy2WZV4Gz3f1lM+sEfmxm97v7oxO2uQj4vbsvMLPVwFeBC5pdbGi+PLU5zEOy67kfi0JzxCHZ\n56LqKHJ8CWSw6xY6tpyPRRVVb6h6ycvlxc7yn8prOecCt5S/3wScY9b82EZovjy1OcxDsuu5H4tC\nc8Qh2eei6ihyfAlksOsWOracj0UVNaVlzKzDzJ4AngcedPetFZv0ArsB3P0AsBeYN8V+1prZoJkN\njoyMBBcbmi9PbQ7zkOx67sei0BxxSPa5qDqKHF8CGey6hY4t52NRRU3N3d3H3P0dQB9wupmdXM8P\nc/f17j7g7gM9PT3Bfz80X57aHOYh2fXcj0WhOeKQ7HNRdRQ5vgQy2HULHVvOx6KKoJy7u+8BfgC8\nt+KlYWA+gJkdAcyldGO1qULz5anNYR6SXc/9WBSaIw7JPhdVR5HjSyCDXbfQseV8LKqoJS3TA+x3\n9z1m1gW8h9IN04nuBj4G/BQ4H3jYC8hYhubLU5vDPCS7nvuxKDRHHJJ9LqqOIseXQAa7bqFjy/lY\nVFE1525miyjdLO2g9E5/o7uvM7N1wKC7312OS34HWAy8CKx299/MtF/l3EVEwjUt5+7u2yk17cr1\nV034/hXg70KLFBGRYmQ//UByD+7I7Ah5sCWGh2CKfHAntYe0YjgfCci6uSf34I7MjpAHW2J4CKbI\nB3dSe0grhvORiKxnhUzuwR2ZHSEPtsTwEEyRD+6k9pBWDOcjEVk39+Qe3JHZEfJgSwwPwRT54E5q\nD2nFcD4SkXVzT+7BHZkdIQ+2xPAQTJEP7qT2kFYM5yMRWTf35B7ckdkR8mBLDA/BFPngTmoPacVw\nPhKRdXNfubiXq887hd7uLgzo7e7i6vNO0c3UdrdoFXzgepg7H7DS1w9cP/UNuZBtY6g3dPuixpfa\nfjOkD+sQEUlI0x5iEml7IR/sEYvUao4lux5LHU2g5i4yk5AP9ohFajXHkl2PpY4myfqau0jDQj7Y\nIxap1RxLdj2WOppEzV1kJiEf7BGL1GqOJbseSx1NouYuMpOQD/aIRWo1x5Jdj6WOJlFzF5lJyAd7\nxCK1mmPJrsdSR5OouYvMZPl1MHDRoXe91lFajvHG5LjUao4lux5LHU2inLuISEKUc5fZk2I2uKia\ni8qXp3iMpaXU3KUxKWaDi6q5qHx5isdYWk7X3KUxKWaDi6q5qHx5isdYWk7NXRqTYja4qJqLypen\neIyl5dTcpTEpZoOLqrmofHmKx1haTs1dGpNiNriomovKl6d4jKXl1NylMSlmg4uquah8eYrHWFpO\nOXcRkYTUmnPXO3fJx/aN8LWT4Uvdpa/bN87+fouqQSSQcu6Sh6Ky4CH7VR5dIqJ37pKHorLgIftV\nHl0iouYueSgqCx6yX+XRJSJq7pKHorLgIftVHl0iouYueSgqCx6yX+XRJSJq7pKHorLgIftVHl0i\nUjXnbmbzgVuBNwIOrHf3b1Rs827gLuDp8qo73X3Gu0jKuYuIhGvmfO4HgM+6+2NmdhSwzcwedPed\nFdv9yN2X11OsRCjF+cNDak5xfDHQcUtG1ebu7s8Bz5W//4OZ7QJ6gcrmLrlIMa+tPHrxdNySEnTN\n3cz6gcXA1ilePtPMfm5m95vZSU2oTVolxby28ujF03FLSs1PqJrZ64HvAZ9x95cqXn4M+Et3f9nM\nlgGbgbdOsY+1wFqAE044oe6ipWAp5rWVRy+ejltSanrnbmadlBr77e5+Z+Xr7v6Su79c/v4+oNPM\njp1iu/XuPuDuAz09PQ2WLoVJMa+tPHrxdNySUrW5m5kBNwG73H3KuUvN7E3l7TCz08v7HW1moTKL\nUsxrK49ePB23pNRyWWYJ8BFgh5k9UV73BeAEAHf/FnA+8CkzOwDsA1Z7q+YSlsaN3xxLKRURUnOK\n44uBjltSNJ+7iEhCmplzl1gpczzZvZfBtptLH0htHaWPt2v0U5BEEqXmnipljie79zIYvOnQso8d\nWlaDlzakuWVSpczxZNtuDlsvkjk191QpczyZj4WtF8mcmnuqlDmezDrC1otkTs09VcocT3bamrD1\nIplTc0+V5g6fbPl1MHDRoXfq1lFa1s1UaVPKuYuIJEQ59zpsfnyYax94kmf37OP47i4uX7qQlYt7\nW11W8+Sei899fDHQMU6GmnvZ5seHueLOHezbX0pXDO/ZxxV37gDIo8HnnovPfXwx0DFOiq65l137\nwJMHG/u4ffvHuPaBJ1tUUZPlnovPfXwx0DFOipp72bN79gWtT07uufjcxxcDHeOkqLmXHd/dFbQ+\nObnn4nMfXwx0jJOi5l52+dKFdHVOfuClq7ODy5cubFFFTZZ7Lj738cVAxzgpuqFaNn7TNNu0TO5z\ncec+vhjoGCdFOXcRkYTUmnPXZRmRFGzfCF87Gb7UXfq6fWMa+5aW0WUZkdgVmS9Xdj1beucuErsi\n8+XKrmdLzV0kdkXmy5Vdz5aau0jsisyXK7ueLTV3kdgVmS9Xdj1bau4isSty7n59LkC2lHMXEUmI\ncu4iIm1MzV1EJENq7iIiGVJzFxHJkJq7iEiG1NxFRDKk5i4ikiE1dxGRDFVt7mY238x+YGY7zewX\nZnbpFNuYmV1vZk+Z2XYzO7WYcqUhmrdbpG3UMp/7AeCz7v6YmR0FbDOzB91954Rt3ge8tfznDOCb\n5a8SC83bLdJWqr5zd/fn3P2x8vd/AHYBlR8sei5wq5c8CnSb2XFNr1bqp3m7RdpK0DV3M+sHFgNb\nK17qBXZPWB7i8H8AMLO1ZjZoZoMjIyNhlUpjNG+3SFupubmb2euB7wGfcfeX6vlh7r7e3QfcfaCn\np6eeXUi9NG+3SFupqbmbWSelxn67u985xSbDwPwJy33ldRILzdst0lZqScsYcBOwy92vm2azu4GP\nllMz7wL2uvtzTaxTGqV5u0XaSi1pmSXAR4AdZvZEed0XgBMA3P1bwH3AMuAp4I/Ax5tfqjRs0So1\nc5E2UbW5u/uPAauyjQMXN6soERFpjJ5QFRHJkJq7iEiG1NxFRDKk5i4ikiE1dxGRDKm5i4hkSM1d\nRCRDVoqot+AHm40Av23JD6/uWOCFVhdRII0vXTmPDTS+Wvylu1ednKtlzT1mZjbo7gOtrqMoGl+6\nch4baHzNpMsyIiIZUnMXEcmQmvvU1re6gIJpfOnKeWyg8TWNrrmLiGRI79xFRDLU1s3dzDrM7HEz\nu3eK19aY2YiZPVH+8/etqLERZvaMme0o1z84xetmZteb2VNmtt3MTm1FnfWoYWzvNrO9E85fUh85\nZWbdZrbJzH5pZrvM7MyK15M9d1DT+JI9f2a2cELdT5jZS2b2mYptCj9/tXxYR84uBXYBR0/z+nfd\n/dOzWE8R/sbdp8vVvg94a/nPGcA3y19TMdPYAH7k7stnrZrm+gawxd3PN7M/B/6i4vXUz1218UGi\n58/dnwTeAaU3kJQ+cvT7FZsVfv7a9p27mfUB7wdubHUtLXQucKuXPAp0m9lxrS6q3ZnZXOAsSh9v\nibu/5u57KjZL9tzVOL5cnAP82t0rH9gs/Py1bXMHvg58DvjTDNt8sPwr0yYzmz/DdrFy4L/MbJuZ\nrZ3i9V5g94TlofK6FFQbG8CZZvZzM7vfzE6azeIadCIwAvxb+bLhjWb2uoptUj53tYwP0j1/E60G\nNkyxvvDz15bN3cyWA8+7+7YZNrsH6Hf3RcCDwC2zUlxz/bW7n0rpV8CLzeysVhfURNXG9hilx7T/\nCvhXYPNsF9iAI4BTgW+6+2Lg/4DPt7akpqplfCmfPwDKl5tWAP/Rip/fls2d0od+rzCzZ4A7gLPN\n7LaJG7j7qLu/Wl68EThtdktsnLsPl78+T+ma3+kVmwwDE38j6Suvi161sbn7S+7+cvn7+4BOMzt2\n1gutzxAw5O5by8ubKDXDiZI9d9QwvsTP37j3AY+5+/9O8Vrh568tm7u7X+Hufe7eT+nXpofd/cMT\nt6m4/rWC0o3XZJjZ68zsqPHvgb8F/rtis7uBj5bv3L8L2Ovuz81yqcFqGZuZvcnMrPz96ZT+Wx+d\n7Vrr4e7/A+w2s4XlVecAOys2S/LcQW3jS/n8TXAhU1+SgVk4f+2elpnEzNYBg+5+N3CJma0ADgAv\nAmtaWVsd3gh8v/z/xxHAv7v7FjP7RwB3/xZwH7AMeAr4I/DxFtUaqpaxnQ98yswOAPuA1Z7WE3v/\nBNxe/tX+N8DHMzl346qNL+nzV37T8R7gHyasm9XzpydURUQy1JaXZUREcqfmLiKSITV3EZEMqbmL\niGRIzV1EJENq7iIiGVJzFxHJkJq7iEiG/h86qpKOmdh1nwAAAABJRU5ErkJggg==\n", 168 | "text/plain": [ 169 | "
" 170 | ] 171 | }, 172 | "metadata": { 173 | "tags": [] 174 | } 175 | } 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "metadata": { 181 | "id": "9AlbKsaNhY-J", 182 | "colab_type": "code", 183 | "colab": {} 184 | }, 185 | "source": [ 186 | "# weak classifier\n", 187 | "weak_cla = DecisionTreeClassifier(max_depth=1)" 188 | ], 189 | "execution_count": 0, 190 | "outputs": [] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "metadata": { 195 | "id": "AS7dimPwhp2r", 196 | "colab_type": "code", 197 | "colab": { 198 | "base_uri": "https://localhost:8080/", 199 | "height": 125 200 | }, 201 | "outputId": "3f3c28e2-7fa9-437f-db5c-da74f6acb084" 202 | }, 203 | "source": [ 204 | "# fit\n", 205 | "weak_cla.fit(X_train, y_train)" 206 | ], 207 | "execution_count": 46, 208 | "outputs": [ 209 | { 210 | "output_type": "execute_result", 211 | "data": { 212 | "text/plain": [ 213 | "DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=1,\n", 214 | " max_features=None, max_leaf_nodes=None,\n", 215 | " min_impurity_decrease=0.0, min_impurity_split=None,\n", 216 | " min_samples_leaf=1, min_samples_split=2,\n", 217 | " min_weight_fraction_leaf=0.0, presort=False,\n", 218 | " random_state=None, splitter='best')" 219 | ] 220 | }, 221 | "metadata": { 222 | "tags": [] 223 | }, 224 | "execution_count": 46 225 | } 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "metadata": { 231 | "id": "xOWWIY5ph00J", 232 | "colab_type": "code", 233 | "colab": { 234 | "base_uri": "https://localhost:8080/", 235 | "height": 35 236 | }, 237 | "outputId": "088e2e42-4db2-46e7-e76f-d7b2d048327c" 238 | }, 239 | "source": [ 240 | "weak_cla_accuracy = weak_cla.score(X_test, y_test);weak_cla_accuracy" 241 | ], 242 | "execution_count": 47, 243 | "outputs": [ 244 | { 245 | "output_type": "execute_result", 246 | "data": { 247 | "text/plain": [ 248 | "0.85" 249 | ] 250 | }, 251 | "metadata": { 252 | "tags": [] 253 | }, 254 | "execution_count": 47 255 | } 256 | ] 257 | }, 258 | { 259 | "cell_type": "markdown", 260 | "metadata": { 261 | "id": "9ezbeV0Tcnza", 262 | "colab_type": "text" 263 | }, 264 | "source": [ 265 | "----\n", 266 | "\n", 267 | "### AdaBoost\n", 268 | "算法 8.1" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "metadata": { 274 | "id": "5Ftb2TtQcnzb", 275 | "colab_type": "code", 276 | "colab": {} 277 | }, 278 | "source": [ 279 | "class AdaBoost:\n", 280 | " def __init__(self, n_estimators=100):\n", 281 | " self.clf_num = n_estimators\n", 282 | " \n", 283 | " def init_args(self, X, y):\n", 284 | " \n", 285 | " self.X = X\n", 286 | " self.y = y\n", 287 | " M, _ = X.shape\n", 288 | " \n", 289 | " self.models = []\n", 290 | " self.alphas = []\n", 291 | " self.weights = np.ones(M) / M # 1\n", 292 | " \n", 293 | " def fit(self, X, y):\n", 294 | " self.init_args(X, y)\n", 295 | " \n", 296 | " for n in range(self.clf_num):\n", 297 | " cla = DecisionTreeClassifier(max_depth=1) # weak cla\n", 298 | " cla.fit(X, y, sample_weight=self.weights) # 2(a)\n", 299 | " P = cla.predict(X) \n", 300 | " \n", 301 | " err = self.weights.dot(P != y) # 2(b) 8.1\n", 302 | " alpha = 0.5*(np.log(1 - err) - np.log(err)) # 2(c) 8.2\n", 303 | " \n", 304 | " self.weights = self.weights * np.exp(-alpha * y * P)\n", 305 | " self.weights = self.weights / self.weights.sum() # 2(d) 8.3, 8.4, 8.5\n", 306 | " \n", 307 | " self.models.append(cla)\n", 308 | " self.alphas.append(alpha)\n", 309 | " \n", 310 | " return 'Done!'\n", 311 | " \n", 312 | " def predict(self, x):\n", 313 | " N, _ = x.shape\n", 314 | " FX = np.zeros(N)\n", 315 | " \n", 316 | " for alpha, cla in zip(self.alphas, self.models):\n", 317 | " FX += alpha * cla.predict(x)\n", 318 | "\n", 319 | " return np.sign(FX)\n", 320 | " \n", 321 | " def score(self, X_test, y_test):\n", 322 | " p = self.predict(X_test)\n", 323 | " r = np.sum(p == y_test)\n", 324 | " \n", 325 | " return r/len(X_test)\n", 326 | " \n", 327 | " def _weights(self):\n", 328 | " return self.alphas, self.weights, self.models" 329 | ], 330 | "execution_count": 0, 331 | "outputs": [] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "metadata": { 336 | "id": "onbrb8Xsr5DE", 337 | "colab_type": "code", 338 | "colab": {} 339 | }, 340 | "source": [ 341 | "adaboost = AdaBoost()" 342 | ], 343 | "execution_count": 0, 344 | "outputs": [] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "metadata": { 349 | "id": "kae_0imhr-Ld", 350 | "colab_type": "code", 351 | "colab": { 352 | "base_uri": "https://localhost:8080/", 353 | "height": 35 354 | }, 355 | "outputId": "e6dc8e11-3ede-4ead-e35a-1ba25b26e056" 356 | }, 357 | "source": [ 358 | "adaboost.fit(X_train, y_train)" 359 | ], 360 | "execution_count": 112, 361 | "outputs": [ 362 | { 363 | "output_type": "execute_result", 364 | "data": { 365 | "text/plain": [ 366 | "'Done!'" 367 | ] 368 | }, 369 | "metadata": { 370 | "tags": [] 371 | }, 372 | "execution_count": 112 373 | } 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "metadata": { 379 | "id": "bhWzbeAEsl64", 380 | "colab_type": "code", 381 | "colab": { 382 | "base_uri": "https://localhost:8080/", 383 | "height": 35 384 | }, 385 | "outputId": "0944a671-36b0-4347-9da4-271ce20315be" 386 | }, 387 | "source": [ 388 | "adaboost.score(X_test, y_test)" 389 | ], 390 | "execution_count": 113, 391 | "outputs": [ 392 | { 393 | "output_type": "execute_result", 394 | "data": { 395 | "text/plain": [ 396 | "1.0" 397 | ] 398 | }, 399 | "metadata": { 400 | "tags": [] 401 | }, 402 | "execution_count": 113 403 | } 404 | ] 405 | }, 406 | { 407 | "cell_type": "markdown", 408 | "metadata": { 409 | "id": "VQhHJ_rDcnzd", 410 | "colab_type": "text" 411 | }, 412 | "source": [ 413 | "### 例8.1" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "metadata": { 419 | "id": "qEKxLH3Hcnzd", 420 | "colab_type": "code", 421 | "colab": {} 422 | }, 423 | "source": [ 424 | "X_ = np.arange(10).reshape(10, 1)\n", 425 | "y_ = np.array([1, 1, 1, -1, -1, -1, 1, 1, 1, -1])" 426 | ], 427 | "execution_count": 0, 428 | "outputs": [] 429 | }, 430 | { 431 | "cell_type": "code", 432 | "metadata": { 433 | "id": "iGr8lCCicnzg", 434 | "colab_type": "code", 435 | "colab": { 436 | "base_uri": "https://localhost:8080/", 437 | "height": 35 438 | }, 439 | "outputId": "c1bffec1-902d-4076-e4f4-0318b9acfddd" 440 | }, 441 | "source": [ 442 | "clf = AdaBoost()\n", 443 | "clf.fit(X_, y_)" 444 | ], 445 | "execution_count": 115, 446 | "outputs": [ 447 | { 448 | "output_type": "execute_result", 449 | "data": { 450 | "text/plain": [ 451 | "'Done!'" 452 | ] 453 | }, 454 | "metadata": { 455 | "tags": [] 456 | }, 457 | "execution_count": 115 458 | } 459 | ] 460 | }, 461 | { 462 | "cell_type": "markdown", 463 | "metadata": { 464 | "id": "YTbvHicmcnzq", 465 | "colab_type": "text" 466 | }, 467 | "source": [ 468 | "-----\n", 469 | "# sklearn.ensemble.AdaBoostClassifier\n", 470 | "\n", 471 | "- algorithm:这个参数只有AdaBoostClassifier有。主要原因是scikit-learn实现了两种Adaboost分类算法,SAMME和SAMME.R。两者的主要区别是弱学习器权重的度量,SAMME使用了和我们的原理篇里二元分类Adaboost算法的扩展,即用对样本集分类效果作为弱学习器权重,而SAMME.R使用了对样本集分类的预测概率大小来作为弱学习器权重。由于SAMME.R使用了概率度量的连续值,迭代一般比SAMME快,因此AdaBoostClassifier的默认算法algorithm的值也是SAMME.R。我们一般使用默认的SAMME.R就够了,但是要注意的是使用了SAMME.R, 则弱分类学习器参数base_estimator必须限制使用支持概率预测的分类器。SAMME算法则没有这个限制。\n", 472 | "\n", 473 | "- n_estimators: AdaBoostClassifier和AdaBoostRegressor都有,就是我们的弱学习器的最大迭代次数,或者说最大的弱学习器的个数。一般来说n_estimators太小,容易欠拟合,n_estimators太大,又容易过拟合,一般选择一个适中的数值。默认是50。在实际调参的过程中,我们常常将n_estimators和下面介绍的参数learning_rate一起考虑。\n", 474 | "\n", 475 | "- learning_rate: AdaBoostClassifier和AdaBoostRegressor都有,即每个弱学习器的权重缩减系数ν\n", 476 | "\n", 477 | "- base_estimator:AdaBoostClassifier和AdaBoostRegressor都有,即我们的弱分类学习器或者弱回归学习器。理论上可以选择任何一个分类或者回归学习器,不过需要支持样本权重。我们常用的一般是CART决策树或者神经网络MLP。" 478 | ] 479 | }, 480 | { 481 | "cell_type": "code", 482 | "metadata": { 483 | "id": "CGLto18Ycnzr", 484 | "colab_type": "code", 485 | "outputId": "4c41d66b-b820-4dda-a212-9d18db023e1b", 486 | "colab": { 487 | "base_uri": "https://localhost:8080/", 488 | "height": 53 489 | } 490 | }, 491 | "source": [ 492 | "from sklearn.ensemble import AdaBoostClassifier\n", 493 | "clf = AdaBoostClassifier(n_estimators=100, learning_rate=0.5)\n", 494 | "clf.fit(X_train, y_train)" 495 | ], 496 | "execution_count": 86, 497 | "outputs": [ 498 | { 499 | "output_type": "execute_result", 500 | "data": { 501 | "text/plain": [ 502 | "AdaBoostClassifier(algorithm='SAMME.R', base_estimator=None, learning_rate=0.5,\n", 503 | " n_estimators=100, random_state=None)" 504 | ] 505 | }, 506 | "metadata": { 507 | "tags": [] 508 | }, 509 | "execution_count": 86 510 | } 511 | ] 512 | }, 513 | { 514 | "cell_type": "code", 515 | "metadata": { 516 | "id": "XHy-rTRwcnzu", 517 | "colab_type": "code", 518 | "outputId": "41b0c4e7-100f-4b18-f25d-a15e33d6060e", 519 | "colab": { 520 | "base_uri": "https://localhost:8080/", 521 | "height": 35 522 | } 523 | }, 524 | "source": [ 525 | "clf.score(X_test, y_test)" 526 | ], 527 | "execution_count": 87, 528 | "outputs": [ 529 | { 530 | "output_type": "execute_result", 531 | "data": { 532 | "text/plain": [ 533 | "1.0" 534 | ] 535 | }, 536 | "metadata": { 537 | "tags": [] 538 | }, 539 | "execution_count": 87 540 | } 541 | ] 542 | } 543 | ] 544 | } -------------------------------------------------------------------------------- /第10章 隐马尔可夫模型/HMM.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "HMM.ipynb", 7 | "version": "0.3.2", 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | } 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "markdown", 19 | "metadata": { 20 | "id": "_Esy4bIi3E4L", 21 | "colab_type": "text" 22 | }, 23 | "source": [ 24 | "# HMM" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": { 30 | "id": "jN7qEs6b3UAf", 31 | "colab_type": "text" 32 | }, 33 | "source": [ 34 | "Hidden Markov Model, HMM 是可用于标注问题的统计学习模型,属于**生成模型**。由**初始概率分布**,**状态转移概率分布**和**观测概率分布**确定。\n", 35 | "$\\lambda = (A, B, \\pi)$. \n", 36 | "\n", 37 | "\n", 38 | "### **两个基本假设:** \n", 39 | "1). 齐次马尔可夫性假设,即假设隐藏的马尔可夫链在任意时刻$t$的状态只依赖于其前一时刻的状态,与其他时刻状态及观测无关,也与时刻$t$无关: \n", 40 | "$P(i_{t}|i_{t-1},o_{t-1},...,i_{1},o_{1}) = P(i_{t}|i_{t-1}), t = 1,2,...,T$ \n", 41 | "\n", 42 | "2). 观察独立性假设,即假设任意时刻的观测只依赖与该时刻的马尔可夫链的状态,与其他观测及状态无关: \n", 43 | "$P(o_{t}|i_{T},o_{T},i_{T-1},o_{T-1},...,i_{t+1},o_{t+1},i_{t},i_{t-1},o_{t-1},...,i_{1},o_{1} = P(o_{t}|i_{t})$\n" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": { 49 | "id": "W1ADpieg1V3d", 50 | "colab_type": "text" 51 | }, 52 | "source": [ 53 | "### **三个基本问题**: \n", 54 | "\n", 55 | "1. **概率计算问题**。给定模型$\\lambda = (A, B, \\pi)$ 和观测序列 $O=(o_{1},o_{2},...,o_{T})$, 计算在模型 $\\lambda$ 下观测序列 $O$ 出现的概率 $P(O|\\lambda)$. \n", 56 | "Evaluate $P(O|\\lambda)$.\n", 57 | "\n", 58 | "2. **学习问题**。已知观测序列 $O=(o_{1},o_{2},...,o_{T})$, 估计模型 $\\lambda = (A, B, \\pi)$ 的参数,使得在该模型下观测序列概率 $P(O|\\lambda)$ 最大。即用极大似然估计的方法估计参数。 \n", 59 | "$\\lambda_{MLE} = argmax_{\\lambda}P(O|\\lambda)$.\n", 60 | "\n", 61 | "3. **预测问题**。也称为解码(decoding) 问题。已知模型 $\\lambda = (A, B, \\pi)$ 和观测序列 $O=(o_{1},o_{2},...,o_{T})$,求给定观测序列条件概率 $P(I|O)$ 最大的状态序列 $I = (i_{1}, i_{2}, i_{3},...,i_{T})$. 即给定观测序列,求最有可能的对应的状态序列。 \n", 62 | "$argmax_{I}P(I|O,\\lambda)$\n" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "metadata": { 68 | "id": "oguS4aSW4Y_Q", 69 | "colab_type": "text" 70 | }, 71 | "source": [ 72 | "问题1. 前向(forward)和后向(backward)算法。 \n", 73 | "问题2. Baum-Welch 算法。 \n", 74 | "问题3. 近似算法,维特比算法。 " 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": { 80 | "id": "L6QU6kS17PgN", 81 | "colab_type": "text" 82 | }, 83 | "source": [ 84 | "---------------------------------------------------------------------------------------------------------------------------------" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "metadata": { 90 | "id": "6aNZwSt17DFG", 91 | "colab_type": "text" 92 | }, 93 | "source": [ 94 | "以下来自徐亦达的课程 \n", 95 | "视频地址:https://www.youtube.com/watch?v=Ji6KbkyNmk8 \n", 96 | "lecture: https://github.com/roboticcam/machine-learning-notes/blob/master/dynamic_model.pdf" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": { 102 | "id": "lyPGx2RqGWrp", 103 | "colab_type": "text" 104 | }, 105 | "source": [ 106 | "# 概率计算问题" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": { 112 | "id": "F_XYKDZa9wdQ", 113 | "colab_type": "text" 114 | }, 115 | "source": [ 116 | "---------------------------------------------------------------------------------------------------------------------------------" 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": { 122 | "id": "596GcIzAqwRs", 123 | "colab_type": "text" 124 | }, 125 | "source": [ 126 | "### 例 10.2 \n" 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "metadata": { 132 | "id": "MuSecU91q2ey", 133 | "colab_type": "text" 134 | }, 135 | "source": [ 136 | "考虑盒子和球模型$\\lambda = (A, B, \\pi)$ , 状态集合 $Q = {1,2,3}$, 观测集合 $V = {红,白}$, \n", 137 | "\n", 138 | "\n", 139 | "$A=\\begin{bmatrix}\n", 140 | " 0.5& 0.2& 0.3\\\\ \n", 141 | " 0.3& 0.5& 0.2\\\\ \n", 142 | " 0.2& 0.3& 0.5\n", 143 | "\\end{bmatrix}, \n", 144 | "B=\\begin{bmatrix}\n", 145 | " 0.5& 0.5\\\\ \n", 146 | " 0.4& 0.6\\\\ \n", 147 | " 0.7& 0.3\n", 148 | "\\end{bmatrix},\n", 149 | "\\pi=\\begin{bmatrix}\n", 150 | " 0.2\\\\ \n", 151 | " 0.4\\\\ \n", 152 | " 0.4\n", 153 | "\\end{bmatrix}$ \n", 154 | "\n", 155 | "设$T=3, Q={红,白,红}$,试用前向算法计算$P(O|\\lambda)$. \n", 156 | "\n", 157 | "\n", 158 | "\n" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "metadata": { 164 | "id": "gDkjQ0uJy_-M", 165 | "colab_type": "code", 166 | "colab": {} 167 | }, 168 | "source": [ 169 | "A = [[0.5, 0.2, 0.3], [0.3, 0.5, 0.2],[0.2, 0.3, 0.5]]\n", 170 | "B = [[0.5, 0.5], [0.4, 0.6], [0.7, 0.3]]\n", 171 | "pi = [0.2, 0.4, 0.4]\n", 172 | "Q = [0,1,0]" 173 | ], 174 | "execution_count": 0, 175 | "outputs": [] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "metadata": { 180 | "id": "zE5Yt4jdu_VP", 181 | "colab_type": "code", 182 | "colab": {} 183 | }, 184 | "source": [ 185 | "import numpy as np" 186 | ], 187 | "execution_count": 0, 188 | "outputs": [] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "metadata": { 193 | "id": "ZwRR7eFrHCnJ", 194 | "colab_type": "code", 195 | "colab": {} 196 | }, 197 | "source": [ 198 | "class HMM_fw:\n", 199 | " def __init__(self, A, B, pi):\n", 200 | " self.A = A # 状态转移概率\n", 201 | " self.B = B # 观测概率\n", 202 | " self.pi = pi # 初始状态\n", 203 | " \n", 204 | " def forward(self, Q):\n", 205 | " T = len(Q) # 观测序列长度,时刻T\n", 206 | " M = len(self.A) # 状态数\n", 207 | " alpha = np.zeros((T, M))\n", 208 | " \n", 209 | " for t in range(T):\n", 210 | " for m in range(M):\n", 211 | " if t == 0:\n", 212 | " alpha[t][m] = self.pi[m] * self.B[m][Q[t]]\n", 213 | " print(\"alpha[{}][{}] = pi[{}] * B[{}](Q{}) = {:.2f}\".format(t+1, m+1, m+1, m+1, Q[t]+1, alpha[t][m]))\n", 214 | " else:\n", 215 | " alpha[t][m] = sum([alpha[t-1][i] * self.A[i][m] for i in range(len(alpha[t-1]))]) * self.B[m][Q[t]]\n", 216 | " print(\"alpha[{}][{}] = {:.5f}\".format(t+1, m+1, alpha[t][m]))\n", 217 | " \n", 218 | " p = sum(alpha[T-1])\n", 219 | " #print(p)\n", 220 | " return p" 221 | ], 222 | "execution_count": 0, 223 | "outputs": [] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "metadata": { 228 | "id": "HfqXwAO1y9G2", 229 | "colab_type": "code", 230 | "colab": { 231 | "base_uri": "https://localhost:8080/", 232 | "height": 197 233 | }, 234 | "outputId": "a6dcb119-4da5-4dd6-a18c-593bfd70eaa4" 235 | }, 236 | "source": [ 237 | "m = HMM_fw(A, B, pi)\n", 238 | "m.forward(Q)" 239 | ], 240 | "execution_count": 79, 241 | "outputs": [ 242 | { 243 | "output_type": "stream", 244 | "text": [ 245 | "alpha[1][1] = pi[1] * B[1](Q1) = 0.10\n", 246 | "alpha[1][2] = pi[2] * B[2](Q1) = 0.16\n", 247 | "alpha[1][3] = pi[3] * B[3](Q1) = 0.28\n", 248 | "alpha[2][1] = 0.07700\n", 249 | "alpha[2][2] = 0.11040\n", 250 | "alpha[2][3] = 0.06060\n", 251 | "alpha[3][1] = 0.04187\n", 252 | "alpha[3][2] = 0.03551\n", 253 | "alpha[3][3] = 0.05284\n" 254 | ], 255 | "name": "stdout" 256 | }, 257 | { 258 | "output_type": "execute_result", 259 | "data": { 260 | "text/plain": [ 261 | "0.130218" 262 | ] 263 | }, 264 | "metadata": { 265 | "tags": [] 266 | }, 267 | "execution_count": 79 268 | } 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "metadata": { 274 | "id": "9LcfAGUVzaU3", 275 | "colab_type": "code", 276 | "colab": {} 277 | }, 278 | "source": [ 279 | "class HMM_bw:\n", 280 | " def __init__(self, A, B, pi):\n", 281 | " self.A = A\n", 282 | " self.B = B\n", 283 | " self.pi = pi\n", 284 | " \n", 285 | " def backward(self, Q):\n", 286 | " T = len(Q) # 观测序列长度,时刻T\n", 287 | " N = len(self.A) # 状态数\n", 288 | " beta = np.zeros((T, N))\n", 289 | " \n", 290 | " for t in range(T-1, -1, -1):\n", 291 | " for n in range(N):\n", 292 | " if t == T - 1:\n", 293 | " beta[t][n] = 1\n", 294 | " print(\"beta[{}][{}] = {:.2f}\".format(t+1, n, beta[t][n]))\n", 295 | " else:\n", 296 | " beta[t][n] = sum(self.A[n][j] * self.B[j][Q[t+1]] * beta[t+1][j] for j in range(N))\n", 297 | " print(\"beta[{}][{}] = {:.5f}\".format(t+1, n, beta[t][n]))\n", 298 | " \n", 299 | " p = sum(self.pi[i] * self.B[i][Q[0]] * beta[0][i] for i in range(N))\n", 300 | " #print(p)\n", 301 | " return p" 302 | ], 303 | "execution_count": 0, 304 | "outputs": [] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "metadata": { 309 | "id": "rIyCzvnbBD8K", 310 | "colab_type": "code", 311 | "colab": { 312 | "base_uri": "https://localhost:8080/", 313 | "height": 197 314 | }, 315 | "outputId": "48682810-a2c9-47ca-d46c-2cb3eeb545b5" 316 | }, 317 | "source": [ 318 | "m = HMM_bw(A, B, pi)\n", 319 | "m.backward(Q)" 320 | ], 321 | "execution_count": 77, 322 | "outputs": [ 323 | { 324 | "output_type": "stream", 325 | "text": [ 326 | "beta[3][0] = 1.00\n", 327 | "beta[3][1] = 1.00\n", 328 | "beta[3][2] = 1.00\n", 329 | "beta[2][0] = 0.54000\n", 330 | "beta[2][1] = 0.49000\n", 331 | "beta[2][2] = 0.57000\n", 332 | "beta[1][0] = 0.24510\n", 333 | "beta[1][1] = 0.26220\n", 334 | "beta[1][2] = 0.22770\n" 335 | ], 336 | "name": "stdout" 337 | }, 338 | { 339 | "output_type": "execute_result", 340 | "data": { 341 | "text/plain": [ 342 | "0.130218" 343 | ] 344 | }, 345 | "metadata": { 346 | "tags": [] 347 | }, 348 | "execution_count": 77 349 | } 350 | ] 351 | }, 352 | { 353 | "cell_type": "markdown", 354 | "metadata": { 355 | "id": "uc5CfmwRGdb6", 356 | "colab_type": "text" 357 | }, 358 | "source": [ 359 | "# 学习问题" 360 | ] 361 | }, 362 | { 363 | "cell_type": "markdown", 364 | "metadata": { 365 | "id": "qTYQ6fO8GfMI", 366 | "colab_type": "text" 367 | }, 368 | "source": [ 369 | "HMM的学习,根据训练数据是否包括观测序列和对应的状态序列还是只有观测序列,可以分别为**有监督学习**和**无监督学习**来实现。" 370 | ] 371 | }, 372 | { 373 | "cell_type": "markdown", 374 | "metadata": { 375 | "id": "qTRkuN4GIfnW", 376 | "colab_type": "text" 377 | }, 378 | "source": [ 379 | "### 监督学习 \n", 380 | "\n", 381 | "假设已给训练数据包含$S$个长度相同的观测序列和对应的状态序列${(O_{1}, I_{1}), (O_{2}, I_{2}),..., (O_{S}, I_{S})}$. 那么可以利用**极大似然估计**法来估计HMM的参数。" 382 | ] 383 | }, 384 | { 385 | "cell_type": "markdown", 386 | "metadata": { 387 | "id": "neoLTatUJgks", 388 | "colab_type": "text" 389 | }, 390 | "source": [ 391 | "### 无监督学习 \n", 392 | "\n", 393 | "假设已给训练数据只包含$S$个长度为$T$的观测序列${O_{1}, O_{2},..., O_{S}}$, 而没有对应的状态序列, 目标是学习HMM $\\lambda = (A, B, \\pi)$ 的参数。 我们将观测序列数据看作观测数据$Q$, 状态序列数据看作不可观测的隐数据$I$, 那么HMM则是一个含有隐变量的概率模型: \n", 394 | "\n", 395 | "$P(O|\\lambda) = \\sum_{I}P(O|I, \\lambda)P(I|\\lambda)$ \n", 396 | "\n", 397 | "他的参数可以由EM算法来学习。" 398 | ] 399 | }, 400 | { 401 | "cell_type": "code", 402 | "metadata": { 403 | "id": "9862N0kFBFBb", 404 | "colab_type": "code", 405 | "colab": {} 406 | }, 407 | "source": [ 408 | "" 409 | ], 410 | "execution_count": 0, 411 | "outputs": [] 412 | }, 413 | { 414 | "cell_type": "markdown", 415 | "metadata": { 416 | "id": "0lKwVVTPPpjt", 417 | "colab_type": "text" 418 | }, 419 | "source": [ 420 | "# 预测问题\n" 421 | ] 422 | }, 423 | { 424 | "cell_type": "markdown", 425 | "metadata": { 426 | "id": "796IZ5QqPrdS", 427 | "colab_type": "text" 428 | }, 429 | "source": [ 430 | "### 近似算法 \n", 431 | "近似算法的想法是, 在每个时刻$t$ 选择在该时刻最有可能出现的状态$i^{*}_{t}$,从而得到一个状态序列 $I^{*} = (i^{*}_{1}, i^{*}_{2}, ..., i^{*}_{T})$, 将他作为预测的结果。" 432 | ] 433 | }, 434 | { 435 | "cell_type": "markdown", 436 | "metadata": { 437 | "id": "G5XHAHOdQlKE", 438 | "colab_type": "text" 439 | }, 440 | "source": [ 441 | "### 维特比算法 \n" 442 | ] 443 | }, 444 | { 445 | "cell_type": "markdown", 446 | "metadata": { 447 | "id": "TlCSEMGhQnRM", 448 | "colab_type": "text" 449 | }, 450 | "source": [ 451 | "维特比算法实际是用动态规划,解HMM的预测问题,即用动态规划求概率最大路径。" 452 | ] 453 | }, 454 | { 455 | "cell_type": "code", 456 | "metadata": { 457 | "id": "gSOX42cFSM2l", 458 | "colab_type": "code", 459 | "colab": {} 460 | }, 461 | "source": [ 462 | "class HMM_viterbi:\n", 463 | " def __init__(self, A, B, pi):\n", 464 | " self.A = A\n", 465 | " self.B = B\n", 466 | " self.pi = pi\n", 467 | " \n", 468 | " def viterbi(self, Q):\n", 469 | " T = len(Q) # 观测序列长度,时刻T\n", 470 | " N = len(self.A) # 状态数\n", 471 | " sigma = np.zeros((T, N))\n", 472 | " delta = np.zeros((T, N))\n", 473 | " for t in range(T):\n", 474 | " for n in range(N):\n", 475 | " if t == 0:\n", 476 | " sigma[t][n] = self.pi[n] * self.B[n][Q[t]]\n", 477 | " delta[t][n] = 0\n", 478 | " print(\"sigmia[{}][{}] = {:.2f}\".format(t+1, n+1, sigma[t][n]))\n", 479 | " print(\"delta[{}][{}] = {}\".format(t+1, n+1, delta[t][n]))\n", 480 | " \n", 481 | " else:\n", 482 | " sigma[t][n] = np.max([sigma[t-1][j] * self.A[j][n] for j in range(N)]) * self.B[n][Q[t]]\n", 483 | " print(\"sigma[{}][{}] = {:.5f}\".format(t+1, n+1, sigma[t][n]))\n", 484 | " \n", 485 | " delta[t][n] = np.argmax([sigma[t-1][j] * self.A[j][n] for j in range(N)]) + 1\n", 486 | " print(\"delta[{}][{}] = {}\".format(t+1, n+1, delta[t][n]))\n", 487 | " \n", 488 | " P = np.max(sigma[T-1])\n", 489 | " print(P)\n", 490 | " pth = []\n", 491 | " for t in range(T-1, -1, -1):\n", 492 | " if t == T - 1:\n", 493 | " i_t = np.argmax(sigma[t])\n", 494 | " pth.append(i_t + 1)\n", 495 | " else:\n", 496 | " i_t = int(delta[t+1][i_t]) - 1\n", 497 | " pth.append(i_t + 1)\n", 498 | " \n", 499 | " return pth" 500 | ], 501 | "execution_count": 0, 502 | "outputs": [] 503 | }, 504 | { 505 | "cell_type": "markdown", 506 | "metadata": { 507 | "id": "LDEr3xz9SNSE", 508 | "colab_type": "text" 509 | }, 510 | "source": [ 511 | "#### 例 10.3" 512 | ] 513 | }, 514 | { 515 | "cell_type": "markdown", 516 | "metadata": { 517 | "id": "oFJgnsvySRNF", 518 | "colab_type": "text" 519 | }, 520 | "source": [ 521 | "$A=\\begin{bmatrix}\n", 522 | " 0.5& 0.2& 0.3\\\\ \n", 523 | " 0.3& 0.5& 0.2\\\\ \n", 524 | " 0.2& 0.3& 0.5\n", 525 | "\\end{bmatrix}, \n", 526 | "B=\\begin{bmatrix}\n", 527 | " 0.5& 0.5\\\\ \n", 528 | " 0.4& 0.6\\\\ \n", 529 | " 0.7& 0.3\n", 530 | "\\end{bmatrix},\n", 531 | "\\pi=\\begin{bmatrix}\n", 532 | " 0.2\\\\ \n", 533 | " 0.4\\\\ \n", 534 | " 0.4\n", 535 | "\\end{bmatrix}$ \n", 536 | "\n", 537 | "已知观测序列$O=(红, 白, 红)$,试求最优状态序列,即最优路径 $I^{*}=(i^{*}_{1}, i^{*}_{2}, i^{*}_{3})$." 538 | ] 539 | }, 540 | { 541 | "cell_type": "code", 542 | "metadata": { 543 | "id": "MXWapkQESxlq", 544 | "colab_type": "code", 545 | "colab": {} 546 | }, 547 | "source": [ 548 | "A = [[0.5, 0.2, 0.3], [0.3, 0.5, 0.2],[0.2, 0.3, 0.5]]\n", 549 | "B = [[0.5, 0.5], [0.4, 0.6], [0.7, 0.3]]\n", 550 | "pi = [0.2, 0.4, 0.4]\n", 551 | "Q = [0,1,0]" 552 | ], 553 | "execution_count": 0, 554 | "outputs": [] 555 | }, 556 | { 557 | "cell_type": "code", 558 | "metadata": { 559 | "id": "a55NPE7AS-Eu", 560 | "colab_type": "code", 561 | "colab": { 562 | "base_uri": "https://localhost:8080/", 563 | "height": 377 564 | }, 565 | "outputId": "5f86553f-345a-466d-fc8a-755b65a1d167" 566 | }, 567 | "source": [ 568 | "m = HMM_viterbi(A, B, pi)\n", 569 | "m.viterbi(Q)" 570 | ], 571 | "execution_count": 135, 572 | "outputs": [ 573 | { 574 | "output_type": "stream", 575 | "text": [ 576 | "sigmia[1][1] = 0.10\n", 577 | "delta[1][1] = 0.0\n", 578 | "sigmia[1][2] = 0.16\n", 579 | "delta[1][2] = 0.0\n", 580 | "sigmia[1][3] = 0.28\n", 581 | "delta[1][3] = 0.0\n", 582 | "sigma[2][1] = 0.02800\n", 583 | "delta[2][1] = 3.0\n", 584 | "sigma[2][2] = 0.05040\n", 585 | "delta[2][2] = 3.0\n", 586 | "sigma[2][3] = 0.04200\n", 587 | "delta[2][3] = 3.0\n", 588 | "sigma[3][1] = 0.00756\n", 589 | "delta[3][1] = 2.0\n", 590 | "sigma[3][2] = 0.01008\n", 591 | "delta[3][2] = 2.0\n", 592 | "sigma[3][3] = 0.01470\n", 593 | "delta[3][3] = 3.0\n", 594 | "0.014699999999999998\n" 595 | ], 596 | "name": "stdout" 597 | }, 598 | { 599 | "output_type": "execute_result", 600 | "data": { 601 | "text/plain": [ 602 | "[3, 3, 3]" 603 | ] 604 | }, 605 | "metadata": { 606 | "tags": [] 607 | }, 608 | "execution_count": 135 609 | } 610 | ] 611 | }, 612 | { 613 | "cell_type": "code", 614 | "metadata": { 615 | "id": "GKPLvAFzXuva", 616 | "colab_type": "code", 617 | "colab": {} 618 | }, 619 | "source": [ 620 | "" 621 | ], 622 | "execution_count": 0, 623 | "outputs": [] 624 | } 625 | ] 626 | } -------------------------------------------------------------------------------- /第11章 条件随机场/CRF.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "CRF.ipynb", 7 | "version": "0.3.2", 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | } 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "markdown", 19 | "metadata": { 20 | "id": "fo3mvewZFRAt", 21 | "colab_type": "text" 22 | }, 23 | "source": [ 24 | "# CRF, Conditional Random Field" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": { 30 | "id": "nXwJKCfSFcnL", 31 | "colab_type": "text" 32 | }, 33 | "source": [ 34 | "1.概率无向图模型是由无向图表示的联合概率分布。无向图上的结点之间的连接关系表示了联合分布的随机变量集合之间的条件独立性,即马尔可夫性。因此,概率无向图模型也称为马尔可夫随机场。\n", 35 | "\n", 36 | "概率无向图模型或马尔可夫随机场的联合概率分布可以分解为无向图最大团上的正值函数的乘积的形式。\n", 37 | "\n", 38 | "2.条件随机场是给定输入随机变量X条件下,输出随机变量Y的条件概率分布模型, 其形式为参数化的对数线性模型。条件随机场的最大特点是假设输出变量之间的联合概率分布构成概率无向图模型,即马尔可夫随机场。条件随机场是判别模型。\n", 39 | "\n", 40 | "3.线性链条件随机场是定义在观测序列与标记序列上的条件随机场。线性链条件随机场一般表示为给定观测序列条件下的标记序列的条件概率分布,由参数化的对数线性模型表示。模型包含特征及相应的权值,特征是定义在线性链的边与结点上的。线性链条件随机场的数学表达式是 \n", 41 | "\n", 42 | "$P(y|x)=\\frac{1}{Z(x)}exp(\\sum_{i,k} \\lambda_{k}t_{k}(y_{i-1}, y_{i}, x, i) + \\sum_{i,l}\\mu_{l}S_{l}(y_{i}, x, i))$ \n", 43 | "\n", 44 | "其中, \n", 45 | "\n", 46 | "$Z(x)=\\sum_{y}exp(\\sum_{i,k}\\lambda_{k}t_{k}(y_{i-1}, y_{i}, x, i) + \\sum_{i,l}\\mu_{l}S_{l}(y_{i}, x, i))$\n", 47 | "\n", 48 | "\n", 49 | "4.线性链条件随机场的概率计算通常利用前向-后向算法。\n", 50 | "\n", 51 | "5.条件随机场的学习方法通常是极大似然估计方法或正则化的极大似然估计,即在给定训练数据下,通过极大化训练数据的对数似然函数以估计模型参数。具体的算法有改进的迭代尺度算法、梯度下降法、拟牛顿法等。\n", 52 | "\n", 53 | "6.线性链条件随机场的一个重要应用是标注。维特比算法是给定观测序列求条件概率最大的标记序列的方法。" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": { 59 | "id": "SL_XKnt7H4Q2", 60 | "colab_type": "text" 61 | }, 62 | "source": [ 63 | "#### 例 11.1" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "metadata": { 69 | "id": "7kx_vjWOFGSd", 70 | "colab_type": "code", 71 | "colab": {} 72 | }, 73 | "source": [ 74 | "import numpy as np" 75 | ], 76 | "execution_count": 0, 77 | "outputs": [] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "metadata": { 82 | "id": "RXUzhi7VHy2B", 83 | "colab_type": "code", 84 | "colab": { 85 | "base_uri": "https://localhost:8080/", 86 | "height": 53 87 | }, 88 | "outputId": "19c4e0bf-f4e7-4533-8d25-d8b2a5822d31" 89 | }, 90 | "source": [ 91 | "#这里定义T为转移矩阵列代表前一个y(ij)代表由状态i转到状态j的概率,Tx矩阵x对应于时间序列\n", 92 | "#这里将书上的转移特征转换为如下以时间轴为区别的三个多维列表,维度为输出的维度\n", 93 | "T1 = [[0.6, 1], [1, 0]]\n", 94 | "T2 = [[0, 1], [1, 0.2]]\n", 95 | "#将书上的状态特征同样转换成列表,第一个是为y1的未规划概率,第二个为y2的未规划概率\n", 96 | "S0 = [1, 0.5]\n", 97 | "S1 = [0.8, 0.5]\n", 98 | "S2 = [0.8, 0.5]\n", 99 | "Y = [1, 2, 2] #即书上例一需要计算的非规划条件概率的标记序列\n", 100 | "Y = np.array(Y) - 1 #这里为了将数与索引相对应即从零开始\n", 101 | "P = np.exp(S0[Y[0]])\n", 102 | "for i in range(1, len(Y)):\n", 103 | " P *= np.exp((eval('S%d' % i)[Y[i]]) + eval('T%d' % i)[Y[i - 1]][Y[i]])\n", 104 | "print(P)\n", 105 | "print(np.exp(3.2))" 106 | ], 107 | "execution_count": 6, 108 | "outputs": [ 109 | { 110 | "output_type": "stream", 111 | "text": [ 112 | "24.532530197109345\n", 113 | "24.532530197109352\n" 114 | ], 115 | "name": "stdout" 116 | } 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": { 122 | "id": "7lUipRWFMVB7", 123 | "colab_type": "text" 124 | }, 125 | "source": [ 126 | "#### 例 11.2" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "metadata": { 132 | "id": "TG6SEYCbMXty", 133 | "colab_type": "code", 134 | "colab": { 135 | "base_uri": "https://localhost:8080/", 136 | "height": 35 137 | }, 138 | "outputId": "b54d58b8-60b1-4dac-b03e-8816458c5ae2" 139 | }, 140 | "source": [ 141 | "#这里根据例11.2的启发整合为一个矩阵\n", 142 | "F0 = S0\n", 143 | "F1 = T1 + np.array(S1 * len(T1)).reshape(np.asarray(T1).shape)\n", 144 | "F2 = T2 + np.array(S2 * len(T2)).reshape(np.asarray(T2).shape)\n", 145 | "Y = [1, 2, 2] #即书上例一需要计算的非规划条件概率的标记序列\n", 146 | "Y = np.array(Y) - 1\n", 147 | "\n", 148 | "P = np.exp(F0[Y[0]])\n", 149 | "Sum = P\n", 150 | "for i in range(1, len(Y)):\n", 151 | " PIter = np.exp((eval('F%d' % i)[Y[i - 1]][Y[i]]))\n", 152 | " P *= PIter\n", 153 | " Sum += PIter\n", 154 | "print('非规范化概率', P)" 155 | ], 156 | "execution_count": 14, 157 | "outputs": [ 158 | { 159 | "output_type": "stream", 160 | "text": [ 161 | "非规范化概率 24.532530197109345\n" 162 | ], 163 | "name": "stdout" 164 | } 165 | ] 166 | }, 167 | { 168 | "cell_type": "markdown", 169 | "metadata": { 170 | "id": "fW5RZz89NPsD", 171 | "colab_type": "text" 172 | }, 173 | "source": [ 174 | "#### Reference: https://nbviewer.jupyter.org/github/fengdu78/lihang-code/blob/master/%E7%AC%AC11%E7%AB%A0%20%E6%9D%A1%E4%BB%B6%E9%9A%8F%E6%9C%BA%E5%9C%BA/11.CRF.ipynb" 175 | ] 176 | }, 177 | { 178 | "cell_type": "markdown", 179 | "metadata": { 180 | "id": "P6FeGVCCM7zC", 181 | "colab_type": "text" 182 | }, 183 | "source": [ 184 | "### 其实,我还是没搞懂CRF,没有在具体的项目中使用。PGM本身就是一个很大的topic,就这简简单单的一章无法全部解释。" 185 | ] 186 | } 187 | ] 188 | } -------------------------------------------------------------------------------- /第13章 无监督学习概论/Introduction_to_Unsupervised_Learning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "display_name": "Python 3", 7 | "language": "python", 8 | "name": "python3" 9 | }, 10 | "language_info": { 11 | "codemirror_mode": { 12 | "name": "ipython", 13 | "version": 3 14 | }, 15 | "file_extension": ".py", 16 | "mimetype": "text/x-python", 17 | "name": "python", 18 | "nbconvert_exporter": "python", 19 | "pygments_lexer": "ipython3", 20 | "version": "3.6.4" 21 | }, 22 | "colab": { 23 | "name": "Introduction_to_Unsupervised_Learning.ipynb", 24 | "version": "0.3.2", 25 | "provenance": [], 26 | "collapsed_sections": [] 27 | } 28 | }, 29 | "cells": [ 30 | { 31 | "cell_type": "markdown", 32 | "metadata": { 33 | "id": "qipcvQWGoxDB", 34 | "colab_type": "text" 35 | }, 36 | "source": [ 37 | "# 第13章 无监督学习概论" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": { 43 | "id": "j5m7ppYdoxDP", 44 | "colab_type": "text" 45 | }, 46 | "source": [ 47 | "1.机器学习或统计学习一般包括监督学习、无监督学习、强化学习。\n", 48 | "\n", 49 | "无监督学习是指从无标注数据中学习模型的机器学习问题。无标注数据是自然得到的数据,模型表示数据的类别、转换或概率无监督学习的本质是学习数据中的统计规律或潜在结构,主要包括聚类、降维、概率估计。\n", 50 | "\n", 51 | "2.无监督学习可以用于对已有数据的分析,也可以用于对未来数据的预测。学习得到的模型有函数$z=g(x)$,条件概率分布$P(z|x)$,或条件概率分布$P(x|z)$。\n", 52 | "\n", 53 | "无监督学习的基本想法是对给定数据(矩阵数据)进行某种“压缩”,从而找到数据的潜在结构,假定损失最小的压缩得到的结果就是最本质的结构。可以考虑发掘数据的纵向结构,对应聚类。也可以考虑发掘数据的横向结构,对应降维。还可以同时考虑发掘数据的纵向与横向结构,对应概率模型估计。\n", 54 | "\n", 55 | "3.聚类是将样本集合中相似的样本(实例)分配到相同的类,不相似的样本分配到不同的类。聚类分硬聚类和软聚类。聚类方法有层次聚类和$k$均值聚类。\n", 56 | "\n", 57 | "4.降维是将样本集合中的样本(实例)从高维空间转换到低维空间。假设样本原本存在于低维空间,或近似地存在于低维空间,通过降维则可以更好地表示样本数据的结构,即更好地表示样本之间的关系。降维有线性降维和非线性降维,降维方法有主成分分析。\n", 58 | "\n", 59 | "5.概率模型估计假设训练数据由一个概率模型生成,同时利用训练数据学习概率模型的结构和参数。概率模型包括混合模型、率图模型等。概率图模型又包括有向图模型和无向图模型。\n", 60 | "\n", 61 | "6.话题分析是文本分析的一种技术。给定一个文本集合,话题分析旨在发现文本集合中每个文本的话题,而话题由单词的集合表示。话题分析方法有潜在语义分析、概率潜在语义分析和潜在狄利克雷分配。\n", 62 | "\n", 63 | "7.图分析的目的是发掘隐藏在图中的统计规律或潜在结构。链接分析是图分析的一种,主要是发现有向图中的重要结点,包括 **PageRank**算法。" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": { 69 | "id": "5Y2Wgt2Lo20z", 70 | "colab_type": "text" 71 | }, 72 | "source": [ 73 | "#### Reference: https://github.com/fengdu78/lihang-code/blob/master/%E7%AC%AC13%E7%AB%A0%20%E6%97%A0%E7%9B%91%E7%9D%A3%E5%AD%A6%E4%B9%A0%E6%A6%82%E8%AE%BA/13.Introduction_to_Unsupervised_Learning.ipynb" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "metadata": { 79 | "id": "uScYNYproxDR", 80 | "colab_type": "code", 81 | "colab": {} 82 | }, 83 | "source": [ 84 | "" 85 | ], 86 | "execution_count": 0, 87 | "outputs": [] 88 | } 89 | ] 90 | } -------------------------------------------------------------------------------- /第17章 潜在语义分析/LSA.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "LSA.ipynb", 7 | "version": "0.3.2", 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | } 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "markdown", 19 | "metadata": { 20 | "id": "LOWANK49Pi27", 21 | "colab_type": "text" 22 | }, 23 | "source": [ 24 | "# 潜在语义分析(Latent semantic analysis, LSA)" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": { 30 | "id": "0WD7XVWRPkX1", 31 | "colab_type": "text" 32 | }, 33 | "source": [ 34 | "**LSA** 是一种无监督学习方法,主要用于文本的话题分析,其特点是通过矩阵分解发现文本与单词之间的基于话题的语义关系。也称为潜在语义索引(Latent semantic indexing, LSI)。\n", 35 | "\n", 36 | "LSA 使用的是非概率的话题分析模型。将文本集合表示为**单词-文本矩阵**,对单词-文本矩阵进行**奇异值分解**,从而得到话题向量空间,以及文本在话题向量空间的表示。\n", 37 | "\n", 38 | "**非负矩阵分解**(non-negative matrix factorization, NMF)是另一种矩阵的因子分解方法,其特点是分解的矩阵非负。" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": { 44 | "id": "P1sWKgTGQ7r-", 45 | "colab_type": "text" 46 | }, 47 | "source": [ 48 | "## 单词向量空间 \n", 49 | "word vector space model" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": { 55 | "id": "CqXj1777RM8y", 56 | "colab_type": "text" 57 | }, 58 | "source": [ 59 | "给定一个文本,用一个向量表示该文本的”语义“, 向量的**每一维对应一个单词**,其数值为该单词在该文本中出现的频数或权值;基本假设是文本中所有单词的出现情况表示了文本的语义内容,文本集合中的每个文本都表示为一个向量,存在于一个向量空间;向量空间的度量,如内积或标准化**内积**表示文本之间的**相似度**。" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": { 65 | "id": "3HVCXf6CSmTT", 66 | "colab_type": "text" 67 | }, 68 | "source": [ 69 | "给定一个含有$n$个文本的集合$D=({d_{1}, d_{2},...,d_{n}})$,以及在所有文本中出现的$m$个单词的集合$W=({w_{1},w_{2},...,w_{m}})$. 将单词在文本的出现的数据用一个单词-文本矩阵(word-document matrix)表示,记作$X$:\n", 70 | "\n", 71 | "$\n", 72 | "X = \\begin{bmatrix}\n", 73 | "x_{11} & x_{12}& x_{1n}& \\\\ \n", 74 | "x_{21}& x_{22}& x_{2n}& \\\\ \n", 75 | "\\vdots & \\vdots & \\vdots & \\\\ \n", 76 | "x_{m1}& x_{m2}& x_{mn}& \n", 77 | "\\end{bmatrix}\n", 78 | "$\n", 79 | "\n", 80 | "这是一个$m*n$矩阵,元素$x_{ij}$表示单词$w_{i}$在文本$d_{j}$中出现的频数或权值。由于单词的种类很多,而每个文本中出现单词的种类通常较少,所有单词-文本矩阵是一个稀疏矩阵。\n", 81 | "\n" 82 | ] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "metadata": { 87 | "id": "K2ncB3cde1Ab", 88 | "colab_type": "text" 89 | }, 90 | "source": [ 91 | "权值通常用单词**频率-逆文本率**(term frequency-inverse document frequency, TF-IDF)表示:\n", 92 | "\n", 93 | "$TF-IDF(t, d ) = TF(t, d) * IDF(t)$, \n", 94 | "\n", 95 | "其中,$TF(t,d)$为单词$t$在文本$d$中出现的概率,$IDF(t)$是逆文本率,用来衡量单词$t$对表示语义所起的重要性, \n", 96 | "\n", 97 | "$IDF(t) = log(\\frac{len(D)}{len(t \\in D) + 1})$." 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": { 103 | "id": "bpu7MycIgu65", 104 | "colab_type": "text" 105 | }, 106 | "source": [ 107 | "单词向量空间模型的优点是**模型简单,计算效率高**。因为单词向量通常是稀疏的,单词向量空间模型也有一定的局限性,体现在内积相似度未必能够准确表达两个文本的语义相似度上。因为自然语言的单词具有一词多义性(polysemy)及多词一义性(synonymy)。" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "metadata": { 113 | "id": "ns5wncZohn-z", 114 | "colab_type": "text" 115 | }, 116 | "source": [ 117 | "## 话题向量空间" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": { 123 | "id": "mmZpPHIdhrAy", 124 | "colab_type": "text" 125 | }, 126 | "source": [ 127 | "**1. 话题向量空间**:\n", 128 | "\n", 129 | "给定一个含有$n$个文本的集合$D=({d_{1}, d_{2},...,d_{n}})$,以及在所有文本中出现的$m$个单词的集合$W=({w_{1},w_{2},...,w_{m}})$. 可以获得其单词-文本矩阵$X$: \n", 130 | "\n", 131 | "$\n", 132 | "X = \\begin{bmatrix}\n", 133 | "x_{11} & x_{12}& x_{1n}& \\\\ \n", 134 | "x_{21}& x_{22}& x_{2n}& \\\\ \n", 135 | "\\vdots & \\vdots & \\vdots & \\\\ \n", 136 | "x_{m1}& x_{m2}& x_{mn}& \n", 137 | "\\end{bmatrix}\n", 138 | "$\n", 139 | "\n", 140 | "\n", 141 | "假设所有文本共含有$k$个话题。假设每个话题由一个定义在单词集合$W$上的$m$维向量表示,称为话题向量,即: \n", 142 | "$t_{l} = \\begin{bmatrix}\n", 143 | "t_{1l}\\\\ \n", 144 | "t_{2l}\\\\ \n", 145 | "\\vdots \\\\ \n", 146 | "t_{ml}\\end{bmatrix}, l=1,2,...,k$\n", 147 | "\n", 148 | "其中$t_{il}$单词$w_{i}$在话题$t_{l}$的权值,$i=1,2,...,m$, 权值越大,该单词在该话题中的重要程度就越高。这$k$个话题向量 $t_{1},t_{2},...,t_{k}$张成一个话题向量空间(topic vector space), 维数为$k$.**话题向量空间是单词向量空间的一个子空间**。\n", 149 | "\n", 150 | "话题向量空间$T$: \n", 151 | "\n", 152 | "\n", 153 | "$\n", 154 | "T = \\begin{bmatrix}\n", 155 | "t_{11} & t_{12}& t_{1k}& \\\\ \n", 156 | "t_{21}& t_{22}& t_{2k}& \\\\ \n", 157 | "\\vdots & \\vdots & \\vdots & \\\\ \n", 158 | "t_{m1}& t_{m2}& t_{mk}& \n", 159 | "\\end{bmatrix}\n", 160 | "$ \n", 161 | "\n", 162 | "矩阵$T$,称为**单词-话题矩阵**。 $T = [t_{1}, t_{2}, ..., t_{k}]$" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "metadata": { 168 | "id": "Oc1c3JcKlTBD", 169 | "colab_type": "text" 170 | }, 171 | "source": [ 172 | "**2. 文本在话题向量空间中的表示** :\n", 173 | "\n", 174 | "考虑文本集合$D$的文本$d_{j}$, 在单词向量空间中由一个向量$x_{j}$表示,将$x_{j}$投影到话题向量空间$T$中,得到话题向量空间的一个向量$y_{j}$, $y_{j}$是一个$k$维向量: \n", 175 | "\n", 176 | "$y_{j} = \\begin{bmatrix}\n", 177 | "y_{1j}\\\\ \n", 178 | "y_{2j}\\\\ \n", 179 | "\\vdots \\\\ \n", 180 | "y_{kj}\\end{bmatrix}, j=1,2,...,n$ \n", 181 | "\n", 182 | "其中,$y_{lj}$是文本$d_{j}$在话题$t_{l}$中的权值, $l = 1,2,..., k$, 权值越大,该话题在该文本中的重要程度就越高。 \n", 183 | "\n", 184 | "矩阵$Y$ 表示话题在文本中出现的情况,称为话题-文本矩阵(topic-document matrix),记作: \n", 185 | "\n", 186 | "$\n", 187 | "Y = \\begin{bmatrix}\n", 188 | "y_{11} & y_{12}& y_{1n}& \\\\ \n", 189 | "y_{21}& y_{22}& y_{2n}& \\\\ \n", 190 | "\\vdots & \\vdots & \\vdots & \\\\ \n", 191 | "y_{k1}& y_{k2}& y_{kn}& \n", 192 | "\\end{bmatrix}\n", 193 | "$ \n", 194 | "\n", 195 | "也可写成: $Y = [y_{1}, y_{2} ..., y_{n}]$" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": { 201 | "id": "YcU3xwYindTo", 202 | "colab_type": "text" 203 | }, 204 | "source": [ 205 | "**3. 从单词向量空间到话题向量空间的线性变换**: \n", 206 | "\n", 207 | "如此,单词向量空间的文本向量$x_{j}$可以通过他在话题空间中的向量$y_{j}$近似表示,具体地由$k$个话题向量以$y_{j}$为系数的线性组合近似表示: \n", 208 | "\n", 209 | "$x_{j} = y_{1j}t_{1} + y_{2j}t_{2} + ... + y_{yj}t_{k}, j = 1,2,..., n$ \n", 210 | "\n", 211 | "所以,单词-文本矩阵$X$可以近似的表示为单词-话题矩阵$T$与话题-文本矩阵$Y$的乘积形式。\n", 212 | "\n", 213 | "$X \\approx TY$ \n", 214 | "\n", 215 | "直观上,潜在语义分析是将单词向量空间的表示通过线性变换转换为在话题向量空间中的表示。这个线性变换由矩阵因子分解式的形式体现。" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "metadata": { 221 | "id": "Cu4JekfXFMqs", 222 | "colab_type": "text" 223 | }, 224 | "source": [ 225 | "### 潜在语义分析算法 " 226 | ] 227 | }, 228 | { 229 | "cell_type": "markdown", 230 | "metadata": { 231 | "id": "0awNXCy1Gw0K", 232 | "colab_type": "text" 233 | }, 234 | "source": [ 235 | "潜在语义分析利用矩阵奇异值分解,具体地,对单词-文本矩阵进行奇异值分解,将其左矩阵作为话题向量空间,将其对角矩阵与右矩阵的乘积作为文本在话题向量空间的表示。" 236 | ] 237 | }, 238 | { 239 | "cell_type": "markdown", 240 | "metadata": { 241 | "id": "otq3HMu5HVoK", 242 | "colab_type": "text" 243 | }, 244 | "source": [ 245 | "给定一个含有$n$个文本的集合$D=({d_{1}, d_{2},...,d_{n}})$,以及在所有文本中出现的$m$个单词的集合$W=({w_{1},w_{2},...,w_{m}})$. 可以获得其单词-文本矩阵$X$: \n", 246 | "$\n", 247 | "X = \\begin{bmatrix}\n", 248 | "x_{11} & x_{12}& x_{1n}& \\\\ \n", 249 | "x_{21}& x_{22}& x_{2n}& \\\\ \n", 250 | "\\vdots & \\vdots & \\vdots & \\\\ \n", 251 | "x_{m1}& x_{m2}& x_{mn}& \n", 252 | "\\end{bmatrix}\n", 253 | "$\n", 254 | "\n", 255 | "\n", 256 | "\n", 257 | "\n" 258 | ] 259 | }, 260 | { 261 | "cell_type": "markdown", 262 | "metadata": { 263 | "id": "mwNGRDgrHmmV", 264 | "colab_type": "text" 265 | }, 266 | "source": [ 267 | "**截断奇异值分解**:\n", 268 | "\n", 269 | "潜在语义分析根据确定的话题数$k$对单词-文本矩阵$X$进行截断奇异值分解: \n", 270 | "\n", 271 | "$\n", 272 | "X \\approx U_{k}\\Sigma _{k}V_{k}^{T} = \\begin{bmatrix}\n", 273 | "\\mu _{1} & \\mu _{2}& \\cdots & \\mu _{k}\n", 274 | "\\end{bmatrix}\\begin{bmatrix}\n", 275 | "\\sigma_{1} & 0& 0& 0\\\\ \n", 276 | " 0& \\sigma_{2}& 0& 0\\\\ \n", 277 | " 0& 0& \\ddots & 0\\\\ \n", 278 | " 0& 0& 0& \\sigma_{k}\n", 279 | "\\end{bmatrix}\\begin{bmatrix}\n", 280 | "v_{1}^{T}\\\\ \n", 281 | "v_{2}^{T}\\\\ \n", 282 | "\\vdots \\\\ \n", 283 | "v_{k}^{T}\\end{bmatrix}\n", 284 | "$\n", 285 | "\n", 286 | "矩阵$U_{k}$的每一个列向量 $u_{1}, u_{2},..., u_{k}$ 表示一个话题,称为**话题向量**。由这 $k$ 个话题向量张成一个子空间: \n", 287 | "\n", 288 | "$\n", 289 | "U_{k} = \\begin{bmatrix}\n", 290 | "u_{1} & u_{2}& \\cdots & u_{k}\n", 291 | "\\end{bmatrix}\n", 292 | "$\n", 293 | "\n", 294 | "称为**话题向量空间**。 \n", 295 | "\n", 296 | "综上, 可以通过对单词-文本矩阵的奇异值分解进行潜在语义分析: \n", 297 | "\n", 298 | "$ X \\approx U_{k} \\Sigma_{k} V_{k}^{T} = U_{k}(\\Sigma_{k}V_{k}^{T})$ \n", 299 | "\n", 300 | "得到话题空间 $U_{k}$ , 以及文本在话题空间的表示($\\Sigma_{k}V_{k}^{T}$). " 301 | ] 302 | }, 303 | { 304 | "cell_type": "markdown", 305 | "metadata": { 306 | "id": "UTNHyq8mK8l5", 307 | "colab_type": "text" 308 | }, 309 | "source": [ 310 | "### 非负矩阵分解算法" 311 | ] 312 | }, 313 | { 314 | "cell_type": "markdown", 315 | "metadata": { 316 | "id": "RQvqaMDYK_jf", 317 | "colab_type": "text" 318 | }, 319 | "source": [ 320 | "非负矩阵分解也可以用于话题分析。对单词-文本矩阵进行非负矩阵分解,将**其左矩阵作为话题向量空间**,将其**右矩阵作为文本在话题向量空间的表示**。" 321 | ] 322 | }, 323 | { 324 | "cell_type": "markdown", 325 | "metadata": { 326 | "id": "ApM8tE3MLqpP", 327 | "colab_type": "text" 328 | }, 329 | "source": [ 330 | "#### 非负矩阵分解 " 331 | ] 332 | }, 333 | { 334 | "cell_type": "markdown", 335 | "metadata": { 336 | "id": "glMwmkiwLyIn", 337 | "colab_type": "text" 338 | }, 339 | "source": [ 340 | "若一个矩阵的索引元素非负,则该矩阵为非负矩阵。若$X$是非负矩阵,则: $X >= 0$. \n", 341 | "\n", 342 | "给定一个非负矩阵$X$, 找到两个非负矩阵$W >= 0$ 和 $H>= 0$, 使得: \n", 343 | "\n", 344 | "$ X \\approx WH$\n", 345 | "\n", 346 | "即非负矩阵$X$分解为两个非负矩阵$W$和$H$的乘积形式,成为非负矩阵分解。因为$WH$与$X$完全相等很难实现,所以只要求近似相等。 \n", 347 | "\n", 348 | "假设非负矩阵$X$是$m\\times n$矩阵,非负矩阵$W$和$H$分别为 $m\\times k$ 矩阵和 $k\\times n$ 矩阵。假设 $k < min(m, n)$ 即$W$ 和 $H$ 小于原矩阵 $X$, 所以非负矩阵分解是对原数据的压缩。\n", 349 | "\n", 350 | "称 $W$ 为基矩阵, $H$ 为系数矩阵。非负矩阵分解旨在用较少的基向量,系数向量来表示为较大的数据矩阵。\n", 351 | "\n", 352 | "令 $W = \\begin{bmatrix}\n", 353 | "w_{1} & w_{2}& \\cdots& w_{k} \n", 354 | "\\end{bmatrix}$\n", 355 | "为话题向量空间, $w_{1}, w_{2}, ..., w_{k}$ 表示文本集合的 $k$ 个话题, 令 $H = \\begin{bmatrix}\n", 356 | "h_{1} & h_{2}& \\cdots& h_{n} \n", 357 | "\\end{bmatrix}$\n", 358 | "为文本在话题向量空间的表示, $h_{1}, h_{2},..., h_{n}$ 表示文本集合的 $n$ 个文本。" 359 | ] 360 | }, 361 | { 362 | "cell_type": "markdown", 363 | "metadata": { 364 | "id": "1DcvVSR0N_CF", 365 | "colab_type": "text" 366 | }, 367 | "source": [ 368 | "##### 算法" 369 | ] 370 | }, 371 | { 372 | "cell_type": "markdown", 373 | "metadata": { 374 | "id": "hvZyHT85O5qt", 375 | "colab_type": "text" 376 | }, 377 | "source": [ 378 | "非负矩阵分解可以形式化为最优化问题求解。可以利用平方损失或散度来作为损失函数。\n", 379 | "\n", 380 | "目标函数 $|| X - WH ||^{2}$ 关于 $W$ 和 $H$ 的最小化,满足约束条件 $W, H >= 0$, 即: \n", 381 | "\n", 382 | "$\\underset{W,H}{min} || X - WH ||^{2}$ \n", 383 | "\n", 384 | "\n", 385 | "$s.t. W, H >= 0$" 386 | ] 387 | }, 388 | { 389 | "cell_type": "markdown", 390 | "metadata": { 391 | "id": "-zIGS1AEQdWp", 392 | "colab_type": "text" 393 | }, 394 | "source": [ 395 | "乘法更新规则: \n", 396 | "\n", 397 | "\n", 398 | "$W_{il} \\leftarrow W_{il}\\frac{(XH^{T})_{il}}{(WHH^{T})_{il}}$ (17.33)\n", 399 | "\n", 400 | "\n", 401 | "$H_{lj} \\leftarrow H_{lj}\\frac{(W^{T}X)_{lj}}{(W^{T}WH)_{lj}}$ (17.34)\n", 402 | "\n", 403 | "\n", 404 | "选择初始矩阵 $W$ 和 $H$ 为非负矩阵,可以保证迭代过程及结果的矩阵 $W$ 和 $H$ 非负。" 405 | ] 406 | }, 407 | { 408 | "cell_type": "markdown", 409 | "metadata": { 410 | "id": "MeiA0REkRpRi", 411 | "colab_type": "text" 412 | }, 413 | "source": [ 414 | "**算法 17.1 (非负矩阵分解的迭代算法)**\n", 415 | "\n", 416 | "输入: 单词-文本矩阵 $X >= 0$, 文本集合的话题个数 $k$, 最大迭代次数 $t$; \n", 417 | "输出: 话题矩阵 $W$, 文本表示矩阵 $H$。 \n", 418 | "\n", 419 | "**1)**. 初始化\n", 420 | "\n", 421 | "$W>=0$, 并对 $W$ 的每一列数据归一化; \n", 422 | "$H>=0$;\n", 423 | "\n", 424 | "**2)**. 迭代 \n", 425 | "\n", 426 | "对迭代次数由1到$t$执行下列步骤: \n", 427 | "a. 更新$W$的元素,对 $l$ 从1到 $k,i$从1到$m$按(17.33)更新 $W_{il}$; \n", 428 | "a. 更新$H$的元素,对 $l$ 从1到 $k,j$从1到$m$按(17.34)更新 $H_{lj}$; " 429 | ] 430 | }, 431 | { 432 | "cell_type": "markdown", 433 | "metadata": { 434 | "id": "rIw6a0HITg08", 435 | "colab_type": "text" 436 | }, 437 | "source": [ 438 | "### 图例 17.1" 439 | ] 440 | }, 441 | { 442 | "cell_type": "code", 443 | "metadata": { 444 | "id": "0hPH9VEMPVGu", 445 | "colab_type": "code", 446 | "colab": {} 447 | }, 448 | "source": [ 449 | "import numpy as np\n", 450 | "from sklearn.decomposition import TruncatedSVD" 451 | ], 452 | "execution_count": 0, 453 | "outputs": [] 454 | }, 455 | { 456 | "cell_type": "code", 457 | "metadata": { 458 | "id": "kjHirYzQWItl", 459 | "colab_type": "code", 460 | "outputId": "0e6f2615-6a0b-4c4f-e74c-559727519eab", 461 | "colab": { 462 | "base_uri": "https://localhost:8080/", 463 | "height": 125 464 | } 465 | }, 466 | "source": [ 467 | "X = [[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 2, 3], [0, 0, 0, 1], [1, 2, 2, 1]]\n", 468 | "X = np.asarray(X);X" 469 | ], 470 | "execution_count": 0, 471 | "outputs": [ 472 | { 473 | "output_type": "execute_result", 474 | "data": { 475 | "text/plain": [ 476 | "array([[2, 0, 0, 0],\n", 477 | " [0, 2, 0, 0],\n", 478 | " [0, 0, 1, 0],\n", 479 | " [0, 0, 2, 3],\n", 480 | " [0, 0, 0, 1],\n", 481 | " [1, 2, 2, 1]])" 482 | ] 483 | }, 484 | "metadata": { 485 | "tags": [] 486 | }, 487 | "execution_count": 2 488 | } 489 | ] 490 | }, 491 | { 492 | "cell_type": "code", 493 | "metadata": { 494 | "id": "I2yFnNJKWcPP", 495 | "colab_type": "code", 496 | "colab": {} 497 | }, 498 | "source": [ 499 | "# 奇异值分解\n", 500 | "U,sigma,VT=np.linalg.svd(X)" 501 | ], 502 | "execution_count": 0, 503 | "outputs": [] 504 | }, 505 | { 506 | "cell_type": "code", 507 | "metadata": { 508 | "id": "ollDH_QNXAdY", 509 | "colab_type": "code", 510 | "outputId": "8d599d82-4bae-4047-941c-8ca524142349", 511 | "colab": { 512 | "base_uri": "https://localhost:8080/", 513 | "height": 233 514 | } 515 | }, 516 | "source": [ 517 | "U" 518 | ], 519 | "execution_count": 0, 520 | "outputs": [ 521 | { 522 | "output_type": "execute_result", 523 | "data": { 524 | "text/plain": [ 525 | "array([[-7.84368672e-02, 2.84423033e-01, 8.94427191e-01,\n", 526 | " 2.15138396e-01, -2.68931121e-02, -2.56794523e-01],\n", 527 | " [-1.56873734e-01, 5.68846066e-01, -4.47213595e-01,\n", 528 | " 4.30276793e-01, -5.37862243e-02, -5.13589047e-01],\n", 529 | " [-1.42622354e-01, -1.37930417e-02, 4.16333634e-17,\n", 530 | " -6.53519444e-01, 4.77828945e-01, -5.69263078e-01],\n", 531 | " [-7.28804669e-01, -5.53499910e-01, 3.33066907e-16,\n", 532 | " 1.56161345e-01, -2.92700697e-01, -2.28957508e-01],\n", 533 | " [-1.47853320e-01, -1.75304609e-01, 1.04083409e-16,\n", 534 | " 4.87733411e-01, 8.24315866e-01, 1.73283476e-01],\n", 535 | " [-6.29190197e-01, 5.08166890e-01, -4.44089210e-16,\n", 536 | " -2.81459486e-01, 5.37862243e-02, 5.13589047e-01]])" 537 | ] 538 | }, 539 | "metadata": { 540 | "tags": [] 541 | }, 542 | "execution_count": 8 543 | } 544 | ] 545 | }, 546 | { 547 | "cell_type": "code", 548 | "metadata": { 549 | "id": "lmxB_JViXFAF", 550 | "colab_type": "code", 551 | "outputId": "9cdcba5c-74f8-42e4-9a4b-fad1c06b83cd", 552 | "colab": { 553 | "base_uri": "https://localhost:8080/", 554 | "height": 35 555 | } 556 | }, 557 | "source": [ 558 | "sigma" 559 | ], 560 | "execution_count": 0, 561 | "outputs": [ 562 | { 563 | "output_type": "execute_result", 564 | "data": { 565 | "text/plain": [ 566 | "array([4.47696617, 2.7519661 , 2. , 1.17620428])" 567 | ] 568 | }, 569 | "metadata": { 570 | "tags": [] 571 | }, 572 | "execution_count": 13 573 | } 574 | ] 575 | }, 576 | { 577 | "cell_type": "code", 578 | "metadata": { 579 | "id": "AiXURUScXMsj", 580 | "colab_type": "code", 581 | "outputId": "d6fc0576-8bf1-491e-caed-02035079f0d3", 582 | "colab": { 583 | "base_uri": "https://localhost:8080/", 584 | "height": 161 585 | } 586 | }, 587 | "source": [ 588 | "VT" 589 | ], 590 | "execution_count": 0, 591 | "outputs": [ 592 | { 593 | "output_type": "execute_result", 594 | "data": { 595 | "text/plain": [ 596 | "array([[-1.75579600e-01, -3.51159201e-01, -6.38515454e-01,\n", 597 | " -6.61934313e-01],\n", 598 | " [ 3.91361272e-01, 7.82722545e-01, -3.79579831e-02,\n", 599 | " -4.82432341e-01],\n", 600 | " [ 8.94427191e-01, -4.47213595e-01, 0.00000000e+00,\n", 601 | " 8.32667268e-17],\n", 602 | " [ 1.26523351e-01, 2.53046702e-01, -7.68672366e-01,\n", 603 | " 5.73674125e-01]])" 604 | ] 605 | }, 606 | "metadata": { 607 | "tags": [] 608 | }, 609 | "execution_count": 14 610 | } 611 | ] 612 | }, 613 | { 614 | "cell_type": "code", 615 | "metadata": { 616 | "id": "DKOxld5lXRCK", 617 | "colab_type": "code", 618 | "outputId": "0832796a-9952-4f39-e1ef-79347c495d16", 619 | "colab": { 620 | "base_uri": "https://localhost:8080/", 621 | "height": 53 622 | } 623 | }, 624 | "source": [ 625 | "# 截断奇异值分解\n", 626 | "\n", 627 | "svd = TruncatedSVD(n_components=3, n_iter=7, random_state=42)\n", 628 | "svd.fit(X) " 629 | ], 630 | "execution_count": 0, 631 | "outputs": [ 632 | { 633 | "output_type": "execute_result", 634 | "data": { 635 | "text/plain": [ 636 | "TruncatedSVD(algorithm='randomized', n_components=3, n_iter=7, random_state=42,\n", 637 | " tol=0.0)" 638 | ] 639 | }, 640 | "metadata": { 641 | "tags": [] 642 | }, 643 | "execution_count": 16 644 | } 645 | ] 646 | }, 647 | { 648 | "cell_type": "code", 649 | "metadata": { 650 | "id": "btnGrF0LXzZI", 651 | "colab_type": "code", 652 | "outputId": "ba85127a-4fa6-4092-828c-9036c47f82f6", 653 | "colab": { 654 | "base_uri": "https://localhost:8080/", 655 | "height": 35 656 | } 657 | }, 658 | "source": [ 659 | "print(svd.explained_variance_ratio_)" 660 | ], 661 | "execution_count": 0, 662 | "outputs": [ 663 | { 664 | "output_type": "stream", 665 | "text": [ 666 | "[0.39945801 0.34585056 0.18861789]\n" 667 | ], 668 | "name": "stdout" 669 | } 670 | ] 671 | }, 672 | { 673 | "cell_type": "code", 674 | "metadata": { 675 | "id": "F1hSe5NxX1zw", 676 | "colab_type": "code", 677 | "outputId": "b0d0b87d-195b-4653-a857-48ff1eca887a", 678 | "colab": { 679 | "base_uri": "https://localhost:8080/", 680 | "height": 35 681 | } 682 | }, 683 | "source": [ 684 | "print(svd.explained_variance_ratio_.sum())" 685 | ], 686 | "execution_count": 0, 687 | "outputs": [ 688 | { 689 | "output_type": "stream", 690 | "text": [ 691 | "0.9339264600284481\n" 692 | ], 693 | "name": "stdout" 694 | } 695 | ] 696 | }, 697 | { 698 | "cell_type": "code", 699 | "metadata": { 700 | "id": "cV4L2i9WX30R", 701 | "colab_type": "code", 702 | "outputId": "6c313215-d095-41b2-a384-3dc1575729a2", 703 | "colab": { 704 | "base_uri": "https://localhost:8080/", 705 | "height": 35 706 | } 707 | }, 708 | "source": [ 709 | "print(svd.singular_values_)" 710 | ], 711 | "execution_count": 0, 712 | "outputs": [ 713 | { 714 | "output_type": "stream", 715 | "text": [ 716 | "[4.47696617 2.7519661 2. ]\n" 717 | ], 718 | "name": "stdout" 719 | } 720 | ] 721 | }, 722 | { 723 | "cell_type": "markdown", 724 | "metadata": { 725 | "id": "4CbG9kJXictK", 726 | "colab_type": "text" 727 | }, 728 | "source": [ 729 | "#### 非负矩阵分解" 730 | ] 731 | }, 732 | { 733 | "cell_type": "code", 734 | "metadata": { 735 | "id": "KcA2Rd4Df_DE", 736 | "colab_type": "code", 737 | "colab": {} 738 | }, 739 | "source": [ 740 | "def inverse_transform(W, H):\n", 741 | " # 重构\n", 742 | " return W.dot(H)\n", 743 | "\n", 744 | "def loss(X, X_):\n", 745 | " #计算重构误差\n", 746 | " return ((X - X_) * (X - X_)).sum()" 747 | ], 748 | "execution_count": 0, 749 | "outputs": [] 750 | }, 751 | { 752 | "cell_type": "code", 753 | "metadata": { 754 | "id": "yRXZt6CfYPJq", 755 | "colab_type": "code", 756 | "colab": {} 757 | }, 758 | "source": [ 759 | "# 算法 17.1\n", 760 | "\n", 761 | "class MyNMF:\n", 762 | " def fit(self, X, k, t):\n", 763 | " m, n = X.shape\n", 764 | " \n", 765 | " W = np.random.rand(m, k)\n", 766 | " W = W/W.sum(axis=0)\n", 767 | " \n", 768 | " H = np.random.rand(k, n)\n", 769 | " \n", 770 | " i = 1\n", 771 | " while i < t:\n", 772 | " \n", 773 | " W = W * X.dot(H.T) / W.dot(H).dot(H.T)\n", 774 | " \n", 775 | " H = H * (W.T).dot(X) / (W.T).dot(W).dot(H)\n", 776 | " \n", 777 | " i += 1\n", 778 | " \n", 779 | " return W, H" 780 | ], 781 | "execution_count": 0, 782 | "outputs": [] 783 | }, 784 | { 785 | "cell_type": "code", 786 | "metadata": { 787 | "id": "zc1IFBBIajXk", 788 | "colab_type": "code", 789 | "colab": {} 790 | }, 791 | "source": [ 792 | "model = MyNMF()\n", 793 | "W, H = model.fit(X, 3, 200)" 794 | ], 795 | "execution_count": 0, 796 | "outputs": [] 797 | }, 798 | { 799 | "cell_type": "code", 800 | "metadata": { 801 | "id": "EYo7JyXXbNpJ", 802 | "colab_type": "code", 803 | "outputId": "716dbb03-5229-4c93-a75b-81d07ff4be85", 804 | "colab": { 805 | "base_uri": "https://localhost:8080/", 806 | "height": 125 807 | } 808 | }, 809 | "source": [ 810 | "W" 811 | ], 812 | "execution_count": 0, 813 | "outputs": [ 814 | { 815 | "output_type": "execute_result", 816 | "data": { 817 | "text/plain": [ 818 | "array([[7.80747563e-282, 1.59350147e-085, 4.46818285e-001],\n", 819 | " [1.36585053e-173, 1.82432253e-099, 8.93962574e-001],\n", 820 | " [5.15393770e-058, 5.75011993e-001, 4.19426683e-039],\n", 821 | " [1.73830597e+000, 1.09961986e+000, 4.87814473e-031],\n", 822 | " [5.89525831e-001, 3.53091403e-065, 1.51262003e-035],\n", 823 | " [5.54346760e-001, 1.12753836e+000, 1.11381284e+000]])" 824 | ] 825 | }, 826 | "metadata": { 827 | "tags": [] 828 | }, 829 | "execution_count": 113 830 | } 831 | ] 832 | }, 833 | { 834 | "cell_type": "code", 835 | "metadata": { 836 | "id": "1cEFDsgXbnXZ", 837 | "colab_type": "code", 838 | "outputId": "e3c8eaf0-bcd8-48f5-edee-57f643b07fed", 839 | "colab": { 840 | "base_uri": "https://localhost:8080/", 841 | "height": 71 842 | } 843 | }, 844 | "source": [ 845 | "H" 846 | ], 847 | "execution_count": 0, 848 | "outputs": [ 849 | { 850 | "output_type": "execute_result", 851 | "data": { 852 | "text/plain": [ 853 | "array([[3.02557029e-05, 2.18916926e-04, 5.10981068e-02, 1.69486742e+00],\n", 854 | " [3.08284998e-03, 5.45494813e-03, 1.73785466e+00, 4.89822454e-02],\n", 855 | " [8.94680268e-01, 1.79000896e+00, 1.09648099e-02, 4.54347640e-03]])" 856 | ] 857 | }, 858 | "metadata": { 859 | "tags": [] 860 | }, 861 | "execution_count": 114 862 | } 863 | ] 864 | }, 865 | { 866 | "cell_type": "code", 867 | "metadata": { 868 | "id": "JFqGat0JdVlL", 869 | "colab_type": "code", 870 | "outputId": "567c82ff-997f-49e7-b829-f9d4c828f9c9", 871 | "colab": { 872 | "base_uri": "https://localhost:8080/", 873 | "height": 125 874 | } 875 | }, 876 | "source": [ 877 | "# 重构\n", 878 | "X_ = inverse_transform(W, H);X_" 879 | ], 880 | "execution_count": 0, 881 | "outputs": [ 882 | { 883 | "output_type": "execute_result", 884 | "data": { 885 | "text/plain": [ 886 | "array([[3.99759503e-01, 7.99808736e-01, 4.89927756e-03, 2.03010833e-03],\n", 887 | " [7.99810675e-01, 1.60020102e+00, 9.80212969e-03, 4.06169786e-03],\n", 888 | " [1.77267571e-03, 3.13666059e-03, 9.99287272e-01, 2.81653785e-02],\n", 889 | " [3.44255674e-03, 6.37891391e-03, 1.99980365e+00, 3.00006000e+00],\n", 890 | " [1.78365184e-05, 1.29057183e-04, 3.01236539e-02, 9.99168124e-01],\n", 891 | " [9.99999171e-01, 2.00000698e+00, 2.00003661e+00, 9.99834205e-01]])" 892 | ] 893 | }, 894 | "metadata": { 895 | "tags": [] 896 | }, 897 | "execution_count": 115 898 | } 899 | ] 900 | }, 901 | { 902 | "cell_type": "code", 903 | "metadata": { 904 | "id": "FmXCjjnyfcfY", 905 | "colab_type": "code", 906 | "outputId": "819c8029-12d3-4344-e5af-b352b51507a3", 907 | "colab": { 908 | "base_uri": "https://localhost:8080/", 909 | "height": 35 910 | } 911 | }, 912 | "source": [ 913 | "# 重构误差\n", 914 | "\n", 915 | "loss(X, X_)" 916 | ], 917 | "execution_count": 0, 918 | "outputs": [ 919 | { 920 | "output_type": "execute_result", 921 | "data": { 922 | "text/plain": [ 923 | "4.001908238790242" 924 | ] 925 | }, 926 | "metadata": { 927 | "tags": [] 928 | }, 929 | "execution_count": 116 930 | } 931 | ] 932 | }, 933 | { 934 | "cell_type": "markdown", 935 | "metadata": { 936 | "id": "dGu-tGDxhEcf", 937 | "colab_type": "text" 938 | }, 939 | "source": [ 940 | "### 使用 sklearn 计算" 941 | ] 942 | }, 943 | { 944 | "cell_type": "code", 945 | "metadata": { 946 | "id": "sLN4FLmvb6tt", 947 | "colab_type": "code", 948 | "colab": {} 949 | }, 950 | "source": [ 951 | "from sklearn.decomposition import NMF\n", 952 | "model = NMF(n_components=3, init='random', max_iter=200, random_state=0)\n", 953 | "W = model.fit_transform(X)\n", 954 | "H = model.components_" 955 | ], 956 | "execution_count": 0, 957 | "outputs": [] 958 | }, 959 | { 960 | "cell_type": "code", 961 | "metadata": { 962 | "id": "Fm5W6xQ0b_jl", 963 | "colab_type": "code", 964 | "outputId": "d2ae79f6-0fea-47ba-a2ba-1742b78f71b1", 965 | "colab": { 966 | "base_uri": "https://localhost:8080/", 967 | "height": 125 968 | } 969 | }, 970 | "source": [ 971 | "W" 972 | ], 973 | "execution_count": 0, 974 | "outputs": [ 975 | { 976 | "output_type": "execute_result", 977 | "data": { 978 | "text/plain": [ 979 | "array([[0. , 0.53849498, 0. ],\n", 980 | " [0. , 1.07698996, 0. ],\n", 981 | " [0.69891361, 0. , 0. ],\n", 982 | " [1.39782972, 0. , 1.97173859],\n", 983 | " [0. , 0. , 0.65783848],\n", 984 | " [1.39783002, 1.34623756, 0.65573258]])" 985 | ] 986 | }, 987 | "metadata": { 988 | "tags": [] 989 | }, 990 | "execution_count": 108 991 | } 992 | ] 993 | }, 994 | { 995 | "cell_type": "code", 996 | "metadata": { 997 | "id": "wheFi8ZwcCY9", 998 | "colab_type": "code", 999 | "outputId": "9dd57216-889a-4cc4-ccb8-82e790595d59", 1000 | "colab": { 1001 | "base_uri": "https://localhost:8080/", 1002 | "height": 71 1003 | } 1004 | }, 1005 | "source": [ 1006 | "H" 1007 | ], 1008 | "execution_count": 0, 1009 | "outputs": [ 1010 | { 1011 | "output_type": "execute_result", 1012 | "data": { 1013 | "text/plain": [ 1014 | "array([[0.00000000e+00, 0.00000000e+00, 1.43078959e+00, 1.71761682e-03],\n", 1015 | " [7.42810976e-01, 1.48562195e+00, 0.00000000e+00, 3.30264644e-04],\n", 1016 | " [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.52030365e+00]])" 1017 | ] 1018 | }, 1019 | "metadata": { 1020 | "tags": [] 1021 | }, 1022 | "execution_count": 109 1023 | } 1024 | ] 1025 | }, 1026 | { 1027 | "cell_type": "code", 1028 | "metadata": { 1029 | "id": "9hfj3bRXgHRb", 1030 | "colab_type": "code", 1031 | "outputId": "3be1a4d8-a160-4200-d4e7-f5c92f2aa1af", 1032 | "colab": { 1033 | "base_uri": "https://localhost:8080/", 1034 | "height": 125 1035 | } 1036 | }, 1037 | "source": [ 1038 | "X__ = inverse_transform(W, H);X__" 1039 | ], 1040 | "execution_count": 0, 1041 | "outputs": [ 1042 | { 1043 | "output_type": "execute_result", 1044 | "data": { 1045 | "text/plain": [ 1046 | "array([[3.99999983e-01, 7.99999966e-01, 0.00000000e+00, 1.77845853e-04],\n", 1047 | " [7.99999966e-01, 1.59999993e+00, 0.00000000e+00, 3.55691707e-04],\n", 1048 | " [0.00000000e+00, 0.00000000e+00, 9.99998311e-01, 1.20046577e-03],\n", 1049 | " [0.00000000e+00, 0.00000000e+00, 2.00000021e+00, 3.00004230e+00],\n", 1050 | " [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.00011424e+00],\n", 1051 | " [1.00000003e+00, 2.00000007e+00, 2.00000064e+00, 9.99758185e-01]])" 1052 | ] 1053 | }, 1054 | "metadata": { 1055 | "tags": [] 1056 | }, 1057 | "execution_count": 110 1058 | } 1059 | ] 1060 | }, 1061 | { 1062 | "cell_type": "code", 1063 | "metadata": { 1064 | "id": "-iALOTKfgKzP", 1065 | "colab_type": "code", 1066 | "outputId": "80c73327-dd36-403f-aa20-d4e2d3f796a7", 1067 | "colab": { 1068 | "base_uri": "https://localhost:8080/", 1069 | "height": 35 1070 | } 1071 | }, 1072 | "source": [ 1073 | "loss(X, X__)" 1074 | ], 1075 | "execution_count": 0, 1076 | "outputs": [ 1077 | { 1078 | "output_type": "execute_result", 1079 | "data": { 1080 | "text/plain": [ 1081 | "4.0000016725824565" 1082 | ] 1083 | }, 1084 | "metadata": { 1085 | "tags": [] 1086 | }, 1087 | "execution_count": 111 1088 | } 1089 | ] 1090 | } 1091 | ] 1092 | } -------------------------------------------------------------------------------- /第18章 概率潜在语义分析/PLSA.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "PLSA.ipynb", 7 | "version": "0.3.2", 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | } 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "markdown", 19 | "metadata": { 20 | "id": "0I-Es-jovJzm", 21 | "colab_type": "text" 22 | }, 23 | "source": [ 24 | "# 概率潜在语义分析" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": { 30 | "id": "EeHek0KrvNbO", 31 | "colab_type": "text" 32 | }, 33 | "source": [ 34 | "概率潜在语义分析(probabilistic latent semantic analysis, PLSA),也称概率潜在语义索引(probabilistic latent semantic indexing, PLSI),是一种利用概率生成模型对文本集合进行话题分析的无监督学习方法。\n", 35 | "\n", 36 | "模型最大特点是用隐变量表示话题,整个模型表示文本生成话题,话题生成单词,从而得到单词-文本共现数据的过程;假设每个文本由一个话题分布决定,每个话题由一个单词分布决定。" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": { 42 | "id": "ZpnNY-eRwjq3", 43 | "colab_type": "text" 44 | }, 45 | "source": [ 46 | "### **18.1.2 生成模型**\n", 47 | "\n", 48 | "假设有单词集合 $W = $ {$w_{1}, w_{2}, ..., w_{M}$}, 其中M是单词个数;文本(指标)集合$D = $ {$d_{1}, d_{2}, ..., d_{N}$}, 其中N是文本个数;话题集合$Z = $ {$z_{1}, z_{2}, ..., z_{K}$},其中$K$是预先设定的话题个数。随机变量 $w$ 取值于单词集合;随机变量 $d$ 取值于文本集合,随机变量 $z$ 取值于话题集合。概率分布 $P(d)$、条件概率分布 $P(z|d)$、条件概率分布 $P(w|z)$ 皆属于多项分布,其中 $P(d)$ 表示生成文本 $d$ 的概率,$P(z|d)$ 表示文本 $d$ 生成话题 $z$ 的概率,$P(w|z)$ 表示话题 $z$ 生成单词 $w$ 的概率。\n", 49 | "\n", 50 | " 每个文本 $d$ 拥有自己的话题概率分布 $P(z|d)$,每个话题 $z$ 拥有自己的单词概率分布 $P(w|z)$;也就是说**一个文本的内容由其相关话题决定,一个话题的内容由其相关单词决定**。\n", 51 | " \n", 52 | " 生成模型通过以下步骤生成文本·单词共现数据: \n", 53 | " (1)依据概率分布 $P(d)$,从文本(指标)集合中随机选取一个文本 $d$ , 共生成 $N$ 个文本;针对每个文本,执行以下操作; \n", 54 | " (2)在文本$d$ 给定条件下,依据条件概率分布 $P(z|d)$, 从话题集合随机选取一个话题 $z$, 共生成 $L$ 个话题,这里 $L$ 是文本长度; \n", 55 | " (3)在话题 $z$ 给定条件下,依据条件概率分布 $P(w|z)$ , 从单词集合中随机选取一个单词 $w$. \n", 56 | " \n", 57 | " 注意这里为叙述方便,假设文本都是等长的,现实中不需要这个假设。" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": { 63 | "id": "_YwFFCuCgugI", 64 | "colab_type": "text" 65 | }, 66 | "source": [ 67 | "生成模型中, 单词变量 $w$ 与文本变量 $d$ 是观测变量, 话题变量 $z$ 是隐变量, 也就是说模型生成的是单词-话题-文本三元组合 ($w, z ,d$)的集合, 但观测到的单词-文本二元组 ($w, d$)的集合, 观测数据表示为单词-文本矩阵 $T$的形式,矩阵 $T$ 的行表示单词,列表示文本, 元素表示单词-文本对($w, d$)的出现次数。 \n", 68 | "\n", 69 | "从数据的生成过程可以推出,文本-单词共现数据$T$的生成概率为所有单词-文本对($w,d$)的生成概率的乘积: \n", 70 | "\n", 71 | "$P(T) = \\prod_{w,d}P(w,d)^{n(w,d)}$ \n", 72 | "\n", 73 | "这里 $n(w,d)$ 表示 ($w,d$)的出现次数,单词-文本对出现的总次数是 $N*L$。 每个单词-文本对($w,d$)的生成概率由一下公式决定: \n", 74 | "\n", 75 | "$P(w,d) = P(d)P(w|d)$ \n", 76 | "\n", 77 | "$= P(d)\\sum_{z}P(w,z|d)$ \n", 78 | "\n", 79 | "$=P(d)\\sum_{z}P(z|d)P(w|z)$" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": { 85 | "id": "rIUH6dILnmQs", 86 | "colab_type": "text" 87 | }, 88 | "source": [ 89 | "### **18.1.3 共现模型**\n", 90 | "\n", 91 | "$P(w,d) = \\sum_{z\\in Z}P(z)P(w|z)P(d|z)$" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": { 97 | "id": "JSt5kq4LoFJT", 98 | "colab_type": "text" 99 | }, 100 | "source": [ 101 | "虽然生成模型与共现模型在概率公式意义上是等价的,但是拥有不同的性质。生成模型刻画文本-单词共现数据生成的过程,共现模型描述文本-单词共现数据拥有的模式。 \n", 102 | "\n", 103 | "如果直接定义单词与文本的共现概率 $P(w,d)$, 模型参数的个数是 $O(M*N)$, 其中 $M$ 是单词数, $N$ 是文本数。 概率潜在语义分析的生成模型和共现模型的参数个数是 $O(M*K + N*K)$, 其中 $K$ 是话题数。 现实中 $K< maxerr:\n", 70 | " ro = r.copy()\n", 71 | " # calculate each pagerank at a time\n", 72 | " for i in range(0,n):\n", 73 | " # inlinks of state i\n", 74 | " Ai = np.array(A[:,i].todense())[:,0]\n", 75 | " # account for sink states\n", 76 | " Di = sink / float(n)\n", 77 | " # account for teleportation to state i\n", 78 | " Ei = np.ones(n) / float(n)\n", 79 | "\n", 80 | " r[i] = ro.dot( Ai*s + Di*s + Ei*(1-s) )\n", 81 | "\n", 82 | " # return normalized pagerank\n", 83 | " return r/float(sum(r))" 84 | ], 85 | "execution_count": 0, 86 | "outputs": [] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "metadata": { 91 | "id": "Ds-wQEFFZ1F7", 92 | "colab_type": "code", 93 | "colab": { 94 | "base_uri": "https://localhost:8080/", 95 | "height": 53 96 | }, 97 | "outputId": "b2860902-8712-4583-ab47-bec602c6791b" 98 | }, 99 | "source": [ 100 | "# Example extracted from 'Introduction to Information Retrieval'\n", 101 | "G = np.array([[0,0,1,0,0,0,0],\n", 102 | " [0,1,1,0,0,0,0],\n", 103 | " [1,0,1,1,0,0,0],\n", 104 | " [0,0,0,1,1,0,0],\n", 105 | " [0,0,0,0,0,0,1],\n", 106 | " [0,0,0,0,0,1,1],\n", 107 | " [0,0,0,1,1,0,1]])\n", 108 | "print(pageRank(G,s=.86))" 109 | ], 110 | "execution_count": 6, 111 | "outputs": [ 112 | { 113 | "output_type": "stream", 114 | "text": [ 115 | "[0.12727557 0.03616954 0.12221594 0.22608452 0.28934412 0.03616954\n", 116 | " 0.16274076]\n" 117 | ], 118 | "name": "stdout" 119 | } 120 | ] 121 | } 122 | ] 123 | } --------------------------------------------------------------------------------