├── .DS_Store
├── LICENSE
├── README.md
├── SLM.jpg
├── 第01章 统计学习方法概论
├── 20190522135945.jpg
└── least_sqaure_method.ipynb
├── 第02章 感知机
└── perceptron.ipynb
├── 第03章 k近邻法
└── KNN.ipynb
├── 第04章 朴素贝叶斯法
└── NaiveBayes.ipynb
├── 第05章 决策树
└── DT.ipynb
├── 第06章 逻辑斯蒂回归与最大熵模型
└── LR.ipynb
├── 第07章 支持向量机
├── .DS_Store
└── SVM.ipynb
├── 第08章 提升方法
├── .DS_Store
└── Adaboost.ipynb
├── 第09章 EM算法及其推广
└── EM.ipynb
├── 第10章 隐马尔可夫模型
└── HMM.ipynb
├── 第11章 条件随机场
└── CRF.ipynb
├── 第12章 监督学习方法总结
└── Summary_of_Supervised_Learning_Methods.ipynb
├── 第13章 无监督学习概论
└── Introduction_to_Unsupervised_Learning.ipynb
├── 第14章 聚类方法
└── Clustering.ipynb
├── 第15章 奇异值分解
└── SVD.ipynb
├── 第16章 主成分分析
└── PCA.ipynb
├── 第17章 潜在语义分析
└── LSA.ipynb
├── 第18章 概率潜在语义分析
└── PLSA.ipynb
├── 第19章 马尔可夫链蒙特卡洛法
└── MCMC.ipynb
├── 第20章 潜在狄利克雷分配
└── LDA.ipynb
└── 第21章 PageRank算法
└── PageRank.ipynb
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hktxt/Learn-Statistical-Learning-Method/14a3b65db3ff35c42743a1062a2a95cecbfad094/.DS_Store
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Max
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Learn-Statistical-Learning-Method, Second Edition
2 | 
3 | Implementation of Statistical Learning Method
4 | 《统计学习方法》第二版,算法实现。
5 |
6 |
7 | 第1章:统计学习方法概论 [least_sqaure_method.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC01%E7%AB%A0%20%E7%BB%9F%E8%AE%A1%E5%AD%A6%E4%B9%A0%E6%96%B9%E6%B3%95%E6%A6%82%E8%AE%BA/least_sqaure_method.ipynb)
8 | 第2章:感知机 [perceptron.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC02%E7%AB%A0%20%E6%84%9F%E7%9F%A5%E6%9C%BA/perceptron.ipynb)
9 | 第3章:k近邻法 [KNN.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC03%E7%AB%A0%20k%E8%BF%91%E9%82%BB%E6%B3%95/KNN.ipynb)
10 | 第4章:朴素贝叶斯法 [NaiveBayes.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC04%E7%AB%A0%20%E6%9C%B4%E7%B4%A0%E8%B4%9D%E5%8F%B6%E6%96%AF%E6%B3%95/NaiveBayes.ipynb)
11 | 第5章:决策树 [DT.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC05%E7%AB%A0%20%E5%86%B3%E7%AD%96%E6%A0%91/DT.ipynb)
12 | 第6章:逻辑斯蒂回归与最大熵模型 [LR.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC06%E7%AB%A0%20%E9%80%BB%E8%BE%91%E6%96%AF%E8%92%82%E5%9B%9E%E5%BD%92%E4%B8%8E%E6%9C%80%E5%A4%A7%E7%86%B5%E6%A8%A1%E5%9E%8B/LR.ipynb)
13 | 第7章:支持向量机 [SVM.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC07%E7%AB%A0%20%E6%94%AF%E6%8C%81%E5%90%91%E9%87%8F%E6%9C%BA/SVM.ipynb)
14 | 第8章:提升方法 [Adaboost.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC08%E7%AB%A0%20%E6%8F%90%E5%8D%87%E6%96%B9%E6%B3%95/Adaboost.ipynb)
15 | 第9章:EM算法及其推广 [EM.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC09%E7%AB%A0%20EM%E7%AE%97%E6%B3%95%E5%8F%8A%E5%85%B6%E6%8E%A8%E5%B9%BF/EM.ipynb)
16 | 第10章:隐马尔可夫模型 [HMM.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC10%E7%AB%A0%20%E9%9A%90%E9%A9%AC%E5%B0%94%E5%8F%AF%E5%A4%AB%E6%A8%A1%E5%9E%8B/HMM.ipynb)
17 | 第11章:条件随机场 [CRF.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC11%E7%AB%A0%20%E6%9D%A1%E4%BB%B6%E9%9A%8F%E6%9C%BA%E5%9C%BA/CRF.ipynb)
18 | 第12章: 监督学习方法总结 [Summary_of_Supervised_Learning_Methods.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC12%E7%AB%A0%20%E7%9B%91%E7%9D%A3%E5%AD%A6%E4%B9%A0%E6%96%B9%E6%B3%95%E6%80%BB%E7%BB%93/Summary_of_Supervised_Learning_Methods.ipynb)
19 | 第13章:无监督学习概论 [Introduction_to_Unsupervised_Learning.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC13%E7%AB%A0%20%E6%97%A0%E7%9B%91%E7%9D%A3%E5%AD%A6%E4%B9%A0%E6%A6%82%E8%AE%BA/Introduction_to_Unsupervised_Learning.ipynb)
20 | 第14章:聚类方法 [Clustering.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC14%E7%AB%A0%20%E8%81%9A%E7%B1%BB%E6%96%B9%E6%B3%95/Clustering.ipynb)
21 | 第15章:奇异值分解 [SVD.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC15%E7%AB%A0%20%E5%A5%87%E5%BC%82%E5%80%BC%E5%88%86%E8%A7%A3/SVD.ipynb)
22 | 第16章:主成分分析 [PCA.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC16%E7%AB%A0%20%E4%B8%BB%E6%88%90%E5%88%86%E5%88%86%E6%9E%90/PCA.ipynb)
23 | 第17章:潜在语义分析 [LSA.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC17%E7%AB%A0%20%E6%BD%9C%E5%9C%A8%E8%AF%AD%E4%B9%89%E5%88%86%E6%9E%90/LSA.ipynb)
24 | 第18章:概率潜在语义分析 [PLSA.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC18%E7%AB%A0%20%E6%A6%82%E7%8E%87%E6%BD%9C%E5%9C%A8%E8%AF%AD%E4%B9%89%E5%88%86%E6%9E%90/PLSA.ipynb)
25 | 第19章:马尔可夫链蒙特卡洛法 [MCMC.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC19%E7%AB%A0%20%E9%A9%AC%E5%B0%94%E5%8F%AF%E5%A4%AB%E9%93%BE%E8%92%99%E7%89%B9%E5%8D%A1%E6%B4%9B%E6%B3%95/MCMC.ipynb)
26 | 第20章:潜在狄利克雷分配 [LDA.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC20%E7%AB%A0%20%E6%BD%9C%E5%9C%A8%E7%8B%84%E5%88%A9%E5%85%8B%E9%9B%B7%E5%88%86%E9%85%8D/LDA.ipynb)
27 | 第21章:PageRank算法 [PageRank.ipynb](https://nbviewer.jupyter.org/github/hktxt/Learn-Statistical-Learning-Method/blob/master/%E7%AC%AC21%E7%AB%A0%20PageRank%E7%AE%97%E6%B3%95/PageRank.ipynb)
28 |
29 |
30 |
31 | ## acknowledgment
32 |
33 | At present, this is still an incomplete project. For some algorithms, I am still ignorant, just followed the math equations to implement. Some algorithms are reproduced independently by myself, and others are referred to online resources, you can find the specific link in the file. I will keep updating this project until I have mastered all the algorithms in the book.
34 |
35 |
36 |
--------------------------------------------------------------------------------
/SLM.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hktxt/Learn-Statistical-Learning-Method/14a3b65db3ff35c42743a1062a2a95cecbfad094/SLM.jpg
--------------------------------------------------------------------------------
/第01章 统计学习方法概论/20190522135945.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hktxt/Learn-Statistical-Learning-Method/14a3b65db3ff35c42743a1062a2a95cecbfad094/第01章 统计学习方法概论/20190522135945.jpg
--------------------------------------------------------------------------------
/第04章 朴素贝叶斯法/NaiveBayes.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "NaiveBayes.ipynb",
7 | "version": "0.3.2",
8 | "provenance": [],
9 | "collapsed_sections": []
10 | },
11 | "language_info": {
12 | "codemirror_mode": {
13 | "name": "ipython",
14 | "version": 3
15 | },
16 | "file_extension": ".py",
17 | "mimetype": "text/x-python",
18 | "name": "python",
19 | "nbconvert_exporter": "python",
20 | "pygments_lexer": "ipython3",
21 | "version": "3.6.2"
22 | },
23 | "kernelspec": {
24 | "display_name": "Python 3",
25 | "language": "python",
26 | "name": "python3"
27 | }
28 | },
29 | "cells": [
30 | {
31 | "cell_type": "markdown",
32 | "metadata": {
33 | "id": "WDqA-VKvfWix",
34 | "colab_type": "text"
35 | },
36 | "source": [
37 | "# 第4章 朴素贝叶斯"
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "metadata": {
43 | "id": "rMWMEdyUfWix",
44 | "colab_type": "text"
45 | },
46 | "source": [
47 | "基于贝叶斯定理与特征条件独立假设的分类方法。\n",
48 | "\n",
49 | "模型:\n",
50 | "\n",
51 | "- 高斯模型\n",
52 | "- 多项式模型\n",
53 | "- 伯努利模型"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "metadata": {
59 | "id": "mahnF7NFfWiy",
60 | "colab_type": "code",
61 | "colab": {}
62 | },
63 | "source": [
64 | "import numpy as np\n",
65 | "import pandas as pd\n",
66 | "import matplotlib.pyplot as plt\n",
67 | "%matplotlib inline\n",
68 | "\n",
69 | "from sklearn.datasets import load_iris\n",
70 | "from sklearn.model_selection import train_test_split\n",
71 | "\n",
72 | "from collections import Counter\n",
73 | "import math"
74 | ],
75 | "execution_count": 0,
76 | "outputs": []
77 | },
78 | {
79 | "cell_type": "code",
80 | "metadata": {
81 | "id": "6tRQt9QFf27Y",
82 | "colab_type": "code",
83 | "colab": {}
84 | },
85 | "source": [
86 | "# 例 4.1 \n",
87 | "lambda_ = 0.2\n",
88 | "x = [2, 'S']\n",
89 | "\n",
90 | "X1 = [1,2,3]\n",
91 | "X2 = ['S', 'M', 'L']\n",
92 | "Y = [1, -1]"
93 | ],
94 | "execution_count": 0,
95 | "outputs": []
96 | },
97 | {
98 | "cell_type": "markdown",
99 | "metadata": {
100 | "id": "CCzHK_v2hafm",
101 | "colab_type": "text"
102 | },
103 | "source": [
104 | "$P_\\lambda(Y=1)=(9+lambda\\_)/(15 + 2*lambda\\_) = (9+0.2)/(15+2*0.2)=0.5974025974025974$\n",
105 | "$P_\\lambda(Y=-1)=(6+lambda\\_)/(15 + 2*lambda\\_) = (6+0.2)/(15+2*0.2)=0.40259740259740264$ \n",
106 | "$P(X^{(1)}=1|Y=1) = (2+0.2)/(9+3*0.2)=0.22916666666666669 $ \n",
107 | "$P(X^{(1)}=2|Y=1) = (3+0.2)/(9+3*0.2)=0.33333333333333337 $ \n",
108 | "$P(X^{(1)}=3|Y=1) = (4+0.2)/(9+3*0.2)=0.43750000000000006 $ \n",
109 | "$P(X^{(2)}=S|Y=1) = (1+0.2)/(9+3*0.2)=0.125 $ \n",
110 | "$P(X^{(2)}=M|Y=1) = (4+0.2)/(9+3*0.2)=0.43750000000000006 $ \n",
111 | "$P(X^{(2)}=L|Y=1) = (4+0.2)/(9+3*0.2)=0.43750000000000006 $ \n",
112 | "$P(X^{(1)}=1|Y=-1) = (3+0.2)/(6+3*0.2)=0.4848484848484849 $ \n",
113 | "$P(X^{(1)}=2|Y=-1) = (2+0.2)/(6+3*0.2)=0.33333333333333337 $ \n",
114 | "$P(X^{(1)}=3|Y=-1) = (1+0.2)/(6+3*0.2)=0.18181818181818182 $ \n",
115 | "$P(X^{(2)}=S|Y=-1) = (3+0.2)/(6+3*0.2)=0.4848484848484849 $ \n",
116 | "$P(X^{(2)}=M|Y=-1) = (2+0.2)/(6+3*0.2)=0.33333333333333337 $ \n",
117 | "$P(X^{(2)}=L|Y=-1) = (1+0.2)/(6+3*0.2)=0.18181818181818182 $ \n",
118 | "so \n",
119 | "$P(Y=1)P(X^{(1)}=2|Y=1)P(X^{(2)}=S|Y=1) =0.5974025974025974* 0.33333333333333337*0.125=0.024891774891774892$ \n",
120 | "$P(Y=-1)P(X^{(1)}=2|Y=-1)P(X^{(2)}=S|Y=-1) =0.40259740259740264* 0.33333333333333337*0.4848484848484849=0.06506624688442873$ \n",
121 | "\n",
122 | "so, it should be -1."
123 | ]
124 | },
125 | {
126 | "cell_type": "code",
127 | "metadata": {
128 | "id": "86WQkGZefWi1",
129 | "colab_type": "code",
130 | "colab": {}
131 | },
132 | "source": [
133 | "class NB:\n",
134 | " def __init__(self, lambda_):\n",
135 | " self.lambda_ = lambda_\n",
136 | " \n",
137 | " def fit(self, X, y):\n",
138 | " N, M = X.shape\n",
139 | " data = np.hstack((X, y.reshape(N, 1)))\n",
140 | " \n",
141 | " py = {}\n",
142 | " pxy = {}\n",
143 | " uniquey, countsy = np.unique(y, return_counts=True)\n",
144 | " tmp = dict(zip(uniquey, countsy))\n",
145 | " for k,v in tmp.items():\n",
146 | " py[k] = (v + self.lambda_)/(N + len(uniquey) * self.lambda_)\n",
147 | " tmp_data = data[data[:, -1] == k]\n",
148 | " for col in range(M):\n",
149 | " uniquecol, countscol = np.unique(tmp_data[:,col], return_counts=True)\n",
150 | " tmp1 = dict(zip(uniquecol, countscol))\n",
151 | " for kk, vv in tmp1.items():\n",
152 | " pxy['X({})={}|Y={}'.format(col+1, kk, k)] = (vv + self.lambda_)/(v + len(uniquecol) * self.lambda_)\n",
153 | " \n",
154 | " self.py = py\n",
155 | " self.pxy = pxy\n",
156 | "\n",
157 | " #return self.py, self.pxy\n",
158 | " \n",
159 | " def predict(self, x):\n",
160 | " M = len(x)\n",
161 | " res = {}\n",
162 | " for k,v in self.py.items():\n",
163 | " p = v\n",
164 | " for i in range(len(x)):\n",
165 | " p = p * self.pxy['X({})={}|Y={}'.format(i+1, x[i], k)]\n",
166 | " res[k] = p\n",
167 | " print(res)\n",
168 | " maxp = -1\n",
169 | " maxk = -1\n",
170 | " for kk,vv in res.items():\n",
171 | " if vv > maxp:\n",
172 | " maxp = vv\n",
173 | " maxk = kk\n",
174 | " \n",
175 | " return maxk"
176 | ],
177 | "execution_count": 0,
178 | "outputs": []
179 | },
180 | {
181 | "cell_type": "code",
182 | "metadata": {
183 | "id": "3hPRglhJfWi3",
184 | "colab_type": "code",
185 | "colab": {}
186 | },
187 | "source": [
188 | "lambda_ = 0.2\n",
189 | "d = {'S':0, 'M':1, 'L':2}\n",
190 | "\n",
191 | "X = np.array([[1, d['S']], [1, d['M']], [1, d['M']],\n",
192 | " [1, d['S']], [1, d['S']], [2, d['S']],\n",
193 | " [2, d['M']], [2, d['M']], [2, d['L']],\n",
194 | " [2, d['L']], [3, d['L']], [3, d['M']],\n",
195 | " [3, d['M']], [3, d['L']], [3, d['L']]])\n",
196 | "\n",
197 | "y = np.array([-1, -1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, -1])"
198 | ],
199 | "execution_count": 0,
200 | "outputs": []
201 | },
202 | {
203 | "cell_type": "code",
204 | "metadata": {
205 | "id": "fs8vvcpWfWi5",
206 | "colab_type": "code",
207 | "outputId": "5b1fafb5-b4de-4618-fdad-baa42385751f",
208 | "colab": {
209 | "base_uri": "https://localhost:8080/",
210 | "height": 287
211 | }
212 | },
213 | "source": [
214 | "X"
215 | ],
216 | "execution_count": 129,
217 | "outputs": [
218 | {
219 | "output_type": "execute_result",
220 | "data": {
221 | "text/plain": [
222 | "array([[1, 0],\n",
223 | " [1, 1],\n",
224 | " [1, 1],\n",
225 | " [1, 0],\n",
226 | " [1, 0],\n",
227 | " [2, 0],\n",
228 | " [2, 1],\n",
229 | " [2, 1],\n",
230 | " [2, 2],\n",
231 | " [2, 2],\n",
232 | " [3, 2],\n",
233 | " [3, 1],\n",
234 | " [3, 1],\n",
235 | " [3, 2],\n",
236 | " [3, 2]])"
237 | ]
238 | },
239 | "metadata": {
240 | "tags": []
241 | },
242 | "execution_count": 129
243 | }
244 | ]
245 | },
246 | {
247 | "cell_type": "code",
248 | "metadata": {
249 | "id": "FM89mRoZRiP1",
250 | "colab_type": "code",
251 | "colab": {
252 | "base_uri": "https://localhost:8080/",
253 | "height": 35
254 | },
255 | "outputId": "93afb1dd-3be8-4c55-cbe0-f0549365e38a"
256 | },
257 | "source": [
258 | "y"
259 | ],
260 | "execution_count": 130,
261 | "outputs": [
262 | {
263 | "output_type": "execute_result",
264 | "data": {
265 | "text/plain": [
266 | "array([-1, -1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, -1])"
267 | ]
268 | },
269 | "metadata": {
270 | "tags": []
271 | },
272 | "execution_count": 130
273 | }
274 | ]
275 | },
276 | {
277 | "cell_type": "code",
278 | "metadata": {
279 | "id": "l08zoOrQRlcN",
280 | "colab_type": "code",
281 | "colab": {
282 | "base_uri": "https://localhost:8080/",
283 | "height": 53
284 | },
285 | "outputId": "71096a11-5d50-4122-84ab-e0ab415dfc81"
286 | },
287 | "source": [
288 | "model = NB(lambda_)\n",
289 | "model.fit(X,y)\n",
290 | "model.predict(np.array([2, 0]))"
291 | ],
292 | "execution_count": 77,
293 | "outputs": [
294 | {
295 | "output_type": "stream",
296 | "text": [
297 | "{-1: 0.06506624688442873, 1: 0.024891774891774892}\n"
298 | ],
299 | "name": "stdout"
300 | },
301 | {
302 | "output_type": "execute_result",
303 | "data": {
304 | "text/plain": [
305 | "-1"
306 | ]
307 | },
308 | "metadata": {
309 | "tags": []
310 | },
311 | "execution_count": 77
312 | }
313 | ]
314 | },
315 | {
316 | "cell_type": "code",
317 | "metadata": {
318 | "id": "wULij7sgcQna",
319 | "colab_type": "code",
320 | "colab": {}
321 | },
322 | "source": [
323 | "# data\n",
324 | "def create_data():\n",
325 | " iris = load_iris()\n",
326 | " df = pd.DataFrame(iris.data, columns=iris.feature_names)\n",
327 | " df['label'] = iris.target\n",
328 | " df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n",
329 | " data = np.array(df.iloc[:100, :])\n",
330 | " # print(data)\n",
331 | " return data[:,:-1], data[:,-1]"
332 | ],
333 | "execution_count": 0,
334 | "outputs": []
335 | },
336 | {
337 | "cell_type": "code",
338 | "metadata": {
339 | "id": "wniDd3wMcTRW",
340 | "colab_type": "code",
341 | "colab": {}
342 | },
343 | "source": [
344 | "X, y = create_data()\n",
345 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)"
346 | ],
347 | "execution_count": 0,
348 | "outputs": []
349 | },
350 | {
351 | "cell_type": "code",
352 | "metadata": {
353 | "id": "G6NBwGCxcUur",
354 | "colab_type": "code",
355 | "colab": {
356 | "base_uri": "https://localhost:8080/",
357 | "height": 35
358 | },
359 | "outputId": "e85d5a75-de23-4ebf-fb00-7d1ce166b0ee"
360 | },
361 | "source": [
362 | "X_test[0], y_test[0]"
363 | ],
364 | "execution_count": 80,
365 | "outputs": [
366 | {
367 | "output_type": "execute_result",
368 | "data": {
369 | "text/plain": [
370 | "(array([5.6, 3. , 4.5, 1.5]), 1.0)"
371 | ]
372 | },
373 | "metadata": {
374 | "tags": []
375 | },
376 | "execution_count": 80
377 | }
378 | ]
379 | },
380 | {
381 | "cell_type": "code",
382 | "metadata": {
383 | "id": "J_xJWo5GcVya",
384 | "colab_type": "code",
385 | "colab": {
386 | "base_uri": "https://localhost:8080/",
387 | "height": 35
388 | },
389 | "outputId": "03afa1be-a569-4857-d05d-7b3f70cc9a02"
390 | },
391 | "source": [
392 | "X_train.shape"
393 | ],
394 | "execution_count": 82,
395 | "outputs": [
396 | {
397 | "output_type": "execute_result",
398 | "data": {
399 | "text/plain": [
400 | "(70, 4)"
401 | ]
402 | },
403 | "metadata": {
404 | "tags": []
405 | },
406 | "execution_count": 82
407 | }
408 | ]
409 | },
410 | {
411 | "cell_type": "markdown",
412 | "metadata": {
413 | "id": "GyXsq6VvfWi-",
414 | "colab_type": "text"
415 | },
416 | "source": [
417 | "## GaussianNB 高斯朴素贝叶斯\n",
418 | "\n",
419 | "特征的可能性被假设为高斯\n",
420 | "\n",
421 | "概率密度函数:\n",
422 | "$$P(x_i | y_k)=\\frac{1}{\\sqrt{2\\pi\\sigma^2_{yk}}}exp(-\\frac{(x_i-\\mu_{yk})^2}{2\\sigma^2_{yk}})$$\n",
423 | "\n",
424 | "数学期望(mean):$\\mu$,方差:$\\sigma^2=\\frac{\\sum(X-\\mu)^2}{N}$"
425 | ]
426 | },
427 | {
428 | "cell_type": "code",
429 | "metadata": {
430 | "id": "uqJnMoUrfWi-",
431 | "colab_type": "code",
432 | "colab": {}
433 | },
434 | "source": [
435 | "class NaiveBayes:\n",
436 | " def fit(self, X, y):\n",
437 | " self.classes = list(np.unique(y))\n",
438 | " self.parameters = {}\n",
439 | " \n",
440 | " for c in self.classes:\n",
441 | " # 计算每个种类的平均值,方差,先验概率\n",
442 | " X_Index_c = X[np.where(y == c)]\n",
443 | " X_index_c_mean = np.mean(X_Index_c, axis=0, keepdims=True)\n",
444 | " X_index_c_var = np.var(X_Index_c, axis=0, keepdims=True)\n",
445 | " parameters = {\"mean\": X_index_c_mean, \"var\": X_index_c_var, \"prior\": X_Index_c.shape[0] / X.shape[0]}\n",
446 | " self.parameters[\"class\" + str(c)] = parameters\n",
447 | " print(self.parameters)\n",
448 | " \n",
449 | " def _pdf(self, X, classes):\n",
450 | " # 一维高斯分布的概率密度函数\n",
451 | " eps = 1e-4\n",
452 | " mean = self.parameters[\"class\" + str(classes)][\"mean\"]\n",
453 | " var = self.parameters[\"class\" + str(classes)][\"var\"]\n",
454 | " \n",
455 | " numerator = np.exp(-(X - mean) ** 2 / (2 * var + eps))\n",
456 | " denominator = np.sqrt(2 * np.pi * var + eps)\n",
457 | " \n",
458 | " # 取对数防止数值溢出\n",
459 | " result = np.sum(np.log(numerator / denominator), axis=1, keepdims=True)\n",
460 | " \n",
461 | " return result.T\n",
462 | " \n",
463 | " def _predict(self, X):\n",
464 | " output = []\n",
465 | " for y in self.classes:\n",
466 | " prior = np.log(self.parameters[\"class\" + str(y)][\"prior\"])\n",
467 | " posterior = self._pdf(X, y)\n",
468 | " prediction = prior + posterior\n",
469 | " output.append(prediction)\n",
470 | " return output\n",
471 | " \n",
472 | " def predict(self, X):\n",
473 | " # 取概率最大的类别返回预测值\n",
474 | " output = self._predict(X)\n",
475 | " output = np.reshape(output, (len(self.classes), X.shape[0]))\n",
476 | " prediction = np.argmax(output, axis=0)\n",
477 | " return prediction\n",
478 | " \n",
479 | " def score(self, X_test, y_test):\n",
480 | " right = 0\n",
481 | " pred = self.predict(X_test)\n",
482 | " right = (y_test - pred == 0).sum()\n",
483 | "\n",
484 | " return right / float(len(X_test))"
485 | ],
486 | "execution_count": 0,
487 | "outputs": []
488 | },
489 | {
490 | "cell_type": "code",
491 | "metadata": {
492 | "id": "NpeBcwKJfWjA",
493 | "colab_type": "code",
494 | "colab": {}
495 | },
496 | "source": [
497 | "model = NaiveBayes()"
498 | ],
499 | "execution_count": 0,
500 | "outputs": []
501 | },
502 | {
503 | "cell_type": "code",
504 | "metadata": {
505 | "id": "JLj3a70GfWjD",
506 | "colab_type": "code",
507 | "outputId": "ef5182a8-668e-4f62-93fb-314f95f68220",
508 | "colab": {
509 | "base_uri": "https://localhost:8080/",
510 | "height": 73
511 | }
512 | },
513 | "source": [
514 | "model.fit(X_train, y_train)"
515 | ],
516 | "execution_count": 123,
517 | "outputs": [
518 | {
519 | "output_type": "stream",
520 | "text": [
521 | "{'class0.0': {'mean': array([[5.02571429, 3.42857143, 1.49142857, 0.24857143]]), 'var': array([[0.10648163, 0.15918367, 0.02478367, 0.01278367]]), 'prior': 0.5}}\n",
522 | "{'class0.0': {'mean': array([[5.02571429, 3.42857143, 1.49142857, 0.24857143]]), 'var': array([[0.10648163, 0.15918367, 0.02478367, 0.01278367]]), 'prior': 0.5}, 'class1.0': {'mean': array([[5.94285714, 2.77714286, 4.27428571, 1.34 ]]), 'var': array([[0.18816327, 0.09833469, 0.17505306, 0.03954286]]), 'prior': 0.5}}\n"
523 | ],
524 | "name": "stdout"
525 | }
526 | ]
527 | },
528 | {
529 | "cell_type": "code",
530 | "metadata": {
531 | "id": "x9PDXudxfWjF",
532 | "colab_type": "code",
533 | "outputId": "68a6c9c7-9524-47fe-800e-d8a2d080ddea",
534 | "colab": {
535 | "base_uri": "https://localhost:8080/",
536 | "height": 35
537 | }
538 | },
539 | "source": [
540 | "print(model.predict(X_test))"
541 | ],
542 | "execution_count": 124,
543 | "outputs": [
544 | {
545 | "output_type": "stream",
546 | "text": [
547 | "[1 0 0 0 1 0 0 0 1 0 1 0 1 1 1 1 0 1 1 1 0 0 0 0 0 1 1 1 1 0]\n"
548 | ],
549 | "name": "stdout"
550 | }
551 | ]
552 | },
553 | {
554 | "cell_type": "code",
555 | "metadata": {
556 | "id": "xMO7vvVvfWjI",
557 | "colab_type": "code",
558 | "outputId": "86148657-b876-4277-b3e0-678254b0ddaf",
559 | "colab": {
560 | "base_uri": "https://localhost:8080/",
561 | "height": 35
562 | }
563 | },
564 | "source": [
565 | "model.score(X_test, y_test)"
566 | ],
567 | "execution_count": 125,
568 | "outputs": [
569 | {
570 | "output_type": "execute_result",
571 | "data": {
572 | "text/plain": [
573 | "1.0"
574 | ]
575 | },
576 | "metadata": {
577 | "tags": []
578 | },
579 | "execution_count": 125
580 | }
581 | ]
582 | },
583 | {
584 | "cell_type": "markdown",
585 | "metadata": {
586 | "collapsed": true,
587 | "id": "TGnmFQFkfWjL",
588 | "colab_type": "text"
589 | },
590 | "source": [
591 | "scikit-learn实例\n",
592 | "\n",
593 | "# sklearn.naive_bayes"
594 | ]
595 | },
596 | {
597 | "cell_type": "code",
598 | "metadata": {
599 | "id": "EBKRlWmsfWjM",
600 | "colab_type": "code",
601 | "colab": {}
602 | },
603 | "source": [
604 | "from sklearn.naive_bayes import GaussianNB"
605 | ],
606 | "execution_count": 0,
607 | "outputs": []
608 | },
609 | {
610 | "cell_type": "code",
611 | "metadata": {
612 | "id": "S7Q8mOzmfWjO",
613 | "colab_type": "code",
614 | "outputId": "d7fbefa1-b855-4ce5-9352-81049d510232",
615 | "colab": {
616 | "base_uri": "https://localhost:8080/",
617 | "height": 35
618 | }
619 | },
620 | "source": [
621 | "clf = GaussianNB()\n",
622 | "clf.fit(X, y)"
623 | ],
624 | "execution_count": 133,
625 | "outputs": [
626 | {
627 | "output_type": "execute_result",
628 | "data": {
629 | "text/plain": [
630 | "GaussianNB(priors=None, var_smoothing=1e-09)"
631 | ]
632 | },
633 | "metadata": {
634 | "tags": []
635 | },
636 | "execution_count": 133
637 | }
638 | ]
639 | },
640 | {
641 | "cell_type": "code",
642 | "metadata": {
643 | "id": "BdAKVtxXfWjT",
644 | "colab_type": "code",
645 | "outputId": "e994d698-ee17-4c39-aa4f-0ed491f7fad5",
646 | "colab": {
647 | "base_uri": "https://localhost:8080/",
648 | "height": 35
649 | }
650 | },
651 | "source": [
652 | "clf.predict([[2, 0]])"
653 | ],
654 | "execution_count": 134,
655 | "outputs": [
656 | {
657 | "output_type": "execute_result",
658 | "data": {
659 | "text/plain": [
660 | "array([-1])"
661 | ]
662 | },
663 | "metadata": {
664 | "tags": []
665 | },
666 | "execution_count": 134
667 | }
668 | ]
669 | },
670 | {
671 | "cell_type": "code",
672 | "metadata": {
673 | "colab_type": "code",
674 | "id": "7qit4xK_0aka",
675 | "colab": {}
676 | },
677 | "source": [
678 | "from sklearn.naive_bayes import BernoulliNB, MultinomialNB # 伯努利模型和多项式模型"
679 | ],
680 | "execution_count": 0,
681 | "outputs": []
682 | },
683 | {
684 | "cell_type": "code",
685 | "metadata": {
686 | "id": "l4sFBX_u0drg",
687 | "colab_type": "code",
688 | "colab": {
689 | "base_uri": "https://localhost:8080/",
690 | "height": 35
691 | },
692 | "outputId": "a7a74c51-e62e-46d1-b966-d00b7ac266f8"
693 | },
694 | "source": [
695 | "clf1 = BernoulliNB()\n",
696 | "clf1.fit(X, y)\n",
697 | "clf1.predict([[2, 0]])"
698 | ],
699 | "execution_count": 138,
700 | "outputs": [
701 | {
702 | "output_type": "execute_result",
703 | "data": {
704 | "text/plain": [
705 | "array([-1])"
706 | ]
707 | },
708 | "metadata": {
709 | "tags": []
710 | },
711 | "execution_count": 138
712 | }
713 | ]
714 | },
715 | {
716 | "cell_type": "code",
717 | "metadata": {
718 | "id": "QEClQeuV0qw2",
719 | "colab_type": "code",
720 | "colab": {
721 | "base_uri": "https://localhost:8080/",
722 | "height": 35
723 | },
724 | "outputId": "bef9b428-fc41-4a06-b120-8c66ab380a56"
725 | },
726 | "source": [
727 | "clf2 = MultinomialNB()\n",
728 | "clf2.fit(X, y)\n",
729 | "clf2.predict([[2, 0]])"
730 | ],
731 | "execution_count": 139,
732 | "outputs": [
733 | {
734 | "output_type": "execute_result",
735 | "data": {
736 | "text/plain": [
737 | "array([1])"
738 | ]
739 | },
740 | "metadata": {
741 | "tags": []
742 | },
743 | "execution_count": 139
744 | }
745 | ]
746 | }
747 | ]
748 | }
--------------------------------------------------------------------------------
/第05章 决策树/DT.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "DT.ipynb",
7 | "version": "0.3.2",
8 | "provenance": [],
9 | "collapsed_sections": [],
10 | "toc_visible": true
11 | },
12 | "language_info": {
13 | "codemirror_mode": {
14 | "name": "ipython",
15 | "version": 3
16 | },
17 | "file_extension": ".py",
18 | "mimetype": "text/x-python",
19 | "name": "python",
20 | "nbconvert_exporter": "python",
21 | "pygments_lexer": "ipython3",
22 | "version": "3.6.2"
23 | },
24 | "kernelspec": {
25 | "display_name": "Python 3",
26 | "language": "python",
27 | "name": "python3"
28 | }
29 | },
30 | "cells": [
31 | {
32 | "cell_type": "markdown",
33 | "metadata": {
34 | "id": "Ajhk76kTn8L4",
35 | "colab_type": "text"
36 | },
37 | "source": [
38 | "# 第5章 决策树"
39 | ]
40 | },
41 | {
42 | "cell_type": "markdown",
43 | "metadata": {
44 | "id": "-iCRuHmOn8L5",
45 | "colab_type": "text"
46 | },
47 | "source": [
48 | "- ID3(基于信息增益)\n",
49 | "- C4.5(基于信息增益比)\n",
50 | "- CART 二叉决策树(gini指数)"
51 | ]
52 | },
53 | {
54 | "cell_type": "markdown",
55 | "metadata": {
56 | "id": "EbDeryI9n8L6",
57 | "colab_type": "text"
58 | },
59 | "source": [
60 | "#### entropy:$H(x) = -\\sum_{i=1}^{n}p_i\\log{p_i}$\n",
61 | "\n",
62 | "#### conditional entropy: $H(X|Y)=\\sum{P(X|Y)}\\log{P(X|Y)}$\n",
63 | "\n",
64 | "#### information gain : $g(D, A)=H(D)-H(D|A)$\n",
65 | "\n",
66 | "#### information gain ratio: $g_R(D, A) = \\frac{g(D,A)}{H_{A}(D)}$\n",
67 | "\n",
68 | "#### gini index:$Gini(D)=\\sum_{k=1}^{K}p_k\\log{p_k}=1-\\sum_{k=1}^{K}p_k^2$"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "metadata": {
74 | "id": "qemEGcJ7n8L6",
75 | "colab_type": "code",
76 | "colab": {}
77 | },
78 | "source": [
79 | "import numpy as np\n",
80 | "import pandas as pd\n",
81 | "import matplotlib.pyplot as plt\n",
82 | "%matplotlib inline\n",
83 | "\n",
84 | "from sklearn.datasets import load_iris\n",
85 | "from sklearn.model_selection import train_test_split\n",
86 | "\n",
87 | "from collections import Counter\n",
88 | "import math\n",
89 | "from math import log\n",
90 | "\n",
91 | "import pprint"
92 | ],
93 | "execution_count": 0,
94 | "outputs": []
95 | },
96 | {
97 | "cell_type": "markdown",
98 | "metadata": {
99 | "id": "XyUeXAC_n8L-",
100 | "colab_type": "text"
101 | },
102 | "source": [
103 | "### 例 5.1"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "metadata": {
109 | "id": "YtNCGcaHn8L_",
110 | "colab_type": "code",
111 | "colab": {}
112 | },
113 | "source": [
114 | "def create_data():\n",
115 | " datasets = [['青年', '否', '否', '一般', '否'],\n",
116 | " ['青年', '否', '否', '好', '否'],\n",
117 | " ['青年', '是', '否', '好', '是'],\n",
118 | " ['青年', '是', '是', '一般', '是'],\n",
119 | " ['青年', '否', '否', '一般', '否'],\n",
120 | " ['中年', '否', '否', '一般', '否'],\n",
121 | " ['中年', '否', '否', '好', '否'],\n",
122 | " ['中年', '是', '是', '好', '是'],\n",
123 | " ['中年', '否', '是', '非常好', '是'],\n",
124 | " ['中年', '否', '是', '非常好', '是'],\n",
125 | " ['老年', '否', '是', '非常好', '是'],\n",
126 | " ['老年', '否', '是', '好', '是'],\n",
127 | " ['老年', '是', '否', '好', '是'],\n",
128 | " ['老年', '是', '否', '非常好', '是'],\n",
129 | " ['老年', '否', '否', '一般', '否'],\n",
130 | " ]\n",
131 | " labels = [u'年龄', u'有工作', u'有自己的房子', u'信贷情况', u'类别']\n",
132 | " # 返回数据集和每个维度的名称\n",
133 | " return datasets, labels"
134 | ],
135 | "execution_count": 0,
136 | "outputs": []
137 | },
138 | {
139 | "cell_type": "code",
140 | "metadata": {
141 | "id": "Ji3uUZS-n8MB",
142 | "colab_type": "code",
143 | "outputId": "bec9dbfe-5016-44ff-a080-6e3eea61bbd6",
144 | "colab": {
145 | "base_uri": "https://localhost:8080/",
146 | "height": 514
147 | }
148 | },
149 | "source": [
150 | "datasets, labels = create_data()\n",
151 | "train_data = pd.DataFrame(datasets, columns=labels)\n",
152 | "train_data"
153 | ],
154 | "execution_count": 68,
155 | "outputs": [
156 | {
157 | "output_type": "execute_result",
158 | "data": {
159 | "text/html": [
160 | "
\n",
161 | "\n",
174 | "
\n",
175 | " \n",
176 | " \n",
177 | " | \n",
178 | " 年龄 | \n",
179 | " 有工作 | \n",
180 | " 有自己的房子 | \n",
181 | " 信贷情况 | \n",
182 | " 类别 | \n",
183 | "
\n",
184 | " \n",
185 | " \n",
186 | " \n",
187 | " 0 | \n",
188 | " 青年 | \n",
189 | " 否 | \n",
190 | " 否 | \n",
191 | " 一般 | \n",
192 | " 否 | \n",
193 | "
\n",
194 | " \n",
195 | " 1 | \n",
196 | " 青年 | \n",
197 | " 否 | \n",
198 | " 否 | \n",
199 | " 好 | \n",
200 | " 否 | \n",
201 | "
\n",
202 | " \n",
203 | " 2 | \n",
204 | " 青年 | \n",
205 | " 是 | \n",
206 | " 否 | \n",
207 | " 好 | \n",
208 | " 是 | \n",
209 | "
\n",
210 | " \n",
211 | " 3 | \n",
212 | " 青年 | \n",
213 | " 是 | \n",
214 | " 是 | \n",
215 | " 一般 | \n",
216 | " 是 | \n",
217 | "
\n",
218 | " \n",
219 | " 4 | \n",
220 | " 青年 | \n",
221 | " 否 | \n",
222 | " 否 | \n",
223 | " 一般 | \n",
224 | " 否 | \n",
225 | "
\n",
226 | " \n",
227 | " 5 | \n",
228 | " 中年 | \n",
229 | " 否 | \n",
230 | " 否 | \n",
231 | " 一般 | \n",
232 | " 否 | \n",
233 | "
\n",
234 | " \n",
235 | " 6 | \n",
236 | " 中年 | \n",
237 | " 否 | \n",
238 | " 否 | \n",
239 | " 好 | \n",
240 | " 否 | \n",
241 | "
\n",
242 | " \n",
243 | " 7 | \n",
244 | " 中年 | \n",
245 | " 是 | \n",
246 | " 是 | \n",
247 | " 好 | \n",
248 | " 是 | \n",
249 | "
\n",
250 | " \n",
251 | " 8 | \n",
252 | " 中年 | \n",
253 | " 否 | \n",
254 | " 是 | \n",
255 | " 非常好 | \n",
256 | " 是 | \n",
257 | "
\n",
258 | " \n",
259 | " 9 | \n",
260 | " 中年 | \n",
261 | " 否 | \n",
262 | " 是 | \n",
263 | " 非常好 | \n",
264 | " 是 | \n",
265 | "
\n",
266 | " \n",
267 | " 10 | \n",
268 | " 老年 | \n",
269 | " 否 | \n",
270 | " 是 | \n",
271 | " 非常好 | \n",
272 | " 是 | \n",
273 | "
\n",
274 | " \n",
275 | " 11 | \n",
276 | " 老年 | \n",
277 | " 否 | \n",
278 | " 是 | \n",
279 | " 好 | \n",
280 | " 是 | \n",
281 | "
\n",
282 | " \n",
283 | " 12 | \n",
284 | " 老年 | \n",
285 | " 是 | \n",
286 | " 否 | \n",
287 | " 好 | \n",
288 | " 是 | \n",
289 | "
\n",
290 | " \n",
291 | " 13 | \n",
292 | " 老年 | \n",
293 | " 是 | \n",
294 | " 否 | \n",
295 | " 非常好 | \n",
296 | " 是 | \n",
297 | "
\n",
298 | " \n",
299 | " 14 | \n",
300 | " 老年 | \n",
301 | " 否 | \n",
302 | " 否 | \n",
303 | " 一般 | \n",
304 | " 否 | \n",
305 | "
\n",
306 | " \n",
307 | "
\n",
308 | "
"
309 | ],
310 | "text/plain": [
311 | " 年龄 有工作 有自己的房子 信贷情况 类别\n",
312 | "0 青年 否 否 一般 否\n",
313 | "1 青年 否 否 好 否\n",
314 | "2 青年 是 否 好 是\n",
315 | "3 青年 是 是 一般 是\n",
316 | "4 青年 否 否 一般 否\n",
317 | "5 中年 否 否 一般 否\n",
318 | "6 中年 否 否 好 否\n",
319 | "7 中年 是 是 好 是\n",
320 | "8 中年 否 是 非常好 是\n",
321 | "9 中年 否 是 非常好 是\n",
322 | "10 老年 否 是 非常好 是\n",
323 | "11 老年 否 是 好 是\n",
324 | "12 老年 是 否 好 是\n",
325 | "13 老年 是 否 非常好 是\n",
326 | "14 老年 否 否 一般 否"
327 | ]
328 | },
329 | "metadata": {
330 | "tags": []
331 | },
332 | "execution_count": 68
333 | }
334 | ]
335 | },
336 | {
337 | "cell_type": "code",
338 | "metadata": {
339 | "id": "bWbCbJSmrdkp",
340 | "colab_type": "code",
341 | "outputId": "3d16ea53-103e-4ad9-c3d6-c40b86614406",
342 | "colab": {
343 | "base_uri": "https://localhost:8080/",
344 | "height": 287
345 | }
346 | },
347 | "source": [
348 | "datasets"
349 | ],
350 | "execution_count": 53,
351 | "outputs": [
352 | {
353 | "output_type": "execute_result",
354 | "data": {
355 | "text/plain": [
356 | "[['青年', '否', '否', '一般', '否'],\n",
357 | " ['青年', '否', '否', '好', '否'],\n",
358 | " ['青年', '是', '否', '好', '是'],\n",
359 | " ['青年', '是', '是', '一般', '是'],\n",
360 | " ['青年', '否', '否', '一般', '否'],\n",
361 | " ['中年', '否', '否', '一般', '否'],\n",
362 | " ['中年', '否', '否', '好', '否'],\n",
363 | " ['中年', '是', '是', '好', '是'],\n",
364 | " ['中年', '否', '是', '非常好', '是'],\n",
365 | " ['中年', '否', '是', '非常好', '是'],\n",
366 | " ['老年', '否', '是', '非常好', '是'],\n",
367 | " ['老年', '否', '是', '好', '是'],\n",
368 | " ['老年', '是', '否', '好', '是'],\n",
369 | " ['老年', '是', '否', '非常好', '是'],\n",
370 | " ['老年', '否', '否', '一般', '否']]"
371 | ]
372 | },
373 | "metadata": {
374 | "tags": []
375 | },
376 | "execution_count": 53
377 | }
378 | ]
379 | },
380 | {
381 | "cell_type": "code",
382 | "metadata": {
383 | "id": "9zcwdE1uiTOO",
384 | "colab_type": "code",
385 | "colab": {
386 | "base_uri": "https://localhost:8080/",
387 | "height": 35
388 | },
389 | "outputId": "30378d95-6bb8-4d63-b57b-33f70b357b3e"
390 | },
391 | "source": [
392 | "labels"
393 | ],
394 | "execution_count": 54,
395 | "outputs": [
396 | {
397 | "output_type": "execute_result",
398 | "data": {
399 | "text/plain": [
400 | "['年龄', '有工作', '有自己的房子', '信贷情况', '类别']"
401 | ]
402 | },
403 | "metadata": {
404 | "tags": []
405 | },
406 | "execution_count": 54
407 | }
408 | ]
409 | },
410 | {
411 | "cell_type": "code",
412 | "metadata": {
413 | "id": "UP3X4BaVrgYQ",
414 | "colab_type": "code",
415 | "colab": {}
416 | },
417 | "source": [
418 | "d = {'青年':1, '中年':2, '老年':3, '一般':1, '好':2, '非常好':3, '是':0, '否':1}\n",
419 | "data = []\n",
420 | "for i in range(15):\n",
421 | " tmp = []\n",
422 | " t = datasets[i]\n",
423 | " for tt in t:\n",
424 | " tmp.append(d[tt])\n",
425 | " data.append(tmp)"
426 | ],
427 | "execution_count": 0,
428 | "outputs": []
429 | },
430 | {
431 | "cell_type": "code",
432 | "metadata": {
433 | "id": "5tV-TRIQqftJ",
434 | "colab_type": "code",
435 | "outputId": "16be14d9-1d5d-4080-db56-0dfcf329830f",
436 | "colab": {
437 | "base_uri": "https://localhost:8080/",
438 | "height": 287
439 | }
440 | },
441 | "source": [
442 | "data = np.array(data);data"
443 | ],
444 | "execution_count": 56,
445 | "outputs": [
446 | {
447 | "output_type": "execute_result",
448 | "data": {
449 | "text/plain": [
450 | "array([[1, 1, 1, 1, 1],\n",
451 | " [1, 1, 1, 2, 1],\n",
452 | " [1, 0, 1, 2, 0],\n",
453 | " [1, 0, 0, 1, 0],\n",
454 | " [1, 1, 1, 1, 1],\n",
455 | " [2, 1, 1, 1, 1],\n",
456 | " [2, 1, 1, 2, 1],\n",
457 | " [2, 0, 0, 2, 0],\n",
458 | " [2, 1, 0, 3, 0],\n",
459 | " [2, 1, 0, 3, 0],\n",
460 | " [3, 1, 0, 3, 0],\n",
461 | " [3, 1, 0, 2, 0],\n",
462 | " [3, 0, 1, 2, 0],\n",
463 | " [3, 0, 1, 3, 0],\n",
464 | " [3, 1, 1, 1, 1]])"
465 | ]
466 | },
467 | "metadata": {
468 | "tags": []
469 | },
470 | "execution_count": 56
471 | }
472 | ]
473 | },
474 | {
475 | "cell_type": "code",
476 | "metadata": {
477 | "id": "sN169YUn2LvE",
478 | "colab_type": "code",
479 | "outputId": "b0561124-c930-4706-fed3-095308e4f53f",
480 | "colab": {
481 | "base_uri": "https://localhost:8080/",
482 | "height": 35
483 | }
484 | },
485 | "source": [
486 | "data.shape"
487 | ],
488 | "execution_count": 57,
489 | "outputs": [
490 | {
491 | "output_type": "execute_result",
492 | "data": {
493 | "text/plain": [
494 | "(15, 5)"
495 | ]
496 | },
497 | "metadata": {
498 | "tags": []
499 | },
500 | "execution_count": 57
501 | }
502 | ]
503 | },
504 | {
505 | "cell_type": "code",
506 | "metadata": {
507 | "id": "oN7QSJC72UN-",
508 | "colab_type": "code",
509 | "colab": {}
510 | },
511 | "source": [
512 | "X, y = data[:,:-1], data[:, -1]"
513 | ],
514 | "execution_count": 0,
515 | "outputs": []
516 | },
517 | {
518 | "cell_type": "code",
519 | "metadata": {
520 | "id": "1KsMqBec5Cwb",
521 | "colab_type": "code",
522 | "colab": {}
523 | },
524 | "source": [
525 | "# 熵\n",
526 | "def entropy(y):\n",
527 | " N = len(y)\n",
528 | " count = []\n",
529 | " for value in set(y):\n",
530 | " count.append(len(y[y == value]))\n",
531 | " count = np.array(count)\n",
532 | " entro = -np.sum((count / N) * (np.log2(count / N)))\n",
533 | " return entro"
534 | ],
535 | "execution_count": 0,
536 | "outputs": []
537 | },
538 | {
539 | "cell_type": "code",
540 | "metadata": {
541 | "id": "DWb2n4RcDflB",
542 | "colab_type": "code",
543 | "outputId": "b01d7bde-33c9-467a-bf65-28d8c4d3e77f",
544 | "colab": {
545 | "base_uri": "https://localhost:8080/",
546 | "height": 35
547 | }
548 | },
549 | "source": [
550 | "entropy(y)"
551 | ],
552 | "execution_count": 10,
553 | "outputs": [
554 | {
555 | "output_type": "execute_result",
556 | "data": {
557 | "text/plain": [
558 | "0.9709505944546686"
559 | ]
560 | },
561 | "metadata": {
562 | "tags": []
563 | },
564 | "execution_count": 10
565 | }
566 | ]
567 | },
568 | {
569 | "cell_type": "code",
570 | "metadata": {
571 | "id": "ApbpfKpcxw6y",
572 | "colab_type": "code",
573 | "colab": {}
574 | },
575 | "source": [
576 | "# 条件熵\n",
577 | "def cond_entropy(X, y, cond):\n",
578 | " N = len(y)\n",
579 | " cond_X = X[:, cond]\n",
580 | " tmp_entro = []\n",
581 | " for val in set(cond_X):\n",
582 | " tmp_y = y[np.where(cond_X == val)]\n",
583 | " tmp_entro.append(len(tmp_y)/N * entropy(tmp_y))\n",
584 | " cond_entro = sum(tmp_entro)\n",
585 | " return cond_entro"
586 | ],
587 | "execution_count": 0,
588 | "outputs": []
589 | },
590 | {
591 | "cell_type": "code",
592 | "metadata": {
593 | "id": "NF-g7udFK5qN",
594 | "colab_type": "code",
595 | "outputId": "9ddf18ff-6ddc-48fb-d766-4dc2c5c4034e",
596 | "colab": {
597 | "base_uri": "https://localhost:8080/",
598 | "height": 35
599 | }
600 | },
601 | "source": [
602 | "cond_entropy(X, y, 0)"
603 | ],
604 | "execution_count": 12,
605 | "outputs": [
606 | {
607 | "output_type": "execute_result",
608 | "data": {
609 | "text/plain": [
610 | "0.8879430945988998"
611 | ]
612 | },
613 | "metadata": {
614 | "tags": []
615 | },
616 | "execution_count": 12
617 | }
618 | ]
619 | },
620 | {
621 | "cell_type": "code",
622 | "metadata": {
623 | "id": "QXrKL-4mS3Q5",
624 | "colab_type": "code",
625 | "colab": {}
626 | },
627 | "source": [
628 | "# 信息增益\n",
629 | "def info_gain(X, y, cond):\n",
630 | " return entropy(y) - cond_entropy(X, y, cond)"
631 | ],
632 | "execution_count": 0,
633 | "outputs": []
634 | },
635 | {
636 | "cell_type": "code",
637 | "metadata": {
638 | "id": "KjLX_NtqezON",
639 | "colab_type": "code",
640 | "colab": {}
641 | },
642 | "source": [
643 | "# 信息增益比\n",
644 | "def info_gain_ratio(X, y, cond):\n",
645 | " return (entropy(y) - cond_entropy(X, y, cond))/cond_entropy(X, y, cond)"
646 | ],
647 | "execution_count": 0,
648 | "outputs": []
649 | },
650 | {
651 | "cell_type": "code",
652 | "metadata": {
653 | "id": "kKY7AVPeF4kh",
654 | "colab_type": "code",
655 | "outputId": "670c66c9-8f2c-46b0-b626-633f7fca4cd8",
656 | "colab": {
657 | "base_uri": "https://localhost:8080/",
658 | "height": 35
659 | }
660 | },
661 | "source": [
662 | "# A1, A2, A3, A4 =》年龄 工作 房子 信贷\n",
663 | "# 信息增益\n",
664 | "\n",
665 | "gain_a1 = info_gain(X, y, 0);gain_a1"
666 | ],
667 | "execution_count": 15,
668 | "outputs": [
669 | {
670 | "output_type": "execute_result",
671 | "data": {
672 | "text/plain": [
673 | "0.08300749985576883"
674 | ]
675 | },
676 | "metadata": {
677 | "tags": []
678 | },
679 | "execution_count": 15
680 | }
681 | ]
682 | },
683 | {
684 | "cell_type": "code",
685 | "metadata": {
686 | "id": "VVTUqG4tSgwn",
687 | "colab_type": "code",
688 | "outputId": "72b043b5-a4c0-42db-b12d-5a31f577ef80",
689 | "colab": {
690 | "base_uri": "https://localhost:8080/",
691 | "height": 34
692 | }
693 | },
694 | "source": [
695 | "gain_a2 = info_gain(X, y, 1);gain_a2"
696 | ],
697 | "execution_count": 0,
698 | "outputs": [
699 | {
700 | "output_type": "execute_result",
701 | "data": {
702 | "text/plain": [
703 | "0.32365019815155627"
704 | ]
705 | },
706 | "metadata": {
707 | "tags": []
708 | },
709 | "execution_count": 16
710 | }
711 | ]
712 | },
713 | {
714 | "cell_type": "code",
715 | "metadata": {
716 | "id": "242jN12HSj_F",
717 | "colab_type": "code",
718 | "outputId": "a620d840-ac93-4adb-da8e-151bbe04e95c",
719 | "colab": {
720 | "base_uri": "https://localhost:8080/",
721 | "height": 34
722 | }
723 | },
724 | "source": [
725 | "gain_a3 = info_gain(X, y, 2);gain_a3"
726 | ],
727 | "execution_count": 0,
728 | "outputs": [
729 | {
730 | "output_type": "execute_result",
731 | "data": {
732 | "text/plain": [
733 | "0.4199730940219749"
734 | ]
735 | },
736 | "metadata": {
737 | "tags": []
738 | },
739 | "execution_count": 17
740 | }
741 | ]
742 | },
743 | {
744 | "cell_type": "code",
745 | "metadata": {
746 | "id": "m9prl_iaSmM1",
747 | "colab_type": "code",
748 | "outputId": "b440c3ed-8bc7-4f36-caf5-d13524d43957",
749 | "colab": {
750 | "base_uri": "https://localhost:8080/",
751 | "height": 34
752 | }
753 | },
754 | "source": [
755 | "gain_a4 = info_gain(X, y, 3);gain_a4"
756 | ],
757 | "execution_count": 0,
758 | "outputs": [
759 | {
760 | "output_type": "execute_result",
761 | "data": {
762 | "text/plain": [
763 | "0.36298956253708536"
764 | ]
765 | },
766 | "metadata": {
767 | "tags": []
768 | },
769 | "execution_count": 18
770 | }
771 | ]
772 | },
773 | {
774 | "cell_type": "code",
775 | "metadata": {
776 | "id": "eIuVibAjpXSr",
777 | "colab_type": "code",
778 | "colab": {}
779 | },
780 | "source": [
781 | "def best_split(X,y, method='info_gain'):\n",
782 | " \"\"\"根据method指定的方法使用信息增益或信息增益比来计算各个维度的最大信息增益(比),返回特征的axis\"\"\"\n",
783 | " _, M = X.shape\n",
784 | " info_gains = []\n",
785 | " if method == 'info_gain':\n",
786 | " split = info_gain\n",
787 | " elif method == 'info_gain_ratio':\n",
788 | " split = info_gain_ratio\n",
789 | " else:\n",
790 | " print('No such method')\n",
791 | " return\n",
792 | " for i in range(M):\n",
793 | " tmp_gain = split(X, y, i)\n",
794 | " info_gains.append(tmp_gain)\n",
795 | " best_feature = np.argmax(info_gains)\n",
796 | " \n",
797 | " return best_feature"
798 | ],
799 | "execution_count": 0,
800 | "outputs": []
801 | },
802 | {
803 | "cell_type": "code",
804 | "metadata": {
805 | "id": "Tr6ckR8wriYm",
806 | "colab_type": "code",
807 | "outputId": "d2db3308-ce72-4f5d-c74e-893944909892",
808 | "colab": {
809 | "base_uri": "https://localhost:8080/",
810 | "height": 35
811 | }
812 | },
813 | "source": [
814 | "best_split(X,y)"
815 | ],
816 | "execution_count": 27,
817 | "outputs": [
818 | {
819 | "output_type": "execute_result",
820 | "data": {
821 | "text/plain": [
822 | "2"
823 | ]
824 | },
825 | "metadata": {
826 | "tags": []
827 | },
828 | "execution_count": 27
829 | }
830 | ]
831 | },
832 | {
833 | "cell_type": "code",
834 | "metadata": {
835 | "id": "iv2hm3ueTKa6",
836 | "colab_type": "code",
837 | "colab": {}
838 | },
839 | "source": [
840 | "def majorityCnt(y):\n",
841 | " \"\"\"当特征使用完时,返回类别数最多的类别\"\"\"\n",
842 | " unique, counts = np.unique(y, return_counts=True)\n",
843 | " max_idx = np.argmax(counts)\n",
844 | " return unique[max_idx]"
845 | ],
846 | "execution_count": 0,
847 | "outputs": []
848 | },
849 | {
850 | "cell_type": "code",
851 | "metadata": {
852 | "id": "FXlY9UPxT80q",
853 | "colab_type": "code",
854 | "colab": {
855 | "base_uri": "https://localhost:8080/",
856 | "height": 35
857 | },
858 | "outputId": "e356a964-3f83-40b4-ed66-41fae5266a89"
859 | },
860 | "source": [
861 | "majorityCnt(y)"
862 | ],
863 | "execution_count": 20,
864 | "outputs": [
865 | {
866 | "output_type": "execute_result",
867 | "data": {
868 | "text/plain": [
869 | "0"
870 | ]
871 | },
872 | "metadata": {
873 | "tags": []
874 | },
875 | "execution_count": 20
876 | }
877 | ]
878 | },
879 | {
880 | "cell_type": "markdown",
881 | "metadata": {
882 | "collapsed": true,
883 | "id": "EDx9vrfcn8MQ",
884 | "colab_type": "text"
885 | },
886 | "source": [
887 | "#### ID3, C4.5算法\n",
888 | "\n",
889 | "例5.3"
890 | ]
891 | },
892 | {
893 | "cell_type": "code",
894 | "metadata": {
895 | "id": "kpgCEMIKRo8_",
896 | "colab_type": "code",
897 | "colab": {}
898 | },
899 | "source": [
900 | "class DecisionTreeClassifer:\n",
901 | " \"\"\"\n",
902 | " 决策树生成算法,\n",
903 | " method指定ID3或C4.5,两方法唯一不同在于特征选择方法不同\n",
904 | " info_gain: 信息增益即ID3\n",
905 | " info_gain_ratio: 信息增益比即C4.5\n",
906 | " \n",
907 | " \n",
908 | " \"\"\"\n",
909 | " def __init__(self, threshold, method='info_gain'):\n",
910 | " self.threshold = threshold\n",
911 | " self.method = method\n",
912 | " \n",
913 | " def fit(self, X, y, labels):\n",
914 | " labels = labels.copy()\n",
915 | " M, N = X.shape\n",
916 | " if len(np.unique(y)) == 1:\n",
917 | " return y[0]\n",
918 | " \n",
919 | " if N == 1:\n",
920 | " return majorityCnt(y)\n",
921 | " \n",
922 | " bestSplit = best_split(X,y, method=self.method)\n",
923 | " bestFeaLable = labels[bestSplit]\n",
924 | " Tree = {bestFeaLable: {}}\n",
925 | " del (labels[bestSplit])\n",
926 | " \n",
927 | " feaVals = np.unique(X[:, bestSplit])\n",
928 | " for val in feaVals:\n",
929 | " idx = np.where(X[:, bestSplit] == val)\n",
930 | " sub_X = X[idx]\n",
931 | " sub_y = y[idx]\n",
932 | " sub_labels = labels\n",
933 | " Tree[bestFeaLable][val] = self.fit(sub_X, sub_y, sub_labels)\n",
934 | " \n",
935 | " return Tree"
936 | ],
937 | "execution_count": 0,
938 | "outputs": []
939 | },
940 | {
941 | "cell_type": "code",
942 | "metadata": {
943 | "id": "8k4cgeqBn8MQ",
944 | "colab_type": "code",
945 | "colab": {
946 | "base_uri": "https://localhost:8080/",
947 | "height": 35
948 | },
949 | "outputId": "f3c4ca27-9f09-4d42-bd58-3afa41cc32e0"
950 | },
951 | "source": [
952 | "My_Tree = DecisionTreeClassifer(threshold=0.1)\n",
953 | "My_Tree.fit(X, y, labels)"
954 | ],
955 | "execution_count": 69,
956 | "outputs": [
957 | {
958 | "output_type": "execute_result",
959 | "data": {
960 | "text/plain": [
961 | "{'有自己的房子': {0: 0, 1: {'有工作': {0: 0, 1: 1}}}}"
962 | ]
963 | },
964 | "metadata": {
965 | "tags": []
966 | },
967 | "execution_count": 69
968 | }
969 | ]
970 | },
971 | {
972 | "cell_type": "markdown",
973 | "metadata": {
974 | "id": "XaGNaDfAoivJ",
975 | "colab_type": "text"
976 | },
977 | "source": [
978 | "#### CART树"
979 | ]
980 | },
981 | {
982 | "cell_type": "code",
983 | "metadata": {
984 | "id": "yXTTfkLCmsdP",
985 | "colab_type": "code",
986 | "colab": {}
987 | },
988 | "source": [
989 | "class CART:\n",
990 | " \"\"\"CART树\"\"\"\n",
991 | " def __init__(self, ):\n",
992 | " \"to be continue\""
993 | ],
994 | "execution_count": 0,
995 | "outputs": []
996 | },
997 | {
998 | "cell_type": "markdown",
999 | "metadata": {
1000 | "id": "6nxK8duGo37e",
1001 | "colab_type": "text"
1002 | },
1003 | "source": [
1004 | "#### 决策树的剪枝"
1005 | ]
1006 | },
1007 | {
1008 | "cell_type": "code",
1009 | "metadata": {
1010 | "id": "N79jPbwWo6rv",
1011 | "colab_type": "code",
1012 | "colab": {}
1013 | },
1014 | "source": [
1015 | "\"to be continue\""
1016 | ],
1017 | "execution_count": 0,
1018 | "outputs": []
1019 | },
1020 | {
1021 | "cell_type": "markdown",
1022 | "metadata": {
1023 | "id": "Gop3ocYDn8MZ",
1024 | "colab_type": "text"
1025 | },
1026 | "source": [
1027 | "---\n",
1028 | "\n",
1029 | "## sklearn.tree.DecisionTreeClassifier\n",
1030 | "\n",
1031 | "### criterion : string, optional (default=”gini”)\n",
1032 | "The function to measure the quality of a split. Supported criteria are “gini” for the Gini impurity and “entropy” for the information gain."
1033 | ]
1034 | },
1035 | {
1036 | "cell_type": "code",
1037 | "metadata": {
1038 | "id": "nxE7F4sqn8Ma",
1039 | "colab_type": "code",
1040 | "colab": {}
1041 | },
1042 | "source": [
1043 | "# data\n",
1044 | "def create_data():\n",
1045 | " iris = load_iris()\n",
1046 | " df = pd.DataFrame(iris.data, columns=iris.feature_names)\n",
1047 | " df['label'] = iris.target\n",
1048 | " df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n",
1049 | " data = np.array(df.iloc[:100, [0, 1, -1]])\n",
1050 | " # print(data)\n",
1051 | " return data[:,:2], data[:,-1]\n",
1052 | "\n",
1053 | "X, y = create_data()\n",
1054 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)"
1055 | ],
1056 | "execution_count": 0,
1057 | "outputs": []
1058 | },
1059 | {
1060 | "cell_type": "code",
1061 | "metadata": {
1062 | "id": "LyqL3F8un8Mc",
1063 | "colab_type": "code",
1064 | "colab": {}
1065 | },
1066 | "source": [
1067 | "from sklearn.tree import DecisionTreeClassifier\n",
1068 | "\n",
1069 | "from sklearn.tree import export_graphviz\n",
1070 | "import graphviz"
1071 | ],
1072 | "execution_count": 0,
1073 | "outputs": []
1074 | },
1075 | {
1076 | "cell_type": "code",
1077 | "metadata": {
1078 | "id": "nNDjw1Phn8Me",
1079 | "colab_type": "code",
1080 | "outputId": "d2dd416b-1a48-4564-c53f-c801416c6df0",
1081 | "colab": {
1082 | "base_uri": "https://localhost:8080/",
1083 | "height": 125
1084 | }
1085 | },
1086 | "source": [
1087 | "clf = DecisionTreeClassifier()\n",
1088 | "clf.fit(data[:,:-1], data[:,-1])"
1089 | ],
1090 | "execution_count": 0,
1091 | "outputs": [
1092 | {
1093 | "output_type": "execute_result",
1094 | "data": {
1095 | "text/plain": [
1096 | "DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,\n",
1097 | " max_features=None, max_leaf_nodes=None,\n",
1098 | " min_impurity_decrease=0.0, min_impurity_split=None,\n",
1099 | " min_samples_leaf=1, min_samples_split=2,\n",
1100 | " min_weight_fraction_leaf=0.0, presort=False,\n",
1101 | " random_state=None, splitter='best')"
1102 | ]
1103 | },
1104 | "metadata": {
1105 | "tags": []
1106 | },
1107 | "execution_count": 25
1108 | }
1109 | ]
1110 | },
1111 | {
1112 | "cell_type": "code",
1113 | "metadata": {
1114 | "id": "RsB_iiLZn8Mh",
1115 | "colab_type": "code",
1116 | "outputId": "af553ea2-ec41-496d-ceb9-a750be3d8088",
1117 | "colab": {
1118 | "base_uri": "https://localhost:8080/",
1119 | "height": 35
1120 | }
1121 | },
1122 | "source": [
1123 | "clf.predict(np.array([1, 1, 0, 1]).reshape(1,-1)) # A"
1124 | ],
1125 | "execution_count": 0,
1126 | "outputs": [
1127 | {
1128 | "output_type": "execute_result",
1129 | "data": {
1130 | "text/plain": [
1131 | "array([0])"
1132 | ]
1133 | },
1134 | "metadata": {
1135 | "tags": []
1136 | },
1137 | "execution_count": 28
1138 | }
1139 | ]
1140 | },
1141 | {
1142 | "cell_type": "code",
1143 | "metadata": {
1144 | "id": "Sd2ScBfQu1Bo",
1145 | "colab_type": "code",
1146 | "outputId": "cbdf0ec2-1fd6-48a3-ae24-f4de9d6a442b",
1147 | "colab": {
1148 | "base_uri": "https://localhost:8080/",
1149 | "height": 35
1150 | }
1151 | },
1152 | "source": [
1153 | "clf.predict(np.array([2, 0, 1, 2]).reshape(1,-1)) # B"
1154 | ],
1155 | "execution_count": 0,
1156 | "outputs": [
1157 | {
1158 | "output_type": "execute_result",
1159 | "data": {
1160 | "text/plain": [
1161 | "array([0])"
1162 | ]
1163 | },
1164 | "metadata": {
1165 | "tags": []
1166 | },
1167 | "execution_count": 29
1168 | }
1169 | ]
1170 | },
1171 | {
1172 | "cell_type": "code",
1173 | "metadata": {
1174 | "id": "0E9mMz34u1a3",
1175 | "colab_type": "code",
1176 | "outputId": "7078b7fc-3322-4ee5-9e5f-6f8d575347f6",
1177 | "colab": {
1178 | "base_uri": "https://localhost:8080/",
1179 | "height": 35
1180 | }
1181 | },
1182 | "source": [
1183 | "clf.predict(np.array([2, 1, 0, 1]).reshape(1,-1)) # C"
1184 | ],
1185 | "execution_count": 0,
1186 | "outputs": [
1187 | {
1188 | "output_type": "execute_result",
1189 | "data": {
1190 | "text/plain": [
1191 | "array([0])"
1192 | ]
1193 | },
1194 | "metadata": {
1195 | "tags": []
1196 | },
1197 | "execution_count": 30
1198 | }
1199 | ]
1200 | },
1201 | {
1202 | "cell_type": "code",
1203 | "metadata": {
1204 | "id": "rmZHZjbYn8Mm",
1205 | "colab_type": "code",
1206 | "colab": {}
1207 | },
1208 | "source": [
1209 | "tree_pic = export_graphviz(clf, out_file=\"mytree.pdf\")\n",
1210 | "with open('mytree.pdf') as f:\n",
1211 | " dot_graph = f.read()"
1212 | ],
1213 | "execution_count": 0,
1214 | "outputs": []
1215 | },
1216 | {
1217 | "cell_type": "code",
1218 | "metadata": {
1219 | "id": "AeRk07sYn8Mq",
1220 | "colab_type": "code",
1221 | "outputId": "c526e824-6d1d-4f3e-f231-53b9d6447d90",
1222 | "colab": {
1223 | "base_uri": "https://localhost:8080/",
1224 | "height": 379
1225 | }
1226 | },
1227 | "source": [
1228 | "graphviz.Source(dot_graph)"
1229 | ],
1230 | "execution_count": 0,
1231 | "outputs": [
1232 | {
1233 | "output_type": "execute_result",
1234 | "data": {
1235 | "text/plain": [
1236 | ""
1237 | ],
1238 | "image/svg+xml": "\n\n\n\n\n"
1239 | },
1240 | "metadata": {
1241 | "tags": []
1242 | },
1243 | "execution_count": 32
1244 | }
1245 | ]
1246 | },
1247 | {
1248 | "cell_type": "code",
1249 | "metadata": {
1250 | "id": "dlk_DsGByMix",
1251 | "colab_type": "code",
1252 | "colab": {}
1253 | },
1254 | "source": [
1255 | ""
1256 | ],
1257 | "execution_count": 0,
1258 | "outputs": []
1259 | }
1260 | ]
1261 | }
--------------------------------------------------------------------------------
/第07章 支持向量机/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hktxt/Learn-Statistical-Learning-Method/14a3b65db3ff35c42743a1062a2a95cecbfad094/第07章 支持向量机/.DS_Store
--------------------------------------------------------------------------------
/第07章 支持向量机/SVM.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 第7章 支持向量机"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "----\n",
15 | "分离超平面:$w^Tx+b=0$\n",
16 | "\n",
17 | "点到直线距离:$r=\\frac{|w^Tx+b|}{||w||_2}$\n",
18 | "\n",
19 | "$||w||_2$为2-范数:$||w||_2=\\sqrt[2]{\\sum^m_{i=1}w_i^2}$\n",
20 | "\n",
21 | "直线为超平面,样本可表示为:\n",
22 | "\n",
23 | "$w^Tx+b\\ \\geq+1$\n",
24 | "\n",
25 | "$w^Tx+b\\ \\leq+1$\n",
26 | "\n",
27 | "#### margin:\n",
28 | "\n",
29 | "**函数间隔**:$label(w^Tx+b)\\ or\\ y_i(w^Tx+b)$\n",
30 | "\n",
31 | "**几何间隔**:$r=\\frac{label(w^Tx+b)}{||w||_2}$,当数据被正确分类时,几何间隔就是点到超平面的距离\n",
32 | "\n",
33 | "为了求几何间隔最大,SVM基本问题可以转化为求解:($\\frac{r^*}{||w||}$为几何间隔,(${r^*}$为函数间隔)\n",
34 | "\n",
35 | "$$\\max\\ \\frac{r^*}{||w||}$$\n",
36 | "\n",
37 | "$$(subject\\ to)\\ y_i({w^T}x_i+{b})\\geq {r^*},\\ i=1,2,..,m$$\n",
38 | "\n",
39 | "分类点几何间隔最大,同时被正确分类。但这个方程并非凸函数求解,所以要先①将方程转化为凸函数,②用拉格朗日乘子法和KKT条件求解对偶问题。\n",
40 | "\n",
41 | "①转化为凸函数:\n",
42 | "\n",
43 | "先令${r^*}=1$,方便计算(参照衡量,不影响评价结果)\n",
44 | "\n",
45 | "$$\\max\\ \\frac{1}{||w||}$$\n",
46 | "\n",
47 | "$$s.t.\\ y_i({w^T}x_i+{b})\\geq {1},\\ i=1,2,..,m$$\n",
48 | "\n",
49 | "再将$\\max\\ \\frac{1}{||w||}$转化成$\\min\\ \\frac{1}{2}||w||^2$求解凸函数,1/2是为了求导之后方便计算。\n",
50 | "\n",
51 | "$$\\min\\ \\frac{1}{2}||w||^2$$\n",
52 | "\n",
53 | "$$s.t.\\ y_i(w^Tx_i+b)\\geq 1,\\ i=1,2,..,m$$\n",
54 | "\n",
55 | "②用拉格朗日乘子法和KKT条件求解最优值:\n",
56 | "\n",
57 | "$$\\min\\ \\frac{1}{2}||w||^2$$\n",
58 | "\n",
59 | "$$s.t.\\ -y_i(w^Tx_i+b)+1\\leq 0,\\ i=1,2,..,m$$\n",
60 | "\n",
61 | "整合成:\n",
62 | "\n",
63 | "$$L(w, b, \\alpha) = \\frac{1}{2}||w||^2+\\sum^m_{i=1}\\alpha_i(-y_i(w^Tx_i+b)+1)$$\n",
64 | "\n",
65 | "推导:$\\min\\ f(x)=\\min \\max\\ L(w, b, \\alpha)\\geq \\max \\min\\ L(w, b, \\alpha)$\n",
66 | "\n",
67 | "根据KKT条件:\n",
68 | "\n",
69 | "$$\\frac{\\partial }{\\partial w}L(w, b, \\alpha)=w-\\sum\\alpha_iy_ix_i=0,\\ w=\\sum\\alpha_iy_ix_i$$\n",
70 | "\n",
71 | "$$\\frac{\\partial }{\\partial b}L(w, b, \\alpha)=\\sum\\alpha_iy_i=0$$\n",
72 | "\n",
73 | "带入$ L(w, b, \\alpha)$\n",
74 | "\n",
75 | "$\\min\\ L(w, b, \\alpha)=\\frac{1}{2}||w||^2+\\sum^m_{i=1}\\alpha_i(-y_i(w^Tx_i+b)+1)$\n",
76 | "\n",
77 | "$\\qquad\\qquad\\qquad=\\frac{1}{2}w^Tw-\\sum^m_{i=1}\\alpha_iy_iw^Tx_i-b\\sum^m_{i=1}\\alpha_iy_i+\\sum^m_{i=1}\\alpha_i$\n",
78 | "\n",
79 | "$\\qquad\\qquad\\qquad=\\frac{1}{2}w^T\\sum\\alpha_iy_ix_i-\\sum^m_{i=1}\\alpha_iy_iw^Tx_i+\\sum^m_{i=1}\\alpha_i$\n",
80 | "\n",
81 | "$\\qquad\\qquad\\qquad=\\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i=1}\\alpha_iy_iw^Tx_i$\n",
82 | "\n",
83 | "$\\qquad\\qquad\\qquad=\\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)$\n",
84 | "\n",
85 | "再把max问题转成min问题:\n",
86 | "\n",
87 | "$\\max\\ \\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)=\\min \\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)-\\sum^m_{i=1}\\alpha_i$\n",
88 | "\n",
89 | "$s.t.\\ \\sum^m_{i=1}\\alpha_iy_i=0,$\n",
90 | "\n",
91 | "$ \\alpha_i \\geq 0,i=1,2,...,m$\n",
92 | "\n",
93 | "以上为SVM对偶问题的对偶形式\n",
94 | "\n",
95 | "-----\n",
96 | "#### kernel\n",
97 | "\n",
98 | "在低维空间计算获得高维空间的计算结果,也就是说计算结果满足高维(满足高维,才能说明高维下线性可分)。\n",
99 | "\n",
100 | "#### soft margin & slack variable\n",
101 | "\n",
102 | "引入松弛变量$\\xi\\geq0$,对应数据点允许偏离的functional margin 的量。\n",
103 | "\n",
104 | "目标函数:$\\min\\ \\frac{1}{2}||w||^2+C\\sum\\xi_i\\qquad s.t.\\ y_i(w^Tx_i+b)\\geq1-\\xi_i$ \n",
105 | "\n",
106 | "对偶问题:\n",
107 | "\n",
108 | "$$\\max\\ \\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)=\\min \\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)-\\sum^m_{i=1}\\alpha_i$$\n",
109 | "\n",
110 | "$$s.t.\\ C\\geq\\alpha_i \\geq 0,i=1,2,...,m\\quad \\sum^m_{i=1}\\alpha_iy_i=0,$$\n",
111 | "\n",
112 | "-----\n",
113 | "\n",
114 | "#### Sequential Minimal Optimization\n",
115 | "\n",
116 | "首先定义特征到结果的输出函数:$u=w^Tx+b$.\n",
117 | "\n",
118 | "因为$w=\\sum\\alpha_iy_ix_i$\n",
119 | "\n",
120 | "有$u=\\sum y_i\\alpha_iK(x_i, x)-b$\n",
121 | "\n",
122 | "\n",
123 | "----\n",
124 | "\n",
125 | "$\\max \\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i=1}\\sum^m_{j=1}\\alpha_i\\alpha_jy_iy_j<\\phi(x_i)^T,\\phi(x_j)>$\n",
126 | "\n",
127 | "$s.t.\\ \\sum^m_{i=1}\\alpha_iy_i=0,$\n",
128 | "\n",
129 | "$ \\alpha_i \\geq 0,i=1,2,...,m$\n",
130 | "\n",
131 | "Reference: \n",
132 | "https://www.youtube.com/watch?v=_PwhiWxHK8o \n",
133 | "https://www.youtube.com/watch?v=vywmP6Ud1HA \n",
134 | "https://www.youtube.com/watch?v=iB2VK7qPfjg\n"
135 | ]
136 | },
137 | {
138 | "cell_type": "code",
139 | "execution_count": 1,
140 | "metadata": {},
141 | "outputs": [],
142 | "source": [
143 | "import numpy as np\n",
144 | "import pandas as pd\n",
145 | "from sklearn.datasets import load_iris\n",
146 | "from sklearn.model_selection import train_test_split\n",
147 | "import matplotlib.pyplot as plt\n",
148 | "%matplotlib inline"
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": 2,
154 | "metadata": {},
155 | "outputs": [],
156 | "source": [
157 | "# data\n",
158 | "def create_data():\n",
159 | " iris = load_iris()\n",
160 | " df = pd.DataFrame(iris.data, columns=iris.feature_names)\n",
161 | " df['label'] = iris.target\n",
162 | " df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n",
163 | " data = np.array(df.iloc[:100, [0, 1, -1]])\n",
164 | " for i in range(len(data)):\n",
165 | " if data[i,-1] == 0:\n",
166 | " data[i,-1] = -1\n",
167 | " # print(data)\n",
168 | " return data[:,:2], data[:,-1]"
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": 3,
174 | "metadata": {},
175 | "outputs": [],
176 | "source": [
177 | "X, y = create_data()\n",
178 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)"
179 | ]
180 | },
181 | {
182 | "cell_type": "code",
183 | "execution_count": 161,
184 | "metadata": {},
185 | "outputs": [
186 | {
187 | "data": {
188 | "text/plain": [
189 | ""
190 | ]
191 | },
192 | "execution_count": 161,
193 | "metadata": {},
194 | "output_type": "execute_result"
195 | },
196 | {
197 | "data": {
198 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvDW2N/gAAGXxJREFUeJzt3X+MXWWdx/H3d4dZOiowaRkWmCmWVdM/bLsWRrBpQlxxF8VaGmShjb+qrN01uGBwMdYQ1IYEDQaV1WhayALCVrsVu4XlxyIs8UekyZTWdrWQoIt2CixDsa2shW3Ld/+4d+jM7Z2597n3nnuf57mfV9J07rkPp9/nHP329pzPea65OyIikpc/6XQBIiLSemruIiIZUnMXEcmQmruISIbU3EVEMqTmLiKSITV3EZEMqbmLiGRIzV1EJEPH1TvQzHqAEWCPuy+peG8lcCOwp7zpm+5+y3T7O/nkk33OnDlBxYqIdLutW7e+4O4DtcbV3dyBq4BdwIlTvP99d/9UvTubM2cOIyMjAX+8iIiY2W/rGVfXZRkzGwLeB0z7aVxEROJQ7zX3rwOfBV6dZswHzGyHmW00s9nVBpjZKjMbMbORsbGx0FpFRKRONZu7mS0Bnnf3rdMMuweY4+4LgB8Bt1cb5O5r3X3Y3YcHBmpeMhIRkQbVc819MbDUzC4EZgAnmtmd7v6h8QHuvnfC+HXAV1pbpohI4w4dOsTo6Cgvv/xyp0up24wZMxgaGqK3t7eh/75mc3f31cBqADN7J/CPExt7eftp7v5s+eVSSjdeRUSiMDo6ygknnMCcOXMws06XU5O7s3fvXkZHRznzzDMb2kfDOXczW2NmS8svrzSzX5rZL4ArgZWN7ldEpNVefvllZs2alURjBzAzZs2a1dS/NEKikLj7o8Cj5Z+vm7D9tU/3IrnZtG0PNz74JM/sO8jp/X1cc8Fcli0c7HRZEiiVxj6u2XqDmrtIt9m0bQ+r797JwUNHANiz7yCr794JoAYvUdPyAyLTuPHBJ19r7OMOHjrCjQ8+2aGKJHVPPPEEixYt4vjjj+erX/1qYX+OPrmLTOOZfQeDtovUMnPmTG6++WY2bdpU6J+jT+4i0zi9vy9ou+Rh07Y9LP7yI5z5uX9n8ZcfYdO2PbX/ozqdcsopvP3tb2844lgvNXeRaVxzwVz6ensmbevr7eGaC+Z2qCIp2vh9lj37DuIcvc/SygbfDmruItNYtnCQGy6ez2B/HwYM9vdxw8XzdTM1Y7ncZ9E1d5Eali0cVDPvIkXcZ/nWt77FunXrALjvvvs4/fTTG95XvfTJXURkgiLus1xxxRVs376d7du3t6Wxg5q7iMgkRd9nee655xgaGuKmm27i+uuvZ2hoiAMHDrRk3xPpsoyIyATjl+CKeir51FNPZXR0tCX7mo6au4hIhRzus+iyjIhIhtTcRUQypOYuIpIhNXcRkQypuYuIZEjNXbJR5GJPIs36+Mc/zimnnMK8efPa8uepuUsWclnsSfK1cuVKHnjggbb9eWrukoVcFnuSSOzYAF+bB1/sL/2+Y0PTuzzvvPOYOXNmC4qrjx5ikizoSzWkZXZsgHuuhEPl/+3s3116DbDg0s7VFUif3CUL+lINaZmH1xxt7OMOHSxtT4iau2RBX6ohLbN/inVfptoeKV2WkSwUvdiTdJGThkqXYqptT4iau2Qjh8WeJALnXzf5mjtAb19pexNWrFjBo48+ygsvvMDQ0BBf+tKXuPzyy5ssdmpq7tK0Tdv26BOz5GP8punDa0qXYk4aKjX2Jm+mrl+/vgXF1U/NXZoyni8fjyGO58sBNXhJ14JLk0rGVKMbqtIU5ctF4qTmLk1RvlxS4e6dLiFIs/WquUtTlC+XFMyYMYO9e/cm0+Ddnb179zJjxoyG96Fr7tKUay6YO+maOyhfLvEZGhpidHSUsbGxTpdStxkzZjA01Hj8Us1dmqJ8uaSgt7eXM888s9NltFXdzd3MeoARYI+7L6l473jgDuBsYC9wmbs/3cI6JWLKl4vEJ+ST+1XALuDEKu9dDvze3d9sZsuBrwCXtaA+kaQo8y+xqOuGqpkNAe8DbpliyEXA7eWfNwLnm5k1X55IOrSmvMSk3rTM14HPAq9O8f4gsBvA3Q8D+4FZTVcnkhBl/iUmNZu7mS0Bnnf3rdMNq7LtmMyRma0ysxEzG0nprrVIPZT5l5jU88l9MbDUzJ4Gvge8y8zurBgzCswGMLPjgJOAFyt35O5r3X3Y3YcHBgaaKlwkNsr8S0xqNnd3X+3uQ+4+B1gOPOLuH6oYthn4aPnnS8pj0nhaQKRFtKa8xKThnLuZrQFG3H0zcCvwXTN7itIn9uUtqk8kGcr8S0ysUx+wh4eHfWRkpCN/tohIqsxsq7sP1xqnJ1QlWtdu2sn6Lbs54k6PGSvOnc31y+Z3uiyRJKi5S5Su3bSTOx/73Wuvj7i/9loNXqQ2rQopUVq/pcp3WE6zXUQmU3OXKB2Z4l7QVNtFZDI1d4lSzxSrV0y1XUQmU3OXKK04d3bQdhGZTDdUJUrjN02VlhFpjHLuIiIJUc5dmvLBdT/nZ78+ujzQ4jfN5K5PLOpgRZ2jNdolRbrmLseobOwAP/v1i3xw3c87VFHnaI12SZWauxyjsrHX2p4zrdEuqVJzF5mG1miXVKm5i0xDa7RLqtTc5RiL3zQzaHvOtEa7pErNXY5x1ycWHdPIuzUts2zhIDdcPJ/B/j4MGOzv44aL5ystI9FTzl1EJCHKuUtTisp2h+xX+XKRxqm5yzHGs93jEcDxbDfQVHMN2W9RNYh0C11zl2MUle0O2a/y5SLNUXOXYxSV7Q7Zr/LlIs1Rc5djFJXtDtmv8uUizVFzl2MUle0O2a/y5SLN0Q1VOcb4DctWJ1VC9ltUDSLdQjl3EZGEKOdesBgy2KE1xFCziLSHmnsDYshgh9YQQ80i0j66odqAGDLYoTXEULOItI+aewNiyGCH1hBDzSLSPmruDYghgx1aQww1i0j7qLk3IIYMdmgNMdQsIu2jG6oNiCGDHVpDDDWLSPvUzLmb2Qzgx8DxlP4y2OjuX6gYsxK4ERj/Svhvuvst0+1XOXcRkXCtzLm/ArzL3V8ys17gp2Z2v7s/VjHu++7+qUaKlfa4dtNO1m/ZzRF3esxYce5srl82v+mxseTnY6lDJAY1m7uXPtq/VH7ZW/7VmcdapWHXbtrJnY/97rXXR9xfe13ZtEPGxpKfj6UOkVjUdUPVzHrMbDvwPPCQu2+pMuwDZrbDzDaa2eyWVilNW79ld93bQ8bGkp+PpQ6RWNTV3N39iLu/DRgCzjGzeRVD7gHmuPsC4EfA7dX2Y2arzGzEzEbGxsaaqVsCHZni3kq17SFjY8nPx1KHSCyCopDuvg94FHhPxfa97v5K+eU64Owp/vu17j7s7sMDAwMNlCuN6jGre3vI2Fjy87HUIRKLms3dzAbMrL/8cx/wbuCJijGnTXi5FNjVyiKleSvOrX6lrNr2kLGx5OdjqUMkFvWkZU4DbjezHkp/GWxw93vNbA0w4u6bgSvNbClwGHgRWFlUwdKY8Ruh9SRgQsbGkp+PpQ6RWGg9dxGRhGg994IVlakOyZcXue+Q+aV4LJKzYwM8vAb2j8JJQ3D+dbDg0k5XJRFTc29AUZnqkHx5kfsOmV+KxyI5OzbAPVfCoXLyZ//u0mtQg5cpaeGwBhSVqQ7Jlxe575D5pXgskvPwmqONfdyhg6XtIlNQc29AUZnqkHx5kfsOmV+KxyI5+0fDtoug5t6QojLVIfnyIvcdMr8Uj0VyThoK2y6CmntDispUh+TLi9x3yPxSPBbJOf866K34y7K3r7RdZAq6odqAojLVIfnyIvcdMr8Uj0Vyxm+aKi0jAZRzFxFJiHLucowYsuuSOOXtk6Hm3iViyK5L4pS3T4puqHaJGLLrkjjl7ZOi5t4lYsiuS+KUt0+KmnuXiCG7LolT3j4pau5dIobsuiROefuk6IZql4ghuy6JU94+Kcq5i4gkRDn3sqLy2iH7jWVdcmXXI5N7Zjz3+YXowLHIurkXldcO2W8s65Irux6Z3DPjuc8vRIeORdY3VIvKa4fsN5Z1yZVdj0zumfHc5xeiQ8ci6+ZeVF47ZL+xrEuu7Hpkcs+M5z6/EB06Flk396Ly2iH7jWVdcmXXI5N7Zjz3+YXo0LHIurkXldcO2W8s65Irux6Z3DPjuc8vRIeORdY3VIvKa4fsN5Z1yZVdj0zumfHc5xeiQ8dCOXcRkYQo514w5edFEnHv1bD1NvAjYD1w9kpYclPz+408x6/m3gDl50USce/VMHLr0dd+5OjrZhp8Ajn+rG+oFkX5eZFEbL0tbHu9Esjxq7k3QPl5kUT4kbDt9Uogx6/m3gDl50USYT1h2+uVQI5fzb0Bys+LJOLslWHb65VAjl83VBug/LxIIsZvmrY6LZNAjl85dxGRhLQs525mM4AfA8eXx2909y9UjDkeuAM4G9gLXObuTzdQd02h+fLU1jAPya7nfiwKzRGHZJ+LqqPI+UWewW5K6NxyPhbTqOeyzCvAu9z9JTPrBX5qZve7+2MTxlwO/N7d32xmy4GvAJe1utjQfHlqa5iHZNdzPxaF5ohDss9F1VHk/BLIYDcsdG45H4saat5Q9ZKXyi97y78qr+VcBNxe/nkjcL5Z62Mbofny1NYwD8mu534sCs0Rh2Sfi6qjyPklkMFuWOjccj4WNdSVljGzHjPbDjwPPOTuWyqGDAK7Adz9MLAfmFVlP6vMbMTMRsbGxoKLDc2Xp7aGeUh2PfdjUWiOOCT7XFQdRc4vgQx2w0LnlvOxqKGu5u7uR9z9bcAQcI6ZzasYUu1T+jEdyd3Xuvuwuw8PDAwEFxuaL09tDfOQ7Hrux6LQHHFI9rmoOoqcXwIZ7IaFzi3nY1FDUM7d3fcBjwLvqXhrFJgNYGbHAScBL7agvklC8+WprWEekl3P/VgUmiMOyT4XVUeR80sgg92w0LnlfCxqqCctMwAccvd9ZtYHvJvSDdOJNgMfBX4OXAI84gVkLEPz5amtYR6SXc/9WBSaIw7JPhdVR5HzSyCD3bDQueV8LGqomXM3swWUbpb2UPqkv8Hd15jZGmDE3TeX45LfBRZS+sS+3N1/M91+lXMXEQnXspy7u++g1LQrt1834eeXgb8JLVJERIqR/fIDyT24I+0R8mBLDA/BFPngTmoPacVwPhKQdXNP7sEdaY+QB1tieAimyAd3UntIK4bzkYisV4VM7sEdaY+QB1tieAimyAd3UntIK4bzkYism3tyD+5Ie4Q82BLDQzBFPriT2kNaMZyPRGTd3JN7cEfaI+TBlhgeginywZ3UHtKK4XwkIuvmntyDO9IeIQ+2xPAQTJEP7qT2kFYM5yMRWTf3ZQsHueHi+Qz292HAYH8fN1w8XzdTu92CS+H9N8NJswEr/f7+m6vfkAsZG0O9oeOLml9q+82QvqxDRCQhLXuISaTrhXyxRyxSqzmW7HosdbSAmrvIdEK+2CMWqdUcS3Y9ljpaJOtr7iJNC/lij1ikVnMs2fVY6mgRNXeR6YR8sUcsUqs5lux6LHW0iJq7yHRCvtgjFqnVHEt2PZY6WkTNXWQ6IV/sEYvUao4lux5LHS2i5i4ynSU3wfDlRz/1Wk/pdYw3JselVnMs2fVY6mgR5dxFRBKinLu0T4rZ4KJqLipfnuIxlo5Sc5fmpJgNLqrmovLlKR5j6Thdc5fmpJgNLqrmovLlKR5j6Tg1d2lOitngomouKl+e4jGWjlNzl+akmA0uquai8uUpHmPpODV3aU6K2eCiai4qX57iMZaOU3OX5qSYDS6q5qLy5SkeY+k45dxFRBJSb85dn9wlHzs2wNfmwRf7S7/v2ND+/RZVg0gg5dwlD0VlwUP2qzy6RESf3CUPRWXBQ/arPLpERM1d8lBUFjxkv8qjS0TU3CUPRWXBQ/arPLpERM1d8lBUFjxkv8qjS0TU3CUPRWXBQ/arPLpEpGbO3cxmA3cApwKvAmvd/RsVY94J/Bvw3+VNd7v7tHeRlHMXEQnXyvXcDwOfcffHzewEYKuZPeTuv6oY9xN3X9JIsRKhFNcPD6k5xfnFQMctGTWbu7s/Czxb/vkPZrYLGAQqm7vkIsW8tvLoxdNxS0rQNXczmwMsBLZUeXuRmf3CzO43s7e2oDbplBTz2sqjF0/HLSl1P6FqZm8AfgB82t0PVLz9OPBGd3/JzC4ENgFvqbKPVcAqgDPOOKPhoqVgKea1lUcvno5bUur65G5mvZQa+13ufnfl++5+wN1fKv98H9BrZidXGbfW3YfdfXhgYKDJ0qUwKea1lUcvno5bUmo2dzMz4FZgl7tXXbvUzE4tj8PMzinvd28rC5U2SjGvrTx68XTcklLPZZnFwIeBnWa2vbzt88AZAO7+HeAS4JNmdhg4CCz3Tq0lLM0bvzmWUioipOYU5xcDHbekaD13EZGEtDLnLrFS5niye6+GrbeVvpDaekpfb9fstyCJJErNPVXKHE9279UwcuvR137k6Gs1eOlCWlsmVcocT7b1trDtIplTc0+VMseT+ZGw7SKZU3NPlTLHk1lP2HaRzKm5p0qZ48nOXhm2XSRzau6p0trhky25CYYvP/pJ3XpKr3UzVbqUcu4iIglRzr0Bm7bt4cYHn+SZfQc5vb+Pay6Yy7KFg50uq3Vyz8XnPr8Y6BgnQ829bNO2Pay+eycHD5XSFXv2HWT13TsB8mjwuefic59fDHSMk6Jr7mU3Pvjka4193MFDR7jxwSc7VFGL5Z6Lz31+MdAxToqae9kz+w4GbU9O7rn43OcXAx3jpKi5l53e3xe0PTm55+Jzn18MdIyTouZeds0Fc+nrnfzAS19vD9dcMLdDFbVY7rn43OcXAx3jpOiGatn4TdNs0zK5r8Wd+/xioGOcFOXcRUQSUm/OXZdlRFKwYwN8bR58sb/0+44NaexbOkaXZURiV2S+XNn1bOmTu0jsisyXK7ueLTV3kdgVmS9Xdj1bau4isSsyX67serbU3EViV2S+XNn1bKm5i8SuyLX79b0A2VLOXUQkIcq5i4h0MTV3EZEMqbmLiGRIzV1EJENq7iIiGVJzFxHJkJq7iEiG1NxFRDJUs7mb2Wwz+08z22VmvzSzq6qMMTO72cyeMrMdZnZWMeVKU7Rut0jXqGc998PAZ9z9cTM7AdhqZg+5+68mjHkv8Jbyr3OBb5d/l1ho3W6RrlLzk7u7P+vuj5d//gOwC6j8YtGLgDu85DGg38xOa3m10jit2y3SVYKuuZvZHGAhsKXirUFg94TXoxz7FwBmtsrMRsxsZGxsLKxSaY7W7RbpKnU3dzN7A/AD4NPufqDy7Sr/yTErkrn7WncfdvfhgYGBsEqlOVq3W6Sr1NXczayXUmO/y93vrjJkFJg94fUQ8Ezz5UnLaN1uka5ST1rGgFuBXe5+0xTDNgMfKadm3gHsd/dnW1inNEvrdot0lXrSMouBDwM7zWx7edvngTMA3P07wH3AhcBTwB+Bj7W+VGnagkvVzEW6RM3m7u4/pfo19YljHLiiVUWJiEhz9ISqiEiG1NxFRDKk5i4ikiE1dxGRDKm5i4hkSM1dRCRDau4iIhmyUkS9A3+w2Rjw24784bWdDLzQ6SIKpPmlK+e5geZXjze6e83FuTrW3GNmZiPuPtzpOoqi+aUr57mB5tdKuiwjIpIhNXcRkQypuVe3ttMFFEzzS1fOcwPNr2V0zV1EJEP65C4ikqGubu5m1mNm28zs3irvrTSzMTPbXv71t52osRlm9rSZ7SzXP1LlfTOzm83sKTPbYWZndaLORtQxt3ea2f4J5y+pr5wys34z22hmT5jZLjNbVPF+sucO6ppfsufPzOZOqHu7mR0ws09XjCn8/NXzZR05uwrYBZw4xfvfd/dPtbGeIvylu0+Vq30v8Jbyr3OBb5d/T8V0cwP4ibsvaVs1rfUN4AF3v8TM/hR4XcX7qZ+7WvODRM+fuz8JvA1KHyCBPcAPK4YVfv669pO7mQ0B7wNu6XQtHXQRcIeXPAb0m9lpnS6q25nZicB5lL7eEnf/P3ffVzEs2XNX5/xycT7wa3evfGCz8PPXtc0d+DrwWeDVacZ8oPxPpo1mNnuacbFy4D/MbKuZrary/iCwe8Lr0fK2FNSaG8AiM/uFmd1vZm9tZ3FN+nNgDPjn8mXDW8zs9RVjUj539cwP0j1/Ey0H1lfZXvj568rmbmZLgOfdfes0w+4B5rj7AuBHwO1tKa61Frv7WZT+CXiFmZ1X8X61r09MJT5Va26PU3pM+y+AfwI2tbvAJhwHnAV8290XAv8LfK5iTMrnrp75pXz+AChfbloK/Gu1t6tsa+n568rmTulLv5ea2dPA94B3mdmdEwe4+153f6X8ch1wdntLbJ67P1P+/XlK1/zOqRgyCkz8F8kQ8Ex7qmtOrbm5+wF3f6n8831Ar5md3PZCGzMKjLr7lvLrjZSaYeWYJM8ddcwv8fM37r3A4+7+P1XeK/z8dWVzd/fV7j7k7nMo/bPpEXf/0MQxFde/llK68ZoMM3u9mZ0w/jPw18B/VQzbDHykfOf+HcB+d3+2zaUGq2duZnaqmVn553Mo/W99b7trbYS7PwfsNrO55U3nA7+qGJbkuYP65pfy+ZtgBdUvyUAbzl+3p2UmMbM1wIi7bwauNLOlwGHgRWBlJ2trwJ8BPyz//+M44F/c/QEz+3sAd/8OcB9wIfAU8EfgYx2qNVQ9c7sE+KSZHQYOAss9rSf2/gG4q/xP+98AH8vk3I2rNb+kz5+ZvQ74K+DvJmxr6/nTE6oiIhnqyssyIiK5U3MXEcmQmruISIbU3EVEMqTmLiKSITV3EZEMqbmLiGRIzV1EJEP/D+1KgcwTy4s9AAAAAElFTkSuQmCC\n",
199 | "text/plain": [
200 | ""
201 | ]
202 | },
203 | "metadata": {
204 | "needs_background": "light"
205 | },
206 | "output_type": "display_data"
207 | }
208 | ],
209 | "source": [
210 | "plt.scatter(X[:50,0],X[:50,1], label='-1')\n",
211 | "plt.scatter(X[50:,0],X[50:,1], label='1')\n",
212 | "plt.legend()"
213 | ]
214 | },
215 | {
216 | "cell_type": "markdown",
217 | "metadata": {},
218 | "source": [
219 | "----\n",
220 | "##### SMO算法\n",
221 | "算法7.5 P130"
222 | ]
223 | },
224 | {
225 | "cell_type": "code",
226 | "execution_count": 155,
227 | "metadata": {},
228 | "outputs": [],
229 | "source": [
230 | "class SVM:\n",
231 | " def __init__(self, max_iter=100, epsilon=0.001, C=1.0, kernel='linear'):\n",
232 | " self.max_iter = max_iter\n",
233 | " self.kernel = kernel\n",
234 | " self.epsilon = epsilon\n",
235 | " self.C = C\n",
236 | " \n",
237 | " def _init_parameters(self, X, y):\n",
238 | " '''\n",
239 | " 初始化一些参数\n",
240 | " '''\n",
241 | " self.X = X\n",
242 | " self.y = y\n",
243 | "\n",
244 | " self.b = 0.0\n",
245 | " self.M, self.N = X.shape\n",
246 | " self.alpha = np.ones(self.M)\n",
247 | " self.E = [self._E(i) for i in range(self.M)]\n",
248 | "\n",
249 | " def _kernel(self, x1, x2):\n",
250 | " #核函数\n",
251 | " if self.kernel == 'linear':\n",
252 | " return np.dot(x1, x2)\n",
253 | " \n",
254 | " def _gx(self, i):\n",
255 | " # g(x_i) 公式7.104\n",
256 | " #return np.sum(self.alpha * self.y * self._kernel(self.X, self.X[i]) + self.b)\n",
257 | " \n",
258 | " r = self.b\n",
259 | " for j in range(self.M):\n",
260 | " r += self.alpha[j]*self.y[j]*self._kernel(self.X[i], self.X[j])\n",
261 | " return r\n",
262 | " \n",
263 | " def _E(self, i):\n",
264 | " # 公式 7.105\n",
265 | " return self._gx(i) - self.y[i]\n",
266 | " \n",
267 | " def _KKT(self, i):\n",
268 | " # P130\n",
269 | " ygx = self.y[i] * self._gx(i)\n",
270 | " if self.alpha[i] == 0:\n",
271 | " return ygx >= 1\n",
272 | " elif 0 < self.alpha[i] < self.C:\n",
273 | " return ygx ==1\n",
274 | " else:\n",
275 | " return ygx <= 1\n",
276 | " \n",
277 | " def _init_alpha(self):\n",
278 | " # 按照书上7.4.2选择两个变量\n",
279 | " # 外层循环首先遍历所有满足0= 0:\n",
291 | " j = np.argmin(self.E)\n",
292 | " else:\n",
293 | " j = np.argmax(self.E)\n",
294 | " return i, j\n",
295 | " \n",
296 | " def _clip(self, alpha, L, H):\n",
297 | " if alpha > H:\n",
298 | " return H\n",
299 | " elif alpha < L:\n",
300 | " return L\n",
301 | " else:\n",
302 | " return alpha\n",
303 | " \n",
304 | " def fit(self, X, y):\n",
305 | " self._init_parameters(X, y)\n",
306 | " \n",
307 | " for _iter in range(self.max_iter):\n",
308 | " i1, i2 = self._init_alpha()\n",
309 | " \n",
310 | " #bound, P126\n",
311 | " if self.y[i1] == self.y[i2]:\n",
312 | " L = np.max((0, self.alpha[i2] + self.alpha[i1] - self.C))\n",
313 | " H = np.min((self.C, self.alpha[i2] + self.alpha[i1]))\n",
314 | " else:\n",
315 | " L = np.max((0, self.alpha[i2] - self.alpha[i1]))\n",
316 | " H = np.min((self.C, self.C + self.alpha[i2] - self.alpha[i1]))\n",
317 | " \n",
318 | " E1 = self.E[i1]\n",
319 | " E2 = self.E[i2]\n",
320 | " \n",
321 | " #eta = K11 + K22 - 2K12, 7.107\n",
322 | " eta = self._kernel(self.X[i1], self.X[i1]) + self._kernel(self.X[i2], self.X[i2]) - \\\n",
323 | " 2 * self._kernel(self.X[i1], self.X[i2])\n",
324 | " \n",
325 | " alpha2_new_unc = self.alpha[i2] + self.y[i2] * (E1 - E2) / (eta + 1e-4) # 7.106\n",
326 | " \n",
327 | " alpha2_new = self._clip(alpha2_new_unc, L, H) # 7.108\n",
328 | " \n",
329 | " alpha1_new = self.alpha[i1] + self.y[i1] * self.y[i2] * (self.alpha[i2] - alpha2_new) # 7.109\n",
330 | " \n",
331 | " b1_new = -E1 - self.y[i1] * self._kernel(self.X[i1], self.X[i1]) * (alpha1_new - self.alpha[i1]) - \\\n",
332 | " self.y[i2] * self._kernel(self.X[i2], self.X[i1]) * (alpha2_new - self.alpha[i2]) + self.b # 7.115\n",
333 | " \n",
334 | " b2_new = -E2 - self.y[i1] * self._kernel(self.X[i1], self.X[i2]) * (alpha1_new - self.alpha[i1]) - \\\n",
335 | " self.y[i2] * self._kernel(self.X[i2], self.X[i2]) * (alpha2_new - self.alpha[i2]) + self.b # 7.116\n",
336 | " \n",
337 | " if 0 < alpha1_new < self.C and 0 < alpha2_new < self.C:\n",
338 | " b_new = b1_new\n",
339 | " else:\n",
340 | " b_new = (b1_new + b2_new) / 2 # 中点, P130\n",
341 | " \n",
342 | " # update parameters\n",
343 | " self.alpha[i1] = alpha1_new\n",
344 | " self.alpha[i2] = alpha2_new\n",
345 | " self.b = b_new\n",
346 | " \n",
347 | " self.E[i1] = self._E(i1)\n",
348 | " self.E[i2] = self._E(i2)\n",
349 | " \n",
350 | " return 'Done.'\n",
351 | " \n",
352 | " def predict(self, data):\n",
353 | " r = self.b\n",
354 | " for i in range(self.M):\n",
355 | " r += self.alpha[i] * self.y[i] * self._kernel(data, self.X[i])\n",
356 | " \n",
357 | " return 1 if r > 0 else -1\n",
358 | " \n",
359 | " def score(self, X_test, y_test):\n",
360 | " right_item = 0\n",
361 | " for i in range(len(X_test)):\n",
362 | " res = self.predict(X_test[i])\n",
363 | " if res == y_test[i]:\n",
364 | " right_item += 1\n",
365 | " return right_item / len(X_test)\n",
366 | " \n",
367 | " def _weight(self):\n",
368 | " yx = self.y.reshape(-1, 1) * self.X\n",
369 | " self.w = np.dot(yx.T, self.alpha)\n",
370 | " return self.w, self.b\n",
371 | " \n",
372 | "\n",
373 | "#https://blog.csdn.net/wds2006sdo/article/details/53156589\n",
374 | "#https://github.com/fengdu78/lihang-code/blob/master/code/%E7%AC%AC7%E7%AB% \\\n",
375 | "#A0%20%E6%94%AF%E6%8C%81%E5%90%91%E9%87%8F%E6%9C%BA(SVM)/support-vector-machine.ipynb"
376 | ]
377 | },
378 | {
379 | "cell_type": "code",
380 | "execution_count": 156,
381 | "metadata": {},
382 | "outputs": [],
383 | "source": [
384 | "svm = SVM(max_iter=1000)"
385 | ]
386 | },
387 | {
388 | "cell_type": "code",
389 | "execution_count": 157,
390 | "metadata": {},
391 | "outputs": [
392 | {
393 | "data": {
394 | "text/plain": [
395 | "'Done.'"
396 | ]
397 | },
398 | "execution_count": 157,
399 | "metadata": {},
400 | "output_type": "execute_result"
401 | }
402 | ],
403 | "source": [
404 | "svm.fit(X_train, y_train)"
405 | ]
406 | },
407 | {
408 | "cell_type": "code",
409 | "execution_count": 158,
410 | "metadata": {},
411 | "outputs": [
412 | {
413 | "data": {
414 | "text/plain": [
415 | "1.0"
416 | ]
417 | },
418 | "execution_count": 158,
419 | "metadata": {},
420 | "output_type": "execute_result"
421 | }
422 | ],
423 | "source": [
424 | "svm.score(X_test, y_test)"
425 | ]
426 | },
427 | {
428 | "cell_type": "code",
429 | "execution_count": 159,
430 | "metadata": {},
431 | "outputs": [
432 | {
433 | "data": {
434 | "text/plain": [
435 | "(array([ 3.6, -5.7]), -3.8699999999999815)"
436 | ]
437 | },
438 | "execution_count": 159,
439 | "metadata": {},
440 | "output_type": "execute_result"
441 | }
442 | ],
443 | "source": [
444 | "svm._weight() #array([ 3.6, -5.7])"
445 | ]
446 | },
447 | {
448 | "cell_type": "markdown",
449 | "metadata": {},
450 | "source": [
451 | "## sklearn.svm.SVC"
452 | ]
453 | },
454 | {
455 | "cell_type": "code",
456 | "execution_count": 169,
457 | "metadata": {},
458 | "outputs": [
459 | {
460 | "name": "stderr",
461 | "output_type": "stream",
462 | "text": [
463 | "/Users/max/anaconda2/envs/pytorch/lib/python3.6/site-packages/sklearn/svm/base.py:196: FutureWarning: The default value of gamma will change from 'auto' to 'scale' in version 0.22 to account better for unscaled features. Set gamma explicitly to 'auto' or 'scale' to avoid this warning.\n",
464 | " \"avoid this warning.\", FutureWarning)\n"
465 | ]
466 | },
467 | {
468 | "data": {
469 | "text/plain": [
470 | "SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,\n",
471 | " decision_function_shape='ovr', degree=3, gamma='auto_deprecated',\n",
472 | " kernel='rbf', max_iter=-1, probability=False, random_state=None,\n",
473 | " shrinking=True, tol=0.001, verbose=False)"
474 | ]
475 | },
476 | "execution_count": 169,
477 | "metadata": {},
478 | "output_type": "execute_result"
479 | }
480 | ],
481 | "source": [
482 | "from sklearn.svm import SVC\n",
483 | "clf = SVC()\n",
484 | "clf.fit(X_train, y_train)"
485 | ]
486 | },
487 | {
488 | "cell_type": "code",
489 | "execution_count": 170,
490 | "metadata": {},
491 | "outputs": [
492 | {
493 | "data": {
494 | "text/plain": [
495 | "1.0"
496 | ]
497 | },
498 | "execution_count": 170,
499 | "metadata": {},
500 | "output_type": "execute_result"
501 | }
502 | ],
503 | "source": [
504 | "clf.score(X_test, y_test)"
505 | ]
506 | },
507 | {
508 | "cell_type": "markdown",
509 | "metadata": {},
510 | "source": [
511 | "### sklearn.svm.SVC\n",
512 | "\n",
513 | "*(C=1.0, kernel='rbf', degree=3, gamma='auto', coef0=0.0, shrinking=True, probability=False,tol=0.001, cache_size=200, class_weight=None, verbose=False, max_iter=-1, decision_function_shape=None,random_state=None)*\n",
514 | "\n",
515 | "参数:\n",
516 | "\n",
517 | "- C:C-SVC的惩罚参数C?默认值是1.0\n",
518 | "\n",
519 | "C越大,相当于惩罚松弛变量,希望松弛变量接近0,即对误分类的惩罚增大,趋向于对训练集全分对的情况,这样对训练集测试时准确率很高,但泛化能力弱。C值小,对误分类的惩罚减小,允许容错,将他们当成噪声点,泛化能力较强。\n",
520 | "\n",
521 | "- kernel :核函数,默认是rbf,可以是‘linear’, ‘poly’, ‘rbf’, ‘sigmoid’, ‘precomputed’ \n",
522 | " \n",
523 | " – 线性:u'v\n",
524 | " \n",
525 | " – 多项式:(gamma*u'*v + coef0)^degree\n",
526 | "\n",
527 | " – RBF函数:exp(-gamma|u-v|^2)\n",
528 | "\n",
529 | " – sigmoid:tanh(gamma*u'*v + coef0)\n",
530 | "\n",
531 | "\n",
532 | "- degree :多项式poly函数的维度,默认是3,选择其他核函数时会被忽略。\n",
533 | "\n",
534 | "\n",
535 | "- gamma : ‘rbf’,‘poly’ 和‘sigmoid’的核函数参数。默认是’auto’,则会选择1/n_features\n",
536 | "\n",
537 | "\n",
538 | "- coef0 :核函数的常数项。对于‘poly’和 ‘sigmoid’有用。\n",
539 | "\n",
540 | "\n",
541 | "- probability :是否采用概率估计?.默认为False\n",
542 | "\n",
543 | "\n",
544 | "- shrinking :是否采用shrinking heuristic方法,默认为true\n",
545 | "\n",
546 | "\n",
547 | "- tol :停止训练的误差值大小,默认为1e-3\n",
548 | "\n",
549 | "\n",
550 | "- cache_size :核函数cache缓存大小,默认为200\n",
551 | "\n",
552 | "\n",
553 | "- class_weight :类别的权重,字典形式传递。设置第几类的参数C为weight*C(C-SVC中的C)\n",
554 | "\n",
555 | "\n",
556 | "- verbose :允许冗余输出?\n",
557 | "\n",
558 | "\n",
559 | "- max_iter :最大迭代次数。-1为无限制。\n",
560 | "\n",
561 | "\n",
562 | "- decision_function_shape :‘ovo’, ‘ovr’ or None, default=None3\n",
563 | "\n",
564 | "\n",
565 | "- random_state :数据洗牌时的种子值,int值\n",
566 | "\n",
567 | "\n",
568 | "主要调节的参数有:C、kernel、degree、gamma、coef0。"
569 | ]
570 | }
571 | ],
572 | "metadata": {
573 | "kernelspec": {
574 | "display_name": "Python 3",
575 | "language": "python",
576 | "name": "python3"
577 | },
578 | "language_info": {
579 | "codemirror_mode": {
580 | "name": "ipython",
581 | "version": 3
582 | },
583 | "file_extension": ".py",
584 | "mimetype": "text/x-python",
585 | "name": "python",
586 | "nbconvert_exporter": "python",
587 | "pygments_lexer": "ipython3",
588 | "version": "3.6.7"
589 | }
590 | },
591 | "nbformat": 4,
592 | "nbformat_minor": 2
593 | }
594 |
--------------------------------------------------------------------------------
/第08章 提升方法/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hktxt/Learn-Statistical-Learning-Method/14a3b65db3ff35c42743a1062a2a95cecbfad094/第08章 提升方法/.DS_Store
--------------------------------------------------------------------------------
/第08章 提升方法/Adaboost.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "Adaboost.ipynb",
7 | "version": "0.3.2",
8 | "provenance": [],
9 | "collapsed_sections": []
10 | },
11 | "language_info": {
12 | "codemirror_mode": {
13 | "name": "ipython",
14 | "version": 3
15 | },
16 | "file_extension": ".py",
17 | "mimetype": "text/x-python",
18 | "name": "python",
19 | "nbconvert_exporter": "python",
20 | "pygments_lexer": "ipython3",
21 | "version": "3.6.2"
22 | },
23 | "kernelspec": {
24 | "display_name": "Python 3",
25 | "language": "python",
26 | "name": "python3"
27 | }
28 | },
29 | "cells": [
30 | {
31 | "cell_type": "markdown",
32 | "metadata": {
33 | "id": "CGJ1QiK3cnzN",
34 | "colab_type": "text"
35 | },
36 | "source": [
37 | "# 第8章 提升方法"
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "metadata": {
43 | "collapsed": true,
44 | "id": "v_MmENfgcnzN",
45 | "colab_type": "text"
46 | },
47 | "source": [
48 | "# Boost\n",
49 | "\n",
50 | "“装袋”(bagging)和“提升”(boost)是构建组合模型的两种最主要的方法,所谓的组合模型是由多个基本模型构成的模型,组合模型的预测效果往往比任意一个基本模型的效果都要好。\n",
51 | "\n",
52 | "- 装袋:每个基本模型由从总体样本中随机抽样得到的不同数据集进行训练得到,通过重抽样得到不同训练数据集的过程称为装袋。\n",
53 | "\n",
54 | "- 提升:每个基本模型训练时的数据集采用不同权重,针对上一个基本模型分类错误的样本增加权重,使得新的模型重点关注误分类样本\n",
55 | "\n",
56 | "### AdaBoost\n",
57 | "\n",
58 | "AdaBoost是AdaptiveBoost的缩写,表明该算法是具有适应性的提升算法。\n",
59 | "\n",
60 | "算法的步骤如下:\n",
61 | "\n",
62 | "1)给每个训练样本($x_{1},x_{2},….,x_{N}$)分配权重,初始权重$w_{1}$均为1/N。\n",
63 | "\n",
64 | "2)针对带有权值的样本进行训练,得到模型$G_m$(初始模型为G1)。\n",
65 | "\n",
66 | "3)计算模型$G_m$的误分率$e_m=\\sum_{i=1}^Nw_iI(y_i\\not= G_m(x_i))$\n",
67 | "\n",
68 | "4)计算模型$G_m$的系数$\\alpha_m=0.5\\log[(1-e_m)/e_m]$\n",
69 | "\n",
70 | "5)根据误分率e和当前权重向量$w_m$更新权重向量$w_{m+1}$。\n",
71 | "\n",
72 | "6)计算组合模型$f(x)=\\sum_{m=1}^M\\alpha_mG_m(x_i)$的误分率。\n",
73 | "\n",
74 | "7)当组合模型的误分率或迭代次数低于一定阈值,停止迭代;否则,回到步骤2)"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "metadata": {
80 | "id": "WkmYWWexcnzO",
81 | "colab_type": "code",
82 | "colab": {}
83 | },
84 | "source": [
85 | "import numpy as np\n",
86 | "import pandas as pd\n",
87 | "from sklearn.datasets import load_iris\n",
88 | "from sklearn.tree import DecisionTreeClassifier\n",
89 | "from sklearn.model_selection import train_test_split\n",
90 | "import matplotlib.pyplot as plt\n",
91 | "%matplotlib inline"
92 | ],
93 | "execution_count": 0,
94 | "outputs": []
95 | },
96 | {
97 | "cell_type": "code",
98 | "metadata": {
99 | "id": "kWFOcuTKcnzR",
100 | "colab_type": "code",
101 | "colab": {}
102 | },
103 | "source": [
104 | "# data\n",
105 | "def create_data():\n",
106 | " iris = load_iris()\n",
107 | " df = pd.DataFrame(iris.data, columns=iris.feature_names)\n",
108 | " df['label'] = iris.target\n",
109 | " df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n",
110 | " data = np.array(df.iloc[:100, [0, 1, -1]])\n",
111 | " for i in range(len(data)):\n",
112 | " if data[i,-1] == 0:\n",
113 | " data[i,-1] = -1\n",
114 | " # print(data)\n",
115 | " return data[:,:2], data[:,-1]"
116 | ],
117 | "execution_count": 0,
118 | "outputs": []
119 | },
120 | {
121 | "cell_type": "code",
122 | "metadata": {
123 | "id": "uk2Mg38UcnzT",
124 | "colab_type": "code",
125 | "colab": {}
126 | },
127 | "source": [
128 | "X, y = create_data()\n",
129 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)"
130 | ],
131 | "execution_count": 0,
132 | "outputs": []
133 | },
134 | {
135 | "cell_type": "code",
136 | "metadata": {
137 | "id": "FNCiiDMycnzW",
138 | "colab_type": "code",
139 | "outputId": "abb8a27d-9db0-449e-e78e-b019f70c2586",
140 | "colab": {
141 | "base_uri": "https://localhost:8080/",
142 | "height": 287
143 | }
144 | },
145 | "source": [
146 | "plt.scatter(X[:50,0],X[:50,1], label='0')\n",
147 | "plt.scatter(X[50:,0],X[50:,1], label='1')\n",
148 | "plt.legend()"
149 | ],
150 | "execution_count": 8,
151 | "outputs": [
152 | {
153 | "output_type": "execute_result",
154 | "data": {
155 | "text/plain": [
156 | ""
157 | ]
158 | },
159 | "metadata": {
160 | "tags": []
161 | },
162 | "execution_count": 8
163 | },
164 | {
165 | "output_type": "display_data",
166 | "data": {
167 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGZhJREFUeJzt3X+MXWWdx/H3d4dZOqvQCWVUmCk7\naE2jQNfCCJJuiAtxq7WWBtlS4q8qa3cNLhhcjBiC2piAS4LKkmgqZAFhi92K5cdCWQISf0RqpoDt\n2kpEQTsDuwyDLbIWaMfv/nHvtDO3M3Pvc+89c5/nuZ9X0sycc0/PfJ9z4Ns753zOc83dERGRvPxZ\nqwsQEZHmU3MXEcmQmruISIbU3EVEMqTmLiKSITV3EZEMqbmLiGRIzV1EJENq7iIiGTqi1g3NrAMY\nBIbdfXnFa2uAa4Hh8qob3P3GmfZ37LHHen9/f1CxIiLtbtu2bS+4e0+17Wpu7sClwC7g6Gle/667\nf7rWnfX39zM4OBjw40VExMx+W8t2NV2WMbM+4P3AjO/GRUQkDrVec/868DngTzNs80Ez225mm8xs\n/lQbmNlaMxs0s8GRkZHQWkVEpEZVm7uZLQeed/dtM2x2D9Dv7ouAB4FbptrI3de7+4C7D/T0VL1k\nJCIidarlmvsSYIWZLQPmAEeb2W3u/uHxDdx9dML2NwL/0twyRUSaZ//+/QwNDfHKK6+0upRpzZkz\nh76+Pjo7O+v6+1Wbu7tfAVwBYGbvBv55YmMvrz/O3Z8rL66gdONVRCRKQ0NDHHXUUfT392NmrS7n\nMO7O6OgoQ0NDnHjiiXXto+6cu5mtM7MV5cVLzOwXZvZz4BJgTb37FREp2iuvvMK8efOibOwAZsa8\nefMa+s0iJAqJuz8CPFL+/qoJ6w++uxfJzebHh7n2gSd5ds8+ju/u4vKlC1m5uLfVZUmDYm3s4xqt\nL6i5i7SbzY8Pc8WdO9i3fwyA4T37uOLOHQBq8BI1TT8gMoNrH3jyYGMft2//GNc+8GSLKpJcbNmy\nhYULF7JgwQKuueaapu9fzV1kBs/u2Re0XqQWY2NjXHzxxdx///3s3LmTDRs2sHPnzqb+DF2WEZnB\n8d1dDE/RyI/v7mpBNdIqzb7v8rOf/YwFCxbw5je/GYDVq1dz11138fa3v71ZJeudu8hMLl+6kK7O\njknrujo7uHzpwhZVJLNt/L7L8J59OIfuu2x+fLjq353O8PAw8+cfepC/r6+P4eH69zcVNXeRGaxc\n3MvV551Cb3cXBvR2d3H1eafoZmobSfW+iy7LiFSxcnGvmnkbK+K+S29vL7t37z64PDQ0RG9vc/8b\n0zt3EZEZTHd/pZH7Lu985zv51a9+xdNPP81rr73GHXfcwYoVK6r/xQBq7iIiMyjivssRRxzBDTfc\nwNKlS3nb297GqlWrOOmkkxotdfLPaOreREQyM35JrtlPKS9btoxly5Y1o8QpqbmLiFSR4n0XXZYR\nEcmQmruISIbU3EVEMqTmLiKSITV3EZEMqblLNjY/PsySax7mxM//J0uuebihuT9EivaJT3yCN7zh\nDZx88smF7F/NXbJQxOROIkVas2YNW7ZsKWz/au6ShVQnd5JEbN8IXzsZvtRd+rp9Y8O7POusszjm\nmGOaUNzU9BCTZEEfqiGF2b4R7rkE9pf/W9q7u7QMsGhV6+qqQu/cJQtFTO4kAsBD6w419nH795XW\nR0zNXbKgD9WQwuwdClsfCV2WkSwUNbmTCHP7SpdiplofMTV3yUaKkztJAs65avI1d4DOrtL6Blx4\n4YU88sgjvPDCC/T19fHlL3+Ziy66qMFiD1Fzl4Y1+8ODRaIyftP0oXWlSzFz+0qNvcGbqRs2bGhC\ncdNTc5eGjOfLx2OI4/lyQA1e8rFoVdTJmKnohqo0RPlykTipuUtDlC+XVLl7q0uYUaP1qblLQ5Qv\nlxTNmTOH0dHRaBu8uzM6OsqcOXPq3oeuuUtDLl+6cNI1d1C+XOLX19fH0NAQIyMjrS5lWnPmzKGv\nr/64pZq7NET5cklRZ2cnJ554YqvLKFTNzd3MOoBBYNjdl1e8diRwK3AaMApc4O7PNLFOiZjy5SLx\nCXnnfimwCzh6itcuAn7v7gvMbDXwVeCCJtQnkhRl/iUWNd1QNbM+4P3AjdNsci5wS/n7TcA5ZmaN\nlyeSDs0pLzGpNS3zdeBzwJ+meb0X2A3g7geAvcC8hqsTSYgy/xKTqs3dzJYDz7v7tkZ/mJmtNbNB\nMxuM+S61SD2U+ZeY1PLOfQmwwsyeAe4Azjaz2yq2GQbmA5jZEcBcSjdWJ3H39e4+4O4DPT09DRUu\nEhtl/iUmVZu7u1/h7n3u3g+sBh529w9XbHY38LHy9+eXt4nz6QCRgmhOeYlJ3Tl3M1sHDLr73cBN\nwHfM7CngRUr/CIi0FWX+JSbWqjfYAwMDPjg42JKfLSKSKjPb5u4D1bbTE6oSrSs372DD1t2MudNh\nxoVnzOcrK09pdVkiSVBzlyhduXkHtz36u4PLY+4Hl9XgRarTrJASpQ1bp/jMyhnWi8hkau4SpbFp\n7gVNt15EJlNzlyh1TDN7xXTrRWQyNXeJ0oVnzA9aLyKT6YaqRGn8pqnSMiL1Uc5dRCQhyrlLQz70\n7Z/yk1+/eHB5yVuO4fZPntnCilpHc7RLinTNXQ5T2dgBfvLrF/nQt3/aoopaR3O0S6rU3OUwlY29\n2vqcaY52SZWau8gMNEe7pErNXWQGmqNdUqXmLodZ8pZjgtbnTHO0S6rU3OUwt3/yzMMaebumZVYu\n7uXq806ht7sLA3q7u7j6vFOUlpHoKecuIpIQ5dylIUVlu0P2q3y5SP3U3OUw49nu8QjgeLYbaKi5\nhuy3qBpE2oWuucthisp2h+xX+XKRxqi5y2GKynaH7Ff5cpHGqLnLYYrKdofsV/lykcaoucthisp2\nh+xX+XKRxuiGqhxm/IZls5MqIfstqgaRdqGcu4hIQpRzL1gMGezQGmKoWURmh5p7HWLIYIfWEEPN\nIjJ7dEO1DjFksENriKFmEZk9au51iCGDHVpDDDWLyOxRc69DDBns0BpiqFlEZo+aex1iyGCH1hBD\nzSIye3RDtQ4xZLBDa4ihZhGZPVVz7mY2B/ghcCSlfww2ufsXK7ZZA1wLjH8k/A3ufuNM+1XOXUQk\nXDNz7q8CZ7v7y2bWCfzYzO5390crtvuuu3+6nmJldly5eQcbtu5mzJ0OMy48Yz5fWXlKw9vGkp+P\npQ6RGFRt7l56a/9yebGz/Kc1j7VK3a7cvIPbHv3dweUx94PLlU07ZNtY8vOx1CESi5puqJpZh5k9\nATwPPOjuW6fY7INmtt3MNpnZ/KZWKQ3bsHV3zetDto0lPx9LHSKxqKm5u/uYu78D6ANON7OTKza5\nB+h390XAg8AtU+3HzNaa2aCZDY6MjDRStwQam+beylTrQ7aNJT8fSx0isQiKQrr7HuAHwHsr1o+6\n+6vlxRuB06b5++vdfcDdB3p6euqpV+rUYVbz+pBtY8nPx1KHSCyqNncz6zGz7vL3XcB7gF9WbHPc\nhMUVwK5mFimNu/CMqa+UTbU+ZNtY8vOx1CESi1rSMscBt5hZB6V/DDa6+71mtg4YdPe7gUvMbAVw\nAHgRWFNUwVKf8RuhtSRgQraNJT8fSx0isdB87iIiCdF87gUrKlMdki8vct8h40vxWCRn+0Z4aB3s\nHYK5fXDOVbBoVaurkoipudehqEx1SL68yH2HjC/FY5Gc7Rvhnktgfzn5s3d3aRnU4GVamjisDkVl\nqkPy5UXuO2R8KR6L5Dy07lBjH7d/X2m9yDTU3OtQVKY6JF9e5L5DxpfisUjO3qGw9SKoudelqEx1\nSL68yH2HjC/FY5GcuX1h60VQc69LUZnqkHx5kfsOGV+KxyI551wFnRX/WHZ2ldaLTEM3VOtQVKY6\nJF9e5L5DxpfisUjO+E1TpWUkgHLuIiIJUc5dDhNDdl0Sp7x9MtTc20QM2XVJnPL2SdEN1TYRQ3Zd\nEqe8fVLU3NtEDNl1SZzy9klRc28TMWTXJXHK2ydFzb1NxJBdl8Qpb58U3VBtEzFk1yVxytsnRTl3\nEZGEKOdeVlReO2S/scxLrux6ZHLPjOc+vhAtOBZZN/ei8toh+41lXnJl1yOTe2Y89/GFaNGxyPqG\nalF57ZD9xjIvubLrkck9M577+EK06Fhk3dyLymuH7DeWecmVXY9M7pnx3McXokXHIuvmXlReO2S/\nscxLrux6ZHLPjOc+vhAtOhZZN/ei8toh+41lXnJl1yOTe2Y89/GFaNGxyPqGalF57ZD9xjIvubLr\nkck9M577+EK06Fgo5y4ikhDl3Aum/LxIIu69DLbdDD4G1gGnrYHl1zW+38hz/GrudVB+XiQR914G\ngzcdWvaxQ8uNNPgEcvxZ31AtivLzIonYdnPY+lolkONXc6+D8vMiifCxsPW1SiDHr+ZeB+XnRRJh\nHWHra5VAjl/NvQ7Kz4sk4rQ1YetrlUCOXzdU66D8vEgixm+aNjstk0COXzl3EZGENC3nbmZzgB8C\nR5a33+TuX6zY5kjgVuA0YBS4wN2fqaPuqkLz5anNYR6SXc/9WBSaIw7JPhdVR5HjizyD3ZDQseV8\nLGZQy2WZV4Gz3f1lM+sEfmxm97v7oxO2uQj4vbsvMLPVwFeBC5pdbGi+PLU5zEOy67kfi0JzxCHZ\n56LqKHJ8CWSw6xY6tpyPRRVVb6h6ycvlxc7yn8prOecCt5S/3wScY9b82EZovjy1OcxDsuu5H4tC\nc8Qh2eei6ihyfAlksOsWOracj0UVNaVlzKzDzJ4AngcedPetFZv0ArsB3P0AsBeYN8V+1prZoJkN\njoyMBBcbmi9PbQ7zkOx67sei0BxxSPa5qDqKHF8CGey6hY4t52NRRU3N3d3H3P0dQB9wupmdXM8P\nc/f17j7g7gM9PT3Bfz80X57aHOYh2fXcj0WhOeKQ7HNRdRQ5vgQy2HULHVvOx6KKoJy7u+8BfgC8\nt+KlYWA+gJkdAcyldGO1qULz5anNYR6SXc/9WBSaIw7JPhdVR5HjSyCDXbfQseV8LKqoJS3TA+x3\n9z1m1gW8h9IN04nuBj4G/BQ4H3jYC8hYhubLU5vDPCS7nvuxKDRHHJJ9LqqOIseXQAa7bqFjy/lY\nVFE1525miyjdLO2g9E5/o7uvM7N1wKC7312OS34HWAy8CKx299/MtF/l3EVEwjUt5+7u2yk17cr1\nV034/hXg70KLFBGRYmQ//UByD+7I7Ah5sCWGh2CKfHAntYe0YjgfCci6uSf34I7MjpAHW2J4CKbI\nB3dSe0grhvORiKxnhUzuwR2ZHSEPtsTwEEyRD+6k9pBWDOcjEVk39+Qe3JHZEfJgSwwPwRT54E5q\nD2nFcD4SkXVzT+7BHZkdIQ+2xPAQTJEP7qT2kFYM5yMRWTf35B7ckdkR8mBLDA/BFPngTmoPacVw\nPhKRdXNfubiXq887hd7uLgzo7e7i6vNO0c3UdrdoFXzgepg7H7DS1w9cP/UNuZBtY6g3dPuixpfa\nfjOkD+sQEUlI0x5iEml7IR/sEYvUao4lux5LHU2g5i4yk5AP9ohFajXHkl2PpY4myfqau0jDQj7Y\nIxap1RxLdj2WOppEzV1kJiEf7BGL1GqOJbseSx1NouYuMpOQD/aIRWo1x5Jdj6WOJlFzF5lJyAd7\nxCK1mmPJrsdSR5OouYvMZPl1MHDRoXe91lFajvHG5LjUao4lux5LHU2inLuISEKUc5fZk2I2uKia\ni8qXp3iMpaXU3KUxKWaDi6q5qHx5isdYWk7X3KUxKWaDi6q5qHx5isdYWk7NXRqTYja4qJqLypen\neIyl5dTcpTEpZoOLqrmofHmKx1haTs1dGpNiNriomovKl6d4jKXl1NylMSlmg4uquah8eYrHWFpO\nOXcRkYTUmnPXO3fJx/aN8LWT4Uvdpa/bN87+fouqQSSQcu6Sh6Ky4CH7VR5dIqJ37pKHorLgIftV\nHl0iouYueSgqCx6yX+XRJSJq7pKHorLgIftVHl0iouYueSgqCx6yX+XRJSJq7pKHorLgIftVHl0i\nUjXnbmbzgVuBNwIOrHf3b1Rs827gLuDp8qo73X3Gu0jKuYuIhGvmfO4HgM+6+2NmdhSwzcwedPed\nFdv9yN2X11OsRCjF+cNDak5xfDHQcUtG1ebu7s8Bz5W//4OZ7QJ6gcrmLrlIMa+tPHrxdNySEnTN\n3cz6gcXA1ilePtPMfm5m95vZSU2oTVolxby28ujF03FLSs1PqJrZ64HvAZ9x95cqXn4M+Et3f9nM\nlgGbgbdOsY+1wFqAE044oe6ipWAp5rWVRy+ejltSanrnbmadlBr77e5+Z+Xr7v6Su79c/v4+oNPM\njp1iu/XuPuDuAz09PQ2WLoVJMa+tPHrxdNySUrW5m5kBNwG73H3KuUvN7E3l7TCz08v7HW1moTKL\nUsxrK49ePB23pNRyWWYJ8BFgh5k9UV73BeAEAHf/FnA+8CkzOwDsA1Z7q+YSlsaN3xxLKRURUnOK\n44uBjltSNJ+7iEhCmplzl1gpczzZvZfBtptLH0htHaWPt2v0U5BEEqXmnipljie79zIYvOnQso8d\nWlaDlzakuWVSpczxZNtuDlsvkjk191QpczyZj4WtF8mcmnuqlDmezDrC1otkTs09VcocT3bamrD1\nIplTc0+V5g6fbPl1MHDRoXfq1lFa1s1UaVPKuYuIJEQ59zpsfnyYax94kmf37OP47i4uX7qQlYt7\nW11W8+Sei899fDHQMU6GmnvZ5seHueLOHezbX0pXDO/ZxxV37gDIo8HnnovPfXwx0DFOiq65l137\nwJMHG/u4ffvHuPaBJ1tUUZPlnovPfXwx0DFOipp72bN79gWtT07uufjcxxcDHeOkqLmXHd/dFbQ+\nObnn4nMfXwx0jJOi5l52+dKFdHVOfuClq7ODy5cubFFFTZZ7Lj738cVAxzgpuqFaNn7TNNu0TO5z\ncec+vhjoGCdFOXcRkYTUmnPXZRmRFGzfCF87Gb7UXfq6fWMa+5aW0WUZkdgVmS9Xdj1beucuErsi\n8+XKrmdLzV0kdkXmy5Vdz5aau0jsisyXK7ueLTV3kdgVmS9Xdj1bau4isSty7n59LkC2lHMXEUmI\ncu4iIm1MzV1EJENq7iIiGVJzFxHJkJq7iEiG1NxFRDKk5i4ikiE1dxGRDFVt7mY238x+YGY7zewX\nZnbpFNuYmV1vZk+Z2XYzO7WYcqUhmrdbpG3UMp/7AeCz7v6YmR0FbDOzB91954Rt3ge8tfznDOCb\n5a8SC83bLdJWqr5zd/fn3P2x8vd/AHYBlR8sei5wq5c8CnSb2XFNr1bqp3m7RdpK0DV3M+sHFgNb\nK17qBXZPWB7i8H8AMLO1ZjZoZoMjIyNhlUpjNG+3SFupubmb2euB7wGfcfeX6vlh7r7e3QfcfaCn\np6eeXUi9NG+3SFupqbmbWSelxn67u985xSbDwPwJy33ldRILzdst0lZqScsYcBOwy92vm2azu4GP\nllMz7wL2uvtzTaxTGqV5u0XaSi1pmSXAR4AdZvZEed0XgBMA3P1bwH3AMuAp4I/Ax5tfqjRs0So1\nc5E2UbW5u/uPAauyjQMXN6soERFpjJ5QFRHJkJq7iEiG1NxFRDKk5i4ikiE1dxGRDKm5i4hkSM1d\nRCRDVoqot+AHm40Av23JD6/uWOCFVhdRII0vXTmPDTS+Wvylu1ednKtlzT1mZjbo7gOtrqMoGl+6\nch4baHzNpMsyIiIZUnMXEcmQmvvU1re6gIJpfOnKeWyg8TWNrrmLiGRI79xFRDLU1s3dzDrM7HEz\nu3eK19aY2YiZPVH+8/etqLERZvaMme0o1z84xetmZteb2VNmtt3MTm1FnfWoYWzvNrO9E85fUh85\nZWbdZrbJzH5pZrvM7MyK15M9d1DT+JI9f2a2cELdT5jZS2b2mYptCj9/tXxYR84uBXYBR0/z+nfd\n/dOzWE8R/sbdp8vVvg94a/nPGcA3y19TMdPYAH7k7stnrZrm+gawxd3PN7M/B/6i4vXUz1218UGi\n58/dnwTeAaU3kJQ+cvT7FZsVfv7a9p27mfUB7wdubHUtLXQucKuXPAp0m9lxrS6q3ZnZXOAsSh9v\nibu/5u57KjZL9tzVOL5cnAP82t0rH9gs/Py1bXMHvg58DvjTDNt8sPwr0yYzmz/DdrFy4L/MbJuZ\nrZ3i9V5g94TlofK6FFQbG8CZZvZzM7vfzE6azeIadCIwAvxb+bLhjWb2uoptUj53tYwP0j1/E60G\nNkyxvvDz15bN3cyWA8+7+7YZNrsH6Hf3RcCDwC2zUlxz/bW7n0rpV8CLzeysVhfURNXG9hilx7T/\nCvhXYPNsF9iAI4BTgW+6+2Lg/4DPt7akpqplfCmfPwDKl5tWAP/Rip/fls2d0od+rzCzZ4A7gLPN\n7LaJG7j7qLu/Wl68EThtdktsnLsPl78+T+ma3+kVmwwDE38j6Suvi161sbn7S+7+cvn7+4BOMzt2\n1gutzxAw5O5by8ubKDXDiZI9d9QwvsTP37j3AY+5+/9O8Vrh568tm7u7X+Hufe7eT+nXpofd/cMT\nt6m4/rWC0o3XZJjZ68zsqPHvgb8F/rtis7uBj5bv3L8L2Ovuz81yqcFqGZuZvcnMrPz96ZT+Wx+d\n7Vrr4e7/A+w2s4XlVecAOys2S/LcQW3jS/n8TXAhU1+SgVk4f+2elpnEzNYBg+5+N3CJma0ADgAv\nAmtaWVsd3gh8v/z/xxHAv7v7FjP7RwB3/xZwH7AMeAr4I/DxFtUaqpaxnQ98yswOAPuA1Z7WE3v/\nBNxe/tX+N8DHMzl346qNL+nzV37T8R7gHyasm9XzpydURUQy1JaXZUREcqfmLiKSITV3EZEMqbmL\niGRIzV1EJENq7iIiGVJzFxHJkJq7iEiG/h86qpKOmdh1nwAAAABJRU5ErkJggg==\n",
168 | "text/plain": [
169 | ""
170 | ]
171 | },
172 | "metadata": {
173 | "tags": []
174 | }
175 | }
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "metadata": {
181 | "id": "9AlbKsaNhY-J",
182 | "colab_type": "code",
183 | "colab": {}
184 | },
185 | "source": [
186 | "# weak classifier\n",
187 | "weak_cla = DecisionTreeClassifier(max_depth=1)"
188 | ],
189 | "execution_count": 0,
190 | "outputs": []
191 | },
192 | {
193 | "cell_type": "code",
194 | "metadata": {
195 | "id": "AS7dimPwhp2r",
196 | "colab_type": "code",
197 | "colab": {
198 | "base_uri": "https://localhost:8080/",
199 | "height": 125
200 | },
201 | "outputId": "3f3c28e2-7fa9-437f-db5c-da74f6acb084"
202 | },
203 | "source": [
204 | "# fit\n",
205 | "weak_cla.fit(X_train, y_train)"
206 | ],
207 | "execution_count": 46,
208 | "outputs": [
209 | {
210 | "output_type": "execute_result",
211 | "data": {
212 | "text/plain": [
213 | "DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=1,\n",
214 | " max_features=None, max_leaf_nodes=None,\n",
215 | " min_impurity_decrease=0.0, min_impurity_split=None,\n",
216 | " min_samples_leaf=1, min_samples_split=2,\n",
217 | " min_weight_fraction_leaf=0.0, presort=False,\n",
218 | " random_state=None, splitter='best')"
219 | ]
220 | },
221 | "metadata": {
222 | "tags": []
223 | },
224 | "execution_count": 46
225 | }
226 | ]
227 | },
228 | {
229 | "cell_type": "code",
230 | "metadata": {
231 | "id": "xOWWIY5ph00J",
232 | "colab_type": "code",
233 | "colab": {
234 | "base_uri": "https://localhost:8080/",
235 | "height": 35
236 | },
237 | "outputId": "088e2e42-4db2-46e7-e76f-d7b2d048327c"
238 | },
239 | "source": [
240 | "weak_cla_accuracy = weak_cla.score(X_test, y_test);weak_cla_accuracy"
241 | ],
242 | "execution_count": 47,
243 | "outputs": [
244 | {
245 | "output_type": "execute_result",
246 | "data": {
247 | "text/plain": [
248 | "0.85"
249 | ]
250 | },
251 | "metadata": {
252 | "tags": []
253 | },
254 | "execution_count": 47
255 | }
256 | ]
257 | },
258 | {
259 | "cell_type": "markdown",
260 | "metadata": {
261 | "id": "9ezbeV0Tcnza",
262 | "colab_type": "text"
263 | },
264 | "source": [
265 | "----\n",
266 | "\n",
267 | "### AdaBoost\n",
268 | "算法 8.1"
269 | ]
270 | },
271 | {
272 | "cell_type": "code",
273 | "metadata": {
274 | "id": "5Ftb2TtQcnzb",
275 | "colab_type": "code",
276 | "colab": {}
277 | },
278 | "source": [
279 | "class AdaBoost:\n",
280 | " def __init__(self, n_estimators=100):\n",
281 | " self.clf_num = n_estimators\n",
282 | " \n",
283 | " def init_args(self, X, y):\n",
284 | " \n",
285 | " self.X = X\n",
286 | " self.y = y\n",
287 | " M, _ = X.shape\n",
288 | " \n",
289 | " self.models = []\n",
290 | " self.alphas = []\n",
291 | " self.weights = np.ones(M) / M # 1\n",
292 | " \n",
293 | " def fit(self, X, y):\n",
294 | " self.init_args(X, y)\n",
295 | " \n",
296 | " for n in range(self.clf_num):\n",
297 | " cla = DecisionTreeClassifier(max_depth=1) # weak cla\n",
298 | " cla.fit(X, y, sample_weight=self.weights) # 2(a)\n",
299 | " P = cla.predict(X) \n",
300 | " \n",
301 | " err = self.weights.dot(P != y) # 2(b) 8.1\n",
302 | " alpha = 0.5*(np.log(1 - err) - np.log(err)) # 2(c) 8.2\n",
303 | " \n",
304 | " self.weights = self.weights * np.exp(-alpha * y * P)\n",
305 | " self.weights = self.weights / self.weights.sum() # 2(d) 8.3, 8.4, 8.5\n",
306 | " \n",
307 | " self.models.append(cla)\n",
308 | " self.alphas.append(alpha)\n",
309 | " \n",
310 | " return 'Done!'\n",
311 | " \n",
312 | " def predict(self, x):\n",
313 | " N, _ = x.shape\n",
314 | " FX = np.zeros(N)\n",
315 | " \n",
316 | " for alpha, cla in zip(self.alphas, self.models):\n",
317 | " FX += alpha * cla.predict(x)\n",
318 | "\n",
319 | " return np.sign(FX)\n",
320 | " \n",
321 | " def score(self, X_test, y_test):\n",
322 | " p = self.predict(X_test)\n",
323 | " r = np.sum(p == y_test)\n",
324 | " \n",
325 | " return r/len(X_test)\n",
326 | " \n",
327 | " def _weights(self):\n",
328 | " return self.alphas, self.weights, self.models"
329 | ],
330 | "execution_count": 0,
331 | "outputs": []
332 | },
333 | {
334 | "cell_type": "code",
335 | "metadata": {
336 | "id": "onbrb8Xsr5DE",
337 | "colab_type": "code",
338 | "colab": {}
339 | },
340 | "source": [
341 | "adaboost = AdaBoost()"
342 | ],
343 | "execution_count": 0,
344 | "outputs": []
345 | },
346 | {
347 | "cell_type": "code",
348 | "metadata": {
349 | "id": "kae_0imhr-Ld",
350 | "colab_type": "code",
351 | "colab": {
352 | "base_uri": "https://localhost:8080/",
353 | "height": 35
354 | },
355 | "outputId": "e6dc8e11-3ede-4ead-e35a-1ba25b26e056"
356 | },
357 | "source": [
358 | "adaboost.fit(X_train, y_train)"
359 | ],
360 | "execution_count": 112,
361 | "outputs": [
362 | {
363 | "output_type": "execute_result",
364 | "data": {
365 | "text/plain": [
366 | "'Done!'"
367 | ]
368 | },
369 | "metadata": {
370 | "tags": []
371 | },
372 | "execution_count": 112
373 | }
374 | ]
375 | },
376 | {
377 | "cell_type": "code",
378 | "metadata": {
379 | "id": "bhWzbeAEsl64",
380 | "colab_type": "code",
381 | "colab": {
382 | "base_uri": "https://localhost:8080/",
383 | "height": 35
384 | },
385 | "outputId": "0944a671-36b0-4347-9da4-271ce20315be"
386 | },
387 | "source": [
388 | "adaboost.score(X_test, y_test)"
389 | ],
390 | "execution_count": 113,
391 | "outputs": [
392 | {
393 | "output_type": "execute_result",
394 | "data": {
395 | "text/plain": [
396 | "1.0"
397 | ]
398 | },
399 | "metadata": {
400 | "tags": []
401 | },
402 | "execution_count": 113
403 | }
404 | ]
405 | },
406 | {
407 | "cell_type": "markdown",
408 | "metadata": {
409 | "id": "VQhHJ_rDcnzd",
410 | "colab_type": "text"
411 | },
412 | "source": [
413 | "### 例8.1"
414 | ]
415 | },
416 | {
417 | "cell_type": "code",
418 | "metadata": {
419 | "id": "qEKxLH3Hcnzd",
420 | "colab_type": "code",
421 | "colab": {}
422 | },
423 | "source": [
424 | "X_ = np.arange(10).reshape(10, 1)\n",
425 | "y_ = np.array([1, 1, 1, -1, -1, -1, 1, 1, 1, -1])"
426 | ],
427 | "execution_count": 0,
428 | "outputs": []
429 | },
430 | {
431 | "cell_type": "code",
432 | "metadata": {
433 | "id": "iGr8lCCicnzg",
434 | "colab_type": "code",
435 | "colab": {
436 | "base_uri": "https://localhost:8080/",
437 | "height": 35
438 | },
439 | "outputId": "c1bffec1-902d-4076-e4f4-0318b9acfddd"
440 | },
441 | "source": [
442 | "clf = AdaBoost()\n",
443 | "clf.fit(X_, y_)"
444 | ],
445 | "execution_count": 115,
446 | "outputs": [
447 | {
448 | "output_type": "execute_result",
449 | "data": {
450 | "text/plain": [
451 | "'Done!'"
452 | ]
453 | },
454 | "metadata": {
455 | "tags": []
456 | },
457 | "execution_count": 115
458 | }
459 | ]
460 | },
461 | {
462 | "cell_type": "markdown",
463 | "metadata": {
464 | "id": "YTbvHicmcnzq",
465 | "colab_type": "text"
466 | },
467 | "source": [
468 | "-----\n",
469 | "# sklearn.ensemble.AdaBoostClassifier\n",
470 | "\n",
471 | "- algorithm:这个参数只有AdaBoostClassifier有。主要原因是scikit-learn实现了两种Adaboost分类算法,SAMME和SAMME.R。两者的主要区别是弱学习器权重的度量,SAMME使用了和我们的原理篇里二元分类Adaboost算法的扩展,即用对样本集分类效果作为弱学习器权重,而SAMME.R使用了对样本集分类的预测概率大小来作为弱学习器权重。由于SAMME.R使用了概率度量的连续值,迭代一般比SAMME快,因此AdaBoostClassifier的默认算法algorithm的值也是SAMME.R。我们一般使用默认的SAMME.R就够了,但是要注意的是使用了SAMME.R, 则弱分类学习器参数base_estimator必须限制使用支持概率预测的分类器。SAMME算法则没有这个限制。\n",
472 | "\n",
473 | "- n_estimators: AdaBoostClassifier和AdaBoostRegressor都有,就是我们的弱学习器的最大迭代次数,或者说最大的弱学习器的个数。一般来说n_estimators太小,容易欠拟合,n_estimators太大,又容易过拟合,一般选择一个适中的数值。默认是50。在实际调参的过程中,我们常常将n_estimators和下面介绍的参数learning_rate一起考虑。\n",
474 | "\n",
475 | "- learning_rate: AdaBoostClassifier和AdaBoostRegressor都有,即每个弱学习器的权重缩减系数ν\n",
476 | "\n",
477 | "- base_estimator:AdaBoostClassifier和AdaBoostRegressor都有,即我们的弱分类学习器或者弱回归学习器。理论上可以选择任何一个分类或者回归学习器,不过需要支持样本权重。我们常用的一般是CART决策树或者神经网络MLP。"
478 | ]
479 | },
480 | {
481 | "cell_type": "code",
482 | "metadata": {
483 | "id": "CGLto18Ycnzr",
484 | "colab_type": "code",
485 | "outputId": "4c41d66b-b820-4dda-a212-9d18db023e1b",
486 | "colab": {
487 | "base_uri": "https://localhost:8080/",
488 | "height": 53
489 | }
490 | },
491 | "source": [
492 | "from sklearn.ensemble import AdaBoostClassifier\n",
493 | "clf = AdaBoostClassifier(n_estimators=100, learning_rate=0.5)\n",
494 | "clf.fit(X_train, y_train)"
495 | ],
496 | "execution_count": 86,
497 | "outputs": [
498 | {
499 | "output_type": "execute_result",
500 | "data": {
501 | "text/plain": [
502 | "AdaBoostClassifier(algorithm='SAMME.R', base_estimator=None, learning_rate=0.5,\n",
503 | " n_estimators=100, random_state=None)"
504 | ]
505 | },
506 | "metadata": {
507 | "tags": []
508 | },
509 | "execution_count": 86
510 | }
511 | ]
512 | },
513 | {
514 | "cell_type": "code",
515 | "metadata": {
516 | "id": "XHy-rTRwcnzu",
517 | "colab_type": "code",
518 | "outputId": "41b0c4e7-100f-4b18-f25d-a15e33d6060e",
519 | "colab": {
520 | "base_uri": "https://localhost:8080/",
521 | "height": 35
522 | }
523 | },
524 | "source": [
525 | "clf.score(X_test, y_test)"
526 | ],
527 | "execution_count": 87,
528 | "outputs": [
529 | {
530 | "output_type": "execute_result",
531 | "data": {
532 | "text/plain": [
533 | "1.0"
534 | ]
535 | },
536 | "metadata": {
537 | "tags": []
538 | },
539 | "execution_count": 87
540 | }
541 | ]
542 | }
543 | ]
544 | }
--------------------------------------------------------------------------------
/第10章 隐马尔可夫模型/HMM.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "HMM.ipynb",
7 | "version": "0.3.2",
8 | "provenance": [],
9 | "collapsed_sections": []
10 | },
11 | "kernelspec": {
12 | "name": "python3",
13 | "display_name": "Python 3"
14 | }
15 | },
16 | "cells": [
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {
20 | "id": "_Esy4bIi3E4L",
21 | "colab_type": "text"
22 | },
23 | "source": [
24 | "# HMM"
25 | ]
26 | },
27 | {
28 | "cell_type": "markdown",
29 | "metadata": {
30 | "id": "jN7qEs6b3UAf",
31 | "colab_type": "text"
32 | },
33 | "source": [
34 | "Hidden Markov Model, HMM 是可用于标注问题的统计学习模型,属于**生成模型**。由**初始概率分布**,**状态转移概率分布**和**观测概率分布**确定。\n",
35 | "$\\lambda = (A, B, \\pi)$. \n",
36 | "\n",
37 | "\n",
38 | "### **两个基本假设:** \n",
39 | "1). 齐次马尔可夫性假设,即假设隐藏的马尔可夫链在任意时刻$t$的状态只依赖于其前一时刻的状态,与其他时刻状态及观测无关,也与时刻$t$无关: \n",
40 | "$P(i_{t}|i_{t-1},o_{t-1},...,i_{1},o_{1}) = P(i_{t}|i_{t-1}), t = 1,2,...,T$ \n",
41 | "\n",
42 | "2). 观察独立性假设,即假设任意时刻的观测只依赖与该时刻的马尔可夫链的状态,与其他观测及状态无关: \n",
43 | "$P(o_{t}|i_{T},o_{T},i_{T-1},o_{T-1},...,i_{t+1},o_{t+1},i_{t},i_{t-1},o_{t-1},...,i_{1},o_{1} = P(o_{t}|i_{t})$\n"
44 | ]
45 | },
46 | {
47 | "cell_type": "markdown",
48 | "metadata": {
49 | "id": "W1ADpieg1V3d",
50 | "colab_type": "text"
51 | },
52 | "source": [
53 | "### **三个基本问题**: \n",
54 | "\n",
55 | "1. **概率计算问题**。给定模型$\\lambda = (A, B, \\pi)$ 和观测序列 $O=(o_{1},o_{2},...,o_{T})$, 计算在模型 $\\lambda$ 下观测序列 $O$ 出现的概率 $P(O|\\lambda)$. \n",
56 | "Evaluate $P(O|\\lambda)$.\n",
57 | "\n",
58 | "2. **学习问题**。已知观测序列 $O=(o_{1},o_{2},...,o_{T})$, 估计模型 $\\lambda = (A, B, \\pi)$ 的参数,使得在该模型下观测序列概率 $P(O|\\lambda)$ 最大。即用极大似然估计的方法估计参数。 \n",
59 | "$\\lambda_{MLE} = argmax_{\\lambda}P(O|\\lambda)$.\n",
60 | "\n",
61 | "3. **预测问题**。也称为解码(decoding) 问题。已知模型 $\\lambda = (A, B, \\pi)$ 和观测序列 $O=(o_{1},o_{2},...,o_{T})$,求给定观测序列条件概率 $P(I|O)$ 最大的状态序列 $I = (i_{1}, i_{2}, i_{3},...,i_{T})$. 即给定观测序列,求最有可能的对应的状态序列。 \n",
62 | "$argmax_{I}P(I|O,\\lambda)$\n"
63 | ]
64 | },
65 | {
66 | "cell_type": "markdown",
67 | "metadata": {
68 | "id": "oguS4aSW4Y_Q",
69 | "colab_type": "text"
70 | },
71 | "source": [
72 | "问题1. 前向(forward)和后向(backward)算法。 \n",
73 | "问题2. Baum-Welch 算法。 \n",
74 | "问题3. 近似算法,维特比算法。 "
75 | ]
76 | },
77 | {
78 | "cell_type": "markdown",
79 | "metadata": {
80 | "id": "L6QU6kS17PgN",
81 | "colab_type": "text"
82 | },
83 | "source": [
84 | "---------------------------------------------------------------------------------------------------------------------------------"
85 | ]
86 | },
87 | {
88 | "cell_type": "markdown",
89 | "metadata": {
90 | "id": "6aNZwSt17DFG",
91 | "colab_type": "text"
92 | },
93 | "source": [
94 | "以下来自徐亦达的课程 \n",
95 | "视频地址:https://www.youtube.com/watch?v=Ji6KbkyNmk8 \n",
96 | "lecture: https://github.com/roboticcam/machine-learning-notes/blob/master/dynamic_model.pdf"
97 | ]
98 | },
99 | {
100 | "cell_type": "markdown",
101 | "metadata": {
102 | "id": "lyPGx2RqGWrp",
103 | "colab_type": "text"
104 | },
105 | "source": [
106 | "# 概率计算问题"
107 | ]
108 | },
109 | {
110 | "cell_type": "markdown",
111 | "metadata": {
112 | "id": "F_XYKDZa9wdQ",
113 | "colab_type": "text"
114 | },
115 | "source": [
116 | "---------------------------------------------------------------------------------------------------------------------------------"
117 | ]
118 | },
119 | {
120 | "cell_type": "markdown",
121 | "metadata": {
122 | "id": "596GcIzAqwRs",
123 | "colab_type": "text"
124 | },
125 | "source": [
126 | "### 例 10.2 \n"
127 | ]
128 | },
129 | {
130 | "cell_type": "markdown",
131 | "metadata": {
132 | "id": "MuSecU91q2ey",
133 | "colab_type": "text"
134 | },
135 | "source": [
136 | "考虑盒子和球模型$\\lambda = (A, B, \\pi)$ , 状态集合 $Q = {1,2,3}$, 观测集合 $V = {红,白}$, \n",
137 | "\n",
138 | "\n",
139 | "$A=\\begin{bmatrix}\n",
140 | " 0.5& 0.2& 0.3\\\\ \n",
141 | " 0.3& 0.5& 0.2\\\\ \n",
142 | " 0.2& 0.3& 0.5\n",
143 | "\\end{bmatrix}, \n",
144 | "B=\\begin{bmatrix}\n",
145 | " 0.5& 0.5\\\\ \n",
146 | " 0.4& 0.6\\\\ \n",
147 | " 0.7& 0.3\n",
148 | "\\end{bmatrix},\n",
149 | "\\pi=\\begin{bmatrix}\n",
150 | " 0.2\\\\ \n",
151 | " 0.4\\\\ \n",
152 | " 0.4\n",
153 | "\\end{bmatrix}$ \n",
154 | "\n",
155 | "设$T=3, Q={红,白,红}$,试用前向算法计算$P(O|\\lambda)$. \n",
156 | "\n",
157 | "\n",
158 | "\n"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "metadata": {
164 | "id": "gDkjQ0uJy_-M",
165 | "colab_type": "code",
166 | "colab": {}
167 | },
168 | "source": [
169 | "A = [[0.5, 0.2, 0.3], [0.3, 0.5, 0.2],[0.2, 0.3, 0.5]]\n",
170 | "B = [[0.5, 0.5], [0.4, 0.6], [0.7, 0.3]]\n",
171 | "pi = [0.2, 0.4, 0.4]\n",
172 | "Q = [0,1,0]"
173 | ],
174 | "execution_count": 0,
175 | "outputs": []
176 | },
177 | {
178 | "cell_type": "code",
179 | "metadata": {
180 | "id": "zE5Yt4jdu_VP",
181 | "colab_type": "code",
182 | "colab": {}
183 | },
184 | "source": [
185 | "import numpy as np"
186 | ],
187 | "execution_count": 0,
188 | "outputs": []
189 | },
190 | {
191 | "cell_type": "code",
192 | "metadata": {
193 | "id": "ZwRR7eFrHCnJ",
194 | "colab_type": "code",
195 | "colab": {}
196 | },
197 | "source": [
198 | "class HMM_fw:\n",
199 | " def __init__(self, A, B, pi):\n",
200 | " self.A = A # 状态转移概率\n",
201 | " self.B = B # 观测概率\n",
202 | " self.pi = pi # 初始状态\n",
203 | " \n",
204 | " def forward(self, Q):\n",
205 | " T = len(Q) # 观测序列长度,时刻T\n",
206 | " M = len(self.A) # 状态数\n",
207 | " alpha = np.zeros((T, M))\n",
208 | " \n",
209 | " for t in range(T):\n",
210 | " for m in range(M):\n",
211 | " if t == 0:\n",
212 | " alpha[t][m] = self.pi[m] * self.B[m][Q[t]]\n",
213 | " print(\"alpha[{}][{}] = pi[{}] * B[{}](Q{}) = {:.2f}\".format(t+1, m+1, m+1, m+1, Q[t]+1, alpha[t][m]))\n",
214 | " else:\n",
215 | " alpha[t][m] = sum([alpha[t-1][i] * self.A[i][m] for i in range(len(alpha[t-1]))]) * self.B[m][Q[t]]\n",
216 | " print(\"alpha[{}][{}] = {:.5f}\".format(t+1, m+1, alpha[t][m]))\n",
217 | " \n",
218 | " p = sum(alpha[T-1])\n",
219 | " #print(p)\n",
220 | " return p"
221 | ],
222 | "execution_count": 0,
223 | "outputs": []
224 | },
225 | {
226 | "cell_type": "code",
227 | "metadata": {
228 | "id": "HfqXwAO1y9G2",
229 | "colab_type": "code",
230 | "colab": {
231 | "base_uri": "https://localhost:8080/",
232 | "height": 197
233 | },
234 | "outputId": "a6dcb119-4da5-4dd6-a18c-593bfd70eaa4"
235 | },
236 | "source": [
237 | "m = HMM_fw(A, B, pi)\n",
238 | "m.forward(Q)"
239 | ],
240 | "execution_count": 79,
241 | "outputs": [
242 | {
243 | "output_type": "stream",
244 | "text": [
245 | "alpha[1][1] = pi[1] * B[1](Q1) = 0.10\n",
246 | "alpha[1][2] = pi[2] * B[2](Q1) = 0.16\n",
247 | "alpha[1][3] = pi[3] * B[3](Q1) = 0.28\n",
248 | "alpha[2][1] = 0.07700\n",
249 | "alpha[2][2] = 0.11040\n",
250 | "alpha[2][3] = 0.06060\n",
251 | "alpha[3][1] = 0.04187\n",
252 | "alpha[3][2] = 0.03551\n",
253 | "alpha[3][3] = 0.05284\n"
254 | ],
255 | "name": "stdout"
256 | },
257 | {
258 | "output_type": "execute_result",
259 | "data": {
260 | "text/plain": [
261 | "0.130218"
262 | ]
263 | },
264 | "metadata": {
265 | "tags": []
266 | },
267 | "execution_count": 79
268 | }
269 | ]
270 | },
271 | {
272 | "cell_type": "code",
273 | "metadata": {
274 | "id": "9LcfAGUVzaU3",
275 | "colab_type": "code",
276 | "colab": {}
277 | },
278 | "source": [
279 | "class HMM_bw:\n",
280 | " def __init__(self, A, B, pi):\n",
281 | " self.A = A\n",
282 | " self.B = B\n",
283 | " self.pi = pi\n",
284 | " \n",
285 | " def backward(self, Q):\n",
286 | " T = len(Q) # 观测序列长度,时刻T\n",
287 | " N = len(self.A) # 状态数\n",
288 | " beta = np.zeros((T, N))\n",
289 | " \n",
290 | " for t in range(T-1, -1, -1):\n",
291 | " for n in range(N):\n",
292 | " if t == T - 1:\n",
293 | " beta[t][n] = 1\n",
294 | " print(\"beta[{}][{}] = {:.2f}\".format(t+1, n, beta[t][n]))\n",
295 | " else:\n",
296 | " beta[t][n] = sum(self.A[n][j] * self.B[j][Q[t+1]] * beta[t+1][j] for j in range(N))\n",
297 | " print(\"beta[{}][{}] = {:.5f}\".format(t+1, n, beta[t][n]))\n",
298 | " \n",
299 | " p = sum(self.pi[i] * self.B[i][Q[0]] * beta[0][i] for i in range(N))\n",
300 | " #print(p)\n",
301 | " return p"
302 | ],
303 | "execution_count": 0,
304 | "outputs": []
305 | },
306 | {
307 | "cell_type": "code",
308 | "metadata": {
309 | "id": "rIyCzvnbBD8K",
310 | "colab_type": "code",
311 | "colab": {
312 | "base_uri": "https://localhost:8080/",
313 | "height": 197
314 | },
315 | "outputId": "48682810-a2c9-47ca-d46c-2cb3eeb545b5"
316 | },
317 | "source": [
318 | "m = HMM_bw(A, B, pi)\n",
319 | "m.backward(Q)"
320 | ],
321 | "execution_count": 77,
322 | "outputs": [
323 | {
324 | "output_type": "stream",
325 | "text": [
326 | "beta[3][0] = 1.00\n",
327 | "beta[3][1] = 1.00\n",
328 | "beta[3][2] = 1.00\n",
329 | "beta[2][0] = 0.54000\n",
330 | "beta[2][1] = 0.49000\n",
331 | "beta[2][2] = 0.57000\n",
332 | "beta[1][0] = 0.24510\n",
333 | "beta[1][1] = 0.26220\n",
334 | "beta[1][2] = 0.22770\n"
335 | ],
336 | "name": "stdout"
337 | },
338 | {
339 | "output_type": "execute_result",
340 | "data": {
341 | "text/plain": [
342 | "0.130218"
343 | ]
344 | },
345 | "metadata": {
346 | "tags": []
347 | },
348 | "execution_count": 77
349 | }
350 | ]
351 | },
352 | {
353 | "cell_type": "markdown",
354 | "metadata": {
355 | "id": "uc5CfmwRGdb6",
356 | "colab_type": "text"
357 | },
358 | "source": [
359 | "# 学习问题"
360 | ]
361 | },
362 | {
363 | "cell_type": "markdown",
364 | "metadata": {
365 | "id": "qTYQ6fO8GfMI",
366 | "colab_type": "text"
367 | },
368 | "source": [
369 | "HMM的学习,根据训练数据是否包括观测序列和对应的状态序列还是只有观测序列,可以分别为**有监督学习**和**无监督学习**来实现。"
370 | ]
371 | },
372 | {
373 | "cell_type": "markdown",
374 | "metadata": {
375 | "id": "qTRkuN4GIfnW",
376 | "colab_type": "text"
377 | },
378 | "source": [
379 | "### 监督学习 \n",
380 | "\n",
381 | "假设已给训练数据包含$S$个长度相同的观测序列和对应的状态序列${(O_{1}, I_{1}), (O_{2}, I_{2}),..., (O_{S}, I_{S})}$. 那么可以利用**极大似然估计**法来估计HMM的参数。"
382 | ]
383 | },
384 | {
385 | "cell_type": "markdown",
386 | "metadata": {
387 | "id": "neoLTatUJgks",
388 | "colab_type": "text"
389 | },
390 | "source": [
391 | "### 无监督学习 \n",
392 | "\n",
393 | "假设已给训练数据只包含$S$个长度为$T$的观测序列${O_{1}, O_{2},..., O_{S}}$, 而没有对应的状态序列, 目标是学习HMM $\\lambda = (A, B, \\pi)$ 的参数。 我们将观测序列数据看作观测数据$Q$, 状态序列数据看作不可观测的隐数据$I$, 那么HMM则是一个含有隐变量的概率模型: \n",
394 | "\n",
395 | "$P(O|\\lambda) = \\sum_{I}P(O|I, \\lambda)P(I|\\lambda)$ \n",
396 | "\n",
397 | "他的参数可以由EM算法来学习。"
398 | ]
399 | },
400 | {
401 | "cell_type": "code",
402 | "metadata": {
403 | "id": "9862N0kFBFBb",
404 | "colab_type": "code",
405 | "colab": {}
406 | },
407 | "source": [
408 | ""
409 | ],
410 | "execution_count": 0,
411 | "outputs": []
412 | },
413 | {
414 | "cell_type": "markdown",
415 | "metadata": {
416 | "id": "0lKwVVTPPpjt",
417 | "colab_type": "text"
418 | },
419 | "source": [
420 | "# 预测问题\n"
421 | ]
422 | },
423 | {
424 | "cell_type": "markdown",
425 | "metadata": {
426 | "id": "796IZ5QqPrdS",
427 | "colab_type": "text"
428 | },
429 | "source": [
430 | "### 近似算法 \n",
431 | "近似算法的想法是, 在每个时刻$t$ 选择在该时刻最有可能出现的状态$i^{*}_{t}$,从而得到一个状态序列 $I^{*} = (i^{*}_{1}, i^{*}_{2}, ..., i^{*}_{T})$, 将他作为预测的结果。"
432 | ]
433 | },
434 | {
435 | "cell_type": "markdown",
436 | "metadata": {
437 | "id": "G5XHAHOdQlKE",
438 | "colab_type": "text"
439 | },
440 | "source": [
441 | "### 维特比算法 \n"
442 | ]
443 | },
444 | {
445 | "cell_type": "markdown",
446 | "metadata": {
447 | "id": "TlCSEMGhQnRM",
448 | "colab_type": "text"
449 | },
450 | "source": [
451 | "维特比算法实际是用动态规划,解HMM的预测问题,即用动态规划求概率最大路径。"
452 | ]
453 | },
454 | {
455 | "cell_type": "code",
456 | "metadata": {
457 | "id": "gSOX42cFSM2l",
458 | "colab_type": "code",
459 | "colab": {}
460 | },
461 | "source": [
462 | "class HMM_viterbi:\n",
463 | " def __init__(self, A, B, pi):\n",
464 | " self.A = A\n",
465 | " self.B = B\n",
466 | " self.pi = pi\n",
467 | " \n",
468 | " def viterbi(self, Q):\n",
469 | " T = len(Q) # 观测序列长度,时刻T\n",
470 | " N = len(self.A) # 状态数\n",
471 | " sigma = np.zeros((T, N))\n",
472 | " delta = np.zeros((T, N))\n",
473 | " for t in range(T):\n",
474 | " for n in range(N):\n",
475 | " if t == 0:\n",
476 | " sigma[t][n] = self.pi[n] * self.B[n][Q[t]]\n",
477 | " delta[t][n] = 0\n",
478 | " print(\"sigmia[{}][{}] = {:.2f}\".format(t+1, n+1, sigma[t][n]))\n",
479 | " print(\"delta[{}][{}] = {}\".format(t+1, n+1, delta[t][n]))\n",
480 | " \n",
481 | " else:\n",
482 | " sigma[t][n] = np.max([sigma[t-1][j] * self.A[j][n] for j in range(N)]) * self.B[n][Q[t]]\n",
483 | " print(\"sigma[{}][{}] = {:.5f}\".format(t+1, n+1, sigma[t][n]))\n",
484 | " \n",
485 | " delta[t][n] = np.argmax([sigma[t-1][j] * self.A[j][n] for j in range(N)]) + 1\n",
486 | " print(\"delta[{}][{}] = {}\".format(t+1, n+1, delta[t][n]))\n",
487 | " \n",
488 | " P = np.max(sigma[T-1])\n",
489 | " print(P)\n",
490 | " pth = []\n",
491 | " for t in range(T-1, -1, -1):\n",
492 | " if t == T - 1:\n",
493 | " i_t = np.argmax(sigma[t])\n",
494 | " pth.append(i_t + 1)\n",
495 | " else:\n",
496 | " i_t = int(delta[t+1][i_t]) - 1\n",
497 | " pth.append(i_t + 1)\n",
498 | " \n",
499 | " return pth"
500 | ],
501 | "execution_count": 0,
502 | "outputs": []
503 | },
504 | {
505 | "cell_type": "markdown",
506 | "metadata": {
507 | "id": "LDEr3xz9SNSE",
508 | "colab_type": "text"
509 | },
510 | "source": [
511 | "#### 例 10.3"
512 | ]
513 | },
514 | {
515 | "cell_type": "markdown",
516 | "metadata": {
517 | "id": "oFJgnsvySRNF",
518 | "colab_type": "text"
519 | },
520 | "source": [
521 | "$A=\\begin{bmatrix}\n",
522 | " 0.5& 0.2& 0.3\\\\ \n",
523 | " 0.3& 0.5& 0.2\\\\ \n",
524 | " 0.2& 0.3& 0.5\n",
525 | "\\end{bmatrix}, \n",
526 | "B=\\begin{bmatrix}\n",
527 | " 0.5& 0.5\\\\ \n",
528 | " 0.4& 0.6\\\\ \n",
529 | " 0.7& 0.3\n",
530 | "\\end{bmatrix},\n",
531 | "\\pi=\\begin{bmatrix}\n",
532 | " 0.2\\\\ \n",
533 | " 0.4\\\\ \n",
534 | " 0.4\n",
535 | "\\end{bmatrix}$ \n",
536 | "\n",
537 | "已知观测序列$O=(红, 白, 红)$,试求最优状态序列,即最优路径 $I^{*}=(i^{*}_{1}, i^{*}_{2}, i^{*}_{3})$."
538 | ]
539 | },
540 | {
541 | "cell_type": "code",
542 | "metadata": {
543 | "id": "MXWapkQESxlq",
544 | "colab_type": "code",
545 | "colab": {}
546 | },
547 | "source": [
548 | "A = [[0.5, 0.2, 0.3], [0.3, 0.5, 0.2],[0.2, 0.3, 0.5]]\n",
549 | "B = [[0.5, 0.5], [0.4, 0.6], [0.7, 0.3]]\n",
550 | "pi = [0.2, 0.4, 0.4]\n",
551 | "Q = [0,1,0]"
552 | ],
553 | "execution_count": 0,
554 | "outputs": []
555 | },
556 | {
557 | "cell_type": "code",
558 | "metadata": {
559 | "id": "a55NPE7AS-Eu",
560 | "colab_type": "code",
561 | "colab": {
562 | "base_uri": "https://localhost:8080/",
563 | "height": 377
564 | },
565 | "outputId": "5f86553f-345a-466d-fc8a-755b65a1d167"
566 | },
567 | "source": [
568 | "m = HMM_viterbi(A, B, pi)\n",
569 | "m.viterbi(Q)"
570 | ],
571 | "execution_count": 135,
572 | "outputs": [
573 | {
574 | "output_type": "stream",
575 | "text": [
576 | "sigmia[1][1] = 0.10\n",
577 | "delta[1][1] = 0.0\n",
578 | "sigmia[1][2] = 0.16\n",
579 | "delta[1][2] = 0.0\n",
580 | "sigmia[1][3] = 0.28\n",
581 | "delta[1][3] = 0.0\n",
582 | "sigma[2][1] = 0.02800\n",
583 | "delta[2][1] = 3.0\n",
584 | "sigma[2][2] = 0.05040\n",
585 | "delta[2][2] = 3.0\n",
586 | "sigma[2][3] = 0.04200\n",
587 | "delta[2][3] = 3.0\n",
588 | "sigma[3][1] = 0.00756\n",
589 | "delta[3][1] = 2.0\n",
590 | "sigma[3][2] = 0.01008\n",
591 | "delta[3][2] = 2.0\n",
592 | "sigma[3][3] = 0.01470\n",
593 | "delta[3][3] = 3.0\n",
594 | "0.014699999999999998\n"
595 | ],
596 | "name": "stdout"
597 | },
598 | {
599 | "output_type": "execute_result",
600 | "data": {
601 | "text/plain": [
602 | "[3, 3, 3]"
603 | ]
604 | },
605 | "metadata": {
606 | "tags": []
607 | },
608 | "execution_count": 135
609 | }
610 | ]
611 | },
612 | {
613 | "cell_type": "code",
614 | "metadata": {
615 | "id": "GKPLvAFzXuva",
616 | "colab_type": "code",
617 | "colab": {}
618 | },
619 | "source": [
620 | ""
621 | ],
622 | "execution_count": 0,
623 | "outputs": []
624 | }
625 | ]
626 | }
--------------------------------------------------------------------------------
/第11章 条件随机场/CRF.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "CRF.ipynb",
7 | "version": "0.3.2",
8 | "provenance": [],
9 | "collapsed_sections": []
10 | },
11 | "kernelspec": {
12 | "name": "python3",
13 | "display_name": "Python 3"
14 | }
15 | },
16 | "cells": [
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {
20 | "id": "fo3mvewZFRAt",
21 | "colab_type": "text"
22 | },
23 | "source": [
24 | "# CRF, Conditional Random Field"
25 | ]
26 | },
27 | {
28 | "cell_type": "markdown",
29 | "metadata": {
30 | "id": "nXwJKCfSFcnL",
31 | "colab_type": "text"
32 | },
33 | "source": [
34 | "1.概率无向图模型是由无向图表示的联合概率分布。无向图上的结点之间的连接关系表示了联合分布的随机变量集合之间的条件独立性,即马尔可夫性。因此,概率无向图模型也称为马尔可夫随机场。\n",
35 | "\n",
36 | "概率无向图模型或马尔可夫随机场的联合概率分布可以分解为无向图最大团上的正值函数的乘积的形式。\n",
37 | "\n",
38 | "2.条件随机场是给定输入随机变量X条件下,输出随机变量Y的条件概率分布模型, 其形式为参数化的对数线性模型。条件随机场的最大特点是假设输出变量之间的联合概率分布构成概率无向图模型,即马尔可夫随机场。条件随机场是判别模型。\n",
39 | "\n",
40 | "3.线性链条件随机场是定义在观测序列与标记序列上的条件随机场。线性链条件随机场一般表示为给定观测序列条件下的标记序列的条件概率分布,由参数化的对数线性模型表示。模型包含特征及相应的权值,特征是定义在线性链的边与结点上的。线性链条件随机场的数学表达式是 \n",
41 | "\n",
42 | "$P(y|x)=\\frac{1}{Z(x)}exp(\\sum_{i,k} \\lambda_{k}t_{k}(y_{i-1}, y_{i}, x, i) + \\sum_{i,l}\\mu_{l}S_{l}(y_{i}, x, i))$ \n",
43 | "\n",
44 | "其中, \n",
45 | "\n",
46 | "$Z(x)=\\sum_{y}exp(\\sum_{i,k}\\lambda_{k}t_{k}(y_{i-1}, y_{i}, x, i) + \\sum_{i,l}\\mu_{l}S_{l}(y_{i}, x, i))$\n",
47 | "\n",
48 | "\n",
49 | "4.线性链条件随机场的概率计算通常利用前向-后向算法。\n",
50 | "\n",
51 | "5.条件随机场的学习方法通常是极大似然估计方法或正则化的极大似然估计,即在给定训练数据下,通过极大化训练数据的对数似然函数以估计模型参数。具体的算法有改进的迭代尺度算法、梯度下降法、拟牛顿法等。\n",
52 | "\n",
53 | "6.线性链条件随机场的一个重要应用是标注。维特比算法是给定观测序列求条件概率最大的标记序列的方法。"
54 | ]
55 | },
56 | {
57 | "cell_type": "markdown",
58 | "metadata": {
59 | "id": "SL_XKnt7H4Q2",
60 | "colab_type": "text"
61 | },
62 | "source": [
63 | "#### 例 11.1"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "metadata": {
69 | "id": "7kx_vjWOFGSd",
70 | "colab_type": "code",
71 | "colab": {}
72 | },
73 | "source": [
74 | "import numpy as np"
75 | ],
76 | "execution_count": 0,
77 | "outputs": []
78 | },
79 | {
80 | "cell_type": "code",
81 | "metadata": {
82 | "id": "RXUzhi7VHy2B",
83 | "colab_type": "code",
84 | "colab": {
85 | "base_uri": "https://localhost:8080/",
86 | "height": 53
87 | },
88 | "outputId": "19c4e0bf-f4e7-4533-8d25-d8b2a5822d31"
89 | },
90 | "source": [
91 | "#这里定义T为转移矩阵列代表前一个y(ij)代表由状态i转到状态j的概率,Tx矩阵x对应于时间序列\n",
92 | "#这里将书上的转移特征转换为如下以时间轴为区别的三个多维列表,维度为输出的维度\n",
93 | "T1 = [[0.6, 1], [1, 0]]\n",
94 | "T2 = [[0, 1], [1, 0.2]]\n",
95 | "#将书上的状态特征同样转换成列表,第一个是为y1的未规划概率,第二个为y2的未规划概率\n",
96 | "S0 = [1, 0.5]\n",
97 | "S1 = [0.8, 0.5]\n",
98 | "S2 = [0.8, 0.5]\n",
99 | "Y = [1, 2, 2] #即书上例一需要计算的非规划条件概率的标记序列\n",
100 | "Y = np.array(Y) - 1 #这里为了将数与索引相对应即从零开始\n",
101 | "P = np.exp(S0[Y[0]])\n",
102 | "for i in range(1, len(Y)):\n",
103 | " P *= np.exp((eval('S%d' % i)[Y[i]]) + eval('T%d' % i)[Y[i - 1]][Y[i]])\n",
104 | "print(P)\n",
105 | "print(np.exp(3.2))"
106 | ],
107 | "execution_count": 6,
108 | "outputs": [
109 | {
110 | "output_type": "stream",
111 | "text": [
112 | "24.532530197109345\n",
113 | "24.532530197109352\n"
114 | ],
115 | "name": "stdout"
116 | }
117 | ]
118 | },
119 | {
120 | "cell_type": "markdown",
121 | "metadata": {
122 | "id": "7lUipRWFMVB7",
123 | "colab_type": "text"
124 | },
125 | "source": [
126 | "#### 例 11.2"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "metadata": {
132 | "id": "TG6SEYCbMXty",
133 | "colab_type": "code",
134 | "colab": {
135 | "base_uri": "https://localhost:8080/",
136 | "height": 35
137 | },
138 | "outputId": "b54d58b8-60b1-4dac-b03e-8816458c5ae2"
139 | },
140 | "source": [
141 | "#这里根据例11.2的启发整合为一个矩阵\n",
142 | "F0 = S0\n",
143 | "F1 = T1 + np.array(S1 * len(T1)).reshape(np.asarray(T1).shape)\n",
144 | "F2 = T2 + np.array(S2 * len(T2)).reshape(np.asarray(T2).shape)\n",
145 | "Y = [1, 2, 2] #即书上例一需要计算的非规划条件概率的标记序列\n",
146 | "Y = np.array(Y) - 1\n",
147 | "\n",
148 | "P = np.exp(F0[Y[0]])\n",
149 | "Sum = P\n",
150 | "for i in range(1, len(Y)):\n",
151 | " PIter = np.exp((eval('F%d' % i)[Y[i - 1]][Y[i]]))\n",
152 | " P *= PIter\n",
153 | " Sum += PIter\n",
154 | "print('非规范化概率', P)"
155 | ],
156 | "execution_count": 14,
157 | "outputs": [
158 | {
159 | "output_type": "stream",
160 | "text": [
161 | "非规范化概率 24.532530197109345\n"
162 | ],
163 | "name": "stdout"
164 | }
165 | ]
166 | },
167 | {
168 | "cell_type": "markdown",
169 | "metadata": {
170 | "id": "fW5RZz89NPsD",
171 | "colab_type": "text"
172 | },
173 | "source": [
174 | "#### Reference: https://nbviewer.jupyter.org/github/fengdu78/lihang-code/blob/master/%E7%AC%AC11%E7%AB%A0%20%E6%9D%A1%E4%BB%B6%E9%9A%8F%E6%9C%BA%E5%9C%BA/11.CRF.ipynb"
175 | ]
176 | },
177 | {
178 | "cell_type": "markdown",
179 | "metadata": {
180 | "id": "P6FeGVCCM7zC",
181 | "colab_type": "text"
182 | },
183 | "source": [
184 | "### 其实,我还是没搞懂CRF,没有在具体的项目中使用。PGM本身就是一个很大的topic,就这简简单单的一章无法全部解释。"
185 | ]
186 | }
187 | ]
188 | }
--------------------------------------------------------------------------------
/第13章 无监督学习概论/Introduction_to_Unsupervised_Learning.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "kernelspec": {
6 | "display_name": "Python 3",
7 | "language": "python",
8 | "name": "python3"
9 | },
10 | "language_info": {
11 | "codemirror_mode": {
12 | "name": "ipython",
13 | "version": 3
14 | },
15 | "file_extension": ".py",
16 | "mimetype": "text/x-python",
17 | "name": "python",
18 | "nbconvert_exporter": "python",
19 | "pygments_lexer": "ipython3",
20 | "version": "3.6.4"
21 | },
22 | "colab": {
23 | "name": "Introduction_to_Unsupervised_Learning.ipynb",
24 | "version": "0.3.2",
25 | "provenance": [],
26 | "collapsed_sections": []
27 | }
28 | },
29 | "cells": [
30 | {
31 | "cell_type": "markdown",
32 | "metadata": {
33 | "id": "qipcvQWGoxDB",
34 | "colab_type": "text"
35 | },
36 | "source": [
37 | "# 第13章 无监督学习概论"
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "metadata": {
43 | "id": "j5m7ppYdoxDP",
44 | "colab_type": "text"
45 | },
46 | "source": [
47 | "1.机器学习或统计学习一般包括监督学习、无监督学习、强化学习。\n",
48 | "\n",
49 | "无监督学习是指从无标注数据中学习模型的机器学习问题。无标注数据是自然得到的数据,模型表示数据的类别、转换或概率无监督学习的本质是学习数据中的统计规律或潜在结构,主要包括聚类、降维、概率估计。\n",
50 | "\n",
51 | "2.无监督学习可以用于对已有数据的分析,也可以用于对未来数据的预测。学习得到的模型有函数$z=g(x)$,条件概率分布$P(z|x)$,或条件概率分布$P(x|z)$。\n",
52 | "\n",
53 | "无监督学习的基本想法是对给定数据(矩阵数据)进行某种“压缩”,从而找到数据的潜在结构,假定损失最小的压缩得到的结果就是最本质的结构。可以考虑发掘数据的纵向结构,对应聚类。也可以考虑发掘数据的横向结构,对应降维。还可以同时考虑发掘数据的纵向与横向结构,对应概率模型估计。\n",
54 | "\n",
55 | "3.聚类是将样本集合中相似的样本(实例)分配到相同的类,不相似的样本分配到不同的类。聚类分硬聚类和软聚类。聚类方法有层次聚类和$k$均值聚类。\n",
56 | "\n",
57 | "4.降维是将样本集合中的样本(实例)从高维空间转换到低维空间。假设样本原本存在于低维空间,或近似地存在于低维空间,通过降维则可以更好地表示样本数据的结构,即更好地表示样本之间的关系。降维有线性降维和非线性降维,降维方法有主成分分析。\n",
58 | "\n",
59 | "5.概率模型估计假设训练数据由一个概率模型生成,同时利用训练数据学习概率模型的结构和参数。概率模型包括混合模型、率图模型等。概率图模型又包括有向图模型和无向图模型。\n",
60 | "\n",
61 | "6.话题分析是文本分析的一种技术。给定一个文本集合,话题分析旨在发现文本集合中每个文本的话题,而话题由单词的集合表示。话题分析方法有潜在语义分析、概率潜在语义分析和潜在狄利克雷分配。\n",
62 | "\n",
63 | "7.图分析的目的是发掘隐藏在图中的统计规律或潜在结构。链接分析是图分析的一种,主要是发现有向图中的重要结点,包括 **PageRank**算法。"
64 | ]
65 | },
66 | {
67 | "cell_type": "markdown",
68 | "metadata": {
69 | "id": "5Y2Wgt2Lo20z",
70 | "colab_type": "text"
71 | },
72 | "source": [
73 | "#### Reference: https://github.com/fengdu78/lihang-code/blob/master/%E7%AC%AC13%E7%AB%A0%20%E6%97%A0%E7%9B%91%E7%9D%A3%E5%AD%A6%E4%B9%A0%E6%A6%82%E8%AE%BA/13.Introduction_to_Unsupervised_Learning.ipynb"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "metadata": {
79 | "id": "uScYNYproxDR",
80 | "colab_type": "code",
81 | "colab": {}
82 | },
83 | "source": [
84 | ""
85 | ],
86 | "execution_count": 0,
87 | "outputs": []
88 | }
89 | ]
90 | }
--------------------------------------------------------------------------------
/第17章 潜在语义分析/LSA.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "LSA.ipynb",
7 | "version": "0.3.2",
8 | "provenance": [],
9 | "collapsed_sections": []
10 | },
11 | "kernelspec": {
12 | "name": "python3",
13 | "display_name": "Python 3"
14 | }
15 | },
16 | "cells": [
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {
20 | "id": "LOWANK49Pi27",
21 | "colab_type": "text"
22 | },
23 | "source": [
24 | "# 潜在语义分析(Latent semantic analysis, LSA)"
25 | ]
26 | },
27 | {
28 | "cell_type": "markdown",
29 | "metadata": {
30 | "id": "0WD7XVWRPkX1",
31 | "colab_type": "text"
32 | },
33 | "source": [
34 | "**LSA** 是一种无监督学习方法,主要用于文本的话题分析,其特点是通过矩阵分解发现文本与单词之间的基于话题的语义关系。也称为潜在语义索引(Latent semantic indexing, LSI)。\n",
35 | "\n",
36 | "LSA 使用的是非概率的话题分析模型。将文本集合表示为**单词-文本矩阵**,对单词-文本矩阵进行**奇异值分解**,从而得到话题向量空间,以及文本在话题向量空间的表示。\n",
37 | "\n",
38 | "**非负矩阵分解**(non-negative matrix factorization, NMF)是另一种矩阵的因子分解方法,其特点是分解的矩阵非负。"
39 | ]
40 | },
41 | {
42 | "cell_type": "markdown",
43 | "metadata": {
44 | "id": "P1sWKgTGQ7r-",
45 | "colab_type": "text"
46 | },
47 | "source": [
48 | "## 单词向量空间 \n",
49 | "word vector space model"
50 | ]
51 | },
52 | {
53 | "cell_type": "markdown",
54 | "metadata": {
55 | "id": "CqXj1777RM8y",
56 | "colab_type": "text"
57 | },
58 | "source": [
59 | "给定一个文本,用一个向量表示该文本的”语义“, 向量的**每一维对应一个单词**,其数值为该单词在该文本中出现的频数或权值;基本假设是文本中所有单词的出现情况表示了文本的语义内容,文本集合中的每个文本都表示为一个向量,存在于一个向量空间;向量空间的度量,如内积或标准化**内积**表示文本之间的**相似度**。"
60 | ]
61 | },
62 | {
63 | "cell_type": "markdown",
64 | "metadata": {
65 | "id": "3HVCXf6CSmTT",
66 | "colab_type": "text"
67 | },
68 | "source": [
69 | "给定一个含有$n$个文本的集合$D=({d_{1}, d_{2},...,d_{n}})$,以及在所有文本中出现的$m$个单词的集合$W=({w_{1},w_{2},...,w_{m}})$. 将单词在文本的出现的数据用一个单词-文本矩阵(word-document matrix)表示,记作$X$:\n",
70 | "\n",
71 | "$\n",
72 | "X = \\begin{bmatrix}\n",
73 | "x_{11} & x_{12}& x_{1n}& \\\\ \n",
74 | "x_{21}& x_{22}& x_{2n}& \\\\ \n",
75 | "\\vdots & \\vdots & \\vdots & \\\\ \n",
76 | "x_{m1}& x_{m2}& x_{mn}& \n",
77 | "\\end{bmatrix}\n",
78 | "$\n",
79 | "\n",
80 | "这是一个$m*n$矩阵,元素$x_{ij}$表示单词$w_{i}$在文本$d_{j}$中出现的频数或权值。由于单词的种类很多,而每个文本中出现单词的种类通常较少,所有单词-文本矩阵是一个稀疏矩阵。\n",
81 | "\n"
82 | ]
83 | },
84 | {
85 | "cell_type": "markdown",
86 | "metadata": {
87 | "id": "K2ncB3cde1Ab",
88 | "colab_type": "text"
89 | },
90 | "source": [
91 | "权值通常用单词**频率-逆文本率**(term frequency-inverse document frequency, TF-IDF)表示:\n",
92 | "\n",
93 | "$TF-IDF(t, d ) = TF(t, d) * IDF(t)$, \n",
94 | "\n",
95 | "其中,$TF(t,d)$为单词$t$在文本$d$中出现的概率,$IDF(t)$是逆文本率,用来衡量单词$t$对表示语义所起的重要性, \n",
96 | "\n",
97 | "$IDF(t) = log(\\frac{len(D)}{len(t \\in D) + 1})$."
98 | ]
99 | },
100 | {
101 | "cell_type": "markdown",
102 | "metadata": {
103 | "id": "bpu7MycIgu65",
104 | "colab_type": "text"
105 | },
106 | "source": [
107 | "单词向量空间模型的优点是**模型简单,计算效率高**。因为单词向量通常是稀疏的,单词向量空间模型也有一定的局限性,体现在内积相似度未必能够准确表达两个文本的语义相似度上。因为自然语言的单词具有一词多义性(polysemy)及多词一义性(synonymy)。"
108 | ]
109 | },
110 | {
111 | "cell_type": "markdown",
112 | "metadata": {
113 | "id": "ns5wncZohn-z",
114 | "colab_type": "text"
115 | },
116 | "source": [
117 | "## 话题向量空间"
118 | ]
119 | },
120 | {
121 | "cell_type": "markdown",
122 | "metadata": {
123 | "id": "mmZpPHIdhrAy",
124 | "colab_type": "text"
125 | },
126 | "source": [
127 | "**1. 话题向量空间**:\n",
128 | "\n",
129 | "给定一个含有$n$个文本的集合$D=({d_{1}, d_{2},...,d_{n}})$,以及在所有文本中出现的$m$个单词的集合$W=({w_{1},w_{2},...,w_{m}})$. 可以获得其单词-文本矩阵$X$: \n",
130 | "\n",
131 | "$\n",
132 | "X = \\begin{bmatrix}\n",
133 | "x_{11} & x_{12}& x_{1n}& \\\\ \n",
134 | "x_{21}& x_{22}& x_{2n}& \\\\ \n",
135 | "\\vdots & \\vdots & \\vdots & \\\\ \n",
136 | "x_{m1}& x_{m2}& x_{mn}& \n",
137 | "\\end{bmatrix}\n",
138 | "$\n",
139 | "\n",
140 | "\n",
141 | "假设所有文本共含有$k$个话题。假设每个话题由一个定义在单词集合$W$上的$m$维向量表示,称为话题向量,即: \n",
142 | "$t_{l} = \\begin{bmatrix}\n",
143 | "t_{1l}\\\\ \n",
144 | "t_{2l}\\\\ \n",
145 | "\\vdots \\\\ \n",
146 | "t_{ml}\\end{bmatrix}, l=1,2,...,k$\n",
147 | "\n",
148 | "其中$t_{il}$单词$w_{i}$在话题$t_{l}$的权值,$i=1,2,...,m$, 权值越大,该单词在该话题中的重要程度就越高。这$k$个话题向量 $t_{1},t_{2},...,t_{k}$张成一个话题向量空间(topic vector space), 维数为$k$.**话题向量空间是单词向量空间的一个子空间**。\n",
149 | "\n",
150 | "话题向量空间$T$: \n",
151 | "\n",
152 | "\n",
153 | "$\n",
154 | "T = \\begin{bmatrix}\n",
155 | "t_{11} & t_{12}& t_{1k}& \\\\ \n",
156 | "t_{21}& t_{22}& t_{2k}& \\\\ \n",
157 | "\\vdots & \\vdots & \\vdots & \\\\ \n",
158 | "t_{m1}& t_{m2}& t_{mk}& \n",
159 | "\\end{bmatrix}\n",
160 | "$ \n",
161 | "\n",
162 | "矩阵$T$,称为**单词-话题矩阵**。 $T = [t_{1}, t_{2}, ..., t_{k}]$"
163 | ]
164 | },
165 | {
166 | "cell_type": "markdown",
167 | "metadata": {
168 | "id": "Oc1c3JcKlTBD",
169 | "colab_type": "text"
170 | },
171 | "source": [
172 | "**2. 文本在话题向量空间中的表示** :\n",
173 | "\n",
174 | "考虑文本集合$D$的文本$d_{j}$, 在单词向量空间中由一个向量$x_{j}$表示,将$x_{j}$投影到话题向量空间$T$中,得到话题向量空间的一个向量$y_{j}$, $y_{j}$是一个$k$维向量: \n",
175 | "\n",
176 | "$y_{j} = \\begin{bmatrix}\n",
177 | "y_{1j}\\\\ \n",
178 | "y_{2j}\\\\ \n",
179 | "\\vdots \\\\ \n",
180 | "y_{kj}\\end{bmatrix}, j=1,2,...,n$ \n",
181 | "\n",
182 | "其中,$y_{lj}$是文本$d_{j}$在话题$t_{l}$中的权值, $l = 1,2,..., k$, 权值越大,该话题在该文本中的重要程度就越高。 \n",
183 | "\n",
184 | "矩阵$Y$ 表示话题在文本中出现的情况,称为话题-文本矩阵(topic-document matrix),记作: \n",
185 | "\n",
186 | "$\n",
187 | "Y = \\begin{bmatrix}\n",
188 | "y_{11} & y_{12}& y_{1n}& \\\\ \n",
189 | "y_{21}& y_{22}& y_{2n}& \\\\ \n",
190 | "\\vdots & \\vdots & \\vdots & \\\\ \n",
191 | "y_{k1}& y_{k2}& y_{kn}& \n",
192 | "\\end{bmatrix}\n",
193 | "$ \n",
194 | "\n",
195 | "也可写成: $Y = [y_{1}, y_{2} ..., y_{n}]$"
196 | ]
197 | },
198 | {
199 | "cell_type": "markdown",
200 | "metadata": {
201 | "id": "YcU3xwYindTo",
202 | "colab_type": "text"
203 | },
204 | "source": [
205 | "**3. 从单词向量空间到话题向量空间的线性变换**: \n",
206 | "\n",
207 | "如此,单词向量空间的文本向量$x_{j}$可以通过他在话题空间中的向量$y_{j}$近似表示,具体地由$k$个话题向量以$y_{j}$为系数的线性组合近似表示: \n",
208 | "\n",
209 | "$x_{j} = y_{1j}t_{1} + y_{2j}t_{2} + ... + y_{yj}t_{k}, j = 1,2,..., n$ \n",
210 | "\n",
211 | "所以,单词-文本矩阵$X$可以近似的表示为单词-话题矩阵$T$与话题-文本矩阵$Y$的乘积形式。\n",
212 | "\n",
213 | "$X \\approx TY$ \n",
214 | "\n",
215 | "直观上,潜在语义分析是将单词向量空间的表示通过线性变换转换为在话题向量空间中的表示。这个线性变换由矩阵因子分解式的形式体现。"
216 | ]
217 | },
218 | {
219 | "cell_type": "markdown",
220 | "metadata": {
221 | "id": "Cu4JekfXFMqs",
222 | "colab_type": "text"
223 | },
224 | "source": [
225 | "### 潜在语义分析算法 "
226 | ]
227 | },
228 | {
229 | "cell_type": "markdown",
230 | "metadata": {
231 | "id": "0awNXCy1Gw0K",
232 | "colab_type": "text"
233 | },
234 | "source": [
235 | "潜在语义分析利用矩阵奇异值分解,具体地,对单词-文本矩阵进行奇异值分解,将其左矩阵作为话题向量空间,将其对角矩阵与右矩阵的乘积作为文本在话题向量空间的表示。"
236 | ]
237 | },
238 | {
239 | "cell_type": "markdown",
240 | "metadata": {
241 | "id": "otq3HMu5HVoK",
242 | "colab_type": "text"
243 | },
244 | "source": [
245 | "给定一个含有$n$个文本的集合$D=({d_{1}, d_{2},...,d_{n}})$,以及在所有文本中出现的$m$个单词的集合$W=({w_{1},w_{2},...,w_{m}})$. 可以获得其单词-文本矩阵$X$: \n",
246 | "$\n",
247 | "X = \\begin{bmatrix}\n",
248 | "x_{11} & x_{12}& x_{1n}& \\\\ \n",
249 | "x_{21}& x_{22}& x_{2n}& \\\\ \n",
250 | "\\vdots & \\vdots & \\vdots & \\\\ \n",
251 | "x_{m1}& x_{m2}& x_{mn}& \n",
252 | "\\end{bmatrix}\n",
253 | "$\n",
254 | "\n",
255 | "\n",
256 | "\n",
257 | "\n"
258 | ]
259 | },
260 | {
261 | "cell_type": "markdown",
262 | "metadata": {
263 | "id": "mwNGRDgrHmmV",
264 | "colab_type": "text"
265 | },
266 | "source": [
267 | "**截断奇异值分解**:\n",
268 | "\n",
269 | "潜在语义分析根据确定的话题数$k$对单词-文本矩阵$X$进行截断奇异值分解: \n",
270 | "\n",
271 | "$\n",
272 | "X \\approx U_{k}\\Sigma _{k}V_{k}^{T} = \\begin{bmatrix}\n",
273 | "\\mu _{1} & \\mu _{2}& \\cdots & \\mu _{k}\n",
274 | "\\end{bmatrix}\\begin{bmatrix}\n",
275 | "\\sigma_{1} & 0& 0& 0\\\\ \n",
276 | " 0& \\sigma_{2}& 0& 0\\\\ \n",
277 | " 0& 0& \\ddots & 0\\\\ \n",
278 | " 0& 0& 0& \\sigma_{k}\n",
279 | "\\end{bmatrix}\\begin{bmatrix}\n",
280 | "v_{1}^{T}\\\\ \n",
281 | "v_{2}^{T}\\\\ \n",
282 | "\\vdots \\\\ \n",
283 | "v_{k}^{T}\\end{bmatrix}\n",
284 | "$\n",
285 | "\n",
286 | "矩阵$U_{k}$的每一个列向量 $u_{1}, u_{2},..., u_{k}$ 表示一个话题,称为**话题向量**。由这 $k$ 个话题向量张成一个子空间: \n",
287 | "\n",
288 | "$\n",
289 | "U_{k} = \\begin{bmatrix}\n",
290 | "u_{1} & u_{2}& \\cdots & u_{k}\n",
291 | "\\end{bmatrix}\n",
292 | "$\n",
293 | "\n",
294 | "称为**话题向量空间**。 \n",
295 | "\n",
296 | "综上, 可以通过对单词-文本矩阵的奇异值分解进行潜在语义分析: \n",
297 | "\n",
298 | "$ X \\approx U_{k} \\Sigma_{k} V_{k}^{T} = U_{k}(\\Sigma_{k}V_{k}^{T})$ \n",
299 | "\n",
300 | "得到话题空间 $U_{k}$ , 以及文本在话题空间的表示($\\Sigma_{k}V_{k}^{T}$). "
301 | ]
302 | },
303 | {
304 | "cell_type": "markdown",
305 | "metadata": {
306 | "id": "UTNHyq8mK8l5",
307 | "colab_type": "text"
308 | },
309 | "source": [
310 | "### 非负矩阵分解算法"
311 | ]
312 | },
313 | {
314 | "cell_type": "markdown",
315 | "metadata": {
316 | "id": "RQvqaMDYK_jf",
317 | "colab_type": "text"
318 | },
319 | "source": [
320 | "非负矩阵分解也可以用于话题分析。对单词-文本矩阵进行非负矩阵分解,将**其左矩阵作为话题向量空间**,将其**右矩阵作为文本在话题向量空间的表示**。"
321 | ]
322 | },
323 | {
324 | "cell_type": "markdown",
325 | "metadata": {
326 | "id": "ApM8tE3MLqpP",
327 | "colab_type": "text"
328 | },
329 | "source": [
330 | "#### 非负矩阵分解 "
331 | ]
332 | },
333 | {
334 | "cell_type": "markdown",
335 | "metadata": {
336 | "id": "glMwmkiwLyIn",
337 | "colab_type": "text"
338 | },
339 | "source": [
340 | "若一个矩阵的索引元素非负,则该矩阵为非负矩阵。若$X$是非负矩阵,则: $X >= 0$. \n",
341 | "\n",
342 | "给定一个非负矩阵$X$, 找到两个非负矩阵$W >= 0$ 和 $H>= 0$, 使得: \n",
343 | "\n",
344 | "$ X \\approx WH$\n",
345 | "\n",
346 | "即非负矩阵$X$分解为两个非负矩阵$W$和$H$的乘积形式,成为非负矩阵分解。因为$WH$与$X$完全相等很难实现,所以只要求近似相等。 \n",
347 | "\n",
348 | "假设非负矩阵$X$是$m\\times n$矩阵,非负矩阵$W$和$H$分别为 $m\\times k$ 矩阵和 $k\\times n$ 矩阵。假设 $k < min(m, n)$ 即$W$ 和 $H$ 小于原矩阵 $X$, 所以非负矩阵分解是对原数据的压缩。\n",
349 | "\n",
350 | "称 $W$ 为基矩阵, $H$ 为系数矩阵。非负矩阵分解旨在用较少的基向量,系数向量来表示为较大的数据矩阵。\n",
351 | "\n",
352 | "令 $W = \\begin{bmatrix}\n",
353 | "w_{1} & w_{2}& \\cdots& w_{k} \n",
354 | "\\end{bmatrix}$\n",
355 | "为话题向量空间, $w_{1}, w_{2}, ..., w_{k}$ 表示文本集合的 $k$ 个话题, 令 $H = \\begin{bmatrix}\n",
356 | "h_{1} & h_{2}& \\cdots& h_{n} \n",
357 | "\\end{bmatrix}$\n",
358 | "为文本在话题向量空间的表示, $h_{1}, h_{2},..., h_{n}$ 表示文本集合的 $n$ 个文本。"
359 | ]
360 | },
361 | {
362 | "cell_type": "markdown",
363 | "metadata": {
364 | "id": "1DcvVSR0N_CF",
365 | "colab_type": "text"
366 | },
367 | "source": [
368 | "##### 算法"
369 | ]
370 | },
371 | {
372 | "cell_type": "markdown",
373 | "metadata": {
374 | "id": "hvZyHT85O5qt",
375 | "colab_type": "text"
376 | },
377 | "source": [
378 | "非负矩阵分解可以形式化为最优化问题求解。可以利用平方损失或散度来作为损失函数。\n",
379 | "\n",
380 | "目标函数 $|| X - WH ||^{2}$ 关于 $W$ 和 $H$ 的最小化,满足约束条件 $W, H >= 0$, 即: \n",
381 | "\n",
382 | "$\\underset{W,H}{min} || X - WH ||^{2}$ \n",
383 | "\n",
384 | "\n",
385 | "$s.t. W, H >= 0$"
386 | ]
387 | },
388 | {
389 | "cell_type": "markdown",
390 | "metadata": {
391 | "id": "-zIGS1AEQdWp",
392 | "colab_type": "text"
393 | },
394 | "source": [
395 | "乘法更新规则: \n",
396 | "\n",
397 | "\n",
398 | "$W_{il} \\leftarrow W_{il}\\frac{(XH^{T})_{il}}{(WHH^{T})_{il}}$ (17.33)\n",
399 | "\n",
400 | "\n",
401 | "$H_{lj} \\leftarrow H_{lj}\\frac{(W^{T}X)_{lj}}{(W^{T}WH)_{lj}}$ (17.34)\n",
402 | "\n",
403 | "\n",
404 | "选择初始矩阵 $W$ 和 $H$ 为非负矩阵,可以保证迭代过程及结果的矩阵 $W$ 和 $H$ 非负。"
405 | ]
406 | },
407 | {
408 | "cell_type": "markdown",
409 | "metadata": {
410 | "id": "MeiA0REkRpRi",
411 | "colab_type": "text"
412 | },
413 | "source": [
414 | "**算法 17.1 (非负矩阵分解的迭代算法)**\n",
415 | "\n",
416 | "输入: 单词-文本矩阵 $X >= 0$, 文本集合的话题个数 $k$, 最大迭代次数 $t$; \n",
417 | "输出: 话题矩阵 $W$, 文本表示矩阵 $H$。 \n",
418 | "\n",
419 | "**1)**. 初始化\n",
420 | "\n",
421 | "$W>=0$, 并对 $W$ 的每一列数据归一化; \n",
422 | "$H>=0$;\n",
423 | "\n",
424 | "**2)**. 迭代 \n",
425 | "\n",
426 | "对迭代次数由1到$t$执行下列步骤: \n",
427 | "a. 更新$W$的元素,对 $l$ 从1到 $k,i$从1到$m$按(17.33)更新 $W_{il}$; \n",
428 | "a. 更新$H$的元素,对 $l$ 从1到 $k,j$从1到$m$按(17.34)更新 $H_{lj}$; "
429 | ]
430 | },
431 | {
432 | "cell_type": "markdown",
433 | "metadata": {
434 | "id": "rIw6a0HITg08",
435 | "colab_type": "text"
436 | },
437 | "source": [
438 | "### 图例 17.1"
439 | ]
440 | },
441 | {
442 | "cell_type": "code",
443 | "metadata": {
444 | "id": "0hPH9VEMPVGu",
445 | "colab_type": "code",
446 | "colab": {}
447 | },
448 | "source": [
449 | "import numpy as np\n",
450 | "from sklearn.decomposition import TruncatedSVD"
451 | ],
452 | "execution_count": 0,
453 | "outputs": []
454 | },
455 | {
456 | "cell_type": "code",
457 | "metadata": {
458 | "id": "kjHirYzQWItl",
459 | "colab_type": "code",
460 | "outputId": "0e6f2615-6a0b-4c4f-e74c-559727519eab",
461 | "colab": {
462 | "base_uri": "https://localhost:8080/",
463 | "height": 125
464 | }
465 | },
466 | "source": [
467 | "X = [[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 2, 3], [0, 0, 0, 1], [1, 2, 2, 1]]\n",
468 | "X = np.asarray(X);X"
469 | ],
470 | "execution_count": 0,
471 | "outputs": [
472 | {
473 | "output_type": "execute_result",
474 | "data": {
475 | "text/plain": [
476 | "array([[2, 0, 0, 0],\n",
477 | " [0, 2, 0, 0],\n",
478 | " [0, 0, 1, 0],\n",
479 | " [0, 0, 2, 3],\n",
480 | " [0, 0, 0, 1],\n",
481 | " [1, 2, 2, 1]])"
482 | ]
483 | },
484 | "metadata": {
485 | "tags": []
486 | },
487 | "execution_count": 2
488 | }
489 | ]
490 | },
491 | {
492 | "cell_type": "code",
493 | "metadata": {
494 | "id": "I2yFnNJKWcPP",
495 | "colab_type": "code",
496 | "colab": {}
497 | },
498 | "source": [
499 | "# 奇异值分解\n",
500 | "U,sigma,VT=np.linalg.svd(X)"
501 | ],
502 | "execution_count": 0,
503 | "outputs": []
504 | },
505 | {
506 | "cell_type": "code",
507 | "metadata": {
508 | "id": "ollDH_QNXAdY",
509 | "colab_type": "code",
510 | "outputId": "8d599d82-4bae-4047-941c-8ca524142349",
511 | "colab": {
512 | "base_uri": "https://localhost:8080/",
513 | "height": 233
514 | }
515 | },
516 | "source": [
517 | "U"
518 | ],
519 | "execution_count": 0,
520 | "outputs": [
521 | {
522 | "output_type": "execute_result",
523 | "data": {
524 | "text/plain": [
525 | "array([[-7.84368672e-02, 2.84423033e-01, 8.94427191e-01,\n",
526 | " 2.15138396e-01, -2.68931121e-02, -2.56794523e-01],\n",
527 | " [-1.56873734e-01, 5.68846066e-01, -4.47213595e-01,\n",
528 | " 4.30276793e-01, -5.37862243e-02, -5.13589047e-01],\n",
529 | " [-1.42622354e-01, -1.37930417e-02, 4.16333634e-17,\n",
530 | " -6.53519444e-01, 4.77828945e-01, -5.69263078e-01],\n",
531 | " [-7.28804669e-01, -5.53499910e-01, 3.33066907e-16,\n",
532 | " 1.56161345e-01, -2.92700697e-01, -2.28957508e-01],\n",
533 | " [-1.47853320e-01, -1.75304609e-01, 1.04083409e-16,\n",
534 | " 4.87733411e-01, 8.24315866e-01, 1.73283476e-01],\n",
535 | " [-6.29190197e-01, 5.08166890e-01, -4.44089210e-16,\n",
536 | " -2.81459486e-01, 5.37862243e-02, 5.13589047e-01]])"
537 | ]
538 | },
539 | "metadata": {
540 | "tags": []
541 | },
542 | "execution_count": 8
543 | }
544 | ]
545 | },
546 | {
547 | "cell_type": "code",
548 | "metadata": {
549 | "id": "lmxB_JViXFAF",
550 | "colab_type": "code",
551 | "outputId": "9cdcba5c-74f8-42e4-9a4b-fad1c06b83cd",
552 | "colab": {
553 | "base_uri": "https://localhost:8080/",
554 | "height": 35
555 | }
556 | },
557 | "source": [
558 | "sigma"
559 | ],
560 | "execution_count": 0,
561 | "outputs": [
562 | {
563 | "output_type": "execute_result",
564 | "data": {
565 | "text/plain": [
566 | "array([4.47696617, 2.7519661 , 2. , 1.17620428])"
567 | ]
568 | },
569 | "metadata": {
570 | "tags": []
571 | },
572 | "execution_count": 13
573 | }
574 | ]
575 | },
576 | {
577 | "cell_type": "code",
578 | "metadata": {
579 | "id": "AiXURUScXMsj",
580 | "colab_type": "code",
581 | "outputId": "d6fc0576-8bf1-491e-caed-02035079f0d3",
582 | "colab": {
583 | "base_uri": "https://localhost:8080/",
584 | "height": 161
585 | }
586 | },
587 | "source": [
588 | "VT"
589 | ],
590 | "execution_count": 0,
591 | "outputs": [
592 | {
593 | "output_type": "execute_result",
594 | "data": {
595 | "text/plain": [
596 | "array([[-1.75579600e-01, -3.51159201e-01, -6.38515454e-01,\n",
597 | " -6.61934313e-01],\n",
598 | " [ 3.91361272e-01, 7.82722545e-01, -3.79579831e-02,\n",
599 | " -4.82432341e-01],\n",
600 | " [ 8.94427191e-01, -4.47213595e-01, 0.00000000e+00,\n",
601 | " 8.32667268e-17],\n",
602 | " [ 1.26523351e-01, 2.53046702e-01, -7.68672366e-01,\n",
603 | " 5.73674125e-01]])"
604 | ]
605 | },
606 | "metadata": {
607 | "tags": []
608 | },
609 | "execution_count": 14
610 | }
611 | ]
612 | },
613 | {
614 | "cell_type": "code",
615 | "metadata": {
616 | "id": "DKOxld5lXRCK",
617 | "colab_type": "code",
618 | "outputId": "0832796a-9952-4f39-e1ef-79347c495d16",
619 | "colab": {
620 | "base_uri": "https://localhost:8080/",
621 | "height": 53
622 | }
623 | },
624 | "source": [
625 | "# 截断奇异值分解\n",
626 | "\n",
627 | "svd = TruncatedSVD(n_components=3, n_iter=7, random_state=42)\n",
628 | "svd.fit(X) "
629 | ],
630 | "execution_count": 0,
631 | "outputs": [
632 | {
633 | "output_type": "execute_result",
634 | "data": {
635 | "text/plain": [
636 | "TruncatedSVD(algorithm='randomized', n_components=3, n_iter=7, random_state=42,\n",
637 | " tol=0.0)"
638 | ]
639 | },
640 | "metadata": {
641 | "tags": []
642 | },
643 | "execution_count": 16
644 | }
645 | ]
646 | },
647 | {
648 | "cell_type": "code",
649 | "metadata": {
650 | "id": "btnGrF0LXzZI",
651 | "colab_type": "code",
652 | "outputId": "ba85127a-4fa6-4092-828c-9036c47f82f6",
653 | "colab": {
654 | "base_uri": "https://localhost:8080/",
655 | "height": 35
656 | }
657 | },
658 | "source": [
659 | "print(svd.explained_variance_ratio_)"
660 | ],
661 | "execution_count": 0,
662 | "outputs": [
663 | {
664 | "output_type": "stream",
665 | "text": [
666 | "[0.39945801 0.34585056 0.18861789]\n"
667 | ],
668 | "name": "stdout"
669 | }
670 | ]
671 | },
672 | {
673 | "cell_type": "code",
674 | "metadata": {
675 | "id": "F1hSe5NxX1zw",
676 | "colab_type": "code",
677 | "outputId": "b0d0b87d-195b-4653-a857-48ff1eca887a",
678 | "colab": {
679 | "base_uri": "https://localhost:8080/",
680 | "height": 35
681 | }
682 | },
683 | "source": [
684 | "print(svd.explained_variance_ratio_.sum())"
685 | ],
686 | "execution_count": 0,
687 | "outputs": [
688 | {
689 | "output_type": "stream",
690 | "text": [
691 | "0.9339264600284481\n"
692 | ],
693 | "name": "stdout"
694 | }
695 | ]
696 | },
697 | {
698 | "cell_type": "code",
699 | "metadata": {
700 | "id": "cV4L2i9WX30R",
701 | "colab_type": "code",
702 | "outputId": "6c313215-d095-41b2-a384-3dc1575729a2",
703 | "colab": {
704 | "base_uri": "https://localhost:8080/",
705 | "height": 35
706 | }
707 | },
708 | "source": [
709 | "print(svd.singular_values_)"
710 | ],
711 | "execution_count": 0,
712 | "outputs": [
713 | {
714 | "output_type": "stream",
715 | "text": [
716 | "[4.47696617 2.7519661 2. ]\n"
717 | ],
718 | "name": "stdout"
719 | }
720 | ]
721 | },
722 | {
723 | "cell_type": "markdown",
724 | "metadata": {
725 | "id": "4CbG9kJXictK",
726 | "colab_type": "text"
727 | },
728 | "source": [
729 | "#### 非负矩阵分解"
730 | ]
731 | },
732 | {
733 | "cell_type": "code",
734 | "metadata": {
735 | "id": "KcA2Rd4Df_DE",
736 | "colab_type": "code",
737 | "colab": {}
738 | },
739 | "source": [
740 | "def inverse_transform(W, H):\n",
741 | " # 重构\n",
742 | " return W.dot(H)\n",
743 | "\n",
744 | "def loss(X, X_):\n",
745 | " #计算重构误差\n",
746 | " return ((X - X_) * (X - X_)).sum()"
747 | ],
748 | "execution_count": 0,
749 | "outputs": []
750 | },
751 | {
752 | "cell_type": "code",
753 | "metadata": {
754 | "id": "yRXZt6CfYPJq",
755 | "colab_type": "code",
756 | "colab": {}
757 | },
758 | "source": [
759 | "# 算法 17.1\n",
760 | "\n",
761 | "class MyNMF:\n",
762 | " def fit(self, X, k, t):\n",
763 | " m, n = X.shape\n",
764 | " \n",
765 | " W = np.random.rand(m, k)\n",
766 | " W = W/W.sum(axis=0)\n",
767 | " \n",
768 | " H = np.random.rand(k, n)\n",
769 | " \n",
770 | " i = 1\n",
771 | " while i < t:\n",
772 | " \n",
773 | " W = W * X.dot(H.T) / W.dot(H).dot(H.T)\n",
774 | " \n",
775 | " H = H * (W.T).dot(X) / (W.T).dot(W).dot(H)\n",
776 | " \n",
777 | " i += 1\n",
778 | " \n",
779 | " return W, H"
780 | ],
781 | "execution_count": 0,
782 | "outputs": []
783 | },
784 | {
785 | "cell_type": "code",
786 | "metadata": {
787 | "id": "zc1IFBBIajXk",
788 | "colab_type": "code",
789 | "colab": {}
790 | },
791 | "source": [
792 | "model = MyNMF()\n",
793 | "W, H = model.fit(X, 3, 200)"
794 | ],
795 | "execution_count": 0,
796 | "outputs": []
797 | },
798 | {
799 | "cell_type": "code",
800 | "metadata": {
801 | "id": "EYo7JyXXbNpJ",
802 | "colab_type": "code",
803 | "outputId": "716dbb03-5229-4c93-a75b-81d07ff4be85",
804 | "colab": {
805 | "base_uri": "https://localhost:8080/",
806 | "height": 125
807 | }
808 | },
809 | "source": [
810 | "W"
811 | ],
812 | "execution_count": 0,
813 | "outputs": [
814 | {
815 | "output_type": "execute_result",
816 | "data": {
817 | "text/plain": [
818 | "array([[7.80747563e-282, 1.59350147e-085, 4.46818285e-001],\n",
819 | " [1.36585053e-173, 1.82432253e-099, 8.93962574e-001],\n",
820 | " [5.15393770e-058, 5.75011993e-001, 4.19426683e-039],\n",
821 | " [1.73830597e+000, 1.09961986e+000, 4.87814473e-031],\n",
822 | " [5.89525831e-001, 3.53091403e-065, 1.51262003e-035],\n",
823 | " [5.54346760e-001, 1.12753836e+000, 1.11381284e+000]])"
824 | ]
825 | },
826 | "metadata": {
827 | "tags": []
828 | },
829 | "execution_count": 113
830 | }
831 | ]
832 | },
833 | {
834 | "cell_type": "code",
835 | "metadata": {
836 | "id": "1cEFDsgXbnXZ",
837 | "colab_type": "code",
838 | "outputId": "e3c8eaf0-bcd8-48f5-edee-57f643b07fed",
839 | "colab": {
840 | "base_uri": "https://localhost:8080/",
841 | "height": 71
842 | }
843 | },
844 | "source": [
845 | "H"
846 | ],
847 | "execution_count": 0,
848 | "outputs": [
849 | {
850 | "output_type": "execute_result",
851 | "data": {
852 | "text/plain": [
853 | "array([[3.02557029e-05, 2.18916926e-04, 5.10981068e-02, 1.69486742e+00],\n",
854 | " [3.08284998e-03, 5.45494813e-03, 1.73785466e+00, 4.89822454e-02],\n",
855 | " [8.94680268e-01, 1.79000896e+00, 1.09648099e-02, 4.54347640e-03]])"
856 | ]
857 | },
858 | "metadata": {
859 | "tags": []
860 | },
861 | "execution_count": 114
862 | }
863 | ]
864 | },
865 | {
866 | "cell_type": "code",
867 | "metadata": {
868 | "id": "JFqGat0JdVlL",
869 | "colab_type": "code",
870 | "outputId": "567c82ff-997f-49e7-b829-f9d4c828f9c9",
871 | "colab": {
872 | "base_uri": "https://localhost:8080/",
873 | "height": 125
874 | }
875 | },
876 | "source": [
877 | "# 重构\n",
878 | "X_ = inverse_transform(W, H);X_"
879 | ],
880 | "execution_count": 0,
881 | "outputs": [
882 | {
883 | "output_type": "execute_result",
884 | "data": {
885 | "text/plain": [
886 | "array([[3.99759503e-01, 7.99808736e-01, 4.89927756e-03, 2.03010833e-03],\n",
887 | " [7.99810675e-01, 1.60020102e+00, 9.80212969e-03, 4.06169786e-03],\n",
888 | " [1.77267571e-03, 3.13666059e-03, 9.99287272e-01, 2.81653785e-02],\n",
889 | " [3.44255674e-03, 6.37891391e-03, 1.99980365e+00, 3.00006000e+00],\n",
890 | " [1.78365184e-05, 1.29057183e-04, 3.01236539e-02, 9.99168124e-01],\n",
891 | " [9.99999171e-01, 2.00000698e+00, 2.00003661e+00, 9.99834205e-01]])"
892 | ]
893 | },
894 | "metadata": {
895 | "tags": []
896 | },
897 | "execution_count": 115
898 | }
899 | ]
900 | },
901 | {
902 | "cell_type": "code",
903 | "metadata": {
904 | "id": "FmXCjjnyfcfY",
905 | "colab_type": "code",
906 | "outputId": "819c8029-12d3-4344-e5af-b352b51507a3",
907 | "colab": {
908 | "base_uri": "https://localhost:8080/",
909 | "height": 35
910 | }
911 | },
912 | "source": [
913 | "# 重构误差\n",
914 | "\n",
915 | "loss(X, X_)"
916 | ],
917 | "execution_count": 0,
918 | "outputs": [
919 | {
920 | "output_type": "execute_result",
921 | "data": {
922 | "text/plain": [
923 | "4.001908238790242"
924 | ]
925 | },
926 | "metadata": {
927 | "tags": []
928 | },
929 | "execution_count": 116
930 | }
931 | ]
932 | },
933 | {
934 | "cell_type": "markdown",
935 | "metadata": {
936 | "id": "dGu-tGDxhEcf",
937 | "colab_type": "text"
938 | },
939 | "source": [
940 | "### 使用 sklearn 计算"
941 | ]
942 | },
943 | {
944 | "cell_type": "code",
945 | "metadata": {
946 | "id": "sLN4FLmvb6tt",
947 | "colab_type": "code",
948 | "colab": {}
949 | },
950 | "source": [
951 | "from sklearn.decomposition import NMF\n",
952 | "model = NMF(n_components=3, init='random', max_iter=200, random_state=0)\n",
953 | "W = model.fit_transform(X)\n",
954 | "H = model.components_"
955 | ],
956 | "execution_count": 0,
957 | "outputs": []
958 | },
959 | {
960 | "cell_type": "code",
961 | "metadata": {
962 | "id": "Fm5W6xQ0b_jl",
963 | "colab_type": "code",
964 | "outputId": "d2ae79f6-0fea-47ba-a2ba-1742b78f71b1",
965 | "colab": {
966 | "base_uri": "https://localhost:8080/",
967 | "height": 125
968 | }
969 | },
970 | "source": [
971 | "W"
972 | ],
973 | "execution_count": 0,
974 | "outputs": [
975 | {
976 | "output_type": "execute_result",
977 | "data": {
978 | "text/plain": [
979 | "array([[0. , 0.53849498, 0. ],\n",
980 | " [0. , 1.07698996, 0. ],\n",
981 | " [0.69891361, 0. , 0. ],\n",
982 | " [1.39782972, 0. , 1.97173859],\n",
983 | " [0. , 0. , 0.65783848],\n",
984 | " [1.39783002, 1.34623756, 0.65573258]])"
985 | ]
986 | },
987 | "metadata": {
988 | "tags": []
989 | },
990 | "execution_count": 108
991 | }
992 | ]
993 | },
994 | {
995 | "cell_type": "code",
996 | "metadata": {
997 | "id": "wheFi8ZwcCY9",
998 | "colab_type": "code",
999 | "outputId": "9dd57216-889a-4cc4-ccb8-82e790595d59",
1000 | "colab": {
1001 | "base_uri": "https://localhost:8080/",
1002 | "height": 71
1003 | }
1004 | },
1005 | "source": [
1006 | "H"
1007 | ],
1008 | "execution_count": 0,
1009 | "outputs": [
1010 | {
1011 | "output_type": "execute_result",
1012 | "data": {
1013 | "text/plain": [
1014 | "array([[0.00000000e+00, 0.00000000e+00, 1.43078959e+00, 1.71761682e-03],\n",
1015 | " [7.42810976e-01, 1.48562195e+00, 0.00000000e+00, 3.30264644e-04],\n",
1016 | " [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.52030365e+00]])"
1017 | ]
1018 | },
1019 | "metadata": {
1020 | "tags": []
1021 | },
1022 | "execution_count": 109
1023 | }
1024 | ]
1025 | },
1026 | {
1027 | "cell_type": "code",
1028 | "metadata": {
1029 | "id": "9hfj3bRXgHRb",
1030 | "colab_type": "code",
1031 | "outputId": "3be1a4d8-a160-4200-d4e7-f5c92f2aa1af",
1032 | "colab": {
1033 | "base_uri": "https://localhost:8080/",
1034 | "height": 125
1035 | }
1036 | },
1037 | "source": [
1038 | "X__ = inverse_transform(W, H);X__"
1039 | ],
1040 | "execution_count": 0,
1041 | "outputs": [
1042 | {
1043 | "output_type": "execute_result",
1044 | "data": {
1045 | "text/plain": [
1046 | "array([[3.99999983e-01, 7.99999966e-01, 0.00000000e+00, 1.77845853e-04],\n",
1047 | " [7.99999966e-01, 1.59999993e+00, 0.00000000e+00, 3.55691707e-04],\n",
1048 | " [0.00000000e+00, 0.00000000e+00, 9.99998311e-01, 1.20046577e-03],\n",
1049 | " [0.00000000e+00, 0.00000000e+00, 2.00000021e+00, 3.00004230e+00],\n",
1050 | " [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.00011424e+00],\n",
1051 | " [1.00000003e+00, 2.00000007e+00, 2.00000064e+00, 9.99758185e-01]])"
1052 | ]
1053 | },
1054 | "metadata": {
1055 | "tags": []
1056 | },
1057 | "execution_count": 110
1058 | }
1059 | ]
1060 | },
1061 | {
1062 | "cell_type": "code",
1063 | "metadata": {
1064 | "id": "-iALOTKfgKzP",
1065 | "colab_type": "code",
1066 | "outputId": "80c73327-dd36-403f-aa20-d4e2d3f796a7",
1067 | "colab": {
1068 | "base_uri": "https://localhost:8080/",
1069 | "height": 35
1070 | }
1071 | },
1072 | "source": [
1073 | "loss(X, X__)"
1074 | ],
1075 | "execution_count": 0,
1076 | "outputs": [
1077 | {
1078 | "output_type": "execute_result",
1079 | "data": {
1080 | "text/plain": [
1081 | "4.0000016725824565"
1082 | ]
1083 | },
1084 | "metadata": {
1085 | "tags": []
1086 | },
1087 | "execution_count": 111
1088 | }
1089 | ]
1090 | }
1091 | ]
1092 | }
--------------------------------------------------------------------------------
/第18章 概率潜在语义分析/PLSA.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "PLSA.ipynb",
7 | "version": "0.3.2",
8 | "provenance": [],
9 | "collapsed_sections": []
10 | },
11 | "kernelspec": {
12 | "name": "python3",
13 | "display_name": "Python 3"
14 | }
15 | },
16 | "cells": [
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {
20 | "id": "0I-Es-jovJzm",
21 | "colab_type": "text"
22 | },
23 | "source": [
24 | "# 概率潜在语义分析"
25 | ]
26 | },
27 | {
28 | "cell_type": "markdown",
29 | "metadata": {
30 | "id": "EeHek0KrvNbO",
31 | "colab_type": "text"
32 | },
33 | "source": [
34 | "概率潜在语义分析(probabilistic latent semantic analysis, PLSA),也称概率潜在语义索引(probabilistic latent semantic indexing, PLSI),是一种利用概率生成模型对文本集合进行话题分析的无监督学习方法。\n",
35 | "\n",
36 | "模型最大特点是用隐变量表示话题,整个模型表示文本生成话题,话题生成单词,从而得到单词-文本共现数据的过程;假设每个文本由一个话题分布决定,每个话题由一个单词分布决定。"
37 | ]
38 | },
39 | {
40 | "cell_type": "markdown",
41 | "metadata": {
42 | "id": "ZpnNY-eRwjq3",
43 | "colab_type": "text"
44 | },
45 | "source": [
46 | "### **18.1.2 生成模型**\n",
47 | "\n",
48 | "假设有单词集合 $W = $ {$w_{1}, w_{2}, ..., w_{M}$}, 其中M是单词个数;文本(指标)集合$D = $ {$d_{1}, d_{2}, ..., d_{N}$}, 其中N是文本个数;话题集合$Z = $ {$z_{1}, z_{2}, ..., z_{K}$},其中$K$是预先设定的话题个数。随机变量 $w$ 取值于单词集合;随机变量 $d$ 取值于文本集合,随机变量 $z$ 取值于话题集合。概率分布 $P(d)$、条件概率分布 $P(z|d)$、条件概率分布 $P(w|z)$ 皆属于多项分布,其中 $P(d)$ 表示生成文本 $d$ 的概率,$P(z|d)$ 表示文本 $d$ 生成话题 $z$ 的概率,$P(w|z)$ 表示话题 $z$ 生成单词 $w$ 的概率。\n",
49 | "\n",
50 | " 每个文本 $d$ 拥有自己的话题概率分布 $P(z|d)$,每个话题 $z$ 拥有自己的单词概率分布 $P(w|z)$;也就是说**一个文本的内容由其相关话题决定,一个话题的内容由其相关单词决定**。\n",
51 | " \n",
52 | " 生成模型通过以下步骤生成文本·单词共现数据: \n",
53 | " (1)依据概率分布 $P(d)$,从文本(指标)集合中随机选取一个文本 $d$ , 共生成 $N$ 个文本;针对每个文本,执行以下操作; \n",
54 | " (2)在文本$d$ 给定条件下,依据条件概率分布 $P(z|d)$, 从话题集合随机选取一个话题 $z$, 共生成 $L$ 个话题,这里 $L$ 是文本长度; \n",
55 | " (3)在话题 $z$ 给定条件下,依据条件概率分布 $P(w|z)$ , 从单词集合中随机选取一个单词 $w$. \n",
56 | " \n",
57 | " 注意这里为叙述方便,假设文本都是等长的,现实中不需要这个假设。"
58 | ]
59 | },
60 | {
61 | "cell_type": "markdown",
62 | "metadata": {
63 | "id": "_YwFFCuCgugI",
64 | "colab_type": "text"
65 | },
66 | "source": [
67 | "生成模型中, 单词变量 $w$ 与文本变量 $d$ 是观测变量, 话题变量 $z$ 是隐变量, 也就是说模型生成的是单词-话题-文本三元组合 ($w, z ,d$)的集合, 但观测到的单词-文本二元组 ($w, d$)的集合, 观测数据表示为单词-文本矩阵 $T$的形式,矩阵 $T$ 的行表示单词,列表示文本, 元素表示单词-文本对($w, d$)的出现次数。 \n",
68 | "\n",
69 | "从数据的生成过程可以推出,文本-单词共现数据$T$的生成概率为所有单词-文本对($w,d$)的生成概率的乘积: \n",
70 | "\n",
71 | "$P(T) = \\prod_{w,d}P(w,d)^{n(w,d)}$ \n",
72 | "\n",
73 | "这里 $n(w,d)$ 表示 ($w,d$)的出现次数,单词-文本对出现的总次数是 $N*L$。 每个单词-文本对($w,d$)的生成概率由一下公式决定: \n",
74 | "\n",
75 | "$P(w,d) = P(d)P(w|d)$ \n",
76 | "\n",
77 | "$= P(d)\\sum_{z}P(w,z|d)$ \n",
78 | "\n",
79 | "$=P(d)\\sum_{z}P(z|d)P(w|z)$"
80 | ]
81 | },
82 | {
83 | "cell_type": "markdown",
84 | "metadata": {
85 | "id": "rIUH6dILnmQs",
86 | "colab_type": "text"
87 | },
88 | "source": [
89 | "### **18.1.3 共现模型**\n",
90 | "\n",
91 | "$P(w,d) = \\sum_{z\\in Z}P(z)P(w|z)P(d|z)$"
92 | ]
93 | },
94 | {
95 | "cell_type": "markdown",
96 | "metadata": {
97 | "id": "JSt5kq4LoFJT",
98 | "colab_type": "text"
99 | },
100 | "source": [
101 | "虽然生成模型与共现模型在概率公式意义上是等价的,但是拥有不同的性质。生成模型刻画文本-单词共现数据生成的过程,共现模型描述文本-单词共现数据拥有的模式。 \n",
102 | "\n",
103 | "如果直接定义单词与文本的共现概率 $P(w,d)$, 模型参数的个数是 $O(M*N)$, 其中 $M$ 是单词数, $N$ 是文本数。 概率潜在语义分析的生成模型和共现模型的参数个数是 $O(M*K + N*K)$, 其中 $K$ 是话题数。 现实中 $K< maxerr:\n",
70 | " ro = r.copy()\n",
71 | " # calculate each pagerank at a time\n",
72 | " for i in range(0,n):\n",
73 | " # inlinks of state i\n",
74 | " Ai = np.array(A[:,i].todense())[:,0]\n",
75 | " # account for sink states\n",
76 | " Di = sink / float(n)\n",
77 | " # account for teleportation to state i\n",
78 | " Ei = np.ones(n) / float(n)\n",
79 | "\n",
80 | " r[i] = ro.dot( Ai*s + Di*s + Ei*(1-s) )\n",
81 | "\n",
82 | " # return normalized pagerank\n",
83 | " return r/float(sum(r))"
84 | ],
85 | "execution_count": 0,
86 | "outputs": []
87 | },
88 | {
89 | "cell_type": "code",
90 | "metadata": {
91 | "id": "Ds-wQEFFZ1F7",
92 | "colab_type": "code",
93 | "colab": {
94 | "base_uri": "https://localhost:8080/",
95 | "height": 53
96 | },
97 | "outputId": "b2860902-8712-4583-ab47-bec602c6791b"
98 | },
99 | "source": [
100 | "# Example extracted from 'Introduction to Information Retrieval'\n",
101 | "G = np.array([[0,0,1,0,0,0,0],\n",
102 | " [0,1,1,0,0,0,0],\n",
103 | " [1,0,1,1,0,0,0],\n",
104 | " [0,0,0,1,1,0,0],\n",
105 | " [0,0,0,0,0,0,1],\n",
106 | " [0,0,0,0,0,1,1],\n",
107 | " [0,0,0,1,1,0,1]])\n",
108 | "print(pageRank(G,s=.86))"
109 | ],
110 | "execution_count": 6,
111 | "outputs": [
112 | {
113 | "output_type": "stream",
114 | "text": [
115 | "[0.12727557 0.03616954 0.12221594 0.22608452 0.28934412 0.03616954\n",
116 | " 0.16274076]\n"
117 | ],
118 | "name": "stdout"
119 | }
120 | ]
121 | }
122 | ]
123 | }
--------------------------------------------------------------------------------