├── Errata └── Errata.md ├── README.md ├── charpter10_AdaBoost └── adaboost.ipynb ├── charpter11_GBDT ├── cart.py ├── gbdt.ipynb └── utils.py ├── charpter12_XGBoost ├── cart.py ├── utils.py └── xgboost.ipynb ├── charpter13_LightGBM └── lightgbm.ipynb ├── charpter14_CatBoost ├── adult.data └── catboost.ipynb ├── charpter15_random_forest ├── cart.py └── random_forest.ipynb ├── charpter16_ensemble_compare └── compare_and_tuning.ipynb ├── charpter17_kmeans └── kmeans.ipynb ├── charpter18_PCA └── pca.ipynb ├── charpter19_SVD ├── louwill.jpg └── svd.ipynb ├── charpter1_ml_start └── NumPy_sklearn.ipynb ├── charpter20_MEM └── max_entropy_model.ipynb ├── charpter21_Bayesian_models ├── bayesian_network.ipynb └── naive_bayes.ipynb ├── charpter22_EM └── em.ipynb ├── charpter23_HMM └── hmm.ipynb ├── charpter24_CRF └── crf.ipynb ├── charpter25_MCMC └── mcmc.ipynb ├── charpter2_linear_regression └── linear_regression.ipynb ├── charpter3_logistic_regression └── logistic_regression.ipynb ├── charpter4_regression_expansion ├── example.dat ├── lasso.ipynb └── ridge.ipynb ├── charpter5_LDA └── LDA.ipynb ├── charpter6_knn └── knn.ipynb ├── charpter7_decision_tree ├── CART.ipynb ├── ID3.ipynb ├── example_data.csv └── utils.py ├── charpter8_neural_networks ├── neural_networks.ipynb ├── perceptron.ipynb └── perceptron.py ├── charpter9_SVM ├── hard_margin_svm.ipynb ├── non-linear_svm.ipynb └── soft_margin_svm.ipynb └── pic ├── cover.jpg ├── ml_xmind.png ├── ppt_1.png ├── ppt_2.png ├── ppt_3.png ├── ppt_4.png └── ppt_5.png /Errata/Errata.md: -------------------------------------------------------------------------------- 1 | #### 第一版第一次印刷勘误 2 | | 序号 | 所在页码 | 具体问题 | 勘误类型 | 修改 | 3 | | ---- | ---- | ---- | ---- | ---- | 4 | | 1 | 彩插第1页 | 图3-6逻辑回归描述与正文不统一 | 文字或格式错误 | 逻辑回归应改为对数几率回归 | 5 | | 2 | 11 | 代码第4行 | 技术错误 | #矩阵点乘 应改为 #矩阵乘法 | 6 | | 3 | 32 | 式3-9 | 技术错误 | 中间+号应为x号,式3-10应改为求和式并补充负号,L应令为-lnp(y\|x) | 7 | | 4 | 37 | 代码3-5倒数第4和第6行 | 文字或格式错误 | lables 应改为labels | 8 | | 5 | 47 | 代码4-3倒数第19行参数少了个0.1 | 技术错误 | 应补全为l1_loss(x,y,w,b,0.1) | 9 | | 6 | 52 | 代码4-8第7行参数少了个0.1 | 技术错误 | 应补全为l2_loss(x,y,w,b,0.1) | 10 | | 7 | 68 | 代码6-6倒数第一行多一个右括号 | 文字或格式错误 | 去掉该括号 | 11 | | 8 | 71 | 该页Kneighbors中的n应大写 | 文字或格式错误 | 应改为KNeighbors | 12 | | 9 | 77 | 表7-3数字统计错误 | 技术错误 | 应改为正确的统计数值:将“晴”和“雨”两行数值互换,相应的表格后第一个式子E(3,2)改为E(2,3),E(2,3)改为E(3,2) | 13 | | 10 | 79 | 倒数第四行名词错误 | 技术错误 | 应将信息增益比改为基尼指数 | 14 | | 11 | 90 | 代码7-10中部分变量命名不统一 | 技术错误 | 应统一best_subsets、left_branch、feature_ix、leaf_value_calc等变量 | 15 | | 12 | 92 | 代码7-11倒数第3行变量有误 | 技术错误 | 应将impurity_calculation改为gini_impurity_calc,\_leaf_value_calculation改为leaf_value_calc | 16 | | 13 | 111 | 代码8-12倒数第9行 | 技术错误 | 多了一个parameter参数,应去掉 | 17 | | 14 | 119 | 第一段第一行L(x,α,β)有误 | 技术错误 | 应改为L(w,b,α) | 18 | | 15 | 121 | 式9-35缺少q矩阵部分 | 技术错误 | 应补充q矩阵 | 19 | | 16 | 167 | 图12-1倒数第2行最优点式子有误 | 技术错误 | 应改为式12-18 | 20 | | 17 | 211 | 第二段n个样本描述有误 | 技术错误 | 应改为m个样本,相应的式17-8求和上标改为n| 21 | | 18 | 216 | 代码17-7第12行函数参数写反 | 技术错误 | 应对调修改 | 22 | 23 | **注意:以上勘误在2022.2第二次印刷版本中均已更正!** 24 | 25 | #### 第一版第二次印刷需勘误之处 26 | | 序号 | 所在页码 | 具体问题 | 勘误类型 | 修改 | 27 | | ---- | ---- | ---- | ---- | ---- | 28 | | 1 | 21 | 式2-8和2-9后两项w漏掉转置T | 技术错误 | 应补上转置T| 29 | | 2 | 21 | 式2-6少一列 | 技术错误 | 应在矩阵最后补充1向量列,相应的描述应改为m*(d+1)维 | 30 | | 3 | 29 | 代码2-8输出第二行 | 技术错误 | 应保留两位小数:0.54 | 31 | | 4 | 33 | 第3段第3行“则式(3-13)可表示为”应补充描述 | 建议 | 应补充描述为“则式(3-13)似然项可表示为” | 32 | | 5 | 40 | 代码3-9第一行和最后一行函数名有误 | 技术错误 | 应统一为plot_decision_boundary | 33 | | 6 | 99 | 第2段第2行漏掉>0且式8-3w\*x+b缺少绝对值符号 | 技术错误 | 应改为“都有-y_i(wx_i+b)>0成立”并补充式8-3绝对值符号 | 34 | | 7 | 145 | 倒数第2段中的t | 技术错误 | t应改为T | 35 | | 8 | 155 | 最后一段式(11-10) | 文字或格式错误 | (11-10)应改为(11-9)| 36 | | 9 | 169-170 | 代码12-1倒数第2、3行impurity_calculation、\_leaf_value_calculation与决策树一章变量不统一 | 文字或格式错误 | 应统一为gini_impurity_calc、leaf_value_calc | 37 | | 10 | 225 | 式(19-11)多了一个VTV | 技术错误 | 去掉末尾的VTV即可 | 38 | | 11 | 240 | 式(21-8)分母连乘符号倒置 | 印刷错误 | 应改正 | 39 | | 12 | 241 | 代码21-1倒数第2行和第12行class_condition_prob变量名未统一 | 技术错误 | 统一为prior_condition_prob | 40 | | 13 | 242 | 代码21-2倒数第5行未缩进,以及prior命名未统一 | 技术错误 | 应缩进该行,并将prior命名统一为class_prior | 41 | | 14 | 248 | 代码21-8倒数第4和第6行student_model和student_infer命名未统一 | 技术错误 | 应统一为letter_model和letter_infer | 42 | 43 | **注意:以上勘误在2022.3第三次印刷版本中均已更正!** 44 | 45 | #### 第一版第三次印刷需勘误之处 46 | | 序号 | 所在页码 | 具体问题 | 勘误类型 | 修改 | 47 | | ---- | ---- | ---- | ---- | ---- | 48 | | 1 | 10 | 代码1-6倒数第6行 | 复制错误 | 第三个False应改为True | 49 | | 2 | 48 | 代码4-5第2行大小写和多参数问题 | 技术错误 | 去掉loss参数输出以及LASSO改为小写lasso | 50 | | 3 | 86 | 倒数第三行均方损失 | 技术错误 | 应改为平方损失 | 51 | | 4 | 104 | 式8-14 1/m | 技术错误 | 应与其他式保持统一,这里可去掉1/m | 52 | | 5 | 126 | 第二段式(9-40) | 技术错误 | 应改为式(9-39) | 53 | | 6 | 211 | 式(17-10) x_j有误 | 复制错误 | 应与17-9一致,改为x^bar_l | 54 | | 7 | 183 | 第一段第二行英文statisitics单词拼写错误 | 文字或格式错误 | 应改为statistics | 55 | | 8 | 225 | 式19-8与式19-9 A与AT写反 | 技术错误 | 式19-8应改为AAT,式19-9应改为ATA,对应公式下的变量和描述也随公式进行修正 | 56 | | 9 | 227 | 19.3.2标题描述不准确 | 建议 | 应改为 “基于SVD的图像压缩” 为宜 | 57 | | 10 | 248 | 代码21-8倒数第5行注释描述不准确 | 技术错误 | 应改为“天赋较好且考试不难的情况下学生获得成绩的好坏” | 58 | | 11 | 253 | 式22-9和后一行\theta_A 和 \theta_B 符号错误 | 技术错误 | 应改为\theta_B 和 theta_C | 59 | | 12 | 261 | 最后一行红白球写反 | 技术错误 | 应改为“1表示白球,0表示红球” | 60 | | 13 | 264 | 式(23-20)和(23-21)alpha_ij有误 | 技术错误 | 应改为a_ij | 61 | | 14 | 265 | 式23-22、23-24、式23-25局部有误 | 技术错误 | 式23-22、23-24条件概率\|应在Ot后,式23-25倒数第二个式最后一个\|应为逗号,q_j应改为q_i | 62 | | 15 | 277-278 | 277页第5行维特比算法少个法字,式(23-19)和(23-20)有误 | 技术错误 | start应改为stop,M_i应改为M_i+1 | 63 | | 16 | 280 | 式(24-36)和(24-37)下标有误 | 技术错误 | δ_i应改为δ_1,δ_l(j)应改为δ_i(l) | 64 | | 17 | 288 | 式(25-14)后一个α(i,j)有误 | 技术错误 | 应改为 α(j,i) | 65 | 66 | 67 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 机器学习 公式推导与代码实现 2 | 李航老师的《统计学习方法》和周志华老师的西瓜书《机器学习》一直国内机器学习领域的经典教材。本书在这两本书理论框架的基础上,补充了必要的代码实现思路和逻辑过程。 3 | 4 | 本书在对全部机器学习算法进行分类梳理的基础之上,分别对监督学习单模型、监督学习集成模型、无监督学习模型、概率模型4个大类26个经典算法进行了相对完整的公式推导和必要的代码实现,旨在帮助机器学习入门读者完整地掌握算法细节、实现方法以及内在逻辑。本书可作为《统计学习方法》和西瓜书《机器学习》的补充材料。 5 | 6 | --- 7 | ### 使用说明 8 | 本仓库为《机器学习 公式推导与代码实现》一书配套代码库,相较于书中代码而言,仓库代码随时保持更新和迭代。目前仓库只开源了全书的代码,全书内容后续也会在仓库中开源。本仓库已经根据书中章节将代码分目录整理好,读者可直接点击相关章节使用该章节代码。 9 | 10 | --- 11 | ### 纸质版 12 | 15 |
16 |
19 | 20 | 购买链接:[京东](https://item.jd.com/13581834.html) | [当当](http://product.dangdang.com/29354670.html) 21 | 22 | --- 23 | ### 配套PPT 24 | 为方便大家更好的使用本书,本书也配套了随书的PPT,购买过纸质书的读者可以在机器学习实验室公众号联系作者获取。 25 | 26 | 29 |
30 |
第1章示例
33 | 34 | 35 | 38 |
39 |
第2章示例
42 | 43 | 44 | 47 |
48 |
第7章示例
51 | 52 | 55 |
56 |
第12章示例
59 | 60 | 61 | 64 |
65 |
第23章示例
68 | 69 | 70 | --- 71 | ### 配套视频讲解(更新中) 72 | 为了帮助广大读者更好地学习和掌握机器学习的一般理论和方法,笔者在PPT基础上同时在为全书配套讲解视频。包括模型的公式手推和代码的讲解。 73 | 74 | 第一章:[机器学习入门](https://www.bilibili.com/video/BV1jR4y1A7aH#reply112207884144) 75 | 76 | --- 77 | ### 全书勘误表 78 | 勘误表:[勘误表](https://github.com/luwill/Machine_Learning_Code_Implementation/blob/master/Errata/Errata.md) 79 | 80 | --- 81 | ### LICENSE 82 | 本项目采用[知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议](https://creativecommons.org/licenses/by-nc-sa/4.0/)进行许可。 83 | -------------------------------------------------------------------------------- /charpter11_GBDT/cart.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from utils import feature_split, calculate_gini 3 | 4 | ### 定义树结点 5 | class TreeNode(): 6 | def __init__(self, feature_i=None, threshold=None, 7 | leaf_value=None, left_branch=None, right_branch=None): 8 | # 特征索引 9 | self.feature_i = feature_i 10 | # 特征划分阈值 11 | self.threshold = threshold 12 | # 叶子节点取值 13 | self.leaf_value = leaf_value 14 | # 左子树 15 | self.left_branch = left_branch 16 | # 右子树 17 | self.right_branch = right_branch 18 | 19 | 20 | ### 定义二叉决策树 21 | class BinaryDecisionTree(object): 22 | ### 决策树初始参数 23 | def __init__(self, min_samples_split=2, min_gini_impurity=float("inf"), 24 | max_depth=float("inf"), loss=None): 25 | # 根结点 26 | self.root = None 27 | # 节点最小分裂样本数 28 | self.min_samples_split = min_samples_split 29 | # 节点初始化基尼不纯度 30 | self.mini_gini_impurity = min_gini_impurity 31 | # 树最大深度 32 | self.max_depth = max_depth 33 | # 基尼不纯度计算函数 34 | self.gini_impurity_calculation = None 35 | # 叶子节点值预测函数 36 | self._leaf_value_calculation = None 37 | # 损失函数 38 | self.loss = loss 39 | 40 | ### 决策树拟合函数 41 | def fit(self, X, y, loss=None): 42 | # 递归构建决策树 43 | self.root = self._build_tree(X, y) 44 | self.loss=None 45 | 46 | ### 决策树构建函数 47 | def _build_tree(self, X, y, current_depth=0): 48 | # 初始化最小基尼不纯度 49 | init_gini_impurity = self.gini_impurity_calculation(y) 50 | # 初始化最佳特征索引和阈值 51 | best_criteria = None 52 | # 初始化数据子集 53 | best_sets = None 54 | 55 | # 合并输入和标签 56 | Xy = np.concatenate((X, y), axis=1) 57 | # 获取样本数和特征数 58 | n_samples, n_features = X.shape 59 | # 设定决策树构建条件 60 | # 训练样本数量大于节点最小分裂样本数且当前树深度小于最大深度 61 | if n_samples >= self.min_samples_split and current_depth <= self.max_depth: 62 | # 遍历计算每个特征的基尼不纯度 63 | for feature_i in range(n_features): 64 | # 获取第i特征的所有取值 65 | feature_values = np.expand_dims(X[:, feature_i], axis=1) 66 | # 获取第i个特征的唯一取值 67 | unique_values = np.unique(feature_values) 68 | 69 | # 遍历取值并寻找最佳特征分裂阈值 70 | for threshold in unique_values: 71 | # 特征节点二叉分裂 72 | Xy1, Xy2 = feature_split(Xy, feature_i, threshold) 73 | # 如果分裂后的子集大小都不为0 74 | if len(Xy1) > 0 and len(Xy2) > 0: 75 | # 获取两个子集的标签值 76 | y1 = Xy1[:, n_features:] 77 | y2 = Xy2[:, n_features:] 78 | 79 | # 计算基尼不纯度 80 | impurity = self.impurity_calculation(y, y1, y2) 81 | 82 | # 获取最小基尼不纯度 83 | # 最佳特征索引和分裂阈值 84 | if impurity < init_gini_impurity: 85 | init_gini_impurity = impurity 86 | best_criteria = {"feature_i": feature_i, "threshold": threshold} 87 | best_sets = { 88 | "leftX": Xy1[:, :n_features], 89 | "lefty": Xy1[:, n_features:], 90 | "rightX": Xy2[:, :n_features], 91 | "righty": Xy2[:, n_features:] 92 | } 93 | 94 | # 如果best_criteria不为None,且计算的最小不纯度小于设定的最小不纯度 95 | if best_criteria and init_gini_impurity < self.mini_gini_impurity: 96 | # 分别构建左右子树 97 | left_branch = self._build_tree(best_sets["leftX"], best_sets["lefty"], current_depth + 1) 98 | right_branch = self._build_tree(best_sets["rightX"], best_sets["righty"], current_depth + 1) 99 | return TreeNode(feature_i=best_criteria["feature_i"], threshold=best_criteria[ 100 | "threshold"], left_branch=left_branch, right_branch=right_branch) 101 | 102 | # 计算叶子计算取值 103 | leaf_value = self._leaf_value_calculation(y) 104 | 105 | return TreeNode(leaf_value=leaf_value) 106 | 107 | ### 定义二叉树值预测函数 108 | def predict_value(self, x, tree=None): 109 | if tree is None: 110 | tree = self.root 111 | 112 | # 如果叶子节点已有值,则直接返回已有值 113 | if tree.leaf_value is not None: 114 | return tree.leaf_value 115 | 116 | # 选择特征并获取特征值 117 | feature_value = x[tree.feature_i] 118 | 119 | # 判断落入左子树还是右子树 120 | branch = tree.right_branch 121 | if isinstance(feature_value, int) or isinstance(feature_value, float): 122 | if feature_value <= tree.threshold: 123 | branch = tree.left_branch 124 | elif feature_value == tree.threshold: 125 | branch = tree.left_branch 126 | 127 | # 测试子集 128 | return self.predict_value(x, branch) 129 | 130 | ### 数据集预测函数 131 | def predict(self, X): 132 | y_pred = [self.predict_value(sample) for sample in X] 133 | return y_pred 134 | 135 | # CART分类树 136 | class ClassificationTree(BinaryDecisionTree): 137 | ### 定义基尼不纯度计算过程 138 | def _calculate_gini_impurity(self, y, y1, y2): 139 | p = len(y1) / len(y) 140 | gini = calculate_gini(y) 141 | gini_impurity = p * calculate_gini(y1) + (1-p) * calculate_gini(y2) 142 | return gini_impurity 143 | 144 | ### 多数投票 145 | def _majority_vote(self, y): 146 | most_common = None 147 | max_count = 0 148 | for label in np.unique(y): 149 | # 统计多数 150 | count = len(y[y == label]) 151 | if count > max_count: 152 | most_common = label 153 | max_count = count 154 | return most_common 155 | 156 | # 分类树拟合 157 | def fit(self, X, y): 158 | self.impurity_calculation = self._calculate_gini_impurity 159 | self._leaf_value_calculation = self._majority_vote 160 | self.gini_impurity_calculation = calculate_gini 161 | super(ClassificationTree, self).fit(X, y) 162 | 163 | 164 | ### CART回归树 165 | class RegressionTree(BinaryDecisionTree): 166 | def _calculate_weighted_mse(self, y, y1, y2): 167 | var_y1 = np.var(y1, axis=0) 168 | var_y2 = np.var(y2, axis=0) 169 | frac_1 = len(y1) / len(y) 170 | frac_2 = len(y2) / len(y) 171 | # 计算左右子树加权总均方误差 172 | weighted_mse = frac_1 * var_y1 + frac_2 * var_y2 173 | 174 | return sum(weighted_mse) 175 | 176 | # 节点值取平均 177 | def _mean_of_y(self, y): 178 | value = np.mean(y, axis=0) 179 | return value if len(value) > 1 else value[0] 180 | 181 | def fit(self, X, y): 182 | self.impurity_calculation = self._calculate_weighted_mse 183 | self._leaf_value_calculation = self._mean_of_y 184 | self.gini_impurity_calculation = lambda y: np.var(y, axis=0) 185 | super(RegressionTree, self).fit(X, y) 186 | -------------------------------------------------------------------------------- /charpter11_GBDT/gbdt.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### GBDT" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "from cart import TreeNode, BinaryDecisionTree, ClassificationTree, RegressionTree\n", 18 | "from sklearn.model_selection import train_test_split\n", 19 | "from sklearn.metrics import mean_squared_error\n", 20 | "from utils import feature_split, calculate_gini, data_shuffle" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "### GBDT定义\n", 30 | "class GBDT(object):\n", 31 | " def __init__(self, n_estimators, learning_rate, min_samples_split,\n", 32 | " min_gini_impurity, max_depth, regression):\n", 33 | " ### 常用超参数\n", 34 | " # 树的棵树\n", 35 | " self.n_estimators = n_estimators\n", 36 | " # 学习率\n", 37 | " self.learning_rate = learning_rate\n", 38 | " # 结点最小分裂样本数\n", 39 | " self.min_samples_split = min_samples_split\n", 40 | " # 结点最小基尼不纯度\n", 41 | " self.min_gini_impurity = min_gini_impurity\n", 42 | " # 最大深度\n", 43 | " self.max_depth = max_depth\n", 44 | " # 默认为回归树\n", 45 | " self.regression = regression\n", 46 | " # 损失为平方损失\n", 47 | " self.loss = SquareLoss()\n", 48 | " # 如果是分类树,需要定义分类树损失函数\n", 49 | " # 这里省略,如需使用,需自定义分类损失函数\n", 50 | " if not self.regression:\n", 51 | " self.loss = None\n", 52 | " # 多棵树叠加\n", 53 | " self.estimators = []\n", 54 | " for i in range(self.n_estimators):\n", 55 | " self.estimators.append(RegressionTree(min_samples_split=self.min_samples_split,\n", 56 | " min_gini_impurity=self.min_gini_impurity,\n", 57 | " max_depth=self.max_depth))\n", 58 | " # 拟合方法\n", 59 | " def fit(self, X, y):\n", 60 | " # 初始化预测结果(不同的初始值会影响收敛速度,通常使用均值)\n", 61 | " self.initial_prediction = np.mean(y)\n", 62 | " y_pred = np.full_like(y, self.initial_prediction)\n", 63 | " # 前向分步迭代训练\n", 64 | " for i in range(self.n_estimators):\n", 65 | " gradient = self.loss.gradient(y, y_pred)\n", 66 | " self.estimators[i].fit(X, gradient)\n", 67 | " y_pred -= self.learning_rate * np.array(self.estimators[i].predict(X)).reshape(-1, 1)\n", 68 | " \n", 69 | " # 预测方法\n", 70 | " def predict(self, X):\n", 71 | " # 回归树预测\n", 72 | " y_pred = np.zeros(X.shape[0]).reshape(-1, 1) + self.initial_prediction\n", 73 | " for i in range(self.n_estimators):\n", 74 | " y_pred -= self.learning_rate * np.array(self.estimators[i].predict(X)).reshape(-1, 1)\n", 75 | " # 分类树预测\n", 76 | " if not self.regression:\n", 77 | " # 将预测值转化为概率\n", 78 | " y_pred = np.exp(y_pred) / np.expand_dims(np.sum(np.exp(y_pred), axis=1), axis=1)\n", 79 | " # 转化为预测标签\n", 80 | " y_pred = np.argmax(y_pred, axis=1)\n", 81 | " return y_pred" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 3, 87 | "metadata": { 88 | "code_folding": [] 89 | }, 90 | "outputs": [], 91 | "source": [ 92 | "### GBDT分类树\n", 93 | "class GBDTClassifier(GBDT):\n", 94 | " def __init__(self, n_estimators=200, learning_rate=.5, min_samples_split=2,\n", 95 | " min_info_gain=1e-6, max_depth=2):\n", 96 | " super(GBDTClassifier, self).__init__(n_estimators=n_estimators,\n", 97 | " learning_rate=learning_rate,\n", 98 | " min_samples_split=min_samples_split,\n", 99 | " min_gini_impurity=min_info_gain,\n", 100 | " max_depth=max_depth,\n", 101 | " regression=False)\n", 102 | " # 拟合方法\n", 103 | " def fit(self, X, y):\n", 104 | " super(GBDTClassifier, self).fit(X, y)\n", 105 | " \n", 106 | "### GBDT回归树\n", 107 | "class GBDTRegressor(GBDT):\n", 108 | " def __init__(self, n_estimators=15, learning_rate=0.5, min_samples_split=2,\n", 109 | " min_var_reduction=float(\"inf\"), max_depth=3):\n", 110 | " super(GBDTRegressor, self).__init__(n_estimators=n_estimators,\n", 111 | " learning_rate=learning_rate,\n", 112 | " min_samples_split=min_samples_split,\n", 113 | " min_gini_impurity=min_var_reduction,\n", 114 | " max_depth=max_depth,\n", 115 | " regression=True)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 4, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "### 定义回归树的平方损失\n", 125 | "class SquareLoss():\n", 126 | " # 定义平方损失\n", 127 | " def loss(self, y, y_pred):\n", 128 | " return 0.5 * np.power((y - y_pred), 2)\n", 129 | " # 定义平方损失的梯度\n", 130 | " def gradient(self, y, y_pred):\n", 131 | " return -(y - y_pred)" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 5, 137 | "metadata": {}, 138 | "outputs": [ 139 | { 140 | "name": "stdout", 141 | "output_type": "stream", 142 | "text": [ 143 | "Mean Squared Error of NumPy GBRT: 13.782569688896887\n" 144 | ] 145 | } 146 | ], 147 | "source": [ 148 | "import pandas as pd\n", 149 | "\n", 150 | "# 波士顿房价数据集的原始 URL\n", 151 | "data_url = \"http://lib.stat.cmu.edu/datasets/boston\"\n", 152 | "\n", 153 | "# 从 URL 加载数据\n", 154 | "raw_df = pd.read_csv(data_url, sep=\"\\s+\", skiprows=22, header=None)\n", 155 | "\n", 156 | "# 处理数据\n", 157 | "data = np.hstack([raw_df.values[::2, :], raw_df.values[1::2, :2]]) # 拼接特征数据\n", 158 | "target = raw_df.values[1::2, 2] # 目标变量\n", 159 | "\n", 160 | "# 将数据和目标变量转换为 NumPy 数组\n", 161 | "X = np.array(data)\n", 162 | "y = np.array(target)\n", 163 | "y = y.reshape(-1,1)\n", 164 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)\n", 165 | "# 创建GBRT实例\n", 166 | "model = GBDTRegressor(n_estimators=15, learning_rate=0.5, min_samples_split=4,\n", 167 | " min_var_reduction=float(\"inf\"), max_depth=3)\n", 168 | "# 模型训练\n", 169 | "model.fit(X_train, y_train)\n", 170 | "# 模型预测\n", 171 | "y_pred = model.predict(X_test)\n", 172 | "# 计算模型预测的均方误差\n", 173 | "mse = mean_squared_error(y_test, y_pred)\n", 174 | "print (\"Mean Squared Error of NumPy GBRT:\", mse)" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 6, 180 | "metadata": {}, 181 | "outputs": [ 182 | { 183 | "name": "stdout", 184 | "output_type": "stream", 185 | "text": [ 186 | "Mean Squared Error of sklearn GBRT: 14.54623458175684\n" 187 | ] 188 | } 189 | ], 190 | "source": [ 191 | "# 导入sklearn GBDT模块\n", 192 | "from sklearn.ensemble import GradientBoostingRegressor\n", 193 | "# 创建模型实例\n", 194 | "reg = GradientBoostingRegressor(n_estimators=15, learning_rate=0.5,\n", 195 | " max_depth=3, random_state=0)\n", 196 | "# 模型拟合\n", 197 | "reg.fit(X_train, y_train.ravel())\n", 198 | "# 模型预测\n", 199 | "y_pred = reg.predict(X_test)\n", 200 | "# 计算模型预测的均方误差\n", 201 | "mse = mean_squared_error(y_test, y_pred)\n", 202 | "print (\"Mean Squared Error of sklearn GBRT:\", mse)" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": null, 208 | "metadata": {}, 209 | "outputs": [], 210 | "source": [] 211 | } 212 | ], 213 | "metadata": { 214 | "kernelspec": { 215 | "display_name": "Python 3", 216 | "language": "python", 217 | "name": "python3" 218 | }, 219 | "language_info": { 220 | "codemirror_mode": { 221 | "name": "ipython", 222 | "version": 3 223 | }, 224 | "file_extension": ".py", 225 | "mimetype": "text/x-python", 226 | "name": "python", 227 | "nbconvert_exporter": "python", 228 | "pygments_lexer": "ipython3", 229 | "version": "3.10.14" 230 | }, 231 | "toc": { 232 | "base_numbering": 1, 233 | "nav_menu": {}, 234 | "number_sections": true, 235 | "sideBar": true, 236 | "skip_h1_title": false, 237 | "title_cell": "Table of Contents", 238 | "title_sidebar": "Contents", 239 | "toc_cell": false, 240 | "toc_position": {}, 241 | "toc_section_display": true, 242 | "toc_window_display": false 243 | } 244 | }, 245 | "nbformat": 4, 246 | "nbformat_minor": 2 247 | } 248 | -------------------------------------------------------------------------------- /charpter11_GBDT/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | ### 定义二叉特征分裂函数 4 | def feature_split(X, feature_i, threshold): 5 | split_func = None 6 | if isinstance(threshold, int) or isinstance(threshold, float): 7 | split_func = lambda sample: sample[feature_i] <= threshold 8 | else: 9 | split_func = lambda sample: sample[feature_i] == threshold 10 | 11 | X_left = np.array([sample for sample in X if split_func(sample)]) 12 | X_right = np.array([sample for sample in X if not split_func(sample)]) 13 | 14 | return X_left, X_right 15 | 16 | 17 | ### 计算基尼指数 18 | def calculate_gini(y): 19 | y = y.tolist() 20 | probs = [y.count(i)/len(y) for i in np.unique(y)] 21 | gini = sum([p*(1-p) for p in probs]) 22 | return gini 23 | 24 | 25 | ### 打乱数据 26 | def data_shuffle(X, y, seed=None): 27 | if seed: 28 | np.random.seed(seed) 29 | idx = np.arange(X.shape[0]) 30 | np.random.shuffle(idx) 31 | return X[idx], y[idx] 32 | -------------------------------------------------------------------------------- /charpter12_XGBoost/cart.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from utils import feature_split, calculate_gini 3 | 4 | ### 定义树结点 5 | class TreeNode(): 6 | def __init__(self, feature_i=None, threshold=None, 7 | leaf_value=None, left_branch=None, right_branch=None): 8 | # 特征索引 9 | self.feature_i = feature_i 10 | # 特征划分阈值 11 | self.threshold = threshold 12 | # 叶子节点取值 13 | self.leaf_value = leaf_value 14 | # 左子树 15 | self.left_branch = left_branch 16 | # 右子树 17 | self.right_branch = right_branch 18 | 19 | ### 定义二叉决策树 20 | class BinaryDecisionTree(object): 21 | ### 决策树初始参数 22 | def __init__(self, min_samples_split=2, min_gini_impurity=999, 23 | max_depth=float("inf"), loss=None): 24 | # 根结点 25 | self.root = None 26 | # 节点最小分裂样本数 27 | self.min_samples_split = min_samples_split 28 | # 节点初始化基尼不纯度 29 | self.min_gini_impurity = min_gini_impurity 30 | # 树最大深度 31 | self.max_depth = max_depth 32 | # 基尼不纯度计算函数 33 | self.gini_impurity_calculation = None 34 | # 叶子节点值预测函数 35 | self._leaf_value_calculation = None 36 | # 损失函数 37 | self.loss = loss 38 | 39 | ### 决策树拟合函数 40 | def fit(self, X, y, loss=None): 41 | # 递归构建决策树 42 | self.root = self._build_tree(X, y) 43 | self.loss=None 44 | 45 | ### 决策树构建函数 46 | def _build_tree(self, X, y, current_depth=0): 47 | # 初始化最小基尼不纯度 48 | init_gini_impurity = 999 49 | # 初始化最佳特征索引和阈值 50 | best_criteria = None 51 | # 初始化数据子集 52 | best_sets = None 53 | 54 | if len(np.shape(y)) == 1: 55 | y = np.expand_dims(y, axis=1) 56 | 57 | # 合并输入和标签 58 | Xy = np.concatenate((X, y), axis=1) 59 | # 获取样本数和特征数 60 | n_samples, n_features = X.shape 61 | # 设定决策树构建条件 62 | # 训练样本数量大于节点最小分裂样本数且当前树深度小于最大深度 63 | if n_samples >= self.min_samples_split and current_depth <= self.max_depth: 64 | # 遍历计算每个特征的基尼不纯度 65 | for feature_i in range(n_features): 66 | # 获取第i特征的所有取值 67 | feature_values = np.expand_dims(X[:, feature_i], axis=1) 68 | # 获取第i个特征的唯一取值 69 | unique_values = np.unique(feature_values) 70 | 71 | # 遍历取值并寻找最佳特征分裂阈值 72 | for threshold in unique_values: 73 | # 特征节点二叉分裂 74 | Xy1, Xy2 = feature_split(Xy, feature_i, threshold) 75 | # 如果分裂后的子集大小都不为0 76 | if len(Xy1) > 0 and len(Xy2) > 0: 77 | # 获取两个子集的标签值 78 | y1 = Xy1[:, n_features:] 79 | y2 = Xy2[:, n_features:] 80 | 81 | # 计算基尼不纯度 82 | impurity = self.impurity_calculation(y, y1, y2) 83 | 84 | # 获取最小基尼不纯度 85 | # 最佳特征索引和分裂阈值 86 | if impurity < init_gini_impurity: 87 | init_gini_impurity = impurity 88 | best_criteria = {"feature_i": feature_i, "threshold": threshold} 89 | best_sets = { 90 | "leftX": Xy1[:, :n_features], 91 | "lefty": Xy1[:, n_features:], 92 | "rightX": Xy2[:, :n_features], 93 | "righty": Xy2[:, n_features:] 94 | } 95 | 96 | # 如果计算的最小不纯度小于设定的最小不纯度 97 | if init_gini_impurity < self.min_gini_impurity: 98 | # 分别构建左右子树 99 | left_branch = self._build_tree(best_sets["leftX"], best_sets["lefty"], current_depth + 1) 100 | right_branch = self._build_tree(best_sets["rightX"], best_sets["righty"], current_depth + 1) 101 | return TreeNode(feature_i=best_criteria["feature_i"], threshold=best_criteria["threshold"], left_branch=left_branch, right_branch=right_branch) 102 | 103 | # 计算叶子计算取值 104 | leaf_value = self._leaf_value_calculation(y) 105 | return TreeNode(leaf_value=leaf_value) 106 | 107 | ### 定义二叉树值预测函数 108 | def predict_value(self, x, tree=None): 109 | if tree is None: 110 | tree = self.root 111 | # 如果叶子节点已有值,则直接返回已有值 112 | if tree.leaf_value is not None: 113 | return tree.leaf_value 114 | # 选择特征并获取特征值 115 | feature_value = x[tree.feature_i] 116 | # 判断落入左子树还是右子树 117 | branch = tree.right_branch 118 | if isinstance(feature_value, int) or isinstance(feature_value, float): 119 | if feature_value >= tree.threshold: 120 | branch = tree.left_branch 121 | elif feature_value == tree.threshold: 122 | branch = tree.right_branch 123 | # 测试子集 124 | return self.predict_value(x, branch) 125 | 126 | ### 数据集预测函数 127 | def predict(self, X): 128 | y_pred = [self.predict_value(sample) for sample in X] 129 | return y_pred 130 | 131 | 132 | 133 | class ClassificationTree(BinaryDecisionTree): 134 | ### 定义基尼不纯度计算过程 135 | def _calculate_gini_impurity(self, y, y1, y2): 136 | p = len(y1) / len(y) 137 | gini = calculate_gini(y) 138 | gini_impurity = p * calculate_gini(y1) + (1-p) * calculate_gini(y2) 139 | return gini_impurity 140 | 141 | ### 多数投票 142 | def _majority_vote(self, y): 143 | most_common = None 144 | max_count = 0 145 | for label in np.unique(y): 146 | # 统计多数 147 | count = len(y[y == label]) 148 | if count > max_count: 149 | most_common = label 150 | max_count = count 151 | return most_common 152 | 153 | # 分类树拟合 154 | def fit(self, X, y): 155 | self.impurity_calculation = self._calculate_gini_impurity 156 | self._leaf_value_calculation = self._majority_vote 157 | super(ClassificationTree, self).fit(X, y) 158 | 159 | 160 | ### CART回归树 161 | class RegressionTree(BinaryDecisionTree): 162 | # 计算方差减少量 163 | def _calculate_variance_reduction(self, y, y1, y2): 164 | var_tot = np.var(y, axis=0) 165 | var_y1 = np.var(y1, axis=0) 166 | var_y2 = np.var(y2, axis=0) 167 | frac_1 = len(y1) / len(y) 168 | frac_2 = len(y2) / len(y) 169 | # 计算方差减少量 170 | variance_reduction = var_tot - (frac_1 * var_y1 + frac_2 * var_y2) 171 | return sum(variance_reduction) 172 | 173 | # 节点值取平均 174 | def _mean_of_y(self, y): 175 | value = np.mean(y, axis=0) 176 | return value if len(value) > 1 else value[0] 177 | 178 | # 回归树拟合 179 | def fit(self, X, y): 180 | self.impurity_calculation = self._calculate_variance_reduction 181 | self._leaf_value_calculation = self._mean_of_y 182 | super(RegressionTree, self).fit(X, y) 183 | -------------------------------------------------------------------------------- /charpter12_XGBoost/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | ### 定义二叉特征分裂函数 4 | def feature_split(X, feature_i, threshold): 5 | split_func = None 6 | if isinstance(threshold, int) or isinstance(threshold, float): 7 | split_func = lambda sample: sample[feature_i] >= threshold 8 | else: 9 | split_func = lambda sample: sample[feature_i] == threshold 10 | 11 | X_left = np.array([sample for sample in X if split_func(sample)]) 12 | X_right = np.array([sample for sample in X if not split_func(sample)]) 13 | return np.array([X_left, X_right]) 14 | 15 | ### 计算基尼指数 16 | def calculate_gini(y): 17 | y = y.tolist() 18 | probs = [y.count(i)/len(y) for i in np.unique(y)] 19 | gini = sum([p*(1-p) for p in probs]) 20 | return gini 21 | 22 | ### 打乱数据 23 | def data_shuffle(X, y, seed=None): 24 | if seed: 25 | np.random.seed(seed) 26 | idx = np.arange(X.shape[0]) 27 | np.random.shuffle(idx) 28 | return X[idx], y[idx] 29 | 30 | ### 类别标签转换 31 | def cat_label_convert(y, n_col=None): 32 | if not n_col: 33 | n_col = np.amax(y) + 1 34 | one_hot = np.zeros((y.shape[0], n_col)) 35 | one_hot[np.arange(y.shape[0]), y] = 1 36 | return one_hot 37 | -------------------------------------------------------------------------------- /charpter12_XGBoost/xgboost.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### XGBoost" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "from cart import TreeNode, BinaryDecisionTree\n", 18 | "from sklearn.model_selection import train_test_split\n", 19 | "from sklearn.metrics import accuracy_score\n", 20 | "from utils import cat_label_convert" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 8, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "### XGBoost单棵树类\n", 30 | "class XGBoost_Single_Tree(BinaryDecisionTree):\n", 31 | " # 结点分裂方法\n", 32 | " def node_split(self, y):\n", 33 | " # 中间特征所在列\n", 34 | " feature = int(np.shape(y)[1]/2)\n", 35 | " # 左子树为真实值,右子树为预测值\n", 36 | " y_true, y_pred = y[:, :feature], y[:, feature:]\n", 37 | " return y_true, y_pred\n", 38 | "\n", 39 | " # 信息增益计算方法\n", 40 | " def gain(self, y, y_pred):\n", 41 | " # 梯度计算\n", 42 | " Gradient = np.power((y * self.loss.gradient(y, y_pred)).sum(), 2)\n", 43 | " # Hessian矩阵计算\n", 44 | " Hessian = self.loss.hess(y, y_pred).sum()\n", 45 | " return 0.5 * (Gradient / Hessian)\n", 46 | "\n", 47 | " # 树分裂增益计算\n", 48 | " # 式(12.28)\n", 49 | " def gain_xgb(self, y, y1, y2):\n", 50 | " # 结点分裂\n", 51 | " y_true, y_pred = self.node_split(y)\n", 52 | " y1, y1_pred = self.node_split(y1)\n", 53 | " y2, y2_pred = self.node_split(y2)\n", 54 | " true_gain = self.gain(y1, y1_pred)\n", 55 | " false_gain = self.gain(y2, y2_pred)\n", 56 | " gain = self.gain(y_true, y_pred)\n", 57 | " return true_gain + false_gain - gain\n", 58 | "\n", 59 | " # 计算叶子结点最优权重\n", 60 | " def leaf_weight(self, y):\n", 61 | " y_true, y_pred = self.node_split(y)\n", 62 | " # 梯度计算\n", 63 | " gradient = np.sum(y_true * self.loss.gradient(y_true, y_pred), axis=0)\n", 64 | " # hessian矩阵计算\n", 65 | " hessian = np.sum(self.loss.hess(y_true, y_pred), axis=0)\n", 66 | " # 叶子结点得分\n", 67 | " leaf_weight = gradient / hessian\n", 68 | " return leaf_weight\n", 69 | "\n", 70 | " # 树拟合方法\n", 71 | " def fit(self, X, y):\n", 72 | " self.impurity_calculation = self.gain_xgb\n", 73 | " self._leaf_value_calculation = self.leaf_weight\n", 74 | " super(XGBoost_Single_Tree, self).fit(X, y)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 9, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "### 分类损失函数定义\n", 84 | "# 定义Sigmoid类\n", 85 | "class Sigmoid:\n", 86 | " def __call__(self, x):\n", 87 | " return 1 / (1 + np.exp(-x))\n", 88 | "\n", 89 | " def gradient(self, x):\n", 90 | " return self.__call__(x) * (1 - self.__call__(x))\n", 91 | "\n", 92 | "# 定义Logit损失\n", 93 | "class LogisticLoss:\n", 94 | " def __init__(self):\n", 95 | " sigmoid = Sigmoid()\n", 96 | " self._func = sigmoid\n", 97 | " self._grad = sigmoid.gradient\n", 98 | " \n", 99 | " # 定义损失函数形式\n", 100 | " def loss(self, y, y_pred):\n", 101 | " y_pred = np.clip(y_pred, 1e-15, 1 - 1e-15)\n", 102 | " p = self._func(y_pred)\n", 103 | " return y * np.log(p) + (1 - y) * np.log(1 - p)\n", 104 | "\n", 105 | " # 定义一阶梯度\n", 106 | " def gradient(self, y, y_pred):\n", 107 | " p = self._func(y_pred)\n", 108 | " return -(y - p)\n", 109 | "\n", 110 | " # 定义二阶梯度\n", 111 | " def hess(self, y, y_pred):\n", 112 | " p = self._func(y_pred)\n", 113 | " return p * (1 - p)" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 10, 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "### XGBoost定义\n", 123 | "class XGBoost:\n", 124 | " def __init__(self, n_estimators=300, learning_rate=0.001, \n", 125 | " min_samples_split=2,\n", 126 | " min_gini_impurity=999, \n", 127 | " max_depth=2):\n", 128 | " # 树的棵树\n", 129 | " self.n_estimators = n_estimators\n", 130 | " # 学习率\n", 131 | " self.learning_rate = learning_rate \n", 132 | " # 结点分裂最小样本数\n", 133 | " self.min_samples_split = min_samples_split \n", 134 | " # 结点最小基尼不纯度\n", 135 | " self.min_gini_impurity = min_gini_impurity \n", 136 | " # 树最大深度\n", 137 | " self.max_depth = max_depth \n", 138 | " # 用于分类的对数损失\n", 139 | " # 回归任务可定义平方损失 \n", 140 | " # self.loss = SquaresLoss()\n", 141 | " self.loss = LogisticLoss()\n", 142 | " # 初始化分类树列表\n", 143 | " self.trees = []\n", 144 | " # 遍历构造每一棵决策树\n", 145 | " for _ in range(n_estimators):\n", 146 | " tree = XGBoost_Single_Tree(\n", 147 | " min_samples_split=self.min_samples_split,\n", 148 | " min_gini_impurity=self.min_gini_impurity,\n", 149 | " max_depth=self.max_depth,\n", 150 | " loss=self.loss)\n", 151 | " self.trees.append(tree)\n", 152 | " \n", 153 | " # xgboost拟合方法\n", 154 | " def fit(self, X, y):\n", 155 | " y = cat_label_convert(y)\n", 156 | " y_pred = np.zeros(np.shape(y))\n", 157 | " # 拟合每一棵树后进行结果累加\n", 158 | " for i in range(self.n_estimators):\n", 159 | " tree = self.trees[i]\n", 160 | " y_true_pred = np.concatenate((y, y_pred), axis=1)\n", 161 | " tree.fit(X, y_true_pred)\n", 162 | " iter_pred = tree.predict(X)\n", 163 | " y_pred -= np.multiply(self.learning_rate, iter_pred)\n", 164 | "\n", 165 | " # xgboost预测方法\n", 166 | " def predict(self, X):\n", 167 | " y_pred = None\n", 168 | " # 遍历预测\n", 169 | " for tree in self.trees:\n", 170 | " iter_pred = tree.predict(X)\n", 171 | " if y_pred is None:\n", 172 | " y_pred = np.zeros_like(iter_pred)\n", 173 | " y_pred -= np.multiply(self.learning_rate, iter_pred)\n", 174 | " y_pred = np.exp(y_pred) / np.sum(np.exp(y_pred), axis=1, keepdims=True)\n", 175 | " # 将概率预测转换为标签\n", 176 | " y_pred = np.argmax(y_pred, axis=1)\n", 177 | " return y_pred" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 11, 183 | "metadata": {}, 184 | "outputs": [ 185 | { 186 | "name": "stdout", 187 | "output_type": "stream", 188 | "text": [ 189 | "Accuracy: 0.9333333333333333\n" 190 | ] 191 | } 192 | ], 193 | "source": [ 194 | "from sklearn import datasets\n", 195 | "# 导入鸢尾花数据集\n", 196 | "data = datasets.load_iris()\n", 197 | "# 获取输入输出\n", 198 | "X, y = data.data, data.target\n", 199 | "# 数据集划分\n", 200 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=43) \n", 201 | "# 创建xgboost分类器\n", 202 | "clf = XGBoost()\n", 203 | "# 模型拟合\n", 204 | "clf.fit(X_train, y_train)\n", 205 | "# 模型预测\n", 206 | "y_pred = clf.predict(X_test)\n", 207 | "# 准确率评估\n", 208 | "accuracy = accuracy_score(y_test, y_pred)\n", 209 | "print (\"Accuracy: \", accuracy)" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 13, 215 | "metadata": {}, 216 | "outputs": [ 217 | { 218 | "name": "stdout", 219 | "output_type": "stream", 220 | "text": [ 221 | "[14:56:53] WARNING: C:/Users/Administrator/workspace/xgboost-win64_release_1.3.0/src/learner.cc:1061: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'multi:softmax' was changed from 'merror' to 'mlogloss'. Explicitly set eval_metric if you'd like to restore the old behavior.\n", 222 | "Accuracy: 0.9666666666666667\n" 223 | ] 224 | }, 225 | { 226 | "data": { 227 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAEWCAYAAACOv5f1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAHlhJREFUeJzt3XucVXW9//HXB0RuoygOY+CASIiYXOaghT1SHCRLLgKpP4JDAXrMQosoL9FFi3PyEafklKdfaZgKpCKSCWRopTBInTBEESzlok5HcZSLYM6IMDN8zh9rzbAZ5rJh9pq94ft+Ph77MXt/99p7vWfBvPfa37Vnjbk7IiISllbZDiAiIi1P5S8iEiCVv4hIgFT+IiIBUvmLiARI5S8iEiCVv0gdZnaXmd2S7RwiSTJ9zl8yxcxKgVOB6pThPu7+ZjOesxi4390Lm5fu6GRmc4E33P072c4ixxbt+UumXebueSmXIy7+TDCz47K5/uYws9bZziDHLpW/tAgzO9/M/sfMdpvZC/Eefc19V5nZS2b2npm9amZfjMc7Ao8D3cysPL50M7O5Zvb9lMcXm9kbKbdLzewbZrYeqDCz4+LHPWJm283sNTOb1kjW2ueveW4zu9nMtplZmZmNNbMRZrbJzN4xs2+lPPZ7ZvZrM1sYfz/PmdnAlPvPNrOSeDv8zcxG11nvnWa2zMwqgH8DJgI3x9/7b+PlZpjZK/Hz/93MPpPyHFPM7E9mdruZ7Yq/1+Ep93c2s/vM7M34/sUp940ys3Vxtv8xswFp/wPLUUflL4kzs9OA3wHfBzoDNwKPmFmXeJFtwCjgROAq4MdmNsjdK4DhwJtH8E5iAjASOAnYD/wWeAE4DRgGTDezT6f5XB8C2sWPvRW4G/gccC5wIXCrmfVKWX4MsCj+Xh8EFptZGzNrE+f4A1AAfAV4wMzOSnnsvwK3AScA84EHgB/G3/tl8TKvxOvtBMwE7jezrinPMRjYCOQDPwTuMTOL7/sV0AE4J87wYwAzGwTcC3wROAX4BbDUzNqmuY3kKKPyl0xbHO857k7Zq/wcsMzdl7n7fnf/I/AsMALA3X/n7q94ZCVROV7YzBz/7e6vu/se4KNAF3f/d3ff5+6vEhX4+DSfqxK4zd0rgYeISvUOd3/P3f8G/A1I3Ute6+6/jpf/L6IXjvPjSx4wK86xHHiM6IWqxhJ3/3O8nT6oL4y7L3L3N+NlFgKbgY+lLPIPd7/b3auBeUBX4NT4BWI48CV33+XulfH2BvgC8At3f8bdq919HrA3zizHoKN2PlRy1lh3f7LO2OnA/zOzy1LG2gArAOJpie8CfYh2SDoAG5qZ4/U66+9mZrtTxloDq9J8rp1xkQLsib++nXL/HqJSP2Td7r4/npLqVnOfu+9PWfYfRO8o6stdLzObBHwd6BkP5RG9INV4K2X978c7/XlE70Tecfdd9Tzt6cBkM/tKytjxKbnlGKPyl5bwOvArd/9C3TviaYVHgElEe72V8TuGmmmK+j6OVkH0AlHjQ/Usk/q414HX3P3MIwl/BLrXXDGzVkAhUDNd1d3MWqW8APQANqU8tu73e9BtMzud6F3LMOAv7l5tZus4sL0a8zrQ2cxOcvfd9dx3m7vflsbzyDFA0z7SEu4HLjOzT5tZazNrFx9ILSTau2wLbAeq4ncBn0p57NvAKWbWKWVsHTAiPnj5IWB6E+v/K/DP+CBw+zhDPzP7aMa+w4Oda2aXx580mk40fbIaeIbohevm+BhAMXAZ0VRSQ94GUo8ndCR6QdgO0cFyoF86ody9jOgA+s/N7OQ4w5D47ruBL5nZYIt0NLORZnZCmt+zHGVU/pI4d3+d6CDot4hK63XgJqCVu78HTAMeBnYRHfBcmvLYl4EFwKvxcYRuRActXwBKiY4PLGxi/dVEJVsEvAbsAH5JdMA0CUuAzxJ9P58HLo/n1/cBo4nm3XcAPwcmxd9jQ+4BPlJzDMXd/w7MBv5C9MLQH/jzYWT7PNExjJeJDrRPB3D3Z4nm/f9/nHsLMOUwnleOMvolL5EMMrPvAb3d/XPZziLSGO35i4gESOUvIhIgTfuIiARIe/4iIgHK2c/5n3TSSd67d+9sx2hURUUFHTt2zHaMRilj8+V6PlDGTDkWMq5du3aHu3dpcIEa7p6Tlz59+niuW7FiRbYjNEkZmy/X87krY6YcCxmBZz2NjtW0j4hIgFT+IiIBUvmLiARI5S8iEiCVv4hIgFT+IiIBUvmLiARI5S8iEiCVv4hIgFT+IiIBUvmLiARI5S8iEiCVv4hIgFT+IiIBUvmLiARI5S8iEiCVv4hIgFT+IiIBUvmLiARI5S8iEiCVv4hIgFT+IiIBUvmLiARI5S8iEiCVv4hIgFT+IiIBUvmLiARI5S8iEiCVv4hIgFT+IiIBUvmLiARI5S8iEiCVv4hIgFT+IiIBUvmLiARI5S8iEiCVv4hIgMzds52hXj169fZW4+7IdoxG3dC/itkbjst2jEYpY/Plej5QxkzJZMbSWSMz8jx1lZSUUFxc3OD9ZrbW3c9r6nm05y8ikpAPPviAj33sYwwcOJBzzjmH7373uwBceOGFFBUVUVRURLdu3Rg7diwAS5YsYcCAARQVFXHeeefxpz/9KbFsib0Mm9k0YCrQF9gQD5cDU939haTWKyKSK9q2bcvy5cvJy8ujsrKSCy64gOHDh7Nq1araZa644grGjBkDwLBhwxg9ejRmxvr16xk3bhwvv/xyItmSfA92HTAc6Aq85O67zGw4MAcYnOB6RURygpmRl5cHQGVlJZWVlZhZ7f3vvfcey5cv57777gOoXRagoqLioGUzLZFpHzO7C+gFLAUGu/uu+K7VQGES6xQRyUXV1dUUFRVRUFDAJZdcwuDBB/Z9H330UYYNG8aJJ5540Fjfvn0ZOXIk9957b2K5Ejvga2alwHnuviNl7Eagr7tf08BjrgWuBcjP73LurT+5O5FsmXJqe3h7T7ZTNE4Zmy/X84EyZkomM/Y/rdNBt8vLy7nllluYNm0aZ5xxBgDf+MY3GDFiBBdddNEhj3/hhReYP38+s2fPPuR5Ut8h1DV06NC0Dvi2WPmb2VDg58AF7r6zqcfr0z6ZoYzNl+v5QBkzJelP+8ycOZOOHTty4403snPnTvr06cPWrVtp165dvc9xxhlnsGbNGvLz82vHjqpP+5jZAOCXwJh0il9E5Fiwfft2du/eDcCePXt48skn6du3LwCLFi1i1KhRBxX/li1bqNkhf+6559i3bx+nnHJKItkSfxk2sx7Ab4DPu/umpNcnIpIrysrKmDx5MtXV1ezfv59x48YxatQoAB566CFmzJhx0PKPPPII8+fPp02bNrRv356FCxcmd9DX3RO5AKVAPtEe/y5gXXx5Np3H9+nTx3PdihUrsh2hScrYfLmez10ZM+VYyJhuxya25+/uPeOr18QXERHJEfoNXxGRAKn8RUQCpPIXEQmQyl9EJEAqfxGRAKn8RUQCpPIXEQmQyl9EJEAqfxGRAKn8RUQCpPIXEQmQyl9EJEAqfxGRAKn8RUQCpPIXEQmQyl9EJEAqfxGRAKn8RUQCpPIXEQmQyl9EJEAqfxGRAKn8RUQCpPIXEQmQyl9EJEAqfxGRAKn8RUQCpPIXEQmQyl9EJEAqfxGRAKn8RUQCpPIXEQmQyl9EJEAqfxGRAJm7ZztDvXr06u2txt2R7RiNuqF/FbM3HJftGI1SxubL9XygjJlSX8bSWSP54IMPGDJkCHv37qWqqoorr7ySmTNnMmXKFFauXEmnTp0AmDt3LkVFRQCUlJQwffp0Kisryc/PZ+XKlRnJWFJSQnFxcYP3m9ladz+vqedJ9F/CzKYBU4EPAa8D+4EqYLq7/ynJdYuIZErbtm1Zvnw5eXl5VFZWcsEFFzB8+HAAfvSjH3HllVcetPzu3bu57rrreOKJJ+jRowfbtm3LRuxGJf0yfB0wHNgOVLi7m9kA4GGgb8LrFhHJCDMjLy8PgMrKSiorKzGzBpd/8MEHufzyy+nRowcABQUFLZLzcCQ2529mdwG9gKXAF/zA/FJHIDfnmkREGlBdXU1RUREFBQVccsklDB48GIBvf/vbDBgwgK997Wvs3bsXgE2bNrFr1y6Ki4s599xzmT9/fjaj1yvROX8zKwXOc/cdZvYZ4AdAATDS3f9Sz/LXAtcC5Od3OffWn9ydWLZMOLU9vL0n2ykap4zNl+v5QBkzpb6M/U/rdNDt8vJybrnlFqZNm8aJJ55I586dqaysZPbs2XTr1o3Jkydzxx13sHHjRmbPns2+ffu4/vrr+cEPfkD37t2bnbG8vLz2XUh9hg4dmv05/1Tu/ijwqJkNAf4D+GQ9y8wB5kB0wPdoPDiUa5Sx+XI9HyhjptR7wHdi8SHLrV27lp07d3LVVVfVjh1//PHcfvvtFBcXs3r1agYOHFh7XGDp0qW0a9eu0QO16WrqgG+6Wvyjnu7+NPBhM8tv6XWLiByJ7du3s3v3bgD27NnDk08+Sd++fSkrKwPA3Vm8eDH9+vUDYMyYMaxatYqqqiref/99nnnmGc4+++ys5a/PYb8Mm9nJQHd3X38Yj+kNvBIf8B0EHA/sPNx1i4hkQ1lZGZMnT6a6upr9+/czbtw4Ro0axcUXX8z27dtxd4qKirjrrrsAOPvss7n00ksZMGAArVq14pprrql9YcgVaZW/mZUAo+Pl1wHbzWylu389zfVcAUwys0pgD/BZz9VfMBARqWPAgAE8//zzh4wvX768wcfcdNNN3HTTTUnGah53b/ICPB9/vQaYGV9fn85jj/TSp08fz3UrVqzIdoQmKWPz5Xo+d2XMlGMhI/Csp9Gx6c75H2dmXYFxwGMJvAaJiEgLSrf8/x34PdG8/Roz6wVsTi6WiIgkKa05f3dfBCxKuf0q0Ty+iIgchdLa8zezPmb2lJm9GN8eYGbfSTaaiIgkJd1pn7uBbwKVAB59zHN8UqFERCRZ6ZZ/B3f/a52xqkyHERGRlpFu+e8wsw8Tn5DNzK4EyhJLJSIiiUr3N3yvJzrnTl8z2wq8BkxMLJWIiCSqyfI3s1ZEZ+b8pJl1BFq5+3vJRxMRkaQ0Oe3j7vuBL8fXK1T8IiJHv3Tn/P9oZjeaWXcz61xzSTSZiIgkJt05/6vjr9enjDnRX+oSEZGjTLq/4XtG0kFERKTlpHtK50n1jbt77v1hShERaVK60z4fTbneDhgGPAeo/EVEjkLpTvt8JfW2mXUCfpVIIhERSdyR/g3f94EzMxlERERaTrpz/r8lPrUD0QvGR0g5xbOIiBxd0p3zvz3lehXwD3d/I4E8IiLSAtKd9hnh7ivjy5/d/Q0z+89Ek4mISGLSLf9L6hkbnskgIiLSchqd9jGzqcB1QC8zW59y1wnAn5MMJiIiyWlqzv9B4HHgB8CMlPH33P2dxFKJiEiiGi1/d38XeBeYAGBmBUS/5JVnZnnu/r/JRxQRkUxL9w+4X2Zmm4n+iMtKoJToHYGIiByF0j3g+33gfGBTfJK3YWjOX0TkqJVu+Ve6+06glZm1cvcVQFGCuUREJEHp/pLXbjPLA1YBD5jZNqJf9hIRkaNQunv+Y4jO5zMdeAJ4BbgsqVAiIpKsdM/qWWFmpwNnuvs8M+sAtE42moiIJCXdT/t8Afg18It46DRgcVKhREQkWelO+1wPfAL4J4C7bwYKkgolIiLJSrf897r7vpobZnYcB07xLCIiR5l0P+2z0sy+BbQ3s0uIzvfz2+RiwZ7KanrO+F2Sq2i2G/pXMUUZm+1IMpbOGplQGpEwpLvnPwPYDmwAvggsA76TVCiRdF199dUUFBTQr1+/g8Z/+tOfctZZZ3HOOedw8803A/DAAw9QVFRUe2nVqhXr1q3LRmyRrGvqrJ493P1/3X0/cHd8SZuZTQOmEv2x953ACKKPjE5x9+eOLLLIAVOmTOHLX/4ykyZNqh1bsWIFS5YsYf369bRt25Zt27YBMHHiRCZOnAjAhg0bGDNmDEVF+l1FCVNTe/61n+gxs0eO4PmvIyr8B4j+5u+ZwLXAnUfwXCKHGDJkCJ07dz5o7M4772TGjBm0bdsWgIKCQz+bsGDBAiZMmNAiGUVyUVPlbynXex3OE5vZXfFjlgKPAvM9sho4ycy6HlZSkTRt2rSJVatWMXjwYC666CLWrFlzyDILFy5U+UvQmjrg6w1cb5K7f8nMLgWGAnOB11PufoPodwXKUh9jZtcSvTMgP78Lt/bP7TNInNo+OliZy47VjCUlJbXX33rrLSoqKmrH3n33XTZs2MCsWbN4+eWXGT16NA8++CBm0b7M3//+d9ydHTt2HPQ8DSkvL09ruWxSxswIKWNT5T/QzP5J9A6gfXyd+La7+4lprsfqGTvkxcTd5wBzAHr06u2zN6T7YaTsuKF/FcrYfEeSsXRi8YHrpaV07NiR4uJo7KyzzmLatGkUFxczdOhQbr/9dvr160eXLl0AWLJkCddcc03t8k0pKSlJe9lsUcbMCCljo9M+7t7a3U909xPc/bj4es3tdIsfoj397im3C4E3jySwSFPGjh3L8uXLgWgKaN++feTn5wOwf/9+Fi1axPjx47MZUSTr0v2oZ3MtBSZZ5HzgXXcva+pBIk2ZMGECH//4x9m4cSOFhYXcc889XH311bz66qv069eP8ePHM2/evNopn6effprCwkJ69TqsQ1gix5yWmg9YRvSpny1EH/W8qqkHtG/Tmo05/os8JSUlB00/5KJjPeOCBQvqHb///vvrHS8uLmb16tVHtC6RY0mi5e/uPVNuXp/kukREJH0tNe0jIiI5ROUvIhIglb+ISIBU/iIiAVL5i4gESOUvIhIglb+ISIBU/iIiAVL5i4gESOUvIhIglb+ISIBU/iIiAVL5i4gESOUvIhIglb+ISIBU/iIiAVL5i4gESOUvIhIglb+ISIBU/iIiAVL5i4gESOUvIhIglb+ISIBU/iIiAVL5i4gESOUvIhIglb+ISIBU/iIiAVL5i4gESOUvIhIglb+ISIBU/iIiAVL5i4gE6LhsB2jInspqes74XbZjNOqG/lVMyfGMcy/tCMDVV1/NY489RkFBAS+++CIA77zzDp/97GcpLS2lZ8+ePPzww5x88sm4O1/96ldZtmwZHTp0YO7cuQwaNCib34aIZFhie/5mNs3MXjKzR8zsL2a218xuTGp90rgpU6bwxBNPHDQ2a9Yshg0bxubNmxk2bBizZs0C4PHHH2fz5s1s3ryZOXPmMHXq1GxEFpEEJTntcx0wApgKTANuT3Bd0oQhQ4bQuXPng8aWLFnC5MmTAZg8eTKLFy+uHZ80aRJmxvnnn8/u3bspKytr8cwikpxEyt/M7gJ6AUuBie6+BqhMYl1y5N5++226du0KQNeuXdm2bRsAW7dupXv37rXLFRYWsnXr1qxkFJFkJDLn7+5fMrNLgaHuviPdx5nZtcC1APn5Xbi1f1US8TLm1PbRvH8uKy8vp6SkBIC33nqLioqK2ttVVVW111Nv79ixg+eff56qquh727VrF2vXrqW8vDzxjLko1/OBMmZKSBlz6oCvu88B5gD06NXbZ2/IqXiHuKF/Fbmece6lHSkuLgagtLSUjh0P3D7ttNM466yz6Nq1K2VlZXTr1o3i4mIGDhxIfn5+7XIVFRWMHj269l1CppWUlNSuKxflej5QxkwJKaM+6hmw0aNHM2/ePADmzZvHmDFjasfnz5+Pu7N69Wo6deqUWPGLSHbk9m6rZMyECRNqp3QKCwuZOXMmM2bMYNy4cdxzzz306NGDRYsWATBixAiWLVtG79696dChA/fdd1+W04tIpiVe/mb2IeBZ4ERgv5lNBz7i7v9Met1ywIIFC+odf+qppw4ZMzN+9rOfJR1JRLIosfJ3954pNwsP9/Ht27Rm46yRmQuUgJKSEkonFmc7RqNy/eCViGSH5vxFRAKk8hcRCZDKX0QkQCp/EZEAqfxFRAKk8hcRCZDKX0QkQCp/EZEAqfxFRAKk8hcRCZDKX0QkQCp/EZEAqfxFRAKk8hcRCZDKX0QkQCp/EZEAqfxFRAKk8hcRCZDKX0QkQCp/EZEAqfxFRAKk8hcRCZDKX0QkQCp/EZEAqfxFRAKk8hcRCZDKX0QkQCp/EZEAqfxFRAKk8hcRCZDKX0QkQCp/EZEAqfxFRAKk8hcRCZDKX0QkQCp/EZEAqfxFRAKk8hcRCZC5e7Yz1MvM3gM2ZjtHE/KBHdkO0QRlbL5czwfKmCnHQsbT3b1LU09yXObyZNxGdz8v2yEaY2bPKmPz5XrGXM8HypgpIWXUtI+ISIBU/iIiAcrl8p+T7QBpUMbMyPWMuZ4PlDFTgsmYswd8RUQkObm85y8iIglR+YuIBCgny9/MLjWzjWa2xcxmZDsPgJmVmtkGM1tnZs/GY53N7I9mtjn+enILZ7rXzLaZ2YspY/Vmssh/x9t0vZkNymLG75nZ1nhbrjOzESn3fTPOuNHMPt1CGbub2Qoze8nM/mZmX43Hc2JbNpIvZ7ajmbUzs7+a2Qtxxpnx+Blm9ky8DRea2fHxeNv49pb4/p5ZzDjXzF5L2Y5F8XhWfmbidbc2s+fN7LH4dua3o7vn1AVoDbwC9AKOB14APpIDuUqB/DpjPwRmxNdnAP/ZwpmGAIOAF5vKBIwAHgcMOB94JosZvwfcWM+yH4n/vdsCZ8T/D1q3QMauwKD4+gnApjhLTmzLRvLlzHaMt0VefL0N8Ey8bR4GxsfjdwFT4+vXAXfF18cDC1vg37mhjHOBK+tZPis/M/G6vw48CDwW3874dszFPf+PAVvc/VV33wc8BIzJcqaGjAHmxdfnAWNbcuXu/jTwTpqZxgDzPbIaOMnMumYpY0PGAA+5+153fw3YQvT/IVHuXubuz8XX3wNeAk4jR7ZlI/ka0uLbMd4W5fHNNvHFgYuBX8fjdbdhzbb9NTDMzCxLGRuSlZ8ZMysERgK/jG8bCWzHXCz/04DXU26/QeP/0VuKA38ws7Vmdm08dqq7l0H0AwoUZC3dAQ1lyrXt+uX4rfS9KdNlWc8Yv23+F6K9wpzblnXyQQ5tx3iqYh2wDfgj0TuO3e5eVU+O2ozx/e8Cp7R0Rnev2Y63xdvxx2bWtm7GevIn6SfAzcD++PYpJLAdc7H863vVyoXPo37C3QcBw4HrzWxItgMdplzarncCHwaKgDJgdjye1Yxmlgc8Akx39382tmg9Y4nnrCdfTm1Hd6929yKgkOidxtmN5MiJjGbWD/gm0Bf4KNAZ+Ea2MprZKGCbu69NHW4kxxFnzMXyfwPonnK7EHgzS1lqufub8ddtwKNE/7nfrnkbGH/dlr2EtRrKlDPb1d3fjn8I9wN3c2BKImsZzawNUbE+4O6/iYdzZlvWly8Xt2OcazdQQjRPfpKZ1ZxDLDVHbcb4/k6kPz2YyYyXxtNq7u57gfvI7nb8BDDazEqJprwvJnonkPHtmIvlvwY4Mz66fTzRQYyl2QxkZh3N7ISa68CngBfjXJPjxSYDS7KT8CANZVoKTIo/wXA+8G7NlEZLqzNv+hmibQlRxvHxJxjOAM4E/toCeQy4B3jJ3f8r5a6c2JYN5cul7WhmXczspPh6e+CTRMcmVgBXxovV3YY12/ZKYLnHRy1bOOPLKS/wRjSXnrodW/Rnxt2/6e6F7t6TqPuWu/tEktiOLXHk+nAvREfZNxHNGX47B/L0Ivr0xAvA32oyEc2tPQVsjr92buFcC4je7lcS7QH8W0OZiN4e/izephuA87KY8VdxhvXxf96uKct/O864ERjeQhkvIHqrvB5YF19G5Mq2bCRfzmxHYADwfJzlReDWeLwX0QvPFmAR0DYebxff3hLf3yuLGZfH2/FF4H4OfCIoKz8zKXmLOfBpn4xvR53eQUQkQLk47SMiIglT+YuIBEjlLyISIJW/iEiAVP4iIgHK5T/gLpIIM6sm+uhejbHuXpqlOCJZoY96SnDMrNzd81pwfcf5gfOyiOQETfuI1GFmXc3s6fjc7i+a2YXx+KVm9lx8Pvin4rHOZrY4PinYajMbEI9/z8zmmNkfgPnxCcV+ZGZr4mW/mMVvUUTTPhKk9vGZHQFec/fP1Ln/X4Hfu/ttZtYa6GBmXYjOnzPE3V8zs87xsjOB5919rJldDMwnOtEawLnABe6+Jz4T7Lvu/tH4rJF/NrM/eHTKZZEWp/KXEO3x6MyODVkD3BufTG2xu68zs2Lg6Zqydveak2ddAFwRjy03s1PMrFN831J33xNf/xQwwMxqzs/SieicOyp/yQqVv0gd7v50fMrukcCvzOxHwG7qP1VuY6fUraiz3Ffc/fcZDStyhDTnL1KHmZ1OdE71u4nOpjkI+AtwUXyWTFKmfZ4GJsZjxcAOr/9vAfwemBq/m8DM+sRniBXJCu35ixyqGLjJzCqBcmCSu2+P5+1/Y2atiM7tfwnR39G9z8zWA+9z4PS6df0S6Ak8F586eDst/Gc/RVLpo54iIgHStI+ISIBU/iIiAVL5i4gESOUvIhIglb+ISIBU/iIiAVL5i4gE6P8Ap5spSUxkDO0AAAAASUVORK5CYII=\n", 228 | "text/plain": [ 229 | "
" 230 | ] 231 | }, 232 | "metadata": { 233 | "needs_background": "light" 234 | }, 235 | "output_type": "display_data" 236 | } 237 | ], 238 | "source": [ 239 | "import xgboost as xgb\n", 240 | "from xgboost import plot_importance\n", 241 | "from matplotlib import pyplot as plt\n", 242 | "\n", 243 | "# 设置模型参数\n", 244 | "params = {\n", 245 | " 'booster': 'gbtree',\n", 246 | " 'objective': 'multi:softmax', \n", 247 | " 'num_class': 3, \n", 248 | " 'gamma': 0.1,\n", 249 | " 'max_depth': 2,\n", 250 | " 'lambda': 2,\n", 251 | " 'subsample': 0.7,\n", 252 | " 'colsample_bytree': 0.7,\n", 253 | " 'min_child_weight': 3,\n", 254 | " 'eta': 0.001,\n", 255 | " 'seed': 1000,\n", 256 | " 'nthread': 4,\n", 257 | "}\n", 258 | "\n", 259 | "\n", 260 | "dtrain = xgb.DMatrix(X_train, y_train)\n", 261 | "num_rounds = 200\n", 262 | "model = xgb.train(params, dtrain, num_rounds)\n", 263 | "# 对测试集进行预测\n", 264 | "dtest = xgb.DMatrix(X_test)\n", 265 | "y_pred = model.predict(dtest)\n", 266 | "\n", 267 | "# 计算准确率\n", 268 | "accuracy = accuracy_score(y_test, y_pred)\n", 269 | "print (\"Accuracy:\", accuracy)\n", 270 | "# 绘制特征重要性\n", 271 | "plot_importance(model)\n", 272 | "plt.show();" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": null, 278 | "metadata": {}, 279 | "outputs": [], 280 | "source": [ 281 | "\n" 282 | ] 283 | } 284 | ], 285 | "metadata": { 286 | "kernelspec": { 287 | "display_name": "Python 3", 288 | "language": "python", 289 | "name": "python3" 290 | }, 291 | "language_info": { 292 | "codemirror_mode": { 293 | "name": "ipython", 294 | "version": 3 295 | }, 296 | "file_extension": ".py", 297 | "mimetype": "text/x-python", 298 | "name": "python", 299 | "nbconvert_exporter": "python", 300 | "pygments_lexer": "ipython3", 301 | "version": "3.7.3" 302 | }, 303 | "toc": { 304 | "base_numbering": 1, 305 | "nav_menu": {}, 306 | "number_sections": true, 307 | "sideBar": true, 308 | "skip_h1_title": false, 309 | "title_cell": "Table of Contents", 310 | "title_sidebar": "Contents", 311 | "toc_cell": false, 312 | "toc_position": {}, 313 | "toc_section_display": true, 314 | "toc_window_display": false 315 | } 316 | }, 317 | "nbformat": 4, 318 | "nbformat_minor": 2 319 | } 320 | -------------------------------------------------------------------------------- /charpter13_LightGBM/lightgbm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "[1]\tvalid_0's multi_logloss: 1.02277\n", 13 | "Training until validation scores don't improve for 5 rounds\n", 14 | "[2]\tvalid_0's multi_logloss: 0.943765\n", 15 | "[3]\tvalid_0's multi_logloss: 0.873274\n", 16 | "[4]\tvalid_0's multi_logloss: 0.810478\n", 17 | "[5]\tvalid_0's multi_logloss: 0.752973\n", 18 | "[6]\tvalid_0's multi_logloss: 0.701621\n", 19 | "[7]\tvalid_0's multi_logloss: 0.654982\n", 20 | "[8]\tvalid_0's multi_logloss: 0.611268\n", 21 | "[9]\tvalid_0's multi_logloss: 0.572202\n", 22 | "[10]\tvalid_0's multi_logloss: 0.53541\n", 23 | "[11]\tvalid_0's multi_logloss: 0.502582\n", 24 | "[12]\tvalid_0's multi_logloss: 0.472856\n", 25 | "[13]\tvalid_0's multi_logloss: 0.443853\n", 26 | "[14]\tvalid_0's multi_logloss: 0.417764\n", 27 | "[15]\tvalid_0's multi_logloss: 0.393613\n", 28 | "[16]\tvalid_0's multi_logloss: 0.370679\n", 29 | "[17]\tvalid_0's multi_logloss: 0.349936\n", 30 | "[18]\tvalid_0's multi_logloss: 0.330669\n", 31 | "[19]\tvalid_0's multi_logloss: 0.312805\n", 32 | "[20]\tvalid_0's multi_logloss: 0.296973\n", 33 | "Did not meet early stopping. Best iteration is:\n", 34 | "[20]\tvalid_0's multi_logloss: 0.296973\n", 35 | "Accuracy of lightgbm: 1.0\n" 36 | ] 37 | }, 38 | { 39 | "data": { 40 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAacAAAEWCAYAAADCeVhIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xu8VHW9//HXe4MhF5MIRQNhB2pxCxJMO8dDW9NjKN66l8fi5KWbFtoR0co4PSg9iOeoXTTFE6Ye4KhRBKbHnzZYSCoYutEkRXeC1zSVi4Ib+Pz+WGvjALPZsy/DrJn9fj4e82DNd9bl852BebPW+s5aigjMzMyypKbcBZiZme3I4WRmZpnjcDIzs8xxOJmZWeY4nMzMLHMcTmZmljkOJ7MKI+kaSd8tdx1mpST/zsk6C0kNQD9gS17zwRHxXDvWWQfcFBED2lddZZI0C1gTEd8pdy1WXbznZJ3NCRHRK+/R5mDqCJK6lnP77SGpS7lrsOrlcDIDJB0u6T5Jr0l6ON0janrtXyX9WdI6SU9J+nLa3hP4LfAeSevTx3skzZI0LW/5Oklr8p43SLpA0iPABkld0+Vuk/Q3SU9L+sYuat22/qZ1S5os6SVJz0s6WdJxkv4i6e+SLspbdqqkWyXNTfvzkKRRea8PlZRL34dHJZ24w3avlnS7pA3A6cCpwOS0779J55siaVW6/scknZK3jomS/iBphqRX076Oz3u9j6SfS3ouff1Xea9NkLQ8re0+SR8o+gO2iuNwsk5PUn9gITAN6AP8G3CbpH3SWV4CJgDvBP4V+C9Jh0TEBmA88Fwb9sQ+BxwP9Aa2Ar8BHgb6Ax8FJkk6tsh17QfsmS57MXAd8C/AGOCfgIslDc6b/yTglrSv/wP8StIekvZI6/g/YF/gHOBmSe/LW/bzwA+AvYBfADcD09O+n5DOsyrd7t7AvwM3Sdo/bx2HASuBvsB04HpJSl+7EegBDE9r+C8ASYcA/w18GXg38DNgvqRuRb5HVmEcTtbZ/Cr9n/dref8r/xfg9oi4PSK2RsRdwFLgOICIWBgRqyKxiOTL+5/aWcdVEbE6It4EDgX2iYjvR8RbEfEUScB8tsh1NQI/iIhGYA7Jl/6VEbEuIh4FHgXy9zKWRcSt6fz/SRJsh6ePXsClaR33AAtIgrTJryNicfo+bSxUTETcEhHPpfPMBZ4APpQ3y18j4rqI2ALcAOwP9EsDbDzwlYh4NSIa0/cb4EzgZxFxf0RsiYgbgE1pzVaFKvZ4t1kbnRwR/2+HtkHApySdkNe2B/A7gPSw0/eAg0n+Q9cDqG9nHat32P57JL2W19YF+H2R63ol/aIHeDP988W8198kCZ2dth0RW9NDju9pei0itubN+1eSPbJCdRck6QvAeUBt2tSLJDCbvJC3/TfSnaZeJHtyf4+IVwusdhDwRUnn5LW9I69uqzIOJ7PkC/fGiDhzxxfSw0a3AV8g2WtoTPe4mg5DFRruuoEkwJrsV2Ce/OVWA09HxEFtKb4NDmiakFQDDACaDkceIKkmL6AGAn/JW3bH/m73XNIgkr2+jwJLImKLpOW8/X7tymqgj6TeEfFagdd+EBE/KGI9VgV8WM8MbgJOkHSspC6S9kwHGgwg+d95N+BvwOZ0L+qf85Z9EXi3pL3z2pYDx6Un9/cDJrWw/QeAtekgie5pDSMkHdphPdzeGEkfT0cKTiI5PPZH4H6SYJ2cnoOqA04gOVTYnBeB/PNZPUkC62+QDCYBRhRTVEQ8TzLA5KeS3pXWMC59+TrgK5IOU6KnpOMl7VVkn63COJys04uI1SSDBC4i+VJdDZwP1ETEOuAbwP8Cr5IMCJift+zjwGzgqfQ81ntITuo/DDSQnJ+a28L2t5CEwGjgaeBlYCbJgIJS+DXwGZL+nAZ8PD2/8xZwIsl5n5eBnwJfSPvYnOuBYU3n8CLiMeByYAlJcI0EFreittNIzqE9TjIQZRJARCwlOe/047TuJ4GJrVivVRj/CNesE5E0FTgwIv6l3LWY7Yr3nMzMLHMcTmZmljk+rGdmZpnjPSczM8sc/86pjXr37h0HHnhgucvoUBs2bKBnz57lLqNDVWOfoDr7VY19gursV3v6tGzZspcjYp+W5nM4tVG/fv1YunRpucvoULlcjrq6unKX0aGqsU9Qnf2qxj5BdfarPX2S9Ndi5vNhPTMzyxyHk5mZZY7DyczMMsfhZGZmmeNwMjOzzHE4mZlZ5jiczMwscxxOZmaWOQ4nMzPLHIeTmZlljsPJzMwyx+FkZmaZ43AyM7PMcTiZmVnmOJzMzCxzHE5mZpY5DiczM8sch5OZmWWOw8nMzDLH4WRmZpnjcDIzs8xxOJmZWeY4nMzMLHMcTmZmljkOJzMzyxyHk5mZZY7DyczMMsfhZGZmmeNwMjOzzHE4mZlZ5jiczMwscxxOZmaWOQ4nMzPLHIeTmZlljsPJzMwyx+FkZmaZ43AyM7PMUUSUu4aKNHDwgVHz6SvLXUaH+tbIzVxe37XcZXSoauwTVGe/qrFPkM1+NVx6fLuWz+Vy1NXVtWlZScsiYmxL83nPycysk/vSl77Evvvuy4gRI7a13XLLLQwfPpyamhqWLl260zLPPPMMvXr1YsaMGSWpyeFkZtbJTZw4kTvuuGO7thEjRvDLX/6ScePGFVzm3HPPZfz48SWrqaThJGk/SXMkrZL0mKTbJR3czLy1klaUsp5mtttD0kJJj0t6VNKlu7sGM7NyGjduHH369NmubejQobzvfe8rOP8f/vAHBg8ezPDhw0tWU8nCSZKAeUAuIoZExDDgIqBfqbbZDjMi4v3AB4F/lFS6/w6YmVWwDRs2MHv2bL73ve+VdDulPEt3JNAYEdc0NUTEciUuA8YDAUyLiLn5C0qaCIyNiLPT5wtIAiQnaT3wE+Bo4FWSwJsODAQmRcT8dPkTgR7AEGBeREwuVGREvAH8Lp1+S9JDwIBC80o6CzgLoG/ffbh45ObWvysZ1q97cvK2mlRjn6A6+1WNfYJs9iuXy+3U9sILL7Bhw4adXnvttddYtmwZ69evB+Dqq69mwoQJLF26lIaGBrp3715wfe1VynAaASwr0P5xYDQwCugLPCjp3lastyfJ3tgFkuYB04BjgGHADcD8dL7RJHtCm4CVkn4UEat3tWJJvYETgILD8CLiWuBaSEbrZW0ETntlcVRRe1Vjn6A6+1WNfYJs9qvh1Lqd2xoa6Nmz506j8Hr37s2YMWMYOzYZYPfd736XRYsWMXfuXF577TVqamoYPnw4Z599dofWWI537AhgdkRsAV6UtAg4FHikyOXfAprO3NUDmyKiUVI9UJs3390R8TqApMeAQUCz4SSpKzAbuCoinmpFf8zMOo3f//7324aST506lV69enV4MEFpB0Q8Cowp0K4ilt3M9rXtmTfdGG//OGsryZ4REbGV7cN2U970FloO4muBJyLiiiLqMzOrGp/73Of48Ic/zMqVKxkwYADXX3898+bNY8CAASxZsoTjjz+eY489drfWVMo9p3uAH0o6MyKuA5B0KMl5os9IugHoA4wDzmf7AGoAviapBugPfKiEdSJpGrA3cEaxy3Tfowsr2/lDtqzJ5XIFd/crWTX2CaqzX9XYJ6iMfs2ePbtg+ymnnLLL5aZOnVqCahIlC6eICEmnAFdImgJsJAmdSUAv4GGSARGTI+IFSbV5iy8GniY5bLcCeKhUdUoaAHwbeBx4KBlkyI8jYmaptmlmZrtW0nNOEfEc8OkCL52fPvLnbSAZREF62O7UZtbZK296aqHXImIWMCuvfcIualxDcYcazcxsN/EVIszMLHOyNb6xxCTdD3Tbofm0iKgvRz1mZlZYpwqniDis3DWYmVnLfFjPzMwyx+FkZmaZ43AyM7PMcTiZmVnmOJzMzCxzHE5mZpY5DiczM8sch5OZmWWOw8nMzDLH4WRmZpnjcDIzs8xxOJmZWeY4nMzMLHMcTmZmljkOJzMzyxyHk5mZZY7DyczMMsfhZGZmmeNwMjOzzHE4mZlZ5jiczMwscxxOZmaWOQ4nMzPLHIeTmZlljsPJzMwyx+FkZmaZ43AyM7PMcTiZmVnmOJzMzCxzupa7gEr1ZuMWaqcsLHcZHepbIzcz0X3aScOlx+/UtmXLFsaOHUv//v1ZsGABEcF3vvMdbrnlFrp06cJXv/pVvvGNb7Rru2admcPJrA2uvPJKhg4dytq1awGYNWsWq1ev5vHHH6empoaXXnqpzBWaVbaSHtaTtJ+kOZJWSXpM0u2SDm5m3lpJK0pZT3Mk3SHpYUmPSrpGUpdy1GGVYc2aNSxcuJAzzjhjW9vVV1/NxRdfTE1N8k9q3333LVd5ZlWhZOEkScA8IBcRQyJiGHAR0K9U22yHT0fEKGAEsA/wqTLXYxk2adIkpk+fvi2IAFatWsXcuXMZO3Ys48eP54knnihjhWaVr5SH9Y4EGiPimqaGiFiuxGXAeCCAaRExN39BSROBsRFxdvp8ATAjInKS1gM/AY4GXiUJvOnAQGBSRMxPlz8R6AEMAeZFxOTmCo2ItelkV+AdaV07kXQWcBZA3777cPHIza14O7KvX/fkHE016Yg+5XK5bdNLliyhsbGRdevWsXz5cl555RVyuRxvvPEGzz77LDNmzODee+/lE5/4BFdddVU7q2/e+vXrt6urGlRjn6A6+7U7+lTKcBoBLCvQ/nFgNDAK6As8KOneVqy3J8ne2AWS5gHTgGOAYcANwPx0vtHAB4FNwEpJP4qI1c2tVNKdwIeA3wK3FponIq4FrgUYOPjAuLy+uk7ZfWvkZtynnTWcWrdt+s4772TZsmVMnDiRjRs3snbtWmbOnMmgQYOYPHkytbW1fOQjH+Hyyy+nrq6u2XW2Vy6XK+n6y6Ea+wTV2a/d0adyDCU/ApgdEVsi4kVgEXBoK5Z/C7gjna4HFkVEYzpdmzff3RHxekRsBB4DBu1qpRFxLLA/0A04qhX1WCdyySWXsGbNGhoaGpgzZw5HHXUUN910EyeffDL33HMPAIsWLeLggwueWjWzIrU6nCS9S9IHipj1UWBMoVUUsexmtq9tz7zpxohoOuy2lWTPiIjYyvZ7gpvyprdQxF5iGmTzgZOKqNFsmylTpnDbbbcxcuRILrzwQmbOnFnukswqWlHhJCkn6Z2S+gAPAz+X9J8tLHYP0E3SmXnrOZTkPNFnJHWRtA8wDnhgh2UbgNGSaiQdQHK4rSQk9ZK0fzrdFTgOeLxU27PqUVdXx4IFCwDo3bs3CxcupL6+niVLljBq1KgyV2dW2Yo9GL93RKyVdAbw84j4nqRHdrVARISkU4ArJE0BNpKEziSgF0nIBTA5Il6QVJu3+GLgaZJDdSuAh4rvUqv1BOZL6gZ0IQnVa3a9CHTfowsrC/w4s5Llcrntzq9Ug2rsk1lnUGw4dU33Lj4NfLvYlUfEc+kyOzo/feTP20AyiIL0sN2pzayzV9701EKvRcQsYFZe+4Rd1PgirTvnZWZmJVbsOafvA3cCqyLiQUmDAf+Qw8zMSqKoPaeIuAW4Je/5U8AnSlVUqUi6n2Q0Xr7TIqK+HPWYmVlhRYVTesmhq4F+ETEiHa13YkRMK2l1HSwiDit3DWZm1rJiD+tdB1wINAJExCPAZ0tVlJmZdW7FhlOPiNhxuHd1XefGzMwyo9hwelnSENJrzkn6JPB8yaoyM7NOrdih5F8nuabc+yU9S/IbpIJDvc3MzNqrxXCSVENyhfCjJfUEaiJiXelLMzOzzqrFw3rpNevOTqc3OJjMzKzUij3ndJekf5N0gKQ+TY+SVmZmZp1WseecvpT++fW8tgAGd2w5ZmZmxV8h4r2lLsTMzKxJsVeI+EKh9oj4RceWY2ZmVvxhvfyrdu8JfJTkNhYOJzMz63DFHtY7J/+5pL2BG0tSkZmZdXqtvk176g3goI4sxMzMrEmx55x+Q3rpIpJAG0beLTTMzMw6UrHnnGbkTW8G/hoRa0pQj5mZWdGH9Y6LiEXpY3FErJH0HyWtzMzMOq1iw+mYAm3jO7IQMzOzJrs8rCfpq8DXgMGSHsl7aS9gcSkLMzOzzqulc07/A/wWuASYkte+LiL+XrKqzMysU9tlOEXE68DrwOcAJO1L8iPcXpJ6RcQzpS/RzMw6m6LOOUk6QdITJDcZXAQ0kOxRmZmZdbhiB0RMAw4H/pJeBPaj+JyTmZmVSLHh1BgRrwA1kmoi4nfA6BLWZWZmnVixP8J9TVIv4PfAzZJeIvkxrpmZWYcrds/pJJLr6U0C7gBWASeUqigzM+vcir0q+QZJg4CDIuIGST2ALqUtzczMOqtiR+udCdwK/Cxt6g/8qlRFmZlZ51bsYb2vA/8IrAWIiCeAfUtVlJmZdW7FhtOmiHir6Ymkrrx9Cw0zM7MOVexovUWSLgK6SzqG5Hp7vyldWdn3ZuMWaqcsLHcZHepbIzczscg+NVx6fImrMbPOrNg9pynA34B64MvA7cB3SlWUVZbVq1dz5JFHMnToUIYPH86VV1653eszZsxAEi+//HKZKjSzStPSVckHRsQzEbEVuC59mG2na9euXH755RxyyCGsW7eOMWPGcMwxxzBs2DBWr17NXXfdxcCBA8tdpplVkJb2nLaNyJN0W2tXLmk/SXMkrZL0mKTbJR3czLy1kla0dhsdQdIYSfWSnpR0lSSVo45Ktf/++3PIIYcAsNdeezF06FCeffZZAM4991ymT5+O31Iza42Wwin/G2Vwa1acfsHPA3IRMSQihgEXAf1aV+JucTVwFnBQ+vhYecupXA0NDfzpT3/isMMOY/78+fTv359Ro0aVuywzqzAtDYiIZqaLcSTJNfmu2baCiOVKXEZyJ90ApkXE3PwFJU0ExkbE2enzBcCMiMhJWg/8BDgaeJUk8KYDA4FJETE/Xf5EoAcwBJgXEZMLFSlpf+CdEbEkff4L4GQKXHVd0lkkIUbfvvtw8cjquoJTv+7JoIhi5HK5ndrefPNNvvnNb3LGGWdw3333ccEFF3DZZZeRy+XYuHEjixcvZu+99+7gqndt/fr1BWutdNXYr2rsE1Rnv3ZHn1oKp1GS1pLsQXVPp0mfR0S8cxfLjgCWFWj/OMlFY0cBfYEHJd3bipp7kuyNXSBpHskV048BhgE3APPT+UYDHwQ2ASsl/SgiVhdYX39gTd7zNWnbTiLiWuBagIGDD4zL64sd7FgZvjVyM8X2qeHUuu2eNzY2MmHCBL7yla9w3nnnUV9fzyuvvMLZZ58NwMsvv8w555zDAw88wH777dfRpTcrl8tRV1fX4nyVphr7VY19gurs1+7oU0s3GyzFJYqOAGZHxBbgRUmLgEOBR3a92DZvkVzfD5LRg5siolFSPVCbN9/d6c0SkfQYMAgoFE6FTob4N1ytEBGcfvrpDB06lPPOOw+AkSNH8tJLL22bp7a2lqVLl9K3b99ylWlmFaTYoeRt8SgwpkB7MWfGN7N9bXvmTTdGRFN4bCXZMyIdUZgftpvyprfQfBCvAQbkPR8APFdEjZZavHgxN954I/fccw+jR49m9OjR3H777eUuy8wqWCmPS90D/FDSmRFxHYCkQ0nOE31G0g1AH2AccD7bB1AD8DVJNSSH2D5UqiIj4nlJ6yQdDtwPfAH4UUvLdd+jCyur7IeouVxup8N1xTjiiCN4+/8LhTU0NLStKDPrlEoWThERkk4BrpA0BdhIEjqTgF7AwySHzyZHxAuSavMWX0xyS/h6YAXwUKnqTH0VmAV0JxkI4VvQm5mVUUnP6EfEc8CnC7x0fvrIn7eBZBAF6WG7U5tZZ6+86amFXouIWSRh09Q+oYU6lzZt28zMyq+U55zMzMzapLrGQrdA0v1Atx2aT4uI+nLUY2ZmhXWqcIqIw8pdg5mZtcyH9czMLHMcTmZmljkOJzMzyxyHk5mZZY7DyczMMsfhZGZmmeNwMjOzzHE4mZlZ5jiczMwscxxOZmaWOQ4nMzPLHIeTmZlljsPJzMwyx+FkZmaZ43AyM7PMcTiZmVnmOJzMzCxzHE5mZpY5DiczM8sch5OZmWWOw8nMzDLH4WRmZpnjcDIzs8xxOJmZWeY4nMzMLHMcTmZmljkOJzMzyxyHk5mZZY7DyczMMqdruQuoVG82bqF2ysIW52u49Pid2rZs2cLYsWPp378/CxYsKEV5ZmYVzXtOZXDllVcydOjQcpdhZpZZJQ0nSftJmiNplaTHJN0u6eBm5q2VtKKU9TRH0g8krZa0vtTbWrNmDQsXLuSMM84o9abMzCpWycJJkoB5QC4ihkTEMOAioF+pttkOvwE+tDs2NGnSJKZPn05NjXdazcyaU8pzTkcCjRFxTVNDRCxX4jJgPBDAtIiYm7+gpInA2Ig4O32+AJgREbl07+YnwNHAqySBNx0YCEyKiPnp8icCPYAhwLyImNxcoRHxx3Q7u+yQpLOAswD69t2Hi0dubvFNyOVy26aXLFlCY2Mj69atY/ny5bzyyivbvV5u69evz1Q9HaEa+wTV2a9q7BNUZ792R59KGU4jgGUF2j8OjAZGAX2BByXd24r19iTZG7tA0jxgGnAMMAy4AZifzjca+CCwCVgp6UcRsbpNPUlFxLXAtQADBx8Yl9e3/PY1nFq3bfrOO+9k2bJlTJw4kY0bN7J27VpmzpzJTTfd1J6yOkwul6Ourq7cZXSoauwTVGe/qrFPUJ392h19KsexpSOA2RGxJSJeBBYBh7Zi+beAO9LpemBRRDSm07V5890dEa9HxEbgMWBQuytvp0suuYQ1a9bQ0NDAnDlzOOqoozITTGZmWVLKcHoUGFOgfdfHzhKb2b62PfOmGyMi0umtJHtGRMRWtt8T3JQ3vQUPmzczqxilDKd7gG6SzmxqkHQoyXmiz0jqImkfYBzwwA7LNgCjJdVIOoDdNFhhd6qrq/NvnMzMmlGyvYmICEmnAFdImgJsJAmdSUAv4GGSARGTI+IFSbV5iy8GniY5VLcCeKhUdQJImg58HughaQ0wMyKm7mqZ7nt0YWWBH9iamVn7lfRQV0Q8B3y6wEvnp4/8eRtIBlGQHrY7tZl19sqbnlrotYiYBczKa5/QQp2TgWZH85mZ2e7lH9uYmVnmdKpBApLuB7rt0HxaRNSXox4zMyusU4VTRBxW7hrMzKxlPqxnZmaZ43AyM7PMcTiZmVnmOJzMzCxzHE5mZpY5DiczM8sch5OZmWWOw8nMzDLH4WRmZpnjcDIzs8xxOJmZWeY4nMzMLHMcTmZmljkOJzMzyxyHk5mZZY7DyczMMsfhZGZmmeNwMjOzzHE4mZlZ5jiczMwscxxOZmaWOQ4nMzPLHIeTmZlljsPJzMwyx+FkZmaZ43AyM7PMcTiZmVnmOJzMzCxzHE5mZpY5DiczM8sch5OZmWWOw8nMzDLH4WRmZpnjcDIzs8xRRJS7hookaR2wstx1dLC+wMvlLqKDVWOfoDr7VY19gursV3v6NCgi9mlppq5tXLnByogYW+4iOpKkpe5TZajGflVjn6A6+7U7+uTDemZmljkOJzMzyxyHU9tdW+4CSsB9qhzV2K9q7BNUZ79K3icPiDAzs8zxnpOZmWWOw8nMzDLH4dRKkj4maaWkJyVNKXc9bSXpAEm/k/RnSY9K+mba3kfSXZKeSP98V7lrbS1JXST9SdKC9Pl7Jd2f9mmupHeUu8bWkNRb0q2SHk8/rw9Xyed0bvp3b4Wk2ZL2rLTPStJ/S3pJ0oq8toKfjRJXpd8dj0g6pHyV71oz/bos/Tv4iKR5knrnvXZh2q+Vko7tiBocTq0gqQvwE2A8MAz4nKRh5a2qzTYD34qIocDhwNfTvkwB7o6Ig4C70+eV5pvAn/Oe/wfwX2mfXgVOL0tVbXclcEdEvB8YRdK3iv6cJPUHvgGMjYgRQBfgs1TeZzUL+NgObc19NuOBg9LHWcDVu6nGtpjFzv26CxgRER8A/gJcCJB+b3wWGJ4u89P0u7JdHE6t8yHgyYh4KiLeAuYAJ5W5pjaJiOcj4qF0eh3JF15/kv7ckM52A3ByeSpsG0kDgOOBmelzAUcBt6azVFSfJL0TGAdcDxARb0XEa1T455TqCnSX1BXoATxPhX1WEXEv8Pcdmpv7bE4CfhGJPwK9Je2/eyptnUL9ioj/i4jN6dM/AgPS6ZOAORGxKSKeBp4k+a5sF4dT6/QHVuc9X5O2VTRJtcAHgfuBfhHxPCQBBuxbvsra5ApgMrA1ff5u4LW8f1SV9pkNBv4G/Dw9VDlTUk8q/HOKiGeBGcAzJKH0OrCMyv6smjT32VTT98eXgN+m0yXpl8OpdVSgraLH4kvqBdwGTIqIteWupz0kTQBeiohl+c0FZq2kz6wrcAhwdUR8ENhAhR3CKyQ9D3MS8F7gPUBPksNeO6qkz6ollf53EQBJ3yY5LXBzU1OB2drdL4dT66wBDsh7PgB4rky1tJukPUiC6eaI+GXa/GLToYb0z5fKVV8b/CNwoqQGkkOuR5HsSfVODx1B5X1ma4A1EXF/+vxWkrCq5M8J4Gjg6Yj4W0Q0Ar8E/oHK/qyaNPfZVPz3h6QvAhOAU+PtH8mWpF8Op9Z5EDgoHVH0DpKTgPPLXFObpOdirgf+HBH/mffSfOCL6fQXgV/v7traKiIujIgBEVFL8tncExGnAr8DPpnOVml9egFYLel9adNHgceo4M8p9QxwuKQe6d/Fpn5V7GeVp7nPZj7whXTU3uHA602H/yqBpI8BFwAnRsQbeS/NBz4rqZuk95IM+Hig3RuMCD9a8QCOIxmpsgr4drnraUc/jiDZ9X4EWJ4+jiM5R3M38ET6Z59y19rG/tUBC9Lpwek/lieBW4Bu5a6vlX0ZDSxNP6tfAe+qhs8J+HfgcWAFcCPQrdI+K2A2yTmzRpI9iNOb+2xIDn/9JP3uqCcZqVj2PrSiX0+SnFtq+r64Jm/+b6f9WgmM74gafPkiMzPLHB/WMzOzzHE4mZlZ5jiczMwscxxOZmaWOQ4nMzPLHIeT2Q4kbZG0PO9R24Z19Jb0tY6vbtv6T9zdV8WXdHIFX+jYKoyHkpvtQNL6iOjVznXUkvzOakQrl+sSEVvas+1SSK/aMJOkT7e2NL9Ze3nPyawI6T2iLpP0YHo/my9s+0EtAAADPklEQVSn7b0k3S3pIUn1kpquUn8pMCTd87pMUp3S+0uly/1Y0sR0ukHSxZL+AHxK0hBJd0haJun3kt5foJ6Jkn6cTs+SdLWS+3M9Jekj6f14/ixpVt4y6yVdntZ6t6R90vbRkv6Yd5+epvsP5ST9UNIi0isDAJelfRoi6cz0/XhY0m2SeuTVc5Wk+9J6PplXw+T0fXpY0qVpW4v9tU6o3L9E9sOPrD2ALbz9K/h5adtZwHfS6W4kV2x4L8mFWd+Ztvcl+RW9gFpgRd4660ivWJE+/zEwMZ1uACbnvXY3cFA6fRjJZZh2rHEi8ON0ehbJtQRFcjHVtcBIkv98LgNGp/MFyTXRAC7OW/4R4CPp9PeBK9LpHPDTvG3OAj6Z9/zdedPTgHPy5rsl3f4wktvMQHJh1/uAHunzPsX214/O92i6wKKZve3NiBi9Q9s/Ax/I2wvYm+QaYmuAH0oaR3Kbjv5AvzZscy5su0r8PwC3JJecA5IwbMlvIiIk1QMvRkR9ur5HSYJyeVrf3HT+m4BfStob6B0Ri9L2G0iCZbu6mjFC0jSgN9ALuDPvtV9FxFbgMUlN78fRwM8jvS5bRPy9Hf21KudwMiuOSPYM7tyuMTk0tw8wJiIa0yui71lg+c1sfxh9x3k2pH/WkNzTaMdwbMmm9M+tedNNz5v7d17MCecNu3htFnByRDycvg91BeqBt2+poALbbGt/rcr5nJNZce4EvpreZgRJByu56d/eJPeQapR0JDAonX8dsFfe8n8FhqVXbt6b5CrcO4nknlpPS/pUuh1JGtVBfajh7St+fx74Q0S8Drwq6Z/S9tOARYUWZuc+7QU8n74npxax/f8DvpR3bqpPiftrFczhZFacmSS3dHhI0grgZyR7JDcDYyUtJfmCfhwgIl4BFktaIemyiFgN/C/J+Z2bgT/tYlunAqdLehh4lOQ8UkfYAAyXtIzkXlffT9u/SDLQ4RGSK6B/v5nl5wDnK7kj7xDguyR3T76LtN+7EhF3kNxeYamk5cC/pS+Vqr9WwTyU3KyT6Igh8ma7i/eczMwsc7znZGZmmeM9JzMzyxyHk5mZZY7DyczMMsfhZGZmmeNwMjOzzPn/IDzS5ZQLUqoAAAAASUVORK5CYII=\n", 41 | "text/plain": [ 42 | "
" 43 | ] 44 | }, 45 | "metadata": { 46 | "needs_background": "light" 47 | }, 48 | "output_type": "display_data" 49 | } 50 | ], 51 | "source": [ 52 | "# 导入相关模块\n", 53 | "import lightgbm as lgb\n", 54 | "from sklearn.metrics import accuracy_score\n", 55 | "from sklearn.datasets import load_iris\n", 56 | "from sklearn.model_selection import train_test_split\n", 57 | "import matplotlib.pyplot as plt\n", 58 | "# 导入iris数据集\n", 59 | "iris = load_iris()\n", 60 | "data = iris.data\n", 61 | "target = iris.target\n", 62 | "# 数据集划分\n", 63 | "X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.2, random_state=43)\n", 64 | "# 创建lightgbm分类模型\n", 65 | "gbm = lgb.LGBMClassifier(objective='multiclass',\n", 66 | " num_class=3,\n", 67 | " num_leaves=31,\n", 68 | " learning_rate=0.05,\n", 69 | " n_estimators=20)\n", 70 | "# 模型训练\n", 71 | "gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], early_stopping_rounds=5)\n", 72 | "# 预测测试集\n", 73 | "y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration_)\n", 74 | "# 模型评估\n", 75 | "print('Accuracy of lightgbm:', accuracy_score(y_test, y_pred))\n", 76 | "lgb.plot_importance(gbm)\n", 77 | "plt.show();" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [] 86 | } 87 | ], 88 | "metadata": { 89 | "kernelspec": { 90 | "display_name": "Python 3", 91 | "language": "python", 92 | "name": "python3" 93 | }, 94 | "language_info": { 95 | "codemirror_mode": { 96 | "name": "ipython", 97 | "version": 3 98 | }, 99 | "file_extension": ".py", 100 | "mimetype": "text/x-python", 101 | "name": "python", 102 | "nbconvert_exporter": "python", 103 | "pygments_lexer": "ipython3", 104 | "version": "3.7.3" 105 | }, 106 | "toc": { 107 | "base_numbering": 1, 108 | "nav_menu": {}, 109 | "number_sections": true, 110 | "sideBar": true, 111 | "skip_h1_title": false, 112 | "title_cell": "Table of Contents", 113 | "title_sidebar": "Contents", 114 | "toc_cell": false, 115 | "toc_position": {}, 116 | "toc_section_display": true, 117 | "toc_window_display": false 118 | } 119 | }, 120 | "nbformat": 4, 121 | "nbformat_minor": 2 122 | } 123 | -------------------------------------------------------------------------------- /charpter15_random_forest/cart.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from utils import feature_split, calculate_gini 3 | 4 | ### 定义树结点 5 | class TreeNode(): 6 | def __init__(self, feature_i=None, threshold=None, 7 | leaf_value=None, left_branch=None, right_branch=None): 8 | # 特征索引 9 | self.feature_i = feature_i 10 | # 特征划分阈值 11 | self.threshold = threshold 12 | # 叶子节点取值 13 | self.leaf_value = leaf_value 14 | # 左子树 15 | self.left_branch = left_branch 16 | # 右子树 17 | self.right_branch = right_branch 18 | 19 | 20 | ### 定义二叉决策树 21 | class BinaryDecisionTree(object): 22 | ### 决策树初始参数 23 | def __init__(self, min_samples_split=2, min_gini_impurity=999, 24 | max_depth=float("inf"), loss=None): 25 | # 根结点 26 | self.root = None 27 | # 节点最小分裂样本数 28 | self.min_samples_split = min_samples_split 29 | # 节点初始化基尼不纯度 30 | self.min_gini_impurity = min_gini_impurity 31 | # 树最大深度 32 | self.max_depth = max_depth 33 | # 基尼不纯度计算函数 34 | self.gini_impurity_calculation = None 35 | # 叶子节点值预测函数 36 | self._leaf_value_calculation = None 37 | # 损失函数 38 | self.loss = loss 39 | 40 | ### 决策树拟合函数 41 | def fit(self, X, y, loss=None): 42 | # 递归构建决策树 43 | self.root = self._build_tree(X, y) 44 | self.loss = None 45 | 46 | ### 决策树构建函数 47 | def _build_tree(self, X, y, current_depth=0): 48 | # 初始化最小基尼不纯度 49 | init_gini_impurity = 999 50 | # 初始化最佳特征索引和阈值 51 | best_criteria = None 52 | # 初始化数据子集 53 | best_sets = None 54 | 55 | if len(np.shape(y)) == 1: 56 | y = np.expand_dims(y, axis=1) 57 | 58 | # 合并输入和标签 59 | Xy = np.concatenate((X, y), axis=1) 60 | # 获取样本数和特征数 61 | n_samples, n_features = X.shape 62 | # 设定决策树构建条件 63 | # 训练样本数量大于节点最小分裂样本数且当前树深度小于最大深度 64 | if n_samples >= self.min_samples_split and current_depth <= self.max_depth: 65 | # 遍历计算每个特征的基尼不纯度 66 | for feature_i in range(n_features): 67 | # 获取第i特征的所有取值 68 | feature_values = np.expand_dims(X[:, feature_i], axis=1) 69 | # 获取第i个特征的唯一取值 70 | unique_values = np.unique(feature_values) 71 | 72 | # 遍历取值并寻找最佳特征分裂阈值 73 | for threshold in unique_values: 74 | # 特征节点二叉分裂 75 | Xy1, Xy2 = feature_split(Xy, feature_i, threshold) 76 | # 如果分裂后的子集大小都不为0 77 | if len(Xy1) > 0 and len(Xy2) > 0: 78 | # 获取两个子集的标签值 79 | y1 = Xy1[:, n_features:] 80 | y2 = Xy2[:, n_features:] 81 | 82 | # 计算基尼不纯度 83 | impurity = self.impurity_calculation(y, y1, y2) 84 | 85 | # 获取最小基尼不纯度 86 | # 最佳特征索引和分裂阈值 87 | if impurity < init_gini_impurity: 88 | init_gini_impurity = impurity 89 | best_criteria = {"feature_i": feature_i, "threshold": threshold} 90 | best_sets = { 91 | "leftX": Xy1[:, :n_features], 92 | "lefty": Xy1[:, n_features:], 93 | "rightX": Xy2[:, :n_features], 94 | "righty": Xy2[:, n_features:] 95 | } 96 | 97 | # 如果计算的最小不纯度小于设定的最小不纯度 98 | if init_gini_impurity < self.min_gini_impurity: 99 | # 分别构建左右子树 100 | left_branch = self._build_tree(best_sets["leftX"], best_sets["lefty"], current_depth + 1) 101 | right_branch = self._build_tree(best_sets["rightX"], best_sets["righty"], current_depth + 1) 102 | return TreeNode(feature_i=best_criteria["feature_i"], threshold=best_criteria["threshold"], left_branch=left_branch, right_branch=right_branch) 103 | 104 | # 计算叶子计算取值 105 | leaf_value = self._leaf_value_calculation(y) 106 | return TreeNode(leaf_value=leaf_value) 107 | 108 | ### 定义二叉树值预测函数 109 | def predict_value(self, x, tree=None): 110 | if tree is None: 111 | tree = self.root 112 | # 如果叶子节点已有值,则直接返回已有值 113 | if tree.leaf_value is not None: 114 | return tree.leaf_value 115 | # 选择特征并获取特征值 116 | feature_value = x[tree.feature_i] 117 | # 判断落入左子树还是右子树 118 | branch = tree.right_branch 119 | if isinstance(feature_value, int) or isinstance(feature_value, float): 120 | if feature_value >= tree.threshold: 121 | branch = tree.left_branch 122 | elif feature_value == tree.threshold: 123 | branch = tree.right_branch 124 | # 测试子集 125 | return self.predict_value(x, branch) 126 | 127 | ### 数据集预测函数 128 | def predict(self, X): 129 | y_pred = [self.predict_value(sample) for sample in X] 130 | return y_pred 131 | 132 | 133 | class ClassificationTree(BinaryDecisionTree): 134 | ### 定义基尼不纯度计算过程 135 | def _calculate_gini_impurity(self, y, y1, y2): 136 | p = len(y1) / len(y) 137 | gini = calculate_gini(y) 138 | gini_impurity = p * calculate_gini(y1) + (1-p) * calculate_gini(y2) 139 | return gini_impurity 140 | 141 | ### 多数投票 142 | def _majority_vote(self, y): 143 | most_common = None 144 | max_count = 0 145 | for label in np.unique(y): 146 | # 统计多数 147 | count = len(y[y == label]) 148 | if count > max_count: 149 | most_common = label 150 | max_count = count 151 | return most_common 152 | 153 | # 分类树拟合 154 | def fit(self, X, y): 155 | self.impurity_calculation = self._calculate_gini_impurity 156 | self._leaf_value_calculation = self._majority_vote 157 | super(ClassificationTree, self).fit(X, y) 158 | 159 | 160 | ### CART回归树 161 | class RegressionTree(BinaryDecisionTree): 162 | # 计算方差减少量 163 | def _calculate_variance_reduction(self, y, y1, y2): 164 | var_tot = np.var(y, axis=0) 165 | var_y1 = np.var(y1, axis=0) 166 | var_y2 = np.var(y2, axis=0) 167 | frac_1 = len(y1) / len(y) 168 | frac_2 = len(y2) / len(y) 169 | # 计算方差减少量 170 | variance_reduction = var_tot - (frac_1 * var_y1 + frac_2 * var_y2) 171 | return sum(variance_reduction) 172 | 173 | # 节点值取平均 174 | def _mean_of_y(self, y): 175 | value = np.mean(y, axis=0) 176 | return value if len(value) > 1 else value[0] 177 | 178 | # 回归树拟合 179 | def fit(self, X, y): 180 | self.impurity_calculation = self._calculate_variance_reduction 181 | self._leaf_value_calculation = self._mean_of_y 182 | super(RegressionTree, self).fit(X, y) 183 | -------------------------------------------------------------------------------- /charpter15_random_forest/random_forest.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "(700, 20) (700,) (300, 20) (300,)\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "import numpy as np\n", 18 | "# 该模块为自定义模块,封装了构建决策树的基本方法\n", 19 | "from cart import *\n", 20 | "from sklearn.datasets import make_classification\n", 21 | "from sklearn.model_selection import train_test_split\n", 22 | "# 树的棵数\n", 23 | "n_estimators = 10\n", 24 | "# 列抽样最大特征数\n", 25 | "max_features = 15\n", 26 | "# 生成模拟二分类数据集\n", 27 | "X, y = make_classification(n_samples=1000, n_features=20, n_redundant=0, n_informative=2,\n", 28 | " random_state=1, n_clusters_per_class=1)\n", 29 | "rng = np.random.RandomState(2)\n", 30 | "X += 2 * rng.uniform(size=X.shape)\n", 31 | "# 划分数据集\n", 32 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)\n", 33 | "print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "# 合并训练数据和标签\n", 43 | "X_y = np.concatenate([X, y.reshape(-1,1)], axis=1)\n", 44 | "np.random.shuffle(X_y)\n", 45 | "m = X_y.shape[0]\n", 46 | "sampling_subsets = []\n", 47 | "\n", 48 | "for _ in range(n_estimators):\n", 49 | " idx = np.random.choice(m, m, replace=True)\n", 50 | " bootstrap_Xy = X_y[idx, :]\n", 51 | " bootstrap_X = bootstrap_Xy[:, :-1]\n", 52 | " bootstrap_y = bootstrap_Xy[:, -1]\n", 53 | " sampling_subsets.append([bootstrap_X, bootstrap_y])" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 3, 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "data": { 63 | "text/plain": [ 64 | "(1000, 20)" 65 | ] 66 | }, 67 | "execution_count": 3, 68 | "metadata": {}, 69 | "output_type": "execute_result" 70 | } 71 | ], 72 | "source": [ 73 | "sampling_subsets[0][0].shape" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 4, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "# 自助抽样选择训练数据子集\n", 83 | "def bootstrap_sampling(X, y):\n", 84 | " X_y = np.concatenate([X, y.reshape(-1,1)], axis=1)\n", 85 | " np.random.shuffle(X_y)\n", 86 | " n_samples = X.shape[0]\n", 87 | " sampling_subsets = []\n", 88 | "\n", 89 | " for _ in range(n_estimators):\n", 90 | " # 第一个随机性,行抽样\n", 91 | " idx1 = np.random.choice(n_samples, n_samples, replace=True)\n", 92 | " bootstrap_Xy = X_y[idx1, :]\n", 93 | " bootstrap_X = bootstrap_Xy[:, :-1]\n", 94 | " bootstrap_y = bootstrap_Xy[:, -1]\n", 95 | " sampling_subsets.append([bootstrap_X, bootstrap_y])\n", 96 | " return sampling_subsets" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 5, 102 | "metadata": {}, 103 | "outputs": [ 104 | { 105 | "name": "stdout", 106 | "output_type": "stream", 107 | "text": [ 108 | "(700, 20) (700,)\n" 109 | ] 110 | } 111 | ], 112 | "source": [ 113 | "sampling_subsets = bootstrap_sampling(X_train, y_train)\n", 114 | "sub_X, sub_y = sampling_subsets[0]\n", 115 | "print(sub_X.shape, sub_y.shape)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 6, 121 | "metadata": {}, 122 | "outputs": [ 123 | { 124 | "data": { 125 | "text/plain": [ 126 | "" 127 | ] 128 | }, 129 | "execution_count": 6, 130 | "metadata": {}, 131 | "output_type": "execute_result" 132 | } 133 | ], 134 | "source": [ 135 | "trees = []\n", 136 | "# 基于决策树构建森林\n", 137 | "for _ in range(n_estimators):\n", 138 | " tree = ClassificationTree(min_samples_split=2, min_gini_impurity=999,\n", 139 | " max_depth=3)\n", 140 | " trees.append(tree)\n", 141 | "\n", 142 | "trees[0]" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 7, 148 | "metadata": { 149 | "scrolled": true 150 | }, 151 | "outputs": [ 152 | { 153 | "name": "stdout", 154 | "output_type": "stream", 155 | "text": [ 156 | "The 1th tree is trained done...\n", 157 | "The 2th tree is trained done...\n", 158 | "The 3th tree is trained done...\n", 159 | "The 4th tree is trained done...\n", 160 | "The 5th tree is trained done...\n", 161 | "The 6th tree is trained done...\n", 162 | "The 7th tree is trained done...\n", 163 | "The 8th tree is trained done...\n", 164 | "The 9th tree is trained done...\n", 165 | "The 10th tree is trained done...\n" 166 | ] 167 | } 168 | ], 169 | "source": [ 170 | "# 随机森林训练\n", 171 | "def fit(X, y):\n", 172 | " # 对森林中每棵树训练一个双随机抽样子集\n", 173 | " n_features = X.shape[1]\n", 174 | " sub_sets = bootstrap_sampling(X, y)\n", 175 | " for i in range(n_estimators):\n", 176 | " sub_X, sub_y = sub_sets[i]\n", 177 | " # 第二个随机性,列抽样\n", 178 | " idx2 = np.random.choice(n_features, max_features, replace=True)\n", 179 | " sub_X = sub_X[:, idx2]\n", 180 | " trees[i].fit(sub_X, sub_y)\n", 181 | " trees[i].feature_indices = idx2\n", 182 | " print('The {}th tree is trained done...'.format(i+1))\n", 183 | "\n", 184 | "fit(X_train, y_train)" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 8, 190 | "metadata": {}, 191 | "outputs": [ 192 | { 193 | "data": { 194 | "text/plain": [ 195 | "300" 196 | ] 197 | }, 198 | "execution_count": 8, 199 | "metadata": {}, 200 | "output_type": "execute_result" 201 | } 202 | ], 203 | "source": [ 204 | "y_preds = []\n", 205 | "for i in range(n_estimators):\n", 206 | " idx = trees[i].feature_indices\n", 207 | " sub_X = X_test[:, idx]\n", 208 | " y_pred = trees[i].predict(sub_X)\n", 209 | " y_preds.append(y_pred)\n", 210 | " \n", 211 | "len(y_preds[0])" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 9, 217 | "metadata": {}, 218 | "outputs": [ 219 | { 220 | "name": "stdout", 221 | "output_type": "stream", 222 | "text": [ 223 | "(300, 10)\n", 224 | "[0, 0, 0, 0, 0, 1, 0, 1, 0, 0]\n" 225 | ] 226 | } 227 | ], 228 | "source": [ 229 | "y_preds = np.array(y_preds).T\n", 230 | "print(y_preds.shape)\n", 231 | "y_pred = []\n", 232 | "for y_p in y_preds:\n", 233 | " y_pred.append(np.bincount(y_p.astype('int')).argmax())\n", 234 | "\n", 235 | "print(y_pred[:10])" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": 10, 241 | "metadata": {}, 242 | "outputs": [ 243 | { 244 | "name": "stdout", 245 | "output_type": "stream", 246 | "text": [ 247 | "0.7366666666666667\n" 248 | ] 249 | } 250 | ], 251 | "source": [ 252 | "from sklearn.metrics import accuracy_score\n", 253 | "print(accuracy_score(y_test, y_pred))" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 3, 259 | "metadata": {}, 260 | "outputs": [], 261 | "source": [ 262 | "class RandomForest():\n", 263 | " def __init__(self, n_estimators=100, min_samples_split=2, min_gain=0,\n", 264 | " max_depth=float(\"inf\"), max_features=None):\n", 265 | " # 树的棵树\n", 266 | " self.n_estimators = n_estimators\n", 267 | " # 树最小分裂样本数\n", 268 | " self.min_samples_split = min_samples_split\n", 269 | " # 最小增益\n", 270 | " self.min_gain = min_gain\n", 271 | " # 树最大深度\n", 272 | " self.max_depth = max_depth\n", 273 | " # 所使用最大特征数\n", 274 | " self.max_features = max_features\n", 275 | "\n", 276 | " self.trees = []\n", 277 | " # 基于决策树构建森林\n", 278 | " for _ in range(self.n_estimators):\n", 279 | " tree = ClassificationTree(min_samples_split=self.min_samples_split, min_impurity=self.min_gain,\n", 280 | " max_depth=self.max_depth)\n", 281 | " self.trees.append(tree)\n", 282 | " \n", 283 | " # 自助抽样\n", 284 | " def bootstrap_sampling(self, X, y):\n", 285 | " X_y = np.concatenate([X, y.reshape(-1,1)], axis=1)\n", 286 | " np.random.shuffle(X_y)\n", 287 | " n_samples = X.shape[0]\n", 288 | " sampling_subsets = []\n", 289 | "\n", 290 | " for _ in range(self.n_estimators):\n", 291 | " # 第一个随机性,行抽样\n", 292 | " idx1 = np.random.choice(n_samples, n_samples, replace=True)\n", 293 | " bootstrap_Xy = X_y[idx1, :]\n", 294 | " bootstrap_X = bootstrap_Xy[:, :-1]\n", 295 | " bootstrap_y = bootstrap_Xy[:, -1]\n", 296 | " sampling_subsets.append([bootstrap_X, bootstrap_y])\n", 297 | " return sampling_subsets\n", 298 | " \n", 299 | " # 随机森林训练\n", 300 | " def fit(self, X, y):\n", 301 | " # 对森林中每棵树训练一个双随机抽样子集\n", 302 | " sub_sets = self.bootstrap_sampling(X, y)\n", 303 | " n_features = X.shape[1]\n", 304 | " # 设置max_feature\n", 305 | " if self.max_features == None:\n", 306 | " self.max_features = int(np.sqrt(n_features))\n", 307 | " \n", 308 | " for i in range(self.n_estimators):\n", 309 | " # 第二个随机性,列抽样\n", 310 | " sub_X, sub_y = sub_sets[i]\n", 311 | " idx2 = np.random.choice(n_features, self.max_features, replace=True)\n", 312 | " sub_X = sub_X[:, idx2]\n", 313 | " self.trees[i].fit(sub_X, sub_y)\n", 314 | " # 保存每次列抽样的列索引,方便预测时每棵树调用\n", 315 | " self.trees[i].feature_indices = idx2\n", 316 | " print('The {}th tree is trained done...'.format(i+1))\n", 317 | " \n", 318 | " # 随机森林预测\n", 319 | " def predict(self, X):\n", 320 | " y_preds = []\n", 321 | " for i in range(self.n_estimators):\n", 322 | " idx = self.trees[i].feature_indices\n", 323 | " sub_X = X[:, idx]\n", 324 | " y_pred = self.trees[i].predict(sub_X)\n", 325 | " y_preds.append(y_pred)\n", 326 | " \n", 327 | " y_preds = np.array(y_preds).T\n", 328 | " res = []\n", 329 | " for j in y_preds:\n", 330 | " res.append(np.bincount(j.astype('int')).argmax())\n", 331 | " return res" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": 4, 337 | "metadata": {}, 338 | "outputs": [ 339 | { 340 | "name": "stdout", 341 | "output_type": "stream", 342 | "text": [ 343 | "The 1th tree is trained done...\n", 344 | "The 2th tree is trained done...\n", 345 | "The 3th tree is trained done...\n", 346 | "The 4th tree is trained done...\n", 347 | "The 5th tree is trained done...\n", 348 | "The 6th tree is trained done...\n", 349 | "The 7th tree is trained done...\n", 350 | "The 8th tree is trained done...\n", 351 | "The 9th tree is trained done...\n", 352 | "The 10th tree is trained done...\n" 353 | ] 354 | } 355 | ], 356 | "source": [ 357 | "rf = RandomForest(n_estimators=10, max_features=15)\n", 358 | "rf.fit(X_train, y_train)\n", 359 | "y_pred = rf.predict(X_test)\n", 360 | "print(accuracy_score(y_test, y_pred))" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": 11, 366 | "metadata": {}, 367 | "outputs": [ 368 | { 369 | "name": "stdout", 370 | "output_type": "stream", 371 | "text": [ 372 | "0.82\n" 373 | ] 374 | }, 375 | { 376 | "name": "stderr", 377 | "output_type": "stream", 378 | "text": [ 379 | "D:\\Installation\\anaconda\\install\\lib\\site-packages\\sklearn\\ensemble\\forest.py:246: FutureWarning: The default value of n_estimators will change from 10 in version 0.20 to 100 in 0.22.\n", 380 | " \"10 in version 0.20 to 100 in 0.22.\", FutureWarning)\n" 381 | ] 382 | } 383 | ], 384 | "source": [ 385 | "from sklearn.ensemble import RandomForestClassifier\n", 386 | "clf = RandomForestClassifier(max_depth=3, random_state=0)\n", 387 | "clf.fit(X_train, y_train)\n", 388 | "y_pred = clf.predict(X_test)\n", 389 | "print(accuracy_score(y_test, y_pred))" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": null, 395 | "metadata": {}, 396 | "outputs": [], 397 | "source": [] 398 | } 399 | ], 400 | "metadata": { 401 | "kernelspec": { 402 | "display_name": "Python 3", 403 | "language": "python", 404 | "name": "python3" 405 | }, 406 | "language_info": { 407 | "codemirror_mode": { 408 | "name": "ipython", 409 | "version": 3 410 | }, 411 | "file_extension": ".py", 412 | "mimetype": "text/x-python", 413 | "name": "python", 414 | "nbconvert_exporter": "python", 415 | "pygments_lexer": "ipython3", 416 | "version": "3.7.3" 417 | }, 418 | "toc": { 419 | "base_numbering": 1, 420 | "nav_menu": {}, 421 | "number_sections": true, 422 | "sideBar": true, 423 | "skip_h1_title": false, 424 | "title_cell": "Table of Contents", 425 | "title_sidebar": "Contents", 426 | "toc_cell": false, 427 | "toc_position": {}, 428 | "toc_section_display": true, 429 | "toc_window_display": false 430 | } 431 | }, 432 | "nbformat": 4, 433 | "nbformat_minor": 4 434 | } 435 | -------------------------------------------------------------------------------- /charpter17_kmeans/kmeans.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### kmeans" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "from sklearn.cluster import KMeans" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "metadata": {}, 24 | "outputs": [ 25 | { 26 | "data": { 27 | "text/plain": [ 28 | "array([[0, 2],\n", 29 | " [0, 0],\n", 30 | " [1, 0],\n", 31 | " [5, 0],\n", 32 | " [5, 2]])" 33 | ] 34 | }, 35 | "execution_count": 2, 36 | "metadata": {}, 37 | "output_type": "execute_result" 38 | } 39 | ], 40 | "source": [ 41 | "X = np.array([[0,2],[0,0],[1,0],[5,0],[5,2]])\n", 42 | "X" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 14, 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "name": "stdout", 52 | "output_type": "stream", 53 | "text": [ 54 | "[1 1 1 0 0]\n" 55 | ] 56 | } 57 | ], 58 | "source": [ 59 | "from sklearn.cluster import KMeans\n", 60 | "kmeans = KMeans(n_clusters=2, random_state=0).fit(X)\n", 61 | "print(kmeans.labels_)" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 4, 67 | "metadata": {}, 68 | "outputs": [ 69 | { 70 | "data": { 71 | "text/plain": [ 72 | "array([1, 0])" 73 | ] 74 | }, 75 | "execution_count": 4, 76 | "metadata": {}, 77 | "output_type": "execute_result" 78 | } 79 | ], 80 | "source": [ 81 | "kmeans.predict([[0, 0], [12, 3]])" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 5, 87 | "metadata": {}, 88 | "outputs": [ 89 | { 90 | "data": { 91 | "text/plain": [ 92 | "array([[5. , 1. ],\n", 93 | " [0.33333333, 0.66666667]])" 94 | ] 95 | }, 96 | "execution_count": 5, 97 | "metadata": {}, 98 | "output_type": "execute_result" 99 | } 100 | ], 101 | "source": [ 102 | "kmeans.cluster_centers_" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 6, 108 | "metadata": {}, 109 | "outputs": [ 110 | { 111 | "name": "stdout", 112 | "output_type": "stream", 113 | "text": [ 114 | "5.0\n" 115 | ] 116 | } 117 | ], 118 | "source": [ 119 | "import numpy as np\n", 120 | "# 定义欧式距离\n", 121 | "def euclidean_distance(x1, x2):\n", 122 | " distance = 0\n", 123 | " # 距离的平方项再开根号\n", 124 | " for i in range(len(x1)):\n", 125 | " distance += pow((x1[i] - x2[i]), 2)\n", 126 | " return np.sqrt(distance)\n", 127 | "\n", 128 | "print(euclidean_distance(X[0], X[4]))" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 7, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "# 定义中心初始化函数\n", 138 | "def centroids_init(k, X):\n", 139 | " m, n = X.shape\n", 140 | " centroids = np.zeros((k, n))\n", 141 | " for i in range(k):\n", 142 | " # 每一次循环随机选择一个类别中心\n", 143 | " centroid = X[np.random.choice(range(m))]\n", 144 | " centroids[i] = centroid\n", 145 | " return centroids" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 8, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "# 定义样本的最近质心点所属的类别索引\n", 155 | "def closest_centroid(sample, centroids):\n", 156 | " closest_i = 0\n", 157 | " closest_dist = float('inf')\n", 158 | " for i, centroid in enumerate(centroids):\n", 159 | " # 根据欧式距离判断,选择最小距离的中心点所属类别\n", 160 | " distance = euclidean_distance(sample, centroid)\n", 161 | " if distance < closest_dist:\n", 162 | " closest_i = i\n", 163 | " closest_dist = distance\n", 164 | " return closest_i" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 9, 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "# 定义构建类别过程\n", 174 | "def build_clusters(centroids, k, X):\n", 175 | " clusters = [[] for _ in range(k)]\n", 176 | " for x_i, x in enumerate(X):\n", 177 | " # 将样本划分到最近的类别区域\n", 178 | " centroid_i = closest_centroid(x, centroids)\n", 179 | " clusters[centroid_i].append(x_i)\n", 180 | " return clusters" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 10, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "# 根据上一步聚类结果计算新的中心点\n", 190 | "def calculate_centroids(clusters, k, X):\n", 191 | " n = X.shape[1]\n", 192 | " centroids = np.zeros((k, n))\n", 193 | " # 以当前每个类样本的均值为新的中心点\n", 194 | " for i, cluster in enumerate(clusters):\n", 195 | " centroid = np.mean(X[cluster], axis=0)\n", 196 | " centroids[i] = centroid\n", 197 | " return centroids" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 11, 203 | "metadata": {}, 204 | "outputs": [], 205 | "source": [ 206 | "# 获取每个样本所属的聚类类别\n", 207 | "def get_cluster_labels(clusters, X):\n", 208 | " y_pred = np.zeros(X.shape[0])\n", 209 | " for cluster_i, cluster in enumerate(clusters):\n", 210 | " for X_i in cluster:\n", 211 | " y_pred[X_i] = cluster_i\n", 212 | " return y_pred" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 12, 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "# 根据上述各流程定义kmeans算法流程\n", 222 | "def kmeans(X, k, max_iterations):\n", 223 | " # 1.初始化中心点\n", 224 | " centroids = centroids_init(k, X)\n", 225 | " # 遍历迭代求解\n", 226 | " for _ in range(max_iterations):\n", 227 | " # 2.根据当前中心点进行聚类\n", 228 | " clusters = build_clusters(centroids, k, X)\n", 229 | " # 保存当前中心点\n", 230 | " prev_centroids = centroids\n", 231 | " # 3.根据聚类结果计算新的中心点\n", 232 | " centroids = calculate_centroids(clusters, k, X)\n", 233 | " # 4.设定收敛条件为中心点是否发生变化\n", 234 | " diff = centroids - prev_centroids\n", 235 | " if not diff.any():\n", 236 | " break\n", 237 | " # 返回最终的聚类标签\n", 238 | " return get_cluster_labels(clusters, X)" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 13, 244 | "metadata": {}, 245 | "outputs": [ 246 | { 247 | "name": "stdout", 248 | "output_type": "stream", 249 | "text": [ 250 | "[0. 0. 0. 1. 1.]\n" 251 | ] 252 | } 253 | ], 254 | "source": [ 255 | "# 测试数据\n", 256 | "X = np.array([[0,2],[0,0],[1,0],[5,0],[5,2]])\n", 257 | "# 设定聚类类别为2个,最大迭代次数为10次\n", 258 | "labels = kmeans(X, 2, 10)\n", 259 | "# 打印每个样本所属的类别标签\n", 260 | "print(labels)" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": null, 266 | "metadata": {}, 267 | "outputs": [], 268 | "source": [] 269 | } 270 | ], 271 | "metadata": { 272 | "kernelspec": { 273 | "display_name": "Python 3", 274 | "language": "python", 275 | "name": "python3" 276 | }, 277 | "language_info": { 278 | "codemirror_mode": { 279 | "name": "ipython", 280 | "version": 3 281 | }, 282 | "file_extension": ".py", 283 | "mimetype": "text/x-python", 284 | "name": "python", 285 | "nbconvert_exporter": "python", 286 | "pygments_lexer": "ipython3", 287 | "version": "3.7.3" 288 | }, 289 | "toc": { 290 | "base_numbering": 1, 291 | "nav_menu": {}, 292 | "number_sections": true, 293 | "sideBar": true, 294 | "skip_h1_title": false, 295 | "title_cell": "Table of Contents", 296 | "title_sidebar": "Contents", 297 | "toc_cell": false, 298 | "toc_position": {}, 299 | "toc_section_display": true, 300 | "toc_window_display": false 301 | } 302 | }, 303 | "nbformat": 4, 304 | "nbformat_minor": 4 305 | } 306 | -------------------------------------------------------------------------------- /charpter19_SVD/louwill.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luwill/Machine_Learning_Code_Implementation/f557538b329d4004da7900b140671a31c43faa7b/charpter19_SVD/louwill.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "#### SVD" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "data": { 26 | "text/plain": [ 27 | "array([[0, 1],\n", 28 | " [1, 1],\n", 29 | " [1, 0]])" 30 | ] 31 | }, 32 | "execution_count": 2, 33 | "metadata": {}, 34 | "output_type": "execute_result" 35 | } 36 | ], 37 | "source": [ 38 | "A = np.array([[0,1],[1,1],[1,0]])\n", 39 | "A" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 3, 45 | "metadata": {}, 46 | "outputs": [ 47 | { 48 | "name": "stdout", 49 | "output_type": "stream", 50 | "text": [ 51 | "(3, 3) (2,) (2, 2)\n" 52 | ] 53 | } 54 | ], 55 | "source": [ 56 | "u, s, vt = np.linalg.svd(A, full_matrices=True)\n", 57 | "print(u.shape, s.shape, vt.shape)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 4, 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "data": { 67 | "text/plain": [ 68 | "array([[-4.08248290e-01, 7.07106781e-01, 5.77350269e-01],\n", 69 | " [-8.16496581e-01, 5.55111512e-17, -5.77350269e-01],\n", 70 | " [-4.08248290e-01, -7.07106781e-01, 5.77350269e-01]])" 71 | ] 72 | }, 73 | "execution_count": 4, 74 | "metadata": {}, 75 | "output_type": "execute_result" 76 | } 77 | ], 78 | "source": [ 79 | "u" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 5, 85 | "metadata": {}, 86 | "outputs": [ 87 | { 88 | "data": { 89 | "text/plain": [ 90 | "array([1.73205081, 1. ])" 91 | ] 92 | }, 93 | "execution_count": 5, 94 | "metadata": {}, 95 | "output_type": "execute_result" 96 | } 97 | ], 98 | "source": [ 99 | "s" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 6, 105 | "metadata": {}, 106 | "outputs": [ 107 | { 108 | "data": { 109 | "text/plain": [ 110 | "array([[-0.70710678, -0.70710678],\n", 111 | " [-0.70710678, 0.70710678]])" 112 | ] 113 | }, 114 | "execution_count": 6, 115 | "metadata": {}, 116 | "output_type": "execute_result" 117 | } 118 | ], 119 | "source": [ 120 | "vt.T" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 8, 126 | "metadata": {}, 127 | "outputs": [ 128 | { 129 | "data": { 130 | "text/plain": [ 131 | "True" 132 | ] 133 | }, 134 | "execution_count": 8, 135 | "metadata": {}, 136 | "output_type": "execute_result" 137 | } 138 | ], 139 | "source": [ 140 | "np.allclose(A, np.dot(u[:,:2]*s, vt))" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 10, 146 | "metadata": {}, 147 | "outputs": [ 148 | { 149 | "data": { 150 | "text/plain": [ 151 | "array([[ 1.11022302e-16, 1.00000000e+00],\n", 152 | " [ 1.00000000e+00, 1.00000000e+00],\n", 153 | " [ 1.00000000e+00, -3.33066907e-16]])" 154 | ] 155 | }, 156 | "execution_count": 10, 157 | "metadata": {}, 158 | "output_type": "execute_result" 159 | } 160 | ], 161 | "source": [ 162 | "np.dot(u[:,:2]*s, vt)" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 17, 168 | "metadata": {}, 169 | "outputs": [ 170 | { 171 | "data": { 172 | "text/plain": [ 173 | "array([[1.73205081, 0. ],\n", 174 | " [0. , 1. ],\n", 175 | " [0. , 0. ]])" 176 | ] 177 | }, 178 | "execution_count": 17, 179 | "metadata": {}, 180 | "output_type": "execute_result" 181 | } 182 | ], 183 | "source": [ 184 | "s_ = np.zeros((3,2))\n", 185 | "for i in range(2):\n", 186 | " s_[i][i] = s[i]\n", 187 | "\n", 188 | "s_" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 22, 194 | "metadata": {}, 195 | "outputs": [ 196 | { 197 | "data": { 198 | "text/plain": [ 199 | "array([[ 1.11022302e-16, 1.00000000e+00],\n", 200 | " [ 1.00000000e+00, 1.00000000e+00],\n", 201 | " [ 1.00000000e+00, -3.33066907e-16]])" 202 | ] 203 | }, 204 | "execution_count": 22, 205 | "metadata": {}, 206 | "output_type": "execute_result" 207 | } 208 | ], 209 | "source": [ 210 | "np.dot(np.dot(u, s_), vt)" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 7, 216 | "metadata": {}, 217 | "outputs": [ 218 | { 219 | "name": "stderr", 220 | "output_type": "stream", 221 | "text": [ 222 | "100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [02:11<00:00, 2.63s/it]\n" 223 | ] 224 | } 225 | ], 226 | "source": [ 227 | "import numpy as np\n", 228 | "import os\n", 229 | "from PIL import Image\n", 230 | "from tqdm import tqdm\n", 231 | "\n", 232 | "# 定义恢复函数,由分解后的矩阵恢复到原矩阵\n", 233 | "def restore(u, s, v, K): \n", 234 | " '''\n", 235 | " u:左奇异矩阵\n", 236 | " v:右奇异矩阵\n", 237 | " s:奇异值矩阵\n", 238 | " K:奇异值个数\n", 239 | " '''\n", 240 | " m, n = len(u), len(v[0])\n", 241 | " a = np.zeros((m, n))\n", 242 | " for k in range(K):\n", 243 | " uk = u[:, k].reshape(m, 1)\n", 244 | " vk = v[k].reshape(1, n)\n", 245 | " # 前k个奇异值的加总\n", 246 | " a += s[k] * np.dot(uk, vk) \n", 247 | " a = a.clip(0, 255)\n", 248 | " return np.rint(a).astype('uint8')\n", 249 | "\n", 250 | "A = np.array(Image.open(\"./louwill.jpg\", 'r'))\n", 251 | "# 对RGB图像进行奇异值分解\n", 252 | "u_r, s_r, v_r = np.linalg.svd(A[:, :, 0]) \n", 253 | "u_g, s_g, v_g = np.linalg.svd(A[:, :, 1])\n", 254 | "u_b, s_b, v_b = np.linalg.svd(A[:, :, 2])\n", 255 | "\n", 256 | "# 使用前50个奇异值\n", 257 | "K = 50 \n", 258 | "output_path = r'./svd_pic'\n", 259 | "# \n", 260 | "for k in tqdm(range(1, K+1)):\n", 261 | " R = restore(u_r, s_r, v_r, k)\n", 262 | " G = restore(u_g, s_g, v_g, k)\n", 263 | " B = restore(u_b, s_b, v_b, k)\n", 264 | " I = np.stack((R, G, B), axis=2) \n", 265 | " Image.fromarray(I).save('%s\\\\svd_%d.jpg' % (output_path, k))" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 4, 271 | "metadata": {}, 272 | "outputs": [ 273 | { 274 | "data": { 275 | "text/plain": [ 276 | "(959, 959, 3)" 277 | ] 278 | }, 279 | "execution_count": 4, 280 | "metadata": {}, 281 | "output_type": "execute_result" 282 | } 283 | ], 284 | "source": [ 285 | "A.shape" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": null, 291 | "metadata": {}, 292 | "outputs": [], 293 | "source": [] 294 | } 295 | ], 296 | "metadata": { 297 | "kernelspec": { 298 | "display_name": "Python 3", 299 | "language": "python", 300 | "name": "python3" 301 | }, 302 | "language_info": { 303 | "codemirror_mode": { 304 | "name": "ipython", 305 | "version": 3 306 | }, 307 | "file_extension": ".py", 308 | "mimetype": "text/x-python", 309 | "name": "python", 310 | "nbconvert_exporter": "python", 311 | "pygments_lexer": "ipython3", 312 | "version": "3.7.3" 313 | }, 314 | "toc": { 315 | "base_numbering": 1, 316 | "nav_menu": {}, 317 | "number_sections": true, 318 | "sideBar": true, 319 | "skip_h1_title": false, 320 | "title_cell": "Table of Contents", 321 | "title_sidebar": "Contents", 322 | "toc_cell": false, 323 | "toc_position": {}, 324 | "toc_section_display": true, 325 | "toc_window_display": false 326 | } 327 | }, 328 | "nbformat": 4, 329 | "nbformat_minor": 2 330 | } 331 | -------------------------------------------------------------------------------- /charpter1_ml_start/NumPy_sklearn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## 机器学习入门" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "data": { 17 | "text/plain": [ 18 | "array([1, 2, 3])" 19 | ] 20 | }, 21 | "execution_count": 1, 22 | "metadata": {}, 23 | "output_type": "execute_result" 24 | } 25 | ], 26 | "source": [ 27 | "# 导入numpy模块\n", 28 | "import numpy as np\n", 29 | "# 将整数列表转换为NumPy数组\n", 30 | "a = np.array([1,2,3])\n", 31 | "# 查看数组对象\n", 32 | "a" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 2, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "data": { 42 | "text/plain": [ 43 | "dtype('int32')" 44 | ] 45 | }, 46 | "execution_count": 2, 47 | "metadata": {}, 48 | "output_type": "execute_result" 49 | } 50 | ], 51 | "source": [ 52 | "# 查看整数数组对象类型\n", 53 | "a.dtype" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 3, 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "data": { 63 | "text/plain": [ 64 | "dtype('float64')" 65 | ] 66 | }, 67 | "execution_count": 3, 68 | "metadata": {}, 69 | "output_type": "execute_result" 70 | } 71 | ], 72 | "source": [ 73 | "# 将浮点数列表转换为NumPy数组\n", 74 | "b = np.array([1.2, 2.3, 3.4])\n", 75 | "# 查看浮点数数组对象类型\n", 76 | "b.dtype" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 5, 82 | "metadata": {}, 83 | "outputs": [ 84 | { 85 | "data": { 86 | "text/plain": [ 87 | "array([[1, 2, 3],\n", 88 | " [4, 5, 6]])" 89 | ] 90 | }, 91 | "execution_count": 5, 92 | "metadata": {}, 93 | "output_type": "execute_result" 94 | } 95 | ], 96 | "source": [ 97 | "# 将两个整数列表转换为二维NumPy数组\n", 98 | "c = np.array([[1,2,3], [4,5,6]])\n", 99 | "c" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 6, 105 | "metadata": {}, 106 | "outputs": [ 107 | { 108 | "data": { 109 | "text/plain": [ 110 | "array([[0., 0., 0.],\n", 111 | " [0., 0., 0.]])" 112 | ] 113 | }, 114 | "execution_count": 6, 115 | "metadata": {}, 116 | "output_type": "execute_result" 117 | } 118 | ], 119 | "source": [ 120 | "# 生成2×3的全0数组\n", 121 | "np.zeros((2, 3))" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 8, 127 | "metadata": {}, 128 | "outputs": [ 129 | { 130 | "data": { 131 | "text/plain": [ 132 | "array([[1, 1, 1, 1],\n", 133 | " [1, 1, 1, 1],\n", 134 | " [1, 1, 1, 1]], dtype=int16)" 135 | ] 136 | }, 137 | "execution_count": 8, 138 | "metadata": {}, 139 | "output_type": "execute_result" 140 | } 141 | ], 142 | "source": [ 143 | "# 生成3×4的全1数组\n", 144 | "np.ones((3, 4), dtype=np.int16)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 9, 150 | "metadata": {}, 151 | "outputs": [ 152 | { 153 | "data": { 154 | "text/plain": [ 155 | "array([[0., 0., 0.],\n", 156 | " [0., 0., 0.]])" 157 | ] 158 | }, 159 | "execution_count": 9, 160 | "metadata": {}, 161 | "output_type": "execute_result" 162 | } 163 | ], 164 | "source": [ 165 | "# 生成2×3的随机数数组\n", 166 | "np.empty([2, 3])" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 10, 172 | "metadata": {}, 173 | "outputs": [ 174 | { 175 | "data": { 176 | "text/plain": [ 177 | "array([10, 15, 20, 25])" 178 | ] 179 | }, 180 | "execution_count": 10, 181 | "metadata": {}, 182 | "output_type": "execute_result" 183 | } 184 | ], 185 | "source": [ 186 | "# arange方法用于创建给定范围内的数组\n", 187 | "np.arange(10, 30, 5 )" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 11, 193 | "metadata": {}, 194 | "outputs": [ 195 | { 196 | "data": { 197 | "text/plain": [ 198 | "array([[0.63400839, 0.00645055],\n", 199 | " [0.05525655, 0.11816511],\n", 200 | " [0.72091326, 0.18560605]])" 201 | ] 202 | }, 203 | "execution_count": 11, 204 | "metadata": {}, 205 | "output_type": "execute_result" 206 | } 207 | ], 208 | "source": [ 209 | "# 生成3×2的符合(0,1)均匀分布的随机数数组\n", 210 | "np.random.rand(3, 2)" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 12, 216 | "metadata": {}, 217 | "outputs": [ 218 | { 219 | "data": { 220 | "text/plain": [ 221 | "array([2, 2, 1, 0, 0])" 222 | ] 223 | }, 224 | "execution_count": 12, 225 | "metadata": {}, 226 | "output_type": "execute_result" 227 | } 228 | ], 229 | "source": [ 230 | "# 生成0到2范围内长度为5的数组\n", 231 | "np.random.randint(3, size=5)" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": 13, 237 | "metadata": {}, 238 | "outputs": [ 239 | { 240 | "data": { 241 | "text/plain": [ 242 | "array([-1.17524897, 0.53607197, -0.79819063])" 243 | ] 244 | }, 245 | "execution_count": 13, 246 | "metadata": {}, 247 | "output_type": "execute_result" 248 | } 249 | ], 250 | "source": [ 251 | "# 生成一组符合标准正态分布的随机数数组\n", 252 | "np.random.randn(3)" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": 14, 258 | "metadata": {}, 259 | "outputs": [ 260 | { 261 | "data": { 262 | "text/plain": [ 263 | "array([ 0, 1, 4, 9, 16, 25, 36, 49, 64, 81], dtype=int32)" 264 | ] 265 | }, 266 | "execution_count": 14, 267 | "metadata": {}, 268 | "output_type": "execute_result" 269 | } 270 | ], 271 | "source": [ 272 | "# 创建一个一维数组 \n", 273 | "a = np.arange(10)**2\n", 274 | "a" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": 15, 280 | "metadata": {}, 281 | "outputs": [ 282 | { 283 | "data": { 284 | "text/plain": [ 285 | "4" 286 | ] 287 | }, 288 | "execution_count": 15, 289 | "metadata": {}, 290 | "output_type": "execute_result" 291 | } 292 | ], 293 | "source": [ 294 | "# 获取数组的第3个元素\n", 295 | "a[2]" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 16, 301 | "metadata": {}, 302 | "outputs": [ 303 | { 304 | "data": { 305 | "text/plain": [ 306 | "array([1, 4, 9], dtype=int32)" 307 | ] 308 | }, 309 | "execution_count": 16, 310 | "metadata": {}, 311 | "output_type": "execute_result" 312 | } 313 | ], 314 | "source": [ 315 | "# 获取第2个到第4个数组元素\n", 316 | "a[1:4]" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": 17, 322 | "metadata": {}, 323 | "outputs": [ 324 | { 325 | "data": { 326 | "text/plain": [ 327 | "array([81, 64, 49, 36, 25, 16, 9, 4, 1, 0], dtype=int32)" 328 | ] 329 | }, 330 | "execution_count": 17, 331 | "metadata": {}, 332 | "output_type": "execute_result" 333 | } 334 | ], 335 | "source": [ 336 | "# 一维数组翻转\n", 337 | "a[::-1]" 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": 18, 343 | "metadata": {}, 344 | "outputs": [ 345 | { 346 | "data": { 347 | "text/plain": [ 348 | "array([[0.15845869, 0.7332507 , 0.08502049],\n", 349 | " [0.39413397, 0.69465393, 0.68838422],\n", 350 | " [0.89048929, 0.12574353, 0.59993755]])" 351 | ] 352 | }, 353 | "execution_count": 18, 354 | "metadata": {}, 355 | "output_type": "execute_result" 356 | } 357 | ], 358 | "source": [ 359 | "# 创建一个多维数组\n", 360 | "b = np.random.random((3,3))\n", 361 | "b" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": 19, 367 | "metadata": {}, 368 | "outputs": [ 369 | { 370 | "data": { 371 | "text/plain": [ 372 | "0.6883842231755073" 373 | ] 374 | }, 375 | "execution_count": 19, 376 | "metadata": {}, 377 | "output_type": "execute_result" 378 | } 379 | ], 380 | "source": [ 381 | "# 获取第2行第3列的数组元素\n", 382 | "b[1,2]" 383 | ] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "execution_count": 20, 388 | "metadata": {}, 389 | "outputs": [ 390 | { 391 | "data": { 392 | "text/plain": [ 393 | "array([0.7332507 , 0.69465393, 0.12574353])" 394 | ] 395 | }, 396 | "execution_count": 20, 397 | "metadata": {}, 398 | "output_type": "execute_result" 399 | } 400 | ], 401 | "source": [ 402 | "# 获取第2列数据\n", 403 | "b[:,1]" 404 | ] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "execution_count": 21, 409 | "metadata": {}, 410 | "outputs": [ 411 | { 412 | "data": { 413 | "text/plain": [ 414 | "array([0.08502049, 0.68838422])" 415 | ] 416 | }, 417 | "execution_count": 21, 418 | "metadata": {}, 419 | "output_type": "execute_result" 420 | } 421 | ], 422 | "source": [ 423 | "# 获取第3列前两行数据\n", 424 | "b[:2, 2]" 425 | ] 426 | }, 427 | { 428 | "cell_type": "code", 429 | "execution_count": 22, 430 | "metadata": {}, 431 | "outputs": [ 432 | { 433 | "data": { 434 | "text/plain": [ 435 | "array([ 5, 9, 13, 17])" 436 | ] 437 | }, 438 | "execution_count": 22, 439 | "metadata": {}, 440 | "output_type": "execute_result" 441 | } 442 | ], 443 | "source": [ 444 | "# 创建两个不同的数组\n", 445 | "a = np.arange(4)\n", 446 | "b = np.array([5,10,15,20])\n", 447 | "# 两个数组做减法运算\n", 448 | "b-a" 449 | ] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "execution_count": 23, 454 | "metadata": {}, 455 | "outputs": [ 456 | { 457 | "data": { 458 | "text/plain": [ 459 | "array([ 25, 100, 225, 400], dtype=int32)" 460 | ] 461 | }, 462 | "execution_count": 23, 463 | "metadata": {}, 464 | "output_type": "execute_result" 465 | } 466 | ], 467 | "source": [ 468 | "# 计算数组的平方\n", 469 | "b**2" 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": 24, 475 | "metadata": {}, 476 | "outputs": [ 477 | { 478 | "data": { 479 | "text/plain": [ 480 | "array([0. , 0.84147098, 0.90929743, 0.14112001])" 481 | ] 482 | }, 483 | "execution_count": 24, 484 | "metadata": {}, 485 | "output_type": "execute_result" 486 | } 487 | ], 488 | "source": [ 489 | "# 计算数组的正弦值\n", 490 | "np.sin(a)" 491 | ] 492 | }, 493 | { 494 | "cell_type": "code", 495 | "execution_count": 25, 496 | "metadata": {}, 497 | "outputs": [ 498 | { 499 | "data": { 500 | "text/plain": [ 501 | "array([ True, True, True, False])" 502 | ] 503 | }, 504 | "execution_count": 25, 505 | "metadata": {}, 506 | "output_type": "execute_result" 507 | } 508 | ], 509 | "source": [ 510 | "# 数组的逻辑运算\n", 511 | "b<20" 512 | ] 513 | }, 514 | { 515 | "cell_type": "code", 516 | "execution_count": 26, 517 | "metadata": {}, 518 | "outputs": [ 519 | { 520 | "data": { 521 | "text/plain": [ 522 | "12.5" 523 | ] 524 | }, 525 | "execution_count": 26, 526 | "metadata": {}, 527 | "output_type": "execute_result" 528 | } 529 | ], 530 | "source": [ 531 | "# 数组求均值和方差\n", 532 | "np.mean(b)" 533 | ] 534 | }, 535 | { 536 | "cell_type": "code", 537 | "execution_count": 27, 538 | "metadata": {}, 539 | "outputs": [ 540 | { 541 | "data": { 542 | "text/plain": [ 543 | "31.25" 544 | ] 545 | }, 546 | "execution_count": 27, 547 | "metadata": {}, 548 | "output_type": "execute_result" 549 | } 550 | ], 551 | "source": [ 552 | "# 数组求方差\n", 553 | "np.var(b)" 554 | ] 555 | }, 556 | { 557 | "cell_type": "code", 558 | "execution_count": 31, 559 | "metadata": {}, 560 | "outputs": [ 561 | { 562 | "data": { 563 | "text/plain": [ 564 | "array([[2, 0],\n", 565 | " [0, 4]])" 566 | ] 567 | }, 568 | "execution_count": 31, 569 | "metadata": {}, 570 | "output_type": "execute_result" 571 | } 572 | ], 573 | "source": [ 574 | "# 创建两个不同的数组\n", 575 | "A = np.array([[1,1],\n", 576 | " [0,1]])\n", 577 | "B = np.array([[2,0],\n", 578 | " [3,4]])\n", 579 | "# 矩阵元素乘积\n", 580 | "A * B" 581 | ] 582 | }, 583 | { 584 | "cell_type": "code", 585 | "execution_count": 32, 586 | "metadata": {}, 587 | "outputs": [ 588 | { 589 | "data": { 590 | "text/plain": [ 591 | "array([[5, 4],\n", 592 | " [3, 4]])" 593 | ] 594 | }, 595 | "execution_count": 32, 596 | "metadata": {}, 597 | "output_type": "execute_result" 598 | } 599 | ], 600 | "source": [ 601 | "# 矩阵点乘\n", 602 | "A.dot(B)" 603 | ] 604 | }, 605 | { 606 | "cell_type": "code", 607 | "execution_count": 33, 608 | "metadata": {}, 609 | "outputs": [ 610 | { 611 | "data": { 612 | "text/plain": [ 613 | "array([[ 1., -1.],\n", 614 | " [ 0., 1.]])" 615 | ] 616 | }, 617 | "execution_count": 33, 618 | "metadata": {}, 619 | "output_type": "execute_result" 620 | } 621 | ], 622 | "source": [ 623 | "# 矩阵求逆\n", 624 | "np.linalg.inv(A)" 625 | ] 626 | }, 627 | { 628 | "cell_type": "code", 629 | "execution_count": 34, 630 | "metadata": {}, 631 | "outputs": [ 632 | { 633 | "data": { 634 | "text/plain": [ 635 | "1.0" 636 | ] 637 | }, 638 | "execution_count": 34, 639 | "metadata": {}, 640 | "output_type": "execute_result" 641 | } 642 | ], 643 | "source": [ 644 | "# 矩阵求行列式\n", 645 | "np.linalg.det(A)" 646 | ] 647 | }, 648 | { 649 | "cell_type": "code", 650 | "execution_count": 36, 651 | "metadata": {}, 652 | "outputs": [ 653 | { 654 | "data": { 655 | "text/plain": [ 656 | "array([[8., 6., 7., 9.],\n", 657 | " [7., 0., 6., 2.],\n", 658 | " [9., 7., 3., 4.]])" 659 | ] 660 | }, 661 | "execution_count": 36, 662 | "metadata": {}, 663 | "output_type": "execute_result" 664 | } 665 | ], 666 | "source": [ 667 | "# 创建一个3×4的数组\n", 668 | "a = np.floor(10*np.random.random((3,4)))\n", 669 | "a" 670 | ] 671 | }, 672 | { 673 | "cell_type": "code", 674 | "execution_count": 37, 675 | "metadata": {}, 676 | "outputs": [ 677 | { 678 | "data": { 679 | "text/plain": [ 680 | "(3, 4)" 681 | ] 682 | }, 683 | "execution_count": 37, 684 | "metadata": {}, 685 | "output_type": "execute_result" 686 | } 687 | ], 688 | "source": [ 689 | "# 查看数组维度\n", 690 | "a.shape" 691 | ] 692 | }, 693 | { 694 | "cell_type": "code", 695 | "execution_count": 38, 696 | "metadata": {}, 697 | "outputs": [ 698 | { 699 | "data": { 700 | "text/plain": [ 701 | "array([8., 6., 7., 9., 7., 0., 6., 2., 9., 7., 3., 4.])" 702 | ] 703 | }, 704 | "execution_count": 38, 705 | "metadata": {}, 706 | "output_type": "execute_result" 707 | } 708 | ], 709 | "source": [ 710 | "# 数组展平\n", 711 | "a.ravel()" 712 | ] 713 | }, 714 | { 715 | "cell_type": "code", 716 | "execution_count": 39, 717 | "metadata": {}, 718 | "outputs": [ 719 | { 720 | "data": { 721 | "text/plain": [ 722 | "array([[8., 6., 7., 9., 7., 0.],\n", 723 | " [6., 2., 9., 7., 3., 4.]])" 724 | ] 725 | }, 726 | "execution_count": 39, 727 | "metadata": {}, 728 | "output_type": "execute_result" 729 | } 730 | ], 731 | "source": [ 732 | "# 将数组变换为2×6数组\n", 733 | "a.reshape(2,6)" 734 | ] 735 | }, 736 | { 737 | "cell_type": "code", 738 | "execution_count": 40, 739 | "metadata": {}, 740 | "outputs": [ 741 | { 742 | "data": { 743 | "text/plain": [ 744 | "array([[8., 7., 9.],\n", 745 | " [6., 0., 7.],\n", 746 | " [7., 6., 3.],\n", 747 | " [9., 2., 4.]])" 748 | ] 749 | }, 750 | "execution_count": 40, 751 | "metadata": {}, 752 | "output_type": "execute_result" 753 | } 754 | ], 755 | "source": [ 756 | "# 求数组的转置\n", 757 | "a.T" 758 | ] 759 | }, 760 | { 761 | "cell_type": "code", 762 | "execution_count": 42, 763 | "metadata": {}, 764 | "outputs": [ 765 | { 766 | "data": { 767 | "text/plain": [ 768 | "(4, 3)" 769 | ] 770 | }, 771 | "execution_count": 42, 772 | "metadata": {}, 773 | "output_type": "execute_result" 774 | } 775 | ], 776 | "source": [ 777 | "a.T.shape" 778 | ] 779 | }, 780 | { 781 | "cell_type": "code", 782 | "execution_count": 43, 783 | "metadata": {}, 784 | "outputs": [ 785 | { 786 | "data": { 787 | "text/plain": [ 788 | "array([[8., 6., 7., 9.],\n", 789 | " [7., 0., 6., 2.],\n", 790 | " [9., 7., 3., 4.]])" 791 | ] 792 | }, 793 | "execution_count": 43, 794 | "metadata": {}, 795 | "output_type": "execute_result" 796 | } 797 | ], 798 | "source": [ 799 | "# -1维度表示NumPy会自动计算该维度\n", 800 | "a.reshape(3,-1)" 801 | ] 802 | }, 803 | { 804 | "cell_type": "code", 805 | "execution_count": 44, 806 | "metadata": {}, 807 | "outputs": [ 808 | { 809 | "data": { 810 | "text/plain": [ 811 | "array([[1, 1, 2, 0],\n", 812 | " [0, 1, 3, 4]])" 813 | ] 814 | }, 815 | "execution_count": 44, 816 | "metadata": {}, 817 | "output_type": "execute_result" 818 | } 819 | ], 820 | "source": [ 821 | "# 按行合并代码清单1-7中的A数组和B数组\n", 822 | "np.hstack((A,B))" 823 | ] 824 | }, 825 | { 826 | "cell_type": "code", 827 | "execution_count": 45, 828 | "metadata": {}, 829 | "outputs": [ 830 | { 831 | "data": { 832 | "text/plain": [ 833 | "array([[1, 1],\n", 834 | " [0, 1],\n", 835 | " [2, 0],\n", 836 | " [3, 4]])" 837 | ] 838 | }, 839 | "execution_count": 45, 840 | "metadata": {}, 841 | "output_type": "execute_result" 842 | } 843 | ], 844 | "source": [ 845 | "# 按列合并A数组和B数组\n", 846 | "np.vstack((A,B))" 847 | ] 848 | }, 849 | { 850 | "cell_type": "code", 851 | "execution_count": 46, 852 | "metadata": {}, 853 | "outputs": [ 854 | { 855 | "data": { 856 | "text/plain": [ 857 | "array([[ 0., 1., 2., 3.],\n", 858 | " [ 4., 5., 6., 7.],\n", 859 | " [ 8., 9., 10., 11.],\n", 860 | " [12., 13., 14., 15.]])" 861 | ] 862 | }, 863 | "execution_count": 46, 864 | "metadata": {}, 865 | "output_type": "execute_result" 866 | } 867 | ], 868 | "source": [ 869 | "# 创建一个新数组\n", 870 | "C = np.arange(16.0).reshape(4, 4)\n", 871 | "C" 872 | ] 873 | }, 874 | { 875 | "cell_type": "code", 876 | "execution_count": 47, 877 | "metadata": {}, 878 | "outputs": [ 879 | { 880 | "data": { 881 | "text/plain": [ 882 | "[array([[ 0., 1.],\n", 883 | " [ 4., 5.],\n", 884 | " [ 8., 9.],\n", 885 | " [12., 13.]]),\n", 886 | " array([[ 2., 3.],\n", 887 | " [ 6., 7.],\n", 888 | " [10., 11.],\n", 889 | " [14., 15.]])]" 890 | ] 891 | }, 892 | "execution_count": 47, 893 | "metadata": {}, 894 | "output_type": "execute_result" 895 | } 896 | ], 897 | "source": [ 898 | "# 按水平方向将数组C切分为两个数组\n", 899 | "np.hsplit(C, 2)" 900 | ] 901 | }, 902 | { 903 | "cell_type": "code", 904 | "execution_count": 48, 905 | "metadata": {}, 906 | "outputs": [ 907 | { 908 | "data": { 909 | "text/plain": [ 910 | "[array([[0., 1., 2., 3.],\n", 911 | " [4., 5., 6., 7.]]),\n", 912 | " array([[ 8., 9., 10., 11.],\n", 913 | " [12., 13., 14., 15.]])]" 914 | ] 915 | }, 916 | "execution_count": 48, 917 | "metadata": {}, 918 | "output_type": "execute_result" 919 | } 920 | ], 921 | "source": [ 922 | "# 按垂直方向将数组C切分为两个数组\n", 923 | "np.vsplit(C, 2)" 924 | ] 925 | }, 926 | { 927 | "cell_type": "code", 928 | "execution_count": 49, 929 | "metadata": {}, 930 | "outputs": [ 931 | { 932 | "name": "stderr", 933 | "output_type": "stream", 934 | "text": [ 935 | "D:\\installation\\anaconda\\install_files\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:764: ConvergenceWarning: lbfgs failed to converge (status=1):\n", 936 | "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", 937 | "\n", 938 | "Increase the number of iterations (max_iter) or scale the data as shown in:\n", 939 | " https://scikit-learn.org/stable/modules/preprocessing.html\n", 940 | "Please also refer to the documentation for alternative solver options:\n", 941 | " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", 942 | " extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG)\n" 943 | ] 944 | }, 945 | { 946 | "data": { 947 | "text/plain": [ 948 | "array([0, 0])" 949 | ] 950 | }, 951 | "execution_count": 49, 952 | "metadata": {}, 953 | "output_type": "execute_result" 954 | } 955 | ], 956 | "source": [ 957 | "# 导入iris数据集和逻辑回归算法模块\n", 958 | "from sklearn.datasets import load_iris\n", 959 | "from sklearn.linear_model import LogisticRegression\n", 960 | "# 导入数据\n", 961 | "X, y = load_iris(return_X_y=True)\n", 962 | "# 拟合模型\n", 963 | "clf = LogisticRegression(random_state=0).fit(X, y)\n", 964 | "# 预测\n", 965 | "clf.predict(X[:2, :])" 966 | ] 967 | }, 968 | { 969 | "cell_type": "code", 970 | "execution_count": 50, 971 | "metadata": {}, 972 | "outputs": [ 973 | { 974 | "data": { 975 | "text/plain": [ 976 | "array([[9.81797141e-01, 1.82028445e-02, 1.44269293e-08],\n", 977 | " [9.71725476e-01, 2.82744937e-02, 3.01659208e-08]])" 978 | ] 979 | }, 980 | "execution_count": 50, 981 | "metadata": {}, 982 | "output_type": "execute_result" 983 | } 984 | ], 985 | "source": [ 986 | "# 概率预测\n", 987 | "clf.predict_proba(X[:2, :])" 988 | ] 989 | }, 990 | { 991 | "cell_type": "code", 992 | "execution_count": 51, 993 | "metadata": {}, 994 | "outputs": [ 995 | { 996 | "data": { 997 | "text/plain": [ 998 | "0.9733333333333334" 999 | ] 1000 | }, 1001 | "execution_count": 51, 1002 | "metadata": {}, 1003 | "output_type": "execute_result" 1004 | } 1005 | ], 1006 | "source": [ 1007 | "# 模型准确率\n", 1008 | "clf.score(X, y)" 1009 | ] 1010 | }, 1011 | { 1012 | "cell_type": "code", 1013 | "execution_count": null, 1014 | "metadata": {}, 1015 | "outputs": [], 1016 | "source": [] 1017 | } 1018 | ], 1019 | "metadata": { 1020 | "kernelspec": { 1021 | "display_name": "Python 3", 1022 | "language": "python", 1023 | "name": "python3" 1024 | }, 1025 | "language_info": { 1026 | "codemirror_mode": { 1027 | "name": "ipython", 1028 | "version": 3 1029 | }, 1030 | "file_extension": ".py", 1031 | "mimetype": "text/x-python", 1032 | "name": "python", 1033 | "nbconvert_exporter": "python", 1034 | "pygments_lexer": "ipython3", 1035 | "version": "3.7.5" 1036 | } 1037 | }, 1038 | "nbformat": 4, 1039 | "nbformat_minor": 4 1040 | } 1041 | -------------------------------------------------------------------------------- /charpter20_MEM/max_entropy_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "import numpy as np\n", 11 | "from collections import defaultdict\n", 12 | "\n", 13 | "class MaxEnt:\n", 14 | " def __init__(self, max_iter=100):\n", 15 | " # 训练输入\n", 16 | " self.X_ = None\n", 17 | " # 训练标签\n", 18 | " self.y_ = None\n", 19 | " # 标签类别数量\n", 20 | " self.m = None \n", 21 | " # 特征数量\n", 22 | " self.n = None \n", 23 | " # 训练样本量\n", 24 | " self.N = None \n", 25 | " # 常数特征取值\n", 26 | " self.M = None\n", 27 | " # 权重系数\n", 28 | " self.w = None\n", 29 | " # 标签名称\n", 30 | " self.labels = defaultdict(int)\n", 31 | " # 特征名称\n", 32 | " self.features = defaultdict(int)\n", 33 | " # 最大迭代次数\n", 34 | " self.max_iter = max_iter\n", 35 | "\n", 36 | " ### 计算特征函数关于经验联合分布P(X,Y)的期望\n", 37 | " def _EP_hat_f(self, x, y):\n", 38 | " self.Pxy = np.zeros((self.m, self.n))\n", 39 | " self.Px = np.zeros(self.n)\n", 40 | " for x_, y_ in zip(x, y):\n", 41 | " # 遍历每个样本\n", 42 | " for x__ in set(x_):\n", 43 | " self.Pxy[self.labels[y_], self.features[x__]] += 1\n", 44 | " self.Px[self.features[x__]] += 1 \n", 45 | " self.EP_hat_f = self.Pxy/self.N\n", 46 | " \n", 47 | " ### 计算特征函数关于模型P(Y|X)与经验分布P(X)的期望\n", 48 | " def _EP_f(self):\n", 49 | " self.EPf = np.zeros((self.m, self.n))\n", 50 | " for X in self.X_:\n", 51 | " pw = self._pw(X)\n", 52 | " pw = pw.reshape(self.m, 1)\n", 53 | " px = self.Px.reshape(1, self.n)\n", 54 | " self.EP_f += pw*px / self.N\n", 55 | " \n", 56 | " ### 最大熵模型P(y|x)\n", 57 | " def _pw(self, x):\n", 58 | " mask = np.zeros(self.n+1)\n", 59 | " for ix in x:\n", 60 | " mask[self.features[ix]] = 1\n", 61 | " tmp = self.w * mask[1:]\n", 62 | " pw = np.exp(np.sum(tmp, axis=1))\n", 63 | " Z = np.sum(pw)\n", 64 | " pw = pw/Z\n", 65 | " return pw\n", 66 | "\n", 67 | " ### 熵模型拟合\n", 68 | " ### 基于改进的迭代尺度方法IIS\n", 69 | " def fit(self, x, y):\n", 70 | " # 训练输入\n", 71 | " self.X_ = x\n", 72 | " # 训练输出\n", 73 | " self.y_ = list(set(y))\n", 74 | " # 输入数据展平后集合\n", 75 | " tmp = set(self.X_.flatten())\n", 76 | " # 特征命名\n", 77 | " self.features = defaultdict(int, zip(tmp, range(1, len(tmp)+1))) \n", 78 | " # 标签命名\n", 79 | " self.labels = dict(zip(self.y_, range(len(self.y_))))\n", 80 | " # 特征数\n", 81 | " self.n = len(self.features)+1 \n", 82 | " # 标签类别数量\n", 83 | " self.m = len(self.labels)\n", 84 | " # 训练样本量\n", 85 | " self.N = len(x) \n", 86 | " # 计算EP_hat_f\n", 87 | " self._EP_hat_f(x, y)\n", 88 | " # 初始化系数矩阵\n", 89 | " self.w = np.zeros((self.m, self.n))\n", 90 | " # 循环迭代\n", 91 | " i = 0\n", 92 | " while i <= self.max_iter:\n", 93 | " # 计算EPf\n", 94 | " self._EP_f()\n", 95 | " # 令常数特征函数为M\n", 96 | " self.M = 100\n", 97 | " # IIS算法步骤(3)\n", 98 | " tmp = np.true_divide(self.EP_hat_f, self.EP_f)\n", 99 | " tmp[tmp == np.inf] = 0\n", 100 | " tmp = np.nan_to_num(tmp)\n", 101 | " sigma = np.where(tmp != 0, 1/self.M*np.log(tmp), 0) \n", 102 | " # 更新系数:IIS步骤(4)\n", 103 | " self.w = self.w + sigma\n", 104 | " i += 1\n", 105 | " print('training done.')\n", 106 | " return self\n", 107 | "\n", 108 | " # 定义最大熵模型预测函数\n", 109 | " def predict(self, x):\n", 110 | " res = np.zeros(len(x), dtype=np.int64)\n", 111 | " for ix, x_ in enumerate(x):\n", 112 | " tmp = self._pw(x_)\n", 113 | " print(tmp, np.argmax(tmp), self.labels)\n", 114 | " res[ix] = self.labels[self.y_[np.argmax(tmp)]]\n", 115 | " return np.array([self.y_[ix] for ix in res])" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 2, 121 | "metadata": {}, 122 | "outputs": [ 123 | { 124 | "name": "stdout", 125 | "output_type": "stream", 126 | "text": [ 127 | "(105, 4) (105,)\n" 128 | ] 129 | } 130 | ], 131 | "source": [ 132 | "from sklearn.datasets import load_iris\n", 133 | "from sklearn.model_selection import train_test_split\n", 134 | "raw_data = load_iris()\n", 135 | "X, labels = raw_data.data, raw_data.target\n", 136 | "X_train, X_test, y_train, y_test = train_test_split(X, labels, test_size=0.3, random_state=43)\n", 137 | "print(X_train.shape, y_train.shape)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 3, 143 | "metadata": {}, 144 | "outputs": [ 145 | { 146 | "data": { 147 | "text/plain": [ 148 | "array([2, 2, 2, 2, 2])" 149 | ] 150 | }, 151 | "execution_count": 3, 152 | "metadata": {}, 153 | "output_type": "execute_result" 154 | } 155 | ], 156 | "source": [ 157 | "labels[-5:]" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 4, 163 | "metadata": {}, 164 | "outputs": [ 165 | { 166 | "name": "stderr", 167 | "output_type": "stream", 168 | "text": [ 169 | "D:\\Installation\\anaconda\\install\\lib\\site-packages\\ipykernel_launcher.py:90: RuntimeWarning: invalid value encountered in true_divide\n", 170 | "D:\\Installation\\anaconda\\install\\lib\\site-packages\\ipykernel_launcher.py:93: RuntimeWarning: divide by zero encountered in log\n" 171 | ] 172 | }, 173 | { 174 | "name": "stdout", 175 | "output_type": "stream", 176 | "text": [ 177 | "training done.\n", 178 | "[0.87116843 0.04683368 0.08199789] 0 {0: 0, 1: 1, 2: 2}\n", 179 | "[0.00261138 0.49573305 0.50165557] 2 {0: 0, 1: 1, 2: 2}\n", 180 | "[0.12626693 0.017157 0.85657607] 2 {0: 0, 1: 1, 2: 2}\n", 181 | "[1.55221378e-04 4.45985560e-05 9.99800180e-01] 2 {0: 0, 1: 1, 2: 2}\n", 182 | "[7.29970746e-03 9.92687370e-01 1.29226740e-05] 1 {0: 0, 1: 1, 2: 2}\n", 183 | "[0.01343943 0.01247887 0.9740817 ] 2 {0: 0, 1: 1, 2: 2}\n", 184 | "[0.85166079 0.05241898 0.09592023] 0 {0: 0, 1: 1, 2: 2}\n", 185 | "[0.00371481 0.00896982 0.98731537] 2 {0: 0, 1: 1, 2: 2}\n", 186 | "[2.69340079e-04 9.78392776e-01 2.13378835e-02] 1 {0: 0, 1: 1, 2: 2}\n", 187 | "[0.01224702 0.02294254 0.96481044] 2 {0: 0, 1: 1, 2: 2}\n", 188 | "[0.00323508 0.98724246 0.00952246] 1 {0: 0, 1: 1, 2: 2}\n", 189 | "[0.00196548 0.01681989 0.98121463] 2 {0: 0, 1: 1, 2: 2}\n", 190 | "[0.00480966 0.00345107 0.99173927] 2 {0: 0, 1: 1, 2: 2}\n", 191 | "[0.00221101 0.01888735 0.97890163] 2 {0: 0, 1: 1, 2: 2}\n", 192 | "[9.87528545e-01 3.25313387e-04 1.21461416e-02] 0 {0: 0, 1: 1, 2: 2}\n", 193 | "[3.84153917e-05 5.25603786e-01 4.74357798e-01] 1 {0: 0, 1: 1, 2: 2}\n", 194 | "[0.91969448 0.00730851 0.07299701] 0 {0: 0, 1: 1, 2: 2}\n", 195 | "[3.48493252e-03 9.96377722e-01 1.37345863e-04] 1 {0: 0, 1: 1, 2: 2}\n", 196 | "[0.00597935 0.02540794 0.96861271] 2 {0: 0, 1: 1, 2: 2}\n", 197 | "[0.96593729 0.01606867 0.01799404] 0 {0: 0, 1: 1, 2: 2}\n", 198 | "[7.07324443e-01 2.92672257e-01 3.29961259e-06] 0 {0: 0, 1: 1, 2: 2}\n", 199 | "[0.96122092 0.03604362 0.00273547] 0 {0: 0, 1: 1, 2: 2}\n", 200 | "[9.92671813e-01 7.31265179e-03 1.55352641e-05] 0 {0: 0, 1: 1, 2: 2}\n", 201 | "[9.99997290e-01 2.58555077e-06 1.24081335e-07] 0 {0: 0, 1: 1, 2: 2}\n", 202 | "[1.77991802e-05 4.62006560e-04 9.99520194e-01] 2 {0: 0, 1: 1, 2: 2}\n", 203 | "[9.99995176e-01 3.85240188e-06 9.72067357e-07] 0 {0: 0, 1: 1, 2: 2}\n", 204 | "[0.15306343 0.21405142 0.63288515] 2 {0: 0, 1: 1, 2: 2}\n", 205 | "[0.25817329 0.28818997 0.45363674] 2 {0: 0, 1: 1, 2: 2}\n", 206 | "[2.43530473e-04 4.07929999e-01 5.91826471e-01] 2 {0: 0, 1: 1, 2: 2}\n", 207 | "[0.71160155 0.27290911 0.01548934] 0 {0: 0, 1: 1, 2: 2}\n", 208 | "[2.94976826e-06 2.51510534e-02 9.74845997e-01] 2 {0: 0, 1: 1, 2: 2}\n", 209 | "[0.97629163 0.00331591 0.02039245] 0 {0: 0, 1: 1, 2: 2}\n", 210 | "[0.04513811 0.01484173 0.94002015] 2 {0: 0, 1: 1, 2: 2}\n", 211 | "[0.61382753 0.38321073 0.00296174] 0 {0: 0, 1: 1, 2: 2}\n", 212 | "[9.65538451e-01 3.86322918e-06 3.44576854e-02] 0 {0: 0, 1: 1, 2: 2}\n", 213 | "[0.00924088 0.01731108 0.97344804] 2 {0: 0, 1: 1, 2: 2}\n", 214 | "[0.02511142 0.93818613 0.03670245] 1 {0: 0, 1: 1, 2: 2}\n", 215 | "[9.99127831e-01 3.29723254e-04 5.42445518e-04] 0 {0: 0, 1: 1, 2: 2}\n", 216 | "[0.05081665 0.0038204 0.94536295] 2 {0: 0, 1: 1, 2: 2}\n", 217 | "[9.99985376e-01 6.85280694e-06 7.77081022e-06] 0 {0: 0, 1: 1, 2: 2}\n", 218 | "[9.99791732e-01 2.06536005e-04 1.73191035e-06] 0 {0: 0, 1: 1, 2: 2}\n", 219 | "[2.72323181e-04 2.99692548e-03 9.96730751e-01] 2 {0: 0, 1: 1, 2: 2}\n", 220 | "[0.02005139 0.97151852 0.00843009] 1 {0: 0, 1: 1, 2: 2}\n", 221 | "[0.95642409 0.02485912 0.01871679] 0 {0: 0, 1: 1, 2: 2}\n", 222 | "[0.00297317 0.01261126 0.98441558] 2 {0: 0, 1: 1, 2: 2}\n", 223 | "0.37777777777777777\n" 224 | ] 225 | } 226 | ], 227 | "source": [ 228 | "from sklearn.metrics import accuracy_score\n", 229 | "maxent = MaxEnt()\n", 230 | "maxent.fit(X_train, y_train)\n", 231 | "y_pred = maxent.predict(X_test)\n", 232 | "print(accuracy_score(y_test, y_pred))" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": null, 238 | "metadata": {}, 239 | "outputs": [], 240 | "source": [] 241 | } 242 | ], 243 | "metadata": { 244 | "kernelspec": { 245 | "display_name": "Python 3", 246 | "language": "python", 247 | "name": "python3" 248 | }, 249 | "language_info": { 250 | "codemirror_mode": { 251 | "name": "ipython", 252 | "version": 3 253 | }, 254 | "file_extension": ".py", 255 | "mimetype": "text/x-python", 256 | "name": "python", 257 | "nbconvert_exporter": "python", 258 | "pygments_lexer": "ipython3", 259 | "version": "3.7.3" 260 | }, 261 | "toc": { 262 | "base_numbering": 1, 263 | "nav_menu": {}, 264 | "number_sections": true, 265 | "sideBar": true, 266 | "skip_h1_title": false, 267 | "title_cell": "Table of Contents", 268 | "title_sidebar": "Contents", 269 | "toc_cell": false, 270 | "toc_position": {}, 271 | "toc_section_display": true, 272 | "toc_window_display": false 273 | } 274 | }, 275 | "nbformat": 4, 276 | "nbformat_minor": 2 277 | } 278 | -------------------------------------------------------------------------------- /charpter21_Bayesian_models/bayesian_network.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### bayesian network" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "# 导入pgmpy相关模块\n", 17 | "from pgmpy.factors.discrete import TabularCPD\n", 18 | "from pgmpy.models import BayesianModel\n", 19 | "letter_model = BayesianModel([('D', 'G'),\n", 20 | " ('I', 'G'),\n", 21 | " ('G', 'L'),\n", 22 | " ('I', 'S')])" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 4, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "# 学生成绩的条件概率分布\n", 32 | "grade_cpd = TabularCPD(\n", 33 | " variable='G', # 节点名称\n", 34 | " variable_card=3, # 节点取值个数\n", 35 | " values=[[0.3, 0.05, 0.9, 0.5], # 该节点的概率表\n", 36 | " [0.4, 0.25, 0.08, 0.3],\n", 37 | " [0.3, 0.7, 0.02, 0.2]],\n", 38 | " evidence=['I', 'D'], # 该节点的依赖节点\n", 39 | " evidence_card=[2, 2] # 依赖节点的取值个数\n", 40 | ")\n", 41 | "# 考试难度的条件概率分布\n", 42 | "difficulty_cpd = TabularCPD(\n", 43 | " variable='D',\n", 44 | " variable_card=2,\n", 45 | " values=[[0.6], [0.4]]\n", 46 | ")\n", 47 | "# 个人天赋的条件概率分布\n", 48 | "intel_cpd = TabularCPD(\n", 49 | " variable='I',\n", 50 | " variable_card=2,\n", 51 | " values=[[0.7], [0.3]]\n", 52 | ")\n", 53 | "# 推荐信质量的条件概率分布\n", 54 | "letter_cpd = TabularCPD(\n", 55 | " variable='L',\n", 56 | " variable_card=2,\n", 57 | " values=[[0.1, 0.4, 0.99],\n", 58 | " [0.9, 0.6, 0.01]],\n", 59 | " evidence=['G'],\n", 60 | " evidence_card=[3]\n", 61 | ")\n", 62 | "# SAT考试分数的条件概率分布\n", 63 | "sat_cpd = TabularCPD(\n", 64 | " variable='S',\n", 65 | " variable_card=2,\n", 66 | " values=[[0.95, 0.2],\n", 67 | " [0.05, 0.8]],\n", 68 | " evidence=['I'],\n", 69 | " evidence_card=[2]\n", 70 | ")" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 7, 76 | "metadata": {}, 77 | "outputs": [ 78 | { 79 | "name": "stderr", 80 | "output_type": "stream", 81 | "text": [ 82 | "WARNING:root:Replacing existing CPD for G\n", 83 | "WARNING:root:Replacing existing CPD for D\n", 84 | "WARNING:root:Replacing existing CPD for I\n", 85 | "WARNING:root:Replacing existing CPD for L\n", 86 | "WARNING:root:Replacing existing CPD for S\n", 87 | "Finding Elimination Order: : 100%|██████████████████████████████████████████████████████| 2/2 [00:00<00:00, 668.95it/s]\n", 88 | "Eliminating: L: 100%|███████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 285.72it/s]" 89 | ] 90 | }, 91 | { 92 | "name": "stdout", 93 | "output_type": "stream", 94 | "text": [ 95 | "+------+----------+\n", 96 | "| G | phi(G) |\n", 97 | "+======+==========+\n", 98 | "| G(0) | 0.9000 |\n", 99 | "+------+----------+\n", 100 | "| G(1) | 0.0800 |\n", 101 | "+------+----------+\n", 102 | "| G(2) | 0.0200 |\n", 103 | "+------+----------+\n" 104 | ] 105 | }, 106 | { 107 | "name": "stderr", 108 | "output_type": "stream", 109 | "text": [ 110 | "\n" 111 | ] 112 | } 113 | ], 114 | "source": [ 115 | "# 将各节点添加到模型中,构建贝叶斯网络\n", 116 | "letter_model.add_cpds(\n", 117 | " grade_cpd, \n", 118 | " difficulty_cpd,\n", 119 | " intel_cpd,\n", 120 | " letter_cpd,\n", 121 | " sat_cpd\n", 122 | ")\n", 123 | "# 导入pgmpy贝叶斯推断模块\n", 124 | "from pgmpy.inference import VariableElimination\n", 125 | "# 贝叶斯网络推断\n", 126 | "letter_infer = VariableElimination(letter_model)\n", 127 | "# 天赋较好且考试不难的情况下推断该学生获得成绩等级\n", 128 | "prob_G = letter_infer.query(\n", 129 | " variables=['G'],\n", 130 | " evidence={'I': 1, 'D': 0})\n", 131 | "print(prob_G)" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [] 140 | } 141 | ], 142 | "metadata": { 143 | "kernelspec": { 144 | "display_name": "Python 3", 145 | "language": "python", 146 | "name": "python3" 147 | }, 148 | "language_info": { 149 | "codemirror_mode": { 150 | "name": "ipython", 151 | "version": 3 152 | }, 153 | "file_extension": ".py", 154 | "mimetype": "text/x-python", 155 | "name": "python", 156 | "nbconvert_exporter": "python", 157 | "pygments_lexer": "ipython3", 158 | "version": "3.7.3" 159 | }, 160 | "toc": { 161 | "base_numbering": 1, 162 | "nav_menu": {}, 163 | "number_sections": true, 164 | "sideBar": true, 165 | "skip_h1_title": false, 166 | "title_cell": "Table of Contents", 167 | "title_sidebar": "Contents", 168 | "toc_cell": false, 169 | "toc_position": {}, 170 | "toc_section_display": true, 171 | "toc_window_display": false 172 | } 173 | }, 174 | "nbformat": 4, 175 | "nbformat_minor": 4 176 | } 177 | -------------------------------------------------------------------------------- /charpter21_Bayesian_models/naive_bayes.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### Naive Bayes" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "import pandas as pd" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "metadata": {}, 24 | "outputs": [ 25 | { 26 | "data": { 27 | "text/html": [ 28 | "
\n", 29 | "\n", 42 | "\n", 43 | " \n", 44 | " \n", 45 | " \n", 46 | " \n", 47 | " \n", 48 | " \n", 49 | " \n", 50 | " \n", 51 | " \n", 52 | " \n", 53 | " \n", 54 | " \n", 55 | " \n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | "
x1x2y
01S-1
11M-1
21M1
31S1
41S-1
\n", 84 | "
" 85 | ], 86 | "text/plain": [ 87 | " x1 x2 y\n", 88 | "0 1 S -1\n", 89 | "1 1 M -1\n", 90 | "2 1 M 1\n", 91 | "3 1 S 1\n", 92 | "4 1 S -1" 93 | ] 94 | }, 95 | "execution_count": 2, 96 | "metadata": {}, 97 | "output_type": "execute_result" 98 | } 99 | ], 100 | "source": [ 101 | "### 构造数据集\n", 102 | "### 来自于李航统计学习方法表4.1\n", 103 | "x1 = [1,1,1,1,1,2,2,2,2,2,3,3,3,3,3]\n", 104 | "x2 = ['S','M','M','S','S','S','M','M','L','L','L','M','M','L','L']\n", 105 | "y = [-1,-1,1,1,-1,-1,-1,1,1,1,1,1,1,1,-1]\n", 106 | "\n", 107 | "df = pd.DataFrame({'x1':x1, 'x2':x2, 'y':y})\n", 108 | "df.head()" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 3, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "X = df[['x1', 'x2']]\n", 118 | "y = df[['y']]" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 6, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "def nb_fit(X, y):\n", 128 | " classes = y[y.columns[0]].unique()\n", 129 | " class_count = y[y.columns[0]].value_counts()\n", 130 | " class_prior = class_count/len(y)\n", 131 | " \n", 132 | " prior = dict()\n", 133 | " for col in X.columns:\n", 134 | " for j in classes:\n", 135 | " p_x_y = X[(y==j).values][col].value_counts()\n", 136 | " for i in p_x_y.index:\n", 137 | " prior[(col, i, j)] = p_x_y[i]/class_count[j]\n", 138 | " return classes, class_prior, prior" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 8, 144 | "metadata": {}, 145 | "outputs": [ 146 | { 147 | "name": "stdout", 148 | "output_type": "stream", 149 | "text": [ 150 | "{('x1', 1, -1): 0.5, ('x1', 2, -1): 0.3333333333333333, ('x1', 3, -1): 0.16666666666666666, ('x1', 3, 1): 0.4444444444444444, ('x1', 2, 1): 0.3333333333333333, ('x1', 1, 1): 0.2222222222222222, ('x2', 'S', -1): 0.5, ('x2', 'M', -1): 0.3333333333333333, ('x2', 'L', -1): 0.16666666666666666, ('x2', 'L', 1): 0.4444444444444444, ('x2', 'M', 1): 0.4444444444444444, ('x2', 'S', 1): 0.1111111111111111}\n" 151 | ] 152 | } 153 | ], 154 | "source": [ 155 | "classes, class_prior, prior = nb_fit(X, y)\n", 156 | "print(classes, class_prior, prior)" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 6, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "X_test = {'x1': 2, 'x2': 'S'}" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 7, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "classes, class_prior, prior = nb_fit(X, y)\n", 175 | "\n", 176 | "def predict(X_test):\n", 177 | " res = []\n", 178 | " for c in classes:\n", 179 | " p_y = class_prior[c]\n", 180 | " p_x_y = 1\n", 181 | " for i in X_test.items():\n", 182 | " p_x_y *= prior[tuple(list(i)+[c])]\n", 183 | " res.append(p_y*p_x_y)\n", 184 | " return classes[np.argmax(res)]" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 10, 190 | "metadata": {}, 191 | "outputs": [ 192 | { 193 | "name": "stdout", 194 | "output_type": "stream", 195 | "text": [ 196 | "测试数据预测类别为: -1\n" 197 | ] 198 | } 199 | ], 200 | "source": [ 201 | "print('测试数据预测类别为:', predict(X_test))" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 15, 207 | "metadata": {}, 208 | "outputs": [ 209 | { 210 | "name": "stdout", 211 | "output_type": "stream", 212 | "text": [ 213 | "Accuracy of GaussianNB in iris data test: 0.9466666666666667\n" 214 | ] 215 | } 216 | ], 217 | "source": [ 218 | "from sklearn.datasets import load_iris\n", 219 | "from sklearn.model_selection import train_test_split\n", 220 | "from sklearn.naive_bayes import GaussianNB\n", 221 | "from sklearn.metrics import accuracy_score\n", 222 | "X, y = load_iris(return_X_y=True)\n", 223 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0)\n", 224 | "gnb = GaussianNB()\n", 225 | "y_pred = gnb.fit(X_train, y_train).predict(X_test)\n", 226 | "print(\"Accuracy of GaussianNB in iris data test:\", \n", 227 | " accuracy_score(y_test, y_pred))" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [] 236 | } 237 | ], 238 | "metadata": { 239 | "kernelspec": { 240 | "display_name": "Python 3", 241 | "language": "python", 242 | "name": "python3" 243 | }, 244 | "language_info": { 245 | "codemirror_mode": { 246 | "name": "ipython", 247 | "version": 3 248 | }, 249 | "file_extension": ".py", 250 | "mimetype": "text/x-python", 251 | "name": "python", 252 | "nbconvert_exporter": "python", 253 | "pygments_lexer": "ipython3", 254 | "version": "3.7.3" 255 | }, 256 | "toc": { 257 | "base_numbering": 1, 258 | "nav_menu": {}, 259 | "number_sections": true, 260 | "sideBar": true, 261 | "skip_h1_title": false, 262 | "title_cell": "Table of Contents", 263 | "title_sidebar": "Contents", 264 | "toc_cell": false, 265 | "toc_position": {}, 266 | "toc_section_display": true, 267 | "toc_window_display": false 268 | } 269 | }, 270 | "nbformat": 4, 271 | "nbformat_minor": 2 272 | } 273 | -------------------------------------------------------------------------------- /charpter22_EM/em.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### EM" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "# 导入numpy库 \n", 17 | "import numpy as np\n", 18 | "\n", 19 | "### EM算法过程函数定义\n", 20 | "def em(data, thetas, max_iter=30, eps=1e-3):\n", 21 | " '''\n", 22 | " 输入:\n", 23 | " data:观测数据\n", 24 | " thetas:初始化的估计参数值\n", 25 | " max_iter:最大迭代次数\n", 26 | " eps:收敛阈值\n", 27 | " 输出:\n", 28 | " thetas:估计参数\n", 29 | " '''\n", 30 | " # 初始化似然函数值\n", 31 | " ll_old = -np.infty\n", 32 | " for i in range(max_iter):\n", 33 | " ### E步:求隐变量分布\n", 34 | " # 对数似然\n", 35 | " log_like = np.array([np.sum(data * np.log(theta), axis=1) for theta in thetas])\n", 36 | " # 似然\n", 37 | " like = np.exp(log_like)\n", 38 | " # 求隐变量分布\n", 39 | " ws = like/like.sum(0)\n", 40 | " # 概率加权\n", 41 | " vs = np.array([w[:, None] * data for w in ws])\n", 42 | " ### M步:更新参数值\n", 43 | " thetas = np.array([v.sum(0)/v.sum() for v in vs])\n", 44 | " # 更新似然函数\n", 45 | " ll_new = np.sum([w*l for w, l in zip(ws, log_like)])\n", 46 | " print(\"Iteration: %d\" % (i+1))\n", 47 | " print(\"theta_B = %.2f, theta_C = %.2f, ll = %.2f\" \n", 48 | " % (thetas[0,0], thetas[1,0], ll_new))\n", 49 | " # 满足迭代条件即退出迭代\n", 50 | " if np.abs(ll_new - ll_old) < eps:\n", 51 | " break\n", 52 | " ll_old = ll_new\n", 53 | " return thetas" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 2, 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "name": "stdout", 63 | "output_type": "stream", 64 | "text": [ 65 | "Iteration: 1\n", 66 | "theta_B = 0.71, theta_C = 0.58, ll = -32.69\n", 67 | "Iteration: 2\n", 68 | "theta_B = 0.75, theta_C = 0.57, ll = -31.26\n", 69 | "Iteration: 3\n", 70 | "theta_B = 0.77, theta_C = 0.55, ll = -30.76\n", 71 | "Iteration: 4\n", 72 | "theta_B = 0.78, theta_C = 0.53, ll = -30.33\n", 73 | "Iteration: 5\n", 74 | "theta_B = 0.79, theta_C = 0.53, ll = -30.07\n", 75 | "Iteration: 6\n", 76 | "theta_B = 0.79, theta_C = 0.52, ll = -29.95\n", 77 | "Iteration: 7\n", 78 | "theta_B = 0.80, theta_C = 0.52, ll = -29.90\n", 79 | "Iteration: 8\n", 80 | "theta_B = 0.80, theta_C = 0.52, ll = -29.88\n", 81 | "Iteration: 9\n", 82 | "theta_B = 0.80, theta_C = 0.52, ll = -29.87\n", 83 | "Iteration: 10\n", 84 | "theta_B = 0.80, theta_C = 0.52, ll = -29.87\n", 85 | "Iteration: 11\n", 86 | "theta_B = 0.80, theta_C = 0.52, ll = -29.87\n", 87 | "Iteration: 12\n", 88 | "theta_B = 0.80, theta_C = 0.52, ll = -29.87\n" 89 | ] 90 | } 91 | ], 92 | "source": [ 93 | "# 观测数据,5次独立试验,每次试验10次抛掷的正反次数\n", 94 | "# 比如第一次试验为5次正面5次反面\n", 95 | "observed_data = np.array([(5,5), (9,1), (8,2), (4,6), (7,3)])\n", 96 | "# 初始化参数值,即硬币B的正面概率为0.6,硬币C的正面概率为0.5\n", 97 | "thetas = np.array([[0.6, 0.4], [0.5, 0.5]])\n", 98 | "thetas = em(observed_data, thetas, max_iter=30, eps=1e-3)" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 3, 104 | "metadata": {}, 105 | "outputs": [ 106 | { 107 | "data": { 108 | "text/plain": [ 109 | "array([[0.7967829 , 0.2032171 ],\n", 110 | " [0.51959543, 0.48040457]])" 111 | ] 112 | }, 113 | "execution_count": 3, 114 | "metadata": {}, 115 | "output_type": "execute_result" 116 | } 117 | ], 118 | "source": [ 119 | "thetas" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [] 128 | } 129 | ], 130 | "metadata": { 131 | "kernelspec": { 132 | "display_name": "Python 3", 133 | "language": "python", 134 | "name": "python3" 135 | }, 136 | "language_info": { 137 | "codemirror_mode": { 138 | "name": "ipython", 139 | "version": 3 140 | }, 141 | "file_extension": ".py", 142 | "mimetype": "text/x-python", 143 | "name": "python", 144 | "nbconvert_exporter": "python", 145 | "pygments_lexer": "ipython3", 146 | "version": "3.7.3" 147 | }, 148 | "toc": { 149 | "base_numbering": 1, 150 | "nav_menu": {}, 151 | "number_sections": true, 152 | "sideBar": true, 153 | "skip_h1_title": false, 154 | "title_cell": "Table of Contents", 155 | "title_sidebar": "Contents", 156 | "toc_cell": false, 157 | "toc_position": {}, 158 | "toc_section_display": true, 159 | "toc_window_display": false 160 | } 161 | }, 162 | "nbformat": 4, 163 | "nbformat_minor": 4 164 | } 165 | -------------------------------------------------------------------------------- /charpter23_HMM/hmm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### HMM" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 3, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "\n", 18 | "### 定义HMM模型\n", 19 | "class HMM:\n", 20 | " def __init__(self, N, M, pi=None, A=None, B=None):\n", 21 | " # 可能的状态数\n", 22 | " self.N = N\n", 23 | " # 可能的观测数\n", 24 | " self.M = M\n", 25 | " # 初始状态概率向量\n", 26 | " self.pi = pi\n", 27 | " # 状态转移概率矩阵\n", 28 | " self.A = A\n", 29 | " # 观测概率矩阵\n", 30 | " self.B = B\n", 31 | "\n", 32 | " # 根据给定的概率分布随机返回数据\n", 33 | " def rdistribution(self, dist): \n", 34 | " r = np.random.rand()\n", 35 | " for ix, p in enumerate(dist):\n", 36 | " if r < p: \n", 37 | " return ix\n", 38 | " r -= p\n", 39 | "\n", 40 | " # 生成HMM观测序列\n", 41 | " def generate(self, T):\n", 42 | " # 根据初始概率分布生成第一个状态\n", 43 | " i = self.rdistribution(self.pi) \n", 44 | " # 生成第一个观测数据\n", 45 | " o = self.rdistribution(self.B[i]) \n", 46 | " observed_data = [o]\n", 47 | " # 遍历生成剩下的状态和观测数据\n", 48 | " for _ in range(T-1): \n", 49 | " i = self.rdistribution(self.A[i])\n", 50 | " o = self.rdistribution(self.B[i])\n", 51 | " observed_data.append(o)\n", 52 | " return observed_data" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 4, 58 | "metadata": {}, 59 | "outputs": [ 60 | { 61 | "name": "stdout", 62 | "output_type": "stream", 63 | "text": [ 64 | "[1, 0, 0, 1, 0]\n" 65 | ] 66 | } 67 | ], 68 | "source": [ 69 | "# 初始状态概率分布\n", 70 | "pi = np.array([0.25, 0.25, 0.25, 0.25])\n", 71 | "# 状态转移概率矩阵\n", 72 | "A = np.array([\n", 73 | " [0, 1, 0, 0],\n", 74 | " [0.4, 0, 0.6, 0],\n", 75 | " [0, 0.4, 0, 0.6],\n", 76 | "[0, 0, 0.5, 0.5]])\n", 77 | "# 观测概率矩阵\n", 78 | "B = np.array([\n", 79 | " [0.5, 0.5],\n", 80 | " [0.6, 0.4],\n", 81 | " [0.2, 0.8],\n", 82 | " [0.3, 0.7]])\n", 83 | "# 可能的状态数和观测数\n", 84 | "N = 4\n", 85 | "M = 2\n", 86 | "# 创建HMM实例\n", 87 | "hmm = HMM(N, M, pi, A, B)\n", 88 | "# 生成观测序列\n", 89 | "print(hmm.generate(5))" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 6, 95 | "metadata": {}, 96 | "outputs": [ 97 | { 98 | "name": "stdout", 99 | "output_type": "stream", 100 | "text": [ 101 | "0.01983169125\n" 102 | ] 103 | } 104 | ], 105 | "source": [ 106 | "### 前向算法计算条件概率\n", 107 | "def prob_calc(O):\n", 108 | " '''\n", 109 | " 输入:\n", 110 | " O:观测序列\n", 111 | " 输出:\n", 112 | " alpha.sum():条件概率\n", 113 | " '''\n", 114 | " # 初值\n", 115 | " alpha = pi * B[:, O[0]]\n", 116 | " # 递推\n", 117 | " for o in O[1:]:\n", 118 | " alpha_next = np.empty(4)\n", 119 | " for j in range(4):\n", 120 | " alpha_next[j] = np.sum(A[:,j] * alpha * B[j,o])\n", 121 | " alpha = alpha_next\n", 122 | " return alpha.sum()\n", 123 | "\n", 124 | "# 给定观测\n", 125 | "O = [1,0,1,0,0]\n", 126 | "print(prob_calc(O))" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 7, 132 | "metadata": {}, 133 | "outputs": [ 134 | { 135 | "name": "stdout", 136 | "output_type": "stream", 137 | "text": [ 138 | "[0, 1, 2, 3, 3]\n" 139 | ] 140 | } 141 | ], 142 | "source": [ 143 | "### 序列标注问题和维特比算法\n", 144 | "def viterbi_decode(O):\n", 145 | " '''\n", 146 | " 输入:\n", 147 | " O:观测序列\n", 148 | " 输出:\n", 149 | " path:最优隐状态路径\n", 150 | " ''' \n", 151 | " # 序列长度和初始观测\n", 152 | " T, o = len(O), O[0]\n", 153 | " # 初始化delta变量\n", 154 | " delta = pi * B[:, o]\n", 155 | " # 初始化varphi变量\n", 156 | " varphi = np.zeros((T, 4), dtype=int)\n", 157 | " path = [0] * T\n", 158 | " # 递推\n", 159 | " for i in range(1, T):\n", 160 | " delta = delta.reshape(-1, 1) \n", 161 | " tmp = delta * A\n", 162 | " varphi[i, :] = np.argmax(tmp, axis=0)\n", 163 | " delta = np.max(tmp, axis=0) * B[:, O[i]]\n", 164 | " # 终止\n", 165 | " path[-1] = np.argmax(delta)\n", 166 | " # 回溯最优路径\n", 167 | " for i in range(T-1, 0, -1):\n", 168 | " path[i-1] = varphi[i, path[i]]\n", 169 | " return path\n", 170 | "\n", 171 | "# 给定观测序列\n", 172 | "O = [1,0,1,1,0]\n", 173 | "# 输出最可能的隐状态序列\n", 174 | "print(viterbi_decode(O))" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": null, 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [] 183 | } 184 | ], 185 | "metadata": { 186 | "kernelspec": { 187 | "display_name": "Python 3", 188 | "language": "python", 189 | "name": "python3" 190 | }, 191 | "language_info": { 192 | "codemirror_mode": { 193 | "name": "ipython", 194 | "version": 3 195 | }, 196 | "file_extension": ".py", 197 | "mimetype": "text/x-python", 198 | "name": "python", 199 | "nbconvert_exporter": "python", 200 | "pygments_lexer": "ipython3", 201 | "version": "3.7.3" 202 | }, 203 | "toc": { 204 | "base_numbering": 1, 205 | "nav_menu": {}, 206 | "number_sections": true, 207 | "sideBar": true, 208 | "skip_h1_title": false, 209 | "title_cell": "Table of Contents", 210 | "title_sidebar": "Contents", 211 | "toc_cell": false, 212 | "toc_position": {}, 213 | "toc_section_display": true, 214 | "toc_window_display": false 215 | } 216 | }, 217 | "nbformat": 4, 218 | "nbformat_minor": 2 219 | } 220 | -------------------------------------------------------------------------------- /charpter24_CRF/crf.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### CRF" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "基于sklearn_crfsuite NER系统搭建,本例来自于sklearn_crfsuite官方tutorial" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 3, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "# 导入相关库\n", 24 | "import nltk\n", 25 | "import sklearn\n", 26 | "import scipy.stats\n", 27 | "from sklearn.metrics import make_scorer\n", 28 | "from sklearn.model_selection import cross_val_score\n", 29 | "from sklearn.model_selection import RandomizedSearchCV\n", 30 | "\n", 31 | "import sklearn_crfsuite\n", 32 | "from sklearn_crfsuite import scorers\n", 33 | "from sklearn_crfsuite import metrics" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 5, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "name": "stderr", 43 | "output_type": "stream", 44 | "text": [ 45 | "[nltk_data] Downloading package conll2002 to\n", 46 | "[nltk_data] C:\\Users\\92070\\AppData\\Roaming\\nltk_data...\n", 47 | "[nltk_data] Unzipping corpora\\conll2002.zip.\n" 48 | ] 49 | }, 50 | { 51 | "data": { 52 | "text/plain": [ 53 | "True" 54 | ] 55 | }, 56 | "execution_count": 5, 57 | "metadata": {}, 58 | "output_type": "execute_result" 59 | } 60 | ], 61 | "source": [ 62 | "# 基于NLTK下载示例数据集\n", 63 | "nltk.download('conll2002')" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 6, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "# 设置训练和测试样本\n", 73 | "train_sents = list(nltk.corpus.conll2002.iob_sents('esp.train'))\n", 74 | "test_sents = list(nltk.corpus.conll2002.iob_sents('esp.testb'))" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 7, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "data": { 84 | "text/plain": [ 85 | "[('Melbourne', 'NP', 'B-LOC'),\n", 86 | " ('(', 'Fpa', 'O'),\n", 87 | " ('Australia', 'NP', 'B-LOC'),\n", 88 | " (')', 'Fpt', 'O'),\n", 89 | " (',', 'Fc', 'O'),\n", 90 | " ('25', 'Z', 'O'),\n", 91 | " ('may', 'NC', 'O'),\n", 92 | " ('(', 'Fpa', 'O'),\n", 93 | " ('EFE', 'NC', 'B-ORG'),\n", 94 | " (')', 'Fpt', 'O'),\n", 95 | " ('.', 'Fp', 'O')]" 96 | ] 97 | }, 98 | "execution_count": 7, 99 | "metadata": {}, 100 | "output_type": "execute_result" 101 | } 102 | ], 103 | "source": [ 104 | "train_sents[0]" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 8, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "# 单词转化为数值特征\n", 114 | "def word2features(sent, i):\n", 115 | " word = sent[i][0]\n", 116 | " postag = sent[i][1]\n", 117 | "\n", 118 | " features = {\n", 119 | " 'bias': 1.0,\n", 120 | " 'word.lower()': word.lower(),\n", 121 | " 'word[-3:]': word[-3:],\n", 122 | " 'word[-2:]': word[-2:],\n", 123 | " 'word.isupper()': word.isupper(),\n", 124 | " 'word.istitle()': word.istitle(),\n", 125 | " 'word.isdigit()': word.isdigit(),\n", 126 | " 'postag': postag,\n", 127 | " 'postag[:2]': postag[:2],\n", 128 | " }\n", 129 | " if i > 0:\n", 130 | " word1 = sent[i-1][0]\n", 131 | " postag1 = sent[i-1][1]\n", 132 | " features.update({\n", 133 | " '-1:word.lower()': word1.lower(),\n", 134 | " '-1:word.istitle()': word1.istitle(),\n", 135 | " '-1:word.isupper()': word1.isupper(),\n", 136 | " '-1:postag': postag1,\n", 137 | " '-1:postag[:2]': postag1[:2],\n", 138 | " })\n", 139 | " else:\n", 140 | " features['BOS'] = True\n", 141 | "\n", 142 | " if i < len(sent)-1:\n", 143 | " word1 = sent[i+1][0]\n", 144 | " postag1 = sent[i+1][1]\n", 145 | " features.update({\n", 146 | " '+1:word.lower()': word1.lower(),\n", 147 | " '+1:word.istitle()': word1.istitle(),\n", 148 | " '+1:word.isupper()': word1.isupper(),\n", 149 | " '+1:postag': postag1,\n", 150 | " '+1:postag[:2]': postag1[:2],\n", 151 | " })\n", 152 | " else:\n", 153 | " features['EOS'] = True\n", 154 | "\n", 155 | " return features\n", 156 | "\n", 157 | "\n", 158 | "def sent2features(sent):\n", 159 | " return [word2features(sent, i) for i in range(len(sent))]\n", 160 | "\n", 161 | "def sent2labels(sent):\n", 162 | " return [label for token, postag, label in sent]\n", 163 | "\n", 164 | "def sent2tokens(sent):\n", 165 | " return [token for token, postag, label in sent]" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 9, 171 | "metadata": {}, 172 | "outputs": [ 173 | { 174 | "data": { 175 | "text/plain": [ 176 | "{'bias': 1.0,\n", 177 | " 'word.lower()': 'melbourne',\n", 178 | " 'word[-3:]': 'rne',\n", 179 | " 'word[-2:]': 'ne',\n", 180 | " 'word.isupper()': False,\n", 181 | " 'word.istitle()': True,\n", 182 | " 'word.isdigit()': False,\n", 183 | " 'postag': 'NP',\n", 184 | " 'postag[:2]': 'NP',\n", 185 | " 'BOS': True,\n", 186 | " '+1:word.lower()': '(',\n", 187 | " '+1:word.istitle()': False,\n", 188 | " '+1:word.isupper()': False,\n", 189 | " '+1:postag': 'Fpa',\n", 190 | " '+1:postag[:2]': 'Fp'}" 191 | ] 192 | }, 193 | "execution_count": 9, 194 | "metadata": {}, 195 | "output_type": "execute_result" 196 | } 197 | ], 198 | "source": [ 199 | "sent2features(train_sents[0])[0]" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 10, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "# 构造训练集和测试集\n", 209 | "X_train = [sent2features(s) for s in train_sents]\n", 210 | "y_train = [sent2labels(s) for s in train_sents]\n", 211 | "\n", 212 | "X_test = [sent2features(s) for s in test_sents]\n", 213 | "y_test = [sent2labels(s) for s in test_sents]" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 11, 219 | "metadata": {}, 220 | "outputs": [ 221 | { 222 | "name": "stdout", 223 | "output_type": "stream", 224 | "text": [ 225 | "8323 1517\n" 226 | ] 227 | } 228 | ], 229 | "source": [ 230 | "print(len(X_train), len(X_test))" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": 18, 236 | "metadata": {}, 237 | "outputs": [ 238 | { 239 | "data": { 240 | "text/plain": [ 241 | "0.7964686316443963" 242 | ] 243 | }, 244 | "execution_count": 18, 245 | "metadata": {}, 246 | "output_type": "execute_result" 247 | } 248 | ], 249 | "source": [ 250 | "# 创建CRF模型实例\n", 251 | "crf = sklearn_crfsuite.CRF(\n", 252 | " algorithm='lbfgs',\n", 253 | " c1=0.1,\n", 254 | " c2=0.1,\n", 255 | " max_iterations=100,\n", 256 | " all_possible_transitions=True\n", 257 | ")\n", 258 | "# 模型训练\n", 259 | "crf.fit(X_train, y_train)\n", 260 | "# 类别标签\n", 261 | "labels = list(crf.classes_)\n", 262 | "labels.remove('O')\n", 263 | "# 模型预测\n", 264 | "y_pred = crf.predict(X_test)\n", 265 | "# 计算F1得分\n", 266 | "metrics.flat_f1_score(y_test, y_pred,\n", 267 | " average='weighted', labels=labels)" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 19, 273 | "metadata": {}, 274 | "outputs": [ 275 | { 276 | "name": "stdout", 277 | "output_type": "stream", 278 | "text": [ 279 | " precision recall f1-score support\n", 280 | "\n", 281 | " B-LOC 0.810 0.784 0.797 1084\n", 282 | " I-LOC 0.690 0.637 0.662 325\n", 283 | " B-MISC 0.731 0.569 0.640 339\n", 284 | " I-MISC 0.699 0.589 0.639 557\n", 285 | " B-ORG 0.807 0.832 0.820 1400\n", 286 | " I-ORG 0.852 0.786 0.818 1104\n", 287 | " B-PER 0.850 0.884 0.867 735\n", 288 | " I-PER 0.893 0.943 0.917 634\n", 289 | "\n", 290 | " micro avg 0.813 0.787 0.799 6178\n", 291 | " macro avg 0.791 0.753 0.770 6178\n", 292 | "weighted avg 0.809 0.787 0.796 6178\n", 293 | "\n" 294 | ] 295 | } 296 | ], 297 | "source": [ 298 | "# 打印B和I组的模型结果\n", 299 | "sorted_labels = sorted(\n", 300 | " labels,\n", 301 | " key=lambda name: (name[1:], name[0])\n", 302 | ")\n", 303 | "print(metrics.flat_classification_report(\n", 304 | " y_test, y_pred, labels=sorted_labels, digits=3\n", 305 | "))" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": null, 311 | "metadata": {}, 312 | "outputs": [], 313 | "source": [] 314 | } 315 | ], 316 | "metadata": { 317 | "kernelspec": { 318 | "display_name": "Python 3", 319 | "language": "python", 320 | "name": "python3" 321 | }, 322 | "language_info": { 323 | "codemirror_mode": { 324 | "name": "ipython", 325 | "version": 3 326 | }, 327 | "file_extension": ".py", 328 | "mimetype": "text/x-python", 329 | "name": "python", 330 | "nbconvert_exporter": "python", 331 | "pygments_lexer": "ipython3", 332 | "version": "3.7.3" 333 | }, 334 | "toc": { 335 | "base_numbering": 1, 336 | "nav_menu": {}, 337 | "number_sections": true, 338 | "sideBar": true, 339 | "skip_h1_title": false, 340 | "title_cell": "Table of Contents", 341 | "title_sidebar": "Contents", 342 | "toc_cell": false, 343 | "toc_position": {}, 344 | "toc_section_display": true, 345 | "toc_window_display": false 346 | } 347 | }, 348 | "nbformat": 4, 349 | "nbformat_minor": 2 350 | } 351 | -------------------------------------------------------------------------------- /charpter5_LDA/LDA.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### LDA" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 14, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "### numpy版本为1.20.3\n", 17 | "import numpy as np\n", 18 | "\n", 19 | "class LDA():\n", 20 | " def __init__(self):\n", 21 | " # 初始化权重矩阵\n", 22 | " self.w = None\n", 23 | " \n", 24 | " # 计算协方差矩阵\n", 25 | " def calc_cov(self, X, Y=None):\n", 26 | " m = X.shape[0]\n", 27 | " # 数据标准化\n", 28 | " X = (X - np.mean(X, axis=0))/np.std(X, axis=0)\n", 29 | " Y = X if Y == None else (Y - np.mean(Y, axis=0))/np.std(Y, axis=0)\n", 30 | " return 1 / m * np.matmul(X.T, Y)\n", 31 | " \n", 32 | " # 对数据进行投影\n", 33 | " def project(self, X, y):\n", 34 | " self.fit(X, y)\n", 35 | " X_projection = X.dot(self.w)\n", 36 | " return X_projection\n", 37 | " \n", 38 | " # LDA拟合过程\n", 39 | " def fit(self, X, y):\n", 40 | " # 按类分组\n", 41 | " X0 = X[y == 0]\n", 42 | " X1 = X[y == 1]\n", 43 | "\n", 44 | " # 分别计算两类数据自变量的协方差矩阵\n", 45 | " sigma0 = self.calc_cov(X0)\n", 46 | " sigma1 = self.calc_cov(X1)\n", 47 | " # 计算类内散度矩阵\n", 48 | " Sw = sigma0 + sigma1\n", 49 | "\n", 50 | " # 分别计算两类数据自变量的均值和差\n", 51 | " u0, u1 = np.mean(X0, axis=0), np.mean(X1, axis=0)\n", 52 | " mean_diff = np.atleast_1d(u0 - u1)\n", 53 | "\n", 54 | " # 对类内散度矩阵进行奇异值分解\n", 55 | " U, S, V = np.linalg.svd(Sw)\n", 56 | " # 计算类内散度矩阵的逆\n", 57 | " Sw_ = np.dot(np.dot(V.T, np.linalg.pinv(np.diag(S))), U.T)\n", 58 | " # 计算w\n", 59 | " self.w = Sw_.dot(mean_diff)\n", 60 | "\n", 61 | " \n", 62 | " # LDA分类预测\n", 63 | " def predict(self, X):\n", 64 | " y_pred = []\n", 65 | " for sample in X:\n", 66 | " h = sample.dot(self.w)\n", 67 | " y = 1 * (h < 0)\n", 68 | " y_pred.append(y)\n", 69 | " return y_pred" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 2, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "from sklearn import datasets\n", 79 | "import matplotlib.pyplot as plt\n", 80 | "from sklearn.model_selection import train_test_split\n", 81 | "\n", 82 | "data = datasets.load_iris()\n", 83 | "X = data.data\n", 84 | "y = data.target" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 3, 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "(80, 4) (20, 4) (80,) (20,)\n" 97 | ] 98 | } 99 | ], 100 | "source": [ 101 | "X = X[y != 2]\n", 102 | "y = y[y != 2]\n", 103 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=41)\n", 104 | "print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 4, 110 | "metadata": {}, 111 | "outputs": [ 112 | { 113 | "name": "stdout", 114 | "output_type": "stream", 115 | "text": [ 116 | "0.85\n" 117 | ] 118 | } 119 | ], 120 | "source": [ 121 | "lda = LDA()\n", 122 | "lda.fit(X_train, y_train)\n", 123 | "y_pred = lda.predict(X_test)\n", 124 | "\n", 125 | "from sklearn.metrics import accuracy_score\n", 126 | "accuracy = accuracy_score(y_test, y_pred)\n", 127 | "print(accuracy)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 5, 133 | "metadata": {}, 134 | "outputs": [ 135 | { 136 | "data": { 137 | "text/plain": [ 138 | "array([0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0])" 139 | ] 140 | }, 141 | "execution_count": 5, 142 | "metadata": {}, 143 | "output_type": "execute_result" 144 | } 145 | ], 146 | "source": [ 147 | "y_test" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 6, 153 | "metadata": {}, 154 | "outputs": [ 155 | { 156 | "data": { 157 | "text/plain": [ 158 | "[0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1]" 159 | ] 160 | }, 161 | "execution_count": 6, 162 | "metadata": {}, 163 | "output_type": "execute_result" 164 | } 165 | ], 166 | "source": [ 167 | "y_pred" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 10, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "def calculate_covariance_matrix(X, Y=None):\n", 177 | " if Y is None:\n", 178 | " Y = X\n", 179 | " n_samples = np.shape(X)[0]\n", 180 | " covariance_matrix = (1 / (n_samples-1)) * (X - X.mean(axis=0)).T.dot(Y - Y.mean(axis=0))\n", 181 | "\n", 182 | " return np.array(covariance_matrix, dtype=float)" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 11, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "import matplotlib.pyplot as plt\n", 192 | "import matplotlib.cm as cmx\n", 193 | "import matplotlib.colors as colors\n", 194 | "\n", 195 | "class Plot():\n", 196 | " def __init__(self): \n", 197 | " self.cmap = plt.get_cmap('viridis')\n", 198 | "\n", 199 | " def _transform(self, X, dim):\n", 200 | " covariance = calculate_covariance_matrix(X)\n", 201 | " eigenvalues, eigenvectors = np.linalg.eig(covariance)\n", 202 | " # Sort eigenvalues and eigenvector by largest eigenvalues\n", 203 | " idx = eigenvalues.argsort()[::-1]\n", 204 | " eigenvalues = eigenvalues[idx][:dim]\n", 205 | " eigenvectors = np.atleast_1d(eigenvectors[:, idx])[:, :dim]\n", 206 | " # Project the data onto principal components\n", 207 | " X_transformed = X.dot(eigenvectors)\n", 208 | "\n", 209 | " return X_transformed\n", 210 | "\n", 211 | "\n", 212 | " def plot_regression(self, lines, title, axis_labels=None, mse=None, scatter=None, legend={\"type\": \"lines\", \"loc\": \"lower right\"}):\n", 213 | " \n", 214 | " if scatter:\n", 215 | " scatter_plots = scatter_labels = []\n", 216 | " for s in scatter:\n", 217 | " scatter_plots += [plt.scatter(s[\"x\"], s[\"y\"], color=s[\"color\"], s=s[\"size\"])]\n", 218 | " scatter_labels += [s[\"label\"]]\n", 219 | " scatter_plots = tuple(scatter_plots)\n", 220 | " scatter_labels = tuple(scatter_labels)\n", 221 | "\n", 222 | " for l in lines:\n", 223 | " li = plt.plot(l[\"x\"], l[\"y\"], color=s[\"color\"], linewidth=l[\"width\"], label=l[\"label\"])\n", 224 | "\n", 225 | " if mse:\n", 226 | " plt.suptitle(title)\n", 227 | " plt.title(\"MSE: %.2f\" % mse, fontsize=10)\n", 228 | " else:\n", 229 | " plt.title(title)\n", 230 | "\n", 231 | " if axis_labels:\n", 232 | " plt.xlabel(axis_labels[\"x\"])\n", 233 | " plt.ylabel(axis_labels[\"y\"])\n", 234 | "\n", 235 | " if legend[\"type\"] == \"lines\":\n", 236 | " plt.legend(loc=\"lower_left\")\n", 237 | " elif legend[\"type\"] == \"scatter\" and scatter:\n", 238 | " plt.legend(scatter_plots, scatter_labels, loc=legend[\"loc\"])\n", 239 | "\n", 240 | " plt.show()\n", 241 | "\n", 242 | "\n", 243 | "\n", 244 | " # Plot the dataset X and the corresponding labels y in 2D using PCA.\n", 245 | " def plot_in_2d(self, X, y=None, title=None, accuracy=None, legend_labels=None):\n", 246 | " X_transformed = self._transform(X, dim=2)\n", 247 | " x1 = X_transformed[:, 0]\n", 248 | " x2 = X_transformed[:, 1]\n", 249 | " class_distr = []\n", 250 | "\n", 251 | " y = np.array(y).astype(int)\n", 252 | "\n", 253 | " colors = [self.cmap(i) for i in np.linspace(0, 1, len(np.unique(y)))]\n", 254 | "\n", 255 | " # Plot the different class distributions\n", 256 | " for i, l in enumerate(np.unique(y)):\n", 257 | " _x1 = x1[y == l]\n", 258 | " _x2 = x2[y == l]\n", 259 | " _y = y[y == l]\n", 260 | " class_distr.append(plt.scatter(_x1, _x2, color=colors[i]))\n", 261 | "\n", 262 | " # Plot legend\n", 263 | " if not legend_labels is None: \n", 264 | " plt.legend(class_distr, legend_labels, loc=1)\n", 265 | "\n", 266 | " # Plot title\n", 267 | " if title:\n", 268 | " if accuracy:\n", 269 | " perc = 100 * accuracy\n", 270 | " plt.suptitle(title)\n", 271 | " plt.title(\"Accuracy: %.1f%%\" % perc, fontsize=10)\n", 272 | " else:\n", 273 | " plt.title(title)\n", 274 | "\n", 275 | " # Axis labels\n", 276 | " plt.xlabel('class 1')\n", 277 | " plt.ylabel('class 2')\n", 278 | "\n", 279 | " plt.show()" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": 12, 285 | "metadata": {}, 286 | "outputs": [ 287 | { 288 | "data": { 289 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEjCAYAAAAomJYLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAe7UlEQVR4nO3dfZRU9Z3n8fenQWAaRBNsFRVEEw+JT1GmQ+I4PoWEaDZizBpH4eRB3OmFRGdXT2ZiCHEziXLOxqxO4hhZ4mBi1h43oyHpzSjicM5EEqOhiU/4lGGMkRYYGnwg2CrB+u4f9zYWxe2mCvvWQ9fndQ6n6v7uvVXfqqP16Xt/v3t/igjMzMxKtdS6ADMzq08OCDMzy+SAMDOzTA4IMzPL5IAwM7NMDggzM8vkgDAzs0wOCLMySHpO0odL2s6UVJC0Pf3XI+lHkt6fsf9R6bbfrV7VZm+PA8Ls7dkQEeOA/YEPAk8DqyTNKNnuM8BLwEWSRle5RrN94oAwGwKR6ImIq4FbgP9ZsslngIXAH4Fzq12f2b5wQJgNvR8D0ySNBZB0GnAEcAfwI5KwMKt7DgizobcBEHBguvxZ4J6IeAnoBM6RdHCNajMrmwPCbOgdDgTwsqQ/AT4F3A4QEb8Cngdm1648s/I4IMyG3vnAbyLi1fT5eOC7kjZJ2kQSID7NZHVvZK0LMGsg+0kaU7S86/8fSQIOA/5L+m9WuuqzwFLgK0X7HQ6slnRCRDyeb8lm+06eD8Js7yQ9BxxZ0vxL4M+APpI+h1eAB4BvRcSDkg4Hfg+cXBoEku4GnoyIL+Zdu9m+ckCYmVkm90GYmVkmB4SZmWVyQJiZWSYHhJmZZXJAWMOSdL6kkPSeWtfydkj6lKQn0ru9tpes+7KkdZKekfTRovY/lfR4uu476TDbrNfeY39JoyUtl7RW0ueLtl0i6eS8Pqc1HgeENbKLgV8AF+X5JpJG5Pn6wFrgk8D9Je97LMlnOw44m+Riu/5abgY6gGPSf2eXvugg+38UWAOcmL4Gkt4HtETEw0P94axxOSCsIUkaB5wKXEpRQEgaIelb6V/Xj0m6PG1/v6QHJD0q6deS9pf0OUl/X7TvzySdmT7fLunrkh4CTpF0taTV6V/dS/r/Ypf0bkn/kr7ubyS9S9IPJZ1X9Lq3S+q/cG4PEfFURDyTseo84I6IeCMifgesA6ZLmgiMj4hfRTJO/TbgE+XuT3JH2T9h9wtlvwFcPVCN1pwcENaoPgEsj4jfAi9Kmpa2dwBHkVycdiJwu6RRwP8F/ltEvA/4MPDaXl5/LLA2Ij4QEb8A/j4i3h8Rx5P8uH483e524Kb0df8M2Ehyu+9LACQdkLbfLeluSYdV8BkPB9YXLfekbYenz0vby93/PuBQ4CHgm2l4rYmIDRXUZk3At9qwRnUx8Hfp8zvS5d+Q/PgvjoidABHxoqQTgI0RsTpt2wYwwGn7fm8CdxUtnyXpb4BW4J3AE5L+FTg8Ipalr/t6uu3PJd2U3rH1k8BdaT0fq/AzZhUYg7SXtX9ay2wASfsB9wKzJF0PTAZui4iuCmu1YcgBYQ1H0gTgQ8DxkgIYAUT6Ay72/LHMagPYye5H0cX3WXo9It5M328M8F2gPSLWS/pauu1gCfNDYA7J6a+5ZX60Uj3ApKLlI0huJd6TPi9tL3f/Yp8HfgCcAuwA/gL4FeCAMJ9isoZ0AclfuUdGxJSImAT8DvhzYAUwT9JIAEnvJJkG9LD+uaLT/oeRwHPASZJaJE0iOT+fpT84tqR9HxfAriORHkmfSF93tKTWdNvvA/893e6JffycXaRTlEo6iqQz+tcRsRH4g6QPpn0hnwF+Wu7+/SslvYPkVNltJEdGBZIgHZPxWtaEHBDWiC4GlpW03UVy2uQWkvkWHpP0KDA7Ivr/Mr4xbbuP5EfwlyTB8jjwLZJTVHuIiJeB76Xb/QRYXbT608BfSXqM5EZ9h6b7/AfwFHBr/4YD9UGkw3V7SP6K/2dJ96av8QTJDHRPAsuBL/Qf1QDz08+6Dvh34J70tWZJ+noZ+0PSKX1N2tF9L9CefsbvZX0P1nx8sz6zHKRHEo8D0yLilVrXY7YvfARhNsQkfZjktNaNDgdrZD6CMDOzTD6CMDOzTA4IMzPLNKyugzjooINiypQptS7DzKxhrFmzZktEtGWtG1YBMWXKFLq7u2tdhplZw5D0+4HW+RSTmZllckCYmVkmB4SZmWVyQJiZWSYHhJmZZXJAmJnVuZWdq5gzZT4zR1zInCnzWdm5qirvO6yGuZqZDTcrO1dxQ8di3ujbAcDm57dwQ8diAGbMPi3X9/YRhJlZHVu6oHNXOPR7o28HSxd05v7eDggzszrWu35rRe1DyQFhZlbH2iZNqKh9KDkgzMzq2NxFsxndOmq3ttGto5i7aHbu7+1OajOzOtbfEb10QSe967fSNmkCcxfNzr2DGobZhEHt7e3hm/WZmZVP0pqIaM9a51NMZmaWyQFhZmaZHBBmZpbJAWFmZpkcEGZmlinXYa6SDgRuAY4HApgbEb8qWv/XwJyiWt4LtEXEi5KeA/4AvAnsHKiX3czM8pH3dRDfBpZHxAWSRgGtxSsj4jrgOgBJ5wJXRMSLRZucFRFbcq7RzMwy5BYQksYDpwOfA4iIHcCOQXa5GPjHvOoxM7PK5NkHcTTQC9wq6WFJt0gam7WhpFbgbOCuouYAVkhaI6kjxzrNzCxDngExEpgG3BwRJwOvAlcNsO25wC9LTi+dGhHTgHOAL0g6PWtHSR2SuiV19/b2DmH5ZmbNLc+A6AF6IuKhdPlOksDIchElp5ciYkP6uBlYBkzP2jEilkREe0S0t7W1DUnhZmaWY0BExCZgvaSpadMM4MnS7SQdAJwB/LSobayk/fufAzOBtXnVamZme8p7FNPlwO3pCKZngUskzQOIiMXpNucDKyLi1aL9DgGWSeqvsTMiludcq5mZFfHdXM3Mmpjv5lplhb4uCpvPpLBpavLY11XrkszMKuYJg4ZYoa8Lti0EXk8bNsC2hRSAltZZtSzNzKwiPoIYatuvZ1c47PJ62m5m1jgcEEOtsLGydjOzOuWAGGotEytrNzOrUw6IoTbuSmBMSeOYtN3MrHG4k3qItbTOogBJn0NhY3LkMO5Kd1CbWcNxQOSgpXUWOBDMrMH5FJOZWZ1b2bmKOVPmM3PEhcyZMp+Vnauq8r4+gjAzq0MrO1exdEEnm5/fAiKZAAHY/PwWbuhI7lQ0Y/ZpudbgIwgzszqzsnMVN3QsTsIBdoVDvzf6drB0QWfudTggzMzqzNIFnbzRN9gEnNC7fmvudTggzMzqTDk//m2TJuRehwPCzKzO7O3Hf3TrKOYump17HQ4IM7M6M3fRbEa3jtq9UcnDwZMP4ool83LvoAaPYjIzqzv9P/5LF3TSu34rbZMmMHfR7KqEQjFPGGRm1sQ8YZCZmVUs14CQdKCkOyU9LekpSaeUrD9T0iuSHkn/XV207mxJz0haJ+mqPOs0M7M95d0H8W1geURcIGkU0JqxzaqI+Hhxg6QRwE3AR4AeYLWkroh4Mud6zcwsldsRhKTxwOnAPwBExI6IeLnM3acD6yLi2YjYAdwBnJdLoWZmlinPU0xHA73ArZIelnSLpLEZ250i6VFJ90g6Lm07HFhftE1P2rYHSR2SuiV19/b2DukHMDNrZnkGxEhgGnBzRJwMvAqU9iX8BjgyIt4H3Aj8JG1XxutlDreKiCUR0R4R7W1tbUNSuJmZ5RsQPUBPRDyULt9JEhi7RMS2iNiePr8b2E/SQem+k4o2PQLYkGOtZmZWIreAiIhNwHpJU9OmGcBuncySDpWk9Pn0tJ6twGrgGElHpZ3bFwFdedVqZmZ7ynsU0+XA7emP/LPAJZLmAUTEYuACYL6kncBrwEWRXLm3U9JlwL3ACGBpRDyRc61mZlbEV1KbmTUxX0ltZmYVc0CYmVkmB4SZmWVyQNSxQl8Xhc1nUtg0NXns80AuM6sezwdRpwp9XbBtIfB62rABti2kALS0zqplaWbWJHwEUa+2X8+ucNjl9bTdzCx/Doh6VdhYWbuZ2RBzQNSrlomVtZuZDTEHRL0adyUwpqRxTNpuZpY/d1LXqZbWWRQg6XMobEyOHMZd6Q5qM6uapg+IQl9X3f4It7TOgjqpxcyaT1MHhIeSmpkNrLn7IDyU1MxsQM0dEB5KamY2oOYOCA8lNTMbUHMHhIeSmpkNqKk7qT2U1MxsYE0dEOChpGZmA8n1FJOkAyXdKelpSU9JOqVk/RxJj6X/HpD0vqJ1z0l6XNIjkjyPqJlZleV9BPFtYHlEXCBpFNBasv53wBkR8ZKkc4AlwAeK1p8VEVtyrtHMzDLkFhCSxgOnA58DiIgdwI7ibSLigaLFB4Ej8qrHzMwqk+cppqOBXuBWSQ9LukXS2EG2vxS4p2g5gBWS1kjqGGgnSR2SuiV19/b2Dk3lZmaWa0CMBKYBN0fEycCrwFVZG0o6iyQgvlTUfGpETAPOAb4g6fSsfSNiSUS0R0R7W1vbkH4AM7NmlmdA9AA9EfFQunwnSWDsRtKJwC3AeRGxtb89Ijakj5uBZcD0HGs1M7MSuQVERGwC1kuamjbNAJ4s3kbSZODHwKcj4rdF7WMl7d//HJgJrM2rVjMz21Peo5guB25PRzA9C1wiaR5ARCwGrgYmAN+VBLAzItqBQ4BladtIoDMiludcq5mZFVFE1LqGIdPe3h7d3b5kwsysXJLWpH+Y76G578VkZlYlKztXMWfKfGaOuJA5U+azsnNVrUvaq6a/1YaZWd5Wdq7iho7FvNGXXAq2+fkt3NCxGIAZs0+rZWmD8hGEmVnOli7o3BUO/d7o28HSBZ01qqg8Dggzs5z1rt9aUXu9cECYmeWsbdKEitrrhQPCzCxncxfNZnTrqN3aRreOYu6i2TWqqDzupDYzy1l/R/TSBZ30rt9K26QJzF00u647qMHXQZiZ1YWVnatqEiCDXQfhIwgzsxqr12Gw7oMwM6uxeh0G64AwM6uxeh0GO2hASPqopEslTSlpn5trVWZmTaReh8EOGBCSFgFfAU4AVkq6vGj1ZXkXZmbWLOp1GOxgndTnAidHxE5JXwM6JR0dEVcAqkp1ZmZNoF6HwQ44zFXSUxHx3qLlEcASYDxwbEQcV50Sy+dhrmZmldnX233/u6Qz+hci4s2IuBR4BnjvwLs1t0JfF4XNZ1LYNDV57OuqdUlmNgzU4nbhg51i+lRWY0QslHRzTvU0tEJfF2xbCLyeNmyAbQspAC2ts2pZmpk1sFpdJzHgEUREvBYRrw2w7oVyXlzSgZLulPS0pKcknVKyXpK+I2mdpMckTStad7akZ9J1V5X7gWpq+/XsCoddXk/bzcz2Ta2uk8j7OohvA8sj4j3A+4CnStafAxyT/usAboZd/R03peuPBS6WdGzOtb59hY2VtZuZlaFW10nkFhCSxgOnA/8AEBE7IuLlks3OA26LxIPAgZImAtOBdRHxbETsAO5It61vLRMrazczK0OtrpPYa0BIepek0enzMyX9laQDy3jto4Fe4FZJD0u6RdLYkm0OB9YXLfekbQO117dxVwJjShrHpO1mZvumVtdJlHMEcRfwpqR3kxwNHAWUc+JrJDANuDkiTgZeBUr7ErKup4hB2vcgqUNSt6Tu3t7eMsrKT0vrLBh/DbQcBih5HH+NO6jN7G2ZMfs0rlgyj4MnH4QkDp58EFcsmZf7dRLl3M21kF4sdz7wdxFxo6SHy9ivB+iJiIfS5TvZMyB6gElFy0cAG4BRA7TvISKWkFyfQXt7e83vXd7SOgsaMBAKfV1JZ3phY3JKbNyVDjazOjJj9mlVv3CunCOIP0q6GPgs8LO0bb+97RQRm4D1kqamTTOAJ0s26wI+k45m+iDwSkRsBFYDx0g6StIo4KJ0W8vBruG5hQ1AvDU819dwmDW1cgLiEuAU4NqI+J2ko4D/U+brXw7cLukx4CRgkaR5kual6+8GngXWAd8DPg8QETtJ7vd0L8nIpx9FxBNlvqdVqkGH5/qiRLN8VTSjnKR3AJMi4rH8Stp3vtXGvilsmkp2F49oOfSZapdTlj0uSgRgjPt8zCq0r7fa6N/5XyWNl/RO4FGSUUn1/aelVaYRh+c26FGPWSMp5xTTARGxDfgkcGtE/Cnw4XzLsqpqxOG5vijRLHflBMTI9OK1C3mrk9qGkYYcntuIRz1mDaacYa5fJ+ks/kVErJZ0NPBv+ZZl1dZww3PHXZndB1HPRz1mDWavARER/wT8U9Hys8B/zrMos71paZ1FAXzthlmO9hoQksYAlwLHUXSiOiI8L7XVVMMd9Zg1mHL6IH4IHAp8FPg5yVXNf8izKDMzq71yAuLdEfFV4NWI+AHwn4AT8i3LzMxqraxbbaSPL0s6HjgAmJJbRWZmVhfKCYgl6RXUXyW5H9KTwDdzrcrMrA7UYh7oelLOKKZb0qc/J5njwcxs2KvVPND1ZMCAkDTogPKI8D0NzGzYGmwe6KYPCGD/qlVhZlZnajUPdD0ZMCAi4m+rWYiZWT1pmzSBzc9vyWxvFuXczfUHxXNQS3qHpKW5VmVmVmO1mge6npRzL6YTI+Ll/oWIeEnSyfmVZGZWe/39DEsXdNK7fittkyYwd9Hspul/gPICokXSOyLiJYB0Xohy9jMza2i1mAe6npTzQ/+/gAck3Uky7diFwLW5VmVmZjVXznUQt0nqBj4ECPhkRDxZzotLeo7kvk1vAjtLp7WT9NfAnKJa3gu0RcSLe9vXzMzyVdapojQQygqFDGdFxJ5DAZLXvQ64DkDSucAVEfFiOfuamVm+yrnVRrVcDPxjrYswM7NE3gERwApJayR1DLSRpFbgbOCufdi3Q1K3pO7e3t4hK9zMrNnlPRrp1IjYIOlg4D5JT0fE/RnbnQv8suT0Uln7RsQSYAlAe3t75PEhzMyaUa5HEBGxIX3cDCwDpg+w6UWUnF6qYF8zM8tBbgEhaayk/fufAzOBtRnbHQCcAfy00n3NzCw/eZ5iOgRYJqn/fTojYrmkeQARsTjd7nxgRUS8urd9c6zVzMxKKGL4nLZvb2+P7u7uWpdhZsPcys5VVbkFRzXeR9Kaga4z8y0zzMwqUK2JhOphwqJ6ug7CzKzuDTaRUCO+z2AcEGZmFajWREL1MGGRA8LMrAIDTRg01BMJVet9BuOAMDOrQLUmEqqHCYvcSW1mVoFqTSRUDxMWeZirmVkTG2yYq08xmZnVuZWdq5gzZT4zR1zInCnzWdm5qirv61NMZmZ1rJbXQ/gIwsysjtXyeggHhJlZHavl9RAOCDOzOlbL6yEcEGZmdayW10O4k9rMrI7V8noIXwdhZtbEfB2EmZlVzAFhZmaZcg0ISc9JelzSI5L2OPcj6UxJr6TrH5F0ddG6syU9I2mdpKvyrNPMzPZUjU7qsyJiyyDrV0XEx4sbJI0AbgI+AvQAqyV1RcSTOdZpZmZF6vUU03RgXUQ8GxE7gDuA82pck5lZU8k7IAJYIWmNpI4BtjlF0qOS7pF0XNp2OLC+aJuetM3MzKok71NMp0bEBkkHA/dJejoi7i9a/xvgyIjYLuljwE+AYwBlvFbmeNw0eDoAJk+ePKTFm5k1s1yPICJiQ/q4GVhGcuqoeP22iNiePr8b2E/SQSRHDJOKNj0C2DDAeyyJiPaIaG9ra8vhU5iZNafcAkLSWEn79z8HZgJrS7Y5VJLS59PTerYCq4FjJB0laRRwEdCVV61mZranPE8xHQIsS3//RwKdEbFc0jyAiFgMXADMl7QTeA24KJJLu3dKugy4FxgBLI2IJ3Ks1czMSvhWG2ZmTcy32jAzs4o5IMzMLJMDwszMMjkgzMwa1MrOVcyZMp+ZIy5kzpT5rOxcNaSv7wmDzMwa0MrOVdzQsZg3+nYAsPn5LdzQsRhgyCYT8hGEmVkDWrqgc1c49HujbwdLF3QO2Xs4IMzMGlDv+q0Vte8LB4SZWQNqmzShovZ94YAwM2tAcxfNZnTrqN3aRreOYu6i2UP2Hu6kNjNrQP0d0UsXdNK7fittkyYwd9HsIeugBt9qw8ysqflWG2ZmVjEHhJmZZXJAmJlZJgeEmZllckCYmVkmB4SZmWVyQJiZWaZcL5ST9BzwB+BNYGfpWFtJc4AvpYvbgfkR8Wg5+5qZWb6qcSX1WRGxZYB1vwPOiIiXJJ0DLAE+UOa+ZmaWo5qeYoqIByLipXTxQeCIWtZjVqzQ10Vh85kUNk1NHvu6al2SWVXlfQQRwApJAfzviFgyyLaXAvdUuq+kDqADYPLkyUNTtTWlQl8XbL8eChuBA4BXgT+mKzfAtoUUgJbWWbUr0qyK8g6IUyNig6SDgfskPR0R95duJOkskoD480r3TYNjCST3YsrnY9hwV+jrgm0LgdfTlpcztno9CRAHhDWJXE8xRcSG9HEzsAyYXrqNpBOBW4DzImJrJfuaDZnt1/NWOAyisDH3UszqRW4BIWmspP37nwMzgbUl20wGfgx8OiJ+W8m+ZkOq3B/+lon51mFWR/I8xXQIsExS//t0RsRySfMAImIxcDUwAfhuul3/cNbMfXOs1Zpdy8Skn2FQY2DclVUpx6weeD4IM7L6IABGgsZBvJIEyLgr3UFtw85g80F4RjkzkpFJBXhrFJMDwcwBYdavpXWWRyiZFfG9mMzMLJMDwszMMjkgzMwskwPCzMwyOSDMzCyTA8LMzDI5IMzMLJMDwszMMjkgzMwskwPCzMwyOSDMzCyTA8LMzDI5IMzMLJMDwszMMjkgrKEV+roobD6TwqapyWNfV61LMhs2cg0ISc9JelzSI5L2mOpNie9IWifpMUnTitadLemZdN1VedZpjWnXLHCFDUAkj9sWOiTMhkg1jiDOioiTBpjS7hzgmPRfB3AzgKQRwE3p+mOBiyUdW4VarZFsv57dpwglWd5+fS2qMRt2an2K6Tzgtkg8CBwoaSIwHVgXEc9GxA7gjnRbs7cUNlbWbmYVyTsgAlghaY2kjoz1hwPri5Z70raB2s3e0jKxsnYzq0jeAXFqREwjOVX0BUmnl6xXxj4xSPseJHVI6pbU3dvb+/aqtcYy7kpgTEnjmLTdzN6uXAMiIjakj5uBZSSnjor1AJOKlo8ANgzSnvUeSyKiPSLa29rahqp0awAtrbNg/DXQchig5HH8NUm7mb1tI/N6YUljgZaI+EP6fCbw9ZLNuoDLJN0BfAB4JSI2SuoFjpF0FPACcBEwO69arXG1tM4CB4JZLnILCOAQYJmk/vfpjIjlkuYBRMRi4G7gY8A6oA+4JF23U9JlwL3ACGBpRDyRY61mZlZCEZmn9htSe3t7dHfvcbmFmZkNQNKaAS5DqPkwVzMzq1MOCDMzy+SAMDOzTMOqDyId/fT7WtdRRQcBW2pdRB3z9zM4fz+Da5bv58iIyLxGYFgFRLOR1D1Q55L5+9kbfz+D8/fjU0xmZjYAB4SZmWVyQDS2JbUuoM75+xmcv5/BNf334z4IMzPL5CMIMzPL5IBocJK+JumFdFrXRyR9rNY11SNJX5QUkg6qdS31RNI30ul+H5G0QtJhta6pnki6TtLT6Xe0TNKBta6pmhwQw8MN6bSuJ0XE3bUupt5ImgR8BHi+1rXUoesi4sSIOAn4GXB1jeupN/cBx0fEicBvgS/XuJ6qckBYM7gB+BsGmHSqmUXEtqLFsfg72k1ErIiInenigyRz0zQNB8TwcFl6CLxU0jtqXUw9kTQLeCEiHq11LfVK0rWS1gNz8BHEYOYC99S6iGryKKYGIOlfgEMzVn2F5K+aLSR/+X0DmBgRc6tYXs3t5ftZAMyMiFckPQe0R0Qz3D5hl8G+n4j4adF2XwbGRMT/qFpxdaCc70fSV4B24JPRRD+aDohhRNIU4GcRcXyta6kHkk4AVpJMRgVvTV07PSI21aywOiXpSOCf/d/P7iR9FpgHzIiIvr1tP5zkOaOcVYGkiRGxMV08H1hby3rqSUQ8Dhzcv9ysRxCDkXRMRPxbujgLeLqW9dQbSWcDXwLOaLZwAB9BNDxJPwROIjnF9BzwX4sCw4o4IPYk6S5gKlAguRPyvIh4obZV1Q9J64DRwNa06cGImFfDkqrKAWFmZpk8isnMzDI5IMzMLJMDwszMMjkgzMwskwPCzMwyOSDM9lF6J90v5vTa10paL2l7Hq9vVg4HhFl9+n/A9FoXYc3NAWFWBkmfSW+I+Gh6cWLp+r+UtDpdf5ek1rT9U5LWpu33p23HSfp1OgfDY5KOKX29iHjQFzxarflCObO9kHQc8GPg1IjYIumdEfGipK8B2yPiW5ImRMTWdPtrgP+IiBslPQ6cHREvSDowIl6WdCPJFbm3SxoFjIiI1wZ47+0RMa46n9Rsdz6CMNu7DwF39t+iIyJezNjmeEmr0kCYAxyXtv8S+L6kvwRGpG2/AhZI+hJw5EDhYFZrDgizvRN7n0jn+8BlEXEC8LfAGID0vj0LgUnAI+mRRifJjfFeA+6V9KG8Cjd7OxwQZnu3ErhQ0gQASe/M2GZ/YKOk/UiOIEi3fVdEPBQRV5PM2zFJ0tHAsxHxHaALODH3T2C2DxwQZnsREU8A1wI/l/QocH3GZl8FHiKZw7j4ltnXSXpc0lrgfuBR4C+AtZIeAd4D3Fb6YpK+KakHaJXUk/Z3mFWVO6nNzCyTjyDMzCyTA8LMzDI5IMzMLJMDwszMMjkgzMwskwPCzMwyOSDMzCyTA8LMzDL9fzs5N8Lc289JAAAAAElFTkSuQmCC\n", 290 | "text/plain": [ 291 | "
" 292 | ] 293 | }, 294 | "metadata": { 295 | "needs_background": "light" 296 | }, 297 | "output_type": "display_data" 298 | } 299 | ], 300 | "source": [ 301 | "Plot().plot_in_2d(X_test, y_pred, title=\"LDA\", accuracy=accuracy)" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": 13, 307 | "metadata": {}, 308 | "outputs": [ 309 | { 310 | "name": "stdout", 311 | "output_type": "stream", 312 | "text": [ 313 | "1.0\n" 314 | ] 315 | } 316 | ], 317 | "source": [ 318 | "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n", 319 | "clf = LinearDiscriminantAnalysis()\n", 320 | "clf.fit(X_train, y_train)\n", 321 | "y_pred = clf.predict(X_test)\n", 322 | "accuracy = accuracy_score(y_test, y_pred)\n", 323 | "print(accuracy)" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": null, 329 | "metadata": {}, 330 | "outputs": [], 331 | "source": [] 332 | } 333 | ], 334 | "metadata": { 335 | "kernelspec": { 336 | "display_name": "Python 3 (ipykernel)", 337 | "language": "python", 338 | "name": "python3" 339 | }, 340 | "language_info": { 341 | "codemirror_mode": { 342 | "name": "ipython", 343 | "version": 3 344 | }, 345 | "file_extension": ".py", 346 | "mimetype": "text/x-python", 347 | "name": "python", 348 | "nbconvert_exporter": "python", 349 | "pygments_lexer": "ipython3", 350 | "version": "3.9.7" 351 | }, 352 | "toc": { 353 | "base_numbering": 1, 354 | "nav_menu": {}, 355 | "number_sections": true, 356 | "sideBar": true, 357 | "skip_h1_title": false, 358 | "title_cell": "Table of Contents", 359 | "title_sidebar": "Contents", 360 | "toc_cell": false, 361 | "toc_position": {}, 362 | "toc_section_display": true, 363 | "toc_window_display": false 364 | } 365 | }, 366 | "nbformat": 4, 367 | "nbformat_minor": 2 368 | } 369 | -------------------------------------------------------------------------------- /charpter7_decision_tree/CART.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### CART" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "from sklearn.model_selection import train_test_split\n", 18 | "from sklearn.metrics import accuracy_score, mean_squared_error\n", 19 | "from utils import feature_split, calculate_gini" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "### 定义树结点\n", 29 | "class TreeNode():\n", 30 | " def __init__(self, feature_i=None, threshold=None,\n", 31 | " leaf_value=None, left_branch=None, right_branch=None):\n", 32 | " # 特征索引\n", 33 | " self.feature_i = feature_i \n", 34 | " # 特征划分阈值\n", 35 | " self.threshold = threshold \n", 36 | " # 叶子节点取值\n", 37 | " self.leaf_value = leaf_value \n", 38 | " # 左子树\n", 39 | " self.left_branch = left_branch \n", 40 | " # 右子树\n", 41 | " self.right_branch = right_branch " 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 3, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "### 定义二叉决策树\n", 51 | "class BinaryDecisionTree(object):\n", 52 | " ### 决策树初始参数\n", 53 | " def __init__(self, min_samples_split=2, min_gini_impurity=float(\"inf\"),\n", 54 | " max_depth=float(\"inf\"), loss=None):\n", 55 | " # 根结点\n", 56 | " self.root = None \n", 57 | " # 节点最小分裂样本数\n", 58 | " self.min_samples_split = min_samples_split\n", 59 | " # 节点初始化基尼不纯度\n", 60 | " self.mini_gini_impurity = min_gini_impurity\n", 61 | " # 树最大深度\n", 62 | " self.max_depth = max_depth\n", 63 | " # 基尼不纯度计算函数\n", 64 | " self.gini_impurity_calculation = None\n", 65 | " # 叶子节点值预测函数\n", 66 | " self._leaf_value_calculation = None\n", 67 | " # 损失函数\n", 68 | " self.loss = loss\n", 69 | "\n", 70 | " ### 决策树拟合函数\n", 71 | " def fit(self, X, y, loss=None):\n", 72 | " # 递归构建决策树\n", 73 | " self.root = self._build_tree(X, y)\n", 74 | " self.loss=None\n", 75 | "\n", 76 | " ### 决策树构建函数\n", 77 | " def _build_tree(self, X, y, current_depth=0):\n", 78 | " # 初始化最小基尼不纯度\n", 79 | " init_gini_impurity = self.gini_impurity_calculation(y)\n", 80 | " # 初始化最佳特征索引和阈值\n", 81 | " best_criteria = None \n", 82 | " # 初始化数据子集\n", 83 | " best_sets = None \n", 84 | "\n", 85 | " # 合并输入和标签\n", 86 | " Xy = np.concatenate((X, y), axis=1)\n", 87 | " # 获取样本数和特征数\n", 88 | " n_samples, n_features = X.shape\n", 89 | " # 设定决策树构建条件\n", 90 | " # 训练样本数量大于节点最小分裂样本数且当前树深度小于最大深度\n", 91 | " if n_samples >= self.min_samples_split and current_depth <= self.max_depth:\n", 92 | " # 遍历计算每个特征的基尼不纯度\n", 93 | " for feature_i in range(n_features):\n", 94 | " # 获取第i特征的所有取值\n", 95 | " feature_values = np.expand_dims(X[:, feature_i], axis=1)\n", 96 | " # 获取第i个特征的唯一取值\n", 97 | " unique_values = np.unique(feature_values)\n", 98 | "\n", 99 | " # 遍历取值并寻找最佳特征分裂阈值\n", 100 | " for threshold in unique_values:\n", 101 | " # 特征节点二叉分裂\n", 102 | " Xy1, Xy2 = feature_split(Xy, feature_i, threshold)\n", 103 | " # 如果分裂后的子集大小都不为0\n", 104 | " if len(Xy1) > 0 and len(Xy2) > 0:\n", 105 | " # 获取两个子集的标签值\n", 106 | " y1 = Xy1[:, n_features:]\n", 107 | " y2 = Xy2[:, n_features:]\n", 108 | "\n", 109 | " # 计算基尼不纯度\n", 110 | " impurity = self.impurity_calculation(y, y1, y2)\n", 111 | "\n", 112 | " # 获取最小基尼不纯度\n", 113 | " # 最佳特征索引和分裂阈值\n", 114 | " if impurity < init_gini_impurity:\n", 115 | " init_gini_impurity = impurity\n", 116 | " best_criteria = {\"feature_i\": feature_i, \"threshold\": threshold}\n", 117 | " best_sets = {\n", 118 | " \"leftX\": Xy1[:, :n_features], \n", 119 | " \"lefty\": Xy1[:, n_features:], \n", 120 | " \"rightX\": Xy2[:, :n_features], \n", 121 | " \"righty\": Xy2[:, n_features:] \n", 122 | " }\n", 123 | " \n", 124 | " # 如果best_criteria不为None,且计算的最小不纯度小于设定的最小不纯度\n", 125 | " if best_criteria and init_gini_impurity < self.mini_gini_impurity:\n", 126 | " # 分别构建左右子树\n", 127 | " left_branch = self._build_tree(best_sets[\"leftX\"], best_sets[\"lefty\"], current_depth + 1)\n", 128 | " right_branch = self._build_tree(best_sets[\"rightX\"], best_sets[\"righty\"], current_depth + 1)\n", 129 | " return TreeNode(feature_i=best_criteria[\"feature_i\"], threshold=best_criteria[\n", 130 | " \"threshold\"], left_branch=left_branch, right_branch=right_branch)\n", 131 | "\n", 132 | " # 计算叶子计算取值\n", 133 | " leaf_value = self._leaf_value_calculation(y)\n", 134 | "\n", 135 | " return TreeNode(leaf_value=leaf_value)\n", 136 | "\n", 137 | " ### 定义二叉树值预测函数\n", 138 | " def predict_value(self, x, tree=None):\n", 139 | " if tree is None:\n", 140 | " tree = self.root\n", 141 | "\n", 142 | " # 如果叶子节点已有值,则直接返回已有值\n", 143 | " if tree.leaf_value is not None:\n", 144 | " return tree.leaf_value\n", 145 | "\n", 146 | " # 选择特征并获取特征值\n", 147 | " feature_value = x[tree.feature_i]\n", 148 | "\n", 149 | " # 判断落入左子树还是右子树\n", 150 | " branch = tree.right_branch\n", 151 | " if isinstance(feature_value, int) or isinstance(feature_value, float):\n", 152 | " if feature_value <= tree.threshold:\n", 153 | " branch = tree.left_branch\n", 154 | " elif feature_value == tree.threshold:\n", 155 | " branch = tree.left_branch\n", 156 | "\n", 157 | " # 测试子集\n", 158 | " return self.predict_value(x, branch)\n", 159 | "\n", 160 | " ### 数据集预测函数\n", 161 | " def predict(self, X):\n", 162 | " y_pred = [self.predict_value(sample) for sample in X]\n", 163 | " return y_pred" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 4, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "### CART回归树\n", 173 | "class RegressionTree(BinaryDecisionTree):\n", 174 | " def _calculate_weighted_mse(self, y, y1, y2):\n", 175 | " var_y1 = np.var(y1, axis=0)\n", 176 | " var_y2 = np.var(y2, axis=0)\n", 177 | " frac_1 = len(y1) / len(y)\n", 178 | " frac_2 = len(y2) / len(y)\n", 179 | " # 计算左右子树加权总均方误差\n", 180 | " weighted_mse = frac_1 * var_y1 + frac_2 * var_y2\n", 181 | " \n", 182 | " return sum(weighted_mse)\n", 183 | "\n", 184 | " # 节点值取平均\n", 185 | " def _mean_of_y(self, y):\n", 186 | " value = np.mean(y, axis=0)\n", 187 | " return value if len(value) > 1 else value[0]\n", 188 | "\n", 189 | " def fit(self, X, y):\n", 190 | " self.impurity_calculation = self._calculate_weighted_mse\n", 191 | " self._leaf_value_calculation = self._mean_of_y\n", 192 | " self.gini_impurity_calculation = lambda y: np.var(y, axis=0)\n", 193 | " super(RegressionTree, self).fit(X, y)" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 5, 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "### CART决策树\n", 203 | "class ClassificationTree(BinaryDecisionTree):\n", 204 | " ### 定义基尼不纯度计算过程\n", 205 | " def _calculate_gini_impurity(self, y, y1, y2):\n", 206 | " p = len(y1) / len(y)\n", 207 | " gini = calculate_gini(y)\n", 208 | " gini_impurity = p * calculate_gini(y1) + (1-p) * calculate_gini(y2)\n", 209 | " return gini_impurity\n", 210 | " \n", 211 | " ### 多数投票\n", 212 | " def _majority_vote(self, y):\n", 213 | " most_common = None\n", 214 | " max_count = 0\n", 215 | " for label in np.unique(y):\n", 216 | " # 统计多数\n", 217 | " count = len(y[y == label])\n", 218 | " if count > max_count:\n", 219 | " most_common = label\n", 220 | " max_count = count\n", 221 | " return most_common\n", 222 | " \n", 223 | " # 分类树拟合\n", 224 | " def fit(self, X, y):\n", 225 | " self.impurity_calculation = self._calculate_gini_impurity\n", 226 | " self._leaf_value_calculation = self._majority_vote\n", 227 | " self.gini_impurity_calculation = calculate_gini\n", 228 | " super(ClassificationTree, self).fit(X, y)" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": 6, 234 | "metadata": {}, 235 | "outputs": [ 236 | { 237 | "name": "stdout", 238 | "output_type": "stream", 239 | "text": [ 240 | "0.9777777777777777\n" 241 | ] 242 | } 243 | ], 244 | "source": [ 245 | "from sklearn import datasets\n", 246 | "data = datasets.load_iris()\n", 247 | "X, y = data.data, data.target\n", 248 | "# 注意!是否要对y进行reshape取决于numpy版本\n", 249 | "y = y.reshape(-1,1)\n", 250 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)\n", 251 | "clf = ClassificationTree()\n", 252 | "clf.fit(X_train, y_train)\n", 253 | "y_pred = clf.predict(X_test)\n", 254 | "\n", 255 | "print(accuracy_score(y_test, y_pred))" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 7, 261 | "metadata": {}, 262 | "outputs": [ 263 | { 264 | "name": "stdout", 265 | "output_type": "stream", 266 | "text": [ 267 | "0.9777777777777777\n" 268 | ] 269 | } 270 | ], 271 | "source": [ 272 | "from sklearn.tree import DecisionTreeClassifier\n", 273 | "clf = DecisionTreeClassifier()\n", 274 | "clf.fit(X_train, y_train)\n", 275 | "y_pred = clf.predict(X_test)\n", 276 | "\n", 277 | "print(accuracy_score(y_test, y_pred))" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": 8, 283 | "metadata": {}, 284 | "outputs": [ 285 | { 286 | "name": "stdout", 287 | "output_type": "stream", 288 | "text": [ 289 | "Mean Squared Error: 23.03842105263158\n" 290 | ] 291 | } 292 | ], 293 | "source": [ 294 | "import pandas as pd\n", 295 | "\n", 296 | "# 波士顿房价数据集的原始 URL\n", 297 | "data_url = \"http://lib.stat.cmu.edu/datasets/boston\"\n", 298 | "\n", 299 | "# 从 URL 加载数据\n", 300 | "raw_df = pd.read_csv(data_url, sep=\"\\s+\", skiprows=22, header=None)\n", 301 | "\n", 302 | "# 处理数据\n", 303 | "data = np.hstack([raw_df.values[::2, :], raw_df.values[1::2, :2]]) # 拼接特征数据\n", 304 | "target = raw_df.values[1::2, 2] # 目标变量\n", 305 | "\n", 306 | "# 将数据和目标变量转换为 NumPy 数组\n", 307 | "X = np.array(data)\n", 308 | "y = np.array(target)\n", 309 | "\n", 310 | "y = y.reshape(-1,1)\n", 311 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)\n", 312 | "model = RegressionTree()\n", 313 | "model.fit(X_train, y_train)\n", 314 | "y_pred = model.predict(X_test)\n", 315 | "mse = mean_squared_error(y_test, y_pred)\n", 316 | "\n", 317 | "print(\"Mean Squared Error:\", mse)" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": 9, 323 | "metadata": {}, 324 | "outputs": [ 325 | { 326 | "name": "stdout", 327 | "output_type": "stream", 328 | "text": [ 329 | "Mean Squared Error: 23.600592105263157\n" 330 | ] 331 | } 332 | ], 333 | "source": [ 334 | "from sklearn.tree import DecisionTreeRegressor\n", 335 | "reg = DecisionTreeRegressor()\n", 336 | "reg.fit(X_train, y_train)\n", 337 | "y_pred = reg.predict(X_test)\n", 338 | "mse = mean_squared_error(y_test, y_pred)\n", 339 | "\n", 340 | "print(\"Mean Squared Error:\", mse)" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": null, 346 | "metadata": {}, 347 | "outputs": [], 348 | "source": [] 349 | } 350 | ], 351 | "metadata": { 352 | "kernelspec": { 353 | "display_name": "Python 3", 354 | "language": "python", 355 | "name": "python3" 356 | }, 357 | "language_info": { 358 | "codemirror_mode": { 359 | "name": "ipython", 360 | "version": 3 361 | }, 362 | "file_extension": ".py", 363 | "mimetype": "text/x-python", 364 | "name": "python", 365 | "nbconvert_exporter": "python", 366 | "pygments_lexer": "ipython3", 367 | "version": "3.9.19" 368 | }, 369 | "toc": { 370 | "base_numbering": 1, 371 | "nav_menu": {}, 372 | "number_sections": true, 373 | "sideBar": true, 374 | "skip_h1_title": false, 375 | "title_cell": "Table of Contents", 376 | "title_sidebar": "Contents", 377 | "toc_cell": false, 378 | "toc_position": {}, 379 | "toc_section_display": true, 380 | "toc_window_display": false 381 | } 382 | }, 383 | "nbformat": 4, 384 | "nbformat_minor": 2 385 | } 386 | -------------------------------------------------------------------------------- /charpter7_decision_tree/example_data.csv: -------------------------------------------------------------------------------- 1 | humility,outlook,play,temp,windy 2 | high,sunny,no,hot,false 3 | high,sunny,no,hot,true 4 | high,overcast,yes,hot,false 5 | high,rainy,yes,mild,false 6 | normal,rainy,yes,cool,false 7 | normal,rainy,no,cool,true 8 | normal,overcast,yes,cool,true 9 | high,sunny,no,mild,false 10 | normal,sunny,yes,cool,false 11 | normal,rainy,yes,mild,false 12 | normal,sunny,yes,mild,true 13 | high,overcast,yes,mild,true 14 | normal,overcast,yes,hot,false 15 | high,rainy,no,mild,true 16 | -------------------------------------------------------------------------------- /charpter7_decision_tree/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | ### 定义二叉特征分裂函数 4 | def feature_split(X, feature_i, threshold): 5 | split_func = None 6 | if isinstance(threshold, int) or isinstance(threshold, float): 7 | split_func = lambda sample: sample[feature_i] <= threshold 8 | else: 9 | split_func = lambda sample: sample[feature_i] == threshold 10 | 11 | X_left = np.array([sample for sample in X if split_func(sample)]) 12 | X_right = np.array([sample for sample in X if not split_func(sample)]) 13 | 14 | return X_left, X_right 15 | 16 | 17 | ### 计算基尼指数 18 | def calculate_gini(y): 19 | # 将数组转化为列表 20 | y = y.tolist() 21 | probs = [y.count(i)/len(y) for i in np.unique(y)] 22 | gini = sum([p*(1-p) for p in probs]) 23 | return gini 24 | -------------------------------------------------------------------------------- /charpter8_neural_networks/perceptron.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # 定义单层感知机类 4 | class Perceptron: 5 | def __init__(self): 6 | pass 7 | 8 | def sign(self, x, w, b): 9 | return np.dot(x, w) + b 10 | 11 | def train(self, X_train, y_train, learning_rate): 12 | # 参数初始化 13 | w, b = self.initilize_with_zeros(X_train.shape[1]) 14 | # 初始化误分类 15 | is_wrong = False 16 | while not is_wrong: 17 | wrong_count = 0 18 | for i in range(len(X_train)): 19 | X = X_train[i] 20 | y = y_train[i] 21 | 22 | # 如果存在误分类点 23 | # 更新参数 24 | # 直到没有误分类点 25 | if y * self.sign(X, w, b) <= 0: 26 | w = w + learning_rate*np.dot(y, X) 27 | b = b + learning_rate*y 28 | wrong_count += 1 29 | if wrong_count == 0: 30 | is_wrong = True 31 | print('There is no missclassification!') 32 | 33 | # 保存更新后的参数 34 | params = { 35 | 'w': w, 36 | 'b': b 37 | } 38 | return params 39 | -------------------------------------------------------------------------------- /pic/cover.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luwill/Machine_Learning_Code_Implementation/f557538b329d4004da7900b140671a31c43faa7b/pic/cover.jpg -------------------------------------------------------------------------------- /pic/ml_xmind.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luwill/Machine_Learning_Code_Implementation/f557538b329d4004da7900b140671a31c43faa7b/pic/ml_xmind.png -------------------------------------------------------------------------------- /pic/ppt_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luwill/Machine_Learning_Code_Implementation/f557538b329d4004da7900b140671a31c43faa7b/pic/ppt_1.png -------------------------------------------------------------------------------- /pic/ppt_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luwill/Machine_Learning_Code_Implementation/f557538b329d4004da7900b140671a31c43faa7b/pic/ppt_2.png -------------------------------------------------------------------------------- /pic/ppt_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luwill/Machine_Learning_Code_Implementation/f557538b329d4004da7900b140671a31c43faa7b/pic/ppt_3.png -------------------------------------------------------------------------------- /pic/ppt_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luwill/Machine_Learning_Code_Implementation/f557538b329d4004da7900b140671a31c43faa7b/pic/ppt_4.png -------------------------------------------------------------------------------- /pic/ppt_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luwill/Machine_Learning_Code_Implementation/f557538b329d4004da7900b140671a31c43faa7b/pic/ppt_5.png --------------------------------------------------------------------------------