├── README.md ├── charpter01_ml_start └── NumPy_sklearn.ipynb ├── charpter02_linear_regression ├── 1_1_linear_model.py └── linear_regression.ipynb ├── charpter03_logistic_regression └── logistic_regression.ipynb ├── charpter04_regression_expansion ├── example.dat ├── lasso.ipynb └── ridge.ipynb ├── charpter05_LDA └── LDA.ipynb ├── charpter06_knn └── knn.ipynb ├── charpter07_decision_tree ├── CART.ipynb └── example_data.csv ├── charpter08_neural_networks ├── neural_networks.ipynb ├── perceptron.ipynb └── perceptron.py ├── charpter09_SVM ├── hard_margin_svm.ipynb ├── non-linear_svm.ipynb └── soft_margin_svm.ipynb ├── charpter10_AdaBoost └── adaboost.ipynb ├── charpter11_GBDT ├── __pycache__ │ ├── cart.cpython-39.pyc │ └── utils.cpython-39.pyc ├── cart.py ├── gbdt.ipynb └── utils.py ├── charpter12_XGBoost ├── __pycache__ │ ├── cart.cpython-39.pyc │ └── utils.cpython-39.pyc ├── cart.py ├── utils.py └── xgboost.ipynb ├── charpter13_LightGBM └── lightgbm.ipynb ├── charpter14_CatBoost ├── adult.data ├── catboost.ipynb └── catboost_info │ ├── catboost_training.json │ ├── learn │ └── events.out.tfevents │ ├── learn_error.tsv │ └── time_left.tsv ├── charpter15_random_forest ├── __pycache__ │ ├── cart.cpython-39.pyc │ └── utils.cpython-39.pyc ├── cart.py ├── random_forest.ipynb └── utils.py ├── charpter16_ensemble_compare ├── catboost_info │ ├── catboost_training.json │ ├── learn │ │ └── events.out.tfevents │ ├── learn_error.tsv │ └── time_left.tsv └── compare_and_tuning.ipynb ├── charpter17_kmeans └── kmeans.ipynb ├── charpter18_PCA └── pca.ipynb ├── charpter19_SVD ├── louwill.jpg ├── svd.ipynb └── svd_pic │ ├── svd_1.jpg │ ├── svd_10.jpg │ ├── svd_11.jpg │ ├── svd_12.jpg │ ├── svd_13.jpg │ ├── svd_14.jpg │ ├── svd_15.jpg │ ├── svd_16.jpg │ ├── svd_17.jpg │ ├── svd_18.jpg │ ├── svd_19.jpg │ ├── svd_2.jpg │ ├── svd_20.jpg │ ├── svd_21.jpg │ ├── svd_22.jpg │ ├── svd_23.jpg │ ├── svd_24.jpg │ ├── svd_25.jpg │ ├── svd_26.jpg │ ├── svd_27.jpg │ ├── svd_28.jpg │ ├── svd_29.jpg │ ├── svd_3.jpg │ ├── svd_30.jpg │ ├── svd_31.jpg │ ├── svd_32.jpg │ ├── svd_33.jpg │ ├── svd_34.jpg │ ├── svd_35.jpg │ ├── svd_36.jpg │ ├── svd_37.jpg │ ├── svd_38.jpg │ ├── svd_39.jpg │ ├── svd_4.jpg │ ├── svd_40.jpg │ ├── svd_41.jpg │ ├── svd_42.jpg │ ├── svd_43.jpg │ ├── svd_44.jpg │ ├── svd_45.jpg │ ├── svd_46.jpg │ ├── svd_47.jpg │ ├── svd_48.jpg │ ├── svd_49.jpg │ ├── svd_5.jpg │ ├── svd_50.jpg │ ├── svd_6.jpg │ ├── svd_7.jpg │ ├── svd_8.jpg │ └── svd_9.jpg ├── charpter20_MEM └── max_entropy_model.ipynb ├── charpter21_Bayesian_models ├── bayesian_network.ipynb └── naive_bayes.ipynb ├── charpter22_EM └── em.ipynb ├── charpter23_HMM └── hmm.ipynb ├── charpter24_CRF ├── crf.ipynb └── crffiles.py ├── charpter25_MCMC └── mcmc.ipynb └── mcmc12.ipynb /README.md: -------------------------------------------------------------------------------- 1 | 鲁伟《机器学习公式推导与代码实现》 2 | 第24章的代码在python3.6中才能运行。使用新版本的python的话,可以通过在anaconda中创建沙箱环境实现。 3 | -------------------------------------------------------------------------------- /charpter02_linear_regression/1_1_linear_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | # import pandas as pd 3 | from random import shuffle 4 | 5 | class Linearmodel(): 6 | def __init__(self): 7 | pass 8 | 9 | 10 | ### 初始化模型参数 11 | def initialize_params(self, dims): 12 | # 初始化权重参数为零矩阵 13 | w = np.zeros((dims, 1)) 14 | # 初始化偏差参数为零 15 | b = 0 16 | return w, b 17 | 18 | ### 定义模型主体部分 19 | ### 包括线性回归公式、均方损失和参数偏导三部分 20 | def linear_loss(self, X, y, w, b): 21 | num_train = X.shape[0] # 训练样本数量 22 | # num_feature = X.shape[1] # 训练特征数量 23 | y_hat = np.dot(X, w) + b # 线性回归预测输出 24 | loss = np.sum((y_hat-y)**2)/num_train # 计算预测输出与实际标签之间的均方损失 25 | dw = np.dot(X.T, (y_hat-y)) /num_train # 基于均方损失对权重参数的一阶偏导数 26 | db = np.sum((y_hat-y)) /num_train # 基于均方损失对偏差项的一阶偏导数 27 | return y_hat, loss, dw, db 28 | 29 | ### 定义线性回归模型训练过程 30 | def linear_train(self, X, y, learning_rate=0.01, epochs=10000): 31 | loss_his = [] # 记录训练损失的空列表 32 | w, b = self.initialize_params(X.shape[1]) # 初始化模型参数 33 | # 迭代训练 34 | for i in range(1, epochs): 35 | # 计算当前迭代的预测值、损失和梯度 36 | y_hat, loss, dw, db = self.linear_loss(X, y, w, b) 37 | # 基于梯度下降的参数更新 38 | w += -learning_rate * dw 39 | b += -learning_rate * db 40 | # 记录当前迭代的损失 41 | loss_his.append(loss) 42 | # 每1000次迭代打印当前损失信息 43 | if i % 10000 == 0: 44 | print('epoch %d loss %f' % (i, loss)) 45 | # 将当前迭代步优化后的参数保存到字典 46 | params = { 47 | 'w': w, 48 | 'b': b 49 | } 50 | # 将当前迭代步的梯度保存到字典 51 | grads = { 52 | 'dw': dw, 53 | 'db': db 54 | } 55 | return loss_his, params, grads 56 | 57 | 58 | ### 定义线性回归预测函数 59 | def predict(self, X, params): 60 | # 获取模型参数 61 | w = params['w'] 62 | b = params['b'] 63 | # 预测 64 | y_pred = np.dot(X, w) + b 65 | return y_pred 66 | 67 | 68 | ### 定义R2系数函数 69 | def r2_score(self, y_test, y_pred): 70 | # 测试标签均值 71 | y_avg = np.mean(y_test) 72 | # 总离差平方和 73 | ss_tot = np.sum((y_test - y_avg)**2) 74 | # 残差平方和 75 | ss_res = np.sum((y_test - y_pred)**2) 76 | # R2计算 77 | r2 = 1 - (ss_res/ss_tot) 78 | return r2 79 | 80 | def k_fold_cross_validation(self, items, k, randomize=True): 81 | if randomize: 82 | items = list(items) 83 | shuffle(items) 84 | slices = [items[i::k] for i in range(k)] 85 | for i in range(k): 86 | validation = slices[i] 87 | training = [item 88 | for s in slices if s is not validation 89 | for item in s] 90 | training = np.array(training) 91 | validation = np.array(validation) 92 | yield training, validation 93 | 94 | 95 | if __name__ == "__main__": 96 | import matplotlib as mpl 97 | mpl.rcParams['font.sans-serif'] = ["SimHei"] 98 | mpl.rcParams["axes.unicode_minus"] = False 99 | 100 | # import pandas as pd 101 | from sklearn.datasets import load_diabetes 102 | from sklearn.utils import shuffle # 导入sklearn打乱数据函数 103 | model = Linearmodel() 104 | diabetes = load_diabetes() # 获取diabetes数据集 105 | data, target = diabetes.data, diabetes.target # 获取输入和标签 106 | X, y = shuffle(data, target, random_state=13) # 打乱数据集 107 | offset = int(X.shape[0] * 0.8) # 按照8/2划分训练集和测试集 108 | X_train, y_train = X[:offset], y[:offset] # 训练集 109 | X_test, y_test = X[offset:], y[offset:] # 测试集 110 | y_train = y_train.reshape((-1,1)) # 将训练集改为列向量的形式 111 | y_test = y_test.reshape((-1,1)) # 将验证集改为列向量的形式 112 | 113 | # 打印训练集和测试集维度 114 | print("X_train's shape: ", X_train.shape) 115 | print("X_test's shape: ", X_test.shape) 116 | print("y_train's shape: ", y_train.shape) 117 | print("y_test's shape: ", y_test.shape) 118 | 119 | # 线性回归模型训练 120 | loss_his, params, grads = model.linear_train(X_train, y_train, 0.01, 200000) 121 | # 打印训练后得到模型参数 122 | print(params) 123 | 124 | # 基于测试集的预测 125 | y_pred = model.predict(X_test, params) 126 | # 打印前五个预测值 127 | print(y_pred[:5]) 128 | # 打印前5个实际值 129 | print(y_test[:5]) 130 | 131 | 132 | # 计算并打印决定系数R2 133 | print(model.r2_score(y_test, y_pred)) 134 | 135 | 136 | import matplotlib.pyplot as plt 137 | f = X_test.dot(params['w']) + params['b'] 138 | plt.scatter(range(X_test.shape[0]), y_test) # 散点部分,真实值 139 | plt.plot(f, color = 'darkorange') # 折线部分,预测值 140 | plt.xlabel('X_test') # X轴没有实际意义 141 | plt.ylabel('y_test') 142 | plt.title("预测数据图") 143 | plt.show() 144 | 145 | # 绘制迭代次数与损失函数关系图 146 | plt.plot(loss_his, color = 'blue') 147 | plt.xlabel('epochs') # 迭代次数 148 | plt.ylabel('loss') # 损失函数走势 149 | plt.title("迭代次数与损失函数关系图") 150 | plt.show() 151 | 152 | 153 | for training, validation in model.k_fold_cross_validation(data, 5): 154 | # 训练集,分离变量和标签 155 | X_train = training[:, :10] 156 | y_train = training[:, -1].reshape((-1,1)) 157 | # 测试集,分离变量和标签 158 | X_valid = validation[:, :10] 159 | y_valid = validation[:, -1].reshape((-1,1)) 160 | print("查看数据信息:") 161 | print(X_train.shape, y_train.shape, X_valid.shape, y_valid.shape) 162 | 163 | loss5 = [] 164 | loss, params, grads = model.linear_train(X_train, y_train, 0.001, 100000) # linear_train拼写错了 165 | loss5.append(loss) 166 | score = np.mean(loss5) 167 | print('five kold cross validation score is', score) # 5折交叉验证得分 168 | 169 | y_pred = model.predict(X_valid, params) 170 | valid_score = np.sum(((y_pred - y_valid)**2))/len(X_valid) 171 | print('valid score is', valid_score) 172 | print() -------------------------------------------------------------------------------- /charpter05_LDA/LDA.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## LDA" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "线性判别分析:基本思想是将数据投影到低维度空间,使得同类数据尽可能接近,异类数据尽可能疏远。 \n", 15 | "主要手段是降维。 " 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": {}, 21 | "source": [ 22 | "### 一.手写算法" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 1, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "# 算法部分\n", 32 | "import numpy as np\n", 33 | "\n", 34 | "class LDA():\n", 35 | " def __init__(self):\n", 36 | " # 初始化权重矩阵\n", 37 | " self.w = None\n", 38 | " \n", 39 | " # 计算协方差矩阵\n", 40 | " def calc_cov(self, X, Y=None):\n", 41 | " m = X.shape[0]\n", 42 | " # 数据标准化\n", 43 | " X = (X - np.mean(X, axis=0))/np.std(X, axis=0)\n", 44 | " Y = X if Y == None else (Y - np.mean(Y, axis=0))/np.std(Y, axis=0)\n", 45 | " return 1 / m * np.matmul(X.T, Y)\n", 46 | " \n", 47 | " # 对数据进行投影\n", 48 | " def project(self, X, y):\n", 49 | " self.fit(X, y)\n", 50 | " X_projection = X.dot(self.w)\n", 51 | " return X_projection\n", 52 | " \n", 53 | " # LDA拟合过程\n", 54 | " def fit(self, X, y):\n", 55 | " # 按类分组\n", 56 | " X0 = X[y == 0]\n", 57 | " X1 = X[y == 1]\n", 58 | "\n", 59 | " # 分别计算两类数据自变量的协方差矩阵\n", 60 | " sigma0 = self.calc_cov(X0)\n", 61 | " sigma1 = self.calc_cov(X1)\n", 62 | " # 计算类内散度矩阵\n", 63 | " Sw = sigma0 + sigma1\n", 64 | "\n", 65 | " # 分别计算两类数据自变量的均值和差\n", 66 | " u0, u1 = np.mean(X0, axis=0), np.mean(X1, axis=0)\n", 67 | " mean_diff = np.atleast_1d(u0 - u1)\n", 68 | "\n", 69 | " # 对类内散度矩阵进行奇异值分解\n", 70 | " U, S, V = np.linalg.svd(Sw)\n", 71 | " # 计算类内散度矩阵的逆\n", 72 | " Sw_ = np.dot(np.dot(V.T, np.linalg.pinv(np.diag(S))), U.T)\n", 73 | " # 计算w\n", 74 | " self.w = Sw_.dot(mean_diff)\n", 75 | "\n", 76 | " # LDA分类预测\n", 77 | " def predict(self, X):\n", 78 | " y_pred = []\n", 79 | " for sample in X:\n", 80 | " h = sample.dot(self.w)\n", 81 | " y = 1 * (h < 0)\n", 82 | " y_pred.append(y)\n", 83 | " return y_pred" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 2, 89 | "metadata": {}, 90 | "outputs": [ 91 | { 92 | "name": "stdout", 93 | "output_type": "stream", 94 | "text": [ 95 | "(80, 4) (20, 4) (80,) (20,)\n" 96 | ] 97 | } 98 | ], 99 | "source": [ 100 | "# 读取数据,分离数据集\n", 101 | "from sklearn import datasets\n", 102 | "import matplotlib.pyplot as plt\n", 103 | "from sklearn.model_selection import train_test_split\n", 104 | "\n", 105 | "data = datasets.load_iris()\n", 106 | "X = data.data\n", 107 | "y = data.target\n", 108 | "\n", 109 | "X = X[y != 2]\n", 110 | "y = y[y != 2]\n", 111 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=41)\n", 112 | "print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 3, 118 | "metadata": {}, 119 | "outputs": [ 120 | { 121 | "name": "stdout", 122 | "output_type": "stream", 123 | "text": [ 124 | "0.85\n", 125 | "[0 1 0 0 1 0 0 0 0 1 1 1 1 0 1 1 1 0 0 0]\n", 126 | "[0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1]\n" 127 | ] 128 | } 129 | ], 130 | "source": [ 131 | "# 使用手写算法完成模型\n", 132 | "lda = LDA()\n", 133 | "lda.fit(X_train, y_train)\n", 134 | "y_pred = lda.predict(X_test)\n", 135 | "\n", 136 | "from sklearn.metrics import accuracy_score\n", 137 | "accuracy = accuracy_score(y_test, y_pred)\n", 138 | "print(accuracy)\n", 139 | "# 真实值\n", 140 | "print(y_test)\n", 141 | "# 预测值\n", 142 | "print(y_pred)" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 4, 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "# 绘图部分\n", 152 | "import matplotlib.pyplot as plt\n", 153 | "import matplotlib.cm as cmx\n", 154 | "import matplotlib.colors as colors\n", 155 | "\n", 156 | "# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #\n", 157 | "# 这段是在网上找出来改的\n", 158 | "def calculate_covariance_matrix(X, Y=np.empty((0,0))):\n", 159 | " if not Y.any():\n", 160 | " Y = X\n", 161 | " n_samples = np.shape(X)[0]\n", 162 | " covariance_matrix = (1 / (n_samples-1)) * (X - X.mean(axis=0)).T.dot(Y - Y.mean(axis=0))\n", 163 | " return np.array(covariance_matrix, dtype=float)\n", 164 | "# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #\n", 165 | "\n", 166 | "class Plot():\n", 167 | " def __init__(self): \n", 168 | " self.cmap = plt.get_cmap('viridis')\n", 169 | "\n", 170 | " def _transform(self, X, dim):\n", 171 | " covariance = calculate_covariance_matrix(X)\n", 172 | " eigenvalues, eigenvectors = np.linalg.eig(covariance)\n", 173 | " # Sort eigenvalues and eigenvector by largest eigenvalues\n", 174 | " idx = eigenvalues.argsort()[::-1]\n", 175 | " eigenvalues = eigenvalues[idx][:dim]\n", 176 | " eigenvectors = np.atleast_1d(eigenvectors[:, idx])[:, :dim]\n", 177 | " # Project the data onto principal components\n", 178 | " X_transformed = X.dot(eigenvectors)\n", 179 | " return X_transformed\n", 180 | "\n", 181 | " def plot_regression(self, lines, title, axis_labels=None, mse=None, scatter=None, legend={\"type\": \"lines\", \"loc\": \"lower right\"}):\n", 182 | " if scatter:\n", 183 | " scatter_plots = scatter_labels = []\n", 184 | " for s in scatter:\n", 185 | " scatter_plots += [plt.scatter(s[\"x\"], s[\"y\"], color=s[\"color\"], s=s[\"size\"])]\n", 186 | " scatter_labels += [s[\"label\"]]\n", 187 | " scatter_plots = tuple(scatter_plots)\n", 188 | " scatter_labels = tuple(scatter_labels)\n", 189 | "\n", 190 | " for l in lines:\n", 191 | " li = plt.plot(l[\"x\"], l[\"y\"], color=s[\"color\"], linewidth=l[\"width\"], label=l[\"label\"])\n", 192 | "\n", 193 | " if mse:\n", 194 | " plt.suptitle(title)\n", 195 | " plt.title(\"MSE: %.2f\" % mse, fontsize=10)\n", 196 | " else:\n", 197 | " plt.title(title)\n", 198 | "\n", 199 | " if axis_labels:\n", 200 | " plt.xlabel(axis_labels[\"x\"])\n", 201 | " plt.ylabel(axis_labels[\"y\"])\n", 202 | "\n", 203 | " if legend[\"type\"] == \"lines\":\n", 204 | " plt.legend(loc=\"lower_left\")\n", 205 | " elif legend[\"type\"] == \"scatter\" and scatter:\n", 206 | " plt.legend(scatter_plots, scatter_labels, loc=legend[\"loc\"])\n", 207 | "\n", 208 | " plt.show()\n", 209 | "\n", 210 | "\n", 211 | " # Plot the dataset X and the corresponding labels y in 2D using PCA.\n", 212 | " def plot_in_2d(self, X, y=None, title=None, accuracy=None, legend_labels=None):\n", 213 | " X_transformed = self._transform(X, dim=2)\n", 214 | " x1 = X_transformed[:, 0]\n", 215 | " x2 = X_transformed[:, 1]\n", 216 | " class_distr = []\n", 217 | " y = np.array(y).astype(int)\n", 218 | " colors = [self.cmap(i) for i in np.linspace(0, 1, len(np.unique(y)))]\n", 219 | "\n", 220 | " # Plot the different class distributions\n", 221 | " for i, l in enumerate(np.unique(y)):\n", 222 | " _x1 = x1[y == l]\n", 223 | " _x2 = x2[y == l]\n", 224 | " _y = y[y == l]\n", 225 | " class_distr.append(plt.scatter(_x1, _x2, color=colors[i]))\n", 226 | "\n", 227 | " # Plot legend\n", 228 | " if not legend_labels is None: \n", 229 | " plt.legend(class_distr, legend_labels, loc=1)\n", 230 | "\n", 231 | " # Plot title\n", 232 | " if title:\n", 233 | " if accuracy:\n", 234 | " perc = 100 * accuracy\n", 235 | " plt.suptitle(title)\n", 236 | " plt.title(\"Accuracy: %.1f%%\" % perc, fontsize=10)\n", 237 | " else:\n", 238 | " plt.title(title)\n", 239 | "\n", 240 | " # Axis labels\n", 241 | " plt.xlabel('class 1')\n", 242 | " plt.ylabel('class 2')\n", 243 | "\n", 244 | " plt.show()" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 5, 250 | "metadata": {}, 251 | "outputs": [ 252 | { 253 | "data": { 254 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEjCAYAAAAomJYLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAfLklEQVR4nO3dfZgdZZ3m8e/diUmmE0IkNBggIaDZqCBCto2yDBCMZsAdwuAgQmcVDWNvssLswjojE5FRRnPNiktWUclEjG9DD6tgtNfhJU5fu5oRYdKRQMKbxhhJSGI64SWGhoTQv/2jqsPJSXX36dB1Xvrcn+s61zn1VNU5vz6EvrvqqXoeRQRmZmbFGipdgJmZVScHhJmZZXJAmJlZJgeEmZllckCYmVkmB4SZmWVyQJiZWSYHhFkJJG2S9J6itlmSeiTtSR9bJH1P0jsy9j8p3fZr5ava7LVxQJi9NlsjYhxwBPAu4AlglaTZRdt9GHgWuEzS6DLXaHZYHBBmQyASWyLiBuA24H8UbfJh4HrgZeDCctdndjgcEGZD7wfADEljASSdDZwA3AF8jyQszKqeA8Js6G0FBExIl68A7omIZ4E24AJJx1SoNrOSOSDMht7xQADPSfoj4APA7QAR8QvgKaClcuWZlcYBYTb0LgZ+GREvpK/HA1+TtF3SdpIA8Wkmq3ojK12AWQ15naQxBcsH/v+RJOA44C/Sx9x01RXAcuBTBfsdD6yW9LaIWJdvyWaHT54PwmxgkjYBJxY1/xz4D0A3SZ/D88D9wBcj4gFJxwO/A84oDgJJdwOPRcQn8q7d7HA5IMzMLJP7IMzMLJMDwszMMjkgzMwskwPCzMwyOSBsWJB0saSQ9OZK1/JaSDpd0gOS1krqlDQzbZ8q6cW0fa2kpX3sf5Skn0j6dfr8+rT9LEmPSFot6U1p2wRJ96WX6JodwgFhw8XlwL8Cl+X5IZJG5Pn+wBeAz0bE6cAN6XKv30TE6eljQR/7Xwd0RMQ0oCNdBvjvwJ8Di4CFadungcXhSxmtDw4Iq3mSxgFnAVdSEBCSRkj6oqR16V/PV6ft75B0v6SHJf2bpCMkfUTSVwr2/bGkWenrPZJulPQgcKakG9K/xNdLWtb7F7ikN0n6l/R9fynpjZK+K+migve9XVLvTXRZguTOa4AjScZ1GoyLgG+nr78N/Fn6+mXgj4BG4GVJbwSOj4ifDvL9rZ5EhB9+1PQD+E/AN9LX9wMz0tcLgbuAkenyUcAoYCPwjrRtPMkd0R8BvlLwnj8GZqWvA7i0YN1RBa+/C1yYvn4QuDh9PYbkl/G5wA/TtiOB36afdzdwXMbP8haSsZo2A08DJ6btU4EXgIeAnwJn9/FdPFe0/Gz6fDrwAPB/eXVk2WmV/m/nR3U/fARhw8HlJL/wSJ8vT1+/B1gaEfsBIuIZYDqwLSJWp227e9f34xWSoOl1nqQHJa0D3g2cIukIkr/IV6Tv+1JEdEfyF/qb0tFbLwfuioj9EfG+iMg6OlgIXBMRk4FrgG+k7duAKRFxBnAt0CZpfMb+mSJibUS8KyLOA04mHXFW0v+W9I+Sji31vax+eCwmq2mSJpL8kj5VUgAjgJD01yTDXxSfX89qA9jPwadcC8dceikiXkk/bwzwNaA5IjZL+ky6bX8dvd8F5pGc/po/wI90BfBf09ffJ5l8iIjYC+xNX6+R9Bvg3wGdRfv/XtKkiNgmaRKwo3BlejrseuCDwFeAvyU5OvlLDh4vysxHEFbzLgG+ExEnRsTU9C/v3wJ/DKwEFkgaCckVPiRTgh7XO2902v8wEtgEnC6pQdJkYGYfn9cbHDvTvo9LIDkSAbZI+rP0fUdLaky3/Rbw39LtHh3g59lKcloKkuD7dfp+Tb0d5JJOBqaRnCor1k4SMqTPPypafwXwz5HMTdEI9KSPRsyK+AjCat3lwN8Xtd1FMt/C1SR/ZT8i6WXg6xHxFUkfBG5J52p4keRU1M9JgmUdsB74ZdaHRcRzkr6ebrcJWF2w+kPAP0i6kaRT+APAxoj4vaTHgR/2bpgO1vcXGaeZPgZ8KQ2tl4DWtP0c4EZJ+0lOeS1IT5kh6TaSU2md6XfxPUlXkvRlfKDgMxtJAmJO2nRz+l3t49XTcmYHeLA+s5ylv5jXkXSeP1/pesxK5VNMZjmS9B6S01q3OBys1vgIwszMMvkIwszMMjkgzMws07C6iunoo4+OqVOnVroMM7OasWbNmp0R0ZS1blgFxNSpU+nsLL5vyMzM+iLpd32t8ykmMzPL5IAwM7NMDggzM8vkgDAzs0wOCDMzy+SAMDOrch1tq5g3dSFzRlzKvKkL6WhbVZbPHVaXuZqZDTcdbatY0rqUvd37ANjx1E6WtC4FYHbL2bl+to8gzMyq2PJFbQfCodfe7n0sX9SW+2c7IMzMqljX5l2Dah9KDggzsyrWNHnioNqHkgPCzKyKzV/cwujGUQe1jW4cxfzFLbl/tjupzcyqWG9H9PJFbXRt3kXT5InMX9ySewc1DLMJg5qbm8OD9ZmZlU7SmohozlrnU0xmZpbJAWFmZpkcEGZmlskBYWZmmRwQZmaWKdfLXCVNAG4DTgUCmB8RvyhY/1fAvIJa3gI0RcQzkjYBfwBeAfb31ctuZmb5yPs+iC8B90bEJZJGAY2FKyPiJuAmAEkXAtdExDMFm5wXETtzrtHMzDLkFhCSxgPnAB8BiIh9wL5+drkc+Ke86jEzs8HJsw/iZKAL+KakhyTdJmls1oaSGoHzgbsKmgNYKWmNpNYc6zQzswx5BsRIYAZwa0ScAbwAXNfHthcCPy86vXRWRMwALgA+LumcrB0ltUrqlNTZ1dU1hOWbmdW3PANiC7AlIh5Ml+8kCYwsl1F0eikitqbPO4AVwMysHSNiWUQ0R0RzU1PTkBRuZmY5BkREbAc2S5qeNs0GHiveTtKRwLnAjwraxko6ovc1MAdYn1etZmZ2qLyvYroauD29gmkj8FFJCwAiYmm6zcXAyoh4oWC/Y4EVknprbIuIe3Ou1czMCng0VzOzOubRXMusp7udnh2z6Nk+PXnubq90SWZmg+YJg4ZYT3c77L4eeClt2Aq7r6cHaGicW8nSzMwGxUcQQ23PzRwIhwNeStvNzGqHA2Ko9WwbXLuZWZVyQAy1hkmDazczq1IOiKE27lpgTFHjmLTdzKx2uJN6iDU0zqUHkj6Hnm3JkcO4a91BbWY1xwGRg4bGueBAMLMa51NMZmZVrqNtFfOmLmTOiEuZN3UhHW2ryvK5PoIwM6tCHW2rWL6ojR1P7QSRTIAA7HhqJ0tak5GKZrecnWsNPoIwM6syHW2rWNK6NAkHOBAOvfZ272P5orbc63BAmJlVmeWL2tjb3d8EnNC1eVfudTggzMyqTCm//JsmT8y9DgeEmVmVGeiX/+jGUcxf3JJ7HQ4IM7MqM39xC6MbRx3cqOTpmClHc82yBbl3UIOvYjIzqzq9v/yXL2qja/MumiZPZP7ilrKEQiFPGGRmVsc8YZCZmQ1argEhaYKkOyU9IelxSWcWrZ8l6XlJa9PHDQXrzpf0pKQNkq7Ls04zMztU3n0QXwLujYhLJI0CGjO2WRURf1rYIGkE8FXgvcAWYLWk9oh4LOd6zcwsldsRhKTxwDnANwAiYl9EPFfi7jOBDRGxMSL2AXcAF+VSqJmZZcrzFNPJQBfwTUkPSbpN0tiM7c6U9LCkeySdkrYdD2wu2GZL2nYISa2SOiV1dnV1DekPYGZWz/IMiJHADODWiDgDeAEo7kv4JXBiRLwduAX4YdqujPfLvNwqIpZFRHNENDc1NQ1J4WZmlm9AbAG2RMSD6fKdJIFxQETsjog96eu7gddJOjrdd3LBpicAW3Os1czMiuQWEBGxHdgsaXraNBs4qJNZ0hskKX09M61nF7AamCbppLRz+zKgPa9azczsUHlfxXQ1cHv6S34j8FFJCwAiYilwCbBQ0n7gReCySO7c2y/pKuA+YASwPCIezblWMzMr4DupzczqmO+kNjOzQXNAmJlZJgeEmZllckBUsZ7udnp2zKJn+/TkudsXcplZ+Xg+iCrV090Ou68HXkobtsLu6+kBGhrnVrI0M6sTPoKoVntu5kA4HPBS2m5mlj8HRLXq2Ta4djOzIeaAqFYNkwbXbmY2xBwQ1WrctcCYosYxabuZWf7cSV2lGhrn0gNJn0PPtuTIYdy17qA2s7Kp+4Do6W6v2l/CDY1zoUpqMbP6U9cB4UtJzcz6Vt99EL6U1MysT/UdEL6U1MysT/UdEL6U1MysT/UdEL6U1MysT3XdSe1LSc3M+lbXAQG+lNTMrC+5nmKSNEHSnZKekPS4pDOL1s+T9Ej6uF/S2wvWbZK0TtJaSZ5H1MyszPI+gvgScG9EXCJpFNBYtP63wLkR8aykC4BlwDsL1p8XETtzrtHMzDLkFhCSxgPnAB8BiIh9wL7CbSLi/oLFB4AT8qrHzMwGJ89TTCcDXcA3JT0k6TZJY/vZ/krgnoLlAFZKWiOpta+dJLVK6pTU2dXVNTSVm5lZrgExEpgB3BoRZwAvANdlbSjpPJKA+GRB81kRMQO4APi4pHOy9o2IZRHRHBHNTU1NQ/oDmJnVszwDYguwJSIeTJfvJAmMg0g6DbgNuCgidvW2R8TW9HkHsAKYmWOtZmZWJLeAiIjtwGZJ09Om2cBjhdtImgL8APhQRPyqoH2spCN6XwNzgPV51WpmZofK+yqmq4Hb0yuYNgIflbQAICKWAjcAE4GvSQLYHxHNwLHAirRtJNAWEffmXKuZmRVQRFS6hiHT3NwcnZ2+ZcLMrFSS1qR/mB+ivsdiMjMrk57udnp2zKJn+/Tkubu90iUNqO6H2jAzy1utTk7mIwgzs7zV6ORkDggzs7zV6ORkDggzs7zV6ORkDggzs7zV6ORk7qQ2M8tZrU5O5oAwMyuDgSYn62hbxfJFbXRt3kXT5InMX9zC7Jazy1jhoRwQZmYV1tG2iiWtS9nbncyIsOOpnSxpXQpQ0ZBwH4SZWYUtX9R2IBx67e3ex/JFbRWqKOGAMDOrsK7NuwbVXi79BoSkP5F0paSpRe3zc63KzKyONE2eOKj2cukzICQtBj4FvA3okHR1weqr8i7MzKxezF/cwujGUQe1jW4cxfzFLRWqKNFfJ/WFwBkRsV/SZ4A2SSdHxDWAylKdmVkd6O2IrrarmPoc7lvS4xHxloLlEcAyYDzw1og4pTwlls7DfZuZDc7hDvf9G0nn9i5ExCsRcSXwJPCWvnerb7U4pK+ZVb+OtlXMm7qQOSMuZd7UhXS0rcr9M/s7xfSBrMaIuF7SrTnVU9NqdUhfM6tulbpPos8jiIh4MSJe7GPd06W8uaQJku6U9ISkxyWdWbRekr4saYOkRyTNKFh3vqQn03XXlfoDVVSNDulrZtWtUvdJ5H0fxJeAeyPizcDbgceL1l8ATEsfrcCtcKC/46vp+rcCl0t6a861vnY1OqSvmVW3St0nkVtASBoPnAN8AyAi9kXEc0WbXQR8JxIPABMkTQJmAhsiYmNE7APuSLetbjU6pK+ZVbdK3ScxYEBIeqOk0enrWZL+UtKEEt77ZKAL+KakhyTdJmls0TbHA5sLlrekbX21V7caHdLXzKpbpe6TKOUI4i7gFUlvIjkaOAko5cTXSGAGcGtEnAG8ABT3JWTdTxH9tB9CUqukTkmdXV1dJZSVn4bGuTD+c9BwHKDkefzn3EFtZq/J7JazuWbZAo6ZcjSSOGbK0VyzbEHu90mUMpprT3qz3MXA/4qIWyQ9VMJ+W4AtEfFgunwnhwbEFmBywfIJwFZgVB/th4iIZST3Z9Dc3Jx9U0cZDTSkb7Xq6W6vubHqzerJ7Jazy37jXClHEC9Luhy4Avhx2va6gXaKiO3AZknT06bZwGNFm7UDH06vZnoX8HxEbANWA9MknSRpFHBZuq3l4MDluT1bgXj18lzfw2FW10oJiI8CZwKfj4jfSjoJ+McS3/9q4HZJjwCnA4slLZC0IF1/N7AR2AB8HfgvABGxn2S8p/tIrnz6XkQ8WuJn2mDV6OW5vinRLF99DrWRubH0emByRDySX0mHz0NtHJ6e7dPJ7uIRDW94stzllOSQmxIBGOM+H7NBOtyhNnp3/n+Sxks6CniY5Kqk6v7T0ganFi/PrdGjHrNaUsoppiMjYjfwfuCbEfHvgffkW5aVVS1enuubEs1yV0pAjExvXruUVzupbRipyctza/Gox6zGlHKZ640kncX/GhGrJZ0M/Drfsqzcau7y3HHXZvdBVPNRj1mNGTAgIuL7wPcLljcCf55nUWYDaWicSw/43g2zHA0YEJLGAFcCp1BwojoiPC+1VVTNHfWY1ZhS+iC+C7wB+BPgpyR3Nf8hz6LMzKzySgmIN0XEp4EXIuLbwH8E3pZvWWZmVmklDbWRPj8n6VTgSGBqbhWZmVlVKCUglqV3UH+aZDykx4Av5FqVmVkVqPfhXEq5ium29OVPSeZ4MDMb9jzHfD8BIanfC8ojwmMamNnw1d9wLvUeEMARZavCzKzaeDiXvgMiIj5bzkLMzKpKw6R0jpSM9jpRymiu3y6cg1rS6yUtz7UqM7NKq8VBLIdYKWMxnRYRz/UuRMSzks7IryQzs8rzcC6lBUSDpNdHxLMA6bwQpexnZlbT6n04l1J+0f9P4H5Jd5JMO3Yp8PlcqzIzs4or5T6I70jqBN4NCHh/RDxWyptL2kQybtMrwP7iae0k/RUwr6CWtwBNEfHMQPuamVm+SjpVlAZCSaGQ4byI2NnH+94E3AQg6ULgmoh4ppR9zcwsX6UMtVEulwP/VOkizMwskXdABLBS0hpJrX1tJKkROB+46zD2bZXUKamzq6tryAo3M6t3eV+NdFZEbJV0DPATSU9ExM8ytrsQ+HnR6aWS9o2IZcAygObm5sjjhzAzq0e5HkFExNb0eQewApjZx6aXUXR6aRD7mplZDnILCEljJR3R+xqYA6zP2O5I4FzgR4Pd18zM8pPnKaZjgRWSej+nLSLulbQAICKWpttdDKyMiBcG2jfHWs3MrIgihs9p++bm5ujs7Kx0GWY2zPV0t5dlCI6OtlUsX9RG1+ZdNE2eyPzFLcxuOXtIP0PSmr7uM/OQGWZmg1CuiYQ62laxpHUpe7v3AbDjqZ0saU1OvAx1SPSlmu6DMDOrfv1NJDSEli9qOxAOvfZ272P5orYh/Zz+OCDMzAajTBMJdW3eNaj2PDggzMwGo68Jg4Z4IqGmyRMH1Z4HB4SZ2WCUaSKh+YtbGN046qC20Y2jmL+4ZUg/pz/upDYzG4RyTSTU2xGd91VM/fFlrmZmday/y1x9isnMrMp1tK1i3tSFzBlxKfOmLqSjbVVZPtenmMzMqlgl74fwEYSZWRWr5P0QDggzsypWyfshHBBmZlWskvdDOCDMzKpYJe+HcCe1mVkVq+T9EL4Pwsysjvk+CDMzGzQHhJmZZco1ICRtkrRO0lpJh5z7kTRL0vPp+rWSbihYd76kJyVtkHRdnnWamdmhytFJfV5E7Oxn/aqI+NPCBkkjgK8C7wW2AKsltUfEYznWaWZmBar1FNNMYENEbIyIfcAdwEUVrsnMrK7kHRABrJS0RlJrH9ucKelhSfdIOiVtOx7YXLDNlrTNzMzKJO9TTGdFxFZJxwA/kfRERPysYP0vgRMjYo+k9wE/BKYBynivzOtx0+BpBZgyZcqQFm9mVs9yPYKIiK3p8w5gBcmpo8L1uyNiT/r6buB1ko4mOWKYXLDpCcDWPj5jWUQ0R0RzU1NTDj+FmVl9yi0gJI2VdETva2AOsL5omzdIUvp6ZlrPLmA1ME3SSZJGAZcB7XnVamZmh8rzFNOxwIr09/9IoC0i7pW0ACAilgKXAAsl7QdeBC6L5Nbu/ZKuAu4DRgDLI+LRHGs1M7MiHmrDzKyOeagNMzMbNAeEmZllckCYmVkmB4SZWY3qaFvFvKkLmTPiUuZNXUhH26ohfX9PGGRmVoM62laxpHUpe7v3AbDjqZ0saV0KMGSTCfkIwsysBi1f1HYgHHrt7d7H8kVtQ/YZDggzsxrUtXnXoNoPhwPCzKwGNU2eOKj2w+GAMDOrQfMXtzC6cdRBbaMbRzF/ccuQfYY7qc3MalBvR/TyRW10bd5F0+SJzF/cMmQd1OChNszM6pqH2jAzs0FzQJiZWSYHhJmZZXJAmJlZJgeEmZllckCYmVkmB4SZmWXK9UY5SZuAPwCvAPuLr7WVNA/4ZLq4B1gYEQ+Xsq+ZmeWrHHdSnxcRO/tY91vg3Ih4VtIFwDLgnSXua2ZmOaroKaaIuD8ink0XHwBOqGQ9ZoV6utvp2TGLnu3Tk+fu9kqXZFZWeR9BBLBSUgD/EBHL+tn2SuCewe4rqRVoBZgyZcrQVG11qae7HfbcDD3bgCOBF4CX05VbYff19AANjXMrV6RZGeUdEGdFxFZJxwA/kfRERPyseCNJ55EExB8Pdt80OJZBMhZTPj+GDXc93e2w+3rgpbTluYytXkoCxAFhdSLXU0wRsTV93gGsAGYWbyPpNOA24KKI2DWYfc2GzJ6beTUc+tGzLfdSzKpFbgEhaaykI3pfA3OA9UXbTAF+AHwoIn41mH3NhlSpv/gbJuVbh1kVyfMU07HACkm9n9MWEfdKWgAQEUuBG4CJwNfS7XovZ83cN8dard41TEr6Gfo1BsZdW5ZyzKqB54MwI6sPAmAkaBzE80mAjLvWHdQ27PQ3H4RnlDMjuTKpB169ismBYOaAMOvV0DjXVyiZFfBYTGZmlskBYWZmmRwQZmaWyQFhZmaZHBBmZpbJAWFmZpkcEGZmlskBYWZmmRwQZmaWyQFhZmaZHBBmZpbJAWFmZpkcEGZmlskBYWZmmRwQVtN6utvp2TGLnu3Tk+fu9kqXZDZs5BoQkjZJWidpraRDpnpT4suSNkh6RNKMgnXnS3oyXXddnnVabTowC1zPViCS593XOyTMhkg5jiDOi4jT+5jS7gJgWvpoBW4FkDQC+Gq6/q3A5ZLeWoZarZbsuZmDpwglWd5zcyWqMRt2Kn2K6SLgO5F4AJggaRIwE9gQERsjYh9wR7qt2at6tg2u3cwGJe+ACGClpDWSWjPWHw9sLljekrb11W72qoZJg2s3s0HJOyDOiogZJKeKPi7pnKL1ytgn+mk/hKRWSZ2SOru6ul5btVZbxl0LjClqHJO2m9lrlWtARMTW9HkHsILk1FGhLcDkguUTgK39tGd9xrKIaI6I5qampqEq3WpAQ+NcGP85aDgOUPI8/nNJu5m9ZiPzemNJY4GGiPhD+noOcGPRZu3AVZLuAN4JPB8R2yR1AdMknQQ8DVwGtORVq9Wuhsa54EAwy0VuAQEcC6yQ1Ps5bRFxr6QFABGxFLgbeB+wAegGPpqu2y/pKuA+YASwPCIezbFWMzMroojMU/s1qbm5OTo7D7ndwszM+iBpTR+3IVT8MlczM6tSDggzM8vkgDAzs0zDqg8ivfrpd5Wuo4yOBnZWuogq5u+nf/5++lcv38+JEZF5j8CwCoh6I6mzr84l8/czEH8//fP341NMZmbWBweEmZllckDUtmWVLqDK+fvpn7+f/tX99+M+CDMzy+QjCDMzy+SAqHGSPiPp6XRa17WS3lfpmqqRpE9ICklHV7qWaiLp79LpftdKWinpuErXVE0k3STpifQ7WiFpQqVrKicHxPCwJJ3W9fSIuLvSxVQbSZOB9wJPVbqWKnRTRJwWEacDPwZuqHA91eYnwKkRcRrwK+BvKlxPWTkgrB4sAf6aPiadqmcRsbtgcSz+jg4SESsjYn+6+ADJ3DR1wwExPFyVHgIvl/T6ShdTTSTNBZ6OiIcrXUu1kvR5SZuBefgIoj/zgXsqXUQ5+SqmGiDpX4A3ZKz6FMlfNTtJ/vL7O2BSRMwvY3kVN8D3swiYExHPS9oENEdEPQyfcEB/309E/Khgu78BxkTE35atuCpQyvcj6VNAM/D+qKNfmg6IYUTSVODHEXFqpWupBpLeBnSQTEYFr05dOzMitlessCol6UTgn/3v52CSrgAWALMjonug7YeTPGeUszKQNCkitqWLFwPrK1lPNYmIdcAxvcv1egTRH0nTIuLX6eJc4IlK1lNtJJ0PfBI4t97CAXwEUfMkfRc4neQU0ybgPxcEhhVwQBxK0l3AdKCHZCTkBRHxdGWrqh6SNgCjgV1p0wMRsaCCJZWVA8LMzDL5KiYzM8vkgDAzs0wOCDMzy+SAMDOzTA4IMzPL5IAwO0zpSLqfyOm9Py9ps6Q9eby/WSkcEGbV6f8AMytdhNU3B4RZCSR9OB0Q8eH05sTi9R+TtDpdf5ekxrT9A5LWp+0/S9tOkfRv6RwMj0iaVvx+EfGAb3i0SvONcmYDkHQK8APgrIjYKemoiHhG0meAPRHxRUkTI2JXuv3ngN9HxC2S1gHnR8TTkiZExHOSbiG5I/d2SaOAERHxYh+fvScixpXnJzU7mI8gzAb2buDO3iE6IuKZjG1OlbQqDYR5wClp+8+Bb0n6GDAibfsFsEjSJ4ET+woHs0pzQJgNTAw8kc63gKsi4m3AZ4ExAOm4PdcDk4G16ZFGG8nAeC8C90l6d16Fm70WDgizgXUAl0qaCCDpqIxtjgC2SXodyREE6bZvjIgHI+IGknk7Jks6GdgYEV8G2oHTcv8JzA6DA8JsABHxKPB54KeSHgZuztjs08CDJHMYFw6ZfZOkdZLWAz8DHgY+CKyXtBZ4M/Cd4jeT9AVJW4BGSVvS/g6zsnIntZmZZfIRhJmZZXJAmJlZJgeEmZllckCYmVkmB4SZmWVyQJiZWSYHhJmZZXJAmJlZpv8Pv6r47tbqzvQAAAAASUVORK5CYII=\n", 255 | "text/plain": [ 256 | "
" 257 | ] 258 | }, 259 | "metadata": { 260 | "needs_background": "light" 261 | }, 262 | "output_type": "display_data" 263 | } 264 | ], 265 | "source": [ 266 | "Plot().plot_in_2d(X_test, y_pred, title=\"LDA\", accuracy=accuracy)" 267 | ] 268 | }, 269 | { 270 | "cell_type": "markdown", 271 | "metadata": {}, 272 | "source": [ 273 | "### 二.使用sklearn处理" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 6, 279 | "metadata": {}, 280 | "outputs": [ 281 | { 282 | "name": "stdout", 283 | "output_type": "stream", 284 | "text": [ 285 | "1.0\n" 286 | ] 287 | } 288 | ], 289 | "source": [ 290 | "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n", 291 | "clf = LinearDiscriminantAnalysis()\n", 292 | "clf.fit(X_train, y_train)\n", 293 | "y_pred = clf.predict(X_test)\n", 294 | "accuracy = accuracy_score(y_test, y_pred)\n", 295 | "print(accuracy)" 296 | ] 297 | } 298 | ], 299 | "metadata": { 300 | "kernelspec": { 301 | "display_name": "Python 3 (ipykernel)", 302 | "language": "python", 303 | "name": "python3" 304 | }, 305 | "language_info": { 306 | "codemirror_mode": { 307 | "name": "ipython", 308 | "version": 3 309 | }, 310 | "file_extension": ".py", 311 | "mimetype": "text/x-python", 312 | "name": "python", 313 | "nbconvert_exporter": "python", 314 | "pygments_lexer": "ipython3", 315 | "version": "3.9.7" 316 | }, 317 | "toc": { 318 | "base_numbering": 1, 319 | "nav_menu": {}, 320 | "number_sections": true, 321 | "sideBar": true, 322 | "skip_h1_title": false, 323 | "title_cell": "Table of Contents", 324 | "title_sidebar": "Contents", 325 | "toc_cell": false, 326 | "toc_position": {}, 327 | "toc_section_display": true, 328 | "toc_window_display": false 329 | } 330 | }, 331 | "nbformat": 4, 332 | "nbformat_minor": 2 333 | } 334 | -------------------------------------------------------------------------------- /charpter07_decision_tree/CART.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## 决策树:CART" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "决策树基于特征对数据实例按照条件不断进行划分,最终达到分类或者回归的目的。 \n", 15 | "决策树模型的核心概念包括特征选择方法、决策树构造过程和决策树剪枝。 \n", 16 | "常见的特征选择方法包括信息增益、信息增益比和基尼系数,对应的三种常见的决策树算法为ID3、C4.5和CART。 " 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": {}, 22 | "source": [ 23 | "### 一.手写算法" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 1, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "import numpy as np\n", 33 | "from sklearn.model_selection import train_test_split\n", 34 | "from sklearn.metrics import accuracy_score, mean_squared_error\n", 35 | "from utils import feature_split, calculate_gini" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 2, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "### 定义树结点\n", 45 | "class TreeNode():\n", 46 | " def __init__(self, feature_i=None, threshold=None, leaf_value=None, left_branch=None, right_branch=None):\n", 47 | " self.feature_i = feature_i # 特征索引 \n", 48 | " self.threshold = threshold # 特征划分阈值\n", 49 | " self.leaf_value = leaf_value # 叶子节点取值\n", 50 | " self.left_branch = left_branch # 左子树\n", 51 | " self.right_branch = right_branch # 右子树" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 3, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "### 定义二叉决策树\n", 61 | "class BinaryDecisionTree(object):\n", 62 | " ### 决策树初始参数\n", 63 | " def __init__(self, min_samples_split=2, min_gini_impurity=999, max_depth=float(\"inf\"), loss=None):\n", 64 | " self.root = None # 根结点\n", 65 | " self.min_samples_split = min_samples_split # 节点最小分裂样本数\n", 66 | " self.mini_gini_impurity = min_gini_impurity # 节点初始化基尼不纯度\n", 67 | " self.max_depth = max_depth # 树最大深度\n", 68 | " self.gini_impurity_calculation = None # 基尼不纯度计算函数\n", 69 | " self._leaf_value_calculation = None # 叶子节点值预测函数\n", 70 | " self.loss = loss # 损失函数\n", 71 | "\n", 72 | " ### 决策树拟合函数\n", 73 | " def fit(self, X, y, loss=None):\n", 74 | " self.root = self._build_tree(X, y) # 递归构建决策树\n", 75 | " self.loss=None\n", 76 | "\n", 77 | " ### 决策树构建函数\n", 78 | " def _build_tree(self, X, y, current_depth=0):\n", 79 | " init_gini_impurity = 999 # 初始化最小基尼不纯度\n", 80 | " best_criteria = None # 初始化最佳特征索引和阈值\n", 81 | " best_sets = None # 初始化数据子集\n", 82 | " Xy = np.concatenate((X, y), axis=1) # 合并输入和标签\n", 83 | " n_samples, n_features = X.shape # 获取样本数和特征数\n", 84 | " \n", 85 | " # 设定决策树构建条件\n", 86 | " # 训练样本数量大于节点最小分裂样本数且当前树深度小于最大深度\n", 87 | " if n_samples >= self.min_samples_split and current_depth <= self.max_depth:\n", 88 | " for feature_i in range(n_features): # 遍历计算每个特征的基尼不纯度\n", 89 | " feature_values = np.expand_dims(X[:, feature_i], axis=1) # 获取第i特征的所有取值\n", 90 | " unique_values = np.unique(feature_values) # 获取第i个特征的唯一取值\n", 91 | "\n", 92 | " # 遍历取值并寻找最佳特征分裂阈值\n", 93 | " for threshold in unique_values:\n", 94 | " Xy1, Xy2 = feature_split(Xy, feature_i, threshold) # 特征节点二叉分裂\n", 95 | " if len(Xy1) > 0 and len(Xy2) > 0: # 如果分裂后的子集大小都不为0\n", 96 | " y1 = Xy1[:, n_features:] # 获取两个子集的标签值\n", 97 | " y2 = Xy2[:, n_features:]\n", 98 | " impurity = self.impurity_calculation(y, y1, y2) # 计算基尼不纯度\n", 99 | "\n", 100 | " # 获取最小基尼不纯度\n", 101 | " # 最佳特征索引和分裂阈值\n", 102 | " if impurity < init_gini_impurity:\n", 103 | " init_gini_impurity = impurity\n", 104 | " best_criteria = {\"feature_i\": feature_i, \"threshold\": threshold}\n", 105 | " best_sets = {\n", 106 | " \"leftX\": Xy1[:, :n_features], \n", 107 | " \"lefty\": Xy1[:, n_features:], \n", 108 | " \"rightX\": Xy2[:, :n_features], \n", 109 | " \"righty\": Xy2[:, n_features:] \n", 110 | " }\n", 111 | " \n", 112 | " # 如果计算的最小不纯度小于设定的最小不纯度\n", 113 | " if init_gini_impurity < self.mini_gini_impurity:\n", 114 | " # 分别构建左右子树\n", 115 | " left_branch = self._build_tree(best_sets[\"leftX\"], best_sets[\"lefty\"], current_depth + 1)\n", 116 | " right_branch = self._build_tree(best_sets[\"rightX\"], best_sets[\"righty\"], current_depth + 1)\n", 117 | " return TreeNode(feature_i=best_criteria[\"feature_i\"], threshold=best_criteria[\n", 118 | " \"threshold\"], left_branch=left_branch, right_branch=right_branch)\n", 119 | "\n", 120 | " # 计算叶子计算取值\n", 121 | " leaf_value = self._leaf_value_calculation(y)\n", 122 | "\n", 123 | " return TreeNode(leaf_value=leaf_value)\n", 124 | "\n", 125 | " ### 定义二叉树值预测函数\n", 126 | " def predict_value(self, x, tree=None):\n", 127 | " if tree is None:\n", 128 | " tree = self.root\n", 129 | "\n", 130 | " # 如果叶子节点已有值,则直接返回已有值\n", 131 | " if tree.leaf_value is not None:\n", 132 | " return tree.leaf_value\n", 133 | "\n", 134 | " # 选择特征并获取特征值\n", 135 | " feature_value = x[tree.feature_i]\n", 136 | "\n", 137 | " # 判断落入左子树还是右子树\n", 138 | " branch = tree.right_branch\n", 139 | " if isinstance(feature_value, int) or isinstance(feature_value, float):\n", 140 | " if feature_value >= tree.threshold:\n", 141 | " branch = tree.left_branch\n", 142 | " elif feature_value == tree.threshold:\n", 143 | " branch = tree.left_branch\n", 144 | "\n", 145 | " # 测试子集\n", 146 | " return self.predict_value(x, branch)\n", 147 | "\n", 148 | " ### 数据集预测函数\n", 149 | " def predict(self, X):\n", 150 | " y_pred = [self.predict_value(sample) for sample in X]\n", 151 | " return y_pred" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 4, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "### CART回归树\n", 161 | "class RegressionTree(BinaryDecisionTree):\n", 162 | " def _calculate_variance_reduction(self, y, y1, y2):\n", 163 | " var_tot = np.var(y, axis=0)\n", 164 | " var_y1 = np.var(y1, axis=0)\n", 165 | " var_y2 = np.var(y2, axis=0)\n", 166 | " frac_1 = len(y1) / len(y)\n", 167 | " frac_2 = len(y2) / len(y)\n", 168 | " # 计算方差减少量\n", 169 | " variance_reduction = var_tot - (frac_1 * var_y1 + frac_2 * var_y2)\n", 170 | " \n", 171 | " return sum(variance_reduction)\n", 172 | "\n", 173 | " # 节点值取平均\n", 174 | " def _mean_of_y(self, y):\n", 175 | " value = np.mean(y, axis=0)\n", 176 | " return value if len(value) > 1 else value[0]\n", 177 | "\n", 178 | " def fit(self, X, y):\n", 179 | " self.impurity_calculation = self._calculate_variance_reduction\n", 180 | " self._leaf_value_calculation = self._mean_of_y\n", 181 | " super(RegressionTree, self).fit(X, y)" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 5, 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "### CART决策树\n", 191 | "class ClassificationTree(BinaryDecisionTree):\n", 192 | " ### 定义基尼不纯度计算过程\n", 193 | " def _calculate_gini_impurity(self, y, y1, y2):\n", 194 | " p = len(y1) / len(y)\n", 195 | " gini = calculate_gini(y)\n", 196 | " gini_impurity = p * calculate_gini(y1) + (1-p) * calculate_gini(y2)\n", 197 | " return gini_impurity\n", 198 | " \n", 199 | " ### 多数投票\n", 200 | " def _majority_vote(self, y):\n", 201 | " most_common = None\n", 202 | " max_count = 0\n", 203 | " for label in np.unique(y):\n", 204 | " # 统计多数\n", 205 | " count = len(y[y == label])\n", 206 | " if count > max_count:\n", 207 | " most_common = label\n", 208 | " max_count = count\n", 209 | " return most_common\n", 210 | " \n", 211 | " # 分类树拟合\n", 212 | " def fit(self, X, y):\n", 213 | " self.impurity_calculation = self._calculate_gini_impurity\n", 214 | " self._leaf_value_calculation = self._majority_vote\n", 215 | " super(ClassificationTree, self).fit(X, y)" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 6, 221 | "metadata": {}, 222 | "outputs": [ 223 | { 224 | "name": "stderr", 225 | "output_type": "stream", 226 | "text": [ 227 | "E:\\study\\coding\\ML公式推导和代码实现\\charpter07_decision_tree\\utils.py:14: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n", 228 | " return np.array([X_left, X_right])\n" 229 | ] 230 | }, 231 | { 232 | "name": "stdout", 233 | "output_type": "stream", 234 | "text": [ 235 | "0.9555555555555556\n" 236 | ] 237 | } 238 | ], 239 | "source": [ 240 | "from sklearn import datasets\n", 241 | "data = datasets.load_iris()\n", 242 | "X, y = data.data, data.target\n", 243 | "X_train, X_test, y_train, y_test = train_test_split(X, y.reshape(-1,1), test_size=0.3)\n", 244 | "clf = ClassificationTree()\n", 245 | "clf.fit(X_train, y_train)\n", 246 | "y_pred = clf.predict(X_test)\n", 247 | "\n", 248 | "print(accuracy_score(y_test, y_pred))" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 7, 254 | "metadata": {}, 255 | "outputs": [ 256 | { 257 | "name": "stdout", 258 | "output_type": "stream", 259 | "text": [ 260 | "0.9555555555555556\n" 261 | ] 262 | } 263 | ], 264 | "source": [ 265 | "from sklearn.tree import DecisionTreeClassifier\n", 266 | "clf = DecisionTreeClassifier()\n", 267 | "clf.fit(X_train, y_train)\n", 268 | "y_pred = clf.predict(X_test)\n", 269 | "\n", 270 | "print(accuracy_score(y_test, y_pred))" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 8, 276 | "metadata": {}, 277 | "outputs": [ 278 | { 279 | "name": "stderr", 280 | "output_type": "stream", 281 | "text": [ 282 | "E:\\study\\coding\\ML公式推导和代码实现\\charpter07_decision_tree\\utils.py:14: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n", 283 | " return np.array([X_left, X_right])\n" 284 | ] 285 | }, 286 | { 287 | "name": "stdout", 288 | "output_type": "stream", 289 | "text": [ 290 | "Mean Squared Error: 148.45756578947368\n" 291 | ] 292 | } 293 | ], 294 | "source": [ 295 | "from sklearn.datasets import load_boston\n", 296 | "X, y = load_boston(return_X_y=True)\n", 297 | "X_train, X_test, y_train, y_test = train_test_split(X, y.reshape(-1,1), test_size=0.3)\n", 298 | "model = RegressionTree()\n", 299 | "model.fit(X_train, y_train)\n", 300 | "y_pred = model.predict(X_test)\n", 301 | "mse = mean_squared_error(y_test, y_pred)\n", 302 | "\n", 303 | "print(\"Mean Squared Error:\", mse)" 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": 9, 309 | "metadata": {}, 310 | "outputs": [ 311 | { 312 | "name": "stdout", 313 | "output_type": "stream", 314 | "text": [ 315 | "Mean Squared Error: 15.152039473684212\n" 316 | ] 317 | } 318 | ], 319 | "source": [ 320 | "from sklearn.tree import DecisionTreeRegressor\n", 321 | "reg = DecisionTreeRegressor()\n", 322 | "reg.fit(X_train, y_train)\n", 323 | "y_pred = reg.predict(X_test)\n", 324 | "mse = mean_squared_error(y_test, y_pred)\n", 325 | "\n", 326 | "print(\"Mean Squared Error:\", mse)" 327 | ] 328 | } 329 | ], 330 | "metadata": { 331 | "kernelspec": { 332 | "display_name": "Python 3 (ipykernel)", 333 | "language": "python", 334 | "name": "python3" 335 | }, 336 | "language_info": { 337 | "codemirror_mode": { 338 | "name": "ipython", 339 | "version": 3 340 | }, 341 | "file_extension": ".py", 342 | "mimetype": "text/x-python", 343 | "name": "python", 344 | "nbconvert_exporter": "python", 345 | "pygments_lexer": "ipython3", 346 | "version": "3.9.7" 347 | }, 348 | "toc": { 349 | "base_numbering": 1, 350 | "nav_menu": {}, 351 | "number_sections": true, 352 | "sideBar": true, 353 | "skip_h1_title": false, 354 | "title_cell": "Table of Contents", 355 | "title_sidebar": "Contents", 356 | "toc_cell": false, 357 | "toc_position": {}, 358 | "toc_section_display": true, 359 | "toc_window_display": false 360 | } 361 | }, 362 | "nbformat": 4, 363 | "nbformat_minor": 2 364 | } 365 | -------------------------------------------------------------------------------- /charpter07_decision_tree/example_data.csv: -------------------------------------------------------------------------------- 1 | humility,outlook,play,temp,windy 2 | high,sunny,no,hot,false 3 | high,sunny,no,hot,true 4 | high,overcast,yes,hot,false 5 | high,rainy,yes,mild,false 6 | normal,rainy,yes,cool,false 7 | normal,rainy,no,cool,true 8 | normal,overcast,yes,cool,true 9 | high,sunny,no,mild,false 10 | normal,sunny,yes,cool,false 11 | normal,rainy,yes,mild,false 12 | normal,sunny,yes,mild,true 13 | high,overcast,yes,mild,true 14 | normal,overcast,yes,hot,false 15 | high,rainy,no,mild,true 16 | -------------------------------------------------------------------------------- /charpter08_neural_networks/perceptron.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # 定义单层感知机类 4 | class Perceptron: 5 | def __init__(self): 6 | pass 7 | 8 | def sign(self, x, w, b): 9 | return np.dot(x, w) + b 10 | 11 | def train(self, X_train, y_train, learning_rate): 12 | # 参数初始化 13 | w, b = self.initilize_with_zeros(X_train.shape[1]) 14 | # 初始化误分类 15 | is_wrong = False 16 | while not is_wrong: 17 | wrong_count = 0 18 | for i in range(len(X_train)): 19 | X = X_train[i] 20 | y = y_train[i] 21 | 22 | # 如果存在误分类点 23 | # 更新参数 24 | # 直到没有误分类点 25 | if y * self.sign(X, w, b) <= 0: 26 | w = w + learning_rate*np.dot(y, X) 27 | b = b + learning_rate*y 28 | wrong_count += 1 29 | if wrong_count == 0: 30 | is_wrong = True 31 | print('There is no missclassification!') 32 | 33 | # 保存更新后的参数 34 | params = { 35 | 'w': w, 36 | 'b': b 37 | } 38 | return params 39 | -------------------------------------------------------------------------------- /charpter11_GBDT/__pycache__/cart.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter11_GBDT/__pycache__/cart.cpython-39.pyc -------------------------------------------------------------------------------- /charpter11_GBDT/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter11_GBDT/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /charpter11_GBDT/cart.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from utils import feature_split, calculate_gini 3 | 4 | ### 定义树结点 5 | class TreeNode(): 6 | def __init__(self, feature_i=None, threshold=None, 7 | leaf_value=None, left_branch=None, right_branch=None): 8 | # 特征索引 9 | self.feature_i = feature_i 10 | # 特征划分阈值 11 | self.threshold = threshold 12 | # 叶子节点取值 13 | self.leaf_value = leaf_value 14 | # 左子树 15 | self.left_branch = left_branch 16 | # 右子树 17 | self.right_branch = right_branch 18 | 19 | 20 | ### 定义二叉决策树 21 | class BinaryDecisionTree(object): 22 | ### 决策树初始参数 23 | def __init__(self, min_samples_split=2, min_gini_impurity=999, 24 | max_depth=float("inf"), loss=None): 25 | # 根结点 26 | self.root = None 27 | # 节点最小分裂样本数 28 | self.min_samples_split = min_samples_split 29 | # 节点初始化基尼不纯度 30 | self.mini_gini_impurity = min_gini_impurity 31 | # 树最大深度 32 | self.max_depth = max_depth 33 | # 基尼不纯度计算函数 34 | self.gini_impurity_calculation = None 35 | # 叶子节点值预测函数 36 | self._leaf_value_calculation = None 37 | # 损失函数 38 | self.loss = loss 39 | 40 | ### 决策树拟合函数 41 | def fit(self, X, y, loss=None): 42 | # 递归构建决策树 43 | self.root = self._build_tree(X, y) 44 | self.loss=None 45 | 46 | ### 决策树构建函数 47 | def _build_tree(self, X, y, current_depth=0): 48 | # 初始化最小基尼不纯度 49 | init_gini_impurity = 999 50 | # 初始化最佳特征索引和阈值 51 | best_criteria = None 52 | # 初始化数据子集 53 | best_sets = None 54 | 55 | if len(np.shape(y)) == 1: 56 | y = np.expand_dims(y, axis=1) 57 | 58 | # 合并输入和标签 59 | Xy = np.concatenate((X, y), axis=1) 60 | # 获取样本数和特征数 61 | n_samples, n_features = X.shape 62 | # 设定决策树构建条件 63 | # 训练样本数量大于节点最小分裂样本数且当前树深度小于最大深度 64 | if n_samples >= self.min_samples_split and current_depth <= self.max_depth: 65 | # 遍历计算每个特征的基尼不纯度 66 | for feature_i in range(n_features): 67 | # 获取第i特征的所有取值 68 | feature_values = np.expand_dims(X[:, feature_i], axis=1) 69 | # 获取第i个特征的唯一取值 70 | unique_values = np.unique(feature_values) 71 | 72 | # 遍历取值并寻找最佳特征分裂阈值 73 | for threshold in unique_values: 74 | # 特征节点二叉分裂 75 | Xy1, Xy2 = feature_split(Xy, feature_i, threshold) 76 | # 如果分裂后的子集大小都不为0 77 | if len(Xy1) > 0 and len(Xy2) > 0: 78 | # 获取两个子集的标签值 79 | y1 = Xy1[:, n_features:] 80 | y2 = Xy2[:, n_features:] 81 | 82 | # 计算基尼不纯度 83 | impurity = self.impurity_calculation(y, y1, y2) 84 | 85 | # 获取最小基尼不纯度 86 | # 最佳特征索引和分裂阈值 87 | if impurity < init_gini_impurity: 88 | init_gini_impurity = impurity 89 | best_criteria = {"feature_i": feature_i, "threshold": threshold} 90 | best_sets = { 91 | "leftX": Xy1[:, :n_features], 92 | "lefty": Xy1[:, n_features:], 93 | "rightX": Xy2[:, :n_features], 94 | "righty": Xy2[:, n_features:] 95 | } 96 | 97 | # 如果计算的最小不纯度小于设定的最小不纯度 98 | if init_gini_impurity < self.mini_gini_impurity: 99 | # 分别构建左右子树 100 | left_branch = self._build_tree(best_sets["leftX"], best_sets["lefty"], current_depth + 1) 101 | right_branch = self._build_tree(best_sets["rightX"], best_sets["righty"], current_depth + 1) 102 | return TreeNode(feature_i=best_criteria["feature_i"], threshold=best_criteria["threshold"], left_branch=left_branch, right_branch=right_branch) 103 | 104 | # 计算叶子计算取值 105 | leaf_value = self._leaf_value_calculation(y) 106 | return TreeNode(leaf_value=leaf_value) 107 | 108 | ### 定义二叉树值预测函数 109 | def predict_value(self, x, tree=None): 110 | if tree is None: 111 | tree = self.root 112 | # 如果叶子节点已有值,则直接返回已有值 113 | if tree.leaf_value is not None: 114 | return tree.leaf_value 115 | # 选择特征并获取特征值 116 | feature_value = x[tree.feature_i] 117 | # 判断落入左子树还是右子树 118 | branch = tree.right_branch 119 | if isinstance(feature_value, int) or isinstance(feature_value, float): 120 | if feature_value >= tree.threshold: 121 | branch = tree.left_branch 122 | elif feature_value == tree.threshold: 123 | branch = tree.right_branch 124 | # 测试子集 125 | return self.predict_value(x, branch) 126 | 127 | ### 数据集预测函数 128 | def predict(self, X): 129 | y_pred = [self.predict_value(sample) for sample in X] 130 | return y_pred 131 | 132 | # CART分类树 133 | class ClassificationTree(BinaryDecisionTree): 134 | ### 定义基尼不纯度计算过程 135 | def _calculate_gini_impurity(self, y, y1, y2): 136 | p = len(y1) / len(y) 137 | gini = calculate_gini(y) 138 | # 基尼不纯度 139 | gini_impurity = p * calculate_gini(y1) + (1-p) * calculate_gini(y2) 140 | return gini_impurity 141 | 142 | ### 多数投票 143 | def _majority_vote(self, y): 144 | most_common = None 145 | max_count = 0 146 | for label in np.unique(y): 147 | # 统计多数 148 | count = len(y[y == label]) 149 | if count > max_count: 150 | most_common = label 151 | max_count = count 152 | return most_common 153 | 154 | # 分类树拟合 155 | def fit(self, X, y): 156 | self.impurity_calculation = self._calculate_gini_impurity 157 | self._leaf_value_calculation = self._majority_vote 158 | super(ClassificationTree, self).fit(X, y) 159 | 160 | 161 | ### CART回归树 162 | class RegressionTree(BinaryDecisionTree): 163 | # 计算方差减少量 164 | def _calculate_variance_reduction(self, y, y1, y2): 165 | var_tot = np.var(y, axis=0) 166 | var_y1 = np.var(y1, axis=0) 167 | var_y2 = np.var(y2, axis=0) 168 | frac_1 = len(y1) / len(y) 169 | frac_2 = len(y2) / len(y) 170 | # 计算方差减少量 171 | variance_reduction = var_tot - (frac_1 * var_y1 + frac_2 * var_y2) 172 | return sum(variance_reduction) 173 | 174 | # 节点值取平均 175 | def _mean_of_y(self, y): 176 | value = np.mean(y, axis=0) 177 | return value if len(value) > 1 else value[0] 178 | 179 | # 回归树拟合 180 | def fit(self, X, y): 181 | self.impurity_calculation = self._calculate_variance_reduction 182 | self._leaf_value_calculation = self._mean_of_y 183 | super(RegressionTree, self).fit(X, y) 184 | -------------------------------------------------------------------------------- /charpter11_GBDT/gbdt.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### GBDT" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "from cart import TreeNode, BinaryDecisionTree, ClassificationTree, RegressionTree\n", 18 | "from sklearn.model_selection import train_test_split\n", 19 | "from sklearn.metrics import mean_squared_error\n", 20 | "from utils import feature_split, calculate_gini, data_shuffle" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "### GBDT定义\n", 30 | "class GBDT(object):\n", 31 | " def __init__(self, n_estimators, learning_rate, min_samples_split,\n", 32 | " min_gini_impurity, max_depth, regression):\n", 33 | " ### 常用超参数\n", 34 | " # 树的棵树\n", 35 | " self.n_estimators = n_estimators\n", 36 | " # 学习率\n", 37 | " self.learning_rate = learning_rate\n", 38 | " # 结点最小分裂样本数\n", 39 | " self.min_samples_split = min_samples_split\n", 40 | " # 结点最小基尼不纯度\n", 41 | " self.min_gini_impurity = min_gini_impurity\n", 42 | " # 最大深度\n", 43 | " self.max_depth = max_depth\n", 44 | " # 默认为回归树\n", 45 | " self.regression = regression\n", 46 | " # 损失为平方损失\n", 47 | " self.loss = SquareLoss()\n", 48 | " # 如果是分类树,需要定义分类树损失函数\n", 49 | " # 这里省略,如需使用,需自定义分类损失函数\n", 50 | " if not self.regression:\n", 51 | " self.loss = None\n", 52 | " # 多棵树叠加\n", 53 | " self.estimators = []\n", 54 | " for i in range(self.n_estimators):\n", 55 | " self.estimators.append(RegressionTree(min_samples_split=self.min_samples_split,\n", 56 | " min_gini_impurity=self.min_gini_impurity,\n", 57 | " max_depth=self.max_depth))\n", 58 | " # 拟合方法\n", 59 | " def fit(self, X, y):\n", 60 | " # 前向分步模型初始化,第一棵树\n", 61 | " self.estimators[0].fit(X, y)\n", 62 | " # 第一棵树的预测结果\n", 63 | " y_pred = self.estimators[0].predict(X)\n", 64 | " # 前向分步迭代训练\n", 65 | " for i in range(1, self.n_estimators):\n", 66 | " gradient = self.loss.gradient(y, y_pred)\n", 67 | " self.estimators[i].fit(X, gradient)\n", 68 | " y_pred -= np.multiply(self.learning_rate, self.estimators[i].predict(X))\n", 69 | " \n", 70 | " # 预测方法\n", 71 | " def predict(self, X):\n", 72 | " # 回归树预测\n", 73 | " y_pred = self.estimators[0].predict(X)\n", 74 | " for i in range(1, self.n_estimators):\n", 75 | " y_pred -= np.multiply(self.learning_rate, self.estimators[i].predict(X))\n", 76 | " # 分类树预测\n", 77 | " if not self.regression:\n", 78 | " # 将预测值转化为概率\n", 79 | " y_pred = np.exp(y_pred) / np.expand_dims(np.sum(np.exp(y_pred), axis=1), axis=1)\n", 80 | " # 转化为预测标签\n", 81 | " y_pred = np.argmax(y_pred, axis=1)\n", 82 | " return y_pred" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 3, 88 | "metadata": { 89 | "code_folding": [] 90 | }, 91 | "outputs": [], 92 | "source": [ 93 | "### GBDT分类树\n", 94 | "class GBDTClassifier(GBDT):\n", 95 | " def __init__(self, n_estimators=200, learning_rate=.5, min_samples_split=2,\n", 96 | " min_info_gain=1e-6, max_depth=2):\n", 97 | " super(GBDTClassifier, self).__init__(n_estimators=n_estimators,\n", 98 | " learning_rate=learning_rate,\n", 99 | " min_samples_split=min_samples_split,\n", 100 | " min_gini_impurity=min_info_gain,\n", 101 | " max_depth=max_depth,\n", 102 | " regression=False)\n", 103 | " # 拟合方法\n", 104 | " def fit(self, X, y):\n", 105 | " super(GBDTClassifier, self).fit(X, y)\n", 106 | " \n", 107 | "### GBDT回归树\n", 108 | "class GBDTRegressor(GBDT):\n", 109 | " def __init__(self, n_estimators=300, learning_rate=0.1, min_samples_split=2,\n", 110 | " min_var_reduction=1e-6, max_depth=3):\n", 111 | " super(GBDTRegressor, self).__init__(n_estimators=n_estimators,\n", 112 | " learning_rate=learning_rate,\n", 113 | " min_samples_split=min_samples_split,\n", 114 | " min_gini_impurity=min_var_reduction,\n", 115 | " max_depth=max_depth,\n", 116 | " regression=True)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 4, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "### 定义回归树的平方损失\n", 126 | "class SquareLoss():\n", 127 | " # 定义平方损失\n", 128 | " def loss(self, y, y_pred):\n", 129 | " return 0.5 * np.power((y - y_pred), 2)\n", 130 | " # 定义平方损失的梯度\n", 131 | " def gradient(self, y, y_pred):\n", 132 | " return -(y - y_pred)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 5, 138 | "metadata": {}, 139 | "outputs": [ 140 | { 141 | "name": "stderr", 142 | "output_type": "stream", 143 | "text": [ 144 | "E:\\study\\coding\\ML公式推导和代码实现\\charpter11_GBDT\\utils.py:14: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n", 145 | " return np.array([X_left, X_right])\n" 146 | ] 147 | }, 148 | { 149 | "name": "stdout", 150 | "output_type": "stream", 151 | "text": [ 152 | "Mean Squared Error of NumPy GBRT: 84.29078032628252\n" 153 | ] 154 | } 155 | ], 156 | "source": [ 157 | "### GBDT分类树\n", 158 | "# 导入sklearn数据集模块\n", 159 | "from sklearn import datasets\n", 160 | "# 导入波士顿房价数据集\n", 161 | "boston = datasets.load_boston()\n", 162 | "# 打乱数据集\n", 163 | "X, y = data_shuffle(boston.data, boston.target, seed=13)\n", 164 | "X = X.astype(np.float32)\n", 165 | "offset = int(X.shape[0] * 0.9)\n", 166 | "# 划分数据集\n", 167 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)\n", 168 | "# 创建GBRT实例\n", 169 | "model = GBDTRegressor()\n", 170 | "# 模型训练\n", 171 | "model.fit(X_train, y_train)\n", 172 | "# 模型预测\n", 173 | "y_pred = model.predict(X_test)\n", 174 | "# 计算模型预测的均方误差\n", 175 | "mse = mean_squared_error(y_test, y_pred)\n", 176 | "print (\"Mean Squared Error of NumPy GBRT:\", mse)" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 6, 182 | "metadata": {}, 183 | "outputs": [ 184 | { 185 | "name": "stdout", 186 | "output_type": "stream", 187 | "text": [ 188 | "Mean Squared Error of sklearn GBRT: 14.88424955427429\n" 189 | ] 190 | } 191 | ], 192 | "source": [ 193 | "# 导入sklearn GBDT模块\n", 194 | "from sklearn.ensemble import GradientBoostingRegressor\n", 195 | "# 创建模型实例\n", 196 | "reg = GradientBoostingRegressor(n_estimators=200, learning_rate=0.5, max_depth=4, random_state=0)\n", 197 | "# 模型拟合\n", 198 | "reg.fit(X_train, y_train)\n", 199 | "# 模型预测\n", 200 | "y_pred = reg.predict(X_test)\n", 201 | "# 计算模型预测的均方误差\n", 202 | "mse = mean_squared_error(y_test, y_pred)\n", 203 | "print (\"Mean Squared Error of sklearn GBRT:\", mse)" 204 | ] 205 | } 206 | ], 207 | "metadata": { 208 | "kernelspec": { 209 | "display_name": "Python 3 (ipykernel)", 210 | "language": "python", 211 | "name": "python3" 212 | }, 213 | "language_info": { 214 | "codemirror_mode": { 215 | "name": "ipython", 216 | "version": 3 217 | }, 218 | "file_extension": ".py", 219 | "mimetype": "text/x-python", 220 | "name": "python", 221 | "nbconvert_exporter": "python", 222 | "pygments_lexer": "ipython3", 223 | "version": "3.9.7" 224 | }, 225 | "toc": { 226 | "base_numbering": 1, 227 | "nav_menu": {}, 228 | "number_sections": true, 229 | "sideBar": true, 230 | "skip_h1_title": false, 231 | "title_cell": "Table of Contents", 232 | "title_sidebar": "Contents", 233 | "toc_cell": false, 234 | "toc_position": {}, 235 | "toc_section_display": true, 236 | "toc_window_display": false 237 | } 238 | }, 239 | "nbformat": 4, 240 | "nbformat_minor": 2 241 | } 242 | -------------------------------------------------------------------------------- /charpter11_GBDT/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | ### 定义二叉特征分裂函数 4 | def feature_split(X, feature_i, threshold): 5 | split_func = None 6 | if isinstance(threshold, int) or isinstance(threshold, float): 7 | split_func = lambda sample: sample[feature_i] >= threshold 8 | else: 9 | split_func = lambda sample: sample[feature_i] == threshold 10 | 11 | X_left = np.array([sample for sample in X if split_func(sample)]) 12 | X_right = np.array([sample for sample in X if not split_func(sample)]) 13 | 14 | return np.array([X_left, X_right]) 15 | 16 | 17 | ### 计算基尼指数 18 | def calculate_gini(y): 19 | y = y.tolist() 20 | probs = [y.count(i)/len(y) for i in np.unique(y)] 21 | gini = sum([p*(1-p) for p in probs]) 22 | return gini 23 | 24 | 25 | ### 打乱数据 26 | def data_shuffle(X, y, seed=None): 27 | if seed: 28 | np.random.seed(seed) 29 | idx = np.arange(X.shape[0]) 30 | np.random.shuffle(idx) 31 | return X[idx], y[idx] 32 | -------------------------------------------------------------------------------- /charpter12_XGBoost/__pycache__/cart.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter12_XGBoost/__pycache__/cart.cpython-39.pyc -------------------------------------------------------------------------------- /charpter12_XGBoost/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter12_XGBoost/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /charpter12_XGBoost/cart.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from utils import feature_split, calculate_gini 3 | 4 | ### 定义树结点 5 | class TreeNode(): 6 | def __init__(self, feature_i=None, threshold=None, 7 | leaf_value=None, left_branch=None, right_branch=None): 8 | # 特征索引 9 | self.feature_i = feature_i 10 | # 特征划分阈值 11 | self.threshold = threshold 12 | # 叶子节点取值 13 | self.leaf_value = leaf_value 14 | # 左子树 15 | self.left_branch = left_branch 16 | # 右子树 17 | self.right_branch = right_branch 18 | 19 | ### 定义二叉决策树 20 | class BinaryDecisionTree(object): 21 | ### 决策树初始参数 22 | def __init__(self, min_samples_split=2, min_gini_impurity=999, 23 | max_depth=float("inf"), loss=None): 24 | # 根结点 25 | self.root = None 26 | # 节点最小分裂样本数 27 | self.min_samples_split = min_samples_split 28 | # 节点初始化基尼不纯度 29 | self.min_gini_impurity = min_gini_impurity 30 | # 树最大深度 31 | self.max_depth = max_depth 32 | # 基尼不纯度计算函数 33 | self.gini_impurity_calculation = None 34 | # 叶子节点值预测函数 35 | self._leaf_value_calculation = None 36 | # 损失函数 37 | self.loss = loss 38 | 39 | ### 决策树拟合函数 40 | def fit(self, X, y, loss=None): 41 | # 递归构建决策树 42 | self.root = self._build_tree(X, y) 43 | self.loss=None 44 | 45 | ### 决策树构建函数 46 | def _build_tree(self, X, y, current_depth=0): 47 | # 初始化最小基尼不纯度 48 | init_gini_impurity = 999 49 | # 初始化最佳特征索引和阈值 50 | best_criteria = None 51 | # 初始化数据子集 52 | best_sets = None 53 | 54 | if len(np.shape(y)) == 1: 55 | y = np.expand_dims(y, axis=1) 56 | 57 | # 合并输入和标签 58 | Xy = np.concatenate((X, y), axis=1) 59 | # 获取样本数和特征数 60 | n_samples, n_features = X.shape 61 | # 设定决策树构建条件 62 | # 训练样本数量大于节点最小分裂样本数且当前树深度小于最大深度 63 | if n_samples >= self.min_samples_split and current_depth <= self.max_depth: 64 | # 遍历计算每个特征的基尼不纯度 65 | for feature_i in range(n_features): 66 | # 获取第i特征的所有取值 67 | feature_values = np.expand_dims(X[:, feature_i], axis=1) 68 | # 获取第i个特征的唯一取值 69 | unique_values = np.unique(feature_values) 70 | 71 | # 遍历取值并寻找最佳特征分裂阈值 72 | for threshold in unique_values: 73 | # 特征节点二叉分裂 74 | Xy1, Xy2 = feature_split(Xy, feature_i, threshold) 75 | # 如果分裂后的子集大小都不为0 76 | if len(Xy1) > 0 and len(Xy2) > 0: 77 | # 获取两个子集的标签值 78 | y1 = Xy1[:, n_features:] 79 | y2 = Xy2[:, n_features:] 80 | 81 | # 计算基尼不纯度 82 | impurity = self.impurity_calculation(y, y1, y2) 83 | 84 | # 获取最小基尼不纯度 85 | # 最佳特征索引和分裂阈值 86 | if impurity < init_gini_impurity: 87 | init_gini_impurity = impurity 88 | best_criteria = {"feature_i": feature_i, "threshold": threshold} 89 | best_sets = { 90 | "leftX": Xy1[:, :n_features], 91 | "lefty": Xy1[:, n_features:], 92 | "rightX": Xy2[:, :n_features], 93 | "righty": Xy2[:, n_features:] 94 | } 95 | 96 | # 如果计算的最小不纯度小于设定的最小不纯度 97 | if init_gini_impurity < self.min_gini_impurity: 98 | # 分别构建左右子树 99 | left_branch = self._build_tree(best_sets["leftX"], best_sets["lefty"], current_depth + 1) 100 | right_branch = self._build_tree(best_sets["rightX"], best_sets["righty"], current_depth + 1) 101 | return TreeNode(feature_i=best_criteria["feature_i"], threshold=best_criteria["threshold"], left_branch=left_branch, right_branch=right_branch) 102 | 103 | # 计算叶子计算取值 104 | leaf_value = self._leaf_value_calculation(y) 105 | return TreeNode(leaf_value=leaf_value) 106 | 107 | ### 定义二叉树值预测函数 108 | def predict_value(self, x, tree=None): 109 | if tree is None: 110 | tree = self.root 111 | # 如果叶子节点已有值,则直接返回已有值 112 | if tree.leaf_value is not None: 113 | return tree.leaf_value 114 | # 选择特征并获取特征值 115 | feature_value = x[tree.feature_i] 116 | # 判断落入左子树还是右子树 117 | branch = tree.right_branch 118 | if isinstance(feature_value, int) or isinstance(feature_value, float): 119 | if feature_value >= tree.threshold: 120 | branch = tree.left_branch 121 | elif feature_value == tree.threshold: 122 | branch = tree.right_branch 123 | # 测试子集 124 | return self.predict_value(x, branch) 125 | 126 | ### 数据集预测函数 127 | def predict(self, X): 128 | y_pred = [self.predict_value(sample) for sample in X] 129 | return y_pred 130 | 131 | 132 | 133 | class ClassificationTree(BinaryDecisionTree): 134 | ### 定义基尼不纯度计算过程 135 | def _calculate_gini_impurity(self, y, y1, y2): 136 | p = len(y1) / len(y) 137 | gini = calculate_gini(y) 138 | gini_impurity = p * calculate_gini(y1) + (1-p) * calculate_gini(y2) 139 | return gini_impurity 140 | 141 | ### 多数投票 142 | def _majority_vote(self, y): 143 | most_common = None 144 | max_count = 0 145 | for label in np.unique(y): 146 | # 统计多数 147 | count = len(y[y == label]) 148 | if count > max_count: 149 | most_common = label 150 | max_count = count 151 | return most_common 152 | 153 | # 分类树拟合 154 | def fit(self, X, y): 155 | self.impurity_calculation = self._calculate_gini_impurity 156 | self._leaf_value_calculation = self._majority_vote 157 | super(ClassificationTree, self).fit(X, y) 158 | 159 | 160 | ### CART回归树 161 | class RegressionTree(BinaryDecisionTree): 162 | # 计算方差减少量 163 | def _calculate_variance_reduction(self, y, y1, y2): 164 | var_tot = np.var(y, axis=0) 165 | var_y1 = np.var(y1, axis=0) 166 | var_y2 = np.var(y2, axis=0) 167 | frac_1 = len(y1) / len(y) 168 | frac_2 = len(y2) / len(y) 169 | # 计算方差减少量 170 | variance_reduction = var_tot - (frac_1 * var_y1 + frac_2 * var_y2) 171 | return sum(variance_reduction) 172 | 173 | # 节点值取平均 174 | def _mean_of_y(self, y): 175 | value = np.mean(y, axis=0) 176 | return value if len(value) > 1 else value[0] 177 | 178 | # 回归树拟合 179 | def fit(self, X, y): 180 | self.impurity_calculation = self._calculate_variance_reduction 181 | self._leaf_value_calculation = self._mean_of_y 182 | super(RegressionTree, self).fit(X, y) 183 | -------------------------------------------------------------------------------- /charpter12_XGBoost/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | ### 定义二叉特征分裂函数 4 | def feature_split(X, feature_i, threshold): 5 | split_func = None 6 | if isinstance(threshold, int) or isinstance(threshold, float): 7 | split_func = lambda sample: sample[feature_i] >= threshold 8 | else: 9 | split_func = lambda sample: sample[feature_i] == threshold 10 | 11 | X_left = np.array([sample for sample in X if split_func(sample)]) 12 | X_right = np.array([sample for sample in X if not split_func(sample)]) 13 | return np.array([X_left, X_right]) 14 | 15 | ### 计算基尼指数 16 | def calculate_gini(y): 17 | y = y.tolist() 18 | probs = [y.count(i)/len(y) for i in np.unique(y)] 19 | gini = sum([p*(1-p) for p in probs]) 20 | return gini 21 | 22 | ### 打乱数据 23 | def data_shuffle(X, y, seed=None): 24 | if seed: 25 | np.random.seed(seed) 26 | idx = np.arange(X.shape[0]) 27 | np.random.shuffle(idx) 28 | return X[idx], y[idx] 29 | 30 | ### 类别标签转换 31 | def cat_label_convert(y, n_col=None): 32 | if not n_col: 33 | n_col = np.amax(y) + 1 34 | one_hot = np.zeros((y.shape[0], n_col)) 35 | one_hot[np.arange(y.shape[0]), y] = 1 36 | return one_hot 37 | -------------------------------------------------------------------------------- /charpter13_LightGBM/lightgbm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "C:\\Users\\18765\\AppData\\Roaming\\Python\\Python39\\site-packages\\lightgbm\\sklearn.py:726: UserWarning: 'early_stopping_rounds' argument is deprecated and will be removed in a future release of LightGBM. Pass 'early_stopping()' callback via 'callbacks' argument instead.\n", 13 | " _log_warning(\"'early_stopping_rounds' argument is deprecated and will be removed in a future release of LightGBM. \"\n" 14 | ] 15 | }, 16 | { 17 | "name": "stdout", 18 | "output_type": "stream", 19 | "text": [ 20 | "[1]\tvalid_0's multi_logloss: 1.02277\n", 21 | "[2]\tvalid_0's multi_logloss: 0.943765\n", 22 | "[3]\tvalid_0's multi_logloss: 0.873274\n", 23 | "[4]\tvalid_0's multi_logloss: 0.810478\n", 24 | "[5]\tvalid_0's multi_logloss: 0.752973\n", 25 | "[6]\tvalid_0's multi_logloss: 0.701621\n", 26 | "[7]\tvalid_0's multi_logloss: 0.654982\n", 27 | "[8]\tvalid_0's multi_logloss: 0.611268\n", 28 | "[9]\tvalid_0's multi_logloss: 0.572202\n", 29 | "[10]\tvalid_0's multi_logloss: 0.53541\n", 30 | "[11]\tvalid_0's multi_logloss: 0.502582\n", 31 | "[12]\tvalid_0's multi_logloss: 0.472856\n", 32 | "[13]\tvalid_0's multi_logloss: 0.443853\n", 33 | "[14]\tvalid_0's multi_logloss: 0.417764\n", 34 | "[15]\tvalid_0's multi_logloss: 0.393613\n", 35 | "[16]\tvalid_0's multi_logloss: 0.370679\n", 36 | "[17]\tvalid_0's multi_logloss: 0.349936\n", 37 | "[18]\tvalid_0's multi_logloss: 0.330669\n", 38 | "[19]\tvalid_0's multi_logloss: 0.312805\n", 39 | "[20]\tvalid_0's multi_logloss: 0.296973\n", 40 | "Accuracy of lightgbm: 1.0\n" 41 | ] 42 | }, 43 | { 44 | "data": { 45 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaMAAAEWCAYAAADLkvgyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAic0lEQVR4nO3dfZxWdZ3/8dcbUEBQiEA0CEfRjLtEQc3WH46urhlo2pY3q23k3ba7Zqh5s7UVW2y6iJum/XQN9yc/KXAzUQLT/ImDpq4JiIImrTdT4B1JikCCw/D5/XHO4MUww1zDzMV3rmvez8fjenCuc3Odz2fQeXO+51znKCIwMzNLqUvqAszMzBxGZmaWnMPIzMyScxiZmVlyDiMzM0vOYWRmZsk5jMzKiKRvSJqeug6z9iZ/z8g6C0m1wECgvmD2xyLitTZ+5vkR8f/aVl35kTQZODAizkldi5U/HxlZZ3NyRPQueO10ELUHSd1S7n9nlWvd1nE5jKzTk9RH0m2SXpf0qqQpkrrmy4ZKWiBpjaS3JP1EUt982R3AEOAXktZLukJStaRVjT6/VtLx+fRkSXdJminpXWDijvbfRK2TJc3Mp6skhaQvS1op6W1JX5F0uKRnJb0j6aaCbSdKekzSjZLWSnpB0l8WLP+IpLmS/iTpRUkXNNpvYd1fAb4BnJH3/ky+3pcl/VbSOkkvS/q7gs+olrRK0mWSVuf9frlgeU9J10n6fV7fryX1zJd9UtLjeU/PSKreib9q68AcRmYwA9gMHAgcCvwVcH6+TMDVwEeAYcBHgckAEfFF4A98cLQ1tcj9fRa4C+gL/KSF/RfjSOAg4AzgeuCbwPHACOB0Scc0WvdloD/wHeBuSf3yZbOAVXmvnwe+XxhWjeq+Dfg+cGfe+yH5OquBCcBewJeBH0g6rOAz9gH6AIOA84AfSfpQvmwaMAb4FNAPuALYImkQMB+Yks//OvBzSQNa8TOyDs5hZJ3NPfm/rt+RdI+kgcBJwKSI2BARq4EfAGcCRMSLEfFgRGyKiD8C/w4c0/zHF+WJiLgnIraQ/dJudv9F+l5EbIyIXwEbgFkRsToiXgUeJQu4BquB6yOiLiLuBFYA4yV9FDgauDL/rKXAdOCLTdUdEe81VUhEzI+IlyKzEPgV8L8KVqkDvpvv/z5gPXCwpC7AucDXIuLViKiPiMcjYhNwDnBfRNyX7/tBYBHwmVb8jKyD87ivdTanFl5sIOkIYDfgdUkNs7sAK/PlewM/JPuFume+7O021rCyYHq/He2/SG8WTL/XxPveBe9fjW2vWvo92ZHQR4A/RcS6RsvGNlN3kySdRHbE9TGyPvYAlhWssiYiNhe8/3NeX3+gB/BSEx+7H/AFSScXzNsNeLileqx8OIyss1sJbAL6N/ol2eBqIIBPRMQaSacCNxUsb3w56gayX8AA5Od+Gg8nFW7T0v7b2yBJKgikIcBc4DWgn6Q9CwJpCPBqwbaNe93mvaTuwM+BvwXujYg6SfeQDXW25C1gIzAUeKbRspXAHRFxwXZbWcXwMJ11ahHxOtlQ0nWS9pLUJb9ooWEobk+yoaR38nMXlzf6iDeBAwre/w7oIWm8pN2Afwa6t2H/7W1v4GJJu0n6Atl5sPsiYiXwOHC1pB6SPkF2TucnO/isN4GqfIgNYHeyXv8IbM6Pkv6qmKLyIcv/BP49v5Ciq6Sj8oCbCZws6cR8fo/8YojBrW/fOiqHkVn2L/ndgefJhuDuAvbNl/0LcBiwluwk+t2Ntr0a+Of8HNTXI2It8A9k51teJTtSWsWO7Wj/7e1Jsosd3gL+Ffh8RKzJl50FVJEdJc0BvpOfn2nOz/I/10hakh9RXQz8F1kff0N21FWsr5MN6T0F/An4N6BLHpSfJbt6749kR0qX499fFcVfejXrJCRNJPuC7tGpazFrzP+yMDOz5BxGZmaWnIfpzMwsOR8ZmZlZcv6e0U7q27dvHHjgganLaFcbNmygV69eqctoV5XYE1RmX5XYE1RmX23pafHixW9FxHa3cnIY7aSBAweyaNGi1GW0q5qaGqqrq1OX0a4qsSeozL4qsSeozL7a0pOk3zc138N0ZmaWnMPIzMyScxiZmVlyDiMzM0vOYWRmZsk5jMzMLDmHkZmZJecwMjOz5BxGZmaWnMPIzMyScxiZmVlyDiMzM0vOYWRmZsk5jMzMLDmHkZmZJecwMjOz5BxGZmaWnMPIzMyScxiZmVlyDiMzM0vOYWRmZsk5jMzMLDmHkZmZJecwMjOz5BxGZmaWnMPIzMyScxiZmVlyDiMzM0vOYWRmZsk5jMzMLDmHkZmZJecwMjOz5BxGZmaWnMPIzMyScxiZmVlyDiMzM0vOYWRmZskpIlLXUJaGHHBgdDn9htRltKvLRm3mumXdUpfRriqxJ6jMviqxJ+iYfdVeM75N29fU1FBdXb1T20paHBFjG8/3kZGZWSd37rnnsvfeezNy5Mit8372s58xYsQIunTpwqJFi7bb5g9/+AO9e/dm2rRp7VKDw8jMrJObOHEi999//zbzRo4cyd133824ceOa3OaSSy7hpJNOarcaShpGkvaRNFvSS5Kel3SfpI81s26VpOWlrKeZ/e4hab6kFyQ9J+maXV2DmVlK48aNo1+/ftvMGzZsGAcffHCT6//617/mgAMOYMSIEe1WQ8nCSJKAOUBNRAyNiOHAN4CBpdpnG0yLiI8DhwJ/Ian94t7MrIJs2LCBWbNm8Z3vfKddP7eUZ9WOBeoi4paGGRGxVJlrgZOAAKZExJ2FG0qaCIyNiIvy9/PIAqNG0nrgR8DxwNtkATcVGAJMioi5+fanAHsAQ4E5EXFFU0VGxJ+Bh/Pp9yUtAQY3ta6kC4ELAfr3H8C3R21u/U+lAxvYMzvZWkkqsSeozL4qsSfomH3V1NRsN++NN95gw4YN2y175513WLx4MevXrwfg5ptvZsKECSxatIja2lp69uzZ5Oe1VinDaCSwuIn5nwNGA4cA/YGnJD3Sis/tRXa0daWkOcAU4ARgODADmJuvN5rsSGcTsELSjRGxckcfLKkvcDLQ5GVyEXErcCtkV9N1tCtk2qojXvXTVpXYE1RmX5XYE3TMvmrPrt5+Xm0tvXr12u4qub59+zJmzBjGjs0ugPvWt77FwoULufPOO3nnnXfo0qULI0aM4KKLLmpTTSl+QkcDsyKiHnhT0kLgcODZIrd/H2g407YM2BQRdZKWAVUF6z0UEWsBJD0P7Ac0G0aSugGzgB9GxMut6MfMrNN49NFHt17aPXnyZHr37t3mIILSXsDwHDCmifkqYtvNbFtbj4Lpuvjgy1FbyI58iIgtbBuumwqm62k5eG8F/iciri+iPjOzinHWWWdx1FFHsWLFCgYPHsxtt93GnDlzGDx4ME888QTjx4/nxBNPLGkNpTwyWgB8X9IFEfFjAEmHk53nOUPSDKAfMA64nG0Dpxb4B0ldgEHAESWsE0lTgD7A+cVu03O3rqxo4xfHOpqampomD9/LWSX2BJXZVyX2BOXR16xZs5qcf9ppp+1wu8mTJ7dbDSULo4gISacB10u6CthIFjKTgN7AM2QXMFwREW9IqirY/DHgFbJhuOXAklLVKWkw8E3gBWBJdhEgN0XE9FLt08zMtlXSc0YR8RpwehOLLs9fhevWkl30QD4Md3Yzn9m7YHpyU8si4nbg9oL5E3ZQ4yqKGzo0M7MS8R0YzMwsuY51vWGJSXoS6N5o9hcjYlmKeszMLNOpwigijkxdg5mZbc/DdGZmlpzDyMzMknMYmZlZcg4jMzNLzmFkZmbJOYzMzCw5h5GZmSXnMDIzs+QcRmZmlpzDyMzMknMYmZlZcg4jMzNLzmFkZmbJOYzMzCw5h5GZmSXnMDIzs+QcRmZmlpzDyMzMknMYmZlZcg4jMzNLzmFkZmbJOYzMzCw5h5GZmSXnMDIzs+QcRmZmlpzDyMzMknMYmZlZcg4jMzNLzmFkZmbJdUtdQLl6r66eqqvmpy6jXV02ajMT3dN2aq8Zv928+vp6xo4dy6BBg5g3bx4AN954IzfddBPdunVj/PjxTJ06tU37NetMHEZmO+GGG25g2LBhvPvuuwA8/PDD3HvvvTz77LN0796d1atXJ67QrLyUdJhO0j6SZkt6SdLzku6T9LFm1q2StLyU9TRH0v2SnpH0nKRbJHVNUYeVh1WrVjF//nzOP//8rfNuvvlmrrrqKrp37w7A3nvvnao8s7JUsjCSJGAOUBMRQyNiOPANYGCp9tkGp0fEIcBIYADwhcT1WAc2adIkpk6dSpcuH/zv87vf/Y5HH32UI488kmOOOYannnoqYYVm5aeUw3THAnURcUvDjIhYqsy1wElAAFMi4s7CDSVNBMZGxEX5+3nAtIiokbQe+BFwPPA2WcBNBYYAkyJibr79KcAewFBgTkRc0VyhEfFuPtkN2D2vazuSLgQuBOjffwDfHrW5FT+Ojm9gz+wcSyVpj55qamq2Tj/xxBPU1dWxbt06li5dypo1a6ipqWHt2rUsW7aMa665hhdeeIFTTjmFn/70p2T/Jmt/69ev36auSlCJPUFl9lWKnkoZRiOBxU3M/xwwGjgE6A88JemRVnxuL7KjrSslzQGmACcAw4EZwNx8vdHAocAmYIWkGyNiZXMfKukB4Ajgl8BdTa0TEbcCtwIMOeDAuG5ZZZ1yu2zUZtzT9mrPrt46/cADD7B48WImTpzIxo0beffdd5k+fToHH3wwF198MdXV1Rx77LFMmzaNkSNHMmDAgDZ20LSamhqqq6tbXK+cVGJPUJl9laKnFJd2Hw3Mioj6iHgTWAgc3ort3wfuz6eXAQsjoi6fripY76GIWBsRG4Hngf129KERcSKwL9AdOK4V9VgncvXVV7Nq1Spqa2uZPXs2xx13HDNnzuTUU09lwYIFQDZk9/7779O/f//E1ZqVj6LCSNJQSd3z6WpJF0vq28JmzwFjmvq4Ina5uVFtPQqm6yKiYRhtC9mRDxGxhW2P9DYVTNdTxFFgHlxzgc8WUaPZVueeey4vv/wyI0eO5Mwzz2TGjBklG6Izq0TFHhn9HKiXdCBwG7A/8NMWtlkAdJd0QcMMSYeTnec5Q1JXSQOAccBvGm1bC4yW1EXSR8mGz0pCUm9J++bT3YDPAC+Uan9WOaqrq7d+x2j33Xdn5syZLF++nCVLlnDccT64NmuNYgfTt0TEZkmnAddHxI2Snt7RBhERDetLugrYSBYyk4DewDNkFwpcERFvSKoq2Pwx4BWyobflwJLiW2q1XsDc/MivK1mI3rLjTaDnbl1Z0cSXIctZTU3NNudHKkEl9mRWiYoNozpJZwFfAk7O5+3W0kYR8RpwehOLLs9fhevWkl30QD4Md3Yzn9m7YHpyU8si4nbg9oL5E3ZQ45u07pyVmZm1s2KH6b4MHAX8a0S8Iml/YGbpyjIzs86kqCOjiHhe0pVk3+UhIl4BrillYaUg6Umyq+UKfTEilqWox8zMMkWFkaSTgWlkXwjdX9Jo4LsRcUoJa2t3EXFk6hrMzGx7xQ7TTSa7ou0dyO6kQHZFnZmZWZsVG0abI2Jto3lN3jLHzMystYq9mm65pL8Buko6CLgYeLx0ZZmZWWdS7JHRV4ERZHc1+Cmwluz7QmZmZm3W4pFR/myfuRFxPPDN0pdkZmadTYtHRhFRD/xZUp9dUI+ZmXVCxZ4z2ggsk/QgsKFhZkRcXJKqzMysUyk2jObnLzMzs3ZX7B0YZpS6EDMz67yKvQPDKzTxvaKIOKDdKzIzs06n2GG6sQXTPYAvAP3avxwzM+uMivqeUUSsKXi9GhHX40dzm5lZOyl2mO6wgrddyI6U9ixJRWZm1ukUO0x3XcH0ZrKnsDb10DwzM7NWKzaMzouIlwtn5A/YMzMza7Ni7013V5HzzMzMWm2HR0aSPk52g9Q+kj5XsGgvsqvqzMzM2qylYbqDgQlAX+DkgvnrgAtKVJOZmXUyOwyjiLgXuFfSURHxxC6qyczMOpliL2B4WtI/kg3ZbR2ei4hzS1KVmZl1KsVewHAHsA9wIrAQGEw2VGdmZtZmxYbRgRHxLWBDftPU8cCo0pVlZmadSbFhVJf/+Y6kkUAfoKokFZmZWadT7DmjWyV9CPgWMBfoDXy7ZFWZmVmnUuzzjKbnkwsBPzbCzMzaVVHDdJIGSrpN0i/z98MlnVfa0szMrLMo9pzR7cADwEfy978DJpWgHjMz64SKDaP+EfFfwBaAiNgM1JesKjMz61SKDaMNkj5M/uhxSZ8E1pasKjMz61SKvZruUrKr6IZKegwYAHy+ZFWVgffq6qm6an7qMtrVZaM2M7HInmqvGV/iasysM9nhkZGkIQARsQQ4BvgU8HfAiIh4tvTlWTlYuXIlxx57LMOGDWPEiBHccMMN2yyfNm0aknjrrbcSVWhmHV1LR0b3AA2PHL8zIv66tOVYOerWrRvXXXcdhx12GOvWrWPMmDGccMIJDB8+nJUrV/Lggw8yZMiQ1GWaWQfW0jkjFUy3+vtFkvaRNFvSS5Kel3SfpI81s26VpOWt3Ud7kDRG0jJJL0r6oSS1vJU12HfffTnssOzfLHvuuSfDhg3j1VdfBeCSSy5h6tSp+EdqZjvSUhhFM9Mtyn+hzwFqImJoRAwHvgEMbF2Ju8TNwIXAQfnr02nLKV+1tbU8/fTTHHnkkcydO5dBgwZxyCGHpC7LzDo4RTSfMZLqgQ1kR0g9gT83LAIiIvbawbbHAZMjYlyj+QKmAieRBdyUiLhTUhUwLyJGSpoIjI2Ii/Jt5gHTIqJG0nrgR8DxwNtkATcVGAJMioi5+fanAHsAQ4E5EXFFM3XuCzwcER/P358FVEfE3zWx7oVkoUX//gPGfPv6Hzf7sytHA3vCm+8Vt+6oQX22m/fee+/xta99jXPOOYcjjjiCSy65hGuvvZbevXtz5pln8h//8R/06bP9dqW0fv16evfuvUv3uStUYl+V2BNUZl9t6enYY49dHBFjG89v6eF6XXdqb5mRwOIm5n8OGA0cAvQHnpL0SCs+txfZ0daVkuYAU4ATgOHADLKr/sj3cSiwCVgh6caIWNnE5w0CVhW8X5XP205E3ArcCjDkgAPjumXFXoxYHi4btZlie6o9u3qb93V1dUyYMIGvfOUrXHrppSxbtow1a9Zw0UUXAfDWW2/x1a9+ld/85jfss88+7V16s2pqaqiurm5xvXJTiX1VYk9QmX2VoqcUv02PBmZFRD3wpqSFwOFAsVfnvQ/cn08vAzZFRJ2kZWx7J/GHImItgKTngf2ApsKoqZMZrRqS7OwigvPOO49hw4Zx6aWXAjBq1ChWr169dZ2qqioWLVpE//79U5VpZh1YsV963RnPAWOamF/MmezNbFtbj4LpuvhgbHEL2ZEPEbGFbcN1U8F0Pc0H7yqyhwU2GAy8VkSNlnvssce44447WLBgAaNHj2b06NHcd999qcsyszJSyiOjBcD3JV0QET8GkHQ42XmeMyTNAPoB44DL2TZwaoF/kNSFbMjsiFIVGRGvS1qX31XiSeBvgRtb2q7nbl1ZUWFf/Kypqdlu+K0YRx99NDs69wjZhQ1mZs0pWRhFREg6Dbhe0lXARrKQmUT2PKRnyIbDroiIN/ILGBo8BrxCNgy3HFhSqjpzf092M9iewC/zl5mZ7SIlPWcUEa8Bpzex6PL8VbhuLdlFD+TDcGc385m9C6YnN7UsIm4nC5eG+RNaqHNRw77NzGzXK+U5IzMzs6JU1rXJLZD0JNC90ewvRsSyFPWYmVmmU4VRRByZugYzM9ueh+nMzCw5h5GZmSXnMDIzs+QcRmZmlpzDyMzMknMYmZlZcg4jMzNLzmFkZmbJOYzMzCw5h5GZmSXnMDIzs+QcRmZmlpzDyMzMknMYmZlZcg4jMzNLzmFkZmbJOYzMzCw5h5GZmSXnMDIzs+QcRmZmlpzDyMzMknMYmZlZcg4jMzNLzmFkZmbJOYzMzCw5h5GZmSXnMDIzs+QcRmZmlpzDyMzMkuuWuoBy9V5dPVVXzW9xvdprxm83r76+nrFjxzJo0CDmzZtXivLMzMqKj4wSuOGGGxg2bFjqMszMOoyShpGkfSTNlvSSpOcl3SfpY82sWyVpeSnraY6kf5W0UtL6Uu9r1apVzJ8/n/PPP7/UuzIzKxslCyNJAuYANRExNCKGA98ABpZqn23wC+CIXbGjSZMmMXXqVLp08UGpmVmDUp4zOhaoi4hbGmZExFJlrgVOAgKYEhF3Fm4oaSIwNiIuyt/PA6ZFRE1+9PIj4HjgbbKAmwoMASZFxNx8+1OAPYChwJyIuKK5QiPiv/P97LAhSRcCFwL07z+Ab4/a3OIPoaamZuv0E088QV1dHevWrWPp0qWsWbNmm+WprV+/vkPV0x4qsSeozL4qsSeozL5K0VMpw2gksLiJ+Z8DRgOHAP2BpyQ90orP7UV2tHWlpDnAFOAEYDgwA5ibrzcaOBTYBKyQdGNErNyJPraKiFuBWwGGHHBgXLes5R9f7dnVW6cfeOABFi9ezMSJE9m4cSPvvvsu06dPZ+bMmW0pq93U1NRQXV2duox2VYk9QWX2VYk9QWX2VYqeUowVHQ3Mioj6iHgTWAgc3ort3wfuz6eXAQsjoi6fripY76GIWBsRG4Hngf3aXHkbXX311axatYra2lpmz57Ncccd12GCyMwspVKG0XPAmCbm73gsLLOZbWvrUTBdFxGRT28hO/IhIraw7ZHepoLpenwZu5lZh1XKMFoAdJd0QcMMSYeTnec5Q1JXSQOAccBvGm1bC4yW1EXSR9lFFxfsStXV1f6OkZlZrmRHCxERkk4Drpd0FbCRLGQmAb2BZ8guYLgiIt6QVFWw+WPAK2RDb8uBJaWqE0DSVOBvgD0krQKmR8TkHW3Tc7eurGjiC61mZtZ6JR26iojXgNObWHR5/ipct5bsogfyYbizm/nM3gXTk5taFhG3A7cXzJ/QQp1XAM1ebWdmZqXlL7uYmVlyneqkvqQnge6NZn8xIpalqMfMzDKdKowi4sjUNZiZ2fY8TGdmZsk5jMzMLDmHkZmZJecwMjOz5BxGZmaWnMPIzMyScxiZmVlyDiMzM0vOYWRmZsk5jMzMLDmHkZmZJecwMjOz5BxGZmaWnMPIzMyScxiZmVlyDiMzM0vOYWRmZsk5jMzMLDmHkZmZJecwMjOz5BxGZmaWnMPIzMyScxiZmVlyDiMzM0vOYWRmZsk5jMzMLDmHkZmZJecwMjOz5BxGZmaWnMPIzMyScxiZmVlyDiMzM0vOYWRmZsk5jMzMLDlFROoaypKkdcCK1HW0s/7AW6mLaGeV2BNUZl+V2BNUZl9t6Wm/iBjQeGa3ttXTqa2IiLGpi2hPkha5p/JQiX1VYk9QmX2VoicP05mZWXIOIzMzS85htPNuTV1ACbin8lGJfVViT1CZfbV7T76AwczMkvORkZmZJecwMjOz5BxGrSTp05JWSHpR0lWp69kZkj4q6WFJv5X0nKSv5fP7SXpQ0v/kf34oda07Q1JXSU9Lmpe/L+u+JPWVdJekF/K/s6PKvScASZfk//0tlzRLUo9y60vSf0paLWl5wbxme5D0T/nvjhWSTkxTdcua6eva/L/BZyXNkdS3YFmb+3IYtYKkrsCPgJOA4cBZkoanrWqnbAYui4hhwCeBf8z7uAp4KCIOAh7K35ejrwG/LXhf7n3dANwfER8HDiHrrax7kjQIuBgYGxEjga7AmZRfX7cDn240r8ke8v/HzgRG5Nv87/x3Skd0O9v39SAwMiI+AfwO+Cdov74cRq1zBPBiRLwcEe8Ds4HPJq6p1SLi9YhYkk+vI/vlNoislxn5ajOAU5MU2AaSBgPjgekFs8u2L0l7AeOA2wAi4v2IeIcy7qlAN6CnpG7AHsBrlFlfEfEI8KdGs5vr4bPA7IjYFBGvAC+S/U7pcJrqKyJ+FRGb87f/DQzOp9ulL4dR6wwCVha8X5XPK1uSqoBDgSeBgRHxOmSBBeydsLSddT1wBbClYF4593UA8Efg/+RDj9Ml9aK8eyIiXgWmAX8AXgfWRsSvKPO+cs31UEm/P84FfplPt0tfDqPWURPzyvbaeEm9gZ8DkyLi3dT1tJWkCcDqiFicupZ21A04DLg5Ig4FNtDxh65alJ9H+SywP/ARoJekc9JWVXIV8ftD0jfJhvp/0jCridVa3ZfDqHVWAR8teD+YbGih7EjajSyIfhIRd+ez35S0b758X2B1qvp20l8Ap0iqJRtCPU7STMq7r1XAqoh4Mn9/F1k4lXNPAMcDr0TEHyOiDrgb+BTl3xc030PZ//6Q9CVgAnB2fPAl1Xbpy2HUOk8BB0naX9LuZCft5iauqdUkiewcxG8j4t8LFs0FvpRPfwm4d1fX1hYR8U8RMTgiqsj+bhZExDmUcV8R8QawUtLB+ay/BJ6njHvK/QH4pKQ98v8e/5Ls3GW59wXN9zAXOFNSd0n7AwcBv0lQ306R9GngSuCUiPhzwaL26Ssi/GrFC/gM2ZUkLwHfTF3PTvZwNNlh9LPA0vz1GeDDZFf//E/+Z7/Utbahx2pgXj5d1n0Bo4FF+d/XPcCHyr2nvK9/AV4AlgN3AN3LrS9gFtk5rzqyI4TzdtQD8M38d8cK4KTU9beyrxfJzg01/M64pT378u2AzMwsOQ/TmZlZcg4jMzNLzmFkZmbJOYzMzCw5h5GZmSXnMDJrRFK9pKUFr6qd+IxTS3UTXUkfkXRXKT57B/scLekzu3Kf1rl0S12AWQf0XkSMbuNnnArMI/uCalEkdYsPbkTZrIh4Dfj8zpfWOvmNTEcDY4H7dtV+rXPxkZFZESSNkbRQ0mJJDxTc7uUCSU9JekbSz/M7CnwKOAW4Nj+yGiqpRtLYfJv++S2LkDRR0s8k/QL4laRe+bNknspvjLrdXeElVTU8Zybf/h5Jv5D0iqSLJF2ab/vfkvrl69VIul7S48qeH3REPr9fvv2z+fqfyOdPlnSrpF8B/xf4LnBG3s8Zko7IP+vp/M+DC+q5W9L9yp7nM7Wg7k9LWpL/rB7K57XYr3USqb/p65dfHe0F1PPBt8znALsBjwMD8uVnAP+ZT3+4YLspwFfz6duBzxcsqyF7dg9Af6A2n55I9g33fvn77wPn5NN9ye720atRfVXA8oLtXwT2BAYAa4Gv5Mt+QHYT3Ib9/zifHlew/Y3Ad/Lp44Cl+fRkYDHQs2A/NxXUsBfQLZ8+Hvh5wXovA32AHsDvye5bNoDs2/v75+sV3a9fnePlYTqz7W0zTCdpJDASeDC7jRpdyW6VAjBS0hSyX6S9gQd2Yn8PRkTDs2P+iuxmr1/P3/cAhrDtwwIbeziy51Ktk7QW+EU+fxnwiYL1ZkH2rBpJeyl7UufRwF/n8xdI+rCkPvn6cyPivWb22QeYIekgsltL7Vaw7KGIWAsg6XlgP7JbGD0S2fNuaGO/VoEcRmYtE/BcRBzVxLLbgVMj4hlJE8nuideUzXwwLN6j0bINjfb11xGxohX1bSqY3lLwfgvb/j/e+N5fwY5v/7+hiWUNvkcWgqflF3jUNFNPfV6Dmtg/7Fy/VoF8zsisZSuAAZKOguzxG5JG5Mv2BF5X9kiOswu2WZcva1ALjMmnd3TxwQPAV/M7WSPp0LaXv9UZ+WceTfYwu7XAI+R1S6oG3oqmn23VuJ8+wKv59MQi9v0EcEx+V2cazmVR2n6tjDiMzFoQ2SPmPw/8m6RnyM4lfSpf/C2yp+Q+SHYH6gazgcvzk/JDyZ5q+veSHic7Z9Sc75ENeT2bX6TwvXZs5e18/7eQ3YUZsnNDYyU9C1zDB48+aOxhYHjDBQzAVOBqSY+RDVvuUET8EbgQuDv/Gd6ZLyplv1ZGfNdus05AUg3w9YhYlLoWs6b4yMjMzJLzkZGZmSXnIyMzM0vOYWRmZsk5jMzMLDmHkZmZJecwMjOz5P4/2XSHuffvehMAAAAASUVORK5CYII=\n", 46 | "text/plain": [ 47 | "
" 48 | ] 49 | }, 50 | "metadata": { 51 | "needs_background": "light" 52 | }, 53 | "output_type": "display_data" 54 | } 55 | ], 56 | "source": [ 57 | "# 导入相关模块\n", 58 | "import lightgbm as lgb\n", 59 | "from sklearn.metrics import accuracy_score\n", 60 | "from sklearn.datasets import load_iris\n", 61 | "from sklearn.model_selection import train_test_split\n", 62 | "import matplotlib.pyplot as plt\n", 63 | "# 导入iris数据集\n", 64 | "iris = load_iris()\n", 65 | "data = iris.data\n", 66 | "target = iris.target\n", 67 | "# 数据集划分\n", 68 | "X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.2, random_state=43)\n", 69 | "# 创建lightgbm分类模型\n", 70 | "gbm = lgb.LGBMClassifier(objective='multiclass',\n", 71 | " num_class=3,\n", 72 | " num_leaves=31,\n", 73 | " learning_rate=0.05,\n", 74 | " n_estimators=20)\n", 75 | "# 模型训练\n", 76 | "gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], early_stopping_rounds=5)\n", 77 | "# 预测测试集\n", 78 | "y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration_)\n", 79 | "# 模型评估\n", 80 | "print('Accuracy of lightgbm:', accuracy_score(y_test, y_pred))\n", 81 | "lgb.plot_importance(gbm)\n", 82 | "plt.show();" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [] 91 | } 92 | ], 93 | "metadata": { 94 | "kernelspec": { 95 | "display_name": "Python 3 (ipykernel)", 96 | "language": "python", 97 | "name": "python3" 98 | }, 99 | "language_info": { 100 | "codemirror_mode": { 101 | "name": "ipython", 102 | "version": 3 103 | }, 104 | "file_extension": ".py", 105 | "mimetype": "text/x-python", 106 | "name": "python", 107 | "nbconvert_exporter": "python", 108 | "pygments_lexer": "ipython3", 109 | "version": "3.9.7" 110 | }, 111 | "toc": { 112 | "base_numbering": 1, 113 | "nav_menu": {}, 114 | "number_sections": true, 115 | "sideBar": true, 116 | "skip_h1_title": false, 117 | "title_cell": "Table of Contents", 118 | "title_sidebar": "Contents", 119 | "toc_cell": false, 120 | "toc_position": {}, 121 | "toc_section_display": true, 122 | "toc_window_display": false 123 | } 124 | }, 125 | "nbformat": 4, 126 | "nbformat_minor": 2 127 | } 128 | -------------------------------------------------------------------------------- /charpter14_CatBoost/catboost_info/learn/events.out.tfevents: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter14_CatBoost/catboost_info/learn/events.out.tfevents -------------------------------------------------------------------------------- /charpter14_CatBoost/catboost_info/learn_error.tsv: -------------------------------------------------------------------------------- 1 | iter Logloss 2 | 0 0.6065199002 3 | 1 0.5378107766 4 | 2 0.4912787557 5 | 3 0.456673048 6 | 4 0.428768478 7 | 5 0.4050500008 8 | 6 0.3899307575 9 | 7 0.3800498225 10 | 8 0.372001084 11 | 9 0.3622406382 12 | 10 0.3540936129 13 | 11 0.3491019959 14 | 12 0.3454124235 15 | 13 0.3407778961 16 | 14 0.3367714623 17 | 15 0.333911845 18 | 16 0.3306521561 19 | 17 0.3286266315 20 | 18 0.3264058618 21 | 19 0.3250865567 22 | 20 0.3233338648 23 | 21 0.3216826804 24 | 22 0.3206293512 25 | 23 0.3198258167 26 | 24 0.31860222 27 | 25 0.3172385796 28 | 26 0.3158636517 29 | 27 0.3146490054 30 | 28 0.3132823692 31 | 29 0.3116432353 32 | 30 0.3110070593 33 | 31 0.3104365333 34 | 32 0.3094960182 35 | 33 0.308806271 36 | 34 0.3082728324 37 | 35 0.3077598174 38 | 36 0.3072207594 39 | 37 0.30659435 40 | 38 0.305956076 41 | 39 0.305285533 42 | 40 0.3045614096 43 | 41 0.3036701936 44 | 42 0.3029562768 45 | 43 0.302640208 46 | 44 0.3021053504 47 | 45 0.3011489375 48 | 46 0.3006449722 49 | 47 0.3002692517 50 | 48 0.2997629416 51 | 49 0.2991658501 52 | 50 0.2988223205 53 | 51 0.298624936 54 | 52 0.2983980263 55 | 53 0.2982662003 56 | 54 0.2972628671 57 | 55 0.2968939375 58 | 56 0.2963495068 59 | 57 0.2959093322 60 | 58 0.2956098345 61 | 59 0.2953556793 62 | 60 0.2947482526 63 | 61 0.2941260686 64 | 62 0.2938056885 65 | 63 0.2936293736 66 | 64 0.2933686961 67 | 65 0.2932296939 68 | 66 0.2930219128 69 | 67 0.2928300045 70 | 68 0.2926488423 71 | 69 0.2924957775 72 | 70 0.2923330308 73 | 71 0.2921670516 74 | 72 0.2920435292 75 | 73 0.291923588 76 | 74 0.291663392 77 | 75 0.2913296433 78 | 76 0.2906140493 79 | 77 0.2897263878 80 | 78 0.2895057437 81 | 79 0.289397704 82 | 80 0.2893070534 83 | 81 0.2891011509 84 | 82 0.2888662165 85 | 83 0.288701674 86 | 84 0.288601643 87 | 85 0.2885229648 88 | 86 0.2884932835 89 | 87 0.2881296262 90 | 88 0.2874352579 91 | 89 0.2873125676 92 | 90 0.2871670174 93 | 91 0.2870132931 94 | 92 0.2867042299 95 | 93 0.2866637776 96 | 94 0.2865104217 97 | 95 0.2863502908 98 | 96 0.2861717921 99 | 97 0.2860295921 100 | 98 0.2858263609 101 | 99 0.2856588171 102 | 100 0.2855380463 103 | 101 0.2848646771 104 | 102 0.2848213195 105 | 103 0.2843038885 106 | 104 0.2840929751 107 | 105 0.2840190593 108 | 106 0.2838830034 109 | 107 0.2837046358 110 | 108 0.283474304 111 | 109 0.2832785729 112 | 110 0.2824402994 113 | 111 0.2823536468 114 | 112 0.2821446579 115 | 113 0.2820113447 116 | 114 0.2818883453 117 | 115 0.2816936887 118 | 116 0.2814284445 119 | 117 0.2810013531 120 | 118 0.2808963433 121 | 119 0.2807983443 122 | 120 0.2801504041 123 | 121 0.2800099487 124 | 122 0.2798282963 125 | 123 0.2796761038 126 | 124 0.2796036138 127 | 125 0.2793368586 128 | 126 0.2791788917 129 | 127 0.2789285817 130 | 128 0.2788666327 131 | 129 0.278779196 132 | 130 0.278721426 133 | 131 0.2786453505 134 | 132 0.2785074777 135 | 133 0.2781894237 136 | 134 0.2780432075 137 | 135 0.2779258484 138 | 136 0.277850313 139 | 137 0.2774010451 140 | 138 0.2772652848 141 | 139 0.2771497239 142 | 140 0.27701705 143 | 141 0.276960594 144 | 142 0.2768644799 145 | 143 0.2768187876 146 | 144 0.2767636104 147 | 145 0.2766037918 148 | 146 0.2764060186 149 | 147 0.2761515125 150 | 148 0.2756623081 151 | 149 0.2755489406 152 | 150 0.2755285521 153 | 151 0.2753868542 154 | 152 0.2753091257 155 | 153 0.2752233169 156 | 154 0.2752034102 157 | 155 0.2748938345 158 | 156 0.2747182557 159 | 157 0.2745855122 160 | 158 0.274571287 161 | 159 0.2745011654 162 | 160 0.2744869324 163 | 161 0.2743618293 164 | 162 0.2742974198 165 | 163 0.274258521 166 | 164 0.2741564563 167 | 165 0.2740632377 168 | 166 0.2739681744 169 | 167 0.2736259008 170 | 168 0.2734691031 171 | 169 0.273431139 172 | 170 0.2733323267 173 | 171 0.2732111572 174 | 172 0.2730752005 175 | 173 0.27298309 176 | 174 0.2729107348 177 | 175 0.2727933543 178 | 176 0.2726910481 179 | 177 0.272567598 180 | 178 0.272497418 181 | 179 0.2723968508 182 | 180 0.2723052382 183 | 181 0.272201636 184 | 182 0.2721605129 185 | 183 0.2720821905 186 | 184 0.2720485406 187 | 185 0.2720105745 188 | 186 0.2719790878 189 | 187 0.2719237783 190 | 188 0.271870807 191 | 189 0.2718467376 192 | 190 0.2716046981 193 | 191 0.2715112819 194 | 192 0.2714635604 195 | 193 0.2714237628 196 | 194 0.2713446262 197 | 195 0.2713249921 198 | 196 0.2711850472 199 | 197 0.2711292937 200 | 198 0.2710440028 201 | 199 0.2705329457 202 | 200 0.2704853904 203 | 201 0.2704348535 204 | 202 0.270399675 205 | 203 0.2702937854 206 | 204 0.2702416898 207 | 205 0.2701451693 208 | 206 0.2700415382 209 | 207 0.2699921687 210 | 208 0.2699216465 211 | 209 0.2698227999 212 | 210 0.2697982631 213 | 211 0.2695834449 214 | 212 0.2695150828 215 | 213 0.2694532011 216 | 214 0.2693920585 217 | 215 0.2693466887 218 | 216 0.2692687999 219 | 217 0.2691919654 220 | 218 0.2691595108 221 | 219 0.2690507397 222 | 220 0.2687172737 223 | 221 0.2686798268 224 | 222 0.2686254193 225 | 223 0.2684801022 226 | 224 0.2683894127 227 | 225 0.268271212 228 | 226 0.2681592691 229 | 227 0.2681333592 230 | 228 0.2681155504 231 | 229 0.2680852512 232 | 230 0.2680259156 233 | 231 0.2679313967 234 | 232 0.2678769713 235 | 233 0.2677892152 236 | 234 0.2677581601 237 | 235 0.2677112928 238 | 236 0.2676790637 239 | 237 0.2676396176 240 | 238 0.2675961107 241 | 239 0.2675628611 242 | 240 0.267466294 243 | 241 0.2673975269 244 | 242 0.2673876516 245 | 243 0.2673826807 246 | 244 0.2673438694 247 | 245 0.2672419189 248 | 246 0.2672214496 249 | 247 0.2671424611 250 | 248 0.2670344919 251 | 249 0.2669841592 252 | 250 0.2669579783 253 | 251 0.2668758907 254 | 252 0.2668287393 255 | 253 0.2667556774 256 | 254 0.2667457246 257 | 255 0.266725966 258 | 256 0.2666712634 259 | 257 0.2666454772 260 | 258 0.2665696718 261 | 259 0.2665548004 262 | 260 0.2665308023 263 | 261 0.266446924 264 | 262 0.2663650193 265 | 263 0.2662598567 266 | 264 0.266197116 267 | 265 0.2659776802 268 | 266 0.2659159476 269 | 267 0.2659019516 270 | 268 0.2658717116 271 | 269 0.26583287 272 | 270 0.2655169309 273 | 271 0.2653233163 274 | 272 0.2652716497 275 | 273 0.2652604178 276 | 274 0.2651546899 277 | 275 0.2651002572 278 | 276 0.265081894 279 | 277 0.2649946491 280 | 278 0.2648977283 281 | 279 0.2648657089 282 | 280 0.2648164856 283 | 281 0.264749663 284 | 282 0.2647381207 285 | 283 0.2646678274 286 | 284 0.2646147908 287 | 285 0.2645231991 288 | 286 0.2644788727 289 | 287 0.2644260786 290 | 288 0.2644064113 291 | 289 0.2643757101 292 | 290 0.2643568353 293 | 291 0.2643156451 294 | 292 0.2642577091 295 | 293 0.2641939029 296 | 294 0.2641570484 297 | 295 0.2641389479 298 | 296 0.2641113881 299 | 297 0.2640148119 300 | 298 0.2639804109 301 | 299 0.2639152228 302 | 300 0.2636536197 303 | 301 0.2636120281 304 | 302 0.2635683692 305 | 303 0.2635647382 306 | 304 0.2635355951 307 | 305 0.2634723657 308 | 306 0.2634326076 309 | 307 0.2634035528 310 | 308 0.2633509055 311 | 309 0.2633143922 312 | 310 0.2632478541 313 | 311 0.2632030167 314 | 312 0.2631061835 315 | 313 0.2630167379 316 | 314 0.262982371 317 | 315 0.2629405734 318 | 316 0.2629323973 319 | 317 0.2627361762 320 | 318 0.2627017045 321 | 319 0.2626101998 322 | 320 0.2626034174 323 | 321 0.2625842483 324 | 322 0.2625723254 325 | 323 0.2625247785 326 | 324 0.2624960673 327 | 325 0.2624629586 328 | 326 0.2624313655 329 | 327 0.2623605992 330 | 328 0.2622300399 331 | 329 0.2621941873 332 | 330 0.2621679908 333 | 331 0.2621339553 334 | 332 0.2620603624 335 | 333 0.2619835211 336 | 334 0.2619280351 337 | 335 0.2618121715 338 | 336 0.2617255734 339 | 337 0.2616585727 340 | 338 0.2616110028 341 | 339 0.2615471711 342 | 340 0.2614339178 343 | 341 0.2614207272 344 | 342 0.2613865617 345 | 343 0.2613514032 346 | 344 0.2612594051 347 | 345 0.2612471096 348 | 346 0.261218204 349 | 347 0.2612084732 350 | 348 0.2611875305 351 | 349 0.2611323504 352 | 350 0.2610994964 353 | 351 0.2610738137 354 | 352 0.2610107439 355 | 353 0.2609436344 356 | 354 0.2608927103 357 | 355 0.2608724464 358 | 356 0.2608083854 359 | 357 0.2607604146 360 | 358 0.26074839 361 | 359 0.2606800177 362 | 360 0.2606243812 363 | 361 0.2605617861 364 | 362 0.2605083765 365 | 363 0.2604662492 366 | 364 0.2604053137 367 | 365 0.2603650146 368 | 366 0.2603237066 369 | 367 0.2602263897 370 | 368 0.2601724528 371 | 369 0.2601245036 372 | 370 0.2600655259 373 | 371 0.2600542644 374 | 372 0.2598390778 375 | 373 0.2598349487 376 | 374 0.2598232714 377 | 375 0.2598047584 378 | 376 0.2597764939 379 | 377 0.2597520336 380 | 378 0.2596979708 381 | 379 0.2596427752 382 | 380 0.2596065538 383 | 381 0.2595933025 384 | 382 0.2594959681 385 | 383 0.2594735002 386 | 384 0.2593942034 387 | 385 0.2593735234 388 | 386 0.2593064494 389 | 387 0.2592460802 390 | 388 0.2591838554 391 | 389 0.2591056406 392 | 390 0.2590448527 393 | 391 0.259016704 394 | 392 0.2589939395 395 | 393 0.2589606623 396 | 394 0.2588947529 397 | 395 0.25888444 398 | 396 0.2588551982 399 | 397 0.2588133942 400 | 398 0.2587709574 401 | 399 0.2587148358 402 | 400 0.258608377 403 | 401 0.2585246516 404 | 402 0.2584756742 405 | 403 0.2584510449 406 | 404 0.258433407 407 | 405 0.2584038276 408 | 406 0.2583902664 409 | 407 0.2583549942 410 | 408 0.258288514 411 | 409 0.258273056 412 | 410 0.2582409891 413 | 411 0.2581954958 414 | 412 0.2581750183 415 | 413 0.2581521184 416 | 414 0.258088053 417 | 415 0.2580495604 418 | 416 0.2580319933 419 | 417 0.2580049181 420 | 418 0.2579846722 421 | 419 0.2579295381 422 | 420 0.2579116808 423 | 421 0.2578633478 424 | 422 0.2578174585 425 | 423 0.2578116256 426 | 424 0.2577814232 427 | 425 0.2577547502 428 | 426 0.2576792508 429 | 427 0.2576482521 430 | 428 0.257589284 431 | 429 0.2575727211 432 | 430 0.2575356634 433 | 431 0.2574984717 434 | 432 0.2574873659 435 | 433 0.257474878 436 | 434 0.2573929079 437 | 435 0.2571991774 438 | 436 0.2571873559 439 | 437 0.2571242568 440 | 438 0.2570834199 441 | 439 0.257055469 442 | 440 0.2570400027 443 | 441 0.2569796113 444 | 442 0.256929946 445 | 443 0.256851046 446 | 444 0.2568297742 447 | 445 0.256683687 448 | 446 0.2566670481 449 | 447 0.2566078744 450 | 448 0.2565523966 451 | 449 0.2565440283 452 | 450 0.2565388741 453 | 451 0.2565191174 454 | 452 0.2564586392 455 | 453 0.2563976563 456 | 454 0.2563618559 457 | 455 0.2563199363 458 | 456 0.2562943041 459 | 457 0.2562760667 460 | 458 0.2562588935 461 | 459 0.2562375427 462 | 460 0.2561790312 463 | 461 0.2561418988 464 | 462 0.2561232381 465 | 463 0.256084434 466 | 464 0.2560542016 467 | 465 0.2560263372 468 | 466 0.2560122115 469 | 467 0.2560105231 470 | 468 0.2559944819 471 | 469 0.2559833858 472 | 470 0.2559219161 473 | 471 0.2558775203 474 | 472 0.2558424926 475 | 473 0.2557115477 476 | 474 0.255664807 477 | 475 0.2556044383 478 | 476 0.255584484 479 | 477 0.2555685542 480 | 478 0.2555450002 481 | 479 0.2554836737 482 | 480 0.2554626636 483 | 481 0.2554402965 484 | 482 0.2553721492 485 | 483 0.2552797246 486 | 484 0.2552373086 487 | 485 0.255222556 488 | 486 0.2551612909 489 | 487 0.2551371138 490 | 488 0.2550970094 491 | 489 0.255050753 492 | 490 0.2549865534 493 | 491 0.2549668762 494 | 492 0.2549267782 495 | 493 0.2548621996 496 | 494 0.2548334863 497 | 495 0.2548013814 498 | 496 0.2546030119 499 | 497 0.2544728751 500 | 498 0.2543876703 501 | 499 0.2543845882 502 | -------------------------------------------------------------------------------- /charpter14_CatBoost/catboost_info/time_left.tsv: -------------------------------------------------------------------------------- 1 | iter Passed Remaining 2 | 0 72 36380 3 | 1 200 50028 4 | 2 271 44974 5 | 3 346 43011 6 | 4 439 43546 7 | 5 514 42368 8 | 6 583 41115 9 | 7 670 41207 10 | 8 744 40641 11 | 9 832 40807 12 | 10 906 40283 13 | 11 981 39913 14 | 12 1053 39478 15 | 13 1124 39036 16 | 14 1202 38879 17 | 15 1269 38401 18 | 16 1351 38395 19 | 17 1428 38257 20 | 18 1501 38001 21 | 19 1629 39112 22 | 20 1730 39471 23 | 21 1805 39219 24 | 22 1869 38768 25 | 23 1942 38519 26 | 24 2014 38269 27 | 25 2084 37999 28 | 26 2154 37736 29 | 27 2225 37510 30 | 28 2295 37276 31 | 29 2370 37139 32 | 30 2448 37046 33 | 31 2519 36845 34 | 32 2588 36626 35 | 33 2655 36390 36 | 34 2735 36339 37 | 35 2803 36135 38 | 36 2864 35847 39 | 37 2936 35698 40 | 38 3011 35591 41 | 39 3070 35308 42 | 40 3133 35079 43 | 41 3211 35025 44 | 42 3280 34869 45 | 43 3341 34631 46 | 44 3412 34507 47 | 45 3480 34347 48 | 46 3546 34183 49 | 47 3605 33954 50 | 48 3662 33712 51 | 49 3719 33475 52 | 50 3774 33232 53 | 51 3838 33066 54 | 52 3920 33067 55 | 53 3983 32900 56 | 54 4068 32914 57 | 55 4139 32818 58 | 56 4204 32680 59 | 57 4267 32522 60 | 58 4322 32309 61 | 59 4381 32130 62 | 60 4438 31939 63 | 61 4499 31787 64 | 62 4567 31684 65 | 63 4623 31494 66 | 64 4687 31372 67 | 65 4750 31239 68 | 66 4822 31165 69 | 67 4893 31085 70 | 68 4969 31043 71 | 69 5054 31051 72 | 70 5116 30912 73 | 71 5171 30740 74 | 72 5243 30670 75 | 73 5304 30535 76 | 74 5375 30461 77 | 75 5450 30407 78 | 76 5526 30360 79 | 77 5609 30350 80 | 78 5691 30331 81 | 79 5770 30294 82 | 80 5845 30239 83 | 81 5911 30135 84 | 82 5980 30045 85 | 83 6049 29959 86 | 84 6128 29920 87 | 85 6203 29863 88 | 86 6288 29854 89 | 87 6361 29781 90 | 88 6426 29675 91 | 89 6505 29634 92 | 90 6592 29629 93 | 91 6648 29485 94 | 92 6709 29363 95 | 93 6765 29220 96 | 94 6818 29067 97 | 95 6872 28923 98 | 96 6927 28781 99 | 97 6976 28618 100 | 98 7060 28597 101 | 99 7116 28466 102 | 100 7170 28327 103 | 101 7230 28213 104 | 102 7286 28084 105 | 103 7369 28059 106 | 104 7437 27977 107 | 105 7488 27833 108 | 106 7717 28344 109 | 107 7863 28540 110 | 108 8008 28726 111 | 109 8149 28892 112 | 110 8240 28878 113 | 111 8320 28822 114 | 112 8382 28708 115 | 113 8446 28600 116 | 114 8506 28479 117 | 115 8567 28361 118 | 116 8638 28277 119 | 117 8717 28221 120 | 118 8776 28100 121 | 119 8837 27986 122 | 120 8894 27858 123 | 121 8950 27732 124 | 122 9022 27655 125 | 123 9085 27550 126 | 124 9149 27449 127 | 125 9233 27408 128 | 126 9378 27546 129 | 127 9436 27425 130 | 128 9500 27323 131 | 129 9573 27247 132 | 130 9632 27133 133 | 131 9691 27018 134 | 132 9755 26919 135 | 133 9815 26810 136 | 134 9872 26692 137 | 135 9927 26570 138 | 136 9982 26450 139 | 137 10042 26344 140 | 138 10112 26263 141 | 139 10160 26127 142 | 140 10222 26028 143 | 141 10273 25901 144 | 142 10328 25784 145 | 143 10382 25667 146 | 144 10433 25544 147 | 145 10506 25474 148 | 146 10618 25499 149 | 147 10706 25463 150 | 148 10785 25407 151 | 149 10859 25338 152 | 150 10911 25218 153 | 151 10962 25098 154 | 152 11023 25000 155 | 153 11103 24947 156 | 154 11167 24855 157 | 155 11230 24763 158 | 156 11294 24675 159 | 157 11351 24570 160 | 158 11410 24470 161 | 159 11455 24342 162 | 160 11494 24202 163 | 161 11558 24115 164 | 162 11610 24003 165 | 163 11657 23883 166 | 164 11714 23784 167 | 165 11774 23691 168 | 166 11833 23596 169 | 167 11894 23506 170 | 168 11941 23388 171 | 169 11987 23270 172 | 170 12049 23182 173 | 171 12106 23086 174 | 172 12152 22970 175 | 173 12212 22881 176 | 174 12268 22784 177 | 175 12323 22687 178 | 176 12379 22590 179 | 177 12427 22481 180 | 178 12484 22387 181 | 179 12544 22302 182 | 180 12619 22240 183 | 181 12680 22156 184 | 182 12744 22077 185 | 183 12800 21982 186 | 184 12854 21887 187 | 185 12905 21787 188 | 186 12963 21699 189 | 187 13023 21613 190 | 188 13071 21508 191 | 189 13139 21438 192 | 190 13218 21384 193 | 191 13276 21297 194 | 192 13324 21195 195 | 193 13379 21103 196 | 194 13450 21038 197 | 195 13505 20947 198 | 196 13564 20862 199 | 197 13610 20759 200 | 198 13671 20679 201 | 199 13731 20597 202 | 200 13782 20502 203 | 201 13833 20408 204 | 202 13890 20322 205 | 203 13950 20241 206 | 204 14001 20148 207 | 205 14063 20071 208 | 206 14120 19987 209 | 207 14167 19888 210 | 208 14217 19796 211 | 209 14275 19713 212 | 210 14339 19640 213 | 211 14416 19584 214 | 212 14485 19517 215 | 213 14547 19441 216 | 214 14609 19366 217 | 215 14778 19430 218 | 216 14904 19437 219 | 217 14950 19340 220 | 218 14999 19245 221 | 219 15053 19158 222 | 220 15108 19074 223 | 221 15167 18994 224 | 222 15220 18905 225 | 223 15277 18824 226 | 224 15336 18744 227 | 225 15394 18664 228 | 226 15475 18611 229 | 227 15534 18532 230 | 228 15587 18446 231 | 229 15644 18364 232 | 230 15699 18282 233 | 231 15752 18197 234 | 232 15794 18098 235 | 233 15842 18008 236 | 234 15902 17932 237 | 235 15955 17848 238 | 236 16012 17769 239 | 237 16095 17718 240 | 238 16234 17728 241 | 239 16303 17662 242 | 240 16358 17579 243 | 241 16410 17495 244 | 242 16464 17412 245 | 243 16516 17328 246 | 244 16603 17281 247 | 245 16717 17261 248 | 246 16787 17195 249 | 247 16856 17128 250 | 248 16906 17042 251 | 249 16965 16965 252 | 250 17017 16881 253 | 251 17071 16800 254 | 252 17130 16724 255 | 253 17199 16657 256 | 254 17254 16578 257 | 255 17303 16492 258 | 256 17360 16415 259 | 257 17415 16335 260 | 258 17471 16256 261 | 259 17524 16176 262 | 260 17582 16100 263 | 261 17641 16025 264 | 262 17689 15940 265 | 263 17742 15860 266 | 264 17787 15773 267 | 265 17843 15697 268 | 266 17955 15668 269 | 267 18053 15628 270 | 268 18112 15553 271 | 269 18168 15476 272 | 270 18222 15398 273 | 271 18277 15321 274 | 272 18334 15245 275 | 273 18388 15167 276 | 274 18448 15094 277 | 275 18504 15018 278 | 276 18558 14940 279 | 277 18617 14867 280 | 278 18669 14788 281 | 279 18718 14707 282 | 280 18772 14630 283 | 281 18833 14558 284 | 282 18883 14479 285 | 283 18931 14398 286 | 284 18981 14319 287 | 285 19040 14247 288 | 286 19102 14177 289 | 287 19156 14101 290 | 288 19205 14021 291 | 289 19257 13945 292 | 290 19314 13872 293 | 291 19372 13799 294 | 292 19427 13725 295 | 293 19495 13660 296 | 294 19557 13590 297 | 295 19612 13516 298 | 296 19664 13440 299 | 297 19710 13360 300 | 298 19761 13284 301 | 299 19812 13208 302 | 300 19869 13136 303 | 301 19922 13061 304 | 302 19972 12985 305 | 303 20019 12907 306 | 304 20068 12830 307 | 305 20121 12756 308 | 306 20181 12687 309 | 307 20244 12620 310 | 308 20301 12548 311 | 309 20358 12477 312 | 310 20411 12404 313 | 311 20470 12334 314 | 312 20529 12265 315 | 313 20585 12194 316 | 314 20640 12122 317 | 315 20708 12058 318 | 316 20771 11991 319 | 317 20836 11925 320 | 318 20912 11865 321 | 319 20979 11800 322 | 320 21040 11732 323 | 321 21109 11669 324 | 322 21173 11602 325 | 323 21236 11535 326 | 324 21303 11470 327 | 325 21367 11404 328 | 326 21435 11340 329 | 327 21494 11271 330 | 328 21564 11208 331 | 329 21615 11135 332 | 330 21682 11070 333 | 331 21745 11003 334 | 332 21799 10932 335 | 333 21851 10860 336 | 334 21908 10790 337 | 335 21960 10718 338 | 336 22011 10646 339 | 337 22066 10576 340 | 338 22119 10505 341 | 339 22173 10434 342 | 340 22226 10363 343 | 341 22278 10292 344 | 342 22327 10219 345 | 343 22378 10148 346 | 344 22450 10086 347 | 345 22502 10015 348 | 346 22551 9943 349 | 347 22605 9873 350 | 348 22655 9802 351 | 349 22700 9728 352 | 350 22752 9658 353 | 351 22808 9589 354 | 352 22861 9520 355 | 353 22907 9447 356 | 354 22962 9378 357 | 355 23019 9311 358 | 356 23074 9242 359 | 357 23120 9170 360 | 358 23171 9100 361 | 359 23237 9036 362 | 360 23292 8968 363 | 361 23355 8903 364 | 362 23408 8834 365 | 363 23461 8766 366 | 364 23514 8697 367 | 365 23574 8631 368 | 366 23630 8563 369 | 367 23685 8495 370 | 368 23746 8430 371 | 369 23807 8364 372 | 370 23860 8296 373 | 371 23912 8228 374 | 372 23978 8164 375 | 373 24030 8095 376 | 374 24087 8029 377 | 375 24140 7961 378 | 376 24189 7892 379 | 377 24238 7823 380 | 378 24296 7757 381 | 379 24345 7688 382 | 380 24395 7619 383 | 381 24444 7550 384 | 382 24503 7485 385 | 383 24559 7419 386 | 384 24605 7349 387 | 385 24659 7282 388 | 386 24746 7225 389 | 387 24803 7159 390 | 388 24852 7091 391 | 389 24902 7023 392 | 390 24962 6958 393 | 391 25013 6891 394 | 392 25064 6824 395 | 393 25128 6760 396 | 394 25182 6694 397 | 395 25243 6629 398 | 396 25292 6562 399 | 397 25350 6496 400 | 398 25404 6430 401 | 399 25458 6364 402 | 400 25517 6299 403 | 401 25572 6234 404 | 402 25628 6168 405 | 403 25690 6104 406 | 404 25743 6038 407 | 405 25799 5973 408 | 406 25851 5907 409 | 407 25903 5840 410 | 408 25952 5774 411 | 409 26017 5711 412 | 410 26084 5648 413 | 411 26152 5585 414 | 412 26217 5522 415 | 413 26287 5460 416 | 414 26350 5397 417 | 415 26435 5337 418 | 416 26499 5274 419 | 417 26555 5209 420 | 418 26610 5144 421 | 419 26667 5079 422 | 420 26720 5014 423 | 421 26768 4947 424 | 422 26821 4882 425 | 423 26874 4817 426 | 424 26929 4752 427 | 425 26978 4686 428 | 426 27032 4621 429 | 427 27084 4556 430 | 428 27144 4492 431 | 429 27214 4430 432 | 430 27277 4366 433 | 431 27332 4302 434 | 432 27391 4238 435 | 433 27441 4173 436 | 434 27494 4108 437 | 435 27551 4044 438 | 436 27603 3979 439 | 437 27661 3915 440 | 438 27709 3850 441 | 439 27762 3785 442 | 440 27825 3722 443 | 441 27881 3658 444 | 442 27935 3594 445 | 443 27992 3530 446 | 444 28053 3467 447 | 445 28110 3403 448 | 446 28165 3339 449 | 447 28216 3275 450 | 448 28265 3210 451 | 449 28322 3146 452 | 450 28382 3083 453 | 451 28446 3020 454 | 452 28530 2960 455 | 453 28593 2897 456 | 454 28641 2832 457 | 455 28692 2768 458 | 456 28747 2704 459 | 457 28801 2641 460 | 458 28847 2576 461 | 459 28900 2513 462 | 460 28953 2449 463 | 461 29009 2386 464 | 462 29073 2323 465 | 463 29124 2259 466 | 464 29184 2196 467 | 465 29244 2133 468 | 466 29295 2070 469 | 467 29348 2006 470 | 468 29403 1943 471 | 469 29461 1880 472 | 470 29524 1817 473 | 471 29589 1755 474 | 472 29640 1691 475 | 473 29692 1628 476 | 474 29750 1565 477 | 475 29802 1502 478 | 476 29849 1439 479 | 477 29902 1376 480 | 478 29968 1313 481 | 479 30020 1250 482 | 480 30090 1188 483 | 481 30152 1126 484 | 482 30205 1063 485 | 483 30262 1000 486 | 484 30314 937 487 | 485 30366 874 488 | 486 30432 812 489 | 487 30488 749 490 | 488 30551 687 491 | 489 30608 624 492 | 490 30667 562 493 | 491 30718 499 494 | 492 30772 436 495 | 493 30834 374 496 | 494 30893 312 497 | 495 30950 249 498 | 496 31006 187 499 | 497 31058 124 500 | 498 31113 62 501 | 499 31165 0 502 | -------------------------------------------------------------------------------- /charpter15_random_forest/__pycache__/cart.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter15_random_forest/__pycache__/cart.cpython-39.pyc -------------------------------------------------------------------------------- /charpter15_random_forest/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter15_random_forest/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /charpter15_random_forest/cart.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from utils import feature_split, calculate_gini 3 | 4 | ### 定义树结点 5 | class TreeNode(): 6 | def __init__(self, feature_i=None, threshold=None, 7 | leaf_value=None, left_branch=None, right_branch=None): 8 | # 特征索引 9 | self.feature_i = feature_i 10 | # 特征划分阈值 11 | self.threshold = threshold 12 | # 叶子节点取值 13 | self.leaf_value = leaf_value 14 | # 左子树 15 | self.left_branch = left_branch 16 | # 右子树 17 | self.right_branch = right_branch 18 | 19 | 20 | ### 定义二叉决策树 21 | class BinaryDecisionTree(object): 22 | ### 决策树初始参数 23 | def __init__(self, min_samples_split=2, min_gini_impurity=999, 24 | max_depth=float("inf"), loss=None): 25 | # 根结点 26 | self.root = None 27 | # 节点最小分裂样本数 28 | self.min_samples_split = min_samples_split 29 | # 节点初始化基尼不纯度 30 | self.min_gini_impurity = min_gini_impurity 31 | # 树最大深度 32 | self.max_depth = max_depth 33 | # 基尼不纯度计算函数 34 | self.gini_impurity_calculation = None 35 | # 叶子节点值预测函数 36 | self._leaf_value_calculation = None 37 | # 损失函数 38 | self.loss = loss 39 | 40 | ### 决策树拟合函数 41 | def fit(self, X, y, loss=None): 42 | # 递归构建决策树 43 | self.root = self._build_tree(X, y) 44 | self.loss = None 45 | 46 | ### 决策树构建函数 47 | def _build_tree(self, X, y, current_depth=0): 48 | # 初始化最小基尼不纯度 49 | init_gini_impurity = 999 50 | # 初始化最佳特征索引和阈值 51 | best_criteria = None 52 | # 初始化数据子集 53 | best_sets = None 54 | 55 | if len(np.shape(y)) == 1: 56 | y = np.expand_dims(y, axis=1) 57 | 58 | # 合并输入和标签 59 | Xy = np.concatenate((X, y), axis=1) 60 | # 获取样本数和特征数 61 | n_samples, n_features = X.shape 62 | # 设定决策树构建条件 63 | # 训练样本数量大于节点最小分裂样本数且当前树深度小于最大深度 64 | if n_samples >= self.min_samples_split and current_depth <= self.max_depth: 65 | # 遍历计算每个特征的基尼不纯度 66 | for feature_i in range(n_features): 67 | # 获取第i特征的所有取值 68 | feature_values = np.expand_dims(X[:, feature_i], axis=1) 69 | # 获取第i个特征的唯一取值 70 | unique_values = np.unique(feature_values) 71 | 72 | # 遍历取值并寻找最佳特征分裂阈值 73 | for threshold in unique_values: 74 | # 特征节点二叉分裂 75 | Xy1, Xy2 = feature_split(Xy, feature_i, threshold) 76 | # 如果分裂后的子集大小都不为0 77 | if len(Xy1) > 0 and len(Xy2) > 0: 78 | # 获取两个子集的标签值 79 | y1 = Xy1[:, n_features:] 80 | y2 = Xy2[:, n_features:] 81 | 82 | # 计算基尼不纯度 83 | impurity = self.impurity_calculation(y, y1, y2) 84 | 85 | # 获取最小基尼不纯度 86 | # 最佳特征索引和分裂阈值 87 | if impurity < init_gini_impurity: 88 | init_gini_impurity = impurity 89 | best_criteria = {"feature_i": feature_i, "threshold": threshold} 90 | best_sets = { 91 | "leftX": Xy1[:, :n_features], 92 | "lefty": Xy1[:, n_features:], 93 | "rightX": Xy2[:, :n_features], 94 | "righty": Xy2[:, n_features:] 95 | } 96 | 97 | # 如果计算的最小不纯度小于设定的最小不纯度 98 | if init_gini_impurity < self.min_gini_impurity: 99 | # 分别构建左右子树 100 | left_branch = self._build_tree(best_sets["leftX"], best_sets["lefty"], current_depth + 1) 101 | right_branch = self._build_tree(best_sets["rightX"], best_sets["righty"], current_depth + 1) 102 | return TreeNode(feature_i=best_criteria["feature_i"], threshold=best_criteria["threshold"], left_branch=left_branch, right_branch=right_branch) 103 | 104 | # 计算叶子计算取值 105 | leaf_value = self._leaf_value_calculation(y) 106 | return TreeNode(leaf_value=leaf_value) 107 | 108 | ### 定义二叉树值预测函数 109 | def predict_value(self, x, tree=None): 110 | if tree is None: 111 | tree = self.root 112 | # 如果叶子节点已有值,则直接返回已有值 113 | if tree.leaf_value is not None: 114 | return tree.leaf_value 115 | # 选择特征并获取特征值 116 | feature_value = x[tree.feature_i] 117 | # 判断落入左子树还是右子树 118 | branch = tree.right_branch 119 | if isinstance(feature_value, int) or isinstance(feature_value, float): 120 | if feature_value >= tree.threshold: 121 | branch = tree.left_branch 122 | elif feature_value == tree.threshold: 123 | branch = tree.right_branch 124 | # 测试子集 125 | return self.predict_value(x, branch) 126 | 127 | ### 数据集预测函数 128 | def predict(self, X): 129 | y_pred = [self.predict_value(sample) for sample in X] 130 | return y_pred 131 | 132 | 133 | class ClassificationTree(BinaryDecisionTree): 134 | ### 定义基尼不纯度计算过程 135 | def _calculate_gini_impurity(self, y, y1, y2): 136 | p = len(y1) / len(y) 137 | gini = calculate_gini(y) 138 | gini_impurity = p * calculate_gini(y1) + (1-p) * calculate_gini(y2) 139 | return gini_impurity 140 | 141 | ### 多数投票 142 | def _majority_vote(self, y): 143 | most_common = None 144 | max_count = 0 145 | for label in np.unique(y): 146 | # 统计多数 147 | count = len(y[y == label]) 148 | if count > max_count: 149 | most_common = label 150 | max_count = count 151 | return most_common 152 | 153 | # 分类树拟合 154 | def fit(self, X, y): 155 | self.impurity_calculation = self._calculate_gini_impurity 156 | self._leaf_value_calculation = self._majority_vote 157 | super(ClassificationTree, self).fit(X, y) 158 | 159 | 160 | ### CART回归树 161 | class RegressionTree(BinaryDecisionTree): 162 | # 计算方差减少量 163 | def _calculate_variance_reduction(self, y, y1, y2): 164 | var_tot = np.var(y, axis=0) 165 | var_y1 = np.var(y1, axis=0) 166 | var_y2 = np.var(y2, axis=0) 167 | frac_1 = len(y1) / len(y) 168 | frac_2 = len(y2) / len(y) 169 | # 计算方差减少量 170 | variance_reduction = var_tot - (frac_1 * var_y1 + frac_2 * var_y2) 171 | return sum(variance_reduction) 172 | 173 | # 节点值取平均 174 | def _mean_of_y(self, y): 175 | value = np.mean(y, axis=0) 176 | return value if len(value) > 1 else value[0] 177 | 178 | # 回归树拟合 179 | def fit(self, X, y): 180 | self.impurity_calculation = self._calculate_variance_reduction 181 | self._leaf_value_calculation = self._mean_of_y 182 | super(RegressionTree, self).fit(X, y) 183 | -------------------------------------------------------------------------------- /charpter15_random_forest/random_forest.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## 随机森林" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "### 一.手写算法" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [ 22 | { 23 | "name": "stdout", 24 | "output_type": "stream", 25 | "text": [ 26 | "(700, 20) (700,) (300, 20) (300,)\n" 27 | ] 28 | } 29 | ], 30 | "source": [ 31 | "import numpy as np\n", 32 | "# 该模块为自定义模块,封装了构建决策树的基本方法\n", 33 | "from cart import *\n", 34 | "from sklearn.datasets import make_classification\n", 35 | "from sklearn.model_selection import train_test_split\n", 36 | "# 树的棵数\n", 37 | "n_estimators = 10\n", 38 | "# 列抽样最大特征数\n", 39 | "max_features = 15\n", 40 | "# 生成模拟二分类数据集\n", 41 | "X, y = make_classification(n_samples=1000, n_features=20, n_redundant=0, n_informative=2, random_state=1, n_clusters_per_class=1)\n", 42 | "rng = np.random.RandomState(2)\n", 43 | "X += 2 * rng.uniform(size=X.shape)\n", 44 | "# 划分数据集\n", 45 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)\n", 46 | "print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 2, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "# 合并训练数据和标签\n", 56 | "X_y = np.concatenate([X, y.reshape(-1,1)], axis=1)\n", 57 | "np.random.shuffle(X_y)\n", 58 | "m = X_y.shape[0]\n", 59 | "sampling_subsets = []\n", 60 | "\n", 61 | "for _ in range(n_estimators):\n", 62 | " idx = np.random.choice(m, m, replace=True)\n", 63 | " bootstrap_Xy = X_y[idx, :]\n", 64 | " bootstrap_X = bootstrap_Xy[:, :-1]\n", 65 | " bootstrap_y = bootstrap_Xy[:, -1]\n", 66 | " sampling_subsets.append([bootstrap_X, bootstrap_y])" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 3, 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "data": { 76 | "text/plain": [ 77 | "(1000, 20)" 78 | ] 79 | }, 80 | "execution_count": 3, 81 | "metadata": {}, 82 | "output_type": "execute_result" 83 | } 84 | ], 85 | "source": [ 86 | "sampling_subsets[0][0].shape" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 4, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "# 自助抽样选择训练数据子集\n", 96 | "def bootstrap_sampling(X, y):\n", 97 | " X_y = np.concatenate([X, y.reshape(-1,1)], axis=1)\n", 98 | " np.random.shuffle(X_y)\n", 99 | " n_samples = X.shape[0]\n", 100 | " sampling_subsets = []\n", 101 | "\n", 102 | " for _ in range(n_estimators):\n", 103 | " # 第一个随机性,行抽样\n", 104 | " idx1 = np.random.choice(n_samples, n_samples, replace=True)\n", 105 | " bootstrap_Xy = X_y[idx1, :]\n", 106 | " bootstrap_X = bootstrap_Xy[:, :-1]\n", 107 | " bootstrap_y = bootstrap_Xy[:, -1]\n", 108 | " sampling_subsets.append([bootstrap_X, bootstrap_y])\n", 109 | " return sampling_subsets" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 5, 115 | "metadata": {}, 116 | "outputs": [ 117 | { 118 | "name": "stdout", 119 | "output_type": "stream", 120 | "text": [ 121 | "(700, 20) (700,)\n" 122 | ] 123 | } 124 | ], 125 | "source": [ 126 | "sampling_subsets = bootstrap_sampling(X_train, y_train)\n", 127 | "sub_X, sub_y = sampling_subsets[0]\n", 128 | "print(sub_X.shape, sub_y.shape)" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 6, 134 | "metadata": {}, 135 | "outputs": [ 136 | { 137 | "data": { 138 | "text/plain": [ 139 | "" 140 | ] 141 | }, 142 | "execution_count": 6, 143 | "metadata": {}, 144 | "output_type": "execute_result" 145 | } 146 | ], 147 | "source": [ 148 | "trees = []\n", 149 | "# 基于决策树构建森林\n", 150 | "for _ in range(n_estimators):\n", 151 | " tree = ClassificationTree(min_samples_split=2, min_gini_impurity=999, max_depth=3)\n", 152 | " trees.append(tree)\n", 153 | "\n", 154 | "trees[0]" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 7, 160 | "metadata": { 161 | "scrolled": true 162 | }, 163 | "outputs": [ 164 | { 165 | "name": "stderr", 166 | "output_type": "stream", 167 | "text": [ 168 | "E:\\study\\coding\\ML公式推导和代码实现\\charpter15_random_forest\\utils.py:14: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n", 169 | " return np.array([X_left, X_right])\n" 170 | ] 171 | }, 172 | { 173 | "name": "stdout", 174 | "output_type": "stream", 175 | "text": [ 176 | "The 1th tree is trained done...\n", 177 | "The 2th tree is trained done...\n", 178 | "The 3th tree is trained done...\n", 179 | "The 4th tree is trained done...\n", 180 | "The 5th tree is trained done...\n", 181 | "The 6th tree is trained done...\n", 182 | "The 7th tree is trained done...\n", 183 | "The 8th tree is trained done...\n", 184 | "The 9th tree is trained done...\n", 185 | "The 10th tree is trained done...\n" 186 | ] 187 | } 188 | ], 189 | "source": [ 190 | "# 随机森林训练\n", 191 | "def fit(X, y):\n", 192 | " # 对森林中每棵树训练一个双随机抽样子集\n", 193 | " n_features = X.shape[1]\n", 194 | " sub_sets = bootstrap_sampling(X, y)\n", 195 | " for i in range(n_estimators):\n", 196 | " sub_X, sub_y = sub_sets[i]\n", 197 | " # 第二个随机性,列抽样\n", 198 | " idx2 = np.random.choice(n_features, max_features, replace=True)\n", 199 | " sub_X = sub_X[:, idx2]\n", 200 | " trees[i].fit(sub_X, sub_y)\n", 201 | " trees[i].feature_indices = idx2\n", 202 | " print('The {}th tree is trained done...'.format(i+1))\n", 203 | "\n", 204 | "fit(X_train, y_train)" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 8, 210 | "metadata": {}, 211 | "outputs": [ 212 | { 213 | "data": { 214 | "text/plain": [ 215 | "300" 216 | ] 217 | }, 218 | "execution_count": 8, 219 | "metadata": {}, 220 | "output_type": "execute_result" 221 | } 222 | ], 223 | "source": [ 224 | "y_preds = []\n", 225 | "for i in range(n_estimators):\n", 226 | " idx = trees[i].feature_indices\n", 227 | " sub_X = X_test[:, idx]\n", 228 | " y_pred = trees[i].predict(sub_X)\n", 229 | " y_preds.append(y_pred)\n", 230 | " \n", 231 | "len(y_preds[0])" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": 9, 237 | "metadata": {}, 238 | "outputs": [ 239 | { 240 | "name": "stdout", 241 | "output_type": "stream", 242 | "text": [ 243 | "(300, 10)\n", 244 | "[1, 0, 0, 0, 0, 0, 1, 0, 1, 0]\n" 245 | ] 246 | } 247 | ], 248 | "source": [ 249 | "y_preds = np.array(y_preds).T\n", 250 | "print(y_preds.shape)\n", 251 | "y_pred = []\n", 252 | "for y_p in y_preds:\n", 253 | " y_pred.append(np.bincount(y_p.astype('int')).argmax())\n", 254 | "\n", 255 | "print(y_pred[:10])" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 10, 261 | "metadata": {}, 262 | "outputs": [ 263 | { 264 | "name": "stdout", 265 | "output_type": "stream", 266 | "text": [ 267 | "0.8266666666666667\n" 268 | ] 269 | } 270 | ], 271 | "source": [ 272 | "from sklearn.metrics import accuracy_score\n", 273 | "print(accuracy_score(y_test, y_pred))" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 11, 279 | "metadata": {}, 280 | "outputs": [], 281 | "source": [ 282 | "class RandomForest():\n", 283 | " def __init__(self, n_estimators=100, min_samples_split=2, min_gain=0,\n", 284 | " max_depth=float(\"inf\"), max_features=None):\n", 285 | " # 树的棵树\n", 286 | " self.n_estimators = n_estimators\n", 287 | " # 树最小分裂样本数\n", 288 | " self.min_samples_split = min_samples_split\n", 289 | " # 最小增益\n", 290 | " self.min_gain = min_gain\n", 291 | " # 树最大深度\n", 292 | " self.max_depth = max_depth\n", 293 | " # 所使用最大特征数\n", 294 | " self.max_features = max_features\n", 295 | "\n", 296 | " self.trees = []\n", 297 | " # 基于决策树构建森林\n", 298 | " for _ in range(self.n_estimators):\n", 299 | " # tree = ClassificationTree(min_samples_split=self.min_samples_split, min_impurity=self.min_gain, max_depth=self.max_depth)\n", 300 | " tree = ClassificationTree(min_samples_split=self.min_samples_split, min_gini_impurity=self.min_gain, max_depth=self.max_depth)\n", 301 | " self.trees.append(tree)\n", 302 | " \n", 303 | " # 自助抽样\n", 304 | " def bootstrap_sampling(self, X, y):\n", 305 | " X_y = np.concatenate([X, y.reshape(-1,1)], axis=1)\n", 306 | " np.random.shuffle(X_y)\n", 307 | " n_samples = X.shape[0]\n", 308 | " sampling_subsets = []\n", 309 | "\n", 310 | " for _ in range(self.n_estimators):\n", 311 | " # 第一个随机性,行抽样\n", 312 | " idx1 = np.random.choice(n_samples, n_samples, replace=True)\n", 313 | " bootstrap_Xy = X_y[idx1, :]\n", 314 | " bootstrap_X = bootstrap_Xy[:, :-1]\n", 315 | " bootstrap_y = bootstrap_Xy[:, -1]\n", 316 | " sampling_subsets.append([bootstrap_X, bootstrap_y])\n", 317 | " return sampling_subsets\n", 318 | " \n", 319 | " # 随机森林训练\n", 320 | " def fit(self, X, y):\n", 321 | " # 对森林中每棵树训练一个双随机抽样子集\n", 322 | " sub_sets = self.bootstrap_sampling(X, y)\n", 323 | " n_features = X.shape[1]\n", 324 | " # 设置max_feature\n", 325 | " if self.max_features == None:\n", 326 | " self.max_features = int(np.sqrt(n_features))\n", 327 | " \n", 328 | " for i in range(self.n_estimators):\n", 329 | " # 第二个随机性,列抽样\n", 330 | " sub_X, sub_y = sub_sets[i]\n", 331 | " idx2 = np.random.choice(n_features, self.max_features, replace=True)\n", 332 | " sub_X = sub_X[:, idx2]\n", 333 | " self.trees[i].fit(sub_X, sub_y)\n", 334 | " # 保存每次列抽样的列索引,方便预测时每棵树调用\n", 335 | " self.trees[i].feature_indices = idx2\n", 336 | " print('The {}th tree is trained done...'.format(i+1))\n", 337 | " \n", 338 | " # 随机森林预测\n", 339 | " def predict(self, X):\n", 340 | " y_preds = []\n", 341 | " for i in range(self.n_estimators):\n", 342 | " idx = self.trees[i].feature_indices\n", 343 | " sub_X = X[:, idx]\n", 344 | " y_pred = self.trees[i].predict(sub_X)\n", 345 | " y_preds.append(y_pred)\n", 346 | " \n", 347 | " y_preds = np.array(y_preds).T\n", 348 | " res = []\n", 349 | " for j in y_preds:\n", 350 | " res.append(np.bincount(j.astype('int')).argmax())\n", 351 | " return res" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": 12, 357 | "metadata": {}, 358 | "outputs": [ 359 | { 360 | "name": "stderr", 361 | "output_type": "stream", 362 | "text": [ 363 | "E:\\study\\coding\\ML公式推导和代码实现\\charpter15_random_forest\\utils.py:14: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n", 364 | " return np.array([X_left, X_right])\n" 365 | ] 366 | }, 367 | { 368 | "name": "stdout", 369 | "output_type": "stream", 370 | "text": [ 371 | "The 1th tree is trained done...\n", 372 | "The 2th tree is trained done...\n", 373 | "The 3th tree is trained done...\n", 374 | "The 4th tree is trained done...\n", 375 | "The 5th tree is trained done...\n", 376 | "The 6th tree is trained done...\n", 377 | "The 7th tree is trained done...\n", 378 | "The 8th tree is trained done...\n", 379 | "The 9th tree is trained done...\n", 380 | "The 10th tree is trained done...\n", 381 | "0.5166666666666667\n" 382 | ] 383 | } 384 | ], 385 | "source": [ 386 | "# 调用手写的算法\n", 387 | "rf = RandomForest(n_estimators=10, max_features=15)\n", 388 | "rf.fit(X_train, y_train)\n", 389 | "y_pred = rf.predict(X_test)\n", 390 | "print(accuracy_score(y_test, y_pred))" 391 | ] 392 | }, 393 | { 394 | "cell_type": "markdown", 395 | "metadata": {}, 396 | "source": [ 397 | "### 二.使用sklearn" 398 | ] 399 | }, 400 | { 401 | "cell_type": "code", 402 | "execution_count": 13, 403 | "metadata": {}, 404 | "outputs": [ 405 | { 406 | "name": "stdout", 407 | "output_type": "stream", 408 | "text": [ 409 | "0.7966666666666666\n" 410 | ] 411 | } 412 | ], 413 | "source": [ 414 | "from sklearn.ensemble import RandomForestClassifier\n", 415 | "clf = RandomForestClassifier(max_depth=3, random_state=0)\n", 416 | "clf.fit(X_train, y_train)\n", 417 | "y_pred = clf.predict(X_test)\n", 418 | "print(accuracy_score(y_test, y_pred))" 419 | ] 420 | } 421 | ], 422 | "metadata": { 423 | "kernelspec": { 424 | "display_name": "Python 3 (ipykernel)", 425 | "language": "python", 426 | "name": "python3" 427 | }, 428 | "language_info": { 429 | "codemirror_mode": { 430 | "name": "ipython", 431 | "version": 3 432 | }, 433 | "file_extension": ".py", 434 | "mimetype": "text/x-python", 435 | "name": "python", 436 | "nbconvert_exporter": "python", 437 | "pygments_lexer": "ipython3", 438 | "version": "3.9.7" 439 | }, 440 | "toc": { 441 | "base_numbering": 1, 442 | "nav_menu": {}, 443 | "number_sections": true, 444 | "sideBar": true, 445 | "skip_h1_title": false, 446 | "title_cell": "Table of Contents", 447 | "title_sidebar": "Contents", 448 | "toc_cell": false, 449 | "toc_position": {}, 450 | "toc_section_display": true, 451 | "toc_window_display": false 452 | } 453 | }, 454 | "nbformat": 4, 455 | "nbformat_minor": 4 456 | } 457 | -------------------------------------------------------------------------------- /charpter15_random_forest/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | ### 定义二叉特征分裂函数 4 | def feature_split(X, feature_i, threshold): 5 | split_func = None 6 | if isinstance(threshold, int) or isinstance(threshold, float): 7 | split_func = lambda sample: sample[feature_i] >= threshold 8 | else: 9 | split_func = lambda sample: sample[feature_i] == threshold 10 | 11 | X_left = np.array([sample for sample in X if split_func(sample)]) 12 | X_right = np.array([sample for sample in X if not split_func(sample)]) 13 | 14 | return np.array([X_left, X_right]) 15 | 16 | 17 | ### 计算基尼指数 18 | def calculate_gini(y): 19 | # 将数组转化为列表 20 | y = y.tolist() 21 | probs = [y.count(i)/len(y) for i in np.unique(y)] 22 | gini = sum([p*(1-p) for p in probs]) 23 | return gini 24 | -------------------------------------------------------------------------------- /charpter16_ensemble_compare/catboost_info/learn/events.out.tfevents: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter16_ensemble_compare/catboost_info/learn/events.out.tfevents -------------------------------------------------------------------------------- /charpter16_ensemble_compare/catboost_info/learn_error.tsv: -------------------------------------------------------------------------------- 1 | iter Logloss 2 | 0 0.6527006333 3 | 1 0.6193492582 4 | 2 0.5941794087 5 | 3 0.5733124318 6 | 4 0.5577082274 7 | 5 0.544840432 8 | 6 0.5344679241 9 | 7 0.5260777687 10 | 8 0.5192146161 11 | 9 0.5134674289 12 | 10 0.5087981923 13 | 11 0.505114313 14 | 12 0.5023427495 15 | 13 0.4992629232 16 | 14 0.4973913142 17 | 15 0.4958819117 18 | 16 0.494266672 19 | 17 0.4928652288 20 | 18 0.4920587926 21 | 19 0.4909576352 22 | 20 0.4899412467 23 | 21 0.4886361603 24 | 22 0.4878442849 25 | 23 0.4870989425 26 | 24 0.486504571 27 | 25 0.4852213172 28 | 26 0.4844110543 29 | 27 0.4833227571 30 | 28 0.4829413445 31 | 29 0.482435996 32 | 30 0.4820132224 33 | 31 0.4816714465 34 | 32 0.4812931347 35 | 33 0.4804872194 36 | 34 0.4804339977 37 | 35 0.4800638362 38 | 36 0.4795605895 39 | 37 0.4789416274 40 | 38 0.4783961207 41 | 39 0.4780656231 42 | 40 0.4776448963 43 | 41 0.4771924395 44 | 42 0.4768901774 45 | 43 0.4765555068 46 | 44 0.4762507872 47 | 45 0.4760446202 48 | 46 0.4754369284 49 | 47 0.474767743 50 | 48 0.4745219553 51 | 49 0.4740823145 52 | 50 0.4737165125 53 | 51 0.4732811114 54 | 52 0.472938032 55 | 53 0.472459876 56 | 54 0.472327878 57 | 55 0.4721380782 58 | 56 0.4718461674 59 | 57 0.4716008923 60 | 58 0.4711414296 61 | 59 0.4709590956 62 | 60 0.4707697018 63 | 61 0.470333492 64 | 62 0.4700537775 65 | 63 0.4698628685 66 | 64 0.4695372033 67 | 65 0.4693651453 68 | 66 0.4692207098 69 | 67 0.468931874 70 | 68 0.4687822898 71 | 69 0.4685692659 72 | 70 0.4684094481 73 | 71 0.4682260989 74 | 72 0.4681185881 75 | 73 0.4678398148 76 | 74 0.4675399904 77 | 75 0.4671308337 78 | 76 0.4669053113 79 | 77 0.4667054285 80 | 78 0.4664819365 81 | 79 0.4661455637 82 | 80 0.4659932572 83 | 81 0.4658852015 84 | 82 0.4656152117 85 | 83 0.4656080021 86 | 84 0.4654394796 87 | 85 0.465293117 88 | 86 0.4651259492 89 | 87 0.465063814 90 | 88 0.4647470816 91 | 89 0.4645601854 92 | 90 0.4644281962 93 | 91 0.4640613887 94 | 92 0.4638250676 95 | 93 0.4635502567 96 | 94 0.4633590819 97 | 95 0.4631819992 98 | 96 0.4630127171 99 | 97 0.4628231592 100 | 98 0.4626641418 101 | 99 0.4623734142 102 | 100 0.4622030722 103 | 101 0.4619653521 104 | 102 0.4616336684 105 | 103 0.4614933963 106 | 104 0.4612071964 107 | 105 0.4609744858 108 | 106 0.460682579 109 | 107 0.4602504061 110 | 108 0.4599660086 111 | 109 0.459624116 112 | 110 0.4593296251 113 | 111 0.4589447452 114 | 112 0.4586136917 115 | 113 0.4582753622 116 | 114 0.4579857302 117 | 115 0.4576706984 118 | 116 0.4574333037 119 | 117 0.4571779752 120 | 118 0.4569845242 121 | 119 0.4565772371 122 | 120 0.4563674356 123 | 121 0.4560683384 124 | 122 0.4558683462 125 | 123 0.4556193423 126 | 124 0.4554050993 127 | 125 0.4552127853 128 | 126 0.4550712881 129 | 127 0.4547808696 130 | 128 0.4544711699 131 | 129 0.4541282934 132 | 130 0.4538274802 133 | 131 0.4535967715 134 | 132 0.4533214491 135 | 133 0.4531731878 136 | 134 0.4530637951 137 | 135 0.4527387723 138 | 136 0.4524965372 139 | 137 0.452247727 140 | 138 0.4520934796 141 | 139 0.4518932377 142 | 140 0.451634417 143 | 141 0.451335177 144 | 142 0.4509815816 145 | 143 0.450697706 146 | 144 0.4504336707 147 | 145 0.4503049829 148 | 146 0.4500368226 149 | 147 0.4498455661 150 | 148 0.449563689 151 | 149 0.4492805672 152 | 150 0.4490711459 153 | 151 0.4488653187 154 | 152 0.4485631802 155 | 153 0.4483111816 156 | 154 0.4480290217 157 | 155 0.4478184854 158 | 156 0.44758083 159 | 157 0.4474770528 160 | 158 0.4472456551 161 | 159 0.4468916034 162 | 160 0.4464742877 163 | 161 0.4462685951 164 | 162 0.4461769669 165 | 163 0.4459846691 166 | 164 0.445845785 167 | 165 0.4455406991 168 | 166 0.4454100694 169 | 167 0.445264607 170 | 168 0.4449336387 171 | 169 0.4446822179 172 | 170 0.444490791 173 | 171 0.4442899179 174 | 172 0.444120478 175 | 173 0.4439204898 176 | 174 0.4436726181 177 | 175 0.4434650095 178 | 176 0.4432100053 179 | 177 0.4430116563 180 | 178 0.4428127827 181 | 179 0.4426172522 182 | 180 0.4424228289 183 | 181 0.4421394846 184 | 182 0.4419199981 185 | 183 0.4417432322 186 | 184 0.4415643661 187 | 185 0.4414388123 188 | 186 0.4411597702 189 | 187 0.4409913504 190 | 188 0.4407818288 191 | 189 0.4404655874 192 | 190 0.4402884833 193 | 191 0.4402102932 194 | 192 0.4399973424 195 | 193 0.4398065022 196 | 194 0.4395259754 197 | 195 0.4393421496 198 | 196 0.4391850199 199 | 197 0.4390054427 200 | 198 0.4389233634 201 | 199 0.4387315356 202 | 200 0.4386313913 203 | 201 0.4383529724 204 | 202 0.4381752706 205 | 203 0.4379885061 206 | 204 0.437737938 207 | 205 0.4375336776 208 | 206 0.4373921985 209 | 207 0.4370839481 210 | 208 0.4369050613 211 | 209 0.4367457914 212 | 210 0.4366224142 213 | 211 0.4364863614 214 | 212 0.4362719623 215 | 213 0.4360762003 216 | 214 0.4358583431 217 | 215 0.4357633817 218 | 216 0.4355713176 219 | 217 0.4354015855 220 | 218 0.4351749506 221 | 219 0.4351022747 222 | 220 0.4349085299 223 | 221 0.434758201 224 | 222 0.4345803308 225 | 223 0.4344466998 226 | 224 0.4343117834 227 | 225 0.4341642055 228 | 226 0.434002621 229 | 227 0.4338493668 230 | 228 0.4337174753 231 | 229 0.4335439856 232 | 230 0.4334249254 233 | 231 0.4332135838 234 | 232 0.4330124051 235 | 233 0.432903857 236 | 234 0.4327445017 237 | 235 0.43258031 238 | 236 0.4323444853 239 | 237 0.4321910858 240 | 238 0.4319732611 241 | 239 0.4318633878 242 | 240 0.4317233092 243 | 241 0.4316100765 244 | 242 0.4314233684 245 | 243 0.4311911969 246 | 244 0.4309511649 247 | 245 0.4307605737 248 | 246 0.4305956963 249 | 247 0.4304967347 250 | 248 0.4302741451 251 | 249 0.4301954186 252 | 250 0.4300582068 253 | 251 0.4298833015 254 | 252 0.4297932535 255 | 253 0.4295760958 256 | 254 0.4294770641 257 | 255 0.4293085107 258 | 256 0.4290923801 259 | 257 0.4289707935 260 | 258 0.4288296893 261 | 259 0.4286859179 262 | 260 0.4285910994 263 | 261 0.428416305 264 | 262 0.4281823627 265 | 263 0.4279207692 266 | 264 0.4277485456 267 | 265 0.4276121971 268 | 266 0.427442275 269 | 267 0.4273092753 270 | 268 0.4271866083 271 | 269 0.4270427061 272 | 270 0.4268035031 273 | 271 0.4265565278 274 | 272 0.4263148916 275 | 273 0.4262022198 276 | 274 0.4260153978 277 | 275 0.4258115357 278 | 276 0.4256838187 279 | 277 0.4255105282 280 | 278 0.425375475 281 | 279 0.4252267282 282 | 280 0.42511492 283 | 281 0.4249172328 284 | 282 0.4247181855 285 | 283 0.4244766468 286 | 284 0.4243412807 287 | 285 0.4242434311 288 | 286 0.4240460172 289 | 287 0.4238952681 290 | 288 0.4237408691 291 | 289 0.4235566093 292 | 290 0.4233924713 293 | 291 0.4233133175 294 | 292 0.4232333796 295 | 293 0.4230644886 296 | 294 0.4229251988 297 | 295 0.4226910543 298 | 296 0.4225576578 299 | 297 0.4224314783 300 | 298 0.4223174159 301 | 299 0.4221563304 302 | -------------------------------------------------------------------------------- /charpter16_ensemble_compare/catboost_info/time_left.tsv: -------------------------------------------------------------------------------- 1 | iter Passed Remaining 2 | 0 219 65631 3 | 1 328 48971 4 | 2 445 44093 5 | 3 582 43072 6 | 4 719 42427 7 | 5 819 40143 8 | 6 930 38950 9 | 7 1046 38202 10 | 8 1179 38144 11 | 9 1290 37438 12 | 10 1408 37014 13 | 11 1527 36670 14 | 12 1617 35703 15 | 13 1740 35547 16 | 14 1879 35719 17 | 15 1994 35406 18 | 16 2095 34878 19 | 17 2244 35161 20 | 18 2306 34114 21 | 19 2476 34677 22 | 20 2610 34688 23 | 21 2733 34540 24 | 22 2832 34112 25 | 23 2953 33961 26 | 24 3094 34044 27 | 25 3204 33772 28 | 26 3335 33725 29 | 27 3467 33685 30 | 28 3653 34138 31 | 29 3778 34009 32 | 30 3946 34244 33 | 31 4104 34371 34 | 32 4232 34246 35 | 33 4362 34126 36 | 34 4404 33349 37 | 35 4524 33178 38 | 36 4631 32918 39 | 37 4737 32665 40 | 38 4853 32478 41 | 39 4950 32177 42 | 40 5057 31951 43 | 41 5175 31791 44 | 42 5291 31628 45 | 43 5400 31419 46 | 44 5505 31197 47 | 45 5616 31013 48 | 46 5726 30825 49 | 47 5845 30690 50 | 48 5955 30504 51 | 49 6074 30370 52 | 50 6200 30275 53 | 51 6319 30137 54 | 52 6428 29960 55 | 53 6535 29774 56 | 54 6632 29544 57 | 55 6749 29406 58 | 56 6855 29224 59 | 57 6958 29032 60 | 58 7067 28870 61 | 59 7187 28749 62 | 60 7289 28561 63 | 61 7398 28400 64 | 62 7507 28243 65 | 63 7610 28064 66 | 64 7734 27961 67 | 65 7837 27786 68 | 66 7939 27609 69 | 67 8049 27464 70 | 68 8153 27296 71 | 69 8247 27099 72 | 70 8367 26989 73 | 71 8480 26855 74 | 72 8582 26686 75 | 73 8689 26537 76 | 74 8781 26344 77 | 75 8891 26207 78 | 76 8998 26060 79 | 77 9119 25955 80 | 78 9232 25827 81 | 79 9349 25710 82 | 80 9459 25575 83 | 81 9551 25393 84 | 82 9677 25300 85 | 83 9745 25058 86 | 84 9864 24952 87 | 85 9950 24761 88 | 86 10044 24591 89 | 87 10141 24432 90 | 88 10243 24286 91 | 89 10356 24164 92 | 90 10450 24002 93 | 91 10567 23892 94 | 92 10664 23736 95 | 93 10766 23594 96 | 94 10888 23496 97 | 95 11127 23646 98 | 96 11253 23550 99 | 97 11380 23457 100 | 98 11498 23344 101 | 99 11612 23224 102 | 100 11779 23208 103 | 101 11924 23146 104 | 102 12045 23037 105 | 103 12158 22913 106 | 104 12289 22824 107 | 105 12394 22684 108 | 106 12507 22559 109 | 107 12621 22437 110 | 108 12734 22314 111 | 109 12842 22181 112 | 110 12948 22048 113 | 111 13065 21930 114 | 112 13173 21800 115 | 113 13281 21670 116 | 114 13397 21551 117 | 115 13500 21415 118 | 116 13609 21285 119 | 117 13709 21145 120 | 118 13818 21018 121 | 119 13929 20894 122 | 120 14066 20809 123 | 121 14172 20677 124 | 122 14302 20582 125 | 123 14416 20462 126 | 124 14517 20325 127 | 125 14618 20187 128 | 126 14725 20059 129 | 127 14841 19943 130 | 128 14951 19819 131 | 129 15067 19704 132 | 130 15186 19591 133 | 131 15289 19459 134 | 132 15393 19329 135 | 133 15497 19198 136 | 134 15610 19079 137 | 135 15721 18958 138 | 136 15838 18844 139 | 137 15953 18727 140 | 138 16069 18612 141 | 139 16169 18479 142 | 140 16271 18348 143 | 141 16385 18231 144 | 142 16498 18113 145 | 143 16603 17986 146 | 144 16708 17860 147 | 145 16823 17745 148 | 146 16935 17626 149 | 147 17032 17492 150 | 148 17149 17379 151 | 149 17267 17267 152 | 150 17374 17143 153 | 151 17489 17028 154 | 152 17602 16911 155 | 153 17710 16790 156 | 154 17819 16670 157 | 155 17930 16550 158 | 156 18051 16442 159 | 157 18152 16313 160 | 158 18265 16198 161 | 159 18377 16080 162 | 160 18498 15970 163 | 161 18606 15849 164 | 162 18716 15731 165 | 163 18820 15607 166 | 164 18945 15501 167 | 165 19059 15385 168 | 166 19173 15269 169 | 167 19284 15152 170 | 168 19400 15038 171 | 169 19512 14921 172 | 170 19616 14798 173 | 171 19708 14666 174 | 172 19825 14553 175 | 173 19939 14439 176 | 174 20055 14325 177 | 175 20166 14208 178 | 176 20288 14098 179 | 177 20393 13977 180 | 178 20510 13864 181 | 179 20614 13742 182 | 180 20728 13628 183 | 181 20857 13523 184 | 182 20964 13403 185 | 183 21072 13284 186 | 184 21194 13175 187 | 185 21305 13058 188 | 186 21421 12944 189 | 187 21528 12825 190 | 188 21642 12710 191 | 189 21776 12607 192 | 190 21892 12493 193 | 191 21976 12361 194 | 192 22068 12234 195 | 193 22188 12123 196 | 194 22295 12005 197 | 195 22403 11887 198 | 196 22506 11767 199 | 197 22613 11649 200 | 198 22710 11526 201 | 199 22815 11407 202 | 200 22926 11291 203 | 201 23039 11177 204 | 202 23147 11060 205 | 203 23264 10948 206 | 204 23389 10839 207 | 205 23489 10718 208 | 206 23596 10601 209 | 207 23704 10484 210 | 208 23817 10370 211 | 209 23921 10251 212 | 210 24041 10140 213 | 211 24141 10020 214 | 212 24258 9908 215 | 213 24366 9792 216 | 214 24469 9674 217 | 215 24575 9557 218 | 216 24681 9440 219 | 217 24768 9316 220 | 218 24883 9203 221 | 219 25001 9091 222 | 220 25123 8980 223 | 221 25225 8863 224 | 222 25341 8750 225 | 223 25435 8629 226 | 224 25541 8513 227 | 225 25648 8398 228 | 226 25729 8274 229 | 227 25832 8157 230 | 228 25932 8040 231 | 229 26040 7925 232 | 230 26138 7807 233 | 231 26257 7696 234 | 232 26361 7580 235 | 233 26457 7462 236 | 234 26562 7346 237 | 235 26671 7232 238 | 236 26779 7118 239 | 237 26888 7004 240 | 238 27009 6893 241 | 239 27113 6778 242 | 240 27217 6663 243 | 241 27311 6545 244 | 242 27419 6431 245 | 243 27529 6318 246 | 244 27637 6204 247 | 245 27755 6092 248 | 246 27864 5978 249 | 247 27978 5866 250 | 248 28092 5753 251 | 249 28212 5642 252 | 250 28320 5528 253 | 251 28415 5412 254 | 252 28527 5299 255 | 253 28625 5184 256 | 254 28740 5071 257 | 255 28851 4958 258 | 256 28966 4846 259 | 257 29066 4731 260 | 258 29185 4620 261 | 259 29296 4507 262 | 260 29399 4392 263 | 261 29509 4279 264 | 262 29631 4168 265 | 263 29733 4054 266 | 264 29828 3939 267 | 265 29929 3825 268 | 266 30046 3713 269 | 267 30156 3600 270 | 268 30265 3487 271 | 269 30382 3375 272 | 270 30483 3262 273 | 271 30583 3148 274 | 272 30697 3035 275 | 273 30801 2922 276 | 274 30918 2810 277 | 275 31034 2698 278 | 276 31140 2585 279 | 277 31247 2472 280 | 278 31357 2360 281 | 279 31465 2247 282 | 280 31580 2135 283 | 281 31702 2023 284 | 282 31820 1911 285 | 283 31931 1798 286 | 284 32033 1685 287 | 285 32144 1573 288 | 286 32280 1462 289 | 287 32402 1350 290 | 288 32542 1238 291 | 289 32654 1126 292 | 290 32776 1013 293 | 291 32885 900 294 | 292 33004 788 295 | 293 33114 675 296 | 294 33248 563 297 | 295 33371 450 298 | 296 33504 338 299 | 297 33611 225 300 | 298 33752 112 301 | 299 33869 0 302 | -------------------------------------------------------------------------------- /charpter17_kmeans/kmeans.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### kmeans" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "from sklearn.cluster import KMeans" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "metadata": {}, 24 | "outputs": [ 25 | { 26 | "data": { 27 | "text/plain": [ 28 | "array([[0, 2],\n", 29 | " [0, 0],\n", 30 | " [1, 0],\n", 31 | " [5, 0],\n", 32 | " [5, 2]])" 33 | ] 34 | }, 35 | "execution_count": 2, 36 | "metadata": {}, 37 | "output_type": "execute_result" 38 | } 39 | ], 40 | "source": [ 41 | "X = np.array([[0,2],[0,0],[1,0],[5,0],[5,2]])\n", 42 | "X" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 3, 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "name": "stdout", 52 | "output_type": "stream", 53 | "text": [ 54 | "[1 1 1 0 0]\n" 55 | ] 56 | } 57 | ], 58 | "source": [ 59 | "from sklearn.cluster import KMeans\n", 60 | "kmeans = KMeans(n_clusters=2, random_state=0).fit(X)\n", 61 | "print(kmeans.labels_)" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 4, 67 | "metadata": {}, 68 | "outputs": [ 69 | { 70 | "data": { 71 | "text/plain": [ 72 | "array([1, 0])" 73 | ] 74 | }, 75 | "execution_count": 4, 76 | "metadata": {}, 77 | "output_type": "execute_result" 78 | } 79 | ], 80 | "source": [ 81 | "kmeans.predict([[0, 0], [12, 3]])" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 5, 87 | "metadata": {}, 88 | "outputs": [ 89 | { 90 | "data": { 91 | "text/plain": [ 92 | "array([[5. , 1. ],\n", 93 | " [0.33333333, 0.66666667]])" 94 | ] 95 | }, 96 | "execution_count": 5, 97 | "metadata": {}, 98 | "output_type": "execute_result" 99 | } 100 | ], 101 | "source": [ 102 | "kmeans.cluster_centers_" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 6, 108 | "metadata": {}, 109 | "outputs": [ 110 | { 111 | "name": "stdout", 112 | "output_type": "stream", 113 | "text": [ 114 | "5.0\n" 115 | ] 116 | } 117 | ], 118 | "source": [ 119 | "import numpy as np\n", 120 | "# 定义欧式距离\n", 121 | "def euclidean_distance(x1, x2):\n", 122 | " distance = 0\n", 123 | " # 距离的平方项再开根号\n", 124 | " for i in range(len(x1)):\n", 125 | " distance += pow((x1[i] - x2[i]), 2)\n", 126 | " return np.sqrt(distance)\n", 127 | "\n", 128 | "print(euclidean_distance(X[0], X[4]))" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 7, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "# 定义中心初始化函数\n", 138 | "def centroids_init(k, X):\n", 139 | " m, n = X.shape\n", 140 | " centroids = np.zeros((k, n))\n", 141 | " for i in range(k):\n", 142 | " # 每一次循环随机选择一个类别中心\n", 143 | " centroid = X[np.random.choice(range(m))]\n", 144 | " centroids[i] = centroid\n", 145 | " return centroids" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 8, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "# 定义样本的最近质心点所属的类别索引\n", 155 | "def closest_centroid(sample, centroids):\n", 156 | " closest_i = 0\n", 157 | " closest_dist = float('inf')\n", 158 | " for i, centroid in enumerate(centroids):\n", 159 | " # 根据欧式距离判断,选择最小距离的中心点所属类别\n", 160 | " distance = euclidean_distance(sample, centroid)\n", 161 | " if distance < closest_dist:\n", 162 | " closest_i = i\n", 163 | " closest_dist = distance\n", 164 | " return closest_i" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 9, 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "# 定义构建类别过程\n", 174 | "def build_clusters(centroids, k, X):\n", 175 | " clusters = [[] for _ in range(k)]\n", 176 | " for x_i, x in enumerate(X):\n", 177 | " # 将样本划分到最近的类别区域\n", 178 | " centroid_i = closest_centroid(x, centroids)\n", 179 | " clusters[centroid_i].append(x_i)\n", 180 | " return clusters" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 10, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "# 根据上一步聚类结果计算新的中心点\n", 190 | "def calculate_centroids(clusters, k, X):\n", 191 | " n = X.shape[1]\n", 192 | " centroids = np.zeros((k, n))\n", 193 | " # 以当前每个类样本的均值为新的中心点\n", 194 | " for i, cluster in enumerate(clusters):\n", 195 | " centroid = np.mean(X[cluster], axis=0)\n", 196 | " centroids[i] = centroid\n", 197 | " return centroids" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 11, 203 | "metadata": {}, 204 | "outputs": [], 205 | "source": [ 206 | "# 获取每个样本所属的聚类类别\n", 207 | "def get_cluster_labels(clusters, X):\n", 208 | " y_pred = np.zeros(X.shape[0])\n", 209 | " for cluster_i, cluster in enumerate(clusters):\n", 210 | " for X_i in cluster:\n", 211 | " y_pred[X_i] = cluster_i\n", 212 | " return y_pred" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 12, 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "# 根据上述各流程定义kmeans算法流程\n", 222 | "def kmeans(X, k, max_iterations):\n", 223 | " # 1.初始化中心点\n", 224 | " centroids = centroids_init(k, X)\n", 225 | " # 遍历迭代求解\n", 226 | " for _ in range(max_iterations):\n", 227 | " # 2.根据当前中心点进行聚类\n", 228 | " clusters = build_clusters(centroids, k, X)\n", 229 | " # 保存当前中心点\n", 230 | " prev_centroids = centroids\n", 231 | " # 3.根据聚类结果计算新的中心点\n", 232 | " centroids = calculate_centroids(clusters, k, X)\n", 233 | " # 4.设定收敛条件为中心点是否发生变化\n", 234 | " diff = centroids - prev_centroids\n", 235 | " if not diff.any():\n", 236 | " break\n", 237 | " # 返回最终的聚类标签\n", 238 | " return get_cluster_labels(clusters, X)" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 13, 244 | "metadata": {}, 245 | "outputs": [ 246 | { 247 | "name": "stdout", 248 | "output_type": "stream", 249 | "text": [ 250 | "[1. 1. 1. 0. 0.]\n" 251 | ] 252 | } 253 | ], 254 | "source": [ 255 | "# 测试数据\n", 256 | "X = np.array([[0,2],[0,0],[1,0],[5,0],[5,2]])\n", 257 | "# 设定聚类类别为2个,最大迭代次数为10次\n", 258 | "labels = kmeans(X, 2, 10)\n", 259 | "# 打印每个样本所属的类别标签\n", 260 | "print(labels)" 261 | ] 262 | } 263 | ], 264 | "metadata": { 265 | "kernelspec": { 266 | "display_name": "Python 3 (ipykernel)", 267 | "language": "python", 268 | "name": "python3" 269 | }, 270 | "language_info": { 271 | "codemirror_mode": { 272 | "name": "ipython", 273 | "version": 3 274 | }, 275 | "file_extension": ".py", 276 | "mimetype": "text/x-python", 277 | "name": "python", 278 | "nbconvert_exporter": "python", 279 | "pygments_lexer": "ipython3", 280 | "version": "3.9.7" 281 | }, 282 | "toc": { 283 | "base_numbering": 1, 284 | "nav_menu": {}, 285 | "number_sections": true, 286 | "sideBar": true, 287 | "skip_h1_title": false, 288 | "title_cell": "Table of Contents", 289 | "title_sidebar": "Contents", 290 | "toc_cell": false, 291 | "toc_position": {}, 292 | "toc_section_display": true, 293 | "toc_window_display": false 294 | } 295 | }, 296 | "nbformat": 4, 297 | "nbformat_minor": 4 298 | } 299 | -------------------------------------------------------------------------------- /charpter19_SVD/louwill.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/louwill.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## SVD:奇异值分解\n", 8 | "降维技术" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [ 17 | "import numpy as np" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "metadata": {}, 24 | "outputs": [ 25 | { 26 | "data": { 27 | "text/plain": [ 28 | "array([[0, 1],\n", 29 | " [1, 1],\n", 30 | " [1, 0]])" 31 | ] 32 | }, 33 | "execution_count": 2, 34 | "metadata": {}, 35 | "output_type": "execute_result" 36 | } 37 | ], 38 | "source": [ 39 | "A = np.array([[0,1],[1,1],[1,0]])\n", 40 | "A" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 3, 46 | "metadata": {}, 47 | "outputs": [ 48 | { 49 | "name": "stdout", 50 | "output_type": "stream", 51 | "text": [ 52 | "(3, 3) (2,) (2, 2)\n" 53 | ] 54 | } 55 | ], 56 | "source": [ 57 | "u, s, vt = np.linalg.svd(A, full_matrices=True)\n", 58 | "print(u.shape, s.shape, vt.shape)" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 4, 64 | "metadata": {}, 65 | "outputs": [ 66 | { 67 | "data": { 68 | "text/plain": [ 69 | "array([[-0.40824829, 0.70710678, 0.57735027],\n", 70 | " [-0.81649658, 0. , -0.57735027],\n", 71 | " [-0.40824829, -0.70710678, 0.57735027]])" 72 | ] 73 | }, 74 | "execution_count": 4, 75 | "metadata": {}, 76 | "output_type": "execute_result" 77 | } 78 | ], 79 | "source": [ 80 | "u" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 5, 86 | "metadata": {}, 87 | "outputs": [ 88 | { 89 | "data": { 90 | "text/plain": [ 91 | "array([1.73205081, 1. ])" 92 | ] 93 | }, 94 | "execution_count": 5, 95 | "metadata": {}, 96 | "output_type": "execute_result" 97 | } 98 | ], 99 | "source": [ 100 | "s" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 6, 106 | "metadata": {}, 107 | "outputs": [ 108 | { 109 | "data": { 110 | "text/plain": [ 111 | "array([[-0.70710678, -0.70710678],\n", 112 | " [-0.70710678, 0.70710678]])" 113 | ] 114 | }, 115 | "execution_count": 6, 116 | "metadata": {}, 117 | "output_type": "execute_result" 118 | } 119 | ], 120 | "source": [ 121 | "vt.T" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 7, 127 | "metadata": {}, 128 | "outputs": [ 129 | { 130 | "data": { 131 | "text/plain": [ 132 | "True" 133 | ] 134 | }, 135 | "execution_count": 7, 136 | "metadata": {}, 137 | "output_type": "execute_result" 138 | } 139 | ], 140 | "source": [ 141 | "np.allclose(A, np.dot(u[:,:2]*s, vt))" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 8, 147 | "metadata": {}, 148 | "outputs": [ 149 | { 150 | "data": { 151 | "text/plain": [ 152 | "array([[ 1.11022302e-16, 1.00000000e+00],\n", 153 | " [ 1.00000000e+00, 1.00000000e+00],\n", 154 | " [ 1.00000000e+00, -4.44089210e-16]])" 155 | ] 156 | }, 157 | "execution_count": 8, 158 | "metadata": {}, 159 | "output_type": "execute_result" 160 | } 161 | ], 162 | "source": [ 163 | "np.dot(u[:,:2]*s, vt)" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 9, 169 | "metadata": {}, 170 | "outputs": [ 171 | { 172 | "data": { 173 | "text/plain": [ 174 | "array([[1.73205081, 0. ],\n", 175 | " [0. , 1. ],\n", 176 | " [0. , 0. ]])" 177 | ] 178 | }, 179 | "execution_count": 9, 180 | "metadata": {}, 181 | "output_type": "execute_result" 182 | } 183 | ], 184 | "source": [ 185 | "s_ = np.zeros((3,2))\n", 186 | "for i in range(2):\n", 187 | " s_[i][i] = s[i]\n", 188 | "\n", 189 | "s_" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 10, 195 | "metadata": {}, 196 | "outputs": [ 197 | { 198 | "data": { 199 | "text/plain": [ 200 | "array([[ 1.11022302e-16, 1.00000000e+00],\n", 201 | " [ 1.00000000e+00, 1.00000000e+00],\n", 202 | " [ 1.00000000e+00, -4.44089210e-16]])" 203 | ] 204 | }, 205 | "execution_count": 10, 206 | "metadata": {}, 207 | "output_type": "execute_result" 208 | } 209 | ], 210 | "source": [ 211 | "np.dot(np.dot(u, s_), vt)" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 11, 217 | "metadata": {}, 218 | "outputs": [ 219 | { 220 | "name": "stderr", 221 | "output_type": "stream", 222 | "text": [ 223 | "100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [02:23<00:00, 2.87s/it]\n" 224 | ] 225 | } 226 | ], 227 | "source": [ 228 | "import numpy as np\n", 229 | "import os\n", 230 | "from PIL import Image\n", 231 | "from tqdm import tqdm\n", 232 | "\n", 233 | "# 定义恢复函数,由分解后的矩阵恢复到原矩阵\n", 234 | "def restore(u, s, v, K): \n", 235 | " '''\n", 236 | " u:左奇异矩阵\n", 237 | " v:右奇异矩阵\n", 238 | " s:奇异值矩阵\n", 239 | " K:奇异值个数\n", 240 | " '''\n", 241 | " m, n = len(u), len(v[0])\n", 242 | " a = np.zeros((m, n))\n", 243 | " for k in range(K):\n", 244 | " uk = u[:, k].reshape(m, 1)\n", 245 | " vk = v[k].reshape(1, n)\n", 246 | " # 前k个奇异值的加总\n", 247 | " a += s[k] * np.dot(uk, vk) \n", 248 | " a = a.clip(0, 255)\n", 249 | " return np.rint(a).astype('uint8')\n", 250 | "\n", 251 | "A = np.array(Image.open(\"./louwill.jpg\", 'r'))\n", 252 | "# 对RGB图像进行奇异值分解\n", 253 | "u_r, s_r, v_r = np.linalg.svd(A[:, :, 0]) \n", 254 | "u_g, s_g, v_g = np.linalg.svd(A[:, :, 1])\n", 255 | "u_b, s_b, v_b = np.linalg.svd(A[:, :, 2])\n", 256 | "\n", 257 | "# 使用前50个奇异值\n", 258 | "K = 50 \n", 259 | "output_path = r'./svd_pic'\n", 260 | "# \n", 261 | "for k in tqdm(range(1, K+1)):\n", 262 | " R = restore(u_r, s_r, v_r, k)\n", 263 | " G = restore(u_g, s_g, v_g, k)\n", 264 | " B = restore(u_b, s_b, v_b, k)\n", 265 | " I = np.stack((R, G, B), axis=2) \n", 266 | " # Image.fromarray(I).save('%s\\\\svd_%d.jpg' % (output_path, k))\n", 267 | " Image.fromarray(I).save('%s/svd_%d.jpg' % (output_path, k))" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 12, 273 | "metadata": {}, 274 | "outputs": [ 275 | { 276 | "data": { 277 | "text/plain": [ 278 | "(1280, 960, 3)" 279 | ] 280 | }, 281 | "execution_count": 12, 282 | "metadata": {}, 283 | "output_type": "execute_result" 284 | } 285 | ], 286 | "source": [ 287 | "A.shape" 288 | ] 289 | } 290 | ], 291 | "metadata": { 292 | "kernelspec": { 293 | "display_name": "Python 3 (ipykernel)", 294 | "language": "python", 295 | "name": "python3" 296 | }, 297 | "language_info": { 298 | "codemirror_mode": { 299 | "name": "ipython", 300 | "version": 3 301 | }, 302 | "file_extension": ".py", 303 | "mimetype": "text/x-python", 304 | "name": "python", 305 | "nbconvert_exporter": "python", 306 | "pygments_lexer": "ipython3", 307 | "version": "3.9.7" 308 | }, 309 | "toc": { 310 | "base_numbering": 1, 311 | "nav_menu": {}, 312 | "number_sections": true, 313 | "sideBar": true, 314 | "skip_h1_title": false, 315 | "title_cell": "Table of Contents", 316 | "title_sidebar": "Contents", 317 | "toc_cell": false, 318 | "toc_position": {}, 319 | "toc_section_display": true, 320 | "toc_window_display": false 321 | } 322 | }, 323 | "nbformat": 4, 324 | "nbformat_minor": 2 325 | } 326 | -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_1.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_10.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_11.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_12.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_12.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_13.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_13.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_14.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_14.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_15.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_16.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_16.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_17.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_17.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_18.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_18.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_19.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_19.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_2.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_20.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_20.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_21.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_21.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_22.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_22.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_23.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_23.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_24.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_24.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_25.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_25.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_26.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_26.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_27.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_27.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_28.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_28.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_29.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_29.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_3.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_30.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_30.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_31.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_31.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_32.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_32.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_33.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_33.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_34.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_34.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_35.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_35.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_36.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_36.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_37.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_37.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_38.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_38.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_39.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_39.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_4.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_40.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_40.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_41.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_41.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_42.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_42.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_43.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_43.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_44.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_44.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_45.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_45.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_46.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_46.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_47.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_47.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_48.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_48.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_49.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_49.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_5.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_50.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_50.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_6.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_7.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_8.jpg -------------------------------------------------------------------------------- /charpter19_SVD/svd_pic/svd_9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/angellyao/formulaandcode/4453bd15c43673339754b86c85a3ae899ef1f1bb/charpter19_SVD/svd_pic/svd_9.jpg -------------------------------------------------------------------------------- /charpter20_MEM/max_entropy_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## 最大信息熵模型" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "还有问题没解决。" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import pandas as pd\n", 24 | "import numpy as np\n", 25 | "from collections import defaultdict\n", 26 | "\n", 27 | "class MaxEnt:\n", 28 | " def __init__(self, max_iter=100):\n", 29 | " # 训练输入\n", 30 | " self.X_ = None\n", 31 | " # 训练标签\n", 32 | " self.y_ = None\n", 33 | " # 标签类别数量\n", 34 | " self.m = None \n", 35 | " # 特征数量\n", 36 | " self.n = None \n", 37 | " # 训练样本量\n", 38 | " self.N = None \n", 39 | " # 常数特征取值\n", 40 | " self.M = None\n", 41 | " # 权重系数\n", 42 | " self.w = None\n", 43 | " # 标签名称\n", 44 | " self.labels = defaultdict(int)\n", 45 | " # 特征名称\n", 46 | " self.features = defaultdict(int)\n", 47 | " # 最大迭代次数\n", 48 | " self.max_iter = max_iter\n", 49 | "\n", 50 | " ### 计算特征函数关于经验联合分布P(X,Y)的期望\n", 51 | " def _EP_hat_f(self, x, y):\n", 52 | " self.Pxy = np.zeros((self.m, self.n))\n", 53 | " self.Px = np.zeros(self.n)\n", 54 | " for x_, y_ in zip(x, y):\n", 55 | " # 遍历每个样本\n", 56 | " for x__ in set(x_):\n", 57 | " self.Pxy[self.labels[y_], self.features[x__]] += 1\n", 58 | " self.Px[self.features[x__]] += 1 \n", 59 | " self.EP_hat_f = self.Pxy/self.N\n", 60 | " \n", 61 | " ### 计算特征函数关于模型P(Y|X)与经验分布P(X)的期望\n", 62 | " def _EP_f(self):\n", 63 | " # self.EPf = np.zeros((self.m, self.n))\n", 64 | " self.EP_f = np.zeros((self.m, self.n))\n", 65 | " for X in self.X_:\n", 66 | " pw = self._pw(X)\n", 67 | " pw = pw.reshape(self.m, 1)\n", 68 | " px = self.Px.reshape(1, self.n)\n", 69 | " self.EP_f += pw*px / self.N\n", 70 | " \n", 71 | " ### 最大熵模型P(y|x)\n", 72 | " def _pw(self, x):\n", 73 | " mask = np.zeros(self.n+1)\n", 74 | " for ix in x:\n", 75 | " mask[self.features[ix]] = 1\n", 76 | " tmp = self.w * mask[1:]\n", 77 | " pw = np.exp(np.sum(tmp, axis=1))\n", 78 | " Z = np.sum(pw)\n", 79 | " pw = pw/Z\n", 80 | " return pw\n", 81 | "\n", 82 | " ### 熵模型拟合\n", 83 | " ### 基于改进的迭代尺度方法IIS\n", 84 | " def fit(self, x, y):\n", 85 | " # 训练输入\n", 86 | " self.X_ = x\n", 87 | " # 训练输出\n", 88 | " self.y_ = list(set(y))\n", 89 | " # 输入数据展平后集合\n", 90 | " tmp = set(self.X_.flatten())\n", 91 | " # 特征命名\n", 92 | " self.features = defaultdict(int, zip(tmp, range(1, len(tmp)+1))) \n", 93 | " # 标签命名\n", 94 | " self.labels = dict(zip(self.y_, range(len(self.y_))))\n", 95 | " # 特征数\n", 96 | " self.n = len(self.features)+1 \n", 97 | " # 标签类别数量\n", 98 | " self.m = len(self.labels)\n", 99 | " # 训练样本量\n", 100 | " self.N = len(x) \n", 101 | " # 计算EP_hat_f\n", 102 | " self._EP_hat_f(x, y)\n", 103 | " # 初始化系数矩阵\n", 104 | " self.w = np.zeros((self.m, self.n))\n", 105 | " # 循环迭代\n", 106 | " i = 0\n", 107 | " while i <= self.max_iter:\n", 108 | " # 计算EPf\n", 109 | " self._EP_f()\n", 110 | " # self.EP_f()\n", 111 | " # 令常数特征函数为M\n", 112 | " self.M = 100\n", 113 | " # IIS算法步骤(3)\n", 114 | " # tmp = np.true_divide(self.EP_hat_f, self.EP_f)\n", 115 | " tmp = np.true_divide(self.EP_hat_f, self._EP_f)\n", 116 | " tmp[tmp == np.inf] = 0\n", 117 | " tmp = np.nan_to_num(tmp)\n", 118 | " sigma = np.where(tmp != 0, 1/self.M*np.log(tmp), 0) \n", 119 | " # 更新系数:IIS步骤(4)\n", 120 | " self.w = self.w + sigma\n", 121 | " i += 1\n", 122 | " print('training done.')\n", 123 | " return self\n", 124 | "\n", 125 | " # 定义最大熵模型预测函数\n", 126 | " def predict(self, x):\n", 127 | " res = np.zeros(len(x), dtype=np.int64)\n", 128 | " for ix, x_ in enumerate(x):\n", 129 | " tmp = self._pw(x_)\n", 130 | " print(tmp, np.argmax(tmp), self.labels)\n", 131 | " res[ix] = self.labels[self.y_[np.argmax(tmp)]]\n", 132 | " return np.array([self.y_[ix] for ix in res])" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 2, 138 | "metadata": {}, 139 | "outputs": [ 140 | { 141 | "name": "stdout", 142 | "output_type": "stream", 143 | "text": [ 144 | "\n", 145 | "\n", 146 | "(105, 4) (105,)\n" 147 | ] 148 | } 149 | ], 150 | "source": [ 151 | "from sklearn.datasets import load_iris\n", 152 | "from sklearn.model_selection import train_test_split\n", 153 | "raw_data = load_iris()\n", 154 | "X, labels = raw_data.data, raw_data.target\n", 155 | "X_train, X_test, y_train, y_test = train_test_split(X, labels, test_size=0.3, random_state=43)\n", 156 | "print(type(X_train))\n", 157 | "print(type(y_train))\n", 158 | "print(X_train.shape, y_train.shape)" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 3, 164 | "metadata": {}, 165 | "outputs": [ 166 | { 167 | "data": { 168 | "text/plain": [ 169 | "array([2, 2, 2, 2, 2])" 170 | ] 171 | }, 172 | "execution_count": 3, 173 | "metadata": {}, 174 | "output_type": "execute_result" 175 | } 176 | ], 177 | "source": [ 178 | "labels[-5:]" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 4, 184 | "metadata": {}, 185 | "outputs": [ 186 | { 187 | "ename": "TypeError", 188 | "evalue": "unsupported operand type(s) for /: 'float' and 'method'", 189 | "output_type": "error", 190 | "traceback": [ 191 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 192 | "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)", 193 | "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_12580/713967991.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0msklearn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmetrics\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0maccuracy_score\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[0mmaxent\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mMaxEnt\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 3\u001b[1;33m \u001b[0mmaxent\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 4\u001b[0m \u001b[0my_pred\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmaxent\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX_test\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0maccuracy_score\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0my_test\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 194 | "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_12580/562601513.py\u001b[0m in \u001b[0;36mfit\u001b[1;34m(self, x, y)\u001b[0m\n\u001b[0;32m 91\u001b[0m \u001b[1;31m# IIS算法步骤(3)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 92\u001b[0m \u001b[1;31m# tmp = np.true_divide(self.EP_hat_f, self.EP_f)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 93\u001b[1;33m \u001b[0mtmp\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtrue_divide\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mEP_hat_f\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_EP_f\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 94\u001b[0m \u001b[0mtmp\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mtmp\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minf\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 95\u001b[0m \u001b[0mtmp\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnan_to_num\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtmp\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 195 | "\u001b[1;31mTypeError\u001b[0m: unsupported operand type(s) for /: 'float' and 'method'" 196 | ] 197 | } 198 | ], 199 | "source": [ 200 | "from sklearn.metrics import accuracy_score\n", 201 | "maxent = MaxEnt()\n", 202 | "maxent.fit(X_train, y_train)\n", 203 | "y_pred = maxent.predict(X_test)\n", 204 | "print(accuracy_score(y_test, y_pred))" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [] 213 | } 214 | ], 215 | "metadata": { 216 | "kernelspec": { 217 | "display_name": "Python 3 (ipykernel)", 218 | "language": "python", 219 | "name": "python3" 220 | }, 221 | "language_info": { 222 | "codemirror_mode": { 223 | "name": "ipython", 224 | "version": 3 225 | }, 226 | "file_extension": ".py", 227 | "mimetype": "text/x-python", 228 | "name": "python", 229 | "nbconvert_exporter": "python", 230 | "pygments_lexer": "ipython3", 231 | "version": "3.9.7" 232 | }, 233 | "toc": { 234 | "base_numbering": 1, 235 | "nav_menu": {}, 236 | "number_sections": true, 237 | "sideBar": true, 238 | "skip_h1_title": false, 239 | "title_cell": "Table of Contents", 240 | "title_sidebar": "Contents", 241 | "toc_cell": false, 242 | "toc_position": {}, 243 | "toc_section_display": true, 244 | "toc_window_display": false 245 | } 246 | }, 247 | "nbformat": 4, 248 | "nbformat_minor": 2 249 | } 250 | -------------------------------------------------------------------------------- /charpter21_Bayesian_models/bayesian_network.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## 贝叶斯网络:bayesian network" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stderr", 17 | "output_type": "stream", 18 | "text": [ 19 | "C:\\Users\\18765\\AppData\\Roaming\\Python\\Python39\\site-packages\\pgmpy\\models\\BayesianModel.py:8: FutureWarning: BayesianModel has been renamed to BayesianNetwork. Please use BayesianNetwork class, BayesianModel will be removed in future.\n", 20 | " warnings.warn(\n" 21 | ] 22 | } 23 | ], 24 | "source": [ 25 | "# 导入pgmpy相关模块\n", 26 | "from pgmpy.factors.discrete import TabularCPD\n", 27 | "from pgmpy.models import BayesianModel\n", 28 | "letter_model = BayesianModel([('D', 'G'),\n", 29 | " ('I', 'G'),\n", 30 | " ('G', 'L'),\n", 31 | " ('I', 'S')])" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 2, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "# 学生成绩的条件概率分布\n", 41 | "grade_cpd = TabularCPD(\n", 42 | " variable='G', # 节点名称\n", 43 | " variable_card=3, # 节点取值个数\n", 44 | " values=[[0.3, 0.05, 0.9, 0.5], # 该节点的概率表\n", 45 | " [0.4, 0.25, 0.08, 0.3],\n", 46 | " [0.3, 0.7, 0.02, 0.2]],\n", 47 | " evidence=['I', 'D'], # 该节点的依赖节点\n", 48 | " evidence_card=[2, 2] # 依赖节点的取值个数\n", 49 | ")\n", 50 | "# 考试难度的条件概率分布\n", 51 | "difficulty_cpd = TabularCPD(\n", 52 | " variable='D',\n", 53 | " variable_card=2,\n", 54 | " values=[[0.6], [0.4]]\n", 55 | ")\n", 56 | "# 个人天赋的条件概率分布\n", 57 | "intel_cpd = TabularCPD(\n", 58 | " variable='I',\n", 59 | " variable_card=2,\n", 60 | " values=[[0.7], [0.3]]\n", 61 | ")\n", 62 | "# 推荐信质量的条件概率分布\n", 63 | "letter_cpd = TabularCPD(\n", 64 | " variable='L',\n", 65 | " variable_card=2,\n", 66 | " values=[[0.1, 0.4, 0.99],\n", 67 | " [0.9, 0.6, 0.01]],\n", 68 | " evidence=['G'],\n", 69 | " evidence_card=[3]\n", 70 | ")\n", 71 | "# SAT考试分数的条件概率分布\n", 72 | "sat_cpd = TabularCPD(\n", 73 | " variable='S',\n", 74 | " variable_card=2,\n", 75 | " values=[[0.95, 0.2],\n", 76 | " [0.05, 0.8]],\n", 77 | " evidence=['I'],\n", 78 | " evidence_card=[2]\n", 79 | ")" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 3, 85 | "metadata": {}, 86 | "outputs": [ 87 | { 88 | "data": { 89 | "application/vnd.jupyter.widget-view+json": { 90 | "model_id": "112eac3f6f9c48d8850b37bccdba7460", 91 | "version_major": 2, 92 | "version_minor": 0 93 | }, 94 | "text/plain": [ 95 | "0it [00:00, ?it/s]" 96 | ] 97 | }, 98 | "metadata": {}, 99 | "output_type": "display_data" 100 | }, 101 | { 102 | "data": { 103 | "application/vnd.jupyter.widget-view+json": { 104 | "model_id": "d946ffc64250489d8da2cec7d80a1aaf", 105 | "version_major": 2, 106 | "version_minor": 0 107 | }, 108 | "text/plain": [ 109 | "0it [00:00, ?it/s]" 110 | ] 111 | }, 112 | "metadata": {}, 113 | "output_type": "display_data" 114 | }, 115 | { 116 | "name": "stdout", 117 | "output_type": "stream", 118 | "text": [ 119 | "+------+----------+\n", 120 | "| G | phi(G) |\n", 121 | "+======+==========+\n", 122 | "| G(0) | 0.9000 |\n", 123 | "+------+----------+\n", 124 | "| G(1) | 0.0800 |\n", 125 | "+------+----------+\n", 126 | "| G(2) | 0.0200 |\n", 127 | "+------+----------+\n" 128 | ] 129 | } 130 | ], 131 | "source": [ 132 | "# 将各节点添加到模型中,构建贝叶斯网络\n", 133 | "letter_model.add_cpds(\n", 134 | " grade_cpd, \n", 135 | " difficulty_cpd,\n", 136 | " intel_cpd,\n", 137 | " letter_cpd,\n", 138 | " sat_cpd\n", 139 | ")\n", 140 | "# 导入pgmpy贝叶斯推断模块\n", 141 | "from pgmpy.inference import VariableElimination\n", 142 | "# 贝叶斯网络推断\n", 143 | "letter_infer = VariableElimination(letter_model)\n", 144 | "# 天赋较好且考试不难的情况下推断该学生获得推荐信质量的好坏\n", 145 | "prob_G = letter_infer.query(\n", 146 | " variables=['G'],\n", 147 | " evidence={'I': 1, 'D': 0})\n", 148 | "print(prob_G)" 149 | ] 150 | } 151 | ], 152 | "metadata": { 153 | "kernelspec": { 154 | "display_name": "Python 3 (ipykernel)", 155 | "language": "python", 156 | "name": "python3" 157 | }, 158 | "language_info": { 159 | "codemirror_mode": { 160 | "name": "ipython", 161 | "version": 3 162 | }, 163 | "file_extension": ".py", 164 | "mimetype": "text/x-python", 165 | "name": "python", 166 | "nbconvert_exporter": "python", 167 | "pygments_lexer": "ipython3", 168 | "version": "3.9.7" 169 | }, 170 | "toc": { 171 | "base_numbering": 1, 172 | "nav_menu": {}, 173 | "number_sections": true, 174 | "sideBar": true, 175 | "skip_h1_title": false, 176 | "title_cell": "Table of Contents", 177 | "title_sidebar": "Contents", 178 | "toc_cell": false, 179 | "toc_position": {}, 180 | "toc_section_display": true, 181 | "toc_window_display": false 182 | } 183 | }, 184 | "nbformat": 4, 185 | "nbformat_minor": 4 186 | } 187 | -------------------------------------------------------------------------------- /charpter21_Bayesian_models/naive_bayes.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## 朴素贝叶斯:Naive Bayes" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "import pandas as pd" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "metadata": {}, 24 | "outputs": [ 25 | { 26 | "data": { 27 | "text/html": [ 28 | "
\n", 29 | "\n", 42 | "\n", 43 | " \n", 44 | " \n", 45 | " \n", 46 | " \n", 47 | " \n", 48 | " \n", 49 | " \n", 50 | " \n", 51 | " \n", 52 | " \n", 53 | " \n", 54 | " \n", 55 | " \n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | "
x1x2y
01S-1
11M-1
21M1
31S1
41S-1
\n", 84 | "
" 85 | ], 86 | "text/plain": [ 87 | " x1 x2 y\n", 88 | "0 1 S -1\n", 89 | "1 1 M -1\n", 90 | "2 1 M 1\n", 91 | "3 1 S 1\n", 92 | "4 1 S -1" 93 | ] 94 | }, 95 | "execution_count": 2, 96 | "metadata": {}, 97 | "output_type": "execute_result" 98 | } 99 | ], 100 | "source": [ 101 | "### 构造数据集\n", 102 | "### 来自于李航统计学习方法表4.1\n", 103 | "x1 = [1,1,1,1,1,2,2,2,2,2,3,3,3,3,3]\n", 104 | "x2 = ['S','M','M','S','S','S','M','M','L','L','L','M','M','L','L']\n", 105 | "y = [-1,-1,1,1,-1,-1,-1,1,1,1,1,1,1,1,-1]\n", 106 | "\n", 107 | "df = pd.DataFrame({'x1':x1, 'x2':x2, 'y':y})\n", 108 | "df.head()" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 3, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "X = df[['x1', 'x2']]\n", 118 | "y = df[['y']]" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 4, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "def nb_fit(X, y):\n", 128 | " classes = y[y.columns[0]].unique()\n", 129 | " class_count = y[y.columns[0]].value_counts()\n", 130 | " class_prior = class_count/len(y)\n", 131 | " \n", 132 | " prior = dict()\n", 133 | " for col in X.columns:\n", 134 | " for j in classes:\n", 135 | " p_x_y = X[(y==j).values][col].value_counts()\n", 136 | " for i in p_x_y.index:\n", 137 | " prior[(col, i, j)] = p_x_y[i]/class_count[j]\n", 138 | " return classes, class_prior, prior" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 5, 144 | "metadata": {}, 145 | "outputs": [ 146 | { 147 | "name": "stdout", 148 | "output_type": "stream", 149 | "text": [ 150 | "[-1 1] 1 0.6\n", 151 | "-1 0.4\n", 152 | "Name: y, dtype: float64 {('x1', 1, -1): 0.5, ('x1', 2, -1): 0.3333333333333333, ('x1', 3, -1): 0.16666666666666666, ('x1', 3, 1): 0.4444444444444444, ('x1', 2, 1): 0.3333333333333333, ('x1', 1, 1): 0.2222222222222222, ('x2', 'S', -1): 0.5, ('x2', 'M', -1): 0.3333333333333333, ('x2', 'L', -1): 0.16666666666666666, ('x2', 'M', 1): 0.4444444444444444, ('x2', 'L', 1): 0.4444444444444444, ('x2', 'S', 1): 0.1111111111111111}\n" 153 | ] 154 | } 155 | ], 156 | "source": [ 157 | "classes, class_prior, prior = nb_fit(X, y)\n", 158 | "print(classes, class_prior, prior)" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 6, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "X_test = {'x1': 2, 'x2': 'S'}" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 7, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "classes, class_prior, prior = nb_fit(X, y)\n", 177 | "\n", 178 | "def predict(X_test):\n", 179 | " res = []\n", 180 | " for c in classes:\n", 181 | " p_y = class_prior[c]\n", 182 | " p_x_y = 1\n", 183 | " for i in X_test.items():\n", 184 | " p_x_y *= prior[tuple(list(i)+[c])]\n", 185 | " res.append(p_y*p_x_y)\n", 186 | " return classes[np.argmax(res)]" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 8, 192 | "metadata": {}, 193 | "outputs": [ 194 | { 195 | "name": "stdout", 196 | "output_type": "stream", 197 | "text": [ 198 | "测试数据预测类别为: -1\n" 199 | ] 200 | } 201 | ], 202 | "source": [ 203 | "print('测试数据预测类别为:', predict(X_test))" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": 9, 209 | "metadata": {}, 210 | "outputs": [ 211 | { 212 | "name": "stdout", 213 | "output_type": "stream", 214 | "text": [ 215 | "Accuracy of GaussianNB in iris data test: 0.9466666666666667\n" 216 | ] 217 | } 218 | ], 219 | "source": [ 220 | "from sklearn.datasets import load_iris\n", 221 | "from sklearn.model_selection import train_test_split\n", 222 | "from sklearn.naive_bayes import GaussianNB\n", 223 | "from sklearn.metrics import accuracy_score\n", 224 | "X, y = load_iris(return_X_y=True)\n", 225 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0)\n", 226 | "gnb = GaussianNB()\n", 227 | "y_pred = gnb.fit(X_train, y_train).predict(X_test)\n", 228 | "print(\"Accuracy of GaussianNB in iris data test:\", \n", 229 | " accuracy_score(y_test, y_pred))" 230 | ] 231 | } 232 | ], 233 | "metadata": { 234 | "kernelspec": { 235 | "display_name": "Python 3 (ipykernel)", 236 | "language": "python", 237 | "name": "python3" 238 | }, 239 | "language_info": { 240 | "codemirror_mode": { 241 | "name": "ipython", 242 | "version": 3 243 | }, 244 | "file_extension": ".py", 245 | "mimetype": "text/x-python", 246 | "name": "python", 247 | "nbconvert_exporter": "python", 248 | "pygments_lexer": "ipython3", 249 | "version": "3.9.7" 250 | }, 251 | "toc": { 252 | "base_numbering": 1, 253 | "nav_menu": {}, 254 | "number_sections": true, 255 | "sideBar": true, 256 | "skip_h1_title": false, 257 | "title_cell": "Table of Contents", 258 | "title_sidebar": "Contents", 259 | "toc_cell": false, 260 | "toc_position": {}, 261 | "toc_section_display": true, 262 | "toc_window_display": false 263 | } 264 | }, 265 | "nbformat": 4, 266 | "nbformat_minor": 2 267 | } 268 | -------------------------------------------------------------------------------- /charpter22_EM/em.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## 最大期望模型:EM" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "# 导入numpy库 \n", 17 | "import numpy as np\n", 18 | "\n", 19 | "### EM算法过程函数定义\n", 20 | "def em(data, thetas, max_iter=30, eps=1e-3):\n", 21 | " '''\n", 22 | " 输入:\n", 23 | " data:观测数据\n", 24 | " thetas:初始化的估计参数值\n", 25 | " max_iter:最大迭代次数\n", 26 | " eps:收敛阈值\n", 27 | " 输出:\n", 28 | " thetas:估计参数\n", 29 | " '''\n", 30 | " # 初始化似然函数值\n", 31 | " ll_old = -np.infty\n", 32 | " for i in range(max_iter):\n", 33 | " ### E步:求隐变量分布\n", 34 | " # 对数似然\n", 35 | " log_like = np.array([np.sum(data * np.log(theta), axis=1) for theta in thetas])\n", 36 | " # 似然\n", 37 | " like = np.exp(log_like)\n", 38 | " # 求隐变量分布\n", 39 | " ws = like/like.sum(0)\n", 40 | " # 概率加权\n", 41 | " vs = np.array([w[:, None] * data for w in ws])\n", 42 | " ### M步:更新参数值\n", 43 | " thetas = np.array([v.sum(0)/v.sum() for v in vs])\n", 44 | " # 更新似然函数\n", 45 | " ll_new = np.sum([w*l for w, l in zip(ws, log_like)])\n", 46 | " print(\"Iteration: %d\" % (i+1))\n", 47 | " print(\"theta_B = %.2f, theta_C = %.2f, ll = %.2f\" \n", 48 | " % (thetas[0,0], thetas[1,0], ll_new))\n", 49 | " # 满足迭代条件即退出迭代\n", 50 | " if np.abs(ll_new - ll_old) < eps:\n", 51 | " break\n", 52 | " ll_old = ll_new\n", 53 | " return thetas" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 2, 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "name": "stdout", 63 | "output_type": "stream", 64 | "text": [ 65 | "Iteration: 1\n", 66 | "theta_B = 0.71, theta_C = 0.58, ll = -32.69\n", 67 | "Iteration: 2\n", 68 | "theta_B = 0.75, theta_C = 0.57, ll = -31.26\n", 69 | "Iteration: 3\n", 70 | "theta_B = 0.77, theta_C = 0.55, ll = -30.76\n", 71 | "Iteration: 4\n", 72 | "theta_B = 0.78, theta_C = 0.53, ll = -30.33\n", 73 | "Iteration: 5\n", 74 | "theta_B = 0.79, theta_C = 0.53, ll = -30.07\n", 75 | "Iteration: 6\n", 76 | "theta_B = 0.79, theta_C = 0.52, ll = -29.95\n", 77 | "Iteration: 7\n", 78 | "theta_B = 0.80, theta_C = 0.52, ll = -29.90\n", 79 | "Iteration: 8\n", 80 | "theta_B = 0.80, theta_C = 0.52, ll = -29.88\n", 81 | "Iteration: 9\n", 82 | "theta_B = 0.80, theta_C = 0.52, ll = -29.87\n", 83 | "Iteration: 10\n", 84 | "theta_B = 0.80, theta_C = 0.52, ll = -29.87\n", 85 | "Iteration: 11\n", 86 | "theta_B = 0.80, theta_C = 0.52, ll = -29.87\n", 87 | "Iteration: 12\n", 88 | "theta_B = 0.80, theta_C = 0.52, ll = -29.87\n" 89 | ] 90 | } 91 | ], 92 | "source": [ 93 | "# 观测数据,5次独立试验,每次试验10次抛掷的正反次数\n", 94 | "# 比如第一次试验为5次正面5次反面\n", 95 | "observed_data = np.array([(5,5), (9,1), (8,2), (4,6), (7,3)])\n", 96 | "# 初始化参数值,即硬币B的正面概率为0.6,硬币C的正面概率为0.5\n", 97 | "thetas = np.array([[0.6, 0.4], [0.5, 0.5]])\n", 98 | "thetas = em(observed_data, thetas, max_iter=30, eps=1e-3)" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 3, 104 | "metadata": {}, 105 | "outputs": [ 106 | { 107 | "data": { 108 | "text/plain": [ 109 | "array([[0.7967829 , 0.2032171 ],\n", 110 | " [0.51959543, 0.48040457]])" 111 | ] 112 | }, 113 | "execution_count": 3, 114 | "metadata": {}, 115 | "output_type": "execute_result" 116 | } 117 | ], 118 | "source": [ 119 | "thetas" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [] 128 | } 129 | ], 130 | "metadata": { 131 | "kernelspec": { 132 | "display_name": "Python 3 (ipykernel)", 133 | "language": "python", 134 | "name": "python3" 135 | }, 136 | "language_info": { 137 | "codemirror_mode": { 138 | "name": "ipython", 139 | "version": 3 140 | }, 141 | "file_extension": ".py", 142 | "mimetype": "text/x-python", 143 | "name": "python", 144 | "nbconvert_exporter": "python", 145 | "pygments_lexer": "ipython3", 146 | "version": "3.9.7" 147 | }, 148 | "toc": { 149 | "base_numbering": 1, 150 | "nav_menu": {}, 151 | "number_sections": true, 152 | "sideBar": true, 153 | "skip_h1_title": false, 154 | "title_cell": "Table of Contents", 155 | "title_sidebar": "Contents", 156 | "toc_cell": false, 157 | "toc_position": {}, 158 | "toc_section_display": true, 159 | "toc_window_display": false 160 | } 161 | }, 162 | "nbformat": 4, 163 | "nbformat_minor": 4 164 | } 165 | -------------------------------------------------------------------------------- /charpter23_HMM/hmm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## 隐马尔科夫模型:HMM" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "\n", 18 | "### 定义HMM模型\n", 19 | "class HMM:\n", 20 | " def __init__(self, N, M, pi=None, A=None, B=None):\n", 21 | " # 可能的状态数\n", 22 | " self.N = N\n", 23 | " # 可能的观测数\n", 24 | " self.M = M\n", 25 | " # 初始状态概率向量\n", 26 | " self.pi = pi\n", 27 | " # 状态转移概率矩阵\n", 28 | " self.A = A\n", 29 | " # 观测概率矩阵\n", 30 | " self.B = B\n", 31 | "\n", 32 | " # 根据给定的概率分布随机返回数据\n", 33 | " def rdistribution(self, dist): \n", 34 | " r = np.random.rand()\n", 35 | " for ix, p in enumerate(dist):\n", 36 | " if r < p: \n", 37 | " return ix\n", 38 | " r -= p\n", 39 | "\n", 40 | " # 生成HMM观测序列\n", 41 | " def generate(self, T):\n", 42 | " # 根据初始概率分布生成第一个状态\n", 43 | " i = self.rdistribution(self.pi) \n", 44 | " # 生成第一个观测数据\n", 45 | " o = self.rdistribution(self.B[i]) \n", 46 | " observed_data = [o]\n", 47 | " # 遍历生成剩下的状态和观测数据\n", 48 | " for _ in range(T-1): \n", 49 | " i = self.rdistribution(self.A[i])\n", 50 | " o = self.rdistribution(self.B[i])\n", 51 | " observed_data.append(o)\n", 52 | " return observed_data" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 2, 58 | "metadata": {}, 59 | "outputs": [ 60 | { 61 | "name": "stdout", 62 | "output_type": "stream", 63 | "text": [ 64 | "[1, 1, 1, 1, 1]\n" 65 | ] 66 | } 67 | ], 68 | "source": [ 69 | "# 初始状态概率分布\n", 70 | "pi = np.array([0.25, 0.25, 0.25, 0.25])\n", 71 | "# 状态转移概率矩阵\n", 72 | "A = np.array([\n", 73 | " [0, 1, 0, 0],\n", 74 | " [0.4, 0, 0.6, 0],\n", 75 | " [0, 0.4, 0, 0.6],\n", 76 | "[0, 0, 0.5, 0.5]])\n", 77 | "# 观测概率矩阵\n", 78 | "B = np.array([\n", 79 | " [0.5, 0.5],\n", 80 | " [0.6, 0.4],\n", 81 | " [0.2, 0.8],\n", 82 | " [0.3, 0.7]])\n", 83 | "# 可能的状态数和观测数\n", 84 | "N = 4\n", 85 | "M = 2\n", 86 | "# 创建HMM实例\n", 87 | "hmm = HMM(N, M, pi, A, B)\n", 88 | "# 生成观测序列\n", 89 | "print(hmm.generate(5))" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 3, 95 | "metadata": {}, 96 | "outputs": [ 97 | { 98 | "name": "stdout", 99 | "output_type": "stream", 100 | "text": [ 101 | "0.01983169125\n" 102 | ] 103 | } 104 | ], 105 | "source": [ 106 | "### 前向算法计算条件概率\n", 107 | "def prob_calc(O):\n", 108 | " '''\n", 109 | " 输入:\n", 110 | " O:观测序列\n", 111 | " 输出:\n", 112 | " alpha.sum():条件概率\n", 113 | " '''\n", 114 | " # 初值\n", 115 | " alpha = pi * B[:, O[0]]\n", 116 | " # 递推\n", 117 | " for o in O[1:]:\n", 118 | " alpha_next = np.empty(4)\n", 119 | " for j in range(4):\n", 120 | " alpha_next[j] = np.sum(A[:,j] * alpha * B[j,o])\n", 121 | " alpha = alpha_next\n", 122 | " return alpha.sum()\n", 123 | "\n", 124 | "# 给定观测\n", 125 | "O = [1,0,1,0,0]\n", 126 | "print(prob_calc(O))" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 4, 132 | "metadata": {}, 133 | "outputs": [ 134 | { 135 | "name": "stdout", 136 | "output_type": "stream", 137 | "text": [ 138 | "[0, 1, 2, 3, 3]\n" 139 | ] 140 | } 141 | ], 142 | "source": [ 143 | "### 序列标注问题和维特比算法\n", 144 | "def viterbi_decode(O):\n", 145 | " '''\n", 146 | " 输入:\n", 147 | " O:观测序列\n", 148 | " 输出:\n", 149 | " path:最优隐状态路径\n", 150 | " ''' \n", 151 | " # 序列长度和初始观测\n", 152 | " T, o = len(O), O[0]\n", 153 | " # 初始化delta变量\n", 154 | " delta = pi * B[:, o]\n", 155 | " # 初始化varphi变量\n", 156 | " varphi = np.zeros((T, 4), dtype=int)\n", 157 | " path = [0] * T\n", 158 | " # 递推\n", 159 | " for i in range(1, T):\n", 160 | " delta = delta.reshape(-1, 1) \n", 161 | " tmp = delta * A\n", 162 | " varphi[i, :] = np.argmax(tmp, axis=0)\n", 163 | " delta = np.max(tmp, axis=0) * B[:, O[i]]\n", 164 | " # 终止\n", 165 | " path[-1] = np.argmax(delta)\n", 166 | " # 回溯最优路径\n", 167 | " for i in range(T-1, 0, -1):\n", 168 | " path[i-1] = varphi[i, path[i]]\n", 169 | " return path\n", 170 | "\n", 171 | "# 给定观测序列\n", 172 | "O = [1,0,1,1,0]\n", 173 | "# 输出最可能的隐状态序列\n", 174 | "print(viterbi_decode(O))" 175 | ] 176 | } 177 | ], 178 | "metadata": { 179 | "kernelspec": { 180 | "display_name": "Python 3 (ipykernel)", 181 | "language": "python", 182 | "name": "python3" 183 | }, 184 | "language_info": { 185 | "codemirror_mode": { 186 | "name": "ipython", 187 | "version": 3 188 | }, 189 | "file_extension": ".py", 190 | "mimetype": "text/x-python", 191 | "name": "python", 192 | "nbconvert_exporter": "python", 193 | "pygments_lexer": "ipython3", 194 | "version": "3.9.7" 195 | }, 196 | "toc": { 197 | "base_numbering": 1, 198 | "nav_menu": {}, 199 | "number_sections": true, 200 | "sideBar": true, 201 | "skip_h1_title": false, 202 | "title_cell": "Table of Contents", 203 | "title_sidebar": "Contents", 204 | "toc_cell": false, 205 | "toc_position": {}, 206 | "toc_section_display": true, 207 | "toc_window_display": false 208 | } 209 | }, 210 | "nbformat": 4, 211 | "nbformat_minor": 2 212 | } 213 | -------------------------------------------------------------------------------- /charpter24_CRF/crf.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### CRF" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "基于sklearn_crfsuite NER系统搭建,本例来自于sklearn_crfsuite官方tutorial" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "在anaconda的命令行窗口做操作,系统环境中没有这个 \n", 22 | "创建沙箱环境 \n", 23 | "https://blog.csdn.net/weixin_42218868/article/details/95398043 \n", 24 | "在沙箱环境中做好配置,整理.py文件,并做后续处理。" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 1, 30 | "metadata": {}, 31 | "outputs": [ 32 | { 33 | "ename": "ModuleNotFoundError", 34 | "evalue": "No module named 'sklearn_crfsuite'", 35 | "output_type": "error", 36 | "traceback": [ 37 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 38 | "\u001b[1;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", 39 | "\u001b[1;32mC:\\Users\\ADMINI~1\\AppData\\Local\\Temp/ipykernel_3636/240492406.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 7\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0msklearn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmodel_selection\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mRandomizedSearchCV\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 8\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 9\u001b[1;33m \u001b[1;32mimport\u001b[0m \u001b[0msklearn_crfsuite\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 10\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0msklearn_crfsuite\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mscorers\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 11\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0msklearn_crfsuite\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mmetrics\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 40 | "\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'sklearn_crfsuite'" 41 | ] 42 | } 43 | ], 44 | "source": [ 45 | "# 导入相关库\n", 46 | "import nltk\n", 47 | "import sklearn\n", 48 | "import scipy.stats\n", 49 | "from sklearn.metrics import make_scorer\n", 50 | "from sklearn.model_selection import cross_val_score\n", 51 | "from sklearn.model_selection import RandomizedSearchCV\n", 52 | "\n", 53 | "# sklearn-crfsuite的文档中的信息来看,更新到2017年,支持到python3.6,用不了\n", 54 | "# 尝试用沙箱解决\n", 55 | "import sklearn_crfsuite\n", 56 | "from sklearn_crfsuite import scorers\n", 57 | "from sklearn_crfsuite import metrics" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "# 基于NLTK下载示例数据集\n", 67 | "nltk.download('conll2002')" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "# 设置训练和测试样本\n", 77 | "train_sents = list(nltk.corpus.conll2002.iob_sents('esp.train'))\n", 78 | "test_sents = list(nltk.corpus.conll2002.iob_sents('esp.testb'))" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "train_sents[0]" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "# 单词转化为数值特征\n", 97 | "def word2features(sent, i):\n", 98 | " word = sent[i][0]\n", 99 | " postag = sent[i][1]\n", 100 | "\n", 101 | " features = {\n", 102 | " 'bias': 1.0,\n", 103 | " 'word.lower()': word.lower(),\n", 104 | " 'word[-3:]': word[-3:],\n", 105 | " 'word[-2:]': word[-2:],\n", 106 | " 'word.isupper()': word.isupper(),\n", 107 | " 'word.istitle()': word.istitle(),\n", 108 | " 'word.isdigit()': word.isdigit(),\n", 109 | " 'postag': postag,\n", 110 | " 'postag[:2]': postag[:2],\n", 111 | " }\n", 112 | " if i > 0:\n", 113 | " word1 = sent[i-1][0]\n", 114 | " postag1 = sent[i-1][1]\n", 115 | " features.update({\n", 116 | " '-1:word.lower()': word1.lower(),\n", 117 | " '-1:word.istitle()': word1.istitle(),\n", 118 | " '-1:word.isupper()': word1.isupper(),\n", 119 | " '-1:postag': postag1,\n", 120 | " '-1:postag[:2]': postag1[:2],\n", 121 | " })\n", 122 | " else:\n", 123 | " features['BOS'] = True\n", 124 | "\n", 125 | " if i < len(sent)-1:\n", 126 | " word1 = sent[i+1][0]\n", 127 | " postag1 = sent[i+1][1]\n", 128 | " features.update({\n", 129 | " '+1:word.lower()': word1.lower(),\n", 130 | " '+1:word.istitle()': word1.istitle(),\n", 131 | " '+1:word.isupper()': word1.isupper(),\n", 132 | " '+1:postag': postag1,\n", 133 | " '+1:postag[:2]': postag1[:2],\n", 134 | " })\n", 135 | " else:\n", 136 | " features['EOS'] = True\n", 137 | "\n", 138 | " return features\n", 139 | "\n", 140 | "\n", 141 | "def sent2features(sent):\n", 142 | " return [word2features(sent, i) for i in range(len(sent))]\n", 143 | "\n", 144 | "def sent2labels(sent):\n", 145 | " return [label for token, postag, label in sent]\n", 146 | "\n", 147 | "def sent2tokens(sent):\n", 148 | " return [token for token, postag, label in sent]" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "sent2features(train_sents[0])[0]" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": null, 163 | "metadata": {}, 164 | "outputs": [], 165 | "source": [ 166 | "# 构造训练集和测试集\n", 167 | "X_train = [sent2features(s) for s in train_sents]\n", 168 | "y_train = [sent2labels(s) for s in train_sents]\n", 169 | "\n", 170 | "X_test = [sent2features(s) for s in test_sents]\n", 171 | "y_test = [sent2labels(s) for s in test_sents]" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "metadata": {}, 178 | "outputs": [], 179 | "source": [ 180 | "print(len(X_train), len(X_test))" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "# 创建CRF模型实例\n", 190 | "crf = sklearn_crfsuite.CRF(\n", 191 | " algorithm='lbfgs',\n", 192 | " c1=0.1,\n", 193 | " c2=0.1,\n", 194 | " max_iterations=100,\n", 195 | " all_possible_transitions=True\n", 196 | ")\n", 197 | "# 模型训练\n", 198 | "crf.fit(X_train, y_train)\n", 199 | "# 类别标签\n", 200 | "labels = list(crf.classes_)\n", 201 | "labels.remove('O')\n", 202 | "# 模型预测\n", 203 | "y_pred = crf.predict(X_test)\n", 204 | "# 计算F1得分\n", 205 | "metrics.flat_f1_score(y_test, y_pred,\n", 206 | " average='weighted', labels=labels)" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": null, 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "# 打印B和I组的模型结果\n", 216 | "sorted_labels = sorted(\n", 217 | " labels,\n", 218 | " key=lambda name: (name[1:], name[0])\n", 219 | ")\n", 220 | "print(metrics.flat_classification_report(\n", 221 | " y_test, y_pred, labels=sorted_labels, digits=3\n", 222 | "))" 223 | ] 224 | } 225 | ], 226 | "metadata": { 227 | "kernelspec": { 228 | "display_name": "Python 3 (ipykernel)", 229 | "language": "python", 230 | "name": "python3" 231 | }, 232 | "language_info": { 233 | "codemirror_mode": { 234 | "name": "ipython", 235 | "version": 3 236 | }, 237 | "file_extension": ".py", 238 | "mimetype": "text/x-python", 239 | "name": "python", 240 | "nbconvert_exporter": "python", 241 | "pygments_lexer": "ipython3", 242 | "version": "3.9.7" 243 | }, 244 | "toc": { 245 | "base_numbering": 1, 246 | "nav_menu": {}, 247 | "number_sections": true, 248 | "sideBar": true, 249 | "skip_h1_title": false, 250 | "title_cell": "Table of Contents", 251 | "title_sidebar": "Contents", 252 | "toc_cell": false, 253 | "toc_position": {}, 254 | "toc_section_display": true, 255 | "toc_window_display": false 256 | } 257 | }, 258 | "nbformat": 4, 259 | "nbformat_minor": 2 260 | } 261 | -------------------------------------------------------------------------------- /charpter24_CRF/crffiles.py: -------------------------------------------------------------------------------- 1 | # 导入相关库 2 | import nltk 3 | import sklearn 4 | import scipy.stats 5 | from sklearn.metrics import make_scorer 6 | from sklearn.model_selection import cross_val_score 7 | from sklearn.model_selection import RandomizedSearchCV 8 | 9 | # sklearn-crfsuite的文档中的信息来看,更新到2017年,支持到python3.6,用不了 10 | # 尝试用沙箱解决 11 | import sklearn_crfsuite 12 | from sklearn_crfsuite import scorers 13 | from sklearn_crfsuite import metrics 14 | 15 | 16 | # 基于NLTK下载示例数据集 17 | nltk.download('conll2002') 18 | 19 | # 设置训练和测试样本 20 | train_sents = list(nltk.corpus.conll2002.iob_sents('esp.train')) 21 | test_sents = list(nltk.corpus.conll2002.iob_sents('esp.testb')) 22 | 23 | train_sents[0] 24 | 25 | 26 | # 单词转化为数值特征 27 | def word2features(sent, i): 28 | word = sent[i][0] 29 | postag = sent[i][1] 30 | 31 | features = { 32 | 'bias': 1.0, 33 | 'word.lower()': word.lower(), 34 | 'word[-3:]': word[-3:], 35 | 'word[-2:]': word[-2:], 36 | 'word.isupper()': word.isupper(), 37 | 'word.istitle()': word.istitle(), 38 | 'word.isdigit()': word.isdigit(), 39 | 'postag': postag, 40 | 'postag[:2]': postag[:2], 41 | } 42 | if i > 0: 43 | word1 = sent[i-1][0] 44 | postag1 = sent[i-1][1] 45 | features.update({ 46 | '-1:word.lower()': word1.lower(), 47 | '-1:word.istitle()': word1.istitle(), 48 | '-1:word.isupper()': word1.isupper(), 49 | '-1:postag': postag1, 50 | '-1:postag[:2]': postag1[:2], 51 | }) 52 | else: 53 | features['BOS'] = True 54 | 55 | if i < len(sent)-1: 56 | word1 = sent[i+1][0] 57 | postag1 = sent[i+1][1] 58 | features.update({ 59 | '+1:word.lower()': word1.lower(), 60 | '+1:word.istitle()': word1.istitle(), 61 | '+1:word.isupper()': word1.isupper(), 62 | '+1:postag': postag1, 63 | '+1:postag[:2]': postag1[:2], 64 | }) 65 | else: 66 | features['EOS'] = True 67 | 68 | return features 69 | 70 | 71 | def sent2features(sent): 72 | return [word2features(sent, i) for i in range(len(sent))] 73 | 74 | def sent2labels(sent): 75 | return [label for token, postag, label in sent] 76 | 77 | def sent2tokens(sent): 78 | return [token for token, postag, label in sent] 79 | 80 | 81 | sent2features(train_sents[0])[0] 82 | 83 | 84 | 85 | # 构造训练集和测试集 86 | X_train = [sent2features(s) for s in train_sents] 87 | y_train = [sent2labels(s) for s in train_sents] 88 | 89 | X_test = [sent2features(s) for s in test_sents] 90 | y_test = [sent2labels(s) for s in test_sents] 91 | 92 | 93 | # 训练集和测试集数量 94 | print(len(X_train), len(X_test)) 95 | 96 | 97 | # 创建CRF模型实例 98 | crf = sklearn_crfsuite.CRF( 99 | algorithm='lbfgs', 100 | c1=0.1, 101 | c2=0.1, 102 | max_iterations=100, 103 | all_possible_transitions=True) 104 | 105 | # 模型训练 106 | crf.fit(X_train, y_train) 107 | # 类别标签 108 | labels = list(crf.classes_) 109 | labels.remove('O') 110 | # 模型预测 111 | y_pred = crf.predict(X_test) 112 | # 计算F1得分 113 | metrics.flat_f1_score(y_test, y_pred, average='weighted', labels=labels) 114 | 115 | 116 | 117 | 118 | # 打印B和I组的模型结果 119 | sorted_labels = sorted( 120 | labels, 121 | key=lambda name: (name[1:], name[0]) 122 | ) 123 | print(metrics.flat_classification_report( 124 | y_test, y_pred, labels=sorted_labels, digits=3 125 | )) --------------------------------------------------------------------------------